Made queue a class
parent
cee7769595
commit
9dcb52b697
|
@ -171,48 +171,48 @@ DiscreteSearch::DiscreteSearch(const DiscreteBayesTree& bayesTree) {
|
|||
costToGo_ = computeCostToGo(conditionals_);
|
||||
};
|
||||
|
||||
using SearchNodeQueue = std::priority_queue<SearchNode, std::vector<SearchNode>,
|
||||
SearchNode::Compare>;
|
||||
|
||||
std::vector<Solution> DiscreteSearch::run(size_t K) const {
|
||||
SearchNodeQueue expansions;
|
||||
expansions.push(SearchNode::Root(conditionals_.size(),
|
||||
costToGo_.empty() ? 0.0 : costToGo_.back()));
|
||||
|
||||
Solutions solutions(K);
|
||||
|
||||
auto expandNextNode = [&] {
|
||||
struct SearchNodeQueue
|
||||
: public std::priority_queue<SearchNode, std::vector<SearchNode>,
|
||||
SearchNode::Compare> {
|
||||
void expandNextNode(
|
||||
const std::vector<DiscreteConditional::shared_ptr>& conditionals,
|
||||
const std::vector<double>& costToGo, Solutions* solutions) {
|
||||
// Pop the partial assignment with the smallest bound
|
||||
SearchNode current = expansions.top();
|
||||
expansions.pop();
|
||||
SearchNode current = top();
|
||||
pop();
|
||||
|
||||
// If we already have K solutions, prune if we cannot beat the worst
|
||||
// one.
|
||||
if (solutions.prune(current.bound)) {
|
||||
// If we already have K solutions, prune if we cannot beat the worst one.
|
||||
if (solutions->prune(current.bound)) {
|
||||
return;
|
||||
}
|
||||
|
||||
// Check if we have a complete assignment
|
||||
if (current.isComplete()) {
|
||||
solutions.maybeAdd(current.error, current.assignment);
|
||||
solutions->maybeAdd(current.error, current.assignment);
|
||||
return;
|
||||
}
|
||||
|
||||
// Expand on the next factor
|
||||
const auto& conditional = conditionals_[current.nextConditional];
|
||||
const auto& conditional = conditionals[current.nextConditional];
|
||||
|
||||
for (auto& fa : conditional->frontalAssignments()) {
|
||||
auto childNode = current.expand(*conditional, fa);
|
||||
if (childNode.nextConditional >= 0)
|
||||
childNode.bound =
|
||||
childNode.error + costToGo_[childNode.nextConditional];
|
||||
childNode.bound = childNode.error + costToGo[childNode.nextConditional];
|
||||
|
||||
// Again, prune if we cannot beat the worst solution
|
||||
if (!solutions.prune(childNode.bound)) {
|
||||
expansions.emplace(childNode);
|
||||
if (!solutions->prune(childNode.bound)) {
|
||||
emplace(childNode);
|
||||
}
|
||||
}
|
||||
};
|
||||
}
|
||||
};
|
||||
|
||||
std::vector<Solution> DiscreteSearch::run(size_t K) const {
|
||||
Solutions solutions(K);
|
||||
SearchNodeQueue expansions;
|
||||
expansions.push(SearchNode::Root(conditionals_.size(),
|
||||
costToGo_.empty() ? 0.0 : costToGo_.back()));
|
||||
|
||||
#ifdef DISCRETE_SEARCH_DEBUG
|
||||
size_t numExpansions = 0;
|
||||
|
@ -220,7 +220,7 @@ std::vector<Solution> DiscreteSearch::run(size_t K) const {
|
|||
|
||||
// Perform the search
|
||||
while (!expansions.empty()) {
|
||||
expandNextNode();
|
||||
expansions.expandNextNode(conditionals_, costToGo_, &solutions);
|
||||
#ifdef DISCRETE_SEARCH_DEBUG
|
||||
++numExpansions;
|
||||
#endif
|
||||
|
|
Loading…
Reference in New Issue