marginals without parents
parent
23a8dba716
commit
64cd58843a
|
|
@ -110,7 +110,26 @@ DiscreteConditional DiscreteConditional::operator*(
|
|||
return DiscreteConditional(newFrontals.size(), discreteKeys, product);
|
||||
}
|
||||
|
||||
/* ******************************************************************************** */
|
||||
/* ************************************************************************** */
|
||||
DiscreteConditional DiscreteConditional::marginal(Key key) const {
|
||||
if (nrParents() > 0)
|
||||
throw std::invalid_argument(
|
||||
"DiscreteConditional::marginal: single argument version only valid for "
|
||||
"fully specified joint distributions (i.e., no parents).");
|
||||
|
||||
// Calculate the keys as the frontal keys without the given key.
|
||||
DiscreteKeys discreteKeys{{key, cardinality(key)}};
|
||||
|
||||
// Calculate sum
|
||||
ADT adt(*this);
|
||||
for (auto&& k : frontals())
|
||||
if (k != key) adt = adt.sum(k, cardinality(k));
|
||||
|
||||
// Return new factor
|
||||
return DiscreteConditional(1, discreteKeys, adt);
|
||||
}
|
||||
|
||||
/* ************************************************************************** */
|
||||
void DiscreteConditional::print(const string& s,
|
||||
const KeyFormatter& formatter) const {
|
||||
cout << s << " P( ";
|
||||
|
|
|
|||
|
|
@ -126,6 +126,9 @@ class GTSAM_EXPORT DiscreteConditional
|
|||
*/
|
||||
DiscreteConditional operator*(const DiscreteConditional& other) const;
|
||||
|
||||
/** Calculate marginal on given key, no parent case. */
|
||||
DiscreteConditional marginal(Key key) const;
|
||||
|
||||
/// @}
|
||||
/// @name Testable
|
||||
/// @{
|
||||
|
|
|
|||
|
|
@ -97,6 +97,7 @@ virtual class DiscreteConditional : gtsam::DecisionTreeFactor {
|
|||
const gtsam::Ordering& orderedKeys);
|
||||
gtsam::DiscreteConditional operator*(
|
||||
const gtsam::DiscreteConditional& other) const;
|
||||
DiscreteConditional marginal(gtsam::Key key) const;
|
||||
void print(string s = "Discrete Conditional\n",
|
||||
const gtsam::KeyFormatter& keyFormatter =
|
||||
gtsam::DefaultKeyFormatter) const;
|
||||
|
|
|
|||
|
|
@ -97,10 +97,14 @@ TEST(DiscreteConditional, constructors3) {
|
|||
/* ************************************************************************* */
|
||||
// Check calculation of joint P(A,B)
|
||||
TEST(DiscreteConditional, Multiply) {
|
||||
DiscreteKey A(0, 2), B(1, 2);
|
||||
DiscreteKey A(1, 2), B(0, 2);
|
||||
DiscreteConditional conditional(A | B = "1/2 2/1");
|
||||
DiscreteConditional prior(B % "1/2");
|
||||
|
||||
// The expected factor
|
||||
DecisionTreeFactor f(A & B, "1 4 2 2");
|
||||
DiscreteConditional expected(2, f);
|
||||
|
||||
// P(A,B) = P(A|B) * P(B) = P(B) * P(A|B)
|
||||
for (auto&& actual : {prior * conditional, conditional * prior}) {
|
||||
EXPECT_LONGS_EQUAL(2, actual.nrFrontals());
|
||||
|
|
@ -110,8 +114,11 @@ TEST(DiscreteConditional, Multiply) {
|
|||
const DiscreteValues& v = it.first;
|
||||
EXPECT_DOUBLES_EQUAL(actual(v), conditional(v) * prior(v), 1e-9);
|
||||
}
|
||||
// And for good measure:
|
||||
EXPECT(assert_equal(expected, actual));
|
||||
}
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
// Check calculation of conditional joint P(A,B|C)
|
||||
TEST(DiscreteConditional, Multiply2) {
|
||||
|
|
@ -131,6 +138,7 @@ TEST(DiscreteConditional, Multiply2) {
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
// Check calculation of conditional joint P(A,B|C), double check keys
|
||||
TEST(DiscreteConditional, Multiply3) {
|
||||
|
|
@ -150,6 +158,7 @@ TEST(DiscreteConditional, Multiply3) {
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
// Check calculation of conditional joint P(A,B,C|D,E) = P(A,B|D) P(C|D,E)
|
||||
TEST(DiscreteConditional, Multiply4) {
|
||||
|
|
@ -173,6 +182,31 @@ TEST(DiscreteConditional, Multiply4) {
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
// Check calculation of marginals for joint P(A,B)
|
||||
TEST(DiscreteConditional, marginals) {
|
||||
DiscreteKey A(1, 2), B(0, 2);
|
||||
DiscreteConditional conditional(A | B = "1/2 2/1");
|
||||
DiscreteConditional prior(B % "1/2");
|
||||
DiscreteConditional pAB = prior * conditional;
|
||||
|
||||
DiscreteConditional actualA = pAB.marginal(A.first);
|
||||
DiscreteConditional pA(A % "5/4");
|
||||
EXPECT(assert_equal(pA, actualA));
|
||||
EXPECT_LONGS_EQUAL(1, actualA.nrFrontals());
|
||||
EXPECT_LONGS_EQUAL(0, actualA.nrParents());
|
||||
KeyVector frontalsA(actualA.beginFrontals(), actualA.endFrontals());
|
||||
EXPECT((frontalsA == KeyVector{1}));
|
||||
|
||||
DiscreteConditional actualB = pAB.marginal(B.first);
|
||||
EXPECT(assert_equal(prior, actualB));
|
||||
EXPECT_LONGS_EQUAL(1, actualB.nrFrontals());
|
||||
EXPECT_LONGS_EQUAL(0, actualB.nrParents());
|
||||
KeyVector frontalsB(actualB.beginFrontals(), actualB.endFrontals());
|
||||
EXPECT((frontalsB == KeyVector{0}));
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
TEST(DiscreteConditional, likelihood) {
|
||||
DiscreteKey X(0, 2), Y(1, 3);
|
||||
|
|
|
|||
|
|
@ -81,6 +81,15 @@ class TestDiscreteConditional(GtsamTestCase):
|
|||
self.assertAlmostEqual(
|
||||
actual(v), AB_given_D(v) * C_given_DE(v))
|
||||
|
||||
def test_marginals(self):
|
||||
conditional = DiscreteConditional(A, [B], "1/2 2/1")
|
||||
prior = DiscreteConditional(B, "1/2")
|
||||
pAB = prior * conditional
|
||||
self.gtsamAssertEquals(prior, pAB.marginal(B[0]))
|
||||
|
||||
pA = DiscreteConditional(A % "5/4")
|
||||
self.gtsamAssertEquals(pA, pAB.marginal(A[0]))
|
||||
|
||||
def test_markdown(self):
|
||||
"""Test whether the _repr_markdown_ method."""
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue