Hybrid factor docs and minor refactor

release/4.3a0
Varun Agrawal 2022-05-27 15:15:34 -04:00
parent b3cab1bd4e
commit 865d10da9c
4 changed files with 45 additions and 29 deletions

View File

@ -19,6 +19,7 @@
namespace gtsam { namespace gtsam {
/* ************************************************************************ */
KeyVector CollectKeys(const KeyVector &continuousKeys, KeyVector CollectKeys(const KeyVector &continuousKeys,
const DiscreteKeys &discreteKeys) { const DiscreteKeys &discreteKeys) {
KeyVector allKeys; KeyVector allKeys;
@ -30,6 +31,7 @@ KeyVector CollectKeys(const KeyVector &continuousKeys,
return allKeys; return allKeys;
} }
/* ************************************************************************ */
KeyVector CollectKeys(const KeyVector &keys1, const KeyVector &keys2) { KeyVector CollectKeys(const KeyVector &keys1, const KeyVector &keys2) {
KeyVector allKeys; KeyVector allKeys;
std::copy(keys1.begin(), keys1.end(), std::back_inserter(allKeys)); std::copy(keys1.begin(), keys1.end(), std::back_inserter(allKeys));
@ -37,6 +39,7 @@ KeyVector CollectKeys(const KeyVector &keys1, const KeyVector &keys2) {
return allKeys; return allKeys;
} }
/* ************************************************************************ */
DiscreteKeys CollectDiscreteKeys(const DiscreteKeys &key1, DiscreteKeys CollectDiscreteKeys(const DiscreteKeys &key1,
const DiscreteKeys &key2) { const DiscreteKeys &key2) {
DiscreteKeys allKeys; DiscreteKeys allKeys;
@ -45,29 +48,32 @@ DiscreteKeys CollectDiscreteKeys(const DiscreteKeys &key1,
return allKeys; return allKeys;
} }
HybridFactor::HybridFactor() = default; /* ************************************************************************ */
HybridFactor::HybridFactor(const KeyVector &keys) HybridFactor::HybridFactor(const KeyVector &keys)
: Base(keys), isContinuous_(true), nrContinuous(keys.size()) {} : Base(keys), isContinuous_(true), nrContinuous_(keys.size()) {}
/* ************************************************************************ */
HybridFactor::HybridFactor(const KeyVector &continuousKeys, HybridFactor::HybridFactor(const KeyVector &continuousKeys,
const DiscreteKeys &discreteKeys) const DiscreteKeys &discreteKeys)
: Base(CollectKeys(continuousKeys, discreteKeys)), : Base(CollectKeys(continuousKeys, discreteKeys)),
isDiscrete_((continuousKeys.size() == 0) && (discreteKeys.size() != 0)), isDiscrete_((continuousKeys.size() == 0) && (discreteKeys.size() != 0)),
isContinuous_((continuousKeys.size() != 0) && (discreteKeys.size() == 0)), isContinuous_((continuousKeys.size() != 0) && (discreteKeys.size() == 0)),
isHybrid_((continuousKeys.size() != 0) && (discreteKeys.size() != 0)), isHybrid_((continuousKeys.size() != 0) && (discreteKeys.size() != 0)),
nrContinuous(continuousKeys.size()), nrContinuous_(continuousKeys.size()),
discreteKeys_(discreteKeys) {} discreteKeys_(discreteKeys) {}
/* ************************************************************************ */
HybridFactor::HybridFactor(const DiscreteKeys &discreteKeys) HybridFactor::HybridFactor(const DiscreteKeys &discreteKeys)
: Base(CollectKeys({}, discreteKeys)), : Base(CollectKeys({}, discreteKeys)),
isDiscrete_(true), isDiscrete_(true),
discreteKeys_(discreteKeys) {} discreteKeys_(discreteKeys) {}
/* ************************************************************************ */
bool HybridFactor::equals(const HybridFactor &lf, double tol) const { bool HybridFactor::equals(const HybridFactor &lf, double tol) const {
return Base::equals(lf, tol); return Base::equals(lf, tol);
} }
/* ************************************************************************ */
void HybridFactor::print(const std::string &s, void HybridFactor::print(const std::string &s,
const KeyFormatter &formatter) const { const KeyFormatter &formatter) const {
std::cout << s; std::cout << s;
@ -77,6 +83,4 @@ void HybridFactor::print(const std::string &s,
this->printKeys("", formatter); this->printKeys("", formatter);
} }
HybridFactor::~HybridFactor() = default;
} // namespace gtsam } // namespace gtsam

View File

@ -41,6 +41,16 @@ DiscreteKeys CollectDiscreteKeys(const DiscreteKeys &key1,
* - GaussianMixture * - GaussianMixture
*/ */
class GTSAM_EXPORT HybridFactor : public Factor { class GTSAM_EXPORT HybridFactor : public Factor {
private:
bool isDiscrete_ = false;
bool isContinuous_ = false;
bool isHybrid_ = false;
size_t nrContinuous_ = 0;
protected:
DiscreteKeys discreteKeys_;
public: public:
// typedefs needed to play nice with gtsam // typedefs needed to play nice with gtsam
typedef HybridFactor This; ///< This class typedef HybridFactor This; ///< This class
@ -48,27 +58,11 @@ class GTSAM_EXPORT HybridFactor : public Factor {
shared_ptr; ///< shared_ptr to this class shared_ptr; ///< shared_ptr to this class
typedef Factor Base; ///< Our base class typedef Factor Base; ///< Our base class
bool isDiscrete_ = false;
bool isContinuous_ = false;
bool isHybrid_ = false;
size_t nrContinuous = 0;
DiscreteKeys discreteKeys_;
public:
/// @name Standard Constructors /// @name Standard Constructors
/// @{ /// @{
/** Default constructor creates empty factor */ /** Default constructor creates empty factor */
HybridFactor(); HybridFactor() = default;
/** Construct from container of keys. This constructor is used internally
* from derived factor
* constructors, either from a container of keys or from a
* boost::assign::list_of. */
// template<typename CONTAINER>
// HybridFactor(const CONTAINER &keys) : Base(keys) {}
explicit HybridFactor(const KeyVector &keys); explicit HybridFactor(const KeyVector &keys);
@ -78,7 +72,7 @@ class GTSAM_EXPORT HybridFactor : public Factor {
explicit HybridFactor(const DiscreteKeys &discreteKeys); explicit HybridFactor(const DiscreteKeys &discreteKeys);
/// Virtual destructor /// Virtual destructor
virtual ~HybridFactor(); virtual ~HybridFactor() = default;
/// @} /// @}
/// @name Testable /// @name Testable
@ -96,6 +90,21 @@ class GTSAM_EXPORT HybridFactor : public Factor {
/// @name Standard Interface /// @name Standard Interface
/// @{ /// @{
/// True if this is a factor of discrete variables only.
bool isDiscrete() const { return isDiscrete_; }
/// True if this is a factor of continuous variables only.
bool isContinuous() const { return isContinuous_; }
/// True is this is a Discrete-Continuous factor.
bool isHybrid() const { return isHybrid_; }
/// Return the number of continuous variables in this factor.
size_t nrContinuous() const { return nrContinuous_; }
/// Return vector of discrete keys.
DiscreteKeys discreteKeys() const { return discreteKeys_; }
/// @} /// @}
}; };
// HybridFactor // HybridFactor

View File

@ -23,12 +23,12 @@ namespace gtsam {
HybridGaussianFactor::HybridGaussianFactor(GaussianFactor::shared_ptr other) HybridGaussianFactor::HybridGaussianFactor(GaussianFactor::shared_ptr other)
: Base(other->keys()) { : Base(other->keys()) {
inner = other; inner_ = other;
} }
HybridGaussianFactor::HybridGaussianFactor(JacobianFactor &&jf) HybridGaussianFactor::HybridGaussianFactor(JacobianFactor &&jf)
: Base(jf.keys()), : Base(jf.keys()),
inner(boost::make_shared<JacobianFactor>(std::move(jf))) {} inner_(boost::make_shared<JacobianFactor>(std::move(jf))) {}
bool HybridGaussianFactor::equals(const HybridFactor &lf, double tol) const { bool HybridGaussianFactor::equals(const HybridFactor &lf, double tol) const {
return false; return false;
@ -36,7 +36,7 @@ bool HybridGaussianFactor::equals(const HybridFactor &lf, double tol) const {
void HybridGaussianFactor::print(const std::string &s, void HybridGaussianFactor::print(const std::string &s,
const KeyFormatter &formatter) const { const KeyFormatter &formatter) const {
HybridFactor::print(s, formatter); HybridFactor::print(s, formatter);
inner->print("inner: ", formatter); inner_->print("inner: ", formatter);
}; };
} // namespace gtsam } // namespace gtsam

View File

@ -28,13 +28,14 @@ namespace gtsam {
* a diamond inheritance. * a diamond inheritance.
*/ */
class HybridGaussianFactor : public HybridFactor { class HybridGaussianFactor : public HybridFactor {
private:
GaussianFactor::shared_ptr inner_;
public: public:
using Base = HybridFactor; using Base = HybridFactor;
using This = HybridGaussianFactor; using This = HybridGaussianFactor;
using shared_ptr = boost::shared_ptr<This>; using shared_ptr = boost::shared_ptr<This>;
GaussianFactor::shared_ptr inner;
// Explicit conversion from a shared ptr of GF // Explicit conversion from a shared ptr of GF
explicit HybridGaussianFactor(GaussianFactor::shared_ptr other); explicit HybridGaussianFactor(GaussianFactor::shared_ptr other);
@ -47,5 +48,7 @@ class HybridGaussianFactor : public HybridFactor {
void print( void print(
const std::string &s = "HybridFactor\n", const std::string &s = "HybridFactor\n",
const KeyFormatter &formatter = DefaultKeyFormatter) const override; const KeyFormatter &formatter = DefaultKeyFormatter) const override;
GaussianFactor::shared_ptr inner() const { return inner_; }
}; };
} // namespace gtsam } // namespace gtsam