commit
						3f6ae48dfb
					
				|  | @ -0,0 +1,246 @@ | |||
| /* ----------------------------------------------------------------------------
 | ||||
| 
 | ||||
|  * GTSAM Copyright 2010, Georgia Tech Research Corporation, | ||||
|  * Atlanta, Georgia 30332-0415 | ||||
|  * All Rights Reserved | ||||
|  * Authors: Frank Dellaert, et al. (see THANKS for the full author list) | ||||
| 
 | ||||
|  * See LICENSE for the license information | ||||
| 
 | ||||
|  * -------------------------------------------------------------------------- */ | ||||
| 
 | ||||
| /*
 | ||||
|  * DiscreteSearch.cpp | ||||
|  * | ||||
|  * @date January, 2025 | ||||
|  * @author Frank Dellaert | ||||
|  */ | ||||
| 
 | ||||
| #include <gtsam/discrete/DiscreteSearch.h> | ||||
| 
 | ||||
| namespace gtsam { | ||||
| 
 | ||||
| using Solution = DiscreteSearch::Solution; | ||||
| 
 | ||||
| /**
 | ||||
|  * @brief Represents a node in the search tree for discrete search algorithms. | ||||
|  * | ||||
|  * @details Each SearchNode contains a partial assignment of discrete variables, | ||||
|  * the current error, a bound on the final error, and the index of the next | ||||
|  * conditional to be assigned. | ||||
|  */ | ||||
| struct SearchNode { | ||||
|   DiscreteValues assignment;  ///< Partial assignment of discrete variables.
 | ||||
|   double error;               ///< Current error for the partial assignment.
 | ||||
|   double bound;  ///< Lower bound on the final error for unassigned variables.
 | ||||
|   int nextConditional;  ///< Index of the next conditional to be assigned.
 | ||||
| 
 | ||||
|   /**
 | ||||
|    * @brief Construct the root node for the search. | ||||
|    */ | ||||
|   static SearchNode Root(size_t numConditionals, double bound) { | ||||
|     return {DiscreteValues(), 0.0, bound, | ||||
|             static_cast<int>(numConditionals) - 1}; | ||||
|   } | ||||
| 
 | ||||
|   struct Compare { | ||||
|     bool operator()(const SearchNode& a, const SearchNode& b) const { | ||||
|       return a.bound > b.bound;  // smallest bound -> highest priority
 | ||||
|     } | ||||
|   }; | ||||
| 
 | ||||
|   /**
 | ||||
|    * @brief Checks if the node represents a complete assignment. | ||||
|    * | ||||
|    * @return True if all variables have been assigned, false otherwise. | ||||
|    */ | ||||
|   inline bool isComplete() const { return nextConditional < 0; } | ||||
| 
 | ||||
|   /**
 | ||||
|    * @brief Expands the node by assigning the next variable. | ||||
|    * | ||||
|    * @param conditional The discrete conditional representing the next variable | ||||
|    * to be assigned. | ||||
|    * @param fa The frontal assignment for the next variable. | ||||
|    * @return A new SearchNode representing the expanded state. | ||||
|    */ | ||||
|   SearchNode expand(const DiscreteConditional& conditional, | ||||
|                     const DiscreteValues& fa) const { | ||||
|     // Combine the new frontal assignment with the current partial assignment
 | ||||
|     DiscreteValues newAssignment = assignment; | ||||
|     for (auto& [key, value] : fa) { | ||||
|       newAssignment[key] = value; | ||||
|     } | ||||
| 
 | ||||
|     return {newAssignment, error + conditional.error(newAssignment), 0.0, | ||||
|             nextConditional - 1}; | ||||
|   } | ||||
| 
 | ||||
|   /**
 | ||||
|    * @brief Prints the SearchNode to an output stream. | ||||
|    * | ||||
|    * @param os The output stream. | ||||
|    * @param node The SearchNode to be printed. | ||||
|    * @return The output stream. | ||||
|    */ | ||||
|   friend std::ostream& operator<<(std::ostream& os, const SearchNode& node) { | ||||
|     os << "SearchNode(error=" << node.error << ", bound=" << node.bound << ")"; | ||||
|     return os; | ||||
|   } | ||||
| }; | ||||
| 
 | ||||
| struct CompareSolution { | ||||
|   bool operator()(const Solution& a, const Solution& b) const { | ||||
|     return a.error < b.error; | ||||
|   } | ||||
| }; | ||||
| 
 | ||||
| // Define the Solutions class
 | ||||
| class Solutions { | ||||
|  private: | ||||
|   size_t maxSize_; | ||||
|   std::priority_queue<Solution, std::vector<Solution>, CompareSolution> pq_; | ||||
| 
 | ||||
|  public: | ||||
|   Solutions(size_t maxSize) : maxSize_(maxSize) {} | ||||
| 
 | ||||
|   /// Add a solution to the priority queue, possibly evicting the worst one.
 | ||||
|   /// Return true if we added the solution.
 | ||||
|   bool maybeAdd(double error, const DiscreteValues& assignment) { | ||||
|     const bool full = pq_.size() == maxSize_; | ||||
|     if (full && error >= pq_.top().error) return false; | ||||
|     if (full) pq_.pop(); | ||||
|     pq_.emplace(error, assignment); | ||||
|     return true; | ||||
|   } | ||||
| 
 | ||||
|   /// Check if we have any solutions
 | ||||
|   bool empty() const { return pq_.empty(); } | ||||
| 
 | ||||
|   // Method to print all solutions
 | ||||
|   friend std::ostream& operator<<(std::ostream& os, const Solutions& sn) { | ||||
|     os << "Solutions (top " << sn.pq_.size() << "):\n"; | ||||
|     auto pq = sn.pq_; | ||||
|     while (!pq.empty()) { | ||||
|       os << pq.top() << "\n"; | ||||
|       pq.pop(); | ||||
|     } | ||||
|     return os; | ||||
|   } | ||||
| 
 | ||||
|   /// Check if (partial) solution with given bound can be pruned. If we have
 | ||||
|   /// room, we never prune. Otherwise, prune if lower bound on error is worse
 | ||||
|   /// than our current worst error.
 | ||||
|   bool prune(double bound) const { | ||||
|     if (pq_.size() < maxSize_) return false; | ||||
|     return bound >= pq_.top().error; | ||||
|   } | ||||
| 
 | ||||
|   // Method to extract solutions in ascending order of error
 | ||||
|   std::vector<Solution> extractSolutions() { | ||||
|     std::vector<Solution> result; | ||||
|     while (!pq_.empty()) { | ||||
|       result.push_back(pq_.top()); | ||||
|       pq_.pop(); | ||||
|     } | ||||
|     std::sort( | ||||
|         result.begin(), result.end(), | ||||
|         [](const Solution& a, const Solution& b) { return a.error < b.error; }); | ||||
|     return result; | ||||
|   } | ||||
| }; | ||||
| 
 | ||||
| DiscreteSearch::DiscreteSearch(const DiscreteBayesNet& bayesNet) { | ||||
|   std::vector<DiscreteConditional::shared_ptr> conditionals; | ||||
|   for (auto& factor : bayesNet) conditionals_.push_back(factor); | ||||
|   costToGo_ = computeCostToGo(conditionals_); | ||||
| } | ||||
| 
 | ||||
| DiscreteSearch::DiscreteSearch(const DiscreteBayesTree& bayesTree) { | ||||
|   std::function<void(const DiscreteBayesTree::sharedClique&)> | ||||
|       collectConditionals = [&](const auto& clique) { | ||||
|         if (!clique) return; | ||||
|         for (const auto& child : clique->children) collectConditionals(child); | ||||
|         conditionals_.push_back(clique->conditional()); | ||||
|       }; | ||||
|   for (const auto& root : bayesTree.roots()) collectConditionals(root); | ||||
|   costToGo_ = computeCostToGo(conditionals_); | ||||
| } | ||||
| 
 | ||||
| struct SearchNodeQueue | ||||
|     : public std::priority_queue<SearchNode, std::vector<SearchNode>, | ||||
|                                  SearchNode::Compare> { | ||||
|   void expandNextNode( | ||||
|       const std::vector<DiscreteConditional::shared_ptr>& conditionals, | ||||
|       const std::vector<double>& costToGo, Solutions* solutions) { | ||||
|     // Pop the partial assignment with the smallest bound
 | ||||
|     SearchNode current = top(); | ||||
|     pop(); | ||||
| 
 | ||||
|     // If we already have K solutions, prune if we cannot beat the worst one.
 | ||||
|     if (solutions->prune(current.bound)) { | ||||
|       return; | ||||
|     } | ||||
| 
 | ||||
|     // Check if we have a complete assignment
 | ||||
|     if (current.isComplete()) { | ||||
|       solutions->maybeAdd(current.error, current.assignment); | ||||
|       return; | ||||
|     } | ||||
| 
 | ||||
|     // Expand on the next factor
 | ||||
|     const auto& conditional = conditionals[current.nextConditional]; | ||||
| 
 | ||||
|     for (auto& fa : conditional->frontalAssignments()) { | ||||
|       auto childNode = current.expand(*conditional, fa); | ||||
|       if (childNode.nextConditional >= 0) | ||||
|         childNode.bound = childNode.error + costToGo[childNode.nextConditional]; | ||||
| 
 | ||||
|       // Again, prune if we cannot beat the worst solution
 | ||||
|       if (!solutions->prune(childNode.bound)) { | ||||
|         emplace(childNode); | ||||
|       } | ||||
|     } | ||||
|   } | ||||
| }; | ||||
| 
 | ||||
| std::vector<Solution> DiscreteSearch::run(size_t K) const { | ||||
|   Solutions solutions(K); | ||||
|   SearchNodeQueue expansions; | ||||
|   expansions.push(SearchNode::Root(conditionals_.size(), | ||||
|                                    costToGo_.empty() ? 0.0 : costToGo_.back())); | ||||
| 
 | ||||
| #ifdef DISCRETE_SEARCH_DEBUG | ||||
|   size_t numExpansions = 0; | ||||
| #endif | ||||
| 
 | ||||
|   // Perform the search
 | ||||
|   while (!expansions.empty()) { | ||||
|     expansions.expandNextNode(conditionals_, costToGo_, &solutions); | ||||
| #ifdef DISCRETE_SEARCH_DEBUG | ||||
|     ++numExpansions; | ||||
| #endif | ||||
|   } | ||||
| 
 | ||||
| #ifdef DISCRETE_SEARCH_DEBUG | ||||
|   std::cout << "Number of expansions: " << numExpansions << std::endl; | ||||
| #endif | ||||
| 
 | ||||
|   // Extract solutions from bestSolutions in ascending order of error
 | ||||
|   return solutions.extractSolutions(); | ||||
| } | ||||
| 
 | ||||
| std::vector<double> DiscreteSearch::computeCostToGo( | ||||
|     const std::vector<DiscreteConditional::shared_ptr>& conditionals) { | ||||
|   std::vector<double> costToGo; | ||||
|   double error = 0.0; | ||||
|   for (const auto& conditional : conditionals) { | ||||
|     Ordering ordering(conditional->begin(), conditional->end()); | ||||
|     auto maxx = conditional->max(ordering); | ||||
|     error -= std::log(maxx->evaluate({})); | ||||
|     costToGo.push_back(error); | ||||
|   } | ||||
|   return costToGo; | ||||
| } | ||||
| 
 | ||||
| }  // namespace gtsam
 | ||||
|  | @ -0,0 +1,78 @@ | |||
| /* ----------------------------------------------------------------------------
 | ||||
| 
 | ||||
|  * GTSAM Copyright 2010, Georgia Tech Research Corporation, | ||||
|  * Atlanta, Georgia 30332-0415 | ||||
|  * All Rights Reserved | ||||
|  * Authors: Frank Dellaert, et al. (see THANKS for the full author list) | ||||
| 
 | ||||
|  * See LICENSE for the license information | ||||
| 
 | ||||
|  * -------------------------------------------------------------------------- */ | ||||
| 
 | ||||
| /*
 | ||||
|  * DiscreteSearch.cpp | ||||
|  * | ||||
|  * @date January, 2025 | ||||
|  * @author Frank Dellaert | ||||
|  */ | ||||
| 
 | ||||
| #include <gtsam/discrete/DiscreteBayesNet.h> | ||||
| #include <gtsam/discrete/DiscreteBayesTree.h> | ||||
| 
 | ||||
| #include <queue> | ||||
| 
 | ||||
| namespace gtsam { | ||||
| 
 | ||||
| /**
 | ||||
|  * DiscreteSearch: Search for the K best solutions. | ||||
|  */ | ||||
| class GTSAM_EXPORT DiscreteSearch { | ||||
|  public: | ||||
|   /**
 | ||||
|    * @brief A solution to a discrete search problem. | ||||
|    */ | ||||
|   struct Solution { | ||||
|     double error; | ||||
|     DiscreteValues assignment; | ||||
|     Solution(double err, const DiscreteValues& assign) | ||||
|         : error(err), assignment(assign) {} | ||||
|     friend std::ostream& operator<<(std::ostream& os, const Solution& sn) { | ||||
|       os << "[ error=" << sn.error << " assignment={" << sn.assignment << "}]"; | ||||
|       return os; | ||||
|     } | ||||
|   }; | ||||
| 
 | ||||
|   /**
 | ||||
|    * Construct from a DiscreteBayesNet and K. | ||||
|    */ | ||||
|   DiscreteSearch(const DiscreteBayesNet& bayesNet); | ||||
| 
 | ||||
|   /**
 | ||||
|    * Construct from a DiscreteBayesTree and K. | ||||
|    */ | ||||
|   DiscreteSearch(const DiscreteBayesTree& bayesTree); | ||||
| 
 | ||||
|   /**
 | ||||
|    * @brief Search for the K best solutions. | ||||
|    * | ||||
|    * This method performs a search to find the K best solutions for the given | ||||
|    * DiscreteBayesNet. It uses a priority queue to manage the search nodes, | ||||
|    * expanding nodes with the smallest bound first. The search continues until | ||||
|    * all possible nodes have been expanded or pruned. | ||||
|    * | ||||
|    * @return A vector of the K best solutions found during the search. | ||||
|    */ | ||||
|   std::vector<Solution> run(size_t K = 1) const; | ||||
| 
 | ||||
|  private: | ||||
|   /// Compute the cumulative cost-to-go for each conditional slot.
 | ||||
|   static std::vector<double> computeCostToGo( | ||||
|       const std::vector<DiscreteConditional::shared_ptr>& conditionals); | ||||
| 
 | ||||
|   /// Expand the next node in the search tree.
 | ||||
|   void expandNextNode() const; | ||||
| 
 | ||||
|   std::vector<DiscreteConditional::shared_ptr> conditionals_; | ||||
|   std::vector<double> costToGo_; | ||||
| }; | ||||
| }  // namespace gtsam
 | ||||
|  | @ -26,12 +26,24 @@ using std::stringstream; | |||
| 
 | ||||
| namespace gtsam { | ||||
| 
 | ||||
| /* ************************************************************************ */ | ||||
| static void stream(std::ostream& os, const DiscreteValues& x, | ||||
|                    const KeyFormatter& keyFormatter) { | ||||
|   for (const auto& kv : x) | ||||
|     os << "(" << keyFormatter(kv.first) << ", " << kv.second << ")"; | ||||
| } | ||||
| 
 | ||||
| /* ************************************************************************ */ | ||||
| std::ostream& operator<<(std::ostream& os, const DiscreteValues& x) { | ||||
|   stream(os, x, DefaultKeyFormatter); | ||||
|   return os; | ||||
| } | ||||
| 
 | ||||
| /* ************************************************************************ */ | ||||
| void DiscreteValues::print(const string& s, | ||||
|                            const KeyFormatter& keyFormatter) const { | ||||
|   cout << s << ": "; | ||||
|   for (auto&& kv : *this) | ||||
|     cout << "(" << keyFormatter(kv.first) << ", " << kv.second << ")"; | ||||
|   stream(cout, *this, keyFormatter); | ||||
|   cout << endl; | ||||
| } | ||||
| 
 | ||||
|  |  | |||
|  | @ -64,6 +64,9 @@ class GTSAM_EXPORT DiscreteValues : public Assignment<Key> { | |||
|   /// @name Standard Interface
 | ||||
|   /// @{
 | ||||
| 
 | ||||
|   /// ostream operator:
 | ||||
|   friend std::ostream& operator<<(std::ostream& os, const DiscreteValues& x); | ||||
| 
 | ||||
|   // insert in base class;
 | ||||
|   std::pair<iterator, bool> insert( const value_type& value ){ | ||||
|     return Base::insert(value); | ||||
|  |  | |||
|  | @ -0,0 +1,61 @@ | |||
| /* ----------------------------------------------------------------------------
 | ||||
| 
 | ||||
|  * GTSAM Copyright 2010, Georgia Tech Research Corporation, | ||||
|  * Atlanta, Georgia 30332-0415 | ||||
|  * All Rights Reserved | ||||
|  * Authors: Frank Dellaert, et al. (see THANKS for the full author list) | ||||
| 
 | ||||
|  * See LICENSE for the license information | ||||
| 
 | ||||
|  * -------------------------------------------------------------------------- */ | ||||
| 
 | ||||
| /*
 | ||||
|  * AsiaExample.h | ||||
|  * | ||||
|  *  @date Jan, 2025 | ||||
|  *  @author Frank Dellaert | ||||
|  */ | ||||
| 
 | ||||
| #include <gtsam/discrete/DiscreteBayesNet.h> | ||||
| #include <gtsam/inference/Symbol.h> | ||||
| 
 | ||||
| namespace gtsam { | ||||
| namespace asia_example { | ||||
| 
 | ||||
| static const Key D = Symbol('D', 1), X = Symbol('X', 2), E = Symbol('E', 3), | ||||
|                  B = Symbol('B', 4), L = Symbol('L', 5), T = Symbol('T', 6), | ||||
|                  S = Symbol('S', 7), A = Symbol('A', 8); | ||||
| 
 | ||||
| static const DiscreteKey Dyspnea(D, 2), XRay(X, 2), Either(E, 2), | ||||
|     Bronchitis(B, 2), LungCancer(L, 2), Tuberculosis(T, 2), Smoking(S, 2), | ||||
|     Asia(A, 2); | ||||
| 
 | ||||
| // Function to construct the Asia priors
 | ||||
| DiscreteBayesNet createPriors() { | ||||
|   DiscreteBayesNet priors; | ||||
|   priors.add(Smoking % "50/50"); | ||||
|   priors.add(Asia, "99/1"); | ||||
|   return priors; | ||||
| } | ||||
| 
 | ||||
| // Function to construct the incomplete Asia example
 | ||||
| DiscreteBayesNet createFragment() { | ||||
|   DiscreteBayesNet fragment; | ||||
|   fragment.add((Either | Tuberculosis, LungCancer) = "F T T T"); | ||||
|   fragment.add(LungCancer | Smoking = "99/1 90/10"); | ||||
|   fragment.add(Tuberculosis | Asia = "99/1 95/5"); | ||||
|   for (const auto& factor : createPriors()) fragment.push_back(factor); | ||||
|   return fragment; | ||||
| } | ||||
| 
 | ||||
| // Function to construct the Asia example
 | ||||
| DiscreteBayesNet createAsiaExample() { | ||||
|   DiscreteBayesNet asia; | ||||
|   asia.add((Dyspnea | Either, Bronchitis) = "9/1 2/8 3/7 1/9"); | ||||
|   asia.add(XRay | Either = "95/5 2/98"); | ||||
|   asia.add(Bronchitis | Smoking = "70/30 40/60"); | ||||
|   for (const auto& factor : createFragment()) asia.push_back(factor); | ||||
|   return asia; | ||||
| } | ||||
| }  // namespace asia_example
 | ||||
| }  // namespace gtsam
 | ||||
|  | @ -23,40 +23,19 @@ | |||
| #include <gtsam/discrete/DiscreteBayesNet.h> | ||||
| #include <gtsam/discrete/DiscreteFactorGraph.h> | ||||
| #include <gtsam/discrete/DiscreteMarginals.h> | ||||
| #include <gtsam/inference/Symbol.h> | ||||
| 
 | ||||
| #include <iostream> | ||||
| #include <string> | ||||
| #include <vector> | ||||
| 
 | ||||
| using namespace std; | ||||
| #include "AsiaExample.h" | ||||
| 
 | ||||
| using namespace gtsam; | ||||
| 
 | ||||
| static const DiscreteKey Asia(0, 2), Smoking(4, 2), Tuberculosis(3, 2), | ||||
|     LungCancer(6, 2), Bronchitis(7, 2), Either(5, 2), XRay(2, 2), Dyspnea(1, 2); | ||||
| 
 | ||||
| using ADT = AlgebraicDecisionTree<Key>; | ||||
| 
 | ||||
| // Function to construct the Asia example
 | ||||
| DiscreteBayesNet constructAsiaExample() { | ||||
|   DiscreteBayesNet asia; | ||||
| 
 | ||||
|   asia.add(Asia, "99/1"); | ||||
|   asia.add(Smoking % "50/50");  // Signature version
 | ||||
| 
 | ||||
|   asia.add(Tuberculosis | Asia = "99/1 95/5"); | ||||
|   asia.add(LungCancer | Smoking = "99/1 90/10"); | ||||
|   asia.add(Bronchitis | Smoking = "70/30 40/60"); | ||||
| 
 | ||||
|   asia.add((Either | Tuberculosis, LungCancer) = "F T T T"); | ||||
| 
 | ||||
|   asia.add(XRay | Either = "95/5 2/98"); | ||||
|   asia.add((Dyspnea | Either, Bronchitis) = "9/1 2/8 3/7 1/9"); | ||||
| 
 | ||||
|   return asia; | ||||
| } | ||||
| 
 | ||||
| /* ************************************************************************* */ | ||||
| TEST(DiscreteBayesNet, bayesNet) { | ||||
|   using ADT = AlgebraicDecisionTree<Key>; | ||||
|   DiscreteBayesNet bayesNet; | ||||
|   DiscreteKey Parent(0, 2), Child(1, 2); | ||||
| 
 | ||||
|  | @ -86,11 +65,12 @@ TEST(DiscreteBayesNet, bayesNet) { | |||
| 
 | ||||
| /* ************************************************************************* */ | ||||
| TEST(DiscreteBayesNet, Asia) { | ||||
|   DiscreteBayesNet asia = constructAsiaExample(); | ||||
|   using namespace asia_example; | ||||
|   const DiscreteBayesNet asia = createAsiaExample(); | ||||
| 
 | ||||
|   // Convert to factor graph
 | ||||
|   DiscreteFactorGraph fg(asia); | ||||
|   LONGS_EQUAL(3, fg.back()->size()); | ||||
|   LONGS_EQUAL(1, fg.back()->size()); | ||||
| 
 | ||||
|   // Check the marginals we know (of the parent-less nodes)
 | ||||
|   DiscreteMarginals marginals(fg); | ||||
|  | @ -99,7 +79,7 @@ TEST(DiscreteBayesNet, Asia) { | |||
|   EXPECT(assert_equal(vs, marginals.marginalProbabilities(Smoking))); | ||||
| 
 | ||||
|   // Create solver and eliminate
 | ||||
|   const Ordering ordering{0, 1, 2, 3, 4, 5, 6, 7}; | ||||
|   const Ordering ordering{A, D, T, X, S, E, L, B}; | ||||
|   DiscreteBayesNet::shared_ptr chordal = fg.eliminateSequential(ordering); | ||||
|   DiscreteConditional expected2(Bronchitis % "11/9"); | ||||
|   EXPECT(assert_equal(expected2, *chordal->back())); | ||||
|  | @ -144,55 +124,50 @@ TEST(DiscreteBayesNet, Sugar) { | |||
| 
 | ||||
| /* ************************************************************************* */ | ||||
| TEST(DiscreteBayesNet, Dot) { | ||||
|   DiscreteBayesNet fragment; | ||||
|   fragment.add(Asia % "99/1"); | ||||
|   fragment.add(Smoking % "50/50"); | ||||
|   using namespace asia_example; | ||||
|   const DiscreteBayesNet fragment = createFragment(); | ||||
| 
 | ||||
|   fragment.add(Tuberculosis | Asia = "99/1 95/5"); | ||||
|   fragment.add(LungCancer | Smoking = "99/1 90/10"); | ||||
|   fragment.add((Either | Tuberculosis, LungCancer) = "F T T T"); | ||||
| 
 | ||||
|   string actual = fragment.dot(); | ||||
|   EXPECT(actual == | ||||
|          "digraph {\n" | ||||
|          "  size=\"5,5\";\n" | ||||
|          "\n" | ||||
|          "  var0[label=\"0\"];\n" | ||||
|          "  var3[label=\"3\"];\n" | ||||
|          "  var4[label=\"4\"];\n" | ||||
|          "  var5[label=\"5\"];\n" | ||||
|          "  var6[label=\"6\"];\n" | ||||
|          "\n" | ||||
|          "  var3->var5\n" | ||||
|          "  var6->var5\n" | ||||
|          "  var4->var6\n" | ||||
|          "  var0->var3\n" | ||||
|          "}"); | ||||
|   std::string expected = | ||||
|       "digraph {\n" | ||||
|       "  size=\"5,5\";\n" | ||||
|       "\n" | ||||
|       "  var4683743612465315848[label=\"A8\"];\n" | ||||
|       "  var4971973988617027587[label=\"E3\"];\n" | ||||
|       "  var5476377146882523141[label=\"L5\"];\n" | ||||
|       "  var5980780305148018695[label=\"S7\"];\n" | ||||
|       "  var6052837899185946630[label=\"T6\"];\n" | ||||
|       "\n" | ||||
|       "  var4683743612465315848->var6052837899185946630\n" | ||||
|       "  var5980780305148018695->var5476377146882523141\n" | ||||
|       "  var6052837899185946630->var4971973988617027587\n" | ||||
|       "  var5476377146882523141->var4971973988617027587\n" | ||||
|       "}"; | ||||
|   std::string actual = fragment.dot(); | ||||
|   EXPECT(actual.compare(expected) == 0); | ||||
| } | ||||
| 
 | ||||
| /* ************************************************************************* */ | ||||
| // Check markdown representation looks as expected.
 | ||||
| TEST(DiscreteBayesNet, markdown) { | ||||
|   DiscreteBayesNet fragment; | ||||
|   fragment.add(Asia % "99/1"); | ||||
|   fragment.add(Smoking | Asia = "8/2 7/3"); | ||||
|   using namespace asia_example; | ||||
|   DiscreteBayesNet priors = createPriors(); | ||||
| 
 | ||||
|   string expected = | ||||
|   std::string expected = | ||||
|       "`DiscreteBayesNet` of size 2\n" | ||||
|       "\n" | ||||
|       " *P(Smoking):*\n\n" | ||||
|       "|Smoking|value|\n" | ||||
|       "|:-:|:-:|\n" | ||||
|       "|0|0.5|\n" | ||||
|       "|1|0.5|\n" | ||||
|       "\n" | ||||
|       " *P(Asia):*\n\n" | ||||
|       "|Asia|value|\n" | ||||
|       "|:-:|:-:|\n" | ||||
|       "|0|0.99|\n" | ||||
|       "|1|0.01|\n" | ||||
|       "\n" | ||||
|       " *P(Smoking|Asia):*\n\n" | ||||
|       "|*Asia*|0|1|\n" | ||||
|       "|:-:|:-:|:-:|\n" | ||||
|       "|0|0.8|0.2|\n" | ||||
|       "|1|0.7|0.3|\n\n"; | ||||
|   auto formatter = [](Key key) { return key == 0 ? "Asia" : "Smoking"; }; | ||||
|   string actual = fragment.markdown(formatter); | ||||
|       "|1|0.01|\n\n"; | ||||
|   auto formatter = [](Key key) { return key == A ? "Asia" : "Smoking"; }; | ||||
|   std::string actual = priors.markdown(formatter); | ||||
|   EXPECT(actual == expected); | ||||
| } | ||||
| 
 | ||||
|  |  | |||
|  | @ -0,0 +1,111 @@ | |||
| /* ----------------------------------------------------------------------------
 | ||||
| 
 | ||||
|  * GTSAM Copyright 2010, Georgia Tech Research Corporation, | ||||
|  * Atlanta, Georgia 30332-0415 | ||||
|  * All Rights Reserved | ||||
|  * Authors: Frank Dellaert, et al. (see THANKS for the full author list) | ||||
| 
 | ||||
|  * See LICENSE for the license information | ||||
| 
 | ||||
|  * -------------------------------------------------------------------------- */ | ||||
| 
 | ||||
| /*
 | ||||
|  * testDiscreteSearch.cpp | ||||
|  * | ||||
|  *  @date January, 2025 | ||||
|  *  @author Frank Dellaert | ||||
|  */ | ||||
| 
 | ||||
| #include <CppUnitLite/TestHarness.h> | ||||
| #include <gtsam/base/Testable.h> | ||||
| #include <gtsam/discrete/DiscreteSearch.h> | ||||
| 
 | ||||
| #include "AsiaExample.h" | ||||
| 
 | ||||
| using namespace gtsam; | ||||
| 
 | ||||
| // Create Asia Bayes net, FG, and Bayes tree once
 | ||||
| namespace asia { | ||||
| using namespace asia_example; | ||||
| static const DiscreteBayesNet bayesNet = createAsiaExample(); | ||||
| static const DiscreteFactorGraph factorGraph(bayesNet); | ||||
| static const DiscreteValues mpe = factorGraph.optimize(); | ||||
| static const Ordering ordering{D, X, B, E, L, T, S, A}; | ||||
| static const DiscreteBayesTree bayesTree = | ||||
|     *factorGraph.eliminateMultifrontal(ordering); | ||||
| }  // namespace asia
 | ||||
| 
 | ||||
| /* ************************************************************************* */ | ||||
| TEST(DiscreteBayesNet, EmptyKBest) { | ||||
|   DiscreteBayesNet net;  // no factors
 | ||||
|   DiscreteSearch search(net); | ||||
|   auto solutions = search.run(3); | ||||
|   // Expect one solution with empty assignment, error=0
 | ||||
|   EXPECT_LONGS_EQUAL(1, solutions.size()); | ||||
|   EXPECT_DOUBLES_EQUAL(0, std::fabs(solutions[0].error), 1e-9); | ||||
| } | ||||
| 
 | ||||
| /* ************************************************************************* */ | ||||
| TEST(DiscreteBayesNet, AsiaKBest) { | ||||
|   const DiscreteSearch search(asia::bayesNet); | ||||
| 
 | ||||
|   // Ask for the MPE
 | ||||
|   auto mpe = search.run(); | ||||
| 
 | ||||
|   EXPECT_LONGS_EQUAL(1, mpe.size()); | ||||
|   // Regression test: check the MPE solution
 | ||||
|   EXPECT_DOUBLES_EQUAL(1.236627, std::fabs(mpe[0].error), 1e-5); | ||||
| 
 | ||||
|   // Check it is equal to MPE via inference
 | ||||
|   EXPECT(assert_equal(asia::mpe, mpe[0].assignment)); | ||||
| 
 | ||||
|   // Ask for top 4 solutions
 | ||||
|   auto solutions = search.run(4); | ||||
| 
 | ||||
|   EXPECT_LONGS_EQUAL(4, solutions.size()); | ||||
|   // Regression test: check the first and last solution
 | ||||
|   EXPECT_DOUBLES_EQUAL(1.236627, std::fabs(solutions[0].error), 1e-5); | ||||
|   EXPECT_DOUBLES_EQUAL(2.201708, std::fabs(solutions[3].error), 1e-5); | ||||
| } | ||||
| 
 | ||||
| /* ************************************************************************* */ | ||||
| TEST(DiscreteBayesTree, EmptyTree) { | ||||
|   DiscreteBayesTree bt; | ||||
| 
 | ||||
|   DiscreteSearch search(bt); | ||||
|   auto solutions = search.run(3); | ||||
| 
 | ||||
|   // We expect exactly 1 solution with error = 0.0 (the empty assignment).
 | ||||
|   EXPECT_LONGS_EQUAL(1, solutions.size()); | ||||
|   EXPECT_DOUBLES_EQUAL(0, std::fabs(solutions[0].error), 1e-9); | ||||
| } | ||||
| 
 | ||||
| /* ************************************************************************* */ | ||||
| TEST(DiscreteBayesTree, AsiaTreeKBest) { | ||||
|   DiscreteSearch search(asia::bayesTree); | ||||
| 
 | ||||
|   // Ask for MPE
 | ||||
|   auto mpe = search.run(); | ||||
| 
 | ||||
|   EXPECT_LONGS_EQUAL(1, mpe.size()); | ||||
|   // Regression test: check the MPE solution
 | ||||
|   EXPECT_DOUBLES_EQUAL(1.236627, std::fabs(mpe[0].error), 1e-5); | ||||
| 
 | ||||
|   // Check it is equal to MPE via inference
 | ||||
|   EXPECT(assert_equal(asia::mpe, mpe[0].assignment)); | ||||
| 
 | ||||
|   // Ask for top 4 solutions
 | ||||
|   auto solutions = search.run(4); | ||||
| 
 | ||||
|   EXPECT_LONGS_EQUAL(4, solutions.size()); | ||||
|   // Regression test: check the first and last solution
 | ||||
|   EXPECT_DOUBLES_EQUAL(1.236627, std::fabs(solutions[0].error), 1e-5); | ||||
|   EXPECT_DOUBLES_EQUAL(2.201708, std::fabs(solutions[3].error), 1e-5); | ||||
| } | ||||
| 
 | ||||
| /* ************************************************************************* */ | ||||
| int main() { | ||||
|   TestResult tr; | ||||
|   return TestRegistry::runAllTests(tr); | ||||
| } | ||||
| /* ************************************************************************* */ | ||||
		Loading…
	
		Reference in New Issue