Merge pull request #1927 from borglab/discrete-improvements
Various Discrete Improvementsrelease/4.3a0
						commit
						2c9e315a2c
					
				|  | @ -168,13 +168,9 @@ class GTSAM_EXPORT DiscreteConditional | |||
|     static_cast<const BaseConditional*>(this)->print(s, formatter); | ||||
|   } | ||||
| 
 | ||||
|   /// Evaluate, just look up in AlgebraicDecisionTree
 | ||||
|   virtual double evaluate(const Assignment<Key>& values) const override { | ||||
|     return ADT::operator()(values); | ||||
|   } | ||||
| 
 | ||||
|   using DecisionTreeFactor::error;       ///< DiscreteValues version
 | ||||
|   using DiscreteFactor::operator();      ///< DiscreteValues version
 | ||||
|   using BaseFactor::error;       ///< DiscreteValues version
 | ||||
|   using BaseFactor::evaluate;    ///< DiscreteValues version
 | ||||
|   using BaseFactor::operator();  ///< DiscreteValues version
 | ||||
| 
 | ||||
|   /**
 | ||||
|    * @brief restrict to given *parent* values. | ||||
|  |  | |||
|  | @ -40,7 +40,7 @@ class GTSAM_EXPORT DiscreteDistribution : public DiscreteConditional { | |||
|   /// Default constructor needed for serialization.
 | ||||
|   DiscreteDistribution() {} | ||||
| 
 | ||||
|   /// Constructor from factor.
 | ||||
|   /// Constructor from DecisionTreeFactor.
 | ||||
|   explicit DiscreteDistribution(const DecisionTreeFactor& f) | ||||
|       : Base(f.size(), f) {} | ||||
| 
 | ||||
|  |  | |||
|  | @ -14,6 +14,7 @@ | |||
|  *  @date Feb 14, 2011 | ||||
|  *  @author Duy-Nguyen Ta | ||||
|  *  @author Frank Dellaert | ||||
|  *  @author Varun Agrawal | ||||
|  */ | ||||
| 
 | ||||
| #include <gtsam/discrete/DiscreteBayesTree.h> | ||||
|  | @ -35,13 +36,12 @@ namespace gtsam { | |||
|   template class FactorGraph<DiscreteFactor>; | ||||
|   template class EliminateableFactorGraph<DiscreteFactorGraph>; | ||||
| 
 | ||||
|   /* ************************************************************************* */ | ||||
|   bool DiscreteFactorGraph::equals(const This& fg, double tol) const | ||||
|   { | ||||
|   /* ************************************************************************ */ | ||||
|   bool DiscreteFactorGraph::equals(const This& fg, double tol) const { | ||||
|     return Base::equals(fg, tol); | ||||
|   } | ||||
| 
 | ||||
|   /* ************************************************************************* */ | ||||
|   /* ************************************************************************ */ | ||||
|   KeySet DiscreteFactorGraph::keys() const { | ||||
|     KeySet keys; | ||||
|     for (const sharedFactor& factor : *this) { | ||||
|  | @ -50,11 +50,11 @@ namespace gtsam { | |||
|     return keys; | ||||
|   } | ||||
| 
 | ||||
|   /* ************************************************************************* */ | ||||
|   /* ************************************************************************ */ | ||||
|   DiscreteKeys DiscreteFactorGraph::discreteKeys() const { | ||||
|     DiscreteKeys result; | ||||
|     for (auto&& factor : *this) { | ||||
|       if (auto p = std::dynamic_pointer_cast<DecisionTreeFactor>(factor)) { | ||||
|       if (auto p = std::dynamic_pointer_cast<DiscreteFactor>(factor)) { | ||||
|         DiscreteKeys factor_keys = p->discreteKeys(); | ||||
|         result.insert(result.end(), factor_keys.begin(), factor_keys.end()); | ||||
|       } | ||||
|  | @ -63,26 +63,27 @@ namespace gtsam { | |||
|     return result; | ||||
|   } | ||||
| 
 | ||||
|   /* ************************************************************************* */ | ||||
|   /* ************************************************************************ */ | ||||
|   DecisionTreeFactor DiscreteFactorGraph::product() const { | ||||
|     DecisionTreeFactor result; | ||||
|     for(const sharedFactor& factor: *this) | ||||
|     for (const sharedFactor& factor : *this) { | ||||
|       if (factor) result = (*factor) * result; | ||||
|     } | ||||
|     return result; | ||||
|   } | ||||
| 
 | ||||
|   /* ************************************************************************* */ | ||||
|   double DiscreteFactorGraph::operator()( | ||||
|       const DiscreteValues &values) const { | ||||
|   /* ************************************************************************ */ | ||||
|   double DiscreteFactorGraph::operator()(const DiscreteValues& values) const { | ||||
|     double product = 1.0; | ||||
|     for( const sharedFactor& factor: factors_ ) | ||||
|       product *= (*factor)(values); | ||||
|     for (const sharedFactor& factor : factors_) { | ||||
|       if (factor) product *= (*factor)(values); | ||||
|     } | ||||
|     return product; | ||||
|   } | ||||
| 
 | ||||
|   /* ************************************************************************* */ | ||||
|   /* ************************************************************************ */ | ||||
|   void DiscreteFactorGraph::print(const string& s, | ||||
|       const KeyFormatter& formatter) const { | ||||
|                                   const KeyFormatter& formatter) const { | ||||
|     std::cout << s << std::endl; | ||||
|     std::cout << "size: " << size() << std::endl; | ||||
|     for (size_t i = 0; i < factors_.size(); i++) { | ||||
|  | @ -110,15 +111,18 @@ namespace gtsam { | |||
| //      }
 | ||||
| //  }
 | ||||
| 
 | ||||
|   /* ************************************************************************ */ | ||||
|   // Alternate eliminate function for MPE
 | ||||
|   std::pair<DiscreteConditional::shared_ptr, DecisionTreeFactor::shared_ptr>  //
 | ||||
|   EliminateForMPE(const DiscreteFactorGraph& factors, | ||||
|                   const Ordering& frontalKeys) { | ||||
|   /**
 | ||||
|    * @brief Multiply all the `factors` and normalize the | ||||
|    * product to prevent underflow. | ||||
|    * | ||||
|    * @param factors The factors to multiply as a DiscreteFactorGraph. | ||||
|    * @return DecisionTreeFactor | ||||
|    */ | ||||
|   static DecisionTreeFactor ProductAndNormalize( | ||||
|       const DiscreteFactorGraph& factors) { | ||||
|     // PRODUCT: multiply all factors
 | ||||
|     gttic(product); | ||||
|     DecisionTreeFactor product; | ||||
|     for (auto&& factor : factors) product = (*factor) * product; | ||||
|     DecisionTreeFactor product = factors.product(); | ||||
|     gttoc(product); | ||||
| 
 | ||||
|     // Max over all the potentials by pretending all keys are frontal:
 | ||||
|  | @ -127,6 +131,16 @@ namespace gtsam { | |||
|     // Normalize the product factor to prevent underflow.
 | ||||
|     product = product / (*normalization); | ||||
| 
 | ||||
|     return product; | ||||
|   } | ||||
| 
 | ||||
|   /* ************************************************************************ */ | ||||
|   // Alternate eliminate function for MPE
 | ||||
|   std::pair<DiscreteConditional::shared_ptr, DiscreteFactor::shared_ptr>  //
 | ||||
|   EliminateForMPE(const DiscreteFactorGraph& factors, | ||||
|                   const Ordering& frontalKeys) { | ||||
|     DecisionTreeFactor product = ProductAndNormalize(factors); | ||||
| 
 | ||||
|     // max out frontals, this is the factor on the separator
 | ||||
|     gttic(max); | ||||
|     DecisionTreeFactor::shared_ptr max = product.max(frontalKeys); | ||||
|  | @ -142,8 +156,8 @@ namespace gtsam { | |||
|     // Make lookup with product
 | ||||
|     gttic(lookup); | ||||
|     size_t nrFrontals = frontalKeys.size(); | ||||
|     auto lookup = std::make_shared<DiscreteLookupTable>(nrFrontals, | ||||
|                                                           orderedKeys, product); | ||||
|     auto lookup = | ||||
|         std::make_shared<DiscreteLookupTable>(nrFrontals, orderedKeys, product); | ||||
|     gttoc(lookup); | ||||
| 
 | ||||
|     return {std::dynamic_pointer_cast<DiscreteConditional>(lookup), max}; | ||||
|  | @ -201,20 +215,10 @@ namespace gtsam { | |||
|   } | ||||
| 
 | ||||
|   /* ************************************************************************ */ | ||||
|   std::pair<DiscreteConditional::shared_ptr, DecisionTreeFactor::shared_ptr>  //
 | ||||
|   std::pair<DiscreteConditional::shared_ptr, DiscreteFactor::shared_ptr>  //
 | ||||
|   EliminateDiscrete(const DiscreteFactorGraph& factors, | ||||
|                     const Ordering& frontalKeys) { | ||||
|     // PRODUCT: multiply all factors
 | ||||
|     gttic(product); | ||||
|     DecisionTreeFactor product; | ||||
|     for (auto&& factor : factors) product = (*factor) * product; | ||||
|     gttoc(product); | ||||
| 
 | ||||
|     // Max over all the potentials by pretending all keys are frontal:
 | ||||
|     auto normalization = product.max(product.size()); | ||||
| 
 | ||||
|     // Normalize the product factor to prevent underflow.
 | ||||
|     product = product / (*normalization); | ||||
|     DecisionTreeFactor product = ProductAndNormalize(factors); | ||||
| 
 | ||||
|     // sum out frontals, this is the factor on the separator
 | ||||
|     gttic(sum); | ||||
|  |  | |||
|  | @ -14,6 +14,7 @@ | |||
|  * @date Feb 14, 2011 | ||||
|  * @author Duy-Nguyen Ta | ||||
|  * @author Frank Dellaert | ||||
|  * @author Varun Agrawal | ||||
|  */ | ||||
| 
 | ||||
| #pragma once | ||||
|  | @ -48,7 +49,7 @@ class DiscreteJunctionTree; | |||
|  * @ingroup discrete | ||||
|  */ | ||||
| GTSAM_EXPORT | ||||
| std::pair<DiscreteConditional::shared_ptr, DecisionTreeFactor::shared_ptr> | ||||
| std::pair<DiscreteConditional::shared_ptr, DiscreteFactor::shared_ptr> | ||||
| EliminateDiscrete(const DiscreteFactorGraph& factors, | ||||
|                   const Ordering& frontalKeys); | ||||
| 
 | ||||
|  | @ -61,7 +62,7 @@ EliminateDiscrete(const DiscreteFactorGraph& factors, | |||
|  * @ingroup discrete | ||||
|  */ | ||||
| GTSAM_EXPORT | ||||
| std::pair<DiscreteConditional::shared_ptr, DecisionTreeFactor::shared_ptr> | ||||
| std::pair<DiscreteConditional::shared_ptr, DiscreteFactor::shared_ptr> | ||||
| EliminateForMPE(const DiscreteFactorGraph& factors, | ||||
|                 const Ordering& frontalKeys); | ||||
| 
 | ||||
|  |  | |||
|  | @ -113,7 +113,8 @@ TEST(DiscreteFactorGraph, test) { | |||
|   const Ordering frontalKeys{0}; | ||||
|   const auto [conditional, newFactorPtr] = EliminateDiscrete(graph, frontalKeys); | ||||
| 
 | ||||
|   DecisionTreeFactor newFactor = *newFactorPtr; | ||||
|   DecisionTreeFactor newFactor = | ||||
|       *std::dynamic_pointer_cast<DecisionTreeFactor>(newFactorPtr); | ||||
| 
 | ||||
|   // Normalize newFactor by max for comparison with expected
 | ||||
|   auto normalization = newFactor.max(newFactor.size()); | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue