Moved data structures to cpp
parent
f174a38eed
commit
cee7769595
|
|
@ -20,61 +20,139 @@
|
||||||
|
|
||||||
namespace gtsam {
|
namespace gtsam {
|
||||||
|
|
||||||
SearchNode SearchNode::Root(size_t numConditionals, double bound) {
|
using Value = size_t;
|
||||||
return {.assignment = DiscreteValues(),
|
|
||||||
.error = 0.0,
|
|
||||||
.bound = bound,
|
|
||||||
.nextConditional = static_cast<int>(numConditionals) - 1};
|
|
||||||
}
|
|
||||||
|
|
||||||
SearchNode SearchNode::expand(const DiscreteConditional& conditional,
|
/**
|
||||||
const DiscreteValues& fa) const {
|
* @brief Represents a node in the search tree for discrete search algorithms.
|
||||||
// Combine the new frontal assignment with the current partial assignment
|
*
|
||||||
DiscreteValues newAssignment = assignment;
|
* @details Each SearchNode contains a partial assignment of discrete variables,
|
||||||
for (auto& [key, value] : fa) {
|
* the current error, a bound on the final error, and the index of the next
|
||||||
newAssignment[key] = value;
|
* conditional to be assigned.
|
||||||
|
*/
|
||||||
|
struct SearchNode {
|
||||||
|
DiscreteValues assignment; ///< Partial assignment of discrete variables.
|
||||||
|
double error; ///< Current error for the partial assignment.
|
||||||
|
double bound; ///< Lower bound on the final error for unassigned variables.
|
||||||
|
int nextConditional; ///< Index of the next conditional to be assigned.
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Construct the root node for the search.
|
||||||
|
*/
|
||||||
|
static SearchNode Root(size_t numConditionals, double bound) {
|
||||||
|
return {.assignment = DiscreteValues(),
|
||||||
|
.error = 0.0,
|
||||||
|
.bound = bound,
|
||||||
|
.nextConditional = static_cast<int>(numConditionals) - 1};
|
||||||
}
|
}
|
||||||
|
|
||||||
return {.assignment = newAssignment,
|
struct Compare {
|
||||||
.error = error + conditional.error(newAssignment),
|
bool operator()(const SearchNode& a, const SearchNode& b) const {
|
||||||
.bound = 0.0,
|
return a.bound > b.bound; // smallest bound -> highest priority
|
||||||
.nextConditional = nextConditional - 1};
|
}
|
||||||
}
|
};
|
||||||
|
|
||||||
bool Solutions::maybeAdd(double error, const DiscreteValues& assignment) {
|
/**
|
||||||
const bool full = pq_.size() == maxSize_;
|
* @brief Checks if the node represents a complete assignment.
|
||||||
if (full && error >= pq_.top().error) return false;
|
*
|
||||||
if (full) pq_.pop();
|
* @return True if all variables have been assigned, false otherwise.
|
||||||
pq_.emplace(error, assignment);
|
*/
|
||||||
return true;
|
inline bool isComplete() const { return nextConditional < 0; }
|
||||||
}
|
|
||||||
|
|
||||||
std::ostream& operator<<(std::ostream& os, const Solutions& sn) {
|
/**
|
||||||
os << "Solutions (top " << sn.pq_.size() << "):\n";
|
* @brief Expands the node by assigning the next variable.
|
||||||
auto pq = sn.pq_;
|
*
|
||||||
while (!pq.empty()) {
|
* @param conditional The discrete conditional representing the next variable
|
||||||
os << pq.top() << "\n";
|
* to be assigned.
|
||||||
pq.pop();
|
* @param fa The frontal assignment for the next variable.
|
||||||
|
* @return A new SearchNode representing the expanded state.
|
||||||
|
*/
|
||||||
|
SearchNode expand(const DiscreteConditional& conditional,
|
||||||
|
const DiscreteValues& fa) const {
|
||||||
|
// Combine the new frontal assignment with the current partial assignment
|
||||||
|
DiscreteValues newAssignment = assignment;
|
||||||
|
for (auto& [key, value] : fa) {
|
||||||
|
newAssignment[key] = value;
|
||||||
|
}
|
||||||
|
|
||||||
|
return {.assignment = newAssignment,
|
||||||
|
.error = error + conditional.error(newAssignment),
|
||||||
|
.bound = 0.0,
|
||||||
|
.nextConditional = nextConditional - 1};
|
||||||
}
|
}
|
||||||
return os;
|
|
||||||
}
|
|
||||||
|
|
||||||
bool Solutions::prune(double bound) const {
|
/**
|
||||||
if (pq_.size() < maxSize_) return false;
|
* @brief Prints the SearchNode to an output stream.
|
||||||
return bound >= pq_.top().error;
|
*
|
||||||
}
|
* @param os The output stream.
|
||||||
|
* @param node The SearchNode to be printed.
|
||||||
std::vector<Solution> Solutions::extractSolutions() {
|
* @return The output stream.
|
||||||
std::vector<Solution> result;
|
*/
|
||||||
while (!pq_.empty()) {
|
friend std::ostream& operator<<(std::ostream& os, const SearchNode& node) {
|
||||||
result.push_back(pq_.top());
|
os << "SearchNode(error=" << node.error << ", bound=" << node.bound << ")";
|
||||||
pq_.pop();
|
return os;
|
||||||
}
|
}
|
||||||
std::sort(
|
};
|
||||||
result.begin(), result.end(),
|
|
||||||
[](const Solution& a, const Solution& b) { return a.error < b.error; });
|
struct CompareSolution {
|
||||||
return result;
|
bool operator()(const Solution& a, const Solution& b) const {
|
||||||
}
|
return a.error < b.error;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
// Define the Solutions class
|
||||||
|
class Solutions {
|
||||||
|
private:
|
||||||
|
size_t maxSize_;
|
||||||
|
std::priority_queue<Solution, std::vector<Solution>, CompareSolution> pq_;
|
||||||
|
|
||||||
|
public:
|
||||||
|
Solutions(size_t maxSize) : maxSize_(maxSize) {}
|
||||||
|
|
||||||
|
/// Add a solution to the priority queue, possibly evicting the worst one.
|
||||||
|
/// Return true if we added the solution.
|
||||||
|
bool maybeAdd(double error, const DiscreteValues& assignment) {
|
||||||
|
const bool full = pq_.size() == maxSize_;
|
||||||
|
if (full && error >= pq_.top().error) return false;
|
||||||
|
if (full) pq_.pop();
|
||||||
|
pq_.emplace(error, assignment);
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Check if we have any solutions
|
||||||
|
bool empty() const { return pq_.empty(); }
|
||||||
|
|
||||||
|
// Method to print all solutions
|
||||||
|
friend std::ostream& operator<<(std::ostream& os, const Solutions& sn) {
|
||||||
|
os << "Solutions (top " << sn.pq_.size() << "):\n";
|
||||||
|
auto pq = sn.pq_;
|
||||||
|
while (!pq.empty()) {
|
||||||
|
os << pq.top() << "\n";
|
||||||
|
pq.pop();
|
||||||
|
}
|
||||||
|
return os;
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Check if (partial) solution with given bound can be pruned. If we have
|
||||||
|
/// room, we never prune. Otherwise, prune if lower bound on error is worse
|
||||||
|
/// than our current worst error.
|
||||||
|
bool prune(double bound) const {
|
||||||
|
if (pq_.size() < maxSize_) return false;
|
||||||
|
return bound >= pq_.top().error;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Method to extract solutions in ascending order of error
|
||||||
|
std::vector<Solution> extractSolutions() {
|
||||||
|
std::vector<Solution> result;
|
||||||
|
while (!pq_.empty()) {
|
||||||
|
result.push_back(pq_.top());
|
||||||
|
pq_.pop();
|
||||||
|
}
|
||||||
|
std::sort(
|
||||||
|
result.begin(), result.end(),
|
||||||
|
[](const Solution& a, const Solution& b) { return a.error < b.error; });
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
DiscreteSearch::DiscreteSearch(const DiscreteBayesNet& bayesNet) {
|
DiscreteSearch::DiscreteSearch(const DiscreteBayesNet& bayesNet) {
|
||||||
std::vector<DiscreteConditional::shared_ptr> conditionals;
|
std::vector<DiscreteConditional::shared_ptr> conditionals;
|
||||||
|
|
|
||||||
|
|
@ -23,63 +23,9 @@
|
||||||
|
|
||||||
namespace gtsam {
|
namespace gtsam {
|
||||||
|
|
||||||
using Value = size_t;
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @brief Represents a node in the search tree for discrete search algorithms.
|
* @brief A solution to a discrete search problem.
|
||||||
*
|
|
||||||
* @details Each SearchNode contains a partial assignment of discrete variables,
|
|
||||||
* the current error, a bound on the final error, and the index of the next
|
|
||||||
* conditional to be assigned.
|
|
||||||
*/
|
*/
|
||||||
struct SearchNode {
|
|
||||||
DiscreteValues assignment; ///< Partial assignment of discrete variables.
|
|
||||||
double error; ///< Current error for the partial assignment.
|
|
||||||
double bound; ///< Lower bound on the final error for unassigned variables.
|
|
||||||
int nextConditional; ///< Index of the next conditional to be assigned.
|
|
||||||
|
|
||||||
/**
|
|
||||||
* @brief Construct the root node for the search.
|
|
||||||
*/
|
|
||||||
static SearchNode Root(size_t numConditionals, double bound);
|
|
||||||
|
|
||||||
struct Compare {
|
|
||||||
bool operator()(const SearchNode& a, const SearchNode& b) const {
|
|
||||||
return a.bound > b.bound; // smallest bound -> highest priority
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
/**
|
|
||||||
* @brief Checks if the node represents a complete assignment.
|
|
||||||
*
|
|
||||||
* @return True if all variables have been assigned, false otherwise.
|
|
||||||
*/
|
|
||||||
inline bool isComplete() const { return nextConditional < 0; }
|
|
||||||
|
|
||||||
/**
|
|
||||||
* @brief Expands the node by assigning the next variable.
|
|
||||||
*
|
|
||||||
* @param conditional The discrete conditional representing the next variable
|
|
||||||
* to be assigned.
|
|
||||||
* @param fa The frontal assignment for the next variable.
|
|
||||||
* @return A new SearchNode representing the expanded state.
|
|
||||||
*/
|
|
||||||
SearchNode expand(const DiscreteConditional& conditional,
|
|
||||||
const DiscreteValues& fa) const;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* @brief Prints the SearchNode to an output stream.
|
|
||||||
*
|
|
||||||
* @param os The output stream.
|
|
||||||
* @param node The SearchNode to be printed.
|
|
||||||
* @return The output stream.
|
|
||||||
*/
|
|
||||||
friend std::ostream& operator<<(std::ostream& os, const SearchNode& node) {
|
|
||||||
os << "SearchNode(error=" << node.error << ", bound=" << node.bound << ")";
|
|
||||||
return os;
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
struct Solution {
|
struct Solution {
|
||||||
double error;
|
double error;
|
||||||
DiscreteValues assignment;
|
DiscreteValues assignment;
|
||||||
|
|
@ -89,40 +35,6 @@ struct Solution {
|
||||||
os << "[ error=" << sn.error << " assignment={" << sn.assignment << "}]";
|
os << "[ error=" << sn.error << " assignment={" << sn.assignment << "}]";
|
||||||
return os;
|
return os;
|
||||||
}
|
}
|
||||||
|
|
||||||
struct Compare {
|
|
||||||
bool operator()(const Solution& a, const Solution& b) const {
|
|
||||||
return a.error < b.error;
|
|
||||||
}
|
|
||||||
};
|
|
||||||
};
|
|
||||||
|
|
||||||
// Define the Solutions class
|
|
||||||
class Solutions {
|
|
||||||
private:
|
|
||||||
size_t maxSize_;
|
|
||||||
std::priority_queue<Solution, std::vector<Solution>, Solution::Compare> pq_;
|
|
||||||
|
|
||||||
public:
|
|
||||||
Solutions(size_t maxSize) : maxSize_(maxSize) {}
|
|
||||||
|
|
||||||
/// Add a solution to the priority queue, possibly evicting the worst one.
|
|
||||||
/// Return true if we added the solution.
|
|
||||||
bool maybeAdd(double error, const DiscreteValues& assignment);
|
|
||||||
|
|
||||||
/// Check if we have any solutions
|
|
||||||
bool empty() const { return pq_.empty(); }
|
|
||||||
|
|
||||||
// Method to print all solutions
|
|
||||||
friend std::ostream& operator<<(std::ostream& os, const Solutions& sn);
|
|
||||||
|
|
||||||
/// Check if (partial) solution with given bound can be pruned. If we have
|
|
||||||
/// room, we never prune. Otherwise, prune if lower bound on error is worse
|
|
||||||
/// than our current worst error.
|
|
||||||
bool prune(double bound) const;
|
|
||||||
|
|
||||||
// Method to extract solutions in ascending order of error
|
|
||||||
std::vector<Solution> extractSolutions();
|
|
||||||
};
|
};
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue