From 70089a0fd4e0c3c7e87cb0eb2493e82b8165235d Mon Sep 17 00:00:00 2001 From: Frank Dellaert Date: Sun, 26 Jan 2025 23:34:30 -0500 Subject: [PATCH] Move to cpp file --- gtsam/discrete/DiscreteSearch.cpp | 178 ++++++++++++++++++++ gtsam/discrete/DiscreteSearch.h | 169 +++---------------- gtsam/discrete/tests/testDiscreteSearch.cpp | 5 +- 3 files changed, 201 insertions(+), 151 deletions(-) create mode 100644 gtsam/discrete/DiscreteSearch.cpp diff --git a/gtsam/discrete/DiscreteSearch.cpp b/gtsam/discrete/DiscreteSearch.cpp new file mode 100644 index 000000000..16794dcfc --- /dev/null +++ b/gtsam/discrete/DiscreteSearch.cpp @@ -0,0 +1,178 @@ +/* ---------------------------------------------------------------------------- + + * GTSAM Copyright 2010, Georgia Tech Research Corporation, + * Atlanta, Georgia 30332-0415 + * All Rights Reserved + * Authors: Frank Dellaert, et al. (see THANKS for the full author list) + + * See LICENSE for the license information + + * -------------------------------------------------------------------------- */ + +/* + * DiscreteSearch.cpp + * + * @date January, 2025 + * @author Frank Dellaert + */ + +#include + +namespace gtsam { + +SearchNode SearchNode::Root(size_t numConditionals, double bound) { + return {.assignment = DiscreteValues(), + .error = 0.0, + .bound = bound, + .nextConditional = static_cast(numConditionals) - 1}; +} + +SearchNode SearchNode::expand(const DiscreteConditional& conditional, + const DiscreteValues& fa) const { + // Combine the new frontal assignment with the current partial assignment + DiscreteValues newAssignment = assignment; + for (auto& kv : fa) { + newAssignment[kv.first] = kv.second; + } + + return {.assignment = newAssignment, + .error = error + conditional.error(newAssignment), + .bound = 0.0, + .nextConditional = nextConditional - 1}; +} + +bool Solutions::maybeAdd(double error, const DiscreteValues& assignment) { + const bool full = pq_.size() == maxSize_; + if (full && error >= pq_.top().error) return false; + if (full) pq_.pop(); + pq_.emplace(error, assignment); + return true; +} + +std::ostream& operator<<(std::ostream& os, const Solutions& sn) { + auto pq = sn.pq_; + while (!pq.empty()) { + const Solution& best = pq.top(); + os << "Error: " << best.error << ", Values: " << best.assignment + << std::endl; + pq.pop(); + } + return os; +} + +bool Solutions::prune(double bound) const { + if (pq_.size() < maxSize_) return false; + double worstError = pq_.top().error; + return (bound >= worstError); +} + +std::vector Solutions::extractSolutions() { + std::vector result; + while (!pq_.empty()) { + result.push_back(pq_.top()); + pq_.pop(); + } + std::sort( + result.begin(), result.end(), + [](const Solution& a, const Solution& b) { return a.error < b.error; }); + return result; +} + +DiscreteSearch::DiscreteSearch(const DiscreteBayesNet& bayesNet, size_t K) + : solutions_(K) { + // Copy out the conditionals + for (auto& factor : bayesNet) { + conditionals_.push_back(factor); + } + + // Calculate the cost-to-go for each conditional + costToGo_ = computeCostToGo(conditionals_); + + // Create the root node and push it to the expansions queue + expansions_.push(SearchNode::Root( + conditionals_.size(), costToGo_.empty() ? 0.0 : costToGo_.back())); +} + +DiscreteSearch::DiscreteSearch(const DiscreteBayesTree& bayesTree, size_t K) + : solutions_(K) { + using CliquePtr = DiscreteBayesTree::sharedClique; + std::function collectConditionals = + [&](const CliquePtr& clique) -> void { + if (!clique) return; + + // Recursive post-order traversal: process children first + for (const auto& child : clique->children) { + collectConditionals(child); + } + + // Then add the current clique's conditional + conditionals_.push_back(clique->conditional()); + }; + + // Start traversal from each root in the tree + for (const auto& root : bayesTree.roots()) collectConditionals(root); + + // Calculate the cost-to-go for each conditional + costToGo_ = computeCostToGo(conditionals_); + + // Create the root node and push it to the expansions queue + expansions_.push(SearchNode::Root( + conditionals_.size(), costToGo_.empty() ? 0.0 : costToGo_.back())); +} + +std::vector DiscreteSearch::run() { + while (!expansions_.empty()) { + numExpansions++; + expandNextNode(); + } + + // Extract solutions from bestSolutions in ascending order of error + return solutions_.extractSolutions(); +} + +std::vector DiscreteSearch::computeCostToGo( + const std::vector& conditionals) { + std::vector costToGo; + double error = 0.0; + for (const auto& conditional : conditionals) { + Ordering ordering(conditional->begin(), conditional->end()); + auto maxx = conditional->max(ordering); + assert(maxx->size() == 1); + error -= std::log(maxx->evaluate({})); + costToGo.push_back(error); + } + return costToGo; +} + +void DiscreteSearch::expandNextNode() { + // Pop the partial assignment with the smallest bound + SearchNode current = expansions_.top(); + expansions_.pop(); + + // If we already have K solutions, prune if we cannot beat the worst one. + if (solutions_.prune(current.bound)) { + return; + } + + // Check if we have a complete assignment + if (current.isComplete()) { + solutions_.maybeAdd(current.error, current.assignment); + return; + } + + // Expand on the next factor + const auto& conditional = conditionals_[current.nextConditional]; + + for (auto& fa : conditional->frontalAssignments()) { + auto childNode = current.expand(*conditional, fa); + if (childNode.nextConditional >= 0) + childNode.bound = childNode.error + costToGo_[childNode.nextConditional]; + + // Again, prune if we cannot beat the worst solution + if (!solutions_.prune(childNode.bound)) { + expansions_.push(childNode); + } + } +} + +} // namespace gtsam diff --git a/gtsam/discrete/DiscreteSearch.h b/gtsam/discrete/DiscreteSearch.h index 24517ee29..989437b9f 100644 --- a/gtsam/discrete/DiscreteSearch.h +++ b/gtsam/discrete/DiscreteSearch.h @@ -39,14 +39,9 @@ struct SearchNode { /** * @brief Construct the root node for the search. */ - static SearchNode Root(size_t numConditionals, double bound) { - return {.assignment = DiscreteValues(), - .error = 0.0, - .bound = bound, - .nextConditional = static_cast(numConditionals) - 1}; - } + static SearchNode Root(size_t numConditionals, double bound); - struct CompareByBound { + struct Compare { bool operator()(const SearchNode& a, const SearchNode& b) const { return a.bound > b.bound; // smallest bound -> highest priority } @@ -68,18 +63,7 @@ struct SearchNode { * @return A new SearchNode representing the expanded state. */ SearchNode expand(const DiscreteConditional& conditional, - const DiscreteValues& fa) const { - // Combine the new frontal assignment with the current partial assignment - DiscreteValues newAssignment = assignment; - for (auto& kv : fa) { - newAssignment[kv.first] = kv.second; - } - - return {.assignment = newAssignment, - .error = error + conditional.error(newAssignment), - .bound = 0.0, - .nextConditional = nextConditional - 1}; - } + const DiscreteValues& fa) const; /** * @brief Prints the SearchNode to an output stream. @@ -103,69 +87,40 @@ struct Solution { os << "[ error=" << sn.error << " assignment={" << sn.assignment << "}]"; return os; } -}; -struct CompareByError { - bool operator()(const Solution& a, const Solution& b) const { - return a.error < b.error; - } + struct Compare { + bool operator()(const Solution& a, const Solution& b) const { + return a.error < b.error; + } + }; }; // Define the Solutions class class Solutions { private: size_t maxSize_; - std::priority_queue, CompareByError> pq_; + std::priority_queue, Solution::Compare> pq_; public: Solutions(size_t maxSize) : maxSize_(maxSize) {} /// Add a solution to the priority queue, possibly evicting the worst one. /// Return true if we added the solution. - bool maybeAdd(double error, const DiscreteValues& assignment) { - const bool full = pq_.size() == maxSize_; - if (full && error >= pq_.top().error) return false; - if (full) pq_.pop(); - pq_.emplace(error, assignment); - return true; - } + bool maybeAdd(double error, const DiscreteValues& assignment); /// Check if we have any solutions bool empty() const { return pq_.empty(); } // Method to print all solutions - friend std::ostream& operator<<(std::ostream& os, const Solutions& sn) { - auto pq = sn.pq_; - while (!pq.empty()) { - const Solution& best = pq.top(); - os << "Error: " << best.error << ", Values: " << best.assignment - << std::endl; - pq.pop(); - } - return os; - } + friend std::ostream& operator<<(std::ostream& os, const Solutions& sn); /// Check if (partial) solution with given bound can be pruned. If we have /// room, we never prune. Otherwise, prune if lower bound on error is worse /// than our current worst error. - bool prune(double bound) const { - if (pq_.size() < maxSize_) return false; - double worstError = pq_.top().error; - return (bound >= worstError); - } + bool prune(double bound) const; // Method to extract solutions in ascending order of error - std::vector extractSolutions() { - std::vector result; - while (!pq_.empty()) { - result.push_back(pq_.top()); - pq_.pop(); - } - std::sort( - result.begin(), result.end(), - [](const Solution& a, const Solution& b) { return a.error < b.error; }); - return result; - } + std::vector extractSolutions(); }; /** @@ -178,48 +133,12 @@ class DiscreteSearch { /** * Construct from a DiscreteBayesNet and K. */ - DiscreteSearch(const DiscreteBayesNet& bayesNet, size_t K) : solutions_(K) { - // Copy out the conditionals - for (auto& factor : bayesNet) { - conditionals_.push_back(factor); - } - - // Calculate the cost-to-go for each conditional - costToGo_ = computeCostToGo(conditionals_); - - // Create the root node and push it to the expansions queue - expansions_.push(SearchNode::Root( - conditionals_.size(), costToGo_.empty() ? 0.0 : costToGo_.back())); - } + DiscreteSearch(const DiscreteBayesNet& bayesNet, size_t K); /** * Construct from a DiscreteBayesTree and K. */ - DiscreteSearch(const DiscreteBayesTree& bayesTree, size_t K) : solutions_(K) { - using CliquePtr = DiscreteBayesTree::sharedClique; - std::function collectConditionals = - [&](const CliquePtr& clique) -> void { - if (!clique) return; - - // Recursive post-order traversal: process children first - for (const auto& child : clique->children) { - collectConditionals(child); - } - - // Then add the current clique's conditional - conditionals_.push_back(clique->conditional()); - }; - - // Start traversal from each root in the tree - for (const auto& root : bayesTree.roots()) collectConditionals(root); - - // Calculate the cost-to-go for each conditional - costToGo_ = computeCostToGo(conditionals_); - - // Create the root node and push it to the expansions queue - expansions_.push(SearchNode::Root( - conditionals_.size(), costToGo_.empty() ? 0.0 : costToGo_.back())); - } + DiscreteSearch(const DiscreteBayesTree& bayesTree, size_t K); /** * @brief Search for the K best solutions. @@ -231,15 +150,7 @@ class DiscreteSearch { * * @return A vector of the K best solutions found during the search. */ - std::vector run() { - while (!expansions_.empty()) { - numExpansions++; - expandNextNode(); - } - - // Extract solutions from bestSolutions in ascending order of error - return solutions_.extractSolutions(); - } + std::vector run(); private: /** @@ -249,58 +160,16 @@ class DiscreteSearch { * @return A vector of cost-to-go values. */ static std::vector computeCostToGo( - const std::vector& conditionals) { - std::vector costToGo; - double error = 0.0; - for (const auto& conditional : conditionals) { - Ordering ordering(conditional->begin(), conditional->end()); - auto maxx = conditional->max(ordering); - assert(maxx->size() == 1); - error -= std::log(maxx->evaluate({})); - costToGo.push_back(error); - } - return costToGo; - } + const std::vector& conditionals); /** * @brief Expand the next node in the search tree. */ - void expandNextNode() { - // Pop the partial assignment with the smallest bound - SearchNode current = expansions_.top(); - expansions_.pop(); - - // If we already have K solutions, prune if we cannot beat the worst one. - if (solutions_.prune(current.bound)) { - return; - } - - // Check if we have a complete assignment - if (current.isComplete()) { - solutions_.maybeAdd(current.error, current.assignment); - return; - } - - // Expand on the next factor - const auto& conditional = conditionals_[current.nextConditional]; - - for (auto& fa : conditional->frontalAssignments()) { - auto childNode = current.expand(*conditional, fa); - if (childNode.nextConditional >= 0) - childNode.bound = - childNode.error + costToGo_[childNode.nextConditional]; - - // Again, prune if we cannot beat the worst solution - if (!solutions_.prune(childNode.bound)) { - expansions_.push(childNode); - } - } - } + void expandNextNode(); std::vector conditionals_; std::vector costToGo_; - std::priority_queue, - SearchNode::CompareByBound> + std::priority_queue, SearchNode::Compare> expansions_; Solutions solutions_; }; diff --git a/gtsam/discrete/tests/testDiscreteSearch.cpp b/gtsam/discrete/tests/testDiscreteSearch.cpp index 041c3b465..a39bcd86b 100644 --- a/gtsam/discrete/tests/testDiscreteSearch.cpp +++ b/gtsam/discrete/tests/testDiscreteSearch.cpp @@ -48,6 +48,10 @@ TEST(DiscreteBayesNet, AsiaKBest) { DiscreteBayesNet asia = createAsiaExample(); DiscreteSearch search(asia, 4); auto solutions = search.run(); + + // print numExpansions + std::cout << "Number of expansions: " << search.numExpansions << std::endl; + EXPECT(!solutions.empty()); // Regression test: check the first and last solution EXPECT_DOUBLES_EQUAL(1.236627, std::fabs(solutions[0].error), 1e-5); @@ -73,7 +77,6 @@ TEST(DiscreteBayesTree, testTrivialOneClique) { DiscreteFactorGraph asia(createAsiaExample()); const Ordering ordering{D, X, B, E, L, T, S, A}; DiscreteBayesTree::shared_ptr bt = asia.eliminateMultifrontal(ordering); - GTSAM_PRINT(*bt); // Ask for top 4 solutions DiscreteSearch search(*bt, 4);