diff --git a/gtsam/discrete/tests/testDiscreteBayesNet.cpp b/gtsam/discrete/tests/testDiscreteBayesNet.cpp index d2033909c..0bd81a197 100644 --- a/gtsam/discrete/tests/testDiscreteBayesNet.cpp +++ b/gtsam/discrete/tests/testDiscreteBayesNet.cpp @@ -23,6 +23,7 @@ #include #include #include +#include #include #include @@ -31,8 +32,15 @@ using namespace std; using namespace gtsam; -static const DiscreteKey Asia(0, 2), Smoking(4, 2), Tuberculosis(3, 2), - LungCancer(6, 2), Bronchitis(7, 2), Either(5, 2), XRay(2, 2), Dyspnea(1, 2); +namespace keys { +static const Key D = Symbol('D', 1), X = Symbol('X', 2), E = Symbol('E', 3), + B = Symbol('B', 4), L = Symbol('L', 5), T = Symbol('T', 6), + S = Symbol('S', 7), A = Symbol('A', 8); +} + +static const DiscreteKey Dyspnea(keys::D, 2), XRay(keys::X, 2), + Either(keys::E, 2), Bronchitis(keys::B, 2), LungCancer(keys::L, 2), + Tuberculosis(keys::T, 2), Smoking(keys::S, 2), Asia(keys::A, 2); using ADT = AlgebraicDecisionTree; @@ -40,17 +48,15 @@ using ADT = AlgebraicDecisionTree; DiscreteBayesNet constructAsiaExample() { DiscreteBayesNet asia; - asia.add(Asia, "99/1"); - asia.add(Smoking % "50/50"); // Signature version - - asia.add(Tuberculosis | Asia = "99/1 95/5"); - asia.add(LungCancer | Smoking = "99/1 90/10"); - asia.add(Bronchitis | Smoking = "70/30 40/60"); - - asia.add((Either | Tuberculosis, LungCancer) = "F T T T"); - - asia.add(XRay | Either = "95/5 2/98"); + // Add in topological sort order, parents last: asia.add((Dyspnea | Either, Bronchitis) = "9/1 2/8 3/7 1/9"); + asia.add(XRay | Either = "95/5 2/98"); + asia.add((Either | Tuberculosis, LungCancer) = "F T T T"); + asia.add(Bronchitis | Smoking = "70/30 40/60"); + asia.add(LungCancer | Smoking = "99/1 90/10"); + asia.add(Tuberculosis | Asia = "99/1 95/5"); + asia.add(Smoking % "50/50"); // Signature version + asia.add(Asia, "99/1"); return asia; } @@ -99,7 +105,8 @@ TEST(DiscreteBayesNet, Asia) { EXPECT(assert_equal(vs, marginals.marginalProbabilities(Smoking))); // Create solver and eliminate - const Ordering ordering{0, 1, 2, 3, 4, 5, 6, 7}; + const Ordering ordering{keys::A, keys::D, keys::T, keys::X, + keys::S, keys::E, keys::L, keys::B}; DiscreteBayesNet::shared_ptr chordal = fg.eliminateSequential(ordering); DiscreteConditional expected2(Bronchitis % "11/9"); EXPECT(assert_equal(expected2, *chordal->back())); @@ -157,16 +164,16 @@ TEST(DiscreteBayesNet, Dot) { "digraph {\n" " size=\"5,5\";\n" "\n" - " var0[label=\"0\"];\n" - " var3[label=\"3\"];\n" - " var4[label=\"4\"];\n" - " var5[label=\"5\"];\n" - " var6[label=\"6\"];\n" + " var4683743612465315840[label=\"A0\"];\n" + " var4971973988617027589[label=\"E5\"];\n" + " var5476377146882523142[label=\"L6\"];\n" + " var5980780305148018692[label=\"S4\"];\n" + " var6052837899185946627[label=\"T3\"];\n" "\n" - " var3->var5\n" - " var6->var5\n" - " var4->var6\n" - " var0->var3\n" + " var6052837899185946627->var4971973988617027589\n" + " var5476377146882523142->var4971973988617027589\n" + " var5980780305148018692->var5476377146882523142\n" + " var4683743612465315840->var6052837899185946627\n" "}"); } @@ -191,11 +198,234 @@ TEST(DiscreteBayesNet, markdown) { "|:-:|:-:|:-:|\n" "|0|0.8|0.2|\n" "|1|0.7|0.3|\n\n"; - auto formatter = [](Key key) { return key == 0 ? "Asia" : "Smoking"; }; + auto formatter = [](Key key) { return key == keys::A ? "Asia" : "Smoking"; }; string actual = fragment.markdown(formatter); EXPECT(actual == expected); } +/* ************************************************************************* */ +#include +#include +#include +#include +#include +#include + +using Value = size_t; + +// ---------------------------------------------------------------------------- +// 1) SearchNode: store partial assignment and next factor to expand +// ---------------------------------------------------------------------------- +struct SearchNode { + DiscreteValues assignment; + double error; + double bound; + int nextConditional; // index into conditionals + + /// if nextConditional < 0, we've assigned everything. + bool isComplete() const { return nextConditional < 0; } + + /// lower bound on final error for unassigned variables. Stub=0. + double computeBound() const { + // Real code might do partial factor analysis or heuristics. + return 0.0; + } + + /// Expand this node by assigning the next variable + SearchNode expand(const DiscreteConditional& conditional, + const DiscreteValues& fa) const { + // Combine the new frontal assignment with the current partial assignment + SearchNode child; + child.assignment = assignment; + for (auto& kv : fa) { + child.assignment[kv.first] = kv.second; + } + + // Compute the incremental error for this factor + child.error = error + conditional.error(child.assignment); + + // Compute new bound + child.bound = child.error + computeBound(); + + // Next factor index + child.nextConditional = nextConditional - 1; + + return child; + } + + friend std::ostream& operator<<(std::ostream& os, const SearchNode& sn) { + os << "[ error=" << sn.error << " bound=" << sn.bound + << " nextConditional=" << sn.nextConditional << " assignment={" + << sn.assignment << "}]"; + return os; + } +}; + +// ---------------------------------------------------------------------------- +// 2) Priority functor to make a min-heap by bound +// ---------------------------------------------------------------------------- +struct CompareByBound { + bool operator()(const SearchNode& a, const SearchNode& b) const { + return a.bound > b.bound; // smallest bound -> highest priority + } +}; + +// ---------------------------------------------------------------------------- +// 4) A Solution +// ---------------------------------------------------------------------------- +struct Solution { + double error; + DiscreteValues assignment; + Solution(double err, const DiscreteValues& assign) + : error(err), assignment(assign) {} + friend std::ostream& operator<<(std::ostream& os, const Solution& sn) { + os << "[ error=" << sn.error << " assignment={" << sn.assignment << "}]"; + return os; + } +}; + +struct CompareByError { + bool operator()(const Solution& a, const Solution& b) const { + return a.error < b.error; + } +}; + +using Solutions = + std::priority_queue, CompareByError>; + +// Function to print the priority queue +void printPriorityQueue(Solutions pq) { + size_t i = 0; + while (!pq.empty()) { + const Solution& best = pq.top(); + cout << " " << ++i << " Error: " << best.error + << ", Values: " << best.assignment << endl; + pq.pop(); + } +} + +// ---------------------------------------------------------------------------- +// 5) The main branch-and-bound routine for DiscreteBayesNet +// ---------------------------------------------------------------------------- +std::vector branchAndBoundKSolutions(const DiscreteBayesNet& bayesNet, + size_t K) { + // Copy out the conditionals + std::vector> conditionals; + for (auto& factor : bayesNet) { + conditionals.push_back(factor); + } + + // Min-heap of expansions by bound + std::priority_queue, CompareByBound> + expansions; + + // Max-heap of best solutions found so far, keyed by error. + // We keep the largest error on top. + Solutions solutions; + + // 1) Create the root node: no variables assigned, nextConditional = last. + SearchNode root{.assignment = DiscreteValues(), + .error = 0.0, + .nextConditional = static_cast(conditionals.size()) - 1}; + root.bound = root.computeBound(); + std::cout << "Root: " << root << std::endl; + expansions.push(root); + + // 2) Main loop + size_t numExpansions = 0; + while (!expansions.empty()) { + // Pop the partial assignment with the smallest bound + SearchNode current = expansions.top(); + expansions.pop(); + numExpansions++; + std::cout << "Expanding: " << current << std::endl; + + // If we already have K solutions, prune if we cannot beat the worst one. + if (solutions.size() == K) { + double worstError = solutions.top().error; + if (current.bound >= worstError) { + std::cout << "Pruning: bound=" << current.bound + << " worstError=" << worstError << std::endl; + continue; + } + } + + // Check if we have a complete assignment + if (current.isComplete()) { + // We have a new candidate solution + double finalError = current.error; + if (solutions.size() < K) { + std::cout << "Adding solution: " << current << std::endl; + solutions.emplace(finalError, current.assignment); + // print out all solutions: + std::cout << "Best solutions so far:" << std::endl; + printPriorityQueue(solutions); + } else { + double worstError = solutions.top().error; + if (finalError < worstError) { + std::cout << "Replacing solution: " << current << std::endl; + solutions.pop(); + solutions.emplace(finalError, current.assignment); + } + } + continue; + } + + // Expand on the next factor + const auto& conditional = conditionals[current.nextConditional]; + + for (auto& fa : conditional->frontalAssignments()) { + std::cout << "Frontal assignment: " << fa << std::endl; + auto childNode = current.expand(*conditional, fa); + + // Again, prune if we cannot beat the worst solution + if (solutions.size() == K) { + double worstError = solutions.top().error; + if (childNode.bound >= worstError) { + std::cout << "Pruning: bound=" << childNode.bound + << " worstError=" << worstError << std::endl; + continue; + } + } + expansions.push(childNode); + } + } + + std::cout << "Expansions: " << numExpansions << std::endl; + + // 3) Extract solutions from bestSolutions in ascending order of error + std::vector result; + while (!solutions.empty()) { + result.push_back(solutions.top()); + solutions.pop(); + } + std::sort(result.begin(), result.end(), + [](auto& a, auto& b) { return a.error < b.error; }); + + return result; +} + +// ---------------------------------------------------------------------------- +// Example “Unit Tests” (trivial stubs) +// ---------------------------------------------------------------------------- + +TEST_DISABLED(DiscreteBayesNet, EmptyKBest) { + DiscreteBayesNet net; // no factors + auto solutions = branchAndBoundKSolutions(net, 3); + // Expect one solution with empty assignment, error=0 + EXPECT_LONGS_EQUAL(1, solutions.size()); + EXPECT_DOUBLES_EQUAL(0, std::fabs(solutions[0].error), 1e-9); +} + +TEST(DiscreteBayesNet, AsiaKBest) { + DiscreteBayesNet net = constructAsiaExample(); + + auto solutions = branchAndBoundKSolutions(net, 4); + EXPECT(!solutions.empty()); + // All solutions likely tie with error=0 in this stub + EXPECT_DOUBLES_EQUAL(0, std::fabs(solutions[0].error), 1e-9); +} + /* ************************************************************************* */ int main() { TestResult tr;