fix hybrid tests
parent
5019153e12
commit
623bd63ec8
|
|
@ -454,7 +454,7 @@ TEST(HybridBayesNet, UpdateDiscreteConditionals) {
|
||||||
}
|
}
|
||||||
|
|
||||||
size_t maxNrLeaves = 3;
|
size_t maxNrLeaves = 3;
|
||||||
auto prunedDecisionTree = joint.prune(maxNrLeaves);
|
auto prunedDecisionTree = *joint.prune(maxNrLeaves);
|
||||||
|
|
||||||
#ifdef GTSAM_DT_MERGING
|
#ifdef GTSAM_DT_MERGING
|
||||||
EXPECT_LONGS_EQUAL(maxNrLeaves + 2 /*2 zero leaves*/,
|
EXPECT_LONGS_EQUAL(maxNrLeaves + 2 /*2 zero leaves*/,
|
||||||
|
|
|
||||||
|
|
@ -261,7 +261,8 @@ TEST(HybridGaussianConditional, Prune) {
|
||||||
potentials[i] = 1;
|
potentials[i] = 1;
|
||||||
const DecisionTreeFactor decisionTreeFactor(keys, potentials);
|
const DecisionTreeFactor decisionTreeFactor(keys, potentials);
|
||||||
// Prune the HybridGaussianConditional
|
// Prune the HybridGaussianConditional
|
||||||
const auto pruned = hgc.prune(decisionTreeFactor);
|
const auto pruned = hgc.prune(std::make_shared<DiscreteConditional>(
|
||||||
|
keys.size(), decisionTreeFactor));
|
||||||
// Check that the pruned HybridGaussianConditional has 1 conditional
|
// Check that the pruned HybridGaussianConditional has 1 conditional
|
||||||
EXPECT_LONGS_EQUAL(1, pruned->nrComponents());
|
EXPECT_LONGS_EQUAL(1, pruned->nrComponents());
|
||||||
}
|
}
|
||||||
|
|
@ -271,7 +272,8 @@ TEST(HybridGaussianConditional, Prune) {
|
||||||
0, 0, 0.5, 0};
|
0, 0, 0.5, 0};
|
||||||
const DecisionTreeFactor decisionTreeFactor(keys, potentials);
|
const DecisionTreeFactor decisionTreeFactor(keys, potentials);
|
||||||
|
|
||||||
const auto pruned = hgc.prune(decisionTreeFactor);
|
const auto pruned = hgc.prune(
|
||||||
|
std::make_shared<DiscreteConditional>(keys.size(), decisionTreeFactor));
|
||||||
|
|
||||||
// Check that the pruned HybridGaussianConditional has 2 conditionals
|
// Check that the pruned HybridGaussianConditional has 2 conditionals
|
||||||
EXPECT_LONGS_EQUAL(2, pruned->nrComponents());
|
EXPECT_LONGS_EQUAL(2, pruned->nrComponents());
|
||||||
|
|
@ -286,7 +288,8 @@ TEST(HybridGaussianConditional, Prune) {
|
||||||
0, 0, 0.5, 0};
|
0, 0, 0.5, 0};
|
||||||
const DecisionTreeFactor decisionTreeFactor(keys, potentials);
|
const DecisionTreeFactor decisionTreeFactor(keys, potentials);
|
||||||
|
|
||||||
const auto pruned = hgc.prune(decisionTreeFactor);
|
const auto pruned = hgc.prune(
|
||||||
|
std::make_shared<DiscreteConditional>(keys.size(), decisionTreeFactor));
|
||||||
|
|
||||||
// Check that the pruned HybridGaussianConditional has 3 conditionals
|
// Check that the pruned HybridGaussianConditional has 3 conditionals
|
||||||
EXPECT_LONGS_EQUAL(3, pruned->nrComponents());
|
EXPECT_LONGS_EQUAL(3, pruned->nrComponents());
|
||||||
|
|
|
||||||
|
|
@ -368,10 +368,9 @@ TEST(HybridGaussianElimination, EliminateHybrid_2_Variable) {
|
||||||
EXPECT_LONGS_EQUAL(1, hybridGaussianConditional->nrParents());
|
EXPECT_LONGS_EQUAL(1, hybridGaussianConditional->nrParents());
|
||||||
|
|
||||||
// This is now a discreteFactor
|
// This is now a discreteFactor
|
||||||
auto discreteFactor = dynamic_pointer_cast<DecisionTreeFactor>(factorOnModes);
|
auto discreteFactor = dynamic_pointer_cast<TableFactor>(factorOnModes);
|
||||||
CHECK(discreteFactor);
|
CHECK(discreteFactor);
|
||||||
EXPECT_LONGS_EQUAL(1, discreteFactor->discreteKeys().size());
|
EXPECT_LONGS_EQUAL(1, discreteFactor->discreteKeys().size());
|
||||||
EXPECT(discreteFactor->root_->isLeaf() == false);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/****************************************************************************
|
/****************************************************************************
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue