Merge pull request #1960 from borglab/table-factor-fixes
						commit
						05d8030af4
					
				|  | @ -87,7 +87,15 @@ static Eigen::SparseVector<double> ComputeSparseTable( | |||
|   }); | ||||
|   sparseTable.reserve(nrValues); | ||||
| 
 | ||||
|   std::set<Key> allKeys(dt.keys().begin(), dt.keys().end()); | ||||
|   KeySet allKeys(dt.keys().begin(), dt.keys().end()); | ||||
| 
 | ||||
|   // Compute denominators to be used in computing sparse table indices
 | ||||
|   std::map<Key, size_t> denominators; | ||||
|   double denom = sparseTable.size(); | ||||
|   for (const DiscreteKey& dkey : dkeys) { | ||||
|     denom /= dkey.second; | ||||
|     denominators.insert(std::pair<Key, double>(dkey.first, denom)); | ||||
|   } | ||||
| 
 | ||||
|   /**
 | ||||
|    * @brief Functor which is called by the DecisionTree for each leaf. | ||||
|  | @ -102,13 +110,13 @@ static Eigen::SparseVector<double> ComputeSparseTable( | |||
|   auto op = [&](const Assignment<Key>& assignment, double p) { | ||||
|     if (p > 0) { | ||||
|       // Get all the keys involved in this assignment
 | ||||
|       std::set<Key> assignmentKeys; | ||||
|       KeySet assignmentKeys; | ||||
|       for (auto&& [k, _] : assignment) { | ||||
|         assignmentKeys.insert(k); | ||||
|       } | ||||
| 
 | ||||
|       // Find the keys missing in the assignment
 | ||||
|       std::vector<Key> diff; | ||||
|       KeyVector diff; | ||||
|       std::set_difference(allKeys.begin(), allKeys.end(), | ||||
|                           assignmentKeys.begin(), assignmentKeys.end(), | ||||
|                           std::back_inserter(diff)); | ||||
|  | @ -127,12 +135,10 @@ static Eigen::SparseVector<double> ComputeSparseTable( | |||
| 
 | ||||
|         // Generate index and add to the sparse vector.
 | ||||
|         Eigen::Index idx = 0; | ||||
|         size_t previousCardinality = 1; | ||||
|         // We go in reverse since a DecisionTree has the highest label first
 | ||||
|         for (auto&& it = updatedAssignment.rbegin(); | ||||
|              it != updatedAssignment.rend(); it++) { | ||||
|           idx += previousCardinality * it->second; | ||||
|           previousCardinality *= dt.cardinality(it->first); | ||||
|           idx += it->second * denominators.at(it->first); | ||||
|         } | ||||
|         sparseTable.coeffRef(idx) = p; | ||||
|       } | ||||
|  | @ -252,9 +258,19 @@ DecisionTreeFactor TableFactor::operator*(const DecisionTreeFactor& f) const { | |||
| DecisionTreeFactor TableFactor::toDecisionTreeFactor() const { | ||||
|   DiscreteKeys dkeys = discreteKeys(); | ||||
| 
 | ||||
|   std::vector<double> table; | ||||
|   for (auto i = 0; i < sparse_table_.size(); i++) { | ||||
|     table.push_back(sparse_table_.coeff(i)); | ||||
|   // If no keys, then return empty DecisionTreeFactor
 | ||||
|   if (dkeys.size() == 0) { | ||||
|     AlgebraicDecisionTree<Key> tree; | ||||
|     // We can have an empty sparse_table_ or one with a single value.
 | ||||
|     if (sparse_table_.size() != 0) { | ||||
|       tree = AlgebraicDecisionTree<Key>(sparse_table_.coeff(0)); | ||||
|     } | ||||
|     return DecisionTreeFactor(dkeys, tree); | ||||
|   } | ||||
| 
 | ||||
|   std::vector<double> table(sparse_table_.size(), 0.0); | ||||
|   for (SparseIt it(sparse_table_); it; ++it) { | ||||
|     table[it.index()] = it.value(); | ||||
|   } | ||||
| 
 | ||||
|   AlgebraicDecisionTree<Key> tree(dkeys, table); | ||||
|  |  | |||
|  | @ -173,6 +173,36 @@ TEST(TableFactor, Conversion) { | |||
|   TableFactor tf(dtf.discreteKeys(), dtf); | ||||
| 
 | ||||
|   EXPECT(assert_equal(dtf, tf.toDecisionTreeFactor())); | ||||
| 
 | ||||
|   // Test for correct construction when keys are not in reverse order.
 | ||||
|   // This is possible in conditionals e.g. P(x1 | x0)
 | ||||
|   DiscreteKey X(1, 2), Y(0, 2); | ||||
|   DiscreteConditional dtf2( | ||||
|       X, {Y}, std::vector<double>{0.33333333, 0.6, 0.66666667, 0.4}); | ||||
| 
 | ||||
|   TableFactor tf2(dtf2); | ||||
|   // GTSAM_PRINT(dtf2);
 | ||||
|   // GTSAM_PRINT(tf2);
 | ||||
|   // GTSAM_PRINT(tf2.toDecisionTreeFactor());
 | ||||
| 
 | ||||
|   // Check for ADT equality since the order of keys is irrelevant
 | ||||
|   EXPECT(assert_equal<AlgebraicDecisionTree<Key>>(dtf2, | ||||
|                                                   tf2.toDecisionTreeFactor())); | ||||
| } | ||||
| 
 | ||||
| /* ************************************************************************* */ | ||||
| TEST(TableFactor, Empty) { | ||||
|   DiscreteKey X(1, 2); | ||||
| 
 | ||||
|   TableFactor single = *TableFactor({X}, "1 1").sum(1); | ||||
|   // Should not throw a segfault
 | ||||
|   EXPECT(assert_equal(*DecisionTreeFactor(X, "1 1").sum(1), | ||||
|                       single.toDecisionTreeFactor())); | ||||
| 
 | ||||
|   TableFactor empty = *TableFactor({X}, "0 0").sum(1); | ||||
|   // Should not throw a segfault
 | ||||
|   EXPECT(assert_equal(*DecisionTreeFactor(X, "0 0").sum(1), | ||||
|                       empty.toDecisionTreeFactor())); | ||||
| } | ||||
| 
 | ||||
| /* ************************************************************************* */ | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue