From 26c48a8c1fc33f9bcebfbcfd0c813ea30c551811 Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Thu, 30 Dec 2021 15:28:07 -0500 Subject: [PATCH] address more review comments # Conflicts: # gtsam/hybrid/DCGaussianMixtureFactor.h # gtsam/hybrid/DCMixtureFactor.h --- gtsam/discrete/AlgebraicDecisionTree.h | 17 ++++++----------- gtsam/discrete/DecisionTree-inl.h | 21 ++++++++------------- gtsam/discrete/DecisionTree.h | 21 +++++++++++---------- 3 files changed, 25 insertions(+), 34 deletions(-) diff --git a/gtsam/discrete/AlgebraicDecisionTree.h b/gtsam/discrete/AlgebraicDecisionTree.h index 3469ac5cf..f47e01668 100644 --- a/gtsam/discrete/AlgebraicDecisionTree.h +++ b/gtsam/discrete/AlgebraicDecisionTree.h @@ -30,14 +30,7 @@ namespace gtsam { template class AlgebraicDecisionTree: public DecisionTree { - /** - * @brief Default method for comparison of two doubles upto some tolerance. - */ - static bool DefaultComparator(double a, double b, double tol) { - return std::abs(a - b) < tol; - } - - public: + public: typedef DecisionTree Super; @@ -146,9 +139,11 @@ namespace gtsam { } /// Equality method customized to node type `double`. - bool equals(const AlgebraicDecisionTree& other, double tol = 1e-9, - const std::function& comparator = - &DefaultComparator) const { + bool equals(const AlgebraicDecisionTree& other, double tol = 1e-9) const { + // lambda for comparison of two doubles upto some tolerance. + auto comparator = [](double a, double b, double tol) { + return std::abs(a - b) < tol; + }; return this->root_->equals(*other.root_, tol, comparator); } }; diff --git a/gtsam/discrete/DecisionTree-inl.h b/gtsam/discrete/DecisionTree-inl.h index f531e2b98..fb6c78148 100644 --- a/gtsam/discrete/DecisionTree-inl.h +++ b/gtsam/discrete/DecisionTree-inl.h @@ -19,7 +19,6 @@ #pragma once -#include #include #include @@ -78,8 +77,7 @@ namespace gtsam { /** equality up to tolerance */ bool equals(const Node& q, double tol, - const std::function& - comparator) const override { + const ComparatorFunc& comparator) const override { const Leaf* other = dynamic_cast(&q); if (!other) return false; return comparator(this->constant_, other->constant_, tol); @@ -87,7 +85,7 @@ namespace gtsam { /** print */ void print(const std::string& s, - const std::function& formatter) const override { + const FormatterFunc& formatter) const override { bool showZero = true; if (showZero || constant_) std::cout << s << " Leaf " << constant_ << std::endl; } @@ -241,7 +239,7 @@ namespace gtsam { /** print (as a tree) */ void print(const std::string& s, - const std::function& formatter) const override { + const FormatterFunc& formatter) const override { std::cout << s << " Choice("; std::cout << formatter(label_) << ") " << std::endl; for (size_t i = 0; i < branches_.size(); i++) @@ -284,8 +282,7 @@ namespace gtsam { /** equality up to tolerance */ bool equals(const Node& q, double tol, - const std::function& - comparator) const override { + const ComparatorFunc& comparator) const override { const Choice* other = dynamic_cast(&q); if (!other) return false; if (this->label_ != other->label_) return false; @@ -651,16 +648,14 @@ namespace gtsam { /*********************************************************************************/ template - bool DecisionTree::equals( - const DecisionTree& other, double tol, - const std::function& comparator) const { + bool DecisionTree::equals(const DecisionTree& other, double tol, + const ComparatorFunc& comparator) const { return root_->equals(*other.root_, tol, comparator); } template - void DecisionTree::print( - const std::string& s, - const std::function& formatter) const { + void DecisionTree::print(const std::string& s, + const FormatterFunc& formatter) const { root_->print(s, formatter); } diff --git a/gtsam/discrete/DecisionTree.h b/gtsam/discrete/DecisionTree.h index eb23bc5ce..8e6c0e4d7 100644 --- a/gtsam/discrete/DecisionTree.h +++ b/gtsam/discrete/DecisionTree.h @@ -53,6 +53,9 @@ namespace gtsam { public: + using FormatterFunc = std::function; + using ComparatorFunc = std::function; + /** Handy typedefs for unary and binary function types */ typedef std::function Unary; typedef std::function Binary; @@ -91,15 +94,15 @@ namespace gtsam { const void* id() const { return this; } // everything else is virtual, no documentation here as internal - virtual void print(const std::string& s = "", - const std::function& formatter = - &DefaultFormatter) const = 0; + virtual void print( + const std::string& s = "", + const FormatterFunc& formatter = &DefaultFormatter) const = 0; virtual void dot(std::ostream& os, bool showZero) const = 0; virtual bool sameLeaf(const Leaf& q) const = 0; virtual bool sameLeaf(const Node& q) const = 0; - virtual bool equals(const Node& other, double tol = 1e-9, - const std::function& - comparator = &DefaultComparator) const = 0; + virtual bool equals( + const Node& other, double tol = 1e-9, + const ComparatorFunc& comparator = &DefaultComparator) const = 0; virtual const Y& operator()(const Assignment& x) const = 0; virtual Ptr apply(const Unary& op) const = 0; virtual Ptr apply_f_op_g(const Node&, const Binary&) const = 0; @@ -181,13 +184,11 @@ namespace gtsam { /** GTSAM-style print */ void print(const std::string& s = "DecisionTree", - const std::function& formatter = - &DefaultFormatter) const; + const FormatterFunc& formatter = &DefaultFormatter) const; // Testable bool equals(const DecisionTree& other, double tol = 1e-9, - const std::function& - comparator = &DefaultComparator) const; + const ComparatorFunc& comparator = &DefaultComparator) const; /// @} /// @name Standard Interface