parent
							
								
									bd4230baae
								
							
						
					
					
						commit
						3a9f5578d6
					
				|  | @ -41,16 +41,30 @@ class DiscreteJunctionTree; | |||
| 
 | ||||
| /**
 | ||||
|  * @brief Main elimination function for DiscreteFactorGraph. | ||||
|  *  | ||||
|  * @param factors  | ||||
|  * @param keys  | ||||
|  * @return GTSAM_EXPORT | ||||
|  * | ||||
|  * @param factors The factor graph to eliminate. | ||||
|  * @param frontalKeys An ordering for which variables to eliminate. | ||||
|  * @return A pair of the resulting conditional and the separator factor. | ||||
|  * @ingroup discrete | ||||
|  */ | ||||
| GTSAM_EXPORT std::pair<std::shared_ptr<DiscreteConditional>, DecisionTreeFactor::shared_ptr> | ||||
| EliminateDiscrete(const DiscreteFactorGraph& factors, const Ordering& keys); | ||||
| GTSAM_EXPORT | ||||
| std::pair<DiscreteConditional::shared_ptr, DecisionTreeFactor::shared_ptr> | ||||
| EliminateDiscrete(const DiscreteFactorGraph& factors, | ||||
|                   const Ordering& frontalKeys); | ||||
| 
 | ||||
| /**
 | ||||
|  * @brief Alternate elimination function for that creates non-normalized lookup tables. | ||||
|  * | ||||
|  * @param factors The factor graph to eliminate. | ||||
|  * @param frontalKeys An ordering for which variables to eliminate. | ||||
|  * @return A pair of the resulting lookup table and the separator factor. | ||||
|  * @ingroup discrete | ||||
|  */ | ||||
| GTSAM_EXPORT | ||||
| std::pair<DiscreteConditional::shared_ptr, DecisionTreeFactor::shared_ptr> | ||||
| EliminateForMPE(const DiscreteFactorGraph& factors, | ||||
|                 const Ordering& frontalKeys); | ||||
| 
 | ||||
| /* ************************************************************************* */ | ||||
| template<> struct EliminationTraits<DiscreteFactorGraph> | ||||
| { | ||||
|   typedef DiscreteFactor FactorType;                   ///< Type of factors in factor graph
 | ||||
|  | @ -60,12 +74,14 @@ template<> struct EliminationTraits<DiscreteFactorGraph> | |||
|   typedef DiscreteEliminationTree EliminationTreeType; ///< Type of elimination tree
 | ||||
|   typedef DiscreteBayesTree BayesTreeType;             ///< Type of Bayes tree
 | ||||
|   typedef DiscreteJunctionTree JunctionTreeType;       ///< Type of Junction tree
 | ||||
|    | ||||
|   /// The default dense elimination function
 | ||||
|   static std::pair<std::shared_ptr<ConditionalType>, | ||||
|                    std::shared_ptr<FactorType> > | ||||
|   DefaultEliminate(const FactorGraphType& factors, const Ordering& keys) { | ||||
|     return EliminateDiscrete(factors, keys); | ||||
|   } | ||||
|    | ||||
|   /// The default ordering generation function
 | ||||
|   static Ordering DefaultOrderingFunc( | ||||
|       const FactorGraphType& graph, | ||||
|  | @ -74,7 +90,6 @@ template<> struct EliminationTraits<DiscreteFactorGraph> | |||
|   } | ||||
| }; | ||||
| 
 | ||||
| /* ************************************************************************* */ | ||||
| /**
 | ||||
|  * A Discrete Factor Graph is a factor graph where all factors are Discrete, i.e. | ||||
|  *   Factor == DiscreteFactor | ||||
|  | @ -108,8 +123,8 @@ class GTSAM_EXPORT DiscreteFactorGraph | |||
| 
 | ||||
|   /** Implicit copy/downcast constructor to override explicit template container
 | ||||
|    * constructor */ | ||||
|   template <class DERIVEDFACTOR> | ||||
|   DiscreteFactorGraph(const FactorGraph<DERIVEDFACTOR>& graph) : Base(graph) {} | ||||
|   template <class DERIVED_FACTOR> | ||||
|   DiscreteFactorGraph(const FactorGraph<DERIVED_FACTOR>& graph) : Base(graph) {} | ||||
| 
 | ||||
|   /// @name Testable
 | ||||
|   /// @{
 | ||||
|  | @ -227,10 +242,6 @@ class GTSAM_EXPORT DiscreteFactorGraph | |||
|   /// @}
 | ||||
| };  // \ DiscreteFactorGraph
 | ||||
| 
 | ||||
| std::pair<DiscreteConditional::shared_ptr, DecisionTreeFactor::shared_ptr>  //
 | ||||
