add formatting capabilities to DecisionTree

release/4.3a0
Varun Agrawal 2021-12-29 14:14:13 -05:00 committed by Frank Dellaert
parent 94f21358f4
commit ddaf9608d0
3 changed files with 28 additions and 14 deletions

View File

@ -83,7 +83,8 @@ namespace gtsam {
} }
/** print */ /** print */
void print(const std::string& s) const override { void print(const std::string& s,
const std::function<std::string(L)> formatter) const override {
bool showZero = true; bool showZero = true;
if (showZero || constant_) std::cout << s << " Leaf " << constant_ << std::endl; if (showZero || constant_) std::cout << s << " Leaf " << constant_ << std::endl;
} }
@ -236,12 +237,11 @@ namespace gtsam {
} }
/** print (as a tree) */ /** print (as a tree) */
void print(const std::string& s) const override { void print(const std::string& s, const std::function<std::string(L)> formatter) const override {
std::cout << s << " Choice("; std::cout << s << " Choice(";
// std::cout << this << ","; std::cout << formatter(label_) << ") " << std::endl;
std::cout << label_ << ") " << std::endl;
for (size_t i = 0; i < branches_.size(); i++) for (size_t i = 0; i < branches_.size(); i++)
branches_[i]->print((boost::format("%s %d") % s % i).str()); branches_[i]->print((boost::format("%s %d") % s % i).str(), formatter);
} }
/** output to graphviz (as a a graph) */ /** output to graphviz (as a a graph) */
@ -591,7 +591,7 @@ namespace gtsam {
// get new label // get new label
M oldLabel = choice->label(); M oldLabel = choice->label();
L newLabel = map.at(oldLabel); L newLabel = oldLabel; //map.at(oldLabel);
// put together via Shannon expansion otherwise not sorted. // put together via Shannon expansion otherwise not sorted.
std::vector<LY> functions; std::vector<LY> functions;
@ -609,8 +609,10 @@ namespace gtsam {
} }
template <typename L, typename Y> template <typename L, typename Y>
void DecisionTree<L, Y>::print(const std::string& s) const { void DecisionTree<L, Y>::print(
root_->print(s); const std::string& s,
const std::function<std::string(L)> formatter) const {
root_->print(s, formatter);
} }
template<typename L, typename Y> template<typename L, typename Y>

View File

@ -20,13 +20,13 @@
#pragma once #pragma once
#include <gtsam/base/types.h> #include <gtsam/base/types.h>
#include <gtsam/discrete/Assignment.h> #include <gtsam/discrete/Assignment.h>
#include <boost/function.hpp> #include <boost/function.hpp>
#include <functional> #include <functional>
#include <iostream> #include <iostream>
#include <map> #include <map>
#include <sstream>
#include <vector> #include <vector>
namespace gtsam { namespace gtsam {
@ -79,7 +79,13 @@ namespace gtsam {
const void* id() const { return this; } const void* id() const { return this; }
// everything else is virtual, no documentation here as internal // everything else is virtual, no documentation here as internal
virtual void print(const std::string& s = "") const = 0; virtual void print(
const std::string& s = "",
const std::function<std::string(L)> formatter = [](const L& x) {
std::stringstream ss;
ss << x;
return ss.str();
}) const = 0;
virtual void dot(std::ostream& os, bool showZero) const = 0; virtual void dot(std::ostream& os, bool showZero) const = 0;
virtual bool sameLeaf(const Leaf& q) const = 0; virtual bool sameLeaf(const Leaf& q) const = 0;
virtual bool sameLeaf(const Node& q) const = 0; virtual bool sameLeaf(const Node& q) const = 0;
@ -154,7 +160,13 @@ namespace gtsam {
/// @{ /// @{
/** GTSAM-style print */ /** GTSAM-style print */
void print(const std::string& s = "DecisionTree") const; void print(
const std::string& s = "DecisionTree",
const std::function<std::string(L)> formatter = [](const L& x) {
std::stringstream ss;
ss << x;
return ss.str();
}) const;
// Testable // Testable
bool equals(const DecisionTree& other, double tol = 1e-9) const; bool equals(const DecisionTree& other, double tol = 1e-9) const;

View File

@ -55,7 +55,7 @@ void Potentials::print(const string& s, const KeyFormatter& formatter) const {
for (const std::pair<const Key,size_t>& key : cardinalities_) for (const std::pair<const Key,size_t>& key : cardinalities_)
cout << formatter(key.first) << ":" << key.second << ", "; cout << formatter(key.first) << ":" << key.second << ", ";
cout << "}" << endl; cout << "}" << endl;
ADT::print(" "); ADT::print(" ", formatter);
} }
// //
// /* ************************************************************************* */ // /* ************************************************************************* */