From 4c966b9753737b9a6d4ef46029a4cd35ccd647cc Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Sun, 27 Mar 2022 13:25:35 -0400 Subject: [PATCH 1/3] New prune method for DecisionTreeFactor --- gtsam/discrete/DecisionTreeFactor.cpp | 38 +++++++++++++++++++ gtsam/discrete/DecisionTreeFactor.h | 12 ++++++ .../discrete/tests/testDecisionTreeFactor.cpp | 21 ++++++++++ 3 files changed, 71 insertions(+) diff --git a/gtsam/discrete/DecisionTreeFactor.cpp b/gtsam/discrete/DecisionTreeFactor.cpp index e95b8fe37..cfbfa976f 100644 --- a/gtsam/discrete/DecisionTreeFactor.cpp +++ b/gtsam/discrete/DecisionTreeFactor.cpp @@ -286,5 +286,43 @@ namespace gtsam { AlgebraicDecisionTree(keys, table), cardinalities_(keys.cardinalities()) {} + /* ************************************************************************ */ + DecisionTreeFactor DecisionTreeFactor::prune(size_t maxNrLeaves) const { + const size_t N = maxNrLeaves; + + // Let's assume that the structure of the last discrete density will be the + // same as the last continuous + std::vector probabilities; + // number of choices + this->visit([&](const double& prob) { probabilities.emplace_back(prob); }); + + // The number of probabilities can be lower than max_leaves + if (probabilities.size() <= N) { + return *this; + } + + std::sort(probabilities.begin(), probabilities.end(), + std::greater{}); + + double threshold = probabilities[N - 1]; + + // Now threshold the decision tree + size_t total = 0; + auto thresholdFunc = [threshold, &total, N](const double& value) { + if (value < threshold || total >= N) { + return 0.0; + } else { + total += 1; + return value; + } + }; + DecisionTree thresholded(*this, thresholdFunc); + + // Create pruned decision tree factor + DecisionTreeFactor prunedDiscreteFactor(this->discreteKeys(), thresholded); + + return prunedDiscreteFactor; + } + /* ************************************************************************ */ } // namespace gtsam diff --git a/gtsam/discrete/DecisionTreeFactor.h b/gtsam/discrete/DecisionTreeFactor.h index 91fa7c484..3cfdbe968 100644 --- a/gtsam/discrete/DecisionTreeFactor.h +++ b/gtsam/discrete/DecisionTreeFactor.h @@ -170,6 +170,18 @@ namespace gtsam { /// Return all the discrete keys associated with this factor. DiscreteKeys discreteKeys() const; + /** + * @brief Prune the decision tree of discrete variables. + * + * Pruning will set the leaves to be "pruned" to 0 indicating a 0 + * probability. + * A leaf is pruned if it is not in the top `maxNrLeaves` values. + * + * @param maxNrLeaves The maximum number of leaves to keep. + * @return DecisionTreeFactor::shared_ptr + */ + DecisionTreeFactor prune(size_t maxNrLeaves) const; + /// @} /// @name Wrapper support /// @{ diff --git a/gtsam/discrete/tests/testDecisionTreeFactor.cpp b/gtsam/discrete/tests/testDecisionTreeFactor.cpp index 846653c38..83b586bbb 100644 --- a/gtsam/discrete/tests/testDecisionTreeFactor.cpp +++ b/gtsam/discrete/tests/testDecisionTreeFactor.cpp @@ -106,6 +106,27 @@ TEST(DecisionTreeFactor, enumerate) { EXPECT(actual == expected); } +/* ************************************************************************* */ +// Check pruning of the decision tree works as expected. +TEST(DecisionTreeFactor, Prune) { + DiscreteKey A(1, 2), B(2, 2), C(3, 2); + DecisionTreeFactor f(A & B & C, "1 5 3 7 2 6 4 8"); + + // Only keep the leaves with the top 5 values. + size_t maxNrLeaves = 5; + auto pruned5 = f.prune(maxNrLeaves); + + // Pruned leaves should be 0 + DecisionTreeFactor expected(A & B & C, "0 5 0 7 0 6 4 8"); + EXPECT(assert_equal(expected, pruned5)); + + // Check for more extreme pruning where we only keep the top 2 leaves + maxNrLeaves = 2; + auto pruned2 = f.prune(maxNrLeaves); + DecisionTreeFactor expected2(A & B & C, "0 0 0 7 0 0 0 8"); + EXPECT(assert_equal(expected2, pruned2)); +} + /* ************************************************************************* */ TEST(DecisionTreeFactor, DotWithNames) { DiscreteKey A(12, 3), B(5, 2); From cd8d5590e6d1320b4d07e312a26f1b93fb90c443 Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Sun, 27 Mar 2022 13:34:42 -0400 Subject: [PATCH 2/3] doc fix --- gtsam/discrete/DecisionTreeFactor.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/gtsam/discrete/DecisionTreeFactor.h b/gtsam/discrete/DecisionTreeFactor.h index 3cfdbe968..1f3d69292 100644 --- a/gtsam/discrete/DecisionTreeFactor.h +++ b/gtsam/discrete/DecisionTreeFactor.h @@ -178,7 +178,7 @@ namespace gtsam { * A leaf is pruned if it is not in the top `maxNrLeaves` values. * * @param maxNrLeaves The maximum number of leaves to keep. - * @return DecisionTreeFactor::shared_ptr + * @return DecisionTreeFactor */ DecisionTreeFactor prune(size_t maxNrLeaves) const; From 365473fc5af2adc0c16e2433c85d095ea68e8e7b Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Sun, 27 Mar 2022 14:35:19 -0400 Subject: [PATCH 3/3] address review comments --- gtsam/discrete/DecisionTreeFactor.cpp | 10 +++------- 1 file changed, 3 insertions(+), 7 deletions(-) diff --git a/gtsam/discrete/DecisionTreeFactor.cpp b/gtsam/discrete/DecisionTreeFactor.cpp index cfbfa976f..acd4d4af2 100644 --- a/gtsam/discrete/DecisionTreeFactor.cpp +++ b/gtsam/discrete/DecisionTreeFactor.cpp @@ -290,10 +290,8 @@ namespace gtsam { DecisionTreeFactor DecisionTreeFactor::prune(size_t maxNrLeaves) const { const size_t N = maxNrLeaves; - // Let's assume that the structure of the last discrete density will be the - // same as the last continuous + // Get the probabilities in the decision tree so we can threshold. std::vector probabilities; - // number of choices this->visit([&](const double& prob) { probabilities.emplace_back(prob); }); // The number of probabilities can be lower than max_leaves @@ -318,10 +316,8 @@ namespace gtsam { }; DecisionTree thresholded(*this, thresholdFunc); - // Create pruned decision tree factor - DecisionTreeFactor prunedDiscreteFactor(this->discreteKeys(), thresholded); - - return prunedDiscreteFactor; + // Create pruned decision tree factor and return. + return DecisionTreeFactor(this->discreteKeys(), thresholded); } /* ************************************************************************ */