diff --git a/gtsam/hybrid/HybridGaussianConditional.cpp b/gtsam/hybrid/HybridGaussianConditional.cpp index df59637aa..e17fd3afe 100644 --- a/gtsam/hybrid/HybridGaussianConditional.cpp +++ b/gtsam/hybrid/HybridGaussianConditional.cpp @@ -28,13 +28,25 @@ #include namespace gtsam { +HybridGaussianFactor::FactorValuePairs GetFactorValuePairs( + const HybridGaussianConditional::Conditionals &conditionals) { + auto func = [](const GaussianConditional::shared_ptr &conditional) + -> GaussianFactorValuePair { + double value = 0.0; + if (conditional) { // Check if conditional is pruned + value = conditional->logNormalizationConstant(); + } + return {std::dynamic_pointer_cast(conditional), value}; + }; + return HybridGaussianFactor::FactorValuePairs(conditionals, func); +} HybridGaussianConditional::HybridGaussianConditional( const KeyVector &continuousFrontals, const KeyVector &continuousParents, const DiscreteKeys &discreteParents, const HybridGaussianConditional::Conditionals &conditionals) : BaseFactor(CollectKeys(continuousFrontals, continuousParents), - discreteParents), + discreteParents, GetFactorValuePairs(conditionals)), BaseConditional(continuousFrontals.size()), conditionals_(conditionals) { // Calculate logConstant_ as the maximum of the log constants of the diff --git a/gtsam/hybrid/HybridGaussianConditional.h b/gtsam/hybrid/HybridGaussianConditional.h index eb2bbb937..82cf6ec8a 100644 --- a/gtsam/hybrid/HybridGaussianConditional.h +++ b/gtsam/hybrid/HybridGaussianConditional.h @@ -51,13 +51,13 @@ class HybridValues; * @ingroup hybrid */ class GTSAM_EXPORT HybridGaussianConditional - : public HybridFactor, - public Conditional { + : public HybridGaussianFactor, + public Conditional { public: using This = HybridGaussianConditional; - using shared_ptr = std::shared_ptr; - using BaseFactor = HybridFactor; - using BaseConditional = Conditional; + using shared_ptr = std::shared_ptr; + using BaseFactor = HybridGaussianFactor; + using BaseConditional = Conditional; /// typedef for Decision Tree of Gaussian Conditionals using Conditionals = DecisionTree;