diff --git a/gtsam/discrete/tests/AsiaExample.h b/gtsam/discrete/tests/AsiaExample.h new file mode 100644 index 000000000..7413f6899 --- /dev/null +++ b/gtsam/discrete/tests/AsiaExample.h @@ -0,0 +1,61 @@ +/* ---------------------------------------------------------------------------- + + * GTSAM Copyright 2010, Georgia Tech Research Corporation, + * Atlanta, Georgia 30332-0415 + * All Rights Reserved + * Authors: Frank Dellaert, et al. (see THANKS for the full author list) + + * See LICENSE for the license information + + * -------------------------------------------------------------------------- */ + +/* + * AsiaExample.h + * + * @date Jan, 2025 + * @author Frank Dellaert + */ + +#include +#include + +namespace gtsam { +namespace asia_example { + +static const Key D = Symbol('D', 1), X = Symbol('X', 2), E = Symbol('E', 3), + B = Symbol('B', 4), L = Symbol('L', 5), T = Symbol('T', 6), + S = Symbol('S', 7), A = Symbol('A', 8); + +static const DiscreteKey Dyspnea(D, 2), XRay(X, 2), Either(E, 2), + Bronchitis(B, 2), LungCancer(L, 2), Tuberculosis(T, 2), Smoking(S, 2), + Asia(A, 2); + +// Function to construct the incomplete Asia example +DiscreteBayesNet createPriors() { + DiscreteBayesNet priors; + priors.add(Smoking % "50/50"); + priors.add(Asia, "99/1"); + return priors; +} + +// Function to construct the incomplete Asia example +DiscreteBayesNet createFragment() { + DiscreteBayesNet fragment; + fragment.add((Either | Tuberculosis, LungCancer) = "F T T T"); + fragment.add(LungCancer | Smoking = "99/1 90/10"); + fragment.add(Tuberculosis | Asia = "99/1 95/5"); + for (const auto& factor : createPriors()) fragment.push_back(factor); + return fragment; +} + +// Function to construct the Asia example +DiscreteBayesNet createAsiaExample() { + DiscreteBayesNet asia; + asia.add((Dyspnea | Either, Bronchitis) = "9/1 2/8 3/7 1/9"); + asia.add(XRay | Either = "95/5 2/98"); + asia.add(Bronchitis | Smoking = "70/30 40/60"); + for (const auto& factor : createFragment()) asia.push_back(factor); + return asia; +} +} // namespace asia_example +} // namespace gtsam \ No newline at end of file diff --git a/gtsam/discrete/tests/testDiscreteBayesNet.cpp b/gtsam/discrete/tests/testDiscreteBayesNet.cpp index 648b26770..dd5d218f8 100644 --- a/gtsam/discrete/tests/testDiscreteBayesNet.cpp +++ b/gtsam/discrete/tests/testDiscreteBayesNet.cpp @@ -29,40 +29,13 @@ #include #include -using namespace std; +#include "AsiaExample.h" + using namespace gtsam; -namespace keys { -static const Key D = Symbol('D', 1), X = Symbol('X', 2), E = Symbol('E', 3), - B = Symbol('B', 4), L = Symbol('L', 5), T = Symbol('T', 6), - S = Symbol('S', 7), A = Symbol('A', 8); -} - -static const DiscreteKey Dyspnea(keys::D, 2), XRay(keys::X, 2), - Either(keys::E, 2), Bronchitis(keys::B, 2), LungCancer(keys::L, 2), - Tuberculosis(keys::T, 2), Smoking(keys::S, 2), Asia(keys::A, 2); - -using ADT = AlgebraicDecisionTree; - -// Function to construct the Asia example -DiscreteBayesNet constructAsiaExample() { - DiscreteBayesNet asia; - - // Add in topological sort order, parents last: - asia.add((Dyspnea | Either, Bronchitis) = "9/1 2/8 3/7 1/9"); - asia.add(XRay | Either = "95/5 2/98"); - asia.add((Either | Tuberculosis, LungCancer) = "F T T T"); - asia.add(Bronchitis | Smoking = "70/30 40/60"); - asia.add(LungCancer | Smoking = "99/1 90/10"); - asia.add(Tuberculosis | Asia = "99/1 95/5"); - asia.add(Smoking % "50/50"); // Signature version - asia.add(Asia, "99/1"); - - return asia; -} - /* ************************************************************************* */ TEST(DiscreteBayesNet, bayesNet) { + using ADT = AlgebraicDecisionTree; DiscreteBayesNet bayesNet; DiscreteKey Parent(0, 2), Child(1, 2); @@ -92,7 +65,8 @@ TEST(DiscreteBayesNet, bayesNet) { /* ************************************************************************* */ TEST(DiscreteBayesNet, Asia) { - DiscreteBayesNet asia = constructAsiaExample(); + using namespace asia_example; + const DiscreteBayesNet asia = createAsiaExample(); // Convert to factor graph DiscreteFactorGraph fg(asia); @@ -105,8 +79,7 @@ TEST(DiscreteBayesNet, Asia) { EXPECT(assert_equal(vs, marginals.marginalProbabilities(Smoking))); // Create solver and eliminate - const Ordering ordering{keys::A, keys::D, keys::T, keys::X, - keys::S, keys::E, keys::L, keys::B}; + const Ordering ordering{A, D, T, X, S, E, L, B}; DiscreteBayesNet::shared_ptr chordal = fg.eliminateSequential(ordering); DiscreteConditional expected2(Bronchitis % "11/9"); EXPECT(assert_equal(expected2, *chordal->back())); @@ -151,319 +124,53 @@ TEST(DiscreteBayesNet, Sugar) { /* ************************************************************************* */ TEST(DiscreteBayesNet, Dot) { - DiscreteBayesNet fragment; - fragment.add(Asia % "99/1"); - fragment.add(Smoking % "50/50"); + using namespace asia_example; + const DiscreteBayesNet fragment = createFragment(); - fragment.add(Tuberculosis | Asia = "99/1 95/5"); - fragment.add(LungCancer | Smoking = "99/1 90/10"); - fragment.add((Either | Tuberculosis, LungCancer) = "F T T T"); - - string actual = fragment.dot(); - EXPECT(actual == - "digraph {\n" - " size=\"5,5\";\n" - "\n" - " var4683743612465315848[label=\"A8\"];\n" - " var4971973988617027587[label=\"E3\"];\n" - " var5476377146882523141[label=\"L5\"];\n" - " var5980780305148018695[label=\"S7\"];\n" - " var6052837899185946630[label=\"T6\"];\n" - "\n" - " var6052837899185946630->var4971973988617027587\n" - " var5476377146882523141->var4971973988617027587\n" - " var5980780305148018695->var5476377146882523141\n" - " var4683743612465315848->var6052837899185946630\n" - "}"); + std::string expected = + "digraph {\n" + " size=\"5,5\";\n" + "\n" + " var4683743612465315848[label=\"A8\"];\n" + " var4971973988617027587[label=\"E3\"];\n" + " var5476377146882523141[label=\"L5\"];\n" + " var5980780305148018695[label=\"S7\"];\n" + " var6052837899185946630[label=\"T6\"];\n" + "\n" + " var4683743612465315848->var6052837899185946630\n" + " var5980780305148018695->var5476377146882523141\n" + " var6052837899185946630->var4971973988617027587\n" + " var5476377146882523141->var4971973988617027587\n" + "}"; + std::string actual = fragment.dot(); + EXPECT(actual.compare(expected) == 0); } /* ************************************************************************* */ // Check markdown representation looks as expected. TEST(DiscreteBayesNet, markdown) { - DiscreteBayesNet fragment; - fragment.add(Asia % "99/1"); - fragment.add(Smoking | Asia = "8/2 7/3"); + using namespace asia_example; + DiscreteBayesNet priors = createPriors(); - string expected = + std::string expected = "`DiscreteBayesNet` of size 2\n" "\n" + " *P(Smoking):*\n\n" + "|Smoking|value|\n" + "|:-:|:-:|\n" + "|0|0.5|\n" + "|1|0.5|\n" + "\n" " *P(Asia):*\n\n" "|Asia|value|\n" "|:-:|:-:|\n" "|0|0.99|\n" - "|1|0.01|\n" - "\n" - " *P(Smoking|Asia):*\n\n" - "|*Asia*|0|1|\n" - "|:-:|:-:|:-:|\n" - "|0|0.8|0.2|\n" - "|1|0.7|0.3|\n\n"; - auto formatter = [](Key key) { return key == keys::A ? "Asia" : "Smoking"; }; - string actual = fragment.markdown(formatter); + "|1|0.01|\n\n"; + auto formatter = [](Key key) { return key == A ? "Asia" : "Smoking"; }; + std::string actual = priors.markdown(formatter); EXPECT(actual == expected); } -/* ************************************************************************* */ -#include -#include -#include -#include -#include -#include - -using Value = size_t; - -// ---------------------------------------------------------------------------- -// 1) SearchNode: store partial assignment and next factor to expand -// ---------------------------------------------------------------------------- -struct SearchNode { - DiscreteValues assignment; - double error; - double bound; - int nextConditional; // index into conditionals - - /// if nextConditional < 0, we've assigned everything. - bool isComplete() const { return nextConditional < 0; } - - /// lower bound on final error for unassigned variables. Stub=0. - double computeBound() const { - // Real code might do partial factor analysis or heuristics. - return 0.0; - } - - /// Expand this node by assigning the next variable - SearchNode expand(const DiscreteConditional& conditional, - const DiscreteValues& fa) const { - // Combine the new frontal assignment with the current partial assignment - SearchNode child; - child.assignment = assignment; - for (auto& kv : fa) { - child.assignment[kv.first] = kv.second; - } - - // Compute the incremental error for this factor - child.error = error + conditional.error(child.assignment); - - // Compute new bound - child.bound = child.error + computeBound(); - - // Next factor index - child.nextConditional = nextConditional - 1; - - return child; - } - - friend std::ostream& operator<<(std::ostream& os, const SearchNode& sn) { - os << "[ error=" << sn.error << " bound=" << sn.bound - << " nextConditional=" << sn.nextConditional << " assignment={" - << sn.assignment << "}]"; - return os; - } -}; - -// ---------------------------------------------------------------------------- -// 2) Priority functor to make a min-heap by bound -// ---------------------------------------------------------------------------- -struct CompareByBound { - bool operator()(const SearchNode& a, const SearchNode& b) const { - return a.bound > b.bound; // smallest bound -> highest priority - } -}; - -// ---------------------------------------------------------------------------- -// 4) A Solution -// ---------------------------------------------------------------------------- -struct Solution { - double error; - DiscreteValues assignment; - Solution(double err, const DiscreteValues& assign) - : error(err), assignment(assign) {} - friend std::ostream& operator<<(std::ostream& os, const Solution& sn) { - os << "[ error=" << sn.error << " assignment={" << sn.assignment << "}]"; - return os; - } -}; - -struct CompareByError { - bool operator()(const Solution& a, const Solution& b) const { - return a.error < b.error; - } -}; - -// Define the Solutions class -class Solutions { - private: - size_t maxSize_; - std::priority_queue, CompareByError> pq_; - - public: - Solutions(size_t maxSize) : maxSize_(maxSize) {} - - /// Add a solution to the priority queue, possibly evicting the worst one. - /// Return true if we added the solution. - bool maybeAdd(double error, const DiscreteValues& assignment) { - const bool full = pq_.size() == maxSize_; - if (full && error >= pq_.top().error) return false; - if (full) pq_.pop(); - pq_.emplace(error, assignment); - return true; - } - - /// Check if we have any solutions - bool empty() const { return pq_.empty(); } - - // Method to print all solutions - void print() const { - auto pq = pq_; - while (!pq.empty()) { - const Solution& best = pq.top(); - std::cout << "Error: " << best.error << ", Values: " << best.assignment - << std::endl; - pq.pop(); - } - } - - /// Check if (partial) solution with given bound can be pruned. If we have - /// room, we never prune. Otherwise, prune if lower bound on error is worse - /// than our current worst error. - bool prune(double bound) const { - if (pq_.size() < maxSize_) return false; - double worstError = pq_.top().error; - return (bound >= worstError); - } - - // Method to extract solutions in ascending order of error - std::vector extractSolutions() { - std::vector result; - while (!pq_.empty()) { - result.push_back(pq_.top()); - pq_.pop(); - } - std::sort( - result.begin(), result.end(), - [](const Solution& a, const Solution& b) { return a.error < b.error; }); - return result; - } -}; - -/** - * BestKSearch: Search for the K best solutions. - */ -class BestKSearch { - public: - /** - * Construct from a DiscreteBayesNet and K. - */ - BestKSearch(const DiscreteBayesNet& bayesNet, size_t K) - : bayesNet_(bayesNet), solutions_(K) { - // Copy out the conditionals - for (auto& factor : bayesNet_) { - conditionals_.push_back(factor); - } - - // Create the root node: no variables assigned, nextConditional = last. - SearchNode root{ - .assignment = DiscreteValues(), - .error = 0.0, - .nextConditional = static_cast(conditionals_.size()) - 1}; - root.bound = root.computeBound(); - std::cout << "Root: " << root << std::endl; - expansions_.push(root); - } - - /** - * @brief Search for the K best solutions. - * - * This method performs a search to find the K best solutions for the given - * DiscreteBayesNet. It uses a priority queue to manage the search nodes, - * expanding nodes with the smallest bound first. The search continues until - * all possible nodes have been expanded or pruned. - * - * @return A vector of the K best solutions found during the search. - */ - std::vector run() { - size_t numExpansions = 0; - while (!expansions_.empty()) { - expandNextNode(); - numExpansions++; - } - - std::cout << "Expansions: " << numExpansions << std::endl; - - // Extract solutions from bestSolutions in ascending order of error - return solutions_.extractSolutions(); - } - - private: - // - void expandNextNode() { - // Pop the partial assignment with the smallest bound - SearchNode current = expansions_.top(); - expansions_.pop(); - std::cout << "Expanding: " << current << std::endl; - - // If we already have K solutions, prune if we cannot beat the worst one. - if (solutions_.prune(current.bound)) { - std::cout << "Pruning: bound=" << current.bound << std::endl; - return; - } - - // Check if we have a complete assignment - if (current.isComplete()) { - const bool added = solutions_.maybeAdd(current.error, current.assignment); - if (added) { - std::cout << "Best solutions so far:" << std::endl; - solutions_.print(); - } - return; - } - - // Expand on the next factor - const auto& conditional = conditionals_[current.nextConditional]; - - for (auto& fa : conditional->frontalAssignments()) { - std::cout << "Frontal assignment: " << fa << std::endl; - auto childNode = current.expand(*conditional, fa); - - // Again, prune if we cannot beat the worst solution - if (solutions_.prune(current.bound)) { - std::cout << "Pruning: bound=" << childNode.bound << std::endl; - continue; - } - - expansions_.push(childNode); - } - } - - const DiscreteBayesNet& bayesNet_; - std::vector> conditionals_; - std::priority_queue, CompareByBound> - expansions_; - Solutions solutions_; -}; - -// ---------------------------------------------------------------------------- -// Example “Unit Tests” (trivial stubs) -// ---------------------------------------------------------------------------- - -TEST(DiscreteBayesNet, EmptyKBest) { - DiscreteBayesNet net; // no factors - BestKSearch search(net, 3); - auto solutions = search.run(); - // Expect one solution with empty assignment, error=0 - EXPECT_LONGS_EQUAL(1, solutions.size()); - EXPECT_DOUBLES_EQUAL(0, std::fabs(solutions[0].error), 1e-9); -} - -TEST(DiscreteBayesNet, AsiaKBest) { - DiscreteBayesNet asia = constructAsiaExample(); - BestKSearch search(asia, 4); - auto solutions = search.run(); - EXPECT(!solutions.empty()); - // Regression test: check the first solution - EXPECT_DOUBLES_EQUAL(1.236627, std::fabs(solutions[0].error), 1e-5); -} - /* ************************************************************************* */ int main() { TestResult tr;