add formatting capabilities to DecisionTree

release/4.3a0
Varun Agrawal 2021-12-29 14:14:13 -05:00 committed by Frank Dellaert
parent 94f21358f4
commit ddaf9608d0
3 changed files with 28 additions and 14 deletions

View File

@ -83,7 +83,8 @@ namespace gtsam {
}
/** print */
void print(const std::string& s) const override {
void print(const std::string& s,
const std::function<std::string(L)> formatter) const override {
bool showZero = true;
if (showZero || constant_) std::cout << s << " Leaf " << constant_ << std::endl;
}
@ -236,12 +237,11 @@ namespace gtsam {
}
/** print (as a tree) */
void print(const std::string& s) const override {
void print(const std::string& s, const std::function<std::string(L)> formatter) const override {
std::cout << s << " Choice(";
// std::cout << this << ",";
std::cout << label_ << ") " << std::endl;
std::cout << formatter(label_) << ") " << std::endl;
for (size_t i = 0; i < branches_.size(); i++)
branches_[i]->print((boost::format("%s %d") % s % i).str());
branches_[i]->print((boost::format("%s %d") % s % i).str(), formatter);
}
/** output to graphviz (as a a graph) */
@ -591,7 +591,7 @@ namespace gtsam {
// get new label
M oldLabel = choice->label();
L newLabel = map.at(oldLabel);
L newLabel = oldLabel; //map.at(oldLabel);
// put together via Shannon expansion otherwise not sorted.
std::vector<LY> functions;
@ -608,9 +608,11 @@ namespace gtsam {
return root_->equals(*other.root_, tol);
}
template<typename L, typename Y>
void DecisionTree<L, Y>::print(const std::string& s) const {
root_->print(s);
template <typename L, typename Y>
void DecisionTree<L, Y>::print(
const std::string& s,
const std::function<std::string(L)> formatter) const {
root_->print(s, formatter);
}
template<typename L, typename Y>

View File

@ -20,13 +20,13 @@
#pragma once
#include <gtsam/base/types.h>
#include <gtsam/discrete/Assignment.h>
#include <boost/function.hpp>
#include <functional>
#include <iostream>
#include <map>
#include <sstream>
#include <vector>
namespace gtsam {
@ -79,7 +79,13 @@ namespace gtsam {
const void* id() const { return this; }
// everything else is virtual, no documentation here as internal
virtual void print(const std::string& s = "") const = 0;
virtual void print(
const std::string& s = "",
const std::function<std::string(L)> formatter = [](const L& x) {
std::stringstream ss;
ss << x;
return ss.str();
}) const = 0;
virtual void dot(std::ostream& os, bool showZero) const = 0;
virtual bool sameLeaf(const Leaf& q) const = 0;
virtual bool sameLeaf(const Node& q) const = 0;
@ -154,7 +160,13 @@ namespace gtsam {
/// @{
/** GTSAM-style print */
void print(const std::string& s = "DecisionTree") const;
void print(
const std::string& s = "DecisionTree",
const std::function<std::string(L)> formatter = [](const L& x) {
std::stringstream ss;
ss << x;
return ss.str();
}) const;
// Testable
bool equals(const DecisionTree& other, double tol = 1e-9) const;

View File

@ -51,11 +51,11 @@ bool Potentials::equals(const Potentials& other, double tol) const {
/* ************************************************************************* */
void Potentials::print(const string& s, const KeyFormatter& formatter) const {
cout << s << "\n Cardinalities: {";
cout << s << "\n Cardinalities: { ";
for (const std::pair<const Key,size_t>& key : cardinalities_)
cout << formatter(key.first) << ":" << key.second << ", ";
cout << "}" << endl;
ADT::print(" ");
ADT::print(" ", formatter);
}
//
// /* ************************************************************************* */