Merge pull request #1928 from borglab/fix-table-factor
						commit
						137a503746
					
				|  | @ -48,7 +48,7 @@ namespace gtsam { | |||
|       return false; | ||||
|     } else { | ||||
|       const auto& f(static_cast<const DecisionTreeFactor&>(other)); | ||||
|       return ADT::equals(f, tol); | ||||
|       return Base::equals(other, tol) && ADT::equals(f, tol); | ||||
|     } | ||||
|   } | ||||
| 
 | ||||
|  |  | |||
|  | @ -28,6 +28,11 @@ using namespace std; | |||
| 
 | ||||
| namespace gtsam { | ||||
| 
 | ||||
| /* ************************************************************************* */ | ||||
| bool DiscreteFactor::equals(const DiscreteFactor& lf, double tol) const { | ||||
|   return Base::equals(lf, tol) && cardinalities_ == lf.cardinalities_; | ||||
| } | ||||
| 
 | ||||
| /* ************************************************************************ */ | ||||
| DiscreteKeys DiscreteFactor::discreteKeys() const { | ||||
|   DiscreteKeys result; | ||||
|  |  | |||
|  | @ -77,7 +77,7 @@ class GTSAM_EXPORT DiscreteFactor : public Factor { | |||
|   /// @{
 | ||||
| 
 | ||||
|   /// equals
 | ||||
|   virtual bool equals(const DiscreteFactor& lf, double tol = 1e-9) const = 0; | ||||
|   virtual bool equals(const DiscreteFactor& lf, double tol = 1e-9) const; | ||||
| 
 | ||||
|   /// print
 | ||||
|   void print( | ||||
|  |  | |||
|  | @ -92,13 +92,28 @@ TableFactor::TableFactor(const DiscreteKeys& dkeys, | |||
|                          const DecisionTreeFactor& dtf) | ||||
|     : TableFactor(dkeys, ComputeLeafOrdering(dkeys, dtf)) {} | ||||
| 
 | ||||
| /* ************************************************************************ */ | ||||
| TableFactor::TableFactor(const DecisionTreeFactor& dtf) | ||||
|     : TableFactor(dtf.discreteKeys(), | ||||
|                   ComputeLeafOrdering(dtf.discreteKeys(), dtf)) {} | ||||
| 
 | ||||
| /* ************************************************************************ */ | ||||
| TableFactor::TableFactor(const DiscreteConditional& c) | ||||
|     : TableFactor(c.discreteKeys(), c) {} | ||||
| 
 | ||||
| /* ************************************************************************ */ | ||||
| Eigen::SparseVector<double> TableFactor::Convert( | ||||
|     const std::vector<double>& table) { | ||||
|     const DiscreteKeys& keys, const std::vector<double>& table) { | ||||
|   size_t max_size = 1; | ||||
|   for (auto&& [_, cardinality] : keys.cardinalities()) { | ||||
|     max_size *= cardinality; | ||||
|   } | ||||
|   if (table.size() != max_size) { | ||||
|     throw std::runtime_error( | ||||
|         "The cardinalities of the keys don't match the number of values in the " | ||||
|         "input."); | ||||
|   } | ||||
| 
 | ||||
|   Eigen::SparseVector<double> sparse_table(table.size()); | ||||
|   // Count number of nonzero elements in table and reserve the space.
 | ||||
|   const uint64_t nnz = std::count_if(table.begin(), table.end(), | ||||
|  | @ -113,13 +128,14 @@ Eigen::SparseVector<double> TableFactor::Convert( | |||
| } | ||||
| 
 | ||||
| /* ************************************************************************ */ | ||||
| Eigen::SparseVector<double> TableFactor::Convert(const std::string& table) { | ||||
| Eigen::SparseVector<double> TableFactor::Convert(const DiscreteKeys& keys, | ||||
|                                                  const std::string& table) { | ||||
|   // Convert string to doubles.
 | ||||
|   std::vector<double> ys; | ||||
|   std::istringstream iss(table); | ||||
|   std::copy(std::istream_iterator<double>(iss), std::istream_iterator<double>(), | ||||
|             std::back_inserter(ys)); | ||||
|   return Convert(ys); | ||||
|   return Convert(keys, ys); | ||||
| } | ||||
| 
 | ||||
| /* ************************************************************************ */ | ||||
|  | @ -128,7 +144,8 @@ bool TableFactor::equals(const DiscreteFactor& other, double tol) const { | |||
|     return false; | ||||
|   } else { | ||||
|     const auto& f(static_cast<const TableFactor&>(other)); | ||||
|     return sparse_table_.isApprox(f.sparse_table_, tol); | ||||
|     return Base::equals(other, tol) && | ||||
|            sparse_table_.isApprox(f.sparse_table_, tol); | ||||
|   } | ||||
| } | ||||
| 
 | ||||
|  | @ -250,7 +267,8 @@ void TableFactor::print(const string& s, const KeyFormatter& formatter) const { | |||
|     for (auto&& kv : assignment) { | ||||
|       cout << "(" << formatter(kv.first) << ", " << kv.second << ")"; | ||||
|     } | ||||
|     cout << " | " << it.value() << " | " << it.index() << endl; | ||||
|     cout << " | " << std::setw(10) << std::left << it.value() << " | " | ||||
|          << it.index() << endl; | ||||
|   } | ||||
|   cout << "number of nnzs: " << sparse_table_.nonZeros() << endl; | ||||
| } | ||||
|  |  | |||
|  | @ -80,12 +80,16 @@ class GTSAM_EXPORT TableFactor : public DiscreteFactor { | |||
|     return DiscreteKey(keys_[i], cardinalities_.at(keys_[i])); | ||||
|   } | ||||
| 
 | ||||
|   /// Convert probability table given as doubles to SparseVector.
 | ||||
|   /// Example) {0, 1, 1, 0, 0, 1, 0} -> values: {1, 1, 1}, indices: {1, 2, 5}
 | ||||
|   static Eigen::SparseVector<double> Convert(const std::vector<double>& table); | ||||
|   /**
 | ||||
|    * Convert probability table given as doubles to SparseVector. | ||||
|    * Example: {0, 1, 1, 0, 0, 1, 0} -> values: {1, 1, 1}, indices: {1, 2, 5} | ||||
|    */ | ||||
|   static Eigen::SparseVector<double> Convert(const DiscreteKeys& keys, | ||||
|                                              const std::vector<double>& table); | ||||
| 
 | ||||
|   /// Convert probability table given as string to SparseVector.
 | ||||
|   static Eigen::SparseVector<double> Convert(const std::string& table); | ||||
|   static Eigen::SparseVector<double> Convert(const DiscreteKeys& keys, | ||||
|                                              const std::string& table); | ||||
| 
 | ||||
|  public: | ||||
|   // typedefs needed to play nice with gtsam
 | ||||
|  | @ -111,11 +115,11 @@ class GTSAM_EXPORT TableFactor : public DiscreteFactor { | |||
| 
 | ||||
|   /** Constructor from doubles */ | ||||
|   TableFactor(const DiscreteKeys& keys, const std::vector<double>& table) | ||||
|       : TableFactor(keys, Convert(table)) {} | ||||
|       : TableFactor(keys, Convert(keys, table)) {} | ||||
| 
 | ||||
|   /** Constructor from string */ | ||||
|   TableFactor(const DiscreteKeys& keys, const std::string& table) | ||||
|       : TableFactor(keys, Convert(table)) {} | ||||
|       : TableFactor(keys, Convert(keys, table)) {} | ||||
| 
 | ||||
|   /// Single-key specialization
 | ||||
|   template <class SOURCE> | ||||
|  | @ -128,6 +132,7 @@ class GTSAM_EXPORT TableFactor : public DiscreteFactor { | |||
| 
 | ||||
|   /// Constructor from DecisionTreeFactor
 | ||||
|   TableFactor(const DiscreteKeys& keys, const DecisionTreeFactor& dtf); | ||||
|   TableFactor(const DecisionTreeFactor& dtf); | ||||
| 
 | ||||
|   /// Constructor from DecisionTree<Key, double>/AlgebraicDecisionTree
 | ||||
|   TableFactor(const DiscreteKeys& keys, const DecisionTree<Key, double>& dtree); | ||||
|  |  | |||
|  | @ -134,14 +134,17 @@ TEST(TableFactor, constructors) { | |||
|   EXPECT(assert_equal(expected, f4)); | ||||
| 
 | ||||
|   // Test for 9=3x3 values.
 | ||||
|   DiscreteKey V(0, 3), W(1, 3); | ||||
|   DiscreteKey V(0, 3), W(1, 3), O(100, 3); | ||||
|   DiscreteConditional conditional5(V | W = "1/2/3 5/6/7 9/10/11"); | ||||
|   TableFactor f5(conditional5); | ||||
|   // GTSAM_PRINT(f5);
 | ||||
|   TableFactor expected_f5( | ||||
|       X & Y, | ||||
|       "0.166667 0.277778 0.3 0.333333 0.333333 0.333333 0.5 0.388889 0.366667"); | ||||
| 
 | ||||
|   std::string expected_values = | ||||
|       "0.166667 0.277778 0.3 0.333333 0.333333 0.333333 0.5 0.388889 0.366667"; | ||||
|   TableFactor expected_f5(V & W, expected_values); | ||||
|   EXPECT(assert_equal(expected_f5, f5, 1e-6)); | ||||
| 
 | ||||
|   TableFactor f5_with_wrong_keys(V & O, expected_values); | ||||
|   EXPECT(assert_inequal(f5_with_wrong_keys, f5, 1e-9)); | ||||
| } | ||||
| 
 | ||||
| /* ************************************************************************* */ | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue