diff --git a/gtsam/discrete/DecisionTree-inl.h b/gtsam/discrete/DecisionTree-inl.h index 7a227d8dd..3e85ba70a 100644 --- a/gtsam/discrete/DecisionTree-inl.h +++ b/gtsam/discrete/DecisionTree-inl.h @@ -199,47 +199,41 @@ namespace gtsam { } /// If all branches of a choice node f are the same, just return a branch. - static NodePtr Unique(const ChoicePtr& f) { -#ifndef GTSAM_DT_NO_MERGING - // If all the branches are the same, we can merge them into one - if (f->allSame_) { - assert(f->branches().size() > 0); - NodePtr f0 = f->branches_[0]; - - size_t nrAssignments = 0; - for(auto branch: f->branches()) { - if (auto leaf = std::dynamic_pointer_cast(branch)) { - nrAssignments += leaf->nrAssignments(); - } - } - NodePtr newLeaf( - new Leaf(std::dynamic_pointer_cast(f0)->constant(), - nrAssignments)); - return newLeaf; - - } else - // Else we recurse -#endif - { - + static NodePtr Unique(const NodePtr& node) { + if (auto choice = std::dynamic_pointer_cast(node)) { + // Choice node, we recurse! // Make non-const copy - auto ff = std::make_shared(f->label(), f->nrChoices()); + auto f = std::make_shared(choice->label(), choice->nrChoices()); // Iterate over all the branches - for (size_t i = 0; i < f->nrChoices(); i++) { - auto branch = f->branches_[i]; - if (auto leaf = std::dynamic_pointer_cast(branch)) { - // Leaf node, simply assign - ff->push_back(branch); - - } else if (auto choice = - std::dynamic_pointer_cast(branch)) { - // Choice node, we recurse - ff->push_back(Unique(choice)); - } + for (size_t i = 0; i < choice->nrChoices(); i++) { + auto branch = choice->branches_[i]; + f->push_back(Unique(branch)); } - return ff; +#ifndef GTSAM_DT_NO_MERGING + // If all the branches are the same, we can merge them into one + if (f->allSame_) { + assert(f->branches().size() > 0); + NodePtr f0 = f->branches_[0]; + + // Compute total number of assignments + size_t nrAssignments = 0; + for (auto branch : f->branches()) { + if (auto leaf = std::dynamic_pointer_cast(branch)) { + nrAssignments += leaf->nrAssignments(); + } + } + NodePtr newLeaf( + new Leaf(std::dynamic_pointer_cast(f0)->constant(), + nrAssignments)); + return newLeaf; + } +#endif + return f; + } else { + // Leaf node, return as is + return node; } } @@ -549,7 +543,7 @@ namespace gtsam { template template DecisionTree::DecisionTree( Iterator begin, Iterator end, const L& label) { - root_ = compose(begin, end, label); + root_ = Choice::Unique(compose(begin, end, label)); } /****************************************************************************/ @@ -557,7 +551,7 @@ namespace gtsam { DecisionTree::DecisionTree(const L& label, const DecisionTree& f0, const DecisionTree& f1) { const std::vector functions{f0, f1}; - root_ = compose(functions.begin(), functions.end(), label); + root_ = Choice::Unique(compose(functions.begin(), functions.end(), label)); } /****************************************************************************/ @@ -608,7 +602,7 @@ namespace gtsam { auto choiceOnLabel = std::make_shared(label, end - begin); for (Iterator it = begin; it != end; it++) choiceOnLabel->push_back(it->root_); - return Choice::Unique(choiceOnLabel); + return choiceOnLabel; } else { // Set up a new choice on the highest label auto choiceOnHighestLabel = @@ -737,7 +731,7 @@ namespace gtsam { for (auto&& branch : choice->branches()) { functions.emplace_back(convertFrom(branch, L_of_M, Y_of_X)); } - return LY::compose(functions.begin(), functions.end(), newLabel); + return Choice::Unique(LY::compose(functions.begin(), functions.end(), newLabel)); } /****************************************************************************/ diff --git a/gtsam/discrete/tests/testDiscreteFactorGraph.cpp b/gtsam/discrete/tests/testDiscreteFactorGraph.cpp index bbce5e8ce..f148cf1d8 100644 --- a/gtsam/discrete/tests/testDiscreteFactorGraph.cpp +++ b/gtsam/discrete/tests/testDiscreteFactorGraph.cpp @@ -15,17 +15,20 @@ * @author Duy-Nguyen Ta */ +#include +#include +#include +#include #include #include -#include -#include #include - -#include +#include using namespace std; using namespace gtsam; +using symbol_shorthand::M; + /* ************************************************************************* */ TEST_UNSAFE(DiscreteFactorGraph, debugScheduler) { DiscreteKey PC(0, 4), ME(1, 4), AI(2, 4), A(3, 3); @@ -345,6 +348,67 @@ TEST(DiscreteFactorGraph, markdown) { values[1] = 0; EXPECT_DOUBLES_EQUAL(0.3, graph[0]->operator()(values), 1e-9); } + +TEST(DiscreteFactorGraph, NrAssignments) { + string expected_dfg = R"( +size: 2 +factor 0: f[ (m0,2), (m1,2), (m2,2), ] + Choice(m2) + 0 Choice(m1) + 0 0 Leaf [1] 0 + 0 1 Choice(m0) + 0 1 0 Leaf [1]0.27527634 + 0 1 1 Leaf [1]0.44944733 + 1 Choice(m1) + 1 0 Leaf [1] 0 + 1 1 Choice(m0) + 1 1 0 Leaf [1] 0 + 1 1 1 Leaf [1]0.27527634 +factor 1: f[ (m0,2), (m1,2), (m2,2), (m3,2), ] + Choice(m3) + 0 Choice(m2) + 0 0 Choice(m1) + 0 0 0 Leaf [2] 1 + 0 0 1 Leaf [2]0.015366387 + 0 1 Choice(m1) + 0 1 0 Leaf [2] 1 + 0 1 1 Choice(m0) + 0 1 1 0 Leaf [1] 1 + 0 1 1 1 Leaf [1]0.015365663 + 1 Choice(m2) + 1 0 Choice(m1) + 1 0 0 Leaf [2] 1 + 1 0 1 Choice(m0) + 1 0 1 0 Leaf [1]0.0094115739 + 1 0 1 1 Leaf [1]0.0094115652 + 1 1 Choice(m1) + 1 1 0 Leaf [2] 1 + 1 1 1 Choice(m0) + 1 1 1 0 Leaf [1] 1 + 1 1 1 1 Leaf [1]0.009321081 +)"; + + DiscreteKeys d0{{M(2), 2}, {M(1), 2}, {M(0), 2}}; + std::vector p0 = {0, 0, 0.17054468, 0.27845056, 0, 0, 0, 0.17054468}; + AlgebraicDecisionTree dt(d0, p0); + //TODO(Varun) Passing ADT to DiscreteConditional causes nrAssignments to get messed up + // Issue seems to be in DecisionTreeFactor.cpp L104 + DiscreteConditional f0(3, DecisionTreeFactor(d0, dt)); + + DiscreteKeys d1{{M(0), 2}, {M(1), 2}, {M(2), 2}, {M(3), 2}}; + std::vector p1 = { + 1, 1, 1, 1, 0.015366387, 0.0094115739, 1, 1, + 1, 1, 1, 1, 0.015366387, 0.0094115652, 0.015365663, 0.009321081}; + DecisionTreeFactor f1(d1, p1); + DecisionTree dt1(d1, p1); + + DiscreteFactorGraph dfg; + dfg.add(f0); + dfg.add(f1); + + EXPECT(assert_print_equal(expected_dfg, dfg)); +} + /* ************************************************************************* */ int main() { TestResult tr;