Merge pull request #990 from borglab/feature/api_improvements
						commit
						e5b928c610
					
				|  | @ -9,12 +9,18 @@ endif() | |||
| 
 | ||||
| # Set the version number for the library | ||||
| set (GTSAM_VERSION_MAJOR 4) | ||||
| set (GTSAM_VERSION_MINOR 1) | ||||
| set (GTSAM_VERSION_PATCH 1) | ||||
| set (GTSAM_VERSION_MINOR 2) | ||||
| set (GTSAM_VERSION_PATCH 0) | ||||
| set (GTSAM_PRERELEASE_VERSION "a0") | ||||
| math (EXPR GTSAM_VERSION_NUMERIC "10000 * ${GTSAM_VERSION_MAJOR} + 100 * ${GTSAM_VERSION_MINOR} + ${GTSAM_VERSION_PATCH}") | ||||
| set (GTSAM_VERSION_STRING "${GTSAM_VERSION_MAJOR}.${GTSAM_VERSION_MINOR}.${GTSAM_VERSION_PATCH}") | ||||
| 
 | ||||
| set (CMAKE_PROJECT_VERSION ${GTSAM_VERSION_STRING}) | ||||
| if (${GTSAM_VERSION_PATCH} EQUAL 0) | ||||
|     set (GTSAM_VERSION_STRING "${GTSAM_VERSION_MAJOR}.${GTSAM_VERSION_MINOR}${GTSAM_PRERELEASE_VERSION}") | ||||
| else() | ||||
|     set (GTSAM_VERSION_STRING "${GTSAM_VERSION_MAJOR}.${GTSAM_VERSION_MINOR}.${GTSAM_VERSION_PATCH}${GTSAM_PRERELEASE_VERSION}") | ||||
| endif() | ||||
| message(STATUS "GTSAM Version: ${GTSAM_VERSION_STRING}") | ||||
| 
 | ||||
| set (CMAKE_PROJECT_VERSION_MAJOR ${GTSAM_VERSION_MAJOR}) | ||||
| set (CMAKE_PROJECT_VERSION_MINOR ${GTSAM_VERSION_MINOR}) | ||||
| set (CMAKE_PROJECT_VERSION_PATCH ${GTSAM_VERSION_PATCH}) | ||||
|  |  | |||
|  | @ -134,17 +134,34 @@ namespace gtsam { | |||
|     return boost::make_shared<DecisionTreeFactor>(dkeys, result); | ||||
|   } | ||||
| 
 | ||||
|   /* ************************************************************************* */ | ||||
|   std::vector<std::pair<DiscreteValues, double>> DecisionTreeFactor::enumerate() const { | ||||
|     // Get all possible assignments
 | ||||
|     std::vector<std::pair<Key, size_t>> pairs; | ||||
|     for (auto& key : keys()) { | ||||
|       pairs.emplace_back(key, cardinalities_.at(key)); | ||||
|     } | ||||
|     // Reverse to make cartesianProduct output a more natural ordering.
 | ||||
|     std::vector<std::pair<Key, size_t>> rpairs(pairs.rbegin(), pairs.rend()); | ||||
|     const auto assignments = cartesianProduct(rpairs); | ||||
| 
 | ||||
|     // Construct unordered_map with values
 | ||||
|     std::vector<std::pair<DiscreteValues, double>> result; | ||||
|     for (const auto& assignment : assignments) { | ||||
|       result.emplace_back(assignment, operator()(assignment)); | ||||
|     } | ||||
|     return result; | ||||
|   } | ||||
| 
 | ||||
