Merge pull request #1039 from borglab/feature/DiscreteDistribution
commit
950ab111d0
|
@ -19,7 +19,7 @@
|
|||
#pragma once
|
||||
|
||||
#include <gtsam/discrete/DiscreteConditional.h>
|
||||
#include <gtsam/discrete/DiscretePrior.h>
|
||||
#include <gtsam/discrete/DiscreteDistribution.h>
|
||||
#include <gtsam/inference/BayesNet.h>
|
||||
#include <gtsam/inference/FactorGraph.h>
|
||||
|
||||
|
@ -79,9 +79,9 @@ namespace gtsam {
|
|||
// Add inherited versions of add.
|
||||
using Base::add;
|
||||
|
||||
/** Add a DiscretePrior using a table or a string */
|
||||
/** Add a DiscreteDistribution using a table or a string */
|
||||
void add(const DiscreteKey& key, const std::string& spec) {
|
||||
emplace_shared<DiscretePrior>(key, spec);
|
||||
emplace_shared<DiscreteDistribution>(key, spec);
|
||||
}
|
||||
|
||||
/** Add a DiscreteCondtional */
|
||||
|
|
|
@ -89,7 +89,7 @@ class GTSAM_EXPORT DiscreteConditional
|
|||
const std::string& spec)
|
||||
: DiscreteConditional(Signature(key, parents, spec)) {}
|
||||
|
||||
/// No-parent specialization; can also use DiscretePrior.
|
||||
/// No-parent specialization; can also use DiscreteDistribution.
|
||||
DiscreteConditional(const DiscreteKey& key, const std::string& spec)
|
||||
: DiscreteConditional(Signature(key, {}, spec)) {}
|
||||
|
||||
|
|
|
@ -10,21 +10,23 @@
|
|||
* -------------------------------------------------------------------------- */
|
||||
|
||||
/**
|
||||
* @file DiscretePrior.cpp
|
||||
* @file DiscreteDistribution.cpp
|
||||
* @date December 2021
|
||||
* @author Frank Dellaert
|
||||
*/
|
||||
|
||||
#include <gtsam/discrete/DiscretePrior.h>
|
||||
#include <gtsam/discrete/DiscreteDistribution.h>
|
||||
|
||||
#include <vector>
|
||||
|
||||
namespace gtsam {
|
||||
|
||||
void DiscretePrior::print(const std::string& s,
|
||||
const KeyFormatter& formatter) const {
|
||||
void DiscreteDistribution::print(const std::string& s,
|
||||
const KeyFormatter& formatter) const {
|
||||
Base::print(s, formatter);
|
||||
}
|
||||
|
||||
double DiscretePrior::operator()(size_t value) const {
|
||||
double DiscreteDistribution::operator()(size_t value) const {
|
||||
if (nrFrontals() != 1)
|
||||
throw std::invalid_argument(
|
||||
"Single value operator can only be invoked on single-variable "
|
||||
|
@ -34,10 +36,10 @@ double DiscretePrior::operator()(size_t value) const {
|
|||
return Base::operator()(values);
|
||||
}
|
||||
|
||||
std::vector<double> DiscretePrior::pmf() const {
|
||||
std::vector<double> DiscreteDistribution::pmf() const {
|
||||
if (nrFrontals() != 1)
|
||||
throw std::invalid_argument(
|
||||
"DiscretePrior::pmf only defined for single-variable priors");
|
||||
"DiscreteDistribution::pmf only defined for single-variable priors");
|
||||
const size_t nrValues = cardinalities_.at(keys_[0]);
|
||||
std::vector<double> array;
|
||||
array.reserve(nrValues);
|
|
@ -10,7 +10,7 @@
|
|||
* -------------------------------------------------------------------------- */
|
||||
|
||||
/**
|
||||
* @file DiscretePrior.h
|
||||
* @file DiscreteDistribution.h
|
||||
* @date December 2021
|
||||
* @author Frank Dellaert
|
||||
*/
|
||||
|
@ -20,6 +20,7 @@
|
|||
#include <gtsam/discrete/DiscreteConditional.h>
|
||||
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
namespace gtsam {
|
||||
|
||||
|
@ -27,7 +28,7 @@ namespace gtsam {
|
|||
* A prior probability on a set of discrete variables.
|
||||
* Derives from DiscreteConditional
|
||||
*/
|
||||
class GTSAM_EXPORT DiscretePrior : public DiscreteConditional {
|
||||
class GTSAM_EXPORT DiscreteDistribution : public DiscreteConditional {
|
||||
public:
|
||||
using Base = DiscreteConditional;
|
||||
|
||||
|
@ -35,35 +36,36 @@ class GTSAM_EXPORT DiscretePrior : public DiscreteConditional {
|
|||
/// @{
|
||||
|
||||
/// Default constructor needed for serialization.
|
||||
DiscretePrior() {}
|
||||
DiscreteDistribution() {}
|
||||
|
||||
/// Constructor from factor.
|
||||
DiscretePrior(const DecisionTreeFactor& f) : Base(f.size(), f) {}
|
||||
explicit DiscreteDistribution(const DecisionTreeFactor& f)
|
||||
: Base(f.size(), f) {}
|
||||
|
||||
/**
|
||||
* Construct from a Signature.
|
||||
*
|
||||
* Example: DiscretePrior P(D % "3/2");
|
||||
* Example: DiscreteDistribution P(D % "3/2");
|
||||
*/
|
||||
DiscretePrior(const Signature& s) : Base(s) {}
|
||||
explicit DiscreteDistribution(const Signature& s) : Base(s) {}
|
||||
|
||||
/**
|
||||
* Construct from key and a vector of floats specifying the probability mass
|
||||
* function (PMF).
|
||||
*
|
||||
* Example: DiscretePrior P(D, {0.4, 0.6});
|
||||
* Example: DiscreteDistribution P(D, {0.4, 0.6});
|
||||
*/
|
||||
DiscretePrior(const DiscreteKey& key, const std::vector<double>& spec)
|
||||
: DiscretePrior(Signature(key, {}, Signature::Table{spec})) {}
|
||||
DiscreteDistribution(const DiscreteKey& key, const std::vector<double>& spec)
|
||||
: DiscreteDistribution(Signature(key, {}, Signature::Table{spec})) {}
|
||||
|
||||
/**
|
||||
* Construct from key and a string specifying the probability mass function
|
||||
* (PMF).
|
||||
*
|
||||
* Example: DiscretePrior P(D, "9/1 2/8 3/7 1/9");
|
||||
* Example: DiscreteDistribution P(D, "9/1 2/8 3/7 1/9");
|
||||
*/
|
||||
DiscretePrior(const DiscreteKey& key, const std::string& spec)
|
||||
: DiscretePrior(Signature(key, {}, spec)) {}
|
||||
DiscreteDistribution(const DiscreteKey& key, const std::string& spec)
|
||||
: DiscreteDistribution(Signature(key, {}, spec)) {}
|
||||
|
||||
/// @}
|
||||
/// @name Testable
|
||||
|
@ -102,10 +104,10 @@ class GTSAM_EXPORT DiscretePrior : public DiscreteConditional {
|
|||
|
||||
/// @}
|
||||
};
|
||||
// DiscretePrior
|
||||
// DiscreteDistribution
|
||||
|
||||
// traits
|
||||
template <>
|
||||
struct traits<DiscretePrior> : public Testable<DiscretePrior> {};
|
||||
struct traits<DiscreteDistribution> : public Testable<DiscreteDistribution> {};
|
||||
|
||||
} // namespace gtsam
|
|
@ -128,12 +128,12 @@ virtual class DiscreteConditional : gtsam::DecisionTreeFactor {
|
|||
std::map<gtsam::Key, std::vector<std::string>> names) const;
|
||||
};
|
||||
|
||||
#include <gtsam/discrete/DiscretePrior.h>
|
||||
virtual class DiscretePrior : gtsam::DiscreteConditional {
|
||||
DiscretePrior();
|
||||
DiscretePrior(const gtsam::DecisionTreeFactor& f);
|
||||
DiscretePrior(const gtsam::DiscreteKey& key, string spec);
|
||||
DiscretePrior(const gtsam::DiscreteKey& key, std::vector<double> spec);
|
||||
#include <gtsam/discrete/DiscreteDistribution.h>
|
||||
virtual class DiscreteDistribution : gtsam::DiscreteConditional {
|
||||
DiscreteDistribution();
|
||||
DiscreteDistribution(const gtsam::DecisionTreeFactor& f);
|
||||
DiscreteDistribution(const gtsam::DiscreteKey& key, string spec);
|
||||
DiscreteDistribution(const gtsam::DiscreteKey& key, std::vector<double> spec);
|
||||
void print(string s = "Discrete Prior\n",
|
||||
const gtsam::KeyFormatter& keyFormatter =
|
||||
gtsam::DefaultKeyFormatter) const;
|
||||
|
|
|
@ -20,7 +20,7 @@
|
|||
#include <CppUnitLite/TestHarness.h>
|
||||
#include <gtsam/base/Testable.h>
|
||||
#include <gtsam/discrete/DecisionTreeFactor.h>
|
||||
#include <gtsam/discrete/DiscretePrior.h>
|
||||
#include <gtsam/discrete/DiscreteDistribution.h>
|
||||
#include <gtsam/discrete/Signature.h>
|
||||
|
||||
#include <boost/assign/std/map.hpp>
|
||||
|
@ -56,8 +56,8 @@ TEST( DecisionTreeFactor, constructors)
|
|||
TEST(DecisionTreeFactor, multiplication) {
|
||||
DiscreteKey v0(0, 2), v1(1, 2), v2(2, 2);
|
||||
|
||||
// Multiply with a DiscretePrior, i.e., Bayes Law!
|
||||
DiscretePrior prior(v1 % "1/3");
|
||||
// Multiply with a DiscreteDistribution, i.e., Bayes Law!
|
||||
DiscreteDistribution prior(v1 % "1/3");
|
||||
DecisionTreeFactor f1(v0 & v1, "1 2 3 4");
|
||||
DecisionTreeFactor expected(v0 & v1, "0.25 1.5 0.75 3");
|
||||
CHECK(assert_equal(expected, static_cast<DecisionTreeFactor>(prior) * f1));
|
||||
|
|
|
@ -11,42 +11,41 @@
|
|||
|
||||
/*
|
||||
* @file testDiscretePrior.cpp
|
||||
* @brief unit tests for DiscretePrior
|
||||
* @brief unit tests for DiscreteDistribution
|
||||
* @author Frank dellaert
|
||||
* @date December 2021
|
||||
*/
|
||||
|
||||
#include <CppUnitLite/TestHarness.h>
|
||||
#include <gtsam/discrete/DiscretePrior.h>
|
||||
#include <gtsam/discrete/DiscreteDistribution.h>
|
||||
#include <gtsam/discrete/Signature.h>
|
||||
|
||||
using namespace std;
|
||||
using namespace gtsam;
|
||||
|
||||
static const DiscreteKey X(0, 2);
|
||||
|
||||
/* ************************************************************************* */
|
||||
TEST(DiscretePrior, constructors) {
|
||||
TEST(DiscreteDistribution, constructors) {
|
||||
DecisionTreeFactor f(X, "0.4 0.6");
|
||||
DiscretePrior expected(f);
|
||||
DiscreteDistribution expected(f);
|
||||
|
||||
DiscretePrior actual(X % "2/3");
|
||||
DiscreteDistribution actual(X % "2/3");
|
||||
EXPECT_LONGS_EQUAL(1, actual.nrFrontals());
|
||||
EXPECT_LONGS_EQUAL(0, actual.nrParents());
|
||||
EXPECT(assert_equal(expected, actual, 1e-9));
|
||||
|
||||
const vector<double> pmf{0.4, 0.6};
|
||||
DiscretePrior actual2(X, pmf);
|
||||
const std::vector<double> pmf{0.4, 0.6};
|
||||
DiscreteDistribution actual2(X, pmf);
|
||||
EXPECT_LONGS_EQUAL(1, actual2.nrFrontals());
|
||||
EXPECT_LONGS_EQUAL(0, actual2.nrParents());
|
||||
EXPECT(assert_equal(expected, actual2, 1e-9));
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
TEST(DiscretePrior, Multiply) {
|
||||
TEST(DiscreteDistribution, Multiply) {
|
||||
DiscreteKey A(0, 2), B(1, 2);
|
||||
DiscreteConditional conditional(A | B = "1/2 2/1");
|
||||
DiscretePrior prior(B, "1/2");
|
||||
DiscreteDistribution prior(B, "1/2");
|
||||
DiscreteConditional actual = prior * conditional; // P(A|B) * P(B)
|
||||
|
||||
EXPECT_LONGS_EQUAL(2, actual.nrFrontals()); // = P(A,B)
|
||||
|
@ -56,22 +55,22 @@ TEST(DiscretePrior, Multiply) {
|
|||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
TEST(DiscretePrior, operator) {
|
||||
DiscretePrior prior(X % "2/3");
|
||||
TEST(DiscreteDistribution, operator) {
|
||||
DiscreteDistribution prior(X % "2/3");
|
||||
EXPECT_DOUBLES_EQUAL(prior(0), 0.4, 1e-9);
|
||||
EXPECT_DOUBLES_EQUAL(prior(1), 0.6, 1e-9);
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
TEST(DiscretePrior, pmf) {
|
||||
DiscretePrior prior(X % "2/3");
|
||||
vector<double> expected {0.4, 0.6};
|
||||
EXPECT(prior.pmf() == expected);
|
||||
TEST(DiscreteDistribution, pmf) {
|
||||
DiscreteDistribution prior(X % "2/3");
|
||||
std::vector<double> expected{0.4, 0.6};
|
||||
EXPECT(prior.pmf() == expected);
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
TEST(DiscretePrior, sample) {
|
||||
DiscretePrior prior(X % "2/3");
|
||||
TEST(DiscreteDistribution, sample) {
|
||||
DiscreteDistribution prior(X % "2/3");
|
||||
prior.sample();
|
||||
}
|
||||
|
|
@ -13,7 +13,7 @@ Author: Frank Dellaert
|
|||
|
||||
import unittest
|
||||
|
||||
from gtsam import DecisionTreeFactor, DiscreteValues, DiscretePrior, Ordering
|
||||
from gtsam import DecisionTreeFactor, DiscreteValues, DiscreteDistribution, Ordering
|
||||
from gtsam.utils.test_case import GtsamTestCase
|
||||
|
||||
|
||||
|
@ -36,8 +36,8 @@ class TestDecisionTreeFactor(GtsamTestCase):
|
|||
v1 = (1, 2)
|
||||
v2 = (2, 2)
|
||||
|
||||
# Multiply with a DiscretePrior, i.e., Bayes Law!
|
||||
prior = DiscretePrior(v1, [1, 3])
|
||||
# Multiply with a DiscreteDistribution, i.e., Bayes Law!
|
||||
prior = DiscreteDistribution(v1, [1, 3])
|
||||
f1 = DecisionTreeFactor([v0, v1], "1 2 3 4")
|
||||
expected = DecisionTreeFactor([v0, v1], "0.25 1.5 0.75 3")
|
||||
self.gtsamAssertEquals(DecisionTreeFactor(prior) * f1, expected)
|
||||
|
|
|
@ -14,7 +14,7 @@ Author: Frank Dellaert
|
|||
import unittest
|
||||
|
||||
from gtsam import (DiscreteBayesNet, DiscreteConditional, DiscreteFactorGraph,
|
||||
DiscreteKeys, DiscretePrior, DiscreteValues, Ordering)
|
||||
DiscreteKeys, DiscreteDistribution, DiscreteValues, Ordering)
|
||||
from gtsam.utils.test_case import GtsamTestCase
|
||||
|
||||
|
||||
|
@ -74,7 +74,7 @@ class TestDiscreteBayesNet(GtsamTestCase):
|
|||
for j in range(8):
|
||||
ordering.push_back(j)
|
||||
chordal = fg.eliminateSequential(ordering)
|
||||
expected2 = DiscretePrior(Bronchitis, "11/9")
|
||||
expected2 = DiscreteDistribution(Bronchitis, "11/9")
|
||||
self.gtsamAssertEquals(chordal.at(7), expected2)
|
||||
|
||||
# solve
|
||||
|
|
|
@ -14,7 +14,7 @@ Author: Frank Dellaert
|
|||
import unittest
|
||||
|
||||
import numpy as np
|
||||
from gtsam import DecisionTreeFactor, DiscreteKeys, DiscretePrior
|
||||
from gtsam import DecisionTreeFactor, DiscreteKeys, DiscreteDistribution
|
||||
from gtsam.utils.test_case import GtsamTestCase
|
||||
|
||||
X = 0, 2
|
||||
|
@ -28,33 +28,33 @@ class TestDiscretePrior(GtsamTestCase):
|
|||
keys = DiscreteKeys()
|
||||
keys.push_back(X)
|
||||
f = DecisionTreeFactor(keys, "0.4 0.6")
|
||||
expected = DiscretePrior(f)
|
||||
|
||||
actual = DiscretePrior(X, "2/3")
|
||||
expected = DiscreteDistribution(f)
|
||||
|
||||
actual = DiscreteDistribution(X, "2/3")
|
||||
self.gtsamAssertEquals(actual, expected)
|
||||
|
||||
actual2 = DiscretePrior(X, [0.4, 0.6])
|
||||
|
||||
actual2 = DiscreteDistribution(X, [0.4, 0.6])
|
||||
self.gtsamAssertEquals(actual2, expected)
|
||||
|
||||
def test_operator(self):
|
||||
prior = DiscretePrior(X, "2/3")
|
||||
prior = DiscreteDistribution(X, "2/3")
|
||||
self.assertAlmostEqual(prior(0), 0.4)
|
||||
self.assertAlmostEqual(prior(1), 0.6)
|
||||
|
||||
def test_pmf(self):
|
||||
prior = DiscretePrior(X, "2/3")
|
||||
prior = DiscreteDistribution(X, "2/3")
|
||||
expected = np.array([0.4, 0.6])
|
||||
np.testing.assert_allclose(expected, prior.pmf())
|
||||
|
||||
def test_sample(self):
|
||||
prior = DiscretePrior(X, "2/3")
|
||||
prior = DiscreteDistribution(X, "2/3")
|
||||
actual = prior.sample()
|
||||
self.assertIsInstance(actual, int)
|
||||
|
||||
def test_markdown(self):
|
||||
"""Test the _repr_markdown_ method."""
|
||||
|
||||
prior = DiscretePrior(X, "2/3")
|
||||
prior = DiscreteDistribution(X, "2/3")
|
||||
expected = " *P(0):*\n\n" \
|
||||
"|0|value|\n" \
|
||||
"|:-:|:-:|\n" \
|
Loading…
Reference in New Issue