From c4870cc840e32a444754cd2b61ef5d0a344099c0 Mon Sep 17 00:00:00 2001 From: Frank Dellaert Date: Mon, 27 Jan 2025 14:49:54 -0500 Subject: [PATCH] Fix heuristic --- gtsam/discrete/DiscreteSearch.cpp | 55 ++++++++++++--------- gtsam/discrete/DiscreteSearch.h | 21 ++++++-- gtsam/discrete/tests/testDiscreteSearch.cpp | 11 +---- 3 files changed, 51 insertions(+), 36 deletions(-) diff --git a/gtsam/discrete/DiscreteSearch.cpp b/gtsam/discrete/DiscreteSearch.cpp index e3722c13e..439331fa1 100644 --- a/gtsam/discrete/DiscreteSearch.cpp +++ b/gtsam/discrete/DiscreteSearch.cpp @@ -16,6 +16,8 @@ * @author Frank Dellaert */ +#include +#include #include namespace gtsam { @@ -39,9 +41,8 @@ struct SearchNode { /** * @brief Construct the root node for the search. */ - static SearchNode Root(size_t numConditionals, double bound) { - return {DiscreteValues(), 0.0, bound, - static_cast(numConditionals) - 1}; + static SearchNode Root(size_t numSlots, double bound) { + return {DiscreteValues(), 0.0, bound, static_cast(numSlots) - 1}; } struct Compare { @@ -60,20 +61,18 @@ struct SearchNode { /** * @brief Expands the node by assigning the next variable. * - * @param factor The discrete factor associated with the next variable - * to be assigned. + * @param slot The slot to be filled. * @param fa The frontal assignment for the next variable. * @return A new SearchNode representing the expanded state. */ - SearchNode expand(const DiscreteFactor& factor, - const DiscreteValues& fa) const { + SearchNode expand(const Slot& slot, const DiscreteValues& fa) const { // Combine the new frontal assignment with the current partial assignment DiscreteValues newAssignment = assignment; for (auto& [key, value] : fa) { newAssignment[key] = value; } - - return {newAssignment, error + factor.error(newAssignment), 0.0, + double errorSoFar = error + slot.factor->error(newAssignment); + return {newAssignment, errorSoFar, errorSoFar + slot.heuristic, nextConditional - 1}; } @@ -151,12 +150,19 @@ class Solutions { } }; -DiscreteSearch::DiscreteSearch(const DiscreteFactorGraph& factorGraph) { +DiscreteSearch::DiscreteSearch(const DiscreteFactorGraph& factorGraph, + const Ordering& ordering, + bool buildJunctionTree) { + const DiscreteEliminationTree etree(factorGraph, ordering); + const DiscreteJunctionTree junctionTree(etree); + + // GTSAM_PRINT(asia::etree); + // GTSAM_PRINT(asia::junctionTree); slots_.reserve(factorGraph.size()); for (auto& factor : factorGraph) { slots_.emplace_back(factor, std::vector{}, 0.0); } - computeHeuristic(); + lowerBound_ = computeHeuristic(); } DiscreteSearch::DiscreteSearch(const DiscreteBayesNet& bayesNet) { @@ -164,7 +170,7 @@ DiscreteSearch::DiscreteSearch(const DiscreteBayesNet& bayesNet) { for (auto& conditional : bayesNet) { slots_.emplace_back(conditional, conditional->frontalAssignments(), 0.0); } - computeHeuristic(); + lowerBound_ = computeHeuristic(); } DiscreteSearch::DiscreteSearch(const DiscreteBayesTree& bayesTree) { @@ -179,7 +185,7 @@ DiscreteSearch::DiscreteSearch(const DiscreteBayesTree& bayesTree) { slots_.reserve(bayesTree.size()); for (const auto& root : bayesTree.roots()) collectConditionals(root); - computeHeuristic(); + lowerBound_ = computeHeuristic(); } struct SearchNodeQueue @@ -199,10 +205,7 @@ struct SearchNodeQueue } for (auto& fa : slot.assignments) { - auto childNode = current.expand(*slot.factor, fa); - if (childNode.nextConditional >= 0) - // TODO(frank): this might be wrong ! - childNode.bound = childNode.error + slot.heuristic; + auto childNode = current.expand(slot, fa); // Again, prune if we cannot beat the worst solution if (!solutions->prune(childNode.bound)) { @@ -212,10 +215,12 @@ struct SearchNodeQueue } }; +#define DISCRETE_SEARCH_DEBUG + std::vector DiscreteSearch::run(size_t K) const { Solutions solutions(K); SearchNodeQueue expansions; - expansions.push(SearchNode::Root(slots_.size(), slots_.back().heuristic)); + expansions.push(SearchNode::Root(slots_.size(), lowerBound_)); #ifdef DISCRETE_SEARCH_DEBUG size_t numExpansions = 0; @@ -244,18 +249,22 @@ std::vector DiscreteSearch::run(size_t K) const { } // We have a number of factors, each with a max value, and we want to compute -// the a lower-bound on the cost-to-go for each slot. For the first slot, this -// -log(max(factor[0])), as we only have one factor to resolve. For the second -// slot, we need to add -log(max(factor[1])) to it, etc... -void DiscreteSearch::computeHeuristic() { +// a lower-bound on the cost-to-go for each slot, *not* including this factor. +// For the first slot, this is 0.0, as this is the last slot to be filled, so +// the cost after that is zero. For the second slot, it is h0 = +// -log(max(factor[0])), because after we assign slot[1] we still need to assign +// slot[0], which will cost *at least* h0. +// We return the estimated lower bound of the cost for *all* slots. +double DiscreteSearch::computeHeuristic() { double error = 0.0; for (size_t i = 0; i < slots_.size(); ++i) { + slots_[i].heuristic = error; const auto& factor = slots_[i].factor; Ordering ordering(factor->begin(), factor->end()); auto maxx = factor->max(ordering); error -= std::log(maxx->evaluate({})); - slots_[i].heuristic = error; } + return error; } } // namespace gtsam diff --git a/gtsam/discrete/DiscreteSearch.h b/gtsam/discrete/DiscreteSearch.h index ddeaf38f1..3d5ee1d50 100644 --- a/gtsam/discrete/DiscreteSearch.h +++ b/gtsam/discrete/DiscreteSearch.h @@ -61,8 +61,19 @@ class GTSAM_EXPORT DiscreteSearch { public: /** * Construct from a DiscreteFactorGraph. + * + * Internally creates either an elimination tree or a junction tree. The + * latter incurs more up-front computation but the search itself might be + * faster. Then again, for the elimination tree, the heuristic will be more + * fine-grained (more slots). + * + * @param factorGraph The factor graph to search over. + * @param ordering The ordering of the variables to search over. + * @param buildJunctionTree Whether to build a junction tree for the factor + * graph. */ - DiscreteSearch(const DiscreteFactorGraph& bayesNet); + DiscreteSearch(const DiscreteFactorGraph& factorGraph, + const Ordering& ordering, bool buildJunctionTree = false); /** * Construct from a DiscreteBayesNet. @@ -87,9 +98,11 @@ class GTSAM_EXPORT DiscreteSearch { std::vector run(size_t K = 1) const; private: - /// Compute the cumulative lower-bound cost-to-go for each slot. - void computeHeuristic(); + /// Compute the cumulative lower-bound cost-to-go after each slot is filled. + /// @return the estimated lower bound of the cost for *all* slots. + double computeHeuristic(); - std::vector slots_; + double lowerBound_; ///< Lower bound on the cost-to-go for the entire search. + std::vector slots_; ///< The slots to fill in the search. }; } // namespace gtsam diff --git a/gtsam/discrete/tests/testDiscreteSearch.cpp b/gtsam/discrete/tests/testDiscreteSearch.cpp index 7e715ca62..0f424bd45 100644 --- a/gtsam/discrete/tests/testDiscreteSearch.cpp +++ b/gtsam/discrete/tests/testDiscreteSearch.cpp @@ -18,8 +18,6 @@ #include #include -#include -#include #include #include "AsiaExample.h" @@ -35,12 +33,9 @@ static const DiscreteBayesNet bayesNet = createAsiaExample(); static const DiscreteFactorGraph factorGraph(bayesNet); static const DiscreteValues mpe = factorGraph.optimize(); -// Create junction tree +// Create ordering static const Ordering ordering{D, X, B, E, L, T, S, A}; -static const DiscreteEliminationTree etree(factorGraph, ordering); -static const DiscreteJunctionTree junctionTree(etree); - // Create Bayes tree static const DiscreteBayesTree bayesTree = *factorGraph.eliminateMultifrontal(ordering); @@ -48,9 +43,7 @@ static const DiscreteBayesTree bayesTree = /* ************************************************************************* */ TEST(DiscreteBayesNet, AsiaFactorGraphKBest) { - GTSAM_PRINT(asia::etree); - GTSAM_PRINT(asia::junctionTree); - DiscreteSearch search(asia::factorGraph); + DiscreteSearch search(asia::factorGraph, asia::ordering); } /* ************************************************************************* */