diff --git a/gtsam/discrete/DiscreteSearch.h b/gtsam/discrete/DiscreteSearch.h index c267001b5..24517ee29 100644 --- a/gtsam/discrete/DiscreteSearch.h +++ b/gtsam/discrete/DiscreteSearch.h @@ -36,6 +36,16 @@ struct SearchNode { double bound; ///< Lower bound on the final error for unassigned variables. int nextConditional; ///< Index of the next conditional to be assigned. + /** + * @brief Construct the root node for the search. + */ + static SearchNode Root(size_t numConditionals, double bound) { + return {.assignment = DiscreteValues(), + .error = 0.0, + .bound = bound, + .nextConditional = static_cast(numConditionals) - 1}; + } + struct CompareByBound { bool operator()(const SearchNode& a, const SearchNode& b) const { return a.bound > b.bound; // smallest bound -> highest priority @@ -49,20 +59,6 @@ struct SearchNode { */ bool isComplete() const { return nextConditional < 0; } - /** - * @brief Computes a lower bound on the final error for unassigned variables. - * - * @details This is a stub implementation that returns 0. Real implementations - * might perform partial factor analysis or use heuristics to compute the - * bound. - * - * @return A lower bound on the final error. - */ - double computeBound() const { - // Real code might do partial factor analysis or heuristics. - return 0.0; - } - /** * @brief Expands the node by assigning the next variable. * @@ -74,22 +70,15 @@ struct SearchNode { SearchNode expand(const DiscreteConditional& conditional, const DiscreteValues& fa) const { // Combine the new frontal assignment with the current partial assignment - SearchNode child; - child.assignment = assignment; + DiscreteValues newAssignment = assignment; for (auto& kv : fa) { - child.assignment[kv.first] = kv.second; + newAssignment[kv.first] = kv.second; } - // Compute the incremental error for this factor - child.error = error + conditional.error(child.assignment); - - // Compute new bound - child.bound = computeBound(); - - // Update the index of the next conditional - child.nextConditional = nextConditional - 1; - - return child; + return {.assignment = newAssignment, + .error = error + conditional.error(newAssignment), + .bound = 0.0, + .nextConditional = nextConditional - 1}; } /** @@ -184,6 +173,8 @@ class Solutions { */ class DiscreteSearch { public: + size_t numExpansions = 0; + /** * Construct from a DiscreteBayesNet and K. */ @@ -193,33 +184,42 @@ class DiscreteSearch { conditionals_.push_back(factor); } - // Calculate the cost-to-go for each conditional. If there are n - // conditionals, we start with nextConditional = n-1, and the minimum error - // obtainable is the sum of all the minimum errors. We start with - // 0, and that is the minimum error of the conditional with that index: - double error = 0.0; - for (const auto& conditional : conditionals_) { - Ordering ordering(conditional->begin(), conditional->end()); - auto maxx = conditional->max(ordering); - assert(maxx->size() == 1); - error -= std::log(maxx->evaluate({})); - costToGo_.push_back(error); - } + // Calculate the cost-to-go for each conditional + costToGo_ = computeCostToGo(conditionals_); - // Create the root node: no variables assigned, nextConditional = last. - SearchNode root{ - .assignment = DiscreteValues(), - .error = 0.0, - .nextConditional = static_cast(conditionals_.size()) - 1}; - if (!costToGo_.empty()) root.bound = costToGo_.back(); - expansions_.push(root); + // Create the root node and push it to the expansions queue + expansions_.push(SearchNode::Root( + conditionals_.size(), costToGo_.empty() ? 0.0 : costToGo_.back())); } /** - * Construct from a DiscreteBayesNet and K. + * Construct from a DiscreteBayesTree and K. */ - DiscreteSearch(const DiscreteBayesTree& bayesTree, size_t K) - : solutions_(K) {} + DiscreteSearch(const DiscreteBayesTree& bayesTree, size_t K) : solutions_(K) { + using CliquePtr = DiscreteBayesTree::sharedClique; + std::function collectConditionals = + [&](const CliquePtr& clique) -> void { + if (!clique) return; + + // Recursive post-order traversal: process children first + for (const auto& child : clique->children) { + collectConditionals(child); + } + + // Then add the current clique's conditional + conditionals_.push_back(clique->conditional()); + }; + + // Start traversal from each root in the tree + for (const auto& root : bayesTree.roots()) collectConditionals(root); + + // Calculate the cost-to-go for each conditional + costToGo_ = computeCostToGo(conditionals_); + + // Create the root node and push it to the expansions queue + expansions_.push(SearchNode::Root( + conditionals_.size(), costToGo_.empty() ? 0.0 : costToGo_.back())); + } /** * @brief Search for the K best solutions. @@ -233,6 +233,7 @@ class DiscreteSearch { */ std::vector run() { while (!expansions_.empty()) { + numExpansions++; expandNextNode(); } @@ -241,6 +242,29 @@ class DiscreteSearch { } private: + /** + * @brief Compute the cost-to-go for each conditional. + * + * @param conditionals The conditionals of the DiscreteBayesNet. + * @return A vector of cost-to-go values. + */ + static std::vector computeCostToGo( + const std::vector& conditionals) { + std::vector costToGo; + double error = 0.0; + for (const auto& conditional : conditionals) { + Ordering ordering(conditional->begin(), conditional->end()); + auto maxx = conditional->max(ordering); + assert(maxx->size() == 1); + error -= std::log(maxx->evaluate({})); + costToGo.push_back(error); + } + return costToGo; + } + + /** + * @brief Expand the next node in the search tree. + */ void expandNextNode() { // Pop the partial assignment with the smallest bound SearchNode current = expansions_.top(); @@ -273,7 +297,7 @@ class DiscreteSearch { } } - std::vector> conditionals_; + std::vector conditionals_; std::vector costToGo_; std::priority_queue, SearchNode::CompareByBound> diff --git a/gtsam/discrete/tests/AsiaExample.h b/gtsam/discrete/tests/AsiaExample.h index 7413f6899..6c327daec 100644 --- a/gtsam/discrete/tests/AsiaExample.h +++ b/gtsam/discrete/tests/AsiaExample.h @@ -30,7 +30,7 @@ static const DiscreteKey Dyspnea(D, 2), XRay(X, 2), Either(E, 2), Bronchitis(B, 2), LungCancer(L, 2), Tuberculosis(T, 2), Smoking(S, 2), Asia(A, 2); -// Function to construct the incomplete Asia example +// Function to construct the Asia priors DiscreteBayesNet createPriors() { DiscreteBayesNet priors; priors.add(Smoking % "50/50"); diff --git a/gtsam/discrete/tests/testDiscreteSearch.cpp b/gtsam/discrete/tests/testDiscreteSearch.cpp index b0f2b03e3..041c3b465 100644 --- a/gtsam/discrete/tests/testDiscreteSearch.cpp +++ b/gtsam/discrete/tests/testDiscreteSearch.cpp @@ -49,8 +49,9 @@ TEST(DiscreteBayesNet, AsiaKBest) { DiscreteSearch search(asia, 4); auto solutions = search.run(); EXPECT(!solutions.empty()); - // Regression test: check the first solution + // Regression test: check the first and last solution EXPECT_DOUBLES_EQUAL(1.236627, std::fabs(solutions[0].error), 1e-5); + EXPECT_DOUBLES_EQUAL(2.201708, std::fabs(solutions[3].error), 1e-5); } /* ************************************************************************* */ @@ -70,16 +71,21 @@ TEST(DiscreteBayesTree, testEmptyTree) { TEST(DiscreteBayesTree, testTrivialOneClique) { using namespace asia_example; DiscreteFactorGraph asia(createAsiaExample()); - DiscreteBayesTree::shared_ptr bt = asia.eliminateMultifrontal(); + const Ordering ordering{D, X, B, E, L, T, S, A}; + DiscreteBayesTree::shared_ptr bt = asia.eliminateMultifrontal(ordering); GTSAM_PRINT(*bt); // Ask for top 4 solutions DiscreteSearch search(*bt, 4); auto solutions = search.run(); + // print numExpansions + std::cout << "Number of expansions: " << search.numExpansions << std::endl; + EXPECT(!solutions.empty()); - // Regression test: check the first solution + // Regression test: check the first and last solution EXPECT_DOUBLES_EQUAL(1.236627, std::fabs(solutions[0].error), 1e-5); + EXPECT_DOUBLES_EQUAL(2.201708, std::fabs(solutions[3].error), 1e-5); } /* ************************************************************************* */