Merge pull request #1039 from borglab/feature/DiscreteDistribution

release/4.3a0
Frank Dellaert 2022-01-16 16:43:51 -05:00 committed by GitHub
commit 950ab111d0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 70 additions and 67 deletions

View File

@ -19,7 +19,7 @@
#pragma once #pragma once
#include <gtsam/discrete/DiscreteConditional.h> #include <gtsam/discrete/DiscreteConditional.h>
#include <gtsam/discrete/DiscretePrior.h> #include <gtsam/discrete/DiscreteDistribution.h>
#include <gtsam/inference/BayesNet.h> #include <gtsam/inference/BayesNet.h>
#include <gtsam/inference/FactorGraph.h> #include <gtsam/inference/FactorGraph.h>
@ -79,9 +79,9 @@ namespace gtsam {
// Add inherited versions of add. // Add inherited versions of add.
using Base::add; using Base::add;
/** Add a DiscretePrior using a table or a string */ /** Add a DiscreteDistribution using a table or a string */
void add(const DiscreteKey& key, const std::string& spec) { void add(const DiscreteKey& key, const std::string& spec) {
emplace_shared<DiscretePrior>(key, spec); emplace_shared<DiscreteDistribution>(key, spec);
} }
/** Add a DiscreteCondtional */ /** Add a DiscreteCondtional */

View File

@ -89,7 +89,7 @@ class GTSAM_EXPORT DiscreteConditional
const std::string& spec) const std::string& spec)
: DiscreteConditional(Signature(key, parents, spec)) {} : DiscreteConditional(Signature(key, parents, spec)) {}
/// No-parent specialization; can also use DiscretePrior. /// No-parent specialization; can also use DiscreteDistribution.
DiscreteConditional(const DiscreteKey& key, const std::string& spec) DiscreteConditional(const DiscreteKey& key, const std::string& spec)
: DiscreteConditional(Signature(key, {}, spec)) {} : DiscreteConditional(Signature(key, {}, spec)) {}

View File

@ -10,21 +10,23 @@
* -------------------------------------------------------------------------- */ * -------------------------------------------------------------------------- */
/** /**
* @file DiscretePrior.cpp * @file DiscreteDistribution.cpp
* @date December 2021 * @date December 2021
* @author Frank Dellaert * @author Frank Dellaert
*/ */
#include <gtsam/discrete/DiscretePrior.h> #include <gtsam/discrete/DiscreteDistribution.h>
#include <vector>
namespace gtsam { namespace gtsam {
void DiscretePrior::print(const std::string& s, void DiscreteDistribution::print(const std::string& s,
const KeyFormatter& formatter) const { const KeyFormatter& formatter) const {
Base::print(s, formatter); Base::print(s, formatter);
} }
double DiscretePrior::operator()(size_t value) const { double DiscreteDistribution::operator()(size_t value) const {
if (nrFrontals() != 1) if (nrFrontals() != 1)
throw std::invalid_argument( throw std::invalid_argument(
"Single value operator can only be invoked on single-variable " "Single value operator can only be invoked on single-variable "
@ -34,10 +36,10 @@ double DiscretePrior::operator()(size_t value) const {
return Base::operator()(values); return Base::operator()(values);
} }
std::vector<double> DiscretePrior::pmf() const { std::vector<double> DiscreteDistribution::pmf() const {
if (nrFrontals() != 1) if (nrFrontals() != 1)
throw std::invalid_argument( throw std::invalid_argument(
"DiscretePrior::pmf only defined for single-variable priors"); "DiscreteDistribution::pmf only defined for single-variable priors");
const size_t nrValues = cardinalities_.at(keys_[0]); const size_t nrValues = cardinalities_.at(keys_[0]);
std::vector<double> array; std::vector<double> array;
array.reserve(nrValues); array.reserve(nrValues);

View File

@ -10,7 +10,7 @@
* -------------------------------------------------------------------------- */ * -------------------------------------------------------------------------- */
/** /**
* @file DiscretePrior.h * @file DiscreteDistribution.h
* @date December 2021 * @date December 2021
* @author Frank Dellaert * @author Frank Dellaert
*/ */
@ -20,6 +20,7 @@
#include <gtsam/discrete/DiscreteConditional.h> #include <gtsam/discrete/DiscreteConditional.h>
#include <string> #include <string>
#include <vector>
namespace gtsam { namespace gtsam {
@ -27,7 +28,7 @@ namespace gtsam {
* A prior probability on a set of discrete variables. * A prior probability on a set of discrete variables.
* Derives from DiscreteConditional * Derives from DiscreteConditional
*/ */
class GTSAM_EXPORT DiscretePrior : public DiscreteConditional { class GTSAM_EXPORT DiscreteDistribution : public DiscreteConditional {
public: public:
using Base = DiscreteConditional; using Base = DiscreteConditional;
@ -35,35 +36,36 @@ class GTSAM_EXPORT DiscretePrior : public DiscreteConditional {
/// @{ /// @{
/// Default constructor needed for serialization. /// Default constructor needed for serialization.
DiscretePrior() {} DiscreteDistribution() {}
/// Constructor from factor. /// Constructor from factor.
DiscretePrior(const DecisionTreeFactor& f) : Base(f.size(), f) {} explicit DiscreteDistribution(const DecisionTreeFactor& f)
: Base(f.size(), f) {}
/** /**
* Construct from a Signature. * Construct from a Signature.
* *
* Example: DiscretePrior P(D % "3/2"); * Example: DiscreteDistribution P(D % "3/2");
*/ */
DiscretePrior(const Signature& s) : Base(s) {} explicit DiscreteDistribution(const Signature& s) : Base(s) {}
/** /**
* Construct from key and a vector of floats specifying the probability mass * Construct from key and a vector of floats specifying the probability mass
* function (PMF). * function (PMF).
* *
* Example: DiscretePrior P(D, {0.4, 0.6}); * Example: DiscreteDistribution P(D, {0.4, 0.6});
*/ */
DiscretePrior(const DiscreteKey& key, const std::vector<double>& spec) DiscreteDistribution(const DiscreteKey& key, const std::vector<double>& spec)
: DiscretePrior(Signature(key, {}, Signature::Table{spec})) {} : DiscreteDistribution(Signature(key, {}, Signature::Table{spec})) {}
/** /**
* Construct from key and a string specifying the probability mass function * Construct from key and a string specifying the probability mass function
* (PMF). * (PMF).
* *
* Example: DiscretePrior P(D, "9/1 2/8 3/7 1/9"); * Example: DiscreteDistribution P(D, "9/1 2/8 3/7 1/9");
*/ */
DiscretePrior(const DiscreteKey& key, const std::string& spec) DiscreteDistribution(const DiscreteKey& key, const std::string& spec)
: DiscretePrior(Signature(key, {}, spec)) {} : DiscreteDistribution(Signature(key, {}, spec)) {}
/// @} /// @}
/// @name Testable /// @name Testable
@ -102,10 +104,10 @@ class GTSAM_EXPORT DiscretePrior : public DiscreteConditional {
/// @} /// @}
}; };
// DiscretePrior // DiscreteDistribution
// traits // traits
template <> template <>
struct traits<DiscretePrior> : public Testable<DiscretePrior> {}; struct traits<DiscreteDistribution> : public Testable<DiscreteDistribution> {};
} // namespace gtsam } // namespace gtsam

View File

@ -128,12 +128,12 @@ virtual class DiscreteConditional : gtsam::DecisionTreeFactor {
std::map<gtsam::Key, std::vector<std::string>> names) const; std::map<gtsam::Key, std::vector<std::string>> names) const;
}; };
#include <gtsam/discrete/DiscretePrior.h> #include <gtsam/discrete/DiscreteDistribution.h>
virtual class DiscretePrior : gtsam::DiscreteConditional { virtual class DiscreteDistribution : gtsam::DiscreteConditional {
DiscretePrior(); DiscreteDistribution();
DiscretePrior(const gtsam::DecisionTreeFactor& f); DiscreteDistribution(const gtsam::DecisionTreeFactor& f);
DiscretePrior(const gtsam::DiscreteKey& key, string spec); DiscreteDistribution(const gtsam::DiscreteKey& key, string spec);
DiscretePrior(const gtsam::DiscreteKey& key, std::vector<double> spec); DiscreteDistribution(const gtsam::DiscreteKey& key, std::vector<double> spec);
void print(string s = "Discrete Prior\n", void print(string s = "Discrete Prior\n",
const gtsam::KeyFormatter& keyFormatter = const gtsam::KeyFormatter& keyFormatter =
gtsam::DefaultKeyFormatter) const; gtsam::DefaultKeyFormatter) const;

View File

@ -20,7 +20,7 @@
#include <CppUnitLite/TestHarness.h> #include <CppUnitLite/TestHarness.h>
#include <gtsam/base/Testable.h> #include <gtsam/base/Testable.h>
#include <gtsam/discrete/DecisionTreeFactor.h> #include <gtsam/discrete/DecisionTreeFactor.h>
#include <gtsam/discrete/DiscretePrior.h> #include <gtsam/discrete/DiscreteDistribution.h>
#include <gtsam/discrete/Signature.h> #include <gtsam/discrete/Signature.h>
#include <boost/assign/std/map.hpp> #include <boost/assign/std/map.hpp>
@ -56,8 +56,8 @@ TEST( DecisionTreeFactor, constructors)
TEST(DecisionTreeFactor, multiplication) { TEST(DecisionTreeFactor, multiplication) {
DiscreteKey v0(0, 2), v1(1, 2), v2(2, 2); DiscreteKey v0(0, 2), v1(1, 2), v2(2, 2);
// Multiply with a DiscretePrior, i.e., Bayes Law! // Multiply with a DiscreteDistribution, i.e., Bayes Law!
DiscretePrior prior(v1 % "1/3"); DiscreteDistribution prior(v1 % "1/3");
DecisionTreeFactor f1(v0 & v1, "1 2 3 4"); DecisionTreeFactor f1(v0 & v1, "1 2 3 4");
DecisionTreeFactor expected(v0 & v1, "0.25 1.5 0.75 3"); DecisionTreeFactor expected(v0 & v1, "0.25 1.5 0.75 3");
CHECK(assert_equal(expected, static_cast<DecisionTreeFactor>(prior) * f1)); CHECK(assert_equal(expected, static_cast<DecisionTreeFactor>(prior) * f1));

View File

@ -11,42 +11,41 @@
/* /*
* @file testDiscretePrior.cpp * @file testDiscretePrior.cpp
* @brief unit tests for DiscretePrior * @brief unit tests for DiscreteDistribution
* @author Frank dellaert * @author Frank dellaert
* @date December 2021 * @date December 2021
*/ */
#include <CppUnitLite/TestHarness.h> #include <CppUnitLite/TestHarness.h>
#include <gtsam/discrete/DiscretePrior.h> #include <gtsam/discrete/DiscreteDistribution.h>
#include <gtsam/discrete/Signature.h> #include <gtsam/discrete/Signature.h>
using namespace std;
using namespace gtsam; using namespace gtsam;
static const DiscreteKey X(0, 2); static const DiscreteKey X(0, 2);
/* ************************************************************************* */ /* ************************************************************************* */
TEST(DiscretePrior, constructors) { TEST(DiscreteDistribution, constructors) {
DecisionTreeFactor f(X, "0.4 0.6"); DecisionTreeFactor f(X, "0.4 0.6");
DiscretePrior expected(f); DiscreteDistribution expected(f);
DiscretePrior actual(X % "2/3"); DiscreteDistribution actual(X % "2/3");
EXPECT_LONGS_EQUAL(1, actual.nrFrontals()); EXPECT_LONGS_EQUAL(1, actual.nrFrontals());
EXPECT_LONGS_EQUAL(0, actual.nrParents()); EXPECT_LONGS_EQUAL(0, actual.nrParents());
EXPECT(assert_equal(expected, actual, 1e-9)); EXPECT(assert_equal(expected, actual, 1e-9));
const vector<double> pmf{0.4, 0.6}; const std::vector<double> pmf{0.4, 0.6};
DiscretePrior actual2(X, pmf); DiscreteDistribution actual2(X, pmf);
EXPECT_LONGS_EQUAL(1, actual2.nrFrontals()); EXPECT_LONGS_EQUAL(1, actual2.nrFrontals());
EXPECT_LONGS_EQUAL(0, actual2.nrParents()); EXPECT_LONGS_EQUAL(0, actual2.nrParents());
EXPECT(assert_equal(expected, actual2, 1e-9)); EXPECT(assert_equal(expected, actual2, 1e-9));
} }
/* ************************************************************************* */ /* ************************************************************************* */
TEST(DiscretePrior, Multiply) { TEST(DiscreteDistribution, Multiply) {
DiscreteKey A(0, 2), B(1, 2); DiscreteKey A(0, 2), B(1, 2);
DiscreteConditional conditional(A | B = "1/2 2/1"); DiscreteConditional conditional(A | B = "1/2 2/1");
DiscretePrior prior(B, "1/2"); DiscreteDistribution prior(B, "1/2");
DiscreteConditional actual = prior * conditional; // P(A|B) * P(B) DiscreteConditional actual = prior * conditional; // P(A|B) * P(B)
EXPECT_LONGS_EQUAL(2, actual.nrFrontals()); // = P(A,B) EXPECT_LONGS_EQUAL(2, actual.nrFrontals()); // = P(A,B)
@ -56,22 +55,22 @@ TEST(DiscretePrior, Multiply) {
} }
/* ************************************************************************* */ /* ************************************************************************* */
TEST(DiscretePrior, operator) { TEST(DiscreteDistribution, operator) {
DiscretePrior prior(X % "2/3"); DiscreteDistribution prior(X % "2/3");
EXPECT_DOUBLES_EQUAL(prior(0), 0.4, 1e-9); EXPECT_DOUBLES_EQUAL(prior(0), 0.4, 1e-9);
EXPECT_DOUBLES_EQUAL(prior(1), 0.6, 1e-9); EXPECT_DOUBLES_EQUAL(prior(1), 0.6, 1e-9);
} }
/* ************************************************************************* */ /* ************************************************************************* */
TEST(DiscretePrior, pmf) { TEST(DiscreteDistribution, pmf) {
DiscretePrior prior(X % "2/3"); DiscreteDistribution prior(X % "2/3");
vector<double> expected {0.4, 0.6}; std::vector<double> expected{0.4, 0.6};
EXPECT(prior.pmf() == expected); EXPECT(prior.pmf() == expected);
} }
/* ************************************************************************* */ /* ************************************************************************* */
TEST(DiscretePrior, sample) { TEST(DiscreteDistribution, sample) {
DiscretePrior prior(X % "2/3"); DiscreteDistribution prior(X % "2/3");
prior.sample(); prior.sample();
} }

View File

@ -13,7 +13,7 @@ Author: Frank Dellaert
import unittest import unittest
from gtsam import DecisionTreeFactor, DiscreteValues, DiscretePrior, Ordering from gtsam import DecisionTreeFactor, DiscreteValues, DiscreteDistribution, Ordering
from gtsam.utils.test_case import GtsamTestCase from gtsam.utils.test_case import GtsamTestCase
@ -36,8 +36,8 @@ class TestDecisionTreeFactor(GtsamTestCase):
v1 = (1, 2) v1 = (1, 2)
v2 = (2, 2) v2 = (2, 2)
# Multiply with a DiscretePrior, i.e., Bayes Law! # Multiply with a DiscreteDistribution, i.e., Bayes Law!
prior = DiscretePrior(v1, [1, 3]) prior = DiscreteDistribution(v1, [1, 3])
f1 = DecisionTreeFactor([v0, v1], "1 2 3 4") f1 = DecisionTreeFactor([v0, v1], "1 2 3 4")
expected = DecisionTreeFactor([v0, v1], "0.25 1.5 0.75 3") expected = DecisionTreeFactor([v0, v1], "0.25 1.5 0.75 3")
self.gtsamAssertEquals(DecisionTreeFactor(prior) * f1, expected) self.gtsamAssertEquals(DecisionTreeFactor(prior) * f1, expected)

View File

@ -14,7 +14,7 @@ Author: Frank Dellaert
import unittest import unittest
from gtsam import (DiscreteBayesNet, DiscreteConditional, DiscreteFactorGraph, from gtsam import (DiscreteBayesNet, DiscreteConditional, DiscreteFactorGraph,
DiscreteKeys, DiscretePrior, DiscreteValues, Ordering) DiscreteKeys, DiscreteDistribution, DiscreteValues, Ordering)
from gtsam.utils.test_case import GtsamTestCase from gtsam.utils.test_case import GtsamTestCase
@ -74,7 +74,7 @@ class TestDiscreteBayesNet(GtsamTestCase):
for j in range(8): for j in range(8):
ordering.push_back(j) ordering.push_back(j)
chordal = fg.eliminateSequential(ordering) chordal = fg.eliminateSequential(ordering)
expected2 = DiscretePrior(Bronchitis, "11/9") expected2 = DiscreteDistribution(Bronchitis, "11/9")
self.gtsamAssertEquals(chordal.at(7), expected2) self.gtsamAssertEquals(chordal.at(7), expected2)
# solve # solve

View File

@ -14,7 +14,7 @@ Author: Frank Dellaert
import unittest import unittest
import numpy as np import numpy as np
from gtsam import DecisionTreeFactor, DiscreteKeys, DiscretePrior from gtsam import DecisionTreeFactor, DiscreteKeys, DiscreteDistribution
from gtsam.utils.test_case import GtsamTestCase from gtsam.utils.test_case import GtsamTestCase
X = 0, 2 X = 0, 2
@ -28,33 +28,33 @@ class TestDiscretePrior(GtsamTestCase):
keys = DiscreteKeys() keys = DiscreteKeys()
keys.push_back(X) keys.push_back(X)
f = DecisionTreeFactor(keys, "0.4 0.6") f = DecisionTreeFactor(keys, "0.4 0.6")
expected = DiscretePrior(f) expected = DiscreteDistribution(f)
actual = DiscretePrior(X, "2/3") actual = DiscreteDistribution(X, "2/3")
self.gtsamAssertEquals(actual, expected) self.gtsamAssertEquals(actual, expected)
actual2 = DiscretePrior(X, [0.4, 0.6]) actual2 = DiscreteDistribution(X, [0.4, 0.6])
self.gtsamAssertEquals(actual2, expected) self.gtsamAssertEquals(actual2, expected)
def test_operator(self): def test_operator(self):
prior = DiscretePrior(X, "2/3") prior = DiscreteDistribution(X, "2/3")
self.assertAlmostEqual(prior(0), 0.4) self.assertAlmostEqual(prior(0), 0.4)
self.assertAlmostEqual(prior(1), 0.6) self.assertAlmostEqual(prior(1), 0.6)
def test_pmf(self): def test_pmf(self):
prior = DiscretePrior(X, "2/3") prior = DiscreteDistribution(X, "2/3")
expected = np.array([0.4, 0.6]) expected = np.array([0.4, 0.6])
np.testing.assert_allclose(expected, prior.pmf()) np.testing.assert_allclose(expected, prior.pmf())
def test_sample(self): def test_sample(self):
prior = DiscretePrior(X, "2/3") prior = DiscreteDistribution(X, "2/3")
actual = prior.sample() actual = prior.sample()
self.assertIsInstance(actual, int) self.assertIsInstance(actual, int)
def test_markdown(self): def test_markdown(self):
"""Test the _repr_markdown_ method.""" """Test the _repr_markdown_ method."""
prior = DiscretePrior(X, "2/3") prior = DiscreteDistribution(X, "2/3")
expected = " *P(0):*\n\n" \ expected = " *P(0):*\n\n" \
"|0|value|\n" \ "|0|value|\n" \
"|:-:|:-:|\n" \ "|:-:|:-:|\n" \