diff --git a/gtsam/discrete/DecisionTreeFactor.cpp b/gtsam/discrete/DecisionTreeFactor.cpp index 931e603a7..e39339dd8 100644 --- a/gtsam/discrete/DecisionTreeFactor.cpp +++ b/gtsam/discrete/DecisionTreeFactor.cpp @@ -501,6 +501,10 @@ namespace gtsam { }; this->visitWith(op); + // If total number of hypotheses is less than N, return 0.0 + if (min_heap.size() < N) { + return 0.0; + } return min_heap.top(); } diff --git a/gtsam/discrete/tests/testDecisionTreeFactor.cpp b/gtsam/discrete/tests/testDecisionTreeFactor.cpp index f4523cd93..94c9624ca 100644 --- a/gtsam/discrete/tests/testDecisionTreeFactor.cpp +++ b/gtsam/discrete/tests/testDecisionTreeFactor.cpp @@ -132,7 +132,6 @@ TEST(DecisionTreeFactor, Divide) { KeySet keys(joint.keys()); keys.insert(pA.keys().begin(), pA.keys().end()); EXPECT(assert_inequal(KeySet(pS.keys()), keys)); - } /* ************************************************************************* */ @@ -234,6 +233,12 @@ TEST(DecisionTreeFactor, Prune) { maxNrAssignments = 5; auto pruned3 = factor.prune(maxNrAssignments); EXPECT(assert_equal(expected3, pruned3)); + + // Edge case where the number of hypotheses are less than maxNrAssignments + DecisionTreeFactor f(A, "0.50001 0.49999"); + auto pruned4 = f.prune(10); + DecisionTreeFactor expected4(A, "0.50001 0.49999"); + EXPECT(assert_equal(expected4, pruned4)); } /* ************************************************************************** */