bottom-up Unique method that works much, much better
parent
23520432ec
commit
2998820d2c
|
|
@ -199,47 +199,41 @@ namespace gtsam {
|
|||
}
|
||||
|
||||
/// If all branches of a choice node f are the same, just return a branch.
|
||||
static NodePtr Unique(const ChoicePtr& f) {
|
||||
#ifndef GTSAM_DT_NO_MERGING
|
||||
// If all the branches are the same, we can merge them into one
|
||||
if (f->allSame_) {
|
||||
assert(f->branches().size() > 0);
|
||||
NodePtr f0 = f->branches_[0];
|
||||
|
||||
size_t nrAssignments = 0;
|
||||
for(auto branch: f->branches()) {
|
||||
if (auto leaf = std::dynamic_pointer_cast<const Leaf>(branch)) {
|
||||
nrAssignments += leaf->nrAssignments();
|
||||
}
|
||||
}
|
||||
NodePtr newLeaf(
|
||||
new Leaf(std::dynamic_pointer_cast<const Leaf>(f0)->constant(),
|
||||
nrAssignments));
|
||||
return newLeaf;
|
||||
|
||||
} else
|
||||
// Else we recurse
|
||||
#endif
|
||||
{
|
||||
|
||||
static NodePtr Unique(const NodePtr& node) {
|
||||
if (auto choice = std::dynamic_pointer_cast<const Choice>(node)) {
|
||||
// Choice node, we recurse!
|
||||
// Make non-const copy
|
||||
auto ff = std::make_shared<Choice>(f->label(), f->nrChoices());
|
||||
auto f = std::make_shared<Choice>(choice->label(), choice->nrChoices());
|
||||
|
||||
// Iterate over all the branches
|
||||
for (size_t i = 0; i < f->nrChoices(); i++) {
|
||||
auto branch = f->branches_[i];
|
||||
if (auto leaf = std::dynamic_pointer_cast<const Leaf>(branch)) {
|
||||
// Leaf node, simply assign
|
||||
ff->push_back(branch);
|
||||
|
||||
} else if (auto choice =
|
||||
std::dynamic_pointer_cast<const Choice>(branch)) {
|
||||
// Choice node, we recurse
|
||||
ff->push_back(Unique(choice));
|
||||
}
|
||||
for (size_t i = 0; i < choice->nrChoices(); i++) {
|
||||
auto branch = choice->branches_[i];
|
||||
f->push_back(Unique(branch));
|
||||
}
|
||||
|
||||
return ff;
|
||||
#ifndef GTSAM_DT_NO_MERGING
|
||||
// If all the branches are the same, we can merge them into one
|
||||
if (f->allSame_) {
|
||||
assert(f->branches().size() > 0);
|
||||
NodePtr f0 = f->branches_[0];
|
||||
|
||||
// Compute total number of assignments
|
||||
size_t nrAssignments = 0;
|
||||
for (auto branch : f->branches()) {
|
||||
if (auto leaf = std::dynamic_pointer_cast<const Leaf>(branch)) {
|
||||
nrAssignments += leaf->nrAssignments();
|
||||
}
|
||||
}
|
||||
NodePtr newLeaf(
|
||||
new Leaf(std::dynamic_pointer_cast<const Leaf>(f0)->constant(),
|
||||
nrAssignments));
|
||||
return newLeaf;
|
||||
}
|
||||
#endif
|
||||
return f;
|
||||
} else {
|
||||
// Leaf node, return as is
|
||||
return node;
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -549,7 +543,7 @@ namespace gtsam {
|
|||
template<typename L, typename Y>
|
||||
template<typename Iterator> DecisionTree<L, Y>::DecisionTree(
|
||||
Iterator begin, Iterator end, const L& label) {
|
||||
root_ = compose(begin, end, label);
|
||||
root_ = Choice::Unique(compose(begin, end, label));
|
||||
}
|
||||
|
||||
/****************************************************************************/
|
||||
|
|
@ -557,7 +551,7 @@ namespace gtsam {
|
|||
DecisionTree<L, Y>::DecisionTree(const L& label,
|
||||
const DecisionTree& f0, const DecisionTree& f1) {
|
||||
const std::vector<DecisionTree> functions{f0, f1};
|
||||
root_ = compose(functions.begin(), functions.end(), label);
|
||||
root_ = Choice::Unique(compose(functions.begin(), functions.end(), label));
|
||||
}
|
||||
|
||||
/****************************************************************************/
|
||||
|
|
@ -608,7 +602,7 @@ namespace gtsam {
|
|||
auto choiceOnLabel = std::make_shared<Choice>(label, end - begin);
|
||||
for (Iterator it = begin; it != end; it++)
|
||||
choiceOnLabel->push_back(it->root_);
|
||||
return Choice::Unique(choiceOnLabel);
|
||||
return choiceOnLabel;
|
||||
} else {
|
||||
// Set up a new choice on the highest label
|
||||
auto choiceOnHighestLabel =
|
||||
|
|
@ -737,7 +731,7 @@ namespace gtsam {
|
|||
for (auto&& branch : choice->branches()) {
|
||||
functions.emplace_back(convertFrom<M, X>(branch, L_of_M, Y_of_X));
|
||||
}
|
||||
return LY::compose(functions.begin(), functions.end(), newLabel);
|
||||
return Choice::Unique(LY::compose(functions.begin(), functions.end(), newLabel));
|
||||
}
|
||||
|
||||
/****************************************************************************/
|
||||
|
|
|
|||
|
|
@ -15,17 +15,20 @@
|
|||
* @author Duy-Nguyen Ta
|
||||
*/
|
||||
|
||||
#include <CppUnitLite/TestHarness.h>
|
||||
#include <gtsam/base/TestableAssertions.h>
|
||||
#include <gtsam/discrete/DiscreteBayesTree.h>
|
||||
#include <gtsam/discrete/DiscreteEliminationTree.h>
|
||||
#include <gtsam/discrete/DiscreteFactor.h>
|
||||
#include <gtsam/discrete/DiscreteFactorGraph.h>
|
||||
#include <gtsam/discrete/DiscreteEliminationTree.h>
|
||||
#include <gtsam/discrete/DiscreteBayesTree.h>
|
||||
#include <gtsam/inference/BayesNet.h>
|
||||
|
||||
#include <CppUnitLite/TestHarness.h>
|
||||
#include <gtsam/inference/Symbol.h>
|
||||
|
||||
using namespace std;
|
||||
using namespace gtsam;
|
||||
|
||||
using symbol_shorthand::M;
|
||||
|
||||
/* ************************************************************************* */
|
||||
TEST_UNSAFE(DiscreteFactorGraph, debugScheduler) {
|
||||
DiscreteKey PC(0, 4), ME(1, 4), AI(2, 4), A(3, 3);
|
||||
|
|
@ -345,6 +348,67 @@ TEST(DiscreteFactorGraph, markdown) {
|
|||
values[1] = 0;
|
||||
EXPECT_DOUBLES_EQUAL(0.3, graph[0]->operator()(values), 1e-9);
|
||||
}
|
||||
|
||||
TEST(DiscreteFactorGraph, NrAssignments) {
|
||||
string expected_dfg = R"(
|
||||
size: 2
|
||||
factor 0: f[ (m0,2), (m1,2), (m2,2), ]
|
||||
Choice(m2)
|
||||
0 Choice(m1)
|
||||
0 0 Leaf [1] 0
|
||||
0 1 Choice(m0)
|
||||
0 1 0 Leaf [1]0.27527634
|
||||
0 1 1 Leaf [1]0.44944733
|
||||
1 Choice(m1)
|
||||
1 0 Leaf [1] 0
|
||||
1 1 Choice(m0)
|
||||
1 1 0 Leaf [1] 0
|
||||
1 1 1 Leaf [1]0.27527634
|
||||
factor 1: f[ (m0,2), (m1,2), (m2,2), (m3,2), ]
|
||||
Choice(m3)
|
||||
0 Choice(m2)
|
||||
0 0 Choice(m1)
|
||||
0 0 0 Leaf [2] 1
|
||||
0 0 1 Leaf [2]0.015366387
|
||||
0 1 Choice(m1)
|
||||
0 1 0 Leaf [2] 1
|
||||
0 1 1 Choice(m0)
|
||||
0 1 1 0 Leaf [1] 1
|
||||
0 1 1 1 Leaf [1]0.015365663
|
||||
1 Choice(m2)
|
||||
1 0 Choice(m1)
|
||||
1 0 0 Leaf [2] 1
|
||||
1 0 1 Choice(m0)
|
||||
1 0 1 0 Leaf [1]0.0094115739
|
||||
1 0 1 1 Leaf [1]0.0094115652
|
||||
1 1 Choice(m1)
|
||||
1 1 0 Leaf [2] 1
|
||||
1 1 1 Choice(m0)
|
||||
1 1 1 0 Leaf [1] 1
|
||||
1 1 1 1 Leaf [1]0.009321081
|
||||
)";
|
||||
|
||||
DiscreteKeys d0{{M(2), 2}, {M(1), 2}, {M(0), 2}};
|
||||
std::vector<double> p0 = {0, 0, 0.17054468, 0.27845056, 0, 0, 0, 0.17054468};
|
||||
AlgebraicDecisionTree<Key> dt(d0, p0);
|
||||
//TODO(Varun) Passing ADT to DiscreteConditional causes nrAssignments to get messed up
|
||||
// Issue seems to be in DecisionTreeFactor.cpp L104
|
||||
DiscreteConditional f0(3, DecisionTreeFactor(d0, dt));
|
||||
|
||||
DiscreteKeys d1{{M(0), 2}, {M(1), 2}, {M(2), 2}, {M(3), 2}};
|
||||
std::vector<double> p1 = {
|
||||
1, 1, 1, 1, 0.015366387, 0.0094115739, 1, 1,
|
||||
1, 1, 1, 1, 0.015366387, 0.0094115652, 0.015365663, 0.009321081};
|
||||
DecisionTreeFactor f1(d1, p1);
|
||||
DecisionTree<Key, double> dt1(d1, p1);
|
||||
|
||||
DiscreteFactorGraph dfg;
|
||||
dfg.add(f0);
|
||||
dfg.add(f1);
|
||||
|
||||
EXPECT(assert_print_equal(expected_dfg, dfg));
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
int main() {
|
||||
TestResult tr;
|
||||
|
|
|
|||
Loading…
Reference in New Issue