From 4027e860576c29337617d4e6fb5cfe62324ae05a Mon Sep 17 00:00:00 2001 From: Frank Dellaert Date: Sun, 26 Jan 2025 16:57:37 -0500 Subject: [PATCH 01/20] stream operator --- gtsam/discrete/DiscreteValues.cpp | 16 ++++++++++++++-- gtsam/discrete/DiscreteValues.h | 3 +++ 2 files changed, 17 insertions(+), 2 deletions(-) diff --git a/gtsam/discrete/DiscreteValues.cpp b/gtsam/discrete/DiscreteValues.cpp index 416dfb888..3c3ed4468 100644 --- a/gtsam/discrete/DiscreteValues.cpp +++ b/gtsam/discrete/DiscreteValues.cpp @@ -26,12 +26,24 @@ using std::stringstream; namespace gtsam { +/* ************************************************************************ */ +static void stream(std::ostream& os, const DiscreteValues& x, + const KeyFormatter& keyFormatter) { + for (const auto& kv : x) + os << "(" << keyFormatter(kv.first) << ", " << kv.second << ")"; +} + +/* ************************************************************************ */ +std::ostream& operator<<(std::ostream& os, const DiscreteValues& x) { + stream(os, x, DefaultKeyFormatter); + return os; +} + /* ************************************************************************ */ void DiscreteValues::print(const string& s, const KeyFormatter& keyFormatter) const { cout << s << ": "; - for (auto&& kv : *this) - cout << "(" << keyFormatter(kv.first) << ", " << kv.second << ")"; + stream(cout, *this, keyFormatter); cout << endl; } diff --git a/gtsam/discrete/DiscreteValues.h b/gtsam/discrete/DiscreteValues.h index fa8a8a846..df4ecdbff 100644 --- a/gtsam/discrete/DiscreteValues.h +++ b/gtsam/discrete/DiscreteValues.h @@ -64,6 +64,9 @@ class GTSAM_EXPORT DiscreteValues : public Assignment { /// @name Standard Interface /// @{ + /// ostream operator: + friend std::ostream& operator<<(std::ostream& os, const DiscreteValues& x); + // insert in base class; std::pair insert( const value_type& value ){ return Base::insert(value); From 2808982903032524d9dfe90b36d67f0b4f4c0ad5 Mon Sep 17 00:00:00 2001 From: Frank Dellaert Date: Sun, 26 Jan 2025 16:57:58 -0500 Subject: [PATCH 02/20] Best-K search, v0 --- gtsam/discrete/tests/testDiscreteBayesNet.cpp | 276 ++++++++++++++++-- 1 file changed, 253 insertions(+), 23 deletions(-) diff --git a/gtsam/discrete/tests/testDiscreteBayesNet.cpp b/gtsam/discrete/tests/testDiscreteBayesNet.cpp index d2033909c..0bd81a197 100644 --- a/gtsam/discrete/tests/testDiscreteBayesNet.cpp +++ b/gtsam/discrete/tests/testDiscreteBayesNet.cpp @@ -23,6 +23,7 @@ #include #include #include +#include #include #include @@ -31,8 +32,15 @@ using namespace std; using namespace gtsam; -static const DiscreteKey Asia(0, 2), Smoking(4, 2), Tuberculosis(3, 2), - LungCancer(6, 2), Bronchitis(7, 2), Either(5, 2), XRay(2, 2), Dyspnea(1, 2); +namespace keys { +static const Key D = Symbol('D', 1), X = Symbol('X', 2), E = Symbol('E', 3), + B = Symbol('B', 4), L = Symbol('L', 5), T = Symbol('T', 6), + S = Symbol('S', 7), A = Symbol('A', 8); +} + +static const DiscreteKey Dyspnea(keys::D, 2), XRay(keys::X, 2), + Either(keys::E, 2), Bronchitis(keys::B, 2), LungCancer(keys::L, 2), + Tuberculosis(keys::T, 2), Smoking(keys::S, 2), Asia(keys::A, 2); using ADT = AlgebraicDecisionTree; @@ -40,17 +48,15 @@ using ADT = AlgebraicDecisionTree; DiscreteBayesNet constructAsiaExample() { DiscreteBayesNet asia; - asia.add(Asia, "99/1"); - asia.add(Smoking % "50/50"); // Signature version - - asia.add(Tuberculosis | Asia = "99/1 95/5"); - asia.add(LungCancer | Smoking = "99/1 90/10"); - asia.add(Bronchitis | Smoking = "70/30 40/60"); - - asia.add((Either | Tuberculosis, LungCancer) = "F T T T"); - - asia.add(XRay | Either = "95/5 2/98"); + // Add in topological sort order, parents last: asia.add((Dyspnea | Either, Bronchitis) = "9/1 2/8 3/7 1/9"); + asia.add(XRay | Either = "95/5 2/98"); + asia.add((Either | Tuberculosis, LungCancer) = "F T T T"); + asia.add(Bronchitis | Smoking = "70/30 40/60"); + asia.add(LungCancer | Smoking = "99/1 90/10"); + asia.add(Tuberculosis | Asia = "99/1 95/5"); + asia.add(Smoking % "50/50"); // Signature version + asia.add(Asia, "99/1"); return asia; } @@ -99,7 +105,8 @@ TEST(DiscreteBayesNet, Asia) { EXPECT(assert_equal(vs, marginals.marginalProbabilities(Smoking))); // Create solver and eliminate - const Ordering ordering{0, 1, 2, 3, 4, 5, 6, 7}; + const Ordering ordering{keys::A, keys::D, keys::T, keys::X, + keys::S, keys::E, keys::L, keys::B}; DiscreteBayesNet::shared_ptr chordal = fg.eliminateSequential(ordering); DiscreteConditional expected2(Bronchitis % "11/9"); EXPECT(assert_equal(expected2, *chordal->back())); @@ -157,16 +164,16 @@ TEST(DiscreteBayesNet, Dot) { "digraph {\n" " size=\"5,5\";\n" "\n" - " var0[label=\"0\"];\n" - " var3[label=\"3\"];\n" - " var4[label=\"4\"];\n" - " var5[label=\"5\"];\n" - " var6[label=\"6\"];\n" + " var4683743612465315840[label=\"A0\"];\n" + " var4971973988617027589[label=\"E5\"];\n" + " var5476377146882523142[label=\"L6\"];\n" + " var5980780305148018692[label=\"S4\"];\n" + " var6052837899185946627[label=\"T3\"];\n" "\n" - " var3->var5\n" - " var6->var5\n" - " var4->var6\n" - " var0->var3\n" + " var6052837899185946627->var4971973988617027589\n" + " var5476377146882523142->var4971973988617027589\n" + " var5980780305148018692->var5476377146882523142\n" + " var4683743612465315840->var6052837899185946627\n" "}"); } @@ -191,11 +198,234 @@ TEST(DiscreteBayesNet, markdown) { "|:-:|:-:|:-:|\n" "|0|0.8|0.2|\n" "|1|0.7|0.3|\n\n"; - auto formatter = [](Key key) { return key == 0 ? "Asia" : "Smoking"; }; + auto formatter = [](Key key) { return key == keys::A ? "Asia" : "Smoking"; }; string actual = fragment.markdown(formatter); EXPECT(actual == expected); } +/* ************************************************************************* */ +#include +#include +#include +#include +#include +#include + +using Value = size_t; + +// ---------------------------------------------------------------------------- +// 1) SearchNode: store partial assignment and next factor to expand +// ---------------------------------------------------------------------------- +struct SearchNode { + DiscreteValues assignment; + double error; + double bound; + int nextConditional; // index into conditionals + + /// if nextConditional < 0, we've assigned everything. + bool isComplete() const { return nextConditional < 0; } + + /// lower bound on final error for unassigned variables. Stub=0. + double computeBound() const { + // Real code might do partial factor analysis or heuristics. + return 0.0; + } + + /// Expand this node by assigning the next variable + SearchNode expand(const DiscreteConditional& conditional, + const DiscreteValues& fa) const { + // Combine the new frontal assignment with the current partial assignment + SearchNode child; + child.assignment = assignment; + for (auto& kv : fa) { + child.assignment[kv.first] = kv.second; + } + + // Compute the incremental error for this factor + child.error = error + conditional.error(child.assignment); + + // Compute new bound + child.bound = child.error + computeBound(); + + // Next factor index + child.nextConditional = nextConditional - 1; + + return child; + } + + friend std::ostream& operator<<(std::ostream& os, const SearchNode& sn) { + os << "[ error=" << sn.error << " bound=" << sn.bound + << " nextConditional=" << sn.nextConditional << " assignment={" + << sn.assignment << "}]"; + return os; + } +}; + +// ---------------------------------------------------------------------------- +// 2) Priority functor to make a min-heap by bound +// ---------------------------------------------------------------------------- +struct CompareByBound { + bool operator()(const SearchNode& a, const SearchNode& b) const { + return a.bound > b.bound; // smallest bound -> highest priority + } +}; + +// ---------------------------------------------------------------------------- +// 4) A Solution +// ---------------------------------------------------------------------------- +struct Solution { + double error; + DiscreteValues assignment; + Solution(double err, const DiscreteValues& assign) + : error(err), assignment(assign) {} + friend std::ostream& operator<<(std::ostream& os, const Solution& sn) { + os << "[ error=" << sn.error << " assignment={" << sn.assignment << "}]"; + return os; + } +}; + +struct CompareByError { + bool operator()(const Solution& a, const Solution& b) const { + return a.error < b.error; + } +}; + +using Solutions = + std::priority_queue, CompareByError>; + +// Function to print the priority queue +void printPriorityQueue(Solutions pq) { + size_t i = 0; + while (!pq.empty()) { + const Solution& best = pq.top(); + cout << " " << ++i << " Error: " << best.error + << ", Values: " << best.assignment << endl; + pq.pop(); + } +} + +// ---------------------------------------------------------------------------- +// 5) The main branch-and-bound routine for DiscreteBayesNet +// ---------------------------------------------------------------------------- +std::vector branchAndBoundKSolutions(const DiscreteBayesNet& bayesNet, + size_t K) { + // Copy out the conditionals + std::vector> conditionals; + for (auto& factor : bayesNet) { + conditionals.push_back(factor); + } + + // Min-heap of expansions by bound + std::priority_queue, CompareByBound> + expansions; + + // Max-heap of best solutions found so far, keyed by error. + // We keep the largest error on top. + Solutions solutions; + + // 1) Create the root node: no variables assigned, nextConditional = last. + SearchNode root{.assignment = DiscreteValues(), + .error = 0.0, + .nextConditional = static_cast(conditionals.size()) - 1}; + root.bound = root.computeBound(); + std::cout << "Root: " << root << std::endl; + expansions.push(root); + + // 2) Main loop + size_t numExpansions = 0; + while (!expansions.empty()) { + // Pop the partial assignment with the smallest bound + SearchNode current = expansions.top(); + expansions.pop(); + numExpansions++; + std::cout << "Expanding: " << current << std::endl; + + // If we already have K solutions, prune if we cannot beat the worst one. + if (solutions.size() == K) { + double worstError = solutions.top().error; + if (current.bound >= worstError) { + std::cout << "Pruning: bound=" << current.bound + << " worstError=" << worstError << std::endl; + continue; + } + } + + // Check if we have a complete assignment + if (current.isComplete()) { + // We have a new candidate solution + double finalError = current.error; + if (solutions.size() < K) { + std::cout << "Adding solution: " << current << std::endl; + solutions.emplace(finalError, current.assignment); + // print out all solutions: + std::cout << "Best solutions so far:" << std::endl; + printPriorityQueue(solutions); + } else { + double worstError = solutions.top().error; + if (finalError < worstError) { + std::cout << "Replacing solution: " << current << std::endl; + solutions.pop(); + solutions.emplace(finalError, current.assignment); + } + } + continue; + } + + // Expand on the next factor + const auto& conditional = conditionals[current.nextConditional]; + + for (auto& fa : conditional->frontalAssignments()) { + std::cout << "Frontal assignment: " << fa << std::endl; + auto childNode = current.expand(*conditional, fa); + + // Again, prune if we cannot beat the worst solution + if (solutions.size() == K) { + double worstError = solutions.top().error; + if (childNode.bound >= worstError) { + std::cout << "Pruning: bound=" << childNode.bound + << " worstError=" << worstError << std::endl; + continue; + } + } + expansions.push(childNode); + } + } + + std::cout << "Expansions: " << numExpansions << std::endl; + + // 3) Extract solutions from bestSolutions in ascending order of error + std::vector result; + while (!solutions.empty()) { + result.push_back(solutions.top()); + solutions.pop(); + } + std::sort(result.begin(), result.end(), + [](auto& a, auto& b) { return a.error < b.error; }); + + return result; +} + +// ---------------------------------------------------------------------------- +// Example “Unit Tests” (trivial stubs) +// ---------------------------------------------------------------------------- + +TEST_DISABLED(DiscreteBayesNet, EmptyKBest) { + DiscreteBayesNet net; // no factors + auto solutions = branchAndBoundKSolutions(net, 3); + // Expect one solution with empty assignment, error=0 + EXPECT_LONGS_EQUAL(1, solutions.size()); + EXPECT_DOUBLES_EQUAL(0, std::fabs(solutions[0].error), 1e-9); +} + +TEST(DiscreteBayesNet, AsiaKBest) { + DiscreteBayesNet net = constructAsiaExample(); + + auto solutions = branchAndBoundKSolutions(net, 4); + EXPECT(!solutions.empty()); + // All solutions likely tie with error=0 in this stub + EXPECT_DOUBLES_EQUAL(0, std::fabs(solutions[0].error), 1e-9); +} + /* ************************************************************************* */ int main() { TestResult tr; From 74b7a962d5f5453f57b9b13b14e362af716f95e7 Mon Sep 17 00:00:00 2001 From: Frank Dellaert Date: Sun, 26 Jan 2025 17:40:44 -0500 Subject: [PATCH 03/20] Methods in Solutions --- gtsam/discrete/tests/testDiscreteBayesNet.cpp | 138 ++++++++++-------- 1 file changed, 77 insertions(+), 61 deletions(-) diff --git a/gtsam/discrete/tests/testDiscreteBayesNet.cpp b/gtsam/discrete/tests/testDiscreteBayesNet.cpp index 0bd81a197..52fc660e5 100644 --- a/gtsam/discrete/tests/testDiscreteBayesNet.cpp +++ b/gtsam/discrete/tests/testDiscreteBayesNet.cpp @@ -96,7 +96,7 @@ TEST(DiscreteBayesNet, Asia) { // Convert to factor graph DiscreteFactorGraph fg(asia); - LONGS_EQUAL(3, fg.back()->size()); + LONGS_EQUAL(1, fg.back()->size()); // Check the marginals we know (of the parent-less nodes) DiscreteMarginals marginals(fg); @@ -164,16 +164,16 @@ TEST(DiscreteBayesNet, Dot) { "digraph {\n" " size=\"5,5\";\n" "\n" - " var4683743612465315840[label=\"A0\"];\n" - " var4971973988617027589[label=\"E5\"];\n" - " var5476377146882523142[label=\"L6\"];\n" - " var5980780305148018692[label=\"S4\"];\n" - " var6052837899185946627[label=\"T3\"];\n" + " var4683743612465315848[label=\"A8\"];\n" + " var4971973988617027587[label=\"E3\"];\n" + " var5476377146882523141[label=\"L5\"];\n" + " var5980780305148018695[label=\"S7\"];\n" + " var6052837899185946630[label=\"T6\"];\n" "\n" - " var6052837899185946627->var4971973988617027589\n" - " var5476377146882523142->var4971973988617027589\n" - " var5980780305148018692->var5476377146882523142\n" - " var4683743612465315840->var6052837899185946627\n" + " var6052837899185946630->var4971973988617027587\n" + " var5476377146882523141->var4971973988617027587\n" + " var5980780305148018695->var5476377146882523141\n" + " var4683743612465315848->var6052837899185946630\n" "}"); } @@ -290,19 +290,61 @@ struct CompareByError { } }; -using Solutions = - std::priority_queue, CompareByError>; +// Define the Solutions class +class Solutions { + private: + size_t maxSize_; + std::priority_queue, CompareByError> pq_; -// Function to print the priority queue -void printPriorityQueue(Solutions pq) { - size_t i = 0; - while (!pq.empty()) { - const Solution& best = pq.top(); - cout << " " << ++i << " Error: " << best.error - << ", Values: " << best.assignment << endl; - pq.pop(); + public: + Solutions(size_t maxSize) : maxSize_(maxSize) {} + + /// Add a solution to the priority queue, possibly evicting the worst one. + /// Return true if we added the solution. + bool maybeAdd(double error, const DiscreteValues& assignment) { + const bool full = pq_.size() == maxSize_; + if (full && error >= pq_.top().error) return false; + if (full) pq_.pop(); + pq_.emplace(error, assignment); + return true; } -} + + /// Check if we have any solutions + bool empty() const { return pq_.empty(); } + + // Method to print all solutions + void print() const { + auto pq = pq_; + while (!pq.empty()) { + const Solution& best = pq.top(); + std::cout << "Error: " << best.error << ", Values: " << best.assignment + << std::endl; + pq.pop(); + } + } + + /// Check if (partial) solution with given bound can be pruned. If we have + /// room, we never prune. Otherwise, prune if lower bound on error is worse + /// than our current worst error. + bool prune(double bound) const { + if (pq_.size() < maxSize_) return false; + double worstError = pq_.top().error; + return (bound >= worstError); + } + + // Method to extract solutions in ascending order of error + std::vector extractSolutions() { + std::vector result; + while (!pq_.empty()) { + result.push_back(pq_.top()); + pq_.pop(); + } + std::sort( + result.begin(), result.end(), + [](const Solution& a, const Solution& b) { return a.error < b.error; }); + return result; + } +}; // ---------------------------------------------------------------------------- // 5) The main branch-and-bound routine for DiscreteBayesNet @@ -321,7 +363,7 @@ std::vector branchAndBoundKSolutions(const DiscreteBayesNet& bayesNet, // Max-heap of best solutions found so far, keyed by error. // We keep the largest error on top. - Solutions solutions; + Solutions solutions(K); // 1) Create the root node: no variables assigned, nextConditional = last. SearchNode root{.assignment = DiscreteValues(), @@ -341,32 +383,17 @@ std::vector branchAndBoundKSolutions(const DiscreteBayesNet& bayesNet, std::cout << "Expanding: " << current << std::endl; // If we already have K solutions, prune if we cannot beat the worst one. - if (solutions.size() == K) { - double worstError = solutions.top().error; - if (current.bound >= worstError) { - std::cout << "Pruning: bound=" << current.bound - << " worstError=" << worstError << std::endl; - continue; - } + if (solutions.prune(current.bound)) { + std::cout << "Pruning: bound=" << current.bound << std::endl; + continue; } // Check if we have a complete assignment if (current.isComplete()) { - // We have a new candidate solution - double finalError = current.error; - if (solutions.size() < K) { - std::cout << "Adding solution: " << current << std::endl; - solutions.emplace(finalError, current.assignment); - // print out all solutions: + const bool added = solutions.maybeAdd(current.error, current.assignment); + if (added) { std::cout << "Best solutions so far:" << std::endl; - printPriorityQueue(solutions); - } else { - double worstError = solutions.top().error; - if (finalError < worstError) { - std::cout << "Replacing solution: " << current << std::endl; - solutions.pop(); - solutions.emplace(finalError, current.assignment); - } + solutions.print(); } continue; } @@ -379,14 +406,11 @@ std::vector branchAndBoundKSolutions(const DiscreteBayesNet& bayesNet, auto childNode = current.expand(*conditional, fa); // Again, prune if we cannot beat the worst solution - if (solutions.size() == K) { - double worstError = solutions.top().error; - if (childNode.bound >= worstError) { - std::cout << "Pruning: bound=" << childNode.bound - << " worstError=" << worstError << std::endl; - continue; - } + if (solutions.prune(current.bound)) { + std::cout << "Pruning: bound=" << childNode.bound << std::endl; + continue; } + expansions.push(childNode); } } @@ -394,15 +418,7 @@ std::vector branchAndBoundKSolutions(const DiscreteBayesNet& bayesNet, std::cout << "Expansions: " << numExpansions << std::endl; // 3) Extract solutions from bestSolutions in ascending order of error - std::vector result; - while (!solutions.empty()) { - result.push_back(solutions.top()); - solutions.pop(); - } - std::sort(result.begin(), result.end(), - [](auto& a, auto& b) { return a.error < b.error; }); - - return result; + return solutions.extractSolutions(); } // ---------------------------------------------------------------------------- @@ -422,8 +438,8 @@ TEST(DiscreteBayesNet, AsiaKBest) { auto solutions = branchAndBoundKSolutions(net, 4); EXPECT(!solutions.empty()); - // All solutions likely tie with error=0 in this stub - EXPECT_DOUBLES_EQUAL(0, std::fabs(solutions[0].error), 1e-9); + // Regression test: check the first solution + EXPECT_DOUBLES_EQUAL(1.236627, std::fabs(solutions[0].error), 1e-5); } /* ************************************************************************* */ From 1f4d9bbd7ecdb90c2401c374000a4142486c7c5b Mon Sep 17 00:00:00 2001 From: Frank Dellaert Date: Sun, 26 Jan 2025 17:59:31 -0500 Subject: [PATCH 04/20] Separate class --- gtsam/discrete/tests/testDiscreteBayesNet.cpp | 114 +++++++++++------- 1 file changed, 68 insertions(+), 46 deletions(-) diff --git a/gtsam/discrete/tests/testDiscreteBayesNet.cpp b/gtsam/discrete/tests/testDiscreteBayesNet.cpp index 52fc660e5..648b26770 100644 --- a/gtsam/discrete/tests/testDiscreteBayesNet.cpp +++ b/gtsam/discrete/tests/testDiscreteBayesNet.cpp @@ -346,97 +346,119 @@ class Solutions { } }; -// ---------------------------------------------------------------------------- -// 5) The main branch-and-bound routine for DiscreteBayesNet -// ---------------------------------------------------------------------------- -std::vector branchAndBoundKSolutions(const DiscreteBayesNet& bayesNet, - size_t K) { - // Copy out the conditionals - std::vector> conditionals; - for (auto& factor : bayesNet) { - conditionals.push_back(factor); +/** + * BestKSearch: Search for the K best solutions. + */ +class BestKSearch { + public: + /** + * Construct from a DiscreteBayesNet and K. + */ + BestKSearch(const DiscreteBayesNet& bayesNet, size_t K) + : bayesNet_(bayesNet), solutions_(K) { + // Copy out the conditionals + for (auto& factor : bayesNet_) { + conditionals_.push_back(factor); + } + + // Create the root node: no variables assigned, nextConditional = last. + SearchNode root{ + .assignment = DiscreteValues(), + .error = 0.0, + .nextConditional = static_cast(conditionals_.size()) - 1}; + root.bound = root.computeBound(); + std::cout << "Root: " << root << std::endl; + expansions_.push(root); } - // Min-heap of expansions by bound - std::priority_queue, CompareByBound> - expansions; + /** + * @brief Search for the K best solutions. + * + * This method performs a search to find the K best solutions for the given + * DiscreteBayesNet. It uses a priority queue to manage the search nodes, + * expanding nodes with the smallest bound first. The search continues until + * all possible nodes have been expanded or pruned. + * + * @return A vector of the K best solutions found during the search. + */ + std::vector run() { + size_t numExpansions = 0; + while (!expansions_.empty()) { + expandNextNode(); + numExpansions++; + } - // Max-heap of best solutions found so far, keyed by error. - // We keep the largest error on top. - Solutions solutions(K); + std::cout << "Expansions: " << numExpansions << std::endl; - // 1) Create the root node: no variables assigned, nextConditional = last. - SearchNode root{.assignment = DiscreteValues(), - .error = 0.0, - .nextConditional = static_cast(conditionals.size()) - 1}; - root.bound = root.computeBound(); - std::cout << "Root: " << root << std::endl; - expansions.push(root); + // Extract solutions from bestSolutions in ascending order of error + return solutions_.extractSolutions(); + } - // 2) Main loop - size_t numExpansions = 0; - while (!expansions.empty()) { + private: + // + void expandNextNode() { // Pop the partial assignment with the smallest bound - SearchNode current = expansions.top(); - expansions.pop(); - numExpansions++; + SearchNode current = expansions_.top(); + expansions_.pop(); std::cout << "Expanding: " << current << std::endl; // If we already have K solutions, prune if we cannot beat the worst one. - if (solutions.prune(current.bound)) { + if (solutions_.prune(current.bound)) { std::cout << "Pruning: bound=" << current.bound << std::endl; - continue; + return; } // Check if we have a complete assignment if (current.isComplete()) { - const bool added = solutions.maybeAdd(current.error, current.assignment); + const bool added = solutions_.maybeAdd(current.error, current.assignment); if (added) { std::cout << "Best solutions so far:" << std::endl; - solutions.print(); + solutions_.print(); } - continue; + return; } // Expand on the next factor - const auto& conditional = conditionals[current.nextConditional]; + const auto& conditional = conditionals_[current.nextConditional]; for (auto& fa : conditional->frontalAssignments()) { std::cout << "Frontal assignment: " << fa << std::endl; auto childNode = current.expand(*conditional, fa); // Again, prune if we cannot beat the worst solution - if (solutions.prune(current.bound)) { + if (solutions_.prune(current.bound)) { std::cout << "Pruning: bound=" << childNode.bound << std::endl; continue; } - expansions.push(childNode); + expansions_.push(childNode); } } - std::cout << "Expansions: " << numExpansions << std::endl; - - // 3) Extract solutions from bestSolutions in ascending order of error - return solutions.extractSolutions(); -} + const DiscreteBayesNet& bayesNet_; + std::vector> conditionals_; + std::priority_queue, CompareByBound> + expansions_; + Solutions solutions_; +}; // ---------------------------------------------------------------------------- // Example “Unit Tests” (trivial stubs) // ---------------------------------------------------------------------------- -TEST_DISABLED(DiscreteBayesNet, EmptyKBest) { +TEST(DiscreteBayesNet, EmptyKBest) { DiscreteBayesNet net; // no factors - auto solutions = branchAndBoundKSolutions(net, 3); + BestKSearch search(net, 3); + auto solutions = search.run(); // Expect one solution with empty assignment, error=0 EXPECT_LONGS_EQUAL(1, solutions.size()); EXPECT_DOUBLES_EQUAL(0, std::fabs(solutions[0].error), 1e-9); } TEST(DiscreteBayesNet, AsiaKBest) { - DiscreteBayesNet net = constructAsiaExample(); - - auto solutions = branchAndBoundKSolutions(net, 4); + DiscreteBayesNet asia = constructAsiaExample(); + BestKSearch search(asia, 4); + auto solutions = search.run(); EXPECT(!solutions.empty()); // Regression test: check the first solution EXPECT_DOUBLES_EQUAL(1.236627, std::fabs(solutions[0].error), 1e-5); From d879b156f8de4da0f0cbdaf16b7ed5b0d13bf20f Mon Sep 17 00:00:00 2001 From: Frank Dellaert Date: Sun, 26 Jan 2025 22:08:04 -0500 Subject: [PATCH 05/20] Asia example --- gtsam/discrete/tests/AsiaExample.h | 61 +++ gtsam/discrete/tests/testDiscreteBayesNet.cpp | 367 ++---------------- 2 files changed, 98 insertions(+), 330 deletions(-) create mode 100644 gtsam/discrete/tests/AsiaExample.h diff --git a/gtsam/discrete/tests/AsiaExample.h b/gtsam/discrete/tests/AsiaExample.h new file mode 100644 index 000000000..7413f6899 --- /dev/null +++ b/gtsam/discrete/tests/AsiaExample.h @@ -0,0 +1,61 @@ +/* ---------------------------------------------------------------------------- + + * GTSAM Copyright 2010, Georgia Tech Research Corporation, + * Atlanta, Georgia 30332-0415 + * All Rights Reserved + * Authors: Frank Dellaert, et al. (see THANKS for the full author list) + + * See LICENSE for the license information + + * -------------------------------------------------------------------------- */ + +/* + * AsiaExample.h + * + * @date Jan, 2025 + * @author Frank Dellaert + */ + +#include +#include + +namespace gtsam { +namespace asia_example { + +static const Key D = Symbol('D', 1), X = Symbol('X', 2), E = Symbol('E', 3), + B = Symbol('B', 4), L = Symbol('L', 5), T = Symbol('T', 6), + S = Symbol('S', 7), A = Symbol('A', 8); + +static const DiscreteKey Dyspnea(D, 2), XRay(X, 2), Either(E, 2), + Bronchitis(B, 2), LungCancer(L, 2), Tuberculosis(T, 2), Smoking(S, 2), + Asia(A, 2); + +// Function to construct the incomplete Asia example +DiscreteBayesNet createPriors() { + DiscreteBayesNet priors; + priors.add(Smoking % "50/50"); + priors.add(Asia, "99/1"); + return priors; +} + +// Function to construct the incomplete Asia example +DiscreteBayesNet createFragment() { + DiscreteBayesNet fragment; + fragment.add((Either | Tuberculosis, LungCancer) = "F T T T"); + fragment.add(LungCancer | Smoking = "99/1 90/10"); + fragment.add(Tuberculosis | Asia = "99/1 95/5"); + for (const auto& factor : createPriors()) fragment.push_back(factor); + return fragment; +} + +// Function to construct the Asia example +DiscreteBayesNet createAsiaExample() { + DiscreteBayesNet asia; + asia.add((Dyspnea | Either, Bronchitis) = "9/1 2/8 3/7 1/9"); + asia.add(XRay | Either = "95/5 2/98"); + asia.add(Bronchitis | Smoking = "70/30 40/60"); + for (const auto& factor : createFragment()) asia.push_back(factor); + return asia; +} +} // namespace asia_example +} // namespace gtsam \ No newline at end of file diff --git a/gtsam/discrete/tests/testDiscreteBayesNet.cpp b/gtsam/discrete/tests/testDiscreteBayesNet.cpp index 648b26770..dd5d218f8 100644 --- a/gtsam/discrete/tests/testDiscreteBayesNet.cpp +++ b/gtsam/discrete/tests/testDiscreteBayesNet.cpp @@ -29,40 +29,13 @@ #include #include -using namespace std; +#include "AsiaExample.h" + using namespace gtsam; -namespace keys { -static const Key D = Symbol('D', 1), X = Symbol('X', 2), E = Symbol('E', 3), - B = Symbol('B', 4), L = Symbol('L', 5), T = Symbol('T', 6), - S = Symbol('S', 7), A = Symbol('A', 8); -} - -static const DiscreteKey Dyspnea(keys::D, 2), XRay(keys::X, 2), - Either(keys::E, 2), Bronchitis(keys::B, 2), LungCancer(keys::L, 2), - Tuberculosis(keys::T, 2), Smoking(keys::S, 2), Asia(keys::A, 2); - -using ADT = AlgebraicDecisionTree; - -// Function to construct the Asia example -DiscreteBayesNet constructAsiaExample() { - DiscreteBayesNet asia; - - // Add in topological sort order, parents last: - asia.add((Dyspnea | Either, Bronchitis) = "9/1 2/8 3/7 1/9"); - asia.add(XRay | Either = "95/5 2/98"); - asia.add((Either | Tuberculosis, LungCancer) = "F T T T"); - asia.add(Bronchitis | Smoking = "70/30 40/60"); - asia.add(LungCancer | Smoking = "99/1 90/10"); - asia.add(Tuberculosis | Asia = "99/1 95/5"); - asia.add(Smoking % "50/50"); // Signature version - asia.add(Asia, "99/1"); - - return asia; -} - /* ************************************************************************* */ TEST(DiscreteBayesNet, bayesNet) { + using ADT = AlgebraicDecisionTree; DiscreteBayesNet bayesNet; DiscreteKey Parent(0, 2), Child(1, 2); @@ -92,7 +65,8 @@ TEST(DiscreteBayesNet, bayesNet) { /* ************************************************************************* */ TEST(DiscreteBayesNet, Asia) { - DiscreteBayesNet asia = constructAsiaExample(); + using namespace asia_example; + const DiscreteBayesNet asia = createAsiaExample(); // Convert to factor graph DiscreteFactorGraph fg(asia); @@ -105,8 +79,7 @@ TEST(DiscreteBayesNet, Asia) { EXPECT(assert_equal(vs, marginals.marginalProbabilities(Smoking))); // Create solver and eliminate - const Ordering ordering{keys::A, keys::D, keys::T, keys::X, - keys::S, keys::E, keys::L, keys::B}; + const Ordering ordering{A, D, T, X, S, E, L, B}; DiscreteBayesNet::shared_ptr chordal = fg.eliminateSequential(ordering); DiscreteConditional expected2(Bronchitis % "11/9"); EXPECT(assert_equal(expected2, *chordal->back())); @@ -151,319 +124,53 @@ TEST(DiscreteBayesNet, Sugar) { /* ************************************************************************* */ TEST(DiscreteBayesNet, Dot) { - DiscreteBayesNet fragment; - fragment.add(Asia % "99/1"); - fragment.add(Smoking % "50/50"); + using namespace asia_example; + const DiscreteBayesNet fragment = createFragment(); - fragment.add(Tuberculosis | Asia = "99/1 95/5"); - fragment.add(LungCancer | Smoking = "99/1 90/10"); - fragment.add((Either | Tuberculosis, LungCancer) = "F T T T"); - - string actual = fragment.dot(); - EXPECT(actual == - "digraph {\n" - " size=\"5,5\";\n" - "\n" - " var4683743612465315848[label=\"A8\"];\n" - " var4971973988617027587[label=\"E3\"];\n" - " var5476377146882523141[label=\"L5\"];\n" - " var5980780305148018695[label=\"S7\"];\n" - " var6052837899185946630[label=\"T6\"];\n" - "\n" - " var6052837899185946630->var4971973988617027587\n" - " var5476377146882523141->var4971973988617027587\n" - " var5980780305148018695->var5476377146882523141\n" - " var4683743612465315848->var6052837899185946630\n" - "}"); + std::string expected = + "digraph {\n" + " size=\"5,5\";\n" + "\n" + " var4683743612465315848[label=\"A8\"];\n" + " var4971973988617027587[label=\"E3\"];\n" + " var5476377146882523141[label=\"L5\"];\n" + " var5980780305148018695[label=\"S7\"];\n" + " var6052837899185946630[label=\"T6\"];\n" + "\n" + " var4683743612465315848->var6052837899185946630\n" + " var5980780305148018695->var5476377146882523141\n" + " var6052837899185946630->var4971973988617027587\n" + " var5476377146882523141->var4971973988617027587\n" + "}"; + std::string actual = fragment.dot(); + EXPECT(actual.compare(expected) == 0); } /* ************************************************************************* */ // Check markdown representation looks as expected. TEST(DiscreteBayesNet, markdown) { - DiscreteBayesNet fragment; - fragment.add(Asia % "99/1"); - fragment.add(Smoking | Asia = "8/2 7/3"); + using namespace asia_example; + DiscreteBayesNet priors = createPriors(); - string expected = + std::string expected = "`DiscreteBayesNet` of size 2\n" "\n" + " *P(Smoking):*\n\n" + "|Smoking|value|\n" + "|:-:|:-:|\n" + "|0|0.5|\n" + "|1|0.5|\n" + "\n" " *P(Asia):*\n\n" "|Asia|value|\n" "|:-:|:-:|\n" "|0|0.99|\n" - "|1|0.01|\n" - "\n" - " *P(Smoking|Asia):*\n\n" - "|*Asia*|0|1|\n" - "|:-:|:-:|:-:|\n" - "|0|0.8|0.2|\n" - "|1|0.7|0.3|\n\n"; - auto formatter = [](Key key) { return key == keys::A ? "Asia" : "Smoking"; }; - string actual = fragment.markdown(formatter); + "|1|0.01|\n\n"; + auto formatter = [](Key key) { return key == A ? "Asia" : "Smoking"; }; + std::string actual = priors.markdown(formatter); EXPECT(actual == expected); } -/* ************************************************************************* */ -#include -#include -#include -#include -#include -#include - -using Value = size_t; - -// ---------------------------------------------------------------------------- -// 1) SearchNode: store partial assignment and next factor to expand -// ---------------------------------------------------------------------------- -struct SearchNode { - DiscreteValues assignment; - double error; - double bound; - int nextConditional; // index into conditionals - - /// if nextConditional < 0, we've assigned everything. - bool isComplete() const { return nextConditional < 0; } - - /// lower bound on final error for unassigned variables. Stub=0. - double computeBound() const { - // Real code might do partial factor analysis or heuristics. - return 0.0; - } - - /// Expand this node by assigning the next variable - SearchNode expand(const DiscreteConditional& conditional, - const DiscreteValues& fa) const { - // Combine the new frontal assignment with the current partial assignment - SearchNode child; - child.assignment = assignment; - for (auto& kv : fa) { - child.assignment[kv.first] = kv.second; - } - - // Compute the incremental error for this factor - child.error = error + conditional.error(child.assignment); - - // Compute new bound - child.bound = child.error + computeBound(); - - // Next factor index - child.nextConditional = nextConditional - 1; - - return child; - } - - friend std::ostream& operator<<(std::ostream& os, const SearchNode& sn) { - os << "[ error=" << sn.error << " bound=" << sn.bound - << " nextConditional=" << sn.nextConditional << " assignment={" - << sn.assignment << "}]"; - return os; - } -}; - -// ---------------------------------------------------------------------------- -// 2) Priority functor to make a min-heap by bound -// ---------------------------------------------------------------------------- -struct CompareByBound { - bool operator()(const SearchNode& a, const SearchNode& b) const { - return a.bound > b.bound; // smallest bound -> highest priority - } -}; - -// ---------------------------------------------------------------------------- -// 4) A Solution -// ---------------------------------------------------------------------------- -struct Solution { - double error; - DiscreteValues assignment; - Solution(double err, const DiscreteValues& assign) - : error(err), assignment(assign) {} - friend std::ostream& operator<<(std::ostream& os, const Solution& sn) { - os << "[ error=" << sn.error << " assignment={" << sn.assignment << "}]"; - return os; - } -}; - -struct CompareByError { - bool operator()(const Solution& a, const Solution& b) const { - return a.error < b.error; - } -}; - -// Define the Solutions class -class Solutions { - private: - size_t maxSize_; - std::priority_queue, CompareByError> pq_; - - public: - Solutions(size_t maxSize) : maxSize_(maxSize) {} - - /// Add a solution to the priority queue, possibly evicting the worst one. - /// Return true if we added the solution. - bool maybeAdd(double error, const DiscreteValues& assignment) { - const bool full = pq_.size() == maxSize_; - if (full && error >= pq_.top().error) return false; - if (full) pq_.pop(); - pq_.emplace(error, assignment); - return true; - } - - /// Check if we have any solutions - bool empty() const { return pq_.empty(); } - - // Method to print all solutions - void print() const { - auto pq = pq_; - while (!pq.empty()) { - const Solution& best = pq.top(); - std::cout << "Error: " << best.error << ", Values: " << best.assignment - << std::endl; - pq.pop(); - } - } - - /// Check if (partial) solution with given bound can be pruned. If we have - /// room, we never prune. Otherwise, prune if lower bound on error is worse - /// than our current worst error. - bool prune(double bound) const { - if (pq_.size() < maxSize_) return false; - double worstError = pq_.top().error; - return (bound >= worstError); - } - - // Method to extract solutions in ascending order of error - std::vector extractSolutions() { - std::vector result; - while (!pq_.empty()) { - result.push_back(pq_.top()); - pq_.pop(); - } - std::sort( - result.begin(), result.end(), - [](const Solution& a, const Solution& b) { return a.error < b.error; }); - return result; - } -}; - -/** - * BestKSearch: Search for the K best solutions. - */ -class BestKSearch { - public: - /** - * Construct from a DiscreteBayesNet and K. - */ - BestKSearch(const DiscreteBayesNet& bayesNet, size_t K) - : bayesNet_(bayesNet), solutions_(K) { - // Copy out the conditionals - for (auto& factor : bayesNet_) { - conditionals_.push_back(factor); - } - - // Create the root node: no variables assigned, nextConditional = last. - SearchNode root{ - .assignment = DiscreteValues(), - .error = 0.0, - .nextConditional = static_cast(conditionals_.size()) - 1}; - root.bound = root.computeBound(); - std::cout << "Root: " << root << std::endl; - expansions_.push(root); - } - - /** - * @brief Search for the K best solutions. - * - * This method performs a search to find the K best solutions for the given - * DiscreteBayesNet. It uses a priority queue to manage the search nodes, - * expanding nodes with the smallest bound first. The search continues until - * all possible nodes have been expanded or pruned. - * - * @return A vector of the K best solutions found during the search. - */ - std::vector run() { - size_t numExpansions = 0; - while (!expansions_.empty()) { - expandNextNode(); - numExpansions++; - } - - std::cout << "Expansions: " << numExpansions << std::endl; - - // Extract solutions from bestSolutions in ascending order of error - return solutions_.extractSolutions(); - } - - private: - // - void expandNextNode() { - // Pop the partial assignment with the smallest bound - SearchNode current = expansions_.top(); - expansions_.pop(); - std::cout << "Expanding: " << current << std::endl; - - // If we already have K solutions, prune if we cannot beat the worst one. - if (solutions_.prune(current.bound)) { - std::cout << "Pruning: bound=" << current.bound << std::endl; - return; - } - - // Check if we have a complete assignment - if (current.isComplete()) { - const bool added = solutions_.maybeAdd(current.error, current.assignment); - if (added) { - std::cout << "Best solutions so far:" << std::endl; - solutions_.print(); - } - return; - } - - // Expand on the next factor - const auto& conditional = conditionals_[current.nextConditional]; - - for (auto& fa : conditional->frontalAssignments()) { - std::cout << "Frontal assignment: " << fa << std::endl; - auto childNode = current.expand(*conditional, fa); - - // Again, prune if we cannot beat the worst solution - if (solutions_.prune(current.bound)) { - std::cout << "Pruning: bound=" << childNode.bound << std::endl; - continue; - } - - expansions_.push(childNode); - } - } - - const DiscreteBayesNet& bayesNet_; - std::vector> conditionals_; - std::priority_queue, CompareByBound> - expansions_; - Solutions solutions_; -}; - -// ---------------------------------------------------------------------------- -// Example “Unit Tests” (trivial stubs) -// ---------------------------------------------------------------------------- - -TEST(DiscreteBayesNet, EmptyKBest) { - DiscreteBayesNet net; // no factors - BestKSearch search(net, 3); - auto solutions = search.run(); - // Expect one solution with empty assignment, error=0 - EXPECT_LONGS_EQUAL(1, solutions.size()); - EXPECT_DOUBLES_EQUAL(0, std::fabs(solutions[0].error), 1e-9); -} - -TEST(DiscreteBayesNet, AsiaKBest) { - DiscreteBayesNet asia = constructAsiaExample(); - BestKSearch search(asia, 4); - auto solutions = search.run(); - EXPECT(!solutions.empty()); - // Regression test: check the first solution - EXPECT_DOUBLES_EQUAL(1.236627, std::fabs(solutions[0].error), 1e-5); -} - /* ************************************************************************* */ int main() { TestResult tr; From 455554c80384fcdc1848eb38d394d7874755f449 Mon Sep 17 00:00:00 2001 From: Frank Dellaert Date: Sun, 26 Jan 2025 22:39:10 -0500 Subject: [PATCH 06/20] DiscreteSearch --- gtsam/discrete/DiscreteSearch.h | 276 ++++++++++++++++++++ gtsam/discrete/tests/testDiscreteSearch.cpp | 59 +++++ 2 files changed, 335 insertions(+) create mode 100644 gtsam/discrete/DiscreteSearch.h create mode 100644 gtsam/discrete/tests/testDiscreteSearch.cpp diff --git a/gtsam/discrete/DiscreteSearch.h b/gtsam/discrete/DiscreteSearch.h new file mode 100644 index 000000000..d9ffae768 --- /dev/null +++ b/gtsam/discrete/DiscreteSearch.h @@ -0,0 +1,276 @@ +/* ---------------------------------------------------------------------------- + + * GTSAM Copyright 2010, Georgia Tech Research Corporation, + * Atlanta, Georgia 30332-0415 + * All Rights Reserved + * Authors: Frank Dellaert, et al. (see THANKS for the full author list) + + * See LICENSE for the license information + + * -------------------------------------------------------------------------- */ + +/* + * DiscreteSearch.cpp + * + * @date January, 2025 + * @author Frank Dellaert + */ + +#include + +namespace gtsam { + +using Value = size_t; + +/** + * @brief Represents a node in the search tree for discrete search algorithms. + * + * @details Each SearchNode contains a partial assignment of discrete variables, + * the current error, a bound on the final error, and the index of the next + * conditional to be assigned. + */ +struct SearchNode { + DiscreteValues assignment; ///< Partial assignment of discrete variables. + double error; ///< Current error for the partial assignment. + double bound; ///< Lower bound on the final error for unassigned variables. + int nextConditional; ///< Index of the next conditional to be assigned. + + struct CompareByBound { + bool operator()(const SearchNode& a, const SearchNode& b) const { + return a.bound > b.bound; // smallest bound -> highest priority + } + }; + + /** + * @brief Checks if the node represents a complete assignment. + * + * @return True if all variables have been assigned, false otherwise. + */ + bool isComplete() const { return nextConditional < 0; } + + /** + * @brief Computes a lower bound on the final error for unassigned variables. + * + * @details This is a stub implementation that returns 0. Real implementations + * might perform partial factor analysis or use heuristics to compute the + * bound. + * + * @return A lower bound on the final error. + */ + double computeBound() const { + // Real code might do partial factor analysis or heuristics. + return 0.0; + } + + /** + * @brief Expands the node by assigning the next variable. + * + * @param conditional The discrete conditional representing the next variable + * to be assigned. + * @param fa The frontal assignment for the next variable. + * @return A new SearchNode representing the expanded state. + */ + SearchNode expand(const DiscreteConditional& conditional, + const DiscreteValues& fa) const { + // Combine the new frontal assignment with the current partial assignment + SearchNode child; + child.assignment = assignment; + for (auto& kv : fa) { + child.assignment[kv.first] = kv.second; + } + + // Compute the incremental error for this factor + child.error = error + conditional.error(child.assignment); + + // Compute new bound + child.bound = computeBound(); + + // Update the index of the next conditional + child.nextConditional = nextConditional - 1; + + return child; + } + + /** + * @brief Prints the SearchNode to an output stream. + * + * @param os The output stream. + * @param node The SearchNode to be printed. + * @return The output stream. + */ + friend std::ostream& operator<<(std::ostream& os, const SearchNode& node) { + os << "SearchNode(error=" << node.error << ", bound=" << node.bound << ")"; + return os; + } +}; + +struct Solution { + double error; + DiscreteValues assignment; + Solution(double err, const DiscreteValues& assign) + : error(err), assignment(assign) {} + friend std::ostream& operator<<(std::ostream& os, const Solution& sn) { + os << "[ error=" << sn.error << " assignment={" << sn.assignment << "}]"; + return os; + } +}; + +struct CompareByError { + bool operator()(const Solution& a, const Solution& b) const { + return a.error < b.error; + } +}; + +// Define the Solutions class +class Solutions { + private: + size_t maxSize_; + std::priority_queue, CompareByError> pq_; + + public: + Solutions(size_t maxSize) : maxSize_(maxSize) {} + + /// Add a solution to the priority queue, possibly evicting the worst one. + /// Return true if we added the solution. + bool maybeAdd(double error, const DiscreteValues& assignment) { + const bool full = pq_.size() == maxSize_; + if (full && error >= pq_.top().error) return false; + if (full) pq_.pop(); + pq_.emplace(error, assignment); + return true; + } + + /// Check if we have any solutions + bool empty() const { return pq_.empty(); } + + // Method to print all solutions + friend std::ostream& operator<<(std::ostream& os, const Solutions& sn) { + auto pq = sn.pq_; + while (!pq.empty()) { + const Solution& best = pq.top(); + os << "Error: " << best.error << ", Values: " << best.assignment + << std::endl; + pq.pop(); + } + return os; + } + + /// Check if (partial) solution with given bound can be pruned. If we have + /// room, we never prune. Otherwise, prune if lower bound on error is worse + /// than our current worst error. + bool prune(double bound) const { + if (pq_.size() < maxSize_) return false; + double worstError = pq_.top().error; + return (bound >= worstError); + } + + // Method to extract solutions in ascending order of error + std::vector extractSolutions() { + std::vector result; + while (!pq_.empty()) { + result.push_back(pq_.top()); + pq_.pop(); + } + std::sort( + result.begin(), result.end(), + [](const Solution& a, const Solution& b) { return a.error < b.error; }); + return result; + } +}; + +/** + * DiscreteSearch: Search for the K best solutions. + */ +class DiscreteSearch { + public: + /** + * Construct from a DiscreteBayesNet and K. + */ + DiscreteSearch(const DiscreteBayesNet& bayesNet, size_t K) : solutions_(K) { + // Copy out the conditionals + for (auto& factor : bayesNet) { + conditionals_.push_back(factor); + } + + // Calculate the cost-to-go for each conditional. If there are n + // conditionals, we start with nextConditional = n-1, and the minimum error + // obtainable is the sum of all the minimum errors. We start with + // 0, and that is the minimum error of the conditional with that index: + double error = 0.0; + for (const auto& conditional : conditionals_) { + Ordering ordering(conditional->begin(), conditional->end()); + auto maxx = conditional->max(ordering); + assert(maxx->size() == 1); + error -= std::log(maxx->evaluate({})); + costToGo_.push_back(error); + } + + // Create the root node: no variables assigned, nextConditional = last. + SearchNode root{ + .assignment = DiscreteValues(), + .error = 0.0, + .nextConditional = static_cast(conditionals_.size()) - 1}; + if (!costToGo_.empty()) root.bound = costToGo_.back(); + expansions_.push(root); + } + + /** + * @brief Search for the K best solutions. + * + * This method performs a search to find the K best solutions for the given + * DiscreteBayesNet. It uses a priority queue to manage the search nodes, + * expanding nodes with the smallest bound first. The search continues until + * all possible nodes have been expanded or pruned. + * + * @return A vector of the K best solutions found during the search. + */ + std::vector run() { + while (!expansions_.empty()) { + expandNextNode(); + } + + // Extract solutions from bestSolutions in ascending order of error + return solutions_.extractSolutions(); + } + + private: + void expandNextNode() { + // Pop the partial assignment with the smallest bound + SearchNode current = expansions_.top(); + expansions_.pop(); + + // If we already have K solutions, prune if we cannot beat the worst one. + if (solutions_.prune(current.bound)) { + return; + } + + // Check if we have a complete assignment + if (current.isComplete()) { + solutions_.maybeAdd(current.error, current.assignment); + return; + } + + // Expand on the next factor + const auto& conditional = conditionals_[current.nextConditional]; + + for (auto& fa : conditional->frontalAssignments()) { + auto childNode = current.expand(*conditional, fa); + if (childNode.nextConditional >= 0) + childNode.bound = + childNode.error + costToGo_[childNode.nextConditional]; + + // Again, prune if we cannot beat the worst solution + if (!solutions_.prune(childNode.bound)) { + expansions_.push(childNode); + } + } + } + + std::vector> conditionals_; + std::vector costToGo_; + std::priority_queue, + SearchNode::CompareByBound> + expansions_; + Solutions solutions_; +}; +} // namespace gtsam diff --git a/gtsam/discrete/tests/testDiscreteSearch.cpp b/gtsam/discrete/tests/testDiscreteSearch.cpp new file mode 100644 index 000000000..4d7adbac5 --- /dev/null +++ b/gtsam/discrete/tests/testDiscreteSearch.cpp @@ -0,0 +1,59 @@ +/* ---------------------------------------------------------------------------- + + * GTSAM Copyright 2010, Georgia Tech Research Corporation, + * Atlanta, Georgia 30332-0415 + * All Rights Reserved + * Authors: Frank Dellaert, et al. (see THANKS for the full author list) + + * See LICENSE for the license information + + * -------------------------------------------------------------------------- */ + +/* + * testDiscreteSearch.cpp + * + * @date January, 2025 + * @author Frank Dellaert + */ + +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include + +#include "AsiaExample.h" + +using namespace gtsam; + +TEST(DiscreteBayesNet, EmptyKBest) { + DiscreteBayesNet net; // no factors + DiscreteSearch search(net, 3); + auto solutions = search.run(); + // Expect one solution with empty assignment, error=0 + EXPECT_LONGS_EQUAL(1, solutions.size()); + EXPECT_DOUBLES_EQUAL(0, std::fabs(solutions[0].error), 1e-9); +} + +TEST(DiscreteBayesNet, AsiaKBest) { + using namespace asia_example; + DiscreteBayesNet asia = createAsiaExample(); + DiscreteSearch search(asia, 4); + auto solutions = search.run(); + EXPECT(!solutions.empty()); + // Regression test: check the first solution + EXPECT_DOUBLES_EQUAL(1.236627, std::fabs(solutions[0].error), 1e-5); +} + +/* ************************************************************************* */ +int main() { + TestResult tr; + return TestRegistry::runAllTests(tr); +} +/* ************************************************************************* */ From 54f493358d5a57e2d3034159c18bfca2f444cfa5 Mon Sep 17 00:00:00 2001 From: Frank Dellaert Date: Sun, 26 Jan 2025 22:48:40 -0500 Subject: [PATCH 07/20] BayesTree WIP --- gtsam/discrete/DiscreteSearch.h | 7 +++++ gtsam/discrete/tests/testDiscreteSearch.cpp | 31 +++++++++++++++++++++ 2 files changed, 38 insertions(+) diff --git a/gtsam/discrete/DiscreteSearch.h b/gtsam/discrete/DiscreteSearch.h index d9ffae768..c267001b5 100644 --- a/gtsam/discrete/DiscreteSearch.h +++ b/gtsam/discrete/DiscreteSearch.h @@ -17,6 +17,7 @@ */ #include +#include namespace gtsam { @@ -214,6 +215,12 @@ class DiscreteSearch { expansions_.push(root); } + /** + * Construct from a DiscreteBayesNet and K. + */ + DiscreteSearch(const DiscreteBayesTree& bayesTree, size_t K) + : solutions_(K) {} + /** * @brief Search for the K best solutions. * diff --git a/gtsam/discrete/tests/testDiscreteSearch.cpp b/gtsam/discrete/tests/testDiscreteSearch.cpp index 4d7adbac5..b0f2b03e3 100644 --- a/gtsam/discrete/tests/testDiscreteSearch.cpp +++ b/gtsam/discrete/tests/testDiscreteSearch.cpp @@ -32,6 +32,7 @@ using namespace gtsam; +/* ************************************************************************* */ TEST(DiscreteBayesNet, EmptyKBest) { DiscreteBayesNet net; // no factors DiscreteSearch search(net, 3); @@ -41,6 +42,7 @@ TEST(DiscreteBayesNet, EmptyKBest) { EXPECT_DOUBLES_EQUAL(0, std::fabs(solutions[0].error), 1e-9); } +/* ************************************************************************* */ TEST(DiscreteBayesNet, AsiaKBest) { using namespace asia_example; DiscreteBayesNet asia = createAsiaExample(); @@ -51,6 +53,35 @@ TEST(DiscreteBayesNet, AsiaKBest) { EXPECT_DOUBLES_EQUAL(1.236627, std::fabs(solutions[0].error), 1e-5); } +/* ************************************************************************* */ +TEST(DiscreteBayesTree, testEmptyTree) { + DiscreteBayesTree bt; + + DiscreteSearch search(bt, 3); + auto solutions = search.run(); + + // We expect exactly 1 solution with error = 0.0 (the empty assignment). + assert(solutions.size() == 1 && "There should be exactly one empty solution"); + EXPECT_LONGS_EQUAL(1, solutions.size()); + EXPECT_DOUBLES_EQUAL(0, std::fabs(solutions[0].error), 1e-9); +} + +/* ************************************************************************* */ +TEST(DiscreteBayesTree, testTrivialOneClique) { + using namespace asia_example; + DiscreteFactorGraph asia(createAsiaExample()); + DiscreteBayesTree::shared_ptr bt = asia.eliminateMultifrontal(); + GTSAM_PRINT(*bt); + + // Ask for top 4 solutions + DiscreteSearch search(*bt, 4); + auto solutions = search.run(); + + EXPECT(!solutions.empty()); + // Regression test: check the first solution + EXPECT_DOUBLES_EQUAL(1.236627, std::fabs(solutions[0].error), 1e-5); +} + /* ************************************************************************* */ int main() { TestResult tr; From 14eeaf93db346f04acd0704abaf0e91b07e209f2 Mon Sep 17 00:00:00 2001 From: Frank Dellaert Date: Sun, 26 Jan 2025 23:24:09 -0500 Subject: [PATCH 08/20] Works for Bayes trees now as well --- gtsam/discrete/DiscreteSearch.h | 124 ++++++++++++-------- gtsam/discrete/tests/AsiaExample.h | 2 +- gtsam/discrete/tests/testDiscreteSearch.cpp | 12 +- 3 files changed, 84 insertions(+), 54 deletions(-) diff --git a/gtsam/discrete/DiscreteSearch.h b/gtsam/discrete/DiscreteSearch.h index c267001b5..24517ee29 100644 --- a/gtsam/discrete/DiscreteSearch.h +++ b/gtsam/discrete/DiscreteSearch.h @@ -36,6 +36,16 @@ struct SearchNode { double bound; ///< Lower bound on the final error for unassigned variables. int nextConditional; ///< Index of the next conditional to be assigned. + /** + * @brief Construct the root node for the search. + */ + static SearchNode Root(size_t numConditionals, double bound) { + return {.assignment = DiscreteValues(), + .error = 0.0, + .bound = bound, + .nextConditional = static_cast(numConditionals) - 1}; + } + struct CompareByBound { bool operator()(const SearchNode& a, const SearchNode& b) const { return a.bound > b.bound; // smallest bound -> highest priority @@ -49,20 +59,6 @@ struct SearchNode { */ bool isComplete() const { return nextConditional < 0; } - /** - * @brief Computes a lower bound on the final error for unassigned variables. - * - * @details This is a stub implementation that returns 0. Real implementations - * might perform partial factor analysis or use heuristics to compute the - * bound. - * - * @return A lower bound on the final error. - */ - double computeBound() const { - // Real code might do partial factor analysis or heuristics. - return 0.0; - } - /** * @brief Expands the node by assigning the next variable. * @@ -74,22 +70,15 @@ struct SearchNode { SearchNode expand(const DiscreteConditional& conditional, const DiscreteValues& fa) const { // Combine the new frontal assignment with the current partial assignment - SearchNode child; - child.assignment = assignment; + DiscreteValues newAssignment = assignment; for (auto& kv : fa) { - child.assignment[kv.first] = kv.second; + newAssignment[kv.first] = kv.second; } - // Compute the incremental error for this factor - child.error = error + conditional.error(child.assignment); - - // Compute new bound - child.bound = computeBound(); - - // Update the index of the next conditional - child.nextConditional = nextConditional - 1; - - return child; + return {.assignment = newAssignment, + .error = error + conditional.error(newAssignment), + .bound = 0.0, + .nextConditional = nextConditional - 1}; } /** @@ -184,6 +173,8 @@ class Solutions { */ class DiscreteSearch { public: + size_t numExpansions = 0; + /** * Construct from a DiscreteBayesNet and K. */ @@ -193,33 +184,42 @@ class DiscreteSearch { conditionals_.push_back(factor); } - // Calculate the cost-to-go for each conditional. If there are n - // conditionals, we start with nextConditional = n-1, and the minimum error - // obtainable is the sum of all the minimum errors. We start with - // 0, and that is the minimum error of the conditional with that index: - double error = 0.0; - for (const auto& conditional : conditionals_) { - Ordering ordering(conditional->begin(), conditional->end()); - auto maxx = conditional->max(ordering); - assert(maxx->size() == 1); - error -= std::log(maxx->evaluate({})); - costToGo_.push_back(error); - } + // Calculate the cost-to-go for each conditional + costToGo_ = computeCostToGo(conditionals_); - // Create the root node: no variables assigned, nextConditional = last. - SearchNode root{ - .assignment = DiscreteValues(), - .error = 0.0, - .nextConditional = static_cast(conditionals_.size()) - 1}; - if (!costToGo_.empty()) root.bound = costToGo_.back(); - expansions_.push(root); + // Create the root node and push it to the expansions queue + expansions_.push(SearchNode::Root( + conditionals_.size(), costToGo_.empty() ? 0.0 : costToGo_.back())); } /** - * Construct from a DiscreteBayesNet and K. + * Construct from a DiscreteBayesTree and K. */ - DiscreteSearch(const DiscreteBayesTree& bayesTree, size_t K) - : solutions_(K) {} + DiscreteSearch(const DiscreteBayesTree& bayesTree, size_t K) : solutions_(K) { + using CliquePtr = DiscreteBayesTree::sharedClique; + std::function collectConditionals = + [&](const CliquePtr& clique) -> void { + if (!clique) return; + + // Recursive post-order traversal: process children first + for (const auto& child : clique->children) { + collectConditionals(child); + } + + // Then add the current clique's conditional + conditionals_.push_back(clique->conditional()); + }; + + // Start traversal from each root in the tree + for (const auto& root : bayesTree.roots()) collectConditionals(root); + + // Calculate the cost-to-go for each conditional + costToGo_ = computeCostToGo(conditionals_); + + // Create the root node and push it to the expansions queue + expansions_.push(SearchNode::Root( + conditionals_.size(), costToGo_.empty() ? 0.0 : costToGo_.back())); + } /** * @brief Search for the K best solutions. @@ -233,6 +233,7 @@ class DiscreteSearch { */ std::vector run() { while (!expansions_.empty()) { + numExpansions++; expandNextNode(); } @@ -241,6 +242,29 @@ class DiscreteSearch { } private: + /** + * @brief Compute the cost-to-go for each conditional. + * + * @param conditionals The conditionals of the DiscreteBayesNet. + * @return A vector of cost-to-go values. + */ + static std::vector computeCostToGo( + const std::vector& conditionals) { + std::vector costToGo; + double error = 0.0; + for (const auto& conditional : conditionals) { + Ordering ordering(conditional->begin(), conditional->end()); + auto maxx = conditional->max(ordering); + assert(maxx->size() == 1); + error -= std::log(maxx->evaluate({})); + costToGo.push_back(error); + } + return costToGo; + } + + /** + * @brief Expand the next node in the search tree. + */ void expandNextNode() { // Pop the partial assignment with the smallest bound SearchNode current = expansions_.top(); @@ -273,7 +297,7 @@ class DiscreteSearch { } } - std::vector> conditionals_; + std::vector conditionals_; std::vector costToGo_; std::priority_queue, SearchNode::CompareByBound> diff --git a/gtsam/discrete/tests/AsiaExample.h b/gtsam/discrete/tests/AsiaExample.h index 7413f6899..6c327daec 100644 --- a/gtsam/discrete/tests/AsiaExample.h +++ b/gtsam/discrete/tests/AsiaExample.h @@ -30,7 +30,7 @@ static const DiscreteKey Dyspnea(D, 2), XRay(X, 2), Either(E, 2), Bronchitis(B, 2), LungCancer(L, 2), Tuberculosis(T, 2), Smoking(S, 2), Asia(A, 2); -// Function to construct the incomplete Asia example +// Function to construct the Asia priors DiscreteBayesNet createPriors() { DiscreteBayesNet priors; priors.add(Smoking % "50/50"); diff --git a/gtsam/discrete/tests/testDiscreteSearch.cpp b/gtsam/discrete/tests/testDiscreteSearch.cpp index b0f2b03e3..041c3b465 100644 --- a/gtsam/discrete/tests/testDiscreteSearch.cpp +++ b/gtsam/discrete/tests/testDiscreteSearch.cpp @@ -49,8 +49,9 @@ TEST(DiscreteBayesNet, AsiaKBest) { DiscreteSearch search(asia, 4); auto solutions = search.run(); EXPECT(!solutions.empty()); - // Regression test: check the first solution + // Regression test: check the first and last solution EXPECT_DOUBLES_EQUAL(1.236627, std::fabs(solutions[0].error), 1e-5); + EXPECT_DOUBLES_EQUAL(2.201708, std::fabs(solutions[3].error), 1e-5); } /* ************************************************************************* */ @@ -70,16 +71,21 @@ TEST(DiscreteBayesTree, testEmptyTree) { TEST(DiscreteBayesTree, testTrivialOneClique) { using namespace asia_example; DiscreteFactorGraph asia(createAsiaExample()); - DiscreteBayesTree::shared_ptr bt = asia.eliminateMultifrontal(); + const Ordering ordering{D, X, B, E, L, T, S, A}; + DiscreteBayesTree::shared_ptr bt = asia.eliminateMultifrontal(ordering); GTSAM_PRINT(*bt); // Ask for top 4 solutions DiscreteSearch search(*bt, 4); auto solutions = search.run(); + // print numExpansions + std::cout << "Number of expansions: " << search.numExpansions << std::endl; + EXPECT(!solutions.empty()); - // Regression test: check the first solution + // Regression test: check the first and last solution EXPECT_DOUBLES_EQUAL(1.236627, std::fabs(solutions[0].error), 1e-5); + EXPECT_DOUBLES_EQUAL(2.201708, std::fabs(solutions[3].error), 1e-5); } /* ************************************************************************* */ From 70089a0fd4e0c3c7e87cb0eb2493e82b8165235d Mon Sep 17 00:00:00 2001 From: Frank Dellaert Date: Sun, 26 Jan 2025 23:34:30 -0500 Subject: [PATCH 09/20] Move to cpp file --- gtsam/discrete/DiscreteSearch.cpp | 178 ++++++++++++++++++++ gtsam/discrete/DiscreteSearch.h | 169 +++---------------- gtsam/discrete/tests/testDiscreteSearch.cpp | 5 +- 3 files changed, 201 insertions(+), 151 deletions(-) create mode 100644 gtsam/discrete/DiscreteSearch.cpp diff --git a/gtsam/discrete/DiscreteSearch.cpp b/gtsam/discrete/DiscreteSearch.cpp new file mode 100644 index 000000000..16794dcfc --- /dev/null +++ b/gtsam/discrete/DiscreteSearch.cpp @@ -0,0 +1,178 @@ +/* ---------------------------------------------------------------------------- + + * GTSAM Copyright 2010, Georgia Tech Research Corporation, + * Atlanta, Georgia 30332-0415 + * All Rights Reserved + * Authors: Frank Dellaert, et al. (see THANKS for the full author list) + + * See LICENSE for the license information + + * -------------------------------------------------------------------------- */ + +/* + * DiscreteSearch.cpp + * + * @date January, 2025 + * @author Frank Dellaert + */ + +#include + +namespace gtsam { + +SearchNode SearchNode::Root(size_t numConditionals, double bound) { + return {.assignment = DiscreteValues(), + .error = 0.0, + .bound = bound, + .nextConditional = static_cast(numConditionals) - 1}; +} + +SearchNode SearchNode::expand(const DiscreteConditional& conditional, + const DiscreteValues& fa) const { + // Combine the new frontal assignment with the current partial assignment + DiscreteValues newAssignment = assignment; + for (auto& kv : fa) { + newAssignment[kv.first] = kv.second; + } + + return {.assignment = newAssignment, + .error = error + conditional.error(newAssignment), + .bound = 0.0, + .nextConditional = nextConditional - 1}; +} + +bool Solutions::maybeAdd(double error, const DiscreteValues& assignment) { + const bool full = pq_.size() == maxSize_; + if (full && error >= pq_.top().error) return false; + if (full) pq_.pop(); + pq_.emplace(error, assignment); + return true; +} + +std::ostream& operator<<(std::ostream& os, const Solutions& sn) { + auto pq = sn.pq_; + while (!pq.empty()) { + const Solution& best = pq.top(); + os << "Error: " << best.error << ", Values: " << best.assignment + << std::endl; + pq.pop(); + } + return os; +} + +bool Solutions::prune(double bound) const { + if (pq_.size() < maxSize_) return false; + double worstError = pq_.top().error; + return (bound >= worstError); +} + +std::vector Solutions::extractSolutions() { + std::vector result; + while (!pq_.empty()) { + result.push_back(pq_.top()); + pq_.pop(); + } + std::sort( + result.begin(), result.end(), + [](const Solution& a, const Solution& b) { return a.error < b.error; }); + return result; +} + +DiscreteSearch::DiscreteSearch(const DiscreteBayesNet& bayesNet, size_t K) + : solutions_(K) { + // Copy out the conditionals + for (auto& factor : bayesNet) { + conditionals_.push_back(factor); + } + + // Calculate the cost-to-go for each conditional + costToGo_ = computeCostToGo(conditionals_); + + // Create the root node and push it to the expansions queue + expansions_.push(SearchNode::Root( + conditionals_.size(), costToGo_.empty() ? 0.0 : costToGo_.back())); +} + +DiscreteSearch::DiscreteSearch(const DiscreteBayesTree& bayesTree, size_t K) + : solutions_(K) { + using CliquePtr = DiscreteBayesTree::sharedClique; + std::function collectConditionals = + [&](const CliquePtr& clique) -> void { + if (!clique) return; + + // Recursive post-order traversal: process children first + for (const auto& child : clique->children) { + collectConditionals(child); + } + + // Then add the current clique's conditional + conditionals_.push_back(clique->conditional()); + }; + + // Start traversal from each root in the tree + for (const auto& root : bayesTree.roots()) collectConditionals(root); + + // Calculate the cost-to-go for each conditional + costToGo_ = computeCostToGo(conditionals_); + + // Create the root node and push it to the expansions queue + expansions_.push(SearchNode::Root( + conditionals_.size(), costToGo_.empty() ? 0.0 : costToGo_.back())); +} + +std::vector DiscreteSearch::run() { + while (!expansions_.empty()) { + numExpansions++; + expandNextNode(); + } + + // Extract solutions from bestSolutions in ascending order of error + return solutions_.extractSolutions(); +} + +std::vector DiscreteSearch::computeCostToGo( + const std::vector& conditionals) { + std::vector costToGo; + double error = 0.0; + for (const auto& conditional : conditionals) { + Ordering ordering(conditional->begin(), conditional->end()); + auto maxx = conditional->max(ordering); + assert(maxx->size() == 1); + error -= std::log(maxx->evaluate({})); + costToGo.push_back(error); + } + return costToGo; +} + +void DiscreteSearch::expandNextNode() { + // Pop the partial assignment with the smallest bound + SearchNode current = expansions_.top(); + expansions_.pop(); + + // If we already have K solutions, prune if we cannot beat the worst one. + if (solutions_.prune(current.bound)) { + return; + } + + // Check if we have a complete assignment + if (current.isComplete()) { + solutions_.maybeAdd(current.error, current.assignment); + return; + } + + // Expand on the next factor + const auto& conditional = conditionals_[current.nextConditional]; + + for (auto& fa : conditional->frontalAssignments()) { + auto childNode = current.expand(*conditional, fa); + if (childNode.nextConditional >= 0) + childNode.bound = childNode.error + costToGo_[childNode.nextConditional]; + + // Again, prune if we cannot beat the worst solution + if (!solutions_.prune(childNode.bound)) { + expansions_.push(childNode); + } + } +} + +} // namespace gtsam diff --git a/gtsam/discrete/DiscreteSearch.h b/gtsam/discrete/DiscreteSearch.h index 24517ee29..989437b9f 100644 --- a/gtsam/discrete/DiscreteSearch.h +++ b/gtsam/discrete/DiscreteSearch.h @@ -39,14 +39,9 @@ struct SearchNode { /** * @brief Construct the root node for the search. */ - static SearchNode Root(size_t numConditionals, double bound) { - return {.assignment = DiscreteValues(), - .error = 0.0, - .bound = bound, - .nextConditional = static_cast(numConditionals) - 1}; - } + static SearchNode Root(size_t numConditionals, double bound); - struct CompareByBound { + struct Compare { bool operator()(const SearchNode& a, const SearchNode& b) const { return a.bound > b.bound; // smallest bound -> highest priority } @@ -68,18 +63,7 @@ struct SearchNode { * @return A new SearchNode representing the expanded state. */ SearchNode expand(const DiscreteConditional& conditional, - const DiscreteValues& fa) const { - // Combine the new frontal assignment with the current partial assignment - DiscreteValues newAssignment = assignment; - for (auto& kv : fa) { - newAssignment[kv.first] = kv.second; - } - - return {.assignment = newAssignment, - .error = error + conditional.error(newAssignment), - .bound = 0.0, - .nextConditional = nextConditional - 1}; - } + const DiscreteValues& fa) const; /** * @brief Prints the SearchNode to an output stream. @@ -103,69 +87,40 @@ struct Solution { os << "[ error=" << sn.error << " assignment={" << sn.assignment << "}]"; return os; } -}; -struct CompareByError { - bool operator()(const Solution& a, const Solution& b) const { - return a.error < b.error; - } + struct Compare { + bool operator()(const Solution& a, const Solution& b) const { + return a.error < b.error; + } + }; }; // Define the Solutions class class Solutions { private: size_t maxSize_; - std::priority_queue, CompareByError> pq_; + std::priority_queue, Solution::Compare> pq_; public: Solutions(size_t maxSize) : maxSize_(maxSize) {} /// Add a solution to the priority queue, possibly evicting the worst one. /// Return true if we added the solution. - bool maybeAdd(double error, const DiscreteValues& assignment) { - const bool full = pq_.size() == maxSize_; - if (full && error >= pq_.top().error) return false; - if (full) pq_.pop(); - pq_.emplace(error, assignment); - return true; - } + bool maybeAdd(double error, const DiscreteValues& assignment); /// Check if we have any solutions bool empty() const { return pq_.empty(); } // Method to print all solutions - friend std::ostream& operator<<(std::ostream& os, const Solutions& sn) { - auto pq = sn.pq_; - while (!pq.empty()) { - const Solution& best = pq.top(); - os << "Error: " << best.error << ", Values: " << best.assignment - << std::endl; - pq.pop(); - } - return os; - } + friend std::ostream& operator<<(std::ostream& os, const Solutions& sn); /// Check if (partial) solution with given bound can be pruned. If we have /// room, we never prune. Otherwise, prune if lower bound on error is worse /// than our current worst error. - bool prune(double bound) const { - if (pq_.size() < maxSize_) return false; - double worstError = pq_.top().error; - return (bound >= worstError); - } + bool prune(double bound) const; // Method to extract solutions in ascending order of error - std::vector extractSolutions() { - std::vector result; - while (!pq_.empty()) { - result.push_back(pq_.top()); - pq_.pop(); - } - std::sort( - result.begin(), result.end(), - [](const Solution& a, const Solution& b) { return a.error < b.error; }); - return result; - } + std::vector extractSolutions(); }; /** @@ -178,48 +133,12 @@ class DiscreteSearch { /** * Construct from a DiscreteBayesNet and K. */ - DiscreteSearch(const DiscreteBayesNet& bayesNet, size_t K) : solutions_(K) { - // Copy out the conditionals - for (auto& factor : bayesNet) { - conditionals_.push_back(factor); - } - - // Calculate the cost-to-go for each conditional - costToGo_ = computeCostToGo(conditionals_); - - // Create the root node and push it to the expansions queue - expansions_.push(SearchNode::Root( - conditionals_.size(), costToGo_.empty() ? 0.0 : costToGo_.back())); - } + DiscreteSearch(const DiscreteBayesNet& bayesNet, size_t K); /** * Construct from a DiscreteBayesTree and K. */ - DiscreteSearch(const DiscreteBayesTree& bayesTree, size_t K) : solutions_(K) { - using CliquePtr = DiscreteBayesTree::sharedClique; - std::function collectConditionals = - [&](const CliquePtr& clique) -> void { - if (!clique) return; - - // Recursive post-order traversal: process children first - for (const auto& child : clique->children) { - collectConditionals(child); - } - - // Then add the current clique's conditional - conditionals_.push_back(clique->conditional()); - }; - - // Start traversal from each root in the tree - for (const auto& root : bayesTree.roots()) collectConditionals(root); - - // Calculate the cost-to-go for each conditional - costToGo_ = computeCostToGo(conditionals_); - - // Create the root node and push it to the expansions queue - expansions_.push(SearchNode::Root( - conditionals_.size(), costToGo_.empty() ? 0.0 : costToGo_.back())); - } + DiscreteSearch(const DiscreteBayesTree& bayesTree, size_t K); /** * @brief Search for the K best solutions. @@ -231,15 +150,7 @@ class DiscreteSearch { * * @return A vector of the K best solutions found during the search. */ - std::vector run() { - while (!expansions_.empty()) { - numExpansions++; - expandNextNode(); - } - - // Extract solutions from bestSolutions in ascending order of error - return solutions_.extractSolutions(); - } + std::vector run(); private: /** @@ -249,58 +160,16 @@ class DiscreteSearch { * @return A vector of cost-to-go values. */ static std::vector computeCostToGo( - const std::vector& conditionals) { - std::vector costToGo; - double error = 0.0; - for (const auto& conditional : conditionals) { - Ordering ordering(conditional->begin(), conditional->end()); - auto maxx = conditional->max(ordering); - assert(maxx->size() == 1); - error -= std::log(maxx->evaluate({})); - costToGo.push_back(error); - } - return costToGo; - } + const std::vector& conditionals); /** * @brief Expand the next node in the search tree. */ - void expandNextNode() { - // Pop the partial assignment with the smallest bound - SearchNode current = expansions_.top(); - expansions_.pop(); - - // If we already have K solutions, prune if we cannot beat the worst one. - if (solutions_.prune(current.bound)) { - return; - } - - // Check if we have a complete assignment - if (current.isComplete()) { - solutions_.maybeAdd(current.error, current.assignment); - return; - } - - // Expand on the next factor - const auto& conditional = conditionals_[current.nextConditional]; - - for (auto& fa : conditional->frontalAssignments()) { - auto childNode = current.expand(*conditional, fa); - if (childNode.nextConditional >= 0) - childNode.bound = - childNode.error + costToGo_[childNode.nextConditional]; - - // Again, prune if we cannot beat the worst solution - if (!solutions_.prune(childNode.bound)) { - expansions_.push(childNode); - } - } - } + void expandNextNode(); std::vector conditionals_; std::vector costToGo_; - std::priority_queue, - SearchNode::CompareByBound> + std::priority_queue, SearchNode::Compare> expansions_; Solutions solutions_; }; diff --git a/gtsam/discrete/tests/testDiscreteSearch.cpp b/gtsam/discrete/tests/testDiscreteSearch.cpp index 041c3b465..a39bcd86b 100644 --- a/gtsam/discrete/tests/testDiscreteSearch.cpp +++ b/gtsam/discrete/tests/testDiscreteSearch.cpp @@ -48,6 +48,10 @@ TEST(DiscreteBayesNet, AsiaKBest) { DiscreteBayesNet asia = createAsiaExample(); DiscreteSearch search(asia, 4); auto solutions = search.run(); + + // print numExpansions + std::cout << "Number of expansions: " << search.numExpansions << std::endl; + EXPECT(!solutions.empty()); // Regression test: check the first and last solution EXPECT_DOUBLES_EQUAL(1.236627, std::fabs(solutions[0].error), 1e-5); @@ -73,7 +77,6 @@ TEST(DiscreteBayesTree, testTrivialOneClique) { DiscreteFactorGraph asia(createAsiaExample()); const Ordering ordering{D, X, B, E, L, T, S, A}; DiscreteBayesTree::shared_ptr bt = asia.eliminateMultifrontal(ordering); - GTSAM_PRINT(*bt); // Ask for top 4 solutions DiscreteSearch search(*bt, 4); From b10ea06626a6d3de83ee01d20237c41cfb686d3b Mon Sep 17 00:00:00 2001 From: Frank Dellaert Date: Mon, 27 Jan 2025 00:07:22 -0500 Subject: [PATCH 10/20] Clean up, MPE tests --- gtsam/discrete/DiscreteSearch.cpp | 60 +++++++-------------- gtsam/discrete/DiscreteSearch.h | 28 +++++----- gtsam/discrete/tests/testDiscreteSearch.cpp | 31 +++++++++-- 3 files changed, 61 insertions(+), 58 deletions(-) diff --git a/gtsam/discrete/DiscreteSearch.cpp b/gtsam/discrete/DiscreteSearch.cpp index 16794dcfc..81ebc6012 100644 --- a/gtsam/discrete/DiscreteSearch.cpp +++ b/gtsam/discrete/DiscreteSearch.cpp @@ -31,8 +31,8 @@ SearchNode SearchNode::expand(const DiscreteConditional& conditional, const DiscreteValues& fa) const { // Combine the new frontal assignment with the current partial assignment DiscreteValues newAssignment = assignment; - for (auto& kv : fa) { - newAssignment[kv.first] = kv.second; + for (auto& [key, value] : fa) { + newAssignment[key] = value; } return {.assignment = newAssignment, @@ -50,11 +50,10 @@ bool Solutions::maybeAdd(double error, const DiscreteValues& assignment) { } std::ostream& operator<<(std::ostream& os, const Solutions& sn) { + os << "Solutions (top " << sn.pq_.size() << "):\n"; auto pq = sn.pq_; while (!pq.empty()) { - const Solution& best = pq.top(); - os << "Error: " << best.error << ", Values: " << best.assignment - << std::endl; + os << pq.top() << "\n"; pq.pop(); } return os; @@ -62,8 +61,7 @@ std::ostream& operator<<(std::ostream& os, const Solutions& sn) { bool Solutions::prune(double bound) const { if (pq_.size() < maxSize_) return false; - double worstError = pq_.top().error; - return (bound >= worstError); + return bound >= pq_.top().error; } std::vector Solutions::extractSolutions() { @@ -80,45 +78,23 @@ std::vector Solutions::extractSolutions() { DiscreteSearch::DiscreteSearch(const DiscreteBayesNet& bayesNet, size_t K) : solutions_(K) { - // Copy out the conditionals - for (auto& factor : bayesNet) { - conditionals_.push_back(factor); - } - - // Calculate the cost-to-go for each conditional - costToGo_ = computeCostToGo(conditionals_); - - // Create the root node and push it to the expansions queue - expansions_.push(SearchNode::Root( - conditionals_.size(), costToGo_.empty() ? 0.0 : costToGo_.back())); + std::vector conditionals; + for (auto& factor : bayesNet) conditionals.push_back(factor); + initialize(conditionals); } DiscreteSearch::DiscreteSearch(const DiscreteBayesTree& bayesTree, size_t K) : solutions_(K) { - using CliquePtr = DiscreteBayesTree::sharedClique; - std::function collectConditionals = - [&](const CliquePtr& clique) -> void { - if (!clique) return; - - // Recursive post-order traversal: process children first - for (const auto& child : clique->children) { - collectConditionals(child); - } - - // Then add the current clique's conditional - conditionals_.push_back(clique->conditional()); - }; - - // Start traversal from each root in the tree + std::vector conditionals; + std::function + collectConditionals = [&](const auto& clique) { + if (!clique) return; + for (const auto& child : clique->children) collectConditionals(child); + conditionals.push_back(clique->conditional()); + }; for (const auto& root : bayesTree.roots()) collectConditionals(root); - - // Calculate the cost-to-go for each conditional - costToGo_ = computeCostToGo(conditionals_); - - // Create the root node and push it to the expansions queue - expansions_.push(SearchNode::Root( - conditionals_.size(), costToGo_.empty() ? 0.0 : costToGo_.back())); -} + initialize(conditionals); +}; std::vector DiscreteSearch::run() { while (!expansions_.empty()) { @@ -170,7 +146,7 @@ void DiscreteSearch::expandNextNode() { // Again, prune if we cannot beat the worst solution if (!solutions_.prune(childNode.bound)) { - expansions_.push(childNode); + expansions_.emplace(childNode); } } } diff --git a/gtsam/discrete/DiscreteSearch.h b/gtsam/discrete/DiscreteSearch.h index 989437b9f..cbf258b3f 100644 --- a/gtsam/discrete/DiscreteSearch.h +++ b/gtsam/discrete/DiscreteSearch.h @@ -19,6 +19,8 @@ #include #include +#include + namespace gtsam { using Value = size_t; @@ -52,7 +54,7 @@ struct SearchNode { * * @return True if all variables have been assigned, false otherwise. */ - bool isComplete() const { return nextConditional < 0; } + inline bool isComplete() const { return nextConditional < 0; } /** * @brief Expands the node by assigning the next variable. @@ -133,12 +135,12 @@ class DiscreteSearch { /** * Construct from a DiscreteBayesNet and K. */ - DiscreteSearch(const DiscreteBayesNet& bayesNet, size_t K); + DiscreteSearch(const DiscreteBayesNet& bayesNet, size_t K = 1); /** * Construct from a DiscreteBayesTree and K. */ - DiscreteSearch(const DiscreteBayesTree& bayesTree, size_t K); + DiscreteSearch(const DiscreteBayesTree& bayesTree, size_t K = 1); /** * @brief Search for the K best solutions. @@ -153,18 +155,20 @@ class DiscreteSearch { std::vector run(); private: - /** - * @brief Compute the cost-to-go for each conditional. - * - * @param conditionals The conditionals of the DiscreteBayesNet. - * @return A vector of cost-to-go values. - */ + /// Initialize the search with the given conditionals. + void initialize( + const std::vector& conditionals) { + conditionals_ = conditionals; + costToGo_ = computeCostToGo(conditionals_); + expansions_.push(SearchNode::Root( + conditionals_.size(), costToGo_.empty() ? 0.0 : costToGo_.back())); + } + + /// Compute the cumulative cost-to-go for each conditional slot. static std::vector computeCostToGo( const std::vector& conditionals); - /** - * @brief Expand the next node in the search tree. - */ + /// Expand the next node in the search tree. void expandNextNode(); std::vector conditionals_; diff --git a/gtsam/discrete/tests/testDiscreteSearch.cpp b/gtsam/discrete/tests/testDiscreteSearch.cpp index a39bcd86b..e301ed3d4 100644 --- a/gtsam/discrete/tests/testDiscreteSearch.cpp +++ b/gtsam/discrete/tests/testDiscreteSearch.cpp @@ -46,20 +46,32 @@ TEST(DiscreteBayesNet, EmptyKBest) { TEST(DiscreteBayesNet, AsiaKBest) { using namespace asia_example; DiscreteBayesNet asia = createAsiaExample(); + + // Ask for the MPE + DiscreteSearch search1(asia); + auto mpe = search1.run(); + + // print numExpansions + std::cout << "Number of expansions: " << search1.numExpansions << std::endl; + + EXPECT_LONGS_EQUAL(1, mpe.size()); + // Regression test: check the MPE solution + EXPECT_DOUBLES_EQUAL(1.236627, std::fabs(mpe[0].error), 1e-5); + DiscreteSearch search(asia, 4); auto solutions = search.run(); // print numExpansions std::cout << "Number of expansions: " << search.numExpansions << std::endl; - EXPECT(!solutions.empty()); + EXPECT_LONGS_EQUAL(4, solutions.size()); // Regression test: check the first and last solution EXPECT_DOUBLES_EQUAL(1.236627, std::fabs(solutions[0].error), 1e-5); EXPECT_DOUBLES_EQUAL(2.201708, std::fabs(solutions[3].error), 1e-5); } /* ************************************************************************* */ -TEST(DiscreteBayesTree, testEmptyTree) { +TEST(DiscreteBayesTree, EmptyTree) { DiscreteBayesTree bt; DiscreteSearch search(bt, 3); @@ -72,12 +84,23 @@ TEST(DiscreteBayesTree, testEmptyTree) { } /* ************************************************************************* */ -TEST(DiscreteBayesTree, testTrivialOneClique) { +TEST(DiscreteBayesTree, AsiaTreeKBest) { using namespace asia_example; DiscreteFactorGraph asia(createAsiaExample()); const Ordering ordering{D, X, B, E, L, T, S, A}; DiscreteBayesTree::shared_ptr bt = asia.eliminateMultifrontal(ordering); + // Ask for top 4 solutions + DiscreteSearch search1(*bt); + auto mpe = search1.run(); + + // print numExpansions + std::cout << "Number of expansions: " << search1.numExpansions << std::endl; + + EXPECT_LONGS_EQUAL(1, mpe.size()); + // Regression test: check the MPE solution + EXPECT_DOUBLES_EQUAL(1.236627, std::fabs(mpe[0].error), 1e-5); + // Ask for top 4 solutions DiscreteSearch search(*bt, 4); auto solutions = search.run(); @@ -85,7 +108,7 @@ TEST(DiscreteBayesTree, testTrivialOneClique) { // print numExpansions std::cout << "Number of expansions: " << search.numExpansions << std::endl; - EXPECT(!solutions.empty()); + EXPECT_LONGS_EQUAL(4, solutions.size()); // Regression test: check the first and last solution EXPECT_DOUBLES_EQUAL(1.236627, std::fabs(solutions[0].error), 1e-5); EXPECT_DOUBLES_EQUAL(2.201708, std::fabs(solutions[3].error), 1e-5); From f174a38eedb25720dfa37852abd8aac5990e516d Mon Sep 17 00:00:00 2001 From: Frank Dellaert Date: Mon, 27 Jan 2025 00:26:35 -0500 Subject: [PATCH 11/20] Made run const, take K --- gtsam/discrete/DiscreteSearch.cpp | 107 ++++++++++++-------- gtsam/discrete/DiscreteSearch.h | 22 +--- gtsam/discrete/tests/testDiscreteSearch.cpp | 37 +++---- 3 files changed, 79 insertions(+), 87 deletions(-) diff --git a/gtsam/discrete/DiscreteSearch.cpp b/gtsam/discrete/DiscreteSearch.cpp index 81ebc6012..70f7f871a 100644 --- a/gtsam/discrete/DiscreteSearch.cpp +++ b/gtsam/discrete/DiscreteSearch.cpp @@ -76,34 +76,84 @@ std::vector Solutions::extractSolutions() { return result; } -DiscreteSearch::DiscreteSearch(const DiscreteBayesNet& bayesNet, size_t K) - : solutions_(K) { +DiscreteSearch::DiscreteSearch(const DiscreteBayesNet& bayesNet) { std::vector conditionals; - for (auto& factor : bayesNet) conditionals.push_back(factor); - initialize(conditionals); + for (auto& factor : bayesNet) conditionals_.push_back(factor); + costToGo_ = computeCostToGo(conditionals_); } -DiscreteSearch::DiscreteSearch(const DiscreteBayesTree& bayesTree, size_t K) - : solutions_(K) { - std::vector conditionals; +DiscreteSearch::DiscreteSearch(const DiscreteBayesTree& bayesTree) { std::function collectConditionals = [&](const auto& clique) { if (!clique) return; for (const auto& child : clique->children) collectConditionals(child); - conditionals.push_back(clique->conditional()); + conditionals_.push_back(clique->conditional()); }; for (const auto& root : bayesTree.roots()) collectConditionals(root); - initialize(conditionals); + costToGo_ = computeCostToGo(conditionals_); }; -std::vector DiscreteSearch::run() { - while (!expansions_.empty()) { - numExpansions++; +using SearchNodeQueue = std::priority_queue, + SearchNode::Compare>; + +std::vector DiscreteSearch::run(size_t K) const { + SearchNodeQueue expansions; + expansions.push(SearchNode::Root(conditionals_.size(), + costToGo_.empty() ? 0.0 : costToGo_.back())); + + Solutions solutions(K); + + auto expandNextNode = [&] { + // Pop the partial assignment with the smallest bound + SearchNode current = expansions.top(); + expansions.pop(); + + // If we already have K solutions, prune if we cannot beat the worst + // one. + if (solutions.prune(current.bound)) { + return; + } + + // Check if we have a complete assignment + if (current.isComplete()) { + solutions.maybeAdd(current.error, current.assignment); + return; + } + + // Expand on the next factor + const auto& conditional = conditionals_[current.nextConditional]; + + for (auto& fa : conditional->frontalAssignments()) { + auto childNode = current.expand(*conditional, fa); + if (childNode.nextConditional >= 0) + childNode.bound = + childNode.error + costToGo_[childNode.nextConditional]; + + // Again, prune if we cannot beat the worst solution + if (!solutions.prune(childNode.bound)) { + expansions.emplace(childNode); + } + } + }; + +#ifdef DISCRETE_SEARCH_DEBUG + size_t numExpansions = 0; +#endif + + // Perform the search + while (!expansions.empty()) { expandNextNode(); +#ifdef DISCRETE_SEARCH_DEBUG + ++numExpansions; +#endif } +#ifdef DISCRETE_SEARCH_DEBUG + std::cout << "Number of expansions: " << numExpansions << std::endl; +#endif + // Extract solutions from bestSolutions in ascending order of error - return solutions_.extractSolutions(); + return solutions.extractSolutions(); } std::vector DiscreteSearch::computeCostToGo( @@ -120,35 +170,4 @@ std::vector DiscreteSearch::computeCostToGo( return costToGo; } -void DiscreteSearch::expandNextNode() { - // Pop the partial assignment with the smallest bound - SearchNode current = expansions_.top(); - expansions_.pop(); - - // If we already have K solutions, prune if we cannot beat the worst one. - if (solutions_.prune(current.bound)) { - return; - } - - // Check if we have a complete assignment - if (current.isComplete()) { - solutions_.maybeAdd(current.error, current.assignment); - return; - } - - // Expand on the next factor - const auto& conditional = conditionals_[current.nextConditional]; - - for (auto& fa : conditional->frontalAssignments()) { - auto childNode = current.expand(*conditional, fa); - if (childNode.nextConditional >= 0) - childNode.bound = childNode.error + costToGo_[childNode.nextConditional]; - - // Again, prune if we cannot beat the worst solution - if (!solutions_.prune(childNode.bound)) { - expansions_.emplace(childNode); - } - } -} - } // namespace gtsam diff --git a/gtsam/discrete/DiscreteSearch.h b/gtsam/discrete/DiscreteSearch.h index cbf258b3f..f5bd67bae 100644 --- a/gtsam/discrete/DiscreteSearch.h +++ b/gtsam/discrete/DiscreteSearch.h @@ -130,17 +130,15 @@ class Solutions { */ class DiscreteSearch { public: - size_t numExpansions = 0; - /** * Construct from a DiscreteBayesNet and K. */ - DiscreteSearch(const DiscreteBayesNet& bayesNet, size_t K = 1); + DiscreteSearch(const DiscreteBayesNet& bayesNet); /** * Construct from a DiscreteBayesTree and K. */ - DiscreteSearch(const DiscreteBayesTree& bayesTree, size_t K = 1); + DiscreteSearch(const DiscreteBayesTree& bayesTree); /** * @brief Search for the K best solutions. @@ -152,29 +150,17 @@ class DiscreteSearch { * * @return A vector of the K best solutions found during the search. */ - std::vector run(); + std::vector run(size_t K = 1) const; private: - /// Initialize the search with the given conditionals. - void initialize( - const std::vector& conditionals) { - conditionals_ = conditionals; - costToGo_ = computeCostToGo(conditionals_); - expansions_.push(SearchNode::Root( - conditionals_.size(), costToGo_.empty() ? 0.0 : costToGo_.back())); - } - /// Compute the cumulative cost-to-go for each conditional slot. static std::vector computeCostToGo( const std::vector& conditionals); /// Expand the next node in the search tree. - void expandNextNode(); + void expandNextNode() const; std::vector conditionals_; std::vector costToGo_; - std::priority_queue, SearchNode::Compare> - expansions_; - Solutions solutions_; }; } // namespace gtsam diff --git a/gtsam/discrete/tests/testDiscreteSearch.cpp b/gtsam/discrete/tests/testDiscreteSearch.cpp index e301ed3d4..a3b50add4 100644 --- a/gtsam/discrete/tests/testDiscreteSearch.cpp +++ b/gtsam/discrete/tests/testDiscreteSearch.cpp @@ -35,8 +35,8 @@ using namespace gtsam; /* ************************************************************************* */ TEST(DiscreteBayesNet, EmptyKBest) { DiscreteBayesNet net; // no factors - DiscreteSearch search(net, 3); - auto solutions = search.run(); + DiscreteSearch search(net); + auto solutions = search.run(3); // Expect one solution with empty assignment, error=0 EXPECT_LONGS_EQUAL(1, solutions.size()); EXPECT_DOUBLES_EQUAL(0, std::fabs(solutions[0].error), 1e-9); @@ -46,23 +46,17 @@ TEST(DiscreteBayesNet, EmptyKBest) { TEST(DiscreteBayesNet, AsiaKBest) { using namespace asia_example; DiscreteBayesNet asia = createAsiaExample(); + DiscreteSearch search(asia); // Ask for the MPE - DiscreteSearch search1(asia); - auto mpe = search1.run(); - - // print numExpansions - std::cout << "Number of expansions: " << search1.numExpansions << std::endl; + auto mpe = search.run(); EXPECT_LONGS_EQUAL(1, mpe.size()); // Regression test: check the MPE solution EXPECT_DOUBLES_EQUAL(1.236627, std::fabs(mpe[0].error), 1e-5); - DiscreteSearch search(asia, 4); - auto solutions = search.run(); - - // print numExpansions - std::cout << "Number of expansions: " << search.numExpansions << std::endl; + // Ask for top 4 solutions + auto solutions = search.run(4); EXPECT_LONGS_EQUAL(4, solutions.size()); // Regression test: check the first and last solution @@ -74,8 +68,8 @@ TEST(DiscreteBayesNet, AsiaKBest) { TEST(DiscreteBayesTree, EmptyTree) { DiscreteBayesTree bt; - DiscreteSearch search(bt, 3); - auto solutions = search.run(); + DiscreteSearch search(bt); + auto solutions = search.run(3); // We expect exactly 1 solution with error = 0.0 (the empty assignment). assert(solutions.size() == 1 && "There should be exactly one empty solution"); @@ -89,24 +83,17 @@ TEST(DiscreteBayesTree, AsiaTreeKBest) { DiscreteFactorGraph asia(createAsiaExample()); const Ordering ordering{D, X, B, E, L, T, S, A}; DiscreteBayesTree::shared_ptr bt = asia.eliminateMultifrontal(ordering); + DiscreteSearch search(*bt); - // Ask for top 4 solutions - DiscreteSearch search1(*bt); - auto mpe = search1.run(); - - // print numExpansions - std::cout << "Number of expansions: " << search1.numExpansions << std::endl; + // Ask for MPE + auto mpe = search.run(); EXPECT_LONGS_EQUAL(1, mpe.size()); // Regression test: check the MPE solution EXPECT_DOUBLES_EQUAL(1.236627, std::fabs(mpe[0].error), 1e-5); // Ask for top 4 solutions - DiscreteSearch search(*bt, 4); - auto solutions = search.run(); - - // print numExpansions - std::cout << "Number of expansions: " << search.numExpansions << std::endl; + auto solutions = search.run(4); EXPECT_LONGS_EQUAL(4, solutions.size()); // Regression test: check the first and last solution From cee77695956de909dad8de0ca626db0793016803 Mon Sep 17 00:00:00 2001 From: Frank Dellaert Date: Mon, 27 Jan 2025 00:32:03 -0500 Subject: [PATCH 12/20] Moved data structures to cpp --- gtsam/discrete/DiscreteSearch.cpp | 172 ++++++++++++++++++++++-------- gtsam/discrete/DiscreteSearch.h | 90 +--------------- 2 files changed, 126 insertions(+), 136 deletions(-) diff --git a/gtsam/discrete/DiscreteSearch.cpp b/gtsam/discrete/DiscreteSearch.cpp index 70f7f871a..5fca5c820 100644 --- a/gtsam/discrete/DiscreteSearch.cpp +++ b/gtsam/discrete/DiscreteSearch.cpp @@ -20,61 +20,139 @@ namespace gtsam { -SearchNode SearchNode::Root(size_t numConditionals, double bound) { - return {.assignment = DiscreteValues(), - .error = 0.0, - .bound = bound, - .nextConditional = static_cast(numConditionals) - 1}; -} +using Value = size_t; -SearchNode SearchNode::expand(const DiscreteConditional& conditional, - const DiscreteValues& fa) const { - // Combine the new frontal assignment with the current partial assignment - DiscreteValues newAssignment = assignment; - for (auto& [key, value] : fa) { - newAssignment[key] = value; +/** + * @brief Represents a node in the search tree for discrete search algorithms. + * + * @details Each SearchNode contains a partial assignment of discrete variables, + * the current error, a bound on the final error, and the index of the next + * conditional to be assigned. + */ +struct SearchNode { + DiscreteValues assignment; ///< Partial assignment of discrete variables. + double error; ///< Current error for the partial assignment. + double bound; ///< Lower bound on the final error for unassigned variables. + int nextConditional; ///< Index of the next conditional to be assigned. + + /** + * @brief Construct the root node for the search. + */ + static SearchNode Root(size_t numConditionals, double bound) { + return {.assignment = DiscreteValues(), + .error = 0.0, + .bound = bound, + .nextConditional = static_cast(numConditionals) - 1}; } - return {.assignment = newAssignment, - .error = error + conditional.error(newAssignment), - .bound = 0.0, - .nextConditional = nextConditional - 1}; -} + struct Compare { + bool operator()(const SearchNode& a, const SearchNode& b) const { + return a.bound > b.bound; // smallest bound -> highest priority + } + }; -bool Solutions::maybeAdd(double error, const DiscreteValues& assignment) { - const bool full = pq_.size() == maxSize_; - if (full && error >= pq_.top().error) return false; - if (full) pq_.pop(); - pq_.emplace(error, assignment); - return true; -} + /** + * @brief Checks if the node represents a complete assignment. + * + * @return True if all variables have been assigned, false otherwise. + */ + inline bool isComplete() const { return nextConditional < 0; } -std::ostream& operator<<(std::ostream& os, const Solutions& sn) { - os << "Solutions (top " << sn.pq_.size() << "):\n"; - auto pq = sn.pq_; - while (!pq.empty()) { - os << pq.top() << "\n"; - pq.pop(); + /** + * @brief Expands the node by assigning the next variable. + * + * @param conditional The discrete conditional representing the next variable + * to be assigned. + * @param fa The frontal assignment for the next variable. + * @return A new SearchNode representing the expanded state. + */ + SearchNode expand(const DiscreteConditional& conditional, + const DiscreteValues& fa) const { + // Combine the new frontal assignment with the current partial assignment + DiscreteValues newAssignment = assignment; + for (auto& [key, value] : fa) { + newAssignment[key] = value; + } + + return {.assignment = newAssignment, + .error = error + conditional.error(newAssignment), + .bound = 0.0, + .nextConditional = nextConditional - 1}; } - return os; -} -bool Solutions::prune(double bound) const { - if (pq_.size() < maxSize_) return false; - return bound >= pq_.top().error; -} - -std::vector Solutions::extractSolutions() { - std::vector result; - while (!pq_.empty()) { - result.push_back(pq_.top()); - pq_.pop(); + /** + * @brief Prints the SearchNode to an output stream. + * + * @param os The output stream. + * @param node The SearchNode to be printed. + * @return The output stream. + */ + friend std::ostream& operator<<(std::ostream& os, const SearchNode& node) { + os << "SearchNode(error=" << node.error << ", bound=" << node.bound << ")"; + return os; } - std::sort( - result.begin(), result.end(), - [](const Solution& a, const Solution& b) { return a.error < b.error; }); - return result; -} +}; + +struct CompareSolution { + bool operator()(const Solution& a, const Solution& b) const { + return a.error < b.error; + } +}; + +// Define the Solutions class +class Solutions { + private: + size_t maxSize_; + std::priority_queue, CompareSolution> pq_; + + public: + Solutions(size_t maxSize) : maxSize_(maxSize) {} + + /// Add a solution to the priority queue, possibly evicting the worst one. + /// Return true if we added the solution. + bool maybeAdd(double error, const DiscreteValues& assignment) { + const bool full = pq_.size() == maxSize_; + if (full && error >= pq_.top().error) return false; + if (full) pq_.pop(); + pq_.emplace(error, assignment); + return true; + } + + /// Check if we have any solutions + bool empty() const { return pq_.empty(); } + + // Method to print all solutions + friend std::ostream& operator<<(std::ostream& os, const Solutions& sn) { + os << "Solutions (top " << sn.pq_.size() << "):\n"; + auto pq = sn.pq_; + while (!pq.empty()) { + os << pq.top() << "\n"; + pq.pop(); + } + return os; + } + + /// Check if (partial) solution with given bound can be pruned. If we have + /// room, we never prune. Otherwise, prune if lower bound on error is worse + /// than our current worst error. + bool prune(double bound) const { + if (pq_.size() < maxSize_) return false; + return bound >= pq_.top().error; + } + + // Method to extract solutions in ascending order of error + std::vector extractSolutions() { + std::vector result; + while (!pq_.empty()) { + result.push_back(pq_.top()); + pq_.pop(); + } + std::sort( + result.begin(), result.end(), + [](const Solution& a, const Solution& b) { return a.error < b.error; }); + return result; + } +}; DiscreteSearch::DiscreteSearch(const DiscreteBayesNet& bayesNet) { std::vector conditionals; diff --git a/gtsam/discrete/DiscreteSearch.h b/gtsam/discrete/DiscreteSearch.h index f5bd67bae..78558a127 100644 --- a/gtsam/discrete/DiscreteSearch.h +++ b/gtsam/discrete/DiscreteSearch.h @@ -23,63 +23,9 @@ namespace gtsam { -using Value = size_t; - /** - * @brief Represents a node in the search tree for discrete search algorithms. - * - * @details Each SearchNode contains a partial assignment of discrete variables, - * the current error, a bound on the final error, and the index of the next - * conditional to be assigned. + * @brief A solution to a discrete search problem. */ -struct SearchNode { - DiscreteValues assignment; ///< Partial assignment of discrete variables. - double error; ///< Current error for the partial assignment. - double bound; ///< Lower bound on the final error for unassigned variables. - int nextConditional; ///< Index of the next conditional to be assigned. - - /** - * @brief Construct the root node for the search. - */ - static SearchNode Root(size_t numConditionals, double bound); - - struct Compare { - bool operator()(const SearchNode& a, const SearchNode& b) const { - return a.bound > b.bound; // smallest bound -> highest priority - } - }; - - /** - * @brief Checks if the node represents a complete assignment. - * - * @return True if all variables have been assigned, false otherwise. - */ - inline bool isComplete() const { return nextConditional < 0; } - - /** - * @brief Expands the node by assigning the next variable. - * - * @param conditional The discrete conditional representing the next variable - * to be assigned. - * @param fa The frontal assignment for the next variable. - * @return A new SearchNode representing the expanded state. - */ - SearchNode expand(const DiscreteConditional& conditional, - const DiscreteValues& fa) const; - - /** - * @brief Prints the SearchNode to an output stream. - * - * @param os The output stream. - * @param node The SearchNode to be printed. - * @return The output stream. - */ - friend std::ostream& operator<<(std::ostream& os, const SearchNode& node) { - os << "SearchNode(error=" << node.error << ", bound=" << node.bound << ")"; - return os; - } -}; - struct Solution { double error; DiscreteValues assignment; @@ -89,40 +35,6 @@ struct Solution { os << "[ error=" << sn.error << " assignment={" << sn.assignment << "}]"; return os; } - - struct Compare { - bool operator()(const Solution& a, const Solution& b) const { - return a.error < b.error; - } - }; -}; - -// Define the Solutions class -class Solutions { - private: - size_t maxSize_; - std::priority_queue, Solution::Compare> pq_; - - public: - Solutions(size_t maxSize) : maxSize_(maxSize) {} - - /// Add a solution to the priority queue, possibly evicting the worst one. - /// Return true if we added the solution. - bool maybeAdd(double error, const DiscreteValues& assignment); - - /// Check if we have any solutions - bool empty() const { return pq_.empty(); } - - // Method to print all solutions - friend std::ostream& operator<<(std::ostream& os, const Solutions& sn); - - /// Check if (partial) solution with given bound can be pruned. If we have - /// room, we never prune. Otherwise, prune if lower bound on error is worse - /// than our current worst error. - bool prune(double bound) const; - - // Method to extract solutions in ascending order of error - std::vector extractSolutions(); }; /** From 9dcb52b697a25d856e49fe8655cc23e3654cfa36 Mon Sep 17 00:00:00 2001 From: Frank Dellaert Date: Mon, 27 Jan 2025 00:37:17 -0500 Subject: [PATCH 13/20] Made queue a class --- gtsam/discrete/DiscreteSearch.cpp | 48 +++++++++++++++---------------- 1 file changed, 24 insertions(+), 24 deletions(-) diff --git a/gtsam/discrete/DiscreteSearch.cpp b/gtsam/discrete/DiscreteSearch.cpp index 5fca5c820..e3090bb50 100644 --- a/gtsam/discrete/DiscreteSearch.cpp +++ b/gtsam/discrete/DiscreteSearch.cpp @@ -171,48 +171,48 @@ DiscreteSearch::DiscreteSearch(const DiscreteBayesTree& bayesTree) { costToGo_ = computeCostToGo(conditionals_); }; -using SearchNodeQueue = std::priority_queue, - SearchNode::Compare>; - -std::vector DiscreteSearch::run(size_t K) const { - SearchNodeQueue expansions; - expansions.push(SearchNode::Root(conditionals_.size(), - costToGo_.empty() ? 0.0 : costToGo_.back())); - - Solutions solutions(K); - - auto expandNextNode = [&] { +struct SearchNodeQueue + : public std::priority_queue, + SearchNode::Compare> { + void expandNextNode( + const std::vector& conditionals, + const std::vector& costToGo, Solutions* solutions) { // Pop the partial assignment with the smallest bound - SearchNode current = expansions.top(); - expansions.pop(); + SearchNode current = top(); + pop(); - // If we already have K solutions, prune if we cannot beat the worst - // one. - if (solutions.prune(current.bound)) { + // If we already have K solutions, prune if we cannot beat the worst one. + if (solutions->prune(current.bound)) { return; } // Check if we have a complete assignment if (current.isComplete()) { - solutions.maybeAdd(current.error, current.assignment); + solutions->maybeAdd(current.error, current.assignment); return; } // Expand on the next factor - const auto& conditional = conditionals_[current.nextConditional]; + const auto& conditional = conditionals[current.nextConditional]; for (auto& fa : conditional->frontalAssignments()) { auto childNode = current.expand(*conditional, fa); if (childNode.nextConditional >= 0) - childNode.bound = - childNode.error + costToGo_[childNode.nextConditional]; + childNode.bound = childNode.error + costToGo[childNode.nextConditional]; // Again, prune if we cannot beat the worst solution - if (!solutions.prune(childNode.bound)) { - expansions.emplace(childNode); + if (!solutions->prune(childNode.bound)) { + emplace(childNode); } } - }; + } +}; + +std::vector DiscreteSearch::run(size_t K) const { + Solutions solutions(K); + SearchNodeQueue expansions; + expansions.push(SearchNode::Root(conditionals_.size(), + costToGo_.empty() ? 0.0 : costToGo_.back())); #ifdef DISCRETE_SEARCH_DEBUG size_t numExpansions = 0; @@ -220,7 +220,7 @@ std::vector DiscreteSearch::run(size_t K) const { // Perform the search while (!expansions.empty()) { - expandNextNode(); + expansions.expandNextNode(conditionals_, costToGo_, &solutions); #ifdef DISCRETE_SEARCH_DEBUG ++numExpansions; #endif From 908e4dcd77d93cf69dd2ae04f82847af4586efb4 Mon Sep 17 00:00:00 2001 From: Frank Dellaert Date: Mon, 27 Jan 2025 00:51:55 -0500 Subject: [PATCH 14/20] Check MPE --- gtsam/discrete/tests/testDiscreteSearch.cpp | 16 +++++++++++----- 1 file changed, 11 insertions(+), 5 deletions(-) diff --git a/gtsam/discrete/tests/testDiscreteSearch.cpp b/gtsam/discrete/tests/testDiscreteSearch.cpp index a3b50add4..1eb274f2d 100644 --- a/gtsam/discrete/tests/testDiscreteSearch.cpp +++ b/gtsam/discrete/tests/testDiscreteSearch.cpp @@ -45,8 +45,8 @@ TEST(DiscreteBayesNet, EmptyKBest) { /* ************************************************************************* */ TEST(DiscreteBayesNet, AsiaKBest) { using namespace asia_example; - DiscreteBayesNet asia = createAsiaExample(); - DiscreteSearch search(asia); + const DiscreteBayesNet asia = createAsiaExample(); + const DiscreteSearch search(asia); // Ask for the MPE auto mpe = search.run(); @@ -55,6 +55,10 @@ TEST(DiscreteBayesNet, AsiaKBest) { // Regression test: check the MPE solution EXPECT_DOUBLES_EQUAL(1.236627, std::fabs(mpe[0].error), 1e-5); + // Check it is equal to MPE via inference + const DiscreteFactorGraph asiaFG(asia); + EXPECT(assert_equal(mpe[0].assignment, asiaFG.optimize())); + // Ask for top 4 solutions auto solutions = search.run(4); @@ -72,7 +76,6 @@ TEST(DiscreteBayesTree, EmptyTree) { auto solutions = search.run(3); // We expect exactly 1 solution with error = 0.0 (the empty assignment). - assert(solutions.size() == 1 && "There should be exactly one empty solution"); EXPECT_LONGS_EQUAL(1, solutions.size()); EXPECT_DOUBLES_EQUAL(0, std::fabs(solutions[0].error), 1e-9); } @@ -80,9 +83,9 @@ TEST(DiscreteBayesTree, EmptyTree) { /* ************************************************************************* */ TEST(DiscreteBayesTree, AsiaTreeKBest) { using namespace asia_example; - DiscreteFactorGraph asia(createAsiaExample()); + DiscreteFactorGraph asiaFG(createAsiaExample()); const Ordering ordering{D, X, B, E, L, T, S, A}; - DiscreteBayesTree::shared_ptr bt = asia.eliminateMultifrontal(ordering); + DiscreteBayesTree::shared_ptr bt = asiaFG.eliminateMultifrontal(ordering); DiscreteSearch search(*bt); // Ask for MPE @@ -92,6 +95,9 @@ TEST(DiscreteBayesTree, AsiaTreeKBest) { // Regression test: check the MPE solution EXPECT_DOUBLES_EQUAL(1.236627, std::fabs(mpe[0].error), 1e-5); + // Check it is equal to MPE via inference + EXPECT(assert_equal(mpe[0].assignment, asiaFG.optimize())); + // Ask for top 4 solutions auto solutions = search.run(4); From 1594a2504560ded7419afedc65fce102d2662c96 Mon Sep 17 00:00:00 2001 From: Frank Dellaert Date: Mon, 27 Jan 2025 00:52:07 -0500 Subject: [PATCH 15/20] Fix assert issue --- gtsam/discrete/DiscreteSearch.cpp | 1 - 1 file changed, 1 deletion(-) diff --git a/gtsam/discrete/DiscreteSearch.cpp b/gtsam/discrete/DiscreteSearch.cpp index e3090bb50..1cd2999f2 100644 --- a/gtsam/discrete/DiscreteSearch.cpp +++ b/gtsam/discrete/DiscreteSearch.cpp @@ -241,7 +241,6 @@ std::vector DiscreteSearch::computeCostToGo( for (const auto& conditional : conditionals) { Ordering ordering(conditional->begin(), conditional->end()); auto maxx = conditional->max(ordering); - assert(maxx->size() == 1); error -= std::log(maxx->evaluate({})); costToGo.push_back(error); } From d53b48cf80364649e63e34dddaa0de3cc4f63617 Mon Sep 17 00:00:00 2001 From: Frank Dellaert Date: Mon, 27 Jan 2025 09:38:57 -0500 Subject: [PATCH 16/20] Don't use c99 feature :-/ --- gtsam/discrete/DiscreteSearch.cpp | 12 ++++-------- 1 file changed, 4 insertions(+), 8 deletions(-) diff --git a/gtsam/discrete/DiscreteSearch.cpp b/gtsam/discrete/DiscreteSearch.cpp index 1cd2999f2..4a068d175 100644 --- a/gtsam/discrete/DiscreteSearch.cpp +++ b/gtsam/discrete/DiscreteSearch.cpp @@ -39,10 +39,8 @@ struct SearchNode { * @brief Construct the root node for the search. */ static SearchNode Root(size_t numConditionals, double bound) { - return {.assignment = DiscreteValues(), - .error = 0.0, - .bound = bound, - .nextConditional = static_cast(numConditionals) - 1}; + return {DiscreteValues(), 0.0, bound, + static_cast(numConditionals) - 1}; } struct Compare { @@ -74,10 +72,8 @@ struct SearchNode { newAssignment[key] = value; } - return {.assignment = newAssignment, - .error = error + conditional.error(newAssignment), - .bound = 0.0, - .nextConditional = nextConditional - 1}; + return {newAssignment, error + conditional.error(newAssignment), 0.0, + nextConditional - 1}; } /** From 2da5d764f15ada8845b49bed177019d4193b02dd Mon Sep 17 00:00:00 2001 From: Frank Dellaert Date: Mon, 27 Jan 2025 10:09:40 -0500 Subject: [PATCH 17/20] Pre-compute asia things --- gtsam/discrete/tests/testDiscreteSearch.cpp | 34 +++++++++------------ 1 file changed, 15 insertions(+), 19 deletions(-) diff --git a/gtsam/discrete/tests/testDiscreteSearch.cpp b/gtsam/discrete/tests/testDiscreteSearch.cpp index 1eb274f2d..b537dd2f0 100644 --- a/gtsam/discrete/tests/testDiscreteSearch.cpp +++ b/gtsam/discrete/tests/testDiscreteSearch.cpp @@ -20,18 +20,21 @@ #include #include -#include -#include -#include -#include -#include -#include -#include - #include "AsiaExample.h" using namespace gtsam; +// Create Asia Bayes net, FG, and Bayes tree once +namespace asia { +using namespace asia_example; +static const DiscreteBayesNet bayesNet = createAsiaExample(); +static const DiscreteFactorGraph factorGraph(bayesNet); +static const DiscreteValues mpe = factorGraph.optimize(); +static const Ordering ordering{D, X, B, E, L, T, S, A}; +static const DiscreteBayesTree bayesTree = + *factorGraph.eliminateMultifrontal(ordering); +} // namespace asia + /* ************************************************************************* */ TEST(DiscreteBayesNet, EmptyKBest) { DiscreteBayesNet net; // no factors @@ -44,9 +47,7 @@ TEST(DiscreteBayesNet, EmptyKBest) { /* ************************************************************************* */ TEST(DiscreteBayesNet, AsiaKBest) { - using namespace asia_example; - const DiscreteBayesNet asia = createAsiaExample(); - const DiscreteSearch search(asia); + const DiscreteSearch search(asia::bayesNet); // Ask for the MPE auto mpe = search.run(); @@ -56,8 +57,7 @@ TEST(DiscreteBayesNet, AsiaKBest) { EXPECT_DOUBLES_EQUAL(1.236627, std::fabs(mpe[0].error), 1e-5); // Check it is equal to MPE via inference - const DiscreteFactorGraph asiaFG(asia); - EXPECT(assert_equal(mpe[0].assignment, asiaFG.optimize())); + EXPECT(assert_equal(asia::mpe, mpe[0].assignment)); // Ask for top 4 solutions auto solutions = search.run(4); @@ -82,11 +82,7 @@ TEST(DiscreteBayesTree, EmptyTree) { /* ************************************************************************* */ TEST(DiscreteBayesTree, AsiaTreeKBest) { - using namespace asia_example; - DiscreteFactorGraph asiaFG(createAsiaExample()); - const Ordering ordering{D, X, B, E, L, T, S, A}; - DiscreteBayesTree::shared_ptr bt = asiaFG.eliminateMultifrontal(ordering); - DiscreteSearch search(*bt); + DiscreteSearch search(asia::bayesTree); // Ask for MPE auto mpe = search.run(); @@ -96,7 +92,7 @@ TEST(DiscreteBayesTree, AsiaTreeKBest) { EXPECT_DOUBLES_EQUAL(1.236627, std::fabs(mpe[0].error), 1e-5); // Check it is equal to MPE via inference - EXPECT(assert_equal(mpe[0].assignment, asiaFG.optimize())); + EXPECT(assert_equal(asia::mpe, mpe[0].assignment)); // Ask for top 4 solutions auto solutions = search.run(4); From cf4441826909aae8ad5cd91cc8d9ef46821b759a Mon Sep 17 00:00:00 2001 From: Frank Dellaert Date: Mon, 27 Jan 2025 13:31:24 -0500 Subject: [PATCH 18/20] Get rid of extra ; --- gtsam/discrete/DiscreteSearch.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/gtsam/discrete/DiscreteSearch.cpp b/gtsam/discrete/DiscreteSearch.cpp index 4a068d175..531d994be 100644 --- a/gtsam/discrete/DiscreteSearch.cpp +++ b/gtsam/discrete/DiscreteSearch.cpp @@ -165,7 +165,7 @@ DiscreteSearch::DiscreteSearch(const DiscreteBayesTree& bayesTree) { }; for (const auto& root : bayesTree.roots()) collectConditionals(root); costToGo_ = computeCostToGo(conditionals_); -}; +} struct SearchNodeQueue : public std::priority_queue, From 16131fe8c92b6a84ca62dec758d845406781b2a6 Mon Sep 17 00:00:00 2001 From: Frank Dellaert Date: Mon, 27 Jan 2025 14:02:40 -0500 Subject: [PATCH 19/20] Remove spurious typedef --- gtsam/discrete/DiscreteSearch.cpp | 2 -- 1 file changed, 2 deletions(-) diff --git a/gtsam/discrete/DiscreteSearch.cpp b/gtsam/discrete/DiscreteSearch.cpp index 531d994be..d151e7a6f 100644 --- a/gtsam/discrete/DiscreteSearch.cpp +++ b/gtsam/discrete/DiscreteSearch.cpp @@ -20,8 +20,6 @@ namespace gtsam { -using Value = size_t; - /** * @brief Represents a node in the search tree for discrete search algorithms. * From 41fdeb6c04908be7b52bc099f734d2221d0fba19 Mon Sep 17 00:00:00 2001 From: Frank Dellaert Date: Mon, 27 Jan 2025 14:51:16 -0500 Subject: [PATCH 20/20] Export --- gtsam/discrete/DiscreteSearch.cpp | 2 ++ gtsam/discrete/DiscreteSearch.h | 30 +++++++++++++++--------------- 2 files changed, 17 insertions(+), 15 deletions(-) diff --git a/gtsam/discrete/DiscreteSearch.cpp b/gtsam/discrete/DiscreteSearch.cpp index d151e7a6f..c5941862d 100644 --- a/gtsam/discrete/DiscreteSearch.cpp +++ b/gtsam/discrete/DiscreteSearch.cpp @@ -20,6 +20,8 @@ namespace gtsam { +using Solution = DiscreteSearch::Solution; + /** * @brief Represents a node in the search tree for discrete search algorithms. * diff --git a/gtsam/discrete/DiscreteSearch.h b/gtsam/discrete/DiscreteSearch.h index 78558a127..6202880b2 100644 --- a/gtsam/discrete/DiscreteSearch.h +++ b/gtsam/discrete/DiscreteSearch.h @@ -23,25 +23,25 @@ namespace gtsam { -/** - * @brief A solution to a discrete search problem. - */ -struct Solution { - double error; - DiscreteValues assignment; - Solution(double err, const DiscreteValues& assign) - : error(err), assignment(assign) {} - friend std::ostream& operator<<(std::ostream& os, const Solution& sn) { - os << "[ error=" << sn.error << " assignment={" << sn.assignment << "}]"; - return os; - } -}; - /** * DiscreteSearch: Search for the K best solutions. */ -class DiscreteSearch { +class GTSAM_EXPORT DiscreteSearch { public: + /** + * @brief A solution to a discrete search problem. + */ + struct Solution { + double error; + DiscreteValues assignment; + Solution(double err, const DiscreteValues& assign) + : error(err), assignment(assign) {} + friend std::ostream& operator<<(std::ostream& os, const Solution& sn) { + os << "[ error=" << sn.error << " assignment={" << sn.assignment << "}]"; + return os; + } + }; + /** * Construct from a DiscreteBayesNet and K. */