rename logConstant_ to logNormalizer_ and add test for error with mode-dependent constants
parent
561bdcf9af
commit
9afbc019f4
|
@ -33,8 +33,10 @@ HybridGaussianFactor::FactorValuePairs GetFactorValuePairs(
|
|||
auto func = [](const GaussianConditional::shared_ptr &conditional)
|
||||
-> GaussianFactorValuePair {
|
||||
double value = 0.0;
|
||||
if (conditional) { // Check if conditional is pruned
|
||||
value = conditional->logNormalizationConstant();
|
||||
// Check if conditional is pruned
|
||||
if (conditional) {
|
||||
// Assign log(|2πΣ|) = -2*log(1 / sqrt(|2πΣ|))
|
||||
value = -2.0 * conditional->logNormalizationConstant();
|
||||
}
|
||||
return {std::dynamic_pointer_cast<GaussianFactor>(conditional), value};
|
||||
};
|
||||
|
@ -49,14 +51,14 @@ HybridGaussianConditional::HybridGaussianConditional(
|
|||
discreteParents, GetFactorValuePairs(conditionals)),
|
||||
BaseConditional(continuousFrontals.size()),
|
||||
conditionals_(conditionals) {
|
||||
// Calculate logConstant_ as the maximum of the log constants of the
|
||||
// Calculate logNormalizer_ as the minimum of the log normalizers of the
|
||||
// conditionals, by visiting the decision tree:
|
||||
logConstant_ = -std::numeric_limits<double>::infinity();
|
||||
logNormalizer_ = std::numeric_limits<double>::infinity();
|
||||
conditionals_.visit(
|
||||
[this](const GaussianConditional::shared_ptr &conditional) {
|
||||
if (conditional) {
|
||||
this->logConstant_ = std::max(
|
||||
this->logConstant_, conditional->logNormalizationConstant());
|
||||
this->logNormalizer_ = std::min(
|
||||
this->logNormalizer_, -conditional->logNormalizationConstant());
|
||||
}
|
||||
});
|
||||
}
|
||||
|
@ -98,7 +100,7 @@ GaussianFactorGraphTree HybridGaussianConditional::asGaussianFactorGraphTree()
|
|||
// First check if conditional has not been pruned
|
||||
if (gc) {
|
||||
const double Cgm_Kgcm =
|
||||
this->logConstant_ - gc->logNormalizationConstant();
|
||||
-gc->logNormalizationConstant() - this->logNormalizer_;
|
||||
// If there is a difference in the covariances, we need to account for
|
||||
// that since the error is dependent on the mode.
|
||||
if (Cgm_Kgcm > 0.0) {
|
||||
|
@ -169,7 +171,8 @@ void HybridGaussianConditional::print(const std::string &s,
|
|||
std::cout << "(" << formatter(dk.first) << ", " << dk.second << "), ";
|
||||
}
|
||||
std::cout << std::endl
|
||||
<< " logNormalizationConstant: " << logConstant_ << std::endl
|
||||
<< " logNormalizationConstant: " << logNormalizationConstant()
|
||||
<< std::endl
|
||||
<< std::endl;
|
||||
conditionals_.print(
|
||||
"", [&](Key k) { return formatter(k); },
|
||||
|
@ -228,7 +231,7 @@ std::shared_ptr<HybridGaussianFactor> HybridGaussianConditional::likelihood(
|
|||
-> GaussianFactorValuePair {
|
||||
const auto likelihood_m = conditional->likelihood(given);
|
||||
const double Cgm_Kgcm =
|
||||
logConstant_ - conditional->logNormalizationConstant();
|
||||
-conditional->logNormalizationConstant() - logNormalizer_;
|
||||
if (Cgm_Kgcm == 0.0) {
|
||||
return {likelihood_m, 0.0};
|
||||
} else {
|
||||
|
@ -342,7 +345,7 @@ double HybridGaussianConditional::conditionalError(
|
|||
// Check if valid pointer
|
||||
if (conditional) {
|
||||
return conditional->error(continuousValues) + //
|
||||
logConstant_ - conditional->logNormalizationConstant();
|
||||
-conditional->logNormalizationConstant() - logNormalizer_;
|
||||
} else {
|
||||
// If not valid, pointer, it means this conditional was pruned,
|
||||
// so we return maximum error.
|
||||
|
|
|
@ -64,7 +64,8 @@ class GTSAM_EXPORT HybridGaussianConditional
|
|||
|
||||
private:
|
||||
Conditionals conditionals_; ///< a decision tree of Gaussian conditionals.
|
||||
double logConstant_; ///< log of the normalization constant.
|
||||
double logNormalizer_; ///< log of the normalization constant
|
||||
///< (log(\sqrt(|2πΣ|))).
|
||||
|
||||
/**
|
||||
* @brief Convert a HybridGaussianConditional of conditionals into
|
||||
|
@ -149,7 +150,7 @@ class GTSAM_EXPORT HybridGaussianConditional
|
|||
|
||||
/// The log normalization constant is max of the the individual
|
||||
/// log-normalization constants.
|
||||
double logNormalizationConstant() const override { return logConstant_; }
|
||||
double logNormalizationConstant() const override { return -logNormalizer_; }
|
||||
|
||||
/**
|
||||
* Create a likelihood factor for a hybrid Gaussian conditional,
|
||||
|
|
|
@ -100,7 +100,7 @@ TEST(HybridGaussianConditional, Error) {
|
|||
auto actual = hybrid_conditional.errorTree(vv);
|
||||
|
||||
// Check result.
|
||||
std::vector<DiscreteKey> discrete_keys = {mode};
|
||||
DiscreteKeys discrete_keys{mode};
|
||||
std::vector<double> leaves = {conditionals[0]->error(vv),
|
||||
conditionals[1]->error(vv)};
|
||||
AlgebraicDecisionTree<Key> expected(discrete_keys, leaves);
|
||||
|
@ -172,6 +172,37 @@ TEST(HybridGaussianConditional, ContinuousParents) {
|
|||
EXPECT(continuousParentKeys[0] == X(0));
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
/// Check error with mode dependent constants.
|
||||
TEST(HybridGaussianConditional, Error2) {
|
||||
using namespace mode_dependent_constants;
|
||||
auto actual = mixture.errorTree(vv);
|
||||
|
||||
// Check result.
|
||||
DiscreteKeys discrete_keys{mode};
|
||||
double logNormalizer0 = -conditionals[0]->logNormalizationConstant();
|
||||
double logNormalizer1 = -conditionals[1]->logNormalizationConstant();
|
||||
double minLogNormalizer = std::min(logNormalizer0, logNormalizer1);
|
||||
|
||||
// Expected error is e(X) + log(|2πΣ|).
|
||||
// We normalize log(|2πΣ|) with min(logNormalizers) so it is non-negative.
|
||||
std::vector<double> leaves = {
|
||||
conditionals[0]->error(vv) + logNormalizer0 - minLogNormalizer,
|
||||
conditionals[1]->error(vv) + logNormalizer1 - minLogNormalizer};
|
||||
AlgebraicDecisionTree<Key> expected(discrete_keys, leaves);
|
||||
|
||||
EXPECT(assert_equal(expected, actual, 1e-6));
|
||||
|
||||
// Check for non-tree version.
|
||||
for (size_t mode : {0, 1}) {
|
||||
const HybridValues hv{vv, {{M(0), mode}}};
|
||||
EXPECT_DOUBLES_EQUAL(conditionals[mode]->error(vv) -
|
||||
conditionals[mode]->logNormalizationConstant() -
|
||||
minLogNormalizer,
|
||||
mixture.error(hv), 1e-8);
|
||||
}
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
/// Check that the likelihood is proportional to the conditional density given
|
||||
/// the measurements.
|
||||
|
|
Loading…
Reference in New Issue