diff --git a/gtsam/hybrid/GaussianMixture.cpp b/gtsam/hybrid/GaussianMixture.cpp index 1663236f0..000057518 100644 --- a/gtsam/hybrid/GaussianMixture.cpp +++ b/gtsam/hybrid/GaussianMixture.cpp @@ -79,7 +79,8 @@ GaussianMixture::asGaussianFactorGraphTree() const { /* *******************************************************************************/ bool GaussianMixture::equals(const HybridFactor &lf, double tol) const { - return BaseFactor::equals(lf, tol); + const This *e = dynamic_cast(&lf); + return e != nullptr && BaseFactor::equals(*e, tol); } /* *******************************************************************************/ diff --git a/gtsam/hybrid/GaussianMixtureFactor.cpp b/gtsam/hybrid/GaussianMixtureFactor.cpp index 589e5c660..a81cf341d 100644 --- a/gtsam/hybrid/GaussianMixtureFactor.cpp +++ b/gtsam/hybrid/GaussianMixtureFactor.cpp @@ -34,7 +34,8 @@ GaussianMixtureFactor::GaussianMixtureFactor(const KeyVector &continuousKeys, /* *******************************************************************************/ bool GaussianMixtureFactor::equals(const HybridFactor &lf, double tol) const { - return Base::equals(lf, tol); + const This *e = dynamic_cast(&lf); + return e != nullptr && Base::equals(*e, tol); } /* *******************************************************************************/ diff --git a/gtsam/hybrid/GaussianMixtureFactor.h b/gtsam/hybrid/GaussianMixtureFactor.h index b2fbe4aef..b3569183b 100644 --- a/gtsam/hybrid/GaussianMixtureFactor.h +++ b/gtsam/hybrid/GaussianMixtureFactor.h @@ -11,7 +11,7 @@ /** * @file GaussianMixtureFactor.h - * @brief A factor that is a function of discrete and continuous variables. + * @brief A set of GaussianFactors, indexed by a set of discrete keys. * @author Fan Jiang * @author Varun Agrawal * @author Frank Dellaert @@ -32,9 +32,10 @@ class GaussianFactorGraph; using GaussianFactorVector = std::vector; /** - * @brief A linear factor that is a function of both discrete and continuous - * variables, i.e. P(X, M | Z) where X is the set of continuous variables, M is - * the set of discrete variables and Z is the measurement set. + * @brief Implementation of a discrete conditional mixture factor. + * Implements a joint discrete-continuous factor where the discrete variable + * serves to "select" a mixture component corresponding to a GaussianFactor type + * of measurement. * * Represents the underlying Gaussian Mixture as a Decision Tree, where the set * of discrete variables indexes to the continuous gaussian distribution. @@ -52,6 +53,7 @@ class GaussianMixtureFactor : public HybridFactor { using Factors = DecisionTree; private: + /// Decision tree of Gaussian factors indexed by discrete keys. Factors factors_; /** diff --git a/gtsam/hybrid/HybridConditional.cpp b/gtsam/hybrid/HybridConditional.cpp index 8f95caf34..8e071532d 100644 --- a/gtsam/hybrid/HybridConditional.cpp +++ b/gtsam/hybrid/HybridConditional.cpp @@ -101,7 +101,8 @@ void HybridConditional::print(const std::string &s, /* ************************************************************************ */ bool HybridConditional::equals(const HybridFactor &other, double tol) const { - return BaseFactor::equals(other, tol); + const This *e = dynamic_cast(&other); + return e != nullptr && BaseFactor::equals(*e, tol); } } // namespace gtsam diff --git a/gtsam/hybrid/HybridDiscreteFactor.cpp b/gtsam/hybrid/HybridDiscreteFactor.cpp index 989127a28..2bdcdee8c 100644 --- a/gtsam/hybrid/HybridDiscreteFactor.cpp +++ b/gtsam/hybrid/HybridDiscreteFactor.cpp @@ -38,7 +38,9 @@ HybridDiscreteFactor::HybridDiscreteFactor(DecisionTreeFactor &&dtf) /* ************************************************************************ */ bool HybridDiscreteFactor::equals(const HybridFactor &lf, double tol) const { - return Base::equals(lf, tol); + const This *e = dynamic_cast(&lf); + // TODO(Varun) How to compare inner_ when they are abstract types? + return e != nullptr && Base::equals(*e, tol); } /* ************************************************************************ */ diff --git a/gtsam/hybrid/HybridFactor.cpp b/gtsam/hybrid/HybridFactor.cpp index 815ea0415..127c9761c 100644 --- a/gtsam/hybrid/HybridFactor.cpp +++ b/gtsam/hybrid/HybridFactor.cpp @@ -70,7 +70,10 @@ HybridFactor::HybridFactor(const DiscreteKeys &discreteKeys) /* ************************************************************************ */ bool HybridFactor::equals(const HybridFactor &lf, double tol) const { - return Base::equals(lf, tol); + const This *e = dynamic_cast(&lf); + return e != nullptr && Base::equals(*e, tol) && + isDiscrete_ == e->isDiscrete_ && isContinuous_ == e->isContinuous_ && + isHybrid_ == e->isHybrid_ && nrContinuous_ == e->nrContinuous_; } /* ************************************************************************ */ diff --git a/gtsam/hybrid/HybridFactor.h b/gtsam/hybrid/HybridFactor.h index da103da43..244fba4cc 100644 --- a/gtsam/hybrid/HybridFactor.h +++ b/gtsam/hybrid/HybridFactor.h @@ -34,6 +34,7 @@ DiscreteKeys CollectDiscreteKeys(const DiscreteKeys &key1, /** * Base class for hybrid probabilistic factors + * * Examples: * - HybridGaussianFactor * - HybridDiscreteFactor @@ -64,13 +65,29 @@ class GTSAM_EXPORT HybridFactor : public Factor { /** Default constructor creates empty factor */ HybridFactor() = default; + /** + * @brief Construct hybrid factor from continuous keys. + * + * @param keys Vector of continuous keys. + */ explicit HybridFactor(const KeyVector &keys); + /** + * @brief Construct hybrid factor from discrete keys. + * + * @param keys Vector of discrete keys. + */ + explicit HybridFactor(const DiscreteKeys &discreteKeys); + + /** + * @brief Construct a new Hybrid Factor object. + * + * @param continuousKeys Vector of keys for continuous variables. + * @param discreteKeys Vector of keys for discrete variables. + */ HybridFactor(const KeyVector &continuousKeys, const DiscreteKeys &discreteKeys); - explicit HybridFactor(const DiscreteKeys &discreteKeys); - /// Virtual destructor virtual ~HybridFactor() = default; diff --git a/gtsam/hybrid/HybridGaussianFactor.cpp b/gtsam/hybrid/HybridGaussianFactor.cpp index 721cb4cc7..59d20fb79 100644 --- a/gtsam/hybrid/HybridGaussianFactor.cpp +++ b/gtsam/hybrid/HybridGaussianFactor.cpp @@ -21,18 +21,23 @@ namespace gtsam { +/* ************************************************************************* */ HybridGaussianFactor::HybridGaussianFactor(GaussianFactor::shared_ptr other) - : Base(other->keys()) { - inner_ = other; -} + : Base(other->keys()), inner_(other) {} +/* ************************************************************************* */ HybridGaussianFactor::HybridGaussianFactor(JacobianFactor &&jf) : Base(jf.keys()), inner_(boost::make_shared(std::move(jf))) {} -bool HybridGaussianFactor::equals(const HybridFactor &lf, double tol) const { - return false; +/* ************************************************************************* */ +bool HybridGaussianFactor::equals(const HybridFactor &other, double tol) const { + const This *e = dynamic_cast(&other); + // TODO(Varun) How to compare inner_ when they are abstract types? + return e != nullptr && Base::equals(*e, tol); } + +/* ************************************************************************* */ void HybridGaussianFactor::print(const std::string &s, const KeyFormatter &formatter) const { HybridFactor::print(s, formatter); diff --git a/gtsam/hybrid/HybridGaussianFactor.h b/gtsam/hybrid/HybridGaussianFactor.h index 8d457e778..6fa83b726 100644 --- a/gtsam/hybrid/HybridGaussianFactor.h +++ b/gtsam/hybrid/HybridGaussianFactor.h @@ -25,7 +25,8 @@ namespace gtsam { /** * A HybridGaussianFactor is a layer over GaussianFactor so that we do not have - * a diamond inheritance. + * a diamond inheritance i.e. an extra factor type that inherits from both + * HybridFactor and GaussianFactor. */ class HybridGaussianFactor : public HybridFactor { private: