likelihood method (as well as continuousParents)

release/4.3a0
Frank Dellaert 2022-12-29 13:28:20 -05:00
parent 611f61c7f4
commit 7ba5392525
3 changed files with 52 additions and 49 deletions

View File

@ -21,6 +21,7 @@
#include <gtsam/base/utilities.h>
#include <gtsam/discrete/DiscreteValues.h>
#include <gtsam/hybrid/GaussianMixture.h>
#include <gtsam/hybrid/GaussianMixtureFactor.h>
#include <gtsam/inference/Conditional-inst.h>
#include <gtsam/linear/GaussianFactorGraph.h>
@ -128,6 +129,36 @@ void GaussianMixture::print(const std::string &s,
});
}
/* ************************************************************************* */
KeyVector GaussianMixture::continuousParents() const {
// Get all parent keys:
const auto range = parents();
KeyVector continuousParentKeys(range.begin(), range.end());
// Loop over all discrete keys:
for (const auto &discreteKey : discreteKeys()) {
const Key key = discreteKey.first;
// remove that key from continuousParentKeys:
continuousParentKeys.erase(std::remove(continuousParentKeys.begin(),
continuousParentKeys.end(), key),
continuousParentKeys.end());
}
return continuousParentKeys;
}
/* ************************************************************************* */
boost::shared_ptr<GaussianMixtureFactor> GaussianMixture::likelihood(
const VectorValues &frontals) const {
// TODO(dellaert): check that values has all frontals
const DiscreteKeys discreteParentKeys = discreteKeys();
const KeyVector continuousParentKeys = continuousParents();
const GaussianMixtureFactor::Factors likelihoods(
conditionals(), [&](const GaussianConditional::shared_ptr &conditional) {
return conditional->likelihood(frontals);
});
return boost::make_shared<GaussianMixtureFactor>(
continuousParentKeys, discreteParentKeys, likelihoods);
}
/* ************************************************************************* */
std::set<DiscreteKey> DiscreteKeysAsSet(const DiscreteKeys &dkeys) {
std::set<DiscreteKey> s;

View File

@ -29,6 +29,8 @@
namespace gtsam {
class GaussianMixtureFactor;
/**
* @brief A conditional of gaussian mixtures indexed by discrete variables, as
* part of a Bayes Network. This is the result of the elimination of a
@ -117,16 +119,6 @@ class GTSAM_EXPORT GaussianMixture
const DiscreteKeys &discreteParents,
const std::vector<GaussianConditional::shared_ptr> &conditionals);
/// @}
/// @name Standard API
/// @{
GaussianConditional::shared_ptr operator()(
const DiscreteValues &discreteValues) const;
/// Returns the total number of continuous components
size_t nrComponents() const;
/// @}
/// @name Testable
/// @{
@ -140,6 +132,22 @@ class GTSAM_EXPORT GaussianMixture
const KeyFormatter &formatter = DefaultKeyFormatter) const override;
/// @}
/// @name Standard API
/// @{
GaussianConditional::shared_ptr operator()(
const DiscreteValues &discreteValues) const;
/// Returns the total number of continuous components
size_t nrComponents() const;
/// Returns the continuous keys among the parents.
KeyVector continuousParents() const;
// Create a likelihood factor for a Gaussian mixture, return a Mixture factor
// on the parents.
boost::shared_ptr<GaussianMixtureFactor> likelihood(
const VectorValues &frontals) const;
/// Getter for the underlying Conditionals DecisionTree
const Conditionals &conditionals() const;
@ -181,6 +189,7 @@ class GTSAM_EXPORT GaussianMixture
* @return Sum
*/
Sum add(const Sum &sum) const;
/// @}
};
/// Return the DiscreteKey vector as a set.

View File

@ -154,53 +154,16 @@ static GaussianMixture createSimpleGaussianMixture() {
}
/* ************************************************************************* */
std::set<DiscreteKey> DiscreteKeysAsSet(const DiscreteKeys& dkeys) {
std::set<DiscreteKey> s;
s.insert(dkeys.begin(), dkeys.end());
return s;
}
// Get only the continuous parent keys as a KeyVector:
KeyVector continuousParents(const GaussianMixture& gm) {
// Get all parent keys:
const auto range = gm.parents();
KeyVector continuousParentKeys(range.begin(), range.end());
// Loop over all discrete keys:
for (const auto& discreteKey : gm.discreteKeys()) {
const Key key = discreteKey.first;
// remove that key from continuousParentKeys:
continuousParentKeys.erase(std::remove(continuousParentKeys.begin(),
continuousParentKeys.end(), key),
continuousParentKeys.end());
}
return continuousParentKeys;
}
// Create a test for continuousParents.
TEST(GaussianMixture, ContinuousParents) {
const GaussianMixture gm = createSimpleGaussianMixture();
const KeyVector continuousParentKeys = continuousParents(gm);
const KeyVector continuousParentKeys = gm.continuousParents();
// Check that the continuous parent keys are correct:
EXPECT(continuousParentKeys.size() == 1);
EXPECT(continuousParentKeys[0] == X(0));
}
/* ************************************************************************* */
// Create a likelihood factor for a Gaussian mixture, return a Mixture factor.
GaussianMixtureFactor::shared_ptr likelihood(const GaussianMixture& gm,
const VectorValues& frontals) {
// TODO(dellaert): check that values has all frontals
const DiscreteKeys discreteParentKeys = gm.discreteKeys();
const KeyVector continuousParentKeys = continuousParents(gm);
const GaussianMixtureFactor::Factors likelihoods(
gm.conditionals(),
[&](const GaussianConditional::shared_ptr& conditional) {
return conditional->likelihood(frontals);
});
return boost::make_shared<GaussianMixtureFactor>(
continuousParentKeys, discreteParentKeys, likelihoods);
}
/// Check that likelihood returns a mixture factor on the parents.
TEST(GaussianMixture, Likelihood) {
const GaussianMixture gm = createSimpleGaussianMixture();
@ -208,7 +171,7 @@ TEST(GaussianMixture, Likelihood) {
// Call the likelihood function:
VectorValues measurements;
measurements.insert(Z(0), Vector1(0));
const auto factor = likelihood(gm, measurements);
const auto factor = gm.likelihood(measurements);
// Check that the factor is a mixture factor on the parents.
// Loop over all discrete assignments over the discrete parents: