likelihood method (as well as continuousParents)
parent
611f61c7f4
commit
7ba5392525
|
|
@ -21,6 +21,7 @@
|
|||
#include <gtsam/base/utilities.h>
|
||||
#include <gtsam/discrete/DiscreteValues.h>
|
||||
#include <gtsam/hybrid/GaussianMixture.h>
|
||||
#include <gtsam/hybrid/GaussianMixtureFactor.h>
|
||||
#include <gtsam/inference/Conditional-inst.h>
|
||||
#include <gtsam/linear/GaussianFactorGraph.h>
|
||||
|
||||
|
|
@ -128,6 +129,36 @@ void GaussianMixture::print(const std::string &s,
|
|||
});
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
KeyVector GaussianMixture::continuousParents() const {
|
||||
// Get all parent keys:
|
||||
const auto range = parents();
|
||||
KeyVector continuousParentKeys(range.begin(), range.end());
|
||||
// Loop over all discrete keys:
|
||||
for (const auto &discreteKey : discreteKeys()) {
|
||||
const Key key = discreteKey.first;
|
||||
// remove that key from continuousParentKeys:
|
||||
continuousParentKeys.erase(std::remove(continuousParentKeys.begin(),
|
||||
continuousParentKeys.end(), key),
|
||||
continuousParentKeys.end());
|
||||
}
|
||||
return continuousParentKeys;
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
boost::shared_ptr<GaussianMixtureFactor> GaussianMixture::likelihood(
|
||||
const VectorValues &frontals) const {
|
||||
// TODO(dellaert): check that values has all frontals
|
||||
const DiscreteKeys discreteParentKeys = discreteKeys();
|
||||
const KeyVector continuousParentKeys = continuousParents();
|
||||
const GaussianMixtureFactor::Factors likelihoods(
|
||||
conditionals(), [&](const GaussianConditional::shared_ptr &conditional) {
|
||||
return conditional->likelihood(frontals);
|
||||
});
|
||||
return boost::make_shared<GaussianMixtureFactor>(
|
||||
continuousParentKeys, discreteParentKeys, likelihoods);
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
std::set<DiscreteKey> DiscreteKeysAsSet(const DiscreteKeys &dkeys) {
|
||||
std::set<DiscreteKey> s;
|
||||
|
|
|
|||
|
|
@ -29,6 +29,8 @@
|
|||
|
||||
namespace gtsam {
|
||||
|
||||
class GaussianMixtureFactor;
|
||||
|
||||
/**
|
||||
* @brief A conditional of gaussian mixtures indexed by discrete variables, as
|
||||
* part of a Bayes Network. This is the result of the elimination of a
|
||||
|
|
@ -117,16 +119,6 @@ class GTSAM_EXPORT GaussianMixture
|
|||
const DiscreteKeys &discreteParents,
|
||||
const std::vector<GaussianConditional::shared_ptr> &conditionals);
|
||||
|
||||
/// @}
|
||||
/// @name Standard API
|
||||
/// @{
|
||||
|
||||
GaussianConditional::shared_ptr operator()(
|
||||
const DiscreteValues &discreteValues) const;
|
||||
|
||||
/// Returns the total number of continuous components
|
||||
size_t nrComponents() const;
|
||||
|
||||
/// @}
|
||||
/// @name Testable
|
||||
/// @{
|
||||
|
|
@ -140,6 +132,22 @@ class GTSAM_EXPORT GaussianMixture
|
|||
const KeyFormatter &formatter = DefaultKeyFormatter) const override;
|
||||
|
||||
/// @}
|
||||
/// @name Standard API
|
||||
/// @{
|
||||
|
||||
GaussianConditional::shared_ptr operator()(
|
||||
const DiscreteValues &discreteValues) const;
|
||||
|
||||
/// Returns the total number of continuous components
|
||||
size_t nrComponents() const;
|
||||
|
||||
/// Returns the continuous keys among the parents.
|
||||
KeyVector continuousParents() const;
|
||||
|
||||
// Create a likelihood factor for a Gaussian mixture, return a Mixture factor
|
||||
// on the parents.
|
||||
boost::shared_ptr<GaussianMixtureFactor> likelihood(
|
||||
const VectorValues &frontals) const;
|
||||
|
||||
/// Getter for the underlying Conditionals DecisionTree
|
||||
const Conditionals &conditionals() const;
|
||||
|
|
@ -181,6 +189,7 @@ class GTSAM_EXPORT GaussianMixture
|
|||
* @return Sum
|
||||
*/
|
||||
Sum add(const Sum &sum) const;
|
||||
/// @}
|
||||
};
|
||||
|
||||
/// Return the DiscreteKey vector as a set.
|
||||
|
|
|
|||
|
|
@ -154,53 +154,16 @@ static GaussianMixture createSimpleGaussianMixture() {
|
|||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
std::set<DiscreteKey> DiscreteKeysAsSet(const DiscreteKeys& dkeys) {
|
||||
std::set<DiscreteKey> s;
|
||||
s.insert(dkeys.begin(), dkeys.end());
|
||||
return s;
|
||||
}
|
||||
|
||||
// Get only the continuous parent keys as a KeyVector:
|
||||
KeyVector continuousParents(const GaussianMixture& gm) {
|
||||
// Get all parent keys:
|
||||
const auto range = gm.parents();
|
||||
KeyVector continuousParentKeys(range.begin(), range.end());
|
||||
// Loop over all discrete keys:
|
||||
for (const auto& discreteKey : gm.discreteKeys()) {
|
||||
const Key key = discreteKey.first;
|
||||
// remove that key from continuousParentKeys:
|
||||
continuousParentKeys.erase(std::remove(continuousParentKeys.begin(),
|
||||
continuousParentKeys.end(), key),
|
||||
continuousParentKeys.end());
|
||||
}
|
||||
return continuousParentKeys;
|
||||
}
|
||||
|
||||
// Create a test for continuousParents.
|
||||
TEST(GaussianMixture, ContinuousParents) {
|
||||
const GaussianMixture gm = createSimpleGaussianMixture();
|
||||
const KeyVector continuousParentKeys = continuousParents(gm);
|
||||
const KeyVector continuousParentKeys = gm.continuousParents();
|
||||
// Check that the continuous parent keys are correct:
|
||||
EXPECT(continuousParentKeys.size() == 1);
|
||||
EXPECT(continuousParentKeys[0] == X(0));
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
// Create a likelihood factor for a Gaussian mixture, return a Mixture factor.
|
||||
GaussianMixtureFactor::shared_ptr likelihood(const GaussianMixture& gm,
|
||||
const VectorValues& frontals) {
|
||||
// TODO(dellaert): check that values has all frontals
|
||||
const DiscreteKeys discreteParentKeys = gm.discreteKeys();
|
||||
const KeyVector continuousParentKeys = continuousParents(gm);
|
||||
const GaussianMixtureFactor::Factors likelihoods(
|
||||
gm.conditionals(),
|
||||
[&](const GaussianConditional::shared_ptr& conditional) {
|
||||
return conditional->likelihood(frontals);
|
||||
});
|
||||
return boost::make_shared<GaussianMixtureFactor>(
|
||||
continuousParentKeys, discreteParentKeys, likelihoods);
|
||||
}
|
||||
|
||||
/// Check that likelihood returns a mixture factor on the parents.
|
||||
TEST(GaussianMixture, Likelihood) {
|
||||
const GaussianMixture gm = createSimpleGaussianMixture();
|
||||
|
|
@ -208,7 +171,7 @@ TEST(GaussianMixture, Likelihood) {
|
|||
// Call the likelihood function:
|
||||
VectorValues measurements;
|
||||
measurements.insert(Z(0), Vector1(0));
|
||||
const auto factor = likelihood(gm, measurements);
|
||||
const auto factor = gm.likelihood(measurements);
|
||||
|
||||
// Check that the factor is a mixture factor on the parents.
|
||||
// Loop over all discrete assignments over the discrete parents:
|
||||
|
|
|
|||
Loading…
Reference in New Issue