Asia example
							parent
							
								
									1f4d9bbd7e
								
							
						
					
					
						commit
						d879b156f8
					
				|  | @ -0,0 +1,61 @@ | |||
| /* ----------------------------------------------------------------------------
 | ||||
| 
 | ||||
|  * GTSAM Copyright 2010, Georgia Tech Research Corporation, | ||||
|  * Atlanta, Georgia 30332-0415 | ||||
|  * All Rights Reserved | ||||
|  * Authors: Frank Dellaert, et al. (see THANKS for the full author list) | ||||
| 
 | ||||
|  * See LICENSE for the license information | ||||
| 
 | ||||
|  * -------------------------------------------------------------------------- */ | ||||
| 
 | ||||
| /*
 | ||||
|  * AsiaExample.h | ||||
|  * | ||||
|  *  @date Jan, 2025 | ||||
|  *  @author Frank Dellaert | ||||
|  */ | ||||
| 
 | ||||
| #include <gtsam/discrete/DiscreteBayesNet.h> | ||||
| #include <gtsam/inference/Symbol.h> | ||||
| 
 | ||||
| namespace gtsam { | ||||
| namespace asia_example { | ||||
| 
 | ||||
| static const Key D = Symbol('D', 1), X = Symbol('X', 2), E = Symbol('E', 3), | ||||
|                  B = Symbol('B', 4), L = Symbol('L', 5), T = Symbol('T', 6), | ||||
|                  S = Symbol('S', 7), A = Symbol('A', 8); | ||||
| 
 | ||||
| static const DiscreteKey Dyspnea(D, 2), XRay(X, 2), Either(E, 2), | ||||
|     Bronchitis(B, 2), LungCancer(L, 2), Tuberculosis(T, 2), Smoking(S, 2), | ||||
|     Asia(A, 2); | ||||
| 
 | ||||
| // Function to construct the incomplete Asia example
 | ||||
| DiscreteBayesNet createPriors() { | ||||
|   DiscreteBayesNet priors; | ||||
|   priors.add(Smoking % "50/50"); | ||||
|   priors.add(Asia, "99/1"); | ||||
|   return priors; | ||||
| } | ||||
| 
 | ||||
| // Function to construct the incomplete Asia example
 | ||||
| DiscreteBayesNet createFragment() { | ||||
|   DiscreteBayesNet fragment; | ||||
|   fragment.add((Either | Tuberculosis, LungCancer) = "F T T T"); | ||||
|   fragment.add(LungCancer | Smoking = "99/1 90/10"); | ||||
|   fragment.add(Tuberculosis | Asia = "99/1 95/5"); | ||||
|   for (const auto& factor : createPriors()) fragment.push_back(factor); | ||||
|   return fragment; | ||||
| } | ||||
| 
 | ||||
| // Function to construct the Asia example
 | ||||
| DiscreteBayesNet createAsiaExample() { | ||||
|   DiscreteBayesNet asia; | ||||
|   asia.add((Dyspnea | Either, Bronchitis) = "9/1 2/8 3/7 1/9"); | ||||
|   asia.add(XRay | Either = "95/5 2/98"); | ||||
|   asia.add(Bronchitis | Smoking = "70/30 40/60"); | ||||
|   for (const auto& factor : createFragment()) asia.push_back(factor); | ||||
|   return asia; | ||||
| } | ||||
| }  // namespace asia_example
 | ||||
| }  // namespace gtsam
 | ||||
|  | @ -29,40 +29,13 @@ | |||
| #include <string> | ||||
| #include <vector> | ||||
| 
 | ||||
| using namespace std; | ||||
| #include "AsiaExample.h" | ||||
| 
 | ||||
| using namespace gtsam; | ||||
| 
 | ||||
| namespace keys { | ||||
| static const Key D = Symbol('D', 1), X = Symbol('X', 2), E = Symbol('E', 3), | ||||
|                  B = Symbol('B', 4), L = Symbol('L', 5), T = Symbol('T', 6), | ||||
|                  S = Symbol('S', 7), A = Symbol('A', 8); | ||||
| } | ||||
| 
 | ||||
