From 35dd42ed2329c264f77f6b9553ca2c9ecc1446e0 Mon Sep 17 00:00:00 2001 From: Frank Dellaert Date: Thu, 26 Sep 2024 12:53:50 -0700 Subject: [PATCH] Three constructor variants extract keys --- gtsam/hybrid/HybridFactor.h | 4 + gtsam/hybrid/HybridGaussianFactor.cpp | 115 +++++++++++++++++++++++--- gtsam/hybrid/HybridGaussianFactor.h | 55 ++++++------ 3 files changed, 134 insertions(+), 40 deletions(-) diff --git a/gtsam/hybrid/HybridFactor.h b/gtsam/hybrid/HybridFactor.h index 1e99c1365..39a72eb26 100644 --- a/gtsam/hybrid/HybridFactor.h +++ b/gtsam/hybrid/HybridFactor.h @@ -140,6 +140,10 @@ class GTSAM_EXPORT HybridFactor : public Factor { /// @} + protected: + /// protected constructor to initialize the category + HybridFactor(Category category) : category_(category) {} + private: #ifdef GTSAM_ENABLE_BOOST_SERIALIZATION /** Serialization function */ diff --git a/gtsam/hybrid/HybridGaussianFactor.cpp b/gtsam/hybrid/HybridGaussianFactor.cpp index 3db6e7965..952bb7017 100644 --- a/gtsam/hybrid/HybridGaussianFactor.cpp +++ b/gtsam/hybrid/HybridGaussianFactor.cpp @@ -26,18 +26,11 @@ #include #include +#include "gtsam/hybrid/HybridFactor.h" + namespace gtsam { -/** - * @brief Helper function to augment the [A|b] matrices in the factor components - * with the additional scalar values. - * This is done by storing the value in - * the `b` vector as an additional row. - * - * @param factors DecisionTree of GaussianFactors and arbitrary scalars. - * Gaussian factor in factors. - * @return HybridGaussianFactor::Factors - */ +/* *******************************************************************************/ HybridGaussianFactor::Factors HybridGaussianFactor::augment( const FactorValuePairs &factors) { // Find the minimum value so we can "proselytize" to positive values. @@ -76,13 +69,111 @@ HybridGaussianFactor::Factors HybridGaussianFactor::augment( return Factors(factors, update); } +/* *******************************************************************************/ +HybridGaussianFactor::HybridGaussianFactor( + const DiscreteKey &discreteKey, + const std::vector &factors) + : Base(HybridFactor::Category::Hybrid) { + // Extract continuous keys from first-null factor and verify all others + KeyVector continuousKeys; + for (const auto &factor : factors) { + if (!factor) continue; + if (continuousKeys.empty()) { + continuousKeys = factor->keys(); + } else if (factor->keys() != continuousKeys) { + throw std::invalid_argument("All factors must have the same keys"); + } + } + + // Check that this worked. + if (continuousKeys.empty()) { + throw std::invalid_argument("Need at least one non-null factor."); + } + + // Initialize the base class + Factor::keys_ = continuousKeys; + Factor::keys_.push_back(discreteKey.first); + Base::discreteKeys_ = {discreteKey}; + Base::continuousKeys_ = continuousKeys; + + // Build the DecisionTree from factor vector + factors_ = Factors({discreteKey}, factors); +} + +/* *******************************************************************************/ +HybridGaussianFactor::HybridGaussianFactor( + const DiscreteKey &discreteKey, + const std::vector &factorPairs) + : Base(HybridFactor::Category::Hybrid) { + // Extract continuous keys from first-null factor and verify all others + KeyVector continuousKeys; + for (const auto &pair : factorPairs) { + if (!pair.first) continue; + if (continuousKeys.empty()) { + continuousKeys = pair.first->keys(); + } else if (pair.first->keys() != continuousKeys) { + throw std::invalid_argument("All factors must have the same keys"); + } + } + + // Check that this worked. + if (continuousKeys.empty()) { + throw std::invalid_argument("Need at least one non-null factor."); + } + + // Initialize the base class + Factor::keys_ = continuousKeys; + Factor::keys_.push_back(discreteKey.first); + Base::discreteKeys_ = {discreteKey}; + Base::continuousKeys_ = continuousKeys; + + // Build the FactorValuePairs DecisionTree + FactorValuePairs pairTree({discreteKey}, factorPairs); + + // Assign factors_ after calling augment + factors_ = augment(pairTree); +} + +/* *******************************************************************************/ +HybridGaussianFactor::HybridGaussianFactor(const DiscreteKeys &discreteKeys, + const FactorValuePairs &factorPairs) + : Base(HybridFactor::Category::Hybrid) { + // Verify that all factors have the same keys + KeyVector continuousKeys; + factorPairs.visit([&](const GaussianFactorValuePair &pair) { + if (pair.first) { + if (continuousKeys.empty()) { + continuousKeys = pair.first->keys(); + } else if (pair.first->keys() != continuousKeys) { + throw std::invalid_argument("All factors must have the same keys"); + } + } + }); + + // Check that this worked. + if (continuousKeys.empty()) { + throw std::invalid_argument("Need at least one non-null factor."); + } + + // Initialize the base class + Factor::keys_ = continuousKeys; + for (const auto &discreteKey : discreteKeys) { + Factor::keys_.push_back(discreteKey.first); + } + Base::discreteKeys_ = discreteKeys; + Base::continuousKeys_ = continuousKeys; + + // Assign factors_ after calling augment + factors_ = augment(factorPairs); +} + /* *******************************************************************************/ bool HybridGaussianFactor::equals(const HybridFactor &lf, double tol) const { const This *e = dynamic_cast(&lf); if (e == nullptr) return false; - // This will return false if either factors_ is empty or e->factors_ is empty, - // but not if both are empty or both are not empty: + // This will return false if either factors_ is empty or e->factors_ is + // empty, but not if both are empty or both are not empty: if (factors_.empty() ^ e->factors_.empty()) return false; // Check the base and the factors: diff --git a/gtsam/hybrid/HybridGaussianFactor.h b/gtsam/hybrid/HybridGaussianFactor.h index 8e36e6615..0dc80ed36 100644 --- a/gtsam/hybrid/HybridGaussianFactor.h +++ b/gtsam/hybrid/HybridGaussianFactor.h @@ -73,17 +73,6 @@ class GTSAM_EXPORT HybridGaussianFactor : public HybridFactor { /// Decision tree of Gaussian factors indexed by discrete keys. Factors factors_; - /// Helper function to "hide" the constants in the Jacobian factors. - static Factors augment(const FactorValuePairs &factors); - - /** - * @brief Helper function to return factors and functional to create a - * DecisionTree of Gaussian Factor Graphs. - * - * @return GaussianFactorGraphTree - */ - GaussianFactorGraphTree asGaussianFactorGraphTree() const; - public: /// @name Constructors /// @{ @@ -96,14 +85,11 @@ class GTSAM_EXPORT HybridGaussianFactor : public HybridFactor { * providing the factors for each mode m as a vector of factors ϕ_m(x). * The value ϕ(x,m) for the factor is simply ϕ_m(x). * - * @param continuousKeys Vector of keys for continuous factors. * @param discreteKey The discrete key for the "mode", indexing components. * @param factors Vector of gaussian factors, one for each mode. */ - HybridGaussianFactor(const KeyVector &continuousKeys, - const DiscreteKey &discreteKey, - const std::vector &factors) - : Base(continuousKeys, {discreteKey}), factors_({discreteKey}, factors) {} + HybridGaussianFactor(const DiscreteKey &discreteKey, + const std::vector &factors); /** * @brief Construct a new HybridGaussianFactor on a single discrete key, @@ -111,15 +97,11 @@ class GTSAM_EXPORT HybridGaussianFactor : public HybridFactor { * provided as a vector of pairs (ϕ_m(x), E_m). * The value ϕ(x,m) for the factor is now ϕ_m(x) + E_m. * - * @param continuousKeys Vector of keys for continuous factors. * @param discreteKey The discrete key for the "mode", indexing components. - * @param factors Vector of gaussian factor-scalar pairs, one per mode. + * @param factorPairs Vector of gaussian factor-scalar pairs, one per mode. */ - HybridGaussianFactor(const KeyVector &continuousKeys, - const DiscreteKey &discreteKey, - const std::vector &factors) - : HybridGaussianFactor(continuousKeys, {discreteKey}, - FactorValuePairs({discreteKey}, factors)) {} + HybridGaussianFactor(const DiscreteKey &discreteKey, + const std::vector &factorPairs); /** * @brief Construct a new HybridGaussianFactor on a several discrete keys M, @@ -127,14 +109,11 @@ class GTSAM_EXPORT HybridGaussianFactor : public HybridFactor { * scalars are provided as a DecisionTree of pairs (ϕ_M(x), E_M). * The value ϕ(x,M) for the factor is again ϕ_m(x) + E_m. * - * @param continuousKeys A vector of keys representing continuous variables. * @param discreteKeys Discrete variables and their cardinalities. * @param factors The decision tree of Gaussian factor/scalar pairs. */ - HybridGaussianFactor(const KeyVector &continuousKeys, - const DiscreteKeys &discreteKeys, - const FactorValuePairs &factors) - : Base(continuousKeys, discreteKeys), factors_(augment(factors)) {} + HybridGaussianFactor(const DiscreteKeys &discreteKeys, + const FactorValuePairs &factors); /// @} /// @name Testable @@ -189,7 +168,27 @@ class GTSAM_EXPORT HybridGaussianFactor : public HybridFactor { } /// @} + protected: + /** + * @brief Helper function to return factors and functional to create a + * DecisionTree of Gaussian Factor Graphs. + * + * @return GaussianFactorGraphTree + */ + GaussianFactorGraphTree asGaussianFactorGraphTree() const; + private: + /** + * @brief Helper function to augment the [A|b] matrices in the factor + * components with the additional scalar values. This is done by storing the + * value in the `b` vector as an additional row. + * + * @param factors DecisionTree of GaussianFactors and arbitrary scalars. + * Gaussian factor in factors. + * @return HybridGaussianFactor::Factors + */ + static Factors augment(const FactorValuePairs &factors); + /// Helper method to compute the error of a component. double potentiallyPrunedComponentError( const sharedFactor &gf, const VectorValues &continuousValues) const;