break big error unit test in HBN to two smaller ones
parent
d5f304ef50
commit
5d7089a5a9
|
@ -343,12 +343,21 @@ TEST(HybridBayesNet, Optimize) {
|
|||
}
|
||||
|
||||
/* ****************************************************************************/
|
||||
// Test Bayes net error
|
||||
TEST(HybridBayesNet, Pruning) {
|
||||
namespace hbn_error {
|
||||
// Create switching network with three continuous variables and two discrete:
|
||||
// ϕ(x0) ϕ(x0,x1,m0) ϕ(x1,x2,m1) ϕ(x0;z0) ϕ(x1;z1) ϕ(x2;z2) ϕ(m0) ϕ(m0,m1)
|
||||
Switching s(3);
|
||||
|
||||
// The true discrete assignment
|
||||
const DiscreteValues discrete_values{{M(0), 1}, {M(1), 1}};
|
||||
|
||||
} // namespace hbn_error
|
||||
|
||||
/* ****************************************************************************/
|
||||
// Test Bayes net error and log-probability
|
||||
TEST(HybridBayesNet, Error) {
|
||||
using namespace hbn_error;
|
||||
|
||||
HybridBayesNet::shared_ptr posterior =
|
||||
s.linearizedFactorGraph().eliminateSequential();
|
||||
EXPECT_LONGS_EQUAL(5, posterior->size());
|
||||
|
@ -366,7 +375,6 @@ TEST(HybridBayesNet, Pruning) {
|
|||
EXPECT(assert_equal(expected, discretePosterior, 1e-6));
|
||||
|
||||
// Verify logProbability computation and check specific logProbability value
|
||||
const DiscreteValues discrete_values{{M(0), 1}, {M(1), 1}};
|
||||
const HybridValues hybridValues{delta.continuous(), discrete_values};
|
||||
double logProbability = 0;
|
||||
logProbability += posterior->at(0)->asHybrid()->logProbability(hybridValues);
|
||||
|
@ -390,17 +398,32 @@ TEST(HybridBayesNet, Pruning) {
|
|||
// Check agreement with discrete posterior
|
||||
double density = exp(logProbability + negLogConstant) / normalizer;
|
||||
EXPECT_DOUBLES_EQUAL(density, discretePosterior(discrete_values), 1e-6);
|
||||
}
|
||||
|
||||
/* ****************************************************************************/
|
||||
// Test Bayes net error and log-probability after pruning
|
||||
TEST(HybridBayesNet, ErrorAfterPruning) {
|
||||
using namespace hbn_error;
|
||||
|
||||
HybridBayesNet::shared_ptr posterior =
|
||||
s.linearizedFactorGraph().eliminateSequential();
|
||||
EXPECT_LONGS_EQUAL(5, posterior->size());
|
||||
|
||||
// Optimize
|
||||
HybridValues delta = posterior->optimize();
|
||||
|
||||
// Prune and get probabilities
|
||||
auto prunedBayesNet = posterior->prune(2);
|
||||
auto prunedTree = prunedBayesNet.discretePosterior(delta.continuous());
|
||||
HybridBayesNet prunedBayesNet = posterior->prune(2);
|
||||
AlgebraicDecisionTree<Key> prunedTree =
|
||||
prunedBayesNet.discretePosterior(delta.continuous());
|
||||
|
||||
// Regression test on pruned logProbability tree
|
||||
// Regression test on pruned probability tree
|
||||
std::vector<double> pruned_leaves = {0.0, 0.50758422, 0.0, 0.49241578};
|
||||
AlgebraicDecisionTree<Key> expected_pruned(s.modes, pruned_leaves);
|
||||
EXPECT(assert_equal(expected_pruned, prunedTree, 1e-6));
|
||||
|
||||
// Regression
|
||||
// Regression to check specific logProbability value
|
||||
const HybridValues hybridValues{delta.continuous(), discrete_values};
|
||||
double pruned_logProbability = 0;
|
||||
pruned_logProbability +=
|
||||
prunedBayesNet.at(0)->asDiscrete()->logProbability(hybridValues);
|
||||
|
@ -423,24 +446,6 @@ TEST(HybridBayesNet, Pruning) {
|
|||
EXPECT_DOUBLES_EQUAL(pruned_density, prunedTree(discrete_values), 1e-9);
|
||||
}
|
||||
|
||||
/* ****************************************************************************/
|
||||
// Test Bayes net pruning
|
||||
TEST(HybridBayesNet, Prune) {
|
||||
Switching s(4);
|
||||
|
||||
HybridBayesNet::shared_ptr posterior =
|
||||
s.linearizedFactorGraph().eliminateSequential();
|
||||
EXPECT_LONGS_EQUAL(7, posterior->size());
|
||||
|
||||
HybridValues delta = posterior->optimize();
|
||||
|
||||
auto prunedBayesNet = posterior->prune(2);
|
||||
HybridValues pruned_delta = prunedBayesNet.optimize();
|
||||
|
||||
EXPECT(assert_equal(delta.discrete(), pruned_delta.discrete()));
|
||||
EXPECT(assert_equal(delta.continuous(), pruned_delta.continuous()));
|
||||
}
|
||||
|
||||
/* ****************************************************************************/
|
||||
// Test Bayes net updateDiscreteConditionals
|
||||
TEST(HybridBayesNet, UpdateDiscreteConditionals) {
|
||||
|
|
Loading…
Reference in New Issue