move common functions to DiscreteFactor class

release/4.3a0
Varun Agrawal 2023-07-08 11:33:01 -04:00
parent 64ecb8581e
commit 50d24ab38e
6 changed files with 46 additions and 66 deletions

View File

@ -33,16 +33,13 @@ namespace gtsam {
/* ************************************************************************ */ /* ************************************************************************ */
DecisionTreeFactor::DecisionTreeFactor(const DiscreteKeys& keys, DecisionTreeFactor::DecisionTreeFactor(const DiscreteKeys& keys,
const ADT& potentials) const ADT& potentials)
: DiscreteFactor(keys.indices()), : DiscreteFactor(keys.indices(), keys.cardinalities()), ADT(potentials) {}
ADT(potentials),
cardinalities_(keys.cardinalities()) {}
/* ************************************************************************ */ /* ************************************************************************ */
DecisionTreeFactor::DecisionTreeFactor(const DiscreteConditional& c) DecisionTreeFactor::DecisionTreeFactor(const DiscreteConditional& c)
: DiscreteFactor(c.keys()), : DiscreteFactor(c.keys(), c.cardinalities()),
AlgebraicDecisionTree<Key>(c), AlgebraicDecisionTree<Key>(c) {}
cardinalities_(c.cardinalities_) {}
/* ************************************************************************ */ /* ************************************************************************ */
bool DecisionTreeFactor::equals(const DiscreteFactor& other, bool DecisionTreeFactor::equals(const DiscreteFactor& other,
@ -190,18 +187,6 @@ namespace gtsam {
return probs; return probs;
} }
/* ************************************************************************ */
DiscreteKeys DecisionTreeFactor::discreteKeys() const {
DiscreteKeys result;
for (auto&& key : keys()) {
DiscreteKey dkey(key, cardinality(key));
if (std::find(result.begin(), result.end(), dkey) == result.end()) {
result.push_back(dkey);
}
}
return result;
}
/* ************************************************************************ */ /* ************************************************************************ */
static std::string valueFormatter(const double& v) { static std::string valueFormatter(const double& v) {
std::stringstream ss; std::stringstream ss;
@ -297,17 +282,15 @@ namespace gtsam {
/* ************************************************************************ */ /* ************************************************************************ */
DecisionTreeFactor::DecisionTreeFactor(const DiscreteKeys& keys, DecisionTreeFactor::DecisionTreeFactor(const DiscreteKeys& keys,
const vector<double>& table) const vector<double>& table)
: DiscreteFactor(keys.indices()), : DiscreteFactor(keys.indices(), keys.cardinalities()),
AlgebraicDecisionTree<Key>(keys, table), AlgebraicDecisionTree<Key>(keys, table) {}
cardinalities_(keys.cardinalities()) {}
/* ************************************************************************ */ /* ************************************************************************ */
DecisionTreeFactor::DecisionTreeFactor(const DiscreteKeys& keys, DecisionTreeFactor::DecisionTreeFactor(const DiscreteKeys& keys,
const string& table) const string& table)
: DiscreteFactor(keys.indices()), : DiscreteFactor(keys.indices(), keys.cardinalities()),
AlgebraicDecisionTree<Key>(keys, table), AlgebraicDecisionTree<Key>(keys, table) {}
cardinalities_(keys.cardinalities()) {}
/* ************************************************************************ */ /* ************************************************************************ */
DecisionTreeFactor DecisionTreeFactor::prune(size_t maxNrAssignments) const { DecisionTreeFactor DecisionTreeFactor::prune(size_t maxNrAssignments) const {

View File

@ -50,10 +50,6 @@ namespace gtsam {
typedef std::shared_ptr<DecisionTreeFactor> shared_ptr; typedef std::shared_ptr<DecisionTreeFactor> shared_ptr;
typedef AlgebraicDecisionTree<Key> ADT; typedef AlgebraicDecisionTree<Key> ADT;
protected:
std::map<Key, size_t> cardinalities_;
public:
/// @name Standard Constructors /// @name Standard Constructors
/// @{ /// @{
@ -119,10 +115,6 @@ namespace gtsam {
static double safe_div(const double& a, const double& b); static double safe_div(const double& a, const double& b);
std::map<Key, size_t> cardinalities() const { return cardinalities_; }
size_t cardinality(Key j) const { return cardinalities_.at(j); }
/// divide by factor f (safely) /// divide by factor f (safely)
DecisionTreeFactor operator/(const DecisionTreeFactor& f) const { DecisionTreeFactor operator/(const DecisionTreeFactor& f) const {
return apply(f, safe_div); return apply(f, safe_div);
@ -184,9 +176,6 @@ namespace gtsam {
/// Get all the probabilities in order of assignment values /// Get all the probabilities in order of assignment values
std::vector<double> probabilities() const; std::vector<double> probabilities() const;
/// Return all the discrete keys associated with this factor.
DiscreteKeys discreteKeys() const;
/** /**
* @brief Prune the decision tree of discrete variables. * @brief Prune the decision tree of discrete variables.
* *
@ -265,7 +254,6 @@ namespace gtsam {
void serialize(ARCHIVE& ar, const unsigned int /*version*/) { void serialize(ARCHIVE& ar, const unsigned int /*version*/) {
ar& BOOST_SERIALIZATION_BASE_OBJECT_NVP(Base); ar& BOOST_SERIALIZATION_BASE_OBJECT_NVP(Base);
ar& BOOST_SERIALIZATION_BASE_OBJECT_NVP(ADT); ar& BOOST_SERIALIZATION_BASE_OBJECT_NVP(ADT);
ar& BOOST_SERIALIZATION_NVP(cardinalities_);
} }
#endif #endif
}; };

View File

@ -28,6 +28,18 @@ using namespace std;
namespace gtsam { namespace gtsam {
/* ************************************************************************ */
DiscreteKeys DiscreteFactor::discreteKeys() const {
DiscreteKeys result;
for (auto&& key : keys()) {
DiscreteKey dkey(key, cardinality(key));
if (std::find(result.begin(), result.end(), dkey) == result.end()) {
result.push_back(dkey);
}
}
return result;
}
/* ************************************************************************* */ /* ************************************************************************* */
double DiscreteFactor::error(const DiscreteValues& values) const { double DiscreteFactor::error(const DiscreteValues& values) const {
return -std::log((*this)(values)); return -std::log((*this)(values));

View File

@ -45,6 +45,10 @@ class GTSAM_EXPORT DiscreteFactor: public Factor {
using Values = DiscreteValues; ///< backwards compatibility using Values = DiscreteValues; ///< backwards compatibility
protected:
/// Map of Keys and their cardinalities.
std::map<Key, size_t> cardinalities_;
public: public:
/// @name Standard Constructors /// @name Standard Constructors
/// @{ /// @{
@ -52,10 +56,15 @@ class GTSAM_EXPORT DiscreteFactor: public Factor {
/** Default constructor creates empty factor */ /** Default constructor creates empty factor */
DiscreteFactor() {} DiscreteFactor() {}
/** Construct from container of keys. This constructor is used internally from derived factor /**
* constructors, either from a container of keys or from a boost::assign::list_of. */ * Construct from container of keys and map of cardinalities.
template<typename CONTAINER> * This constructor is used internally from derived factor constructors,
DiscreteFactor(const CONTAINER& keys) : Base(keys) {} * either from a container of keys or from a boost::assign::list_of.
*/
template <typename CONTAINER>
DiscreteFactor(const CONTAINER& keys,
const std::map<Key, size_t> cardinalities = {})
: Base(keys), cardinalities_(cardinalities) {}
/// @} /// @}
/// @name Testable /// @name Testable
@ -75,6 +84,13 @@ class GTSAM_EXPORT DiscreteFactor: public Factor {
/// @name Standard Interface /// @name Standard Interface
/// @{ /// @{
/// Return all the discrete keys associated with this factor.
DiscreteKeys discreteKeys() const;
std::map<Key, size_t> cardinalities() const { return cardinalities_; }
size_t cardinality(Key j) const { return cardinalities_.at(j); }
/// Find value for given assignment of values to variables /// Find value for given assignment of values to variables
virtual double operator()(const DiscreteValues&) const = 0; virtual double operator()(const DiscreteValues&) const = 0;
@ -130,6 +146,7 @@ class GTSAM_EXPORT DiscreteFactor: public Factor {
template <class ARCHIVE> template <class ARCHIVE>
void serialize(ARCHIVE& ar, const unsigned int /*version*/) { void serialize(ARCHIVE& ar, const unsigned int /*version*/) {
ar& BOOST_SERIALIZATION_BASE_OBJECT_NVP(Base); ar& BOOST_SERIALIZATION_BASE_OBJECT_NVP(Base);
ar& BOOST_SERIALIZATION_NVP(cardinalities_);
} }
#endif #endif
}; };

View File

@ -34,8 +34,7 @@ TableFactor::TableFactor() {}
/* ************************************************************************ */ /* ************************************************************************ */
TableFactor::TableFactor(const DiscreteKeys& dkeys, TableFactor::TableFactor(const DiscreteKeys& dkeys,
const TableFactor& potentials) const TableFactor& potentials)
: DiscreteFactor(dkeys.indices()), : DiscreteFactor(dkeys.indices(), dkeys.cardinalities()) {
cardinalities_(potentials.cardinalities_) {
sparse_table_ = potentials.sparse_table_; sparse_table_ = potentials.sparse_table_;
denominators_ = potentials.denominators_; denominators_ = potentials.denominators_;
sorted_dkeys_ = discreteKeys(); sorted_dkeys_ = discreteKeys();
@ -45,11 +44,11 @@ TableFactor::TableFactor(const DiscreteKeys& dkeys,
/* ************************************************************************ */ /* ************************************************************************ */
TableFactor::TableFactor(const DiscreteKeys& dkeys, TableFactor::TableFactor(const DiscreteKeys& dkeys,
const Eigen::SparseVector<double>& table) const Eigen::SparseVector<double>& table)
: DiscreteFactor(dkeys.indices()), sparse_table_(table.size()) { : DiscreteFactor(dkeys.indices(), dkeys.cardinalities()),
sparse_table_(table.size()) {
sparse_table_ = table; sparse_table_ = table;
double denom = table.size(); double denom = table.size();
for (const DiscreteKey& dkey : dkeys) { for (const DiscreteKey& dkey : dkeys) {
cardinalities_.insert(dkey);
denom /= dkey.second; denom /= dkey.second;
denominators_.insert(std::pair<Key, double>(dkey.first, denom)); denominators_.insert(std::pair<Key, double>(dkey.first, denom));
} }
@ -440,18 +439,6 @@ std::vector<std::pair<DiscreteValues, double>> TableFactor::enumerate() const {
return result; return result;
} }
/* ************************************************************************ */
DiscreteKeys TableFactor::discreteKeys() const {
DiscreteKeys result;
for (auto&& key : keys()) {
DiscreteKey dkey(key, cardinality(key));
if (std::find(result.begin(), result.end(), dkey) == result.end()) {
result.push_back(dkey);
}
}
return result;
}
// Print out header. // Print out header.
/* ************************************************************************ */ /* ************************************************************************ */
string TableFactor::markdown(const KeyFormatter& keyFormatter, string TableFactor::markdown(const KeyFormatter& keyFormatter,

View File

@ -45,8 +45,6 @@ class HybridValues;
*/ */
class GTSAM_EXPORT TableFactor : public DiscreteFactor { class GTSAM_EXPORT TableFactor : public DiscreteFactor {
protected: protected:
/// Map of Keys and their cardinalities.
std::map<Key, size_t> cardinalities_;
/// SparseVector of nonzero probabilities. /// SparseVector of nonzero probabilities.
Eigen::SparseVector<double> sparse_table_; Eigen::SparseVector<double> sparse_table_;
@ -184,8 +182,6 @@ class GTSAM_EXPORT TableFactor : public DiscreteFactor {
static double safe_div(const double& a, const double& b); static double safe_div(const double& a, const double& b);
size_t cardinality(Key j) const { return cardinalities_.at(j); }
/// divide by factor f (safely) /// divide by factor f (safely)
TableFactor operator/(const TableFactor& f) const { TableFactor operator/(const TableFactor& f) const {
return apply(f, safe_div); return apply(f, safe_div);
@ -278,9 +274,6 @@ class GTSAM_EXPORT TableFactor : public DiscreteFactor {
/// Enumerate all values into a map from values to double. /// Enumerate all values into a map from values to double.
std::vector<std::pair<DiscreteValues, double>> enumerate() const; std::vector<std::pair<DiscreteValues, double>> enumerate() const;
/// Return all the discrete keys associated with this factor.
DiscreteKeys discreteKeys() const;
/** /**
* @brief Prune the decision tree of discrete variables. * @brief Prune the decision tree of discrete variables.
* *