test for different error values in BN from MixtureFactor
parent
113a7f8e6b
commit
2430abb4bc
|
@ -298,7 +298,6 @@ static std::shared_ptr<Factor> createDiscreteFactor(
|
||||||
// exp(-factor->error(kEmpty)) / conditional->normalizationConstant();
|
// exp(-factor->error(kEmpty)) / conditional->normalizationConstant();
|
||||||
// We take negative of the logNormalizationConstant `log(1/k)`
|
// We take negative of the logNormalizationConstant `log(1/k)`
|
||||||
// to get `log(k)`.
|
// to get `log(k)`.
|
||||||
// factor->print("Discrete Separator");
|
|
||||||
return -factor->error(kEmpty) + (-conditional->logNormalizationConstant());
|
return -factor->error(kEmpty) + (-conditional->logNormalizationConstant());
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
|
@ -161,6 +161,35 @@ TEST(MixtureFactor, DifferentMeans) {
|
||||||
DiscreteValues{{M(1), 1}, {M(2), 0}});
|
DiscreteValues{{M(1), 1}, {M(2), 0}});
|
||||||
|
|
||||||
EXPECT(assert_equal(expected, actual));
|
EXPECT(assert_equal(expected, actual));
|
||||||
|
|
||||||
|
{
|
||||||
|
DiscreteValues dv{{M(1), 0}, {M(2), 0}};
|
||||||
|
VectorValues cont = bn->optimize(dv);
|
||||||
|
double error = bn->error(HybridValues(cont, dv));
|
||||||
|
// regression
|
||||||
|
EXPECT_DOUBLES_EQUAL(1.77418393408, error, 1e-9);
|
||||||
|
}
|
||||||
|
{
|
||||||
|
DiscreteValues dv{{M(1), 0}, {M(2), 1}};
|
||||||
|
VectorValues cont = bn->optimize(dv);
|
||||||
|
double error = bn->error(HybridValues(cont, dv));
|
||||||
|
// regression
|
||||||
|
EXPECT_DOUBLES_EQUAL(1.77418393408, error, 1e-9);
|
||||||
|
}
|
||||||
|
{
|
||||||
|
DiscreteValues dv{{M(1), 1}, {M(2), 0}};
|
||||||
|
VectorValues cont = bn->optimize(dv);
|
||||||
|
double error = bn->error(HybridValues(cont, dv));
|
||||||
|
// regression
|
||||||
|
EXPECT_DOUBLES_EQUAL(1.10751726741, error, 1e-9);
|
||||||
|
}
|
||||||
|
{
|
||||||
|
DiscreteValues dv{{M(1), 1}, {M(2), 1}};
|
||||||
|
VectorValues cont = bn->optimize(dv);
|
||||||
|
double error = bn->error(HybridValues(cont, dv));
|
||||||
|
// regression
|
||||||
|
EXPECT_DOUBLES_EQUAL(1.10751726741, error, 1e-9);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
|
@ -217,15 +246,35 @@ TEST(MixtureFactor, DifferentCovariances) {
|
||||||
|
|
||||||
auto hbn = mixture_fg.eliminateSequential();
|
auto hbn = mixture_fg.eliminateSequential();
|
||||||
|
|
||||||
HybridValues actual_values = hbn->optimize();
|
|
||||||
|
|
||||||
VectorValues cv;
|
VectorValues cv;
|
||||||
cv.insert(X(1), Vector1(0.0));
|
cv.insert(X(1), Vector1(0.0));
|
||||||
cv.insert(X(2), Vector1(0.0));
|
cv.insert(X(2), Vector1(0.0));
|
||||||
|
|
||||||
|
// Check that we get different error values at the MLE point μ.
|
||||||
|
AlgebraicDecisionTree<Key> errorTree = hbn->errorTree(cv);
|
||||||
|
auto cond0 = hbn->at(0)->asMixture();
|
||||||
|
auto cond1 = hbn->at(1)->asMixture();
|
||||||
|
auto discrete_cond = hbn->at(2)->asDiscrete();
|
||||||
|
|
||||||
|
HybridValues hv0(cv, DiscreteValues{{M(1), 0}});
|
||||||
|
HybridValues hv1(cv, DiscreteValues{{M(1), 1}});
|
||||||
|
AlgebraicDecisionTree<Key> expectedErrorTree(
|
||||||
|
m1,
|
||||||
|
cond0->error(hv0) // cond0(0)->logNormalizationConstant()
|
||||||
|
// - cond0(1)->logNormalizationConstant
|
||||||
|
+ cond1->error(hv0) + discrete_cond->error(DiscreteValues{{M(1), 0}}),
|
||||||
|
cond0->error(hv1) // cond1(0)->logNormalizationConstant()
|
||||||
|
// - cond1(1)->logNormalizationConstant
|
||||||
|
+ cond1->error(hv1) +
|
||||||
|
discrete_cond->error(DiscreteValues{{M(1), 0}}));
|
||||||
|
EXPECT(assert_equal(expectedErrorTree, errorTree));
|
||||||
|
|
||||||
DiscreteValues dv;
|
DiscreteValues dv;
|
||||||
dv.insert({M(1), 1});
|
dv.insert({M(1), 1});
|
||||||
HybridValues expected_values(cv, dv);
|
HybridValues expected_values(cv, dv);
|
||||||
|
|
||||||
|
HybridValues actual_values = hbn->optimize();
|
||||||
|
|
||||||
EXPECT(assert_equal(expected_values, actual_values));
|
EXPECT(assert_equal(expected_values, actual_values));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue