Moved data structures to cpp

release/4.3a0
Frank Dellaert 2025-01-27 00:32:03 -05:00
parent f174a38eed
commit cee7769595
2 changed files with 126 additions and 136 deletions

View File

@ -20,14 +20,53 @@
namespace gtsam { namespace gtsam {
SearchNode SearchNode::Root(size_t numConditionals, double bound) { using Value = size_t;
/**
* @brief Represents a node in the search tree for discrete search algorithms.
*
* @details Each SearchNode contains a partial assignment of discrete variables,
* the current error, a bound on the final error, and the index of the next
* conditional to be assigned.
*/
struct SearchNode {
DiscreteValues assignment; ///< Partial assignment of discrete variables.
double error; ///< Current error for the partial assignment.
double bound; ///< Lower bound on the final error for unassigned variables.
int nextConditional; ///< Index of the next conditional to be assigned.
/**
* @brief Construct the root node for the search.
*/
static SearchNode Root(size_t numConditionals, double bound) {
return {.assignment = DiscreteValues(), return {.assignment = DiscreteValues(),
.error = 0.0, .error = 0.0,
.bound = bound, .bound = bound,
.nextConditional = static_cast<int>(numConditionals) - 1}; .nextConditional = static_cast<int>(numConditionals) - 1};
} }
SearchNode SearchNode::expand(const DiscreteConditional& conditional, struct Compare {
bool operator()(const SearchNode& a, const SearchNode& b) const {
return a.bound > b.bound; // smallest bound -> highest priority
}
};
/**
* @brief Checks if the node represents a complete assignment.
*
* @return True if all variables have been assigned, false otherwise.
*/
inline bool isComplete() const { return nextConditional < 0; }
/**
* @brief Expands the node by assigning the next variable.
*
* @param conditional The discrete conditional representing the next variable
* to be assigned.
* @param fa The frontal assignment for the next variable.
* @return A new SearchNode representing the expanded state.
*/
SearchNode expand(const DiscreteConditional& conditional,
const DiscreteValues& fa) const { const DiscreteValues& fa) const {
// Combine the new frontal assignment with the current partial assignment // Combine the new frontal assignment with the current partial assignment
DiscreteValues newAssignment = assignment; DiscreteValues newAssignment = assignment;
@ -41,7 +80,37 @@ SearchNode SearchNode::expand(const DiscreteConditional& conditional,
.nextConditional = nextConditional - 1}; .nextConditional = nextConditional - 1};
} }
bool Solutions::maybeAdd(double error, const DiscreteValues& assignment) { /**
* @brief Prints the SearchNode to an output stream.
*
* @param os The output stream.
* @param node The SearchNode to be printed.
* @return The output stream.
*/
friend std::ostream& operator<<(std::ostream& os, const SearchNode& node) {
os << "SearchNode(error=" << node.error << ", bound=" << node.bound << ")";
return os;
}
};
struct CompareSolution {
bool operator()(const Solution& a, const Solution& b) const {
return a.error < b.error;
}
};
// Define the Solutions class
class Solutions {
private:
size_t maxSize_;
std::priority_queue<Solution, std::vector<Solution>, CompareSolution> pq_;
public:
Solutions(size_t maxSize) : maxSize_(maxSize) {}
/// Add a solution to the priority queue, possibly evicting the worst one.
/// Return true if we added the solution.
bool maybeAdd(double error, const DiscreteValues& assignment) {
const bool full = pq_.size() == maxSize_; const bool full = pq_.size() == maxSize_;
if (full && error >= pq_.top().error) return false; if (full && error >= pq_.top().error) return false;
if (full) pq_.pop(); if (full) pq_.pop();
@ -49,7 +118,11 @@ bool Solutions::maybeAdd(double error, const DiscreteValues& assignment) {
return true; return true;
} }
std::ostream& operator<<(std::ostream& os, const Solutions& sn) { /// Check if we have any solutions
bool empty() const { return pq_.empty(); }
// Method to print all solutions
friend std::ostream& operator<<(std::ostream& os, const Solutions& sn) {
os << "Solutions (top " << sn.pq_.size() << "):\n"; os << "Solutions (top " << sn.pq_.size() << "):\n";
auto pq = sn.pq_; auto pq = sn.pq_;
while (!pq.empty()) { while (!pq.empty()) {
@ -59,12 +132,16 @@ std::ostream& operator<<(std::ostream& os, const Solutions& sn) {
return os; return os;
} }
bool Solutions::prune(double bound) const { /// Check if (partial) solution with given bound can be pruned. If we have
/// room, we never prune. Otherwise, prune if lower bound on error is worse
/// than our current worst error.
bool prune(double bound) const {
if (pq_.size() < maxSize_) return false; if (pq_.size() < maxSize_) return false;
return bound >= pq_.top().error; return bound >= pq_.top().error;
} }
std::vector<Solution> Solutions::extractSolutions() { // Method to extract solutions in ascending order of error
std::vector<Solution> extractSolutions() {
std::vector<Solution> result; std::vector<Solution> result;
while (!pq_.empty()) { while (!pq_.empty()) {
result.push_back(pq_.top()); result.push_back(pq_.top());
@ -75,6 +152,7 @@ std::vector<Solution> Solutions::extractSolutions() {
[](const Solution& a, const Solution& b) { return a.error < b.error; }); [](const Solution& a, const Solution& b) { return a.error < b.error; });
return result; return result;
} }
};
DiscreteSearch::DiscreteSearch(const DiscreteBayesNet& bayesNet) { DiscreteSearch::DiscreteSearch(const DiscreteBayesNet& bayesNet) {
std::vector<DiscreteConditional::shared_ptr> conditionals; std::vector<DiscreteConditional::shared_ptr> conditionals;

View File

@ -23,63 +23,9 @@
namespace gtsam { namespace gtsam {
using Value = size_t;
/** /**
* @brief Represents a node in the search tree for discrete search algorithms. * @brief A solution to a discrete search problem.
*
* @details Each SearchNode contains a partial assignment of discrete variables,
* the current error, a bound on the final error, and the index of the next
* conditional to be assigned.
*/ */
struct SearchNode {
DiscreteValues assignment; ///< Partial assignment of discrete variables.
double error; ///< Current error for the partial assignment.
double bound; ///< Lower bound on the final error for unassigned variables.
int nextConditional; ///< Index of the next conditional to be assigned.
/**
* @brief Construct the root node for the search.
*/
static SearchNode Root(size_t numConditionals, double bound);
struct Compare {
bool operator()(const SearchNode& a, const SearchNode& b) const {
return a.bound > b.bound; // smallest bound -> highest priority
}
};
/**
* @brief Checks if the node represents a complete assignment.
*
* @return True if all variables have been assigned, false otherwise.
*/
inline bool isComplete() const { return nextConditional < 0; }
/**
* @brief Expands the node by assigning the next variable.
*
* @param conditional The discrete conditional representing the next variable
* to be assigned.
* @param fa The frontal assignment for the next variable.
* @return A new SearchNode representing the expanded state.
*/
SearchNode expand(const DiscreteConditional& conditional,
const DiscreteValues& fa) const;
/**
* @brief Prints the SearchNode to an output stream.
*
* @param os The output stream.
* @param node The SearchNode to be printed.
* @return The output stream.
*/
friend std::ostream& operator<<(std::ostream& os, const SearchNode& node) {
os << "SearchNode(error=" << node.error << ", bound=" << node.bound << ")";
return os;
}
};
struct Solution { struct Solution {
double error; double error;
DiscreteValues assignment; DiscreteValues assignment;
@ -89,40 +35,6 @@ struct Solution {
os << "[ error=" << sn.error << " assignment={" << sn.assignment << "}]"; os << "[ error=" << sn.error << " assignment={" << sn.assignment << "}]";
return os; return os;
} }
struct Compare {
bool operator()(const Solution& a, const Solution& b) const {
return a.error < b.error;
}
};
};
// Define the Solutions class
class Solutions {
private:
size_t maxSize_;
std::priority_queue<Solution, std::vector<Solution>, Solution::Compare> pq_;
public:
Solutions(size_t maxSize) : maxSize_(maxSize) {}
/// Add a solution to the priority queue, possibly evicting the worst one.
/// Return true if we added the solution.
bool maybeAdd(double error, const DiscreteValues& assignment);
/// Check if we have any solutions
bool empty() const { return pq_.empty(); }
// Method to print all solutions
friend std::ostream& operator<<(std::ostream& os, const Solutions& sn);
/// Check if (partial) solution with given bound can be pruned. If we have
/// room, we never prune. Otherwise, prune if lower bound on error is worse
/// than our current worst error.
bool prune(double bound) const;
// Method to extract solutions in ascending order of error
std::vector<Solution> extractSolutions();
}; };
/** /**