diff --git a/gtsam/hybrid/tests/testHybridGaussianConditional.cpp b/gtsam/hybrid/tests/testHybridGaussianConditional.cpp index 7c7fc3ea5..ba6eaf4cd 100644 --- a/gtsam/hybrid/tests/testHybridGaussianConditional.cpp +++ b/gtsam/hybrid/tests/testHybridGaussianConditional.cpp @@ -308,6 +308,14 @@ TEST(HybridGaussianConditional, Prune) { #include +// Helper function to apply discrete values to the tree +auto choose(auto tree, const DiscreteValues &discreteValues) { + for (const auto &[key, value] : discreteValues) { + tree = tree.choose(key, value); + } + return tree; +} + /** * Return a HybridConditional by choosing branches based on the given discrete * values. If all discrete parents are specified, return a HybridConditional @@ -316,31 +324,27 @@ TEST(HybridGaussianConditional, Prune) { HybridConditional::shared_ptr choose( const HybridGaussianConditional::shared_ptr &self, const DiscreteValues &discreteValues) { - const auto &discreteParents = self->discreteKeys(); - DiscreteValues deadParentValues; - DiscreteKeys liveParents; - for (const auto &key : discreteParents) { - auto it = discreteValues.find(key.first); - if (it != discreteValues.end()) - deadParentValues[key.first] = it->second; - else - liveParents.emplace_back(key); + auto parentValues = discreteValues.filter(self->discreteKeys()); + auto unspecifiedParentKeys = discreteValues.missingKeys(self->discreteKeys()); + + // Case 1: Fully determined, return corresponding Gaussian conditional + if (parentValues.size() == self->discreteKeys().size()) { + return std::make_shared(self->choose(parentValues)); } - // If so then we just get the corresponding Gaussian conditional: - if (deadParentValues.size() == discreteParents.size()) { - // print on how many discreteParents we are choosing: - return std::make_shared(self->choose(deadParentValues)); - } else if (liveParents.size() > 0) { + + // Case 2: Some live parents remain, build a new tree + if (!unspecifiedParentKeys.empty()) { auto newTree = self->factors(); - for (auto &&[key, value] : discreteValues) { + for (const auto &[key, value] : parentValues) { newTree = newTree.choose(key, value); } return std::make_shared( - std::make_shared(liveParents, newTree)); - } else { - // Add as-is - return std::make_shared(self); + std::make_shared(unspecifiedParentKeys, + newTree)); } + + // Case 3: No changes needed, return original + return std::make_shared(self); } /* ************************************************************************* */ @@ -356,6 +360,11 @@ TEST(HybridGaussianConditional, PrunePlus) { EXPECT(oneParent->isHybrid()); EXPECT(oneParent->asHybrid()->nrComponents() == 2); + const HybridConditional::shared_ptr oneParent2 = + choose(hgc, {{M(7), 0}, {M(1), 0}}); + EXPECT(oneParent2->isHybrid()); + EXPECT(oneParent2->asHybrid()->nrComponents() == 2); + const HybridConditional::shared_ptr gaussian = choose(hgc, {{M(1), 0}, {M(2), 1}}); EXPECT(gaussian->asGaussian());