diff --git a/gtsam/hybrid/HybridGaussianConditional.cpp b/gtsam/hybrid/HybridGaussianConditional.cpp index 3826ef899..ac03bd3a3 100644 --- a/gtsam/hybrid/HybridGaussianConditional.cpp +++ b/gtsam/hybrid/HybridGaussianConditional.cpp @@ -33,6 +33,19 @@ #include namespace gtsam { + +/* *******************************************************************************/ +GaussianConditional::shared_ptr checkConditional( + const GaussianFactor::shared_ptr &factor) { + if (auto conditional = + std::dynamic_pointer_cast(factor)) { + return conditional; + } else { + throw std::logic_error( + "A HybridGaussianConditional unexpectedly contained a non-conditional"); + } +} + /* *******************************************************************************/ /** * @brief Helper struct for constructing HybridGaussianConditional objects @@ -92,10 +105,7 @@ struct HybridGaussianConditional::Helper { explicit Helper(const FactorValuePairs &pairs) : pairs(pairs) { auto func = [this](const GaussianFactorValuePair &pair) { if (!pair.first) return; - auto gc = std::dynamic_pointer_cast(pair.first); - if (!gc) - throw std::runtime_error( - "HybridGaussianConditional called with non-conditional."); + auto gc = checkConditional(pair.first); if (!nrFrontals) nrFrontals = gc->nrFrontals(); minNegLogConstant = std::min(minNegLogConstant, pair.second); }; @@ -179,14 +189,11 @@ size_t HybridGaussianConditional::nrComponents() const { /* *******************************************************************************/ GaussianConditional::shared_ptr HybridGaussianConditional::choose( const DiscreteValues &discreteValues) const { - auto &[ptr, _] = factors()(discreteValues); - if (!ptr) return nullptr; - auto conditional = std::dynamic_pointer_cast(ptr); - if (conditional) - return conditional; - else - throw std::logic_error( - "A HybridGaussianConditional unexpectedly contained a non-conditional"); + auto &[factor, _] = factors()(discreteValues); + if (!factor) return nullptr; + + auto conditional = checkConditional(factor); + return conditional; } /* *******************************************************************************/ @@ -320,28 +327,23 @@ HybridGaussianConditional::shared_ptr HybridGaussianConditional::prune( }; FactorValuePairs prunedConditionals = factors().apply(pruner); - return std::make_shared(discreteKeys(), prunedConditionals); + return std::make_shared(discreteKeys(), + prunedConditionals); } /* *******************************************************************************/ double HybridGaussianConditional::logProbability( const HybridValues &values) const { auto [factor, _] = factors()(values.discrete()); - if (auto conditional = std::dynamic_pointer_cast(factor)) - return conditional->logProbability(values.continuous()); - else - throw std::logic_error( - "A HybridGaussianConditional unexpectedly contained a non-conditional"); + auto conditional = checkConditional(factor); + return conditional->logProbability(values.continuous()); } /* *******************************************************************************/ double HybridGaussianConditional::evaluate(const HybridValues &values) const { auto [factor, _] = factors()(values.discrete()); - if (auto conditional = std::dynamic_pointer_cast(factor)) - return conditional->evaluate(values.continuous()); - else - throw std::logic_error( - "A HybridGaussianConditional unexpectedly contained a non-conditional"); + auto conditional = checkConditional(factor); + return conditional->evaluate(values.continuous()); } } // namespace gtsam