|   /* ************************************************************************* */ | ||||
|   std::string DecisionTreeFactor::markdown( | ||||
|       const KeyFormatter& keyFormatter) const { | ||||
|     std::stringstream ss; | ||||
| 
 | ||||
|     // Print out header and construct argument for `cartesianProduct`.
 | ||||
|     std::vector<std::pair<Key, size_t>> pairs; | ||||
|     ss << "|"; | ||||
|     for (auto& key : keys()) { | ||||
|       ss << keyFormatter(key) << "|"; | ||||
|       pairs.emplace_back(key, cardinalities_.at(key)); | ||||
|     } | ||||
|     ss << "value|\n"; | ||||
| 
 | ||||
|  | @ -154,12 +171,12 @@ namespace gtsam { | |||
|     ss << ":-:|\n"; | ||||
| 
 | ||||
|     // Print out all rows.
 | ||||
|     std::vector<std::pair<Key, size_t>> rpairs(pairs.rbegin(), pairs.rend()); | ||||
|     const auto assignments = cartesianProduct(rpairs); | ||||
|     for (const auto& assignment : assignments) { | ||||
|     auto rows = enumerate(); | ||||
|     for (const auto& kv : rows) { | ||||
|       ss << "|"; | ||||
|       auto assignment = kv.first; | ||||
|       for (auto& key : keys()) ss << assignment.at(key) << "|"; | ||||
|       ss << operator()(assignment) << "|\n"; | ||||
|       ss << kv.second << "|\n"; | ||||
|     } | ||||
|     return ss.str(); | ||||
|   } | ||||
|  |  | |||
|  | @ -61,6 +61,15 @@ namespace gtsam { | |||
|         DiscreteFactor(keys.indices()), Potentials(keys, table) { | ||||
|     } | ||||
| 
 | ||||
|     /// Single-key specialization
 | ||||
|     template <class SOURCE> | ||||
|     DecisionTreeFactor(const DiscreteKey& key, SOURCE table) | ||||
|         : DecisionTreeFactor(DiscreteKeys{key}, table) {} | ||||
| 
 | ||||
|     /// Single-key specialization, with vector of doubles.
 | ||||
|     DecisionTreeFactor(const DiscreteKey& key, const std::vector<double>& row) | ||||
|         : DecisionTreeFactor(DiscreteKeys{key}, row) {} | ||||
| 
 | ||||
|     /** Construct from a DiscreteConditional type */ | ||||
|     DecisionTreeFactor(const DiscreteConditional& c); | ||||
| 
 | ||||
|  | @ -162,6 +171,9 @@ namespace gtsam { | |||
| //      Potentials::reduceWithInverse(inverseReduction);
 | ||||
| //    }
 | ||||
| 
 | ||||
|     /// Enumerate all values into a map from values to double.
 | ||||
|     std::vector<std::pair<DiscreteValues, double>> enumerate() const; | ||||
| 
 | ||||
|     /// @}
 | ||||
|     /// @name Wrapper support
 | ||||
|     /// @{
 | ||||
|  |  | |||
|  | @ -23,6 +23,7 @@ | |||
| #include <boost/shared_ptr.hpp> | ||||
| #include <gtsam/inference/BayesNet.h> | ||||
| #include <gtsam/inference/FactorGraph.h> | ||||
| #include <gtsam/discrete/DiscretePrior.h> | ||||
| #include <gtsam/discrete/DiscreteConditional.h> | ||||
| 
 | ||||
| namespace gtsam { | ||||
|  | @ -75,6 +76,11 @@ namespace gtsam { | |||
|     // Add inherited versions of add.
 | ||||
|     using Base::add; | ||||
| 
 | ||||
|     /** Add a DiscretePrior using a table or a string */ | ||||
|     void add(const DiscreteKey& key, const std::string& spec) { | ||||
|       emplace_shared<DiscretePrior>(key, spec); | ||||
|     } | ||||
| 
 | ||||
|     /** Add a DiscreteCondtional */ | ||||
|     template <typename... Args> | ||||
|     void add(Args&&... args) { | ||||
|  |  | |||
|  | @ -97,45 +97,90 @@ bool DiscreteConditional::equals(const DiscreteFactor& other, | |||
| } | ||||
| 
 | ||||
| /* ******************************************************************************** */ | ||||
| Potentials::ADT DiscreteConditional::choose( | ||||
|     const DiscreteValues& parentsValues) const { | ||||
| static DiscreteConditional::ADT Choose(const DiscreteConditional& conditional, | ||||
|                                        const DiscreteValues& parentsValues) { | ||||
|   // Get the big decision tree with all the levels, and then go down the
 | ||||
|   // branches based on the value of the parent variables.
 | ||||
|   ADT pFS(*this); | ||||
|   DiscreteConditional::ADT adt(conditional); | ||||
|   size_t value; | ||||
|   for (Key j : parents()) { | ||||
|   for (Key j : conditional.parents()) { | ||||
|     try { | ||||
|       value = parentsValues.at(j); | ||||
|       pFS = pFS.choose(j, value);  // ADT keeps getting smaller.
 | ||||
|     } catch (exception&) { | ||||
|       cout << "Key: " << j << "  Value: " << value << endl; | ||||
|       adt = adt.choose(j, value);  // ADT keeps getting smaller.
 | ||||
|     } catch (std::out_of_range&) { | ||||
|       parentsValues.print("parentsValues: "); | ||||
|       throw runtime_error("DiscreteConditional::choose: parent value missing"); | ||||
|     }; | ||||
|   } | ||||
|   return pFS; | ||||
|   return adt; | ||||
| } | ||||
| 
 | ||||
| /* ******************************************************************************** */ | ||||
| DecisionTreeFactor::shared_ptr DiscreteConditional::chooseAsFactor( | ||||
| DecisionTreeFactor::shared_ptr DiscreteConditional::choose( | ||||
|     const DiscreteValues& parentsValues) const { | ||||
|   ADT pFS = choose(parentsValues); | ||||
|   // Get the big decision tree with all the levels, and then go down the
 | ||||
|   // branches based on the value of the parent variables.
 | ||||
|   ADT adt(*this); | ||||
|   size_t value; | ||||
|   for (Key j : parents()) { | ||||
|     try { | ||||
|       value = parentsValues.at(j); | ||||
|       adt = adt.choose(j, value);  // ADT keeps getting smaller.
 | ||||
|     } catch (exception&) { | ||||
|       parentsValues.print("parentsValues: "); | ||||
|       throw runtime_error("DiscreteConditional::choose: parent value missing"); | ||||
|     }; | ||||
|   } | ||||
| 
 | ||||
|   // Convert ADT to factor.
 | ||||
|   if (nrFrontals() != 1) { | ||||
|     throw std::runtime_error("Expected only one frontal variable in choose."); | ||||
|   DiscreteKeys discreteKeys; | ||||
|   for (Key j : frontals()) { | ||||
|     discreteKeys.emplace_back(j, this->cardinality(j)); | ||||
|   } | ||||
|   DiscreteKeys keys; | ||||
|   const Key frontalKey = keys_[0]; | ||||
|   size_t frontalCardinality = this->cardinality(frontalKey); | ||||
|   keys.push_back(DiscreteKey(frontalKey, frontalCardinality)); | ||||
|   return boost::make_shared<DecisionTreeFactor>(keys, pFS); | ||||
|   return boost::make_shared<DecisionTreeFactor>(discreteKeys, adt); | ||||
| } | ||||
| 
 | ||||
| /* ******************************************************************************** */ | ||||
| DecisionTreeFactor::shared_ptr DiscreteConditional::likelihood( | ||||
|     const DiscreteValues& frontalValues) const { | ||||
|   // Get the big decision tree with all the levels, and then go down the
 | ||||
|   // branches based on the value of the frontal variables.
 | ||||
|   ADT adt(*this); | ||||
|   size_t value; | ||||
|   for (Key j : frontals()) { | ||||
|     try { | ||||
|       value = frontalValues.at(j); | ||||
|       adt = adt.choose(j, value);  // ADT keeps getting smaller.
 | ||||
|     } catch (exception&) { | ||||
|       frontalValues.print("frontalValues: "); | ||||
|       throw runtime_error("DiscreteConditional::choose: frontal value missing"); | ||||
|     }; | ||||
|   } | ||||
| 
 | ||||
|   // Convert ADT to factor.
 | ||||
|   DiscreteKeys discreteKeys; | ||||
|   for (Key j : parents()) { | ||||
|     discreteKeys.emplace_back(j, this->cardinality(j)); | ||||
|   } | ||||
|   return boost::make_shared<DecisionTreeFactor>(discreteKeys, adt); | ||||
| } | ||||
| 
 | ||||
| /* ******************************************************************************** */ | ||||
| DecisionTreeFactor::shared_ptr DiscreteConditional::likelihood( | ||||
|     size_t parent_value) const { | ||||
|   if (nrFrontals() != 1) | ||||
|     throw std::invalid_argument( | ||||
|         "Single value likelihood can only be invoked on single-variable " | ||||
|         "conditional"); | ||||
|   DiscreteValues values; | ||||
|   values.emplace(keys_[0], parent_value); | ||||
|   return likelihood(values); | ||||
| } | ||||
| 
 | ||||
| /* ******************************************************************************** */ | ||||
| void DiscreteConditional::solveInPlace(DiscreteValues* values) const { | ||||
|   // TODO: Abhijit asks: is this really the fastest way? He thinks it is.
 | ||||
|   ADT pFS = choose(*values); // P(F|S=parentsValues)
 | ||||
|   ADT pFS = Choose(*this, *values); // P(F|S=parentsValues)
 | ||||
| 
 | ||||
|   // Initialize
 | ||||
|   DiscreteValues mpe; | ||||
|  | @ -177,7 +222,7 @@ void DiscreteConditional::sampleInPlace(DiscreteValues* values) const { | |||
| size_t DiscreteConditional::solve(const DiscreteValues& parentsValues) const { | ||||
| 
 | ||||
|   // TODO: is this really the fastest way? I think it is.
 | ||||
|   ADT pFS = choose(parentsValues); // P(F|S=parentsValues)
 | ||||
|   ADT pFS = Choose(*this, parentsValues); // P(F|S=parentsValues)
 | ||||
| 
 | ||||
|   // Then, find the max over all remaining
 | ||||
|   // TODO, only works for one key now, seems horribly slow this way
 | ||||
|  | @ -203,10 +248,14 @@ size_t DiscreteConditional::sample(const DiscreteValues& parentsValues) const { | |||
|   static mt19937 rng(2);  // random number generator
 | ||||
| 
 | ||||
|   // Get the correct conditional density
 | ||||
|   ADT pFS = choose(parentsValues);  // P(F|S=parentsValues)
 | ||||
|   ADT pFS = Choose(*this, parentsValues);  // P(F|S=parentsValues)
 | ||||
| 
 | ||||
|   // TODO(Duy): only works for one key now, seems horribly slow this way
 | ||||
|   assert(nrFrontals() == 1); | ||||
|   if (nrFrontals() != 1) { | ||||
|     throw std::invalid_argument( | ||||
|         "DiscreteConditional::sample can only be called on single variable " | ||||
|         "conditionals"); | ||||
|   } | ||||
|   Key key = firstFrontalKey(); | ||||
|   size_t nj = cardinality(key); | ||||
|   vector<double> p(nj); | ||||
|  | @ -222,13 +271,24 @@ size_t DiscreteConditional::sample(const DiscreteValues& parentsValues) const { | |||
|   return distribution(rng); | ||||
| } | ||||
| 
 | ||||
| /* ******************************************************************************** */ | ||||
| size_t DiscreteConditional::sample(size_t parent_value) const { | ||||
|   if (nrParents() != 1) | ||||
|     throw std::invalid_argument( | ||||
|         "Single value sample() can only be invoked on single-parent " | ||||
|         "conditional"); | ||||
|   DiscreteValues values; | ||||
|   values.emplace(keys_.back(), parent_value); | ||||
|   return sample(values); | ||||
| } | ||||
| 
 | ||||
| /* ************************************************************************* */ | ||||
| std::string DiscreteConditional::markdown( | ||||
|     const KeyFormatter& keyFormatter) const { | ||||
|   std::stringstream ss; | ||||
| 
 | ||||
|   // Print out signature.
 | ||||
|   ss << " $P("; | ||||
|   ss << " *P("; | ||||
|   bool first = true; | ||||
|   for (Key key : frontals()) { | ||||
|     if (!first) ss << ","; | ||||
|  | @ -237,7 +297,7 @@ std::string DiscreteConditional::markdown( | |||
|   } | ||||
|   if (nrParents() == 0) { | ||||
|    // We have no parents, call factor method.
 | ||||
|     ss << ")$:" << std::endl; | ||||
|     ss << ")*:\n" << std::endl; | ||||
|     ss << DecisionTreeFactor::markdown(keyFormatter); | ||||
|     return ss.str(); | ||||
|   } | ||||
|  | @ -250,7 +310,7 @@ std::string DiscreteConditional::markdown( | |||
|     ss << keyFormatter(parent); | ||||
|     first = false; | ||||
|   } | ||||
|   ss << ")$:" << std::endl; | ||||
|   ss << ")*:\n" << std::endl; | ||||
| 
 | ||||
|   // Print out header and construct argument for `cartesianProduct`.
 | ||||
|   std::vector<std::pair<Key, size_t>> pairs; | ||||
|  |  | |||
|  | @ -62,8 +62,6 @@ public: | |||
|    * conditional probability table (CPT) in 00 01 10 11 order. For | ||||
|    * three-valued, it would be 00 01 02 10 11 12 20 21 22, etc.... | ||||
|    * | ||||
|    * The first string is parsed to add a key and parents. | ||||
|    * | ||||
|    * Example: DiscreteConditional P(D, {B,E}, table); | ||||
|    */ | ||||
|   DiscreteConditional(const DiscreteKey& key, const DiscreteKeys& parents, | ||||
|  | @ -75,8 +73,7 @@ public: | |||
|    * probability table (CPT) in 00 01 10 11 order. For three-valued, it would | ||||
|    * be 00 01 02 10 11 12 20 21 22, etc.... | ||||
|    * | ||||
|    * The first string is parsed to add a key and parents. The second string | ||||
|    * parses into a table. | ||||
|    * The string is parsed into a Signature::Table. | ||||
|    * | ||||
|    * Example: DiscreteConditional P(D, {B,E}, "9/1 2/8 3/7 1/9"); | ||||
|    */ | ||||
|  | @ -84,6 +81,10 @@ public: | |||
|                       const std::string& spec) | ||||
|       : DiscreteConditional(Signature(key, parents, spec)) {} | ||||
| 
 | ||||
|   /// No-parent specialization; can also use DiscretePrior.
 | ||||
|   DiscreteConditional(const DiscreteKey& key, const std::string& spec) | ||||
|       : DiscreteConditional(Signature(key, {}, spec)) {} | ||||
| 
 | ||||
|   /** construct P(X|Y)=P(X,Y)/P(Y) from P(X,Y) and P(Y) */ | ||||
|   DiscreteConditional(const DecisionTreeFactor& joint, | ||||
|       const DecisionTreeFactor& marginal); | ||||
|  | @ -135,13 +136,17 @@ public: | |||
|     return DecisionTreeFactor::shared_ptr(new DecisionTreeFactor(*this)); | ||||
|   } | ||||
| 
 | ||||
|   /** Restrict to given parent values, returns AlgebraicDecisionDiagram */ | ||||
|   ADT choose(const DiscreteValues& parentsValues) const; | ||||
| 
 | ||||
|   /** Restrict to given parent values, returns DecisionTreeFactor */ | ||||
|   DecisionTreeFactor::shared_ptr chooseAsFactor( | ||||
|   DecisionTreeFactor::shared_ptr choose( | ||||
|       const DiscreteValues& parentsValues) const; | ||||
| 
 | ||||
|   /** Convert to a likelihood factor by providing value before bar. */ | ||||
|   DecisionTreeFactor::shared_ptr likelihood( | ||||
|       const DiscreteValues& frontalValues) const; | ||||
| 
 | ||||
|   /** Single variable version of likelihood. */ | ||||
|   DecisionTreeFactor::shared_ptr likelihood(size_t parent_value) const; | ||||
| 
 | ||||
|   /**
 | ||||
|    * solve a conditional | ||||
|    * @param parentsValues Known values of the parents | ||||
|  | @ -156,6 +161,10 @@ public: | |||
|    */ | ||||
|   size_t sample(const DiscreteValues& parentsValues) const; | ||||
| 
 | ||||
| 
 | ||||
|   /// Single value version.
 | ||||
|   size_t sample(size_t parent_value) const; | ||||
| 
 | ||||
|   /// @}
 | ||||
|   /// @name Advanced Interface
 | ||||
|   /// @{
 | ||||
|  |  | |||
|  | @ -101,29 +101,12 @@ public: | |||
| 
 | ||||
|   /// @}
 | ||||
| 
 | ||||
|   // Add single key decision-tree factor.
 | ||||
|   template <class SOURCE> | ||||
|   void add(const DiscreteKey& j, SOURCE table) { | ||||
|     DiscreteKeys keys; | ||||
|     keys.push_back(j); | ||||
|     emplace_shared<DecisionTreeFactor>(keys, table); | ||||
|   /** Add a decision-tree factor */ | ||||
|   template <typename... Args> | ||||
|   void add(Args&&... args) { | ||||
|     emplace_shared<DecisionTreeFactor>(std::forward<Args>(args)...); | ||||
|   } | ||||
| 
 | ||||
|   // Add binary key decision-tree factor.
 | ||||
|   template <class SOURCE> | ||||
|   void add(const DiscreteKey& j1, const DiscreteKey& j2, SOURCE table) { | ||||
|     DiscreteKeys keys; | ||||
|     keys.push_back(j1); | ||||
|     keys.push_back(j2); | ||||
|     emplace_shared<DecisionTreeFactor>(keys, table); | ||||
|   } | ||||
| 
 | ||||
|   // Add shared discreteFactor immediately from arguments.
 | ||||
|   template <class SOURCE> | ||||
|   void add(const DiscreteKeys& keys, SOURCE table) { | ||||
|     emplace_shared<DecisionTreeFactor>(keys, table); | ||||
|   } | ||||
| 
 | ||||
|        | ||||
|   /** Return the set of variables involved in the factors (set union) */ | ||||
|   KeySet keys() const; | ||||
| 
 | ||||
|  |  | |||
|  | @ -43,9 +43,7 @@ namespace gtsam { | |||
|     DiscreteKeys() : std::vector<DiscreteKey>::vector() {} | ||||
| 
 | ||||
|     /// Construct from a key
 | ||||
|     DiscreteKeys(const DiscreteKey& key) { | ||||
|       push_back(key); | ||||
|     } | ||||
|     explicit DiscreteKeys(const DiscreteKey& key) { push_back(key); } | ||||
| 
 | ||||
|     /// Construct from a vector of keys
 | ||||
|     DiscreteKeys(const std::vector<DiscreteKey>& keys) : | ||||
|  |  | |||
|  | @ -0,0 +1,50 @@ | |||
| /* ----------------------------------------------------------------------------
 | ||||
| 
 | ||||
|  * GTSAM Copyright 2010, Georgia Tech Research Corporation, | ||||
|  * Atlanta, Georgia 30332-0415 | ||||
|  * All Rights Reserved | ||||
|  * Authors: Frank Dellaert, et al. (see THANKS for the full author list) | ||||
| 
 | ||||
|  * See LICENSE for the license information | ||||
| 
 | ||||
|  * -------------------------------------------------------------------------- */ | ||||
| 
 | ||||
| /**
 | ||||
|  *  @file DiscretePrior.cpp | ||||
|  *  @date December 2021 | ||||
|  *  @author Frank Dellaert | ||||
|  */ | ||||
| 
 | ||||
| #include <gtsam/discrete/DiscretePrior.h> | ||||
| 
 | ||||
| namespace gtsam { | ||||
| 
 | ||||
| void DiscretePrior::print(const std::string& s, | ||||
|                           const KeyFormatter& formatter) const { | ||||
|   Base::print(s, formatter); | ||||
| } | ||||
| 
 | ||||
| double DiscretePrior::operator()(size_t value) const { | ||||
|   if (nrFrontals() != 1) | ||||
|     throw std::invalid_argument( | ||||
|         "Single value operator can only be invoked on single-variable " | ||||
|         "priors"); | ||||
|   DiscreteValues values; | ||||
|   values.emplace(keys_[0], value); | ||||
|   return Base::operator()(values); | ||||
| } | ||||
| 
 | ||||
| std::vector<double> DiscretePrior::pmf() const { | ||||
|   if (nrFrontals() != 1) | ||||
|     throw std::invalid_argument( | ||||
|         "DiscretePrior::pmf only defined for single-variable priors"); | ||||
|   const size_t nrValues = cardinalities_.at(keys_[0]); | ||||
|   std::vector<double> array; | ||||
|   array.reserve(nrValues); | ||||
|   for (size_t v = 0; v < nrValues; v++) { | ||||
|     array.push_back(operator()(v)); | ||||
|   } | ||||
|   return array; | ||||
| } | ||||
| 
 | ||||
| }  // namespace gtsam
 | ||||
|  | @ -0,0 +1,111 @@ | |||
| /* ----------------------------------------------------------------------------
 | ||||
| 
 | ||||
|  * GTSAM Copyright 2010, Georgia Tech Research Corporation, | ||||
|  * Atlanta, Georgia 30332-0415 | ||||
|  * All Rights Reserved | ||||
|  * Authors: Frank Dellaert, et al. (see THANKS for the full author list) | ||||
| 
 | ||||
|  * See LICENSE for the license information | ||||
| 
 | ||||
|  * -------------------------------------------------------------------------- */ | ||||
| 
 | ||||
| /**
 | ||||
|  *  @file DiscretePrior.h | ||||
|  *  @date December 2021 | ||||
|  *  @author Frank Dellaert | ||||
|  */ | ||||
| 
 | ||||
| #pragma once | ||||
| 
 | ||||
| #include <gtsam/discrete/DiscreteConditional.h> | ||||
| 
 | ||||
| #include <string> | ||||
| 
 | ||||
| namespace gtsam { | ||||
| 
 | ||||
| /**
 | ||||
|  * A prior probability on a set of discrete variables. | ||||
|  * Derives from DiscreteConditional | ||||
|  */ | ||||
| class GTSAM_EXPORT DiscretePrior : public DiscreteConditional { | ||||
|  public: | ||||
|   using Base = DiscreteConditional; | ||||
| 
 | ||||
|   /// @name Standard Constructors
 | ||||
|   /// @{
 | ||||
| 
 | ||||
|   /// Default constructor needed for serialization.
 | ||||
|   DiscretePrior() {} | ||||
| 
 | ||||
|   /// Constructor from factor.
 | ||||
|   DiscretePrior(const DecisionTreeFactor& f) : Base(f.size(), f) {} | ||||
| 
 | ||||
|   /**
 | ||||
|    * Construct from a Signature. | ||||
|    * | ||||
|    * Example: DiscretePrior P(D % "3/2"); | ||||
|    */ | ||||
|   DiscretePrior(const Signature& s) : Base(s) {} | ||||
| 
 | ||||
|   /**
 | ||||
|    * Construct from key and a Signature::Table specifying the | ||||
|    * conditional probability table (CPT). | ||||
|    * | ||||
|    * Example: DiscretePrior P(D, table); | ||||
|    */ | ||||
|   DiscretePrior(const DiscreteKey& key, const Signature::Table& table) | ||||
|       : Base(Signature(key, {}, table)) {} | ||||
| 
 | ||||
|   /**
 | ||||
|    * Construct from key and a string specifying the conditional | ||||
|    * probability table (CPT). | ||||
|    * | ||||
|    * Example: DiscretePrior P(D, "9/1 2/8 3/7 1/9"); | ||||
|    */ | ||||
|   DiscretePrior(const DiscreteKey& key, const std::string& spec) | ||||
|       : DiscretePrior(Signature(key, {}, spec)) {} | ||||
| 
 | ||||
|   /// @}
 | ||||
|   /// @name Testable
 | ||||
|   /// @{
 | ||||
| 
 | ||||
|   /// GTSAM-style print
 | ||||
|   void print( | ||||
|       const std::string& s = "Discrete Prior: ", | ||||
|       const KeyFormatter& formatter = DefaultKeyFormatter) const override; | ||||
| 
 | ||||
|   /// @}
 | ||||
|   /// @name Standard interface
 | ||||
|   /// @{
 | ||||
| 
 | ||||
|   /// Evaluate given a single value.
 | ||||
|   double operator()(size_t value) const; | ||||
| 
 | ||||
|   /// We also want to keep the Base version, taking DiscreteValues:
 | ||||
|   // TODO(dellaert): does not play well with wrapper!
 | ||||
|   // using Base::operator();
 | ||||
| 
 | ||||
|   /// Return entire probability mass function.
 | ||||
|   std::vector<double> pmf() const; | ||||
| 
 | ||||
|   /**
 | ||||
|    * solve a conditional | ||||
|    * @return MPE value of the child (1 frontal variable). | ||||
|    */ | ||||
|   size_t solve() const { return Base::solve({}); } | ||||
| 
 | ||||
|   /**
 | ||||
|    * sample | ||||
|    * @return sample from conditional | ||||
|    */ | ||||
|   size_t sample() const { return Base::sample({}); } | ||||
| 
 | ||||
|   /// @}
 | ||||
| }; | ||||
| // DiscretePrior
 | ||||
| 
 | ||||
| // traits
 | ||||
| template <> | ||||
| struct traits<DiscretePrior> : public Testable<DiscretePrior> {}; | ||||
| 
 | ||||
| }  // namespace gtsam
 | ||||
|  | @ -30,25 +30,37 @@ class DiscreteFactor { | |||
| }; | ||||
| 
 | ||||
| #include <gtsam/discrete/DecisionTreeFactor.h> | ||||
| virtual class DecisionTreeFactor: gtsam::DiscreteFactor { | ||||
| virtual class DecisionTreeFactor : gtsam::DiscreteFactor { | ||||
|   DecisionTreeFactor(); | ||||
|    | ||||
|   DecisionTreeFactor(const gtsam::DiscreteKey& key, | ||||
|                      const std::vector<double>& spec); | ||||
|   DecisionTreeFactor(const gtsam::DiscreteKey& key, string table); | ||||
|    | ||||
|   DecisionTreeFactor(const gtsam::DiscreteKeys& keys, string table); | ||||
|   DecisionTreeFactor(const std::vector<gtsam::DiscreteKey>& keys, string table); | ||||
|    | ||||
|   DecisionTreeFactor(const gtsam::DiscreteConditional& c); | ||||
|    | ||||
|   void print(string s = "DecisionTreeFactor\n", | ||||
|              const gtsam::KeyFormatter& keyFormatter = | ||||
|                  gtsam::DefaultKeyFormatter) const; | ||||
|   bool equals(const gtsam::DecisionTreeFactor& other, double tol = 1e-9) const; | ||||
|   string dot(bool showZero = false) const; | ||||
|   std::vector<std::pair<DiscreteValues, double>> enumerate() const; | ||||
|   string markdown(const gtsam::KeyFormatter& keyFormatter = | ||||
|                  gtsam::DefaultKeyFormatter) const; | ||||
|                       gtsam::DefaultKeyFormatter) const; | ||||
| }; | ||||
| 
 | ||||
| #include <gtsam/discrete/DiscreteConditional.h> | ||||
| virtual class DiscreteConditional : gtsam::DecisionTreeFactor { | ||||
|   DiscreteConditional(); | ||||
|   DiscreteConditional(size_t nFrontals, const gtsam::DecisionTreeFactor& f); | ||||
|   DiscreteConditional(const gtsam::DiscreteKey& key, string spec); | ||||
|   DiscreteConditional(const gtsam::DiscreteKey& key, | ||||
|                       const gtsam::DiscreteKeys& parents, string spec); | ||||
|   DiscreteConditional(const gtsam::DiscreteKey& key, | ||||
|                       const std::vector<gtsam::DiscreteKey>& parents, string spec); | ||||
|   DiscreteConditional(const gtsam::DecisionTreeFactor& joint, | ||||
|                       const gtsam::DecisionTreeFactor& marginal); | ||||
|   DiscreteConditional(const gtsam::DecisionTreeFactor& joint, | ||||
|  | @ -62,20 +74,43 @@ virtual class DiscreteConditional : gtsam::DecisionTreeFactor { | |||
|       string s = "Discrete Conditional: ", | ||||
|       const gtsam::KeyFormatter& formatter = gtsam::DefaultKeyFormatter) const; | ||||
|   gtsam::DecisionTreeFactor* toFactor() const; | ||||
|   gtsam::DecisionTreeFactor* chooseAsFactor(const gtsam::DiscreteValues& parentsValues) const; | ||||
|   gtsam::DecisionTreeFactor* choose( | ||||
|       const gtsam::DiscreteValues& parentsValues) const; | ||||
|   gtsam::DecisionTreeFactor* likelihood( | ||||
|       const gtsam::DiscreteValues& frontalValues) const; | ||||
|   gtsam::DecisionTreeFactor* likelihood(size_t value) const; | ||||
|   size_t solve(const gtsam::DiscreteValues& parentsValues) const; | ||||
|   size_t sample(const gtsam::DiscreteValues& parentsValues) const; | ||||
|   void solveInPlace(gtsam::DiscreteValues@ parentsValues) const; | ||||
|   void sampleInPlace(gtsam::DiscreteValues@ parentsValues) const; | ||||
|   size_t sample(size_t value) const; | ||||
|   void solveInPlace(gtsam::DiscreteValues @parentsValues) const; | ||||
|   void sampleInPlace(gtsam::DiscreteValues @parentsValues) const; | ||||
|   string markdown(const gtsam::KeyFormatter& keyFormatter = | ||||
|                       gtsam::DefaultKeyFormatter) const; | ||||
| }; | ||||
| 
 | ||||
| #include <gtsam/discrete/DiscretePrior.h> | ||||
| virtual class DiscretePrior : gtsam::DiscreteConditional { | ||||
|   DiscretePrior(); | ||||
|   DiscretePrior(const gtsam::DecisionTreeFactor& f); | ||||
|   DiscretePrior(const gtsam::DiscreteKey& key, string spec); | ||||
|   void print(string s = "Discrete Prior\n", | ||||
|              const gtsam::KeyFormatter& keyFormatter = | ||||
|                  gtsam::DefaultKeyFormatter) const; | ||||
|   double operator()(size_t value) const; | ||||
|   std::vector<double> pmf() const; | ||||
|   size_t solve() const; | ||||
|   size_t sample() const; | ||||
| }; | ||||
| 
 | ||||
| #include <gtsam/discrete/DiscreteBayesNet.h> | ||||
| class DiscreteBayesNet {  | ||||
| class DiscreteBayesNet { | ||||
|   DiscreteBayesNet(); | ||||
|   void add(const gtsam::DiscreteConditional& s); | ||||
|   void add(const gtsam::DiscreteKey& key, string spec); | ||||
|   void add(const gtsam::DiscreteKey& key, const gtsam::DiscreteKeys& parents, | ||||
|            string spec); | ||||
|   void add(const gtsam::DiscreteKey& key, | ||||
|            const gtsam::DiscreteKeys& parents, string spec); | ||||
|            const std::vector<gtsam::DiscreteKey>& parents, string spec); | ||||
|   bool empty() const; | ||||
|   size_t size() const; | ||||
|   gtsam::KeySet keys() const; | ||||
|  | @ -86,15 +121,13 @@ class DiscreteBayesNet { | |||
|   bool equals(const gtsam::DiscreteBayesNet& other, double tol = 1e-9) const; | ||||
|   string dot(const gtsam::KeyFormatter& keyFormatter = | ||||
|                  gtsam::DefaultKeyFormatter) const; | ||||
|   void saveGraph(string s, | ||||
|                 const gtsam::KeyFormatter& keyFormatter = | ||||
|                  gtsam::DefaultKeyFormatter) const; | ||||
|   void add(const gtsam::DiscreteConditional& s); | ||||
|   void saveGraph(string s, const gtsam::KeyFormatter& keyFormatter = | ||||
|                                gtsam::DefaultKeyFormatter) const; | ||||
|   double operator()(const gtsam::DiscreteValues& values) const; | ||||
|   gtsam::DiscreteValues optimize() const; | ||||
|   gtsam::DiscreteValues sample() const; | ||||
|   string markdown(const gtsam::KeyFormatter& keyFormatter = | ||||
|                  gtsam::DefaultKeyFormatter) const; | ||||
|                       gtsam::DefaultKeyFormatter) const; | ||||
| }; | ||||
| 
 | ||||
| #include <gtsam/discrete/DiscreteBayesTree.h> | ||||
|  | @ -142,11 +175,13 @@ class DotWriter { | |||
| class DiscreteFactorGraph { | ||||
|   DiscreteFactorGraph(); | ||||
|   DiscreteFactorGraph(const gtsam::DiscreteBayesNet& bayesNet); | ||||
|    | ||||
| 
 | ||||
|   void add(const gtsam::DiscreteKey& j, string table); | ||||
|   void add(const gtsam::DiscreteKey& j1, const gtsam::DiscreteKey& j2, string table); | ||||
|   void add(const gtsam::DiscreteKey& j, const std::vector<double>& spec); | ||||
| 
 | ||||
|   void add(const gtsam::DiscreteKeys& keys, string table); | ||||
|    | ||||
|   void add(const std::vector<gtsam::DiscreteKey>& keys, string table); | ||||
| 
 | ||||
|   bool empty() const; | ||||
|   size_t size() const; | ||||
|   gtsam::KeySet keys() const; | ||||
|  |  | |||
|  | @ -34,7 +34,7 @@ TEST( DecisionTreeFactor, constructors) | |||
|   DiscreteKey X(0,2), Y(1,3), Z(2,2); | ||||
| 
 | ||||
|   // Create factors
 | ||||
|   DecisionTreeFactor f1(X, "2 8"); | ||||
|   DecisionTreeFactor f1(X, {2, 8}); | ||||
|   DecisionTreeFactor f2(X & Y, "2 5 3 6 4 7"); | ||||
|   DecisionTreeFactor f3(X & Y & Z, "2 5 3 6 4 7 25 55 35 65 45 75"); | ||||
|   EXPECT_LONGS_EQUAL(1,f1.size()); | ||||
|  | @ -82,11 +82,29 @@ TEST( DecisionTreeFactor, sum_max) | |||
|   DecisionTreeFactor::shared_ptr actual22 = f2.sum(1); | ||||
| } | ||||
| 
 | ||||
| /* ************************************************************************* */ | ||||
| // Check enumerate yields the correct list of assignment/value pairs.
 | ||||
| TEST(DecisionTreeFactor, enumerate) { | ||||
|   DiscreteKey A(12, 3), B(5, 2); | ||||
|   DecisionTreeFactor f(A & B, "1 2  3 4  5 6"); | ||||
|   auto actual = f.enumerate(); | ||||
|   std::vector<std::pair<DiscreteValues, double>> expected; | ||||
|   DiscreteValues values; | ||||
|   for (size_t a : {0, 1, 2}) { | ||||
|     for (size_t b : {0, 1}) { | ||||
|       values[12] = a; | ||||
|       values[5] = b; | ||||
|       expected.emplace_back(values, f(values)); | ||||
|     } | ||||
|   } | ||||
|   EXPECT(actual == expected); | ||||
| } | ||||
| 
 | ||||
| /* ************************************************************************* */ | ||||
| // Check markdown representation looks as expected.
 | ||||
| TEST(DecisionTreeFactor, markdown) { | ||||
|   DiscreteKey A(12, 3), B(5, 2); | ||||
|   DecisionTreeFactor f1(A & B, "1 2  3 4  5 6"); | ||||
|   DecisionTreeFactor f(A & B, "1 2  3 4  5 6"); | ||||
|   string expected = | ||||
|       "|A|B|value|\n" | ||||
|       "|:-:|:-:|:-:|\n" | ||||
|  | @ -97,7 +115,7 @@ TEST(DecisionTreeFactor, markdown) { | |||
|       "|2|0|5|\n" | ||||
|       "|2|1|6|\n"; | ||||
|   auto formatter = [](Key key) { return key == 12 ? "A" : "B"; }; | ||||
|   string actual = f1.markdown(formatter); | ||||
|   string actual = f.markdown(formatter); | ||||
|   EXPECT(actual == expected); | ||||
| } | ||||
| 
 | ||||
|  |  | |||
|  | @ -75,8 +75,8 @@ TEST(DiscreteBayesNet, bayesNet) { | |||
| TEST(DiscreteBayesNet, Asia) { | ||||
|   DiscreteBayesNet asia; | ||||
| 
 | ||||
|   asia.add(Asia % "99/1"); | ||||
|   asia.add(Smoking % "50/50"); | ||||
|   asia.add(Asia, "99/1"); | ||||
|   asia.add(Smoking % "50/50");  // Signature version
 | ||||
| 
 | ||||
|   asia.add(Tuberculosis | Asia = "99/1 95/5"); | ||||
|   asia.add(LungCancer | Smoking = "99/1 90/10"); | ||||
|  | @ -180,13 +180,13 @@ TEST(DiscreteBayesNet, markdown) { | |||
|   string expected = | ||||
|       "`DiscreteBayesNet` of size 2\n" | ||||
|       "\n" | ||||
|       " $P(Asia)$:\n" | ||||
|       " *P(Asia)*:\n\n" | ||||
|       "|Asia|value|\n" | ||||
|       "|:-:|:-:|\n" | ||||
|       "|0|0.99|\n" | ||||
|       "|1|0.01|\n" | ||||
|       "\n" | ||||
|       " $P(Smoking|Asia)$:\n" | ||||
|       " *P(Smoking|Asia)*:\n\n" | ||||
|       "|Asia|0|1|\n" | ||||
|       "|:-:|:-:|:-:|\n" | ||||
|       "|0|0.8|0.2|\n" | ||||
|  |  | |||
|  | @ -10,10 +10,11 @@ | |||
|  * -------------------------------------------------------------------------- */ | ||||
| 
 | ||||
| /*
 | ||||
|  * @file    testDecisionTreeFactor.cpp | ||||
|  * @file    testDiscreteConditional.cpp | ||||
|  * @brief   unit tests for DiscreteConditional | ||||
|  * @author  Duy-Nguyen Ta | ||||
|  * @date Feb 14, 2011 | ||||
|  * @author  Frank dellaert | ||||
|  * @date    Feb 14, 2011 | ||||
|  */ | ||||
| 
 | ||||
| #include <boost/assign/std/map.hpp> | ||||
|  | @ -30,24 +31,21 @@ using namespace std; | |||
| using namespace gtsam; | ||||
| 
 | ||||
| /* ************************************************************************* */ | ||||
| TEST( DiscreteConditional, constructors) | ||||
| { | ||||
|   DiscreteKey X(0, 2), Y(2, 3), Z(1, 2); // watch ordering !
 | ||||
| TEST(DiscreteConditional, constructors) { | ||||
|   DiscreteKey X(0, 2), Y(2, 3), Z(1, 2);  // watch ordering !
 | ||||
| 
 | ||||
|   DiscreteConditional expected(X | Y = "1/1 2/3 1/4"); | ||||
|   EXPECT_LONGS_EQUAL(0, *(expected.beginFrontals())); | ||||
|   EXPECT_LONGS_EQUAL(2, *(expected.beginParents())); | ||||
|   EXPECT(expected.endParents() == expected.end()); | ||||
|   EXPECT(expected.endFrontals() == expected.beginParents()); | ||||
| 
 | ||||
|   DiscreteConditional::shared_ptr expected1 = //
 | ||||
|       boost::make_shared<DiscreteConditional>(X | Y = "1/1 2/3 1/4"); | ||||
|   EXPECT(expected1); | ||||
|   EXPECT_LONGS_EQUAL(0, *(expected1->beginFrontals())); | ||||
|   EXPECT_LONGS_EQUAL(2, *(expected1->beginParents())); | ||||
|   EXPECT(expected1->endParents() == expected1->end()); | ||||
|   EXPECT(expected1->endFrontals() == expected1->beginParents()); | ||||
|    | ||||
|   DecisionTreeFactor f1(X & Y, "0.5 0.4 0.2 0.5 0.6 0.8"); | ||||
|   DiscreteConditional actual1(1, f1); | ||||
|   EXPECT(assert_equal(*expected1, actual1, 1e-9)); | ||||
|   EXPECT(assert_equal(expected, actual1, 1e-9)); | ||||
| 
 | ||||
|   DecisionTreeFactor f2(X & Y & Z, | ||||
|       "0.2 0.5 0.3 0.6 0.4 0.7 0.25 0.55 0.35 0.65 0.45 0.75"); | ||||
|   DecisionTreeFactor f2( | ||||
|       X & Y & Z, "0.2 0.5 0.3 0.6 0.4 0.7 0.25 0.55 0.35 0.65 0.45 0.75"); | ||||
|   DiscreteConditional actual2(1, f2); | ||||
|   EXPECT(assert_equal(f2 / *f2.sum(1), *actual2.toFactor(), 1e-9)); | ||||
| } | ||||
|  | @ -107,13 +105,27 @@ TEST(DiscreteConditional, Combine) { | |||
|   EXPECT(assert_equal(expected, *actual, 1e-5)); | ||||
| } | ||||
| 
 | ||||
| /* ************************************************************************* */ | ||||
| TEST(DiscreteConditional, likelihood) { | ||||
|   DiscreteKey X(0, 2), Y(1, 3); | ||||
|   DiscreteConditional conditional(X | Y = "2/8 4/6 5/5"); | ||||
| 
 | ||||
|   auto actual0 = conditional.likelihood(0); | ||||
|   DecisionTreeFactor expected0(Y, "0.2 0.4 0.5"); | ||||
|   EXPECT(assert_equal(expected0, *actual0, 1e-9)); | ||||
| 
 | ||||
|   auto actual1 = conditional.likelihood(1); | ||||
|   DecisionTreeFactor expected1(Y, "0.8 0.6 0.5"); | ||||
|   EXPECT(assert_equal(expected1, *actual1, 1e-9)); | ||||
| } | ||||
| 
 | ||||
| /* ************************************************************************* */ | ||||
| // Check markdown representation looks as expected, no parents.
 | ||||
| TEST(DiscreteConditional, markdown_prior) { | ||||
|   DiscreteKey A(Symbol('x', 1), 3); | ||||
|   DiscreteConditional conditional(A % "1/2/2"); | ||||
|   string expected = | ||||
|       " $P(x1)$:\n" | ||||
|       " *P(x1)*:\n\n" | ||||
|       "|x1|value|\n" | ||||
|       "|:-:|:-:|\n" | ||||
|       "|0|0.2|\n" | ||||
|  | @ -130,7 +142,7 @@ TEST(DiscreteConditional, markdown_multivalued) { | |||
|   DiscreteConditional conditional( | ||||
|       A | B = "2/88/10 2/20/78 33/33/34 33/33/34 95/2/3"); | ||||
|   string expected = | ||||
|       " $P(a1|b1)$:\n" | ||||
|       " *P(a1|b1)*:\n\n" | ||||
|       "|b1|0|1|2|\n" | ||||
|       "|:-:|:-:|:-:|:-:|\n" | ||||
|       "|0|0.02|0.88|0.1|\n" | ||||
|  | @ -148,7 +160,7 @@ TEST(DiscreteConditional, markdown) { | |||
|   DiscreteKey A(2, 2), B(1, 2), C(0, 3); | ||||
|   DiscreteConditional conditional(A, {B, C}, "0/1 1/3  1/1 3/1  0/1 1/0"); | ||||
|   string expected = | ||||
|       " $P(A|B,C)$:\n" | ||||
|       " *P(A|B,C)*:\n\n" | ||||
|       "|B|C|0|1|\n" | ||||
|       "|:-:|:-:|:-:|:-:|\n" | ||||
|       "|0|0|0|1|\n" | ||||
|  |  | |||
|  | @ -0,0 +1,55 @@ | |||
| /* ----------------------------------------------------------------------------
 | ||||
| 
 | ||||
|  * GTSAM Copyright 2010, Georgia Tech Research Corporation, | ||||
|  * Atlanta, Georgia 30332-0415 | ||||
|  * All Rights Reserved | ||||
|  * Authors: Frank Dellaert, et al. (see THANKS for the full author list) | ||||
| 
 | ||||
|  * See LICENSE for the license information | ||||
| 
 | ||||
|  * -------------------------------------------------------------------------- */ | ||||
| 
 | ||||
| /*
 | ||||
|  * @file    testDiscretePrior.cpp | ||||
|  * @brief   unit tests for DiscretePrior | ||||
|  * @author  Frank dellaert | ||||
|  * @date    December 2021 | ||||
|  */ | ||||
| 
 | ||||
| #include <CppUnitLite/TestHarness.h> | ||||
| #include <gtsam/discrete/DiscretePrior.h> | ||||
| #include <gtsam/discrete/Signature.h> | ||||
| 
 | ||||
| using namespace std; | ||||
| using namespace gtsam; | ||||
| 
 | ||||
| static const DiscreteKey X(0, 2); | ||||
| 
 | ||||
| /* ************************************************************************* */ | ||||
| TEST(DiscretePrior, constructors) { | ||||
|   DiscretePrior actual(X % "2/3"); | ||||
|   DecisionTreeFactor f(X, "0.4 0.6"); | ||||
|   DiscretePrior expected(f); | ||||
|   EXPECT(assert_equal(expected, actual, 1e-9)); | ||||
| } | ||||
| 
 | ||||
| /* ************************************************************************* */ | ||||
| TEST(DiscretePrior, operator) { | ||||
|   DiscretePrior prior(X % "2/3"); | ||||
|   EXPECT_DOUBLES_EQUAL(prior(0), 0.4, 1e-9); | ||||
|   EXPECT_DOUBLES_EQUAL(prior(1), 0.6, 1e-9); | ||||
| } | ||||
| 
 | ||||
| /* ************************************************************************* */ | ||||
| TEST(DiscretePrior, to_vector) { | ||||
|   DiscretePrior prior(X % "2/3"); | ||||
|   vector<double> expected {0.4, 0.6}; | ||||
|   EXPECT(prior.pmf() ==  expected); | ||||
| } | ||||
| 
 | ||||
| /* ************************************************************************* */ | ||||
| int main() { | ||||
|   TestResult tr; | ||||
|   return TestRegistry::runAllTests(tr); | ||||
| } | ||||
| /* ************************************************************************* */ | ||||
|  | @ -133,10 +133,10 @@ void Scheduler::addStudentSpecificConstraints(size_t i, | |||
|       Potentials::ADT p(dummy & areaKey, | ||||
|                         available_);  // available_ is Doodle string
 | ||||
|       Potentials::ADT q = p.choose(dummyIndex, *slot); | ||||
|       DiscreteFactor::shared_ptr f(new DecisionTreeFactor(areaKey, q)); | ||||
|       CSP::push_back(f); | ||||
|       CSP::add(areaKey, q); | ||||
|     } else { | ||||
|       CSP::add(s.key_, areaKey, available_);  // available_ is Doodle string
 | ||||
|       DiscreteKeys keys {s.key_, areaKey}; | ||||
|       CSP::add(keys, available_);  // available_ is Doodle string
 | ||||
|     } | ||||
|   } | ||||
| 
 | ||||
|  |  | |||
|  | @ -0,0 +1,54 @@ | |||
| """ | ||||
| GTSAM Copyright 2010-2021, Georgia Tech Research Corporation, | ||||
| Atlanta, Georgia 30332-0415 | ||||
| All Rights Reserved | ||||
| 
 | ||||
| See LICENSE for the license information | ||||
| 
 | ||||
| Unit tests for DecisionTreeFactors. | ||||
| Author: Frank Dellaert | ||||
| """ | ||||
| 
 | ||||
| # pylint: disable=no-name-in-module, invalid-name | ||||
| 
 | ||||
| import unittest | ||||
| 
 | ||||
| from gtsam import DecisionTreeFactor, DecisionTreeFactor, DiscreteKeys | ||||
| from gtsam.utils.test_case import GtsamTestCase | ||||
| 
 | ||||
| 
 | ||||
| class TestDecisionTreeFactor(GtsamTestCase): | ||||
|     """Tests for DecisionTreeFactors.""" | ||||
| 
 | ||||
|     def setUp(self): | ||||
|         A = (12, 3) | ||||
|         B = (5, 2) | ||||
|         self.factor = DecisionTreeFactor([A, B], "1 2  3 4  5 6") | ||||
| 
 | ||||
|     def test_enumerate(self): | ||||
|         actual = self.factor.enumerate() | ||||
|         _, values = zip(*actual) | ||||
|         self.assertEqual(list(values), [1.0, 2.0, 3.0, 4.0, 5.0, 6.0]) | ||||
| 
 | ||||
|     def test_markdown(self): | ||||
|         """Test whether the _repr_markdown_ method.""" | ||||
| 
 | ||||
|         expected = \ | ||||
|             "|A|B|value|\n" \ | ||||
|             "|:-:|:-:|:-:|\n" \ | ||||
|             "|0|0|1|\n" \ | ||||
|             "|0|1|2|\n" \ | ||||
|             "|1|0|3|\n" \ | ||||
|             "|1|1|4|\n" \ | ||||
|             "|2|0|5|\n" \ | ||||
|             "|2|1|6|\n" | ||||
| 
 | ||||
|         def formatter(x: int): | ||||
|             return "A" if x == 12 else "B" | ||||
| 
 | ||||
|         actual = self.factor._repr_markdown_(formatter) | ||||
|         self.assertEqual(actual, expected) | ||||
| 
 | ||||
| 
 | ||||
| if __name__ == "__main__": | ||||
|     unittest.main() | ||||
|  | @ -14,7 +14,7 @@ Author: Frank Dellaert | |||
| import unittest | ||||
| 
 | ||||
| from gtsam import (DiscreteBayesNet, DiscreteConditional, DiscreteFactorGraph, | ||||
|                    DiscreteKeys, DiscreteValues, Ordering) | ||||
|                    DiscreteKeys, DiscretePrior, DiscreteValues, Ordering) | ||||
| from gtsam.utils.test_case import GtsamTestCase | ||||
| 
 | ||||
| 
 | ||||
|  | @ -53,24 +53,18 @@ class TestDiscreteBayesNet(GtsamTestCase): | |||
|         XRay = (2, 2) | ||||
|         Dyspnea = (1, 2) | ||||
| 
 | ||||
|         def P(keys): | ||||
|             dks = DiscreteKeys() | ||||
|             for key in keys: | ||||
|                 dks.push_back(key) | ||||
|             return dks | ||||
| 
 | ||||
|         asia = DiscreteBayesNet() | ||||
|         asia.add(Asia, P([]), "99/1") | ||||
|         asia.add(Smoking, P([]), "50/50") | ||||
|         asia.add(Asia, "99/1") | ||||
|         asia.add(Smoking, "50/50") | ||||
| 
 | ||||
|         asia.add(Tuberculosis, P([Asia]), "99/1 95/5") | ||||
|         asia.add(LungCancer, P([Smoking]), "99/1 90/10") | ||||
|         asia.add(Bronchitis, P([Smoking]), "70/30 40/60") | ||||
|         asia.add(Tuberculosis, [Asia], "99/1 95/5") | ||||
|         asia.add(LungCancer, [Smoking], "99/1 90/10") | ||||
|         asia.add(Bronchitis, [Smoking], "70/30 40/60") | ||||
| 
 | ||||
|         asia.add(Either, P([Tuberculosis, LungCancer]), "F T T T") | ||||
|         asia.add(Either, [Tuberculosis, LungCancer], "F T T T") | ||||
| 
 | ||||
|         asia.add(XRay, P([Either]), "95/5 2/98") | ||||
|         asia.add(Dyspnea, P([Either, Bronchitis]), "9/1 2/8 3/7 1/9") | ||||
|         asia.add(XRay, [Either], "95/5 2/98") | ||||
|         asia.add(Dyspnea, [Either, Bronchitis], "9/1 2/8 3/7 1/9") | ||||
| 
 | ||||
|         # Convert to factor graph | ||||
|         fg = DiscreteFactorGraph(asia) | ||||
|  | @ -80,7 +74,7 @@ class TestDiscreteBayesNet(GtsamTestCase): | |||
|         for j in range(8): | ||||
|             ordering.push_back(j) | ||||
|         chordal = fg.eliminateSequential(ordering) | ||||
|         expected2 = DiscreteConditional(Bronchitis, P([]), "11/9") | ||||
|         expected2 = DiscretePrior(Bronchitis, "11/9") | ||||
|         self.gtsamAssertEquals(chordal.at(7), expected2) | ||||
| 
 | ||||
|         # solve | ||||
|  |  | |||
|  | @ -14,20 +14,10 @@ Author: Frank Dellaert | |||
| import unittest | ||||
| 
 | ||||
| from gtsam import (DiscreteBayesNet, DiscreteBayesTreeClique, | ||||
|                    DiscreteConditional, DiscreteFactorGraph, DiscreteKeys, | ||||
|                    Ordering) | ||||
|                    DiscreteConditional, DiscreteFactorGraph, Ordering) | ||||
| from gtsam.utils.test_case import GtsamTestCase | ||||
| 
 | ||||
| 
 | ||||
| def P(*args): | ||||
|     """ Create a DiscreteKeys instances from a variable number of DiscreteKey pairs.""" | ||||
|     # TODO: We can make life easier by providing variable argument functions in C++ itself. | ||||
|     dks = DiscreteKeys() | ||||
|     for key in args: | ||||
|         dks.push_back(key) | ||||
|     return dks | ||||
| 
 | ||||
| 
 | ||||
| class TestDiscreteBayesNet(GtsamTestCase): | ||||
|     """Tests for Discrete Bayes Nets.""" | ||||
| 
 | ||||
|  | @ -40,25 +30,25 @@ class TestDiscreteBayesNet(GtsamTestCase): | |||
|         # Create thin-tree Bayesnet. | ||||
|         bayesNet = DiscreteBayesNet() | ||||
| 
 | ||||
|         bayesNet.add(keys[0], P(keys[8], keys[12]), "2/3 1/4 3/2 4/1") | ||||
|         bayesNet.add(keys[1], P(keys[8], keys[12]), "4/1 2/3 3/2 1/4") | ||||
|         bayesNet.add(keys[2], P(keys[9], keys[12]), "1/4 8/2 2/3 4/1") | ||||
|         bayesNet.add(keys[3], P(keys[9], keys[12]), "1/4 2/3 3/2 4/1") | ||||
|         bayesNet.add(keys[0], [keys[8], keys[12]], "2/3 1/4 3/2 4/1") | ||||
|         bayesNet.add(keys[1], [keys[8], keys[12]], "4/1 2/3 3/2 1/4") | ||||
|         bayesNet.add(keys[2], [keys[9], keys[12]], "1/4 8/2 2/3 4/1") | ||||
|         bayesNet.add(keys[3], [keys[9], keys[12]], "1/4 2/3 3/2 4/1") | ||||
| 
 | ||||
|         bayesNet.add(keys[4], P(keys[10], keys[13]), "2/3 1/4 3/2 4/1") | ||||
|         bayesNet.add(keys[5], P(keys[10], keys[13]), "4/1 2/3 3/2 1/4") | ||||
|         bayesNet.add(keys[6], P(keys[11], keys[13]), "1/4 3/2 2/3 4/1") | ||||
|         bayesNet.add(keys[7], P(keys[11], keys[13]), "1/4 2/3 3/2 4/1") | ||||
|         bayesNet.add(keys[4], [keys[10], keys[13]], "2/3 1/4 3/2 4/1") | ||||
|         bayesNet.add(keys[5], [keys[10], keys[13]], "4/1 2/3 3/2 1/4") | ||||
|         bayesNet.add(keys[6], [keys[11], keys[13]], "1/4 3/2 2/3 4/1") | ||||
|         bayesNet.add(keys[7], [keys[11], keys[13]], "1/4 2/3 3/2 4/1") | ||||
| 
 | ||||
|         bayesNet.add(keys[8], P(keys[12], keys[14]), "T 1/4 3/2 4/1") | ||||
|         bayesNet.add(keys[9], P(keys[12], keys[14]), "4/1 2/3 F 1/4") | ||||
|         bayesNet.add(keys[10], P(keys[13], keys[14]), "1/4 3/2 2/3 4/1") | ||||
|         bayesNet.add(keys[11], P(keys[13], keys[14]), "1/4 2/3 3/2 4/1") | ||||
|         bayesNet.add(keys[8], [keys[12], keys[14]], "T 1/4 3/2 4/1") | ||||
|         bayesNet.add(keys[9], [keys[12], keys[14]], "4/1 2/3 F 1/4") | ||||
|         bayesNet.add(keys[10], [keys[13], keys[14]], "1/4 3/2 2/3 4/1") | ||||
|         bayesNet.add(keys[11], [keys[13], keys[14]], "1/4 2/3 3/2 4/1") | ||||
| 
 | ||||
|         bayesNet.add(keys[12], P(keys[14]), "3/1 3/1") | ||||
|         bayesNet.add(keys[13], P(keys[14]), "1/3 3/1") | ||||
|         bayesNet.add(keys[12], [keys[14]], "3/1 3/1") | ||||
|         bayesNet.add(keys[13], [keys[14]], "1/3 3/1") | ||||
| 
 | ||||
|         bayesNet.add(keys[14], P(), "1/3") | ||||
|         bayesNet.add(keys[14], "1/3") | ||||
| 
 | ||||
|         # Create a factor graph out of the Bayes net. | ||||
|         factorGraph = DiscreteFactorGraph(bayesNet) | ||||
|  |  | |||
|  | @ -13,12 +13,29 @@ Author: Varun Agrawal | |||
| 
 | ||||
| import unittest | ||||
| 
 | ||||
| from gtsam import DiscreteConditional, DiscreteKeys | ||||
| from gtsam import DecisionTreeFactor, DiscreteConditional, DiscreteKeys | ||||
| from gtsam.utils.test_case import GtsamTestCase | ||||
| 
 | ||||
| 
 | ||||
| class TestDiscreteConditional(GtsamTestCase): | ||||
|     """Tests for Discrete Conditionals.""" | ||||
| 
 | ||||
|     def test_single_value_versions(self): | ||||
|         X = (0, 2) | ||||
|         Y = (1, 3) | ||||
|         conditional = DiscreteConditional(X, [Y], "2/8 4/6 5/5") | ||||
| 
 | ||||
|         actual0 = conditional.likelihood(0) | ||||
|         expected0 = DecisionTreeFactor(Y, "0.2 0.4 0.5") | ||||
|         self.gtsamAssertEquals(actual0, expected0, 1e-9) | ||||
| 
 | ||||
|         actual1 = conditional.likelihood(1) | ||||
|         expected1 = DecisionTreeFactor(Y, "0.8 0.6 0.5") | ||||
|         self.gtsamAssertEquals(actual1, expected1, 1e-9) | ||||
| 
 | ||||
|         actual = conditional.sample(2) | ||||
|         self.assertIsInstance(actual, int) | ||||
| 
 | ||||
|     def test_markdown(self): | ||||
|         """Test whether the _repr_markdown_ method.""" | ||||
| 
 | ||||
|  | @ -32,7 +49,7 @@ class TestDiscreteConditional(GtsamTestCase): | |||
|         conditional = DiscreteConditional(A, parents, | ||||
|                                           "0/1 1/3  1/1 3/1  0/1 1/0") | ||||
|         expected = \ | ||||
|             " $P(A|B,C)$:\n" \ | ||||
|             " *P(A|B,C)*:\n\n" \ | ||||
|             "|B|C|0|1|\n" \ | ||||
|             "|:-:|:-:|:-:|:-:|\n" \ | ||||
|             "|0|0|0|1|\n" \ | ||||
|  |  | |||
|  | @ -32,11 +32,11 @@ class TestDiscreteFactorGraph(GtsamTestCase): | |||
|         graph = DiscreteFactorGraph() | ||||
| 
 | ||||
|         # Add two unary factors (priors) | ||||
|         graph.add(P1, "0.9 0.3") | ||||
|         graph.add(P1, [0.9, 0.3]) | ||||
|         graph.add(P2, "0.9 0.6") | ||||
| 
 | ||||
|         # Add a binary factor | ||||
|         graph.add(P1, P2, "4 1 10 4") | ||||
|         graph.add([P1, P2], "4 1 10 4") | ||||
| 
 | ||||
|         # Instantiate Values | ||||
|         assignment = DiscreteValues() | ||||
|  | @ -85,8 +85,8 @@ class TestDiscreteFactorGraph(GtsamTestCase): | |||
|         # A simple factor graph (A)-fAC-(C)-fBC-(B) | ||||
|         # with smoothness priors | ||||
|         graph = DiscreteFactorGraph() | ||||
|         graph.add(A, C, "3 1 1 3") | ||||
|         graph.add(C, B, "3 1 1 3") | ||||
|         graph.add([A, C], "3 1 1 3") | ||||
|         graph.add([C, B], "3 1 1 3") | ||||
| 
 | ||||
|         # Test optimization | ||||
|         expectedValues = DiscreteValues() | ||||
|  | @ -105,8 +105,8 @@ class TestDiscreteFactorGraph(GtsamTestCase): | |||
| 
 | ||||
|         # Create Factor graph | ||||
|         graph = DiscreteFactorGraph() | ||||
|         graph.add(C, A, "0.2 0.8 0.3 0.7") | ||||
|         graph.add(C, B, "0.1 0.9 0.4 0.6") | ||||
|         graph.add([C, A], "0.2 0.8 0.3 0.7") | ||||
|         graph.add([C, B], "0.1 0.9 0.4 0.6") | ||||
| 
 | ||||
|         actualMPE = graph.optimize() | ||||
| 
 | ||||
|  |  | |||
|  | @ -0,0 +1,60 @@ | |||
| """ | ||||
| GTSAM Copyright 2010-2021, Georgia Tech Research Corporation, | ||||
| Atlanta, Georgia 30332-0415 | ||||
| All Rights Reserved | ||||
| 
 | ||||
| See LICENSE for the license information | ||||
| 
 | ||||
| Unit tests for Discrete Priors. | ||||
| Author: Varun Agrawal | ||||
| """ | ||||
| 
 | ||||
| # pylint: disable=no-name-in-module, invalid-name | ||||
| 
 | ||||
| import unittest | ||||
| 
 | ||||
| import numpy as np | ||||
| from gtsam import DecisionTreeFactor, DiscreteKeys, DiscretePrior | ||||
| from gtsam.utils.test_case import GtsamTestCase | ||||
| 
 | ||||
| X = 0, 2 | ||||
| 
 | ||||
| 
 | ||||
| class TestDiscretePrior(GtsamTestCase): | ||||
|     """Tests for Discrete Priors.""" | ||||
| 
 | ||||
|     def test_constructor(self): | ||||
|         """Test various constructors.""" | ||||
|         actual = DiscretePrior(X, "2/3") | ||||
|         keys = DiscreteKeys() | ||||
|         keys.push_back(X) | ||||
|         f = DecisionTreeFactor(keys, "0.4 0.6") | ||||
|         expected = DiscretePrior(f) | ||||
|         self.gtsamAssertEquals(actual, expected) | ||||
| 
 | ||||
|     def test_operator(self): | ||||
|         prior = DiscretePrior(X, "2/3") | ||||
|         self.assertAlmostEqual(prior(0), 0.4) | ||||
|         self.assertAlmostEqual(prior(1), 0.6) | ||||
| 
 | ||||
|     def test_pmf(self): | ||||
|         prior = DiscretePrior(X, "2/3") | ||||
|         expected = np.array([0.4, 0.6]) | ||||
|         np.testing.assert_allclose(expected, prior.pmf()) | ||||
| 
 | ||||
|     def test_markdown(self): | ||||
|         """Test the _repr_markdown_ method.""" | ||||
| 
 | ||||
|         prior = DiscretePrior(X, "2/3") | ||||
|         expected = " *P(0)*:\n\n" \ | ||||
|             "|0|value|\n" \ | ||||
|             "|:-:|:-:|\n" \ | ||||
|             "|0|0.4|\n" \ | ||||
|             "|1|0.6|\n" \ | ||||
| 
 | ||||
|         actual = prior._repr_markdown_() | ||||
|         self.assertEqual(actual, expected) | ||||
| 
 | ||||
| 
 | ||||
| if __name__ == "__main__": | ||||
|     unittest.main() | ||||
		Loading…
	
		Reference in New Issue