diff --git a/gtsam/discrete/DecisionTree.h b/gtsam/discrete/DecisionTree.h index 7c8d2742d..105da7394 100644 --- a/gtsam/discrete/DecisionTree.h +++ b/gtsam/discrete/DecisionTree.h @@ -393,6 +393,13 @@ namespace gtsam { return DecisionTree(newRoot); } + /** Choose multiple values. */ + DecisionTree restrict(const Assignment& assignment) const { + NodePtr newRoot = root_; + for (const auto& [l, v] : assignment) newRoot = newRoot->choose(l, v); + return DecisionTree(newRoot); + } + /** combine subtrees on key with binary operation "op" */ DecisionTree combine(const L& label, size_t cardinality, const Binary& op) const; diff --git a/gtsam/discrete/tests/testDecisionTree.cpp b/gtsam/discrete/tests/testDecisionTree.cpp index c664fe6b5..ab2f318b7 100644 --- a/gtsam/discrete/tests/testDecisionTree.cpp +++ b/gtsam/discrete/tests/testDecisionTree.cpp @@ -228,9 +228,9 @@ TEST(DecisionTree, Example) { // Test choose 0 DT actual0 = notba.choose(A, 0); #ifdef GTSAM_DT_MERGING - EXPECT(assert_equal(DT(0.0), actual0)); + EXPECT(assert_equal(DT(0), actual0)); #else - EXPECT(assert_equal(DT({0.0, 0.0}), actual0)); + EXPECT(assert_equal(DT(B, 0, 0), actual0)); #endif DOT(actual0); @@ -618,6 +618,21 @@ TEST(DecisionTree, ApplyWithAssignment) { #endif } +/* ************************************************************************** */ +// Test apply with assignment. +TEST(DecisionTree, Restrict) { + // Create three level tree + const vector keys{DT::LabelC("C", 2), DT::LabelC("B", 2), + DT::LabelC("A", 2)}; + DT tree(keys, "1 2 3 4 5 6 7 8"); + + DT restrictedTree = tree.restrict({{"A", 0}, {"B", 1}}); + EXPECT(assert_equal(DT({DT::LabelC("C", 2)}, "3 7"), restrictedTree)); + + DT restrictMore = tree.restrict({{"A", 1}, {"B", 1}, {"C", 1}}); + EXPECT(assert_equal(DT(8), restrictMore)); +} + /* ************************************************************************* */ int main() { TestResult tr;