commit
618ac28f2c
|
@ -20,7 +20,7 @@
|
||||||
#include <gtsam/base/debug.h>
|
#include <gtsam/base/debug.h>
|
||||||
#include <gtsam/discrete/DiscreteConditional.h>
|
#include <gtsam/discrete/DiscreteConditional.h>
|
||||||
#include <gtsam/discrete/Signature.h>
|
#include <gtsam/discrete/Signature.h>
|
||||||
#include <gtsam/inference/Conditional-inst.h>
|
#include <gtsam/hybrid/HybridValues.h>
|
||||||
|
|
||||||
#include <algorithm>
|
#include <algorithm>
|
||||||
#include <boost/make_shared.hpp>
|
#include <boost/make_shared.hpp>
|
||||||
|
@ -510,6 +510,10 @@ string DiscreteConditional::html(const KeyFormatter& keyFormatter,
|
||||||
return ss.str();
|
return ss.str();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************* */
|
||||||
|
double DiscreteConditional::evaluate(const HybridValues& x) const{
|
||||||
|
return this->evaluate(x.discrete());
|
||||||
|
}
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
|
|
||||||
} // namespace gtsam
|
} // namespace gtsam
|
||||||
|
|
|
@ -160,10 +160,13 @@ class GTSAM_EXPORT DiscreteConditional
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Evaluate, just look up in AlgebraicDecisonTree
|
/// Evaluate, just look up in AlgebraicDecisonTree
|
||||||
double operator()(const DiscreteValues& values) const override {
|
double evaluate(const DiscreteValues& values) const {
|
||||||
return ADT::operator()(values);
|
return ADT::operator()(values);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
using DecisionTreeFactor::error; ///< DiscreteValues version
|
||||||
|
using DecisionTreeFactor::operator(); ///< DiscreteValues version
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @brief restrict to given *parent* values.
|
* @brief restrict to given *parent* values.
|
||||||
*
|
*
|
||||||
|
@ -235,6 +238,14 @@ class GTSAM_EXPORT DiscreteConditional
|
||||||
/// @name HybridValues methods.
|
/// @name HybridValues methods.
|
||||||
/// @{
|
/// @{
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Calculate probability for HybridValues `x`.
|
||||||
|
* Dispatches to DiscreteValues version.
|
||||||
|
*/
|
||||||
|
double evaluate(const HybridValues& x) const override;
|
||||||
|
|
||||||
|
using BaseConditional::operator(); ///< HybridValues version
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Calculate log-probability log(evaluate(x)) for HybridValues `x`.
|
* Calculate log-probability log(evaluate(x)) for HybridValues `x`.
|
||||||
* This is actually just -error(x).
|
* This is actually just -error(x).
|
||||||
|
@ -243,8 +254,6 @@ class GTSAM_EXPORT DiscreteConditional
|
||||||
return -error(x);
|
return -error(x);
|
||||||
}
|
}
|
||||||
|
|
||||||
using DecisionTreeFactor::evaluate;
|
|
||||||
|
|
||||||
/// @}
|
/// @}
|
||||||
|
|
||||||
#ifdef GTSAM_ALLOW_DEPRECATED_SINCE_V42
|
#ifdef GTSAM_ALLOW_DEPRECATED_SINCE_V42
|
||||||
|
|
|
@ -82,6 +82,7 @@ virtual class DecisionTreeFactor : gtsam::DiscreteFactor {
|
||||||
};
|
};
|
||||||
|
|
||||||
#include <gtsam/discrete/DiscreteConditional.h>
|
#include <gtsam/discrete/DiscreteConditional.h>
|
||||||
|
#include <gtsam/hybrid/HybridValues.h>
|
||||||
virtual class DiscreteConditional : gtsam::DecisionTreeFactor {
|
virtual class DiscreteConditional : gtsam::DecisionTreeFactor {
|
||||||
DiscreteConditional();
|
DiscreteConditional();
|
||||||
DiscreteConditional(size_t nFrontals, const gtsam::DecisionTreeFactor& f);
|
DiscreteConditional(size_t nFrontals, const gtsam::DecisionTreeFactor& f);
|
||||||
|
@ -95,9 +96,12 @@ virtual class DiscreteConditional : gtsam::DecisionTreeFactor {
|
||||||
DiscreteConditional(const gtsam::DecisionTreeFactor& joint,
|
DiscreteConditional(const gtsam::DecisionTreeFactor& joint,
|
||||||
const gtsam::DecisionTreeFactor& marginal,
|
const gtsam::DecisionTreeFactor& marginal,
|
||||||
const gtsam::Ordering& orderedKeys);
|
const gtsam::Ordering& orderedKeys);
|
||||||
|
|
||||||
|
// Standard interface
|
||||||
|
double logNormalizationConstant() const;
|
||||||
double logProbability(const gtsam::DiscreteValues& values) const;
|
double logProbability(const gtsam::DiscreteValues& values) const;
|
||||||
double evaluate(const gtsam::DiscreteValues& values) const;
|
double evaluate(const gtsam::DiscreteValues& values) const;
|
||||||
double operator()(const gtsam::DiscreteValues& values) const;
|
double error(const gtsam::DiscreteValues& values) const;
|
||||||
gtsam::DiscreteConditional operator*(
|
gtsam::DiscreteConditional operator*(
|
||||||
const gtsam::DiscreteConditional& other) const;
|
const gtsam::DiscreteConditional& other) const;
|
||||||
gtsam::DiscreteConditional marginal(gtsam::Key key) const;
|
gtsam::DiscreteConditional marginal(gtsam::Key key) const;
|
||||||
|
@ -119,6 +123,8 @@ virtual class DiscreteConditional : gtsam::DecisionTreeFactor {
|
||||||
size_t sample(size_t value) const;
|
size_t sample(size_t value) const;
|
||||||
size_t sample() const;
|
size_t sample() const;
|
||||||
void sampleInPlace(gtsam::DiscreteValues @parentsValues) const;
|
void sampleInPlace(gtsam::DiscreteValues @parentsValues) const;
|
||||||
|
|
||||||
|
// Markdown and HTML
|
||||||
string markdown(const gtsam::KeyFormatter& keyFormatter =
|
string markdown(const gtsam::KeyFormatter& keyFormatter =
|
||||||
gtsam::DefaultKeyFormatter) const;
|
gtsam::DefaultKeyFormatter) const;
|
||||||
string markdown(const gtsam::KeyFormatter& keyFormatter,
|
string markdown(const gtsam::KeyFormatter& keyFormatter,
|
||||||
|
@ -127,6 +133,11 @@ virtual class DiscreteConditional : gtsam::DecisionTreeFactor {
|
||||||
gtsam::DefaultKeyFormatter) const;
|
gtsam::DefaultKeyFormatter) const;
|
||||||
string html(const gtsam::KeyFormatter& keyFormatter,
|
string html(const gtsam::KeyFormatter& keyFormatter,
|
||||||
std::map<gtsam::Key, std::vector<std::string>> names) const;
|
std::map<gtsam::Key, std::vector<std::string>> names) const;
|
||||||
|
|
||||||
|
// Expose HybridValues versions
|
||||||
|
double logProbability(const gtsam::HybridValues& x) const;
|
||||||
|
double evaluate(const gtsam::HybridValues& x) const;
|
||||||
|
double error(const gtsam::HybridValues& x) const;
|
||||||
};
|
};
|
||||||
|
|
||||||
#include <gtsam/discrete/DiscreteDistribution.h>
|
#include <gtsam/discrete/DiscreteDistribution.h>
|
||||||
|
|
|
@ -96,6 +96,7 @@ TEST(DiscreteConditional, PriorProbability) {
|
||||||
DiscreteConditional dc(Asia, "4/6");
|
DiscreteConditional dc(Asia, "4/6");
|
||||||
DiscreteValues values{{asiaKey, 0}};
|
DiscreteValues values{{asiaKey, 0}};
|
||||||
EXPECT_DOUBLES_EQUAL(0.4, dc.evaluate(values), 1e-9);
|
EXPECT_DOUBLES_EQUAL(0.4, dc.evaluate(values), 1e-9);
|
||||||
|
EXPECT(DiscreteConditional::CheckInvariants(dc, values));
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
|
@ -109,6 +110,7 @@ TEST(DiscreteConditional, probability) {
|
||||||
EXPECT_DOUBLES_EQUAL(0.2, C_given_DE(given), 1e-9);
|
EXPECT_DOUBLES_EQUAL(0.2, C_given_DE(given), 1e-9);
|
||||||
EXPECT_DOUBLES_EQUAL(log(0.2), C_given_DE.logProbability(given), 1e-9);
|
EXPECT_DOUBLES_EQUAL(log(0.2), C_given_DE.logProbability(given), 1e-9);
|
||||||
EXPECT_DOUBLES_EQUAL(-log(0.2), C_given_DE.error(given), 1e-9);
|
EXPECT_DOUBLES_EQUAL(-log(0.2), C_given_DE.error(given), 1e-9);
|
||||||
|
EXPECT(DiscreteConditional::CheckInvariants(C_given_DE, given));
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
|
|
|
@ -290,10 +290,22 @@ AlgebraicDecisionTree<Key> GaussianMixture::logProbability(
|
||||||
}
|
}
|
||||||
|
|
||||||
/* *******************************************************************************/
|
/* *******************************************************************************/
|
||||||
double GaussianMixture::logProbability(const HybridValues &values) const {
|
double GaussianMixture::error(const HybridValues &values) const {
|
||||||
// Directly index to get the conditional, no need to build the whole tree.
|
// Directly index to get the conditional, no need to build the whole tree.
|
||||||
|
auto conditional = conditionals_(values.discrete());
|
||||||
|
return conditional->error(values.continuous()) - conditional->logNormalizationConstant();
|
||||||
|
}
|
||||||
|
|
||||||
|
/* *******************************************************************************/
|
||||||
|
double GaussianMixture::logProbability(const HybridValues &values) const {
|
||||||
auto conditional = conditionals_(values.discrete());
|
auto conditional = conditionals_(values.discrete());
|
||||||
return conditional->logProbability(values.continuous());
|
return conditional->logProbability(values.continuous());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/* *******************************************************************************/
|
||||||
|
double GaussianMixture::evaluate(const HybridValues &values) const {
|
||||||
|
auto conditional = conditionals_(values.discrete());
|
||||||
|
return conditional->evaluate(values.continuous());
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace gtsam
|
} // namespace gtsam
|
||||||
|
|
|
@ -174,20 +174,44 @@ class GTSAM_EXPORT GaussianMixture
|
||||||
const VectorValues &continuousValues) const;
|
const VectorValues &continuousValues) const;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @brief Compute the logProbability of this Gaussian Mixture given the
|
* @brief Compute the error of this Gaussian Mixture.
|
||||||
* continuous values and a discrete assignment.
|
*
|
||||||
|
* This requires some care, as different mixture components may have
|
||||||
|
* different normalization constants. Let's consider p(x|y,m), where m is
|
||||||
|
* discrete. We need the error to satisfy the invariant:
|
||||||
|
*
|
||||||
|
* error(x;y,m) = K - log(probability(x;y,m))
|
||||||
|
*
|
||||||
|
* For all x,y,m. But note that K, for the GaussianMixture, cannot depend on
|
||||||
|
* any arguments. Hence, we delegate to the underlying Gaussian
|
||||||
|
* conditionals, indexed by m, which do satisfy:
|
||||||
|
*
|
||||||
|
* log(probability_m(x;y)) = K_m - error_m(x;y)
|
||||||
|
*
|
||||||
|
* We resolve by having K == 0.0 and
|
||||||
|
*
|
||||||
|
* error(x;y,m) = error_m(x;y) - K_m
|
||||||
|
*
|
||||||
|
* @param values Continuous values and discrete assignment.
|
||||||
|
* @return double
|
||||||
|
*/
|
||||||
|
double error(const HybridValues &values) const override;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Compute the logProbability of this Gaussian Mixture.
|
||||||
*
|
*
|
||||||
* @param values Continuous values and discrete assignment.
|
* @param values Continuous values and discrete assignment.
|
||||||
* @return double
|
* @return double
|
||||||
*/
|
*/
|
||||||
double logProbability(const HybridValues &values) const override;
|
double logProbability(const HybridValues &values) const override;
|
||||||
|
|
||||||
// /// Calculate probability density for given values `x`.
|
/// Calculate probability density for given `values`.
|
||||||
// double evaluate(const HybridValues &values) const;
|
double evaluate(const HybridValues &values) const override;
|
||||||
|
|
||||||
// /// Evaluate probability density, sugar.
|
/// Evaluate probability density, sugar.
|
||||||
// double operator()(const HybridValues &values) const { return
|
double operator()(const HybridValues &values) const {
|
||||||
// evaluate(values); }
|
return evaluate(values);
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @brief Prune the decision tree of Gaussian factors as per the discrete
|
* @brief Prune the decision tree of Gaussian factors as per the discrete
|
||||||
|
|
|
@ -121,6 +121,21 @@ bool HybridConditional::equals(const HybridFactor &other, double tol) const {
|
||||||
: !(e->inner_);
|
: !(e->inner_);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************ */
|
||||||
|
double HybridConditional::error(const HybridValues &values) const {
|
||||||
|
if (auto gc = asGaussian()) {
|
||||||
|
return gc->error(values.continuous());
|
||||||
|
}
|
||||||
|
if (auto gm = asMixture()) {
|
||||||
|
return gm->error(values);
|
||||||
|
}
|
||||||
|
if (auto dc = asDiscrete()) {
|
||||||
|
return dc->error(values.discrete());
|
||||||
|
}
|
||||||
|
throw std::runtime_error(
|
||||||
|
"HybridConditional::error: conditional type not handled");
|
||||||
|
}
|
||||||
|
|
||||||
/* ************************************************************************ */
|
/* ************************************************************************ */
|
||||||
double HybridConditional::logProbability(const HybridValues &values) const {
|
double HybridConditional::logProbability(const HybridValues &values) const {
|
||||||
if (auto gc = asGaussian()) {
|
if (auto gc = asGaussian()) {
|
||||||
|
@ -136,4 +151,24 @@ double HybridConditional::logProbability(const HybridValues &values) const {
|
||||||
"HybridConditional::logProbability: conditional type not handled");
|
"HybridConditional::logProbability: conditional type not handled");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************ */
|
||||||
|
double HybridConditional::logNormalizationConstant() const {
|
||||||
|
if (auto gc = asGaussian()) {
|
||||||
|
return gc->logNormalizationConstant();
|
||||||
|
}
|
||||||
|
if (auto gm = asMixture()) {
|
||||||
|
return gm->logNormalizationConstant(); // 0.0!
|
||||||
|
}
|
||||||
|
if (auto dc = asDiscrete()) {
|
||||||
|
return dc->logNormalizationConstant(); // 0.0!
|
||||||
|
}
|
||||||
|
throw std::runtime_error(
|
||||||
|
"HybridConditional::logProbability: conditional type not handled");
|
||||||
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************ */
|
||||||
|
double HybridConditional::evaluate(const HybridValues &values) const {
|
||||||
|
return std::exp(logProbability(values));
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace gtsam
|
} // namespace gtsam
|
||||||
|
|
|
@ -176,9 +176,22 @@ class GTSAM_EXPORT HybridConditional
|
||||||
/// Get the type-erased pointer to the inner type
|
/// Get the type-erased pointer to the inner type
|
||||||
boost::shared_ptr<Factor> inner() const { return inner_; }
|
boost::shared_ptr<Factor> inner() const { return inner_; }
|
||||||
|
|
||||||
/// Return the logProbability of the underlying conditional.
|
/// Return the error of the underlying conditional.
|
||||||
|
double error(const HybridValues& values) const override;
|
||||||
|
|
||||||
|
/// Return the log-probability (or density) of the underlying conditional.
|
||||||
double logProbability(const HybridValues& values) const override;
|
double logProbability(const HybridValues& values) const override;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Return the log normalization constant.
|
||||||
|
* Note this is 0.0 for discrete and hybrid conditionals, but depends
|
||||||
|
* on the continuous parameters for Gaussian conditionals.
|
||||||
|
*/
|
||||||
|
double logNormalizationConstant() const override;
|
||||||
|
|
||||||
|
/// Return the probability (or density) of the underlying conditional.
|
||||||
|
double evaluate(const HybridValues& values) const override;
|
||||||
|
|
||||||
/// Check if VectorValues `measurements` contains all frontal keys.
|
/// Check if VectorValues `measurements` contains all frontal keys.
|
||||||
bool frontalsIn(const VectorValues& measurements) const {
|
bool frontalsIn(const VectorValues& measurements) const {
|
||||||
for (Key key : frontals()) {
|
for (Key key : frontals()) {
|
||||||
|
|
|
@ -61,6 +61,7 @@ virtual class HybridConditional {
|
||||||
size_t nrParents() const;
|
size_t nrParents() const;
|
||||||
|
|
||||||
// Standard interface:
|
// Standard interface:
|
||||||
|
double logNormalizationConstant() const;
|
||||||
double logProbability(const gtsam::HybridValues& values) const;
|
double logProbability(const gtsam::HybridValues& values) const;
|
||||||
double evaluate(const gtsam::HybridValues& values) const;
|
double evaluate(const gtsam::HybridValues& values) const;
|
||||||
double operator()(const gtsam::HybridValues& values) const;
|
double operator()(const gtsam::HybridValues& values) const;
|
||||||
|
|
|
@ -0,0 +1,83 @@
|
||||||
|
/* ----------------------------------------------------------------------------
|
||||||
|
|
||||||
|
* GTSAM Copyright 2010, Georgia Tech Research Corporation,
|
||||||
|
* Atlanta, Georgia 30332-0415
|
||||||
|
* All Rights Reserved
|
||||||
|
* Authors: Frank Dellaert, et al. (see THANKS for the full author list)
|
||||||
|
|
||||||
|
* See LICENSE for the license information
|
||||||
|
|
||||||
|
* -------------------------------------------------------------------------- */
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @file testHybridConditional.cpp
|
||||||
|
* @brief Unit tests for HybridConditional class
|
||||||
|
* @date January 2023
|
||||||
|
*/
|
||||||
|
|
||||||
|
#include <gtsam/hybrid/HybridConditional.h>
|
||||||
|
|
||||||
|
#include "TinyHybridExample.h"
|
||||||
|
|
||||||
|
// Include for test suite
|
||||||
|
#include <CppUnitLite/TestHarness.h>
|
||||||
|
|
||||||
|
using namespace gtsam;
|
||||||
|
|
||||||
|
using symbol_shorthand::M;
|
||||||
|
using symbol_shorthand::X;
|
||||||
|
using symbol_shorthand::Z;
|
||||||
|
|
||||||
|
/* ****************************************************************************/
|
||||||
|
// Check invariants for all conditionals in a tiny Bayes net.
|
||||||
|
TEST(HybridConditional, Invariants) {
|
||||||
|
// Create hybrid Bayes net p(z|x,m)p(x)P(m)
|
||||||
|
auto bn = tiny::createHybridBayesNet();
|
||||||
|
|
||||||
|
// Create values to check invariants.
|
||||||
|
const VectorValues c{{X(0), Vector1(5.1)}, {Z(0), Vector1(4.9)}};
|
||||||
|
const DiscreteValues d{{M(0), 1}};
|
||||||
|
const HybridValues values{c, d};
|
||||||
|
|
||||||
|
// Check invariants for p(z|x,m)
|
||||||
|
auto hc0 = bn.at(0);
|
||||||
|
CHECK(hc0->isHybrid());
|
||||||
|
|
||||||
|
// Check invariants as a GaussianMixture.
|
||||||
|
const auto mixture = hc0->asMixture();
|
||||||
|
EXPECT(GaussianMixture::CheckInvariants(*mixture, values));
|
||||||
|
|
||||||
|
// Check invariants as a HybridConditional.
|
||||||
|
EXPECT(HybridConditional::CheckInvariants(*hc0, values));
|
||||||
|
|
||||||
|
// Check invariants for p(x)
|
||||||
|
auto hc1 = bn.at(1);
|
||||||
|
CHECK(hc1->isContinuous());
|
||||||
|
|
||||||
|
// Check invariants as a GaussianConditional.
|
||||||
|
const auto gaussian = hc1->asGaussian();
|
||||||
|
EXPECT(GaussianConditional::CheckInvariants(*gaussian, c));
|
||||||
|
EXPECT(GaussianConditional::CheckInvariants(*gaussian, values));
|
||||||
|
|
||||||
|
// Check invariants as a HybridConditional.
|
||||||
|
EXPECT(HybridConditional::CheckInvariants(*hc1, values));
|
||||||
|
|
||||||
|
// Check invariants for p(m)
|
||||||
|
auto hc2 = bn.at(2);
|
||||||
|
CHECK(hc2->isDiscrete());
|
||||||
|
|
||||||
|
// Check invariants as a DiscreteConditional.
|
||||||
|
const auto discrete = hc2->asDiscrete();
|
||||||
|
EXPECT(DiscreteConditional::CheckInvariants(*discrete, d));
|
||||||
|
EXPECT(DiscreteConditional::CheckInvariants(*discrete, values));
|
||||||
|
|
||||||
|
// Check invariants as a HybridConditional.
|
||||||
|
EXPECT(HybridConditional::CheckInvariants(*hc2, values));
|
||||||
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************* */
|
||||||
|
int main() {
|
||||||
|
TestResult tr;
|
||||||
|
return TestRegistry::runAllTests(tr);
|
||||||
|
}
|
||||||
|
/* ************************************************************************* */
|
|
@ -56,4 +56,30 @@ double Conditional<FACTOR, DERIVEDCONDITIONAL>::evaluate(
|
||||||
const HybridValues& c) const {
|
const HybridValues& c) const {
|
||||||
throw std::runtime_error("Conditional::evaluate is not implemented");
|
throw std::runtime_error("Conditional::evaluate is not implemented");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************* */
|
||||||
|
template <class FACTOR, class DERIVEDCONDITIONAL>
|
||||||
|
double Conditional<FACTOR, DERIVEDCONDITIONAL>::normalizationConstant() const {
|
||||||
|
return std::exp(logNormalizationConstant());
|
||||||
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************* */
|
||||||
|
template <class FACTOR, class DERIVEDCONDITIONAL>
|
||||||
|
template <class VALUES>
|
||||||
|
bool Conditional<FACTOR, DERIVEDCONDITIONAL>::CheckInvariants(
|
||||||
|
const DERIVEDCONDITIONAL& conditional, const VALUES& values) {
|
||||||
|
const double prob_or_density = conditional.evaluate(values);
|
||||||
|
if (prob_or_density < 0.0) return false; // prob_or_density is negative.
|
||||||
|
if (std::abs(prob_or_density - conditional(values)) > 1e-9)
|
||||||
|
return false; // operator and evaluate differ
|
||||||
|
const double logProb = conditional.logProbability(values);
|
||||||
|
if (std::abs(prob_or_density - std::exp(logProb)) > 1e-9)
|
||||||
|
return false; // logProb is not consistent with prob_or_density
|
||||||
|
const double expected =
|
||||||
|
conditional.logNormalizationConstant() - conditional.error(values);
|
||||||
|
if (std::abs(logProb - expected) > 1e-9)
|
||||||
|
return false; // logProb is not consistent with error
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace gtsam
|
} // namespace gtsam
|
||||||
|
|
|
@ -141,6 +141,15 @@ namespace gtsam {
|
||||||
return evaluate(x);
|
return evaluate(x);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* By default, log normalization constant = 0.0.
|
||||||
|
* Override if this depends on the parameters.
|
||||||
|
*/
|
||||||
|
virtual double logNormalizationConstant() const { return 0.0; }
|
||||||
|
|
||||||
|
/** Non-virtual, exponentiate logNormalizationConstant. */
|
||||||
|
double normalizationConstant() const;
|
||||||
|
|
||||||
/// @}
|
/// @}
|
||||||
/// @name Advanced Interface
|
/// @name Advanced Interface
|
||||||
/// @{
|
/// @{
|
||||||
|
@ -172,7 +181,17 @@ namespace gtsam {
|
||||||
/** Mutable iterator pointing past the last parent key. */
|
/** Mutable iterator pointing past the last parent key. */
|
||||||
typename FACTOR::iterator endParents() { return asFactor().end(); }
|
typename FACTOR::iterator endParents() { return asFactor().end(); }
|
||||||
|
|
||||||
|
template <class VALUES>
|
||||||
|
static bool CheckInvariants(const DERIVEDCONDITIONAL& conditional,
|
||||||
|
const VALUES& values);
|
||||||
|
|
||||||
|
/// @}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
|
||||||
|
/// @name Serialization
|
||||||
|
/// @{
|
||||||
|
|
||||||
// Cast to factor type (non-const) (casts down to derived conditional type, then up to factor type)
|
// Cast to factor type (non-const) (casts down to derived conditional type, then up to factor type)
|
||||||
FACTOR& asFactor() { return static_cast<FACTOR&>(static_cast<DERIVEDCONDITIONAL&>(*this)); }
|
FACTOR& asFactor() { return static_cast<FACTOR&>(static_cast<DERIVEDCONDITIONAL&>(*this)); }
|
||||||
|
|
||||||
|
|
|
@ -205,9 +205,14 @@ namespace gtsam {
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
double GaussianConditional::evaluate(const VectorValues& c) const {
|
double GaussianConditional::evaluate(const VectorValues& x) const {
|
||||||
return exp(logProbability(c));
|
return exp(logProbability(x));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
double GaussianConditional::evaluate(const HybridValues& x) const {
|
||||||
|
return evaluate(x.continuous());
|
||||||
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
VectorValues GaussianConditional::solve(const VectorValues& x) const {
|
VectorValues GaussianConditional::solve(const VectorValues& x) const {
|
||||||
// Concatenate all vector values that correspond to parent variables
|
// Concatenate all vector values that correspond to parent variables
|
||||||
|
|
|
@ -34,7 +34,7 @@ namespace gtsam {
|
||||||
/**
|
/**
|
||||||
* A GaussianConditional functions as the node in a Bayes network.
|
* A GaussianConditional functions as the node in a Bayes network.
|
||||||
* It has a set of parents y,z, etc. and implements a probability density on x.
|
* It has a set of parents y,z, etc. and implements a probability density on x.
|
||||||
* The negative log-probability is given by \f$ \frac{1}{2} |Rx - (d - Sy - Tz - ...)|^2 \f$
|
* The negative log-density is given by \f$ \frac{1}{2} |Rx - (d - Sy - Tz - ...)|^2 \f$
|
||||||
* @ingroup linear
|
* @ingroup linear
|
||||||
*/
|
*/
|
||||||
class GTSAM_EXPORT GaussianConditional :
|
class GTSAM_EXPORT GaussianConditional :
|
||||||
|
@ -136,14 +136,7 @@ namespace gtsam {
|
||||||
* normalization constant = 1.0 / sqrt((2*pi)^n*det(Sigma))
|
* normalization constant = 1.0 / sqrt((2*pi)^n*det(Sigma))
|
||||||
* log = - 0.5 * n*log(2*pi) - 0.5 * log det(Sigma)
|
* log = - 0.5 * n*log(2*pi) - 0.5 * log det(Sigma)
|
||||||
*/
|
*/
|
||||||
double logNormalizationConstant() const;
|
double logNormalizationConstant() const override;
|
||||||
|
|
||||||
/**
|
|
||||||
* normalization constant = 1.0 / sqrt((2*pi)^n*det(Sigma))
|
|
||||||
*/
|
|
||||||
inline double normalizationConstant() const {
|
|
||||||
return exp(logNormalizationConstant());
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Calculate log-probability log(evaluate(x)) for given values `x`:
|
* Calculate log-probability log(evaluate(x)) for given values `x`:
|
||||||
|
@ -269,9 +262,14 @@ namespace gtsam {
|
||||||
*/
|
*/
|
||||||
double logProbability(const HybridValues& x) const override;
|
double logProbability(const HybridValues& x) const override;
|
||||||
|
|
||||||
using Conditional::evaluate; // Expose evaluate(const HybridValues&) method..
|
/**
|
||||||
|
* Calculate probability for HybridValues `x`.
|
||||||
|
* Simply dispatches to VectorValues version.
|
||||||
|
*/
|
||||||
|
double evaluate(const HybridValues& x) const override;
|
||||||
|
|
||||||
using Conditional::operator(); // Expose evaluate(const HybridValues&) method..
|
using Conditional::operator(); // Expose evaluate(const HybridValues&) method..
|
||||||
using Base::error; // Expose error(const HybridValues&) method..
|
using JacobianFactor::error; // Expose error(const HybridValues&) method..
|
||||||
|
|
||||||
/// @}
|
/// @}
|
||||||
|
|
||||||
|
|
|
@ -196,6 +196,9 @@ namespace gtsam {
|
||||||
/** Compare to another factor for testing (implementing Testable) */
|
/** Compare to another factor for testing (implementing Testable) */
|
||||||
bool equals(const GaussianFactor& lf, double tol = 1e-9) const override;
|
bool equals(const GaussianFactor& lf, double tol = 1e-9) const override;
|
||||||
|
|
||||||
|
/// HybridValues simply extracts the \class VectorValues and calls error.
|
||||||
|
using GaussianFactor::error;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Evaluate the factor error f(x).
|
* Evaluate the factor error f(x).
|
||||||
* returns 0.5*[x -1]'*H*[x -1] (also see constructor documentation)
|
* returns 0.5*[x -1]'*H*[x -1] (also see constructor documentation)
|
||||||
|
|
|
@ -198,7 +198,12 @@ namespace gtsam {
|
||||||
|
|
||||||
Vector unweighted_error(const VectorValues& c) const; /** (A*x-b) */
|
Vector unweighted_error(const VectorValues& c) const; /** (A*x-b) */
|
||||||
Vector error_vector(const VectorValues& c) const; /** (A*x-b)/sigma */
|
Vector error_vector(const VectorValues& c) const; /** (A*x-b)/sigma */
|
||||||
double error(const VectorValues& c) const override; /** 0.5*(A*x-b)'*D*(A*x-b) */
|
|
||||||
|
/// HybridValues simply extracts the \class VectorValues and calls error.
|
||||||
|
using GaussianFactor::error;
|
||||||
|
|
||||||
|
//// 0.5*(A*x-b)'*D*(A*x-b).
|
||||||
|
double error(const VectorValues& c) const override;
|
||||||
|
|
||||||
/** Return the augmented information matrix represented by this GaussianFactor.
|
/** Return the augmented information matrix represented by this GaussianFactor.
|
||||||
* The augmented information matrix contains the information matrix with an
|
* The augmented information matrix contains the information matrix with an
|
||||||
|
|
|
@ -456,6 +456,7 @@ class GaussianFactorGraph {
|
||||||
};
|
};
|
||||||
|
|
||||||
#include <gtsam/linear/GaussianConditional.h>
|
#include <gtsam/linear/GaussianConditional.h>
|
||||||
|
#include <gtsam/hybrid/HybridValues.h>
|
||||||
virtual class GaussianConditional : gtsam::JacobianFactor {
|
virtual class GaussianConditional : gtsam::JacobianFactor {
|
||||||
// Constructors
|
// Constructors
|
||||||
GaussianConditional(size_t key, Vector d, Matrix R,
|
GaussianConditional(size_t key, Vector d, Matrix R,
|
||||||
|
@ -497,6 +498,7 @@ virtual class GaussianConditional : gtsam::JacobianFactor {
|
||||||
bool equals(const gtsam::GaussianConditional& cg, double tol) const;
|
bool equals(const gtsam::GaussianConditional& cg, double tol) const;
|
||||||
|
|
||||||
// Standard Interface
|
// Standard Interface
|
||||||
|
double logNormalizationConstant() const;
|
||||||
double logProbability(const gtsam::VectorValues& x) const;
|
double logProbability(const gtsam::VectorValues& x) const;
|
||||||
double evaluate(const gtsam::VectorValues& x) const;
|
double evaluate(const gtsam::VectorValues& x) const;
|
||||||
double error(const gtsam::VectorValues& x) const;
|
double error(const gtsam::VectorValues& x) const;
|
||||||
|
@ -518,6 +520,11 @@ virtual class GaussianConditional : gtsam::JacobianFactor {
|
||||||
|
|
||||||
// enabling serialization functionality
|
// enabling serialization functionality
|
||||||
void serialize() const;
|
void serialize() const;
|
||||||
|
|
||||||
|
// Expose HybridValues versions
|
||||||
|
double logProbability(const gtsam::HybridValues& x) const;
|
||||||
|
double evaluate(const gtsam::HybridValues& x) const;
|
||||||
|
double error(const gtsam::HybridValues& x) const;
|
||||||
};
|
};
|
||||||
|
|
||||||
#include <gtsam/linear/GaussianDensity.h>
|
#include <gtsam/linear/GaussianDensity.h>
|
||||||
|
|
|
@ -25,6 +25,7 @@
|
||||||
#include <gtsam/linear/GaussianConditional.h>
|
#include <gtsam/linear/GaussianConditional.h>
|
||||||
#include <gtsam/linear/GaussianDensity.h>
|
#include <gtsam/linear/GaussianDensity.h>
|
||||||
#include <gtsam/linear/GaussianBayesNet.h>
|
#include <gtsam/linear/GaussianBayesNet.h>
|
||||||
|
#include <gtsam/hybrid/HybridValues.h>
|
||||||
|
|
||||||
#include <boost/make_shared.hpp>
|
#include <boost/make_shared.hpp>
|
||||||
|
|
||||||
|
@ -154,6 +155,13 @@ TEST(GaussianConditional, Evaluate1) {
|
||||||
using density::key;
|
using density::key;
|
||||||
using density::sigma;
|
using density::sigma;
|
||||||
|
|
||||||
|
// Check Invariants at the mean and a different value
|
||||||
|
for (auto vv : {mean, VectorValues{{key, Vector1(4)}}}) {
|
||||||
|
EXPECT(GaussianConditional::CheckInvariants(density::unitPrior, vv));
|
||||||
|
EXPECT(GaussianConditional::CheckInvariants(density::unitPrior,
|
||||||
|
HybridValues{vv, {}, {}}));
|
||||||
|
}
|
||||||
|
|
||||||
// Let's numerically integrate and see that we integrate to 1.0.
|
// Let's numerically integrate and see that we integrate to 1.0.
|
||||||
double integral = 0.0;
|
double integral = 0.0;
|
||||||
// Loop from -5*sigma to 5*sigma in 0.1*sigma steps:
|
// Loop from -5*sigma to 5*sigma in 0.1*sigma steps:
|
||||||
|
@ -180,6 +188,13 @@ TEST(GaussianConditional, Evaluate2) {
|
||||||
using density::key;
|
using density::key;
|
||||||
using density::sigma;
|
using density::sigma;
|
||||||
|
|
||||||
|
// Check Invariants at the mean and a different value
|
||||||
|
for (auto vv : {mean, VectorValues{{key, Vector1(4)}}}) {
|
||||||
|
EXPECT(GaussianConditional::CheckInvariants(density::widerPrior, vv));
|
||||||
|
EXPECT(GaussianConditional::CheckInvariants(density::widerPrior,
|
||||||
|
HybridValues{vv, {}, {}}));
|
||||||
|
}
|
||||||
|
|
||||||
// Let's numerically integrate and see that we integrate to 1.0.
|
// Let's numerically integrate and see that we integrate to 1.0.
|
||||||
double integral = 0.0;
|
double integral = 0.0;
|
||||||
// Loop from -5*sigma to 5*sigma in 0.1*sigma steps:
|
// Loop from -5*sigma to 5*sigma in 0.1*sigma steps:
|
||||||
|
@ -384,17 +399,18 @@ TEST(GaussianConditional, FromMeanAndStddev) {
|
||||||
double expected1 = 0.5 * e1.dot(e1);
|
double expected1 = 0.5 * e1.dot(e1);
|
||||||
EXPECT_DOUBLES_EQUAL(expected1, conditional1.error(values), 1e-9);
|
EXPECT_DOUBLES_EQUAL(expected1, conditional1.error(values), 1e-9);
|
||||||
|
|
||||||
double expected2 = conditional1.logNormalizationConstant() - 0.5 * e1.dot(e1);
|
|
||||||
EXPECT_DOUBLES_EQUAL(expected2, conditional1.logProbability(values), 1e-9);
|
|
||||||
|
|
||||||
auto conditional2 = GaussianConditional::FromMeanAndStddev(X(0), A1, X(1), A2,
|
auto conditional2 = GaussianConditional::FromMeanAndStddev(X(0), A1, X(1), A2,
|
||||||
X(2), b, sigma);
|
X(2), b, sigma);
|
||||||
Vector2 e2 = (x0 - (A1 * x1 + A2 * x2 + b)) / sigma;
|
Vector2 e2 = (x0 - (A1 * x1 + A2 * x2 + b)) / sigma;
|
||||||
double expected3 = 0.5 * e2.dot(e2);
|
double expected2 = 0.5 * e2.dot(e2);
|
||||||
EXPECT_DOUBLES_EQUAL(expected3, conditional2.error(values), 1e-9);
|
EXPECT_DOUBLES_EQUAL(expected2, conditional2.error(values), 1e-9);
|
||||||
|
|
||||||
double expected4 = conditional2.logNormalizationConstant() - 0.5 * e2.dot(e2);
|
// Check Invariants for both conditionals
|
||||||
EXPECT_DOUBLES_EQUAL(expected4, conditional2.logProbability(values), 1e-9);
|
for (auto conditional : {conditional1, conditional2}) {
|
||||||
|
EXPECT(GaussianConditional::CheckInvariants(conditional, values));
|
||||||
|
EXPECT(GaussianConditional::CheckInvariants(conditional,
|
||||||
|
HybridValues{values, {}, {}}));
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
|
|
|
@ -17,8 +17,8 @@ import numpy as np
|
||||||
from gtsam.symbol_shorthand import A, X
|
from gtsam.symbol_shorthand import A, X
|
||||||
from gtsam.utils.test_case import GtsamTestCase
|
from gtsam.utils.test_case import GtsamTestCase
|
||||||
|
|
||||||
from gtsam import (DiscreteConditional, DiscreteKeys, GaussianConditional,
|
from gtsam import (DiscreteConditional, DiscreteKeys, DiscreteValues, GaussianConditional,
|
||||||
GaussianMixture, HybridBayesNet, HybridValues, noiseModel)
|
GaussianMixture, HybridBayesNet, HybridValues, noiseModel, VectorValues)
|
||||||
|
|
||||||
|
|
||||||
class TestHybridBayesNet(GtsamTestCase):
|
class TestHybridBayesNet(GtsamTestCase):
|
||||||
|
@ -53,9 +53,13 @@ class TestHybridBayesNet(GtsamTestCase):
|
||||||
|
|
||||||
# Create values at which to evaluate.
|
# Create values at which to evaluate.
|
||||||
values = HybridValues()
|
values = HybridValues()
|
||||||
values.insert(asiaKey, 0)
|
continuous = VectorValues()
|
||||||
values.insert(X(0), [-6])
|
continuous.insert(X(0), [-6])
|
||||||
values.insert(X(1), [1])
|
continuous.insert(X(1), [1])
|
||||||
|
values.insert(continuous)
|
||||||
|
discrete = DiscreteValues()
|
||||||
|
discrete[asiaKey] = 0
|
||||||
|
values.insert(discrete)
|
||||||
|
|
||||||
conditionalProbability = conditional.evaluate(values.continuous())
|
conditionalProbability = conditional.evaluate(values.continuous())
|
||||||
mixtureProbability = conditional0.evaluate(values.continuous())
|
mixtureProbability = conditional0.evaluate(values.continuous())
|
||||||
|
@ -68,6 +72,26 @@ class TestHybridBayesNet(GtsamTestCase):
|
||||||
self.assertAlmostEqual(bayesNet.logProbability(values),
|
self.assertAlmostEqual(bayesNet.logProbability(values),
|
||||||
math.log(bayesNet.evaluate(values)))
|
math.log(bayesNet.evaluate(values)))
|
||||||
|
|
||||||
|
# Check invariance for all conditionals:
|
||||||
|
self.check_invariance(bayesNet.at(0).asGaussian(), continuous)
|
||||||
|
self.check_invariance(bayesNet.at(0).asGaussian(), values)
|
||||||
|
self.check_invariance(bayesNet.at(0), values)
|
||||||
|
|
||||||
|
self.check_invariance(bayesNet.at(1), values)
|
||||||
|
|
||||||
|
self.check_invariance(bayesNet.at(2).asDiscrete(), discrete)
|
||||||
|
self.check_invariance(bayesNet.at(2).asDiscrete(), values)
|
||||||
|
self.check_invariance(bayesNet.at(2), values)
|
||||||
|
|
||||||
|
def check_invariance(self, conditional, values):
|
||||||
|
"""Check invariance for given conditional."""
|
||||||
|
probability = conditional.evaluate(values)
|
||||||
|
self.assertTrue(probability >= 0.0)
|
||||||
|
logProb = conditional.logProbability(values)
|
||||||
|
self.assertAlmostEqual(probability, np.exp(logProb))
|
||||||
|
expected = conditional.logNormalizationConstant() - conditional.error(values)
|
||||||
|
self.assertAlmostEqual(logProb, expected)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
|
Loading…
Reference in New Issue