316 lines
10 KiB
C++
316 lines
10 KiB
C++
/* ----------------------------------------------------------------------------
|
|
|
|
* GTSAM Copyright 2010, Georgia Tech Research Corporation,
|
|
* Atlanta, Georgia 30332-0415
|
|
* All Rights Reserved
|
|
* Authors: Frank Dellaert, et al. (see THANKS for the full author list)
|
|
|
|
* See LICENSE for the license information
|
|
|
|
* -------------------------------------------------------------------------- */
|
|
|
|
/*
|
|
* DiscreteSearch.cpp
|
|
*
|
|
* @date January, 2025
|
|
* @author Frank Dellaert
|
|
*/
|
|
|
|
#include <gtsam/discrete/DiscreteEliminationTree.h>
|
|
#include <gtsam/discrete/DiscreteJunctionTree.h>
|
|
#include <gtsam/discrete/DiscreteSearch.h>
|
|
|
|
namespace gtsam {
|
|
|
|
using Slot = DiscreteSearch::Slot;
|
|
using Solution = DiscreteSearch::Solution;
|
|
|
|
/**
|
|
* @brief Represents a node in the search tree for discrete search algorithms.
|
|
*
|
|
* @details Each SearchNode contains a partial assignment of discrete variables,
|
|
* the current error, a bound on the final error, and the index of the next
|
|
* conditional to be assigned.
|
|
*/
|
|
struct SearchNode {
|
|
DiscreteValues assignment; ///< Partial assignment of discrete variables.
|
|
double error; ///< Current error for the partial assignment.
|
|
double bound; ///< Lower bound on the final error for unassigned variables.
|
|
int nextConditional; ///< Index of the next conditional to be assigned.
|
|
|
|
/**
|
|
* @brief Construct the root node for the search.
|
|
*/
|
|
static SearchNode Root(size_t numSlots, double bound) {
|
|
return {DiscreteValues(), 0.0, bound, static_cast<int>(numSlots) - 1};
|
|
}
|
|
|
|
struct Compare {
|
|
bool operator()(const SearchNode& a, const SearchNode& b) const {
|
|
return a.bound > b.bound; // smallest bound -> highest priority
|
|
}
|
|
};
|
|
|
|
/**
|
|
* @brief Checks if the node represents a complete assignment.
|
|
*
|
|
* @return True if all variables have been assigned, false otherwise.
|
|
*/
|
|
inline bool isComplete() const { return nextConditional < 0; }
|
|
|
|
/**
|
|
* @brief Expands the node by assigning the next variable.
|
|
*
|
|
* @param slot The slot to be filled.
|
|
* @param fa The frontal assignment for the next variable.
|
|
* @return A new SearchNode representing the expanded state.
|
|
*/
|
|
SearchNode expand(const Slot& slot, const DiscreteValues& fa) const {
|
|
// Combine the new frontal assignment with the current partial assignment
|
|
DiscreteValues newAssignment = assignment;
|
|
for (auto& [key, value] : fa) {
|
|
newAssignment[key] = value;
|
|
}
|
|
double errorSoFar = error + slot.factor->error(newAssignment);
|
|
return {newAssignment, errorSoFar, errorSoFar + slot.heuristic,
|
|
nextConditional - 1};
|
|
}
|
|
|
|
/**
|
|
* @brief Prints the SearchNode to an output stream.
|
|
*
|
|
* @param os The output stream.
|
|
* @param node The SearchNode to be printed.
|
|
* @return The output stream.
|
|
*/
|
|
friend std::ostream& operator<<(std::ostream& os, const SearchNode& node) {
|
|
os << "SearchNode(error=" << node.error << ", bound=" << node.bound << ")";
|
|
return os;
|
|
}
|
|
};
|
|
|
|
struct CompareSolution {
|
|
bool operator()(const Solution& a, const Solution& b) const {
|
|
return a.error < b.error;
|
|
}
|
|
};
|
|
|
|
// Define the Solutions class
|
|
class Solutions {
|
|
private:
|
|
size_t maxSize_;
|
|
std::priority_queue<Solution, std::vector<Solution>, CompareSolution> pq_;
|
|
|
|
public:
|
|
Solutions(size_t maxSize) : maxSize_(maxSize) {}
|
|
|
|
/// Add a solution to the priority queue, possibly evicting the worst one.
|
|
/// Return true if we added the solution.
|
|
bool maybeAdd(double error, const DiscreteValues& assignment) {
|
|
const bool full = pq_.size() == maxSize_;
|
|
if (full && error >= pq_.top().error) return false;
|
|
if (full) pq_.pop();
|
|
pq_.emplace(error, assignment);
|
|
return true;
|
|
}
|
|
|
|
/// Check if we have any solutions
|
|
bool empty() const { return pq_.empty(); }
|
|
|
|
// Method to print all solutions
|
|
friend std::ostream& operator<<(std::ostream& os, const Solutions& sn) {
|
|
os << "Solutions (top " << sn.pq_.size() << "):\n";
|
|
auto pq = sn.pq_;
|
|
while (!pq.empty()) {
|
|
os << pq.top() << "\n";
|
|
pq.pop();
|
|
}
|
|
return os;
|
|
}
|
|
|
|
/// Check if (partial) solution with given bound can be pruned. If we have
|
|
/// room, we never prune. Otherwise, prune if lower bound on error is worse
|
|
/// than our current worst error.
|
|
bool prune(double bound) const {
|
|
if (pq_.size() < maxSize_) return false;
|
|
return bound >= pq_.top().error;
|
|
}
|
|
|
|
// Method to extract solutions in ascending order of error
|
|
std::vector<Solution> extractSolutions() {
|
|
std::vector<Solution> result;
|
|
while (!pq_.empty()) {
|
|
result.push_back(pq_.top());
|
|
pq_.pop();
|
|
}
|
|
std::sort(
|
|
result.begin(), result.end(),
|
|
[](const Solution& a, const Solution& b) { return a.error < b.error; });
|
|
return result;
|
|
}
|
|
};
|
|
|
|
DiscreteSearch::DiscreteSearch(const DiscreteEliminationTree& etree) {
|
|
using NodePtr = std::shared_ptr<DiscreteEliminationTree::Node>;
|
|
auto visitor = [this](const NodePtr& node, int data) {
|
|
const auto& factors = node->factors;
|
|
const auto factor = factors.size() == 1
|
|
? factors.back()
|
|
: DiscreteFactorGraph(factors).product();
|
|
const size_t cardinality = factor->cardinality(node->key);
|
|
std::vector<std::pair<Key, size_t>> pairs{{node->key, cardinality}};
|
|
const Slot slot{factor, DiscreteValues::CartesianProduct(pairs), 0.0};
|
|
slots_.emplace_back(std::move(slot));
|
|
return data + 1;
|
|
};
|
|
|
|
const int data = 0; // unused
|
|
treeTraversal::DepthFirstForest(etree, data, visitor);
|
|
std::reverse(slots_.begin(), slots_.end()); // reverse slots
|
|
lowerBound_ = computeHeuristic();
|
|
}
|
|
|
|
DiscreteSearch::DiscreteSearch(const DiscreteJunctionTree& junctionTree) {
|
|
using NodePtr = std::shared_ptr<DiscreteJunctionTree::Cluster>;
|
|
auto visitor = [this](const NodePtr& cluster, int data) {
|
|
const auto& factors = cluster->factors;
|
|
const auto factor = factors.size() == 1
|
|
? factors.back()
|
|
: DiscreteFactorGraph(factors).product();
|
|
std::vector<std::pair<Key, size_t>> pairs;
|
|
for (Key key : cluster->orderedFrontalKeys) {
|
|
pairs.emplace_back(key, factor->cardinality(key));
|
|
}
|
|
const Slot slot{factor, DiscreteValues::CartesianProduct(pairs), 0.0};
|
|
slots_.emplace_back(std::move(slot));
|
|
return data + 1;
|
|
};
|
|
|
|
const int data = 0; // unused
|
|
treeTraversal::DepthFirstForest(junctionTree, data, visitor);
|
|
std::reverse(slots_.begin(), slots_.end()); // reverse slots
|
|
lowerBound_ = computeHeuristic();
|
|
}
|
|
|
|
DiscreteSearch DiscreteSearch::FromFactorGraph(
|
|
const DiscreteFactorGraph& factorGraph, const Ordering& ordering,
|
|
bool buildJunctionTree) {
|
|
const DiscreteEliminationTree etree(factorGraph, ordering);
|
|
if (buildJunctionTree) {
|
|
const DiscreteJunctionTree junctionTree(etree);
|
|
return DiscreteSearch(junctionTree);
|
|
} else {
|
|
return DiscreteSearch(etree);
|
|
}
|
|
}
|
|
|
|
DiscreteSearch::DiscreteSearch(const DiscreteBayesNet& bayesNet) {
|
|
slots_.reserve(bayesNet.size());
|
|
for (auto& conditional : bayesNet) {
|
|
const Slot slot{conditional, conditional->frontalAssignments(), 0.0};
|
|
slots_.emplace_back(std::move(slot));
|
|
}
|
|
lowerBound_ = computeHeuristic();
|
|
}
|
|
|
|
DiscreteSearch::DiscreteSearch(const DiscreteBayesTree& bayesTree) {
|
|
std::function<void(const DiscreteBayesTree::sharedClique&)>
|
|
collectConditionals = [&](const auto& clique) {
|
|
if (!clique) return;
|
|
for (const auto& child : clique->children) collectConditionals(child);
|
|
auto conditional = clique->conditional();
|
|
const Slot slot{conditional, conditional->frontalAssignments(), 0.0};
|
|
slots_.emplace_back(std::move(slot));
|
|
};
|
|
|
|
slots_.reserve(bayesTree.size());
|
|
for (const auto& root : bayesTree.roots()) collectConditionals(root);
|
|
lowerBound_ = computeHeuristic();
|
|
}
|
|
|
|
void DiscreteSearch::print(const std::string& name,
|
|
const KeyFormatter& formatter) const {
|
|
std::cout << name << " with " << slots_.size() << " slots:\n";
|
|
for (size_t i = 0; i < slots_.size(); ++i) {
|
|
std::cout << i << ": " << slots_[i] << std::endl;
|
|
}
|
|
}
|
|
|
|
struct SearchNodeQueue
|
|
: public std::priority_queue<SearchNode, std::vector<SearchNode>,
|
|
SearchNode::Compare> {
|
|
void expandNextNode(const SearchNode& current, const Slot& slot,
|
|
Solutions* solutions) {
|
|
// If we already have K solutions, prune if we cannot beat the worst one.
|
|
if (solutions->prune(current.bound)) {
|
|
return;
|
|
}
|
|
|
|
// Check if we have a complete assignment
|
|
if (current.isComplete()) {
|
|
solutions->maybeAdd(current.error, current.assignment);
|
|
return;
|
|
}
|
|
|
|
for (auto& fa : slot.assignments) {
|
|
auto childNode = current.expand(slot, fa);
|
|
|
|
// Again, prune if we cannot beat the worst solution
|
|
if (!solutions->prune(childNode.bound)) {
|
|
emplace(childNode);
|
|
}
|
|
}
|
|
}
|
|
};
|
|
|
|
std::vector<Solution> DiscreteSearch::run(size_t K) const {
|
|
Solutions solutions(K);
|
|
SearchNodeQueue expansions;
|
|
expansions.push(SearchNode::Root(slots_.size(), lowerBound_));
|
|
|
|
#ifdef DISCRETE_SEARCH_DEBUG
|
|
size_t numExpansions = 0;
|
|
#endif
|
|
|
|
// Perform the search
|
|
while (!expansions.empty()) {
|
|
// Pop the partial assignment with the smallest bound
|
|
SearchNode current = expansions.top();
|
|
expansions.pop();
|
|
|
|
// Get the next slot to expand
|
|
const auto& slot = slots_[current.nextConditional];
|
|
expansions.expandNextNode(current, slot, &solutions);
|
|
#ifdef DISCRETE_SEARCH_DEBUG
|
|
++numExpansions;
|
|
#endif
|
|
}
|
|
|
|
#ifdef DISCRETE_SEARCH_DEBUG
|
|
std::cout << "Number of expansions: " << numExpansions << std::endl;
|
|
#endif
|
|
|
|
// Extract solutions from bestSolutions in ascending order of error
|
|
return solutions.extractSolutions();
|
|
}
|
|
|
|
// We have a number of factors, each with a max value, and we want to compute
|
|
// a lower-bound on the cost-to-go for each slot, *not* including this factor.
|
|
// For the first slot, this is 0.0, as this is the last slot to be filled, so
|
|
// the cost after that is zero. For the second slot, it is h0 =
|
|
// -log(max(factor[0])), because after we assign slot[1] we still need to
|
|
// assign slot[0], which will cost *at least* h0. We return the estimated
|
|
// lower bound of the cost for *all* slots.
|
|
double DiscreteSearch::computeHeuristic() {
|
|
double error = 0.0;
|
|
for (auto& slot : slots_) {
|
|
slot.heuristic = error;
|
|
Ordering ordering(slot.factor->begin(), slot.factor->end());
|
|
auto maxx = slot.factor->max(ordering);
|
|
error -= std::log(maxx->evaluate({}));
|
|
}
|
|
return error;
|
|
}
|
|
|
|
} // namespace gtsam
|