update docs and test

release/4.3a0
Varun Agrawal 2022-03-31 10:08:02 -04:00
parent dac84e9932
commit 2da2bcbf9c
2 changed files with 29 additions and 9 deletions

View File

@ -174,9 +174,16 @@ namespace gtsam {
* @brief Prune the decision tree of discrete variables. * @brief Prune the decision tree of discrete variables.
* *
* Pruning will set the leaves to be "pruned" to 0 indicating a 0 * Pruning will set the leaves to be "pruned" to 0 indicating a 0
* probability. * probability. An assignment is pruned if it is not in the top
* An assignment is pruned if it is not in the top `maxNrAssignments` * `maxNrAssignments` values.
* values. *
* A violation can occur if there are more
* duplicate values than `maxNrAssignments`. A violation here is the need to
* un-prune the decision tree (e.g. all assignment values are 1.0). We could
* have another case where some subset of duplicates exist (e.g. for a tree
* with 8 assignments we have 1, 1, 1, 1, 0.8, 0.7, 0.6, 0.5), but this is
* not a violation since the for `maxNrAssignments=5` the top values are (1,
* 0.8).
* *
* @param maxNrAssignments The maximum number of assignments to keep. * @param maxNrAssignments The maximum number of assignments to keep.
* @return DecisionTreeFactor * @return DecisionTreeFactor

View File

@ -113,18 +113,32 @@ TEST(DecisionTreeFactor, Prune) {
DecisionTreeFactor f(A & B & C, "1 5 3 7 2 6 4 8"); DecisionTreeFactor f(A & B & C, "1 5 3 7 2 6 4 8");
// Only keep the leaves with the top 5 values. // Only keep the leaves with the top 5 values.
size_t maxNrLeaves = 5; size_t maxNrAssignments = 5;
auto pruned5 = f.prune(maxNrLeaves); auto pruned5 = f.prune(maxNrAssignments);
// Pruned leaves should be 0 // Pruned leaves should be 0
DecisionTreeFactor expected(A & B & C, "0 5 0 7 0 6 4 8"); DecisionTreeFactor expected(A & B & C, "0 5 0 7 0 6 4 8");
EXPECT(assert_equal(expected, pruned5)); EXPECT(assert_equal(expected, pruned5));
// Check for more extreme pruning where we only keep the top 2 leaves // Check for more extreme pruning where we only keep the top 2 leaves
maxNrLeaves = 2; maxNrAssignments = 2;
auto pruned2 = f.prune(maxNrLeaves); auto pruned2 = f.prune(maxNrAssignments);
DecisionTreeFactor expected2(A & B & C, "0 0 0 7 0 0 0 8"); DecisionTreeFactor expected2(A & B & C, "0 0 0 7 0 0 0 8");
EXPECT(assert_equal(expected2, pruned2)); EXPECT(assert_equal(expected2, pruned2));
DiscreteKey D(4, 2);
DecisionTreeFactor factor(
D & C & B & A,
"0.0 0.0 0.0 0.60658897 0.61241912 0.61241969 0.61247685 0.61247742 0.0 "
"0.0 0.0 0.99995287 1.0 1.0 1.0 1.0");
DecisionTreeFactor expected3(
D & C & B & A,
"0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 "
"0.999952870000 1.0 1.0 1.0 1.0");
maxNrAssignments = 5;
auto pruned3 = factor.prune(maxNrAssignments);
EXPECT(assert_equal(expected3, pruned3));
} }
/* ************************************************************************* */ /* ************************************************************************* */
@ -215,4 +229,3 @@ int main() {
return TestRegistry::runAllTests(tr); return TestRegistry::runAllTests(tr);
} }
/* ************************************************************************* */ /* ************************************************************************* */