Methods in Solutions

release/4.3a0
Frank Dellaert 2025-01-26 17:40:44 -05:00
parent 2808982903
commit 74b7a962d5
1 changed files with 77 additions and 61 deletions

View File

@ -96,7 +96,7 @@ TEST(DiscreteBayesNet, Asia) {
// Convert to factor graph
DiscreteFactorGraph fg(asia);
LONGS_EQUAL(3, fg.back()->size());
LONGS_EQUAL(1, fg.back()->size());
// Check the marginals we know (of the parent-less nodes)
DiscreteMarginals marginals(fg);
@ -164,16 +164,16 @@ TEST(DiscreteBayesNet, Dot) {
"digraph {\n"
" size=\"5,5\";\n"
"\n"
" var4683743612465315840[label=\"A0\"];\n"
" var4971973988617027589[label=\"E5\"];\n"
" var5476377146882523142[label=\"L6\"];\n"
" var5980780305148018692[label=\"S4\"];\n"
" var6052837899185946627[label=\"T3\"];\n"
" var4683743612465315848[label=\"A8\"];\n"
" var4971973988617027587[label=\"E3\"];\n"
" var5476377146882523141[label=\"L5\"];\n"
" var5980780305148018695[label=\"S7\"];\n"
" var6052837899185946630[label=\"T6\"];\n"
"\n"
" var6052837899185946627->var4971973988617027589\n"
" var5476377146882523142->var4971973988617027589\n"
" var5980780305148018692->var5476377146882523142\n"
" var4683743612465315840->var6052837899185946627\n"
" var6052837899185946630->var4971973988617027587\n"
" var5476377146882523141->var4971973988617027587\n"
" var5980780305148018695->var5476377146882523141\n"
" var4683743612465315848->var6052837899185946630\n"
"}");
}
@ -290,19 +290,61 @@ struct CompareByError {
}
};
using Solutions =
std::priority_queue<Solution, std::vector<Solution>, CompareByError>;
// Define the Solutions class
class Solutions {
private:
size_t maxSize_;
std::priority_queue<Solution, std::vector<Solution>, CompareByError> pq_;
// Function to print the priority queue
void printPriorityQueue(Solutions pq) {
size_t i = 0;
while (!pq.empty()) {
const Solution& best = pq.top();
cout << " " << ++i << " Error: " << best.error
<< ", Values: " << best.assignment << endl;
pq.pop();
public:
Solutions(size_t maxSize) : maxSize_(maxSize) {}
/// Add a solution to the priority queue, possibly evicting the worst one.
/// Return true if we added the solution.
bool maybeAdd(double error, const DiscreteValues& assignment) {
const bool full = pq_.size() == maxSize_;
if (full && error >= pq_.top().error) return false;
if (full) pq_.pop();
pq_.emplace(error, assignment);
return true;
}
}
/// Check if we have any solutions
bool empty() const { return pq_.empty(); }
// Method to print all solutions
void print() const {
auto pq = pq_;
while (!pq.empty()) {
const Solution& best = pq.top();
std::cout << "Error: " << best.error << ", Values: " << best.assignment
<< std::endl;
pq.pop();
}
}
/// Check if (partial) solution with given bound can be pruned. If we have
/// room, we never prune. Otherwise, prune if lower bound on error is worse
/// than our current worst error.
bool prune(double bound) const {
if (pq_.size() < maxSize_) return false;
double worstError = pq_.top().error;
return (bound >= worstError);
}
// Method to extract solutions in ascending order of error
std::vector<Solution> extractSolutions() {
std::vector<Solution> result;
while (!pq_.empty()) {
result.push_back(pq_.top());
pq_.pop();
}
std::sort(
result.begin(), result.end(),
[](const Solution& a, const Solution& b) { return a.error < b.error; });
return result;
}
};
// ----------------------------------------------------------------------------
// 5) The main branch-and-bound routine for DiscreteBayesNet
@ -321,7 +363,7 @@ std::vector<Solution> branchAndBoundKSolutions(const DiscreteBayesNet& bayesNet,
// Max-heap of best solutions found so far, keyed by error.
// We keep the largest error on top.
Solutions solutions;
Solutions solutions(K);
// 1) Create the root node: no variables assigned, nextConditional = last.
SearchNode root{.assignment = DiscreteValues(),
@ -341,32 +383,17 @@ std::vector<Solution> branchAndBoundKSolutions(const DiscreteBayesNet& bayesNet,
std::cout << "Expanding: " << current << std::endl;
// If we already have K solutions, prune if we cannot beat the worst one.
if (solutions.size() == K) {
double worstError = solutions.top().error;
if (current.bound >= worstError) {
std::cout << "Pruning: bound=" << current.bound
<< " worstError=" << worstError << std::endl;
continue;
}
if (solutions.prune(current.bound)) {
std::cout << "Pruning: bound=" << current.bound << std::endl;
continue;
}
// Check if we have a complete assignment
if (current.isComplete()) {
// We have a new candidate solution
double finalError = current.error;
if (solutions.size() < K) {
std::cout << "Adding solution: " << current << std::endl;
solutions.emplace(finalError, current.assignment);
// print out all solutions:
const bool added = solutions.maybeAdd(current.error, current.assignment);
if (added) {
std::cout << "Best solutions so far:" << std::endl;
printPriorityQueue(solutions);
} else {
double worstError = solutions.top().error;
if (finalError < worstError) {
std::cout << "Replacing solution: " << current << std::endl;
solutions.pop();
solutions.emplace(finalError, current.assignment);
}
solutions.print();
}
continue;
}
@ -379,14 +406,11 @@ std::vector<Solution> branchAndBoundKSolutions(const DiscreteBayesNet& bayesNet,
auto childNode = current.expand(*conditional, fa);
// Again, prune if we cannot beat the worst solution
if (solutions.size() == K) {
double worstError = solutions.top().error;
if (childNode.bound >= worstError) {
std::cout << "Pruning: bound=" << childNode.bound
<< " worstError=" << worstError << std::endl;
continue;
}
if (solutions.prune(current.bound)) {
std::cout << "Pruning: bound=" << childNode.bound << std::endl;
continue;
}
expansions.push(childNode);
}
}
@ -394,15 +418,7 @@ std::vector<Solution> branchAndBoundKSolutions(const DiscreteBayesNet& bayesNet,
std::cout << "Expansions: " << numExpansions << std::endl;
// 3) Extract solutions from bestSolutions in ascending order of error
std::vector<Solution> result;
while (!solutions.empty()) {
result.push_back(solutions.top());
solutions.pop();
}
std::sort(result.begin(), result.end(),
[](auto& a, auto& b) { return a.error < b.error; });
return result;
return solutions.extractSolutions();
}
// ----------------------------------------------------------------------------
@ -422,8 +438,8 @@ TEST(DiscreteBayesNet, AsiaKBest) {
auto solutions = branchAndBoundKSolutions(net, 4);
EXPECT(!solutions.empty());
// All solutions likely tie with error=0 in this stub
EXPECT_DOUBLES_EQUAL(0, std::fabs(solutions[0].error), 1e-9);
// Regression test: check the first solution
EXPECT_DOUBLES_EQUAL(1.236627, std::fabs(solutions[0].error), 1e-5);
}
/* ************************************************************************* */