diff --git a/gtsam/hybrid/GaussianMixture.cpp b/gtsam/hybrid/GaussianMixture.cpp index e686241fc..b5af0bf7f 100644 --- a/gtsam/hybrid/GaussianMixture.cpp +++ b/gtsam/hybrid/GaussianMixture.cpp @@ -203,8 +203,8 @@ boost::shared_ptr GaussianMixture::likelihood( const KeyVector continuousParentKeys = continuousParents(); const GaussianMixtureFactor::Factors likelihoods( conditionals_, [&](const GaussianConditional::shared_ptr &conditional) { - return GaussianMixtureFactor::FactorAndConstant{ - conditional->likelihood(given), 0.0}; + return GaussianMixtureFactor::sharedFactor{ + conditional->likelihood(given)}; }); return boost::make_shared( continuousParentKeys, discreteParentKeys, likelihoods); diff --git a/gtsam/hybrid/GaussianMixtureFactor.cpp b/gtsam/hybrid/GaussianMixtureFactor.cpp index 57f42e6f1..e8a07b42a 100644 --- a/gtsam/hybrid/GaussianMixtureFactor.cpp +++ b/gtsam/hybrid/GaussianMixtureFactor.cpp @@ -31,11 +31,8 @@ namespace gtsam { /* *******************************************************************************/ GaussianMixtureFactor::GaussianMixtureFactor(const KeyVector &continuousKeys, const DiscreteKeys &discreteKeys, - const Mixture &factors) - : Base(continuousKeys, discreteKeys), - factors_(factors, [](const GaussianFactor::shared_ptr &gf) { - return FactorAndConstant{gf, 0.0}; - }) {} + const Factors &factors) + : Base(continuousKeys, discreteKeys), factors_(factors) {} /* *******************************************************************************/ bool GaussianMixtureFactor::equals(const HybridFactor &lf, double tol) const { @@ -48,11 +45,10 @@ bool GaussianMixtureFactor::equals(const HybridFactor &lf, double tol) const { // Check the base and the factors: return Base::equals(*e, tol) && - factors_.equals(e->factors_, [tol](const FactorAndConstant &f1, - const FactorAndConstant &f2) { - return f1.factor->equals(*(f2.factor), tol) && - std::abs(f1.constant - f2.constant) < tol; - }); + factors_.equals(e->factors_, + [tol](const sharedFactor &f1, const sharedFactor &f2) { + return f1->equals(*f2, tol); + }); } /* *******************************************************************************/ @@ -65,8 +61,7 @@ void GaussianMixtureFactor::print(const std::string &s, } else { factors_.print( "", [&](Key k) { return formatter(k); }, - [&](const FactorAndConstant &gf_z) -> std::string { - auto gf = gf_z.factor; + [&](const sharedFactor &gf) -> std::string { RedirectCout rd; std::cout << ":\n"; if (gf && !gf->empty()) { @@ -81,14 +76,9 @@ void GaussianMixtureFactor::print(const std::string &s, } /* *******************************************************************************/ -GaussianFactor::shared_ptr GaussianMixtureFactor::factor( +GaussianMixtureFactor::sharedFactor GaussianMixtureFactor::operator()( const DiscreteValues &assignment) const { - return factors_(assignment).factor; -} - -/* *******************************************************************************/ -double GaussianMixtureFactor::constant(const DiscreteValues &assignment) const { - return factors_(assignment).constant; + return factors_(assignment); } /* *******************************************************************************/ @@ -107,10 +97,10 @@ GaussianFactorGraphTree GaussianMixtureFactor::add( /* *******************************************************************************/ GaussianFactorGraphTree GaussianMixtureFactor::asGaussianFactorGraphTree() const { - auto wrap = [](const FactorAndConstant &factor_z) { + auto wrap = [](const sharedFactor &gf) { GaussianFactorGraph result; - result.push_back(factor_z.factor); - return GraphAndConstant(result, factor_z.constant); + result.push_back(gf); + return GraphAndConstant(result, 0.0); }; return {factors_, wrap}; } @@ -119,8 +109,8 @@ GaussianFactorGraphTree GaussianMixtureFactor::asGaussianFactorGraphTree() AlgebraicDecisionTree GaussianMixtureFactor::error( const VectorValues &continuousValues) const { // functor to convert from sharedFactor to double error value. - auto errorFunc = [continuousValues](const FactorAndConstant &factor_z) { - return factor_z.error(continuousValues); + auto errorFunc = [&continuousValues](const sharedFactor &gf) { + return gf->error(continuousValues); }; DecisionTree errorTree(factors_, errorFunc); return errorTree; @@ -128,8 +118,8 @@ AlgebraicDecisionTree GaussianMixtureFactor::error( /* *******************************************************************************/ double GaussianMixtureFactor::error(const HybridValues &values) const { - const FactorAndConstant factor_z = factors_(values.discrete()); - return factor_z.error(values.continuous()); + const sharedFactor gf = factors_(values.discrete()); + return gf->error(values.continuous()); } /* *******************************************************************************/ diff --git a/gtsam/hybrid/GaussianMixtureFactor.h b/gtsam/hybrid/GaussianMixtureFactor.h index 01de2f0f7..aa8f2a199 100644 --- a/gtsam/hybrid/GaussianMixtureFactor.h +++ b/gtsam/hybrid/GaussianMixtureFactor.h @@ -39,7 +39,7 @@ class VectorValues; * serves to "select" a mixture component corresponding to a GaussianFactor type * of measurement. * - * Represents the underlying Gaussian Mixture as a Decision Tree, where the set + * Represents the underlying Gaussian mixture as a Decision Tree, where the set * of discrete variables indexes to the continuous gaussian distribution. * * @ingroup hybrid @@ -52,38 +52,8 @@ class GTSAM_EXPORT GaussianMixtureFactor : public HybridFactor { using sharedFactor = boost::shared_ptr; - /// Gaussian factor and log of normalizing constant. - struct FactorAndConstant { - sharedFactor factor; - double constant; - - // Return error with constant correction. - double error(const VectorValues &values) const { - // Note: constant is log of normalization constant for probabilities. - // Errors is the negative log-likelihood, - // hence we subtract the constant here. - if (!factor) return 0.0; // If nullptr, return 0.0 error - return factor->error(values) - constant; - } - - // Check pointer equality. - bool operator==(const FactorAndConstant &other) const { - return factor == other.factor && constant == other.constant; - } - - private: - /** Serialization function */ - friend class boost::serialization::access; - template - void serialize(ARCHIVE &ar, const unsigned int /*version*/) { - ar &BOOST_SERIALIZATION_NVP(factor); - ar &BOOST_SERIALIZATION_NVP(constant); - } - }; - /// typedef for Decision Tree of Gaussian factors and log-constant. - using Factors = DecisionTree; - using Mixture = DecisionTree; + using Factors = DecisionTree; private: /// Decision tree of Gaussian factors indexed by discrete keys. @@ -105,7 +75,7 @@ class GTSAM_EXPORT GaussianMixtureFactor : public HybridFactor { GaussianMixtureFactor() = default; /** - * @brief Construct a new Gaussian Mixture Factor object. + * @brief Construct a new Gaussian mixture factor. * * @param continuousKeys A vector of keys representing continuous variables. * @param discreteKeys A vector of keys representing discrete variables and @@ -115,12 +85,7 @@ class GTSAM_EXPORT GaussianMixtureFactor : public HybridFactor { */ GaussianMixtureFactor(const KeyVector &continuousKeys, const DiscreteKeys &discreteKeys, - const Mixture &factors); - - GaussianMixtureFactor(const KeyVector &continuousKeys, - const DiscreteKeys &discreteKeys, - const Factors &factors_and_z) - : Base(continuousKeys, discreteKeys), factors_(factors_and_z) {} + const Factors &factors); /** * @brief Construct a new GaussianMixtureFactor object using a vector of @@ -134,7 +99,7 @@ class GTSAM_EXPORT GaussianMixtureFactor : public HybridFactor { const DiscreteKeys &discreteKeys, const std::vector &factors) : GaussianMixtureFactor(continuousKeys, discreteKeys, - Mixture(discreteKeys, factors)) {} + Factors(discreteKeys, factors)) {} /// @} /// @name Testable @@ -151,10 +116,7 @@ class GTSAM_EXPORT GaussianMixtureFactor : public HybridFactor { /// @{ /// Get factor at a given discrete assignment. - sharedFactor factor(const DiscreteValues &assignment) const; - - /// Get constant at a given discrete assignment. - double constant(const DiscreteValues &assignment) const; + sharedFactor operator()(const DiscreteValues &assignment) const; /** * @brief Combine the Gaussian Factor Graphs in `sum` and `this` while diff --git a/gtsam/hybrid/HybridGaussianFactorGraph.cpp b/gtsam/hybrid/HybridGaussianFactorGraph.cpp index c59187f4e..04ee21fc9 100644 --- a/gtsam/hybrid/HybridGaussianFactorGraph.cpp +++ b/gtsam/hybrid/HybridGaussianFactorGraph.cpp @@ -213,20 +213,20 @@ hybridElimination(const HybridGaussianFactorGraph &factors, // Collect all the factors to create a set of Gaussian factor graphs in a // decision tree indexed by all discrete keys involved. - GaussianFactorGraphTree sum = factors.assembleGraphTree(); + GaussianFactorGraphTree factorGraphTree = factors.assembleGraphTree(); // Convert factor graphs with a nullptr to an empty factor graph. // This is done after assembly since it is non-trivial to keep track of which // FG has a nullptr as we're looping over the factors. - sum = removeEmpty(sum); + factorGraphTree = removeEmpty(factorGraphTree); using EliminationPair = std::pair, - GaussianMixtureFactor::FactorAndConstant>; + GaussianMixtureFactor::sharedFactor>; // This is the elimination method on the leaf nodes auto eliminateFunc = [&](const GraphAndConstant &graph_z) -> EliminationPair { if (graph_z.graph.empty()) { - return {nullptr, {nullptr, 0.0}}; + return {nullptr, nullptr}; } #ifdef HYBRID_TIMING @@ -240,27 +240,19 @@ hybridElimination(const HybridGaussianFactorGraph &factors, // Get the log of the log normalization constant inverse and // add it to the previous constant. - const double logZ = - graph_z.constant - conditional->logNormalizationConstant(); - // Get the log of the log normalization constant inverse. - // double logZ = -conditional->logNormalizationConstant(); - // // IF this is the last continuous variable to eliminated, we need to - // // calculate the error here: the value of all factors at the mean, see - // // ml_map_rao.pdf. - // if (continuousSeparator.empty()) { - // const auto posterior_mean = conditional->solve(VectorValues()); - // logZ += graph_z.graph.error(posterior_mean); - // } + // const double logZ = + // graph_z.constant - conditional->logNormalizationConstant(); #ifdef HYBRID_TIMING gttoc_(hybrid_eliminate); #endif - return {conditional, {newFactor, logZ}}; + return {conditional, newFactor}; }; // Perform elimination! - DecisionTree eliminationResults(sum, eliminateFunc); + DecisionTree eliminationResults(factorGraphTree, + eliminateFunc); #ifdef HYBRID_TIMING tictoc_print_(); @@ -279,26 +271,17 @@ hybridElimination(const HybridGaussianFactorGraph &factors, // If there are no more continuous parents, then we should create a // DiscreteFactor here, with the error for each discrete choice. if (continuousSeparator.empty()) { - auto factorProb = - [&](const GaussianMixtureFactor::FactorAndConstant &factor_z) { - // This is the probability q(μ) at the MLE point. - // factor_z.factor is a factor without keys, - // just containing the residual. - return exp(-factor_z.error(VectorValues())); - }; + auto factorProb = [&](const EliminationPair &conditionalAndFactor) { + // This is the probability q(μ) at the MLE point. + // conditionalAndFactor.second is a factor without keys, just containing the residual. + static const VectorValues kEmpty; + // return exp(-conditionalAndFactor.first->logNormalizationConstant()); + // return exp(-conditionalAndFactor.first->logNormalizationConstant() - conditionalAndFactor.second->error(kEmpty)); + return exp( - conditionalAndFactor.second->error(kEmpty)); + // return 1.0; + }; - const DecisionTree fdt(newFactors, factorProb); - // // Normalize the values of decision tree to be valid probabilities - // double sum = 0.0; - // auto visitor = [&](double y) { sum += y; }; - // fdt.visit(visitor); - // // Check if sum is 0, and update accordingly. - // if (sum == 0) { - // sum = 1.0; - // } - // fdt = DecisionTree(fdt, - // [sum](const double &x) { return x / sum; - // }); + const DecisionTree fdt(eliminationResults, factorProb); const auto discreteFactor = boost::make_shared(discreteSeparator, fdt); @@ -375,6 +358,11 @@ EliminateHybrid(const HybridGaussianFactorGraph &factors, // PREPROCESS: Identify the nature of the current elimination + // TODO(dellaert): just check the factors: + // 1. if all factors are discrete, then we can do discrete elimination: + // 2. if all factors are continuous, then we can do continuous elimination: + // 3. if not, we do hybrid elimination: + // First, identify the separator keys, i.e. all keys that are not frontal. KeySet separatorKeys; for (auto &&factor : factors) { diff --git a/gtsam/hybrid/tests/testGaussianMixture.cpp b/gtsam/hybrid/tests/testGaussianMixture.cpp index 024aafbc7..4cca91b72 100644 --- a/gtsam/hybrid/tests/testGaussianMixture.cpp +++ b/gtsam/hybrid/tests/testGaussianMixture.cpp @@ -197,8 +197,7 @@ TEST(GaussianMixture, Likelihood) { const GaussianMixtureFactor::Factors factors( gm.conditionals(), [measurements](const GaussianConditional::shared_ptr& conditional) { - return GaussianMixtureFactor::FactorAndConstant{ - conditional->likelihood(measurements), 0.0}; + return conditional->likelihood(measurements); }); const GaussianMixtureFactor expected({X(0)}, {mode}, factors); EXPECT(assert_equal(expected, *factor)); diff --git a/gtsam/hybrid/tests/testHybridBayesNet.cpp b/gtsam/hybrid/tests/testHybridBayesNet.cpp index 3af131f09..fcb65a7de 100644 --- a/gtsam/hybrid/tests/testHybridBayesNet.cpp +++ b/gtsam/hybrid/tests/testHybridBayesNet.cpp @@ -34,6 +34,7 @@ using namespace gtsam; using noiseModel::Isotropic; using symbol_shorthand::M; using symbol_shorthand::X; +using symbol_shorthand::Z; static const Key asiaKey = 0; static const DiscreteKey Asia(asiaKey, 2); @@ -73,8 +74,12 @@ TEST(HybridBayesNet, EvaluatePureDiscrete) { /* ****************************************************************************/ // Test creation of a tiny hybrid Bayes net. TEST(HybridBayesNet, Tiny) { - auto bayesNet = tiny::createHybridBayesNet(); - EXPECT_LONGS_EQUAL(3, bayesNet.size()); + auto bn = tiny::createHybridBayesNet(); + EXPECT_LONGS_EQUAL(3, bn.size()); + + const VectorValues measurements{{Z(0), Vector1(5.0)}}; + auto fg = bn.toFactorGraph(measurements); + EXPECT_LONGS_EQUAL(4, fg.size()); } /* ****************************************************************************/ diff --git a/gtsam/hybrid/tests/testHybridGaussianFactorGraph.cpp b/gtsam/hybrid/tests/testHybridGaussianFactorGraph.cpp index ef89c0bfd..21a79e4e7 100644 --- a/gtsam/hybrid/tests/testHybridGaussianFactorGraph.cpp +++ b/gtsam/hybrid/tests/testHybridGaussianFactorGraph.cpp @@ -57,6 +57,9 @@ using gtsam::symbol_shorthand::X; using gtsam::symbol_shorthand::Y; using gtsam::symbol_shorthand::Z; +// Set up sampling +std::mt19937_64 kRng(42); + /* ************************************************************************* */ TEST(HybridGaussianFactorGraph, Creation) { HybridConditional conditional; @@ -638,24 +641,47 @@ TEST(HybridGaussianFactorGraph, assembleGraphTree) { // f(x0;mode=0)P(x0) and f(x0;mode=1)P(x0) GaussianFactorGraphTree expected{ M(0), - {GaussianFactorGraph(std::vector{mixture->factor(d0), prior}), - mixture->constant(d0)}, - {GaussianFactorGraph(std::vector{mixture->factor(d1), prior}), - mixture->constant(d1)}}; + {GaussianFactorGraph(std::vector{(*mixture)(d0), prior}), 0.0}, + {GaussianFactorGraph(std::vector{(*mixture)(d1), prior}), 0.0}}; EXPECT(assert_equal(expected(d0), actual(d0), 1e-5)); EXPECT(assert_equal(expected(d1), actual(d1), 1e-5)); } +/* ****************************************************************************/ +// Check that the factor graph unnormalized probability is proportional to the +// Bayes net probability for the given measurements. +bool ratioTest(const HybridBayesNet &bn, const VectorValues &measurements, + const HybridGaussianFactorGraph &fg, size_t num_samples = 10) { + auto compute_ratio = [&](HybridValues *sample) -> double { + sample->update(measurements); // update sample with given measurements: + return bn.evaluate(*sample) / fg.probPrime(*sample); + // return bn.evaluate(*sample) / posterior->evaluate(*sample); + }; + + HybridValues sample = bn.sample(&kRng); + double expected_ratio = compute_ratio(&sample); + + // Test ratios for a number of independent samples: + for (size_t i = 0; i < num_samples; i++) { + HybridValues sample = bn.sample(&kRng); + if (std::abs(expected_ratio - compute_ratio(&sample)) > 1e-6) return false; + } + return true; +} + /* ****************************************************************************/ // Check that eliminating tiny net with 1 measurement yields correct result. TEST(HybridGaussianFactorGraph, EliminateTiny1) { using symbol_shorthand::Z; const int num_measurements = 1; - auto fg = tiny::createHybridGaussianFactorGraph( - num_measurements, VectorValues{{Z(0), Vector1(5.0)}}); + const VectorValues measurements{{Z(0), Vector1(5.0)}}; + auto bn = tiny::createHybridBayesNet(num_measurements); + auto fg = bn.toFactorGraph(measurements); EXPECT_LONGS_EQUAL(4, fg.size()); + EXPECT(ratioTest(bn, measurements, fg)); + // Create expected Bayes Net: HybridBayesNet expectedBayesNet; @@ -675,6 +701,8 @@ TEST(HybridGaussianFactorGraph, EliminateTiny1) { // Test elimination const auto posterior = fg.eliminateSequential(); EXPECT(assert_equal(expectedBayesNet, *posterior, 0.01)); + + EXPECT(ratioTest(bn, measurements, *posterior)); } /* ****************************************************************************/ @@ -683,9 +711,9 @@ TEST(HybridGaussianFactorGraph, EliminateTiny2) { // Create factor graph with 2 measurements such that posterior mean = 5.0. using symbol_shorthand::Z; const int num_measurements = 2; - auto fg = tiny::createHybridGaussianFactorGraph( - num_measurements, - VectorValues{{Z(0), Vector1(4.0)}, {Z(1), Vector1(6.0)}}); + const VectorValues measurements{{Z(0), Vector1(4.0)}, {Z(1), Vector1(6.0)}}; + auto bn = tiny::createHybridBayesNet(num_measurements); + auto fg = bn.toFactorGraph(measurements); EXPECT_LONGS_EQUAL(6, fg.size()); // Create expected Bayes Net: @@ -707,6 +735,8 @@ TEST(HybridGaussianFactorGraph, EliminateTiny2) { // Test elimination const auto posterior = fg.eliminateSequential(); EXPECT(assert_equal(expectedBayesNet, *posterior, 0.01)); + + EXPECT(ratioTest(bn, measurements, *posterior)); } /* ****************************************************************************/ @@ -723,32 +753,12 @@ TEST(HybridGaussianFactorGraph, EliminateTiny22) { auto fg = bn.toFactorGraph(measurements); EXPECT_LONGS_EQUAL(7, fg.size()); + EXPECT(ratioTest(bn, measurements, fg)); + // Test elimination const auto posterior = fg.eliminateSequential(); - // Compute the log-ratio between the Bayes net and the factor graph. - auto compute_ratio = [&](HybridValues *sample) -> double { - // update sample with given measurements: - sample->update(measurements); - return bn.evaluate(*sample) / posterior->evaluate(*sample); - }; - - // Set up sampling - std::mt19937_64 rng(42); - - // The error evaluated by the factor graph and the Bayes net should differ by - // the normalizing term computed via the Bayes net determinant. - HybridValues sample = bn.sample(&rng); - double expected_ratio = compute_ratio(&sample); - // regression - EXPECT_DOUBLES_EQUAL(0.018253037966018862, expected_ratio, 1e-6); - - // Test ratios for a number of independent samples: - constexpr int num_samples = 100; - for (size_t i = 0; i < num_samples; i++) { - HybridValues sample = bn.sample(&rng); - EXPECT_DOUBLES_EQUAL(expected_ratio, compute_ratio(&sample), 1e-6); - } + EXPECT(ratioTest(bn, measurements, *posterior)); } /* ****************************************************************************/ @@ -818,31 +828,7 @@ TEST(HybridGaussianFactorGraph, EliminateSwitchingNetwork) { // Test resulting posterior Bayes net has correct size: EXPECT_LONGS_EQUAL(8, posterior->size()); - // TODO(dellaert): below is copy/pasta from above, refactor - - // Compute the log-ratio between the Bayes net and the factor graph. - auto compute_ratio = [&](HybridValues *sample) -> double { - // update sample with given measurements: - sample->update(measurements); - return bn.evaluate(*sample) / posterior->evaluate(*sample); - }; - - // Set up sampling - std::mt19937_64 rng(42); - - // The error evaluated by the factor graph and the Bayes net should differ by - // the normalizing term computed via the Bayes net determinant. - HybridValues sample = bn.sample(&rng); - double expected_ratio = compute_ratio(&sample); - // regression - EXPECT_DOUBLES_EQUAL(0.0094526745785019472, expected_ratio, 1e-6); - - // Test ratios for a number of independent samples: - constexpr int num_samples = 100; - for (size_t i = 0; i < num_samples; i++) { - HybridValues sample = bn.sample(&rng); - EXPECT_DOUBLES_EQUAL(expected_ratio, compute_ratio(&sample), 1e-6); - } + EXPECT(ratioTest(bn, measurements, *posterior)); } /* ************************************************************************* */