From 10f30e1ca93f7c335684f814b64944c3a5b7a51e Mon Sep 17 00:00:00 2001 From: Frank Dellaert Date: Sat, 10 Jun 2023 13:56:14 -0700 Subject: [PATCH] multifrontal MPE in python --- gtsam/discrete/DiscreteFactorGraph.h | 39 +++++++---- gtsam/discrete/discrete.i | 52 +++++++++++---- .../discrete/tests/testDiscreteBayesTree.cpp | 3 +- gtsam/inference/EliminateableFactorGraph.h | 10 +-- gtsam/inference/inference.i | 1 + python/gtsam/tests/test_DiscreteBayesTree.py | 64 ++++++++++++++++--- 6 files changed, 130 insertions(+), 39 deletions(-) diff --git a/gtsam/discrete/DiscreteFactorGraph.h b/gtsam/discrete/DiscreteFactorGraph.h index 3fbb64d50..68b7a85a7 100644 --- a/gtsam/discrete/DiscreteFactorGraph.h +++ b/gtsam/discrete/DiscreteFactorGraph.h @@ -42,16 +42,30 @@ class DiscreteJunctionTree; /** * @brief Main elimination function for DiscreteFactorGraph. - * - * @param factors - * @param keys - * @return GTSAM_EXPORT + * + * @param factors The factor graph to eliminate. + * @param frontalKeys An ordering for which variables to eliminate. + * @return A pair of the resulting conditional and the separator factor. * @ingroup discrete */ -GTSAM_EXPORT std::pair, DecisionTreeFactor::shared_ptr> -EliminateDiscrete(const DiscreteFactorGraph& factors, const Ordering& keys); +GTSAM_EXPORT +std::pair +EliminateDiscrete(const DiscreteFactorGraph& factors, + const Ordering& frontalKeys); + +/** + * @brief Alternate elimination function for that creates non-normalized lookup tables. + * + * @param factors The factor graph to eliminate. + * @param frontalKeys An ordering for which variables to eliminate. + * @return A pair of the resulting lookup table and the separator factor. + * @ingroup discrete + */ +GTSAM_EXPORT +std::pair +EliminateForMPE(const DiscreteFactorGraph& factors, + const Ordering& frontalKeys); -/* ************************************************************************* */ template<> struct EliminationTraits { typedef DiscreteFactor FactorType; ///< Type of factors in factor graph @@ -61,12 +75,14 @@ template<> struct EliminationTraits typedef DiscreteEliminationTree EliminationTreeType; ///< Type of elimination tree typedef DiscreteBayesTree BayesTreeType; ///< Type of Bayes tree typedef DiscreteJunctionTree JunctionTreeType; ///< Type of Junction tree + /// The default dense elimination function static std::pair, boost::shared_ptr > DefaultEliminate(const FactorGraphType& factors, const Ordering& keys) { return EliminateDiscrete(factors, keys); } + /// The default ordering generation function static Ordering DefaultOrderingFunc( const FactorGraphType& graph, @@ -75,7 +91,6 @@ template<> struct EliminationTraits } }; -/* ************************************************************************* */ /** * A Discrete Factor Graph is a factor graph where all factors are Discrete, i.e. * Factor == DiscreteFactor @@ -109,8 +124,8 @@ class GTSAM_EXPORT DiscreteFactorGraph /** Implicit copy/downcast constructor to override explicit template container * constructor */ - template - DiscreteFactorGraph(const FactorGraph& graph) : Base(graph) {} + template + DiscreteFactorGraph(const FactorGraph& graph) : Base(graph) {} /// Destructor virtual ~DiscreteFactorGraph() {} @@ -231,10 +246,6 @@ class GTSAM_EXPORT DiscreteFactorGraph /// @} }; // \ DiscreteFactorGraph -std::pair // -EliminateForMPE(const DiscreteFactorGraph& factors, - const Ordering& frontalKeys); - /// traits template <> struct traits : public Testable {}; diff --git a/gtsam/discrete/discrete.i b/gtsam/discrete/discrete.i index 32f0439ea..b66069340 100644 --- a/gtsam/discrete/discrete.i +++ b/gtsam/discrete/discrete.i @@ -275,6 +275,14 @@ class DiscreteLookupDAG { }; #include +std::pair +EliminateDiscrete(const gtsam::DiscreteFactorGraph& factors, + const gtsam::Ordering& frontalKeys); + +std::pair +EliminateForMPE(const gtsam::DiscreteFactorGraph& factors, + const gtsam::Ordering& frontalKeys); + class DiscreteFactorGraph { DiscreteFactorGraph(); DiscreteFactorGraph(const gtsam::DiscreteBayesNet& bayesNet); @@ -289,6 +297,7 @@ class DiscreteFactorGraph { void add(const gtsam::DiscreteKey& j, const std::vector& spec); void add(const gtsam::DiscreteKeys& keys, string spec); void add(const std::vector& keys, string spec); + void add(const std::vector& keys, const std::vector& spec); bool empty() const; size_t size() const; @@ -302,25 +311,46 @@ class DiscreteFactorGraph { double operator()(const gtsam::DiscreteValues& values) const; gtsam::DiscreteValues optimize() const; - gtsam::DiscreteBayesNet sumProduct(); - gtsam::DiscreteBayesNet sumProduct(gtsam::Ordering::OrderingType type); + gtsam::DiscreteBayesNet sumProduct( + gtsam::Ordering::OrderingType type = gtsam::Ordering::COLAMD); gtsam::DiscreteBayesNet sumProduct(const gtsam::Ordering& ordering); - gtsam::DiscreteLookupDAG maxProduct(); - gtsam::DiscreteLookupDAG maxProduct(gtsam::Ordering::OrderingType type); + gtsam::DiscreteLookupDAG maxProduct( + gtsam::Ordering::OrderingType type = gtsam::Ordering::COLAMD); gtsam::DiscreteLookupDAG maxProduct(const gtsam::Ordering& ordering); - gtsam::DiscreteBayesNet* eliminateSequential(); - gtsam::DiscreteBayesNet* eliminateSequential(gtsam::Ordering::OrderingType type); + gtsam::DiscreteBayesNet* eliminateSequential( + gtsam::Ordering::OrderingType type = gtsam::Ordering::COLAMD); + gtsam::DiscreteBayesNet* eliminateSequential( + gtsam::Ordering::OrderingType type, + const gtsam::DiscreteFactorGraph::Eliminate& function); gtsam::DiscreteBayesNet* eliminateSequential(const gtsam::Ordering& ordering); + gtsam::DiscreteBayesNet* eliminateSequential( + const gtsam::Ordering& ordering, + const gtsam::DiscreteFactorGraph::Eliminate& function); pair - eliminatePartialSequential(const gtsam::Ordering& ordering); + eliminatePartialSequential(const gtsam::Ordering& ordering); + pair + eliminatePartialSequential( + const gtsam::Ordering& ordering, + const gtsam::DiscreteFactorGraph::Eliminate& function); - gtsam::DiscreteBayesTree* eliminateMultifrontal(); - gtsam::DiscreteBayesTree* eliminateMultifrontal(gtsam::Ordering::OrderingType type); - gtsam::DiscreteBayesTree* eliminateMultifrontal(const gtsam::Ordering& ordering); + gtsam::DiscreteBayesTree* eliminateMultifrontal( + gtsam::Ordering::OrderingType type = gtsam::Ordering::COLAMD); + gtsam::DiscreteBayesTree* eliminateMultifrontal( + gtsam::Ordering::OrderingType type, + const gtsam::DiscreteFactorGraph::Eliminate& function); + gtsam::DiscreteBayesTree* eliminateMultifrontal( + const gtsam::Ordering& ordering); + gtsam::DiscreteBayesTree* eliminateMultifrontal( + const gtsam::Ordering& ordering, + const gtsam::DiscreteFactorGraph::Eliminate& function); pair - eliminatePartialMultifrontal(const gtsam::Ordering& ordering); + eliminatePartialMultifrontal(const gtsam::Ordering& ordering); + pair + eliminatePartialMultifrontal( + const gtsam::Ordering& ordering, + const gtsam::DiscreteFactorGraph::Eliminate& function); string dot( const gtsam::KeyFormatter& keyFormatter = gtsam::DefaultKeyFormatter, diff --git a/gtsam/discrete/tests/testDiscreteBayesTree.cpp b/gtsam/discrete/tests/testDiscreteBayesTree.cpp index 00020f567..dc0b69011 100644 --- a/gtsam/discrete/tests/testDiscreteBayesTree.cpp +++ b/gtsam/discrete/tests/testDiscreteBayesTree.cpp @@ -323,10 +323,11 @@ TEST(DiscreteBayesTree, Lookup) { DiscreteFactorGraph graph; const DiscreteKey x1{X(1), 3}, x2{X(2), 3}, x3{X(3), 3}; const DiscreteKey a1{A(1), 2}, a2{A(2), 2}; - const DiscreteKeys keys{x1, x2, x3, a1, a2}; + // Constraint on start and goal graph.add(DiscreteKeys{x1}, std::vector{1, 0, 0}); graph.add(DiscreteKeys{x3}, std::vector{0, 0, 1}); + // Should I stay or should I go? // "Reward" (exp(-cost)) for an action is 10, and rewards multiply: const double r = 10; diff --git a/gtsam/inference/EliminateableFactorGraph.h b/gtsam/inference/EliminateableFactorGraph.h index 900346f7f..b8d255f5a 100644 --- a/gtsam/inference/EliminateableFactorGraph.h +++ b/gtsam/inference/EliminateableFactorGraph.h @@ -52,12 +52,12 @@ namespace gtsam { * algorithms. Any factor graph holding eliminateable factors can derive from this class to * expose functions for computing marginals, conditional marginals, doing multifrontal and * sequential elimination, etc. */ - template + template class EliminateableFactorGraph { private: - typedef EliminateableFactorGraph This; ///< Typedef to this class. - typedef FACTORGRAPH FactorGraphType; ///< Typedef to factor graph type + typedef EliminateableFactorGraph This; ///< Typedef to this class. + typedef FACTOR_GRAPH FactorGraphType; ///< Typedef to factor graph type // Base factor type stored in this graph (private because derived classes will get this from // their FactorGraph base class) typedef typename EliminationTraits::FactorType _FactorType; @@ -139,7 +139,7 @@ namespace gtsam { OptionalVariableIndex variableIndex = boost::none) const; /** Do multifrontal elimination of all variables to produce a Bayes tree. If an ordering is not - * provided, the ordering will be computed using either COLAMD or METIS, dependeing on + * provided, the ordering will be computed using either COLAMD or METIS, depending on * the parameter orderingType (Ordering::COLAMD or Ordering::METIS) * * Example - Full Cholesky elimination in COLAMD order: @@ -160,7 +160,7 @@ namespace gtsam { OptionalVariableIndex variableIndex = boost::none) const; /** Do multifrontal elimination of all variables to produce a Bayes tree. If an ordering is not - * provided, the ordering will be computed using either COLAMD or METIS, dependeing on + * provided, the ordering will be computed using either COLAMD or METIS, depending on * the parameter orderingType (Ordering::COLAMD or Ordering::METIS) * * Example - Full QR elimination in specified order: diff --git a/gtsam/inference/inference.i b/gtsam/inference/inference.i index 17ea117c3..39039926e 100644 --- a/gtsam/inference/inference.i +++ b/gtsam/inference/inference.i @@ -104,6 +104,7 @@ class Ordering { // Standard Constructors and Named Constructors Ordering(); Ordering(const gtsam::Ordering& other); + Ordering(const std::vector& keys); template < FACTOR_GRAPH = {gtsam::NonlinearFactorGraph, gtsam::DiscreteFactorGraph, diff --git a/python/gtsam/tests/test_DiscreteBayesTree.py b/python/gtsam/tests/test_DiscreteBayesTree.py index e1754ca64..720f884e0 100644 --- a/python/gtsam/tests/test_DiscreteBayesTree.py +++ b/python/gtsam/tests/test_DiscreteBayesTree.py @@ -13,16 +13,14 @@ Author: Frank Dellaert import unittest +import numpy as np +from gtsam.symbol_shorthand import A, X from gtsam.utils.test_case import GtsamTestCase -from gtsam import ( - DiscreteBayesNet, - DiscreteBayesTreeClique, - DiscreteConditional, - DiscreteFactorGraph, - DiscreteValues, - Ordering, -) +import gtsam +from gtsam import (DiscreteBayesNet, DiscreteBayesTreeClique, + DiscreteConditional, DiscreteFactorGraph, + DiscreteKeys, DiscreteValues, Ordering) class TestDiscreteBayesNet(GtsamTestCase): @@ -100,6 +98,56 @@ class TestDiscreteBayesNet(GtsamTestCase): self.assertFalse(bayesTree.empty()) self.assertEqual(12, bayesTree.size()) + def test_discrete_bayes_tree_lookup(self): + """Check that we can have a multi-frontal lookup table.""" + # Make a small planning-like graph: 3 states, 2 actions + graph = DiscreteFactorGraph() + x1, x2, x3 = (X(1), 3), (X(2), 3), (X(3), 3) + a1, a2 = (A(1), 2), (A(2), 2) + + # Constraint on start and goal + graph.add([x1], np.array([1, 0, 0])) + graph.add([x3], np.array([0, 0, 1])) + + # Should I stay or should I go? + # "Reward" (exp(-cost)) for an action is 10, and rewards multiply: + r = 10 + table = np.array([ + r, 0, 0, 0, r, 0, # x1 = 0 + 0, r, 0, 0, 0, r, # x1 = 1 + 0, 0, r, 0, 0, r # x1 = 2 + ]) + graph.add([x1, a1, x2], table) + graph.add([x2, a2, x3], table) + + # Eliminate for MPE (maximum probable explanation). + ordering = Ordering([A(2), X(3), X(1), A(1), X(2)]) + lookup = graph.eliminateMultifrontal(ordering, gtsam.EliminateForMPE) + + # Check that the lookup table is correct + assert len(lookup) == 2 + lookup_x1_a1_x2 = lookup[X(1)].conditional() + assert len(lookup_x1_a1_x2.frontals()) == 3 + # Check that sum is 100 + empty = gtsam.DiscreteValues() + assert np.isclose(lookup_x1_a1_x2.sum(3)(empty), 100, atol=1e-9) + # And that only non-zero reward is for x1 a1 x2 == 0 1 1 + assert np.isclose(lookup_x1_a1_x2({X(1): 0, A(1): 1, X(2): 1}), 100, atol=1e-9) + + lookup_a2_x3 = lookup[X(3)].conditional() + # Check that the sum depends on x2 and is non-zero only for x2 in {1, 2} + sum_x2 = lookup_a2_x3.sum(2) + assert np.isclose(sum_x2({X(2): 0}), 0, atol=1e-9) + assert np.isclose(sum_x2({X(2): 1}), 10, atol=1e-9) + assert np.isclose(sum_x2({X(2): 2}), 20, atol=1e-9) + assert len(lookup_a2_x3.frontals()) == 2 + # And that the non-zero rewards are for + # x2 a2 x3 == 1 1 2 + assert np.isclose(lookup_a2_x3({X(2): 1, A(2): 1, X(3): 2}), 10, atol=1e-9) + # x2 a2 x3 == 2 0 2 + assert np.isclose(lookup_a2_x3({X(2): 2, A(2): 0, X(3): 2}), 10, atol=1e-9) + # x2 a2 x3 == 2 1 2 + assert np.isclose(lookup_a2_x3({X(2): 2, A(2): 1, X(3): 2}), 10, atol=1e-9) if __name__ == "__main__": unittest.main()