diff --git a/gtsam/hybrid/HybridGaussianConditional.cpp b/gtsam/hybrid/HybridGaussianConditional.cpp index e17fd3afe..e0ae16e82 100644 --- a/gtsam/hybrid/HybridGaussianConditional.cpp +++ b/gtsam/hybrid/HybridGaussianConditional.cpp @@ -33,8 +33,10 @@ HybridGaussianFactor::FactorValuePairs GetFactorValuePairs( auto func = [](const GaussianConditional::shared_ptr &conditional) -> GaussianFactorValuePair { double value = 0.0; - if (conditional) { // Check if conditional is pruned - value = conditional->logNormalizationConstant(); + // Check if conditional is pruned + if (conditional) { + // Assign log(|2πΣ|) = -2*log(1 / sqrt(|2πΣ|)) + value = -2.0 * conditional->logNormalizationConstant(); } return {std::dynamic_pointer_cast(conditional), value}; }; @@ -49,14 +51,14 @@ HybridGaussianConditional::HybridGaussianConditional( discreteParents, GetFactorValuePairs(conditionals)), BaseConditional(continuousFrontals.size()), conditionals_(conditionals) { - // Calculate logConstant_ as the maximum of the log constants of the + // Calculate logNormalizer_ as the minimum of the log normalizers of the // conditionals, by visiting the decision tree: - logConstant_ = -std::numeric_limits::infinity(); + logNormalizer_ = std::numeric_limits::infinity(); conditionals_.visit( [this](const GaussianConditional::shared_ptr &conditional) { if (conditional) { - this->logConstant_ = std::max( - this->logConstant_, conditional->logNormalizationConstant()); + this->logNormalizer_ = std::min( + this->logNormalizer_, -conditional->logNormalizationConstant()); } }); } @@ -98,7 +100,7 @@ GaussianFactorGraphTree HybridGaussianConditional::asGaussianFactorGraphTree() // First check if conditional has not been pruned if (gc) { const double Cgm_Kgcm = - this->logConstant_ - gc->logNormalizationConstant(); + -gc->logNormalizationConstant() - this->logNormalizer_; // If there is a difference in the covariances, we need to account for // that since the error is dependent on the mode. if (Cgm_Kgcm > 0.0) { @@ -169,7 +171,8 @@ void HybridGaussianConditional::print(const std::string &s, std::cout << "(" << formatter(dk.first) << ", " << dk.second << "), "; } std::cout << std::endl - << " logNormalizationConstant: " << logConstant_ << std::endl + << " logNormalizationConstant: " << logNormalizationConstant() + << std::endl << std::endl; conditionals_.print( "", [&](Key k) { return formatter(k); }, @@ -228,7 +231,7 @@ std::shared_ptr HybridGaussianConditional::likelihood( -> GaussianFactorValuePair { const auto likelihood_m = conditional->likelihood(given); const double Cgm_Kgcm = - logConstant_ - conditional->logNormalizationConstant(); + -conditional->logNormalizationConstant() - logNormalizer_; if (Cgm_Kgcm == 0.0) { return {likelihood_m, 0.0}; } else { @@ -342,7 +345,7 @@ double HybridGaussianConditional::conditionalError( // Check if valid pointer if (conditional) { return conditional->error(continuousValues) + // - logConstant_ - conditional->logNormalizationConstant(); + -conditional->logNormalizationConstant() - logNormalizer_; } else { // If not valid, pointer, it means this conditional was pruned, // so we return maximum error. diff --git a/gtsam/hybrid/HybridGaussianConditional.h b/gtsam/hybrid/HybridGaussianConditional.h index 82cf6ec8a..434750bc9 100644 --- a/gtsam/hybrid/HybridGaussianConditional.h +++ b/gtsam/hybrid/HybridGaussianConditional.h @@ -64,7 +64,8 @@ class GTSAM_EXPORT HybridGaussianConditional private: Conditionals conditionals_; ///< a decision tree of Gaussian conditionals. - double logConstant_; ///< log of the normalization constant. + double logNormalizer_; ///< log of the normalization constant + ///< (log(\sqrt(|2πΣ|))). /** * @brief Convert a HybridGaussianConditional of conditionals into @@ -149,7 +150,7 @@ class GTSAM_EXPORT HybridGaussianConditional /// The log normalization constant is max of the the individual /// log-normalization constants. - double logNormalizationConstant() const override { return logConstant_; } + double logNormalizationConstant() const override { return -logNormalizer_; } /** * Create a likelihood factor for a hybrid Gaussian conditional, diff --git a/gtsam/hybrid/tests/testHybridGaussianConditional.cpp b/gtsam/hybrid/tests/testHybridGaussianConditional.cpp index 406203db8..70cc55712 100644 --- a/gtsam/hybrid/tests/testHybridGaussianConditional.cpp +++ b/gtsam/hybrid/tests/testHybridGaussianConditional.cpp @@ -100,7 +100,7 @@ TEST(HybridGaussianConditional, Error) { auto actual = hybrid_conditional.errorTree(vv); // Check result. - std::vector discrete_keys = {mode}; + DiscreteKeys discrete_keys{mode}; std::vector leaves = {conditionals[0]->error(vv), conditionals[1]->error(vv)}; AlgebraicDecisionTree expected(discrete_keys, leaves); @@ -172,6 +172,37 @@ TEST(HybridGaussianConditional, ContinuousParents) { EXPECT(continuousParentKeys[0] == X(0)); } +/* ************************************************************************* */ +/// Check error with mode dependent constants. +TEST(HybridGaussianConditional, Error2) { + using namespace mode_dependent_constants; + auto actual = mixture.errorTree(vv); + + // Check result. + DiscreteKeys discrete_keys{mode}; + double logNormalizer0 = -conditionals[0]->logNormalizationConstant(); + double logNormalizer1 = -conditionals[1]->logNormalizationConstant(); + double minLogNormalizer = std::min(logNormalizer0, logNormalizer1); + + // Expected error is e(X) + log(|2πΣ|). + // We normalize log(|2πΣ|) with min(logNormalizers) so it is non-negative. + std::vector leaves = { + conditionals[0]->error(vv) + logNormalizer0 - minLogNormalizer, + conditionals[1]->error(vv) + logNormalizer1 - minLogNormalizer}; + AlgebraicDecisionTree expected(discrete_keys, leaves); + + EXPECT(assert_equal(expected, actual, 1e-6)); + + // Check for non-tree version. + for (size_t mode : {0, 1}) { + const HybridValues hv{vv, {{M(0), mode}}}; + EXPECT_DOUBLES_EQUAL(conditionals[mode]->error(vv) - + conditionals[mode]->logNormalizationConstant() - + minLogNormalizer, + mixture.error(hv), 1e-8); + } +} + /* ************************************************************************* */ /// Check that the likelihood is proportional to the conditional density given /// the measurements.