From 655f57ef0dc851366b0f8c67d60ef014620449db Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Sun, 26 Jan 2025 11:14:16 -0500 Subject: [PATCH 01/43] check for valid factor in HybridNonlinearFactor::errorTree --- gtsam/hybrid/HybridNonlinearFactor.cpp | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/gtsam/hybrid/HybridNonlinearFactor.cpp b/gtsam/hybrid/HybridNonlinearFactor.cpp index 376bc66f1..fa22051e5 100644 --- a/gtsam/hybrid/HybridNonlinearFactor.cpp +++ b/gtsam/hybrid/HybridNonlinearFactor.cpp @@ -99,7 +99,8 @@ AlgebraicDecisionTree HybridNonlinearFactor::errorTree( auto errorFunc = [continuousValues](const std::pair& f) { auto [factor, val] = f; - return factor->error(continuousValues) + val; + return factor ? factor->error(continuousValues) + val + : std::numeric_limits::infinity(); }; return {factors_, errorFunc}; } From 087c0cc5258bf1711f0c3f3aef74340d7de9355e Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Sun, 26 Jan 2025 12:17:05 -0500 Subject: [PATCH 02/43] allow for only continuous variables in HybridBayesTree --- gtsam/hybrid/HybridBayesTree.cpp | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/gtsam/hybrid/HybridBayesTree.cpp b/gtsam/hybrid/HybridBayesTree.cpp index c74930a8e..fbf892235 100644 --- a/gtsam/hybrid/HybridBayesTree.cpp +++ b/gtsam/hybrid/HybridBayesTree.cpp @@ -202,6 +202,11 @@ VectorValues HybridBayesTree::optimize(const DiscreteValues& assignment) const { /* ************************************************************************* */ void HybridBayesTree::prune(const size_t maxNrLeaves) { + if (!this->roots_.at(0)->conditional()->asDiscrete()) { + // Root of the BayesTree is not a discrete clique, so we do nothing. + return; + } + auto prunedDiscreteProbs = this->roots_.at(0)->conditional()->asDiscrete(); From 4027e860576c29337617d4e6fb5cfe62324ae05a Mon Sep 17 00:00:00 2001 From: Frank Dellaert Date: Sun, 26 Jan 2025 16:57:37 -0500 Subject: [PATCH 03/43] stream operator --- gtsam/discrete/DiscreteValues.cpp | 16 ++++++++++++++-- gtsam/discrete/DiscreteValues.h | 3 +++ 2 files changed, 17 insertions(+), 2 deletions(-) diff --git a/gtsam/discrete/DiscreteValues.cpp b/gtsam/discrete/DiscreteValues.cpp index 416dfb888..3c3ed4468 100644 --- a/gtsam/discrete/DiscreteValues.cpp +++ b/gtsam/discrete/DiscreteValues.cpp @@ -26,12 +26,24 @@ using std::stringstream; namespace gtsam { +/* ************************************************************************ */ +static void stream(std::ostream& os, const DiscreteValues& x, + const KeyFormatter& keyFormatter) { + for (const auto& kv : x) + os << "(" << keyFormatter(kv.first) << ", " << kv.second << ")"; +} + +/* ************************************************************************ */ +std::ostream& operator<<(std::ostream& os, const DiscreteValues& x) { + stream(os, x, DefaultKeyFormatter); + return os; +} + /* ************************************************************************ */ void DiscreteValues::print(const string& s, const KeyFormatter& keyFormatter) const { cout << s << ": "; - for (auto&& kv : *this) - cout << "(" << keyFormatter(kv.first) << ", " << kv.second << ")"; + stream(cout, *this, keyFormatter); cout << endl; } diff --git a/gtsam/discrete/DiscreteValues.h b/gtsam/discrete/DiscreteValues.h index fa8a8a846..df4ecdbff 100644 --- a/gtsam/discrete/DiscreteValues.h +++ b/gtsam/discrete/DiscreteValues.h @@ -64,6 +64,9 @@ class GTSAM_EXPORT DiscreteValues : public Assignment { /// @name Standard Interface /// @{ + /// ostream operator: + friend std::ostream& operator<<(std::ostream& os, const DiscreteValues& x); + // insert in base class; std::pair insert( const value_type& value ){ return Base::insert(value); From 2808982903032524d9dfe90b36d67f0b4f4c0ad5 Mon Sep 17 00:00:00 2001 From: Frank Dellaert Date: Sun, 26 Jan 2025 16:57:58 -0500 Subject: [PATCH 04/43] Best-K search, v0 --- gtsam/discrete/tests/testDiscreteBayesNet.cpp | 276 ++++++++++++++++-- 1 file changed, 253 insertions(+), 23 deletions(-) diff --git a/gtsam/discrete/tests/testDiscreteBayesNet.cpp b/gtsam/discrete/tests/testDiscreteBayesNet.cpp index d2033909c..0bd81a197 100644 --- a/gtsam/discrete/tests/testDiscreteBayesNet.cpp +++ b/gtsam/discrete/tests/testDiscreteBayesNet.cpp @@ -23,6 +23,7 @@ #include #include #include +#include #include #include @@ -31,8 +32,15 @@ using namespace std; using namespace gtsam; -static const DiscreteKey Asia(0, 2), Smoking(4, 2), Tuberculosis(3, 2), - LungCancer(6, 2), Bronchitis(7, 2), Either(5, 2), XRay(2, 2), Dyspnea(1, 2); +namespace keys { +static const Key D = Symbol('D', 1), X = Symbol('X', 2), E = Symbol('E', 3), + B = Symbol('B', 4), L = Symbol('L', 5), T = Symbol('T', 6), + S = Symbol('S', 7), A = Symbol('A', 8); +} + +static const DiscreteKey Dyspnea(keys::D, 2), XRay(keys::X, 2), + Either(keys::E, 2), Bronchitis(keys::B, 2), LungCancer(keys::L, 2), + Tuberculosis(keys::T, 2), Smoking(keys::S, 2), Asia(keys::A, 2); using ADT = AlgebraicDecisionTree; @@ -40,17 +48,15 @@ using ADT = AlgebraicDecisionTree; DiscreteBayesNet constructAsiaExample() { DiscreteBayesNet asia; - asia.add(Asia, "99/1"); - asia.add(Smoking % "50/50"); // Signature version - - asia.add(Tuberculosis | Asia = "99/1 95/5"); - asia.add(LungCancer | Smoking = "99/1 90/10"); - asia.add(Bronchitis | Smoking = "70/30 40/60"); - - asia.add((Either | Tuberculosis, LungCancer) = "F T T T"); - - asia.add(XRay | Either = "95/5 2/98"); + // Add in topological sort order, parents last: asia.add((Dyspnea | Either, Bronchitis) = "9/1 2/8 3/7 1/9"); + asia.add(XRay | Either = "95/5 2/98"); + asia.add((Either | Tuberculosis, LungCancer) = "F T T T"); + asia.add(Bronchitis | Smoking = "70/30 40/60"); + asia.add(LungCancer | Smoking = "99/1 90/10"); + asia.add(Tuberculosis | Asia = "99/1 95/5"); + asia.add(Smoking % "50/50"); // Signature version + asia.add(Asia, "99/1"); return asia; } @@ -99,7 +105,8 @@ TEST(DiscreteBayesNet, Asia) { EXPECT(assert_equal(vs, marginals.marginalProbabilities(Smoking))); // Create solver and eliminate - const Ordering ordering{0, 1, 2, 3, 4, 5, 6, 7}; + const Ordering ordering{keys::A, keys::D, keys::T, keys::X, + keys::S, keys::E, keys::L, keys::B}; DiscreteBayesNet::shared_ptr chordal = fg.eliminateSequential(ordering); DiscreteConditional expected2(Bronchitis % "11/9"); EXPECT(assert_equal(expected2, *chordal->back())); @@ -157,16 +164,16 @@ TEST(DiscreteBayesNet, Dot) { "digraph {\n" " size=\"5,5\";\n" "\n" - " var0[label=\"0\"];\n" - " var3[label=\"3\"];\n" - " var4[label=\"4\"];\n" - " var5[label=\"5\"];\n" - " var6[label=\"6\"];\n" + " var4683743612465315840[label=\"A0\"];\n" + " var4971973988617027589[label=\"E5\"];\n" + " var5476377146882523142[label=\"L6\"];\n" + " var5980780305148018692[label=\"S4\"];\n" + " var6052837899185946627[label=\"T3\"];\n" "\n" - " var3->var5\n" - " var6->var5\n" - " var4->var6\n" - " var0->var3\n" + " var6052837899185946627->var4971973988617027589\n" + " var5476377146882523142->var4971973988617027589\n" + " var5980780305148018692->var5476377146882523142\n" + " var4683743612465315840->var6052837899185946627\n" "}"); } @@ -191,11 +198,234 @@ TEST(DiscreteBayesNet, markdown) { "|:-:|:-:|:-:|\n" "|0|0.8|0.2|\n" "|1|0.7|0.3|\n\n"; - auto formatter = [](Key key) { return key == 0 ? "Asia" : "Smoking"; }; + auto formatter = [](Key key) { return key == keys::A ? "Asia" : "Smoking"; }; string actual = fragment.markdown(formatter); EXPECT(actual == expected); } +/* ************************************************************************* */ +#include +#include +#include +#include +#include +#include + +using Value = size_t; + +// ---------------------------------------------------------------------------- +// 1) SearchNode: store partial assignment and next factor to expand +// ---------------------------------------------------------------------------- +struct SearchNode { + DiscreteValues assignment; + double error; + double bound; + int nextConditional; // index into conditionals + + /// if nextConditional < 0, we've assigned everything. + bool isComplete() const { return nextConditional < 0; } + + /// lower bound on final error for unassigned variables. Stub=0. + double computeBound() const { + // Real code might do partial factor analysis or heuristics. + return 0.0; + } + + /// Expand this node by assigning the next variable + SearchNode expand(const DiscreteConditional& conditional, + const DiscreteValues& fa) const { + // Combine the new frontal assignment with the current partial assignment + SearchNode child; + child.assignment = assignment; + for (auto& kv : fa) { + child.assignment[kv.first] = kv.second; + } + + // Compute the incremental error for this factor + child.error = error + conditional.error(child.assignment); + + // Compute new bound + child.bound = child.error + computeBound(); + + // Next factor index + child.nextConditional = nextConditional - 1; + + return child; + } + + friend std::ostream& operator<<(std::ostream& os, const SearchNode& sn) { + os << "[ error=" << sn.error << " bound=" << sn.bound + << " nextConditional=" << sn.nextConditional << " assignment={" + << sn.assignment << "}]"; + return os; + } +}; + +// ---------------------------------------------------------------------------- +// 2) Priority functor to make a min-heap by bound +// ---------------------------------------------------------------------------- +struct CompareByBound { + bool operator()(const SearchNode& a, const SearchNode& b) const { + return a.bound > b.bound; // smallest bound -> highest priority + } +}; + +// ---------------------------------------------------------------------------- +// 4) A Solution +// ---------------------------------------------------------------------------- +struct Solution { + double error; + DiscreteValues assignment; + Solution(double err, const DiscreteValues& assign) + : error(err), assignment(assign) {} + friend std::ostream& operator<<(std::ostream& os, const Solution& sn) { + os << "[ error=" << sn.error << " assignment={" << sn.assignment << "}]"; + return os; + } +}; + +struct CompareByError { + bool operator()(const Solution& a, const Solution& b) const { + return a.error < b.error; + } +}; + +using Solutions = + std::priority_queue, CompareByError>; + +// Function to print the priority queue +void printPriorityQueue(Solutions pq) { + size_t i = 0; + while (!pq.empty()) { + const Solution& best = pq.top(); + cout << " " << ++i << " Error: " << best.error + << ", Values: " << best.assignment << endl; + pq.pop(); + } +} + +// ---------------------------------------------------------------------------- +// 5) The main branch-and-bound routine for DiscreteBayesNet +// ---------------------------------------------------------------------------- +std::vector branchAndBoundKSolutions(const DiscreteBayesNet& bayesNet, + size_t K) { + // Copy out the conditionals + std::vector> conditionals; + for (auto& factor : bayesNet) { + conditionals.push_back(factor); + } + + // Min-heap of expansions by bound + std::priority_queue, CompareByBound> + expansions; + + // Max-heap of best solutions found so far, keyed by error. + // We keep the largest error on top. + Solutions solutions; + + // 1) Create the root node: no variables assigned, nextConditional = last. + SearchNode root{.assignment = DiscreteValues(), + .error = 0.0, + .nextConditional = static_cast(conditionals.size()) - 1}; + root.bound = root.computeBound(); + std::cout << "Root: " << root << std::endl; + expansions.push(root); + + // 2) Main loop + size_t numExpansions = 0; + while (!expansions.empty()) { + // Pop the partial assignment with the smallest bound + SearchNode current = expansions.top(); + expansions.pop(); + numExpansions++; + std::cout << "Expanding: " << current << std::endl; + + // If we already have K solutions, prune if we cannot beat the worst one. + if (solutions.size() == K) { + double worstError = solutions.top().error; + if (current.bound >= worstError) { + std::cout << "Pruning: bound=" << current.bound + << " worstError=" << worstError << std::endl; + continue; + } + } + + // Check if we have a complete assignment + if (current.isComplete()) { + // We have a new candidate solution + double finalError = current.error; + if (solutions.size() < K) { + std::cout << "Adding solution: " << current << std::endl; + solutions.emplace(finalError, current.assignment); + // print out all solutions: + std::cout << "Best solutions so far:" << std::endl; + printPriorityQueue(solutions); + } else { + double worstError = solutions.top().error; + if (finalError < worstError) { + std::cout << "Replacing solution: " << current << std::endl; + solutions.pop(); + solutions.emplace(finalError, current.assignment); + } + } + continue; + } + + // Expand on the next factor + const auto& conditional = conditionals[current.nextConditional]; + + for (auto& fa : conditional->frontalAssignments()) { + std::cout << "Frontal assignment: " << fa << std::endl; + auto childNode = current.expand(*conditional, fa); + + // Again, prune if we cannot beat the worst solution + if (solutions.size() == K) { + double worstError = solutions.top().error; + if (childNode.bound >= worstError) { + std::cout << "Pruning: bound=" << childNode.bound + << " worstError=" << worstError << std::endl; + continue; + } + } + expansions.push(childNode); + } + } + + std::cout << "Expansions: " << numExpansions << std::endl; + + // 3) Extract solutions from bestSolutions in ascending order of error + std::vector result; + while (!solutions.empty()) { + result.push_back(solutions.top()); + solutions.pop(); + } + std::sort(result.begin(), result.end(), + [](auto& a, auto& b) { return a.error < b.error; }); + + return result; +} + +// ---------------------------------------------------------------------------- +// Example “Unit Tests” (trivial stubs) +// ---------------------------------------------------------------------------- + +TEST_DISABLED(DiscreteBayesNet, EmptyKBest) { + DiscreteBayesNet net; // no factors + auto solutions = branchAndBoundKSolutions(net, 3); + // Expect one solution with empty assignment, error=0 + EXPECT_LONGS_EQUAL(1, solutions.size()); + EXPECT_DOUBLES_EQUAL(0, std::fabs(solutions[0].error), 1e-9); +} + +TEST(DiscreteBayesNet, AsiaKBest) { + DiscreteBayesNet net = constructAsiaExample(); + + auto solutions = branchAndBoundKSolutions(net, 4); + EXPECT(!solutions.empty()); + // All solutions likely tie with error=0 in this stub + EXPECT_DOUBLES_EQUAL(0, std::fabs(solutions[0].error), 1e-9); +} + /* ************************************************************************* */ int main() { TestResult tr; From 74b7a962d5f5453f57b9b13b14e362af716f95e7 Mon Sep 17 00:00:00 2001 From: Frank Dellaert Date: Sun, 26 Jan 2025 17:40:44 -0500 Subject: [PATCH 05/43] Methods in Solutions --- gtsam/discrete/tests/testDiscreteBayesNet.cpp | 138 ++++++++++-------- 1 file changed, 77 insertions(+), 61 deletions(-) diff --git a/gtsam/discrete/tests/testDiscreteBayesNet.cpp b/gtsam/discrete/tests/testDiscreteBayesNet.cpp index 0bd81a197..52fc660e5 100644 --- a/gtsam/discrete/tests/testDiscreteBayesNet.cpp +++ b/gtsam/discrete/tests/testDiscreteBayesNet.cpp @@ -96,7 +96,7 @@ TEST(DiscreteBayesNet, Asia) { // Convert to factor graph DiscreteFactorGraph fg(asia); - LONGS_EQUAL(3, fg.back()->size()); + LONGS_EQUAL(1, fg.back()->size()); // Check the marginals we know (of the parent-less nodes) DiscreteMarginals marginals(fg); @@ -164,16 +164,16 @@ TEST(DiscreteBayesNet, Dot) { "digraph {\n" " size=\"5,5\";\n" "\n" - " var4683743612465315840[label=\"A0\"];\n" - " var4971973988617027589[label=\"E5\"];\n" - " var5476377146882523142[label=\"L6\"];\n" - " var5980780305148018692[label=\"S4\"];\n" - " var6052837899185946627[label=\"T3\"];\n" + " var4683743612465315848[label=\"A8\"];\n" + " var4971973988617027587[label=\"E3\"];\n" + " var5476377146882523141[label=\"L5\"];\n" + " var5980780305148018695[label=\"S7\"];\n" + " var6052837899185946630[label=\"T6\"];\n" "\n" - " var6052837899185946627->var4971973988617027589\n" - " var5476377146882523142->var4971973988617027589\n" - " var5980780305148018692->var5476377146882523142\n" - " var4683743612465315840->var6052837899185946627\n" + " var6052837899185946630->var4971973988617027587\n" + " var5476377146882523141->var4971973988617027587\n" + " var5980780305148018695->var5476377146882523141\n" + " var4683743612465315848->var6052837899185946630\n" "}"); } @@ -290,19 +290,61 @@ struct CompareByError { } }; -using Solutions = - std::priority_queue, CompareByError>; +// Define the Solutions class +class Solutions { + private: + size_t maxSize_; + std::priority_queue, CompareByError> pq_; -// Function to print the priority queue -void printPriorityQueue(Solutions pq) { - size_t i = 0; - while (!pq.empty()) { - const Solution& best = pq.top(); - cout << " " << ++i << " Error: " << best.error - << ", Values: " << best.assignment << endl; - pq.pop(); + public: + Solutions(size_t maxSize) : maxSize_(maxSize) {} + + /// Add a solution to the priority queue, possibly evicting the worst one. + /// Return true if we added the solution. + bool maybeAdd(double error, const DiscreteValues& assignment) { + const bool full = pq_.size() == maxSize_; + if (full && error >= pq_.top().error) return false; + if (full) pq_.pop(); + pq_.emplace(error, assignment); + return true; } -} + + /// Check if we have any solutions + bool empty() const { return pq_.empty(); } + + // Method to print all solutions + void print() const { + auto pq = pq_; + while (!pq.empty()) { + const Solution& best = pq.top(); + std::cout << "Error: " << best.error << ", Values: " << best.assignment + << std::endl; + pq.pop(); + } + } + + /// Check if (partial) solution with given bound can be pruned. If we have + /// room, we never prune. Otherwise, prune if lower bound on error is worse + /// than our current worst error. + bool prune(double bound) const { + if (pq_.size() < maxSize_) return false; + double worstError = pq_.top().error; + return (bound >= worstError); + } + + // Method to extract solutions in ascending order of error + std::vector extractSolutions() { + std::vector result; + while (!pq_.empty()) { + result.push_back(pq_.top()); + pq_.pop(); + } + std::sort( + result.begin(), result.end(), + [](const Solution& a, const Solution& b) { return a.error < b.error; }); + return result; + } +}; // ---------------------------------------------------------------------------- // 5) The main branch-and-bound routine for DiscreteBayesNet @@ -321,7 +363,7 @@ std::vector branchAndBoundKSolutions(const DiscreteBayesNet& bayesNet, // Max-heap of best solutions found so far, keyed by error. // We keep the largest error on top. - Solutions solutions; + Solutions solutions(K); // 1) Create the root node: no variables assigned, nextConditional = last. SearchNode root{.assignment = DiscreteValues(), @@ -341,32 +383,17 @@ std::vector branchAndBoundKSolutions(const DiscreteBayesNet& bayesNet, std::cout << "Expanding: " << current << std::endl; // If we already have K solutions, prune if we cannot beat the worst one. - if (solutions.size() == K) { - double worstError = solutions.top().error; - if (current.bound >= worstError) { - std::cout << "Pruning: bound=" << current.bound - << " worstError=" << worstError << std::endl; - continue; - } + if (solutions.prune(current.bound)) { + std::cout << "Pruning: bound=" << current.bound << std::endl; + continue; } // Check if we have a complete assignment if (current.isComplete()) { - // We have a new candidate solution - double finalError = current.error; - if (solutions.size() < K) { - std::cout << "Adding solution: " << current << std::endl; - solutions.emplace(finalError, current.assignment); - // print out all solutions: + const bool added = solutions.maybeAdd(current.error, current.assignment); + if (added) { std::cout << "Best solutions so far:" << std::endl; - printPriorityQueue(solutions); - } else { - double worstError = solutions.top().error; - if (finalError < worstError) { - std::cout << "Replacing solution: " << current << std::endl; - solutions.pop(); - solutions.emplace(finalError, current.assignment); - } + solutions.print(); } continue; } @@ -379,14 +406,11 @@ std::vector branchAndBoundKSolutions(const DiscreteBayesNet& bayesNet, auto childNode = current.expand(*conditional, fa); // Again, prune if we cannot beat the worst solution - if (solutions.size() == K) { - double worstError = solutions.top().error; - if (childNode.bound >= worstError) { - std::cout << "Pruning: bound=" << childNode.bound - << " worstError=" << worstError << std::endl; - continue; - } + if (solutions.prune(current.bound)) { + std::cout << "Pruning: bound=" << childNode.bound << std::endl; + continue; } + expansions.push(childNode); } } @@ -394,15 +418,7 @@ std::vector branchAndBoundKSolutions(const DiscreteBayesNet& bayesNet, std::cout << "Expansions: " << numExpansions << std::endl; // 3) Extract solutions from bestSolutions in ascending order of error - std::vector result; - while (!solutions.empty()) { - result.push_back(solutions.top()); - solutions.pop(); - } - std::sort(result.begin(), result.end(), - [](auto& a, auto& b) { return a.error < b.error; }); - - return result; + return solutions.extractSolutions(); } // ---------------------------------------------------------------------------- @@ -422,8 +438,8 @@ TEST(DiscreteBayesNet, AsiaKBest) { auto solutions = branchAndBoundKSolutions(net, 4); EXPECT(!solutions.empty()); - // All solutions likely tie with error=0 in this stub - EXPECT_DOUBLES_EQUAL(0, std::fabs(solutions[0].error), 1e-9); + // Regression test: check the first solution + EXPECT_DOUBLES_EQUAL(1.236627, std::fabs(solutions[0].error), 1e-5); } /* ************************************************************************* */ From 1f4d9bbd7ecdb90c2401c374000a4142486c7c5b Mon Sep 17 00:00:00 2001 From: Frank Dellaert Date: Sun, 26 Jan 2025 17:59:31 -0500 Subject: [PATCH 06/43] Separate class --- gtsam/discrete/tests/testDiscreteBayesNet.cpp | 114 +++++++++++------- 1 file changed, 68 insertions(+), 46 deletions(-) diff --git a/gtsam/discrete/tests/testDiscreteBayesNet.cpp b/gtsam/discrete/tests/testDiscreteBayesNet.cpp index 52fc660e5..648b26770 100644 --- a/gtsam/discrete/tests/testDiscreteBayesNet.cpp +++ b/gtsam/discrete/tests/testDiscreteBayesNet.cpp @@ -346,97 +346,119 @@ class Solutions { } }; -// ---------------------------------------------------------------------------- -// 5) The main branch-and-bound routine for DiscreteBayesNet -// ---------------------------------------------------------------------------- -std::vector branchAndBoundKSolutions(const DiscreteBayesNet& bayesNet, - size_t K) { - // Copy out the conditionals - std::vector> conditionals; - for (auto& factor : bayesNet) { - conditionals.push_back(factor); +/** + * BestKSearch: Search for the K best solutions. + */ +class BestKSearch { + public: + /** + * Construct from a DiscreteBayesNet and K. + */ + BestKSearch(const DiscreteBayesNet& bayesNet, size_t K) + : bayesNet_(bayesNet), solutions_(K) { + // Copy out the conditionals + for (auto& factor : bayesNet_) { + conditionals_.push_back(factor); + } + + // Create the root node: no variables assigned, nextConditional = last. + SearchNode root{ + .assignment = DiscreteValues(), + .error = 0.0, + .nextConditional = static_cast(conditionals_.size()) - 1}; + root.bound = root.computeBound(); + std::cout << "Root: " << root << std::endl; + expansions_.push(root); } - // Min-heap of expansions by bound - std::priority_queue, CompareByBound> - expansions; + /** + * @brief Search for the K best solutions. + * + * This method performs a search to find the K best solutions for the given + * DiscreteBayesNet. It uses a priority queue to manage the search nodes, + * expanding nodes with the smallest bound first. The search continues until + * all possible nodes have been expanded or pruned. + * + * @return A vector of the K best solutions found during the search. + */ + std::vector run() { + size_t numExpansions = 0; + while (!expansions_.empty()) { + expandNextNode(); + numExpansions++; + } - // Max-heap of best solutions found so far, keyed by error. - // We keep the largest error on top. - Solutions solutions(K); + std::cout << "Expansions: " << numExpansions << std::endl; - // 1) Create the root node: no variables assigned, nextConditional = last. - SearchNode root{.assignment = DiscreteValues(), - .error = 0.0, - .nextConditional = static_cast(conditionals.size()) - 1}; - root.bound = root.computeBound(); - std::cout << "Root: " << root << std::endl; - expansions.push(root); + // Extract solutions from bestSolutions in ascending order of error + return solutions_.extractSolutions(); + } - // 2) Main loop - size_t numExpansions = 0; - while (!expansions.empty()) { + private: + // + void expandNextNode() { // Pop the partial assignment with the smallest bound - SearchNode current = expansions.top(); - expansions.pop(); - numExpansions++; + SearchNode current = expansions_.top(); + expansions_.pop(); std::cout << "Expanding: " << current << std::endl; // If we already have K solutions, prune if we cannot beat the worst one. - if (solutions.prune(current.bound)) { + if (solutions_.prune(current.bound)) { std::cout << "Pruning: bound=" << current.bound << std::endl; - continue; + return; } // Check if we have a complete assignment if (current.isComplete()) { - const bool added = solutions.maybeAdd(current.error, current.assignment); + const bool added = solutions_.maybeAdd(current.error, current.assignment); if (added) { std::cout << "Best solutions so far:" << std::endl; - solutions.print(); + solutions_.print(); } - continue; + return; } // Expand on the next factor - const auto& conditional = conditionals[current.nextConditional]; + const auto& conditional = conditionals_[current.nextConditional]; for (auto& fa : conditional->frontalAssignments()) { std::cout << "Frontal assignment: " << fa << std::endl; auto childNode = current.expand(*conditional, fa); // Again, prune if we cannot beat the worst solution - if (solutions.prune(current.bound)) { + if (solutions_.prune(current.bound)) { std::cout << "Pruning: bound=" << childNode.bound << std::endl; continue; } - expansions.push(childNode); + expansions_.push(childNode); } } - std::cout << "Expansions: " << numExpansions << std::endl; - - // 3) Extract solutions from bestSolutions in ascending order of error - return solutions.extractSolutions(); -} + const DiscreteBayesNet& bayesNet_; + std::vector> conditionals_; + std::priority_queue, CompareByBound> + expansions_; + Solutions solutions_; +}; // ---------------------------------------------------------------------------- // Example “Unit Tests” (trivial stubs) // ---------------------------------------------------------------------------- -TEST_DISABLED(DiscreteBayesNet, EmptyKBest) { +TEST(DiscreteBayesNet, EmptyKBest) { DiscreteBayesNet net; // no factors - auto solutions = branchAndBoundKSolutions(net, 3); + BestKSearch search(net, 3); + auto solutions = search.run(); // Expect one solution with empty assignment, error=0 EXPECT_LONGS_EQUAL(1, solutions.size()); EXPECT_DOUBLES_EQUAL(0, std::fabs(solutions[0].error), 1e-9); } TEST(DiscreteBayesNet, AsiaKBest) { - DiscreteBayesNet net = constructAsiaExample(); - - auto solutions = branchAndBoundKSolutions(net, 4); + DiscreteBayesNet asia = constructAsiaExample(); + BestKSearch search(asia, 4); + auto solutions = search.run(); EXPECT(!solutions.empty()); // Regression test: check the first solution EXPECT_DOUBLES_EQUAL(1.236627, std::fabs(solutions[0].error), 1e-5); From d879b156f8de4da0f0cbdaf16b7ed5b0d13bf20f Mon Sep 17 00:00:00 2001 From: Frank Dellaert Date: Sun, 26 Jan 2025 22:08:04 -0500 Subject: [PATCH 07/43] Asia example --- gtsam/discrete/tests/AsiaExample.h | 61 +++ gtsam/discrete/tests/testDiscreteBayesNet.cpp | 367 ++---------------- 2 files changed, 98 insertions(+), 330 deletions(-) create mode 100644 gtsam/discrete/tests/AsiaExample.h diff --git a/gtsam/discrete/tests/AsiaExample.h b/gtsam/discrete/tests/AsiaExample.h new file mode 100644 index 000000000..7413f6899 --- /dev/null +++ b/gtsam/discrete/tests/AsiaExample.h @@ -0,0 +1,61 @@ +/* ---------------------------------------------------------------------------- + + * GTSAM Copyright 2010, Georgia Tech Research Corporation, + * Atlanta, Georgia 30332-0415 + * All Rights Reserved + * Authors: Frank Dellaert, et al. (see THANKS for the full author list) + + * See LICENSE for the license information + + * -------------------------------------------------------------------------- */ + +/* + * AsiaExample.h + * + * @date Jan, 2025 + * @author Frank Dellaert + */ + +#include +#include + +namespace gtsam { +namespace asia_example { + +static const Key D = Symbol('D', 1), X = Symbol('X', 2), E = Symbol('E', 3), + B = Symbol('B', 4), L = Symbol('L', 5), T = Symbol('T', 6), + S = Symbol('S', 7), A = Symbol('A', 8); + +static const DiscreteKey Dyspnea(D, 2), XRay(X, 2), Either(E, 2), + Bronchitis(B, 2), LungCancer(L, 2), Tuberculosis(T, 2), Smoking(S, 2), + Asia(A, 2); + +// Function to construct the incomplete Asia example +DiscreteBayesNet createPriors() { + DiscreteBayesNet priors; + priors.add(Smoking % "50/50"); + priors.add(Asia, "99/1"); + return priors; +} + +// Function to construct the incomplete Asia example +DiscreteBayesNet createFragment() { + DiscreteBayesNet fragment; + fragment.add((Either | Tuberculosis, LungCancer) = "F T T T"); + fragment.add(LungCancer | Smoking = "99/1 90/10"); + fragment.add(Tuberculosis | Asia = "99/1 95/5"); + for (const auto& factor : createPriors()) fragment.push_back(factor); + return fragment; +} + +// Function to construct the Asia example +DiscreteBayesNet createAsiaExample() { + DiscreteBayesNet asia; + asia.add((Dyspnea | Either, Bronchitis) = "9/1 2/8 3/7 1/9"); + asia.add(XRay | Either = "95/5 2/98"); + asia.add(Bronchitis | Smoking = "70/30 40/60"); + for (const auto& factor : createFragment()) asia.push_back(factor); + return asia; +} +} // namespace asia_example +} // namespace gtsam \ No newline at end of file diff --git a/gtsam/discrete/tests/testDiscreteBayesNet.cpp b/gtsam/discrete/tests/testDiscreteBayesNet.cpp index 648b26770..dd5d218f8 100644 --- a/gtsam/discrete/tests/testDiscreteBayesNet.cpp +++ b/gtsam/discrete/tests/testDiscreteBayesNet.cpp @@ -29,40 +29,13 @@ #include #include -using namespace std; +#include "AsiaExample.h" + using namespace gtsam; -namespace keys { -static const Key D = Symbol('D', 1), X = Symbol('X', 2), E = Symbol('E', 3), - B = Symbol('B', 4), L = Symbol('L', 5), T = Symbol('T', 6), - S = Symbol('S', 7), A = Symbol('A', 8); -} - -static const DiscreteKey Dyspnea(keys::D, 2), XRay(keys::X, 2), - Either(keys::E, 2), Bronchitis(keys::B, 2), LungCancer(keys::L, 2), - Tuberculosis(keys::T, 2), Smoking(keys::S, 2), Asia(keys::A, 2); - -using ADT = AlgebraicDecisionTree; - -// Function to construct the Asia example -DiscreteBayesNet constructAsiaExample() { - DiscreteBayesNet asia; - - // Add in topological sort order, parents last: - asia.add((Dyspnea | Either, Bronchitis) = "9/1 2/8 3/7 1/9"); - asia.add(XRay | Either = "95/5 2/98"); - asia.add((Either | Tuberculosis, LungCancer) = "F T T T"); - asia.add(Bronchitis | Smoking = "70/30 40/60"); - asia.add(LungCancer | Smoking = "99/1 90/10"); - asia.add(Tuberculosis | Asia = "99/1 95/5"); - asia.add(Smoking % "50/50"); // Signature version - asia.add(Asia, "99/1"); - - return asia; -} - /* ************************************************************************* */ TEST(DiscreteBayesNet, bayesNet) { + using ADT = AlgebraicDecisionTree; DiscreteBayesNet bayesNet; DiscreteKey Parent(0, 2), Child(1, 2); @@ -92,7 +65,8 @@ TEST(DiscreteBayesNet, bayesNet) { /* ************************************************************************* */ TEST(DiscreteBayesNet, Asia) { - DiscreteBayesNet asia = constructAsiaExample(); + using namespace asia_example; + const DiscreteBayesNet asia = createAsiaExample(); // Convert to factor graph DiscreteFactorGraph fg(asia); @@ -105,8 +79,7 @@ TEST(DiscreteBayesNet, Asia) { EXPECT(assert_equal(vs, marginals.marginalProbabilities(Smoking))); // Create solver and eliminate - const Ordering ordering{keys::A, keys::D, keys::T, keys::X, - keys::S, keys::E, keys::L, keys::B}; + const Ordering ordering{A, D, T, X, S, E, L, B}; DiscreteBayesNet::shared_ptr chordal = fg.eliminateSequential(ordering); DiscreteConditional expected2(Bronchitis % "11/9"); EXPECT(assert_equal(expected2, *chordal->back())); @@ -151,319 +124,53 @@ TEST(DiscreteBayesNet, Sugar) { /* ************************************************************************* */ TEST(DiscreteBayesNet, Dot) { - DiscreteBayesNet fragment; - fragment.add(Asia % "99/1"); - fragment.add(Smoking % "50/50"); + using namespace asia_example; + const DiscreteBayesNet fragment = createFragment(); - fragment.add(Tuberculosis | Asia = "99/1 95/5"); - fragment.add(LungCancer | Smoking = "99/1 90/10"); - fragment.add((Either | Tuberculosis, LungCancer) = "F T T T"); - - string actual = fragment.dot(); - EXPECT(actual == - "digraph {\n" - " size=\"5,5\";\n" - "\n" - " var4683743612465315848[label=\"A8\"];\n" - " var4971973988617027587[label=\"E3\"];\n" - " var5476377146882523141[label=\"L5\"];\n" - " var5980780305148018695[label=\"S7\"];\n" - " var6052837899185946630[label=\"T6\"];\n" - "\n" - " var6052837899185946630->var4971973988617027587\n" - " var5476377146882523141->var4971973988617027587\n" - " var5980780305148018695->var5476377146882523141\n" - " var4683743612465315848->var6052837899185946630\n" - "}"); + std::string expected = + "digraph {\n" + " size=\"5,5\";\n" + "\n" + " var4683743612465315848[label=\"A8\"];\n" + " var4971973988617027587[label=\"E3\"];\n" + " var5476377146882523141[label=\"L5\"];\n" + " var5980780305148018695[label=\"S7\"];\n" + " var6052837899185946630[label=\"T6\"];\n" + "\n" + " var4683743612465315848->var6052837899185946630\n" + " var5980780305148018695->var5476377146882523141\n" + " var6052837899185946630->var4971973988617027587\n" + " var5476377146882523141->var4971973988617027587\n" + "}"; + std::string actual = fragment.dot(); + EXPECT(actual.compare(expected) == 0); } /* ************************************************************************* */ // Check markdown representation looks as expected. TEST(DiscreteBayesNet, markdown) { - DiscreteBayesNet fragment; - fragment.add(Asia % "99/1"); - fragment.add(Smoking | Asia = "8/2 7/3"); + using namespace asia_example; + DiscreteBayesNet priors = createPriors(); - string expected = + std::string expected = "`DiscreteBayesNet` of size 2\n" "\n" + " *P(Smoking):*\n\n" + "|Smoking|value|\n" + "|:-:|:-:|\n" + "|0|0.5|\n" + "|1|0.5|\n" + "\n" " *P(Asia):*\n\n" "|Asia|value|\n" "|:-:|:-:|\n" "|0|0.99|\n" - "|1|0.01|\n" - "\n" - " *P(Smoking|Asia):*\n\n" - "|*Asia*|0|1|\n" - "|:-:|:-:|:-:|\n" - "|0|0.8|0.2|\n" - "|1|0.7|0.3|\n\n"; - auto formatter = [](Key key) { return key == keys::A ? "Asia" : "Smoking"; }; - string actual = fragment.markdown(formatter); + "|1|0.01|\n\n"; + auto formatter = [](Key key) { return key == A ? "Asia" : "Smoking"; }; + std::string actual = priors.markdown(formatter); EXPECT(actual == expected); } -/* ************************************************************************* */ -#include -#include -#include -#include -#include -#include - -using Value = size_t; - -// ---------------------------------------------------------------------------- -// 1) SearchNode: store partial assignment and next factor to expand -// ---------------------------------------------------------------------------- -struct SearchNode { - DiscreteValues assignment; - double error; - double bound; - int nextConditional; // index into conditionals - - /// if nextConditional < 0, we've assigned everything. - bool isComplete() const { return nextConditional < 0; } - - /// lower bound on final error for unassigned variables. Stub=0. - double computeBound() const { - // Real code might do partial factor analysis or heuristics. - return 0.0; - } - - /// Expand this node by assigning the next variable - SearchNode expand(const DiscreteConditional& conditional, - const DiscreteValues& fa) const { - // Combine the new frontal assignment with the current partial assignment - SearchNode child; - child.assignment = assignment; - for (auto& kv : fa) { - child.assignment[kv.first] = kv.second; - } - - // Compute the incremental error for this factor - child.error = error + conditional.error(child.assignment); - - // Compute new bound - child.bound = child.error + computeBound(); - - // Next factor index - child.nextConditional = nextConditional - 1; - - return child; - } - - friend std::ostream& operator<<(std::ostream& os, const SearchNode& sn) { - os << "[ error=" << sn.error << " bound=" << sn.bound - << " nextConditional=" << sn.nextConditional << " assignment={" - << sn.assignment << "}]"; - return os; - } -}; - -// ---------------------------------------------------------------------------- -// 2) Priority functor to make a min-heap by bound -// ---------------------------------------------------------------------------- -struct CompareByBound { - bool operator()(const SearchNode& a, const SearchNode& b) const { - return a.bound > b.bound; // smallest bound -> highest priority - } -}; - -// ---------------------------------------------------------------------------- -// 4) A Solution -// ---------------------------------------------------------------------------- -struct Solution { - double error; - DiscreteValues assignment; - Solution(double err, const DiscreteValues& assign) - : error(err), assignment(assign) {} - friend std::ostream& operator<<(std::ostream& os, const Solution& sn) { - os << "[ error=" << sn.error << " assignment={" << sn.assignment << "}]"; - return os; - } -}; - -struct CompareByError { - bool operator()(const Solution& a, const Solution& b) const { - return a.error < b.error; - } -}; - -// Define the Solutions class -class Solutions { - private: - size_t maxSize_; - std::priority_queue, CompareByError> pq_; - - public: - Solutions(size_t maxSize) : maxSize_(maxSize) {} - - /// Add a solution to the priority queue, possibly evicting the worst one. - /// Return true if we added the solution. - bool maybeAdd(double error, const DiscreteValues& assignment) { - const bool full = pq_.size() == maxSize_; - if (full && error >= pq_.top().error) return false; - if (full) pq_.pop(); - pq_.emplace(error, assignment); - return true; - } - - /// Check if we have any solutions - bool empty() const { return pq_.empty(); } - - // Method to print all solutions - void print() const { - auto pq = pq_; - while (!pq.empty()) { - const Solution& best = pq.top(); - std::cout << "Error: " << best.error << ", Values: " << best.assignment - << std::endl; - pq.pop(); - } - } - - /// Check if (partial) solution with given bound can be pruned. If we have - /// room, we never prune. Otherwise, prune if lower bound on error is worse - /// than our current worst error. - bool prune(double bound) const { - if (pq_.size() < maxSize_) return false; - double worstError = pq_.top().error; - return (bound >= worstError); - } - - // Method to extract solutions in ascending order of error - std::vector extractSolutions() { - std::vector result; - while (!pq_.empty()) { - result.push_back(pq_.top()); - pq_.pop(); - } - std::sort( - result.begin(), result.end(), - [](const Solution& a, const Solution& b) { return a.error < b.error; }); - return result; - } -}; - -/** - * BestKSearch: Search for the K best solutions. - */ -class BestKSearch { - public: - /** - * Construct from a DiscreteBayesNet and K. - */ - BestKSearch(const DiscreteBayesNet& bayesNet, size_t K) - : bayesNet_(bayesNet), solutions_(K) { - // Copy out the conditionals - for (auto& factor : bayesNet_) { - conditionals_.push_back(factor); - } - - // Create the root node: no variables assigned, nextConditional = last. - SearchNode root{ - .assignment = DiscreteValues(), - .error = 0.0, - .nextConditional = static_cast(conditionals_.size()) - 1}; - root.bound = root.computeBound(); - std::cout << "Root: " << root << std::endl; - expansions_.push(root); - } - - /** - * @brief Search for the K best solutions. - * - * This method performs a search to find the K best solutions for the given - * DiscreteBayesNet. It uses a priority queue to manage the search nodes, - * expanding nodes with the smallest bound first. The search continues until - * all possible nodes have been expanded or pruned. - * - * @return A vector of the K best solutions found during the search. - */ - std::vector run() { - size_t numExpansions = 0; - while (!expansions_.empty()) { - expandNextNode(); - numExpansions++; - } - - std::cout << "Expansions: " << numExpansions << std::endl; - - // Extract solutions from bestSolutions in ascending order of error - return solutions_.extractSolutions(); - } - - private: - // - void expandNextNode() { - // Pop the partial assignment with the smallest bound - SearchNode current = expansions_.top(); - expansions_.pop(); - std::cout << "Expanding: " << current << std::endl; - - // If we already have K solutions, prune if we cannot beat the worst one. - if (solutions_.prune(current.bound)) { - std::cout << "Pruning: bound=" << current.bound << std::endl; - return; - } - - // Check if we have a complete assignment - if (current.isComplete()) { - const bool added = solutions_.maybeAdd(current.error, current.assignment); - if (added) { - std::cout << "Best solutions so far:" << std::endl; - solutions_.print(); - } - return; - } - - // Expand on the next factor - const auto& conditional = conditionals_[current.nextConditional]; - - for (auto& fa : conditional->frontalAssignments()) { - std::cout << "Frontal assignment: " << fa << std::endl; - auto childNode = current.expand(*conditional, fa); - - // Again, prune if we cannot beat the worst solution - if (solutions_.prune(current.bound)) { - std::cout << "Pruning: bound=" << childNode.bound << std::endl; - continue; - } - - expansions_.push(childNode); - } - } - - const DiscreteBayesNet& bayesNet_; - std::vector> conditionals_; - std::priority_queue, CompareByBound> - expansions_; - Solutions solutions_; -}; - -// ---------------------------------------------------------------------------- -// Example “Unit Tests” (trivial stubs) -// ---------------------------------------------------------------------------- - -TEST(DiscreteBayesNet, EmptyKBest) { - DiscreteBayesNet net; // no factors - BestKSearch search(net, 3); - auto solutions = search.run(); - // Expect one solution with empty assignment, error=0 - EXPECT_LONGS_EQUAL(1, solutions.size()); - EXPECT_DOUBLES_EQUAL(0, std::fabs(solutions[0].error), 1e-9); -} - -TEST(DiscreteBayesNet, AsiaKBest) { - DiscreteBayesNet asia = constructAsiaExample(); - BestKSearch search(asia, 4); - auto solutions = search.run(); - EXPECT(!solutions.empty()); - // Regression test: check the first solution - EXPECT_DOUBLES_EQUAL(1.236627, std::fabs(solutions[0].error), 1e-5); -} - /* ************************************************************************* */ int main() { TestResult tr; From 455554c80384fcdc1848eb38d394d7874755f449 Mon Sep 17 00:00:00 2001 From: Frank Dellaert Date: Sun, 26 Jan 2025 22:39:10 -0500 Subject: [PATCH 08/43] DiscreteSearch --- gtsam/discrete/DiscreteSearch.h | 276 ++++++++++++++++++++ gtsam/discrete/tests/testDiscreteSearch.cpp | 59 +++++ 2 files changed, 335 insertions(+) create mode 100644 gtsam/discrete/DiscreteSearch.h create mode 100644 gtsam/discrete/tests/testDiscreteSearch.cpp diff --git a/gtsam/discrete/DiscreteSearch.h b/gtsam/discrete/DiscreteSearch.h new file mode 100644 index 000000000..d9ffae768 --- /dev/null +++ b/gtsam/discrete/DiscreteSearch.h @@ -0,0 +1,276 @@ +/* ---------------------------------------------------------------------------- + + * GTSAM Copyright 2010, Georgia Tech Research Corporation, + * Atlanta, Georgia 30332-0415 + * All Rights Reserved + * Authors: Frank Dellaert, et al. (see THANKS for the full author list) + + * See LICENSE for the license information + + * -------------------------------------------------------------------------- */ + +/* + * DiscreteSearch.cpp + * + * @date January, 2025 + * @author Frank Dellaert + */ + +#include + +namespace gtsam { + +using Value = size_t; + +/** + * @brief Represents a node in the search tree for discrete search algorithms. + * + * @details Each SearchNode contains a partial assignment of discrete variables, + * the current error, a bound on the final error, and the index of the next + * conditional to be assigned. + */ +struct SearchNode { + DiscreteValues assignment; ///< Partial assignment of discrete variables. + double error; ///< Current error for the partial assignment. + double bound; ///< Lower bound on the final error for unassigned variables. + int nextConditional; ///< Index of the next conditional to be assigned. + + struct CompareByBound { + bool operator()(const SearchNode& a, const SearchNode& b) const { + return a.bound > b.bound; // smallest bound -> highest priority + } + }; + + /** + * @brief Checks if the node represents a complete assignment. + * + * @return True if all variables have been assigned, false otherwise. + */ + bool isComplete() const { return nextConditional < 0; } + + /** + * @brief Computes a lower bound on the final error for unassigned variables. + * + * @details This is a stub implementation that returns 0. Real implementations + * might perform partial factor analysis or use heuristics to compute the + * bound. + * + * @return A lower bound on the final error. + */ + double computeBound() const { + // Real code might do partial factor analysis or heuristics. + return 0.0; + } + + /** + * @brief Expands the node by assigning the next variable. + * + * @param conditional The discrete conditional representing the next variable + * to be assigned. + * @param fa The frontal assignment for the next variable. + * @return A new SearchNode representing the expanded state. + */ + SearchNode expand(const DiscreteConditional& conditional, + const DiscreteValues& fa) const { + // Combine the new frontal assignment with the current partial assignment + SearchNode child; + child.assignment = assignment; + for (auto& kv : fa) { + child.assignment[kv.first] = kv.second; + } + + // Compute the incremental error for this factor + child.error = error + conditional.error(child.assignment); + + // Compute new bound + child.bound = computeBound(); + + // Update the index of the next conditional + child.nextConditional = nextConditional - 1; + + return child; + } + + /** + * @brief Prints the SearchNode to an output stream. + * + * @param os The output stream. + * @param node The SearchNode to be printed. + * @return The output stream. + */ + friend std::ostream& operator<<(std::ostream& os, const SearchNode& node) { + os << "SearchNode(error=" << node.error << ", bound=" << node.bound << ")"; + return os; + } +}; + +struct Solution { + double error; + DiscreteValues assignment; + Solution(double err, const DiscreteValues& assign) + : error(err), assignment(assign) {} + friend std::ostream& operator<<(std::ostream& os, const Solution& sn) { + os << "[ error=" << sn.error << " assignment={" << sn.assignment << "}]"; + return os; + } +}; + +struct CompareByError { + bool operator()(const Solution& a, const Solution& b) const { + return a.error < b.error; + } +}; + +// Define the Solutions class +class Solutions { + private: + size_t maxSize_; + std::priority_queue, CompareByError> pq_; + + public: + Solutions(size_t maxSize) : maxSize_(maxSize) {} + + /// Add a solution to the priority queue, possibly evicting the worst one. + /// Return true if we added the solution. + bool maybeAdd(double error, const DiscreteValues& assignment) { + const bool full = pq_.size() == maxSize_; + if (full && error >= pq_.top().error) return false; + if (full) pq_.pop(); + pq_.emplace(error, assignment); + return true; + } + + /// Check if we have any solutions + bool empty() const { return pq_.empty(); } + + // Method to print all solutions + friend std::ostream& operator<<(std::ostream& os, const Solutions& sn) { + auto pq = sn.pq_; + while (!pq.empty()) { + const Solution& best = pq.top(); + os << "Error: " << best.error << ", Values: " << best.assignment + << std::endl; + pq.pop(); + } + return os; + } + + /// Check if (partial) solution with given bound can be pruned. If we have + /// room, we never prune. Otherwise, prune if lower bound on error is worse + /// than our current worst error. + bool prune(double bound) const { + if (pq_.size() < maxSize_) return false; + double worstError = pq_.top().error; + return (bound >= worstError); + } + + // Method to extract solutions in ascending order of error + std::vector extractSolutions() { + std::vector result; + while (!pq_.empty()) { + result.push_back(pq_.top()); + pq_.pop(); + } + std::sort( + result.begin(), result.end(), + [](const Solution& a, const Solution& b) { return a.error < b.error; }); + return result; + } +}; + +/** + * DiscreteSearch: Search for the K best solutions. + */ +class DiscreteSearch { + public: + /** + * Construct from a DiscreteBayesNet and K. + */ + DiscreteSearch(const DiscreteBayesNet& bayesNet, size_t K) : solutions_(K) { + // Copy out the conditionals + for (auto& factor : bayesNet) { + conditionals_.push_back(factor); + } + + // Calculate the cost-to-go for each conditional. If there are n + // conditionals, we start with nextConditional = n-1, and the minimum error + // obtainable is the sum of all the minimum errors. We start with + // 0, and that is the minimum error of the conditional with that index: + double error = 0.0; + for (const auto& conditional : conditionals_) { + Ordering ordering(conditional->begin(), conditional->end()); + auto maxx = conditional->max(ordering); + assert(maxx->size() == 1); + error -= std::log(maxx->evaluate({})); + costToGo_.push_back(error); + } + + // Create the root node: no variables assigned, nextConditional = last. + SearchNode root{ + .assignment = DiscreteValues(), + .error = 0.0, + .nextConditional = static_cast(conditionals_.size()) - 1}; + if (!costToGo_.empty()) root.bound = costToGo_.back(); + expansions_.push(root); + } + + /** + * @brief Search for the K best solutions. + * + * This method performs a search to find the K best solutions for the given + * DiscreteBayesNet. It uses a priority queue to manage the search nodes, + * expanding nodes with the smallest bound first. The search continues until + * all possible nodes have been expanded or pruned. + * + * @return A vector of the K best solutions found during the search. + */ + std::vector run() { + while (!expansions_.empty()) { + expandNextNode(); + } + + // Extract solutions from bestSolutions in ascending order of error + return solutions_.extractSolutions(); + } + + private: + void expandNextNode() { + // Pop the partial assignment with the smallest bound + SearchNode current = expansions_.top(); + expansions_.pop(); + + // If we already have K solutions, prune if we cannot beat the worst one. + if (solutions_.prune(current.bound)) { + return; + } + + // Check if we have a complete assignment + if (current.isComplete()) { + solutions_.maybeAdd(current.error, current.assignment); + return; + } + + // Expand on the next factor + const auto& conditional = conditionals_[current.nextConditional]; + + for (auto& fa : conditional->frontalAssignments()) { + auto childNode = current.expand(*conditional, fa); + if (childNode.nextConditional >= 0) + childNode.bound = + childNode.error + costToGo_[childNode.nextConditional]; + + // Again, prune if we cannot beat the worst solution + if (!solutions_.prune(childNode.bound)) { + expansions_.push(childNode); + } + } + } + + std::vector> conditionals_; + std::vector costToGo_; + std::priority_queue, + SearchNode::CompareByBound> + expansions_; + Solutions solutions_; +}; +} // namespace gtsam diff --git a/gtsam/discrete/tests/testDiscreteSearch.cpp b/gtsam/discrete/tests/testDiscreteSearch.cpp new file mode 100644 index 000000000..4d7adbac5 --- /dev/null +++ b/gtsam/discrete/tests/testDiscreteSearch.cpp @@ -0,0 +1,59 @@ +/* ---------------------------------------------------------------------------- + + * GTSAM Copyright 2010, Georgia Tech Research Corporation, + * Atlanta, Georgia 30332-0415 + * All Rights Reserved + * Authors: Frank Dellaert, et al. (see THANKS for the full author list) + + * See LICENSE for the license information + + * -------------------------------------------------------------------------- */ + +/* + * testDiscreteSearch.cpp + * + * @date January, 2025 + * @author Frank Dellaert + */ + +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include + +#include "AsiaExample.h" + +using namespace gtsam; + +TEST(DiscreteBayesNet, EmptyKBest) { + DiscreteBayesNet net; // no factors + DiscreteSearch search(net, 3); + auto solutions = search.run(); + // Expect one solution with empty assignment, error=0 + EXPECT_LONGS_EQUAL(1, solutions.size()); + EXPECT_DOUBLES_EQUAL(0, std::fabs(solutions[0].error), 1e-9); +} + +TEST(DiscreteBayesNet, AsiaKBest) { + using namespace asia_example; + DiscreteBayesNet asia = createAsiaExample(); + DiscreteSearch search(asia, 4); + auto solutions = search.run(); + EXPECT(!solutions.empty()); + // Regression test: check the first solution + EXPECT_DOUBLES_EQUAL(1.236627, std::fabs(solutions[0].error), 1e-5); +} + +/* ************************************************************************* */ +int main() { + TestResult tr; + return TestRegistry::runAllTests(tr); +} +/* ************************************************************************* */ From 54f493358d5a57e2d3034159c18bfca2f444cfa5 Mon Sep 17 00:00:00 2001 From: Frank Dellaert Date: Sun, 26 Jan 2025 22:48:40 -0500 Subject: [PATCH 09/43] BayesTree WIP --- gtsam/discrete/DiscreteSearch.h | 7 +++++ gtsam/discrete/tests/testDiscreteSearch.cpp | 31 +++++++++++++++++++++ 2 files changed, 38 insertions(+) diff --git a/gtsam/discrete/DiscreteSearch.h b/gtsam/discrete/DiscreteSearch.h index d9ffae768..c267001b5 100644 --- a/gtsam/discrete/DiscreteSearch.h +++ b/gtsam/discrete/DiscreteSearch.h @@ -17,6 +17,7 @@ */ #include +#include namespace gtsam { @@ -214,6 +215,12 @@ class DiscreteSearch { expansions_.push(root); } + /** + * Construct from a DiscreteBayesNet and K. + */ + DiscreteSearch(const DiscreteBayesTree& bayesTree, size_t K) + : solutions_(K) {} + /** * @brief Search for the K best solutions. * diff --git a/gtsam/discrete/tests/testDiscreteSearch.cpp b/gtsam/discrete/tests/testDiscreteSearch.cpp index 4d7adbac5..b0f2b03e3 100644 --- a/gtsam/discrete/tests/testDiscreteSearch.cpp +++ b/gtsam/discrete/tests/testDiscreteSearch.cpp @@ -32,6 +32,7 @@ using namespace gtsam; +/* ************************************************************************* */ TEST(DiscreteBayesNet, EmptyKBest) { DiscreteBayesNet net; // no factors DiscreteSearch search(net, 3); @@ -41,6 +42,7 @@ TEST(DiscreteBayesNet, EmptyKBest) { EXPECT_DOUBLES_EQUAL(0, std::fabs(solutions[0].error), 1e-9); } +/* ************************************************************************* */ TEST(DiscreteBayesNet, AsiaKBest) { using namespace asia_example; DiscreteBayesNet asia = createAsiaExample(); @@ -51,6 +53,35 @@ TEST(DiscreteBayesNet, AsiaKBest) { EXPECT_DOUBLES_EQUAL(1.236627, std::fabs(solutions[0].error), 1e-5); } +/* ************************************************************************* */ +TEST(DiscreteBayesTree, testEmptyTree) { + DiscreteBayesTree bt; + + DiscreteSearch search(bt, 3); + auto solutions = search.run(); + + // We expect exactly 1 solution with error = 0.0 (the empty assignment). + assert(solutions.size() == 1 && "There should be exactly one empty solution"); + EXPECT_LONGS_EQUAL(1, solutions.size()); + EXPECT_DOUBLES_EQUAL(0, std::fabs(solutions[0].error), 1e-9); +} + +/* ************************************************************************* */ +TEST(DiscreteBayesTree, testTrivialOneClique) { + using namespace asia_example; + DiscreteFactorGraph asia(createAsiaExample()); + DiscreteBayesTree::shared_ptr bt = asia.eliminateMultifrontal(); + GTSAM_PRINT(*bt); + + // Ask for top 4 solutions + DiscreteSearch search(*bt, 4); + auto solutions = search.run(); + + EXPECT(!solutions.empty()); + // Regression test: check the first solution + EXPECT_DOUBLES_EQUAL(1.236627, std::fabs(solutions[0].error), 1e-5); +} + /* ************************************************************************* */ int main() { TestResult tr; From 14eeaf93db346f04acd0704abaf0e91b07e209f2 Mon Sep 17 00:00:00 2001 From: Frank Dellaert Date: Sun, 26 Jan 2025 23:24:09 -0500 Subject: [PATCH 10/43] Works for Bayes trees now as well --- gtsam/discrete/DiscreteSearch.h | 124 ++++++++++++-------- gtsam/discrete/tests/AsiaExample.h | 2 +- gtsam/discrete/tests/testDiscreteSearch.cpp | 12 +- 3 files changed, 84 insertions(+), 54 deletions(-) diff --git a/gtsam/discrete/DiscreteSearch.h b/gtsam/discrete/DiscreteSearch.h index c267001b5..24517ee29 100644 --- a/gtsam/discrete/DiscreteSearch.h +++ b/gtsam/discrete/DiscreteSearch.h @@ -36,6 +36,16 @@ struct SearchNode { double bound; ///< Lower bound on the final error for unassigned variables. int nextConditional; ///< Index of the next conditional to be assigned. + /** + * @brief Construct the root node for the search. + */ + static SearchNode Root(size_t numConditionals, double bound) { + return {.assignment = DiscreteValues(), + .error = 0.0, + .bound = bound, + .nextConditional = static_cast(numConditionals) - 1}; + } + struct CompareByBound { bool operator()(const SearchNode& a, const SearchNode& b) const { return a.bound > b.bound; // smallest bound -> highest priority @@ -49,20 +59,6 @@ struct SearchNode { */ bool isComplete() const { return nextConditional < 0; } - /** - * @brief Computes a lower bound on the final error for unassigned variables. - * - * @details This is a stub implementation that returns 0. Real implementations - * might perform partial factor analysis or use heuristics to compute the - * bound. - * - * @return A lower bound on the final error. - */ - double computeBound() const { - // Real code might do partial factor analysis or heuristics. - return 0.0; - } - /** * @brief Expands the node by assigning the next variable. * @@ -74,22 +70,15 @@ struct SearchNode { SearchNode expand(const DiscreteConditional& conditional, const DiscreteValues& fa) const { // Combine the new frontal assignment with the current partial assignment - SearchNode child; - child.assignment = assignment; + DiscreteValues newAssignment = assignment; for (auto& kv : fa) { - child.assignment[kv.first] = kv.second; + newAssignment[kv.first] = kv.second; } - // Compute the incremental error for this factor - child.error = error + conditional.error(child.assignment); - - // Compute new bound - child.bound = computeBound(); - - // Update the index of the next conditional - child.nextConditional = nextConditional - 1; - - return child; + return {.assignment = newAssignment, + .error = error + conditional.error(newAssignment), + .bound = 0.0, + .nextConditional = nextConditional - 1}; } /** @@ -184,6 +173,8 @@ class Solutions { */ class DiscreteSearch { public: + size_t numExpansions = 0; + /** * Construct from a DiscreteBayesNet and K. */ @@ -193,33 +184,42 @@ class DiscreteSearch { conditionals_.push_back(factor); } - // Calculate the cost-to-go for each conditional. If there are n - // conditionals, we start with nextConditional = n-1, and the minimum error - // obtainable is the sum of all the minimum errors. We start with - // 0, and that is the minimum error of the conditional with that index: - double error = 0.0; - for (const auto& conditional : conditionals_) { - Ordering ordering(conditional->begin(), conditional->end()); - auto maxx = conditional->max(ordering); - assert(maxx->size() == 1); - error -= std::log(maxx->evaluate({})); - costToGo_.push_back(error); - } + // Calculate the cost-to-go for each conditional + costToGo_ = computeCostToGo(conditionals_); - // Create the root node: no variables assigned, nextConditional = last. - SearchNode root{ - .assignment = DiscreteValues(), - .error = 0.0, - .nextConditional = static_cast(conditionals_.size()) - 1}; - if (!costToGo_.empty()) root.bound = costToGo_.back(); - expansions_.push(root); + // Create the root node and push it to the expansions queue + expansions_.push(SearchNode::Root( + conditionals_.size(), costToGo_.empty() ? 0.0 : costToGo_.back())); } /** - * Construct from a DiscreteBayesNet and K. + * Construct from a DiscreteBayesTree and K. */ - DiscreteSearch(const DiscreteBayesTree& bayesTree, size_t K) - : solutions_(K) {} + DiscreteSearch(const DiscreteBayesTree& bayesTree, size_t K) : solutions_(K) { + using CliquePtr = DiscreteBayesTree::sharedClique; + std::function collectConditionals = + [&](const CliquePtr& clique) -> void { + if (!clique) return; + + // Recursive post-order traversal: process children first + for (const auto& child : clique->children) { + collectConditionals(child); + } + + // Then add the current clique's conditional + conditionals_.push_back(clique->conditional()); + }; + + // Start traversal from each root in the tree + for (const auto& root : bayesTree.roots()) collectConditionals(root); + + // Calculate the cost-to-go for each conditional + costToGo_ = computeCostToGo(conditionals_); + + // Create the root node and push it to the expansions queue + expansions_.push(SearchNode::Root( + conditionals_.size(), costToGo_.empty() ? 0.0 : costToGo_.back())); + } /** * @brief Search for the K best solutions. @@ -233,6 +233,7 @@ class DiscreteSearch { */ std::vector run() { while (!expansions_.empty()) { + numExpansions++; expandNextNode(); } @@ -241,6 +242,29 @@ class DiscreteSearch { } private: + /** + * @brief Compute the cost-to-go for each conditional. + * + * @param conditionals The conditionals of the DiscreteBayesNet. + * @return A vector of cost-to-go values. + */ + static std::vector computeCostToGo( + const std::vector& conditionals) { + std::vector costToGo; + double error = 0.0; + for (const auto& conditional : conditionals) { + Ordering ordering(conditional->begin(), conditional->end()); + auto maxx = conditional->max(ordering); + assert(maxx->size() == 1); + error -= std::log(maxx->evaluate({})); + costToGo.push_back(error); + } + return costToGo; + } + + /** + * @brief Expand the next node in the search tree. + */ void expandNextNode() { // Pop the partial assignment with the smallest bound SearchNode current = expansions_.top(); @@ -273,7 +297,7 @@ class DiscreteSearch { } } - std::vector> conditionals_; + std::vector conditionals_; std::vector costToGo_; std::priority_queue, SearchNode::CompareByBound> diff --git a/gtsam/discrete/tests/AsiaExample.h b/gtsam/discrete/tests/AsiaExample.h index 7413f6899..6c327daec 100644 --- a/gtsam/discrete/tests/AsiaExample.h +++ b/gtsam/discrete/tests/AsiaExample.h @@ -30,7 +30,7 @@ static const DiscreteKey Dyspnea(D, 2), XRay(X, 2), Either(E, 2), Bronchitis(B, 2), LungCancer(L, 2), Tuberculosis(T, 2), Smoking(S, 2), Asia(A, 2); -// Function to construct the incomplete Asia example +// Function to construct the Asia priors DiscreteBayesNet createPriors() { DiscreteBayesNet priors; priors.add(Smoking % "50/50"); diff --git a/gtsam/discrete/tests/testDiscreteSearch.cpp b/gtsam/discrete/tests/testDiscreteSearch.cpp index b0f2b03e3..041c3b465 100644 --- a/gtsam/discrete/tests/testDiscreteSearch.cpp +++ b/gtsam/discrete/tests/testDiscreteSearch.cpp @@ -49,8 +49,9 @@ TEST(DiscreteBayesNet, AsiaKBest) { DiscreteSearch search(asia, 4); auto solutions = search.run(); EXPECT(!solutions.empty()); - // Regression test: check the first solution + // Regression test: check the first and last solution EXPECT_DOUBLES_EQUAL(1.236627, std::fabs(solutions[0].error), 1e-5); + EXPECT_DOUBLES_EQUAL(2.201708, std::fabs(solutions[3].error), 1e-5); } /* ************************************************************************* */ @@ -70,16 +71,21 @@ TEST(DiscreteBayesTree, testEmptyTree) { TEST(DiscreteBayesTree, testTrivialOneClique) { using namespace asia_example; DiscreteFactorGraph asia(createAsiaExample()); - DiscreteBayesTree::shared_ptr bt = asia.eliminateMultifrontal(); + const Ordering ordering{D, X, B, E, L, T, S, A}; + DiscreteBayesTree::shared_ptr bt = asia.eliminateMultifrontal(ordering); GTSAM_PRINT(*bt); // Ask for top 4 solutions DiscreteSearch search(*bt, 4); auto solutions = search.run(); + // print numExpansions + std::cout << "Number of expansions: " << search.numExpansions << std::endl; + EXPECT(!solutions.empty()); - // Regression test: check the first solution + // Regression test: check the first and last solution EXPECT_DOUBLES_EQUAL(1.236627, std::fabs(solutions[0].error), 1e-5); + EXPECT_DOUBLES_EQUAL(2.201708, std::fabs(solutions[3].error), 1e-5); } /* ************************************************************************* */ From 70089a0fd4e0c3c7e87cb0eb2493e82b8165235d Mon Sep 17 00:00:00 2001 From: Frank Dellaert Date: Sun, 26 Jan 2025 23:34:30 -0500 Subject: [PATCH 11/43] Move to cpp file --- gtsam/discrete/DiscreteSearch.cpp | 178 ++++++++++++++++++++ gtsam/discrete/DiscreteSearch.h | 169 +++---------------- gtsam/discrete/tests/testDiscreteSearch.cpp | 5 +- 3 files changed, 201 insertions(+), 151 deletions(-) create mode 100644 gtsam/discrete/DiscreteSearch.cpp diff --git a/gtsam/discrete/DiscreteSearch.cpp b/gtsam/discrete/DiscreteSearch.cpp new file mode 100644 index 000000000..16794dcfc --- /dev/null +++ b/gtsam/discrete/DiscreteSearch.cpp @@ -0,0 +1,178 @@ +/* ---------------------------------------------------------------------------- + + * GTSAM Copyright 2010, Georgia Tech Research Corporation, + * Atlanta, Georgia 30332-0415 + * All Rights Reserved + * Authors: Frank Dellaert, et al. (see THANKS for the full author list) + + * See LICENSE for the license information + + * -------------------------------------------------------------------------- */ + +/* + * DiscreteSearch.cpp + * + * @date January, 2025 + * @author Frank Dellaert + */ + +#include + +namespace gtsam { + +SearchNode SearchNode::Root(size_t numConditionals, double bound) { + return {.assignment = DiscreteValues(), + .error = 0.0, + .bound = bound, + .nextConditional = static_cast(numConditionals) - 1}; +} + +SearchNode SearchNode::expand(const DiscreteConditional& conditional, + const DiscreteValues& fa) const { + // Combine the new frontal assignment with the current partial assignment + DiscreteValues newAssignment = assignment; + for (auto& kv : fa) { + newAssignment[kv.first] = kv.second; + } + + return {.assignment = newAssignment, + .error = error + conditional.error(newAssignment), + .bound = 0.0, + .nextConditional = nextConditional - 1}; +} + +bool Solutions::maybeAdd(double error, const DiscreteValues& assignment) { + const bool full = pq_.size() == maxSize_; + if (full && error >= pq_.top().error) return false; + if (full) pq_.pop(); + pq_.emplace(error, assignment); + return true; +} + +std::ostream& operator<<(std::ostream& os, const Solutions& sn) { + auto pq = sn.pq_; + while (!pq.empty()) { + const Solution& best = pq.top(); + os << "Error: " << best.error << ", Values: " << best.assignment + << std::endl; + pq.pop(); + } + return os; +} + +bool Solutions::prune(double bound) const { + if (pq_.size() < maxSize_) return false; + double worstError = pq_.top().error; + return (bound >= worstError); +} + +std::vector Solutions::extractSolutions() { + std::vector result; + while (!pq_.empty()) { + result.push_back(pq_.top()); + pq_.pop(); + } + std::sort( + result.begin(), result.end(), + [](const Solution& a, const Solution& b) { return a.error < b.error; }); + return result; +} + +DiscreteSearch::DiscreteSearch(const DiscreteBayesNet& bayesNet, size_t K) + : solutions_(K) { + // Copy out the conditionals + for (auto& factor : bayesNet) { + conditionals_.push_back(factor); + } + + // Calculate the cost-to-go for each conditional + costToGo_ = computeCostToGo(conditionals_); + + // Create the root node and push it to the expansions queue + expansions_.push(SearchNode::Root( + conditionals_.size(), costToGo_.empty() ? 0.0 : costToGo_.back())); +} + +DiscreteSearch::DiscreteSearch(const DiscreteBayesTree& bayesTree, size_t K) + : solutions_(K) { + using CliquePtr = DiscreteBayesTree::sharedClique; + std::function collectConditionals = + [&](const CliquePtr& clique) -> void { + if (!clique) return; + + // Recursive post-order traversal: process children first + for (const auto& child : clique->children) { + collectConditionals(child); + } + + // Then add the current clique's conditional + conditionals_.push_back(clique->conditional()); + }; + + // Start traversal from each root in the tree + for (const auto& root : bayesTree.roots()) collectConditionals(root); + + // Calculate the cost-to-go for each conditional + costToGo_ = computeCostToGo(conditionals_); + + // Create the root node and push it to the expansions queue + expansions_.push(SearchNode::Root( + conditionals_.size(), costToGo_.empty() ? 0.0 : costToGo_.back())); +} + +std::vector DiscreteSearch::run() { + while (!expansions_.empty()) { + numExpansions++; + expandNextNode(); + } + + // Extract solutions from bestSolutions in ascending order of error + return solutions_.extractSolutions(); +} + +std::vector DiscreteSearch::computeCostToGo( + const std::vector& conditionals) { + std::vector costToGo; + double error = 0.0; + for (const auto& conditional : conditionals) { + Ordering ordering(conditional->begin(), conditional->end()); + auto maxx = conditional->max(ordering); + assert(maxx->size() == 1); + error -= std::log(maxx->evaluate({})); + costToGo.push_back(error); + } + return costToGo; +} + +void DiscreteSearch::expandNextNode() { + // Pop the partial assignment with the smallest bound + SearchNode current = expansions_.top(); + expansions_.pop(); + + // If we already have K solutions, prune if we cannot beat the worst one. + if (solutions_.prune(current.bound)) { + return; + } + + // Check if we have a complete assignment + if (current.isComplete()) { + solutions_.maybeAdd(current.error, current.assignment); + return; + } + + // Expand on the next factor + const auto& conditional = conditionals_[current.nextConditional]; + + for (auto& fa : conditional->frontalAssignments()) { + auto childNode = current.expand(*conditional, fa); + if (childNode.nextConditional >= 0) + childNode.bound = childNode.error + costToGo_[childNode.nextConditional]; + + // Again, prune if we cannot beat the worst solution + if (!solutions_.prune(childNode.bound)) { + expansions_.push(childNode); + } + } +} + +} // namespace gtsam diff --git a/gtsam/discrete/DiscreteSearch.h b/gtsam/discrete/DiscreteSearch.h index 24517ee29..989437b9f 100644 --- a/gtsam/discrete/DiscreteSearch.h +++ b/gtsam/discrete/DiscreteSearch.h @@ -39,14 +39,9 @@ struct SearchNode { /** * @brief Construct the root node for the search. */ - static SearchNode Root(size_t numConditionals, double bound) { - return {.assignment = DiscreteValues(), - .error = 0.0, - .bound = bound, - .nextConditional = static_cast(numConditionals) - 1}; - } + static SearchNode Root(size_t numConditionals, double bound); - struct CompareByBound { + struct Compare { bool operator()(const SearchNode& a, const SearchNode& b) const { return a.bound > b.bound; // smallest bound -> highest priority } @@ -68,18 +63,7 @@ struct SearchNode { * @return A new SearchNode representing the expanded state. */ SearchNode expand(const DiscreteConditional& conditional, - const DiscreteValues& fa) const { - // Combine the new frontal assignment with the current partial assignment - DiscreteValues newAssignment = assignment; - for (auto& kv : fa) { - newAssignment[kv.first] = kv.second; - } - - return {.assignment = newAssignment, - .error = error + conditional.error(newAssignment), - .bound = 0.0, - .nextConditional = nextConditional - 1}; - } + const DiscreteValues& fa) const; /** * @brief Prints the SearchNode to an output stream. @@ -103,69 +87,40 @@ struct Solution { os << "[ error=" << sn.error << " assignment={" << sn.assignment << "}]"; return os; } -}; -struct CompareByError { - bool operator()(const Solution& a, const Solution& b) const { - return a.error < b.error; - } + struct Compare { + bool operator()(const Solution& a, const Solution& b) const { + return a.error < b.error; + } + }; }; // Define the Solutions class class Solutions { private: size_t maxSize_; - std::priority_queue, CompareByError> pq_; + std::priority_queue, Solution::Compare> pq_; public: Solutions(size_t maxSize) : maxSize_(maxSize) {} /// Add a solution to the priority queue, possibly evicting the worst one. /// Return true if we added the solution. - bool maybeAdd(double error, const DiscreteValues& assignment) { - const bool full = pq_.size() == maxSize_; - if (full && error >= pq_.top().error) return false; - if (full) pq_.pop(); - pq_.emplace(error, assignment); - return true; - } + bool maybeAdd(double error, const DiscreteValues& assignment); /// Check if we have any solutions bool empty() const { return pq_.empty(); } // Method to print all solutions - friend std::ostream& operator<<(std::ostream& os, const Solutions& sn) { - auto pq = sn.pq_; - while (!pq.empty()) { - const Solution& best = pq.top(); - os << "Error: " << best.error << ", Values: " << best.assignment - << std::endl; - pq.pop(); - } - return os; - } + friend std::ostream& operator<<(std::ostream& os, const Solutions& sn); /// Check if (partial) solution with given bound can be pruned. If we have /// room, we never prune. Otherwise, prune if lower bound on error is worse /// than our current worst error. - bool prune(double bound) const { - if (pq_.size() < maxSize_) return false; - double worstError = pq_.top().error; - return (bound >= worstError); - } + bool prune(double bound) const; // Method to extract solutions in ascending order of error - std::vector extractSolutions() { - std::vector result; - while (!pq_.empty()) { - result.push_back(pq_.top()); - pq_.pop(); - } - std::sort( - result.begin(), result.end(), - [](const Solution& a, const Solution& b) { return a.error < b.error; }); - return result; - } + std::vector extractSolutions(); }; /** @@ -178,48 +133,12 @@ class DiscreteSearch { /** * Construct from a DiscreteBayesNet and K. */ - DiscreteSearch(const DiscreteBayesNet& bayesNet, size_t K) : solutions_(K) { - // Copy out the conditionals - for (auto& factor : bayesNet) { - conditionals_.push_back(factor); - } - - // Calculate the cost-to-go for each conditional - costToGo_ = computeCostToGo(conditionals_); - - // Create the root node and push it to the expansions queue - expansions_.push(SearchNode::Root( - conditionals_.size(), costToGo_.empty() ? 0.0 : costToGo_.back())); - } + DiscreteSearch(const DiscreteBayesNet& bayesNet, size_t K); /** * Construct from a DiscreteBayesTree and K. */ - DiscreteSearch(const DiscreteBayesTree& bayesTree, size_t K) : solutions_(K) { - using CliquePtr = DiscreteBayesTree::sharedClique; - std::function collectConditionals = - [&](const CliquePtr& clique) -> void { - if (!clique) return; - - // Recursive post-order traversal: process children first - for (const auto& child : clique->children) { - collectConditionals(child); - } - - // Then add the current clique's conditional - conditionals_.push_back(clique->conditional()); - }; - - // Start traversal from each root in the tree - for (const auto& root : bayesTree.roots()) collectConditionals(root); - - // Calculate the cost-to-go for each conditional - costToGo_ = computeCostToGo(conditionals_); - - // Create the root node and push it to the expansions queue - expansions_.push(SearchNode::Root( - conditionals_.size(), costToGo_.empty() ? 0.0 : costToGo_.back())); - } + DiscreteSearch(const DiscreteBayesTree& bayesTree, size_t K); /** * @brief Search for the K best solutions. @@ -231,15 +150,7 @@ class DiscreteSearch { * * @return A vector of the K best solutions found during the search. */ - std::vector run() { - while (!expansions_.empty()) { - numExpansions++; - expandNextNode(); - } - - // Extract solutions from bestSolutions in ascending order of error - return solutions_.extractSolutions(); - } + std::vector run(); private: /** @@ -249,58 +160,16 @@ class DiscreteSearch { * @return A vector of cost-to-go values. */ static std::vector computeCostToGo( - const std::vector& conditionals) { - std::vector costToGo; - double error = 0.0; - for (const auto& conditional : conditionals) { - Ordering ordering(conditional->begin(), conditional->end()); - auto maxx = conditional->max(ordering); - assert(maxx->size() == 1); - error -= std::log(maxx->evaluate({})); - costToGo.push_back(error); - } - return costToGo; - } + const std::vector& conditionals); /** * @brief Expand the next node in the search tree. */ - void expandNextNode() { - // Pop the partial assignment with the smallest bound - SearchNode current = expansions_.top(); - expansions_.pop(); - - // If we already have K solutions, prune if we cannot beat the worst one. - if (solutions_.prune(current.bound)) { - return; - } - - // Check if we have a complete assignment - if (current.isComplete()) { - solutions_.maybeAdd(current.error, current.assignment); - return; - } - - // Expand on the next factor - const auto& conditional = conditionals_[current.nextConditional]; - - for (auto& fa : conditional->frontalAssignments()) { - auto childNode = current.expand(*conditional, fa); - if (childNode.nextConditional >= 0) - childNode.bound = - childNode.error + costToGo_[childNode.nextConditional]; - - // Again, prune if we cannot beat the worst solution - if (!solutions_.prune(childNode.bound)) { - expansions_.push(childNode); - } - } - } + void expandNextNode(); std::vector conditionals_; std::vector costToGo_; - std::priority_queue, - SearchNode::CompareByBound> + std::priority_queue, SearchNode::Compare> expansions_; Solutions solutions_; }; diff --git a/gtsam/discrete/tests/testDiscreteSearch.cpp b/gtsam/discrete/tests/testDiscreteSearch.cpp index 041c3b465..a39bcd86b 100644 --- a/gtsam/discrete/tests/testDiscreteSearch.cpp +++ b/gtsam/discrete/tests/testDiscreteSearch.cpp @@ -48,6 +48,10 @@ TEST(DiscreteBayesNet, AsiaKBest) { DiscreteBayesNet asia = createAsiaExample(); DiscreteSearch search(asia, 4); auto solutions = search.run(); + + // print numExpansions + std::cout << "Number of expansions: " << search.numExpansions << std::endl; + EXPECT(!solutions.empty()); // Regression test: check the first and last solution EXPECT_DOUBLES_EQUAL(1.236627, std::fabs(solutions[0].error), 1e-5); @@ -73,7 +77,6 @@ TEST(DiscreteBayesTree, testTrivialOneClique) { DiscreteFactorGraph asia(createAsiaExample()); const Ordering ordering{D, X, B, E, L, T, S, A}; DiscreteBayesTree::shared_ptr bt = asia.eliminateMultifrontal(ordering); - GTSAM_PRINT(*bt); // Ask for top 4 solutions DiscreteSearch search(*bt, 4); From b10ea06626a6d3de83ee01d20237c41cfb686d3b Mon Sep 17 00:00:00 2001 From: Frank Dellaert Date: Mon, 27 Jan 2025 00:07:22 -0500 Subject: [PATCH 12/43] Clean up, MPE tests --- gtsam/discrete/DiscreteSearch.cpp | 60 +++++++-------------- gtsam/discrete/DiscreteSearch.h | 28 +++++----- gtsam/discrete/tests/testDiscreteSearch.cpp | 31 +++++++++-- 3 files changed, 61 insertions(+), 58 deletions(-) diff --git a/gtsam/discrete/DiscreteSearch.cpp b/gtsam/discrete/DiscreteSearch.cpp index 16794dcfc..81ebc6012 100644 --- a/gtsam/discrete/DiscreteSearch.cpp +++ b/gtsam/discrete/DiscreteSearch.cpp @@ -31,8 +31,8 @@ SearchNode SearchNode::expand(const DiscreteConditional& conditional, const DiscreteValues& fa) const { // Combine the new frontal assignment with the current partial assignment DiscreteValues newAssignment = assignment; - for (auto& kv : fa) { - newAssignment[kv.first] = kv.second; + for (auto& [key, value] : fa) { + newAssignment[key] = value; } return {.assignment = newAssignment, @@ -50,11 +50,10 @@ bool Solutions::maybeAdd(double error, const DiscreteValues& assignment) { } std::ostream& operator<<(std::ostream& os, const Solutions& sn) { + os << "Solutions (top " << sn.pq_.size() << "):\n"; auto pq = sn.pq_; while (!pq.empty()) { - const Solution& best = pq.top(); - os << "Error: " << best.error << ", Values: " << best.assignment - << std::endl; + os << pq.top() << "\n"; pq.pop(); } return os; @@ -62,8 +61,7 @@ std::ostream& operator<<(std::ostream& os, const Solutions& sn) { bool Solutions::prune(double bound) const { if (pq_.size() < maxSize_) return false; - double worstError = pq_.top().error; - return (bound >= worstError); + return bound >= pq_.top().error; } std::vector Solutions::extractSolutions() { @@ -80,45 +78,23 @@ std::vector Solutions::extractSolutions() { DiscreteSearch::DiscreteSearch(const DiscreteBayesNet& bayesNet, size_t K) : solutions_(K) { - // Copy out the conditionals - for (auto& factor : bayesNet) { - conditionals_.push_back(factor); - } - - // Calculate the cost-to-go for each conditional - costToGo_ = computeCostToGo(conditionals_); - - // Create the root node and push it to the expansions queue - expansions_.push(SearchNode::Root( - conditionals_.size(), costToGo_.empty() ? 0.0 : costToGo_.back())); + std::vector conditionals; + for (auto& factor : bayesNet) conditionals.push_back(factor); + initialize(conditionals); } DiscreteSearch::DiscreteSearch(const DiscreteBayesTree& bayesTree, size_t K) : solutions_(K) { - using CliquePtr = DiscreteBayesTree::sharedClique; - std::function collectConditionals = - [&](const CliquePtr& clique) -> void { - if (!clique) return; - - // Recursive post-order traversal: process children first - for (const auto& child : clique->children) { - collectConditionals(child); - } - - // Then add the current clique's conditional - conditionals_.push_back(clique->conditional()); - }; - - // Start traversal from each root in the tree + std::vector conditionals; + std::function + collectConditionals = [&](const auto& clique) { + if (!clique) return; + for (const auto& child : clique->children) collectConditionals(child); + conditionals.push_back(clique->conditional()); + }; for (const auto& root : bayesTree.roots()) collectConditionals(root); - - // Calculate the cost-to-go for each conditional - costToGo_ = computeCostToGo(conditionals_); - - // Create the root node and push it to the expansions queue - expansions_.push(SearchNode::Root( - conditionals_.size(), costToGo_.empty() ? 0.0 : costToGo_.back())); -} + initialize(conditionals); +}; std::vector DiscreteSearch::run() { while (!expansions_.empty()) { @@ -170,7 +146,7 @@ void DiscreteSearch::expandNextNode() { // Again, prune if we cannot beat the worst solution if (!solutions_.prune(childNode.bound)) { - expansions_.push(childNode); + expansions_.emplace(childNode); } } } diff --git a/gtsam/discrete/DiscreteSearch.h b/gtsam/discrete/DiscreteSearch.h index 989437b9f..cbf258b3f 100644 --- a/gtsam/discrete/DiscreteSearch.h +++ b/gtsam/discrete/DiscreteSearch.h @@ -19,6 +19,8 @@ #include #include +#include + namespace gtsam { using Value = size_t; @@ -52,7 +54,7 @@ struct SearchNode { * * @return True if all variables have been assigned, false otherwise. */ - bool isComplete() const { return nextConditional < 0; } + inline bool isComplete() const { return nextConditional < 0; } /** * @brief Expands the node by assigning the next variable. @@ -133,12 +135,12 @@ class DiscreteSearch { /** * Construct from a DiscreteBayesNet and K. */ - DiscreteSearch(const DiscreteBayesNet& bayesNet, size_t K); + DiscreteSearch(const DiscreteBayesNet& bayesNet, size_t K = 1); /** * Construct from a DiscreteBayesTree and K. */ - DiscreteSearch(const DiscreteBayesTree& bayesTree, size_t K); + DiscreteSearch(const DiscreteBayesTree& bayesTree, size_t K = 1); /** * @brief Search for the K best solutions. @@ -153,18 +155,20 @@ class DiscreteSearch { std::vector run(); private: - /** - * @brief Compute the cost-to-go for each conditional. - * - * @param conditionals The conditionals of the DiscreteBayesNet. - * @return A vector of cost-to-go values. - */ + /// Initialize the search with the given conditionals. + void initialize( + const std::vector& conditionals) { + conditionals_ = conditionals; + costToGo_ = computeCostToGo(conditionals_); + expansions_.push(SearchNode::Root( + conditionals_.size(), costToGo_.empty() ? 0.0 : costToGo_.back())); + } + + /// Compute the cumulative cost-to-go for each conditional slot. static std::vector computeCostToGo( const std::vector& conditionals); - /** - * @brief Expand the next node in the search tree. - */ + /// Expand the next node in the search tree. void expandNextNode(); std::vector conditionals_; diff --git a/gtsam/discrete/tests/testDiscreteSearch.cpp b/gtsam/discrete/tests/testDiscreteSearch.cpp index a39bcd86b..e301ed3d4 100644 --- a/gtsam/discrete/tests/testDiscreteSearch.cpp +++ b/gtsam/discrete/tests/testDiscreteSearch.cpp @@ -46,20 +46,32 @@ TEST(DiscreteBayesNet, EmptyKBest) { TEST(DiscreteBayesNet, AsiaKBest) { using namespace asia_example; DiscreteBayesNet asia = createAsiaExample(); + + // Ask for the MPE + DiscreteSearch search1(asia); + auto mpe = search1.run(); + + // print numExpansions + std::cout << "Number of expansions: " << search1.numExpansions << std::endl; + + EXPECT_LONGS_EQUAL(1, mpe.size()); + // Regression test: check the MPE solution + EXPECT_DOUBLES_EQUAL(1.236627, std::fabs(mpe[0].error), 1e-5); + DiscreteSearch search(asia, 4); auto solutions = search.run(); // print numExpansions std::cout << "Number of expansions: " << search.numExpansions << std::endl; - EXPECT(!solutions.empty()); + EXPECT_LONGS_EQUAL(4, solutions.size()); // Regression test: check the first and last solution EXPECT_DOUBLES_EQUAL(1.236627, std::fabs(solutions[0].error), 1e-5); EXPECT_DOUBLES_EQUAL(2.201708, std::fabs(solutions[3].error), 1e-5); } /* ************************************************************************* */ -TEST(DiscreteBayesTree, testEmptyTree) { +TEST(DiscreteBayesTree, EmptyTree) { DiscreteBayesTree bt; DiscreteSearch search(bt, 3); @@ -72,12 +84,23 @@ TEST(DiscreteBayesTree, testEmptyTree) { } /* ************************************************************************* */ -TEST(DiscreteBayesTree, testTrivialOneClique) { +TEST(DiscreteBayesTree, AsiaTreeKBest) { using namespace asia_example; DiscreteFactorGraph asia(createAsiaExample()); const Ordering ordering{D, X, B, E, L, T, S, A}; DiscreteBayesTree::shared_ptr bt = asia.eliminateMultifrontal(ordering); + // Ask for top 4 solutions + DiscreteSearch search1(*bt); + auto mpe = search1.run(); + + // print numExpansions + std::cout << "Number of expansions: " << search1.numExpansions << std::endl; + + EXPECT_LONGS_EQUAL(1, mpe.size()); + // Regression test: check the MPE solution + EXPECT_DOUBLES_EQUAL(1.236627, std::fabs(mpe[0].error), 1e-5); + // Ask for top 4 solutions DiscreteSearch search(*bt, 4); auto solutions = search.run(); @@ -85,7 +108,7 @@ TEST(DiscreteBayesTree, testTrivialOneClique) { // print numExpansions std::cout << "Number of expansions: " << search.numExpansions << std::endl; - EXPECT(!solutions.empty()); + EXPECT_LONGS_EQUAL(4, solutions.size()); // Regression test: check the first and last solution EXPECT_DOUBLES_EQUAL(1.236627, std::fabs(solutions[0].error), 1e-5); EXPECT_DOUBLES_EQUAL(2.201708, std::fabs(solutions[3].error), 1e-5); From f174a38eedb25720dfa37852abd8aac5990e516d Mon Sep 17 00:00:00 2001 From: Frank Dellaert Date: Mon, 27 Jan 2025 00:26:35 -0500 Subject: [PATCH 13/43] Made run const, take K --- gtsam/discrete/DiscreteSearch.cpp | 107 ++++++++++++-------- gtsam/discrete/DiscreteSearch.h | 22 +--- gtsam/discrete/tests/testDiscreteSearch.cpp | 37 +++---- 3 files changed, 79 insertions(+), 87 deletions(-) diff --git a/gtsam/discrete/DiscreteSearch.cpp b/gtsam/discrete/DiscreteSearch.cpp index 81ebc6012..70f7f871a 100644 --- a/gtsam/discrete/DiscreteSearch.cpp +++ b/gtsam/discrete/DiscreteSearch.cpp @@ -76,34 +76,84 @@ std::vector Solutions::extractSolutions() { return result; } -DiscreteSearch::DiscreteSearch(const DiscreteBayesNet& bayesNet, size_t K) - : solutions_(K) { +DiscreteSearch::DiscreteSearch(const DiscreteBayesNet& bayesNet) { std::vector conditionals; - for (auto& factor : bayesNet) conditionals.push_back(factor); - initialize(conditionals); + for (auto& factor : bayesNet) conditionals_.push_back(factor); + costToGo_ = computeCostToGo(conditionals_); } -DiscreteSearch::DiscreteSearch(const DiscreteBayesTree& bayesTree, size_t K) - : solutions_(K) { - std::vector conditionals; +DiscreteSearch::DiscreteSearch(const DiscreteBayesTree& bayesTree) { std::function collectConditionals = [&](const auto& clique) { if (!clique) return; for (const auto& child : clique->children) collectConditionals(child); - conditionals.push_back(clique->conditional()); + conditionals_.push_back(clique->conditional()); }; for (const auto& root : bayesTree.roots()) collectConditionals(root); - initialize(conditionals); + costToGo_ = computeCostToGo(conditionals_); }; -std::vector DiscreteSearch::run() { - while (!expansions_.empty()) { - numExpansions++; +using SearchNodeQueue = std::priority_queue, + SearchNode::Compare>; + +std::vector DiscreteSearch::run(size_t K) const { + SearchNodeQueue expansions; + expansions.push(SearchNode::Root(conditionals_.size(), + costToGo_.empty() ? 0.0 : costToGo_.back())); + + Solutions solutions(K); + + auto expandNextNode = [&] { + // Pop the partial assignment with the smallest bound + SearchNode current = expansions.top(); + expansions.pop(); + + // If we already have K solutions, prune if we cannot beat the worst + // one. + if (solutions.prune(current.bound)) { + return; + } + + // Check if we have a complete assignment + if (current.isComplete()) { + solutions.maybeAdd(current.error, current.assignment); + return; + } + + // Expand on the next factor + const auto& conditional = conditionals_[current.nextConditional]; + + for (auto& fa : conditional->frontalAssignments()) { + auto childNode = current.expand(*conditional, fa); + if (childNode.nextConditional >= 0) + childNode.bound = + childNode.error + costToGo_[childNode.nextConditional]; + + // Again, prune if we cannot beat the worst solution + if (!solutions.prune(childNode.bound)) { + expansions.emplace(childNode); + } + } + }; + +#ifdef DISCRETE_SEARCH_DEBUG + size_t numExpansions = 0; +#endif + + // Perform the search + while (!expansions.empty()) { expandNextNode(); +#ifdef DISCRETE_SEARCH_DEBUG + ++numExpansions; +#endif } +#ifdef DISCRETE_SEARCH_DEBUG + std::cout << "Number of expansions: " << numExpansions << std::endl; +#endif + // Extract solutions from bestSolutions in ascending order of error - return solutions_.extractSolutions(); + return solutions.extractSolutions(); } std::vector DiscreteSearch::computeCostToGo( @@ -120,35 +170,4 @@ std::vector DiscreteSearch::computeCostToGo( return costToGo; } -void DiscreteSearch::expandNextNode() { - // Pop the partial assignment with the smallest bound - SearchNode current = expansions_.top(); - expansions_.pop(); - - // If we already have K solutions, prune if we cannot beat the worst one. - if (solutions_.prune(current.bound)) { - return; - } - - // Check if we have a complete assignment - if (current.isComplete()) { - solutions_.maybeAdd(current.error, current.assignment); - return; - } - - // Expand on the next factor - const auto& conditional = conditionals_[current.nextConditional]; - - for (auto& fa : conditional->frontalAssignments()) { - auto childNode = current.expand(*conditional, fa); - if (childNode.nextConditional >= 0) - childNode.bound = childNode.error + costToGo_[childNode.nextConditional]; - - // Again, prune if we cannot beat the worst solution - if (!solutions_.prune(childNode.bound)) { - expansions_.emplace(childNode); - } - } -} - } // namespace gtsam diff --git a/gtsam/discrete/DiscreteSearch.h b/gtsam/discrete/DiscreteSearch.h index cbf258b3f..f5bd67bae 100644 --- a/gtsam/discrete/DiscreteSearch.h +++ b/gtsam/discrete/DiscreteSearch.h @@ -130,17 +130,15 @@ class Solutions { */ class DiscreteSearch { public: - size_t numExpansions = 0; - /** * Construct from a DiscreteBayesNet and K. */ - DiscreteSearch(const DiscreteBayesNet& bayesNet, size_t K = 1); + DiscreteSearch(const DiscreteBayesNet& bayesNet); /** * Construct from a DiscreteBayesTree and K. */ - DiscreteSearch(const DiscreteBayesTree& bayesTree, size_t K = 1); + DiscreteSearch(const DiscreteBayesTree& bayesTree); /** * @brief Search for the K best solutions. @@ -152,29 +150,17 @@ class DiscreteSearch { * * @return A vector of the K best solutions found during the search. */ - std::vector run(); + std::vector run(size_t K = 1) const; private: - /// Initialize the search with the given conditionals. - void initialize( - const std::vector& conditionals) { - conditionals_ = conditionals; - costToGo_ = computeCostToGo(conditionals_); - expansions_.push(SearchNode::Root( - conditionals_.size(), costToGo_.empty() ? 0.0 : costToGo_.back())); - } - /// Compute the cumulative cost-to-go for each conditional slot. static std::vector computeCostToGo( const std::vector& conditionals); /// Expand the next node in the search tree. - void expandNextNode(); + void expandNextNode() const; std::vector conditionals_; std::vector costToGo_; - std::priority_queue, SearchNode::Compare> - expansions_; - Solutions solutions_; }; } // namespace gtsam diff --git a/gtsam/discrete/tests/testDiscreteSearch.cpp b/gtsam/discrete/tests/testDiscreteSearch.cpp index e301ed3d4..a3b50add4 100644 --- a/gtsam/discrete/tests/testDiscreteSearch.cpp +++ b/gtsam/discrete/tests/testDiscreteSearch.cpp @@ -35,8 +35,8 @@ using namespace gtsam; /* ************************************************************************* */ TEST(DiscreteBayesNet, EmptyKBest) { DiscreteBayesNet net; // no factors - DiscreteSearch search(net, 3); - auto solutions = search.run(); + DiscreteSearch search(net); + auto solutions = search.run(3); // Expect one solution with empty assignment, error=0 EXPECT_LONGS_EQUAL(1, solutions.size()); EXPECT_DOUBLES_EQUAL(0, std::fabs(solutions[0].error), 1e-9); @@ -46,23 +46,17 @@ TEST(DiscreteBayesNet, EmptyKBest) { TEST(DiscreteBayesNet, AsiaKBest) { using namespace asia_example; DiscreteBayesNet asia = createAsiaExample(); + DiscreteSearch search(asia); // Ask for the MPE - DiscreteSearch search1(asia); - auto mpe = search1.run(); - - // print numExpansions - std::cout << "Number of expansions: " << search1.numExpansions << std::endl; + auto mpe = search.run(); EXPECT_LONGS_EQUAL(1, mpe.size()); // Regression test: check the MPE solution EXPECT_DOUBLES_EQUAL(1.236627, std::fabs(mpe[0].error), 1e-5); - DiscreteSearch search(asia, 4); - auto solutions = search.run(); - - // print numExpansions - std::cout << "Number of expansions: " << search.numExpansions << std::endl; + // Ask for top 4 solutions + auto solutions = search.run(4); EXPECT_LONGS_EQUAL(4, solutions.size()); // Regression test: check the first and last solution @@ -74,8 +68,8 @@ TEST(DiscreteBayesNet, AsiaKBest) { TEST(DiscreteBayesTree, EmptyTree) { DiscreteBayesTree bt; - DiscreteSearch search(bt, 3); - auto solutions = search.run(); + DiscreteSearch search(bt); + auto solutions = search.run(3); // We expect exactly 1 solution with error = 0.0 (the empty assignment). assert(solutions.size() == 1 && "There should be exactly one empty solution"); @@ -89,24 +83,17 @@ TEST(DiscreteBayesTree, AsiaTreeKBest) { DiscreteFactorGraph asia(createAsiaExample()); const Ordering ordering{D, X, B, E, L, T, S, A}; DiscreteBayesTree::shared_ptr bt = asia.eliminateMultifrontal(ordering); + DiscreteSearch search(*bt); - // Ask for top 4 solutions - DiscreteSearch search1(*bt); - auto mpe = search1.run(); - - // print numExpansions - std::cout << "Number of expansions: " << search1.numExpansions << std::endl; + // Ask for MPE + auto mpe = search.run(); EXPECT_LONGS_EQUAL(1, mpe.size()); // Regression test: check the MPE solution EXPECT_DOUBLES_EQUAL(1.236627, std::fabs(mpe[0].error), 1e-5); // Ask for top 4 solutions - DiscreteSearch search(*bt, 4); - auto solutions = search.run(); - - // print numExpansions - std::cout << "Number of expansions: " << search.numExpansions << std::endl; + auto solutions = search.run(4); EXPECT_LONGS_EQUAL(4, solutions.size()); // Regression test: check the first and last solution From cee77695956de909dad8de0ca626db0793016803 Mon Sep 17 00:00:00 2001 From: Frank Dellaert Date: Mon, 27 Jan 2025 00:32:03 -0500 Subject: [PATCH 14/43] Moved data structures to cpp --- gtsam/discrete/DiscreteSearch.cpp | 172 ++++++++++++++++++++++-------- gtsam/discrete/DiscreteSearch.h | 90 +--------------- 2 files changed, 126 insertions(+), 136 deletions(-) diff --git a/gtsam/discrete/DiscreteSearch.cpp b/gtsam/discrete/DiscreteSearch.cpp index 70f7f871a..5fca5c820 100644 --- a/gtsam/discrete/DiscreteSearch.cpp +++ b/gtsam/discrete/DiscreteSearch.cpp @@ -20,61 +20,139 @@ namespace gtsam { -SearchNode SearchNode::Root(size_t numConditionals, double bound) { - return {.assignment = DiscreteValues(), - .error = 0.0, - .bound = bound, - .nextConditional = static_cast(numConditionals) - 1}; -} +using Value = size_t; -SearchNode SearchNode::expand(const DiscreteConditional& conditional, - const DiscreteValues& fa) const { - // Combine the new frontal assignment with the current partial assignment - DiscreteValues newAssignment = assignment; - for (auto& [key, value] : fa) { - newAssignment[key] = value; +/** + * @brief Represents a node in the search tree for discrete search algorithms. + * + * @details Each SearchNode contains a partial assignment of discrete variables, + * the current error, a bound on the final error, and the index of the next + * conditional to be assigned. + */ +struct SearchNode { + DiscreteValues assignment; ///< Partial assignment of discrete variables. + double error; ///< Current error for the partial assignment. + double bound; ///< Lower bound on the final error for unassigned variables. + int nextConditional; ///< Index of the next conditional to be assigned. + + /** + * @brief Construct the root node for the search. + */ + static SearchNode Root(size_t numConditionals, double bound) { + return {.assignment = DiscreteValues(), + .error = 0.0, + .bound = bound, + .nextConditional = static_cast(numConditionals) - 1}; } - return {.assignment = newAssignment, - .error = error + conditional.error(newAssignment), - .bound = 0.0, - .nextConditional = nextConditional - 1}; -} + struct Compare { + bool operator()(const SearchNode& a, const SearchNode& b) const { + return a.bound > b.bound; // smallest bound -> highest priority + } + }; -bool Solutions::maybeAdd(double error, const DiscreteValues& assignment) { - const bool full = pq_.size() == maxSize_; - if (full && error >= pq_.top().error) return false; - if (full) pq_.pop(); - pq_.emplace(error, assignment); - return true; -} + /** + * @brief Checks if the node represents a complete assignment. + * + * @return True if all variables have been assigned, false otherwise. + */ + inline bool isComplete() const { return nextConditional < 0; } -std::ostream& operator<<(std::ostream& os, const Solutions& sn) { - os << "Solutions (top " << sn.pq_.size() << "):\n"; - auto pq = sn.pq_; - while (!pq.empty()) { - os << pq.top() << "\n"; - pq.pop(); + /** + * @brief Expands the node by assigning the next variable. + * + * @param conditional The discrete conditional representing the next variable + * to be assigned. + * @param fa The frontal assignment for the next variable. + * @return A new SearchNode representing the expanded state. + */ + SearchNode expand(const DiscreteConditional& conditional, + const DiscreteValues& fa) const { + // Combine the new frontal assignment with the current partial assignment + DiscreteValues newAssignment = assignment; + for (auto& [key, value] : fa) { + newAssignment[key] = value; + } + + return {.assignment = newAssignment, + .error = error + conditional.error(newAssignment), + .bound = 0.0, + .nextConditional = nextConditional - 1}; } - return os; -} -bool Solutions::prune(double bound) const { - if (pq_.size() < maxSize_) return false; - return bound >= pq_.top().error; -} - -std::vector Solutions::extractSolutions() { - std::vector result; - while (!pq_.empty()) { - result.push_back(pq_.top()); - pq_.pop(); + /** + * @brief Prints the SearchNode to an output stream. + * + * @param os The output stream. + * @param node The SearchNode to be printed. + * @return The output stream. + */ + friend std::ostream& operator<<(std::ostream& os, const SearchNode& node) { + os << "SearchNode(error=" << node.error << ", bound=" << node.bound << ")"; + return os; } - std::sort( - result.begin(), result.end(), - [](const Solution& a, const Solution& b) { return a.error < b.error; }); - return result; -} +}; + +struct CompareSolution { + bool operator()(const Solution& a, const Solution& b) const { + return a.error < b.error; + } +}; + +// Define the Solutions class +class Solutions { + private: + size_t maxSize_; + std::priority_queue, CompareSolution> pq_; + + public: + Solutions(size_t maxSize) : maxSize_(maxSize) {} + + /// Add a solution to the priority queue, possibly evicting the worst one. + /// Return true if we added the solution. + bool maybeAdd(double error, const DiscreteValues& assignment) { + const bool full = pq_.size() == maxSize_; + if (full && error >= pq_.top().error) return false; + if (full) pq_.pop(); + pq_.emplace(error, assignment); + return true; + } + + /// Check if we have any solutions + bool empty() const { return pq_.empty(); } + + // Method to print all solutions + friend std::ostream& operator<<(std::ostream& os, const Solutions& sn) { + os << "Solutions (top " << sn.pq_.size() << "):\n"; + auto pq = sn.pq_; + while (!pq.empty()) { + os << pq.top() << "\n"; + pq.pop(); + } + return os; + } + + /// Check if (partial) solution with given bound can be pruned. If we have + /// room, we never prune. Otherwise, prune if lower bound on error is worse + /// than our current worst error. + bool prune(double bound) const { + if (pq_.size() < maxSize_) return false; + return bound >= pq_.top().error; + } + + // Method to extract solutions in ascending order of error + std::vector extractSolutions() { + std::vector result; + while (!pq_.empty()) { + result.push_back(pq_.top()); + pq_.pop(); + } + std::sort( + result.begin(), result.end(), + [](const Solution& a, const Solution& b) { return a.error < b.error; }); + return result; + } +}; DiscreteSearch::DiscreteSearch(const DiscreteBayesNet& bayesNet) { std::vector conditionals; diff --git a/gtsam/discrete/DiscreteSearch.h b/gtsam/discrete/DiscreteSearch.h index f5bd67bae..78558a127 100644 --- a/gtsam/discrete/DiscreteSearch.h +++ b/gtsam/discrete/DiscreteSearch.h @@ -23,63 +23,9 @@ namespace gtsam { -using Value = size_t; - /** - * @brief Represents a node in the search tree for discrete search algorithms. - * - * @details Each SearchNode contains a partial assignment of discrete variables, - * the current error, a bound on the final error, and the index of the next - * conditional to be assigned. + * @brief A solution to a discrete search problem. */ -struct SearchNode { - DiscreteValues assignment; ///< Partial assignment of discrete variables. - double error; ///< Current error for the partial assignment. - double bound; ///< Lower bound on the final error for unassigned variables. - int nextConditional; ///< Index of the next conditional to be assigned. - - /** - * @brief Construct the root node for the search. - */ - static SearchNode Root(size_t numConditionals, double bound); - - struct Compare { - bool operator()(const SearchNode& a, const SearchNode& b) const { - return a.bound > b.bound; // smallest bound -> highest priority - } - }; - - /** - * @brief Checks if the node represents a complete assignment. - * - * @return True if all variables have been assigned, false otherwise. - */ - inline bool isComplete() const { return nextConditional < 0; } - - /** - * @brief Expands the node by assigning the next variable. - * - * @param conditional The discrete conditional representing the next variable - * to be assigned. - * @param fa The frontal assignment for the next variable. - * @return A new SearchNode representing the expanded state. - */ - SearchNode expand(const DiscreteConditional& conditional, - const DiscreteValues& fa) const; - - /** - * @brief Prints the SearchNode to an output stream. - * - * @param os The output stream. - * @param node The SearchNode to be printed. - * @return The output stream. - */ - friend std::ostream& operator<<(std::ostream& os, const SearchNode& node) { - os << "SearchNode(error=" << node.error << ", bound=" << node.bound << ")"; - return os; - } -}; - struct Solution { double error; DiscreteValues assignment; @@ -89,40 +35,6 @@ struct Solution { os << "[ error=" << sn.error << " assignment={" << sn.assignment << "}]"; return os; } - - struct Compare { - bool operator()(const Solution& a, const Solution& b) const { - return a.error < b.error; - } - }; -}; - -// Define the Solutions class -class Solutions { - private: - size_t maxSize_; - std::priority_queue, Solution::Compare> pq_; - - public: - Solutions(size_t maxSize) : maxSize_(maxSize) {} - - /// Add a solution to the priority queue, possibly evicting the worst one. - /// Return true if we added the solution. - bool maybeAdd(double error, const DiscreteValues& assignment); - - /// Check if we have any solutions - bool empty() const { return pq_.empty(); } - - // Method to print all solutions - friend std::ostream& operator<<(std::ostream& os, const Solutions& sn); - - /// Check if (partial) solution with given bound can be pruned. If we have - /// room, we never prune. Otherwise, prune if lower bound on error is worse - /// than our current worst error. - bool prune(double bound) const; - - // Method to extract solutions in ascending order of error - std::vector extractSolutions(); }; /** From 9dcb52b697a25d856e49fe8655cc23e3654cfa36 Mon Sep 17 00:00:00 2001 From: Frank Dellaert Date: Mon, 27 Jan 2025 00:37:17 -0500 Subject: [PATCH 15/43] Made queue a class --- gtsam/discrete/DiscreteSearch.cpp | 48 +++++++++++++++---------------- 1 file changed, 24 insertions(+), 24 deletions(-) diff --git a/gtsam/discrete/DiscreteSearch.cpp b/gtsam/discrete/DiscreteSearch.cpp index 5fca5c820..e3090bb50 100644 --- a/gtsam/discrete/DiscreteSearch.cpp +++ b/gtsam/discrete/DiscreteSearch.cpp @@ -171,48 +171,48 @@ DiscreteSearch::DiscreteSearch(const DiscreteBayesTree& bayesTree) { costToGo_ = computeCostToGo(conditionals_); }; -using SearchNodeQueue = std::priority_queue, - SearchNode::Compare>; - -std::vector DiscreteSearch::run(size_t K) const { - SearchNodeQueue expansions; - expansions.push(SearchNode::Root(conditionals_.size(), - costToGo_.empty() ? 0.0 : costToGo_.back())); - - Solutions solutions(K); - - auto expandNextNode = [&] { +struct SearchNodeQueue + : public std::priority_queue, + SearchNode::Compare> { + void expandNextNode( + const std::vector& conditionals, + const std::vector& costToGo, Solutions* solutions) { // Pop the partial assignment with the smallest bound - SearchNode current = expansions.top(); - expansions.pop(); + SearchNode current = top(); + pop(); - // If we already have K solutions, prune if we cannot beat the worst - // one. - if (solutions.prune(current.bound)) { + // If we already have K solutions, prune if we cannot beat the worst one. + if (solutions->prune(current.bound)) { return; } // Check if we have a complete assignment if (current.isComplete()) { - solutions.maybeAdd(current.error, current.assignment); + solutions->maybeAdd(current.error, current.assignment); return; } // Expand on the next factor - const auto& conditional = conditionals_[current.nextConditional]; + const auto& conditional = conditionals[current.nextConditional]; for (auto& fa : conditional->frontalAssignments()) { auto childNode = current.expand(*conditional, fa); if (childNode.nextConditional >= 0) - childNode.bound = - childNode.error + costToGo_[childNode.nextConditional]; + childNode.bound = childNode.error + costToGo[childNode.nextConditional]; // Again, prune if we cannot beat the worst solution - if (!solutions.prune(childNode.bound)) { - expansions.emplace(childNode); + if (!solutions->prune(childNode.bound)) { + emplace(childNode); } } - }; + } +}; + +std::vector DiscreteSearch::run(size_t K) const { + Solutions solutions(K); + SearchNodeQueue expansions; + expansions.push(SearchNode::Root(conditionals_.size(), + costToGo_.empty() ? 0.0 : costToGo_.back())); #ifdef DISCRETE_SEARCH_DEBUG size_t numExpansions = 0; @@ -220,7 +220,7 @@ std::vector DiscreteSearch::run(size_t K) const { // Perform the search while (!expansions.empty()) { - expandNextNode(); + expansions.expandNextNode(conditionals_, costToGo_, &solutions); #ifdef DISCRETE_SEARCH_DEBUG ++numExpansions; #endif From 908e4dcd77d93cf69dd2ae04f82847af4586efb4 Mon Sep 17 00:00:00 2001 From: Frank Dellaert Date: Mon, 27 Jan 2025 00:51:55 -0500 Subject: [PATCH 16/43] Check MPE --- gtsam/discrete/tests/testDiscreteSearch.cpp | 16 +++++++++++----- 1 file changed, 11 insertions(+), 5 deletions(-) diff --git a/gtsam/discrete/tests/testDiscreteSearch.cpp b/gtsam/discrete/tests/testDiscreteSearch.cpp index a3b50add4..1eb274f2d 100644 --- a/gtsam/discrete/tests/testDiscreteSearch.cpp +++ b/gtsam/discrete/tests/testDiscreteSearch.cpp @@ -45,8 +45,8 @@ TEST(DiscreteBayesNet, EmptyKBest) { /* ************************************************************************* */ TEST(DiscreteBayesNet, AsiaKBest) { using namespace asia_example; - DiscreteBayesNet asia = createAsiaExample(); - DiscreteSearch search(asia); + const DiscreteBayesNet asia = createAsiaExample(); + const DiscreteSearch search(asia); // Ask for the MPE auto mpe = search.run(); @@ -55,6 +55,10 @@ TEST(DiscreteBayesNet, AsiaKBest) { // Regression test: check the MPE solution EXPECT_DOUBLES_EQUAL(1.236627, std::fabs(mpe[0].error), 1e-5); + // Check it is equal to MPE via inference + const DiscreteFactorGraph asiaFG(asia); + EXPECT(assert_equal(mpe[0].assignment, asiaFG.optimize())); + // Ask for top 4 solutions auto solutions = search.run(4); @@ -72,7 +76,6 @@ TEST(DiscreteBayesTree, EmptyTree) { auto solutions = search.run(3); // We expect exactly 1 solution with error = 0.0 (the empty assignment). - assert(solutions.size() == 1 && "There should be exactly one empty solution"); EXPECT_LONGS_EQUAL(1, solutions.size()); EXPECT_DOUBLES_EQUAL(0, std::fabs(solutions[0].error), 1e-9); } @@ -80,9 +83,9 @@ TEST(DiscreteBayesTree, EmptyTree) { /* ************************************************************************* */ TEST(DiscreteBayesTree, AsiaTreeKBest) { using namespace asia_example; - DiscreteFactorGraph asia(createAsiaExample()); + DiscreteFactorGraph asiaFG(createAsiaExample()); const Ordering ordering{D, X, B, E, L, T, S, A}; - DiscreteBayesTree::shared_ptr bt = asia.eliminateMultifrontal(ordering); + DiscreteBayesTree::shared_ptr bt = asiaFG.eliminateMultifrontal(ordering); DiscreteSearch search(*bt); // Ask for MPE @@ -92,6 +95,9 @@ TEST(DiscreteBayesTree, AsiaTreeKBest) { // Regression test: check the MPE solution EXPECT_DOUBLES_EQUAL(1.236627, std::fabs(mpe[0].error), 1e-5); + // Check it is equal to MPE via inference + EXPECT(assert_equal(mpe[0].assignment, asiaFG.optimize())); + // Ask for top 4 solutions auto solutions = search.run(4); From 1594a2504560ded7419afedc65fce102d2662c96 Mon Sep 17 00:00:00 2001 From: Frank Dellaert Date: Mon, 27 Jan 2025 00:52:07 -0500 Subject: [PATCH 17/43] Fix assert issue --- gtsam/discrete/DiscreteSearch.cpp | 1 - 1 file changed, 1 deletion(-) diff --git a/gtsam/discrete/DiscreteSearch.cpp b/gtsam/discrete/DiscreteSearch.cpp index e3090bb50..1cd2999f2 100644 --- a/gtsam/discrete/DiscreteSearch.cpp +++ b/gtsam/discrete/DiscreteSearch.cpp @@ -241,7 +241,6 @@ std::vector DiscreteSearch::computeCostToGo( for (const auto& conditional : conditionals) { Ordering ordering(conditional->begin(), conditional->end()); auto maxx = conditional->max(ordering); - assert(maxx->size() == 1); error -= std::log(maxx->evaluate({})); costToGo.push_back(error); } From d53b48cf80364649e63e34dddaa0de3cc4f63617 Mon Sep 17 00:00:00 2001 From: Frank Dellaert Date: Mon, 27 Jan 2025 09:38:57 -0500 Subject: [PATCH 18/43] Don't use c99 feature :-/ --- gtsam/discrete/DiscreteSearch.cpp | 12 ++++-------- 1 file changed, 4 insertions(+), 8 deletions(-) diff --git a/gtsam/discrete/DiscreteSearch.cpp b/gtsam/discrete/DiscreteSearch.cpp index 1cd2999f2..4a068d175 100644 --- a/gtsam/discrete/DiscreteSearch.cpp +++ b/gtsam/discrete/DiscreteSearch.cpp @@ -39,10 +39,8 @@ struct SearchNode { * @brief Construct the root node for the search. */ static SearchNode Root(size_t numConditionals, double bound) { - return {.assignment = DiscreteValues(), - .error = 0.0, - .bound = bound, - .nextConditional = static_cast(numConditionals) - 1}; + return {DiscreteValues(), 0.0, bound, + static_cast(numConditionals) - 1}; } struct Compare { @@ -74,10 +72,8 @@ struct SearchNode { newAssignment[key] = value; } - return {.assignment = newAssignment, - .error = error + conditional.error(newAssignment), - .bound = 0.0, - .nextConditional = nextConditional - 1}; + return {newAssignment, error + conditional.error(newAssignment), 0.0, + nextConditional - 1}; } /** From 2da5d764f15ada8845b49bed177019d4193b02dd Mon Sep 17 00:00:00 2001 From: Frank Dellaert Date: Mon, 27 Jan 2025 10:09:40 -0500 Subject: [PATCH 19/43] Pre-compute asia things --- gtsam/discrete/tests/testDiscreteSearch.cpp | 34 +++++++++------------ 1 file changed, 15 insertions(+), 19 deletions(-) diff --git a/gtsam/discrete/tests/testDiscreteSearch.cpp b/gtsam/discrete/tests/testDiscreteSearch.cpp index 1eb274f2d..b537dd2f0 100644 --- a/gtsam/discrete/tests/testDiscreteSearch.cpp +++ b/gtsam/discrete/tests/testDiscreteSearch.cpp @@ -20,18 +20,21 @@ #include #include -#include -#include -#include -#include -#include -#include -#include - #include "AsiaExample.h" using namespace gtsam; +// Create Asia Bayes net, FG, and Bayes tree once +namespace asia { +using namespace asia_example; +static const DiscreteBayesNet bayesNet = createAsiaExample(); +static const DiscreteFactorGraph factorGraph(bayesNet); +static const DiscreteValues mpe = factorGraph.optimize(); +static const Ordering ordering{D, X, B, E, L, T, S, A}; +static const DiscreteBayesTree bayesTree = + *factorGraph.eliminateMultifrontal(ordering); +} // namespace asia + /* ************************************************************************* */ TEST(DiscreteBayesNet, EmptyKBest) { DiscreteBayesNet net; // no factors @@ -44,9 +47,7 @@ TEST(DiscreteBayesNet, EmptyKBest) { /* ************************************************************************* */ TEST(DiscreteBayesNet, AsiaKBest) { - using namespace asia_example; - const DiscreteBayesNet asia = createAsiaExample(); - const DiscreteSearch search(asia); + const DiscreteSearch search(asia::bayesNet); // Ask for the MPE auto mpe = search.run(); @@ -56,8 +57,7 @@ TEST(DiscreteBayesNet, AsiaKBest) { EXPECT_DOUBLES_EQUAL(1.236627, std::fabs(mpe[0].error), 1e-5); // Check it is equal to MPE via inference - const DiscreteFactorGraph asiaFG(asia); - EXPECT(assert_equal(mpe[0].assignment, asiaFG.optimize())); + EXPECT(assert_equal(asia::mpe, mpe[0].assignment)); // Ask for top 4 solutions auto solutions = search.run(4); @@ -82,11 +82,7 @@ TEST(DiscreteBayesTree, EmptyTree) { /* ************************************************************************* */ TEST(DiscreteBayesTree, AsiaTreeKBest) { - using namespace asia_example; - DiscreteFactorGraph asiaFG(createAsiaExample()); - const Ordering ordering{D, X, B, E, L, T, S, A}; - DiscreteBayesTree::shared_ptr bt = asiaFG.eliminateMultifrontal(ordering); - DiscreteSearch search(*bt); + DiscreteSearch search(asia::bayesTree); // Ask for MPE auto mpe = search.run(); @@ -96,7 +92,7 @@ TEST(DiscreteBayesTree, AsiaTreeKBest) { EXPECT_DOUBLES_EQUAL(1.236627, std::fabs(mpe[0].error), 1e-5); // Check it is equal to MPE via inference - EXPECT(assert_equal(mpe[0].assignment, asiaFG.optimize())); + EXPECT(assert_equal(asia::mpe, mpe[0].assignment)); // Ask for top 4 solutions auto solutions = search.run(4); From cf4441826909aae8ad5cd91cc8d9ef46821b759a Mon Sep 17 00:00:00 2001 From: Frank Dellaert Date: Mon, 27 Jan 2025 13:31:24 -0500 Subject: [PATCH 20/43] Get rid of extra ; --- gtsam/discrete/DiscreteSearch.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/gtsam/discrete/DiscreteSearch.cpp b/gtsam/discrete/DiscreteSearch.cpp index 4a068d175..531d994be 100644 --- a/gtsam/discrete/DiscreteSearch.cpp +++ b/gtsam/discrete/DiscreteSearch.cpp @@ -165,7 +165,7 @@ DiscreteSearch::DiscreteSearch(const DiscreteBayesTree& bayesTree) { }; for (const auto& root : bayesTree.roots()) collectConditionals(root); costToGo_ = computeCostToGo(conditionals_); -}; +} struct SearchNodeQueue : public std::priority_queue, From 16131fe8c92b6a84ca62dec758d845406781b2a6 Mon Sep 17 00:00:00 2001 From: Frank Dellaert Date: Mon, 27 Jan 2025 14:02:40 -0500 Subject: [PATCH 21/43] Remove spurious typedef --- gtsam/discrete/DiscreteSearch.cpp | 2 -- 1 file changed, 2 deletions(-) diff --git a/gtsam/discrete/DiscreteSearch.cpp b/gtsam/discrete/DiscreteSearch.cpp index 531d994be..d151e7a6f 100644 --- a/gtsam/discrete/DiscreteSearch.cpp +++ b/gtsam/discrete/DiscreteSearch.cpp @@ -20,8 +20,6 @@ namespace gtsam { -using Value = size_t; - /** * @brief Represents a node in the search tree for discrete search algorithms. * From 41fdeb6c04908be7b52bc099f734d2221d0fba19 Mon Sep 17 00:00:00 2001 From: Frank Dellaert Date: Mon, 27 Jan 2025 14:51:16 -0500 Subject: [PATCH 22/43] Export --- gtsam/discrete/DiscreteSearch.cpp | 2 ++ gtsam/discrete/DiscreteSearch.h | 30 +++++++++++++++--------------- 2 files changed, 17 insertions(+), 15 deletions(-) diff --git a/gtsam/discrete/DiscreteSearch.cpp b/gtsam/discrete/DiscreteSearch.cpp index d151e7a6f..c5941862d 100644 --- a/gtsam/discrete/DiscreteSearch.cpp +++ b/gtsam/discrete/DiscreteSearch.cpp @@ -20,6 +20,8 @@ namespace gtsam { +using Solution = DiscreteSearch::Solution; + /** * @brief Represents a node in the search tree for discrete search algorithms. * diff --git a/gtsam/discrete/DiscreteSearch.h b/gtsam/discrete/DiscreteSearch.h index 78558a127..6202880b2 100644 --- a/gtsam/discrete/DiscreteSearch.h +++ b/gtsam/discrete/DiscreteSearch.h @@ -23,25 +23,25 @@ namespace gtsam { -/** - * @brief A solution to a discrete search problem. - */ -struct Solution { - double error; - DiscreteValues assignment; - Solution(double err, const DiscreteValues& assign) - : error(err), assignment(assign) {} - friend std::ostream& operator<<(std::ostream& os, const Solution& sn) { - os << "[ error=" << sn.error << " assignment={" << sn.assignment << "}]"; - return os; - } -}; - /** * DiscreteSearch: Search for the K best solutions. */ -class DiscreteSearch { +class GTSAM_EXPORT DiscreteSearch { public: + /** + * @brief A solution to a discrete search problem. + */ + struct Solution { + double error; + DiscreteValues assignment; + Solution(double err, const DiscreteValues& assign) + : error(err), assignment(assign) {} + friend std::ostream& operator<<(std::ostream& os, const Solution& sn) { + os << "[ error=" << sn.error << " assignment={" << sn.assignment << "}]"; + return os; + } + }; + /** * Construct from a DiscreteBayesNet and K. */ From 35e7acbf16548a286d3677717f8dfc1b9398d1f8 Mon Sep 17 00:00:00 2001 From: Frank Dellaert Date: Mon, 27 Jan 2025 13:05:20 -0500 Subject: [PATCH 23/43] Print factors --- gtsam/discrete/DiscreteJunctionTree.cpp | 41 +++++++--- gtsam/discrete/DiscreteJunctionTree.h | 99 +++++++++++++++---------- 2 files changed, 90 insertions(+), 50 deletions(-) diff --git a/gtsam/discrete/DiscreteJunctionTree.cpp b/gtsam/discrete/DiscreteJunctionTree.cpp index dc24860eb..b4657ec8d 100644 --- a/gtsam/discrete/DiscreteJunctionTree.cpp +++ b/gtsam/discrete/DiscreteJunctionTree.cpp @@ -16,19 +16,42 @@ * @author Richard Roberts */ -#include -#include #include +#include +#include namespace gtsam { - // Instantiate base classes - template class EliminatableClusterTree; - template class JunctionTree; +// Instantiate base classes +template class EliminatableClusterTree; +template class JunctionTree; - /* ************************************************************************* */ - DiscreteJunctionTree::DiscreteJunctionTree( - const DiscreteEliminationTree& eliminationTree) : - Base(eliminationTree) {} +/* ************************************************************************* */ +DiscreteJunctionTree::DiscreteJunctionTree( + const DiscreteEliminationTree& eliminationTree) + : Base(eliminationTree) {} +/* ************************************************************************* */ +namespace { +struct PrintForestVisitorPre { + const KeyFormatter& formatter; + PrintForestVisitorPre(const KeyFormatter& formatter) : formatter(formatter) {} + std::string operator()( + const std::shared_ptr& node, + const std::string& parentString) { + // Print the current node + node->print(parentString + "-", formatter); + node->factors.print(parentString + "-", formatter); + std::cout << std::endl; + // Increment the indentation + return parentString + "| "; + } +}; +} // namespace +void DiscreteJunctionTree::print(const std::string& s, + const KeyFormatter& keyFormatter) const { + PrintForestVisitorPre visitor(keyFormatter); + treeTraversal::DepthFirstForest(*this, s, visitor); } + +} // namespace gtsam diff --git a/gtsam/discrete/DiscreteJunctionTree.h b/gtsam/discrete/DiscreteJunctionTree.h index f6171c672..4b9241036 100644 --- a/gtsam/discrete/DiscreteJunctionTree.h +++ b/gtsam/discrete/DiscreteJunctionTree.h @@ -18,54 +18,71 @@ #pragma once -#include #include +#include #include namespace gtsam { - // Forward declarations - class DiscreteEliminationTree; +// Forward declarations +class DiscreteEliminationTree; + +/** + * An EliminatableClusterTree, i.e., a set of variable clusters with factors, + * arranged in a tree, with the additional property that it represents the + * clique tree associated with a Bayes net. + * + * In GTSAM a junction tree is an intermediate data structure in multifrontal + * variable elimination. Each node is a cluster of factors, along with a + * clique of variables that are eliminated all at once. In detail, every node k + * represents a clique (maximal fully connected subset) of an associated chordal + * graph, such as a chordal Bayes net resulting from elimination. + * + * The difference with the BayesTree is that a JunctionTree stores factors, + * whereas a BayesTree stores conditionals, that are the product of eliminating + * the factors in the corresponding JunctionTree cliques. + * + * The tree structure and elimination method are exactly analogous to the + * EliminationTree, except that in the JunctionTree, at each node multiple + * variables are eliminated at a time. + * + * \ingroup Multifrontal + * @ingroup discrete + * \nosubgrouping + */ +class GTSAM_EXPORT DiscreteJunctionTree + : public JunctionTree { + public: + typedef JunctionTree + Base; ///< Base class + typedef DiscreteJunctionTree This; ///< This class + typedef std::shared_ptr shared_ptr; ///< Shared pointer to this class + + /// @name Constructors + /// @{ /** - * An EliminatableClusterTree, i.e., a set of variable clusters with factors, arranged in a tree, - * with the additional property that it represents the clique tree associated with a Bayes net. - * - * In GTSAM a junction tree is an intermediate data structure in multifrontal - * variable elimination. Each node is a cluster of factors, along with a - * clique of variables that are eliminated all at once. In detail, every node k represents - * a clique (maximal fully connected subset) of an associated chordal graph, such as a - * chordal Bayes net resulting from elimination. - * - * The difference with the BayesTree is that a JunctionTree stores factors, whereas a - * BayesTree stores conditionals, that are the product of eliminating the factors in the - * corresponding JunctionTree cliques. - * - * The tree structure and elimination method are exactly analogous to the EliminationTree, - * except that in the JunctionTree, at each node multiple variables are eliminated at a time. - * - * \ingroup Multifrontal - * @ingroup discrete - * \nosubgrouping + * Build the elimination tree of a factor graph using precomputed column + * structure. + * @param factorGraph The factor graph for which to build the elimination tree + * @param structure The set of factors involving each variable. If this is + * not precomputed, you can call the Create(const FactorGraph&) + * named constructor instead. + * @return The elimination tree */ - class GTSAM_EXPORT DiscreteJunctionTree : - public JunctionTree { - public: - typedef JunctionTree Base; ///< Base class - typedef DiscreteJunctionTree This; ///< This class - typedef std::shared_ptr shared_ptr; ///< Shared pointer to this class + DiscreteJunctionTree(const DiscreteEliminationTree& eliminationTree); - /** - * Build the elimination tree of a factor graph using precomputed column structure. - * @param factorGraph The factor graph for which to build the elimination tree - * @param structure The set of factors involving each variable. If this is not - * precomputed, you can call the Create(const FactorGraph&) - * named constructor instead. - * @return The elimination tree - */ - DiscreteJunctionTree(const DiscreteEliminationTree& eliminationTree); - }; + /// @} + /// @name Testable + /// @{ - /// typedef for wrapper: - using DiscreteCluster = DiscreteJunctionTree::Cluster; -} + /** Print the tree to cout */ + void print(const std::string& name = "DiscreteJunctionTree: ", + const KeyFormatter& formatter = DefaultKeyFormatter) const; + + /// @} +}; + +/// typedef for wrapper: +using DiscreteCluster = DiscreteJunctionTree::Cluster; +} // namespace gtsam From d8ed60aeada546b5145ee64ec067ba04c2c0c436 Mon Sep 17 00:00:00 2001 From: Frank Dellaert Date: Mon, 27 Jan 2025 14:01:20 -0500 Subject: [PATCH 24/43] Refactor to slots --- gtsam/discrete/DiscreteSearch.cpp | 79 ++++++++++++++++++------------- gtsam/discrete/DiscreteSearch.h | 43 ++++++++++++----- 2 files changed, 77 insertions(+), 45 deletions(-) diff --git a/gtsam/discrete/DiscreteSearch.cpp b/gtsam/discrete/DiscreteSearch.cpp index c5941862d..e3722c13e 100644 --- a/gtsam/discrete/DiscreteSearch.cpp +++ b/gtsam/discrete/DiscreteSearch.cpp @@ -20,6 +20,7 @@ namespace gtsam { +using Slot = DiscreteSearch::Slot; using Solution = DiscreteSearch::Solution; /** @@ -59,12 +60,12 @@ struct SearchNode { /** * @brief Expands the node by assigning the next variable. * - * @param conditional The discrete conditional representing the next variable + * @param factor The discrete factor associated with the next variable * to be assigned. * @param fa The frontal assignment for the next variable. * @return A new SearchNode representing the expanded state. */ - SearchNode expand(const DiscreteConditional& conditional, + SearchNode expand(const DiscreteFactor& factor, const DiscreteValues& fa) const { // Combine the new frontal assignment with the current partial assignment DiscreteValues newAssignment = assignment; @@ -72,7 +73,7 @@ struct SearchNode { newAssignment[key] = value; } - return {newAssignment, error + conditional.error(newAssignment), 0.0, + return {newAssignment, error + factor.error(newAssignment), 0.0, nextConditional - 1}; } @@ -150,10 +151,20 @@ class Solutions { } }; +DiscreteSearch::DiscreteSearch(const DiscreteFactorGraph& factorGraph) { + slots_.reserve(factorGraph.size()); + for (auto& factor : factorGraph) { + slots_.emplace_back(factor, std::vector{}, 0.0); + } + computeHeuristic(); +} + DiscreteSearch::DiscreteSearch(const DiscreteBayesNet& bayesNet) { - std::vector conditionals; - for (auto& factor : bayesNet) conditionals_.push_back(factor); - costToGo_ = computeCostToGo(conditionals_); + slots_.reserve(bayesNet.size()); + for (auto& conditional : bayesNet) { + slots_.emplace_back(conditional, conditional->frontalAssignments(), 0.0); + } + computeHeuristic(); } DiscreteSearch::DiscreteSearch(const DiscreteBayesTree& bayesTree) { @@ -161,22 +172,21 @@ DiscreteSearch::DiscreteSearch(const DiscreteBayesTree& bayesTree) { collectConditionals = [&](const auto& clique) { if (!clique) return; for (const auto& child : clique->children) collectConditionals(child); - conditionals_.push_back(clique->conditional()); + auto conditional = clique->conditional(); + slots_.emplace_back(conditional, conditional->frontalAssignments(), + 0.0); }; + + slots_.reserve(bayesTree.size()); for (const auto& root : bayesTree.roots()) collectConditionals(root); - costToGo_ = computeCostToGo(conditionals_); + computeHeuristic(); } struct SearchNodeQueue : public std::priority_queue, SearchNode::Compare> { - void expandNextNode( - const std::vector& conditionals, - const std::vector& costToGo, Solutions* solutions) { - // Pop the partial assignment with the smallest bound - SearchNode current = top(); - pop(); - + void expandNextNode(const SearchNode& current, const Slot& slot, + Solutions* solutions) { // If we already have K solutions, prune if we cannot beat the worst one. if (solutions->prune(current.bound)) { return; @@ -188,13 +198,11 @@ struct SearchNodeQueue return; } - // Expand on the next factor - const auto& conditional = conditionals[current.nextConditional]; - - for (auto& fa : conditional->frontalAssignments()) { - auto childNode = current.expand(*conditional, fa); + for (auto& fa : slot.assignments) { + auto childNode = current.expand(*slot.factor, fa); if (childNode.nextConditional >= 0) - childNode.bound = childNode.error + costToGo[childNode.nextConditional]; + // TODO(frank): this might be wrong ! + childNode.bound = childNode.error + slot.heuristic; // Again, prune if we cannot beat the worst solution if (!solutions->prune(childNode.bound)) { @@ -207,8 +215,7 @@ struct SearchNodeQueue std::vector DiscreteSearch::run(size_t K) const { Solutions solutions(K); SearchNodeQueue expansions; - expansions.push(SearchNode::Root(conditionals_.size(), - costToGo_.empty() ? 0.0 : costToGo_.back())); + expansions.push(SearchNode::Root(slots_.size(), slots_.back().heuristic)); #ifdef DISCRETE_SEARCH_DEBUG size_t numExpansions = 0; @@ -216,7 +223,13 @@ std::vector DiscreteSearch::run(size_t K) const { // Perform the search while (!expansions.empty()) { - expansions.expandNextNode(conditionals_, costToGo_, &solutions); + // Pop the partial assignment with the smallest bound + SearchNode current = expansions.top(); + expansions.pop(); + + // Get the next slot to expand + const auto& slot = slots_[current.nextConditional]; + expansions.expandNextNode(current, slot, &solutions); #ifdef DISCRETE_SEARCH_DEBUG ++numExpansions; #endif @@ -230,17 +243,19 @@ std::vector DiscreteSearch::run(size_t K) const { return solutions.extractSolutions(); } -std::vector DiscreteSearch::computeCostToGo( - const std::vector& conditionals) { - std::vector costToGo; +// We have a number of factors, each with a max value, and we want to compute +// the a lower-bound on the cost-to-go for each slot. For the first slot, this +// -log(max(factor[0])), as we only have one factor to resolve. For the second +// slot, we need to add -log(max(factor[1])) to it, etc... +void DiscreteSearch::computeHeuristic() { double error = 0.0; - for (const auto& conditional : conditionals) { - Ordering ordering(conditional->begin(), conditional->end()); - auto maxx = conditional->max(ordering); + for (size_t i = 0; i < slots_.size(); ++i) { + const auto& factor = slots_[i].factor; + Ordering ordering(factor->begin(), factor->end()); + auto maxx = factor->max(ordering); error -= std::log(maxx->evaluate({})); - costToGo.push_back(error); + slots_[i].heuristic = error; } - return costToGo; } } // namespace gtsam diff --git a/gtsam/discrete/DiscreteSearch.h b/gtsam/discrete/DiscreteSearch.h index 6202880b2..ddeaf38f1 100644 --- a/gtsam/discrete/DiscreteSearch.h +++ b/gtsam/discrete/DiscreteSearch.h @@ -28,9 +28,25 @@ namespace gtsam { */ class GTSAM_EXPORT DiscreteSearch { public: - /** - * @brief A solution to a discrete search problem. - */ + /// We structure the search as a set of slots, each with a factor and + /// a set of variable assignments that need to be chosen. In addition, each + /// slot has a heuristic associated with it. + struct Slot { + /// The factors in the search problem, + /// e.g., [P(B|A),P(A)] + DiscreteFactor::shared_ptr factor; + + /// The assignments for each factor, + /// e.g., [[B0,B1] [A0,A1]] + std::vector assignments; + + /// A lower bound on the cost-to-go for each slot, e.g., + /// [-log(max_B P(B|A)), -log(max_A P(A))] + double heuristic; + }; + + /// A solution is then a set of assignments, covering all the slots. + /// as well as an associated error = -log(probability) struct Solution { double error; DiscreteValues assignment; @@ -42,13 +58,19 @@ class GTSAM_EXPORT DiscreteSearch { } }; + public: /** - * Construct from a DiscreteBayesNet and K. + * Construct from a DiscreteFactorGraph. + */ + DiscreteSearch(const DiscreteFactorGraph& bayesNet); + + /** + * Construct from a DiscreteBayesNet. */ DiscreteSearch(const DiscreteBayesNet& bayesNet); /** - * Construct from a DiscreteBayesTree and K. + * Construct from a DiscreteBayesTree. */ DiscreteSearch(const DiscreteBayesTree& bayesTree); @@ -65,14 +87,9 @@ class GTSAM_EXPORT DiscreteSearch { std::vector run(size_t K = 1) const; private: - /// Compute the cumulative cost-to-go for each conditional slot. - static std::vector computeCostToGo( - const std::vector& conditionals); + /// Compute the cumulative lower-bound cost-to-go for each slot. + void computeHeuristic(); - /// Expand the next node in the search tree. - void expandNextNode() const; - - std::vector conditionals_; - std::vector costToGo_; + std::vector slots_; }; } // namespace gtsam From 9800e110aab52433ce90494d398c1bbbf9f37523 Mon Sep 17 00:00:00 2001 From: Frank Dellaert Date: Mon, 27 Jan 2025 14:01:31 -0500 Subject: [PATCH 25/43] Build etree and jtree --- gtsam/discrete/tests/testDiscreteSearch.cpp | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/gtsam/discrete/tests/testDiscreteSearch.cpp b/gtsam/discrete/tests/testDiscreteSearch.cpp index b537dd2f0..7e715ca62 100644 --- a/gtsam/discrete/tests/testDiscreteSearch.cpp +++ b/gtsam/discrete/tests/testDiscreteSearch.cpp @@ -18,6 +18,8 @@ #include #include +#include +#include #include #include "AsiaExample.h" @@ -28,13 +30,29 @@ using namespace gtsam; namespace asia { using namespace asia_example; static const DiscreteBayesNet bayesNet = createAsiaExample(); + +// Create factor graph and optimize with max-product for MPE static const DiscreteFactorGraph factorGraph(bayesNet); static const DiscreteValues mpe = factorGraph.optimize(); + +// Create junction tree static const Ordering ordering{D, X, B, E, L, T, S, A}; + +static const DiscreteEliminationTree etree(factorGraph, ordering); +static const DiscreteJunctionTree junctionTree(etree); + +// Create Bayes tree static const DiscreteBayesTree bayesTree = *factorGraph.eliminateMultifrontal(ordering); } // namespace asia +/* ************************************************************************* */ +TEST(DiscreteBayesNet, AsiaFactorGraphKBest) { + GTSAM_PRINT(asia::etree); + GTSAM_PRINT(asia::junctionTree); + DiscreteSearch search(asia::factorGraph); +} + /* ************************************************************************* */ TEST(DiscreteBayesNet, EmptyKBest) { DiscreteBayesNet net; // no factors From c4870cc840e32a444754cd2b61ef5d0a344099c0 Mon Sep 17 00:00:00 2001 From: Frank Dellaert Date: Mon, 27 Jan 2025 14:49:54 -0500 Subject: [PATCH 26/43] Fix heuristic --- gtsam/discrete/DiscreteSearch.cpp | 55 ++++++++++++--------- gtsam/discrete/DiscreteSearch.h | 21 ++++++-- gtsam/discrete/tests/testDiscreteSearch.cpp | 11 +---- 3 files changed, 51 insertions(+), 36 deletions(-) diff --git a/gtsam/discrete/DiscreteSearch.cpp b/gtsam/discrete/DiscreteSearch.cpp index e3722c13e..439331fa1 100644 --- a/gtsam/discrete/DiscreteSearch.cpp +++ b/gtsam/discrete/DiscreteSearch.cpp @@ -16,6 +16,8 @@ * @author Frank Dellaert */ +#include +#include #include namespace gtsam { @@ -39,9 +41,8 @@ struct SearchNode { /** * @brief Construct the root node for the search. */ - static SearchNode Root(size_t numConditionals, double bound) { - return {DiscreteValues(), 0.0, bound, - static_cast(numConditionals) - 1}; + static SearchNode Root(size_t numSlots, double bound) { + return {DiscreteValues(), 0.0, bound, static_cast(numSlots) - 1}; } struct Compare { @@ -60,20 +61,18 @@ struct SearchNode { /** * @brief Expands the node by assigning the next variable. * - * @param factor The discrete factor associated with the next variable - * to be assigned. + * @param slot The slot to be filled. * @param fa The frontal assignment for the next variable. * @return A new SearchNode representing the expanded state. */ - SearchNode expand(const DiscreteFactor& factor, - const DiscreteValues& fa) const { + SearchNode expand(const Slot& slot, const DiscreteValues& fa) const { // Combine the new frontal assignment with the current partial assignment DiscreteValues newAssignment = assignment; for (auto& [key, value] : fa) { newAssignment[key] = value; } - - return {newAssignment, error + factor.error(newAssignment), 0.0, + double errorSoFar = error + slot.factor->error(newAssignment); + return {newAssignment, errorSoFar, errorSoFar + slot.heuristic, nextConditional - 1}; } @@ -151,12 +150,19 @@ class Solutions { } }; -DiscreteSearch::DiscreteSearch(const DiscreteFactorGraph& factorGraph) { +DiscreteSearch::DiscreteSearch(const DiscreteFactorGraph& factorGraph, + const Ordering& ordering, + bool buildJunctionTree) { + const DiscreteEliminationTree etree(factorGraph, ordering); + const DiscreteJunctionTree junctionTree(etree); + + // GTSAM_PRINT(asia::etree); + // GTSAM_PRINT(asia::junctionTree); slots_.reserve(factorGraph.size()); for (auto& factor : factorGraph) { slots_.emplace_back(factor, std::vector{}, 0.0); } - computeHeuristic(); + lowerBound_ = computeHeuristic(); } DiscreteSearch::DiscreteSearch(const DiscreteBayesNet& bayesNet) { @@ -164,7 +170,7 @@ DiscreteSearch::DiscreteSearch(const DiscreteBayesNet& bayesNet) { for (auto& conditional : bayesNet) { slots_.emplace_back(conditional, conditional->frontalAssignments(), 0.0); } - computeHeuristic(); + lowerBound_ = computeHeuristic(); } DiscreteSearch::DiscreteSearch(const DiscreteBayesTree& bayesTree) { @@ -179,7 +185,7 @@ DiscreteSearch::DiscreteSearch(const DiscreteBayesTree& bayesTree) { slots_.reserve(bayesTree.size()); for (const auto& root : bayesTree.roots()) collectConditionals(root); - computeHeuristic(); + lowerBound_ = computeHeuristic(); } struct SearchNodeQueue @@ -199,10 +205,7 @@ struct SearchNodeQueue } for (auto& fa : slot.assignments) { - auto childNode = current.expand(*slot.factor, fa); - if (childNode.nextConditional >= 0) - // TODO(frank): this might be wrong ! - childNode.bound = childNode.error + slot.heuristic; + auto childNode = current.expand(slot, fa); // Again, prune if we cannot beat the worst solution if (!solutions->prune(childNode.bound)) { @@ -212,10 +215,12 @@ struct SearchNodeQueue } }; +#define DISCRETE_SEARCH_DEBUG + std::vector DiscreteSearch::run(size_t K) const { Solutions solutions(K); SearchNodeQueue expansions; - expansions.push(SearchNode::Root(slots_.size(), slots_.back().heuristic)); + expansions.push(SearchNode::Root(slots_.size(), lowerBound_)); #ifdef DISCRETE_SEARCH_DEBUG size_t numExpansions = 0; @@ -244,18 +249,22 @@ std::vector DiscreteSearch::run(size_t K) const { } // We have a number of factors, each with a max value, and we want to compute -// the a lower-bound on the cost-to-go for each slot. For the first slot, this -// -log(max(factor[0])), as we only have one factor to resolve. For the second -// slot, we need to add -log(max(factor[1])) to it, etc... -void DiscreteSearch::computeHeuristic() { +// a lower-bound on the cost-to-go for each slot, *not* including this factor. +// For the first slot, this is 0.0, as this is the last slot to be filled, so +// the cost after that is zero. For the second slot, it is h0 = +// -log(max(factor[0])), because after we assign slot[1] we still need to assign +// slot[0], which will cost *at least* h0. +// We return the estimated lower bound of the cost for *all* slots. +double DiscreteSearch::computeHeuristic() { double error = 0.0; for (size_t i = 0; i < slots_.size(); ++i) { + slots_[i].heuristic = error; const auto& factor = slots_[i].factor; Ordering ordering(factor->begin(), factor->end()); auto maxx = factor->max(ordering); error -= std::log(maxx->evaluate({})); - slots_[i].heuristic = error; } + return error; } } // namespace gtsam diff --git a/gtsam/discrete/DiscreteSearch.h b/gtsam/discrete/DiscreteSearch.h index ddeaf38f1..3d5ee1d50 100644 --- a/gtsam/discrete/DiscreteSearch.h +++ b/gtsam/discrete/DiscreteSearch.h @@ -61,8 +61,19 @@ class GTSAM_EXPORT DiscreteSearch { public: /** * Construct from a DiscreteFactorGraph. + * + * Internally creates either an elimination tree or a junction tree. The + * latter incurs more up-front computation but the search itself might be + * faster. Then again, for the elimination tree, the heuristic will be more + * fine-grained (more slots). + * + * @param factorGraph The factor graph to search over. + * @param ordering The ordering of the variables to search over. + * @param buildJunctionTree Whether to build a junction tree for the factor + * graph. */ - DiscreteSearch(const DiscreteFactorGraph& bayesNet); + DiscreteSearch(const DiscreteFactorGraph& factorGraph, + const Ordering& ordering, bool buildJunctionTree = false); /** * Construct from a DiscreteBayesNet. @@ -87,9 +98,11 @@ class GTSAM_EXPORT DiscreteSearch { std::vector run(size_t K = 1) const; private: - /// Compute the cumulative lower-bound cost-to-go for each slot. - void computeHeuristic(); + /// Compute the cumulative lower-bound cost-to-go after each slot is filled. + /// @return the estimated lower bound of the cost for *all* slots. + double computeHeuristic(); - std::vector slots_; + double lowerBound_; ///< Lower bound on the cost-to-go for the entire search. + std::vector slots_; ///< The slots to fill in the search. }; } // namespace gtsam diff --git a/gtsam/discrete/tests/testDiscreteSearch.cpp b/gtsam/discrete/tests/testDiscreteSearch.cpp index 7e715ca62..0f424bd45 100644 --- a/gtsam/discrete/tests/testDiscreteSearch.cpp +++ b/gtsam/discrete/tests/testDiscreteSearch.cpp @@ -18,8 +18,6 @@ #include #include -#include -#include #include #include "AsiaExample.h" @@ -35,12 +33,9 @@ static const DiscreteBayesNet bayesNet = createAsiaExample(); static const DiscreteFactorGraph factorGraph(bayesNet); static const DiscreteValues mpe = factorGraph.optimize(); -// Create junction tree +// Create ordering static const Ordering ordering{D, X, B, E, L, T, S, A}; -static const DiscreteEliminationTree etree(factorGraph, ordering); -static const DiscreteJunctionTree junctionTree(etree); - // Create Bayes tree static const DiscreteBayesTree bayesTree = *factorGraph.eliminateMultifrontal(ordering); @@ -48,9 +43,7 @@ static const DiscreteBayesTree bayesTree = /* ************************************************************************* */ TEST(DiscreteBayesNet, AsiaFactorGraphKBest) { - GTSAM_PRINT(asia::etree); - GTSAM_PRINT(asia::junctionTree); - DiscreteSearch search(asia::factorGraph); + DiscreteSearch search(asia::factorGraph, asia::ordering); } /* ************************************************************************* */ From 0bc566f69227c5ae3f198daaf510e331c677e7ef Mon Sep 17 00:00:00 2001 From: Frank Dellaert Date: Mon, 27 Jan 2025 19:06:01 -0500 Subject: [PATCH 27/43] Working etree and jtree versions --- gtsam/discrete/DiscreteSearch.cpp | 75 ++++++++++++++++++++++++------- gtsam/discrete/DiscreteSearch.h | 54 +++++++++++++++++++--- 2 files changed, 107 insertions(+), 22 deletions(-) diff --git a/gtsam/discrete/DiscreteSearch.cpp b/gtsam/discrete/DiscreteSearch.cpp index 439331fa1..4270dfbf7 100644 --- a/gtsam/discrete/DiscreteSearch.cpp +++ b/gtsam/discrete/DiscreteSearch.cpp @@ -150,21 +150,58 @@ class Solutions { } }; -DiscreteSearch::DiscreteSearch(const DiscreteFactorGraph& factorGraph, - const Ordering& ordering, - bool buildJunctionTree) { - const DiscreteEliminationTree etree(factorGraph, ordering); - const DiscreteJunctionTree junctionTree(etree); +DiscreteSearch::DiscreteSearch(const DiscreteEliminationTree& etree) { + using NodePtr = std::shared_ptr; + auto visitor = [this](const NodePtr& node, int data) { + const auto& factors = node->factors; + const auto factor = factors.size() == 1 + ? factors.back() + : DiscreteFactorGraph(factors).product(); + const size_t cardinality = factor->cardinality(node->key); + std::vector> pairs{{node->key, cardinality}}; + slots_.emplace_back(factor, DiscreteValues::CartesianProduct(pairs), 0.0); + return data + 1; + }; - // GTSAM_PRINT(asia::etree); - // GTSAM_PRINT(asia::junctionTree); - slots_.reserve(factorGraph.size()); - for (auto& factor : factorGraph) { - slots_.emplace_back(factor, std::vector{}, 0.0); - } + const int data = 0; // unused + treeTraversal::DepthFirstForest(etree, data, visitor); + std::reverse(slots_.begin(), slots_.end()); // reverse slots lowerBound_ = computeHeuristic(); } +DiscreteSearch::DiscreteSearch(const DiscreteJunctionTree& junctionTree) { + using NodePtr = std::shared_ptr; + auto visitor = [this](const NodePtr& cluster, int data) { + const auto& factors = cluster->factors; + const auto factor = factors.size() == 1 + ? factors.back() + : DiscreteFactorGraph(factors).product(); + std::vector> pairs; + for (Key key : cluster->orderedFrontalKeys) { + pairs.emplace_back(key, factor->cardinality(key)); + } + slots_.emplace_back(factor, DiscreteValues::CartesianProduct(pairs), 0.0); + return data + 1; + }; + + const int data = 0; // unused + treeTraversal::DepthFirstForest(junctionTree, data, visitor); + std::reverse(slots_.begin(), slots_.end()); // reverse slots + lowerBound_ = computeHeuristic(); +} + +DiscreteSearch DiscreteSearch::FromFactorGraph( + const DiscreteFactorGraph& factorGraph, const Ordering& ordering, + bool buildJunctionTree) { + const DiscreteEliminationTree etree(factorGraph, ordering); + if (buildJunctionTree) { + const DiscreteJunctionTree junctionTree(etree); + return DiscreteSearch(junctionTree); + } else { + return DiscreteSearch(etree); + } +} + DiscreteSearch::DiscreteSearch(const DiscreteBayesNet& bayesNet) { slots_.reserve(bayesNet.size()); for (auto& conditional : bayesNet) { @@ -188,6 +225,14 @@ DiscreteSearch::DiscreteSearch(const DiscreteBayesTree& bayesTree) { lowerBound_ = computeHeuristic(); } +void DiscreteSearch::print(const std::string& name, + const KeyFormatter& formatter) const { + std::cout << name << " with " << slots_.size() << " slots:\n"; + for (size_t i = 0; i < slots_.size(); ++i) { + std::cout << i << ": " << slots_[i] << std::endl; + } +} + struct SearchNodeQueue : public std::priority_queue, SearchNode::Compare> { @@ -215,8 +260,6 @@ struct SearchNodeQueue } }; -#define DISCRETE_SEARCH_DEBUG - std::vector DiscreteSearch::run(size_t K) const { Solutions solutions(K); SearchNodeQueue expansions; @@ -252,9 +295,9 @@ std::vector DiscreteSearch::run(size_t K) const { // a lower-bound on the cost-to-go for each slot, *not* including this factor. // For the first slot, this is 0.0, as this is the last slot to be filled, so // the cost after that is zero. For the second slot, it is h0 = -// -log(max(factor[0])), because after we assign slot[1] we still need to assign -// slot[0], which will cost *at least* h0. -// We return the estimated lower bound of the cost for *all* slots. +// -log(max(factor[0])), because after we assign slot[1] we still need to +// assign slot[0], which will cost *at least* h0. We return the estimated +// lower bound of the cost for *all* slots. double DiscreteSearch::computeHeuristic() { double error = 0.0; for (size_t i = 0; i < slots_.size(); ++i) { diff --git a/gtsam/discrete/DiscreteSearch.h b/gtsam/discrete/DiscreteSearch.h index 3d5ee1d50..44e605e34 100644 --- a/gtsam/discrete/DiscreteSearch.h +++ b/gtsam/discrete/DiscreteSearch.h @@ -10,7 +10,11 @@ * -------------------------------------------------------------------------- */ /* - * DiscreteSearch.cpp + * @file DiscreteSearch.h + * @brief Defines the DiscreteSearch class for discrete search algorithms. + * + * @details This file contains the definition of the DiscreteSearch class, which + * is used in discrete search algorithms to find the K best solutions. * * @date January, 2025 * @author Frank Dellaert @@ -43,6 +47,13 @@ class GTSAM_EXPORT DiscreteSearch { /// A lower bound on the cost-to-go for each slot, e.g., /// [-log(max_B P(B|A)), -log(max_A P(A))] double heuristic; + + friend std::ostream& operator<<(std::ostream& os, const Slot& slot) { + os << "Slot with " << slot.assignments.size() + << " assignments, heuristic=" << slot.heuristic; + os << ", factor:\n" << slot.factor->markdown() << std::endl; + return os; + } }; /// A solution is then a set of assignments, covering all the slots. @@ -59,6 +70,9 @@ class GTSAM_EXPORT DiscreteSearch { }; public: + /// @name Standard Constructors + /// @{ + /** * Construct from a DiscreteFactorGraph. * @@ -68,12 +82,26 @@ class GTSAM_EXPORT DiscreteSearch { * fine-grained (more slots). * * @param factorGraph The factor graph to search over. - * @param ordering The ordering of the variables to search over. - * @param buildJunctionTree Whether to build a junction tree for the factor - * graph. + * @param ordering The ordering used to create etree (and maybe jtree). + * @param buildJunctionTree Whether to build a junction tree or not. */ - DiscreteSearch(const DiscreteFactorGraph& factorGraph, - const Ordering& ordering, bool buildJunctionTree = false); + static DiscreteSearch FromFactorGraph(const DiscreteFactorGraph& factorGraph, + const Ordering& ordering, + bool buildJunctionTree = false); + + /** + * @brief Constructor from a DiscreteEliminationTree. + * + * @param etree The DiscreteEliminationTree to initialize from. + */ + DiscreteSearch(const DiscreteEliminationTree& etree); + + /** + * @brief Constructor from a DiscreteJunctionTree. + * + * @param junctionTree The DiscreteJunctionTree to initialize from. + */ + DiscreteSearch(const DiscreteJunctionTree& junctionTree); /** * Construct from a DiscreteBayesNet. @@ -85,6 +113,18 @@ class GTSAM_EXPORT DiscreteSearch { */ DiscreteSearch(const DiscreteBayesTree& bayesTree); + /// @} + /// @name Testable + /// @{ + + /** Print the tree to cout */ + void print(const std::string& name = "DiscreteSearch: ", + const KeyFormatter& formatter = DefaultKeyFormatter) const; + + /// @} + /// @name Standard API + /// @{ + /** * @brief Search for the K best solutions. * @@ -97,6 +137,8 @@ class GTSAM_EXPORT DiscreteSearch { */ std::vector run(size_t K = 1) const; + /// @} + private: /// Compute the cumulative lower-bound cost-to-go after each slot is filled. /// @return the estimated lower bound of the cost for *all* slots. From 8d6b6055fe3254104ac69b0027db527d0d513e5a Mon Sep 17 00:00:00 2001 From: Frank Dellaert Date: Mon, 27 Jan 2025 19:06:09 -0500 Subject: [PATCH 28/43] Loop over all variants --- gtsam/discrete/tests/testDiscreteSearch.cpp | 66 +++++++-------------- 1 file changed, 23 insertions(+), 43 deletions(-) diff --git a/gtsam/discrete/tests/testDiscreteSearch.cpp b/gtsam/discrete/tests/testDiscreteSearch.cpp index 0f424bd45..dd5389558 100644 --- a/gtsam/discrete/tests/testDiscreteSearch.cpp +++ b/gtsam/discrete/tests/testDiscreteSearch.cpp @@ -41,11 +41,6 @@ static const DiscreteBayesTree bayesTree = *factorGraph.eliminateMultifrontal(ordering); } // namespace asia -/* ************************************************************************* */ -TEST(DiscreteBayesNet, AsiaFactorGraphKBest) { - DiscreteSearch search(asia::factorGraph, asia::ordering); -} - /* ************************************************************************* */ TEST(DiscreteBayesNet, EmptyKBest) { DiscreteBayesNet net; // no factors @@ -56,29 +51,6 @@ TEST(DiscreteBayesNet, EmptyKBest) { EXPECT_DOUBLES_EQUAL(0, std::fabs(solutions[0].error), 1e-9); } -/* ************************************************************************* */ -TEST(DiscreteBayesNet, AsiaKBest) { - const DiscreteSearch search(asia::bayesNet); - - // Ask for the MPE - auto mpe = search.run(); - - EXPECT_LONGS_EQUAL(1, mpe.size()); - // Regression test: check the MPE solution - EXPECT_DOUBLES_EQUAL(1.236627, std::fabs(mpe[0].error), 1e-5); - - // Check it is equal to MPE via inference - EXPECT(assert_equal(asia::mpe, mpe[0].assignment)); - - // Ask for top 4 solutions - auto solutions = search.run(4); - - EXPECT_LONGS_EQUAL(4, solutions.size()); - // Regression test: check the first and last solution - EXPECT_DOUBLES_EQUAL(1.236627, std::fabs(solutions[0].error), 1e-5); - EXPECT_DOUBLES_EQUAL(2.201708, std::fabs(solutions[3].error), 1e-5); -} - /* ************************************************************************* */ TEST(DiscreteBayesTree, EmptyTree) { DiscreteBayesTree bt; @@ -92,26 +64,34 @@ TEST(DiscreteBayesTree, EmptyTree) { } /* ************************************************************************* */ -TEST(DiscreteBayesTree, AsiaTreeKBest) { - DiscreteSearch search(asia::bayesTree); +TEST(DiscreteBayesNet, AsiaKBest) { + auto fromETree = + DiscreteSearch::FromFactorGraph(asia::factorGraph, asia::ordering); + auto fromJunctionTree = + DiscreteSearch::FromFactorGraph(asia::factorGraph, asia::ordering, true); + const DiscreteSearch fromBayesNet(asia::bayesNet); + const DiscreteSearch fromBayesTree(asia::bayesTree); - // Ask for MPE - auto mpe = search.run(); + for (auto& search : + {fromETree, fromJunctionTree, fromBayesNet, fromBayesTree}) { + // Ask for the MPE + auto mpe = search.run(); - EXPECT_LONGS_EQUAL(1, mpe.size()); - // Regression test: check the MPE solution - EXPECT_DOUBLES_EQUAL(1.236627, std::fabs(mpe[0].error), 1e-5); + EXPECT_LONGS_EQUAL(1, mpe.size()); + // Regression test: check the MPE solution + EXPECT_DOUBLES_EQUAL(1.236627, std::fabs(mpe[0].error), 1e-5); - // Check it is equal to MPE via inference - EXPECT(assert_equal(asia::mpe, mpe[0].assignment)); + // Check it is equal to MPE via inference + EXPECT(assert_equal(asia::mpe, mpe[0].assignment)); - // Ask for top 4 solutions - auto solutions = search.run(4); + // Ask for top 4 solutions + auto solutions = search.run(4); - EXPECT_LONGS_EQUAL(4, solutions.size()); - // Regression test: check the first and last solution - EXPECT_DOUBLES_EQUAL(1.236627, std::fabs(solutions[0].error), 1e-5); - EXPECT_DOUBLES_EQUAL(2.201708, std::fabs(solutions[3].error), 1e-5); + EXPECT_LONGS_EQUAL(4, solutions.size()); + // Regression test: check the first and last solution + EXPECT_DOUBLES_EQUAL(1.236627, std::fabs(solutions[0].error), 1e-5); + EXPECT_DOUBLES_EQUAL(2.201708, std::fabs(solutions[3].error), 1e-5); + } } /* ************************************************************************* */ From e6f0fdce4402bd45a4dbc09ff82be87c9962cb2c Mon Sep 17 00:00:00 2001 From: jmackay2 <1.732mackay@gmail.com> Date: Mon, 27 Jan 2025 21:23:09 -0500 Subject: [PATCH 29/43] fix hybrid uninitialized warning --- examples/Hybrid_City10000.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/Hybrid_City10000.cpp b/examples/Hybrid_City10000.cpp index 42f0169c8..2051a75d6 100644 --- a/examples/Hybrid_City10000.cpp +++ b/examples/Hybrid_City10000.cpp @@ -120,7 +120,7 @@ int main(int argc, char* argv[]) { SmootherUpdate(smoother, graph, init_values, maxNrHypotheses, &results); - size_t key_s, key_t; + size_t key_s, key_t{0}; clock_t start_time = clock(); std::string str; From af07409c10f9b3a1c6742a341f47d83c81cb802e Mon Sep 17 00:00:00 2001 From: Frank Dellaert Date: Mon, 27 Jan 2025 23:35:31 -0500 Subject: [PATCH 30/43] Fix some CI issues --- gtsam/discrete/DiscreteJunctionTree.cpp | 2 +- gtsam/discrete/DiscreteSearch.cpp | 22 ++++++++++++---------- 2 files changed, 13 insertions(+), 11 deletions(-) diff --git a/gtsam/discrete/DiscreteJunctionTree.cpp b/gtsam/discrete/DiscreteJunctionTree.cpp index b4657ec8d..0c9eb10ef 100644 --- a/gtsam/discrete/DiscreteJunctionTree.cpp +++ b/gtsam/discrete/DiscreteJunctionTree.cpp @@ -51,7 +51,7 @@ struct PrintForestVisitorPre { void DiscreteJunctionTree::print(const std::string& s, const KeyFormatter& keyFormatter) const { PrintForestVisitorPre visitor(keyFormatter); - treeTraversal::DepthFirstForest(*this, s, visitor); + treeTraversal::DepthFirstForest(*this, std::string(s), visitor); } } // namespace gtsam diff --git a/gtsam/discrete/DiscreteSearch.cpp b/gtsam/discrete/DiscreteSearch.cpp index 4270dfbf7..71485795e 100644 --- a/gtsam/discrete/DiscreteSearch.cpp +++ b/gtsam/discrete/DiscreteSearch.cpp @@ -159,7 +159,8 @@ DiscreteSearch::DiscreteSearch(const DiscreteEliminationTree& etree) { : DiscreteFactorGraph(factors).product(); const size_t cardinality = factor->cardinality(node->key); std::vector> pairs{{node->key, cardinality}}; - slots_.emplace_back(factor, DiscreteValues::CartesianProduct(pairs), 0.0); + const Slot slot(factor, DiscreteValues::CartesianProduct(pairs), 0.0); + slots_.emplace_back(std::move(slot)); return data + 1; }; @@ -180,7 +181,8 @@ DiscreteSearch::DiscreteSearch(const DiscreteJunctionTree& junctionTree) { for (Key key : cluster->orderedFrontalKeys) { pairs.emplace_back(key, factor->cardinality(key)); } - slots_.emplace_back(factor, DiscreteValues::CartesianProduct(pairs), 0.0); + const Slot slot(factor, DiscreteValues::CartesianProduct(pairs), 0.0); + slots_.emplace_back(std::move(slot)); return data + 1; }; @@ -205,7 +207,8 @@ DiscreteSearch DiscreteSearch::FromFactorGraph( DiscreteSearch::DiscreteSearch(const DiscreteBayesNet& bayesNet) { slots_.reserve(bayesNet.size()); for (auto& conditional : bayesNet) { - slots_.emplace_back(conditional, conditional->frontalAssignments(), 0.0); + const Slot slot(conditional, conditional->frontalAssignments(), 0.0); + slots_.emplace_back(std::move(slot)); } lowerBound_ = computeHeuristic(); } @@ -216,8 +219,8 @@ DiscreteSearch::DiscreteSearch(const DiscreteBayesTree& bayesTree) { if (!clique) return; for (const auto& child : clique->children) collectConditionals(child); auto conditional = clique->conditional(); - slots_.emplace_back(conditional, conditional->frontalAssignments(), - 0.0); + const Slot slot(conditional, conditional->frontalAssignments(), 0.0); + slots_.emplace_back(std::move(slot)); }; slots_.reserve(bayesTree.size()); @@ -300,11 +303,10 @@ std::vector DiscreteSearch::run(size_t K) const { // lower bound of the cost for *all* slots. double DiscreteSearch::computeHeuristic() { double error = 0.0; - for (size_t i = 0; i < slots_.size(); ++i) { - slots_[i].heuristic = error; - const auto& factor = slots_[i].factor; - Ordering ordering(factor->begin(), factor->end()); - auto maxx = factor->max(ordering); + for (auto& slot : slots_) { + slot.heuristic = error; + Ordering ordering(slot.factor->begin(), slot.factor->end()); + auto maxx = slot.factor->max(ordering); error -= std::log(maxx->evaluate({})); } return error; From 460a9a958e2655c70eeb04bcfd775ffc4cce5ec3 Mon Sep 17 00:00:00 2001 From: Frank Dellaert Date: Mon, 27 Jan 2025 23:41:35 -0500 Subject: [PATCH 31/43] Fix compilation issue --- gtsam/discrete/DiscreteJunctionTree.cpp | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/gtsam/discrete/DiscreteJunctionTree.cpp b/gtsam/discrete/DiscreteJunctionTree.cpp index 0c9eb10ef..e1fc2af11 100644 --- a/gtsam/discrete/DiscreteJunctionTree.cpp +++ b/gtsam/discrete/DiscreteJunctionTree.cpp @@ -51,7 +51,8 @@ struct PrintForestVisitorPre { void DiscreteJunctionTree::print(const std::string& s, const KeyFormatter& keyFormatter) const { PrintForestVisitorPre visitor(keyFormatter); - treeTraversal::DepthFirstForest(*this, std::string(s), visitor); + std::string parentString = s; + treeTraversal::DepthFirstForest(*this, parentString, visitor); } } // namespace gtsam From b8f265d69f041f2fab32da147cc8f32358852802 Mon Sep 17 00:00:00 2001 From: Frank Dellaert Date: Mon, 27 Jan 2025 23:50:08 -0500 Subject: [PATCH 32/43] Use brace init --- gtsam/discrete/DiscreteSearch.cpp | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/gtsam/discrete/DiscreteSearch.cpp b/gtsam/discrete/DiscreteSearch.cpp index 71485795e..569887d7f 100644 --- a/gtsam/discrete/DiscreteSearch.cpp +++ b/gtsam/discrete/DiscreteSearch.cpp @@ -159,7 +159,7 @@ DiscreteSearch::DiscreteSearch(const DiscreteEliminationTree& etree) { : DiscreteFactorGraph(factors).product(); const size_t cardinality = factor->cardinality(node->key); std::vector> pairs{{node->key, cardinality}}; - const Slot slot(factor, DiscreteValues::CartesianProduct(pairs), 0.0); + const Slot slot{factor, DiscreteValues::CartesianProduct(pairs), 0.0}; slots_.emplace_back(std::move(slot)); return data + 1; }; @@ -181,7 +181,7 @@ DiscreteSearch::DiscreteSearch(const DiscreteJunctionTree& junctionTree) { for (Key key : cluster->orderedFrontalKeys) { pairs.emplace_back(key, factor->cardinality(key)); } - const Slot slot(factor, DiscreteValues::CartesianProduct(pairs), 0.0); + const Slot slot{factor, DiscreteValues::CartesianProduct(pairs), 0.0}; slots_.emplace_back(std::move(slot)); return data + 1; }; @@ -207,7 +207,7 @@ DiscreteSearch DiscreteSearch::FromFactorGraph( DiscreteSearch::DiscreteSearch(const DiscreteBayesNet& bayesNet) { slots_.reserve(bayesNet.size()); for (auto& conditional : bayesNet) { - const Slot slot(conditional, conditional->frontalAssignments(), 0.0); + const Slot slot{conditional, conditional->frontalAssignments(), 0.0}; slots_.emplace_back(std::move(slot)); } lowerBound_ = computeHeuristic(); @@ -219,7 +219,7 @@ DiscreteSearch::DiscreteSearch(const DiscreteBayesTree& bayesTree) { if (!clique) return; for (const auto& child : clique->children) collectConditionals(child); auto conditional = clique->conditional(); - const Slot slot(conditional, conditional->frontalAssignments(), 0.0); + const Slot slot{conditional, conditional->frontalAssignments(), 0.0}; slots_.emplace_back(std::move(slot)); }; From 9e98b805d6cfb07373dc8fa2d6a4ebd2de7492aa Mon Sep 17 00:00:00 2001 From: Frank Dellaert Date: Tue, 28 Jan 2025 00:17:45 -0500 Subject: [PATCH 33/43] Reversed slots so we start from zero --- gtsam/discrete/DiscreteSearch.cpp | 170 ++++++++------------ gtsam/discrete/DiscreteSearch.h | 6 + gtsam/discrete/tests/testDiscreteSearch.cpp | 11 ++ 3 files changed, 88 insertions(+), 99 deletions(-) diff --git a/gtsam/discrete/DiscreteSearch.cpp b/gtsam/discrete/DiscreteSearch.cpp index 569887d7f..43f321d4a 100644 --- a/gtsam/discrete/DiscreteSearch.cpp +++ b/gtsam/discrete/DiscreteSearch.cpp @@ -33,16 +33,16 @@ using Solution = DiscreteSearch::Solution; * conditional to be assigned. */ struct SearchNode { - DiscreteValues assignment; ///< Partial assignment of discrete variables. - double error; ///< Current error for the partial assignment. - double bound; ///< Lower bound on the final error for unassigned variables. - int nextConditional; ///< Index of the next conditional to be assigned. + DiscreteValues assignment; ///< Partial assignment of discrete variables. + double error; ///< Current error for the partial assignment. + double bound; ///< Lower bound on the final error + std::optional next; ///< Index of the next factor to be assigned. /** * @brief Construct the root node for the search. */ static SearchNode Root(size_t numSlots, double bound) { - return {DiscreteValues(), 0.0, bound, static_cast(numSlots) - 1}; + return {DiscreteValues(), 0.0, bound, 0}; } struct Compare { @@ -51,38 +51,22 @@ struct SearchNode { } }; - /** - * @brief Checks if the node represents a complete assignment. - * - * @return True if all variables have been assigned, false otherwise. - */ - inline bool isComplete() const { return nextConditional < 0; } + /// Checks if the node represents a complete assignment. + inline bool isComplete() const { return !next; } - /** - * @brief Expands the node by assigning the next variable. - * - * @param slot The slot to be filled. - * @param fa The frontal assignment for the next variable. - * @return A new SearchNode representing the expanded state. - */ - SearchNode expand(const Slot& slot, const DiscreteValues& fa) const { + /// Expands the node by assigning the next variable(s). + SearchNode expand(const DiscreteValues& fa, const Slot& slot, + std::optional nextSlot) const { // Combine the new frontal assignment with the current partial assignment DiscreteValues newAssignment = assignment; for (auto& [key, value] : fa) { newAssignment[key] = value; } double errorSoFar = error + slot.factor->error(newAssignment); - return {newAssignment, errorSoFar, errorSoFar + slot.heuristic, - nextConditional - 1}; + return {newAssignment, errorSoFar, errorSoFar + slot.heuristic, nextSlot}; } - /** - * @brief Prints the SearchNode to an output stream. - * - * @param os The output stream. - * @param node The SearchNode to be printed. - * @return The output stream. - */ + /// Prints the SearchNode to an output stream. friend std::ostream& operator<<(std::ostream& os, const SearchNode& node) { os << "SearchNode(error=" << node.error << ", bound=" << node.bound << ")"; return os; @@ -150,13 +134,18 @@ class Solutions { } }; +/// @brief Get the factor associated with a node, possibly product of factors. +template +static auto getFactor(const NodeType& node) { + const auto& factors = node->factors; + return factors.size() == 1 ? factors.back() + : DiscreteFactorGraph(factors).product(); +} + DiscreteSearch::DiscreteSearch(const DiscreteEliminationTree& etree) { using NodePtr = std::shared_ptr; auto visitor = [this](const NodePtr& node, int data) { - const auto& factors = node->factors; - const auto factor = factors.size() == 1 - ? factors.back() - : DiscreteFactorGraph(factors).product(); + const auto factor = getFactor(node); const size_t cardinality = factor->cardinality(node->key); std::vector> pairs{{node->key, cardinality}}; const Slot slot{factor, DiscreteValues::CartesianProduct(pairs), 0.0}; @@ -164,19 +153,15 @@ DiscreteSearch::DiscreteSearch(const DiscreteEliminationTree& etree) { return data + 1; }; - const int data = 0; // unused + int data = 0; // unused treeTraversal::DepthFirstForest(etree, data, visitor); - std::reverse(slots_.begin(), slots_.end()); // reverse slots lowerBound_ = computeHeuristic(); } DiscreteSearch::DiscreteSearch(const DiscreteJunctionTree& junctionTree) { using NodePtr = std::shared_ptr; auto visitor = [this](const NodePtr& cluster, int data) { - const auto& factors = cluster->factors; - const auto factor = factors.size() == 1 - ? factors.back() - : DiscreteFactorGraph(factors).product(); + const auto factor = getFactor(cluster); std::vector> pairs; for (Key key : cluster->orderedFrontalKeys) { pairs.emplace_back(key, factor->cardinality(key)); @@ -186,9 +171,8 @@ DiscreteSearch::DiscreteSearch(const DiscreteJunctionTree& junctionTree) { return data + 1; }; - const int data = 0; // unused + int data = 0; // unused treeTraversal::DepthFirstForest(junctionTree, data, visitor); - std::reverse(slots_.begin(), slots_.end()); // reverse slots lowerBound_ = computeHeuristic(); } @@ -210,21 +194,21 @@ DiscreteSearch::DiscreteSearch(const DiscreteBayesNet& bayesNet) { const Slot slot{conditional, conditional->frontalAssignments(), 0.0}; slots_.emplace_back(std::move(slot)); } + std::reverse(slots_.begin(), slots_.end()); lowerBound_ = computeHeuristic(); } DiscreteSearch::DiscreteSearch(const DiscreteBayesTree& bayesTree) { - std::function - collectConditionals = [&](const auto& clique) { - if (!clique) return; - for (const auto& child : clique->children) collectConditionals(child); - auto conditional = clique->conditional(); - const Slot slot{conditional, conditional->frontalAssignments(), 0.0}; - slots_.emplace_back(std::move(slot)); - }; + using NodePtr = DiscreteBayesTree::sharedClique; + auto visitor = [this](const NodePtr& clique, int data) { + auto conditional = clique->conditional(); + const Slot slot{conditional, conditional->frontalAssignments(), 0.0}; + slots_.emplace_back(std::move(slot)); + return data + 1; + }; - slots_.reserve(bayesTree.size()); - for (const auto& root : bayesTree.roots()) collectConditionals(root); + int data = 0; // unused + treeTraversal::DepthFirstForest(bayesTree, data, visitor); lowerBound_ = computeHeuristic(); } @@ -236,59 +220,48 @@ void DiscreteSearch::print(const std::string& name, } } -struct SearchNodeQueue - : public std::priority_queue, - SearchNode::Compare> { - void expandNextNode(const SearchNode& current, const Slot& slot, - Solutions* solutions) { - // If we already have K solutions, prune if we cannot beat the worst one. - if (solutions->prune(current.bound)) { - return; - } - - // Check if we have a complete assignment - if (current.isComplete()) { - solutions->maybeAdd(current.error, current.assignment); - return; - } - - for (auto& fa : slot.assignments) { - auto childNode = current.expand(slot, fa); - - // Again, prune if we cannot beat the worst solution - if (!solutions->prune(childNode.bound)) { - emplace(childNode); - } - } - } -}; +using SearchNodeQueue = std::priority_queue, + SearchNode::Compare>; std::vector DiscreteSearch::run(size_t K) const { + if (slots_.empty()) { + return {Solution(0.0, DiscreteValues())}; + } + Solutions solutions(K); SearchNodeQueue expansions; expansions.push(SearchNode::Root(slots_.size(), lowerBound_)); -#ifdef DISCRETE_SEARCH_DEBUG - size_t numExpansions = 0; -#endif - // Perform the search while (!expansions.empty()) { // Pop the partial assignment with the smallest bound SearchNode current = expansions.top(); expansions.pop(); - // Get the next slot to expand - const auto& slot = slots_[current.nextConditional]; - expansions.expandNextNode(current, slot, &solutions); -#ifdef DISCRETE_SEARCH_DEBUG - ++numExpansions; -#endif - } + // If we already have K solutions, prune if we cannot beat the worst one. + if (solutions.prune(current.bound)) { + continue; + } -#ifdef DISCRETE_SEARCH_DEBUG - std::cout << "Number of expansions: " << numExpansions << std::endl; -#endif + // Check if we have a complete assignment + if (current.isComplete()) { + solutions.maybeAdd(current.error, current.assignment); + continue; + } + + // Get the next slot to expand + const auto& slot = slots_[*current.next]; + std::optional nextSlot = *current.next + 1; + if (nextSlot == slots_.size()) nextSlot.reset(); + for (auto& fa : slot.assignments) { + auto childNode = current.expand(fa, slot, nextSlot); + + // Again, prune if we cannot beat the worst solution + if (!solutions.prune(childNode.bound)) { + expansions.emplace(childNode); + } + } + } // Extract solutions from bestSolutions in ascending order of error return solutions.extractSolutions(); @@ -296,17 +269,16 @@ std::vector DiscreteSearch::run(size_t K) const { // We have a number of factors, each with a max value, and we want to compute // a lower-bound on the cost-to-go for each slot, *not* including this factor. -// For the first slot, this is 0.0, as this is the last slot to be filled, so -// the cost after that is zero. For the second slot, it is h0 = -// -log(max(factor[0])), because after we assign slot[1] we still need to -// assign slot[0], which will cost *at least* h0. We return the estimated -// lower bound of the cost for *all* slots. +// For the last slot, this is 0.0, as the cost after that is zero. +// For the second-to-last slot, it is -log(max(factor[0])), because after we +// assign slot[1] we still need to assign slot[0], which will cost *at least* +// h0. We return the estimated lower bound of the cost for *all* slots. double DiscreteSearch::computeHeuristic() { double error = 0.0; - for (auto& slot : slots_) { - slot.heuristic = error; - Ordering ordering(slot.factor->begin(), slot.factor->end()); - auto maxx = slot.factor->max(ordering); + for (auto it = slots_.rbegin(); it != slots_.rend(); ++it) { + it->heuristic = error; + Ordering ordering(it->factor->begin(), it->factor->end()); + auto maxx = it->factor->max(ordering); error -= std::log(maxx->evaluate({})); } return error; diff --git a/gtsam/discrete/DiscreteSearch.h b/gtsam/discrete/DiscreteSearch.h index 44e605e34..700e41392 100644 --- a/gtsam/discrete/DiscreteSearch.h +++ b/gtsam/discrete/DiscreteSearch.h @@ -125,6 +125,12 @@ class GTSAM_EXPORT DiscreteSearch { /// @name Standard API /// @{ + /// Return lower bound on the cost-to-go for the entire search + double lowerBound() const { return lowerBound_; } + + /// Read access to the slots + const std::vector& slots() const { return slots_; } + /** * @brief Search for the K best solutions. * diff --git a/gtsam/discrete/tests/testDiscreteSearch.cpp b/gtsam/discrete/tests/testDiscreteSearch.cpp index dd5389558..cebddfe8d 100644 --- a/gtsam/discrete/tests/testDiscreteSearch.cpp +++ b/gtsam/discrete/tests/testDiscreteSearch.cpp @@ -77,6 +77,17 @@ TEST(DiscreteBayesNet, AsiaKBest) { // Ask for the MPE auto mpe = search.run(); + // Regression on error lower bound + EXPECT_DOUBLES_EQUAL(1.205536, search.lowerBound(), 1e-5); + + // Check that the cost-to-go heuristic decreases from there + auto slots = search.slots(); + double previousHeuristic = search.lowerBound(); + for (auto&& slot : slots) { + EXPECT(slot.heuristic <= previousHeuristic); + previousHeuristic = slot.heuristic; + } + EXPECT_LONGS_EQUAL(1, mpe.size()); // Regression test: check the MPE solution EXPECT_DOUBLES_EQUAL(1.236627, std::fabs(mpe[0].error), 1e-5); From 8c7e75bb25d0be16efada3c5ac7733ca12e58367 Mon Sep 17 00:00:00 2001 From: Frank Dellaert Date: Tue, 28 Jan 2025 14:50:11 -0500 Subject: [PATCH 34/43] Review comments --- gtsam/discrete/DiscreteJunctionTree.cpp | 26 +++++++++---------------- gtsam/discrete/tests/AsiaExample.h | 2 +- 2 files changed, 10 insertions(+), 18 deletions(-) diff --git a/gtsam/discrete/DiscreteJunctionTree.cpp b/gtsam/discrete/DiscreteJunctionTree.cpp index e1fc2af11..bf9f9fe18 100644 --- a/gtsam/discrete/DiscreteJunctionTree.cpp +++ b/gtsam/discrete/DiscreteJunctionTree.cpp @@ -31,26 +31,18 @@ DiscreteJunctionTree::DiscreteJunctionTree( const DiscreteEliminationTree& eliminationTree) : Base(eliminationTree) {} /* ************************************************************************* */ -namespace { -struct PrintForestVisitorPre { - const KeyFormatter& formatter; - PrintForestVisitorPre(const KeyFormatter& formatter) : formatter(formatter) {} - std::string operator()( - const std::shared_ptr& node, - const std::string& parentString) { - // Print the current node - node->print(parentString + "-", formatter); - node->factors.print(parentString + "-", formatter); - std::cout << std::endl; - // Increment the indentation - return parentString + "| "; - } -}; -} // namespace void DiscreteJunctionTree::print(const std::string& s, const KeyFormatter& keyFormatter) const { - PrintForestVisitorPre visitor(keyFormatter); + auto visitor = [&keyFormatter]( + const std::shared_ptr& node, + const std::string& parentString) { + // Print the current node + node->print(parentString + "-", keyFormatter); + node->factors.print(parentString + "-", keyFormatter); + std::cout << std::endl; + return parentString + "| "; // Increment the indentation + }; std::string parentString = s; treeTraversal::DepthFirstForest(*this, parentString, visitor); } diff --git a/gtsam/discrete/tests/AsiaExample.h b/gtsam/discrete/tests/AsiaExample.h index 6c327daec..ff6c4ea99 100644 --- a/gtsam/discrete/tests/AsiaExample.h +++ b/gtsam/discrete/tests/AsiaExample.h @@ -58,4 +58,4 @@ DiscreteBayesNet createAsiaExample() { return asia; } } // namespace asia_example -} // namespace gtsam \ No newline at end of file +} // namespace gtsam From 1afb0891437898a7fbace791a60792100402fa1c Mon Sep 17 00:00:00 2001 From: Frank Dellaert Date: Tue, 28 Jan 2025 14:50:39 -0500 Subject: [PATCH 35/43] Better and more consistent documentation. --- gtsam/discrete/DiscreteSearch.cpp | 73 +++++++++++++++--------------- gtsam/discrete/DiscreteSearch.h | 74 +++++++++++++++++-------------- 2 files changed, 78 insertions(+), 69 deletions(-) diff --git a/gtsam/discrete/DiscreteSearch.cpp b/gtsam/discrete/DiscreteSearch.cpp index 43f321d4a..c046f508f 100644 --- a/gtsam/discrete/DiscreteSearch.cpp +++ b/gtsam/discrete/DiscreteSearch.cpp @@ -9,7 +9,7 @@ * -------------------------------------------------------------------------- */ -/* +/** * DiscreteSearch.cpp * * @date January, 2025 @@ -25,22 +25,19 @@ namespace gtsam { using Slot = DiscreteSearch::Slot; using Solution = DiscreteSearch::Solution; -/** - * @brief Represents a node in the search tree for discrete search algorithms. - * - * @details Each SearchNode contains a partial assignment of discrete variables, - * the current error, a bound on the final error, and the index of the next - * conditional to be assigned. +/* + * A SearchNode represents a node in the search tree for the search algorithm. + * Each SearchNode contains a partial assignment of discrete variables, the + * current error, a bound on the final error, and the index of the next + * slot to be assigned. */ struct SearchNode { - DiscreteValues assignment; ///< Partial assignment of discrete variables. - double error; ///< Current error for the partial assignment. - double bound; ///< Lower bound on the final error - std::optional next; ///< Index of the next factor to be assigned. + DiscreteValues assignment; // Partial assignment of discrete variables. + double error; // Current error for the partial assignment. + double bound; // Lower bound on the final error + std::optional next; // Index of the next slot to be assigned. - /** - * @brief Construct the root node for the search. - */ + // Construct the root node for the search. static SearchNode Root(size_t numSlots, double bound) { return {DiscreteValues(), 0.0, bound, 0}; } @@ -51,10 +48,10 @@ struct SearchNode { } }; - /// Checks if the node represents a complete assignment. + // Checks if the node represents a complete assignment. inline bool isComplete() const { return !next; } - /// Expands the node by assigning the next variable(s). + // Expands the node by assigning the next variable(s). SearchNode expand(const DiscreteValues& fa, const Slot& slot, std::optional nextSlot) const { // Combine the new frontal assignment with the current partial assignment @@ -66,7 +63,7 @@ struct SearchNode { return {newAssignment, errorSoFar, errorSoFar + slot.heuristic, nextSlot}; } - /// Prints the SearchNode to an output stream. + // Prints the SearchNode to an output stream. friend std::ostream& operator<<(std::ostream& os, const SearchNode& node) { os << "SearchNode(error=" << node.error << ", bound=" << node.bound << ")"; return os; @@ -79,17 +76,20 @@ struct CompareSolution { } }; -// Define the Solutions class +/* + * A Solutions object maintains a priority queue of the best solutions found + * during the search. The priority queue is limited to a maximum size, and + * solutions are only added if they are better than the worst solution. + */ class Solutions { - private: - size_t maxSize_; + size_t maxSize_; // Maximum number of solutions to keep std::priority_queue, CompareSolution> pq_; public: Solutions(size_t maxSize) : maxSize_(maxSize) {} - /// Add a solution to the priority queue, possibly evicting the worst one. - /// Return true if we added the solution. + // Add a solution to the priority queue, possibly evicting the worst one. + // Return true if we added the solution. bool maybeAdd(double error, const DiscreteValues& assignment) { const bool full = pq_.size() == maxSize_; if (full && error >= pq_.top().error) return false; @@ -98,7 +98,7 @@ class Solutions { return true; } - /// Check if we have any solutions + // Check if we have any solutions bool empty() const { return pq_.empty(); } // Method to print all solutions @@ -112,9 +112,9 @@ class Solutions { return os; } - /// Check if (partial) solution with given bound can be pruned. If we have - /// room, we never prune. Otherwise, prune if lower bound on error is worse - /// than our current worst error. + // Check if (partial) solution with given bound can be pruned. If we have + // room, we never prune. Otherwise, prune if lower bound on error is worse + // than our current worst error. bool prune(double bound) const { if (pq_.size() < maxSize_) return false; return bound >= pq_.top().error; @@ -134,9 +134,9 @@ class Solutions { } }; -/// @brief Get the factor associated with a node, possibly product of factors. +// Get the factor associated with a node, possibly product of factors. template -static auto getFactor(const NodeType& node) { +static DiscreteFactor::shared_ptr getFactor(const NodeType& node) { const auto& factors = node->factors; return factors.size() == 1 ? factors.back() : DiscreteFactorGraph(factors).product(); @@ -145,7 +145,7 @@ static auto getFactor(const NodeType& node) { DiscreteSearch::DiscreteSearch(const DiscreteEliminationTree& etree) { using NodePtr = std::shared_ptr; auto visitor = [this](const NodePtr& node, int data) { - const auto factor = getFactor(node); + const DiscreteFactor::shared_ptr factor = getFactor(node); const size_t cardinality = factor->cardinality(node->key); std::vector> pairs{{node->key, cardinality}}; const Slot slot{factor, DiscreteValues::CartesianProduct(pairs), 0.0}; @@ -266,13 +266,14 @@ std::vector DiscreteSearch::run(size_t K) const { // Extract solutions from bestSolutions in ascending order of error return solutions.extractSolutions(); } - -// We have a number of factors, each with a max value, and we want to compute -// a lower-bound on the cost-to-go for each slot, *not* including this factor. -// For the last slot, this is 0.0, as the cost after that is zero. -// For the second-to-last slot, it is -log(max(factor[0])), because after we -// assign slot[1] we still need to assign slot[0], which will cost *at least* -// h0. We return the estimated lower bound of the cost for *all* slots. +/* + * We have a number of factors, each with a max value, and we want to compute + * a lower-bound on the cost-to-go for each slot, *not* including this factor. + * For the last slot[n-1], this is 0.0, as the cost after that is zero. + * For the second-to-last slot, it is h = -log(max(factor[n-1])), because after + * we assign slot[n-2] we still need to assign slot[n-1], which will cost *at + * least* h. We return the estimated lower bound of the cost for *all* slots. + */ double DiscreteSearch::computeHeuristic() { double error = 0.0; for (auto it = slots_.rbegin(); it != slots_.rend(); ++it) { diff --git a/gtsam/discrete/DiscreteSearch.h b/gtsam/discrete/DiscreteSearch.h index 700e41392..b610955b2 100644 --- a/gtsam/discrete/DiscreteSearch.h +++ b/gtsam/discrete/DiscreteSearch.h @@ -9,7 +9,7 @@ * -------------------------------------------------------------------------- */ -/* +/** * @file DiscreteSearch.h * @brief Defines the DiscreteSearch class for discrete search algorithms. * @@ -28,24 +28,40 @@ namespace gtsam { /** - * DiscreteSearch: Search for the K best solutions. + * @brief DiscreteSearch: Search for the K best solutions. + * + * This class is used to search for the K best solutions in a DiscreteBayesNet. + * This is implemented with a modified A* search algorithm that uses a priority + * queue to manage the search nodes. That machinery is defined in the .cpp file. + * The heuristic we use is the sum of the log-probabilities of the + * maximum-probability assignments for each slot, for all slots to the right of + * the current slot. + * + * TODO: The heuristic could be refined by using the partial assignment in + * search node to refine the max-probability assignment for the remaining slots. + * This would incur more computation but will lead to fewer expansions. */ class GTSAM_EXPORT DiscreteSearch { public: - /// We structure the search as a set of slots, each with a factor and - /// a set of variable assignments that need to be chosen. In addition, each - /// slot has a heuristic associated with it. + /** + * We structure the search as a set of slots, each with a factor and + * a set of variable assignments that need to be chosen. In addition, each + * slot has a heuristic associated with it. + * + * Example: + * The factors in the search problem (always parents before descendents!): + * [P(A), P(B|A), P(C|A,B)] + * The assignments for each factor. + * [[A0,A1], [B0,B1], [C0,C1,C2]] + * A lower bound on the cost-to-go after each slot, e.g., + * [-log(max_B P(B|A)) -log(max_C P(C|A,B)), -log(max_C P(C|A,B)), 0.0] + * Note that these decrease as we move from right to left. + * We keep the global lower bound as lowerBound_. In the example, it is: + * -log(max_B P(B|A)) -log(max_C P(C|A,B)) -log(max_C P(C|A,B)) + */ struct Slot { - /// The factors in the search problem, - /// e.g., [P(B|A),P(A)] DiscreteFactor::shared_ptr factor; - - /// The assignments for each factor, - /// e.g., [[B0,B1] [A0,A1]] std::vector assignments; - - /// A lower bound on the cost-to-go for each slot, e.g., - /// [-log(max_B P(B|A)), -log(max_A P(A))] double heuristic; friend std::ostream& operator<<(std::ostream& os, const Slot& slot) { @@ -56,8 +72,10 @@ class GTSAM_EXPORT DiscreteSearch { } }; - /// A solution is then a set of assignments, covering all the slots. - /// as well as an associated error = -log(probability) + /** + * A solution is a set of assignments, covering all the slots. + * as well as an associated error = -log(probability) + */ struct Solution { double error; DiscreteValues assignment; @@ -89,28 +107,16 @@ class GTSAM_EXPORT DiscreteSearch { const Ordering& ordering, bool buildJunctionTree = false); - /** - * @brief Constructor from a DiscreteEliminationTree. - * - * @param etree The DiscreteEliminationTree to initialize from. - */ + /// Construct from a DiscreteEliminationTree. DiscreteSearch(const DiscreteEliminationTree& etree); - /** - * @brief Constructor from a DiscreteJunctionTree. - * - * @param junctionTree The DiscreteJunctionTree to initialize from. - */ + /// Construct from a DiscreteJunctionTree. DiscreteSearch(const DiscreteJunctionTree& junctionTree); - /** - * Construct from a DiscreteBayesNet. - */ + //// Construct from a DiscreteBayesNet. DiscreteSearch(const DiscreteBayesNet& bayesNet); - /** - * Construct from a DiscreteBayesTree. - */ + /// Construct from a DiscreteBayesTree. DiscreteSearch(const DiscreteBayesTree& bayesTree); /// @} @@ -146,8 +152,10 @@ class GTSAM_EXPORT DiscreteSearch { /// @} private: - /// Compute the cumulative lower-bound cost-to-go after each slot is filled. - /// @return the estimated lower bound of the cost for *all* slots. + /** + * Compute the cumulative lower-bound cost-to-go after each slot is filled. + * @return the estimated lower bound of the cost for *all* slots. + */ double computeHeuristic(); double lowerBound_; ///< Lower bound on the cost-to-go for the entire search. From 19773153bc24107d744611fb71c953f2abf3ef45 Mon Sep 17 00:00:00 2001 From: Frank Dellaert Date: Tue, 28 Jan 2025 15:57:56 -0500 Subject: [PATCH 36/43] Wrapper --- gtsam/discrete/DiscreteSearch.h | 2 ++ gtsam/discrete/discrete.i | 25 +++++++++++++++++++++++++ 2 files changed, 27 insertions(+) diff --git a/gtsam/discrete/DiscreteSearch.h b/gtsam/discrete/DiscreteSearch.h index b610955b2..db3dd5f03 100644 --- a/gtsam/discrete/DiscreteSearch.h +++ b/gtsam/discrete/DiscreteSearch.h @@ -161,4 +161,6 @@ class GTSAM_EXPORT DiscreteSearch { double lowerBound_; ///< Lower bound on the cost-to-go for the entire search. std::vector slots_; ///< The slots to fill in the search. }; + +using DiscreteSearchSolution = DiscreteSearch::Solution; // for wrapping } // namespace gtsam diff --git a/gtsam/discrete/discrete.i b/gtsam/discrete/discrete.i index b84ac69a0..5e4d8d22d 100644 --- a/gtsam/discrete/discrete.i +++ b/gtsam/discrete/discrete.i @@ -464,4 +464,29 @@ class DiscreteJunctionTree { const gtsam::DiscreteCluster& operator[](size_t i) const; }; +#include +class DiscreteSearchSolution { + double error; + gtsam::DiscreteValues assignment; + DiscreteSearchSolution(double error, const gtsam::DiscreteValues& assignment); +}; + +class DiscreteSearch { + static DiscreteSearch FromFactorGraph(const gtsam::DiscreteFactorGraph& factorGraph, + const gtsam::Ordering& ordering, + bool buildJunctionTree = false); + + DiscreteSearch(const gtsam::DiscreteEliminationTree& etree); + DiscreteSearch(const gtsam::DiscreteJunctionTree& junctionTree); + DiscreteSearch(const gtsam::DiscreteBayesNet& bayesNet); + DiscreteSearch(const gtsam::DiscreteBayesTree& bayesTree); + + void print(string name = "DiscreteSearch: ", + const gtsam::KeyFormatter& formatter = gtsam::DefaultKeyFormatter) const; + + double lowerBound() const; + + std::vector run(size_t K = 1) const; +}; + } // namespace gtsam From 615196e41504961b792c1630986c1fdf8f229ade Mon Sep 17 00:00:00 2001 From: Frank Dellaert Date: Tue, 28 Jan 2025 15:58:15 -0500 Subject: [PATCH 37/43] Avoid some copy/paste --- python/gtsam/tests/dfg_utils.py | 35 +++++++ .../gtsam/tests/test_DiscreteFactorGraph.py | 92 +++++++------------ 2 files changed, 69 insertions(+), 58 deletions(-) create mode 100644 python/gtsam/tests/dfg_utils.py diff --git a/python/gtsam/tests/dfg_utils.py b/python/gtsam/tests/dfg_utils.py new file mode 100644 index 000000000..9ad521fd4 --- /dev/null +++ b/python/gtsam/tests/dfg_utils.py @@ -0,0 +1,35 @@ +import numpy as np +from gtsam import Symbol + + +def make_key(character, index, cardinality): + """ + Helper function to mimic the behavior of gtbook.Variables discrete_series function. + """ + symbol = Symbol(character, index) + key = symbol.key() + return (key, cardinality) + + +def generate_transition_cpt(num_states, transitions=None): + """ + Generate a row-wise CPT for a transition matrix. + """ + if transitions is None: + # Default to identity matrix with slight regularization + transitions = np.eye(num_states) + 0.1 / num_states + + # Ensure transitions sum to 1 if not already normalized + transitions /= np.sum(transitions, axis=1, keepdims=True) + return " ".join(["/".join(map(str, row)) for row in transitions]) + + +def generate_observation_cpt(num_states, num_obs, desired_state): + """ + Generate a row-wise CPT for observations with contrived probabilities. + """ + obs = np.zeros((num_states, num_obs + 1)) + obs[:, -1] = 1 # All states default to measurement num_obs + obs[desired_state, 0:-1] = 1 + obs[desired_state, -1] = 0 + return " ".join(["/".join(map(str, row)) for row in obs]) diff --git a/python/gtsam/tests/test_DiscreteFactorGraph.py b/python/gtsam/tests/test_DiscreteFactorGraph.py index 3053087b4..521eeefa6 100644 --- a/python/gtsam/tests/test_DiscreteFactorGraph.py +++ b/python/gtsam/tests/test_DiscreteFactorGraph.py @@ -15,10 +15,16 @@ import unittest import numpy as np from gtsam.utils.test_case import GtsamTestCase +from dfg_utils import make_key, generate_transition_cpt, generate_observation_cpt -from gtsam import (DecisionTreeFactor, DiscreteConditional, - DiscreteFactorGraph, DiscreteKeys, DiscreteValues, Ordering, - Symbol) +from gtsam import ( + DecisionTreeFactor, + DiscreteConditional, + DiscreteFactorGraph, + DiscreteKeys, + DiscreteValues, + Ordering, +) OrderingType = Ordering.OrderingType @@ -50,7 +56,7 @@ class TestDiscreteFactorGraph(GtsamTestCase): assignment[1] = 1 # Check if graph evaluation works ( 0.3*0.6*4 ) - self.assertAlmostEqual(.72, graph(assignment)) + self.assertAlmostEqual(0.72, graph(assignment)) # Create a new test with third node and adding unary and ternary factor graph.add(P3, "0.9 0.2 0.5") @@ -100,8 +106,7 @@ class TestDiscreteFactorGraph(GtsamTestCase): expectedValues[1] = 0 expectedValues[2] = 0 actualValues = graph.optimize() - self.assertEqual(list(actualValues.items()), - list(expectedValues.items())) + self.assertEqual(list(actualValues.items()), list(expectedValues.items())) def test_MPE(self): """Test maximum probable explanation (MPE): same as optimize.""" @@ -123,13 +128,11 @@ class TestDiscreteFactorGraph(GtsamTestCase): # Use maxProduct dag = graph.maxProduct(OrderingType.COLAMD) actualMPE = dag.argmax() - self.assertEqual(list(actualMPE.items()), - list(mpe.items())) + self.assertEqual(list(actualMPE.items()), list(mpe.items())) # All in one actualMPE2 = graph.optimize() - self.assertEqual(list(actualMPE2.items()), - list(mpe.items())) + self.assertEqual(list(actualMPE2.items()), list(mpe.items())) def test_sumProduct(self): """Test sumProduct.""" @@ -154,11 +157,17 @@ class TestDiscreteFactorGraph(GtsamTestCase): self.assertAlmostEqual(mpeProbability, 0.36) # regression # Use sumProduct - for ordering_type in [OrderingType.COLAMD, OrderingType.METIS, OrderingType.NATURAL, - OrderingType.CUSTOM]: + for ordering_type in [ + OrderingType.COLAMD, + OrderingType.METIS, + OrderingType.NATURAL, + OrderingType.CUSTOM, + ]: bayesNet = graph.sumProduct(ordering_type) self.assertEqual(bayesNet(mpe), mpeProbability) + +class TestChains(GtsamTestCase): def test_MPE_chain(self): """ Test for numerical underflow in EliminateMPE on long chains. @@ -170,46 +179,22 @@ class TestDiscreteFactorGraph(GtsamTestCase): desired_state = 1 states = list(range(num_states)) - # Helper function to mimic the behavior of gtbook.Variables discrete_series function - def make_key(character, index, cardinality): - symbol = Symbol(character, index) - key = symbol.key() - return (key, cardinality) - X = {index: make_key("X", index, len(states)) for index in range(num_obs)} Z = {index: make_key("Z", index, num_obs + 1) for index in range(num_obs)} graph = DiscreteFactorGraph() - # Mostly identity transition matrix - transitions = np.eye(num_states) - - # Needed otherwise mpe is always state 0? - transitions += 0.1/(num_states) - - transition_cpt = [] - for i in range(0, num_states): - transition_row = "/".join([str(x) for x in transitions[i]]) - transition_cpt.append(transition_row) - transition_cpt = " ".join(transition_cpt) - + transition_cpt = generate_transition_cpt(num_states) for i in reversed(range(1, num_obs)): - transition_conditional = DiscreteConditional(X[i], [X[i-1]], transition_cpt) + transition_conditional = DiscreteConditional( + X[i], [X[i - 1]], transition_cpt + ) graph.push_back(transition_conditional) # Contrived example such that the desired state gives measurements [0, num_obs) with equal probability # but all other states always give measurement num_obs - obs = np.zeros((num_states, num_obs+1)) - obs[:,-1] = 1 - obs[desired_state,0: -1] = 1 - obs[desired_state,-1] = 0 - obs_cpt_list = [] - for i in range(0, num_states): - obs_row = "/".join([str(z) for z in obs[i]]) - obs_cpt_list.append(obs_row) - obs_cpt = " ".join(obs_cpt_list) - + obs_cpt = generate_observation_cpt(num_states, num_obs, desired_state) # Contrived example where each measurement is its own index - for i in range(0, num_obs): + for i in range(num_obs): obs_conditional = DiscreteConditional(Z[i], [X[i]], obs_cpt) factor = obs_conditional.likelihood(i) graph.push_back(factor) @@ -217,7 +202,7 @@ class TestDiscreteFactorGraph(GtsamTestCase): mpe = graph.optimize() vals = [mpe[X[i][0]] for i in range(num_obs)] - self.assertEqual(vals, [desired_state]*num_obs) + self.assertEqual(vals, [desired_state] * num_obs) def test_sumProduct_chain(self): """ @@ -227,15 +212,8 @@ class TestDiscreteFactorGraph(GtsamTestCase): """ num_states = 3 chain_length = 400 - desired_state = 1 states = list(range(num_states)) - # Helper function to mimic the behavior of gtbook.Variables discrete_series function - def make_key(character, index, cardinality): - symbol = Symbol(character, index) - key = symbol.key() - return (key, cardinality) - X = {index: make_key("X", index, len(states)) for index in range(chain_length)} graph = DiscreteFactorGraph() @@ -253,18 +231,15 @@ class TestDiscreteFactorGraph(GtsamTestCase): # Ensure that the stationary distribution is positive and normalized stationary_dist /= np.sum(stationary_dist) - expected = DecisionTreeFactor(X[chain_length-1], stationary_dist.flatten()) + expected = DecisionTreeFactor(X[chain_length - 1], stationary_dist.ravel()) # The transition matrix parsed by DiscreteConditional is a row-wise CPT - transitions = transitions.T - transition_cpt = [] - for i in range(0, num_states): - transition_row = "/".join([str(x) for x in transitions[i]]) - transition_cpt.append(transition_row) - transition_cpt = " ".join(transition_cpt) + transition_cpt = generate_transition_cpt(num_states, transitions.T) for i in reversed(range(1, chain_length)): - transition_conditional = DiscreteConditional(X[i], [X[i-1]], transition_cpt) + transition_conditional = DiscreteConditional( + X[i], [X[i - 1]], transition_cpt + ) graph.push_back(transition_conditional) # Run sum product using natural ordering so the resulting Bayes net has the form: @@ -277,5 +252,6 @@ class TestDiscreteFactorGraph(GtsamTestCase): # Ensure marginal probabilities are close to the stationary distribution self.gtsamAssertEquals(expected, last_marginal) + if __name__ == "__main__": unittest.main() From 5e5a67d85316a6f41df542d8940f802c663bd4d9 Mon Sep 17 00:00:00 2001 From: Frank Dellaert Date: Tue, 28 Jan 2025 15:58:33 -0500 Subject: [PATCH 38/43] Test search with long chain --- python/gtsam/tests/test_DiscreteSearch.py | 84 +++++++++++++++++++++++ 1 file changed, 84 insertions(+) create mode 100644 python/gtsam/tests/test_DiscreteSearch.py diff --git a/python/gtsam/tests/test_DiscreteSearch.py b/python/gtsam/tests/test_DiscreteSearch.py new file mode 100644 index 000000000..d0077f6db --- /dev/null +++ b/python/gtsam/tests/test_DiscreteSearch.py @@ -0,0 +1,84 @@ +""" +GTSAM Copyright 2010-2019, Georgia Tech Research Corporation, +Atlanta, Georgia 30332-0415 +All Rights Reserved + +See LICENSE for the license information + +Unit tests for Discrete Search. +Author: Frank Dellaert +""" + +# pylint: disable=no-name-in-module, invalid-name + +import unittest + +from dfg_utils import generate_observation_cpt, generate_transition_cpt, make_key +from gtsam.utils.test_case import GtsamTestCase + +from gtsam import ( + DiscreteConditional, + DiscreteFactorGraph, + DiscreteSearch, + Ordering, + DefaultKeyFormatter, +) + +OrderingType = Ordering.OrderingType + + +class TestDiscreteSearch(GtsamTestCase): + """Tests for Discrete Factor Graphs.""" + + def test_MPE_chain(self): + """ + Test for numerical underflow in EliminateMPE on long chains. + Adapted from the toy problem of @pcl15423 + Ref: https://github.com/borglab/gtsam/issues/1448 + """ + num_states = 3 + num_obs = 200 + desired_state = 1 + states = list(range(num_states)) + + X = {index: make_key("X", index, len(states)) for index in range(num_obs)} + Z = {index: make_key("Z", index, num_obs + 1) for index in range(num_obs)} + graph = DiscreteFactorGraph() + + transition_cpt = generate_transition_cpt(num_states) + for i in reversed(range(1, num_obs)): + transition_conditional = DiscreteConditional( + X[i], [X[i - 1]], transition_cpt + ) + graph.push_back(transition_conditional) + + # Contrived example such that the desired state gives measurements [0, num_obs) with equal + # probability but all other states always give measurement num_obs + obs_cpt = generate_observation_cpt(num_states, num_obs, desired_state) + # Contrived example where each measurement is its own index + for i in range(num_obs): + obs_conditional = DiscreteConditional(Z[i], [X[i]], obs_cpt) + factor = obs_conditional.likelihood(i) + graph.push_back(factor) + + # Check MPE + mpe = graph.optimize() + vals = [mpe[X[i][0]] for i in range(num_obs)] + self.assertEqual(vals, [desired_state] * num_obs) + + # Create an ordering: + ordering = Ordering() + for i in reversed(range(num_obs)): + ordering.push_back(X[i][0]) + + # Now do Search + search = DiscreteSearch.FromFactorGraph(graph, ordering) + solutions = search.run(K=1) + mpe2 = solutions[0].assignment + # print({DefaultKeyFormatter(key): value for key, value in mpe2.items()}) + vals = [mpe2[X[i][0]] for i in range(num_obs)] + self.assertEqual(vals, [desired_state] * num_obs) + + +if __name__ == "__main__": + unittest.main() From d57994d279bad587af87bc717ef03f5b6cee515c Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Tue, 28 Jan 2025 16:11:32 -0500 Subject: [PATCH 39/43] kill expNormalize since it isn't used anywhere --- gtsam/discrete/DiscreteFactor.cpp | 45 ------------------------------- gtsam/discrete/DiscreteFactor.h | 18 ------------- 2 files changed, 63 deletions(-) diff --git a/gtsam/discrete/DiscreteFactor.cpp b/gtsam/discrete/DiscreteFactor.cpp index 68cc1df7d..c8e2e0fd2 100644 --- a/gtsam/discrete/DiscreteFactor.cpp +++ b/gtsam/discrete/DiscreteFactor.cpp @@ -71,49 +71,4 @@ AlgebraicDecisionTree DiscreteFactor::errorTree() const { return AlgebraicDecisionTree(dkeys, errors); } -/* ************************************************************************* */ -std::vector expNormalize(const std::vector& logProbs) { - double maxLogProb = -std::numeric_limits::infinity(); - for (size_t i = 0; i < logProbs.size(); i++) { - double logProb = logProbs[i]; - if ((logProb != std::numeric_limits::infinity()) && - logProb > maxLogProb) { - maxLogProb = logProb; - } - } - - // After computing the max = "Z" of the log probabilities L_i, we compute - // the log of the normalizing constant, log S, where S = sum_j exp(L_j - Z). - double total = 0.0; - for (size_t i = 0; i < logProbs.size(); i++) { - double probPrime = exp(logProbs[i] - maxLogProb); - total += probPrime; - } - double logTotal = log(total); - - // Now we compute the (normalized) probability (for each i): - // p_i = exp(L_i - Z - log S) - double checkNormalization = 0.0; - std::vector probs; - for (size_t i = 0; i < logProbs.size(); i++) { - double prob = exp(logProbs[i] - maxLogProb - logTotal); - probs.push_back(prob); - checkNormalization += prob; - } - - // Numerical tolerance for floating point comparisons - double tol = 1e-9; - - if (!gtsam::fpEqual(checkNormalization, 1.0, tol)) { - std::string errMsg = - std::string("expNormalize failed to normalize probabilities. ") + - std::string("Expected normalization constant = 1.0. Got value: ") + - std::to_string(checkNormalization) + - std::string( - "\n This could have resulted from numerical overflow/underflow."); - throw std::logic_error(errMsg); - } - return probs; -} - } // namespace gtsam diff --git a/gtsam/discrete/DiscreteFactor.h b/gtsam/discrete/DiscreteFactor.h index 6cbc00d09..92bbd18dd 100644 --- a/gtsam/discrete/DiscreteFactor.h +++ b/gtsam/discrete/DiscreteFactor.h @@ -212,22 +212,4 @@ class GTSAM_EXPORT DiscreteFactor : public Factor { template <> struct traits : public Testable {}; -/** - * @brief Normalize a set of log probabilities. - * - * Normalizing a set of log probabilities in a numerically stable way is - * tricky. To avoid overflow/underflow issues, we compute the largest - * (finite) log probability and subtract it from each log probability before - * normalizing. This comes from the observation that if: - * p_i = exp(L_i) / ( sum_j exp(L_j) ), - * Then, - * p_i = exp(Z) exp(L_i - Z) / (exp(Z) sum_j exp(L_j - Z)), - * = exp(L_i - Z) / ( sum_j exp(L_j - Z) ) - * - * Setting Z = max_j L_j, we can avoid numerical issues that arise when all - * of the (unnormalized) log probabilities are either very large or very - * small. - */ -std::vector expNormalize(const std::vector& logProbs); - } // namespace gtsam From c8bafab4307d7b2da98bf50b4f85e04cac53ba9e Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Tue, 28 Jan 2025 16:21:26 -0500 Subject: [PATCH 40/43] add scale() method in DiscreteFactor --- gtsam/discrete/DiscreteFactor.cpp | 8 ++++++++ gtsam/discrete/DiscreteFactor.h | 8 ++++++++ 2 files changed, 16 insertions(+) diff --git a/gtsam/discrete/DiscreteFactor.cpp b/gtsam/discrete/DiscreteFactor.cpp index c8e2e0fd2..faae02af2 100644 --- a/gtsam/discrete/DiscreteFactor.cpp +++ b/gtsam/discrete/DiscreteFactor.cpp @@ -71,4 +71,12 @@ AlgebraicDecisionTree DiscreteFactor::errorTree() const { return AlgebraicDecisionTree(dkeys, errors); } +/* ************************************************************************ */ +DiscreteFactor::shared_ptr DiscreteFactor::scale() const { + // Max over all the potentials by pretending all keys are frontal: + shared_ptr denominator = this->max(this->size()); + // Normalize the product factor to prevent underflow. + return this->operator/(denominator); +} + } // namespace gtsam diff --git a/gtsam/discrete/DiscreteFactor.h b/gtsam/discrete/DiscreteFactor.h index 92bbd18dd..fafb4dbf5 100644 --- a/gtsam/discrete/DiscreteFactor.h +++ b/gtsam/discrete/DiscreteFactor.h @@ -158,6 +158,14 @@ class GTSAM_EXPORT DiscreteFactor : public Factor { /// Create new factor by maximizing over all values with the same separator. virtual DiscreteFactor::shared_ptr max(const Ordering& keys) const = 0; + /** + * @brief Scale the factor values by the maximum + * to prevent underflow/overflow. + * + * @return DiscreteFactor::shared_ptr + */ + DiscreteFactor::shared_ptr scale() const; + /** * Get the number of non-zero values contained in this factor. * It could be much smaller than `prod_{key}(cardinality(key))`. From d4d95e2342e5c1d9759bc2e068d0d2049e97df70 Mon Sep 17 00:00:00 2001 From: Frank Dellaert Date: Tue, 28 Jan 2025 17:22:44 -0500 Subject: [PATCH 41/43] scale product and sum so we don't get all 0s --- gtsam/discrete/DiscreteFactorGraph.cpp | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/gtsam/discrete/DiscreteFactorGraph.cpp b/gtsam/discrete/DiscreteFactorGraph.cpp index 7e059c5e5..1fb353423 100644 --- a/gtsam/discrete/DiscreteFactorGraph.cpp +++ b/gtsam/discrete/DiscreteFactorGraph.cpp @@ -125,11 +125,8 @@ namespace gtsam { DiscreteFactor::shared_ptr product = this->product(); gttoc(product); - // Max over all the potentials by pretending all keys are frontal: - auto denominator = product->max(product->size()); - // Normalize the product factor to prevent underflow. - product = product->operator/(denominator); + product = product->scale(); return product; } @@ -222,6 +219,8 @@ namespace gtsam { // sum out frontals, this is the factor on the separator gttic(sum); DiscreteFactor::shared_ptr sum = product->sum(frontalKeys); + // Normalize/scale to prevent underflow. + sum = sum->scale(); gttoc(sum); // Ordering keys for the conditional so that frontalKeys are really in front From 2702854c188dc24324b89bf2ed23e2bb3c37c2b4 Mon Sep 17 00:00:00 2001 From: Frank Dellaert Date: Tue, 28 Jan 2025 21:15:35 -0500 Subject: [PATCH 42/43] Fix bug --- gtsam/discrete/DiscreteFactorGraph.cpp | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/gtsam/discrete/DiscreteFactorGraph.cpp b/gtsam/discrete/DiscreteFactorGraph.cpp index 1fb353423..321ec7147 100644 --- a/gtsam/discrete/DiscreteFactorGraph.cpp +++ b/gtsam/discrete/DiscreteFactorGraph.cpp @@ -214,14 +214,21 @@ namespace gtsam { std::pair // EliminateDiscrete(const DiscreteFactorGraph& factors, const Ordering& frontalKeys) { - DiscreteFactor::shared_ptr product = factors.scaledProduct(); + gttic(product); + DiscreteFactor::shared_ptr product = factors.product(); + gttoc(product); // sum out frontals, this is the factor on the separator gttic(sum); DiscreteFactor::shared_ptr sum = product->sum(frontalKeys); - // Normalize/scale to prevent underflow. - sum = sum->scale(); gttoc(sum); + + // Normalize/scale to prevent underflow. + gttic(scale); + DiscreteFactor::shared_ptr denominator = sum->max(sum->size()); + product = product->operator/(denominator); + sum = sum->operator/(denominator); + gttoc(scale); // Ordering keys for the conditional so that frontalKeys are really in front Ordering orderedKeys; From 2e4ae8ca0a5c520c9072bf0a6bbc9563e349d82c Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Tue, 28 Jan 2025 22:08:17 -0500 Subject: [PATCH 43/43] add explanatory comments --- gtsam/discrete/DiscreteFactorGraph.cpp | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/gtsam/discrete/DiscreteFactorGraph.cpp b/gtsam/discrete/DiscreteFactorGraph.cpp index 321ec7147..f2bae4b9b 100644 --- a/gtsam/discrete/DiscreteFactorGraph.cpp +++ b/gtsam/discrete/DiscreteFactorGraph.cpp @@ -215,6 +215,7 @@ namespace gtsam { EliminateDiscrete(const DiscreteFactorGraph& factors, const Ordering& frontalKeys) { gttic(product); + // `product` is scaled later to prevent underflow. DiscreteFactor::shared_ptr product = factors.product(); gttoc(product); @@ -222,8 +223,11 @@ namespace gtsam { gttic(sum); DiscreteFactor::shared_ptr sum = product->sum(frontalKeys); gttoc(sum); - + // Normalize/scale to prevent underflow. + // We divide both `product` and `sum` by `max(sum)` + // since it is faster to compute and when the conditional + // is formed by `product/sum`, the scaling term cancels out. gttic(scale); DiscreteFactor::shared_ptr denominator = sum->max(sum->size()); product = product->operator/(denominator);