Cleaner version of eliminate
							parent
							
								
									9d70605d48
								
							
						
					
					
						commit
						fcda1536c6
					
				|  | @ -96,7 +96,6 @@ static GaussianFactorGraphTree addGaussian( | |||
| // TODO(dellaert): it's probably more efficient to first collect the discrete
 | ||||
| // keys, and then loop over all assignments to populate a vector.
 | ||||
| GaussianFactorGraphTree HybridGaussianFactorGraph::assembleGraphTree() const { | ||||
| 
 | ||||
|   GaussianFactorGraphTree result; | ||||
| 
 | ||||
|   for (auto &f : factors_) { | ||||
|  | @ -198,6 +197,51 @@ GaussianFactorGraphTree removeEmpty(const GaussianFactorGraphTree &sum) { | |||
| } | ||||
| 
 | ||||
| /* ************************************************************************ */ | ||||
| using Result = std::pair<std::shared_ptr<GaussianConditional>, | ||||
|                          GaussianMixtureFactor::sharedFactor>; | ||||
| 
 | ||||
| // Integrate the probability mass in the last continuous conditional using
 | ||||
| // the unnormalized probability q(μ;m) = exp(-error(μ;m)) at the mean.
 | ||||
| //   discrete_probability = exp(-error(μ;m)) * sqrt(det(2π Σ_m))
 | ||||
| static std::shared_ptr<Factor> createDiscreteFactor( | ||||
|     const DecisionTree<Key, Result> &eliminationResults, | ||||
|     const DiscreteKeys &discreteSeparator) { | ||||
|   auto probability = [&](const Result &pair) -> double { | ||||
|     const auto &[conditional, factor] = pair; | ||||
|     static const VectorValues kEmpty; | ||||
|     // If the factor is not null, it has no keys, just contains the residual.
 | ||||
|     if (!factor) return 1.0;  // TODO(dellaert): not loving this.
 | ||||
|     return exp(-factor->error(kEmpty)) / conditional->normalizationConstant(); | ||||
|   }; | ||||
| 
 | ||||
|   DecisionTree<Key, double> probabilities(eliminationResults, probability); | ||||
| 
 | ||||
|   return std::make_shared<DecisionTreeFactor>(discreteSeparator, probabilities); | ||||
| } | ||||
| 
 | ||||
| // Create GaussianMixtureFactor on the separator, taking care to correct
 | ||||
| // for conditional constants.
 | ||||
| static std::shared_ptr<Factor> createGaussianMixtureFactor( | ||||
|     const DecisionTree<Key, Result> &eliminationResults, | ||||
|     const KeyVector &continuousSeparator, | ||||
|     const DiscreteKeys &discreteSeparator) { | ||||
|   // Correct for the normalization constant used up by the conditional
 | ||||
|   auto correct = [&](const Result &pair) -> GaussianFactor::shared_ptr { | ||||
|     const auto &[conditional, factor] = pair; | ||||
|     if (factor) { | ||||
|       auto hf = std::dynamic_pointer_cast<HessianFactor>(factor); | ||||
|       if (!hf) throw std::runtime_error("Expected HessianFactor!"); | ||||
|       hf->constantTerm() += 2.0 * conditional->logNormalizationConstant(); | ||||
|     } | ||||
|     return factor; | ||||
|   }; | ||||
|   DecisionTree<Key, GaussianFactor::shared_ptr> newFactors(eliminationResults, | ||||
|                                                            correct); | ||||
| 
 | ||||
|   return std::make_shared<GaussianMixtureFactor>(continuousSeparator, | ||||
|                                                  discreteSeparator, newFactors); | ||||
| } | ||||
| 
 | ||||
| static std::pair<HybridConditional::shared_ptr, std::shared_ptr<Factor>> | ||||
| hybridElimination(const HybridGaussianFactorGraph &factors, | ||||
|                   const Ordering &frontalKeys, | ||||
|  | @ -217,9 +261,6 @@ hybridElimination(const HybridGaussianFactorGraph &factors, | |||
|   // FG has a nullptr as we're looping over the factors.
 | ||||
|   factorGraphTree = removeEmpty(factorGraphTree); | ||||
| 
 | ||||
|   using Result = std::pair<std::shared_ptr<GaussianConditional>, | ||||
|                            GaussianMixtureFactor::sharedFactor>; | ||||
| 
 | ||||
|   // This is the elimination method on the leaf nodes
 | ||||
|   auto eliminate = [&](const GaussianFactorGraph &graph) -> Result { | ||||
|     if (graph.empty()) { | ||||
|  | @ -234,53 +275,22 @@ hybridElimination(const HybridGaussianFactorGraph &factors, | |||
|   // Perform elimination!
 | ||||
|   DecisionTree<Key, Result> eliminationResults(factorGraphTree, eliminate); | ||||
| 
 | ||||
|   // Separate out decision tree into conditionals and remaining factors.
 | ||||
|   const auto [conditionals, newFactors] = unzip(eliminationResults); | ||||
|   // If there are no more continuous parents we create a DiscreteFactor with the
 | ||||
|   // error for each discrete choice. Otherwise, create a GaussianMixtureFactor
 | ||||
|   // on the separator, taking care to correct for conditional constants.
 | ||||
|   auto newFactor = | ||||
|       continuousSeparator.empty() | ||||
|           ? createDiscreteFactor(eliminationResults, discreteSeparator) | ||||
|           : createGaussianMixtureFactor(eliminationResults, continuousSeparator, | ||||
|                                         discreteSeparator); | ||||
| 
 | ||||
|   // Create the GaussianMixture from the conditionals
 | ||||
|   GaussianMixture::Conditionals conditionals( | ||||
|       eliminationResults, [](const Result &pair) { return pair.first; }); | ||||
|   auto gaussianMixture = std::make_shared<GaussianMixture>( | ||||
|       frontalKeys, continuousSeparator, discreteSeparator, conditionals); | ||||
| 
 | ||||
|   if (continuousSeparator.empty()) { | ||||
|     // If there are no more continuous parents, then we create a
 | ||||
|     // DiscreteFactor here, with the error for each discrete choice.
 | ||||
| 
 | ||||
|     // Integrate the probability mass in the last continuous conditional using
 | ||||
|     // the unnormalized probability q(μ;m) = exp(-error(μ;m)) at the mean.
 | ||||
|     //   discrete_probability = exp(-error(μ;m)) * sqrt(det(2π Σ_m))
 | ||||
|     auto probability = [&](const Result &pair) -> double { | ||||
|       static const VectorValues kEmpty; | ||||
|       // If the factor is not null, it has no keys, just contains the residual.
 | ||||
|       const auto &factor = pair.second; | ||||
|       if (!factor) return 1.0;  // TODO(dellaert): not loving this.
 | ||||
|       return exp(-factor->error(kEmpty)) / pair.first->normalizationConstant(); | ||||
|     }; | ||||
| 
 | ||||
|     DecisionTree<Key, double> probabilities(eliminationResults, probability); | ||||
| 
 | ||||
|     return { | ||||
|         std::make_shared<HybridConditional>(gaussianMixture), | ||||
|         std::make_shared<DecisionTreeFactor>(discreteSeparator, probabilities)}; | ||||
|   } else { | ||||
|     // Otherwise, we create a resulting GaussianMixtureFactor on the separator,
 | ||||
|     // taking care to correct for conditional constant.
 | ||||
| 
 | ||||
|     // Correct for the normalization constant used up by the conditional
 | ||||
|     auto correct = [&](const Result &pair) { | ||||
|       const auto &factor = pair.second; | ||||
|       if (!factor) return; | ||||
|       auto hf = std::dynamic_pointer_cast<HessianFactor>(factor); | ||||
|       if (!hf) throw std::runtime_error("Expected HessianFactor!"); | ||||
|       hf->constantTerm() += 2.0 * pair.first->logNormalizationConstant(); | ||||
|     }; | ||||
|     eliminationResults.visit(correct); | ||||
| 
 | ||||
|     const auto mixtureFactor = std::make_shared<GaussianMixtureFactor>( | ||||
|         continuousSeparator, discreteSeparator, newFactors); | ||||
| 
 | ||||
|     return {std::make_shared<HybridConditional>(gaussianMixture), | ||||
|             mixtureFactor}; | ||||
|   } | ||||
|   return {std::make_shared<HybridConditional>(gaussianMixture), newFactor}; | ||||
| } | ||||
| 
 | ||||
| /* ************************************************************************
 | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue