diff --git a/gtsam/hybrid/tests/testHybridGaussianConditional.cpp b/gtsam/hybrid/tests/testHybridGaussianConditional.cpp index 8bb83cac4..7c7fc3ea5 100644 --- a/gtsam/hybrid/tests/testHybridGaussianConditional.cpp +++ b/gtsam/hybrid/tests/testHybridGaussianConditional.cpp @@ -238,22 +238,27 @@ TEST(HybridGaussianConditional, Likelihood2) { EXPECT_DOUBLES_EQUAL(ratio[0], ratio[1], 1e-8); } +/* ************************************************************************* */ +namespace two_mode_measurement { +// Create a two key conditional: +const DiscreteKeys modes{{M(1), 2}, {M(2), 2}}; +const std::vector gcs = { + GaussianConditional::sharedMeanAndStddev(Z(0), Vector1(1), 1), + GaussianConditional::sharedMeanAndStddev(Z(0), Vector1(2), 2), + GaussianConditional::sharedMeanAndStddev(Z(0), Vector1(3), 3), + GaussianConditional::sharedMeanAndStddev(Z(0), Vector1(4), 4)}; +const HybridGaussianConditional::Conditionals conditionals(modes, gcs); +const auto hgc = + std::make_shared(modes, conditionals); +} // namespace two_mode_measurement + /* ************************************************************************* */ // Test pruning a HybridGaussianConditional with two discrete keys, based on a // DecisionTreeFactor with 3 keys: TEST(HybridGaussianConditional, Prune) { - // Create a two key conditional: - DiscreteKeys modes{{M(1), 2}, {M(2), 2}}; - std::vector gcs; - for (size_t i = 0; i < 4; i++) { - gcs.push_back( - GaussianConditional::sharedMeanAndStddev(Z(0), Vector1(i + 1), i + 1)); - } - auto empty = std::make_shared(); - HybridGaussianConditional::Conditionals conditionals(modes, gcs); - HybridGaussianConditional hgc(modes, conditionals); + using two_mode_measurement::hgc; - DiscreteKeys keys = modes; + DiscreteKeys keys = two_mode_measurement::modes; keys.push_back({M(3), 2}); { for (size_t i = 0; i < 8; i++) { @@ -262,7 +267,7 @@ TEST(HybridGaussianConditional, Prune) { const DecisionTreeFactor decisionTreeFactor(keys, potentials); // Prune the HybridGaussianConditional const auto pruned = - hgc.prune(DiscreteConditional(keys.size(), decisionTreeFactor)); + hgc->prune(DiscreteConditional(keys.size(), decisionTreeFactor)); // Check that the pruned HybridGaussianConditional has 1 conditional EXPECT_LONGS_EQUAL(1, pruned->nrComponents()); } @@ -273,14 +278,14 @@ TEST(HybridGaussianConditional, Prune) { const DecisionTreeFactor decisionTreeFactor(keys, potentials); const auto pruned = - hgc.prune(DiscreteConditional(keys.size(), decisionTreeFactor)); + hgc->prune(DiscreteConditional(keys.size(), decisionTreeFactor)); // Check that the pruned HybridGaussianConditional has 2 conditionals EXPECT_LONGS_EQUAL(2, pruned->nrComponents()); // Check that the minimum negLogConstant is set correctly EXPECT_DOUBLES_EQUAL( - hgc.conditionals()({{M(1), 0}, {M(2), 1}})->negLogConstant(), + hgc->conditionals()({{M(1), 0}, {M(2), 1}})->negLogConstant(), pruned->negLogConstant(), 1e-9); } { @@ -289,18 +294,74 @@ TEST(HybridGaussianConditional, Prune) { const DecisionTreeFactor decisionTreeFactor(keys, potentials); const auto pruned = - hgc.prune(DiscreteConditional(keys.size(), decisionTreeFactor)); + hgc->prune(DiscreteConditional(keys.size(), decisionTreeFactor)); // Check that the pruned HybridGaussianConditional has 3 conditionals EXPECT_LONGS_EQUAL(3, pruned->nrComponents()); // Check that the minimum negLogConstant is correct - EXPECT_DOUBLES_EQUAL(hgc.negLogConstant(), pruned->negLogConstant(), 1e-9); + EXPECT_DOUBLES_EQUAL(hgc->negLogConstant(), pruned->negLogConstant(), 1e-9); } } -/* ************************************************************************* +/* ************************************************************************* */ + +#include + +/** + * Return a HybridConditional by choosing branches based on the given discrete + * values. If all discrete parents are specified, return a HybridConditional + * which is just a GaussianConditional. */ +HybridConditional::shared_ptr choose( + const HybridGaussianConditional::shared_ptr &self, + const DiscreteValues &discreteValues) { + const auto &discreteParents = self->discreteKeys(); + DiscreteValues deadParentValues; + DiscreteKeys liveParents; + for (const auto &key : discreteParents) { + auto it = discreteValues.find(key.first); + if (it != discreteValues.end()) + deadParentValues[key.first] = it->second; + else + liveParents.emplace_back(key); + } + // If so then we just get the corresponding Gaussian conditional: + if (deadParentValues.size() == discreteParents.size()) { + // print on how many discreteParents we are choosing: + return std::make_shared(self->choose(deadParentValues)); + } else if (liveParents.size() > 0) { + auto newTree = self->factors(); + for (auto &&[key, value] : discreteValues) { + newTree = newTree.choose(key, value); + } + return std::make_shared( + std::make_shared(liveParents, newTree)); + } else { + // Add as-is + return std::make_shared(self); + } +} + +/* ************************************************************************* */ +// Test the pruning and dead-mode removal. +TEST(HybridGaussianConditional, PrunePlus) { + using two_mode_measurement::hgc; // two discrete parents + + const HybridConditional::shared_ptr same = choose(hgc, {}); + EXPECT(same->isHybrid()); + EXPECT(same->asHybrid()->nrComponents() == 4); + + const HybridConditional::shared_ptr oneParent = choose(hgc, {{M(1), 0}}); + EXPECT(oneParent->isHybrid()); + EXPECT(oneParent->asHybrid()->nrComponents() == 2); + + const HybridConditional::shared_ptr gaussian = + choose(hgc, {{M(1), 0}, {M(2), 1}}); + EXPECT(gaussian->asGaussian()); +} + +/* ************************************************************************* */ int main() { TestResult tr; return TestRegistry::runAllTests(tr);