diff --git a/gtsam/hybrid/GaussianMixture.cpp b/gtsam/hybrid/GaussianMixture.cpp index c105a329e..325c32f95 100644 --- a/gtsam/hybrid/GaussianMixture.cpp +++ b/gtsam/hybrid/GaussianMixture.cpp @@ -24,6 +24,7 @@ #include #include #include +#include #include namespace gtsam { @@ -86,7 +87,22 @@ GaussianFactorGraphTree GaussianMixture::add( /* *******************************************************************************/ GaussianFactorGraphTree GaussianMixture::asGaussianFactorGraphTree() const { - auto wrap = [](const GaussianConditional::shared_ptr &gc) { + auto wrap = [this](const GaussianConditional::shared_ptr &gc) { + // First check if conditional has not been pruned + if (gc) { + const double Cgm_Kgcm = + this->logConstant_ - gc->logNormalizationConstant(); + // If there is a difference in the covariances, we need to account for + // that since the error is dependent on the mode. + if (Cgm_Kgcm > 0.0) { + // We add a constant factor which will be used when computing + // the probability of the discrete variables. + Vector c(1); + c << std::sqrt(2.0 * Cgm_Kgcm); + auto constantFactor = std::make_shared(c); + return GaussianFactorGraph{gc, constantFactor}; + } + } return GaussianFactorGraph{gc}; }; return {conditionals_, wrap}; @@ -145,6 +161,8 @@ void GaussianMixture::print(const std::string &s, std::cout << "(" << formatter(dk.first) << ", " << dk.second << "), "; } std::cout << "\n"; + std::cout << " logNormalizationConstant: " << logConstant_ << "\n" + << std::endl; conditionals_.print( "", [&](Key k) { return formatter(k); }, [&](const GaussianConditional::shared_ptr &gf) -> std::string { @@ -312,12 +330,28 @@ AlgebraicDecisionTree GaussianMixture::logProbability( return DecisionTree(conditionals_, probFunc); } +/* ************************************************************************* */ +double GaussianMixture::conditionalError( + const GaussianConditional::shared_ptr &conditional, + const VectorValues &continuousValues) const { + // Check if valid pointer + if (conditional) { + return conditional->error(continuousValues) + // + logConstant_ - conditional->logNormalizationConstant(); + } else { + // If not valid, pointer, it means this conditional was pruned, + // so we return maximum error. + // This way the negative exponential will give + // a probability value close to 0.0. + return std::numeric_limits::max(); + } +} + /* *******************************************************************************/ AlgebraicDecisionTree GaussianMixture::errorTree( const VectorValues &continuousValues) const { auto errorFunc = [&](const GaussianConditional::shared_ptr &conditional) { - return conditional->error(continuousValues) + // - logConstant_ - conditional->logNormalizationConstant(); + return conditionalError(conditional, continuousValues); }; DecisionTree error_tree(conditionals_, errorFunc); return error_tree; @@ -327,8 +361,7 @@ AlgebraicDecisionTree GaussianMixture::errorTree( double GaussianMixture::error(const HybridValues &values) const { // Directly index to get the conditional, no need to build the whole tree. auto conditional = conditionals_(values.discrete()); - return conditional->error(values.continuous()) + // - logConstant_ - conditional->logNormalizationConstant(); + return conditionalError(conditional, values.continuous()); } /* *******************************************************************************/ diff --git a/gtsam/hybrid/GaussianMixture.h b/gtsam/hybrid/GaussianMixture.h index c1ef504f8..714c00272 100644 --- a/gtsam/hybrid/GaussianMixture.h +++ b/gtsam/hybrid/GaussianMixture.h @@ -67,7 +67,7 @@ class GTSAM_EXPORT GaussianMixture double logConstant_; ///< log of the normalization constant. /** - * @brief Convert a DecisionTree of factors into + * @brief Convert a GaussianMixture of conditionals into * a DecisionTree of Gaussian factor graphs. */ GaussianFactorGraphTree asGaussianFactorGraphTree() const; @@ -256,6 +256,10 @@ class GTSAM_EXPORT GaussianMixture /// Check whether `given` has values for all frontal keys. bool allFrontalsGiven(const VectorValues &given) const; + /// Helper method to compute the error of a conditional. + double conditionalError(const GaussianConditional::shared_ptr &conditional, + const VectorValues &continuousValues) const; + #ifdef GTSAM_ENABLE_BOOST_SERIALIZATION /** Serialization function */ friend class boost::serialization::access; diff --git a/gtsam/hybrid/GaussianMixtureFactor.cpp b/gtsam/hybrid/GaussianMixtureFactor.cpp index a3db16d04..94bc09407 100644 --- a/gtsam/hybrid/GaussianMixtureFactor.cpp +++ b/gtsam/hybrid/GaussianMixtureFactor.cpp @@ -54,7 +54,9 @@ bool GaussianMixtureFactor::equals(const HybridFactor &lf, double tol) const { /* *******************************************************************************/ void GaussianMixtureFactor::print(const std::string &s, const KeyFormatter &formatter) const { - HybridFactor::print(s, formatter); + std::cout << (s.empty() ? "" : s + "\n"); + std::cout << "GaussianMixtureFactor" << std::endl; + HybridFactor::print("", formatter); std::cout << "{\n"; if (factors_.empty()) { std::cout << " empty" << std::endl; @@ -64,7 +66,7 @@ void GaussianMixtureFactor::print(const std::string &s, [&](const sharedFactor &gf) -> std::string { RedirectCout rd; std::cout << ":\n"; - if (gf && !gf->empty()) { + if (gf) { gf->print("", formatter); return rd.str(); } else { @@ -117,6 +119,5 @@ double GaussianMixtureFactor::error(const HybridValues &values) const { const sharedFactor gf = factors_(values.discrete()); return gf->error(values.continuous()); } -/* *******************************************************************************/ } // namespace gtsam diff --git a/gtsam/hybrid/GaussianMixtureFactor.h b/gtsam/hybrid/GaussianMixtureFactor.h index 67d12ddb0..2459af259 100644 --- a/gtsam/hybrid/GaussianMixtureFactor.h +++ b/gtsam/hybrid/GaussianMixtureFactor.h @@ -80,8 +80,8 @@ class GTSAM_EXPORT GaussianMixtureFactor : public HybridFactor { * @param continuousKeys A vector of keys representing continuous variables. * @param discreteKeys A vector of keys representing discrete variables and * their cardinalities. - * @param factors The decision tree of Gaussian factors stored as the mixture - * density. + * @param factors The decision tree of Gaussian factors stored + * as the mixture density. */ GaussianMixtureFactor(const KeyVector &continuousKeys, const DiscreteKeys &discreteKeys, @@ -107,9 +107,8 @@ class GTSAM_EXPORT GaussianMixtureFactor : public HybridFactor { bool equals(const HybridFactor &lf, double tol = 1e-9) const override; - void print( - const std::string &s = "GaussianMixtureFactor\n", - const KeyFormatter &formatter = DefaultKeyFormatter) const override; + void print(const std::string &s = "", const KeyFormatter &formatter = + DefaultKeyFormatter) const override; /// @} /// @name Standard API diff --git a/gtsam/hybrid/HybridBayesNet.cpp b/gtsam/hybrid/HybridBayesNet.cpp index b02967555..1d01baed2 100644 --- a/gtsam/hybrid/HybridBayesNet.cpp +++ b/gtsam/hybrid/HybridBayesNet.cpp @@ -220,15 +220,16 @@ GaussianBayesNet HybridBayesNet::choose( /* ************************************************************************* */ HybridValues HybridBayesNet::optimize() const { // Collect all the discrete factors to compute MPE - DiscreteBayesNet discrete_bn; + DiscreteFactorGraph discrete_fg; + for (auto &&conditional : *this) { if (conditional->isDiscrete()) { - discrete_bn.push_back(conditional->asDiscrete()); + discrete_fg.push_back(conditional->asDiscrete()); } } // Solve for the MPE - DiscreteValues mpe = DiscreteFactorGraph(discrete_bn).optimize(); + DiscreteValues mpe = discrete_fg.optimize(); // Given the MPE, compute the optimal continuous values. return HybridValues(optimize(mpe), mpe); diff --git a/gtsam/hybrid/HybridConditional.h b/gtsam/hybrid/HybridConditional.h index 64bdcb2c1..fb6542822 100644 --- a/gtsam/hybrid/HybridConditional.h +++ b/gtsam/hybrid/HybridConditional.h @@ -61,7 +61,7 @@ class GTSAM_EXPORT HybridConditional public Conditional { public: // typedefs needed to play nice with gtsam - typedef HybridConditional This; ///< Typedef to this class + typedef HybridConditional This; ///< Typedef to this class typedef std::shared_ptr shared_ptr; ///< shared_ptr to this class typedef HybridFactor BaseFactor; ///< Typedef to our factor base class typedef Conditional @@ -185,7 +185,7 @@ class GTSAM_EXPORT HybridConditional * Return the log normalization constant. * Note this is 0.0 for discrete and hybrid conditionals, but depends * on the continuous parameters for Gaussian conditionals. - */ + */ double logNormalizationConstant() const override; /// Return the probability (or density) of the underlying conditional. diff --git a/gtsam/hybrid/HybridFactor.h b/gtsam/hybrid/HybridFactor.h index afd1c8032..a9c0e53d2 100644 --- a/gtsam/hybrid/HybridFactor.h +++ b/gtsam/hybrid/HybridFactor.h @@ -13,6 +13,7 @@ * @file HybridFactor.h * @date Mar 11, 2022 * @author Fan Jiang + * @author Varun Agrawal */ #pragma once diff --git a/gtsam/hybrid/HybridFactorGraph.cpp b/gtsam/hybrid/HybridFactorGraph.cpp index f7b96f694..f5a7bcdfe 100644 --- a/gtsam/hybrid/HybridFactorGraph.cpp +++ b/gtsam/hybrid/HybridFactorGraph.cpp @@ -49,15 +49,6 @@ KeySet HybridFactorGraph::discreteKeySet() const { return keys; } -/* ************************************************************************* */ -std::unordered_map HybridFactorGraph::discreteKeyMap() const { - std::unordered_map result; - for (const DiscreteKey& k : discreteKeys()) { - result[k.first] = k; - } - return result; -} - /* ************************************************************************* */ const KeySet HybridFactorGraph::continuousKeySet() const { KeySet keys; diff --git a/gtsam/hybrid/HybridFactorGraph.h b/gtsam/hybrid/HybridFactorGraph.h index 8b59fd4f9..79f2a7af1 100644 --- a/gtsam/hybrid/HybridFactorGraph.h +++ b/gtsam/hybrid/HybridFactorGraph.h @@ -38,7 +38,7 @@ using SharedFactor = std::shared_ptr; class GTSAM_EXPORT HybridFactorGraph : public FactorGraph { public: using Base = FactorGraph; - using This = HybridFactorGraph; ///< this class + using This = HybridFactorGraph; ///< this class using shared_ptr = std::shared_ptr; ///< shared_ptr to This using Values = gtsam::Values; ///< backwards compatibility @@ -66,12 +66,9 @@ class GTSAM_EXPORT HybridFactorGraph : public FactorGraph { /// Get all the discrete keys in the factor graph. std::set discreteKeys() const; - /// Get all the discrete keys in the factor graph, as a set. + /// Get all the discrete keys in the factor graph, as a set of Keys. KeySet discreteKeySet() const; - /// Get a map from Key to corresponding DiscreteKey. - std::unordered_map discreteKeyMap() const; - /// Get all the continuous keys in the factor graph. const KeySet continuousKeySet() const; diff --git a/gtsam/hybrid/HybridGaussianFactorGraph.cpp b/gtsam/hybrid/HybridGaussianFactorGraph.cpp index b764dc9e0..a7a6eee5a 100644 --- a/gtsam/hybrid/HybridGaussianFactorGraph.cpp +++ b/gtsam/hybrid/HybridGaussianFactorGraph.cpp @@ -97,29 +97,27 @@ void HybridGaussianFactorGraph::printErrors( std::cout << "nullptr" << "\n"; } else { - factor->print(ss.str(), keyFormatter); - std::cout << "error = "; - gmf->errorTree(values.continuous()).print("", keyFormatter); - std::cout << std::endl; + gmf->operator()(values.discrete())->print(ss.str(), keyFormatter); + std::cout << "error = " << gmf->error(values) << std::endl; } } else if (auto hc = std::dynamic_pointer_cast(factor)) { if (factor == nullptr) { std::cout << "nullptr" << "\n"; } else { - factor->print(ss.str(), keyFormatter); - if (hc->isContinuous()) { + factor->print(ss.str(), keyFormatter); std::cout << "error = " << hc->asGaussian()->error(values) << "\n"; } else if (hc->isDiscrete()) { - std::cout << "error = "; - hc->asDiscrete()->errorTree().print("", keyFormatter); - std::cout << "\n"; + factor->print(ss.str(), keyFormatter); + std::cout << "error = " << hc->asDiscrete()->error(values.discrete()) + << "\n"; } else { // Is hybrid - std::cout << "error = "; - hc->asMixture()->errorTree(values.continuous()).print(); - std::cout << "\n"; + auto mixtureComponent = + hc->asMixture()->operator()(values.discrete()); + mixtureComponent->print(ss.str(), keyFormatter); + std::cout << "error = " << mixtureComponent->error(values) << "\n"; } } } else if (auto gf = std::dynamic_pointer_cast(factor)) { @@ -140,8 +138,7 @@ void HybridGaussianFactorGraph::printErrors( << "\n"; } else { factor->print(ss.str(), keyFormatter); - std::cout << "error = "; - df->errorTree().print("", keyFormatter); + std::cout << "error = " << df->error(values.discrete()) << std::endl; } } else { @@ -233,6 +230,25 @@ continuousElimination(const HybridGaussianFactorGraph &factors, return {std::make_shared(result.first), result.second}; } +/* ************************************************************************ */ +/** + * @brief Exponentiate log-values, not necessarily normalized, normalize, and + * return as AlgebraicDecisionTree. + * + * @param logValues DecisionTree of (unnormalized) log values. + * @return AlgebraicDecisionTree + */ +static AlgebraicDecisionTree probabilitiesFromLogValues( + const AlgebraicDecisionTree &logValues) { + // Perform normalization + double max_log = logValues.max(); + AlgebraicDecisionTree probabilities = DecisionTree( + logValues, [&max_log](const double x) { return exp(x - max_log); }); + probabilities = probabilities.normalize(probabilities.sum()); + + return probabilities; +} + /* ************************************************************************ */ static std::pair> discreteElimination(const HybridGaussianFactorGraph &factors, @@ -242,6 +258,22 @@ discreteElimination(const HybridGaussianFactorGraph &factors, for (auto &f : factors) { if (auto df = dynamic_pointer_cast(f)) { dfg.push_back(df); + } else if (auto gmf = dynamic_pointer_cast(f)) { + // Case where we have a GaussianMixtureFactor with no continuous keys. + // In this case, compute discrete probabilities. + auto logProbability = + [&](const GaussianFactor::shared_ptr &factor) -> double { + if (!factor) return 0.0; + return -factor->error(VectorValues()); + }; + AlgebraicDecisionTree logProbabilities = + DecisionTree(gmf->factors(), logProbability); + + AlgebraicDecisionTree probabilities = + probabilitiesFromLogValues(logProbabilities); + dfg.emplace_shared(gmf->discreteKeys(), + probabilities); + } else if (auto orphan = dynamic_pointer_cast(f)) { // Ignore orphaned clique. // TODO(dellaert): is this correct? If so explain here. @@ -279,21 +311,32 @@ GaussianFactorGraphTree removeEmpty(const GaussianFactorGraphTree &sum) { using Result = std::pair, GaussianMixtureFactor::sharedFactor>; -// Integrate the probability mass in the last continuous conditional using -// the unnormalized probability q(μ;m) = exp(-error(μ;m)) at the mean. -// discrete_probability = exp(-error(μ;m)) * sqrt(det(2π Σ_m)) +/** + * Compute the probability p(μ;m) = exp(-error(μ;m)) * sqrt(det(2π Σ_m) + * from the residual error ||b||^2 at the mean μ. + * The residual error contains no keys, and only + * depends on the discrete separator if present. + */ static std::shared_ptr createDiscreteFactor( const DecisionTree &eliminationResults, const DiscreteKeys &discreteSeparator) { - auto probability = [&](const Result &pair) -> double { + auto logProbability = [&](const Result &pair) -> double { const auto &[conditional, factor] = pair; static const VectorValues kEmpty; // If the factor is not null, it has no keys, just contains the residual. if (!factor) return 1.0; // TODO(dellaert): not loving this. - return exp(-factor->error(kEmpty)) / conditional->normalizationConstant(); + + // Logspace version of: + // exp(-factor->error(kEmpty)) / conditional->normalizationConstant(); + // We take negative of the logNormalizationConstant `log(1/k)` + // to get `log(k)`. + return -factor->error(kEmpty) - conditional->logNormalizationConstant(); }; - DecisionTree probabilities(eliminationResults, probability); + AlgebraicDecisionTree logProbabilities( + DecisionTree(eliminationResults, logProbability)); + AlgebraicDecisionTree probabilities = + probabilitiesFromLogValues(logProbabilities); return std::make_shared(discreteSeparator, probabilities); } @@ -480,18 +523,9 @@ EliminateHybrid(const HybridGaussianFactorGraph &factors, std::inserter(continuousSeparator, continuousSeparator.begin())); // Similarly for the discrete separator. - KeySet discreteSeparatorSet; - std::set discreteSeparator; - auto discreteKeySet = factors.discreteKeySet(); - std::set_difference( - discreteKeySet.begin(), discreteKeySet.end(), frontalKeysSet.begin(), - frontalKeysSet.end(), - std::inserter(discreteSeparatorSet, discreteSeparatorSet.begin())); - // Convert from set of keys to set of DiscreteKeys - auto discreteKeyMap = factors.discreteKeyMap(); - for (auto key : discreteSeparatorSet) { - discreteSeparator.insert(discreteKeyMap.at(key)); - } + // Since we eliminate all continuous variables first, + // the discrete separator will be *all* the discrete keys. + std::set discreteSeparator = factors.discreteKeys(); return hybridElimination(factors, frontalKeys, continuousSeparator, discreteSeparator); @@ -504,10 +538,15 @@ AlgebraicDecisionTree HybridGaussianFactorGraph::errorTree( AlgebraicDecisionTree error_tree(0.0); // Iterate over each factor. - for (auto &f : factors_) { + for (auto &factor : factors_) { // TODO(dellaert): just use a virtual method defined in HybridFactor. AlgebraicDecisionTree factor_error; + auto f = factor; + if (auto hc = dynamic_pointer_cast(factor)) { + f = hc->inner(); + } + if (auto gaussianMixture = dynamic_pointer_cast(f)) { // Compute factor error and add it. error_tree = error_tree + gaussianMixture->errorTree(continuousValues); diff --git a/gtsam/hybrid/HybridGaussianFactorGraph.h b/gtsam/hybrid/HybridGaussianFactorGraph.h index 1708ff64b..2ca6e4c95 100644 --- a/gtsam/hybrid/HybridGaussianFactorGraph.h +++ b/gtsam/hybrid/HybridGaussianFactorGraph.h @@ -144,6 +144,14 @@ class GTSAM_EXPORT HybridGaussianFactorGraph // const std::string& s = "HybridGaussianFactorGraph", // const KeyFormatter& keyFormatter = DefaultKeyFormatter) const override; + /** + * @brief Print the errors of each factor in the hybrid factor graph. + * + * @param values The HybridValues for the variables used to compute the error. + * @param str String that is output before the factor graph and errors. + * @param keyFormatter Formatter function for the keys in the factors. + * @param printCondition A condition to check if a factor should be printed. + */ void printErrors( const HybridValues& values, const std::string& str = "HybridGaussianFactorGraph: ", diff --git a/gtsam/hybrid/tests/testGaussianMixtureFactor.cpp b/gtsam/hybrid/tests/testGaussianMixtureFactor.cpp index 9cc7e6bfd..b2a4981f3 100644 --- a/gtsam/hybrid/tests/testGaussianMixtureFactor.cpp +++ b/gtsam/hybrid/tests/testGaussianMixtureFactor.cpp @@ -22,9 +22,13 @@ #include #include #include +#include +#include #include #include #include +#include +#include // Include for test suite #include @@ -32,8 +36,10 @@ using namespace std; using namespace gtsam; using noiseModel::Isotropic; +using symbol_shorthand::F; using symbol_shorthand::M; using symbol_shorthand::X; +using symbol_shorthand::Z; /* ************************************************************************* */ // Check iterators of empty mixture. @@ -56,7 +62,6 @@ TEST(GaussianMixtureFactor, Sum) { auto b = Matrix::Zero(2, 1); Vector2 sigmas; sigmas << 1, 2; - auto model = noiseModel::Diagonal::Sigmas(sigmas, true); auto f10 = std::make_shared(X(1), A1, X(2), A2, b); auto f11 = std::make_shared(X(1), A1, X(2), A2, b); @@ -106,7 +111,8 @@ TEST(GaussianMixtureFactor, Printing) { GaussianMixtureFactor mixtureFactor({X(1), X(2)}, {m1}, factors); std::string expected = - R"(Hybrid [x1 x2; 1]{ + R"(GaussianMixtureFactor +Hybrid [x1 x2; 1]{ Choice(1) 0 Leaf : A[x1] = [ @@ -178,7 +184,8 @@ TEST(GaussianMixtureFactor, Error) { continuousValues.insert(X(2), Vector2(1, 1)); // error should return a tree of errors, with nodes for each discrete value. - AlgebraicDecisionTree error_tree = mixtureFactor.errorTree(continuousValues); + AlgebraicDecisionTree error_tree = + mixtureFactor.errorTree(continuousValues); std::vector discrete_keys = {m1}; // Error values for regression test @@ -191,8 +198,390 @@ TEST(GaussianMixtureFactor, Error) { DiscreteValues discreteValues; discreteValues[m1.first] = 1; EXPECT_DOUBLES_EQUAL( - 4.0, mixtureFactor.error({continuousValues, discreteValues}), - 1e-9); + 4.0, mixtureFactor.error({continuousValues, discreteValues}), 1e-9); +} + +namespace test_gmm { + +/** + * Function to compute P(m=1|z). For P(m=0|z), swap mus and sigmas. + * If sigma0 == sigma1, it simplifies to a sigmoid function. + * + * Follows equation 7.108 since it is more generic. + */ +double prob_m_z(double mu0, double mu1, double sigma0, double sigma1, + double z) { + double x1 = ((z - mu0) / sigma0), x2 = ((z - mu1) / sigma1); + double d = sigma1 / sigma0; + double e = d * std::exp(-0.5 * (x1 * x1 - x2 * x2)); + return 1 / (1 + e); +}; + +static HybridBayesNet GetGaussianMixtureModel(double mu0, double mu1, + double sigma0, double sigma1) { + DiscreteKey m(M(0), 2); + Key z = Z(0); + + auto model0 = noiseModel::Isotropic::Sigma(1, sigma0); + auto model1 = noiseModel::Isotropic::Sigma(1, sigma1); + + auto c0 = make_shared(z, Vector1(mu0), I_1x1, model0), + c1 = make_shared(z, Vector1(mu1), I_1x1, model1); + + auto gm = new GaussianMixture({z}, {}, {m}, {c0, c1}); + auto mixing = new DiscreteConditional(m, "0.5/0.5"); + + HybridBayesNet hbn; + hbn.emplace_back(gm); + hbn.emplace_back(mixing); + + return hbn; +} + +} // namespace test_gmm + +/* ************************************************************************* */ +/** + * Test a simple Gaussian Mixture Model represented as P(m)P(z|m) + * where m is a discrete variable and z is a continuous variable. + * m is binary and depending on m, we have 2 different means + * μ1 and μ2 for the Gaussian distribution around which we sample z. + * + * The resulting factor graph should eliminate to a Bayes net + * which represents a sigmoid function. + */ +TEST(GaussianMixtureFactor, GaussianMixtureModel) { + using namespace test_gmm; + + double mu0 = 1.0, mu1 = 3.0; + double sigma = 2.0; + + DiscreteKey m(M(0), 2); + Key z = Z(0); + + auto hbn = GetGaussianMixtureModel(mu0, mu1, sigma, sigma); + + // The result should be a sigmoid. + // So should be P(m=1|z) = 0.5 at z=3.0 - 1.0=2.0 + double midway = mu1 - mu0, lambda = 4; + { + VectorValues given; + given.insert(z, Vector1(midway)); + + HybridGaussianFactorGraph gfg = hbn.toFactorGraph(given); + HybridBayesNet::shared_ptr bn = gfg.eliminateSequential(); + + EXPECT_DOUBLES_EQUAL( + prob_m_z(mu0, mu1, sigma, sigma, midway), + bn->at(0)->asDiscrete()->operator()(DiscreteValues{{m.first, 1}}), + 1e-8); + + // At the halfway point between the means, we should get P(m|z)=0.5 + HybridBayesNet expected; + expected.emplace_back(new DiscreteConditional(m, "0.5/0.5")); + + EXPECT(assert_equal(expected, *bn)); + } + { + // Shift by -lambda + VectorValues given; + given.insert(z, Vector1(midway - lambda)); + + HybridGaussianFactorGraph gfg = hbn.toFactorGraph(given); + HybridBayesNet::shared_ptr bn = gfg.eliminateSequential(); + + EXPECT_DOUBLES_EQUAL( + prob_m_z(mu0, mu1, sigma, sigma, midway - lambda), + bn->at(0)->asDiscrete()->operator()(DiscreteValues{{m.first, 1}}), + 1e-8); + } + { + // Shift by lambda + VectorValues given; + given.insert(z, Vector1(midway + lambda)); + + HybridGaussianFactorGraph gfg = hbn.toFactorGraph(given); + HybridBayesNet::shared_ptr bn = gfg.eliminateSequential(); + + EXPECT_DOUBLES_EQUAL( + prob_m_z(mu0, mu1, sigma, sigma, midway + lambda), + bn->at(0)->asDiscrete()->operator()(DiscreteValues{{m.first, 1}}), + 1e-8); + } +} + +/* ************************************************************************* */ +/** + * Test a simple Gaussian Mixture Model represented as P(m)P(z|m) + * where m is a discrete variable and z is a continuous variable. + * m is binary and depending on m, we have 2 different means + * and covariances each for the + * Gaussian distribution around which we sample z. + * + * The resulting factor graph should eliminate to a Bayes net + * which represents a Gaussian-like function + * where m1>m0 close to 3.1333. + */ +TEST(GaussianMixtureFactor, GaussianMixtureModel2) { + using namespace test_gmm; + + double mu0 = 1.0, mu1 = 3.0; + double sigma0 = 8.0, sigma1 = 4.0; + + DiscreteKey m(M(0), 2); + Key z = Z(0); + + auto hbn = GetGaussianMixtureModel(mu0, mu1, sigma0, sigma1); + + double m1_high = 3.133, lambda = 4; + { + // The result should be a bell curve like function + // with m1 > m0 close to 3.1333. + // We get 3.1333 by finding the maximum value of the function. + VectorValues given; + given.insert(z, Vector1(3.133)); + + HybridGaussianFactorGraph gfg = hbn.toFactorGraph(given); + HybridBayesNet::shared_ptr bn = gfg.eliminateSequential(); + + EXPECT_DOUBLES_EQUAL( + prob_m_z(mu0, mu1, sigma0, sigma1, m1_high), + bn->at(0)->asDiscrete()->operator()(DiscreteValues{{M(0), 1}}), 1e-8); + + // At the halfway point between the means + HybridBayesNet expected; + expected.emplace_back(new DiscreteConditional( + m, {}, + vector{prob_m_z(mu1, mu0, sigma1, sigma0, m1_high), + prob_m_z(mu0, mu1, sigma0, sigma1, m1_high)})); + + EXPECT(assert_equal(expected, *bn)); + } + { + // Shift by -lambda + VectorValues given; + given.insert(z, Vector1(m1_high - lambda)); + + HybridGaussianFactorGraph gfg = hbn.toFactorGraph(given); + HybridBayesNet::shared_ptr bn = gfg.eliminateSequential(); + + EXPECT_DOUBLES_EQUAL( + prob_m_z(mu0, mu1, sigma0, sigma1, m1_high - lambda), + bn->at(0)->asDiscrete()->operator()(DiscreteValues{{m.first, 1}}), + 1e-8); + } + { + // Shift by lambda + VectorValues given; + given.insert(z, Vector1(m1_high + lambda)); + + HybridGaussianFactorGraph gfg = hbn.toFactorGraph(given); + HybridBayesNet::shared_ptr bn = gfg.eliminateSequential(); + + EXPECT_DOUBLES_EQUAL( + prob_m_z(mu0, mu1, sigma0, sigma1, m1_high + lambda), + bn->at(0)->asDiscrete()->operator()(DiscreteValues{{m.first, 1}}), + 1e-8); + } +} + +namespace test_two_state_estimation { + +/// Create Two State Bayes Network with measurements +static HybridBayesNet CreateBayesNet(double mu0, double mu1, double sigma0, + double sigma1, + bool add_second_measurement = false, + double prior_sigma = 1e-3, + double measurement_sigma = 3.0) { + DiscreteKey m1(M(1), 2); + Key z0 = Z(0), z1 = Z(1); + Key x0 = X(0), x1 = X(1); + + HybridBayesNet hbn; + + auto measurement_model = noiseModel::Isotropic::Sigma(1, measurement_sigma); + // Add measurement P(z0 | x0) + auto p_z0 = new GaussianConditional(z0, Vector1(0.0), -I_1x1, x0, I_1x1, + measurement_model); + hbn.emplace_back(p_z0); + + // Add hybrid motion model + auto model0 = noiseModel::Isotropic::Sigma(1, sigma0); + auto model1 = noiseModel::Isotropic::Sigma(1, sigma1); + auto c0 = make_shared(x1, Vector1(mu0), I_1x1, x0, + -I_1x1, model0), + c1 = make_shared(x1, Vector1(mu1), I_1x1, x0, + -I_1x1, model1); + + auto motion = new GaussianMixture({x1}, {x0}, {m1}, {c0, c1}); + hbn.emplace_back(motion); + + if (add_second_measurement) { + // Add second measurement + auto p_z1 = new GaussianConditional(z1, Vector1(0.0), -I_1x1, x1, I_1x1, + measurement_model); + hbn.emplace_back(p_z1); + } + + // Discrete uniform prior. + auto p_m1 = new DiscreteConditional(m1, "0.5/0.5"); + hbn.emplace_back(p_m1); + + return hbn; +} + +} // namespace test_two_state_estimation + +/* ************************************************************************* */ +/** + * Test a model P(z0|x0)P(x1|x0,m1)P(z1|x1)P(m1). + * + * P(f01|x1,x0,m1) has different means and same covariance. + * + * Converting to a factor graph gives us + * ϕ(x0)ϕ(x1,x0,m1)ϕ(x1)P(m1) + * + * If we only have a measurement on z0, then + * the probability of m1 should be 0.5/0.5. + * Getting a measurement on z1 gives use more information. + */ +TEST(GaussianMixtureFactor, TwoStateModel) { + using namespace test_two_state_estimation; + + double mu0 = 1.0, mu1 = 3.0; + double sigma = 2.0; + + DiscreteKey m1(M(1), 2); + Key z0 = Z(0), z1 = Z(1); + + // Start with no measurement on x1, only on x0 + HybridBayesNet hbn = CreateBayesNet(mu0, mu1, sigma, sigma, false); + + VectorValues given; + given.insert(z0, Vector1(0.5)); + + { + HybridGaussianFactorGraph gfg = hbn.toFactorGraph(given); + HybridBayesNet::shared_ptr bn = gfg.eliminateSequential(); + + // Since no measurement on x1, we hedge our bets + DiscreteConditional expected(m1, "0.5/0.5"); + EXPECT(assert_equal(expected, *(bn->at(2)->asDiscrete()))); + } + + { + // Now we add a measurement z1 on x1 + hbn = CreateBayesNet(mu0, mu1, sigma, sigma, true); + + // If we see z1=2.6 (> 2.5 which is the halfway point), + // discrete mode should say m1=1 + given.insert(z1, Vector1(2.6)); + HybridGaussianFactorGraph gfg = hbn.toFactorGraph(given); + HybridBayesNet::shared_ptr bn = gfg.eliminateSequential(); + + // Since we have a measurement on z2, we get a definite result + DiscreteConditional expected(m1, "0.49772729/0.50227271"); + // regression + EXPECT(assert_equal(expected, *(bn->at(2)->asDiscrete()), 1e-6)); + } +} + +/* ************************************************************************* */ +/** + * Test a model P(z0|x0)P(x1|x0,m1)P(z1|x1)P(m1). + * + * P(f01|x1,x0,m1) has different means and different covariances. + * + * Converting to a factor graph gives us + * ϕ(x0)ϕ(x1,x0,m1)ϕ(x1)P(m1) + * + * If we only have a measurement on z0, then + * the P(m1) should be 0.5/0.5. + * Getting a measurement on z1 gives use more information. + */ +TEST(GaussianMixtureFactor, TwoStateModel2) { + using namespace test_two_state_estimation; + + double mu0 = 1.0, mu1 = 3.0; + double sigma0 = 6.0, sigma1 = 4.0; + auto model0 = noiseModel::Isotropic::Sigma(1, sigma0); + auto model1 = noiseModel::Isotropic::Sigma(1, sigma1); + + DiscreteKey m1(M(1), 2); + Key z0 = Z(0), z1 = Z(1); + + // Start with no measurement on x1, only on x0 + HybridBayesNet hbn = CreateBayesNet(mu0, mu1, sigma0, sigma1, false); + + VectorValues given; + given.insert(z0, Vector1(0.5)); + + { + // Start with no measurement on x1, only on x0 + HybridGaussianFactorGraph gfg = hbn.toFactorGraph(given); + + { + VectorValues vv{ + {X(0), Vector1(0.0)}, {X(1), Vector1(1.0)}, {Z(0), Vector1(0.5)}}; + HybridValues hv0(vv, DiscreteValues{{M(1), 0}}), + hv1(vv, DiscreteValues{{M(1), 1}}); + EXPECT_DOUBLES_EQUAL(gfg.error(hv0) / hbn.error(hv0), + gfg.error(hv1) / hbn.error(hv1), 1e-9); + } + { + VectorValues vv{ + {X(0), Vector1(0.5)}, {X(1), Vector1(3.0)}, {Z(0), Vector1(0.5)}}; + HybridValues hv0(vv, DiscreteValues{{M(1), 0}}), + hv1(vv, DiscreteValues{{M(1), 1}}); + EXPECT_DOUBLES_EQUAL(gfg.error(hv0) / hbn.error(hv0), + gfg.error(hv1) / hbn.error(hv1), 1e-9); + } + + HybridBayesNet::shared_ptr bn = gfg.eliminateSequential(); + + // Since no measurement on x1, we a 50/50 probability + auto p_m = bn->at(2)->asDiscrete(); + EXPECT_DOUBLES_EQUAL(0.5, p_m->operator()(DiscreteValues{{m1.first, 0}}), + 1e-9); + EXPECT_DOUBLES_EQUAL(0.5, p_m->operator()(DiscreteValues{{m1.first, 1}}), + 1e-9); + } + + { + // Now we add a measurement z1 on x1 + hbn = CreateBayesNet(mu0, mu1, sigma0, sigma1, true); + + given.insert(z1, Vector1(2.2)); + HybridGaussianFactorGraph gfg = hbn.toFactorGraph(given); + + { + VectorValues vv{{X(0), Vector1(0.0)}, + {X(1), Vector1(1.0)}, + {Z(0), Vector1(0.5)}, + {Z(1), Vector1(2.2)}}; + HybridValues hv0(vv, DiscreteValues{{M(1), 0}}), + hv1(vv, DiscreteValues{{M(1), 1}}); + EXPECT_DOUBLES_EQUAL(gfg.error(hv0) / hbn.error(hv0), + gfg.error(hv1) / hbn.error(hv1), 1e-9); + } + { + VectorValues vv{{X(0), Vector1(0.5)}, + {X(1), Vector1(3.0)}, + {Z(0), Vector1(0.5)}, + {Z(1), Vector1(2.2)}}; + HybridValues hv0(vv, DiscreteValues{{M(1), 0}}), + hv1(vv, DiscreteValues{{M(1), 1}}); + EXPECT_DOUBLES_EQUAL(gfg.error(hv0) / hbn.error(hv0), + gfg.error(hv1) / hbn.error(hv1), 1e-9); + } + + HybridBayesNet::shared_ptr bn = gfg.eliminateSequential(); + + // Since we have a measurement on z2, we get a definite result + DiscreteConditional expected(m1, "0.44744586/0.55255414"); + // regression + EXPECT(assert_equal(expected, *(bn->at(2)->asDiscrete()), 1e-6)); + } } /* ************************************************************************* */ @@ -200,4 +589,4 @@ int main() { TestResult tr; return TestRegistry::runAllTests(tr); } -/* ************************************************************************* */ \ No newline at end of file +/* ************************************************************************* */ diff --git a/gtsam/hybrid/tests/testHybridGaussianFactorGraph.cpp b/gtsam/hybrid/tests/testHybridGaussianFactorGraph.cpp index 5be2f2742..68b3b8215 100644 --- a/gtsam/hybrid/tests/testHybridGaussianFactorGraph.cpp +++ b/gtsam/hybrid/tests/testHybridGaussianFactorGraph.cpp @@ -598,6 +598,57 @@ TEST(HybridGaussianFactorGraph, ErrorAndProbPrimeTree) { EXPECT(assert_equal(expected_probs, probs, 1e-7)); } +/* ****************************************************************************/ +// Test hybrid gaussian factor graph errorTree when there is a HybridConditional in the graph +TEST(HybridGaussianFactorGraph, ErrorTreeWithConditional) { + using symbol_shorthand::F; + + DiscreteKey m1(M(1), 2); + Key z0 = Z(0), f01 = F(0); + Key x0 = X(0), x1 = X(1); + + HybridBayesNet hbn; + + auto prior_model = noiseModel::Isotropic::Sigma(1, 1e-1); + auto measurement_model = noiseModel::Isotropic::Sigma(1, 2.0); + + // Set a prior P(x0) at x0=0 + hbn.emplace_back( + new GaussianConditional(x0, Vector1(0.0), I_1x1, prior_model)); + + // Add measurement P(z0 | x0) + hbn.emplace_back(new GaussianConditional(z0, Vector1(0.0), -I_1x1, x0, I_1x1, + measurement_model)); + + // Add hybrid motion model + double mu = 0.0; + double sigma0 = 1e2, sigma1 = 1e-2; + auto model0 = noiseModel::Isotropic::Sigma(1, sigma0); + auto model1 = noiseModel::Isotropic::Sigma(1, sigma1); + auto c0 = make_shared(f01, Vector1(mu), I_1x1, x1, I_1x1, + x0, -I_1x1, model0), + c1 = make_shared(f01, Vector1(mu), I_1x1, x1, I_1x1, + x0, -I_1x1, model1); + hbn.emplace_back(new GaussianMixture({f01}, {x0, x1}, {m1}, {c0, c1})); + + // Discrete uniform prior. + hbn.emplace_back(new DiscreteConditional(m1, "0.5/0.5")); + + VectorValues given; + given.insert(z0, Vector1(0.0)); + given.insert(f01, Vector1(0.0)); + auto gfg = hbn.toFactorGraph(given); + + VectorValues vv; + vv.insert(x0, Vector1(1.0)); + vv.insert(x1, Vector1(2.0)); + AlgebraicDecisionTree errorTree = gfg.errorTree(vv); + + // regression + AlgebraicDecisionTree expected(m1, 59.335390372, 5050.125); + EXPECT(assert_equal(expected, errorTree, 1e-9)); +} + /* ****************************************************************************/ // Check that assembleGraphTree assembles Gaussian factor graphs for each // assignment. diff --git a/gtsam/hybrid/tests/testHybridNonlinearFactorGraph.cpp b/gtsam/hybrid/tests/testHybridNonlinearFactorGraph.cpp index 751e84d91..2d851b0ff 100644 --- a/gtsam/hybrid/tests/testHybridNonlinearFactorGraph.cpp +++ b/gtsam/hybrid/tests/testHybridNonlinearFactorGraph.cpp @@ -510,6 +510,7 @@ factor 0: b = [ -10 ] No noise model factor 1: +GaussianMixtureFactor Hybrid [x0 x1; m0]{ Choice(m0) 0 Leaf : @@ -534,6 +535,7 @@ Hybrid [x0 x1; m0]{ } factor 2: +GaussianMixtureFactor Hybrid [x1 x2; m1]{ Choice(m1) 0 Leaf : @@ -675,6 +677,8 @@ factor 6: P( m1 | m0 ): size: 3 conditional 0: Hybrid P( x0 | x1 m0) Discrete Keys = (m0, 2), + logNormalizationConstant: 1.38862 + Choice(m0) 0 Leaf p(x0 | x1) R = [ 10.0499 ] @@ -692,6 +696,8 @@ conditional 0: Hybrid P( x0 | x1 m0) conditional 1: Hybrid P( x1 | x2 m0 m1) Discrete Keys = (m0, 2), (m1, 2), + logNormalizationConstant: 1.3935 + Choice(m1) 0 Choice(m0) 0 0 Leaf p(x1 | x2) @@ -725,6 +731,8 @@ conditional 1: Hybrid P( x1 | x2 m0 m1) conditional 2: Hybrid P( x2 | m0 m1) Discrete Keys = (m0, 2), (m1, 2), + logNormalizationConstant: 1.38857 + Choice(m1) 0 Choice(m0) 0 0 Leaf p(x2) diff --git a/gtsam/hybrid/tests/testMixtureFactor.cpp b/gtsam/hybrid/tests/testMixtureFactor.cpp index 0b2564403..a58a4767f 100644 --- a/gtsam/hybrid/tests/testMixtureFactor.cpp +++ b/gtsam/hybrid/tests/testMixtureFactor.cpp @@ -18,6 +18,9 @@ #include #include +#include +#include +#include #include #include #include diff --git a/gtsam/linear/VectorValues.h b/gtsam/linear/VectorValues.h index 2fa50b7f6..7fbd43ffc 100644 --- a/gtsam/linear/VectorValues.h +++ b/gtsam/linear/VectorValues.h @@ -263,11 +263,6 @@ namespace gtsam { /** equals required by Testable for unit testing */ bool equals(const VectorValues& x, double tol = 1e-9) const; - /// Check equality. - friend bool operator==(const VectorValues& lhs, const VectorValues& rhs) { - return lhs.equals(rhs); - } - /// @{ /// @name Advanced Interface /// @{