From 1fe09f5e09a1c68e0461f41167c4f2d897718de0 Mon Sep 17 00:00:00 2001 From: Frank Dellaert Date: Wed, 16 Oct 2024 19:53:36 +0900 Subject: [PATCH] Avoid using slow conditionals() --- gtsam/hybrid/HybridGaussianConditional.cpp | 42 +++++++++++++--------- gtsam/hybrid/HybridGaussianConditional.h | 9 +++++ 2 files changed, 35 insertions(+), 16 deletions(-) diff --git a/gtsam/hybrid/HybridGaussianConditional.cpp b/gtsam/hybrid/HybridGaussianConditional.cpp index 8ab00fb67..6e5fe93c4 100644 --- a/gtsam/hybrid/HybridGaussianConditional.cpp +++ b/gtsam/hybrid/HybridGaussianConditional.cpp @@ -197,11 +197,11 @@ void HybridGaussianConditional::print(const std::string &s, std::cout << std::endl << " logNormalizationConstant: " << -negLogConstant() << std::endl << std::endl; - conditionals().print( + factors().print( "", [&](Key k) { return formatter(k); }, - [&](const GaussianConditional::shared_ptr &gf) -> std::string { + [&](const GaussianFactorValuePair &pair) -> std::string { RedirectCout rd; - if (gf && !gf->empty()) { + if (auto gf = std::dynamic_pointer_cast(pair.first)) { gf->print("", formatter); return rd.str(); } else { @@ -249,12 +249,18 @@ std::shared_ptr HybridGaussianConditional::likelihood( const DiscreteKeys discreteParentKeys = discreteKeys(); const KeyVector continuousParentKeys = continuousParents(); const HybridGaussianFactor::FactorValuePairs likelihoods( - conditionals(), - [&](const GaussianConditional::shared_ptr &conditional) - -> GaussianFactorValuePair { - const auto likelihood_m = conditional->likelihood(given); - const double Cgm_Kgcm = conditional->negLogConstant() - negLogConstant_; - return {likelihood_m, Cgm_Kgcm}; + factors(), + [&](const GaussianFactorValuePair &pair) -> GaussianFactorValuePair { + if (auto conditional = + std::dynamic_pointer_cast(pair.first)) { + const auto likelihood_m = conditional->likelihood(given); + // scalar is already correct. + assert(pair.second == + conditional->negLogConstant() - negLogConstant_); + return {likelihood_m, pair.second}; + } else { + return {nullptr, std::numeric_limits::infinity()}; + } }); return std::make_shared(discreteParentKeys, likelihoods); @@ -283,15 +289,19 @@ HybridGaussianConditional::shared_ptr HybridGaussianConditional::prune( // Check the max value for every combination of our keys. // If the max value is 0.0, we can prune the corresponding conditional. - auto pruner = [&](const Assignment &choices, - const GaussianConditional::shared_ptr &conditional) - -> GaussianConditional::shared_ptr { - return (max->evaluate(choices) == 0.0) ? nullptr : conditional; + auto pruner = + [&](const Assignment &choices, + const GaussianFactorValuePair &pair) -> GaussianFactorValuePair { + if (max->evaluate(choices) == 0.0) + return {nullptr, std::numeric_limits::infinity()}; + else + return pair; }; - auto pruned_conditionals = conditionals().apply(pruner); - return std::make_shared(discreteKeys(), - pruned_conditionals); + FactorValuePairs prunedConditionals = factors().apply(pruner); + return std::shared_ptr( + new HybridGaussianConditional(discreteKeys(), nrFrontals_, + prunedConditionals, negLogConstant_)); } /* *******************************************************************************/ diff --git a/gtsam/hybrid/HybridGaussianConditional.h b/gtsam/hybrid/HybridGaussianConditional.h index 6e0d2800c..38b7b9795 100644 --- a/gtsam/hybrid/HybridGaussianConditional.h +++ b/gtsam/hybrid/HybridGaussianConditional.h @@ -191,6 +191,7 @@ class GTSAM_EXPORT HybridGaussianConditional const VectorValues &given) const; /// Get Conditionals DecisionTree (dynamic cast from factors) + /// @note Slow: avoid using in favor of factors(), which uses existing tree. const Conditionals conditionals() const; /** @@ -229,6 +230,14 @@ class GTSAM_EXPORT HybridGaussianConditional HybridGaussianConditional(const DiscreteKeys &discreteParents, const Helper &helper); + /// Private constructor used when constants have already been calculated. + HybridGaussianConditional(const DiscreteKeys &discreteKeys, int nrFrontals, + const FactorValuePairs &factors, + double negLogConstant) + : BaseFactor(discreteKeys, factors), + BaseConditional(nrFrontals), + negLogConstant_(negLogConstant) {} + /// Check whether `given` has values for all frontal keys. bool allFrontalsGiven(const VectorValues &given) const;