marginals without parents

release/4.3a0
Frank Dellaert 2022-01-15 16:28:34 -05:00
parent 23a8dba716
commit 64cd58843a
5 changed files with 68 additions and 2 deletions

View File

@ -110,7 +110,26 @@ DiscreteConditional DiscreteConditional::operator*(
return DiscreteConditional(newFrontals.size(), discreteKeys, product);
}
/* ******************************************************************************** */
/* ************************************************************************** */
DiscreteConditional DiscreteConditional::marginal(Key key) const {
if (nrParents() > 0)
throw std::invalid_argument(
"DiscreteConditional::marginal: single argument version only valid for "
"fully specified joint distributions (i.e., no parents).");
// Calculate the keys as the frontal keys without the given key.
DiscreteKeys discreteKeys{{key, cardinality(key)}};
// Calculate sum
ADT adt(*this);
for (auto&& k : frontals())
if (k != key) adt = adt.sum(k, cardinality(k));
// Return new factor
return DiscreteConditional(1, discreteKeys, adt);
}
/* ************************************************************************** */
void DiscreteConditional::print(const string& s,
const KeyFormatter& formatter) const {
cout << s << " P( ";

View File

@ -126,6 +126,9 @@ class GTSAM_EXPORT DiscreteConditional
*/
DiscreteConditional operator*(const DiscreteConditional& other) const;
/** Calculate marginal on given key, no parent case. */
DiscreteConditional marginal(Key key) const;
/// @}
/// @name Testable
/// @{

View File

@ -97,6 +97,7 @@ virtual class DiscreteConditional : gtsam::DecisionTreeFactor {
const gtsam::Ordering& orderedKeys);
gtsam::DiscreteConditional operator*(
const gtsam::DiscreteConditional& other) const;
DiscreteConditional marginal(gtsam::Key key) const;
void print(string s = "Discrete Conditional\n",
const gtsam::KeyFormatter& keyFormatter =
gtsam::DefaultKeyFormatter) const;

View File

@ -97,10 +97,14 @@ TEST(DiscreteConditional, constructors3) {
/* ************************************************************************* */
// Check calculation of joint P(A,B)
TEST(DiscreteConditional, Multiply) {
DiscreteKey A(0, 2), B(1, 2);
DiscreteKey A(1, 2), B(0, 2);
DiscreteConditional conditional(A | B = "1/2 2/1");
DiscreteConditional prior(B % "1/2");
// The expected factor
DecisionTreeFactor f(A & B, "1 4 2 2");
DiscreteConditional expected(2, f);
// P(A,B) = P(A|B) * P(B) = P(B) * P(A|B)
for (auto&& actual : {prior * conditional, conditional * prior}) {
EXPECT_LONGS_EQUAL(2, actual.nrFrontals());
@ -110,8 +114,11 @@ TEST(DiscreteConditional, Multiply) {
const DiscreteValues& v = it.first;
EXPECT_DOUBLES_EQUAL(actual(v), conditional(v) * prior(v), 1e-9);
}
// And for good measure:
EXPECT(assert_equal(expected, actual));
}
}
/* ************************************************************************* */
// Check calculation of conditional joint P(A,B|C)
TEST(DiscreteConditional, Multiply2) {
@ -131,6 +138,7 @@ TEST(DiscreteConditional, Multiply2) {
}
}
}
/* ************************************************************************* */
// Check calculation of conditional joint P(A,B|C), double check keys
TEST(DiscreteConditional, Multiply3) {
@ -150,6 +158,7 @@ TEST(DiscreteConditional, Multiply3) {
}
}
}
/* ************************************************************************* */
// Check calculation of conditional joint P(A,B,C|D,E) = P(A,B|D) P(C|D,E)
TEST(DiscreteConditional, Multiply4) {
@ -173,6 +182,31 @@ TEST(DiscreteConditional, Multiply4) {
}
}
}
/* ************************************************************************* */
// Check calculation of marginals for joint P(A,B)
TEST(DiscreteConditional, marginals) {
DiscreteKey A(1, 2), B(0, 2);
DiscreteConditional conditional(A | B = "1/2 2/1");
DiscreteConditional prior(B % "1/2");
DiscreteConditional pAB = prior * conditional;
DiscreteConditional actualA = pAB.marginal(A.first);
DiscreteConditional pA(A % "5/4");
EXPECT(assert_equal(pA, actualA));
EXPECT_LONGS_EQUAL(1, actualA.nrFrontals());
EXPECT_LONGS_EQUAL(0, actualA.nrParents());
KeyVector frontalsA(actualA.beginFrontals(), actualA.endFrontals());
EXPECT((frontalsA == KeyVector{1}));
DiscreteConditional actualB = pAB.marginal(B.first);
EXPECT(assert_equal(prior, actualB));
EXPECT_LONGS_EQUAL(1, actualB.nrFrontals());
EXPECT_LONGS_EQUAL(0, actualB.nrParents());
KeyVector frontalsB(actualB.beginFrontals(), actualB.endFrontals());
EXPECT((frontalsB == KeyVector{0}));
}
/* ************************************************************************* */
TEST(DiscreteConditional, likelihood) {
DiscreteKey X(0, 2), Y(1, 3);

View File

@ -81,6 +81,15 @@ class TestDiscreteConditional(GtsamTestCase):
self.assertAlmostEqual(
actual(v), AB_given_D(v) * C_given_DE(v))
def test_marginals(self):
conditional = DiscreteConditional(A, [B], "1/2 2/1")
prior = DiscreteConditional(B, "1/2")
pAB = prior * conditional
self.gtsamAssertEquals(prior, pAB.marginal(B[0]))
pA = DiscreteConditional(A % "5/4")
self.gtsamAssertEquals(pA, pAB.marginal(A[0]))
def test_markdown(self):
"""Test whether the _repr_markdown_ method."""