move common functions to DiscreteFactor class
parent
64ecb8581e
commit
50d24ab38e
|
|
@ -34,15 +34,12 @@ namespace gtsam {
|
||||||
/* ************************************************************************ */
|
/* ************************************************************************ */
|
||||||
DecisionTreeFactor::DecisionTreeFactor(const DiscreteKeys& keys,
|
DecisionTreeFactor::DecisionTreeFactor(const DiscreteKeys& keys,
|
||||||
const ADT& potentials)
|
const ADT& potentials)
|
||||||
: DiscreteFactor(keys.indices()),
|
: DiscreteFactor(keys.indices(), keys.cardinalities()), ADT(potentials) {}
|
||||||
ADT(potentials),
|
|
||||||
cardinalities_(keys.cardinalities()) {}
|
|
||||||
|
|
||||||
/* ************************************************************************ */
|
/* ************************************************************************ */
|
||||||
DecisionTreeFactor::DecisionTreeFactor(const DiscreteConditional& c)
|
DecisionTreeFactor::DecisionTreeFactor(const DiscreteConditional& c)
|
||||||
: DiscreteFactor(c.keys()),
|
: DiscreteFactor(c.keys(), c.cardinalities()),
|
||||||
AlgebraicDecisionTree<Key>(c),
|
AlgebraicDecisionTree<Key>(c) {}
|
||||||
cardinalities_(c.cardinalities_) {}
|
|
||||||
|
|
||||||
/* ************************************************************************ */
|
/* ************************************************************************ */
|
||||||
bool DecisionTreeFactor::equals(const DiscreteFactor& other,
|
bool DecisionTreeFactor::equals(const DiscreteFactor& other,
|
||||||
|
|
@ -190,18 +187,6 @@ namespace gtsam {
|
||||||
return probs;
|
return probs;
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************ */
|
|
||||||
DiscreteKeys DecisionTreeFactor::discreteKeys() const {
|
|
||||||
DiscreteKeys result;
|
|
||||||
for (auto&& key : keys()) {
|
|
||||||
DiscreteKey dkey(key, cardinality(key));
|
|
||||||
if (std::find(result.begin(), result.end(), dkey) == result.end()) {
|
|
||||||
result.push_back(dkey);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return result;
|
|
||||||
}
|
|
||||||
|
|
||||||
/* ************************************************************************ */
|
/* ************************************************************************ */
|
||||||
static std::string valueFormatter(const double& v) {
|
static std::string valueFormatter(const double& v) {
|
||||||
std::stringstream ss;
|
std::stringstream ss;
|
||||||
|
|
@ -298,16 +283,14 @@ namespace gtsam {
|
||||||
/* ************************************************************************ */
|
/* ************************************************************************ */
|
||||||
DecisionTreeFactor::DecisionTreeFactor(const DiscreteKeys& keys,
|
DecisionTreeFactor::DecisionTreeFactor(const DiscreteKeys& keys,
|
||||||
const vector<double>& table)
|
const vector<double>& table)
|
||||||
: DiscreteFactor(keys.indices()),
|
: DiscreteFactor(keys.indices(), keys.cardinalities()),
|
||||||
AlgebraicDecisionTree<Key>(keys, table),
|
AlgebraicDecisionTree<Key>(keys, table) {}
|
||||||
cardinalities_(keys.cardinalities()) {}
|
|
||||||
|
|
||||||
/* ************************************************************************ */
|
/* ************************************************************************ */
|
||||||
DecisionTreeFactor::DecisionTreeFactor(const DiscreteKeys& keys,
|
DecisionTreeFactor::DecisionTreeFactor(const DiscreteKeys& keys,
|
||||||
const string& table)
|
const string& table)
|
||||||
: DiscreteFactor(keys.indices()),
|
: DiscreteFactor(keys.indices(), keys.cardinalities()),
|
||||||
AlgebraicDecisionTree<Key>(keys, table),
|
AlgebraicDecisionTree<Key>(keys, table) {}
|
||||||
cardinalities_(keys.cardinalities()) {}
|
|
||||||
|
|
||||||
/* ************************************************************************ */
|
/* ************************************************************************ */
|
||||||
DecisionTreeFactor DecisionTreeFactor::prune(size_t maxNrAssignments) const {
|
DecisionTreeFactor DecisionTreeFactor::prune(size_t maxNrAssignments) const {
|
||||||
|
|
|
||||||
|
|
@ -50,10 +50,6 @@ namespace gtsam {
|
||||||
typedef std::shared_ptr<DecisionTreeFactor> shared_ptr;
|
typedef std::shared_ptr<DecisionTreeFactor> shared_ptr;
|
||||||
typedef AlgebraicDecisionTree<Key> ADT;
|
typedef AlgebraicDecisionTree<Key> ADT;
|
||||||
|
|
||||||
protected:
|
|
||||||
std::map<Key, size_t> cardinalities_;
|
|
||||||
|
|
||||||
public:
|
|
||||||
/// @name Standard Constructors
|
/// @name Standard Constructors
|
||||||
/// @{
|
/// @{
|
||||||
|
|
||||||
|
|
@ -119,10 +115,6 @@ namespace gtsam {
|
||||||
|
|
||||||
static double safe_div(const double& a, const double& b);
|
static double safe_div(const double& a, const double& b);
|
||||||
|
|
||||||
std::map<Key, size_t> cardinalities() const { return cardinalities_; }
|
|
||||||
|
|
||||||
size_t cardinality(Key j) const { return cardinalities_.at(j); }
|
|
||||||
|
|
||||||
/// divide by factor f (safely)
|
/// divide by factor f (safely)
|
||||||
DecisionTreeFactor operator/(const DecisionTreeFactor& f) const {
|
DecisionTreeFactor operator/(const DecisionTreeFactor& f) const {
|
||||||
return apply(f, safe_div);
|
return apply(f, safe_div);
|
||||||
|
|
@ -184,9 +176,6 @@ namespace gtsam {
|
||||||
/// Get all the probabilities in order of assignment values
|
/// Get all the probabilities in order of assignment values
|
||||||
std::vector<double> probabilities() const;
|
std::vector<double> probabilities() const;
|
||||||
|
|
||||||
/// Return all the discrete keys associated with this factor.
|
|
||||||
DiscreteKeys discreteKeys() const;
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @brief Prune the decision tree of discrete variables.
|
* @brief Prune the decision tree of discrete variables.
|
||||||
*
|
*
|
||||||
|
|
@ -265,7 +254,6 @@ namespace gtsam {
|
||||||
void serialize(ARCHIVE& ar, const unsigned int /*version*/) {
|
void serialize(ARCHIVE& ar, const unsigned int /*version*/) {
|
||||||
ar& BOOST_SERIALIZATION_BASE_OBJECT_NVP(Base);
|
ar& BOOST_SERIALIZATION_BASE_OBJECT_NVP(Base);
|
||||||
ar& BOOST_SERIALIZATION_BASE_OBJECT_NVP(ADT);
|
ar& BOOST_SERIALIZATION_BASE_OBJECT_NVP(ADT);
|
||||||
ar& BOOST_SERIALIZATION_NVP(cardinalities_);
|
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
};
|
};
|
||||||
|
|
|
||||||
|
|
@ -28,6 +28,18 @@ using namespace std;
|
||||||
|
|
||||||
namespace gtsam {
|
namespace gtsam {
|
||||||
|
|
||||||
|
/* ************************************************************************ */
|
||||||
|
DiscreteKeys DiscreteFactor::discreteKeys() const {
|
||||||
|
DiscreteKeys result;
|
||||||
|
for (auto&& key : keys()) {
|
||||||
|
DiscreteKey dkey(key, cardinality(key));
|
||||||
|
if (std::find(result.begin(), result.end(), dkey) == result.end()) {
|
||||||
|
result.push_back(dkey);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
double DiscreteFactor::error(const DiscreteValues& values) const {
|
double DiscreteFactor::error(const DiscreteValues& values) const {
|
||||||
return -std::log((*this)(values));
|
return -std::log((*this)(values));
|
||||||
|
|
|
||||||
|
|
@ -45,6 +45,10 @@ class GTSAM_EXPORT DiscreteFactor: public Factor {
|
||||||
|
|
||||||
using Values = DiscreteValues; ///< backwards compatibility
|
using Values = DiscreteValues; ///< backwards compatibility
|
||||||
|
|
||||||
|
protected:
|
||||||
|
/// Map of Keys and their cardinalities.
|
||||||
|
std::map<Key, size_t> cardinalities_;
|
||||||
|
|
||||||
public:
|
public:
|
||||||
/// @name Standard Constructors
|
/// @name Standard Constructors
|
||||||
/// @{
|
/// @{
|
||||||
|
|
@ -52,10 +56,15 @@ class GTSAM_EXPORT DiscreteFactor: public Factor {
|
||||||
/** Default constructor creates empty factor */
|
/** Default constructor creates empty factor */
|
||||||
DiscreteFactor() {}
|
DiscreteFactor() {}
|
||||||
|
|
||||||
/** Construct from container of keys. This constructor is used internally from derived factor
|
/**
|
||||||
* constructors, either from a container of keys or from a boost::assign::list_of. */
|
* Construct from container of keys and map of cardinalities.
|
||||||
template<typename CONTAINER>
|
* This constructor is used internally from derived factor constructors,
|
||||||
DiscreteFactor(const CONTAINER& keys) : Base(keys) {}
|
* either from a container of keys or from a boost::assign::list_of.
|
||||||
|
*/
|
||||||
|
template <typename CONTAINER>
|
||||||
|
DiscreteFactor(const CONTAINER& keys,
|
||||||
|
const std::map<Key, size_t> cardinalities = {})
|
||||||
|
: Base(keys), cardinalities_(cardinalities) {}
|
||||||
|
|
||||||
/// @}
|
/// @}
|
||||||
/// @name Testable
|
/// @name Testable
|
||||||
|
|
@ -75,6 +84,13 @@ class GTSAM_EXPORT DiscreteFactor: public Factor {
|
||||||
/// @name Standard Interface
|
/// @name Standard Interface
|
||||||
/// @{
|
/// @{
|
||||||
|
|
||||||
|
/// Return all the discrete keys associated with this factor.
|
||||||
|
DiscreteKeys discreteKeys() const;
|
||||||
|
|
||||||
|
std::map<Key, size_t> cardinalities() const { return cardinalities_; }
|
||||||
|
|
||||||
|
size_t cardinality(Key j) const { return cardinalities_.at(j); }
|
||||||
|
|
||||||
/// Find value for given assignment of values to variables
|
/// Find value for given assignment of values to variables
|
||||||
virtual double operator()(const DiscreteValues&) const = 0;
|
virtual double operator()(const DiscreteValues&) const = 0;
|
||||||
|
|
||||||
|
|
@ -130,6 +146,7 @@ class GTSAM_EXPORT DiscreteFactor: public Factor {
|
||||||
template <class ARCHIVE>
|
template <class ARCHIVE>
|
||||||
void serialize(ARCHIVE& ar, const unsigned int /*version*/) {
|
void serialize(ARCHIVE& ar, const unsigned int /*version*/) {
|
||||||
ar& BOOST_SERIALIZATION_BASE_OBJECT_NVP(Base);
|
ar& BOOST_SERIALIZATION_BASE_OBJECT_NVP(Base);
|
||||||
|
ar& BOOST_SERIALIZATION_NVP(cardinalities_);
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
};
|
};
|
||||||
|
|
|
||||||
|
|
@ -34,8 +34,7 @@ TableFactor::TableFactor() {}
|
||||||
/* ************************************************************************ */
|
/* ************************************************************************ */
|
||||||
TableFactor::TableFactor(const DiscreteKeys& dkeys,
|
TableFactor::TableFactor(const DiscreteKeys& dkeys,
|
||||||
const TableFactor& potentials)
|
const TableFactor& potentials)
|
||||||
: DiscreteFactor(dkeys.indices()),
|
: DiscreteFactor(dkeys.indices(), dkeys.cardinalities()) {
|
||||||
cardinalities_(potentials.cardinalities_) {
|
|
||||||
sparse_table_ = potentials.sparse_table_;
|
sparse_table_ = potentials.sparse_table_;
|
||||||
denominators_ = potentials.denominators_;
|
denominators_ = potentials.denominators_;
|
||||||
sorted_dkeys_ = discreteKeys();
|
sorted_dkeys_ = discreteKeys();
|
||||||
|
|
@ -45,11 +44,11 @@ TableFactor::TableFactor(const DiscreteKeys& dkeys,
|
||||||
/* ************************************************************************ */
|
/* ************************************************************************ */
|
||||||
TableFactor::TableFactor(const DiscreteKeys& dkeys,
|
TableFactor::TableFactor(const DiscreteKeys& dkeys,
|
||||||
const Eigen::SparseVector<double>& table)
|
const Eigen::SparseVector<double>& table)
|
||||||
: DiscreteFactor(dkeys.indices()), sparse_table_(table.size()) {
|
: DiscreteFactor(dkeys.indices(), dkeys.cardinalities()),
|
||||||
|
sparse_table_(table.size()) {
|
||||||
sparse_table_ = table;
|
sparse_table_ = table;
|
||||||
double denom = table.size();
|
double denom = table.size();
|
||||||
for (const DiscreteKey& dkey : dkeys) {
|
for (const DiscreteKey& dkey : dkeys) {
|
||||||
cardinalities_.insert(dkey);
|
|
||||||
denom /= dkey.second;
|
denom /= dkey.second;
|
||||||
denominators_.insert(std::pair<Key, double>(dkey.first, denom));
|
denominators_.insert(std::pair<Key, double>(dkey.first, denom));
|
||||||
}
|
}
|
||||||
|
|
@ -440,18 +439,6 @@ std::vector<std::pair<DiscreteValues, double>> TableFactor::enumerate() const {
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************ */
|
|
||||||
DiscreteKeys TableFactor::discreteKeys() const {
|
|
||||||
DiscreteKeys result;
|
|
||||||
for (auto&& key : keys()) {
|
|
||||||
DiscreteKey dkey(key, cardinality(key));
|
|
||||||
if (std::find(result.begin(), result.end(), dkey) == result.end()) {
|
|
||||||
result.push_back(dkey);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return result;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Print out header.
|
// Print out header.
|
||||||
/* ************************************************************************ */
|
/* ************************************************************************ */
|
||||||
string TableFactor::markdown(const KeyFormatter& keyFormatter,
|
string TableFactor::markdown(const KeyFormatter& keyFormatter,
|
||||||
|
|
|
||||||
|
|
@ -45,8 +45,6 @@ class HybridValues;
|
||||||
*/
|
*/
|
||||||
class GTSAM_EXPORT TableFactor : public DiscreteFactor {
|
class GTSAM_EXPORT TableFactor : public DiscreteFactor {
|
||||||
protected:
|
protected:
|
||||||
/// Map of Keys and their cardinalities.
|
|
||||||
std::map<Key, size_t> cardinalities_;
|
|
||||||
/// SparseVector of nonzero probabilities.
|
/// SparseVector of nonzero probabilities.
|
||||||
Eigen::SparseVector<double> sparse_table_;
|
Eigen::SparseVector<double> sparse_table_;
|
||||||
|
|
||||||
|
|
@ -184,8 +182,6 @@ class GTSAM_EXPORT TableFactor : public DiscreteFactor {
|
||||||
|
|
||||||
static double safe_div(const double& a, const double& b);
|
static double safe_div(const double& a, const double& b);
|
||||||
|
|
||||||
size_t cardinality(Key j) const { return cardinalities_.at(j); }
|
|
||||||
|
|
||||||
/// divide by factor f (safely)
|
/// divide by factor f (safely)
|
||||||
TableFactor operator/(const TableFactor& f) const {
|
TableFactor operator/(const TableFactor& f) const {
|
||||||
return apply(f, safe_div);
|
return apply(f, safe_div);
|
||||||
|
|
@ -278,9 +274,6 @@ class GTSAM_EXPORT TableFactor : public DiscreteFactor {
|
||||||
/// Enumerate all values into a map from values to double.
|
/// Enumerate all values into a map from values to double.
|
||||||
std::vector<std::pair<DiscreteValues, double>> enumerate() const;
|
std::vector<std::pair<DiscreteValues, double>> enumerate() const;
|
||||||
|
|
||||||
/// Return all the discrete keys associated with this factor.
|
|
||||||
DiscreteKeys discreteKeys() const;
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @brief Prune the decision tree of discrete variables.
|
* @brief Prune the decision tree of discrete variables.
|
||||||
*
|
*
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue