Best-K search, v0
parent
4027e86057
commit
2808982903
|
@ -23,6 +23,7 @@
|
||||||
#include <gtsam/discrete/DiscreteBayesNet.h>
|
#include <gtsam/discrete/DiscreteBayesNet.h>
|
||||||
#include <gtsam/discrete/DiscreteFactorGraph.h>
|
#include <gtsam/discrete/DiscreteFactorGraph.h>
|
||||||
#include <gtsam/discrete/DiscreteMarginals.h>
|
#include <gtsam/discrete/DiscreteMarginals.h>
|
||||||
|
#include <gtsam/inference/Symbol.h>
|
||||||
|
|
||||||
#include <iostream>
|
#include <iostream>
|
||||||
#include <string>
|
#include <string>
|
||||||
|
@ -31,8 +32,15 @@
|
||||||
using namespace std;
|
using namespace std;
|
||||||
using namespace gtsam;
|
using namespace gtsam;
|
||||||
|
|
||||||
static const DiscreteKey Asia(0, 2), Smoking(4, 2), Tuberculosis(3, 2),
|
namespace keys {
|
||||||
LungCancer(6, 2), Bronchitis(7, 2), Either(5, 2), XRay(2, 2), Dyspnea(1, 2);
|
static const Key D = Symbol('D', 1), X = Symbol('X', 2), E = Symbol('E', 3),
|
||||||
|
B = Symbol('B', 4), L = Symbol('L', 5), T = Symbol('T', 6),
|
||||||
|
S = Symbol('S', 7), A = Symbol('A', 8);
|
||||||
|
}
|
||||||
|
|
||||||
|
static const DiscreteKey Dyspnea(keys::D, 2), XRay(keys::X, 2),
|
||||||
|
Either(keys::E, 2), Bronchitis(keys::B, 2), LungCancer(keys::L, 2),
|
||||||
|
Tuberculosis(keys::T, 2), Smoking(keys::S, 2), Asia(keys::A, 2);
|
||||||
|
|
||||||
using ADT = AlgebraicDecisionTree<Key>;
|
using ADT = AlgebraicDecisionTree<Key>;
|
||||||
|
|
||||||
|
@ -40,17 +48,15 @@ using ADT = AlgebraicDecisionTree<Key>;
|
||||||
DiscreteBayesNet constructAsiaExample() {
|
DiscreteBayesNet constructAsiaExample() {
|
||||||
DiscreteBayesNet asia;
|
DiscreteBayesNet asia;
|
||||||
|
|
||||||
asia.add(Asia, "99/1");
|
// Add in topological sort order, parents last:
|
||||||
asia.add(Smoking % "50/50"); // Signature version
|
|
||||||
|
|
||||||
asia.add(Tuberculosis | Asia = "99/1 95/5");
|
|
||||||
asia.add(LungCancer | Smoking = "99/1 90/10");
|
|
||||||
asia.add(Bronchitis | Smoking = "70/30 40/60");
|
|
||||||
|
|
||||||
asia.add((Either | Tuberculosis, LungCancer) = "F T T T");
|
|
||||||
|
|
||||||
asia.add(XRay | Either = "95/5 2/98");
|
|
||||||
asia.add((Dyspnea | Either, Bronchitis) = "9/1 2/8 3/7 1/9");
|
asia.add((Dyspnea | Either, Bronchitis) = "9/1 2/8 3/7 1/9");
|
||||||
|
asia.add(XRay | Either = "95/5 2/98");
|
||||||
|
asia.add((Either | Tuberculosis, LungCancer) = "F T T T");
|
||||||
|
asia.add(Bronchitis | Smoking = "70/30 40/60");
|
||||||
|
asia.add(LungCancer | Smoking = "99/1 90/10");
|
||||||
|
asia.add(Tuberculosis | Asia = "99/1 95/5");
|
||||||
|
asia.add(Smoking % "50/50"); // Signature version
|
||||||
|
asia.add(Asia, "99/1");
|
||||||
|
|
||||||
return asia;
|
return asia;
|
||||||
}
|
}
|
||||||
|
@ -99,7 +105,8 @@ TEST(DiscreteBayesNet, Asia) {
|
||||||
EXPECT(assert_equal(vs, marginals.marginalProbabilities(Smoking)));
|
EXPECT(assert_equal(vs, marginals.marginalProbabilities(Smoking)));
|
||||||
|
|
||||||
// Create solver and eliminate
|
// Create solver and eliminate
|
||||||
const Ordering ordering{0, 1, 2, 3, 4, 5, 6, 7};
|
const Ordering ordering{keys::A, keys::D, keys::T, keys::X,
|
||||||
|
keys::S, keys::E, keys::L, keys::B};
|
||||||
DiscreteBayesNet::shared_ptr chordal = fg.eliminateSequential(ordering);
|
DiscreteBayesNet::shared_ptr chordal = fg.eliminateSequential(ordering);
|
||||||
DiscreteConditional expected2(Bronchitis % "11/9");
|
DiscreteConditional expected2(Bronchitis % "11/9");
|
||||||
EXPECT(assert_equal(expected2, *chordal->back()));
|
EXPECT(assert_equal(expected2, *chordal->back()));
|
||||||
|
@ -157,16 +164,16 @@ TEST(DiscreteBayesNet, Dot) {
|
||||||
"digraph {\n"
|
"digraph {\n"
|
||||||
" size=\"5,5\";\n"
|
" size=\"5,5\";\n"
|
||||||
"\n"
|
"\n"
|
||||||
" var0[label=\"0\"];\n"
|
" var4683743612465315840[label=\"A0\"];\n"
|
||||||
" var3[label=\"3\"];\n"
|
" var4971973988617027589[label=\"E5\"];\n"
|
||||||
" var4[label=\"4\"];\n"
|
" var5476377146882523142[label=\"L6\"];\n"
|
||||||
" var5[label=\"5\"];\n"
|
" var5980780305148018692[label=\"S4\"];\n"
|
||||||
" var6[label=\"6\"];\n"
|
" var6052837899185946627[label=\"T3\"];\n"
|
||||||
"\n"
|
"\n"
|
||||||
" var3->var5\n"
|
" var6052837899185946627->var4971973988617027589\n"
|
||||||
" var6->var5\n"
|
" var5476377146882523142->var4971973988617027589\n"
|
||||||
" var4->var6\n"
|
" var5980780305148018692->var5476377146882523142\n"
|
||||||
" var0->var3\n"
|
" var4683743612465315840->var6052837899185946627\n"
|
||||||
"}");
|
"}");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -191,11 +198,234 @@ TEST(DiscreteBayesNet, markdown) {
|
||||||
"|:-:|:-:|:-:|\n"
|
"|:-:|:-:|:-:|\n"
|
||||||
"|0|0.8|0.2|\n"
|
"|0|0.8|0.2|\n"
|
||||||
"|1|0.7|0.3|\n\n";
|
"|1|0.7|0.3|\n\n";
|
||||||
auto formatter = [](Key key) { return key == 0 ? "Asia" : "Smoking"; };
|
auto formatter = [](Key key) { return key == keys::A ? "Asia" : "Smoking"; };
|
||||||
string actual = fragment.markdown(formatter);
|
string actual = fragment.markdown(formatter);
|
||||||
EXPECT(actual == expected);
|
EXPECT(actual == expected);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************* */
|
||||||
|
#include <algorithm>
|
||||||
|
#include <cmath>
|
||||||
|
#include <iostream>
|
||||||
|
#include <map>
|
||||||
|
#include <queue>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
using Value = size_t;
|
||||||
|
|
||||||
|
// ----------------------------------------------------------------------------
|
||||||
|
// 1) SearchNode: store partial assignment and next factor to expand
|
||||||
|
// ----------------------------------------------------------------------------
|
||||||
|
struct SearchNode {
|
||||||
|
DiscreteValues assignment;
|
||||||
|
double error;
|
||||||
|
double bound;
|
||||||
|
int nextConditional; // index into conditionals
|
||||||
|
|
||||||
|
/// if nextConditional < 0, we've assigned everything.
|
||||||
|
bool isComplete() const { return nextConditional < 0; }
|
||||||
|
|
||||||
|
/// lower bound on final error for unassigned variables. Stub=0.
|
||||||
|
double computeBound() const {
|
||||||
|
// Real code might do partial factor analysis or heuristics.
|
||||||
|
return 0.0;
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Expand this node by assigning the next variable
|
||||||
|
SearchNode expand(const DiscreteConditional& conditional,
|
||||||
|
const DiscreteValues& fa) const {
|
||||||
|
// Combine the new frontal assignment with the current partial assignment
|
||||||
|
SearchNode child;
|
||||||
|
child.assignment = assignment;
|
||||||
|
for (auto& kv : fa) {
|
||||||
|
child.assignment[kv.first] = kv.second;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Compute the incremental error for this factor
|
||||||
|
child.error = error + conditional.error(child.assignment);
|
||||||
|
|
||||||
|
// Compute new bound
|
||||||
|
child.bound = child.error + computeBound();
|
||||||
|
|
||||||
|
// Next factor index
|
||||||
|
child.nextConditional = nextConditional - 1;
|
||||||
|
|
||||||
|
return child;
|
||||||
|
}
|
||||||
|
|
||||||
|
friend std::ostream& operator<<(std::ostream& os, const SearchNode& sn) {
|
||||||
|
os << "[ error=" << sn.error << " bound=" << sn.bound
|
||||||
|
<< " nextConditional=" << sn.nextConditional << " assignment={"
|
||||||
|
<< sn.assignment << "}]";
|
||||||
|
return os;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
// ----------------------------------------------------------------------------
|
||||||
|
// 2) Priority functor to make a min-heap by bound
|
||||||
|
// ----------------------------------------------------------------------------
|
||||||
|
struct CompareByBound {
|
||||||
|
bool operator()(const SearchNode& a, const SearchNode& b) const {
|
||||||
|
return a.bound > b.bound; // smallest bound -> highest priority
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
// ----------------------------------------------------------------------------
|
||||||
|
// 4) A Solution
|
||||||
|
// ----------------------------------------------------------------------------
|
||||||
|
struct Solution {
|
||||||
|
double error;
|
||||||
|
DiscreteValues assignment;
|
||||||
|
Solution(double err, const DiscreteValues& assign)
|
||||||
|
: error(err), assignment(assign) {}
|
||||||
|
friend std::ostream& operator<<(std::ostream& os, const Solution& sn) {
|
||||||
|
os << "[ error=" << sn.error << " assignment={" << sn.assignment << "}]";
|
||||||
|
return os;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
struct CompareByError {
|
||||||
|
bool operator()(const Solution& a, const Solution& b) const {
|
||||||
|
return a.error < b.error;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
using Solutions =
|
||||||
|
std::priority_queue<Solution, std::vector<Solution>, CompareByError>;
|
||||||
|
|
||||||
|
// Function to print the priority queue
|
||||||
|
void printPriorityQueue(Solutions pq) {
|
||||||
|
size_t i = 0;
|
||||||
|
while (!pq.empty()) {
|
||||||
|
const Solution& best = pq.top();
|
||||||
|
cout << " " << ++i << " Error: " << best.error
|
||||||
|
<< ", Values: " << best.assignment << endl;
|
||||||
|
pq.pop();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ----------------------------------------------------------------------------
|
||||||
|
// 5) The main branch-and-bound routine for DiscreteBayesNet
|
||||||
|
// ----------------------------------------------------------------------------
|
||||||
|
std::vector<Solution> branchAndBoundKSolutions(const DiscreteBayesNet& bayesNet,
|
||||||
|
size_t K) {
|
||||||
|
// Copy out the conditionals
|
||||||
|
std::vector<std::shared_ptr<DiscreteConditional>> conditionals;
|
||||||
|
for (auto& factor : bayesNet) {
|
||||||
|
conditionals.push_back(factor);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Min-heap of expansions by bound
|
||||||
|
std::priority_queue<SearchNode, std::vector<SearchNode>, CompareByBound>
|
||||||
|
expansions;
|
||||||
|
|
||||||
|
// Max-heap of best solutions found so far, keyed by error.
|
||||||
|
// We keep the largest error on top.
|
||||||
|
Solutions solutions;
|
||||||
|
|
||||||
|
// 1) Create the root node: no variables assigned, nextConditional = last.
|
||||||
|
SearchNode root{.assignment = DiscreteValues(),
|
||||||
|
.error = 0.0,
|
||||||
|
.nextConditional = static_cast<int>(conditionals.size()) - 1};
|
||||||
|
root.bound = root.computeBound();
|
||||||
|
std::cout << "Root: " << root << std::endl;
|
||||||
|
expansions.push(root);
|
||||||
|
|
||||||
|
// 2) Main loop
|
||||||
|
size_t numExpansions = 0;
|
||||||
|
while (!expansions.empty()) {
|
||||||
|
// Pop the partial assignment with the smallest bound
|
||||||
|
SearchNode current = expansions.top();
|
||||||
|
expansions.pop();
|
||||||
|
numExpansions++;
|
||||||
|
std::cout << "Expanding: " << current << std::endl;
|
||||||
|
|
||||||
|
// If we already have K solutions, prune if we cannot beat the worst one.
|
||||||
|
if (solutions.size() == K) {
|
||||||
|
double worstError = solutions.top().error;
|
||||||
|
if (current.bound >= worstError) {
|
||||||
|
std::cout << "Pruning: bound=" << current.bound
|
||||||
|
<< " worstError=" << worstError << std::endl;
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if we have a complete assignment
|
||||||
|
if (current.isComplete()) {
|
||||||
|
// We have a new candidate solution
|
||||||
|
double finalError = current.error;
|
||||||
|
if (solutions.size() < K) {
|
||||||
|
std::cout << "Adding solution: " << current << std::endl;
|
||||||
|
solutions.emplace(finalError, current.assignment);
|
||||||
|
// print out all solutions:
|
||||||
|
std::cout << "Best solutions so far:" << std::endl;
|
||||||
|
printPriorityQueue(solutions);
|
||||||
|
} else {
|
||||||
|
double worstError = solutions.top().error;
|
||||||
|
if (finalError < worstError) {
|
||||||
|
std::cout << "Replacing solution: " << current << std::endl;
|
||||||
|
solutions.pop();
|
||||||
|
solutions.emplace(finalError, current.assignment);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Expand on the next factor
|
||||||
|
const auto& conditional = conditionals[current.nextConditional];
|
||||||
|
|
||||||
|
for (auto& fa : conditional->frontalAssignments()) {
|
||||||
|
std::cout << "Frontal assignment: " << fa << std::endl;
|
||||||
|
auto childNode = current.expand(*conditional, fa);
|
||||||
|
|
||||||
|
// Again, prune if we cannot beat the worst solution
|
||||||
|
if (solutions.size() == K) {
|
||||||
|
double worstError = solutions.top().error;
|
||||||
|
if (childNode.bound >= worstError) {
|
||||||
|
std::cout << "Pruning: bound=" << childNode.bound
|
||||||
|
<< " worstError=" << worstError << std::endl;
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
expansions.push(childNode);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
std::cout << "Expansions: " << numExpansions << std::endl;
|
||||||
|
|
||||||
|
// 3) Extract solutions from bestSolutions in ascending order of error
|
||||||
|
std::vector<Solution> result;
|
||||||
|
while (!solutions.empty()) {
|
||||||
|
result.push_back(solutions.top());
|
||||||
|
solutions.pop();
|
||||||
|
}
|
||||||
|
std::sort(result.begin(), result.end(),
|
||||||
|
[](auto& a, auto& b) { return a.error < b.error; });
|
||||||
|
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
// ----------------------------------------------------------------------------
|
||||||
|
// Example “Unit Tests” (trivial stubs)
|
||||||
|
// ----------------------------------------------------------------------------
|
||||||
|
|
||||||
|
TEST_DISABLED(DiscreteBayesNet, EmptyKBest) {
|
||||||
|
DiscreteBayesNet net; // no factors
|
||||||
|
auto solutions = branchAndBoundKSolutions(net, 3);
|
||||||
|
// Expect one solution with empty assignment, error=0
|
||||||
|
EXPECT_LONGS_EQUAL(1, solutions.size());
|
||||||
|
EXPECT_DOUBLES_EQUAL(0, std::fabs(solutions[0].error), 1e-9);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(DiscreteBayesNet, AsiaKBest) {
|
||||||
|
DiscreteBayesNet net = constructAsiaExample();
|
||||||
|
|
||||||
|
auto solutions = branchAndBoundKSolutions(net, 4);
|
||||||
|
EXPECT(!solutions.empty());
|
||||||
|
// All solutions likely tie with error=0 in this stub
|
||||||
|
EXPECT_DOUBLES_EQUAL(0, std::fabs(solutions[0].error), 1e-9);
|
||||||
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
int main() {
|
int main() {
|
||||||
TestResult tr;
|
TestResult tr;
|
||||||
|
|
Loading…
Reference in New Issue