Fix heuristic
parent
9800e110aa
commit
c4870cc840
|
@ -16,6 +16,8 @@
|
|||
* @author Frank Dellaert
|
||||
*/
|
||||
|
||||
#include <gtsam/discrete/DiscreteEliminationTree.h>
|
||||
#include <gtsam/discrete/DiscreteJunctionTree.h>
|
||||
#include <gtsam/discrete/DiscreteSearch.h>
|
||||
|
||||
namespace gtsam {
|
||||
|
@ -39,9 +41,8 @@ struct SearchNode {
|
|||
/**
|
||||
* @brief Construct the root node for the search.
|
||||
*/
|
||||
static SearchNode Root(size_t numConditionals, double bound) {
|
||||
return {DiscreteValues(), 0.0, bound,
|
||||
static_cast<int>(numConditionals) - 1};
|
||||
static SearchNode Root(size_t numSlots, double bound) {
|
||||
return {DiscreteValues(), 0.0, bound, static_cast<int>(numSlots) - 1};
|
||||
}
|
||||
|
||||
struct Compare {
|
||||
|
@ -60,20 +61,18 @@ struct SearchNode {
|
|||
/**
|
||||
* @brief Expands the node by assigning the next variable.
|
||||
*
|
||||
* @param factor The discrete factor associated with the next variable
|
||||
* to be assigned.
|
||||
* @param slot The slot to be filled.
|
||||
* @param fa The frontal assignment for the next variable.
|
||||
* @return A new SearchNode representing the expanded state.
|
||||
*/
|
||||
SearchNode expand(const DiscreteFactor& factor,
|
||||
const DiscreteValues& fa) const {
|
||||
SearchNode expand(const Slot& slot, const DiscreteValues& fa) const {
|
||||
// Combine the new frontal assignment with the current partial assignment
|
||||
DiscreteValues newAssignment = assignment;
|
||||
for (auto& [key, value] : fa) {
|
||||
newAssignment[key] = value;
|
||||
}
|
||||
|
||||
return {newAssignment, error + factor.error(newAssignment), 0.0,
|
||||
double errorSoFar = error + slot.factor->error(newAssignment);
|
||||
return {newAssignment, errorSoFar, errorSoFar + slot.heuristic,
|
||||
nextConditional - 1};
|
||||
}
|
||||
|
||||
|
@ -151,12 +150,19 @@ class Solutions {
|
|||
}
|
||||
};
|
||||
|
||||
DiscreteSearch::DiscreteSearch(const DiscreteFactorGraph& factorGraph) {
|
||||
DiscreteSearch::DiscreteSearch(const DiscreteFactorGraph& factorGraph,
|
||||
const Ordering& ordering,
|
||||
bool buildJunctionTree) {
|
||||
const DiscreteEliminationTree etree(factorGraph, ordering);
|
||||
const DiscreteJunctionTree junctionTree(etree);
|
||||
|
||||
// GTSAM_PRINT(asia::etree);
|
||||
// GTSAM_PRINT(asia::junctionTree);
|
||||
slots_.reserve(factorGraph.size());
|
||||
for (auto& factor : factorGraph) {
|
||||
slots_.emplace_back(factor, std::vector<DiscreteValues>{}, 0.0);
|
||||
}
|
||||
computeHeuristic();
|
||||
lowerBound_ = computeHeuristic();
|
||||
}
|
||||
|
||||
DiscreteSearch::DiscreteSearch(const DiscreteBayesNet& bayesNet) {
|
||||
|
@ -164,7 +170,7 @@ DiscreteSearch::DiscreteSearch(const DiscreteBayesNet& bayesNet) {
|
|||
for (auto& conditional : bayesNet) {
|
||||
slots_.emplace_back(conditional, conditional->frontalAssignments(), 0.0);
|
||||
}
|
||||
computeHeuristic();
|
||||
lowerBound_ = computeHeuristic();
|
||||
}
|
||||
|
||||
DiscreteSearch::DiscreteSearch(const DiscreteBayesTree& bayesTree) {
|
||||
|
@ -179,7 +185,7 @@ DiscreteSearch::DiscreteSearch(const DiscreteBayesTree& bayesTree) {
|
|||
|
||||
slots_.reserve(bayesTree.size());
|
||||
for (const auto& root : bayesTree.roots()) collectConditionals(root);
|
||||
computeHeuristic();
|
||||
lowerBound_ = computeHeuristic();
|
||||
}
|
||||
|
||||
struct SearchNodeQueue
|
||||
|
@ -199,10 +205,7 @@ struct SearchNodeQueue
|
|||
}
|
||||
|
||||
for (auto& fa : slot.assignments) {
|
||||
auto childNode = current.expand(*slot.factor, fa);
|
||||
if (childNode.nextConditional >= 0)
|
||||
// TODO(frank): this might be wrong !
|
||||
childNode.bound = childNode.error + slot.heuristic;
|
||||
auto childNode = current.expand(slot, fa);
|
||||
|
||||
// Again, prune if we cannot beat the worst solution
|
||||
if (!solutions->prune(childNode.bound)) {
|
||||
|
@ -212,10 +215,12 @@ struct SearchNodeQueue
|
|||
}
|
||||
};
|
||||
|
||||
#define DISCRETE_SEARCH_DEBUG
|
||||
|
||||
std::vector<Solution> DiscreteSearch::run(size_t K) const {
|
||||
Solutions solutions(K);
|
||||
SearchNodeQueue expansions;
|
||||
expansions.push(SearchNode::Root(slots_.size(), slots_.back().heuristic));
|
||||
expansions.push(SearchNode::Root(slots_.size(), lowerBound_));
|
||||
|
||||
#ifdef DISCRETE_SEARCH_DEBUG
|
||||
size_t numExpansions = 0;
|
||||
|
@ -244,18 +249,22 @@ std::vector<Solution> DiscreteSearch::run(size_t K) const {
|
|||
}
|
||||
|
||||
// We have a number of factors, each with a max value, and we want to compute
|
||||
// the a lower-bound on the cost-to-go for each slot. For the first slot, this
|
||||
// -log(max(factor[0])), as we only have one factor to resolve. For the second
|
||||
// slot, we need to add -log(max(factor[1])) to it, etc...
|
||||
void DiscreteSearch::computeHeuristic() {
|
||||
// a lower-bound on the cost-to-go for each slot, *not* including this factor.
|
||||
// For the first slot, this is 0.0, as this is the last slot to be filled, so
|
||||
// the cost after that is zero. For the second slot, it is h0 =
|
||||
// -log(max(factor[0])), because after we assign slot[1] we still need to assign
|
||||
// slot[0], which will cost *at least* h0.
|
||||
// We return the estimated lower bound of the cost for *all* slots.
|
||||
double DiscreteSearch::computeHeuristic() {
|
||||
double error = 0.0;
|
||||
for (size_t i = 0; i < slots_.size(); ++i) {
|
||||
slots_[i].heuristic = error;
|
||||
const auto& factor = slots_[i].factor;
|
||||
Ordering ordering(factor->begin(), factor->end());
|
||||
auto maxx = factor->max(ordering);
|
||||
error -= std::log(maxx->evaluate({}));
|
||||
slots_[i].heuristic = error;
|
||||
}
|
||||
return error;
|
||||
}
|
||||
|
||||
} // namespace gtsam
|
||||
|
|
|
@ -61,8 +61,19 @@ class GTSAM_EXPORT DiscreteSearch {
|
|||
public:
|
||||
/**
|
||||
* Construct from a DiscreteFactorGraph.
|
||||
*
|
||||
* Internally creates either an elimination tree or a junction tree. The
|
||||
* latter incurs more up-front computation but the search itself might be
|
||||
* faster. Then again, for the elimination tree, the heuristic will be more
|
||||
* fine-grained (more slots).
|
||||
*
|
||||
* @param factorGraph The factor graph to search over.
|
||||
* @param ordering The ordering of the variables to search over.
|
||||
* @param buildJunctionTree Whether to build a junction tree for the factor
|
||||
* graph.
|
||||
*/
|
||||
DiscreteSearch(const DiscreteFactorGraph& bayesNet);
|
||||
DiscreteSearch(const DiscreteFactorGraph& factorGraph,
|
||||
const Ordering& ordering, bool buildJunctionTree = false);
|
||||
|
||||
/**
|
||||
* Construct from a DiscreteBayesNet.
|
||||
|
@ -87,9 +98,11 @@ class GTSAM_EXPORT DiscreteSearch {
|
|||
std::vector<Solution> run(size_t K = 1) const;
|
||||
|
||||
private:
|
||||
/// Compute the cumulative lower-bound cost-to-go for each slot.
|
||||
void computeHeuristic();
|
||||
/// Compute the cumulative lower-bound cost-to-go after each slot is filled.
|
||||
/// @return the estimated lower bound of the cost for *all* slots.
|
||||
double computeHeuristic();
|
||||
|
||||
std::vector<Slot> slots_;
|
||||
double lowerBound_; ///< Lower bound on the cost-to-go for the entire search.
|
||||
std::vector<Slot> slots_; ///< The slots to fill in the search.
|
||||
};
|
||||
} // namespace gtsam
|
||||
|
|
|
@ -18,8 +18,6 @@
|
|||
|
||||
#include <CppUnitLite/TestHarness.h>
|
||||
#include <gtsam/base/Testable.h>
|
||||
#include <gtsam/discrete/DiscreteEliminationTree.h>
|
||||
#include <gtsam/discrete/DiscreteJunctionTree.h>
|
||||
#include <gtsam/discrete/DiscreteSearch.h>
|
||||
|
||||
#include "AsiaExample.h"
|
||||
|
@ -35,12 +33,9 @@ static const DiscreteBayesNet bayesNet = createAsiaExample();
|
|||
static const DiscreteFactorGraph factorGraph(bayesNet);
|
||||
static const DiscreteValues mpe = factorGraph.optimize();
|
||||
|
||||
// Create junction tree
|
||||
// Create ordering
|
||||
static const Ordering ordering{D, X, B, E, L, T, S, A};
|
||||
|
||||
static const DiscreteEliminationTree etree(factorGraph, ordering);
|
||||
static const DiscreteJunctionTree junctionTree(etree);
|
||||
|
||||
// Create Bayes tree
|
||||
static const DiscreteBayesTree bayesTree =
|
||||
*factorGraph.eliminateMultifrontal(ordering);
|
||||
|
@ -48,9 +43,7 @@ static const DiscreteBayesTree bayesTree =
|
|||
|
||||
/* ************************************************************************* */
|
||||
TEST(DiscreteBayesNet, AsiaFactorGraphKBest) {
|
||||
GTSAM_PRINT(asia::etree);
|
||||
GTSAM_PRINT(asia::junctionTree);
|
||||
DiscreteSearch search(asia::factorGraph);
|
||||
DiscreteSearch search(asia::factorGraph, asia::ordering);
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
|
|
Loading…
Reference in New Issue