restrict for DT
parent
f1b184aca0
commit
ee3733713f
|
@ -393,6 +393,13 @@ namespace gtsam {
|
||||||
return DecisionTree(newRoot);
|
return DecisionTree(newRoot);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/** Choose multiple values. */
|
||||||
|
DecisionTree restrict(const Assignment<L>& assignment) const {
|
||||||
|
NodePtr newRoot = root_;
|
||||||
|
for (const auto& [l, v] : assignment) newRoot = newRoot->choose(l, v);
|
||||||
|
return DecisionTree(newRoot);
|
||||||
|
}
|
||||||
|
|
||||||
/** combine subtrees on key with binary operation "op" */
|
/** combine subtrees on key with binary operation "op" */
|
||||||
DecisionTree combine(const L& label, size_t cardinality,
|
DecisionTree combine(const L& label, size_t cardinality,
|
||||||
const Binary& op) const;
|
const Binary& op) const;
|
||||||
|
|
|
@ -228,9 +228,9 @@ TEST(DecisionTree, Example) {
|
||||||
// Test choose 0
|
// Test choose 0
|
||||||
DT actual0 = notba.choose(A, 0);
|
DT actual0 = notba.choose(A, 0);
|
||||||
#ifdef GTSAM_DT_MERGING
|
#ifdef GTSAM_DT_MERGING
|
||||||
EXPECT(assert_equal(DT(0.0), actual0));
|
EXPECT(assert_equal(DT(0), actual0));
|
||||||
#else
|
#else
|
||||||
EXPECT(assert_equal(DT({0.0, 0.0}), actual0));
|
EXPECT(assert_equal(DT(B, 0, 0), actual0));
|
||||||
#endif
|
#endif
|
||||||
DOT(actual0);
|
DOT(actual0);
|
||||||
|
|
||||||
|
@ -618,6 +618,21 @@ TEST(DecisionTree, ApplyWithAssignment) {
|
||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************** */
|
||||||
|
// Test apply with assignment.
|
||||||
|
TEST(DecisionTree, Restrict) {
|
||||||
|
// Create three level tree
|
||||||
|
const vector<DT::LabelC> keys{DT::LabelC("C", 2), DT::LabelC("B", 2),
|
||||||
|
DT::LabelC("A", 2)};
|
||||||
|
DT tree(keys, "1 2 3 4 5 6 7 8");
|
||||||
|
|
||||||
|
DT restrictedTree = tree.restrict({{"A", 0}, {"B", 1}});
|
||||||
|
EXPECT(assert_equal(DT({DT::LabelC("C", 2)}, "3 7"), restrictedTree));
|
||||||
|
|
||||||
|
DT restrictMore = tree.restrict({{"A", 1}, {"B", 1}, {"C", 1}});
|
||||||
|
EXPECT(assert_equal(DT(8), restrictMore));
|
||||||
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
int main() {
|
int main() {
|
||||||
TestResult tr;
|
TestResult tr;
|
||||||
|
|
Loading…
Reference in New Issue