Merge pull request #1963 from borglab/discrete-multiply
DiscreteFactor multiply methodrelease/4.3a0
						commit
						47074bd0c2
					
				|  | @ -18,9 +18,10 @@ | |||
|  */ | ||||
| 
 | ||||
| #include <gtsam/base/FastSet.h> | ||||
| #include <gtsam/hybrid/HybridValues.h> | ||||
| #include <gtsam/discrete/DecisionTreeFactor.h> | ||||
| #include <gtsam/discrete/DiscreteConditional.h> | ||||
| #include <gtsam/discrete/TableFactor.h> | ||||
| #include <gtsam/hybrid/HybridValues.h> | ||||
| 
 | ||||
| #include <utility> | ||||
| 
 | ||||
|  | @ -62,6 +63,30 @@ namespace gtsam { | |||
|     return error(values.discrete()); | ||||
|   } | ||||
| 
 | ||||
|   /* ************************************************************************ */ | ||||
|   DiscreteFactor::shared_ptr DecisionTreeFactor::multiply( | ||||
|       const DiscreteFactor::shared_ptr& f) const { | ||||
|     DiscreteFactor::shared_ptr result; | ||||
|     if (auto tf = std::dynamic_pointer_cast<TableFactor>(f)) { | ||||
|       // If f is a TableFactor, we convert `this` to a TableFactor since this
 | ||||
|       // conversion is cheaper than converting `f` to a DecisionTreeFactor. We
 | ||||
|       // then return a TableFactor.
 | ||||
|       result = std::make_shared<TableFactor>((*tf) * TableFactor(*this)); | ||||
| 
 | ||||
|     } else if (auto dtf = std::dynamic_pointer_cast<DecisionTreeFactor>(f)) { | ||||
|       // If `f` is a DecisionTreeFactor, simply call operator*.
 | ||||
|       result = std::make_shared<DecisionTreeFactor>(this->operator*(*dtf)); | ||||
| 
 | ||||
|     } else { | ||||
|       // Simulate double dispatch in C++
 | ||||
|       // Useful for other classes which inherit from DiscreteFactor and have
 | ||||
|       // only `operator*(DecisionTreeFactor)` defined. Thus, other classes don't
 | ||||
|       // need to be updated.
 | ||||
|       result = std::make_shared<DecisionTreeFactor>(f->operator*(*this)); | ||||
|     } | ||||
|     return result; | ||||
|   } | ||||
| 
 | ||||
|   /* ************************************************************************ */ | ||||
|   double DecisionTreeFactor::safe_div(const double& a, const double& b) { | ||||
|     // The use for safe_div is when we divide the product factor by the sum
 | ||||
|  |  | |||
|  | @ -147,6 +147,23 @@ namespace gtsam { | |||
|     /// Calculate error for DiscreteValues `x`, is -log(probability).
 | ||||
|     double error(const DiscreteValues& values) const override; | ||||
| 
 | ||||
|     /**
 | ||||
|      * @brief Multiply factors, DiscreteFactor::shared_ptr edition. | ||||
|      * | ||||
|      * This method accepts `DiscreteFactor::shared_ptr` and uses dynamic | ||||
|      * dispatch and specializations to perform the most efficient | ||||
|      * multiplication. | ||||
|      * | ||||
|      * While converting a DecisionTreeFactor to a TableFactor is efficient, the | ||||
|      * reverse is not. Hence we specialize the code to return a TableFactor if | ||||
|      * `f` is a TableFactor, and DecisionTreeFactor otherwise. | ||||
|      * | ||||
|      * @param f The factor to multiply with. | ||||
|      * @return DiscreteFactor::shared_ptr | ||||
|      */ | ||||
|     virtual DiscreteFactor::shared_ptr multiply( | ||||
|         const DiscreteFactor::shared_ptr& f) const override; | ||||
| 
 | ||||
|     /// multiply two factors
 | ||||
|     DecisionTreeFactor operator*(const DecisionTreeFactor& f) const override { | ||||
|       return apply(f, Ring::mul); | ||||
|  |  | |||
|  | @ -129,6 +129,16 @@ class GTSAM_EXPORT DiscreteFactor : public Factor { | |||
|   /// DecisionTreeFactor
 | ||||
|   virtual DecisionTreeFactor operator*(const DecisionTreeFactor&) const = 0; | ||||
| 
 | ||||
|   /**
 | ||||
|    * @brief Multiply in a DiscreteFactor and return the result as | ||||
|    * DiscreteFactor, both via shared pointers. | ||||
|    * | ||||
|    * @param df DiscreteFactor shared_ptr | ||||
|    * @return DiscreteFactor::shared_ptr | ||||
|    */ | ||||
|   virtual DiscreteFactor::shared_ptr multiply( | ||||
|       const DiscreteFactor::shared_ptr& df) const = 0; | ||||
| 
 | ||||
|   virtual DecisionTreeFactor toDecisionTreeFactor() const = 0; | ||||
| 
 | ||||
|   /// @}
 | ||||
|  |  | |||
|  | @ -65,11 +65,18 @@ namespace gtsam { | |||
| 
 | ||||
|   /* ************************************************************************ */ | ||||
|   DecisionTreeFactor DiscreteFactorGraph::product() const { | ||||
|     DecisionTreeFactor result; | ||||
|     for (const sharedFactor& factor : *this) { | ||||
|       if (factor) result = (*factor) * result; | ||||
|     DiscreteFactor::shared_ptr result; | ||||
|     for (auto it = this->begin(); it != this->end(); ++it) { | ||||
|       if (*it) { | ||||
|         if (result) { | ||||
|           result = result->multiply(*it); | ||||
|         } else { | ||||
|           // Assign to the first non-null factor
 | ||||
|           result = *it; | ||||
|         } | ||||
|     return result; | ||||
|       } | ||||
|     } | ||||
|     return result->toDecisionTreeFactor(); | ||||
|   } | ||||
| 
 | ||||
|   /* ************************************************************************ */ | ||||
|  |  | |||
|  | @ -254,6 +254,32 @@ DecisionTreeFactor TableFactor::operator*(const DecisionTreeFactor& f) const { | |||
|   return toDecisionTreeFactor() * f; | ||||
| } | ||||
| 
 | ||||
| /* ************************************************************************ */ | ||||
| DiscreteFactor::shared_ptr TableFactor::multiply( | ||||
|     const DiscreteFactor::shared_ptr& f) const { | ||||
|   DiscreteFactor::shared_ptr result; | ||||
|   if (auto tf = std::dynamic_pointer_cast<TableFactor>(f)) { | ||||
|     // If `f` is a TableFactor, we can simply call `operator*`.
 | ||||
|     result = std::make_shared<TableFactor>(this->operator*(*tf)); | ||||
| 
 | ||||
|   } else if (auto dtf = std::dynamic_pointer_cast<DecisionTreeFactor>(f)) { | ||||
|     // If `f` is a DecisionTreeFactor, we convert to a TableFactor which is
 | ||||
|     // cheaper than converting `this` to a DecisionTreeFactor.
 | ||||
|     result = std::make_shared<TableFactor>(this->operator*(TableFactor(*dtf))); | ||||
| 
 | ||||
|   } else { | ||||
|     // Simulate double dispatch in C++
 | ||||
|     // Useful for other classes which inherit from DiscreteFactor and have
 | ||||
|     // only `operator*(DecisionTreeFactor)` defined. Thus, other classes don't
 | ||||
|     // need to be updated to know about TableFactor.
 | ||||
|     // Those classes can be specialized to use TableFactor
 | ||||
|     // if efficiency is a problem.
 | ||||
|     result = std::make_shared<DecisionTreeFactor>( | ||||
|         f->operator*(this->toDecisionTreeFactor())); | ||||
|   } | ||||
|   return result; | ||||
| } | ||||
| 
 | ||||
| /* ************************************************************************ */ | ||||
| DecisionTreeFactor TableFactor::toDecisionTreeFactor() const { | ||||
|   DiscreteKeys dkeys = discreteKeys(); | ||||
|  |  | |||
|  | @ -178,6 +178,23 @@ class GTSAM_EXPORT TableFactor : public DiscreteFactor { | |||
|   /// multiply with DecisionTreeFactor
 | ||||
|   DecisionTreeFactor operator*(const DecisionTreeFactor& f) const override; | ||||
| 
 | ||||
|   /**
 | ||||
|    * @brief Multiply factors, DiscreteFactor::shared_ptr edition. | ||||
|    * | ||||
|    * This method accepts `DiscreteFactor::shared_ptr` and uses dynamic | ||||
|    * dispatch and specializations to perform the most efficient | ||||
|    * multiplication. | ||||
|    * | ||||
|    * While converting a DecisionTreeFactor to a TableFactor is efficient, the | ||||
|    * reverse is not. | ||||
|    * Hence we specialize the code to return a TableFactor always. | ||||
|    * | ||||
|    * @param f The factor to multiply with. | ||||
|    * @return DiscreteFactor::shared_ptr | ||||
|    */ | ||||
|   virtual DiscreteFactor::shared_ptr multiply( | ||||
|       const DiscreteFactor::shared_ptr& f) const override; | ||||
| 
 | ||||
|   static double safe_div(const double& a, const double& b); | ||||
| 
 | ||||
|   /// divide by factor f (safely)
 | ||||
|  |  | |||
|  | @ -78,6 +78,14 @@ class GTSAM_UNSTABLE_EXPORT Constraint : public DiscreteFactor { | |||
| 
 | ||||
|   /// Partially apply known values, domain version
 | ||||
|   virtual shared_ptr partiallyApply(const Domains&) const = 0; | ||||
| 
 | ||||
|   /// Multiply factors, DiscreteFactor::shared_ptr edition
 | ||||
|   DiscreteFactor::shared_ptr multiply( | ||||
|       const DiscreteFactor::shared_ptr& df) const override { | ||||
|     return std::make_shared<DecisionTreeFactor>( | ||||
|         this->operator*(df->toDecisionTreeFactor())); | ||||
|   } | ||||
| 
 | ||||
|   /// @}
 | ||||
|   /// @name Wrapper support
 | ||||
|   /// @{
 | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue