likelihood method (as well as continuousParents)
parent
611f61c7f4
commit
7ba5392525
|
|
@ -21,6 +21,7 @@
|
||||||
#include <gtsam/base/utilities.h>
|
#include <gtsam/base/utilities.h>
|
||||||
#include <gtsam/discrete/DiscreteValues.h>
|
#include <gtsam/discrete/DiscreteValues.h>
|
||||||
#include <gtsam/hybrid/GaussianMixture.h>
|
#include <gtsam/hybrid/GaussianMixture.h>
|
||||||
|
#include <gtsam/hybrid/GaussianMixtureFactor.h>
|
||||||
#include <gtsam/inference/Conditional-inst.h>
|
#include <gtsam/inference/Conditional-inst.h>
|
||||||
#include <gtsam/linear/GaussianFactorGraph.h>
|
#include <gtsam/linear/GaussianFactorGraph.h>
|
||||||
|
|
||||||
|
|
@ -128,6 +129,36 @@ void GaussianMixture::print(const std::string &s,
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************* */
|
||||||
|
KeyVector GaussianMixture::continuousParents() const {
|
||||||
|
// Get all parent keys:
|
||||||
|
const auto range = parents();
|
||||||
|
KeyVector continuousParentKeys(range.begin(), range.end());
|
||||||
|
// Loop over all discrete keys:
|
||||||
|
for (const auto &discreteKey : discreteKeys()) {
|
||||||
|
const Key key = discreteKey.first;
|
||||||
|
// remove that key from continuousParentKeys:
|
||||||
|
continuousParentKeys.erase(std::remove(continuousParentKeys.begin(),
|
||||||
|
continuousParentKeys.end(), key),
|
||||||
|
continuousParentKeys.end());
|
||||||
|
}
|
||||||
|
return continuousParentKeys;
|
||||||
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************* */
|
||||||
|
boost::shared_ptr<GaussianMixtureFactor> GaussianMixture::likelihood(
|
||||||
|
const VectorValues &frontals) const {
|
||||||
|
// TODO(dellaert): check that values has all frontals
|
||||||
|
const DiscreteKeys discreteParentKeys = discreteKeys();
|
||||||
|
const KeyVector continuousParentKeys = continuousParents();
|
||||||
|
const GaussianMixtureFactor::Factors likelihoods(
|
||||||
|
conditionals(), [&](const GaussianConditional::shared_ptr &conditional) {
|
||||||
|
return conditional->likelihood(frontals);
|
||||||
|
});
|
||||||
|
return boost::make_shared<GaussianMixtureFactor>(
|
||||||
|
continuousParentKeys, discreteParentKeys, likelihoods);
|
||||||
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
std::set<DiscreteKey> DiscreteKeysAsSet(const DiscreteKeys &dkeys) {
|
std::set<DiscreteKey> DiscreteKeysAsSet(const DiscreteKeys &dkeys) {
|
||||||
std::set<DiscreteKey> s;
|
std::set<DiscreteKey> s;
|
||||||
|
|
|
||||||
|
|
@ -29,6 +29,8 @@
|
||||||
|
|
||||||
namespace gtsam {
|
namespace gtsam {
|
||||||
|
|
||||||
|
class GaussianMixtureFactor;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @brief A conditional of gaussian mixtures indexed by discrete variables, as
|
* @brief A conditional of gaussian mixtures indexed by discrete variables, as
|
||||||
* part of a Bayes Network. This is the result of the elimination of a
|
* part of a Bayes Network. This is the result of the elimination of a
|
||||||
|
|
@ -117,16 +119,6 @@ class GTSAM_EXPORT GaussianMixture
|
||||||
const DiscreteKeys &discreteParents,
|
const DiscreteKeys &discreteParents,
|
||||||
const std::vector<GaussianConditional::shared_ptr> &conditionals);
|
const std::vector<GaussianConditional::shared_ptr> &conditionals);
|
||||||
|
|
||||||
/// @}
|
|
||||||
/// @name Standard API
|
|
||||||
/// @{
|
|
||||||
|
|
||||||
GaussianConditional::shared_ptr operator()(
|
|
||||||
const DiscreteValues &discreteValues) const;
|
|
||||||
|
|
||||||
/// Returns the total number of continuous components
|
|
||||||
size_t nrComponents() const;
|
|
||||||
|
|
||||||
/// @}
|
/// @}
|
||||||
/// @name Testable
|
/// @name Testable
|
||||||
/// @{
|
/// @{
|
||||||
|
|
@ -140,6 +132,22 @@ class GTSAM_EXPORT GaussianMixture
|
||||||
const KeyFormatter &formatter = DefaultKeyFormatter) const override;
|
const KeyFormatter &formatter = DefaultKeyFormatter) const override;
|
||||||
|
|
||||||
/// @}
|
/// @}
|
||||||
|
/// @name Standard API
|
||||||
|
/// @{
|
||||||
|
|
||||||
|
GaussianConditional::shared_ptr operator()(
|
||||||
|
const DiscreteValues &discreteValues) const;
|
||||||
|
|
||||||
|
/// Returns the total number of continuous components
|
||||||
|
size_t nrComponents() const;
|
||||||
|
|
||||||
|
/// Returns the continuous keys among the parents.
|
||||||
|
KeyVector continuousParents() const;
|
||||||
|
|
||||||
|
// Create a likelihood factor for a Gaussian mixture, return a Mixture factor
|
||||||
|
// on the parents.
|
||||||
|
boost::shared_ptr<GaussianMixtureFactor> likelihood(
|
||||||
|
const VectorValues &frontals) const;
|
||||||
|
|
||||||
/// Getter for the underlying Conditionals DecisionTree
|
/// Getter for the underlying Conditionals DecisionTree
|
||||||
const Conditionals &conditionals() const;
|
const Conditionals &conditionals() const;
|
||||||
|
|
@ -181,6 +189,7 @@ class GTSAM_EXPORT GaussianMixture
|
||||||
* @return Sum
|
* @return Sum
|
||||||
*/
|
*/
|
||||||
Sum add(const Sum &sum) const;
|
Sum add(const Sum &sum) const;
|
||||||
|
/// @}
|
||||||
};
|
};
|
||||||
|
|
||||||
/// Return the DiscreteKey vector as a set.
|
/// Return the DiscreteKey vector as a set.
|
||||||
|
|
|
||||||
|
|
@ -154,53 +154,16 @@ static GaussianMixture createSimpleGaussianMixture() {
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
std::set<DiscreteKey> DiscreteKeysAsSet(const DiscreteKeys& dkeys) {
|
|
||||||
std::set<DiscreteKey> s;
|
|
||||||
s.insert(dkeys.begin(), dkeys.end());
|
|
||||||
return s;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Get only the continuous parent keys as a KeyVector:
|
|
||||||
KeyVector continuousParents(const GaussianMixture& gm) {
|
|
||||||
// Get all parent keys:
|
|
||||||
const auto range = gm.parents();
|
|
||||||
KeyVector continuousParentKeys(range.begin(), range.end());
|
|
||||||
// Loop over all discrete keys:
|
|
||||||
for (const auto& discreteKey : gm.discreteKeys()) {
|
|
||||||
const Key key = discreteKey.first;
|
|
||||||
// remove that key from continuousParentKeys:
|
|
||||||
continuousParentKeys.erase(std::remove(continuousParentKeys.begin(),
|
|
||||||
continuousParentKeys.end(), key),
|
|
||||||
continuousParentKeys.end());
|
|
||||||
}
|
|
||||||
return continuousParentKeys;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Create a test for continuousParents.
|
// Create a test for continuousParents.
|
||||||
TEST(GaussianMixture, ContinuousParents) {
|
TEST(GaussianMixture, ContinuousParents) {
|
||||||
const GaussianMixture gm = createSimpleGaussianMixture();
|
const GaussianMixture gm = createSimpleGaussianMixture();
|
||||||
const KeyVector continuousParentKeys = continuousParents(gm);
|
const KeyVector continuousParentKeys = gm.continuousParents();
|
||||||
// Check that the continuous parent keys are correct:
|
// Check that the continuous parent keys are correct:
|
||||||
EXPECT(continuousParentKeys.size() == 1);
|
EXPECT(continuousParentKeys.size() == 1);
|
||||||
EXPECT(continuousParentKeys[0] == X(0));
|
EXPECT(continuousParentKeys[0] == X(0));
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
// Create a likelihood factor for a Gaussian mixture, return a Mixture factor.
|
|
||||||
GaussianMixtureFactor::shared_ptr likelihood(const GaussianMixture& gm,
|
|
||||||
const VectorValues& frontals) {
|
|
||||||
// TODO(dellaert): check that values has all frontals
|
|
||||||
const DiscreteKeys discreteParentKeys = gm.discreteKeys();
|
|
||||||
const KeyVector continuousParentKeys = continuousParents(gm);
|
|
||||||
const GaussianMixtureFactor::Factors likelihoods(
|
|
||||||
gm.conditionals(),
|
|
||||||
[&](const GaussianConditional::shared_ptr& conditional) {
|
|
||||||
return conditional->likelihood(frontals);
|
|
||||||
});
|
|
||||||
return boost::make_shared<GaussianMixtureFactor>(
|
|
||||||
continuousParentKeys, discreteParentKeys, likelihoods);
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Check that likelihood returns a mixture factor on the parents.
|
/// Check that likelihood returns a mixture factor on the parents.
|
||||||
TEST(GaussianMixture, Likelihood) {
|
TEST(GaussianMixture, Likelihood) {
|
||||||
const GaussianMixture gm = createSimpleGaussianMixture();
|
const GaussianMixture gm = createSimpleGaussianMixture();
|
||||||
|
|
@ -208,7 +171,7 @@ TEST(GaussianMixture, Likelihood) {
|
||||||
// Call the likelihood function:
|
// Call the likelihood function:
|
||||||
VectorValues measurements;
|
VectorValues measurements;
|
||||||
measurements.insert(Z(0), Vector1(0));
|
measurements.insert(Z(0), Vector1(0));
|
||||||
const auto factor = likelihood(gm, measurements);
|
const auto factor = gm.likelihood(measurements);
|
||||||
|
|
||||||
// Check that the factor is a mixture factor on the parents.
|
// Check that the factor is a mixture factor on the parents.
|
||||||
// Loop over all discrete assignments over the discrete parents:
|
// Loop over all discrete assignments over the discrete parents:
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue