diff --git a/gtsam/hybrid/GaussianMixtureFactor.h b/gtsam/hybrid/GaussianMixtureFactor.h index ad8e19f02..69149ac05 100644 --- a/gtsam/hybrid/GaussianMixtureFactor.h +++ b/gtsam/hybrid/GaussianMixtureFactor.h @@ -62,6 +62,7 @@ class GTSAM_EXPORT GaussianMixtureFactor : public HybridFactor { // Note: constant is log of normalization constant for probabilities. // Errors is the negative log-likelihood, // hence we subtract the constant here. + if (!factor) return 0.0; // If nullptr, return 0.0 error return factor->error(values) - constant; } diff --git a/gtsam/hybrid/HybridGaussianFactorGraph.cpp b/gtsam/hybrid/HybridGaussianFactorGraph.cpp index caef85068..0ffc22ba1 100644 --- a/gtsam/hybrid/HybridGaussianFactorGraph.cpp +++ b/gtsam/hybrid/HybridGaussianFactorGraph.cpp @@ -199,14 +199,14 @@ hybridElimination(const HybridGaussianFactorGraph &factors, DiscreteKeys discreteSeparator(discreteSeparatorSet.begin(), discreteSeparatorSet.end()); - // Collect all the frontal factors to create Gaussian factor graphs - // indexed on the discrete keys. + // Collect all the factors to create a set of Gaussian factor graphs in a + // decision tree indexed by all discrete keys involved. GaussianMixtureFactor::Sum sum = factors.SumFrontals(); - // If a tree leaf contains nullptr, - // convert that leaf to an empty GaussianFactorGraph. - // Needed since the DecisionTree will otherwise create - // a GFG with a single (null) factor. + // If a tree leaf contains nullptr, convert that leaf to an empty + // GaussianFactorGraph. Needed since the DecisionTree will otherwise create a + // GFG with a single (null) factor. + // TODO(dellaert): can SumFrontals not guarantee this? auto emptyGaussian = [](const GaussianMixtureFactor::GraphAndConstant &graph_z) { bool hasNull = std::any_of( @@ -222,7 +222,6 @@ hybridElimination(const HybridGaussianFactorGraph &factors, using EliminationPair = std::pair, GaussianMixtureFactor::FactorAndConstant>; - KeyVector keysOfEliminated; // Not the ordering KeyVector keysOfSeparator; // This is the elimination method on the leaf nodes @@ -236,24 +235,21 @@ hybridElimination(const HybridGaussianFactorGraph &factors, gttic_(hybrid_eliminate); #endif - std::pair, - boost::shared_ptr> - conditional_factor = - EliminatePreferCholesky(graph_z.graph, frontalKeys); + boost::shared_ptr conditional; + boost::shared_ptr newFactor; + boost::tie(conditional, newFactor) = + EliminatePreferCholesky(graph_z.graph, frontalKeys); - // Initialize the keysOfEliminated to be the keys of the - // eliminated GaussianConditional - keysOfEliminated = conditional_factor.first->keys(); - keysOfSeparator = conditional_factor.second->keys(); + // TODO(dellaert): always the same, and we already computed this in caller? + keysOfSeparator = newFactor->keys(); #ifdef HYBRID_TIMING gttoc_(hybrid_eliminate); #endif - GaussianConditional::shared_ptr conditional = conditional_factor.first; // Get the log of the log normalization constant inverse. double logZ = -conditional->logNormalizationConstant() + graph_z.constant; - return {conditional, {conditional_factor.second, logZ}}; + return {conditional, {newFactor, logZ}}; }; // Perform elimination! @@ -266,44 +262,31 @@ hybridElimination(const HybridGaussianFactorGraph &factors, // Separate out decision tree into conditionals and remaining factors. auto pair = unzip(eliminationResults); - const auto &separatorFactors = pair.second; // Create the GaussianMixture from the conditionals auto conditional = boost::make_shared( frontalKeys, keysOfSeparator, discreteSeparator, pair.first); - // If there are no more continuous parents, then we should create here a - // DiscreteFactor, with the error for each discrete choice. + // If there are no more continuous parents, then we should create a + // DiscreteFactor here, with the error for each discrete choice. + const auto &separatorFactors = pair.second; if (keysOfSeparator.empty()) { auto factorProb = [&](const GaussianMixtureFactor::FactorAndConstant &factor_z) { - GaussianFactor::shared_ptr factor = factor_z.factor; - if (!factor) { - return 0.0; // If nullptr, return 0.0 probability - } else { - // This is the probability q(μ) at the MLE point. - double error = factor_z.error(VectorValues()); - return std::exp(-error); - } + // This is the probability q(μ) at the MLE point. + return factor_z.error(VectorValues()); }; - DecisionTree fdt(separatorFactors, factorProb); - // Normalize the values of decision tree to be valid probabilities - double sum = 0.0; - auto visitor = [&](double y) { sum += y; }; - fdt.visit(visitor); - // fdt = DecisionTree(fdt, - // [sum](const double &x) { return x / sum; - // }); - - auto discreteFactor = + const DecisionTree fdt(separatorFactors, factorProb); + const auto discreteFactor = boost::make_shared(discreteSeparator, fdt); return {boost::make_shared(conditional), boost::make_shared(discreteFactor)}; - } else { // Create a resulting GaussianMixtureFactor on the separator. - auto factor = boost::make_shared( + // Keys can be computed from the factors, so we should not need to pass them + // in. + const auto factor = boost::make_shared( KeyVector(continuousSeparator.begin(), continuousSeparator.end()), discreteSeparator, separatorFactors); return {boost::make_shared(conditional), factor};