Fix heuristic

release/4.3a0
Frank Dellaert 2025-01-27 14:49:54 -05:00
parent 9800e110aa
commit c4870cc840
3 changed files with 51 additions and 36 deletions

View File

@ -16,6 +16,8 @@
* @author Frank Dellaert * @author Frank Dellaert
*/ */
#include <gtsam/discrete/DiscreteEliminationTree.h>
#include <gtsam/discrete/DiscreteJunctionTree.h>
#include <gtsam/discrete/DiscreteSearch.h> #include <gtsam/discrete/DiscreteSearch.h>
namespace gtsam { namespace gtsam {
@ -39,9 +41,8 @@ struct SearchNode {
/** /**
* @brief Construct the root node for the search. * @brief Construct the root node for the search.
*/ */
static SearchNode Root(size_t numConditionals, double bound) { static SearchNode Root(size_t numSlots, double bound) {
return {DiscreteValues(), 0.0, bound, return {DiscreteValues(), 0.0, bound, static_cast<int>(numSlots) - 1};
static_cast<int>(numConditionals) - 1};
} }
struct Compare { struct Compare {
@ -60,20 +61,18 @@ struct SearchNode {
/** /**
* @brief Expands the node by assigning the next variable. * @brief Expands the node by assigning the next variable.
* *
* @param factor The discrete factor associated with the next variable * @param slot The slot to be filled.
* to be assigned.
* @param fa The frontal assignment for the next variable. * @param fa The frontal assignment for the next variable.
* @return A new SearchNode representing the expanded state. * @return A new SearchNode representing the expanded state.
*/ */
SearchNode expand(const DiscreteFactor& factor, SearchNode expand(const Slot& slot, const DiscreteValues& fa) const {
const DiscreteValues& fa) const {
// Combine the new frontal assignment with the current partial assignment // Combine the new frontal assignment with the current partial assignment
DiscreteValues newAssignment = assignment; DiscreteValues newAssignment = assignment;
for (auto& [key, value] : fa) { for (auto& [key, value] : fa) {
newAssignment[key] = value; newAssignment[key] = value;
} }
double errorSoFar = error + slot.factor->error(newAssignment);
return {newAssignment, error + factor.error(newAssignment), 0.0, return {newAssignment, errorSoFar, errorSoFar + slot.heuristic,
nextConditional - 1}; nextConditional - 1};
} }
@ -151,12 +150,19 @@ class Solutions {
} }
}; };
DiscreteSearch::DiscreteSearch(const DiscreteFactorGraph& factorGraph) { DiscreteSearch::DiscreteSearch(const DiscreteFactorGraph& factorGraph,
const Ordering& ordering,
bool buildJunctionTree) {
const DiscreteEliminationTree etree(factorGraph, ordering);
const DiscreteJunctionTree junctionTree(etree);
// GTSAM_PRINT(asia::etree);
// GTSAM_PRINT(asia::junctionTree);
slots_.reserve(factorGraph.size()); slots_.reserve(factorGraph.size());
for (auto& factor : factorGraph) { for (auto& factor : factorGraph) {
slots_.emplace_back(factor, std::vector<DiscreteValues>{}, 0.0); slots_.emplace_back(factor, std::vector<DiscreteValues>{}, 0.0);
} }
computeHeuristic(); lowerBound_ = computeHeuristic();
} }
DiscreteSearch::DiscreteSearch(const DiscreteBayesNet& bayesNet) { DiscreteSearch::DiscreteSearch(const DiscreteBayesNet& bayesNet) {
@ -164,7 +170,7 @@ DiscreteSearch::DiscreteSearch(const DiscreteBayesNet& bayesNet) {
for (auto& conditional : bayesNet) { for (auto& conditional : bayesNet) {
slots_.emplace_back(conditional, conditional->frontalAssignments(), 0.0); slots_.emplace_back(conditional, conditional->frontalAssignments(), 0.0);
} }
computeHeuristic(); lowerBound_ = computeHeuristic();
} }
DiscreteSearch::DiscreteSearch(const DiscreteBayesTree& bayesTree) { DiscreteSearch::DiscreteSearch(const DiscreteBayesTree& bayesTree) {
@ -179,7 +185,7 @@ DiscreteSearch::DiscreteSearch(const DiscreteBayesTree& bayesTree) {
slots_.reserve(bayesTree.size()); slots_.reserve(bayesTree.size());
for (const auto& root : bayesTree.roots()) collectConditionals(root); for (const auto& root : bayesTree.roots()) collectConditionals(root);
computeHeuristic(); lowerBound_ = computeHeuristic();
} }
struct SearchNodeQueue struct SearchNodeQueue
@ -199,10 +205,7 @@ struct SearchNodeQueue
} }
for (auto& fa : slot.assignments) { for (auto& fa : slot.assignments) {
auto childNode = current.expand(*slot.factor, fa); auto childNode = current.expand(slot, fa);
if (childNode.nextConditional >= 0)
// TODO(frank): this might be wrong !
childNode.bound = childNode.error + slot.heuristic;
// Again, prune if we cannot beat the worst solution // Again, prune if we cannot beat the worst solution
if (!solutions->prune(childNode.bound)) { if (!solutions->prune(childNode.bound)) {
@ -212,10 +215,12 @@ struct SearchNodeQueue
} }
}; };
#define DISCRETE_SEARCH_DEBUG
std::vector<Solution> DiscreteSearch::run(size_t K) const { std::vector<Solution> DiscreteSearch::run(size_t K) const {
Solutions solutions(K); Solutions solutions(K);
SearchNodeQueue expansions; SearchNodeQueue expansions;
expansions.push(SearchNode::Root(slots_.size(), slots_.back().heuristic)); expansions.push(SearchNode::Root(slots_.size(), lowerBound_));
#ifdef DISCRETE_SEARCH_DEBUG #ifdef DISCRETE_SEARCH_DEBUG
size_t numExpansions = 0; size_t numExpansions = 0;
@ -244,18 +249,22 @@ std::vector<Solution> DiscreteSearch::run(size_t K) const {
} }
// We have a number of factors, each with a max value, and we want to compute // We have a number of factors, each with a max value, and we want to compute
// the a lower-bound on the cost-to-go for each slot. For the first slot, this // a lower-bound on the cost-to-go for each slot, *not* including this factor.
// -log(max(factor[0])), as we only have one factor to resolve. For the second // For the first slot, this is 0.0, as this is the last slot to be filled, so
// slot, we need to add -log(max(factor[1])) to it, etc... // the cost after that is zero. For the second slot, it is h0 =
void DiscreteSearch::computeHeuristic() { // -log(max(factor[0])), because after we assign slot[1] we still need to assign
// slot[0], which will cost *at least* h0.
// We return the estimated lower bound of the cost for *all* slots.
double DiscreteSearch::computeHeuristic() {
double error = 0.0; double error = 0.0;
for (size_t i = 0; i < slots_.size(); ++i) { for (size_t i = 0; i < slots_.size(); ++i) {
slots_[i].heuristic = error;
const auto& factor = slots_[i].factor; const auto& factor = slots_[i].factor;
Ordering ordering(factor->begin(), factor->end()); Ordering ordering(factor->begin(), factor->end());
auto maxx = factor->max(ordering); auto maxx = factor->max(ordering);
error -= std::log(maxx->evaluate({})); error -= std::log(maxx->evaluate({}));
slots_[i].heuristic = error;
} }
return error;
} }
} // namespace gtsam } // namespace gtsam

View File

@ -61,8 +61,19 @@ class GTSAM_EXPORT DiscreteSearch {
public: public:
/** /**
* Construct from a DiscreteFactorGraph. * Construct from a DiscreteFactorGraph.
*
* Internally creates either an elimination tree or a junction tree. The
* latter incurs more up-front computation but the search itself might be
* faster. Then again, for the elimination tree, the heuristic will be more
* fine-grained (more slots).
*
* @param factorGraph The factor graph to search over.
* @param ordering The ordering of the variables to search over.
* @param buildJunctionTree Whether to build a junction tree for the factor
* graph.
*/ */
DiscreteSearch(const DiscreteFactorGraph& bayesNet); DiscreteSearch(const DiscreteFactorGraph& factorGraph,
const Ordering& ordering, bool buildJunctionTree = false);
/** /**
* Construct from a DiscreteBayesNet. * Construct from a DiscreteBayesNet.
@ -87,9 +98,11 @@ class GTSAM_EXPORT DiscreteSearch {
std::vector<Solution> run(size_t K = 1) const; std::vector<Solution> run(size_t K = 1) const;
private: private:
/// Compute the cumulative lower-bound cost-to-go for each slot. /// Compute the cumulative lower-bound cost-to-go after each slot is filled.
void computeHeuristic(); /// @return the estimated lower bound of the cost for *all* slots.
double computeHeuristic();
std::vector<Slot> slots_; double lowerBound_; ///< Lower bound on the cost-to-go for the entire search.
std::vector<Slot> slots_; ///< The slots to fill in the search.
}; };
} // namespace gtsam } // namespace gtsam

View File

@ -18,8 +18,6 @@
#include <CppUnitLite/TestHarness.h> #include <CppUnitLite/TestHarness.h>
#include <gtsam/base/Testable.h> #include <gtsam/base/Testable.h>
#include <gtsam/discrete/DiscreteEliminationTree.h>
#include <gtsam/discrete/DiscreteJunctionTree.h>
#include <gtsam/discrete/DiscreteSearch.h> #include <gtsam/discrete/DiscreteSearch.h>
#include "AsiaExample.h" #include "AsiaExample.h"
@ -35,12 +33,9 @@ static const DiscreteBayesNet bayesNet = createAsiaExample();
static const DiscreteFactorGraph factorGraph(bayesNet); static const DiscreteFactorGraph factorGraph(bayesNet);
static const DiscreteValues mpe = factorGraph.optimize(); static const DiscreteValues mpe = factorGraph.optimize();
// Create junction tree // Create ordering
static const Ordering ordering{D, X, B, E, L, T, S, A}; static const Ordering ordering{D, X, B, E, L, T, S, A};
static const DiscreteEliminationTree etree(factorGraph, ordering);
static const DiscreteJunctionTree junctionTree(etree);
// Create Bayes tree // Create Bayes tree
static const DiscreteBayesTree bayesTree = static const DiscreteBayesTree bayesTree =
*factorGraph.eliminateMultifrontal(ordering); *factorGraph.eliminateMultifrontal(ordering);
@ -48,9 +43,7 @@ static const DiscreteBayesTree bayesTree =
/* ************************************************************************* */ /* ************************************************************************* */
TEST(DiscreteBayesNet, AsiaFactorGraphKBest) { TEST(DiscreteBayesNet, AsiaFactorGraphKBest) {
GTSAM_PRINT(asia::etree); DiscreteSearch search(asia::factorGraph, asia::ordering);
GTSAM_PRINT(asia::junctionTree);
DiscreteSearch search(asia::factorGraph);
} }
/* ************************************************************************* */ /* ************************************************************************* */