get failing tests in testHybridBayesNet to pass

release/4.3a0
Varun Agrawal 2024-11-01 20:23:32 -04:00
parent e52970aa92
commit 44e8485360
1 changed files with 34 additions and 8 deletions

View File

@ -363,10 +363,6 @@ TEST(HybridBayesNet, Pruning) {
AlgebraicDecisionTree<Key> expected(s.modes, leaves);
EXPECT(assert_equal(expected, discretePosterior, 1e-6));
// Prune and get probabilities
auto prunedBayesNet = posterior->prune(2);
auto prunedTree = prunedBayesNet.discretePosterior(delta.continuous());
// Verify logProbability computation and check specific logProbability value
const DiscreteValues discrete_values{{M(0), 1}, {M(1), 1}};
const HybridValues hybridValues{delta.continuous(), discrete_values};
@ -381,10 +377,21 @@ TEST(HybridBayesNet, Pruning) {
EXPECT_DOUBLES_EQUAL(logProbability, posterior->logProbability(hybridValues),
1e-9);
double negLogConstant = posterior->negLogConstant(discrete_values);
// The sum of all the mode densities
double normalizer =
AlgebraicDecisionTree<Key>(posterior->errorTree(delta.continuous()),
[](double error) { return exp(-error); })
.sum();
// Check agreement with discrete posterior
// double density = exp(logProbability);
// FAILS: EXPECT_DOUBLES_EQUAL(density, discretePosterior(discrete_values),
// 1e-6);
double density = exp(logProbability + negLogConstant) / normalizer;
EXPECT_DOUBLES_EQUAL(density, discretePosterior(discrete_values), 1e-6);
// Prune and get probabilities
auto prunedBayesNet = posterior->prune(2);
auto prunedTree = prunedBayesNet.discretePosterior(delta.continuous());
// Regression test on pruned logProbability tree
std::vector<double> pruned_leaves = {0.0, 0.50758422, 0.0, 0.49241578};
@ -392,7 +399,26 @@ TEST(HybridBayesNet, Pruning) {
EXPECT(assert_equal(expected_pruned, prunedTree, 1e-6));
// Regression
// FAILS: EXPECT_DOUBLES_EQUAL(density, prunedTree(discrete_values), 1e-9);
double pruned_logProbability = 0;
pruned_logProbability +=
prunedBayesNet.at(0)->asDiscrete()->logProbability(hybridValues);
pruned_logProbability +=
prunedBayesNet.at(1)->asHybrid()->logProbability(hybridValues);
pruned_logProbability +=
prunedBayesNet.at(2)->asHybrid()->logProbability(hybridValues);
pruned_logProbability +=
prunedBayesNet.at(3)->asHybrid()->logProbability(hybridValues);
double pruned_negLogConstant = prunedBayesNet.negLogConstant(discrete_values);
// The sum of all the mode densities
double pruned_normalizer =
AlgebraicDecisionTree<Key>(prunedBayesNet.errorTree(delta.continuous()),
[](double error) { return exp(-error); })
.sum();
double pruned_density =
exp(pruned_logProbability + pruned_negLogConstant) / pruned_normalizer;
EXPECT_DOUBLES_EQUAL(pruned_density, prunedTree(discrete_values), 1e-9);
}
/* ****************************************************************************/