Merge branch 'fan/prototype-hybrid-tr' into feature/GaussianHybridFactorGraph

release/4.3a0
Varun Agrawal 2022-06-02 00:18:58 -04:00
commit 2cc0611f20
9 changed files with 50 additions and 17 deletions

View File

@ -79,7 +79,8 @@ GaussianMixture::asGaussianFactorGraphTree() const {
/* *******************************************************************************/
bool GaussianMixture::equals(const HybridFactor &lf,
double tol) const {
return BaseFactor::equals(lf, tol);
const This *e = dynamic_cast<const This *>(&lf);
return e != nullptr && BaseFactor::equals(*e, tol);
}
/* *******************************************************************************/

View File

@ -34,7 +34,8 @@ GaussianMixtureFactor::GaussianMixtureFactor(const KeyVector &continuousKeys,
/* *******************************************************************************/
bool GaussianMixtureFactor::equals(const HybridFactor &lf, double tol) const {
return Base::equals(lf, tol);
const This *e = dynamic_cast<const This *>(&lf);
return e != nullptr && Base::equals(*e, tol);
}
/* *******************************************************************************/

View File

@ -11,7 +11,7 @@
/**
* @file GaussianMixtureFactor.h
* @brief A factor that is a function of discrete and continuous variables.
* @brief A set of GaussianFactors, indexed by a set of discrete keys.
* @author Fan Jiang
* @author Varun Agrawal
* @author Frank Dellaert
@ -32,9 +32,10 @@ class GaussianFactorGraph;
using GaussianFactorVector = std::vector<gtsam::GaussianFactor::shared_ptr>;
/**
* @brief A linear factor that is a function of both discrete and continuous
* variables, i.e. P(X, M | Z) where X is the set of continuous variables, M is
* the set of discrete variables and Z is the measurement set.
* @brief Implementation of a discrete conditional mixture factor.
* Implements a joint discrete-continuous factor where the discrete variable
* serves to "select" a mixture component corresponding to a GaussianFactor type
* of measurement.
*
* Represents the underlying Gaussian Mixture as a Decision Tree, where the set
* of discrete variables indexes to the continuous gaussian distribution.
@ -52,6 +53,7 @@ class GaussianMixtureFactor : public HybridFactor {
using Factors = DecisionTree<Key, GaussianFactor::shared_ptr>;
private:
/// Decision tree of Gaussian factors indexed by discrete keys.
Factors factors_;
/**

View File

@ -101,7 +101,8 @@ void HybridConditional::print(const std::string &s,
/* ************************************************************************ */
bool HybridConditional::equals(const HybridFactor &other, double tol) const {
return BaseFactor::equals(other, tol);
const This *e = dynamic_cast<const This *>(&other);
return e != nullptr && BaseFactor::equals(*e, tol);
}
} // namespace gtsam

View File

@ -38,7 +38,9 @@ HybridDiscreteFactor::HybridDiscreteFactor(DecisionTreeFactor &&dtf)
/* ************************************************************************ */
bool HybridDiscreteFactor::equals(const HybridFactor &lf, double tol) const {
return Base::equals(lf, tol);
const This *e = dynamic_cast<const This *>(&lf);
// TODO(Varun) How to compare inner_ when they are abstract types?
return e != nullptr && Base::equals(*e, tol);
}
/* ************************************************************************ */

View File

@ -70,7 +70,10 @@ HybridFactor::HybridFactor(const DiscreteKeys &discreteKeys)
/* ************************************************************************ */
bool HybridFactor::equals(const HybridFactor &lf, double tol) const {
return Base::equals(lf, tol);
const This *e = dynamic_cast<const This *>(&lf);
return e != nullptr && Base::equals(*e, tol) &&
isDiscrete_ == e->isDiscrete_ && isContinuous_ == e->isContinuous_ &&
isHybrid_ == e->isHybrid_ && nrContinuous_ == e->nrContinuous_;
}
/* ************************************************************************ */

View File

@ -34,6 +34,7 @@ DiscreteKeys CollectDiscreteKeys(const DiscreteKeys &key1,
/**
* Base class for hybrid probabilistic factors
*
* Examples:
* - HybridGaussianFactor
* - HybridDiscreteFactor
@ -64,13 +65,29 @@ class GTSAM_EXPORT HybridFactor : public Factor {
/** Default constructor creates empty factor */
HybridFactor() = default;
/**
* @brief Construct hybrid factor from continuous keys.
*
* @param keys Vector of continuous keys.
*/
explicit HybridFactor(const KeyVector &keys);
/**
* @brief Construct hybrid factor from discrete keys.
*
* @param keys Vector of discrete keys.
*/
explicit HybridFactor(const DiscreteKeys &discreteKeys);
/**
* @brief Construct a new Hybrid Factor object.
*
* @param continuousKeys Vector of keys for continuous variables.
* @param discreteKeys Vector of keys for discrete variables.
*/
HybridFactor(const KeyVector &continuousKeys,
const DiscreteKeys &discreteKeys);
explicit HybridFactor(const DiscreteKeys &discreteKeys);
/// Virtual destructor
virtual ~HybridFactor() = default;

View File

@ -21,18 +21,23 @@
namespace gtsam {
/* ************************************************************************* */
HybridGaussianFactor::HybridGaussianFactor(GaussianFactor::shared_ptr other)
: Base(other->keys()) {
inner_ = other;
}
: Base(other->keys()), inner_(other) {}
/* ************************************************************************* */
HybridGaussianFactor::HybridGaussianFactor(JacobianFactor &&jf)
: Base(jf.keys()),
inner_(boost::make_shared<JacobianFactor>(std::move(jf))) {}
bool HybridGaussianFactor::equals(const HybridFactor &lf, double tol) const {
return false;
/* ************************************************************************* */
bool HybridGaussianFactor::equals(const HybridFactor &other, double tol) const {
const This *e = dynamic_cast<const This *>(&other);
// TODO(Varun) How to compare inner_ when they are abstract types?
return e != nullptr && Base::equals(*e, tol);
}
/* ************************************************************************* */
void HybridGaussianFactor::print(const std::string &s,
const KeyFormatter &formatter) const {
HybridFactor::print(s, formatter);

View File

@ -25,7 +25,8 @@ namespace gtsam {
/**
* A HybridGaussianFactor is a layer over GaussianFactor so that we do not have
* a diamond inheritance.
* a diamond inheritance i.e. an extra factor type that inherits from both
* HybridFactor and GaussianFactor.
*/
class HybridGaussianFactor : public HybridFactor {
private: