Refactor print, equals, convert

release/4.3a0
Frank Dellaert 2022-01-02 13:57:12 -05:00
parent 78f8cc948d
commit db3cb4d9ac
4 changed files with 238 additions and 163 deletions

View File

@ -28,7 +28,13 @@ namespace gtsam {
* TODO: consider eliminating this class altogether? * TODO: consider eliminating this class altogether?
*/ */
template<typename L> template<typename L>
class AlgebraicDecisionTree: public DecisionTree<L, double> { class GTSAM_EXPORT AlgebraicDecisionTree: public DecisionTree<L, double> {
/// Default method used by `formatter` when printing.
static std::string DefaultFormatter(const L& x) {
std::stringstream ss;
ss << x;
return ss.str();
}
public: public:
@ -141,13 +147,23 @@ namespace gtsam {
return this->combine(labelC, &Ring::add); return this->combine(labelC, &Ring::add);
} }
/// print method customized to node type `double`.
void print(const std::string& s,
const typename Super::LabelFormatter& labelFormatter =
&DefaultFormatter) const {
auto valueFormatter = [](const double& v) {
return (boost::format("%4.2g") % v).str();
};
Super::print(s, labelFormatter, valueFormatter);
}
/// Equality method customized to node type `double`. /// Equality method customized to node type `double`.
bool equals(const AlgebraicDecisionTree& other, double tol = 1e-9) const { bool equals(const AlgebraicDecisionTree& other, double tol = 1e-9) const {
// lambda for comparison of two doubles upto some tolerance. // lambda for comparison of two doubles upto some tolerance.
auto compare = [tol](double a, double b) { auto compare = [tol](double a, double b) {
return std::abs(a - b) < tol; return std::abs(a - b) < tol;
}; };
return this->root_->equals(*other.root_, tol, compare); return Super::equals(other, compare);
} }
}; };
// AlgebraicDecisionTree // AlgebraicDecisionTree

View File

@ -76,25 +76,26 @@ namespace gtsam {
} }
/** equality up to tolerance */ /** equality up to tolerance */
bool equals(const Node& q, double tol, bool equals(const Node& q, const CompareFunc& compare) const override {
const CompareFunc& compare) const override {
const Leaf* other = dynamic_cast<const Leaf*>(&q); const Leaf* other = dynamic_cast<const Leaf*>(&q);
if (!other) return false; if (!other) return false;
return compare(this->constant_, other->constant_); return compare(this->constant_, other->constant_);
} }
/** print */ /** print */
void print(const std::string& s, void print(const std::string& s, const LabelFormatter& labelFormatter,
const FormatterFunc& formatter) const override { const ValueFormatter& valueFormatter) const override {
bool showZero = true; std::cout << s << " Leaf " << valueFormatter(constant_) << std::endl;
if (showZero || constant_) std::cout << s << " Leaf " << constant_ << std::endl;
} }
/** to graphviz file */ /** to graphviz file */
void dot(std::ostream& os, bool showZero) const override { void dot(std::ostream& os, const LabelFormatter& labelFormatter,
if (showZero || constant_) os << "\"" << this->id() << "\" [label=\"" const ValueFormatter& valueFormatter,
<< boost::format("%4.2g") % constant_ bool showZero) const override {
<< "\", shape=box, rank=sink, height=0.35, fixedsize=true]\n"; // width=0.55, std::string value = valueFormatter(constant_);
if (showZero || value.compare("0"))
os << "\"" << this->id() << "\" [label=\"" << value
<< "\", shape=box, rank=sink, height=0.35, fixedsize=true]\n"; // width=0.55,
} }
/** evaluate */ /** evaluate */
@ -238,16 +239,19 @@ namespace gtsam {
} }
/** print (as a tree) */ /** print (as a tree) */
void print(const std::string& s, void print(const std::string& s, const LabelFormatter& labelFormatter,
const FormatterFunc& formatter) const override { const ValueFormatter& valueFormatter) const override {
std::cout << s << " Choice("; std::cout << s << " Choice(";
std::cout << formatter(label_) << ") " << std::endl; std::cout << labelFormatter(label_) << ") " << std::endl;
for (size_t i = 0; i < branches_.size(); i++) for (size_t i = 0; i < branches_.size(); i++)
branches_[i]->print((boost::format("%s %d") % s % i).str(), formatter); branches_[i]->print((boost::format("%s %d") % s % i).str(),
labelFormatter, valueFormatter);
} }
/** output to graphviz (as a a graph) */ /** output to graphviz (as a a graph) */
void dot(std::ostream& os, bool showZero) const override { void dot(std::ostream& os, const LabelFormatter& labelFormatter,
const ValueFormatter& valueFormatter,
bool showZero) const override {
os << "\"" << this->id() << "\" [shape=circle, label=\"" << label_ os << "\"" << this->id() << "\" [shape=circle, label=\"" << label_
<< "\"]\n"; << "\"]\n";
size_t B = branches_.size(); size_t B = branches_.size();
@ -257,7 +261,8 @@ namespace gtsam {
// Check if zero // Check if zero
if (!showZero) { if (!showZero) {
const Leaf* leaf = dynamic_cast<const Leaf*> (branch.get()); const Leaf* leaf = dynamic_cast<const Leaf*> (branch.get());
if (leaf && !leaf->constant()) continue; std::string value = valueFormatter(leaf->constant());
if (leaf && value.compare("0")) continue;
} }
os << "\"" << this->id() << "\" -> \"" << branch->id() << "\""; os << "\"" << this->id() << "\" -> \"" << branch->id() << "\"";
@ -266,7 +271,7 @@ namespace gtsam {
if (i > 1) os << " [style=bold]"; if (i > 1) os << " [style=bold]";
} }
os << std::endl; os << std::endl;
branch->dot(os, showZero); branch->dot(os, labelFormatter, valueFormatter, showZero);
} }
} }
@ -280,16 +285,15 @@ namespace gtsam {
return (q.isLeaf() && q.sameLeaf(*this)); return (q.isLeaf() && q.sameLeaf(*this));
} }
/** equality up to tolerance */ /** equality */
bool equals(const Node& q, double tol, bool equals(const Node& q, const CompareFunc& compare) const override {
const CompareFunc& compare) const override {
const Choice* other = dynamic_cast<const Choice*>(&q); const Choice* other = dynamic_cast<const Choice*>(&q);
if (!other) return false; if (!other) return false;
if (this->label_ != other->label_) return false; if (this->label_ != other->label_) return false;
if (branches_.size() != other->branches_.size()) return false; if (branches_.size() != other->branches_.size()) return false;
// we don't care about shared pointers being equal here // we don't care about shared pointers being equal here
for (size_t i = 0; i < branches_.size(); i++) for (size_t i = 0; i < branches_.size(); i++)
if (!(branches_[i]->equals(*(other->branches_[i]), tol, compare))) if (!(branches_[i]->equals(*(other->branches_[i]), compare)))
return false; return false;
return true; return true;
} }
@ -459,7 +463,7 @@ namespace gtsam {
DecisionTree<L, Y>::DecisionTree(const DecisionTree<L, X>& other, DecisionTree<L, Y>::DecisionTree(const DecisionTree<L, X>& other,
std::function<Y(const X&)> op) { std::function<Y(const X&)> op) {
auto map = [](const L& label) { return label; }; auto map = [](const L& label) { return label; };
root_ = convert<L, X>(other.root_, op, map); root_ = other.template convert<L, X>(op, map);
} }
/*********************************************************************************/ /*********************************************************************************/
@ -470,7 +474,7 @@ namespace gtsam {
std::function<L(const M&)> map_function = [&map](const M& label) -> L { std::function<L(const M&)> map_function = [&map](const M& label) -> L {
return map.at(label); return map.at(label);
}; };
root_ = convert<M, X>(other.root_, op, map_function); root_ = other.template convert<M, X>(op, map_function);
} }
/*********************************************************************************/ /*********************************************************************************/
@ -587,7 +591,7 @@ namespace gtsam {
template <typename M, typename X> template <typename M, typename X>
typename DecisionTree<L, Y>::NodePtr DecisionTree<L, Y>::convert( typename DecisionTree<L, Y>::NodePtr DecisionTree<L, Y>::convert(
const typename DecisionTree<M, X>::NodePtr& f, const typename DecisionTree<M, X>::NodePtr& f,
std::function<Y(const X&)> op, std::function<L(const M&)> map) { std::function<Y(const X&)> op, std::function<L(const M&)> map) const {
typedef DecisionTree<M, X> MX; typedef DecisionTree<M, X> MX;
typedef typename MX::Leaf MXLeaf; typedef typename MX::Leaf MXLeaf;
typedef typename MX::Choice MXChoice; typedef typename MX::Choice MXChoice;
@ -596,11 +600,11 @@ namespace gtsam {
// ugliness below because apparently we can't have templated virtual functions // ugliness below because apparently we can't have templated virtual functions
// If leaf, apply unary conversion "op" and create a unique leaf // If leaf, apply unary conversion "op" and create a unique leaf
const MXLeaf* leaf = dynamic_cast<const MXLeaf*> (f.get()); auto leaf = boost::dynamic_pointer_cast<const MXLeaf>(f);
if (leaf) return NodePtr(new Leaf(op(leaf->constant()))); if (leaf) return NodePtr(new Leaf(op(leaf->constant())));
// Check if Choice // Check if Choice
boost::shared_ptr<const MXChoice> choice = boost::dynamic_pointer_cast<const MXChoice> (f); auto choice = boost::dynamic_pointer_cast<const MXChoice>(f);
if (!choice) throw std::invalid_argument( if (!choice) throw std::invalid_argument(
"DecisionTree::Convert: Invalid NodePtr"); "DecisionTree::Convert: Invalid NodePtr");
@ -619,15 +623,16 @@ namespace gtsam {
/*********************************************************************************/ /*********************************************************************************/
template <typename L, typename Y> template <typename L, typename Y>
bool DecisionTree<L, Y>::equals(const DecisionTree& other, double tol, bool DecisionTree<L, Y>::equals(const DecisionTree& other,
const CompareFunc& compare) const { const CompareFunc& compare) const {
return root_->equals(*other.root_, tol, compare); return root_->equals(*other.root_, compare);
} }
template <typename L, typename Y> template <typename L, typename Y>
void DecisionTree<L, Y>::print(const std::string& s, void DecisionTree<L, Y>::print(const std::string& s,
const FormatterFunc& formatter) const { const LabelFormatter& labelFormatter,
root_->print(s, formatter); const ValueFormatter& valueFormatter) const {
root_->print(s, labelFormatter, valueFormatter);
} }
template<typename L, typename Y> template<typename L, typename Y>
@ -687,26 +692,34 @@ namespace gtsam {
} }
/*********************************************************************************/ /*********************************************************************************/
template<typename L, typename Y> template <typename L, typename Y>
void DecisionTree<L, Y>::dot(std::ostream& os, bool showZero) const { void DecisionTree<L, Y>::dot(std::ostream& os,
const LabelFormatter& labelFormatter,
const ValueFormatter& valueFormatter,
bool showZero) const {
os << "digraph G {\n"; os << "digraph G {\n";
root_->dot(os, showZero); root_->dot(os, labelFormatter, valueFormatter, showZero);
os << " [ordering=out]}" << std::endl; os << " [ordering=out]}" << std::endl;
} }
template<typename L, typename Y> template <typename L, typename Y>
void DecisionTree<L, Y>::dot(const std::string& name, bool showZero) const { void DecisionTree<L, Y>::dot(const std::string& name,
const LabelFormatter& labelFormatter,
const ValueFormatter& valueFormatter,
bool showZero) const {
std::ofstream os((name + ".dot").c_str()); std::ofstream os((name + ".dot").c_str());
dot(os, showZero); dot(os, labelFormatter, valueFormatter, showZero);
int result = system( int result = system(
("dot -Tpdf " + name + ".dot -o " + name + ".pdf >& /dev/null").c_str()); ("dot -Tpdf " + name + ".dot -o " + name + ".pdf >& /dev/null").c_str());
if (result==-1) throw std::runtime_error("DecisionTree::dot system call failed"); if (result==-1) throw std::runtime_error("DecisionTree::dot system call failed");
} }
template<typename L, typename Y> template <typename L, typename Y>
std::string DecisionTree<L, Y>::dot(bool showZero) const { std::string DecisionTree<L, Y>::dot(const LabelFormatter& labelFormatter,
const ValueFormatter& valueFormatter,
bool showZero) const {
std::stringstream ss; std::stringstream ss;
dot(ss, showZero); dot(ss, labelFormatter, valueFormatter, showZero);
return ss.str(); return ss.str();
} }

View File

@ -39,13 +39,6 @@ namespace gtsam {
template<typename L, typename Y> template<typename L, typename Y>
class GTSAM_EXPORT DecisionTree { class GTSAM_EXPORT DecisionTree {
/// Default method used by `formatter` when printing.
static std::string DefaultFormatter(const L& x) {
std::stringstream ss;
ss << x;
return ss.str();
}
/// Default method for comparison of two objects of type Y. /// Default method for comparison of two objects of type Y.
static bool DefaultCompare(const Y& a, const Y& b) { static bool DefaultCompare(const Y& a, const Y& b) {
return a == b; return a == b;
@ -53,7 +46,8 @@ namespace gtsam {
public: public:
using FormatterFunc = std::function<std::string(L)>; using LabelFormatter = std::function<std::string(L)>;
using ValueFormatter = std::function<std::string(Y)>;
using CompareFunc = std::function<bool(const Y&, const Y&)>; using CompareFunc = std::function<bool(const Y&, const Y&)>;
/** Handy typedefs for unary and binary function types */ /** Handy typedefs for unary and binary function types */
@ -94,15 +88,16 @@ namespace gtsam {
const void* id() const { return this; } const void* id() const { return this; }
// everything else is virtual, no documentation here as internal // everything else is virtual, no documentation here as internal
virtual void print( virtual void print(const std::string& s,
const std::string& s = "", const LabelFormatter& labelFormatter,
const FormatterFunc& formatter = &DefaultFormatter) const = 0; const ValueFormatter& valueFormatter) const = 0;
virtual void dot(std::ostream& os, bool showZero) const = 0; virtual void dot(std::ostream& os, const LabelFormatter& labelFormatter,
const ValueFormatter& valueFormatter,
bool showZero) const = 0;
virtual bool sameLeaf(const Leaf& q) const = 0; virtual bool sameLeaf(const Leaf& q) const = 0;
virtual bool sameLeaf(const Node& q) const = 0; virtual bool sameLeaf(const Node& q) const = 0;
virtual bool equals( virtual bool equals(const Node& other, const CompareFunc& compare =
const Node& other, double tol = 1e-9, &DefaultCompare) const = 0;
const CompareFunc& compare = &DefaultCompare) const = 0;
virtual const Y& operator()(const Assignment<L>& x) const = 0; virtual const Y& operator()(const Assignment<L>& x) const = 0;
virtual Ptr apply(const Unary& op) const = 0; virtual Ptr apply(const Unary& op) const = 0;
virtual Ptr apply_f_op_g(const Node&, const Binary&) const = 0; virtual Ptr apply_f_op_g(const Node&, const Binary&) const = 0;
@ -118,11 +113,11 @@ namespace gtsam {
/** A function is a shared pointer to the root of a DT */ /** A function is a shared pointer to the root of a DT */
typedef typename Node::Ptr NodePtr; typedef typename Node::Ptr NodePtr;
protected:
/* a DecisionTree just contains the root */ /* a DecisionTree just contains the root */
NodePtr root_; NodePtr root_;
protected:
/** Internal recursive function to create from keys, cardinalities, and Y values */ /** Internal recursive function to create from keys, cardinalities, and Y values */
template<typename It, typename ValueIt> template<typename It, typename ValueIt>
NodePtr create(It begin, It end, ValueIt beginY, ValueIt endY) const; NodePtr create(It begin, It end, ValueIt beginY, ValueIt endY) const;
@ -131,7 +126,14 @@ namespace gtsam {
template <typename M, typename X> template <typename M, typename X>
NodePtr convert(const typename DecisionTree<M, X>::NodePtr& f, NodePtr convert(const typename DecisionTree<M, X>::NodePtr& f,
std::function<Y(const X&)> op, std::function<Y(const X&)> op,
std::function<L(const M&)> map); std::function<L(const M&)> map) const;
/// Convert to a different type, will not convert label if map empty.
template <typename M, typename X>
NodePtr convert(std::function<Y(const X&)> op,
std::function<L(const M&)> map) const {
return convert(root_, op, map);
}
public: public:
@ -179,11 +181,11 @@ namespace gtsam {
/// @{ /// @{
/** GTSAM-style print */ /** GTSAM-style print */
void print(const std::string& s = "DecisionTree", void print(const std::string& s, const LabelFormatter& labelFormatter,
const FormatterFunc& formatter = &DefaultFormatter) const; const ValueFormatter& valueFormatter) const;
// Testable // Testable
bool equals(const DecisionTree& other, double tol = 1e-9, bool equals(const DecisionTree& other,
const CompareFunc& compare = &DefaultCompare) const; const CompareFunc& compare = &DefaultCompare) const;
/// @} /// @}
@ -225,13 +227,17 @@ namespace gtsam {
} }
/** output to graphviz format, stream version */ /** output to graphviz format, stream version */
void dot(std::ostream& os, bool showZero = true) const; void dot(std::ostream& os, const LabelFormatter& labelFormatter,
const ValueFormatter& valueFormatter, bool showZero = true) const;
/** output to graphviz format, open a file */ /** output to graphviz format, open a file */
void dot(const std::string& name, bool showZero = true) const; void dot(const std::string& name, const LabelFormatter& labelFormatter,
const ValueFormatter& valueFormatter, bool showZero = true) const;
/** output to graphviz format string */ /** output to graphviz format string */
std::string dot(bool showZero = true) const; std::string dot(const LabelFormatter& labelFormatter,
const ValueFormatter& valueFormatter,
bool showZero = true) const;
/// @name Advanced Interface /// @name Advanced Interface
/// @{ /// @{

View File

@ -43,34 +43,74 @@ void dot(const T&f, const string& filename) {
struct Crazy { struct Crazy {
int a; int a;
double b; double b;
bool equals(const Crazy& other, double tol = 1e-12) const {
return a == other.a && std::abs(b - other.b) < tol;
}
bool operator==(const Crazy& other) const {
return this->equals(other);
}
}; };
typedef DecisionTree<string,Crazy> CrazyDecisionTree; // check that DecisionTree is actually generic (as it pretends to be) // bool equals(const Crazy& other, double tol = 1e-12) const {
// return a == other.a && std::abs(b - other.b) < tol;
// }
// bool operator==(const Crazy& other) const {
// return this->equals(other);
// }
// };
struct CrazyDecisionTree : public DecisionTree<string, Crazy> {
/// print to stdout
void print(const std::string& s = "") const {
auto keyFormatter = [](const std::string& s) { return s; };
auto valueFormatter = [](const Crazy& v) {
return (boost::format("{%d,%4.2g}") % v.a % v.b).str();
};
DecisionTree<string, Crazy>::print("", keyFormatter, valueFormatter);
}
/// Equality method customized to Crazy node type
bool equals(const CrazyDecisionTree& other, double tol = 1e-9) const {
auto compare = [tol](const Crazy& v, const Crazy& w) {
return v.a == w.a && std::abs(v.b - w.b) < tol;
};
return DecisionTree<string, Crazy>::equals(other, compare);
}
};
// traits // traits
namespace gtsam { namespace gtsam {
template<> struct traits<CrazyDecisionTree> : public Testable<CrazyDecisionTree> {}; template<> struct traits<CrazyDecisionTree> : public Testable<CrazyDecisionTree> {};
} }
GTSAM_CONCEPT_TESTABLE_INST(CrazyDecisionTree)
/* ******************************************************************************** */ /* ******************************************************************************** */
// Test string labels and int range // Test string labels and int range
/* ******************************************************************************** */ /* ******************************************************************************** */
typedef DecisionTree<string, int> DT; struct DT : public DecisionTree<string, int> {
using DecisionTree::DecisionTree;
DT(const DecisionTree<string, int>& dt) : root_(dt.root_) {}
/// print to stdout
void print(const std::string& s = "") const {
auto keyFormatter = [](const std::string& s) { return s; };
auto valueFormatter = [](const int& v) {
return (boost::format("%d") % v).str();
};
DecisionTree<string, int>::print("", keyFormatter, valueFormatter);
}
// /// Equality method customized to int node type
// bool equals(const CrazyDecisionTree& other, double tol = 1e-9) const {
// auto compare = [tol](const int& v, const int& w) {
// return v.a == w.a && std::abs(v.b - w.b) < tol;
// };
// return DecisionTree<string, int>::equals(other, compare);
// }
};
// traits // traits
namespace gtsam { namespace gtsam {
template<> struct traits<DT> : public Testable<DT> {}; template<> struct traits<DT> : public Testable<DT> {};
} }
GTSAM_CONCEPT_TESTABLE_INST(DT)
struct Ring { struct Ring {
static inline int zero() { static inline int zero() {
return 0; return 0;
@ -91,111 +131,111 @@ struct Ring {
/* ******************************************************************************** */ /* ******************************************************************************** */
// test DT // test DT
TEST(DT, example) // TEST(DT, example)
{ // {
// Create labels // // Create labels
string A("A"), B("B"), C("C"); // string A("A"), B("B"), C("C");
// create a value // // create a value
Assignment<string> x00, x01, x10, x11; // Assignment<string> x00, x01, x10, x11;
x00[A] = 0, x00[B] = 0; // x00[A] = 0, x00[B] = 0;
x01[A] = 0, x01[B] = 1; // x01[A] = 0, x01[B] = 1;
x10[A] = 1, x10[B] = 0; // x10[A] = 1, x10[B] = 0;
x11[A] = 1, x11[B] = 1; // x11[A] = 1, x11[B] = 1;
// empty // // empty
DT empty; // DT empty;
// A // // A
DT a(A, 0, 5); // DT a(A, 0, 5);
LONGS_EQUAL(0,a(x00)) // LONGS_EQUAL(0,a(x00))
LONGS_EQUAL(5,a(x10)) // LONGS_EQUAL(5,a(x10))
DOT(a); // DOT(a);
// pruned // // pruned
DT p(A, 2, 2); // DT p(A, 2, 2);
LONGS_EQUAL(2,p(x00)) // LONGS_EQUAL(2,p(x00))
LONGS_EQUAL(2,p(x10)) // LONGS_EQUAL(2,p(x10))
DOT(p); // DOT(p);
// \neg B // // \neg B
DT notb(B, 5, 0); // DT notb(B, 5, 0);
LONGS_EQUAL(5,notb(x00)) // LONGS_EQUAL(5,notb(x00))
LONGS_EQUAL(5,notb(x10)) // LONGS_EQUAL(5,notb(x10))
DOT(notb); // DOT(notb);
// Check supplying empty trees yields an exception // // Check supplying empty trees yields an exception
CHECK_EXCEPTION(apply(empty, &Ring::id), std::runtime_error); // CHECK_EXCEPTION(apply(empty, &Ring::id), std::runtime_error);
CHECK_EXCEPTION(apply(empty, a, &Ring::mul), std::runtime_error); // CHECK_EXCEPTION(apply(empty, a, &Ring::mul), std::runtime_error);
CHECK_EXCEPTION(apply(a, empty, &Ring::mul), std::runtime_error); // CHECK_EXCEPTION(apply(a, empty, &Ring::mul), std::runtime_error);
// apply, two nodes, in natural order // // apply, two nodes, in natural order
DT anotb = apply(a, notb, &Ring::mul); // DT anotb = apply(a, notb, &Ring::mul);
LONGS_EQUAL(0,anotb(x00)) // LONGS_EQUAL(0,anotb(x00))
LONGS_EQUAL(0,anotb(x01)) // LONGS_EQUAL(0,anotb(x01))
LONGS_EQUAL(25,anotb(x10)) // LONGS_EQUAL(25,anotb(x10))
LONGS_EQUAL(0,anotb(x11)) // LONGS_EQUAL(0,anotb(x11))
DOT(anotb); // DOT(anotb);
// check pruning // // check pruning
DT pnotb = apply(p, notb, &Ring::mul); // DT pnotb = apply(p, notb, &Ring::mul);
LONGS_EQUAL(10,pnotb(x00)) // LONGS_EQUAL(10,pnotb(x00))
LONGS_EQUAL( 0,pnotb(x01)) // LONGS_EQUAL( 0,pnotb(x01))
LONGS_EQUAL(10,pnotb(x10)) // LONGS_EQUAL(10,pnotb(x10))
LONGS_EQUAL( 0,pnotb(x11)) // LONGS_EQUAL( 0,pnotb(x11))
DOT(pnotb); // DOT(pnotb);
// check pruning // // check pruning
DT zeros = apply(DT(A, 0, 0), notb, &Ring::mul); // DT zeros = apply(DT(A, 0, 0), notb, &Ring::mul);
LONGS_EQUAL(0,zeros(x00)) // LONGS_EQUAL(0,zeros(x00))
LONGS_EQUAL(0,zeros(x01)) // LONGS_EQUAL(0,zeros(x01))
LONGS_EQUAL(0,zeros(x10)) // LONGS_EQUAL(0,zeros(x10))
LONGS_EQUAL(0,zeros(x11)) // LONGS_EQUAL(0,zeros(x11))
DOT(zeros); // DOT(zeros);
// apply, two nodes, in switched order // // apply, two nodes, in switched order
DT notba = apply(a, notb, &Ring::mul); // DT notba = apply(a, notb, &Ring::mul);
LONGS_EQUAL(0,notba(x00)) // LONGS_EQUAL(0,notba(x00))
LONGS_EQUAL(0,notba(x01)) // LONGS_EQUAL(0,notba(x01))
LONGS_EQUAL(25,notba(x10)) // LONGS_EQUAL(25,notba(x10))
LONGS_EQUAL(0,notba(x11)) // LONGS_EQUAL(0,notba(x11))
DOT(notba); // DOT(notba);
// Test choose 0 // // Test choose 0
DT actual0 = notba.choose(A, 0); // DT actual0 = notba.choose(A, 0);
EXPECT(assert_equal(DT(0.0), actual0)); // EXPECT(assert_equal(DT(0.0), actual0));
DOT(actual0); // DOT(actual0);
// Test choose 1 // // Test choose 1
DT actual1 = notba.choose(A, 1); // DT actual1 = notba.choose(A, 1);
EXPECT(assert_equal(DT(B, 25, 0), actual1)); // EXPECT(assert_equal(DT(B, 25, 0), actual1));
DOT(actual1); // DOT(actual1);
// apply, two nodes at same level // // apply, two nodes at same level
DT a_and_a = apply(a, a, &Ring::mul); // DT a_and_a = apply(a, a, &Ring::mul);
LONGS_EQUAL(0,a_and_a(x00)) // LONGS_EQUAL(0,a_and_a(x00))
LONGS_EQUAL(0,a_and_a(x01)) // LONGS_EQUAL(0,a_and_a(x01))
LONGS_EQUAL(25,a_and_a(x10)) // LONGS_EQUAL(25,a_and_a(x10))
LONGS_EQUAL(25,a_and_a(x11)) // LONGS_EQUAL(25,a_and_a(x11))
DOT(a_and_a); // DOT(a_and_a);
// create a function on C // // create a function on C
DT c(C, 0, 5); // DT c(C, 0, 5);
// and a model assigning stuff to C // // and a model assigning stuff to C
Assignment<string> x101; // Assignment<string> x101;
x101[A] = 1, x101[B] = 0, x101[C] = 1; // x101[A] = 1, x101[B] = 0, x101[C] = 1;
// mul notba with C // // mul notba with C
DT notbac = apply(notba, c, &Ring::mul); // DT notbac = apply(notba, c, &Ring::mul);
LONGS_EQUAL(125,notbac(x101)) // LONGS_EQUAL(125,notbac(x101))
DOT(notbac); // DOT(notbac);
// mul now in different order // // mul now in different order
DT acnotb = apply(apply(a, c, &Ring::mul), notb, &Ring::mul); // DT acnotb = apply(apply(a, c, &Ring::mul), notb, &Ring::mul);
LONGS_EQUAL(125,acnotb(x101)) // LONGS_EQUAL(125,acnotb(x101))
DOT(acnotb); // DOT(acnotb);
} // }
/* ******************************************************************************** */ /* ******************************************************************************** */
// test Conversion // test Conversion