Finish tests
parent
788f4b6a19
commit
3d55fe0d37
|
@ -65,13 +65,18 @@ TEST(HybridBayesNet, Add) {
|
||||||
// Test API for a pure discrete Bayes net P(Asia).
|
// Test API for a pure discrete Bayes net P(Asia).
|
||||||
TEST(HybridBayesNet, EvaluatePureDiscrete) {
|
TEST(HybridBayesNet, EvaluatePureDiscrete) {
|
||||||
HybridBayesNet bayesNet;
|
HybridBayesNet bayesNet;
|
||||||
bayesNet.emplace_shared<DiscreteConditional>(Asia, "4/6");
|
const auto pAsia = std::make_shared<DiscreteConditional>(Asia, "4/6");
|
||||||
|
bayesNet.push_back(pAsia);
|
||||||
HybridValues zero{{}, {{asiaKey, 0}}}, one{{}, {{asiaKey, 1}}};
|
HybridValues zero{{}, {{asiaKey, 0}}}, one{{}, {{asiaKey, 1}}};
|
||||||
|
|
||||||
// choose
|
// choose
|
||||||
GaussianBayesNet empty;
|
GaussianBayesNet empty;
|
||||||
EXPECT(assert_equal(empty, bayesNet.choose(zero.discrete()), 1e-9));
|
EXPECT(assert_equal(empty, bayesNet.choose(zero.discrete()), 1e-9));
|
||||||
|
|
||||||
|
// logProbability
|
||||||
|
EXPECT_DOUBLES_EQUAL(log(0.4), bayesNet.logProbability(zero), 1e-9);
|
||||||
|
EXPECT_DOUBLES_EQUAL(log(0.6), bayesNet.logProbability(one), 1e-9);
|
||||||
|
|
||||||
// evaluate
|
// evaluate
|
||||||
EXPECT_DOUBLES_EQUAL(0.4, bayesNet.evaluate(zero), 1e-9);
|
EXPECT_DOUBLES_EQUAL(0.4, bayesNet.evaluate(zero), 1e-9);
|
||||||
EXPECT_DOUBLES_EQUAL(0.4, bayesNet(zero), 1e-9);
|
EXPECT_DOUBLES_EQUAL(0.4, bayesNet(zero), 1e-9);
|
||||||
|
@ -88,18 +93,35 @@ TEST(HybridBayesNet, EvaluatePureDiscrete) {
|
||||||
|
|
||||||
// prune
|
// prune
|
||||||
EXPECT(assert_equal(bayesNet, bayesNet.prune(2)));
|
EXPECT(assert_equal(bayesNet, bayesNet.prune(2)));
|
||||||
// EXPECT(assert_equal(bayesNet, bayesNet.prune(1))); Should fail !!!
|
EXPECT_LONGS_EQUAL(1, bayesNet.prune(1).at(0)->size());
|
||||||
// EXPECT(assert_equal(bayesNet, bayesNet.prune(0))); Should fail !!!
|
|
||||||
|
|
||||||
// errorTree
|
// errorTree
|
||||||
AlgebraicDecisionTree<Key> actual = bayesNet.errorTree({});
|
AlgebraicDecisionTree<Key> actual = bayesNet.errorTree({});
|
||||||
AlgebraicDecisionTree<Key> expected(
|
AlgebraicDecisionTree<Key> expectedErrorTree(
|
||||||
{Asia}, std::vector<double>{-log(0.4), -log(0.6)});
|
{Asia}, std::vector<double>{-log(0.4), -log(0.6)});
|
||||||
EXPECT(assert_equal(expected, actual));
|
EXPECT(assert_equal(expectedErrorTree, actual));
|
||||||
|
|
||||||
// error
|
// error
|
||||||
EXPECT_DOUBLES_EQUAL(-log(0.4), bayesNet.error(zero), 1e-9);
|
EXPECT_DOUBLES_EQUAL(-log(0.4), bayesNet.error(zero), 1e-9);
|
||||||
EXPECT_DOUBLES_EQUAL(-log(0.6), bayesNet.error(one), 1e-9);
|
EXPECT_DOUBLES_EQUAL(-log(0.6), bayesNet.error(one), 1e-9);
|
||||||
|
|
||||||
|
// logDiscretePosteriorPrime, TODO: useless as -errorTree?
|
||||||
|
AlgebraicDecisionTree<Key> expected({Asia},
|
||||||
|
std::vector<double>{log(0.4), log(0.6)});
|
||||||
|
EXPECT(assert_equal(expected, bayesNet.logDiscretePosteriorPrime({})));
|
||||||
|
|
||||||
|
// logProbability
|
||||||
|
EXPECT_DOUBLES_EQUAL(log(0.4), bayesNet.logProbability(zero), 1e-9);
|
||||||
|
EXPECT_DOUBLES_EQUAL(log(0.6), bayesNet.logProbability(one), 1e-9);
|
||||||
|
|
||||||
|
// discretePosterior
|
||||||
|
AlgebraicDecisionTree<Key> expectedPosterior({Asia},
|
||||||
|
std::vector<double>{0.4, 0.6});
|
||||||
|
EXPECT(assert_equal(expectedPosterior, bayesNet.discretePosterior({})));
|
||||||
|
|
||||||
|
// toFactorGraph
|
||||||
|
HybridGaussianFactorGraph expectedFG{pAsia}, fg = bayesNet.toFactorGraph({});
|
||||||
|
EXPECT(assert_equal(expectedFG, fg));
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ****************************************************************************/
|
/* ****************************************************************************/
|
||||||
|
@ -358,7 +380,7 @@ TEST(HybridBayesNet, UpdateDiscreteConditionals) {
|
||||||
DiscreteConditional expected_discrete_conditionals(1, s.modes, potentials);
|
DiscreteConditional expected_discrete_conditionals(1, s.modes, potentials);
|
||||||
|
|
||||||
// Prune!
|
// Prune!
|
||||||
posterior->prune(maxNrLeaves);
|
auto pruned = posterior->prune(maxNrLeaves);
|
||||||
|
|
||||||
// Functor to verify values against the expected_discrete_conditionals
|
// Functor to verify values against the expected_discrete_conditionals
|
||||||
auto checker = [&](const Assignment<Key>& assignment,
|
auto checker = [&](const Assignment<Key>& assignment,
|
||||||
|
@ -375,7 +397,7 @@ TEST(HybridBayesNet, UpdateDiscreteConditionals) {
|
||||||
};
|
};
|
||||||
|
|
||||||
// Get the pruned discrete conditionals as an AlgebraicDecisionTree
|
// Get the pruned discrete conditionals as an AlgebraicDecisionTree
|
||||||
auto pruned_discrete_conditionals = posterior->at(4)->asDiscrete();
|
auto pruned_discrete_conditionals = pruned.at(4)->asDiscrete();
|
||||||
auto discrete_conditional_tree =
|
auto discrete_conditional_tree =
|
||||||
std::dynamic_pointer_cast<DecisionTreeFactor::ADT>(
|
std::dynamic_pointer_cast<DecisionTreeFactor::ADT>(
|
||||||
pruned_discrete_conditionals);
|
pruned_discrete_conditionals);
|
||||||
|
|
Loading…
Reference in New Issue