diff --git a/gtsam/hybrid/HybridNonlinearFactor.cpp b/gtsam/hybrid/HybridNonlinearFactor.cpp index 9613233b1..a541722c4 100644 --- a/gtsam/hybrid/HybridNonlinearFactor.cpp +++ b/gtsam/hybrid/HybridNonlinearFactor.cpp @@ -21,10 +21,57 @@ namespace gtsam { /* *******************************************************************************/ -HybridNonlinearFactor::HybridNonlinearFactor(const KeyVector& keys, +static void checkKeys(const KeyVector& continuousKeys, + std::vector& pairs) { + KeySet factor_keys_set; + for (const auto& pair : pairs) { + auto f = pair.first; + // Insert all factor continuous keys in the continuous keys set. + std::copy(f->keys().begin(), f->keys().end(), + std::inserter(factor_keys_set, factor_keys_set.end())); + } + + KeySet continuous_keys_set(continuousKeys.begin(), continuousKeys.end()); + if (continuous_keys_set != factor_keys_set) { + throw std::runtime_error( + "HybridNonlinearFactor: The specified continuous keys and the keys in " + "the factors do not match!"); + } +} + +/* *******************************************************************************/ +HybridNonlinearFactor::HybridNonlinearFactor( + const KeyVector& continuousKeys, const DiscreteKey& discreteKey, + const std::vector& factors) + : Base(continuousKeys, {discreteKey}) { + std::vector pairs; + for (auto&& f : factors) { + pairs.emplace_back(f, 0.0); + } + checkKeys(continuousKeys, pairs); + factors_ = FactorValuePairs({discreteKey}, pairs); +} + +/* *******************************************************************************/ +HybridNonlinearFactor::HybridNonlinearFactor( + const KeyVector& continuousKeys, const DiscreteKey& discreteKey, + const std::vector& factors) + : Base(continuousKeys, {discreteKey}) { + std::vector pairs; + KeySet continuous_keys_set(continuousKeys.begin(), continuousKeys.end()); + KeySet factor_keys_set; + for (auto&& [f, val] : factors) { + pairs.emplace_back(f, val); + } + checkKeys(continuousKeys, pairs); + factors_ = FactorValuePairs({discreteKey}, pairs); +} + +/* *******************************************************************************/ +HybridNonlinearFactor::HybridNonlinearFactor(const KeyVector& continuousKeys, const DiscreteKeys& discreteKeys, - const Factors& factors) - : Base(keys, discreteKeys), factors_(factors) {} + const FactorValuePairs& factors) + : Base(continuousKeys, discreteKeys), factors_(factors) {} /* *******************************************************************************/ AlgebraicDecisionTree HybridNonlinearFactor::errorTree( diff --git a/gtsam/hybrid/HybridNonlinearFactor.h b/gtsam/hybrid/HybridNonlinearFactor.h index 9852602de..766467518 100644 --- a/gtsam/hybrid/HybridNonlinearFactor.h +++ b/gtsam/hybrid/HybridNonlinearFactor.h @@ -68,11 +68,11 @@ class GTSAM_EXPORT HybridNonlinearFactor : public HybridFactor { * @brief typedef for DecisionTree which has Keys as node labels and * pairs of NonlinearFactor & an arbitrary scalar as leaf nodes. */ - using Factors = DecisionTree; + using FactorValuePairs = DecisionTree; private: - /// Decision tree of Gaussian factors indexed by discrete keys. - Factors factors_; + /// Decision tree of nonlinear factors indexed by discrete keys. + FactorValuePairs factors_; /// HybridFactor method implementation. Should not be used. AlgebraicDecisionTree errorTree( @@ -82,62 +82,49 @@ class GTSAM_EXPORT HybridNonlinearFactor : public HybridFactor { } public: + /// Default constructor, mainly for serialization. HybridNonlinearFactor() = default; /** - * @brief Construct from Decision tree. + * @brief Construct a new HybridNonlinearFactor on a single discrete key, + * providing the factors for each mode m as a vector of factors ϕ_m(x). + * The value ϕ(x,m) for the factor is simply ϕ_m(x). * - * @param keys Vector of keys for continuous factors. - * @param discreteKeys Vector of discrete keys. - * @param factors Decision tree with of shared factors. + * @param continuousKeys Vector of keys for continuous factors. + * @param discreteKey The discrete key for the "mode", indexing components. + * @param factors Vector of gaussian factors, one for each mode. */ - HybridNonlinearFactor(const KeyVector& keys, const DiscreteKeys& discreteKeys, - const Factors& factors); + HybridNonlinearFactor( + const KeyVector& continuousKeys, const DiscreteKey& discreteKey, + const std::vector& factors); /** - * @brief Convenience constructor that generates the underlying factor - * decision tree for us. + * @brief Construct a new HybridNonlinearFactor on a single discrete key, + * including a scalar error value for each mode m. The factors and scalars are + * provided as a vector of pairs (ϕ_m(x), E_m). + * The value ϕ(x,m) for the factor is now ϕ_m(x) + E_m. * - * Here it is important that the vector of factors has the correct number of - * elements based on the number of discrete keys and the cardinality of the - * keys, so that the decision tree is constructed appropriately. - * - * @tparam FACTOR The type of the factor shared pointers being passed in. - * Will be typecast to NonlinearFactor shared pointers. - * @param keys Vector of keys for continuous factors. - * @param discreteKey The discrete key indexing each component factor. - * @param factors Vector of nonlinear factor and scalar pairs. - * Same size as the cardinality of discreteKey. + * @param continuousKeys Vector of keys for continuous factors. + * @param discreteKey The discrete key for the "mode", indexing components. + * @param factors Vector of gaussian factor-scalar pairs, one per mode. */ - template - HybridNonlinearFactor( - const KeyVector& keys, const DiscreteKey& discreteKey, - const std::vector, double>>& factors) - : Base(keys, {discreteKey}) { - std::vector nonlinear_factors; - KeySet continuous_keys_set(keys.begin(), keys.end()); - KeySet factor_keys_set; - for (auto&& [f, val] : factors) { - // Insert all factor continuous keys in the continuous keys set. - std::copy(f->keys().begin(), f->keys().end(), - std::inserter(factor_keys_set, factor_keys_set.end())); - - if (auto nf = std::dynamic_pointer_cast(f)) { - nonlinear_factors.emplace_back(nf, val); - } else { - throw std::runtime_error( - "Factors passed into HybridNonlinearFactor need to be nonlinear!"); - } - } - factors_ = Factors({discreteKey}, nonlinear_factors); - - if (continuous_keys_set != factor_keys_set) { - throw std::runtime_error( - "The specified continuous keys and the keys in the factors don't " - "match!"); - } - } + HybridNonlinearFactor(const KeyVector& continuousKeys, + const DiscreteKey& discreteKey, + const std::vector& factors); + /** + * @brief Construct a new HybridNonlinearFactor on a several discrete keys M, + * including a scalar error value for each assignment m. The factors and + * scalars are provided as a DecisionTree of pairs (ϕ_M(x), E_M). + * The value ϕ(x,M) for the factor is again ϕ_m(x) + E_m. + * + * @param continuousKeys A vector of keys representing continuous variables. + * @param discreteKeys Discrete variables and their cardinalities. + * @param factors The decision tree of nonlinear factor/scalar pairs. + */ + HybridNonlinearFactor(const KeyVector& continuousKeys, + const DiscreteKeys& discreteKeys, + const FactorValuePairs& factors); /** * @brief Compute error of the HybridNonlinearFactor as a tree. * @@ -196,4 +183,9 @@ class GTSAM_EXPORT HybridNonlinearFactor : public HybridFactor { const Values& continuousValues) const; }; +// traits +template <> +struct traits : public Testable { +}; + } // namespace gtsam diff --git a/gtsam/hybrid/hybrid.i b/gtsam/hybrid/hybrid.i index 82881ac47..d1b8fbf6d 100644 --- a/gtsam/hybrid/hybrid.i +++ b/gtsam/hybrid/hybrid.i @@ -246,14 +246,18 @@ class HybridNonlinearFactorGraph { #include class HybridNonlinearFactor : gtsam::HybridFactor { HybridNonlinearFactor( - const gtsam::KeyVector& keys, const gtsam::DiscreteKeys& discreteKeys, - const gtsam::DecisionTree< - gtsam::Key, std::pair>& factors); + const gtsam::KeyVector& keys, const gtsam::DiscreteKey& discreteKey, + const std::vector& factors); HybridNonlinearFactor( const gtsam::KeyVector& keys, const gtsam::DiscreteKey& discreteKey, const std::vector>& factors); + HybridNonlinearFactor( + const gtsam::KeyVector& keys, const gtsam::DiscreteKeys& discreteKeys, + const gtsam::DecisionTree< + gtsam::Key, std::pair>& factors); + double error(const gtsam::Values& continuousValues, const gtsam::DiscreteValues& discreteValues) const;