marginals without parents
parent
23a8dba716
commit
64cd58843a
|
|
@ -110,7 +110,26 @@ DiscreteConditional DiscreteConditional::operator*(
|
||||||
return DiscreteConditional(newFrontals.size(), discreteKeys, product);
|
return DiscreteConditional(newFrontals.size(), discreteKeys, product);
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ******************************************************************************** */
|
/* ************************************************************************** */
|
||||||
|
DiscreteConditional DiscreteConditional::marginal(Key key) const {
|
||||||
|
if (nrParents() > 0)
|
||||||
|
throw std::invalid_argument(
|
||||||
|
"DiscreteConditional::marginal: single argument version only valid for "
|
||||||
|
"fully specified joint distributions (i.e., no parents).");
|
||||||
|
|
||||||
|
// Calculate the keys as the frontal keys without the given key.
|
||||||
|
DiscreteKeys discreteKeys{{key, cardinality(key)}};
|
||||||
|
|
||||||
|
// Calculate sum
|
||||||
|
ADT adt(*this);
|
||||||
|
for (auto&& k : frontals())
|
||||||
|
if (k != key) adt = adt.sum(k, cardinality(k));
|
||||||
|
|
||||||
|
// Return new factor
|
||||||
|
return DiscreteConditional(1, discreteKeys, adt);
|
||||||
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************** */
|
||||||
void DiscreteConditional::print(const string& s,
|
void DiscreteConditional::print(const string& s,
|
||||||
const KeyFormatter& formatter) const {
|
const KeyFormatter& formatter) const {
|
||||||
cout << s << " P( ";
|
cout << s << " P( ";
|
||||||
|
|
|
||||||
|
|
@ -126,6 +126,9 @@ class GTSAM_EXPORT DiscreteConditional
|
||||||
*/
|
*/
|
||||||
DiscreteConditional operator*(const DiscreteConditional& other) const;
|
DiscreteConditional operator*(const DiscreteConditional& other) const;
|
||||||
|
|
||||||
|
/** Calculate marginal on given key, no parent case. */
|
||||||
|
DiscreteConditional marginal(Key key) const;
|
||||||
|
|
||||||
/// @}
|
/// @}
|
||||||
/// @name Testable
|
/// @name Testable
|
||||||
/// @{
|
/// @{
|
||||||
|
|
|
||||||
|
|
@ -97,6 +97,7 @@ virtual class DiscreteConditional : gtsam::DecisionTreeFactor {
|
||||||
const gtsam::Ordering& orderedKeys);
|
const gtsam::Ordering& orderedKeys);
|
||||||
gtsam::DiscreteConditional operator*(
|
gtsam::DiscreteConditional operator*(
|
||||||
const gtsam::DiscreteConditional& other) const;
|
const gtsam::DiscreteConditional& other) const;
|
||||||
|
DiscreteConditional marginal(gtsam::Key key) const;
|
||||||
void print(string s = "Discrete Conditional\n",
|
void print(string s = "Discrete Conditional\n",
|
||||||
const gtsam::KeyFormatter& keyFormatter =
|
const gtsam::KeyFormatter& keyFormatter =
|
||||||
gtsam::DefaultKeyFormatter) const;
|
gtsam::DefaultKeyFormatter) const;
|
||||||
|
|
|
||||||
|
|
@ -97,10 +97,14 @@ TEST(DiscreteConditional, constructors3) {
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
// Check calculation of joint P(A,B)
|
// Check calculation of joint P(A,B)
|
||||||
TEST(DiscreteConditional, Multiply) {
|
TEST(DiscreteConditional, Multiply) {
|
||||||
DiscreteKey A(0, 2), B(1, 2);
|
DiscreteKey A(1, 2), B(0, 2);
|
||||||
DiscreteConditional conditional(A | B = "1/2 2/1");
|
DiscreteConditional conditional(A | B = "1/2 2/1");
|
||||||
DiscreteConditional prior(B % "1/2");
|
DiscreteConditional prior(B % "1/2");
|
||||||
|
|
||||||
|
// The expected factor
|
||||||
|
DecisionTreeFactor f(A & B, "1 4 2 2");
|
||||||
|
DiscreteConditional expected(2, f);
|
||||||
|
|
||||||
// P(A,B) = P(A|B) * P(B) = P(B) * P(A|B)
|
// P(A,B) = P(A|B) * P(B) = P(B) * P(A|B)
|
||||||
for (auto&& actual : {prior * conditional, conditional * prior}) {
|
for (auto&& actual : {prior * conditional, conditional * prior}) {
|
||||||
EXPECT_LONGS_EQUAL(2, actual.nrFrontals());
|
EXPECT_LONGS_EQUAL(2, actual.nrFrontals());
|
||||||
|
|
@ -110,8 +114,11 @@ TEST(DiscreteConditional, Multiply) {
|
||||||
const DiscreteValues& v = it.first;
|
const DiscreteValues& v = it.first;
|
||||||
EXPECT_DOUBLES_EQUAL(actual(v), conditional(v) * prior(v), 1e-9);
|
EXPECT_DOUBLES_EQUAL(actual(v), conditional(v) * prior(v), 1e-9);
|
||||||
}
|
}
|
||||||
|
// And for good measure:
|
||||||
|
EXPECT(assert_equal(expected, actual));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
// Check calculation of conditional joint P(A,B|C)
|
// Check calculation of conditional joint P(A,B|C)
|
||||||
TEST(DiscreteConditional, Multiply2) {
|
TEST(DiscreteConditional, Multiply2) {
|
||||||
|
|
@ -131,6 +138,7 @@ TEST(DiscreteConditional, Multiply2) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
// Check calculation of conditional joint P(A,B|C), double check keys
|
// Check calculation of conditional joint P(A,B|C), double check keys
|
||||||
TEST(DiscreteConditional, Multiply3) {
|
TEST(DiscreteConditional, Multiply3) {
|
||||||
|
|
@ -150,6 +158,7 @@ TEST(DiscreteConditional, Multiply3) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
// Check calculation of conditional joint P(A,B,C|D,E) = P(A,B|D) P(C|D,E)
|
// Check calculation of conditional joint P(A,B,C|D,E) = P(A,B|D) P(C|D,E)
|
||||||
TEST(DiscreteConditional, Multiply4) {
|
TEST(DiscreteConditional, Multiply4) {
|
||||||
|
|
@ -173,6 +182,31 @@ TEST(DiscreteConditional, Multiply4) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************* */
|
||||||
|
// Check calculation of marginals for joint P(A,B)
|
||||||
|
TEST(DiscreteConditional, marginals) {
|
||||||
|
DiscreteKey A(1, 2), B(0, 2);
|
||||||
|
DiscreteConditional conditional(A | B = "1/2 2/1");
|
||||||
|
DiscreteConditional prior(B % "1/2");
|
||||||
|
DiscreteConditional pAB = prior * conditional;
|
||||||
|
|
||||||
|
DiscreteConditional actualA = pAB.marginal(A.first);
|
||||||
|
DiscreteConditional pA(A % "5/4");
|
||||||
|
EXPECT(assert_equal(pA, actualA));
|
||||||
|
EXPECT_LONGS_EQUAL(1, actualA.nrFrontals());
|
||||||
|
EXPECT_LONGS_EQUAL(0, actualA.nrParents());
|
||||||
|
KeyVector frontalsA(actualA.beginFrontals(), actualA.endFrontals());
|
||||||
|
EXPECT((frontalsA == KeyVector{1}));
|
||||||
|
|
||||||
|
DiscreteConditional actualB = pAB.marginal(B.first);
|
||||||
|
EXPECT(assert_equal(prior, actualB));
|
||||||
|
EXPECT_LONGS_EQUAL(1, actualB.nrFrontals());
|
||||||
|
EXPECT_LONGS_EQUAL(0, actualB.nrParents());
|
||||||
|
KeyVector frontalsB(actualB.beginFrontals(), actualB.endFrontals());
|
||||||
|
EXPECT((frontalsB == KeyVector{0}));
|
||||||
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
TEST(DiscreteConditional, likelihood) {
|
TEST(DiscreteConditional, likelihood) {
|
||||||
DiscreteKey X(0, 2), Y(1, 3);
|
DiscreteKey X(0, 2), Y(1, 3);
|
||||||
|
|
|
||||||
|
|
@ -81,6 +81,15 @@ class TestDiscreteConditional(GtsamTestCase):
|
||||||
self.assertAlmostEqual(
|
self.assertAlmostEqual(
|
||||||
actual(v), AB_given_D(v) * C_given_DE(v))
|
actual(v), AB_given_D(v) * C_given_DE(v))
|
||||||
|
|
||||||
|
def test_marginals(self):
|
||||||
|
conditional = DiscreteConditional(A, [B], "1/2 2/1")
|
||||||
|
prior = DiscreteConditional(B, "1/2")
|
||||||
|
pAB = prior * conditional
|
||||||
|
self.gtsamAssertEquals(prior, pAB.marginal(B[0]))
|
||||||
|
|
||||||
|
pA = DiscreteConditional(A % "5/4")
|
||||||
|
self.gtsamAssertEquals(pA, pAB.marginal(A[0]))
|
||||||
|
|
||||||
def test_markdown(self):
|
def test_markdown(self):
|
||||||
"""Test whether the _repr_markdown_ method."""
|
"""Test whether the _repr_markdown_ method."""
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue