Merge branch 'fan/prototype-hybrid-tr' into feature/GaussianHybridFactorGraph

release/4.3a0
Varun Agrawal 2022-06-02 00:18:58 -04:00
commit 2cc0611f20
9 changed files with 50 additions and 17 deletions

View File

@ -79,7 +79,8 @@ GaussianMixture::asGaussianFactorGraphTree() const {
/* *******************************************************************************/ /* *******************************************************************************/
bool GaussianMixture::equals(const HybridFactor &lf, bool GaussianMixture::equals(const HybridFactor &lf,
double tol) const { double tol) const {
return BaseFactor::equals(lf, tol); const This *e = dynamic_cast<const This *>(&lf);
return e != nullptr && BaseFactor::equals(*e, tol);
} }
/* *******************************************************************************/ /* *******************************************************************************/

View File

@ -34,7 +34,8 @@ GaussianMixtureFactor::GaussianMixtureFactor(const KeyVector &continuousKeys,
/* *******************************************************************************/ /* *******************************************************************************/
bool GaussianMixtureFactor::equals(const HybridFactor &lf, double tol) const { bool GaussianMixtureFactor::equals(const HybridFactor &lf, double tol) const {
return Base::equals(lf, tol); const This *e = dynamic_cast<const This *>(&lf);
return e != nullptr && Base::equals(*e, tol);
} }
/* *******************************************************************************/ /* *******************************************************************************/

View File

@ -11,7 +11,7 @@
/** /**
* @file GaussianMixtureFactor.h * @file GaussianMixtureFactor.h
* @brief A factor that is a function of discrete and continuous variables. * @brief A set of GaussianFactors, indexed by a set of discrete keys.
* @author Fan Jiang * @author Fan Jiang
* @author Varun Agrawal * @author Varun Agrawal
* @author Frank Dellaert * @author Frank Dellaert
@ -32,9 +32,10 @@ class GaussianFactorGraph;
using GaussianFactorVector = std::vector<gtsam::GaussianFactor::shared_ptr>; using GaussianFactorVector = std::vector<gtsam::GaussianFactor::shared_ptr>;
/** /**
* @brief A linear factor that is a function of both discrete and continuous * @brief Implementation of a discrete conditional mixture factor.
* variables, i.e. P(X, M | Z) where X is the set of continuous variables, M is * Implements a joint discrete-continuous factor where the discrete variable
* the set of discrete variables and Z is the measurement set. * serves to "select" a mixture component corresponding to a GaussianFactor type
* of measurement.
* *
* Represents the underlying Gaussian Mixture as a Decision Tree, where the set * Represents the underlying Gaussian Mixture as a Decision Tree, where the set
* of discrete variables indexes to the continuous gaussian distribution. * of discrete variables indexes to the continuous gaussian distribution.
@ -52,6 +53,7 @@ class GaussianMixtureFactor : public HybridFactor {
using Factors = DecisionTree<Key, GaussianFactor::shared_ptr>; using Factors = DecisionTree<Key, GaussianFactor::shared_ptr>;
private: private:
/// Decision tree of Gaussian factors indexed by discrete keys.
Factors factors_; Factors factors_;
/** /**

View File

@ -101,7 +101,8 @@ void HybridConditional::print(const std::string &s,
/* ************************************************************************ */ /* ************************************************************************ */
bool HybridConditional::equals(const HybridFactor &other, double tol) const { bool HybridConditional::equals(const HybridFactor &other, double tol) const {
return BaseFactor::equals(other, tol); const This *e = dynamic_cast<const This *>(&other);
return e != nullptr && BaseFactor::equals(*e, tol);
} }
} // namespace gtsam } // namespace gtsam

View File

@ -38,7 +38,9 @@ HybridDiscreteFactor::HybridDiscreteFactor(DecisionTreeFactor &&dtf)
/* ************************************************************************ */ /* ************************************************************************ */
bool HybridDiscreteFactor::equals(const HybridFactor &lf, double tol) const { bool HybridDiscreteFactor::equals(const HybridFactor &lf, double tol) const {
return Base::equals(lf, tol); const This *e = dynamic_cast<const This *>(&lf);
// TODO(Varun) How to compare inner_ when they are abstract types?
return e != nullptr && Base::equals(*e, tol);
} }
/* ************************************************************************ */ /* ************************************************************************ */

View File

@ -70,7 +70,10 @@ HybridFactor::HybridFactor(const DiscreteKeys &discreteKeys)
/* ************************************************************************ */ /* ************************************************************************ */
bool HybridFactor::equals(const HybridFactor &lf, double tol) const { bool HybridFactor::equals(const HybridFactor &lf, double tol) const {
return Base::equals(lf, tol); const This *e = dynamic_cast<const This *>(&lf);
return e != nullptr && Base::equals(*e, tol) &&
isDiscrete_ == e->isDiscrete_ && isContinuous_ == e->isContinuous_ &&
isHybrid_ == e->isHybrid_ && nrContinuous_ == e->nrContinuous_;
} }
/* ************************************************************************ */ /* ************************************************************************ */

View File

@ -34,6 +34,7 @@ DiscreteKeys CollectDiscreteKeys(const DiscreteKeys &key1,
/** /**
* Base class for hybrid probabilistic factors * Base class for hybrid probabilistic factors
*
* Examples: * Examples:
* - HybridGaussianFactor * - HybridGaussianFactor
* - HybridDiscreteFactor * - HybridDiscreteFactor
@ -64,13 +65,29 @@ class GTSAM_EXPORT HybridFactor : public Factor {
/** Default constructor creates empty factor */ /** Default constructor creates empty factor */
HybridFactor() = default; HybridFactor() = default;
/**
* @brief Construct hybrid factor from continuous keys.
*
* @param keys Vector of continuous keys.
*/
explicit HybridFactor(const KeyVector &keys); explicit HybridFactor(const KeyVector &keys);
/**
* @brief Construct hybrid factor from discrete keys.
*
* @param keys Vector of discrete keys.
*/
explicit HybridFactor(const DiscreteKeys &discreteKeys);
/**
* @brief Construct a new Hybrid Factor object.
*
* @param continuousKeys Vector of keys for continuous variables.
* @param discreteKeys Vector of keys for discrete variables.
*/
HybridFactor(const KeyVector &continuousKeys, HybridFactor(const KeyVector &continuousKeys,
const DiscreteKeys &discreteKeys); const DiscreteKeys &discreteKeys);
explicit HybridFactor(const DiscreteKeys &discreteKeys);
/// Virtual destructor /// Virtual destructor
virtual ~HybridFactor() = default; virtual ~HybridFactor() = default;

View File

@ -21,18 +21,23 @@
namespace gtsam { namespace gtsam {
/* ************************************************************************* */
HybridGaussianFactor::HybridGaussianFactor(GaussianFactor::shared_ptr other) HybridGaussianFactor::HybridGaussianFactor(GaussianFactor::shared_ptr other)
: Base(other->keys()) { : Base(other->keys()), inner_(other) {}
inner_ = other;
}
/* ************************************************************************* */
HybridGaussianFactor::HybridGaussianFactor(JacobianFactor &&jf) HybridGaussianFactor::HybridGaussianFactor(JacobianFactor &&jf)
: Base(jf.keys()), : Base(jf.keys()),
inner_(boost::make_shared<JacobianFactor>(std::move(jf))) {} inner_(boost::make_shared<JacobianFactor>(std::move(jf))) {}
bool HybridGaussianFactor::equals(const HybridFactor &lf, double tol) const { /* ************************************************************************* */
return false; bool HybridGaussianFactor::equals(const HybridFactor &other, double tol) const {
const This *e = dynamic_cast<const This *>(&other);
// TODO(Varun) How to compare inner_ when they are abstract types?
return e != nullptr && Base::equals(*e, tol);
} }
/* ************************************************************************* */
void HybridGaussianFactor::print(const std::string &s, void HybridGaussianFactor::print(const std::string &s,
const KeyFormatter &formatter) const { const KeyFormatter &formatter) const {
HybridFactor::print(s, formatter); HybridFactor::print(s, formatter);

View File

@ -25,7 +25,8 @@ namespace gtsam {
/** /**
* A HybridGaussianFactor is a layer over GaussianFactor so that we do not have * A HybridGaussianFactor is a layer over GaussianFactor so that we do not have
* a diamond inheritance. * a diamond inheritance i.e. an extra factor type that inherits from both
* HybridFactor and GaussianFactor.
*/ */
class HybridGaussianFactor : public HybridFactor { class HybridGaussianFactor : public HybridFactor {
private: private: