Merge pull request #990 from borglab/feature/api_improvements
commit
e5b928c610
|
@ -9,12 +9,18 @@ endif()
|
|||
|
||||
# Set the version number for the library
|
||||
set (GTSAM_VERSION_MAJOR 4)
|
||||
set (GTSAM_VERSION_MINOR 1)
|
||||
set (GTSAM_VERSION_PATCH 1)
|
||||
set (GTSAM_VERSION_MINOR 2)
|
||||
set (GTSAM_VERSION_PATCH 0)
|
||||
set (GTSAM_PRERELEASE_VERSION "a0")
|
||||
math (EXPR GTSAM_VERSION_NUMERIC "10000 * ${GTSAM_VERSION_MAJOR} + 100 * ${GTSAM_VERSION_MINOR} + ${GTSAM_VERSION_PATCH}")
|
||||
set (GTSAM_VERSION_STRING "${GTSAM_VERSION_MAJOR}.${GTSAM_VERSION_MINOR}.${GTSAM_VERSION_PATCH}")
|
||||
|
||||
set (CMAKE_PROJECT_VERSION ${GTSAM_VERSION_STRING})
|
||||
if (${GTSAM_VERSION_PATCH} EQUAL 0)
|
||||
set (GTSAM_VERSION_STRING "${GTSAM_VERSION_MAJOR}.${GTSAM_VERSION_MINOR}${GTSAM_PRERELEASE_VERSION}")
|
||||
else()
|
||||
set (GTSAM_VERSION_STRING "${GTSAM_VERSION_MAJOR}.${GTSAM_VERSION_MINOR}.${GTSAM_VERSION_PATCH}${GTSAM_PRERELEASE_VERSION}")
|
||||
endif()
|
||||
message(STATUS "GTSAM Version: ${GTSAM_VERSION_STRING}")
|
||||
|
||||
set (CMAKE_PROJECT_VERSION_MAJOR ${GTSAM_VERSION_MAJOR})
|
||||
set (CMAKE_PROJECT_VERSION_MINOR ${GTSAM_VERSION_MINOR})
|
||||
set (CMAKE_PROJECT_VERSION_PATCH ${GTSAM_VERSION_PATCH})
|
||||
|
|
|
@ -134,17 +134,34 @@ namespace gtsam {
|
|||
return boost::make_shared<DecisionTreeFactor>(dkeys, result);
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
std::vector<std::pair<DiscreteValues, double>> DecisionTreeFactor::enumerate() const {
|
||||
// Get all possible assignments
|
||||
std::vector<std::pair<Key, size_t>> pairs;
|
||||
for (auto& key : keys()) {
|
||||
pairs.emplace_back(key, cardinalities_.at(key));
|
||||
}
|
||||
// Reverse to make cartesianProduct output a more natural ordering.
|
||||
std::vector<std::pair<Key, size_t>> rpairs(pairs.rbegin(), pairs.rend());
|
||||
const auto assignments = cartesianProduct(rpairs);
|
||||
|
||||
// Construct unordered_map with values
|
||||
std::vector<std::pair<DiscreteValues, double>> result;
|
||||
for (const auto& assignment : assignments) {
|
||||
result.emplace_back(assignment, operator()(assignment));
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
std::string DecisionTreeFactor::markdown(
|
||||
const KeyFormatter& keyFormatter) const {
|
||||
std::stringstream ss;
|
||||
|
||||
// Print out header and construct argument for `cartesianProduct`.
|
||||
std::vector<std::pair<Key, size_t>> pairs;
|
||||
ss << "|";
|
||||
for (auto& key : keys()) {
|
||||
ss << keyFormatter(key) << "|";
|
||||
pairs.emplace_back(key, cardinalities_.at(key));
|
||||
}
|
||||
ss << "value|\n";
|
||||
|
||||
|
@ -154,12 +171,12 @@ namespace gtsam {
|
|||
ss << ":-:|\n";
|
||||
|
||||
// Print out all rows.
|
||||
std::vector<std::pair<Key, size_t>> rpairs(pairs.rbegin(), pairs.rend());
|
||||
const auto assignments = cartesianProduct(rpairs);
|
||||
for (const auto& assignment : assignments) {
|
||||
auto rows = enumerate();
|
||||
for (const auto& kv : rows) {
|
||||
ss << "|";
|
||||
auto assignment = kv.first;
|
||||
for (auto& key : keys()) ss << assignment.at(key) << "|";
|
||||
ss << operator()(assignment) << "|\n";
|
||||
ss << kv.second << "|\n";
|
||||
}
|
||||
return ss.str();
|
||||
}
|
||||
|
|
|
@ -61,6 +61,15 @@ namespace gtsam {
|
|||
DiscreteFactor(keys.indices()), Potentials(keys, table) {
|
||||
}
|
||||
|
||||
/// Single-key specialization
|
||||
template <class SOURCE>
|
||||
DecisionTreeFactor(const DiscreteKey& key, SOURCE table)
|
||||
: DecisionTreeFactor(DiscreteKeys{key}, table) {}
|
||||
|
||||
/// Single-key specialization, with vector of doubles.
|
||||
DecisionTreeFactor(const DiscreteKey& key, const std::vector<double>& row)
|
||||
: DecisionTreeFactor(DiscreteKeys{key}, row) {}
|
||||
|
||||
/** Construct from a DiscreteConditional type */
|
||||
DecisionTreeFactor(const DiscreteConditional& c);
|
||||
|
||||
|
@ -162,6 +171,9 @@ namespace gtsam {
|
|||
// Potentials::reduceWithInverse(inverseReduction);
|
||||
// }
|
||||
|
||||
/// Enumerate all values into a map from values to double.
|
||||
std::vector<std::pair<DiscreteValues, double>> enumerate() const;
|
||||
|
||||
/// @}
|
||||
/// @name Wrapper support
|
||||
/// @{
|
||||
|
|
|
@ -23,6 +23,7 @@
|
|||
#include <boost/shared_ptr.hpp>
|
||||
#include <gtsam/inference/BayesNet.h>
|
||||
#include <gtsam/inference/FactorGraph.h>
|
||||
#include <gtsam/discrete/DiscretePrior.h>
|
||||
#include <gtsam/discrete/DiscreteConditional.h>
|
||||
|
||||
namespace gtsam {
|
||||
|
@ -75,6 +76,11 @@ namespace gtsam {
|
|||
// Add inherited versions of add.
|
||||
using Base::add;
|
||||
|
||||
/** Add a DiscretePrior using a table or a string */
|
||||
void add(const DiscreteKey& key, const std::string& spec) {
|
||||
emplace_shared<DiscretePrior>(key, spec);
|
||||
}
|
||||
|
||||
/** Add a DiscreteCondtional */
|
||||
template <typename... Args>
|
||||
void add(Args&&... args) {
|
||||
|
|
|
@ -97,45 +97,90 @@ bool DiscreteConditional::equals(const DiscreteFactor& other,
|
|||
}
|
||||
|
||||
/* ******************************************************************************** */
|
||||
Potentials::ADT DiscreteConditional::choose(
|
||||
const DiscreteValues& parentsValues) const {
|
||||
static DiscreteConditional::ADT Choose(const DiscreteConditional& conditional,
|
||||
const DiscreteValues& parentsValues) {
|
||||
// Get the big decision tree with all the levels, and then go down the
|
||||
// branches based on the value of the parent variables.
|
||||
ADT pFS(*this);
|
||||
DiscreteConditional::ADT adt(conditional);
|
||||
size_t value;
|
||||
for (Key j : parents()) {
|
||||
for (Key j : conditional.parents()) {
|
||||
try {
|
||||
value = parentsValues.at(j);
|
||||
pFS = pFS.choose(j, value); // ADT keeps getting smaller.
|
||||
} catch (exception&) {
|
||||
cout << "Key: " << j << " Value: " << value << endl;
|
||||
adt = adt.choose(j, value); // ADT keeps getting smaller.
|
||||
} catch (std::out_of_range&) {
|
||||
parentsValues.print("parentsValues: ");
|
||||
throw runtime_error("DiscreteConditional::choose: parent value missing");
|
||||
};
|
||||
}
|
||||
return pFS;
|
||||
return adt;
|
||||
}
|
||||
|
||||
/* ******************************************************************************** */
|
||||
DecisionTreeFactor::shared_ptr DiscreteConditional::chooseAsFactor(
|
||||
DecisionTreeFactor::shared_ptr DiscreteConditional::choose(
|
||||
const DiscreteValues& parentsValues) const {
|
||||
ADT pFS = choose(parentsValues);
|
||||
// Get the big decision tree with all the levels, and then go down the
|
||||
// branches based on the value of the parent variables.
|
||||
ADT adt(*this);
|
||||
size_t value;
|
||||
for (Key j : parents()) {
|
||||
try {
|
||||
value = parentsValues.at(j);
|
||||
adt = adt.choose(j, value); // ADT keeps getting smaller.
|
||||
} catch (exception&) {
|
||||
parentsValues.print("parentsValues: ");
|
||||
throw runtime_error("DiscreteConditional::choose: parent value missing");
|
||||
};
|
||||
}
|
||||
|
||||
// Convert ADT to factor.
|
||||
if (nrFrontals() != 1) {
|
||||
throw std::runtime_error("Expected only one frontal variable in choose.");
|
||||
DiscreteKeys discreteKeys;
|
||||
for (Key j : frontals()) {
|
||||
discreteKeys.emplace_back(j, this->cardinality(j));
|
||||
}
|
||||
DiscreteKeys keys;
|
||||
const Key frontalKey = keys_[0];
|
||||
size_t frontalCardinality = this->cardinality(frontalKey);
|
||||
keys.push_back(DiscreteKey(frontalKey, frontalCardinality));
|
||||
return boost::make_shared<DecisionTreeFactor>(keys, pFS);
|
||||
return boost::make_shared<DecisionTreeFactor>(discreteKeys, adt);
|
||||
}
|
||||
|
||||
/* ******************************************************************************** */
|
||||
DecisionTreeFactor::shared_ptr DiscreteConditional::likelihood(
|
||||
const DiscreteValues& frontalValues) const {
|
||||
// Get the big decision tree with all the levels, and then go down the
|
||||
// branches based on the value of the frontal variables.
|
||||
ADT adt(*this);
|
||||
size_t value;
|
||||
for (Key j : frontals()) {
|
||||
try {
|
||||
value = frontalValues.at(j);
|
||||
adt = adt.choose(j, value); // ADT keeps getting smaller.
|
||||
} catch (exception&) {
|
||||
frontalValues.print("frontalValues: ");
|
||||
throw runtime_error("DiscreteConditional::choose: frontal value missing");
|
||||
};
|
||||
}
|
||||
|
||||
// Convert ADT to factor.
|
||||
DiscreteKeys discreteKeys;
|
||||
for (Key j : parents()) {
|
||||
discreteKeys.emplace_back(j, this->cardinality(j));
|
||||
}
|
||||
return boost::make_shared<DecisionTreeFactor>(discreteKeys, adt);
|
||||
}
|
||||
|
||||
/* ******************************************************************************** */
|
||||
DecisionTreeFactor::shared_ptr DiscreteConditional::likelihood(
|
||||
size_t parent_value) const {
|
||||
if (nrFrontals() != 1)
|
||||
throw std::invalid_argument(
|
||||
"Single value likelihood can only be invoked on single-variable "
|
||||
"conditional");
|
||||
DiscreteValues values;
|
||||
values.emplace(keys_[0], parent_value);
|
||||
return likelihood(values);
|
||||
}
|
||||
|
||||
/* ******************************************************************************** */
|
||||
void DiscreteConditional::solveInPlace(DiscreteValues* values) const {
|
||||
// TODO: Abhijit asks: is this really the fastest way? He thinks it is.
|
||||
ADT pFS = choose(*values); // P(F|S=parentsValues)
|
||||
ADT pFS = Choose(*this, *values); // P(F|S=parentsValues)
|
||||
|
||||
// Initialize
|
||||
DiscreteValues mpe;
|
||||
|
@ -177,7 +222,7 @@ void DiscreteConditional::sampleInPlace(DiscreteValues* values) const {
|
|||
size_t DiscreteConditional::solve(const DiscreteValues& parentsValues) const {
|
||||
|
||||
// TODO: is this really the fastest way? I think it is.
|
||||
ADT pFS = choose(parentsValues); // P(F|S=parentsValues)
|
||||
ADT pFS = Choose(*this, parentsValues); // P(F|S=parentsValues)
|
||||
|
||||
// Then, find the max over all remaining
|
||||
// TODO, only works for one key now, seems horribly slow this way
|
||||
|
@ -203,10 +248,14 @@ size_t DiscreteConditional::sample(const DiscreteValues& parentsValues) const {
|
|||
static mt19937 rng(2); // random number generator
|
||||
|
||||
// Get the correct conditional density
|
||||
ADT pFS = choose(parentsValues); // P(F|S=parentsValues)
|
||||
ADT pFS = Choose(*this, parentsValues); // P(F|S=parentsValues)
|
||||
|
||||
// TODO(Duy): only works for one key now, seems horribly slow this way
|
||||
assert(nrFrontals() == 1);
|
||||
if (nrFrontals() != 1) {
|
||||
throw std::invalid_argument(
|
||||
"DiscreteConditional::sample can only be called on single variable "
|
||||
"conditionals");
|
||||
}
|
||||
Key key = firstFrontalKey();
|
||||
size_t nj = cardinality(key);
|
||||
vector<double> p(nj);
|
||||
|
@ -222,13 +271,24 @@ size_t DiscreteConditional::sample(const DiscreteValues& parentsValues) const {
|
|||
return distribution(rng);
|
||||
}
|
||||
|
||||
/* ******************************************************************************** */
|
||||
size_t DiscreteConditional::sample(size_t parent_value) const {
|
||||
if (nrParents() != 1)
|
||||
throw std::invalid_argument(
|
||||
"Single value sample() can only be invoked on single-parent "
|
||||
"conditional");
|
||||
DiscreteValues values;
|
||||
values.emplace(keys_.back(), parent_value);
|
||||
return sample(values);
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
std::string DiscreteConditional::markdown(
|
||||
const KeyFormatter& keyFormatter) const {
|
||||
std::stringstream ss;
|
||||
|
||||
// Print out signature.
|
||||
ss << " $P(";
|
||||
ss << " *P(";
|
||||
bool first = true;
|
||||
for (Key key : frontals()) {
|
||||
if (!first) ss << ",";
|
||||
|
@ -237,7 +297,7 @@ std::string DiscreteConditional::markdown(
|
|||
}
|
||||
if (nrParents() == 0) {
|
||||
// We have no parents, call factor method.
|
||||
ss << ")$:" << std::endl;
|
||||
ss << ")*:\n" << std::endl;
|
||||
ss << DecisionTreeFactor::markdown(keyFormatter);
|
||||
return ss.str();
|
||||
}
|
||||
|
@ -250,7 +310,7 @@ std::string DiscreteConditional::markdown(
|
|||
ss << keyFormatter(parent);
|
||||
first = false;
|
||||
}
|
||||
ss << ")$:" << std::endl;
|
||||
ss << ")*:\n" << std::endl;
|
||||
|
||||
// Print out header and construct argument for `cartesianProduct`.
|
||||
std::vector<std::pair<Key, size_t>> pairs;
|
||||
|
|
|
@ -62,8 +62,6 @@ public:
|
|||
* conditional probability table (CPT) in 00 01 10 11 order. For
|
||||
* three-valued, it would be 00 01 02 10 11 12 20 21 22, etc....
|
||||
*
|
||||
* The first string is parsed to add a key and parents.
|
||||
*
|
||||
* Example: DiscreteConditional P(D, {B,E}, table);
|
||||
*/
|
||||
DiscreteConditional(const DiscreteKey& key, const DiscreteKeys& parents,
|
||||
|
@ -75,8 +73,7 @@ public:
|
|||
* probability table (CPT) in 00 01 10 11 order. For three-valued, it would
|
||||
* be 00 01 02 10 11 12 20 21 22, etc....
|
||||
*
|
||||
* The first string is parsed to add a key and parents. The second string
|
||||
* parses into a table.
|
||||
* The string is parsed into a Signature::Table.
|
||||
*
|
||||
* Example: DiscreteConditional P(D, {B,E}, "9/1 2/8 3/7 1/9");
|
||||
*/
|
||||
|
@ -84,6 +81,10 @@ public:
|
|||
const std::string& spec)
|
||||
: DiscreteConditional(Signature(key, parents, spec)) {}
|
||||
|
||||
/// No-parent specialization; can also use DiscretePrior.
|
||||
DiscreteConditional(const DiscreteKey& key, const std::string& spec)
|
||||
: DiscreteConditional(Signature(key, {}, spec)) {}
|
||||
|
||||
/** construct P(X|Y)=P(X,Y)/P(Y) from P(X,Y) and P(Y) */
|
||||
DiscreteConditional(const DecisionTreeFactor& joint,
|
||||
const DecisionTreeFactor& marginal);
|
||||
|
@ -135,13 +136,17 @@ public:
|
|||
return DecisionTreeFactor::shared_ptr(new DecisionTreeFactor(*this));
|
||||
}
|
||||
|
||||
/** Restrict to given parent values, returns AlgebraicDecisionDiagram */
|
||||
ADT choose(const DiscreteValues& parentsValues) const;
|
||||
|
||||
/** Restrict to given parent values, returns DecisionTreeFactor */
|
||||
DecisionTreeFactor::shared_ptr chooseAsFactor(
|
||||
DecisionTreeFactor::shared_ptr choose(
|
||||
const DiscreteValues& parentsValues) const;
|
||||
|
||||
/** Convert to a likelihood factor by providing value before bar. */
|
||||
DecisionTreeFactor::shared_ptr likelihood(
|
||||
const DiscreteValues& frontalValues) const;
|
||||
|
||||
/** Single variable version of likelihood. */
|
||||
DecisionTreeFactor::shared_ptr likelihood(size_t parent_value) const;
|
||||
|
||||
/**
|
||||
* solve a conditional
|
||||
* @param parentsValues Known values of the parents
|
||||
|
@ -156,6 +161,10 @@ public:
|
|||
*/
|
||||
size_t sample(const DiscreteValues& parentsValues) const;
|
||||
|
||||
|
||||
/// Single value version.
|
||||
size_t sample(size_t parent_value) const;
|
||||
|
||||
/// @}
|
||||
/// @name Advanced Interface
|
||||
/// @{
|
||||
|
|
|
@ -101,29 +101,12 @@ public:
|
|||
|
||||
/// @}
|
||||
|
||||
// Add single key decision-tree factor.
|
||||
template <class SOURCE>
|
||||
void add(const DiscreteKey& j, SOURCE table) {
|
||||
DiscreteKeys keys;
|
||||
keys.push_back(j);
|
||||
emplace_shared<DecisionTreeFactor>(keys, table);
|
||||
/** Add a decision-tree factor */
|
||||
template <typename... Args>
|
||||
void add(Args&&... args) {
|
||||
emplace_shared<DecisionTreeFactor>(std::forward<Args>(args)...);
|
||||
}
|
||||
|
||||
// Add binary key decision-tree factor.
|
||||
template <class SOURCE>
|
||||
void add(const DiscreteKey& j1, const DiscreteKey& j2, SOURCE table) {
|
||||
DiscreteKeys keys;
|
||||
keys.push_back(j1);
|
||||
keys.push_back(j2);
|
||||
emplace_shared<DecisionTreeFactor>(keys, table);
|
||||
}
|
||||
|
||||
// Add shared discreteFactor immediately from arguments.
|
||||
template <class SOURCE>
|
||||
void add(const DiscreteKeys& keys, SOURCE table) {
|
||||
emplace_shared<DecisionTreeFactor>(keys, table);
|
||||
}
|
||||
|
||||
|
||||
/** Return the set of variables involved in the factors (set union) */
|
||||
KeySet keys() const;
|
||||
|
||||
|
|
|
@ -43,9 +43,7 @@ namespace gtsam {
|
|||
DiscreteKeys() : std::vector<DiscreteKey>::vector() {}
|
||||
|
||||
/// Construct from a key
|
||||
DiscreteKeys(const DiscreteKey& key) {
|
||||
push_back(key);
|
||||
}
|
||||
explicit DiscreteKeys(const DiscreteKey& key) { push_back(key); }
|
||||
|
||||
/// Construct from a vector of keys
|
||||
DiscreteKeys(const std::vector<DiscreteKey>& keys) :
|
||||
|
|
|
@ -0,0 +1,50 @@
|
|||
/* ----------------------------------------------------------------------------
|
||||
|
||||
* GTSAM Copyright 2010, Georgia Tech Research Corporation,
|
||||
* Atlanta, Georgia 30332-0415
|
||||
* All Rights Reserved
|
||||
* Authors: Frank Dellaert, et al. (see THANKS for the full author list)
|
||||
|
||||
* See LICENSE for the license information
|
||||
|
||||
* -------------------------------------------------------------------------- */
|
||||
|
||||
/**
|
||||
* @file DiscretePrior.cpp
|
||||
* @date December 2021
|
||||
* @author Frank Dellaert
|
||||
*/
|
||||
|
||||
#include <gtsam/discrete/DiscretePrior.h>
|
||||
|
||||
namespace gtsam {
|
||||
|
||||
void DiscretePrior::print(const std::string& s,
|
||||
const KeyFormatter& formatter) const {
|
||||
Base::print(s, formatter);
|
||||
}
|
||||
|
||||
double DiscretePrior::operator()(size_t value) const {
|
||||
if (nrFrontals() != 1)
|
||||
throw std::invalid_argument(
|
||||
"Single value operator can only be invoked on single-variable "
|
||||
"priors");
|
||||
DiscreteValues values;
|
||||
values.emplace(keys_[0], value);
|
||||
return Base::operator()(values);
|
||||
}
|
||||
|
||||
std::vector<double> DiscretePrior::pmf() const {
|
||||
if (nrFrontals() != 1)
|
||||
throw std::invalid_argument(
|
||||
"DiscretePrior::pmf only defined for single-variable priors");
|
||||
const size_t nrValues = cardinalities_.at(keys_[0]);
|
||||
std::vector<double> array;
|
||||
array.reserve(nrValues);
|
||||
for (size_t v = 0; v < nrValues; v++) {
|
||||
array.push_back(operator()(v));
|
||||
}
|
||||
return array;
|
||||
}
|
||||
|
||||
} // namespace gtsam
|
|
@ -0,0 +1,111 @@
|
|||
/* ----------------------------------------------------------------------------
|
||||
|
||||
* GTSAM Copyright 2010, Georgia Tech Research Corporation,
|
||||
* Atlanta, Georgia 30332-0415
|
||||
* All Rights Reserved
|
||||
* Authors: Frank Dellaert, et al. (see THANKS for the full author list)
|
||||
|
||||
* See LICENSE for the license information
|
||||
|
||||
* -------------------------------------------------------------------------- */
|
||||
|
||||
/**
|
||||
* @file DiscretePrior.h
|
||||
* @date December 2021
|
||||
* @author Frank Dellaert
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <gtsam/discrete/DiscreteConditional.h>
|
||||
|
||||
#include <string>
|
||||
|
||||
namespace gtsam {
|
||||
|
||||
/**
|
||||
* A prior probability on a set of discrete variables.
|
||||
* Derives from DiscreteConditional
|
||||
*/
|
||||
class GTSAM_EXPORT DiscretePrior : public DiscreteConditional {
|
||||
public:
|
||||
using Base = DiscreteConditional;
|
||||
|
||||
/// @name Standard Constructors
|
||||
/// @{
|
||||
|
||||
/// Default constructor needed for serialization.
|
||||
DiscretePrior() {}
|
||||
|
||||
/// Constructor from factor.
|
||||
DiscretePrior(const DecisionTreeFactor& f) : Base(f.size(), f) {}
|
||||
|
||||
/**
|
||||
* Construct from a Signature.
|
||||
*
|
||||
* Example: DiscretePrior P(D % "3/2");
|
||||
*/
|
||||
DiscretePrior(const Signature& s) : Base(s) {}
|
||||
|
||||
/**
|
||||
* Construct from key and a Signature::Table specifying the
|
||||
* conditional probability table (CPT).
|
||||
*
|
||||
* Example: DiscretePrior P(D, table);
|
||||
*/
|
||||
DiscretePrior(const DiscreteKey& key, const Signature::Table& table)
|
||||
: Base(Signature(key, {}, table)) {}
|
||||
|
||||
/**
|
||||
* Construct from key and a string specifying the conditional
|
||||
* probability table (CPT).
|
||||
*
|
||||
* Example: DiscretePrior P(D, "9/1 2/8 3/7 1/9");
|
||||
*/
|
||||
DiscretePrior(const DiscreteKey& key, const std::string& spec)
|
||||
: DiscretePrior(Signature(key, {}, spec)) {}
|
||||
|
||||
/// @}
|
||||
/// @name Testable
|
||||
/// @{
|
||||
|
||||
/// GTSAM-style print
|
||||
void print(
|
||||
const std::string& s = "Discrete Prior: ",
|
||||
const KeyFormatter& formatter = DefaultKeyFormatter) const override;
|
||||
|
||||
/// @}
|
||||
/// @name Standard interface
|
||||
/// @{
|
||||
|
||||
/// Evaluate given a single value.
|
||||
double operator()(size_t value) const;
|
||||
|
||||
/// We also want to keep the Base version, taking DiscreteValues:
|
||||
// TODO(dellaert): does not play well with wrapper!
|
||||
// using Base::operator();
|
||||
|
||||
/// Return entire probability mass function.
|
||||
std::vector<double> pmf() const;
|
||||
|
||||
/**
|
||||
* solve a conditional
|
||||
* @return MPE value of the child (1 frontal variable).
|
||||
*/
|
||||
size_t solve() const { return Base::solve({}); }
|
||||
|
||||
/**
|
||||
* sample
|
||||
* @return sample from conditional
|
||||
*/
|
||||
size_t sample() const { return Base::sample({}); }
|
||||
|
||||
/// @}
|
||||
};
|
||||
// DiscretePrior
|
||||
|
||||
// traits
|
||||
template <>
|
||||
struct traits<DiscretePrior> : public Testable<DiscretePrior> {};
|
||||
|
||||
} // namespace gtsam
|
|
@ -30,25 +30,37 @@ class DiscreteFactor {
|
|||
};
|
||||
|
||||
#include <gtsam/discrete/DecisionTreeFactor.h>
|
||||
virtual class DecisionTreeFactor: gtsam::DiscreteFactor {
|
||||
virtual class DecisionTreeFactor : gtsam::DiscreteFactor {
|
||||
DecisionTreeFactor();
|
||||
|
||||
DecisionTreeFactor(const gtsam::DiscreteKey& key,
|
||||
const std::vector<double>& spec);
|
||||
DecisionTreeFactor(const gtsam::DiscreteKey& key, string table);
|
||||
|
||||
DecisionTreeFactor(const gtsam::DiscreteKeys& keys, string table);
|
||||
DecisionTreeFactor(const std::vector<gtsam::DiscreteKey>& keys, string table);
|
||||
|
||||
DecisionTreeFactor(const gtsam::DiscreteConditional& c);
|
||||
|
||||
void print(string s = "DecisionTreeFactor\n",
|
||||
const gtsam::KeyFormatter& keyFormatter =
|
||||
gtsam::DefaultKeyFormatter) const;
|
||||
bool equals(const gtsam::DecisionTreeFactor& other, double tol = 1e-9) const;
|
||||
string dot(bool showZero = false) const;
|
||||
std::vector<std::pair<DiscreteValues, double>> enumerate() const;
|
||||
string markdown(const gtsam::KeyFormatter& keyFormatter =
|
||||
gtsam::DefaultKeyFormatter) const;
|
||||
gtsam::DefaultKeyFormatter) const;
|
||||
};
|
||||
|
||||
#include <gtsam/discrete/DiscreteConditional.h>
|
||||
virtual class DiscreteConditional : gtsam::DecisionTreeFactor {
|
||||
DiscreteConditional();
|
||||
DiscreteConditional(size_t nFrontals, const gtsam::DecisionTreeFactor& f);
|
||||
DiscreteConditional(const gtsam::DiscreteKey& key, string spec);
|
||||
DiscreteConditional(const gtsam::DiscreteKey& key,
|
||||
const gtsam::DiscreteKeys& parents, string spec);
|
||||
DiscreteConditional(const gtsam::DiscreteKey& key,
|
||||
const std::vector<gtsam::DiscreteKey>& parents, string spec);
|
||||
DiscreteConditional(const gtsam::DecisionTreeFactor& joint,
|
||||
const gtsam::DecisionTreeFactor& marginal);
|
||||
DiscreteConditional(const gtsam::DecisionTreeFactor& joint,
|
||||
|
@ -62,20 +74,43 @@ virtual class DiscreteConditional : gtsam::DecisionTreeFactor {
|
|||
string s = "Discrete Conditional: ",
|
||||
const gtsam::KeyFormatter& formatter = gtsam::DefaultKeyFormatter) const;
|
||||
gtsam::DecisionTreeFactor* toFactor() const;
|
||||
gtsam::DecisionTreeFactor* chooseAsFactor(const gtsam::DiscreteValues& parentsValues) const;
|
||||
gtsam::DecisionTreeFactor* choose(
|
||||
const gtsam::DiscreteValues& parentsValues) const;
|
||||
gtsam::DecisionTreeFactor* likelihood(
|
||||
const gtsam::DiscreteValues& frontalValues) const;
|
||||
gtsam::DecisionTreeFactor* likelihood(size_t value) const;
|
||||
size_t solve(const gtsam::DiscreteValues& parentsValues) const;
|
||||
size_t sample(const gtsam::DiscreteValues& parentsValues) const;
|
||||
void solveInPlace(gtsam::DiscreteValues@ parentsValues) const;
|
||||
void sampleInPlace(gtsam::DiscreteValues@ parentsValues) const;
|
||||
size_t sample(size_t value) const;
|
||||
void solveInPlace(gtsam::DiscreteValues @parentsValues) const;
|
||||
void sampleInPlace(gtsam::DiscreteValues @parentsValues) const;
|
||||
string markdown(const gtsam::KeyFormatter& keyFormatter =
|
||||
gtsam::DefaultKeyFormatter) const;
|
||||
};
|
||||
|
||||
#include <gtsam/discrete/DiscretePrior.h>
|
||||
virtual class DiscretePrior : gtsam::DiscreteConditional {
|
||||
DiscretePrior();
|
||||
DiscretePrior(const gtsam::DecisionTreeFactor& f);
|
||||
DiscretePrior(const gtsam::DiscreteKey& key, string spec);
|
||||
void print(string s = "Discrete Prior\n",
|
||||
const gtsam::KeyFormatter& keyFormatter =
|
||||
gtsam::DefaultKeyFormatter) const;
|
||||
double operator()(size_t value) const;
|
||||
std::vector<double> pmf() const;
|
||||
size_t solve() const;
|
||||
size_t sample() const;
|
||||
};
|
||||
|
||||
#include <gtsam/discrete/DiscreteBayesNet.h>
|
||||
class DiscreteBayesNet {
|
||||
class DiscreteBayesNet {
|
||||
DiscreteBayesNet();
|
||||
void add(const gtsam::DiscreteConditional& s);
|
||||
void add(const gtsam::DiscreteKey& key, string spec);
|
||||
void add(const gtsam::DiscreteKey& key, const gtsam::DiscreteKeys& parents,
|
||||
string spec);
|
||||
void add(const gtsam::DiscreteKey& key,
|
||||
const gtsam::DiscreteKeys& parents, string spec);
|
||||
const std::vector<gtsam::DiscreteKey>& parents, string spec);
|
||||
bool empty() const;
|
||||
size_t size() const;
|
||||
gtsam::KeySet keys() const;
|
||||
|
@ -86,15 +121,13 @@ class DiscreteBayesNet {
|
|||
bool equals(const gtsam::DiscreteBayesNet& other, double tol = 1e-9) const;
|
||||
string dot(const gtsam::KeyFormatter& keyFormatter =
|
||||
gtsam::DefaultKeyFormatter) const;
|
||||
void saveGraph(string s,
|
||||
const gtsam::KeyFormatter& keyFormatter =
|
||||
gtsam::DefaultKeyFormatter) const;
|
||||
void add(const gtsam::DiscreteConditional& s);
|
||||
void saveGraph(string s, const gtsam::KeyFormatter& keyFormatter =
|
||||
gtsam::DefaultKeyFormatter) const;
|
||||
double operator()(const gtsam::DiscreteValues& values) const;
|
||||
gtsam::DiscreteValues optimize() const;
|
||||
gtsam::DiscreteValues sample() const;
|
||||
string markdown(const gtsam::KeyFormatter& keyFormatter =
|
||||
gtsam::DefaultKeyFormatter) const;
|
||||
gtsam::DefaultKeyFormatter) const;
|
||||
};
|
||||
|
||||
#include <gtsam/discrete/DiscreteBayesTree.h>
|
||||
|
@ -142,11 +175,13 @@ class DotWriter {
|
|||
class DiscreteFactorGraph {
|
||||
DiscreteFactorGraph();
|
||||
DiscreteFactorGraph(const gtsam::DiscreteBayesNet& bayesNet);
|
||||
|
||||
|
||||
void add(const gtsam::DiscreteKey& j, string table);
|
||||
void add(const gtsam::DiscreteKey& j1, const gtsam::DiscreteKey& j2, string table);
|
||||
void add(const gtsam::DiscreteKey& j, const std::vector<double>& spec);
|
||||
|
||||
void add(const gtsam::DiscreteKeys& keys, string table);
|
||||
|
||||
void add(const std::vector<gtsam::DiscreteKey>& keys, string table);
|
||||
|
||||
bool empty() const;
|
||||
size_t size() const;
|
||||
gtsam::KeySet keys() const;
|
||||
|
|
|
@ -34,7 +34,7 @@ TEST( DecisionTreeFactor, constructors)
|
|||
DiscreteKey X(0,2), Y(1,3), Z(2,2);
|
||||
|
||||
// Create factors
|
||||
DecisionTreeFactor f1(X, "2 8");
|
||||
DecisionTreeFactor f1(X, {2, 8});
|
||||
DecisionTreeFactor f2(X & Y, "2 5 3 6 4 7");
|
||||
DecisionTreeFactor f3(X & Y & Z, "2 5 3 6 4 7 25 55 35 65 45 75");
|
||||
EXPECT_LONGS_EQUAL(1,f1.size());
|
||||
|
@ -82,11 +82,29 @@ TEST( DecisionTreeFactor, sum_max)
|
|||
DecisionTreeFactor::shared_ptr actual22 = f2.sum(1);
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
// Check enumerate yields the correct list of assignment/value pairs.
|
||||
TEST(DecisionTreeFactor, enumerate) {
|
||||
DiscreteKey A(12, 3), B(5, 2);
|
||||
DecisionTreeFactor f(A & B, "1 2 3 4 5 6");
|
||||
auto actual = f.enumerate();
|
||||
std::vector<std::pair<DiscreteValues, double>> expected;
|
||||
DiscreteValues values;
|
||||
for (size_t a : {0, 1, 2}) {
|
||||
for (size_t b : {0, 1}) {
|
||||
values[12] = a;
|
||||
values[5] = b;
|
||||
expected.emplace_back(values, f(values));
|
||||
}
|
||||
}
|
||||
EXPECT(actual == expected);
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
// Check markdown representation looks as expected.
|
||||
TEST(DecisionTreeFactor, markdown) {
|
||||
DiscreteKey A(12, 3), B(5, 2);
|
||||
DecisionTreeFactor f1(A & B, "1 2 3 4 5 6");
|
||||
DecisionTreeFactor f(A & B, "1 2 3 4 5 6");
|
||||
string expected =
|
||||
"|A|B|value|\n"
|
||||
"|:-:|:-:|:-:|\n"
|
||||
|
@ -97,7 +115,7 @@ TEST(DecisionTreeFactor, markdown) {
|
|||
"|2|0|5|\n"
|
||||
"|2|1|6|\n";
|
||||
auto formatter = [](Key key) { return key == 12 ? "A" : "B"; };
|
||||
string actual = f1.markdown(formatter);
|
||||
string actual = f.markdown(formatter);
|
||||
EXPECT(actual == expected);
|
||||
}
|
||||
|
||||
|
|
|
@ -75,8 +75,8 @@ TEST(DiscreteBayesNet, bayesNet) {
|
|||
TEST(DiscreteBayesNet, Asia) {
|
||||
DiscreteBayesNet asia;
|
||||
|
||||
asia.add(Asia % "99/1");
|
||||
asia.add(Smoking % "50/50");
|
||||
asia.add(Asia, "99/1");
|
||||
asia.add(Smoking % "50/50"); // Signature version
|
||||
|
||||
asia.add(Tuberculosis | Asia = "99/1 95/5");
|
||||
asia.add(LungCancer | Smoking = "99/1 90/10");
|
||||
|
@ -180,13 +180,13 @@ TEST(DiscreteBayesNet, markdown) {
|
|||
string expected =
|
||||
"`DiscreteBayesNet` of size 2\n"
|
||||
"\n"
|
||||
" $P(Asia)$:\n"
|
||||
" *P(Asia)*:\n\n"
|
||||
"|Asia|value|\n"
|
||||
"|:-:|:-:|\n"
|
||||
"|0|0.99|\n"
|
||||
"|1|0.01|\n"
|
||||
"\n"
|
||||
" $P(Smoking|Asia)$:\n"
|
||||
" *P(Smoking|Asia)*:\n\n"
|
||||
"|Asia|0|1|\n"
|
||||
"|:-:|:-:|:-:|\n"
|
||||
"|0|0.8|0.2|\n"
|
||||
|
|
|
@ -10,10 +10,11 @@
|
|||
* -------------------------------------------------------------------------- */
|
||||
|
||||
/*
|
||||
* @file testDecisionTreeFactor.cpp
|
||||
* @file testDiscreteConditional.cpp
|
||||
* @brief unit tests for DiscreteConditional
|
||||
* @author Duy-Nguyen Ta
|
||||
* @date Feb 14, 2011
|
||||
* @author Frank dellaert
|
||||
* @date Feb 14, 2011
|
||||
*/
|
||||
|
||||
#include <boost/assign/std/map.hpp>
|
||||
|
@ -30,24 +31,21 @@ using namespace std;
|
|||
using namespace gtsam;
|
||||
|
||||
/* ************************************************************************* */
|
||||
TEST( DiscreteConditional, constructors)
|
||||
{
|
||||
DiscreteKey X(0, 2), Y(2, 3), Z(1, 2); // watch ordering !
|
||||
TEST(DiscreteConditional, constructors) {
|
||||
DiscreteKey X(0, 2), Y(2, 3), Z(1, 2); // watch ordering !
|
||||
|
||||
DiscreteConditional expected(X | Y = "1/1 2/3 1/4");
|
||||
EXPECT_LONGS_EQUAL(0, *(expected.beginFrontals()));
|
||||
EXPECT_LONGS_EQUAL(2, *(expected.beginParents()));
|
||||
EXPECT(expected.endParents() == expected.end());
|
||||
EXPECT(expected.endFrontals() == expected.beginParents());
|
||||
|
||||
DiscreteConditional::shared_ptr expected1 = //
|
||||
boost::make_shared<DiscreteConditional>(X | Y = "1/1 2/3 1/4");
|
||||
EXPECT(expected1);
|
||||
EXPECT_LONGS_EQUAL(0, *(expected1->beginFrontals()));
|
||||
EXPECT_LONGS_EQUAL(2, *(expected1->beginParents()));
|
||||
EXPECT(expected1->endParents() == expected1->end());
|
||||
EXPECT(expected1->endFrontals() == expected1->beginParents());
|
||||
|
||||
DecisionTreeFactor f1(X & Y, "0.5 0.4 0.2 0.5 0.6 0.8");
|
||||
DiscreteConditional actual1(1, f1);
|
||||
EXPECT(assert_equal(*expected1, actual1, 1e-9));
|
||||
EXPECT(assert_equal(expected, actual1, 1e-9));
|
||||
|
||||
DecisionTreeFactor f2(X & Y & Z,
|
||||
"0.2 0.5 0.3 0.6 0.4 0.7 0.25 0.55 0.35 0.65 0.45 0.75");
|
||||
DecisionTreeFactor f2(
|
||||
X & Y & Z, "0.2 0.5 0.3 0.6 0.4 0.7 0.25 0.55 0.35 0.65 0.45 0.75");
|
||||
DiscreteConditional actual2(1, f2);
|
||||
EXPECT(assert_equal(f2 / *f2.sum(1), *actual2.toFactor(), 1e-9));
|
||||
}
|
||||
|
@ -107,13 +105,27 @@ TEST(DiscreteConditional, Combine) {
|
|||
EXPECT(assert_equal(expected, *actual, 1e-5));
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
TEST(DiscreteConditional, likelihood) {
|
||||
DiscreteKey X(0, 2), Y(1, 3);
|
||||
DiscreteConditional conditional(X | Y = "2/8 4/6 5/5");
|
||||
|
||||
auto actual0 = conditional.likelihood(0);
|
||||
DecisionTreeFactor expected0(Y, "0.2 0.4 0.5");
|
||||
EXPECT(assert_equal(expected0, *actual0, 1e-9));
|
||||
|
||||
auto actual1 = conditional.likelihood(1);
|
||||
DecisionTreeFactor expected1(Y, "0.8 0.6 0.5");
|
||||
EXPECT(assert_equal(expected1, *actual1, 1e-9));
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
// Check markdown representation looks as expected, no parents.
|
||||
TEST(DiscreteConditional, markdown_prior) {
|
||||
DiscreteKey A(Symbol('x', 1), 3);
|
||||
DiscreteConditional conditional(A % "1/2/2");
|
||||
string expected =
|
||||
" $P(x1)$:\n"
|
||||
" *P(x1)*:\n\n"
|
||||
"|x1|value|\n"
|
||||
"|:-:|:-:|\n"
|
||||
"|0|0.2|\n"
|
||||
|
@ -130,7 +142,7 @@ TEST(DiscreteConditional, markdown_multivalued) {
|
|||
DiscreteConditional conditional(
|
||||
A | B = "2/88/10 2/20/78 33/33/34 33/33/34 95/2/3");
|
||||
string expected =
|
||||
" $P(a1|b1)$:\n"
|
||||
" *P(a1|b1)*:\n\n"
|
||||
"|b1|0|1|2|\n"
|
||||
"|:-:|:-:|:-:|:-:|\n"
|
||||
"|0|0.02|0.88|0.1|\n"
|
||||
|
@ -148,7 +160,7 @@ TEST(DiscreteConditional, markdown) {
|
|||
DiscreteKey A(2, 2), B(1, 2), C(0, 3);
|
||||
DiscreteConditional conditional(A, {B, C}, "0/1 1/3 1/1 3/1 0/1 1/0");
|
||||
string expected =
|
||||
" $P(A|B,C)$:\n"
|
||||
" *P(A|B,C)*:\n\n"
|
||||
"|B|C|0|1|\n"
|
||||
"|:-:|:-:|:-:|:-:|\n"
|
||||
"|0|0|0|1|\n"
|
||||
|
|
|
@ -0,0 +1,55 @@
|
|||
/* ----------------------------------------------------------------------------
|
||||
|
||||
* GTSAM Copyright 2010, Georgia Tech Research Corporation,
|
||||
* Atlanta, Georgia 30332-0415
|
||||
* All Rights Reserved
|
||||
* Authors: Frank Dellaert, et al. (see THANKS for the full author list)
|
||||
|
||||
* See LICENSE for the license information
|
||||
|
||||
* -------------------------------------------------------------------------- */
|
||||
|
||||
/*
|
||||
* @file testDiscretePrior.cpp
|
||||
* @brief unit tests for DiscretePrior
|
||||
* @author Frank dellaert
|
||||
* @date December 2021
|
||||
*/
|
||||
|
||||
#include <CppUnitLite/TestHarness.h>
|
||||
#include <gtsam/discrete/DiscretePrior.h>
|
||||
#include <gtsam/discrete/Signature.h>
|
||||
|
||||
using namespace std;
|
||||
using namespace gtsam;
|
||||
|
||||
static const DiscreteKey X(0, 2);
|
||||
|
||||
/* ************************************************************************* */
|
||||
TEST(DiscretePrior, constructors) {
|
||||
DiscretePrior actual(X % "2/3");
|
||||
DecisionTreeFactor f(X, "0.4 0.6");
|
||||
DiscretePrior expected(f);
|
||||
EXPECT(assert_equal(expected, actual, 1e-9));
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
TEST(DiscretePrior, operator) {
|
||||
DiscretePrior prior(X % "2/3");
|
||||
EXPECT_DOUBLES_EQUAL(prior(0), 0.4, 1e-9);
|
||||
EXPECT_DOUBLES_EQUAL(prior(1), 0.6, 1e-9);
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
TEST(DiscretePrior, to_vector) {
|
||||
DiscretePrior prior(X % "2/3");
|
||||
vector<double> expected {0.4, 0.6};
|
||||
EXPECT(prior.pmf() == expected);
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
int main() {
|
||||
TestResult tr;
|
||||
return TestRegistry::runAllTests(tr);
|
||||
}
|
||||
/* ************************************************************************* */
|
|
@ -133,10 +133,10 @@ void Scheduler::addStudentSpecificConstraints(size_t i,
|
|||
Potentials::ADT p(dummy & areaKey,
|
||||
available_); // available_ is Doodle string
|
||||
Potentials::ADT q = p.choose(dummyIndex, *slot);
|
||||
DiscreteFactor::shared_ptr f(new DecisionTreeFactor(areaKey, q));
|
||||
CSP::push_back(f);
|
||||
CSP::add(areaKey, q);
|
||||
} else {
|
||||
CSP::add(s.key_, areaKey, available_); // available_ is Doodle string
|
||||
DiscreteKeys keys {s.key_, areaKey};
|
||||
CSP::add(keys, available_); // available_ is Doodle string
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -0,0 +1,54 @@
|
|||
"""
|
||||
GTSAM Copyright 2010-2021, Georgia Tech Research Corporation,
|
||||
Atlanta, Georgia 30332-0415
|
||||
All Rights Reserved
|
||||
|
||||
See LICENSE for the license information
|
||||
|
||||
Unit tests for DecisionTreeFactors.
|
||||
Author: Frank Dellaert
|
||||
"""
|
||||
|
||||
# pylint: disable=no-name-in-module, invalid-name
|
||||
|
||||
import unittest
|
||||
|
||||
from gtsam import DecisionTreeFactor, DecisionTreeFactor, DiscreteKeys
|
||||
from gtsam.utils.test_case import GtsamTestCase
|
||||
|
||||
|
||||
class TestDecisionTreeFactor(GtsamTestCase):
|
||||
"""Tests for DecisionTreeFactors."""
|
||||
|
||||
def setUp(self):
|
||||
A = (12, 3)
|
||||
B = (5, 2)
|
||||
self.factor = DecisionTreeFactor([A, B], "1 2 3 4 5 6")
|
||||
|
||||
def test_enumerate(self):
|
||||
actual = self.factor.enumerate()
|
||||
_, values = zip(*actual)
|
||||
self.assertEqual(list(values), [1.0, 2.0, 3.0, 4.0, 5.0, 6.0])
|
||||
|
||||
def test_markdown(self):
|
||||
"""Test whether the _repr_markdown_ method."""
|
||||
|
||||
expected = \
|
||||
"|A|B|value|\n" \
|
||||
"|:-:|:-:|:-:|\n" \
|
||||
"|0|0|1|\n" \
|
||||
"|0|1|2|\n" \
|
||||
"|1|0|3|\n" \
|
||||
"|1|1|4|\n" \
|
||||
"|2|0|5|\n" \
|
||||
"|2|1|6|\n"
|
||||
|
||||
def formatter(x: int):
|
||||
return "A" if x == 12 else "B"
|
||||
|
||||
actual = self.factor._repr_markdown_(formatter)
|
||||
self.assertEqual(actual, expected)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
|
@ -14,7 +14,7 @@ Author: Frank Dellaert
|
|||
import unittest
|
||||
|
||||
from gtsam import (DiscreteBayesNet, DiscreteConditional, DiscreteFactorGraph,
|
||||
DiscreteKeys, DiscreteValues, Ordering)
|
||||
DiscreteKeys, DiscretePrior, DiscreteValues, Ordering)
|
||||
from gtsam.utils.test_case import GtsamTestCase
|
||||
|
||||
|
||||
|
@ -53,24 +53,18 @@ class TestDiscreteBayesNet(GtsamTestCase):
|
|||
XRay = (2, 2)
|
||||
Dyspnea = (1, 2)
|
||||
|
||||
def P(keys):
|
||||
dks = DiscreteKeys()
|
||||
for key in keys:
|
||||
dks.push_back(key)
|
||||
return dks
|
||||
|
||||
asia = DiscreteBayesNet()
|
||||
asia.add(Asia, P([]), "99/1")
|
||||
asia.add(Smoking, P([]), "50/50")
|
||||
asia.add(Asia, "99/1")
|
||||
asia.add(Smoking, "50/50")
|
||||
|
||||
asia.add(Tuberculosis, P([Asia]), "99/1 95/5")
|
||||
asia.add(LungCancer, P([Smoking]), "99/1 90/10")
|
||||
asia.add(Bronchitis, P([Smoking]), "70/30 40/60")
|
||||
asia.add(Tuberculosis, [Asia], "99/1 95/5")
|
||||
asia.add(LungCancer, [Smoking], "99/1 90/10")
|
||||
asia.add(Bronchitis, [Smoking], "70/30 40/60")
|
||||
|
||||
asia.add(Either, P([Tuberculosis, LungCancer]), "F T T T")
|
||||
asia.add(Either, [Tuberculosis, LungCancer], "F T T T")
|
||||
|
||||
asia.add(XRay, P([Either]), "95/5 2/98")
|
||||
asia.add(Dyspnea, P([Either, Bronchitis]), "9/1 2/8 3/7 1/9")
|
||||
asia.add(XRay, [Either], "95/5 2/98")
|
||||
asia.add(Dyspnea, [Either, Bronchitis], "9/1 2/8 3/7 1/9")
|
||||
|
||||
# Convert to factor graph
|
||||
fg = DiscreteFactorGraph(asia)
|
||||
|
@ -80,7 +74,7 @@ class TestDiscreteBayesNet(GtsamTestCase):
|
|||
for j in range(8):
|
||||
ordering.push_back(j)
|
||||
chordal = fg.eliminateSequential(ordering)
|
||||
expected2 = DiscreteConditional(Bronchitis, P([]), "11/9")
|
||||
expected2 = DiscretePrior(Bronchitis, "11/9")
|
||||
self.gtsamAssertEquals(chordal.at(7), expected2)
|
||||
|
||||
# solve
|
||||
|
|
|
@ -14,20 +14,10 @@ Author: Frank Dellaert
|
|||
import unittest
|
||||
|
||||
from gtsam import (DiscreteBayesNet, DiscreteBayesTreeClique,
|
||||
DiscreteConditional, DiscreteFactorGraph, DiscreteKeys,
|
||||
Ordering)
|
||||
DiscreteConditional, DiscreteFactorGraph, Ordering)
|
||||
from gtsam.utils.test_case import GtsamTestCase
|
||||
|
||||
|
||||
def P(*args):
|
||||
""" Create a DiscreteKeys instances from a variable number of DiscreteKey pairs."""
|
||||
# TODO: We can make life easier by providing variable argument functions in C++ itself.
|
||||
dks = DiscreteKeys()
|
||||
for key in args:
|
||||
dks.push_back(key)
|
||||
return dks
|
||||
|
||||
|
||||
class TestDiscreteBayesNet(GtsamTestCase):
|
||||
"""Tests for Discrete Bayes Nets."""
|
||||
|
||||
|
@ -40,25 +30,25 @@ class TestDiscreteBayesNet(GtsamTestCase):
|
|||
# Create thin-tree Bayesnet.
|
||||
bayesNet = DiscreteBayesNet()
|
||||
|
||||
bayesNet.add(keys[0], P(keys[8], keys[12]), "2/3 1/4 3/2 4/1")
|
||||
bayesNet.add(keys[1], P(keys[8], keys[12]), "4/1 2/3 3/2 1/4")
|
||||
bayesNet.add(keys[2], P(keys[9], keys[12]), "1/4 8/2 2/3 4/1")
|
||||
bayesNet.add(keys[3], P(keys[9], keys[12]), "1/4 2/3 3/2 4/1")
|
||||
bayesNet.add(keys[0], [keys[8], keys[12]], "2/3 1/4 3/2 4/1")
|
||||
bayesNet.add(keys[1], [keys[8], keys[12]], "4/1 2/3 3/2 1/4")
|
||||
bayesNet.add(keys[2], [keys[9], keys[12]], "1/4 8/2 2/3 4/1")
|
||||
bayesNet.add(keys[3], [keys[9], keys[12]], "1/4 2/3 3/2 4/1")
|
||||
|
||||
bayesNet.add(keys[4], P(keys[10], keys[13]), "2/3 1/4 3/2 4/1")
|
||||
bayesNet.add(keys[5], P(keys[10], keys[13]), "4/1 2/3 3/2 1/4")
|
||||
bayesNet.add(keys[6], P(keys[11], keys[13]), "1/4 3/2 2/3 4/1")
|
||||
bayesNet.add(keys[7], P(keys[11], keys[13]), "1/4 2/3 3/2 4/1")
|
||||
bayesNet.add(keys[4], [keys[10], keys[13]], "2/3 1/4 3/2 4/1")
|
||||
bayesNet.add(keys[5], [keys[10], keys[13]], "4/1 2/3 3/2 1/4")
|
||||
bayesNet.add(keys[6], [keys[11], keys[13]], "1/4 3/2 2/3 4/1")
|
||||
bayesNet.add(keys[7], [keys[11], keys[13]], "1/4 2/3 3/2 4/1")
|
||||
|
||||
bayesNet.add(keys[8], P(keys[12], keys[14]), "T 1/4 3/2 4/1")
|
||||
bayesNet.add(keys[9], P(keys[12], keys[14]), "4/1 2/3 F 1/4")
|
||||
bayesNet.add(keys[10], P(keys[13], keys[14]), "1/4 3/2 2/3 4/1")
|
||||
bayesNet.add(keys[11], P(keys[13], keys[14]), "1/4 2/3 3/2 4/1")
|
||||
bayesNet.add(keys[8], [keys[12], keys[14]], "T 1/4 3/2 4/1")
|
||||
bayesNet.add(keys[9], [keys[12], keys[14]], "4/1 2/3 F 1/4")
|
||||
bayesNet.add(keys[10], [keys[13], keys[14]], "1/4 3/2 2/3 4/1")
|
||||
bayesNet.add(keys[11], [keys[13], keys[14]], "1/4 2/3 3/2 4/1")
|
||||
|
||||
bayesNet.add(keys[12], P(keys[14]), "3/1 3/1")
|
||||
bayesNet.add(keys[13], P(keys[14]), "1/3 3/1")
|
||||
bayesNet.add(keys[12], [keys[14]], "3/1 3/1")
|
||||
bayesNet.add(keys[13], [keys[14]], "1/3 3/1")
|
||||
|
||||
bayesNet.add(keys[14], P(), "1/3")
|
||||
bayesNet.add(keys[14], "1/3")
|
||||
|
||||
# Create a factor graph out of the Bayes net.
|
||||
factorGraph = DiscreteFactorGraph(bayesNet)
|
||||
|
|
|
@ -13,12 +13,29 @@ Author: Varun Agrawal
|
|||
|
||||
import unittest
|
||||
|
||||
from gtsam import DiscreteConditional, DiscreteKeys
|
||||
from gtsam import DecisionTreeFactor, DiscreteConditional, DiscreteKeys
|
||||
from gtsam.utils.test_case import GtsamTestCase
|
||||
|
||||
|
||||
class TestDiscreteConditional(GtsamTestCase):
|
||||
"""Tests for Discrete Conditionals."""
|
||||
|
||||
def test_single_value_versions(self):
|
||||
X = (0, 2)
|
||||
Y = (1, 3)
|
||||
conditional = DiscreteConditional(X, [Y], "2/8 4/6 5/5")
|
||||
|
||||
actual0 = conditional.likelihood(0)
|
||||
expected0 = DecisionTreeFactor(Y, "0.2 0.4 0.5")
|
||||
self.gtsamAssertEquals(actual0, expected0, 1e-9)
|
||||
|
||||
actual1 = conditional.likelihood(1)
|
||||
expected1 = DecisionTreeFactor(Y, "0.8 0.6 0.5")
|
||||
self.gtsamAssertEquals(actual1, expected1, 1e-9)
|
||||
|
||||
actual = conditional.sample(2)
|
||||
self.assertIsInstance(actual, int)
|
||||
|
||||
def test_markdown(self):
|
||||
"""Test whether the _repr_markdown_ method."""
|
||||
|
||||
|
@ -32,7 +49,7 @@ class TestDiscreteConditional(GtsamTestCase):
|
|||
conditional = DiscreteConditional(A, parents,
|
||||
"0/1 1/3 1/1 3/1 0/1 1/0")
|
||||
expected = \
|
||||
" $P(A|B,C)$:\n" \
|
||||
" *P(A|B,C)*:\n\n" \
|
||||
"|B|C|0|1|\n" \
|
||||
"|:-:|:-:|:-:|:-:|\n" \
|
||||
"|0|0|0|1|\n" \
|
||||
|
|
|
@ -32,11 +32,11 @@ class TestDiscreteFactorGraph(GtsamTestCase):
|
|||
graph = DiscreteFactorGraph()
|
||||
|
||||
# Add two unary factors (priors)
|
||||
graph.add(P1, "0.9 0.3")
|
||||
graph.add(P1, [0.9, 0.3])
|
||||
graph.add(P2, "0.9 0.6")
|
||||
|
||||
# Add a binary factor
|
||||
graph.add(P1, P2, "4 1 10 4")
|
||||
graph.add([P1, P2], "4 1 10 4")
|
||||
|
||||
# Instantiate Values
|
||||
assignment = DiscreteValues()
|
||||
|
@ -85,8 +85,8 @@ class TestDiscreteFactorGraph(GtsamTestCase):
|
|||
# A simple factor graph (A)-fAC-(C)-fBC-(B)
|
||||
# with smoothness priors
|
||||
graph = DiscreteFactorGraph()
|
||||
graph.add(A, C, "3 1 1 3")
|
||||
graph.add(C, B, "3 1 1 3")
|
||||
graph.add([A, C], "3 1 1 3")
|
||||
graph.add([C, B], "3 1 1 3")
|
||||
|
||||
# Test optimization
|
||||
expectedValues = DiscreteValues()
|
||||
|
@ -105,8 +105,8 @@ class TestDiscreteFactorGraph(GtsamTestCase):
|
|||
|
||||
# Create Factor graph
|
||||
graph = DiscreteFactorGraph()
|
||||
graph.add(C, A, "0.2 0.8 0.3 0.7")
|
||||
graph.add(C, B, "0.1 0.9 0.4 0.6")
|
||||
graph.add([C, A], "0.2 0.8 0.3 0.7")
|
||||
graph.add([C, B], "0.1 0.9 0.4 0.6")
|
||||
|
||||
actualMPE = graph.optimize()
|
||||
|
||||
|
|
|
@ -0,0 +1,60 @@
|
|||
"""
|
||||
GTSAM Copyright 2010-2021, Georgia Tech Research Corporation,
|
||||
Atlanta, Georgia 30332-0415
|
||||
All Rights Reserved
|
||||
|
||||
See LICENSE for the license information
|
||||
|
||||
Unit tests for Discrete Priors.
|
||||
Author: Varun Agrawal
|
||||
"""
|
||||
|
||||
# pylint: disable=no-name-in-module, invalid-name
|
||||
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
from gtsam import DecisionTreeFactor, DiscreteKeys, DiscretePrior
|
||||
from gtsam.utils.test_case import GtsamTestCase
|
||||
|
||||
X = 0, 2
|
||||
|
||||
|
||||
class TestDiscretePrior(GtsamTestCase):
|
||||
"""Tests for Discrete Priors."""
|
||||
|
||||
def test_constructor(self):
|
||||
"""Test various constructors."""
|
||||
actual = DiscretePrior(X, "2/3")
|
||||
keys = DiscreteKeys()
|
||||
keys.push_back(X)
|
||||
f = DecisionTreeFactor(keys, "0.4 0.6")
|
||||
expected = DiscretePrior(f)
|
||||
self.gtsamAssertEquals(actual, expected)
|
||||
|
||||
def test_operator(self):
|
||||
prior = DiscretePrior(X, "2/3")
|
||||
self.assertAlmostEqual(prior(0), 0.4)
|
||||
self.assertAlmostEqual(prior(1), 0.6)
|
||||
|
||||
def test_pmf(self):
|
||||
prior = DiscretePrior(X, "2/3")
|
||||
expected = np.array([0.4, 0.6])
|
||||
np.testing.assert_allclose(expected, prior.pmf())
|
||||
|
||||
def test_markdown(self):
|
||||
"""Test the _repr_markdown_ method."""
|
||||
|
||||
prior = DiscretePrior(X, "2/3")
|
||||
expected = " *P(0)*:\n\n" \
|
||||
"|0|value|\n" \
|
||||
"|:-:|:-:|\n" \
|
||||
"|0|0.4|\n" \
|
||||
"|1|0.6|\n" \
|
||||
|
||||
actual = prior._repr_markdown_()
|
||||
self.assertEqual(actual, expected)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
Loading…
Reference in New Issue