diff --git a/gtsam/discrete/tests/testDiscreteBayesNet.cpp b/gtsam/discrete/tests/testDiscreteBayesNet.cpp index 52fc660e5..648b26770 100644 --- a/gtsam/discrete/tests/testDiscreteBayesNet.cpp +++ b/gtsam/discrete/tests/testDiscreteBayesNet.cpp @@ -346,97 +346,119 @@ class Solutions { } }; -// ---------------------------------------------------------------------------- -// 5) The main branch-and-bound routine for DiscreteBayesNet -// ---------------------------------------------------------------------------- -std::vector branchAndBoundKSolutions(const DiscreteBayesNet& bayesNet, - size_t K) { - // Copy out the conditionals - std::vector> conditionals; - for (auto& factor : bayesNet) { - conditionals.push_back(factor); +/** + * BestKSearch: Search for the K best solutions. + */ +class BestKSearch { + public: + /** + * Construct from a DiscreteBayesNet and K. + */ + BestKSearch(const DiscreteBayesNet& bayesNet, size_t K) + : bayesNet_(bayesNet), solutions_(K) { + // Copy out the conditionals + for (auto& factor : bayesNet_) { + conditionals_.push_back(factor); + } + + // Create the root node: no variables assigned, nextConditional = last. + SearchNode root{ + .assignment = DiscreteValues(), + .error = 0.0, + .nextConditional = static_cast(conditionals_.size()) - 1}; + root.bound = root.computeBound(); + std::cout << "Root: " << root << std::endl; + expansions_.push(root); } - // Min-heap of expansions by bound - std::priority_queue, CompareByBound> - expansions; + /** + * @brief Search for the K best solutions. + * + * This method performs a search to find the K best solutions for the given + * DiscreteBayesNet. It uses a priority queue to manage the search nodes, + * expanding nodes with the smallest bound first. The search continues until + * all possible nodes have been expanded or pruned. + * + * @return A vector of the K best solutions found during the search. + */ + std::vector run() { + size_t numExpansions = 0; + while (!expansions_.empty()) { + expandNextNode(); + numExpansions++; + } - // Max-heap of best solutions found so far, keyed by error. - // We keep the largest error on top. - Solutions solutions(K); + std::cout << "Expansions: " << numExpansions << std::endl; - // 1) Create the root node: no variables assigned, nextConditional = last. - SearchNode root{.assignment = DiscreteValues(), - .error = 0.0, - .nextConditional = static_cast(conditionals.size()) - 1}; - root.bound = root.computeBound(); - std::cout << "Root: " << root << std::endl; - expansions.push(root); + // Extract solutions from bestSolutions in ascending order of error + return solutions_.extractSolutions(); + } - // 2) Main loop - size_t numExpansions = 0; - while (!expansions.empty()) { + private: + // + void expandNextNode() { // Pop the partial assignment with the smallest bound - SearchNode current = expansions.top(); - expansions.pop(); - numExpansions++; + SearchNode current = expansions_.top(); + expansions_.pop(); std::cout << "Expanding: " << current << std::endl; // If we already have K solutions, prune if we cannot beat the worst one. - if (solutions.prune(current.bound)) { + if (solutions_.prune(current.bound)) { std::cout << "Pruning: bound=" << current.bound << std::endl; - continue; + return; } // Check if we have a complete assignment if (current.isComplete()) { - const bool added = solutions.maybeAdd(current.error, current.assignment); + const bool added = solutions_.maybeAdd(current.error, current.assignment); if (added) { std::cout << "Best solutions so far:" << std::endl; - solutions.print(); + solutions_.print(); } - continue; + return; } // Expand on the next factor - const auto& conditional = conditionals[current.nextConditional]; + const auto& conditional = conditionals_[current.nextConditional]; for (auto& fa : conditional->frontalAssignments()) { std::cout << "Frontal assignment: " << fa << std::endl; auto childNode = current.expand(*conditional, fa); // Again, prune if we cannot beat the worst solution - if (solutions.prune(current.bound)) { + if (solutions_.prune(current.bound)) { std::cout << "Pruning: bound=" << childNode.bound << std::endl; continue; } - expansions.push(childNode); + expansions_.push(childNode); } } - std::cout << "Expansions: " << numExpansions << std::endl; - - // 3) Extract solutions from bestSolutions in ascending order of error - return solutions.extractSolutions(); -} + const DiscreteBayesNet& bayesNet_; + std::vector> conditionals_; + std::priority_queue, CompareByBound> + expansions_; + Solutions solutions_; +}; // ---------------------------------------------------------------------------- // Example “Unit Tests” (trivial stubs) // ---------------------------------------------------------------------------- -TEST_DISABLED(DiscreteBayesNet, EmptyKBest) { +TEST(DiscreteBayesNet, EmptyKBest) { DiscreteBayesNet net; // no factors - auto solutions = branchAndBoundKSolutions(net, 3); + BestKSearch search(net, 3); + auto solutions = search.run(); // Expect one solution with empty assignment, error=0 EXPECT_LONGS_EQUAL(1, solutions.size()); EXPECT_DOUBLES_EQUAL(0, std::fabs(solutions[0].error), 1e-9); } TEST(DiscreteBayesNet, AsiaKBest) { - DiscreteBayesNet net = constructAsiaExample(); - - auto solutions = branchAndBoundKSolutions(net, 4); + DiscreteBayesNet asia = constructAsiaExample(); + BestKSearch search(asia, 4); + auto solutions = search.run(); EXPECT(!solutions.empty()); // Regression test: check the first solution EXPECT_DOUBLES_EQUAL(1.236627, std::fabs(solutions[0].error), 1e-5);