From cee77695956de909dad8de0ca626db0793016803 Mon Sep 17 00:00:00 2001 From: Frank Dellaert Date: Mon, 27 Jan 2025 00:32:03 -0500 Subject: [PATCH] Moved data structures to cpp --- gtsam/discrete/DiscreteSearch.cpp | 172 ++++++++++++++++++++++-------- gtsam/discrete/DiscreteSearch.h | 90 +--------------- 2 files changed, 126 insertions(+), 136 deletions(-) diff --git a/gtsam/discrete/DiscreteSearch.cpp b/gtsam/discrete/DiscreteSearch.cpp index 70f7f871a..5fca5c820 100644 --- a/gtsam/discrete/DiscreteSearch.cpp +++ b/gtsam/discrete/DiscreteSearch.cpp @@ -20,61 +20,139 @@ namespace gtsam { -SearchNode SearchNode::Root(size_t numConditionals, double bound) { - return {.assignment = DiscreteValues(), - .error = 0.0, - .bound = bound, - .nextConditional = static_cast(numConditionals) - 1}; -} +using Value = size_t; -SearchNode SearchNode::expand(const DiscreteConditional& conditional, - const DiscreteValues& fa) const { - // Combine the new frontal assignment with the current partial assignment - DiscreteValues newAssignment = assignment; - for (auto& [key, value] : fa) { - newAssignment[key] = value; +/** + * @brief Represents a node in the search tree for discrete search algorithms. + * + * @details Each SearchNode contains a partial assignment of discrete variables, + * the current error, a bound on the final error, and the index of the next + * conditional to be assigned. + */ +struct SearchNode { + DiscreteValues assignment; ///< Partial assignment of discrete variables. + double error; ///< Current error for the partial assignment. + double bound; ///< Lower bound on the final error for unassigned variables. + int nextConditional; ///< Index of the next conditional to be assigned. + + /** + * @brief Construct the root node for the search. + */ + static SearchNode Root(size_t numConditionals, double bound) { + return {.assignment = DiscreteValues(), + .error = 0.0, + .bound = bound, + .nextConditional = static_cast(numConditionals) - 1}; } - return {.assignment = newAssignment, - .error = error + conditional.error(newAssignment), - .bound = 0.0, - .nextConditional = nextConditional - 1}; -} + struct Compare { + bool operator()(const SearchNode& a, const SearchNode& b) const { + return a.bound > b.bound; // smallest bound -> highest priority + } + }; -bool Solutions::maybeAdd(double error, const DiscreteValues& assignment) { - const bool full = pq_.size() == maxSize_; - if (full && error >= pq_.top().error) return false; - if (full) pq_.pop(); - pq_.emplace(error, assignment); - return true; -} + /** + * @brief Checks if the node represents a complete assignment. + * + * @return True if all variables have been assigned, false otherwise. + */ + inline bool isComplete() const { return nextConditional < 0; } -std::ostream& operator<<(std::ostream& os, const Solutions& sn) { - os << "Solutions (top " << sn.pq_.size() << "):\n"; - auto pq = sn.pq_; - while (!pq.empty()) { - os << pq.top() << "\n"; - pq.pop(); + /** + * @brief Expands the node by assigning the next variable. + * + * @param conditional The discrete conditional representing the next variable + * to be assigned. + * @param fa The frontal assignment for the next variable. + * @return A new SearchNode representing the expanded state. + */ + SearchNode expand(const DiscreteConditional& conditional, + const DiscreteValues& fa) const { + // Combine the new frontal assignment with the current partial assignment + DiscreteValues newAssignment = assignment; + for (auto& [key, value] : fa) { + newAssignment[key] = value; + } + + return {.assignment = newAssignment, + .error = error + conditional.error(newAssignment), + .bound = 0.0, + .nextConditional = nextConditional - 1}; } - return os; -} -bool Solutions::prune(double bound) const { - if (pq_.size() < maxSize_) return false; - return bound >= pq_.top().error; -} - -std::vector Solutions::extractSolutions() { - std::vector result; - while (!pq_.empty()) { - result.push_back(pq_.top()); - pq_.pop(); + /** + * @brief Prints the SearchNode to an output stream. + * + * @param os The output stream. + * @param node The SearchNode to be printed. + * @return The output stream. + */ + friend std::ostream& operator<<(std::ostream& os, const SearchNode& node) { + os << "SearchNode(error=" << node.error << ", bound=" << node.bound << ")"; + return os; } - std::sort( - result.begin(), result.end(), - [](const Solution& a, const Solution& b) { return a.error < b.error; }); - return result; -} +}; + +struct CompareSolution { + bool operator()(const Solution& a, const Solution& b) const { + return a.error < b.error; + } +}; + +// Define the Solutions class +class Solutions { + private: + size_t maxSize_; + std::priority_queue, CompareSolution> pq_; + + public: + Solutions(size_t maxSize) : maxSize_(maxSize) {} + + /// Add a solution to the priority queue, possibly evicting the worst one. + /// Return true if we added the solution. + bool maybeAdd(double error, const DiscreteValues& assignment) { + const bool full = pq_.size() == maxSize_; + if (full && error >= pq_.top().error) return false; + if (full) pq_.pop(); + pq_.emplace(error, assignment); + return true; + } + + /// Check if we have any solutions + bool empty() const { return pq_.empty(); } + + // Method to print all solutions + friend std::ostream& operator<<(std::ostream& os, const Solutions& sn) { + os << "Solutions (top " << sn.pq_.size() << "):\n"; + auto pq = sn.pq_; + while (!pq.empty()) { + os << pq.top() << "\n"; + pq.pop(); + } + return os; + } + + /// Check if (partial) solution with given bound can be pruned. If we have + /// room, we never prune. Otherwise, prune if lower bound on error is worse + /// than our current worst error. + bool prune(double bound) const { + if (pq_.size() < maxSize_) return false; + return bound >= pq_.top().error; + } + + // Method to extract solutions in ascending order of error + std::vector extractSolutions() { + std::vector result; + while (!pq_.empty()) { + result.push_back(pq_.top()); + pq_.pop(); + } + std::sort( + result.begin(), result.end(), + [](const Solution& a, const Solution& b) { return a.error < b.error; }); + return result; + } +}; DiscreteSearch::DiscreteSearch(const DiscreteBayesNet& bayesNet) { std::vector conditionals; diff --git a/gtsam/discrete/DiscreteSearch.h b/gtsam/discrete/DiscreteSearch.h index f5bd67bae..78558a127 100644 --- a/gtsam/discrete/DiscreteSearch.h +++ b/gtsam/discrete/DiscreteSearch.h @@ -23,63 +23,9 @@ namespace gtsam { -using Value = size_t; - /** - * @brief Represents a node in the search tree for discrete search algorithms. - * - * @details Each SearchNode contains a partial assignment of discrete variables, - * the current error, a bound on the final error, and the index of the next - * conditional to be assigned. + * @brief A solution to a discrete search problem. */ -struct SearchNode { - DiscreteValues assignment; ///< Partial assignment of discrete variables. - double error; ///< Current error for the partial assignment. - double bound; ///< Lower bound on the final error for unassigned variables. - int nextConditional; ///< Index of the next conditional to be assigned. - - /** - * @brief Construct the root node for the search. - */ - static SearchNode Root(size_t numConditionals, double bound); - - struct Compare { - bool operator()(const SearchNode& a, const SearchNode& b) const { - return a.bound > b.bound; // smallest bound -> highest priority - } - }; - - /** - * @brief Checks if the node represents a complete assignment. - * - * @return True if all variables have been assigned, false otherwise. - */ - inline bool isComplete() const { return nextConditional < 0; } - - /** - * @brief Expands the node by assigning the next variable. - * - * @param conditional The discrete conditional representing the next variable - * to be assigned. - * @param fa The frontal assignment for the next variable. - * @return A new SearchNode representing the expanded state. - */ - SearchNode expand(const DiscreteConditional& conditional, - const DiscreteValues& fa) const; - - /** - * @brief Prints the SearchNode to an output stream. - * - * @param os The output stream. - * @param node The SearchNode to be printed. - * @return The output stream. - */ - friend std::ostream& operator<<(std::ostream& os, const SearchNode& node) { - os << "SearchNode(error=" << node.error << ", bound=" << node.bound << ")"; - return os; - } -}; - struct Solution { double error; DiscreteValues assignment; @@ -89,40 +35,6 @@ struct Solution { os << "[ error=" << sn.error << " assignment={" << sn.assignment << "}]"; return os; } - - struct Compare { - bool operator()(const Solution& a, const Solution& b) const { - return a.error < b.error; - } - }; -}; - -// Define the Solutions class -class Solutions { - private: - size_t maxSize_; - std::priority_queue, Solution::Compare> pq_; - - public: - Solutions(size_t maxSize) : maxSize_(maxSize) {} - - /// Add a solution to the priority queue, possibly evicting the worst one. - /// Return true if we added the solution. - bool maybeAdd(double error, const DiscreteValues& assignment); - - /// Check if we have any solutions - bool empty() const { return pq_.empty(); } - - // Method to print all solutions - friend std::ostream& operator<<(std::ostream& os, const Solutions& sn); - - /// Check if (partial) solution with given bound can be pruned. If we have - /// room, we never prune. Otherwise, prune if lower bound on error is worse - /// than our current worst error. - bool prune(double bound) const; - - // Method to extract solutions in ascending order of error - std::vector extractSolutions(); }; /**