Working etree and jtree versions
parent
c4870cc840
commit
0bc566f692
|
@ -150,21 +150,58 @@ class Solutions {
|
|||
}
|
||||
};
|
||||
|
||||
DiscreteSearch::DiscreteSearch(const DiscreteFactorGraph& factorGraph,
|
||||
const Ordering& ordering,
|
||||
bool buildJunctionTree) {
|
||||
const DiscreteEliminationTree etree(factorGraph, ordering);
|
||||
const DiscreteJunctionTree junctionTree(etree);
|
||||
DiscreteSearch::DiscreteSearch(const DiscreteEliminationTree& etree) {
|
||||
using NodePtr = std::shared_ptr<DiscreteEliminationTree::Node>;
|
||||
auto visitor = [this](const NodePtr& node, int data) {
|
||||
const auto& factors = node->factors;
|
||||
const auto factor = factors.size() == 1
|
||||
? factors.back()
|
||||
: DiscreteFactorGraph(factors).product();
|
||||
const size_t cardinality = factor->cardinality(node->key);
|
||||
std::vector<std::pair<Key, size_t>> pairs{{node->key, cardinality}};
|
||||
slots_.emplace_back(factor, DiscreteValues::CartesianProduct(pairs), 0.0);
|
||||
return data + 1;
|
||||
};
|
||||
|
||||
// GTSAM_PRINT(asia::etree);
|
||||
// GTSAM_PRINT(asia::junctionTree);
|
||||
slots_.reserve(factorGraph.size());
|
||||
for (auto& factor : factorGraph) {
|
||||
slots_.emplace_back(factor, std::vector<DiscreteValues>{}, 0.0);
|
||||
}
|
||||
const int data = 0; // unused
|
||||
treeTraversal::DepthFirstForest(etree, data, visitor);
|
||||
std::reverse(slots_.begin(), slots_.end()); // reverse slots
|
||||
lowerBound_ = computeHeuristic();
|
||||
}
|
||||
|
||||
DiscreteSearch::DiscreteSearch(const DiscreteJunctionTree& junctionTree) {
|
||||
using NodePtr = std::shared_ptr<DiscreteJunctionTree::Cluster>;
|
||||
auto visitor = [this](const NodePtr& cluster, int data) {
|
||||
const auto& factors = cluster->factors;
|
||||
const auto factor = factors.size() == 1
|
||||
? factors.back()
|
||||
: DiscreteFactorGraph(factors).product();
|
||||
std::vector<std::pair<Key, size_t>> pairs;
|
||||
for (Key key : cluster->orderedFrontalKeys) {
|
||||
pairs.emplace_back(key, factor->cardinality(key));
|
||||
}
|
||||
slots_.emplace_back(factor, DiscreteValues::CartesianProduct(pairs), 0.0);
|
||||
return data + 1;
|
||||
};
|
||||
|
||||
const int data = 0; // unused
|
||||
treeTraversal::DepthFirstForest(junctionTree, data, visitor);
|
||||
std::reverse(slots_.begin(), slots_.end()); // reverse slots
|
||||
lowerBound_ = computeHeuristic();
|
||||
}
|
||||
|
||||
DiscreteSearch DiscreteSearch::FromFactorGraph(
|
||||
const DiscreteFactorGraph& factorGraph, const Ordering& ordering,
|
||||
bool buildJunctionTree) {
|
||||
const DiscreteEliminationTree etree(factorGraph, ordering);
|
||||
if (buildJunctionTree) {
|
||||
const DiscreteJunctionTree junctionTree(etree);
|
||||
return DiscreteSearch(junctionTree);
|
||||
} else {
|
||||
return DiscreteSearch(etree);
|
||||
}
|
||||
}
|
||||
|
||||
DiscreteSearch::DiscreteSearch(const DiscreteBayesNet& bayesNet) {
|
||||
slots_.reserve(bayesNet.size());
|
||||
for (auto& conditional : bayesNet) {
|
||||
|
@ -188,6 +225,14 @@ DiscreteSearch::DiscreteSearch(const DiscreteBayesTree& bayesTree) {
|
|||
lowerBound_ = computeHeuristic();
|
||||
}
|
||||
|
||||
void DiscreteSearch::print(const std::string& name,
|
||||
const KeyFormatter& formatter) const {
|
||||
std::cout << name << " with " << slots_.size() << " slots:\n";
|
||||
for (size_t i = 0; i < slots_.size(); ++i) {
|
||||
std::cout << i << ": " << slots_[i] << std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
struct SearchNodeQueue
|
||||
: public std::priority_queue<SearchNode, std::vector<SearchNode>,
|
||||
SearchNode::Compare> {
|
||||
|
@ -215,8 +260,6 @@ struct SearchNodeQueue
|
|||
}
|
||||
};
|
||||
|
||||
#define DISCRETE_SEARCH_DEBUG
|
||||
|
||||
std::vector<Solution> DiscreteSearch::run(size_t K) const {
|
||||
Solutions solutions(K);
|
||||
SearchNodeQueue expansions;
|
||||
|
@ -252,9 +295,9 @@ std::vector<Solution> DiscreteSearch::run(size_t K) const {
|
|||
// a lower-bound on the cost-to-go for each slot, *not* including this factor.
|
||||
// For the first slot, this is 0.0, as this is the last slot to be filled, so
|
||||
// the cost after that is zero. For the second slot, it is h0 =
|
||||
// -log(max(factor[0])), because after we assign slot[1] we still need to assign
|
||||
// slot[0], which will cost *at least* h0.
|
||||
// We return the estimated lower bound of the cost for *all* slots.
|
||||
// -log(max(factor[0])), because after we assign slot[1] we still need to
|
||||
// assign slot[0], which will cost *at least* h0. We return the estimated
|
||||
// lower bound of the cost for *all* slots.
|
||||
double DiscreteSearch::computeHeuristic() {
|
||||
double error = 0.0;
|
||||
for (size_t i = 0; i < slots_.size(); ++i) {
|
||||
|
|
|
@ -10,7 +10,11 @@
|
|||
* -------------------------------------------------------------------------- */
|
||||
|
||||
/*
|
||||
* DiscreteSearch.cpp
|
||||
* @file DiscreteSearch.h
|
||||
* @brief Defines the DiscreteSearch class for discrete search algorithms.
|
||||
*
|
||||
* @details This file contains the definition of the DiscreteSearch class, which
|
||||
* is used in discrete search algorithms to find the K best solutions.
|
||||
*
|
||||
* @date January, 2025
|
||||
* @author Frank Dellaert
|
||||
|
@ -43,6 +47,13 @@ class GTSAM_EXPORT DiscreteSearch {
|
|||
/// A lower bound on the cost-to-go for each slot, e.g.,
|
||||
/// [-log(max_B P(B|A)), -log(max_A P(A))]
|
||||
double heuristic;
|
||||
|
||||
friend std::ostream& operator<<(std::ostream& os, const Slot& slot) {
|
||||
os << "Slot with " << slot.assignments.size()
|
||||
<< " assignments, heuristic=" << slot.heuristic;
|
||||
os << ", factor:\n" << slot.factor->markdown() << std::endl;
|
||||
return os;
|
||||
}
|
||||
};
|
||||
|
||||
/// A solution is then a set of assignments, covering all the slots.
|
||||
|
@ -59,6 +70,9 @@ class GTSAM_EXPORT DiscreteSearch {
|
|||
};
|
||||
|
||||
public:
|
||||
/// @name Standard Constructors
|
||||
/// @{
|
||||
|
||||
/**
|
||||
* Construct from a DiscreteFactorGraph.
|
||||
*
|
||||
|
@ -68,12 +82,26 @@ class GTSAM_EXPORT DiscreteSearch {
|
|||
* fine-grained (more slots).
|
||||
*
|
||||
* @param factorGraph The factor graph to search over.
|
||||
* @param ordering The ordering of the variables to search over.
|
||||
* @param buildJunctionTree Whether to build a junction tree for the factor
|
||||
* graph.
|
||||
* @param ordering The ordering used to create etree (and maybe jtree).
|
||||
* @param buildJunctionTree Whether to build a junction tree or not.
|
||||
*/
|
||||
DiscreteSearch(const DiscreteFactorGraph& factorGraph,
|
||||
const Ordering& ordering, bool buildJunctionTree = false);
|
||||
static DiscreteSearch FromFactorGraph(const DiscreteFactorGraph& factorGraph,
|
||||
const Ordering& ordering,
|
||||
bool buildJunctionTree = false);
|
||||
|
||||
/**
|
||||
* @brief Constructor from a DiscreteEliminationTree.
|
||||
*
|
||||
* @param etree The DiscreteEliminationTree to initialize from.
|
||||
*/
|
||||
DiscreteSearch(const DiscreteEliminationTree& etree);
|
||||
|
||||
/**
|
||||
* @brief Constructor from a DiscreteJunctionTree.
|
||||
*
|
||||
* @param junctionTree The DiscreteJunctionTree to initialize from.
|
||||
*/
|
||||
DiscreteSearch(const DiscreteJunctionTree& junctionTree);
|
||||
|
||||
/**
|
||||
* Construct from a DiscreteBayesNet.
|
||||
|
@ -85,6 +113,18 @@ class GTSAM_EXPORT DiscreteSearch {
|
|||
*/
|
||||
DiscreteSearch(const DiscreteBayesTree& bayesTree);
|
||||
|
||||
/// @}
|
||||
/// @name Testable
|
||||
/// @{
|
||||
|
||||
/** Print the tree to cout */
|
||||
void print(const std::string& name = "DiscreteSearch: ",
|
||||
const KeyFormatter& formatter = DefaultKeyFormatter) const;
|
||||
|
||||
/// @}
|
||||
/// @name Standard API
|
||||
/// @{
|
||||
|
||||
/**
|
||||
* @brief Search for the K best solutions.
|
||||
*
|
||||
|
@ -97,6 +137,8 @@ class GTSAM_EXPORT DiscreteSearch {
|
|||
*/
|
||||
std::vector<Solution> run(size_t K = 1) const;
|
||||
|
||||
/// @}
|
||||
|
||||
private:
|
||||
/// Compute the cumulative lower-bound cost-to-go after each slot is filled.
|
||||
/// @return the estimated lower bound of the cost for *all* slots.
|
||||
|
|
Loading…
Reference in New Issue