From d834897b14f134c98a78307f894d4c805094c881 Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Mon, 31 Oct 2022 15:38:23 -0400 Subject: [PATCH] update MixtureFactor so that all tests pass --- gtsam/hybrid/MixtureFactor.h | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/gtsam/hybrid/MixtureFactor.h b/gtsam/hybrid/MixtureFactor.h index 5a2383221..511705cf3 100644 --- a/gtsam/hybrid/MixtureFactor.h +++ b/gtsam/hybrid/MixtureFactor.h @@ -107,8 +107,12 @@ class MixtureFactor : public HybridFactor { std::copy(f->keys().begin(), f->keys().end(), std::inserter(factor_keys_set, factor_keys_set.end())); - nonlinear_factors.push_back( - boost::dynamic_pointer_cast(f)); + if (auto nf = boost::dynamic_pointer_cast(f)) { + nonlinear_factors.push_back(nf); + } else { + throw std::runtime_error( + "Factors passed into MixtureFactor need to be nonlinear!"); + } } factors_ = Factors(discreteKeys, nonlinear_factors); @@ -125,10 +129,10 @@ class MixtureFactor : public HybridFactor { * @brief Compute error of the MixtureFactor as a tree. * * @param continuousVals The continuous values for which to compute the error. - * @return DecisionTree A decision tree with corresponding keys + * @return AlgebraicDecisionTree A decision tree with corresponding keys * as the factor but leaf values as the error. */ - DecisionTree error(const Values& continuousVals) const { + AlgebraicDecisionTree error(const Values& continuousVals) const { // functor to convert from sharedFactor to double error value. auto errorFunc = [continuousVals](const sharedFactor& factor) { return factor->error(continuousVals); @@ -165,7 +169,7 @@ class MixtureFactor : public HybridFactor { /// print to stdout void print( - const std::string& s = "MixtureFactor", + const std::string& s = "", const KeyFormatter& keyFormatter = DefaultKeyFormatter) const override { std::cout << (s.empty() ? "" : s + " "); Base::print("", keyFormatter);