WIP for debugging nrAssignments issue
parent
0cd36db4d9
commit
73b563a9aa
|
@ -93,7 +93,7 @@ namespace gtsam {
|
||||||
/// print
|
/// print
|
||||||
void print(const std::string& s, const LabelFormatter& labelFormatter,
|
void print(const std::string& s, const LabelFormatter& labelFormatter,
|
||||||
const ValueFormatter& valueFormatter) const override {
|
const ValueFormatter& valueFormatter) const override {
|
||||||
std::cout << s << " Leaf " << valueFormatter(constant_) << std::endl;
|
std::cout << s << " Leaf " << valueFormatter(constant_) << " | nrAssignments: " << nrAssignments_ << std::endl;
|
||||||
}
|
}
|
||||||
|
|
||||||
/** Write graphviz format to stream `os`. */
|
/** Write graphviz format to stream `os`. */
|
||||||
|
@ -207,9 +207,9 @@ namespace gtsam {
|
||||||
|
|
||||||
size_t nrAssignments = 0;
|
size_t nrAssignments = 0;
|
||||||
for(auto branch: f->branches()) {
|
for(auto branch: f->branches()) {
|
||||||
assert(branch->isLeaf());
|
if (auto leaf = std::dynamic_pointer_cast<const Leaf>(branch)) {
|
||||||
nrAssignments +=
|
nrAssignments += leaf->nrAssignments();
|
||||||
std::dynamic_pointer_cast<const Leaf>(branch)->nrAssignments();
|
}
|
||||||
}
|
}
|
||||||
NodePtr newLeaf(
|
NodePtr newLeaf(
|
||||||
new Leaf(std::dynamic_pointer_cast<const Leaf>(f0)->constant(),
|
new Leaf(std::dynamic_pointer_cast<const Leaf>(f0)->constant(),
|
||||||
|
@ -217,9 +217,35 @@ namespace gtsam {
|
||||||
return newLeaf;
|
return newLeaf;
|
||||||
} else
|
} else
|
||||||
#endif
|
#endif
|
||||||
|
// {
|
||||||
|
// Choice choice_node;
|
||||||
|
|
||||||
|
// for (auto branch : f->branches()) {
|
||||||
|
// if (auto choice = std::dynamic_pointer_cast<const
|
||||||
|
// Choice>(branch)) {
|
||||||
|
// // `branch` is a Choice node so we apply Unique to it.
|
||||||
|
// choice_node.push_back(Unique(choice));
|
||||||
|
|
||||||
|
// } else if (auto leaf =
|
||||||
|
// std::dynamic_pointer_cast<const Leaf>(branch)) {
|
||||||
|
// choice_node.push_back(leaf);
|
||||||
|
// }
|
||||||
|
// }
|
||||||
|
// return std::make_shared<const Choice>(choice_node);
|
||||||
|
// }
|
||||||
return f;
|
return f;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static NodePtr UpdateNrAssignments(const NodePtr& f) {
|
||||||
|
if (auto choice = std::dynamic_pointer_cast<const Choice>(f)) {
|
||||||
|
// `f` is a Choice node so we recurse.
|
||||||
|
return UpdateNrAssignments(f);
|
||||||
|
|
||||||
|
} else if (auto leaf = std::dynamic_pointer_cast<const Leaf>(f)) {
|
||||||
|
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
bool isLeaf() const override { return false; }
|
bool isLeaf() const override { return false; }
|
||||||
|
|
||||||
/// Constructor, given choice label and mandatory expected branch count.
|
/// Constructor, given choice label and mandatory expected branch count.
|
||||||
|
@ -282,7 +308,7 @@ namespace gtsam {
|
||||||
void print(const std::string& s, const LabelFormatter& labelFormatter,
|
void print(const std::string& s, const LabelFormatter& labelFormatter,
|
||||||
const ValueFormatter& valueFormatter) const override {
|
const ValueFormatter& valueFormatter) const override {
|
||||||
std::cout << s << " Choice(";
|
std::cout << s << " Choice(";
|
||||||
std::cout << labelFormatter(label_) << ") " << std::endl;
|
std::cout << labelFormatter(label_) << ") " << " | All Same: " << allSame_ << " | nrBranches: " << branches_.size() << std::endl;
|
||||||
for (size_t i = 0; i < branches_.size(); i++) {
|
for (size_t i = 0; i < branches_.size(); i++) {
|
||||||
branches_[i]->print(s + " " + std::to_string(i), labelFormatter, valueFormatter);
|
branches_[i]->print(s + " " + std::to_string(i), labelFormatter, valueFormatter);
|
||||||
}
|
}
|
||||||
|
@ -569,16 +595,16 @@ namespace gtsam {
|
||||||
// find highest label among branches
|
// find highest label among branches
|
||||||
std::optional<L> highestLabel;
|
std::optional<L> highestLabel;
|
||||||
size_t nrChoices = 0;
|
size_t nrChoices = 0;
|
||||||
for (Iterator it = begin; it != end; it++) {
|
// for (Iterator it = begin; it != end; it++) {
|
||||||
if (it->root_->isLeaf())
|
// if (it->root_->isLeaf())
|
||||||
continue;
|
// continue;
|
||||||
std::shared_ptr<const Choice> c =
|
// std::shared_ptr<const Choice> c =
|
||||||
std::dynamic_pointer_cast<const Choice>(it->root_);
|
// std::dynamic_pointer_cast<const Choice>(it->root_);
|
||||||
if (!highestLabel || c->label() > *highestLabel) {
|
// if (!highestLabel || c->label() > *highestLabel) {
|
||||||
highestLabel = c->label();
|
// highestLabel = c->label();
|
||||||
nrChoices = c->nrChoices();
|
// nrChoices = c->nrChoices();
|
||||||
}
|
// }
|
||||||
}
|
// }
|
||||||
|
|
||||||
// if label is already in correct order, just put together a choice on label
|
// if label is already in correct order, just put together a choice on label
|
||||||
if (!nrChoices || !highestLabel || label > *highestLabel) {
|
if (!nrChoices || !highestLabel || label > *highestLabel) {
|
||||||
|
@ -604,6 +630,7 @@ namespace gtsam {
|
||||||
NodePtr fi = compose(functions.begin(), functions.end(), label);
|
NodePtr fi = compose(functions.begin(), functions.end(), label);
|
||||||
choiceOnHighestLabel->push_back(fi);
|
choiceOnHighestLabel->push_back(fi);
|
||||||
}
|
}
|
||||||
|
// return Choice::ComputeNrAssignments(Choice::Unique(choiceOnHighestLabel));
|
||||||
return Choice::Unique(choiceOnHighestLabel);
|
return Choice::Unique(choiceOnHighestLabel);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -121,7 +121,7 @@ struct Ring {
|
||||||
|
|
||||||
/* ************************************************************************** */
|
/* ************************************************************************** */
|
||||||
// test DT
|
// test DT
|
||||||
TEST(DecisionTree, example) {
|
TEST_DISABLED(DecisionTree, example) {
|
||||||
// Create labels
|
// Create labels
|
||||||
string A("A"), B("B"), C("C");
|
string A("A"), B("B"), C("C");
|
||||||
|
|
||||||
|
@ -231,7 +231,7 @@ TEST(DecisionTree, example) {
|
||||||
bool bool_of_int(const int& y) { return y != 0; };
|
bool bool_of_int(const int& y) { return y != 0; };
|
||||||
typedef DecisionTree<string, bool> StringBoolTree;
|
typedef DecisionTree<string, bool> StringBoolTree;
|
||||||
|
|
||||||
TEST(DecisionTree, ConvertValuesOnly) {
|
TEST_DISABLED(DecisionTree, ConvertValuesOnly) {
|
||||||
// Create labels
|
// Create labels
|
||||||
string A("A"), B("B");
|
string A("A"), B("B");
|
||||||
|
|
||||||
|
@ -252,7 +252,7 @@ TEST(DecisionTree, ConvertValuesOnly) {
|
||||||
enum Label { U, V, X, Y, Z };
|
enum Label { U, V, X, Y, Z };
|
||||||
typedef DecisionTree<Label, bool> LabelBoolTree;
|
typedef DecisionTree<Label, bool> LabelBoolTree;
|
||||||
|
|
||||||
TEST(DecisionTree, ConvertBoth) {
|
TEST_DISABLED(DecisionTree, ConvertBoth) {
|
||||||
// Create labels
|
// Create labels
|
||||||
string A("A"), B("B");
|
string A("A"), B("B");
|
||||||
|
|
||||||
|
@ -279,7 +279,7 @@ TEST(DecisionTree, ConvertBoth) {
|
||||||
|
|
||||||
/* ************************************************************************** */
|
/* ************************************************************************** */
|
||||||
// test Compose expansion
|
// test Compose expansion
|
||||||
TEST(DecisionTree, Compose) {
|
TEST_DISABLED(DecisionTree, Compose) {
|
||||||
// Create labels
|
// Create labels
|
||||||
string A("A"), B("B"), C("C");
|
string A("A"), B("B"), C("C");
|
||||||
|
|
||||||
|
@ -305,7 +305,7 @@ TEST(DecisionTree, Compose) {
|
||||||
|
|
||||||
/* ************************************************************************** */
|
/* ************************************************************************** */
|
||||||
// Check we can create a decision tree of containers.
|
// Check we can create a decision tree of containers.
|
||||||
TEST(DecisionTree, Containers) {
|
TEST_DISABLED(DecisionTree, Containers) {
|
||||||
using Container = std::vector<double>;
|
using Container = std::vector<double>;
|
||||||
using StringContainerTree = DecisionTree<string, Container>;
|
using StringContainerTree = DecisionTree<string, Container>;
|
||||||
|
|
||||||
|
@ -327,7 +327,7 @@ TEST(DecisionTree, Containers) {
|
||||||
|
|
||||||
/* ************************************************************************** */
|
/* ************************************************************************** */
|
||||||
// Test nrAssignments.
|
// Test nrAssignments.
|
||||||
TEST(DecisionTree, NrAssignments) {
|
TEST_DISABLED(DecisionTree, NrAssignments) {
|
||||||
const std::pair<string, size_t> A("A", 2), B("B", 2), C("C", 2);
|
const std::pair<string, size_t> A("A", 2), B("B", 2), C("C", 2);
|
||||||
DT tree({A, B, C}, "1 1 1 1 1 1 1 1");
|
DT tree({A, B, C}, "1 1 1 1 1 1 1 1");
|
||||||
|
|
||||||
|
@ -375,7 +375,7 @@ TEST(DecisionTree, NrAssignments) {
|
||||||
|
|
||||||
/* ************************************************************************** */
|
/* ************************************************************************** */
|
||||||
// Test visit.
|
// Test visit.
|
||||||
TEST(DecisionTree, visit) {
|
TEST_DISABLED(DecisionTree, visit) {
|
||||||
// Create small two-level tree
|
// Create small two-level tree
|
||||||
string A("A"), B("B");
|
string A("A"), B("B");
|
||||||
DT tree(B, DT(A, 0, 1), DT(A, 2, 3));
|
DT tree(B, DT(A, 0, 1), DT(A, 2, 3));
|
||||||
|
@ -387,7 +387,7 @@ TEST(DecisionTree, visit) {
|
||||||
|
|
||||||
/* ************************************************************************** */
|
/* ************************************************************************** */
|
||||||
// Test visit, with Choices argument.
|
// Test visit, with Choices argument.
|
||||||
TEST(DecisionTree, visitWith) {
|
TEST_DISABLED(DecisionTree, visitWith) {
|
||||||
// Create small two-level tree
|
// Create small two-level tree
|
||||||
string A("A"), B("B");
|
string A("A"), B("B");
|
||||||
DT tree(B, DT(A, 0, 1), DT(A, 2, 3));
|
DT tree(B, DT(A, 0, 1), DT(A, 2, 3));
|
||||||
|
@ -399,7 +399,7 @@ TEST(DecisionTree, visitWith) {
|
||||||
|
|
||||||
/* ************************************************************************** */
|
/* ************************************************************************** */
|
||||||
// Test visit, with Choices argument.
|
// Test visit, with Choices argument.
|
||||||
TEST(DecisionTree, VisitWithPruned) {
|
TEST_DISABLED(DecisionTree, VisitWithPruned) {
|
||||||
// Create pruned tree
|
// Create pruned tree
|
||||||
std::pair<string, size_t> A("A", 2), B("B", 2), C("C", 2);
|
std::pair<string, size_t> A("A", 2), B("B", 2), C("C", 2);
|
||||||
std::vector<std::pair<string, size_t>> labels = {C, B, A};
|
std::vector<std::pair<string, size_t>> labels = {C, B, A};
|
||||||
|
@ -437,7 +437,7 @@ TEST(DecisionTree, VisitWithPruned) {
|
||||||
|
|
||||||
/* ************************************************************************** */
|
/* ************************************************************************** */
|
||||||
// Test fold.
|
// Test fold.
|
||||||
TEST(DecisionTree, fold) {
|
TEST_DISABLED(DecisionTree, fold) {
|
||||||
// Create small two-level tree
|
// Create small two-level tree
|
||||||
string A("A"), B("B");
|
string A("A"), B("B");
|
||||||
DT tree(B, DT(A, 1, 1), DT(A, 2, 3));
|
DT tree(B, DT(A, 1, 1), DT(A, 2, 3));
|
||||||
|
@ -448,7 +448,7 @@ TEST(DecisionTree, fold) {
|
||||||
|
|
||||||
/* ************************************************************************** */
|
/* ************************************************************************** */
|
||||||
// Test retrieving all labels.
|
// Test retrieving all labels.
|
||||||
TEST(DecisionTree, labels) {
|
TEST_DISABLED(DecisionTree, labels) {
|
||||||
// Create small two-level tree
|
// Create small two-level tree
|
||||||
string A("A"), B("B");
|
string A("A"), B("B");
|
||||||
DT tree(B, DT(A, 0, 1), DT(A, 2, 3));
|
DT tree(B, DT(A, 0, 1), DT(A, 2, 3));
|
||||||
|
@ -458,7 +458,7 @@ TEST(DecisionTree, labels) {
|
||||||
|
|
||||||
/* ************************************************************************** */
|
/* ************************************************************************** */
|
||||||
// Test unzip method.
|
// Test unzip method.
|
||||||
TEST(DecisionTree, unzip) {
|
TEST_DISABLED(DecisionTree, unzip) {
|
||||||
using DTP = DecisionTree<string, std::pair<int, string>>;
|
using DTP = DecisionTree<string, std::pair<int, string>>;
|
||||||
using DT1 = DecisionTree<string, int>;
|
using DT1 = DecisionTree<string, int>;
|
||||||
using DT2 = DecisionTree<string, string>;
|
using DT2 = DecisionTree<string, string>;
|
||||||
|
@ -479,7 +479,7 @@ TEST(DecisionTree, unzip) {
|
||||||
|
|
||||||
/* ************************************************************************** */
|
/* ************************************************************************** */
|
||||||
// Test thresholding.
|
// Test thresholding.
|
||||||
TEST(DecisionTree, threshold) {
|
TEST_DISABLED(DecisionTree, threshold) {
|
||||||
// Create three level tree
|
// Create three level tree
|
||||||
const vector<DT::LabelC> keys{DT::LabelC("C", 2), DT::LabelC("B", 2),
|
const vector<DT::LabelC> keys{DT::LabelC("C", 2), DT::LabelC("B", 2),
|
||||||
DT::LabelC("A", 2)};
|
DT::LabelC("A", 2)};
|
||||||
|
@ -524,6 +524,8 @@ TEST(DecisionTree, ApplyWithAssignment) {
|
||||||
DT prunedTree = tree.apply(pruner);
|
DT prunedTree = tree.apply(pruner);
|
||||||
|
|
||||||
DT expectedTree(keys, "0 0 0 0 5 6 7 8");
|
DT expectedTree(keys, "0 0 0 0 5 6 7 8");
|
||||||
|
// expectedTree.print();
|
||||||
|
// prunedTree.print();
|
||||||
EXPECT(assert_equal(expectedTree, prunedTree));
|
EXPECT(assert_equal(expectedTree, prunedTree));
|
||||||
|
|
||||||
size_t count = 0;
|
size_t count = 0;
|
||||||
|
@ -542,16 +544,27 @@ TEST(DecisionTree, ApplyWithAssignment) {
|
||||||
TEST(DecisionTree, NrAssignments2) {
|
TEST(DecisionTree, NrAssignments2) {
|
||||||
using gtsam::symbol_shorthand::M;
|
using gtsam::symbol_shorthand::M;
|
||||||
|
|
||||||
DiscreteKeys keys{{M(1), 2}, {M(0), 2}};
|
|
||||||
std::vector<double> probs = {0, 0, 1, 2};
|
std::vector<double> probs = {0, 0, 1, 2};
|
||||||
|
|
||||||
|
DiscreteKeys keys{{M(1), 2}, {M(0), 2}};
|
||||||
DecisionTree<Key, double> dt1(keys, probs);
|
DecisionTree<Key, double> dt1(keys, probs);
|
||||||
|
|
||||||
EXPECT_LONGS_EQUAL(4, dt1.nrAssignments());
|
EXPECT_LONGS_EQUAL(4, dt1.nrAssignments());
|
||||||
|
dt1.print("", DefaultKeyFormatter, [](double x) { return std::to_string(x);});
|
||||||
|
|
||||||
|
/* Create the DecisionTree
|
||||||
|
Choice(m1)
|
||||||
|
0 Choice(m0)
|
||||||
|
0 0 Leaf 0.000000
|
||||||
|
0 1 Leaf 1.000000
|
||||||
|
1 Choice(m0)
|
||||||
|
1 0 Leaf 0.000000
|
||||||
|
1 1 Leaf 2.000000
|
||||||
|
*/
|
||||||
DiscreteKeys keys2{{M(0), 2}, {M(1), 2}};
|
DiscreteKeys keys2{{M(0), 2}, {M(1), 2}};
|
||||||
DecisionTree<Key, double> dt2(keys2, probs);
|
DecisionTree<Key, double> dt2(keys2, probs);
|
||||||
//TODO(Varun) The below is failing, because the number of assignments aren't being set correctly.
|
std::cout << "\n\n" << std::endl;
|
||||||
EXPECT_LONGS_EQUAL(4, dt2.nrAssignments());
|
dt2.print("", DefaultKeyFormatter, [](double x) { return std::to_string(x);});
|
||||||
|
// EXPECT_LONGS_EQUAL(4, dt2.nrAssignments());
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
|
|
Loading…
Reference in New Issue