diff --git a/gtsam/discrete/DiscretePrior.h b/gtsam/discrete/DiscretePrior.h index 9ac8acb17..1da188215 100644 --- a/gtsam/discrete/DiscretePrior.h +++ b/gtsam/discrete/DiscretePrior.h @@ -48,17 +48,17 @@ class GTSAM_EXPORT DiscretePrior : public DiscreteConditional { DiscretePrior(const Signature& s) : Base(s) {} /** - * Construct from key and a Signature::Table specifying the - * conditional probability table (CPT). + * Construct from key and a vector of floats specifying the probability mass + * function (PMF). * - * Example: DiscretePrior P(D, table); + * Example: DiscretePrior P(D, {0.4, 0.6}); */ - DiscretePrior(const DiscreteKey& key, const Signature::Table& table) - : Base(Signature(key, {}, table)) {} + DiscretePrior(const DiscreteKey& key, const std::vector& spec) + : DiscretePrior(Signature(key, {}, Signature::Table{spec})) {} /** - * Construct from key and a string specifying the conditional - * probability table (CPT). + * Construct from key and a string specifying the probability mass function + * (PMF). * * Example: DiscretePrior P(D, "9/1 2/8 3/7 1/9"); */ diff --git a/gtsam/discrete/discrete.i b/gtsam/discrete/discrete.i index 218b790e8..12bd5be54 100644 --- a/gtsam/discrete/discrete.i +++ b/gtsam/discrete/discrete.i @@ -120,6 +120,7 @@ virtual class DiscretePrior : gtsam::DiscreteConditional { DiscretePrior(); DiscretePrior(const gtsam::DecisionTreeFactor& f); DiscretePrior(const gtsam::DiscreteKey& key, string spec); + DiscretePrior(const gtsam::DiscreteKey& key, std::vector spec); void print(string s = "Discrete Prior\n", const gtsam::KeyFormatter& keyFormatter = gtsam::DefaultKeyFormatter) const; diff --git a/gtsam/discrete/tests/testDiscretePrior.cpp b/gtsam/discrete/tests/testDiscretePrior.cpp index 23f093b22..6225d227e 100644 --- a/gtsam/discrete/tests/testDiscretePrior.cpp +++ b/gtsam/discrete/tests/testDiscretePrior.cpp @@ -27,12 +27,19 @@ static const DiscreteKey X(0, 2); /* ************************************************************************* */ TEST(DiscretePrior, constructors) { + DecisionTreeFactor f(X, "0.4 0.6"); + DiscretePrior expected(f); + DiscretePrior actual(X % "2/3"); EXPECT_LONGS_EQUAL(1, actual.nrFrontals()); EXPECT_LONGS_EQUAL(0, actual.nrParents()); - DecisionTreeFactor f(X, "0.4 0.6"); - DiscretePrior expected(f); EXPECT(assert_equal(expected, actual, 1e-9)); + + const vector pmf{0.4, 0.6}; + DiscretePrior actual2(X, pmf); + EXPECT_LONGS_EQUAL(1, actual2.nrFrontals()); + EXPECT_LONGS_EQUAL(0, actual2.nrParents()); + EXPECT(assert_equal(expected, actual2, 1e-9)); } /* ************************************************************************* */ diff --git a/python/gtsam/tests/test_DiscretePrior.py b/python/gtsam/tests/test_DiscretePrior.py index 2c923589c..06bdc81ca 100644 --- a/python/gtsam/tests/test_DiscretePrior.py +++ b/python/gtsam/tests/test_DiscretePrior.py @@ -25,12 +25,16 @@ class TestDiscretePrior(GtsamTestCase): def test_constructor(self): """Test various constructors.""" - actual = DiscretePrior(X, "2/3") keys = DiscreteKeys() keys.push_back(X) f = DecisionTreeFactor(keys, "0.4 0.6") expected = DiscretePrior(f) + + actual = DiscretePrior(X, "2/3") self.gtsamAssertEquals(actual, expected) + + actual2 = DiscretePrior(X, [0.4, 0.6]) + self.gtsamAssertEquals(actual2, expected) def test_operator(self): prior = DiscretePrior(X, "2/3")