diff --git a/gtsam/hybrid/GaussianMixture.cpp b/gtsam/hybrid/GaussianMixture.cpp index 155cae10b..65c0e8522 100644 --- a/gtsam/hybrid/GaussianMixture.cpp +++ b/gtsam/hybrid/GaussianMixture.cpp @@ -22,6 +22,7 @@ #include #include #include +#include #include #include @@ -149,17 +150,19 @@ boost::shared_ptr GaussianMixture::likelihood( const DiscreteKeys discreteParentKeys = discreteKeys(); const KeyVector continuousParentKeys = continuousParents(); const GaussianMixtureFactor::Factors likelihoods( - conditionals(), [&](const GaussianConditional::shared_ptr &conditional) { - return conditional->likelihood(frontals); + conditionals_, [&](const GaussianConditional::shared_ptr &conditional) { + return GaussianMixtureFactor::FactorAndConstant{ + conditional->likelihood(frontals), + conditional->logNormalizationConstant()}; }); return boost::make_shared( continuousParentKeys, discreteParentKeys, likelihoods); } /* ************************************************************************* */ -std::set DiscreteKeysAsSet(const DiscreteKeys &dkeys) { +std::set DiscreteKeysAsSet(const DiscreteKeys &discreteKeys) { std::set s; - s.insert(dkeys.begin(), dkeys.end()); + s.insert(discreteKeys.begin(), discreteKeys.end()); return s; } @@ -184,7 +187,7 @@ GaussianMixture::prunerFunc(const DecisionTreeFactor &decisionTree) { const GaussianConditional::shared_ptr &conditional) -> GaussianConditional::shared_ptr { // typecast so we can use this to get probability value - DiscreteValues values(choices); + const DiscreteValues values(choices); // Case where the gaussian mixture has the same // discrete keys as the decision tree. @@ -254,11 +257,10 @@ AlgebraicDecisionTree GaussianMixture::error( } /* *******************************************************************************/ -double GaussianMixture::error(const VectorValues &continuousValues, - const DiscreteValues &discreteValues) const { +double GaussianMixture::error(const HybridValues &values) const { // Directly index to get the conditional, no need to build the whole tree. - auto conditional = conditionals_(discreteValues); - return conditional->error(continuousValues); + auto conditional = conditionals_(values.discrete()); + return conditional->error(values.continuous()); } } // namespace gtsam diff --git a/gtsam/hybrid/GaussianMixture.h b/gtsam/hybrid/GaussianMixture.h index 2cdc23b46..a9b05f250 100644 --- a/gtsam/hybrid/GaussianMixture.h +++ b/gtsam/hybrid/GaussianMixture.h @@ -30,6 +30,7 @@ namespace gtsam { class GaussianMixtureFactor; +class HybridValues; /** * @brief A conditional of gaussian mixtures indexed by discrete variables, as @@ -87,7 +88,7 @@ class GTSAM_EXPORT GaussianMixture /// @name Constructors /// @{ - /// Defaut constructor, mainly for serialization. + /// Default constructor, mainly for serialization. GaussianMixture() = default; /** @@ -135,6 +136,7 @@ class GTSAM_EXPORT GaussianMixture /// @name Standard API /// @{ + /// @brief Return the conditional Gaussian for the given discrete assignment. GaussianConditional::shared_ptr operator()( const DiscreteValues &discreteValues) const; @@ -165,12 +167,10 @@ class GTSAM_EXPORT GaussianMixture * @brief Compute the error of this Gaussian Mixture given the continuous * values and a discrete assignment. * - * @param continuousValues Continuous values at which to compute the error. - * @param discreteValues The discrete assignment for a specific mode sequence. + * @param values Continuous values and discrete assignment. * @return double */ - double error(const VectorValues &continuousValues, - const DiscreteValues &discreteValues) const; + double error(const HybridValues &values) const override; /** * @brief Prune the decision tree of Gaussian factors as per the discrete @@ -193,7 +193,7 @@ class GTSAM_EXPORT GaussianMixture }; /// Return the DiscreteKey vector as a set. -std::set DiscreteKeysAsSet(const DiscreteKeys &dkeys); +std::set DiscreteKeysAsSet(const DiscreteKeys &discreteKeys); // traits template <> diff --git a/gtsam/hybrid/GaussianMixtureFactor.cpp b/gtsam/hybrid/GaussianMixtureFactor.cpp index 32ca1432c..e60368717 100644 --- a/gtsam/hybrid/GaussianMixtureFactor.cpp +++ b/gtsam/hybrid/GaussianMixtureFactor.cpp @@ -22,6 +22,8 @@ #include #include #include +#include +#include #include namespace gtsam { @@ -29,8 +31,11 @@ namespace gtsam { /* *******************************************************************************/ GaussianMixtureFactor::GaussianMixtureFactor(const KeyVector &continuousKeys, const DiscreteKeys &discreteKeys, - const Factors &factors) - : Base(continuousKeys, discreteKeys), factors_(factors) {} + const Mixture &factors) + : Base(continuousKeys, discreteKeys), + factors_(factors, [](const GaussianFactor::shared_ptr &gf) { + return FactorAndConstant{gf, 0.0}; + }) {} /* *******************************************************************************/ bool GaussianMixtureFactor::equals(const HybridFactor &lf, double tol) const { @@ -43,11 +48,11 @@ bool GaussianMixtureFactor::equals(const HybridFactor &lf, double tol) const { // Check the base and the factors: return Base::equals(*e, tol) && - factors_.equals(e->factors_, - [tol](const GaussianFactor::shared_ptr &f1, - const GaussianFactor::shared_ptr &f2) { - return f1->equals(*f2, tol); - }); + factors_.equals(e->factors_, [tol](const FactorAndConstant &f1, + const FactorAndConstant &f2) { + return f1.factor->equals(*(f2.factor), tol) && + std::abs(f1.constant - f2.constant) < tol; + }); } /* *******************************************************************************/ @@ -60,7 +65,8 @@ void GaussianMixtureFactor::print(const std::string &s, } else { factors_.print( "", [&](Key k) { return formatter(k); }, - [&](const GaussianFactor::shared_ptr &gf) -> std::string { + [&](const FactorAndConstant &gf_z) -> std::string { + auto gf = gf_z.factor; RedirectCout rd; std::cout << ":\n"; if (gf && !gf->empty()) { @@ -75,8 +81,10 @@ void GaussianMixtureFactor::print(const std::string &s, } /* *******************************************************************************/ -const GaussianMixtureFactor::Factors &GaussianMixtureFactor::factors() { - return factors_; +const GaussianMixtureFactor::Mixture GaussianMixtureFactor::factors() const { + return Mixture(factors_, [](const FactorAndConstant &factor_z) { + return factor_z.factor; + }); } /* *******************************************************************************/ @@ -95,9 +103,9 @@ GaussianMixtureFactor::Sum GaussianMixtureFactor::add( /* *******************************************************************************/ GaussianMixtureFactor::Sum GaussianMixtureFactor::asGaussianFactorGraphTree() const { - auto wrap = [](const GaussianFactor::shared_ptr &factor) { + auto wrap = [](const FactorAndConstant &factor_z) { GaussianFactorGraph result; - result.push_back(factor); + result.push_back(factor_z.factor); return result; }; return {factors_, wrap}; @@ -107,21 +115,18 @@ GaussianMixtureFactor::Sum GaussianMixtureFactor::asGaussianFactorGraphTree() AlgebraicDecisionTree GaussianMixtureFactor::error( const VectorValues &continuousValues) const { // functor to convert from sharedFactor to double error value. - auto errorFunc = - [continuousValues](const GaussianFactor::shared_ptr &factor) { - return factor->error(continuousValues); - }; + auto errorFunc = [continuousValues](const FactorAndConstant &factor_z) { + return factor_z.error(continuousValues); + }; DecisionTree errorTree(factors_, errorFunc); return errorTree; } /* *******************************************************************************/ -double GaussianMixtureFactor::error( - const VectorValues &continuousValues, - const DiscreteValues &discreteValues) const { - // Directly index to get the conditional, no need to build the whole tree. - auto factor = factors_(discreteValues); - return factor->error(continuousValues); +double GaussianMixtureFactor::error(const HybridValues &values) const { + const FactorAndConstant factor_z = factors_(values.discrete()); + return factor_z.error(values.continuous()); } +/* *******************************************************************************/ } // namespace gtsam diff --git a/gtsam/hybrid/GaussianMixtureFactor.h b/gtsam/hybrid/GaussianMixtureFactor.h index b8f475de3..ce011fecc 100644 --- a/gtsam/hybrid/GaussianMixtureFactor.h +++ b/gtsam/hybrid/GaussianMixtureFactor.h @@ -23,17 +23,15 @@ #include #include #include -#include -#include +#include #include -#include namespace gtsam { class GaussianFactorGraph; - -// Needed for wrapper. -using GaussianFactorVector = std::vector; +class HybridValues; +class DiscreteValues; +class VectorValues; /** * @brief Implementation of a discrete conditional mixture factor. @@ -53,9 +51,29 @@ class GTSAM_EXPORT GaussianMixtureFactor : public HybridFactor { using shared_ptr = boost::shared_ptr; using Sum = DecisionTree; + using sharedFactor = boost::shared_ptr; - /// typedef for Decision Tree of Gaussian Factors - using Factors = DecisionTree; + /// Gaussian factor and log of normalizing constant. + struct FactorAndConstant { + sharedFactor factor; + double constant; + + // Return error with constant correction. + double error(const VectorValues &values) const { + // Note minus sign: constant is log of normalization constant for probabilities. + // Errors is the negative log-likelihood, hence we subtract the constant here. + return factor->error(values) - constant; + } + + // Check pointer equality. + bool operator==(const FactorAndConstant &other) const { + return factor == other.factor && constant == other.constant; + } + }; + + /// typedef for Decision Tree of Gaussian factors and log-constant. + using Factors = DecisionTree; + using Mixture = DecisionTree; private: /// Decision tree of Gaussian factors indexed by discrete keys. @@ -82,12 +100,17 @@ class GTSAM_EXPORT GaussianMixtureFactor : public HybridFactor { * @param continuousKeys A vector of keys representing continuous variables. * @param discreteKeys A vector of keys representing discrete variables and * their cardinalities. - * @param factors The decision tree of Gaussian Factors stored as the mixture + * @param factors The decision tree of Gaussian factors stored as the mixture * density. */ GaussianMixtureFactor(const KeyVector &continuousKeys, const DiscreteKeys &discreteKeys, - const Factors &factors); + const Mixture &factors); + + GaussianMixtureFactor(const KeyVector &continuousKeys, + const DiscreteKeys &discreteKeys, + const Factors &factors_and_z) + : Base(continuousKeys, discreteKeys), factors_(factors_and_z) {} /** * @brief Construct a new GaussianMixtureFactor object using a vector of @@ -99,9 +122,9 @@ class GTSAM_EXPORT GaussianMixtureFactor : public HybridFactor { */ GaussianMixtureFactor(const KeyVector &continuousKeys, const DiscreteKeys &discreteKeys, - const std::vector &factors) + const std::vector &factors) : GaussianMixtureFactor(continuousKeys, discreteKeys, - Factors(discreteKeys, factors)) {} + Mixture(discreteKeys, factors)) {} /// @} /// @name Testable @@ -113,9 +136,11 @@ class GTSAM_EXPORT GaussianMixtureFactor : public HybridFactor { const std::string &s = "GaussianMixtureFactor\n", const KeyFormatter &formatter = DefaultKeyFormatter) const override; /// @} + /// @name Standard API + /// @{ /// Getter for the underlying Gaussian Factor Decision Tree. - const Factors &factors(); + const Mixture factors() const; /** * @brief Combine the Gaussian Factor Graphs in `sum` and `this` while @@ -137,21 +162,17 @@ class GTSAM_EXPORT GaussianMixtureFactor : public HybridFactor { AlgebraicDecisionTree error(const VectorValues &continuousValues) const; /** - * @brief Compute the error of this Gaussian Mixture given the continuous - * values and a discrete assignment. - * - * @param continuousValues Continuous values at which to compute the error. - * @param discreteValues The discrete assignment for a specific mode sequence. + * @brief Compute the log-likelihood, including the log-normalizing constant. * @return double */ - double error(const VectorValues &continuousValues, - const DiscreteValues &discreteValues) const; + double error(const HybridValues &values) const override; /// Add MixtureFactor to a Sum, syntactic sugar. friend Sum &operator+=(Sum &sum, const GaussianMixtureFactor &factor) { sum = factor.add(sum); return sum; } + /// @} }; // traits diff --git a/gtsam/hybrid/HybridBayesNet.cpp b/gtsam/hybrid/HybridBayesNet.cpp index 8e01c0c76..8be314c4e 100644 --- a/gtsam/hybrid/HybridBayesNet.cpp +++ b/gtsam/hybrid/HybridBayesNet.cpp @@ -1,5 +1,5 @@ /* ---------------------------------------------------------------------------- - * GTSAM Copyright 2010, Georgia Tech Research Corporation, + * GTSAM Copyright 2010-2022, Georgia Tech Research Corporation, * Atlanta, Georgia 30332-0415 * All Rights Reserved * Authors: Frank Dellaert, et al. (see THANKS for the full author list) @@ -12,6 +12,7 @@ * @author Fan Jiang * @author Varun Agrawal * @author Shangjie Xue + * @author Frank Dellaert * @date January 2022 */ @@ -321,10 +322,9 @@ HybridValues HybridBayesNet::sample() const { } /* ************************************************************************* */ -double HybridBayesNet::error(const VectorValues &continuousValues, - const DiscreteValues &discreteValues) const { - GaussianBayesNet gbn = choose(discreteValues); - return gbn.error(continuousValues); +double HybridBayesNet::error(const HybridValues &values) const { + GaussianBayesNet gbn = choose(values.discrete()); + return gbn.error(values.continuous()); } /* ************************************************************************* */ diff --git a/gtsam/hybrid/HybridBayesNet.h b/gtsam/hybrid/HybridBayesNet.h index a64b3bb4f..0d2c337b7 100644 --- a/gtsam/hybrid/HybridBayesNet.h +++ b/gtsam/hybrid/HybridBayesNet.h @@ -206,12 +206,10 @@ class GTSAM_EXPORT HybridBayesNet : public BayesNet { * @brief 0.5 * sum of squared Mahalanobis distances * for a specific discrete assignment. * - * @param continuousValues Continuous values at which to compute the error. - * @param discreteValues Discrete assignment for a specific mode sequence. + * @param values Continuous values and discrete assignment. * @return double */ - double error(const VectorValues &continuousValues, - const DiscreteValues &discreteValues) const; + double error(const HybridValues &values) const; /** * @brief Compute conditional error for each discrete assignment, diff --git a/gtsam/hybrid/HybridConditional.h b/gtsam/hybrid/HybridConditional.h index db03ba59c..be671d55f 100644 --- a/gtsam/hybrid/HybridConditional.h +++ b/gtsam/hybrid/HybridConditional.h @@ -52,7 +52,7 @@ namespace gtsam { * having diamond inheritances, and neutralized the need to change other * components of GTSAM to make hybrid elimination work. * - * A great reference to the type-erasure pattern is Eduaado Madrid's CppCon + * A great reference to the type-erasure pattern is Eduardo Madrid's CppCon * talk (https://www.youtube.com/watch?v=s082Qmd_nHs). * * @ingroup hybrid @@ -129,33 +129,6 @@ class GTSAM_EXPORT HybridConditional */ HybridConditional(boost::shared_ptr gaussianMixture); - /** - * @brief Return HybridConditional as a GaussianMixture - * @return nullptr if not a mixture - * @return GaussianMixture::shared_ptr otherwise - */ - GaussianMixture::shared_ptr asMixture() { - return boost::dynamic_pointer_cast(inner_); - } - - /** - * @brief Return HybridConditional as a GaussianConditional - * @return nullptr if not a GaussianConditional - * @return GaussianConditional::shared_ptr otherwise - */ - GaussianConditional::shared_ptr asGaussian() { - return boost::dynamic_pointer_cast(inner_); - } - - /** - * @brief Return conditional as a DiscreteConditional - * @return nullptr if not a DiscreteConditional - * @return DiscreteConditional::shared_ptr - */ - DiscreteConditional::shared_ptr asDiscrete() { - return boost::dynamic_pointer_cast(inner_); - } - /// @} /// @name Testable /// @{ @@ -169,10 +142,52 @@ class GTSAM_EXPORT HybridConditional bool equals(const HybridFactor& other, double tol = 1e-9) const override; /// @} + /// @name Standard Interface + /// @{ + + /** + * @brief Return HybridConditional as a GaussianMixture + * @return nullptr if not a mixture + * @return GaussianMixture::shared_ptr otherwise + */ + GaussianMixture::shared_ptr asMixture() const { + return boost::dynamic_pointer_cast(inner_); + } + + /** + * @brief Return HybridConditional as a GaussianConditional + * @return nullptr if not a GaussianConditional + * @return GaussianConditional::shared_ptr otherwise + */ + GaussianConditional::shared_ptr asGaussian() const { + return boost::dynamic_pointer_cast(inner_); + } + + /** + * @brief Return conditional as a DiscreteConditional + * @return nullptr if not a DiscreteConditional + * @return DiscreteConditional::shared_ptr + */ + DiscreteConditional::shared_ptr asDiscrete() const { + return boost::dynamic_pointer_cast(inner_); + } /// Get the type-erased pointer to the inner type boost::shared_ptr inner() { return inner_; } + /// Return the error of the underlying conditional. + /// Currently only implemented for Gaussian mixture. + double error(const HybridValues& values) const override { + if (auto gm = asMixture()) { + return gm->error(values); + } else { + throw std::runtime_error( + "HybridConditional::error: only implemented for Gaussian mixture"); + } + } + + /// @} + private: /** Serialization function */ friend class boost::serialization::access; diff --git a/gtsam/hybrid/HybridDiscreteFactor.cpp b/gtsam/hybrid/HybridDiscreteFactor.cpp index 0455e1e90..605ea5738 100644 --- a/gtsam/hybrid/HybridDiscreteFactor.cpp +++ b/gtsam/hybrid/HybridDiscreteFactor.cpp @@ -17,6 +17,7 @@ */ #include +#include #include @@ -50,4 +51,10 @@ void HybridDiscreteFactor::print(const std::string &s, inner_->print("\n", formatter); }; +/* ************************************************************************ */ +double HybridDiscreteFactor::error(const HybridValues &values) const { + return -log((*inner_)(values.discrete())); +} +/* ************************************************************************ */ + } // namespace gtsam diff --git a/gtsam/hybrid/HybridDiscreteFactor.h b/gtsam/hybrid/HybridDiscreteFactor.h index 015dc46f8..6e914d38b 100644 --- a/gtsam/hybrid/HybridDiscreteFactor.h +++ b/gtsam/hybrid/HybridDiscreteFactor.h @@ -24,10 +24,12 @@ namespace gtsam { +class HybridValues; + /** - * A HybridDiscreteFactor is a thin container for DiscreteFactor, which allows - * us to hide the implementation of DiscreteFactor and thus avoid diamond - * inheritance. + * A HybridDiscreteFactor is a thin container for DiscreteFactor, which + * allows us to hide the implementation of DiscreteFactor and thus avoid + * diamond inheritance. * * @ingroup hybrid */ @@ -59,9 +61,15 @@ class GTSAM_EXPORT HybridDiscreteFactor : public HybridFactor { const KeyFormatter &formatter = DefaultKeyFormatter) const override; /// @} + /// @name Standard Interface + /// @{ /// Return pointer to the internal discrete factor DiscreteFactor::shared_ptr inner() const { return inner_; } + + /// Return the error of the underlying Discrete Factor. + double error(const HybridValues &values) const override; + /// @} }; // traits diff --git a/gtsam/hybrid/HybridFactor.h b/gtsam/hybrid/HybridFactor.h index e0cae55c1..a28fee8ed 100644 --- a/gtsam/hybrid/HybridFactor.h +++ b/gtsam/hybrid/HybridFactor.h @@ -26,6 +26,8 @@ #include namespace gtsam { +class HybridValues; + KeyVector CollectKeys(const KeyVector &continuousKeys, const DiscreteKeys &discreteKeys); KeyVector CollectKeys(const KeyVector &keys1, const KeyVector &keys2); @@ -110,6 +112,15 @@ class GTSAM_EXPORT HybridFactor : public Factor { /// @name Standard Interface /// @{ + /** + * @brief Compute the error of this Gaussian Mixture given the continuous + * values and a discrete assignment. + * + * @param values Continuous values and discrete assignment. + * @return double + */ + virtual double error(const HybridValues &values) const = 0; + /// True if this is a factor of discrete variables only. bool isDiscrete() const { return isDiscrete_; } diff --git a/gtsam/hybrid/HybridGaussianFactor.cpp b/gtsam/hybrid/HybridGaussianFactor.cpp index ba0c0bf1a..5a89a04a8 100644 --- a/gtsam/hybrid/HybridGaussianFactor.cpp +++ b/gtsam/hybrid/HybridGaussianFactor.cpp @@ -16,6 +16,7 @@ */ #include +#include #include #include @@ -54,4 +55,10 @@ void HybridGaussianFactor::print(const std::string &s, inner_->print("\n", formatter); }; +/* ************************************************************************ */ +double HybridGaussianFactor::error(const HybridValues &values) const { + return inner_->error(values.continuous()); +} +/* ************************************************************************ */ + } // namespace gtsam diff --git a/gtsam/hybrid/HybridGaussianFactor.h b/gtsam/hybrid/HybridGaussianFactor.h index 966524b81..897da9caa 100644 --- a/gtsam/hybrid/HybridGaussianFactor.h +++ b/gtsam/hybrid/HybridGaussianFactor.h @@ -25,6 +25,7 @@ namespace gtsam { // Forward declarations class JacobianFactor; class HessianFactor; +class HybridValues; /** * A HybridGaussianFactor is a layer over GaussianFactor so that we do not have @@ -92,8 +93,15 @@ class GTSAM_EXPORT HybridGaussianFactor : public HybridFactor { const KeyFormatter &formatter = DefaultKeyFormatter) const override; /// @} + /// @name Standard Interface + /// @{ + /// Return pointer to the internal discrete factor GaussianFactor::shared_ptr inner() const { return inner_; } + + /// Return the error of the underlying Discrete Factor. + double error(const HybridValues &values) const override; + /// @} }; // traits diff --git a/gtsam/hybrid/HybridGaussianFactorGraph.cpp b/gtsam/hybrid/HybridGaussianFactorGraph.cpp index aac37bc24..6af0fb1a9 100644 --- a/gtsam/hybrid/HybridGaussianFactorGraph.cpp +++ b/gtsam/hybrid/HybridGaussianFactorGraph.cpp @@ -55,13 +55,14 @@ namespace gtsam { +/// Specialize EliminateableFactorGraph for HybridGaussianFactorGraph: template class EliminateableFactorGraph; /* ************************************************************************ */ static GaussianMixtureFactor::Sum &addGaussian( GaussianMixtureFactor::Sum &sum, const GaussianFactor::shared_ptr &factor) { using Y = GaussianFactorGraph; - // If the decision tree is not intiialized, then intialize it. + // If the decision tree is not initialized, then initialize it. if (sum.empty()) { GaussianFactorGraph result; result.push_back(factor); @@ -89,8 +90,9 @@ GaussianMixtureFactor::Sum sumFrontals( for (auto &f : factors) { if (f->isHybrid()) { - if (auto cgmf = boost::dynamic_pointer_cast(f)) { - sum = cgmf->add(sum); + // TODO(dellaert): just use a virtual method defined in HybridFactor. + if (auto gm = boost::dynamic_pointer_cast(f)) { + sum = gm->add(sum); } if (auto gm = boost::dynamic_pointer_cast(f)) { sum = gm->asMixture()->add(sum); @@ -184,7 +186,7 @@ hybridElimination(const HybridGaussianFactorGraph &factors, const KeySet &continuousSeparator, const std::set &discreteSeparatorSet) { // NOTE: since we use the special JunctionTree, - // only possiblity is continuous conditioned on discrete. + // only possibility is continuous conditioned on discrete. DiscreteKeys discreteSeparator(discreteSeparatorSet.begin(), discreteSeparatorSet.end()); @@ -204,16 +206,16 @@ hybridElimination(const HybridGaussianFactorGraph &factors, }; sum = GaussianMixtureFactor::Sum(sum, emptyGaussian); - using EliminationPair = GaussianFactorGraph::EliminationResult; + using EliminationPair = std::pair, + GaussianMixtureFactor::FactorAndConstant>; KeyVector keysOfEliminated; // Not the ordering KeyVector keysOfSeparator; // TODO(frank): Is this just (keys - ordering)? // This is the elimination method on the leaf nodes - auto eliminate = [&](const GaussianFactorGraph &graph) - -> GaussianFactorGraph::EliminationResult { + auto eliminate = [&](const GaussianFactorGraph &graph) -> EliminationPair { if (graph.empty()) { - return {nullptr, nullptr}; + return {nullptr, {nullptr, 0.0}}; } #ifdef HYBRID_TIMING @@ -222,18 +224,18 @@ hybridElimination(const HybridGaussianFactorGraph &factors, std::pair, boost::shared_ptr> - result = EliminatePreferCholesky(graph, frontalKeys); + conditional_factor = EliminatePreferCholesky(graph, frontalKeys); // Initialize the keysOfEliminated to be the keys of the // eliminated GaussianConditional - keysOfEliminated = result.first->keys(); - keysOfSeparator = result.second->keys(); + keysOfEliminated = conditional_factor.first->keys(); + keysOfSeparator = conditional_factor.second->keys(); #ifdef HYBRID_TIMING gttoc_(hybrid_eliminate); #endif - return result; + return {conditional_factor.first, {conditional_factor.second, 0.0}}; }; // Perform elimination! @@ -246,8 +248,7 @@ hybridElimination(const HybridGaussianFactorGraph &factors, // Separate out decision tree into conditionals and remaining factors. auto pair = unzip(eliminationResults); - - const GaussianMixtureFactor::Factors &separatorFactors = pair.second; + const auto &separatorFactors = pair.second; // Create the GaussianMixture from the conditionals auto conditional = boost::make_shared( @@ -257,16 +258,19 @@ hybridElimination(const HybridGaussianFactorGraph &factors, // DiscreteFactor, with the error for each discrete choice. if (keysOfSeparator.empty()) { VectorValues empty_values; - auto factorProb = [&](const GaussianFactor::shared_ptr &factor) { - if (!factor) { - return 0.0; // If nullptr, return 0.0 probability - } else { - // This is the probability q(μ) at the MLE point. - double error = - 0.5 * std::abs(factor->augmentedInformation().determinant()); - return std::exp(-error); - } - }; + auto factorProb = + [&](const GaussianMixtureFactor::FactorAndConstant &factor_z) { + GaussianFactor::shared_ptr factor = factor_z.factor; + if (!factor) { + return 0.0; // If nullptr, return 0.0 probability + } else { + // This is the probability q(μ) at the MLE point. + double error = + 0.5 * std::abs(factor->augmentedInformation().determinant()) + + factor_z.constant; + return std::exp(-error); + } + }; DecisionTree fdt(separatorFactors, factorProb); auto discreteFactor = @@ -452,6 +456,7 @@ AlgebraicDecisionTree HybridGaussianFactorGraph::error( // Iterate over each factor. for (size_t idx = 0; idx < size(); idx++) { + // TODO(dellaert): just use a virtual method defined in HybridFactor. AlgebraicDecisionTree factor_error; if (factors_.at(idx)->isHybrid()) { @@ -491,38 +496,17 @@ AlgebraicDecisionTree HybridGaussianFactorGraph::error( } /* ************************************************************************ */ -double HybridGaussianFactorGraph::error( - const VectorValues &continuousValues, - const DiscreteValues &discreteValues) const { +double HybridGaussianFactorGraph::error(const HybridValues &values) const { double error = 0.0; - for (size_t idx = 0; idx < size(); idx++) { - auto factor = factors_.at(idx); - - if (factor->isHybrid()) { - if (auto c = boost::dynamic_pointer_cast(factor)) { - error += c->asMixture()->error(continuousValues, discreteValues); - } - if (auto f = boost::dynamic_pointer_cast(factor)) { - error += f->error(continuousValues, discreteValues); - } - - } else if (factor->isContinuous()) { - if (auto f = boost::dynamic_pointer_cast(factor)) { - error += f->inner()->error(continuousValues); - } - if (auto cg = boost::dynamic_pointer_cast(factor)) { - error += cg->asGaussian()->error(continuousValues); - } - } + for (auto &factor : factors_) { + error += factor->error(values); } return error; } /* ************************************************************************ */ -double HybridGaussianFactorGraph::probPrime( - const VectorValues &continuousValues, - const DiscreteValues &discreteValues) const { - double error = this->error(continuousValues, discreteValues); +double HybridGaussianFactorGraph::probPrime(const HybridValues &values) const { + double error = this->error(values); // NOTE: The 0.5 term is handled by each factor return std::exp(-error); } diff --git a/gtsam/hybrid/HybridGaussianFactorGraph.h b/gtsam/hybrid/HybridGaussianFactorGraph.h index 4e22bed7c..c851adfe5 100644 --- a/gtsam/hybrid/HybridGaussianFactorGraph.h +++ b/gtsam/hybrid/HybridGaussianFactorGraph.h @@ -12,7 +12,7 @@ /** * @file HybridGaussianFactorGraph.h * @brief Linearized Hybrid factor graph that uses type erasure - * @author Fan Jiang, Varun Agrawal + * @author Fan Jiang, Varun Agrawal, Frank Dellaert * @date Mar 11, 2022 */ @@ -38,6 +38,7 @@ class HybridBayesTree; class HybridJunctionTree; class DecisionTreeFactor; class JacobianFactor; +class HybridValues; /** * @brief Main elimination function for HybridGaussianFactorGraph. @@ -186,14 +187,9 @@ class GTSAM_EXPORT HybridGaussianFactorGraph * @brief Compute error given a continuous vector values * and a discrete assignment. * - * @param continuousValues The continuous VectorValues - * for computing the error. - * @param discreteValues The specific discrete assignment - * whose error we wish to compute. * @return double */ - double error(const VectorValues& continuousValues, - const DiscreteValues& discreteValues) const; + double error(const HybridValues& values) const; /** * @brief Compute unnormalized probability \f$ P(X | M, Z) \f$ @@ -210,13 +206,9 @@ class GTSAM_EXPORT HybridGaussianFactorGraph * @brief Compute the unnormalized posterior probability for a continuous * vector values given a specific assignment. * - * @param continuousValues The vector values for which to compute the - * posterior probability. - * @param discreteValues The specific assignment to use for the computation. * @return double */ - double probPrime(const VectorValues& continuousValues, - const DiscreteValues& discreteValues) const; + double probPrime(const HybridValues& values) const; /** * @brief Return a Colamd constrained ordering where the discrete keys are diff --git a/gtsam/hybrid/HybridNonlinearFactor.h b/gtsam/hybrid/HybridNonlinearFactor.h index 7776347b3..9b3e780ef 100644 --- a/gtsam/hybrid/HybridNonlinearFactor.h +++ b/gtsam/hybrid/HybridNonlinearFactor.h @@ -51,12 +51,22 @@ class HybridNonlinearFactor : public HybridFactor { const KeyFormatter &formatter = DefaultKeyFormatter) const override; /// @} + /// @name Standard Interface + /// @{ NonlinearFactor::shared_ptr inner() const { return inner_; } + /// Error for HybridValues is not provided for nonlinear factor. + double error(const HybridValues &values) const override { + throw std::runtime_error( + "HybridNonlinearFactor::error(HybridValues) not implemented."); + } + /// Linearize to a HybridGaussianFactor at the linearization point `c`. boost::shared_ptr linearize(const Values &c) const { return boost::make_shared(inner_->linearize(c)); } + + /// @} }; } // namespace gtsam diff --git a/gtsam/hybrid/MixtureFactor.h b/gtsam/hybrid/MixtureFactor.h index f29a84022..fc1a9a2b8 100644 --- a/gtsam/hybrid/MixtureFactor.h +++ b/gtsam/hybrid/MixtureFactor.h @@ -161,6 +161,12 @@ class MixtureFactor : public HybridFactor { factor, continuousValues); } + /// Error for HybridValues is not provided for nonlinear hybrid factor. + double error(const HybridValues &values) const override { + throw std::runtime_error( + "MixtureFactor::error(HybridValues) not implemented."); + } + size_t dim() const { // TODO(Varun) throw std::runtime_error("MixtureFactor::dim not implemented."); diff --git a/gtsam/hybrid/hybrid.i b/gtsam/hybrid/hybrid.i index 15687d11b..3c74d1ee2 100644 --- a/gtsam/hybrid/hybrid.i +++ b/gtsam/hybrid/hybrid.i @@ -183,10 +183,8 @@ class HybridGaussianFactorGraph { bool equals(const gtsam::HybridGaussianFactorGraph& fg, double tol = 1e-9) const; // evaluation - double error(const gtsam::VectorValues& continuousValues, - const gtsam::DiscreteValues& discreteValues) const; - double probPrime(const gtsam::VectorValues& continuousValues, - const gtsam::DiscreteValues& discreteValues) const; + double error(const gtsam::HybridValues& values) const; + double probPrime(const gtsam::HybridValues& values) const; gtsam::HybridBayesNet* eliminateSequential(); gtsam::HybridBayesNet* eliminateSequential( diff --git a/gtsam/hybrid/tests/testGaussianMixture.cpp b/gtsam/hybrid/tests/testGaussianMixture.cpp index 242c9ba41..ff8edd46e 100644 --- a/gtsam/hybrid/tests/testGaussianMixture.cpp +++ b/gtsam/hybrid/tests/testGaussianMixture.cpp @@ -128,9 +128,9 @@ TEST(GaussianMixture, Error) { // Regression for non-tree version. DiscreteValues assignment; assignment[M(1)] = 0; - EXPECT_DOUBLES_EQUAL(0.5, mixture.error(values, assignment), 1e-8); + EXPECT_DOUBLES_EQUAL(0.5, mixture.error({values, assignment}), 1e-8); assignment[M(1)] = 1; - EXPECT_DOUBLES_EQUAL(4.3252595155709335, mixture.error(values, assignment), + EXPECT_DOUBLES_EQUAL(4.3252595155709335, mixture.error({values, assignment}), 1e-8); } @@ -179,7 +179,9 @@ TEST(GaussianMixture, Likelihood) { const GaussianMixtureFactor::Factors factors( gm.conditionals(), [measurements](const GaussianConditional::shared_ptr& conditional) { - return conditional->likelihood(measurements); + return GaussianMixtureFactor::FactorAndConstant{ + conditional->likelihood(measurements), + conditional->logNormalizationConstant()}; }); const GaussianMixtureFactor expected({X(0)}, {mode}, factors); EXPECT(assert_equal(*factor, expected)); diff --git a/gtsam/hybrid/tests/testGaussianMixtureFactor.cpp b/gtsam/hybrid/tests/testGaussianMixtureFactor.cpp index ba0622ff9..d17968a3a 100644 --- a/gtsam/hybrid/tests/testGaussianMixtureFactor.cpp +++ b/gtsam/hybrid/tests/testGaussianMixtureFactor.cpp @@ -22,6 +22,7 @@ #include #include #include +#include #include #include @@ -188,7 +189,8 @@ TEST(GaussianMixtureFactor, Error) { DiscreteValues discreteValues; discreteValues[m1.first] = 1; EXPECT_DOUBLES_EQUAL( - 4.0, mixtureFactor.error(continuousValues, discreteValues), 1e-9); + 4.0, mixtureFactor.error({continuousValues, discreteValues}), + 1e-9); } /* ************************************************************************* */ diff --git a/gtsam/hybrid/tests/testHybridBayesNet.cpp b/gtsam/hybrid/tests/testHybridBayesNet.cpp index 43cee6f74..58230cfde 100644 --- a/gtsam/hybrid/tests/testHybridBayesNet.cpp +++ b/gtsam/hybrid/tests/testHybridBayesNet.cpp @@ -188,14 +188,14 @@ TEST(HybridBayesNet, Optimize) { HybridValues delta = hybridBayesNet->optimize(); - //TODO(Varun) The expectedAssignment should be 111, not 101 + // TODO(Varun) The expectedAssignment should be 111, not 101 DiscreteValues expectedAssignment; expectedAssignment[M(0)] = 1; expectedAssignment[M(1)] = 0; expectedAssignment[M(2)] = 1; EXPECT(assert_equal(expectedAssignment, delta.discrete())); - //TODO(Varun) This should be all -Vector1::Ones() + // TODO(Varun) This should be all -Vector1::Ones() VectorValues expectedValues; expectedValues.insert(X(0), -0.999904 * Vector1::Ones()); expectedValues.insert(X(1), -0.99029 * Vector1::Ones()); @@ -243,8 +243,8 @@ TEST(HybridBayesNet, Error) { double total_error = 0; for (size_t idx = 0; idx < hybridBayesNet->size(); idx++) { if (hybridBayesNet->at(idx)->isHybrid()) { - double error = hybridBayesNet->atMixture(idx)->error(delta.continuous(), - discrete_values); + double error = hybridBayesNet->atMixture(idx)->error( + {delta.continuous(), discrete_values}); total_error += error; } else if (hybridBayesNet->at(idx)->isContinuous()) { double error = hybridBayesNet->atGaussian(idx)->error(delta.continuous()); @@ -253,7 +253,7 @@ TEST(HybridBayesNet, Error) { } EXPECT_DOUBLES_EQUAL( - total_error, hybridBayesNet->error(delta.continuous(), discrete_values), + total_error, hybridBayesNet->error({delta.continuous(), discrete_values}), 1e-9); EXPECT_DOUBLES_EQUAL(total_error, error_tree(discrete_values), 1e-9); EXPECT_DOUBLES_EQUAL(total_error, pruned_error_tree(discrete_values), 1e-9); diff --git a/gtsam/hybrid/tests/testHybridEstimation.cpp b/gtsam/hybrid/tests/testHybridEstimation.cpp index 927f5c047..660cb3317 100644 --- a/gtsam/hybrid/tests/testHybridEstimation.cpp +++ b/gtsam/hybrid/tests/testHybridEstimation.cpp @@ -273,7 +273,7 @@ AlgebraicDecisionTree getProbPrimeTree( continue; } - double error = graph.error(delta, assignment); + double error = graph.error({delta, assignment}); probPrimes.push_back(exp(-error)); } AlgebraicDecisionTree probPrimeTree(discrete_keys, probPrimes); @@ -487,8 +487,8 @@ TEST(HybridEstimation, CorrectnessViaSampling) { const HybridValues& sample) -> double { const DiscreteValues assignment = sample.discrete(); // Compute in log form for numerical stability - double log_ratio = bayesNet->error(sample.continuous(), assignment) - - factorGraph->error(sample.continuous(), assignment); + double log_ratio = bayesNet->error({sample.continuous(), assignment}) - + factorGraph->error({sample.continuous(), assignment}); double ratio = exp(-log_ratio); return ratio; }; diff --git a/gtsam/hybrid/tests/testHybridGaussianFactorGraph.cpp b/gtsam/hybrid/tests/testHybridGaussianFactorGraph.cpp index 565c7f0a0..1bdb6d4db 100644 --- a/gtsam/hybrid/tests/testHybridGaussianFactorGraph.cpp +++ b/gtsam/hybrid/tests/testHybridGaussianFactorGraph.cpp @@ -575,18 +575,14 @@ TEST(HybridGaussianFactorGraph, ErrorAndProbPrime) { HybridBayesNet::shared_ptr hybridBayesNet = graph.eliminateSequential(hybridOrdering); - HybridValues delta = hybridBayesNet->optimize(); - double error = graph.error(delta.continuous(), delta.discrete()); - - double expected_error = 0.490243199; - // regression - EXPECT(assert_equal(expected_error, error, 1e-9)); - - double probs = exp(-error); - double expected_probs = graph.probPrime(delta.continuous(), delta.discrete()); + const HybridValues delta = hybridBayesNet->optimize(); + const double error = graph.error(delta); // regression - EXPECT(assert_equal(expected_probs, probs, 1e-7)); + EXPECT(assert_equal(1.58886, error, 1e-5)); + + // Real test: + EXPECT(assert_equal(graph.probPrime(delta), exp(-error), 1e-7)); } /* ****************************************************************************/ diff --git a/gtsam/linear/GaussianConditional.cpp b/gtsam/linear/GaussianConditional.cpp index 7cdff914f..ecfa02282 100644 --- a/gtsam/linear/GaussianConditional.cpp +++ b/gtsam/linear/GaussianConditional.cpp @@ -168,26 +168,30 @@ namespace gtsam { /* ************************************************************************* */ double GaussianConditional::logDeterminant() const { - double logDet; - if (this->get_model()) { - Vector diag = this->R().diagonal(); - this->get_model()->whitenInPlace(diag); - logDet = diag.unaryExpr([](double x) { return log(x); }).sum(); + if (get_model()) { + Vector diag = R().diagonal(); + get_model()->whitenInPlace(diag); + return diag.unaryExpr([](double x) { return log(x); }).sum(); } else { - logDet = - this->R().diagonal().unaryExpr([](double x) { return log(x); }).sum(); + return R().diagonal().unaryExpr([](double x) { return log(x); }).sum(); } - return logDet; } /* ************************************************************************* */ -// density = exp(-error(x)) / sqrt((2*pi)^n*det(Sigma)) -// log = -error(x) - 0.5 * n*log(2*pi) - 0.5 * log det(Sigma) -double GaussianConditional::logDensity(const VectorValues& x) const { +// normalization constant = 1.0 / sqrt((2*pi)^n*det(Sigma)) +// log = - 0.5 * n*log(2*pi) - 0.5 * log det(Sigma) +double GaussianConditional::logNormalizationConstant() const { constexpr double log2pi = 1.8378770664093454835606594728112; size_t n = d().size(); // log det(Sigma)) = - 2.0 * logDeterminant() - return - error(x) - 0.5 * n * log2pi + logDeterminant(); + return - 0.5 * n * log2pi + logDeterminant(); +} + +/* ************************************************************************* */ +// density = k exp(-error(x)) +// log = log(k) -error(x) - 0.5 * n*log(2*pi) +double GaussianConditional::logDensity(const VectorValues& x) const { + return logNormalizationConstant() - error(x); } /* ************************************************************************* */ diff --git a/gtsam/linear/GaussianConditional.h b/gtsam/linear/GaussianConditional.h index af1c5d80e..d25efb2e1 100644 --- a/gtsam/linear/GaussianConditional.h +++ b/gtsam/linear/GaussianConditional.h @@ -169,7 +169,7 @@ namespace gtsam { * * @return double */ - double determinant() const { return exp(this->logDeterminant()); } + inline double determinant() const { return exp(logDeterminant()); } /** * @brief Compute the log determinant of the R matrix. @@ -184,6 +184,19 @@ namespace gtsam { */ double logDeterminant() const; + /** + * normalization constant = 1.0 / sqrt((2*pi)^n*det(Sigma)) + * log = - 0.5 * n*log(2*pi) - 0.5 * log det(Sigma) + */ + double logNormalizationConstant() const; + + /** + * normalization constant = 1.0 / sqrt((2*pi)^n*det(Sigma)) + */ + inline double normalizationConstant() const { + return exp(logNormalizationConstant()); + } + /** * Solves a conditional Gaussian and writes the solution into the entries of * \c x for each frontal variable of the conditional. The parents are diff --git a/python/gtsam/tests/test_HybridFactorGraph.py b/python/gtsam/tests/test_HybridFactorGraph.py index 481617db1..5398160dc 100644 --- a/python/gtsam/tests/test_HybridFactorGraph.py +++ b/python/gtsam/tests/test_HybridFactorGraph.py @@ -6,7 +6,7 @@ All Rights Reserved See LICENSE for the license information Unit tests for Hybrid Factor Graphs. -Author: Fan Jiang +Author: Fan Jiang, Varun Agrawal, Frank Dellaert """ # pylint: disable=invalid-name, no-name-in-module, no-member @@ -18,13 +18,14 @@ from gtsam.utils.test_case import GtsamTestCase import gtsam from gtsam import (DiscreteConditional, DiscreteKeys, GaussianConditional, - GaussianMixture, GaussianMixtureFactor, + GaussianMixture, GaussianMixtureFactor, HybridBayesNet, HybridValues, HybridGaussianFactorGraph, JacobianFactor, Ordering, noiseModel) class TestHybridGaussianFactorGraph(GtsamTestCase): """Unit tests for HybridGaussianFactorGraph.""" + def test_create(self): """Test construction of hybrid factor graph.""" model = noiseModel.Unit.Create(3) @@ -81,13 +82,13 @@ class TestHybridGaussianFactorGraph(GtsamTestCase): self.assertEqual(hv.atDiscrete(C(0)), 1) @staticmethod - def tiny(num_measurements: int = 1) -> gtsam.HybridBayesNet: + def tiny(num_measurements: int = 1) -> HybridBayesNet: """ Create a tiny two variable hybrid model which represents the generative probability P(z, x, n) = P(z | x, n)P(x)P(n). """ # Create hybrid Bayes net. - bayesNet = gtsam.HybridBayesNet() + bayesNet = HybridBayesNet() # Create mode key: 0 is low-noise, 1 is high-noise. mode = (M(0), 2) @@ -113,35 +114,76 @@ class TestHybridGaussianFactorGraph(GtsamTestCase): bayesNet.addGaussian(prior_on_x0) # Add prior on mode. - bayesNet.emplaceDiscrete(mode, "1/1") + bayesNet.emplaceDiscrete(mode, "4/6") return bayesNet + @staticmethod + def factor_graph_from_bayes_net(bayesNet: HybridBayesNet, sample: HybridValues): + """Create a factor graph from the Bayes net with sampled measurements. + The factor graph is `P(x)P(n) ϕ(x, n; z0) ϕ(x, n; z1) ...` + and thus represents the same joint probability as the Bayes net. + """ + fg = HybridGaussianFactorGraph() + num_measurements = bayesNet.size() - 2 + for i in range(num_measurements): + conditional = bayesNet.atMixture(i) + measurement = gtsam.VectorValues() + measurement.insert(Z(i), sample.at(Z(i))) + factor = conditional.likelihood(measurement) + fg.push_back(factor) + fg.push_back(bayesNet.atGaussian(num_measurements)) + fg.push_back(bayesNet.atDiscrete(num_measurements+1)) + return fg + + @classmethod + def estimate_marginals(cls, bayesNet: HybridBayesNet, sample: HybridValues, N=10000): + """Do importance sampling to get an estimate of the discrete marginal P(mode).""" + # Use prior on x0, mode as proposal density. + prior = cls.tiny(num_measurements=0) # just P(x0)P(mode) + + # Allocate space for marginals. + marginals = np.zeros((2,)) + + # Do importance sampling. + num_measurements = bayesNet.size() - 2 + for s in range(N): + proposed = prior.sample() + for i in range(num_measurements): + z_i = sample.at(Z(i)) + proposed.insert(Z(i), z_i) + weight = bayesNet.evaluate(proposed) / prior.evaluate(proposed) + marginals[proposed.atDiscrete(M(0))] += weight + + # print marginals: + marginals /= marginals.sum() + return marginals + def test_tiny(self): """Test a tiny two variable hybrid model.""" bayesNet = self.tiny() sample = bayesNet.sample() # print(sample) - # Create a factor graph from the Bayes net with sampled measurements. - fg = HybridGaussianFactorGraph() - conditional = bayesNet.atMixture(0) - measurement = gtsam.VectorValues() - measurement.insert(Z(0), sample.at(Z(0))) - factor = conditional.likelihood(measurement) - fg.push_back(factor) - fg.push_back(bayesNet.atGaussian(1)) - fg.push_back(bayesNet.atDiscrete(2)) + # Estimate marginals using importance sampling. + marginals = self.estimate_marginals(bayesNet, sample) + # print(f"True mode: {sample.atDiscrete(M(0))}") + # print(f"P(mode=0; z0) = {marginals[0]}") + # print(f"P(mode=1; z0) = {marginals[1]}") + # Check that the estimate is close to the true value. + self.assertAlmostEqual(marginals[0], 0.4, delta=0.1) + self.assertAlmostEqual(marginals[1], 0.6, delta=0.1) + + fg = self.factor_graph_from_bayes_net(bayesNet, sample) self.assertEqual(fg.size(), 3) @staticmethod - def calculate_ratio(bayesNet, fg, sample): + def calculate_ratio(bayesNet: HybridBayesNet, + fg: HybridGaussianFactorGraph, + sample: HybridValues): """Calculate ratio between Bayes net probability and the factor graph.""" - continuous = gtsam.VectorValues() - continuous.insert(X(0), sample.at(X(0))) - return bayesNet.evaluate(sample) / fg.probPrime( - continuous, sample.discrete()) + return bayesNet.evaluate(sample) / fg.probPrime(sample) if fg.probPrime(sample) > 0 else 0 def test_ratio(self): """ @@ -153,23 +195,22 @@ class TestHybridGaussianFactorGraph(GtsamTestCase): # Create the Bayes net representing the generative model P(z, x, n)=P(z|x, n)P(x)P(n) bayesNet = self.tiny(num_measurements=2) # Sample from the Bayes net. - sample: gtsam.HybridValues = bayesNet.sample() + sample: HybridValues = bayesNet.sample() # print(sample) - # Create a factor graph from the Bayes net with sampled measurements. - # The factor graph is `P(x)P(n) ϕ(x, n; z1) ϕ(x, n; z2)` - # and thus represents the same joint probability as the Bayes net. - fg = HybridGaussianFactorGraph() - for i in range(2): - conditional = bayesNet.atMixture(i) - measurement = gtsam.VectorValues() - measurement.insert(Z(i), sample.at(Z(i))) - factor = conditional.likelihood(measurement) - fg.push_back(factor) - fg.push_back(bayesNet.atGaussian(2)) - fg.push_back(bayesNet.atDiscrete(3)) + # Estimate marginals using importance sampling. + marginals = self.estimate_marginals(bayesNet, sample) + # print(f"True mode: {sample.atDiscrete(M(0))}") + # print(f"P(mode=0; z0, z1) = {marginals[0]}") + # print(f"P(mode=1; z0, z1) = {marginals[1]}") - # print(fg) + # Check marginals based on sampled mode. + if sample.atDiscrete(M(0)) == 0: + self.assertGreater(marginals[0], marginals[1]) + else: + self.assertGreater(marginals[1], marginals[0]) + + fg = self.factor_graph_from_bayes_net(bayesNet, sample) self.assertEqual(fg.size(), 4) # Calculate ratio between Bayes net probability and the factor graph: @@ -185,10 +226,10 @@ class TestHybridGaussianFactorGraph(GtsamTestCase): for i in range(10): other = bayesNet.sample() other.update(measurements) - # print(other) - # ratio = self.calculate_ratio(bayesNet, fg, other) + ratio = self.calculate_ratio(bayesNet, fg, other) # print(f"Ratio: {ratio}\n") - # self.assertAlmostEqual(ratio, expected_ratio) + if (ratio > 0): + self.assertAlmostEqual(ratio, expected_ratio) if __name__ == "__main__":