| static const DiscreteKey Dyspnea(keys::D, 2), XRay(keys::X, 2), | ||||
|     Either(keys::E, 2), Bronchitis(keys::B, 2), LungCancer(keys::L, 2), | ||||
|     Tuberculosis(keys::T, 2), Smoking(keys::S, 2), Asia(keys::A, 2); | ||||
| 
 | ||||
| using ADT = AlgebraicDecisionTree<Key>; | ||||
| 
 | ||||
| // Function to construct the Asia example
 | ||||
| DiscreteBayesNet constructAsiaExample() { | ||||
|   DiscreteBayesNet asia; | ||||
| 
 | ||||
|   // Add in topological sort order, parents last:
 | ||||
|   asia.add((Dyspnea | Either, Bronchitis) = "9/1 2/8 3/7 1/9"); | ||||
|   asia.add(XRay | Either = "95/5 2/98"); | ||||
|   asia.add((Either | Tuberculosis, LungCancer) = "F T T T"); | ||||
|   asia.add(Bronchitis | Smoking = "70/30 40/60"); | ||||
|   asia.add(LungCancer | Smoking = "99/1 90/10"); | ||||
|   asia.add(Tuberculosis | Asia = "99/1 95/5"); | ||||
|   asia.add(Smoking % "50/50");  // Signature version
 | ||||
|   asia.add(Asia, "99/1"); | ||||
| 
 | ||||
|   return asia; | ||||
| } | ||||
| 
 | ||||
| /* ************************************************************************* */ | ||||
| TEST(DiscreteBayesNet, bayesNet) { | ||||
|   using ADT = AlgebraicDecisionTree<Key>; | ||||
|   DiscreteBayesNet bayesNet; | ||||
|   DiscreteKey Parent(0, 2), Child(1, 2); | ||||
| 
 | ||||
|  | @ -92,7 +65,8 @@ TEST(DiscreteBayesNet, bayesNet) { | |||
| 
 | ||||
| /* ************************************************************************* */ | ||||
| TEST(DiscreteBayesNet, Asia) { | ||||
|   DiscreteBayesNet asia = constructAsiaExample(); | ||||
|   using namespace asia_example; | ||||
|   const DiscreteBayesNet asia = createAsiaExample(); | ||||
| 
 | ||||
|   // Convert to factor graph
 | ||||
|   DiscreteFactorGraph fg(asia); | ||||
|  | @ -105,8 +79,7 @@ TEST(DiscreteBayesNet, Asia) { | |||
|   EXPECT(assert_equal(vs, marginals.marginalProbabilities(Smoking))); | ||||
| 
 | ||||
|   // Create solver and eliminate
 | ||||
|   const Ordering ordering{keys::A, keys::D, keys::T, keys::X, | ||||
|                           keys::S, keys::E, keys::L, keys::B}; | ||||
|   const Ordering ordering{A, D, T, X, S, E, L, B}; | ||||
|   DiscreteBayesNet::shared_ptr chordal = fg.eliminateSequential(ordering); | ||||
|   DiscreteConditional expected2(Bronchitis % "11/9"); | ||||
|   EXPECT(assert_equal(expected2, *chordal->back())); | ||||
|  | @ -151,319 +124,53 @@ TEST(DiscreteBayesNet, Sugar) { | |||
| 
 | ||||
| /* ************************************************************************* */ | ||||
| TEST(DiscreteBayesNet, Dot) { | ||||
|   DiscreteBayesNet fragment; | ||||
|   fragment.add(Asia % "99/1"); | ||||
|   fragment.add(Smoking % "50/50"); | ||||
|   using namespace asia_example; | ||||
|   const DiscreteBayesNet fragment = createFragment(); | ||||
| 
 | ||||
|   fragment.add(Tuberculosis | Asia = "99/1 95/5"); | ||||
|   fragment.add(LungCancer | Smoking = "99/1 90/10"); | ||||
|   fragment.add((Either | Tuberculosis, LungCancer) = "F T T T"); | ||||
| 
 | ||||
|   string actual = fragment.dot(); | ||||
|   EXPECT(actual == | ||||
|          "digraph {\n" | ||||
|          "  size=\"5,5\";\n" | ||||
|          "\n" | ||||
|          "  var4683743612465315848[label=\"A8\"];\n" | ||||
|          "  var4971973988617027587[label=\"E3\"];\n" | ||||
|          "  var5476377146882523141[label=\"L5\"];\n" | ||||
|          "  var5980780305148018695[label=\"S7\"];\n" | ||||
|          "  var6052837899185946630[label=\"T6\"];\n" | ||||
|          "\n" | ||||
|          "  var6052837899185946630->var4971973988617027587\n" | ||||
|          "  var5476377146882523141->var4971973988617027587\n" | ||||
|          "  var5980780305148018695->var5476377146882523141\n" | ||||
|          "  var4683743612465315848->var6052837899185946630\n" | ||||
|          "}"); | ||||
|   std::string expected = | ||||
|       "digraph {\n" | ||||
|       "  size=\"5,5\";\n" | ||||
|       "\n" | ||||
|       "  var4683743612465315848[label=\"A8\"];\n" | ||||
|       "  var4971973988617027587[label=\"E3\"];\n" | ||||
|       "  var5476377146882523141[label=\"L5\"];\n" | ||||
|       "  var5980780305148018695[label=\"S7\"];\n" | ||||
|       "  var6052837899185946630[label=\"T6\"];\n" | ||||
|       "\n" | ||||
|       "  var4683743612465315848->var6052837899185946630\n" | ||||
|       "  var5980780305148018695->var5476377146882523141\n" | ||||
|       "  var6052837899185946630->var4971973988617027587\n" | ||||
|       "  var5476377146882523141->var4971973988617027587\n" | ||||
|       "}"; | ||||
|   std::string actual = fragment.dot(); | ||||
|   EXPECT(actual.compare(expected) == 0); | ||||
| } | ||||
| 
 | ||||
| /* ************************************************************************* */ | ||||
| // Check markdown representation looks as expected.
 | ||||
| TEST(DiscreteBayesNet, markdown) { | ||||
|   DiscreteBayesNet fragment; | ||||
|   fragment.add(Asia % "99/1"); | ||||
|   fragment.add(Smoking | Asia = "8/2 7/3"); | ||||
|   using namespace asia_example; | ||||
|   DiscreteBayesNet priors = createPriors(); | ||||
| 
 | ||||
|   string expected = | ||||
|   std::string expected = | ||||
|       "`DiscreteBayesNet` of size 2\n" | ||||
|       "\n" | ||||
|       " *P(Smoking):*\n\n" | ||||
|       "|Smoking|value|\n" | ||||
|       "|:-:|:-:|\n" | ||||
|       "|0|0.5|\n" | ||||
|       "|1|0.5|\n" | ||||
|       "\n" | ||||
|       " *P(Asia):*\n\n" | ||||
|       "|Asia|value|\n" | ||||
|       "|:-:|:-:|\n" | ||||
|       "|0|0.99|\n" | ||||
|       "|1|0.01|\n" | ||||
|       "\n" | ||||
|       " *P(Smoking|Asia):*\n\n" | ||||
|       "|*Asia*|0|1|\n" | ||||
|       "|:-:|:-:|:-:|\n" | ||||
|       "|0|0.8|0.2|\n" | ||||
|       "|1|0.7|0.3|\n\n"; | ||||
|   auto formatter = [](Key key) { return key == keys::A ? "Asia" : "Smoking"; }; | ||||
|   string actual = fragment.markdown(formatter); | ||||
|       "|1|0.01|\n\n"; | ||||
|   auto formatter = [](Key key) { return key == A ? "Asia" : "Smoking"; }; | ||||
|   std::string actual = priors.markdown(formatter); | ||||
|   EXPECT(actual == expected); | ||||
| } | ||||
| 
 | ||||
| /* ************************************************************************* */ | ||||
| #include <algorithm> | ||||
| #include <cmath> | ||||
| #include <iostream> | ||||
| #include <map> | ||||
| #include <queue> | ||||
| #include <vector> | ||||
| 
 | ||||
| using Value = size_t; | ||||
| 
 | ||||
| // ----------------------------------------------------------------------------
 | ||||
| // 1) SearchNode: store partial assignment and next factor to expand
 | ||||
| // ----------------------------------------------------------------------------
 | ||||
| struct SearchNode { | ||||
|   DiscreteValues assignment; | ||||
|   double error; | ||||
|   double bound; | ||||
|   int nextConditional;  // index into conditionals
 | ||||
| 
 | ||||
|   /// if nextConditional < 0, we've assigned everything.
 | ||||
|   bool isComplete() const { return nextConditional < 0; } | ||||
| 
 | ||||
|   /// lower bound on final error for unassigned variables. Stub=0.
 | ||||
|   double computeBound() const { | ||||
|     // Real code might do partial factor analysis or heuristics.
 | ||||
|     return 0.0; | ||||
|   } | ||||
| 
 | ||||
|   /// Expand this node by assigning the next variable
 | ||||
|   SearchNode expand(const DiscreteConditional& conditional, | ||||
|                     const DiscreteValues& fa) const { | ||||
|     // Combine the new frontal assignment with the current partial assignment
 | ||||
|     SearchNode child; | ||||
|     child.assignment = assignment; | ||||
|     for (auto& kv : fa) { | ||||
|       child.assignment[kv.first] = kv.second; | ||||
|     } | ||||
| 
 | ||||
|     // Compute the incremental error for this factor
 | ||||
|     child.error = error + conditional.error(child.assignment); | ||||
| 
 | ||||
|     // Compute new bound
 | ||||
|     child.bound = child.error + computeBound(); | ||||
| 
 | ||||
|     // Next factor index
 | ||||
|     child.nextConditional = nextConditional - 1; | ||||
| 
 | ||||
|     return child; | ||||
|   } | ||||
| 
 | ||||
|   friend std::ostream& operator<<(std::ostream& os, const SearchNode& sn) { | ||||
|     os << "[ error=" << sn.error << " bound=" << sn.bound | ||||
|        << " nextConditional=" << sn.nextConditional << " assignment={" | ||||
|        << sn.assignment << "}]"; | ||||
|     return os; | ||||
|   } | ||||
| }; | ||||
| 
 | ||||
| // ----------------------------------------------------------------------------
 | ||||
| // 2) Priority functor to make a min-heap by bound
 | ||||
| // ----------------------------------------------------------------------------
 | ||||
| struct CompareByBound { | ||||
|   bool operator()(const SearchNode& a, const SearchNode& b) const { | ||||
|     return a.bound > b.bound;  // smallest bound -> highest priority
 | ||||
|   } | ||||
| }; | ||||
| 
 | ||||
| // ----------------------------------------------------------------------------
 | ||||
| // 4) A Solution
 | ||||
| // ----------------------------------------------------------------------------
 | ||||
| struct Solution { | ||||
|   double error; | ||||
|   DiscreteValues assignment; | ||||
|   Solution(double err, const DiscreteValues& assign) | ||||
|       : error(err), assignment(assign) {} | ||||
|   friend std::ostream& operator<<(std::ostream& os, const Solution& sn) { | ||||
|     os << "[ error=" << sn.error << " assignment={" << sn.assignment << "}]"; | ||||
|     return os; | ||||
|   } | ||||
| }; | ||||
| 
 | ||||
| struct CompareByError { | ||||
|   bool operator()(const Solution& a, const Solution& b) const { | ||||
|     return a.error < b.error; | ||||
|   } | ||||
| }; | ||||
| 
 | ||||
| // Define the Solutions class
 | ||||
| class Solutions { | ||||
|  private: | ||||
|   size_t maxSize_; | ||||
|   std::priority_queue<Solution, std::vector<Solution>, CompareByError> pq_; | ||||
| 
 | ||||
|  public: | ||||
|   Solutions(size_t maxSize) : maxSize_(maxSize) {} | ||||
| 
 | ||||
|   /// Add a solution to the priority queue, possibly evicting the worst one.
 | ||||
|   /// Return true if we added the solution.
 | ||||
|   bool maybeAdd(double error, const DiscreteValues& assignment) { | ||||
|     const bool full = pq_.size() == maxSize_; | ||||
|     if (full && error >= pq_.top().error) return false; | ||||
|     if (full) pq_.pop(); | ||||
|     pq_.emplace(error, assignment); | ||||
|     return true; | ||||
|   } | ||||
| 
 | ||||
|   /// Check if we have any solutions
 | ||||
|   bool empty() const { return pq_.empty(); } | ||||
| 
 | ||||
|   // Method to print all solutions
 | ||||
|   void print() const { | ||||
|     auto pq = pq_; | ||||
|     while (!pq.empty()) { | ||||
|       const Solution& best = pq.top(); | ||||
|       std::cout << "Error: " << best.error << ", Values: " << best.assignment | ||||
|                 << std::endl; | ||||
|       pq.pop(); | ||||
|     } | ||||
|   } | ||||
| 
 | ||||
|   /// Check if (partial) solution with given bound can be pruned. If we have
 | ||||
|   /// room, we never prune. Otherwise, prune if lower bound on error is worse
 | ||||
|   /// than our current worst error.
 | ||||
|   bool prune(double bound) const { | ||||
|     if (pq_.size() < maxSize_) return false; | ||||
|     double worstError = pq_.top().error; | ||||
|     return (bound >= worstError); | ||||
|   } | ||||
| 
 | ||||
|   // Method to extract solutions in ascending order of error
 | ||||
|   std::vector<Solution> extractSolutions() { | ||||
|     std::vector<Solution> result; | ||||
|     while (!pq_.empty()) { | ||||
|       result.push_back(pq_.top()); | ||||
|       pq_.pop(); | ||||
|     } | ||||
|     std::sort( | ||||
|         result.begin(), result.end(), | ||||
|         [](const Solution& a, const Solution& b) { return a.error < b.error; }); | ||||
|     return result; | ||||
|   } | ||||
| }; | ||||
| 
 | ||||
| /**
 | ||||
|  * BestKSearch: Search for the K best solutions. | ||||
|  */ | ||||
| class BestKSearch { | ||||
|  public: | ||||
|   /**
 | ||||
|    * Construct from a DiscreteBayesNet and K. | ||||
|    */ | ||||
|   BestKSearch(const DiscreteBayesNet& bayesNet, size_t K) | ||||
|       : bayesNet_(bayesNet), solutions_(K) { | ||||
|     // Copy out the conditionals
 | ||||
|     for (auto& factor : bayesNet_) { | ||||
|       conditionals_.push_back(factor); | ||||
|     } | ||||
| 
 | ||||
|     // Create the root node: no variables assigned, nextConditional = last.
 | ||||
|     SearchNode root{ | ||||
|         .assignment = DiscreteValues(), | ||||
|         .error = 0.0, | ||||
|         .nextConditional = static_cast<int>(conditionals_.size()) - 1}; | ||||
|     root.bound = root.computeBound(); | ||||
|     std::cout << "Root: " << root << std::endl; | ||||
|     expansions_.push(root); | ||||
|   } | ||||
| 
 | ||||
|   /**
 | ||||
|    * @brief Search for the K best solutions. | ||||
|    * | ||||
|    * This method performs a search to find the K best solutions for the given | ||||
|    * DiscreteBayesNet. It uses a priority queue to manage the search nodes, | ||||
|    * expanding nodes with the smallest bound first. The search continues until | ||||
|    * all possible nodes have been expanded or pruned. | ||||
|    * | ||||
|    * @return A vector of the K best solutions found during the search. | ||||
|    */ | ||||
|   std::vector<Solution> run() { | ||||
|     size_t numExpansions = 0; | ||||
|     while (!expansions_.empty()) { | ||||
|       expandNextNode(); | ||||
|       numExpansions++; | ||||
|     } | ||||
| 
 | ||||
|     std::cout << "Expansions: " << numExpansions << std::endl; | ||||
| 
 | ||||
|     // Extract solutions from bestSolutions in ascending order of error
 | ||||
|     return solutions_.extractSolutions(); | ||||
|   } | ||||
| 
 | ||||
|  private: | ||||
|   //
 | ||||
|   void expandNextNode() { | ||||
|     // Pop the partial assignment with the smallest bound
 | ||||
|     SearchNode current = expansions_.top(); | ||||
|     expansions_.pop(); | ||||
|     std::cout << "Expanding: " << current << std::endl; | ||||
| 
 | ||||
|     // If we already have K solutions, prune if we cannot beat the worst one.
 | ||||
|     if (solutions_.prune(current.bound)) { | ||||
|       std::cout << "Pruning: bound=" << current.bound << std::endl; | ||||
|       return; | ||||
|     } | ||||
| 
 | ||||
|     // Check if we have a complete assignment
 | ||||
|     if (current.isComplete()) { | ||||
|       const bool added = solutions_.maybeAdd(current.error, current.assignment); | ||||
|       if (added) { | ||||
|         std::cout << "Best solutions so far:" << std::endl; | ||||
|         solutions_.print(); | ||||
|       } | ||||
|       return; | ||||
|     } | ||||
| 
 | ||||
|     // Expand on the next factor
 | ||||
|     const auto& conditional = conditionals_[current.nextConditional]; | ||||
| 
 | ||||
|     for (auto& fa : conditional->frontalAssignments()) { | ||||
|       std::cout << "Frontal assignment: " << fa << std::endl; | ||||
|       auto childNode = current.expand(*conditional, fa); | ||||
| 
 | ||||
|       // Again, prune if we cannot beat the worst solution
 | ||||
|       if (solutions_.prune(current.bound)) { | ||||
|         std::cout << "Pruning: bound=" << childNode.bound << std::endl; | ||||
|         continue; | ||||
|       } | ||||
| 
 | ||||
|       expansions_.push(childNode); | ||||
|     } | ||||
|   } | ||||
| 
 | ||||
|   const DiscreteBayesNet& bayesNet_; | ||||
|   std::vector<std::shared_ptr<DiscreteConditional>> conditionals_; | ||||
|   std::priority_queue<SearchNode, std::vector<SearchNode>, CompareByBound> | ||||
|       expansions_; | ||||
|   Solutions solutions_; | ||||
| }; | ||||
| 
 | ||||
| // ----------------------------------------------------------------------------
 | ||||
| // Example “Unit Tests” (trivial stubs)
 | ||||
| // ----------------------------------------------------------------------------
 | ||||
| 
 | ||||
| TEST(DiscreteBayesNet, EmptyKBest) { | ||||
|   DiscreteBayesNet net;  // no factors
 | ||||
|   BestKSearch search(net, 3); | ||||
|   auto solutions = search.run(); | ||||
|   // Expect one solution with empty assignment, error=0
 | ||||
|   EXPECT_LONGS_EQUAL(1, solutions.size()); | ||||
|   EXPECT_DOUBLES_EQUAL(0, std::fabs(solutions[0].error), 1e-9); | ||||
| } | ||||
| 
 | ||||
| TEST(DiscreteBayesNet, AsiaKBest) { | ||||
|   DiscreteBayesNet asia = constructAsiaExample(); | ||||
|   BestKSearch search(asia, 4); | ||||
|   auto solutions = search.run(); | ||||
|   EXPECT(!solutions.empty()); | ||||
|   // Regression test: check the first solution
 | ||||
|   EXPECT_DOUBLES_EQUAL(1.236627, std::fabs(solutions[0].error), 1e-5); | ||||
| } | ||||
| 
 | ||||
| /* ************************************************************************* */ | ||||
| int main() { | ||||
|   TestResult tr; | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue