update docs and test
parent
dac84e9932
commit
2da2bcbf9c
|
@ -174,9 +174,16 @@ namespace gtsam {
|
||||||
* @brief Prune the decision tree of discrete variables.
|
* @brief Prune the decision tree of discrete variables.
|
||||||
*
|
*
|
||||||
* Pruning will set the leaves to be "pruned" to 0 indicating a 0
|
* Pruning will set the leaves to be "pruned" to 0 indicating a 0
|
||||||
* probability.
|
* probability. An assignment is pruned if it is not in the top
|
||||||
* An assignment is pruned if it is not in the top `maxNrAssignments`
|
* `maxNrAssignments` values.
|
||||||
* values.
|
*
|
||||||
|
* A violation can occur if there are more
|
||||||
|
* duplicate values than `maxNrAssignments`. A violation here is the need to
|
||||||
|
* un-prune the decision tree (e.g. all assignment values are 1.0). We could
|
||||||
|
* have another case where some subset of duplicates exist (e.g. for a tree
|
||||||
|
* with 8 assignments we have 1, 1, 1, 1, 0.8, 0.7, 0.6, 0.5), but this is
|
||||||
|
* not a violation since the for `maxNrAssignments=5` the top values are (1,
|
||||||
|
* 0.8).
|
||||||
*
|
*
|
||||||
* @param maxNrAssignments The maximum number of assignments to keep.
|
* @param maxNrAssignments The maximum number of assignments to keep.
|
||||||
* @return DecisionTreeFactor
|
* @return DecisionTreeFactor
|
||||||
|
|
|
@ -113,18 +113,32 @@ TEST(DecisionTreeFactor, Prune) {
|
||||||
DecisionTreeFactor f(A & B & C, "1 5 3 7 2 6 4 8");
|
DecisionTreeFactor f(A & B & C, "1 5 3 7 2 6 4 8");
|
||||||
|
|
||||||
// Only keep the leaves with the top 5 values.
|
// Only keep the leaves with the top 5 values.
|
||||||
size_t maxNrLeaves = 5;
|
size_t maxNrAssignments = 5;
|
||||||
auto pruned5 = f.prune(maxNrLeaves);
|
auto pruned5 = f.prune(maxNrAssignments);
|
||||||
|
|
||||||
// Pruned leaves should be 0
|
// Pruned leaves should be 0
|
||||||
DecisionTreeFactor expected(A & B & C, "0 5 0 7 0 6 4 8");
|
DecisionTreeFactor expected(A & B & C, "0 5 0 7 0 6 4 8");
|
||||||
EXPECT(assert_equal(expected, pruned5));
|
EXPECT(assert_equal(expected, pruned5));
|
||||||
|
|
||||||
// Check for more extreme pruning where we only keep the top 2 leaves
|
// Check for more extreme pruning where we only keep the top 2 leaves
|
||||||
maxNrLeaves = 2;
|
maxNrAssignments = 2;
|
||||||
auto pruned2 = f.prune(maxNrLeaves);
|
auto pruned2 = f.prune(maxNrAssignments);
|
||||||
DecisionTreeFactor expected2(A & B & C, "0 0 0 7 0 0 0 8");
|
DecisionTreeFactor expected2(A & B & C, "0 0 0 7 0 0 0 8");
|
||||||
EXPECT(assert_equal(expected2, pruned2));
|
EXPECT(assert_equal(expected2, pruned2));
|
||||||
|
|
||||||
|
DiscreteKey D(4, 2);
|
||||||
|
DecisionTreeFactor factor(
|
||||||
|
D & C & B & A,
|
||||||
|
"0.0 0.0 0.0 0.60658897 0.61241912 0.61241969 0.61247685 0.61247742 0.0 "
|
||||||
|
"0.0 0.0 0.99995287 1.0 1.0 1.0 1.0");
|
||||||
|
|
||||||
|
DecisionTreeFactor expected3(
|
||||||
|
D & C & B & A,
|
||||||
|
"0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 "
|
||||||
|
"0.999952870000 1.0 1.0 1.0 1.0");
|
||||||
|
maxNrAssignments = 5;
|
||||||
|
auto pruned3 = factor.prune(maxNrAssignments);
|
||||||
|
EXPECT(assert_equal(expected3, pruned3));
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
|
@ -215,4 +229,3 @@ int main() {
|
||||||
return TestRegistry::runAllTests(tr);
|
return TestRegistry::runAllTests(tr);
|
||||||
}
|
}
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue