diff --git a/gtsam/hybrid/HybridFactor.cpp b/gtsam/hybrid/HybridFactor.cpp index b25e97f05..89b1943cd 100644 --- a/gtsam/hybrid/HybridFactor.cpp +++ b/gtsam/hybrid/HybridFactor.cpp @@ -50,31 +50,37 @@ DiscreteKeys CollectDiscreteKeys(const DiscreteKeys &key1, /* ************************************************************************ */ HybridFactor::HybridFactor(const KeyVector &keys) - : Base(keys), isContinuous_(true), continuousKeys_(keys) {} + : Base(keys), + category_(HybridCategory::Continuous), + continuousKeys_(keys) {} /* ************************************************************************ */ HybridFactor::HybridFactor(const KeyVector &continuousKeys, const DiscreteKeys &discreteKeys) : Base(CollectKeys(continuousKeys, discreteKeys)), - isDiscrete_((continuousKeys.size() == 0) && (discreteKeys.size() != 0)), - isContinuous_((continuousKeys.size() != 0) && (discreteKeys.size() == 0)), - isHybrid_((continuousKeys.size() != 0) && (discreteKeys.size() != 0)), discreteKeys_(discreteKeys), - continuousKeys_(continuousKeys) {} + continuousKeys_(continuousKeys) { + if ((continuousKeys.size() == 0) && (discreteKeys.size() != 0)) { + category_ = HybridCategory::Discrete; + } else if ((continuousKeys.size() != 0) && (discreteKeys.size() == 0)) { + category_ = HybridCategory::Continuous; + } else { + category_ = HybridCategory::Hybrid; + } +} /* ************************************************************************ */ HybridFactor::HybridFactor(const DiscreteKeys &discreteKeys) : Base(CollectKeys({}, discreteKeys)), - isDiscrete_(true), + category_(HybridCategory::Discrete), discreteKeys_(discreteKeys), continuousKeys_({}) {} /* ************************************************************************ */ bool HybridFactor::equals(const HybridFactor &lf, double tol) const { const This *e = dynamic_cast(&lf); - return e != nullptr && Base::equals(*e, tol) && - isDiscrete_ == e->isDiscrete_ && isContinuous_ == e->isContinuous_ && - isHybrid_ == e->isHybrid_ && continuousKeys_ == e->continuousKeys_ && + return e != nullptr && Base::equals(*e, tol) && category_ == e->category_ && + continuousKeys_ == e->continuousKeys_ && discreteKeys_ == e->discreteKeys_; } @@ -82,9 +88,18 @@ bool HybridFactor::equals(const HybridFactor &lf, double tol) const { void HybridFactor::print(const std::string &s, const KeyFormatter &formatter) const { std::cout << (s.empty() ? "" : s + "\n"); - if (isContinuous_) std::cout << "Continuous "; - if (isDiscrete_) std::cout << "Discrete "; - if (isHybrid_) std::cout << "Hybrid "; + switch (category_) { + case HybridCategory::Continuous: + std::cout << "Continuous "; + break; + case HybridCategory::Discrete: + std::cout << "Discrete "; + break; + case HybridCategory::Hybrid: + std::cout << "Hybrid "; + break; + } + std::cout << "["; for (size_t c = 0; c < continuousKeys_.size(); c++) { std::cout << formatter(continuousKeys_.at(c)); diff --git a/gtsam/hybrid/HybridFactor.h b/gtsam/hybrid/HybridFactor.h index c66116512..2cc7453f4 100644 --- a/gtsam/hybrid/HybridFactor.h +++ b/gtsam/hybrid/HybridFactor.h @@ -41,6 +41,9 @@ KeyVector CollectKeys(const KeyVector &keys1, const KeyVector &keys2); DiscreteKeys CollectDiscreteKeys(const DiscreteKeys &key1, const DiscreteKeys &key2); +/// Enum to help with categorizing hybrid factors. +enum class HybridCategory { Discrete, Continuous, Hybrid }; + /** * Base class for *truly* hybrid probabilistic factors * @@ -53,9 +56,8 @@ DiscreteKeys CollectDiscreteKeys(const DiscreteKeys &key1, */ class GTSAM_EXPORT HybridFactor : public Factor { private: - bool isDiscrete_ = false; - bool isContinuous_ = false; - bool isHybrid_ = false; + /// Record what category of HybridFactor this is. + HybridCategory category_; protected: // Set of DiscreteKeys for this factor. @@ -116,13 +118,13 @@ class GTSAM_EXPORT HybridFactor : public Factor { /// @{ /// True if this is a factor of discrete variables only. - bool isDiscrete() const { return isDiscrete_; } + bool isDiscrete() const { return category_ == HybridCategory::Discrete; } /// True if this is a factor of continuous variables only. - bool isContinuous() const { return isContinuous_; } + bool isContinuous() const { return category_ == HybridCategory::Continuous; } /// True is this is a Discrete-Continuous factor. - bool isHybrid() const { return isHybrid_; } + bool isHybrid() const { return category_ == HybridCategory::Hybrid; } /// Return the number of continuous variables in this factor. size_t nrContinuous() const { return continuousKeys_.size(); } @@ -142,9 +144,7 @@ class GTSAM_EXPORT HybridFactor : public Factor { template void serialize(ARCHIVE &ar, const unsigned int /*version*/) { ar &BOOST_SERIALIZATION_BASE_OBJECT_NVP(Base); - ar &BOOST_SERIALIZATION_NVP(isDiscrete_); - ar &BOOST_SERIALIZATION_NVP(isContinuous_); - ar &BOOST_SERIALIZATION_NVP(isHybrid_); + ar &BOOST_SERIALIZATION_NVP(category_); ar &BOOST_SERIALIZATION_NVP(discreteKeys_); ar &BOOST_SERIALIZATION_NVP(continuousKeys_); } diff --git a/gtsam/hybrid/tests/testHybridBayesNet.cpp b/gtsam/hybrid/tests/testHybridBayesNet.cpp index cf4231dba..99487a84a 100644 --- a/gtsam/hybrid/tests/testHybridBayesNet.cpp +++ b/gtsam/hybrid/tests/testHybridBayesNet.cpp @@ -387,11 +387,13 @@ TEST(HybridBayesNet, Sampling) { std::make_shared>(X(0), X(1), 0, noise_model); auto one_motion = std::make_shared>(X(0), X(1), 1, noise_model); - std::vector factors = {{zero_motion, 0.0}, - {one_motion, 0.0}}; + + DiscreteKeys discreteKeys{DiscreteKey(M(0), 2)}; + HybridNonlinearFactor::Factors factors( + discreteKeys, {{zero_motion, 0.0}, {one_motion, 0.0}}); nfg.emplace_shared>(X(0), 0.0, noise_model); - nfg.emplace_shared( - KeyVector{X(0), X(1)}, DiscreteKeys{DiscreteKey(M(0), 2)}, factors); + nfg.emplace_shared(KeyVector{X(0), X(1)}, discreteKeys, + factors); DiscreteKey mode(M(0), 2); nfg.emplace_shared(mode, "1/1");