Constructor from PMF

release/4.3a0
Frank Dellaert 2022-01-15 08:15:46 -05:00
parent 1000825b03
commit be5aa56df7
4 changed files with 22 additions and 10 deletions

View File

@ -48,17 +48,17 @@ class GTSAM_EXPORT DiscretePrior : public DiscreteConditional {
DiscretePrior(const Signature& s) : Base(s) {} DiscretePrior(const Signature& s) : Base(s) {}
/** /**
* Construct from key and a Signature::Table specifying the * Construct from key and a vector of floats specifying the probability mass
* conditional probability table (CPT). * function (PMF).
* *
* Example: DiscretePrior P(D, table); * Example: DiscretePrior P(D, {0.4, 0.6});
*/ */
DiscretePrior(const DiscreteKey& key, const Signature::Table& table) DiscretePrior(const DiscreteKey& key, const std::vector<double>& spec)
: Base(Signature(key, {}, table)) {} : DiscretePrior(Signature(key, {}, Signature::Table{spec})) {}
/** /**
* Construct from key and a string specifying the conditional * Construct from key and a string specifying the probability mass function
* probability table (CPT). * (PMF).
* *
* Example: DiscretePrior P(D, "9/1 2/8 3/7 1/9"); * Example: DiscretePrior P(D, "9/1 2/8 3/7 1/9");
*/ */

View File

@ -120,6 +120,7 @@ virtual class DiscretePrior : gtsam::DiscreteConditional {
DiscretePrior(); DiscretePrior();
DiscretePrior(const gtsam::DecisionTreeFactor& f); DiscretePrior(const gtsam::DecisionTreeFactor& f);
DiscretePrior(const gtsam::DiscreteKey& key, string spec); DiscretePrior(const gtsam::DiscreteKey& key, string spec);
DiscretePrior(const gtsam::DiscreteKey& key, std::vector<double> spec);
void print(string s = "Discrete Prior\n", void print(string s = "Discrete Prior\n",
const gtsam::KeyFormatter& keyFormatter = const gtsam::KeyFormatter& keyFormatter =
gtsam::DefaultKeyFormatter) const; gtsam::DefaultKeyFormatter) const;

View File

@ -27,12 +27,19 @@ static const DiscreteKey X(0, 2);
/* ************************************************************************* */ /* ************************************************************************* */
TEST(DiscretePrior, constructors) { TEST(DiscretePrior, constructors) {
DecisionTreeFactor f(X, "0.4 0.6");
DiscretePrior expected(f);
DiscretePrior actual(X % "2/3"); DiscretePrior actual(X % "2/3");
EXPECT_LONGS_EQUAL(1, actual.nrFrontals()); EXPECT_LONGS_EQUAL(1, actual.nrFrontals());
EXPECT_LONGS_EQUAL(0, actual.nrParents()); EXPECT_LONGS_EQUAL(0, actual.nrParents());
DecisionTreeFactor f(X, "0.4 0.6");
DiscretePrior expected(f);
EXPECT(assert_equal(expected, actual, 1e-9)); EXPECT(assert_equal(expected, actual, 1e-9));
const vector<double> pmf{0.4, 0.6};
DiscretePrior actual2(X, pmf);
EXPECT_LONGS_EQUAL(1, actual2.nrFrontals());
EXPECT_LONGS_EQUAL(0, actual2.nrParents());
EXPECT(assert_equal(expected, actual2, 1e-9));
} }
/* ************************************************************************* */ /* ************************************************************************* */

View File

@ -25,12 +25,16 @@ class TestDiscretePrior(GtsamTestCase):
def test_constructor(self): def test_constructor(self):
"""Test various constructors.""" """Test various constructors."""
actual = DiscretePrior(X, "2/3")
keys = DiscreteKeys() keys = DiscreteKeys()
keys.push_back(X) keys.push_back(X)
f = DecisionTreeFactor(keys, "0.4 0.6") f = DecisionTreeFactor(keys, "0.4 0.6")
expected = DiscretePrior(f) expected = DiscretePrior(f)
actual = DiscretePrior(X, "2/3")
self.gtsamAssertEquals(actual, expected) self.gtsamAssertEquals(actual, expected)
actual2 = DiscretePrior(X, [0.4, 0.6])
self.gtsamAssertEquals(actual2, expected)
def test_operator(self): def test_operator(self):
prior = DiscretePrior(X, "2/3") prior = DiscretePrior(X, "2/3")