From 94f21358f4725b3ebc2afcfacf2a42ffe9b08358 Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Wed, 29 Dec 2021 13:31:06 -0500 Subject: [PATCH 01/26] fix decision tree equality and move default constructor to public --- gtsam/discrete/DecisionTree-inl.h | 2 +- gtsam/discrete/DecisionTree.h | 6 +++--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/gtsam/discrete/DecisionTree-inl.h b/gtsam/discrete/DecisionTree-inl.h index f6a64f11f..3bd2ac113 100644 --- a/gtsam/discrete/DecisionTree-inl.h +++ b/gtsam/discrete/DecisionTree-inl.h @@ -79,7 +79,7 @@ namespace gtsam { bool equals(const Node& q, double tol) const override { const Leaf* other = dynamic_cast (&q); if (!other) return false; - return std::abs(double(this->constant_ - other->constant_)) < tol; + return this->constant_ == other->constant_; } /** print */ diff --git a/gtsam/discrete/DecisionTree.h b/gtsam/discrete/DecisionTree.h index 0a78d4635..1e2c8b509 100644 --- a/gtsam/discrete/DecisionTree.h +++ b/gtsam/discrete/DecisionTree.h @@ -113,14 +113,14 @@ namespace gtsam { convert(const typename DecisionTree::NodePtr& f, const std::map& map, std::function op); - /** Default constructor */ - DecisionTree(); - public: /// @name Standard Constructors /// @{ + /** Default constructor (for serialization) */ + DecisionTree(); + /** Create a constant */ DecisionTree(const Y& y); From ddaf9608d0855676306b379aa5c0d3cbeb4b7be6 Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Wed, 29 Dec 2021 14:14:13 -0500 Subject: [PATCH 02/26] add formatting capabilities to DecisionTree --- gtsam/discrete/DecisionTree-inl.h | 20 +++++++++++--------- gtsam/discrete/DecisionTree.h | 18 +++++++++++++++--- gtsam/discrete/Potentials.cpp | 4 ++-- 3 files changed, 28 insertions(+), 14 deletions(-) diff --git a/gtsam/discrete/DecisionTree-inl.h b/gtsam/discrete/DecisionTree-inl.h index 3bd2ac113..e2c0a944d 100644 --- a/gtsam/discrete/DecisionTree-inl.h +++ b/gtsam/discrete/DecisionTree-inl.h @@ -83,7 +83,8 @@ namespace gtsam { } /** print */ - void print(const std::string& s) const override { + void print(const std::string& s, + const std::function formatter) const override { bool showZero = true; if (showZero || constant_) std::cout << s << " Leaf " << constant_ << std::endl; } @@ -236,12 +237,11 @@ namespace gtsam { } /** print (as a tree) */ - void print(const std::string& s) const override { + void print(const std::string& s, const std::function formatter) const override { std::cout << s << " Choice("; - // std::cout << this << ","; - std::cout << label_ << ") " << std::endl; + std::cout << formatter(label_) << ") " << std::endl; for (size_t i = 0; i < branches_.size(); i++) - branches_[i]->print((boost::format("%s %d") % s % i).str()); + branches_[i]->print((boost::format("%s %d") % s % i).str(), formatter); } /** output to graphviz (as a a graph) */ @@ -591,7 +591,7 @@ namespace gtsam { // get new label M oldLabel = choice->label(); - L newLabel = map.at(oldLabel); + L newLabel = oldLabel; //map.at(oldLabel); // put together via Shannon expansion otherwise not sorted. std::vector functions; @@ -608,9 +608,11 @@ namespace gtsam { return root_->equals(*other.root_, tol); } - template - void DecisionTree::print(const std::string& s) const { - root_->print(s); + template + void DecisionTree::print( + const std::string& s, + const std::function formatter) const { + root_->print(s, formatter); } template diff --git a/gtsam/discrete/DecisionTree.h b/gtsam/discrete/DecisionTree.h index 1e2c8b509..68ddfa06b 100644 --- a/gtsam/discrete/DecisionTree.h +++ b/gtsam/discrete/DecisionTree.h @@ -20,13 +20,13 @@ #pragma once #include - #include #include #include #include #include +#include #include namespace gtsam { @@ -79,7 +79,13 @@ namespace gtsam { const void* id() const { return this; } // everything else is virtual, no documentation here as internal - virtual void print(const std::string& s = "") const = 0; + virtual void print( + const std::string& s = "", + const std::function formatter = [](const L& x) { + std::stringstream ss; + ss << x; + return ss.str(); + }) const = 0; virtual void dot(std::ostream& os, bool showZero) const = 0; virtual bool sameLeaf(const Leaf& q) const = 0; virtual bool sameLeaf(const Node& q) const = 0; @@ -154,7 +160,13 @@ namespace gtsam { /// @{ /** GTSAM-style print */ - void print(const std::string& s = "DecisionTree") const; + void print( + const std::string& s = "DecisionTree", + const std::function formatter = [](const L& x) { + std::stringstream ss; + ss << x; + return ss.str(); + }) const; // Testable bool equals(const DecisionTree& other, double tol = 1e-9) const; diff --git a/gtsam/discrete/Potentials.cpp b/gtsam/discrete/Potentials.cpp index fa491eba3..057b6a265 100644 --- a/gtsam/discrete/Potentials.cpp +++ b/gtsam/discrete/Potentials.cpp @@ -51,11 +51,11 @@ bool Potentials::equals(const Potentials& other, double tol) const { /* ************************************************************************* */ void Potentials::print(const string& s, const KeyFormatter& formatter) const { - cout << s << "\n Cardinalities: {"; + cout << s << "\n Cardinalities: { "; for (const std::pair& key : cardinalities_) cout << formatter(key.first) << ":" << key.second << ", "; cout << "}" << endl; - ADT::print(" "); + ADT::print(" ", formatter); } // // /* ************************************************************************* */ From 8b5d93ad379367c941bf3351d8841286c4cbe0d7 Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Wed, 29 Dec 2021 14:32:35 -0500 Subject: [PATCH 03/26] revert incorrect change --- gtsam/discrete/DecisionTree-inl.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/gtsam/discrete/DecisionTree-inl.h b/gtsam/discrete/DecisionTree-inl.h index e2c0a944d..c26d25420 100644 --- a/gtsam/discrete/DecisionTree-inl.h +++ b/gtsam/discrete/DecisionTree-inl.h @@ -591,7 +591,7 @@ namespace gtsam { // get new label M oldLabel = choice->label(); - L newLabel = oldLabel; //map.at(oldLabel); + L newLabel = map.at(oldLabel); // put together via Shannon expansion otherwise not sorted. std::vector functions; From bb6e489c372cadcd50434731e4be5a4dcb6e5ac2 Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Wed, 29 Dec 2021 15:15:19 -0500 Subject: [PATCH 04/26] new DecisionTree constructor and methods that takes an op to convert from one type to another # Conflicts: # gtsam/hybrid/DCMixtureFactor.h --- gtsam/discrete/DecisionTree-inl.h | 40 +++++++++++++++++++++++++++++++ gtsam/discrete/DecisionTree.h | 19 ++++++++++++++- 2 files changed, 58 insertions(+), 1 deletion(-) diff --git a/gtsam/discrete/DecisionTree-inl.h b/gtsam/discrete/DecisionTree-inl.h index c26d25420..099ccb528 100644 --- a/gtsam/discrete/DecisionTree-inl.h +++ b/gtsam/discrete/DecisionTree-inl.h @@ -457,6 +457,14 @@ namespace gtsam { root_ = convert(other.root_, map, op); } + /*********************************************************************************/ + template + template + DecisionTree::DecisionTree(const DecisionTree& other, + std::function op) { + root_ = convert(other.root_, op); + } + /*********************************************************************************/ // Called by two constructors above. // Takes a label and a corresponding range of decision trees, and creates a new @@ -602,6 +610,38 @@ namespace gtsam { return LY::compose(functions.begin(), functions.end(), newLabel); } + /*********************************************************************************/ + template + template + typename DecisionTree::NodePtr DecisionTree::convert( + const typename DecisionTree::NodePtr& f, + std::function op) { + + typedef DecisionTree LX; + typedef typename LX::Leaf LXLeaf; + typedef typename LX::Choice LXChoice; + typedef typename LX::NodePtr LXNodePtr; + typedef DecisionTree LY; + + // ugliness below because apparently we can't have templated virtual functions + // If leaf, apply unary conversion "op" and create a unique leaf + const LXLeaf* leaf = dynamic_cast (f.get()); + if (leaf) return NodePtr(new Leaf(op(leaf->constant()))); + + // Check if Choice + boost::shared_ptr choice = boost::dynamic_pointer_cast (f); + if (!choice) throw std::invalid_argument( + "DecisionTree::Convert: Invalid NodePtr"); + + // put together via Shannon expansion otherwise not sorted. + std::vector functions; + for(const LXNodePtr& branch: choice->branches()) { + LY converted(convert(branch, op)); + functions += converted; + } + return LY::compose(functions.begin(), functions.end(), choice->label()); + } + /*********************************************************************************/ template bool DecisionTree::equals(const DecisionTree& other, double tol) const { diff --git a/gtsam/discrete/DecisionTree.h b/gtsam/discrete/DecisionTree.h index 68ddfa06b..3b91def63 100644 --- a/gtsam/discrete/DecisionTree.h +++ b/gtsam/discrete/DecisionTree.h @@ -119,7 +119,12 @@ namespace gtsam { convert(const typename DecisionTree::NodePtr& f, const std::map& map, std::function op); - public: + /** Convert only node to a different type */ + template + NodePtr convert(const typename DecisionTree::NodePtr& f, + const std::function op); + + public: /// @name Standard Constructors /// @{ @@ -155,6 +160,11 @@ namespace gtsam { DecisionTree(const DecisionTree& other, const std::map& map, std::function op); + /** Convert only nodes from a different type */ + template + DecisionTree(const DecisionTree& other, + std::function op); + /// @} /// @name Testable /// @{ @@ -231,12 +241,19 @@ namespace gtsam { /** free versions of apply */ + //TODO(Varun) where are these templates Y, L and not L, Y? template DecisionTree apply(const DecisionTree& f, const typename DecisionTree::Unary& op) { return f.apply(op); } + template + DecisionTree apply(const DecisionTree& f, + const std::function& op) { + return f.apply(op); + } + template DecisionTree apply(const DecisionTree& f, const DecisionTree& g, From 315b10bb960fd765b24e78f87ae09e27a1127967 Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Wed, 29 Dec 2021 16:00:09 -0500 Subject: [PATCH 05/26] minor format --- gtsam/discrete/DecisionTree-inl.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/gtsam/discrete/DecisionTree-inl.h b/gtsam/discrete/DecisionTree-inl.h index 099ccb528..51d66c860 100644 --- a/gtsam/discrete/DecisionTree-inl.h +++ b/gtsam/discrete/DecisionTree-inl.h @@ -77,7 +77,7 @@ namespace gtsam { /** equality up to tolerance */ bool equals(const Node& q, double tol) const override { - const Leaf* other = dynamic_cast (&q); + const Leaf* other = dynamic_cast(&q); if (!other) return false; return this->constant_ == other->constant_; } From 28071ed23dcd543283e41e925c83fdc4ea06028b Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Wed, 29 Dec 2021 20:43:16 -0500 Subject: [PATCH 06/26] added SFINAE methods for Leaf node equality checks --- gtsam/discrete/DecisionTree-inl.h | 38 +++++++++++++++++++++++++------ 1 file changed, 31 insertions(+), 7 deletions(-) diff --git a/gtsam/discrete/DecisionTree-inl.h b/gtsam/discrete/DecisionTree-inl.h index 51d66c860..f3b0dbf3a 100644 --- a/gtsam/discrete/DecisionTree-inl.h +++ b/gtsam/discrete/DecisionTree-inl.h @@ -19,22 +19,24 @@ #pragma once -#include #include +#include +#include +#include #include +#include #include #include -#include -using boost::assign::operator+=; +#include #include -#include - -#include #include #include +#include #include +using boost::assign::operator+=; + namespace gtsam { /*********************************************************************************/ @@ -75,11 +77,33 @@ namespace gtsam { return (q.isLeaf() && q.sameLeaf(*this)); } + /// @{ + /// SFINAE methods for proper substitution. + /** equality for integral types. */ + template + typename std::enable_if::value, bool>::type + equals(const T& a, const T& b, double tol) const { + return std::abs(double(a - b)) < tol; + } + /** equality for boost::shared_ptr types. */ + template + typename std::enable_if::value, bool>::type + equals(const T& a, const T& b, double tol) const { + return traits::Equals(*a, *b, tol); + } + /** equality for all other types. */ + template + typename std::enable_if::value && !std::is_integral::value, bool>::type + equals(const Y& a, const Y& b, double tol) const { + return traits::Equals(a, b, tol); + } + /// @} + /** equality up to tolerance */ bool equals(const Node& q, double tol) const override { const Leaf* other = dynamic_cast(&q); if (!other) return false; - return this->constant_ == other->constant_; + return this->equals(this->constant_, other->constant_, tol); } /** print */ From f1dedca2b791f9355130086c3185721d5fdeff18 Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Wed, 29 Dec 2021 20:44:35 -0500 Subject: [PATCH 07/26] replace dot with DOT to prevent collision with vector dot product --- .../tests/testAlgebraicDecisionTree.cpp | 90 +++++++++---------- gtsam/discrete/tests/testDecisionTree.cpp | 10 +-- 2 files changed, 49 insertions(+), 51 deletions(-) diff --git a/gtsam/discrete/tests/testAlgebraicDecisionTree.cpp b/gtsam/discrete/tests/testAlgebraicDecisionTree.cpp index 7a33810c7..becc5a2a1 100644 --- a/gtsam/discrete/tests/testAlgebraicDecisionTree.cpp +++ b/gtsam/discrete/tests/testAlgebraicDecisionTree.cpp @@ -48,7 +48,7 @@ template<> struct traits : public Testable {}; #define DISABLE_DOT template -void dot(const T&f, const string& filename) { +void DOT(const T&f, const string& filename) { #ifndef DISABLE_DOT f.dot(filename); #endif @@ -112,7 +112,7 @@ TEST(ADT, example3) ADT note(E, 0.9, 0.1); ADT cnotb = c * notb; - dot(cnotb, "ADT-cnotb"); + DOT(cnotb, "ADT-cnotb"); // a.print("a: "); // cnotb.print("cnotb: "); @@ -120,11 +120,11 @@ TEST(ADT, example3) // acnotb.print("acnotb: "); // acnotb.printCache("acnotb Cache:"); - dot(acnotb, "ADT-acnotb"); + DOT(acnotb, "ADT-acnotb"); ADT big = apply(apply(d, note, &mul), acnotb, &add_); - dot(big, "ADT-big"); + DOT(big, "ADT-big"); } /* ******************************************************************************** */ @@ -136,8 +136,8 @@ ADT create(const Signature& signature) { ADT p(signature.discreteKeys(), signature.cpt()); static size_t count = 0; const DiscreteKey& key = signature.key(); - string dotfile = (boost::format("CPT-%03d-%d") % ++count % key.first).str(); - dot(p, dotfile); + string DOTfile = (boost::format("CPT-%03d-%d") % ++count % key.first).str(); + DOT(p, DOTfile); return p; } @@ -167,21 +167,21 @@ TEST(ADT, joint) resetCounts(); gttic_(asiaJoint); ADT joint = pA; - dot(joint, "Asia-A"); + DOT(joint, "Asia-A"); joint = apply(joint, pS, &mul); - dot(joint, "Asia-AS"); + DOT(joint, "Asia-AS"); joint = apply(joint, pT, &mul); - dot(joint, "Asia-AST"); + DOT(joint, "Asia-AST"); joint = apply(joint, pL, &mul); - dot(joint, "Asia-ASTL"); + DOT(joint, "Asia-ASTL"); joint = apply(joint, pB, &mul); - dot(joint, "Asia-ASTLB"); + DOT(joint, "Asia-ASTLB"); joint = apply(joint, pE, &mul); - dot(joint, "Asia-ASTLBE"); + DOT(joint, "Asia-ASTLBE"); joint = apply(joint, pX, &mul); - dot(joint, "Asia-ASTLBEX"); + DOT(joint, "Asia-ASTLBEX"); joint = apply(joint, pD, &mul); - dot(joint, "Asia-ASTLBEXD"); + DOT(joint, "Asia-ASTLBEXD"); EXPECT_LONGS_EQUAL(346, muls); gttoc_(asiaJoint); tictoc_getNode(asiaJointNode, asiaJoint); @@ -229,21 +229,21 @@ TEST(ADT, inference) resetCounts(); gttic_(asiaProd); ADT joint = pA; - dot(joint, "Joint-Product-A"); + DOT(joint, "Joint-Product-A"); joint = apply(joint, pS, &mul); - dot(joint, "Joint-Product-AS"); + DOT(joint, "Joint-Product-AS"); joint = apply(joint, pT, &mul); - dot(joint, "Joint-Product-AST"); + DOT(joint, "Joint-Product-AST"); joint = apply(joint, pL, &mul); - dot(joint, "Joint-Product-ASTL"); + DOT(joint, "Joint-Product-ASTL"); joint = apply(joint, pB, &mul); - dot(joint, "Joint-Product-ASTLB"); + DOT(joint, "Joint-Product-ASTLB"); joint = apply(joint, pE, &mul); - dot(joint, "Joint-Product-ASTLBE"); + DOT(joint, "Joint-Product-ASTLBE"); joint = apply(joint, pX, &mul); - dot(joint, "Joint-Product-ASTLBEX"); + DOT(joint, "Joint-Product-ASTLBEX"); joint = apply(joint, pD, &mul); - dot(joint, "Joint-Product-ASTLBEXD"); + DOT(joint, "Joint-Product-ASTLBEXD"); EXPECT_LONGS_EQUAL(370, (long)muls); // different ordering gttoc_(asiaProd); tictoc_getNode(asiaProdNode, asiaProd); @@ -255,13 +255,13 @@ TEST(ADT, inference) gttic_(asiaSum); ADT marginal = joint; marginal = marginal.combine(X, &add_); - dot(marginal, "Joint-Sum-ADBLEST"); + DOT(marginal, "Joint-Sum-ADBLEST"); marginal = marginal.combine(T, &add_); - dot(marginal, "Joint-Sum-ADBLES"); + DOT(marginal, "Joint-Sum-ADBLES"); marginal = marginal.combine(S, &add_); - dot(marginal, "Joint-Sum-ADBLE"); + DOT(marginal, "Joint-Sum-ADBLE"); marginal = marginal.combine(E, &add_); - dot(marginal, "Joint-Sum-ADBL"); + DOT(marginal, "Joint-Sum-ADBL"); EXPECT_LONGS_EQUAL(161, (long)adds); gttoc_(asiaSum); tictoc_getNode(asiaSumNode, asiaSum); @@ -300,7 +300,7 @@ TEST(ADT, factor_graph) fg = apply(fg, pE, &mul); fg = apply(fg, pX, &mul); fg = apply(fg, pD, &mul); - dot(fg, "FactorGraph"); + DOT(fg, "FactorGraph"); EXPECT_LONGS_EQUAL(158, (long)muls); gttoc_(asiaFG); tictoc_getNode(asiaFGNode, asiaFG); @@ -311,15 +311,15 @@ TEST(ADT, factor_graph) resetCounts(); gttic_(marg); fg = fg.combine(X, &add_); - dot(fg, "Marginalized-6X"); + DOT(fg, "Marginalized-6X"); fg = fg.combine(T, &add_); - dot(fg, "Marginalized-5T"); + DOT(fg, "Marginalized-5T"); fg = fg.combine(S, &add_); - dot(fg, "Marginalized-4S"); + DOT(fg, "Marginalized-4S"); fg = fg.combine(E, &add_); - dot(fg, "Marginalized-3E"); + DOT(fg, "Marginalized-3E"); fg = fg.combine(L, &add_); - dot(fg, "Marginalized-2L"); + DOT(fg, "Marginalized-2L"); EXPECT(adds = 54); gttoc_(marg); tictoc_getNode(margNode, marg); @@ -333,9 +333,9 @@ TEST(ADT, factor_graph) resetCounts(); gttic_(elimX); ADT fE = pX; - dot(fE, "Eliminate-01-fEX"); + DOT(fE, "Eliminate-01-fEX"); fE = fE.combine(X, &add_); - dot(fE, "Eliminate-02-fE"); + DOT(fE, "Eliminate-02-fE"); gttoc_(elimX); tictoc_getNode(elimXNode, elimX); elapsed = elimXNode->secs() + elimXNode->wall(); @@ -347,9 +347,9 @@ TEST(ADT, factor_graph) gttic_(elimT); ADT fLE = pT; fLE = apply(fLE, pE, &mul); - dot(fLE, "Eliminate-03-fLET"); + DOT(fLE, "Eliminate-03-fLET"); fLE = fLE.combine(T, &add_); - dot(fLE, "Eliminate-04-fLE"); + DOT(fLE, "Eliminate-04-fLE"); gttoc_(elimT); tictoc_getNode(elimTNode, elimT); elapsed = elimTNode->secs() + elimTNode->wall(); @@ -362,9 +362,9 @@ TEST(ADT, factor_graph) ADT fBL = pS; fBL = apply(fBL, pL, &mul); fBL = apply(fBL, pB, &mul); - dot(fBL, "Eliminate-05-fBLS"); + DOT(fBL, "Eliminate-05-fBLS"); fBL = fBL.combine(S, &add_); - dot(fBL, "Eliminate-06-fBL"); + DOT(fBL, "Eliminate-06-fBL"); gttoc_(elimS); tictoc_getNode(elimSNode, elimS); elapsed = elimSNode->secs() + elimSNode->wall(); @@ -377,9 +377,9 @@ TEST(ADT, factor_graph) ADT fBL2 = fE; fBL2 = apply(fBL2, fLE, &mul); fBL2 = apply(fBL2, pD, &mul); - dot(fBL2, "Eliminate-07-fBLE"); + DOT(fBL2, "Eliminate-07-fBLE"); fBL2 = fBL2.combine(E, &add_); - dot(fBL2, "Eliminate-08-fBL2"); + DOT(fBL2, "Eliminate-08-fBL2"); gttoc_(elimE); tictoc_getNode(elimENode, elimE); elapsed = elimENode->secs() + elimENode->wall(); @@ -391,9 +391,9 @@ TEST(ADT, factor_graph) gttic_(elimL); ADT fB = fBL; fB = apply(fB, fBL2, &mul); - dot(fB, "Eliminate-09-fBL"); + DOT(fB, "Eliminate-09-fBL"); fB = fB.combine(L, &add_); - dot(fB, "Eliminate-10-fB"); + DOT(fB, "Eliminate-10-fB"); gttoc_(elimL); tictoc_getNode(elimLNode, elimL); elapsed = elimLNode->secs() + elimLNode->wall(); @@ -491,7 +491,7 @@ TEST(ADT, conversion) { DiscreteKey X(0,2), Y(1,2); ADT fDiscreteKey(X & Y, "0.2 0.5 0.3 0.6"); - dot(fDiscreteKey, "conversion-f1"); + DOT(fDiscreteKey, "conversion-f1"); std::map keyMap; keyMap[0] = 5; @@ -500,7 +500,7 @@ TEST(ADT, conversion) AlgebraicDecisionTree fIndexKey(fDiscreteKey, keyMap); // f1.print("f1"); // f2.print("f2"); - dot(fIndexKey, "conversion-f2"); + DOT(fIndexKey, "conversion-f2"); DiscreteValues x00, x01, x02, x10, x11, x12; x00[5] = 0, x00[2] = 0; @@ -519,7 +519,7 @@ TEST(ADT, elimination) { DiscreteKey A(0,2), B(1,3), C(2,2); ADT f1(A & B & C, "1 2 3 4 5 6 1 8 3 3 5 5"); - dot(f1, "elimination-f1"); + DOT(f1, "elimination-f1"); { // sum out lower key diff --git a/gtsam/discrete/tests/testDecisionTree.cpp b/gtsam/discrete/tests/testDecisionTree.cpp index 96f503abc..9eb06f2c4 100644 --- a/gtsam/discrete/tests/testDecisionTree.cpp +++ b/gtsam/discrete/tests/testDecisionTree.cpp @@ -32,14 +32,12 @@ using namespace std; using namespace gtsam; template -void dot(const T&f, const string& filename) { +void DOT(const T&f, const string& filename) { #ifndef DISABLE_DOT f.dot(filename); #endif } -#define DOT(x)(dot(x,#x)) - struct Crazy { int a; double b; }; typedef DecisionTree CrazyDecisionTree; // check that DecisionTree is actually generic (as it pretends to be) @@ -179,8 +177,8 @@ TEST(DT, example) enum Label { U, V, X, Y, Z }; -typedef DecisionTree BDT; -bool convert(const int& y) { +typedef DecisionTree BDT; +int convert(const int& y) { return y != 0; } @@ -196,7 +194,7 @@ TEST(DT, conversion) map ordering; ordering[A] = X; ordering[B] = Y; - std::function op = convert; + std::function op = convert; BDT f2(f1, ordering, op); // f1.print("f1"); // f2.print("f2"); From 1c76de40d1e4ee8d07aa3723a8931b2e18c0f626 Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Wed, 29 Dec 2021 20:45:55 -0500 Subject: [PATCH 08/26] minor fix --- gtsam/discrete/tests/testDecisionTree.cpp | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/gtsam/discrete/tests/testDecisionTree.cpp b/gtsam/discrete/tests/testDecisionTree.cpp index 9eb06f2c4..28b8866ad 100644 --- a/gtsam/discrete/tests/testDecisionTree.cpp +++ b/gtsam/discrete/tests/testDecisionTree.cpp @@ -32,12 +32,14 @@ using namespace std; using namespace gtsam; template -void DOT(const T&f, const string& filename) { +void write_dot(const T&f, const string& filename) { #ifndef DISABLE_DOT f.dot(filename); #endif } +#define DOT(x)(write_dot(x,#x)) + struct Crazy { int a; double b; }; typedef DecisionTree CrazyDecisionTree; // check that DecisionTree is actually generic (as it pretends to be) From 573d0d17737cf11edb0d24b7677a2e964523775b Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Wed, 29 Dec 2021 23:46:20 -0500 Subject: [PATCH 09/26] undo change to test --- gtsam/discrete/tests/testDecisionTree.cpp | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/gtsam/discrete/tests/testDecisionTree.cpp b/gtsam/discrete/tests/testDecisionTree.cpp index 28b8866ad..b44306723 100644 --- a/gtsam/discrete/tests/testDecisionTree.cpp +++ b/gtsam/discrete/tests/testDecisionTree.cpp @@ -179,8 +179,8 @@ TEST(DT, example) enum Label { U, V, X, Y, Z }; -typedef DecisionTree BDT; -int convert(const int& y) { +typedef DecisionTree BDT; +bool convert(const int& y) { return y != 0; } @@ -196,7 +196,7 @@ TEST(DT, conversion) map ordering; ordering[A] = X; ordering[B] = Y; - std::function op = convert; + std::function op = convert; BDT f2(f1, ordering, op); // f1.print("f1"); // f2.print("f2"); From ed839083e214a2d834d25230e3a877cbd94c6a49 Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Wed, 29 Dec 2021 23:47:20 -0500 Subject: [PATCH 10/26] formatter passed as reference and added a default formatter method --- gtsam/discrete/DecisionTree-inl.h | 7 ++++--- gtsam/discrete/DecisionTree.h | 34 +++++++++++++++---------------- 2 files changed, 20 insertions(+), 21 deletions(-) diff --git a/gtsam/discrete/DecisionTree-inl.h b/gtsam/discrete/DecisionTree-inl.h index f3b0dbf3a..ec6222fd7 100644 --- a/gtsam/discrete/DecisionTree-inl.h +++ b/gtsam/discrete/DecisionTree-inl.h @@ -108,7 +108,7 @@ namespace gtsam { /** print */ void print(const std::string& s, - const std::function formatter) const override { + const std::function& formatter) const override { bool showZero = true; if (showZero || constant_) std::cout << s << " Leaf " << constant_ << std::endl; } @@ -261,7 +261,8 @@ namespace gtsam { } /** print (as a tree) */ - void print(const std::string& s, const std::function formatter) const override { + void print(const std::string& s, + const std::function& formatter) const override { std::cout << s << " Choice("; std::cout << formatter(label_) << ") " << std::endl; for (size_t i = 0; i < branches_.size(); i++) @@ -675,7 +676,7 @@ namespace gtsam { template void DecisionTree::print( const std::string& s, - const std::function formatter) const { + const std::function& formatter) const { root_->print(s, formatter); } diff --git a/gtsam/discrete/DecisionTree.h b/gtsam/discrete/DecisionTree.h index 3b91def63..498fce5aa 100644 --- a/gtsam/discrete/DecisionTree.h +++ b/gtsam/discrete/DecisionTree.h @@ -39,6 +39,13 @@ namespace gtsam { template class GTSAM_EXPORT DecisionTree { + /// default method used by `formatter` when printing. + static std::string DefaultFormatter(const L& x) { + std::stringstream ss; + ss << x; + return ss.str(); + } + public: /** Handy typedefs for unary and binary function types */ @@ -79,13 +86,9 @@ namespace gtsam { const void* id() const { return this; } // everything else is virtual, no documentation here as internal - virtual void print( - const std::string& s = "", - const std::function formatter = [](const L& x) { - std::stringstream ss; - ss << x; - return ss.str(); - }) const = 0; + virtual void print(const std::string& s = "", + const std::function& formatter = + &DefaultFormatter) const = 0; virtual void dot(std::ostream& os, bool showZero) const = 0; virtual bool sameLeaf(const Leaf& q) const = 0; virtual bool sameLeaf(const Node& q) const = 0; @@ -170,13 +173,9 @@ namespace gtsam { /// @{ /** GTSAM-style print */ - void print( - const std::string& s = "DecisionTree", - const std::function formatter = [](const L& x) { - std::stringstream ss; - ss << x; - return ss.str(); - }) const; + void print(const std::string& s = "DecisionTree", + const std::function& formatter = + &DefaultFormatter) const; // Testable bool equals(const DecisionTree& other, double tol = 1e-9) const; @@ -241,20 +240,19 @@ namespace gtsam { /** free versions of apply */ - //TODO(Varun) where are these templates Y, L and not L, Y? - template + template DecisionTree apply(const DecisionTree& f, const typename DecisionTree::Unary& op) { return f.apply(op); } - template + template DecisionTree apply(const DecisionTree& f, const std::function& op) { return f.apply(op); } - template + template DecisionTree apply(const DecisionTree& f, const DecisionTree& g, const typename DecisionTree::Binary& op) { From 9982057a2b832ab0ac0ce348eeb8bcd0156ebd9c Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Thu, 30 Dec 2021 11:32:51 -0500 Subject: [PATCH 11/26] undo dot changes --- .../tests/testAlgebraicDecisionTree.cpp | 96 +++++++++---------- gtsam/discrete/tests/testDecisionTree.cpp | 4 +- 2 files changed, 50 insertions(+), 50 deletions(-) diff --git a/gtsam/discrete/tests/testAlgebraicDecisionTree.cpp b/gtsam/discrete/tests/testAlgebraicDecisionTree.cpp index becc5a2a1..910515b5c 100644 --- a/gtsam/discrete/tests/testAlgebraicDecisionTree.cpp +++ b/gtsam/discrete/tests/testAlgebraicDecisionTree.cpp @@ -48,7 +48,7 @@ template<> struct traits : public Testable {}; #define DISABLE_DOT template -void DOT(const T&f, const string& filename) { +void dot(const T&f, const string& filename) { #ifndef DISABLE_DOT f.dot(filename); #endif @@ -112,7 +112,7 @@ TEST(ADT, example3) ADT note(E, 0.9, 0.1); ADT cnotb = c * notb; - DOT(cnotb, "ADT-cnotb"); + dot(cnotb, "ADT-cnotb"); // a.print("a: "); // cnotb.print("cnotb: "); @@ -120,11 +120,11 @@ TEST(ADT, example3) // acnotb.print("acnotb: "); // acnotb.printCache("acnotb Cache:"); - DOT(acnotb, "ADT-acnotb"); + dot(acnotb, "ADT-acnotb"); ADT big = apply(apply(d, note, &mul), acnotb, &add_); - DOT(big, "ADT-big"); + dot(big, "ADT-big"); } /* ******************************************************************************** */ @@ -137,7 +137,7 @@ ADT create(const Signature& signature) { static size_t count = 0; const DiscreteKey& key = signature.key(); string DOTfile = (boost::format("CPT-%03d-%d") % ++count % key.first).str(); - DOT(p, DOTfile); + dot(p, DOTfile); return p; } @@ -167,21 +167,21 @@ TEST(ADT, joint) resetCounts(); gttic_(asiaJoint); ADT joint = pA; - DOT(joint, "Asia-A"); + dot(joint, "Asia-A"); joint = apply(joint, pS, &mul); - DOT(joint, "Asia-AS"); + dot(joint, "Asia-AS"); joint = apply(joint, pT, &mul); - DOT(joint, "Asia-AST"); + dot(joint, "Asia-AST"); joint = apply(joint, pL, &mul); - DOT(joint, "Asia-ASTL"); + dot(joint, "Asia-ASTL"); joint = apply(joint, pB, &mul); - DOT(joint, "Asia-ASTLB"); + dot(joint, "Asia-ASTLB"); joint = apply(joint, pE, &mul); - DOT(joint, "Asia-ASTLBE"); + dot(joint, "Asia-ASTLBE"); joint = apply(joint, pX, &mul); - DOT(joint, "Asia-ASTLBEX"); + dot(joint, "Asia-ASTLBEX"); joint = apply(joint, pD, &mul); - DOT(joint, "Asia-ASTLBEXD"); + dot(joint, "Asia-ASTLBEXD"); EXPECT_LONGS_EQUAL(346, muls); gttoc_(asiaJoint); tictoc_getNode(asiaJointNode, asiaJoint); @@ -229,21 +229,21 @@ TEST(ADT, inference) resetCounts(); gttic_(asiaProd); ADT joint = pA; - DOT(joint, "Joint-Product-A"); + dot(joint, "Joint-Product-A"); joint = apply(joint, pS, &mul); - DOT(joint, "Joint-Product-AS"); + dot(joint, "Joint-Product-AS"); joint = apply(joint, pT, &mul); - DOT(joint, "Joint-Product-AST"); + dot(joint, "Joint-Product-AST"); joint = apply(joint, pL, &mul); - DOT(joint, "Joint-Product-ASTL"); + dot(joint, "Joint-Product-ASTL"); joint = apply(joint, pB, &mul); - DOT(joint, "Joint-Product-ASTLB"); + dot(joint, "Joint-Product-ASTLB"); joint = apply(joint, pE, &mul); - DOT(joint, "Joint-Product-ASTLBE"); + dot(joint, "Joint-Product-ASTLBE"); joint = apply(joint, pX, &mul); - DOT(joint, "Joint-Product-ASTLBEX"); + dot(joint, "Joint-Product-ASTLBEX"); joint = apply(joint, pD, &mul); - DOT(joint, "Joint-Product-ASTLBEXD"); + dot(joint, "Joint-Product-ASTLBEXD"); EXPECT_LONGS_EQUAL(370, (long)muls); // different ordering gttoc_(asiaProd); tictoc_getNode(asiaProdNode, asiaProd); @@ -255,13 +255,13 @@ TEST(ADT, inference) gttic_(asiaSum); ADT marginal = joint; marginal = marginal.combine(X, &add_); - DOT(marginal, "Joint-Sum-ADBLEST"); + dot(marginal, "Joint-Sum-ADBLEST"); marginal = marginal.combine(T, &add_); - DOT(marginal, "Joint-Sum-ADBLES"); + dot(marginal, "Joint-Sum-ADBLES"); marginal = marginal.combine(S, &add_); - DOT(marginal, "Joint-Sum-ADBLE"); + dot(marginal, "Joint-Sum-ADBLE"); marginal = marginal.combine(E, &add_); - DOT(marginal, "Joint-Sum-ADBL"); + dot(marginal, "Joint-Sum-ADBL"); EXPECT_LONGS_EQUAL(161, (long)adds); gttoc_(asiaSum); tictoc_getNode(asiaSumNode, asiaSum); @@ -300,7 +300,7 @@ TEST(ADT, factor_graph) fg = apply(fg, pE, &mul); fg = apply(fg, pX, &mul); fg = apply(fg, pD, &mul); - DOT(fg, "FactorGraph"); + dot(fg, "FactorGraph"); EXPECT_LONGS_EQUAL(158, (long)muls); gttoc_(asiaFG); tictoc_getNode(asiaFGNode, asiaFG); @@ -311,15 +311,15 @@ TEST(ADT, factor_graph) resetCounts(); gttic_(marg); fg = fg.combine(X, &add_); - DOT(fg, "Marginalized-6X"); + dot(fg, "Marginalized-6X"); fg = fg.combine(T, &add_); - DOT(fg, "Marginalized-5T"); + dot(fg, "Marginalized-5T"); fg = fg.combine(S, &add_); - DOT(fg, "Marginalized-4S"); + dot(fg, "Marginalized-4S"); fg = fg.combine(E, &add_); - DOT(fg, "Marginalized-3E"); + dot(fg, "Marginalized-3E"); fg = fg.combine(L, &add_); - DOT(fg, "Marginalized-2L"); + dot(fg, "Marginalized-2L"); EXPECT(adds = 54); gttoc_(marg); tictoc_getNode(margNode, marg); @@ -333,9 +333,9 @@ TEST(ADT, factor_graph) resetCounts(); gttic_(elimX); ADT fE = pX; - DOT(fE, "Eliminate-01-fEX"); + dot(fE, "Eliminate-01-fEX"); fE = fE.combine(X, &add_); - DOT(fE, "Eliminate-02-fE"); + dot(fE, "Eliminate-02-fE"); gttoc_(elimX); tictoc_getNode(elimXNode, elimX); elapsed = elimXNode->secs() + elimXNode->wall(); @@ -347,9 +347,9 @@ TEST(ADT, factor_graph) gttic_(elimT); ADT fLE = pT; fLE = apply(fLE, pE, &mul); - DOT(fLE, "Eliminate-03-fLET"); + dot(fLE, "Eliminate-03-fLET"); fLE = fLE.combine(T, &add_); - DOT(fLE, "Eliminate-04-fLE"); + dot(fLE, "Eliminate-04-fLE"); gttoc_(elimT); tictoc_getNode(elimTNode, elimT); elapsed = elimTNode->secs() + elimTNode->wall(); @@ -362,9 +362,9 @@ TEST(ADT, factor_graph) ADT fBL = pS; fBL = apply(fBL, pL, &mul); fBL = apply(fBL, pB, &mul); - DOT(fBL, "Eliminate-05-fBLS"); + dot(fBL, "Eliminate-05-fBLS"); fBL = fBL.combine(S, &add_); - DOT(fBL, "Eliminate-06-fBL"); + dot(fBL, "Eliminate-06-fBL"); gttoc_(elimS); tictoc_getNode(elimSNode, elimS); elapsed = elimSNode->secs() + elimSNode->wall(); @@ -377,9 +377,9 @@ TEST(ADT, factor_graph) ADT fBL2 = fE; fBL2 = apply(fBL2, fLE, &mul); fBL2 = apply(fBL2, pD, &mul); - DOT(fBL2, "Eliminate-07-fBLE"); + dot(fBL2, "Eliminate-07-fBLE"); fBL2 = fBL2.combine(E, &add_); - DOT(fBL2, "Eliminate-08-fBL2"); + dot(fBL2, "Eliminate-08-fBL2"); gttoc_(elimE); tictoc_getNode(elimENode, elimE); elapsed = elimENode->secs() + elimENode->wall(); @@ -391,9 +391,9 @@ TEST(ADT, factor_graph) gttic_(elimL); ADT fB = fBL; fB = apply(fB, fBL2, &mul); - DOT(fB, "Eliminate-09-fBL"); + dot(fB, "Eliminate-09-fBL"); fB = fB.combine(L, &add_); - DOT(fB, "Eliminate-10-fB"); + dot(fB, "Eliminate-10-fB"); gttoc_(elimL); tictoc_getNode(elimLNode, elimL); elapsed = elimLNode->secs() + elimLNode->wall(); @@ -414,13 +414,13 @@ TEST(ADT, equality_noparser) // Check straight equality ADT pA1 = create(A % tableA); ADT pA2 = create(A % tableA); - EXPECT(pA1 == pA2); // should be equal + EXPECT(pA1.equals(pA2)); // should be equal // Check equality after apply ADT pB = create(B % tableB); ADT pAB1 = apply(pA1, pB, &mul); ADT pAB2 = apply(pB, pA1, &mul); - EXPECT(pAB2 == pAB1); + EXPECT(pAB2.equals(pAB1)); } /* ************************************************************************* */ @@ -431,13 +431,13 @@ TEST(ADT, equality_parser) // Check straight equality ADT pA1 = create(A % "80/20"); ADT pA2 = create(A % "80/20"); - EXPECT(pA1 == pA2); // should be equal + EXPECT(pA1.equals(pA2)); // should be equal // Check equality after apply ADT pB = create(B % "60/40"); ADT pAB1 = apply(pA1, pB, &mul); ADT pAB2 = apply(pB, pA1, &mul); - EXPECT(pAB2 == pAB1); + EXPECT(pAB2.equals(pAB1)); } /* ******************************************************************************** */ @@ -491,7 +491,7 @@ TEST(ADT, conversion) { DiscreteKey X(0,2), Y(1,2); ADT fDiscreteKey(X & Y, "0.2 0.5 0.3 0.6"); - DOT(fDiscreteKey, "conversion-f1"); + dot(fDiscreteKey, "conversion-f1"); std::map keyMap; keyMap[0] = 5; @@ -500,7 +500,7 @@ TEST(ADT, conversion) AlgebraicDecisionTree fIndexKey(fDiscreteKey, keyMap); // f1.print("f1"); // f2.print("f2"); - DOT(fIndexKey, "conversion-f2"); + dot(fIndexKey, "conversion-f2"); DiscreteValues x00, x01, x02, x10, x11, x12; x00[5] = 0, x00[2] = 0; @@ -519,7 +519,7 @@ TEST(ADT, elimination) { DiscreteKey A(0,2), B(1,3), C(2,2); ADT f1(A & B & C, "1 2 3 4 5 6 1 8 3 3 5 5"); - DOT(f1, "elimination-f1"); + dot(f1, "elimination-f1"); { // sum out lower key diff --git a/gtsam/discrete/tests/testDecisionTree.cpp b/gtsam/discrete/tests/testDecisionTree.cpp index b44306723..96f503abc 100644 --- a/gtsam/discrete/tests/testDecisionTree.cpp +++ b/gtsam/discrete/tests/testDecisionTree.cpp @@ -32,13 +32,13 @@ using namespace std; using namespace gtsam; template -void write_dot(const T&f, const string& filename) { +void dot(const T&f, const string& filename) { #ifndef DISABLE_DOT f.dot(filename); #endif } -#define DOT(x)(write_dot(x,#x)) +#define DOT(x)(dot(x,#x)) struct Crazy { int a; double b; }; typedef DecisionTree CrazyDecisionTree; // check that DecisionTree is actually generic (as it pretends to be) From b24da8399a363e27be3d281c5aed39aa11e07d57 Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Thu, 30 Dec 2021 11:34:05 -0500 Subject: [PATCH 12/26] add comparator as argument to equals method # Conflicts: # gtsam/hybrid/DCMixtureFactor.h --- gtsam/discrete/AlgebraicDecisionTree.h | 13 +++++++ gtsam/discrete/DecisionTree-inl.h | 46 ++++++++--------------- gtsam/discrete/DecisionTree.h | 15 ++++++-- gtsam/discrete/tests/testDecisionTree.cpp | 14 ++++++- 4 files changed, 53 insertions(+), 35 deletions(-) diff --git a/gtsam/discrete/AlgebraicDecisionTree.h b/gtsam/discrete/AlgebraicDecisionTree.h index 9cc55ed6a..3469ac5cf 100644 --- a/gtsam/discrete/AlgebraicDecisionTree.h +++ b/gtsam/discrete/AlgebraicDecisionTree.h @@ -30,6 +30,13 @@ namespace gtsam { template class AlgebraicDecisionTree: public DecisionTree { + /** + * @brief Default method for comparison of two doubles upto some tolerance. + */ + static bool DefaultComparator(double a, double b, double tol) { + return std::abs(a - b) < tol; + } + public: typedef DecisionTree Super; @@ -138,6 +145,12 @@ namespace gtsam { return this->combine(labelC, &Ring::add); } + /// Equality method customized to node type `double`. + bool equals(const AlgebraicDecisionTree& other, double tol = 1e-9, + const std::function& comparator = + &DefaultComparator) const { + return this->root_->equals(*other.root_, tol, comparator); + } }; // AlgebraicDecisionTree diff --git a/gtsam/discrete/DecisionTree-inl.h b/gtsam/discrete/DecisionTree-inl.h index ec6222fd7..f531e2b98 100644 --- a/gtsam/discrete/DecisionTree-inl.h +++ b/gtsam/discrete/DecisionTree-inl.h @@ -20,7 +20,6 @@ #pragma once #include -#include #include #include @@ -77,33 +76,13 @@ namespace gtsam { return (q.isLeaf() && q.sameLeaf(*this)); } - /// @{ - /// SFINAE methods for proper substitution. - /** equality for integral types. */ - template - typename std::enable_if::value, bool>::type - equals(const T& a, const T& b, double tol) const { - return std::abs(double(a - b)) < tol; - } - /** equality for boost::shared_ptr types. */ - template - typename std::enable_if::value, bool>::type - equals(const T& a, const T& b, double tol) const { - return traits::Equals(*a, *b, tol); - } - /** equality for all other types. */ - template - typename std::enable_if::value && !std::is_integral::value, bool>::type - equals(const Y& a, const Y& b, double tol) const { - return traits::Equals(a, b, tol); - } - /// @} - /** equality up to tolerance */ - bool equals(const Node& q, double tol) const override { + bool equals(const Node& q, double tol, + const std::function& + comparator) const override { const Leaf* other = dynamic_cast(&q); if (!other) return false; - return this->equals(this->constant_, other->constant_, tol); + return comparator(this->constant_, other->constant_, tol); } /** print */ @@ -304,14 +283,17 @@ namespace gtsam { } /** equality up to tolerance */ - bool equals(const Node& q, double tol) const override { - const Choice* other = dynamic_cast (&q); + bool equals(const Node& q, double tol, + const std::function& + comparator) const override { + const Choice* other = dynamic_cast(&q); if (!other) return false; if (this->label_ != other->label_) return false; if (branches_.size() != other->branches_.size()) return false; // we don't care about shared pointers being equal here for (size_t i = 0; i < branches_.size(); i++) - if (!(branches_[i]->equals(*(other->branches_[i]), tol))) return false; + if (!(branches_[i]->equals(*(other->branches_[i]), tol, comparator))) + return false; return true; } @@ -668,9 +650,11 @@ namespace gtsam { } /*********************************************************************************/ - template - bool DecisionTree::equals(const DecisionTree& other, double tol) const { - return root_->equals(*other.root_, tol); + template + bool DecisionTree::equals( + const DecisionTree& other, double tol, + const std::function& comparator) const { + return root_->equals(*other.root_, tol, comparator); } template diff --git a/gtsam/discrete/DecisionTree.h b/gtsam/discrete/DecisionTree.h index 498fce5aa..eb23bc5ce 100644 --- a/gtsam/discrete/DecisionTree.h +++ b/gtsam/discrete/DecisionTree.h @@ -39,13 +39,18 @@ namespace gtsam { template class GTSAM_EXPORT DecisionTree { - /// default method used by `formatter` when printing. + /// Default method used by `formatter` when printing. static std::string DefaultFormatter(const L& x) { std::stringstream ss; ss << x; return ss.str(); } + /// Default method for comparison of two objects of type Y. + static bool DefaultComparator(const Y& a, const Y& b, double tol) { + return a == b; + } + public: /** Handy typedefs for unary and binary function types */ @@ -92,7 +97,9 @@ namespace gtsam { virtual void dot(std::ostream& os, bool showZero) const = 0; virtual bool sameLeaf(const Leaf& q) const = 0; virtual bool sameLeaf(const Node& q) const = 0; - virtual bool equals(const Node& other, double tol = 1e-9) const = 0; + virtual bool equals(const Node& other, double tol = 1e-9, + const std::function& + comparator = &DefaultComparator) const = 0; virtual const Y& operator()(const Assignment& x) const = 0; virtual Ptr apply(const Unary& op) const = 0; virtual Ptr apply_f_op_g(const Node&, const Binary&) const = 0; @@ -178,7 +185,9 @@ namespace gtsam { &DefaultFormatter) const; // Testable - bool equals(const DecisionTree& other, double tol = 1e-9) const; + bool equals(const DecisionTree& other, double tol = 1e-9, + const std::function& + comparator = &DefaultComparator) const; /// @} /// @name Standard Interface diff --git a/gtsam/discrete/tests/testDecisionTree.cpp b/gtsam/discrete/tests/testDecisionTree.cpp index 96f503abc..dbf6cc44b 100644 --- a/gtsam/discrete/tests/testDecisionTree.cpp +++ b/gtsam/discrete/tests/testDecisionTree.cpp @@ -40,7 +40,19 @@ void dot(const T&f, const string& filename) { #define DOT(x)(dot(x,#x)) -struct Crazy { int a; double b; }; +struct Crazy { + int a; + double b; + + bool equals(const Crazy& other, double tol = 1e-12) const { + return a == other.a && std::abs(b - other.b) < tol; + } + + bool operator==(const Crazy& other) const { + return this->equals(other); + } +}; + typedef DecisionTree CrazyDecisionTree; // check that DecisionTree is actually generic (as it pretends to be) // traits From 26c48a8c1fc33f9bcebfbcfd0c813ea30c551811 Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Thu, 30 Dec 2021 15:28:07 -0500 Subject: [PATCH 13/26] address more review comments # Conflicts: # gtsam/hybrid/DCGaussianMixtureFactor.h # gtsam/hybrid/DCMixtureFactor.h --- gtsam/discrete/AlgebraicDecisionTree.h | 17 ++++++----------- gtsam/discrete/DecisionTree-inl.h | 21 ++++++++------------- gtsam/discrete/DecisionTree.h | 21 +++++++++++---------- 3 files changed, 25 insertions(+), 34 deletions(-) diff --git a/gtsam/discrete/AlgebraicDecisionTree.h b/gtsam/discrete/AlgebraicDecisionTree.h index 3469ac5cf..f47e01668 100644 --- a/gtsam/discrete/AlgebraicDecisionTree.h +++ b/gtsam/discrete/AlgebraicDecisionTree.h @@ -30,14 +30,7 @@ namespace gtsam { template class AlgebraicDecisionTree: public DecisionTree { - /** - * @brief Default method for comparison of two doubles upto some tolerance. - */ - static bool DefaultComparator(double a, double b, double tol) { - return std::abs(a - b) < tol; - } - - public: + public: typedef DecisionTree Super; @@ -146,9 +139,11 @@ namespace gtsam { } /// Equality method customized to node type `double`. - bool equals(const AlgebraicDecisionTree& other, double tol = 1e-9, - const std::function& comparator = - &DefaultComparator) const { + bool equals(const AlgebraicDecisionTree& other, double tol = 1e-9) const { + // lambda for comparison of two doubles upto some tolerance. + auto comparator = [](double a, double b, double tol) { + return std::abs(a - b) < tol; + }; return this->root_->equals(*other.root_, tol, comparator); } }; diff --git a/gtsam/discrete/DecisionTree-inl.h b/gtsam/discrete/DecisionTree-inl.h index f531e2b98..fb6c78148 100644 --- a/gtsam/discrete/DecisionTree-inl.h +++ b/gtsam/discrete/DecisionTree-inl.h @@ -19,7 +19,6 @@ #pragma once -#include #include #include @@ -78,8 +77,7 @@ namespace gtsam { /** equality up to tolerance */ bool equals(const Node& q, double tol, - const std::function& - comparator) const override { + const ComparatorFunc& comparator) const override { const Leaf* other = dynamic_cast(&q); if (!other) return false; return comparator(this->constant_, other->constant_, tol); @@ -87,7 +85,7 @@ namespace gtsam { /** print */ void print(const std::string& s, - const std::function& formatter) const override { + const FormatterFunc& formatter) const override { bool showZero = true; if (showZero || constant_) std::cout << s << " Leaf " << constant_ << std::endl; } @@ -241,7 +239,7 @@ namespace gtsam { /** print (as a tree) */ void print(const std::string& s, - const std::function& formatter) const override { + const FormatterFunc& formatter) const override { std::cout << s << " Choice("; std::cout << formatter(label_) << ") " << std::endl; for (size_t i = 0; i < branches_.size(); i++) @@ -284,8 +282,7 @@ namespace gtsam { /** equality up to tolerance */ bool equals(const Node& q, double tol, - const std::function& - comparator) const override { + const ComparatorFunc& comparator) const override { const Choice* other = dynamic_cast(&q); if (!other) return false; if (this->label_ != other->label_) return false; @@ -651,16 +648,14 @@ namespace gtsam { /*********************************************************************************/ template - bool DecisionTree::equals( - const DecisionTree& other, double tol, - const std::function& comparator) const { + bool DecisionTree::equals(const DecisionTree& other, double tol, + const ComparatorFunc& comparator) const { return root_->equals(*other.root_, tol, comparator); } template - void DecisionTree::print( - const std::string& s, - const std::function& formatter) const { + void DecisionTree::print(const std::string& s, + const FormatterFunc& formatter) const { root_->print(s, formatter); } diff --git a/gtsam/discrete/DecisionTree.h b/gtsam/discrete/DecisionTree.h index eb23bc5ce..8e6c0e4d7 100644 --- a/gtsam/discrete/DecisionTree.h +++ b/gtsam/discrete/DecisionTree.h @@ -53,6 +53,9 @@ namespace gtsam { public: + using FormatterFunc = std::function; + using ComparatorFunc = std::function; + /** Handy typedefs for unary and binary function types */ typedef std::function Unary; typedef std::function Binary; @@ -91,15 +94,15 @@ namespace gtsam { const void* id() const { return this; } // everything else is virtual, no documentation here as internal - virtual void print(const std::string& s = "", - const std::function& formatter = - &DefaultFormatter) const = 0; + virtual void print( + const std::string& s = "", + const FormatterFunc& formatter = &DefaultFormatter) const = 0; virtual void dot(std::ostream& os, bool showZero) const = 0; virtual bool sameLeaf(const Leaf& q) const = 0; virtual bool sameLeaf(const Node& q) const = 0; - virtual bool equals(const Node& other, double tol = 1e-9, - const std::function& - comparator = &DefaultComparator) const = 0; + virtual bool equals( + const Node& other, double tol = 1e-9, + const ComparatorFunc& comparator = &DefaultComparator) const = 0; virtual const Y& operator()(const Assignment& x) const = 0; virtual Ptr apply(const Unary& op) const = 0; virtual Ptr apply_f_op_g(const Node&, const Binary&) const = 0; @@ -181,13 +184,11 @@ namespace gtsam { /** GTSAM-style print */ void print(const std::string& s = "DecisionTree", - const std::function& formatter = - &DefaultFormatter) const; + const FormatterFunc& formatter = &DefaultFormatter) const; // Testable bool equals(const DecisionTree& other, double tol = 1e-9, - const std::function& - comparator = &DefaultComparator) const; + const ComparatorFunc& comparator = &DefaultComparator) const; /// @} /// @name Standard Interface From 731cff746bf3d2224883693afe5cbf4d40ccef17 Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Thu, 30 Dec 2021 17:27:40 -0500 Subject: [PATCH 14/26] rename comparator to compare and capture tol in the function lambda. # Conflicts: # gtsam/hybrid/DCMixtureFactor.h --- gtsam/discrete/AlgebraicDecisionTree.h | 4 ++-- gtsam/discrete/DecisionTree-inl.h | 12 ++++++------ gtsam/discrete/DecisionTree.h | 8 ++++---- 3 files changed, 12 insertions(+), 12 deletions(-) diff --git a/gtsam/discrete/AlgebraicDecisionTree.h b/gtsam/discrete/AlgebraicDecisionTree.h index f47e01668..acdbf63a3 100644 --- a/gtsam/discrete/AlgebraicDecisionTree.h +++ b/gtsam/discrete/AlgebraicDecisionTree.h @@ -141,10 +141,10 @@ namespace gtsam { /// Equality method customized to node type `double`. bool equals(const AlgebraicDecisionTree& other, double tol = 1e-9) const { // lambda for comparison of two doubles upto some tolerance. - auto comparator = [](double a, double b, double tol) { + auto compare = [tol](double a, double b) { return std::abs(a - b) < tol; }; - return this->root_->equals(*other.root_, tol, comparator); + return this->root_->equals(*other.root_, tol, compare); } }; // AlgebraicDecisionTree diff --git a/gtsam/discrete/DecisionTree-inl.h b/gtsam/discrete/DecisionTree-inl.h index fb6c78148..209c2ad80 100644 --- a/gtsam/discrete/DecisionTree-inl.h +++ b/gtsam/discrete/DecisionTree-inl.h @@ -77,10 +77,10 @@ namespace gtsam { /** equality up to tolerance */ bool equals(const Node& q, double tol, - const ComparatorFunc& comparator) const override { + const CompareFunc& compare) const override { const Leaf* other = dynamic_cast(&q); if (!other) return false; - return comparator(this->constant_, other->constant_, tol); + return compare(this->constant_, other->constant_); } /** print */ @@ -282,14 +282,14 @@ namespace gtsam { /** equality up to tolerance */ bool equals(const Node& q, double tol, - const ComparatorFunc& comparator) const override { + const CompareFunc& compare) const override { const Choice* other = dynamic_cast(&q); if (!other) return false; if (this->label_ != other->label_) return false; if (branches_.size() != other->branches_.size()) return false; // we don't care about shared pointers being equal here for (size_t i = 0; i < branches_.size(); i++) - if (!(branches_[i]->equals(*(other->branches_[i]), tol, comparator))) + if (!(branches_[i]->equals(*(other->branches_[i]), tol, compare))) return false; return true; } @@ -649,8 +649,8 @@ namespace gtsam { /*********************************************************************************/ template bool DecisionTree::equals(const DecisionTree& other, double tol, - const ComparatorFunc& comparator) const { - return root_->equals(*other.root_, tol, comparator); + const CompareFunc& compare) const { + return root_->equals(*other.root_, tol, compare); } template diff --git a/gtsam/discrete/DecisionTree.h b/gtsam/discrete/DecisionTree.h index 8e6c0e4d7..26817bf79 100644 --- a/gtsam/discrete/DecisionTree.h +++ b/gtsam/discrete/DecisionTree.h @@ -47,14 +47,14 @@ namespace gtsam { } /// Default method for comparison of two objects of type Y. - static bool DefaultComparator(const Y& a, const Y& b, double tol) { + static bool DefaultCompare(const Y& a, const Y& b) { return a == b; } public: using FormatterFunc = std::function; - using ComparatorFunc = std::function; + using CompareFunc = std::function; /** Handy typedefs for unary and binary function types */ typedef std::function Unary; @@ -102,7 +102,7 @@ namespace gtsam { virtual bool sameLeaf(const Node& q) const = 0; virtual bool equals( const Node& other, double tol = 1e-9, - const ComparatorFunc& comparator = &DefaultComparator) const = 0; + const CompareFunc& compare = &DefaultCompare) const = 0; virtual const Y& operator()(const Assignment& x) const = 0; virtual Ptr apply(const Unary& op) const = 0; virtual Ptr apply_f_op_g(const Node&, const Binary&) const = 0; @@ -188,7 +188,7 @@ namespace gtsam { // Testable bool equals(const DecisionTree& other, double tol = 1e-9, - const ComparatorFunc& comparator = &DefaultComparator) const; + const CompareFunc& compare = &DefaultCompare) const; /// @} /// @name Standard Interface From 7f3f332d09acadd5a6dbcacea7a0c3aca24b5a66 Mon Sep 17 00:00:00 2001 From: Frank Dellaert Date: Sat, 1 Jan 2022 22:10:54 -0500 Subject: [PATCH 15/26] Removed copy/paste convert --- gtsam/discrete/AlgebraicDecisionTree.h | 9 ++-- gtsam/discrete/DecisionTree-inl.h | 69 ++++++++------------------ gtsam/discrete/DecisionTree.h | 26 ++++------ 3 files changed, 37 insertions(+), 67 deletions(-) diff --git a/gtsam/discrete/AlgebraicDecisionTree.h b/gtsam/discrete/AlgebraicDecisionTree.h index acdbf63a3..72ea5e79f 100644 --- a/gtsam/discrete/AlgebraicDecisionTree.h +++ b/gtsam/discrete/AlgebraicDecisionTree.h @@ -108,9 +108,12 @@ namespace gtsam { /** Convert */ template AlgebraicDecisionTree(const AlgebraicDecisionTree& other, - const std::map& map) { - this->root_ = this->template convert(other.root_, map, - Ring::id); + const std::map& map) { + std::function map_function = [&map](const M& label) -> L { + return map.at(label); + }; + std::function op = Ring::id; + this->root_ = this->template convert(other.root_, op, map_function); } /** sum */ diff --git a/gtsam/discrete/DecisionTree-inl.h b/gtsam/discrete/DecisionTree-inl.h index 209c2ad80..96f1421ce 100644 --- a/gtsam/discrete/DecisionTree-inl.h +++ b/gtsam/discrete/DecisionTree-inl.h @@ -453,20 +453,24 @@ namespace gtsam { root_ = compose(functions.begin(), functions.end(), label); } - /*********************************************************************************/ - template - template - DecisionTree::DecisionTree(const DecisionTree& other, - const std::map& map, std::function op) { - root_ = convert(other.root_, map, op); - } - /*********************************************************************************/ template template DecisionTree::DecisionTree(const DecisionTree& other, std::function op) { - root_ = convert(other.root_, op); + auto map = [](const L& label) { return label; }; + root_ = convert(other.root_, op, map); + } + + /*********************************************************************************/ + template + template + DecisionTree::DecisionTree(const DecisionTree& other, + const std::map& map, std::function op) { + std::function map_function = [&map](const M& label) -> L { + return map.at(label); + }; + root_ = convert(other.root_, op, map_function); } /*********************************************************************************/ @@ -579,12 +583,11 @@ namespace gtsam { } /*********************************************************************************/ - template - template + template + template typename DecisionTree::NodePtr DecisionTree::convert( - const typename DecisionTree::NodePtr& f, const std::map& map, - std::function op) { - + const typename DecisionTree::NodePtr& f, + std::function op, std::function map) { typedef DecisionTree MX; typedef typename MX::Leaf MXLeaf; typedef typename MX::Choice MXChoice; @@ -602,50 +605,18 @@ namespace gtsam { "DecisionTree::Convert: Invalid NodePtr"); // get new label - M oldLabel = choice->label(); - L newLabel = map.at(oldLabel); + const M oldLabel = choice->label(); + const L newLabel = map(oldLabel); // put together via Shannon expansion otherwise not sorted. std::vector functions; for(const MXNodePtr& branch: choice->branches()) { - LY converted(convert(branch, map, op)); + LY converted(convert(branch, op, map)); functions += converted; } return LY::compose(functions.begin(), functions.end(), newLabel); } - /*********************************************************************************/ - template - template - typename DecisionTree::NodePtr DecisionTree::convert( - const typename DecisionTree::NodePtr& f, - std::function op) { - - typedef DecisionTree LX; - typedef typename LX::Leaf LXLeaf; - typedef typename LX::Choice LXChoice; - typedef typename LX::NodePtr LXNodePtr; - typedef DecisionTree LY; - - // ugliness below because apparently we can't have templated virtual functions - // If leaf, apply unary conversion "op" and create a unique leaf - const LXLeaf* leaf = dynamic_cast (f.get()); - if (leaf) return NodePtr(new Leaf(op(leaf->constant()))); - - // Check if Choice - boost::shared_ptr choice = boost::dynamic_pointer_cast (f); - if (!choice) throw std::invalid_argument( - "DecisionTree::Convert: Invalid NodePtr"); - - // put together via Shannon expansion otherwise not sorted. - std::vector functions; - for(const LXNodePtr& branch: choice->branches()) { - LY converted(convert(branch, op)); - functions += converted; - } - return LY::compose(functions.begin(), functions.end(), choice->label()); - } - /*********************************************************************************/ template bool DecisionTree::equals(const DecisionTree& other, double tol, diff --git a/gtsam/discrete/DecisionTree.h b/gtsam/discrete/DecisionTree.h index 26817bf79..baf2a79fa 100644 --- a/gtsam/discrete/DecisionTree.h +++ b/gtsam/discrete/DecisionTree.h @@ -127,15 +127,11 @@ namespace gtsam { template NodePtr create(It begin, It end, ValueIt beginY, ValueIt endY) const; - /** Convert to a different type */ - template NodePtr - convert(const typename DecisionTree::NodePtr& f, const std::map& map, std::function op); - - /** Convert only node to a different type */ - template - NodePtr convert(const typename DecisionTree::NodePtr& f, - const std::function op); + /// Convert to a different type, will not convert label if map empty. + template + NodePtr convert(const typename DecisionTree::NodePtr& f, + std::function op, + std::function map); public: @@ -168,16 +164,16 @@ namespace gtsam { DecisionTree(const L& label, // const DecisionTree& f0, const DecisionTree& f1); - /** Convert from a different type */ - template - DecisionTree(const DecisionTree& other, - const std::map& map, std::function op); - - /** Convert only nodes from a different type */ + /** Convert from a different type. */ template DecisionTree(const DecisionTree& other, std::function op); + /** Convert from a different type, also transate labels via map. */ + template + DecisionTree(const DecisionTree& other, + const std::map& map, std::function op); + /// @} /// @name Testable /// @{ From 78f8cc948d05e1fd11e150e092b638e59dcadf93 Mon Sep 17 00:00:00 2001 From: Frank Dellaert Date: Sun, 2 Jan 2022 09:47:15 -0500 Subject: [PATCH 16/26] Define empty and check for it in apply variants --- gtsam/discrete/DecisionTree-inl.h | 10 ++++++++++ gtsam/discrete/DecisionTree.h | 3 +++ gtsam/discrete/tests/testDecisionTree.cpp | 11 +++++++++++ 3 files changed, 24 insertions(+) diff --git a/gtsam/discrete/DecisionTree-inl.h b/gtsam/discrete/DecisionTree-inl.h index 96f1421ce..fbdeae460 100644 --- a/gtsam/discrete/DecisionTree-inl.h +++ b/gtsam/discrete/DecisionTree-inl.h @@ -642,6 +642,11 @@ namespace gtsam { template DecisionTree DecisionTree::apply(const Unary& op) const { + // It is unclear what should happen if tree is empty: + if (empty()) { + throw std::runtime_error( + "DecisionTree::apply(unary op) undefined for empty tree."); + } return DecisionTree(root_->apply(op)); } @@ -649,6 +654,11 @@ namespace gtsam { template DecisionTree DecisionTree::apply(const DecisionTree& g, const Binary& op) const { + // It is unclear what should happen if either tree is empty: + if (empty() or g.empty()) { + throw std::runtime_error( + "DecisionTree::apply(binary op) undefined for empty trees."); + } // apply the operaton on the root of both diagrams NodePtr h = root_->apply_f_op_g(*g.root_, op); // create a new class with the resulting root "h" diff --git a/gtsam/discrete/DecisionTree.h b/gtsam/discrete/DecisionTree.h index baf2a79fa..5cf92f157 100644 --- a/gtsam/discrete/DecisionTree.h +++ b/gtsam/discrete/DecisionTree.h @@ -194,6 +194,9 @@ namespace gtsam { virtual ~DecisionTree() { } + /** empty tree? */ + bool empty() const { return !root_; } + /** equality */ bool operator==(const DecisionTree& q) const; diff --git a/gtsam/discrete/tests/testDecisionTree.cpp b/gtsam/discrete/tests/testDecisionTree.cpp index dbf6cc44b..c7ee6cc2a 100644 --- a/gtsam/discrete/tests/testDecisionTree.cpp +++ b/gtsam/discrete/tests/testDecisionTree.cpp @@ -78,6 +78,9 @@ struct Ring { static inline int one() { return 1; } + static inline int id(const int& a) { + return a; + } static inline int add(const int& a, const int& b) { return a + b; } @@ -100,6 +103,9 @@ TEST(DT, example) x10[A] = 1, x10[B] = 0; x11[A] = 1, x11[B] = 1; + // empty + DT empty; + // A DT a(A, 0, 5); LONGS_EQUAL(0,a(x00)) @@ -118,6 +124,11 @@ TEST(DT, example) LONGS_EQUAL(5,notb(x10)) DOT(notb); + // Check supplying empty trees yields an exception + CHECK_EXCEPTION(apply(empty, &Ring::id), std::runtime_error); + CHECK_EXCEPTION(apply(empty, a, &Ring::mul), std::runtime_error); + CHECK_EXCEPTION(apply(a, empty, &Ring::mul), std::runtime_error); + // apply, two nodes, in natural order DT anotb = apply(a, notb, &Ring::mul); LONGS_EQUAL(0,anotb(x00)) From db3cb4d9ac53f71872962fbab38eb4d82bf24321 Mon Sep 17 00:00:00 2001 From: Frank Dellaert Date: Sun, 2 Jan 2022 13:57:12 -0500 Subject: [PATCH 17/26] Refactor print, equals, convert --- gtsam/discrete/AlgebraicDecisionTree.h | 20 +- gtsam/discrete/DecisionTree-inl.h | 91 +++++---- gtsam/discrete/DecisionTree.h | 54 ++--- gtsam/discrete/tests/testDecisionTree.cpp | 236 +++++++++++++--------- 4 files changed, 238 insertions(+), 163 deletions(-) diff --git a/gtsam/discrete/AlgebraicDecisionTree.h b/gtsam/discrete/AlgebraicDecisionTree.h index 72ea5e79f..37130ab72 100644 --- a/gtsam/discrete/AlgebraicDecisionTree.h +++ b/gtsam/discrete/AlgebraicDecisionTree.h @@ -28,7 +28,13 @@ namespace gtsam { * TODO: consider eliminating this class altogether? */ template - class AlgebraicDecisionTree: public DecisionTree { + class GTSAM_EXPORT AlgebraicDecisionTree: public DecisionTree { + /// Default method used by `formatter` when printing. + static std::string DefaultFormatter(const L& x) { + std::stringstream ss; + ss << x; + return ss.str(); + } public: @@ -141,13 +147,23 @@ namespace gtsam { return this->combine(labelC, &Ring::add); } + /// print method customized to node type `double`. + void print(const std::string& s, + const typename Super::LabelFormatter& labelFormatter = + &DefaultFormatter) const { + auto valueFormatter = [](const double& v) { + return (boost::format("%4.2g") % v).str(); + }; + Super::print(s, labelFormatter, valueFormatter); + } + /// Equality method customized to node type `double`. bool equals(const AlgebraicDecisionTree& other, double tol = 1e-9) const { // lambda for comparison of two doubles upto some tolerance. auto compare = [tol](double a, double b) { return std::abs(a - b) < tol; }; - return this->root_->equals(*other.root_, tol, compare); + return Super::equals(other, compare); } }; // AlgebraicDecisionTree diff --git a/gtsam/discrete/DecisionTree-inl.h b/gtsam/discrete/DecisionTree-inl.h index fbdeae460..b31773702 100644 --- a/gtsam/discrete/DecisionTree-inl.h +++ b/gtsam/discrete/DecisionTree-inl.h @@ -76,25 +76,26 @@ namespace gtsam { } /** equality up to tolerance */ - bool equals(const Node& q, double tol, - const CompareFunc& compare) const override { + bool equals(const Node& q, const CompareFunc& compare) const override { const Leaf* other = dynamic_cast(&q); if (!other) return false; return compare(this->constant_, other->constant_); } /** print */ - void print(const std::string& s, - const FormatterFunc& formatter) const override { - bool showZero = true; - if (showZero || constant_) std::cout << s << " Leaf " << constant_ << std::endl; + void print(const std::string& s, const LabelFormatter& labelFormatter, + const ValueFormatter& valueFormatter) const override { + std::cout << s << " Leaf " << valueFormatter(constant_) << std::endl; } /** to graphviz file */ - void dot(std::ostream& os, bool showZero) const override { - if (showZero || constant_) os << "\"" << this->id() << "\" [label=\"" - << boost::format("%4.2g") % constant_ - << "\", shape=box, rank=sink, height=0.35, fixedsize=true]\n"; // width=0.55, + void dot(std::ostream& os, const LabelFormatter& labelFormatter, + const ValueFormatter& valueFormatter, + bool showZero) const override { + std::string value = valueFormatter(constant_); + if (showZero || value.compare("0")) + os << "\"" << this->id() << "\" [label=\"" << value + << "\", shape=box, rank=sink, height=0.35, fixedsize=true]\n"; // width=0.55, } /** evaluate */ @@ -238,16 +239,19 @@ namespace gtsam { } /** print (as a tree) */ - void print(const std::string& s, - const FormatterFunc& formatter) const override { + void print(const std::string& s, const LabelFormatter& labelFormatter, + const ValueFormatter& valueFormatter) const override { std::cout << s << " Choice("; - std::cout << formatter(label_) << ") " << std::endl; + std::cout << labelFormatter(label_) << ") " << std::endl; for (size_t i = 0; i < branches_.size(); i++) - branches_[i]->print((boost::format("%s %d") % s % i).str(), formatter); + branches_[i]->print((boost::format("%s %d") % s % i).str(), + labelFormatter, valueFormatter); } /** output to graphviz (as a a graph) */ - void dot(std::ostream& os, bool showZero) const override { + void dot(std::ostream& os, const LabelFormatter& labelFormatter, + const ValueFormatter& valueFormatter, + bool showZero) const override { os << "\"" << this->id() << "\" [shape=circle, label=\"" << label_ << "\"]\n"; size_t B = branches_.size(); @@ -257,7 +261,8 @@ namespace gtsam { // Check if zero if (!showZero) { const Leaf* leaf = dynamic_cast (branch.get()); - if (leaf && !leaf->constant()) continue; + std::string value = valueFormatter(leaf->constant()); + if (leaf && value.compare("0")) continue; } os << "\"" << this->id() << "\" -> \"" << branch->id() << "\""; @@ -266,7 +271,7 @@ namespace gtsam { if (i > 1) os << " [style=bold]"; } os << std::endl; - branch->dot(os, showZero); + branch->dot(os, labelFormatter, valueFormatter, showZero); } } @@ -280,16 +285,15 @@ namespace gtsam { return (q.isLeaf() && q.sameLeaf(*this)); } - /** equality up to tolerance */ - bool equals(const Node& q, double tol, - const CompareFunc& compare) const override { + /** equality */ + bool equals(const Node& q, const CompareFunc& compare) const override { const Choice* other = dynamic_cast(&q); if (!other) return false; if (this->label_ != other->label_) return false; if (branches_.size() != other->branches_.size()) return false; // we don't care about shared pointers being equal here for (size_t i = 0; i < branches_.size(); i++) - if (!(branches_[i]->equals(*(other->branches_[i]), tol, compare))) + if (!(branches_[i]->equals(*(other->branches_[i]), compare))) return false; return true; } @@ -459,7 +463,7 @@ namespace gtsam { DecisionTree::DecisionTree(const DecisionTree& other, std::function op) { auto map = [](const L& label) { return label; }; - root_ = convert(other.root_, op, map); + root_ = other.template convert(op, map); } /*********************************************************************************/ @@ -470,7 +474,7 @@ namespace gtsam { std::function map_function = [&map](const M& label) -> L { return map.at(label); }; - root_ = convert(other.root_, op, map_function); + root_ = other.template convert(op, map_function); } /*********************************************************************************/ @@ -587,7 +591,7 @@ namespace gtsam { template typename DecisionTree::NodePtr DecisionTree::convert( const typename DecisionTree::NodePtr& f, - std::function op, std::function map) { + std::function op, std::function map) const { typedef DecisionTree MX; typedef typename MX::Leaf MXLeaf; typedef typename MX::Choice MXChoice; @@ -596,11 +600,11 @@ namespace gtsam { // ugliness below because apparently we can't have templated virtual functions // If leaf, apply unary conversion "op" and create a unique leaf - const MXLeaf* leaf = dynamic_cast (f.get()); + auto leaf = boost::dynamic_pointer_cast(f); if (leaf) return NodePtr(new Leaf(op(leaf->constant()))); // Check if Choice - boost::shared_ptr choice = boost::dynamic_pointer_cast (f); + auto choice = boost::dynamic_pointer_cast(f); if (!choice) throw std::invalid_argument( "DecisionTree::Convert: Invalid NodePtr"); @@ -619,15 +623,16 @@ namespace gtsam { /*********************************************************************************/ template - bool DecisionTree::equals(const DecisionTree& other, double tol, + bool DecisionTree::equals(const DecisionTree& other, const CompareFunc& compare) const { - return root_->equals(*other.root_, tol, compare); + return root_->equals(*other.root_, compare); } template void DecisionTree::print(const std::string& s, - const FormatterFunc& formatter) const { - root_->print(s, formatter); + const LabelFormatter& labelFormatter, + const ValueFormatter& valueFormatter) const { + root_->print(s, labelFormatter, valueFormatter); } template @@ -687,26 +692,34 @@ namespace gtsam { } /*********************************************************************************/ - template - void DecisionTree::dot(std::ostream& os, bool showZero) const { + template + void DecisionTree::dot(std::ostream& os, + const LabelFormatter& labelFormatter, + const ValueFormatter& valueFormatter, + bool showZero) const { os << "digraph G {\n"; - root_->dot(os, showZero); + root_->dot(os, labelFormatter, valueFormatter, showZero); os << " [ordering=out]}" << std::endl; } - template - void DecisionTree::dot(const std::string& name, bool showZero) const { + template + void DecisionTree::dot(const std::string& name, + const LabelFormatter& labelFormatter, + const ValueFormatter& valueFormatter, + bool showZero) const { std::ofstream os((name + ".dot").c_str()); - dot(os, showZero); + dot(os, labelFormatter, valueFormatter, showZero); int result = system( ("dot -Tpdf " + name + ".dot -o " + name + ".pdf >& /dev/null").c_str()); if (result==-1) throw std::runtime_error("DecisionTree::dot system call failed"); } - template - std::string DecisionTree::dot(bool showZero) const { + template + std::string DecisionTree::dot(const LabelFormatter& labelFormatter, + const ValueFormatter& valueFormatter, + bool showZero) const { std::stringstream ss; - dot(ss, showZero); + dot(ss, labelFormatter, valueFormatter, showZero); return ss.str(); } diff --git a/gtsam/discrete/DecisionTree.h b/gtsam/discrete/DecisionTree.h index 5cf92f157..4c79c5841 100644 --- a/gtsam/discrete/DecisionTree.h +++ b/gtsam/discrete/DecisionTree.h @@ -39,13 +39,6 @@ namespace gtsam { template class GTSAM_EXPORT DecisionTree { - /// Default method used by `formatter` when printing. - static std::string DefaultFormatter(const L& x) { - std::stringstream ss; - ss << x; - return ss.str(); - } - /// Default method for comparison of two objects of type Y. static bool DefaultCompare(const Y& a, const Y& b) { return a == b; @@ -53,7 +46,8 @@ namespace gtsam { public: - using FormatterFunc = std::function; + using LabelFormatter = std::function; + using ValueFormatter = std::function; using CompareFunc = std::function; /** Handy typedefs for unary and binary function types */ @@ -94,15 +88,16 @@ namespace gtsam { const void* id() const { return this; } // everything else is virtual, no documentation here as internal - virtual void print( - const std::string& s = "", - const FormatterFunc& formatter = &DefaultFormatter) const = 0; - virtual void dot(std::ostream& os, bool showZero) const = 0; + virtual void print(const std::string& s, + const LabelFormatter& labelFormatter, + const ValueFormatter& valueFormatter) const = 0; + virtual void dot(std::ostream& os, const LabelFormatter& labelFormatter, + const ValueFormatter& valueFormatter, + bool showZero) const = 0; virtual bool sameLeaf(const Leaf& q) const = 0; virtual bool sameLeaf(const Node& q) const = 0; - virtual bool equals( - const Node& other, double tol = 1e-9, - const CompareFunc& compare = &DefaultCompare) const = 0; + virtual bool equals(const Node& other, const CompareFunc& compare = + &DefaultCompare) const = 0; virtual const Y& operator()(const Assignment& x) const = 0; virtual Ptr apply(const Unary& op) const = 0; virtual Ptr apply_f_op_g(const Node&, const Binary&) const = 0; @@ -118,11 +113,11 @@ namespace gtsam { /** A function is a shared pointer to the root of a DT */ typedef typename Node::Ptr NodePtr; + protected: + /* a DecisionTree just contains the root */ NodePtr root_; - protected: - /** Internal recursive function to create from keys, cardinalities, and Y values */ template NodePtr create(It begin, It end, ValueIt beginY, ValueIt endY) const; @@ -131,7 +126,14 @@ namespace gtsam { template NodePtr convert(const typename DecisionTree::NodePtr& f, std::function op, - std::function map); + std::function map) const; + + /// Convert to a different type, will not convert label if map empty. + template + NodePtr convert(std::function op, + std::function map) const { + return convert(root_, op, map); + } public: @@ -179,11 +181,11 @@ namespace gtsam { /// @{ /** GTSAM-style print */ - void print(const std::string& s = "DecisionTree", - const FormatterFunc& formatter = &DefaultFormatter) const; + void print(const std::string& s, const LabelFormatter& labelFormatter, + const ValueFormatter& valueFormatter) const; // Testable - bool equals(const DecisionTree& other, double tol = 1e-9, + bool equals(const DecisionTree& other, const CompareFunc& compare = &DefaultCompare) const; /// @} @@ -225,13 +227,17 @@ namespace gtsam { } /** output to graphviz format, stream version */ - void dot(std::ostream& os, bool showZero = true) const; + void dot(std::ostream& os, const LabelFormatter& labelFormatter, + const ValueFormatter& valueFormatter, bool showZero = true) const; /** output to graphviz format, open a file */ - void dot(const std::string& name, bool showZero = true) const; + void dot(const std::string& name, const LabelFormatter& labelFormatter, + const ValueFormatter& valueFormatter, bool showZero = true) const; /** output to graphviz format string */ - std::string dot(bool showZero = true) const; + std::string dot(const LabelFormatter& labelFormatter, + const ValueFormatter& valueFormatter, + bool showZero = true) const; /// @name Advanced Interface /// @{ diff --git a/gtsam/discrete/tests/testDecisionTree.cpp b/gtsam/discrete/tests/testDecisionTree.cpp index c7ee6cc2a..f42a590ae 100644 --- a/gtsam/discrete/tests/testDecisionTree.cpp +++ b/gtsam/discrete/tests/testDecisionTree.cpp @@ -43,34 +43,74 @@ void dot(const T&f, const string& filename) { struct Crazy { int a; double b; - - bool equals(const Crazy& other, double tol = 1e-12) const { - return a == other.a && std::abs(b - other.b) < tol; - } - - bool operator==(const Crazy& other) const { - return this->equals(other); - } }; -typedef DecisionTree CrazyDecisionTree; // check that DecisionTree is actually generic (as it pretends to be) +// bool equals(const Crazy& other, double tol = 1e-12) const { +// return a == other.a && std::abs(b - other.b) < tol; +// } + +// bool operator==(const Crazy& other) const { +// return this->equals(other); +// } +// }; + +struct CrazyDecisionTree : public DecisionTree { + /// print to stdout + void print(const std::string& s = "") const { + auto keyFormatter = [](const std::string& s) { return s; }; + auto valueFormatter = [](const Crazy& v) { + return (boost::format("{%d,%4.2g}") % v.a % v.b).str(); + }; + DecisionTree::print("", keyFormatter, valueFormatter); + } + /// Equality method customized to Crazy node type + bool equals(const CrazyDecisionTree& other, double tol = 1e-9) const { + auto compare = [tol](const Crazy& v, const Crazy& w) { + return v.a == w.a && std::abs(v.b - w.b) < tol; + }; + return DecisionTree::equals(other, compare); + } +}; // traits namespace gtsam { template<> struct traits : public Testable {}; } +GTSAM_CONCEPT_TESTABLE_INST(CrazyDecisionTree) + /* ******************************************************************************** */ // Test string labels and int range /* ******************************************************************************** */ -typedef DecisionTree DT; +struct DT : public DecisionTree { + using DecisionTree::DecisionTree; + DT(const DecisionTree& dt) : root_(dt.root_) {} + + /// print to stdout + void print(const std::string& s = "") const { + auto keyFormatter = [](const std::string& s) { return s; }; + auto valueFormatter = [](const int& v) { + return (boost::format("%d") % v).str(); + }; + DecisionTree::print("", keyFormatter, valueFormatter); + } + // /// Equality method customized to int node type + // bool equals(const CrazyDecisionTree& other, double tol = 1e-9) const { + // auto compare = [tol](const int& v, const int& w) { + // return v.a == w.a && std::abs(v.b - w.b) < tol; + // }; + // return DecisionTree::equals(other, compare); + // } +}; // traits namespace gtsam { template<> struct traits
: public Testable
{}; } +GTSAM_CONCEPT_TESTABLE_INST(DT) + struct Ring { static inline int zero() { return 0; @@ -91,111 +131,111 @@ struct Ring { /* ******************************************************************************** */ // test DT -TEST(DT, example) -{ - // Create labels - string A("A"), B("B"), C("C"); +// TEST(DT, example) +// { +// // Create labels +// string A("A"), B("B"), C("C"); - // create a value - Assignment x00, x01, x10, x11; - x00[A] = 0, x00[B] = 0; - x01[A] = 0, x01[B] = 1; - x10[A] = 1, x10[B] = 0; - x11[A] = 1, x11[B] = 1; +// // create a value +// Assignment x00, x01, x10, x11; +// x00[A] = 0, x00[B] = 0; +// x01[A] = 0, x01[B] = 1; +// x10[A] = 1, x10[B] = 0; +// x11[A] = 1, x11[B] = 1; - // empty - DT empty; +// // empty +// DT empty; - // A - DT a(A, 0, 5); - LONGS_EQUAL(0,a(x00)) - LONGS_EQUAL(5,a(x10)) - DOT(a); +// // A +// DT a(A, 0, 5); +// LONGS_EQUAL(0,a(x00)) +// LONGS_EQUAL(5,a(x10)) +// DOT(a); - // pruned - DT p(A, 2, 2); - LONGS_EQUAL(2,p(x00)) - LONGS_EQUAL(2,p(x10)) - DOT(p); +// // pruned +// DT p(A, 2, 2); +// LONGS_EQUAL(2,p(x00)) +// LONGS_EQUAL(2,p(x10)) +// DOT(p); - // \neg B - DT notb(B, 5, 0); - LONGS_EQUAL(5,notb(x00)) - LONGS_EQUAL(5,notb(x10)) - DOT(notb); +// // \neg B +// DT notb(B, 5, 0); +// LONGS_EQUAL(5,notb(x00)) +// LONGS_EQUAL(5,notb(x10)) +// DOT(notb); - // Check supplying empty trees yields an exception - CHECK_EXCEPTION(apply(empty, &Ring::id), std::runtime_error); - CHECK_EXCEPTION(apply(empty, a, &Ring::mul), std::runtime_error); - CHECK_EXCEPTION(apply(a, empty, &Ring::mul), std::runtime_error); +// // Check supplying empty trees yields an exception +// CHECK_EXCEPTION(apply(empty, &Ring::id), std::runtime_error); +// CHECK_EXCEPTION(apply(empty, a, &Ring::mul), std::runtime_error); +// CHECK_EXCEPTION(apply(a, empty, &Ring::mul), std::runtime_error); - // apply, two nodes, in natural order - DT anotb = apply(a, notb, &Ring::mul); - LONGS_EQUAL(0,anotb(x00)) - LONGS_EQUAL(0,anotb(x01)) - LONGS_EQUAL(25,anotb(x10)) - LONGS_EQUAL(0,anotb(x11)) - DOT(anotb); +// // apply, two nodes, in natural order +// DT anotb = apply(a, notb, &Ring::mul); +// LONGS_EQUAL(0,anotb(x00)) +// LONGS_EQUAL(0,anotb(x01)) +// LONGS_EQUAL(25,anotb(x10)) +// LONGS_EQUAL(0,anotb(x11)) +// DOT(anotb); - // check pruning - DT pnotb = apply(p, notb, &Ring::mul); - LONGS_EQUAL(10,pnotb(x00)) - LONGS_EQUAL( 0,pnotb(x01)) - LONGS_EQUAL(10,pnotb(x10)) - LONGS_EQUAL( 0,pnotb(x11)) - DOT(pnotb); +// // check pruning +// DT pnotb = apply(p, notb, &Ring::mul); +// LONGS_EQUAL(10,pnotb(x00)) +// LONGS_EQUAL( 0,pnotb(x01)) +// LONGS_EQUAL(10,pnotb(x10)) +// LONGS_EQUAL( 0,pnotb(x11)) +// DOT(pnotb); - // check pruning - DT zeros = apply(DT(A, 0, 0), notb, &Ring::mul); - LONGS_EQUAL(0,zeros(x00)) - LONGS_EQUAL(0,zeros(x01)) - LONGS_EQUAL(0,zeros(x10)) - LONGS_EQUAL(0,zeros(x11)) - DOT(zeros); +// // check pruning +// DT zeros = apply(DT(A, 0, 0), notb, &Ring::mul); +// LONGS_EQUAL(0,zeros(x00)) +// LONGS_EQUAL(0,zeros(x01)) +// LONGS_EQUAL(0,zeros(x10)) +// LONGS_EQUAL(0,zeros(x11)) +// DOT(zeros); - // apply, two nodes, in switched order - DT notba = apply(a, notb, &Ring::mul); - LONGS_EQUAL(0,notba(x00)) - LONGS_EQUAL(0,notba(x01)) - LONGS_EQUAL(25,notba(x10)) - LONGS_EQUAL(0,notba(x11)) - DOT(notba); +// // apply, two nodes, in switched order +// DT notba = apply(a, notb, &Ring::mul); +// LONGS_EQUAL(0,notba(x00)) +// LONGS_EQUAL(0,notba(x01)) +// LONGS_EQUAL(25,notba(x10)) +// LONGS_EQUAL(0,notba(x11)) +// DOT(notba); - // Test choose 0 - DT actual0 = notba.choose(A, 0); - EXPECT(assert_equal(DT(0.0), actual0)); - DOT(actual0); +// // Test choose 0 +// DT actual0 = notba.choose(A, 0); +// EXPECT(assert_equal(DT(0.0), actual0)); +// DOT(actual0); - // Test choose 1 - DT actual1 = notba.choose(A, 1); - EXPECT(assert_equal(DT(B, 25, 0), actual1)); - DOT(actual1); +// // Test choose 1 +// DT actual1 = notba.choose(A, 1); +// EXPECT(assert_equal(DT(B, 25, 0), actual1)); +// DOT(actual1); - // apply, two nodes at same level - DT a_and_a = apply(a, a, &Ring::mul); - LONGS_EQUAL(0,a_and_a(x00)) - LONGS_EQUAL(0,a_and_a(x01)) - LONGS_EQUAL(25,a_and_a(x10)) - LONGS_EQUAL(25,a_and_a(x11)) - DOT(a_and_a); +// // apply, two nodes at same level +// DT a_and_a = apply(a, a, &Ring::mul); +// LONGS_EQUAL(0,a_and_a(x00)) +// LONGS_EQUAL(0,a_and_a(x01)) +// LONGS_EQUAL(25,a_and_a(x10)) +// LONGS_EQUAL(25,a_and_a(x11)) +// DOT(a_and_a); - // create a function on C - DT c(C, 0, 5); +// // create a function on C +// DT c(C, 0, 5); - // and a model assigning stuff to C - Assignment x101; - x101[A] = 1, x101[B] = 0, x101[C] = 1; +// // and a model assigning stuff to C +// Assignment x101; +// x101[A] = 1, x101[B] = 0, x101[C] = 1; - // mul notba with C - DT notbac = apply(notba, c, &Ring::mul); - LONGS_EQUAL(125,notbac(x101)) - DOT(notbac); +// // mul notba with C +// DT notbac = apply(notba, c, &Ring::mul); +// LONGS_EQUAL(125,notbac(x101)) +// DOT(notbac); - // mul now in different order - DT acnotb = apply(apply(a, c, &Ring::mul), notb, &Ring::mul); - LONGS_EQUAL(125,acnotb(x101)) - DOT(acnotb); -} +// // mul now in different order +// DT acnotb = apply(apply(a, c, &Ring::mul), notb, &Ring::mul); +// LONGS_EQUAL(125,acnotb(x101)) +// DOT(acnotb); +// } /* ******************************************************************************** */ // test Conversion From 5c4038c7c02ed49958fe12336cfef7a9c76bf039 Mon Sep 17 00:00:00 2001 From: Frank Dellaert Date: Sun, 2 Jan 2022 18:08:09 -0500 Subject: [PATCH 18/26] Fixed dot to have right arguments --- gtsam/discrete/DecisionTreeFactor.cpp | 25 +++++++++++++++++++++++++ gtsam/discrete/DecisionTreeFactor.h | 14 ++++++++++++++ gtsam/discrete/discrete.i | 4 +++- 3 files changed, 42 insertions(+), 1 deletion(-) diff --git a/gtsam/discrete/DecisionTreeFactor.cpp b/gtsam/discrete/DecisionTreeFactor.cpp index 7aed00c57..75018cf92 100644 --- a/gtsam/discrete/DecisionTreeFactor.cpp +++ b/gtsam/discrete/DecisionTreeFactor.cpp @@ -153,6 +153,31 @@ namespace gtsam { return result; } + /* ************************************************************************* */ + static std::string valueFormatter(const double& v) { + return (boost::format("%4.2g") % v).str(); + } + + /** output to graphviz format, stream version */ + void DecisionTreeFactor::dot(std::ostream& os, + const KeyFormatter& keyFormatter, + bool showZero) const { + Potentials::dot(os, keyFormatter, valueFormatter, showZero); + } + + /** output to graphviz format, open a file */ + void DecisionTreeFactor::dot(const std::string& name, + const KeyFormatter& keyFormatter, + bool showZero) const { + Potentials::dot(name, keyFormatter, valueFormatter, showZero); + } + + /** output to graphviz format string */ + std::string DecisionTreeFactor::dot(const KeyFormatter& keyFormatter, + bool showZero) const { + return Potentials::dot(keyFormatter, valueFormatter, showZero); + } + /* ************************************************************************* */ std::string DecisionTreeFactor::markdown( const KeyFormatter& keyFormatter) const { diff --git a/gtsam/discrete/DecisionTreeFactor.h b/gtsam/discrete/DecisionTreeFactor.h index f90af56dd..46509db82 100644 --- a/gtsam/discrete/DecisionTreeFactor.h +++ b/gtsam/discrete/DecisionTreeFactor.h @@ -178,6 +178,20 @@ namespace gtsam { /// @name Wrapper support /// @{ + /** output to graphviz format, stream version */ + void dot(std::ostream& os, + const KeyFormatter& keyFormatter = DefaultKeyFormatter, + bool showZero = true) const; + + /** output to graphviz format, open a file */ + void dot(const std::string& name, + const KeyFormatter& keyFormatter = DefaultKeyFormatter, + bool showZero = true) const; + + /** output to graphviz format string */ + std::string dot(const KeyFormatter& keyFormatter = DefaultKeyFormatter, + bool showZero = true) const; + /// Render as markdown table. std::string markdown( const KeyFormatter& keyFormatter = DefaultKeyFormatter) const override; diff --git a/gtsam/discrete/discrete.i b/gtsam/discrete/discrete.i index 36caccfc8..5bd4a2913 100644 --- a/gtsam/discrete/discrete.i +++ b/gtsam/discrete/discrete.i @@ -46,7 +46,9 @@ virtual class DecisionTreeFactor : gtsam::DiscreteFactor { const gtsam::KeyFormatter& keyFormatter = gtsam::DefaultKeyFormatter) const; bool equals(const gtsam::DecisionTreeFactor& other, double tol = 1e-9) const; - string dot(bool showZero = false) const; + string dot( + const gtsam::KeyFormatter& keyFormatter = gtsam::DefaultKeyFormatter, + bool showZero = false) const; std::vector> enumerate() const; string markdown(const gtsam::KeyFormatter& keyFormatter = gtsam::DefaultKeyFormatter) const; From 6c23fd1e866d967e9c752df7d616859963e58df8 Mon Sep 17 00:00:00 2001 From: Frank Dellaert Date: Sun, 2 Jan 2022 18:08:45 -0500 Subject: [PATCH 19/26] Renamed protected method convert -> convertFrom --- gtsam/discrete/AlgebraicDecisionTree.h | 4 +- gtsam/discrete/DecisionTree-inl.h | 28 +-- gtsam/discrete/DecisionTree.h | 29 ++- gtsam/discrete/tests/testDecisionTree.cpp | 204 +++++++++++----------- 4 files changed, 130 insertions(+), 135 deletions(-) diff --git a/gtsam/discrete/AlgebraicDecisionTree.h b/gtsam/discrete/AlgebraicDecisionTree.h index 37130ab72..17a38f7cf 100644 --- a/gtsam/discrete/AlgebraicDecisionTree.h +++ b/gtsam/discrete/AlgebraicDecisionTree.h @@ -115,11 +115,11 @@ namespace gtsam { template AlgebraicDecisionTree(const AlgebraicDecisionTree& other, const std::map& map) { - std::function map_function = [&map](const M& label) -> L { + std::function L_of_M = [&map](const M& label) -> L { return map.at(label); }; std::function op = Ring::id; - this->root_ = this->template convert(other.root_, op, map_function); + this->root_ = this->template convertFrom(other.root_, L_of_M, op); } /** sum */ diff --git a/gtsam/discrete/DecisionTree-inl.h b/gtsam/discrete/DecisionTree-inl.h index b31773702..af52a6daf 100644 --- a/gtsam/discrete/DecisionTree-inl.h +++ b/gtsam/discrete/DecisionTree-inl.h @@ -461,20 +461,21 @@ namespace gtsam { template template DecisionTree::DecisionTree(const DecisionTree& other, - std::function op) { - auto map = [](const L& label) { return label; }; - root_ = other.template convert(op, map); + std::function Y_of_X) { + auto L_of_L = [](const L& label) { return label; }; + root_ = convertFrom(Y_of_X, L_of_L); } /*********************************************************************************/ - template - template + template + template DecisionTree::DecisionTree(const DecisionTree& other, - const std::map& map, std::function op) { - std::function map_function = [&map](const M& label) -> L { + const std::map& map, + std::function Y_of_X) { + std::function L_of_M = [&map](const M& label) -> L { return map.at(label); }; - root_ = other.template convert(op, map_function); + root_ = convertFrom(other.root_, L_of_M, Y_of_X); } /*********************************************************************************/ @@ -589,9 +590,10 @@ namespace gtsam { /*********************************************************************************/ template template - typename DecisionTree::NodePtr DecisionTree::convert( + typename DecisionTree::NodePtr DecisionTree::convertFrom( const typename DecisionTree::NodePtr& f, - std::function op, std::function map) const { + std::function L_of_M, + std::function Y_of_X) const { typedef DecisionTree MX; typedef typename MX::Leaf MXLeaf; typedef typename MX::Choice MXChoice; @@ -601,7 +603,7 @@ namespace gtsam { // ugliness below because apparently we can't have templated virtual functions // If leaf, apply unary conversion "op" and create a unique leaf auto leaf = boost::dynamic_pointer_cast(f); - if (leaf) return NodePtr(new Leaf(op(leaf->constant()))); + if (leaf) return NodePtr(new Leaf(Y_of_X(leaf->constant()))); // Check if Choice auto choice = boost::dynamic_pointer_cast(f); @@ -610,12 +612,12 @@ namespace gtsam { // get new label const M oldLabel = choice->label(); - const L newLabel = map(oldLabel); + const L newLabel = L_of_M(oldLabel); // put together via Shannon expansion otherwise not sorted. std::vector functions; for(const MXNodePtr& branch: choice->branches()) { - LY converted(convert(branch, op, map)); + LY converted(convertFrom(branch, L_of_M, Y_of_X)); functions += converted; } return LY::compose(functions.begin(), functions.end(), newLabel); diff --git a/gtsam/discrete/DecisionTree.h b/gtsam/discrete/DecisionTree.h index 4c79c5841..ecc3d17dc 100644 --- a/gtsam/discrete/DecisionTree.h +++ b/gtsam/discrete/DecisionTree.h @@ -113,27 +113,20 @@ namespace gtsam { /** A function is a shared pointer to the root of a DT */ typedef typename Node::Ptr NodePtr; - protected: - - /* a DecisionTree just contains the root */ + /// a DecisionTree just contains the root. TODO(dellaert): make protected. NodePtr root_; + protected: + /** Internal recursive function to create from keys, cardinalities, and Y values */ template NodePtr create(It begin, It end, ValueIt beginY, ValueIt endY) const; - /// Convert to a different type, will not convert label if map empty. + /// Convert from a DecisionTree. template - NodePtr convert(const typename DecisionTree::NodePtr& f, - std::function op, - std::function map) const; - - /// Convert to a different type, will not convert label if map empty. - template - NodePtr convert(std::function op, - std::function map) const { - return convert(root_, op, map); - } + NodePtr convertFrom(const typename DecisionTree::NodePtr& f, + std::function L_of_M, + std::function Y_of_X) const; public: @@ -169,12 +162,12 @@ namespace gtsam { /** Convert from a different type. */ template DecisionTree(const DecisionTree& other, - std::function op); + std::function Y_of_X); /** Convert from a different type, also transate labels via map. */ - template - DecisionTree(const DecisionTree& other, - const std::map& map, std::function op); + template + DecisionTree(const DecisionTree& other, const std::map& L_of_M, + std::function Y_of_X); /// @} /// @name Testable diff --git a/gtsam/discrete/tests/testDecisionTree.cpp b/gtsam/discrete/tests/testDecisionTree.cpp index f42a590ae..cc61a382f 100644 --- a/gtsam/discrete/tests/testDecisionTree.cpp +++ b/gtsam/discrete/tests/testDecisionTree.cpp @@ -84,8 +84,11 @@ GTSAM_CONCEPT_TESTABLE_INST(CrazyDecisionTree) /* ******************************************************************************** */ struct DT : public DecisionTree { + using Base = DecisionTree; using DecisionTree::DecisionTree; - DT(const DecisionTree& dt) : root_(dt.root_) {} + DT() = default; + + DT(const Base& dt) : Base(dt) {} /// print to stdout void print(const std::string& s = "") const { @@ -93,15 +96,13 @@ struct DT : public DecisionTree { auto valueFormatter = [](const int& v) { return (boost::format("%d") % v).str(); }; - DecisionTree::print("", keyFormatter, valueFormatter); + Base::print("", keyFormatter, valueFormatter); + } + /// Equality method customized to int node type + bool equals(const Base& other, double tol = 1e-9) const { + auto compare = [](const int& v, const int& w) { return v == w; }; + return Base::equals(other, compare); } - // /// Equality method customized to int node type - // bool equals(const CrazyDecisionTree& other, double tol = 1e-9) const { - // auto compare = [tol](const int& v, const int& w) { - // return v.a == w.a && std::abs(v.b - w.b) < tol; - // }; - // return DecisionTree::equals(other, compare); - // } }; // traits @@ -131,111 +132,111 @@ struct Ring { /* ******************************************************************************** */ // test DT -// TEST(DT, example) -// { -// // Create labels -// string A("A"), B("B"), C("C"); +TEST(DT, example) +{ + // Create labels + string A("A"), B("B"), C("C"); -// // create a value -// Assignment x00, x01, x10, x11; -// x00[A] = 0, x00[B] = 0; -// x01[A] = 0, x01[B] = 1; -// x10[A] = 1, x10[B] = 0; -// x11[A] = 1, x11[B] = 1; + // create a value + Assignment x00, x01, x10, x11; + x00[A] = 0, x00[B] = 0; + x01[A] = 0, x01[B] = 1; + x10[A] = 1, x10[B] = 0; + x11[A] = 1, x11[B] = 1; -// // empty -// DT empty; + // empty + DT empty; -// // A -// DT a(A, 0, 5); -// LONGS_EQUAL(0,a(x00)) -// LONGS_EQUAL(5,a(x10)) -// DOT(a); + // A + DT a(A, 0, 5); + LONGS_EQUAL(0,a(x00)) + LONGS_EQUAL(5,a(x10)) + DOT(a); -// // pruned -// DT p(A, 2, 2); -// LONGS_EQUAL(2,p(x00)) -// LONGS_EQUAL(2,p(x10)) -// DOT(p); + // pruned + DT p(A, 2, 2); + LONGS_EQUAL(2,p(x00)) + LONGS_EQUAL(2,p(x10)) + DOT(p); -// // \neg B -// DT notb(B, 5, 0); -// LONGS_EQUAL(5,notb(x00)) -// LONGS_EQUAL(5,notb(x10)) -// DOT(notb); + // \neg B + DT notb(B, 5, 0); + LONGS_EQUAL(5,notb(x00)) + LONGS_EQUAL(5,notb(x10)) + DOT(notb); -// // Check supplying empty trees yields an exception -// CHECK_EXCEPTION(apply(empty, &Ring::id), std::runtime_error); -// CHECK_EXCEPTION(apply(empty, a, &Ring::mul), std::runtime_error); -// CHECK_EXCEPTION(apply(a, empty, &Ring::mul), std::runtime_error); + // Check supplying empty trees yields an exception + CHECK_EXCEPTION(apply(empty, &Ring::id), std::runtime_error); + CHECK_EXCEPTION(apply(empty, a, &Ring::mul), std::runtime_error); + CHECK_EXCEPTION(apply(a, empty, &Ring::mul), std::runtime_error); -// // apply, two nodes, in natural order -// DT anotb = apply(a, notb, &Ring::mul); -// LONGS_EQUAL(0,anotb(x00)) -// LONGS_EQUAL(0,anotb(x01)) -// LONGS_EQUAL(25,anotb(x10)) -// LONGS_EQUAL(0,anotb(x11)) -// DOT(anotb); + // apply, two nodes, in natural order + DT anotb = apply(a, notb, &Ring::mul); + LONGS_EQUAL(0,anotb(x00)) + LONGS_EQUAL(0,anotb(x01)) + LONGS_EQUAL(25,anotb(x10)) + LONGS_EQUAL(0,anotb(x11)) + DOT(anotb); -// // check pruning -// DT pnotb = apply(p, notb, &Ring::mul); -// LONGS_EQUAL(10,pnotb(x00)) -// LONGS_EQUAL( 0,pnotb(x01)) -// LONGS_EQUAL(10,pnotb(x10)) -// LONGS_EQUAL( 0,pnotb(x11)) -// DOT(pnotb); + // check pruning + DT pnotb = apply(p, notb, &Ring::mul); + LONGS_EQUAL(10,pnotb(x00)) + LONGS_EQUAL( 0,pnotb(x01)) + LONGS_EQUAL(10,pnotb(x10)) + LONGS_EQUAL( 0,pnotb(x11)) + DOT(pnotb); -// // check pruning -// DT zeros = apply(DT(A, 0, 0), notb, &Ring::mul); -// LONGS_EQUAL(0,zeros(x00)) -// LONGS_EQUAL(0,zeros(x01)) -// LONGS_EQUAL(0,zeros(x10)) -// LONGS_EQUAL(0,zeros(x11)) -// DOT(zeros); + // check pruning + DT zeros = apply(DT(A, 0, 0), notb, &Ring::mul); + LONGS_EQUAL(0,zeros(x00)) + LONGS_EQUAL(0,zeros(x01)) + LONGS_EQUAL(0,zeros(x10)) + LONGS_EQUAL(0,zeros(x11)) + DOT(zeros); -// // apply, two nodes, in switched order -// DT notba = apply(a, notb, &Ring::mul); -// LONGS_EQUAL(0,notba(x00)) -// LONGS_EQUAL(0,notba(x01)) -// LONGS_EQUAL(25,notba(x10)) -// LONGS_EQUAL(0,notba(x11)) -// DOT(notba); + // apply, two nodes, in switched order + DT notba = apply(a, notb, &Ring::mul); + LONGS_EQUAL(0,notba(x00)) + LONGS_EQUAL(0,notba(x01)) + LONGS_EQUAL(25,notba(x10)) + LONGS_EQUAL(0,notba(x11)) + DOT(notba); -// // Test choose 0 -// DT actual0 = notba.choose(A, 0); -// EXPECT(assert_equal(DT(0.0), actual0)); -// DOT(actual0); + // Test choose 0 + DT actual0 = notba.choose(A, 0); + EXPECT(assert_equal(DT(0.0), actual0)); + DOT(actual0); -// // Test choose 1 -// DT actual1 = notba.choose(A, 1); -// EXPECT(assert_equal(DT(B, 25, 0), actual1)); -// DOT(actual1); + // Test choose 1 + DT actual1 = notba.choose(A, 1); + EXPECT(assert_equal(DT(B, 25, 0), actual1)); + DOT(actual1); -// // apply, two nodes at same level -// DT a_and_a = apply(a, a, &Ring::mul); -// LONGS_EQUAL(0,a_and_a(x00)) -// LONGS_EQUAL(0,a_and_a(x01)) -// LONGS_EQUAL(25,a_and_a(x10)) -// LONGS_EQUAL(25,a_and_a(x11)) -// DOT(a_and_a); + // apply, two nodes at same level + DT a_and_a = apply(a, a, &Ring::mul); + LONGS_EQUAL(0,a_and_a(x00)) + LONGS_EQUAL(0,a_and_a(x01)) + LONGS_EQUAL(25,a_and_a(x10)) + LONGS_EQUAL(25,a_and_a(x11)) + DOT(a_and_a); -// // create a function on C -// DT c(C, 0, 5); + // create a function on C + DT c(C, 0, 5); -// // and a model assigning stuff to C -// Assignment x101; -// x101[A] = 1, x101[B] = 0, x101[C] = 1; + // and a model assigning stuff to C + Assignment x101; + x101[A] = 1, x101[B] = 0, x101[C] = 1; -// // mul notba with C -// DT notbac = apply(notba, c, &Ring::mul); -// LONGS_EQUAL(125,notbac(x101)) -// DOT(notbac); + // mul notba with C + DT notbac = apply(notba, c, &Ring::mul); + LONGS_EQUAL(125,notbac(x101)) + DOT(notbac); -// // mul now in different order -// DT acnotb = apply(apply(a, c, &Ring::mul), notb, &Ring::mul); -// LONGS_EQUAL(125,acnotb(x101)) -// DOT(acnotb); -// } + // mul now in different order + DT acnotb = apply(apply(a, c, &Ring::mul), notb, &Ring::mul); + LONGS_EQUAL(125,acnotb(x101)) + DOT(acnotb); +} /* ******************************************************************************** */ // test Conversion @@ -243,9 +244,6 @@ enum Label { U, V, X, Y, Z }; typedef DecisionTree BDT; -bool convert(const int& y) { - return y != 0; -} TEST(DT, conversion) { @@ -259,8 +257,10 @@ TEST(DT, conversion) map ordering; ordering[A] = X; ordering[B] = Y; - std::function op = convert; - BDT f2(f1, ordering, op); + std::function bool_of_int = [](const int& y) { + return y != 0; + }; + BDT f2(f1, ordering, bool_of_int); // f1.print("f1"); // f2.print("f2"); From 636425404d18347061114c735b8ffc77cdaf0d69 Mon Sep 17 00:00:00 2001 From: Frank Dellaert Date: Sun, 2 Jan 2022 19:57:07 -0500 Subject: [PATCH 20/26] Fix compile error on windows --- gtsam/discrete/DecisionTree-inl.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/gtsam/discrete/DecisionTree-inl.h b/gtsam/discrete/DecisionTree-inl.h index af52a6daf..259489f06 100644 --- a/gtsam/discrete/DecisionTree-inl.h +++ b/gtsam/discrete/DecisionTree-inl.h @@ -662,7 +662,7 @@ namespace gtsam { DecisionTree DecisionTree::apply(const DecisionTree& g, const Binary& op) const { // It is unclear what should happen if either tree is empty: - if (empty() or g.empty()) { + if (empty() || g.empty()) { throw std::runtime_error( "DecisionTree::apply(binary op) undefined for empty trees."); } From a9b2c326693b6c087b190628a3fd3780c671a094 Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Sun, 2 Jan 2022 23:45:01 -0500 Subject: [PATCH 21/26] Move DefaultFormatter to base class and add defaults. Also replace Super with Base and add using. --- gtsam/discrete/AlgebraicDecisionTree.h | 48 +++++++++--------- gtsam/discrete/DecisionTree.h | 69 +++++++++++++++++++------- 2 files changed, 75 insertions(+), 42 deletions(-) diff --git a/gtsam/discrete/AlgebraicDecisionTree.h b/gtsam/discrete/AlgebraicDecisionTree.h index 17a38f7cf..0b13f408e 100644 --- a/gtsam/discrete/AlgebraicDecisionTree.h +++ b/gtsam/discrete/AlgebraicDecisionTree.h @@ -29,16 +29,9 @@ namespace gtsam { */ template class GTSAM_EXPORT AlgebraicDecisionTree: public DecisionTree { - /// Default method used by `formatter` when printing. - static std::string DefaultFormatter(const L& x) { - std::stringstream ss; - ss << x; - return ss.str(); - } - public: - typedef DecisionTree Super; + using Base = DecisionTree; /** The Real ring with addition and multiplication */ struct Ring { @@ -66,33 +59,33 @@ namespace gtsam { }; AlgebraicDecisionTree() : - Super(1.0) { + Base(1.0) { } - AlgebraicDecisionTree(const Super& add) : - Super(add) { + AlgebraicDecisionTree(const Base& add) : + Base(add) { } /** Create a new leaf function splitting on a variable */ AlgebraicDecisionTree(const L& label, double y1, double y2) : - Super(label, y1, y2) { + Base(label, y1, y2) { } /** Create a new leaf function splitting on a variable */ - AlgebraicDecisionTree(const typename Super::LabelC& labelC, double y1, double y2) : - Super(labelC, y1, y2) { + AlgebraicDecisionTree(const typename Base::LabelC& labelC, double y1, double y2) : + Base(labelC, y1, y2) { } /** Create from keys and vector table */ AlgebraicDecisionTree // - (const std::vector& labelCs, const std::vector& ys) { - this->root_ = Super::create(labelCs.begin(), labelCs.end(), ys.begin(), + (const std::vector& labelCs, const std::vector& ys) { + this->root_ = Base::create(labelCs.begin(), labelCs.end(), ys.begin(), ys.end()); } /** Create from keys and string table */ AlgebraicDecisionTree // - (const std::vector& labelCs, const std::string& table) { + (const std::vector& labelCs, const std::string& table) { // Convert string to doubles std::vector ys; std::istringstream iss(table); @@ -100,18 +93,23 @@ namespace gtsam { std::istream_iterator(), std::back_inserter(ys)); // now call recursive Create - this->root_ = Super::create(labelCs.begin(), labelCs.end(), ys.begin(), + this->root_ = Base::create(labelCs.begin(), labelCs.end(), ys.begin(), ys.end()); } /** Create a new function splitting on a variable */ template AlgebraicDecisionTree(Iterator begin, Iterator end, const L& label) : - Super(nullptr) { + Base(nullptr) { this->root_ = compose(begin, end, label); } - /** Convert */ + /** + * Convert labels from type M to type L. + * + * @param other: The AlgebraicDecisionTree with label type M to convert. + * @param map: Map from label type M to label type L. + */ template AlgebraicDecisionTree(const AlgebraicDecisionTree& other, const std::map& map) { @@ -143,18 +141,18 @@ namespace gtsam { } /** sum out variable */ - AlgebraicDecisionTree sum(const typename Super::LabelC& labelC) const { + AlgebraicDecisionTree sum(const typename Base::LabelC& labelC) const { return this->combine(labelC, &Ring::add); } /// print method customized to node type `double`. void print(const std::string& s, - const typename Super::LabelFormatter& labelFormatter = - &DefaultFormatter) const { + const typename Base::LabelFormatter& labelFormatter = + &Base::DefaultFormatter) const { auto valueFormatter = [](const double& v) { return (boost::format("%4.2g") % v).str(); }; - Super::print(s, labelFormatter, valueFormatter); + Base::print(s, labelFormatter, valueFormatter); } /// Equality method customized to node type `double`. @@ -163,7 +161,7 @@ namespace gtsam { auto compare = [tol](double a, double b) { return std::abs(a - b) < tol; }; - return Super::equals(other, compare); + return Base::equals(other, compare); } }; // AlgebraicDecisionTree diff --git a/gtsam/discrete/DecisionTree.h b/gtsam/discrete/DecisionTree.h index ecc3d17dc..b02c2b302 100644 --- a/gtsam/discrete/DecisionTree.h +++ b/gtsam/discrete/DecisionTree.h @@ -39,11 +39,24 @@ namespace gtsam { template class GTSAM_EXPORT DecisionTree { + protected: /// Default method for comparison of two objects of type Y. static bool DefaultCompare(const Y& a, const Y& b) { return a == b; } + /** + * @brief Default method used by `labelFormatter` or `valueFormatter` when printing. + * + * @param x The value passed to format. + * @return std::string + */ + static std::string DefaultFormatter(const L& x) { + std::stringstream ss; + ss << x; + return ss.str(); + } + public: using LabelFormatter = std::function; @@ -88,12 +101,14 @@ namespace gtsam { const void* id() const { return this; } // everything else is virtual, no documentation here as internal - virtual void print(const std::string& s, - const LabelFormatter& labelFormatter, - const ValueFormatter& valueFormatter) const = 0; - virtual void dot(std::ostream& os, const LabelFormatter& labelFormatter, - const ValueFormatter& valueFormatter, - bool showZero) const = 0; + virtual void print( + const std::string& s, + const LabelFormatter& labelFormatter = &DefaultFormatter, + const ValueFormatter& valueFormatter = &DefaultFormatter) const = 0; + virtual void dot(std::ostream& os, + const LabelFormatter& labelFormatter = &DefaultFormatter, + const ValueFormatter& valueFormatter = &DefaultFormatter, + bool showZero = true) const = 0; virtual bool sameLeaf(const Leaf& q) const = 0; virtual bool sameLeaf(const Node& q) const = 0; virtual bool equals(const Node& other, const CompareFunc& compare = @@ -111,7 +126,7 @@ namespace gtsam { public: /** A function is a shared pointer to the root of a DT */ - typedef typename Node::Ptr NodePtr; + using NodePtr = typename Node::Ptr; /// a DecisionTree just contains the root. TODO(dellaert): make protected. NodePtr root_; @@ -164,7 +179,16 @@ namespace gtsam { DecisionTree(const DecisionTree& other, std::function Y_of_X); - /** Convert from a different type, also transate labels via map. */ + /** + * @brief Convert from a different node type X to node type Y, also transate + * labels via map from type M to L. + * + * @tparam M Previous label type. + * @tparam X Previous node type. + * @param other The decision tree to convert. + * @param L_of_M Map from label type M to type L. + * @param Y_of_X Functor to convert from type X to type Y. + */ template DecisionTree(const DecisionTree& other, const std::map& L_of_M, std::function Y_of_X); @@ -173,9 +197,16 @@ namespace gtsam { /// @name Testable /// @{ - /** GTSAM-style print */ - void print(const std::string& s, const LabelFormatter& labelFormatter, - const ValueFormatter& valueFormatter) const; + /** + * @brief GTSAM-style print + * + * @param s Prefix string. + * @param labelFormatter Functor to format the node label. + * @param valueFormatter Functor to format the node value. + */ + void print(const std::string& s, + const LabelFormatter& labelFormatter = &DefaultFormatter, + const ValueFormatter& valueFormatter = &DefaultFormatter) const; // Testable bool equals(const DecisionTree& other, @@ -220,16 +251,20 @@ namespace gtsam { } /** output to graphviz format, stream version */ - void dot(std::ostream& os, const LabelFormatter& labelFormatter, - const ValueFormatter& valueFormatter, bool showZero = true) const; + void dot(std::ostream& os, + const LabelFormatter& labelFormatter = &DefaultFormatter, + const ValueFormatter& valueFormatter = &DefaultFormatter, + bool showZero = true) const; /** output to graphviz format, open a file */ - void dot(const std::string& name, const LabelFormatter& labelFormatter, - const ValueFormatter& valueFormatter, bool showZero = true) const; + void dot(const std::string& name, + const LabelFormatter& labelFormatter = &DefaultFormatter, + const ValueFormatter& valueFormatter = &DefaultFormatter, + bool showZero = true) const; /** output to graphviz format string */ - std::string dot(const LabelFormatter& labelFormatter, - const ValueFormatter& valueFormatter, + std::string dot(const LabelFormatter& labelFormatter = &DefaultFormatter, + const ValueFormatter& valueFormatter = &DefaultFormatter, bool showZero = true) const; /// @name Advanced Interface From 174490eb510dc39b5cc2b9f2c50764081f99f092 Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Sun, 2 Jan 2022 23:49:47 -0500 Subject: [PATCH 22/26] kill commented out code --- gtsam/discrete/tests/testDecisionTree.cpp | 11 ----------- 1 file changed, 11 deletions(-) diff --git a/gtsam/discrete/tests/testDecisionTree.cpp b/gtsam/discrete/tests/testDecisionTree.cpp index cc61a382f..53f3c4379 100644 --- a/gtsam/discrete/tests/testDecisionTree.cpp +++ b/gtsam/discrete/tests/testDecisionTree.cpp @@ -45,15 +45,6 @@ struct Crazy { double b; }; -// bool equals(const Crazy& other, double tol = 1e-12) const { -// return a == other.a && std::abs(b - other.b) < tol; -// } - -// bool operator==(const Crazy& other) const { -// return this->equals(other); -// } -// }; - struct CrazyDecisionTree : public DecisionTree { /// print to stdout void print(const std::string& s = "") const { @@ -261,8 +252,6 @@ TEST(DT, conversion) return y != 0; }; BDT f2(f1, ordering, bool_of_int); - // f1.print("f1"); - // f2.print("f2"); // create a value Assignment