Best-K search, v0

release/4.3a0
Frank Dellaert 2025-01-26 16:57:58 -05:00
parent 4027e86057
commit 2808982903
1 changed files with 253 additions and 23 deletions

View File

@ -23,6 +23,7 @@
#include <gtsam/discrete/DiscreteBayesNet.h>
#include <gtsam/discrete/DiscreteFactorGraph.h>
#include <gtsam/discrete/DiscreteMarginals.h>
#include <gtsam/inference/Symbol.h>
#include <iostream>
#include <string>
@ -31,8 +32,15 @@
using namespace std;
using namespace gtsam;
static const DiscreteKey Asia(0, 2), Smoking(4, 2), Tuberculosis(3, 2),
LungCancer(6, 2), Bronchitis(7, 2), Either(5, 2), XRay(2, 2), Dyspnea(1, 2);
namespace keys {
static const Key D = Symbol('D', 1), X = Symbol('X', 2), E = Symbol('E', 3),
B = Symbol('B', 4), L = Symbol('L', 5), T = Symbol('T', 6),
S = Symbol('S', 7), A = Symbol('A', 8);
}
static const DiscreteKey Dyspnea(keys::D, 2), XRay(keys::X, 2),
Either(keys::E, 2), Bronchitis(keys::B, 2), LungCancer(keys::L, 2),
Tuberculosis(keys::T, 2), Smoking(keys::S, 2), Asia(keys::A, 2);
using ADT = AlgebraicDecisionTree<Key>;
@ -40,17 +48,15 @@ using ADT = AlgebraicDecisionTree<Key>;
DiscreteBayesNet constructAsiaExample() {
DiscreteBayesNet asia;
asia.add(Asia, "99/1");
asia.add(Smoking % "50/50"); // Signature version
asia.add(Tuberculosis | Asia = "99/1 95/5");
asia.add(LungCancer | Smoking = "99/1 90/10");
asia.add(Bronchitis | Smoking = "70/30 40/60");
asia.add((Either | Tuberculosis, LungCancer) = "F T T T");
asia.add(XRay | Either = "95/5 2/98");
// Add in topological sort order, parents last:
asia.add((Dyspnea | Either, Bronchitis) = "9/1 2/8 3/7 1/9");
asia.add(XRay | Either = "95/5 2/98");
asia.add((Either | Tuberculosis, LungCancer) = "F T T T");
asia.add(Bronchitis | Smoking = "70/30 40/60");
asia.add(LungCancer | Smoking = "99/1 90/10");
asia.add(Tuberculosis | Asia = "99/1 95/5");
asia.add(Smoking % "50/50"); // Signature version
asia.add(Asia, "99/1");
return asia;
}
@ -99,7 +105,8 @@ TEST(DiscreteBayesNet, Asia) {
EXPECT(assert_equal(vs, marginals.marginalProbabilities(Smoking)));
// Create solver and eliminate
const Ordering ordering{0, 1, 2, 3, 4, 5, 6, 7};
const Ordering ordering{keys::A, keys::D, keys::T, keys::X,
keys::S, keys::E, keys::L, keys::B};
DiscreteBayesNet::shared_ptr chordal = fg.eliminateSequential(ordering);
DiscreteConditional expected2(Bronchitis % "11/9");
EXPECT(assert_equal(expected2, *chordal->back()));
@ -157,16 +164,16 @@ TEST(DiscreteBayesNet, Dot) {
"digraph {\n"
" size=\"5,5\";\n"
"\n"
" var0[label=\"0\"];\n"
" var3[label=\"3\"];\n"
" var4[label=\"4\"];\n"
" var5[label=\"5\"];\n"
" var6[label=\"6\"];\n"
" var4683743612465315840[label=\"A0\"];\n"
" var4971973988617027589[label=\"E5\"];\n"
" var5476377146882523142[label=\"L6\"];\n"
" var5980780305148018692[label=\"S4\"];\n"
" var6052837899185946627[label=\"T3\"];\n"
"\n"
" var3->var5\n"
" var6->var5\n"
" var4->var6\n"
" var0->var3\n"
" var6052837899185946627->var4971973988617027589\n"
" var5476377146882523142->var4971973988617027589\n"
" var5980780305148018692->var5476377146882523142\n"
" var4683743612465315840->var6052837899185946627\n"
"}");
}
@ -191,11 +198,234 @@ TEST(DiscreteBayesNet, markdown) {
"|:-:|:-:|:-:|\n"
"|0|0.8|0.2|\n"
"|1|0.7|0.3|\n\n";
auto formatter = [](Key key) { return key == 0 ? "Asia" : "Smoking"; };
auto formatter = [](Key key) { return key == keys::A ? "Asia" : "Smoking"; };
string actual = fragment.markdown(formatter);
EXPECT(actual == expected);
}
/* ************************************************************************* */
#include <algorithm>
#include <cmath>
#include <iostream>
#include <map>
#include <queue>
#include <vector>
using Value = size_t;
// ----------------------------------------------------------------------------
// 1) SearchNode: store partial assignment and next factor to expand
// ----------------------------------------------------------------------------
struct SearchNode {
DiscreteValues assignment;
double error;
double bound;
int nextConditional; // index into conditionals
/// if nextConditional < 0, we've assigned everything.
bool isComplete() const { return nextConditional < 0; }
/// lower bound on final error for unassigned variables. Stub=0.
double computeBound() const {
// Real code might do partial factor analysis or heuristics.
return 0.0;
}
/// Expand this node by assigning the next variable
SearchNode expand(const DiscreteConditional& conditional,
const DiscreteValues& fa) const {
// Combine the new frontal assignment with the current partial assignment
SearchNode child;
child.assignment = assignment;
for (auto& kv : fa) {
child.assignment[kv.first] = kv.second;
}
// Compute the incremental error for this factor
child.error = error + conditional.error(child.assignment);
// Compute new bound
child.bound = child.error + computeBound();
// Next factor index
child.nextConditional = nextConditional - 1;
return child;
}
friend std::ostream& operator<<(std::ostream& os, const SearchNode& sn) {
os << "[ error=" << sn.error << " bound=" << sn.bound
<< " nextConditional=" << sn.nextConditional << " assignment={"
<< sn.assignment << "}]";
return os;
}
};
// ----------------------------------------------------------------------------
// 2) Priority functor to make a min-heap by bound
// ----------------------------------------------------------------------------
struct CompareByBound {
bool operator()(const SearchNode& a, const SearchNode& b) const {
return a.bound > b.bound; // smallest bound -> highest priority
}
};
// ----------------------------------------------------------------------------
// 4) A Solution
// ----------------------------------------------------------------------------
struct Solution {
double error;
DiscreteValues assignment;
Solution(double err, const DiscreteValues& assign)
: error(err), assignment(assign) {}
friend std::ostream& operator<<(std::ostream& os, const Solution& sn) {
os << "[ error=" << sn.error << " assignment={" << sn.assignment << "}]";
return os;
}
};
struct CompareByError {
bool operator()(const Solution& a, const Solution& b) const {
return a.error < b.error;
}
};
using Solutions =
std::priority_queue<Solution, std::vector<Solution>, CompareByError>;
// Function to print the priority queue
void printPriorityQueue(Solutions pq) {
size_t i = 0;
while (!pq.empty()) {
const Solution& best = pq.top();
cout << " " << ++i << " Error: " << best.error
<< ", Values: " << best.assignment << endl;
pq.pop();
}
}
// ----------------------------------------------------------------------------
// 5) The main branch-and-bound routine for DiscreteBayesNet
// ----------------------------------------------------------------------------
std::vector<Solution> branchAndBoundKSolutions(const DiscreteBayesNet& bayesNet,
size_t K) {
// Copy out the conditionals
std::vector<std::shared_ptr<DiscreteConditional>> conditionals;
for (auto& factor : bayesNet) {
conditionals.push_back(factor);
}
// Min-heap of expansions by bound
std::priority_queue<SearchNode, std::vector<SearchNode>, CompareByBound>
expansions;
// Max-heap of best solutions found so far, keyed by error.
// We keep the largest error on top.
Solutions solutions;
// 1) Create the root node: no variables assigned, nextConditional = last.
SearchNode root{.assignment = DiscreteValues(),
.error = 0.0,
.nextConditional = static_cast<int>(conditionals.size()) - 1};
root.bound = root.computeBound();
std::cout << "Root: " << root << std::endl;
expansions.push(root);
// 2) Main loop
size_t numExpansions = 0;
while (!expansions.empty()) {
// Pop the partial assignment with the smallest bound
SearchNode current = expansions.top();
expansions.pop();
numExpansions++;
std::cout << "Expanding: " << current << std::endl;
// If we already have K solutions, prune if we cannot beat the worst one.
if (solutions.size() == K) {
double worstError = solutions.top().error;
if (current.bound >= worstError) {
std::cout << "Pruning: bound=" << current.bound
<< " worstError=" << worstError << std::endl;
continue;
}
}
// Check if we have a complete assignment
if (current.isComplete()) {
// We have a new candidate solution
double finalError = current.error;
if (solutions.size() < K) {
std::cout << "Adding solution: " << current << std::endl;
solutions.emplace(finalError, current.assignment);
// print out all solutions:
std::cout << "Best solutions so far:" << std::endl;
printPriorityQueue(solutions);
} else {
double worstError = solutions.top().error;
if (finalError < worstError) {
std::cout << "Replacing solution: " << current << std::endl;
solutions.pop();
solutions.emplace(finalError, current.assignment);
}
}
continue;
}
// Expand on the next factor
const auto& conditional = conditionals[current.nextConditional];
for (auto& fa : conditional->frontalAssignments()) {
std::cout << "Frontal assignment: " << fa << std::endl;
auto childNode = current.expand(*conditional, fa);
// Again, prune if we cannot beat the worst solution
if (solutions.size() == K) {
double worstError = solutions.top().error;
if (childNode.bound >= worstError) {
std::cout << "Pruning: bound=" << childNode.bound
<< " worstError=" << worstError << std::endl;
continue;
}
}
expansions.push(childNode);
}
}
std::cout << "Expansions: " << numExpansions << std::endl;
// 3) Extract solutions from bestSolutions in ascending order of error
std::vector<Solution> result;
while (!solutions.empty()) {
result.push_back(solutions.top());
solutions.pop();
}
std::sort(result.begin(), result.end(),
[](auto& a, auto& b) { return a.error < b.error; });
return result;
}
// ----------------------------------------------------------------------------
// Example “Unit Tests” (trivial stubs)
// ----------------------------------------------------------------------------
TEST_DISABLED(DiscreteBayesNet, EmptyKBest) {
DiscreteBayesNet net; // no factors
auto solutions = branchAndBoundKSolutions(net, 3);
// Expect one solution with empty assignment, error=0
EXPECT_LONGS_EQUAL(1, solutions.size());
EXPECT_DOUBLES_EQUAL(0, std::fabs(solutions[0].error), 1e-9);
}
TEST(DiscreteBayesNet, AsiaKBest) {
DiscreteBayesNet net = constructAsiaExample();
auto solutions = branchAndBoundKSolutions(net, 4);
EXPECT(!solutions.empty());
// All solutions likely tie with error=0 in this stub
EXPECT_DOUBLES_EQUAL(0, std::fabs(solutions[0].error), 1e-9);
}
/* ************************************************************************* */
int main() {
TestResult tr;