diff --git a/gtsam/hybrid/HybridGaussianFactorGraph.cpp b/gtsam/hybrid/HybridGaussianFactorGraph.cpp index 0ffc22ba1..adba9e88b 100644 --- a/gtsam/hybrid/HybridGaussianFactorGraph.cpp +++ b/gtsam/hybrid/HybridGaussianFactorGraph.cpp @@ -188,11 +188,28 @@ discreteElimination(const HybridGaussianFactorGraph &factors, boost::make_shared(result.second)}; } +/* ************************************************************************ */ +// If any GaussianFactorGraph in the decision tree contains a nullptr, convert +// that leaf to an empty GaussianFactorGraph. Needed since the DecisionTree will +// otherwise create a GFG with a single (null) factor. +GaussianMixtureFactor::Sum removeEmpty(const GaussianMixtureFactor::Sum &sum) { + auto emptyGaussian = [](const GaussianMixtureFactor::GraphAndConstant + &graph_z) { + bool hasNull = + std::any_of(graph_z.graph.begin(), graph_z.graph.end(), + [](const GaussianFactor::shared_ptr &ptr) { return !ptr; }); + return hasNull + ? GaussianMixtureFactor::GraphAndConstant{GaussianFactorGraph(), + 0.0} + : graph_z; + }; + return GaussianMixtureFactor::Sum(sum, emptyGaussian); +} /* ************************************************************************ */ static std::pair hybridElimination(const HybridGaussianFactorGraph &factors, const Ordering &frontalKeys, - const KeySet &continuousSeparator, + const KeyVector &continuousSeparator, const std::set &discreteSeparatorSet) { // NOTE: since we use the special JunctionTree, // only possibility is continuous conditioned on discrete. @@ -203,27 +220,12 @@ hybridElimination(const HybridGaussianFactorGraph &factors, // decision tree indexed by all discrete keys involved. GaussianMixtureFactor::Sum sum = factors.SumFrontals(); - // If a tree leaf contains nullptr, convert that leaf to an empty - // GaussianFactorGraph. Needed since the DecisionTree will otherwise create a - // GFG with a single (null) factor. - // TODO(dellaert): can SumFrontals not guarantee this? - auto emptyGaussian = - [](const GaussianMixtureFactor::GraphAndConstant &graph_z) { - bool hasNull = std::any_of( - graph_z.graph.begin(), graph_z.graph.end(), - [](const GaussianFactor::shared_ptr &ptr) { return !ptr; }); - - return hasNull ? GaussianMixtureFactor::GraphAndConstant( - GaussianFactorGraph(), 0.0) - : graph_z; - }; - sum = GaussianMixtureFactor::Sum(sum, emptyGaussian); + // TODO(dellaert): does SumFrontals not guarantee we do not need this? + sum = removeEmpty(sum); using EliminationPair = std::pair, GaussianMixtureFactor::FactorAndConstant>; - KeyVector keysOfSeparator; - // This is the elimination method on the leaf nodes auto eliminate = [&](const GaussianMixtureFactor::GraphAndConstant &graph_z) -> EliminationPair { @@ -240,15 +242,12 @@ hybridElimination(const HybridGaussianFactorGraph &factors, boost::tie(conditional, newFactor) = EliminatePreferCholesky(graph_z.graph, frontalKeys); - // TODO(dellaert): always the same, and we already computed this in caller? - keysOfSeparator = newFactor->keys(); - #ifdef HYBRID_TIMING gttoc_(hybrid_eliminate); #endif // Get the log of the log normalization constant inverse. - double logZ = -conditional->logNormalizationConstant() + graph_z.constant; + double logZ = graph_z.constant - conditional->logNormalizationConstant(); return {conditional, {newFactor, logZ}}; }; @@ -261,35 +260,35 @@ hybridElimination(const HybridGaussianFactorGraph &factors, #endif // Separate out decision tree into conditionals and remaining factors. - auto pair = unzip(eliminationResults); + GaussianMixture::Conditionals conditionals; + GaussianMixtureFactor::Factors newFactors; + std::tie(conditionals, newFactors) = unzip(eliminationResults); // Create the GaussianMixture from the conditionals - auto conditional = boost::make_shared( - frontalKeys, keysOfSeparator, discreteSeparator, pair.first); + auto gaussianMixture = boost::make_shared( + frontalKeys, continuousSeparator, discreteSeparator, conditionals); // If there are no more continuous parents, then we should create a // DiscreteFactor here, with the error for each discrete choice. - const auto &separatorFactors = pair.second; - if (keysOfSeparator.empty()) { + if (continuousSeparator.empty()) { auto factorProb = [&](const GaussianMixtureFactor::FactorAndConstant &factor_z) { // This is the probability q(μ) at the MLE point. return factor_z.error(VectorValues()); }; - const DecisionTree fdt(separatorFactors, factorProb); + const DecisionTree fdt(newFactors, factorProb); const auto discreteFactor = boost::make_shared(discreteSeparator, fdt); - return {boost::make_shared(conditional), + return {boost::make_shared(gaussianMixture), boost::make_shared(discreteFactor)}; } else { // Create a resulting GaussianMixtureFactor on the separator. - // Keys can be computed from the factors, so we should not need to pass them - // in. - const auto factor = boost::make_shared( - KeyVector(continuousSeparator.begin(), continuousSeparator.end()), - discreteSeparator, separatorFactors); - return {boost::make_shared(conditional), factor}; + // TODO(dellaert): Keys can be computed from the factors, so we should not + // need to pass them in. + return {boost::make_shared(gaussianMixture), + boost::make_shared( + continuousSeparator, discreteSeparator, newFactors)}; } } @@ -389,12 +388,12 @@ EliminateHybrid(const HybridGaussianFactorGraph &factors, // Fill in discrete discrete separator keys and continuous separator keys. std::set discreteSeparatorSet; - KeySet continuousSeparator; + KeyVector continuousSeparator; for (auto &k : separatorKeys) { if (mapFromKeyToDiscreteKey.find(k) != mapFromKeyToDiscreteKey.end()) { discreteSeparatorSet.insert(mapFromKeyToDiscreteKey.at(k)); } else { - continuousSeparator.insert(k); + continuousSeparator.push_back(k); } }