diff --git a/gtsam/discrete/tests/testDiscreteBayesNet.cpp b/gtsam/discrete/tests/testDiscreteBayesNet.cpp index 0bd81a197..52fc660e5 100644 --- a/gtsam/discrete/tests/testDiscreteBayesNet.cpp +++ b/gtsam/discrete/tests/testDiscreteBayesNet.cpp @@ -96,7 +96,7 @@ TEST(DiscreteBayesNet, Asia) { // Convert to factor graph DiscreteFactorGraph fg(asia); - LONGS_EQUAL(3, fg.back()->size()); + LONGS_EQUAL(1, fg.back()->size()); // Check the marginals we know (of the parent-less nodes) DiscreteMarginals marginals(fg); @@ -164,16 +164,16 @@ TEST(DiscreteBayesNet, Dot) { "digraph {\n" " size=\"5,5\";\n" "\n" - " var4683743612465315840[label=\"A0\"];\n" - " var4971973988617027589[label=\"E5\"];\n" - " var5476377146882523142[label=\"L6\"];\n" - " var5980780305148018692[label=\"S4\"];\n" - " var6052837899185946627[label=\"T3\"];\n" + " var4683743612465315848[label=\"A8\"];\n" + " var4971973988617027587[label=\"E3\"];\n" + " var5476377146882523141[label=\"L5\"];\n" + " var5980780305148018695[label=\"S7\"];\n" + " var6052837899185946630[label=\"T6\"];\n" "\n" - " var6052837899185946627->var4971973988617027589\n" - " var5476377146882523142->var4971973988617027589\n" - " var5980780305148018692->var5476377146882523142\n" - " var4683743612465315840->var6052837899185946627\n" + " var6052837899185946630->var4971973988617027587\n" + " var5476377146882523141->var4971973988617027587\n" + " var5980780305148018695->var5476377146882523141\n" + " var4683743612465315848->var6052837899185946630\n" "}"); } @@ -290,19 +290,61 @@ struct CompareByError { } }; -using Solutions = - std::priority_queue, CompareByError>; +// Define the Solutions class +class Solutions { + private: + size_t maxSize_; + std::priority_queue, CompareByError> pq_; -// Function to print the priority queue -void printPriorityQueue(Solutions pq) { - size_t i = 0; - while (!pq.empty()) { - const Solution& best = pq.top(); - cout << " " << ++i << " Error: " << best.error - << ", Values: " << best.assignment << endl; - pq.pop(); + public: + Solutions(size_t maxSize) : maxSize_(maxSize) {} + + /// Add a solution to the priority queue, possibly evicting the worst one. + /// Return true if we added the solution. + bool maybeAdd(double error, const DiscreteValues& assignment) { + const bool full = pq_.size() == maxSize_; + if (full && error >= pq_.top().error) return false; + if (full) pq_.pop(); + pq_.emplace(error, assignment); + return true; } -} + + /// Check if we have any solutions + bool empty() const { return pq_.empty(); } + + // Method to print all solutions + void print() const { + auto pq = pq_; + while (!pq.empty()) { + const Solution& best = pq.top(); + std::cout << "Error: " << best.error << ", Values: " << best.assignment + << std::endl; + pq.pop(); + } + } + + /// Check if (partial) solution with given bound can be pruned. If we have + /// room, we never prune. Otherwise, prune if lower bound on error is worse + /// than our current worst error. + bool prune(double bound) const { + if (pq_.size() < maxSize_) return false; + double worstError = pq_.top().error; + return (bound >= worstError); + } + + // Method to extract solutions in ascending order of error + std::vector extractSolutions() { + std::vector result; + while (!pq_.empty()) { + result.push_back(pq_.top()); + pq_.pop(); + } + std::sort( + result.begin(), result.end(), + [](const Solution& a, const Solution& b) { return a.error < b.error; }); + return result; + } +}; // ---------------------------------------------------------------------------- // 5) The main branch-and-bound routine for DiscreteBayesNet @@ -321,7 +363,7 @@ std::vector branchAndBoundKSolutions(const DiscreteBayesNet& bayesNet, // Max-heap of best solutions found so far, keyed by error. // We keep the largest error on top. - Solutions solutions; + Solutions solutions(K); // 1) Create the root node: no variables assigned, nextConditional = last. SearchNode root{.assignment = DiscreteValues(), @@ -341,32 +383,17 @@ std::vector branchAndBoundKSolutions(const DiscreteBayesNet& bayesNet, std::cout << "Expanding: " << current << std::endl; // If we already have K solutions, prune if we cannot beat the worst one. - if (solutions.size() == K) { - double worstError = solutions.top().error; - if (current.bound >= worstError) { - std::cout << "Pruning: bound=" << current.bound - << " worstError=" << worstError << std::endl; - continue; - } + if (solutions.prune(current.bound)) { + std::cout << "Pruning: bound=" << current.bound << std::endl; + continue; } // Check if we have a complete assignment if (current.isComplete()) { - // We have a new candidate solution - double finalError = current.error; - if (solutions.size() < K) { - std::cout << "Adding solution: " << current << std::endl; - solutions.emplace(finalError, current.assignment); - // print out all solutions: + const bool added = solutions.maybeAdd(current.error, current.assignment); + if (added) { std::cout << "Best solutions so far:" << std::endl; - printPriorityQueue(solutions); - } else { - double worstError = solutions.top().error; - if (finalError < worstError) { - std::cout << "Replacing solution: " << current << std::endl; - solutions.pop(); - solutions.emplace(finalError, current.assignment); - } + solutions.print(); } continue; } @@ -379,14 +406,11 @@ std::vector branchAndBoundKSolutions(const DiscreteBayesNet& bayesNet, auto childNode = current.expand(*conditional, fa); // Again, prune if we cannot beat the worst solution - if (solutions.size() == K) { - double worstError = solutions.top().error; - if (childNode.bound >= worstError) { - std::cout << "Pruning: bound=" << childNode.bound - << " worstError=" << worstError << std::endl; - continue; - } + if (solutions.prune(current.bound)) { + std::cout << "Pruning: bound=" << childNode.bound << std::endl; + continue; } + expansions.push(childNode); } } @@ -394,15 +418,7 @@ std::vector branchAndBoundKSolutions(const DiscreteBayesNet& bayesNet, std::cout << "Expansions: " << numExpansions << std::endl; // 3) Extract solutions from bestSolutions in ascending order of error - std::vector result; - while (!solutions.empty()) { - result.push_back(solutions.top()); - solutions.pop(); - } - std::sort(result.begin(), result.end(), - [](auto& a, auto& b) { return a.error < b.error; }); - - return result; + return solutions.extractSolutions(); } // ---------------------------------------------------------------------------- @@ -422,8 +438,8 @@ TEST(DiscreteBayesNet, AsiaKBest) { auto solutions = branchAndBoundKSolutions(net, 4); EXPECT(!solutions.empty()); - // All solutions likely tie with error=0 in this stub - EXPECT_DOUBLES_EQUAL(0, std::fabs(solutions[0].error), 1e-9); + // Regression test: check the first solution + EXPECT_DOUBLES_EQUAL(1.236627, std::fabs(solutions[0].error), 1e-5); } /* ************************************************************************* */