Finish tests

release/4.3a0
Frank Dellaert 2024-09-29 22:56:32 -07:00
parent 788f4b6a19
commit 3d55fe0d37
1 changed files with 29 additions and 7 deletions

View File

@ -65,13 +65,18 @@ TEST(HybridBayesNet, Add) {
// Test API for a pure discrete Bayes net P(Asia).
TEST(HybridBayesNet, EvaluatePureDiscrete) {
HybridBayesNet bayesNet;
bayesNet.emplace_shared<DiscreteConditional>(Asia, "4/6");
const auto pAsia = std::make_shared<DiscreteConditional>(Asia, "4/6");
bayesNet.push_back(pAsia);
HybridValues zero{{}, {{asiaKey, 0}}}, one{{}, {{asiaKey, 1}}};
// choose
GaussianBayesNet empty;
EXPECT(assert_equal(empty, bayesNet.choose(zero.discrete()), 1e-9));
// logProbability
EXPECT_DOUBLES_EQUAL(log(0.4), bayesNet.logProbability(zero), 1e-9);
EXPECT_DOUBLES_EQUAL(log(0.6), bayesNet.logProbability(one), 1e-9);
// evaluate
EXPECT_DOUBLES_EQUAL(0.4, bayesNet.evaluate(zero), 1e-9);
EXPECT_DOUBLES_EQUAL(0.4, bayesNet(zero), 1e-9);
@ -88,18 +93,35 @@ TEST(HybridBayesNet, EvaluatePureDiscrete) {
// prune
EXPECT(assert_equal(bayesNet, bayesNet.prune(2)));
// EXPECT(assert_equal(bayesNet, bayesNet.prune(1))); Should fail !!!
// EXPECT(assert_equal(bayesNet, bayesNet.prune(0))); Should fail !!!
EXPECT_LONGS_EQUAL(1, bayesNet.prune(1).at(0)->size());
// errorTree
AlgebraicDecisionTree<Key> actual = bayesNet.errorTree({});
AlgebraicDecisionTree<Key> expected(
AlgebraicDecisionTree<Key> expectedErrorTree(
{Asia}, std::vector<double>{-log(0.4), -log(0.6)});
EXPECT(assert_equal(expected, actual));
EXPECT(assert_equal(expectedErrorTree, actual));
// error
EXPECT_DOUBLES_EQUAL(-log(0.4), bayesNet.error(zero), 1e-9);
EXPECT_DOUBLES_EQUAL(-log(0.6), bayesNet.error(one), 1e-9);
// logDiscretePosteriorPrime, TODO: useless as -errorTree?
AlgebraicDecisionTree<Key> expected({Asia},
std::vector<double>{log(0.4), log(0.6)});
EXPECT(assert_equal(expected, bayesNet.logDiscretePosteriorPrime({})));
// logProbability
EXPECT_DOUBLES_EQUAL(log(0.4), bayesNet.logProbability(zero), 1e-9);
EXPECT_DOUBLES_EQUAL(log(0.6), bayesNet.logProbability(one), 1e-9);
// discretePosterior
AlgebraicDecisionTree<Key> expectedPosterior({Asia},
std::vector<double>{0.4, 0.6});
EXPECT(assert_equal(expectedPosterior, bayesNet.discretePosterior({})));
// toFactorGraph
HybridGaussianFactorGraph expectedFG{pAsia}, fg = bayesNet.toFactorGraph({});
EXPECT(assert_equal(expectedFG, fg));
}
/* ****************************************************************************/
@ -358,7 +380,7 @@ TEST(HybridBayesNet, UpdateDiscreteConditionals) {
DiscreteConditional expected_discrete_conditionals(1, s.modes, potentials);
// Prune!
posterior->prune(maxNrLeaves);
auto pruned = posterior->prune(maxNrLeaves);
// Functor to verify values against the expected_discrete_conditionals
auto checker = [&](const Assignment<Key>& assignment,
@ -375,7 +397,7 @@ TEST(HybridBayesNet, UpdateDiscreteConditionals) {
};
// Get the pruned discrete conditionals as an AlgebraicDecisionTree
auto pruned_discrete_conditionals = posterior->at(4)->asDiscrete();
auto pruned_discrete_conditionals = pruned.at(4)->asDiscrete();
auto discrete_conditional_tree =
std::dynamic_pointer_cast<DecisionTreeFactor::ADT>(
pruned_discrete_conditionals);