parent
647d3c0744
commit
2db08281c6
|
@ -53,17 +53,26 @@ namespace gtsam {
|
|||
/** constant stored in this leaf */
|
||||
Y constant_;
|
||||
|
||||
/** The number of assignments contained within this leaf.
|
||||
* Particularly useful when leaves have been pruned.
|
||||
*/
|
||||
size_t nrAssignments_;
|
||||
|
||||
/// Default constructor for serialization.
|
||||
Leaf() {}
|
||||
|
||||
/// Constructor from constant
|
||||
Leaf(const Y& constant) : constant_(constant) {}
|
||||
Leaf(const Y& constant, size_t nrAssignments = 1)
|
||||
: constant_(constant), nrAssignments_(nrAssignments) {}
|
||||
|
||||
/// Return the constant
|
||||
const Y& constant() const {
|
||||
return constant_;
|
||||
}
|
||||
|
||||
/// Return the number of assignments contained within this leaf.
|
||||
size_t nrAssignments() const { return nrAssignments_; }
|
||||
|
||||
/// Leaf-Leaf equality
|
||||
bool sameLeaf(const Leaf& q) const override {
|
||||
return constant_ == q.constant_;
|
||||
|
@ -84,7 +93,8 @@ namespace gtsam {
|
|||
/// print
|
||||
void print(const std::string& s, const LabelFormatter& labelFormatter,
|
||||
const ValueFormatter& valueFormatter) const override {
|
||||
std::cout << s << " Leaf " << valueFormatter(constant_) << std::endl;
|
||||
std::cout << s << " Leaf [" << nrAssignments() << "]"
|
||||
<< valueFormatter(constant_) << std::endl;
|
||||
}
|
||||
|
||||
/** Write graphviz format to stream `os`. */
|
||||
|
@ -104,14 +114,14 @@ namespace gtsam {
|
|||
|
||||
/** apply unary operator */
|
||||
NodePtr apply(const Unary& op) const override {
|
||||
NodePtr f(new Leaf(op(constant_)));
|
||||
NodePtr f(new Leaf(op(constant_), nrAssignments_));
|
||||
return f;
|
||||
}
|
||||
|
||||
/// Apply unary operator with assignment
|
||||
NodePtr apply(const UnaryAssignment& op,
|
||||
const Assignment<L>& assignment) const override {
|
||||
NodePtr f(new Leaf(op(assignment, constant_)));
|
||||
NodePtr f(new Leaf(op(assignment, constant_), nrAssignments_));
|
||||
return f;
|
||||
}
|
||||
|
||||
|
@ -127,7 +137,9 @@ namespace gtsam {
|
|||
// Applying binary operator to two leaves results in a leaf
|
||||
NodePtr apply_g_op_fL(const Leaf& fL, const Binary& op) const override {
|
||||
// fL op gL
|
||||
NodePtr h(new Leaf(op(fL.constant_, constant_)));
|
||||
// TODO(Varun) nrAssignments setting is not correct.
|
||||
// Depending on f and g, the nrAssignments can be different. This is a bug!
|
||||
NodePtr h(new Leaf(op(fL.constant_, constant_), fL.nrAssignments()));
|
||||
return h;
|
||||
}
|
||||
|
||||
|
@ -138,7 +150,7 @@ namespace gtsam {
|
|||
|
||||
/** choose a branch, create new memory ! */
|
||||
NodePtr choose(const L& label, size_t index) const override {
|
||||
return NodePtr(new Leaf(constant()));
|
||||
return NodePtr(new Leaf(constant(), nrAssignments()));
|
||||
}
|
||||
|
||||
bool isLeaf() const override { return true; }
|
||||
|
@ -153,6 +165,7 @@ namespace gtsam {
|
|||
void serialize(ARCHIVE& ar, const unsigned int /*version*/) {
|
||||
ar & BOOST_SERIALIZATION_BASE_OBJECT_NVP(Base);
|
||||
ar& BOOST_SERIALIZATION_NVP(constant_);
|
||||
ar& BOOST_SERIALIZATION_NVP(nrAssignments_);
|
||||
}
|
||||
#endif
|
||||
}; // Leaf
|
||||
|
@ -222,8 +235,16 @@ namespace gtsam {
|
|||
assert(f->branches().size() > 0);
|
||||
NodePtr f0 = f->branches_[0];
|
||||
|
||||
// Compute total number of assignments
|
||||
size_t nrAssignments = 0;
|
||||
for (auto branch : f->branches()) {
|
||||
if (auto leaf = std::dynamic_pointer_cast<const Leaf>(branch)) {
|
||||
nrAssignments += leaf->nrAssignments();
|
||||
}
|
||||
}
|
||||
NodePtr newLeaf(
|
||||
new Leaf(std::dynamic_pointer_cast<const Leaf>(f0)->constant()));
|
||||
new Leaf(std::dynamic_pointer_cast<const Leaf>(f0)->constant(),
|
||||
nrAssignments));
|
||||
return newLeaf;
|
||||
}
|
||||
#endif
|
||||
|
@ -709,7 +730,7 @@ namespace gtsam {
|
|||
// If leaf, apply unary conversion "op" and create a unique leaf.
|
||||
using MXLeaf = typename DecisionTree<M, X>::Leaf;
|
||||
if (auto leaf = std::dynamic_pointer_cast<const MXLeaf>(f)) {
|
||||
return NodePtr(new Leaf(Y_of_X(leaf->constant())));
|
||||
return NodePtr(new Leaf(Y_of_X(leaf->constant()), leaf->nrAssignments()));
|
||||
}
|
||||
|
||||
// Check if Choice
|
||||
|
@ -856,6 +877,16 @@ namespace gtsam {
|
|||
return total;
|
||||
}
|
||||
|
||||
/****************************************************************************/
|
||||
template <typename L, typename Y>
|
||||
size_t DecisionTree<L, Y>::nrAssignments() const {
|
||||
size_t n = 0;
|
||||
this->visitLeaf([&n](const DecisionTree<L, Y>::Leaf& leaf) {
|
||||
n += leaf.nrAssignments();
|
||||
});
|
||||
return n;
|
||||
}
|
||||
|
||||
/****************************************************************************/
|
||||
// fold is just done with a visit
|
||||
template <typename L, typename Y>
|
||||
|
|
|
@ -307,6 +307,42 @@ namespace gtsam {
|
|||
/// Return the number of leaves in the tree.
|
||||
size_t nrLeaves() const;
|
||||
|
||||
/**
|
||||
* @brief This is a convenience function which returns the total number of
|
||||
* leaf assignments in the decision tree.
|
||||
* This function is not used for anymajor operations within the discrete
|
||||
* factor graph framework.
|
||||
*
|
||||
* Leaf assignments represent the cardinality of each leaf node, e.g. in a
|
||||
* binary tree each leaf has 2 assignments. This includes counts removed
|
||||
* from implicit pruning hence, it will always be >= nrLeaves().
|
||||
*
|
||||
* E.g. we have a decision tree as below, where each node has 2 branches:
|
||||
*
|
||||
* Choice(m1)
|
||||
* 0 Choice(m0)
|
||||
* 0 0 Leaf 0.0
|
||||
* 0 1 Leaf 0.0
|
||||
* 1 Choice(m0)
|
||||
* 1 0 Leaf 1.0
|
||||
* 1 1 Leaf 2.0
|
||||
*
|
||||
* In the unpruned form, the tree will have 4 assignments, 2 for each key,
|
||||
* and 4 leaves.
|
||||
*
|
||||
* In the pruned form, the number of assignments is still 4 but the number
|
||||
* of leaves is now 3, as below:
|
||||
*
|
||||
* Choice(m1)
|
||||
* 0 Leaf 0.0
|
||||
* 1 Choice(m0)
|
||||
* 1 0 Leaf 1.0
|
||||
* 1 1 Leaf 2.0
|
||||
*
|
||||
* @return size_t
|
||||
*/
|
||||
size_t nrAssignments() const;
|
||||
|
||||
/**
|
||||
* @brief Fold a binary function over the tree, returning accumulator.
|
||||
*
|
||||
|
|
|
@ -328,6 +328,59 @@ TEST(DecisionTree, Containers) {
|
|||
StringContainerTree converted(stringIntTree, container_of_int);
|
||||
}
|
||||
|
||||
/* ************************************************************************** */
|
||||
// Test nrAssignments.
|
||||
TEST(DecisionTree, NrAssignments) {
|
||||
const std::pair<string, size_t> A("A", 2), B("B", 2), C("C", 2);
|
||||
DT tree({A, B, C}, "1 1 1 1 1 1 1 1");
|
||||
|
||||
EXPECT_LONGS_EQUAL(8, tree.nrAssignments());
|
||||
|
||||
#ifdef GTSAM_DT_MERGING
|
||||
EXPECT(tree.root_->isLeaf());
|
||||
auto leaf = std::dynamic_pointer_cast<const DT::Leaf>(tree.root_);
|
||||
EXPECT_LONGS_EQUAL(8, leaf->nrAssignments());
|
||||
#endif
|
||||
|
||||
DT tree2({C, B, A}, "1 1 1 2 3 4 5 5");
|
||||
/* The tree is
|
||||
Choice(C)
|
||||
0 Choice(B)
|
||||
0 0 Leaf 1
|
||||
0 1 Choice(A)
|
||||
0 1 0 Leaf 1
|
||||
0 1 1 Leaf 2
|
||||
1 Choice(B)
|
||||
1 0 Choice(A)
|
||||
1 0 0 Leaf 3
|
||||
1 0 1 Leaf 4
|
||||
1 1 Leaf 5
|
||||
*/
|
||||
|
||||
EXPECT_LONGS_EQUAL(8, tree2.nrAssignments());
|
||||
|
||||
auto root = std::dynamic_pointer_cast<const DT::Choice>(tree2.root_);
|
||||
CHECK(root);
|
||||
auto choice0 = std::dynamic_pointer_cast<const DT::Choice>(root->branches()[0]);
|
||||
CHECK(choice0);
|
||||
|
||||
#ifdef GTSAM_DT_MERGING
|
||||
EXPECT(choice0->branches()[0]->isLeaf());
|
||||
auto choice00 = std::dynamic_pointer_cast<const DT::Leaf>(choice0->branches()[0]);
|
||||
CHECK(choice00);
|
||||
EXPECT_LONGS_EQUAL(2, choice00->nrAssignments());
|
||||
|
||||
auto choice1 = std::dynamic_pointer_cast<const DT::Choice>(root->branches()[1]);
|
||||
CHECK(choice1);
|
||||
auto choice10 = std::dynamic_pointer_cast<const DT::Choice>(choice1->branches()[0]);
|
||||
CHECK(choice10);
|
||||
auto choice11 = std::dynamic_pointer_cast<const DT::Leaf>(choice1->branches()[1]);
|
||||
CHECK(choice11);
|
||||
EXPECT(choice11->isLeaf());
|
||||
EXPECT_LONGS_EQUAL(2, choice11->nrAssignments());
|
||||
#endif
|
||||
}
|
||||
|
||||
/* ************************************************************************** */
|
||||
// Test visit.
|
||||
TEST(DecisionTree, visit) {
|
||||
|
@ -540,6 +593,38 @@ TEST(DecisionTree, ApplyWithAssignment) {
|
|||
#endif
|
||||
}
|
||||
|
||||
/* ************************************************************************** */
|
||||
// Test number of assignments.
|
||||
TEST(DecisionTree, NrAssignments2) {
|
||||
using gtsam::symbol_shorthand::M;
|
||||
|
||||
std::vector<double> probs = {0, 0, 1, 2};
|
||||
|
||||
/* Create the decision tree
|
||||
Choice(m1)
|
||||
0 Leaf 0.000000
|
||||
1 Choice(m0)
|
||||
1 0 Leaf 1.000000
|
||||
1 1 Leaf 2.000000
|
||||
*/
|
||||
DiscreteKeys keys{{M(1), 2}, {M(0), 2}};
|
||||
DecisionTree<Key, double> dt1(keys, probs);
|
||||
EXPECT_LONGS_EQUAL(4, dt1.nrAssignments());
|
||||
|
||||
/* Create the DecisionTree
|
||||
Choice(m1)
|
||||
0 Choice(m0)
|
||||
0 0 Leaf 0.000000
|
||||
0 1 Leaf 1.000000
|
||||
1 Choice(m0)
|
||||
1 0 Leaf 0.000000
|
||||
1 1 Leaf 2.000000
|
||||
*/
|
||||
DiscreteKeys keys2{{M(0), 2}, {M(1), 2}};
|
||||
DecisionTree<Key, double> dt2(keys2, probs);
|
||||
EXPECT_LONGS_EQUAL(4, dt2.nrAssignments());
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
int main() {
|
||||
TestResult tr;
|
||||
|
|
|
@ -349,6 +349,119 @@ TEST(DiscreteFactorGraph, markdown) {
|
|||
EXPECT_DOUBLES_EQUAL(0.3, graph[0]->operator()(values), 1e-9);
|
||||
}
|
||||
|
||||
TEST(DiscreteFactorGraph, NrAssignments) {
|
||||
#ifdef GTSAM_DT_MERGING
|
||||
string expected_dfg = R"(
|
||||
size: 2
|
||||
factor 0: f[ (m0,2), (m1,2), (m2,2), ]
|
||||
Choice(m2)
|
||||
0 Choice(m1)
|
||||
0 0 Leaf [2] 0
|
||||
0 1 Choice(m0)
|
||||
0 1 0 Leaf [1]0.27527634
|
||||
0 1 1 Leaf [1] 0
|
||||
1 Choice(m1)
|
||||
1 0 Leaf [2] 0
|
||||
1 1 Choice(m0)
|
||||
1 1 0 Leaf [1]0.44944733
|
||||
1 1 1 Leaf [1]0.27527634
|
||||
factor 1: f[ (m0,2), (m1,2), (m2,2), (m3,2), ]
|
||||
Choice(m3)
|
||||
0 Choice(m2)
|
||||
0 0 Choice(m1)
|
||||
0 0 0 Leaf [2] 1
|
||||
0 0 1 Leaf [2]0.015366387
|
||||
0 1 Choice(m1)
|
||||
0 1 0 Leaf [2] 1
|
||||
0 1 1 Choice(m0)
|
||||
0 1 1 0 Leaf [1] 1
|
||||
0 1 1 1 Leaf [1]0.015365663
|
||||
1 Choice(m2)
|
||||
1 0 Choice(m1)
|
||||
1 0 0 Leaf [2] 1
|
||||
1 0 1 Choice(m0)
|
||||
1 0 1 0 Leaf [1]0.0094115739
|
||||
1 0 1 1 Leaf [1]0.0094115652
|
||||
1 1 Choice(m1)
|
||||
1 1 0 Leaf [2] 1
|
||||
1 1 1 Choice(m0)
|
||||
1 1 1 0 Leaf [1] 1
|
||||
1 1 1 1 Leaf [1]0.009321081
|
||||
)";
|
||||
#else
|
||||
string expected_dfg = R"(
|
||||
size: 2
|
||||
factor 0: f[ (m0,2), (m1,2), (m2,2), ]
|
||||
Choice(m2)
|
||||
0 Choice(m1)
|
||||
0 0 Choice(m0)
|
||||
0 0 0 Leaf [1] 0
|
||||
0 0 1 Leaf [1] 0
|
||||
0 1 Choice(m0)
|
||||
0 1 0 Leaf [1]0.27527634
|
||||
0 1 1 Leaf [1]0.44944733
|
||||
1 Choice(m1)
|
||||
1 0 Choice(m0)
|
||||
1 0 0 Leaf [1] 0
|
||||
1 0 1 Leaf [1] 0
|
||||
1 1 Choice(m0)
|
||||
1 1 0 Leaf [1] 0
|
||||
1 1 1 Leaf [1]0.27527634
|
||||
factor 1: f[ (m0,2), (m1,2), (m2,2), (m3,2), ]
|
||||
Choice(m3)
|
||||
0 Choice(m2)
|
||||
0 0 Choice(m1)
|
||||
0 0 0 Choice(m0)
|
||||
0 0 0 0 Leaf [1] 1
|
||||
0 0 0 1 Leaf [1] 1
|
||||
0 0 1 Choice(m0)
|
||||
0 0 1 0 Leaf [1]0.015366387
|
||||
0 0 1 1 Leaf [1]0.015366387
|
||||
0 1 Choice(m1)
|
||||
0 1 0 Choice(m0)
|
||||
0 1 0 0 Leaf [1] 1
|
||||
0 1 0 1 Leaf [1] 1
|
||||
0 1 1 Choice(m0)
|
||||
0 1 1 0 Leaf [1] 1
|
||||
0 1 1 1 Leaf [1]0.015365663
|
||||
1 Choice(m2)
|
||||
1 0 Choice(m1)
|
||||
1 0 0 Choice(m0)
|
||||
1 0 0 0 Leaf [1] 1
|
||||
1 0 0 1 Leaf [1] 1
|
||||
1 0 1 Choice(m0)
|
||||
1 0 1 0 Leaf [1]0.0094115739
|
||||
1 0 1 1 Leaf [1]0.0094115652
|
||||
1 1 Choice(m1)
|
||||
1 1 0 Choice(m0)
|
||||
1 1 0 0 Leaf [1] 1
|
||||
1 1 0 1 Leaf [1] 1
|
||||
1 1 1 Choice(m0)
|
||||
1 1 1 0 Leaf [1] 1
|
||||
1 1 1 1 Leaf [1]0.009321081
|
||||
)";
|
||||
#endif
|
||||
|
||||
DiscreteKeys d0{{M(0), 2}, {M(1), 2}, {M(2), 2}};
|
||||
std::vector<double> p0 = {0, 0, 0.17054468, 0.27845056, 0, 0, 0, 0.17054468};
|
||||
AlgebraicDecisionTree<Key> dt(d0, p0);
|
||||
//TODO(Varun) Passing ADT to DiscreteConditional causes nrAssignments to get messed up
|
||||
// Issue seems to be in DecisionTreeFactor.cpp L104
|
||||
DiscreteConditional f0(3, DecisionTreeFactor(d0, dt));
|
||||
|
||||
DiscreteKeys d1{{M(0), 2}, {M(1), 2}, {M(2), 2}, {M(3), 2}};
|
||||
std::vector<double> p1 = {
|
||||
1, 1, 1, 1, 0.015366387, 0.0094115739, 1, 1,
|
||||
1, 1, 1, 1, 0.015366387, 0.0094115652, 0.015365663, 0.009321081};
|
||||
DecisionTreeFactor f1(d1, p1);
|
||||
|
||||
DiscreteFactorGraph dfg;
|
||||
dfg.add(f0);
|
||||
dfg.add(f1);
|
||||
|
||||
EXPECT(assert_print_equal(expected_dfg, dfg));
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
int main() {
|
||||
TestResult tr;
|
||||
|
|
Loading…
Reference in New Issue