Made queue a class
parent
cee7769595
commit
9dcb52b697
|
@ -171,48 +171,48 @@ DiscreteSearch::DiscreteSearch(const DiscreteBayesTree& bayesTree) {
|
||||||
costToGo_ = computeCostToGo(conditionals_);
|
costToGo_ = computeCostToGo(conditionals_);
|
||||||
};
|
};
|
||||||
|
|
||||||
using SearchNodeQueue = std::priority_queue<SearchNode, std::vector<SearchNode>,
|
struct SearchNodeQueue
|
||||||
SearchNode::Compare>;
|
: public std::priority_queue<SearchNode, std::vector<SearchNode>,
|
||||||
|
SearchNode::Compare> {
|
||||||
std::vector<Solution> DiscreteSearch::run(size_t K) const {
|
void expandNextNode(
|
||||||
SearchNodeQueue expansions;
|
const std::vector<DiscreteConditional::shared_ptr>& conditionals,
|
||||||
expansions.push(SearchNode::Root(conditionals_.size(),
|
const std::vector<double>& costToGo, Solutions* solutions) {
|
||||||
costToGo_.empty() ? 0.0 : costToGo_.back()));
|
|
||||||
|
|
||||||
Solutions solutions(K);
|
|
||||||
|
|
||||||
auto expandNextNode = [&] {
|
|
||||||
// Pop the partial assignment with the smallest bound
|
// Pop the partial assignment with the smallest bound
|
||||||
SearchNode current = expansions.top();
|
SearchNode current = top();
|
||||||
expansions.pop();
|
pop();
|
||||||
|
|
||||||
// If we already have K solutions, prune if we cannot beat the worst
|
// If we already have K solutions, prune if we cannot beat the worst one.
|
||||||
// one.
|
if (solutions->prune(current.bound)) {
|
||||||
if (solutions.prune(current.bound)) {
|
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Check if we have a complete assignment
|
// Check if we have a complete assignment
|
||||||
if (current.isComplete()) {
|
if (current.isComplete()) {
|
||||||
solutions.maybeAdd(current.error, current.assignment);
|
solutions->maybeAdd(current.error, current.assignment);
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Expand on the next factor
|
// Expand on the next factor
|
||||||
const auto& conditional = conditionals_[current.nextConditional];
|
const auto& conditional = conditionals[current.nextConditional];
|
||||||
|
|
||||||
for (auto& fa : conditional->frontalAssignments()) {
|
for (auto& fa : conditional->frontalAssignments()) {
|
||||||
auto childNode = current.expand(*conditional, fa);
|
auto childNode = current.expand(*conditional, fa);
|
||||||
if (childNode.nextConditional >= 0)
|
if (childNode.nextConditional >= 0)
|
||||||
childNode.bound =
|
childNode.bound = childNode.error + costToGo[childNode.nextConditional];
|
||||||
childNode.error + costToGo_[childNode.nextConditional];
|
|
||||||
|
|
||||||
// Again, prune if we cannot beat the worst solution
|
// Again, prune if we cannot beat the worst solution
|
||||||
if (!solutions.prune(childNode.bound)) {
|
if (!solutions->prune(childNode.bound)) {
|
||||||
expansions.emplace(childNode);
|
emplace(childNode);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
};
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
std::vector<Solution> DiscreteSearch::run(size_t K) const {
|
||||||
|
Solutions solutions(K);
|
||||||
|
SearchNodeQueue expansions;
|
||||||
|
expansions.push(SearchNode::Root(conditionals_.size(),
|
||||||
|
costToGo_.empty() ? 0.0 : costToGo_.back()));
|
||||||
|
|
||||||
#ifdef DISCRETE_SEARCH_DEBUG
|
#ifdef DISCRETE_SEARCH_DEBUG
|
||||||
size_t numExpansions = 0;
|
size_t numExpansions = 0;
|
||||||
|
@ -220,7 +220,7 @@ std::vector<Solution> DiscreteSearch::run(size_t K) const {
|
||||||
|
|
||||||
// Perform the search
|
// Perform the search
|
||||||
while (!expansions.empty()) {
|
while (!expansions.empty()) {
|
||||||
expandNextNode();
|
expansions.expandNextNode(conditionals_, costToGo_, &solutions);
|
||||||
#ifdef DISCRETE_SEARCH_DEBUG
|
#ifdef DISCRETE_SEARCH_DEBUG
|
||||||
++numExpansions;
|
++numExpansions;
|
||||||
#endif
|
#endif
|
||||||
|
|
Loading…
Reference in New Issue