diff --git a/gtsam/discrete/DiscreteSearch.cpp b/gtsam/discrete/DiscreteSearch.cpp index 16794dcfc..81ebc6012 100644 --- a/gtsam/discrete/DiscreteSearch.cpp +++ b/gtsam/discrete/DiscreteSearch.cpp @@ -31,8 +31,8 @@ SearchNode SearchNode::expand(const DiscreteConditional& conditional, const DiscreteValues& fa) const { // Combine the new frontal assignment with the current partial assignment DiscreteValues newAssignment = assignment; - for (auto& kv : fa) { - newAssignment[kv.first] = kv.second; + for (auto& [key, value] : fa) { + newAssignment[key] = value; } return {.assignment = newAssignment, @@ -50,11 +50,10 @@ bool Solutions::maybeAdd(double error, const DiscreteValues& assignment) { } std::ostream& operator<<(std::ostream& os, const Solutions& sn) { + os << "Solutions (top " << sn.pq_.size() << "):\n"; auto pq = sn.pq_; while (!pq.empty()) { - const Solution& best = pq.top(); - os << "Error: " << best.error << ", Values: " << best.assignment - << std::endl; + os << pq.top() << "\n"; pq.pop(); } return os; @@ -62,8 +61,7 @@ std::ostream& operator<<(std::ostream& os, const Solutions& sn) { bool Solutions::prune(double bound) const { if (pq_.size() < maxSize_) return false; - double worstError = pq_.top().error; - return (bound >= worstError); + return bound >= pq_.top().error; } std::vector Solutions::extractSolutions() { @@ -80,45 +78,23 @@ std::vector Solutions::extractSolutions() { DiscreteSearch::DiscreteSearch(const DiscreteBayesNet& bayesNet, size_t K) : solutions_(K) { - // Copy out the conditionals - for (auto& factor : bayesNet) { - conditionals_.push_back(factor); - } - - // Calculate the cost-to-go for each conditional - costToGo_ = computeCostToGo(conditionals_); - - // Create the root node and push it to the expansions queue - expansions_.push(SearchNode::Root( - conditionals_.size(), costToGo_.empty() ? 0.0 : costToGo_.back())); + std::vector conditionals; + for (auto& factor : bayesNet) conditionals.push_back(factor); + initialize(conditionals); } DiscreteSearch::DiscreteSearch(const DiscreteBayesTree& bayesTree, size_t K) : solutions_(K) { - using CliquePtr = DiscreteBayesTree::sharedClique; - std::function collectConditionals = - [&](const CliquePtr& clique) -> void { - if (!clique) return; - - // Recursive post-order traversal: process children first - for (const auto& child : clique->children) { - collectConditionals(child); - } - - // Then add the current clique's conditional - conditionals_.push_back(clique->conditional()); - }; - - // Start traversal from each root in the tree + std::vector conditionals; + std::function + collectConditionals = [&](const auto& clique) { + if (!clique) return; + for (const auto& child : clique->children) collectConditionals(child); + conditionals.push_back(clique->conditional()); + }; for (const auto& root : bayesTree.roots()) collectConditionals(root); - - // Calculate the cost-to-go for each conditional - costToGo_ = computeCostToGo(conditionals_); - - // Create the root node and push it to the expansions queue - expansions_.push(SearchNode::Root( - conditionals_.size(), costToGo_.empty() ? 0.0 : costToGo_.back())); -} + initialize(conditionals); +}; std::vector DiscreteSearch::run() { while (!expansions_.empty()) { @@ -170,7 +146,7 @@ void DiscreteSearch::expandNextNode() { // Again, prune if we cannot beat the worst solution if (!solutions_.prune(childNode.bound)) { - expansions_.push(childNode); + expansions_.emplace(childNode); } } } diff --git a/gtsam/discrete/DiscreteSearch.h b/gtsam/discrete/DiscreteSearch.h index 989437b9f..cbf258b3f 100644 --- a/gtsam/discrete/DiscreteSearch.h +++ b/gtsam/discrete/DiscreteSearch.h @@ -19,6 +19,8 @@ #include #include +#include + namespace gtsam { using Value = size_t; @@ -52,7 +54,7 @@ struct SearchNode { * * @return True if all variables have been assigned, false otherwise. */ - bool isComplete() const { return nextConditional < 0; } + inline bool isComplete() const { return nextConditional < 0; } /** * @brief Expands the node by assigning the next variable. @@ -133,12 +135,12 @@ class DiscreteSearch { /** * Construct from a DiscreteBayesNet and K. */ - DiscreteSearch(const DiscreteBayesNet& bayesNet, size_t K); + DiscreteSearch(const DiscreteBayesNet& bayesNet, size_t K = 1); /** * Construct from a DiscreteBayesTree and K. */ - DiscreteSearch(const DiscreteBayesTree& bayesTree, size_t K); + DiscreteSearch(const DiscreteBayesTree& bayesTree, size_t K = 1); /** * @brief Search for the K best solutions. @@ -153,18 +155,20 @@ class DiscreteSearch { std::vector run(); private: - /** - * @brief Compute the cost-to-go for each conditional. - * - * @param conditionals The conditionals of the DiscreteBayesNet. - * @return A vector of cost-to-go values. - */ + /// Initialize the search with the given conditionals. + void initialize( + const std::vector& conditionals) { + conditionals_ = conditionals; + costToGo_ = computeCostToGo(conditionals_); + expansions_.push(SearchNode::Root( + conditionals_.size(), costToGo_.empty() ? 0.0 : costToGo_.back())); + } + + /// Compute the cumulative cost-to-go for each conditional slot. static std::vector computeCostToGo( const std::vector& conditionals); - /** - * @brief Expand the next node in the search tree. - */ + /// Expand the next node in the search tree. void expandNextNode(); std::vector conditionals_; diff --git a/gtsam/discrete/tests/testDiscreteSearch.cpp b/gtsam/discrete/tests/testDiscreteSearch.cpp index a39bcd86b..e301ed3d4 100644 --- a/gtsam/discrete/tests/testDiscreteSearch.cpp +++ b/gtsam/discrete/tests/testDiscreteSearch.cpp @@ -46,20 +46,32 @@ TEST(DiscreteBayesNet, EmptyKBest) { TEST(DiscreteBayesNet, AsiaKBest) { using namespace asia_example; DiscreteBayesNet asia = createAsiaExample(); + + // Ask for the MPE + DiscreteSearch search1(asia); + auto mpe = search1.run(); + + // print numExpansions + std::cout << "Number of expansions: " << search1.numExpansions << std::endl; + + EXPECT_LONGS_EQUAL(1, mpe.size()); + // Regression test: check the MPE solution + EXPECT_DOUBLES_EQUAL(1.236627, std::fabs(mpe[0].error), 1e-5); + DiscreteSearch search(asia, 4); auto solutions = search.run(); // print numExpansions std::cout << "Number of expansions: " << search.numExpansions << std::endl; - EXPECT(!solutions.empty()); + EXPECT_LONGS_EQUAL(4, solutions.size()); // Regression test: check the first and last solution EXPECT_DOUBLES_EQUAL(1.236627, std::fabs(solutions[0].error), 1e-5); EXPECT_DOUBLES_EQUAL(2.201708, std::fabs(solutions[3].error), 1e-5); } /* ************************************************************************* */ -TEST(DiscreteBayesTree, testEmptyTree) { +TEST(DiscreteBayesTree, EmptyTree) { DiscreteBayesTree bt; DiscreteSearch search(bt, 3); @@ -72,12 +84,23 @@ TEST(DiscreteBayesTree, testEmptyTree) { } /* ************************************************************************* */ -TEST(DiscreteBayesTree, testTrivialOneClique) { +TEST(DiscreteBayesTree, AsiaTreeKBest) { using namespace asia_example; DiscreteFactorGraph asia(createAsiaExample()); const Ordering ordering{D, X, B, E, L, T, S, A}; DiscreteBayesTree::shared_ptr bt = asia.eliminateMultifrontal(ordering); + // Ask for top 4 solutions + DiscreteSearch search1(*bt); + auto mpe = search1.run(); + + // print numExpansions + std::cout << "Number of expansions: " << search1.numExpansions << std::endl; + + EXPECT_LONGS_EQUAL(1, mpe.size()); + // Regression test: check the MPE solution + EXPECT_DOUBLES_EQUAL(1.236627, std::fabs(mpe[0].error), 1e-5); + // Ask for top 4 solutions DiscreteSearch search(*bt, 4); auto solutions = search.run(); @@ -85,7 +108,7 @@ TEST(DiscreteBayesTree, testTrivialOneClique) { // print numExpansions std::cout << "Number of expansions: " << search.numExpansions << std::endl; - EXPECT(!solutions.empty()); + EXPECT_LONGS_EQUAL(4, solutions.size()); // Regression test: check the first and last solution EXPECT_DOUBLES_EQUAL(1.236627, std::fabs(solutions[0].error), 1e-5); EXPECT_DOUBLES_EQUAL(2.201708, std::fabs(solutions[3].error), 1e-5);