move common functions to DiscreteFactor class

release/4.3a0
Varun Agrawal 2023-07-08 11:33:01 -04:00
parent 64ecb8581e
commit 50d24ab38e
6 changed files with 46 additions and 66 deletions

View File

@ -33,16 +33,13 @@ namespace gtsam {
/* ************************************************************************ */
DecisionTreeFactor::DecisionTreeFactor(const DiscreteKeys& keys,
const ADT& potentials)
: DiscreteFactor(keys.indices()),
ADT(potentials),
cardinalities_(keys.cardinalities()) {}
const ADT& potentials)
: DiscreteFactor(keys.indices(), keys.cardinalities()), ADT(potentials) {}
/* ************************************************************************ */
DecisionTreeFactor::DecisionTreeFactor(const DiscreteConditional& c)
: DiscreteFactor(c.keys()),
AlgebraicDecisionTree<Key>(c),
cardinalities_(c.cardinalities_) {}
: DiscreteFactor(c.keys(), c.cardinalities()),
AlgebraicDecisionTree<Key>(c) {}
/* ************************************************************************ */
bool DecisionTreeFactor::equals(const DiscreteFactor& other,
@ -190,18 +187,6 @@ namespace gtsam {
return probs;
}
/* ************************************************************************ */
DiscreteKeys DecisionTreeFactor::discreteKeys() const {
DiscreteKeys result;
for (auto&& key : keys()) {
DiscreteKey dkey(key, cardinality(key));
if (std::find(result.begin(), result.end(), dkey) == result.end()) {
result.push_back(dkey);
}
}
return result;
}
/* ************************************************************************ */
static std::string valueFormatter(const double& v) {
std::stringstream ss;
@ -297,17 +282,15 @@ namespace gtsam {
/* ************************************************************************ */
DecisionTreeFactor::DecisionTreeFactor(const DiscreteKeys& keys,
const vector<double>& table)
: DiscreteFactor(keys.indices()),
AlgebraicDecisionTree<Key>(keys, table),
cardinalities_(keys.cardinalities()) {}
const vector<double>& table)
: DiscreteFactor(keys.indices(), keys.cardinalities()),
AlgebraicDecisionTree<Key>(keys, table) {}
/* ************************************************************************ */
DecisionTreeFactor::DecisionTreeFactor(const DiscreteKeys& keys,
const string& table)
: DiscreteFactor(keys.indices()),
AlgebraicDecisionTree<Key>(keys, table),
cardinalities_(keys.cardinalities()) {}
const string& table)
: DiscreteFactor(keys.indices(), keys.cardinalities()),
AlgebraicDecisionTree<Key>(keys, table) {}
/* ************************************************************************ */
DecisionTreeFactor DecisionTreeFactor::prune(size_t maxNrAssignments) const {

View File

@ -50,10 +50,6 @@ namespace gtsam {
typedef std::shared_ptr<DecisionTreeFactor> shared_ptr;
typedef AlgebraicDecisionTree<Key> ADT;
protected:
std::map<Key, size_t> cardinalities_;
public:
/// @name Standard Constructors
/// @{
@ -119,10 +115,6 @@ namespace gtsam {
static double safe_div(const double& a, const double& b);
std::map<Key, size_t> cardinalities() const { return cardinalities_; }
size_t cardinality(Key j) const { return cardinalities_.at(j); }
/// divide by factor f (safely)
DecisionTreeFactor operator/(const DecisionTreeFactor& f) const {
return apply(f, safe_div);
@ -184,9 +176,6 @@ namespace gtsam {
/// Get all the probabilities in order of assignment values
std::vector<double> probabilities() const;
/// Return all the discrete keys associated with this factor.
DiscreteKeys discreteKeys() const;
/**
* @brief Prune the decision tree of discrete variables.
*
@ -265,7 +254,6 @@ namespace gtsam {
void serialize(ARCHIVE& ar, const unsigned int /*version*/) {
ar& BOOST_SERIALIZATION_BASE_OBJECT_NVP(Base);
ar& BOOST_SERIALIZATION_BASE_OBJECT_NVP(ADT);
ar& BOOST_SERIALIZATION_NVP(cardinalities_);
}
#endif
};

View File

@ -28,6 +28,18 @@ using namespace std;
namespace gtsam {
/* ************************************************************************ */
DiscreteKeys DiscreteFactor::discreteKeys() const {
DiscreteKeys result;
for (auto&& key : keys()) {
DiscreteKey dkey(key, cardinality(key));
if (std::find(result.begin(), result.end(), dkey) == result.end()) {
result.push_back(dkey);
}
}
return result;
}
/* ************************************************************************* */
double DiscreteFactor::error(const DiscreteValues& values) const {
return -std::log((*this)(values));

View File

@ -45,6 +45,10 @@ class GTSAM_EXPORT DiscreteFactor: public Factor {
using Values = DiscreteValues; ///< backwards compatibility
protected:
/// Map of Keys and their cardinalities.
std::map<Key, size_t> cardinalities_;
public:
/// @name Standard Constructors
/// @{
@ -52,10 +56,15 @@ class GTSAM_EXPORT DiscreteFactor: public Factor {
/** Default constructor creates empty factor */
DiscreteFactor() {}
/** Construct from container of keys. This constructor is used internally from derived factor
* constructors, either from a container of keys or from a boost::assign::list_of. */
template<typename CONTAINER>
DiscreteFactor(const CONTAINER& keys) : Base(keys) {}
/**
* Construct from container of keys and map of cardinalities.
* This constructor is used internally from derived factor constructors,
* either from a container of keys or from a boost::assign::list_of.
*/
template <typename CONTAINER>
DiscreteFactor(const CONTAINER& keys,
const std::map<Key, size_t> cardinalities = {})
: Base(keys), cardinalities_(cardinalities) {}
/// @}
/// @name Testable
@ -75,6 +84,13 @@ class GTSAM_EXPORT DiscreteFactor: public Factor {
/// @name Standard Interface
/// @{
/// Return all the discrete keys associated with this factor.
DiscreteKeys discreteKeys() const;
std::map<Key, size_t> cardinalities() const { return cardinalities_; }
size_t cardinality(Key j) const { return cardinalities_.at(j); }
/// Find value for given assignment of values to variables
virtual double operator()(const DiscreteValues&) const = 0;
@ -130,6 +146,7 @@ class GTSAM_EXPORT DiscreteFactor: public Factor {
template <class ARCHIVE>
void serialize(ARCHIVE& ar, const unsigned int /*version*/) {
ar& BOOST_SERIALIZATION_BASE_OBJECT_NVP(Base);
ar& BOOST_SERIALIZATION_NVP(cardinalities_);
}
#endif
};

View File

@ -34,8 +34,7 @@ TableFactor::TableFactor() {}
/* ************************************************************************ */
TableFactor::TableFactor(const DiscreteKeys& dkeys,
const TableFactor& potentials)
: DiscreteFactor(dkeys.indices()),
cardinalities_(potentials.cardinalities_) {
: DiscreteFactor(dkeys.indices(), dkeys.cardinalities()) {
sparse_table_ = potentials.sparse_table_;
denominators_ = potentials.denominators_;
sorted_dkeys_ = discreteKeys();
@ -45,11 +44,11 @@ TableFactor::TableFactor(const DiscreteKeys& dkeys,
/* ************************************************************************ */
TableFactor::TableFactor(const DiscreteKeys& dkeys,
const Eigen::SparseVector<double>& table)
: DiscreteFactor(dkeys.indices()), sparse_table_(table.size()) {
: DiscreteFactor(dkeys.indices(), dkeys.cardinalities()),
sparse_table_(table.size()) {
sparse_table_ = table;
double denom = table.size();
for (const DiscreteKey& dkey : dkeys) {
cardinalities_.insert(dkey);
denom /= dkey.second;
denominators_.insert(std::pair<Key, double>(dkey.first, denom));
}
@ -440,18 +439,6 @@ std::vector<std::pair<DiscreteValues, double>> TableFactor::enumerate() const {
return result;
}
/* ************************************************************************ */
DiscreteKeys TableFactor::discreteKeys() const {
DiscreteKeys result;
for (auto&& key : keys()) {
DiscreteKey dkey(key, cardinality(key));
if (std::find(result.begin(), result.end(), dkey) == result.end()) {
result.push_back(dkey);
}
}
return result;
}
// Print out header.
/* ************************************************************************ */
string TableFactor::markdown(const KeyFormatter& keyFormatter,

View File

@ -45,8 +45,6 @@ class HybridValues;
*/
class GTSAM_EXPORT TableFactor : public DiscreteFactor {
protected:
/// Map of Keys and their cardinalities.
std::map<Key, size_t> cardinalities_;
/// SparseVector of nonzero probabilities.
Eigen::SparseVector<double> sparse_table_;
@ -184,8 +182,6 @@ class GTSAM_EXPORT TableFactor : public DiscreteFactor {
static double safe_div(const double& a, const double& b);
size_t cardinality(Key j) const { return cardinalities_.at(j); }
/// divide by factor f (safely)
TableFactor operator/(const TableFactor& f) const {
return apply(f, safe_div);
@ -278,9 +274,6 @@ class GTSAM_EXPORT TableFactor : public DiscreteFactor {
/// Enumerate all values into a map from values to double.
std::vector<std::pair<DiscreteValues, double>> enumerate() const;
/// Return all the discrete keys associated with this factor.
DiscreteKeys discreteKeys() const;
/**
* @brief Prune the decision tree of discrete variables.
*