From 22c758221ddd26c58923454f5f12428039163e2e Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Fri, 30 Dec 2022 19:28:42 +0530 Subject: [PATCH] make GaussianMixtureFactor store the normalizing constant as well --- gtsam/hybrid/GaussianMixture.cpp | 3 +- gtsam/hybrid/GaussianMixtureFactor.cpp | 39 +++++++++++++------- gtsam/hybrid/GaussianMixtureFactor.h | 18 ++++++--- gtsam/hybrid/HybridGaussianFactorGraph.cpp | 43 +++++++++++++--------- 4 files changed, 66 insertions(+), 37 deletions(-) diff --git a/gtsam/hybrid/GaussianMixture.cpp b/gtsam/hybrid/GaussianMixture.cpp index 155cae10b..ddcfaf0e8 100644 --- a/gtsam/hybrid/GaussianMixture.cpp +++ b/gtsam/hybrid/GaussianMixture.cpp @@ -150,7 +150,8 @@ boost::shared_ptr GaussianMixture::likelihood( const KeyVector continuousParentKeys = continuousParents(); const GaussianMixtureFactor::Factors likelihoods( conditionals(), [&](const GaussianConditional::shared_ptr &conditional) { - return conditional->likelihood(frontals); + return std::make_pair(conditional->likelihood(frontals), + 0.5 * conditional->logDeterminant()); }); return boost::make_shared( continuousParentKeys, discreteParentKeys, likelihoods); diff --git a/gtsam/hybrid/GaussianMixtureFactor.cpp b/gtsam/hybrid/GaussianMixtureFactor.cpp index 32ca1432c..0759cf3be 100644 --- a/gtsam/hybrid/GaussianMixtureFactor.cpp +++ b/gtsam/hybrid/GaussianMixtureFactor.cpp @@ -29,8 +29,11 @@ namespace gtsam { /* *******************************************************************************/ GaussianMixtureFactor::GaussianMixtureFactor(const KeyVector &continuousKeys, const DiscreteKeys &discreteKeys, - const Factors &factors) - : Base(continuousKeys, discreteKeys), factors_(factors) {} + const Mixture &factors) + : Base(continuousKeys, discreteKeys), + factors_(factors, [](const GaussianFactor::shared_ptr &gf) { + return std::make_pair(gf, 0.0); + }) {} /* *******************************************************************************/ bool GaussianMixtureFactor::equals(const HybridFactor &lf, double tol) const { @@ -44,9 +47,9 @@ bool GaussianMixtureFactor::equals(const HybridFactor &lf, double tol) const { // Check the base and the factors: return Base::equals(*e, tol) && factors_.equals(e->factors_, - [tol](const GaussianFactor::shared_ptr &f1, - const GaussianFactor::shared_ptr &f2) { - return f1->equals(*f2, tol); + [tol](const GaussianMixtureFactor::FactorAndLogZ &f1, + const GaussianMixtureFactor::FactorAndLogZ &f2) { + return f1.first->equals(*(f2.first), tol); }); } @@ -60,7 +63,8 @@ void GaussianMixtureFactor::print(const std::string &s, } else { factors_.print( "", [&](Key k) { return formatter(k); }, - [&](const GaussianFactor::shared_ptr &gf) -> std::string { + [&](const GaussianMixtureFactor::FactorAndLogZ &gf_z) -> std::string { + auto gf = gf_z.first; RedirectCout rd; std::cout << ":\n"; if (gf && !gf->empty()) { @@ -75,8 +79,10 @@ void GaussianMixtureFactor::print(const std::string &s, } /* *******************************************************************************/ -const GaussianMixtureFactor::Factors &GaussianMixtureFactor::factors() { - return factors_; +const GaussianMixtureFactor::Mixture GaussianMixtureFactor::factors() { + // Unzip to tree of Gaussian factors and tree of log-constants, + // and return the first tree. + return unzip(factors_).first; } /* *******************************************************************************/ @@ -95,9 +101,9 @@ GaussianMixtureFactor::Sum GaussianMixtureFactor::add( /* *******************************************************************************/ GaussianMixtureFactor::Sum GaussianMixtureFactor::asGaussianFactorGraphTree() const { - auto wrap = [](const GaussianFactor::shared_ptr &factor) { + auto wrap = [](const GaussianMixtureFactor::FactorAndLogZ &factor_z) { GaussianFactorGraph result; - result.push_back(factor); + result.push_back(factor_z.first); return result; }; return {factors_, wrap}; @@ -108,8 +114,11 @@ AlgebraicDecisionTree GaussianMixtureFactor::error( const VectorValues &continuousValues) const { // functor to convert from sharedFactor to double error value. auto errorFunc = - [continuousValues](const GaussianFactor::shared_ptr &factor) { - return factor->error(continuousValues); + [continuousValues](const GaussianMixtureFactor::FactorAndLogZ &factor_z) { + GaussianFactor::shared_ptr factor; + double log_z; + std::tie(factor, log_z) = factor_z; + return factor->error(continuousValues) + log_z; }; DecisionTree errorTree(factors_, errorFunc); return errorTree; @@ -120,8 +129,10 @@ double GaussianMixtureFactor::error( const VectorValues &continuousValues, const DiscreteValues &discreteValues) const { // Directly index to get the conditional, no need to build the whole tree. - auto factor = factors_(discreteValues); - return factor->error(continuousValues); + GaussianFactor::shared_ptr factor; + double log_z; + std::tie(factor, log_z) = factors_(discreteValues); + return factor->error(continuousValues) + log_z; } } // namespace gtsam diff --git a/gtsam/hybrid/GaussianMixtureFactor.h b/gtsam/hybrid/GaussianMixtureFactor.h index b8f475de3..b3e603bc3 100644 --- a/gtsam/hybrid/GaussianMixtureFactor.h +++ b/gtsam/hybrid/GaussianMixtureFactor.h @@ -54,8 +54,11 @@ class GTSAM_EXPORT GaussianMixtureFactor : public HybridFactor { using Sum = DecisionTree; - /// typedef for Decision Tree of Gaussian Factors - using Factors = DecisionTree; + /// typedef of pair of Gaussian factor and log of normalizing constant. + using FactorAndLogZ = std::pair; + /// typedef for Decision Tree of Gaussian Factors and log-constant. + using Factors = DecisionTree; + using Mixture = DecisionTree; private: /// Decision tree of Gaussian factors indexed by discrete keys. @@ -87,7 +90,12 @@ class GTSAM_EXPORT GaussianMixtureFactor : public HybridFactor { */ GaussianMixtureFactor(const KeyVector &continuousKeys, const DiscreteKeys &discreteKeys, - const Factors &factors); + const Mixture &factors); + + GaussianMixtureFactor(const KeyVector &continuousKeys, + const DiscreteKeys &discreteKeys, + const Factors &factors_and_z) + : Base(continuousKeys, discreteKeys), factors_(factors_and_z) {} /** * @brief Construct a new GaussianMixtureFactor object using a vector of @@ -101,7 +109,7 @@ class GTSAM_EXPORT GaussianMixtureFactor : public HybridFactor { const DiscreteKeys &discreteKeys, const std::vector &factors) : GaussianMixtureFactor(continuousKeys, discreteKeys, - Factors(discreteKeys, factors)) {} + Mixture(discreteKeys, factors)) {} /// @} /// @name Testable @@ -115,7 +123,7 @@ class GTSAM_EXPORT GaussianMixtureFactor : public HybridFactor { /// @} /// Getter for the underlying Gaussian Factor Decision Tree. - const Factors &factors(); + const Mixture factors(); /** * @brief Combine the Gaussian Factor Graphs in `sum` and `this` while diff --git a/gtsam/hybrid/HybridGaussianFactorGraph.cpp b/gtsam/hybrid/HybridGaussianFactorGraph.cpp index aac37bc24..15a84b27a 100644 --- a/gtsam/hybrid/HybridGaussianFactorGraph.cpp +++ b/gtsam/hybrid/HybridGaussianFactorGraph.cpp @@ -204,16 +204,17 @@ hybridElimination(const HybridGaussianFactorGraph &factors, }; sum = GaussianMixtureFactor::Sum(sum, emptyGaussian); - using EliminationPair = GaussianFactorGraph::EliminationResult; + using EliminationPair = + std::pair, + std::pair, double>>; KeyVector keysOfEliminated; // Not the ordering KeyVector keysOfSeparator; // TODO(frank): Is this just (keys - ordering)? // This is the elimination method on the leaf nodes - auto eliminate = [&](const GaussianFactorGraph &graph) - -> GaussianFactorGraph::EliminationResult { + auto eliminate = [&](const GaussianFactorGraph &graph) -> EliminationPair { if (graph.empty()) { - return {nullptr, nullptr}; + return {nullptr, std::make_pair(nullptr, 0.0)}; } #ifdef HYBRID_TIMING @@ -222,17 +223,21 @@ hybridElimination(const HybridGaussianFactorGraph &factors, std::pair, boost::shared_ptr> - result = EliminatePreferCholesky(graph, frontalKeys); + conditional_factor = EliminatePreferCholesky(graph, frontalKeys); // Initialize the keysOfEliminated to be the keys of the // eliminated GaussianConditional - keysOfEliminated = result.first->keys(); - keysOfSeparator = result.second->keys(); + keysOfEliminated = conditional_factor.first->keys(); + keysOfSeparator = conditional_factor.second->keys(); #ifdef HYBRID_TIMING gttoc_(hybrid_eliminate); #endif + std::pair, + std::pair, double>> + result = std::make_pair(conditional_factor.first, + std::make_pair(conditional_factor.second, 0.0)); return result; }; @@ -257,16 +262,20 @@ hybridElimination(const HybridGaussianFactorGraph &factors, // DiscreteFactor, with the error for each discrete choice. if (keysOfSeparator.empty()) { VectorValues empty_values; - auto factorProb = [&](const GaussianFactor::shared_ptr &factor) { - if (!factor) { - return 0.0; // If nullptr, return 0.0 probability - } else { - // This is the probability q(μ) at the MLE point. - double error = - 0.5 * std::abs(factor->augmentedInformation().determinant()); - return std::exp(-error); - } - }; + auto factorProb = + [&](const GaussianMixtureFactor::FactorAndLogZ &factor_z) { + if (!factor_z.first) { + return 0.0; // If nullptr, return 0.0 probability + } else { + GaussianFactor::shared_ptr factor = factor_z.first; + double log_z = factor_z.second; + // This is the probability q(μ) at the MLE point. + double error = + 0.5 * std::abs(factor->augmentedInformation().determinant()) + + log_z; + return std::exp(-error); + } + }; DecisionTree fdt(separatorFactors, factorProb); auto discreteFactor =