| EliminateForMPE(const DiscreteFactorGraph& factors, | ||||
|                 const Ordering& frontalKeys); | ||||
| 
 | ||||
| /// traits
 | ||||
| template <> | ||||
| struct traits<DiscreteFactorGraph> : public Testable<DiscreteFactorGraph> {}; | ||||
|  |  | |||
|  | @ -275,6 +275,14 @@ class DiscreteLookupDAG { | |||
| }; | ||||
| 
 | ||||
| #include <gtsam/discrete/DiscreteFactorGraph.h> | ||||
| std::pair<gtsam::DiscreteConditional*, gtsam::DecisionTreeFactor*> | ||||
| EliminateDiscrete(const gtsam::DiscreteFactorGraph& factors, | ||||
|                   const gtsam::Ordering& frontalKeys); | ||||
| 
 | ||||
| std::pair<gtsam::DiscreteConditional*, gtsam::DecisionTreeFactor*> | ||||
| EliminateForMPE(const gtsam::DiscreteFactorGraph& factors, | ||||
|                 const gtsam::Ordering& frontalKeys); | ||||
| 
 | ||||
| class DiscreteFactorGraph { | ||||
|   DiscreteFactorGraph(); | ||||
|   DiscreteFactorGraph(const gtsam::DiscreteBayesNet& bayesNet); | ||||
|  | @ -289,6 +297,7 @@ class DiscreteFactorGraph { | |||
|   void add(const gtsam::DiscreteKey& j, const std::vector<double>& spec); | ||||
|   void add(const gtsam::DiscreteKeys& keys, string spec); | ||||
|   void add(const std::vector<gtsam::DiscreteKey>& keys, string spec); | ||||
|   void add(const std::vector<gtsam::DiscreteKey>& keys, const std::vector<double>& spec); | ||||
| 
 | ||||
|   bool empty() const; | ||||
|   size_t size() const; | ||||
|  | @ -302,25 +311,46 @@ class DiscreteFactorGraph { | |||
|   double operator()(const gtsam::DiscreteValues& values) const; | ||||
|   gtsam::DiscreteValues optimize() const; | ||||
| 
 | ||||
|   gtsam::DiscreteBayesNet sumProduct(); | ||||
|   gtsam::DiscreteBayesNet sumProduct(gtsam::Ordering::OrderingType type); | ||||
|   gtsam::DiscreteBayesNet sumProduct( | ||||
|       gtsam::Ordering::OrderingType type = gtsam::Ordering::COLAMD); | ||||
|   gtsam::DiscreteBayesNet sumProduct(const gtsam::Ordering& ordering); | ||||
| 
 | ||||
|   gtsam::DiscreteLookupDAG maxProduct(); | ||||
|   gtsam::DiscreteLookupDAG maxProduct(gtsam::Ordering::OrderingType type); | ||||
|   gtsam::DiscreteLookupDAG maxProduct( | ||||
|       gtsam::Ordering::OrderingType type = gtsam::Ordering::COLAMD); | ||||
|   gtsam::DiscreteLookupDAG maxProduct(const gtsam::Ordering& ordering); | ||||
| 
 | ||||
|   gtsam::DiscreteBayesNet* eliminateSequential(); | ||||
|   gtsam::DiscreteBayesNet* eliminateSequential(gtsam::Ordering::OrderingType type); | ||||
|   gtsam::DiscreteBayesNet* eliminateSequential( | ||||
|       gtsam::Ordering::OrderingType type = gtsam::Ordering::COLAMD); | ||||
|   gtsam::DiscreteBayesNet* eliminateSequential( | ||||
|       gtsam::Ordering::OrderingType type, | ||||
|       const gtsam::DiscreteFactorGraph::Eliminate& function); | ||||
|   gtsam::DiscreteBayesNet* eliminateSequential(const gtsam::Ordering& ordering); | ||||
|   gtsam::DiscreteBayesNet* eliminateSequential( | ||||
|       const gtsam::Ordering& ordering, | ||||
|       const gtsam::DiscreteFactorGraph::Eliminate& function); | ||||
|   pair<gtsam::DiscreteBayesNet*, gtsam::DiscreteFactorGraph*> | ||||
|       eliminatePartialSequential(const gtsam::Ordering& ordering); | ||||
|   eliminatePartialSequential(const gtsam::Ordering& ordering); | ||||
|   pair<gtsam::DiscreteBayesNet*, gtsam::DiscreteFactorGraph*> | ||||
|   eliminatePartialSequential( | ||||
|       const gtsam::Ordering& ordering, | ||||
|       const gtsam::DiscreteFactorGraph::Eliminate& function); | ||||
| 
 | ||||
|   gtsam::DiscreteBayesTree* eliminateMultifrontal(); | ||||
|   gtsam::DiscreteBayesTree* eliminateMultifrontal(gtsam::Ordering::OrderingType type);   | ||||
|   gtsam::DiscreteBayesTree* eliminateMultifrontal(const gtsam::Ordering& ordering);   | ||||
|   gtsam::DiscreteBayesTree* eliminateMultifrontal( | ||||
|       gtsam::Ordering::OrderingType type = gtsam::Ordering::COLAMD); | ||||
|   gtsam::DiscreteBayesTree* eliminateMultifrontal( | ||||
|       gtsam::Ordering::OrderingType type, | ||||
|       const gtsam::DiscreteFactorGraph::Eliminate& function); | ||||
|   gtsam::DiscreteBayesTree* eliminateMultifrontal( | ||||
|       const gtsam::Ordering& ordering); | ||||
|   gtsam::DiscreteBayesTree* eliminateMultifrontal( | ||||
|       const gtsam::Ordering& ordering, | ||||
|       const gtsam::DiscreteFactorGraph::Eliminate& function); | ||||
|   pair<gtsam::DiscreteBayesTree*, gtsam::DiscreteFactorGraph*> | ||||
|       eliminatePartialMultifrontal(const gtsam::Ordering& ordering); | ||||
|   eliminatePartialMultifrontal(const gtsam::Ordering& ordering); | ||||
|   pair<gtsam::DiscreteBayesTree*, gtsam::DiscreteFactorGraph*> | ||||
|   eliminatePartialMultifrontal( | ||||
|       const gtsam::Ordering& ordering, | ||||
|       const gtsam::DiscreteFactorGraph::Eliminate& function); | ||||
| 
 | ||||
|   string dot( | ||||
|       const gtsam::KeyFormatter& keyFormatter = gtsam::DefaultKeyFormatter, | ||||
|  |  | |||
|  | @ -323,10 +323,11 @@ TEST(DiscreteBayesTree, Lookup) { | |||
|   DiscreteFactorGraph graph; | ||||
|   const DiscreteKey x1{X(1), 3}, x2{X(2), 3}, x3{X(3), 3}; | ||||
|   const DiscreteKey a1{A(1), 2}, a2{A(2), 2}; | ||||
|   const DiscreteKeys keys{x1, x2, x3, a1, a2}; | ||||
| 
 | ||||
|   // Constraint on start and goal
 | ||||
|   graph.add(DiscreteKeys{x1}, std::vector<double>{1, 0, 0}); | ||||
|   graph.add(DiscreteKeys{x3}, std::vector<double>{0, 0, 1}); | ||||
| 
 | ||||
|   // Should I stay or should I go?
 | ||||
|   // "Reward" (exp(-cost)) for an action is 10, and rewards multiply:
 | ||||
|   const double r = 10; | ||||
|  |  | |||
|  | @ -51,12 +51,12 @@ namespace gtsam { | |||
|    *  algorithms.  Any factor graph holding eliminateable factors can derive from this class to | ||||
|    *  expose functions for computing marginals, conditional marginals, doing multifrontal and | ||||
|    *  sequential elimination, etc. */ | ||||
|   template<class FACTORGRAPH> | ||||
|   template<class FACTOR_GRAPH> | ||||
|   class EliminateableFactorGraph | ||||
|   { | ||||
|   private: | ||||
|     typedef EliminateableFactorGraph<FACTORGRAPH> This; ///< Typedef to this class.
 | ||||
|     typedef FACTORGRAPH FactorGraphType; ///< Typedef to factor graph type
 | ||||
|     typedef EliminateableFactorGraph<FACTOR_GRAPH> This; ///< Typedef to this class.
 | ||||
|     typedef FACTOR_GRAPH FactorGraphType; ///< Typedef to factor graph type
 | ||||
|     // Base factor type stored in this graph (private because derived classes will get this from
 | ||||
|     // their FactorGraph base class)
 | ||||
|     typedef typename EliminationTraits<FactorGraphType>::FactorType _FactorType; | ||||
|  |  | |||
|  | @ -104,6 +104,7 @@ class Ordering { | |||
|   // Standard Constructors and Named Constructors | ||||
|   Ordering(); | ||||
|   Ordering(const gtsam::Ordering& other); | ||||
|   Ordering(const std::vector<size_t>& keys); | ||||
| 
 | ||||
|   template < | ||||
|       FACTOR_GRAPH = {gtsam::NonlinearFactorGraph, gtsam::DiscreteFactorGraph, | ||||
|  |  | |||
|  | @ -13,16 +13,14 @@ Author: Frank Dellaert | |||
| 
 | ||||
| import unittest | ||||
| 
 | ||||
| import numpy as np | ||||
| from gtsam.symbol_shorthand import A, X | ||||
| from gtsam.utils.test_case import GtsamTestCase | ||||
| 
 | ||||
| from gtsam import ( | ||||
|     DiscreteBayesNet, | ||||
|     DiscreteBayesTreeClique, | ||||
|     DiscreteConditional, | ||||
|     DiscreteFactorGraph, | ||||
|     DiscreteValues, | ||||
|     Ordering, | ||||
| ) | ||||
| import gtsam | ||||
| from gtsam import (DiscreteBayesNet, DiscreteBayesTreeClique, | ||||
|                    DiscreteConditional, DiscreteFactorGraph, | ||||
|                    DiscreteKeys, DiscreteValues, Ordering) | ||||
| 
 | ||||
| 
 | ||||
| class TestDiscreteBayesNet(GtsamTestCase): | ||||
|  | @ -100,6 +98,56 @@ class TestDiscreteBayesNet(GtsamTestCase): | |||
|         self.assertFalse(bayesTree.empty()) | ||||
|         self.assertEqual(12, bayesTree.size()) | ||||
| 
 | ||||
|     def test_discrete_bayes_tree_lookup(self): | ||||
|         """Check that we can have a multi-frontal lookup table.""" | ||||
|         # Make a small planning-like graph: 3 states, 2 actions | ||||
|         graph = DiscreteFactorGraph() | ||||
|         x1, x2, x3 = (X(1), 3), (X(2), 3), (X(3), 3) | ||||
|         a1, a2 = (A(1), 2), (A(2), 2) | ||||
| 
 | ||||
|         # Constraint on start and goal | ||||
|         graph.add([x1], np.array([1, 0, 0])) | ||||
|         graph.add([x3], np.array([0, 0, 1])) | ||||
| 
 | ||||
|         # Should I stay or should I go? | ||||
|         # "Reward" (exp(-cost)) for an action is 10, and rewards multiply: | ||||
|         r = 10 | ||||
|         table = np.array([ | ||||
|             r, 0, 0, 0, r, 0,  # x1 = 0 | ||||
|             0, r, 0, 0, 0, r,  # x1 = 1 | ||||
|             0, 0, r, 0, 0, r   # x1 = 2 | ||||
|         ]) | ||||
|         graph.add([x1, a1, x2], table) | ||||
|         graph.add([x2, a2, x3], table) | ||||
| 
 | ||||
|         # Eliminate for MPE (maximum probable explanation). | ||||
|         ordering = Ordering([A(2), X(3), X(1), A(1), X(2)]) | ||||
|         lookup = graph.eliminateMultifrontal(ordering, gtsam.EliminateForMPE) | ||||
| 
 | ||||
|         # Check that the lookup table is correct | ||||
|         assert len(lookup) == 2 | ||||
|         lookup_x1_a1_x2 = lookup[X(1)].conditional() | ||||
|         assert len(lookup_x1_a1_x2.frontals()) == 3 | ||||
|         # Check that sum is 100 | ||||
|         empty = gtsam.DiscreteValues() | ||||
|         assert np.isclose(lookup_x1_a1_x2.sum(3)(empty), 100, atol=1e-9) | ||||
|         # And that only non-zero reward is for x1 a1 x2 == 0 1 1 | ||||
|         assert np.isclose(lookup_x1_a1_x2({X(1): 0, A(1): 1, X(2): 1}), 100, atol=1e-9) | ||||
| 
 | ||||
|         lookup_a2_x3 = lookup[X(3)].conditional() | ||||
|         # Check that the sum depends on x2 and is non-zero only for x2 in {1, 2} | ||||
|         sum_x2 = lookup_a2_x3.sum(2) | ||||
|         assert np.isclose(sum_x2({X(2): 0}), 0, atol=1e-9) | ||||
|         assert np.isclose(sum_x2({X(2): 1}), 10, atol=1e-9) | ||||
|         assert np.isclose(sum_x2({X(2): 2}), 20, atol=1e-9) | ||||
|         assert len(lookup_a2_x3.frontals()) == 2 | ||||
|         # And that the non-zero rewards are for | ||||
|         # x2 a2 x3 == 1 1 2 | ||||
|         assert np.isclose(lookup_a2_x3({X(2): 1, A(2): 1, X(3): 2}), 10, atol=1e-9) | ||||
|         # x2 a2 x3 == 2 0 2 | ||||
|         assert np.isclose(lookup_a2_x3({X(2): 2, A(2): 0, X(3): 2}), 10, atol=1e-9) | ||||
|         # x2 a2 x3 == 2 1 2 | ||||
|         assert np.isclose(lookup_a2_x3({X(2): 2, A(2): 1, X(3): 2}), 10, atol=1e-9) | ||||
| 
 | ||||
| if __name__ == "__main__": | ||||
|     unittest.main() | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue