diff --git a/gtsam/discrete/AlgebraicDecisionTree.h b/gtsam/discrete/AlgebraicDecisionTree.h index 9cc55ed6a..60f3017f4 100644 --- a/gtsam/discrete/AlgebraicDecisionTree.h +++ b/gtsam/discrete/AlgebraicDecisionTree.h @@ -28,11 +28,22 @@ namespace gtsam { * TODO: consider eliminating this class altogether? */ template - class AlgebraicDecisionTree: public DecisionTree { + class GTSAM_EXPORT AlgebraicDecisionTree: public DecisionTree { + /** + * @brief Default method used by `labelFormatter` or `valueFormatter` when printing. + * + * @param x The value passed to format. + * @return std::string + */ + static std::string DefaultFormatter(const L& x) { + std::stringstream ss; + ss << x; + return ss.str(); + } - public: + public: - typedef DecisionTree Super; + using Base = DecisionTree; /** The Real ring with addition and multiplication */ struct Ring { @@ -60,33 +71,33 @@ namespace gtsam { }; AlgebraicDecisionTree() : - Super(1.0) { + Base(1.0) { } - AlgebraicDecisionTree(const Super& add) : - Super(add) { + AlgebraicDecisionTree(const Base& add) : + Base(add) { } /** Create a new leaf function splitting on a variable */ AlgebraicDecisionTree(const L& label, double y1, double y2) : - Super(label, y1, y2) { + Base(label, y1, y2) { } /** Create a new leaf function splitting on a variable */ - AlgebraicDecisionTree(const typename Super::LabelC& labelC, double y1, double y2) : - Super(labelC, y1, y2) { + AlgebraicDecisionTree(const typename Base::LabelC& labelC, double y1, double y2) : + Base(labelC, y1, y2) { } /** Create from keys and vector table */ AlgebraicDecisionTree // - (const std::vector& labelCs, const std::vector& ys) { - this->root_ = Super::create(labelCs.begin(), labelCs.end(), ys.begin(), + (const std::vector& labelCs, const std::vector& ys) { + this->root_ = Base::create(labelCs.begin(), labelCs.end(), ys.begin(), ys.end()); } /** Create from keys and string table */ AlgebraicDecisionTree // - (const std::vector& labelCs, const std::string& table) { + (const std::vector& labelCs, const std::string& table) { // Convert string to doubles std::vector ys; std::istringstream iss(table); @@ -94,23 +105,32 @@ namespace gtsam { std::istream_iterator(), std::back_inserter(ys)); // now call recursive Create - this->root_ = Super::create(labelCs.begin(), labelCs.end(), ys.begin(), + this->root_ = Base::create(labelCs.begin(), labelCs.end(), ys.begin(), ys.end()); } /** Create a new function splitting on a variable */ template AlgebraicDecisionTree(Iterator begin, Iterator end, const L& label) : - Super(nullptr) { + Base(nullptr) { this->root_ = compose(begin, end, label); } - /** Convert */ + /** + * Convert labels from type M to type L. + * + * @param other: The AlgebraicDecisionTree with label type M to convert. + * @param map: Map from label type M to label type L. + */ template AlgebraicDecisionTree(const AlgebraicDecisionTree& other, - const std::map& map) { - this->root_ = this->template convert(other.root_, map, - Ring::id); + const std::map& map) { + // Functor for label conversion so we can use `convertFrom`. + std::function L_of_M = [&map](const M& label) -> L { + return map.at(label); + }; + std::function op = Ring::id; + this->root_ = this->template convertFrom(other.root_, L_of_M, op); } /** sum */ @@ -134,10 +154,28 @@ namespace gtsam { } /** sum out variable */ - AlgebraicDecisionTree sum(const typename Super::LabelC& labelC) const { + AlgebraicDecisionTree sum(const typename Base::LabelC& labelC) const { return this->combine(labelC, &Ring::add); } + /// print method customized to value type `double`. + void print(const std::string& s, + const typename Base::LabelFormatter& labelFormatter = + &DefaultFormatter) const { + auto valueFormatter = [](const double& v) { + return (boost::format("%4.2g") % v).str(); + }; + Base::print(s, labelFormatter, valueFormatter); + } + + /// Equality method customized to value type `double`. + bool equals(const AlgebraicDecisionTree& other, double tol = 1e-9) const { + // lambda for comparison of two doubles upto some tolerance. + auto compare = [tol](double a, double b) { + return std::abs(a - b) < tol; + }; + return Base::equals(other, compare); + } }; // AlgebraicDecisionTree diff --git a/gtsam/discrete/DecisionTree-inl.h b/gtsam/discrete/DecisionTree-inl.h index f6a64f11f..11ecbf183 100644 --- a/gtsam/discrete/DecisionTree-inl.h +++ b/gtsam/discrete/DecisionTree-inl.h @@ -20,21 +20,21 @@ #pragma once #include -#include +#include #include +#include #include #include -#include -using boost::assign::operator+=; +#include #include -#include - -#include #include #include +#include #include +using boost::assign::operator+=; + namespace gtsam { /*********************************************************************************/ @@ -76,23 +76,32 @@ namespace gtsam { } /** equality up to tolerance */ - bool equals(const Node& q, double tol) const override { - const Leaf* other = dynamic_cast (&q); + bool equals(const Node& q, const CompareFunc& compare) const override { + const Leaf* other = dynamic_cast(&q); if (!other) return false; - return std::abs(double(this->constant_ - other->constant_)) < tol; + return compare(this->constant_, other->constant_); } - /** print */ - void print(const std::string& s) const override { - bool showZero = true; - if (showZero || constant_) std::cout << s << " Leaf " << constant_ << std::endl; + /** + * @brief Print method. + * + * @param s Prefix string. + * @param labelFormatter Functor to format the labels of type L. + * @param valueFormatter Functor to format the values of type Y. + */ + void print(const std::string& s, const LabelFormatter& labelFormatter, + const ValueFormatter& valueFormatter) const override { + std::cout << s << " Leaf " << valueFormatter(constant_) << std::endl; } - /** to graphviz file */ - void dot(std::ostream& os, bool showZero) const override { - if (showZero || constant_) os << "\"" << this->id() << "\" [label=\"" - << boost::format("%4.2g") % constant_ - << "\", shape=box, rank=sink, height=0.35, fixedsize=true]\n"; // width=0.55, + /** Write graphviz format to stream `os`. */ + void dot(std::ostream& os, const LabelFormatter& labelFormatter, + const ValueFormatter& valueFormatter, + bool showZero) const override { + std::string value = valueFormatter(constant_); + if (showZero || value.compare("0")) + os << "\"" << this->id() << "\" [label=\"" << value + << "\", shape=box, rank=sink, height=0.35, fixedsize=true]\n"; // width=0.55, } /** evaluate */ @@ -151,7 +160,7 @@ namespace gtsam { /** incremental allSame */ size_t allSame_; - typedef boost::shared_ptr ChoicePtr; + using ChoicePtr = boost::shared_ptr; public: @@ -236,16 +245,19 @@ namespace gtsam { } /** print (as a tree) */ - void print(const std::string& s) const override { + void print(const std::string& s, const LabelFormatter& labelFormatter, + const ValueFormatter& valueFormatter) const override { std::cout << s << " Choice("; - // std::cout << this << ","; - std::cout << label_ << ") " << std::endl; + std::cout << labelFormatter(label_) << ") " << std::endl; for (size_t i = 0; i < branches_.size(); i++) - branches_[i]->print((boost::format("%s %d") % s % i).str()); + branches_[i]->print((boost::format("%s %d") % s % i).str(), + labelFormatter, valueFormatter); } /** output to graphviz (as a a graph) */ - void dot(std::ostream& os, bool showZero) const override { + void dot(std::ostream& os, const LabelFormatter& labelFormatter, + const ValueFormatter& valueFormatter, + bool showZero) const override { os << "\"" << this->id() << "\" [shape=circle, label=\"" << label_ << "\"]\n"; size_t B = branches_.size(); @@ -255,7 +267,8 @@ namespace gtsam { // Check if zero if (!showZero) { const Leaf* leaf = dynamic_cast (branch.get()); - if (leaf && !leaf->constant()) continue; + std::string value = valueFormatter(leaf->constant()); + if (leaf && value.compare("0")) continue; } os << "\"" << this->id() << "\" -> \"" << branch->id() << "\""; @@ -264,7 +277,7 @@ namespace gtsam { if (i > 1) os << " [style=bold]"; } os << std::endl; - branch->dot(os, showZero); + branch->dot(os, labelFormatter, valueFormatter, showZero); } } @@ -278,15 +291,16 @@ namespace gtsam { return (q.isLeaf() && q.sameLeaf(*this)); } - /** equality up to tolerance */ - bool equals(const Node& q, double tol) const override { - const Choice* other = dynamic_cast (&q); + /** equality */ + bool equals(const Node& q, const CompareFunc& compare) const override { + const Choice* other = dynamic_cast(&q); if (!other) return false; if (this->label_ != other->label_) return false; if (branches_.size() != other->branches_.size()) return false; // we don't care about shared pointers being equal here for (size_t i = 0; i < branches_.size(); i++) - if (!(branches_[i]->equals(*(other->branches_[i]), tol))) return false; + if (!(branches_[i]->equals(*(other->branches_[i]), compare))) + return false; return true; } @@ -450,11 +464,25 @@ namespace gtsam { } /*********************************************************************************/ - template - template + template + template + DecisionTree::DecisionTree(const DecisionTree& other, + std::function Y_of_X) { + // Define functor for identity mapping of node label. + auto L_of_L = [](const L& label) { return label; }; + root_ = convertFrom(other.root_, L_of_L, Y_of_X); + } + + /*********************************************************************************/ + template + template DecisionTree::DecisionTree(const DecisionTree& other, - const std::map& map, std::function op) { - root_ = convert(other.root_, map, op); + const std::map& map, + std::function Y_of_X) { + std::function L_of_M = [&map](const M& label) -> L { + return map.at(label); + }; + root_ = convertFrom(other.root_, L_of_M, Y_of_X); } /*********************************************************************************/ @@ -567,50 +595,53 @@ namespace gtsam { } /*********************************************************************************/ - template - template - typename DecisionTree::NodePtr DecisionTree::convert( - const typename DecisionTree::NodePtr& f, const std::map& map, - std::function op) { - - typedef DecisionTree MX; - typedef typename MX::Leaf MXLeaf; - typedef typename MX::Choice MXChoice; - typedef typename MX::NodePtr MXNodePtr; - typedef DecisionTree LY; + template + template + typename DecisionTree::NodePtr DecisionTree::convertFrom( + const typename DecisionTree::NodePtr& f, + std::function L_of_M, + std::function Y_of_X) const { + using MX = DecisionTree; + using MXLeaf = typename MX::Leaf; + using MXChoice = typename MX::Choice; + using MXNodePtr = typename MX::NodePtr; + using LY = DecisionTree; // ugliness below because apparently we can't have templated virtual functions // If leaf, apply unary conversion "op" and create a unique leaf - const MXLeaf* leaf = dynamic_cast (f.get()); - if (leaf) return NodePtr(new Leaf(op(leaf->constant()))); + auto leaf = boost::dynamic_pointer_cast(f); + if (leaf) return NodePtr(new Leaf(Y_of_X(leaf->constant()))); // Check if Choice - boost::shared_ptr choice = boost::dynamic_pointer_cast (f); + auto choice = boost::dynamic_pointer_cast(f); if (!choice) throw std::invalid_argument( "DecisionTree::Convert: Invalid NodePtr"); // get new label - M oldLabel = choice->label(); - L newLabel = map.at(oldLabel); + const M oldLabel = choice->label(); + const L newLabel = L_of_M(oldLabel); // put together via Shannon expansion otherwise not sorted. std::vector functions; for(const MXNodePtr& branch: choice->branches()) { - LY converted(convert(branch, map, op)); + LY converted(convertFrom(branch, L_of_M, Y_of_X)); functions += converted; } return LY::compose(functions.begin(), functions.end(), newLabel); } /*********************************************************************************/ - template - bool DecisionTree::equals(const DecisionTree& other, double tol) const { - return root_->equals(*other.root_, tol); + template + bool DecisionTree::equals(const DecisionTree& other, + const CompareFunc& compare) const { + return root_->equals(*other.root_, compare); } - template - void DecisionTree::print(const std::string& s) const { - root_->print(s); + template + void DecisionTree::print(const std::string& s, + const LabelFormatter& labelFormatter, + const ValueFormatter& valueFormatter) const { + root_->print(s, labelFormatter, valueFormatter); } template @@ -625,6 +656,11 @@ namespace gtsam { template DecisionTree DecisionTree::apply(const Unary& op) const { + // It is unclear what should happen if tree is empty: + if (empty()) { + throw std::runtime_error( + "DecisionTree::apply(unary op) undefined for empty tree."); + } return DecisionTree(root_->apply(op)); } @@ -632,6 +668,11 @@ namespace gtsam { template DecisionTree DecisionTree::apply(const DecisionTree& g, const Binary& op) const { + // It is unclear what should happen if either tree is empty: + if (empty() || g.empty()) { + throw std::runtime_error( + "DecisionTree::apply(binary op) undefined for empty trees."); + } // apply the operaton on the root of both diagrams NodePtr h = root_->apply_f_op_g(*g.root_, op); // create a new class with the resulting root "h" @@ -660,26 +701,34 @@ namespace gtsam { } /*********************************************************************************/ - template - void DecisionTree::dot(std::ostream& os, bool showZero) const { + template + void DecisionTree::dot(std::ostream& os, + const LabelFormatter& labelFormatter, + const ValueFormatter& valueFormatter, + bool showZero) const { os << "digraph G {\n"; - root_->dot(os, showZero); + root_->dot(os, labelFormatter, valueFormatter, showZero); os << " [ordering=out]}" << std::endl; } - template - void DecisionTree::dot(const std::string& name, bool showZero) const { + template + void DecisionTree::dot(const std::string& name, + const LabelFormatter& labelFormatter, + const ValueFormatter& valueFormatter, + bool showZero) const { std::ofstream os((name + ".dot").c_str()); - dot(os, showZero); + dot(os, labelFormatter, valueFormatter, showZero); int result = system( ("dot -Tpdf " + name + ".dot -o " + name + ".pdf >& /dev/null").c_str()); if (result==-1) throw std::runtime_error("DecisionTree::dot system call failed"); } - template - std::string DecisionTree::dot(bool showZero) const { + template + std::string DecisionTree::dot(const LabelFormatter& labelFormatter, + const ValueFormatter& valueFormatter, + bool showZero) const { std::stringstream ss; - dot(ss, showZero); + dot(ss, labelFormatter, valueFormatter, showZero); return ss.str(); } diff --git a/gtsam/discrete/DecisionTree.h b/gtsam/discrete/DecisionTree.h index 0a78d4635..db8a12a20 100644 --- a/gtsam/discrete/DecisionTree.h +++ b/gtsam/discrete/DecisionTree.h @@ -20,13 +20,13 @@ #pragma once #include - #include #include #include #include #include +#include #include namespace gtsam { @@ -39,14 +39,24 @@ namespace gtsam { template class GTSAM_EXPORT DecisionTree { + protected: + /// Default method for comparison of two objects of type Y. + static bool DefaultCompare(const Y& a, const Y& b) { + return a == b; + } + public: + using LabelFormatter = std::function; + using ValueFormatter = std::function; + using CompareFunc = std::function; + /** Handy typedefs for unary and binary function types */ - typedef std::function Unary; - typedef std::function Binary; + using Unary = std::function; + using Binary = std::function; /** A label annotated with cardinality */ - typedef std::pair LabelC; + using LabelC = std::pair; /** DTs consist of Leaf and Choice nodes, both subclasses of Node */ class Leaf; @@ -55,7 +65,7 @@ namespace gtsam { /** ------------------------ Node base class --------------------------- */ class Node { public: - typedef boost::shared_ptr Ptr; + using Ptr = boost::shared_ptr; #ifdef DT_DEBUG_MEMORY static int nrNodes; @@ -79,11 +89,16 @@ namespace gtsam { const void* id() const { return this; } // everything else is virtual, no documentation here as internal - virtual void print(const std::string& s = "") const = 0; - virtual void dot(std::ostream& os, bool showZero) const = 0; + virtual void print(const std::string& s, + const LabelFormatter& labelFormatter, + const ValueFormatter& valueFormatter) const = 0; + virtual void dot(std::ostream& os, const LabelFormatter& labelFormatter, + const ValueFormatter& valueFormatter, + bool showZero) const = 0; virtual bool sameLeaf(const Leaf& q) const = 0; virtual bool sameLeaf(const Node& q) const = 0; - virtual bool equals(const Node& other, double tol = 1e-9) const = 0; + virtual bool equals(const Node& other, const CompareFunc& compare = + &DefaultCompare) const = 0; virtual const Y& operator()(const Assignment& x) const = 0; virtual Ptr apply(const Unary& op) const = 0; virtual Ptr apply_f_op_g(const Node&, const Binary&) const = 0; @@ -97,9 +112,9 @@ namespace gtsam { public: /** A function is a shared pointer to the root of a DT */ - typedef typename Node::Ptr NodePtr; + using NodePtr = typename Node::Ptr; - /* a DecisionTree just contains the root */ + /// A DecisionTree just contains the root. TODO(dellaert): make protected. NodePtr root_; protected: @@ -108,19 +123,29 @@ namespace gtsam { template NodePtr create(It begin, It end, ValueIt beginY, ValueIt endY) const; - /** Convert to a different type */ - template NodePtr - convert(const typename DecisionTree::NodePtr& f, const std::map& map, std::function op); + /** + * @brief Convert from a DecisionTree to DecisionTree. + * + * @tparam M The previous label type. + * @tparam X The previous value type. + * @param f The node pointer to the root of the previous DecisionTree. + * @param L_of_M Functor to convert from label type M to type L. + * @param Y_of_X Functor to convert from value type X to type Y. + * @return NodePtr + */ + template + NodePtr convertFrom(const typename DecisionTree::NodePtr& f, + std::function L_of_M, + std::function Y_of_X) const; - /** Default constructor */ - DecisionTree(); - - public: + public: /// @name Standard Constructors /// @{ + /** Default constructor (for serialization) */ + DecisionTree(); + /** Create a constant */ DecisionTree(const Y& y); @@ -144,20 +169,48 @@ namespace gtsam { DecisionTree(const L& label, // const DecisionTree& f0, const DecisionTree& f1); - /** Convert from a different type */ - template - DecisionTree(const DecisionTree& other, - const std::map& map, std::function op); + /** + * @brief Convert from a different value type. + * + * @tparam X The previous value type. + * @param other The DecisionTree to convert from. + * @param Y_of_X Functor to convert from value type X to type Y. + */ + template + DecisionTree(const DecisionTree& other, + std::function Y_of_X); + + /** + * @brief Convert from a different value type X to value type Y, also transate + * labels via map from type M to L. + * + * @tparam M Previous label type. + * @tparam X Previous value type. + * @param other The decision tree to convert. + * @param L_of_M Map from label type M to type L. + * @param Y_of_X Functor to convert from type X to type Y. + */ + template + DecisionTree(const DecisionTree& other, const std::map& L_of_M, + std::function Y_of_X); /// @} /// @name Testable /// @{ - /** GTSAM-style print */ - void print(const std::string& s = "DecisionTree") const; + /** + * @brief GTSAM-style print + * + * @param s Prefix string. + * @param labelFormatter Functor to format the node label. + * @param valueFormatter Functor to format the node value. + */ + void print(const std::string& s, const LabelFormatter& labelFormatter, + const ValueFormatter& valueFormatter) const; // Testable - bool equals(const DecisionTree& other, double tol = 1e-9) const; + bool equals(const DecisionTree& other, + const CompareFunc& compare = &DefaultCompare) const; /// @} /// @name Standard Interface @@ -167,6 +220,9 @@ namespace gtsam { virtual ~DecisionTree() { } + /// Check if tree is empty. + bool empty() const { return !root_; } + /** equality */ bool operator==(const DecisionTree& q) const; @@ -195,13 +251,17 @@ namespace gtsam { } /** output to graphviz format, stream version */ - void dot(std::ostream& os, bool showZero = true) const; + void dot(std::ostream& os, const LabelFormatter& labelFormatter, + const ValueFormatter& valueFormatter, bool showZero = true) const; /** output to graphviz format, open a file */ - void dot(const std::string& name, bool showZero = true) const; + void dot(const std::string& name, const LabelFormatter& labelFormatter, + const ValueFormatter& valueFormatter, bool showZero = true) const; /** output to graphviz format string */ - std::string dot(bool showZero = true) const; + std::string dot(const LabelFormatter& labelFormatter, + const ValueFormatter& valueFormatter, + bool showZero = true) const; /// @name Advanced Interface /// @{ @@ -219,13 +279,15 @@ namespace gtsam { /** free versions of apply */ - template + /// Apply unary operator `op` to DecisionTree `f`. + template DecisionTree apply(const DecisionTree& f, const typename DecisionTree::Unary& op) { return f.apply(op); } - template + /// Apply binary operator `op` to DecisionTree `f`. + template DecisionTree apply(const DecisionTree& f, const DecisionTree& g, const typename DecisionTree::Binary& op) { diff --git a/gtsam/discrete/DecisionTreeFactor.cpp b/gtsam/discrete/DecisionTreeFactor.cpp index 7aed00c57..75018cf92 100644 --- a/gtsam/discrete/DecisionTreeFactor.cpp +++ b/gtsam/discrete/DecisionTreeFactor.cpp @@ -153,6 +153,31 @@ namespace gtsam { return result; } + /* ************************************************************************* */ + static std::string valueFormatter(const double& v) { + return (boost::format("%4.2g") % v).str(); + } + + /** output to graphviz format, stream version */ + void DecisionTreeFactor::dot(std::ostream& os, + const KeyFormatter& keyFormatter, + bool showZero) const { + Potentials::dot(os, keyFormatter, valueFormatter, showZero); + } + + /** output to graphviz format, open a file */ + void DecisionTreeFactor::dot(const std::string& name, + const KeyFormatter& keyFormatter, + bool showZero) const { + Potentials::dot(name, keyFormatter, valueFormatter, showZero); + } + + /** output to graphviz format string */ + std::string DecisionTreeFactor::dot(const KeyFormatter& keyFormatter, + bool showZero) const { + return Potentials::dot(keyFormatter, valueFormatter, showZero); + } + /* ************************************************************************* */ std::string DecisionTreeFactor::markdown( const KeyFormatter& keyFormatter) const { diff --git a/gtsam/discrete/DecisionTreeFactor.h b/gtsam/discrete/DecisionTreeFactor.h index f90af56dd..46509db82 100644 --- a/gtsam/discrete/DecisionTreeFactor.h +++ b/gtsam/discrete/DecisionTreeFactor.h @@ -178,6 +178,20 @@ namespace gtsam { /// @name Wrapper support /// @{ + /** output to graphviz format, stream version */ + void dot(std::ostream& os, + const KeyFormatter& keyFormatter = DefaultKeyFormatter, + bool showZero = true) const; + + /** output to graphviz format, open a file */ + void dot(const std::string& name, + const KeyFormatter& keyFormatter = DefaultKeyFormatter, + bool showZero = true) const; + + /** output to graphviz format string */ + std::string dot(const KeyFormatter& keyFormatter = DefaultKeyFormatter, + bool showZero = true) const; + /// Render as markdown table. std::string markdown( const KeyFormatter& keyFormatter = DefaultKeyFormatter) const override; diff --git a/gtsam/discrete/Potentials.cpp b/gtsam/discrete/Potentials.cpp index fa491eba3..057b6a265 100644 --- a/gtsam/discrete/Potentials.cpp +++ b/gtsam/discrete/Potentials.cpp @@ -51,11 +51,11 @@ bool Potentials::equals(const Potentials& other, double tol) const { /* ************************************************************************* */ void Potentials::print(const string& s, const KeyFormatter& formatter) const { - cout << s << "\n Cardinalities: {"; + cout << s << "\n Cardinalities: { "; for (const std::pair& key : cardinalities_) cout << formatter(key.first) << ":" << key.second << ", "; cout << "}" << endl; - ADT::print(" "); + ADT::print(" ", formatter); } // // /* ************************************************************************* */ diff --git a/gtsam/discrete/discrete.i b/gtsam/discrete/discrete.i index 36caccfc8..5bd4a2913 100644 --- a/gtsam/discrete/discrete.i +++ b/gtsam/discrete/discrete.i @@ -46,7 +46,9 @@ virtual class DecisionTreeFactor : gtsam::DiscreteFactor { const gtsam::KeyFormatter& keyFormatter = gtsam::DefaultKeyFormatter) const; bool equals(const gtsam::DecisionTreeFactor& other, double tol = 1e-9) const; - string dot(bool showZero = false) const; + string dot( + const gtsam::KeyFormatter& keyFormatter = gtsam::DefaultKeyFormatter, + bool showZero = false) const; std::vector> enumerate() const; string markdown(const gtsam::KeyFormatter& keyFormatter = gtsam::DefaultKeyFormatter) const; diff --git a/gtsam/discrete/tests/testAlgebraicDecisionTree.cpp b/gtsam/discrete/tests/testAlgebraicDecisionTree.cpp index 7a33810c7..910515b5c 100644 --- a/gtsam/discrete/tests/testAlgebraicDecisionTree.cpp +++ b/gtsam/discrete/tests/testAlgebraicDecisionTree.cpp @@ -136,8 +136,8 @@ ADT create(const Signature& signature) { ADT p(signature.discreteKeys(), signature.cpt()); static size_t count = 0; const DiscreteKey& key = signature.key(); - string dotfile = (boost::format("CPT-%03d-%d") % ++count % key.first).str(); - dot(p, dotfile); + string DOTfile = (boost::format("CPT-%03d-%d") % ++count % key.first).str(); + dot(p, DOTfile); return p; } @@ -414,13 +414,13 @@ TEST(ADT, equality_noparser) // Check straight equality ADT pA1 = create(A % tableA); ADT pA2 = create(A % tableA); - EXPECT(pA1 == pA2); // should be equal + EXPECT(pA1.equals(pA2)); // should be equal // Check equality after apply ADT pB = create(B % tableB); ADT pAB1 = apply(pA1, pB, &mul); ADT pAB2 = apply(pB, pA1, &mul); - EXPECT(pAB2 == pAB1); + EXPECT(pAB2.equals(pAB1)); } /* ************************************************************************* */ @@ -431,13 +431,13 @@ TEST(ADT, equality_parser) // Check straight equality ADT pA1 = create(A % "80/20"); ADT pA2 = create(A % "80/20"); - EXPECT(pA1 == pA2); // should be equal + EXPECT(pA1.equals(pA2)); // should be equal // Check equality after apply ADT pB = create(B % "60/40"); ADT pAB1 = apply(pA1, pB, &mul); ADT pAB2 = apply(pB, pA1, &mul); - EXPECT(pAB2 == pAB1); + EXPECT(pAB2.equals(pAB1)); } /* ******************************************************************************** */ diff --git a/gtsam/discrete/tests/testDecisionTree.cpp b/gtsam/discrete/tests/testDecisionTree.cpp index 96f503abc..5976ea2d4 100644 --- a/gtsam/discrete/tests/testDecisionTree.cpp +++ b/gtsam/discrete/tests/testDecisionTree.cpp @@ -40,25 +40,69 @@ void dot(const T&f, const string& filename) { #define DOT(x)(dot(x,#x)) -struct Crazy { int a; double b; }; -typedef DecisionTree CrazyDecisionTree; // check that DecisionTree is actually generic (as it pretends to be) +struct Crazy { + int a; + double b; +}; + +struct CrazyDecisionTree : public DecisionTree { + /// print to stdout + void print(const std::string& s = "") const { + auto keyFormatter = [](const std::string& s) { return s; }; + auto valueFormatter = [](const Crazy& v) { + return (boost::format("{%d,%4.2g}") % v.a % v.b).str(); + }; + DecisionTree::print("", keyFormatter, valueFormatter); + } + /// Equality method customized to Crazy node type + bool equals(const CrazyDecisionTree& other, double tol = 1e-9) const { + auto compare = [tol](const Crazy& v, const Crazy& w) { + return v.a == w.a && std::abs(v.b - w.b) < tol; + }; + return DecisionTree::equals(other, compare); + } +}; // traits namespace gtsam { template<> struct traits : public Testable {}; } +GTSAM_CONCEPT_TESTABLE_INST(CrazyDecisionTree) + /* ******************************************************************************** */ // Test string labels and int range /* ******************************************************************************** */ -typedef DecisionTree DT; +struct DT : public DecisionTree { + using Base = DecisionTree; + using DecisionTree::DecisionTree; + DT() = default; + + DT(const Base& dt) : Base(dt) {} + + /// print to stdout + void print(const std::string& s = "") const { + auto keyFormatter = [](const std::string& s) { return s; }; + auto valueFormatter = [](const int& v) { + return (boost::format("%d") % v).str(); + }; + Base::print("", keyFormatter, valueFormatter); + } + /// Equality method customized to int node type + bool equals(const Base& other, double tol = 1e-9) const { + auto compare = [](const int& v, const int& w) { return v == w; }; + return Base::equals(other, compare); + } +}; // traits namespace gtsam { template<> struct traits
: public Testable
{}; } +GTSAM_CONCEPT_TESTABLE_INST(DT) + struct Ring { static inline int zero() { return 0; @@ -66,6 +110,9 @@ struct Ring { static inline int one() { return 1; } + static inline int id(const int& a) { + return a; + } static inline int add(const int& a, const int& b) { return a + b; } @@ -88,6 +135,9 @@ TEST(DT, example) x10[A] = 1, x10[B] = 0; x11[A] = 1, x11[B] = 1; + // empty + DT empty; + // A DT a(A, 0, 5); LONGS_EQUAL(0,a(x00)) @@ -106,6 +156,11 @@ TEST(DT, example) LONGS_EQUAL(5,notb(x10)) DOT(notb); + // Check supplying empty trees yields an exception + CHECK_EXCEPTION(apply(empty, &Ring::id), std::runtime_error); + CHECK_EXCEPTION(apply(empty, a, &Ring::mul), std::runtime_error); + CHECK_EXCEPTION(apply(a, empty, &Ring::mul), std::runtime_error); + // apply, two nodes, in natural order DT anotb = apply(a, notb, &Ring::mul); LONGS_EQUAL(0,anotb(x00)) @@ -175,16 +230,37 @@ TEST(DT, example) } /* ******************************************************************************** */ -// test Conversion +// test Conversion of values +std::function bool_of_int = [](const int& y) { + return y != 0; +}; +typedef DecisionTree StringBoolTree; + +TEST(DT, ConvertValuesOnly) +{ + // Create labels + string A("A"), B("B"); + + // apply, two nodes, in natural order + DT f1 = apply(DT(A, 0, 5), DT(B, 5, 0), &Ring::mul); + + // convert + StringBoolTree f2(f1, bool_of_int); + + // Check a value + Assignment x00; + x00["A"] = 0, x00["B"] = 0; + EXPECT(!f2(x00)); +} + +/* ******************************************************************************** */ +// test Conversion of both values and labels. enum Label { U, V, X, Y, Z }; -typedef DecisionTree BDT; -bool convert(const int& y) { - return y != 0; -} +typedef DecisionTree LabelBoolTree; -TEST(DT, conversion) +TEST(DT, ConvertBoth) { // Create labels string A("A"), B("B"); @@ -196,12 +272,9 @@ TEST(DT, conversion) map ordering; ordering[A] = X; ordering[B] = Y; - std::function op = convert; - BDT f2(f1, ordering, op); - // f1.print("f1"); - // f2.print("f2"); + LabelBoolTree f2(f1, ordering, bool_of_int); - // create a value + // Check some values Assignment