break big error unit test in HBN to two smaller ones

release/4.3a0
Varun Agrawal 2025-01-21 15:30:03 -05:00
parent d5f304ef50
commit 5d7089a5a9
1 changed files with 33 additions and 28 deletions

View File

@ -343,12 +343,21 @@ TEST(HybridBayesNet, Optimize) {
}
/* ****************************************************************************/
// Test Bayes net error
TEST(HybridBayesNet, Pruning) {
namespace hbn_error {
// Create switching network with three continuous variables and two discrete:
// ϕ(x0) ϕ(x0,x1,m0) ϕ(x1,x2,m1) ϕ(x0;z0) ϕ(x1;z1) ϕ(x2;z2) ϕ(m0) ϕ(m0,m1)
Switching s(3);
// The true discrete assignment
const DiscreteValues discrete_values{{M(0), 1}, {M(1), 1}};
} // namespace hbn_error
/* ****************************************************************************/
// Test Bayes net error and log-probability
TEST(HybridBayesNet, Error) {
using namespace hbn_error;
HybridBayesNet::shared_ptr posterior =
s.linearizedFactorGraph().eliminateSequential();
EXPECT_LONGS_EQUAL(5, posterior->size());
@ -366,7 +375,6 @@ TEST(HybridBayesNet, Pruning) {
EXPECT(assert_equal(expected, discretePosterior, 1e-6));
// Verify logProbability computation and check specific logProbability value
const DiscreteValues discrete_values{{M(0), 1}, {M(1), 1}};
const HybridValues hybridValues{delta.continuous(), discrete_values};
double logProbability = 0;
logProbability += posterior->at(0)->asHybrid()->logProbability(hybridValues);
@ -390,17 +398,32 @@ TEST(HybridBayesNet, Pruning) {
// Check agreement with discrete posterior
double density = exp(logProbability + negLogConstant) / normalizer;
EXPECT_DOUBLES_EQUAL(density, discretePosterior(discrete_values), 1e-6);
}
/* ****************************************************************************/
// Test Bayes net error and log-probability after pruning
TEST(HybridBayesNet, ErrorAfterPruning) {
using namespace hbn_error;
HybridBayesNet::shared_ptr posterior =
s.linearizedFactorGraph().eliminateSequential();
EXPECT_LONGS_EQUAL(5, posterior->size());
// Optimize
HybridValues delta = posterior->optimize();
// Prune and get probabilities
auto prunedBayesNet = posterior->prune(2);
auto prunedTree = prunedBayesNet.discretePosterior(delta.continuous());
HybridBayesNet prunedBayesNet = posterior->prune(2);
AlgebraicDecisionTree<Key> prunedTree =
prunedBayesNet.discretePosterior(delta.continuous());
// Regression test on pruned logProbability tree
// Regression test on pruned probability tree
std::vector<double> pruned_leaves = {0.0, 0.50758422, 0.0, 0.49241578};
AlgebraicDecisionTree<Key> expected_pruned(s.modes, pruned_leaves);
EXPECT(assert_equal(expected_pruned, prunedTree, 1e-6));
// Regression
// Regression to check specific logProbability value
const HybridValues hybridValues{delta.continuous(), discrete_values};
double pruned_logProbability = 0;
pruned_logProbability +=
prunedBayesNet.at(0)->asDiscrete()->logProbability(hybridValues);
@ -423,24 +446,6 @@ TEST(HybridBayesNet, Pruning) {
EXPECT_DOUBLES_EQUAL(pruned_density, prunedTree(discrete_values), 1e-9);
}
/* ****************************************************************************/
// Test Bayes net pruning
TEST(HybridBayesNet, Prune) {
Switching s(4);
HybridBayesNet::shared_ptr posterior =
s.linearizedFactorGraph().eliminateSequential();
EXPECT_LONGS_EQUAL(7, posterior->size());
HybridValues delta = posterior->optimize();
auto prunedBayesNet = posterior->prune(2);
HybridValues pruned_delta = prunedBayesNet.optimize();
EXPECT(assert_equal(delta.discrete(), pruned_delta.discrete()));
EXPECT(assert_equal(delta.continuous(), pruned_delta.continuous()));
}
/* ****************************************************************************/
// Test Bayes net updateDiscreteConditionals
TEST(HybridBayesNet, UpdateDiscreteConditionals) {