get failing tests in testHybridBayesNet to pass
parent
e52970aa92
commit
44e8485360
|
@ -363,10 +363,6 @@ TEST(HybridBayesNet, Pruning) {
|
||||||
AlgebraicDecisionTree<Key> expected(s.modes, leaves);
|
AlgebraicDecisionTree<Key> expected(s.modes, leaves);
|
||||||
EXPECT(assert_equal(expected, discretePosterior, 1e-6));
|
EXPECT(assert_equal(expected, discretePosterior, 1e-6));
|
||||||
|
|
||||||
// Prune and get probabilities
|
|
||||||
auto prunedBayesNet = posterior->prune(2);
|
|
||||||
auto prunedTree = prunedBayesNet.discretePosterior(delta.continuous());
|
|
||||||
|
|
||||||
// Verify logProbability computation and check specific logProbability value
|
// Verify logProbability computation and check specific logProbability value
|
||||||
const DiscreteValues discrete_values{{M(0), 1}, {M(1), 1}};
|
const DiscreteValues discrete_values{{M(0), 1}, {M(1), 1}};
|
||||||
const HybridValues hybridValues{delta.continuous(), discrete_values};
|
const HybridValues hybridValues{delta.continuous(), discrete_values};
|
||||||
|
@ -381,10 +377,21 @@ TEST(HybridBayesNet, Pruning) {
|
||||||
EXPECT_DOUBLES_EQUAL(logProbability, posterior->logProbability(hybridValues),
|
EXPECT_DOUBLES_EQUAL(logProbability, posterior->logProbability(hybridValues),
|
||||||
1e-9);
|
1e-9);
|
||||||
|
|
||||||
|
double negLogConstant = posterior->negLogConstant(discrete_values);
|
||||||
|
|
||||||
|
// The sum of all the mode densities
|
||||||
|
double normalizer =
|
||||||
|
AlgebraicDecisionTree<Key>(posterior->errorTree(delta.continuous()),
|
||||||
|
[](double error) { return exp(-error); })
|
||||||
|
.sum();
|
||||||
|
|
||||||
// Check agreement with discrete posterior
|
// Check agreement with discrete posterior
|
||||||
// double density = exp(logProbability);
|
double density = exp(logProbability + negLogConstant) / normalizer;
|
||||||
// FAILS: EXPECT_DOUBLES_EQUAL(density, discretePosterior(discrete_values),
|
EXPECT_DOUBLES_EQUAL(density, discretePosterior(discrete_values), 1e-6);
|
||||||
// 1e-6);
|
|
||||||
|
// Prune and get probabilities
|
||||||
|
auto prunedBayesNet = posterior->prune(2);
|
||||||
|
auto prunedTree = prunedBayesNet.discretePosterior(delta.continuous());
|
||||||
|
|
||||||
// Regression test on pruned logProbability tree
|
// Regression test on pruned logProbability tree
|
||||||
std::vector<double> pruned_leaves = {0.0, 0.50758422, 0.0, 0.49241578};
|
std::vector<double> pruned_leaves = {0.0, 0.50758422, 0.0, 0.49241578};
|
||||||
|
@ -392,7 +399,26 @@ TEST(HybridBayesNet, Pruning) {
|
||||||
EXPECT(assert_equal(expected_pruned, prunedTree, 1e-6));
|
EXPECT(assert_equal(expected_pruned, prunedTree, 1e-6));
|
||||||
|
|
||||||
// Regression
|
// Regression
|
||||||
// FAILS: EXPECT_DOUBLES_EQUAL(density, prunedTree(discrete_values), 1e-9);
|
double pruned_logProbability = 0;
|
||||||
|
pruned_logProbability +=
|
||||||
|
prunedBayesNet.at(0)->asDiscrete()->logProbability(hybridValues);
|
||||||
|
pruned_logProbability +=
|
||||||
|
prunedBayesNet.at(1)->asHybrid()->logProbability(hybridValues);
|
||||||
|
pruned_logProbability +=
|
||||||
|
prunedBayesNet.at(2)->asHybrid()->logProbability(hybridValues);
|
||||||
|
pruned_logProbability +=
|
||||||
|
prunedBayesNet.at(3)->asHybrid()->logProbability(hybridValues);
|
||||||
|
|
||||||
|
double pruned_negLogConstant = prunedBayesNet.negLogConstant(discrete_values);
|
||||||
|
|
||||||
|
// The sum of all the mode densities
|
||||||
|
double pruned_normalizer =
|
||||||
|
AlgebraicDecisionTree<Key>(prunedBayesNet.errorTree(delta.continuous()),
|
||||||
|
[](double error) { return exp(-error); })
|
||||||
|
.sum();
|
||||||
|
double pruned_density =
|
||||||
|
exp(pruned_logProbability + pruned_negLogConstant) / pruned_normalizer;
|
||||||
|
EXPECT_DOUBLES_EQUAL(pruned_density, prunedTree(discrete_values), 1e-9);
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ****************************************************************************/
|
/* ****************************************************************************/
|
||||||
|
|
Loading…
Reference in New Issue