turns out we can merge DiscreteConditional and DiscreteLookupTable
							parent
							
								
									5550537709
								
							
						
					
					
						commit
						52f1aba10c
					
				|  | @ -235,7 +235,10 @@ DecisionTreeFactor::shared_ptr DiscreteConditional::likelihood( | |||
| } | ||||
| 
 | ||||
| /* ************************************************************************** */ | ||||
| size_t DiscreteConditional::argmax() const { | ||||
| size_t DiscreteConditional::argmax(const DiscreteValues& parentsValues) const { | ||||
|   ADT pFS = choose(parentsValues, true);  // P(F|S=parentsValues)
 | ||||
| 
 | ||||
|   // Initialize
 | ||||
|   size_t maxValue = 0; | ||||
|   double maxP = 0; | ||||
|   assert(nrFrontals() == 1); | ||||
|  | @ -254,6 +257,33 @@ size_t DiscreteConditional::argmax() const { | |||
|   return maxValue; | ||||
| } | ||||
| 
 | ||||
| /* ************************************************************************** */ | ||||
| void DiscreteConditional::argmaxInPlace(DiscreteValues* values) const { | ||||
|   ADT pFS = choose(*values, true);  // P(F|S=parentsValues)
 | ||||
| 
 | ||||
|   // Initialize
 | ||||
|   DiscreteValues mpe; | ||||
|   double maxP = 0; | ||||
| 
 | ||||
|   // Get all Possible Configurations
 | ||||
|   const auto allPosbValues = frontalAssignments(); | ||||
| 
 | ||||
|   // Find the maximum
 | ||||
|   for (const auto& frontalVals : allPosbValues) { | ||||
|     double pValueS = pFS(frontalVals);  // P(F=value|S=parentsValues)
 | ||||
|     // Update maximum solution if better
 | ||||
|     if (pValueS > maxP) { | ||||
|       maxP = pValueS; | ||||
|       mpe = frontalVals; | ||||
|     } | ||||
|   } | ||||
| 
 | ||||
|   // set values (inPlace) to maximum
 | ||||
|   for (Key j : frontals()) { | ||||
|     (*values)[j] = mpe[j]; | ||||
|   } | ||||
| } | ||||
| 
 | ||||
| /* ************************************************************************** */ | ||||
| void DiscreteConditional::sampleInPlace(DiscreteValues* values) const { | ||||
|   assert(nrFrontals() == 1); | ||||
|  | @ -459,7 +489,7 @@ string DiscreteConditional::html(const KeyFormatter& keyFormatter, | |||
| } | ||||
| 
 | ||||
| /* ************************************************************************* */ | ||||
| double DiscreteConditional::evaluate(const HybridValues& x) const{ | ||||
| double DiscreteConditional::evaluate(const HybridValues& x) const { | ||||
|   return this->evaluate(x.discrete()); | ||||
| } | ||||
| /* ************************************************************************* */ | ||||
|  |  | |||
|  | @ -18,9 +18,9 @@ | |||
| 
 | ||||
| #pragma once | ||||
| 
 | ||||
| #include <gtsam/inference/Conditional-inst.h> | ||||
| #include <gtsam/discrete/DecisionTreeFactor.h> | ||||
| #include <gtsam/discrete/Signature.h> | ||||
| #include <gtsam/inference/Conditional-inst.h> | ||||
| 
 | ||||
| #include <memory> | ||||
| #include <string> | ||||
|  | @ -39,7 +39,7 @@ class GTSAM_EXPORT DiscreteConditional | |||
|       public Conditional<DecisionTreeFactor, DiscreteConditional> { | ||||
|  public: | ||||
|   // typedefs needed to play nice with gtsam
 | ||||
|   typedef DiscreteConditional This;            ///< Typedef to this class
 | ||||
|   typedef DiscreteConditional This;          ///< Typedef to this class
 | ||||
|   typedef std::shared_ptr<This> shared_ptr;  ///< shared_ptr to this class
 | ||||
|   typedef DecisionTreeFactor BaseFactor;  ///< Typedef to our factor base class
 | ||||
|   typedef Conditional<BaseFactor, This> | ||||
|  | @ -159,9 +159,7 @@ class GTSAM_EXPORT DiscreteConditional | |||
|   /// @{
 | ||||
| 
 | ||||
|   /// Log-probability is just -error(x).
 | ||||
|   double logProbability(const DiscreteValues& x) const  { | ||||
|     return -error(x); | ||||
|   } | ||||
|   double logProbability(const DiscreteValues& x) const { return -error(x); } | ||||
| 
 | ||||
|   /// print index signature only
 | ||||
|   void printSignature( | ||||
|  | @ -214,11 +212,18 @@ class GTSAM_EXPORT DiscreteConditional | |||
|   size_t sample() const; | ||||
| 
 | ||||
|   /**
 | ||||
|    * @brief Return assignment that maximizes distribution. | ||||
|    * @return Optimal assignment (1 frontal variable). | ||||
|    * @brief Return assignment for single frontal variable that maximizes value. | ||||
|    * @param parentsValues Known assignments for the parents. | ||||
|    * @return maximizing assignment for the frontal variable. | ||||
|    */ | ||||
|   size_t argmax() const; | ||||
| 
 | ||||
|   /**
 | ||||
|    * @brief Calculate assignment for frontal variables that maximizes value. | ||||
|    * @param (in/out) parentsValues Known assignments for the parents. | ||||
|    */ | ||||
|   void argmaxInPlace(DiscreteValues* parentsValues) const; | ||||
| 
 | ||||
|   /// @}
 | ||||
|   /// @name Advanced Interface
 | ||||
|   /// @{
 | ||||
|  | @ -244,7 +249,6 @@ class GTSAM_EXPORT DiscreteConditional | |||
|   std::string html(const KeyFormatter& keyFormatter = DefaultKeyFormatter, | ||||
|                    const Names& names = {}) const override; | ||||
| 
 | ||||
| 
 | ||||
|   /// @}
 | ||||
|   /// @name HybridValues methods.
 | ||||
|   /// @{
 | ||||
|  |  | |||
|  | @ -29,97 +29,20 @@ using std::vector; | |||
| 
 | ||||
| namespace gtsam { | ||||
| 
 | ||||
| /* ************************************************************************** */ | ||||
| // TODO(dellaert): copy/paste from DiscreteConditional.cpp :-(
 | ||||
| void DiscreteLookupTable::print(const std::string& s, | ||||
|                                 const KeyFormatter& formatter) const { | ||||
|   using std::cout; | ||||
|   using std::endl; | ||||
| 
 | ||||
|   cout << s << " g( "; | ||||
|   for (const_iterator it = beginFrontals(); it != endFrontals(); ++it) { | ||||
|     cout << formatter(*it) << " "; | ||||
|   } | ||||
|   if (nrParents()) { | ||||
|     cout << "; "; | ||||
|     for (const_iterator it = beginParents(); it != endParents(); ++it) { | ||||
|       cout << formatter(*it) << " "; | ||||
|     } | ||||
|   } | ||||
|   cout << "):\n"; | ||||
|   ADT::print("", formatter); | ||||
|   cout << endl; | ||||
| } | ||||
| 
 | ||||
| /* ************************************************************************** */ | ||||
| void DiscreteLookupTable::argmaxInPlace(DiscreteValues* values) const { | ||||
|   ADT pFS = choose(*values, true);  // P(F|S=parentsValues)
 | ||||
| 
 | ||||
|   // Initialize
 | ||||
|   DiscreteValues mpe; | ||||
|   double maxP = 0; | ||||
| 
 | ||||
|   // Get all Possible Configurations
 | ||||
|   const auto allPosbValues = frontalAssignments(); | ||||
| 
 | ||||
|   // Find the maximum
 | ||||
|   for (const auto& frontalVals : allPosbValues) { | ||||
|     double pValueS = pFS(frontalVals);  // P(F=value|S=parentsValues)
 | ||||
|     // Update maximum solution if better
 | ||||
|     if (pValueS > maxP) { | ||||
|       maxP = pValueS; | ||||
|       mpe = frontalVals; | ||||
|     } | ||||
|   } | ||||
| 
 | ||||
|   // set values (inPlace) to maximum
 | ||||
|   for (Key j : frontals()) { | ||||
|     (*values)[j] = mpe[j]; | ||||
|   } | ||||
| } | ||||
| 
 | ||||
| /* ************************************************************************** */ | ||||
| size_t DiscreteLookupTable::argmax(const DiscreteValues& parentsValues) const { | ||||
|   ADT pFS = choose(parentsValues, true);  // P(F|S=parentsValues)
 | ||||
| 
 | ||||
|   // Then, find the max over all remaining
 | ||||
|   // TODO(Duy): only works for one key now, seems horribly slow this way
 | ||||
|   size_t mpe = 0; | ||||
|   double maxP = 0; | ||||
|   DiscreteValues frontals; | ||||
|   assert(nrFrontals() == 1); | ||||
|   Key j = (firstFrontalKey()); | ||||
|   for (size_t value = 0; value < cardinality(j); value++) { | ||||
|     frontals[j] = value; | ||||
|     double pValueS = pFS(frontals);  // P(F=value|S=parentsValues)
 | ||||
|     // Update MPE solution if better
 | ||||
|     if (pValueS > maxP) { | ||||
|       maxP = pValueS; | ||||
|       mpe = value; | ||||
|     } | ||||
|   } | ||||
|   return mpe; | ||||
| } | ||||
| 
 | ||||
| /* ************************************************************************** */ | ||||
| DiscreteLookupDAG DiscreteLookupDAG::FromBayesNet( | ||||
|     const DiscreteBayesNet& bayesNet) { | ||||
|   DiscreteLookupDAG dag; | ||||
|   for (auto&& conditional : bayesNet) { | ||||
|     if (auto lookupTable = | ||||
|             std::dynamic_pointer_cast<DiscreteLookupTable>(conditional)) { | ||||
|       dag.push_back(lookupTable); | ||||
|     } else { | ||||
|       throw std::runtime_error( | ||||
|           "DiscreteFactorGraph::maxProduct: Expected look up table."); | ||||
|     } | ||||
|     dag.push_back(conditional); | ||||
|   } | ||||
|   return dag; | ||||
| } | ||||
| 
 | ||||
| DiscreteValues DiscreteLookupDAG::argmax(DiscreteValues result) const { | ||||
|   // Argmax each node in turn in topological sort order (parents first).
 | ||||
|   for (auto it = std::make_reverse_iterator(end()); it != std::make_reverse_iterator(begin()); ++it) { | ||||
|   for (auto it = std::make_reverse_iterator(end()); | ||||
|        it != std::make_reverse_iterator(begin()); ++it) { | ||||
|     // dereference to get the sharedFactor to the lookup table
 | ||||
|     (*it)->argmaxInPlace(&result); | ||||
|   } | ||||
|  |  | |||
|  | @ -37,41 +37,9 @@ class DiscreteBayesNet; | |||
|  * Inherits from discrete conditional for convenience, but is not normalized. | ||||
|  * Is used in the max-product algorithm. | ||||
|  */ | ||||
| class GTSAM_EXPORT DiscreteLookupTable : public DiscreteConditional { | ||||
|  public: | ||||
|   using This = DiscreteLookupTable; | ||||
|   using shared_ptr = std::shared_ptr<This>; | ||||
|   using BaseConditional = Conditional<DecisionTreeFactor, This>; | ||||
| 
 | ||||
|   /**
 | ||||
|    * @brief Construct a new Discrete Lookup Table object | ||||
|    * | ||||
|    * @param nFrontals number of frontal variables | ||||
|    * @param keys a sorted list of gtsam::Keys | ||||
|    * @param potentials the algebraic decision tree with lookup values | ||||
|    */ | ||||
|   DiscreteLookupTable(size_t nFrontals, const DiscreteKeys& keys, | ||||
|                       const ADT& potentials) | ||||
|       : DiscreteConditional(nFrontals, keys, potentials) {} | ||||
| 
 | ||||
|   /// GTSAM-style print
 | ||||
|   void print( | ||||
|       const std::string& s = "Discrete Lookup Table: ", | ||||
|       const KeyFormatter& formatter = DefaultKeyFormatter) const override; | ||||
| 
 | ||||
|   /**
 | ||||
|    * @brief return assignment for single frontal variable that maximizes value. | ||||
|    * @param parentsValues Known assignments for the parents. | ||||
|    * @return maximizing assignment for the frontal variable. | ||||
|    */ | ||||
|   size_t argmax(const DiscreteValues& parentsValues) const; | ||||
| 
 | ||||
|   /**
 | ||||
|    * @brief Calculate assignment for frontal variables that maximizes value. | ||||
|    * @param (in/out) parentsValues Known assignments for the parents. | ||||
|    */ | ||||
|   void argmaxInPlace(DiscreteValues* parentsValues) const; | ||||
| }; | ||||
| // Typedef for backwards compatibility
 | ||||
| // TODO(Varun): Remove
 | ||||
| using DiscreteLookupTable = DiscreteConditional; | ||||
| 
 | ||||
| /** A DAG made from lookup tables, as defined above. */ | ||||
| class GTSAM_EXPORT DiscreteLookupDAG : public BayesNet<DiscreteLookupTable> { | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue