Methods in Solutions

release/4.3a0
Frank Dellaert 2025-01-26 17:40:44 -05:00
parent 2808982903
commit 74b7a962d5
1 changed files with 77 additions and 61 deletions

View File

@ -96,7 +96,7 @@ TEST(DiscreteBayesNet, Asia) {
// Convert to factor graph // Convert to factor graph
DiscreteFactorGraph fg(asia); DiscreteFactorGraph fg(asia);
LONGS_EQUAL(3, fg.back()->size()); LONGS_EQUAL(1, fg.back()->size());
// Check the marginals we know (of the parent-less nodes) // Check the marginals we know (of the parent-less nodes)
DiscreteMarginals marginals(fg); DiscreteMarginals marginals(fg);
@ -164,16 +164,16 @@ TEST(DiscreteBayesNet, Dot) {
"digraph {\n" "digraph {\n"
" size=\"5,5\";\n" " size=\"5,5\";\n"
"\n" "\n"
" var4683743612465315840[label=\"A0\"];\n" " var4683743612465315848[label=\"A8\"];\n"
" var4971973988617027589[label=\"E5\"];\n" " var4971973988617027587[label=\"E3\"];\n"
" var5476377146882523142[label=\"L6\"];\n" " var5476377146882523141[label=\"L5\"];\n"
" var5980780305148018692[label=\"S4\"];\n" " var5980780305148018695[label=\"S7\"];\n"
" var6052837899185946627[label=\"T3\"];\n" " var6052837899185946630[label=\"T6\"];\n"
"\n" "\n"
" var6052837899185946627->var4971973988617027589\n" " var6052837899185946630->var4971973988617027587\n"
" var5476377146882523142->var4971973988617027589\n" " var5476377146882523141->var4971973988617027587\n"
" var5980780305148018692->var5476377146882523142\n" " var5980780305148018695->var5476377146882523141\n"
" var4683743612465315840->var6052837899185946627\n" " var4683743612465315848->var6052837899185946630\n"
"}"); "}");
} }
@ -290,20 +290,62 @@ struct CompareByError {
} }
}; };
using Solutions = // Define the Solutions class
std::priority_queue<Solution, std::vector<Solution>, CompareByError>; class Solutions {
private:
size_t maxSize_;
std::priority_queue<Solution, std::vector<Solution>, CompareByError> pq_;
// Function to print the priority queue public:
void printPriorityQueue(Solutions pq) { Solutions(size_t maxSize) : maxSize_(maxSize) {}
size_t i = 0;
/// Add a solution to the priority queue, possibly evicting the worst one.
/// Return true if we added the solution.
bool maybeAdd(double error, const DiscreteValues& assignment) {
const bool full = pq_.size() == maxSize_;
if (full && error >= pq_.top().error) return false;
if (full) pq_.pop();
pq_.emplace(error, assignment);
return true;
}
/// Check if we have any solutions
bool empty() const { return pq_.empty(); }
// Method to print all solutions
void print() const {
auto pq = pq_;
while (!pq.empty()) { while (!pq.empty()) {
const Solution& best = pq.top(); const Solution& best = pq.top();
cout << " " << ++i << " Error: " << best.error std::cout << "Error: " << best.error << ", Values: " << best.assignment
<< ", Values: " << best.assignment << endl; << std::endl;
pq.pop(); pq.pop();
} }
} }
/// Check if (partial) solution with given bound can be pruned. If we have
/// room, we never prune. Otherwise, prune if lower bound on error is worse
/// than our current worst error.
bool prune(double bound) const {
if (pq_.size() < maxSize_) return false;
double worstError = pq_.top().error;
return (bound >= worstError);
}
// Method to extract solutions in ascending order of error
std::vector<Solution> extractSolutions() {
std::vector<Solution> result;
while (!pq_.empty()) {
result.push_back(pq_.top());
pq_.pop();
}
std::sort(
result.begin(), result.end(),
[](const Solution& a, const Solution& b) { return a.error < b.error; });
return result;
}
};
// ---------------------------------------------------------------------------- // ----------------------------------------------------------------------------
// 5) The main branch-and-bound routine for DiscreteBayesNet // 5) The main branch-and-bound routine for DiscreteBayesNet
// ---------------------------------------------------------------------------- // ----------------------------------------------------------------------------
@ -321,7 +363,7 @@ std::vector<Solution> branchAndBoundKSolutions(const DiscreteBayesNet& bayesNet,
// Max-heap of best solutions found so far, keyed by error. // Max-heap of best solutions found so far, keyed by error.
// We keep the largest error on top. // We keep the largest error on top.
Solutions solutions; Solutions solutions(K);
// 1) Create the root node: no variables assigned, nextConditional = last. // 1) Create the root node: no variables assigned, nextConditional = last.
SearchNode root{.assignment = DiscreteValues(), SearchNode root{.assignment = DiscreteValues(),
@ -341,32 +383,17 @@ std::vector<Solution> branchAndBoundKSolutions(const DiscreteBayesNet& bayesNet,
std::cout << "Expanding: " << current << std::endl; std::cout << "Expanding: " << current << std::endl;
// If we already have K solutions, prune if we cannot beat the worst one. // If we already have K solutions, prune if we cannot beat the worst one.
if (solutions.size() == K) { if (solutions.prune(current.bound)) {
double worstError = solutions.top().error; std::cout << "Pruning: bound=" << current.bound << std::endl;
if (current.bound >= worstError) {
std::cout << "Pruning: bound=" << current.bound
<< " worstError=" << worstError << std::endl;
continue; continue;
} }
}
// Check if we have a complete assignment // Check if we have a complete assignment
if (current.isComplete()) { if (current.isComplete()) {
// We have a new candidate solution const bool added = solutions.maybeAdd(current.error, current.assignment);
double finalError = current.error; if (added) {
if (solutions.size() < K) {
std::cout << "Adding solution: " << current << std::endl;
solutions.emplace(finalError, current.assignment);
// print out all solutions:
std::cout << "Best solutions so far:" << std::endl; std::cout << "Best solutions so far:" << std::endl;
printPriorityQueue(solutions); solutions.print();
} else {
double worstError = solutions.top().error;
if (finalError < worstError) {
std::cout << "Replacing solution: " << current << std::endl;
solutions.pop();
solutions.emplace(finalError, current.assignment);
}
} }
continue; continue;
} }
@ -379,14 +406,11 @@ std::vector<Solution> branchAndBoundKSolutions(const DiscreteBayesNet& bayesNet,
auto childNode = current.expand(*conditional, fa); auto childNode = current.expand(*conditional, fa);
// Again, prune if we cannot beat the worst solution // Again, prune if we cannot beat the worst solution
if (solutions.size() == K) { if (solutions.prune(current.bound)) {
double worstError = solutions.top().error; std::cout << "Pruning: bound=" << childNode.bound << std::endl;
if (childNode.bound >= worstError) {
std::cout << "Pruning: bound=" << childNode.bound
<< " worstError=" << worstError << std::endl;
continue; continue;
} }
}
expansions.push(childNode); expansions.push(childNode);
} }
} }
@ -394,15 +418,7 @@ std::vector<Solution> branchAndBoundKSolutions(const DiscreteBayesNet& bayesNet,
std::cout << "Expansions: " << numExpansions << std::endl; std::cout << "Expansions: " << numExpansions << std::endl;
// 3) Extract solutions from bestSolutions in ascending order of error // 3) Extract solutions from bestSolutions in ascending order of error
std::vector<Solution> result; return solutions.extractSolutions();
while (!solutions.empty()) {
result.push_back(solutions.top());
solutions.pop();
}
std::sort(result.begin(), result.end(),
[](auto& a, auto& b) { return a.error < b.error; });
return result;
} }
// ---------------------------------------------------------------------------- // ----------------------------------------------------------------------------
@ -422,8 +438,8 @@ TEST(DiscreteBayesNet, AsiaKBest) {
auto solutions = branchAndBoundKSolutions(net, 4); auto solutions = branchAndBoundKSolutions(net, 4);
EXPECT(!solutions.empty()); EXPECT(!solutions.empty());
// All solutions likely tie with error=0 in this stub // Regression test: check the first solution
EXPECT_DOUBLES_EQUAL(0, std::fabs(solutions[0].error), 1e-9); EXPECT_DOUBLES_EQUAL(1.236627, std::fabs(solutions[0].error), 1e-5);
} }
/* ************************************************************************* */ /* ************************************************************************* */