move common functions to DiscreteFactor class
parent
64ecb8581e
commit
50d24ab38e
|
|
@ -33,16 +33,13 @@ namespace gtsam {
|
|||
|
||||
/* ************************************************************************ */
|
||||
DecisionTreeFactor::DecisionTreeFactor(const DiscreteKeys& keys,
|
||||
const ADT& potentials)
|
||||
: DiscreteFactor(keys.indices()),
|
||||
ADT(potentials),
|
||||
cardinalities_(keys.cardinalities()) {}
|
||||
const ADT& potentials)
|
||||
: DiscreteFactor(keys.indices(), keys.cardinalities()), ADT(potentials) {}
|
||||
|
||||
/* ************************************************************************ */
|
||||
DecisionTreeFactor::DecisionTreeFactor(const DiscreteConditional& c)
|
||||
: DiscreteFactor(c.keys()),
|
||||
AlgebraicDecisionTree<Key>(c),
|
||||
cardinalities_(c.cardinalities_) {}
|
||||
: DiscreteFactor(c.keys(), c.cardinalities()),
|
||||
AlgebraicDecisionTree<Key>(c) {}
|
||||
|
||||
/* ************************************************************************ */
|
||||
bool DecisionTreeFactor::equals(const DiscreteFactor& other,
|
||||
|
|
@ -190,18 +187,6 @@ namespace gtsam {
|
|||
return probs;
|
||||
}
|
||||
|
||||
/* ************************************************************************ */
|
||||
DiscreteKeys DecisionTreeFactor::discreteKeys() const {
|
||||
DiscreteKeys result;
|
||||
for (auto&& key : keys()) {
|
||||
DiscreteKey dkey(key, cardinality(key));
|
||||
if (std::find(result.begin(), result.end(), dkey) == result.end()) {
|
||||
result.push_back(dkey);
|
||||
}
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
/* ************************************************************************ */
|
||||
static std::string valueFormatter(const double& v) {
|
||||
std::stringstream ss;
|
||||
|
|
@ -297,17 +282,15 @@ namespace gtsam {
|
|||
|
||||
/* ************************************************************************ */
|
||||
DecisionTreeFactor::DecisionTreeFactor(const DiscreteKeys& keys,
|
||||
const vector<double>& table)
|
||||
: DiscreteFactor(keys.indices()),
|
||||
AlgebraicDecisionTree<Key>(keys, table),
|
||||
cardinalities_(keys.cardinalities()) {}
|
||||
const vector<double>& table)
|
||||
: DiscreteFactor(keys.indices(), keys.cardinalities()),
|
||||
AlgebraicDecisionTree<Key>(keys, table) {}
|
||||
|
||||
/* ************************************************************************ */
|
||||
DecisionTreeFactor::DecisionTreeFactor(const DiscreteKeys& keys,
|
||||
const string& table)
|
||||
: DiscreteFactor(keys.indices()),
|
||||
AlgebraicDecisionTree<Key>(keys, table),
|
||||
cardinalities_(keys.cardinalities()) {}
|
||||
const string& table)
|
||||
: DiscreteFactor(keys.indices(), keys.cardinalities()),
|
||||
AlgebraicDecisionTree<Key>(keys, table) {}
|
||||
|
||||
/* ************************************************************************ */
|
||||
DecisionTreeFactor DecisionTreeFactor::prune(size_t maxNrAssignments) const {
|
||||
|
|
|
|||
|
|
@ -50,10 +50,6 @@ namespace gtsam {
|
|||
typedef std::shared_ptr<DecisionTreeFactor> shared_ptr;
|
||||
typedef AlgebraicDecisionTree<Key> ADT;
|
||||
|
||||
protected:
|
||||
std::map<Key, size_t> cardinalities_;
|
||||
|
||||
public:
|
||||
/// @name Standard Constructors
|
||||
/// @{
|
||||
|
||||
|
|
@ -119,10 +115,6 @@ namespace gtsam {
|
|||
|
||||
static double safe_div(const double& a, const double& b);
|
||||
|
||||
std::map<Key, size_t> cardinalities() const { return cardinalities_; }
|
||||
|
||||
size_t cardinality(Key j) const { return cardinalities_.at(j); }
|
||||
|
||||
/// divide by factor f (safely)
|
||||
DecisionTreeFactor operator/(const DecisionTreeFactor& f) const {
|
||||
return apply(f, safe_div);
|
||||
|
|
@ -184,9 +176,6 @@ namespace gtsam {
|
|||
/// Get all the probabilities in order of assignment values
|
||||
std::vector<double> probabilities() const;
|
||||
|
||||
/// Return all the discrete keys associated with this factor.
|
||||
DiscreteKeys discreteKeys() const;
|
||||
|
||||
/**
|
||||
* @brief Prune the decision tree of discrete variables.
|
||||
*
|
||||
|
|
@ -265,7 +254,6 @@ namespace gtsam {
|
|||
void serialize(ARCHIVE& ar, const unsigned int /*version*/) {
|
||||
ar& BOOST_SERIALIZATION_BASE_OBJECT_NVP(Base);
|
||||
ar& BOOST_SERIALIZATION_BASE_OBJECT_NVP(ADT);
|
||||
ar& BOOST_SERIALIZATION_NVP(cardinalities_);
|
||||
}
|
||||
#endif
|
||||
};
|
||||
|
|
|
|||
|
|
@ -28,6 +28,18 @@ using namespace std;
|
|||
|
||||
namespace gtsam {
|
||||
|
||||
/* ************************************************************************ */
|
||||
DiscreteKeys DiscreteFactor::discreteKeys() const {
|
||||
DiscreteKeys result;
|
||||
for (auto&& key : keys()) {
|
||||
DiscreteKey dkey(key, cardinality(key));
|
||||
if (std::find(result.begin(), result.end(), dkey) == result.end()) {
|
||||
result.push_back(dkey);
|
||||
}
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
double DiscreteFactor::error(const DiscreteValues& values) const {
|
||||
return -std::log((*this)(values));
|
||||
|
|
|
|||
|
|
@ -45,6 +45,10 @@ class GTSAM_EXPORT DiscreteFactor: public Factor {
|
|||
|
||||
using Values = DiscreteValues; ///< backwards compatibility
|
||||
|
||||
protected:
|
||||
/// Map of Keys and their cardinalities.
|
||||
std::map<Key, size_t> cardinalities_;
|
||||
|
||||
public:
|
||||
/// @name Standard Constructors
|
||||
/// @{
|
||||
|
|
@ -52,10 +56,15 @@ class GTSAM_EXPORT DiscreteFactor: public Factor {
|
|||
/** Default constructor creates empty factor */
|
||||
DiscreteFactor() {}
|
||||
|
||||
/** Construct from container of keys. This constructor is used internally from derived factor
|
||||
* constructors, either from a container of keys or from a boost::assign::list_of. */
|
||||
template<typename CONTAINER>
|
||||
DiscreteFactor(const CONTAINER& keys) : Base(keys) {}
|
||||
/**
|
||||
* Construct from container of keys and map of cardinalities.
|
||||
* This constructor is used internally from derived factor constructors,
|
||||
* either from a container of keys or from a boost::assign::list_of.
|
||||
*/
|
||||
template <typename CONTAINER>
|
||||
DiscreteFactor(const CONTAINER& keys,
|
||||
const std::map<Key, size_t> cardinalities = {})
|
||||
: Base(keys), cardinalities_(cardinalities) {}
|
||||
|
||||
/// @}
|
||||
/// @name Testable
|
||||
|
|
@ -75,6 +84,13 @@ class GTSAM_EXPORT DiscreteFactor: public Factor {
|
|||
/// @name Standard Interface
|
||||
/// @{
|
||||
|
||||
/// Return all the discrete keys associated with this factor.
|
||||
DiscreteKeys discreteKeys() const;
|
||||
|
||||
std::map<Key, size_t> cardinalities() const { return cardinalities_; }
|
||||
|
||||
size_t cardinality(Key j) const { return cardinalities_.at(j); }
|
||||
|
||||
/// Find value for given assignment of values to variables
|
||||
virtual double operator()(const DiscreteValues&) const = 0;
|
||||
|
||||
|
|
@ -130,6 +146,7 @@ class GTSAM_EXPORT DiscreteFactor: public Factor {
|
|||
template <class ARCHIVE>
|
||||
void serialize(ARCHIVE& ar, const unsigned int /*version*/) {
|
||||
ar& BOOST_SERIALIZATION_BASE_OBJECT_NVP(Base);
|
||||
ar& BOOST_SERIALIZATION_NVP(cardinalities_);
|
||||
}
|
||||
#endif
|
||||
};
|
||||
|
|
|
|||
|
|
@ -34,8 +34,7 @@ TableFactor::TableFactor() {}
|
|||
/* ************************************************************************ */
|
||||
TableFactor::TableFactor(const DiscreteKeys& dkeys,
|
||||
const TableFactor& potentials)
|
||||
: DiscreteFactor(dkeys.indices()),
|
||||
cardinalities_(potentials.cardinalities_) {
|
||||
: DiscreteFactor(dkeys.indices(), dkeys.cardinalities()) {
|
||||
sparse_table_ = potentials.sparse_table_;
|
||||
denominators_ = potentials.denominators_;
|
||||
sorted_dkeys_ = discreteKeys();
|
||||
|
|
@ -45,11 +44,11 @@ TableFactor::TableFactor(const DiscreteKeys& dkeys,
|
|||
/* ************************************************************************ */
|
||||
TableFactor::TableFactor(const DiscreteKeys& dkeys,
|
||||
const Eigen::SparseVector<double>& table)
|
||||
: DiscreteFactor(dkeys.indices()), sparse_table_(table.size()) {
|
||||
: DiscreteFactor(dkeys.indices(), dkeys.cardinalities()),
|
||||
sparse_table_(table.size()) {
|
||||
sparse_table_ = table;
|
||||
double denom = table.size();
|
||||
for (const DiscreteKey& dkey : dkeys) {
|
||||
cardinalities_.insert(dkey);
|
||||
denom /= dkey.second;
|
||||
denominators_.insert(std::pair<Key, double>(dkey.first, denom));
|
||||
}
|
||||
|
|
@ -440,18 +439,6 @@ std::vector<std::pair<DiscreteValues, double>> TableFactor::enumerate() const {
|
|||
return result;
|
||||
}
|
||||
|
||||
/* ************************************************************************ */
|
||||
DiscreteKeys TableFactor::discreteKeys() const {
|
||||
DiscreteKeys result;
|
||||
for (auto&& key : keys()) {
|
||||
DiscreteKey dkey(key, cardinality(key));
|
||||
if (std::find(result.begin(), result.end(), dkey) == result.end()) {
|
||||
result.push_back(dkey);
|
||||
}
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
// Print out header.
|
||||
/* ************************************************************************ */
|
||||
string TableFactor::markdown(const KeyFormatter& keyFormatter,
|
||||
|
|
|
|||
|
|
@ -45,8 +45,6 @@ class HybridValues;
|
|||
*/
|
||||
class GTSAM_EXPORT TableFactor : public DiscreteFactor {
|
||||
protected:
|
||||
/// Map of Keys and their cardinalities.
|
||||
std::map<Key, size_t> cardinalities_;
|
||||
/// SparseVector of nonzero probabilities.
|
||||
Eigen::SparseVector<double> sparse_table_;
|
||||
|
||||
|
|
@ -184,8 +182,6 @@ class GTSAM_EXPORT TableFactor : public DiscreteFactor {
|
|||
|
||||
static double safe_div(const double& a, const double& b);
|
||||
|
||||
size_t cardinality(Key j) const { return cardinalities_.at(j); }
|
||||
|
||||
/// divide by factor f (safely)
|
||||
TableFactor operator/(const TableFactor& f) const {
|
||||
return apply(f, safe_div);
|
||||
|
|
@ -278,9 +274,6 @@ class GTSAM_EXPORT TableFactor : public DiscreteFactor {
|
|||
/// Enumerate all values into a map from values to double.
|
||||
std::vector<std::pair<DiscreteValues, double>> enumerate() const;
|
||||
|
||||
/// Return all the discrete keys associated with this factor.
|
||||
DiscreteKeys discreteKeys() const;
|
||||
|
||||
/**
|
||||
* @brief Prune the decision tree of discrete variables.
|
||||
*
|
||||
|
|
|
|||
Loading…
Reference in New Issue