Merge pull request #1380 from borglab/feature/uniform_error

release/4.3a0
Frank Dellaert 2023-01-12 07:48:01 -08:00 committed by GitHub
commit a34c463e2b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
59 changed files with 829 additions and 422 deletions

View File

@ -18,6 +18,7 @@
*/ */
#include <gtsam/base/FastSet.h> #include <gtsam/base/FastSet.h>
#include <gtsam/hybrid/HybridValues.h>
#include <gtsam/discrete/DecisionTreeFactor.h> #include <gtsam/discrete/DecisionTreeFactor.h>
#include <gtsam/discrete/DiscreteConditional.h> #include <gtsam/discrete/DiscreteConditional.h>
@ -56,6 +57,16 @@ namespace gtsam {
} }
} }
/* ************************************************************************ */
double DecisionTreeFactor::error(const DiscreteValues& values) const {
return -std::log(evaluate(values));
}
/* ************************************************************************ */
double DecisionTreeFactor::error(const HybridValues& values) const {
return error(values.discrete());
}
/* ************************************************************************ */ /* ************************************************************************ */
double DecisionTreeFactor::safe_div(const double& a, const double& b) { double DecisionTreeFactor::safe_div(const double& a, const double& b) {
// The use for safe_div is when we divide the product factor by the sum // The use for safe_div is when we divide the product factor by the sum

View File

@ -34,6 +34,7 @@
namespace gtsam { namespace gtsam {
class DiscreteConditional; class DiscreteConditional;
class HybridValues;
/** /**
* A discrete probabilistic factor. * A discrete probabilistic factor.
@ -97,11 +98,20 @@ namespace gtsam {
/// @name Standard Interface /// @name Standard Interface
/// @{ /// @{
/// Value is just look up in AlgebraicDecisonTree /// Calculate probability for given values `x`,
/// is just look up in AlgebraicDecisionTree.
double evaluate(const DiscreteValues& values) const {
return ADT::operator()(values);
}
/// Evaluate probability density, sugar.
double operator()(const DiscreteValues& values) const override { double operator()(const DiscreteValues& values) const override {
return ADT::operator()(values); return ADT::operator()(values);
} }
/// Calculate error for DiscreteValues `x`, is -log(probability).
double error(const DiscreteValues& values) const;
/// multiply two factors /// multiply two factors
DecisionTreeFactor operator*(const DecisionTreeFactor& f) const override { DecisionTreeFactor operator*(const DecisionTreeFactor& f) const override {
return apply(f, ADT::Ring::mul); return apply(f, ADT::Ring::mul);
@ -230,7 +240,17 @@ namespace gtsam {
std::string html(const KeyFormatter& keyFormatter = DefaultKeyFormatter, std::string html(const KeyFormatter& keyFormatter = DefaultKeyFormatter,
const Names& names = {}) const override; const Names& names = {}) const override;
/// @} /// @}
/// @name HybridValues methods.
/// @{
/**
* Calculate error for HybridValues `x`, is -log(probability)
* Simply dispatches to DiscreteValues version.
*/
double error(const HybridValues& values) const override;
/// @}
private: private:
/** Serialization function */ /** Serialization function */

View File

@ -33,6 +33,15 @@ bool DiscreteBayesNet::equals(const This& bn, double tol) const {
return Base::equals(bn, tol); return Base::equals(bn, tol);
} }
/* ************************************************************************* */
double DiscreteBayesNet::logProbability(const DiscreteValues& values) const {
// evaluate all conditionals and add
double result = 0.0;
for (const DiscreteConditional::shared_ptr& conditional : *this)
result += conditional->logProbability(values);
return result;
}
/* ************************************************************************* */ /* ************************************************************************* */
double DiscreteBayesNet::evaluate(const DiscreteValues& values) const { double DiscreteBayesNet::evaluate(const DiscreteValues& values) const {
// evaluate all conditionals and multiply // evaluate all conditionals and multiply

View File

@ -103,6 +103,9 @@ class GTSAM_EXPORT DiscreteBayesNet: public BayesNet<DiscreteConditional> {
return evaluate(values); return evaluate(values);
} }
//** log(evaluate(values)) for given DiscreteValues */
double logProbability(const DiscreteValues & values) const;
/** /**
* @brief do ancestral sampling * @brief do ancestral sampling
* *
@ -136,7 +139,15 @@ class GTSAM_EXPORT DiscreteBayesNet: public BayesNet<DiscreteConditional> {
std::string html(const KeyFormatter& keyFormatter = DefaultKeyFormatter, std::string html(const KeyFormatter& keyFormatter = DefaultKeyFormatter,
const DiscreteFactor::Names& names = {}) const; const DiscreteFactor::Names& names = {}) const;
///@} /// @}
/// @name HybridValues methods.
/// @{
using Base::error; // Expose error(const HybridValues&) method..
using Base::evaluate; // Expose evaluate(const HybridValues&) method..
using Base::logProbability; // Expose logProbability(const HybridValues&)
/// @}
#ifdef GTSAM_ALLOW_DEPRECATED_SINCE_V42 #ifdef GTSAM_ALLOW_DEPRECATED_SINCE_V42
/// @name Deprecated functionality /// @name Deprecated functionality

View File

@ -18,9 +18,9 @@
#pragma once #pragma once
#include <gtsam/inference/Conditional-inst.h>
#include <gtsam/discrete/DecisionTreeFactor.h> #include <gtsam/discrete/DecisionTreeFactor.h>
#include <gtsam/discrete/Signature.h> #include <gtsam/discrete/Signature.h>
#include <gtsam/inference/Conditional.h>
#include <boost/make_shared.hpp> #include <boost/make_shared.hpp>
#include <boost/shared_ptr.hpp> #include <boost/shared_ptr.hpp>
@ -147,6 +147,11 @@ class GTSAM_EXPORT DiscreteConditional
/// @name Standard Interface /// @name Standard Interface
/// @{ /// @{
/// Log-probability is just -error(x).
double logProbability(const DiscreteValues& x) const {
return -error(x);
}
/// print index signature only /// print index signature only
void printSignature( void printSignature(
const std::string& s = "Discrete Conditional: ", const std::string& s = "Discrete Conditional: ",
@ -225,6 +230,21 @@ class GTSAM_EXPORT DiscreteConditional
std::string html(const KeyFormatter& keyFormatter = DefaultKeyFormatter, std::string html(const KeyFormatter& keyFormatter = DefaultKeyFormatter,
const Names& names = {}) const override; const Names& names = {}) const override;
/// @}
/// @name HybridValues methods.
/// @{
/**
* Calculate log-probability log(evaluate(x)) for HybridValues `x`.
* This is actually just -error(x).
*/
double logProbability(const HybridValues& x) const override {
return -error(x);
}
using DecisionTreeFactor::evaluate;
/// @} /// @}
#ifdef GTSAM_ALLOW_DEPRECATED_SINCE_V42 #ifdef GTSAM_ALLOW_DEPRECATED_SINCE_V42

View File

@ -19,6 +19,7 @@
#include <gtsam/base/Vector.h> #include <gtsam/base/Vector.h>
#include <gtsam/discrete/DiscreteFactor.h> #include <gtsam/discrete/DiscreteFactor.h>
#include <gtsam/hybrid/HybridValues.h>
#include <cmath> #include <cmath>
#include <sstream> #include <sstream>
@ -27,6 +28,16 @@ using namespace std;
namespace gtsam { namespace gtsam {
/* ************************************************************************* */
double DiscreteFactor::error(const DiscreteValues& values) const {
return -std::log((*this)(values));
}
/* ************************************************************************* */
double DiscreteFactor::error(const HybridValues& c) const {
return this->error(c.discrete());
}
/* ************************************************************************* */ /* ************************************************************************* */
std::vector<double> expNormalize(const std::vector<double>& logProbs) { std::vector<double> expNormalize(const std::vector<double>& logProbs) {
double maxLogProb = -std::numeric_limits<double>::infinity(); double maxLogProb = -std::numeric_limits<double>::infinity();

View File

@ -27,6 +27,7 @@ namespace gtsam {
class DecisionTreeFactor; class DecisionTreeFactor;
class DiscreteConditional; class DiscreteConditional;
class HybridValues;
/** /**
* Base class for discrete probabilistic factors * Base class for discrete probabilistic factors
@ -83,6 +84,15 @@ public:
/// Find value for given assignment of values to variables /// Find value for given assignment of values to variables
virtual double operator()(const DiscreteValues&) const = 0; virtual double operator()(const DiscreteValues&) const = 0;
/// Error is just -log(value)
double error(const DiscreteValues& values) const;
/**
* The Factor::error simply extracts the \class DiscreteValues from the
* \class HybridValues and calculates the error.
*/
double error(const HybridValues& c) const override;
/// Multiply in a DecisionTreeFactor and return the result as DecisionTreeFactor /// Multiply in a DecisionTreeFactor and return the result as DecisionTreeFactor
virtual DecisionTreeFactor operator*(const DecisionTreeFactor&) const = 0; virtual DecisionTreeFactor operator*(const DecisionTreeFactor&) const = 0;

View File

@ -222,6 +222,12 @@ class GTSAM_EXPORT DiscreteFactorGraph
std::string html(const KeyFormatter& keyFormatter = DefaultKeyFormatter, std::string html(const KeyFormatter& keyFormatter = DefaultKeyFormatter,
const DiscreteFactor::Names& names = {}) const; const DiscreteFactor::Names& names = {}) const;
/// @}
/// @name HybridValues methods.
/// @{
using Base::error; // Expose error(const HybridValues&) method..
/// @} /// @}
}; // \ DiscreteFactorGraph }; // \ DiscreteFactorGraph

View File

@ -95,6 +95,9 @@ virtual class DiscreteConditional : gtsam::DecisionTreeFactor {
DiscreteConditional(const gtsam::DecisionTreeFactor& joint, DiscreteConditional(const gtsam::DecisionTreeFactor& joint,
const gtsam::DecisionTreeFactor& marginal, const gtsam::DecisionTreeFactor& marginal,
const gtsam::Ordering& orderedKeys); const gtsam::Ordering& orderedKeys);
double logProbability(const gtsam::DiscreteValues& values) const;
double evaluate(const gtsam::DiscreteValues& values) const;
double operator()(const gtsam::DiscreteValues& values) const;
gtsam::DiscreteConditional operator*( gtsam::DiscreteConditional operator*(
const gtsam::DiscreteConditional& other) const; const gtsam::DiscreteConditional& other) const;
gtsam::DiscreteConditional marginal(gtsam::Key key) const; gtsam::DiscreteConditional marginal(gtsam::Key key) const;
@ -157,7 +160,12 @@ class DiscreteBayesNet {
const gtsam::KeyFormatter& keyFormatter = const gtsam::KeyFormatter& keyFormatter =
gtsam::DefaultKeyFormatter) const; gtsam::DefaultKeyFormatter) const;
bool equals(const gtsam::DiscreteBayesNet& other, double tol = 1e-9) const; bool equals(const gtsam::DiscreteBayesNet& other, double tol = 1e-9) const;
// Standard interface.
double logProbability(const gtsam::DiscreteValues& values) const;
double evaluate(const gtsam::DiscreteValues& values) const;
double operator()(const gtsam::DiscreteValues& values) const; double operator()(const gtsam::DiscreteValues& values) const;
gtsam::DiscreteValues sample() const; gtsam::DiscreteValues sample() const;
gtsam::DiscreteValues sample(gtsam::DiscreteValues given) const; gtsam::DiscreteValues sample(gtsam::DiscreteValues given) const;

View File

@ -48,6 +48,9 @@ TEST( DecisionTreeFactor, constructors)
EXPECT_DOUBLES_EQUAL(8, f1(values), 1e-9); EXPECT_DOUBLES_EQUAL(8, f1(values), 1e-9);
EXPECT_DOUBLES_EQUAL(7, f2(values), 1e-9); EXPECT_DOUBLES_EQUAL(7, f2(values), 1e-9);
EXPECT_DOUBLES_EQUAL(75, f3(values), 1e-9); EXPECT_DOUBLES_EQUAL(75, f3(values), 1e-9);
// Assert that error = -log(value)
EXPECT_DOUBLES_EQUAL(-log(f1(values)), f1.error(values), 1e-9);
} }
/* ************************************************************************* */ /* ************************************************************************* */

View File

@ -25,7 +25,6 @@
#include <CppUnitLite/TestHarness.h> #include <CppUnitLite/TestHarness.h>
#include <iostream> #include <iostream>
#include <string> #include <string>
#include <vector> #include <vector>
@ -101,6 +100,11 @@ TEST(DiscreteBayesNet, Asia) {
DiscreteConditional expected2(Bronchitis % "11/9"); DiscreteConditional expected2(Bronchitis % "11/9");
EXPECT(assert_equal(expected2, *chordal->back())); EXPECT(assert_equal(expected2, *chordal->back()));
// Check evaluate and logProbability
auto result = fg.optimize();
EXPECT_DOUBLES_EQUAL(asia.logProbability(result),
std::log(asia.evaluate(result)), 1e-9);
// add evidence, we were in Asia and we have dyspnea // add evidence, we were in Asia and we have dyspnea
fg.add(Asia, "0 1"); fg.add(Asia, "0 1");
fg.add(Dyspnea, "0 1"); fg.add(Dyspnea, "0 1");

View File

@ -88,6 +88,29 @@ TEST(DiscreteConditional, constructors3) {
EXPECT(assert_equal(expected, static_cast<DecisionTreeFactor>(actual))); EXPECT(assert_equal(expected, static_cast<DecisionTreeFactor>(actual)));
} }
/* ****************************************************************************/
// Test evaluate for a discrete Prior P(Asia).
TEST(DiscreteConditional, PriorProbability) {
constexpr Key asiaKey = 0;
const DiscreteKey Asia(asiaKey, 2);
DiscreteConditional dc(Asia, "4/6");
DiscreteValues values{{asiaKey, 0}};
EXPECT_DOUBLES_EQUAL(0.4, dc.evaluate(values), 1e-9);
}
/* ************************************************************************* */
// Check that error, logProbability, evaluate all work as expected.
TEST(DiscreteConditional, probability) {
DiscreteKey C(2, 2), D(4, 2), E(3, 2);
DiscreteConditional C_given_DE((C | D, E) = "4/1 1/1 1/1 1/4");
DiscreteValues given {{C.first, 1}, {D.first, 0}, {E.first, 0}};
EXPECT_DOUBLES_EQUAL(0.2, C_given_DE.evaluate(given), 1e-9);
EXPECT_DOUBLES_EQUAL(0.2, C_given_DE(given), 1e-9);
EXPECT_DOUBLES_EQUAL(log(0.2), C_given_DE.logProbability(given), 1e-9);
EXPECT_DOUBLES_EQUAL(-log(0.2), C_given_DE.error(given), 1e-9);
}
/* ************************************************************************* */ /* ************************************************************************* */
// Check calculation of joint P(A,B) // Check calculation of joint P(A,B)
TEST(DiscreteConditional, Multiply) { TEST(DiscreteConditional, Multiply) {

View File

@ -271,15 +271,16 @@ void GaussianMixture::prune(const DecisionTreeFactor &decisionTree) {
} }
/* *******************************************************************************/ /* *******************************************************************************/
AlgebraicDecisionTree<Key> GaussianMixture::error( AlgebraicDecisionTree<Key> GaussianMixture::logProbability(
const VectorValues &continuousValues) const { const VectorValues &continuousValues) const {
// functor to calculate to double error value from GaussianConditional. // functor to calculate to double logProbability value from
// GaussianConditional.
auto errorFunc = auto errorFunc =
[continuousValues](const GaussianConditional::shared_ptr &conditional) { [continuousValues](const GaussianConditional::shared_ptr &conditional) {
if (conditional) { if (conditional) {
return conditional->error(continuousValues); return conditional->logProbability(continuousValues);
} else { } else {
// Return arbitrarily large error if conditional is null // Return arbitrarily large logProbability if conditional is null
// Conditional is null if it is pruned out. // Conditional is null if it is pruned out.
return 1e50; return 1e50;
} }
@ -289,10 +290,10 @@ AlgebraicDecisionTree<Key> GaussianMixture::error(
} }
/* *******************************************************************************/ /* *******************************************************************************/
double GaussianMixture::error(const HybridValues &values) const { double GaussianMixture::logProbability(const HybridValues &values) const {
// Directly index to get the conditional, no need to build the whole tree. // Directly index to get the conditional, no need to build the whole tree.
auto conditional = conditionals_(values.discrete()); auto conditional = conditionals_(values.discrete());
return conditional->error(values.continuous()); return conditional->logProbability(values.continuous());
} }
} // namespace gtsam } // namespace gtsam

View File

@ -164,22 +164,23 @@ class GTSAM_EXPORT GaussianMixture
const Conditionals &conditionals() const; const Conditionals &conditionals() const;
/** /**
* @brief Compute error of the GaussianMixture as a tree. * @brief Compute logProbability of the GaussianMixture as a tree.
* *
* @param continuousValues The continuous VectorValues. * @param continuousValues The continuous VectorValues.
* @return AlgebraicDecisionTree<Key> A decision tree with the same keys * @return AlgebraicDecisionTree<Key> A decision tree with the same keys
* as the conditionals, and leaf values as the error. * as the conditionals, and leaf values as the logProbability.
*/ */
AlgebraicDecisionTree<Key> error(const VectorValues &continuousValues) const; AlgebraicDecisionTree<Key> logProbability(
const VectorValues &continuousValues) const;
/** /**
* @brief Compute the error of this Gaussian Mixture given the continuous * @brief Compute the logProbability of this Gaussian Mixture given the
* values and a discrete assignment. * continuous values and a discrete assignment.
* *
* @param values Continuous values and discrete assignment. * @param values Continuous values and discrete assignment.
* @return double * @return double
*/ */
double error(const HybridValues &values) const override; double logProbability(const HybridValues &values) const override;
// /// Calculate probability density for given values `x`. // /// Calculate probability density for given values `x`.
// double evaluate(const HybridValues &values) const; // double evaluate(const HybridValues &values) const;
@ -188,9 +189,6 @@ class GTSAM_EXPORT GaussianMixture
// double operator()(const HybridValues &values) const { return // double operator()(const HybridValues &values) const { return
// evaluate(values); } // evaluate(values); }
// /// Calculate log-density for given values `x`.
// double logDensity(const HybridValues &values) const;
/** /**
* @brief Prune the decision tree of Gaussian factors as per the discrete * @brief Prune the decision tree of Gaussian factors as per the discrete
* `decisionTree`. * `decisionTree`.

View File

@ -255,33 +255,6 @@ VectorValues HybridBayesNet::optimize(const DiscreteValues &assignment) const {
return gbn.optimize(); return gbn.optimize();
} }
/* ************************************************************************* */
double HybridBayesNet::evaluate(const HybridValues &values) const {
const DiscreteValues &discreteValues = values.discrete();
const VectorValues &continuousValues = values.continuous();
double logDensity = 0.0, probability = 1.0;
// Iterate over each conditional.
for (auto &&conditional : *this) {
// TODO: should be delegated to derived classes.
if (auto gm = conditional->asMixture()) {
const auto component = (*gm)(discreteValues);
logDensity += component->logDensity(continuousValues);
} else if (auto gc = conditional->asGaussian()) {
// If continuous only, evaluate the probability and multiply.
logDensity += gc->logDensity(continuousValues);
} else if (auto dc = conditional->asDiscrete()) {
// Conditional is discrete-only, so return its probability.
probability *= dc->operator()(discreteValues);
}
}
return probability * exp(logDensity);
}
/* ************************************************************************* */ /* ************************************************************************* */
HybridValues HybridBayesNet::sample(const HybridValues &given, HybridValues HybridBayesNet::sample(const HybridValues &given,
std::mt19937_64 *rng) const { std::mt19937_64 *rng) const {
@ -318,45 +291,45 @@ HybridValues HybridBayesNet::sample() const {
} }
/* ************************************************************************* */ /* ************************************************************************* */
double HybridBayesNet::error(const HybridValues &values) const { AlgebraicDecisionTree<Key> HybridBayesNet::logProbability(
GaussianBayesNet gbn = choose(values.discrete());
return gbn.error(values.continuous());
}
/* ************************************************************************* */
AlgebraicDecisionTree<Key> HybridBayesNet::error(
const VectorValues &continuousValues) const { const VectorValues &continuousValues) const {
AlgebraicDecisionTree<Key> error_tree(0.0); AlgebraicDecisionTree<Key> result(0.0);
// Iterate over each conditional. // Iterate over each conditional.
for (auto &&conditional : *this) { for (auto &&conditional : *this) {
if (auto gm = conditional->asMixture()) { if (auto gm = conditional->asMixture()) {
// If conditional is hybrid, select based on assignment and compute error. // If conditional is hybrid, select based on assignment and compute
AlgebraicDecisionTree<Key> conditional_error = // logProbability.
gm->error(continuousValues); result = result + gm->logProbability(continuousValues);
error_tree = error_tree + conditional_error;
} else if (auto gc = conditional->asGaussian()) { } else if (auto gc = conditional->asGaussian()) {
// If continuous only, get the (double) error // If continuous, get the (double) logProbability and add it to the
// and add it to the error_tree // result
double error = gc->error(continuousValues); double logProbability = gc->logProbability(continuousValues);
// Add the computed error to every leaf of the error tree. // Add the computed logProbability to every leaf of the logProbability
error_tree = error_tree.apply( // tree.
[error](double leaf_value) { return leaf_value + error; }); result = result.apply([logProbability](double leaf_value) {
return leaf_value + logProbability;
});
} else if (auto dc = conditional->asDiscrete()) { } else if (auto dc = conditional->asDiscrete()) {
// Conditional is discrete-only, we skip. // TODO(dellaert): if discrete, we need to add logProbability in the right
// branch?
continue; continue;
} }
} }
return error_tree; return result;
} }
/* ************************************************************************* */ /* ************************************************************************* */
AlgebraicDecisionTree<Key> HybridBayesNet::probPrime( AlgebraicDecisionTree<Key> HybridBayesNet::evaluate(
const VectorValues &continuousValues) const { const VectorValues &continuousValues) const {
AlgebraicDecisionTree<Key> error_tree = this->error(continuousValues); AlgebraicDecisionTree<Key> tree = this->logProbability(continuousValues);
return error_tree.apply([](double error) { return exp(-error); }); return tree.apply([](double log) { return exp(log); });
}
/* ************************************************************************* */
double HybridBayesNet::evaluate(const HybridValues &values) const {
return exp(logProbability(values));
} }
/* ************************************************************************* */ /* ************************************************************************* */

View File

@ -187,15 +187,6 @@ class GTSAM_EXPORT HybridBayesNet : public BayesNet<HybridConditional> {
/// Prune the Hybrid Bayes Net such that we have at most maxNrLeaves leaves. /// Prune the Hybrid Bayes Net such that we have at most maxNrLeaves leaves.
HybridBayesNet prune(size_t maxNrLeaves); HybridBayesNet prune(size_t maxNrLeaves);
/**
* @brief 0.5 * sum of squared Mahalanobis distances
* for a specific discrete assignment.
*
* @param values Continuous values and discrete assignment.
* @return double
*/
double error(const HybridValues &values) const;
/** /**
* @brief Compute conditional error for each discrete assignment, * @brief Compute conditional error for each discrete assignment,
* and return as a tree. * and return as a tree.
@ -203,7 +194,10 @@ class GTSAM_EXPORT HybridBayesNet : public BayesNet<HybridConditional> {
* @param continuousValues Continuous values at which to compute the error. * @param continuousValues Continuous values at which to compute the error.
* @return AlgebraicDecisionTree<Key> * @return AlgebraicDecisionTree<Key>
*/ */
AlgebraicDecisionTree<Key> error(const VectorValues &continuousValues) const; AlgebraicDecisionTree<Key> logProbability(
const VectorValues &continuousValues) const;
using BayesNet::logProbability; // expose HybridValues version
/** /**
* @brief Compute unnormalized probability q(μ|M), * @brief Compute unnormalized probability q(μ|M),
@ -215,7 +209,7 @@ class GTSAM_EXPORT HybridBayesNet : public BayesNet<HybridConditional> {
* probability. * probability.
* @return AlgebraicDecisionTree<Key> * @return AlgebraicDecisionTree<Key>
*/ */
AlgebraicDecisionTree<Key> probPrime( AlgebraicDecisionTree<Key> evaluate(
const VectorValues &continuousValues) const; const VectorValues &continuousValues) const;
/** /**

View File

@ -122,18 +122,18 @@ bool HybridConditional::equals(const HybridFactor &other, double tol) const {
} }
/* ************************************************************************ */ /* ************************************************************************ */
double HybridConditional::error(const HybridValues &values) const { double HybridConditional::logProbability(const HybridValues &values) const {
if (auto gm = asMixture()) {
return gm->error(values);
}
if (auto gc = asGaussian()) { if (auto gc = asGaussian()) {
return gc->error(values.continuous()); return gc->logProbability(values.continuous());
}
if (auto gm = asMixture()) {
return gm->logProbability(values);
} }
if (auto dc = asDiscrete()) { if (auto dc = asDiscrete()) {
return -log((*dc)(values.discrete())); return dc->logProbability(values.discrete());
} }
throw std::runtime_error( throw std::runtime_error(
"HybridConditional::error: conditional type not handled"); "HybridConditional::logProbability: conditional type not handled");
} }
} // namespace gtsam } // namespace gtsam

View File

@ -176,8 +176,8 @@ class GTSAM_EXPORT HybridConditional
/// Get the type-erased pointer to the inner type /// Get the type-erased pointer to the inner type
boost::shared_ptr<Factor> inner() const { return inner_; } boost::shared_ptr<Factor> inner() const { return inner_; }
/// Return the error of the underlying conditional. /// Return the logProbability of the underlying conditional.
double error(const HybridValues& values) const override; double logProbability(const HybridValues& values) const override;
/// Check if VectorValues `measurements` contains all frontal keys. /// Check if VectorValues `measurements` contains all frontal keys.
bool frontalsIn(const VectorValues& measurements) const { bool frontalsIn(const VectorValues& measurements) const {

View File

@ -143,15 +143,6 @@ class GTSAM_EXPORT HybridFactor : public Factor {
/// @name Standard Interface /// @name Standard Interface
/// @{ /// @{
/**
* @brief Compute the error of this Gaussian Mixture given the continuous
* values and a discrete assignment.
*
* @param values Continuous values and discrete assignment.
* @return double
*/
virtual double error(const HybridValues &values) const = 0;
/// True if this is a factor of discrete variables only. /// True if this is a factor of discrete variables only.
bool isDiscrete() const { return isDiscrete_; } bool isDiscrete() const { return isDiscrete_; }

View File

@ -25,17 +25,17 @@
namespace gtsam { namespace gtsam {
/* ************************************************************************* */ /* ************************************************************************* */
DiscreteKeys HybridFactorGraph::discreteKeys() const { std::set<DiscreteKey> HybridFactorGraph::discreteKeys() const {
DiscreteKeys keys; std::set<DiscreteKey> keys;
for (auto& factor : factors_) { for (auto& factor : factors_) {
if (auto p = boost::dynamic_pointer_cast<DecisionTreeFactor>(factor)) { if (auto p = boost::dynamic_pointer_cast<DecisionTreeFactor>(factor)) {
for (const DiscreteKey& key : p->discreteKeys()) { for (const DiscreteKey& key : p->discreteKeys()) {
keys.push_back(key); keys.insert(key);
} }
} }
if (auto p = boost::dynamic_pointer_cast<HybridFactor>(factor)) { if (auto p = boost::dynamic_pointer_cast<HybridFactor>(factor)) {
for (const DiscreteKey& key : p->discreteKeys()) { for (const DiscreteKey& key : p->discreteKeys()) {
keys.push_back(key); keys.insert(key);
} }
} }
} }

View File

@ -65,7 +65,7 @@ class HybridFactorGraph : public FactorGraph<Factor> {
/// @{ /// @{
/// Get all the discrete keys in the factor graph. /// Get all the discrete keys in the factor graph.
DiscreteKeys discreteKeys() const; std::set<DiscreteKey> discreteKeys() const;
/// Get all the discrete keys in the factor graph, as a set. /// Get all the discrete keys in the factor graph, as a set.
KeySet discreteKeySet() const; KeySet discreteKeySet() const;

View File

@ -463,24 +463,6 @@ AlgebraicDecisionTree<Key> HybridGaussianFactorGraph::error(
return error_tree; return error_tree;
} }
/* ************************************************************************ */
double HybridGaussianFactorGraph::error(const HybridValues &values) const {
double error = 0.0;
for (auto &f : factors_) {
if (auto hf = dynamic_pointer_cast<GaussianFactor>(f)) {
error += hf->error(values.continuous());
} else if (auto hf = dynamic_pointer_cast<HybridFactor>(f)) {
// TODO(dellaert): needs to change when we discard other wrappers.
error += hf->error(values);
} else if (auto dtf = dynamic_pointer_cast<DecisionTreeFactor>(f)) {
error -= log((*dtf)(values.discrete()));
} else {
throwRuntimeError("HybridGaussianFactorGraph::error(HV)", f);
}
}
return error;
}
/* ************************************************************************ */ /* ************************************************************************ */
double HybridGaussianFactorGraph::probPrime(const HybridValues &values) const { double HybridGaussianFactorGraph::probPrime(const HybridValues &values) const {
double error = this->error(values); double error = this->error(values);

View File

@ -145,6 +145,8 @@ class GTSAM_EXPORT HybridGaussianFactorGraph
/// @name Standard Interface /// @name Standard Interface
/// @{ /// @{
using Base::error; // Expose error(const HybridValues&) method..
/** /**
* @brief Compute error for each discrete assignment, * @brief Compute error for each discrete assignment,
* and return as a tree. * and return as a tree.
@ -156,14 +158,6 @@ class GTSAM_EXPORT HybridGaussianFactorGraph
*/ */
AlgebraicDecisionTree<Key> error(const VectorValues& continuousValues) const; AlgebraicDecisionTree<Key> error(const VectorValues& continuousValues) const;
/**
* @brief Compute error given a continuous vector values
* and a discrete assignment.
*
* @return double
*/
double error(const HybridValues& values) const;
/** /**
* @brief Compute unnormalized probability \f$ P(X | M, Z) \f$ * @brief Compute unnormalized probability \f$ P(X | M, Z) \f$
* for each discrete assignment, and return as a tree. * for each discrete assignment, and return as a tree.

View File

@ -55,12 +55,18 @@ class GTSAM_EXPORT HybridNonlinearFactorGraph : public HybridFactorGraph {
: Base(graph) {} : Base(graph) {}
/// @} /// @}
/// @name Constructors
/// @{
/// Print the factor graph. /// Print the factor graph.
void print( void print(
const std::string& s = "HybridNonlinearFactorGraph", const std::string& s = "HybridNonlinearFactorGraph",
const KeyFormatter& keyFormatter = DefaultKeyFormatter) const override; const KeyFormatter& keyFormatter = DefaultKeyFormatter) const override;
/// @}
/// @name Standard Interface
/// @{
/** /**
* @brief Linearize all the continuous factors in the * @brief Linearize all the continuous factors in the
* HybridNonlinearFactorGraph. * HybridNonlinearFactorGraph.
@ -70,6 +76,7 @@ class GTSAM_EXPORT HybridNonlinearFactorGraph : public HybridFactorGraph {
*/ */
HybridGaussianFactorGraph::shared_ptr linearize( HybridGaussianFactorGraph::shared_ptr linearize(
const Values& continuousValues) const; const Values& continuousValues) const;
/// @}
}; };
template <> template <>

View File

@ -37,12 +37,15 @@ namespace gtsam {
*/ */
class GTSAM_EXPORT HybridValues { class GTSAM_EXPORT HybridValues {
private: private:
// VectorValue stored the continuous components of the HybridValues. /// Continuous multi-dimensional vectors for \class GaussianFactor.
VectorValues continuous_; VectorValues continuous_;
// DiscreteValue stored the discrete components of the HybridValues. /// Discrete values for \class DiscreteFactor.
DiscreteValues discrete_; DiscreteValues discrete_;
/// Continuous, differentiable manifold values for \class NonlinearFactor.
Values nonlinear_;
public: public:
/// @name Standard Constructors /// @name Standard Constructors
/// @{ /// @{
@ -54,6 +57,11 @@ class GTSAM_EXPORT HybridValues {
HybridValues(const VectorValues& cv, const DiscreteValues& dv) HybridValues(const VectorValues& cv, const DiscreteValues& dv)
: continuous_(cv), discrete_(dv){}; : continuous_(cv), discrete_(dv){};
/// Construct from all values types.
HybridValues(const VectorValues& cv, const DiscreteValues& dv,
const Values& v)
: continuous_(cv), discrete_(dv), nonlinear_(v){};
/// @} /// @}
/// @name Testable /// @name Testable
/// @{ /// @{
@ -77,26 +85,30 @@ class GTSAM_EXPORT HybridValues {
/// @name Interface /// @name Interface
/// @{ /// @{
/// Return the discrete MPE assignment /// Return the multi-dimensional vector values.
const DiscreteValues& discrete() const { return discrete_; }
/// Return the delta update for the continuous vectors
const VectorValues& continuous() const { return continuous_; } const VectorValues& continuous() const { return continuous_; }
/// Check whether a variable with key \c j exists in DiscreteValue. /// Return the discrete values.
bool existsDiscrete(Key j) { return (discrete_.find(j) != discrete_.end()); }; const DiscreteValues& discrete() const { return discrete_; }
/// Check whether a variable with key \c j exists in VectorValue. /// Return the nonlinear values.
const Values& nonlinear() const { return nonlinear_; }
/// Check whether a variable with key \c j exists in VectorValues.
bool existsVector(Key j) { return continuous_.exists(j); }; bool existsVector(Key j) { return continuous_.exists(j); };
/// Check whether a variable with key \c j exists. /// Check whether a variable with key \c j exists in DiscreteValues.
bool exists(Key j) { return existsDiscrete(j) || existsVector(j); }; bool existsDiscrete(Key j) { return (discrete_.find(j) != discrete_.end()); };
/** Insert a discrete \c value with key \c j. Replaces the existing value if /// Check whether a variable with key \c j exists in values.
* the key \c j is already used. bool existsNonlinear(Key j) {
* @param value The vector to be inserted. return (nonlinear_.find(j) != nonlinear_.end());
* @param j The index with which the value will be associated. */ };
void insert(Key j, size_t value) { discrete_[j] = value; };
/// Check whether a variable with key \c j exists.
bool exists(Key j) {
return existsVector(j) || existsDiscrete(j) || existsNonlinear(j);
};
/** Insert a vector \c value with key \c j. Throws an invalid_argument /** Insert a vector \c value with key \c j. Throws an invalid_argument
* exception if the key \c j is already used. * exception if the key \c j is already used.
@ -104,6 +116,12 @@ class GTSAM_EXPORT HybridValues {
* @param j The index with which the value will be associated. */ * @param j The index with which the value will be associated. */
void insert(Key j, const Vector& value) { continuous_.insert(j, value); } void insert(Key j, const Vector& value) { continuous_.insert(j, value); }
/** Insert a discrete \c value with key \c j. Replaces the existing value if
* the key \c j is already used.
* @param value The vector to be inserted.
* @param j The index with which the value will be associated. */
void insert(Key j, size_t value) { discrete_[j] = value; };
/** Insert all continuous values from \c values. Throws an invalid_argument /** Insert all continuous values from \c values. Throws an invalid_argument
* exception if any keys to be inserted are already used. */ * exception if any keys to be inserted are already used. */
HybridValues& insert(const VectorValues& values) { HybridValues& insert(const VectorValues& values) {
@ -118,28 +136,36 @@ class GTSAM_EXPORT HybridValues {
return *this; return *this;
} }
/** Insert all values from \c values. Throws an invalid_argument
* exception if any keys to be inserted are already used. */
HybridValues& insert(const Values& values) {
nonlinear_.insert(values);
return *this;
}
/** Insert all values from \c values. Throws an invalid_argument exception if /** Insert all values from \c values. Throws an invalid_argument exception if
* any keys to be inserted are already used. */ * any keys to be inserted are already used. */
HybridValues& insert(const HybridValues& values) { HybridValues& insert(const HybridValues& values) {
continuous_.insert(values.continuous()); continuous_.insert(values.continuous());
discrete_.insert(values.discrete()); discrete_.insert(values.discrete());
nonlinear_.insert(values.nonlinear());
return *this; return *this;
} }
// TODO(Shangjie)- insert_or_assign() , similar to Values.h // TODO(Shangjie)- insert_or_assign() , similar to Values.h
/**
* Read/write access to the discrete value with key \c j, throws
* std::out_of_range if \c j does not exist.
*/
size_t& atDiscrete(Key j) { return discrete_.at(j); };
/** /**
* Read/write access to the vector value with key \c j, throws * Read/write access to the vector value with key \c j, throws
* std::out_of_range if \c j does not exist. * std::out_of_range if \c j does not exist.
*/ */
Vector& at(Key j) { return continuous_.at(j); }; Vector& at(Key j) { return continuous_.at(j); };
/**
* Read/write access to the discrete value with key \c j, throws
* std::out_of_range if \c j does not exist.
*/
size_t& atDiscrete(Key j) { return discrete_.at(j); };
/** For all key/value pairs in \c values, replace continuous values with /** For all key/value pairs in \c values, replace continuous values with
* corresponding keys in this object with those in \c values. Throws * corresponding keys in this object with those in \c values. Throws
* std::out_of_range if any keys in \c values are not present in this object. * std::out_of_range if any keys in \c values are not present in this object.

View File

@ -61,6 +61,9 @@ virtual class HybridConditional {
size_t nrParents() const; size_t nrParents() const;
// Standard interface: // Standard interface:
double logProbability(const gtsam::HybridValues& values) const;
double evaluate(const gtsam::HybridValues& values) const;
double operator()(const gtsam::HybridValues& values) const;
gtsam::GaussianMixture* asMixture() const; gtsam::GaussianMixture* asMixture() const;
gtsam::GaussianConditional* asGaussian() const; gtsam::GaussianConditional* asGaussian() const;
gtsam::DiscreteConditional* asDiscrete() const; gtsam::DiscreteConditional* asDiscrete() const;
@ -133,7 +136,10 @@ class HybridBayesNet {
gtsam::KeySet keys() const; gtsam::KeySet keys() const;
const gtsam::HybridConditional* at(size_t i) const; const gtsam::HybridConditional* at(size_t i) const;
double evaluate(const gtsam::HybridValues& x) const; // Standard interface:
double logProbability(const gtsam::HybridValues& values) const;
double evaluate(const gtsam::HybridValues& values) const;
gtsam::HybridValues optimize() const; gtsam::HybridValues optimize() const;
gtsam::HybridValues sample(const gtsam::HybridValues &given) const; gtsam::HybridValues sample(const gtsam::HybridValues &given) const;
gtsam::HybridValues sample() const; gtsam::HybridValues sample() const;

View File

@ -116,11 +116,12 @@ TEST(GaussianMixture, Error) {
VectorValues values; VectorValues values;
values.insert(X(1), Vector2::Ones()); values.insert(X(1), Vector2::Ones());
values.insert(X(2), Vector2::Zero()); values.insert(X(2), Vector2::Zero());
auto error_tree = mixture.error(values); auto error_tree = mixture.logProbability(values);
// regression // Check result.
std::vector<DiscreteKey> discrete_keys = {m1}; std::vector<DiscreteKey> discrete_keys = {m1};
std::vector<double> leaves = {0.5, 4.3252595}; std::vector<double> leaves = {conditional0->logProbability(values),
conditional1->logProbability(values)};
AlgebraicDecisionTree<Key> expected_error(discrete_keys, leaves); AlgebraicDecisionTree<Key> expected_error(discrete_keys, leaves);
EXPECT(assert_equal(expected_error, error_tree, 1e-6)); EXPECT(assert_equal(expected_error, error_tree, 1e-6));
@ -128,10 +129,11 @@ TEST(GaussianMixture, Error) {
// Regression for non-tree version. // Regression for non-tree version.
DiscreteValues assignment; DiscreteValues assignment;
assignment[M(1)] = 0; assignment[M(1)] = 0;
EXPECT_DOUBLES_EQUAL(0.5, mixture.error({values, assignment}), 1e-8); EXPECT_DOUBLES_EQUAL(conditional0->logProbability(values),
mixture.logProbability({values, assignment}), 1e-8);
assignment[M(1)] = 1; assignment[M(1)] = 1;
EXPECT_DOUBLES_EQUAL(4.3252595155709335, mixture.error({values, assignment}), EXPECT_DOUBLES_EQUAL(conditional1->logProbability(values),
1e-8); mixture.logProbability({values, assignment}), 1e-8);
} }
/* ************************************************************************* */ /* ************************************************************************* */

View File

@ -64,10 +64,10 @@ TEST(HybridBayesNet, Add) {
// Test evaluate for a pure discrete Bayes net P(Asia). // Test evaluate for a pure discrete Bayes net P(Asia).
TEST(HybridBayesNet, EvaluatePureDiscrete) { TEST(HybridBayesNet, EvaluatePureDiscrete) {
HybridBayesNet bayesNet; HybridBayesNet bayesNet;
bayesNet.emplace_back(new DiscreteConditional(Asia, "99/1")); bayesNet.emplace_back(new DiscreteConditional(Asia, "4/6"));
HybridValues values; HybridValues values;
values.insert(asiaKey, 0); values.insert(asiaKey, 0);
EXPECT_DOUBLES_EQUAL(0.99, bayesNet.evaluate(values), 1e-9); EXPECT_DOUBLES_EQUAL(0.4, bayesNet.evaluate(values), 1e-9);
} }
/* ****************************************************************************/ /* ****************************************************************************/
@ -207,55 +207,57 @@ TEST(HybridBayesNet, Optimize) {
/* ****************************************************************************/ /* ****************************************************************************/
// Test Bayes net error // Test Bayes net error
TEST(HybridBayesNet, Error) { TEST(HybridBayesNet, logProbability) {
Switching s(3); Switching s(3);
HybridBayesNet::shared_ptr hybridBayesNet = HybridBayesNet::shared_ptr hybridBayesNet =
s.linearizedFactorGraph.eliminateSequential(); s.linearizedFactorGraph.eliminateSequential();
EXPECT_LONGS_EQUAL(5, hybridBayesNet->size());
HybridValues delta = hybridBayesNet->optimize(); HybridValues delta = hybridBayesNet->optimize();
auto error_tree = hybridBayesNet->error(delta.continuous()); auto error_tree = hybridBayesNet->logProbability(delta.continuous());
std::vector<DiscreteKey> discrete_keys = {{M(0), 2}, {M(1), 2}}; std::vector<DiscreteKey> discrete_keys = {{M(0), 2}, {M(1), 2}};
std::vector<double> leaves = {0.0097568009, 3.3973404e-31, 0.029126214, std::vector<double> leaves = {4.1609374, 4.1706942, 4.141568, 4.1609374};
0.0097568009};
AlgebraicDecisionTree<Key> expected_error(discrete_keys, leaves); AlgebraicDecisionTree<Key> expected_error(discrete_keys, leaves);
// regression // regression
EXPECT(assert_equal(expected_error, error_tree, 1e-9)); EXPECT(assert_equal(expected_error, error_tree, 1e-6));
// Error on pruned Bayes net // logProbability on pruned Bayes net
auto prunedBayesNet = hybridBayesNet->prune(2); auto prunedBayesNet = hybridBayesNet->prune(2);
auto pruned_error_tree = prunedBayesNet.error(delta.continuous()); auto pruned_error_tree = prunedBayesNet.logProbability(delta.continuous());
std::vector<double> pruned_leaves = {2e50, 3.3973404e-31, 2e50, 0.0097568009}; std::vector<double> pruned_leaves = {2e50, 4.1706942, 2e50, 4.1609374};
AlgebraicDecisionTree<Key> expected_pruned_error(discrete_keys, AlgebraicDecisionTree<Key> expected_pruned_error(discrete_keys,
pruned_leaves); pruned_leaves);
// regression // regression
EXPECT(assert_equal(expected_pruned_error, pruned_error_tree, 1e-9)); EXPECT(assert_equal(expected_pruned_error, pruned_error_tree, 1e-6));
// Verify error computation and check for specific error value // Verify logProbability computation and check for specific logProbability
DiscreteValues discrete_values{{M(0), 1}, {M(1), 1}}; // value
const DiscreteValues discrete_values{{M(0), 1}, {M(1), 1}};
const HybridValues hybridValues{delta.continuous(), discrete_values};
double logProbability = 0;
logProbability +=
hybridBayesNet->at(0)->asMixture()->logProbability(hybridValues);
logProbability +=
hybridBayesNet->at(1)->asMixture()->logProbability(hybridValues);
logProbability +=
hybridBayesNet->at(2)->asMixture()->logProbability(hybridValues);
double total_error = 0; // TODO(dellaert): the discrete errors are not added in logProbability tree!
for (size_t idx = 0; idx < hybridBayesNet->size(); idx++) { EXPECT_DOUBLES_EQUAL(logProbability, error_tree(discrete_values), 1e-9);
if (hybridBayesNet->at(idx)->isHybrid()) { EXPECT_DOUBLES_EQUAL(logProbability, pruned_error_tree(discrete_values),
double error = hybridBayesNet->at(idx)->asMixture()->error( 1e-9);
{delta.continuous(), discrete_values});
total_error += error;
} else if (hybridBayesNet->at(idx)->isContinuous()) {
double error =
hybridBayesNet->at(idx)->asGaussian()->error(delta.continuous());
total_error += error;
}
}
EXPECT_DOUBLES_EQUAL( logProbability +=
total_error, hybridBayesNet->error({delta.continuous(), discrete_values}), hybridBayesNet->at(3)->asDiscrete()->logProbability(discrete_values);
1e-9); logProbability +=
EXPECT_DOUBLES_EQUAL(total_error, error_tree(discrete_values), 1e-9); hybridBayesNet->at(4)->asDiscrete()->logProbability(discrete_values);
EXPECT_DOUBLES_EQUAL(total_error, pruned_error_tree(discrete_values), 1e-9); EXPECT_DOUBLES_EQUAL(logProbability,
hybridBayesNet->logProbability(hybridValues), 1e-9);
} }
/* ****************************************************************************/ /* ****************************************************************************/

View File

@ -60,12 +60,14 @@ TEST(HybridFactorGraph, GaussianFactorGraph) {
Values linearizationPoint; Values linearizationPoint;
linearizationPoint.insert<double>(X(0), 0); linearizationPoint.insert<double>(X(0), 0);
// Linearize the factor graph.
HybridGaussianFactorGraph ghfg = *fg.linearize(linearizationPoint); HybridGaussianFactorGraph ghfg = *fg.linearize(linearizationPoint);
EXPECT_LONGS_EQUAL(1, ghfg.size());
// Add a factor to the GaussianFactorGraph // Check that the error is the same for the nonlinear values.
ghfg.add(JacobianFactor(X(0), I_1x1, Vector1(5))); const VectorValues zero{{X(0), Vector1(0)}};
const HybridValues hybridValues{zero, {}, linearizationPoint};
EXPECT_LONGS_EQUAL(2, ghfg.size()); EXPECT_DOUBLES_EQUAL(fg.error(hybridValues), ghfg.error(hybridValues), 1e-9);
} }
/*************************************************************************** /***************************************************************************

View File

@ -88,6 +88,22 @@ void BayesNet<CONDITIONAL>::saveGraph(const std::string& filename,
of.close(); of.close();
} }
/* ************************************************************************* */
template <class CONDITIONAL>
double BayesNet<CONDITIONAL>::logProbability(const HybridValues& x) const {
double sum = 0.;
for (const auto& gc : *this) {
if (gc) sum += gc->logProbability(x);
}
return sum;
}
/* ************************************************************************* */
template <class CONDITIONAL>
double BayesNet<CONDITIONAL>::evaluate(const HybridValues& x) const {
return exp(-logProbability(x));
}
/* ************************************************************************* */ /* ************************************************************************* */
} // namespace gtsam } // namespace gtsam

View File

@ -25,6 +25,8 @@
namespace gtsam { namespace gtsam {
class HybridValues;
/** /**
* A BayesNet is a tree of conditionals, stored in elimination order. * A BayesNet is a tree of conditionals, stored in elimination order.
* @ingroup inference * @ingroup inference
@ -52,9 +54,11 @@ class BayesNet : public FactorGraph<CONDITIONAL> {
/** /**
* Constructor that takes an initializer list of shared pointers. * Constructor that takes an initializer list of shared pointers.
* BayesNet<SymbolicConditional> bn = {make_shared<SymbolicConditional>(), ...}; * BayesNet<SymbolicConditional> bn = {make_shared<SymbolicConditional>(),
* ...};
*/ */
BayesNet(std::initializer_list<sharedConditional> conditionals): Base(conditionals) {} BayesNet(std::initializer_list<sharedConditional> conditionals)
: Base(conditionals) {}
/// @} /// @}
@ -68,7 +72,6 @@ class BayesNet : public FactorGraph<CONDITIONAL> {
const KeyFormatter& formatter = DefaultKeyFormatter) const override; const KeyFormatter& formatter = DefaultKeyFormatter) const override;
/// @} /// @}
/// @name Graph Display /// @name Graph Display
/// @{ /// @{
@ -86,6 +89,16 @@ class BayesNet : public FactorGraph<CONDITIONAL> {
const KeyFormatter& keyFormatter = DefaultKeyFormatter, const KeyFormatter& keyFormatter = DefaultKeyFormatter,
const DotWriter& writer = DotWriter()) const; const DotWriter& writer = DotWriter()) const;
/// @}
/// @name HybridValues methods
/// @{
// Expose HybridValues version of logProbability.
double logProbability(const HybridValues& x) const;
// Expose HybridValues version of evaluate.
double evaluate(const HybridValues& c) const;
/// @} /// @}
}; };

View File

@ -18,30 +18,42 @@
// \callgraph // \callgraph
#pragma once #pragma once
#include <iostream>
#include <gtsam/inference/Conditional.h> #include <gtsam/inference/Conditional.h>
#include <cmath>
#include <iostream>
namespace gtsam { namespace gtsam {
/* ************************************************************************* */ /* ************************************************************************* */
template<class FACTOR, class DERIVEDFACTOR> template <class FACTOR, class DERIVEDCONDITIONAL>
void Conditional<FACTOR,DERIVEDFACTOR>::print(const std::string& s, const KeyFormatter& formatter) const { void Conditional<FACTOR, DERIVEDCONDITIONAL>::print(
std::cout << s << " P("; const std::string& s, const KeyFormatter& formatter) const {
for(Key key: frontals()) std::cout << s << " P(";
std::cout << " " << formatter(key); for (Key key : frontals()) std::cout << " " << formatter(key);
if (nrParents() > 0) if (nrParents() > 0) std::cout << " |";
std::cout << " |"; for (Key parent : parents()) std::cout << " " << formatter(parent);
for(Key parent: parents()) std::cout << ")" << std::endl;
std::cout << " " << formatter(parent);
std::cout << ")" << std::endl;
}
/* ************************************************************************* */
template<class FACTOR, class DERIVEDFACTOR>
bool Conditional<FACTOR,DERIVEDFACTOR>::equals(const This& c, double tol) const
{
return nrFrontals_ == c.nrFrontals_;
}
} }
/* ************************************************************************* */
template <class FACTOR, class DERIVEDCONDITIONAL>
bool Conditional<FACTOR, DERIVEDCONDITIONAL>::equals(const This& c,
double tol) const {
return nrFrontals_ == c.nrFrontals_;
}
/* ************************************************************************* */
template <class FACTOR, class DERIVEDCONDITIONAL>
double Conditional<FACTOR, DERIVEDCONDITIONAL>::logProbability(
const HybridValues& c) const {
throw std::runtime_error("Conditional::logProbability is not implemented");
}
/* ************************************************************************* */
template <class FACTOR, class DERIVEDCONDITIONAL>
double Conditional<FACTOR, DERIVEDCONDITIONAL>::evaluate(
const HybridValues& c) const {
throw std::runtime_error("Conditional::evaluate is not implemented");
}
} // namespace gtsam

View File

@ -24,13 +24,37 @@
namespace gtsam { namespace gtsam {
class HybridValues; // forward declaration.
/** /**
* Base class for conditional densities. This class iterators and * This is the base class for all conditional distributions/densities,
* access to the frontal and separator keys. * which are implemented as specialized factors. This class does not store any
* data other than its keys. Derived classes store data such as matrices and
* probability tables.
*
* The `evaluate` method is used to evaluate the factor, and together with
* `logProbability` is the main methods that need to be implemented in derived
* classes. These two methods relate to the `error` method in the factor by:
* probability(x) = k exp(-error(x))
* where k is a normalization constant making \int probability(x) == 1.0, and
* logProbability(x) = K - error(x)
* i.e., K = log(K).
*
* There are four broad classes of conditionals that derive from Conditional:
*
* - \b Gaussian conditionals, implemented in \class GaussianConditional, a
* Gaussian density over a set of continuous variables.
* - \b Discrete conditionals, implemented in \class DiscreteConditional, which
* represent a discrete conditional distribution over discrete variables.
* - \b Hybrid conditional densities, such as \class GaussianMixture, which is
* a density over continuous variables given discrete/continuous parents.
* - \b Symbolic factors, used to represent a graph structure, implemented in
* \class SymbolicConditional. Only used for symbolic elimination etc.
* *
* Derived classes *must* redefine the Factor and shared_ptr typedefs to refer * Derived classes *must* redefine the Factor and shared_ptr typedefs to refer
* to the associated factor type and shared_ptr type of the derived class. See * to the associated factor type and shared_ptr type of the derived class. See
* SymbolicConditional and GaussianConditional for examples. * SymbolicConditional and GaussianConditional for examples.
*
* \nosubgrouping * \nosubgrouping
*/ */
template<class FACTOR, class DERIVEDCONDITIONAL> template<class FACTOR, class DERIVEDCONDITIONAL>
@ -78,6 +102,8 @@ namespace gtsam {
/// @name Standard Interface /// @name Standard Interface
/// @{ /// @{
virtual ~Conditional() {}
/** return the number of frontals */ /** return the number of frontals */
size_t nrFrontals() const { return nrFrontals_; } size_t nrFrontals() const { return nrFrontals_; }
@ -98,6 +124,27 @@ namespace gtsam {
/** return a view of the parent keys */ /** return a view of the parent keys */
Parents parents() const { return boost::make_iterator_range(beginParents(), endParents()); } Parents parents() const { return boost::make_iterator_range(beginParents(), endParents()); }
/**
* All conditional types need to implement a `logProbability` function, for which
* exp(logProbability(x)) = evaluate(x).
*/
virtual double logProbability(const HybridValues& c) const;
/**
* All conditional types need to implement an `evaluate` function, that yields
* a true probability. The default implementation just exponentiates logProbability.
*/
virtual double evaluate(const HybridValues& c) const;
/// Evaluate probability density, sugar.
double operator()(const HybridValues& x) const {
return evaluate(x);
}
/// @}
/// @name Advanced Interface
/// @{
/** Iterator pointing to first frontal key. */ /** Iterator pointing to first frontal key. */
typename FACTOR::const_iterator beginFrontals() const { return asFactor().begin(); } typename FACTOR::const_iterator beginFrontals() const { return asFactor().begin(); }
@ -110,10 +157,6 @@ namespace gtsam {
/** Iterator pointing past the last parent key. */ /** Iterator pointing past the last parent key. */
typename FACTOR::const_iterator endParents() const { return asFactor().end(); } typename FACTOR::const_iterator endParents() const { return asFactor().end(); }
/// @}
/// @name Advanced Interface
/// @{
/** Mutable version of nrFrontals */ /** Mutable version of nrFrontals */
size_t& nrFrontals() { return nrFrontals_; } size_t& nrFrontals() { return nrFrontals_; }

View File

@ -43,4 +43,10 @@ namespace gtsam {
return keys_ == other.keys_; return keys_ == other.keys_;
} }
/* ************************************************************************* */
double Factor::error(const HybridValues& c) const {
throw std::runtime_error("Factor::error is not implemented");
}
} }

View File

@ -29,21 +29,38 @@
#include <gtsam/inference/Key.h> #include <gtsam/inference/Key.h>
namespace gtsam { namespace gtsam {
/// Define collection types:
typedef FastVector<FactorIndex> FactorIndices; /// Define collection types:
typedef FastSet<FactorIndex> FactorIndexSet; typedef FastVector<FactorIndex> FactorIndices;
typedef FastSet<FactorIndex> FactorIndexSet;
class HybridValues; // forward declaration of a Value type for error.
/** /**
* This is the base class for all factor types. This class does not store any * This is the base class for all factor types, as well as conditionals,
* which are implemented as specialized factors. This class does not store any
* data other than its keys. Derived classes store data such as matrices and * data other than its keys. Derived classes store data such as matrices and
* probability tables. * probability tables.
* *
* Note that derived classes *must* redefine the `This` and `shared_ptr` * The `error` method is used to evaluate the factor, and is the only method
* typedefs. See JacobianFactor, etc. for examples. * that is required to be implemented in derived classes, although it has a
* default implementation that throws an exception.
*
* There are five broad classes of factors that derive from Factor:
* *
* This class is \b not virtual for performance reasons - the derived class * - \b Nonlinear factors, such as \class NonlinearFactor and \class NoiseModelFactor, which
* SymbolicFactor needs to be created and destroyed quickly during symbolic * represent a nonlinear likelihood function over a set of variables.
* elimination. GaussianFactor and NonlinearFactor are virtual. * - \b Gaussian factors, such as \class JacobianFactor and \class HessianFactor, which
* represent a Gaussian likelihood over a set of variables.
* - \b Discrete factors, such as \class DiscreteFactor and \class DecisionTreeFactor, which
* represent a discrete distribution over a set of variables.
* - \b Hybrid factors, such as \class HybridFactor, which represent a mixture of
* Gaussian and discrete distributions over a set of variables.
* - \b Symbolic factors, used to represent a graph structure, such as
* \class SymbolicFactor, only used for symbolic elimination etc.
*
* Note that derived classes must also redefine the `This` and `shared_ptr`
* typedefs. See JacobianFactor, etc. for examples.
* *
* \nosubgrouping * \nosubgrouping
*/ */
@ -128,6 +145,12 @@ typedef FastSet<FactorIndex> FactorIndexSet;
/** Iterator at end of involved variable keys */ /** Iterator at end of involved variable keys */
const_iterator end() const { return keys_.end(); } const_iterator end() const { return keys_.end(); }
/**
* All factor types need to implement an error function.
* In factor graphs, this is the negative log-likelihood.
*/
virtual double error(const HybridValues& c) const;
/** /**
* @return the number of variables involved in this factor * @return the number of variables involved in this factor
*/ */
@ -152,7 +175,6 @@ typedef FastSet<FactorIndex> FactorIndexSet;
bool equals(const This& other, double tol = 1e-9) const; bool equals(const This& other, double tol = 1e-9) const;
/// @} /// @}
/// @name Advanced Interface /// @name Advanced Interface
/// @{ /// @{
@ -165,7 +187,13 @@ typedef FastSet<FactorIndex> FactorIndexSet;
/** Iterator at end of involved variable keys */ /** Iterator at end of involved variable keys */
iterator end() { return keys_.end(); } iterator end() { return keys_.end(); }
/// @}
private: private:
/// @name Serialization
/// @{
/** Serialization function */ /** Serialization function */
friend class boost::serialization::access; friend class boost::serialization::access;
template<class Archive> template<class Archive>
@ -177,4 +205,4 @@ typedef FastSet<FactorIndex> FactorIndexSet;
}; };
} } // \namespace gtsam

View File

@ -61,6 +61,16 @@ bool FactorGraph<FACTOR>::equals(const This& fg, double tol) const {
return true; return true;
} }
/* ************************************************************************ */
template <class FACTOR>
double FactorGraph<FACTOR>::error(const HybridValues &values) const {
double error = 0.0;
for (auto &f : factors_) {
error += f->error(values);
}
return error;
}
/* ************************************************************************* */ /* ************************************************************************* */
template <class FACTOR> template <class FACTOR>
size_t FactorGraph<FACTOR>::nrFactors() const { size_t FactorGraph<FACTOR>::nrFactors() const {

View File

@ -47,6 +47,8 @@ typedef FastVector<FactorIndex> FactorIndices;
template <class CLIQUE> template <class CLIQUE>
class BayesTree; class BayesTree;
class HybridValues;
/** Helper */ /** Helper */
template <class C> template <class C>
class CRefCallPushBack { class CRefCallPushBack {
@ -359,6 +361,9 @@ class FactorGraph {
/** Get the last factor */ /** Get the last factor */
sharedFactor back() const { return factors_.back(); } sharedFactor back() const { return factors_.back(); }
/** Add error for all factors. */
double error(const HybridValues &values) const;
/// @} /// @}
/// @name Modifying Factor Graphs (imperative, discouraged) /// @name Modifying Factor Graphs (imperative, discouraged)
/// @{ /// @{

View File

@ -104,7 +104,25 @@ namespace gtsam {
/* ************************************************************************* */ /* ************************************************************************* */
double GaussianBayesNet::error(const VectorValues& x) const { double GaussianBayesNet::error(const VectorValues& x) const {
return GaussianFactorGraph(*this).error(x); double sum = 0.;
for (const auto& gc : *this) {
if (gc) sum += gc->error(x);
}
return sum;
}
/* ************************************************************************* */
double GaussianBayesNet::logProbability(const VectorValues& x) const {
double sum = 0.;
for (const auto& gc : *this) {
if (gc) sum += gc->logProbability(x);
}
return sum;
}
/* ************************************************************************* */
double GaussianBayesNet::evaluate(const VectorValues& x) const {
return exp(logProbability(x));
} }
/* ************************************************************************* */ /* ************************************************************************* */
@ -225,19 +243,5 @@ namespace gtsam {
} }
/* ************************************************************************* */ /* ************************************************************************* */
double GaussianBayesNet::logDensity(const VectorValues& x) const {
double sum = 0.0;
for (const auto& conditional : *this) {
if (conditional) sum += conditional->logDensity(x);
}
return sum;
}
/* ************************************************************************* */
double GaussianBayesNet::evaluate(const VectorValues& x) const {
return exp(logDensity(x));
}
/* ************************************************************************* */
} // namespace gtsam } // namespace gtsam

View File

@ -97,11 +97,16 @@ namespace gtsam {
/// @name Standard Interface /// @name Standard Interface
/// @{ /// @{
/// Sum error over all variables.
double error(const VectorValues& x) const;
/// Sum logProbability over all variables.
double logProbability(const VectorValues& x) const;
/** /**
* Calculate probability density for given values `x`: * Calculate probability density for given values `x`:
* exp(-error(x)) / sqrt((2*pi)^n*det(Sigma)) * exp(logProbability)
* where x is the vector of values, and Sigma is the covariance matrix. * where x is the vector of values.
* Note that error(x)=0.5*e'*e includes the 0.5 factor already.
*/ */
double evaluate(const VectorValues& x) const; double evaluate(const VectorValues& x) const;
@ -110,13 +115,6 @@ namespace gtsam {
return evaluate(x); return evaluate(x);
} }
/**
* Calculate log-density for given values `x`:
* -error(x) - 0.5 * n*log(2*pi) - 0.5 * log det(Sigma)
* where x is the vector of values, and Sigma is the covariance matrix.
*/
double logDensity(const VectorValues& x) const;
/// Solve the GaussianBayesNet, i.e. return \f$ x = R^{-1}*d \f$, by /// Solve the GaussianBayesNet, i.e. return \f$ x = R^{-1}*d \f$, by
/// back-substitution /// back-substitution
VectorValues optimize() const; VectorValues optimize() const;
@ -216,9 +214,6 @@ namespace gtsam {
* allocateVectorValues */ * allocateVectorValues */
VectorValues gradientAtZero() const; VectorValues gradientAtZero() const;
/** 0.5 * sum of squared Mahalanobis distances. */
double error(const VectorValues& x) const;
/** /**
* Computes the determinant of a GassianBayesNet. A GaussianBayesNet is an upper triangular * Computes the determinant of a GassianBayesNet. A GaussianBayesNet is an upper triangular
* matrix and for an upper triangular matrix determinant is the product of the diagonal * matrix and for an upper triangular matrix determinant is the product of the diagonal
@ -251,6 +246,14 @@ namespace gtsam {
VectorValues backSubstituteTranspose(const VectorValues& gx) const; VectorValues backSubstituteTranspose(const VectorValues& gx) const;
/// @} /// @}
/// @name HybridValues methods.
/// @{
using Base::evaluate; // Expose evaluate(const HybridValues&) method..
using Base::logProbability; // Expose logProbability(const HybridValues&) method..
using Base::error; // Expose error(const HybridValues&) method..
/// @}
private: private:
/** Serialization function */ /** Serialization function */

View File

@ -19,6 +19,7 @@
#include <gtsam/linear/Sampler.h> #include <gtsam/linear/Sampler.h>
#include <gtsam/linear/VectorValues.h> #include <gtsam/linear/VectorValues.h>
#include <gtsam/linear/linearExceptions.h> #include <gtsam/linear/linearExceptions.h>
#include <gtsam/hybrid/HybridValues.h>
#include <boost/format.hpp> #include <boost/format.hpp>
#ifdef __GNUC__ #ifdef __GNUC__
@ -34,6 +35,7 @@
#include <functional> #include <functional>
#include <list> #include <list>
#include <string> #include <string>
#include <cmath>
// In Wrappers we have no access to this so have a default ready // In Wrappers we have no access to this so have a default ready
static std::mt19937_64 kRandomNumberGenerator(42); static std::mt19937_64 kRandomNumberGenerator(42);
@ -170,39 +172,42 @@ namespace gtsam {
} }
} }
/* ************************************************************************* */ /* ************************************************************************* */
double GaussianConditional::logDeterminant() const { double GaussianConditional::logDeterminant() const {
if (get_model()) { if (get_model()) {
Vector diag = R().diagonal(); Vector diag = R().diagonal();
get_model()->whitenInPlace(diag); get_model()->whitenInPlace(diag);
return diag.unaryExpr([](double x) { return log(x); }).sum(); return diag.unaryExpr([](double x) { return log(x); }).sum();
} else { } else {
return R().diagonal().unaryExpr([](double x) { return log(x); }).sum(); return R().diagonal().unaryExpr([](double x) { return log(x); }).sum();
}
} }
}
/* ************************************************************************* */ /* ************************************************************************* */
// normalization constant = 1.0 / sqrt((2*pi)^n*det(Sigma)) // normalization constant = 1.0 / sqrt((2*pi)^n*det(Sigma))
// log = - 0.5 * n*log(2*pi) - 0.5 * log det(Sigma) // log = - 0.5 * n*log(2*pi) - 0.5 * log det(Sigma)
double GaussianConditional::logNormalizationConstant() const { double GaussianConditional::logNormalizationConstant() const {
constexpr double log2pi = 1.8378770664093454835606594728112; constexpr double log2pi = 1.8378770664093454835606594728112;
size_t n = d().size(); size_t n = d().size();
// log det(Sigma)) = - 2.0 * logDeterminant() // log det(Sigma)) = - 2.0 * logDeterminant()
return - 0.5 * n * log2pi + logDeterminant(); return - 0.5 * n * log2pi + logDeterminant();
} }
/* ************************************************************************* */ /* ************************************************************************* */
// density = k exp(-error(x)) // density = k exp(-error(x))
// log = log(k) -error(x) // log = log(k) - error(x)
double GaussianConditional::logDensity(const VectorValues& x) const { double GaussianConditional::logProbability(const VectorValues& x) const {
return logNormalizationConstant() - error(x); return logNormalizationConstant() - error(x);
} }
/* ************************************************************************* */ double GaussianConditional::logProbability(const HybridValues& x) const {
double GaussianConditional::evaluate(const VectorValues& x) const { return logProbability(x.continuous());
return exp(logDensity(x)); }
}
/* ************************************************************************* */
double GaussianConditional::evaluate(const VectorValues& c) const {
return exp(logProbability(c));
}
/* ************************************************************************* */ /* ************************************************************************* */
VectorValues GaussianConditional::solve(const VectorValues& x) const { VectorValues GaussianConditional::solve(const VectorValues& x) const {
// Concatenate all vector values that correspond to parent variables // Concatenate all vector values that correspond to parent variables

View File

@ -132,64 +132,6 @@ namespace gtsam {
/// @name Standard Interface /// @name Standard Interface
/// @{ /// @{
/**
* Calculate probability density for given values `x`:
* exp(-error(x)) / sqrt((2*pi)^n*det(Sigma))
* where x is the vector of values, and Sigma is the covariance matrix.
* Note that error(x)=0.5*e'*e includes the 0.5 factor already.
*/
double evaluate(const VectorValues& x) const;
/// Evaluate probability density, sugar.
double operator()(const VectorValues& x) const {
return evaluate(x);
}
/**
* Calculate log-density for given values `x`:
* -error(x) - 0.5 * n*log(2*pi) - 0.5 * log det(Sigma)
* where x is the vector of values, and Sigma is the covariance matrix.
*/
double logDensity(const VectorValues& x) const;
/** Return a view of the upper-triangular R block of the conditional */
constABlock R() const { return Ab_.range(0, nrFrontals()); }
/** Get a view of the parent blocks. */
constABlock S() const { return Ab_.range(nrFrontals(), size()); }
/** Get a view of the S matrix for the variable pointed to by the given key iterator */
constABlock S(const_iterator it) const { return BaseFactor::getA(it); }
/** Get a view of the r.h.s. vector d */
const constBVector d() const { return BaseFactor::getb(); }
/**
* @brief Compute the determinant of the R matrix.
*
* The determinant is computed in log form using logDeterminant for
* numerical stability and then exponentiated.
*
* Note, the covariance matrix \f$ \Sigma = (R^T R)^{-1} \f$, and hence
* \f$ \det(\Sigma) = 1 / \det(R^T R) = 1 / determinant()^ 2 \f$.
*
* @return double
*/
inline double determinant() const { return exp(logDeterminant()); }
/**
* @brief Compute the log determinant of the R matrix.
*
* For numerical stability, the determinant is computed in log
* form, so it is a summation rather than a multiplication.
*
* Note, the covariance matrix \f$ \Sigma = (R^T R)^{-1} \f$, and hence
* \f$ \log \det(\Sigma) = - \log \det(R^T R) = - 2 logDeterminant() \f$.
*
* @return double
*/
double logDeterminant() const;
/** /**
* normalization constant = 1.0 / sqrt((2*pi)^n*det(Sigma)) * normalization constant = 1.0 / sqrt((2*pi)^n*det(Sigma))
* log = - 0.5 * n*log(2*pi) - 0.5 * log det(Sigma) * log = - 0.5 * n*log(2*pi) - 0.5 * log det(Sigma)
@ -203,6 +145,27 @@ namespace gtsam {
return exp(logNormalizationConstant()); return exp(logNormalizationConstant());
} }
/**
* Calculate log-probability log(evaluate(x)) for given values `x`:
* -error(x) - 0.5 * n*log(2*pi) - 0.5 * log det(Sigma)
* where x is the vector of values, and Sigma is the covariance matrix.
* This differs from error as it is log, not negative log, and it
* includes the normalization constant.
*/
double logProbability(const VectorValues& x) const;
/**
* Calculate probability density for given values `x`:
* exp(logProbability(x)) == exp(-GaussianFactor::error(x)) / sqrt((2*pi)^n*det(Sigma))
* where x is the vector of values, and Sigma is the covariance matrix.
*/
double evaluate(const VectorValues& x) const;
/// Evaluate probability density, sugar.
double operator()(const VectorValues& x) const {
return evaluate(x);
}
/** /**
* Solves a conditional Gaussian and writes the solution into the entries of * Solves a conditional Gaussian and writes the solution into the entries of
* \c x for each frontal variable of the conditional. The parents are * \c x for each frontal variable of the conditional. The parents are
@ -255,6 +218,63 @@ namespace gtsam {
VectorValues sample(const VectorValues& parentsValues) const; VectorValues sample(const VectorValues& parentsValues) const;
/// @} /// @}
/// @name Linear algebra.
/// @{
/** Return a view of the upper-triangular R block of the conditional */
constABlock R() const { return Ab_.range(0, nrFrontals()); }
/** Get a view of the parent blocks. */
constABlock S() const { return Ab_.range(nrFrontals(), size()); }
/** Get a view of the S matrix for the variable pointed to by the given key iterator */
constABlock S(const_iterator it) const { return BaseFactor::getA(it); }
/** Get a view of the r.h.s. vector d */
const constBVector d() const { return BaseFactor::getb(); }
/**
* @brief Compute the determinant of the R matrix.
*
* The determinant is computed in log form using logDeterminant for
* numerical stability and then exponentiated.
*
* Note, the covariance matrix \f$ \Sigma = (R^T R)^{-1} \f$, and hence
* \f$ \det(\Sigma) = 1 / \det(R^T R) = 1 / determinant()^ 2 \f$.
*
* @return double
*/
inline double determinant() const { return exp(logDeterminant()); }
/**
* @brief Compute the log determinant of the R matrix.
*
* For numerical stability, the determinant is computed in log
* form, so it is a summation rather than a multiplication.
*
* Note, the covariance matrix \f$ \Sigma = (R^T R)^{-1} \f$, and hence
* \f$ \log \det(\Sigma) = - \log \det(R^T R) = - 2 logDeterminant() \f$.
*
* @return double
*/
double logDeterminant() const;
/// @}
/// @name HybridValues methods.
/// @{
/**
* Calculate log-probability log(evaluate(x)) for HybridValues `x`.
* Simply dispatches to VectorValues version.
*/
double logProbability(const HybridValues& x) const override;
using Conditional::evaluate; // Expose evaluate(const HybridValues&) method..
using Conditional::operator(); // Expose evaluate(const HybridValues&) method..
using Base::error; // Expose error(const HybridValues&) method..
/// @}
#ifdef GTSAM_ALLOW_DEPRECATED_SINCE_V42 #ifdef GTSAM_ALLOW_DEPRECATED_SINCE_V42
/// @name Deprecated /// @name Deprecated

View File

@ -18,16 +18,24 @@
// \callgraph // \callgraph
#include <gtsam/hybrid/HybridValues.h>
#include <gtsam/linear/GaussianFactor.h> #include <gtsam/linear/GaussianFactor.h>
#include <gtsam/linear/VectorValues.h> #include <gtsam/linear/VectorValues.h>
namespace gtsam { namespace gtsam {
/* ************************************************************************* */ double GaussianFactor::error(const VectorValues& c) const {
VectorValues GaussianFactor::hessianDiagonal() const { throw std::runtime_error("GaussianFactor::error is not implemented");
VectorValues d;
hessianDiagonalAdd(d);
return d;
}
} }
double GaussianFactor::error(const HybridValues& c) const {
return this->error(c.continuous());
}
VectorValues GaussianFactor::hessianDiagonal() const {
VectorValues d;
hessianDiagonalAdd(d);
return d;
}
} // namespace gtsam

View File

@ -63,8 +63,20 @@ namespace gtsam {
/** Equals for testable */ /** Equals for testable */
virtual bool equals(const GaussianFactor& lf, double tol = 1e-9) const = 0; virtual bool equals(const GaussianFactor& lf, double tol = 1e-9) const = 0;
/** Print for testable */ /**
virtual double error(const VectorValues& c) const = 0; /** 0.5*(A*x-b)'*D*(A*x-b) */ * In Gaussian factors, the error function returns either the negative log-likelihood, e.g.,
* 0.5*(A*x-b)'*D*(A*x-b)
* for a \class JacobianFactor, or the negative log-density, e.g.,
* 0.5*(A*x-b)'*D*(A*x-b) - log(k)
* for a \class GaussianConditional, where k is the normalization constant.
*/
virtual double error(const VectorValues& c) const;
/**
* The Factor::error simply extracts the \class VectorValues from the
* \class HybridValues and calculates the error.
*/
double error(const HybridValues& c) const override;
/** Return the dimension of the variable pointed to by the given key iterator */ /** Return the dimension of the variable pointed to by the given key iterator */
virtual DenseIndex getDim(const_iterator variable) const = 0; virtual DenseIndex getDim(const_iterator variable) const = 0;

View File

@ -67,6 +67,22 @@ namespace gtsam {
return spec; return spec;
} }
/* ************************************************************************* */
double GaussianFactorGraph::error(const VectorValues& x) const {
double total_error = 0.;
for(const sharedFactor& factor: *this){
if(factor)
total_error += factor->error(x);
}
return total_error;
}
/* ************************************************************************* */
double GaussianFactorGraph::probPrime(const VectorValues& c) const {
// NOTE the 0.5 constant is handled by the factor error.
return exp(-error(c));
}
/* ************************************************************************* */ /* ************************************************************************* */
GaussianFactorGraph::shared_ptr GaussianFactorGraph::cloneToPtr() const { GaussianFactorGraph::shared_ptr GaussianFactorGraph::cloneToPtr() const {
gtsam::GaussianFactorGraph::shared_ptr result(new GaussianFactorGraph()); gtsam::GaussianFactorGraph::shared_ptr result(new GaussianFactorGraph());

View File

@ -167,20 +167,10 @@ namespace gtsam {
std::map<Key, size_t> getKeyDimMap() const; std::map<Key, size_t> getKeyDimMap() const;
/** unnormalized error */ /** unnormalized error */
double error(const VectorValues& x) const { double error(const VectorValues& x) const;
double total_error = 0.;
for(const sharedFactor& factor: *this){
if(factor)
total_error += factor->error(x);
}
return total_error;
}
/** Unnormalized probability. O(n) */ /** Unnormalized probability. O(n) */
double probPrime(const VectorValues& c) const { double probPrime(const VectorValues& c) const;
// NOTE the 0.5 constant is handled by the factor error.
return exp(-error(c));
}
/** /**
* Clone() performs a deep-copy of the graph, including all of the factors. * Clone() performs a deep-copy of the graph, including all of the factors.

View File

@ -497,8 +497,9 @@ virtual class GaussianConditional : gtsam::JacobianFactor {
bool equals(const gtsam::GaussianConditional& cg, double tol) const; bool equals(const gtsam::GaussianConditional& cg, double tol) const;
// Standard Interface // Standard Interface
double logProbability(const gtsam::VectorValues& x) const;
double evaluate(const gtsam::VectorValues& x) const; double evaluate(const gtsam::VectorValues& x) const;
double logDensity(const gtsam::VectorValues& x) const; double error(const gtsam::VectorValues& x) const;
gtsam::Key firstFrontalKey() const; gtsam::Key firstFrontalKey() const;
gtsam::VectorValues solve(const gtsam::VectorValues& parents) const; gtsam::VectorValues solve(const gtsam::VectorValues& parents) const;
gtsam::JacobianFactor* likelihood( gtsam::JacobianFactor* likelihood(
@ -558,8 +559,10 @@ virtual class GaussianBayesNet {
gtsam::GaussianConditional* back() const; gtsam::GaussianConditional* back() const;
// Standard interface // Standard interface
// Standard Interface
double logProbability(const gtsam::VectorValues& x) const;
double evaluate(const gtsam::VectorValues& x) const; double evaluate(const gtsam::VectorValues& x) const;
double logDensity(const gtsam::VectorValues& x) const; double error(const gtsam::VectorValues& x) const;
gtsam::VectorValues optimize() const; gtsam::VectorValues optimize() const;
gtsam::VectorValues optimize(const gtsam::VectorValues& given) const; gtsam::VectorValues optimize(const gtsam::VectorValues& given) const;

View File

@ -78,9 +78,13 @@ TEST(GaussianBayesNet, Evaluate1) {
// which at the mean is 1.0! So, the only thing we need to calculate is // which at the mean is 1.0! So, the only thing we need to calculate is
// the normalization constant 1.0/sqrt((2*pi*Sigma).det()). // the normalization constant 1.0/sqrt((2*pi*Sigma).det()).
// The covariance matrix inv(Sigma) = R'*R, so the determinant is // The covariance matrix inv(Sigma) = R'*R, so the determinant is
const double expected = sqrt((invSigma / (2 * M_PI)).determinant()); const double constant = sqrt((invSigma / (2 * M_PI)).determinant());
EXPECT_DOUBLES_EQUAL(log(constant),
smallBayesNet.at(0)->logNormalizationConstant() +
smallBayesNet.at(1)->logNormalizationConstant(),
1e-9);
const double actual = smallBayesNet.evaluate(mean); const double actual = smallBayesNet.evaluate(mean);
EXPECT_DOUBLES_EQUAL(expected, actual, 1e-9); EXPECT_DOUBLES_EQUAL(constant, actual, 1e-9);
} }
// Check the evaluate with non-unit noise. // Check the evaluate with non-unit noise.
@ -89,9 +93,9 @@ TEST(GaussianBayesNet, Evaluate2) {
const VectorValues mean = noisyBayesNet.optimize(); const VectorValues mean = noisyBayesNet.optimize();
const Matrix R = noisyBayesNet.matrix().first; const Matrix R = noisyBayesNet.matrix().first;
const Matrix invSigma = R.transpose() * R; const Matrix invSigma = R.transpose() * R;
const double expected = sqrt((invSigma / (2 * M_PI)).determinant()); const double constant = sqrt((invSigma / (2 * M_PI)).determinant());
const double actual = noisyBayesNet.evaluate(mean); const double actual = noisyBayesNet.evaluate(mean);
EXPECT_DOUBLES_EQUAL(expected, actual, 1e-9); EXPECT_DOUBLES_EQUAL(constant, actual, 1e-9);
} }
/* ************************************************************************* */ /* ************************************************************************* */

View File

@ -384,11 +384,17 @@ TEST(GaussianConditional, FromMeanAndStddev) {
double expected1 = 0.5 * e1.dot(e1); double expected1 = 0.5 * e1.dot(e1);
EXPECT_DOUBLES_EQUAL(expected1, conditional1.error(values), 1e-9); EXPECT_DOUBLES_EQUAL(expected1, conditional1.error(values), 1e-9);
double expected2 = conditional1.logNormalizationConstant() - 0.5 * e1.dot(e1);
EXPECT_DOUBLES_EQUAL(expected2, conditional1.logProbability(values), 1e-9);
auto conditional2 = GaussianConditional::FromMeanAndStddev(X(0), A1, X(1), A2, auto conditional2 = GaussianConditional::FromMeanAndStddev(X(0), A1, X(1), A2,
X(2), b, sigma); X(2), b, sigma);
Vector2 e2 = (x0 - (A1 * x1 + A2 * x2 + b)) / sigma; Vector2 e2 = (x0 - (A1 * x1 + A2 * x2 + b)) / sigma;
double expected2 = 0.5 * e2.dot(e2); double expected3 = 0.5 * e2.dot(e2);
EXPECT_DOUBLES_EQUAL(expected2, conditional2.error(values), 1e-9); EXPECT_DOUBLES_EQUAL(expected3, conditional2.error(values), 1e-9);
double expected4 = conditional2.logNormalizationConstant() - 0.5 * e2.dot(e2);
EXPECT_DOUBLES_EQUAL(expected4, conditional2.logProbability(values), 1e-9);
} }
/* ************************************************************************* */ /* ************************************************************************* */
@ -448,20 +454,24 @@ TEST(GaussianConditional, sample) {
} }
/* ************************************************************************* */ /* ************************************************************************* */
TEST(GaussianConditional, LogNormalizationConstant) { TEST(GaussianConditional, Error) {
// Create univariate standard gaussian conditional // Create univariate standard gaussian conditional
auto std_gaussian = auto stdGaussian =
GaussianConditional::FromMeanAndStddev(X(0), Vector1::Zero(), 1.0); GaussianConditional::FromMeanAndStddev(X(0), Vector1::Zero(), 1.0);
VectorValues values; VectorValues values;
values.insert(X(0), Vector1::Zero()); values.insert(X(0), Vector1::Zero());
double logDensity = std_gaussian.logDensity(values); double logProbability = stdGaussian.logProbability(values);
// Regression. // Regression.
// These values were computed by hand for a univariate standard gaussian. // These values were computed by hand for a univariate standard gaussian.
EXPECT_DOUBLES_EQUAL(-0.9189385332046727, logDensity, 1e-9); EXPECT_DOUBLES_EQUAL(-0.9189385332046727, logProbability, 1e-9);
EXPECT_DOUBLES_EQUAL(0.3989422804014327, exp(logDensity), 1e-9); EXPECT_DOUBLES_EQUAL(0.3989422804014327, exp(logProbability), 1e-9);
EXPECT_DOUBLES_EQUAL(stdGaussian(values), exp(logProbability), 1e-9);
}
// Similar test for multivariate gaussian but with sigma 2.0 /* ************************************************************************* */
// Similar test for multivariate gaussian but with sigma 2.0
TEST(GaussianConditional, LogNormalizationConstant) {
double sigma = 2.0; double sigma = 2.0;
auto conditional = GaussianConditional::FromMeanAndStddev(X(0), Vector3::Zero(), sigma); auto conditional = GaussianConditional::FromMeanAndStddev(X(0), Vector3::Zero(), sigma);
VectorValues x; VectorValues x;
@ -469,7 +479,8 @@ TEST(GaussianConditional, LogNormalizationConstant) {
Matrix3 Sigma = I_3x3 * sigma * sigma; Matrix3 Sigma = I_3x3 * sigma * sigma;
double expectedLogNormalizingConstant = log(1 / sqrt((2 * M_PI * Sigma).determinant())); double expectedLogNormalizingConstant = log(1 / sqrt((2 * M_PI * Sigma).determinant()));
EXPECT_DOUBLES_EQUAL(expectedLogNormalizingConstant, conditional.logNormalizationConstant(), 1e-9); EXPECT_DOUBLES_EQUAL(expectedLogNormalizingConstant,
conditional.logNormalizationConstant(), 1e-9);
} }
/* ************************************************************************* */ /* ************************************************************************* */

View File

@ -52,8 +52,11 @@ TEST(GaussianDensity, FromMeanAndStddev) {
auto density = GaussianDensity::FromMeanAndStddev(X(0), b, sigma); auto density = GaussianDensity::FromMeanAndStddev(X(0), b, sigma);
Vector2 e = (x0 - b) / sigma; Vector2 e = (x0 - b) / sigma;
double expected = 0.5 * e.dot(e); double expected1 = 0.5 * e.dot(e);
EXPECT_DOUBLES_EQUAL(expected, density.error(values), 1e-9); EXPECT_DOUBLES_EQUAL(expected1, density.error(values), 1e-9);
double expected2 = density.logNormalizationConstant()- 0.5 * e.dot(e);
EXPECT_DOUBLES_EQUAL(expected2, density.logProbability(values), 1e-9);
} }
/* ************************************************************************* */ /* ************************************************************************* */

View File

@ -16,12 +16,23 @@
* @author Richard Roberts * @author Richard Roberts
*/ */
#include <gtsam/hybrid/HybridValues.h>
#include <gtsam/nonlinear/NonlinearFactor.h> #include <gtsam/nonlinear/NonlinearFactor.h>
#include <boost/make_shared.hpp> #include <boost/make_shared.hpp>
#include <boost/format.hpp> #include <boost/format.hpp>
namespace gtsam { namespace gtsam {
/* ************************************************************************* */
double NonlinearFactor::error(const Values& c) const {
throw std::runtime_error("NonlinearFactor::error is not implemented");
}
/* ************************************************************************* */
double NonlinearFactor::error(const HybridValues& c) const {
return this->error(c.nonlinear());
}
/* ************************************************************************* */ /* ************************************************************************* */
void NonlinearFactor::print(const std::string& s, void NonlinearFactor::print(const std::string& s,
const KeyFormatter& keyFormatter) const { const KeyFormatter& keyFormatter) const {

View File

@ -74,6 +74,7 @@ public:
/** Check if two factors are equal */ /** Check if two factors are equal */
virtual bool equals(const NonlinearFactor& f, double tol = 1e-9) const; virtual bool equals(const NonlinearFactor& f, double tol = 1e-9) const;
/// @} /// @}
/// @name Standard Interface /// @name Standard Interface
/// @{ /// @{
@ -81,13 +82,25 @@ public:
/** Destructor */ /** Destructor */
virtual ~NonlinearFactor() {} virtual ~NonlinearFactor() {}
/**
* In nonlinear factors, the error function returns the negative log-likelihood
* as a non-linear function of the values in a \class Values object.
*
* The idea is that Gaussian factors have a quadratic error function that locally
* approximates the negative log-likelihood, and are obtained by \b linearizing
* the nonlinear error function at a given linearization.
*
* The derived class, \class NoiseModelFactor, adds a noise model to the factor,
* and calculates the error by asking the user to implement the method
* \code double evaluateError(const Values& c) const \endcode.
*/
virtual double error(const Values& c) const;
/** /**
* Calculate the error of the factor * The Factor::error simply extracts the \class Values from the
* This is typically equal to log-likelihood, e.g. \f$ 0.5(h(x)-z)^2/sigma^2 \f$ in case of Gaussian. * \class HybridValues and calculates the error.
* You can override this for systems with unusual noise models.
*/ */
virtual double error(const Values& c) const = 0; double error(const HybridValues& c) const override;
/** get the dimension of the factor (number of rows on linearization) */ /** get the dimension of the factor (number of rows on linearization) */
virtual size_t dim() const = 0; virtual size_t dim() const = 0;

View File

@ -15,7 +15,6 @@
* @date Oct 17, 2010 * @date Oct 17, 2010
*/ */
#include <gtsam/inference/Conditional-inst.h>
#include <gtsam/symbolic/SymbolicConditional.h> #include <gtsam/symbolic/SymbolicConditional.h>
namespace gtsam { namespace gtsam {
@ -33,4 +32,15 @@ bool SymbolicConditional::equals(const This& c, double tol) const {
return BaseFactor::equals(c) && BaseConditional::equals(c); return BaseFactor::equals(c) && BaseConditional::equals(c);
} }
/* ************************************************************************* */
double SymbolicConditional::logProbability(const HybridValues& c) const {
throw std::runtime_error("SymbolicConditional::logProbability is not implemented");
}
/* ************************************************************************* */
double SymbolicConditional::evaluate(const HybridValues& c) const {
throw std::runtime_error("SymbolicConditional::evaluate is not implemented");
}
} // namespace gtsam } // namespace gtsam

View File

@ -17,10 +17,10 @@
#pragma once #pragma once
#include <gtsam/symbolic/SymbolicFactor.h>
#include <gtsam/inference/Conditional.h>
#include <gtsam/base/Testable.h> #include <gtsam/base/Testable.h>
#include <gtsam/base/types.h> #include <gtsam/base/types.h>
#include <gtsam/inference/Conditional-inst.h>
#include <gtsam/symbolic/SymbolicFactor.h>
namespace gtsam { namespace gtsam {
@ -95,13 +95,10 @@ namespace gtsam {
return FromIteratorsShared(keys.begin(), keys.end(), nrFrontals); return FromIteratorsShared(keys.begin(), keys.end(), nrFrontals);
} }
~SymbolicConditional() override {}
/// Copy this object as its actual derived type. /// Copy this object as its actual derived type.
SymbolicFactor::shared_ptr clone() const { return boost::make_shared<This>(*this); } SymbolicFactor::shared_ptr clone() const { return boost::make_shared<This>(*this); }
/// @} /// @}
/// @name Testable /// @name Testable
/// @{ /// @{
@ -114,6 +111,19 @@ namespace gtsam {
bool equals(const This& c, double tol = 1e-9) const; bool equals(const This& c, double tol = 1e-9) const;
/// @} /// @}
/// @name HybridValues methods.
/// @{
/// logProbability throws exception, symbolic.
double logProbability(const HybridValues& x) const override;
/// evaluate throws exception, symbolic.
double evaluate(const HybridValues& x) const override;
using Conditional::operator(); // Expose evaluate(const HybridValues&) method..
using SymbolicFactor::error; // Expose error(const HybridValues&) method..
/// @}
private: private:
/** Serialization function */ /** Serialization function */

View File

@ -26,6 +26,11 @@ using namespace std;
namespace gtsam { namespace gtsam {
/* ************************************************************************* */
double SymbolicFactor::error(const HybridValues& c) const {
throw std::runtime_error("SymbolicFactor::error is not implemented");
}
/* ************************************************************************* */ /* ************************************************************************* */
std::pair<boost::shared_ptr<SymbolicConditional>, boost::shared_ptr<SymbolicFactor> > std::pair<boost::shared_ptr<SymbolicConditional>, boost::shared_ptr<SymbolicFactor> >
EliminateSymbolic(const SymbolicFactorGraph& factors, const Ordering& keys) EliminateSymbolic(const SymbolicFactorGraph& factors, const Ordering& keys)

View File

@ -30,6 +30,7 @@ namespace gtsam {
// Forward declarations // Forward declarations
class SymbolicConditional; class SymbolicConditional;
class HybridValues;
class Ordering; class Ordering;
/** SymbolicFactor represents a symbolic factor that specifies graph topology but is not /** SymbolicFactor represents a symbolic factor that specifies graph topology but is not
@ -46,7 +47,7 @@ namespace gtsam {
/** Overriding the shared_ptr typedef */ /** Overriding the shared_ptr typedef */
typedef boost::shared_ptr<This> shared_ptr; typedef boost::shared_ptr<This> shared_ptr;
/// @name Standard Interface /// @name Standard Constructors
/// @{ /// @{
/** Default constructor for I/O */ /** Default constructor for I/O */
@ -106,10 +107,9 @@ namespace gtsam {
} }
/// @} /// @}
/// @name Advanced Constructors /// @name Advanced Constructors
/// @{ /// @{
public:
/** Constructor from a collection of keys */ /** Constructor from a collection of keys */
template<typename KEYITERATOR> template<typename KEYITERATOR>
static SymbolicFactor FromIterators(KEYITERATOR beginKey, KEYITERATOR endKey) { static SymbolicFactor FromIterators(KEYITERATOR beginKey, KEYITERATOR endKey) {
@ -143,6 +143,9 @@ namespace gtsam {
/// @name Standard Interface /// @name Standard Interface
/// @{ /// @{
/// The `error` method throws an exception.
double error(const HybridValues& c) const override;
/** Eliminate the variables in \c keys, in the order specified in \c keys, returning a /** Eliminate the variables in \c keys, in the order specified in \c keys, returning a
* conditional and marginal. */ * conditional and marginal. */
std::pair<boost::shared_ptr<SymbolicConditional>, boost::shared_ptr<SymbolicFactor> > std::pair<boost::shared_ptr<SymbolicConditional>, boost::shared_ptr<SymbolicFactor> >

View File

@ -11,13 +11,15 @@ Author: Frank Dellaert
# pylint: disable=no-name-in-module, invalid-name # pylint: disable=no-name-in-module, invalid-name
import math
import textwrap import textwrap
import unittest import unittest
from gtsam.utils.test_case import GtsamTestCase
import gtsam import gtsam
from gtsam import (DiscreteBayesNet, DiscreteConditional, DiscreteDistribution, from gtsam import (DiscreteBayesNet, DiscreteConditional, DiscreteDistribution,
DiscreteFactorGraph, DiscreteKeys, DiscreteValues, Ordering) DiscreteFactorGraph, DiscreteKeys, DiscreteValues, Ordering)
from gtsam.utils.test_case import GtsamTestCase
# Some keys: # Some keys:
Asia = (0, 2) Asia = (0, 2)
@ -111,7 +113,7 @@ class TestDiscreteBayesNet(GtsamTestCase):
self.assertEqual(len(actualSample), 8) self.assertEqual(len(actualSample), 8)
def test_fragment(self): def test_fragment(self):
"""Test sampling and optimizing for Asia fragment.""" """Test evaluate/sampling/optimizing for Asia fragment."""
# Create a reverse-topologically sorted fragment: # Create a reverse-topologically sorted fragment:
fragment = DiscreteBayesNet() fragment = DiscreteBayesNet()
@ -125,8 +127,14 @@ class TestDiscreteBayesNet(GtsamTestCase):
given[key[0]] = 0 given[key[0]] = 0
# Now sample from fragment: # Now sample from fragment:
actual = fragment.sample(given) values = fragment.sample(given)
self.assertEqual(len(actual), 5) self.assertEqual(len(values), 5)
for i in [0, 1, 2]:
self.assertAlmostEqual(fragment.at(i).logProbability(values),
math.log(fragment.at(i).evaluate(values)))
self.assertAlmostEqual(fragment.logProbability(values),
math.log(fragment.evaluate(values)))
def test_dot(self): def test_dot(self):
"""Check that dot works with position hints.""" """Check that dot works with position hints."""
@ -139,7 +147,7 @@ class TestDiscreteBayesNet(GtsamTestCase):
# Make sure we can *update* position hints # Make sure we can *update* position hints
writer = gtsam.DotWriter() writer = gtsam.DotWriter()
ph: dict = writer.positionHints ph: dict = writer.positionHints
ph['a'] = 2 # hint at symbol position ph['a'] = 2 # hint at symbol position
writer.positionHints = ph writer.positionHints = ph
# Check the output of dot # Check the output of dot

View File

@ -10,15 +10,14 @@ Author: Frank Dellaert
""" """
# pylint: disable=invalid-name, no-name-in-module, no-member # pylint: disable=invalid-name, no-name-in-module, no-member
from __future__ import print_function
import unittest import unittest
import gtsam
import numpy as np import numpy as np
from gtsam import GaussianBayesNet, GaussianConditional
from gtsam.utils.test_case import GtsamTestCase from gtsam.utils.test_case import GtsamTestCase
import gtsam
from gtsam import GaussianBayesNet, GaussianConditional
# some keys # some keys
_x_ = 11 _x_ = 11
_y_ = 22 _y_ = 22
@ -45,6 +44,18 @@ class TestGaussianBayesNet(GtsamTestCase):
np.testing.assert_equal(R, R1) np.testing.assert_equal(R, R1)
np.testing.assert_equal(d, d1) np.testing.assert_equal(d, d1)
def test_evaluate(self):
"""Test evaluate method"""
bayesNet = smallBayesNet()
values = gtsam.VectorValues()
values.insert(_x_, np.array([9.0]))
values.insert(_y_, np.array([5.0]))
for i in [0, 1]:
self.assertAlmostEqual(bayesNet.at(i).logProbability(values),
np.log(bayesNet.at(i).evaluate(values)))
self.assertAlmostEqual(bayesNet.logProbability(values),
np.log(bayesNet.evaluate(values)))
def test_sample(self): def test_sample(self):
"""Test sample method""" """Test sample method"""
bayesNet = smallBayesNet() bayesNet = smallBayesNet()

View File

@ -10,14 +10,15 @@ Author: Frank Dellaert
""" """
# pylint: disable=invalid-name, no-name-in-module, no-member # pylint: disable=invalid-name, no-name-in-module, no-member
import math
import unittest import unittest
import numpy as np import numpy as np
from gtsam.symbol_shorthand import A, X from gtsam.symbol_shorthand import A, X
from gtsam.utils.test_case import GtsamTestCase from gtsam.utils.test_case import GtsamTestCase
from gtsam import (DiscreteKeys, GaussianMixture, DiscreteConditional, GaussianConditional, GaussianMixture, from gtsam import (DiscreteConditional, DiscreteKeys, GaussianConditional,
HybridBayesNet, HybridValues, noiseModel) GaussianMixture, HybridBayesNet, HybridValues, noiseModel)
class TestHybridBayesNet(GtsamTestCase): class TestHybridBayesNet(GtsamTestCase):
@ -30,8 +31,8 @@ class TestHybridBayesNet(GtsamTestCase):
# Create the continuous conditional # Create the continuous conditional
I_1x1 = np.eye(1) I_1x1 = np.eye(1)
gc = GaussianConditional.FromMeanAndStddev(X(0), 2 * I_1x1, X(1), [-4], conditional = GaussianConditional.FromMeanAndStddev(X(0), 2 * I_1x1, X(1), [-4],
5.0) 5.0)
# Create the noise models # Create the noise models
model0 = noiseModel.Diagonal.Sigmas([2.0]) model0 = noiseModel.Diagonal.Sigmas([2.0])
@ -45,7 +46,7 @@ class TestHybridBayesNet(GtsamTestCase):
# Create hybrid Bayes net. # Create hybrid Bayes net.
bayesNet = HybridBayesNet() bayesNet = HybridBayesNet()
bayesNet.push_back(gc) bayesNet.push_back(conditional)
bayesNet.push_back(GaussianMixture( bayesNet.push_back(GaussianMixture(
[X(1)], [], discrete_keys, [conditional0, conditional1])) [X(1)], [], discrete_keys, [conditional0, conditional1]))
bayesNet.push_back(DiscreteConditional(Asia, "99/1")) bayesNet.push_back(DiscreteConditional(Asia, "99/1"))
@ -56,13 +57,17 @@ class TestHybridBayesNet(GtsamTestCase):
values.insert(X(0), [-6]) values.insert(X(0), [-6])
values.insert(X(1), [1]) values.insert(X(1), [1])
conditionalProbability = gc.evaluate(values.continuous()) conditionalProbability = conditional.evaluate(values.continuous())
mixtureProbability = conditional0.evaluate(values.continuous()) mixtureProbability = conditional0.evaluate(values.continuous())
self.assertAlmostEqual(conditionalProbability * mixtureProbability * self.assertAlmostEqual(conditionalProbability * mixtureProbability *
0.99, 0.99,
bayesNet.evaluate(values), bayesNet.evaluate(values),
places=5) places=5)
# Check logProbability
self.assertAlmostEqual(bayesNet.logProbability(values),
math.log(bayesNet.evaluate(values)))
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()

View File

@ -126,17 +126,22 @@ TEST(DoglegOptimizer, Iterate) {
double Delta = 1.0; double Delta = 1.0;
for(size_t it=0; it<10; ++it) { for(size_t it=0; it<10; ++it) {
GaussianBayesNet gbn = *fg.linearize(config)->eliminateSequential(); auto linearized = fg.linearize(config);
// Iterate assumes that linear error = nonlinear error at the linearization point, and this should be true // Iterate assumes that linear error = nonlinear error at the linearization point, and this should be true
double nonlinearError = fg.error(config); double nonlinearError = fg.error(config);
double linearError = GaussianFactorGraph(gbn).error(config.zeroVectors()); double linearError = linearized->error(config.zeroVectors());
DOUBLES_EQUAL(nonlinearError, linearError, 1e-5); DOUBLES_EQUAL(nonlinearError, linearError, 1e-5);
// cout << "it " << it << ", Delta = " << Delta << ", error = " << fg->error(*config) << endl;
VectorValues dx_u = gbn.optimizeGradientSearch(); auto gbn = linearized->eliminateSequential();
VectorValues dx_n = gbn.optimize(); VectorValues dx_u = gbn->optimizeGradientSearch();
DoglegOptimizerImpl::IterationResult result = DoglegOptimizerImpl::Iterate(Delta, DoglegOptimizerImpl::SEARCH_EACH_ITERATION, dx_u, dx_n, gbn, fg, config, fg.error(config)); VectorValues dx_n = gbn->optimize();
DoglegOptimizerImpl::IterationResult result = DoglegOptimizerImpl::Iterate(
Delta, DoglegOptimizerImpl::SEARCH_EACH_ITERATION, dx_u, dx_n, *gbn, fg,
config, fg.error(config));
Delta = result.delta; Delta = result.delta;
EXPECT(result.f_error < fg.error(config)); // Check that error decreases EXPECT(result.f_error < fg.error(config)); // Check that error decreases
Values newConfig(config.retract(result.dx_d)); Values newConfig(config.retract(result.dx_d));
config = newConfig; config = newConfig;
DOUBLES_EQUAL(fg.error(config), result.f_error, 1e-5); // Check that error is correctly filled in DOUBLES_EQUAL(fg.error(config), result.f_error, 1e-5); // Check that error is correctly filled in