Move to cpp file
parent
14eeaf93db
commit
70089a0fd4
|
|
@ -0,0 +1,178 @@
|
||||||
|
/* ----------------------------------------------------------------------------
|
||||||
|
|
||||||
|
* GTSAM Copyright 2010, Georgia Tech Research Corporation,
|
||||||
|
* Atlanta, Georgia 30332-0415
|
||||||
|
* All Rights Reserved
|
||||||
|
* Authors: Frank Dellaert, et al. (see THANKS for the full author list)
|
||||||
|
|
||||||
|
* See LICENSE for the license information
|
||||||
|
|
||||||
|
* -------------------------------------------------------------------------- */
|
||||||
|
|
||||||
|
/*
|
||||||
|
* DiscreteSearch.cpp
|
||||||
|
*
|
||||||
|
* @date January, 2025
|
||||||
|
* @author Frank Dellaert
|
||||||
|
*/
|
||||||
|
|
||||||
|
#include <gtsam/discrete/DiscreteSearch.h>
|
||||||
|
|
||||||
|
namespace gtsam {
|
||||||
|
|
||||||
|
SearchNode SearchNode::Root(size_t numConditionals, double bound) {
|
||||||
|
return {.assignment = DiscreteValues(),
|
||||||
|
.error = 0.0,
|
||||||
|
.bound = bound,
|
||||||
|
.nextConditional = static_cast<int>(numConditionals) - 1};
|
||||||
|
}
|
||||||
|
|
||||||
|
SearchNode SearchNode::expand(const DiscreteConditional& conditional,
|
||||||
|
const DiscreteValues& fa) const {
|
||||||
|
// Combine the new frontal assignment with the current partial assignment
|
||||||
|
DiscreteValues newAssignment = assignment;
|
||||||
|
for (auto& kv : fa) {
|
||||||
|
newAssignment[kv.first] = kv.second;
|
||||||
|
}
|
||||||
|
|
||||||
|
return {.assignment = newAssignment,
|
||||||
|
.error = error + conditional.error(newAssignment),
|
||||||
|
.bound = 0.0,
|
||||||
|
.nextConditional = nextConditional - 1};
|
||||||
|
}
|
||||||
|
|
||||||
|
bool Solutions::maybeAdd(double error, const DiscreteValues& assignment) {
|
||||||
|
const bool full = pq_.size() == maxSize_;
|
||||||
|
if (full && error >= pq_.top().error) return false;
|
||||||
|
if (full) pq_.pop();
|
||||||
|
pq_.emplace(error, assignment);
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::ostream& operator<<(std::ostream& os, const Solutions& sn) {
|
||||||
|
auto pq = sn.pq_;
|
||||||
|
while (!pq.empty()) {
|
||||||
|
const Solution& best = pq.top();
|
||||||
|
os << "Error: " << best.error << ", Values: " << best.assignment
|
||||||
|
<< std::endl;
|
||||||
|
pq.pop();
|
||||||
|
}
|
||||||
|
return os;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool Solutions::prune(double bound) const {
|
||||||
|
if (pq_.size() < maxSize_) return false;
|
||||||
|
double worstError = pq_.top().error;
|
||||||
|
return (bound >= worstError);
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<Solution> Solutions::extractSolutions() {
|
||||||
|
std::vector<Solution> result;
|
||||||
|
while (!pq_.empty()) {
|
||||||
|
result.push_back(pq_.top());
|
||||||
|
pq_.pop();
|
||||||
|
}
|
||||||
|
std::sort(
|
||||||
|
result.begin(), result.end(),
|
||||||
|
[](const Solution& a, const Solution& b) { return a.error < b.error; });
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
DiscreteSearch::DiscreteSearch(const DiscreteBayesNet& bayesNet, size_t K)
|
||||||
|
: solutions_(K) {
|
||||||
|
// Copy out the conditionals
|
||||||
|
for (auto& factor : bayesNet) {
|
||||||
|
conditionals_.push_back(factor);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Calculate the cost-to-go for each conditional
|
||||||
|
costToGo_ = computeCostToGo(conditionals_);
|
||||||
|
|
||||||
|
// Create the root node and push it to the expansions queue
|
||||||
|
expansions_.push(SearchNode::Root(
|
||||||
|
conditionals_.size(), costToGo_.empty() ? 0.0 : costToGo_.back()));
|
||||||
|
}
|
||||||
|
|
||||||
|
DiscreteSearch::DiscreteSearch(const DiscreteBayesTree& bayesTree, size_t K)
|
||||||
|
: solutions_(K) {
|
||||||
|
using CliquePtr = DiscreteBayesTree::sharedClique;
|
||||||
|
std::function<void(const CliquePtr&)> collectConditionals =
|
||||||
|
[&](const CliquePtr& clique) -> void {
|
||||||
|
if (!clique) return;
|
||||||
|
|
||||||
|
// Recursive post-order traversal: process children first
|
||||||
|
for (const auto& child : clique->children) {
|
||||||
|
collectConditionals(child);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Then add the current clique's conditional
|
||||||
|
conditionals_.push_back(clique->conditional());
|
||||||
|
};
|
||||||
|
|
||||||
|
// Start traversal from each root in the tree
|
||||||
|
for (const auto& root : bayesTree.roots()) collectConditionals(root);
|
||||||
|
|
||||||
|
// Calculate the cost-to-go for each conditional
|
||||||
|
costToGo_ = computeCostToGo(conditionals_);
|
||||||
|
|
||||||
|
// Create the root node and push it to the expansions queue
|
||||||
|
expansions_.push(SearchNode::Root(
|
||||||
|
conditionals_.size(), costToGo_.empty() ? 0.0 : costToGo_.back()));
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<Solution> DiscreteSearch::run() {
|
||||||
|
while (!expansions_.empty()) {
|
||||||
|
numExpansions++;
|
||||||
|
expandNextNode();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Extract solutions from bestSolutions in ascending order of error
|
||||||
|
return solutions_.extractSolutions();
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<double> DiscreteSearch::computeCostToGo(
|
||||||
|
const std::vector<DiscreteConditional::shared_ptr>& conditionals) {
|
||||||
|
std::vector<double> costToGo;
|
||||||
|
double error = 0.0;
|
||||||
|
for (const auto& conditional : conditionals) {
|
||||||
|
Ordering ordering(conditional->begin(), conditional->end());
|
||||||
|
auto maxx = conditional->max(ordering);
|
||||||
|
assert(maxx->size() == 1);
|
||||||
|
error -= std::log(maxx->evaluate({}));
|
||||||
|
costToGo.push_back(error);
|
||||||
|
}
|
||||||
|
return costToGo;
|
||||||
|
}
|
||||||
|
|
||||||
|
void DiscreteSearch::expandNextNode() {
|
||||||
|
// Pop the partial assignment with the smallest bound
|
||||||
|
SearchNode current = expansions_.top();
|
||||||
|
expansions_.pop();
|
||||||
|
|
||||||
|
// If we already have K solutions, prune if we cannot beat the worst one.
|
||||||
|
if (solutions_.prune(current.bound)) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if we have a complete assignment
|
||||||
|
if (current.isComplete()) {
|
||||||
|
solutions_.maybeAdd(current.error, current.assignment);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Expand on the next factor
|
||||||
|
const auto& conditional = conditionals_[current.nextConditional];
|
||||||
|
|
||||||
|
for (auto& fa : conditional->frontalAssignments()) {
|
||||||
|
auto childNode = current.expand(*conditional, fa);
|
||||||
|
if (childNode.nextConditional >= 0)
|
||||||
|
childNode.bound = childNode.error + costToGo_[childNode.nextConditional];
|
||||||
|
|
||||||
|
// Again, prune if we cannot beat the worst solution
|
||||||
|
if (!solutions_.prune(childNode.bound)) {
|
||||||
|
expansions_.push(childNode);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace gtsam
|
||||||
|
|
@ -39,14 +39,9 @@ struct SearchNode {
|
||||||
/**
|
/**
|
||||||
* @brief Construct the root node for the search.
|
* @brief Construct the root node for the search.
|
||||||
*/
|
*/
|
||||||
static SearchNode Root(size_t numConditionals, double bound) {
|
static SearchNode Root(size_t numConditionals, double bound);
|
||||||
return {.assignment = DiscreteValues(),
|
|
||||||
.error = 0.0,
|
|
||||||
.bound = bound,
|
|
||||||
.nextConditional = static_cast<int>(numConditionals) - 1};
|
|
||||||
}
|
|
||||||
|
|
||||||
struct CompareByBound {
|
struct Compare {
|
||||||
bool operator()(const SearchNode& a, const SearchNode& b) const {
|
bool operator()(const SearchNode& a, const SearchNode& b) const {
|
||||||
return a.bound > b.bound; // smallest bound -> highest priority
|
return a.bound > b.bound; // smallest bound -> highest priority
|
||||||
}
|
}
|
||||||
|
|
@ -68,18 +63,7 @@ struct SearchNode {
|
||||||
* @return A new SearchNode representing the expanded state.
|
* @return A new SearchNode representing the expanded state.
|
||||||
*/
|
*/
|
||||||
SearchNode expand(const DiscreteConditional& conditional,
|
SearchNode expand(const DiscreteConditional& conditional,
|
||||||
const DiscreteValues& fa) const {
|
const DiscreteValues& fa) const;
|
||||||
// Combine the new frontal assignment with the current partial assignment
|
|
||||||
DiscreteValues newAssignment = assignment;
|
|
||||||
for (auto& kv : fa) {
|
|
||||||
newAssignment[kv.first] = kv.second;
|
|
||||||
}
|
|
||||||
|
|
||||||
return {.assignment = newAssignment,
|
|
||||||
.error = error + conditional.error(newAssignment),
|
|
||||||
.bound = 0.0,
|
|
||||||
.nextConditional = nextConditional - 1};
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @brief Prints the SearchNode to an output stream.
|
* @brief Prints the SearchNode to an output stream.
|
||||||
|
|
@ -103,69 +87,40 @@ struct Solution {
|
||||||
os << "[ error=" << sn.error << " assignment={" << sn.assignment << "}]";
|
os << "[ error=" << sn.error << " assignment={" << sn.assignment << "}]";
|
||||||
return os;
|
return os;
|
||||||
}
|
}
|
||||||
};
|
|
||||||
|
|
||||||
struct CompareByError {
|
struct Compare {
|
||||||
bool operator()(const Solution& a, const Solution& b) const {
|
bool operator()(const Solution& a, const Solution& b) const {
|
||||||
return a.error < b.error;
|
return a.error < b.error;
|
||||||
}
|
}
|
||||||
|
};
|
||||||
};
|
};
|
||||||
|
|
||||||
// Define the Solutions class
|
// Define the Solutions class
|
||||||
class Solutions {
|
class Solutions {
|
||||||
private:
|
private:
|
||||||
size_t maxSize_;
|
size_t maxSize_;
|
||||||
std::priority_queue<Solution, std::vector<Solution>, CompareByError> pq_;
|
std::priority_queue<Solution, std::vector<Solution>, Solution::Compare> pq_;
|
||||||
|
|
||||||
public:
|
public:
|
||||||
Solutions(size_t maxSize) : maxSize_(maxSize) {}
|
Solutions(size_t maxSize) : maxSize_(maxSize) {}
|
||||||
|
|
||||||
/// Add a solution to the priority queue, possibly evicting the worst one.
|
/// Add a solution to the priority queue, possibly evicting the worst one.
|
||||||
/// Return true if we added the solution.
|
/// Return true if we added the solution.
|
||||||
bool maybeAdd(double error, const DiscreteValues& assignment) {
|
bool maybeAdd(double error, const DiscreteValues& assignment);
|
||||||
const bool full = pq_.size() == maxSize_;
|
|
||||||
if (full && error >= pq_.top().error) return false;
|
|
||||||
if (full) pq_.pop();
|
|
||||||
pq_.emplace(error, assignment);
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Check if we have any solutions
|
/// Check if we have any solutions
|
||||||
bool empty() const { return pq_.empty(); }
|
bool empty() const { return pq_.empty(); }
|
||||||
|
|
||||||
// Method to print all solutions
|
// Method to print all solutions
|
||||||
friend std::ostream& operator<<(std::ostream& os, const Solutions& sn) {
|
friend std::ostream& operator<<(std::ostream& os, const Solutions& sn);
|
||||||
auto pq = sn.pq_;
|
|
||||||
while (!pq.empty()) {
|
|
||||||
const Solution& best = pq.top();
|
|
||||||
os << "Error: " << best.error << ", Values: " << best.assignment
|
|
||||||
<< std::endl;
|
|
||||||
pq.pop();
|
|
||||||
}
|
|
||||||
return os;
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Check if (partial) solution with given bound can be pruned. If we have
|
/// Check if (partial) solution with given bound can be pruned. If we have
|
||||||
/// room, we never prune. Otherwise, prune if lower bound on error is worse
|
/// room, we never prune. Otherwise, prune if lower bound on error is worse
|
||||||
/// than our current worst error.
|
/// than our current worst error.
|
||||||
bool prune(double bound) const {
|
bool prune(double bound) const;
|
||||||
if (pq_.size() < maxSize_) return false;
|
|
||||||
double worstError = pq_.top().error;
|
|
||||||
return (bound >= worstError);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Method to extract solutions in ascending order of error
|
// Method to extract solutions in ascending order of error
|
||||||
std::vector<Solution> extractSolutions() {
|
std::vector<Solution> extractSolutions();
|
||||||
std::vector<Solution> result;
|
|
||||||
while (!pq_.empty()) {
|
|
||||||
result.push_back(pq_.top());
|
|
||||||
pq_.pop();
|
|
||||||
}
|
|
||||||
std::sort(
|
|
||||||
result.begin(), result.end(),
|
|
||||||
[](const Solution& a, const Solution& b) { return a.error < b.error; });
|
|
||||||
return result;
|
|
||||||
}
|
|
||||||
};
|
};
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
|
@ -178,48 +133,12 @@ class DiscreteSearch {
|
||||||
/**
|
/**
|
||||||
* Construct from a DiscreteBayesNet and K.
|
* Construct from a DiscreteBayesNet and K.
|
||||||
*/
|
*/
|
||||||
DiscreteSearch(const DiscreteBayesNet& bayesNet, size_t K) : solutions_(K) {
|
DiscreteSearch(const DiscreteBayesNet& bayesNet, size_t K);
|
||||||
// Copy out the conditionals
|
|
||||||
for (auto& factor : bayesNet) {
|
|
||||||
conditionals_.push_back(factor);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Calculate the cost-to-go for each conditional
|
|
||||||
costToGo_ = computeCostToGo(conditionals_);
|
|
||||||
|
|
||||||
// Create the root node and push it to the expansions queue
|
|
||||||
expansions_.push(SearchNode::Root(
|
|
||||||
conditionals_.size(), costToGo_.empty() ? 0.0 : costToGo_.back()));
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Construct from a DiscreteBayesTree and K.
|
* Construct from a DiscreteBayesTree and K.
|
||||||
*/
|
*/
|
||||||
DiscreteSearch(const DiscreteBayesTree& bayesTree, size_t K) : solutions_(K) {
|
DiscreteSearch(const DiscreteBayesTree& bayesTree, size_t K);
|
||||||
using CliquePtr = DiscreteBayesTree::sharedClique;
|
|
||||||
std::function<void(const CliquePtr&)> collectConditionals =
|
|
||||||
[&](const CliquePtr& clique) -> void {
|
|
||||||
if (!clique) return;
|
|
||||||
|
|
||||||
// Recursive post-order traversal: process children first
|
|
||||||
for (const auto& child : clique->children) {
|
|
||||||
collectConditionals(child);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Then add the current clique's conditional
|
|
||||||
conditionals_.push_back(clique->conditional());
|
|
||||||
};
|
|
||||||
|
|
||||||
// Start traversal from each root in the tree
|
|
||||||
for (const auto& root : bayesTree.roots()) collectConditionals(root);
|
|
||||||
|
|
||||||
// Calculate the cost-to-go for each conditional
|
|
||||||
costToGo_ = computeCostToGo(conditionals_);
|
|
||||||
|
|
||||||
// Create the root node and push it to the expansions queue
|
|
||||||
expansions_.push(SearchNode::Root(
|
|
||||||
conditionals_.size(), costToGo_.empty() ? 0.0 : costToGo_.back()));
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @brief Search for the K best solutions.
|
* @brief Search for the K best solutions.
|
||||||
|
|
@ -231,15 +150,7 @@ class DiscreteSearch {
|
||||||
*
|
*
|
||||||
* @return A vector of the K best solutions found during the search.
|
* @return A vector of the K best solutions found during the search.
|
||||||
*/
|
*/
|
||||||
std::vector<Solution> run() {
|
std::vector<Solution> run();
|
||||||
while (!expansions_.empty()) {
|
|
||||||
numExpansions++;
|
|
||||||
expandNextNode();
|
|
||||||
}
|
|
||||||
|
|
||||||
// Extract solutions from bestSolutions in ascending order of error
|
|
||||||
return solutions_.extractSolutions();
|
|
||||||
}
|
|
||||||
|
|
||||||
private:
|
private:
|
||||||
/**
|
/**
|
||||||
|
|
@ -249,58 +160,16 @@ class DiscreteSearch {
|
||||||
* @return A vector of cost-to-go values.
|
* @return A vector of cost-to-go values.
|
||||||
*/
|
*/
|
||||||
static std::vector<double> computeCostToGo(
|
static std::vector<double> computeCostToGo(
|
||||||
const std::vector<DiscreteConditional::shared_ptr>& conditionals) {
|
const std::vector<DiscreteConditional::shared_ptr>& conditionals);
|
||||||
std::vector<double> costToGo;
|
|
||||||
double error = 0.0;
|
|
||||||
for (const auto& conditional : conditionals) {
|
|
||||||
Ordering ordering(conditional->begin(), conditional->end());
|
|
||||||
auto maxx = conditional->max(ordering);
|
|
||||||
assert(maxx->size() == 1);
|
|
||||||
error -= std::log(maxx->evaluate({}));
|
|
||||||
costToGo.push_back(error);
|
|
||||||
}
|
|
||||||
return costToGo;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @brief Expand the next node in the search tree.
|
* @brief Expand the next node in the search tree.
|
||||||
*/
|
*/
|
||||||
void expandNextNode() {
|
void expandNextNode();
|
||||||
// Pop the partial assignment with the smallest bound
|
|
||||||
SearchNode current = expansions_.top();
|
|
||||||
expansions_.pop();
|
|
||||||
|
|
||||||
// If we already have K solutions, prune if we cannot beat the worst one.
|
|
||||||
if (solutions_.prune(current.bound)) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Check if we have a complete assignment
|
|
||||||
if (current.isComplete()) {
|
|
||||||
solutions_.maybeAdd(current.error, current.assignment);
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Expand on the next factor
|
|
||||||
const auto& conditional = conditionals_[current.nextConditional];
|
|
||||||
|
|
||||||
for (auto& fa : conditional->frontalAssignments()) {
|
|
||||||
auto childNode = current.expand(*conditional, fa);
|
|
||||||
if (childNode.nextConditional >= 0)
|
|
||||||
childNode.bound =
|
|
||||||
childNode.error + costToGo_[childNode.nextConditional];
|
|
||||||
|
|
||||||
// Again, prune if we cannot beat the worst solution
|
|
||||||
if (!solutions_.prune(childNode.bound)) {
|
|
||||||
expansions_.push(childNode);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
std::vector<DiscreteConditional::shared_ptr> conditionals_;
|
std::vector<DiscreteConditional::shared_ptr> conditionals_;
|
||||||
std::vector<double> costToGo_;
|
std::vector<double> costToGo_;
|
||||||
std::priority_queue<SearchNode, std::vector<SearchNode>,
|
std::priority_queue<SearchNode, std::vector<SearchNode>, SearchNode::Compare>
|
||||||
SearchNode::CompareByBound>
|
|
||||||
expansions_;
|
expansions_;
|
||||||
Solutions solutions_;
|
Solutions solutions_;
|
||||||
};
|
};
|
||||||
|
|
|
||||||
|
|
@ -48,6 +48,10 @@ TEST(DiscreteBayesNet, AsiaKBest) {
|
||||||
DiscreteBayesNet asia = createAsiaExample();
|
DiscreteBayesNet asia = createAsiaExample();
|
||||||
DiscreteSearch search(asia, 4);
|
DiscreteSearch search(asia, 4);
|
||||||
auto solutions = search.run();
|
auto solutions = search.run();
|
||||||
|
|
||||||
|
// print numExpansions
|
||||||
|
std::cout << "Number of expansions: " << search.numExpansions << std::endl;
|
||||||
|
|
||||||
EXPECT(!solutions.empty());
|
EXPECT(!solutions.empty());
|
||||||
// Regression test: check the first and last solution
|
// Regression test: check the first and last solution
|
||||||
EXPECT_DOUBLES_EQUAL(1.236627, std::fabs(solutions[0].error), 1e-5);
|
EXPECT_DOUBLES_EQUAL(1.236627, std::fabs(solutions[0].error), 1e-5);
|
||||||
|
|
@ -73,7 +77,6 @@ TEST(DiscreteBayesTree, testTrivialOneClique) {
|
||||||
DiscreteFactorGraph asia(createAsiaExample());
|
DiscreteFactorGraph asia(createAsiaExample());
|
||||||
const Ordering ordering{D, X, B, E, L, T, S, A};
|
const Ordering ordering{D, X, B, E, L, T, S, A};
|
||||||
DiscreteBayesTree::shared_ptr bt = asia.eliminateMultifrontal(ordering);
|
DiscreteBayesTree::shared_ptr bt = asia.eliminateMultifrontal(ordering);
|
||||||
GTSAM_PRINT(*bt);
|
|
||||||
|
|
||||||
// Ask for top 4 solutions
|
// Ask for top 4 solutions
|
||||||
DiscreteSearch search(*bt, 4);
|
DiscreteSearch search(*bt, 4);
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue