diff --git a/gtsam/hybrid/HybridFactor.h b/gtsam/hybrid/HybridFactor.h index fc91e0838..1e99c1365 100644 --- a/gtsam/hybrid/HybridFactor.h +++ b/gtsam/hybrid/HybridFactor.h @@ -56,11 +56,9 @@ class GTSAM_EXPORT HybridFactor : public Factor { /// Enum to help with categorizing hybrid factors. enum class Category { None, Discrete, Continuous, Hybrid }; - private: + protected: /// Record what category of HybridFactor this is. Category category_ = Category::None; - - protected: // Set of DiscreteKeys for this factor. DiscreteKeys discreteKeys_; /// Record continuous keys for book-keeping diff --git a/gtsam/hybrid/HybridGaussianConditional.cpp b/gtsam/hybrid/HybridGaussianConditional.cpp index 00ead068a..903542c24 100644 --- a/gtsam/hybrid/HybridGaussianConditional.cpp +++ b/gtsam/hybrid/HybridGaussianConditional.cpp @@ -61,17 +61,22 @@ HybridGaussianConditional::HybridGaussianConditional( } return {std::dynamic_pointer_cast(c), value}; }; - BaseFactor::factors_ = HybridGaussianFactor::Factors(conditionals, func); + HybridGaussianFactor::FactorValuePairs pairs(conditionals, func); - // Initialize base classes - KeyVector continuousKeys = frontals; - continuousKeys.insert(continuousKeys.end(), parents.begin(), parents.end()); - BaseFactor::keys_ = continuousKeys; - BaseFactor::discreteKeys_ = discreteParents; + // Adjust frontals size BaseConditional::nrFrontals_ = frontals.size(); - // Assign conditionals - conditionals_ = conditionals; // TODO(frank): a duplicate of factors_ !!! + // Initialize HybridFactor + HybridFactor::category_ = HybridFactor::Category::Hybrid; + HybridFactor::discreteKeys_ = discreteParents; + HybridFactor::keys_ = frontals; + keys_.insert(keys_.end(), parents.begin(), parents.end()); + + // Initialize BaseFactor + BaseFactor::factors_ = BaseFactor::augment(pairs); // TODO(frank): expensive + + // Assign local conditionals. TODO(frank): a duplicate of factors_ !!! + conditionals_ = conditionals; } /* *******************************************************************************/ diff --git a/gtsam/hybrid/HybridGaussianFactor.cpp b/gtsam/hybrid/HybridGaussianFactor.cpp index d5773590b..af2d0fde5 100644 --- a/gtsam/hybrid/HybridGaussianFactor.cpp +++ b/gtsam/hybrid/HybridGaussianFactor.cpp @@ -38,11 +38,11 @@ namespace gtsam { * Gaussian factor in factors. * @return HybridGaussianFactor::Factors */ -static HybridGaussianFactor::Factors augment( - const HybridGaussianFactor::FactorValuePairs &factors) { +HybridGaussianFactor::Factors HybridGaussianFactor::augment( + const FactorValuePairs &factors) { // Find the minimum value so we can "proselytize" to positive values. // Done because we can't have sqrt of negative numbers. - HybridGaussianFactor::Factors gaussianFactors; + Factors gaussianFactors; AlgebraicDecisionTree valueTree; std::tie(gaussianFactors, valueTree) = unzip(factors); @@ -73,7 +73,7 @@ static HybridGaussianFactor::Factors augment( return std::dynamic_pointer_cast( std::make_shared(gfg)); }; - return HybridGaussianFactor::Factors(factors, update); + return Factors(factors, update); } /* *******************************************************************************/ diff --git a/gtsam/hybrid/HybridGaussianFactor.h b/gtsam/hybrid/HybridGaussianFactor.h index 1aa6a0263..9aa505092 100644 --- a/gtsam/hybrid/HybridGaussianFactor.h +++ b/gtsam/hybrid/HybridGaussianFactor.h @@ -73,6 +73,9 @@ class GTSAM_EXPORT HybridGaussianFactor : public HybridFactor { /// Decision tree of Gaussian factors indexed by discrete keys. Factors factors_; + /// Helper function to "hide" the constants in the Jacobian factors. + static Factors augment(const FactorValuePairs &factors); + /** * @brief Helper function to return factors and functional to create a * DecisionTree of Gaussian Factor Graphs.