diff --git a/gtsam/hybrid/HybridGaussianFactor.cpp b/gtsam/hybrid/HybridGaussianFactor.cpp index 12ec1d2a8..279ac4069 100644 --- a/gtsam/hybrid/HybridGaussianFactor.cpp +++ b/gtsam/hybrid/HybridGaussianFactor.cpp @@ -70,14 +70,13 @@ HybridGaussianFactor::Factors HybridGaussianFactor::augment( } /* *******************************************************************************/ -HybridGaussianFactor::HybridGaussianFactor( +HybridGaussianFactor::ConstructorHelper::ConstructorHelper( const DiscreteKey &discreteKey, const std::vector &factors) - : Base(HybridFactor::Category::Hybrid) { - // Extract continuous keys from first-null factor and verify all others - KeyVector continuousKeys; + : discreteKeys({discreteKey}) { + // Extract continuous keys from the first non-null factor for (const auto &factor : factors) { - if (!factor) continue; + if (!factor) continue; // Skip null factors if (continuousKeys.empty()) { continuousKeys = factor->keys(); } else if (factor->keys() != continuousKeys) { @@ -85,25 +84,18 @@ HybridGaussianFactor::HybridGaussianFactor( } } - // Initialize the base class - Factor::keys_ = continuousKeys; - Factor::keys_.push_back(discreteKey.first); - Base::discreteKeys_ = {discreteKey}; - Base::continuousKeys_ = continuousKeys; - - // Build the DecisionTree from factor vector - factors_ = Factors({discreteKey}, factors); + // Build the DecisionTree from the factor vector + factorsTree = Factors(discreteKeys, factors); } /* *******************************************************************************/ -HybridGaussianFactor::HybridGaussianFactor( +HybridGaussianFactor::ConstructorHelper::ConstructorHelper( const DiscreteKey &discreteKey, const std::vector &factorPairs) - : Base(HybridFactor::Category::Hybrid) { - // Extract continuous keys from first-null factor and verify all others - KeyVector continuousKeys; + : discreteKeys({discreteKey}) { + // Extract continuous keys from the first non-null factor for (const auto &pair : factorPairs) { - if (!pair.first) continue; + if (!pair.first) continue; // Skip null factors if (continuousKeys.empty()) { continuousKeys = pair.first->keys(); } else if (pair.first->keys() != continuousKeys) { @@ -111,45 +103,26 @@ HybridGaussianFactor::HybridGaussianFactor( } } - // Initialize the base class - Factor::keys_ = continuousKeys; - Factor::keys_.push_back(discreteKey.first); - Base::discreteKeys_ = {discreteKey}; - Base::continuousKeys_ = continuousKeys; - // Build the FactorValuePairs DecisionTree - FactorValuePairs pairTree({discreteKey}, factorPairs); - - // Assign factors_ after calling augment - factors_ = augment(pairTree); + pairs = FactorValuePairs(discreteKeys, factorPairs); } /* *******************************************************************************/ -HybridGaussianFactor::HybridGaussianFactor(const DiscreteKeys &discreteKeys, - const FactorValuePairs &factorPairs) - : Base(HybridFactor::Category::Hybrid) { - // Verify that all factors have the same keys - KeyVector continuousKeys; +HybridGaussianFactor::ConstructorHelper::ConstructorHelper( + const DiscreteKeys &discreteKeys, const FactorValuePairs &factorPairs) + : discreteKeys(discreteKeys) { + // Extract continuous keys from the first non-null factor factorPairs.visit([&](const GaussianFactorValuePair &pair) { - if (pair.first) { - if (continuousKeys.empty()) { - continuousKeys = pair.first->keys(); - } else if (pair.first->keys() != continuousKeys) { - throw std::invalid_argument("All factors must have the same keys"); - } + if (!pair.first) return; // Skip null factors + if (continuousKeys.empty()) { + continuousKeys = pair.first->keys(); + } else if (pair.first->keys() != continuousKeys) { + throw std::invalid_argument("All factors must have the same keys"); } }); - // Initialize the base class - Factor::keys_ = continuousKeys; - for (const auto &discreteKey : discreteKeys) { - Factor::keys_.push_back(discreteKey.first); - } - Base::discreteKeys_ = discreteKeys; - Base::continuousKeys_ = continuousKeys; - - // Assign factors_ after calling augment - factors_ = augment(factorPairs); + // Build the FactorValuePairs DecisionTree + pairs = factorPairs; } /* *******************************************************************************/ diff --git a/gtsam/hybrid/HybridGaussianFactor.h b/gtsam/hybrid/HybridGaussianFactor.h index 0dc80ed36..b20dc130c 100644 --- a/gtsam/hybrid/HybridGaussianFactor.h +++ b/gtsam/hybrid/HybridGaussianFactor.h @@ -89,7 +89,8 @@ class GTSAM_EXPORT HybridGaussianFactor : public HybridFactor { * @param factors Vector of gaussian factors, one for each mode. */ HybridGaussianFactor(const DiscreteKey &discreteKey, - const std::vector &factors); + const std::vector &factors) + : HybridGaussianFactor(ConstructorHelper(discreteKey, factors)) {} /** * @brief Construct a new HybridGaussianFactor on a single discrete key, @@ -101,7 +102,8 @@ class GTSAM_EXPORT HybridGaussianFactor : public HybridFactor { * @param factorPairs Vector of gaussian factor-scalar pairs, one per mode. */ HybridGaussianFactor(const DiscreteKey &discreteKey, - const std::vector &factorPairs); + const std::vector &factorPairs) + : HybridGaussianFactor(ConstructorHelper(discreteKey, factorPairs)) {} /** * @brief Construct a new HybridGaussianFactor on a several discrete keys M, @@ -113,7 +115,8 @@ class GTSAM_EXPORT HybridGaussianFactor : public HybridFactor { * @param factors The decision tree of Gaussian factor/scalar pairs. */ HybridGaussianFactor(const DiscreteKeys &discreteKeys, - const FactorValuePairs &factors); + const FactorValuePairs &factors) + : HybridGaussianFactor(ConstructorHelper(discreteKeys, factors)) {} /// @} /// @name Testable @@ -193,6 +196,30 @@ class GTSAM_EXPORT HybridGaussianFactor : public HybridFactor { double potentiallyPrunedComponentError( const sharedFactor &gf, const VectorValues &continuousValues) const; + /// Helper struct to assist in constructing the HybridGaussianFactor + struct ConstructorHelper { + KeyVector continuousKeys; // Continuous keys extracted from factors + DiscreteKeys discreteKeys; // Discrete keys provided to the constructors + FactorValuePairs pairs; // Used only if factorsTree is empty + Factors factorsTree; + + ConstructorHelper( + const DiscreteKey &discreteKey, + const std::vector &factorsVec); + + ConstructorHelper(const DiscreteKey &discreteKey, + const std::vector &factorPairs); + + ConstructorHelper(const DiscreteKeys &discreteKeys, + const FactorValuePairs &factorPairs); + }; + + // Private constructor using ConstructorHelper + HybridGaussianFactor(const ConstructorHelper &helper) + : Base(helper.continuousKeys, helper.discreteKeys), + factors_(helper.factorsTree.empty() ? augment(helper.pairs) + : helper.factorsTree) {} + #ifdef GTSAM_ENABLE_BOOST_SERIALIZATION /** Serialization function */ friend class boost::serialization::access; diff --git a/gtsam/hybrid/tests/testGaussianMixture.cpp b/gtsam/hybrid/tests/testGaussianMixture.cpp index 743b98c66..cde7e4063 100644 --- a/gtsam/hybrid/tests/testGaussianMixture.cpp +++ b/gtsam/hybrid/tests/testGaussianMixture.cpp @@ -101,12 +101,6 @@ TEST(GaussianMixture, GaussianMixtureModel) { auto hbn = GaussianMixtureModel(mu0, mu1, sigma, sigma); - // Check the number of keys matches what we expect - auto hgc = hbn.at(0)->asHybrid(); - EXPECT_LONGS_EQUAL(2, hgc->keys().size()); - EXPECT_LONGS_EQUAL(1, hgc->continuousKeys().size()); - EXPECT_LONGS_EQUAL(1, hgc->discreteKeys().size()); - // At the halfway point between the means, we should get P(m|z)=0.5 double midway = mu1 - mu0; auto pMid = SolveHBN(hbn, midway); @@ -141,12 +135,6 @@ TEST(GaussianMixture, GaussianMixtureModel2) { auto hbn = GaussianMixtureModel(mu0, mu1, sigma0, sigma1); - // Check the number of keys matches what we expect - auto hgc = hbn.at(0)->asHybrid(); - EXPECT_LONGS_EQUAL(2, hgc->keys().size()); - EXPECT_LONGS_EQUAL(1, hgc->continuousKeys().size()); - EXPECT_LONGS_EQUAL(1, hgc->discreteKeys().size()); - // We get zMax=3.1333 by finding the maximum value of the function, at which // point the mode m==1 is about twice as probable as m==0. double zMax = 3.133;