Separate class
parent
74b7a962d5
commit
1f4d9bbd7e
|
|
@ -346,97 +346,119 @@ class Solutions {
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
// ----------------------------------------------------------------------------
|
/**
|
||||||
// 5) The main branch-and-bound routine for DiscreteBayesNet
|
* BestKSearch: Search for the K best solutions.
|
||||||
// ----------------------------------------------------------------------------
|
*/
|
||||||
std::vector<Solution> branchAndBoundKSolutions(const DiscreteBayesNet& bayesNet,
|
class BestKSearch {
|
||||||
size_t K) {
|
public:
|
||||||
// Copy out the conditionals
|
/**
|
||||||
std::vector<std::shared_ptr<DiscreteConditional>> conditionals;
|
* Construct from a DiscreteBayesNet and K.
|
||||||
for (auto& factor : bayesNet) {
|
*/
|
||||||
conditionals.push_back(factor);
|
BestKSearch(const DiscreteBayesNet& bayesNet, size_t K)
|
||||||
|
: bayesNet_(bayesNet), solutions_(K) {
|
||||||
|
// Copy out the conditionals
|
||||||
|
for (auto& factor : bayesNet_) {
|
||||||
|
conditionals_.push_back(factor);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create the root node: no variables assigned, nextConditional = last.
|
||||||
|
SearchNode root{
|
||||||
|
.assignment = DiscreteValues(),
|
||||||
|
.error = 0.0,
|
||||||
|
.nextConditional = static_cast<int>(conditionals_.size()) - 1};
|
||||||
|
root.bound = root.computeBound();
|
||||||
|
std::cout << "Root: " << root << std::endl;
|
||||||
|
expansions_.push(root);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Min-heap of expansions by bound
|
/**
|
||||||
std::priority_queue<SearchNode, std::vector<SearchNode>, CompareByBound>
|
* @brief Search for the K best solutions.
|
||||||
expansions;
|
*
|
||||||
|
* This method performs a search to find the K best solutions for the given
|
||||||
|
* DiscreteBayesNet. It uses a priority queue to manage the search nodes,
|
||||||
|
* expanding nodes with the smallest bound first. The search continues until
|
||||||
|
* all possible nodes have been expanded or pruned.
|
||||||
|
*
|
||||||
|
* @return A vector of the K best solutions found during the search.
|
||||||
|
*/
|
||||||
|
std::vector<Solution> run() {
|
||||||
|
size_t numExpansions = 0;
|
||||||
|
while (!expansions_.empty()) {
|
||||||
|
expandNextNode();
|
||||||
|
numExpansions++;
|
||||||
|
}
|
||||||
|
|
||||||
// Max-heap of best solutions found so far, keyed by error.
|
std::cout << "Expansions: " << numExpansions << std::endl;
|
||||||
// We keep the largest error on top.
|
|
||||||
Solutions solutions(K);
|
|
||||||
|
|
||||||
// 1) Create the root node: no variables assigned, nextConditional = last.
|
// Extract solutions from bestSolutions in ascending order of error
|
||||||
SearchNode root{.assignment = DiscreteValues(),
|
return solutions_.extractSolutions();
|
||||||
.error = 0.0,
|
}
|
||||||
.nextConditional = static_cast<int>(conditionals.size()) - 1};
|
|
||||||
root.bound = root.computeBound();
|
|
||||||
std::cout << "Root: " << root << std::endl;
|
|
||||||
expansions.push(root);
|
|
||||||
|
|
||||||
// 2) Main loop
|
private:
|
||||||
size_t numExpansions = 0;
|
//
|
||||||
while (!expansions.empty()) {
|
void expandNextNode() {
|
||||||
// Pop the partial assignment with the smallest bound
|
// Pop the partial assignment with the smallest bound
|
||||||
SearchNode current = expansions.top();
|
SearchNode current = expansions_.top();
|
||||||
expansions.pop();
|
expansions_.pop();
|
||||||
numExpansions++;
|
|
||||||
std::cout << "Expanding: " << current << std::endl;
|
std::cout << "Expanding: " << current << std::endl;
|
||||||
|
|
||||||
// If we already have K solutions, prune if we cannot beat the worst one.
|
// If we already have K solutions, prune if we cannot beat the worst one.
|
||||||
if (solutions.prune(current.bound)) {
|
if (solutions_.prune(current.bound)) {
|
||||||
std::cout << "Pruning: bound=" << current.bound << std::endl;
|
std::cout << "Pruning: bound=" << current.bound << std::endl;
|
||||||
continue;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Check if we have a complete assignment
|
// Check if we have a complete assignment
|
||||||
if (current.isComplete()) {
|
if (current.isComplete()) {
|
||||||
const bool added = solutions.maybeAdd(current.error, current.assignment);
|
const bool added = solutions_.maybeAdd(current.error, current.assignment);
|
||||||
if (added) {
|
if (added) {
|
||||||
std::cout << "Best solutions so far:" << std::endl;
|
std::cout << "Best solutions so far:" << std::endl;
|
||||||
solutions.print();
|
solutions_.print();
|
||||||
}
|
}
|
||||||
continue;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Expand on the next factor
|
// Expand on the next factor
|
||||||
const auto& conditional = conditionals[current.nextConditional];
|
const auto& conditional = conditionals_[current.nextConditional];
|
||||||
|
|
||||||
for (auto& fa : conditional->frontalAssignments()) {
|
for (auto& fa : conditional->frontalAssignments()) {
|
||||||
std::cout << "Frontal assignment: " << fa << std::endl;
|
std::cout << "Frontal assignment: " << fa << std::endl;
|
||||||
auto childNode = current.expand(*conditional, fa);
|
auto childNode = current.expand(*conditional, fa);
|
||||||
|
|
||||||
// Again, prune if we cannot beat the worst solution
|
// Again, prune if we cannot beat the worst solution
|
||||||
if (solutions.prune(current.bound)) {
|
if (solutions_.prune(current.bound)) {
|
||||||
std::cout << "Pruning: bound=" << childNode.bound << std::endl;
|
std::cout << "Pruning: bound=" << childNode.bound << std::endl;
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
expansions.push(childNode);
|
expansions_.push(childNode);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
std::cout << "Expansions: " << numExpansions << std::endl;
|
const DiscreteBayesNet& bayesNet_;
|
||||||
|
std::vector<std::shared_ptr<DiscreteConditional>> conditionals_;
|
||||||
// 3) Extract solutions from bestSolutions in ascending order of error
|
std::priority_queue<SearchNode, std::vector<SearchNode>, CompareByBound>
|
||||||
return solutions.extractSolutions();
|
expansions_;
|
||||||
}
|
Solutions solutions_;
|
||||||
|
};
|
||||||
|
|
||||||
// ----------------------------------------------------------------------------
|
// ----------------------------------------------------------------------------
|
||||||
// Example “Unit Tests” (trivial stubs)
|
// Example “Unit Tests” (trivial stubs)
|
||||||
// ----------------------------------------------------------------------------
|
// ----------------------------------------------------------------------------
|
||||||
|
|
||||||
TEST_DISABLED(DiscreteBayesNet, EmptyKBest) {
|
TEST(DiscreteBayesNet, EmptyKBest) {
|
||||||
DiscreteBayesNet net; // no factors
|
DiscreteBayesNet net; // no factors
|
||||||
auto solutions = branchAndBoundKSolutions(net, 3);
|
BestKSearch search(net, 3);
|
||||||
|
auto solutions = search.run();
|
||||||
// Expect one solution with empty assignment, error=0
|
// Expect one solution with empty assignment, error=0
|
||||||
EXPECT_LONGS_EQUAL(1, solutions.size());
|
EXPECT_LONGS_EQUAL(1, solutions.size());
|
||||||
EXPECT_DOUBLES_EQUAL(0, std::fabs(solutions[0].error), 1e-9);
|
EXPECT_DOUBLES_EQUAL(0, std::fabs(solutions[0].error), 1e-9);
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(DiscreteBayesNet, AsiaKBest) {
|
TEST(DiscreteBayesNet, AsiaKBest) {
|
||||||
DiscreteBayesNet net = constructAsiaExample();
|
DiscreteBayesNet asia = constructAsiaExample();
|
||||||
|
BestKSearch search(asia, 4);
|
||||||
auto solutions = branchAndBoundKSolutions(net, 4);
|
auto solutions = search.run();
|
||||||
EXPECT(!solutions.empty());
|
EXPECT(!solutions.empty());
|
||||||
// Regression test: check the first solution
|
// Regression test: check the first solution
|
||||||
EXPECT_DOUBLES_EQUAL(1.236627, std::fabs(solutions[0].error), 1e-5);
|
EXPECT_DOUBLES_EQUAL(1.236627, std::fabs(solutions[0].error), 1e-5);
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue