Merge pull request #1041 from borglab/release/4.2a3
commit
76c74195de
|
|
@ -15,7 +15,7 @@ jobs:
|
||||||
BOOST_VERSION: 1.67.0
|
BOOST_VERSION: 1.67.0
|
||||||
|
|
||||||
strategy:
|
strategy:
|
||||||
fail-fast: false
|
fail-fast: true
|
||||||
matrix:
|
matrix:
|
||||||
# Github Actions requires a single row to be added to the build matrix.
|
# Github Actions requires a single row to be added to the build matrix.
|
||||||
# See https://help.github.com/en/articles/workflow-syntax-for-github-actions.
|
# See https://help.github.com/en/articles/workflow-syntax-for-github-actions.
|
||||||
|
|
|
||||||
|
|
@ -11,7 +11,7 @@ endif()
|
||||||
set (GTSAM_VERSION_MAJOR 4)
|
set (GTSAM_VERSION_MAJOR 4)
|
||||||
set (GTSAM_VERSION_MINOR 2)
|
set (GTSAM_VERSION_MINOR 2)
|
||||||
set (GTSAM_VERSION_PATCH 0)
|
set (GTSAM_VERSION_PATCH 0)
|
||||||
set (GTSAM_PRERELEASE_VERSION "a2")
|
set (GTSAM_PRERELEASE_VERSION "a3")
|
||||||
math (EXPR GTSAM_VERSION_NUMERIC "10000 * ${GTSAM_VERSION_MAJOR} + 100 * ${GTSAM_VERSION_MINOR} + ${GTSAM_VERSION_PATCH}")
|
math (EXPR GTSAM_VERSION_NUMERIC "10000 * ${GTSAM_VERSION_MAJOR} + 100 * ${GTSAM_VERSION_MINOR} + ${GTSAM_VERSION_PATCH}")
|
||||||
|
|
||||||
if (${GTSAM_VERSION_PATCH} EQUAL 0)
|
if (${GTSAM_VERSION_PATCH} EQUAL 0)
|
||||||
|
|
|
||||||
|
|
@ -57,7 +57,7 @@ namespace gtsam {
|
||||||
/** Default constructor for I/O */
|
/** Default constructor for I/O */
|
||||||
DecisionTreeFactor();
|
DecisionTreeFactor();
|
||||||
|
|
||||||
/** Constructor from Indices, Ordering, and AlgebraicDecisionDiagram */
|
/** Constructor from DiscreteKeys and AlgebraicDecisionTree */
|
||||||
DecisionTreeFactor(const DiscreteKeys& keys, const ADT& potentials);
|
DecisionTreeFactor(const DiscreteKeys& keys, const ADT& potentials);
|
||||||
|
|
||||||
/** Constructor from doubles */
|
/** Constructor from doubles */
|
||||||
|
|
@ -139,14 +139,14 @@ namespace gtsam {
|
||||||
/**
|
/**
|
||||||
* Apply binary operator (*this) "op" f
|
* Apply binary operator (*this) "op" f
|
||||||
* @param f the second argument for op
|
* @param f the second argument for op
|
||||||
* @param op a binary operator that operates on AlgebraicDecisionDiagram potentials
|
* @param op a binary operator that operates on AlgebraicDecisionTree
|
||||||
*/
|
*/
|
||||||
DecisionTreeFactor apply(const DecisionTreeFactor& f, ADT::Binary op) const;
|
DecisionTreeFactor apply(const DecisionTreeFactor& f, ADT::Binary op) const;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Combine frontal variables using binary operator "op"
|
* Combine frontal variables using binary operator "op"
|
||||||
* @param nrFrontals nr. of frontal to combine variables in this factor
|
* @param nrFrontals nr. of frontal to combine variables in this factor
|
||||||
* @param op a binary operator that operates on AlgebraicDecisionDiagram potentials
|
* @param op a binary operator that operates on AlgebraicDecisionTree
|
||||||
* @return shared pointer to newly created DecisionTreeFactor
|
* @return shared pointer to newly created DecisionTreeFactor
|
||||||
*/
|
*/
|
||||||
shared_ptr combine(size_t nrFrontals, ADT::Binary op) const;
|
shared_ptr combine(size_t nrFrontals, ADT::Binary op) const;
|
||||||
|
|
@ -154,7 +154,7 @@ namespace gtsam {
|
||||||
/**
|
/**
|
||||||
* Combine frontal variables in an Ordering using binary operator "op"
|
* Combine frontal variables in an Ordering using binary operator "op"
|
||||||
* @param nrFrontals nr. of frontal to combine variables in this factor
|
* @param nrFrontals nr. of frontal to combine variables in this factor
|
||||||
* @param op a binary operator that operates on AlgebraicDecisionDiagram potentials
|
* @param op a binary operator that operates on AlgebraicDecisionTree
|
||||||
* @return shared pointer to newly created DecisionTreeFactor
|
* @return shared pointer to newly created DecisionTreeFactor
|
||||||
*/
|
*/
|
||||||
shared_ptr combine(const Ordering& keys, ADT::Binary op) const;
|
shared_ptr combine(const Ordering& keys, ADT::Binary op) const;
|
||||||
|
|
|
||||||
|
|
@ -19,7 +19,7 @@
|
||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
#include <gtsam/discrete/DiscreteConditional.h>
|
#include <gtsam/discrete/DiscreteConditional.h>
|
||||||
#include <gtsam/discrete/DiscretePrior.h>
|
#include <gtsam/discrete/DiscreteDistribution.h>
|
||||||
#include <gtsam/inference/BayesNet.h>
|
#include <gtsam/inference/BayesNet.h>
|
||||||
#include <gtsam/inference/FactorGraph.h>
|
#include <gtsam/inference/FactorGraph.h>
|
||||||
|
|
||||||
|
|
@ -79,9 +79,9 @@ namespace gtsam {
|
||||||
// Add inherited versions of add.
|
// Add inherited versions of add.
|
||||||
using Base::add;
|
using Base::add;
|
||||||
|
|
||||||
/** Add a DiscretePrior using a table or a string */
|
/** Add a DiscreteDistribution using a table or a string */
|
||||||
void add(const DiscreteKey& key, const std::string& spec) {
|
void add(const DiscreteKey& key, const std::string& spec) {
|
||||||
emplace_shared<DiscretePrior>(key, spec);
|
emplace_shared<DiscreteDistribution>(key, spec);
|
||||||
}
|
}
|
||||||
|
|
||||||
/** Add a DiscreteCondtional */
|
/** Add a DiscreteCondtional */
|
||||||
|
|
|
||||||
|
|
@ -30,6 +30,7 @@
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
#include <utility>
|
#include <utility>
|
||||||
|
#include <set>
|
||||||
|
|
||||||
using namespace std;
|
using namespace std;
|
||||||
using std::stringstream;
|
using std::stringstream;
|
||||||
|
|
@ -38,38 +39,97 @@ using std::pair;
|
||||||
namespace gtsam {
|
namespace gtsam {
|
||||||
|
|
||||||
// Instantiate base class
|
// Instantiate base class
|
||||||
template class GTSAM_EXPORT Conditional<DecisionTreeFactor, DiscreteConditional> ;
|
template class GTSAM_EXPORT
|
||||||
|
Conditional<DecisionTreeFactor, DiscreteConditional>;
|
||||||
|
|
||||||
/* ******************************************************************************** */
|
/* ************************************************************************** */
|
||||||
DiscreteConditional::DiscreteConditional(const size_t nrFrontals,
|
DiscreteConditional::DiscreteConditional(const size_t nrFrontals,
|
||||||
const DecisionTreeFactor& f) :
|
const DecisionTreeFactor& f)
|
||||||
BaseFactor(f / (*f.sum(nrFrontals))), BaseConditional(nrFrontals) {
|
: BaseFactor(f / (*f.sum(nrFrontals))), BaseConditional(nrFrontals) {}
|
||||||
}
|
|
||||||
|
|
||||||
/* ******************************************************************************** */
|
/* ************************************************************************** */
|
||||||
DiscreteConditional::DiscreteConditional(const DecisionTreeFactor& joint,
|
DiscreteConditional::DiscreteConditional(size_t nrFrontals,
|
||||||
const DecisionTreeFactor& marginal) :
|
const DiscreteKeys& keys,
|
||||||
BaseFactor(
|
const ADT& potentials)
|
||||||
ISDEBUG("DiscreteConditional::COUNT") ? joint : joint / marginal), BaseConditional(
|
: BaseFactor(keys, potentials), BaseConditional(nrFrontals) {}
|
||||||
joint.size()-marginal.size()) {
|
|
||||||
if (ISDEBUG("DiscreteConditional::DiscreteConditional"))
|
|
||||||
cout << (firstFrontalKey()) << endl; //TODO Print all keys
|
|
||||||
}
|
|
||||||
|
|
||||||
/* ******************************************************************************** */
|
/* ************************************************************************** */
|
||||||
DiscreteConditional::DiscreteConditional(const DecisionTreeFactor& joint,
|
DiscreteConditional::DiscreteConditional(const DecisionTreeFactor& joint,
|
||||||
const DecisionTreeFactor& marginal, const Ordering& orderedKeys) :
|
const DecisionTreeFactor& marginal)
|
||||||
DiscreteConditional(joint, marginal) {
|
: BaseFactor(joint / marginal),
|
||||||
|
BaseConditional(joint.size() - marginal.size()) {}
|
||||||
|
|
||||||
|
/* ************************************************************************** */
|
||||||
|
DiscreteConditional::DiscreteConditional(const DecisionTreeFactor& joint,
|
||||||
|
const DecisionTreeFactor& marginal,
|
||||||
|
const Ordering& orderedKeys)
|
||||||
|
: DiscreteConditional(joint, marginal) {
|
||||||
keys_.clear();
|
keys_.clear();
|
||||||
keys_.insert(keys_.end(), orderedKeys.begin(), orderedKeys.end());
|
keys_.insert(keys_.end(), orderedKeys.begin(), orderedKeys.end());
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ******************************************************************************** */
|
/* ************************************************************************** */
|
||||||
DiscreteConditional::DiscreteConditional(const Signature& signature)
|
DiscreteConditional::DiscreteConditional(const Signature& signature)
|
||||||
: BaseFactor(signature.discreteKeys(), signature.cpt()),
|
: BaseFactor(signature.discreteKeys(), signature.cpt()),
|
||||||
BaseConditional(1) {}
|
BaseConditional(1) {}
|
||||||
|
|
||||||
/* ******************************************************************************** */
|
/* ************************************************************************** */
|
||||||
|
DiscreteConditional DiscreteConditional::operator*(
|
||||||
|
const DiscreteConditional& other) const {
|
||||||
|
// Take union of frontal keys
|
||||||
|
std::set<Key> newFrontals;
|
||||||
|
for (auto&& key : this->frontals()) newFrontals.insert(key);
|
||||||
|
for (auto&& key : other.frontals()) newFrontals.insert(key);
|
||||||
|
|
||||||
|
// Check if frontals overlapped
|
||||||
|
if (nrFrontals() + other.nrFrontals() > newFrontals.size())
|
||||||
|
throw std::invalid_argument(
|
||||||
|
"DiscreteConditional::operator* called with overlapping frontal keys.");
|
||||||
|
|
||||||
|
// Now, add cardinalities.
|
||||||
|
DiscreteKeys discreteKeys;
|
||||||
|
for (auto&& key : frontals())
|
||||||
|
discreteKeys.emplace_back(key, cardinality(key));
|
||||||
|
for (auto&& key : other.frontals())
|
||||||
|
discreteKeys.emplace_back(key, other.cardinality(key));
|
||||||
|
|
||||||
|
// Sort
|
||||||
|
std::sort(discreteKeys.begin(), discreteKeys.end());
|
||||||
|
|
||||||
|
// Add parents to set, to make them unique
|
||||||
|
std::set<DiscreteKey> parents;
|
||||||
|
for (auto&& key : this->parents())
|
||||||
|
if (!newFrontals.count(key)) parents.emplace(key, cardinality(key));
|
||||||
|
for (auto&& key : other.parents())
|
||||||
|
if (!newFrontals.count(key)) parents.emplace(key, other.cardinality(key));
|
||||||
|
|
||||||
|
// Finally, add parents to keys, in order
|
||||||
|
for (auto&& dk : parents) discreteKeys.push_back(dk);
|
||||||
|
|
||||||
|
ADT product = ADT::apply(other, ADT::Ring::mul);
|
||||||
|
return DiscreteConditional(newFrontals.size(), discreteKeys, product);
|
||||||
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************** */
|
||||||
|
DiscreteConditional DiscreteConditional::marginal(Key key) const {
|
||||||
|
if (nrParents() > 0)
|
||||||
|
throw std::invalid_argument(
|
||||||
|
"DiscreteConditional::marginal: single argument version only valid for "
|
||||||
|
"fully specified joint distributions (i.e., no parents).");
|
||||||
|
|
||||||
|
// Calculate the keys as the frontal keys without the given key.
|
||||||
|
DiscreteKeys discreteKeys{{key, cardinality(key)}};
|
||||||
|
|
||||||
|
// Calculate sum
|
||||||
|
ADT adt(*this);
|
||||||
|
for (auto&& k : frontals())
|
||||||
|
if (k != key) adt = adt.sum(k, cardinality(k));
|
||||||
|
|
||||||
|
// Return new factor
|
||||||
|
return DiscreteConditional(1, discreteKeys, adt);
|
||||||
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************** */
|
||||||
void DiscreteConditional::print(const string& s,
|
void DiscreteConditional::print(const string& s,
|
||||||
const KeyFormatter& formatter) const {
|
const KeyFormatter& formatter) const {
|
||||||
cout << s << " P( ";
|
cout << s << " P( ";
|
||||||
|
|
@ -82,7 +142,7 @@ void DiscreteConditional::print(const string& s,
|
||||||
cout << formatter(*it) << " ";
|
cout << formatter(*it) << " ";
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
cout << ")";
|
cout << "):\n";
|
||||||
ADT::print("");
|
ADT::print("");
|
||||||
cout << endl;
|
cout << endl;
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -49,14 +49,21 @@ class GTSAM_EXPORT DiscreteConditional
|
||||||
/// @name Standard Constructors
|
/// @name Standard Constructors
|
||||||
/// @{
|
/// @{
|
||||||
|
|
||||||
/** default constructor needed for serialization */
|
/// Default constructor needed for serialization.
|
||||||
DiscreteConditional() {}
|
DiscreteConditional() {}
|
||||||
|
|
||||||
/** constructor from factor */
|
/// Construct from factor, taking the first `nFrontals` keys as frontals.
|
||||||
DiscreteConditional(size_t nFrontals, const DecisionTreeFactor& f);
|
DiscreteConditional(size_t nFrontals, const DecisionTreeFactor& f);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Construct from DiscreteKeys and AlgebraicDecisionTree, taking the first
|
||||||
|
* `nFrontals` keys as frontals, in the order given.
|
||||||
|
*/
|
||||||
|
DiscreteConditional(size_t nFrontals, const DiscreteKeys& keys,
|
||||||
|
const ADT& potentials);
|
||||||
|
|
||||||
/** Construct from signature */
|
/** Construct from signature */
|
||||||
DiscreteConditional(const Signature& signature);
|
explicit DiscreteConditional(const Signature& signature);
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Construct from key, parents, and a Signature::Table specifying the
|
* Construct from key, parents, and a Signature::Table specifying the
|
||||||
|
|
@ -82,31 +89,45 @@ class GTSAM_EXPORT DiscreteConditional
|
||||||
const std::string& spec)
|
const std::string& spec)
|
||||||
: DiscreteConditional(Signature(key, parents, spec)) {}
|
: DiscreteConditional(Signature(key, parents, spec)) {}
|
||||||
|
|
||||||
/// No-parent specialization; can also use DiscretePrior.
|
/// No-parent specialization; can also use DiscreteDistribution.
|
||||||
DiscreteConditional(const DiscreteKey& key, const std::string& spec)
|
DiscreteConditional(const DiscreteKey& key, const std::string& spec)
|
||||||
: DiscreteConditional(Signature(key, {}, spec)) {}
|
: DiscreteConditional(Signature(key, {}, spec)) {}
|
||||||
|
|
||||||
/** construct P(X|Y)=P(X,Y)/P(Y) from P(X,Y) and P(Y) */
|
/**
|
||||||
|
* @brief construct P(X|Y) = f(X,Y)/f(Y) from f(X,Y) and f(Y)
|
||||||
|
* Assumes but *does not check* that f(Y)=sum_X f(X,Y).
|
||||||
|
*/
|
||||||
DiscreteConditional(const DecisionTreeFactor& joint,
|
DiscreteConditional(const DecisionTreeFactor& joint,
|
||||||
const DecisionTreeFactor& marginal);
|
const DecisionTreeFactor& marginal);
|
||||||
|
|
||||||
/** construct P(X|Y)=P(X,Y)/P(Y) from P(X,Y) and P(Y) */
|
/**
|
||||||
|
* @brief construct P(X|Y) = f(X,Y)/f(Y) from f(X,Y) and f(Y)
|
||||||
|
* Assumes but *does not check* that f(Y)=sum_X f(X,Y).
|
||||||
|
* Makes sure the keys are ordered as given. Does not check orderedKeys.
|
||||||
|
*/
|
||||||
DiscreteConditional(const DecisionTreeFactor& joint,
|
DiscreteConditional(const DecisionTreeFactor& joint,
|
||||||
const DecisionTreeFactor& marginal,
|
const DecisionTreeFactor& marginal,
|
||||||
const Ordering& orderedKeys);
|
const Ordering& orderedKeys);
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Combine several conditional into a single one.
|
* @brief Combine two conditionals, yielding a new conditional with the union
|
||||||
* The conditionals must be given in increasing order, meaning that the
|
* of the frontal keys, ordered by gtsam::Key.
|
||||||
* parents of any conditional may not include a conditional coming before it.
|
*
|
||||||
* @param firstConditional Iterator to the first conditional to combine, must
|
* The two conditionals must make a valid Bayes net fragment, i.e.,
|
||||||
* dereference to a shared_ptr<DiscreteConditional>.
|
* the frontal variables cannot overlap, and must be acyclic:
|
||||||
* @param lastConditional Iterator to after the last conditional to combine,
|
* Example of correct use:
|
||||||
* must dereference to a shared_ptr<DiscreteConditional>.
|
* P(A,B) = P(A|B) * P(B)
|
||||||
* */
|
* P(A,B|C) = P(A|B) * P(B|C)
|
||||||
template <typename ITERATOR>
|
* P(A,B,C) = P(A,B|C) * P(C)
|
||||||
static shared_ptr Combine(ITERATOR firstConditional,
|
* Example of incorrect use:
|
||||||
ITERATOR lastConditional);
|
* P(A|B) * P(A|C) = ?
|
||||||
|
* P(A|B) * P(B|A) = ?
|
||||||
|
* We check for overlapping frontals, but do *not* check for cyclic.
|
||||||
|
*/
|
||||||
|
DiscreteConditional operator*(const DiscreteConditional& other) const;
|
||||||
|
|
||||||
|
/** Calculate marginal on given key, no parent case. */
|
||||||
|
DiscreteConditional marginal(Key key) const;
|
||||||
|
|
||||||
/// @}
|
/// @}
|
||||||
/// @name Testable
|
/// @name Testable
|
||||||
|
|
@ -136,11 +157,6 @@ class GTSAM_EXPORT DiscreteConditional
|
||||||
return ADT::operator()(values);
|
return ADT::operator()(values);
|
||||||
}
|
}
|
||||||
|
|
||||||
/** Convert to a factor */
|
|
||||||
DecisionTreeFactor::shared_ptr toFactor() const {
|
|
||||||
return DecisionTreeFactor::shared_ptr(new DecisionTreeFactor(*this));
|
|
||||||
}
|
|
||||||
|
|
||||||
/** Restrict to given parent values, returns DecisionTreeFactor */
|
/** Restrict to given parent values, returns DecisionTreeFactor */
|
||||||
DecisionTreeFactor::shared_ptr choose(
|
DecisionTreeFactor::shared_ptr choose(
|
||||||
const DiscreteValues& parentsValues) const;
|
const DiscreteValues& parentsValues) const;
|
||||||
|
|
@ -208,23 +224,4 @@ class GTSAM_EXPORT DiscreteConditional
|
||||||
template <>
|
template <>
|
||||||
struct traits<DiscreteConditional> : public Testable<DiscreteConditional> {};
|
struct traits<DiscreteConditional> : public Testable<DiscreteConditional> {};
|
||||||
|
|
||||||
/* ************************************************************************* */
|
|
||||||
template <typename ITERATOR>
|
|
||||||
DiscreteConditional::shared_ptr DiscreteConditional::Combine(
|
|
||||||
ITERATOR firstConditional, ITERATOR lastConditional) {
|
|
||||||
// TODO: check for being a clique
|
|
||||||
|
|
||||||
// multiply all the potentials of the given conditionals
|
|
||||||
size_t nrFrontals = 0;
|
|
||||||
DecisionTreeFactor product;
|
|
||||||
for (ITERATOR it = firstConditional; it != lastConditional;
|
|
||||||
++it, ++nrFrontals) {
|
|
||||||
DiscreteConditional::shared_ptr c = *it;
|
|
||||||
DecisionTreeFactor::shared_ptr factor = c->toFactor();
|
|
||||||
product = (*factor) * product;
|
|
||||||
}
|
|
||||||
// and then create a new multi-frontal conditional
|
|
||||||
return boost::make_shared<DiscreteConditional>(nrFrontals, product);
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace gtsam
|
} // namespace gtsam
|
||||||
|
|
|
||||||
|
|
@ -10,21 +10,23 @@
|
||||||
* -------------------------------------------------------------------------- */
|
* -------------------------------------------------------------------------- */
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @file DiscretePrior.cpp
|
* @file DiscreteDistribution.cpp
|
||||||
* @date December 2021
|
* @date December 2021
|
||||||
* @author Frank Dellaert
|
* @author Frank Dellaert
|
||||||
*/
|
*/
|
||||||
|
|
||||||
#include <gtsam/discrete/DiscretePrior.h>
|
#include <gtsam/discrete/DiscreteDistribution.h>
|
||||||
|
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
namespace gtsam {
|
namespace gtsam {
|
||||||
|
|
||||||
void DiscretePrior::print(const std::string& s,
|
void DiscreteDistribution::print(const std::string& s,
|
||||||
const KeyFormatter& formatter) const {
|
const KeyFormatter& formatter) const {
|
||||||
Base::print(s, formatter);
|
Base::print(s, formatter);
|
||||||
}
|
}
|
||||||
|
|
||||||
double DiscretePrior::operator()(size_t value) const {
|
double DiscreteDistribution::operator()(size_t value) const {
|
||||||
if (nrFrontals() != 1)
|
if (nrFrontals() != 1)
|
||||||
throw std::invalid_argument(
|
throw std::invalid_argument(
|
||||||
"Single value operator can only be invoked on single-variable "
|
"Single value operator can only be invoked on single-variable "
|
||||||
|
|
@ -34,10 +36,10 @@ double DiscretePrior::operator()(size_t value) const {
|
||||||
return Base::operator()(values);
|
return Base::operator()(values);
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<double> DiscretePrior::pmf() const {
|
std::vector<double> DiscreteDistribution::pmf() const {
|
||||||
if (nrFrontals() != 1)
|
if (nrFrontals() != 1)
|
||||||
throw std::invalid_argument(
|
throw std::invalid_argument(
|
||||||
"DiscretePrior::pmf only defined for single-variable priors");
|
"DiscreteDistribution::pmf only defined for single-variable priors");
|
||||||
const size_t nrValues = cardinalities_.at(keys_[0]);
|
const size_t nrValues = cardinalities_.at(keys_[0]);
|
||||||
std::vector<double> array;
|
std::vector<double> array;
|
||||||
array.reserve(nrValues);
|
array.reserve(nrValues);
|
||||||
|
|
@ -10,7 +10,7 @@
|
||||||
* -------------------------------------------------------------------------- */
|
* -------------------------------------------------------------------------- */
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @file DiscretePrior.h
|
* @file DiscreteDistribution.h
|
||||||
* @date December 2021
|
* @date December 2021
|
||||||
* @author Frank Dellaert
|
* @author Frank Dellaert
|
||||||
*/
|
*/
|
||||||
|
|
@ -20,6 +20,7 @@
|
||||||
#include <gtsam/discrete/DiscreteConditional.h>
|
#include <gtsam/discrete/DiscreteConditional.h>
|
||||||
|
|
||||||
#include <string>
|
#include <string>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
namespace gtsam {
|
namespace gtsam {
|
||||||
|
|
||||||
|
|
@ -27,7 +28,7 @@ namespace gtsam {
|
||||||
* A prior probability on a set of discrete variables.
|
* A prior probability on a set of discrete variables.
|
||||||
* Derives from DiscreteConditional
|
* Derives from DiscreteConditional
|
||||||
*/
|
*/
|
||||||
class GTSAM_EXPORT DiscretePrior : public DiscreteConditional {
|
class GTSAM_EXPORT DiscreteDistribution : public DiscreteConditional {
|
||||||
public:
|
public:
|
||||||
using Base = DiscreteConditional;
|
using Base = DiscreteConditional;
|
||||||
|
|
||||||
|
|
@ -35,35 +36,36 @@ class GTSAM_EXPORT DiscretePrior : public DiscreteConditional {
|
||||||
/// @{
|
/// @{
|
||||||
|
|
||||||
/// Default constructor needed for serialization.
|
/// Default constructor needed for serialization.
|
||||||
DiscretePrior() {}
|
DiscreteDistribution() {}
|
||||||
|
|
||||||
/// Constructor from factor.
|
/// Constructor from factor.
|
||||||
DiscretePrior(const DecisionTreeFactor& f) : Base(f.size(), f) {}
|
explicit DiscreteDistribution(const DecisionTreeFactor& f)
|
||||||
|
: Base(f.size(), f) {}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Construct from a Signature.
|
* Construct from a Signature.
|
||||||
*
|
*
|
||||||
* Example: DiscretePrior P(D % "3/2");
|
* Example: DiscreteDistribution P(D % "3/2");
|
||||||
*/
|
*/
|
||||||
DiscretePrior(const Signature& s) : Base(s) {}
|
explicit DiscreteDistribution(const Signature& s) : Base(s) {}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Construct from key and a Signature::Table specifying the
|
* Construct from key and a vector of floats specifying the probability mass
|
||||||
* conditional probability table (CPT).
|
* function (PMF).
|
||||||
*
|
*
|
||||||
* Example: DiscretePrior P(D, table);
|
* Example: DiscreteDistribution P(D, {0.4, 0.6});
|
||||||
*/
|
*/
|
||||||
DiscretePrior(const DiscreteKey& key, const Signature::Table& table)
|
DiscreteDistribution(const DiscreteKey& key, const std::vector<double>& spec)
|
||||||
: Base(Signature(key, {}, table)) {}
|
: DiscreteDistribution(Signature(key, {}, Signature::Table{spec})) {}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Construct from key and a string specifying the conditional
|
* Construct from key and a string specifying the probability mass function
|
||||||
* probability table (CPT).
|
* (PMF).
|
||||||
*
|
*
|
||||||
* Example: DiscretePrior P(D, "9/1 2/8 3/7 1/9");
|
* Example: DiscreteDistribution P(D, "9/1 2/8 3/7 1/9");
|
||||||
*/
|
*/
|
||||||
DiscretePrior(const DiscreteKey& key, const std::string& spec)
|
DiscreteDistribution(const DiscreteKey& key, const std::string& spec)
|
||||||
: DiscretePrior(Signature(key, {}, spec)) {}
|
: DiscreteDistribution(Signature(key, {}, spec)) {}
|
||||||
|
|
||||||
/// @}
|
/// @}
|
||||||
/// @name Testable
|
/// @name Testable
|
||||||
|
|
@ -102,10 +104,10 @@ class GTSAM_EXPORT DiscretePrior : public DiscreteConditional {
|
||||||
|
|
||||||
/// @}
|
/// @}
|
||||||
};
|
};
|
||||||
// DiscretePrior
|
// DiscreteDistribution
|
||||||
|
|
||||||
// traits
|
// traits
|
||||||
template <>
|
template <>
|
||||||
struct traits<DiscretePrior> : public Testable<DiscretePrior> {};
|
struct traits<DiscreteDistribution> : public Testable<DiscreteDistribution> {};
|
||||||
|
|
||||||
} // namespace gtsam
|
} // namespace gtsam
|
||||||
|
|
@ -58,6 +58,15 @@ virtual class DecisionTreeFactor : gtsam::DiscreteFactor {
|
||||||
const gtsam::KeyFormatter& keyFormatter =
|
const gtsam::KeyFormatter& keyFormatter =
|
||||||
gtsam::DefaultKeyFormatter) const;
|
gtsam::DefaultKeyFormatter) const;
|
||||||
bool equals(const gtsam::DecisionTreeFactor& other, double tol = 1e-9) const;
|
bool equals(const gtsam::DecisionTreeFactor& other, double tol = 1e-9) const;
|
||||||
|
|
||||||
|
double operator()(const gtsam::DiscreteValues& values) const;
|
||||||
|
gtsam::DecisionTreeFactor operator*(const gtsam::DecisionTreeFactor& f) const;
|
||||||
|
size_t cardinality(gtsam::Key j) const;
|
||||||
|
gtsam::DecisionTreeFactor operator/(const gtsam::DecisionTreeFactor& f) const;
|
||||||
|
gtsam::DecisionTreeFactor* sum(size_t nrFrontals) const;
|
||||||
|
gtsam::DecisionTreeFactor* sum(const gtsam::Ordering& keys) const;
|
||||||
|
gtsam::DecisionTreeFactor* max(size_t nrFrontals) const;
|
||||||
|
|
||||||
string dot(
|
string dot(
|
||||||
const gtsam::KeyFormatter& keyFormatter = gtsam::DefaultKeyFormatter,
|
const gtsam::KeyFormatter& keyFormatter = gtsam::DefaultKeyFormatter,
|
||||||
bool showZero = true) const;
|
bool showZero = true) const;
|
||||||
|
|
@ -86,14 +95,18 @@ virtual class DiscreteConditional : gtsam::DecisionTreeFactor {
|
||||||
DiscreteConditional(const gtsam::DecisionTreeFactor& joint,
|
DiscreteConditional(const gtsam::DecisionTreeFactor& joint,
|
||||||
const gtsam::DecisionTreeFactor& marginal,
|
const gtsam::DecisionTreeFactor& marginal,
|
||||||
const gtsam::Ordering& orderedKeys);
|
const gtsam::Ordering& orderedKeys);
|
||||||
|
gtsam::DiscreteConditional operator*(
|
||||||
|
const gtsam::DiscreteConditional& other) const;
|
||||||
|
DiscreteConditional marginal(gtsam::Key key) const;
|
||||||
void print(string s = "Discrete Conditional\n",
|
void print(string s = "Discrete Conditional\n",
|
||||||
const gtsam::KeyFormatter& keyFormatter =
|
const gtsam::KeyFormatter& keyFormatter =
|
||||||
gtsam::DefaultKeyFormatter) const;
|
gtsam::DefaultKeyFormatter) const;
|
||||||
bool equals(const gtsam::DiscreteConditional& other, double tol = 1e-9) const;
|
bool equals(const gtsam::DiscreteConditional& other, double tol = 1e-9) const;
|
||||||
|
size_t nrFrontals() const;
|
||||||
|
size_t nrParents() const;
|
||||||
void printSignature(
|
void printSignature(
|
||||||
string s = "Discrete Conditional: ",
|
string s = "Discrete Conditional: ",
|
||||||
const gtsam::KeyFormatter& formatter = gtsam::DefaultKeyFormatter) const;
|
const gtsam::KeyFormatter& formatter = gtsam::DefaultKeyFormatter) const;
|
||||||
gtsam::DecisionTreeFactor* toFactor() const;
|
|
||||||
gtsam::DecisionTreeFactor* choose(
|
gtsam::DecisionTreeFactor* choose(
|
||||||
const gtsam::DiscreteValues& parentsValues) const;
|
const gtsam::DiscreteValues& parentsValues) const;
|
||||||
gtsam::DecisionTreeFactor* likelihood(
|
gtsam::DecisionTreeFactor* likelihood(
|
||||||
|
|
@ -115,11 +128,12 @@ virtual class DiscreteConditional : gtsam::DecisionTreeFactor {
|
||||||
std::map<gtsam::Key, std::vector<std::string>> names) const;
|
std::map<gtsam::Key, std::vector<std::string>> names) const;
|
||||||
};
|
};
|
||||||
|
|
||||||
#include <gtsam/discrete/DiscretePrior.h>
|
#include <gtsam/discrete/DiscreteDistribution.h>
|
||||||
virtual class DiscretePrior : gtsam::DiscreteConditional {
|
virtual class DiscreteDistribution : gtsam::DiscreteConditional {
|
||||||
DiscretePrior();
|
DiscreteDistribution();
|
||||||
DiscretePrior(const gtsam::DecisionTreeFactor& f);
|
DiscreteDistribution(const gtsam::DecisionTreeFactor& f);
|
||||||
DiscretePrior(const gtsam::DiscreteKey& key, string spec);
|
DiscreteDistribution(const gtsam::DiscreteKey& key, string spec);
|
||||||
|
DiscreteDistribution(const gtsam::DiscreteKey& key, std::vector<double> spec);
|
||||||
void print(string s = "Discrete Prior\n",
|
void print(string s = "Discrete Prior\n",
|
||||||
const gtsam::KeyFormatter& keyFormatter =
|
const gtsam::KeyFormatter& keyFormatter =
|
||||||
gtsam::DefaultKeyFormatter) const;
|
gtsam::DefaultKeyFormatter) const;
|
||||||
|
|
|
||||||
|
|
@ -17,10 +17,12 @@
|
||||||
* @author Duy-Nguyen Ta
|
* @author Duy-Nguyen Ta
|
||||||
*/
|
*/
|
||||||
|
|
||||||
#include <gtsam/discrete/Signature.h>
|
|
||||||
#include <gtsam/discrete/DecisionTreeFactor.h>
|
|
||||||
#include <gtsam/base/Testable.h>
|
|
||||||
#include <CppUnitLite/TestHarness.h>
|
#include <CppUnitLite/TestHarness.h>
|
||||||
|
#include <gtsam/base/Testable.h>
|
||||||
|
#include <gtsam/discrete/DecisionTreeFactor.h>
|
||||||
|
#include <gtsam/discrete/DiscreteDistribution.h>
|
||||||
|
#include <gtsam/discrete/Signature.h>
|
||||||
|
|
||||||
#include <boost/assign/std/map.hpp>
|
#include <boost/assign/std/map.hpp>
|
||||||
using namespace boost::assign;
|
using namespace boost::assign;
|
||||||
|
|
||||||
|
|
@ -51,17 +53,21 @@ TEST( DecisionTreeFactor, constructors)
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
TEST_UNSAFE( DecisionTreeFactor, multiplication)
|
TEST(DecisionTreeFactor, multiplication) {
|
||||||
{
|
DiscreteKey v0(0, 2), v1(1, 2), v2(2, 2);
|
||||||
DiscreteKey v0(0,2), v1(1,2), v2(2,2);
|
|
||||||
|
|
||||||
|
// Multiply with a DiscreteDistribution, i.e., Bayes Law!
|
||||||
|
DiscreteDistribution prior(v1 % "1/3");
|
||||||
DecisionTreeFactor f1(v0 & v1, "1 2 3 4");
|
DecisionTreeFactor f1(v0 & v1, "1 2 3 4");
|
||||||
|
DecisionTreeFactor expected(v0 & v1, "0.25 1.5 0.75 3");
|
||||||
|
CHECK(assert_equal(expected, static_cast<DecisionTreeFactor>(prior) * f1));
|
||||||
|
CHECK(assert_equal(expected, f1 * prior));
|
||||||
|
|
||||||
|
// Multiply two factors
|
||||||
DecisionTreeFactor f2(v1 & v2, "5 6 7 8");
|
DecisionTreeFactor f2(v1 & v2, "5 6 7 8");
|
||||||
|
|
||||||
DecisionTreeFactor expected(v0 & v1 & v2, "5 6 14 16 15 18 28 32");
|
|
||||||
|
|
||||||
DecisionTreeFactor actual = f1 * f2;
|
DecisionTreeFactor actual = f1 * f2;
|
||||||
CHECK(assert_equal(expected, actual));
|
DecisionTreeFactor expected2(v0 & v1 & v2, "5 6 14 16 15 18 28 32");
|
||||||
|
CHECK(assert_equal(expected2, actual));
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
|
|
|
||||||
|
|
@ -34,20 +34,21 @@ using namespace gtsam;
|
||||||
TEST(DiscreteConditional, constructors) {
|
TEST(DiscreteConditional, constructors) {
|
||||||
DiscreteKey X(0, 2), Y(2, 3), Z(1, 2); // watch ordering !
|
DiscreteKey X(0, 2), Y(2, 3), Z(1, 2); // watch ordering !
|
||||||
|
|
||||||
DiscreteConditional expected(X | Y = "1/1 2/3 1/4");
|
DiscreteConditional actual(X | Y = "1/1 2/3 1/4");
|
||||||
EXPECT_LONGS_EQUAL(0, *(expected.beginFrontals()));
|
EXPECT_LONGS_EQUAL(0, *(actual.beginFrontals()));
|
||||||
EXPECT_LONGS_EQUAL(2, *(expected.beginParents()));
|
EXPECT_LONGS_EQUAL(2, *(actual.beginParents()));
|
||||||
EXPECT(expected.endParents() == expected.end());
|
EXPECT(actual.endParents() == actual.end());
|
||||||
EXPECT(expected.endFrontals() == expected.beginParents());
|
EXPECT(actual.endFrontals() == actual.beginParents());
|
||||||
|
|
||||||
DecisionTreeFactor f1(X & Y, "0.5 0.4 0.2 0.5 0.6 0.8");
|
DecisionTreeFactor f1(X & Y, "0.5 0.4 0.2 0.5 0.6 0.8");
|
||||||
DiscreteConditional actual1(1, f1);
|
DiscreteConditional expected1(1, f1);
|
||||||
EXPECT(assert_equal(expected, actual1, 1e-9));
|
EXPECT(assert_equal(expected1, actual, 1e-9));
|
||||||
|
|
||||||
DecisionTreeFactor f2(
|
DecisionTreeFactor f2(
|
||||||
X & Y & Z, "0.2 0.5 0.3 0.6 0.4 0.7 0.25 0.55 0.35 0.65 0.45 0.75");
|
X & Y & Z, "0.2 0.5 0.3 0.6 0.4 0.7 0.25 0.55 0.35 0.65 0.45 0.75");
|
||||||
DiscreteConditional actual2(1, f2);
|
DiscreteConditional actual2(1, f2);
|
||||||
EXPECT(assert_equal(f2 / *f2.sum(1), *actual2.toFactor(), 1e-9));
|
DecisionTreeFactor expected2 = f2 / *f2.sum(1);
|
||||||
|
EXPECT(assert_equal(expected2, static_cast<DecisionTreeFactor>(actual2)));
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
|
|
@ -61,6 +62,7 @@ TEST(DiscreteConditional, constructors_alt_interface) {
|
||||||
r3 += 1.0, 4.0;
|
r3 += 1.0, 4.0;
|
||||||
table += r1, r2, r3;
|
table += r1, r2, r3;
|
||||||
DiscreteConditional actual1(X, {Y}, table);
|
DiscreteConditional actual1(X, {Y}, table);
|
||||||
|
|
||||||
DecisionTreeFactor f1(X & Y, "0.5 0.4 0.2 0.5 0.6 0.8");
|
DecisionTreeFactor f1(X & Y, "0.5 0.4 0.2 0.5 0.6 0.8");
|
||||||
DiscreteConditional expected1(1, f1);
|
DiscreteConditional expected1(1, f1);
|
||||||
EXPECT(assert_equal(expected1, actual1, 1e-9));
|
EXPECT(assert_equal(expected1, actual1, 1e-9));
|
||||||
|
|
@ -68,41 +70,141 @@ TEST(DiscreteConditional, constructors_alt_interface) {
|
||||||
DecisionTreeFactor f2(
|
DecisionTreeFactor f2(
|
||||||
X & Y & Z, "0.2 0.5 0.3 0.6 0.4 0.7 0.25 0.55 0.35 0.65 0.45 0.75");
|
X & Y & Z, "0.2 0.5 0.3 0.6 0.4 0.7 0.25 0.55 0.35 0.65 0.45 0.75");
|
||||||
DiscreteConditional actual2(1, f2);
|
DiscreteConditional actual2(1, f2);
|
||||||
EXPECT(assert_equal(f2 / *f2.sum(1), *actual2.toFactor(), 1e-9));
|
DecisionTreeFactor expected2 = f2 / *f2.sum(1);
|
||||||
|
EXPECT(assert_equal(expected2, static_cast<DecisionTreeFactor>(actual2)));
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
TEST(DiscreteConditional, constructors2) {
|
TEST(DiscreteConditional, constructors2) {
|
||||||
// Declare keys and ordering
|
|
||||||
DiscreteKey C(0, 2), B(1, 2);
|
DiscreteKey C(0, 2), B(1, 2);
|
||||||
DecisionTreeFactor actual(C & B, "0.8 0.75 0.2 0.25");
|
|
||||||
Signature signature((C | B) = "4/1 3/1");
|
Signature signature((C | B) = "4/1 3/1");
|
||||||
DiscreteConditional expected(signature);
|
DiscreteConditional actual(signature);
|
||||||
DecisionTreeFactor::shared_ptr expectedFactor = expected.toFactor();
|
|
||||||
EXPECT(assert_equal(*expectedFactor, actual));
|
DecisionTreeFactor expected(C & B, "0.8 0.75 0.2 0.25");
|
||||||
|
EXPECT(assert_equal(expected, static_cast<DecisionTreeFactor>(actual)));
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
TEST(DiscreteConditional, constructors3) {
|
TEST(DiscreteConditional, constructors3) {
|
||||||
// Declare keys and ordering
|
|
||||||
DiscreteKey C(0, 2), B(1, 2), A(2, 2);
|
DiscreteKey C(0, 2), B(1, 2), A(2, 2);
|
||||||
DecisionTreeFactor actual(C & B & A, "0.8 0.5 0.5 0.2 0.2 0.5 0.5 0.8");
|
|
||||||
Signature signature((C | B, A) = "4/1 1/1 1/1 1/4");
|
Signature signature((C | B, A) = "4/1 1/1 1/1 1/4");
|
||||||
DiscreteConditional expected(signature);
|
DiscreteConditional actual(signature);
|
||||||
DecisionTreeFactor::shared_ptr expectedFactor = expected.toFactor();
|
|
||||||
EXPECT(assert_equal(*expectedFactor, actual));
|
DecisionTreeFactor expected(C & B & A, "0.8 0.5 0.5 0.2 0.2 0.5 0.5 0.8");
|
||||||
|
EXPECT(assert_equal(expected, static_cast<DecisionTreeFactor>(actual)));
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
TEST(DiscreteConditional, Combine) {
|
// Check calculation of joint P(A,B)
|
||||||
DiscreteKey A(0, 2), B(1, 2);
|
TEST(DiscreteConditional, Multiply) {
|
||||||
vector<DiscreteConditional::shared_ptr> c;
|
DiscreteKey A(1, 2), B(0, 2);
|
||||||
c.push_back(boost::make_shared<DiscreteConditional>(A | B = "1/2 2/1"));
|
DiscreteConditional conditional(A | B = "1/2 2/1");
|
||||||
c.push_back(boost::make_shared<DiscreteConditional>(B % "1/2"));
|
DiscreteConditional prior(B % "1/2");
|
||||||
DecisionTreeFactor factor(A & B, "0.111111 0.444444 0.222222 0.222222");
|
|
||||||
DiscreteConditional expected(2, factor);
|
// The expected factor
|
||||||
auto actual = DiscreteConditional::Combine(c.begin(), c.end());
|
DecisionTreeFactor f(A & B, "1 4 2 2");
|
||||||
EXPECT(assert_equal(expected, *actual, 1e-5));
|
DiscreteConditional expected(2, f);
|
||||||
|
|
||||||
|
// P(A,B) = P(A|B) * P(B) = P(B) * P(A|B)
|
||||||
|
for (auto&& actual : {prior * conditional, conditional * prior}) {
|
||||||
|
EXPECT_LONGS_EQUAL(2, actual.nrFrontals());
|
||||||
|
KeyVector frontals(actual.beginFrontals(), actual.endFrontals());
|
||||||
|
EXPECT((frontals == KeyVector{0, 1}));
|
||||||
|
for (auto&& it : actual.enumerate()) {
|
||||||
|
const DiscreteValues& v = it.first;
|
||||||
|
EXPECT_DOUBLES_EQUAL(actual(v), conditional(v) * prior(v), 1e-9);
|
||||||
|
}
|
||||||
|
// And for good measure:
|
||||||
|
EXPECT(assert_equal(expected, actual));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************* */
|
||||||
|
// Check calculation of conditional joint P(A,B|C)
|
||||||
|
TEST(DiscreteConditional, Multiply2) {
|
||||||
|
DiscreteKey A(0, 2), B(1, 2), C(2, 2);
|
||||||
|
DiscreteConditional A_given_B(A | B = "1/3 3/1");
|
||||||
|
DiscreteConditional B_given_C(B | C = "1/3 3/1");
|
||||||
|
|
||||||
|
// P(A,B|C) = P(A|B)P(B|C) = P(B|C)P(A|B)
|
||||||
|
for (auto&& actual : {A_given_B * B_given_C, B_given_C * A_given_B}) {
|
||||||
|
EXPECT_LONGS_EQUAL(2, actual.nrFrontals());
|
||||||
|
EXPECT_LONGS_EQUAL(1, actual.nrParents());
|
||||||
|
KeyVector frontals(actual.beginFrontals(), actual.endFrontals());
|
||||||
|
EXPECT((frontals == KeyVector{0, 1}));
|
||||||
|
for (auto&& it : actual.enumerate()) {
|
||||||
|
const DiscreteValues& v = it.first;
|
||||||
|
EXPECT_DOUBLES_EQUAL(actual(v), A_given_B(v) * B_given_C(v), 1e-9);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************* */
|
||||||
|
// Check calculation of conditional joint P(A,B|C), double check keys
|
||||||
|
TEST(DiscreteConditional, Multiply3) {
|
||||||
|
DiscreteKey A(1, 2), B(2, 2), C(0, 2); // different keys!!!
|
||||||
|
DiscreteConditional A_given_B(A | B = "1/3 3/1");
|
||||||
|
DiscreteConditional B_given_C(B | C = "1/3 3/1");
|
||||||
|
|
||||||
|
// P(A,B|C) = P(A|B)P(B|C) = P(B|C)P(A|B)
|
||||||
|
for (auto&& actual : {A_given_B * B_given_C, B_given_C * A_given_B}) {
|
||||||
|
EXPECT_LONGS_EQUAL(2, actual.nrFrontals());
|
||||||
|
EXPECT_LONGS_EQUAL(1, actual.nrParents());
|
||||||
|
KeyVector frontals(actual.beginFrontals(), actual.endFrontals());
|
||||||
|
EXPECT((frontals == KeyVector{1, 2}));
|
||||||
|
for (auto&& it : actual.enumerate()) {
|
||||||
|
const DiscreteValues& v = it.first;
|
||||||
|
EXPECT_DOUBLES_EQUAL(actual(v), A_given_B(v) * B_given_C(v), 1e-9);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************* */
|
||||||
|
// Check calculation of conditional joint P(A,B,C|D,E) = P(A,B|D) P(C|D,E)
|
||||||
|
TEST(DiscreteConditional, Multiply4) {
|
||||||
|
DiscreteKey A(0, 2), B(1, 2), C(2, 2), D(4, 2), E(3, 2);
|
||||||
|
DiscreteConditional A_given_B(A | B = "1/3 3/1");
|
||||||
|
DiscreteConditional B_given_D(B | D = "1/3 3/1");
|
||||||
|
DiscreteConditional AB_given_D = A_given_B * B_given_D;
|
||||||
|
DiscreteConditional C_given_DE((C | D, E) = "4/1 1/1 1/1 1/4");
|
||||||
|
|
||||||
|
// P(A,B,C|D,E) = P(A,B|D) P(C|D,E) = P(C|D,E) P(A,B|D)
|
||||||
|
for (auto&& actual : {AB_given_D * C_given_DE, C_given_DE * AB_given_D}) {
|
||||||
|
EXPECT_LONGS_EQUAL(3, actual.nrFrontals());
|
||||||
|
EXPECT_LONGS_EQUAL(2, actual.nrParents());
|
||||||
|
KeyVector frontals(actual.beginFrontals(), actual.endFrontals());
|
||||||
|
EXPECT((frontals == KeyVector{0, 1, 2}));
|
||||||
|
KeyVector parents(actual.beginParents(), actual.endParents());
|
||||||
|
EXPECT((parents == KeyVector{3, 4}));
|
||||||
|
for (auto&& it : actual.enumerate()) {
|
||||||
|
const DiscreteValues& v = it.first;
|
||||||
|
EXPECT_DOUBLES_EQUAL(actual(v), AB_given_D(v) * C_given_DE(v), 1e-9);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************* */
|
||||||
|
// Check calculation of marginals for joint P(A,B)
|
||||||
|
TEST(DiscreteConditional, marginals) {
|
||||||
|
DiscreteKey A(1, 2), B(0, 2);
|
||||||
|
DiscreteConditional conditional(A | B = "1/2 2/1");
|
||||||
|
DiscreteConditional prior(B % "1/2");
|
||||||
|
DiscreteConditional pAB = prior * conditional;
|
||||||
|
|
||||||
|
DiscreteConditional actualA = pAB.marginal(A.first);
|
||||||
|
DiscreteConditional pA(A % "5/4");
|
||||||
|
EXPECT(assert_equal(pA, actualA));
|
||||||
|
EXPECT_LONGS_EQUAL(1, actualA.nrFrontals());
|
||||||
|
EXPECT_LONGS_EQUAL(0, actualA.nrParents());
|
||||||
|
KeyVector frontalsA(actualA.beginFrontals(), actualA.endFrontals());
|
||||||
|
EXPECT((frontalsA == KeyVector{1}));
|
||||||
|
|
||||||
|
DiscreteConditional actualB = pAB.marginal(B.first);
|
||||||
|
EXPECT(assert_equal(prior, actualB));
|
||||||
|
EXPECT_LONGS_EQUAL(1, actualB.nrFrontals());
|
||||||
|
EXPECT_LONGS_EQUAL(0, actualB.nrParents());
|
||||||
|
KeyVector frontalsB(actualB.beginFrontals(), actualB.endFrontals());
|
||||||
|
EXPECT((frontalsB == KeyVector{0}));
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
|
|
|
||||||
|
|
@ -11,47 +11,66 @@
|
||||||
|
|
||||||
/*
|
/*
|
||||||
* @file testDiscretePrior.cpp
|
* @file testDiscretePrior.cpp
|
||||||
* @brief unit tests for DiscretePrior
|
* @brief unit tests for DiscreteDistribution
|
||||||
* @author Frank dellaert
|
* @author Frank dellaert
|
||||||
* @date December 2021
|
* @date December 2021
|
||||||
*/
|
*/
|
||||||
|
|
||||||
#include <CppUnitLite/TestHarness.h>
|
#include <CppUnitLite/TestHarness.h>
|
||||||
#include <gtsam/discrete/DiscretePrior.h>
|
#include <gtsam/discrete/DiscreteDistribution.h>
|
||||||
#include <gtsam/discrete/Signature.h>
|
#include <gtsam/discrete/Signature.h>
|
||||||
|
|
||||||
using namespace std;
|
|
||||||
using namespace gtsam;
|
using namespace gtsam;
|
||||||
|
|
||||||
static const DiscreteKey X(0, 2);
|
static const DiscreteKey X(0, 2);
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
TEST(DiscretePrior, constructors) {
|
TEST(DiscreteDistribution, constructors) {
|
||||||
DiscretePrior actual(X % "2/3");
|
DecisionTreeFactor f(X, "0.4 0.6");
|
||||||
|
DiscreteDistribution expected(f);
|
||||||
|
|
||||||
|
DiscreteDistribution actual(X % "2/3");
|
||||||
EXPECT_LONGS_EQUAL(1, actual.nrFrontals());
|
EXPECT_LONGS_EQUAL(1, actual.nrFrontals());
|
||||||
EXPECT_LONGS_EQUAL(0, actual.nrParents());
|
EXPECT_LONGS_EQUAL(0, actual.nrParents());
|
||||||
DecisionTreeFactor f(X, "0.4 0.6");
|
|
||||||
DiscretePrior expected(f);
|
|
||||||
EXPECT(assert_equal(expected, actual, 1e-9));
|
EXPECT(assert_equal(expected, actual, 1e-9));
|
||||||
|
|
||||||
|
const std::vector<double> pmf{0.4, 0.6};
|
||||||
|
DiscreteDistribution actual2(X, pmf);
|
||||||
|
EXPECT_LONGS_EQUAL(1, actual2.nrFrontals());
|
||||||
|
EXPECT_LONGS_EQUAL(0, actual2.nrParents());
|
||||||
|
EXPECT(assert_equal(expected, actual2, 1e-9));
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
TEST(DiscretePrior, operator) {
|
TEST(DiscreteDistribution, Multiply) {
|
||||||
DiscretePrior prior(X % "2/3");
|
DiscreteKey A(0, 2), B(1, 2);
|
||||||
|
DiscreteConditional conditional(A | B = "1/2 2/1");
|
||||||
|
DiscreteDistribution prior(B, "1/2");
|
||||||
|
DiscreteConditional actual = prior * conditional; // P(A|B) * P(B)
|
||||||
|
|
||||||
|
EXPECT_LONGS_EQUAL(2, actual.nrFrontals()); // = P(A,B)
|
||||||
|
DecisionTreeFactor factor(A & B, "1 4 2 2");
|
||||||
|
DiscreteConditional expected(2, factor);
|
||||||
|
EXPECT(assert_equal(expected, actual, 1e-5));
|
||||||
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************* */
|
||||||
|
TEST(DiscreteDistribution, operator) {
|
||||||
|
DiscreteDistribution prior(X % "2/3");
|
||||||
EXPECT_DOUBLES_EQUAL(prior(0), 0.4, 1e-9);
|
EXPECT_DOUBLES_EQUAL(prior(0), 0.4, 1e-9);
|
||||||
EXPECT_DOUBLES_EQUAL(prior(1), 0.6, 1e-9);
|
EXPECT_DOUBLES_EQUAL(prior(1), 0.6, 1e-9);
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
TEST(DiscretePrior, pmf) {
|
TEST(DiscreteDistribution, pmf) {
|
||||||
DiscretePrior prior(X % "2/3");
|
DiscreteDistribution prior(X % "2/3");
|
||||||
vector<double> expected {0.4, 0.6};
|
std::vector<double> expected{0.4, 0.6};
|
||||||
EXPECT(prior.pmf() == expected);
|
EXPECT(prior.pmf() == expected);
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
TEST(DiscretePrior, sample) {
|
TEST(DiscreteDistribution, sample) {
|
||||||
DiscretePrior prior(X % "2/3");
|
DiscreteDistribution prior(X % "2/3");
|
||||||
prior.sample();
|
prior.sample();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -154,7 +154,8 @@ namespace gtsam {
|
||||||
|
|
||||||
/** Unnormalized probability. O(n) */
|
/** Unnormalized probability. O(n) */
|
||||||
double probPrime(const VectorValues& c) const {
|
double probPrime(const VectorValues& c) const {
|
||||||
return exp(-0.5 * error(c));
|
// NOTE the 0.5 constant is handled by the factor error.
|
||||||
|
return exp(-error(c));
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
|
|
||||||
|
|
@ -426,6 +426,7 @@ TEST(GaussianFactorGraph, hessianDiagonal) {
|
||||||
EXPECT(assert_equal(expected, actual));
|
EXPECT(assert_equal(expected, actual));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************* */
|
||||||
TEST(GaussianFactorGraph, DenseSolve) {
|
TEST(GaussianFactorGraph, DenseSolve) {
|
||||||
GaussianFactorGraph fg = createSimpleGaussianFactorGraph();
|
GaussianFactorGraph fg = createSimpleGaussianFactorGraph();
|
||||||
VectorValues expected = fg.optimize();
|
VectorValues expected = fg.optimize();
|
||||||
|
|
@ -433,6 +434,28 @@ TEST(GaussianFactorGraph, DenseSolve) {
|
||||||
EXPECT(assert_equal(expected, actual));
|
EXPECT(assert_equal(expected, actual));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************* */
|
||||||
|
TEST(GaussianFactorGraph, ProbPrime) {
|
||||||
|
GaussianFactorGraph gfg;
|
||||||
|
gfg.emplace_shared<JacobianFactor>(1, I_1x1, Z_1x1,
|
||||||
|
noiseModel::Isotropic::Sigma(1, 1.0));
|
||||||
|
|
||||||
|
VectorValues values;
|
||||||
|
values.insert(1, I_1x1);
|
||||||
|
|
||||||
|
// We are testing the normal distribution PDF where info matrix Σ = 1,
|
||||||
|
// mean mu = 0 and x = 1.
|
||||||
|
// Therefore factor squared error: y = 0.5 * (Σ*x - mu)^2 =
|
||||||
|
// 0.5 * (1.0 - 0)^2 = 0.5
|
||||||
|
// NOTE the 0.5 constant is a part of the factor error.
|
||||||
|
EXPECT_DOUBLES_EQUAL(0.5, gfg.error(values), 1e-12);
|
||||||
|
|
||||||
|
// The gaussian PDF value is: exp^(-0.5 * (Σ*x - mu)^2) / sqrt(2 * PI)
|
||||||
|
// Ignore the denominator and we get: exp^(-0.5 * (1.0)^2) = exp^(-0.5)
|
||||||
|
double expected = exp(-0.5);
|
||||||
|
EXPECT_DOUBLES_EQUAL(expected, gfg.probPrime(values), 1e-12);
|
||||||
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
int main() {
|
int main() {
|
||||||
TestResult tr;
|
TestResult tr;
|
||||||
|
|
|
||||||
|
|
@ -45,7 +45,8 @@ template class FactorGraph<NonlinearFactor>;
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
double NonlinearFactorGraph::probPrime(const Values& values) const {
|
double NonlinearFactorGraph::probPrime(const Values& values) const {
|
||||||
return exp(-0.5 * error(values));
|
// NOTE the 0.5 constant is handled by the factor error.
|
||||||
|
return exp(-error(values));
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
|
|
@ -54,9 +55,14 @@ void NonlinearFactorGraph::print(const std::string& str, const KeyFormatter& key
|
||||||
for (size_t i = 0; i < factors_.size(); i++) {
|
for (size_t i = 0; i < factors_.size(); i++) {
|
||||||
stringstream ss;
|
stringstream ss;
|
||||||
ss << "Factor " << i << ": ";
|
ss << "Factor " << i << ": ";
|
||||||
if (factors_[i] != nullptr) factors_[i]->print(ss.str(), keyFormatter);
|
if (factors_[i] != nullptr) {
|
||||||
cout << endl;
|
factors_[i]->print(ss.str(), keyFormatter);
|
||||||
|
cout << "\n";
|
||||||
|
} else {
|
||||||
|
cout << ss.str() << "nullptr\n";
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
std::cout.flush();
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
|
|
@ -80,8 +86,9 @@ void NonlinearFactorGraph::printErrors(const Values& values, const std::string&
|
||||||
factor->print(ss.str(), keyFormatter);
|
factor->print(ss.str(), keyFormatter);
|
||||||
cout << "error = " << errorValue << "\n";
|
cout << "error = " << errorValue << "\n";
|
||||||
}
|
}
|
||||||
cout << endl; // only one "endl" at end might be faster, \n for each factor
|
cout << "\n";
|
||||||
}
|
}
|
||||||
|
std::cout.flush();
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
|
|
|
||||||
|
|
@ -90,7 +90,7 @@ namespace gtsam {
|
||||||
/** Test equality */
|
/** Test equality */
|
||||||
bool equals(const NonlinearFactorGraph& other, double tol = 1e-9) const;
|
bool equals(const NonlinearFactorGraph& other, double tol = 1e-9) const;
|
||||||
|
|
||||||
/** unnormalized error, \f$ 0.5 \sum_i (h_i(X_i)-z)^2/\sigma^2 \f$ in the most common case */
|
/** unnormalized error, \f$ \sum_i 0.5 (h_i(X_i)-z)^2 / \sigma^2 \f$ in the most common case */
|
||||||
double error(const Values& values) const;
|
double error(const Values& values) const;
|
||||||
|
|
||||||
/** Unnormalized probability. O(n) */
|
/** Unnormalized probability. O(n) */
|
||||||
|
|
|
||||||
|
|
@ -11,7 +11,7 @@ namespace gtsam {
|
||||||
// ######
|
// ######
|
||||||
|
|
||||||
#include <gtsam/slam/BetweenFactor.h>
|
#include <gtsam/slam/BetweenFactor.h>
|
||||||
template <T = {Vector, gtsam::Point2, gtsam::Point3, gtsam::Rot2, gtsam::SO3,
|
template <T = {double, Vector, gtsam::Point2, gtsam::Point3, gtsam::Rot2, gtsam::SO3,
|
||||||
gtsam::SO4, gtsam::Rot3, gtsam::Pose2, gtsam::Pose3,
|
gtsam::SO4, gtsam::Rot3, gtsam::Pose2, gtsam::Pose3,
|
||||||
gtsam::imuBias::ConstantBias}>
|
gtsam::imuBias::ConstantBias}>
|
||||||
virtual class BetweenFactor : gtsam::NoiseModelFactor {
|
virtual class BetweenFactor : gtsam::NoiseModelFactor {
|
||||||
|
|
|
||||||
|
|
@ -13,7 +13,7 @@ Author: Frank Dellaert
|
||||||
|
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
from gtsam import DecisionTreeFactor, DecisionTreeFactor, DiscreteKeys
|
from gtsam import DecisionTreeFactor, DiscreteValues, DiscreteDistribution, Ordering
|
||||||
from gtsam.utils.test_case import GtsamTestCase
|
from gtsam.utils.test_case import GtsamTestCase
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -21,15 +21,59 @@ class TestDecisionTreeFactor(GtsamTestCase):
|
||||||
"""Tests for DecisionTreeFactors."""
|
"""Tests for DecisionTreeFactors."""
|
||||||
|
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
A = (12, 3)
|
self.A = (12, 3)
|
||||||
B = (5, 2)
|
self.B = (5, 2)
|
||||||
self.factor = DecisionTreeFactor([A, B], "1 2 3 4 5 6")
|
self.factor = DecisionTreeFactor([self.A, self.B], "1 2 3 4 5 6")
|
||||||
|
|
||||||
def test_enumerate(self):
|
def test_enumerate(self):
|
||||||
actual = self.factor.enumerate()
|
actual = self.factor.enumerate()
|
||||||
_, values = zip(*actual)
|
_, values = zip(*actual)
|
||||||
self.assertEqual(list(values), [1.0, 2.0, 3.0, 4.0, 5.0, 6.0])
|
self.assertEqual(list(values), [1.0, 2.0, 3.0, 4.0, 5.0, 6.0])
|
||||||
|
|
||||||
|
def test_multiplication(self):
|
||||||
|
"""Test whether multiplication works with overloading."""
|
||||||
|
v0 = (0, 2)
|
||||||
|
v1 = (1, 2)
|
||||||
|
v2 = (2, 2)
|
||||||
|
|
||||||
|
# Multiply with a DiscreteDistribution, i.e., Bayes Law!
|
||||||
|
prior = DiscreteDistribution(v1, [1, 3])
|
||||||
|
f1 = DecisionTreeFactor([v0, v1], "1 2 3 4")
|
||||||
|
expected = DecisionTreeFactor([v0, v1], "0.25 1.5 0.75 3")
|
||||||
|
self.gtsamAssertEquals(DecisionTreeFactor(prior) * f1, expected)
|
||||||
|
self.gtsamAssertEquals(f1 * prior, expected)
|
||||||
|
|
||||||
|
# Multiply two factors
|
||||||
|
f2 = DecisionTreeFactor([v1, v2], "5 6 7 8")
|
||||||
|
actual = f1 * f2
|
||||||
|
expected2 = DecisionTreeFactor([v0, v1, v2], "5 6 14 16 15 18 28 32")
|
||||||
|
self.gtsamAssertEquals(actual, expected2)
|
||||||
|
|
||||||
|
def test_methods(self):
|
||||||
|
"""Test whether we can call methods in python."""
|
||||||
|
# double operator()(const DiscreteValues& values) const;
|
||||||
|
values = DiscreteValues()
|
||||||
|
values[self.A[0]] = 0
|
||||||
|
values[self.B[0]] = 0
|
||||||
|
self.assertIsInstance(self.factor(values), float)
|
||||||
|
|
||||||
|
# size_t cardinality(Key j) const;
|
||||||
|
self.assertIsInstance(self.factor.cardinality(self.A[0]), int)
|
||||||
|
|
||||||
|
# DecisionTreeFactor operator/(const DecisionTreeFactor& f) const;
|
||||||
|
self.assertIsInstance(self.factor / self.factor, DecisionTreeFactor)
|
||||||
|
|
||||||
|
# DecisionTreeFactor* sum(size_t nrFrontals) const;
|
||||||
|
self.assertIsInstance(self.factor.sum(1), DecisionTreeFactor)
|
||||||
|
|
||||||
|
# DecisionTreeFactor* sum(const Ordering& keys) const;
|
||||||
|
ordering = Ordering()
|
||||||
|
ordering.push_back(self.A[0])
|
||||||
|
self.assertIsInstance(self.factor.sum(ordering), DecisionTreeFactor)
|
||||||
|
|
||||||
|
# DecisionTreeFactor* max(size_t nrFrontals) const;
|
||||||
|
self.assertIsInstance(self.factor.max(1), DecisionTreeFactor)
|
||||||
|
|
||||||
def test_markdown(self):
|
def test_markdown(self):
|
||||||
"""Test whether the _repr_markdown_ method."""
|
"""Test whether the _repr_markdown_ method."""
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -14,7 +14,7 @@ Author: Frank Dellaert
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
from gtsam import (DiscreteBayesNet, DiscreteConditional, DiscreteFactorGraph,
|
from gtsam import (DiscreteBayesNet, DiscreteConditional, DiscreteFactorGraph,
|
||||||
DiscreteKeys, DiscretePrior, DiscreteValues, Ordering)
|
DiscreteKeys, DiscreteDistribution, DiscreteValues, Ordering)
|
||||||
from gtsam.utils.test_case import GtsamTestCase
|
from gtsam.utils.test_case import GtsamTestCase
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -74,7 +74,7 @@ class TestDiscreteBayesNet(GtsamTestCase):
|
||||||
for j in range(8):
|
for j in range(8):
|
||||||
ordering.push_back(j)
|
ordering.push_back(j)
|
||||||
chordal = fg.eliminateSequential(ordering)
|
chordal = fg.eliminateSequential(ordering)
|
||||||
expected2 = DiscretePrior(Bronchitis, "11/9")
|
expected2 = DiscreteDistribution(Bronchitis, "11/9")
|
||||||
self.gtsamAssertEquals(chordal.at(7), expected2)
|
self.gtsamAssertEquals(chordal.at(7), expected2)
|
||||||
|
|
||||||
# solve
|
# solve
|
||||||
|
|
|
||||||
|
|
@ -16,6 +16,13 @@ import unittest
|
||||||
from gtsam import DecisionTreeFactor, DiscreteConditional, DiscreteKeys
|
from gtsam import DecisionTreeFactor, DiscreteConditional, DiscreteKeys
|
||||||
from gtsam.utils.test_case import GtsamTestCase
|
from gtsam.utils.test_case import GtsamTestCase
|
||||||
|
|
||||||
|
# Some DiscreteKeys for binary variables:
|
||||||
|
A = 0, 2
|
||||||
|
B = 1, 2
|
||||||
|
C = 2, 2
|
||||||
|
D = 4, 2
|
||||||
|
E = 3, 2
|
||||||
|
|
||||||
|
|
||||||
class TestDiscreteConditional(GtsamTestCase):
|
class TestDiscreteConditional(GtsamTestCase):
|
||||||
"""Tests for Discrete Conditionals."""
|
"""Tests for Discrete Conditionals."""
|
||||||
|
|
@ -36,6 +43,53 @@ class TestDiscreteConditional(GtsamTestCase):
|
||||||
actual = conditional.sample(2)
|
actual = conditional.sample(2)
|
||||||
self.assertIsInstance(actual, int)
|
self.assertIsInstance(actual, int)
|
||||||
|
|
||||||
|
def test_multiply(self):
|
||||||
|
"""Check calculation of joint P(A,B)"""
|
||||||
|
conditional = DiscreteConditional(A, [B], "1/2 2/1")
|
||||||
|
prior = DiscreteConditional(B, "1/2")
|
||||||
|
|
||||||
|
# P(A,B) = P(A|B) * P(B) = P(B) * P(A|B)
|
||||||
|
for actual in [prior * conditional, conditional * prior]:
|
||||||
|
self.assertEqual(2, actual.nrFrontals())
|
||||||
|
for v, value in actual.enumerate():
|
||||||
|
self.assertAlmostEqual(actual(v), conditional(v) * prior(v))
|
||||||
|
|
||||||
|
def test_multiply2(self):
|
||||||
|
"""Check calculation of conditional joint P(A,B|C)"""
|
||||||
|
A_given_B = DiscreteConditional(A, [B], "1/3 3/1")
|
||||||
|
B_given_C = DiscreteConditional(B, [C], "1/3 3/1")
|
||||||
|
|
||||||
|
# P(A,B|C) = P(A|B)P(B|C) = P(B|C)P(A|B)
|
||||||
|
for actual in [A_given_B * B_given_C, B_given_C * A_given_B]:
|
||||||
|
self.assertEqual(2, actual.nrFrontals())
|
||||||
|
self.assertEqual(1, actual.nrParents())
|
||||||
|
for v, value in actual.enumerate():
|
||||||
|
self.assertAlmostEqual(actual(v), A_given_B(v) * B_given_C(v))
|
||||||
|
|
||||||
|
def test_multiply4(self):
|
||||||
|
"""Check calculation of joint P(A,B,C|D,E) = P(A,B|D) P(C|D,E)"""
|
||||||
|
A_given_B = DiscreteConditional(A, [B], "1/3 3/1")
|
||||||
|
B_given_D = DiscreteConditional(B, [D], "1/3 3/1")
|
||||||
|
AB_given_D = A_given_B * B_given_D
|
||||||
|
C_given_DE = DiscreteConditional(C, [D, E], "4/1 1/1 1/1 1/4")
|
||||||
|
|
||||||
|
# P(A,B,C|D,E) = P(A,B|D) P(C|D,E) = P(C|D,E) P(A,B|D)
|
||||||
|
for actual in [AB_given_D * C_given_DE, C_given_DE * AB_given_D]:
|
||||||
|
self.assertEqual(3, actual.nrFrontals())
|
||||||
|
self.assertEqual(2, actual.nrParents())
|
||||||
|
for v, value in actual.enumerate():
|
||||||
|
self.assertAlmostEqual(
|
||||||
|
actual(v), AB_given_D(v) * C_given_DE(v))
|
||||||
|
|
||||||
|
def test_marginals(self):
|
||||||
|
conditional = DiscreteConditional(A, [B], "1/2 2/1")
|
||||||
|
prior = DiscreteConditional(B, "1/2")
|
||||||
|
pAB = prior * conditional
|
||||||
|
self.gtsamAssertEquals(prior, pAB.marginal(B[0]))
|
||||||
|
|
||||||
|
pA = DiscreteConditional(A, "5/4")
|
||||||
|
self.gtsamAssertEquals(pA, pAB.marginal(A[0]))
|
||||||
|
|
||||||
def test_markdown(self):
|
def test_markdown(self):
|
||||||
"""Test whether the _repr_markdown_ method."""
|
"""Test whether the _repr_markdown_ method."""
|
||||||
|
|
||||||
|
|
@ -48,8 +102,7 @@ class TestDiscreteConditional(GtsamTestCase):
|
||||||
|
|
||||||
conditional = DiscreteConditional(A, parents,
|
conditional = DiscreteConditional(A, parents,
|
||||||
"0/1 1/3 1/1 3/1 0/1 1/0")
|
"0/1 1/3 1/1 3/1 0/1 1/0")
|
||||||
expected = \
|
expected = " *P(A|B,C):*\n\n" \
|
||||||
" *P(A|B,C):*\n\n" \
|
|
||||||
"|*B*|*C*|0|1|\n" \
|
"|*B*|*C*|0|1|\n" \
|
||||||
"|:-:|:-:|:-:|:-:|\n" \
|
"|:-:|:-:|:-:|:-:|\n" \
|
||||||
"|0|0|0|1|\n" \
|
"|0|0|0|1|\n" \
|
||||||
|
|
|
||||||
|
|
@ -14,7 +14,7 @@ Author: Frank Dellaert
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from gtsam import DecisionTreeFactor, DiscreteKeys, DiscretePrior
|
from gtsam import DecisionTreeFactor, DiscreteKeys, DiscreteDistribution
|
||||||
from gtsam.utils.test_case import GtsamTestCase
|
from gtsam.utils.test_case import GtsamTestCase
|
||||||
|
|
||||||
X = 0, 2
|
X = 0, 2
|
||||||
|
|
@ -25,32 +25,36 @@ class TestDiscretePrior(GtsamTestCase):
|
||||||
|
|
||||||
def test_constructor(self):
|
def test_constructor(self):
|
||||||
"""Test various constructors."""
|
"""Test various constructors."""
|
||||||
actual = DiscretePrior(X, "2/3")
|
|
||||||
keys = DiscreteKeys()
|
keys = DiscreteKeys()
|
||||||
keys.push_back(X)
|
keys.push_back(X)
|
||||||
f = DecisionTreeFactor(keys, "0.4 0.6")
|
f = DecisionTreeFactor(keys, "0.4 0.6")
|
||||||
expected = DiscretePrior(f)
|
expected = DiscreteDistribution(f)
|
||||||
|
|
||||||
|
actual = DiscreteDistribution(X, "2/3")
|
||||||
self.gtsamAssertEquals(actual, expected)
|
self.gtsamAssertEquals(actual, expected)
|
||||||
|
|
||||||
|
actual2 = DiscreteDistribution(X, [0.4, 0.6])
|
||||||
|
self.gtsamAssertEquals(actual2, expected)
|
||||||
|
|
||||||
def test_operator(self):
|
def test_operator(self):
|
||||||
prior = DiscretePrior(X, "2/3")
|
prior = DiscreteDistribution(X, "2/3")
|
||||||
self.assertAlmostEqual(prior(0), 0.4)
|
self.assertAlmostEqual(prior(0), 0.4)
|
||||||
self.assertAlmostEqual(prior(1), 0.6)
|
self.assertAlmostEqual(prior(1), 0.6)
|
||||||
|
|
||||||
def test_pmf(self):
|
def test_pmf(self):
|
||||||
prior = DiscretePrior(X, "2/3")
|
prior = DiscreteDistribution(X, "2/3")
|
||||||
expected = np.array([0.4, 0.6])
|
expected = np.array([0.4, 0.6])
|
||||||
np.testing.assert_allclose(expected, prior.pmf())
|
np.testing.assert_allclose(expected, prior.pmf())
|
||||||
|
|
||||||
def test_sample(self):
|
def test_sample(self):
|
||||||
prior = DiscretePrior(X, "2/3")
|
prior = DiscreteDistribution(X, "2/3")
|
||||||
actual = prior.sample()
|
actual = prior.sample()
|
||||||
self.assertIsInstance(actual, int)
|
self.assertIsInstance(actual, int)
|
||||||
|
|
||||||
def test_markdown(self):
|
def test_markdown(self):
|
||||||
"""Test the _repr_markdown_ method."""
|
"""Test the _repr_markdown_ method."""
|
||||||
|
|
||||||
prior = DiscretePrior(X, "2/3")
|
prior = DiscreteDistribution(X, "2/3")
|
||||||
expected = " *P(0):*\n\n" \
|
expected = " *P(0):*\n\n" \
|
||||||
"|0|value|\n" \
|
"|0|value|\n" \
|
||||||
"|:-:|:-:|\n" \
|
"|:-:|:-:|\n" \
|
||||||
|
|
@ -107,6 +107,24 @@ TEST( NonlinearFactorGraph, probPrime )
|
||||||
DOUBLES_EQUAL(expected,actual,0);
|
DOUBLES_EQUAL(expected,actual,0);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************* */
|
||||||
|
TEST(NonlinearFactorGraph, ProbPrime2) {
|
||||||
|
NonlinearFactorGraph fg;
|
||||||
|
fg.emplace_shared<PriorFactor<double>>(1, 0.0,
|
||||||
|
noiseModel::Isotropic::Sigma(1, 1.0));
|
||||||
|
|
||||||
|
Values values;
|
||||||
|
values.insert(1, 1.0);
|
||||||
|
|
||||||
|
// The prior factor squared error is: 0.5.
|
||||||
|
EXPECT_DOUBLES_EQUAL(0.5, fg.error(values), 1e-12);
|
||||||
|
|
||||||
|
// The probability value is: exp^(-factor_error) / sqrt(2 * PI)
|
||||||
|
// Ignore the denominator and we get: exp^(-factor_error) = exp^(-0.5)
|
||||||
|
double expected = exp(-0.5);
|
||||||
|
EXPECT_DOUBLES_EQUAL(expected, fg.probPrime(values), 1e-12);
|
||||||
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
TEST( NonlinearFactorGraph, linearize )
|
TEST( NonlinearFactorGraph, linearize )
|
||||||
{
|
{
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue