Merge pull request #1045 from borglab/feature/discrete_wrapping

release/4.3a0
Frank Dellaert 2022-01-19 20:58:13 -05:00 committed by GitHub
commit d8abdc280d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 175 additions and 108 deletions

View File

@ -143,67 +143,64 @@ void DiscreteConditional::print(const string& s,
}
}
cout << "):\n";
ADT::print("");
ADT::print("", formatter);
cout << endl;
}
/* ******************************************************************************** */
bool DiscreteConditional::equals(const DiscreteFactor& other,
double tol) const {
if (!dynamic_cast<const DecisionTreeFactor*>(&other))
double tol) const {
if (!dynamic_cast<const DecisionTreeFactor*>(&other)) {
return false;
else {
const DecisionTreeFactor& f(
static_cast<const DecisionTreeFactor&>(other));
} else {
const DecisionTreeFactor& f(static_cast<const DecisionTreeFactor&>(other));
return DecisionTreeFactor::equals(f, tol);
}
}
/* ******************************************************************************** */
/* ************************************************************************** */
static DiscreteConditional::ADT Choose(const DiscreteConditional& conditional,
const DiscreteValues& parentsValues) {
const DiscreteValues& given,
bool forceComplete = true) {
// Get the big decision tree with all the levels, and then go down the
// branches based on the value of the parent variables.
DiscreteConditional::ADT adt(conditional);
size_t value;
for (Key j : conditional.parents()) {
try {
value = parentsValues.at(j);
value = given.at(j);
adt = adt.choose(j, value); // ADT keeps getting smaller.
} catch (std::out_of_range&) {
parentsValues.print("parentsValues: ");
throw runtime_error("DiscreteConditional::choose: parent value missing");
};
if (forceComplete) {
given.print("parentsValues: ");
throw runtime_error(
"DiscreteConditional::Choose: parent value missing");
}
}
}
return adt;
}
/* ******************************************************************************** */
DecisionTreeFactor::shared_ptr DiscreteConditional::choose(
const DiscreteValues& parentsValues) const {
// Get the big decision tree with all the levels, and then go down the
// branches based on the value of the parent variables.
ADT adt(*this);
size_t value;
for (Key j : parents()) {
try {
value = parentsValues.at(j);
adt = adt.choose(j, value); // ADT keeps getting smaller.
} catch (exception&) {
parentsValues.print("parentsValues: ");
throw runtime_error("DiscreteConditional::choose: parent value missing");
};
}
/* ************************************************************************** */
DiscreteConditional::shared_ptr DiscreteConditional::choose(
const DiscreteValues& given) const {
ADT adt = Choose(*this, given, false); // P(F|S=given)
// Convert ADT to factor.
DiscreteKeys discreteKeys;
// Collect all keys not in given.
DiscreteKeys dKeys;
for (Key j : frontals()) {
discreteKeys.emplace_back(j, this->cardinality(j));
dKeys.emplace_back(j, this->cardinality(j));
}
return boost::make_shared<DecisionTreeFactor>(discreteKeys, adt);
for (size_t i = nrFrontals(); i < size(); i++) {
Key j = keys_[i];
if (given.count(j) == 0) {
dKeys.emplace_back(j, this->cardinality(j));
}
}
return boost::make_shared<DiscreteConditional>(nrFrontals(), dKeys, adt);
}
/* ******************************************************************************** */
/* ************************************************************************** */
DecisionTreeFactor::shared_ptr DiscreteConditional::likelihood(
const DiscreteValues& frontalValues) const {
// Get the big decision tree with all the levels, and then go down the
@ -217,7 +214,7 @@ DecisionTreeFactor::shared_ptr DiscreteConditional::likelihood(
} catch (exception&) {
frontalValues.print("frontalValues: ");
throw runtime_error("DiscreteConditional::choose: frontal value missing");
};
}
}
// Convert ADT to factor.
@ -242,7 +239,6 @@ DecisionTreeFactor::shared_ptr DiscreteConditional::likelihood(
/* ************************************************************************** */
void DiscreteConditional::solveInPlace(DiscreteValues* values) const {
// TODO(Abhijit): is this really the fastest way? He thinks it is.
ADT pFS = Choose(*this, *values); // P(F|S=parentsValues)
// Initialize
@ -276,11 +272,9 @@ void DiscreteConditional::sampleInPlace(DiscreteValues* values) const {
(*values)[j] = sampled; // store result in partial solution
}
/* ******************************************************************************** */
/* ************************************************************************** */
size_t DiscreteConditional::solve(const DiscreteValues& parentsValues) const {
// TODO: is this really the fastest way? I think it is.
ADT pFS = Choose(*this, parentsValues); // P(F|S=parentsValues)
ADT pFS = Choose(*this, parentsValues); // P(F|S=parentsValues)
// Then, find the max over all remaining
// TODO, only works for one key now, seems horribly slow this way

View File

@ -157,9 +157,20 @@ class GTSAM_EXPORT DiscreteConditional
return ADT::operator()(values);
}
/** Restrict to given parent values, returns DecisionTreeFactor */
DecisionTreeFactor::shared_ptr choose(
const DiscreteValues& parentsValues) const;
/**
* @brief restrict to given *parent* values.
*
* Note: does not need be complete set. Examples:
*
* P(C|D,E) + . -> P(C|D,E)
* P(C|D,E) + E -> P(C|D)
* P(C|D,E) + D -> P(C|E)
* P(C|D,E) + D,E -> P(C)
* P(C|D,E) + C -> error!
*
* @return a shared_ptr to a new DiscreteConditional
*/
shared_ptr choose(const DiscreteValues& given) const;
/** Convert to a likelihood factor by providing value before bar. */
DecisionTreeFactor::shared_ptr likelihood(

View File

@ -64,33 +64,35 @@ template<> struct EliminationTraits<DiscreteFactorGraph>
* A Discrete Factor Graph is a factor graph where all factors are Discrete, i.e.
* Factor == DiscreteFactor
*/
class GTSAM_EXPORT DiscreteFactorGraph: public FactorGraph<DiscreteFactor>,
public EliminateableFactorGraph<DiscreteFactorGraph> {
public:
class GTSAM_EXPORT DiscreteFactorGraph
: public FactorGraph<DiscreteFactor>,
public EliminateableFactorGraph<DiscreteFactorGraph> {
public:
using This = DiscreteFactorGraph; ///< this class
using Base = FactorGraph<DiscreteFactor>; ///< base factor graph type
using BaseEliminateable =
EliminateableFactorGraph<This>; ///< for elimination
using shared_ptr = boost::shared_ptr<This>; ///< shared_ptr to This
typedef DiscreteFactorGraph This; ///< Typedef to this class
typedef FactorGraph<DiscreteFactor> Base; ///< Typedef to base factor graph type
typedef EliminateableFactorGraph<This> BaseEliminateable; ///< Typedef to base elimination class
typedef boost::shared_ptr<This> shared_ptr; ///< shared_ptr to this class
using Values = DiscreteValues; ///< backwards compatibility
using Values = DiscreteValues; ///< backwards compatibility
/** A map from keys to values */
typedef KeyVector Indices;
using Indices = KeyVector; ///> map from keys to values
/** Default constructor */
DiscreteFactorGraph() {}
/** Construct from iterator over factors */
template<typename ITERATOR>
DiscreteFactorGraph(ITERATOR firstFactor, ITERATOR lastFactor) : Base(firstFactor, lastFactor) {}
template <typename ITERATOR>
DiscreteFactorGraph(ITERATOR firstFactor, ITERATOR lastFactor)
: Base(firstFactor, lastFactor) {}
/** Construct from container of factors (shared_ptr or plain objects) */
template<class CONTAINER>
template <class CONTAINER>
explicit DiscreteFactorGraph(const CONTAINER& factors) : Base(factors) {}
/** Implicit copy/downcast constructor to override explicit template container constructor */
template<class DERIVEDFACTOR>
/** Implicit copy/downcast constructor to override explicit template container
* constructor */
template <class DERIVEDFACTOR>
DiscreteFactorGraph(const FactorGraph<DERIVEDFACTOR>& graph) : Base(graph) {}
/// Destructor
@ -108,7 +110,7 @@ public:
void add(Args&&... args) {
emplace_shared<DecisionTreeFactor>(std::forward<Args>(args)...);
}
/** Return the set of variables involved in the factors (set union) */
KeySet keys() const;
@ -163,9 +165,10 @@ public:
const DiscreteFactor::Names& names = {}) const;
/// @}
}; // \ DiscreteFactorGraph
}; // \ DiscreteFactorGraph
/// traits
template<> struct traits<DiscreteFactorGraph> : public Testable<DiscreteFactorGraph> {};
template <>
struct traits<DiscreteFactorGraph> : public Testable<DiscreteFactorGraph> {};
} // \ namespace gtsam
} // namespace gtsam

View File

@ -107,8 +107,7 @@ virtual class DiscreteConditional : gtsam::DecisionTreeFactor {
void printSignature(
string s = "Discrete Conditional: ",
const gtsam::KeyFormatter& formatter = gtsam::DefaultKeyFormatter) const;
gtsam::DecisionTreeFactor* choose(
const gtsam::DiscreteValues& parentsValues) const;
gtsam::DecisionTreeFactor* choose(const gtsam::DiscreteValues& given) const;
gtsam::DecisionTreeFactor* likelihood(
const gtsam::DiscreteValues& frontalValues) const;
gtsam::DecisionTreeFactor* likelihood(size_t value) const;
@ -230,11 +229,16 @@ class DiscreteFactorGraph {
DiscreteFactorGraph();
DiscreteFactorGraph(const gtsam::DiscreteBayesNet& bayesNet);
void add(const gtsam::DiscreteKey& j, string table);
// Building the graph
void push_back(const gtsam::DiscreteFactor* factor);
void push_back(const gtsam::DiscreteConditional* conditional);
void push_back(const gtsam::DiscreteFactorGraph& graph);
void push_back(const gtsam::DiscreteBayesNet& bayesNet);
void push_back(const gtsam::DiscreteBayesTree& bayesTree);
void add(const gtsam::DiscreteKey& j, string spec);
void add(const gtsam::DiscreteKey& j, const std::vector<double>& spec);
void add(const gtsam::DiscreteKeys& keys, string table);
void add(const std::vector<gtsam::DiscreteKey>& keys, string table);
void add(const gtsam::DiscreteKeys& keys, string spec);
void add(const std::vector<gtsam::DiscreteKey>& keys, string spec);
bool empty() const;
size_t size() const;
@ -258,8 +262,12 @@ class DiscreteFactorGraph {
gtsam::DiscreteBayesNet eliminateSequential();
gtsam::DiscreteBayesNet eliminateSequential(const gtsam::Ordering& ordering);
std::pair<gtsam::DiscreteBayesNet, gtsam::DiscreteFactorGraph>
eliminatePartialSequential(const gtsam::Ordering& ordering);
gtsam::DiscreteBayesTree eliminateMultifrontal();
gtsam::DiscreteBayesTree eliminateMultifrontal(const gtsam::Ordering& ordering);
std::pair<gtsam::DiscreteBayesTree, gtsam::DiscreteFactorGraph>
eliminatePartialMultifrontal(const gtsam::Ordering& ordering);
string markdown(const gtsam::KeyFormatter& keyFormatter =
gtsam::DefaultKeyFormatter) const;

View File

@ -221,6 +221,34 @@ TEST(DiscreteConditional, likelihood) {
EXPECT(assert_equal(expected1, *actual1, 1e-9));
}
/* ************************************************************************* */
// Check choose on P(C|D,E)
TEST(DiscreteConditional, choose) {
DiscreteKey C(2, 2), D(4, 2), E(3, 2);
DiscreteConditional C_given_DE((C | D, E) = "4/1 1/1 1/1 1/4");
// Case 1: no given values: no-op
DiscreteValues given;
auto actual1 = C_given_DE.choose(given);
EXPECT(assert_equal(C_given_DE, *actual1, 1e-9));
// Case 2: 1 given value
given[D.first] = 1;
auto actual2 = C_given_DE.choose(given);
EXPECT_LONGS_EQUAL(1, actual2->nrFrontals());
EXPECT_LONGS_EQUAL(1, actual2->nrParents());
DiscreteConditional expected2(C | E = "1/1 1/4");
EXPECT(assert_equal(expected2, *actual2, 1e-9));
// Case 2: 2 given values
given[E.first] = 0;
auto actual3 = C_given_DE.choose(given);
EXPECT_LONGS_EQUAL(1, actual3->nrFrontals());
EXPECT_LONGS_EQUAL(0, actual3->nrParents());
DiscreteConditional expected3(C % "1/1");
EXPECT(assert_equal(expected3, *actual3, 1e-9));
}
/* ************************************************************************* */
// Check markdown representation looks as expected, no parents.
TEST(DiscreteConditional, markdown_prior) {

View File

@ -376,8 +376,12 @@ TEST(DiscreteFactorGraph, Dot) {
" var1[label=\"1\"];\n"
" var2[label=\"2\"];\n"
"\n"
" var0--var1;\n"
" var0--var2;\n"
" factor0[label=\"\", shape=point];\n"
" var0--factor0;\n"
" var1--factor0;\n"
" factor1[label=\"\", shape=point];\n"
" var0--factor1;\n"
" var2--factor1;\n"
"}\n";
EXPECT(actual == expected);
}
@ -397,12 +401,16 @@ TEST(DiscreteFactorGraph, DotWithNames) {
"graph {\n"
" size=\"5,5\";\n"
"\n"
" var0[label=\"C\"];\n"
" var1[label=\"A\"];\n"
" var2[label=\"B\"];\n"
" varC[label=\"C\"];\n"
" varA[label=\"A\"];\n"
" varB[label=\"B\"];\n"
"\n"
" var0--var1;\n"
" var0--var2;\n"
" factor0[label=\"\", shape=point];\n"
" varC--factor0;\n"
" varA--factor0;\n"
" factor1[label=\"\", shape=point];\n"
" varC--factor1;\n"
" varB--factor1;\n"
"}\n";
EXPECT(actual == expected);
}

View File

@ -35,7 +35,8 @@ void DotWriter::DrawVariable(Key key, const KeyFormatter& keyFormatter,
const boost::optional<Vector2>& position,
ostream* os) {
// Label the node with the label from the KeyFormatter
*os << " var" << key << "[label=\"" << keyFormatter(key) << "\"";
*os << " var" << keyFormatter(key) << "[label=\"" << keyFormatter(key)
<< "\"";
if (position) {
*os << ", pos=\"" << position->x() << "," << position->y() << "!\"";
}
@ -51,22 +52,26 @@ void DotWriter::DrawFactor(size_t i, const boost::optional<Vector2>& position,
*os << "];\n";
}
void DotWriter::ConnectVariables(Key key1, Key key2, ostream* os) {
*os << " var" << key1 << "--"
<< "var" << key2 << ";\n";
static void ConnectVariables(Key key1, Key key2,
const KeyFormatter& keyFormatter,
ostream* os) {
*os << " var" << keyFormatter(key1) << "--"
<< "var" << keyFormatter(key2) << ";\n";
}
void DotWriter::ConnectVariableFactor(Key key, size_t i, ostream* os) {
*os << " var" << key << "--"
static void ConnectVariableFactor(Key key, const KeyFormatter& keyFormatter,
size_t i, ostream* os) {
*os << " var" << keyFormatter(key) << "--"
<< "factor" << i << ";\n";
}
void DotWriter::processFactor(size_t i, const KeyVector& keys,
const KeyFormatter& keyFormatter,
const boost::optional<Vector2>& position,
ostream* os) const {
if (plotFactorPoints) {
if (binaryEdges && keys.size() == 2) {
ConnectVariables(keys[0], keys[1], os);
ConnectVariables(keys[0], keys[1], keyFormatter, os);
} else {
// Create dot for the factor.
DrawFactor(i, position, os);
@ -74,7 +79,7 @@ void DotWriter::processFactor(size_t i, const KeyVector& keys,
// Make factor-variable connections
if (connectKeysToFactor) {
for (Key key : keys) {
ConnectVariableFactor(key, i, os);
ConnectVariableFactor(key, keyFormatter, i, os);
}
}
}
@ -83,7 +88,7 @@ void DotWriter::processFactor(size_t i, const KeyVector& keys,
for (Key key1 : keys) {
for (Key key2 : keys) {
if (key2 > key1) {
ConnectVariables(key1, key2, os);
ConnectVariables(key1, key2, keyFormatter, os);
}
}
}

View File

@ -38,7 +38,7 @@ struct GTSAM_EXPORT DotWriter {
explicit DotWriter(double figureWidthInches = 5,
double figureHeightInches = 5,
bool plotFactorPoints = true,
bool connectKeysToFactor = true, bool binaryEdges = true)
bool connectKeysToFactor = true, bool binaryEdges = false)
: figureWidthInches(figureWidthInches),
figureHeightInches(figureHeightInches),
plotFactorPoints(plotFactorPoints),
@ -57,14 +57,9 @@ struct GTSAM_EXPORT DotWriter {
static void DrawFactor(size_t i, const boost::optional<Vector2>& position,
std::ostream* os);
/// Connect two variables.
static void ConnectVariables(Key key1, Key key2, std::ostream* os);
/// Connect variable and factor.
static void ConnectVariableFactor(Key key, size_t i, std::ostream* os);
/// Draw a single factor, specified by its index i and its variable keys.
void processFactor(size_t i, const KeyVector& keys,
const KeyFormatter& keyFormatter,
const boost::optional<Vector2>& position,
std::ostream* os) const;
};

View File

@ -144,7 +144,7 @@ void FactorGraph<FACTOR>::dot(std::ostream& os,
const auto& factor = at(i);
if (factor) {
const KeyVector& factorKeys = factor->keys();
writer.processFactor(i, factorKeys, boost::none, &os);
writer.processFactor(i, factorKeys, keyFormatter, boost::none, &os);
}
}

View File

@ -33,8 +33,10 @@
# include <tbb/parallel_for.h>
#endif
#include <algorithm>
#include <cmath>
#include <fstream>
#include <set>
using namespace std;
@ -127,7 +129,7 @@ void NonlinearFactorGraph::dot(std::ostream& os, const Values& values,
// Create factors and variable connections
size_t i = 0;
for (const KeyVector& factorKeys : structure) {
writer.processFactor(i++, factorKeys, boost::none, &os);
writer.processFactor(i++, factorKeys, keyFormatter, boost::none, &os);
}
} else {
// Create factors and variable connections
@ -135,7 +137,8 @@ void NonlinearFactorGraph::dot(std::ostream& os, const Values& values,
const NonlinearFactor::shared_ptr& factor = at(i);
if (factor) {
const KeyVector& factorKeys = factor->keys();
writer.processFactor(i, factorKeys, writer.factorPos(min, i), &os);
writer.processFactor(i, factorKeys, keyFormatter,
writer.factorPos(min, i), &os);
}
}
}

View File

@ -335,15 +335,21 @@ TEST(NonlinearFactorGraph, dot) {
"graph {\n"
" size=\"5,5\";\n"
"\n"
" var7782220156096217089[label=\"l1\"];\n"
" var8646911284551352321[label=\"x1\"];\n"
" var8646911284551352322[label=\"x2\"];\n"
" varl1[label=\"l1\"];\n"
" varx1[label=\"x1\"];\n"
" varx2[label=\"x2\"];\n"
"\n"
" factor0[label=\"\", shape=point];\n"
" var8646911284551352321--factor0;\n"
" var8646911284551352321--var8646911284551352322;\n"
" var8646911284551352321--var7782220156096217089;\n"
" var8646911284551352322--var7782220156096217089;\n"
" varx1--factor0;\n"
" factor1[label=\"\", shape=point];\n"
" varx1--factor1;\n"
" varx2--factor1;\n"
" factor2[label=\"\", shape=point];\n"
" varx1--factor2;\n"
" varl1--factor2;\n"
" factor3[label=\"\", shape=point];\n"
" varx2--factor3;\n"
" varl1--factor3;\n"
"}\n";
const NonlinearFactorGraph fg = createNonlinearFactorGraph();
@ -357,15 +363,21 @@ TEST(NonlinearFactorGraph, dot_extra) {
"graph {\n"
" size=\"5,5\";\n"
"\n"
" var7782220156096217089[label=\"l1\", pos=\"0,0!\"];\n"
" var8646911284551352321[label=\"x1\", pos=\"1,0!\"];\n"
" var8646911284551352322[label=\"x2\", pos=\"1,1.5!\"];\n"
" varl1[label=\"l1\", pos=\"0,0!\"];\n"
" varx1[label=\"x1\", pos=\"1,0!\"];\n"
" varx2[label=\"x2\", pos=\"1,1.5!\"];\n"
"\n"
" factor0[label=\"\", shape=point];\n"
" var8646911284551352321--factor0;\n"
" var8646911284551352321--var8646911284551352322;\n"
" var8646911284551352321--var7782220156096217089;\n"
" var8646911284551352322--var7782220156096217089;\n"
" varx1--factor0;\n"
" factor1[label=\"\", shape=point];\n"
" varx1--factor1;\n"
" varx2--factor1;\n"
" factor2[label=\"\", shape=point];\n"
" varx1--factor2;\n"
" varl1--factor2;\n"
" factor3[label=\"\", shape=point];\n"
" varx2--factor3;\n"
" varl1--factor3;\n"
"}\n";
const NonlinearFactorGraph fg = createNonlinearFactorGraph();