From dfef2c202ff2dd86d8a36086ef114f680aa3438f Mon Sep 17 00:00:00 2001 From: Frank Dellaert Date: Fri, 13 Jan 2023 08:26:41 -0800 Subject: [PATCH] Simplify elimination --- gtsam/hybrid/HybridGaussianFactorGraph.cpp | 16 +++++------- .../tests/testHybridGaussianFactorGraph.cpp | 26 ++++++++++++++++++- 2 files changed, 31 insertions(+), 11 deletions(-) diff --git a/gtsam/hybrid/HybridGaussianFactorGraph.cpp b/gtsam/hybrid/HybridGaussianFactorGraph.cpp index 04ee21fc9..ac6734d48 100644 --- a/gtsam/hybrid/HybridGaussianFactorGraph.cpp +++ b/gtsam/hybrid/HybridGaussianFactorGraph.cpp @@ -271,19 +271,15 @@ hybridElimination(const HybridGaussianFactorGraph &factors, // If there are no more continuous parents, then we should create a // DiscreteFactor here, with the error for each discrete choice. if (continuousSeparator.empty()) { - auto factorProb = [&](const EliminationPair &conditionalAndFactor) { - // This is the probability q(μ) at the MLE point. - // conditionalAndFactor.second is a factor without keys, just containing the residual. + auto probPrime = [&](const GaussianMixtureFactor::sharedFactor &factor) { + // This is the unnormalized probability q(μ) at the mean. + // The factor has no keys, just contains the residual. static const VectorValues kEmpty; - // return exp(-conditionalAndFactor.first->logNormalizationConstant()); - // return exp(-conditionalAndFactor.first->logNormalizationConstant() - conditionalAndFactor.second->error(kEmpty)); - return exp( - conditionalAndFactor.second->error(kEmpty)); - // return 1.0; + return factor? exp(-factor->error(kEmpty)) : 1.0; }; - const DecisionTree fdt(eliminationResults, factorProb); - const auto discreteFactor = - boost::make_shared(discreteSeparator, fdt); + const auto discreteFactor = boost::make_shared( + discreteSeparator, DecisionTree(newFactors, probPrime)); return {boost::make_shared(gaussianMixture), discreteFactor}; diff --git a/gtsam/hybrid/tests/testHybridGaussianFactorGraph.cpp b/gtsam/hybrid/tests/testHybridGaussianFactorGraph.cpp index 21a79e4e7..c51d65da1 100644 --- a/gtsam/hybrid/tests/testHybridGaussianFactorGraph.cpp +++ b/gtsam/hybrid/tests/testHybridGaussianFactorGraph.cpp @@ -652,7 +652,7 @@ TEST(HybridGaussianFactorGraph, assembleGraphTree) { // Check that the factor graph unnormalized probability is proportional to the // Bayes net probability for the given measurements. bool ratioTest(const HybridBayesNet &bn, const VectorValues &measurements, - const HybridGaussianFactorGraph &fg, size_t num_samples = 10) { + const HybridGaussianFactorGraph &fg, size_t num_samples = 100) { auto compute_ratio = [&](HybridValues *sample) -> double { sample->update(measurements); // update sample with given measurements: return bn.evaluate(*sample) / fg.probPrime(*sample); @@ -670,6 +670,28 @@ bool ratioTest(const HybridBayesNet &bn, const VectorValues &measurements, return true; } +/* ****************************************************************************/ +// Check that the factor graph unnormalized probability is proportional to the +// Bayes net probability for the given measurements. +bool ratioTest(const HybridBayesNet &bn, const VectorValues &measurements, + const HybridBayesNet &posterior, size_t num_samples = 100) { + auto compute_ratio = [&](HybridValues *sample) -> double { + sample->update(measurements); // update sample with given measurements: + // return bn.evaluate(*sample) / fg.probPrime(*sample); + return bn.evaluate(*sample) / posterior.evaluate(*sample); + }; + + HybridValues sample = bn.sample(&kRng); + double expected_ratio = compute_ratio(&sample); + + // Test ratios for a number of independent samples: + for (size_t i = 0; i < num_samples; i++) { + HybridValues sample = bn.sample(&kRng); + if (std::abs(expected_ratio - compute_ratio(&sample)) > 1e-6) return false; + } + return true; +} + /* ****************************************************************************/ // Check that eliminating tiny net with 1 measurement yields correct result. TEST(HybridGaussianFactorGraph, EliminateTiny1) { @@ -678,6 +700,7 @@ TEST(HybridGaussianFactorGraph, EliminateTiny1) { const VectorValues measurements{{Z(0), Vector1(5.0)}}; auto bn = tiny::createHybridBayesNet(num_measurements); auto fg = bn.toFactorGraph(measurements); + GTSAM_PRINT(bn); EXPECT_LONGS_EQUAL(4, fg.size()); EXPECT(ratioTest(bn, measurements, fg)); @@ -701,6 +724,7 @@ TEST(HybridGaussianFactorGraph, EliminateTiny1) { // Test elimination const auto posterior = fg.eliminateSequential(); EXPECT(assert_equal(expectedBayesNet, *posterior, 0.01)); + GTSAM_PRINT(*posterior); EXPECT(ratioTest(bn, measurements, *posterior)); }