Finish tests

release/4.3a0
Frank Dellaert 2024-09-29 22:56:32 -07:00
parent 788f4b6a19
commit 3d55fe0d37
1 changed files with 29 additions and 7 deletions

View File

@ -65,13 +65,18 @@ TEST(HybridBayesNet, Add) {
// Test API for a pure discrete Bayes net P(Asia). // Test API for a pure discrete Bayes net P(Asia).
TEST(HybridBayesNet, EvaluatePureDiscrete) { TEST(HybridBayesNet, EvaluatePureDiscrete) {
HybridBayesNet bayesNet; HybridBayesNet bayesNet;
bayesNet.emplace_shared<DiscreteConditional>(Asia, "4/6"); const auto pAsia = std::make_shared<DiscreteConditional>(Asia, "4/6");
bayesNet.push_back(pAsia);
HybridValues zero{{}, {{asiaKey, 0}}}, one{{}, {{asiaKey, 1}}}; HybridValues zero{{}, {{asiaKey, 0}}}, one{{}, {{asiaKey, 1}}};
// choose // choose
GaussianBayesNet empty; GaussianBayesNet empty;
EXPECT(assert_equal(empty, bayesNet.choose(zero.discrete()), 1e-9)); EXPECT(assert_equal(empty, bayesNet.choose(zero.discrete()), 1e-9));
// logProbability
EXPECT_DOUBLES_EQUAL(log(0.4), bayesNet.logProbability(zero), 1e-9);
EXPECT_DOUBLES_EQUAL(log(0.6), bayesNet.logProbability(one), 1e-9);
// evaluate // evaluate
EXPECT_DOUBLES_EQUAL(0.4, bayesNet.evaluate(zero), 1e-9); EXPECT_DOUBLES_EQUAL(0.4, bayesNet.evaluate(zero), 1e-9);
EXPECT_DOUBLES_EQUAL(0.4, bayesNet(zero), 1e-9); EXPECT_DOUBLES_EQUAL(0.4, bayesNet(zero), 1e-9);
@ -88,18 +93,35 @@ TEST(HybridBayesNet, EvaluatePureDiscrete) {
// prune // prune
EXPECT(assert_equal(bayesNet, bayesNet.prune(2))); EXPECT(assert_equal(bayesNet, bayesNet.prune(2)));
// EXPECT(assert_equal(bayesNet, bayesNet.prune(1))); Should fail !!! EXPECT_LONGS_EQUAL(1, bayesNet.prune(1).at(0)->size());
// EXPECT(assert_equal(bayesNet, bayesNet.prune(0))); Should fail !!!
// errorTree // errorTree
AlgebraicDecisionTree<Key> actual = bayesNet.errorTree({}); AlgebraicDecisionTree<Key> actual = bayesNet.errorTree({});
AlgebraicDecisionTree<Key> expected( AlgebraicDecisionTree<Key> expectedErrorTree(
{Asia}, std::vector<double>{-log(0.4), -log(0.6)}); {Asia}, std::vector<double>{-log(0.4), -log(0.6)});
EXPECT(assert_equal(expected, actual)); EXPECT(assert_equal(expectedErrorTree, actual));
// error // error
EXPECT_DOUBLES_EQUAL(-log(0.4), bayesNet.error(zero), 1e-9); EXPECT_DOUBLES_EQUAL(-log(0.4), bayesNet.error(zero), 1e-9);
EXPECT_DOUBLES_EQUAL(-log(0.6), bayesNet.error(one), 1e-9); EXPECT_DOUBLES_EQUAL(-log(0.6), bayesNet.error(one), 1e-9);
// logDiscretePosteriorPrime, TODO: useless as -errorTree?
AlgebraicDecisionTree<Key> expected({Asia},
std::vector<double>{log(0.4), log(0.6)});
EXPECT(assert_equal(expected, bayesNet.logDiscretePosteriorPrime({})));
// logProbability
EXPECT_DOUBLES_EQUAL(log(0.4), bayesNet.logProbability(zero), 1e-9);
EXPECT_DOUBLES_EQUAL(log(0.6), bayesNet.logProbability(one), 1e-9);
// discretePosterior
AlgebraicDecisionTree<Key> expectedPosterior({Asia},
std::vector<double>{0.4, 0.6});
EXPECT(assert_equal(expectedPosterior, bayesNet.discretePosterior({})));
// toFactorGraph
HybridGaussianFactorGraph expectedFG{pAsia}, fg = bayesNet.toFactorGraph({});
EXPECT(assert_equal(expectedFG, fg));
} }
/* ****************************************************************************/ /* ****************************************************************************/
@ -358,7 +380,7 @@ TEST(HybridBayesNet, UpdateDiscreteConditionals) {
DiscreteConditional expected_discrete_conditionals(1, s.modes, potentials); DiscreteConditional expected_discrete_conditionals(1, s.modes, potentials);
// Prune! // Prune!
posterior->prune(maxNrLeaves); auto pruned = posterior->prune(maxNrLeaves);
// Functor to verify values against the expected_discrete_conditionals // Functor to verify values against the expected_discrete_conditionals
auto checker = [&](const Assignment<Key>& assignment, auto checker = [&](const Assignment<Key>& assignment,
@ -375,7 +397,7 @@ TEST(HybridBayesNet, UpdateDiscreteConditionals) {
}; };
// Get the pruned discrete conditionals as an AlgebraicDecisionTree // Get the pruned discrete conditionals as an AlgebraicDecisionTree
auto pruned_discrete_conditionals = posterior->at(4)->asDiscrete(); auto pruned_discrete_conditionals = pruned.at(4)->asDiscrete();
auto discrete_conditional_tree = auto discrete_conditional_tree =
std::dynamic_pointer_cast<DecisionTreeFactor::ADT>( std::dynamic_pointer_cast<DecisionTreeFactor::ADT>(
pruned_discrete_conditionals); pruned_discrete_conditionals);