diff --git a/gtsam/hybrid/HybridConditional.cpp b/gtsam/hybrid/HybridConditional.cpp index 8a8511aef..ed8125c2b 100644 --- a/gtsam/hybrid/HybridConditional.cpp +++ b/gtsam/hybrid/HybridConditional.cpp @@ -28,14 +28,9 @@ HybridConditional::HybridConditional(const KeyVector &continuousFrontals, const DiscreteKeys &discreteFrontals, const KeyVector &continuousParents, const DiscreteKeys &discreteParents) - : HybridConditional( - CollectKeys( - {continuousFrontals.begin(), continuousFrontals.end()}, - KeyVector{continuousParents.begin(), continuousParents.end()}), - CollectDiscreteKeys( - {discreteFrontals.begin(), discreteFrontals.end()}, - {discreteParents.begin(), discreteParents.end()}), - continuousFrontals.size() + discreteFrontals.size()) {} + : HybridConditional(CollectKeys(continuousFrontals, continuousParents), + CollectDiscreteKeys(discreteFrontals, discreteParents), + continuousFrontals.size() + discreteFrontals.size()) {} /* ************************************************************************ */ HybridConditional::HybridConditional( @@ -56,9 +51,7 @@ HybridConditional::HybridConditional( /* ************************************************************************ */ HybridConditional::HybridConditional( const std::shared_ptr &gaussianMixture) - : BaseFactor(KeyVector(gaussianMixture->keys().begin(), - gaussianMixture->keys().begin() + - gaussianMixture->nrContinuous()), + : BaseFactor(gaussianMixture->continuousKeys(), gaussianMixture->discreteKeys()), BaseConditional(gaussianMixture->nrFrontals()) { inner_ = gaussianMixture; diff --git a/gtsam/hybrid/HybridFactor.cpp b/gtsam/hybrid/HybridFactor.cpp index b25e97f05..3338951bf 100644 --- a/gtsam/hybrid/HybridFactor.cpp +++ b/gtsam/hybrid/HybridFactor.cpp @@ -50,31 +50,43 @@ DiscreteKeys CollectDiscreteKeys(const DiscreteKeys &key1, /* ************************************************************************ */ HybridFactor::HybridFactor(const KeyVector &keys) - : Base(keys), isContinuous_(true), continuousKeys_(keys) {} + : Base(keys), category_(Category::Continuous), continuousKeys_(keys) {} + +/* ************************************************************************ */ +HybridFactor::Category GetCategory(const KeyVector &continuousKeys, + const DiscreteKeys &discreteKeys) { + if ((continuousKeys.size() == 0) && (discreteKeys.size() != 0)) { + return HybridFactor::Category::Discrete; + } else if ((continuousKeys.size() != 0) && (discreteKeys.size() == 0)) { + return HybridFactor::Category::Continuous; + } else if ((continuousKeys.size() != 0) && (discreteKeys.size() != 0)) { + return HybridFactor::Category::Hybrid; + } else { + // Case where we have no keys. Should never happen. + return HybridFactor::Category::None; + } +} /* ************************************************************************ */ HybridFactor::HybridFactor(const KeyVector &continuousKeys, const DiscreteKeys &discreteKeys) : Base(CollectKeys(continuousKeys, discreteKeys)), - isDiscrete_((continuousKeys.size() == 0) && (discreteKeys.size() != 0)), - isContinuous_((continuousKeys.size() != 0) && (discreteKeys.size() == 0)), - isHybrid_((continuousKeys.size() != 0) && (discreteKeys.size() != 0)), + category_(GetCategory(continuousKeys, discreteKeys)), discreteKeys_(discreteKeys), continuousKeys_(continuousKeys) {} /* ************************************************************************ */ HybridFactor::HybridFactor(const DiscreteKeys &discreteKeys) : Base(CollectKeys({}, discreteKeys)), - isDiscrete_(true), + category_(Category::Discrete), discreteKeys_(discreteKeys), continuousKeys_({}) {} /* ************************************************************************ */ bool HybridFactor::equals(const HybridFactor &lf, double tol) const { const This *e = dynamic_cast(&lf); - return e != nullptr && Base::equals(*e, tol) && - isDiscrete_ == e->isDiscrete_ && isContinuous_ == e->isContinuous_ && - isHybrid_ == e->isHybrid_ && continuousKeys_ == e->continuousKeys_ && + return e != nullptr && Base::equals(*e, tol) && category_ == e->category_ && + continuousKeys_ == e->continuousKeys_ && discreteKeys_ == e->discreteKeys_; } @@ -82,9 +94,21 @@ bool HybridFactor::equals(const HybridFactor &lf, double tol) const { void HybridFactor::print(const std::string &s, const KeyFormatter &formatter) const { std::cout << (s.empty() ? "" : s + "\n"); - if (isContinuous_) std::cout << "Continuous "; - if (isDiscrete_) std::cout << "Discrete "; - if (isHybrid_) std::cout << "Hybrid "; + switch (category_) { + case Category::Continuous: + std::cout << "Continuous "; + break; + case Category::Discrete: + std::cout << "Discrete "; + break; + case Category::Hybrid: + std::cout << "Hybrid "; + break; + case Category::None: + std::cout << "None "; + break; + } + std::cout << "["; for (size_t c = 0; c < continuousKeys_.size(); c++) { std::cout << formatter(continuousKeys_.at(c)); diff --git a/gtsam/hybrid/HybridFactor.h b/gtsam/hybrid/HybridFactor.h index c66116512..ad29dfdca 100644 --- a/gtsam/hybrid/HybridFactor.h +++ b/gtsam/hybrid/HybridFactor.h @@ -52,10 +52,13 @@ DiscreteKeys CollectDiscreteKeys(const DiscreteKeys &key1, * @ingroup hybrid */ class GTSAM_EXPORT HybridFactor : public Factor { + public: + /// Enum to help with categorizing hybrid factors. + enum class Category { None, Discrete, Continuous, Hybrid }; + private: - bool isDiscrete_ = false; - bool isContinuous_ = false; - bool isHybrid_ = false; + /// Record what category of HybridFactor this is. + Category category_ = Category::None; protected: // Set of DiscreteKeys for this factor. @@ -116,13 +119,13 @@ class GTSAM_EXPORT HybridFactor : public Factor { /// @{ /// True if this is a factor of discrete variables only. - bool isDiscrete() const { return isDiscrete_; } + bool isDiscrete() const { return category_ == Category::Discrete; } /// True if this is a factor of continuous variables only. - bool isContinuous() const { return isContinuous_; } + bool isContinuous() const { return category_ == Category::Continuous; } /// True is this is a Discrete-Continuous factor. - bool isHybrid() const { return isHybrid_; } + bool isHybrid() const { return category_ == Category::Hybrid; } /// Return the number of continuous variables in this factor. size_t nrContinuous() const { return continuousKeys_.size(); } @@ -142,9 +145,7 @@ class GTSAM_EXPORT HybridFactor : public Factor { template void serialize(ARCHIVE &ar, const unsigned int /*version*/) { ar &BOOST_SERIALIZATION_BASE_OBJECT_NVP(Base); - ar &BOOST_SERIALIZATION_NVP(isDiscrete_); - ar &BOOST_SERIALIZATION_NVP(isContinuous_); - ar &BOOST_SERIALIZATION_NVP(isHybrid_); + ar &BOOST_SERIALIZATION_NVP(category_); ar &BOOST_SERIALIZATION_NVP(discreteKeys_); ar &BOOST_SERIALIZATION_NVP(continuousKeys_); } diff --git a/gtsam/hybrid/HybridGaussianFactorGraph.cpp b/gtsam/hybrid/HybridGaussianFactorGraph.cpp index 74091bf95..362150745 100644 --- a/gtsam/hybrid/HybridGaussianFactorGraph.cpp +++ b/gtsam/hybrid/HybridGaussianFactorGraph.cpp @@ -114,10 +114,11 @@ void HybridGaussianFactorGraph::printErrors( << "\n"; } else { // Is hybrid - auto mixtureComponent = + auto conditionalComponent = hc->asMixture()->operator()(values.discrete()); - mixtureComponent->print(ss.str(), keyFormatter); - std::cout << "error = " << mixtureComponent->error(values) << "\n"; + conditionalComponent->print(ss.str(), keyFormatter); + std::cout << "error = " << conditionalComponent->error(values) + << "\n"; } } } else if (auto gf = std::dynamic_pointer_cast(factor)) { @@ -411,10 +412,10 @@ hybridElimination(const HybridGaussianFactorGraph &factors, // Create the HybridGaussianConditional from the conditionals HybridGaussianConditional::Conditionals conditionals( eliminationResults, [](const Result &pair) { return pair.first; }); - auto gaussianMixture = std::make_shared( + auto hybridGaussian = std::make_shared( frontalKeys, continuousSeparator, discreteSeparator, conditionals); - return {std::make_shared(gaussianMixture), newFactor}; + return {std::make_shared(hybridGaussian), newFactor}; } /* ************************************************************************ @@ -465,7 +466,7 @@ EliminateHybrid(const HybridGaussianFactorGraph &factors, // Now we will need to know how to retrieve the corresponding continuous // densities for the assignment (c1,c2,c3) (OR (c2,c3,c1), note there is NO // defined order!). We also need to consider when there is pruning. Two - // mixture factors could have different pruning patterns - one could have + // hybrid factors could have different pruning patterns - one could have // (c1=0,c2=1) pruned, and another could have (c2=0,c3=1) pruned, and this // creates a big problem in how to identify the intersection of non-pruned // branches.