Constructor from PMF
parent
1000825b03
commit
be5aa56df7
|
@ -48,17 +48,17 @@ class GTSAM_EXPORT DiscretePrior : public DiscreteConditional {
|
|||
DiscretePrior(const Signature& s) : Base(s) {}
|
||||
|
||||
/**
|
||||
* Construct from key and a Signature::Table specifying the
|
||||
* conditional probability table (CPT).
|
||||
* Construct from key and a vector of floats specifying the probability mass
|
||||
* function (PMF).
|
||||
*
|
||||
* Example: DiscretePrior P(D, table);
|
||||
* Example: DiscretePrior P(D, {0.4, 0.6});
|
||||
*/
|
||||
DiscretePrior(const DiscreteKey& key, const Signature::Table& table)
|
||||
: Base(Signature(key, {}, table)) {}
|
||||
DiscretePrior(const DiscreteKey& key, const std::vector<double>& spec)
|
||||
: DiscretePrior(Signature(key, {}, Signature::Table{spec})) {}
|
||||
|
||||
/**
|
||||
* Construct from key and a string specifying the conditional
|
||||
* probability table (CPT).
|
||||
* Construct from key and a string specifying the probability mass function
|
||||
* (PMF).
|
||||
*
|
||||
* Example: DiscretePrior P(D, "9/1 2/8 3/7 1/9");
|
||||
*/
|
||||
|
|
|
@ -120,6 +120,7 @@ virtual class DiscretePrior : gtsam::DiscreteConditional {
|
|||
DiscretePrior();
|
||||
DiscretePrior(const gtsam::DecisionTreeFactor& f);
|
||||
DiscretePrior(const gtsam::DiscreteKey& key, string spec);
|
||||
DiscretePrior(const gtsam::DiscreteKey& key, std::vector<double> spec);
|
||||
void print(string s = "Discrete Prior\n",
|
||||
const gtsam::KeyFormatter& keyFormatter =
|
||||
gtsam::DefaultKeyFormatter) const;
|
||||
|
|
|
@ -27,12 +27,19 @@ static const DiscreteKey X(0, 2);
|
|||
|
||||
/* ************************************************************************* */
|
||||
TEST(DiscretePrior, constructors) {
|
||||
DecisionTreeFactor f(X, "0.4 0.6");
|
||||
DiscretePrior expected(f);
|
||||
|
||||
DiscretePrior actual(X % "2/3");
|
||||
EXPECT_LONGS_EQUAL(1, actual.nrFrontals());
|
||||
EXPECT_LONGS_EQUAL(0, actual.nrParents());
|
||||
DecisionTreeFactor f(X, "0.4 0.6");
|
||||
DiscretePrior expected(f);
|
||||
EXPECT(assert_equal(expected, actual, 1e-9));
|
||||
|
||||
const vector<double> pmf{0.4, 0.6};
|
||||
DiscretePrior actual2(X, pmf);
|
||||
EXPECT_LONGS_EQUAL(1, actual2.nrFrontals());
|
||||
EXPECT_LONGS_EQUAL(0, actual2.nrParents());
|
||||
EXPECT(assert_equal(expected, actual2, 1e-9));
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
|
|
|
@ -25,12 +25,16 @@ class TestDiscretePrior(GtsamTestCase):
|
|||
|
||||
def test_constructor(self):
|
||||
"""Test various constructors."""
|
||||
actual = DiscretePrior(X, "2/3")
|
||||
keys = DiscreteKeys()
|
||||
keys.push_back(X)
|
||||
f = DecisionTreeFactor(keys, "0.4 0.6")
|
||||
expected = DiscretePrior(f)
|
||||
|
||||
actual = DiscretePrior(X, "2/3")
|
||||
self.gtsamAssertEquals(actual, expected)
|
||||
|
||||
actual2 = DiscretePrior(X, [0.4, 0.6])
|
||||
self.gtsamAssertEquals(actual2, expected)
|
||||
|
||||
def test_operator(self):
|
||||
prior = DiscretePrior(X, "2/3")
|
||||
|
|
Loading…
Reference in New Issue