diff --git a/gtsam/hybrid/HybridGaussianFactor.h b/gtsam/hybrid/HybridGaussianFactor.h index b20dc130c..46b21b8aa 100644 --- a/gtsam/hybrid/HybridGaussianFactor.h +++ b/gtsam/hybrid/HybridGaussianFactor.h @@ -196,16 +196,15 @@ class GTSAM_EXPORT HybridGaussianFactor : public HybridFactor { double potentiallyPrunedComponentError( const sharedFactor &gf, const VectorValues &continuousValues) const; - /// Helper struct to assist in constructing the HybridGaussianFactor + /// Helper struct to assist private constructor below. struct ConstructorHelper { KeyVector continuousKeys; // Continuous keys extracted from factors DiscreteKeys discreteKeys; // Discrete keys provided to the constructors FactorValuePairs pairs; // Used only if factorsTree is empty Factors factorsTree; - ConstructorHelper( - const DiscreteKey &discreteKey, - const std::vector &factorsVec); + ConstructorHelper(const DiscreteKey &discreteKey, + const std::vector &factors); ConstructorHelper(const DiscreteKey &discreteKey, const std::vector &factorPairs); @@ -214,7 +213,7 @@ class GTSAM_EXPORT HybridGaussianFactor : public HybridFactor { const FactorValuePairs &factorPairs); }; - // Private constructor using ConstructorHelper + // Private constructor using ConstructorHelper above. HybridGaussianFactor(const ConstructorHelper &helper) : Base(helper.continuousKeys, helper.discreteKeys), factors_(helper.factorsTree.empty() ? augment(helper.pairs) diff --git a/gtsam/hybrid/HybridNonlinearFactor.cpp b/gtsam/hybrid/HybridNonlinearFactor.cpp index 301d79593..351a7fea4 100644 --- a/gtsam/hybrid/HybridNonlinearFactor.cpp +++ b/gtsam/hybrid/HybridNonlinearFactor.cpp @@ -18,55 +18,59 @@ #include +#include + +#include "gtsam/nonlinear/NonlinearFactor.h" + namespace gtsam { /* *******************************************************************************/ -static void checkKeys(const KeyVector& continuousKeys, - const std::vector& pairs) { - KeySet factor_keys_set; - for (const auto& pair : pairs) { - auto f = pair.first; - // Insert all factor continuous keys in the continuous keys set. - std::copy(f->keys().begin(), f->keys().end(), - std::inserter(factor_keys_set, factor_keys_set.end())); - } - - KeySet continuous_keys_set(continuousKeys.begin(), continuousKeys.end()); - if (continuous_keys_set != factor_keys_set) { +static void CopyOrCheckContinuousKeys(const NonlinearFactor::shared_ptr& factor, + KeyVector* continuousKeys) { + if (!factor) return; + if (continuousKeys->empty()) { + *continuousKeys = factor->keys(); + } else if (factor->keys() != *continuousKeys) { throw std::runtime_error( - "HybridNonlinearFactor: The specified continuous keys and the keys in " - "the factors do not match!"); + "HybridNonlinearFactor: all factors should have the same keys!"); } } /* *******************************************************************************/ -HybridNonlinearFactor::HybridNonlinearFactor( - const KeyVector& continuousKeys, const DiscreteKey& discreteKey, +HybridNonlinearFactor::ConstructorHelper::ConstructorHelper( + const DiscreteKey& discreteKey, const std::vector& factors) - : Base(continuousKeys, {discreteKey}) { + : discreteKeys({discreteKey}) { std::vector pairs; - for (auto&& f : factors) { - pairs.emplace_back(f, 0.0); + // Extract continuous keys from the first non-null factor + for (const auto& factor : factors) { + pairs.emplace_back(factor, 0.0); + CopyOrCheckContinuousKeys(factor, &continuousKeys); } - checkKeys(continuousKeys, pairs); - factors_ = FactorValuePairs({discreteKey}, pairs); + factorTree = FactorValuePairs({discreteKey}, pairs); } /* *******************************************************************************/ -HybridNonlinearFactor::HybridNonlinearFactor( - const KeyVector& continuousKeys, const DiscreteKey& discreteKey, +HybridNonlinearFactor::ConstructorHelper::ConstructorHelper( + const DiscreteKey& discreteKey, const std::vector& pairs) - : Base(continuousKeys, {discreteKey}) { - KeySet continuous_keys_set(continuousKeys.begin(), continuousKeys.end()); - checkKeys(continuousKeys, pairs); - factors_ = FactorValuePairs({discreteKey}, pairs); + : discreteKeys({discreteKey}) { + // Extract continuous keys from the first non-null factor + for (const auto& pair : pairs) { + CopyOrCheckContinuousKeys(pair.first, &continuousKeys); + } + factorTree = FactorValuePairs({discreteKey}, pairs); } /* *******************************************************************************/ -HybridNonlinearFactor::HybridNonlinearFactor(const KeyVector& continuousKeys, - const DiscreteKeys& discreteKeys, - const FactorValuePairs& factors) - : Base(continuousKeys, discreteKeys), factors_(factors) {} +HybridNonlinearFactor::ConstructorHelper::ConstructorHelper( + const DiscreteKeys& discreteKeys, const FactorValuePairs& factorPairs) + : discreteKeys(discreteKeys), factorTree(factorPairs) { + // Extract continuous keys from the first non-null factor + factorPairs.visit([&](const NonlinearFactorValuePair& pair) { + CopyOrCheckContinuousKeys(pair.first, &continuousKeys); + }); +} /* *******************************************************************************/ AlgebraicDecisionTree HybridNonlinearFactor::errorTree( diff --git a/gtsam/hybrid/HybridNonlinearFactor.h b/gtsam/hybrid/HybridNonlinearFactor.h index 7843afc83..161a7b357 100644 --- a/gtsam/hybrid/HybridNonlinearFactor.h +++ b/gtsam/hybrid/HybridNonlinearFactor.h @@ -90,13 +90,12 @@ class GTSAM_EXPORT HybridNonlinearFactor : public HybridFactor { * providing the factors for each mode m as a vector of factors ϕ_m(x). * The value ϕ(x,m) for the factor is simply ϕ_m(x). * - * @param continuousKeys Vector of keys for continuous factors. * @param discreteKey The discrete key for the "mode", indexing components. * @param factors Vector of gaussian factors, one for each mode. */ - HybridNonlinearFactor( - const KeyVector& continuousKeys, const DiscreteKey& discreteKey, - const std::vector& factors); + HybridNonlinearFactor(const DiscreteKey& discreteKey, + const std::vector& factors) + : HybridNonlinearFactor(ConstructorHelper(discreteKey, factors)) {} /** * @brief Construct a new HybridNonlinearFactor on a single discrete key, @@ -104,13 +103,12 @@ class GTSAM_EXPORT HybridNonlinearFactor : public HybridFactor { * provided as a vector of pairs (ϕ_m(x), E_m). * The value ϕ(x,m) for the factor is now ϕ_m(x) + E_m. * - * @param continuousKeys Vector of keys for continuous factors. * @param discreteKey The discrete key for the "mode", indexing components. * @param pairs Vector of gaussian factor-scalar pairs, one per mode. */ - HybridNonlinearFactor(const KeyVector& continuousKeys, - const DiscreteKey& discreteKey, - const std::vector& pairs); + HybridNonlinearFactor(const DiscreteKey& discreteKey, + const std::vector& pairs) + : HybridNonlinearFactor(ConstructorHelper(discreteKey, pairs)) {} /** * @brief Construct a new HybridNonlinearFactor on a several discrete keys M, @@ -118,13 +116,12 @@ class GTSAM_EXPORT HybridNonlinearFactor : public HybridFactor { * scalars are provided as a DecisionTree of pairs (ϕ_M(x), E_M). * The value ϕ(x,M) for the factor is again ϕ_m(x) + E_m. * - * @param continuousKeys A vector of keys representing continuous variables. * @param discreteKeys Discrete variables and their cardinalities. * @param factors The decision tree of nonlinear factor/scalar pairs. */ - HybridNonlinearFactor(const KeyVector& continuousKeys, - const DiscreteKeys& discreteKeys, - const FactorValuePairs& factors); + HybridNonlinearFactor(const DiscreteKeys& discreteKeys, + const FactorValuePairs& factors) + : HybridNonlinearFactor(ConstructorHelper(discreteKeys, factors)) {} /** * @brief Compute error of the HybridNonlinearFactor as a tree. * @@ -181,6 +178,28 @@ class GTSAM_EXPORT HybridNonlinearFactor : public HybridFactor { /// Linearize all the continuous factors to get a HybridGaussianFactor. std::shared_ptr linearize( const Values& continuousValues) const; + + private: + /// Helper struct to assist private constructor below. + struct ConstructorHelper { + KeyVector continuousKeys; // Continuous keys extracted from factors + DiscreteKeys discreteKeys; // Discrete keys provided to the constructors + FactorValuePairs factorTree; + + ConstructorHelper(const DiscreteKey& discreteKey, + const std::vector& factors); + + ConstructorHelper(const DiscreteKey& discreteKey, + const std::vector& factorPairs); + + ConstructorHelper(const DiscreteKeys& discreteKeys, + const FactorValuePairs& factorPairs); + }; + + // Private constructor using ConstructorHelper above. + HybridNonlinearFactor(const ConstructorHelper& helper) + : Base(helper.continuousKeys, helper.discreteKeys), + factors_(helper.factorTree) {} }; // traits