break big error unit test in HBN to two smaller ones
parent
d5f304ef50
commit
5d7089a5a9
|
@ -343,11 +343,20 @@ TEST(HybridBayesNet, Optimize) {
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ****************************************************************************/
|
/* ****************************************************************************/
|
||||||
// Test Bayes net error
|
namespace hbn_error {
|
||||||
TEST(HybridBayesNet, Pruning) {
|
// Create switching network with three continuous variables and two discrete:
|
||||||
// Create switching network with three continuous variables and two discrete:
|
// ϕ(x0) ϕ(x0,x1,m0) ϕ(x1,x2,m1) ϕ(x0;z0) ϕ(x1;z1) ϕ(x2;z2) ϕ(m0) ϕ(m0,m1)
|
||||||
// ϕ(x0) ϕ(x0,x1,m0) ϕ(x1,x2,m1) ϕ(x0;z0) ϕ(x1;z1) ϕ(x2;z2) ϕ(m0) ϕ(m0,m1)
|
Switching s(3);
|
||||||
Switching s(3);
|
|
||||||
|
// The true discrete assignment
|
||||||
|
const DiscreteValues discrete_values{{M(0), 1}, {M(1), 1}};
|
||||||
|
|
||||||
|
} // namespace hbn_error
|
||||||
|
|
||||||
|
/* ****************************************************************************/
|
||||||
|
// Test Bayes net error and log-probability
|
||||||
|
TEST(HybridBayesNet, Error) {
|
||||||
|
using namespace hbn_error;
|
||||||
|
|
||||||
HybridBayesNet::shared_ptr posterior =
|
HybridBayesNet::shared_ptr posterior =
|
||||||
s.linearizedFactorGraph().eliminateSequential();
|
s.linearizedFactorGraph().eliminateSequential();
|
||||||
|
@ -366,7 +375,6 @@ TEST(HybridBayesNet, Pruning) {
|
||||||
EXPECT(assert_equal(expected, discretePosterior, 1e-6));
|
EXPECT(assert_equal(expected, discretePosterior, 1e-6));
|
||||||
|
|
||||||
// Verify logProbability computation and check specific logProbability value
|
// Verify logProbability computation and check specific logProbability value
|
||||||
const DiscreteValues discrete_values{{M(0), 1}, {M(1), 1}};
|
|
||||||
const HybridValues hybridValues{delta.continuous(), discrete_values};
|
const HybridValues hybridValues{delta.continuous(), discrete_values};
|
||||||
double logProbability = 0;
|
double logProbability = 0;
|
||||||
logProbability += posterior->at(0)->asHybrid()->logProbability(hybridValues);
|
logProbability += posterior->at(0)->asHybrid()->logProbability(hybridValues);
|
||||||
|
@ -390,17 +398,32 @@ TEST(HybridBayesNet, Pruning) {
|
||||||
// Check agreement with discrete posterior
|
// Check agreement with discrete posterior
|
||||||
double density = exp(logProbability + negLogConstant) / normalizer;
|
double density = exp(logProbability + negLogConstant) / normalizer;
|
||||||
EXPECT_DOUBLES_EQUAL(density, discretePosterior(discrete_values), 1e-6);
|
EXPECT_DOUBLES_EQUAL(density, discretePosterior(discrete_values), 1e-6);
|
||||||
|
}
|
||||||
|
|
||||||
|
/* ****************************************************************************/
|
||||||
|
// Test Bayes net error and log-probability after pruning
|
||||||
|
TEST(HybridBayesNet, ErrorAfterPruning) {
|
||||||
|
using namespace hbn_error;
|
||||||
|
|
||||||
|
HybridBayesNet::shared_ptr posterior =
|
||||||
|
s.linearizedFactorGraph().eliminateSequential();
|
||||||
|
EXPECT_LONGS_EQUAL(5, posterior->size());
|
||||||
|
|
||||||
|
// Optimize
|
||||||
|
HybridValues delta = posterior->optimize();
|
||||||
|
|
||||||
// Prune and get probabilities
|
// Prune and get probabilities
|
||||||
auto prunedBayesNet = posterior->prune(2);
|
HybridBayesNet prunedBayesNet = posterior->prune(2);
|
||||||
auto prunedTree = prunedBayesNet.discretePosterior(delta.continuous());
|
AlgebraicDecisionTree<Key> prunedTree =
|
||||||
|
prunedBayesNet.discretePosterior(delta.continuous());
|
||||||
|
|
||||||
// Regression test on pruned logProbability tree
|
// Regression test on pruned probability tree
|
||||||
std::vector<double> pruned_leaves = {0.0, 0.50758422, 0.0, 0.49241578};
|
std::vector<double> pruned_leaves = {0.0, 0.50758422, 0.0, 0.49241578};
|
||||||
AlgebraicDecisionTree<Key> expected_pruned(s.modes, pruned_leaves);
|
AlgebraicDecisionTree<Key> expected_pruned(s.modes, pruned_leaves);
|
||||||
EXPECT(assert_equal(expected_pruned, prunedTree, 1e-6));
|
EXPECT(assert_equal(expected_pruned, prunedTree, 1e-6));
|
||||||
|
|
||||||
// Regression
|
// Regression to check specific logProbability value
|
||||||
|
const HybridValues hybridValues{delta.continuous(), discrete_values};
|
||||||
double pruned_logProbability = 0;
|
double pruned_logProbability = 0;
|
||||||
pruned_logProbability +=
|
pruned_logProbability +=
|
||||||
prunedBayesNet.at(0)->asDiscrete()->logProbability(hybridValues);
|
prunedBayesNet.at(0)->asDiscrete()->logProbability(hybridValues);
|
||||||
|
@ -423,24 +446,6 @@ TEST(HybridBayesNet, Pruning) {
|
||||||
EXPECT_DOUBLES_EQUAL(pruned_density, prunedTree(discrete_values), 1e-9);
|
EXPECT_DOUBLES_EQUAL(pruned_density, prunedTree(discrete_values), 1e-9);
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ****************************************************************************/
|
|
||||||
// Test Bayes net pruning
|
|
||||||
TEST(HybridBayesNet, Prune) {
|
|
||||||
Switching s(4);
|
|
||||||
|
|
||||||
HybridBayesNet::shared_ptr posterior =
|
|
||||||
s.linearizedFactorGraph().eliminateSequential();
|
|
||||||
EXPECT_LONGS_EQUAL(7, posterior->size());
|
|
||||||
|
|
||||||
HybridValues delta = posterior->optimize();
|
|
||||||
|
|
||||||
auto prunedBayesNet = posterior->prune(2);
|
|
||||||
HybridValues pruned_delta = prunedBayesNet.optimize();
|
|
||||||
|
|
||||||
EXPECT(assert_equal(delta.discrete(), pruned_delta.discrete()));
|
|
||||||
EXPECT(assert_equal(delta.continuous(), pruned_delta.continuous()));
|
|
||||||
}
|
|
||||||
|
|
||||||
/* ****************************************************************************/
|
/* ****************************************************************************/
|
||||||
// Test Bayes net updateDiscreteConditionals
|
// Test Bayes net updateDiscreteConditionals
|
||||||
TEST(HybridBayesNet, UpdateDiscreteConditionals) {
|
TEST(HybridBayesNet, UpdateDiscreteConditionals) {
|
||||||
|
|
Loading…
Reference in New Issue