Separate class

release/4.3a0
Frank Dellaert 2025-01-26 17:59:31 -05:00
parent 74b7a962d5
commit 1f4d9bbd7e
1 changed files with 68 additions and 46 deletions

View File

@ -346,97 +346,119 @@ class Solutions {
} }
}; };
// ---------------------------------------------------------------------------- /**
// 5) The main branch-and-bound routine for DiscreteBayesNet * BestKSearch: Search for the K best solutions.
// ---------------------------------------------------------------------------- */
std::vector<Solution> branchAndBoundKSolutions(const DiscreteBayesNet& bayesNet, class BestKSearch {
size_t K) { public:
// Copy out the conditionals /**
std::vector<std::shared_ptr<DiscreteConditional>> conditionals; * Construct from a DiscreteBayesNet and K.
for (auto& factor : bayesNet) { */
conditionals.push_back(factor); BestKSearch(const DiscreteBayesNet& bayesNet, size_t K)
: bayesNet_(bayesNet), solutions_(K) {
// Copy out the conditionals
for (auto& factor : bayesNet_) {
conditionals_.push_back(factor);
}
// Create the root node: no variables assigned, nextConditional = last.
SearchNode root{
.assignment = DiscreteValues(),
.error = 0.0,
.nextConditional = static_cast<int>(conditionals_.size()) - 1};
root.bound = root.computeBound();
std::cout << "Root: " << root << std::endl;
expansions_.push(root);
} }
// Min-heap of expansions by bound /**
std::priority_queue<SearchNode, std::vector<SearchNode>, CompareByBound> * @brief Search for the K best solutions.
expansions; *
* This method performs a search to find the K best solutions for the given
* DiscreteBayesNet. It uses a priority queue to manage the search nodes,
* expanding nodes with the smallest bound first. The search continues until
* all possible nodes have been expanded or pruned.
*
* @return A vector of the K best solutions found during the search.
*/
std::vector<Solution> run() {
size_t numExpansions = 0;
while (!expansions_.empty()) {
expandNextNode();
numExpansions++;
}
// Max-heap of best solutions found so far, keyed by error. std::cout << "Expansions: " << numExpansions << std::endl;
// We keep the largest error on top.
Solutions solutions(K);
// 1) Create the root node: no variables assigned, nextConditional = last. // Extract solutions from bestSolutions in ascending order of error
SearchNode root{.assignment = DiscreteValues(), return solutions_.extractSolutions();
.error = 0.0, }
.nextConditional = static_cast<int>(conditionals.size()) - 1};
root.bound = root.computeBound();
std::cout << "Root: " << root << std::endl;
expansions.push(root);
// 2) Main loop private:
size_t numExpansions = 0; //
while (!expansions.empty()) { void expandNextNode() {
// Pop the partial assignment with the smallest bound // Pop the partial assignment with the smallest bound
SearchNode current = expansions.top(); SearchNode current = expansions_.top();
expansions.pop(); expansions_.pop();
numExpansions++;
std::cout << "Expanding: " << current << std::endl; std::cout << "Expanding: " << current << std::endl;
// If we already have K solutions, prune if we cannot beat the worst one. // If we already have K solutions, prune if we cannot beat the worst one.
if (solutions.prune(current.bound)) { if (solutions_.prune(current.bound)) {
std::cout << "Pruning: bound=" << current.bound << std::endl; std::cout << "Pruning: bound=" << current.bound << std::endl;
continue; return;
} }
// Check if we have a complete assignment // Check if we have a complete assignment
if (current.isComplete()) { if (current.isComplete()) {
const bool added = solutions.maybeAdd(current.error, current.assignment); const bool added = solutions_.maybeAdd(current.error, current.assignment);
if (added) { if (added) {
std::cout << "Best solutions so far:" << std::endl; std::cout << "Best solutions so far:" << std::endl;
solutions.print(); solutions_.print();
} }
continue; return;
} }
// Expand on the next factor // Expand on the next factor
const auto& conditional = conditionals[current.nextConditional]; const auto& conditional = conditionals_[current.nextConditional];
for (auto& fa : conditional->frontalAssignments()) { for (auto& fa : conditional->frontalAssignments()) {
std::cout << "Frontal assignment: " << fa << std::endl; std::cout << "Frontal assignment: " << fa << std::endl;
auto childNode = current.expand(*conditional, fa); auto childNode = current.expand(*conditional, fa);
// Again, prune if we cannot beat the worst solution // Again, prune if we cannot beat the worst solution
if (solutions.prune(current.bound)) { if (solutions_.prune(current.bound)) {
std::cout << "Pruning: bound=" << childNode.bound << std::endl; std::cout << "Pruning: bound=" << childNode.bound << std::endl;
continue; continue;
} }
expansions.push(childNode); expansions_.push(childNode);
} }
} }
std::cout << "Expansions: " << numExpansions << std::endl; const DiscreteBayesNet& bayesNet_;
std::vector<std::shared_ptr<DiscreteConditional>> conditionals_;
// 3) Extract solutions from bestSolutions in ascending order of error std::priority_queue<SearchNode, std::vector<SearchNode>, CompareByBound>
return solutions.extractSolutions(); expansions_;
} Solutions solutions_;
};
// ---------------------------------------------------------------------------- // ----------------------------------------------------------------------------
// Example “Unit Tests” (trivial stubs) // Example “Unit Tests” (trivial stubs)
// ---------------------------------------------------------------------------- // ----------------------------------------------------------------------------
TEST_DISABLED(DiscreteBayesNet, EmptyKBest) { TEST(DiscreteBayesNet, EmptyKBest) {
DiscreteBayesNet net; // no factors DiscreteBayesNet net; // no factors
auto solutions = branchAndBoundKSolutions(net, 3); BestKSearch search(net, 3);
auto solutions = search.run();
// Expect one solution with empty assignment, error=0 // Expect one solution with empty assignment, error=0
EXPECT_LONGS_EQUAL(1, solutions.size()); EXPECT_LONGS_EQUAL(1, solutions.size());
EXPECT_DOUBLES_EQUAL(0, std::fabs(solutions[0].error), 1e-9); EXPECT_DOUBLES_EQUAL(0, std::fabs(solutions[0].error), 1e-9);
} }
TEST(DiscreteBayesNet, AsiaKBest) { TEST(DiscreteBayesNet, AsiaKBest) {
DiscreteBayesNet net = constructAsiaExample(); DiscreteBayesNet asia = constructAsiaExample();
BestKSearch search(asia, 4);
auto solutions = branchAndBoundKSolutions(net, 4); auto solutions = search.run();
EXPECT(!solutions.empty()); EXPECT(!solutions.empty());
// Regression test: check the first solution // Regression test: check the first solution
EXPECT_DOUBLES_EQUAL(1.236627, std::fabs(solutions[0].error), 1e-5); EXPECT_DOUBLES_EQUAL(1.236627, std::fabs(solutions[0].error), 1e-5);