From 2da2bcbf9c0c4ad9d29090dff48d72434495cf84 Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Thu, 31 Mar 2022 10:08:02 -0400 Subject: [PATCH] update docs and test --- gtsam/discrete/DecisionTreeFactor.h | 13 +++++++--- .../discrete/tests/testDecisionTreeFactor.cpp | 25 ++++++++++++++----- 2 files changed, 29 insertions(+), 9 deletions(-) diff --git a/gtsam/discrete/DecisionTreeFactor.h b/gtsam/discrete/DecisionTreeFactor.h index 286571ffc..86fa44649 100644 --- a/gtsam/discrete/DecisionTreeFactor.h +++ b/gtsam/discrete/DecisionTreeFactor.h @@ -174,9 +174,16 @@ namespace gtsam { * @brief Prune the decision tree of discrete variables. * * Pruning will set the leaves to be "pruned" to 0 indicating a 0 - * probability. - * An assignment is pruned if it is not in the top `maxNrAssignments` - * values. + * probability. An assignment is pruned if it is not in the top + * `maxNrAssignments` values. + * + * A violation can occur if there are more + * duplicate values than `maxNrAssignments`. A violation here is the need to + * un-prune the decision tree (e.g. all assignment values are 1.0). We could + * have another case where some subset of duplicates exist (e.g. for a tree + * with 8 assignments we have 1, 1, 1, 1, 0.8, 0.7, 0.6, 0.5), but this is + * not a violation since the for `maxNrAssignments=5` the top values are (1, + * 0.8). * * @param maxNrAssignments The maximum number of assignments to keep. * @return DecisionTreeFactor diff --git a/gtsam/discrete/tests/testDecisionTreeFactor.cpp b/gtsam/discrete/tests/testDecisionTreeFactor.cpp index 83b586bbb..84e45a0f5 100644 --- a/gtsam/discrete/tests/testDecisionTreeFactor.cpp +++ b/gtsam/discrete/tests/testDecisionTreeFactor.cpp @@ -113,18 +113,32 @@ TEST(DecisionTreeFactor, Prune) { DecisionTreeFactor f(A & B & C, "1 5 3 7 2 6 4 8"); // Only keep the leaves with the top 5 values. - size_t maxNrLeaves = 5; - auto pruned5 = f.prune(maxNrLeaves); + size_t maxNrAssignments = 5; + auto pruned5 = f.prune(maxNrAssignments); // Pruned leaves should be 0 DecisionTreeFactor expected(A & B & C, "0 5 0 7 0 6 4 8"); EXPECT(assert_equal(expected, pruned5)); // Check for more extreme pruning where we only keep the top 2 leaves - maxNrLeaves = 2; - auto pruned2 = f.prune(maxNrLeaves); + maxNrAssignments = 2; + auto pruned2 = f.prune(maxNrAssignments); DecisionTreeFactor expected2(A & B & C, "0 0 0 7 0 0 0 8"); EXPECT(assert_equal(expected2, pruned2)); + + DiscreteKey D(4, 2); + DecisionTreeFactor factor( + D & C & B & A, + "0.0 0.0 0.0 0.60658897 0.61241912 0.61241969 0.61247685 0.61247742 0.0 " + "0.0 0.0 0.99995287 1.0 1.0 1.0 1.0"); + + DecisionTreeFactor expected3( + D & C & B & A, + "0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 " + "0.999952870000 1.0 1.0 1.0 1.0"); + maxNrAssignments = 5; + auto pruned3 = factor.prune(maxNrAssignments); + EXPECT(assert_equal(expected3, pruned3)); } /* ************************************************************************* */ @@ -133,7 +147,7 @@ TEST(DecisionTreeFactor, DotWithNames) { DecisionTreeFactor f(A & B, "1 2 3 4 5 6"); auto formatter = [](Key key) { return key == 12 ? "A" : "B"; }; - for (bool showZero:{true, false}) { + for (bool showZero:{true, false}) { string actual = f.dot(formatter, showZero); // pretty weak test, as ids are pointers and not stable across platforms. string expected = "digraph G {"; @@ -215,4 +229,3 @@ int main() { return TestRegistry::runAllTests(tr); } /* ************************************************************************* */ -