Add GaussianMixture tests
							parent
							
								
									9a1eb022a9
								
							
						
					
					
						commit
						2396bca22f
					
				|  | @ -19,7 +19,7 @@ | ||||||
|  */ |  */ | ||||||
| 
 | 
 | ||||||
| #include <gtsam/base/utilities.h> | #include <gtsam/base/utilities.h> | ||||||
| #include <gtsam/discrete/DecisionTree-inl.h> | #include <gtsam/discrete/DiscreteValues.h> | ||||||
| #include <gtsam/hybrid/GaussianMixture.h> | #include <gtsam/hybrid/GaussianMixture.h> | ||||||
| #include <gtsam/inference/Conditional-inst.h> | #include <gtsam/inference/Conditional-inst.h> | ||||||
| #include <gtsam/linear/GaussianFactorGraph.h> | #include <gtsam/linear/GaussianFactorGraph.h> | ||||||
|  | @ -77,8 +77,29 @@ GaussianMixture::asGaussianFactorGraphTree() const { | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| /* *******************************************************************************/ | /* *******************************************************************************/ | ||||||
| bool GaussianMixture::equals(const HybridFactor &lf, | size_t GaussianMixture::nrComponents() const { | ||||||
|                                         double tol) const { |   size_t total = 0; | ||||||
|  |   conditionals_.visit([&total](const GaussianFactor::shared_ptr &node) { | ||||||
|  |     if (node) total += 1; | ||||||
|  |   }); | ||||||
|  |   return total; | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | /* *******************************************************************************/ | ||||||
|  | GaussianConditional::shared_ptr GaussianMixture::operator()( | ||||||
|  |     const DiscreteValues &discreteVals) const { | ||||||
|  |   auto &ptr = conditionals_(discreteVals); | ||||||
|  |   if (!ptr) return nullptr; | ||||||
|  |   auto conditional = boost::dynamic_pointer_cast<GaussianConditional>(ptr); | ||||||
|  |   if (conditional) | ||||||
|  |     return conditional; | ||||||
|  |   else | ||||||
|  |     throw std::logic_error( | ||||||
|  |         "A GaussianMixture unexpectedly contained a non-conditional"); | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | /* *******************************************************************************/ | ||||||
|  | bool GaussianMixture::equals(const HybridFactor &lf, double tol) const { | ||||||
|   const This *e = dynamic_cast<const This *>(&lf); |   const This *e = dynamic_cast<const This *>(&lf); | ||||||
|   return e != nullptr && BaseFactor::equals(*e, tol); |   return e != nullptr && BaseFactor::equals(*e, tol); | ||||||
| } | } | ||||||
|  |  | ||||||
|  | @ -19,7 +19,9 @@ | ||||||
| 
 | 
 | ||||||
| #pragma once | #pragma once | ||||||
| 
 | 
 | ||||||
|  | #include <gtsam/discrete/DecisionTree-inl.h> | ||||||
| #include <gtsam/discrete/DecisionTree.h> | #include <gtsam/discrete/DecisionTree.h> | ||||||
|  | #include <gtsam/discrete/DiscreteKey.h> | ||||||
| #include <gtsam/hybrid/HybridFactor.h> | #include <gtsam/hybrid/HybridFactor.h> | ||||||
| #include <gtsam/inference/Conditional.h> | #include <gtsam/inference/Conditional.h> | ||||||
| #include <gtsam/linear/GaussianConditional.h> | #include <gtsam/linear/GaussianConditional.h> | ||||||
|  | @ -99,6 +101,16 @@ class GTSAM_EXPORT GaussianMixture | ||||||
|       const DiscreteKeys &discreteParents, |       const DiscreteKeys &discreteParents, | ||||||
|       const std::vector<GaussianConditional::shared_ptr> &conditionals); |       const std::vector<GaussianConditional::shared_ptr> &conditionals); | ||||||
| 
 | 
 | ||||||
|  |   /// @}
 | ||||||
|  |   /// @name Standard API
 | ||||||
|  |   /// @{
 | ||||||
|  | 
 | ||||||
|  |   GaussianConditional::shared_ptr operator()( | ||||||
|  |       const DiscreteValues &discreteVals) const; | ||||||
|  | 
 | ||||||
|  |   /// Returns the total number of continuous components
 | ||||||
|  |   size_t nrComponents() const; | ||||||
|  | 
 | ||||||
|   /// @}
 |   /// @}
 | ||||||
|   /// @name Testable
 |   /// @name Testable
 | ||||||
|   /// @{
 |   /// @{
 | ||||||
|  |  | ||||||
|  | @ -0,0 +1,92 @@ | ||||||
|  | /* ----------------------------------------------------------------------------
 | ||||||
|  | 
 | ||||||
|  |  * GTSAM Copyright 2010, Georgia Tech Research Corporation, | ||||||
|  |  * Atlanta, Georgia 30332-0415 | ||||||
|  |  * All Rights Reserved | ||||||
|  |  * Authors: Frank Dellaert, et al. (see THANKS for the full author list) | ||||||
|  | 
 | ||||||
|  |  * See LICENSE for the license information | ||||||
|  | 
 | ||||||
|  |  * -------------------------------------------------------------------------- */ | ||||||
|  | 
 | ||||||
|  | /**
 | ||||||
|  |  * @file    testGaussianMixture.cpp | ||||||
|  |  * @brief   Unit tests for GaussianMixture class | ||||||
|  |  * @author  Varun Agrawal | ||||||
|  |  * @author  Fan Jiang | ||||||
|  |  * @author  Frank Dellaert | ||||||
|  |  * @date    December 2021 | ||||||
|  |  */ | ||||||
|  | 
 | ||||||
|  | #include <gtsam/discrete/DiscreteValues.h> | ||||||
|  | #include <gtsam/hybrid/GaussianMixture.h> | ||||||
|  | #include <gtsam/inference/Symbol.h> | ||||||
|  | #include <gtsam/linear/GaussianConditional.h> | ||||||
|  | 
 | ||||||
|  | #include <vector> | ||||||
|  | 
 | ||||||
|  | // Include for test suite
 | ||||||
|  | #include <CppUnitLite/TestHarness.h> | ||||||
|  | 
 | ||||||
|  | using namespace std; | ||||||
|  | using namespace gtsam; | ||||||
|  | using noiseModel::Isotropic; | ||||||
|  | using symbol_shorthand::M; | ||||||
|  | using symbol_shorthand::X; | ||||||
|  | 
 | ||||||
|  | /* ************************************************************************* */ | ||||||
|  | TEST(GaussianConditional, Equals) { | ||||||
|  |   // create a conditional gaussian node
 | ||||||
|  |   Matrix S1(2, 2); | ||||||
|  |   S1(0, 0) = 1; | ||||||
|  |   S1(1, 0) = 2; | ||||||
|  |   S1(0, 1) = 3; | ||||||
|  |   S1(1, 1) = 4; | ||||||
|  | 
 | ||||||
|  |   Matrix S2(2, 2); | ||||||
|  |   S2(0, 0) = 6; | ||||||
|  |   S2(1, 0) = 0.2; | ||||||
|  |   S2(0, 1) = 8; | ||||||
|  |   S2(1, 1) = 0.4; | ||||||
|  | 
 | ||||||
|  |   Matrix R1(2, 2); | ||||||
|  |   R1(0, 0) = 0.1; | ||||||
|  |   R1(1, 0) = 0.3; | ||||||
|  |   R1(0, 1) = 0.0; | ||||||
|  |   R1(1, 1) = 0.34; | ||||||
|  | 
 | ||||||
|  |   Matrix R2(2, 2); | ||||||
|  |   R2(0, 0) = 0.1; | ||||||
|  |   R2(1, 0) = 0.3; | ||||||
|  |   R2(0, 1) = 0.0; | ||||||
|  |   R2(1, 1) = 0.34; | ||||||
|  | 
 | ||||||
|  |   SharedDiagonal model = noiseModel::Diagonal::Sigmas(Vector2(1.0, 0.34)); | ||||||
|  | 
 | ||||||
|  |   Vector2 d1(0.2, 0.5), d2(0.5, 0.2); | ||||||
|  | 
 | ||||||
|  |   auto conditional0 = boost::make_shared<GaussianConditional>(X(1), d1, R1, | ||||||
|  |                                                               X(2), S1, model), | ||||||
|  |        conditional1 = boost::make_shared<GaussianConditional>(X(1), d2, R2, | ||||||
|  |                                                               X(2), S2, model); | ||||||
|  | 
 | ||||||
|  |   // Create decision tree
 | ||||||
|  |   DiscreteKey m1(1, 2); | ||||||
|  |   GaussianMixture::Conditionals conditionals( | ||||||
|  |       {m1}, | ||||||
|  |       vector<GaussianConditional::shared_ptr>{conditional0, conditional1}); | ||||||
|  |   GaussianMixture mixtureFactor({X(1)}, {X(2)}, {m1}, conditionals); | ||||||
|  | 
 | ||||||
|  |   // Let's check that this worked:
 | ||||||
|  |   DiscreteValues mode; | ||||||
|  |   mode[m1.first] = 1; | ||||||
|  |   auto actual = mixtureFactor(mode); | ||||||
|  |   EXPECT(actual == conditional1); | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | /* ************************************************************************* */ | ||||||
|  | int main() { | ||||||
|  |   TestResult tr; | ||||||
|  |   return TestRegistry::runAllTests(tr); | ||||||
|  | } | ||||||
|  | /* ************************************************************************* */ | ||||||
		Loading…
	
		Reference in New Issue