diff --git a/gtsam/discrete/DiscreteConditional.cpp b/gtsam/discrete/DiscreteConditional.cpp index 0bdc7d7b5..5acd7c0f6 100644 --- a/gtsam/discrete/DiscreteConditional.cpp +++ b/gtsam/discrete/DiscreteConditional.cpp @@ -30,6 +30,7 @@ #include #include #include +#include using namespace std; using std::stringstream; @@ -38,37 +39,77 @@ using std::pair; namespace gtsam { // Instantiate base class -template class GTSAM_EXPORT Conditional ; +template class GTSAM_EXPORT + Conditional; -/* ******************************************************************************** */ +/* ************************************************************************** */ DiscreteConditional::DiscreteConditional(const size_t nrFrontals, - const DecisionTreeFactor& f) : - BaseFactor(f / (*f.sum(nrFrontals))), BaseConditional(nrFrontals) { -} + const DecisionTreeFactor& f) + : BaseFactor(f / (*f.sum(nrFrontals))), BaseConditional(nrFrontals) {} -/* ******************************************************************************** */ -DiscreteConditional::DiscreteConditional(const DecisionTreeFactor& joint, - const DecisionTreeFactor& marginal) : - BaseFactor( - ISDEBUG("DiscreteConditional::COUNT") ? joint : joint / marginal), BaseConditional( - joint.size()-marginal.size()) { - if (ISDEBUG("DiscreteConditional::DiscreteConditional")) - cout << (firstFrontalKey()) << endl; //TODO Print all keys -} +/* ************************************************************************** */ +DiscreteConditional::DiscreteConditional(size_t nrFrontals, + const DiscreteKeys& keys, + const ADT& potentials) + : BaseFactor(keys, potentials), BaseConditional(nrFrontals) {} -/* ******************************************************************************** */ +/* ************************************************************************** */ DiscreteConditional::DiscreteConditional(const DecisionTreeFactor& joint, - const DecisionTreeFactor& marginal, const Ordering& orderedKeys) : - DiscreteConditional(joint, marginal) { + const DecisionTreeFactor& marginal) + : BaseFactor(joint / marginal), + BaseConditional(joint.size() - marginal.size()) {} + +/* ************************************************************************** */ +DiscreteConditional::DiscreteConditional(const DecisionTreeFactor& joint, + const DecisionTreeFactor& marginal, + const Ordering& orderedKeys) + : DiscreteConditional(joint, marginal) { keys_.clear(); keys_.insert(keys_.end(), orderedKeys.begin(), orderedKeys.end()); } -/* ******************************************************************************** */ +/* ************************************************************************** */ DiscreteConditional::DiscreteConditional(const Signature& signature) : BaseFactor(signature.discreteKeys(), signature.cpt()), BaseConditional(1) {} +/* ************************************************************************** */ +DiscreteConditional DiscreteConditional::operator*( + const DiscreteConditional& other) const { + // Take union of frontal keys + std::set newFrontals; + for (auto&& key : this->frontals()) newFrontals.insert(key); + for (auto&& key : other.frontals()) newFrontals.insert(key); + + // Check if frontals overlapped + if (nrFrontals() + other.nrFrontals() > newFrontals.size()) + throw std::invalid_argument( + "DiscreteConditional::operator* called with overlapping frontal keys."); + + // Now, add cardinalities. + DiscreteKeys discreteKeys; + for (auto&& key : frontals()) + discreteKeys.emplace_back(key, cardinality(key)); + for (auto&& key : other.frontals()) + discreteKeys.emplace_back(key, other.cardinality(key)); + + // Sort + std::sort(discreteKeys.begin(), discreteKeys.end()); + + // Add parents to set, to make them unique + std::set parents; + for (auto&& key : this->parents()) + if (!newFrontals.count(key)) parents.emplace(key, cardinality(key)); + for (auto&& key : other.parents()) + if (!newFrontals.count(key)) parents.emplace(key, other.cardinality(key)); + + // Finally, add parents to keys, in order + for (auto&& dk : parents) discreteKeys.push_back(dk); + + ADT product = ADT::apply(other, ADT::Ring::mul); + return DiscreteConditional(newFrontals.size(), discreteKeys, product); +} + /* ******************************************************************************** */ void DiscreteConditional::print(const string& s, const KeyFormatter& formatter) const { @@ -82,7 +123,7 @@ void DiscreteConditional::print(const string& s, cout << formatter(*it) << " "; } } - cout << ")"; + cout << "):\n"; ADT::print(""); cout << endl; } diff --git a/gtsam/discrete/DiscreteConditional.h b/gtsam/discrete/DiscreteConditional.h index 4a83ff83a..450af57ab 100644 --- a/gtsam/discrete/DiscreteConditional.h +++ b/gtsam/discrete/DiscreteConditional.h @@ -49,14 +49,21 @@ class GTSAM_EXPORT DiscreteConditional /// @name Standard Constructors /// @{ - /** default constructor needed for serialization */ + /// Default constructor needed for serialization. DiscreteConditional() {} - /** constructor from factor */ + /// Construct from factor, taking the first `nFrontals` keys as frontals. DiscreteConditional(size_t nFrontals, const DecisionTreeFactor& f); + /** + * Construct from DiscreteKeys and AlgebraicDecisionTree, taking the first + * `nFrontals` keys as frontals, in the order given. + */ + DiscreteConditional(size_t nFrontals, const DiscreteKeys& keys, + const ADT& potentials); + /** Construct from signature */ - DiscreteConditional(const Signature& signature); + explicit DiscreteConditional(const Signature& signature); /** * Construct from key, parents, and a Signature::Table specifying the @@ -86,27 +93,38 @@ class GTSAM_EXPORT DiscreteConditional DiscreteConditional(const DiscreteKey& key, const std::string& spec) : DiscreteConditional(Signature(key, {}, spec)) {} - /** construct P(X|Y)=P(X,Y)/P(Y) from P(X,Y) and P(Y) */ + /** + * @brief construct P(X|Y) = f(X,Y)/f(Y) from f(X,Y) and f(Y) + * Assumes but *does not check* that f(Y)=sum_X f(X,Y). + */ DiscreteConditional(const DecisionTreeFactor& joint, const DecisionTreeFactor& marginal); - /** construct P(X|Y)=P(X,Y)/P(Y) from P(X,Y) and P(Y) */ + /** + * @brief construct P(X|Y) = f(X,Y)/f(Y) from f(X,Y) and f(Y) + * Assumes but *does not check* that f(Y)=sum_X f(X,Y). + * Makes sure the keys are ordered as given. Does not check orderedKeys. + */ DiscreteConditional(const DecisionTreeFactor& joint, const DecisionTreeFactor& marginal, const Ordering& orderedKeys); /** - * Combine several conditional into a single one. - * The conditionals must be given in increasing order, meaning that the - * parents of any conditional may not include a conditional coming before it. - * @param firstConditional Iterator to the first conditional to combine, must - * dereference to a shared_ptr. - * @param lastConditional Iterator to after the last conditional to combine, - * must dereference to a shared_ptr. - * */ - template - static shared_ptr Combine(ITERATOR firstConditional, - ITERATOR lastConditional); + * @brief Combine two conditionals, yielding a new conditional with the union + * of the frontal keys, ordered by gtsam::Key. + * + * The two conditionals must make a valid Bayes net fragment, i.e., + * the frontal variables cannot overlap, and must be acyclic: + * Example of correct use: + * P(A,B) = P(A|B) * P(B) + * P(A,B|C) = P(A|B) * P(B|C) + * P(A,B,C) = P(A,B|C) * P(C) + * Example of incorrect use: + * P(A|B) * P(A|C) = ? + * P(A|B) * P(B|A) = ? + * We check for overlapping frontals, but do *not* check for cyclic. + */ + DiscreteConditional operator*(const DiscreteConditional& other) const; /// @} /// @name Testable @@ -136,11 +154,6 @@ class GTSAM_EXPORT DiscreteConditional return ADT::operator()(values); } - /** Convert to a factor */ - DecisionTreeFactor::shared_ptr toFactor() const { - return DecisionTreeFactor::shared_ptr(new DecisionTreeFactor(*this)); - } - /** Restrict to given parent values, returns DecisionTreeFactor */ DecisionTreeFactor::shared_ptr choose( const DiscreteValues& parentsValues) const; @@ -208,23 +221,4 @@ class GTSAM_EXPORT DiscreteConditional template <> struct traits : public Testable {}; -/* ************************************************************************* */ -template -DiscreteConditional::shared_ptr DiscreteConditional::Combine( - ITERATOR firstConditional, ITERATOR lastConditional) { - // TODO: check for being a clique - - // multiply all the potentials of the given conditionals - size_t nrFrontals = 0; - DecisionTreeFactor product; - for (ITERATOR it = firstConditional; it != lastConditional; - ++it, ++nrFrontals) { - DiscreteConditional::shared_ptr c = *it; - DecisionTreeFactor::shared_ptr factor = c->toFactor(); - product = (*factor) * product; - } - // and then create a new multi-frontal conditional - return boost::make_shared(nrFrontals, product); -} - } // namespace gtsam diff --git a/gtsam/discrete/discrete.i b/gtsam/discrete/discrete.i index 24a941056..5fce25cf5 100644 --- a/gtsam/discrete/discrete.i +++ b/gtsam/discrete/discrete.i @@ -102,7 +102,6 @@ virtual class DiscreteConditional : gtsam::DecisionTreeFactor { void printSignature( string s = "Discrete Conditional: ", const gtsam::KeyFormatter& formatter = gtsam::DefaultKeyFormatter) const; - gtsam::DecisionTreeFactor* toFactor() const; gtsam::DecisionTreeFactor* choose( const gtsam::DiscreteValues& parentsValues) const; gtsam::DecisionTreeFactor* likelihood( diff --git a/gtsam/discrete/tests/testDecisionTreeFactor.cpp b/gtsam/discrete/tests/testDecisionTreeFactor.cpp index f2ab5f6bc..7e89874a5 100644 --- a/gtsam/discrete/tests/testDecisionTreeFactor.cpp +++ b/gtsam/discrete/tests/testDecisionTreeFactor.cpp @@ -60,7 +60,7 @@ TEST(DecisionTreeFactor, multiplication) { DiscretePrior prior(v1 % "1/3"); DecisionTreeFactor f1(v0 & v1, "1 2 3 4"); DecisionTreeFactor expected(v0 & v1, "0.25 1.5 0.75 3"); - CHECK(assert_equal(expected, prior * f1)); + CHECK(assert_equal(expected, static_cast(prior) * f1)); CHECK(assert_equal(expected, f1 * prior)); // Multiply two factors diff --git a/gtsam/discrete/tests/testDiscreteConditional.cpp b/gtsam/discrete/tests/testDiscreteConditional.cpp index 3fb67a615..03766136c 100644 --- a/gtsam/discrete/tests/testDiscreteConditional.cpp +++ b/gtsam/discrete/tests/testDiscreteConditional.cpp @@ -34,20 +34,21 @@ using namespace gtsam; TEST(DiscreteConditional, constructors) { DiscreteKey X(0, 2), Y(2, 3), Z(1, 2); // watch ordering ! - DiscreteConditional expected(X | Y = "1/1 2/3 1/4"); - EXPECT_LONGS_EQUAL(0, *(expected.beginFrontals())); - EXPECT_LONGS_EQUAL(2, *(expected.beginParents())); - EXPECT(expected.endParents() == expected.end()); - EXPECT(expected.endFrontals() == expected.beginParents()); + DiscreteConditional actual(X | Y = "1/1 2/3 1/4"); + EXPECT_LONGS_EQUAL(0, *(actual.beginFrontals())); + EXPECT_LONGS_EQUAL(2, *(actual.beginParents())); + EXPECT(actual.endParents() == actual.end()); + EXPECT(actual.endFrontals() == actual.beginParents()); DecisionTreeFactor f1(X & Y, "0.5 0.4 0.2 0.5 0.6 0.8"); - DiscreteConditional actual1(1, f1); - EXPECT(assert_equal(expected, actual1, 1e-9)); + DiscreteConditional expected1(1, f1); + EXPECT(assert_equal(expected1, actual, 1e-9)); DecisionTreeFactor f2( X & Y & Z, "0.2 0.5 0.3 0.6 0.4 0.7 0.25 0.55 0.35 0.65 0.45 0.75"); DiscreteConditional actual2(1, f2); - EXPECT(assert_equal(f2 / *f2.sum(1), *actual2.toFactor(), 1e-9)); + DecisionTreeFactor expected2 = f2 / *f2.sum(1); + EXPECT(assert_equal(expected2, static_cast(actual2))); } /* ************************************************************************* */ @@ -61,6 +62,7 @@ TEST(DiscreteConditional, constructors_alt_interface) { r3 += 1.0, 4.0; table += r1, r2, r3; DiscreteConditional actual1(X, {Y}, table); + DecisionTreeFactor f1(X & Y, "0.5 0.4 0.2 0.5 0.6 0.8"); DiscreteConditional expected1(1, f1); EXPECT(assert_equal(expected1, actual1, 1e-9)); @@ -68,43 +70,109 @@ TEST(DiscreteConditional, constructors_alt_interface) { DecisionTreeFactor f2( X & Y & Z, "0.2 0.5 0.3 0.6 0.4 0.7 0.25 0.55 0.35 0.65 0.45 0.75"); DiscreteConditional actual2(1, f2); - EXPECT(assert_equal(f2 / *f2.sum(1), *actual2.toFactor(), 1e-9)); + DecisionTreeFactor expected2 = f2 / *f2.sum(1); + EXPECT(assert_equal(expected2, static_cast(actual2))); } /* ************************************************************************* */ TEST(DiscreteConditional, constructors2) { - // Declare keys and ordering DiscreteKey C(0, 2), B(1, 2); - DecisionTreeFactor actual(C & B, "0.8 0.75 0.2 0.25"); Signature signature((C | B) = "4/1 3/1"); - DiscreteConditional expected(signature); - DecisionTreeFactor::shared_ptr expectedFactor = expected.toFactor(); - EXPECT(assert_equal(*expectedFactor, actual)); + DiscreteConditional actual(signature); + + DecisionTreeFactor expected(C & B, "0.8 0.75 0.2 0.25"); + EXPECT(assert_equal(expected, static_cast(actual))); } /* ************************************************************************* */ TEST(DiscreteConditional, constructors3) { - // Declare keys and ordering DiscreteKey C(0, 2), B(1, 2), A(2, 2); - DecisionTreeFactor actual(C & B & A, "0.8 0.5 0.5 0.2 0.2 0.5 0.5 0.8"); Signature signature((C | B, A) = "4/1 1/1 1/1 1/4"); - DiscreteConditional expected(signature); - DecisionTreeFactor::shared_ptr expectedFactor = expected.toFactor(); - EXPECT(assert_equal(*expectedFactor, actual)); + DiscreteConditional actual(signature); + + DecisionTreeFactor expected(C & B & A, "0.8 0.5 0.5 0.2 0.2 0.5 0.5 0.8"); + EXPECT(assert_equal(expected, static_cast(actual))); } /* ************************************************************************* */ -TEST(DiscreteConditional, Combine) { +// Check calculation of joint P(A,B) +TEST(DiscreteConditional, Multiply) { DiscreteKey A(0, 2), B(1, 2); - vector c; - c.push_back(boost::make_shared(A | B = "1/2 2/1")); - c.push_back(boost::make_shared(B % "1/2")); - DecisionTreeFactor factor(A & B, "0.111111 0.444444 0.222222 0.222222"); - DiscreteConditional expected(2, factor); - auto actual = DiscreteConditional::Combine(c.begin(), c.end()); - EXPECT(assert_equal(expected, *actual, 1e-5)); -} + DiscreteConditional conditional(A | B = "1/2 2/1"); + DiscreteConditional prior(B % "1/2"); + // P(A,B) = P(A|B) * P(B) = P(B) * P(A|B) + for (auto&& actual : {prior * conditional, conditional * prior}) { + EXPECT_LONGS_EQUAL(2, actual.nrFrontals()); + KeyVector frontals(actual.beginFrontals(), actual.endFrontals()); + EXPECT((frontals == KeyVector{0, 1})); + for (auto&& it : actual.enumerate()) { + const DiscreteValues& v = it.first; + EXPECT_DOUBLES_EQUAL(actual(v), conditional(v) * prior(v), 1e-9); + } + } +} +/* ************************************************************************* */ +// Check calculation of conditional joint P(A,B|C) +TEST(DiscreteConditional, Multiply2) { + DiscreteKey A(0, 2), B(1, 2), C(2, 2); + DiscreteConditional A_given_B(A | B = "1/3 3/1"); + DiscreteConditional B_given_C(B | C = "1/3 3/1"); + + // P(A,B|C) = P(A|B)P(B|C) = P(B|C)P(A|B) + for (auto&& actual : {A_given_B * B_given_C, B_given_C * A_given_B}) { + EXPECT_LONGS_EQUAL(2, actual.nrFrontals()); + EXPECT_LONGS_EQUAL(1, actual.nrParents()); + KeyVector frontals(actual.beginFrontals(), actual.endFrontals()); + EXPECT((frontals == KeyVector{0, 1})); + for (auto&& it : actual.enumerate()) { + const DiscreteValues& v = it.first; + EXPECT_DOUBLES_EQUAL(actual(v), A_given_B(v) * B_given_C(v), 1e-9); + } + } +} +/* ************************************************************************* */ +// Check calculation of conditional joint P(A,B|C), double check keys +TEST(DiscreteConditional, Multiply3) { + DiscreteKey A(1, 2), B(2, 2), C(0, 2); // different keys!!! + DiscreteConditional A_given_B(A | B = "1/3 3/1"); + DiscreteConditional B_given_C(B | C = "1/3 3/1"); + + // P(A,B|C) = P(A|B)P(B|C) = P(B|C)P(A|B) + for (auto&& actual : {A_given_B * B_given_C, B_given_C * A_given_B}) { + EXPECT_LONGS_EQUAL(2, actual.nrFrontals()); + EXPECT_LONGS_EQUAL(1, actual.nrParents()); + KeyVector frontals(actual.beginFrontals(), actual.endFrontals()); + EXPECT((frontals == KeyVector{1, 2})); + for (auto&& it : actual.enumerate()) { + const DiscreteValues& v = it.first; + EXPECT_DOUBLES_EQUAL(actual(v), A_given_B(v) * B_given_C(v), 1e-9); + } + } +} +/* ************************************************************************* */ +// Check calculation of conditional joint P(A,B,C|D,E) = P(A,B|D) P(C|D,E) +TEST(DiscreteConditional, Multiply4) { + DiscreteKey A(0, 2), B(1, 2), C(2, 2), D(4, 2), E(3, 2); + DiscreteConditional A_given_B(A | B = "1/3 3/1"); + DiscreteConditional B_given_D(B | D = "1/3 3/1"); + DiscreteConditional AB_given_D = A_given_B * B_given_D; + DiscreteConditional C_given_DE((C | D, E) = "4/1 1/1 1/1 1/4"); + + // P(A,B,C|D,E) = P(A,B|D) P(C|D,E) = P(C|D,E) P(A,B|D) + for (auto&& actual : {AB_given_D * C_given_DE, C_given_DE * AB_given_D}) { + EXPECT_LONGS_EQUAL(3, actual.nrFrontals()); + EXPECT_LONGS_EQUAL(2, actual.nrParents()); + KeyVector frontals(actual.beginFrontals(), actual.endFrontals()); + EXPECT((frontals == KeyVector{0, 1, 2})); + KeyVector parents(actual.beginParents(), actual.endParents()); + EXPECT((parents == KeyVector{3, 4})); + for (auto&& it : actual.enumerate()) { + const DiscreteValues& v = it.first; + EXPECT_DOUBLES_EQUAL(actual(v), AB_given_D(v) * C_given_DE(v), 1e-9); + } + } +} /* ************************************************************************* */ TEST(DiscreteConditional, likelihood) { DiscreteKey X(0, 2), Y(1, 3); diff --git a/gtsam/discrete/tests/testDiscretePrior.cpp b/gtsam/discrete/tests/testDiscretePrior.cpp index 6225d227e..6ef57c7ff 100644 --- a/gtsam/discrete/tests/testDiscretePrior.cpp +++ b/gtsam/discrete/tests/testDiscretePrior.cpp @@ -42,6 +42,19 @@ TEST(DiscretePrior, constructors) { EXPECT(assert_equal(expected, actual2, 1e-9)); } +/* ************************************************************************* */ +TEST(DiscretePrior, Multiply) { + DiscreteKey A(0, 2), B(1, 2); + DiscreteConditional conditional(A | B = "1/2 2/1"); + DiscretePrior prior(B, "1/2"); + DiscreteConditional actual = prior * conditional; // P(A|B) * P(B) + + EXPECT_LONGS_EQUAL(2, actual.nrFrontals()); // = P(A,B) + DecisionTreeFactor factor(A & B, "1 4 2 2"); + DiscreteConditional expected(2, factor); + EXPECT(assert_equal(expected, actual, 1e-5)); +} + /* ************************************************************************* */ TEST(DiscretePrior, operator) { DiscretePrior prior(X % "2/3");