Refactor print, equals, convert
parent
78f8cc948d
commit
db3cb4d9ac
|
|
@ -28,7 +28,13 @@ namespace gtsam {
|
||||||
* TODO: consider eliminating this class altogether?
|
* TODO: consider eliminating this class altogether?
|
||||||
*/
|
*/
|
||||||
template<typename L>
|
template<typename L>
|
||||||
class AlgebraicDecisionTree: public DecisionTree<L, double> {
|
class GTSAM_EXPORT AlgebraicDecisionTree: public DecisionTree<L, double> {
|
||||||
|
/// Default method used by `formatter` when printing.
|
||||||
|
static std::string DefaultFormatter(const L& x) {
|
||||||
|
std::stringstream ss;
|
||||||
|
ss << x;
|
||||||
|
return ss.str();
|
||||||
|
}
|
||||||
|
|
||||||
public:
|
public:
|
||||||
|
|
||||||
|
|
@ -141,13 +147,23 @@ namespace gtsam {
|
||||||
return this->combine(labelC, &Ring::add);
|
return this->combine(labelC, &Ring::add);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// print method customized to node type `double`.
|
||||||
|
void print(const std::string& s,
|
||||||
|
const typename Super::LabelFormatter& labelFormatter =
|
||||||
|
&DefaultFormatter) const {
|
||||||
|
auto valueFormatter = [](const double& v) {
|
||||||
|
return (boost::format("%4.2g") % v).str();
|
||||||
|
};
|
||||||
|
Super::print(s, labelFormatter, valueFormatter);
|
||||||
|
}
|
||||||
|
|
||||||
/// Equality method customized to node type `double`.
|
/// Equality method customized to node type `double`.
|
||||||
bool equals(const AlgebraicDecisionTree& other, double tol = 1e-9) const {
|
bool equals(const AlgebraicDecisionTree& other, double tol = 1e-9) const {
|
||||||
// lambda for comparison of two doubles upto some tolerance.
|
// lambda for comparison of two doubles upto some tolerance.
|
||||||
auto compare = [tol](double a, double b) {
|
auto compare = [tol](double a, double b) {
|
||||||
return std::abs(a - b) < tol;
|
return std::abs(a - b) < tol;
|
||||||
};
|
};
|
||||||
return this->root_->equals(*other.root_, tol, compare);
|
return Super::equals(other, compare);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
// AlgebraicDecisionTree
|
// AlgebraicDecisionTree
|
||||||
|
|
|
||||||
|
|
@ -76,24 +76,25 @@ namespace gtsam {
|
||||||
}
|
}
|
||||||
|
|
||||||
/** equality up to tolerance */
|
/** equality up to tolerance */
|
||||||
bool equals(const Node& q, double tol,
|
bool equals(const Node& q, const CompareFunc& compare) const override {
|
||||||
const CompareFunc& compare) const override {
|
|
||||||
const Leaf* other = dynamic_cast<const Leaf*>(&q);
|
const Leaf* other = dynamic_cast<const Leaf*>(&q);
|
||||||
if (!other) return false;
|
if (!other) return false;
|
||||||
return compare(this->constant_, other->constant_);
|
return compare(this->constant_, other->constant_);
|
||||||
}
|
}
|
||||||
|
|
||||||
/** print */
|
/** print */
|
||||||
void print(const std::string& s,
|
void print(const std::string& s, const LabelFormatter& labelFormatter,
|
||||||
const FormatterFunc& formatter) const override {
|
const ValueFormatter& valueFormatter) const override {
|
||||||
bool showZero = true;
|
std::cout << s << " Leaf " << valueFormatter(constant_) << std::endl;
|
||||||
if (showZero || constant_) std::cout << s << " Leaf " << constant_ << std::endl;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/** to graphviz file */
|
/** to graphviz file */
|
||||||
void dot(std::ostream& os, bool showZero) const override {
|
void dot(std::ostream& os, const LabelFormatter& labelFormatter,
|
||||||
if (showZero || constant_) os << "\"" << this->id() << "\" [label=\""
|
const ValueFormatter& valueFormatter,
|
||||||
<< boost::format("%4.2g") % constant_
|
bool showZero) const override {
|
||||||
|
std::string value = valueFormatter(constant_);
|
||||||
|
if (showZero || value.compare("0"))
|
||||||
|
os << "\"" << this->id() << "\" [label=\"" << value
|
||||||
<< "\", shape=box, rank=sink, height=0.35, fixedsize=true]\n"; // width=0.55,
|
<< "\", shape=box, rank=sink, height=0.35, fixedsize=true]\n"; // width=0.55,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -238,16 +239,19 @@ namespace gtsam {
|
||||||
}
|
}
|
||||||
|
|
||||||
/** print (as a tree) */
|
/** print (as a tree) */
|
||||||
void print(const std::string& s,
|
void print(const std::string& s, const LabelFormatter& labelFormatter,
|
||||||
const FormatterFunc& formatter) const override {
|
const ValueFormatter& valueFormatter) const override {
|
||||||
std::cout << s << " Choice(";
|
std::cout << s << " Choice(";
|
||||||
std::cout << formatter(label_) << ") " << std::endl;
|
std::cout << labelFormatter(label_) << ") " << std::endl;
|
||||||
for (size_t i = 0; i < branches_.size(); i++)
|
for (size_t i = 0; i < branches_.size(); i++)
|
||||||
branches_[i]->print((boost::format("%s %d") % s % i).str(), formatter);
|
branches_[i]->print((boost::format("%s %d") % s % i).str(),
|
||||||
|
labelFormatter, valueFormatter);
|
||||||
}
|
}
|
||||||
|
|
||||||
/** output to graphviz (as a a graph) */
|
/** output to graphviz (as a a graph) */
|
||||||
void dot(std::ostream& os, bool showZero) const override {
|
void dot(std::ostream& os, const LabelFormatter& labelFormatter,
|
||||||
|
const ValueFormatter& valueFormatter,
|
||||||
|
bool showZero) const override {
|
||||||
os << "\"" << this->id() << "\" [shape=circle, label=\"" << label_
|
os << "\"" << this->id() << "\" [shape=circle, label=\"" << label_
|
||||||
<< "\"]\n";
|
<< "\"]\n";
|
||||||
size_t B = branches_.size();
|
size_t B = branches_.size();
|
||||||
|
|
@ -257,7 +261,8 @@ namespace gtsam {
|
||||||
// Check if zero
|
// Check if zero
|
||||||
if (!showZero) {
|
if (!showZero) {
|
||||||
const Leaf* leaf = dynamic_cast<const Leaf*> (branch.get());
|
const Leaf* leaf = dynamic_cast<const Leaf*> (branch.get());
|
||||||
if (leaf && !leaf->constant()) continue;
|
std::string value = valueFormatter(leaf->constant());
|
||||||
|
if (leaf && value.compare("0")) continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
os << "\"" << this->id() << "\" -> \"" << branch->id() << "\"";
|
os << "\"" << this->id() << "\" -> \"" << branch->id() << "\"";
|
||||||
|
|
@ -266,7 +271,7 @@ namespace gtsam {
|
||||||
if (i > 1) os << " [style=bold]";
|
if (i > 1) os << " [style=bold]";
|
||||||
}
|
}
|
||||||
os << std::endl;
|
os << std::endl;
|
||||||
branch->dot(os, showZero);
|
branch->dot(os, labelFormatter, valueFormatter, showZero);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -280,16 +285,15 @@ namespace gtsam {
|
||||||
return (q.isLeaf() && q.sameLeaf(*this));
|
return (q.isLeaf() && q.sameLeaf(*this));
|
||||||
}
|
}
|
||||||
|
|
||||||
/** equality up to tolerance */
|
/** equality */
|
||||||
bool equals(const Node& q, double tol,
|
bool equals(const Node& q, const CompareFunc& compare) const override {
|
||||||
const CompareFunc& compare) const override {
|
|
||||||
const Choice* other = dynamic_cast<const Choice*>(&q);
|
const Choice* other = dynamic_cast<const Choice*>(&q);
|
||||||
if (!other) return false;
|
if (!other) return false;
|
||||||
if (this->label_ != other->label_) return false;
|
if (this->label_ != other->label_) return false;
|
||||||
if (branches_.size() != other->branches_.size()) return false;
|
if (branches_.size() != other->branches_.size()) return false;
|
||||||
// we don't care about shared pointers being equal here
|
// we don't care about shared pointers being equal here
|
||||||
for (size_t i = 0; i < branches_.size(); i++)
|
for (size_t i = 0; i < branches_.size(); i++)
|
||||||
if (!(branches_[i]->equals(*(other->branches_[i]), tol, compare)))
|
if (!(branches_[i]->equals(*(other->branches_[i]), compare)))
|
||||||
return false;
|
return false;
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
@ -459,7 +463,7 @@ namespace gtsam {
|
||||||
DecisionTree<L, Y>::DecisionTree(const DecisionTree<L, X>& other,
|
DecisionTree<L, Y>::DecisionTree(const DecisionTree<L, X>& other,
|
||||||
std::function<Y(const X&)> op) {
|
std::function<Y(const X&)> op) {
|
||||||
auto map = [](const L& label) { return label; };
|
auto map = [](const L& label) { return label; };
|
||||||
root_ = convert<L, X>(other.root_, op, map);
|
root_ = other.template convert<L, X>(op, map);
|
||||||
}
|
}
|
||||||
|
|
||||||
/*********************************************************************************/
|
/*********************************************************************************/
|
||||||
|
|
@ -470,7 +474,7 @@ namespace gtsam {
|
||||||
std::function<L(const M&)> map_function = [&map](const M& label) -> L {
|
std::function<L(const M&)> map_function = [&map](const M& label) -> L {
|
||||||
return map.at(label);
|
return map.at(label);
|
||||||
};
|
};
|
||||||
root_ = convert<M, X>(other.root_, op, map_function);
|
root_ = other.template convert<M, X>(op, map_function);
|
||||||
}
|
}
|
||||||
|
|
||||||
/*********************************************************************************/
|
/*********************************************************************************/
|
||||||
|
|
@ -587,7 +591,7 @@ namespace gtsam {
|
||||||
template <typename M, typename X>
|
template <typename M, typename X>
|
||||||
typename DecisionTree<L, Y>::NodePtr DecisionTree<L, Y>::convert(
|
typename DecisionTree<L, Y>::NodePtr DecisionTree<L, Y>::convert(
|
||||||
const typename DecisionTree<M, X>::NodePtr& f,
|
const typename DecisionTree<M, X>::NodePtr& f,
|
||||||
std::function<Y(const X&)> op, std::function<L(const M&)> map) {
|
std::function<Y(const X&)> op, std::function<L(const M&)> map) const {
|
||||||
typedef DecisionTree<M, X> MX;
|
typedef DecisionTree<M, X> MX;
|
||||||
typedef typename MX::Leaf MXLeaf;
|
typedef typename MX::Leaf MXLeaf;
|
||||||
typedef typename MX::Choice MXChoice;
|
typedef typename MX::Choice MXChoice;
|
||||||
|
|
@ -596,11 +600,11 @@ namespace gtsam {
|
||||||
|
|
||||||
// ugliness below because apparently we can't have templated virtual functions
|
// ugliness below because apparently we can't have templated virtual functions
|
||||||
// If leaf, apply unary conversion "op" and create a unique leaf
|
// If leaf, apply unary conversion "op" and create a unique leaf
|
||||||
const MXLeaf* leaf = dynamic_cast<const MXLeaf*> (f.get());
|
auto leaf = boost::dynamic_pointer_cast<const MXLeaf>(f);
|
||||||
if (leaf) return NodePtr(new Leaf(op(leaf->constant())));
|
if (leaf) return NodePtr(new Leaf(op(leaf->constant())));
|
||||||
|
|
||||||
// Check if Choice
|
// Check if Choice
|
||||||
boost::shared_ptr<const MXChoice> choice = boost::dynamic_pointer_cast<const MXChoice> (f);
|
auto choice = boost::dynamic_pointer_cast<const MXChoice>(f);
|
||||||
if (!choice) throw std::invalid_argument(
|
if (!choice) throw std::invalid_argument(
|
||||||
"DecisionTree::Convert: Invalid NodePtr");
|
"DecisionTree::Convert: Invalid NodePtr");
|
||||||
|
|
||||||
|
|
@ -619,15 +623,16 @@ namespace gtsam {
|
||||||
|
|
||||||
/*********************************************************************************/
|
/*********************************************************************************/
|
||||||
template <typename L, typename Y>
|
template <typename L, typename Y>
|
||||||
bool DecisionTree<L, Y>::equals(const DecisionTree& other, double tol,
|
bool DecisionTree<L, Y>::equals(const DecisionTree& other,
|
||||||
const CompareFunc& compare) const {
|
const CompareFunc& compare) const {
|
||||||
return root_->equals(*other.root_, tol, compare);
|
return root_->equals(*other.root_, compare);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename L, typename Y>
|
template <typename L, typename Y>
|
||||||
void DecisionTree<L, Y>::print(const std::string& s,
|
void DecisionTree<L, Y>::print(const std::string& s,
|
||||||
const FormatterFunc& formatter) const {
|
const LabelFormatter& labelFormatter,
|
||||||
root_->print(s, formatter);
|
const ValueFormatter& valueFormatter) const {
|
||||||
|
root_->print(s, labelFormatter, valueFormatter);
|
||||||
}
|
}
|
||||||
|
|
||||||
template<typename L, typename Y>
|
template<typename L, typename Y>
|
||||||
|
|
@ -687,26 +692,34 @@ namespace gtsam {
|
||||||
}
|
}
|
||||||
|
|
||||||
/*********************************************************************************/
|
/*********************************************************************************/
|
||||||
template<typename L, typename Y>
|
template <typename L, typename Y>
|
||||||
void DecisionTree<L, Y>::dot(std::ostream& os, bool showZero) const {
|
void DecisionTree<L, Y>::dot(std::ostream& os,
|
||||||
|
const LabelFormatter& labelFormatter,
|
||||||
|
const ValueFormatter& valueFormatter,
|
||||||
|
bool showZero) const {
|
||||||
os << "digraph G {\n";
|
os << "digraph G {\n";
|
||||||
root_->dot(os, showZero);
|
root_->dot(os, labelFormatter, valueFormatter, showZero);
|
||||||
os << " [ordering=out]}" << std::endl;
|
os << " [ordering=out]}" << std::endl;
|
||||||
}
|
}
|
||||||
|
|
||||||
template<typename L, typename Y>
|
template <typename L, typename Y>
|
||||||
void DecisionTree<L, Y>::dot(const std::string& name, bool showZero) const {
|
void DecisionTree<L, Y>::dot(const std::string& name,
|
||||||
|
const LabelFormatter& labelFormatter,
|
||||||
|
const ValueFormatter& valueFormatter,
|
||||||
|
bool showZero) const {
|
||||||
std::ofstream os((name + ".dot").c_str());
|
std::ofstream os((name + ".dot").c_str());
|
||||||
dot(os, showZero);
|
dot(os, labelFormatter, valueFormatter, showZero);
|
||||||
int result = system(
|
int result = system(
|
||||||
("dot -Tpdf " + name + ".dot -o " + name + ".pdf >& /dev/null").c_str());
|
("dot -Tpdf " + name + ".dot -o " + name + ".pdf >& /dev/null").c_str());
|
||||||
if (result==-1) throw std::runtime_error("DecisionTree::dot system call failed");
|
if (result==-1) throw std::runtime_error("DecisionTree::dot system call failed");
|
||||||
}
|
}
|
||||||
|
|
||||||
template<typename L, typename Y>
|
template <typename L, typename Y>
|
||||||
std::string DecisionTree<L, Y>::dot(bool showZero) const {
|
std::string DecisionTree<L, Y>::dot(const LabelFormatter& labelFormatter,
|
||||||
|
const ValueFormatter& valueFormatter,
|
||||||
|
bool showZero) const {
|
||||||
std::stringstream ss;
|
std::stringstream ss;
|
||||||
dot(ss, showZero);
|
dot(ss, labelFormatter, valueFormatter, showZero);
|
||||||
return ss.str();
|
return ss.str();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -39,13 +39,6 @@ namespace gtsam {
|
||||||
template<typename L, typename Y>
|
template<typename L, typename Y>
|
||||||
class GTSAM_EXPORT DecisionTree {
|
class GTSAM_EXPORT DecisionTree {
|
||||||
|
|
||||||
/// Default method used by `formatter` when printing.
|
|
||||||
static std::string DefaultFormatter(const L& x) {
|
|
||||||
std::stringstream ss;
|
|
||||||
ss << x;
|
|
||||||
return ss.str();
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Default method for comparison of two objects of type Y.
|
/// Default method for comparison of two objects of type Y.
|
||||||
static bool DefaultCompare(const Y& a, const Y& b) {
|
static bool DefaultCompare(const Y& a, const Y& b) {
|
||||||
return a == b;
|
return a == b;
|
||||||
|
|
@ -53,7 +46,8 @@ namespace gtsam {
|
||||||
|
|
||||||
public:
|
public:
|
||||||
|
|
||||||
using FormatterFunc = std::function<std::string(L)>;
|
using LabelFormatter = std::function<std::string(L)>;
|
||||||
|
using ValueFormatter = std::function<std::string(Y)>;
|
||||||
using CompareFunc = std::function<bool(const Y&, const Y&)>;
|
using CompareFunc = std::function<bool(const Y&, const Y&)>;
|
||||||
|
|
||||||
/** Handy typedefs for unary and binary function types */
|
/** Handy typedefs for unary and binary function types */
|
||||||
|
|
@ -94,15 +88,16 @@ namespace gtsam {
|
||||||
const void* id() const { return this; }
|
const void* id() const { return this; }
|
||||||
|
|
||||||
// everything else is virtual, no documentation here as internal
|
// everything else is virtual, no documentation here as internal
|
||||||
virtual void print(
|
virtual void print(const std::string& s,
|
||||||
const std::string& s = "",
|
const LabelFormatter& labelFormatter,
|
||||||
const FormatterFunc& formatter = &DefaultFormatter) const = 0;
|
const ValueFormatter& valueFormatter) const = 0;
|
||||||
virtual void dot(std::ostream& os, bool showZero) const = 0;
|
virtual void dot(std::ostream& os, const LabelFormatter& labelFormatter,
|
||||||
|
const ValueFormatter& valueFormatter,
|
||||||
|
bool showZero) const = 0;
|
||||||
virtual bool sameLeaf(const Leaf& q) const = 0;
|
virtual bool sameLeaf(const Leaf& q) const = 0;
|
||||||
virtual bool sameLeaf(const Node& q) const = 0;
|
virtual bool sameLeaf(const Node& q) const = 0;
|
||||||
virtual bool equals(
|
virtual bool equals(const Node& other, const CompareFunc& compare =
|
||||||
const Node& other, double tol = 1e-9,
|
&DefaultCompare) const = 0;
|
||||||
const CompareFunc& compare = &DefaultCompare) const = 0;
|
|
||||||
virtual const Y& operator()(const Assignment<L>& x) const = 0;
|
virtual const Y& operator()(const Assignment<L>& x) const = 0;
|
||||||
virtual Ptr apply(const Unary& op) const = 0;
|
virtual Ptr apply(const Unary& op) const = 0;
|
||||||
virtual Ptr apply_f_op_g(const Node&, const Binary&) const = 0;
|
virtual Ptr apply_f_op_g(const Node&, const Binary&) const = 0;
|
||||||
|
|
@ -118,11 +113,11 @@ namespace gtsam {
|
||||||
/** A function is a shared pointer to the root of a DT */
|
/** A function is a shared pointer to the root of a DT */
|
||||||
typedef typename Node::Ptr NodePtr;
|
typedef typename Node::Ptr NodePtr;
|
||||||
|
|
||||||
|
protected:
|
||||||
|
|
||||||
/* a DecisionTree just contains the root */
|
/* a DecisionTree just contains the root */
|
||||||
NodePtr root_;
|
NodePtr root_;
|
||||||
|
|
||||||
protected:
|
|
||||||
|
|
||||||
/** Internal recursive function to create from keys, cardinalities, and Y values */
|
/** Internal recursive function to create from keys, cardinalities, and Y values */
|
||||||
template<typename It, typename ValueIt>
|
template<typename It, typename ValueIt>
|
||||||
NodePtr create(It begin, It end, ValueIt beginY, ValueIt endY) const;
|
NodePtr create(It begin, It end, ValueIt beginY, ValueIt endY) const;
|
||||||
|
|
@ -131,7 +126,14 @@ namespace gtsam {
|
||||||
template <typename M, typename X>
|
template <typename M, typename X>
|
||||||
NodePtr convert(const typename DecisionTree<M, X>::NodePtr& f,
|
NodePtr convert(const typename DecisionTree<M, X>::NodePtr& f,
|
||||||
std::function<Y(const X&)> op,
|
std::function<Y(const X&)> op,
|
||||||
std::function<L(const M&)> map);
|
std::function<L(const M&)> map) const;
|
||||||
|
|
||||||
|
/// Convert to a different type, will not convert label if map empty.
|
||||||
|
template <typename M, typename X>
|
||||||
|
NodePtr convert(std::function<Y(const X&)> op,
|
||||||
|
std::function<L(const M&)> map) const {
|
||||||
|
return convert(root_, op, map);
|
||||||
|
}
|
||||||
|
|
||||||
public:
|
public:
|
||||||
|
|
||||||
|
|
@ -179,11 +181,11 @@ namespace gtsam {
|
||||||
/// @{
|
/// @{
|
||||||
|
|
||||||
/** GTSAM-style print */
|
/** GTSAM-style print */
|
||||||
void print(const std::string& s = "DecisionTree",
|
void print(const std::string& s, const LabelFormatter& labelFormatter,
|
||||||
const FormatterFunc& formatter = &DefaultFormatter) const;
|
const ValueFormatter& valueFormatter) const;
|
||||||
|
|
||||||
// Testable
|
// Testable
|
||||||
bool equals(const DecisionTree& other, double tol = 1e-9,
|
bool equals(const DecisionTree& other,
|
||||||
const CompareFunc& compare = &DefaultCompare) const;
|
const CompareFunc& compare = &DefaultCompare) const;
|
||||||
|
|
||||||
/// @}
|
/// @}
|
||||||
|
|
@ -225,13 +227,17 @@ namespace gtsam {
|
||||||
}
|
}
|
||||||
|
|
||||||
/** output to graphviz format, stream version */
|
/** output to graphviz format, stream version */
|
||||||
void dot(std::ostream& os, bool showZero = true) const;
|
void dot(std::ostream& os, const LabelFormatter& labelFormatter,
|
||||||
|
const ValueFormatter& valueFormatter, bool showZero = true) const;
|
||||||
|
|
||||||
/** output to graphviz format, open a file */
|
/** output to graphviz format, open a file */
|
||||||
void dot(const std::string& name, bool showZero = true) const;
|
void dot(const std::string& name, const LabelFormatter& labelFormatter,
|
||||||
|
const ValueFormatter& valueFormatter, bool showZero = true) const;
|
||||||
|
|
||||||
/** output to graphviz format string */
|
/** output to graphviz format string */
|
||||||
std::string dot(bool showZero = true) const;
|
std::string dot(const LabelFormatter& labelFormatter,
|
||||||
|
const ValueFormatter& valueFormatter,
|
||||||
|
bool showZero = true) const;
|
||||||
|
|
||||||
/// @name Advanced Interface
|
/// @name Advanced Interface
|
||||||
/// @{
|
/// @{
|
||||||
|
|
|
||||||
|
|
@ -43,34 +43,74 @@ void dot(const T&f, const string& filename) {
|
||||||
struct Crazy {
|
struct Crazy {
|
||||||
int a;
|
int a;
|
||||||
double b;
|
double b;
|
||||||
|
|
||||||
bool equals(const Crazy& other, double tol = 1e-12) const {
|
|
||||||
return a == other.a && std::abs(b - other.b) < tol;
|
|
||||||
}
|
|
||||||
|
|
||||||
bool operator==(const Crazy& other) const {
|
|
||||||
return this->equals(other);
|
|
||||||
}
|
|
||||||
};
|
};
|
||||||
|
|
||||||
typedef DecisionTree<string,Crazy> CrazyDecisionTree; // check that DecisionTree is actually generic (as it pretends to be)
|
// bool equals(const Crazy& other, double tol = 1e-12) const {
|
||||||
|
// return a == other.a && std::abs(b - other.b) < tol;
|
||||||
|
// }
|
||||||
|
|
||||||
|
// bool operator==(const Crazy& other) const {
|
||||||
|
// return this->equals(other);
|
||||||
|
// }
|
||||||
|
// };
|
||||||
|
|
||||||
|
struct CrazyDecisionTree : public DecisionTree<string, Crazy> {
|
||||||
|
/// print to stdout
|
||||||
|
void print(const std::string& s = "") const {
|
||||||
|
auto keyFormatter = [](const std::string& s) { return s; };
|
||||||
|
auto valueFormatter = [](const Crazy& v) {
|
||||||
|
return (boost::format("{%d,%4.2g}") % v.a % v.b).str();
|
||||||
|
};
|
||||||
|
DecisionTree<string, Crazy>::print("", keyFormatter, valueFormatter);
|
||||||
|
}
|
||||||
|
/// Equality method customized to Crazy node type
|
||||||
|
bool equals(const CrazyDecisionTree& other, double tol = 1e-9) const {
|
||||||
|
auto compare = [tol](const Crazy& v, const Crazy& w) {
|
||||||
|
return v.a == w.a && std::abs(v.b - w.b) < tol;
|
||||||
|
};
|
||||||
|
return DecisionTree<string, Crazy>::equals(other, compare);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
// traits
|
// traits
|
||||||
namespace gtsam {
|
namespace gtsam {
|
||||||
template<> struct traits<CrazyDecisionTree> : public Testable<CrazyDecisionTree> {};
|
template<> struct traits<CrazyDecisionTree> : public Testable<CrazyDecisionTree> {};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
GTSAM_CONCEPT_TESTABLE_INST(CrazyDecisionTree)
|
||||||
|
|
||||||
/* ******************************************************************************** */
|
/* ******************************************************************************** */
|
||||||
// Test string labels and int range
|
// Test string labels and int range
|
||||||
/* ******************************************************************************** */
|
/* ******************************************************************************** */
|
||||||
|
|
||||||
typedef DecisionTree<string, int> DT;
|
struct DT : public DecisionTree<string, int> {
|
||||||
|
using DecisionTree::DecisionTree;
|
||||||
|
DT(const DecisionTree<string, int>& dt) : root_(dt.root_) {}
|
||||||
|
|
||||||
|
/// print to stdout
|
||||||
|
void print(const std::string& s = "") const {
|
||||||
|
auto keyFormatter = [](const std::string& s) { return s; };
|
||||||
|
auto valueFormatter = [](const int& v) {
|
||||||
|
return (boost::format("%d") % v).str();
|
||||||
|
};
|
||||||
|
DecisionTree<string, int>::print("", keyFormatter, valueFormatter);
|
||||||
|
}
|
||||||
|
// /// Equality method customized to int node type
|
||||||
|
// bool equals(const CrazyDecisionTree& other, double tol = 1e-9) const {
|
||||||
|
// auto compare = [tol](const int& v, const int& w) {
|
||||||
|
// return v.a == w.a && std::abs(v.b - w.b) < tol;
|
||||||
|
// };
|
||||||
|
// return DecisionTree<string, int>::equals(other, compare);
|
||||||
|
// }
|
||||||
|
};
|
||||||
|
|
||||||
// traits
|
// traits
|
||||||
namespace gtsam {
|
namespace gtsam {
|
||||||
template<> struct traits<DT> : public Testable<DT> {};
|
template<> struct traits<DT> : public Testable<DT> {};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
GTSAM_CONCEPT_TESTABLE_INST(DT)
|
||||||
|
|
||||||
struct Ring {
|
struct Ring {
|
||||||
static inline int zero() {
|
static inline int zero() {
|
||||||
return 0;
|
return 0;
|
||||||
|
|
@ -91,111 +131,111 @@ struct Ring {
|
||||||
|
|
||||||
/* ******************************************************************************** */
|
/* ******************************************************************************** */
|
||||||
// test DT
|
// test DT
|
||||||
TEST(DT, example)
|
// TEST(DT, example)
|
||||||
{
|
// {
|
||||||
// Create labels
|
// // Create labels
|
||||||
string A("A"), B("B"), C("C");
|
// string A("A"), B("B"), C("C");
|
||||||
|
|
||||||
// create a value
|
// // create a value
|
||||||
Assignment<string> x00, x01, x10, x11;
|
// Assignment<string> x00, x01, x10, x11;
|
||||||
x00[A] = 0, x00[B] = 0;
|
// x00[A] = 0, x00[B] = 0;
|
||||||
x01[A] = 0, x01[B] = 1;
|
// x01[A] = 0, x01[B] = 1;
|
||||||
x10[A] = 1, x10[B] = 0;
|
// x10[A] = 1, x10[B] = 0;
|
||||||
x11[A] = 1, x11[B] = 1;
|
// x11[A] = 1, x11[B] = 1;
|
||||||
|
|
||||||
// empty
|
// // empty
|
||||||
DT empty;
|
// DT empty;
|
||||||
|
|
||||||
// A
|
// // A
|
||||||
DT a(A, 0, 5);
|
// DT a(A, 0, 5);
|
||||||
LONGS_EQUAL(0,a(x00))
|
// LONGS_EQUAL(0,a(x00))
|
||||||
LONGS_EQUAL(5,a(x10))
|
// LONGS_EQUAL(5,a(x10))
|
||||||
DOT(a);
|
// DOT(a);
|
||||||
|
|
||||||
// pruned
|
// // pruned
|
||||||
DT p(A, 2, 2);
|
// DT p(A, 2, 2);
|
||||||
LONGS_EQUAL(2,p(x00))
|
// LONGS_EQUAL(2,p(x00))
|
||||||
LONGS_EQUAL(2,p(x10))
|
// LONGS_EQUAL(2,p(x10))
|
||||||
DOT(p);
|
// DOT(p);
|
||||||
|
|
||||||
// \neg B
|
// // \neg B
|
||||||
DT notb(B, 5, 0);
|
// DT notb(B, 5, 0);
|
||||||
LONGS_EQUAL(5,notb(x00))
|
// LONGS_EQUAL(5,notb(x00))
|
||||||
LONGS_EQUAL(5,notb(x10))
|
// LONGS_EQUAL(5,notb(x10))
|
||||||
DOT(notb);
|
// DOT(notb);
|
||||||
|
|
||||||
// Check supplying empty trees yields an exception
|
// // Check supplying empty trees yields an exception
|
||||||
CHECK_EXCEPTION(apply(empty, &Ring::id), std::runtime_error);
|
// CHECK_EXCEPTION(apply(empty, &Ring::id), std::runtime_error);
|
||||||
CHECK_EXCEPTION(apply(empty, a, &Ring::mul), std::runtime_error);
|
// CHECK_EXCEPTION(apply(empty, a, &Ring::mul), std::runtime_error);
|
||||||
CHECK_EXCEPTION(apply(a, empty, &Ring::mul), std::runtime_error);
|
// CHECK_EXCEPTION(apply(a, empty, &Ring::mul), std::runtime_error);
|
||||||
|
|
||||||
// apply, two nodes, in natural order
|
// // apply, two nodes, in natural order
|
||||||
DT anotb = apply(a, notb, &Ring::mul);
|
// DT anotb = apply(a, notb, &Ring::mul);
|
||||||
LONGS_EQUAL(0,anotb(x00))
|
// LONGS_EQUAL(0,anotb(x00))
|
||||||
LONGS_EQUAL(0,anotb(x01))
|
// LONGS_EQUAL(0,anotb(x01))
|
||||||
LONGS_EQUAL(25,anotb(x10))
|
// LONGS_EQUAL(25,anotb(x10))
|
||||||
LONGS_EQUAL(0,anotb(x11))
|
// LONGS_EQUAL(0,anotb(x11))
|
||||||
DOT(anotb);
|
// DOT(anotb);
|
||||||
|
|
||||||
// check pruning
|
// // check pruning
|
||||||
DT pnotb = apply(p, notb, &Ring::mul);
|
// DT pnotb = apply(p, notb, &Ring::mul);
|
||||||
LONGS_EQUAL(10,pnotb(x00))
|
// LONGS_EQUAL(10,pnotb(x00))
|
||||||
LONGS_EQUAL( 0,pnotb(x01))
|
// LONGS_EQUAL( 0,pnotb(x01))
|
||||||
LONGS_EQUAL(10,pnotb(x10))
|
// LONGS_EQUAL(10,pnotb(x10))
|
||||||
LONGS_EQUAL( 0,pnotb(x11))
|
// LONGS_EQUAL( 0,pnotb(x11))
|
||||||
DOT(pnotb);
|
// DOT(pnotb);
|
||||||
|
|
||||||
// check pruning
|
// // check pruning
|
||||||
DT zeros = apply(DT(A, 0, 0), notb, &Ring::mul);
|
// DT zeros = apply(DT(A, 0, 0), notb, &Ring::mul);
|
||||||
LONGS_EQUAL(0,zeros(x00))
|
// LONGS_EQUAL(0,zeros(x00))
|
||||||
LONGS_EQUAL(0,zeros(x01))
|
// LONGS_EQUAL(0,zeros(x01))
|
||||||
LONGS_EQUAL(0,zeros(x10))
|
// LONGS_EQUAL(0,zeros(x10))
|
||||||
LONGS_EQUAL(0,zeros(x11))
|
// LONGS_EQUAL(0,zeros(x11))
|
||||||
DOT(zeros);
|
// DOT(zeros);
|
||||||
|
|
||||||
// apply, two nodes, in switched order
|
// // apply, two nodes, in switched order
|
||||||
DT notba = apply(a, notb, &Ring::mul);
|
// DT notba = apply(a, notb, &Ring::mul);
|
||||||
LONGS_EQUAL(0,notba(x00))
|
// LONGS_EQUAL(0,notba(x00))
|
||||||
LONGS_EQUAL(0,notba(x01))
|
// LONGS_EQUAL(0,notba(x01))
|
||||||
LONGS_EQUAL(25,notba(x10))
|
// LONGS_EQUAL(25,notba(x10))
|
||||||
LONGS_EQUAL(0,notba(x11))
|
// LONGS_EQUAL(0,notba(x11))
|
||||||
DOT(notba);
|
// DOT(notba);
|
||||||
|
|
||||||
// Test choose 0
|
// // Test choose 0
|
||||||
DT actual0 = notba.choose(A, 0);
|
// DT actual0 = notba.choose(A, 0);
|
||||||
EXPECT(assert_equal(DT(0.0), actual0));
|
// EXPECT(assert_equal(DT(0.0), actual0));
|
||||||
DOT(actual0);
|
// DOT(actual0);
|
||||||
|
|
||||||
// Test choose 1
|
// // Test choose 1
|
||||||
DT actual1 = notba.choose(A, 1);
|
// DT actual1 = notba.choose(A, 1);
|
||||||
EXPECT(assert_equal(DT(B, 25, 0), actual1));
|
// EXPECT(assert_equal(DT(B, 25, 0), actual1));
|
||||||
DOT(actual1);
|
// DOT(actual1);
|
||||||
|
|
||||||
// apply, two nodes at same level
|
// // apply, two nodes at same level
|
||||||
DT a_and_a = apply(a, a, &Ring::mul);
|
// DT a_and_a = apply(a, a, &Ring::mul);
|
||||||
LONGS_EQUAL(0,a_and_a(x00))
|
// LONGS_EQUAL(0,a_and_a(x00))
|
||||||
LONGS_EQUAL(0,a_and_a(x01))
|
// LONGS_EQUAL(0,a_and_a(x01))
|
||||||
LONGS_EQUAL(25,a_and_a(x10))
|
// LONGS_EQUAL(25,a_and_a(x10))
|
||||||
LONGS_EQUAL(25,a_and_a(x11))
|
// LONGS_EQUAL(25,a_and_a(x11))
|
||||||
DOT(a_and_a);
|
// DOT(a_and_a);
|
||||||
|
|
||||||
// create a function on C
|
// // create a function on C
|
||||||
DT c(C, 0, 5);
|
// DT c(C, 0, 5);
|
||||||
|
|
||||||
// and a model assigning stuff to C
|
// // and a model assigning stuff to C
|
||||||
Assignment<string> x101;
|
// Assignment<string> x101;
|
||||||
x101[A] = 1, x101[B] = 0, x101[C] = 1;
|
// x101[A] = 1, x101[B] = 0, x101[C] = 1;
|
||||||
|
|
||||||
// mul notba with C
|
// // mul notba with C
|
||||||
DT notbac = apply(notba, c, &Ring::mul);
|
// DT notbac = apply(notba, c, &Ring::mul);
|
||||||
LONGS_EQUAL(125,notbac(x101))
|
// LONGS_EQUAL(125,notbac(x101))
|
||||||
DOT(notbac);
|
// DOT(notbac);
|
||||||
|
|
||||||
// mul now in different order
|
// // mul now in different order
|
||||||
DT acnotb = apply(apply(a, c, &Ring::mul), notb, &Ring::mul);
|
// DT acnotb = apply(apply(a, c, &Ring::mul), notb, &Ring::mul);
|
||||||
LONGS_EQUAL(125,acnotb(x101))
|
// LONGS_EQUAL(125,acnotb(x101))
|
||||||
DOT(acnotb);
|
// DOT(acnotb);
|
||||||
}
|
// }
|
||||||
|
|
||||||
/* ******************************************************************************** */
|
/* ******************************************************************************** */
|
||||||
// test Conversion
|
// test Conversion
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue