diff --git a/gtsam/hybrid/HybridGaussianFactor.cpp b/gtsam/hybrid/HybridGaussianFactor.cpp index 2fbd4bd88..a4e0bf874 100644 --- a/gtsam/hybrid/HybridGaussianFactor.cpp +++ b/gtsam/hybrid/HybridGaussianFactor.cpp @@ -151,12 +151,26 @@ GaussianFactorGraphTree HybridGaussianFactor::asGaussianFactorGraphTree() return {factors_, wrap}; } +/* *******************************************************************************/ +double HybridGaussianFactor::potentiallyPrunedComponentError( + const sharedFactor &gf, const VectorValues &values) const { + // Check if valid pointer + if (gf) { + return gf->error(values); + } else { + // If not valid, pointer, it means this component was pruned, + // so we return maximum error. + // This way the negative exponential will give + // a probability value close to 0.0. + return std::numeric_limits::max(); + } +} /* *******************************************************************************/ AlgebraicDecisionTree HybridGaussianFactor::errorTree( const VectorValues &continuousValues) const { // functor to convert from sharedFactor to double error value. - auto errorFunc = [&continuousValues](const sharedFactor &gf) { - return gf->error(continuousValues); + auto errorFunc = [this, &continuousValues](const sharedFactor &gf) { + return this->potentiallyPrunedComponentError(gf, continuousValues); }; DecisionTree error_tree(factors_, errorFunc); return error_tree; @@ -164,8 +178,9 @@ AlgebraicDecisionTree HybridGaussianFactor::errorTree( /* *******************************************************************************/ double HybridGaussianFactor::error(const HybridValues &values) const { + // Directly index to get the component, no need to build the whole tree. const sharedFactor gf = factors_(values.discrete()); - return gf->error(values.continuous()); + return potentiallyPrunedComponentError(gf, values.continuous()); } } // namespace gtsam diff --git a/gtsam/hybrid/HybridGaussianFactor.h b/gtsam/hybrid/HybridGaussianFactor.h index 3bf9f3dfd..b1b93dc32 100644 --- a/gtsam/hybrid/HybridGaussianFactor.h +++ b/gtsam/hybrid/HybridGaussianFactor.h @@ -169,6 +169,10 @@ class GTSAM_EXPORT HybridGaussianFactor : public HybridFactor { /// @} private: + /// Helper method to compute the error of a component. + double potentiallyPrunedComponentError( + const sharedFactor &gf, const VectorValues &continuousValues) const; + #ifdef GTSAM_ENABLE_BOOST_SERIALIZATION /** Serialization function */ friend class boost::serialization::access;