Made queue a class

release/4.3a0
Frank Dellaert 2025-01-27 00:37:17 -05:00
parent cee7769595
commit 9dcb52b697
1 changed files with 24 additions and 24 deletions

View File

@ -171,48 +171,48 @@ DiscreteSearch::DiscreteSearch(const DiscreteBayesTree& bayesTree) {
costToGo_ = computeCostToGo(conditionals_);
};
using SearchNodeQueue = std::priority_queue<SearchNode, std::vector<SearchNode>,
SearchNode::Compare>;
std::vector<Solution> DiscreteSearch::run(size_t K) const {
SearchNodeQueue expansions;
expansions.push(SearchNode::Root(conditionals_.size(),
costToGo_.empty() ? 0.0 : costToGo_.back()));
Solutions solutions(K);
auto expandNextNode = [&] {
struct SearchNodeQueue
: public std::priority_queue<SearchNode, std::vector<SearchNode>,
SearchNode::Compare> {
void expandNextNode(
const std::vector<DiscreteConditional::shared_ptr>& conditionals,
const std::vector<double>& costToGo, Solutions* solutions) {
// Pop the partial assignment with the smallest bound
SearchNode current = expansions.top();
expansions.pop();
SearchNode current = top();
pop();
// If we already have K solutions, prune if we cannot beat the worst
// one.
if (solutions.prune(current.bound)) {
// If we already have K solutions, prune if we cannot beat the worst one.
if (solutions->prune(current.bound)) {
return;
}
// Check if we have a complete assignment
if (current.isComplete()) {
solutions.maybeAdd(current.error, current.assignment);
solutions->maybeAdd(current.error, current.assignment);
return;
}
// Expand on the next factor
const auto& conditional = conditionals_[current.nextConditional];
const auto& conditional = conditionals[current.nextConditional];
for (auto& fa : conditional->frontalAssignments()) {
auto childNode = current.expand(*conditional, fa);
if (childNode.nextConditional >= 0)
childNode.bound =
childNode.error + costToGo_[childNode.nextConditional];
childNode.bound = childNode.error + costToGo[childNode.nextConditional];
// Again, prune if we cannot beat the worst solution
if (!solutions.prune(childNode.bound)) {
expansions.emplace(childNode);
if (!solutions->prune(childNode.bound)) {
emplace(childNode);
}
}
};
}
};
std::vector<Solution> DiscreteSearch::run(size_t K) const {
Solutions solutions(K);
SearchNodeQueue expansions;
expansions.push(SearchNode::Root(conditionals_.size(),
costToGo_.empty() ? 0.0 : costToGo_.back()));
#ifdef DISCRETE_SEARCH_DEBUG
size_t numExpansions = 0;
@ -220,7 +220,7 @@ std::vector<Solution> DiscreteSearch::run(size_t K) const {
// Perform the search
while (!expansions.empty()) {
expandNextNode();
expansions.expandNextNode(conditionals_, costToGo_, &solutions);
#ifdef DISCRETE_SEARCH_DEBUG
++numExpansions;
#endif