use enum to categorize HybridFactor

release/4.3a0
Varun Agrawal 2024-09-15 17:26:33 -04:00
parent 1c74da26f4
commit 3a7a0b84fe
3 changed files with 42 additions and 25 deletions

View File

@ -50,31 +50,37 @@ DiscreteKeys CollectDiscreteKeys(const DiscreteKeys &key1,
/* ************************************************************************ */ /* ************************************************************************ */
HybridFactor::HybridFactor(const KeyVector &keys) HybridFactor::HybridFactor(const KeyVector &keys)
: Base(keys), isContinuous_(true), continuousKeys_(keys) {} : Base(keys),
category_(HybridCategory::Continuous),
continuousKeys_(keys) {}
/* ************************************************************************ */ /* ************************************************************************ */
HybridFactor::HybridFactor(const KeyVector &continuousKeys, HybridFactor::HybridFactor(const KeyVector &continuousKeys,
const DiscreteKeys &discreteKeys) const DiscreteKeys &discreteKeys)
: Base(CollectKeys(continuousKeys, discreteKeys)), : Base(CollectKeys(continuousKeys, discreteKeys)),
isDiscrete_((continuousKeys.size() == 0) && (discreteKeys.size() != 0)),
isContinuous_((continuousKeys.size() != 0) && (discreteKeys.size() == 0)),
isHybrid_((continuousKeys.size() != 0) && (discreteKeys.size() != 0)),
discreteKeys_(discreteKeys), discreteKeys_(discreteKeys),
continuousKeys_(continuousKeys) {} continuousKeys_(continuousKeys) {
if ((continuousKeys.size() == 0) && (discreteKeys.size() != 0)) {
category_ = HybridCategory::Discrete;
} else if ((continuousKeys.size() != 0) && (discreteKeys.size() == 0)) {
category_ = HybridCategory::Continuous;
} else {
category_ = HybridCategory::Hybrid;
}
}
/* ************************************************************************ */ /* ************************************************************************ */
HybridFactor::HybridFactor(const DiscreteKeys &discreteKeys) HybridFactor::HybridFactor(const DiscreteKeys &discreteKeys)
: Base(CollectKeys({}, discreteKeys)), : Base(CollectKeys({}, discreteKeys)),
isDiscrete_(true), category_(HybridCategory::Discrete),
discreteKeys_(discreteKeys), discreteKeys_(discreteKeys),
continuousKeys_({}) {} continuousKeys_({}) {}
/* ************************************************************************ */ /* ************************************************************************ */
bool HybridFactor::equals(const HybridFactor &lf, double tol) const { bool HybridFactor::equals(const HybridFactor &lf, double tol) const {
const This *e = dynamic_cast<const This *>(&lf); const This *e = dynamic_cast<const This *>(&lf);
return e != nullptr && Base::equals(*e, tol) && return e != nullptr && Base::equals(*e, tol) && category_ == e->category_ &&
isDiscrete_ == e->isDiscrete_ && isContinuous_ == e->isContinuous_ && continuousKeys_ == e->continuousKeys_ &&
isHybrid_ == e->isHybrid_ && continuousKeys_ == e->continuousKeys_ &&
discreteKeys_ == e->discreteKeys_; discreteKeys_ == e->discreteKeys_;
} }
@ -82,9 +88,18 @@ bool HybridFactor::equals(const HybridFactor &lf, double tol) const {
void HybridFactor::print(const std::string &s, void HybridFactor::print(const std::string &s,
const KeyFormatter &formatter) const { const KeyFormatter &formatter) const {
std::cout << (s.empty() ? "" : s + "\n"); std::cout << (s.empty() ? "" : s + "\n");
if (isContinuous_) std::cout << "Continuous "; switch (category_) {
if (isDiscrete_) std::cout << "Discrete "; case HybridCategory::Continuous:
if (isHybrid_) std::cout << "Hybrid "; std::cout << "Continuous ";
break;
case HybridCategory::Discrete:
std::cout << "Discrete ";
break;
case HybridCategory::Hybrid:
std::cout << "Hybrid ";
break;
}
std::cout << "["; std::cout << "[";
for (size_t c = 0; c < continuousKeys_.size(); c++) { for (size_t c = 0; c < continuousKeys_.size(); c++) {
std::cout << formatter(continuousKeys_.at(c)); std::cout << formatter(continuousKeys_.at(c));

View File

@ -41,6 +41,9 @@ KeyVector CollectKeys(const KeyVector &keys1, const KeyVector &keys2);
DiscreteKeys CollectDiscreteKeys(const DiscreteKeys &key1, DiscreteKeys CollectDiscreteKeys(const DiscreteKeys &key1,
const DiscreteKeys &key2); const DiscreteKeys &key2);
/// Enum to help with categorizing hybrid factors.
enum class HybridCategory { Discrete, Continuous, Hybrid };
/** /**
* Base class for *truly* hybrid probabilistic factors * Base class for *truly* hybrid probabilistic factors
* *
@ -53,9 +56,8 @@ DiscreteKeys CollectDiscreteKeys(const DiscreteKeys &key1,
*/ */
class GTSAM_EXPORT HybridFactor : public Factor { class GTSAM_EXPORT HybridFactor : public Factor {
private: private:
bool isDiscrete_ = false; /// Record what category of HybridFactor this is.
bool isContinuous_ = false; HybridCategory category_;
bool isHybrid_ = false;
protected: protected:
// Set of DiscreteKeys for this factor. // Set of DiscreteKeys for this factor.
@ -116,13 +118,13 @@ class GTSAM_EXPORT HybridFactor : public Factor {
/// @{ /// @{
/// True if this is a factor of discrete variables only. /// True if this is a factor of discrete variables only.
bool isDiscrete() const { return isDiscrete_; } bool isDiscrete() const { return category_ == HybridCategory::Discrete; }
/// True if this is a factor of continuous variables only. /// True if this is a factor of continuous variables only.
bool isContinuous() const { return isContinuous_; } bool isContinuous() const { return category_ == HybridCategory::Continuous; }
/// True is this is a Discrete-Continuous factor. /// True is this is a Discrete-Continuous factor.
bool isHybrid() const { return isHybrid_; } bool isHybrid() const { return category_ == HybridCategory::Hybrid; }
/// Return the number of continuous variables in this factor. /// Return the number of continuous variables in this factor.
size_t nrContinuous() const { return continuousKeys_.size(); } size_t nrContinuous() const { return continuousKeys_.size(); }
@ -142,9 +144,7 @@ class GTSAM_EXPORT HybridFactor : public Factor {
template <class ARCHIVE> template <class ARCHIVE>
void serialize(ARCHIVE &ar, const unsigned int /*version*/) { void serialize(ARCHIVE &ar, const unsigned int /*version*/) {
ar &BOOST_SERIALIZATION_BASE_OBJECT_NVP(Base); ar &BOOST_SERIALIZATION_BASE_OBJECT_NVP(Base);
ar &BOOST_SERIALIZATION_NVP(isDiscrete_); ar &BOOST_SERIALIZATION_NVP(category_);
ar &BOOST_SERIALIZATION_NVP(isContinuous_);
ar &BOOST_SERIALIZATION_NVP(isHybrid_);
ar &BOOST_SERIALIZATION_NVP(discreteKeys_); ar &BOOST_SERIALIZATION_NVP(discreteKeys_);
ar &BOOST_SERIALIZATION_NVP(continuousKeys_); ar &BOOST_SERIALIZATION_NVP(continuousKeys_);
} }

View File

@ -387,11 +387,13 @@ TEST(HybridBayesNet, Sampling) {
std::make_shared<BetweenFactor<double>>(X(0), X(1), 0, noise_model); std::make_shared<BetweenFactor<double>>(X(0), X(1), 0, noise_model);
auto one_motion = auto one_motion =
std::make_shared<BetweenFactor<double>>(X(0), X(1), 1, noise_model); std::make_shared<BetweenFactor<double>>(X(0), X(1), 1, noise_model);
std::vector<NonlinearFactorValuePair> factors = {{zero_motion, 0.0},
{one_motion, 0.0}}; DiscreteKeys discreteKeys{DiscreteKey(M(0), 2)};
HybridNonlinearFactor::Factors factors(
discreteKeys, {{zero_motion, 0.0}, {one_motion, 0.0}});
nfg.emplace_shared<PriorFactor<double>>(X(0), 0.0, noise_model); nfg.emplace_shared<PriorFactor<double>>(X(0), 0.0, noise_model);
nfg.emplace_shared<HybridNonlinearFactor>( nfg.emplace_shared<HybridNonlinearFactor>(KeyVector{X(0), X(1)}, discreteKeys,
KeyVector{X(0), X(1)}, DiscreteKeys{DiscreteKey(M(0), 2)}, factors); factors);
DiscreteKey mode(M(0), 2); DiscreteKey mode(M(0), 2);
nfg.emplace_shared<DiscreteDistribution>(mode, "1/1"); nfg.emplace_shared<DiscreteDistribution>(mode, "1/1");