Made queue a class

release/4.3a0
Frank Dellaert 2025-01-27 00:37:17 -05:00
parent cee7769595
commit 9dcb52b697
1 changed files with 24 additions and 24 deletions

View File

@ -171,48 +171,48 @@ DiscreteSearch::DiscreteSearch(const DiscreteBayesTree& bayesTree) {
costToGo_ = computeCostToGo(conditionals_); costToGo_ = computeCostToGo(conditionals_);
}; };
using SearchNodeQueue = std::priority_queue<SearchNode, std::vector<SearchNode>, struct SearchNodeQueue
SearchNode::Compare>; : public std::priority_queue<SearchNode, std::vector<SearchNode>,
SearchNode::Compare> {
std::vector<Solution> DiscreteSearch::run(size_t K) const { void expandNextNode(
SearchNodeQueue expansions; const std::vector<DiscreteConditional::shared_ptr>& conditionals,
expansions.push(SearchNode::Root(conditionals_.size(), const std::vector<double>& costToGo, Solutions* solutions) {
costToGo_.empty() ? 0.0 : costToGo_.back()));
Solutions solutions(K);
auto expandNextNode = [&] {
// Pop the partial assignment with the smallest bound // Pop the partial assignment with the smallest bound
SearchNode current = expansions.top(); SearchNode current = top();
expansions.pop(); pop();
// If we already have K solutions, prune if we cannot beat the worst // If we already have K solutions, prune if we cannot beat the worst one.
// one. if (solutions->prune(current.bound)) {
if (solutions.prune(current.bound)) {
return; return;
} }
// Check if we have a complete assignment // Check if we have a complete assignment
if (current.isComplete()) { if (current.isComplete()) {
solutions.maybeAdd(current.error, current.assignment); solutions->maybeAdd(current.error, current.assignment);
return; return;
} }
// Expand on the next factor // Expand on the next factor
const auto& conditional = conditionals_[current.nextConditional]; const auto& conditional = conditionals[current.nextConditional];
for (auto& fa : conditional->frontalAssignments()) { for (auto& fa : conditional->frontalAssignments()) {
auto childNode = current.expand(*conditional, fa); auto childNode = current.expand(*conditional, fa);
if (childNode.nextConditional >= 0) if (childNode.nextConditional >= 0)
childNode.bound = childNode.bound = childNode.error + costToGo[childNode.nextConditional];
childNode.error + costToGo_[childNode.nextConditional];
// Again, prune if we cannot beat the worst solution // Again, prune if we cannot beat the worst solution
if (!solutions.prune(childNode.bound)) { if (!solutions->prune(childNode.bound)) {
expansions.emplace(childNode); emplace(childNode);
} }
} }
}; }
};
std::vector<Solution> DiscreteSearch::run(size_t K) const {
Solutions solutions(K);
SearchNodeQueue expansions;
expansions.push(SearchNode::Root(conditionals_.size(),
costToGo_.empty() ? 0.0 : costToGo_.back()));
#ifdef DISCRETE_SEARCH_DEBUG #ifdef DISCRETE_SEARCH_DEBUG
size_t numExpansions = 0; size_t numExpansions = 0;
@ -220,7 +220,7 @@ std::vector<Solution> DiscreteSearch::run(size_t K) const {
// Perform the search // Perform the search
while (!expansions.empty()) { while (!expansions.empty()) {
expandNextNode(); expansions.expandNextNode(conditionals_, costToGo_, &solutions);
#ifdef DISCRETE_SEARCH_DEBUG #ifdef DISCRETE_SEARCH_DEBUG
++numExpansions; ++numExpansions;
#endif #endif