improved equality checks

release/4.3a0
Varun Agrawal 2022-05-28 18:03:52 -04:00
parent f443cf30e0
commit 841e6b01df
8 changed files with 31 additions and 15 deletions

View File

@ -79,7 +79,8 @@ GaussianMixtureConditional::asGaussianFactorGraphTree() const {
/* *******************************************************************************/ /* *******************************************************************************/
bool GaussianMixtureConditional::equals(const HybridFactor &lf, bool GaussianMixtureConditional::equals(const HybridFactor &lf,
double tol) const { double tol) const {
return BaseFactor::equals(lf, tol); const This *e = dynamic_cast<const This *>(&lf);
return e != nullptr && BaseFactor::equals(*e, tol);
} }
/* *******************************************************************************/ /* *******************************************************************************/

View File

@ -34,7 +34,8 @@ GaussianMixtureFactor::GaussianMixtureFactor(const KeyVector &continuousKeys,
/* *******************************************************************************/ /* *******************************************************************************/
bool GaussianMixtureFactor::equals(const HybridFactor &lf, double tol) const { bool GaussianMixtureFactor::equals(const HybridFactor &lf, double tol) const {
return Base::equals(lf, tol); const This *e = dynamic_cast<const This *>(&lf);
return e != nullptr && Base::equals(*e, tol);
} }
/* *******************************************************************************/ /* *******************************************************************************/

View File

@ -11,7 +11,7 @@
/** /**
* @file GaussianMixtureFactor.h * @file GaussianMixtureFactor.h
* @brief A factor that is a function of discrete and continuous variables. * @brief A set of GaussianFactors, indexed by a set of discrete keys.
* @author Fan Jiang * @author Fan Jiang
* @author Varun Agrawal * @author Varun Agrawal
* @author Frank Dellaert * @author Frank Dellaert
@ -32,9 +32,10 @@ class GaussianFactorGraph;
using GaussianFactorVector = std::vector<gtsam::GaussianFactor::shared_ptr>; using GaussianFactorVector = std::vector<gtsam::GaussianFactor::shared_ptr>;
/** /**
* @brief A linear factor that is a function of both discrete and continuous * @brief Implementation of a discrete conditional mixture factor.
* variables, i.e. P(X, M | Z) where X is the set of continuous variables, M is * Implements a joint discrete-continuous factor where the discrete variable
* the set of discrete variables and Z is the measurement set. * serves to "select" a mixture component corresponding to a GaussianFactor type
* of measurement.
* *
* Represents the underlying Gaussian Mixture as a Decision Tree, where the set * Represents the underlying Gaussian Mixture as a Decision Tree, where the set
* of discrete variables indexes to the continuous gaussian distribution. * of discrete variables indexes to the continuous gaussian distribution.
@ -52,6 +53,7 @@ class GaussianMixtureFactor : public HybridFactor {
using Factors = DecisionTree<Key, GaussianFactor::shared_ptr>; using Factors = DecisionTree<Key, GaussianFactor::shared_ptr>;
private: private:
/// Decision tree of Gaussian factors indexed by discrete keys.
Factors factors_; Factors factors_;
/** /**

View File

@ -101,7 +101,8 @@ void HybridConditional::print(const std::string &s,
/* ************************************************************************ */ /* ************************************************************************ */
bool HybridConditional::equals(const HybridFactor &other, double tol) const { bool HybridConditional::equals(const HybridFactor &other, double tol) const {
return BaseFactor::equals(other, tol); const This *e = dynamic_cast<const This *>(&other);
return e != nullptr && BaseFactor::equals(*e, tol);
} }
} // namespace gtsam } // namespace gtsam

View File

@ -38,7 +38,9 @@ HybridDiscreteFactor::HybridDiscreteFactor(DecisionTreeFactor &&dtf)
/* ************************************************************************ */ /* ************************************************************************ */
bool HybridDiscreteFactor::equals(const HybridFactor &lf, double tol) const { bool HybridDiscreteFactor::equals(const HybridFactor &lf, double tol) const {
return Base::equals(lf, tol); const This *e = dynamic_cast<const This *>(&lf);
// TODO(Varun) How to compare inner_ when they are abstract types?
return e != nullptr && Base::equals(*e, tol);
} }
/* ************************************************************************ */ /* ************************************************************************ */

View File

@ -70,7 +70,10 @@ HybridFactor::HybridFactor(const DiscreteKeys &discreteKeys)
/* ************************************************************************ */ /* ************************************************************************ */
bool HybridFactor::equals(const HybridFactor &lf, double tol) const { bool HybridFactor::equals(const HybridFactor &lf, double tol) const {
return Base::equals(lf, tol); const This *e = dynamic_cast<const This *>(&lf);
return e != nullptr && Base::equals(*e, tol) &&
isDiscrete_ == e->isDiscrete_ && isContinuous_ == e->isContinuous_ &&
isHybrid_ == e->isHybrid_ && nrContinuous_ == e->nrContinuous_;
} }
/* ************************************************************************ */ /* ************************************************************************ */

View File

@ -21,18 +21,23 @@
namespace gtsam { namespace gtsam {
/* ************************************************************************* */
HybridGaussianFactor::HybridGaussianFactor(GaussianFactor::shared_ptr other) HybridGaussianFactor::HybridGaussianFactor(GaussianFactor::shared_ptr other)
: Base(other->keys()) { : Base(other->keys()), inner_(other) {}
inner_ = other;
}
/* ************************************************************************* */
HybridGaussianFactor::HybridGaussianFactor(JacobianFactor &&jf) HybridGaussianFactor::HybridGaussianFactor(JacobianFactor &&jf)
: Base(jf.keys()), : Base(jf.keys()),
inner_(boost::make_shared<JacobianFactor>(std::move(jf))) {} inner_(boost::make_shared<JacobianFactor>(std::move(jf))) {}
bool HybridGaussianFactor::equals(const HybridFactor &lf, double tol) const { /* ************************************************************************* */
return false; bool HybridGaussianFactor::equals(const HybridFactor &other, double tol) const {
const This *e = dynamic_cast<const This *>(&other);
// TODO(Varun) How to compare inner_ when they are abstract types?
return e != nullptr && Base::equals(*e, tol);
} }
/* ************************************************************************* */
void HybridGaussianFactor::print(const std::string &s, void HybridGaussianFactor::print(const std::string &s,
const KeyFormatter &formatter) const { const KeyFormatter &formatter) const {
HybridFactor::print(s, formatter); HybridFactor::print(s, formatter);

View File

@ -25,7 +25,8 @@ namespace gtsam {
/** /**
* A HybridGaussianFactor is a layer over GaussianFactor so that we do not have * A HybridGaussianFactor is a layer over GaussianFactor so that we do not have
* a diamond inheritance. * a diamond inheritance i.e. an extra factor type that inherits from both
* HybridFactor and GaussianFactor.
*/ */
class HybridGaussianFactor : public HybridFactor { class HybridGaussianFactor : public HybridFactor {
private: private: