From b56595c6f804928bffcd1bd34ffc3282f7df4c9c Mon Sep 17 00:00:00 2001 From: Frank Dellaert Date: Tue, 8 Oct 2024 18:10:27 +0900 Subject: [PATCH] Get rid of double storage --- gtsam/hybrid/HybridGaussianConditional.cpp | 67 +++++++++++----------- gtsam/hybrid/HybridGaussianConditional.h | 7 +-- 2 files changed, 37 insertions(+), 37 deletions(-) diff --git a/gtsam/hybrid/HybridGaussianConditional.cpp b/gtsam/hybrid/HybridGaussianConditional.cpp index 58724163e..8ab00fb67 100644 --- a/gtsam/hybrid/HybridGaussianConditional.cpp +++ b/gtsam/hybrid/HybridGaussianConditional.cpp @@ -45,7 +45,6 @@ namespace gtsam { struct HybridGaussianConditional::Helper { std::optional nrFrontals; FactorValuePairs pairs; - Conditionals conditionals; double minNegLogConstant; using GC = GaussianConditional; @@ -70,14 +69,12 @@ struct HybridGaussianConditional::Helper { gcs.push_back(gaussianConditional); } - conditionals = Conditionals({mode}, gcs); pairs = FactorValuePairs({mode}, fvs); } /// Construct from tree of GaussianConditionals. explicit Helper(const Conditionals &conditionals) - : conditionals(conditionals), - minNegLogConstant(std::numeric_limits::infinity()) { + : minNegLogConstant(std::numeric_limits::infinity()) { auto func = [this](const GC::shared_ptr &gc) -> GaussianFactorValuePair { if (!gc) return {nullptr, std::numeric_limits::infinity()}; if (!nrFrontals) nrFrontals = gc->nrFrontals(); @@ -106,7 +103,6 @@ HybridGaussianConditional::HybridGaussianConditional( pair.second - helper.minNegLogConstant}; })), BaseConditional(*helper.nrFrontals), - conditionals_(helper.conditionals), negLogConstant_(helper.minNegLogConstant) {} HybridGaussianConditional::HybridGaussianConditional( @@ -143,24 +139,26 @@ HybridGaussianConditional::HybridGaussianConditional( : HybridGaussianConditional(discreteParents, Helper(conditionals)) {} /* *******************************************************************************/ -const HybridGaussianConditional::Conditionals & +const HybridGaussianConditional::Conditionals HybridGaussianConditional::conditionals() const { - return conditionals_; + return Conditionals(factors(), [](const auto& pair) { + return std::dynamic_pointer_cast(pair.first); + }); } /* *******************************************************************************/ size_t HybridGaussianConditional::nrComponents() const { size_t total = 0; - conditionals_.visit([&total](const GaussianFactor::shared_ptr &node) { - if (node) total += 1; + factors().visit([&total](const auto& node) { + if (node.first) total += 1; }); return total; } /* *******************************************************************************/ GaussianConditional::shared_ptr HybridGaussianConditional::choose( - const DiscreteValues &discreteValues) const { - auto &ptr = conditionals_(discreteValues); + const DiscreteValues& discreteValues) const { + auto& [ptr, _] = factors()(discreteValues); if (!ptr) return nullptr; auto conditional = std::dynamic_pointer_cast(ptr); if (conditional) @@ -176,18 +174,15 @@ bool HybridGaussianConditional::equals(const HybridFactor &lf, const This *e = dynamic_cast(&lf); if (e == nullptr) return false; - // This will return false if either conditionals_ is empty or e->conditionals_ - // is empty, but not if both are empty or both are not empty: - if (conditionals_.empty() ^ e->conditionals_.empty()) return false; - - // Check the base and the factors: - return BaseFactor::equals(*e, tol) && - conditionals_.equals(e->conditionals_, - [tol](const GaussianConditional::shared_ptr &f1, - const GaussianConditional::shared_ptr &f2) { - return (!f1 && !f2) || - (f1 && f2 && f1->equals(*f2, tol)); - }); + // Factors existence and scalar values are checked in BaseFactor::equals. + // Here we check additionally that the factors *are* conditionals and are equal. + auto compareFunc = [tol](const GaussianFactorValuePair& pair1, + const GaussianFactorValuePair& pair2) { + auto c1 = std::dynamic_pointer_cast(pair1.first), + c2 = std::dynamic_pointer_cast(pair2.first); + return (!c1 && !c2) || (c1 && c2 && c1->equals(*c2, tol)); + }; + return Base::equals(*e, tol) && factors().equals(e->factors(), compareFunc); } /* *******************************************************************************/ @@ -202,7 +197,7 @@ void HybridGaussianConditional::print(const std::string &s, std::cout << std::endl << " logNormalizationConstant: " << -negLogConstant() << std::endl << std::endl; - conditionals_.print( + conditionals().print( "", [&](Key k) { return formatter(k); }, [&](const GaussianConditional::shared_ptr &gf) -> std::string { RedirectCout rd; @@ -254,7 +249,7 @@ std::shared_ptr HybridGaussianConditional::likelihood( const DiscreteKeys discreteParentKeys = discreteKeys(); const KeyVector continuousParentKeys = continuousParents(); const HybridGaussianFactor::FactorValuePairs likelihoods( - conditionals_, + conditionals(), [&](const GaussianConditional::shared_ptr &conditional) -> GaussianFactorValuePair { const auto likelihood_m = conditional->likelihood(given); @@ -294,22 +289,30 @@ HybridGaussianConditional::shared_ptr HybridGaussianConditional::prune( return (max->evaluate(choices) == 0.0) ? nullptr : conditional; }; - auto pruned_conditionals = conditionals_.apply(pruner); + auto pruned_conditionals = conditionals().apply(pruner); return std::make_shared(discreteKeys(), pruned_conditionals); } /* *******************************************************************************/ double HybridGaussianConditional::logProbability( - const HybridValues &values) const { - auto conditional = conditionals_(values.discrete()); - return conditional->logProbability(values.continuous()); + const HybridValues& values) const { + auto [factor, _] = factors()(values.discrete()); + if (auto conditional = std::dynamic_pointer_cast(factor)) + return conditional->logProbability(values.continuous()); + else + throw std::logic_error( + "A HybridGaussianConditional unexpectedly contained a non-conditional"); } /* *******************************************************************************/ -double HybridGaussianConditional::evaluate(const HybridValues &values) const { - auto conditional = conditionals_(values.discrete()); - return conditional->evaluate(values.continuous()); +double HybridGaussianConditional::evaluate(const HybridValues& values) const { + auto [factor, _] = factors()(values.discrete()); + if (auto conditional = std::dynamic_pointer_cast(factor)) + return conditional->evaluate(values.continuous()); + else + throw std::logic_error( + "A HybridGaussianConditional unexpectedly contained a non-conditional"); } } // namespace gtsam diff --git a/gtsam/hybrid/HybridGaussianConditional.h b/gtsam/hybrid/HybridGaussianConditional.h index 4cc3d3196..6e0d2800c 100644 --- a/gtsam/hybrid/HybridGaussianConditional.h +++ b/gtsam/hybrid/HybridGaussianConditional.h @@ -64,8 +64,6 @@ class GTSAM_EXPORT HybridGaussianConditional using Conditionals = DecisionTree; private: - Conditionals conditionals_; ///< a decision tree of Gaussian conditionals. - ///< Negative-log of the normalization constant (log(\sqrt(|2πΣ|))). ///< Take advantage of the neg-log space so everything is a minimization double negLogConstant_; @@ -192,8 +190,8 @@ class GTSAM_EXPORT HybridGaussianConditional std::shared_ptr likelihood( const VectorValues &given) const; - /// Getter for the underlying Conditionals DecisionTree - const Conditionals &conditionals() const; + /// Get Conditionals DecisionTree (dynamic cast from factors) + const Conditionals conditionals() const; /** * @brief Compute the logProbability of this hybrid Gaussian conditional. @@ -241,7 +239,6 @@ class GTSAM_EXPORT HybridGaussianConditional void serialize(Archive &ar, const unsigned int /*version*/) { ar &BOOST_SERIALIZATION_BASE_OBJECT_NVP(BaseFactor); ar &BOOST_SERIALIZATION_BASE_OBJECT_NVP(BaseConditional); - ar &BOOST_SERIALIZATION_NVP(conditionals_); } #endif };