Working etree and jtree versions

release/4.3a0
Frank Dellaert 2025-01-27 19:06:01 -05:00
parent c4870cc840
commit 0bc566f692
2 changed files with 107 additions and 22 deletions

View File

@ -150,21 +150,58 @@ class Solutions {
}
};
DiscreteSearch::DiscreteSearch(const DiscreteFactorGraph& factorGraph,
const Ordering& ordering,
bool buildJunctionTree) {
const DiscreteEliminationTree etree(factorGraph, ordering);
const DiscreteJunctionTree junctionTree(etree);
DiscreteSearch::DiscreteSearch(const DiscreteEliminationTree& etree) {
using NodePtr = std::shared_ptr<DiscreteEliminationTree::Node>;
auto visitor = [this](const NodePtr& node, int data) {
const auto& factors = node->factors;
const auto factor = factors.size() == 1
? factors.back()
: DiscreteFactorGraph(factors).product();
const size_t cardinality = factor->cardinality(node->key);
std::vector<std::pair<Key, size_t>> pairs{{node->key, cardinality}};
slots_.emplace_back(factor, DiscreteValues::CartesianProduct(pairs), 0.0);
return data + 1;
};
// GTSAM_PRINT(asia::etree);
// GTSAM_PRINT(asia::junctionTree);
slots_.reserve(factorGraph.size());
for (auto& factor : factorGraph) {
slots_.emplace_back(factor, std::vector<DiscreteValues>{}, 0.0);
}
const int data = 0; // unused
treeTraversal::DepthFirstForest(etree, data, visitor);
std::reverse(slots_.begin(), slots_.end()); // reverse slots
lowerBound_ = computeHeuristic();
}
DiscreteSearch::DiscreteSearch(const DiscreteJunctionTree& junctionTree) {
using NodePtr = std::shared_ptr<DiscreteJunctionTree::Cluster>;
auto visitor = [this](const NodePtr& cluster, int data) {
const auto& factors = cluster->factors;
const auto factor = factors.size() == 1
? factors.back()
: DiscreteFactorGraph(factors).product();
std::vector<std::pair<Key, size_t>> pairs;
for (Key key : cluster->orderedFrontalKeys) {
pairs.emplace_back(key, factor->cardinality(key));
}
slots_.emplace_back(factor, DiscreteValues::CartesianProduct(pairs), 0.0);
return data + 1;
};
const int data = 0; // unused
treeTraversal::DepthFirstForest(junctionTree, data, visitor);
std::reverse(slots_.begin(), slots_.end()); // reverse slots
lowerBound_ = computeHeuristic();
}
DiscreteSearch DiscreteSearch::FromFactorGraph(
const DiscreteFactorGraph& factorGraph, const Ordering& ordering,
bool buildJunctionTree) {
const DiscreteEliminationTree etree(factorGraph, ordering);
if (buildJunctionTree) {
const DiscreteJunctionTree junctionTree(etree);
return DiscreteSearch(junctionTree);
} else {
return DiscreteSearch(etree);
}
}
DiscreteSearch::DiscreteSearch(const DiscreteBayesNet& bayesNet) {
slots_.reserve(bayesNet.size());
for (auto& conditional : bayesNet) {
@ -188,6 +225,14 @@ DiscreteSearch::DiscreteSearch(const DiscreteBayesTree& bayesTree) {
lowerBound_ = computeHeuristic();
}
void DiscreteSearch::print(const std::string& name,
const KeyFormatter& formatter) const {
std::cout << name << " with " << slots_.size() << " slots:\n";
for (size_t i = 0; i < slots_.size(); ++i) {
std::cout << i << ": " << slots_[i] << std::endl;
}
}
struct SearchNodeQueue
: public std::priority_queue<SearchNode, std::vector<SearchNode>,
SearchNode::Compare> {
@ -215,8 +260,6 @@ struct SearchNodeQueue
}
};
#define DISCRETE_SEARCH_DEBUG
std::vector<Solution> DiscreteSearch::run(size_t K) const {
Solutions solutions(K);
SearchNodeQueue expansions;
@ -252,9 +295,9 @@ std::vector<Solution> DiscreteSearch::run(size_t K) const {
// a lower-bound on the cost-to-go for each slot, *not* including this factor.
// For the first slot, this is 0.0, as this is the last slot to be filled, so
// the cost after that is zero. For the second slot, it is h0 =
// -log(max(factor[0])), because after we assign slot[1] we still need to assign
// slot[0], which will cost *at least* h0.
// We return the estimated lower bound of the cost for *all* slots.
// -log(max(factor[0])), because after we assign slot[1] we still need to
// assign slot[0], which will cost *at least* h0. We return the estimated
// lower bound of the cost for *all* slots.
double DiscreteSearch::computeHeuristic() {
double error = 0.0;
for (size_t i = 0; i < slots_.size(); ++i) {

View File

@ -10,7 +10,11 @@
* -------------------------------------------------------------------------- */
/*
* DiscreteSearch.cpp
* @file DiscreteSearch.h
* @brief Defines the DiscreteSearch class for discrete search algorithms.
*
* @details This file contains the definition of the DiscreteSearch class, which
* is used in discrete search algorithms to find the K best solutions.
*
* @date January, 2025
* @author Frank Dellaert
@ -43,6 +47,13 @@ class GTSAM_EXPORT DiscreteSearch {
/// A lower bound on the cost-to-go for each slot, e.g.,
/// [-log(max_B P(B|A)), -log(max_A P(A))]
double heuristic;
friend std::ostream& operator<<(std::ostream& os, const Slot& slot) {
os << "Slot with " << slot.assignments.size()
<< " assignments, heuristic=" << slot.heuristic;
os << ", factor:\n" << slot.factor->markdown() << std::endl;
return os;
}
};
/// A solution is then a set of assignments, covering all the slots.
@ -59,6 +70,9 @@ class GTSAM_EXPORT DiscreteSearch {
};
public:
/// @name Standard Constructors
/// @{
/**
* Construct from a DiscreteFactorGraph.
*
@ -68,12 +82,26 @@ class GTSAM_EXPORT DiscreteSearch {
* fine-grained (more slots).
*
* @param factorGraph The factor graph to search over.
* @param ordering The ordering of the variables to search over.
* @param buildJunctionTree Whether to build a junction tree for the factor
* graph.
* @param ordering The ordering used to create etree (and maybe jtree).
* @param buildJunctionTree Whether to build a junction tree or not.
*/
DiscreteSearch(const DiscreteFactorGraph& factorGraph,
const Ordering& ordering, bool buildJunctionTree = false);
static DiscreteSearch FromFactorGraph(const DiscreteFactorGraph& factorGraph,
const Ordering& ordering,
bool buildJunctionTree = false);
/**
* @brief Constructor from a DiscreteEliminationTree.
*
* @param etree The DiscreteEliminationTree to initialize from.
*/
DiscreteSearch(const DiscreteEliminationTree& etree);
/**
* @brief Constructor from a DiscreteJunctionTree.
*
* @param junctionTree The DiscreteJunctionTree to initialize from.
*/
DiscreteSearch(const DiscreteJunctionTree& junctionTree);
/**
* Construct from a DiscreteBayesNet.
@ -85,6 +113,18 @@ class GTSAM_EXPORT DiscreteSearch {
*/
DiscreteSearch(const DiscreteBayesTree& bayesTree);
/// @}
/// @name Testable
/// @{
/** Print the tree to cout */
void print(const std::string& name = "DiscreteSearch: ",
const KeyFormatter& formatter = DefaultKeyFormatter) const;
/// @}
/// @name Standard API
/// @{
/**
* @brief Search for the K best solutions.
*
@ -97,6 +137,8 @@ class GTSAM_EXPORT DiscreteSearch {
*/
std::vector<Solution> run(size_t K = 1) const;
/// @}
private:
/// Compute the cumulative lower-bound cost-to-go after each slot is filled.
/// @return the estimated lower bound of the cost for *all* slots.