restrict for DT

release/4.3a0
Frank Dellaert 2025-02-01 01:14:52 -05:00
parent f1b184aca0
commit ee3733713f
2 changed files with 24 additions and 2 deletions

View File

@ -393,6 +393,13 @@ namespace gtsam {
return DecisionTree(newRoot); return DecisionTree(newRoot);
} }
/** Choose multiple values. */
DecisionTree restrict(const Assignment<L>& assignment) const {
NodePtr newRoot = root_;
for (const auto& [l, v] : assignment) newRoot = newRoot->choose(l, v);
return DecisionTree(newRoot);
}
/** combine subtrees on key with binary operation "op" */ /** combine subtrees on key with binary operation "op" */
DecisionTree combine(const L& label, size_t cardinality, DecisionTree combine(const L& label, size_t cardinality,
const Binary& op) const; const Binary& op) const;

View File

@ -228,9 +228,9 @@ TEST(DecisionTree, Example) {
// Test choose 0 // Test choose 0
DT actual0 = notba.choose(A, 0); DT actual0 = notba.choose(A, 0);
#ifdef GTSAM_DT_MERGING #ifdef GTSAM_DT_MERGING
EXPECT(assert_equal(DT(0.0), actual0)); EXPECT(assert_equal(DT(0), actual0));
#else #else
EXPECT(assert_equal(DT({0.0, 0.0}), actual0)); EXPECT(assert_equal(DT(B, 0, 0), actual0));
#endif #endif
DOT(actual0); DOT(actual0);
@ -618,6 +618,21 @@ TEST(DecisionTree, ApplyWithAssignment) {
#endif #endif
} }
/* ************************************************************************** */
// Test apply with assignment.
TEST(DecisionTree, Restrict) {
// Create three level tree
const vector<DT::LabelC> keys{DT::LabelC("C", 2), DT::LabelC("B", 2),
DT::LabelC("A", 2)};
DT tree(keys, "1 2 3 4 5 6 7 8");
DT restrictedTree = tree.restrict({{"A", 0}, {"B", 1}});
EXPECT(assert_equal(DT({DT::LabelC("C", 2)}, "3 7"), restrictedTree));
DT restrictMore = tree.restrict({{"A", 1}, {"B", 1}, {"C", 1}});
EXPECT(assert_equal(DT(8), restrictMore));
}
/* ************************************************************************* */ /* ************************************************************************* */
int main() { int main() {
TestResult tr; TestResult tr;