106 lines
4.0 KiB
C++
106 lines
4.0 KiB
C++
/* ----------------------------------------------------------------------------
|
|
|
|
* GTSAM Copyright 2010, Georgia Tech Research Corporation,
|
|
* Atlanta, Georgia 30332-0415
|
|
* All Rights Reserved
|
|
* Authors: Frank Dellaert, et al. (see THANKS for the full author list)
|
|
|
|
* See LICENSE for the license information
|
|
|
|
* -------------------------------------------------------------------------- */
|
|
|
|
/**
|
|
* @file Conditional.h
|
|
* @brief Base class for conditional densities
|
|
* @author Frank Dellaert
|
|
*/
|
|
|
|
// \callgraph
|
|
#pragma once
|
|
|
|
#include <gtsam/inference/Conditional.h>
|
|
|
|
#include <cmath>
|
|
#include <iostream>
|
|
|
|
namespace gtsam {
|
|
|
|
/* ************************************************************************* */
|
|
template <class FACTOR, class DERIVEDCONDITIONAL>
|
|
void Conditional<FACTOR, DERIVEDCONDITIONAL>::print(
|
|
const std::string& s, const KeyFormatter& formatter) const {
|
|
std::cout << s << " P(";
|
|
for (Key key : frontals()) std::cout << " " << formatter(key);
|
|
if (nrParents() > 0) std::cout << " |";
|
|
for (Key parent : parents()) std::cout << " " << formatter(parent);
|
|
std::cout << ")" << std::endl;
|
|
}
|
|
|
|
/* ************************************************************************* */
|
|
template <class FACTOR, class DERIVEDCONDITIONAL>
|
|
bool Conditional<FACTOR, DERIVEDCONDITIONAL>::equals(const This& c,
|
|
double tol) const {
|
|
return nrFrontals_ == c.nrFrontals_;
|
|
}
|
|
|
|
/* ************************************************************************* */
|
|
template <class FACTOR, class DERIVEDCONDITIONAL>
|
|
double Conditional<FACTOR, DERIVEDCONDITIONAL>::logProbability(
|
|
const HybridValues& c) const {
|
|
throw std::runtime_error("Conditional::logProbability is not implemented");
|
|
}
|
|
|
|
/* ************************************************************************* */
|
|
template <class FACTOR, class DERIVEDCONDITIONAL>
|
|
double Conditional<FACTOR, DERIVEDCONDITIONAL>::evaluate(
|
|
const HybridValues& c) const {
|
|
throw std::runtime_error("Conditional::evaluate is not implemented");
|
|
}
|
|
|
|
/* ************************************************************************* */
|
|
template <class FACTOR, class DERIVEDCONDITIONAL>
|
|
double Conditional<FACTOR, DERIVEDCONDITIONAL>::errorConstant()
|
|
const {
|
|
throw std::runtime_error(
|
|
"Conditional::errorConstant is not implemented");
|
|
}
|
|
|
|
/* ************************************************************************* */
|
|
template <class FACTOR, class DERIVEDCONDITIONAL>
|
|
double Conditional<FACTOR, DERIVEDCONDITIONAL>::logNormalizationConstant()
|
|
const {
|
|
return -errorConstant();
|
|
}
|
|
|
|
/* ************************************************************************* */
|
|
template <class FACTOR, class DERIVEDCONDITIONAL>
|
|
double Conditional<FACTOR, DERIVEDCONDITIONAL>::normalizationConstant() const {
|
|
return std::exp(logNormalizationConstant());
|
|
}
|
|
|
|
/* ************************************************************************* */
|
|
template <class FACTOR, class DERIVEDCONDITIONAL>
|
|
template <class VALUES>
|
|
bool Conditional<FACTOR, DERIVEDCONDITIONAL>::CheckInvariants(
|
|
const DERIVEDCONDITIONAL& conditional, const VALUES& values) {
|
|
const double prob_or_density = conditional.evaluate(values);
|
|
if (prob_or_density < 0.0) return false; // prob_or_density is negative.
|
|
if (std::abs(prob_or_density - conditional(values)) > 1e-9)
|
|
return false; // operator and evaluate differ
|
|
const double logProb = conditional.logProbability(values);
|
|
if (std::abs(prob_or_density - std::exp(logProb)) > 1e-9)
|
|
return false; // logProb is not consistent with prob_or_density
|
|
if (std::abs(conditional.logNormalizationConstant() -
|
|
std::log(conditional.normalizationConstant())) > 1e-9)
|
|
return false; // log normalization constant is not consistent with
|
|
// normalization constant
|
|
const double error = conditional.error(values);
|
|
if (error < 0.0) return false; // prob_or_density is negative.
|
|
const double expected = -(conditional.errorConstant() + error);
|
|
if (std::abs(logProb - expected) > 1e-9)
|
|
return false; // logProb is not consistent with error
|
|
return true;
|
|
}
|
|
|
|
} // namespace gtsam
|