Separate class

release/4.3a0
Frank Dellaert 2025-01-26 17:59:31 -05:00
parent 74b7a962d5
commit 1f4d9bbd7e
1 changed files with 68 additions and 46 deletions

View File

@ -346,97 +346,119 @@ class Solutions {
}
};
// ----------------------------------------------------------------------------
// 5) The main branch-and-bound routine for DiscreteBayesNet
// ----------------------------------------------------------------------------
std::vector<Solution> branchAndBoundKSolutions(const DiscreteBayesNet& bayesNet,
size_t K) {
// Copy out the conditionals
std::vector<std::shared_ptr<DiscreteConditional>> conditionals;
for (auto& factor : bayesNet) {
conditionals.push_back(factor);
/**
* BestKSearch: Search for the K best solutions.
*/
class BestKSearch {
public:
/**
* Construct from a DiscreteBayesNet and K.
*/
BestKSearch(const DiscreteBayesNet& bayesNet, size_t K)
: bayesNet_(bayesNet), solutions_(K) {
// Copy out the conditionals
for (auto& factor : bayesNet_) {
conditionals_.push_back(factor);
}
// Create the root node: no variables assigned, nextConditional = last.
SearchNode root{
.assignment = DiscreteValues(),
.error = 0.0,
.nextConditional = static_cast<int>(conditionals_.size()) - 1};
root.bound = root.computeBound();
std::cout << "Root: " << root << std::endl;
expansions_.push(root);
}
// Min-heap of expansions by bound
std::priority_queue<SearchNode, std::vector<SearchNode>, CompareByBound>
expansions;
/**
* @brief Search for the K best solutions.
*
* This method performs a search to find the K best solutions for the given
* DiscreteBayesNet. It uses a priority queue to manage the search nodes,
* expanding nodes with the smallest bound first. The search continues until
* all possible nodes have been expanded or pruned.
*
* @return A vector of the K best solutions found during the search.
*/
std::vector<Solution> run() {
size_t numExpansions = 0;
while (!expansions_.empty()) {
expandNextNode();
numExpansions++;
}
// Max-heap of best solutions found so far, keyed by error.
// We keep the largest error on top.
Solutions solutions(K);
std::cout << "Expansions: " << numExpansions << std::endl;
// 1) Create the root node: no variables assigned, nextConditional = last.
SearchNode root{.assignment = DiscreteValues(),
.error = 0.0,
.nextConditional = static_cast<int>(conditionals.size()) - 1};
root.bound = root.computeBound();
std::cout << "Root: " << root << std::endl;
expansions.push(root);
// Extract solutions from bestSolutions in ascending order of error
return solutions_.extractSolutions();
}
// 2) Main loop
size_t numExpansions = 0;
while (!expansions.empty()) {
private:
//
void expandNextNode() {
// Pop the partial assignment with the smallest bound
SearchNode current = expansions.top();
expansions.pop();
numExpansions++;
SearchNode current = expansions_.top();
expansions_.pop();
std::cout << "Expanding: " << current << std::endl;
// If we already have K solutions, prune if we cannot beat the worst one.
if (solutions.prune(current.bound)) {
if (solutions_.prune(current.bound)) {
std::cout << "Pruning: bound=" << current.bound << std::endl;
continue;
return;
}
// Check if we have a complete assignment
if (current.isComplete()) {
const bool added = solutions.maybeAdd(current.error, current.assignment);
const bool added = solutions_.maybeAdd(current.error, current.assignment);
if (added) {
std::cout << "Best solutions so far:" << std::endl;
solutions.print();
solutions_.print();
}
continue;
return;
}
// Expand on the next factor
const auto& conditional = conditionals[current.nextConditional];
const auto& conditional = conditionals_[current.nextConditional];
for (auto& fa : conditional->frontalAssignments()) {
std::cout << "Frontal assignment: " << fa << std::endl;
auto childNode = current.expand(*conditional, fa);
// Again, prune if we cannot beat the worst solution
if (solutions.prune(current.bound)) {
if (solutions_.prune(current.bound)) {
std::cout << "Pruning: bound=" << childNode.bound << std::endl;
continue;
}
expansions.push(childNode);
expansions_.push(childNode);
}
}
std::cout << "Expansions: " << numExpansions << std::endl;
// 3) Extract solutions from bestSolutions in ascending order of error
return solutions.extractSolutions();
}
const DiscreteBayesNet& bayesNet_;
std::vector<std::shared_ptr<DiscreteConditional>> conditionals_;
std::priority_queue<SearchNode, std::vector<SearchNode>, CompareByBound>
expansions_;
Solutions solutions_;
};
// ----------------------------------------------------------------------------
// Example “Unit Tests” (trivial stubs)
// ----------------------------------------------------------------------------
TEST_DISABLED(DiscreteBayesNet, EmptyKBest) {
TEST(DiscreteBayesNet, EmptyKBest) {
DiscreteBayesNet net; // no factors
auto solutions = branchAndBoundKSolutions(net, 3);
BestKSearch search(net, 3);
auto solutions = search.run();
// Expect one solution with empty assignment, error=0
EXPECT_LONGS_EQUAL(1, solutions.size());
EXPECT_DOUBLES_EQUAL(0, std::fabs(solutions[0].error), 1e-9);
}
TEST(DiscreteBayesNet, AsiaKBest) {
DiscreteBayesNet net = constructAsiaExample();
auto solutions = branchAndBoundKSolutions(net, 4);
DiscreteBayesNet asia = constructAsiaExample();
BestKSearch search(asia, 4);
auto solutions = search.run();
EXPECT(!solutions.empty());
// Regression test: check the first solution
EXPECT_DOUBLES_EQUAL(1.236627, std::fabs(solutions[0].error), 1e-5);