rename logConstant_ to logNormalizer_ and add test for error with mode-dependent constants
parent
561bdcf9af
commit
9afbc019f4
|
@ -33,8 +33,10 @@ HybridGaussianFactor::FactorValuePairs GetFactorValuePairs(
|
||||||
auto func = [](const GaussianConditional::shared_ptr &conditional)
|
auto func = [](const GaussianConditional::shared_ptr &conditional)
|
||||||
-> GaussianFactorValuePair {
|
-> GaussianFactorValuePair {
|
||||||
double value = 0.0;
|
double value = 0.0;
|
||||||
if (conditional) { // Check if conditional is pruned
|
// Check if conditional is pruned
|
||||||
value = conditional->logNormalizationConstant();
|
if (conditional) {
|
||||||
|
// Assign log(|2πΣ|) = -2*log(1 / sqrt(|2πΣ|))
|
||||||
|
value = -2.0 * conditional->logNormalizationConstant();
|
||||||
}
|
}
|
||||||
return {std::dynamic_pointer_cast<GaussianFactor>(conditional), value};
|
return {std::dynamic_pointer_cast<GaussianFactor>(conditional), value};
|
||||||
};
|
};
|
||||||
|
@ -49,14 +51,14 @@ HybridGaussianConditional::HybridGaussianConditional(
|
||||||
discreteParents, GetFactorValuePairs(conditionals)),
|
discreteParents, GetFactorValuePairs(conditionals)),
|
||||||
BaseConditional(continuousFrontals.size()),
|
BaseConditional(continuousFrontals.size()),
|
||||||
conditionals_(conditionals) {
|
conditionals_(conditionals) {
|
||||||
// Calculate logConstant_ as the maximum of the log constants of the
|
// Calculate logNormalizer_ as the minimum of the log normalizers of the
|
||||||
// conditionals, by visiting the decision tree:
|
// conditionals, by visiting the decision tree:
|
||||||
logConstant_ = -std::numeric_limits<double>::infinity();
|
logNormalizer_ = std::numeric_limits<double>::infinity();
|
||||||
conditionals_.visit(
|
conditionals_.visit(
|
||||||
[this](const GaussianConditional::shared_ptr &conditional) {
|
[this](const GaussianConditional::shared_ptr &conditional) {
|
||||||
if (conditional) {
|
if (conditional) {
|
||||||
this->logConstant_ = std::max(
|
this->logNormalizer_ = std::min(
|
||||||
this->logConstant_, conditional->logNormalizationConstant());
|
this->logNormalizer_, -conditional->logNormalizationConstant());
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
@ -98,7 +100,7 @@ GaussianFactorGraphTree HybridGaussianConditional::asGaussianFactorGraphTree()
|
||||||
// First check if conditional has not been pruned
|
// First check if conditional has not been pruned
|
||||||
if (gc) {
|
if (gc) {
|
||||||
const double Cgm_Kgcm =
|
const double Cgm_Kgcm =
|
||||||
this->logConstant_ - gc->logNormalizationConstant();
|
-gc->logNormalizationConstant() - this->logNormalizer_;
|
||||||
// If there is a difference in the covariances, we need to account for
|
// If there is a difference in the covariances, we need to account for
|
||||||
// that since the error is dependent on the mode.
|
// that since the error is dependent on the mode.
|
||||||
if (Cgm_Kgcm > 0.0) {
|
if (Cgm_Kgcm > 0.0) {
|
||||||
|
@ -169,7 +171,8 @@ void HybridGaussianConditional::print(const std::string &s,
|
||||||
std::cout << "(" << formatter(dk.first) << ", " << dk.second << "), ";
|
std::cout << "(" << formatter(dk.first) << ", " << dk.second << "), ";
|
||||||
}
|
}
|
||||||
std::cout << std::endl
|
std::cout << std::endl
|
||||||
<< " logNormalizationConstant: " << logConstant_ << std::endl
|
<< " logNormalizationConstant: " << logNormalizationConstant()
|
||||||
|
<< std::endl
|
||||||
<< std::endl;
|
<< std::endl;
|
||||||
conditionals_.print(
|
conditionals_.print(
|
||||||
"", [&](Key k) { return formatter(k); },
|
"", [&](Key k) { return formatter(k); },
|
||||||
|
@ -228,7 +231,7 @@ std::shared_ptr<HybridGaussianFactor> HybridGaussianConditional::likelihood(
|
||||||
-> GaussianFactorValuePair {
|
-> GaussianFactorValuePair {
|
||||||
const auto likelihood_m = conditional->likelihood(given);
|
const auto likelihood_m = conditional->likelihood(given);
|
||||||
const double Cgm_Kgcm =
|
const double Cgm_Kgcm =
|
||||||
logConstant_ - conditional->logNormalizationConstant();
|
-conditional->logNormalizationConstant() - logNormalizer_;
|
||||||
if (Cgm_Kgcm == 0.0) {
|
if (Cgm_Kgcm == 0.0) {
|
||||||
return {likelihood_m, 0.0};
|
return {likelihood_m, 0.0};
|
||||||
} else {
|
} else {
|
||||||
|
@ -342,7 +345,7 @@ double HybridGaussianConditional::conditionalError(
|
||||||
// Check if valid pointer
|
// Check if valid pointer
|
||||||
if (conditional) {
|
if (conditional) {
|
||||||
return conditional->error(continuousValues) + //
|
return conditional->error(continuousValues) + //
|
||||||
logConstant_ - conditional->logNormalizationConstant();
|
-conditional->logNormalizationConstant() - logNormalizer_;
|
||||||
} else {
|
} else {
|
||||||
// If not valid, pointer, it means this conditional was pruned,
|
// If not valid, pointer, it means this conditional was pruned,
|
||||||
// so we return maximum error.
|
// so we return maximum error.
|
||||||
|
|
|
@ -64,7 +64,8 @@ class GTSAM_EXPORT HybridGaussianConditional
|
||||||
|
|
||||||
private:
|
private:
|
||||||
Conditionals conditionals_; ///< a decision tree of Gaussian conditionals.
|
Conditionals conditionals_; ///< a decision tree of Gaussian conditionals.
|
||||||
double logConstant_; ///< log of the normalization constant.
|
double logNormalizer_; ///< log of the normalization constant
|
||||||
|
///< (log(\sqrt(|2πΣ|))).
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @brief Convert a HybridGaussianConditional of conditionals into
|
* @brief Convert a HybridGaussianConditional of conditionals into
|
||||||
|
@ -149,7 +150,7 @@ class GTSAM_EXPORT HybridGaussianConditional
|
||||||
|
|
||||||
/// The log normalization constant is max of the the individual
|
/// The log normalization constant is max of the the individual
|
||||||
/// log-normalization constants.
|
/// log-normalization constants.
|
||||||
double logNormalizationConstant() const override { return logConstant_; }
|
double logNormalizationConstant() const override { return -logNormalizer_; }
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Create a likelihood factor for a hybrid Gaussian conditional,
|
* Create a likelihood factor for a hybrid Gaussian conditional,
|
||||||
|
|
|
@ -100,7 +100,7 @@ TEST(HybridGaussianConditional, Error) {
|
||||||
auto actual = hybrid_conditional.errorTree(vv);
|
auto actual = hybrid_conditional.errorTree(vv);
|
||||||
|
|
||||||
// Check result.
|
// Check result.
|
||||||
std::vector<DiscreteKey> discrete_keys = {mode};
|
DiscreteKeys discrete_keys{mode};
|
||||||
std::vector<double> leaves = {conditionals[0]->error(vv),
|
std::vector<double> leaves = {conditionals[0]->error(vv),
|
||||||
conditionals[1]->error(vv)};
|
conditionals[1]->error(vv)};
|
||||||
AlgebraicDecisionTree<Key> expected(discrete_keys, leaves);
|
AlgebraicDecisionTree<Key> expected(discrete_keys, leaves);
|
||||||
|
@ -172,6 +172,37 @@ TEST(HybridGaussianConditional, ContinuousParents) {
|
||||||
EXPECT(continuousParentKeys[0] == X(0));
|
EXPECT(continuousParentKeys[0] == X(0));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************* */
|
||||||
|
/// Check error with mode dependent constants.
|
||||||
|
TEST(HybridGaussianConditional, Error2) {
|
||||||
|
using namespace mode_dependent_constants;
|
||||||
|
auto actual = mixture.errorTree(vv);
|
||||||
|
|
||||||
|
// Check result.
|
||||||
|
DiscreteKeys discrete_keys{mode};
|
||||||
|
double logNormalizer0 = -conditionals[0]->logNormalizationConstant();
|
||||||
|
double logNormalizer1 = -conditionals[1]->logNormalizationConstant();
|
||||||
|
double minLogNormalizer = std::min(logNormalizer0, logNormalizer1);
|
||||||
|
|
||||||
|
// Expected error is e(X) + log(|2πΣ|).
|
||||||
|
// We normalize log(|2πΣ|) with min(logNormalizers) so it is non-negative.
|
||||||
|
std::vector<double> leaves = {
|
||||||
|
conditionals[0]->error(vv) + logNormalizer0 - minLogNormalizer,
|
||||||
|
conditionals[1]->error(vv) + logNormalizer1 - minLogNormalizer};
|
||||||
|
AlgebraicDecisionTree<Key> expected(discrete_keys, leaves);
|
||||||
|
|
||||||
|
EXPECT(assert_equal(expected, actual, 1e-6));
|
||||||
|
|
||||||
|
// Check for non-tree version.
|
||||||
|
for (size_t mode : {0, 1}) {
|
||||||
|
const HybridValues hv{vv, {{M(0), mode}}};
|
||||||
|
EXPECT_DOUBLES_EQUAL(conditionals[mode]->error(vv) -
|
||||||
|
conditionals[mode]->logNormalizationConstant() -
|
||||||
|
minLogNormalizer,
|
||||||
|
mixture.error(hv), 1e-8);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
/// Check that the likelihood is proportional to the conditional density given
|
/// Check that the likelihood is proportional to the conditional density given
|
||||||
/// the measurements.
|
/// the measurements.
|
||||||
|
|
Loading…
Reference in New Issue