add unit test showing issue with nrAssignments
parent
e114e9f6d2
commit
6aa7d667f3
|
@ -25,6 +25,7 @@
|
|||
#include <gtsam/base/serializationTestHelpers.h>
|
||||
#include <gtsam/discrete/DecisionTree-inl.h>
|
||||
#include <gtsam/discrete/Signature.h>
|
||||
#include <gtsam/inference/Symbol.h>
|
||||
|
||||
#include <iomanip>
|
||||
|
||||
|
@ -329,6 +330,9 @@ TEST(DecisionTree, Containers) {
|
|||
TEST(DecisionTree, NrAssignments) {
|
||||
const std::pair<string, size_t> A("A", 2), B("B", 2), C("C", 2);
|
||||
DT tree({A, B, C}, "1 1 1 1 1 1 1 1");
|
||||
|
||||
EXPECT_LONGS_EQUAL(8, tree.nrAssignments());
|
||||
|
||||
EXPECT(tree.root_->isLeaf());
|
||||
auto leaf = std::dynamic_pointer_cast<const DT::Leaf>(tree.root_);
|
||||
EXPECT_LONGS_EQUAL(8, leaf->nrAssignments());
|
||||
|
@ -348,6 +352,8 @@ TEST(DecisionTree, NrAssignments) {
|
|||
1 1 Leaf 5
|
||||
*/
|
||||
|
||||
EXPECT_LONGS_EQUAL(8, tree2.nrAssignments());
|
||||
|
||||
auto root = std::dynamic_pointer_cast<const DT::Choice>(tree2.root_);
|
||||
CHECK(root);
|
||||
auto choice0 = std::dynamic_pointer_cast<const DT::Choice>(root->branches()[0]);
|
||||
|
@ -531,6 +537,23 @@ TEST(DecisionTree, ApplyWithAssignment) {
|
|||
EXPECT_LONGS_EQUAL(5, count);
|
||||
}
|
||||
|
||||
/* ************************************************************************** */
|
||||
// Test number of assignments.
|
||||
TEST(DecisionTree, NrAssignments2) {
|
||||
using gtsam::symbol_shorthand::M;
|
||||
|
||||
DiscreteKeys keys{{M(1), 2}, {M(0), 2}};
|
||||
std::vector<double> probs = {0, 0, 1, 2};
|
||||
DecisionTree<Key, double> dt1(keys, probs);
|
||||
|
||||
EXPECT_LONGS_EQUAL(4, dt1.nrAssignments());
|
||||
|
||||
DiscreteKeys keys2{{M(0), 2}, {M(1), 2}};
|
||||
DecisionTree<Key, double> dt2(keys2, probs);
|
||||
//TODO(Varun) The below is failing, because the number of assignments aren't being set correctly.
|
||||
EXPECT_LONGS_EQUAL(4, dt2.nrAssignments());
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
int main() {
|
||||
TestResult tr;
|
||||
|
|
Loading…
Reference in New Issue