diff --git a/gtsam/hybrid/HybridGaussianConditional.cpp b/gtsam/hybrid/HybridGaussianConditional.cpp index 4e896206f..d0f812d73 100644 --- a/gtsam/hybrid/HybridGaussianConditional.cpp +++ b/gtsam/hybrid/HybridGaussianConditional.cpp @@ -29,16 +29,20 @@ #include +#include "gtsam/linear/JacobianFactor.h" + namespace gtsam { /* *******************************************************************************/ struct HybridGaussianConditional::ConstructorHelper { std::optional nrFrontals; - HybridGaussianFactor::FactorValuePairs pairs; + FactorValuePairs pairs; + Conditionals conditionals; double minNegLogConstant; - /// Compute all variables needed for the private constructor below. + /// Construct from tree of GaussianConditionals. ConstructorHelper(const Conditionals &conditionals) - : minNegLogConstant(std::numeric_limits::infinity()) { + : conditionals(conditionals), + minNegLogConstant(std::numeric_limits::infinity()) { auto func = [this](const GaussianConditional::shared_ptr &c) -> GaussianFactorValuePair { double value = 0.0; @@ -51,38 +55,79 @@ struct HybridGaussianConditional::ConstructorHelper { } return {std::dynamic_pointer_cast(c), value}; }; - pairs = HybridGaussianFactor::FactorValuePairs(conditionals, func); + pairs = FactorValuePairs(conditionals, func); if (!nrFrontals.has_value()) { throw std::runtime_error( "HybridGaussianConditional: need at least one frontal variable."); } } + + /// Construct from means and a single sigma. + ConstructorHelper(Key x, const DiscreteKey mode, + const std::vector &means, double sigma) + : nrFrontals(1), minNegLogConstant(0) { + std::vector gcs; + for (const auto &mean : means) { + auto c = GaussianConditional::sharedMeanAndStddev(x, mean, sigma); + gcs.push_back(c); + } + conditionals = Conditionals({mode}, gcs); + pairs = FactorValuePairs(conditionals, [](const auto &c) { + return GaussianFactorValuePair{c, 0.0}; + }); + } + + /// Construct from means and a sigmas. + ConstructorHelper(Key x, const DiscreteKey mode, + const std::vector> ¶meters) + : nrFrontals(1), + minNegLogConstant(std::numeric_limits::infinity()) { + std::vector gcs; + std::vector fvs; + for (const auto &[mean, sigma] : parameters) { + auto c = GaussianConditional::sharedMeanAndStddev(x, mean, sigma); + double value = c->negLogConstant(); + minNegLogConstant = std::min(minNegLogConstant, value); + gcs.push_back(c); + fvs.push_back({c, value}); + } + conditionals = Conditionals({mode}, gcs); + pairs = FactorValuePairs({mode}, fvs); + } }; /* *******************************************************************************/ HybridGaussianConditional::HybridGaussianConditional( - const DiscreteKeys &discreteParents, - const HybridGaussianConditional::Conditionals &conditionals, - const ConstructorHelper &helper) + const DiscreteKeys &discreteParents, const ConstructorHelper &helper) : BaseFactor(discreteParents, helper.pairs), BaseConditional(*helper.nrFrontals), - conditionals_(conditionals), + conditionals_(helper.conditionals), negLogConstant_(helper.minNegLogConstant) {} -/* *******************************************************************************/ +HybridGaussianConditional::HybridGaussianConditional( + const DiscreteKey &mode, + const std::vector &conditionals) + : HybridGaussianConditional(DiscreteKeys{mode}, + Conditionals({mode}, conditionals)) {} + +HybridGaussianConditional::HybridGaussianConditional( + Key x, const DiscreteKey mode, const std::vector &means, + double sigma) + : HybridGaussianConditional(DiscreteKeys{mode}, + ConstructorHelper(x, mode, means, sigma)) {} + +HybridGaussianConditional::HybridGaussianConditional( + Key x, const DiscreteKey mode, + const std::vector> ¶meters) + : HybridGaussianConditional(DiscreteKeys{mode}, + ConstructorHelper(x, mode, parameters)) {} + HybridGaussianConditional::HybridGaussianConditional( const DiscreteKeys &discreteParents, const HybridGaussianConditional::Conditionals &conditionals) - : HybridGaussianConditional(discreteParents, conditionals, + : HybridGaussianConditional(discreteParents, ConstructorHelper(conditionals)) {} -/* *******************************************************************************/ -HybridGaussianConditional::HybridGaussianConditional( - const DiscreteKey &discreteParent, - const std::vector &conditionals) - : HybridGaussianConditional(DiscreteKeys{discreteParent}, - Conditionals({discreteParent}, conditionals)) {} - /* *******************************************************************************/ const HybridGaussianConditional::Conditionals & HybridGaussianConditional::conditionals() const { diff --git a/gtsam/hybrid/HybridGaussianConditional.h b/gtsam/hybrid/HybridGaussianConditional.h index 02f6bba1e..662318837 100644 --- a/gtsam/hybrid/HybridGaussianConditional.h +++ b/gtsam/hybrid/HybridGaussianConditional.h @@ -79,14 +79,36 @@ class GTSAM_EXPORT HybridGaussianConditional /** * @brief Construct from one discrete key and vector of conditionals. * - * @param discreteParent Single discrete parent variable + * @param mode Single discrete parent variable * @param conditionals Vector of conditionals with the same size as the * cardinality of the discrete parent. */ HybridGaussianConditional( - const DiscreteKey &discreteParent, + const DiscreteKey &mode, const std::vector &conditionals); + /** + * @brief Construct from vector of means and a single sigma. + * + * @param x The continuous key. + * @param mode The discrete key. + * @param means The means for the Gaussian conditionals. + * @param sigma The standard deviation for the Gaussian conditionals. + */ + HybridGaussianConditional(Key x, const DiscreteKey mode, + const std::vector &means, double sigma); + + /** + * @brief Construct from vector of means and sigmas. + * + * @param x The continuous key. + * @param mode The discrete key. + * @param parameters The means and sigmas for the Gaussian conditionals. + */ + HybridGaussianConditional( + Key x, const DiscreteKey mode, + const std::vector> ¶meters); + /** * @brief Construct from multiple discrete keys and conditional tree. * @@ -186,10 +208,8 @@ class GTSAM_EXPORT HybridGaussianConditional struct ConstructorHelper; /// Private constructor that uses helper struct above. - HybridGaussianConditional( - const DiscreteKeys &discreteParents, - const HybridGaussianConditional::Conditionals &conditionals, - const ConstructorHelper &helper); + HybridGaussianConditional(const DiscreteKeys &discreteParents, + const ConstructorHelper &helper); /// Convert to a DecisionTree of Gaussian factor graphs. GaussianFactorGraphTree asGaussianFactorGraphTree() const; diff --git a/gtsam/hybrid/tests/testGaussianMixture.cpp b/gtsam/hybrid/tests/testGaussianMixture.cpp index cde7e4063..87ef5e25d 100644 --- a/gtsam/hybrid/tests/testGaussianMixture.cpp +++ b/gtsam/hybrid/tests/testGaussianMixture.cpp @@ -43,26 +43,6 @@ const DiscreteValues m1Assignment{{M(0), 1}}; DiscreteConditional::shared_ptr mixing = std::make_shared(m, "60/40"); -/** - * Create a simple Gaussian Mixture Model represented as p(z|m)P(m) - * where m is a discrete variable and z is a continuous variable. - * The "mode" m is binary and depending on m, we have 2 different means - * μ1 and μ2 for the Gaussian density p(z|m). - */ -HybridBayesNet GaussianMixtureModel(double mu0, double mu1, double sigma0, - double sigma1) { - HybridBayesNet hbn; - auto model0 = noiseModel::Isotropic::Sigma(1, sigma0); - auto model1 = noiseModel::Isotropic::Sigma(1, sigma1); - auto c0 = std::make_shared(Z(0), Vector1(mu0), I_1x1, - model0), - c1 = std::make_shared(Z(0), Vector1(mu1), I_1x1, - model1); - hbn.emplace_shared(m, std::vector{c0, c1}); - hbn.push_back(mixing); - return hbn; -} - /// Gaussian density function double Gaussian(double mu, double sigma, double z) { return exp(-0.5 * pow((z - mu) / sigma, 2)) / sqrt(2 * M_PI * sigma * sigma); @@ -99,7 +79,10 @@ TEST(GaussianMixture, GaussianMixtureModel) { double mu0 = 1.0, mu1 = 3.0; double sigma = 2.0; - auto hbn = GaussianMixtureModel(mu0, mu1, sigma, sigma); + HybridBayesNet hbn; + std::vector means{Vector1(mu0), Vector1(mu1)}; + hbn.emplace_shared(Z(0), m, means, sigma); + hbn.push_back(mixing); // At the halfway point between the means, we should get P(m|z)=0.5 double midway = mu1 - mu0; @@ -133,7 +116,11 @@ TEST(GaussianMixture, GaussianMixtureModel2) { double mu0 = 1.0, mu1 = 3.0; double sigma0 = 8.0, sigma1 = 4.0; - auto hbn = GaussianMixtureModel(mu0, mu1, sigma0, sigma1); + HybridBayesNet hbn; + std::vector> parameters{{Vector1(mu0), sigma0}, + {Vector1(mu1), sigma1}}; + hbn.emplace_shared(Z(0), m, parameters); + hbn.push_back(mixing); // We get zMax=3.1333 by finding the maximum value of the function, at which // point the mode m==1 is about twice as probable as m==0.