diff --git a/gtsam/hybrid/GaussianMixtureFactor.h b/gtsam/hybrid/GaussianMixtureFactor.h index 913426a98..b1be6ef1d 100644 --- a/gtsam/hybrid/GaussianMixtureFactor.h +++ b/gtsam/hybrid/GaussianMixtureFactor.h @@ -134,7 +134,8 @@ class GTSAM_EXPORT GaussianMixtureFactor : public HybridFactor { * @return AlgebraicDecisionTree A decision tree with the same keys * as the factors involved, and leaf values as the error. */ - AlgebraicDecisionTree errorTree(const VectorValues &continuousValues) const; + AlgebraicDecisionTree errorTree( + const VectorValues &continuousValues) const; /** * @brief Compute the log-likelihood, including the log-normalizing constant. diff --git a/gtsam/hybrid/HybridGaussianFactorGraph.cpp b/gtsam/hybrid/HybridGaussianFactorGraph.cpp index 42402076e..f1b79b123 100644 --- a/gtsam/hybrid/HybridGaussianFactorGraph.cpp +++ b/gtsam/hybrid/HybridGaussianFactorGraph.cpp @@ -296,7 +296,10 @@ static std::shared_ptr createDiscreteFactor( // Logspace version of: // exp(-factor->error(kEmpty)) / conditional->normalizationConstant(); - return -factor->error(kEmpty) - conditional->logNormalizationConstant(); + // We take negative of the logNormalizationConstant `log(1/k)` + // to get `log(k)`. + // factor->print("Discrete Separator"); + return -factor->error(kEmpty) + (-conditional->logNormalizationConstant()); }; AlgebraicDecisionTree logProbabilities( @@ -368,6 +371,12 @@ hybridElimination(const HybridGaussianFactorGraph &factors, // Perform elimination! DecisionTree eliminationResults(factorGraphTree, eliminate); + // Create the GaussianMixture from the conditionals + GaussianMixture::Conditionals conditionals( + eliminationResults, [](const Result &pair) { return pair.first; }); + auto gaussianMixture = std::make_shared( + frontalKeys, continuousSeparator, discreteSeparator, conditionals); + // If there are no more continuous parents we create a DiscreteFactor with the // error for each discrete choice. Otherwise, create a GaussianMixtureFactor // on the separator, taking care to correct for conditional constants. @@ -377,12 +386,6 @@ hybridElimination(const HybridGaussianFactorGraph &factors, : createGaussianMixtureFactor(eliminationResults, continuousSeparator, discreteSeparator); - // Create the GaussianMixture from the conditionals - GaussianMixture::Conditionals conditionals( - eliminationResults, [](const Result &pair) { return pair.first; }); - auto gaussianMixture = std::make_shared( - frontalKeys, continuousSeparator, discreteSeparator, conditionals); - return {std::make_shared(gaussianMixture), newFactor}; }