address review comments
							parent
							
								
									098d2ce4a4
								
							
						
					
					
						commit
						d94b3199a0
					
				|  | @ -210,13 +210,14 @@ void GaussianMixture::prune(const DecisionTreeFactor &decisionTree) { | ||||||
| /* *******************************************************************************/ | /* *******************************************************************************/ | ||||||
| AlgebraicDecisionTree<Key> GaussianMixture::error( | AlgebraicDecisionTree<Key> GaussianMixture::error( | ||||||
|     const VectorValues &continuousValues) const { |     const VectorValues &continuousValues) const { | ||||||
|   // functor to convert from GaussianConditional to double error value.
 |   // functor to calculate to double error value from GaussianConditional.
 | ||||||
|   auto errorFunc = |   auto errorFunc = | ||||||
|       [continuousValues](const GaussianConditional::shared_ptr &conditional) { |       [continuousValues](const GaussianConditional::shared_ptr &conditional) { | ||||||
|         if (conditional) { |         if (conditional) { | ||||||
|           return conditional->error(continuousValues); |           return conditional->error(continuousValues); | ||||||
|         } else { |         } else { | ||||||
|           // return arbitrarily large error
 |           // Return arbitrarily large error if conditional is null
 | ||||||
|  |           // Conditional is null if it is pruned out.
 | ||||||
|           return 1e50; |           return 1e50; | ||||||
|         } |         } | ||||||
|       }; |       }; | ||||||
|  | @ -227,6 +228,7 @@ AlgebraicDecisionTree<Key> GaussianMixture::error( | ||||||
| /* *******************************************************************************/ | /* *******************************************************************************/ | ||||||
| double GaussianMixture::error(const VectorValues &continuousValues, | double GaussianMixture::error(const VectorValues &continuousValues, | ||||||
|                               const DiscreteValues &discreteValues) const { |                               const DiscreteValues &discreteValues) const { | ||||||
|  |   // Directly index to get the conditional, no need to build the whole tree.
 | ||||||
|   auto conditional = conditionals_(discreteValues); |   auto conditional = conditionals_(discreteValues); | ||||||
|   return conditional->error(continuousValues); |   return conditional->error(continuousValues); | ||||||
| } | } | ||||||
|  |  | ||||||
|  | @ -112,6 +112,7 @@ AlgebraicDecisionTree<Key> GaussianMixtureFactor::error( | ||||||
| double GaussianMixtureFactor::error( | double GaussianMixtureFactor::error( | ||||||
|     const VectorValues &continuousValues, |     const VectorValues &continuousValues, | ||||||
|     const DiscreteValues &discreteValues) const { |     const DiscreteValues &discreteValues) const { | ||||||
|  |   // Directly index to get the conditional, no need to build the whole tree.
 | ||||||
|   auto factor = factors_(discreteValues); |   auto factor = factors_(discreteValues); | ||||||
|   return factor->error(continuousValues); |   return factor->error(continuousValues); | ||||||
| } | } | ||||||
|  |  | ||||||
|  | @ -244,13 +244,16 @@ AlgebraicDecisionTree<Key> HybridBayesNet::error( | ||||||
|     const VectorValues &continuousValues) const { |     const VectorValues &continuousValues) const { | ||||||
|   AlgebraicDecisionTree<Key> error_tree; |   AlgebraicDecisionTree<Key> error_tree; | ||||||
| 
 | 
 | ||||||
|  |   // Iterate over each factor.
 | ||||||
|   for (size_t idx = 0; idx < size(); idx++) { |   for (size_t idx = 0; idx < size(); idx++) { | ||||||
|     AlgebraicDecisionTree<Key> conditional_error; |     AlgebraicDecisionTree<Key> conditional_error; | ||||||
|  | 
 | ||||||
|     if (factors_.at(idx)->isHybrid()) { |     if (factors_.at(idx)->isHybrid()) { | ||||||
|       // If factor is hybrid, select based on assignment.
 |       // If factor is hybrid, select based on assignment and compute error.
 | ||||||
|       GaussianMixture::shared_ptr gm = this->atMixture(idx); |       GaussianMixture::shared_ptr gm = this->atMixture(idx); | ||||||
|       conditional_error = gm->error(continuousValues); |       conditional_error = gm->error(continuousValues); | ||||||
| 
 | 
 | ||||||
|  |       // Assign for the first index, add error for subsequent ones.
 | ||||||
|       if (idx == 0) { |       if (idx == 0) { | ||||||
|         error_tree = conditional_error; |         error_tree = conditional_error; | ||||||
|       } else { |       } else { | ||||||
|  | @ -261,6 +264,7 @@ AlgebraicDecisionTree<Key> HybridBayesNet::error( | ||||||
|       // If continuous only, get the (double) error
 |       // If continuous only, get the (double) error
 | ||||||
|       // and add it to the error_tree
 |       // and add it to the error_tree
 | ||||||
|       double error = this->atGaussian(idx)->error(continuousValues); |       double error = this->atGaussian(idx)->error(continuousValues); | ||||||
|  |       // Add the computed error to every leaf of the error tree.
 | ||||||
|       error_tree = error_tree.apply( |       error_tree = error_tree.apply( | ||||||
|           [error](double leaf_value) { return leaf_value + error; }); |           [error](double leaf_value) { return leaf_value + error; }); | ||||||
| 
 | 
 | ||||||
|  | @ -273,6 +277,7 @@ AlgebraicDecisionTree<Key> HybridBayesNet::error( | ||||||
|   return error_tree; |   return error_tree; | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | /* ************************************************************************* */ | ||||||
| AlgebraicDecisionTree<Key> HybridBayesNet::probPrime( | AlgebraicDecisionTree<Key> HybridBayesNet::probPrime( | ||||||
|     const VectorValues &continuousValues) const { |     const VectorValues &continuousValues) const { | ||||||
|   AlgebraicDecisionTree<Key> error_tree = this->error(continuousValues); |   AlgebraicDecisionTree<Key> error_tree = this->error(continuousValues); | ||||||
|  |  | ||||||
|  | @ -428,6 +428,7 @@ AlgebraicDecisionTree<Key> HybridGaussianFactorGraph::error( | ||||||
|     const VectorValues &continuousValues) const { |     const VectorValues &continuousValues) const { | ||||||
|   AlgebraicDecisionTree<Key> error_tree(0.0); |   AlgebraicDecisionTree<Key> error_tree(0.0); | ||||||
| 
 | 
 | ||||||
|  |   // Iterate over each factor.
 | ||||||
|   for (size_t idx = 0; idx < size(); idx++) { |   for (size_t idx = 0; idx < size(); idx++) { | ||||||
|     AlgebraicDecisionTree<Key> factor_error; |     AlgebraicDecisionTree<Key> factor_error; | ||||||
| 
 | 
 | ||||||
|  | @ -435,8 +436,10 @@ AlgebraicDecisionTree<Key> HybridGaussianFactorGraph::error( | ||||||
|       // If factor is hybrid, select based on assignment.
 |       // If factor is hybrid, select based on assignment.
 | ||||||
|       GaussianMixtureFactor::shared_ptr gaussianMixture = |       GaussianMixtureFactor::shared_ptr gaussianMixture = | ||||||
|           boost::static_pointer_cast<GaussianMixtureFactor>(factors_.at(idx)); |           boost::static_pointer_cast<GaussianMixtureFactor>(factors_.at(idx)); | ||||||
|  |       // Compute factor error.
 | ||||||
|       factor_error = gaussianMixture->error(continuousValues); |       factor_error = gaussianMixture->error(continuousValues); | ||||||
| 
 | 
 | ||||||
|  |       // If first factor, assign error, else add it.
 | ||||||
|       if (idx == 0) { |       if (idx == 0) { | ||||||
|         error_tree = factor_error; |         error_tree = factor_error; | ||||||
|       } else { |       } else { | ||||||
|  | @ -450,7 +453,9 @@ AlgebraicDecisionTree<Key> HybridGaussianFactorGraph::error( | ||||||
|           boost::static_pointer_cast<HybridGaussianFactor>(factors_.at(idx)); |           boost::static_pointer_cast<HybridGaussianFactor>(factors_.at(idx)); | ||||||
|       GaussianFactor::shared_ptr gaussian = hybridGaussianFactor->inner(); |       GaussianFactor::shared_ptr gaussian = hybridGaussianFactor->inner(); | ||||||
| 
 | 
 | ||||||
|  |       // Compute the error of the gaussian factor.
 | ||||||
|       double error = gaussian->error(continuousValues); |       double error = gaussian->error(continuousValues); | ||||||
|  |       // Add the gaussian factor error to every leaf of the error tree.
 | ||||||
|       error_tree = error_tree.apply( |       error_tree = error_tree.apply( | ||||||
|           [error](double leaf_value) { return leaf_value + error; }); |           [error](double leaf_value) { return leaf_value + error; }); | ||||||
| 
 | 
 | ||||||
|  |  | ||||||
|  | @ -23,6 +23,7 @@ | ||||||
| #include <gtsam/hybrid/GaussianMixtureFactor.h> | #include <gtsam/hybrid/GaussianMixtureFactor.h> | ||||||
| #include <gtsam/hybrid/HybridNonlinearFactor.h> | #include <gtsam/hybrid/HybridNonlinearFactor.h> | ||||||
| #include <gtsam/nonlinear/NonlinearFactor.h> | #include <gtsam/nonlinear/NonlinearFactor.h> | ||||||
|  | #include <gtsam/nonlinear/NonlinearFactorGraph.h> | ||||||
| #include <gtsam/nonlinear/Symbol.h> | #include <gtsam/nonlinear/Symbol.h> | ||||||
| 
 | 
 | ||||||
| #include <algorithm> | #include <algorithm> | ||||||
|  | @ -86,11 +87,11 @@ class MixtureFactor : public HybridFactor { | ||||||
|    * elements based on the number of discrete keys and the cardinality of the |    * elements based on the number of discrete keys and the cardinality of the | ||||||
|    * keys, so that the decision tree is constructed appropriately. |    * keys, so that the decision tree is constructed appropriately. | ||||||
|    * |    * | ||||||
|    * @tparam FACTOR The type of the factor shared pointers being passed in. Will |    * @tparam FACTOR The type of the factor shared pointers being passed in. | ||||||
|    * be typecast to NonlinearFactor shared pointers. |    * Will be typecast to NonlinearFactor shared pointers. | ||||||
|    * @param keys Vector of keys for continuous factors. |    * @param keys Vector of keys for continuous factors. | ||||||
|    * @param discreteKeys Vector of discrete keys. |    * @param discreteKeys Vector of discrete keys. | ||||||
|    * @param factors Vector of shared pointers to factors. |    * @param factors Vector of nonlinear factors. | ||||||
|    * @param normalized Flag indicating if the factor error is already |    * @param normalized Flag indicating if the factor error is already | ||||||
|    * normalized. |    * normalized. | ||||||
|    */ |    */ | ||||||
|  |  | ||||||
|  | @ -196,8 +196,10 @@ class HybridNonlinearFactorGraph { | ||||||
| 
 | 
 | ||||||
| #include <gtsam/hybrid/MixtureFactor.h> | #include <gtsam/hybrid/MixtureFactor.h> | ||||||
| class MixtureFactor : gtsam::HybridFactor { | class MixtureFactor : gtsam::HybridFactor { | ||||||
|   MixtureFactor(const gtsam::KeyVector& keys, const gtsam::DiscreteKeys& discreteKeys, |   MixtureFactor( | ||||||
|                 const gtsam::DecisionTree<gtsam::Key, gtsam::NonlinearFactor*>& factors, bool normalized = false); |       const gtsam::KeyVector& keys, const gtsam::DiscreteKeys& discreteKeys, | ||||||
|  |       const gtsam::DecisionTree<gtsam::Key, gtsam::NonlinearFactor*>& factors, | ||||||
|  |       bool normalized = false); | ||||||
| 
 | 
 | ||||||
|   template <FACTOR = {gtsam::NonlinearFactor}> |   template <FACTOR = {gtsam::NonlinearFactor}> | ||||||
|   MixtureFactor(const gtsam::KeyVector& keys, const gtsam::DiscreteKeys& discreteKeys, |   MixtureFactor(const gtsam::KeyVector& keys, const gtsam::DiscreteKeys& discreteKeys, | ||||||
|  |  | ||||||
|  | @ -104,7 +104,7 @@ TEST(GaussianMixture, Error) { | ||||||
|                                                               X(2), S2, model); |                                                               X(2), S2, model); | ||||||
| 
 | 
 | ||||||
|   // Create decision tree
 |   // Create decision tree
 | ||||||
|   DiscreteKey m1(1, 2); |   DiscreteKey m1(M(1), 2); | ||||||
|   GaussianMixture::Conditionals conditionals( |   GaussianMixture::Conditionals conditionals( | ||||||
|       {m1}, |       {m1}, | ||||||
|       vector<GaussianConditional::shared_ptr>{conditional0, conditional1}); |       vector<GaussianConditional::shared_ptr>{conditional0, conditional1}); | ||||||
|  | @ -115,12 +115,19 @@ TEST(GaussianMixture, Error) { | ||||||
|   values.insert(X(2), Vector2::Zero()); |   values.insert(X(2), Vector2::Zero()); | ||||||
|   auto error_tree = mixture.error(values); |   auto error_tree = mixture.error(values); | ||||||
| 
 | 
 | ||||||
|  |   // regression
 | ||||||
|   std::vector<DiscreteKey> discrete_keys = {m1}; |   std::vector<DiscreteKey> discrete_keys = {m1}; | ||||||
|   std::vector<double> leaves = {0.5, 4.3252595}; |   std::vector<double> leaves = {0.5, 4.3252595}; | ||||||
|   AlgebraicDecisionTree<Key> expected_error(discrete_keys, leaves); |   AlgebraicDecisionTree<Key> expected_error(discrete_keys, leaves); | ||||||
| 
 | 
 | ||||||
|   // regression
 |  | ||||||
|   EXPECT(assert_equal(expected_error, error_tree, 1e-6)); |   EXPECT(assert_equal(expected_error, error_tree, 1e-6)); | ||||||
|  | 
 | ||||||
|  |   // Regression for non-tree version.
 | ||||||
|  |   DiscreteValues assignment; | ||||||
|  |   assignment[M(1)] = 0; | ||||||
|  |   EXPECT_DOUBLES_EQUAL(0.5, mixture.error(values, assignment), 1e-8); | ||||||
|  |   assignment[M(1)] = 1; | ||||||
|  |   EXPECT_DOUBLES_EQUAL(4.3252595155709335, mixture.error(values, assignment), 1e-8); | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| /* ************************************************************************* */ | /* ************************************************************************* */ | ||||||
|  |  | ||||||
|  | @ -178,6 +178,7 @@ TEST(GaussianMixtureFactor, Error) { | ||||||
|   AlgebraicDecisionTree<Key> error_tree = mixtureFactor.error(continuousValues); |   AlgebraicDecisionTree<Key> error_tree = mixtureFactor.error(continuousValues); | ||||||
| 
 | 
 | ||||||
|   std::vector<DiscreteKey> discrete_keys = {m1}; |   std::vector<DiscreteKey> discrete_keys = {m1}; | ||||||
|  |   // Error values for regression test
 | ||||||
|   std::vector<double> errors = {1, 4}; |   std::vector<double> errors = {1, 4}; | ||||||
|   AlgebraicDecisionTree<Key> expected_error(discrete_keys, errors); |   AlgebraicDecisionTree<Key> expected_error(discrete_keys, errors); | ||||||
| 
 | 
 | ||||||
|  |  | ||||||
|  | @ -216,8 +216,7 @@ TEST(HybridBayesNet, Error) { | ||||||
| 
 | 
 | ||||||
|   // Verify error computation and check for specific error value
 |   // Verify error computation and check for specific error value
 | ||||||
|   DiscreteValues discrete_values; |   DiscreteValues discrete_values; | ||||||
|   discrete_values[M(0)] = 1; |   insert(discrete_values)(M(0), 1)(M(1), 1); | ||||||
|   discrete_values[M(1)] = 1; |  | ||||||
| 
 | 
 | ||||||
|   double total_error = 0; |   double total_error = 0; | ||||||
|   for (size_t idx = 0; idx < hybridBayesNet->size(); idx++) { |   for (size_t idx = 0; idx < hybridBayesNet->size(); idx++) { | ||||||
|  |  | ||||||
|  | @ -41,7 +41,8 @@ TEST(MixtureFactor, Constructor) { | ||||||
|   CHECK(it == factor.end()); |   CHECK(it == factor.end()); | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| 
 | /* ************************************************************************* */ | ||||||
|  | // Test .print() output.
 | ||||||
| TEST(MixtureFactor, Printing) { | TEST(MixtureFactor, Printing) { | ||||||
|   DiscreteKey m1(1, 2); |   DiscreteKey m1(1, 2); | ||||||
|   double between0 = 0.0; |   double between0 = 0.0; | ||||||
|  |  | ||||||
		Loading…
	
		Reference in New Issue