Separate class
parent
74b7a962d5
commit
1f4d9bbd7e
|
|
@ -346,97 +346,119 @@ class Solutions {
|
|||
}
|
||||
};
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
// 5) The main branch-and-bound routine for DiscreteBayesNet
|
||||
// ----------------------------------------------------------------------------
|
||||
std::vector<Solution> branchAndBoundKSolutions(const DiscreteBayesNet& bayesNet,
|
||||
size_t K) {
|
||||
// Copy out the conditionals
|
||||
std::vector<std::shared_ptr<DiscreteConditional>> conditionals;
|
||||
for (auto& factor : bayesNet) {
|
||||
conditionals.push_back(factor);
|
||||
/**
|
||||
* BestKSearch: Search for the K best solutions.
|
||||
*/
|
||||
class BestKSearch {
|
||||
public:
|
||||
/**
|
||||
* Construct from a DiscreteBayesNet and K.
|
||||
*/
|
||||
BestKSearch(const DiscreteBayesNet& bayesNet, size_t K)
|
||||
: bayesNet_(bayesNet), solutions_(K) {
|
||||
// Copy out the conditionals
|
||||
for (auto& factor : bayesNet_) {
|
||||
conditionals_.push_back(factor);
|
||||
}
|
||||
|
||||
// Create the root node: no variables assigned, nextConditional = last.
|
||||
SearchNode root{
|
||||
.assignment = DiscreteValues(),
|
||||
.error = 0.0,
|
||||
.nextConditional = static_cast<int>(conditionals_.size()) - 1};
|
||||
root.bound = root.computeBound();
|
||||
std::cout << "Root: " << root << std::endl;
|
||||
expansions_.push(root);
|
||||
}
|
||||
|
||||
// Min-heap of expansions by bound
|
||||
std::priority_queue<SearchNode, std::vector<SearchNode>, CompareByBound>
|
||||
expansions;
|
||||
/**
|
||||
* @brief Search for the K best solutions.
|
||||
*
|
||||
* This method performs a search to find the K best solutions for the given
|
||||
* DiscreteBayesNet. It uses a priority queue to manage the search nodes,
|
||||
* expanding nodes with the smallest bound first. The search continues until
|
||||
* all possible nodes have been expanded or pruned.
|
||||
*
|
||||
* @return A vector of the K best solutions found during the search.
|
||||
*/
|
||||
std::vector<Solution> run() {
|
||||
size_t numExpansions = 0;
|
||||
while (!expansions_.empty()) {
|
||||
expandNextNode();
|
||||
numExpansions++;
|
||||
}
|
||||
|
||||
// Max-heap of best solutions found so far, keyed by error.
|
||||
// We keep the largest error on top.
|
||||
Solutions solutions(K);
|
||||
std::cout << "Expansions: " << numExpansions << std::endl;
|
||||
|
||||
// 1) Create the root node: no variables assigned, nextConditional = last.
|
||||
SearchNode root{.assignment = DiscreteValues(),
|
||||
.error = 0.0,
|
||||
.nextConditional = static_cast<int>(conditionals.size()) - 1};
|
||||
root.bound = root.computeBound();
|
||||
std::cout << "Root: " << root << std::endl;
|
||||
expansions.push(root);
|
||||
// Extract solutions from bestSolutions in ascending order of error
|
||||
return solutions_.extractSolutions();
|
||||
}
|
||||
|
||||
// 2) Main loop
|
||||
size_t numExpansions = 0;
|
||||
while (!expansions.empty()) {
|
||||
private:
|
||||
//
|
||||
void expandNextNode() {
|
||||
// Pop the partial assignment with the smallest bound
|
||||
SearchNode current = expansions.top();
|
||||
expansions.pop();
|
||||
numExpansions++;
|
||||
SearchNode current = expansions_.top();
|
||||
expansions_.pop();
|
||||
std::cout << "Expanding: " << current << std::endl;
|
||||
|
||||
// If we already have K solutions, prune if we cannot beat the worst one.
|
||||
if (solutions.prune(current.bound)) {
|
||||
if (solutions_.prune(current.bound)) {
|
||||
std::cout << "Pruning: bound=" << current.bound << std::endl;
|
||||
continue;
|
||||
return;
|
||||
}
|
||||
|
||||
// Check if we have a complete assignment
|
||||
if (current.isComplete()) {
|
||||
const bool added = solutions.maybeAdd(current.error, current.assignment);
|
||||
const bool added = solutions_.maybeAdd(current.error, current.assignment);
|
||||
if (added) {
|
||||
std::cout << "Best solutions so far:" << std::endl;
|
||||
solutions.print();
|
||||
solutions_.print();
|
||||
}
|
||||
continue;
|
||||
return;
|
||||
}
|
||||
|
||||
// Expand on the next factor
|
||||
const auto& conditional = conditionals[current.nextConditional];
|
||||
const auto& conditional = conditionals_[current.nextConditional];
|
||||
|
||||
for (auto& fa : conditional->frontalAssignments()) {
|
||||
std::cout << "Frontal assignment: " << fa << std::endl;
|
||||
auto childNode = current.expand(*conditional, fa);
|
||||
|
||||
// Again, prune if we cannot beat the worst solution
|
||||
if (solutions.prune(current.bound)) {
|
||||
if (solutions_.prune(current.bound)) {
|
||||
std::cout << "Pruning: bound=" << childNode.bound << std::endl;
|
||||
continue;
|
||||
}
|
||||
|
||||
expansions.push(childNode);
|
||||
expansions_.push(childNode);
|
||||
}
|
||||
}
|
||||
|
||||
std::cout << "Expansions: " << numExpansions << std::endl;
|
||||
|
||||
// 3) Extract solutions from bestSolutions in ascending order of error
|
||||
return solutions.extractSolutions();
|
||||
}
|
||||
const DiscreteBayesNet& bayesNet_;
|
||||
std::vector<std::shared_ptr<DiscreteConditional>> conditionals_;
|
||||
std::priority_queue<SearchNode, std::vector<SearchNode>, CompareByBound>
|
||||
expansions_;
|
||||
Solutions solutions_;
|
||||
};
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
// Example “Unit Tests” (trivial stubs)
|
||||
// ----------------------------------------------------------------------------
|
||||
|
||||
TEST_DISABLED(DiscreteBayesNet, EmptyKBest) {
|
||||
TEST(DiscreteBayesNet, EmptyKBest) {
|
||||
DiscreteBayesNet net; // no factors
|
||||
auto solutions = branchAndBoundKSolutions(net, 3);
|
||||
BestKSearch search(net, 3);
|
||||
auto solutions = search.run();
|
||||
// Expect one solution with empty assignment, error=0
|
||||
EXPECT_LONGS_EQUAL(1, solutions.size());
|
||||
EXPECT_DOUBLES_EQUAL(0, std::fabs(solutions[0].error), 1e-9);
|
||||
}
|
||||
|
||||
TEST(DiscreteBayesNet, AsiaKBest) {
|
||||
DiscreteBayesNet net = constructAsiaExample();
|
||||
|
||||
auto solutions = branchAndBoundKSolutions(net, 4);
|
||||
DiscreteBayesNet asia = constructAsiaExample();
|
||||
BestKSearch search(asia, 4);
|
||||
auto solutions = search.run();
|
||||
EXPECT(!solutions.empty());
|
||||
// Regression test: check the first solution
|
||||
EXPECT_DOUBLES_EQUAL(1.236627, std::fabs(solutions[0].error), 1e-5);
|
||||
|
|
|
|||
Loading…
Reference in New Issue