Move to cpp file

release/4.3a0
Frank Dellaert 2025-01-26 23:34:30 -05:00
parent 14eeaf93db
commit 70089a0fd4
3 changed files with 201 additions and 151 deletions

View File

@ -0,0 +1,178 @@
/* ----------------------------------------------------------------------------
* GTSAM Copyright 2010, Georgia Tech Research Corporation,
* Atlanta, Georgia 30332-0415
* All Rights Reserved
* Authors: Frank Dellaert, et al. (see THANKS for the full author list)
* See LICENSE for the license information
* -------------------------------------------------------------------------- */
/*
* DiscreteSearch.cpp
*
* @date January, 2025
* @author Frank Dellaert
*/
#include <gtsam/discrete/DiscreteSearch.h>
namespace gtsam {
SearchNode SearchNode::Root(size_t numConditionals, double bound) {
return {.assignment = DiscreteValues(),
.error = 0.0,
.bound = bound,
.nextConditional = static_cast<int>(numConditionals) - 1};
}
SearchNode SearchNode::expand(const DiscreteConditional& conditional,
const DiscreteValues& fa) const {
// Combine the new frontal assignment with the current partial assignment
DiscreteValues newAssignment = assignment;
for (auto& kv : fa) {
newAssignment[kv.first] = kv.second;
}
return {.assignment = newAssignment,
.error = error + conditional.error(newAssignment),
.bound = 0.0,
.nextConditional = nextConditional - 1};
}
bool Solutions::maybeAdd(double error, const DiscreteValues& assignment) {
const bool full = pq_.size() == maxSize_;
if (full && error >= pq_.top().error) return false;
if (full) pq_.pop();
pq_.emplace(error, assignment);
return true;
}
std::ostream& operator<<(std::ostream& os, const Solutions& sn) {
auto pq = sn.pq_;
while (!pq.empty()) {
const Solution& best = pq.top();
os << "Error: " << best.error << ", Values: " << best.assignment
<< std::endl;
pq.pop();
}
return os;
}
bool Solutions::prune(double bound) const {
if (pq_.size() < maxSize_) return false;
double worstError = pq_.top().error;
return (bound >= worstError);
}
std::vector<Solution> Solutions::extractSolutions() {
std::vector<Solution> result;
while (!pq_.empty()) {
result.push_back(pq_.top());
pq_.pop();
}
std::sort(
result.begin(), result.end(),
[](const Solution& a, const Solution& b) { return a.error < b.error; });
return result;
}
DiscreteSearch::DiscreteSearch(const DiscreteBayesNet& bayesNet, size_t K)
: solutions_(K) {
// Copy out the conditionals
for (auto& factor : bayesNet) {
conditionals_.push_back(factor);
}
// Calculate the cost-to-go for each conditional
costToGo_ = computeCostToGo(conditionals_);
// Create the root node and push it to the expansions queue
expansions_.push(SearchNode::Root(
conditionals_.size(), costToGo_.empty() ? 0.0 : costToGo_.back()));
}
DiscreteSearch::DiscreteSearch(const DiscreteBayesTree& bayesTree, size_t K)
: solutions_(K) {
using CliquePtr = DiscreteBayesTree::sharedClique;
std::function<void(const CliquePtr&)> collectConditionals =
[&](const CliquePtr& clique) -> void {
if (!clique) return;
// Recursive post-order traversal: process children first
for (const auto& child : clique->children) {
collectConditionals(child);
}
// Then add the current clique's conditional
conditionals_.push_back(clique->conditional());
};
// Start traversal from each root in the tree
for (const auto& root : bayesTree.roots()) collectConditionals(root);
// Calculate the cost-to-go for each conditional
costToGo_ = computeCostToGo(conditionals_);
// Create the root node and push it to the expansions queue
expansions_.push(SearchNode::Root(
conditionals_.size(), costToGo_.empty() ? 0.0 : costToGo_.back()));
}
std::vector<Solution> DiscreteSearch::run() {
while (!expansions_.empty()) {
numExpansions++;
expandNextNode();
}
// Extract solutions from bestSolutions in ascending order of error
return solutions_.extractSolutions();
}
std::vector<double> DiscreteSearch::computeCostToGo(
const std::vector<DiscreteConditional::shared_ptr>& conditionals) {
std::vector<double> costToGo;
double error = 0.0;
for (const auto& conditional : conditionals) {
Ordering ordering(conditional->begin(), conditional->end());
auto maxx = conditional->max(ordering);
assert(maxx->size() == 1);
error -= std::log(maxx->evaluate({}));
costToGo.push_back(error);
}
return costToGo;
}
void DiscreteSearch::expandNextNode() {
// Pop the partial assignment with the smallest bound
SearchNode current = expansions_.top();
expansions_.pop();
// If we already have K solutions, prune if we cannot beat the worst one.
if (solutions_.prune(current.bound)) {
return;
}
// Check if we have a complete assignment
if (current.isComplete()) {
solutions_.maybeAdd(current.error, current.assignment);
return;
}
// Expand on the next factor
const auto& conditional = conditionals_[current.nextConditional];
for (auto& fa : conditional->frontalAssignments()) {
auto childNode = current.expand(*conditional, fa);
if (childNode.nextConditional >= 0)
childNode.bound = childNode.error + costToGo_[childNode.nextConditional];
// Again, prune if we cannot beat the worst solution
if (!solutions_.prune(childNode.bound)) {
expansions_.push(childNode);
}
}
}
} // namespace gtsam

View File

@ -39,14 +39,9 @@ struct SearchNode {
/** /**
* @brief Construct the root node for the search. * @brief Construct the root node for the search.
*/ */
static SearchNode Root(size_t numConditionals, double bound) { static SearchNode Root(size_t numConditionals, double bound);
return {.assignment = DiscreteValues(),
.error = 0.0,
.bound = bound,
.nextConditional = static_cast<int>(numConditionals) - 1};
}
struct CompareByBound { struct Compare {
bool operator()(const SearchNode& a, const SearchNode& b) const { bool operator()(const SearchNode& a, const SearchNode& b) const {
return a.bound > b.bound; // smallest bound -> highest priority return a.bound > b.bound; // smallest bound -> highest priority
} }
@ -68,18 +63,7 @@ struct SearchNode {
* @return A new SearchNode representing the expanded state. * @return A new SearchNode representing the expanded state.
*/ */
SearchNode expand(const DiscreteConditional& conditional, SearchNode expand(const DiscreteConditional& conditional,
const DiscreteValues& fa) const { const DiscreteValues& fa) const;
// Combine the new frontal assignment with the current partial assignment
DiscreteValues newAssignment = assignment;
for (auto& kv : fa) {
newAssignment[kv.first] = kv.second;
}
return {.assignment = newAssignment,
.error = error + conditional.error(newAssignment),
.bound = 0.0,
.nextConditional = nextConditional - 1};
}
/** /**
* @brief Prints the SearchNode to an output stream. * @brief Prints the SearchNode to an output stream.
@ -103,69 +87,40 @@ struct Solution {
os << "[ error=" << sn.error << " assignment={" << sn.assignment << "}]"; os << "[ error=" << sn.error << " assignment={" << sn.assignment << "}]";
return os; return os;
} }
};
struct CompareByError { struct Compare {
bool operator()(const Solution& a, const Solution& b) const { bool operator()(const Solution& a, const Solution& b) const {
return a.error < b.error; return a.error < b.error;
} }
};
}; };
// Define the Solutions class // Define the Solutions class
class Solutions { class Solutions {
private: private:
size_t maxSize_; size_t maxSize_;
std::priority_queue<Solution, std::vector<Solution>, CompareByError> pq_; std::priority_queue<Solution, std::vector<Solution>, Solution::Compare> pq_;
public: public:
Solutions(size_t maxSize) : maxSize_(maxSize) {} Solutions(size_t maxSize) : maxSize_(maxSize) {}
/// Add a solution to the priority queue, possibly evicting the worst one. /// Add a solution to the priority queue, possibly evicting the worst one.
/// Return true if we added the solution. /// Return true if we added the solution.
bool maybeAdd(double error, const DiscreteValues& assignment) { bool maybeAdd(double error, const DiscreteValues& assignment);
const bool full = pq_.size() == maxSize_;
if (full && error >= pq_.top().error) return false;
if (full) pq_.pop();
pq_.emplace(error, assignment);
return true;
}
/// Check if we have any solutions /// Check if we have any solutions
bool empty() const { return pq_.empty(); } bool empty() const { return pq_.empty(); }
// Method to print all solutions // Method to print all solutions
friend std::ostream& operator<<(std::ostream& os, const Solutions& sn) { friend std::ostream& operator<<(std::ostream& os, const Solutions& sn);
auto pq = sn.pq_;
while (!pq.empty()) {
const Solution& best = pq.top();
os << "Error: " << best.error << ", Values: " << best.assignment
<< std::endl;
pq.pop();
}
return os;
}
/// Check if (partial) solution with given bound can be pruned. If we have /// Check if (partial) solution with given bound can be pruned. If we have
/// room, we never prune. Otherwise, prune if lower bound on error is worse /// room, we never prune. Otherwise, prune if lower bound on error is worse
/// than our current worst error. /// than our current worst error.
bool prune(double bound) const { bool prune(double bound) const;
if (pq_.size() < maxSize_) return false;
double worstError = pq_.top().error;
return (bound >= worstError);
}
// Method to extract solutions in ascending order of error // Method to extract solutions in ascending order of error
std::vector<Solution> extractSolutions() { std::vector<Solution> extractSolutions();
std::vector<Solution> result;
while (!pq_.empty()) {
result.push_back(pq_.top());
pq_.pop();
}
std::sort(
result.begin(), result.end(),
[](const Solution& a, const Solution& b) { return a.error < b.error; });
return result;
}
}; };
/** /**
@ -178,48 +133,12 @@ class DiscreteSearch {
/** /**
* Construct from a DiscreteBayesNet and K. * Construct from a DiscreteBayesNet and K.
*/ */
DiscreteSearch(const DiscreteBayesNet& bayesNet, size_t K) : solutions_(K) { DiscreteSearch(const DiscreteBayesNet& bayesNet, size_t K);
// Copy out the conditionals
for (auto& factor : bayesNet) {
conditionals_.push_back(factor);
}
// Calculate the cost-to-go for each conditional
costToGo_ = computeCostToGo(conditionals_);
// Create the root node and push it to the expansions queue
expansions_.push(SearchNode::Root(
conditionals_.size(), costToGo_.empty() ? 0.0 : costToGo_.back()));
}
/** /**
* Construct from a DiscreteBayesTree and K. * Construct from a DiscreteBayesTree and K.
*/ */
DiscreteSearch(const DiscreteBayesTree& bayesTree, size_t K) : solutions_(K) { DiscreteSearch(const DiscreteBayesTree& bayesTree, size_t K);
using CliquePtr = DiscreteBayesTree::sharedClique;
std::function<void(const CliquePtr&)> collectConditionals =
[&](const CliquePtr& clique) -> void {
if (!clique) return;
// Recursive post-order traversal: process children first
for (const auto& child : clique->children) {
collectConditionals(child);
}
// Then add the current clique's conditional
conditionals_.push_back(clique->conditional());
};
// Start traversal from each root in the tree
for (const auto& root : bayesTree.roots()) collectConditionals(root);
// Calculate the cost-to-go for each conditional
costToGo_ = computeCostToGo(conditionals_);
// Create the root node and push it to the expansions queue
expansions_.push(SearchNode::Root(
conditionals_.size(), costToGo_.empty() ? 0.0 : costToGo_.back()));
}
/** /**
* @brief Search for the K best solutions. * @brief Search for the K best solutions.
@ -231,15 +150,7 @@ class DiscreteSearch {
* *
* @return A vector of the K best solutions found during the search. * @return A vector of the K best solutions found during the search.
*/ */
std::vector<Solution> run() { std::vector<Solution> run();
while (!expansions_.empty()) {
numExpansions++;
expandNextNode();
}
// Extract solutions from bestSolutions in ascending order of error
return solutions_.extractSolutions();
}
private: private:
/** /**
@ -249,58 +160,16 @@ class DiscreteSearch {
* @return A vector of cost-to-go values. * @return A vector of cost-to-go values.
*/ */
static std::vector<double> computeCostToGo( static std::vector<double> computeCostToGo(
const std::vector<DiscreteConditional::shared_ptr>& conditionals) { const std::vector<DiscreteConditional::shared_ptr>& conditionals);
std::vector<double> costToGo;
double error = 0.0;
for (const auto& conditional : conditionals) {
Ordering ordering(conditional->begin(), conditional->end());
auto maxx = conditional->max(ordering);
assert(maxx->size() == 1);
error -= std::log(maxx->evaluate({}));
costToGo.push_back(error);
}
return costToGo;
}
/** /**
* @brief Expand the next node in the search tree. * @brief Expand the next node in the search tree.
*/ */
void expandNextNode() { void expandNextNode();
// Pop the partial assignment with the smallest bound
SearchNode current = expansions_.top();
expansions_.pop();
// If we already have K solutions, prune if we cannot beat the worst one.
if (solutions_.prune(current.bound)) {
return;
}
// Check if we have a complete assignment
if (current.isComplete()) {
solutions_.maybeAdd(current.error, current.assignment);
return;
}
// Expand on the next factor
const auto& conditional = conditionals_[current.nextConditional];
for (auto& fa : conditional->frontalAssignments()) {
auto childNode = current.expand(*conditional, fa);
if (childNode.nextConditional >= 0)
childNode.bound =
childNode.error + costToGo_[childNode.nextConditional];
// Again, prune if we cannot beat the worst solution
if (!solutions_.prune(childNode.bound)) {
expansions_.push(childNode);
}
}
}
std::vector<DiscreteConditional::shared_ptr> conditionals_; std::vector<DiscreteConditional::shared_ptr> conditionals_;
std::vector<double> costToGo_; std::vector<double> costToGo_;
std::priority_queue<SearchNode, std::vector<SearchNode>, std::priority_queue<SearchNode, std::vector<SearchNode>, SearchNode::Compare>
SearchNode::CompareByBound>
expansions_; expansions_;
Solutions solutions_; Solutions solutions_;
}; };

View File

@ -48,6 +48,10 @@ TEST(DiscreteBayesNet, AsiaKBest) {
DiscreteBayesNet asia = createAsiaExample(); DiscreteBayesNet asia = createAsiaExample();
DiscreteSearch search(asia, 4); DiscreteSearch search(asia, 4);
auto solutions = search.run(); auto solutions = search.run();
// print numExpansions
std::cout << "Number of expansions: " << search.numExpansions << std::endl;
EXPECT(!solutions.empty()); EXPECT(!solutions.empty());
// Regression test: check the first and last solution // Regression test: check the first and last solution
EXPECT_DOUBLES_EQUAL(1.236627, std::fabs(solutions[0].error), 1e-5); EXPECT_DOUBLES_EQUAL(1.236627, std::fabs(solutions[0].error), 1e-5);
@ -73,7 +77,6 @@ TEST(DiscreteBayesTree, testTrivialOneClique) {
DiscreteFactorGraph asia(createAsiaExample()); DiscreteFactorGraph asia(createAsiaExample());
const Ordering ordering{D, X, B, E, L, T, S, A}; const Ordering ordering{D, X, B, E, L, T, S, A};
DiscreteBayesTree::shared_ptr bt = asia.eliminateMultifrontal(ordering); DiscreteBayesTree::shared_ptr bt = asia.eliminateMultifrontal(ordering);
GTSAM_PRINT(*bt);
// Ask for top 4 solutions // Ask for top 4 solutions
DiscreteSearch search(*bt, 4); DiscreteSearch search(*bt, 4);