diff --git a/gtsam/discrete/DiscreteSearch.cpp b/gtsam/discrete/DiscreteSearch.cpp index 569887d7f..43f321d4a 100644 --- a/gtsam/discrete/DiscreteSearch.cpp +++ b/gtsam/discrete/DiscreteSearch.cpp @@ -33,16 +33,16 @@ using Solution = DiscreteSearch::Solution; * conditional to be assigned. */ struct SearchNode { - DiscreteValues assignment; ///< Partial assignment of discrete variables. - double error; ///< Current error for the partial assignment. - double bound; ///< Lower bound on the final error for unassigned variables. - int nextConditional; ///< Index of the next conditional to be assigned. + DiscreteValues assignment; ///< Partial assignment of discrete variables. + double error; ///< Current error for the partial assignment. + double bound; ///< Lower bound on the final error + std::optional next; ///< Index of the next factor to be assigned. /** * @brief Construct the root node for the search. */ static SearchNode Root(size_t numSlots, double bound) { - return {DiscreteValues(), 0.0, bound, static_cast(numSlots) - 1}; + return {DiscreteValues(), 0.0, bound, 0}; } struct Compare { @@ -51,38 +51,22 @@ struct SearchNode { } }; - /** - * @brief Checks if the node represents a complete assignment. - * - * @return True if all variables have been assigned, false otherwise. - */ - inline bool isComplete() const { return nextConditional < 0; } + /// Checks if the node represents a complete assignment. + inline bool isComplete() const { return !next; } - /** - * @brief Expands the node by assigning the next variable. - * - * @param slot The slot to be filled. - * @param fa The frontal assignment for the next variable. - * @return A new SearchNode representing the expanded state. - */ - SearchNode expand(const Slot& slot, const DiscreteValues& fa) const { + /// Expands the node by assigning the next variable(s). + SearchNode expand(const DiscreteValues& fa, const Slot& slot, + std::optional nextSlot) const { // Combine the new frontal assignment with the current partial assignment DiscreteValues newAssignment = assignment; for (auto& [key, value] : fa) { newAssignment[key] = value; } double errorSoFar = error + slot.factor->error(newAssignment); - return {newAssignment, errorSoFar, errorSoFar + slot.heuristic, - nextConditional - 1}; + return {newAssignment, errorSoFar, errorSoFar + slot.heuristic, nextSlot}; } - /** - * @brief Prints the SearchNode to an output stream. - * - * @param os The output stream. - * @param node The SearchNode to be printed. - * @return The output stream. - */ + /// Prints the SearchNode to an output stream. friend std::ostream& operator<<(std::ostream& os, const SearchNode& node) { os << "SearchNode(error=" << node.error << ", bound=" << node.bound << ")"; return os; @@ -150,13 +134,18 @@ class Solutions { } }; +/// @brief Get the factor associated with a node, possibly product of factors. +template +static auto getFactor(const NodeType& node) { + const auto& factors = node->factors; + return factors.size() == 1 ? factors.back() + : DiscreteFactorGraph(factors).product(); +} + DiscreteSearch::DiscreteSearch(const DiscreteEliminationTree& etree) { using NodePtr = std::shared_ptr; auto visitor = [this](const NodePtr& node, int data) { - const auto& factors = node->factors; - const auto factor = factors.size() == 1 - ? factors.back() - : DiscreteFactorGraph(factors).product(); + const auto factor = getFactor(node); const size_t cardinality = factor->cardinality(node->key); std::vector> pairs{{node->key, cardinality}}; const Slot slot{factor, DiscreteValues::CartesianProduct(pairs), 0.0}; @@ -164,19 +153,15 @@ DiscreteSearch::DiscreteSearch(const DiscreteEliminationTree& etree) { return data + 1; }; - const int data = 0; // unused + int data = 0; // unused treeTraversal::DepthFirstForest(etree, data, visitor); - std::reverse(slots_.begin(), slots_.end()); // reverse slots lowerBound_ = computeHeuristic(); } DiscreteSearch::DiscreteSearch(const DiscreteJunctionTree& junctionTree) { using NodePtr = std::shared_ptr; auto visitor = [this](const NodePtr& cluster, int data) { - const auto& factors = cluster->factors; - const auto factor = factors.size() == 1 - ? factors.back() - : DiscreteFactorGraph(factors).product(); + const auto factor = getFactor(cluster); std::vector> pairs; for (Key key : cluster->orderedFrontalKeys) { pairs.emplace_back(key, factor->cardinality(key)); @@ -186,9 +171,8 @@ DiscreteSearch::DiscreteSearch(const DiscreteJunctionTree& junctionTree) { return data + 1; }; - const int data = 0; // unused + int data = 0; // unused treeTraversal::DepthFirstForest(junctionTree, data, visitor); - std::reverse(slots_.begin(), slots_.end()); // reverse slots lowerBound_ = computeHeuristic(); } @@ -210,21 +194,21 @@ DiscreteSearch::DiscreteSearch(const DiscreteBayesNet& bayesNet) { const Slot slot{conditional, conditional->frontalAssignments(), 0.0}; slots_.emplace_back(std::move(slot)); } + std::reverse(slots_.begin(), slots_.end()); lowerBound_ = computeHeuristic(); } DiscreteSearch::DiscreteSearch(const DiscreteBayesTree& bayesTree) { - std::function - collectConditionals = [&](const auto& clique) { - if (!clique) return; - for (const auto& child : clique->children) collectConditionals(child); - auto conditional = clique->conditional(); - const Slot slot{conditional, conditional->frontalAssignments(), 0.0}; - slots_.emplace_back(std::move(slot)); - }; + using NodePtr = DiscreteBayesTree::sharedClique; + auto visitor = [this](const NodePtr& clique, int data) { + auto conditional = clique->conditional(); + const Slot slot{conditional, conditional->frontalAssignments(), 0.0}; + slots_.emplace_back(std::move(slot)); + return data + 1; + }; - slots_.reserve(bayesTree.size()); - for (const auto& root : bayesTree.roots()) collectConditionals(root); + int data = 0; // unused + treeTraversal::DepthFirstForest(bayesTree, data, visitor); lowerBound_ = computeHeuristic(); } @@ -236,59 +220,48 @@ void DiscreteSearch::print(const std::string& name, } } -struct SearchNodeQueue - : public std::priority_queue, - SearchNode::Compare> { - void expandNextNode(const SearchNode& current, const Slot& slot, - Solutions* solutions) { - // If we already have K solutions, prune if we cannot beat the worst one. - if (solutions->prune(current.bound)) { - return; - } - - // Check if we have a complete assignment - if (current.isComplete()) { - solutions->maybeAdd(current.error, current.assignment); - return; - } - - for (auto& fa : slot.assignments) { - auto childNode = current.expand(slot, fa); - - // Again, prune if we cannot beat the worst solution - if (!solutions->prune(childNode.bound)) { - emplace(childNode); - } - } - } -}; +using SearchNodeQueue = std::priority_queue, + SearchNode::Compare>; std::vector DiscreteSearch::run(size_t K) const { + if (slots_.empty()) { + return {Solution(0.0, DiscreteValues())}; + } + Solutions solutions(K); SearchNodeQueue expansions; expansions.push(SearchNode::Root(slots_.size(), lowerBound_)); -#ifdef DISCRETE_SEARCH_DEBUG - size_t numExpansions = 0; -#endif - // Perform the search while (!expansions.empty()) { // Pop the partial assignment with the smallest bound SearchNode current = expansions.top(); expansions.pop(); - // Get the next slot to expand - const auto& slot = slots_[current.nextConditional]; - expansions.expandNextNode(current, slot, &solutions); -#ifdef DISCRETE_SEARCH_DEBUG - ++numExpansions; -#endif - } + // If we already have K solutions, prune if we cannot beat the worst one. + if (solutions.prune(current.bound)) { + continue; + } -#ifdef DISCRETE_SEARCH_DEBUG - std::cout << "Number of expansions: " << numExpansions << std::endl; -#endif + // Check if we have a complete assignment + if (current.isComplete()) { + solutions.maybeAdd(current.error, current.assignment); + continue; + } + + // Get the next slot to expand + const auto& slot = slots_[*current.next]; + std::optional nextSlot = *current.next + 1; + if (nextSlot == slots_.size()) nextSlot.reset(); + for (auto& fa : slot.assignments) { + auto childNode = current.expand(fa, slot, nextSlot); + + // Again, prune if we cannot beat the worst solution + if (!solutions.prune(childNode.bound)) { + expansions.emplace(childNode); + } + } + } // Extract solutions from bestSolutions in ascending order of error return solutions.extractSolutions(); @@ -296,17 +269,16 @@ std::vector DiscreteSearch::run(size_t K) const { // We have a number of factors, each with a max value, and we want to compute // a lower-bound on the cost-to-go for each slot, *not* including this factor. -// For the first slot, this is 0.0, as this is the last slot to be filled, so -// the cost after that is zero. For the second slot, it is h0 = -// -log(max(factor[0])), because after we assign slot[1] we still need to -// assign slot[0], which will cost *at least* h0. We return the estimated -// lower bound of the cost for *all* slots. +// For the last slot, this is 0.0, as the cost after that is zero. +// For the second-to-last slot, it is -log(max(factor[0])), because after we +// assign slot[1] we still need to assign slot[0], which will cost *at least* +// h0. We return the estimated lower bound of the cost for *all* slots. double DiscreteSearch::computeHeuristic() { double error = 0.0; - for (auto& slot : slots_) { - slot.heuristic = error; - Ordering ordering(slot.factor->begin(), slot.factor->end()); - auto maxx = slot.factor->max(ordering); + for (auto it = slots_.rbegin(); it != slots_.rend(); ++it) { + it->heuristic = error; + Ordering ordering(it->factor->begin(), it->factor->end()); + auto maxx = it->factor->max(ordering); error -= std::log(maxx->evaluate({})); } return error; diff --git a/gtsam/discrete/DiscreteSearch.h b/gtsam/discrete/DiscreteSearch.h index 44e605e34..700e41392 100644 --- a/gtsam/discrete/DiscreteSearch.h +++ b/gtsam/discrete/DiscreteSearch.h @@ -125,6 +125,12 @@ class GTSAM_EXPORT DiscreteSearch { /// @name Standard API /// @{ + /// Return lower bound on the cost-to-go for the entire search + double lowerBound() const { return lowerBound_; } + + /// Read access to the slots + const std::vector& slots() const { return slots_; } + /** * @brief Search for the K best solutions. * diff --git a/gtsam/discrete/tests/testDiscreteSearch.cpp b/gtsam/discrete/tests/testDiscreteSearch.cpp index dd5389558..cebddfe8d 100644 --- a/gtsam/discrete/tests/testDiscreteSearch.cpp +++ b/gtsam/discrete/tests/testDiscreteSearch.cpp @@ -77,6 +77,17 @@ TEST(DiscreteBayesNet, AsiaKBest) { // Ask for the MPE auto mpe = search.run(); + // Regression on error lower bound + EXPECT_DOUBLES_EQUAL(1.205536, search.lowerBound(), 1e-5); + + // Check that the cost-to-go heuristic decreases from there + auto slots = search.slots(); + double previousHeuristic = search.lowerBound(); + for (auto&& slot : slots) { + EXPECT(slot.heuristic <= previousHeuristic); + previousHeuristic = slot.heuristic; + } + EXPECT_LONGS_EQUAL(1, mpe.size()); // Regression test: check the MPE solution EXPECT_DOUBLES_EQUAL(1.236627, std::fabs(mpe[0].error), 1e-5);