normalizationConstants returns all constants as a DecisionTreeFactor
parent
618ac28f2c
commit
34a9aef6f3
|
|
@ -170,21 +170,41 @@ KeyVector GaussianMixture::continuousParents() const {
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
boost::shared_ptr<GaussianMixtureFactor> GaussianMixture::likelihood(
|
boost::shared_ptr<DecisionTreeFactor> GaussianMixture::normalizationConstants()
|
||||||
const VectorValues &frontals) const {
|
const {
|
||||||
// Check that values has all frontals
|
DecisionTree<Key, double> constants(
|
||||||
for (auto &&kv : frontals) {
|
conditionals_, [&](const GaussianConditional::shared_ptr &conditional) {
|
||||||
if (frontals.find(kv.first) == frontals.end()) {
|
return conditional->normalizationConstant();
|
||||||
throw std::runtime_error("GaussianMixture: frontals missing factor key.");
|
});
|
||||||
|
// If all constants the same, return nullptr:
|
||||||
|
if (constants.nrLeaves() == 1) return nullptr;
|
||||||
|
return boost::make_shared<DecisionTreeFactor>(discreteKeys(), constants);
|
||||||
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************* */
|
||||||
|
bool GaussianMixture::allFrontalsGiven(const VectorValues &given) const {
|
||||||
|
for (auto &&kv : given) {
|
||||||
|
if (given.find(kv.first) == given.end()) {
|
||||||
|
return false;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************* */
|
||||||
|
boost::shared_ptr<GaussianMixtureFactor> GaussianMixture::likelihood(
|
||||||
|
const VectorValues &given) const {
|
||||||
|
if (!allFrontalsGiven(given)) {
|
||||||
|
throw std::runtime_error(
|
||||||
|
"GaussianMixture::likelihood: given values are missing some frontals.");
|
||||||
|
}
|
||||||
|
|
||||||
const DiscreteKeys discreteParentKeys = discreteKeys();
|
const DiscreteKeys discreteParentKeys = discreteKeys();
|
||||||
const KeyVector continuousParentKeys = continuousParents();
|
const KeyVector continuousParentKeys = continuousParents();
|
||||||
const GaussianMixtureFactor::Factors likelihoods(
|
const GaussianMixtureFactor::Factors likelihoods(
|
||||||
conditionals_, [&](const GaussianConditional::shared_ptr &conditional) {
|
conditionals_, [&](const GaussianConditional::shared_ptr &conditional) {
|
||||||
return GaussianMixtureFactor::FactorAndConstant{
|
return GaussianMixtureFactor::FactorAndConstant{
|
||||||
conditional->likelihood(frontals),
|
conditional->likelihood(given),
|
||||||
conditional->logNormalizationConstant()};
|
conditional->logNormalizationConstant()};
|
||||||
});
|
});
|
||||||
return boost::make_shared<GaussianMixtureFactor>(
|
return boost::make_shared<GaussianMixtureFactor>(
|
||||||
|
|
@ -285,8 +305,7 @@ AlgebraicDecisionTree<Key> GaussianMixture::logProbability(
|
||||||
return 1e50;
|
return 1e50;
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
DecisionTree<Key, double> errorTree(conditionals_, errorFunc);
|
return DecisionTree<Key, double>(conditionals_, errorFunc);
|
||||||
return errorTree;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/* *******************************************************************************/
|
/* *******************************************************************************/
|
||||||
|
|
|
||||||
|
|
@ -155,10 +155,16 @@ class GTSAM_EXPORT GaussianMixture
|
||||||
/// Returns the continuous keys among the parents.
|
/// Returns the continuous keys among the parents.
|
||||||
KeyVector continuousParents() const;
|
KeyVector continuousParents() const;
|
||||||
|
|
||||||
// Create a likelihood factor for a Gaussian mixture, return a Mixture factor
|
/// Return a discrete factor with possibly varying normalization constants.
|
||||||
// on the parents.
|
/// If there is no variation, return nullptr.
|
||||||
|
boost::shared_ptr<DecisionTreeFactor> normalizationConstants() const;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Create a likelihood factor for a Gaussian mixture, return a Mixture factor
|
||||||
|
* on the parents.
|
||||||
|
*/
|
||||||
boost::shared_ptr<GaussianMixtureFactor> likelihood(
|
boost::shared_ptr<GaussianMixtureFactor> likelihood(
|
||||||
const VectorValues &frontals) const;
|
const VectorValues &given) const;
|
||||||
|
|
||||||
/// Getter for the underlying Conditionals DecisionTree
|
/// Getter for the underlying Conditionals DecisionTree
|
||||||
const Conditionals &conditionals() const;
|
const Conditionals &conditionals() const;
|
||||||
|
|
@ -233,6 +239,9 @@ class GTSAM_EXPORT GaussianMixture
|
||||||
/// @}
|
/// @}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
/// Check whether `given` has values for all frontal keys.
|
||||||
|
bool allFrontalsGiven(const VectorValues &given) const;
|
||||||
|
|
||||||
/** Serialization function */
|
/** Serialization function */
|
||||||
friend class boost::serialization::access;
|
friend class boost::serialization::access;
|
||||||
template <class Archive>
|
template <class Archive>
|
||||||
|
|
|
||||||
|
|
@ -106,13 +106,16 @@ TEST(GaussianMixture, Error) {
|
||||||
conditional1 = boost::make_shared<GaussianConditional>(X(1), d2, R2,
|
conditional1 = boost::make_shared<GaussianConditional>(X(1), d2, R2,
|
||||||
X(2), S2, model);
|
X(2), S2, model);
|
||||||
|
|
||||||
// Create decision tree
|
// Create Gaussian Mixture.
|
||||||
DiscreteKey m1(M(1), 2);
|
DiscreteKey m1(M(1), 2);
|
||||||
GaussianMixture::Conditionals conditionals(
|
GaussianMixture::Conditionals conditionals(
|
||||||
{m1},
|
{m1},
|
||||||
vector<GaussianConditional::shared_ptr>{conditional0, conditional1});
|
vector<GaussianConditional::shared_ptr>{conditional0, conditional1});
|
||||||
GaussianMixture mixture({X(1)}, {X(2)}, {m1}, conditionals);
|
GaussianMixture mixture({X(1)}, {X(2)}, {m1}, conditionals);
|
||||||
|
|
||||||
|
// Check that normalizationConstants returns nullptr, as all constants equal.
|
||||||
|
CHECK(!mixture.normalizationConstants());
|
||||||
|
|
||||||
VectorValues values;
|
VectorValues values;
|
||||||
values.insert(X(1), Vector2::Ones());
|
values.insert(X(1), Vector2::Ones());
|
||||||
values.insert(X(2), Vector2::Zero());
|
values.insert(X(2), Vector2::Zero());
|
||||||
|
|
@ -163,6 +166,19 @@ TEST(GaussianMixture, ContinuousParents) {
|
||||||
EXPECT(continuousParentKeys[0] == X(0));
|
EXPECT(continuousParentKeys[0] == X(0));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************* */
|
||||||
|
/// Check we can create a DecisionTreeFactor with all normalization constants.
|
||||||
|
TEST(GaussianMixture, NormalizationConstants) {
|
||||||
|
const GaussianMixture gm = createSimpleGaussianMixture();
|
||||||
|
|
||||||
|
const auto factor = gm.normalizationConstants();
|
||||||
|
|
||||||
|
// Test with 1D Gaussian normalization constants for sigma 0.5 and 3:
|
||||||
|
auto c = [](double sigma) { return 1.0 / (sqrt(2 * M_PI) * sigma); };
|
||||||
|
const DecisionTreeFactor expected({M(0), 2}, {c(0.5), c(3)});
|
||||||
|
EXPECT(assert_equal(expected, *factor));
|
||||||
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
/// Check that likelihood returns a mixture factor on the parents.
|
/// Check that likelihood returns a mixture factor on the parents.
|
||||||
TEST(GaussianMixture, Likelihood) {
|
TEST(GaussianMixture, Likelihood) {
|
||||||
|
|
@ -186,7 +202,7 @@ TEST(GaussianMixture, Likelihood) {
|
||||||
conditional->logNormalizationConstant()};
|
conditional->logNormalizationConstant()};
|
||||||
});
|
});
|
||||||
const GaussianMixtureFactor expected({X(0)}, {mode}, factors);
|
const GaussianMixtureFactor expected({X(0)}, {mode}, factors);
|
||||||
EXPECT(assert_equal(*factor, expected));
|
EXPECT(assert_equal(expected, *factor));
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue