Works for Bayes trees now as well
							parent
							
								
									54f493358d
								
							
						
					
					
						commit
						14eeaf93db
					
				|  | @ -36,6 +36,16 @@ struct SearchNode { | |||
|   double bound;  ///< Lower bound on the final error for unassigned variables.
 | ||||
|   int nextConditional;  ///< Index of the next conditional to be assigned.
 | ||||
| 
 | ||||
|   /**
 | ||||
|    * @brief Construct the root node for the search. | ||||
|    */ | ||||
|   static SearchNode Root(size_t numConditionals, double bound) { | ||||
|     return {.assignment = DiscreteValues(), | ||||
|             .error = 0.0, | ||||
|             .bound = bound, | ||||
|             .nextConditional = static_cast<int>(numConditionals) - 1}; | ||||
|   } | ||||
| 
 | ||||
|   struct CompareByBound { | ||||
|     bool operator()(const SearchNode& a, const SearchNode& b) const { | ||||
|       return a.bound > b.bound;  // smallest bound -> highest priority
 | ||||
|  | @ -49,20 +59,6 @@ struct SearchNode { | |||
|    */ | ||||
|   bool isComplete() const { return nextConditional < 0; } | ||||
| 
 | ||||
|   /**
 | ||||
|    * @brief Computes a lower bound on the final error for unassigned variables. | ||||
|    * | ||||
|    * @details This is a stub implementation that returns 0. Real implementations | ||||
|    * might perform partial factor analysis or use heuristics to compute the | ||||
|    * bound. | ||||
|    * | ||||
|    * @return A lower bound on the final error. | ||||
|    */ | ||||
|   double computeBound() const { | ||||
|     // Real code might do partial factor analysis or heuristics.
 | ||||
|     return 0.0; | ||||
|   } | ||||
| 
 | ||||
|   /**
 | ||||
|    * @brief Expands the node by assigning the next variable. | ||||
|    * | ||||
|  | @ -74,22 +70,15 @@ struct SearchNode { | |||
|   SearchNode expand(const DiscreteConditional& conditional, | ||||
|                     const DiscreteValues& fa) const { | ||||
|     // Combine the new frontal assignment with the current partial assignment
 | ||||
|     SearchNode child; | ||||
|     child.assignment = assignment; | ||||
|     DiscreteValues newAssignment = assignment; | ||||
|     for (auto& kv : fa) { | ||||
|       child.assignment[kv.first] = kv.second; | ||||
|       newAssignment[kv.first] = kv.second; | ||||
|     } | ||||
| 
 | ||||
|     // Compute the incremental error for this factor
 | ||||
|     child.error = error + conditional.error(child.assignment); | ||||
| 
 | ||||
|     // Compute new bound
 | ||||
|     child.bound = computeBound(); | ||||
| 
 | ||||
|     // Update the index of the next conditional
 | ||||
|     child.nextConditional = nextConditional - 1; | ||||
| 
 | ||||
|     return child; | ||||
|     return {.assignment = newAssignment, | ||||
|             .error = error + conditional.error(newAssignment), | ||||
|             .bound = 0.0, | ||||
|             .nextConditional = nextConditional - 1}; | ||||
|   } | ||||
| 
 | ||||
|   /**
 | ||||
|  | @ -184,6 +173,8 @@ class Solutions { | |||
|  */ | ||||
| class DiscreteSearch { | ||||
|  public: | ||||
|   size_t numExpansions = 0; | ||||
| 
 | ||||
|   /**
 | ||||
|    * Construct from a DiscreteBayesNet and K. | ||||
|    */ | ||||
|  | @ -193,33 +184,42 @@ class DiscreteSearch { | |||
|       conditionals_.push_back(factor); | ||||
|     } | ||||
| 
 | ||||
|     // Calculate the cost-to-go for each conditional. If there are n
 | ||||
|     // conditionals, we start with nextConditional = n-1, and the minimum error
 | ||||
|     // obtainable is the sum of all the minimum errors. We start with
 | ||||
|     // 0, and that is the minimum error of the conditional with that index:
 | ||||
|     double error = 0.0; | ||||
|     for (const auto& conditional : conditionals_) { | ||||
|       Ordering ordering(conditional->begin(), conditional->end()); | ||||
|       auto maxx = conditional->max(ordering); | ||||
|       assert(maxx->size() == 1); | ||||
|       error -= std::log(maxx->evaluate({})); | ||||
|       costToGo_.push_back(error); | ||||
|     } | ||||
|     // Calculate the cost-to-go for each conditional
 | ||||
|     costToGo_ = computeCostToGo(conditionals_); | ||||
| 
 | ||||
|     // Create the root node: no variables assigned, nextConditional = last.
 | ||||
|     SearchNode root{ | ||||
|         .assignment = DiscreteValues(), | ||||
|         .error = 0.0, | ||||
|         .nextConditional = static_cast<int>(conditionals_.size()) - 1}; | ||||
|     if (!costToGo_.empty()) root.bound = costToGo_.back(); | ||||
|     expansions_.push(root); | ||||
|     // Create the root node and push it to the expansions queue
 | ||||
|     expansions_.push(SearchNode::Root( | ||||
|         conditionals_.size(), costToGo_.empty() ? 0.0 : costToGo_.back())); | ||||
|   } | ||||
| 
 | ||||
|   /**
 | ||||
|    * Construct from a DiscreteBayesNet and K. | ||||
|    * Construct from a DiscreteBayesTree and K. | ||||
|    */ | ||||
|   DiscreteSearch(const DiscreteBayesTree& bayesTree, size_t K) | ||||
|       : solutions_(K) {} | ||||
|   DiscreteSearch(const DiscreteBayesTree& bayesTree, size_t K) : solutions_(K) { | ||||
|     using CliquePtr = DiscreteBayesTree::sharedClique; | ||||
|     std::function<void(const CliquePtr&)> collectConditionals = | ||||
|         [&](const CliquePtr& clique) -> void { | ||||
|       if (!clique) return; | ||||
| 
 | ||||
|       // Recursive post-order traversal: process children first
 | ||||
|       for (const auto& child : clique->children) { | ||||
|         collectConditionals(child); | ||||
|       } | ||||
| 
 | ||||
|       // Then add the current clique's conditional
 | ||||
|       conditionals_.push_back(clique->conditional()); | ||||
|     }; | ||||
| 
 | ||||
|     // Start traversal from each root in the tree
 | ||||
|     for (const auto& root : bayesTree.roots()) collectConditionals(root); | ||||
| 
 | ||||
|     // Calculate the cost-to-go for each conditional
 | ||||
|     costToGo_ = computeCostToGo(conditionals_); | ||||
| 
 | ||||
|     // Create the root node and push it to the expansions queue
 | ||||
|     expansions_.push(SearchNode::Root( | ||||
|         conditionals_.size(), costToGo_.empty() ? 0.0 : costToGo_.back())); | ||||
|   } | ||||
| 
 | ||||
|   /**
 | ||||
|    * @brief Search for the K best solutions. | ||||
|  | @ -233,6 +233,7 @@ class DiscreteSearch { | |||
|    */ | ||||
|   std::vector<Solution> run() { | ||||
|     while (!expansions_.empty()) { | ||||
|       numExpansions++; | ||||
|       expandNextNode(); | ||||
|     } | ||||
| 
 | ||||
|  | @ -241,6 +242,29 @@ class DiscreteSearch { | |||
|   } | ||||
| 
 | ||||
|  private: | ||||
|   /**
 | ||||
|    * @brief Compute the cost-to-go for each conditional. | ||||
|    * | ||||
|    * @param conditionals The conditionals of the DiscreteBayesNet. | ||||
|    * @return A vector of cost-to-go values. | ||||
|    */ | ||||
|   static std::vector<double> computeCostToGo( | ||||
|       const std::vector<DiscreteConditional::shared_ptr>& conditionals) { | ||||
|     std::vector<double> costToGo; | ||||
|     double error = 0.0; | ||||
|     for (const auto& conditional : conditionals) { | ||||
|       Ordering ordering(conditional->begin(), conditional->end()); | ||||
|       auto maxx = conditional->max(ordering); | ||||
|       assert(maxx->size() == 1); | ||||
|       error -= std::log(maxx->evaluate({})); | ||||
|       costToGo.push_back(error); | ||||
|     } | ||||
|     return costToGo; | ||||
|   } | ||||
| 
 | ||||
|   /**
 | ||||
|    * @brief Expand the next node in the search tree. | ||||
|    */ | ||||
|   void expandNextNode() { | ||||
|     // Pop the partial assignment with the smallest bound
 | ||||
|     SearchNode current = expansions_.top(); | ||||
|  | @ -273,7 +297,7 @@ class DiscreteSearch { | |||
|     } | ||||
|   } | ||||
| 
 | ||||
|   std::vector<std::shared_ptr<DiscreteConditional>> conditionals_; | ||||
|   std::vector<DiscreteConditional::shared_ptr> conditionals_; | ||||
|   std::vector<double> costToGo_; | ||||
|   std::priority_queue<SearchNode, std::vector<SearchNode>, | ||||
|                       SearchNode::CompareByBound> | ||||
|  |  | |||
|  | @ -30,7 +30,7 @@ static const DiscreteKey Dyspnea(D, 2), XRay(X, 2), Either(E, 2), | |||
|     Bronchitis(B, 2), LungCancer(L, 2), Tuberculosis(T, 2), Smoking(S, 2), | ||||
|     Asia(A, 2); | ||||
| 
 | ||||
| // Function to construct the incomplete Asia example
 | ||||
| // Function to construct the Asia priors
 | ||||
| DiscreteBayesNet createPriors() { | ||||
|   DiscreteBayesNet priors; | ||||
|   priors.add(Smoking % "50/50"); | ||||
|  |  | |||
|  | @ -49,8 +49,9 @@ TEST(DiscreteBayesNet, AsiaKBest) { | |||
|   DiscreteSearch search(asia, 4); | ||||
|   auto solutions = search.run(); | ||||
|   EXPECT(!solutions.empty()); | ||||
|   // Regression test: check the first solution
 | ||||
|   // Regression test: check the first and last solution
 | ||||
|   EXPECT_DOUBLES_EQUAL(1.236627, std::fabs(solutions[0].error), 1e-5); | ||||
|   EXPECT_DOUBLES_EQUAL(2.201708, std::fabs(solutions[3].error), 1e-5); | ||||
| } | ||||
| 
 | ||||
| /* ************************************************************************* */ | ||||
|  | @ -70,16 +71,21 @@ TEST(DiscreteBayesTree, testEmptyTree) { | |||
| TEST(DiscreteBayesTree, testTrivialOneClique) { | ||||
|   using namespace asia_example; | ||||
|   DiscreteFactorGraph asia(createAsiaExample()); | ||||
|   DiscreteBayesTree::shared_ptr bt = asia.eliminateMultifrontal(); | ||||
|   const Ordering ordering{D, X, B, E, L, T, S, A}; | ||||
|   DiscreteBayesTree::shared_ptr bt = asia.eliminateMultifrontal(ordering); | ||||
|   GTSAM_PRINT(*bt); | ||||
| 
 | ||||
|   // Ask for top 4 solutions
 | ||||
|   DiscreteSearch search(*bt, 4); | ||||
|   auto solutions = search.run(); | ||||
| 
 | ||||
|   // print numExpansions
 | ||||
|   std::cout << "Number of expansions: " << search.numExpansions << std::endl; | ||||
| 
 | ||||
|   EXPECT(!solutions.empty()); | ||||
|   // Regression test: check the first solution
 | ||||
|   // Regression test: check the first and last solution
 | ||||
|   EXPECT_DOUBLES_EQUAL(1.236627, std::fabs(solutions[0].error), 1e-5); | ||||
|   EXPECT_DOUBLES_EQUAL(2.201708, std::fabs(solutions[3].error), 1e-5); | ||||
| } | ||||
| 
 | ||||
| /* ************************************************************************* */ | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue