Merge pull request #1045 from borglab/feature/discrete_wrapping
						commit
						d8abdc280d
					
				|  | @ -143,67 +143,64 @@ void DiscreteConditional::print(const string& s, | |||
|     } | ||||
|   } | ||||
|   cout << "):\n"; | ||||
|   ADT::print(""); | ||||
|   ADT::print("", formatter); | ||||
|   cout << endl; | ||||
| } | ||||
| 
 | ||||
| /* ******************************************************************************** */ | ||||
| bool DiscreteConditional::equals(const DiscreteFactor& other, | ||||
|     double tol) const { | ||||
|   if (!dynamic_cast<const DecisionTreeFactor*>(&other)) | ||||
|                                  double tol) const { | ||||
|   if (!dynamic_cast<const DecisionTreeFactor*>(&other)) { | ||||
|     return false; | ||||
|   else { | ||||
|     const DecisionTreeFactor& f( | ||||
|         static_cast<const DecisionTreeFactor&>(other)); | ||||
|   } else { | ||||
|     const DecisionTreeFactor& f(static_cast<const DecisionTreeFactor&>(other)); | ||||
|     return DecisionTreeFactor::equals(f, tol); | ||||
|   } | ||||
| } | ||||
| 
 | ||||
| /* ******************************************************************************** */ | ||||
| /* ************************************************************************** */ | ||||
| static DiscreteConditional::ADT Choose(const DiscreteConditional& conditional, | ||||
|                                        const DiscreteValues& parentsValues) { | ||||
|                                        const DiscreteValues& given, | ||||
|                                        bool forceComplete = true) { | ||||
|   // Get the big decision tree with all the levels, and then go down the
 | ||||
|   // branches based on the value of the parent variables.
 | ||||
|   DiscreteConditional::ADT adt(conditional); | ||||
|   size_t value; | ||||
|   for (Key j : conditional.parents()) { | ||||
|     try { | ||||
|       value = parentsValues.at(j); | ||||
|       value = given.at(j); | ||||
|       adt = adt.choose(j, value);  // ADT keeps getting smaller.
 | ||||
|     } catch (std::out_of_range&) { | ||||
|       parentsValues.print("parentsValues: "); | ||||
|       throw runtime_error("DiscreteConditional::choose: parent value missing"); | ||||
|     }; | ||||
|       if (forceComplete) { | ||||
|         given.print("parentsValues: "); | ||||
|         throw runtime_error( | ||||
|             "DiscreteConditional::Choose: parent value missing"); | ||||
|       } | ||||
|     } | ||||
|   } | ||||
|   return adt; | ||||
| } | ||||
| 
 | ||||
| /* ******************************************************************************** */ | ||||
| DecisionTreeFactor::shared_ptr DiscreteConditional::choose( | ||||
|     const DiscreteValues& parentsValues) const { | ||||
|   // Get the big decision tree with all the levels, and then go down the
 | ||||
|   // branches based on the value of the parent variables.
 | ||||
|   ADT adt(*this); | ||||
|   size_t value; | ||||
|   for (Key j : parents()) { | ||||
|     try { | ||||
|       value = parentsValues.at(j); | ||||
|       adt = adt.choose(j, value);  // ADT keeps getting smaller.
 | ||||
|     } catch (exception&) { | ||||
|       parentsValues.print("parentsValues: "); | ||||
|       throw runtime_error("DiscreteConditional::choose: parent value missing"); | ||||
|     }; | ||||
|   } | ||||
| /* ************************************************************************** */ | ||||
| DiscreteConditional::shared_ptr DiscreteConditional::choose( | ||||
|     const DiscreteValues& given) const { | ||||
|   ADT adt = Choose(*this, given, false);  // P(F|S=given)
 | ||||
| 
 | ||||
|   // Convert ADT to factor.
 | ||||
|   DiscreteKeys discreteKeys; | ||||
|   // Collect all keys not in given.
 | ||||
|   DiscreteKeys dKeys; | ||||
|   for (Key j : frontals()) { | ||||
|     discreteKeys.emplace_back(j, this->cardinality(j)); | ||||
|     dKeys.emplace_back(j, this->cardinality(j)); | ||||
|   } | ||||
|   return boost::make_shared<DecisionTreeFactor>(discreteKeys, adt); | ||||
|   for (size_t i = nrFrontals(); i < size(); i++) { | ||||
|     Key j = keys_[i]; | ||||
|     if (given.count(j) == 0) { | ||||
|       dKeys.emplace_back(j, this->cardinality(j)); | ||||
|     } | ||||
|   } | ||||
|   return boost::make_shared<DiscreteConditional>(nrFrontals(), dKeys, adt); | ||||
| } | ||||
| 
 | ||||
| /* ******************************************************************************** */ | ||||
| /* ************************************************************************** */ | ||||
| DecisionTreeFactor::shared_ptr DiscreteConditional::likelihood( | ||||
|     const DiscreteValues& frontalValues) const { | ||||
|   // Get the big decision tree with all the levels, and then go down the
 | ||||
|  | @ -217,7 +214,7 @@ DecisionTreeFactor::shared_ptr DiscreteConditional::likelihood( | |||
|     } catch (exception&) { | ||||
|       frontalValues.print("frontalValues: "); | ||||
|       throw runtime_error("DiscreteConditional::choose: frontal value missing"); | ||||
|     }; | ||||
|     } | ||||
|   } | ||||
| 
 | ||||
|   // Convert ADT to factor.
 | ||||
