Constructor from PMF
parent
1000825b03
commit
be5aa56df7
|
@ -48,17 +48,17 @@ class GTSAM_EXPORT DiscretePrior : public DiscreteConditional {
|
||||||
DiscretePrior(const Signature& s) : Base(s) {}
|
DiscretePrior(const Signature& s) : Base(s) {}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Construct from key and a Signature::Table specifying the
|
* Construct from key and a vector of floats specifying the probability mass
|
||||||
* conditional probability table (CPT).
|
* function (PMF).
|
||||||
*
|
*
|
||||||
* Example: DiscretePrior P(D, table);
|
* Example: DiscretePrior P(D, {0.4, 0.6});
|
||||||
*/
|
*/
|
||||||
DiscretePrior(const DiscreteKey& key, const Signature::Table& table)
|
DiscretePrior(const DiscreteKey& key, const std::vector<double>& spec)
|
||||||
: Base(Signature(key, {}, table)) {}
|
: DiscretePrior(Signature(key, {}, Signature::Table{spec})) {}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Construct from key and a string specifying the conditional
|
* Construct from key and a string specifying the probability mass function
|
||||||
* probability table (CPT).
|
* (PMF).
|
||||||
*
|
*
|
||||||
* Example: DiscretePrior P(D, "9/1 2/8 3/7 1/9");
|
* Example: DiscretePrior P(D, "9/1 2/8 3/7 1/9");
|
||||||
*/
|
*/
|
||||||
|
|
|
@ -120,6 +120,7 @@ virtual class DiscretePrior : gtsam::DiscreteConditional {
|
||||||
DiscretePrior();
|
DiscretePrior();
|
||||||
DiscretePrior(const gtsam::DecisionTreeFactor& f);
|
DiscretePrior(const gtsam::DecisionTreeFactor& f);
|
||||||
DiscretePrior(const gtsam::DiscreteKey& key, string spec);
|
DiscretePrior(const gtsam::DiscreteKey& key, string spec);
|
||||||
|
DiscretePrior(const gtsam::DiscreteKey& key, std::vector<double> spec);
|
||||||
void print(string s = "Discrete Prior\n",
|
void print(string s = "Discrete Prior\n",
|
||||||
const gtsam::KeyFormatter& keyFormatter =
|
const gtsam::KeyFormatter& keyFormatter =
|
||||||
gtsam::DefaultKeyFormatter) const;
|
gtsam::DefaultKeyFormatter) const;
|
||||||
|
|
|
@ -27,12 +27,19 @@ static const DiscreteKey X(0, 2);
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
TEST(DiscretePrior, constructors) {
|
TEST(DiscretePrior, constructors) {
|
||||||
|
DecisionTreeFactor f(X, "0.4 0.6");
|
||||||
|
DiscretePrior expected(f);
|
||||||
|
|
||||||
DiscretePrior actual(X % "2/3");
|
DiscretePrior actual(X % "2/3");
|
||||||
EXPECT_LONGS_EQUAL(1, actual.nrFrontals());
|
EXPECT_LONGS_EQUAL(1, actual.nrFrontals());
|
||||||
EXPECT_LONGS_EQUAL(0, actual.nrParents());
|
EXPECT_LONGS_EQUAL(0, actual.nrParents());
|
||||||
DecisionTreeFactor f(X, "0.4 0.6");
|
|
||||||
DiscretePrior expected(f);
|
|
||||||
EXPECT(assert_equal(expected, actual, 1e-9));
|
EXPECT(assert_equal(expected, actual, 1e-9));
|
||||||
|
|
||||||
|
const vector<double> pmf{0.4, 0.6};
|
||||||
|
DiscretePrior actual2(X, pmf);
|
||||||
|
EXPECT_LONGS_EQUAL(1, actual2.nrFrontals());
|
||||||
|
EXPECT_LONGS_EQUAL(0, actual2.nrParents());
|
||||||
|
EXPECT(assert_equal(expected, actual2, 1e-9));
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
|
|
|
@ -25,12 +25,16 @@ class TestDiscretePrior(GtsamTestCase):
|
||||||
|
|
||||||
def test_constructor(self):
|
def test_constructor(self):
|
||||||
"""Test various constructors."""
|
"""Test various constructors."""
|
||||||
actual = DiscretePrior(X, "2/3")
|
|
||||||
keys = DiscreteKeys()
|
keys = DiscreteKeys()
|
||||||
keys.push_back(X)
|
keys.push_back(X)
|
||||||
f = DecisionTreeFactor(keys, "0.4 0.6")
|
f = DecisionTreeFactor(keys, "0.4 0.6")
|
||||||
expected = DiscretePrior(f)
|
expected = DiscretePrior(f)
|
||||||
|
|
||||||
|
actual = DiscretePrior(X, "2/3")
|
||||||
self.gtsamAssertEquals(actual, expected)
|
self.gtsamAssertEquals(actual, expected)
|
||||||
|
|
||||||
|
actual2 = DiscretePrior(X, [0.4, 0.6])
|
||||||
|
self.gtsamAssertEquals(actual2, expected)
|
||||||
|
|
||||||
def test_operator(self):
|
def test_operator(self):
|
||||||
prior = DiscretePrior(X, "2/3")
|
prior = DiscretePrior(X, "2/3")
|
||||||
|
|
Loading…
Reference in New Issue