Add error and probPrime variants
							parent
							
								
									2d688a1986
								
							
						
					
					
						commit
						a4659f01c7
					
				|  | @ -468,12 +468,51 @@ AlgebraicDecisionTree<Key> HybridGaussianFactorGraph::error( | |||
|   return error_tree; | ||||
| } | ||||
| 
 | ||||
| /* ************************************************************************ */ | ||||
| double HybridGaussianFactorGraph::error( | ||||
|     const VectorValues &continuousValues, | ||||
|     const DiscreteValues &discreteValues) const { | ||||
|   double error = 0.0; | ||||
|   for (size_t idx = 0; idx < size(); idx++) { | ||||
|     auto factor = factors_.at(idx); | ||||
| 
 | ||||
|     if (factor->isHybrid()) { | ||||
|       if (auto c = boost::dynamic_pointer_cast<HybridConditional>(factor)) { | ||||
|         error += c->asMixture()->error(continuousValues, discreteValues); | ||||
|       } | ||||
|       if (auto f = boost::dynamic_pointer_cast<GaussianMixtureFactor>(factor)) { | ||||
|         error += f->error(continuousValues, discreteValues); | ||||
|       } | ||||
| 
 | ||||
|     } else if (factor->isContinuous()) { | ||||
|       if (auto f = boost::dynamic_pointer_cast<HybridGaussianFactor>(factor)) { | ||||
|         error += f->inner()->error(continuousValues); | ||||
|       } | ||||
|       if (auto cg = boost::dynamic_pointer_cast<HybridConditional>(factor)) { | ||||
|         error += cg->asGaussian()->error(continuousValues); | ||||
|       } | ||||
|     } | ||||
|   } | ||||
|   return error; | ||||
| } | ||||
| 
 | ||||
| /* ************************************************************************ */ | ||||
| double HybridGaussianFactorGraph::probPrime( | ||||
|     const VectorValues &continuousValues, | ||||
|     const DiscreteValues &discreteValues) const { | ||||
|   double error = this->error(continuousValues, discreteValues); | ||||
|   // NOTE: The 0.5 term is handled by each factor
 | ||||
|   return std::exp(-error); | ||||
| } | ||||
| 
 | ||||
| /* ************************************************************************ */ | ||||
| AlgebraicDecisionTree<Key> HybridGaussianFactorGraph::probPrime( | ||||
|     const VectorValues &continuousValues) const { | ||||
|   AlgebraicDecisionTree<Key> error_tree = this->error(continuousValues); | ||||
|   AlgebraicDecisionTree<Key> prob_tree = | ||||
|       error_tree.apply([](double error) { return exp(-error); }); | ||||
|   AlgebraicDecisionTree<Key> prob_tree = error_tree.apply([](double error) { | ||||
|     // NOTE: The 0.5 term is handled by each factor
 | ||||
|     return exp(-error); | ||||
|   }); | ||||
|   return prob_tree; | ||||
| } | ||||
| 
 | ||||
|  |  | |||
|  | @ -182,6 +182,19 @@ class GTSAM_EXPORT HybridGaussianFactorGraph | |||
|    */ | ||||
|   AlgebraicDecisionTree<Key> error(const VectorValues& continuousValues) const; | ||||
| 
 | ||||
|   /**
 | ||||
|    * @brief Compute error given a continuous vector values | ||||
|    * and a discrete assignment. | ||||
|    * | ||||
|    * @param continuousValues The continuous VectorValues | ||||
|    * for computing the error. | ||||
|    * @param discreteValues The specific discrete assignment | ||||
|    * whose error we wish to compute. | ||||
|    * @return double | ||||
|    */ | ||||
|   double error(const VectorValues& continuousValues, | ||||
|                const DiscreteValues& discreteValues) const; | ||||
| 
 | ||||
|   /**
 | ||||
|    * @brief Compute unnormalized probability \f$ P(X | M, Z) \f$ | ||||
|    * for each discrete assignment, and return as a tree. | ||||
|  | @ -193,6 +206,18 @@ class GTSAM_EXPORT HybridGaussianFactorGraph | |||
|   AlgebraicDecisionTree<Key> probPrime( | ||||
|       const VectorValues& continuousValues) const; | ||||
| 
 | ||||
|   /**
 | ||||
|    * @brief Compute the unnormalized posterior probability for a continuous | ||||
|    * vector values given a specific assignment. | ||||
|    * | ||||
|    * @param continuousValues The vector values for which to compute the | ||||
|    * posterior probability. | ||||
|    * @param discreteValues The specific assignment to use for the computation. | ||||
|    * @return double | ||||
|    */ | ||||
|   double probPrime(const VectorValues& continuousValues, | ||||
|                    const DiscreteValues& discreteValues) const; | ||||
| 
 | ||||
|   /**
 | ||||
|    * @brief Return a Colamd constrained ordering where the discrete keys are | ||||
|    * eliminated after the continuous keys. | ||||
|  |  | |||
|  | @ -180,6 +180,12 @@ class HybridGaussianFactorGraph { | |||
|   void print(string s = "") const; | ||||
|   bool equals(const gtsam::HybridGaussianFactorGraph& fg, double tol = 1e-9) const; | ||||
| 
 | ||||
|   // evaluation | ||||
|   double error(const gtsam::VectorValues& continuousValues, | ||||
|                const gtsam::DiscreteValues& discreteValues) const; | ||||
|   double probPrime(const gtsam::VectorValues& continuousValues, | ||||
|                    const gtsam::DiscreteValues& discreteValues) const; | ||||
| 
 | ||||
|   gtsam::HybridBayesNet* eliminateSequential(); | ||||
|   gtsam::HybridBayesNet* eliminateSequential( | ||||
|       gtsam::Ordering::OrderingType type); | ||||
|  |  | |||
|  | @ -11,6 +11,7 @@ Author: Fan Jiang | |||
| # pylint: disable=invalid-name, no-name-in-module, no-member | ||||
| 
 | ||||
| import unittest | ||||
| import math | ||||
| 
 | ||||
| import numpy as np | ||||
| from gtsam.symbol_shorthand import C, M, X, Z | ||||
|  | @ -110,7 +111,9 @@ class TestHybridGaussianFactorGraph(GtsamTestCase): | |||
|             conditional1 = GaussianConditional.FromMeanAndStddev( | ||||
|                 Z(i), I, X(0), [0], sigma=3 | ||||
|             ) | ||||
|             bayesNet.emplaceMixture([Z(i)], [X(0)], keys, [conditional0, conditional1]) | ||||
|             bayesNet.emplaceMixture( | ||||
|                 [Z(i)], [X(0)], keys, [conditional0, conditional1] | ||||
|             ) | ||||
| 
 | ||||
|         # Create prior on X(0). | ||||
|         prior_on_x0 = GaussianConditional.FromMeanAndStddev(X(0), [5.0], 5.0) | ||||
|  | @ -136,7 +139,7 @@ class TestHybridGaussianFactorGraph(GtsamTestCase): | |||
|         fg.push_back(factor) | ||||
|         fg.push_back(bayesNet.atGaussian(1)) | ||||
|         fg.push_back(bayesNet.atDiscrete(2)) | ||||
|          | ||||
| 
 | ||||
|         self.assertEqual(fg.size(), 3) | ||||
| 
 | ||||
|     def test_tiny2(self): | ||||
|  | @ -156,8 +159,18 @@ class TestHybridGaussianFactorGraph(GtsamTestCase): | |||
|             fg.push_back(factor) | ||||
|         fg.push_back(bayesNet.atGaussian(2)) | ||||
|         fg.push_back(bayesNet.atDiscrete(3)) | ||||
|          | ||||
| 
 | ||||
|         self.assertEqual(fg.size(), 4) | ||||
|         # Calculate ratio  between Bayes net probability and the factor graph: | ||||
|         continuousValues = gtsam.VectorValues() | ||||
|         continuousValues.insert(X(0), sample.at(X(0))) | ||||
|         discreteValues = sample.discrete() | ||||
|         expected_ratio = bayesNet.evaluate(sample) / fg.probPrime( | ||||
|             continuousValues, discreteValues | ||||
|         ) | ||||
|         print(expected_ratio) | ||||
| 
 | ||||
|         # TODO(dellaert): Change the mode to 0 and calculate the ratio again. | ||||
| 
 | ||||
| 
 | ||||
| if __name__ == "__main__": | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue