Fix heuristic

release/4.3a0
Frank Dellaert 2025-01-27 14:49:54 -05:00
parent 9800e110aa
commit c4870cc840
3 changed files with 51 additions and 36 deletions

View File

@ -16,6 +16,8 @@
* @author Frank Dellaert
*/
#include <gtsam/discrete/DiscreteEliminationTree.h>
#include <gtsam/discrete/DiscreteJunctionTree.h>
#include <gtsam/discrete/DiscreteSearch.h>
namespace gtsam {
@ -39,9 +41,8 @@ struct SearchNode {
/**
* @brief Construct the root node for the search.
*/
static SearchNode Root(size_t numConditionals, double bound) {
return {DiscreteValues(), 0.0, bound,
static_cast<int>(numConditionals) - 1};
static SearchNode Root(size_t numSlots, double bound) {
return {DiscreteValues(), 0.0, bound, static_cast<int>(numSlots) - 1};
}
struct Compare {
@ -60,20 +61,18 @@ struct SearchNode {
/**
* @brief Expands the node by assigning the next variable.
*
* @param factor The discrete factor associated with the next variable
* to be assigned.
* @param slot The slot to be filled.
* @param fa The frontal assignment for the next variable.
* @return A new SearchNode representing the expanded state.
*/
SearchNode expand(const DiscreteFactor& factor,
const DiscreteValues& fa) const {
SearchNode expand(const Slot& slot, const DiscreteValues& fa) const {
// Combine the new frontal assignment with the current partial assignment
DiscreteValues newAssignment = assignment;
for (auto& [key, value] : fa) {
newAssignment[key] = value;
}
return {newAssignment, error + factor.error(newAssignment), 0.0,
double errorSoFar = error + slot.factor->error(newAssignment);
return {newAssignment, errorSoFar, errorSoFar + slot.heuristic,
nextConditional - 1};
}
@ -151,12 +150,19 @@ class Solutions {
}
};
DiscreteSearch::DiscreteSearch(const DiscreteFactorGraph& factorGraph) {
DiscreteSearch::DiscreteSearch(const DiscreteFactorGraph& factorGraph,
const Ordering& ordering,
bool buildJunctionTree) {
const DiscreteEliminationTree etree(factorGraph, ordering);
const DiscreteJunctionTree junctionTree(etree);
// GTSAM_PRINT(asia::etree);
// GTSAM_PRINT(asia::junctionTree);
slots_.reserve(factorGraph.size());
for (auto& factor : factorGraph) {
slots_.emplace_back(factor, std::vector<DiscreteValues>{}, 0.0);
}
computeHeuristic();
lowerBound_ = computeHeuristic();
}
DiscreteSearch::DiscreteSearch(const DiscreteBayesNet& bayesNet) {
@ -164,7 +170,7 @@ DiscreteSearch::DiscreteSearch(const DiscreteBayesNet& bayesNet) {
for (auto& conditional : bayesNet) {
slots_.emplace_back(conditional, conditional->frontalAssignments(), 0.0);
}
computeHeuristic();
lowerBound_ = computeHeuristic();
}
DiscreteSearch::DiscreteSearch(const DiscreteBayesTree& bayesTree) {
@ -179,7 +185,7 @@ DiscreteSearch::DiscreteSearch(const DiscreteBayesTree& bayesTree) {
slots_.reserve(bayesTree.size());
for (const auto& root : bayesTree.roots()) collectConditionals(root);
computeHeuristic();
lowerBound_ = computeHeuristic();
}
struct SearchNodeQueue
@ -199,10 +205,7 @@ struct SearchNodeQueue
}
for (auto& fa : slot.assignments) {
auto childNode = current.expand(*slot.factor, fa);
if (childNode.nextConditional >= 0)
// TODO(frank): this might be wrong !
childNode.bound = childNode.error + slot.heuristic;
auto childNode = current.expand(slot, fa);
// Again, prune if we cannot beat the worst solution
if (!solutions->prune(childNode.bound)) {
@ -212,10 +215,12 @@ struct SearchNodeQueue
}
};
#define DISCRETE_SEARCH_DEBUG
std::vector<Solution> DiscreteSearch::run(size_t K) const {
Solutions solutions(K);
SearchNodeQueue expansions;
expansions.push(SearchNode::Root(slots_.size(), slots_.back().heuristic));
expansions.push(SearchNode::Root(slots_.size(), lowerBound_));
#ifdef DISCRETE_SEARCH_DEBUG
size_t numExpansions = 0;
@ -244,18 +249,22 @@ std::vector<Solution> DiscreteSearch::run(size_t K) const {
}
// We have a number of factors, each with a max value, and we want to compute
// the a lower-bound on the cost-to-go for each slot. For the first slot, this
// -log(max(factor[0])), as we only have one factor to resolve. For the second
// slot, we need to add -log(max(factor[1])) to it, etc...
void DiscreteSearch::computeHeuristic() {
// a lower-bound on the cost-to-go for each slot, *not* including this factor.
// For the first slot, this is 0.0, as this is the last slot to be filled, so
// the cost after that is zero. For the second slot, it is h0 =
// -log(max(factor[0])), because after we assign slot[1] we still need to assign
// slot[0], which will cost *at least* h0.
// We return the estimated lower bound of the cost for *all* slots.
double DiscreteSearch::computeHeuristic() {
double error = 0.0;
for (size_t i = 0; i < slots_.size(); ++i) {
slots_[i].heuristic = error;
const auto& factor = slots_[i].factor;
Ordering ordering(factor->begin(), factor->end());
auto maxx = factor->max(ordering);
error -= std::log(maxx->evaluate({}));
slots_[i].heuristic = error;
}
return error;
}
} // namespace gtsam

View File

@ -61,8 +61,19 @@ class GTSAM_EXPORT DiscreteSearch {
public:
/**
* Construct from a DiscreteFactorGraph.
*
* Internally creates either an elimination tree or a junction tree. The
* latter incurs more up-front computation but the search itself might be
* faster. Then again, for the elimination tree, the heuristic will be more
* fine-grained (more slots).
*
* @param factorGraph The factor graph to search over.
* @param ordering The ordering of the variables to search over.
* @param buildJunctionTree Whether to build a junction tree for the factor
* graph.
*/
DiscreteSearch(const DiscreteFactorGraph& bayesNet);
DiscreteSearch(const DiscreteFactorGraph& factorGraph,
const Ordering& ordering, bool buildJunctionTree = false);
/**
* Construct from a DiscreteBayesNet.
@ -87,9 +98,11 @@ class GTSAM_EXPORT DiscreteSearch {
std::vector<Solution> run(size_t K = 1) const;
private:
/// Compute the cumulative lower-bound cost-to-go for each slot.
void computeHeuristic();
/// Compute the cumulative lower-bound cost-to-go after each slot is filled.
/// @return the estimated lower bound of the cost for *all* slots.
double computeHeuristic();
std::vector<Slot> slots_;
double lowerBound_; ///< Lower bound on the cost-to-go for the entire search.
std::vector<Slot> slots_; ///< The slots to fill in the search.
};
} // namespace gtsam

View File

@ -18,8 +18,6 @@
#include <CppUnitLite/TestHarness.h>
#include <gtsam/base/Testable.h>
#include <gtsam/discrete/DiscreteEliminationTree.h>
#include <gtsam/discrete/DiscreteJunctionTree.h>
#include <gtsam/discrete/DiscreteSearch.h>
#include "AsiaExample.h"
@ -35,12 +33,9 @@ static const DiscreteBayesNet bayesNet = createAsiaExample();
static const DiscreteFactorGraph factorGraph(bayesNet);
static const DiscreteValues mpe = factorGraph.optimize();
// Create junction tree
// Create ordering
static const Ordering ordering{D, X, B, E, L, T, S, A};
static const DiscreteEliminationTree etree(factorGraph, ordering);
static const DiscreteJunctionTree junctionTree(etree);
// Create Bayes tree
static const DiscreteBayesTree bayesTree =
*factorGraph.eliminateMultifrontal(ordering);
@ -48,9 +43,7 @@ static const DiscreteBayesTree bayesTree =
/* ************************************************************************* */
TEST(DiscreteBayesNet, AsiaFactorGraphKBest) {
GTSAM_PRINT(asia::etree);
GTSAM_PRINT(asia::junctionTree);
DiscreteSearch search(asia::factorGraph);
DiscreteSearch search(asia::factorGraph, asia::ordering);
}
/* ************************************************************************* */