release/4.3a0
Frank Dellaert 2022-01-22 12:50:35 -05:00
parent 94c692ddd1
commit ca329daa13
2 changed files with 106 additions and 122 deletions

View File

@ -17,9 +17,9 @@
* @author Frank Dellaert
*/
#include <gtsam/base/FastSet.h>
#include <gtsam/discrete/DecisionTreeFactor.h>
#include <gtsam/discrete/DiscreteConditional.h>
#include <gtsam/base/FastSet.h>
#include <boost/make_shared.hpp>
#include <boost/format.hpp>
@ -29,42 +29,42 @@ using namespace std;
namespace gtsam {
/* ******************************************************************************** */
DecisionTreeFactor::DecisionTreeFactor() {
}
/* ************************************************************************ */
DecisionTreeFactor::DecisionTreeFactor() {}
/* ******************************************************************************** */
/* ************************************************************************ */
DecisionTreeFactor::DecisionTreeFactor(const DiscreteKeys& keys,
const ADT& potentials) :
DiscreteFactor(keys.indices()), ADT(potentials),
cardinalities_(keys.cardinalities()) {
}
const ADT& potentials)
: DiscreteFactor(keys.indices()),
ADT(potentials),
cardinalities_(keys.cardinalities()) {}
/* *************************************************************************/
DecisionTreeFactor::DecisionTreeFactor(const DiscreteConditional& c) :
DiscreteFactor(c.keys()), AlgebraicDecisionTree<Key>(c), cardinalities_(c.cardinalities_) {
}
/* ************************************************************************ */
DecisionTreeFactor::DecisionTreeFactor(const DiscreteConditional& c)
: DiscreteFactor(c.keys()),
AlgebraicDecisionTree<Key>(c),
cardinalities_(c.cardinalities_) {}
/* ************************************************************************* */
bool DecisionTreeFactor::equals(const DiscreteFactor& other, double tol) const {
if(!dynamic_cast<const DecisionTreeFactor*>(&other)) {
/* ************************************************************************ */
bool DecisionTreeFactor::equals(const DiscreteFactor& other,
double tol) const {
if (!dynamic_cast<const DecisionTreeFactor*>(&other)) {
return false;
}
else {
} else {
const auto& f(static_cast<const DecisionTreeFactor&>(other));
return ADT::equals(f, tol);
}
}
/* ************************************************************************* */
double DecisionTreeFactor::safe_div(const double &a, const double &b) {
/* ************************************************************************ */
double DecisionTreeFactor::safe_div(const double& a, const double& b) {
// The use for safe_div is when we divide the product factor by the sum
// factor. If the product or sum is zero, we accord zero probability to the
// event.
return (a == 0 || b == 0) ? 0 : (a / b);
}
/* ************************************************************************* */
/* ************************************************************************ */
void DecisionTreeFactor::print(const string& s,
const KeyFormatter& formatter) const {
cout << s;
@ -75,31 +75,32 @@ namespace gtsam {
ADT::print("", formatter);
}
/* ************************************************************************* */
/* ************************************************************************ */
DecisionTreeFactor DecisionTreeFactor::apply(const DecisionTreeFactor& f,
ADT::Binary op) const {
map<Key,size_t> cs; // new cardinalities
ADT::Binary op) const {
map<Key, size_t> cs; // new cardinalities
// make unique key-cardinality map
for(Key j: keys()) cs[j] = cardinality(j);
for(Key j: f.keys()) cs[j] = f.cardinality(j);
for (Key j : keys()) cs[j] = cardinality(j);
for (Key j : f.keys()) cs[j] = f.cardinality(j);
// Convert map into keys
DiscreteKeys keys;
for(const std::pair<const Key,size_t>& key: cs)
keys.push_back(key);
for (const std::pair<const Key, size_t>& key : cs) keys.push_back(key);
// apply operand
ADT result = ADT::apply(f, op);
// Make a new factor
return DecisionTreeFactor(keys, result);
}
/* ************************************************************************* */
DecisionTreeFactor::shared_ptr DecisionTreeFactor::combine(size_t nrFrontals,
ADT::Binary op) const {
if (nrFrontals > size()) throw invalid_argument(
(boost::format(
"DecisionTreeFactor::combine: invalid number of frontal keys %d, nr.keys=%d")
% nrFrontals % size()).str());
/* ************************************************************************ */
DecisionTreeFactor::shared_ptr DecisionTreeFactor::combine(
size_t nrFrontals, ADT::Binary op) const {
if (nrFrontals > size())
throw invalid_argument(
(boost::format(
"DecisionTreeFactor::combine: invalid number of frontal "
"keys %d, nr.keys=%d") %
nrFrontals % size())
.str());
// sum over nrFrontals keys
size_t i;
@ -113,20 +114,21 @@ namespace gtsam {
DiscreteKeys dkeys;
for (; i < keys().size(); i++) {
Key j = keys()[i];
dkeys.push_back(DiscreteKey(j,cardinality(j)));
dkeys.push_back(DiscreteKey(j, cardinality(j)));
}
return boost::make_shared<DecisionTreeFactor>(dkeys, result);
}
/* ************************************************************************* */
DecisionTreeFactor::shared_ptr DecisionTreeFactor::combine(const Ordering& frontalKeys,
ADT::Binary op) const {
if (frontalKeys.size() > size()) throw invalid_argument(
(boost::format(
"DecisionTreeFactor::combine: invalid number of frontal keys %d, nr.keys=%d")
% frontalKeys.size() % size()).str());
/* ************************************************************************ */
DecisionTreeFactor::shared_ptr DecisionTreeFactor::combine(
const Ordering& frontalKeys, ADT::Binary op) const {
if (frontalKeys.size() > size())
throw invalid_argument(
(boost::format(
"DecisionTreeFactor::combine: invalid number of frontal "
"keys %d, nr.keys=%d") %
frontalKeys.size() % size())
.str());
// sum over nrFrontals keys
size_t i;
@ -137,20 +139,22 @@ namespace gtsam {
}
// create new factor, note we collect keys that are not in frontalKeys
// TODO: why do we need this??? result should contain correct keys!!!
// TODO(frank): why do we need this??? result should contain correct keys!!!
DiscreteKeys dkeys;
for (i = 0; i < keys().size(); i++) {
Key j = keys()[i];
// TODO: inefficient!
if (std::find(frontalKeys.begin(), frontalKeys.end(), j) != frontalKeys.end())
// TODO(frank): inefficient!
if (std::find(frontalKeys.begin(), frontalKeys.end(), j) !=
frontalKeys.end())
continue;
dkeys.push_back(DiscreteKey(j,cardinality(j)));
dkeys.push_back(DiscreteKey(j, cardinality(j)));
}
return boost::make_shared<DecisionTreeFactor>(dkeys, result);
}
/* ************************************************************************* */
std::vector<std::pair<DiscreteValues, double>> DecisionTreeFactor::enumerate() const {
/* ************************************************************************ */
std::vector<std::pair<DiscreteValues, double>> DecisionTreeFactor::enumerate()
const {
// Get all possible assignments
std::vector<std::pair<Key, size_t>> pairs;
for (auto& key : keys()) {
@ -168,7 +172,7 @@ namespace gtsam {
return result;
}
/* ************************************************************************* */
/* ************************************************************************ */
DiscreteKeys DecisionTreeFactor::discreteKeys() const {
DiscreteKeys result;
for (auto&& key : keys()) {
@ -180,7 +184,7 @@ namespace gtsam {
return result;
}
/* ************************************************************************* */
/* ************************************************************************ */
static std::string valueFormatter(const double& v) {
return (boost::format("%4.2g") % v).str();
}
@ -194,8 +198,8 @@ namespace gtsam {
/** output to graphviz format, open a file */
void DecisionTreeFactor::dot(const std::string& name,
const KeyFormatter& keyFormatter,
bool showZero) const {
const KeyFormatter& keyFormatter,
bool showZero) const {
ADT::dot(name, keyFormatter, valueFormatter, showZero);
}
@ -205,8 +209,8 @@ namespace gtsam {
return ADT::dot(keyFormatter, valueFormatter, showZero);
}
// Print out header.
/* ************************************************************************* */
// Print out header.
/* ************************************************************************ */
string DecisionTreeFactor::markdown(const KeyFormatter& keyFormatter,
const Names& names) const {
stringstream ss;
@ -271,17 +275,19 @@ namespace gtsam {
return ss.str();
}
/* ************************************************************************* */
DecisionTreeFactor::DecisionTreeFactor(const DiscreteKeys &keys, const vector<double> &table) :
DiscreteFactor(keys.indices()), AlgebraicDecisionTree<Key>(keys, table),
cardinalities_(keys.cardinalities()) {
}
/* ************************************************************************ */
DecisionTreeFactor::DecisionTreeFactor(const DiscreteKeys& keys,
const vector<double>& table)
: DiscreteFactor(keys.indices()),
AlgebraicDecisionTree<Key>(keys, table),
cardinalities_(keys.cardinalities()) {}
/* ************************************************************************* */
DecisionTreeFactor::DecisionTreeFactor(const DiscreteKeys &keys, const string &table) :
DiscreteFactor(keys.indices()), AlgebraicDecisionTree<Key>(keys, table),
cardinalities_(keys.cardinalities()) {
}
/* ************************************************************************ */
DecisionTreeFactor::DecisionTreeFactor(const DiscreteKeys& keys,
const string& table)
: DiscreteFactor(keys.indices()),
AlgebraicDecisionTree<Key>(keys, table),
cardinalities_(keys.cardinalities()) {}
/* ************************************************************************* */
} // namespace gtsam
/* ************************************************************************ */
} // namespace gtsam

View File

@ -18,16 +18,18 @@
#pragma once
#include <gtsam/discrete/AlgebraicDecisionTree.h>
#include <gtsam/discrete/DiscreteFactor.h>
#include <gtsam/discrete/DiscreteKey.h>
#include <gtsam/discrete/AlgebraicDecisionTree.h>
#include <gtsam/inference/Ordering.h>
#include <algorithm>
#include <boost/shared_ptr.hpp>
#include <vector>
#include <exception>
#include <map>
#include <stdexcept>
#include <string>
#include <utility>
#include <vector>
namespace gtsam {
@ -36,21 +38,19 @@ namespace gtsam {
/**
* A discrete probabilistic factor
*/
class GTSAM_EXPORT DecisionTreeFactor: public DiscreteFactor, public AlgebraicDecisionTree<Key> {
public:
class GTSAM_EXPORT DecisionTreeFactor : public DiscreteFactor,
public AlgebraicDecisionTree<Key> {
public:
// typedefs needed to play nice with gtsam
typedef DecisionTreeFactor This;
typedef DiscreteFactor Base; ///< Typedef to base class
typedef DiscreteFactor Base; ///< Typedef to base class
typedef boost::shared_ptr<DecisionTreeFactor> shared_ptr;
typedef AlgebraicDecisionTree<Key> ADT;
protected:
std::map<Key,size_t> cardinalities_;
public:
protected:
std::map<Key, size_t> cardinalities_;
public:
/// @name Standard Constructors
/// @{
@ -61,7 +61,8 @@ namespace gtsam {
DecisionTreeFactor(const DiscreteKeys& keys, const ADT& potentials);
/** Constructor from doubles */
DecisionTreeFactor(const DiscreteKeys& keys, const std::vector<double>& table);
DecisionTreeFactor(const DiscreteKeys& keys,
const std::vector<double>& table);
/** Constructor from string */
DecisionTreeFactor(const DiscreteKeys& keys, const std::string& table);
@ -86,7 +87,8 @@ namespace gtsam {
bool equals(const DiscreteFactor& other, double tol = 1e-9) const override;
// print
void print(const std::string& s = "DecisionTreeFactor:\n",
void print(
const std::string& s = "DecisionTreeFactor:\n",
const KeyFormatter& formatter = DefaultKeyFormatter) const override;
/// @}
@ -105,7 +107,7 @@ namespace gtsam {
static double safe_div(const double& a, const double& b);
size_t cardinality(Key j) const { return cardinalities_.at(j);}
size_t cardinality(Key j) const { return cardinalities_.at(j); }
/// divide by factor f (safely)
DecisionTreeFactor operator/(const DecisionTreeFactor& f) const {
@ -113,9 +115,7 @@ namespace gtsam {
}
/// Convert into a decisiontree
DecisionTreeFactor toDecisionTreeFactor() const override {
return *this;
}
DecisionTreeFactor toDecisionTreeFactor() const override { return *this; }
/// Create new factor by summing all values with the same separator values
shared_ptr sum(size_t nrFrontals) const {
@ -164,27 +164,6 @@ namespace gtsam {
*/
shared_ptr combine(const Ordering& keys, ADT::Binary op) const;
// /**
// * @brief Permutes the keys in Potentials and DiscreteFactor
// *
// * This re-implements the permuteWithInverse() in both Potentials
// * and DiscreteFactor by doing both of them together.
// */
//
// void permuteWithInverse(const Permutation& inversePermutation){
// DiscreteFactor::permuteWithInverse(inversePermutation);
// Potentials::permuteWithInverse(inversePermutation);
// }
//
// /**
// * Apply a reduction, which is a remapping of variable indices.
// */
// virtual void reduceWithInverse(const internal::Reduction& inverseReduction) {
// DiscreteFactor::reduceWithInverse(inverseReduction);
// Potentials::reduceWithInverse(inverseReduction);
// }
/// Enumerate all values into a map from values to double.
std::vector<std::pair<DiscreteValues, double>> enumerate() const;
@ -194,16 +173,16 @@ namespace gtsam {
/// @}
/// @name Wrapper support
/// @{
/** output to graphviz format, stream version */
void dot(std::ostream& os,
const KeyFormatter& keyFormatter = DefaultKeyFormatter,
bool showZero = true) const;
const KeyFormatter& keyFormatter = DefaultKeyFormatter,
bool showZero = true) const;
/** output to graphviz format, open a file */
void dot(const std::string& name,
const KeyFormatter& keyFormatter = DefaultKeyFormatter,
bool showZero = true) const;
const KeyFormatter& keyFormatter = DefaultKeyFormatter,
bool showZero = true) const;
/** output to graphviz format string */
std::string dot(const KeyFormatter& keyFormatter = DefaultKeyFormatter,
@ -217,7 +196,7 @@ namespace gtsam {
* @return std::string a markdown string.
*/
std::string markdown(const KeyFormatter& keyFormatter = DefaultKeyFormatter,
const Names& names = {}) const override;
const Names& names = {}) const override;
/**
* @brief Render as html table
@ -227,14 +206,13 @@ namespace gtsam {
* @return std::string a html string.
*/
std::string html(const KeyFormatter& keyFormatter = DefaultKeyFormatter,
const Names& names = {}) const override;
const Names& names = {}) const override;
/// @}
};
// DecisionTreeFactor
};
// traits
template<> struct traits<DecisionTreeFactor> : public Testable<DecisionTreeFactor> {};
template <>
struct traits<DecisionTreeFactor> : public Testable<DecisionTreeFactor> {};
}// namespace gtsam
} // namespace gtsam