likelihood method (as well as continuousParents)

release/4.3a0
Frank Dellaert 2022-12-29 13:28:20 -05:00
parent 611f61c7f4
commit 7ba5392525
3 changed files with 52 additions and 49 deletions

View File

@ -21,6 +21,7 @@
#include <gtsam/base/utilities.h> #include <gtsam/base/utilities.h>
#include <gtsam/discrete/DiscreteValues.h> #include <gtsam/discrete/DiscreteValues.h>
#include <gtsam/hybrid/GaussianMixture.h> #include <gtsam/hybrid/GaussianMixture.h>
#include <gtsam/hybrid/GaussianMixtureFactor.h>
#include <gtsam/inference/Conditional-inst.h> #include <gtsam/inference/Conditional-inst.h>
#include <gtsam/linear/GaussianFactorGraph.h> #include <gtsam/linear/GaussianFactorGraph.h>
@ -128,6 +129,36 @@ void GaussianMixture::print(const std::string &s,
}); });
} }
/* ************************************************************************* */
KeyVector GaussianMixture::continuousParents() const {
// Get all parent keys:
const auto range = parents();
KeyVector continuousParentKeys(range.begin(), range.end());
// Loop over all discrete keys:
for (const auto &discreteKey : discreteKeys()) {
const Key key = discreteKey.first;
// remove that key from continuousParentKeys:
continuousParentKeys.erase(std::remove(continuousParentKeys.begin(),
continuousParentKeys.end(), key),
continuousParentKeys.end());
}
return continuousParentKeys;
}
/* ************************************************************************* */
boost::shared_ptr<GaussianMixtureFactor> GaussianMixture::likelihood(
const VectorValues &frontals) const {
// TODO(dellaert): check that values has all frontals
const DiscreteKeys discreteParentKeys = discreteKeys();
const KeyVector continuousParentKeys = continuousParents();
const GaussianMixtureFactor::Factors likelihoods(
conditionals(), [&](const GaussianConditional::shared_ptr &conditional) {
return conditional->likelihood(frontals);
});
return boost::make_shared<GaussianMixtureFactor>(
continuousParentKeys, discreteParentKeys, likelihoods);
}
/* ************************************************************************* */ /* ************************************************************************* */
std::set<DiscreteKey> DiscreteKeysAsSet(const DiscreteKeys &dkeys) { std::set<DiscreteKey> DiscreteKeysAsSet(const DiscreteKeys &dkeys) {
std::set<DiscreteKey> s; std::set<DiscreteKey> s;

View File

@ -29,6 +29,8 @@
namespace gtsam { namespace gtsam {
class GaussianMixtureFactor;
/** /**
* @brief A conditional of gaussian mixtures indexed by discrete variables, as * @brief A conditional of gaussian mixtures indexed by discrete variables, as
* part of a Bayes Network. This is the result of the elimination of a * part of a Bayes Network. This is the result of the elimination of a
@ -117,16 +119,6 @@ class GTSAM_EXPORT GaussianMixture
const DiscreteKeys &discreteParents, const DiscreteKeys &discreteParents,
const std::vector<GaussianConditional::shared_ptr> &conditionals); const std::vector<GaussianConditional::shared_ptr> &conditionals);
/// @}
/// @name Standard API
/// @{
GaussianConditional::shared_ptr operator()(
const DiscreteValues &discreteValues) const;
/// Returns the total number of continuous components
size_t nrComponents() const;
/// @} /// @}
/// @name Testable /// @name Testable
/// @{ /// @{
@ -140,6 +132,22 @@ class GTSAM_EXPORT GaussianMixture
const KeyFormatter &formatter = DefaultKeyFormatter) const override; const KeyFormatter &formatter = DefaultKeyFormatter) const override;
/// @} /// @}
/// @name Standard API
/// @{
GaussianConditional::shared_ptr operator()(
const DiscreteValues &discreteValues) const;
/// Returns the total number of continuous components
size_t nrComponents() const;
/// Returns the continuous keys among the parents.
KeyVector continuousParents() const;
// Create a likelihood factor for a Gaussian mixture, return a Mixture factor
// on the parents.
boost::shared_ptr<GaussianMixtureFactor> likelihood(
const VectorValues &frontals) const;
/// Getter for the underlying Conditionals DecisionTree /// Getter for the underlying Conditionals DecisionTree
const Conditionals &conditionals() const; const Conditionals &conditionals() const;
@ -181,6 +189,7 @@ class GTSAM_EXPORT GaussianMixture
* @return Sum * @return Sum
*/ */
Sum add(const Sum &sum) const; Sum add(const Sum &sum) const;
/// @}
}; };
/// Return the DiscreteKey vector as a set. /// Return the DiscreteKey vector as a set.

View File

@ -154,53 +154,16 @@ static GaussianMixture createSimpleGaussianMixture() {
} }
/* ************************************************************************* */ /* ************************************************************************* */
std::set<DiscreteKey> DiscreteKeysAsSet(const DiscreteKeys& dkeys) {
std::set<DiscreteKey> s;
s.insert(dkeys.begin(), dkeys.end());
return s;
}
// Get only the continuous parent keys as a KeyVector:
KeyVector continuousParents(const GaussianMixture& gm) {
// Get all parent keys:
const auto range = gm.parents();
KeyVector continuousParentKeys(range.begin(), range.end());
// Loop over all discrete keys:
for (const auto& discreteKey : gm.discreteKeys()) {
const Key key = discreteKey.first;
// remove that key from continuousParentKeys:
continuousParentKeys.erase(std::remove(continuousParentKeys.begin(),
continuousParentKeys.end(), key),
continuousParentKeys.end());
}
return continuousParentKeys;
}
// Create a test for continuousParents. // Create a test for continuousParents.
TEST(GaussianMixture, ContinuousParents) { TEST(GaussianMixture, ContinuousParents) {
const GaussianMixture gm = createSimpleGaussianMixture(); const GaussianMixture gm = createSimpleGaussianMixture();
const KeyVector continuousParentKeys = continuousParents(gm); const KeyVector continuousParentKeys = gm.continuousParents();
// Check that the continuous parent keys are correct: // Check that the continuous parent keys are correct:
EXPECT(continuousParentKeys.size() == 1); EXPECT(continuousParentKeys.size() == 1);
EXPECT(continuousParentKeys[0] == X(0)); EXPECT(continuousParentKeys[0] == X(0));
} }
/* ************************************************************************* */ /* ************************************************************************* */
// Create a likelihood factor for a Gaussian mixture, return a Mixture factor.
GaussianMixtureFactor::shared_ptr likelihood(const GaussianMixture& gm,
const VectorValues& frontals) {
// TODO(dellaert): check that values has all frontals
const DiscreteKeys discreteParentKeys = gm.discreteKeys();
const KeyVector continuousParentKeys = continuousParents(gm);
const GaussianMixtureFactor::Factors likelihoods(
gm.conditionals(),
[&](const GaussianConditional::shared_ptr& conditional) {
return conditional->likelihood(frontals);
});
return boost::make_shared<GaussianMixtureFactor>(
continuousParentKeys, discreteParentKeys, likelihoods);
}
/// Check that likelihood returns a mixture factor on the parents. /// Check that likelihood returns a mixture factor on the parents.
TEST(GaussianMixture, Likelihood) { TEST(GaussianMixture, Likelihood) {
const GaussianMixture gm = createSimpleGaussianMixture(); const GaussianMixture gm = createSimpleGaussianMixture();
@ -208,7 +171,7 @@ TEST(GaussianMixture, Likelihood) {
// Call the likelihood function: // Call the likelihood function:
VectorValues measurements; VectorValues measurements;
measurements.insert(Z(0), Vector1(0)); measurements.insert(Z(0), Vector1(0));
const auto factor = likelihood(gm, measurements); const auto factor = gm.likelihood(measurements);
// Check that the factor is a mixture factor on the parents. // Check that the factor is a mixture factor on the parents.
// Loop over all discrete assignments over the discrete parents: // Loop over all discrete assignments over the discrete parents: