diff --git a/gtsam/discrete/DecisionTreeFactor.cpp b/gtsam/discrete/DecisionTreeFactor.cpp index 04e29024c..ff18268b1 100644 --- a/gtsam/discrete/DecisionTreeFactor.cpp +++ b/gtsam/discrete/DecisionTreeFactor.cpp @@ -33,16 +33,13 @@ namespace gtsam { /* ************************************************************************ */ DecisionTreeFactor::DecisionTreeFactor(const DiscreteKeys& keys, - const ADT& potentials) - : DiscreteFactor(keys.indices()), - ADT(potentials), - cardinalities_(keys.cardinalities()) {} + const ADT& potentials) + : DiscreteFactor(keys.indices(), keys.cardinalities()), ADT(potentials) {} /* ************************************************************************ */ DecisionTreeFactor::DecisionTreeFactor(const DiscreteConditional& c) - : DiscreteFactor(c.keys()), - AlgebraicDecisionTree(c), - cardinalities_(c.cardinalities_) {} + : DiscreteFactor(c.keys(), c.cardinalities()), + AlgebraicDecisionTree(c) {} /* ************************************************************************ */ bool DecisionTreeFactor::equals(const DiscreteFactor& other, @@ -190,18 +187,6 @@ namespace gtsam { return probs; } - /* ************************************************************************ */ - DiscreteKeys DecisionTreeFactor::discreteKeys() const { - DiscreteKeys result; - for (auto&& key : keys()) { - DiscreteKey dkey(key, cardinality(key)); - if (std::find(result.begin(), result.end(), dkey) == result.end()) { - result.push_back(dkey); - } - } - return result; - } - /* ************************************************************************ */ static std::string valueFormatter(const double& v) { std::stringstream ss; @@ -297,17 +282,15 @@ namespace gtsam { /* ************************************************************************ */ DecisionTreeFactor::DecisionTreeFactor(const DiscreteKeys& keys, - const vector& table) - : DiscreteFactor(keys.indices()), - AlgebraicDecisionTree(keys, table), - cardinalities_(keys.cardinalities()) {} + const vector& table) + : DiscreteFactor(keys.indices(), keys.cardinalities()), + AlgebraicDecisionTree(keys, table) {} /* ************************************************************************ */ DecisionTreeFactor::DecisionTreeFactor(const DiscreteKeys& keys, - const string& table) - : DiscreteFactor(keys.indices()), - AlgebraicDecisionTree(keys, table), - cardinalities_(keys.cardinalities()) {} + const string& table) + : DiscreteFactor(keys.indices(), keys.cardinalities()), + AlgebraicDecisionTree(keys, table) {} /* ************************************************************************ */ DecisionTreeFactor DecisionTreeFactor::prune(size_t maxNrAssignments) const { diff --git a/gtsam/discrete/DecisionTreeFactor.h b/gtsam/discrete/DecisionTreeFactor.h index 42639095f..6cce6e5d4 100644 --- a/gtsam/discrete/DecisionTreeFactor.h +++ b/gtsam/discrete/DecisionTreeFactor.h @@ -50,10 +50,6 @@ namespace gtsam { typedef std::shared_ptr shared_ptr; typedef AlgebraicDecisionTree ADT; - protected: - std::map cardinalities_; - - public: /// @name Standard Constructors /// @{ @@ -119,10 +115,6 @@ namespace gtsam { static double safe_div(const double& a, const double& b); - std::map cardinalities() const { return cardinalities_; } - - size_t cardinality(Key j) const { return cardinalities_.at(j); } - /// divide by factor f (safely) DecisionTreeFactor operator/(const DecisionTreeFactor& f) const { return apply(f, safe_div); @@ -184,9 +176,6 @@ namespace gtsam { /// Get all the probabilities in order of assignment values std::vector probabilities() const; - /// Return all the discrete keys associated with this factor. - DiscreteKeys discreteKeys() const; - /** * @brief Prune the decision tree of discrete variables. * @@ -265,7 +254,6 @@ namespace gtsam { void serialize(ARCHIVE& ar, const unsigned int /*version*/) { ar& BOOST_SERIALIZATION_BASE_OBJECT_NVP(Base); ar& BOOST_SERIALIZATION_BASE_OBJECT_NVP(ADT); - ar& BOOST_SERIALIZATION_NVP(cardinalities_); } #endif }; diff --git a/gtsam/discrete/DiscreteFactor.cpp b/gtsam/discrete/DiscreteFactor.cpp index 2b1bc36a3..b44d4fce2 100644 --- a/gtsam/discrete/DiscreteFactor.cpp +++ b/gtsam/discrete/DiscreteFactor.cpp @@ -28,6 +28,18 @@ using namespace std; namespace gtsam { +/* ************************************************************************ */ +DiscreteKeys DiscreteFactor::discreteKeys() const { + DiscreteKeys result; + for (auto&& key : keys()) { + DiscreteKey dkey(key, cardinality(key)); + if (std::find(result.begin(), result.end(), dkey) == result.end()) { + result.push_back(dkey); + } + } + return result; +} + /* ************************************************************************* */ double DiscreteFactor::error(const DiscreteValues& values) const { return -std::log((*this)(values)); diff --git a/gtsam/discrete/DiscreteFactor.h b/gtsam/discrete/DiscreteFactor.h index 2132e1cc8..24b2b55e4 100644 --- a/gtsam/discrete/DiscreteFactor.h +++ b/gtsam/discrete/DiscreteFactor.h @@ -45,6 +45,10 @@ class GTSAM_EXPORT DiscreteFactor: public Factor { using Values = DiscreteValues; ///< backwards compatibility + protected: + /// Map of Keys and their cardinalities. + std::map cardinalities_; + public: /// @name Standard Constructors /// @{ @@ -52,10 +56,15 @@ class GTSAM_EXPORT DiscreteFactor: public Factor { /** Default constructor creates empty factor */ DiscreteFactor() {} - /** Construct from container of keys. This constructor is used internally from derived factor - * constructors, either from a container of keys or from a boost::assign::list_of. */ - template - DiscreteFactor(const CONTAINER& keys) : Base(keys) {} + /** + * Construct from container of keys and map of cardinalities. + * This constructor is used internally from derived factor constructors, + * either from a container of keys or from a boost::assign::list_of. + */ + template + DiscreteFactor(const CONTAINER& keys, + const std::map cardinalities = {}) + : Base(keys), cardinalities_(cardinalities) {} /// @} /// @name Testable @@ -75,6 +84,13 @@ class GTSAM_EXPORT DiscreteFactor: public Factor { /// @name Standard Interface /// @{ + /// Return all the discrete keys associated with this factor. + DiscreteKeys discreteKeys() const; + + std::map cardinalities() const { return cardinalities_; } + + size_t cardinality(Key j) const { return cardinalities_.at(j); } + /// Find value for given assignment of values to variables virtual double operator()(const DiscreteValues&) const = 0; @@ -130,6 +146,7 @@ class GTSAM_EXPORT DiscreteFactor: public Factor { template void serialize(ARCHIVE& ar, const unsigned int /*version*/) { ar& BOOST_SERIALIZATION_BASE_OBJECT_NVP(Base); + ar& BOOST_SERIALIZATION_NVP(cardinalities_); } #endif }; diff --git a/gtsam/discrete/TableFactor.cpp b/gtsam/discrete/TableFactor.cpp index c59b7b72c..74eb3ddb3 100644 --- a/gtsam/discrete/TableFactor.cpp +++ b/gtsam/discrete/TableFactor.cpp @@ -34,8 +34,7 @@ TableFactor::TableFactor() {} /* ************************************************************************ */ TableFactor::TableFactor(const DiscreteKeys& dkeys, const TableFactor& potentials) - : DiscreteFactor(dkeys.indices()), - cardinalities_(potentials.cardinalities_) { + : DiscreteFactor(dkeys.indices(), dkeys.cardinalities()) { sparse_table_ = potentials.sparse_table_; denominators_ = potentials.denominators_; sorted_dkeys_ = discreteKeys(); @@ -45,11 +44,11 @@ TableFactor::TableFactor(const DiscreteKeys& dkeys, /* ************************************************************************ */ TableFactor::TableFactor(const DiscreteKeys& dkeys, const Eigen::SparseVector& table) - : DiscreteFactor(dkeys.indices()), sparse_table_(table.size()) { + : DiscreteFactor(dkeys.indices(), dkeys.cardinalities()), + sparse_table_(table.size()) { sparse_table_ = table; double denom = table.size(); for (const DiscreteKey& dkey : dkeys) { - cardinalities_.insert(dkey); denom /= dkey.second; denominators_.insert(std::pair(dkey.first, denom)); } @@ -440,18 +439,6 @@ std::vector> TableFactor::enumerate() const { return result; } -/* ************************************************************************ */ -DiscreteKeys TableFactor::discreteKeys() const { - DiscreteKeys result; - for (auto&& key : keys()) { - DiscreteKey dkey(key, cardinality(key)); - if (std::find(result.begin(), result.end(), dkey) == result.end()) { - result.push_back(dkey); - } - } - return result; -} - // Print out header. /* ************************************************************************ */ string TableFactor::markdown(const KeyFormatter& keyFormatter, diff --git a/gtsam/discrete/TableFactor.h b/gtsam/discrete/TableFactor.h index 08c675b67..bd637bb7d 100644 --- a/gtsam/discrete/TableFactor.h +++ b/gtsam/discrete/TableFactor.h @@ -45,8 +45,6 @@ class HybridValues; */ class GTSAM_EXPORT TableFactor : public DiscreteFactor { protected: - /// Map of Keys and their cardinalities. - std::map cardinalities_; /// SparseVector of nonzero probabilities. Eigen::SparseVector sparse_table_; @@ -184,8 +182,6 @@ class GTSAM_EXPORT TableFactor : public DiscreteFactor { static double safe_div(const double& a, const double& b); - size_t cardinality(Key j) const { return cardinalities_.at(j); } - /// divide by factor f (safely) TableFactor operator/(const TableFactor& f) const { return apply(f, safe_div); @@ -278,9 +274,6 @@ class GTSAM_EXPORT TableFactor : public DiscreteFactor { /// Enumerate all values into a map from values to double. std::vector> enumerate() const; - /// Return all the discrete keys associated with this factor. - DiscreteKeys discreteKeys() const; - /** * @brief Prune the decision tree of discrete variables. *