use enum to categorize HybridFactor
							parent
							
								
									1c74da26f4
								
							
						
					
					
						commit
						3a7a0b84fe
					
				|  | @ -50,31 +50,37 @@ DiscreteKeys CollectDiscreteKeys(const DiscreteKeys &key1, | |||
| 
 | ||||
| /* ************************************************************************ */ | ||||
| HybridFactor::HybridFactor(const KeyVector &keys) | ||||
|     : Base(keys), isContinuous_(true), continuousKeys_(keys) {} | ||||
|     : Base(keys), | ||||
|       category_(HybridCategory::Continuous), | ||||
|       continuousKeys_(keys) {} | ||||
| 
 | ||||
| /* ************************************************************************ */ | ||||
| HybridFactor::HybridFactor(const KeyVector &continuousKeys, | ||||
|                            const DiscreteKeys &discreteKeys) | ||||
|     : Base(CollectKeys(continuousKeys, discreteKeys)), | ||||
|       isDiscrete_((continuousKeys.size() == 0) && (discreteKeys.size() != 0)), | ||||
|       isContinuous_((continuousKeys.size() != 0) && (discreteKeys.size() == 0)), | ||||
|       isHybrid_((continuousKeys.size() != 0) && (discreteKeys.size() != 0)), | ||||
|       discreteKeys_(discreteKeys), | ||||
|       continuousKeys_(continuousKeys) {} | ||||
|       continuousKeys_(continuousKeys) { | ||||
|   if ((continuousKeys.size() == 0) && (discreteKeys.size() != 0)) { | ||||
|     category_ = HybridCategory::Discrete; | ||||
|   } else if ((continuousKeys.size() != 0) && (discreteKeys.size() == 0)) { | ||||
|     category_ = HybridCategory::Continuous; | ||||
|   } else { | ||||
|     category_ = HybridCategory::Hybrid; | ||||
|   } | ||||
| } | ||||
| 
 | ||||
| /* ************************************************************************ */ | ||||
| HybridFactor::HybridFactor(const DiscreteKeys &discreteKeys) | ||||
|     : Base(CollectKeys({}, discreteKeys)), | ||||
|       isDiscrete_(true), | ||||
|       category_(HybridCategory::Discrete), | ||||
|       discreteKeys_(discreteKeys), | ||||
|       continuousKeys_({}) {} | ||||
| 
 | ||||
| /* ************************************************************************ */ | ||||
| bool HybridFactor::equals(const HybridFactor &lf, double tol) const { | ||||
|   const This *e = dynamic_cast<const This *>(&lf); | ||||
|   return e != nullptr && Base::equals(*e, tol) && | ||||
|          isDiscrete_ == e->isDiscrete_ && isContinuous_ == e->isContinuous_ && | ||||
|          isHybrid_ == e->isHybrid_ && continuousKeys_ == e->continuousKeys_ && | ||||
|   return e != nullptr && Base::equals(*e, tol) && category_ == e->category_ && | ||||
|          continuousKeys_ == e->continuousKeys_ && | ||||
|          discreteKeys_ == e->discreteKeys_; | ||||
| } | ||||
| 
 | ||||
|  | @ -82,9 +88,18 @@ bool HybridFactor::equals(const HybridFactor &lf, double tol) const { | |||
| void HybridFactor::print(const std::string &s, | ||||
|                          const KeyFormatter &formatter) const { | ||||
|   std::cout << (s.empty() ? "" : s + "\n"); | ||||
|   if (isContinuous_) std::cout << "Continuous "; | ||||
|   if (isDiscrete_) std::cout << "Discrete "; | ||||
|   if (isHybrid_) std::cout << "Hybrid "; | ||||
|   switch (category_) { | ||||
|     case HybridCategory::Continuous: | ||||
|       std::cout << "Continuous "; | ||||
|       break; | ||||
|     case HybridCategory::Discrete: | ||||
|       std::cout << "Discrete "; | ||||
|       break; | ||||
|     case HybridCategory::Hybrid: | ||||
|       std::cout << "Hybrid "; | ||||
|       break; | ||||
|   } | ||||
| 
 | ||||
|   std::cout << "["; | ||||
|   for (size_t c = 0; c < continuousKeys_.size(); c++) { | ||||
|     std::cout << formatter(continuousKeys_.at(c)); | ||||
|  |  | |||
|  | @ -41,6 +41,9 @@ KeyVector CollectKeys(const KeyVector &keys1, const KeyVector &keys2); | |||
| DiscreteKeys CollectDiscreteKeys(const DiscreteKeys &key1, | ||||
|                                  const DiscreteKeys &key2); | ||||
| 
 | ||||
| /// Enum to help with categorizing hybrid factors.
 | ||||
| enum class HybridCategory { Discrete, Continuous, Hybrid }; | ||||
| 
 | ||||
| /**
 | ||||
|  * Base class for *truly* hybrid probabilistic factors | ||||
|  * | ||||
|  | @ -53,9 +56,8 @@ DiscreteKeys CollectDiscreteKeys(const DiscreteKeys &key1, | |||
|  */ | ||||
| class GTSAM_EXPORT HybridFactor : public Factor { | ||||
|  private: | ||||
|   bool isDiscrete_ = false; | ||||
|   bool isContinuous_ = false; | ||||
|   bool isHybrid_ = false; | ||||
|   /// Record what category of HybridFactor this is.
 | ||||
|   HybridCategory category_; | ||||
| 
 | ||||
|  protected: | ||||
|   // Set of DiscreteKeys for this factor.
 | ||||
|  | @ -116,13 +118,13 @@ class GTSAM_EXPORT HybridFactor : public Factor { | |||
|   /// @{
 | ||||
| 
 | ||||
|   /// True if this is a factor of discrete variables only.
 | ||||
|   bool isDiscrete() const { return isDiscrete_; } | ||||
|   bool isDiscrete() const { return category_ == HybridCategory::Discrete; } | ||||
| 
 | ||||
|   /// True if this is a factor of continuous variables only.
 | ||||
|   bool isContinuous() const { return isContinuous_; } | ||||
|   bool isContinuous() const { return category_ == HybridCategory::Continuous; } | ||||
| 
 | ||||
|   /// True is this is a Discrete-Continuous factor.
 | ||||
|   bool isHybrid() const { return isHybrid_; } | ||||
|   bool isHybrid() const { return category_ == HybridCategory::Hybrid; } | ||||
| 
 | ||||
|   /// Return the number of continuous variables in this factor.
 | ||||
|   size_t nrContinuous() const { return continuousKeys_.size(); } | ||||
|  | @ -142,9 +144,7 @@ class GTSAM_EXPORT HybridFactor : public Factor { | |||
|   template <class ARCHIVE> | ||||
|   void serialize(ARCHIVE &ar, const unsigned int /*version*/) { | ||||
|     ar &BOOST_SERIALIZATION_BASE_OBJECT_NVP(Base); | ||||
|     ar &BOOST_SERIALIZATION_NVP(isDiscrete_); | ||||
|     ar &BOOST_SERIALIZATION_NVP(isContinuous_); | ||||
|     ar &BOOST_SERIALIZATION_NVP(isHybrid_); | ||||
|     ar &BOOST_SERIALIZATION_NVP(category_); | ||||
|     ar &BOOST_SERIALIZATION_NVP(discreteKeys_); | ||||
|     ar &BOOST_SERIALIZATION_NVP(continuousKeys_); | ||||
|   } | ||||
|  |  | |||
|  | @ -387,11 +387,13 @@ TEST(HybridBayesNet, Sampling) { | |||
|       std::make_shared<BetweenFactor<double>>(X(0), X(1), 0, noise_model); | ||||
|   auto one_motion = | ||||
|       std::make_shared<BetweenFactor<double>>(X(0), X(1), 1, noise_model); | ||||
|   std::vector<NonlinearFactorValuePair> factors = {{zero_motion, 0.0}, | ||||
|                                                    {one_motion, 0.0}}; | ||||
| 
 | ||||
|   DiscreteKeys discreteKeys{DiscreteKey(M(0), 2)}; | ||||
|   HybridNonlinearFactor::Factors factors( | ||||
|       discreteKeys, {{zero_motion, 0.0}, {one_motion, 0.0}}); | ||||
|   nfg.emplace_shared<PriorFactor<double>>(X(0), 0.0, noise_model); | ||||
|   nfg.emplace_shared<HybridNonlinearFactor>( | ||||
|       KeyVector{X(0), X(1)}, DiscreteKeys{DiscreteKey(M(0), 2)}, factors); | ||||
|   nfg.emplace_shared<HybridNonlinearFactor>(KeyVector{X(0), X(1)}, discreteKeys, | ||||
|                                             factors); | ||||
| 
 | ||||
|   DiscreteKey mode(M(0), 2); | ||||
|   nfg.emplace_shared<DiscreteDistribution>(mode, "1/1"); | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue