diff --git a/gtsam/discrete/DiscreteSearch.cpp b/gtsam/discrete/DiscreteSearch.cpp new file mode 100644 index 000000000..c5941862d --- /dev/null +++ b/gtsam/discrete/DiscreteSearch.cpp @@ -0,0 +1,246 @@ +/* ---------------------------------------------------------------------------- + + * GTSAM Copyright 2010, Georgia Tech Research Corporation, + * Atlanta, Georgia 30332-0415 + * All Rights Reserved + * Authors: Frank Dellaert, et al. (see THANKS for the full author list) + + * See LICENSE for the license information + + * -------------------------------------------------------------------------- */ + +/* + * DiscreteSearch.cpp + * + * @date January, 2025 + * @author Frank Dellaert + */ + +#include + +namespace gtsam { + +using Solution = DiscreteSearch::Solution; + +/** + * @brief Represents a node in the search tree for discrete search algorithms. + * + * @details Each SearchNode contains a partial assignment of discrete variables, + * the current error, a bound on the final error, and the index of the next + * conditional to be assigned. + */ +struct SearchNode { + DiscreteValues assignment; ///< Partial assignment of discrete variables. + double error; ///< Current error for the partial assignment. + double bound; ///< Lower bound on the final error for unassigned variables. + int nextConditional; ///< Index of the next conditional to be assigned. + + /** + * @brief Construct the root node for the search. + */ + static SearchNode Root(size_t numConditionals, double bound) { + return {DiscreteValues(), 0.0, bound, + static_cast(numConditionals) - 1}; + } + + struct Compare { + bool operator()(const SearchNode& a, const SearchNode& b) const { + return a.bound > b.bound; // smallest bound -> highest priority + } + }; + + /** + * @brief Checks if the node represents a complete assignment. + * + * @return True if all variables have been assigned, false otherwise. + */ + inline bool isComplete() const { return nextConditional < 0; } + + /** + * @brief Expands the node by assigning the next variable. + * + * @param conditional The discrete conditional representing the next variable + * to be assigned. + * @param fa The frontal assignment for the next variable. + * @return A new SearchNode representing the expanded state. + */ + SearchNode expand(const DiscreteConditional& conditional, + const DiscreteValues& fa) const { + // Combine the new frontal assignment with the current partial assignment + DiscreteValues newAssignment = assignment; + for (auto& [key, value] : fa) { + newAssignment[key] = value; + } + + return {newAssignment, error + conditional.error(newAssignment), 0.0, + nextConditional - 1}; + } + + /** + * @brief Prints the SearchNode to an output stream. + * + * @param os The output stream. + * @param node The SearchNode to be printed. + * @return The output stream. + */ + friend std::ostream& operator<<(std::ostream& os, const SearchNode& node) { + os << "SearchNode(error=" << node.error << ", bound=" << node.bound << ")"; + return os; + } +}; + +struct CompareSolution { + bool operator()(const Solution& a, const Solution& b) const { + return a.error < b.error; + } +}; + +// Define the Solutions class +class Solutions { + private: + size_t maxSize_; + std::priority_queue, CompareSolution> pq_; + + public: + Solutions(size_t maxSize) : maxSize_(maxSize) {} + + /// Add a solution to the priority queue, possibly evicting the worst one. + /// Return true if we added the solution. + bool maybeAdd(double error, const DiscreteValues& assignment) { + const bool full = pq_.size() == maxSize_; + if (full && error >= pq_.top().error) return false; + if (full) pq_.pop(); + pq_.emplace(error, assignment); + return true; + } + + /// Check if we have any solutions + bool empty() const { return pq_.empty(); } + + // Method to print all solutions + friend std::ostream& operator<<(std::ostream& os, const Solutions& sn) { + os << "Solutions (top " << sn.pq_.size() << "):\n"; + auto pq = sn.pq_; + while (!pq.empty()) { + os << pq.top() << "\n"; + pq.pop(); + } + return os; + } + + /// Check if (partial) solution with given bound can be pruned. If we have + /// room, we never prune. Otherwise, prune if lower bound on error is worse + /// than our current worst error. + bool prune(double bound) const { + if (pq_.size() < maxSize_) return false; + return bound >= pq_.top().error; + } + + // Method to extract solutions in ascending order of error + std::vector extractSolutions() { + std::vector result; + while (!pq_.empty()) { + result.push_back(pq_.top()); + pq_.pop(); + } + std::sort( + result.begin(), result.end(), + [](const Solution& a, const Solution& b) { return a.error < b.error; }); + return result; + } +}; + +DiscreteSearch::DiscreteSearch(const DiscreteBayesNet& bayesNet) { + std::vector conditionals; + for (auto& factor : bayesNet) conditionals_.push_back(factor); + costToGo_ = computeCostToGo(conditionals_); +} + +DiscreteSearch::DiscreteSearch(const DiscreteBayesTree& bayesTree) { + std::function + collectConditionals = [&](const auto& clique) { + if (!clique) return; + for (const auto& child : clique->children) collectConditionals(child); + conditionals_.push_back(clique->conditional()); + }; + for (const auto& root : bayesTree.roots()) collectConditionals(root); + costToGo_ = computeCostToGo(conditionals_); +} + +struct SearchNodeQueue + : public std::priority_queue, + SearchNode::Compare> { + void expandNextNode( + const std::vector& conditionals, + const std::vector& costToGo, Solutions* solutions) { + // Pop the partial assignment with the smallest bound + SearchNode current = top(); + pop(); + + // If we already have K solutions, prune if we cannot beat the worst one. + if (solutions->prune(current.bound)) { + return; + } + + // Check if we have a complete assignment + if (current.isComplete()) { + solutions->maybeAdd(current.error, current.assignment); + return; + } + + // Expand on the next factor + const auto& conditional = conditionals[current.nextConditional]; + + for (auto& fa : conditional->frontalAssignments()) { + auto childNode = current.expand(*conditional, fa); + if (childNode.nextConditional >= 0) + childNode.bound = childNode.error + costToGo[childNode.nextConditional]; + + // Again, prune if we cannot beat the worst solution + if (!solutions->prune(childNode.bound)) { + emplace(childNode); + } + } + } +}; + +std::vector DiscreteSearch::run(size_t K) const { + Solutions solutions(K); + SearchNodeQueue expansions; + expansions.push(SearchNode::Root(conditionals_.size(), + costToGo_.empty() ? 0.0 : costToGo_.back())); + +#ifdef DISCRETE_SEARCH_DEBUG + size_t numExpansions = 0; +#endif + + // Perform the search + while (!expansions.empty()) { + expansions.expandNextNode(conditionals_, costToGo_, &solutions); +#ifdef DISCRETE_SEARCH_DEBUG + ++numExpansions; +#endif + } + +#ifdef DISCRETE_SEARCH_DEBUG + std::cout << "Number of expansions: " << numExpansions << std::endl; +#endif + + // Extract solutions from bestSolutions in ascending order of error + return solutions.extractSolutions(); +} + +std::vector DiscreteSearch::computeCostToGo( + const std::vector& conditionals) { + std::vector costToGo; + double error = 0.0; + for (const auto& conditional : conditionals) { + Ordering ordering(conditional->begin(), conditional->end()); + auto maxx = conditional->max(ordering); + error -= std::log(maxx->evaluate({})); + costToGo.push_back(error); + } + return costToGo; +} + +} // namespace gtsam diff --git a/gtsam/discrete/DiscreteSearch.h b/gtsam/discrete/DiscreteSearch.h new file mode 100644 index 000000000..6202880b2 --- /dev/null +++ b/gtsam/discrete/DiscreteSearch.h @@ -0,0 +1,78 @@ +/* ---------------------------------------------------------------------------- + + * GTSAM Copyright 2010, Georgia Tech Research Corporation, + * Atlanta, Georgia 30332-0415 + * All Rights Reserved + * Authors: Frank Dellaert, et al. (see THANKS for the full author list) + + * See LICENSE for the license information + + * -------------------------------------------------------------------------- */ + +/* + * DiscreteSearch.cpp + * + * @date January, 2025 + * @author Frank Dellaert + */ + +#include +#include + +#include + +namespace gtsam { + +/** + * DiscreteSearch: Search for the K best solutions. + */ +class GTSAM_EXPORT DiscreteSearch { + public: + /** + * @brief A solution to a discrete search problem. + */ + struct Solution { + double error; + DiscreteValues assignment; + Solution(double err, const DiscreteValues& assign) + : error(err), assignment(assign) {} + friend std::ostream& operator<<(std::ostream& os, const Solution& sn) { + os << "[ error=" << sn.error << " assignment={" << sn.assignment << "}]"; + return os; + } + }; + + /** + * Construct from a DiscreteBayesNet and K. + */ + DiscreteSearch(const DiscreteBayesNet& bayesNet); + + /** + * Construct from a DiscreteBayesTree and K. + */ + DiscreteSearch(const DiscreteBayesTree& bayesTree); + + /** + * @brief Search for the K best solutions. + * + * This method performs a search to find the K best solutions for the given + * DiscreteBayesNet. It uses a priority queue to manage the search nodes, + * expanding nodes with the smallest bound first. The search continues until + * all possible nodes have been expanded or pruned. + * + * @return A vector of the K best solutions found during the search. + */ + std::vector run(size_t K = 1) const; + + private: + /// Compute the cumulative cost-to-go for each conditional slot. + static std::vector computeCostToGo( + const std::vector& conditionals); + + /// Expand the next node in the search tree. + void expandNextNode() const; + + std::vector conditionals_; + std::vector costToGo_; +}; +} // namespace gtsam diff --git a/gtsam/discrete/DiscreteValues.cpp b/gtsam/discrete/DiscreteValues.cpp index 416dfb888..3c3ed4468 100644 --- a/gtsam/discrete/DiscreteValues.cpp +++ b/gtsam/discrete/DiscreteValues.cpp @@ -26,12 +26,24 @@ using std::stringstream; namespace gtsam { +/* ************************************************************************ */ +static void stream(std::ostream& os, const DiscreteValues& x, + const KeyFormatter& keyFormatter) { + for (const auto& kv : x) + os << "(" << keyFormatter(kv.first) << ", " << kv.second << ")"; +} + +/* ************************************************************************ */ +std::ostream& operator<<(std::ostream& os, const DiscreteValues& x) { + stream(os, x, DefaultKeyFormatter); + return os; +} + /* ************************************************************************ */ void DiscreteValues::print(const string& s, const KeyFormatter& keyFormatter) const { cout << s << ": "; - for (auto&& kv : *this) - cout << "(" << keyFormatter(kv.first) << ", " << kv.second << ")"; + stream(cout, *this, keyFormatter); cout << endl; } diff --git a/gtsam/discrete/DiscreteValues.h b/gtsam/discrete/DiscreteValues.h index fa8a8a846..df4ecdbff 100644 --- a/gtsam/discrete/DiscreteValues.h +++ b/gtsam/discrete/DiscreteValues.h @@ -64,6 +64,9 @@ class GTSAM_EXPORT DiscreteValues : public Assignment { /// @name Standard Interface /// @{ + /// ostream operator: + friend std::ostream& operator<<(std::ostream& os, const DiscreteValues& x); + // insert in base class; std::pair insert( const value_type& value ){ return Base::insert(value); diff --git a/gtsam/discrete/tests/AsiaExample.h b/gtsam/discrete/tests/AsiaExample.h new file mode 100644 index 000000000..6c327daec --- /dev/null +++ b/gtsam/discrete/tests/AsiaExample.h @@ -0,0 +1,61 @@ +/* ---------------------------------------------------------------------------- + + * GTSAM Copyright 2010, Georgia Tech Research Corporation, + * Atlanta, Georgia 30332-0415 + * All Rights Reserved + * Authors: Frank Dellaert, et al. (see THANKS for the full author list) + + * See LICENSE for the license information + + * -------------------------------------------------------------------------- */ + +/* + * AsiaExample.h + * + * @date Jan, 2025 + * @author Frank Dellaert + */ + +#include +#include + +namespace gtsam { +namespace asia_example { + +static const Key D = Symbol('D', 1), X = Symbol('X', 2), E = Symbol('E', 3), + B = Symbol('B', 4), L = Symbol('L', 5), T = Symbol('T', 6), + S = Symbol('S', 7), A = Symbol('A', 8); + +static const DiscreteKey Dyspnea(D, 2), XRay(X, 2), Either(E, 2), + Bronchitis(B, 2), LungCancer(L, 2), Tuberculosis(T, 2), Smoking(S, 2), + Asia(A, 2); + +// Function to construct the Asia priors +DiscreteBayesNet createPriors() { + DiscreteBayesNet priors; + priors.add(Smoking % "50/50"); + priors.add(Asia, "99/1"); + return priors; +} + +// Function to construct the incomplete Asia example +DiscreteBayesNet createFragment() { + DiscreteBayesNet fragment; + fragment.add((Either | Tuberculosis, LungCancer) = "F T T T"); + fragment.add(LungCancer | Smoking = "99/1 90/10"); + fragment.add(Tuberculosis | Asia = "99/1 95/5"); + for (const auto& factor : createPriors()) fragment.push_back(factor); + return fragment; +} + +// Function to construct the Asia example +DiscreteBayesNet createAsiaExample() { + DiscreteBayesNet asia; + asia.add((Dyspnea | Either, Bronchitis) = "9/1 2/8 3/7 1/9"); + asia.add(XRay | Either = "95/5 2/98"); + asia.add(Bronchitis | Smoking = "70/30 40/60"); + for (const auto& factor : createFragment()) asia.push_back(factor); + return asia; +} +} // namespace asia_example +} // namespace gtsam \ No newline at end of file diff --git a/gtsam/discrete/tests/testDiscreteBayesNet.cpp b/gtsam/discrete/tests/testDiscreteBayesNet.cpp index d2033909c..dd5d218f8 100644 --- a/gtsam/discrete/tests/testDiscreteBayesNet.cpp +++ b/gtsam/discrete/tests/testDiscreteBayesNet.cpp @@ -23,40 +23,19 @@ #include #include #include +#include #include #include #include -using namespace std; +#include "AsiaExample.h" + using namespace gtsam; -static const DiscreteKey Asia(0, 2), Smoking(4, 2), Tuberculosis(3, 2), - LungCancer(6, 2), Bronchitis(7, 2), Either(5, 2), XRay(2, 2), Dyspnea(1, 2); - -using ADT = AlgebraicDecisionTree; - -// Function to construct the Asia example -DiscreteBayesNet constructAsiaExample() { - DiscreteBayesNet asia; - - asia.add(Asia, "99/1"); - asia.add(Smoking % "50/50"); // Signature version - - asia.add(Tuberculosis | Asia = "99/1 95/5"); - asia.add(LungCancer | Smoking = "99/1 90/10"); - asia.add(Bronchitis | Smoking = "70/30 40/60"); - - asia.add((Either | Tuberculosis, LungCancer) = "F T T T"); - - asia.add(XRay | Either = "95/5 2/98"); - asia.add((Dyspnea | Either, Bronchitis) = "9/1 2/8 3/7 1/9"); - - return asia; -} - /* ************************************************************************* */ TEST(DiscreteBayesNet, bayesNet) { + using ADT = AlgebraicDecisionTree; DiscreteBayesNet bayesNet; DiscreteKey Parent(0, 2), Child(1, 2); @@ -86,11 +65,12 @@ TEST(DiscreteBayesNet, bayesNet) { /* ************************************************************************* */ TEST(DiscreteBayesNet, Asia) { - DiscreteBayesNet asia = constructAsiaExample(); + using namespace asia_example; + const DiscreteBayesNet asia = createAsiaExample(); // Convert to factor graph DiscreteFactorGraph fg(asia); - LONGS_EQUAL(3, fg.back()->size()); + LONGS_EQUAL(1, fg.back()->size()); // Check the marginals we know (of the parent-less nodes) DiscreteMarginals marginals(fg); @@ -99,7 +79,7 @@ TEST(DiscreteBayesNet, Asia) { EXPECT(assert_equal(vs, marginals.marginalProbabilities(Smoking))); // Create solver and eliminate - const Ordering ordering{0, 1, 2, 3, 4, 5, 6, 7}; + const Ordering ordering{A, D, T, X, S, E, L, B}; DiscreteBayesNet::shared_ptr chordal = fg.eliminateSequential(ordering); DiscreteConditional expected2(Bronchitis % "11/9"); EXPECT(assert_equal(expected2, *chordal->back())); @@ -144,55 +124,50 @@ TEST(DiscreteBayesNet, Sugar) { /* ************************************************************************* */ TEST(DiscreteBayesNet, Dot) { - DiscreteBayesNet fragment; - fragment.add(Asia % "99/1"); - fragment.add(Smoking % "50/50"); + using namespace asia_example; + const DiscreteBayesNet fragment = createFragment(); - fragment.add(Tuberculosis | Asia = "99/1 95/5"); - fragment.add(LungCancer | Smoking = "99/1 90/10"); - fragment.add((Either | Tuberculosis, LungCancer) = "F T T T"); - - string actual = fragment.dot(); - EXPECT(actual == - "digraph {\n" - " size=\"5,5\";\n" - "\n" - " var0[label=\"0\"];\n" - " var3[label=\"3\"];\n" - " var4[label=\"4\"];\n" - " var5[label=\"5\"];\n" - " var6[label=\"6\"];\n" - "\n" - " var3->var5\n" - " var6->var5\n" - " var4->var6\n" - " var0->var3\n" - "}"); + std::string expected = + "digraph {\n" + " size=\"5,5\";\n" + "\n" + " var4683743612465315848[label=\"A8\"];\n" + " var4971973988617027587[label=\"E3\"];\n" + " var5476377146882523141[label=\"L5\"];\n" + " var5980780305148018695[label=\"S7\"];\n" + " var6052837899185946630[label=\"T6\"];\n" + "\n" + " var4683743612465315848->var6052837899185946630\n" + " var5980780305148018695->var5476377146882523141\n" + " var6052837899185946630->var4971973988617027587\n" + " var5476377146882523141->var4971973988617027587\n" + "}"; + std::string actual = fragment.dot(); + EXPECT(actual.compare(expected) == 0); } /* ************************************************************************* */ // Check markdown representation looks as expected. TEST(DiscreteBayesNet, markdown) { - DiscreteBayesNet fragment; - fragment.add(Asia % "99/1"); - fragment.add(Smoking | Asia = "8/2 7/3"); + using namespace asia_example; + DiscreteBayesNet priors = createPriors(); - string expected = + std::string expected = "`DiscreteBayesNet` of size 2\n" "\n" + " *P(Smoking):*\n\n" + "|Smoking|value|\n" + "|:-:|:-:|\n" + "|0|0.5|\n" + "|1|0.5|\n" + "\n" " *P(Asia):*\n\n" "|Asia|value|\n" "|:-:|:-:|\n" "|0|0.99|\n" - "|1|0.01|\n" - "\n" - " *P(Smoking|Asia):*\n\n" - "|*Asia*|0|1|\n" - "|:-:|:-:|:-:|\n" - "|0|0.8|0.2|\n" - "|1|0.7|0.3|\n\n"; - auto formatter = [](Key key) { return key == 0 ? "Asia" : "Smoking"; }; - string actual = fragment.markdown(formatter); + "|1|0.01|\n\n"; + auto formatter = [](Key key) { return key == A ? "Asia" : "Smoking"; }; + std::string actual = priors.markdown(formatter); EXPECT(actual == expected); } diff --git a/gtsam/discrete/tests/testDiscreteSearch.cpp b/gtsam/discrete/tests/testDiscreteSearch.cpp new file mode 100644 index 000000000..b537dd2f0 --- /dev/null +++ b/gtsam/discrete/tests/testDiscreteSearch.cpp @@ -0,0 +1,111 @@ +/* ---------------------------------------------------------------------------- + + * GTSAM Copyright 2010, Georgia Tech Research Corporation, + * Atlanta, Georgia 30332-0415 + * All Rights Reserved + * Authors: Frank Dellaert, et al. (see THANKS for the full author list) + + * See LICENSE for the license information + + * -------------------------------------------------------------------------- */ + +/* + * testDiscreteSearch.cpp + * + * @date January, 2025 + * @author Frank Dellaert + */ + +#include +#include +#include + +#include "AsiaExample.h" + +using namespace gtsam; + +// Create Asia Bayes net, FG, and Bayes tree once +namespace asia { +using namespace asia_example; +static const DiscreteBayesNet bayesNet = createAsiaExample(); +static const DiscreteFactorGraph factorGraph(bayesNet); +static const DiscreteValues mpe = factorGraph.optimize(); +static const Ordering ordering{D, X, B, E, L, T, S, A}; +static const DiscreteBayesTree bayesTree = + *factorGraph.eliminateMultifrontal(ordering); +} // namespace asia + +/* ************************************************************************* */ +TEST(DiscreteBayesNet, EmptyKBest) { + DiscreteBayesNet net; // no factors + DiscreteSearch search(net); + auto solutions = search.run(3); + // Expect one solution with empty assignment, error=0 + EXPECT_LONGS_EQUAL(1, solutions.size()); + EXPECT_DOUBLES_EQUAL(0, std::fabs(solutions[0].error), 1e-9); +} + +/* ************************************************************************* */ +TEST(DiscreteBayesNet, AsiaKBest) { + const DiscreteSearch search(asia::bayesNet); + + // Ask for the MPE + auto mpe = search.run(); + + EXPECT_LONGS_EQUAL(1, mpe.size()); + // Regression test: check the MPE solution + EXPECT_DOUBLES_EQUAL(1.236627, std::fabs(mpe[0].error), 1e-5); + + // Check it is equal to MPE via inference + EXPECT(assert_equal(asia::mpe, mpe[0].assignment)); + + // Ask for top 4 solutions + auto solutions = search.run(4); + + EXPECT_LONGS_EQUAL(4, solutions.size()); + // Regression test: check the first and last solution + EXPECT_DOUBLES_EQUAL(1.236627, std::fabs(solutions[0].error), 1e-5); + EXPECT_DOUBLES_EQUAL(2.201708, std::fabs(solutions[3].error), 1e-5); +} + +/* ************************************************************************* */ +TEST(DiscreteBayesTree, EmptyTree) { + DiscreteBayesTree bt; + + DiscreteSearch search(bt); + auto solutions = search.run(3); + + // We expect exactly 1 solution with error = 0.0 (the empty assignment). + EXPECT_LONGS_EQUAL(1, solutions.size()); + EXPECT_DOUBLES_EQUAL(0, std::fabs(solutions[0].error), 1e-9); +} + +/* ************************************************************************* */ +TEST(DiscreteBayesTree, AsiaTreeKBest) { + DiscreteSearch search(asia::bayesTree); + + // Ask for MPE + auto mpe = search.run(); + + EXPECT_LONGS_EQUAL(1, mpe.size()); + // Regression test: check the MPE solution + EXPECT_DOUBLES_EQUAL(1.236627, std::fabs(mpe[0].error), 1e-5); + + // Check it is equal to MPE via inference + EXPECT(assert_equal(asia::mpe, mpe[0].assignment)); + + // Ask for top 4 solutions + auto solutions = search.run(4); + + EXPECT_LONGS_EQUAL(4, solutions.size()); + // Regression test: check the first and last solution + EXPECT_DOUBLES_EQUAL(1.236627, std::fabs(solutions[0].error), 1e-5); + EXPECT_DOUBLES_EQUAL(2.201708, std::fabs(solutions[3].error), 1e-5); +} + +/* ************************************************************************* */ +int main() { + TestResult tr; + return TestRegistry::runAllTests(tr); +} +/* ************************************************************************* */