diff --git a/gtsam/hybrid/HybridGaussianFactorGraph.cpp b/gtsam/hybrid/HybridGaussianFactorGraph.cpp index 68f8f432d..c912a74fc 100644 --- a/gtsam/hybrid/HybridGaussianFactorGraph.cpp +++ b/gtsam/hybrid/HybridGaussianFactorGraph.cpp @@ -220,12 +220,11 @@ hybridElimination(const HybridGaussianFactorGraph &factors, // FG has a nullptr as we're looping over the factors. factorGraphTree = removeEmpty(factorGraphTree); - using EliminationPair = std::pair, - GaussianMixtureFactor::sharedFactor>; + using Result = std::pair, + GaussianMixtureFactor::sharedFactor>; // This is the elimination method on the leaf nodes - auto eliminateFunc = - [&](const GaussianFactorGraph &graph) -> EliminationPair { + auto eliminate = [&](const GaussianFactorGraph &graph) -> Result { if (graph.empty()) { return {nullptr, nullptr}; } @@ -234,21 +233,17 @@ hybridElimination(const HybridGaussianFactorGraph &factors, gttic_(hybrid_eliminate); #endif - boost::shared_ptr conditional; - boost::shared_ptr newFactor; - boost::tie(conditional, newFactor) = - EliminatePreferCholesky(graph, frontalKeys); + auto result = EliminatePreferCholesky(graph, frontalKeys); #ifdef HYBRID_TIMING gttoc_(hybrid_eliminate); #endif - return {conditional, newFactor}; + return result; }; // Perform elimination! - DecisionTree eliminationResults(factorGraphTree, - eliminateFunc); + DecisionTree eliminationResults(factorGraphTree, eliminate); #ifdef HYBRID_TIMING tictoc_print_(); @@ -264,30 +259,46 @@ hybridElimination(const HybridGaussianFactorGraph &factors, auto gaussianMixture = boost::make_shared( frontalKeys, continuousSeparator, discreteSeparator, conditionals); - // If there are no more continuous parents, then we should create a - // DiscreteFactor here, with the error for each discrete choice. if (continuousSeparator.empty()) { - auto probPrime = [&](const EliminationPair &pair) { - // This is the unnormalized probability q(μ;m) at the mean. - // q(μ;m) = exp(-error(μ;m)) * sqrt(det(2π Σ_m)) - // The factor has no keys, just contains the residual. + // If there are no more continuous parents, then we create a + // DiscreteFactor here, with the error for each discrete choice. + + // Integrate the probability mass in the last continuous conditional using + // the unnormalized probability q(μ;m) = exp(-error(μ;m)) at the mean. + // discrete_probability = exp(-error(μ;m)) * sqrt(det(2π Σ_m)) + auto probability = [&](const Result &pair) -> double { static const VectorValues kEmpty; - return pair.second ? exp(-pair.second->error(kEmpty)) / - pair.first->normalizationConstant() - : 1.0; + // If the factor is not null, it has no keys, just contains the residual. + const auto &factor = pair.second; + if (!factor) return 1.0; // TODO(dellaert): not loving this. + return exp(-factor->error(kEmpty)) / pair.first->normalizationConstant(); }; - const auto discreteFactor = boost::make_shared( - discreteSeparator, - DecisionTree(eliminationResults, probPrime)); + DecisionTree probabilities(eliminationResults, probability); + return {boost::make_shared(gaussianMixture), + boost::make_shared(discreteSeparator, + probabilities)}; + } else { + // Otherwise, we create a resulting GaussianMixtureFactor on the separator, + // taking care to correct for conditional constant. + + // Correct for the normalization constant used up by the conditional + auto correct = [&](const Result &pair) -> GaussianFactor::shared_ptr { + const auto &factor = pair.second; + if (!factor) return factor; // TODO(dellaert): not loving this. + auto hf = boost::dynamic_pointer_cast(factor); + if (!hf) throw std::runtime_error("Expected HessianFactor!"); + hf->constantTerm() += 2.0 * pair.first->logNormalizationConstant(); + return hf; + }; + + GaussianMixtureFactor::Factors correctedFactors(eliminationResults, + correct); + const auto mixtureFactor = boost::make_shared( + continuousSeparator, discreteSeparator, newFactors); return {boost::make_shared(gaussianMixture), - discreteFactor}; - } else { - // Create a resulting GaussianMixtureFactor on the separator. - return {boost::make_shared(gaussianMixture), - boost::make_shared( - continuousSeparator, discreteSeparator, newFactors)}; + mixtureFactor}; } }