From be5aa56df72f654f338168d6e79c69e915186ebc Mon Sep 17 00:00:00 2001 From: Frank Dellaert Date: Sat, 15 Jan 2022 08:15:46 -0500 Subject: [PATCH 1/7] Constructor from PMF --- gtsam/discrete/DiscretePrior.h | 14 +++++++------- gtsam/discrete/discrete.i | 1 + gtsam/discrete/tests/testDiscretePrior.cpp | 11 +++++++++-- python/gtsam/tests/test_DiscretePrior.py | 6 +++++- 4 files changed, 22 insertions(+), 10 deletions(-) diff --git a/gtsam/discrete/DiscretePrior.h b/gtsam/discrete/DiscretePrior.h index 9ac8acb17..1da188215 100644 --- a/gtsam/discrete/DiscretePrior.h +++ b/gtsam/discrete/DiscretePrior.h @@ -48,17 +48,17 @@ class GTSAM_EXPORT DiscretePrior : public DiscreteConditional { DiscretePrior(const Signature& s) : Base(s) {} /** - * Construct from key and a Signature::Table specifying the - * conditional probability table (CPT). + * Construct from key and a vector of floats specifying the probability mass + * function (PMF). * - * Example: DiscretePrior P(D, table); + * Example: DiscretePrior P(D, {0.4, 0.6}); */ - DiscretePrior(const DiscreteKey& key, const Signature::Table& table) - : Base(Signature(key, {}, table)) {} + DiscretePrior(const DiscreteKey& key, const std::vector& spec) + : DiscretePrior(Signature(key, {}, Signature::Table{spec})) {} /** - * Construct from key and a string specifying the conditional - * probability table (CPT). + * Construct from key and a string specifying the probability mass function + * (PMF). * * Example: DiscretePrior P(D, "9/1 2/8 3/7 1/9"); */ diff --git a/gtsam/discrete/discrete.i b/gtsam/discrete/discrete.i index 218b790e8..12bd5be54 100644 --- a/gtsam/discrete/discrete.i +++ b/gtsam/discrete/discrete.i @@ -120,6 +120,7 @@ virtual class DiscretePrior : gtsam::DiscreteConditional { DiscretePrior(); DiscretePrior(const gtsam::DecisionTreeFactor& f); DiscretePrior(const gtsam::DiscreteKey& key, string spec); + DiscretePrior(const gtsam::DiscreteKey& key, std::vector spec); void print(string s = "Discrete Prior\n", const gtsam::KeyFormatter& keyFormatter = gtsam::DefaultKeyFormatter) const; diff --git a/gtsam/discrete/tests/testDiscretePrior.cpp b/gtsam/discrete/tests/testDiscretePrior.cpp index 23f093b22..6225d227e 100644 --- a/gtsam/discrete/tests/testDiscretePrior.cpp +++ b/gtsam/discrete/tests/testDiscretePrior.cpp @@ -27,12 +27,19 @@ static const DiscreteKey X(0, 2); /* ************************************************************************* */ TEST(DiscretePrior, constructors) { + DecisionTreeFactor f(X, "0.4 0.6"); + DiscretePrior expected(f); + DiscretePrior actual(X % "2/3"); EXPECT_LONGS_EQUAL(1, actual.nrFrontals()); EXPECT_LONGS_EQUAL(0, actual.nrParents()); - DecisionTreeFactor f(X, "0.4 0.6"); - DiscretePrior expected(f); EXPECT(assert_equal(expected, actual, 1e-9)); + + const vector pmf{0.4, 0.6}; + DiscretePrior actual2(X, pmf); + EXPECT_LONGS_EQUAL(1, actual2.nrFrontals()); + EXPECT_LONGS_EQUAL(0, actual2.nrParents()); + EXPECT(assert_equal(expected, actual2, 1e-9)); } /* ************************************************************************* */ diff --git a/python/gtsam/tests/test_DiscretePrior.py b/python/gtsam/tests/test_DiscretePrior.py index 2c923589c..06bdc81ca 100644 --- a/python/gtsam/tests/test_DiscretePrior.py +++ b/python/gtsam/tests/test_DiscretePrior.py @@ -25,12 +25,16 @@ class TestDiscretePrior(GtsamTestCase): def test_constructor(self): """Test various constructors.""" - actual = DiscretePrior(X, "2/3") keys = DiscreteKeys() keys.push_back(X) f = DecisionTreeFactor(keys, "0.4 0.6") expected = DiscretePrior(f) + + actual = DiscretePrior(X, "2/3") self.gtsamAssertEquals(actual, expected) + + actual2 = DiscretePrior(X, [0.4, 0.6]) + self.gtsamAssertEquals(actual2, expected) def test_operator(self): prior = DiscretePrior(X, "2/3") From c15bbed9dc044ffa159ec5a243dce6985e5203cd Mon Sep 17 00:00:00 2001 From: Frank Dellaert Date: Sat, 15 Jan 2022 08:44:10 -0500 Subject: [PATCH 2/7] exposing more factor methods --- gtsam/discrete/discrete.i | 9 ++++ .../discrete/tests/testDecisionTreeFactor.cpp | 26 ++++++---- python/gtsam/tests/test_DecisionTreeFactor.py | 52 +++++++++++++++++-- 3 files changed, 73 insertions(+), 14 deletions(-) diff --git a/gtsam/discrete/discrete.i b/gtsam/discrete/discrete.i index 12bd5be54..24a941056 100644 --- a/gtsam/discrete/discrete.i +++ b/gtsam/discrete/discrete.i @@ -58,6 +58,15 @@ virtual class DecisionTreeFactor : gtsam::DiscreteFactor { const gtsam::KeyFormatter& keyFormatter = gtsam::DefaultKeyFormatter) const; bool equals(const gtsam::DecisionTreeFactor& other, double tol = 1e-9) const; + + double operator()(const gtsam::DiscreteValues& values) const; + gtsam::DecisionTreeFactor operator*(const gtsam::DecisionTreeFactor& f) const; + size_t cardinality(gtsam::Key j) const; + gtsam::DecisionTreeFactor operator/(const gtsam::DecisionTreeFactor& f) const; + gtsam::DecisionTreeFactor* sum(size_t nrFrontals) const; + gtsam::DecisionTreeFactor* sum(const gtsam::Ordering& keys) const; + gtsam::DecisionTreeFactor* max(size_t nrFrontals) const; + string dot( const gtsam::KeyFormatter& keyFormatter = gtsam::DefaultKeyFormatter, bool showZero = true) const; diff --git a/gtsam/discrete/tests/testDecisionTreeFactor.cpp b/gtsam/discrete/tests/testDecisionTreeFactor.cpp index 594134edf..f2ab5f6bc 100644 --- a/gtsam/discrete/tests/testDecisionTreeFactor.cpp +++ b/gtsam/discrete/tests/testDecisionTreeFactor.cpp @@ -17,10 +17,12 @@ * @author Duy-Nguyen Ta */ -#include -#include -#include #include +#include +#include +#include +#include + #include using namespace boost::assign; @@ -51,17 +53,21 @@ TEST( DecisionTreeFactor, constructors) } /* ************************************************************************* */ -TEST_UNSAFE( DecisionTreeFactor, multiplication) -{ - DiscreteKey v0(0,2), v1(1,2), v2(2,2); +TEST(DecisionTreeFactor, multiplication) { + DiscreteKey v0(0, 2), v1(1, 2), v2(2, 2); + // Multiply with a DiscretePrior, i.e., Bayes Law! + DiscretePrior prior(v1 % "1/3"); DecisionTreeFactor f1(v0 & v1, "1 2 3 4"); + DecisionTreeFactor expected(v0 & v1, "0.25 1.5 0.75 3"); + CHECK(assert_equal(expected, prior * f1)); + CHECK(assert_equal(expected, f1 * prior)); + + // Multiply two factors DecisionTreeFactor f2(v1 & v2, "5 6 7 8"); - - DecisionTreeFactor expected(v0 & v1 & v2, "5 6 14 16 15 18 28 32"); - DecisionTreeFactor actual = f1 * f2; - CHECK(assert_equal(expected, actual)); + DecisionTreeFactor expected2(v0 & v1 & v2, "5 6 14 16 15 18 28 32"); + CHECK(assert_equal(expected2, actual)); } /* ************************************************************************* */ diff --git a/python/gtsam/tests/test_DecisionTreeFactor.py b/python/gtsam/tests/test_DecisionTreeFactor.py index 12a60d5cb..03d9f82d7 100644 --- a/python/gtsam/tests/test_DecisionTreeFactor.py +++ b/python/gtsam/tests/test_DecisionTreeFactor.py @@ -13,7 +13,7 @@ Author: Frank Dellaert import unittest -from gtsam import DecisionTreeFactor, DecisionTreeFactor, DiscreteKeys +from gtsam import DecisionTreeFactor, DiscreteValues, DiscretePrior, Ordering from gtsam.utils.test_case import GtsamTestCase @@ -21,15 +21,59 @@ class TestDecisionTreeFactor(GtsamTestCase): """Tests for DecisionTreeFactors.""" def setUp(self): - A = (12, 3) - B = (5, 2) - self.factor = DecisionTreeFactor([A, B], "1 2 3 4 5 6") + self.A = (12, 3) + self.B = (5, 2) + self.factor = DecisionTreeFactor([self.A, self.B], "1 2 3 4 5 6") def test_enumerate(self): actual = self.factor.enumerate() _, values = zip(*actual) self.assertEqual(list(values), [1.0, 2.0, 3.0, 4.0, 5.0, 6.0]) + def test_multiplication(self): + """Test whether multiplication works with overloading.""" + v0 = (0, 2) + v1 = (1, 2) + v2 = (2, 2) + + # Multiply with a DiscretePrior, i.e., Bayes Law! + prior = DiscretePrior(v1, [1, 3]) + f1 = DecisionTreeFactor([v0, v1], "1 2 3 4") + expected = DecisionTreeFactor([v0, v1], "0.25 1.5 0.75 3") + self.gtsamAssertEquals(prior * f1, expected) + self.gtsamAssertEquals(f1 * prior, expected) + + # Multiply two factors + f2 = DecisionTreeFactor([v1, v2], "5 6 7 8") + actual = f1 * f2 + expected2 = DecisionTreeFactor([v0, v1, v2], "5 6 14 16 15 18 28 32") + self.gtsamAssertEquals(actual, expected2) + + def test_methods(self): + """Test whether we can call methods in python.""" + # double operator()(const DiscreteValues& values) const; + values = DiscreteValues() + values[self.A[0]] = 0 + values[self.B[0]] = 0 + self.assertIsInstance(self.factor(values), float) + + # size_t cardinality(Key j) const; + self.assertIsInstance(self.factor.cardinality(self.A[0]), int) + + # DecisionTreeFactor operator/(const DecisionTreeFactor& f) const; + self.assertIsInstance(self.factor / self.factor, DecisionTreeFactor) + + # DecisionTreeFactor* sum(size_t nrFrontals) const; + self.assertIsInstance(self.factor.sum(1), DecisionTreeFactor) + + # DecisionTreeFactor* sum(const Ordering& keys) const; + ordering = Ordering() + ordering.push_back(self.A[0]) + self.assertIsInstance(self.factor.sum(ordering), DecisionTreeFactor) + + # DecisionTreeFactor* max(size_t nrFrontals) const; + self.assertIsInstance(self.factor.max(1), DecisionTreeFactor) + def test_markdown(self): """Test whether the _repr_markdown_ method.""" From 0909e9838915ab6b6332d27462d9dd58309b438a Mon Sep 17 00:00:00 2001 From: Frank Dellaert Date: Sat, 15 Jan 2022 15:11:25 -0500 Subject: [PATCH 3/7] Comments only --- gtsam/discrete/DecisionTreeFactor.h | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/gtsam/discrete/DecisionTreeFactor.h b/gtsam/discrete/DecisionTreeFactor.h index b5f6c0c4a..8beeb4c4a 100644 --- a/gtsam/discrete/DecisionTreeFactor.h +++ b/gtsam/discrete/DecisionTreeFactor.h @@ -57,7 +57,7 @@ namespace gtsam { /** Default constructor for I/O */ DecisionTreeFactor(); - /** Constructor from Indices, Ordering, and AlgebraicDecisionDiagram */ + /** Constructor from DiscreteKeys and AlgebraicDecisionTree */ DecisionTreeFactor(const DiscreteKeys& keys, const ADT& potentials); /** Constructor from doubles */ @@ -139,14 +139,14 @@ namespace gtsam { /** * Apply binary operator (*this) "op" f * @param f the second argument for op - * @param op a binary operator that operates on AlgebraicDecisionDiagram potentials + * @param op a binary operator that operates on AlgebraicDecisionTree */ DecisionTreeFactor apply(const DecisionTreeFactor& f, ADT::Binary op) const; /** * Combine frontal variables using binary operator "op" * @param nrFrontals nr. of frontal to combine variables in this factor - * @param op a binary operator that operates on AlgebraicDecisionDiagram potentials + * @param op a binary operator that operates on AlgebraicDecisionTree * @return shared pointer to newly created DecisionTreeFactor */ shared_ptr combine(size_t nrFrontals, ADT::Binary op) const; @@ -154,7 +154,7 @@ namespace gtsam { /** * Combine frontal variables in an Ordering using binary operator "op" * @param nrFrontals nr. of frontal to combine variables in this factor - * @param op a binary operator that operates on AlgebraicDecisionDiagram potentials + * @param op a binary operator that operates on AlgebraicDecisionTree * @return shared pointer to newly created DecisionTreeFactor */ shared_ptr combine(const Ordering& keys, ADT::Binary op) const; From f9dd225ca5d4498bcd9b3f1aa75441c0a351e3f1 Mon Sep 17 00:00:00 2001 From: Frank Dellaert Date: Sat, 15 Jan 2022 15:12:55 -0500 Subject: [PATCH 4/7] Replace buggy/awkward Combine with principled operator*, remove toFactor --- gtsam/discrete/DiscreteConditional.cpp | 79 ++++++++--- gtsam/discrete/DiscreteConditional.h | 74 +++++------ gtsam/discrete/discrete.i | 1 - .../discrete/tests/testDecisionTreeFactor.cpp | 2 +- .../tests/testDiscreteConditional.cpp | 124 ++++++++++++++---- gtsam/discrete/tests/testDiscretePrior.cpp | 13 ++ 6 files changed, 204 insertions(+), 89 deletions(-) diff --git a/gtsam/discrete/DiscreteConditional.cpp b/gtsam/discrete/DiscreteConditional.cpp index 0bdc7d7b5..5acd7c0f6 100644 --- a/gtsam/discrete/DiscreteConditional.cpp +++ b/gtsam/discrete/DiscreteConditional.cpp @@ -30,6 +30,7 @@ #include #include #include +#include using namespace std; using std::stringstream; @@ -38,37 +39,77 @@ using std::pair; namespace gtsam { // Instantiate base class -template class GTSAM_EXPORT Conditional ; +template class GTSAM_EXPORT + Conditional; -/* ******************************************************************************** */ +/* ************************************************************************** */ DiscreteConditional::DiscreteConditional(const size_t nrFrontals, - const DecisionTreeFactor& f) : - BaseFactor(f / (*f.sum(nrFrontals))), BaseConditional(nrFrontals) { -} + const DecisionTreeFactor& f) + : BaseFactor(f / (*f.sum(nrFrontals))), BaseConditional(nrFrontals) {} -/* ******************************************************************************** */ -DiscreteConditional::DiscreteConditional(const DecisionTreeFactor& joint, - const DecisionTreeFactor& marginal) : - BaseFactor( - ISDEBUG("DiscreteConditional::COUNT") ? joint : joint / marginal), BaseConditional( - joint.size()-marginal.size()) { - if (ISDEBUG("DiscreteConditional::DiscreteConditional")) - cout << (firstFrontalKey()) << endl; //TODO Print all keys -} +/* ************************************************************************** */ +DiscreteConditional::DiscreteConditional(size_t nrFrontals, + const DiscreteKeys& keys, + const ADT& potentials) + : BaseFactor(keys, potentials), BaseConditional(nrFrontals) {} -/* ******************************************************************************** */ +/* ************************************************************************** */ DiscreteConditional::DiscreteConditional(const DecisionTreeFactor& joint, - const DecisionTreeFactor& marginal, const Ordering& orderedKeys) : - DiscreteConditional(joint, marginal) { + const DecisionTreeFactor& marginal) + : BaseFactor(joint / marginal), + BaseConditional(joint.size() - marginal.size()) {} + +/* ************************************************************************** */ +DiscreteConditional::DiscreteConditional(const DecisionTreeFactor& joint, + const DecisionTreeFactor& marginal, + const Ordering& orderedKeys) + : DiscreteConditional(joint, marginal) { keys_.clear(); keys_.insert(keys_.end(), orderedKeys.begin(), orderedKeys.end()); } -/* ******************************************************************************** */ +/* ************************************************************************** */ DiscreteConditional::DiscreteConditional(const Signature& signature) : BaseFactor(signature.discreteKeys(), signature.cpt()), BaseConditional(1) {} +/* ************************************************************************** */ +DiscreteConditional DiscreteConditional::operator*( + const DiscreteConditional& other) const { + // Take union of frontal keys + std::set newFrontals; + for (auto&& key : this->frontals()) newFrontals.insert(key); + for (auto&& key : other.frontals()) newFrontals.insert(key); + + // Check if frontals overlapped + if (nrFrontals() + other.nrFrontals() > newFrontals.size()) + throw std::invalid_argument( + "DiscreteConditional::operator* called with overlapping frontal keys."); + + // Now, add cardinalities. + DiscreteKeys discreteKeys; + for (auto&& key : frontals()) + discreteKeys.emplace_back(key, cardinality(key)); + for (auto&& key : other.frontals()) + discreteKeys.emplace_back(key, other.cardinality(key)); + + // Sort + std::sort(discreteKeys.begin(), discreteKeys.end()); + + // Add parents to set, to make them unique + std::set parents; + for (auto&& key : this->parents()) + if (!newFrontals.count(key)) parents.emplace(key, cardinality(key)); + for (auto&& key : other.parents()) + if (!newFrontals.count(key)) parents.emplace(key, other.cardinality(key)); + + // Finally, add parents to keys, in order + for (auto&& dk : parents) discreteKeys.push_back(dk); + + ADT product = ADT::apply(other, ADT::Ring::mul); + return DiscreteConditional(newFrontals.size(), discreteKeys, product); +} + /* ******************************************************************************** */ void DiscreteConditional::print(const string& s, const KeyFormatter& formatter) const { @@ -82,7 +123,7 @@ void DiscreteConditional::print(const string& s, cout << formatter(*it) << " "; } } - cout << ")"; + cout << "):\n"; ADT::print(""); cout << endl; } diff --git a/gtsam/discrete/DiscreteConditional.h b/gtsam/discrete/DiscreteConditional.h index 4a83ff83a..450af57ab 100644 --- a/gtsam/discrete/DiscreteConditional.h +++ b/gtsam/discrete/DiscreteConditional.h @@ -49,14 +49,21 @@ class GTSAM_EXPORT DiscreteConditional /// @name Standard Constructors /// @{ - /** default constructor needed for serialization */ + /// Default constructor needed for serialization. DiscreteConditional() {} - /** constructor from factor */ + /// Construct from factor, taking the first `nFrontals` keys as frontals. DiscreteConditional(size_t nFrontals, const DecisionTreeFactor& f); + /** + * Construct from DiscreteKeys and AlgebraicDecisionTree, taking the first + * `nFrontals` keys as frontals, in the order given. + */ + DiscreteConditional(size_t nFrontals, const DiscreteKeys& keys, + const ADT& potentials); + /** Construct from signature */ - DiscreteConditional(const Signature& signature); + explicit DiscreteConditional(const Signature& signature); /** * Construct from key, parents, and a Signature::Table specifying the @@ -86,27 +93,38 @@ class GTSAM_EXPORT DiscreteConditional DiscreteConditional(const DiscreteKey& key, const std::string& spec) : DiscreteConditional(Signature(key, {}, spec)) {} - /** construct P(X|Y)=P(X,Y)/P(Y) from P(X,Y) and P(Y) */ + /** + * @brief construct P(X|Y) = f(X,Y)/f(Y) from f(X,Y) and f(Y) + * Assumes but *does not check* that f(Y)=sum_X f(X,Y). + */ DiscreteConditional(const DecisionTreeFactor& joint, const DecisionTreeFactor& marginal); - /** construct P(X|Y)=P(X,Y)/P(Y) from P(X,Y) and P(Y) */ + /** + * @brief construct P(X|Y) = f(X,Y)/f(Y) from f(X,Y) and f(Y) + * Assumes but *does not check* that f(Y)=sum_X f(X,Y). + * Makes sure the keys are ordered as given. Does not check orderedKeys. + */ DiscreteConditional(const DecisionTreeFactor& joint, const DecisionTreeFactor& marginal, const Ordering& orderedKeys); /** - * Combine several conditional into a single one. - * The conditionals must be given in increasing order, meaning that the - * parents of any conditional may not include a conditional coming before it. - * @param firstConditional Iterator to the first conditional to combine, must - * dereference to a shared_ptr. - * @param lastConditional Iterator to after the last conditional to combine, - * must dereference to a shared_ptr. - * */ - template - static shared_ptr Combine(ITERATOR firstConditional, - ITERATOR lastConditional); + * @brief Combine two conditionals, yielding a new conditional with the union + * of the frontal keys, ordered by gtsam::Key. + * + * The two conditionals must make a valid Bayes net fragment, i.e., + * the frontal variables cannot overlap, and must be acyclic: + * Example of correct use: + * P(A,B) = P(A|B) * P(B) + * P(A,B|C) = P(A|B) * P(B|C) + * P(A,B,C) = P(A,B|C) * P(C) + * Example of incorrect use: + * P(A|B) * P(A|C) = ? + * P(A|B) * P(B|A) = ? + * We check for overlapping frontals, but do *not* check for cyclic. + */ + DiscreteConditional operator*(const DiscreteConditional& other) const; /// @} /// @name Testable @@ -136,11 +154,6 @@ class GTSAM_EXPORT DiscreteConditional return ADT::operator()(values); } - /** Convert to a factor */ - DecisionTreeFactor::shared_ptr toFactor() const { - return DecisionTreeFactor::shared_ptr(new DecisionTreeFactor(*this)); - } - /** Restrict to given parent values, returns DecisionTreeFactor */ DecisionTreeFactor::shared_ptr choose( const DiscreteValues& parentsValues) const; @@ -208,23 +221,4 @@ class GTSAM_EXPORT DiscreteConditional template <> struct traits : public Testable {}; -/* ************************************************************************* */ -template -DiscreteConditional::shared_ptr DiscreteConditional::Combine( - ITERATOR firstConditional, ITERATOR lastConditional) { - // TODO: check for being a clique - - // multiply all the potentials of the given conditionals - size_t nrFrontals = 0; - DecisionTreeFactor product; - for (ITERATOR it = firstConditional; it != lastConditional; - ++it, ++nrFrontals) { - DiscreteConditional::shared_ptr c = *it; - DecisionTreeFactor::shared_ptr factor = c->toFactor(); - product = (*factor) * product; - } - // and then create a new multi-frontal conditional - return boost::make_shared(nrFrontals, product); -} - } // namespace gtsam diff --git a/gtsam/discrete/discrete.i b/gtsam/discrete/discrete.i index 24a941056..5fce25cf5 100644 --- a/gtsam/discrete/discrete.i +++ b/gtsam/discrete/discrete.i @@ -102,7 +102,6 @@ virtual class DiscreteConditional : gtsam::DecisionTreeFactor { void printSignature( string s = "Discrete Conditional: ", const gtsam::KeyFormatter& formatter = gtsam::DefaultKeyFormatter) const; - gtsam::DecisionTreeFactor* toFactor() const; gtsam::DecisionTreeFactor* choose( const gtsam::DiscreteValues& parentsValues) const; gtsam::DecisionTreeFactor* likelihood( diff --git a/gtsam/discrete/tests/testDecisionTreeFactor.cpp b/gtsam/discrete/tests/testDecisionTreeFactor.cpp index f2ab5f6bc..7e89874a5 100644 --- a/gtsam/discrete/tests/testDecisionTreeFactor.cpp +++ b/gtsam/discrete/tests/testDecisionTreeFactor.cpp @@ -60,7 +60,7 @@ TEST(DecisionTreeFactor, multiplication) { DiscretePrior prior(v1 % "1/3"); DecisionTreeFactor f1(v0 & v1, "1 2 3 4"); DecisionTreeFactor expected(v0 & v1, "0.25 1.5 0.75 3"); - CHECK(assert_equal(expected, prior * f1)); + CHECK(assert_equal(expected, static_cast(prior) * f1)); CHECK(assert_equal(expected, f1 * prior)); // Multiply two factors diff --git a/gtsam/discrete/tests/testDiscreteConditional.cpp b/gtsam/discrete/tests/testDiscreteConditional.cpp index 3fb67a615..03766136c 100644 --- a/gtsam/discrete/tests/testDiscreteConditional.cpp +++ b/gtsam/discrete/tests/testDiscreteConditional.cpp @@ -34,20 +34,21 @@ using namespace gtsam; TEST(DiscreteConditional, constructors) { DiscreteKey X(0, 2), Y(2, 3), Z(1, 2); // watch ordering ! - DiscreteConditional expected(X | Y = "1/1 2/3 1/4"); - EXPECT_LONGS_EQUAL(0, *(expected.beginFrontals())); - EXPECT_LONGS_EQUAL(2, *(expected.beginParents())); - EXPECT(expected.endParents() == expected.end()); - EXPECT(expected.endFrontals() == expected.beginParents()); + DiscreteConditional actual(X | Y = "1/1 2/3 1/4"); + EXPECT_LONGS_EQUAL(0, *(actual.beginFrontals())); + EXPECT_LONGS_EQUAL(2, *(actual.beginParents())); + EXPECT(actual.endParents() == actual.end()); + EXPECT(actual.endFrontals() == actual.beginParents()); DecisionTreeFactor f1(X & Y, "0.5 0.4 0.2 0.5 0.6 0.8"); - DiscreteConditional actual1(1, f1); - EXPECT(assert_equal(expected, actual1, 1e-9)); + DiscreteConditional expected1(1, f1); + EXPECT(assert_equal(expected1, actual, 1e-9)); DecisionTreeFactor f2( X & Y & Z, "0.2 0.5 0.3 0.6 0.4 0.7 0.25 0.55 0.35 0.65 0.45 0.75"); DiscreteConditional actual2(1, f2); - EXPECT(assert_equal(f2 / *f2.sum(1), *actual2.toFactor(), 1e-9)); + DecisionTreeFactor expected2 = f2 / *f2.sum(1); + EXPECT(assert_equal(expected2, static_cast(actual2))); } /* ************************************************************************* */ @@ -61,6 +62,7 @@ TEST(DiscreteConditional, constructors_alt_interface) { r3 += 1.0, 4.0; table += r1, r2, r3; DiscreteConditional actual1(X, {Y}, table); + DecisionTreeFactor f1(X & Y, "0.5 0.4 0.2 0.5 0.6 0.8"); DiscreteConditional expected1(1, f1); EXPECT(assert_equal(expected1, actual1, 1e-9)); @@ -68,43 +70,109 @@ TEST(DiscreteConditional, constructors_alt_interface) { DecisionTreeFactor f2( X & Y & Z, "0.2 0.5 0.3 0.6 0.4 0.7 0.25 0.55 0.35 0.65 0.45 0.75"); DiscreteConditional actual2(1, f2); - EXPECT(assert_equal(f2 / *f2.sum(1), *actual2.toFactor(), 1e-9)); + DecisionTreeFactor expected2 = f2 / *f2.sum(1); + EXPECT(assert_equal(expected2, static_cast(actual2))); } /* ************************************************************************* */ TEST(DiscreteConditional, constructors2) { - // Declare keys and ordering DiscreteKey C(0, 2), B(1, 2); - DecisionTreeFactor actual(C & B, "0.8 0.75 0.2 0.25"); Signature signature((C | B) = "4/1 3/1"); - DiscreteConditional expected(signature); - DecisionTreeFactor::shared_ptr expectedFactor = expected.toFactor(); - EXPECT(assert_equal(*expectedFactor, actual)); + DiscreteConditional actual(signature); + + DecisionTreeFactor expected(C & B, "0.8 0.75 0.2 0.25"); + EXPECT(assert_equal(expected, static_cast(actual))); } /* ************************************************************************* */ TEST(DiscreteConditional, constructors3) { - // Declare keys and ordering DiscreteKey C(0, 2), B(1, 2), A(2, 2); - DecisionTreeFactor actual(C & B & A, "0.8 0.5 0.5 0.2 0.2 0.5 0.5 0.8"); Signature signature((C | B, A) = "4/1 1/1 1/1 1/4"); - DiscreteConditional expected(signature); - DecisionTreeFactor::shared_ptr expectedFactor = expected.toFactor(); - EXPECT(assert_equal(*expectedFactor, actual)); + DiscreteConditional actual(signature); + + DecisionTreeFactor expected(C & B & A, "0.8 0.5 0.5 0.2 0.2 0.5 0.5 0.8"); + EXPECT(assert_equal(expected, static_cast(actual))); } /* ************************************************************************* */ -TEST(DiscreteConditional, Combine) { +// Check calculation of joint P(A,B) +TEST(DiscreteConditional, Multiply) { DiscreteKey A(0, 2), B(1, 2); - vector c; - c.push_back(boost::make_shared(A | B = "1/2 2/1")); - c.push_back(boost::make_shared(B % "1/2")); - DecisionTreeFactor factor(A & B, "0.111111 0.444444 0.222222 0.222222"); - DiscreteConditional expected(2, factor); - auto actual = DiscreteConditional::Combine(c.begin(), c.end()); - EXPECT(assert_equal(expected, *actual, 1e-5)); -} + DiscreteConditional conditional(A | B = "1/2 2/1"); + DiscreteConditional prior(B % "1/2"); + // P(A,B) = P(A|B) * P(B) = P(B) * P(A|B) + for (auto&& actual : {prior * conditional, conditional * prior}) { + EXPECT_LONGS_EQUAL(2, actual.nrFrontals()); + KeyVector frontals(actual.beginFrontals(), actual.endFrontals()); + EXPECT((frontals == KeyVector{0, 1})); + for (auto&& it : actual.enumerate()) { + const DiscreteValues& v = it.first; + EXPECT_DOUBLES_EQUAL(actual(v), conditional(v) * prior(v), 1e-9); + } + } +} +/* ************************************************************************* */ +// Check calculation of conditional joint P(A,B|C) +TEST(DiscreteConditional, Multiply2) { + DiscreteKey A(0, 2), B(1, 2), C(2, 2); + DiscreteConditional A_given_B(A | B = "1/3 3/1"); + DiscreteConditional B_given_C(B | C = "1/3 3/1"); + + // P(A,B|C) = P(A|B)P(B|C) = P(B|C)P(A|B) + for (auto&& actual : {A_given_B * B_given_C, B_given_C * A_given_B}) { + EXPECT_LONGS_EQUAL(2, actual.nrFrontals()); + EXPECT_LONGS_EQUAL(1, actual.nrParents()); + KeyVector frontals(actual.beginFrontals(), actual.endFrontals()); + EXPECT((frontals == KeyVector{0, 1})); + for (auto&& it : actual.enumerate()) { + const DiscreteValues& v = it.first; + EXPECT_DOUBLES_EQUAL(actual(v), A_given_B(v) * B_given_C(v), 1e-9); + } + } +} +/* ************************************************************************* */ +// Check calculation of conditional joint P(A,B|C), double check keys +TEST(DiscreteConditional, Multiply3) { + DiscreteKey A(1, 2), B(2, 2), C(0, 2); // different keys!!! + DiscreteConditional A_given_B(A | B = "1/3 3/1"); + DiscreteConditional B_given_C(B | C = "1/3 3/1"); + + // P(A,B|C) = P(A|B)P(B|C) = P(B|C)P(A|B) + for (auto&& actual : {A_given_B * B_given_C, B_given_C * A_given_B}) { + EXPECT_LONGS_EQUAL(2, actual.nrFrontals()); + EXPECT_LONGS_EQUAL(1, actual.nrParents()); + KeyVector frontals(actual.beginFrontals(), actual.endFrontals()); + EXPECT((frontals == KeyVector{1, 2})); + for (auto&& it : actual.enumerate()) { + const DiscreteValues& v = it.first; + EXPECT_DOUBLES_EQUAL(actual(v), A_given_B(v) * B_given_C(v), 1e-9); + } + } +} +/* ************************************************************************* */ +// Check calculation of conditional joint P(A,B,C|D,E) = P(A,B|D) P(C|D,E) +TEST(DiscreteConditional, Multiply4) { + DiscreteKey A(0, 2), B(1, 2), C(2, 2), D(4, 2), E(3, 2); + DiscreteConditional A_given_B(A | B = "1/3 3/1"); + DiscreteConditional B_given_D(B | D = "1/3 3/1"); + DiscreteConditional AB_given_D = A_given_B * B_given_D; + DiscreteConditional C_given_DE((C | D, E) = "4/1 1/1 1/1 1/4"); + + // P(A,B,C|D,E) = P(A,B|D) P(C|D,E) = P(C|D,E) P(A,B|D) + for (auto&& actual : {AB_given_D * C_given_DE, C_given_DE * AB_given_D}) { + EXPECT_LONGS_EQUAL(3, actual.nrFrontals()); + EXPECT_LONGS_EQUAL(2, actual.nrParents()); + KeyVector frontals(actual.beginFrontals(), actual.endFrontals()); + EXPECT((frontals == KeyVector{0, 1, 2})); + KeyVector parents(actual.beginParents(), actual.endParents()); + EXPECT((parents == KeyVector{3, 4})); + for (auto&& it : actual.enumerate()) { + const DiscreteValues& v = it.first; + EXPECT_DOUBLES_EQUAL(actual(v), AB_given_D(v) * C_given_DE(v), 1e-9); + } + } +} /* ************************************************************************* */ TEST(DiscreteConditional, likelihood) { DiscreteKey X(0, 2), Y(1, 3); diff --git a/gtsam/discrete/tests/testDiscretePrior.cpp b/gtsam/discrete/tests/testDiscretePrior.cpp index 6225d227e..6ef57c7ff 100644 --- a/gtsam/discrete/tests/testDiscretePrior.cpp +++ b/gtsam/discrete/tests/testDiscretePrior.cpp @@ -42,6 +42,19 @@ TEST(DiscretePrior, constructors) { EXPECT(assert_equal(expected, actual2, 1e-9)); } +/* ************************************************************************* */ +TEST(DiscretePrior, Multiply) { + DiscreteKey A(0, 2), B(1, 2); + DiscreteConditional conditional(A | B = "1/2 2/1"); + DiscretePrior prior(B, "1/2"); + DiscreteConditional actual = prior * conditional; // P(A|B) * P(B) + + EXPECT_LONGS_EQUAL(2, actual.nrFrontals()); // = P(A,B) + DecisionTreeFactor factor(A & B, "1 4 2 2"); + DiscreteConditional expected(2, factor); + EXPECT(assert_equal(expected, actual, 1e-5)); +} + /* ************************************************************************* */ TEST(DiscretePrior, operator) { DiscretePrior prior(X % "2/3"); From 23a8dba7163f57988a495c898828d938b1a678dd Mon Sep 17 00:00:00 2001 From: Frank Dellaert Date: Sat, 15 Jan 2022 15:33:01 -0500 Subject: [PATCH 5/7] Wrapped multiplication --- gtsam/discrete/discrete.i | 4 ++ python/gtsam/tests/test_DecisionTreeFactor.py | 2 +- .../gtsam/tests/test_DiscreteConditional.py | 48 ++++++++++++++++++- 3 files changed, 51 insertions(+), 3 deletions(-) diff --git a/gtsam/discrete/discrete.i b/gtsam/discrete/discrete.i index 5fce25cf5..8bcb8b4aa 100644 --- a/gtsam/discrete/discrete.i +++ b/gtsam/discrete/discrete.i @@ -95,10 +95,14 @@ virtual class DiscreteConditional : gtsam::DecisionTreeFactor { DiscreteConditional(const gtsam::DecisionTreeFactor& joint, const gtsam::DecisionTreeFactor& marginal, const gtsam::Ordering& orderedKeys); + gtsam::DiscreteConditional operator*( + const gtsam::DiscreteConditional& other) const; void print(string s = "Discrete Conditional\n", const gtsam::KeyFormatter& keyFormatter = gtsam::DefaultKeyFormatter) const; bool equals(const gtsam::DiscreteConditional& other, double tol = 1e-9) const; + size_t nrFrontals() const; + size_t nrParents() const; void printSignature( string s = "Discrete Conditional: ", const gtsam::KeyFormatter& formatter = gtsam::DefaultKeyFormatter) const; diff --git a/python/gtsam/tests/test_DecisionTreeFactor.py b/python/gtsam/tests/test_DecisionTreeFactor.py index 03d9f82d7..a13a43e26 100644 --- a/python/gtsam/tests/test_DecisionTreeFactor.py +++ b/python/gtsam/tests/test_DecisionTreeFactor.py @@ -40,7 +40,7 @@ class TestDecisionTreeFactor(GtsamTestCase): prior = DiscretePrior(v1, [1, 3]) f1 = DecisionTreeFactor([v0, v1], "1 2 3 4") expected = DecisionTreeFactor([v0, v1], "0.25 1.5 0.75 3") - self.gtsamAssertEquals(prior * f1, expected) + self.gtsamAssertEquals(DecisionTreeFactor(prior) * f1, expected) self.gtsamAssertEquals(f1 * prior, expected) # Multiply two factors diff --git a/python/gtsam/tests/test_DiscreteConditional.py b/python/gtsam/tests/test_DiscreteConditional.py index 0ae66c2d4..190c22181 100644 --- a/python/gtsam/tests/test_DiscreteConditional.py +++ b/python/gtsam/tests/test_DiscreteConditional.py @@ -16,6 +16,13 @@ import unittest from gtsam import DecisionTreeFactor, DiscreteConditional, DiscreteKeys from gtsam.utils.test_case import GtsamTestCase +# Some DiscreteKeys for binary variables: +A = 0, 2 +B = 1, 2 +C = 2, 2 +D = 4, 2 +E = 3, 2 + class TestDiscreteConditional(GtsamTestCase): """Tests for Discrete Conditionals.""" @@ -36,6 +43,44 @@ class TestDiscreteConditional(GtsamTestCase): actual = conditional.sample(2) self.assertIsInstance(actual, int) + def test_multiply(self): + """Check calculation of joint P(A,B)""" + conditional = DiscreteConditional(A, [B], "1/2 2/1") + prior = DiscreteConditional(B, "1/2") + + # P(A,B) = P(A|B) * P(B) = P(B) * P(A|B) + for actual in [prior * conditional, conditional * prior]: + self.assertEqual(2, actual.nrFrontals()) + for v, value in actual.enumerate(): + self.assertAlmostEqual(actual(v), conditional(v) * prior(v)) + + def test_multiply2(self): + """Check calculation of conditional joint P(A,B|C)""" + A_given_B = DiscreteConditional(A, [B], "1/3 3/1") + B_given_C = DiscreteConditional(B, [C], "1/3 3/1") + + # P(A,B|C) = P(A|B)P(B|C) = P(B|C)P(A|B) + for actual in [A_given_B * B_given_C, B_given_C * A_given_B]: + self.assertEqual(2, actual.nrFrontals()) + self.assertEqual(1, actual.nrParents()) + for v, value in actual.enumerate(): + self.assertAlmostEqual(actual(v), A_given_B(v) * B_given_C(v)) + + def test_multiply4(self): + """Check calculation of joint P(A,B,C|D,E) = P(A,B|D) P(C|D,E)""" + A_given_B = DiscreteConditional(A, [B], "1/3 3/1") + B_given_D = DiscreteConditional(B, [D], "1/3 3/1") + AB_given_D = A_given_B * B_given_D + C_given_DE = DiscreteConditional(C, [D, E], "4/1 1/1 1/1 1/4") + + # P(A,B,C|D,E) = P(A,B|D) P(C|D,E) = P(C|D,E) P(A,B|D) + for actual in [AB_given_D * C_given_DE, C_given_DE * AB_given_D]: + self.assertEqual(3, actual.nrFrontals()) + self.assertEqual(2, actual.nrParents()) + for v, value in actual.enumerate(): + self.assertAlmostEqual( + actual(v), AB_given_D(v) * C_given_DE(v)) + def test_markdown(self): """Test whether the _repr_markdown_ method.""" @@ -48,8 +93,7 @@ class TestDiscreteConditional(GtsamTestCase): conditional = DiscreteConditional(A, parents, "0/1 1/3 1/1 3/1 0/1 1/0") - expected = \ - " *P(A|B,C):*\n\n" \ + expected = " *P(A|B,C):*\n\n" \ "|*B*|*C*|0|1|\n" \ "|:-:|:-:|:-:|:-:|\n" \ "|0|0|0|1|\n" \ From 64cd58843acf7664cf84169cd829507fff3050fa Mon Sep 17 00:00:00 2001 From: Frank Dellaert Date: Sat, 15 Jan 2022 16:28:34 -0500 Subject: [PATCH 6/7] marginals without parents --- gtsam/discrete/DiscreteConditional.cpp | 21 ++++++++++- gtsam/discrete/DiscreteConditional.h | 3 ++ gtsam/discrete/discrete.i | 1 + .../tests/testDiscreteConditional.cpp | 36 ++++++++++++++++++- .../gtsam/tests/test_DiscreteConditional.py | 9 +++++ 5 files changed, 68 insertions(+), 2 deletions(-) diff --git a/gtsam/discrete/DiscreteConditional.cpp b/gtsam/discrete/DiscreteConditional.cpp index 5acd7c0f6..e8aa4511d 100644 --- a/gtsam/discrete/DiscreteConditional.cpp +++ b/gtsam/discrete/DiscreteConditional.cpp @@ -110,7 +110,26 @@ DiscreteConditional DiscreteConditional::operator*( return DiscreteConditional(newFrontals.size(), discreteKeys, product); } -/* ******************************************************************************** */ +/* ************************************************************************** */ +DiscreteConditional DiscreteConditional::marginal(Key key) const { + if (nrParents() > 0) + throw std::invalid_argument( + "DiscreteConditional::marginal: single argument version only valid for " + "fully specified joint distributions (i.e., no parents)."); + + // Calculate the keys as the frontal keys without the given key. + DiscreteKeys discreteKeys{{key, cardinality(key)}}; + + // Calculate sum + ADT adt(*this); + for (auto&& k : frontals()) + if (k != key) adt = adt.sum(k, cardinality(k)); + + // Return new factor + return DiscreteConditional(1, discreteKeys, adt); +} + +/* ************************************************************************** */ void DiscreteConditional::print(const string& s, const KeyFormatter& formatter) const { cout << s << " P( "; diff --git a/gtsam/discrete/DiscreteConditional.h b/gtsam/discrete/DiscreteConditional.h index 450af57ab..836aa3920 100644 --- a/gtsam/discrete/DiscreteConditional.h +++ b/gtsam/discrete/DiscreteConditional.h @@ -126,6 +126,9 @@ class GTSAM_EXPORT DiscreteConditional */ DiscreteConditional operator*(const DiscreteConditional& other) const; + /** Calculate marginal on given key, no parent case. */ + DiscreteConditional marginal(Key key) const; + /// @} /// @name Testable /// @{ diff --git a/gtsam/discrete/discrete.i b/gtsam/discrete/discrete.i index 8bcb8b4aa..cd3e85598 100644 --- a/gtsam/discrete/discrete.i +++ b/gtsam/discrete/discrete.i @@ -97,6 +97,7 @@ virtual class DiscreteConditional : gtsam::DecisionTreeFactor { const gtsam::Ordering& orderedKeys); gtsam::DiscreteConditional operator*( const gtsam::DiscreteConditional& other) const; + DiscreteConditional marginal(gtsam::Key key) const; void print(string s = "Discrete Conditional\n", const gtsam::KeyFormatter& keyFormatter = gtsam::DefaultKeyFormatter) const; diff --git a/gtsam/discrete/tests/testDiscreteConditional.cpp b/gtsam/discrete/tests/testDiscreteConditional.cpp index 03766136c..125659517 100644 --- a/gtsam/discrete/tests/testDiscreteConditional.cpp +++ b/gtsam/discrete/tests/testDiscreteConditional.cpp @@ -97,10 +97,14 @@ TEST(DiscreteConditional, constructors3) { /* ************************************************************************* */ // Check calculation of joint P(A,B) TEST(DiscreteConditional, Multiply) { - DiscreteKey A(0, 2), B(1, 2); + DiscreteKey A(1, 2), B(0, 2); DiscreteConditional conditional(A | B = "1/2 2/1"); DiscreteConditional prior(B % "1/2"); + // The expected factor + DecisionTreeFactor f(A & B, "1 4 2 2"); + DiscreteConditional expected(2, f); + // P(A,B) = P(A|B) * P(B) = P(B) * P(A|B) for (auto&& actual : {prior * conditional, conditional * prior}) { EXPECT_LONGS_EQUAL(2, actual.nrFrontals()); @@ -110,8 +114,11 @@ TEST(DiscreteConditional, Multiply) { const DiscreteValues& v = it.first; EXPECT_DOUBLES_EQUAL(actual(v), conditional(v) * prior(v), 1e-9); } + // And for good measure: + EXPECT(assert_equal(expected, actual)); } } + /* ************************************************************************* */ // Check calculation of conditional joint P(A,B|C) TEST(DiscreteConditional, Multiply2) { @@ -131,6 +138,7 @@ TEST(DiscreteConditional, Multiply2) { } } } + /* ************************************************************************* */ // Check calculation of conditional joint P(A,B|C), double check keys TEST(DiscreteConditional, Multiply3) { @@ -150,6 +158,7 @@ TEST(DiscreteConditional, Multiply3) { } } } + /* ************************************************************************* */ // Check calculation of conditional joint P(A,B,C|D,E) = P(A,B|D) P(C|D,E) TEST(DiscreteConditional, Multiply4) { @@ -173,6 +182,31 @@ TEST(DiscreteConditional, Multiply4) { } } } + +/* ************************************************************************* */ +// Check calculation of marginals for joint P(A,B) +TEST(DiscreteConditional, marginals) { + DiscreteKey A(1, 2), B(0, 2); + DiscreteConditional conditional(A | B = "1/2 2/1"); + DiscreteConditional prior(B % "1/2"); + DiscreteConditional pAB = prior * conditional; + + DiscreteConditional actualA = pAB.marginal(A.first); + DiscreteConditional pA(A % "5/4"); + EXPECT(assert_equal(pA, actualA)); + EXPECT_LONGS_EQUAL(1, actualA.nrFrontals()); + EXPECT_LONGS_EQUAL(0, actualA.nrParents()); + KeyVector frontalsA(actualA.beginFrontals(), actualA.endFrontals()); + EXPECT((frontalsA == KeyVector{1})); + + DiscreteConditional actualB = pAB.marginal(B.first); + EXPECT(assert_equal(prior, actualB)); + EXPECT_LONGS_EQUAL(1, actualB.nrFrontals()); + EXPECT_LONGS_EQUAL(0, actualB.nrParents()); + KeyVector frontalsB(actualB.beginFrontals(), actualB.endFrontals()); + EXPECT((frontalsB == KeyVector{0})); +} + /* ************************************************************************* */ TEST(DiscreteConditional, likelihood) { DiscreteKey X(0, 2), Y(1, 3); diff --git a/python/gtsam/tests/test_DiscreteConditional.py b/python/gtsam/tests/test_DiscreteConditional.py index 190c22181..f46a0e877 100644 --- a/python/gtsam/tests/test_DiscreteConditional.py +++ b/python/gtsam/tests/test_DiscreteConditional.py @@ -81,6 +81,15 @@ class TestDiscreteConditional(GtsamTestCase): self.assertAlmostEqual( actual(v), AB_given_D(v) * C_given_DE(v)) + def test_marginals(self): + conditional = DiscreteConditional(A, [B], "1/2 2/1") + prior = DiscreteConditional(B, "1/2") + pAB = prior * conditional + self.gtsamAssertEquals(prior, pAB.marginal(B[0])) + + pA = DiscreteConditional(A % "5/4") + self.gtsamAssertEquals(pA, pAB.marginal(A[0])) + def test_markdown(self): """Test whether the _repr_markdown_ method.""" From 0b11b127609c3a0f7492050bf0613457f02fba22 Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Sun, 16 Jan 2022 12:02:22 -0500 Subject: [PATCH 7/7] fix tests --- gtsam/slam/slam.i | 2 +- python/gtsam/tests/test_DiscreteConditional.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/gtsam/slam/slam.i b/gtsam/slam/slam.i index a0a7329dd..602b2afe3 100644 --- a/gtsam/slam/slam.i +++ b/gtsam/slam/slam.i @@ -11,7 +11,7 @@ namespace gtsam { // ###### #include -template virtual class BetweenFactor : gtsam::NoiseModelFactor { diff --git a/python/gtsam/tests/test_DiscreteConditional.py b/python/gtsam/tests/test_DiscreteConditional.py index f46a0e877..241a5f0be 100644 --- a/python/gtsam/tests/test_DiscreteConditional.py +++ b/python/gtsam/tests/test_DiscreteConditional.py @@ -87,7 +87,7 @@ class TestDiscreteConditional(GtsamTestCase): pAB = prior * conditional self.gtsamAssertEquals(prior, pAB.marginal(B[0])) - pA = DiscreteConditional(A % "5/4") + pA = DiscreteConditional(A, "5/4") self.gtsamAssertEquals(pA, pAB.marginal(A[0])) def test_markdown(self):