Methods in Solutions
parent
2808982903
commit
74b7a962d5
|
@ -96,7 +96,7 @@ TEST(DiscreteBayesNet, Asia) {
|
||||||
|
|
||||||
// Convert to factor graph
|
// Convert to factor graph
|
||||||
DiscreteFactorGraph fg(asia);
|
DiscreteFactorGraph fg(asia);
|
||||||
LONGS_EQUAL(3, fg.back()->size());
|
LONGS_EQUAL(1, fg.back()->size());
|
||||||
|
|
||||||
// Check the marginals we know (of the parent-less nodes)
|
// Check the marginals we know (of the parent-less nodes)
|
||||||
DiscreteMarginals marginals(fg);
|
DiscreteMarginals marginals(fg);
|
||||||
|
@ -164,16 +164,16 @@ TEST(DiscreteBayesNet, Dot) {
|
||||||
"digraph {\n"
|
"digraph {\n"
|
||||||
" size=\"5,5\";\n"
|
" size=\"5,5\";\n"
|
||||||
"\n"
|
"\n"
|
||||||
" var4683743612465315840[label=\"A0\"];\n"
|
" var4683743612465315848[label=\"A8\"];\n"
|
||||||
" var4971973988617027589[label=\"E5\"];\n"
|
" var4971973988617027587[label=\"E3\"];\n"
|
||||||
" var5476377146882523142[label=\"L6\"];\n"
|
" var5476377146882523141[label=\"L5\"];\n"
|
||||||
" var5980780305148018692[label=\"S4\"];\n"
|
" var5980780305148018695[label=\"S7\"];\n"
|
||||||
" var6052837899185946627[label=\"T3\"];\n"
|
" var6052837899185946630[label=\"T6\"];\n"
|
||||||
"\n"
|
"\n"
|
||||||
" var6052837899185946627->var4971973988617027589\n"
|
" var6052837899185946630->var4971973988617027587\n"
|
||||||
" var5476377146882523142->var4971973988617027589\n"
|
" var5476377146882523141->var4971973988617027587\n"
|
||||||
" var5980780305148018692->var5476377146882523142\n"
|
" var5980780305148018695->var5476377146882523141\n"
|
||||||
" var4683743612465315840->var6052837899185946627\n"
|
" var4683743612465315848->var6052837899185946630\n"
|
||||||
"}");
|
"}");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -290,19 +290,61 @@ struct CompareByError {
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
using Solutions =
|
// Define the Solutions class
|
||||||
std::priority_queue<Solution, std::vector<Solution>, CompareByError>;
|
class Solutions {
|
||||||
|
private:
|
||||||
|
size_t maxSize_;
|
||||||
|
std::priority_queue<Solution, std::vector<Solution>, CompareByError> pq_;
|
||||||
|
|
||||||
// Function to print the priority queue
|
public:
|
||||||
void printPriorityQueue(Solutions pq) {
|
Solutions(size_t maxSize) : maxSize_(maxSize) {}
|
||||||
size_t i = 0;
|
|
||||||
|
/// Add a solution to the priority queue, possibly evicting the worst one.
|
||||||
|
/// Return true if we added the solution.
|
||||||
|
bool maybeAdd(double error, const DiscreteValues& assignment) {
|
||||||
|
const bool full = pq_.size() == maxSize_;
|
||||||
|
if (full && error >= pq_.top().error) return false;
|
||||||
|
if (full) pq_.pop();
|
||||||
|
pq_.emplace(error, assignment);
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Check if we have any solutions
|
||||||
|
bool empty() const { return pq_.empty(); }
|
||||||
|
|
||||||
|
// Method to print all solutions
|
||||||
|
void print() const {
|
||||||
|
auto pq = pq_;
|
||||||
while (!pq.empty()) {
|
while (!pq.empty()) {
|
||||||
const Solution& best = pq.top();
|
const Solution& best = pq.top();
|
||||||
cout << " " << ++i << " Error: " << best.error
|
std::cout << "Error: " << best.error << ", Values: " << best.assignment
|
||||||
<< ", Values: " << best.assignment << endl;
|
<< std::endl;
|
||||||
pq.pop();
|
pq.pop();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Check if (partial) solution with given bound can be pruned. If we have
|
||||||
|
/// room, we never prune. Otherwise, prune if lower bound on error is worse
|
||||||
|
/// than our current worst error.
|
||||||
|
bool prune(double bound) const {
|
||||||
|
if (pq_.size() < maxSize_) return false;
|
||||||
|
double worstError = pq_.top().error;
|
||||||
|
return (bound >= worstError);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Method to extract solutions in ascending order of error
|
||||||
|
std::vector<Solution> extractSolutions() {
|
||||||
|
std::vector<Solution> result;
|
||||||
|
while (!pq_.empty()) {
|
||||||
|
result.push_back(pq_.top());
|
||||||
|
pq_.pop();
|
||||||
|
}
|
||||||
|
std::sort(
|
||||||
|
result.begin(), result.end(),
|
||||||
|
[](const Solution& a, const Solution& b) { return a.error < b.error; });
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
// ----------------------------------------------------------------------------
|
// ----------------------------------------------------------------------------
|
||||||
// 5) The main branch-and-bound routine for DiscreteBayesNet
|
// 5) The main branch-and-bound routine for DiscreteBayesNet
|
||||||
|
@ -321,7 +363,7 @@ std::vector<Solution> branchAndBoundKSolutions(const DiscreteBayesNet& bayesNet,
|
||||||
|
|
||||||
// Max-heap of best solutions found so far, keyed by error.
|
// Max-heap of best solutions found so far, keyed by error.
|
||||||
// We keep the largest error on top.
|
// We keep the largest error on top.
|
||||||
Solutions solutions;
|
Solutions solutions(K);
|
||||||
|
|
||||||
// 1) Create the root node: no variables assigned, nextConditional = last.
|
// 1) Create the root node: no variables assigned, nextConditional = last.
|
||||||
SearchNode root{.assignment = DiscreteValues(),
|
SearchNode root{.assignment = DiscreteValues(),
|
||||||
|
@ -341,32 +383,17 @@ std::vector<Solution> branchAndBoundKSolutions(const DiscreteBayesNet& bayesNet,
|
||||||
std::cout << "Expanding: " << current << std::endl;
|
std::cout << "Expanding: " << current << std::endl;
|
||||||
|
|
||||||
// If we already have K solutions, prune if we cannot beat the worst one.
|
// If we already have K solutions, prune if we cannot beat the worst one.
|
||||||
if (solutions.size() == K) {
|
if (solutions.prune(current.bound)) {
|
||||||
double worstError = solutions.top().error;
|
std::cout << "Pruning: bound=" << current.bound << std::endl;
|
||||||
if (current.bound >= worstError) {
|
|
||||||
std::cout << "Pruning: bound=" << current.bound
|
|
||||||
<< " worstError=" << worstError << std::endl;
|
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
// Check if we have a complete assignment
|
// Check if we have a complete assignment
|
||||||
if (current.isComplete()) {
|
if (current.isComplete()) {
|
||||||
// We have a new candidate solution
|
const bool added = solutions.maybeAdd(current.error, current.assignment);
|
||||||
double finalError = current.error;
|
if (added) {
|
||||||
if (solutions.size() < K) {
|
|
||||||
std::cout << "Adding solution: " << current << std::endl;
|
|
||||||
solutions.emplace(finalError, current.assignment);
|
|
||||||
// print out all solutions:
|
|
||||||
std::cout << "Best solutions so far:" << std::endl;
|
std::cout << "Best solutions so far:" << std::endl;
|
||||||
printPriorityQueue(solutions);
|
solutions.print();
|
||||||
} else {
|
|
||||||
double worstError = solutions.top().error;
|
|
||||||
if (finalError < worstError) {
|
|
||||||
std::cout << "Replacing solution: " << current << std::endl;
|
|
||||||
solutions.pop();
|
|
||||||
solutions.emplace(finalError, current.assignment);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
@ -379,14 +406,11 @@ std::vector<Solution> branchAndBoundKSolutions(const DiscreteBayesNet& bayesNet,
|
||||||
auto childNode = current.expand(*conditional, fa);
|
auto childNode = current.expand(*conditional, fa);
|
||||||
|
|
||||||
// Again, prune if we cannot beat the worst solution
|
// Again, prune if we cannot beat the worst solution
|
||||||
if (solutions.size() == K) {
|
if (solutions.prune(current.bound)) {
|
||||||
double worstError = solutions.top().error;
|
std::cout << "Pruning: bound=" << childNode.bound << std::endl;
|
||||||
if (childNode.bound >= worstError) {
|
|
||||||
std::cout << "Pruning: bound=" << childNode.bound
|
|
||||||
<< " worstError=" << worstError << std::endl;
|
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
}
|
|
||||||
expansions.push(childNode);
|
expansions.push(childNode);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -394,15 +418,7 @@ std::vector<Solution> branchAndBoundKSolutions(const DiscreteBayesNet& bayesNet,
|
||||||
std::cout << "Expansions: " << numExpansions << std::endl;
|
std::cout << "Expansions: " << numExpansions << std::endl;
|
||||||
|
|
||||||
// 3) Extract solutions from bestSolutions in ascending order of error
|
// 3) Extract solutions from bestSolutions in ascending order of error
|
||||||
std::vector<Solution> result;
|
return solutions.extractSolutions();
|
||||||
while (!solutions.empty()) {
|
|
||||||
result.push_back(solutions.top());
|
|
||||||
solutions.pop();
|
|
||||||
}
|
|
||||||
std::sort(result.begin(), result.end(),
|
|
||||||
[](auto& a, auto& b) { return a.error < b.error; });
|
|
||||||
|
|
||||||
return result;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// ----------------------------------------------------------------------------
|
// ----------------------------------------------------------------------------
|
||||||
|
@ -422,8 +438,8 @@ TEST(DiscreteBayesNet, AsiaKBest) {
|
||||||
|
|
||||||
auto solutions = branchAndBoundKSolutions(net, 4);
|
auto solutions = branchAndBoundKSolutions(net, 4);
|
||||||
EXPECT(!solutions.empty());
|
EXPECT(!solutions.empty());
|
||||||
// All solutions likely tie with error=0 in this stub
|
// Regression test: check the first solution
|
||||||
EXPECT_DOUBLES_EQUAL(0, std::fabs(solutions[0].error), 1e-9);
|
EXPECT_DOUBLES_EQUAL(1.236627, std::fabs(solutions[0].error), 1e-5);
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
|
|
Loading…
Reference in New Issue