Merge pull request #1960 from borglab/table-factor-fixes
						commit
						05d8030af4
					
				|  | @ -87,7 +87,15 @@ static Eigen::SparseVector<double> ComputeSparseTable( | ||||||
|   }); |   }); | ||||||
|   sparseTable.reserve(nrValues); |   sparseTable.reserve(nrValues); | ||||||
| 
 | 
 | ||||||
|   std::set<Key> allKeys(dt.keys().begin(), dt.keys().end()); |   KeySet allKeys(dt.keys().begin(), dt.keys().end()); | ||||||
|  | 
 | ||||||
|  |   // Compute denominators to be used in computing sparse table indices
 | ||||||
|  |   std::map<Key, size_t> denominators; | ||||||
|  |   double denom = sparseTable.size(); | ||||||
|  |   for (const DiscreteKey& dkey : dkeys) { | ||||||
|  |     denom /= dkey.second; | ||||||
|  |     denominators.insert(std::pair<Key, double>(dkey.first, denom)); | ||||||
|  |   } | ||||||
| 
 | 
 | ||||||
|   /**
 |   /**
 | ||||||
|    * @brief Functor which is called by the DecisionTree for each leaf. |    * @brief Functor which is called by the DecisionTree for each leaf. | ||||||
|  | @ -102,13 +110,13 @@ static Eigen::SparseVector<double> ComputeSparseTable( | ||||||
|   auto op = [&](const Assignment<Key>& assignment, double p) { |   auto op = [&](const Assignment<Key>& assignment, double p) { | ||||||
|     if (p > 0) { |     if (p > 0) { | ||||||
|       // Get all the keys involved in this assignment
 |       // Get all the keys involved in this assignment
 | ||||||
|       std::set<Key> assignmentKeys; |       KeySet assignmentKeys; | ||||||
|       for (auto&& [k, _] : assignment) { |       for (auto&& [k, _] : assignment) { | ||||||
|         assignmentKeys.insert(k); |         assignmentKeys.insert(k); | ||||||
|       } |       } | ||||||
| 
 | 
 | ||||||
|       // Find the keys missing in the assignment
 |       // Find the keys missing in the assignment
 | ||||||
|       std::vector<Key> diff; |       KeyVector diff; | ||||||
|       std::set_difference(allKeys.begin(), allKeys.end(), |       std::set_difference(allKeys.begin(), allKeys.end(), | ||||||
|                           assignmentKeys.begin(), assignmentKeys.end(), |                           assignmentKeys.begin(), assignmentKeys.end(), | ||||||
|                           std::back_inserter(diff)); |                           std::back_inserter(diff)); | ||||||
|  | @ -127,12 +135,10 @@ static Eigen::SparseVector<double> ComputeSparseTable( | ||||||
| 
 | 
 | ||||||
|         // Generate index and add to the sparse vector.
 |         // Generate index and add to the sparse vector.
 | ||||||
|         Eigen::Index idx = 0; |         Eigen::Index idx = 0; | ||||||
|         size_t previousCardinality = 1; |  | ||||||
|         // We go in reverse since a DecisionTree has the highest label first
 |         // We go in reverse since a DecisionTree has the highest label first
 | ||||||
|         for (auto&& it = updatedAssignment.rbegin(); |         for (auto&& it = updatedAssignment.rbegin(); | ||||||
|              it != updatedAssignment.rend(); it++) { |              it != updatedAssignment.rend(); it++) { | ||||||
|           idx += previousCardinality * it->second; |           idx += it->second * denominators.at(it->first); | ||||||
|           previousCardinality *= dt.cardinality(it->first); |  | ||||||
|         } |         } | ||||||
|         sparseTable.coeffRef(idx) = p; |         sparseTable.coeffRef(idx) = p; | ||||||
|       } |       } | ||||||
|  | @ -252,9 +258,19 @@ DecisionTreeFactor TableFactor::operator*(const DecisionTreeFactor& f) const { | ||||||
| DecisionTreeFactor TableFactor::toDecisionTreeFactor() const { | DecisionTreeFactor TableFactor::toDecisionTreeFactor() const { | ||||||
|   DiscreteKeys dkeys = discreteKeys(); |   DiscreteKeys dkeys = discreteKeys(); | ||||||
| 
 | 
 | ||||||
|   std::vector<double> table; |   // If no keys, then return empty DecisionTreeFactor
 | ||||||
|   for (auto i = 0; i < sparse_table_.size(); i++) { |   if (dkeys.size() == 0) { | ||||||
|     table.push_back(sparse_table_.coeff(i)); |     AlgebraicDecisionTree<Key> tree; | ||||||
|  |     // We can have an empty sparse_table_ or one with a single value.
 | ||||||
|  |     if (sparse_table_.size() != 0) { | ||||||
|  |       tree = AlgebraicDecisionTree<Key>(sparse_table_.coeff(0)); | ||||||
|  |     } | ||||||
|  |     return DecisionTreeFactor(dkeys, tree); | ||||||
|  |   } | ||||||
|  | 
 | ||||||
|  |   std::vector<double> table(sparse_table_.size(), 0.0); | ||||||
|  |   for (SparseIt it(sparse_table_); it; ++it) { | ||||||
|  |     table[it.index()] = it.value(); | ||||||
|   } |   } | ||||||
| 
 | 
 | ||||||
|   AlgebraicDecisionTree<Key> tree(dkeys, table); |   AlgebraicDecisionTree<Key> tree(dkeys, table); | ||||||
|  |  | ||||||
|  | @ -173,6 +173,36 @@ TEST(TableFactor, Conversion) { | ||||||
|   TableFactor tf(dtf.discreteKeys(), dtf); |   TableFactor tf(dtf.discreteKeys(), dtf); | ||||||
| 
 | 
 | ||||||
|   EXPECT(assert_equal(dtf, tf.toDecisionTreeFactor())); |   EXPECT(assert_equal(dtf, tf.toDecisionTreeFactor())); | ||||||
|  | 
 | ||||||
|  |   // Test for correct construction when keys are not in reverse order.
 | ||||||
|  |   // This is possible in conditionals e.g. P(x1 | x0)
 | ||||||
|  |   DiscreteKey X(1, 2), Y(0, 2); | ||||||
|  |   DiscreteConditional dtf2( | ||||||
|  |       X, {Y}, std::vector<double>{0.33333333, 0.6, 0.66666667, 0.4}); | ||||||
|  | 
 | ||||||
|  |   TableFactor tf2(dtf2); | ||||||
|  |   // GTSAM_PRINT(dtf2);
 | ||||||
|  |   // GTSAM_PRINT(tf2);
 | ||||||
|  |   // GTSAM_PRINT(tf2.toDecisionTreeFactor());
 | ||||||
|  | 
 | ||||||
|  |   // Check for ADT equality since the order of keys is irrelevant
 | ||||||
|  |   EXPECT(assert_equal<AlgebraicDecisionTree<Key>>(dtf2, | ||||||
|  |                                                   tf2.toDecisionTreeFactor())); | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | /* ************************************************************************* */ | ||||||
|  | TEST(TableFactor, Empty) { | ||||||
|  |   DiscreteKey X(1, 2); | ||||||
|  | 
 | ||||||
|  |   TableFactor single = *TableFactor({X}, "1 1").sum(1); | ||||||
|  |   // Should not throw a segfault
 | ||||||
|  |   EXPECT(assert_equal(*DecisionTreeFactor(X, "1 1").sum(1), | ||||||
|  |                       single.toDecisionTreeFactor())); | ||||||
|  | 
 | ||||||
|  |   TableFactor empty = *TableFactor({X}, "0 0").sum(1); | ||||||
|  |   // Should not throw a segfault
 | ||||||
|  |   EXPECT(assert_equal(*DecisionTreeFactor(X, "0 0").sum(1), | ||||||
|  |                       empty.toDecisionTreeFactor())); | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| /* ************************************************************************* */ | /* ************************************************************************* */ | ||||||
|  |  | ||||||
		Loading…
	
		Reference in New Issue