Working etree and jtree versions
parent
c4870cc840
commit
0bc566f692
|
@ -150,21 +150,58 @@ class Solutions {
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
DiscreteSearch::DiscreteSearch(const DiscreteFactorGraph& factorGraph,
|
DiscreteSearch::DiscreteSearch(const DiscreteEliminationTree& etree) {
|
||||||
const Ordering& ordering,
|
using NodePtr = std::shared_ptr<DiscreteEliminationTree::Node>;
|
||||||
bool buildJunctionTree) {
|
auto visitor = [this](const NodePtr& node, int data) {
|
||||||
const DiscreteEliminationTree etree(factorGraph, ordering);
|
const auto& factors = node->factors;
|
||||||
const DiscreteJunctionTree junctionTree(etree);
|
const auto factor = factors.size() == 1
|
||||||
|
? factors.back()
|
||||||
|
: DiscreteFactorGraph(factors).product();
|
||||||
|
const size_t cardinality = factor->cardinality(node->key);
|
||||||
|
std::vector<std::pair<Key, size_t>> pairs{{node->key, cardinality}};
|
||||||
|
slots_.emplace_back(factor, DiscreteValues::CartesianProduct(pairs), 0.0);
|
||||||
|
return data + 1;
|
||||||
|
};
|
||||||
|
|
||||||
// GTSAM_PRINT(asia::etree);
|
const int data = 0; // unused
|
||||||
// GTSAM_PRINT(asia::junctionTree);
|
treeTraversal::DepthFirstForest(etree, data, visitor);
|
||||||
slots_.reserve(factorGraph.size());
|
std::reverse(slots_.begin(), slots_.end()); // reverse slots
|
||||||
for (auto& factor : factorGraph) {
|
|
||||||
slots_.emplace_back(factor, std::vector<DiscreteValues>{}, 0.0);
|
|
||||||
}
|
|
||||||
lowerBound_ = computeHeuristic();
|
lowerBound_ = computeHeuristic();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
DiscreteSearch::DiscreteSearch(const DiscreteJunctionTree& junctionTree) {
|
||||||
|
using NodePtr = std::shared_ptr<DiscreteJunctionTree::Cluster>;
|
||||||
|
auto visitor = [this](const NodePtr& cluster, int data) {
|
||||||
|
const auto& factors = cluster->factors;
|
||||||
|
const auto factor = factors.size() == 1
|
||||||
|
? factors.back()
|
||||||
|
: DiscreteFactorGraph(factors).product();
|
||||||
|
std::vector<std::pair<Key, size_t>> pairs;
|
||||||
|
for (Key key : cluster->orderedFrontalKeys) {
|
||||||
|
pairs.emplace_back(key, factor->cardinality(key));
|
||||||
|
}
|
||||||
|
slots_.emplace_back(factor, DiscreteValues::CartesianProduct(pairs), 0.0);
|
||||||
|
return data + 1;
|
||||||
|
};
|
||||||
|
|
||||||
|
const int data = 0; // unused
|
||||||
|
treeTraversal::DepthFirstForest(junctionTree, data, visitor);
|
||||||
|
std::reverse(slots_.begin(), slots_.end()); // reverse slots
|
||||||
|
lowerBound_ = computeHeuristic();
|
||||||
|
}
|
||||||
|
|
||||||
|
DiscreteSearch DiscreteSearch::FromFactorGraph(
|
||||||
|
const DiscreteFactorGraph& factorGraph, const Ordering& ordering,
|
||||||
|
bool buildJunctionTree) {
|
||||||
|
const DiscreteEliminationTree etree(factorGraph, ordering);
|
||||||
|
if (buildJunctionTree) {
|
||||||
|
const DiscreteJunctionTree junctionTree(etree);
|
||||||
|
return DiscreteSearch(junctionTree);
|
||||||
|
} else {
|
||||||
|
return DiscreteSearch(etree);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
DiscreteSearch::DiscreteSearch(const DiscreteBayesNet& bayesNet) {
|
DiscreteSearch::DiscreteSearch(const DiscreteBayesNet& bayesNet) {
|
||||||
slots_.reserve(bayesNet.size());
|
slots_.reserve(bayesNet.size());
|
||||||
for (auto& conditional : bayesNet) {
|
for (auto& conditional : bayesNet) {
|
||||||
|
@ -188,6 +225,14 @@ DiscreteSearch::DiscreteSearch(const DiscreteBayesTree& bayesTree) {
|
||||||
lowerBound_ = computeHeuristic();
|
lowerBound_ = computeHeuristic();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void DiscreteSearch::print(const std::string& name,
|
||||||
|
const KeyFormatter& formatter) const {
|
||||||
|
std::cout << name << " with " << slots_.size() << " slots:\n";
|
||||||
|
for (size_t i = 0; i < slots_.size(); ++i) {
|
||||||
|
std::cout << i << ": " << slots_[i] << std::endl;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
struct SearchNodeQueue
|
struct SearchNodeQueue
|
||||||
: public std::priority_queue<SearchNode, std::vector<SearchNode>,
|
: public std::priority_queue<SearchNode, std::vector<SearchNode>,
|
||||||
SearchNode::Compare> {
|
SearchNode::Compare> {
|
||||||
|
@ -215,8 +260,6 @@ struct SearchNodeQueue
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
#define DISCRETE_SEARCH_DEBUG
|
|
||||||
|
|
||||||
std::vector<Solution> DiscreteSearch::run(size_t K) const {
|
std::vector<Solution> DiscreteSearch::run(size_t K) const {
|
||||||
Solutions solutions(K);
|
Solutions solutions(K);
|
||||||
SearchNodeQueue expansions;
|
SearchNodeQueue expansions;
|
||||||
|
@ -252,9 +295,9 @@ std::vector<Solution> DiscreteSearch::run(size_t K) const {
|
||||||
// a lower-bound on the cost-to-go for each slot, *not* including this factor.
|
// a lower-bound on the cost-to-go for each slot, *not* including this factor.
|
||||||
// For the first slot, this is 0.0, as this is the last slot to be filled, so
|
// For the first slot, this is 0.0, as this is the last slot to be filled, so
|
||||||
// the cost after that is zero. For the second slot, it is h0 =
|
// the cost after that is zero. For the second slot, it is h0 =
|
||||||
// -log(max(factor[0])), because after we assign slot[1] we still need to assign
|
// -log(max(factor[0])), because after we assign slot[1] we still need to
|
||||||
// slot[0], which will cost *at least* h0.
|
// assign slot[0], which will cost *at least* h0. We return the estimated
|
||||||
// We return the estimated lower bound of the cost for *all* slots.
|
// lower bound of the cost for *all* slots.
|
||||||
double DiscreteSearch::computeHeuristic() {
|
double DiscreteSearch::computeHeuristic() {
|
||||||
double error = 0.0;
|
double error = 0.0;
|
||||||
for (size_t i = 0; i < slots_.size(); ++i) {
|
for (size_t i = 0; i < slots_.size(); ++i) {
|
||||||
|
|
|
@ -10,7 +10,11 @@
|
||||||
* -------------------------------------------------------------------------- */
|
* -------------------------------------------------------------------------- */
|
||||||
|
|
||||||
/*
|
/*
|
||||||
* DiscreteSearch.cpp
|
* @file DiscreteSearch.h
|
||||||
|
* @brief Defines the DiscreteSearch class for discrete search algorithms.
|
||||||
|
*
|
||||||
|
* @details This file contains the definition of the DiscreteSearch class, which
|
||||||
|
* is used in discrete search algorithms to find the K best solutions.
|
||||||
*
|
*
|
||||||
* @date January, 2025
|
* @date January, 2025
|
||||||
* @author Frank Dellaert
|
* @author Frank Dellaert
|
||||||
|
@ -43,6 +47,13 @@ class GTSAM_EXPORT DiscreteSearch {
|
||||||
/// A lower bound on the cost-to-go for each slot, e.g.,
|
/// A lower bound on the cost-to-go for each slot, e.g.,
|
||||||
/// [-log(max_B P(B|A)), -log(max_A P(A))]
|
/// [-log(max_B P(B|A)), -log(max_A P(A))]
|
||||||
double heuristic;
|
double heuristic;
|
||||||
|
|
||||||
|
friend std::ostream& operator<<(std::ostream& os, const Slot& slot) {
|
||||||
|
os << "Slot with " << slot.assignments.size()
|
||||||
|
<< " assignments, heuristic=" << slot.heuristic;
|
||||||
|
os << ", factor:\n" << slot.factor->markdown() << std::endl;
|
||||||
|
return os;
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
/// A solution is then a set of assignments, covering all the slots.
|
/// A solution is then a set of assignments, covering all the slots.
|
||||||
|
@ -59,6 +70,9 @@ class GTSAM_EXPORT DiscreteSearch {
|
||||||
};
|
};
|
||||||
|
|
||||||
public:
|
public:
|
||||||
|
/// @name Standard Constructors
|
||||||
|
/// @{
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Construct from a DiscreteFactorGraph.
|
* Construct from a DiscreteFactorGraph.
|
||||||
*
|
*
|
||||||
|
@ -68,12 +82,26 @@ class GTSAM_EXPORT DiscreteSearch {
|
||||||
* fine-grained (more slots).
|
* fine-grained (more slots).
|
||||||
*
|
*
|
||||||
* @param factorGraph The factor graph to search over.
|
* @param factorGraph The factor graph to search over.
|
||||||
* @param ordering The ordering of the variables to search over.
|
* @param ordering The ordering used to create etree (and maybe jtree).
|
||||||
* @param buildJunctionTree Whether to build a junction tree for the factor
|
* @param buildJunctionTree Whether to build a junction tree or not.
|
||||||
* graph.
|
|
||||||
*/
|
*/
|
||||||
DiscreteSearch(const DiscreteFactorGraph& factorGraph,
|
static DiscreteSearch FromFactorGraph(const DiscreteFactorGraph& factorGraph,
|
||||||
const Ordering& ordering, bool buildJunctionTree = false);
|
const Ordering& ordering,
|
||||||
|
bool buildJunctionTree = false);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Constructor from a DiscreteEliminationTree.
|
||||||
|
*
|
||||||
|
* @param etree The DiscreteEliminationTree to initialize from.
|
||||||
|
*/
|
||||||
|
DiscreteSearch(const DiscreteEliminationTree& etree);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Constructor from a DiscreteJunctionTree.
|
||||||
|
*
|
||||||
|
* @param junctionTree The DiscreteJunctionTree to initialize from.
|
||||||
|
*/
|
||||||
|
DiscreteSearch(const DiscreteJunctionTree& junctionTree);
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Construct from a DiscreteBayesNet.
|
* Construct from a DiscreteBayesNet.
|
||||||
|
@ -85,6 +113,18 @@ class GTSAM_EXPORT DiscreteSearch {
|
||||||
*/
|
*/
|
||||||
DiscreteSearch(const DiscreteBayesTree& bayesTree);
|
DiscreteSearch(const DiscreteBayesTree& bayesTree);
|
||||||
|
|
||||||
|
/// @}
|
||||||
|
/// @name Testable
|
||||||
|
/// @{
|
||||||
|
|
||||||
|
/** Print the tree to cout */
|
||||||
|
void print(const std::string& name = "DiscreteSearch: ",
|
||||||
|
const KeyFormatter& formatter = DefaultKeyFormatter) const;
|
||||||
|
|
||||||
|
/// @}
|
||||||
|
/// @name Standard API
|
||||||
|
/// @{
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @brief Search for the K best solutions.
|
* @brief Search for the K best solutions.
|
||||||
*
|
*
|
||||||
|
@ -97,6 +137,8 @@ class GTSAM_EXPORT DiscreteSearch {
|
||||||
*/
|
*/
|
||||||
std::vector<Solution> run(size_t K = 1) const;
|
std::vector<Solution> run(size_t K = 1) const;
|
||||||
|
|
||||||
|
/// @}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
/// Compute the cumulative lower-bound cost-to-go after each slot is filled.
|
/// Compute the cumulative lower-bound cost-to-go after each slot is filled.
|
||||||
/// @return the estimated lower bound of the cost for *all* slots.
|
/// @return the estimated lower bound of the cost for *all* slots.
|
||||||
|
|
Loading…
Reference in New Issue