test for different error values in BN from MixtureFactor
parent
113a7f8e6b
commit
2430abb4bc
|
@ -298,7 +298,6 @@ static std::shared_ptr<Factor> createDiscreteFactor(
|
|||
// exp(-factor->error(kEmpty)) / conditional->normalizationConstant();
|
||||
// We take negative of the logNormalizationConstant `log(1/k)`
|
||||
// to get `log(k)`.
|
||||
// factor->print("Discrete Separator");
|
||||
return -factor->error(kEmpty) + (-conditional->logNormalizationConstant());
|
||||
};
|
||||
|
||||
|
|
|
@ -161,6 +161,35 @@ TEST(MixtureFactor, DifferentMeans) {
|
|||
DiscreteValues{{M(1), 1}, {M(2), 0}});
|
||||
|
||||
EXPECT(assert_equal(expected, actual));
|
||||
|
||||
{
|
||||
DiscreteValues dv{{M(1), 0}, {M(2), 0}};
|
||||
VectorValues cont = bn->optimize(dv);
|
||||
double error = bn->error(HybridValues(cont, dv));
|
||||
// regression
|
||||
EXPECT_DOUBLES_EQUAL(1.77418393408, error, 1e-9);
|
||||
}
|
||||
{
|
||||
DiscreteValues dv{{M(1), 0}, {M(2), 1}};
|
||||
VectorValues cont = bn->optimize(dv);
|
||||
double error = bn->error(HybridValues(cont, dv));
|
||||
// regression
|
||||
EXPECT_DOUBLES_EQUAL(1.77418393408, error, 1e-9);
|
||||
}
|
||||
{
|
||||
DiscreteValues dv{{M(1), 1}, {M(2), 0}};
|
||||
VectorValues cont = bn->optimize(dv);
|
||||
double error = bn->error(HybridValues(cont, dv));
|
||||
// regression
|
||||
EXPECT_DOUBLES_EQUAL(1.10751726741, error, 1e-9);
|
||||
}
|
||||
{
|
||||
DiscreteValues dv{{M(1), 1}, {M(2), 1}};
|
||||
VectorValues cont = bn->optimize(dv);
|
||||
double error = bn->error(HybridValues(cont, dv));
|
||||
// regression
|
||||
EXPECT_DOUBLES_EQUAL(1.10751726741, error, 1e-9);
|
||||
}
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
|
@ -217,15 +246,35 @@ TEST(MixtureFactor, DifferentCovariances) {
|
|||
|
||||
auto hbn = mixture_fg.eliminateSequential();
|
||||
|
||||
HybridValues actual_values = hbn->optimize();
|
||||
|
||||
VectorValues cv;
|
||||
cv.insert(X(1), Vector1(0.0));
|
||||
cv.insert(X(2), Vector1(0.0));
|
||||
|
||||
// Check that we get different error values at the MLE point μ.
|
||||
AlgebraicDecisionTree<Key> errorTree = hbn->errorTree(cv);
|
||||
auto cond0 = hbn->at(0)->asMixture();
|
||||
auto cond1 = hbn->at(1)->asMixture();
|
||||
auto discrete_cond = hbn->at(2)->asDiscrete();
|
||||
|
||||
HybridValues hv0(cv, DiscreteValues{{M(1), 0}});
|
||||
HybridValues hv1(cv, DiscreteValues{{M(1), 1}});
|
||||
AlgebraicDecisionTree<Key> expectedErrorTree(
|
||||
m1,
|
||||
cond0->error(hv0) // cond0(0)->logNormalizationConstant()
|
||||
// - cond0(1)->logNormalizationConstant
|
||||
+ cond1->error(hv0) + discrete_cond->error(DiscreteValues{{M(1), 0}}),
|
||||
cond0->error(hv1) // cond1(0)->logNormalizationConstant()
|
||||
// - cond1(1)->logNormalizationConstant
|
||||
+ cond1->error(hv1) +
|
||||
discrete_cond->error(DiscreteValues{{M(1), 0}}));
|
||||
EXPECT(assert_equal(expectedErrorTree, errorTree));
|
||||
|
||||
DiscreteValues dv;
|
||||
dv.insert({M(1), 1});
|
||||
HybridValues expected_values(cv, dv);
|
||||
|
||||
HybridValues actual_values = hbn->optimize();
|
||||
|
||||
EXPECT(assert_equal(expected_values, actual_values));
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue