diff --git a/gtsam/discrete/DiscreteJunctionTree.cpp b/gtsam/discrete/DiscreteJunctionTree.cpp index dc24860eb..bf9f9fe18 100644 --- a/gtsam/discrete/DiscreteJunctionTree.cpp +++ b/gtsam/discrete/DiscreteJunctionTree.cpp @@ -16,19 +16,35 @@ * @author Richard Roberts */ -#include -#include #include +#include +#include namespace gtsam { - // Instantiate base classes - template class EliminatableClusterTree; - template class JunctionTree; +// Instantiate base classes +template class EliminatableClusterTree; +template class JunctionTree; - /* ************************************************************************* */ - DiscreteJunctionTree::DiscreteJunctionTree( - const DiscreteEliminationTree& eliminationTree) : - Base(eliminationTree) {} +/* ************************************************************************* */ +DiscreteJunctionTree::DiscreteJunctionTree( + const DiscreteEliminationTree& eliminationTree) + : Base(eliminationTree) {} +/* ************************************************************************* */ +void DiscreteJunctionTree::print(const std::string& s, + const KeyFormatter& keyFormatter) const { + auto visitor = [&keyFormatter]( + const std::shared_ptr& node, + const std::string& parentString) { + // Print the current node + node->print(parentString + "-", keyFormatter); + node->factors.print(parentString + "-", keyFormatter); + std::cout << std::endl; + return parentString + "| "; // Increment the indentation + }; + std::string parentString = s; + treeTraversal::DepthFirstForest(*this, parentString, visitor); } + +} // namespace gtsam diff --git a/gtsam/discrete/DiscreteJunctionTree.h b/gtsam/discrete/DiscreteJunctionTree.h index f6171c672..4b9241036 100644 --- a/gtsam/discrete/DiscreteJunctionTree.h +++ b/gtsam/discrete/DiscreteJunctionTree.h @@ -18,54 +18,71 @@ #pragma once -#include #include +#include #include namespace gtsam { - // Forward declarations - class DiscreteEliminationTree; +// Forward declarations +class DiscreteEliminationTree; + +/** + * An EliminatableClusterTree, i.e., a set of variable clusters with factors, + * arranged in a tree, with the additional property that it represents the + * clique tree associated with a Bayes net. + * + * In GTSAM a junction tree is an intermediate data structure in multifrontal + * variable elimination. Each node is a cluster of factors, along with a + * clique of variables that are eliminated all at once. In detail, every node k + * represents a clique (maximal fully connected subset) of an associated chordal + * graph, such as a chordal Bayes net resulting from elimination. + * + * The difference with the BayesTree is that a JunctionTree stores factors, + * whereas a BayesTree stores conditionals, that are the product of eliminating + * the factors in the corresponding JunctionTree cliques. + * + * The tree structure and elimination method are exactly analogous to the + * EliminationTree, except that in the JunctionTree, at each node multiple + * variables are eliminated at a time. + * + * \ingroup Multifrontal + * @ingroup discrete + * \nosubgrouping + */ +class GTSAM_EXPORT DiscreteJunctionTree + : public JunctionTree { + public: + typedef JunctionTree + Base; ///< Base class + typedef DiscreteJunctionTree This; ///< This class + typedef std::shared_ptr shared_ptr; ///< Shared pointer to this class + + /// @name Constructors + /// @{ /** - * An EliminatableClusterTree, i.e., a set of variable clusters with factors, arranged in a tree, - * with the additional property that it represents the clique tree associated with a Bayes net. - * - * In GTSAM a junction tree is an intermediate data structure in multifrontal - * variable elimination. Each node is a cluster of factors, along with a - * clique of variables that are eliminated all at once. In detail, every node k represents - * a clique (maximal fully connected subset) of an associated chordal graph, such as a - * chordal Bayes net resulting from elimination. - * - * The difference with the BayesTree is that a JunctionTree stores factors, whereas a - * BayesTree stores conditionals, that are the product of eliminating the factors in the - * corresponding JunctionTree cliques. - * - * The tree structure and elimination method are exactly analogous to the EliminationTree, - * except that in the JunctionTree, at each node multiple variables are eliminated at a time. - * - * \ingroup Multifrontal - * @ingroup discrete - * \nosubgrouping + * Build the elimination tree of a factor graph using precomputed column + * structure. + * @param factorGraph The factor graph for which to build the elimination tree + * @param structure The set of factors involving each variable. If this is + * not precomputed, you can call the Create(const FactorGraph&) + * named constructor instead. + * @return The elimination tree */ - class GTSAM_EXPORT DiscreteJunctionTree : - public JunctionTree { - public: - typedef JunctionTree Base; ///< Base class - typedef DiscreteJunctionTree This; ///< This class - typedef std::shared_ptr shared_ptr; ///< Shared pointer to this class + DiscreteJunctionTree(const DiscreteEliminationTree& eliminationTree); - /** - * Build the elimination tree of a factor graph using precomputed column structure. - * @param factorGraph The factor graph for which to build the elimination tree - * @param structure The set of factors involving each variable. If this is not - * precomputed, you can call the Create(const FactorGraph&) - * named constructor instead. - * @return The elimination tree - */ - DiscreteJunctionTree(const DiscreteEliminationTree& eliminationTree); - }; + /// @} + /// @name Testable + /// @{ - /// typedef for wrapper: - using DiscreteCluster = DiscreteJunctionTree::Cluster; -} + /** Print the tree to cout */ + void print(const std::string& name = "DiscreteJunctionTree: ", + const KeyFormatter& formatter = DefaultKeyFormatter) const; + + /// @} +}; + +/// typedef for wrapper: +using DiscreteCluster = DiscreteJunctionTree::Cluster; +} // namespace gtsam diff --git a/gtsam/discrete/DiscreteSearch.cpp b/gtsam/discrete/DiscreteSearch.cpp index c5941862d..c046f508f 100644 --- a/gtsam/discrete/DiscreteSearch.cpp +++ b/gtsam/discrete/DiscreteSearch.cpp @@ -9,38 +9,37 @@ * -------------------------------------------------------------------------- */ -/* +/** * DiscreteSearch.cpp * * @date January, 2025 * @author Frank Dellaert */ +#include +#include #include namespace gtsam { +using Slot = DiscreteSearch::Slot; using Solution = DiscreteSearch::Solution; -/** - * @brief Represents a node in the search tree for discrete search algorithms. - * - * @details Each SearchNode contains a partial assignment of discrete variables, - * the current error, a bound on the final error, and the index of the next - * conditional to be assigned. +/* + * A SearchNode represents a node in the search tree for the search algorithm. + * Each SearchNode contains a partial assignment of discrete variables, the + * current error, a bound on the final error, and the index of the next + * slot to be assigned. */ struct SearchNode { - DiscreteValues assignment; ///< Partial assignment of discrete variables. - double error; ///< Current error for the partial assignment. - double bound; ///< Lower bound on the final error for unassigned variables. - int nextConditional; ///< Index of the next conditional to be assigned. + DiscreteValues assignment; // Partial assignment of discrete variables. + double error; // Current error for the partial assignment. + double bound; // Lower bound on the final error + std::optional next; // Index of the next slot to be assigned. - /** - * @brief Construct the root node for the search. - */ - static SearchNode Root(size_t numConditionals, double bound) { - return {DiscreteValues(), 0.0, bound, - static_cast(numConditionals) - 1}; + // Construct the root node for the search. + static SearchNode Root(size_t numSlots, double bound) { + return {DiscreteValues(), 0.0, bound, 0}; } struct Compare { @@ -49,40 +48,22 @@ struct SearchNode { } }; - /** - * @brief Checks if the node represents a complete assignment. - * - * @return True if all variables have been assigned, false otherwise. - */ - inline bool isComplete() const { return nextConditional < 0; } + // Checks if the node represents a complete assignment. + inline bool isComplete() const { return !next; } - /** - * @brief Expands the node by assigning the next variable. - * - * @param conditional The discrete conditional representing the next variable - * to be assigned. - * @param fa The frontal assignment for the next variable. - * @return A new SearchNode representing the expanded state. - */ - SearchNode expand(const DiscreteConditional& conditional, - const DiscreteValues& fa) const { + // Expands the node by assigning the next variable(s). + SearchNode expand(const DiscreteValues& fa, const Slot& slot, + std::optional nextSlot) const { // Combine the new frontal assignment with the current partial assignment DiscreteValues newAssignment = assignment; for (auto& [key, value] : fa) { newAssignment[key] = value; } - - return {newAssignment, error + conditional.error(newAssignment), 0.0, - nextConditional - 1}; + double errorSoFar = error + slot.factor->error(newAssignment); + return {newAssignment, errorSoFar, errorSoFar + slot.heuristic, nextSlot}; } - /** - * @brief Prints the SearchNode to an output stream. - * - * @param os The output stream. - * @param node The SearchNode to be printed. - * @return The output stream. - */ + // Prints the SearchNode to an output stream. friend std::ostream& operator<<(std::ostream& os, const SearchNode& node) { os << "SearchNode(error=" << node.error << ", bound=" << node.bound << ")"; return os; @@ -95,17 +76,20 @@ struct CompareSolution { } }; -// Define the Solutions class +/* + * A Solutions object maintains a priority queue of the best solutions found + * during the search. The priority queue is limited to a maximum size, and + * solutions are only added if they are better than the worst solution. + */ class Solutions { - private: - size_t maxSize_; + size_t maxSize_; // Maximum number of solutions to keep std::priority_queue, CompareSolution> pq_; public: Solutions(size_t maxSize) : maxSize_(maxSize) {} - /// Add a solution to the priority queue, possibly evicting the worst one. - /// Return true if we added the solution. + // Add a solution to the priority queue, possibly evicting the worst one. + // Return true if we added the solution. bool maybeAdd(double error, const DiscreteValues& assignment) { const bool full = pq_.size() == maxSize_; if (full && error >= pq_.top().error) return false; @@ -114,7 +98,7 @@ class Solutions { return true; } - /// Check if we have any solutions + // Check if we have any solutions bool empty() const { return pq_.empty(); } // Method to print all solutions @@ -128,9 +112,9 @@ class Solutions { return os; } - /// Check if (partial) solution with given bound can be pruned. If we have - /// room, we never prune. Otherwise, prune if lower bound on error is worse - /// than our current worst error. + // Check if (partial) solution with given bound can be pruned. If we have + // room, we never prune. Otherwise, prune if lower bound on error is worse + // than our current worst error. bool prune(double bound) const { if (pq_.size() < maxSize_) return false; return bound >= pq_.top().error; @@ -150,97 +134,155 @@ class Solutions { } }; +// Get the factor associated with a node, possibly product of factors. +template +static DiscreteFactor::shared_ptr getFactor(const NodeType& node) { + const auto& factors = node->factors; + return factors.size() == 1 ? factors.back() + : DiscreteFactorGraph(factors).product(); +} + +DiscreteSearch::DiscreteSearch(const DiscreteEliminationTree& etree) { + using NodePtr = std::shared_ptr; + auto visitor = [this](const NodePtr& node, int data) { + const DiscreteFactor::shared_ptr factor = getFactor(node); + const size_t cardinality = factor->cardinality(node->key); + std::vector> pairs{{node->key, cardinality}}; + const Slot slot{factor, DiscreteValues::CartesianProduct(pairs), 0.0}; + slots_.emplace_back(std::move(slot)); + return data + 1; + }; + + int data = 0; // unused + treeTraversal::DepthFirstForest(etree, data, visitor); + lowerBound_ = computeHeuristic(); +} + +DiscreteSearch::DiscreteSearch(const DiscreteJunctionTree& junctionTree) { + using NodePtr = std::shared_ptr; + auto visitor = [this](const NodePtr& cluster, int data) { + const auto factor = getFactor(cluster); + std::vector> pairs; + for (Key key : cluster->orderedFrontalKeys) { + pairs.emplace_back(key, factor->cardinality(key)); + } + const Slot slot{factor, DiscreteValues::CartesianProduct(pairs), 0.0}; + slots_.emplace_back(std::move(slot)); + return data + 1; + }; + + int data = 0; // unused + treeTraversal::DepthFirstForest(junctionTree, data, visitor); + lowerBound_ = computeHeuristic(); +} + +DiscreteSearch DiscreteSearch::FromFactorGraph( + const DiscreteFactorGraph& factorGraph, const Ordering& ordering, + bool buildJunctionTree) { + const DiscreteEliminationTree etree(factorGraph, ordering); + if (buildJunctionTree) { + const DiscreteJunctionTree junctionTree(etree); + return DiscreteSearch(junctionTree); + } else { + return DiscreteSearch(etree); + } +} + DiscreteSearch::DiscreteSearch(const DiscreteBayesNet& bayesNet) { - std::vector conditionals; - for (auto& factor : bayesNet) conditionals_.push_back(factor); - costToGo_ = computeCostToGo(conditionals_); + slots_.reserve(bayesNet.size()); + for (auto& conditional : bayesNet) { + const Slot slot{conditional, conditional->frontalAssignments(), 0.0}; + slots_.emplace_back(std::move(slot)); + } + std::reverse(slots_.begin(), slots_.end()); + lowerBound_ = computeHeuristic(); } DiscreteSearch::DiscreteSearch(const DiscreteBayesTree& bayesTree) { - std::function - collectConditionals = [&](const auto& clique) { - if (!clique) return; - for (const auto& child : clique->children) collectConditionals(child); - conditionals_.push_back(clique->conditional()); - }; - for (const auto& root : bayesTree.roots()) collectConditionals(root); - costToGo_ = computeCostToGo(conditionals_); + using NodePtr = DiscreteBayesTree::sharedClique; + auto visitor = [this](const NodePtr& clique, int data) { + auto conditional = clique->conditional(); + const Slot slot{conditional, conditional->frontalAssignments(), 0.0}; + slots_.emplace_back(std::move(slot)); + return data + 1; + }; + + int data = 0; // unused + treeTraversal::DepthFirstForest(bayesTree, data, visitor); + lowerBound_ = computeHeuristic(); } -struct SearchNodeQueue - : public std::priority_queue, - SearchNode::Compare> { - void expandNextNode( - const std::vector& conditionals, - const std::vector& costToGo, Solutions* solutions) { +void DiscreteSearch::print(const std::string& name, + const KeyFormatter& formatter) const { + std::cout << name << " with " << slots_.size() << " slots:\n"; + for (size_t i = 0; i < slots_.size(); ++i) { + std::cout << i << ": " << slots_[i] << std::endl; + } +} + +using SearchNodeQueue = std::priority_queue, + SearchNode::Compare>; + +std::vector DiscreteSearch::run(size_t K) const { + if (slots_.empty()) { + return {Solution(0.0, DiscreteValues())}; + } + + Solutions solutions(K); + SearchNodeQueue expansions; + expansions.push(SearchNode::Root(slots_.size(), lowerBound_)); + + // Perform the search + while (!expansions.empty()) { // Pop the partial assignment with the smallest bound - SearchNode current = top(); - pop(); + SearchNode current = expansions.top(); + expansions.pop(); // If we already have K solutions, prune if we cannot beat the worst one. - if (solutions->prune(current.bound)) { - return; + if (solutions.prune(current.bound)) { + continue; } // Check if we have a complete assignment if (current.isComplete()) { - solutions->maybeAdd(current.error, current.assignment); - return; + solutions.maybeAdd(current.error, current.assignment); + continue; } - // Expand on the next factor - const auto& conditional = conditionals[current.nextConditional]; - - for (auto& fa : conditional->frontalAssignments()) { - auto childNode = current.expand(*conditional, fa); - if (childNode.nextConditional >= 0) - childNode.bound = childNode.error + costToGo[childNode.nextConditional]; + // Get the next slot to expand + const auto& slot = slots_[*current.next]; + std::optional nextSlot = *current.next + 1; + if (nextSlot == slots_.size()) nextSlot.reset(); + for (auto& fa : slot.assignments) { + auto childNode = current.expand(fa, slot, nextSlot); // Again, prune if we cannot beat the worst solution - if (!solutions->prune(childNode.bound)) { - emplace(childNode); + if (!solutions.prune(childNode.bound)) { + expansions.emplace(childNode); } } } -}; - -std::vector DiscreteSearch::run(size_t K) const { - Solutions solutions(K); - SearchNodeQueue expansions; - expansions.push(SearchNode::Root(conditionals_.size(), - costToGo_.empty() ? 0.0 : costToGo_.back())); - -#ifdef DISCRETE_SEARCH_DEBUG - size_t numExpansions = 0; -#endif - - // Perform the search - while (!expansions.empty()) { - expansions.expandNextNode(conditionals_, costToGo_, &solutions); -#ifdef DISCRETE_SEARCH_DEBUG - ++numExpansions; -#endif - } - -#ifdef DISCRETE_SEARCH_DEBUG - std::cout << "Number of expansions: " << numExpansions << std::endl; -#endif // Extract solutions from bestSolutions in ascending order of error return solutions.extractSolutions(); } - -std::vector DiscreteSearch::computeCostToGo( - const std::vector& conditionals) { - std::vector costToGo; +/* + * We have a number of factors, each with a max value, and we want to compute + * a lower-bound on the cost-to-go for each slot, *not* including this factor. + * For the last slot[n-1], this is 0.0, as the cost after that is zero. + * For the second-to-last slot, it is h = -log(max(factor[n-1])), because after + * we assign slot[n-2] we still need to assign slot[n-1], which will cost *at + * least* h. We return the estimated lower bound of the cost for *all* slots. + */ +double DiscreteSearch::computeHeuristic() { double error = 0.0; - for (const auto& conditional : conditionals) { - Ordering ordering(conditional->begin(), conditional->end()); - auto maxx = conditional->max(ordering); + for (auto it = slots_.rbegin(); it != slots_.rend(); ++it) { + it->heuristic = error; + Ordering ordering(it->factor->begin(), it->factor->end()); + auto maxx = it->factor->max(ordering); error -= std::log(maxx->evaluate({})); - costToGo.push_back(error); } - return costToGo; + return error; } } // namespace gtsam diff --git a/gtsam/discrete/DiscreteSearch.h b/gtsam/discrete/DiscreteSearch.h index 6202880b2..b610955b2 100644 --- a/gtsam/discrete/DiscreteSearch.h +++ b/gtsam/discrete/DiscreteSearch.h @@ -9,8 +9,12 @@ * -------------------------------------------------------------------------- */ -/* - * DiscreteSearch.cpp +/** + * @file DiscreteSearch.h + * @brief Defines the DiscreteSearch class for discrete search algorithms. + * + * @details This file contains the definition of the DiscreteSearch class, which + * is used in discrete search algorithms to find the K best solutions. * * @date January, 2025 * @author Frank Dellaert @@ -24,12 +28,53 @@ namespace gtsam { /** - * DiscreteSearch: Search for the K best solutions. + * @brief DiscreteSearch: Search for the K best solutions. + * + * This class is used to search for the K best solutions in a DiscreteBayesNet. + * This is implemented with a modified A* search algorithm that uses a priority + * queue to manage the search nodes. That machinery is defined in the .cpp file. + * The heuristic we use is the sum of the log-probabilities of the + * maximum-probability assignments for each slot, for all slots to the right of + * the current slot. + * + * TODO: The heuristic could be refined by using the partial assignment in + * search node to refine the max-probability assignment for the remaining slots. + * This would incur more computation but will lead to fewer expansions. */ class GTSAM_EXPORT DiscreteSearch { public: /** - * @brief A solution to a discrete search problem. + * We structure the search as a set of slots, each with a factor and + * a set of variable assignments that need to be chosen. In addition, each + * slot has a heuristic associated with it. + * + * Example: + * The factors in the search problem (always parents before descendents!): + * [P(A), P(B|A), P(C|A,B)] + * The assignments for each factor. + * [[A0,A1], [B0,B1], [C0,C1,C2]] + * A lower bound on the cost-to-go after each slot, e.g., + * [-log(max_B P(B|A)) -log(max_C P(C|A,B)), -log(max_C P(C|A,B)), 0.0] + * Note that these decrease as we move from right to left. + * We keep the global lower bound as lowerBound_. In the example, it is: + * -log(max_B P(B|A)) -log(max_C P(C|A,B)) -log(max_C P(C|A,B)) + */ + struct Slot { + DiscreteFactor::shared_ptr factor; + std::vector assignments; + double heuristic; + + friend std::ostream& operator<<(std::ostream& os, const Slot& slot) { + os << "Slot with " << slot.assignments.size() + << " assignments, heuristic=" << slot.heuristic; + os << ", factor:\n" << slot.factor->markdown() << std::endl; + return os; + } + }; + + /** + * A solution is a set of assignments, covering all the slots. + * as well as an associated error = -log(probability) */ struct Solution { double error; @@ -42,16 +87,56 @@ class GTSAM_EXPORT DiscreteSearch { } }; - /** - * Construct from a DiscreteBayesNet and K. - */ - DiscreteSearch(const DiscreteBayesNet& bayesNet); + public: + /// @name Standard Constructors + /// @{ /** - * Construct from a DiscreteBayesTree and K. + * Construct from a DiscreteFactorGraph. + * + * Internally creates either an elimination tree or a junction tree. The + * latter incurs more up-front computation but the search itself might be + * faster. Then again, for the elimination tree, the heuristic will be more + * fine-grained (more slots). + * + * @param factorGraph The factor graph to search over. + * @param ordering The ordering used to create etree (and maybe jtree). + * @param buildJunctionTree Whether to build a junction tree or not. */ + static DiscreteSearch FromFactorGraph(const DiscreteFactorGraph& factorGraph, + const Ordering& ordering, + bool buildJunctionTree = false); + + /// Construct from a DiscreteEliminationTree. + DiscreteSearch(const DiscreteEliminationTree& etree); + + /// Construct from a DiscreteJunctionTree. + DiscreteSearch(const DiscreteJunctionTree& junctionTree); + + //// Construct from a DiscreteBayesNet. + DiscreteSearch(const DiscreteBayesNet& bayesNet); + + /// Construct from a DiscreteBayesTree. DiscreteSearch(const DiscreteBayesTree& bayesTree); + /// @} + /// @name Testable + /// @{ + + /** Print the tree to cout */ + void print(const std::string& name = "DiscreteSearch: ", + const KeyFormatter& formatter = DefaultKeyFormatter) const; + + /// @} + /// @name Standard API + /// @{ + + /// Return lower bound on the cost-to-go for the entire search + double lowerBound() const { return lowerBound_; } + + /// Read access to the slots + const std::vector& slots() const { return slots_; } + /** * @brief Search for the K best solutions. * @@ -64,15 +149,16 @@ class GTSAM_EXPORT DiscreteSearch { */ std::vector run(size_t K = 1) const; + /// @} + private: - /// Compute the cumulative cost-to-go for each conditional slot. - static std::vector computeCostToGo( - const std::vector& conditionals); + /** + * Compute the cumulative lower-bound cost-to-go after each slot is filled. + * @return the estimated lower bound of the cost for *all* slots. + */ + double computeHeuristic(); - /// Expand the next node in the search tree. - void expandNextNode() const; - - std::vector conditionals_; - std::vector costToGo_; + double lowerBound_; ///< Lower bound on the cost-to-go for the entire search. + std::vector slots_; ///< The slots to fill in the search. }; } // namespace gtsam diff --git a/gtsam/discrete/tests/AsiaExample.h b/gtsam/discrete/tests/AsiaExample.h index 6c327daec..ff6c4ea99 100644 --- a/gtsam/discrete/tests/AsiaExample.h +++ b/gtsam/discrete/tests/AsiaExample.h @@ -58,4 +58,4 @@ DiscreteBayesNet createAsiaExample() { return asia; } } // namespace asia_example -} // namespace gtsam \ No newline at end of file +} // namespace gtsam diff --git a/gtsam/discrete/tests/testDiscreteSearch.cpp b/gtsam/discrete/tests/testDiscreteSearch.cpp index b537dd2f0..cebddfe8d 100644 --- a/gtsam/discrete/tests/testDiscreteSearch.cpp +++ b/gtsam/discrete/tests/testDiscreteSearch.cpp @@ -28,9 +28,15 @@ using namespace gtsam; namespace asia { using namespace asia_example; static const DiscreteBayesNet bayesNet = createAsiaExample(); + +// Create factor graph and optimize with max-product for MPE static const DiscreteFactorGraph factorGraph(bayesNet); static const DiscreteValues mpe = factorGraph.optimize(); + +// Create ordering static const Ordering ordering{D, X, B, E, L, T, S, A}; + +// Create Bayes tree static const DiscreteBayesTree bayesTree = *factorGraph.eliminateMultifrontal(ordering); } // namespace asia @@ -45,29 +51,6 @@ TEST(DiscreteBayesNet, EmptyKBest) { EXPECT_DOUBLES_EQUAL(0, std::fabs(solutions[0].error), 1e-9); } -/* ************************************************************************* */ -TEST(DiscreteBayesNet, AsiaKBest) { - const DiscreteSearch search(asia::bayesNet); - - // Ask for the MPE - auto mpe = search.run(); - - EXPECT_LONGS_EQUAL(1, mpe.size()); - // Regression test: check the MPE solution - EXPECT_DOUBLES_EQUAL(1.236627, std::fabs(mpe[0].error), 1e-5); - - // Check it is equal to MPE via inference - EXPECT(assert_equal(asia::mpe, mpe[0].assignment)); - - // Ask for top 4 solutions - auto solutions = search.run(4); - - EXPECT_LONGS_EQUAL(4, solutions.size()); - // Regression test: check the first and last solution - EXPECT_DOUBLES_EQUAL(1.236627, std::fabs(solutions[0].error), 1e-5); - EXPECT_DOUBLES_EQUAL(2.201708, std::fabs(solutions[3].error), 1e-5); -} - /* ************************************************************************* */ TEST(DiscreteBayesTree, EmptyTree) { DiscreteBayesTree bt; @@ -81,26 +64,45 @@ TEST(DiscreteBayesTree, EmptyTree) { } /* ************************************************************************* */ -TEST(DiscreteBayesTree, AsiaTreeKBest) { - DiscreteSearch search(asia::bayesTree); +TEST(DiscreteBayesNet, AsiaKBest) { + auto fromETree = + DiscreteSearch::FromFactorGraph(asia::factorGraph, asia::ordering); + auto fromJunctionTree = + DiscreteSearch::FromFactorGraph(asia::factorGraph, asia::ordering, true); + const DiscreteSearch fromBayesNet(asia::bayesNet); + const DiscreteSearch fromBayesTree(asia::bayesTree); - // Ask for MPE - auto mpe = search.run(); + for (auto& search : + {fromETree, fromJunctionTree, fromBayesNet, fromBayesTree}) { + // Ask for the MPE + auto mpe = search.run(); - EXPECT_LONGS_EQUAL(1, mpe.size()); - // Regression test: check the MPE solution - EXPECT_DOUBLES_EQUAL(1.236627, std::fabs(mpe[0].error), 1e-5); + // Regression on error lower bound + EXPECT_DOUBLES_EQUAL(1.205536, search.lowerBound(), 1e-5); - // Check it is equal to MPE via inference - EXPECT(assert_equal(asia::mpe, mpe[0].assignment)); + // Check that the cost-to-go heuristic decreases from there + auto slots = search.slots(); + double previousHeuristic = search.lowerBound(); + for (auto&& slot : slots) { + EXPECT(slot.heuristic <= previousHeuristic); + previousHeuristic = slot.heuristic; + } - // Ask for top 4 solutions - auto solutions = search.run(4); + EXPECT_LONGS_EQUAL(1, mpe.size()); + // Regression test: check the MPE solution + EXPECT_DOUBLES_EQUAL(1.236627, std::fabs(mpe[0].error), 1e-5); - EXPECT_LONGS_EQUAL(4, solutions.size()); - // Regression test: check the first and last solution - EXPECT_DOUBLES_EQUAL(1.236627, std::fabs(solutions[0].error), 1e-5); - EXPECT_DOUBLES_EQUAL(2.201708, std::fabs(solutions[3].error), 1e-5); + // Check it is equal to MPE via inference + EXPECT(assert_equal(asia::mpe, mpe[0].assignment)); + + // Ask for top 4 solutions + auto solutions = search.run(4); + + EXPECT_LONGS_EQUAL(4, solutions.size()); + // Regression test: check the first and last solution + EXPECT_DOUBLES_EQUAL(1.236627, std::fabs(solutions[0].error), 1e-5); + EXPECT_DOUBLES_EQUAL(2.201708, std::fabs(solutions[3].error), 1e-5); + } } /* ************************************************************************* */