|  | @ -242,7 +239,6 @@ DecisionTreeFactor::shared_ptr DiscreteConditional::likelihood( | |||
| 
 | ||||
| /* ************************************************************************** */ | ||||
| void DiscreteConditional::solveInPlace(DiscreteValues* values) const { | ||||
|   // TODO(Abhijit): is this really the fastest way? He thinks it is.
 | ||||
|   ADT pFS = Choose(*this, *values);  // P(F|S=parentsValues)
 | ||||
| 
 | ||||
|   // Initialize
 | ||||
|  | @ -276,11 +272,9 @@ void DiscreteConditional::sampleInPlace(DiscreteValues* values) const { | |||
|   (*values)[j] = sampled; // store result in partial solution
 | ||||
| } | ||||
| 
 | ||||
| /* ******************************************************************************** */ | ||||
| /* ************************************************************************** */ | ||||
| size_t DiscreteConditional::solve(const DiscreteValues& parentsValues) const { | ||||
| 
 | ||||
|   // TODO: is this really the fastest way? I think it is.
 | ||||
|   ADT pFS = Choose(*this, parentsValues); // P(F|S=parentsValues)
 | ||||
|   ADT pFS = Choose(*this, parentsValues);  // P(F|S=parentsValues)
 | ||||
| 
 | ||||
|   // Then, find the max over all remaining
 | ||||
|   // TODO, only works for one key now, seems horribly slow this way
 | ||||
|  |  | |||
|  | @ -157,9 +157,20 @@ class GTSAM_EXPORT DiscreteConditional | |||
|     return ADT::operator()(values); | ||||
|   } | ||||
| 
 | ||||
|   /** Restrict to given parent values, returns DecisionTreeFactor */ | ||||
|   DecisionTreeFactor::shared_ptr choose( | ||||
|       const DiscreteValues& parentsValues) const; | ||||
|   /** 
 | ||||
|    * @brief restrict to given *parent* values. | ||||
|    *  | ||||
|    * Note: does not need be complete set. Examples: | ||||
|    *  | ||||
|    * P(C|D,E) + . -> P(C|D,E) | ||||
|    * P(C|D,E) + E -> P(C|D) | ||||
|    * P(C|D,E) + D -> P(C|E) | ||||
|    * P(C|D,E) + D,E -> P(C) | ||||
|    * P(C|D,E) + C -> error! | ||||
|    *  | ||||
|    * @return a shared_ptr to a new DiscreteConditional | ||||
|    */ | ||||
|   shared_ptr choose(const DiscreteValues& given) const; | ||||
| 
 | ||||
|   /** Convert to a likelihood factor by providing value before bar. */ | ||||
|   DecisionTreeFactor::shared_ptr likelihood( | ||||
|  |  | |||
|  | @ -64,33 +64,35 @@ template<> struct EliminationTraits<DiscreteFactorGraph> | |||
|  * A Discrete Factor Graph is a factor graph where all factors are Discrete, i.e. | ||||
|  *   Factor == DiscreteFactor | ||||
|  */ | ||||
| class GTSAM_EXPORT DiscreteFactorGraph: public FactorGraph<DiscreteFactor>, | ||||
| public EliminateableFactorGraph<DiscreteFactorGraph> { | ||||
| public: | ||||
| class GTSAM_EXPORT DiscreteFactorGraph | ||||
|     : public FactorGraph<DiscreteFactor>, | ||||
|       public EliminateableFactorGraph<DiscreteFactorGraph> { | ||||
|  public: | ||||
|   using This = DiscreteFactorGraph;          ///< this class
 | ||||
|   using Base = FactorGraph<DiscreteFactor>;  ///< base factor graph type
 | ||||
|   using BaseEliminateable = | ||||
|       EliminateableFactorGraph<This>;          ///< for elimination
 | ||||
|   using shared_ptr = boost::shared_ptr<This>;  ///< shared_ptr to This
 | ||||
| 
 | ||||
|   typedef DiscreteFactorGraph This; ///< Typedef to this class
 | ||||
|   typedef FactorGraph<DiscreteFactor> Base; ///< Typedef to base factor graph type
 | ||||
|   typedef EliminateableFactorGraph<This> BaseEliminateable; ///< Typedef to base elimination class
 | ||||
|   typedef boost::shared_ptr<This> shared_ptr; ///< shared_ptr to this class
 | ||||
|   using Values = DiscreteValues;  ///< backwards compatibility
 | ||||
| 
 | ||||
|   using Values = DiscreteValues; ///< backwards compatibility
 | ||||
| 
 | ||||
|   /** A map from keys to values */ | ||||
|   typedef KeyVector Indices; | ||||
|   using Indices = KeyVector;  ///> map from keys to values
 | ||||
| 
 | ||||
|   /** Default constructor */ | ||||
|   DiscreteFactorGraph() {} | ||||
| 
 | ||||
|   /** Construct from iterator over factors */ | ||||
|   template<typename ITERATOR> | ||||
|   DiscreteFactorGraph(ITERATOR firstFactor, ITERATOR lastFactor) : Base(firstFactor, lastFactor) {} | ||||
|   template <typename ITERATOR> | ||||
|   DiscreteFactorGraph(ITERATOR firstFactor, ITERATOR lastFactor) | ||||
|       : Base(firstFactor, lastFactor) {} | ||||
| 
 | ||||
|   /** Construct from container of factors (shared_ptr or plain objects) */ | ||||
|   template<class CONTAINER> | ||||
|   template <class CONTAINER> | ||||
|   explicit DiscreteFactorGraph(const CONTAINER& factors) : Base(factors) {} | ||||
| 
 | ||||
|   /** Implicit copy/downcast constructor to override explicit template container constructor */ | ||||
|   template<class DERIVEDFACTOR> | ||||
|   /** Implicit copy/downcast constructor to override explicit template container
 | ||||
|    * constructor */ | ||||
|   template <class DERIVEDFACTOR> | ||||
|   DiscreteFactorGraph(const FactorGraph<DERIVEDFACTOR>& graph) : Base(graph) {} | ||||
| 
 | ||||
|   /// Destructor
 | ||||
|  | @ -108,7 +110,7 @@ public: | |||
|   void add(Args&&... args) { | ||||
|     emplace_shared<DecisionTreeFactor>(std::forward<Args>(args)...); | ||||
|   } | ||||
|        | ||||
| 
 | ||||
|   /** Return the set of variables involved in the factors (set union) */ | ||||
|   KeySet keys() const; | ||||
| 
 | ||||
|  | @ -163,9 +165,10 @@ public: | |||
|                    const DiscreteFactor::Names& names = {}) const; | ||||
| 
 | ||||
|   /// @}
 | ||||
| }; // \ DiscreteFactorGraph
 | ||||
| };  // \ DiscreteFactorGraph
 | ||||
| 
 | ||||
| /// traits
 | ||||
| template<> struct traits<DiscreteFactorGraph> : public Testable<DiscreteFactorGraph> {}; | ||||
| template <> | ||||
| struct traits<DiscreteFactorGraph> : public Testable<DiscreteFactorGraph> {}; | ||||
| 
 | ||||
| } // \ namespace gtsam
 | ||||
| }  // namespace gtsam
 | ||||
|  |  | |||
|  | @ -107,8 +107,7 @@ virtual class DiscreteConditional : gtsam::DecisionTreeFactor { | |||
|   void printSignature( | ||||
|       string s = "Discrete Conditional: ", | ||||
|       const gtsam::KeyFormatter& formatter = gtsam::DefaultKeyFormatter) const; | ||||
|   gtsam::DecisionTreeFactor* choose( | ||||
|       const gtsam::DiscreteValues& parentsValues) const; | ||||
|   gtsam::DecisionTreeFactor* choose(const gtsam::DiscreteValues& given) const; | ||||
|   gtsam::DecisionTreeFactor* likelihood( | ||||
|       const gtsam::DiscreteValues& frontalValues) const; | ||||
|   gtsam::DecisionTreeFactor* likelihood(size_t value) const; | ||||
|  | @ -230,11 +229,16 @@ class DiscreteFactorGraph { | |||
|   DiscreteFactorGraph(); | ||||
|   DiscreteFactorGraph(const gtsam::DiscreteBayesNet& bayesNet); | ||||
| 
 | ||||
|   void add(const gtsam::DiscreteKey& j, string table); | ||||
|   // Building the graph | ||||
|   void push_back(const gtsam::DiscreteFactor* factor); | ||||
|   void push_back(const gtsam::DiscreteConditional* conditional); | ||||
|   void push_back(const gtsam::DiscreteFactorGraph& graph); | ||||
|   void push_back(const gtsam::DiscreteBayesNet& bayesNet); | ||||
|   void push_back(const gtsam::DiscreteBayesTree& bayesTree); | ||||
|   void add(const gtsam::DiscreteKey& j, string spec); | ||||
|   void add(const gtsam::DiscreteKey& j, const std::vector<double>& spec); | ||||
| 
 | ||||
|   void add(const gtsam::DiscreteKeys& keys, string table); | ||||
|   void add(const std::vector<gtsam::DiscreteKey>& keys, string table); | ||||
|   void add(const gtsam::DiscreteKeys& keys, string spec); | ||||
|   void add(const std::vector<gtsam::DiscreteKey>& keys, string spec); | ||||
| 
 | ||||
|   bool empty() const; | ||||
|   size_t size() const; | ||||
|  | @ -258,8 +262,12 @@ class DiscreteFactorGraph { | |||
| 
 | ||||
|   gtsam::DiscreteBayesNet eliminateSequential(); | ||||
|   gtsam::DiscreteBayesNet eliminateSequential(const gtsam::Ordering& ordering); | ||||
|   std::pair<gtsam::DiscreteBayesNet, gtsam::DiscreteFactorGraph> | ||||
|       eliminatePartialSequential(const gtsam::Ordering& ordering); | ||||
|   gtsam::DiscreteBayesTree eliminateMultifrontal(); | ||||
|   gtsam::DiscreteBayesTree eliminateMultifrontal(const gtsam::Ordering& ordering); | ||||
|   std::pair<gtsam::DiscreteBayesTree, gtsam::DiscreteFactorGraph> | ||||
|       eliminatePartialMultifrontal(const gtsam::Ordering& ordering); | ||||
| 
 | ||||
|   string markdown(const gtsam::KeyFormatter& keyFormatter = | ||||
|                  gtsam::DefaultKeyFormatter) const; | ||||
|  |  | |||
|  | @ -221,6 +221,34 @@ TEST(DiscreteConditional, likelihood) { | |||
|   EXPECT(assert_equal(expected1, *actual1, 1e-9)); | ||||
| } | ||||
| 
 | ||||
| /* ************************************************************************* */ | ||||
| // Check choose on P(C|D,E)
 | ||||
| TEST(DiscreteConditional, choose) { | ||||
|   DiscreteKey C(2, 2), D(4, 2), E(3, 2); | ||||
|   DiscreteConditional C_given_DE((C | D, E) = "4/1 1/1 1/1 1/4"); | ||||
| 
 | ||||
|   // Case 1: no given values: no-op
 | ||||
|   DiscreteValues given; | ||||
|   auto actual1 = C_given_DE.choose(given); | ||||
|   EXPECT(assert_equal(C_given_DE, *actual1, 1e-9)); | ||||
| 
 | ||||
|   // Case 2: 1 given value
 | ||||
|   given[D.first] = 1; | ||||
|   auto actual2 = C_given_DE.choose(given); | ||||
|   EXPECT_LONGS_EQUAL(1, actual2->nrFrontals()); | ||||
|   EXPECT_LONGS_EQUAL(1, actual2->nrParents()); | ||||
|   DiscreteConditional expected2(C | E = "1/1 1/4"); | ||||
|   EXPECT(assert_equal(expected2, *actual2, 1e-9)); | ||||
| 
 | ||||
|   // Case 2: 2 given values
 | ||||
|   given[E.first] = 0; | ||||
|   auto actual3 = C_given_DE.choose(given); | ||||
|   EXPECT_LONGS_EQUAL(1, actual3->nrFrontals()); | ||||
|   EXPECT_LONGS_EQUAL(0, actual3->nrParents()); | ||||
|   DiscreteConditional expected3(C % "1/1"); | ||||
|   EXPECT(assert_equal(expected3, *actual3, 1e-9)); | ||||
| } | ||||
| 
 | ||||
| /* ************************************************************************* */ | ||||
| // Check markdown representation looks as expected, no parents.
 | ||||
| TEST(DiscreteConditional, markdown_prior) { | ||||
|  |  | |||
|  | @ -376,8 +376,12 @@ TEST(DiscreteFactorGraph, Dot) { | |||
|       "  var1[label=\"1\"];\n" | ||||
|       "  var2[label=\"2\"];\n" | ||||
|       "\n" | ||||
|       "  var0--var1;\n" | ||||
|       "  var0--var2;\n" | ||||
|       "  factor0[label=\"\", shape=point];\n" | ||||
|       "  var0--factor0;\n" | ||||
|       "  var1--factor0;\n" | ||||
|       "  factor1[label=\"\", shape=point];\n" | ||||
|       "  var0--factor1;\n" | ||||
|       "  var2--factor1;\n" | ||||
|       "}\n"; | ||||
|   EXPECT(actual == expected); | ||||
| } | ||||
|  | @ -397,12 +401,16 @@ TEST(DiscreteFactorGraph, DotWithNames) { | |||
|       "graph {\n" | ||||
|       "  size=\"5,5\";\n" | ||||
|       "\n" | ||||
|       "  var0[label=\"C\"];\n" | ||||
|       "  var1[label=\"A\"];\n" | ||||
|       "  var2[label=\"B\"];\n" | ||||
|       "  varC[label=\"C\"];\n" | ||||
|       "  varA[label=\"A\"];\n" | ||||
|       "  varB[label=\"B\"];\n" | ||||
|       "\n" | ||||
|       "  var0--var1;\n" | ||||
|       "  var0--var2;\n" | ||||
|       "  factor0[label=\"\", shape=point];\n" | ||||
|       "  varC--factor0;\n" | ||||
|       "  varA--factor0;\n" | ||||
|       "  factor1[label=\"\", shape=point];\n" | ||||
|       "  varC--factor1;\n" | ||||
|       "  varB--factor1;\n" | ||||
|       "}\n"; | ||||
|   EXPECT(actual == expected); | ||||
| } | ||||
|  |  | |||
|  | @ -35,7 +35,8 @@ void DotWriter::DrawVariable(Key key, const KeyFormatter& keyFormatter, | |||
|                              const boost::optional<Vector2>& position, | ||||
|                              ostream* os) { | ||||
|   // Label the node with the label from the KeyFormatter
 | ||||
|   *os << "  var" << key << "[label=\"" << keyFormatter(key) << "\""; | ||||
|   *os << "  var" << keyFormatter(key) << "[label=\"" << keyFormatter(key) | ||||
|       << "\""; | ||||
|   if (position) { | ||||
|     *os << ", pos=\"" << position->x() << "," << position->y() << "!\""; | ||||
|   } | ||||
|  | @ -51,22 +52,26 @@ void DotWriter::DrawFactor(size_t i, const boost::optional<Vector2>& position, | |||
|   *os << "];\n"; | ||||
| } | ||||
| 
 | ||||
| void DotWriter::ConnectVariables(Key key1, Key key2, ostream* os) { | ||||
|   *os << "  var" << key1 << "--" | ||||
|       << "var" << key2 << ";\n"; | ||||
| static void ConnectVariables(Key key1, Key key2, | ||||
|                                  const KeyFormatter& keyFormatter, | ||||
|                                  ostream* os) { | ||||
|   *os << "  var" << keyFormatter(key1) << "--" | ||||
|       << "var" << keyFormatter(key2) << ";\n"; | ||||
| } | ||||
| 
 | ||||
| void DotWriter::ConnectVariableFactor(Key key, size_t i, ostream* os) { | ||||
|   *os << "  var" << key << "--" | ||||
| static void ConnectVariableFactor(Key key, const KeyFormatter& keyFormatter, | ||||
|                                       size_t i, ostream* os) { | ||||
|   *os << "  var" << keyFormatter(key) << "--" | ||||
|       << "factor" << i << ";\n"; | ||||
| } | ||||
| 
 | ||||
| void DotWriter::processFactor(size_t i, const KeyVector& keys, | ||||
|                               const KeyFormatter& keyFormatter, | ||||
|                               const boost::optional<Vector2>& position, | ||||
|                               ostream* os) const { | ||||
|   if (plotFactorPoints) { | ||||
|     if (binaryEdges && keys.size() == 2) { | ||||
|       ConnectVariables(keys[0], keys[1], os); | ||||
|       ConnectVariables(keys[0], keys[1], keyFormatter, os); | ||||
|     } else { | ||||
|       // Create dot for the factor.
 | ||||
|       DrawFactor(i, position, os); | ||||
|  | @ -74,7 +79,7 @@ void DotWriter::processFactor(size_t i, const KeyVector& keys, | |||
|       // Make factor-variable connections
 | ||||
|       if (connectKeysToFactor) { | ||||
|         for (Key key : keys) { | ||||
|           ConnectVariableFactor(key, i, os); | ||||
|           ConnectVariableFactor(key, keyFormatter, i, os); | ||||
|         } | ||||
|       } | ||||
|     } | ||||
|  | @ -83,7 +88,7 @@ void DotWriter::processFactor(size_t i, const KeyVector& keys, | |||
|     for (Key key1 : keys) { | ||||
|       for (Key key2 : keys) { | ||||
|         if (key2 > key1) { | ||||
|           ConnectVariables(key1, key2, os); | ||||
|           ConnectVariables(key1, key2, keyFormatter, os); | ||||
|         } | ||||
|       } | ||||
|     } | ||||
|  |  | |||
|  | @ -38,7 +38,7 @@ struct GTSAM_EXPORT DotWriter { | |||
|   explicit DotWriter(double figureWidthInches = 5, | ||||
|                      double figureHeightInches = 5, | ||||
|                      bool plotFactorPoints = true, | ||||
|                      bool connectKeysToFactor = true, bool binaryEdges = true) | ||||
|                      bool connectKeysToFactor = true, bool binaryEdges = false) | ||||
|       : figureWidthInches(figureWidthInches), | ||||
|         figureHeightInches(figureHeightInches), | ||||
|         plotFactorPoints(plotFactorPoints), | ||||
|  | @ -57,14 +57,9 @@ struct GTSAM_EXPORT DotWriter { | |||
|   static void DrawFactor(size_t i, const boost::optional<Vector2>& position, | ||||
|                          std::ostream* os); | ||||
| 
 | ||||
|   /// Connect two variables.
 | ||||
|   static void ConnectVariables(Key key1, Key key2, std::ostream* os); | ||||
| 
 | ||||
|   /// Connect variable and factor.
 | ||||
|   static void ConnectVariableFactor(Key key, size_t i, std::ostream* os); | ||||
| 
 | ||||
|   /// Draw a single factor, specified by its index i and its variable keys.
 | ||||
|   void processFactor(size_t i, const KeyVector& keys, | ||||
|                      const KeyFormatter& keyFormatter, | ||||
|                      const boost::optional<Vector2>& position, | ||||
|                      std::ostream* os) const; | ||||
| }; | ||||
|  |  | |||
|  | @ -144,7 +144,7 @@ void FactorGraph<FACTOR>::dot(std::ostream& os, | |||
|     const auto& factor = at(i); | ||||
|     if (factor) { | ||||
|       const KeyVector& factorKeys = factor->keys(); | ||||
|       writer.processFactor(i, factorKeys, boost::none, &os); | ||||
|       writer.processFactor(i, factorKeys, keyFormatter, boost::none, &os); | ||||
|     } | ||||
|   } | ||||
| 
 | ||||
|  |  | |||
|  | @ -33,8 +33,10 @@ | |||
| #  include <tbb/parallel_for.h> | ||||
| #endif | ||||
| 
 | ||||
| #include <algorithm> | ||||
| #include <cmath> | ||||
| #include <fstream> | ||||
| #include <set> | ||||
| 
 | ||||
| using namespace std; | ||||
| 
 | ||||
|  | @ -127,7 +129,7 @@ void NonlinearFactorGraph::dot(std::ostream& os, const Values& values, | |||
|     // Create factors and variable connections
 | ||||
|     size_t i = 0; | ||||
|     for (const KeyVector& factorKeys : structure) { | ||||
|       writer.processFactor(i++, factorKeys, boost::none, &os); | ||||
|       writer.processFactor(i++, factorKeys, keyFormatter, boost::none, &os); | ||||
|     } | ||||
|   } else { | ||||
|     // Create factors and variable connections
 | ||||
|  | @ -135,7 +137,8 @@ void NonlinearFactorGraph::dot(std::ostream& os, const Values& values, | |||
|       const NonlinearFactor::shared_ptr& factor = at(i); | ||||
|       if (factor) { | ||||
|         const KeyVector& factorKeys = factor->keys(); | ||||
|         writer.processFactor(i, factorKeys, writer.factorPos(min, i), &os); | ||||
|         writer.processFactor(i, factorKeys, keyFormatter, | ||||
|                              writer.factorPos(min, i), &os); | ||||
|       } | ||||
|     } | ||||
|   } | ||||
|  |  | |||
|  | @ -335,15 +335,21 @@ TEST(NonlinearFactorGraph, dot) { | |||
|       "graph {\n" | ||||
|       "  size=\"5,5\";\n" | ||||
|       "\n" | ||||
|       "  var7782220156096217089[label=\"l1\"];\n" | ||||
|       "  var8646911284551352321[label=\"x1\"];\n" | ||||
|       "  var8646911284551352322[label=\"x2\"];\n" | ||||
|       "  varl1[label=\"l1\"];\n" | ||||
|       "  varx1[label=\"x1\"];\n" | ||||
|       "  varx2[label=\"x2\"];\n" | ||||
|       "\n" | ||||
|       "  factor0[label=\"\", shape=point];\n" | ||||
|       "  var8646911284551352321--factor0;\n" | ||||
|       "  var8646911284551352321--var8646911284551352322;\n" | ||||
|       "  var8646911284551352321--var7782220156096217089;\n" | ||||
|       "  var8646911284551352322--var7782220156096217089;\n" | ||||
|       "  varx1--factor0;\n" | ||||
|       "  factor1[label=\"\", shape=point];\n" | ||||
|       "  varx1--factor1;\n" | ||||
|       "  varx2--factor1;\n" | ||||
|       "  factor2[label=\"\", shape=point];\n" | ||||
|       "  varx1--factor2;\n" | ||||
|       "  varl1--factor2;\n" | ||||
|       "  factor3[label=\"\", shape=point];\n" | ||||
|       "  varx2--factor3;\n" | ||||
|       "  varl1--factor3;\n" | ||||
|       "}\n"; | ||||
| 
 | ||||
|   const NonlinearFactorGraph fg = createNonlinearFactorGraph(); | ||||
|  | @ -357,15 +363,21 @@ TEST(NonlinearFactorGraph, dot_extra) { | |||
|       "graph {\n" | ||||
|       "  size=\"5,5\";\n" | ||||
|       "\n" | ||||
|       "  var7782220156096217089[label=\"l1\", pos=\"0,0!\"];\n" | ||||
|       "  var8646911284551352321[label=\"x1\", pos=\"1,0!\"];\n" | ||||
|       "  var8646911284551352322[label=\"x2\", pos=\"1,1.5!\"];\n" | ||||
|       "  varl1[label=\"l1\", pos=\"0,0!\"];\n" | ||||
|       "  varx1[label=\"x1\", pos=\"1,0!\"];\n" | ||||
|       "  varx2[label=\"x2\", pos=\"1,1.5!\"];\n" | ||||
|       "\n" | ||||
|       "  factor0[label=\"\", shape=point];\n" | ||||
|       "  var8646911284551352321--factor0;\n" | ||||
|       "  var8646911284551352321--var8646911284551352322;\n" | ||||
|       "  var8646911284551352321--var7782220156096217089;\n" | ||||
|       "  var8646911284551352322--var7782220156096217089;\n" | ||||
|       "  varx1--factor0;\n" | ||||
|       "  factor1[label=\"\", shape=point];\n" | ||||
|       "  varx1--factor1;\n" | ||||
|       "  varx2--factor1;\n" | ||||
|       "  factor2[label=\"\", shape=point];\n" | ||||
|       "  varx1--factor2;\n" | ||||
|       "  varl1--factor2;\n" | ||||
|       "  factor3[label=\"\", shape=point];\n" | ||||
|       "  varx2--factor3;\n" | ||||
|       "  varl1--factor3;\n" | ||||
|       "}\n"; | ||||
| 
 | ||||
|   const NonlinearFactorGraph fg = createNonlinearFactorGraph(); | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue