diff --git a/gtsam/discrete/DiscreteSearch.cpp b/gtsam/discrete/DiscreteSearch.cpp index c5941862d..e3722c13e 100644 --- a/gtsam/discrete/DiscreteSearch.cpp +++ b/gtsam/discrete/DiscreteSearch.cpp @@ -20,6 +20,7 @@ namespace gtsam { +using Slot = DiscreteSearch::Slot; using Solution = DiscreteSearch::Solution; /** @@ -59,12 +60,12 @@ struct SearchNode { /** * @brief Expands the node by assigning the next variable. * - * @param conditional The discrete conditional representing the next variable + * @param factor The discrete factor associated with the next variable * to be assigned. * @param fa The frontal assignment for the next variable. * @return A new SearchNode representing the expanded state. */ - SearchNode expand(const DiscreteConditional& conditional, + SearchNode expand(const DiscreteFactor& factor, const DiscreteValues& fa) const { // Combine the new frontal assignment with the current partial assignment DiscreteValues newAssignment = assignment; @@ -72,7 +73,7 @@ struct SearchNode { newAssignment[key] = value; } - return {newAssignment, error + conditional.error(newAssignment), 0.0, + return {newAssignment, error + factor.error(newAssignment), 0.0, nextConditional - 1}; } @@ -150,10 +151,20 @@ class Solutions { } }; +DiscreteSearch::DiscreteSearch(const DiscreteFactorGraph& factorGraph) { + slots_.reserve(factorGraph.size()); + for (auto& factor : factorGraph) { + slots_.emplace_back(factor, std::vector{}, 0.0); + } + computeHeuristic(); +} + DiscreteSearch::DiscreteSearch(const DiscreteBayesNet& bayesNet) { - std::vector conditionals; - for (auto& factor : bayesNet) conditionals_.push_back(factor); - costToGo_ = computeCostToGo(conditionals_); + slots_.reserve(bayesNet.size()); + for (auto& conditional : bayesNet) { + slots_.emplace_back(conditional, conditional->frontalAssignments(), 0.0); + } + computeHeuristic(); } DiscreteSearch::DiscreteSearch(const DiscreteBayesTree& bayesTree) { @@ -161,22 +172,21 @@ DiscreteSearch::DiscreteSearch(const DiscreteBayesTree& bayesTree) { collectConditionals = [&](const auto& clique) { if (!clique) return; for (const auto& child : clique->children) collectConditionals(child); - conditionals_.push_back(clique->conditional()); + auto conditional = clique->conditional(); + slots_.emplace_back(conditional, conditional->frontalAssignments(), + 0.0); }; + + slots_.reserve(bayesTree.size()); for (const auto& root : bayesTree.roots()) collectConditionals(root); - costToGo_ = computeCostToGo(conditionals_); + computeHeuristic(); } struct SearchNodeQueue : public std::priority_queue, SearchNode::Compare> { - void expandNextNode( - const std::vector& conditionals, - const std::vector& costToGo, Solutions* solutions) { - // Pop the partial assignment with the smallest bound - SearchNode current = top(); - pop(); - + void expandNextNode(const SearchNode& current, const Slot& slot, + Solutions* solutions) { // If we already have K solutions, prune if we cannot beat the worst one. if (solutions->prune(current.bound)) { return; @@ -188,13 +198,11 @@ struct SearchNodeQueue return; } - // Expand on the next factor - const auto& conditional = conditionals[current.nextConditional]; - - for (auto& fa : conditional->frontalAssignments()) { - auto childNode = current.expand(*conditional, fa); + for (auto& fa : slot.assignments) { + auto childNode = current.expand(*slot.factor, fa); if (childNode.nextConditional >= 0) - childNode.bound = childNode.error + costToGo[childNode.nextConditional]; + // TODO(frank): this might be wrong ! + childNode.bound = childNode.error + slot.heuristic; // Again, prune if we cannot beat the worst solution if (!solutions->prune(childNode.bound)) { @@ -207,8 +215,7 @@ struct SearchNodeQueue std::vector DiscreteSearch::run(size_t K) const { Solutions solutions(K); SearchNodeQueue expansions; - expansions.push(SearchNode::Root(conditionals_.size(), - costToGo_.empty() ? 0.0 : costToGo_.back())); + expansions.push(SearchNode::Root(slots_.size(), slots_.back().heuristic)); #ifdef DISCRETE_SEARCH_DEBUG size_t numExpansions = 0; @@ -216,7 +223,13 @@ std::vector DiscreteSearch::run(size_t K) const { // Perform the search while (!expansions.empty()) { - expansions.expandNextNode(conditionals_, costToGo_, &solutions); + // Pop the partial assignment with the smallest bound + SearchNode current = expansions.top(); + expansions.pop(); + + // Get the next slot to expand + const auto& slot = slots_[current.nextConditional]; + expansions.expandNextNode(current, slot, &solutions); #ifdef DISCRETE_SEARCH_DEBUG ++numExpansions; #endif @@ -230,17 +243,19 @@ std::vector DiscreteSearch::run(size_t K) const { return solutions.extractSolutions(); } -std::vector DiscreteSearch::computeCostToGo( - const std::vector& conditionals) { - std::vector costToGo; +// We have a number of factors, each with a max value, and we want to compute +// the a lower-bound on the cost-to-go for each slot. For the first slot, this +// -log(max(factor[0])), as we only have one factor to resolve. For the second +// slot, we need to add -log(max(factor[1])) to it, etc... +void DiscreteSearch::computeHeuristic() { double error = 0.0; - for (const auto& conditional : conditionals) { - Ordering ordering(conditional->begin(), conditional->end()); - auto maxx = conditional->max(ordering); + for (size_t i = 0; i < slots_.size(); ++i) { + const auto& factor = slots_[i].factor; + Ordering ordering(factor->begin(), factor->end()); + auto maxx = factor->max(ordering); error -= std::log(maxx->evaluate({})); - costToGo.push_back(error); + slots_[i].heuristic = error; } - return costToGo; } } // namespace gtsam diff --git a/gtsam/discrete/DiscreteSearch.h b/gtsam/discrete/DiscreteSearch.h index 6202880b2..ddeaf38f1 100644 --- a/gtsam/discrete/DiscreteSearch.h +++ b/gtsam/discrete/DiscreteSearch.h @@ -28,9 +28,25 @@ namespace gtsam { */ class GTSAM_EXPORT DiscreteSearch { public: - /** - * @brief A solution to a discrete search problem. - */ + /// We structure the search as a set of slots, each with a factor and + /// a set of variable assignments that need to be chosen. In addition, each + /// slot has a heuristic associated with it. + struct Slot { + /// The factors in the search problem, + /// e.g., [P(B|A),P(A)] + DiscreteFactor::shared_ptr factor; + + /// The assignments for each factor, + /// e.g., [[B0,B1] [A0,A1]] + std::vector assignments; + + /// A lower bound on the cost-to-go for each slot, e.g., + /// [-log(max_B P(B|A)), -log(max_A P(A))] + double heuristic; + }; + + /// A solution is then a set of assignments, covering all the slots. + /// as well as an associated error = -log(probability) struct Solution { double error; DiscreteValues assignment; @@ -42,13 +58,19 @@ class GTSAM_EXPORT DiscreteSearch { } }; + public: /** - * Construct from a DiscreteBayesNet and K. + * Construct from a DiscreteFactorGraph. + */ + DiscreteSearch(const DiscreteFactorGraph& bayesNet); + + /** + * Construct from a DiscreteBayesNet. */ DiscreteSearch(const DiscreteBayesNet& bayesNet); /** - * Construct from a DiscreteBayesTree and K. + * Construct from a DiscreteBayesTree. */ DiscreteSearch(const DiscreteBayesTree& bayesTree); @@ -65,14 +87,9 @@ class GTSAM_EXPORT DiscreteSearch { std::vector run(size_t K = 1) const; private: - /// Compute the cumulative cost-to-go for each conditional slot. - static std::vector computeCostToGo( - const std::vector& conditionals); + /// Compute the cumulative lower-bound cost-to-go for each slot. + void computeHeuristic(); - /// Expand the next node in the search tree. - void expandNextNode() const; - - std::vector conditionals_; - std::vector costToGo_; + std::vector slots_; }; } // namespace gtsam