diff --git a/gtsam/discrete/DiscreteSearch.cpp b/gtsam/discrete/DiscreteSearch.cpp index 439331fa1..4270dfbf7 100644 --- a/gtsam/discrete/DiscreteSearch.cpp +++ b/gtsam/discrete/DiscreteSearch.cpp @@ -150,21 +150,58 @@ class Solutions { } }; -DiscreteSearch::DiscreteSearch(const DiscreteFactorGraph& factorGraph, - const Ordering& ordering, - bool buildJunctionTree) { - const DiscreteEliminationTree etree(factorGraph, ordering); - const DiscreteJunctionTree junctionTree(etree); +DiscreteSearch::DiscreteSearch(const DiscreteEliminationTree& etree) { + using NodePtr = std::shared_ptr; + auto visitor = [this](const NodePtr& node, int data) { + const auto& factors = node->factors; + const auto factor = factors.size() == 1 + ? factors.back() + : DiscreteFactorGraph(factors).product(); + const size_t cardinality = factor->cardinality(node->key); + std::vector> pairs{{node->key, cardinality}}; + slots_.emplace_back(factor, DiscreteValues::CartesianProduct(pairs), 0.0); + return data + 1; + }; - // GTSAM_PRINT(asia::etree); - // GTSAM_PRINT(asia::junctionTree); - slots_.reserve(factorGraph.size()); - for (auto& factor : factorGraph) { - slots_.emplace_back(factor, std::vector{}, 0.0); - } + const int data = 0; // unused + treeTraversal::DepthFirstForest(etree, data, visitor); + std::reverse(slots_.begin(), slots_.end()); // reverse slots lowerBound_ = computeHeuristic(); } +DiscreteSearch::DiscreteSearch(const DiscreteJunctionTree& junctionTree) { + using NodePtr = std::shared_ptr; + auto visitor = [this](const NodePtr& cluster, int data) { + const auto& factors = cluster->factors; + const auto factor = factors.size() == 1 + ? factors.back() + : DiscreteFactorGraph(factors).product(); + std::vector> pairs; + for (Key key : cluster->orderedFrontalKeys) { + pairs.emplace_back(key, factor->cardinality(key)); + } + slots_.emplace_back(factor, DiscreteValues::CartesianProduct(pairs), 0.0); + return data + 1; + }; + + const int data = 0; // unused + treeTraversal::DepthFirstForest(junctionTree, data, visitor); + std::reverse(slots_.begin(), slots_.end()); // reverse slots + lowerBound_ = computeHeuristic(); +} + +DiscreteSearch DiscreteSearch::FromFactorGraph( + const DiscreteFactorGraph& factorGraph, const Ordering& ordering, + bool buildJunctionTree) { + const DiscreteEliminationTree etree(factorGraph, ordering); + if (buildJunctionTree) { + const DiscreteJunctionTree junctionTree(etree); + return DiscreteSearch(junctionTree); + } else { + return DiscreteSearch(etree); + } +} + DiscreteSearch::DiscreteSearch(const DiscreteBayesNet& bayesNet) { slots_.reserve(bayesNet.size()); for (auto& conditional : bayesNet) { @@ -188,6 +225,14 @@ DiscreteSearch::DiscreteSearch(const DiscreteBayesTree& bayesTree) { lowerBound_ = computeHeuristic(); } +void DiscreteSearch::print(const std::string& name, + const KeyFormatter& formatter) const { + std::cout << name << " with " << slots_.size() << " slots:\n"; + for (size_t i = 0; i < slots_.size(); ++i) { + std::cout << i << ": " << slots_[i] << std::endl; + } +} + struct SearchNodeQueue : public std::priority_queue, SearchNode::Compare> { @@ -215,8 +260,6 @@ struct SearchNodeQueue } }; -#define DISCRETE_SEARCH_DEBUG - std::vector DiscreteSearch::run(size_t K) const { Solutions solutions(K); SearchNodeQueue expansions; @@ -252,9 +295,9 @@ std::vector DiscreteSearch::run(size_t K) const { // a lower-bound on the cost-to-go for each slot, *not* including this factor. // For the first slot, this is 0.0, as this is the last slot to be filled, so // the cost after that is zero. For the second slot, it is h0 = -// -log(max(factor[0])), because after we assign slot[1] we still need to assign -// slot[0], which will cost *at least* h0. -// We return the estimated lower bound of the cost for *all* slots. +// -log(max(factor[0])), because after we assign slot[1] we still need to +// assign slot[0], which will cost *at least* h0. We return the estimated +// lower bound of the cost for *all* slots. double DiscreteSearch::computeHeuristic() { double error = 0.0; for (size_t i = 0; i < slots_.size(); ++i) { diff --git a/gtsam/discrete/DiscreteSearch.h b/gtsam/discrete/DiscreteSearch.h index 3d5ee1d50..44e605e34 100644 --- a/gtsam/discrete/DiscreteSearch.h +++ b/gtsam/discrete/DiscreteSearch.h @@ -10,7 +10,11 @@ * -------------------------------------------------------------------------- */ /* - * DiscreteSearch.cpp + * @file DiscreteSearch.h + * @brief Defines the DiscreteSearch class for discrete search algorithms. + * + * @details This file contains the definition of the DiscreteSearch class, which + * is used in discrete search algorithms to find the K best solutions. * * @date January, 2025 * @author Frank Dellaert @@ -43,6 +47,13 @@ class GTSAM_EXPORT DiscreteSearch { /// A lower bound on the cost-to-go for each slot, e.g., /// [-log(max_B P(B|A)), -log(max_A P(A))] double heuristic; + + friend std::ostream& operator<<(std::ostream& os, const Slot& slot) { + os << "Slot with " << slot.assignments.size() + << " assignments, heuristic=" << slot.heuristic; + os << ", factor:\n" << slot.factor->markdown() << std::endl; + return os; + } }; /// A solution is then a set of assignments, covering all the slots. @@ -59,6 +70,9 @@ class GTSAM_EXPORT DiscreteSearch { }; public: + /// @name Standard Constructors + /// @{ + /** * Construct from a DiscreteFactorGraph. * @@ -68,12 +82,26 @@ class GTSAM_EXPORT DiscreteSearch { * fine-grained (more slots). * * @param factorGraph The factor graph to search over. - * @param ordering The ordering of the variables to search over. - * @param buildJunctionTree Whether to build a junction tree for the factor - * graph. + * @param ordering The ordering used to create etree (and maybe jtree). + * @param buildJunctionTree Whether to build a junction tree or not. */ - DiscreteSearch(const DiscreteFactorGraph& factorGraph, - const Ordering& ordering, bool buildJunctionTree = false); + static DiscreteSearch FromFactorGraph(const DiscreteFactorGraph& factorGraph, + const Ordering& ordering, + bool buildJunctionTree = false); + + /** + * @brief Constructor from a DiscreteEliminationTree. + * + * @param etree The DiscreteEliminationTree to initialize from. + */ + DiscreteSearch(const DiscreteEliminationTree& etree); + + /** + * @brief Constructor from a DiscreteJunctionTree. + * + * @param junctionTree The DiscreteJunctionTree to initialize from. + */ + DiscreteSearch(const DiscreteJunctionTree& junctionTree); /** * Construct from a DiscreteBayesNet. @@ -85,6 +113,18 @@ class GTSAM_EXPORT DiscreteSearch { */ DiscreteSearch(const DiscreteBayesTree& bayesTree); + /// @} + /// @name Testable + /// @{ + + /** Print the tree to cout */ + void print(const std::string& name = "DiscreteSearch: ", + const KeyFormatter& formatter = DefaultKeyFormatter) const; + + /// @} + /// @name Standard API + /// @{ + /** * @brief Search for the K best solutions. * @@ -97,6 +137,8 @@ class GTSAM_EXPORT DiscreteSearch { */ std::vector run(size_t K = 1) const; + /// @} + private: /// Compute the cumulative lower-bound cost-to-go after each slot is filled. /// @return the estimated lower bound of the cost for *all* slots.