proto code for likelihood

release/4.3a0
Frank Dellaert 2022-12-29 13:21:20 -05:00
parent 364417e4aa
commit 611f61c7f4
1 changed files with 70 additions and 15 deletions

View File

@ -135,19 +135,12 @@ TEST(GaussianMixture, Error) {
}
/* ************************************************************************* */
// Create a likelihood factor for a Gaussian mixture, return a Mixture factor on
// the parents.
GaussianMixtureFactor::shared_ptr likelihood(const HybridValues& values) {
GaussianMixtureFactor::shared_ptr factor;
return factor;
}
/// Check that likelihood returns a mixture factor on the parents.
TEST(GaussianMixture, Likelihood) {
// Create mode key: 0 is low-noise, 1 is high-noise.
Key modeKey = M(0);
DiscreteKey mode(modeKey, 2);
// Create mode key: 0 is low-noise, 1 is high-noise.
static const Key modeKey = M(0);
static const DiscreteKey mode(modeKey, 2);
// Create a simple GaussianMixture
static GaussianMixture createSimpleGaussianMixture() {
// Create Gaussian mixture Z(0) = X(0) + noise.
// TODO(dellaert): making copies below is not ideal !
Matrix1 I = Matrix1::Identity();
@ -157,15 +150,77 @@ TEST(GaussianMixture, Likelihood) {
GaussianConditional::FromMeanAndStddev(Z(0), I, X(0), Vector1(0), 3));
const auto gm = GaussianMixture::FromConditionals(
{Z(0)}, {X(0)}, {mode}, {conditional0, conditional1});
return gm;
}
/* ************************************************************************* */
std::set<DiscreteKey> DiscreteKeysAsSet(const DiscreteKeys& dkeys) {
std::set<DiscreteKey> s;
s.insert(dkeys.begin(), dkeys.end());
return s;
}
// Get only the continuous parent keys as a KeyVector:
KeyVector continuousParents(const GaussianMixture& gm) {
// Get all parent keys:
const auto range = gm.parents();
KeyVector continuousParentKeys(range.begin(), range.end());
// Loop over all discrete keys:
for (const auto& discreteKey : gm.discreteKeys()) {
const Key key = discreteKey.first;
// remove that key from continuousParentKeys:
continuousParentKeys.erase(std::remove(continuousParentKeys.begin(),
continuousParentKeys.end(), key),
continuousParentKeys.end());
}
return continuousParentKeys;
}
// Create a test for continuousParents.
TEST(GaussianMixture, ContinuousParents) {
const GaussianMixture gm = createSimpleGaussianMixture();
const KeyVector continuousParentKeys = continuousParents(gm);
// Check that the continuous parent keys are correct:
EXPECT(continuousParentKeys.size() == 1);
EXPECT(continuousParentKeys[0] == X(0));
}
/* ************************************************************************* */
// Create a likelihood factor for a Gaussian mixture, return a Mixture factor.
GaussianMixtureFactor::shared_ptr likelihood(const GaussianMixture& gm,
const VectorValues& frontals) {
// TODO(dellaert): check that values has all frontals
const DiscreteKeys discreteParentKeys = gm.discreteKeys();
const KeyVector continuousParentKeys = continuousParents(gm);
const GaussianMixtureFactor::Factors likelihoods(
gm.conditionals(),
[&](const GaussianConditional::shared_ptr& conditional) {
return conditional->likelihood(frontals);
});
return boost::make_shared<GaussianMixtureFactor>(
continuousParentKeys, discreteParentKeys, likelihoods);
}
/// Check that likelihood returns a mixture factor on the parents.
TEST(GaussianMixture, Likelihood) {
const GaussianMixture gm = createSimpleGaussianMixture();
// Call the likelihood function:
VectorValues measurements;
measurements.insert(Z(0), Vector1(0));
HybridValues values(DiscreteValues(), measurements);
const auto factor = likelihood(values);
const auto factor = likelihood(gm, measurements);
// Check that the factor is a mixture factor on the parents.
const GaussianMixtureFactor expected = GaussianMixtureFactor();
// Loop over all discrete assignments over the discrete parents:
const DiscreteKeys discreteParentKeys = gm.discreteKeys();
// Apply the likelihood function to all conditionals:
const GaussianMixtureFactor::Factors factors(
gm.conditionals(),
[measurements](const GaussianConditional::shared_ptr& conditional) {
return conditional->likelihood(measurements);
});
const GaussianMixtureFactor expected({X(0)}, {mode}, factors);
EXPECT(assert_equal(*factor, expected));
}