restrict for DT
parent
f1b184aca0
commit
ee3733713f
|
@ -393,6 +393,13 @@ namespace gtsam {
|
|||
return DecisionTree(newRoot);
|
||||
}
|
||||
|
||||
/** Choose multiple values. */
|
||||
DecisionTree restrict(const Assignment<L>& assignment) const {
|
||||
NodePtr newRoot = root_;
|
||||
for (const auto& [l, v] : assignment) newRoot = newRoot->choose(l, v);
|
||||
return DecisionTree(newRoot);
|
||||
}
|
||||
|
||||
/** combine subtrees on key with binary operation "op" */
|
||||
DecisionTree combine(const L& label, size_t cardinality,
|
||||
const Binary& op) const;
|
||||
|
|
|
@ -228,9 +228,9 @@ TEST(DecisionTree, Example) {
|
|||
// Test choose 0
|
||||
DT actual0 = notba.choose(A, 0);
|
||||
#ifdef GTSAM_DT_MERGING
|
||||
EXPECT(assert_equal(DT(0.0), actual0));
|
||||
EXPECT(assert_equal(DT(0), actual0));
|
||||
#else
|
||||
EXPECT(assert_equal(DT({0.0, 0.0}), actual0));
|
||||
EXPECT(assert_equal(DT(B, 0, 0), actual0));
|
||||
#endif
|
||||
DOT(actual0);
|
||||
|
||||
|
@ -618,6 +618,21 @@ TEST(DecisionTree, ApplyWithAssignment) {
|
|||
#endif
|
||||
}
|
||||
|
||||
/* ************************************************************************** */
|
||||
// Test apply with assignment.
|
||||
TEST(DecisionTree, Restrict) {
|
||||
// Create three level tree
|
||||
const vector<DT::LabelC> keys{DT::LabelC("C", 2), DT::LabelC("B", 2),
|
||||
DT::LabelC("A", 2)};
|
||||
DT tree(keys, "1 2 3 4 5 6 7 8");
|
||||
|
||||
DT restrictedTree = tree.restrict({{"A", 0}, {"B", 1}});
|
||||
EXPECT(assert_equal(DT({DT::LabelC("C", 2)}, "3 7"), restrictedTree));
|
||||
|
||||
DT restrictMore = tree.restrict({{"A", 1}, {"B", 1}, {"C", 1}});
|
||||
EXPECT(assert_equal(DT(8), restrictMore));
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
int main() {
|
||||
TestResult tr;
|
||||
|
|
Loading…
Reference in New Issue