diff --git a/gtsam/discrete/DiscreteSearch.cpp b/gtsam/discrete/DiscreteSearch.cpp index 5fca5c820..e3090bb50 100644 --- a/gtsam/discrete/DiscreteSearch.cpp +++ b/gtsam/discrete/DiscreteSearch.cpp @@ -171,48 +171,48 @@ DiscreteSearch::DiscreteSearch(const DiscreteBayesTree& bayesTree) { costToGo_ = computeCostToGo(conditionals_); }; -using SearchNodeQueue = std::priority_queue, - SearchNode::Compare>; - -std::vector DiscreteSearch::run(size_t K) const { - SearchNodeQueue expansions; - expansions.push(SearchNode::Root(conditionals_.size(), - costToGo_.empty() ? 0.0 : costToGo_.back())); - - Solutions solutions(K); - - auto expandNextNode = [&] { +struct SearchNodeQueue + : public std::priority_queue, + SearchNode::Compare> { + void expandNextNode( + const std::vector& conditionals, + const std::vector& costToGo, Solutions* solutions) { // Pop the partial assignment with the smallest bound - SearchNode current = expansions.top(); - expansions.pop(); + SearchNode current = top(); + pop(); - // If we already have K solutions, prune if we cannot beat the worst - // one. - if (solutions.prune(current.bound)) { + // If we already have K solutions, prune if we cannot beat the worst one. + if (solutions->prune(current.bound)) { return; } // Check if we have a complete assignment if (current.isComplete()) { - solutions.maybeAdd(current.error, current.assignment); + solutions->maybeAdd(current.error, current.assignment); return; } // Expand on the next factor - const auto& conditional = conditionals_[current.nextConditional]; + const auto& conditional = conditionals[current.nextConditional]; for (auto& fa : conditional->frontalAssignments()) { auto childNode = current.expand(*conditional, fa); if (childNode.nextConditional >= 0) - childNode.bound = - childNode.error + costToGo_[childNode.nextConditional]; + childNode.bound = childNode.error + costToGo[childNode.nextConditional]; // Again, prune if we cannot beat the worst solution - if (!solutions.prune(childNode.bound)) { - expansions.emplace(childNode); + if (!solutions->prune(childNode.bound)) { + emplace(childNode); } } - }; + } +}; + +std::vector DiscreteSearch::run(size_t K) const { + Solutions solutions(K); + SearchNodeQueue expansions; + expansions.push(SearchNode::Root(conditionals_.size(), + costToGo_.empty() ? 0.0 : costToGo_.back())); #ifdef DISCRETE_SEARCH_DEBUG size_t numExpansions = 0; @@ -220,7 +220,7 @@ std::vector DiscreteSearch::run(size_t K) const { // Perform the search while (!expansions.empty()) { - expandNextNode(); + expansions.expandNextNode(conditionals_, costToGo_, &solutions); #ifdef DISCRETE_SEARCH_DEBUG ++numExpansions; #endif