Fix unit tests
parent
05ad198ca6
commit
e6662b8206
|
|
@ -166,8 +166,8 @@ TEST(HybridBayesNet, Tiny) {
|
||||||
|
|
||||||
// prune
|
// prune
|
||||||
auto pruned = bayesNet.prune(1);
|
auto pruned = bayesNet.prune(1);
|
||||||
CHECK(pruned.at(1)->asHybrid());
|
CHECK(pruned.at(0)->asHybrid());
|
||||||
EXPECT_LONGS_EQUAL(1, pruned.at(1)->asHybrid()->nrComponents());
|
EXPECT_LONGS_EQUAL(1, pruned.at(0)->asHybrid()->nrComponents());
|
||||||
EXPECT(!pruned.equals(bayesNet));
|
EXPECT(!pruned.equals(bayesNet));
|
||||||
|
|
||||||
// error
|
// error
|
||||||
|
|
@ -437,20 +437,21 @@ TEST(HybridBayesNet, RemoveDeadNodes) {
|
||||||
const double pruneDeadVariables = 0.99;
|
const double pruneDeadVariables = 0.99;
|
||||||
auto prunedBayesNet = posterior->prune(2, pruneDeadVariables);
|
auto prunedBayesNet = posterior->prune(2, pruneDeadVariables);
|
||||||
|
|
||||||
|
// First conditional is still the same: P( x0 | x1 m0)
|
||||||
|
EXPECT(prunedBayesNet.at(0)->isHybrid());
|
||||||
|
|
||||||
|
// Check that hybrid conditional that only depend on M1
|
||||||
|
// is now Gaussian and not Hybrid
|
||||||
|
EXPECT(prunedBayesNet.at(1)->isContinuous());
|
||||||
|
|
||||||
|
// Third conditional is still Hybrid: P( x1 | m0 m1) -> P( x1 | m0)
|
||||||
|
EXPECT(prunedBayesNet.at(0)->isHybrid());
|
||||||
|
|
||||||
// Check that discrete joint only has M0 and not (M0, M1)
|
// Check that discrete joint only has M0 and not (M0, M1)
|
||||||
// since M0 is removed
|
// since M0 is removed
|
||||||
KeyVector actual_keys = prunedBayesNet.at(0)->asDiscrete()->keys();
|
auto joint = prunedBayesNet.at(3)->asDiscrete();
|
||||||
EXPECT(KeyVector{M(0)} == actual_keys);
|
EXPECT(joint);
|
||||||
|
EXPECT(joint->keys() == KeyVector{M(0)});
|
||||||
// Check that hybrid conditionals that only depend on M1
|
|
||||||
// are now Gaussian and not Hybrid
|
|
||||||
EXPECT(prunedBayesNet.at(0)->isDiscrete());
|
|
||||||
EXPECT(prunedBayesNet.at(1)->isDiscrete());
|
|
||||||
EXPECT(prunedBayesNet.at(2)->isHybrid());
|
|
||||||
// Only P(X2 | X1, M1) depends on M1,
|
|
||||||
// so it gets convert to a Gaussian P(X2 | X1)
|
|
||||||
EXPECT(prunedBayesNet.at(3)->isContinuous());
|
|
||||||
EXPECT(prunedBayesNet.at(4)->isHybrid());
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ****************************************************************************/
|
/* ****************************************************************************/
|
||||||
|
|
@ -479,13 +480,13 @@ TEST(HybridBayesNet, ErrorAfterPruning) {
|
||||||
const HybridValues hybridValues{delta.continuous(), discrete_values};
|
const HybridValues hybridValues{delta.continuous(), discrete_values};
|
||||||
double pruned_logProbability = 0;
|
double pruned_logProbability = 0;
|
||||||
pruned_logProbability +=
|
pruned_logProbability +=
|
||||||
prunedBayesNet.at(0)->asDiscrete()->logProbability(hybridValues);
|
prunedBayesNet.at(0)->asHybrid()->logProbability(hybridValues);
|
||||||
pruned_logProbability +=
|
pruned_logProbability +=
|
||||||
prunedBayesNet.at(1)->asHybrid()->logProbability(hybridValues);
|
prunedBayesNet.at(1)->asHybrid()->logProbability(hybridValues);
|
||||||
pruned_logProbability +=
|
pruned_logProbability +=
|
||||||
prunedBayesNet.at(2)->asHybrid()->logProbability(hybridValues);
|
prunedBayesNet.at(2)->asHybrid()->logProbability(hybridValues);
|
||||||
pruned_logProbability +=
|
pruned_logProbability +=
|
||||||
prunedBayesNet.at(3)->asHybrid()->logProbability(hybridValues);
|
prunedBayesNet.at(3)->asDiscrete()->logProbability(hybridValues);
|
||||||
|
|
||||||
double pruned_negLogConstant = prunedBayesNet.negLogConstant(discrete_values);
|
double pruned_negLogConstant = prunedBayesNet.negLogConstant(discrete_values);
|
||||||
|
|
||||||
|
|
@ -548,8 +549,8 @@ TEST(HybridBayesNet, UpdateDiscreteConditionals) {
|
||||||
};
|
};
|
||||||
|
|
||||||
// Get the pruned discrete conditionals as an AlgebraicDecisionTree
|
// Get the pruned discrete conditionals as an AlgebraicDecisionTree
|
||||||
CHECK(pruned.at(0)->asDiscrete());
|
CHECK(pruned.at(4)->asDiscrete());
|
||||||
auto pruned_discrete_conditionals = pruned.at(0)->asDiscrete();
|
auto pruned_discrete_conditionals = pruned.at(4)->asDiscrete();
|
||||||
auto discrete_conditional_tree =
|
auto discrete_conditional_tree =
|
||||||
std::dynamic_pointer_cast<DecisionTreeFactor::ADT>(
|
std::dynamic_pointer_cast<DecisionTreeFactor::ADT>(
|
||||||
pruned_discrete_conditionals);
|
pruned_discrete_conditionals);
|
||||||
|
|
|
||||||
|
|
@ -103,7 +103,7 @@ TEST(HybridSmoother, IncrementalSmoother) {
|
||||||
}
|
}
|
||||||
|
|
||||||
EXPECT_LONGS_EQUAL(11,
|
EXPECT_LONGS_EQUAL(11,
|
||||||
smoother.hybridBayesNet().at(3)->asDiscrete()->nrValues());
|
smoother.hybridBayesNet().at(5)->asDiscrete()->nrValues());
|
||||||
|
|
||||||
// Get the continuous delta update as well as
|
// Get the continuous delta update as well as
|
||||||
// the optimal discrete assignment.
|
// the optimal discrete assignment.
|
||||||
|
|
@ -157,7 +157,7 @@ TEST(HybridSmoother, ValidPruningError) {
|
||||||
}
|
}
|
||||||
|
|
||||||
EXPECT_LONGS_EQUAL(14,
|
EXPECT_LONGS_EQUAL(14,
|
||||||
smoother.hybridBayesNet().at(6)->asDiscrete()->nrValues());
|
smoother.hybridBayesNet().at(8)->asDiscrete()->nrValues());
|
||||||
|
|
||||||
// Get the continuous delta update as well as
|
// Get the continuous delta update as well as
|
||||||
// the optimal discrete assignment.
|
// the optimal discrete assignment.
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue