Checking mixture invariants, WIP
parent
693d18233a
commit
ab439bfbb0
|
@ -298,9 +298,14 @@ double GaussianMixture::error(const HybridValues &values) const {
|
||||||
|
|
||||||
/* *******************************************************************************/
|
/* *******************************************************************************/
|
||||||
double GaussianMixture::logProbability(const HybridValues &values) const {
|
double GaussianMixture::logProbability(const HybridValues &values) const {
|
||||||
// Directly index to get the conditional, no need to build the whole tree.
|
|
||||||
auto conditional = conditionals_(values.discrete());
|
auto conditional = conditionals_(values.discrete());
|
||||||
return conditional->logProbability(values.continuous());
|
return conditional->logProbability(values.continuous());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/* *******************************************************************************/
|
||||||
|
double GaussianMixture::evaluate(const HybridValues &values) const {
|
||||||
|
auto conditional = conditionals_(values.discrete());
|
||||||
|
return conditional->evaluate(values.continuous());
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace gtsam
|
} // namespace gtsam
|
||||||
|
|
|
@ -175,7 +175,7 @@ class GTSAM_EXPORT GaussianMixture
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @brief Compute the error of this Gaussian Mixture.
|
* @brief Compute the error of this Gaussian Mixture.
|
||||||
*
|
*
|
||||||
* log(probability(x)) = K - error(x)
|
* log(probability(x)) = K - error(x)
|
||||||
*
|
*
|
||||||
* @param values Continuous values and discrete assignment.
|
* @param values Continuous values and discrete assignment.
|
||||||
|
@ -191,12 +191,13 @@ class GTSAM_EXPORT GaussianMixture
|
||||||
*/
|
*/
|
||||||
double logProbability(const HybridValues &values) const override;
|
double logProbability(const HybridValues &values) const override;
|
||||||
|
|
||||||
// /// Calculate probability density for given values `x`.
|
/// Calculate probability density for given `values`.
|
||||||
// double evaluate(const HybridValues &values) const;
|
double evaluate(const HybridValues &values) const override;
|
||||||
|
|
||||||
// /// Evaluate probability density, sugar.
|
/// Evaluate probability density, sugar.
|
||||||
// double operator()(const HybridValues &values) const { return
|
double operator()(const HybridValues &values) const {
|
||||||
// evaluate(values); }
|
return evaluate(values);
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @brief Prune the decision tree of Gaussian factors as per the discrete
|
* @brief Prune the decision tree of Gaussian factors as per the discrete
|
||||||
|
|
|
@ -151,4 +151,24 @@ double HybridConditional::logProbability(const HybridValues &values) const {
|
||||||
"HybridConditional::logProbability: conditional type not handled");
|
"HybridConditional::logProbability: conditional type not handled");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************ */
|
||||||
|
double HybridConditional::logNormalizationConstant() const {
|
||||||
|
if (auto gc = asGaussian()) {
|
||||||
|
return gc->logNormalizationConstant();
|
||||||
|
}
|
||||||
|
if (auto gm = asMixture()) {
|
||||||
|
return gm->logNormalizationConstant(); // 0.0!
|
||||||
|
}
|
||||||
|
if (auto dc = asDiscrete()) {
|
||||||
|
return dc->logNormalizationConstant(); // 0.0!
|
||||||
|
}
|
||||||
|
throw std::runtime_error(
|
||||||
|
"HybridConditional::logProbability: conditional type not handled");
|
||||||
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************ */
|
||||||
|
double HybridConditional::evaluate(const HybridValues &values) const {
|
||||||
|
return std::exp(logProbability(values));
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace gtsam
|
} // namespace gtsam
|
||||||
|
|
|
@ -179,9 +179,19 @@ class GTSAM_EXPORT HybridConditional
|
||||||
/// Return the error of the underlying conditional.
|
/// Return the error of the underlying conditional.
|
||||||
double error(const HybridValues& values) const override;
|
double error(const HybridValues& values) const override;
|
||||||
|
|
||||||
/// Return the logProbability of the underlying conditional.
|
/// Return the log-probability (or density) of the underlying conditional.
|
||||||
double logProbability(const HybridValues& values) const override;
|
double logProbability(const HybridValues& values) const override;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Return the log normalization constant.
|
||||||
|
* Note this is 0.0 for discrete and hybrid conditionals, but depends
|
||||||
|
* on the continuous parameters for Gaussian conditionals.
|
||||||
|
*/
|
||||||
|
double logNormalizationConstant() const override;
|
||||||
|
|
||||||
|
/// Return the probability (or density) of the underlying conditional.
|
||||||
|
double evaluate(const HybridValues& values) const override;
|
||||||
|
|
||||||
/// Check if VectorValues `measurements` contains all frontal keys.
|
/// Check if VectorValues `measurements` contains all frontal keys.
|
||||||
bool frontalsIn(const VectorValues& measurements) const {
|
bool frontalsIn(const VectorValues& measurements) const {
|
||||||
for (Key key : frontals()) {
|
for (Key key : frontals()) {
|
||||||
|
|
|
@ -0,0 +1,75 @@
|
||||||
|
/* ----------------------------------------------------------------------------
|
||||||
|
|
||||||
|
* GTSAM Copyright 2010, Georgia Tech Research Corporation,
|
||||||
|
* Atlanta, Georgia 30332-0415
|
||||||
|
* All Rights Reserved
|
||||||
|
* Authors: Frank Dellaert, et al. (see THANKS for the full author list)
|
||||||
|
|
||||||
|
* See LICENSE for the license information
|
||||||
|
|
||||||
|
* -------------------------------------------------------------------------- */
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @file testHybridConditional.cpp
|
||||||
|
* @brief Unit tests for HybridConditional class
|
||||||
|
* @date January 2023
|
||||||
|
*/
|
||||||
|
|
||||||
|
#include <gtsam/hybrid/HybridConditional.h>
|
||||||
|
|
||||||
|
#include "TinyHybridExample.h"
|
||||||
|
|
||||||
|
// Include for test suite
|
||||||
|
#include <CppUnitLite/TestHarness.h>
|
||||||
|
|
||||||
|
using namespace gtsam;
|
||||||
|
|
||||||
|
using symbol_shorthand::M;
|
||||||
|
using symbol_shorthand::X;
|
||||||
|
using symbol_shorthand::Z;
|
||||||
|
|
||||||
|
/* ****************************************************************************/
|
||||||
|
// Check invariants for all conditionals in a tiny Bayes net.
|
||||||
|
TEST(HybridConditional, Invariants) {
|
||||||
|
// Create hybrid Bayes net p(z|x,m)p(x)P(m)
|
||||||
|
auto bn = tiny::createHybridBayesNet();
|
||||||
|
|
||||||
|
// Create values to check invariants.
|
||||||
|
const VectorValues c{{X(0), Vector1(5.1)}, {Z(0), Vector1(4.9)}};
|
||||||
|
const DiscreteValues d{{M(0), 1}};
|
||||||
|
const HybridValues values{c, d};
|
||||||
|
|
||||||
|
// Check invariants for p(z|x,m)
|
||||||
|
auto hc1 = bn.at(0);
|
||||||
|
CHECK(hc1->isHybrid());
|
||||||
|
GTSAM_PRINT(*hc1);
|
||||||
|
|
||||||
|
// Check invariants as a GaussianMixture.
|
||||||
|
const auto mixture = hc1->asMixture();
|
||||||
|
double probability = mixture->evaluate(values);
|
||||||
|
CHECK(probability >= 0.0);
|
||||||
|
EXPECT_DOUBLES_EQUAL(probability, (*mixture)(values), 1e-9);
|
||||||
|
double logProb = mixture->logProbability(values);
|
||||||
|
EXPECT_DOUBLES_EQUAL(probability, std::exp(logProb), 1e-9);
|
||||||
|
double expected =
|
||||||
|
mixture->logNormalizationConstant() - mixture->error(values);
|
||||||
|
EXPECT_DOUBLES_EQUAL(logProb, expected, 1e-9);
|
||||||
|
EXPECT(GaussianMixture::CheckInvariants(*mixture, values));
|
||||||
|
|
||||||
|
// Check invariants as a HybridConditional.
|
||||||
|
probability = hc1->evaluate(values);
|
||||||
|
CHECK(probability >= 0.0);
|
||||||
|
EXPECT_DOUBLES_EQUAL(probability, (*hc1)(values), 1e-9);
|
||||||
|
logProb = hc1->logProbability(values);
|
||||||
|
EXPECT_DOUBLES_EQUAL(probability, std::exp(logProb), 1e-9);
|
||||||
|
expected = hc1->logNormalizationConstant() - hc1->error(values);
|
||||||
|
EXPECT_DOUBLES_EQUAL(logProb, expected, 1e-9);
|
||||||
|
EXPECT(HybridConditional::CheckInvariants(*hc1, values));
|
||||||
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************* */
|
||||||
|
int main() {
|
||||||
|
TestResult tr;
|
||||||
|
return TestRegistry::runAllTests(tr);
|
||||||
|
}
|
||||||
|
/* ************************************************************************* */
|
Loading…
Reference in New Issue