Merge pull request #1919 from borglab/discrete-elimination-refactor
						commit
						82d0ebc8fe
					
				|  | @ -87,6 +87,25 @@ namespace gtsam { | |||
|     return result; | ||||
|   } | ||||
| 
 | ||||
|   /* ************************************************************************ */ | ||||
|   DiscreteFactor::shared_ptr DecisionTreeFactor::operator/( | ||||
|       const DiscreteFactor::shared_ptr& f) const { | ||||
|     if (auto tf = std::dynamic_pointer_cast<TableFactor>(f)) { | ||||
|       // Check if `f` is a TableFactor. If yes, then
 | ||||
|       // convert `this` to a TableFactor which is cheaper.
 | ||||
|       return std::make_shared<TableFactor>(tf->operator/(TableFactor(*this))); | ||||
| 
 | ||||
|     } else if (auto dtf = std::dynamic_pointer_cast<DecisionTreeFactor>(f)) { | ||||
|       // If `f` is a DecisionTreeFactor, divide normally.
 | ||||
|       return std::make_shared<DecisionTreeFactor>(this->operator/(*dtf)); | ||||
| 
 | ||||
|     } else { | ||||
|       // Else, convert `f` to a DecisionTreeFactor so we can divide
 | ||||
|       return std::make_shared<DecisionTreeFactor>( | ||||
|           this->operator/(f->toDecisionTreeFactor())); | ||||
|     } | ||||
|   } | ||||
| 
 | ||||
|   /* ************************************************************************ */ | ||||
|   double DecisionTreeFactor::safe_div(const double& a, const double& b) { | ||||
|     // The use for safe_div is when we divide the product factor by the sum
 | ||||
|  |  | |||
|  | @ -184,26 +184,30 @@ namespace gtsam { | |||
|       return apply(f, safe_div); | ||||
|     } | ||||
| 
 | ||||
|     /// divide by DiscreteFactor::shared_ptr f (safely)
 | ||||
|     DiscreteFactor::shared_ptr operator/( | ||||
|         const DiscreteFactor::shared_ptr& f) const override; | ||||
| 
 | ||||
|     /// Convert into a decision tree
 | ||||
|     DecisionTreeFactor toDecisionTreeFactor() const override { return *this; } | ||||
| 
 | ||||
|     /// Create new factor by summing all values with the same separator values
 | ||||
|     shared_ptr sum(size_t nrFrontals) const { | ||||
|     DiscreteFactor::shared_ptr sum(size_t nrFrontals) const override { | ||||
|       return combine(nrFrontals, Ring::add); | ||||
|     } | ||||
| 
 | ||||
|     /// Create new factor by summing all values with the same separator values
 | ||||
|     shared_ptr sum(const Ordering& keys) const { | ||||
|     DiscreteFactor::shared_ptr sum(const Ordering& keys) const override { | ||||
|       return combine(keys, Ring::add); | ||||
|     } | ||||
| 
 | ||||
|     /// Create new factor by maximizing over all values with the same separator.
 | ||||
|     shared_ptr max(size_t nrFrontals) const { | ||||
|     DiscreteFactor::shared_ptr max(size_t nrFrontals) const override { | ||||
|       return combine(nrFrontals, Ring::max); | ||||
|     } | ||||
| 
 | ||||
|     /// Create new factor by maximizing over all values with the same separator.
 | ||||
|     shared_ptr max(const Ordering& keys) const { | ||||
|     DiscreteFactor::shared_ptr max(const Ordering& keys) const override { | ||||
|       return combine(keys, Ring::max); | ||||
|     } | ||||
| 
 | ||||
|  | @ -284,6 +288,12 @@ namespace gtsam { | |||
|      */ | ||||
|     DecisionTreeFactor prune(size_t maxNrAssignments) const; | ||||
| 
 | ||||
|     /**
 | ||||
|      * Get the number of non-zero values contained in this factor. | ||||
|      * It could be much smaller than `prod_{key}(cardinality(key))`. | ||||
|      */ | ||||
|     uint64_t nrValues() const override { return nrLeaves(); } | ||||
| 
 | ||||
|     /// @}
 | ||||
|     /// @name Wrapper support
 | ||||
|     /// @{
 | ||||
|  |  | |||
|  | @ -24,13 +24,13 @@ | |||
| #include <gtsam/hybrid/HybridValues.h> | ||||
| 
 | ||||
| #include <algorithm> | ||||
| #include <cassert> | ||||
| #include <random> | ||||
| #include <set> | ||||
| #include <stdexcept> | ||||
| #include <string> | ||||
| #include <utility> | ||||
| #include <vector> | ||||
| #include <cassert> | ||||
| 
 | ||||
| using namespace std; | ||||
| using std::pair; | ||||
|  | @ -44,8 +44,9 @@ template class GTSAM_EXPORT | |||
| 
 | ||||
| /* ************************************************************************** */ | ||||
| DiscreteConditional::DiscreteConditional(const size_t nrFrontals, | ||||
|                                          const DecisionTreeFactor& f) | ||||
|     : BaseFactor(f / (*f.sum(nrFrontals))), BaseConditional(nrFrontals) {} | ||||
|                                          const DiscreteFactor& f) | ||||
|     : BaseFactor((f / f.sum(nrFrontals))->toDecisionTreeFactor()), | ||||
|       BaseConditional(nrFrontals) {} | ||||
| 
 | ||||
| /* ************************************************************************** */ | ||||
| DiscreteConditional::DiscreteConditional(size_t nrFrontals, | ||||
|  | @ -150,11 +151,11 @@ void DiscreteConditional::print(const string& s, | |||
| /* ************************************************************************** */ | ||||
| bool DiscreteConditional::equals(const DiscreteFactor& other, | ||||
|                                  double tol) const { | ||||
|   if (!dynamic_cast<const DecisionTreeFactor*>(&other)) { | ||||
|   if (!dynamic_cast<const BaseFactor*>(&other)) { | ||||
|     return false; | ||||
|   } else { | ||||
|     const DecisionTreeFactor& f(static_cast<const DecisionTreeFactor&>(other)); | ||||
|     return DecisionTreeFactor::equals(f, tol); | ||||
|     const BaseFactor& f(static_cast<const BaseFactor&>(other)); | ||||
|     return BaseFactor::equals(f, tol); | ||||
|   } | ||||
| } | ||||
| 
 | ||||
|  | @ -375,7 +376,7 @@ std::string DiscreteConditional::markdown(const KeyFormatter& keyFormatter, | |||
|   ss << "*\n" << std::endl; | ||||
|   if (nrParents() == 0) { | ||||
|     // We have no parents, call factor method.
 | ||||
|     ss << DecisionTreeFactor::markdown(keyFormatter, names); | ||||
|     ss << BaseFactor::markdown(keyFormatter, names); | ||||
|     return ss.str(); | ||||
|   } | ||||
| 
 | ||||
|  | @ -427,7 +428,7 @@ string DiscreteConditional::html(const KeyFormatter& keyFormatter, | |||
|   ss << "</i></p>\n"; | ||||
|   if (nrParents() == 0) { | ||||
|     // We have no parents, call factor method.
 | ||||
|     ss << DecisionTreeFactor::html(keyFormatter, names); | ||||
|     ss << BaseFactor::html(keyFormatter, names); | ||||
|     return ss.str(); | ||||
|   } | ||||
| 
 | ||||
|  | @ -475,7 +476,7 @@ string DiscreteConditional::html(const KeyFormatter& keyFormatter, | |||
| 
 | ||||
| /* ************************************************************************* */ | ||||
| double DiscreteConditional::evaluate(const HybridValues& x) const { | ||||
|   return this->evaluate(x.discrete()); | ||||
|   return this->operator()(x.discrete()); | ||||
| } | ||||
| 
 | ||||
| /* ************************************************************************* */ | ||||
|  |  | |||
|  | @ -54,7 +54,7 @@ class GTSAM_EXPORT DiscreteConditional | |||
|   DiscreteConditional() {} | ||||
| 
 | ||||
|   /// Construct from factor, taking the first `nFrontals` keys as frontals.
 | ||||
|   DiscreteConditional(size_t nFrontals, const DecisionTreeFactor& f); | ||||
|   DiscreteConditional(size_t nFrontals, const DiscreteFactor& f); | ||||
| 
 | ||||
|   /**
 | ||||
|    * Construct from DiscreteKeys and AlgebraicDecisionTree, taking the first | ||||
|  |  | |||
|  | @ -22,6 +22,7 @@ | |||
| #include <gtsam/discrete/AlgebraicDecisionTree.h> | ||||
| #include <gtsam/discrete/DiscreteValues.h> | ||||
| #include <gtsam/inference/Factor.h> | ||||
| #include <gtsam/inference/Ordering.h> | ||||
| 
 | ||||
| #include <string> | ||||
| namespace gtsam { | ||||
|  | @ -139,8 +140,30 @@ class GTSAM_EXPORT DiscreteFactor : public Factor { | |||
|   virtual DiscreteFactor::shared_ptr multiply( | ||||
|       const DiscreteFactor::shared_ptr& df) const = 0; | ||||
| 
 | ||||
|   /// divide by DiscreteFactor::shared_ptr f (safely)
 | ||||
|   virtual DiscreteFactor::shared_ptr operator/( | ||||
|       const DiscreteFactor::shared_ptr& df) const = 0; | ||||
| 
 | ||||
|   virtual DecisionTreeFactor toDecisionTreeFactor() const = 0; | ||||
| 
 | ||||
|   /// Create new factor by summing all values with the same separator values
 | ||||
|   virtual DiscreteFactor::shared_ptr sum(size_t nrFrontals) const = 0; | ||||
| 
 | ||||
|   /// Create new factor by summing all values with the same separator values
 | ||||
|   virtual DiscreteFactor::shared_ptr sum(const Ordering& keys) const = 0; | ||||
| 
 | ||||
|   /// Create new factor by maximizing over all values with the same separator.
 | ||||
|   virtual DiscreteFactor::shared_ptr max(size_t nrFrontals) const = 0; | ||||
| 
 | ||||
|   /// Create new factor by maximizing over all values with the same separator.
 | ||||
|   virtual DiscreteFactor::shared_ptr max(const Ordering& keys) const = 0; | ||||
| 
 | ||||
|   /**
 | ||||
|    * Get the number of non-zero values contained in this factor. | ||||
|    * It could be much smaller than `prod_{key}(cardinality(key))`. | ||||
|    */ | ||||
|   virtual uint64_t nrValues() const = 0; | ||||
| 
 | ||||
|   /// @}
 | ||||
|   /// @name Wrapper support
 | ||||
|   /// @{
 | ||||
|  |  | |||
|  | @ -64,7 +64,7 @@ namespace gtsam { | |||
|   } | ||||
| 
 | ||||
|   /* ************************************************************************ */ | ||||
|   DecisionTreeFactor DiscreteFactorGraph::product() const { | ||||
|   DiscreteFactor::shared_ptr DiscreteFactorGraph::product() const { | ||||
|     DiscreteFactor::shared_ptr result; | ||||
|     for (auto it = this->begin(); it != this->end(); ++it) { | ||||
|       if (*it) { | ||||
|  | @ -76,7 +76,7 @@ namespace gtsam { | |||
|         } | ||||
|       } | ||||
|     } | ||||
|     return result->toDecisionTreeFactor(); | ||||
|     return result; | ||||
|   } | ||||
| 
 | ||||
|   /* ************************************************************************ */ | ||||
|  | @ -122,20 +122,20 @@ namespace gtsam { | |||
|    * @brief Multiply all the `factors`. | ||||
|    * | ||||
|    * @param factors The factors to multiply as a DiscreteFactorGraph. | ||||
|    * @return DecisionTreeFactor | ||||
|    * @return DiscreteFactor::shared_ptr | ||||
|    */ | ||||
|   static DecisionTreeFactor DiscreteProduct( | ||||
|   static DiscreteFactor::shared_ptr DiscreteProduct( | ||||
|       const DiscreteFactorGraph& factors) { | ||||
|     // PRODUCT: multiply all factors
 | ||||
|     gttic(product); | ||||
|     DecisionTreeFactor product = factors.product(); | ||||
|     DiscreteFactor::shared_ptr product = factors.product(); | ||||
|     gttoc(product); | ||||
| 
 | ||||
|     // Max over all the potentials by pretending all keys are frontal:
 | ||||
|     auto denominator = product.max(product.size()); | ||||
|     auto denominator = product->max(product->size()); | ||||
| 
 | ||||
|     // Normalize the product factor to prevent underflow.
 | ||||
|     product = product / (*denominator); | ||||
|     product = product->operator/(denominator); | ||||
| 
 | ||||
|     return product; | ||||
|   } | ||||
|  | @ -145,25 +145,25 @@ namespace gtsam { | |||
|   std::pair<DiscreteConditional::shared_ptr, DiscreteFactor::shared_ptr>  //
 | ||||
|   EliminateForMPE(const DiscreteFactorGraph& factors, | ||||
|                   const Ordering& frontalKeys) { | ||||
|     DecisionTreeFactor product = DiscreteProduct(factors); | ||||
|     DiscreteFactor::shared_ptr product = DiscreteProduct(factors); | ||||
| 
 | ||||
|     // max out frontals, this is the factor on the separator
 | ||||
|     gttic(max); | ||||
|     DecisionTreeFactor::shared_ptr max = product.max(frontalKeys); | ||||
|     DiscreteFactor::shared_ptr max = product->max(frontalKeys); | ||||
|     gttoc(max); | ||||
| 
 | ||||
|     // Ordering keys for the conditional so that frontalKeys are really in front
 | ||||
|     DiscreteKeys orderedKeys; | ||||
|     for (auto&& key : frontalKeys) | ||||
|       orderedKeys.emplace_back(key, product.cardinality(key)); | ||||
|       orderedKeys.emplace_back(key, product->cardinality(key)); | ||||
|     for (auto&& key : max->keys()) | ||||
|       orderedKeys.emplace_back(key, product.cardinality(key)); | ||||
|       orderedKeys.emplace_back(key, product->cardinality(key)); | ||||
| 
 | ||||
|     // Make lookup with product
 | ||||
|     gttic(lookup); | ||||
|     size_t nrFrontals = frontalKeys.size(); | ||||
|     auto lookup = | ||||
|         std::make_shared<DiscreteLookupTable>(nrFrontals, orderedKeys, product); | ||||
|     auto lookup = std::make_shared<DiscreteLookupTable>( | ||||
|         nrFrontals, orderedKeys, product->toDecisionTreeFactor()); | ||||
|     gttoc(lookup); | ||||
| 
 | ||||
|     return {std::dynamic_pointer_cast<DiscreteConditional>(lookup), max}; | ||||
|  | @ -223,11 +223,11 @@ namespace gtsam { | |||
|   std::pair<DiscreteConditional::shared_ptr, DiscreteFactor::shared_ptr>  //
 | ||||
|   EliminateDiscrete(const DiscreteFactorGraph& factors, | ||||
|                     const Ordering& frontalKeys) { | ||||
|     DecisionTreeFactor product = DiscreteProduct(factors); | ||||
|     DiscreteFactor::shared_ptr product = DiscreteProduct(factors); | ||||
| 
 | ||||
|     // sum out frontals, this is the factor on the separator
 | ||||
|     gttic(sum); | ||||
|     DecisionTreeFactor::shared_ptr sum = product.sum(frontalKeys); | ||||
|     DiscreteFactor::shared_ptr sum = product->sum(frontalKeys); | ||||
|     gttoc(sum); | ||||
| 
 | ||||
|     // Ordering keys for the conditional so that frontalKeys are really in front
 | ||||
|  | @ -239,8 +239,9 @@ namespace gtsam { | |||
| 
 | ||||
|     // now divide product/sum to get conditional
 | ||||
|     gttic(divide); | ||||
|     auto conditional = | ||||
|         std::make_shared<DiscreteConditional>(product, *sum, orderedKeys); | ||||
|     auto conditional = std::make_shared<DiscreteConditional>( | ||||
|         product->toDecisionTreeFactor(), sum->toDecisionTreeFactor(), | ||||
|         orderedKeys); | ||||
|     gttoc(divide); | ||||
| 
 | ||||
|     return {conditional, sum}; | ||||
|  |  | |||
|  | @ -134,6 +134,7 @@ class GTSAM_EXPORT DiscreteFactorGraph | |||
| 
 | ||||
|   /// @}
 | ||||
| 
 | ||||
|   //TODO(Varun): Make compatible with TableFactor
 | ||||
|   /** Add a decision-tree factor */ | ||||
|   template <typename... Args> | ||||
|   void add(Args&&... args) { | ||||
|  | @ -147,7 +148,7 @@ class GTSAM_EXPORT DiscreteFactorGraph | |||
|   DiscreteKeys discreteKeys() const; | ||||
| 
 | ||||
|   /** return product of all factors as a single factor */ | ||||
|   DecisionTreeFactor product() const; | ||||
|   DiscreteFactor::shared_ptr product() const; | ||||
| 
 | ||||
|   /** 
 | ||||
|    * Evaluates the factor graph given values, returns the joint probability of | ||||
|  |  | |||
|  | @ -18,6 +18,7 @@ | |||
| #pragma once | ||||
| 
 | ||||
| #include <gtsam/discrete/DiscreteDistribution.h> | ||||
| #include <gtsam/discrete/TableFactor.h> | ||||
| #include <gtsam/inference/BayesNet.h> | ||||
| #include <gtsam/inference/FactorGraph.h> | ||||
| 
 | ||||
|  | @ -54,6 +55,18 @@ class GTSAM_EXPORT DiscreteLookupTable : public DiscreteConditional { | |||
|                       const ADT& potentials) | ||||
|       : DiscreteConditional(nFrontals, keys, potentials) {} | ||||
| 
 | ||||
|   /**
 | ||||
|    * @brief Construct a new Discrete Lookup Table object | ||||
|    * | ||||
|    * @param nFrontals number of frontal variables | ||||
|    * @param keys a sorted list of gtsam::Keys | ||||
|    * @param potentials Discrete potentials as a TableFactor. | ||||
|    */ | ||||
|   DiscreteLookupTable(size_t nFrontals, const DiscreteKeys& keys, | ||||
|                       const TableFactor& potentials) | ||||
|       : DiscreteConditional(nFrontals, keys, | ||||
|                             potentials.toDecisionTreeFactor()) {} | ||||
| 
 | ||||
|   /// GTSAM-style print
 | ||||
|   void print( | ||||
|       const std::string& s = "Discrete Lookup Table: ", | ||||
|  |  | |||
|  | @ -280,6 +280,20 @@ DiscreteFactor::shared_ptr TableFactor::multiply( | |||
|   return result; | ||||
| } | ||||
| 
 | ||||
| /* ************************************************************************ */ | ||||
| DiscreteFactor::shared_ptr TableFactor::operator/( | ||||
|     const DiscreteFactor::shared_ptr& f) const { | ||||
|   if (auto tf = std::dynamic_pointer_cast<TableFactor>(f)) { | ||||
|     return std::make_shared<TableFactor>(this->operator/(*tf)); | ||||
|   } else if (auto dtf = std::dynamic_pointer_cast<DecisionTreeFactor>(f)) { | ||||
|     return std::make_shared<TableFactor>( | ||||
|         this->operator/(TableFactor(f->discreteKeys(), *dtf))); | ||||
|   } else { | ||||
|     TableFactor divisor(f->toDecisionTreeFactor()); | ||||
|     return std::make_shared<TableFactor>(this->operator/(divisor)); | ||||
|   } | ||||
| } | ||||
| 
 | ||||
| /* ************************************************************************ */ | ||||
| DecisionTreeFactor TableFactor::toDecisionTreeFactor() const { | ||||
|   DiscreteKeys dkeys = discreteKeys(); | ||||
|  |  | |||
|  | @ -17,6 +17,7 @@ | |||
| 
 | ||||
| #pragma once | ||||
| 
 | ||||
| #include <gtsam/discrete/DecisionTreeFactor.h> | ||||
| #include <gtsam/discrete/DiscreteFactor.h> | ||||
| #include <gtsam/discrete/DiscreteKey.h> | ||||
| #include <gtsam/discrete/Ring.h> | ||||
|  | @ -202,6 +203,10 @@ class GTSAM_EXPORT TableFactor : public DiscreteFactor { | |||
|     return apply(f, safe_div); | ||||
|   } | ||||
| 
 | ||||
|   /// divide by DiscreteFactor::shared_ptr f (safely)
 | ||||
|   DiscreteFactor::shared_ptr operator/( | ||||
|       const DiscreteFactor::shared_ptr& f) const override; | ||||
| 
 | ||||
|   /// Convert into a decisiontree
 | ||||
|   DecisionTreeFactor toDecisionTreeFactor() const override; | ||||
| 
 | ||||
|  | @ -210,22 +215,22 @@ class GTSAM_EXPORT TableFactor : public DiscreteFactor { | |||
|                      DiscreteKeys parent_keys) const; | ||||
| 
 | ||||
|   /// Create new factor by summing all values with the same separator values
 | ||||
|   shared_ptr sum(size_t nrFrontals) const { | ||||
|   DiscreteFactor::shared_ptr sum(size_t nrFrontals) const override { | ||||
|     return combine(nrFrontals, Ring::add); | ||||
|   } | ||||
| 
 | ||||
|   /// Create new factor by summing all values with the same separator values
 | ||||
|   shared_ptr sum(const Ordering& keys) const { | ||||
|   DiscreteFactor::shared_ptr sum(const Ordering& keys) const override { | ||||
|     return combine(keys, Ring::add); | ||||
|   } | ||||
| 
 | ||||
|   /// Create new factor by maximizing over all values with the same separator.
 | ||||
|   shared_ptr max(size_t nrFrontals) const { | ||||
|   DiscreteFactor::shared_ptr max(size_t nrFrontals) const override { | ||||
|     return combine(nrFrontals, Ring::max); | ||||
|   } | ||||
| 
 | ||||
|   /// Create new factor by maximizing over all values with the same separator.
 | ||||
|   shared_ptr max(const Ordering& keys) const { | ||||
|   DiscreteFactor::shared_ptr max(const Ordering& keys) const override { | ||||
|     return combine(keys, Ring::max); | ||||
|   } | ||||
| 
 | ||||
|  | @ -330,6 +335,12 @@ class GTSAM_EXPORT TableFactor : public DiscreteFactor { | |||
|    */ | ||||
|   TableFactor prune(size_t maxNrAssignments) const; | ||||
| 
 | ||||
|   /**
 | ||||
|    * Get the number of non-zero values contained in this factor. | ||||
|    * It could be much smaller than `prod_{key}(cardinality(key))`. | ||||
|    */ | ||||
|   uint64_t nrValues() const override { return sparse_table_.nonZeros(); } | ||||
| 
 | ||||
|   /// @}
 | ||||
|   /// @name Wrapper support
 | ||||
|   /// @{
 | ||||
|  |  | |||
|  | @ -138,15 +138,18 @@ TEST(DecisionTreeFactor, sum_max) { | |||
|   DecisionTreeFactor f1(v0 & v1, "1 2  3 4  5 6"); | ||||
| 
 | ||||
|   DecisionTreeFactor expected(v1, "9 12"); | ||||
|   DecisionTreeFactor::shared_ptr actual = f1.sum(1); | ||||
|   auto actual = std::dynamic_pointer_cast<DecisionTreeFactor>(f1.sum(1)); | ||||
|   CHECK(actual); | ||||
|   CHECK(assert_equal(expected, *actual, 1e-5)); | ||||
| 
 | ||||
|   DecisionTreeFactor expected2(v1, "5 6"); | ||||
|   DecisionTreeFactor::shared_ptr actual2 = f1.max(1); | ||||
|   auto actual2 = std::dynamic_pointer_cast<DecisionTreeFactor>(f1.max(1)); | ||||
|   CHECK(actual2); | ||||
|   CHECK(assert_equal(expected2, *actual2)); | ||||
| 
 | ||||
|   DecisionTreeFactor f2(v1 & v0, "1 2  3 4  5 6"); | ||||
|   DecisionTreeFactor::shared_ptr actual22 = f2.sum(1); | ||||
|   auto actual22 = std::dynamic_pointer_cast<DecisionTreeFactor>(f2.sum(1)); | ||||
|   CHECK(actual22); | ||||
| } | ||||
| 
 | ||||
| /* ************************************************************************* */ | ||||
|  |  | |||
|  | @ -46,7 +46,7 @@ TEST(DiscreteConditional, constructors) { | |||
|   DecisionTreeFactor f2( | ||||
|       X & Y & Z, "0.2 0.5 0.3 0.6 0.4 0.7 0.25 0.55 0.35 0.65 0.45 0.75"); | ||||
|   DiscreteConditional actual2(1, f2); | ||||
|   DecisionTreeFactor expected2 = f2 / *f2.sum(1); | ||||
|   DecisionTreeFactor expected2 = f2 / f2.sum(1)->toDecisionTreeFactor(); | ||||
|   EXPECT(assert_equal(expected2, static_cast<DecisionTreeFactor>(actual2))); | ||||
| 
 | ||||
|   std::vector<double> probs{0.2, 0.5, 0.3, 0.6, 0.4, 0.7, 0.25, 0.55, 0.35, 0.65, 0.45, 0.75}; | ||||
|  | @ -70,7 +70,7 @@ TEST(DiscreteConditional, constructors_alt_interface) { | |||
|   DecisionTreeFactor f2( | ||||
|       X & Y & Z, "0.2 0.5 0.3 0.6 0.4 0.7 0.25 0.55 0.35 0.65 0.45 0.75"); | ||||
|   DiscreteConditional actual2(1, f2); | ||||
|   DecisionTreeFactor expected2 = f2 / *f2.sum(1); | ||||
|   DecisionTreeFactor expected2 = f2 / f2.sum(1)->toDecisionTreeFactor(); | ||||
|   EXPECT(assert_equal(expected2, static_cast<DecisionTreeFactor>(actual2))); | ||||
| } | ||||
| 
 | ||||
|  |  | |||
|  | @ -94,7 +94,7 @@ TEST_UNSAFE( DiscreteFactorGraph, DiscreteFactorGraphEvaluationTest) { | |||
|   EXPECT_DOUBLES_EQUAL( 1.944, graph(values), 1e-9); | ||||
| 
 | ||||
|   // Check if graph product works
 | ||||
|   DecisionTreeFactor product = graph.product(); | ||||
|   DecisionTreeFactor product = graph.product()->toDecisionTreeFactor(); | ||||
|   EXPECT_DOUBLES_EQUAL( 1.944, product(values), 1e-9); | ||||
| } | ||||
| 
 | ||||
|  | @ -117,9 +117,9 @@ TEST(DiscreteFactorGraph, test) { | |||
|       *std::dynamic_pointer_cast<DecisionTreeFactor>(newFactorPtr); | ||||
| 
 | ||||
|   // Normalize newFactor by max for comparison with expected
 | ||||
|   auto normalizer = newFactor.max(newFactor.size()); | ||||
|   auto denominator = newFactor.max(newFactor.size())->toDecisionTreeFactor(); | ||||
| 
 | ||||
|   newFactor = newFactor / *normalizer; | ||||
|   newFactor = newFactor / denominator; | ||||
| 
 | ||||
|   // Check Conditional
 | ||||
|   CHECK(conditional); | ||||
|  | @ -131,9 +131,10 @@ TEST(DiscreteFactorGraph, test) { | |||
|   CHECK(&newFactor); | ||||
|   DecisionTreeFactor expectedFactor(B & A, "10 6 6 10"); | ||||
|   // Normalize by max.
 | ||||
|   normalizer = expectedFactor.max(expectedFactor.size()); | ||||
|   // Ensure normalizer is correct.
 | ||||
|   expectedFactor = expectedFactor / *normalizer; | ||||
|   denominator = | ||||
|       expectedFactor.max(expectedFactor.size())->toDecisionTreeFactor(); | ||||
|   // Ensure denominator is correct.
 | ||||
|   expectedFactor = expectedFactor / denominator; | ||||
|   EXPECT(assert_equal(expectedFactor, newFactor)); | ||||
| 
 | ||||
|   // Test using elimination tree
 | ||||
|  |  | |||
|  | @ -194,15 +194,17 @@ TEST(TableFactor, Conversion) { | |||
| TEST(TableFactor, Empty) { | ||||
|   DiscreteKey X(1, 2); | ||||
| 
 | ||||
|   TableFactor single = *TableFactor({X}, "1 1").sum(1); | ||||
|   auto single = TableFactor({X}, "1 1").sum(1); | ||||
|   // Should not throw a segfault
 | ||||
|   EXPECT(assert_equal(*DecisionTreeFactor(X, "1 1").sum(1), | ||||
|                       single.toDecisionTreeFactor())); | ||||
|   auto expected_single = DecisionTreeFactor(X, "1 1").sum(1); | ||||
|   EXPECT(assert_equal(expected_single->toDecisionTreeFactor(), | ||||
|                       single->toDecisionTreeFactor())); | ||||
| 
 | ||||
|   TableFactor empty = *TableFactor({X}, "0 0").sum(1); | ||||
|   auto empty = TableFactor({X}, "0 0").sum(1); | ||||
|   // Should not throw a segfault
 | ||||
|   EXPECT(assert_equal(*DecisionTreeFactor(X, "0 0").sum(1), | ||||
|                       empty.toDecisionTreeFactor())); | ||||
|   auto expected_empty = DecisionTreeFactor(X, "0 0").sum(1); | ||||
|   EXPECT(assert_equal(expected_empty->toDecisionTreeFactor(), | ||||
|                       empty->toDecisionTreeFactor())); | ||||
| } | ||||
| 
 | ||||
| /* ************************************************************************* */ | ||||
|  | @ -303,15 +305,18 @@ TEST(TableFactor, sum_max) { | |||
|   TableFactor f1(v0 & v1, "1 2  3 4  5 6"); | ||||
| 
 | ||||
|   TableFactor expected(v1, "9 12"); | ||||
|   TableFactor::shared_ptr actual = f1.sum(1); | ||||
|   auto actual = std::dynamic_pointer_cast<TableFactor>(f1.sum(1)); | ||||
|   CHECK(actual); | ||||
|   CHECK(assert_equal(expected, *actual, 1e-5)); | ||||
| 
 | ||||
|   TableFactor expected2(v1, "5 6"); | ||||
|   TableFactor::shared_ptr actual2 = f1.max(1); | ||||
|   auto actual2 = std::dynamic_pointer_cast<TableFactor>(f1.max(1)); | ||||
|   CHECK(actual2); | ||||
|   CHECK(assert_equal(expected2, *actual2)); | ||||
| 
 | ||||
|   TableFactor f2(v1 & v0, "1 2  3 4  5 6"); | ||||
|   TableFactor::shared_ptr actual22 = f2.sum(1); | ||||
|   auto actual22 = std::dynamic_pointer_cast<TableFactor>(f2.sum(1)); | ||||
|   CHECK(actual22); | ||||
| } | ||||
| 
 | ||||
| /* ************************************************************************* */ | ||||
|  |  | |||
|  | @ -68,7 +68,8 @@ class GTSAM_UNSTABLE_EXPORT Constraint : public DiscreteFactor { | |||
|   /*
 | ||||
|    * Ensure Arc-consistency by checking every possible value of domain j. | ||||
|    * @param j domain to be checked | ||||
|    * @param (in/out) domains all domains, but only domains->at(j) will be checked. | ||||
|    * @param (in/out) domains all domains, but only domains->at(j) will be | ||||
|    * checked. | ||||
|    * @return true if domains->at(j) was changed, false otherwise. | ||||
|    */ | ||||
|   virtual bool ensureArcConsistency(Key j, Domains* domains) const = 0; | ||||
|  | @ -86,6 +87,31 @@ class GTSAM_UNSTABLE_EXPORT Constraint : public DiscreteFactor { | |||
|         this->operator*(df->toDecisionTreeFactor())); | ||||
|   } | ||||
| 
 | ||||
|   /// divide by DiscreteFactor::shared_ptr f (safely)
 | ||||
|   DiscreteFactor::shared_ptr operator/( | ||||
|       const DiscreteFactor::shared_ptr& df) const override { | ||||
|     return this->toDecisionTreeFactor() / df; | ||||
|   } | ||||
| 
 | ||||
|   /// Get the number of non-zero values contained in this factor.
 | ||||
|   uint64_t nrValues() const override { return 1; }; | ||||
| 
 | ||||
|   DiscreteFactor::shared_ptr sum(size_t nrFrontals) const override { | ||||
|     return toDecisionTreeFactor().sum(nrFrontals); | ||||
|   } | ||||
| 
 | ||||
|   DiscreteFactor::shared_ptr sum(const Ordering& keys) const override { | ||||
|     return toDecisionTreeFactor().sum(keys); | ||||
|   } | ||||
| 
 | ||||
|   DiscreteFactor::shared_ptr max(size_t nrFrontals) const override { | ||||
|     return toDecisionTreeFactor().max(nrFrontals); | ||||
|   } | ||||
| 
 | ||||
|   DiscreteFactor::shared_ptr max(const Ordering& keys) const override { | ||||
|     return toDecisionTreeFactor().max(keys); | ||||
|   } | ||||
| 
 | ||||
|   /// @}
 | ||||
|   /// @name Wrapper support
 | ||||
|   /// @{
 | ||||
|  |  | |||
|  | @ -49,7 +49,7 @@ class GTSAM_UNSTABLE_EXPORT Domain : public Constraint { | |||
|   /// Erase a value, non const :-(
 | ||||
|   void erase(size_t value) { values_.erase(value); } | ||||
| 
 | ||||
|   size_t nrValues() const { return values_.size(); } | ||||
|   uint64_t nrValues() const override { return values_.size(); } | ||||
| 
 | ||||
|   bool isSingleton() const { return nrValues() == 1; } | ||||
| 
 | ||||
|  |  | |||
|  | @ -124,7 +124,7 @@ TEST(CSP, allInOne) { | |||
|   EXPECT_DOUBLES_EQUAL(1, csp(valid), 1e-9); | ||||
| 
 | ||||
|   // Just for fun, create the product and check it
 | ||||
|   DecisionTreeFactor product = csp.product(); | ||||
|   DecisionTreeFactor product = csp.product()->toDecisionTreeFactor(); | ||||
|   // product.dot("product");
 | ||||
|   DecisionTreeFactor expectedProduct(ID & AZ & UT, "0 1 0 0 0 0 1 0"); | ||||
|   EXPECT(assert_equal(expectedProduct, product)); | ||||
|  |  | |||
|  | @ -113,7 +113,7 @@ TEST(schedulingExample, test) { | |||
|   EXPECT(assert_equal(expected, (DiscreteFactorGraph)s)); | ||||
| 
 | ||||
|   // Do brute force product and output that to file
 | ||||
|   DecisionTreeFactor product = s.product(); | ||||
|   DecisionTreeFactor product = s.product()->toDecisionTreeFactor(); | ||||
|   // product.dot("scheduling", false);
 | ||||
| 
 | ||||
|   // Do exact inference
 | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue