test for different error values in BN from MixtureFactor

release/4.3a0
Varun Agrawal 2024-08-07 09:22:20 -04:00
parent 113a7f8e6b
commit 2430abb4bc
2 changed files with 51 additions and 3 deletions

View File

@ -298,7 +298,6 @@ static std::shared_ptr<Factor> createDiscreteFactor(
// exp(-factor->error(kEmpty)) / conditional->normalizationConstant();
// We take negative of the logNormalizationConstant `log(1/k)`
// to get `log(k)`.
// factor->print("Discrete Separator");
return -factor->error(kEmpty) + (-conditional->logNormalizationConstant());
};

View File

@ -161,6 +161,35 @@ TEST(MixtureFactor, DifferentMeans) {
DiscreteValues{{M(1), 1}, {M(2), 0}});
EXPECT(assert_equal(expected, actual));
{
DiscreteValues dv{{M(1), 0}, {M(2), 0}};
VectorValues cont = bn->optimize(dv);
double error = bn->error(HybridValues(cont, dv));
// regression
EXPECT_DOUBLES_EQUAL(1.77418393408, error, 1e-9);
}
{
DiscreteValues dv{{M(1), 0}, {M(2), 1}};
VectorValues cont = bn->optimize(dv);
double error = bn->error(HybridValues(cont, dv));
// regression
EXPECT_DOUBLES_EQUAL(1.77418393408, error, 1e-9);
}
{
DiscreteValues dv{{M(1), 1}, {M(2), 0}};
VectorValues cont = bn->optimize(dv);
double error = bn->error(HybridValues(cont, dv));
// regression
EXPECT_DOUBLES_EQUAL(1.10751726741, error, 1e-9);
}
{
DiscreteValues dv{{M(1), 1}, {M(2), 1}};
VectorValues cont = bn->optimize(dv);
double error = bn->error(HybridValues(cont, dv));
// regression
EXPECT_DOUBLES_EQUAL(1.10751726741, error, 1e-9);
}
}
/* ************************************************************************* */
@ -217,15 +246,35 @@ TEST(MixtureFactor, DifferentCovariances) {
auto hbn = mixture_fg.eliminateSequential();
HybridValues actual_values = hbn->optimize();
VectorValues cv;
cv.insert(X(1), Vector1(0.0));
cv.insert(X(2), Vector1(0.0));
// Check that we get different error values at the MLE point μ.
AlgebraicDecisionTree<Key> errorTree = hbn->errorTree(cv);
auto cond0 = hbn->at(0)->asMixture();
auto cond1 = hbn->at(1)->asMixture();
auto discrete_cond = hbn->at(2)->asDiscrete();
HybridValues hv0(cv, DiscreteValues{{M(1), 0}});
HybridValues hv1(cv, DiscreteValues{{M(1), 1}});
AlgebraicDecisionTree<Key> expectedErrorTree(
m1,
cond0->error(hv0) // cond0(0)->logNormalizationConstant()
// - cond0(1)->logNormalizationConstant
+ cond1->error(hv0) + discrete_cond->error(DiscreteValues{{M(1), 0}}),
cond0->error(hv1) // cond1(0)->logNormalizationConstant()
// - cond1(1)->logNormalizationConstant
+ cond1->error(hv1) +
discrete_cond->error(DiscreteValues{{M(1), 0}}));
EXPECT(assert_equal(expectedErrorTree, errorTree));
DiscreteValues dv;
dv.insert({M(1), 1});
HybridValues expected_values(cv, dv);
HybridValues actual_values = hbn->optimize();
EXPECT(assert_equal(expected_values, actual_values));
}