diff --git a/gtsam/discrete/DecisionTreeFactor.h b/gtsam/discrete/DecisionTreeFactor.h index 891030f49..3764a01c4 100644 --- a/gtsam/discrete/DecisionTreeFactor.h +++ b/gtsam/discrete/DecisionTreeFactor.h @@ -80,7 +80,7 @@ namespace gtsam { * @endcode * * The values in the table should be laid out so that the first key varies - * the slowest. and the last key the fastest. + * the slowest, and the last key the fastest. */ DecisionTreeFactor(const DiscreteKeys& keys, const std::vector& table); @@ -101,7 +101,7 @@ namespace gtsam { * @endcode * * The values in the table should be laid out so that the first key varies - * the slowest. and the last key the fastest. + * the slowest, and the last key the fastest. */ DecisionTreeFactor(const DiscreteKeys& keys, const std::string& table); diff --git a/gtsam/discrete/DiscreteBayesTree.h b/gtsam/discrete/DiscreteBayesTree.h index 03a9a2fc7..b37b64908 100644 --- a/gtsam/discrete/DiscreteBayesTree.h +++ b/gtsam/discrete/DiscreteBayesTree.h @@ -59,6 +59,11 @@ class GTSAM_EXPORT DiscreteBayesTreeClique //** evaluate conditional probability of subtree for given DiscreteValues */ double evaluate(const DiscreteValues& values) const; + + //** (Preferred) sugar for the above for given DiscreteValues */ + double operator()(const DiscreteValues& values) const { + return evaluate(values); + } }; /* ************************************************************************* */ diff --git a/gtsam/discrete/discrete.i b/gtsam/discrete/discrete.i index 040dabf48..32f0439ea 100644 --- a/gtsam/discrete/discrete.i +++ b/gtsam/discrete/discrete.i @@ -215,6 +215,7 @@ class DiscreteBayesTreeClique { const string& s = "Clique: ", const gtsam::KeyFormatter& formatter = gtsam::DefaultKeyFormatter) const; double evaluate(const gtsam::DiscreteValues& values) const; + double operator()(const gtsam::DiscreteValues& values) const; }; class DiscreteBayesTree { @@ -229,6 +230,7 @@ class DiscreteBayesTree { const DiscreteBayesTreeClique* operator[](size_t j) const; double evaluate(const gtsam::DiscreteValues& values) const; + double operator()(const gtsam::DiscreteValues& values) const; string dot(const gtsam::KeyFormatter& keyFormatter = gtsam::DefaultKeyFormatter) const; diff --git a/gtsam/discrete/tests/testDiscreteBayesTree.cpp b/gtsam/discrete/tests/testDiscreteBayesTree.cpp index 78e254aed..00020f567 100644 --- a/gtsam/discrete/tests/testDiscreteBayesTree.cpp +++ b/gtsam/discrete/tests/testDiscreteBayesTree.cpp @@ -16,6 +16,7 @@ */ #include +#include #include #include #include @@ -26,7 +27,6 @@ #include #include -using namespace std; using namespace gtsam; static constexpr bool debug = false; @@ -108,7 +108,7 @@ TEST(DiscreteBayesTree, ThinTree) { /* ************************************************************************* */ // Check calculation of separator marginals -TEST(DiscreteBayesTree, separatorMarginal) { +TEST(DiscreteBayesTree, SeparatorMarginals) { TestFixture self; // Calculate some marginals for DiscreteValues==all1 @@ -141,7 +141,7 @@ TEST(DiscreteBayesTree, separatorMarginal) { /* ************************************************************************* */ // Check shortcuts in the tree -TEST(DiscreteBayesTree, shortcut) { +TEST(DiscreteBayesTree, Shortcuts) { TestFixture self; // Calculate some marginals for DiscreteValues==all1 @@ -199,7 +199,7 @@ TEST(DiscreteBayesTree, shortcut) { /* ************************************************************************* */ // Check all marginals -TEST(DiscreteBayesTree, marginalFactor) { +TEST(DiscreteBayesTree, MarginalFactors) { TestFixture self; Vector marginals = Vector::Zero(15); @@ -286,7 +286,7 @@ TEST(DiscreteBayesTree, Joints) { /* ************************************************************************* */ TEST(DiscreteBayesTree, Dot) { TestFixture self; - string actual = self.bayesTree->dot(); + std::string actual = self.bayesTree->dot(); EXPECT(actual == "digraph G{\n" "0[label=\"13, 11, 6, 7\"];\n" @@ -313,6 +313,61 @@ TEST(DiscreteBayesTree, Dot) { "}"); } +/* ************************************************************************* */ +// Check that we can have a multi-frontal lookup table +TEST(DiscreteBayesTree, Lookup) { + using gtsam::symbol_shorthand::A; + using gtsam::symbol_shorthand::X; + + // Make a small planning-like graph: 3 states, 2 actions + DiscreteFactorGraph graph; + const DiscreteKey x1{X(1), 3}, x2{X(2), 3}, x3{X(3), 3}; + const DiscreteKey a1{A(1), 2}, a2{A(2), 2}; + const DiscreteKeys keys{x1, x2, x3, a1, a2}; + // Constraint on start and goal + graph.add(DiscreteKeys{x1}, std::vector{1, 0, 0}); + graph.add(DiscreteKeys{x3}, std::vector{0, 0, 1}); + // Should I stay or should I go? + // "Reward" (exp(-cost)) for an action is 10, and rewards multiply: + const double r = 10; + std::vector table{ + r, 0, 0, 0, r, 0, // x1 = 0 + 0, r, 0, 0, 0, r, // x1 = 1 + 0, 0, r, 0, 0, r // x1 = 2 + }; + graph.add(DiscreteKeys{x1, a1, x2}, table); + graph.add(DiscreteKeys{x2, a2, x3}, table); + + // eliminate for MPE (maximum probable explanation). + Ordering ordering{A(2), X(3), X(1), A(1), X(2)}; + auto lookup = graph.eliminateMultifrontal(ordering, EliminateForMPE); + + // Check that the lookup table is correct + EXPECT_LONGS_EQUAL(2, lookup->size()); + auto lookup_x1_a1_x2 = (*lookup)[X(1)]->conditional(); + EXPECT_LONGS_EQUAL(3, lookup_x1_a1_x2->frontals().size()); + // check that sum is 100 + DiscreteValues empty; + EXPECT_DOUBLES_EQUAL(100, (*lookup_x1_a1_x2->sum(3))(empty), 1e-9); + // And that only non-zero reward is for x1 a1 x2 == 0 1 1 + EXPECT_DOUBLES_EQUAL(100, (*lookup_x1_a1_x2)({{X(1),0},{A(1),1},{X(2),1}}), 1e-9); + + auto lookup_a2_x3 = (*lookup)[X(3)]->conditional(); + // check that the sum depends on x2 and is non-zero only for x2 \in {1,2} + auto sum_x2 = lookup_a2_x3->sum(2); + EXPECT_DOUBLES_EQUAL(0, (*sum_x2)({{X(2),0}}), 1e-9); + EXPECT_DOUBLES_EQUAL(10, (*sum_x2)({{X(2),1}}), 1e-9); + EXPECT_DOUBLES_EQUAL(20, (*sum_x2)({{X(2),2}}), 1e-9); + EXPECT_LONGS_EQUAL(2, lookup_a2_x3->frontals().size()); + // And that the non-zero rewards are for + // x2 a2 x3 == 1 1 2 + EXPECT_DOUBLES_EQUAL(10, (*lookup_a2_x3)({{X(2),1},{A(2),1},{X(3),2}}), 1e-9); + // x2 a2 x3 == 2 0 2 + EXPECT_DOUBLES_EQUAL(10, (*lookup_a2_x3)({{X(2),2},{A(2),0},{X(3),2}}), 1e-9); + // x2 a2 x3 == 2 1 2 + EXPECT_DOUBLES_EQUAL(10, (*lookup_a2_x3)({{X(2),2},{A(2),1},{X(3),2}}), 1e-9); +} + /* ************************************************************************* */ int main() { TestResult tr;