diff --git a/gtsam/discrete/DiscreteSearch.cpp b/gtsam/discrete/DiscreteSearch.cpp index 81ebc6012..70f7f871a 100644 --- a/gtsam/discrete/DiscreteSearch.cpp +++ b/gtsam/discrete/DiscreteSearch.cpp @@ -76,34 +76,84 @@ std::vector Solutions::extractSolutions() { return result; } -DiscreteSearch::DiscreteSearch(const DiscreteBayesNet& bayesNet, size_t K) - : solutions_(K) { +DiscreteSearch::DiscreteSearch(const DiscreteBayesNet& bayesNet) { std::vector conditionals; - for (auto& factor : bayesNet) conditionals.push_back(factor); - initialize(conditionals); + for (auto& factor : bayesNet) conditionals_.push_back(factor); + costToGo_ = computeCostToGo(conditionals_); } -DiscreteSearch::DiscreteSearch(const DiscreteBayesTree& bayesTree, size_t K) - : solutions_(K) { - std::vector conditionals; +DiscreteSearch::DiscreteSearch(const DiscreteBayesTree& bayesTree) { std::function collectConditionals = [&](const auto& clique) { if (!clique) return; for (const auto& child : clique->children) collectConditionals(child); - conditionals.push_back(clique->conditional()); + conditionals_.push_back(clique->conditional()); }; for (const auto& root : bayesTree.roots()) collectConditionals(root); - initialize(conditionals); + costToGo_ = computeCostToGo(conditionals_); }; -std::vector DiscreteSearch::run() { - while (!expansions_.empty()) { - numExpansions++; +using SearchNodeQueue = std::priority_queue, + SearchNode::Compare>; + +std::vector DiscreteSearch::run(size_t K) const { + SearchNodeQueue expansions; + expansions.push(SearchNode::Root(conditionals_.size(), + costToGo_.empty() ? 0.0 : costToGo_.back())); + + Solutions solutions(K); + + auto expandNextNode = [&] { + // Pop the partial assignment with the smallest bound + SearchNode current = expansions.top(); + expansions.pop(); + + // If we already have K solutions, prune if we cannot beat the worst + // one. + if (solutions.prune(current.bound)) { + return; + } + + // Check if we have a complete assignment + if (current.isComplete()) { + solutions.maybeAdd(current.error, current.assignment); + return; + } + + // Expand on the next factor + const auto& conditional = conditionals_[current.nextConditional]; + + for (auto& fa : conditional->frontalAssignments()) { + auto childNode = current.expand(*conditional, fa); + if (childNode.nextConditional >= 0) + childNode.bound = + childNode.error + costToGo_[childNode.nextConditional]; + + // Again, prune if we cannot beat the worst solution + if (!solutions.prune(childNode.bound)) { + expansions.emplace(childNode); + } + } + }; + +#ifdef DISCRETE_SEARCH_DEBUG + size_t numExpansions = 0; +#endif + + // Perform the search + while (!expansions.empty()) { expandNextNode(); +#ifdef DISCRETE_SEARCH_DEBUG + ++numExpansions; +#endif } +#ifdef DISCRETE_SEARCH_DEBUG + std::cout << "Number of expansions: " << numExpansions << std::endl; +#endif + // Extract solutions from bestSolutions in ascending order of error - return solutions_.extractSolutions(); + return solutions.extractSolutions(); } std::vector DiscreteSearch::computeCostToGo( @@ -120,35 +170,4 @@ std::vector DiscreteSearch::computeCostToGo( return costToGo; } -void DiscreteSearch::expandNextNode() { - // Pop the partial assignment with the smallest bound - SearchNode current = expansions_.top(); - expansions_.pop(); - - // If we already have K solutions, prune if we cannot beat the worst one. - if (solutions_.prune(current.bound)) { - return; - } - - // Check if we have a complete assignment - if (current.isComplete()) { - solutions_.maybeAdd(current.error, current.assignment); - return; - } - - // Expand on the next factor - const auto& conditional = conditionals_[current.nextConditional]; - - for (auto& fa : conditional->frontalAssignments()) { - auto childNode = current.expand(*conditional, fa); - if (childNode.nextConditional >= 0) - childNode.bound = childNode.error + costToGo_[childNode.nextConditional]; - - // Again, prune if we cannot beat the worst solution - if (!solutions_.prune(childNode.bound)) { - expansions_.emplace(childNode); - } - } -} - } // namespace gtsam diff --git a/gtsam/discrete/DiscreteSearch.h b/gtsam/discrete/DiscreteSearch.h index cbf258b3f..f5bd67bae 100644 --- a/gtsam/discrete/DiscreteSearch.h +++ b/gtsam/discrete/DiscreteSearch.h @@ -130,17 +130,15 @@ class Solutions { */ class DiscreteSearch { public: - size_t numExpansions = 0; - /** * Construct from a DiscreteBayesNet and K. */ - DiscreteSearch(const DiscreteBayesNet& bayesNet, size_t K = 1); + DiscreteSearch(const DiscreteBayesNet& bayesNet); /** * Construct from a DiscreteBayesTree and K. */ - DiscreteSearch(const DiscreteBayesTree& bayesTree, size_t K = 1); + DiscreteSearch(const DiscreteBayesTree& bayesTree); /** * @brief Search for the K best solutions. @@ -152,29 +150,17 @@ class DiscreteSearch { * * @return A vector of the K best solutions found during the search. */ - std::vector run(); + std::vector run(size_t K = 1) const; private: - /// Initialize the search with the given conditionals. - void initialize( - const std::vector& conditionals) { - conditionals_ = conditionals; - costToGo_ = computeCostToGo(conditionals_); - expansions_.push(SearchNode::Root( - conditionals_.size(), costToGo_.empty() ? 0.0 : costToGo_.back())); - } - /// Compute the cumulative cost-to-go for each conditional slot. static std::vector computeCostToGo( const std::vector& conditionals); /// Expand the next node in the search tree. - void expandNextNode(); + void expandNextNode() const; std::vector conditionals_; std::vector costToGo_; - std::priority_queue, SearchNode::Compare> - expansions_; - Solutions solutions_; }; } // namespace gtsam diff --git a/gtsam/discrete/tests/testDiscreteSearch.cpp b/gtsam/discrete/tests/testDiscreteSearch.cpp index e301ed3d4..a3b50add4 100644 --- a/gtsam/discrete/tests/testDiscreteSearch.cpp +++ b/gtsam/discrete/tests/testDiscreteSearch.cpp @@ -35,8 +35,8 @@ using namespace gtsam; /* ************************************************************************* */ TEST(DiscreteBayesNet, EmptyKBest) { DiscreteBayesNet net; // no factors - DiscreteSearch search(net, 3); - auto solutions = search.run(); + DiscreteSearch search(net); + auto solutions = search.run(3); // Expect one solution with empty assignment, error=0 EXPECT_LONGS_EQUAL(1, solutions.size()); EXPECT_DOUBLES_EQUAL(0, std::fabs(solutions[0].error), 1e-9); @@ -46,23 +46,17 @@ TEST(DiscreteBayesNet, EmptyKBest) { TEST(DiscreteBayesNet, AsiaKBest) { using namespace asia_example; DiscreteBayesNet asia = createAsiaExample(); + DiscreteSearch search(asia); // Ask for the MPE - DiscreteSearch search1(asia); - auto mpe = search1.run(); - - // print numExpansions - std::cout << "Number of expansions: " << search1.numExpansions << std::endl; + auto mpe = search.run(); EXPECT_LONGS_EQUAL(1, mpe.size()); // Regression test: check the MPE solution EXPECT_DOUBLES_EQUAL(1.236627, std::fabs(mpe[0].error), 1e-5); - DiscreteSearch search(asia, 4); - auto solutions = search.run(); - - // print numExpansions - std::cout << "Number of expansions: " << search.numExpansions << std::endl; + // Ask for top 4 solutions + auto solutions = search.run(4); EXPECT_LONGS_EQUAL(4, solutions.size()); // Regression test: check the first and last solution @@ -74,8 +68,8 @@ TEST(DiscreteBayesNet, AsiaKBest) { TEST(DiscreteBayesTree, EmptyTree) { DiscreteBayesTree bt; - DiscreteSearch search(bt, 3); - auto solutions = search.run(); + DiscreteSearch search(bt); + auto solutions = search.run(3); // We expect exactly 1 solution with error = 0.0 (the empty assignment). assert(solutions.size() == 1 && "There should be exactly one empty solution"); @@ -89,24 +83,17 @@ TEST(DiscreteBayesTree, AsiaTreeKBest) { DiscreteFactorGraph asia(createAsiaExample()); const Ordering ordering{D, X, B, E, L, T, S, A}; DiscreteBayesTree::shared_ptr bt = asia.eliminateMultifrontal(ordering); + DiscreteSearch search(*bt); - // Ask for top 4 solutions - DiscreteSearch search1(*bt); - auto mpe = search1.run(); - - // print numExpansions - std::cout << "Number of expansions: " << search1.numExpansions << std::endl; + // Ask for MPE + auto mpe = search.run(); EXPECT_LONGS_EQUAL(1, mpe.size()); // Regression test: check the MPE solution EXPECT_DOUBLES_EQUAL(1.236627, std::fabs(mpe[0].error), 1e-5); // Ask for top 4 solutions - DiscreteSearch search(*bt, 4); - auto solutions = search.run(); - - // print numExpansions - std::cout << "Number of expansions: " << search.numExpansions << std::endl; + auto solutions = search.run(4); EXPECT_LONGS_EQUAL(4, solutions.size()); // Regression test: check the first and last solution