multiply method for DiscreteFactor
							parent
							
								
									e9e52ad21f
								
							
						
					
					
						commit
						5d865a8cc7
					
				|  | @ -62,6 +62,18 @@ namespace gtsam { | |||
|     return error(values.discrete()); | ||||
|   } | ||||
| 
 | ||||
|   /* ************************************************************************ */ | ||||
|   DiscreteFactor::shared_ptr DecisionTreeFactor::multiply( | ||||
|       const DiscreteFactor::shared_ptr& f) const override { | ||||
|     DiscreteFactor::shared_ptr result; | ||||
|     if (auto tf = std::dynamic_pointer_cast<TableFactor>(f)) { | ||||
|       result = std::make_shared<TableFactor>((*tf) * TableFactor(*this)); | ||||
|     } else if (auto dtf = std::dynamic_pointer_cast<DecisionTreeFactor>(f)) { | ||||
|       result = std::make_shared<DecisionTreeFactor>(this->operator*(*dtf)); | ||||
|     } | ||||
|     return result; | ||||
|   } | ||||
| 
 | ||||
|   /* ************************************************************************ */ | ||||
|   double DecisionTreeFactor::safe_div(const double& a, const double& b) { | ||||
|     // The use for safe_div is when we divide the product factor by the sum
 | ||||
|  |  | |||
|  | @ -22,6 +22,7 @@ | |||
| #include <gtsam/discrete/DiscreteFactor.h> | ||||
| #include <gtsam/discrete/DiscreteKey.h> | ||||
| #include <gtsam/discrete/Ring.h> | ||||
| #include <gtsam/discrete/TableFactor.h> | ||||
| #include <gtsam/inference/Ordering.h> | ||||
| 
 | ||||
| #include <algorithm> | ||||
|  | @ -147,6 +148,10 @@ namespace gtsam { | |||
|     /// Calculate error for DiscreteValues `x`, is -log(probability).
 | ||||
|     double error(const DiscreteValues& values) const override; | ||||
| 
 | ||||
|     /// Multiply factors, DiscreteFactor::shared_ptr edition
 | ||||
|     virtual DiscreteFactor::shared_ptr multiply( | ||||
|         const DiscreteFactor::shared_ptr& f) const override; | ||||
| 
 | ||||
|     /// multiply two factors
 | ||||
|     DecisionTreeFactor operator*(const DecisionTreeFactor& f) const override { | ||||
|       return apply(f, Ring::mul); | ||||
|  |  | |||
|  | @ -129,6 +129,16 @@ class GTSAM_EXPORT DiscreteFactor : public Factor { | |||
|   /// DecisionTreeFactor
 | ||||
|   virtual DecisionTreeFactor operator*(const DecisionTreeFactor&) const = 0; | ||||
| 
 | ||||
|   /**
 | ||||
|    * @brief Multiply in a DiscreteFactor and return the result as | ||||
|    * DiscreteFactor, both via shared pointers. | ||||
|    * | ||||
|    * @param df DiscreteFactor shared_ptr | ||||
|    * @return DiscreteFactor::shared_ptr | ||||
|    */ | ||||
|   virtual DiscreteFactor::shared_ptr multiply( | ||||
|       const DiscreteFactor::shared_ptr& df) const = 0; | ||||
| 
 | ||||
|   virtual DecisionTreeFactor toDecisionTreeFactor() const = 0; | ||||
| 
 | ||||
|   /// @}
 | ||||
|  |  | |||
|  | @ -254,6 +254,18 @@ DecisionTreeFactor TableFactor::operator*(const DecisionTreeFactor& f) const { | |||
|   return toDecisionTreeFactor() * f; | ||||
| } | ||||
| 
 | ||||
| /* ************************************************************************ */ | ||||
| DiscreteFactor::shared_ptr TableFactor::multiply( | ||||
|     const DiscreteFactor::shared_ptr& f) const override { | ||||
|   DiscreteFactor::shared_ptr result; | ||||
|   if (auto tf = std::dynamic_pointer_cast<TableFactor>(f)) { | ||||
|     result = std::make_shared<TableFactor>(this->operator*(*tf)); | ||||
|   } else if (auto dtf = std::dynamic_pointer_cast<DecisionTreeFactor>(f)) { | ||||
|     result = std::make_shared<TableFactor>(this->operator*(TableFactor(*dtf))); | ||||
|   } | ||||
|   return result; | ||||
| } | ||||
| 
 | ||||
| /* ************************************************************************ */ | ||||
| DecisionTreeFactor TableFactor::toDecisionTreeFactor() const { | ||||
|   DiscreteKeys dkeys = discreteKeys(); | ||||
|  |  | |||
|  | @ -178,6 +178,10 @@ class GTSAM_EXPORT TableFactor : public DiscreteFactor { | |||
|   /// multiply with DecisionTreeFactor
 | ||||
|   DecisionTreeFactor operator*(const DecisionTreeFactor& f) const override; | ||||
| 
 | ||||
|   /// Multiply factors, DiscreteFactor::shared_ptr edition
 | ||||
|   virtual DiscreteFactor::shared_ptr multiply( | ||||
|       const DiscreteFactor::shared_ptr& f) const override; | ||||
| 
 | ||||
|   static double safe_div(const double& a, const double& b); | ||||
| 
 | ||||
|   /// divide by factor f (safely)
 | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue