diff --git a/gtsam/hybrid/HybridConditional.cpp b/gtsam/hybrid/HybridConditional.cpp index 97ec1a1f8..257eca314 100644 --- a/gtsam/hybrid/HybridConditional.cpp +++ b/gtsam/hybrid/HybridConditional.cpp @@ -169,4 +169,41 @@ double HybridConditional::evaluate(const HybridValues &values) const { return std::exp(logProbability(values)); } +/* ************************************************************************ */ +HybridConditional::shared_ptr HybridConditional::restrict( + const DiscreteValues &discreteValues) const { + if (auto gc = asGaussian()) { + return std::make_shared(gc); + } else if (auto dc = asDiscrete()) { + return std::make_shared(dc); + }; + + auto hgc = asHybrid(); + if (!hgc) + throw std::runtime_error( + "HybridConditional::restrict: conditional type not handled"); + + // Case 1: Fully determined, return corresponding Gaussian conditional + auto parentValues = discreteValues.filter(discreteKeys_); + if (parentValues.size() == discreteKeys_.size()) { + return std::make_shared(hgc->choose(parentValues)); + } + + // Case 2: Some live parents remain, build a new tree + auto unspecifiedParentKeys = discreteValues.missingKeys(discreteKeys_); + if (!unspecifiedParentKeys.empty()) { + auto newTree = hgc->factors(); + for (const auto &[key, value] : parentValues) { + newTree = newTree.choose(key, value); + } + return std::make_shared( + std::make_shared(unspecifiedParentKeys, + newTree)); + } + + // Case 3: No changes needed, return original + return std::make_shared(hgc); +} + +/* ************************************************************************ */ } // namespace gtsam diff --git a/gtsam/hybrid/HybridConditional.h b/gtsam/hybrid/HybridConditional.h index 3cf5b80e5..075fbe411 100644 --- a/gtsam/hybrid/HybridConditional.h +++ b/gtsam/hybrid/HybridConditional.h @@ -215,6 +215,14 @@ class GTSAM_EXPORT HybridConditional return true; } + /** + * Return a HybridConditional by choosing branches based on the given discrete + * values. If all discrete parents are specified, return a HybridConditional + * which is just a GaussianConditional. If this conditional is *not* a hybrid + * conditional, just return that. + */ + shared_ptr restrict(const DiscreteValues& discreteValues) const; + /// @} private: diff --git a/gtsam/hybrid/tests/testHybridGaussianConditional.cpp b/gtsam/hybrid/tests/testHybridGaussianConditional.cpp index ba6eaf4cd..032be5a78 100644 --- a/gtsam/hybrid/tests/testHybridGaussianConditional.cpp +++ b/gtsam/hybrid/tests/testHybridGaussianConditional.cpp @@ -316,57 +316,34 @@ auto choose(auto tree, const DiscreteValues &discreteValues) { return tree; } -/** - * Return a HybridConditional by choosing branches based on the given discrete - * values. If all discrete parents are specified, return a HybridConditional - * which is just a GaussianConditional. +/* ************************************************************************* + * This test verifies the behavior of the restrict method in different + * scenarios: + * - When no restrictions are applied. + * - When one parent is restricted. + * - When two parents are restricted. + * - When the restriction results in a Gaussian conditional. */ -HybridConditional::shared_ptr choose( - const HybridGaussianConditional::shared_ptr &self, - const DiscreteValues &discreteValues) { - auto parentValues = discreteValues.filter(self->discreteKeys()); - auto unspecifiedParentKeys = discreteValues.missingKeys(self->discreteKeys()); +TEST(HybridGaussianConditional, Restrict) { + // Create a HybridConditional with two discrete parents P(z0|m0,m1) + const auto hc = + std::make_shared(two_mode_measurement::hgc); - // Case 1: Fully determined, return corresponding Gaussian conditional - if (parentValues.size() == self->discreteKeys().size()) { - return std::make_shared(self->choose(parentValues)); - } - - // Case 2: Some live parents remain, build a new tree - if (!unspecifiedParentKeys.empty()) { - auto newTree = self->factors(); - for (const auto &[key, value] : parentValues) { - newTree = newTree.choose(key, value); - } - return std::make_shared( - std::make_shared(unspecifiedParentKeys, - newTree)); - } - - // Case 3: No changes needed, return original - return std::make_shared(self); -} - -/* ************************************************************************* */ -// Test the pruning and dead-mode removal. -TEST(HybridGaussianConditional, PrunePlus) { - using two_mode_measurement::hgc; // two discrete parents - - const HybridConditional::shared_ptr same = choose(hgc, {}); + const HybridConditional::shared_ptr same = hc->restrict({}); EXPECT(same->isHybrid()); EXPECT(same->asHybrid()->nrComponents() == 4); - const HybridConditional::shared_ptr oneParent = choose(hgc, {{M(1), 0}}); + const HybridConditional::shared_ptr oneParent = hc->restrict({{M(1), 0}}); EXPECT(oneParent->isHybrid()); EXPECT(oneParent->asHybrid()->nrComponents() == 2); const HybridConditional::shared_ptr oneParent2 = - choose(hgc, {{M(7), 0}, {M(1), 0}}); + hc->restrict({{M(7), 0}, {M(1), 0}}); EXPECT(oneParent2->isHybrid()); EXPECT(oneParent2->asHybrid()->nrComponents() == 2); const HybridConditional::shared_ptr gaussian = - choose(hgc, {{M(1), 0}, {M(2), 1}}); + hc->restrict({{M(1), 0}, {M(2), 1}}); EXPECT(gaussian->asGaussian()); }