diff --git a/examples/Hybrid_City10000.cpp b/examples/Hybrid_City10000.cpp index 80480ca64..f79f88970 100644 --- a/examples/Hybrid_City10000.cpp +++ b/examples/Hybrid_City10000.cpp @@ -166,7 +166,7 @@ int main(int argc, char* argv[]) { clock_t after_update = clock(); smoother_update_times.push_back({index, after_update - before_update}); - size_t key_s, key_t; + size_t key_s, key_t{0}; clock_t start_time = clock(); std::string str; diff --git a/gtsam/discrete/DiscreteFactorGraph.cpp b/gtsam/discrete/DiscreteFactorGraph.cpp index 1fb353423..133182a19 100644 --- a/gtsam/discrete/DiscreteFactorGraph.cpp +++ b/gtsam/discrete/DiscreteFactorGraph.cpp @@ -214,7 +214,10 @@ namespace gtsam { std::pair // EliminateDiscrete(const DiscreteFactorGraph& factors, const Ordering& frontalKeys) { - DiscreteFactor::shared_ptr product = factors.scaledProduct(); + gttic(product); + // `product` is scaled later to prevent underflow. + DiscreteFactor::shared_ptr product = factors.product(); + gttoc(product); // sum out frontals, this is the factor on the separator gttic(sum); @@ -223,6 +226,16 @@ namespace gtsam { sum = sum->scale(); gttoc(sum); + // Normalize/scale to prevent underflow. + // We divide both `product` and `sum` by `max(sum)` + // since it is faster to compute and when the conditional + // is formed by `product/sum`, the scaling term cancels out. + gttic(scale); + DiscreteFactor::shared_ptr denominator = sum->max(sum->size()); + product = product->operator/(denominator); + sum = sum->operator/(denominator); + gttoc(scale); + // Ordering keys for the conditional so that frontalKeys are really in front Ordering orderedKeys; orderedKeys.insert(orderedKeys.end(), frontalKeys.begin(), diff --git a/gtsam/discrete/DiscreteJunctionTree.cpp b/gtsam/discrete/DiscreteJunctionTree.cpp index dc24860eb..bf9f9fe18 100644 --- a/gtsam/discrete/DiscreteJunctionTree.cpp +++ b/gtsam/discrete/DiscreteJunctionTree.cpp @@ -16,19 +16,35 @@ * @author Richard Roberts */ -#include -#include #include +#include +#include namespace gtsam { - // Instantiate base classes - template class EliminatableClusterTree; - template class JunctionTree; +// Instantiate base classes +template class EliminatableClusterTree; +template class JunctionTree; - /* ************************************************************************* */ - DiscreteJunctionTree::DiscreteJunctionTree( - const DiscreteEliminationTree& eliminationTree) : - Base(eliminationTree) {} +/* ************************************************************************* */ +DiscreteJunctionTree::DiscreteJunctionTree( + const DiscreteEliminationTree& eliminationTree) + : Base(eliminationTree) {} +/* ************************************************************************* */ +void DiscreteJunctionTree::print(const std::string& s, + const KeyFormatter& keyFormatter) const { + auto visitor = [&keyFormatter]( + const std::shared_ptr& node, + const std::string& parentString) { + // Print the current node + node->print(parentString + "-", keyFormatter); + node->factors.print(parentString + "-", keyFormatter); + std::cout << std::endl; + return parentString + "| "; // Increment the indentation + }; + std::string parentString = s; + treeTraversal::DepthFirstForest(*this, parentString, visitor); } + +} // namespace gtsam diff --git a/gtsam/discrete/DiscreteJunctionTree.h b/gtsam/discrete/DiscreteJunctionTree.h index f6171c672..4b9241036 100644 --- a/gtsam/discrete/DiscreteJunctionTree.h +++ b/gtsam/discrete/DiscreteJunctionTree.h @@ -18,54 +18,71 @@ #pragma once -#include #include +#include #include namespace gtsam { - // Forward declarations - class DiscreteEliminationTree; +// Forward declarations +class DiscreteEliminationTree; + +/** + * An EliminatableClusterTree, i.e., a set of variable clusters with factors, + * arranged in a tree, with the additional property that it represents the + * clique tree associated with a Bayes net. + * + * In GTSAM a junction tree is an intermediate data structure in multifrontal + * variable elimination. Each node is a cluster of factors, along with a + * clique of variables that are eliminated all at once. In detail, every node k + * represents a clique (maximal fully connected subset) of an associated chordal + * graph, such as a chordal Bayes net resulting from elimination. + * + * The difference with the BayesTree is that a JunctionTree stores factors, + * whereas a BayesTree stores conditionals, that are the product of eliminating + * the factors in the corresponding JunctionTree cliques. + * + * The tree structure and elimination method are exactly analogous to the + * EliminationTree, except that in the JunctionTree, at each node multiple + * variables are eliminated at a time. + * + * \ingroup Multifrontal + * @ingroup discrete + * \nosubgrouping + */ +class GTSAM_EXPORT DiscreteJunctionTree + : public JunctionTree { + public: + typedef JunctionTree + Base; ///< Base class + typedef DiscreteJunctionTree This; ///< This class + typedef std::shared_ptr shared_ptr; ///< Shared pointer to this class + + /// @name Constructors + /// @{ /** - * An EliminatableClusterTree, i.e., a set of variable clusters with factors, arranged in a tree, - * with the additional property that it represents the clique tree associated with a Bayes net. - * - * In GTSAM a junction tree is an intermediate data structure in multifrontal - * variable elimination. Each node is a cluster of factors, along with a - * clique of variables that are eliminated all at once. In detail, every node k represents - * a clique (maximal fully connected subset) of an associated chordal graph, such as a - * chordal Bayes net resulting from elimination. - * - * The difference with the BayesTree is that a JunctionTree stores factors, whereas a - * BayesTree stores conditionals, that are the product of eliminating the factors in the - * corresponding JunctionTree cliques. - * - * The tree structure and elimination method are exactly analogous to the EliminationTree, - * except that in the JunctionTree, at each node multiple variables are eliminated at a time. - * - * \ingroup Multifrontal - * @ingroup discrete - * \nosubgrouping + * Build the elimination tree of a factor graph using precomputed column + * structure. + * @param factorGraph The factor graph for which to build the elimination tree + * @param structure The set of factors involving each variable. If this is + * not precomputed, you can call the Create(const FactorGraph&) + * named constructor instead. + * @return The elimination tree */ - class GTSAM_EXPORT DiscreteJunctionTree : - public JunctionTree { - public: - typedef JunctionTree Base; ///< Base class - typedef DiscreteJunctionTree This; ///< This class - typedef std::shared_ptr shared_ptr; ///< Shared pointer to this class + DiscreteJunctionTree(const DiscreteEliminationTree& eliminationTree); - /** - * Build the elimination tree of a factor graph using precomputed column structure. - * @param factorGraph The factor graph for which to build the elimination tree - * @param structure The set of factors involving each variable. If this is not - * precomputed, you can call the Create(const FactorGraph&) - * named constructor instead. - * @return The elimination tree - */ - DiscreteJunctionTree(const DiscreteEliminationTree& eliminationTree); - }; + /// @} + /// @name Testable + /// @{ - /// typedef for wrapper: - using DiscreteCluster = DiscreteJunctionTree::Cluster; -} + /** Print the tree to cout */ + void print(const std::string& name = "DiscreteJunctionTree: ", + const KeyFormatter& formatter = DefaultKeyFormatter) const; + + /// @} +}; + +/// typedef for wrapper: +using DiscreteCluster = DiscreteJunctionTree::Cluster; +} // namespace gtsam diff --git a/gtsam/discrete/DiscreteSearch.cpp b/gtsam/discrete/DiscreteSearch.cpp new file mode 100644 index 000000000..c046f508f --- /dev/null +++ b/gtsam/discrete/DiscreteSearch.cpp @@ -0,0 +1,288 @@ +/* ---------------------------------------------------------------------------- + + * GTSAM Copyright 2010, Georgia Tech Research Corporation, + * Atlanta, Georgia 30332-0415 + * All Rights Reserved + * Authors: Frank Dellaert, et al. (see THANKS for the full author list) + + * See LICENSE for the license information + + * -------------------------------------------------------------------------- */ + +/** + * DiscreteSearch.cpp + * + * @date January, 2025 + * @author Frank Dellaert + */ + +#include +#include +#include + +namespace gtsam { + +using Slot = DiscreteSearch::Slot; +using Solution = DiscreteSearch::Solution; + +/* + * A SearchNode represents a node in the search tree for the search algorithm. + * Each SearchNode contains a partial assignment of discrete variables, the + * current error, a bound on the final error, and the index of the next + * slot to be assigned. + */ +struct SearchNode { + DiscreteValues assignment; // Partial assignment of discrete variables. + double error; // Current error for the partial assignment. + double bound; // Lower bound on the final error + std::optional next; // Index of the next slot to be assigned. + + // Construct the root node for the search. + static SearchNode Root(size_t numSlots, double bound) { + return {DiscreteValues(), 0.0, bound, 0}; + } + + struct Compare { + bool operator()(const SearchNode& a, const SearchNode& b) const { + return a.bound > b.bound; // smallest bound -> highest priority + } + }; + + // Checks if the node represents a complete assignment. + inline bool isComplete() const { return !next; } + + // Expands the node by assigning the next variable(s). + SearchNode expand(const DiscreteValues& fa, const Slot& slot, + std::optional nextSlot) const { + // Combine the new frontal assignment with the current partial assignment + DiscreteValues newAssignment = assignment; + for (auto& [key, value] : fa) { + newAssignment[key] = value; + } + double errorSoFar = error + slot.factor->error(newAssignment); + return {newAssignment, errorSoFar, errorSoFar + slot.heuristic, nextSlot}; + } + + // Prints the SearchNode to an output stream. + friend std::ostream& operator<<(std::ostream& os, const SearchNode& node) { + os << "SearchNode(error=" << node.error << ", bound=" << node.bound << ")"; + return os; + } +}; + +struct CompareSolution { + bool operator()(const Solution& a, const Solution& b) const { + return a.error < b.error; + } +}; + +/* + * A Solutions object maintains a priority queue of the best solutions found + * during the search. The priority queue is limited to a maximum size, and + * solutions are only added if they are better than the worst solution. + */ +class Solutions { + size_t maxSize_; // Maximum number of solutions to keep + std::priority_queue, CompareSolution> pq_; + + public: + Solutions(size_t maxSize) : maxSize_(maxSize) {} + + // Add a solution to the priority queue, possibly evicting the worst one. + // Return true if we added the solution. + bool maybeAdd(double error, const DiscreteValues& assignment) { + const bool full = pq_.size() == maxSize_; + if (full && error >= pq_.top().error) return false; + if (full) pq_.pop(); + pq_.emplace(error, assignment); + return true; + } + + // Check if we have any solutions + bool empty() const { return pq_.empty(); } + + // Method to print all solutions + friend std::ostream& operator<<(std::ostream& os, const Solutions& sn) { + os << "Solutions (top " << sn.pq_.size() << "):\n"; + auto pq = sn.pq_; + while (!pq.empty()) { + os << pq.top() << "\n"; + pq.pop(); + } + return os; + } + + // Check if (partial) solution with given bound can be pruned. If we have + // room, we never prune. Otherwise, prune if lower bound on error is worse + // than our current worst error. + bool prune(double bound) const { + if (pq_.size() < maxSize_) return false; + return bound >= pq_.top().error; + } + + // Method to extract solutions in ascending order of error + std::vector extractSolutions() { + std::vector result; + while (!pq_.empty()) { + result.push_back(pq_.top()); + pq_.pop(); + } + std::sort( + result.begin(), result.end(), + [](const Solution& a, const Solution& b) { return a.error < b.error; }); + return result; + } +}; + +// Get the factor associated with a node, possibly product of factors. +template +static DiscreteFactor::shared_ptr getFactor(const NodeType& node) { + const auto& factors = node->factors; + return factors.size() == 1 ? factors.back() + : DiscreteFactorGraph(factors).product(); +} + +DiscreteSearch::DiscreteSearch(const DiscreteEliminationTree& etree) { + using NodePtr = std::shared_ptr; + auto visitor = [this](const NodePtr& node, int data) { + const DiscreteFactor::shared_ptr factor = getFactor(node); + const size_t cardinality = factor->cardinality(node->key); + std::vector> pairs{{node->key, cardinality}}; + const Slot slot{factor, DiscreteValues::CartesianProduct(pairs), 0.0}; + slots_.emplace_back(std::move(slot)); + return data + 1; + }; + + int data = 0; // unused + treeTraversal::DepthFirstForest(etree, data, visitor); + lowerBound_ = computeHeuristic(); +} + +DiscreteSearch::DiscreteSearch(const DiscreteJunctionTree& junctionTree) { + using NodePtr = std::shared_ptr; + auto visitor = [this](const NodePtr& cluster, int data) { + const auto factor = getFactor(cluster); + std::vector> pairs; + for (Key key : cluster->orderedFrontalKeys) { + pairs.emplace_back(key, factor->cardinality(key)); + } + const Slot slot{factor, DiscreteValues::CartesianProduct(pairs), 0.0}; + slots_.emplace_back(std::move(slot)); + return data + 1; + }; + + int data = 0; // unused + treeTraversal::DepthFirstForest(junctionTree, data, visitor); + lowerBound_ = computeHeuristic(); +} + +DiscreteSearch DiscreteSearch::FromFactorGraph( + const DiscreteFactorGraph& factorGraph, const Ordering& ordering, + bool buildJunctionTree) { + const DiscreteEliminationTree etree(factorGraph, ordering); + if (buildJunctionTree) { + const DiscreteJunctionTree junctionTree(etree); + return DiscreteSearch(junctionTree); + } else { + return DiscreteSearch(etree); + } +} + +DiscreteSearch::DiscreteSearch(const DiscreteBayesNet& bayesNet) { + slots_.reserve(bayesNet.size()); + for (auto& conditional : bayesNet) { + const Slot slot{conditional, conditional->frontalAssignments(), 0.0}; + slots_.emplace_back(std::move(slot)); + } + std::reverse(slots_.begin(), slots_.end()); + lowerBound_ = computeHeuristic(); +} + +DiscreteSearch::DiscreteSearch(const DiscreteBayesTree& bayesTree) { + using NodePtr = DiscreteBayesTree::sharedClique; + auto visitor = [this](const NodePtr& clique, int data) { + auto conditional = clique->conditional(); + const Slot slot{conditional, conditional->frontalAssignments(), 0.0}; + slots_.emplace_back(std::move(slot)); + return data + 1; + }; + + int data = 0; // unused + treeTraversal::DepthFirstForest(bayesTree, data, visitor); + lowerBound_ = computeHeuristic(); +} + +void DiscreteSearch::print(const std::string& name, + const KeyFormatter& formatter) const { + std::cout << name << " with " << slots_.size() << " slots:\n"; + for (size_t i = 0; i < slots_.size(); ++i) { + std::cout << i << ": " << slots_[i] << std::endl; + } +} + +using SearchNodeQueue = std::priority_queue, + SearchNode::Compare>; + +std::vector DiscreteSearch::run(size_t K) const { + if (slots_.empty()) { + return {Solution(0.0, DiscreteValues())}; + } + + Solutions solutions(K); + SearchNodeQueue expansions; + expansions.push(SearchNode::Root(slots_.size(), lowerBound_)); + + // Perform the search + while (!expansions.empty()) { + // Pop the partial assignment with the smallest bound + SearchNode current = expansions.top(); + expansions.pop(); + + // If we already have K solutions, prune if we cannot beat the worst one. + if (solutions.prune(current.bound)) { + continue; + } + + // Check if we have a complete assignment + if (current.isComplete()) { + solutions.maybeAdd(current.error, current.assignment); + continue; + } + + // Get the next slot to expand + const auto& slot = slots_[*current.next]; + std::optional nextSlot = *current.next + 1; + if (nextSlot == slots_.size()) nextSlot.reset(); + for (auto& fa : slot.assignments) { + auto childNode = current.expand(fa, slot, nextSlot); + + // Again, prune if we cannot beat the worst solution + if (!solutions.prune(childNode.bound)) { + expansions.emplace(childNode); + } + } + } + + // Extract solutions from bestSolutions in ascending order of error + return solutions.extractSolutions(); +} +/* + * We have a number of factors, each with a max value, and we want to compute + * a lower-bound on the cost-to-go for each slot, *not* including this factor. + * For the last slot[n-1], this is 0.0, as the cost after that is zero. + * For the second-to-last slot, it is h = -log(max(factor[n-1])), because after + * we assign slot[n-2] we still need to assign slot[n-1], which will cost *at + * least* h. We return the estimated lower bound of the cost for *all* slots. + */ +double DiscreteSearch::computeHeuristic() { + double error = 0.0; + for (auto it = slots_.rbegin(); it != slots_.rend(); ++it) { + it->heuristic = error; + Ordering ordering(it->factor->begin(), it->factor->end()); + auto maxx = it->factor->max(ordering); + error -= std::log(maxx->evaluate({})); + } + return error; +} + +} // namespace gtsam diff --git a/gtsam/discrete/DiscreteSearch.h b/gtsam/discrete/DiscreteSearch.h new file mode 100644 index 000000000..db3dd5f03 --- /dev/null +++ b/gtsam/discrete/DiscreteSearch.h @@ -0,0 +1,166 @@ +/* ---------------------------------------------------------------------------- + + * GTSAM Copyright 2010, Georgia Tech Research Corporation, + * Atlanta, Georgia 30332-0415 + * All Rights Reserved + * Authors: Frank Dellaert, et al. (see THANKS for the full author list) + + * See LICENSE for the license information + + * -------------------------------------------------------------------------- */ + +/** + * @file DiscreteSearch.h + * @brief Defines the DiscreteSearch class for discrete search algorithms. + * + * @details This file contains the definition of the DiscreteSearch class, which + * is used in discrete search algorithms to find the K best solutions. + * + * @date January, 2025 + * @author Frank Dellaert + */ + +#include +#include + +#include + +namespace gtsam { + +/** + * @brief DiscreteSearch: Search for the K best solutions. + * + * This class is used to search for the K best solutions in a DiscreteBayesNet. + * This is implemented with a modified A* search algorithm that uses a priority + * queue to manage the search nodes. That machinery is defined in the .cpp file. + * The heuristic we use is the sum of the log-probabilities of the + * maximum-probability assignments for each slot, for all slots to the right of + * the current slot. + * + * TODO: The heuristic could be refined by using the partial assignment in + * search node to refine the max-probability assignment for the remaining slots. + * This would incur more computation but will lead to fewer expansions. + */ +class GTSAM_EXPORT DiscreteSearch { + public: + /** + * We structure the search as a set of slots, each with a factor and + * a set of variable assignments that need to be chosen. In addition, each + * slot has a heuristic associated with it. + * + * Example: + * The factors in the search problem (always parents before descendents!): + * [P(A), P(B|A), P(C|A,B)] + * The assignments for each factor. + * [[A0,A1], [B0,B1], [C0,C1,C2]] + * A lower bound on the cost-to-go after each slot, e.g., + * [-log(max_B P(B|A)) -log(max_C P(C|A,B)), -log(max_C P(C|A,B)), 0.0] + * Note that these decrease as we move from right to left. + * We keep the global lower bound as lowerBound_. In the example, it is: + * -log(max_B P(B|A)) -log(max_C P(C|A,B)) -log(max_C P(C|A,B)) + */ + struct Slot { + DiscreteFactor::shared_ptr factor; + std::vector assignments; + double heuristic; + + friend std::ostream& operator<<(std::ostream& os, const Slot& slot) { + os << "Slot with " << slot.assignments.size() + << " assignments, heuristic=" << slot.heuristic; + os << ", factor:\n" << slot.factor->markdown() << std::endl; + return os; + } + }; + + /** + * A solution is a set of assignments, covering all the slots. + * as well as an associated error = -log(probability) + */ + struct Solution { + double error; + DiscreteValues assignment; + Solution(double err, const DiscreteValues& assign) + : error(err), assignment(assign) {} + friend std::ostream& operator<<(std::ostream& os, const Solution& sn) { + os << "[ error=" << sn.error << " assignment={" << sn.assignment << "}]"; + return os; + } + }; + + public: + /// @name Standard Constructors + /// @{ + + /** + * Construct from a DiscreteFactorGraph. + * + * Internally creates either an elimination tree or a junction tree. The + * latter incurs more up-front computation but the search itself might be + * faster. Then again, for the elimination tree, the heuristic will be more + * fine-grained (more slots). + * + * @param factorGraph The factor graph to search over. + * @param ordering The ordering used to create etree (and maybe jtree). + * @param buildJunctionTree Whether to build a junction tree or not. + */ + static DiscreteSearch FromFactorGraph(const DiscreteFactorGraph& factorGraph, + const Ordering& ordering, + bool buildJunctionTree = false); + + /// Construct from a DiscreteEliminationTree. + DiscreteSearch(const DiscreteEliminationTree& etree); + + /// Construct from a DiscreteJunctionTree. + DiscreteSearch(const DiscreteJunctionTree& junctionTree); + + //// Construct from a DiscreteBayesNet. + DiscreteSearch(const DiscreteBayesNet& bayesNet); + + /// Construct from a DiscreteBayesTree. + DiscreteSearch(const DiscreteBayesTree& bayesTree); + + /// @} + /// @name Testable + /// @{ + + /** Print the tree to cout */ + void print(const std::string& name = "DiscreteSearch: ", + const KeyFormatter& formatter = DefaultKeyFormatter) const; + + /// @} + /// @name Standard API + /// @{ + + /// Return lower bound on the cost-to-go for the entire search + double lowerBound() const { return lowerBound_; } + + /// Read access to the slots + const std::vector& slots() const { return slots_; } + + /** + * @brief Search for the K best solutions. + * + * This method performs a search to find the K best solutions for the given + * DiscreteBayesNet. It uses a priority queue to manage the search nodes, + * expanding nodes with the smallest bound first. The search continues until + * all possible nodes have been expanded or pruned. + * + * @return A vector of the K best solutions found during the search. + */ + std::vector run(size_t K = 1) const; + + /// @} + + private: + /** + * Compute the cumulative lower-bound cost-to-go after each slot is filled. + * @return the estimated lower bound of the cost for *all* slots. + */ + double computeHeuristic(); + + double lowerBound_; ///< Lower bound on the cost-to-go for the entire search. + std::vector slots_; ///< The slots to fill in the search. +}; + +using DiscreteSearchSolution = DiscreteSearch::Solution; // for wrapping +} // namespace gtsam diff --git a/gtsam/discrete/DiscreteValues.cpp b/gtsam/discrete/DiscreteValues.cpp index 416dfb888..3c3ed4468 100644 --- a/gtsam/discrete/DiscreteValues.cpp +++ b/gtsam/discrete/DiscreteValues.cpp @@ -26,12 +26,24 @@ using std::stringstream; namespace gtsam { +/* ************************************************************************ */ +static void stream(std::ostream& os, const DiscreteValues& x, + const KeyFormatter& keyFormatter) { + for (const auto& kv : x) + os << "(" << keyFormatter(kv.first) << ", " << kv.second << ")"; +} + +/* ************************************************************************ */ +std::ostream& operator<<(std::ostream& os, const DiscreteValues& x) { + stream(os, x, DefaultKeyFormatter); + return os; +} + /* ************************************************************************ */ void DiscreteValues::print(const string& s, const KeyFormatter& keyFormatter) const { cout << s << ": "; - for (auto&& kv : *this) - cout << "(" << keyFormatter(kv.first) << ", " << kv.second << ")"; + stream(cout, *this, keyFormatter); cout << endl; } diff --git a/gtsam/discrete/DiscreteValues.h b/gtsam/discrete/DiscreteValues.h index fa8a8a846..df4ecdbff 100644 --- a/gtsam/discrete/DiscreteValues.h +++ b/gtsam/discrete/DiscreteValues.h @@ -64,6 +64,9 @@ class GTSAM_EXPORT DiscreteValues : public Assignment { /// @name Standard Interface /// @{ + /// ostream operator: + friend std::ostream& operator<<(std::ostream& os, const DiscreteValues& x); + // insert in base class; std::pair insert( const value_type& value ){ return Base::insert(value); diff --git a/gtsam/discrete/discrete.i b/gtsam/discrete/discrete.i index b84ac69a0..5e4d8d22d 100644 --- a/gtsam/discrete/discrete.i +++ b/gtsam/discrete/discrete.i @@ -464,4 +464,29 @@ class DiscreteJunctionTree { const gtsam::DiscreteCluster& operator[](size_t i) const; }; +#include +class DiscreteSearchSolution { + double error; + gtsam::DiscreteValues assignment; + DiscreteSearchSolution(double error, const gtsam::DiscreteValues& assignment); +}; + +class DiscreteSearch { + static DiscreteSearch FromFactorGraph(const gtsam::DiscreteFactorGraph& factorGraph, + const gtsam::Ordering& ordering, + bool buildJunctionTree = false); + + DiscreteSearch(const gtsam::DiscreteEliminationTree& etree); + DiscreteSearch(const gtsam::DiscreteJunctionTree& junctionTree); + DiscreteSearch(const gtsam::DiscreteBayesNet& bayesNet); + DiscreteSearch(const gtsam::DiscreteBayesTree& bayesTree); + + void print(string name = "DiscreteSearch: ", + const gtsam::KeyFormatter& formatter = gtsam::DefaultKeyFormatter) const; + + double lowerBound() const; + + std::vector run(size_t K = 1) const; +}; + } // namespace gtsam diff --git a/gtsam/discrete/tests/AsiaExample.h b/gtsam/discrete/tests/AsiaExample.h new file mode 100644 index 000000000..ff6c4ea99 --- /dev/null +++ b/gtsam/discrete/tests/AsiaExample.h @@ -0,0 +1,61 @@ +/* ---------------------------------------------------------------------------- + + * GTSAM Copyright 2010, Georgia Tech Research Corporation, + * Atlanta, Georgia 30332-0415 + * All Rights Reserved + * Authors: Frank Dellaert, et al. (see THANKS for the full author list) + + * See LICENSE for the license information + + * -------------------------------------------------------------------------- */ + +/* + * AsiaExample.h + * + * @date Jan, 2025 + * @author Frank Dellaert + */ + +#include +#include + +namespace gtsam { +namespace asia_example { + +static const Key D = Symbol('D', 1), X = Symbol('X', 2), E = Symbol('E', 3), + B = Symbol('B', 4), L = Symbol('L', 5), T = Symbol('T', 6), + S = Symbol('S', 7), A = Symbol('A', 8); + +static const DiscreteKey Dyspnea(D, 2), XRay(X, 2), Either(E, 2), + Bronchitis(B, 2), LungCancer(L, 2), Tuberculosis(T, 2), Smoking(S, 2), + Asia(A, 2); + +// Function to construct the Asia priors +DiscreteBayesNet createPriors() { + DiscreteBayesNet priors; + priors.add(Smoking % "50/50"); + priors.add(Asia, "99/1"); + return priors; +} + +// Function to construct the incomplete Asia example +DiscreteBayesNet createFragment() { + DiscreteBayesNet fragment; + fragment.add((Either | Tuberculosis, LungCancer) = "F T T T"); + fragment.add(LungCancer | Smoking = "99/1 90/10"); + fragment.add(Tuberculosis | Asia = "99/1 95/5"); + for (const auto& factor : createPriors()) fragment.push_back(factor); + return fragment; +} + +// Function to construct the Asia example +DiscreteBayesNet createAsiaExample() { + DiscreteBayesNet asia; + asia.add((Dyspnea | Either, Bronchitis) = "9/1 2/8 3/7 1/9"); + asia.add(XRay | Either = "95/5 2/98"); + asia.add(Bronchitis | Smoking = "70/30 40/60"); + for (const auto& factor : createFragment()) asia.push_back(factor); + return asia; +} +} // namespace asia_example +} // namespace gtsam diff --git a/gtsam/discrete/tests/testDiscreteBayesNet.cpp b/gtsam/discrete/tests/testDiscreteBayesNet.cpp index d2033909c..dd5d218f8 100644 --- a/gtsam/discrete/tests/testDiscreteBayesNet.cpp +++ b/gtsam/discrete/tests/testDiscreteBayesNet.cpp @@ -23,40 +23,19 @@ #include #include #include +#include #include #include #include -using namespace std; +#include "AsiaExample.h" + using namespace gtsam; -static const DiscreteKey Asia(0, 2), Smoking(4, 2), Tuberculosis(3, 2), - LungCancer(6, 2), Bronchitis(7, 2), Either(5, 2), XRay(2, 2), Dyspnea(1, 2); - -using ADT = AlgebraicDecisionTree; - -// Function to construct the Asia example -DiscreteBayesNet constructAsiaExample() { - DiscreteBayesNet asia; - - asia.add(Asia, "99/1"); - asia.add(Smoking % "50/50"); // Signature version - - asia.add(Tuberculosis | Asia = "99/1 95/5"); - asia.add(LungCancer | Smoking = "99/1 90/10"); - asia.add(Bronchitis | Smoking = "70/30 40/60"); - - asia.add((Either | Tuberculosis, LungCancer) = "F T T T"); - - asia.add(XRay | Either = "95/5 2/98"); - asia.add((Dyspnea | Either, Bronchitis) = "9/1 2/8 3/7 1/9"); - - return asia; -} - /* ************************************************************************* */ TEST(DiscreteBayesNet, bayesNet) { + using ADT = AlgebraicDecisionTree; DiscreteBayesNet bayesNet; DiscreteKey Parent(0, 2), Child(1, 2); @@ -86,11 +65,12 @@ TEST(DiscreteBayesNet, bayesNet) { /* ************************************************************************* */ TEST(DiscreteBayesNet, Asia) { - DiscreteBayesNet asia = constructAsiaExample(); + using namespace asia_example; + const DiscreteBayesNet asia = createAsiaExample(); // Convert to factor graph DiscreteFactorGraph fg(asia); - LONGS_EQUAL(3, fg.back()->size()); + LONGS_EQUAL(1, fg.back()->size()); // Check the marginals we know (of the parent-less nodes) DiscreteMarginals marginals(fg); @@ -99,7 +79,7 @@ TEST(DiscreteBayesNet, Asia) { EXPECT(assert_equal(vs, marginals.marginalProbabilities(Smoking))); // Create solver and eliminate - const Ordering ordering{0, 1, 2, 3, 4, 5, 6, 7}; + const Ordering ordering{A, D, T, X, S, E, L, B}; DiscreteBayesNet::shared_ptr chordal = fg.eliminateSequential(ordering); DiscreteConditional expected2(Bronchitis % "11/9"); EXPECT(assert_equal(expected2, *chordal->back())); @@ -144,55 +124,50 @@ TEST(DiscreteBayesNet, Sugar) { /* ************************************************************************* */ TEST(DiscreteBayesNet, Dot) { - DiscreteBayesNet fragment; - fragment.add(Asia % "99/1"); - fragment.add(Smoking % "50/50"); + using namespace asia_example; + const DiscreteBayesNet fragment = createFragment(); - fragment.add(Tuberculosis | Asia = "99/1 95/5"); - fragment.add(LungCancer | Smoking = "99/1 90/10"); - fragment.add((Either | Tuberculosis, LungCancer) = "F T T T"); - - string actual = fragment.dot(); - EXPECT(actual == - "digraph {\n" - " size=\"5,5\";\n" - "\n" - " var0[label=\"0\"];\n" - " var3[label=\"3\"];\n" - " var4[label=\"4\"];\n" - " var5[label=\"5\"];\n" - " var6[label=\"6\"];\n" - "\n" - " var3->var5\n" - " var6->var5\n" - " var4->var6\n" - " var0->var3\n" - "}"); + std::string expected = + "digraph {\n" + " size=\"5,5\";\n" + "\n" + " var4683743612465315848[label=\"A8\"];\n" + " var4971973988617027587[label=\"E3\"];\n" + " var5476377146882523141[label=\"L5\"];\n" + " var5980780305148018695[label=\"S7\"];\n" + " var6052837899185946630[label=\"T6\"];\n" + "\n" + " var4683743612465315848->var6052837899185946630\n" + " var5980780305148018695->var5476377146882523141\n" + " var6052837899185946630->var4971973988617027587\n" + " var5476377146882523141->var4971973988617027587\n" + "}"; + std::string actual = fragment.dot(); + EXPECT(actual.compare(expected) == 0); } /* ************************************************************************* */ // Check markdown representation looks as expected. TEST(DiscreteBayesNet, markdown) { - DiscreteBayesNet fragment; - fragment.add(Asia % "99/1"); - fragment.add(Smoking | Asia = "8/2 7/3"); + using namespace asia_example; + DiscreteBayesNet priors = createPriors(); - string expected = + std::string expected = "`DiscreteBayesNet` of size 2\n" "\n" + " *P(Smoking):*\n\n" + "|Smoking|value|\n" + "|:-:|:-:|\n" + "|0|0.5|\n" + "|1|0.5|\n" + "\n" " *P(Asia):*\n\n" "|Asia|value|\n" "|:-:|:-:|\n" "|0|0.99|\n" - "|1|0.01|\n" - "\n" - " *P(Smoking|Asia):*\n\n" - "|*Asia*|0|1|\n" - "|:-:|:-:|:-:|\n" - "|0|0.8|0.2|\n" - "|1|0.7|0.3|\n\n"; - auto formatter = [](Key key) { return key == 0 ? "Asia" : "Smoking"; }; - string actual = fragment.markdown(formatter); + "|1|0.01|\n\n"; + auto formatter = [](Key key) { return key == A ? "Asia" : "Smoking"; }; + std::string actual = priors.markdown(formatter); EXPECT(actual == expected); } diff --git a/gtsam/discrete/tests/testDiscreteSearch.cpp b/gtsam/discrete/tests/testDiscreteSearch.cpp new file mode 100644 index 000000000..cebddfe8d --- /dev/null +++ b/gtsam/discrete/tests/testDiscreteSearch.cpp @@ -0,0 +1,113 @@ +/* ---------------------------------------------------------------------------- + + * GTSAM Copyright 2010, Georgia Tech Research Corporation, + * Atlanta, Georgia 30332-0415 + * All Rights Reserved + * Authors: Frank Dellaert, et al. (see THANKS for the full author list) + + * See LICENSE for the license information + + * -------------------------------------------------------------------------- */ + +/* + * testDiscreteSearch.cpp + * + * @date January, 2025 + * @author Frank Dellaert + */ + +#include +#include +#include + +#include "AsiaExample.h" + +using namespace gtsam; + +// Create Asia Bayes net, FG, and Bayes tree once +namespace asia { +using namespace asia_example; +static const DiscreteBayesNet bayesNet = createAsiaExample(); + +// Create factor graph and optimize with max-product for MPE +static const DiscreteFactorGraph factorGraph(bayesNet); +static const DiscreteValues mpe = factorGraph.optimize(); + +// Create ordering +static const Ordering ordering{D, X, B, E, L, T, S, A}; + +// Create Bayes tree +static const DiscreteBayesTree bayesTree = + *factorGraph.eliminateMultifrontal(ordering); +} // namespace asia + +/* ************************************************************************* */ +TEST(DiscreteBayesNet, EmptyKBest) { + DiscreteBayesNet net; // no factors + DiscreteSearch search(net); + auto solutions = search.run(3); + // Expect one solution with empty assignment, error=0 + EXPECT_LONGS_EQUAL(1, solutions.size()); + EXPECT_DOUBLES_EQUAL(0, std::fabs(solutions[0].error), 1e-9); +} + +/* ************************************************************************* */ +TEST(DiscreteBayesTree, EmptyTree) { + DiscreteBayesTree bt; + + DiscreteSearch search(bt); + auto solutions = search.run(3); + + // We expect exactly 1 solution with error = 0.0 (the empty assignment). + EXPECT_LONGS_EQUAL(1, solutions.size()); + EXPECT_DOUBLES_EQUAL(0, std::fabs(solutions[0].error), 1e-9); +} + +/* ************************************************************************* */ +TEST(DiscreteBayesNet, AsiaKBest) { + auto fromETree = + DiscreteSearch::FromFactorGraph(asia::factorGraph, asia::ordering); + auto fromJunctionTree = + DiscreteSearch::FromFactorGraph(asia::factorGraph, asia::ordering, true); + const DiscreteSearch fromBayesNet(asia::bayesNet); + const DiscreteSearch fromBayesTree(asia::bayesTree); + + for (auto& search : + {fromETree, fromJunctionTree, fromBayesNet, fromBayesTree}) { + // Ask for the MPE + auto mpe = search.run(); + + // Regression on error lower bound + EXPECT_DOUBLES_EQUAL(1.205536, search.lowerBound(), 1e-5); + + // Check that the cost-to-go heuristic decreases from there + auto slots = search.slots(); + double previousHeuristic = search.lowerBound(); + for (auto&& slot : slots) { + EXPECT(slot.heuristic <= previousHeuristic); + previousHeuristic = slot.heuristic; + } + + EXPECT_LONGS_EQUAL(1, mpe.size()); + // Regression test: check the MPE solution + EXPECT_DOUBLES_EQUAL(1.236627, std::fabs(mpe[0].error), 1e-5); + + // Check it is equal to MPE via inference + EXPECT(assert_equal(asia::mpe, mpe[0].assignment)); + + // Ask for top 4 solutions + auto solutions = search.run(4); + + EXPECT_LONGS_EQUAL(4, solutions.size()); + // Regression test: check the first and last solution + EXPECT_DOUBLES_EQUAL(1.236627, std::fabs(solutions[0].error), 1e-5); + EXPECT_DOUBLES_EQUAL(2.201708, std::fabs(solutions[3].error), 1e-5); + } +} + +/* ************************************************************************* */ +int main() { + TestResult tr; + return TestRegistry::runAllTests(tr); +} +/* ************************************************************************* */ diff --git a/gtsam/hybrid/HybridBayesTree.cpp b/gtsam/hybrid/HybridBayesTree.cpp index c74930a8e..fbf892235 100644 --- a/gtsam/hybrid/HybridBayesTree.cpp +++ b/gtsam/hybrid/HybridBayesTree.cpp @@ -202,6 +202,11 @@ VectorValues HybridBayesTree::optimize(const DiscreteValues& assignment) const { /* ************************************************************************* */ void HybridBayesTree::prune(const size_t maxNrLeaves) { + if (!this->roots_.at(0)->conditional()->asDiscrete()) { + // Root of the BayesTree is not a discrete clique, so we do nothing. + return; + } + auto prunedDiscreteProbs = this->roots_.at(0)->conditional()->asDiscrete(); diff --git a/gtsam/hybrid/HybridNonlinearFactor.cpp b/gtsam/hybrid/HybridNonlinearFactor.cpp index 376bc66f1..fa22051e5 100644 --- a/gtsam/hybrid/HybridNonlinearFactor.cpp +++ b/gtsam/hybrid/HybridNonlinearFactor.cpp @@ -99,7 +99,8 @@ AlgebraicDecisionTree HybridNonlinearFactor::errorTree( auto errorFunc = [continuousValues](const std::pair& f) { auto [factor, val] = f; - return factor->error(continuousValues) + val; + return factor ? factor->error(continuousValues) + val + : std::numeric_limits::infinity(); }; return {factors_, errorFunc}; } diff --git a/python/gtsam/tests/dfg_utils.py b/python/gtsam/tests/dfg_utils.py new file mode 100644 index 000000000..9ad521fd4 --- /dev/null +++ b/python/gtsam/tests/dfg_utils.py @@ -0,0 +1,35 @@ +import numpy as np +from gtsam import Symbol + + +def make_key(character, index, cardinality): + """ + Helper function to mimic the behavior of gtbook.Variables discrete_series function. + """ + symbol = Symbol(character, index) + key = symbol.key() + return (key, cardinality) + + +def generate_transition_cpt(num_states, transitions=None): + """ + Generate a row-wise CPT for a transition matrix. + """ + if transitions is None: + # Default to identity matrix with slight regularization + transitions = np.eye(num_states) + 0.1 / num_states + + # Ensure transitions sum to 1 if not already normalized + transitions /= np.sum(transitions, axis=1, keepdims=True) + return " ".join(["/".join(map(str, row)) for row in transitions]) + + +def generate_observation_cpt(num_states, num_obs, desired_state): + """ + Generate a row-wise CPT for observations with contrived probabilities. + """ + obs = np.zeros((num_states, num_obs + 1)) + obs[:, -1] = 1 # All states default to measurement num_obs + obs[desired_state, 0:-1] = 1 + obs[desired_state, -1] = 0 + return " ".join(["/".join(map(str, row)) for row in obs]) diff --git a/python/gtsam/tests/test_DiscreteFactorGraph.py b/python/gtsam/tests/test_DiscreteFactorGraph.py index 3053087b4..521eeefa6 100644 --- a/python/gtsam/tests/test_DiscreteFactorGraph.py +++ b/python/gtsam/tests/test_DiscreteFactorGraph.py @@ -15,10 +15,16 @@ import unittest import numpy as np from gtsam.utils.test_case import GtsamTestCase +from dfg_utils import make_key, generate_transition_cpt, generate_observation_cpt -from gtsam import (DecisionTreeFactor, DiscreteConditional, - DiscreteFactorGraph, DiscreteKeys, DiscreteValues, Ordering, - Symbol) +from gtsam import ( + DecisionTreeFactor, + DiscreteConditional, + DiscreteFactorGraph, + DiscreteKeys, + DiscreteValues, + Ordering, +) OrderingType = Ordering.OrderingType @@ -50,7 +56,7 @@ class TestDiscreteFactorGraph(GtsamTestCase): assignment[1] = 1 # Check if graph evaluation works ( 0.3*0.6*4 ) - self.assertAlmostEqual(.72, graph(assignment)) + self.assertAlmostEqual(0.72, graph(assignment)) # Create a new test with third node and adding unary and ternary factor graph.add(P3, "0.9 0.2 0.5") @@ -100,8 +106,7 @@ class TestDiscreteFactorGraph(GtsamTestCase): expectedValues[1] = 0 expectedValues[2] = 0 actualValues = graph.optimize() - self.assertEqual(list(actualValues.items()), - list(expectedValues.items())) + self.assertEqual(list(actualValues.items()), list(expectedValues.items())) def test_MPE(self): """Test maximum probable explanation (MPE): same as optimize.""" @@ -123,13 +128,11 @@ class TestDiscreteFactorGraph(GtsamTestCase): # Use maxProduct dag = graph.maxProduct(OrderingType.COLAMD) actualMPE = dag.argmax() - self.assertEqual(list(actualMPE.items()), - list(mpe.items())) + self.assertEqual(list(actualMPE.items()), list(mpe.items())) # All in one actualMPE2 = graph.optimize() - self.assertEqual(list(actualMPE2.items()), - list(mpe.items())) + self.assertEqual(list(actualMPE2.items()), list(mpe.items())) def test_sumProduct(self): """Test sumProduct.""" @@ -154,11 +157,17 @@ class TestDiscreteFactorGraph(GtsamTestCase): self.assertAlmostEqual(mpeProbability, 0.36) # regression # Use sumProduct - for ordering_type in [OrderingType.COLAMD, OrderingType.METIS, OrderingType.NATURAL, - OrderingType.CUSTOM]: + for ordering_type in [ + OrderingType.COLAMD, + OrderingType.METIS, + OrderingType.NATURAL, + OrderingType.CUSTOM, + ]: bayesNet = graph.sumProduct(ordering_type) self.assertEqual(bayesNet(mpe), mpeProbability) + +class TestChains(GtsamTestCase): def test_MPE_chain(self): """ Test for numerical underflow in EliminateMPE on long chains. @@ -170,46 +179,22 @@ class TestDiscreteFactorGraph(GtsamTestCase): desired_state = 1 states = list(range(num_states)) - # Helper function to mimic the behavior of gtbook.Variables discrete_series function - def make_key(character, index, cardinality): - symbol = Symbol(character, index) - key = symbol.key() - return (key, cardinality) - X = {index: make_key("X", index, len(states)) for index in range(num_obs)} Z = {index: make_key("Z", index, num_obs + 1) for index in range(num_obs)} graph = DiscreteFactorGraph() - # Mostly identity transition matrix - transitions = np.eye(num_states) - - # Needed otherwise mpe is always state 0? - transitions += 0.1/(num_states) - - transition_cpt = [] - for i in range(0, num_states): - transition_row = "/".join([str(x) for x in transitions[i]]) - transition_cpt.append(transition_row) - transition_cpt = " ".join(transition_cpt) - + transition_cpt = generate_transition_cpt(num_states) for i in reversed(range(1, num_obs)): - transition_conditional = DiscreteConditional(X[i], [X[i-1]], transition_cpt) + transition_conditional = DiscreteConditional( + X[i], [X[i - 1]], transition_cpt + ) graph.push_back(transition_conditional) # Contrived example such that the desired state gives measurements [0, num_obs) with equal probability # but all other states always give measurement num_obs - obs = np.zeros((num_states, num_obs+1)) - obs[:,-1] = 1 - obs[desired_state,0: -1] = 1 - obs[desired_state,-1] = 0 - obs_cpt_list = [] - for i in range(0, num_states): - obs_row = "/".join([str(z) for z in obs[i]]) - obs_cpt_list.append(obs_row) - obs_cpt = " ".join(obs_cpt_list) - + obs_cpt = generate_observation_cpt(num_states, num_obs, desired_state) # Contrived example where each measurement is its own index - for i in range(0, num_obs): + for i in range(num_obs): obs_conditional = DiscreteConditional(Z[i], [X[i]], obs_cpt) factor = obs_conditional.likelihood(i) graph.push_back(factor) @@ -217,7 +202,7 @@ class TestDiscreteFactorGraph(GtsamTestCase): mpe = graph.optimize() vals = [mpe[X[i][0]] for i in range(num_obs)] - self.assertEqual(vals, [desired_state]*num_obs) + self.assertEqual(vals, [desired_state] * num_obs) def test_sumProduct_chain(self): """ @@ -227,15 +212,8 @@ class TestDiscreteFactorGraph(GtsamTestCase): """ num_states = 3 chain_length = 400 - desired_state = 1 states = list(range(num_states)) - # Helper function to mimic the behavior of gtbook.Variables discrete_series function - def make_key(character, index, cardinality): - symbol = Symbol(character, index) - key = symbol.key() - return (key, cardinality) - X = {index: make_key("X", index, len(states)) for index in range(chain_length)} graph = DiscreteFactorGraph() @@ -253,18 +231,15 @@ class TestDiscreteFactorGraph(GtsamTestCase): # Ensure that the stationary distribution is positive and normalized stationary_dist /= np.sum(stationary_dist) - expected = DecisionTreeFactor(X[chain_length-1], stationary_dist.flatten()) + expected = DecisionTreeFactor(X[chain_length - 1], stationary_dist.ravel()) # The transition matrix parsed by DiscreteConditional is a row-wise CPT - transitions = transitions.T - transition_cpt = [] - for i in range(0, num_states): - transition_row = "/".join([str(x) for x in transitions[i]]) - transition_cpt.append(transition_row) - transition_cpt = " ".join(transition_cpt) + transition_cpt = generate_transition_cpt(num_states, transitions.T) for i in reversed(range(1, chain_length)): - transition_conditional = DiscreteConditional(X[i], [X[i-1]], transition_cpt) + transition_conditional = DiscreteConditional( + X[i], [X[i - 1]], transition_cpt + ) graph.push_back(transition_conditional) # Run sum product using natural ordering so the resulting Bayes net has the form: @@ -277,5 +252,6 @@ class TestDiscreteFactorGraph(GtsamTestCase): # Ensure marginal probabilities are close to the stationary distribution self.gtsamAssertEquals(expected, last_marginal) + if __name__ == "__main__": unittest.main() diff --git a/python/gtsam/tests/test_DiscreteSearch.py b/python/gtsam/tests/test_DiscreteSearch.py new file mode 100644 index 000000000..d0077f6db --- /dev/null +++ b/python/gtsam/tests/test_DiscreteSearch.py @@ -0,0 +1,84 @@ +""" +GTSAM Copyright 2010-2019, Georgia Tech Research Corporation, +Atlanta, Georgia 30332-0415 +All Rights Reserved + +See LICENSE for the license information + +Unit tests for Discrete Search. +Author: Frank Dellaert +""" + +# pylint: disable=no-name-in-module, invalid-name + +import unittest + +from dfg_utils import generate_observation_cpt, generate_transition_cpt, make_key +from gtsam.utils.test_case import GtsamTestCase + +from gtsam import ( + DiscreteConditional, + DiscreteFactorGraph, + DiscreteSearch, + Ordering, + DefaultKeyFormatter, +) + +OrderingType = Ordering.OrderingType + + +class TestDiscreteSearch(GtsamTestCase): + """Tests for Discrete Factor Graphs.""" + + def test_MPE_chain(self): + """ + Test for numerical underflow in EliminateMPE on long chains. + Adapted from the toy problem of @pcl15423 + Ref: https://github.com/borglab/gtsam/issues/1448 + """ + num_states = 3 + num_obs = 200 + desired_state = 1 + states = list(range(num_states)) + + X = {index: make_key("X", index, len(states)) for index in range(num_obs)} + Z = {index: make_key("Z", index, num_obs + 1) for index in range(num_obs)} + graph = DiscreteFactorGraph() + + transition_cpt = generate_transition_cpt(num_states) + for i in reversed(range(1, num_obs)): + transition_conditional = DiscreteConditional( + X[i], [X[i - 1]], transition_cpt + ) + graph.push_back(transition_conditional) + + # Contrived example such that the desired state gives measurements [0, num_obs) with equal + # probability but all other states always give measurement num_obs + obs_cpt = generate_observation_cpt(num_states, num_obs, desired_state) + # Contrived example where each measurement is its own index + for i in range(num_obs): + obs_conditional = DiscreteConditional(Z[i], [X[i]], obs_cpt) + factor = obs_conditional.likelihood(i) + graph.push_back(factor) + + # Check MPE + mpe = graph.optimize() + vals = [mpe[X[i][0]] for i in range(num_obs)] + self.assertEqual(vals, [desired_state] * num_obs) + + # Create an ordering: + ordering = Ordering() + for i in reversed(range(num_obs)): + ordering.push_back(X[i][0]) + + # Now do Search + search = DiscreteSearch.FromFactorGraph(graph, ordering) + solutions = search.run(K=1) + mpe2 = solutions[0].assignment + # print({DefaultKeyFormatter(key): value for key, value in mpe2.items()}) + vals = [mpe2[X[i][0]] for i in range(num_obs)] + self.assertEqual(vals, [desired_state] * num_obs) + + +if __name__ == "__main__": + unittest.main()