From 455554c80384fcdc1848eb38d394d7874755f449 Mon Sep 17 00:00:00 2001 From: Frank Dellaert Date: Sun, 26 Jan 2025 22:39:10 -0500 Subject: [PATCH] DiscreteSearch --- gtsam/discrete/DiscreteSearch.h | 276 ++++++++++++++++++++ gtsam/discrete/tests/testDiscreteSearch.cpp | 59 +++++ 2 files changed, 335 insertions(+) create mode 100644 gtsam/discrete/DiscreteSearch.h create mode 100644 gtsam/discrete/tests/testDiscreteSearch.cpp diff --git a/gtsam/discrete/DiscreteSearch.h b/gtsam/discrete/DiscreteSearch.h new file mode 100644 index 000000000..d9ffae768 --- /dev/null +++ b/gtsam/discrete/DiscreteSearch.h @@ -0,0 +1,276 @@ +/* ---------------------------------------------------------------------------- + + * GTSAM Copyright 2010, Georgia Tech Research Corporation, + * Atlanta, Georgia 30332-0415 + * All Rights Reserved + * Authors: Frank Dellaert, et al. (see THANKS for the full author list) + + * See LICENSE for the license information + + * -------------------------------------------------------------------------- */ + +/* + * DiscreteSearch.cpp + * + * @date January, 2025 + * @author Frank Dellaert + */ + +#include + +namespace gtsam { + +using Value = size_t; + +/** + * @brief Represents a node in the search tree for discrete search algorithms. + * + * @details Each SearchNode contains a partial assignment of discrete variables, + * the current error, a bound on the final error, and the index of the next + * conditional to be assigned. + */ +struct SearchNode { + DiscreteValues assignment; ///< Partial assignment of discrete variables. + double error; ///< Current error for the partial assignment. + double bound; ///< Lower bound on the final error for unassigned variables. + int nextConditional; ///< Index of the next conditional to be assigned. + + struct CompareByBound { + bool operator()(const SearchNode& a, const SearchNode& b) const { + return a.bound > b.bound; // smallest bound -> highest priority + } + }; + + /** + * @brief Checks if the node represents a complete assignment. + * + * @return True if all variables have been assigned, false otherwise. + */ + bool isComplete() const { return nextConditional < 0; } + + /** + * @brief Computes a lower bound on the final error for unassigned variables. + * + * @details This is a stub implementation that returns 0. Real implementations + * might perform partial factor analysis or use heuristics to compute the + * bound. + * + * @return A lower bound on the final error. + */ + double computeBound() const { + // Real code might do partial factor analysis or heuristics. + return 0.0; + } + + /** + * @brief Expands the node by assigning the next variable. + * + * @param conditional The discrete conditional representing the next variable + * to be assigned. + * @param fa The frontal assignment for the next variable. + * @return A new SearchNode representing the expanded state. + */ + SearchNode expand(const DiscreteConditional& conditional, + const DiscreteValues& fa) const { + // Combine the new frontal assignment with the current partial assignment + SearchNode child; + child.assignment = assignment; + for (auto& kv : fa) { + child.assignment[kv.first] = kv.second; + } + + // Compute the incremental error for this factor + child.error = error + conditional.error(child.assignment); + + // Compute new bound + child.bound = computeBound(); + + // Update the index of the next conditional + child.nextConditional = nextConditional - 1; + + return child; + } + + /** + * @brief Prints the SearchNode to an output stream. + * + * @param os The output stream. + * @param node The SearchNode to be printed. + * @return The output stream. + */ + friend std::ostream& operator<<(std::ostream& os, const SearchNode& node) { + os << "SearchNode(error=" << node.error << ", bound=" << node.bound << ")"; + return os; + } +}; + +struct Solution { + double error; + DiscreteValues assignment; + Solution(double err, const DiscreteValues& assign) + : error(err), assignment(assign) {} + friend std::ostream& operator<<(std::ostream& os, const Solution& sn) { + os << "[ error=" << sn.error << " assignment={" << sn.assignment << "}]"; + return os; + } +}; + +struct CompareByError { + bool operator()(const Solution& a, const Solution& b) const { + return a.error < b.error; + } +}; + +// Define the Solutions class +class Solutions { + private: + size_t maxSize_; + std::priority_queue, CompareByError> pq_; + + public: + Solutions(size_t maxSize) : maxSize_(maxSize) {} + + /// Add a solution to the priority queue, possibly evicting the worst one. + /// Return true if we added the solution. + bool maybeAdd(double error, const DiscreteValues& assignment) { + const bool full = pq_.size() == maxSize_; + if (full && error >= pq_.top().error) return false; + if (full) pq_.pop(); + pq_.emplace(error, assignment); + return true; + } + + /// Check if we have any solutions + bool empty() const { return pq_.empty(); } + + // Method to print all solutions + friend std::ostream& operator<<(std::ostream& os, const Solutions& sn) { + auto pq = sn.pq_; + while (!pq.empty()) { + const Solution& best = pq.top(); + os << "Error: " << best.error << ", Values: " << best.assignment + << std::endl; + pq.pop(); + } + return os; + } + + /// Check if (partial) solution with given bound can be pruned. If we have + /// room, we never prune. Otherwise, prune if lower bound on error is worse + /// than our current worst error. + bool prune(double bound) const { + if (pq_.size() < maxSize_) return false; + double worstError = pq_.top().error; + return (bound >= worstError); + } + + // Method to extract solutions in ascending order of error + std::vector extractSolutions() { + std::vector result; + while (!pq_.empty()) { + result.push_back(pq_.top()); + pq_.pop(); + } + std::sort( + result.begin(), result.end(), + [](const Solution& a, const Solution& b) { return a.error < b.error; }); + return result; + } +}; + +/** + * DiscreteSearch: Search for the K best solutions. + */ +class DiscreteSearch { + public: + /** + * Construct from a DiscreteBayesNet and K. + */ + DiscreteSearch(const DiscreteBayesNet& bayesNet, size_t K) : solutions_(K) { + // Copy out the conditionals + for (auto& factor : bayesNet) { + conditionals_.push_back(factor); + } + + // Calculate the cost-to-go for each conditional. If there are n + // conditionals, we start with nextConditional = n-1, and the minimum error + // obtainable is the sum of all the minimum errors. We start with + // 0, and that is the minimum error of the conditional with that index: + double error = 0.0; + for (const auto& conditional : conditionals_) { + Ordering ordering(conditional->begin(), conditional->end()); + auto maxx = conditional->max(ordering); + assert(maxx->size() == 1); + error -= std::log(maxx->evaluate({})); + costToGo_.push_back(error); + } + + // Create the root node: no variables assigned, nextConditional = last. + SearchNode root{ + .assignment = DiscreteValues(), + .error = 0.0, + .nextConditional = static_cast(conditionals_.size()) - 1}; + if (!costToGo_.empty()) root.bound = costToGo_.back(); + expansions_.push(root); + } + + /** + * @brief Search for the K best solutions. + * + * This method performs a search to find the K best solutions for the given + * DiscreteBayesNet. It uses a priority queue to manage the search nodes, + * expanding nodes with the smallest bound first. The search continues until + * all possible nodes have been expanded or pruned. + * + * @return A vector of the K best solutions found during the search. + */ + std::vector run() { + while (!expansions_.empty()) { + expandNextNode(); + } + + // Extract solutions from bestSolutions in ascending order of error + return solutions_.extractSolutions(); + } + + private: + void expandNextNode() { + // Pop the partial assignment with the smallest bound + SearchNode current = expansions_.top(); + expansions_.pop(); + + // If we already have K solutions, prune if we cannot beat the worst one. + if (solutions_.prune(current.bound)) { + return; + } + + // Check if we have a complete assignment + if (current.isComplete()) { + solutions_.maybeAdd(current.error, current.assignment); + return; + } + + // Expand on the next factor + const auto& conditional = conditionals_[current.nextConditional]; + + for (auto& fa : conditional->frontalAssignments()) { + auto childNode = current.expand(*conditional, fa); + if (childNode.nextConditional >= 0) + childNode.bound = + childNode.error + costToGo_[childNode.nextConditional]; + + // Again, prune if we cannot beat the worst solution + if (!solutions_.prune(childNode.bound)) { + expansions_.push(childNode); + } + } + } + + std::vector> conditionals_; + std::vector costToGo_; + std::priority_queue, + SearchNode::CompareByBound> + expansions_; + Solutions solutions_; +}; +} // namespace gtsam diff --git a/gtsam/discrete/tests/testDiscreteSearch.cpp b/gtsam/discrete/tests/testDiscreteSearch.cpp new file mode 100644 index 000000000..4d7adbac5 --- /dev/null +++ b/gtsam/discrete/tests/testDiscreteSearch.cpp @@ -0,0 +1,59 @@ +/* ---------------------------------------------------------------------------- + + * GTSAM Copyright 2010, Georgia Tech Research Corporation, + * Atlanta, Georgia 30332-0415 + * All Rights Reserved + * Authors: Frank Dellaert, et al. (see THANKS for the full author list) + + * See LICENSE for the license information + + * -------------------------------------------------------------------------- */ + +/* + * testDiscreteSearch.cpp + * + * @date January, 2025 + * @author Frank Dellaert + */ + +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include + +#include "AsiaExample.h" + +using namespace gtsam; + +TEST(DiscreteBayesNet, EmptyKBest) { + DiscreteBayesNet net; // no factors + DiscreteSearch search(net, 3); + auto solutions = search.run(); + // Expect one solution with empty assignment, error=0 + EXPECT_LONGS_EQUAL(1, solutions.size()); + EXPECT_DOUBLES_EQUAL(0, std::fabs(solutions[0].error), 1e-9); +} + +TEST(DiscreteBayesNet, AsiaKBest) { + using namespace asia_example; + DiscreteBayesNet asia = createAsiaExample(); + DiscreteSearch search(asia, 4); + auto solutions = search.run(); + EXPECT(!solutions.empty()); + // Regression test: check the first solution + EXPECT_DOUBLES_EQUAL(1.236627, std::fabs(solutions[0].error), 1e-5); +} + +/* ************************************************************************* */ +int main() { + TestResult tr; + return TestRegistry::runAllTests(tr); +} +/* ************************************************************************* */