diff --git a/gtsam/discrete/DiscreteSearch.h b/gtsam/discrete/DiscreteSearch.h index d9ffae768..c267001b5 100644 --- a/gtsam/discrete/DiscreteSearch.h +++ b/gtsam/discrete/DiscreteSearch.h @@ -17,6 +17,7 @@ */ #include +#include namespace gtsam { @@ -214,6 +215,12 @@ class DiscreteSearch { expansions_.push(root); } + /** + * Construct from a DiscreteBayesNet and K. + */ + DiscreteSearch(const DiscreteBayesTree& bayesTree, size_t K) + : solutions_(K) {} + /** * @brief Search for the K best solutions. * diff --git a/gtsam/discrete/tests/testDiscreteSearch.cpp b/gtsam/discrete/tests/testDiscreteSearch.cpp index 4d7adbac5..b0f2b03e3 100644 --- a/gtsam/discrete/tests/testDiscreteSearch.cpp +++ b/gtsam/discrete/tests/testDiscreteSearch.cpp @@ -32,6 +32,7 @@ using namespace gtsam; +/* ************************************************************************* */ TEST(DiscreteBayesNet, EmptyKBest) { DiscreteBayesNet net; // no factors DiscreteSearch search(net, 3); @@ -41,6 +42,7 @@ TEST(DiscreteBayesNet, EmptyKBest) { EXPECT_DOUBLES_EQUAL(0, std::fabs(solutions[0].error), 1e-9); } +/* ************************************************************************* */ TEST(DiscreteBayesNet, AsiaKBest) { using namespace asia_example; DiscreteBayesNet asia = createAsiaExample(); @@ -51,6 +53,35 @@ TEST(DiscreteBayesNet, AsiaKBest) { EXPECT_DOUBLES_EQUAL(1.236627, std::fabs(solutions[0].error), 1e-5); } +/* ************************************************************************* */ +TEST(DiscreteBayesTree, testEmptyTree) { + DiscreteBayesTree bt; + + DiscreteSearch search(bt, 3); + auto solutions = search.run(); + + // We expect exactly 1 solution with error = 0.0 (the empty assignment). + assert(solutions.size() == 1 && "There should be exactly one empty solution"); + EXPECT_LONGS_EQUAL(1, solutions.size()); + EXPECT_DOUBLES_EQUAL(0, std::fabs(solutions[0].error), 1e-9); +} + +/* ************************************************************************* */ +TEST(DiscreteBayesTree, testTrivialOneClique) { + using namespace asia_example; + DiscreteFactorGraph asia(createAsiaExample()); + DiscreteBayesTree::shared_ptr bt = asia.eliminateMultifrontal(); + GTSAM_PRINT(*bt); + + // Ask for top 4 solutions + DiscreteSearch search(*bt, 4); + auto solutions = search.run(); + + EXPECT(!solutions.empty()); + // Regression test: check the first solution + EXPECT_DOUBLES_EQUAL(1.236627, std::fabs(solutions[0].error), 1e-5); +} + /* ************************************************************************* */ int main() { TestResult tr;