Asia example

release/4.3a0
Frank Dellaert 2025-01-26 22:08:04 -05:00
parent 1f4d9bbd7e
commit d879b156f8
2 changed files with 98 additions and 330 deletions

View File

@ -0,0 +1,61 @@
/* ----------------------------------------------------------------------------
* GTSAM Copyright 2010, Georgia Tech Research Corporation,
* Atlanta, Georgia 30332-0415
* All Rights Reserved
* Authors: Frank Dellaert, et al. (see THANKS for the full author list)
* See LICENSE for the license information
* -------------------------------------------------------------------------- */
/*
* AsiaExample.h
*
* @date Jan, 2025
* @author Frank Dellaert
*/
#include <gtsam/discrete/DiscreteBayesNet.h>
#include <gtsam/inference/Symbol.h>
namespace gtsam {
namespace asia_example {
static const Key D = Symbol('D', 1), X = Symbol('X', 2), E = Symbol('E', 3),
B = Symbol('B', 4), L = Symbol('L', 5), T = Symbol('T', 6),
S = Symbol('S', 7), A = Symbol('A', 8);
static const DiscreteKey Dyspnea(D, 2), XRay(X, 2), Either(E, 2),
Bronchitis(B, 2), LungCancer(L, 2), Tuberculosis(T, 2), Smoking(S, 2),
Asia(A, 2);
// Function to construct the incomplete Asia example
DiscreteBayesNet createPriors() {
DiscreteBayesNet priors;
priors.add(Smoking % "50/50");
priors.add(Asia, "99/1");
return priors;
}
// Function to construct the incomplete Asia example
DiscreteBayesNet createFragment() {
DiscreteBayesNet fragment;
fragment.add((Either | Tuberculosis, LungCancer) = "F T T T");
fragment.add(LungCancer | Smoking = "99/1 90/10");
fragment.add(Tuberculosis | Asia = "99/1 95/5");
for (const auto& factor : createPriors()) fragment.push_back(factor);
return fragment;
}
// Function to construct the Asia example
DiscreteBayesNet createAsiaExample() {
DiscreteBayesNet asia;
asia.add((Dyspnea | Either, Bronchitis) = "9/1 2/8 3/7 1/9");
asia.add(XRay | Either = "95/5 2/98");
asia.add(Bronchitis | Smoking = "70/30 40/60");
for (const auto& factor : createFragment()) asia.push_back(factor);
return asia;
}
} // namespace asia_example
} // namespace gtsam

View File

@ -29,40 +29,13 @@
#include <string> #include <string>
#include <vector> #include <vector>
using namespace std; #include "AsiaExample.h"
using namespace gtsam; using namespace gtsam;
namespace keys {
static const Key D = Symbol('D', 1), X = Symbol('X', 2), E = Symbol('E', 3),
B = Symbol('B', 4), L = Symbol('L', 5), T = Symbol('T', 6),
S = Symbol('S', 7), A = Symbol('A', 8);
}
static const DiscreteKey Dyspnea(keys::D, 2), XRay(keys::X, 2),
Either(keys::E, 2), Bronchitis(keys::B, 2), LungCancer(keys::L, 2),
Tuberculosis(keys::T, 2), Smoking(keys::S, 2), Asia(keys::A, 2);
using ADT = AlgebraicDecisionTree<Key>;
// Function to construct the Asia example
DiscreteBayesNet constructAsiaExample() {
DiscreteBayesNet asia;
// Add in topological sort order, parents last:
asia.add((Dyspnea | Either, Bronchitis) = "9/1 2/8 3/7 1/9");
asia.add(XRay | Either = "95/5 2/98");
asia.add((Either | Tuberculosis, LungCancer) = "F T T T");
asia.add(Bronchitis | Smoking = "70/30 40/60");
asia.add(LungCancer | Smoking = "99/1 90/10");
asia.add(Tuberculosis | Asia = "99/1 95/5");
asia.add(Smoking % "50/50"); // Signature version
asia.add(Asia, "99/1");
return asia;
}
/* ************************************************************************* */ /* ************************************************************************* */
TEST(DiscreteBayesNet, bayesNet) { TEST(DiscreteBayesNet, bayesNet) {
using ADT = AlgebraicDecisionTree<Key>;
DiscreteBayesNet bayesNet; DiscreteBayesNet bayesNet;
DiscreteKey Parent(0, 2), Child(1, 2); DiscreteKey Parent(0, 2), Child(1, 2);
@ -92,7 +65,8 @@ TEST(DiscreteBayesNet, bayesNet) {
/* ************************************************************************* */ /* ************************************************************************* */
TEST(DiscreteBayesNet, Asia) { TEST(DiscreteBayesNet, Asia) {
DiscreteBayesNet asia = constructAsiaExample(); using namespace asia_example;
const DiscreteBayesNet asia = createAsiaExample();
// Convert to factor graph // Convert to factor graph
DiscreteFactorGraph fg(asia); DiscreteFactorGraph fg(asia);
@ -105,8 +79,7 @@ TEST(DiscreteBayesNet, Asia) {
EXPECT(assert_equal(vs, marginals.marginalProbabilities(Smoking))); EXPECT(assert_equal(vs, marginals.marginalProbabilities(Smoking)));
// Create solver and eliminate // Create solver and eliminate
const Ordering ordering{keys::A, keys::D, keys::T, keys::X, const Ordering ordering{A, D, T, X, S, E, L, B};
keys::S, keys::E, keys::L, keys::B};
DiscreteBayesNet::shared_ptr chordal = fg.eliminateSequential(ordering); DiscreteBayesNet::shared_ptr chordal = fg.eliminateSequential(ordering);
DiscreteConditional expected2(Bronchitis % "11/9"); DiscreteConditional expected2(Bronchitis % "11/9");
EXPECT(assert_equal(expected2, *chordal->back())); EXPECT(assert_equal(expected2, *chordal->back()));
@ -151,319 +124,53 @@ TEST(DiscreteBayesNet, Sugar) {
/* ************************************************************************* */ /* ************************************************************************* */
TEST(DiscreteBayesNet, Dot) { TEST(DiscreteBayesNet, Dot) {
DiscreteBayesNet fragment; using namespace asia_example;
fragment.add(Asia % "99/1"); const DiscreteBayesNet fragment = createFragment();
fragment.add(Smoking % "50/50");
fragment.add(Tuberculosis | Asia = "99/1 95/5"); std::string expected =
fragment.add(LungCancer | Smoking = "99/1 90/10"); "digraph {\n"
fragment.add((Either | Tuberculosis, LungCancer) = "F T T T"); " size=\"5,5\";\n"
"\n"
string actual = fragment.dot(); " var4683743612465315848[label=\"A8\"];\n"
EXPECT(actual == " var4971973988617027587[label=\"E3\"];\n"
"digraph {\n" " var5476377146882523141[label=\"L5\"];\n"
" size=\"5,5\";\n" " var5980780305148018695[label=\"S7\"];\n"
"\n" " var6052837899185946630[label=\"T6\"];\n"
" var4683743612465315848[label=\"A8\"];\n" "\n"
" var4971973988617027587[label=\"E3\"];\n" " var4683743612465315848->var6052837899185946630\n"
" var5476377146882523141[label=\"L5\"];\n" " var5980780305148018695->var5476377146882523141\n"
" var5980780305148018695[label=\"S7\"];\n" " var6052837899185946630->var4971973988617027587\n"
" var6052837899185946630[label=\"T6\"];\n" " var5476377146882523141->var4971973988617027587\n"
"\n" "}";
" var6052837899185946630->var4971973988617027587\n" std::string actual = fragment.dot();
" var5476377146882523141->var4971973988617027587\n" EXPECT(actual.compare(expected) == 0);
" var5980780305148018695->var5476377146882523141\n"
" var4683743612465315848->var6052837899185946630\n"
"}");
} }
/* ************************************************************************* */ /* ************************************************************************* */
// Check markdown representation looks as expected. // Check markdown representation looks as expected.
TEST(DiscreteBayesNet, markdown) { TEST(DiscreteBayesNet, markdown) {
DiscreteBayesNet fragment; using namespace asia_example;
fragment.add(Asia % "99/1"); DiscreteBayesNet priors = createPriors();
fragment.add(Smoking | Asia = "8/2 7/3");
string expected = std::string expected =
"`DiscreteBayesNet` of size 2\n" "`DiscreteBayesNet` of size 2\n"
"\n" "\n"
" *P(Smoking):*\n\n"
"|Smoking|value|\n"
"|:-:|:-:|\n"
"|0|0.5|\n"
"|1|0.5|\n"
"\n"
" *P(Asia):*\n\n" " *P(Asia):*\n\n"
"|Asia|value|\n" "|Asia|value|\n"
"|:-:|:-:|\n" "|:-:|:-:|\n"
"|0|0.99|\n" "|0|0.99|\n"
"|1|0.01|\n" "|1|0.01|\n\n";
"\n" auto formatter = [](Key key) { return key == A ? "Asia" : "Smoking"; };
" *P(Smoking|Asia):*\n\n" std::string actual = priors.markdown(formatter);
"|*Asia*|0|1|\n"
"|:-:|:-:|:-:|\n"
"|0|0.8|0.2|\n"
"|1|0.7|0.3|\n\n";
auto formatter = [](Key key) { return key == keys::A ? "Asia" : "Smoking"; };
string actual = fragment.markdown(formatter);
EXPECT(actual == expected); EXPECT(actual == expected);
} }
/* ************************************************************************* */
#include <algorithm>
#include <cmath>
#include <iostream>
#include <map>
#include <queue>
#include <vector>
using Value = size_t;
// ----------------------------------------------------------------------------
// 1) SearchNode: store partial assignment and next factor to expand
// ----------------------------------------------------------------------------
struct SearchNode {
DiscreteValues assignment;
double error;
double bound;
int nextConditional; // index into conditionals
/// if nextConditional < 0, we've assigned everything.
bool isComplete() const { return nextConditional < 0; }
/// lower bound on final error for unassigned variables. Stub=0.
double computeBound() const {
// Real code might do partial factor analysis or heuristics.
return 0.0;
}
/// Expand this node by assigning the next variable
SearchNode expand(const DiscreteConditional& conditional,
const DiscreteValues& fa) const {
// Combine the new frontal assignment with the current partial assignment
SearchNode child;
child.assignment = assignment;
for (auto& kv : fa) {
child.assignment[kv.first] = kv.second;
}
// Compute the incremental error for this factor
child.error = error + conditional.error(child.assignment);
// Compute new bound
child.bound = child.error + computeBound();
// Next factor index
child.nextConditional = nextConditional - 1;
return child;
}
friend std::ostream& operator<<(std::ostream& os, const SearchNode& sn) {
os << "[ error=" << sn.error << " bound=" << sn.bound
<< " nextConditional=" << sn.nextConditional << " assignment={"
<< sn.assignment << "}]";
return os;
}
};
// ----------------------------------------------------------------------------
// 2) Priority functor to make a min-heap by bound
// ----------------------------------------------------------------------------
struct CompareByBound {
bool operator()(const SearchNode& a, const SearchNode& b) const {
return a.bound > b.bound; // smallest bound -> highest priority
}
};
// ----------------------------------------------------------------------------
// 4) A Solution
// ----------------------------------------------------------------------------
struct Solution {
double error;
DiscreteValues assignment;
Solution(double err, const DiscreteValues& assign)
: error(err), assignment(assign) {}
friend std::ostream& operator<<(std::ostream& os, const Solution& sn) {
os << "[ error=" << sn.error << " assignment={" << sn.assignment << "}]";
return os;
}
};
struct CompareByError {
bool operator()(const Solution& a, const Solution& b) const {
return a.error < b.error;
}
};
// Define the Solutions class
class Solutions {
private:
size_t maxSize_;
std::priority_queue<Solution, std::vector<Solution>, CompareByError> pq_;
public:
Solutions(size_t maxSize) : maxSize_(maxSize) {}
/// Add a solution to the priority queue, possibly evicting the worst one.
/// Return true if we added the solution.
bool maybeAdd(double error, const DiscreteValues& assignment) {
const bool full = pq_.size() == maxSize_;
if (full && error >= pq_.top().error) return false;
if (full) pq_.pop();
pq_.emplace(error, assignment);
return true;
}
/// Check if we have any solutions
bool empty() const { return pq_.empty(); }
// Method to print all solutions
void print() const {
auto pq = pq_;
while (!pq.empty()) {
const Solution& best = pq.top();
std::cout << "Error: " << best.error << ", Values: " << best.assignment
<< std::endl;
pq.pop();
}
}
/// Check if (partial) solution with given bound can be pruned. If we have
/// room, we never prune. Otherwise, prune if lower bound on error is worse
/// than our current worst error.
bool prune(double bound) const {
if (pq_.size() < maxSize_) return false;
double worstError = pq_.top().error;
return (bound >= worstError);
}
// Method to extract solutions in ascending order of error
std::vector<Solution> extractSolutions() {
std::vector<Solution> result;
while (!pq_.empty()) {
result.push_back(pq_.top());
pq_.pop();
}
std::sort(
result.begin(), result.end(),
[](const Solution& a, const Solution& b) { return a.error < b.error; });
return result;
}
};
/**
* BestKSearch: Search for the K best solutions.
*/
class BestKSearch {
public:
/**
* Construct from a DiscreteBayesNet and K.
*/
BestKSearch(const DiscreteBayesNet& bayesNet, size_t K)
: bayesNet_(bayesNet), solutions_(K) {
// Copy out the conditionals
for (auto& factor : bayesNet_) {
conditionals_.push_back(factor);
}
// Create the root node: no variables assigned, nextConditional = last.
SearchNode root{
.assignment = DiscreteValues(),
.error = 0.0,
.nextConditional = static_cast<int>(conditionals_.size()) - 1};
root.bound = root.computeBound();
std::cout << "Root: " << root << std::endl;
expansions_.push(root);
}
/**
* @brief Search for the K best solutions.
*
* This method performs a search to find the K best solutions for the given
* DiscreteBayesNet. It uses a priority queue to manage the search nodes,
* expanding nodes with the smallest bound first. The search continues until
* all possible nodes have been expanded or pruned.
*
* @return A vector of the K best solutions found during the search.
*/
std::vector<Solution> run() {
size_t numExpansions = 0;
while (!expansions_.empty()) {
expandNextNode();
numExpansions++;
}
std::cout << "Expansions: " << numExpansions << std::endl;
// Extract solutions from bestSolutions in ascending order of error
return solutions_.extractSolutions();
}
private:
//
void expandNextNode() {
// Pop the partial assignment with the smallest bound
SearchNode current = expansions_.top();
expansions_.pop();
std::cout << "Expanding: " << current << std::endl;
// If we already have K solutions, prune if we cannot beat the worst one.
if (solutions_.prune(current.bound)) {
std::cout << "Pruning: bound=" << current.bound << std::endl;
return;
}
// Check if we have a complete assignment
if (current.isComplete()) {
const bool added = solutions_.maybeAdd(current.error, current.assignment);
if (added) {
std::cout << "Best solutions so far:" << std::endl;
solutions_.print();
}
return;
}
// Expand on the next factor
const auto& conditional = conditionals_[current.nextConditional];
for (auto& fa : conditional->frontalAssignments()) {
std::cout << "Frontal assignment: " << fa << std::endl;
auto childNode = current.expand(*conditional, fa);
// Again, prune if we cannot beat the worst solution
if (solutions_.prune(current.bound)) {
std::cout << "Pruning: bound=" << childNode.bound << std::endl;
continue;
}
expansions_.push(childNode);
}
}
const DiscreteBayesNet& bayesNet_;
std::vector<std::shared_ptr<DiscreteConditional>> conditionals_;
std::priority_queue<SearchNode, std::vector<SearchNode>, CompareByBound>
expansions_;
Solutions solutions_;
};
// ----------------------------------------------------------------------------
// Example “Unit Tests” (trivial stubs)
// ----------------------------------------------------------------------------
TEST(DiscreteBayesNet, EmptyKBest) {
DiscreteBayesNet net; // no factors
BestKSearch search(net, 3);
auto solutions = search.run();
// Expect one solution with empty assignment, error=0
EXPECT_LONGS_EQUAL(1, solutions.size());
EXPECT_DOUBLES_EQUAL(0, std::fabs(solutions[0].error), 1e-9);
}
TEST(DiscreteBayesNet, AsiaKBest) {
DiscreteBayesNet asia = constructAsiaExample();
BestKSearch search(asia, 4);
auto solutions = search.run();
EXPECT(!solutions.empty());
// Regression test: check the first solution
EXPECT_DOUBLES_EQUAL(1.236627, std::fabs(solutions[0].error), 1e-5);
}
/* ************************************************************************* */ /* ************************************************************************* */
int main() { int main() {
TestResult tr; TestResult tr;