diff --git a/examples/DiscreteBayesNetExample.cpp b/examples/DiscreteBayesNetExample.cpp new file mode 100644 index 000000000..629043431 --- /dev/null +++ b/examples/DiscreteBayesNetExample.cpp @@ -0,0 +1,83 @@ +/* ---------------------------------------------------------------------------- + + * GTSAM Copyright 2010, Georgia Tech Research Corporation, + * Atlanta, Georgia 30332-0415 + * All Rights Reserved + * Authors: Frank Dellaert, et al. (see THANKS for the full author list) + + * See LICENSE for the license information + + * -------------------------------------------------------------------------- */ + +/** + * @file DiscreteBayesNetExample.cpp + * @brief Discrete Bayes Net example with famous Asia Bayes Network + * @author Frank Dellaert + * @date JULY 10, 2020 + */ + +#include +#include +#include + +#include + +using namespace std; +using namespace gtsam; + +int main(int argc, char **argv) { + DiscreteBayesNet asia; + DiscreteKey Asia(0, 2), Smoking(4, 2), Tuberculosis(3, 2), LungCancer(6, 2), + Bronchitis(7, 2), Either(5, 2), XRay(2, 2), Dyspnea(1, 2); + asia.add(Asia % "99/1"); + asia.add(Smoking % "50/50"); + + asia.add(Tuberculosis | Asia = "99/1 95/5"); + asia.add(LungCancer | Smoking = "99/1 90/10"); + asia.add(Bronchitis | Smoking = "70/30 40/60"); + + asia.add((Either | Tuberculosis, LungCancer) = "F T T T"); + + asia.add(XRay | Either = "95/5 2/98"); + asia.add((Dyspnea | Either, Bronchitis) = "9/1 2/8 3/7 1/9"); + + // print + vector pretty = {"Asia", "Dyspnea", "XRay", "Tuberculosis", + "Smoking", "Either", "LungCancer", "Bronchitis"}; + auto formatter = [pretty](Key key) { return pretty[key]; }; + asia.print("Asia", formatter); + + // Convert to factor graph + DiscreteFactorGraph fg(asia); + + // Create solver and eliminate + Ordering ordering; + ordering += Key(0), Key(1), Key(2), Key(3), Key(4), Key(5), Key(6), Key(7); + DiscreteBayesNet::shared_ptr chordal = fg.eliminateSequential(ordering); + + // solve + DiscreteFactor::sharedValues mpe = chordal->optimize(); + GTSAM_PRINT(*mpe); + + // We can also build a Bayes tree (directed junction tree). + // The elimination order above will do fine: + auto bayesTree = fg.eliminateMultifrontal(ordering); + bayesTree->print("bayesTree", formatter); + + // add evidence, we were in Asia and we have dyspnea + fg.add(Asia, "0 1"); + fg.add(Dyspnea, "0 1"); + + // solve again, now with evidence + DiscreteBayesNet::shared_ptr chordal2 = fg.eliminateSequential(ordering); + DiscreteFactor::sharedValues mpe2 = chordal2->optimize(); + GTSAM_PRINT(*mpe2); + + // We can also sample from it + cout << "\n10 samples:" << endl; + for (size_t i = 0; i < 10; i++) { + DiscreteFactor::sharedValues sample = chordal2->sample(); + GTSAM_PRINT(*sample); + } + return 0; +} diff --git a/examples/DiscreteBayesNet_FG.cpp b/examples/DiscreteBayesNet_FG.cpp index 9802b5984..121df4bef 100644 --- a/examples/DiscreteBayesNet_FG.cpp +++ b/examples/DiscreteBayesNet_FG.cpp @@ -10,7 +10,7 @@ * -------------------------------------------------------------------------- */ /** - * @file DiscreteBayesNet_graph.cpp + * @file DiscreteBayesNet_FG.cpp * @brief Discrete Bayes Net example using Factor Graphs * @author Abhijit * @date Jun 4, 2012 diff --git a/examples/HMMExample.cpp b/examples/HMMExample.cpp new file mode 100644 index 000000000..a56058633 --- /dev/null +++ b/examples/HMMExample.cpp @@ -0,0 +1,94 @@ +/* ---------------------------------------------------------------------------- + + * GTSAM Copyright 2010-2020, Georgia Tech Research Corporation, + * Atlanta, Georgia 30332-0415 + * All Rights Reserved + * Authors: Frank Dellaert, et al. (see THANKS for the full author list) + + * See LICENSE for the license information + + * -------------------------------------------------------------------------- */ + +/** + * @file DiscreteBayesNetExample.cpp + * @brief Hidden Markov Model example, discrete. + * @author Frank Dellaert + * @date July 12, 2020 + */ + +#include +#include +#include + +#include +#include + +using namespace std; +using namespace gtsam; + +int main(int argc, char **argv) { + const int nrNodes = 4; + const size_t nrStates = 3; + + // Define variables as well as ordering + Ordering ordering; + vector keys; + for (int k = 0; k < nrNodes; k++) { + DiscreteKey key_i(k, nrStates); + keys.push_back(key_i); + ordering.emplace_back(k); + } + + // Create HMM as a DiscreteBayesNet + DiscreteBayesNet hmm; + + // Define backbone + const string transition = "8/1/1 1/8/1 1/1/8"; + for (int k = 1; k < nrNodes; k++) { + hmm.add(keys[k] | keys[k - 1] = transition); + } + + // Add some measurements, not needed for all time steps! + hmm.add(keys[0] % "7/2/1"); + hmm.add(keys[1] % "1/9/0"); + hmm.add(keys.back() % "5/4/1"); + + // print + hmm.print("HMM"); + + // Convert to factor graph + DiscreteFactorGraph factorGraph(hmm); + + // Create solver and eliminate + // This will create a DAG ordered with arrow of time reversed + DiscreteBayesNet::shared_ptr chordal = + factorGraph.eliminateSequential(ordering); + chordal->print("Eliminated"); + + // solve + DiscreteFactor::sharedValues mpe = chordal->optimize(); + GTSAM_PRINT(*mpe); + + // We can also sample from it + cout << "\n10 samples:" << endl; + for (size_t k = 0; k < 10; k++) { + DiscreteFactor::sharedValues sample = chordal->sample(); + GTSAM_PRINT(*sample); + } + + // Or compute the marginals. This re-eliminates the FG into a Bayes tree + cout << "\nComputing Node Marginals .." << endl; + DiscreteMarginals marginals(factorGraph); + for (int k = 0; k < nrNodes; k++) { + Vector margProbs = marginals.marginalProbabilities(keys[k]); + stringstream ss; + ss << "marginal " << k; + print(margProbs, ss.str()); + } + + // TODO(frank): put in the glue to have DiscreteMarginals produce *arbitrary* + // joints efficiently, by the Bayes tree shortcut magic. All the code is there + // but it's not yet connected. + + return 0; +} diff --git a/gtsam/discrete/DiscreteConditional.cpp b/gtsam/discrete/DiscreteConditional.cpp index 2ab3054a8..ac7c58405 100644 --- a/gtsam/discrete/DiscreteConditional.cpp +++ b/gtsam/discrete/DiscreteConditional.cpp @@ -27,6 +27,7 @@ #include #include #include +#include #include using namespace std; @@ -61,16 +62,26 @@ DiscreteConditional::DiscreteConditional(const DecisionTreeFactor& joint, } /* ******************************************************************************** */ -DiscreteConditional::DiscreteConditional(const Signature& signature) : - BaseFactor(signature.discreteKeysParentsFirst(), signature.cpt()), BaseConditional( - 1) { -} +DiscreteConditional::DiscreteConditional(const Signature& signature) + : BaseFactor(signature.discreteKeys(), signature.cpt()), + BaseConditional(1) {} /* ******************************************************************************** */ -void DiscreteConditional::print(const std::string& s, - const KeyFormatter& formatter) const { - std::cout << s << std::endl; - Potentials::print(s); +void DiscreteConditional::print(const string& s, + const KeyFormatter& formatter) const { + cout << s << " P( "; + for (const_iterator it = beginFrontals(); it != endFrontals(); ++it) { + cout << formatter(*it) << " "; + } + if (nrParents()) { + cout << "| "; + for (const_iterator it = beginParents(); it != endParents(); ++it) { + cout << formatter(*it) << " "; + } + } + cout << ")"; + Potentials::print(""); + cout << endl; } /* ******************************************************************************** */ @@ -173,55 +184,28 @@ size_t DiscreteConditional::solve(const Values& parentsValues) const { /* ******************************************************************************** */ size_t DiscreteConditional::sample(const Values& parentsValues) const { - - static mt19937 rng(2); // random number generator - - bool debug = ISDEBUG("DiscreteConditional::sample"); + static mt19937 rng(2); // random number generator // Get the correct conditional density - ADT pFS = choose(parentsValues); // P(F|S=parentsValues) - if (debug) - GTSAM_PRINT(pFS); + ADT pFS = choose(parentsValues); // P(F|S=parentsValues) - // get cumulative distribution function (cdf) - // TODO, only works for one key now, seems horribly slow this way + // TODO(Duy): only works for one key now, seems horribly slow this way assert(nrFrontals() == 1); - Key j = (firstFrontalKey()); - size_t nj = cardinality(j); - vector cdf(nj); + Key key = firstFrontalKey(); + size_t nj = cardinality(key); + vector p(nj); Values frontals; - double sum = 0; for (size_t value = 0; value < nj; value++) { - frontals[j] = value; - double pValueS = pFS(frontals); // P(F=value|S=parentsValues) - sum += pValueS; // accumulate - if (debug) - cout << sum << " "; - if (pValueS == 1) { - if (debug) - cout << "--> " << value << endl; - return value; // shortcut exit + frontals[key] = value; + p[value] = pFS(frontals); // P(F=value|S=parentsValues) + if (p[value] == 1.0) { + return value; // shortcut exit } - cdf[value] = sum; } - - // inspired by http://www.boost.org/doc/libs/1_46_1/doc/html/boost_random/tutorial.html - uniform_real_distribution dist(0, cdf.back()); - size_t sampled = lower_bound(cdf.begin(), cdf.end(), dist(rng)) - cdf.begin(); - if (debug) - cout << "-> " << sampled << endl; - - return sampled; - - return 0; + std::discrete_distribution distribution(p.begin(), p.end()); + return distribution(rng); } -/* ******************************************************************************** */ -//void DiscreteConditional::permuteWithInverse( -// const Permutation& inversePermutation) { -// IndexConditionalOrdered::permuteWithInverse(inversePermutation); -// Potentials::permuteWithInverse(inversePermutation); -//} /* ******************************************************************************** */ }// namespace diff --git a/gtsam/discrete/Potentials.cpp b/gtsam/discrete/Potentials.cpp index c4cdbe0ef..fe99ea975 100644 --- a/gtsam/discrete/Potentials.cpp +++ b/gtsam/discrete/Potentials.cpp @@ -15,50 +15,52 @@ * @author Frank Dellaert */ -#include #include +#include + #include +#include + using namespace std; namespace gtsam { - // explicit instantiation - template class DecisionTree ; - template class AlgebraicDecisionTree ; +// explicit instantiation +template class DecisionTree; +template class AlgebraicDecisionTree; - /* ************************************************************************* */ - double Potentials::safe_div(const double& a, const double& b) { - // cout << boost::format("%g / %g = %g\n") % a % b % ((a == 0) ? 0 : (a / b)); - // The use for safe_div is when we divide the product factor by the sum factor. - // If the product or sum is zero, we accord zero probability to the event. - return (a == 0 || b == 0) ? 0 : (a / b); - } +/* ************************************************************************* */ +double Potentials::safe_div(const double& a, const double& b) { + // cout << boost::format("%g / %g = %g\n") % a % b % ((a == 0) ? 0 : (a / b)); + // The use for safe_div is when we divide the product factor by the sum + // factor. If the product or sum is zero, we accord zero probability to the + // event. + return (a == 0 || b == 0) ? 0 : (a / b); +} - /* ******************************************************************************** */ - Potentials::Potentials() : - ADT(1.0) { - } +/* ******************************************************************************** + */ +Potentials::Potentials() : ADT(1.0) {} - /* ******************************************************************************** */ - Potentials::Potentials(const DiscreteKeys& keys, const ADT& decisionTree) : - ADT(decisionTree), cardinalities_(keys.cardinalities()) { - } +/* ******************************************************************************** + */ +Potentials::Potentials(const DiscreteKeys& keys, const ADT& decisionTree) + : ADT(decisionTree), cardinalities_(keys.cardinalities()) {} - /* ************************************************************************* */ - bool Potentials::equals(const Potentials& other, double tol) const { - return ADT::equals(other, tol); - } +/* ************************************************************************* */ +bool Potentials::equals(const Potentials& other, double tol) const { + return ADT::equals(other, tol); +} - /* ************************************************************************* */ - void Potentials::print(const string& s, - const KeyFormatter& formatter) const { - cout << s << "\n Cardinalities: "; - for(const DiscreteKey& key: cardinalities_) - cout << formatter(key.first) << "=" << formatter(key.second) << " "; - cout << endl; - ADT::print(" "); - } +/* ************************************************************************* */ +void Potentials::print(const string& s, const KeyFormatter& formatter) const { + cout << s << "\n Cardinalities: {"; + for (const DiscreteKey& key : cardinalities_) + cout << formatter(key.first) << ":" << key.second << ", "; + cout << "}" << endl; + ADT::print(" "); +} // // /* ************************************************************************* */ // template @@ -95,4 +97,4 @@ namespace gtsam { /* ************************************************************************* */ -} // namespace gtsam +} // namespace gtsam diff --git a/gtsam/discrete/Signature.cpp b/gtsam/discrete/Signature.cpp index 89e763703..94b160a29 100644 --- a/gtsam/discrete/Signature.cpp +++ b/gtsam/discrete/Signature.cpp @@ -122,28 +122,30 @@ namespace gtsam { key_(key) { } - DiscreteKeys Signature::discreteKeysParentsFirst() const { + DiscreteKeys Signature::discreteKeys() const { DiscreteKeys keys; - for(const DiscreteKey& key: parents_) - keys.push_back(key); keys.push_back(key_); + for (const DiscreteKey& key : parents_) keys.push_back(key); return keys; } KeyVector Signature::indices() const { KeyVector js; js.push_back(key_.first); - for(const DiscreteKey& key: parents_) - js.push_back(key.first); + for (const DiscreteKey& key : parents_) js.push_back(key.first); return js; } vector Signature::cpt() const { vector cpt; if (table_) { - for(const Row& row: *table_) - for(const double& x: row) - cpt.push_back(x); + const size_t nrStates = table_->at(0).size(); + for (size_t j = 0; j < nrStates; j++) { + for (const Row& row : *table_) { + assert(row.size() == nrStates); + cpt.push_back(row[j]); + } + } } return cpt; } diff --git a/gtsam/discrete/Signature.h b/gtsam/discrete/Signature.h index 587cd6b30..6c59b5bff 100644 --- a/gtsam/discrete/Signature.h +++ b/gtsam/discrete/Signature.h @@ -86,8 +86,8 @@ namespace gtsam { return parents_; } - /** All keys, with variable key last */ - DiscreteKeys discreteKeysParentsFirst() const; + /** All keys, with variable key first */ + DiscreteKeys discreteKeys() const; /** All key indices, with variable key first */ KeyVector indices() const; diff --git a/gtsam/discrete/tests/testAlgebraicDecisionTree.cpp b/gtsam/discrete/tests/testAlgebraicDecisionTree.cpp index 753af16d8..0261ef833 100644 --- a/gtsam/discrete/tests/testAlgebraicDecisionTree.cpp +++ b/gtsam/discrete/tests/testAlgebraicDecisionTree.cpp @@ -132,7 +132,7 @@ TEST(ADT, example3) /** Convert Signature into CPT */ ADT create(const Signature& signature) { - ADT p(signature.discreteKeysParentsFirst(), signature.cpt()); + ADT p(signature.discreteKeys(), signature.cpt()); static size_t count = 0; const DiscreteKey& key = signature.key(); string dotfile = (boost::format("CPT-%03d-%d") % ++count % key.first).str(); @@ -181,19 +181,20 @@ TEST(ADT, joint) dot(joint, "Asia-ASTLBEX"); joint = apply(joint, pD, &mul); dot(joint, "Asia-ASTLBEXD"); - EXPECT_LONGS_EQUAL(346, (long)muls); + EXPECT_LONGS_EQUAL(346, muls); gttoc_(asiaJoint); tictoc_getNode(asiaJointNode, asiaJoint); elapsed = asiaJointNode->secs() + asiaJointNode->wall(); tictoc_reset_(); printCounts("Asia joint"); + // Form P(A,S,T,L) = P(A) P(S) P(T|A) P(L|S) ADT pASTL = pA; pASTL = apply(pASTL, pS, &mul); pASTL = apply(pASTL, pT, &mul); pASTL = apply(pASTL, pL, &mul); - // test combine + // test combine to check that P(A) = \sum_{S,T,L} P(A,S,T,L) ADT fAa = pASTL.combine(L, &add_).combine(T, &add_).combine(S, &add_); EXPECT(assert_equal(pA, fAa)); ADT fAb = pASTL.combine(S, &add_).combine(T, &add_).combine(L, &add_); diff --git a/gtsam/discrete/tests/testDiscreteBayesNet.cpp b/gtsam/discrete/tests/testDiscreteBayesNet.cpp index 5ed662332..2b440e5a0 100644 --- a/gtsam/discrete/tests/testDiscreteBayesNet.cpp +++ b/gtsam/discrete/tests/testDiscreteBayesNet.cpp @@ -18,110 +18,135 @@ #include #include -#include +#include #include +#include +#include #include -#include + #include +#include using namespace boost::assign; #include +#include +#include using namespace std; using namespace gtsam; /* ************************************************************************* */ -TEST(DiscreteBayesNet, Asia) -{ +TEST(DiscreteBayesNet, bayesNet) { + DiscreteBayesNet bayesNet; + DiscreteKey Parent(0, 2), Child(1, 2); + + auto prior = boost::make_shared(Parent % "6/4"); + CHECK(assert_equal(Potentials::ADT({Parent}, "0.6 0.4"), + (Potentials::ADT)*prior)); + bayesNet.push_back(prior); + + auto conditional = + boost::make_shared(Child | Parent = "7/3 8/2"); + EXPECT_LONGS_EQUAL(1, *(conditional->beginFrontals())); + Potentials::ADT expected(Child & Parent, "0.7 0.8 0.3 0.2"); + CHECK(assert_equal(expected, (Potentials::ADT)*conditional)); + bayesNet.push_back(conditional); + + DiscreteFactorGraph fg(bayesNet); + LONGS_EQUAL(2, fg.back()->size()); + + // Check the marginals + const double expectedMarginal[2]{0.4, 0.6 * 0.3 + 0.4 * 0.2}; + DiscreteMarginals marginals(fg); + for (size_t j = 0; j < 2; j++) { + Vector FT = marginals.marginalProbabilities(DiscreteKey(j, 2)); + EXPECT_DOUBLES_EQUAL(expectedMarginal[j], FT[1], 1e-3); + EXPECT_DOUBLES_EQUAL(FT[0], 1.0 - FT[1], 1e-9); + } +} + +/* ************************************************************************* */ +TEST(DiscreteBayesNet, Asia) { DiscreteBayesNet asia; -// DiscreteKey A("Asia"), S("Smoking"), T("Tuberculosis"), L("LungCancer"), B( -// "Bronchitis"), E("Either"), X("XRay"), D("Dyspnoea"); - DiscreteKey A(0,2), S(4,2), T(3,2), L(6,2), B(7,2), E(5,2), X(2,2), D(1,2); + DiscreteKey Asia(0, 2), Smoking(4, 2), Tuberculosis(3, 2), LungCancer(6, 2), + Bronchitis(7, 2), Either(5, 2), XRay(2, 2), Dyspnea(1, 2); - // TODO: make a version that doesn't use the parser - asia.add(A % "99/1"); - asia.add(S % "50/50"); + asia.add(Asia % "99/1"); + asia.add(Smoking % "50/50"); - asia.add(T | A = "99/1 95/5"); - asia.add(L | S = "99/1 90/10"); - asia.add(B | S = "70/30 40/60"); + asia.add(Tuberculosis | Asia = "99/1 95/5"); + asia.add(LungCancer | Smoking = "99/1 90/10"); + asia.add(Bronchitis | Smoking = "70/30 40/60"); - asia.add((E | T, L) = "F T T T"); + asia.add((Either | Tuberculosis, LungCancer) = "F T T T"); - asia.add(X | E = "95/5 2/98"); - // next lines are same as asia.add((D | E, B) = "9/1 2/8 3/7 1/9"); - DiscreteConditional::shared_ptr actual = - boost::make_shared((D | E, B) = "9/1 2/8 3/7 1/9"); - asia.push_back(actual); - // GTSAM_PRINT(asia); + asia.add(XRay | Either = "95/5 2/98"); + asia.add((Dyspnea | Either, Bronchitis) = "9/1 2/8 3/7 1/9"); // Convert to factor graph DiscreteFactorGraph fg(asia); -// GTSAM_PRINT(fg); - LONGS_EQUAL(3,fg.back()->size()); - Potentials::ADT expected(B & D & E, "0.9 0.3 0.1 0.7 0.2 0.1 0.8 0.9"); - CHECK(assert_equal(expected,(Potentials::ADT)*actual)); + LONGS_EQUAL(3, fg.back()->size()); + + // Check the marginals we know (of the parent-less nodes) + DiscreteMarginals marginals(fg); + Vector2 va(0.99, 0.01), vs(0.5, 0.5); + EXPECT(assert_equal(va, marginals.marginalProbabilities(Asia))); + EXPECT(assert_equal(vs, marginals.marginalProbabilities(Smoking))); // Create solver and eliminate Ordering ordering; - ordering += Key(0),Key(1),Key(2),Key(3),Key(4),Key(5),Key(6),Key(7); + ordering += Key(0), Key(1), Key(2), Key(3), Key(4), Key(5), Key(6), Key(7); DiscreteBayesNet::shared_ptr chordal = fg.eliminateSequential(ordering); -// GTSAM_PRINT(*chordal); - DiscreteConditional expected2(B % "11/9"); - CHECK(assert_equal(expected2,*chordal->back())); + DiscreteConditional expected2(Bronchitis % "11/9"); + EXPECT(assert_equal(expected2, *chordal->back())); // solve DiscreteFactor::sharedValues actualMPE = chordal->optimize(); DiscreteFactor::Values expectedMPE; - insert(expectedMPE)(A.first, 0)(D.first, 0)(X.first, 0)(T.first, 0)(S.first, - 0)(E.first, 0)(L.first, 0)(B.first, 0); + insert(expectedMPE)(Asia.first, 0)(Dyspnea.first, 0)(XRay.first, 0)( + Tuberculosis.first, 0)(Smoking.first, 0)(Either.first, 0)( + LungCancer.first, 0)(Bronchitis.first, 0); EXPECT(assert_equal(expectedMPE, *actualMPE)); - // add evidence, we were in Asia and we have Dispnoea - fg.add(A, "0 1"); - fg.add(D, "0 1"); -// fg.product().dot("fg"); + // add evidence, we were in Asia and we have dyspnea + fg.add(Asia, "0 1"); + fg.add(Dyspnea, "0 1"); // solve again, now with evidence DiscreteBayesNet::shared_ptr chordal2 = fg.eliminateSequential(ordering); -// GTSAM_PRINT(*chordal2); DiscreteFactor::sharedValues actualMPE2 = chordal2->optimize(); DiscreteFactor::Values expectedMPE2; - insert(expectedMPE2)(A.first, 1)(D.first, 1)(X.first, 0)(T.first, 0)(S.first, - 1)(E.first, 0)(L.first, 0)(B.first, 1); + insert(expectedMPE2)(Asia.first, 1)(Dyspnea.first, 1)(XRay.first, 0)( + Tuberculosis.first, 0)(Smoking.first, 1)(Either.first, 0)( + LungCancer.first, 0)(Bronchitis.first, 1); EXPECT(assert_equal(expectedMPE2, *actualMPE2)); // now sample from it DiscreteFactor::Values expectedSample; SETDEBUG("DiscreteConditional::sample", false); - insert(expectedSample)(A.first, 1)(D.first, 1)(X.first, 1)(T.first, 0)( - S.first, 1)(E.first, 1)(L.first, 1)(B.first, 0); + insert(expectedSample)(Asia.first, 1)(Dyspnea.first, 1)(XRay.first, 1)( + Tuberculosis.first, 0)(Smoking.first, 1)(Either.first, 1)( + LungCancer.first, 1)(Bronchitis.first, 0); DiscreteFactor::sharedValues actualSample = chordal2->sample(); EXPECT(assert_equal(expectedSample, *actualSample)); } /* ************************************************************************* */ -TEST_UNSAFE(DiscreteBayesNet, Sugar) -{ - DiscreteKey T(0,2), L(1,2), E(2,2), D(3,2), C(8,3), S(7,2); +TEST_UNSAFE(DiscreteBayesNet, Sugar) { + DiscreteKey T(0, 2), L(1, 2), E(2, 2), C(8, 3), S(7, 2); DiscreteBayesNet bn; - // test some mistakes - // add(bn, D); - // add(bn, D | E); - // add(bn, D | E = "blah"); - // try logic bn.add((E | T, L) = "OR"); bn.add((E | T, L) = "AND"); - // // try multivalued - bn.add(C % "1/1/2"); - bn.add(C | S = "1/1/2 5/2/3"); + // try multivalued + bn.add(C % "1/1/2"); + bn.add(C | S = "1/1/2 5/2/3"); } /* ************************************************************************* */ @@ -130,4 +155,3 @@ int main() { return TestRegistry::runAllTests(tr); } /* ************************************************************************* */ - diff --git a/gtsam/discrete/tests/testDiscreteBayesTree.cpp b/gtsam/discrete/tests/testDiscreteBayesTree.cpp index 9950c014e..11a88af59 100644 --- a/gtsam/discrete/tests/testDiscreteBayesTree.cpp +++ b/gtsam/discrete/tests/testDiscreteBayesTree.cpp @@ -80,6 +80,12 @@ TEST_UNSAFE(DiscreteBayesTree, ThinTree) { bayesTree->saveGraph("/tmp/discreteBayesTree.dot"); } + // Check frontals and parents + for (size_t i : {13, 14, 9, 3, 2, 8, 1, 0, 10, 5, 4}) { + auto clique_i = (*bayesTree)[i]; + EXPECT_LONGS_EQUAL(i, *(clique_i->conditional_->beginFrontals())); + } + auto R = bayesTree->roots().front(); // Check whether BN and BT give the same answer on all configurations @@ -104,16 +110,22 @@ TEST_UNSAFE(DiscreteBayesTree, ThinTree) { double px = bayesTree->evaluate(x); for (size_t i = 0; i < 15; i++) if (x[i]) marginals[i] += px; - if (x[12] && x[14]) joint_12_14 += px; - if (x[9] && x[12] && x[14]) joint_9_12_14 += px; - if (x[8] && x[12] && x[14]) joint_8_12_14 += px; + if (x[12] && x[14]) { + joint_12_14 += px; + if (x[9]) joint_9_12_14 += px; + if (x[8]) joint_8_12_14 += px; + } if (x[8] && x[12]) joint_8_12 += px; - if (x[8] && x[2]) joint82 += px; - if (x[1] && x[2]) joint12 += px; - if (x[2] && x[4]) joint24 += px; - if (x[4] && x[5]) joint45 += px; - if (x[4] && x[6]) joint46 += px; - if (x[4] && x[11]) joint_4_11 += px; + if (x[2]) { + if (x[8]) joint82 += px; + if (x[1]) joint12 += px; + } + if (x[4]) { + if (x[2]) joint24 += px; + if (x[5]) joint45 += px; + if (x[6]) joint46 += px; + if (x[11]) joint_4_11 += px; + } if (x[11] && x[13]) { joint_11_13 += px; if (x[8] && x[12]) joint_8_11_12_13 += px; @@ -129,50 +141,50 @@ TEST_UNSAFE(DiscreteBayesTree, ThinTree) { DiscreteFactor::Values all1 = allPosbValues.back(); // check separator marginal P(S0) - auto c = (*bayesTree)[0]; + auto clique = (*bayesTree)[0]; DiscreteFactorGraph separatorMarginal0 = - c->separatorMarginal(EliminateDiscrete); + clique->separatorMarginal(EliminateDiscrete); DOUBLES_EQUAL(joint_8_12, separatorMarginal0(all1), 1e-9); // check separator marginal P(S9), should be P(14) - c = (*bayesTree)[9]; + clique = (*bayesTree)[9]; DiscreteFactorGraph separatorMarginal9 = - c->separatorMarginal(EliminateDiscrete); + clique->separatorMarginal(EliminateDiscrete); DOUBLES_EQUAL(marginals[14], separatorMarginal9(all1), 1e-9); // check separator marginal of root, should be empty - c = (*bayesTree)[11]; + clique = (*bayesTree)[11]; DiscreteFactorGraph separatorMarginal11 = - c->separatorMarginal(EliminateDiscrete); + clique->separatorMarginal(EliminateDiscrete); LONGS_EQUAL(0, separatorMarginal11.size()); // check shortcut P(S9||R) to root - c = (*bayesTree)[9]; - DiscreteBayesNet shortcut = c->shortcut(R, EliminateDiscrete); + clique = (*bayesTree)[9]; + DiscreteBayesNet shortcut = clique->shortcut(R, EliminateDiscrete); LONGS_EQUAL(1, shortcut.size()); DOUBLES_EQUAL(joint_11_13_14 / joint_11_13, shortcut.evaluate(all1), 1e-9); // check shortcut P(S8||R) to root - c = (*bayesTree)[8]; - shortcut = c->shortcut(R, EliminateDiscrete); + clique = (*bayesTree)[8]; + shortcut = clique->shortcut(R, EliminateDiscrete); DOUBLES_EQUAL(joint_11_12_13_14 / joint_11_13, shortcut.evaluate(all1), 1e-9); // check shortcut P(S2||R) to root - c = (*bayesTree)[2]; - shortcut = c->shortcut(R, EliminateDiscrete); + clique = (*bayesTree)[2]; + shortcut = clique->shortcut(R, EliminateDiscrete); DOUBLES_EQUAL(joint_9_11_12_13 / joint_11_13, shortcut.evaluate(all1), 1e-9); // check shortcut P(S0||R) to root - c = (*bayesTree)[0]; - shortcut = c->shortcut(R, EliminateDiscrete); + clique = (*bayesTree)[0]; + shortcut = clique->shortcut(R, EliminateDiscrete); DOUBLES_EQUAL(joint_8_11_12_13 / joint_11_13, shortcut.evaluate(all1), 1e-9); // calculate all shortcuts to root DiscreteBayesTree::Nodes cliques = bayesTree->nodes(); - for (auto c : cliques) { - DiscreteBayesNet shortcut = c.second->shortcut(R, EliminateDiscrete); + for (auto clique : cliques) { + DiscreteBayesNet shortcut = clique.second->shortcut(R, EliminateDiscrete); if (debug) { - c.second->conditional_->printSignature(); + clique.second->conditional_->printSignature(); shortcut.print("shortcut:"); } } diff --git a/gtsam/discrete/tests/testDiscreteConditional.cpp b/gtsam/discrete/tests/testDiscreteConditional.cpp index 888bf76df..3ac3ffc9e 100644 --- a/gtsam/discrete/tests/testDiscreteConditional.cpp +++ b/gtsam/discrete/tests/testDiscreteConditional.cpp @@ -16,9 +16,9 @@ * @date Feb 14, 2011 */ -#include #include #include +#include using namespace boost::assign; #include @@ -36,6 +36,11 @@ TEST( DiscreteConditional, constructors) DiscreteConditional::shared_ptr expected1 = // boost::make_shared(X | Y = "1/1 2/3 1/4"); EXPECT(expected1); + EXPECT_LONGS_EQUAL(0, *(expected1->beginFrontals())); + EXPECT_LONGS_EQUAL(2, *(expected1->beginParents())); + EXPECT(expected1->endParents() == expected1->end()); + EXPECT(expected1->endFrontals() == expected1->beginParents()); + DecisionTreeFactor f1(X & Y, "0.5 0.4 0.2 0.5 0.6 0.8"); DiscreteConditional actual1(1, f1); EXPECT(assert_equal(*expected1, actual1, 1e-9)); @@ -43,71 +48,68 @@ TEST( DiscreteConditional, constructors) DecisionTreeFactor f2(X & Y & Z, "0.2 0.5 0.3 0.6 0.4 0.7 0.25 0.55 0.35 0.65 0.45 0.75"); DiscreteConditional actual2(1, f2); - DecisionTreeFactor::shared_ptr actual2factor = actual2.toFactor(); -// EXPECT(assert_equal(f2, *actual2factor, 1e-9)); + EXPECT(assert_equal(f2 / *f2.sum(1), *actual2.toFactor(), 1e-9)); } /* ************************************************************************* */ -TEST( DiscreteConditional, constructors_alt_interface) -{ - DiscreteKey X(0, 2), Y(2, 3), Z(1, 2); // watch ordering ! +TEST(DiscreteConditional, constructors_alt_interface) { + DiscreteKey X(0, 2), Y(2, 3), Z(1, 2); // watch ordering ! Signature::Table table; Signature::Row r1, r2, r3; - r1 += 1.0, 1.0; r2 += 2.0, 3.0; r3 += 1.0, 4.0; + r1 += 1.0, 1.0; + r2 += 2.0, 3.0; + r3 += 1.0, 4.0; table += r1, r2, r3; - DiscreteConditional::shared_ptr expected1 = // - boost::make_shared(X | Y = table); - EXPECT(expected1); + auto actual1 = boost::make_shared(X | Y = table); + EXPECT(actual1); DecisionTreeFactor f1(X & Y, "0.5 0.4 0.2 0.5 0.6 0.8"); - DiscreteConditional actual1(1, f1); - EXPECT(assert_equal(*expected1, actual1, 1e-9)); + DiscreteConditional expected1(1, f1); + EXPECT(assert_equal(expected1, *actual1, 1e-9)); - DecisionTreeFactor f2(X & Y & Z, - "0.2 0.5 0.3 0.6 0.4 0.7 0.25 0.55 0.35 0.65 0.45 0.75"); + DecisionTreeFactor f2( + X & Y & Z, "0.2 0.5 0.3 0.6 0.4 0.7 0.25 0.55 0.35 0.65 0.45 0.75"); DiscreteConditional actual2(1, f2); - DecisionTreeFactor::shared_ptr actual2factor = actual2.toFactor(); -// EXPECT(assert_equal(f2, *actual2factor, 1e-9)); + EXPECT(assert_equal(f2 / *f2.sum(1), *actual2.toFactor(), 1e-9)); } /* ************************************************************************* */ -TEST( DiscreteConditional, constructors2) -{ +TEST(DiscreteConditional, constructors2) { // Declare keys and ordering - DiscreteKey C(0,2), B(1,2); - DecisionTreeFactor expected(C & B, "0.8 0.75 0.2 0.25"); + DiscreteKey C(0, 2), B(1, 2); + DecisionTreeFactor actual(C & B, "0.8 0.75 0.2 0.25"); Signature signature((C | B) = "4/1 3/1"); - DiscreteConditional actual(signature); - DecisionTreeFactor::shared_ptr actualFactor = actual.toFactor(); - EXPECT(assert_equal(expected, *actualFactor)); + DiscreteConditional expected(signature); + DecisionTreeFactor::shared_ptr expectedFactor = expected.toFactor(); + EXPECT(assert_equal(*expectedFactor, actual)); } /* ************************************************************************* */ -TEST( DiscreteConditional, constructors3) -{ +TEST(DiscreteConditional, constructors3) { // Declare keys and ordering - DiscreteKey C(0,2), B(1,2), A(2,2); - DecisionTreeFactor expected(C & B & A, "0.8 0.5 0.5 0.2 0.2 0.5 0.5 0.8"); + DiscreteKey C(0, 2), B(1, 2), A(2, 2); + DecisionTreeFactor actual(C & B & A, "0.8 0.5 0.5 0.2 0.2 0.5 0.5 0.8"); Signature signature((C | B, A) = "4/1 1/1 1/1 1/4"); - DiscreteConditional actual(signature); - DecisionTreeFactor::shared_ptr actualFactor = actual.toFactor(); - EXPECT(assert_equal(expected, *actualFactor)); + DiscreteConditional expected(signature); + DecisionTreeFactor::shared_ptr expectedFactor = expected.toFactor(); + EXPECT(assert_equal(*expectedFactor, actual)); } /* ************************************************************************* */ -TEST( DiscreteConditional, Combine) { +TEST(DiscreteConditional, Combine) { DiscreteKey A(0, 2), B(1, 2); vector c; c.push_back(boost::make_shared(A | B = "1/2 2/1")); c.push_back(boost::make_shared(B % "1/2")); DecisionTreeFactor factor(A & B, "0.111111 0.444444 0.222222 0.222222"); - DiscreteConditional expected(2, factor); - DiscreteConditional::shared_ptr actual = DiscreteConditional::Combine( - c.begin(), c.end()); - EXPECT(assert_equal(expected, *actual,1e-5)); + DiscreteConditional actual(2, factor); + auto expected = DiscreteConditional::Combine(c.begin(), c.end()); + EXPECT(assert_equal(*expected, actual, 1e-5)); } /* ************************************************************************* */ -int main() { TestResult tr; return TestRegistry::runAllTests(tr); } +int main() { + TestResult tr; + return TestRegistry::runAllTests(tr); +} /* ************************************************************************* */ - diff --git a/gtsam/discrete/tests/testDiscreteMarginals.cpp b/gtsam/discrete/tests/testDiscreteMarginals.cpp index 4e9f956b6..e1eb92af3 100644 --- a/gtsam/discrete/tests/testDiscreteMarginals.cpp +++ b/gtsam/discrete/tests/testDiscreteMarginals.cpp @@ -146,8 +146,7 @@ TEST_UNSAFE( DiscreteMarginals, truss ) { /* ************************************************************************* */ // Second truss example with non-trivial factors -TEST_UNSAFE( DiscreteMarginals, truss2 ) { - +TEST_UNSAFE(DiscreteMarginals, truss2) { const int nrNodes = 5; const size_t nrStates = 2; @@ -160,40 +159,39 @@ TEST_UNSAFE( DiscreteMarginals, truss2 ) { // create graph and add three truss potentials DiscreteFactorGraph graph; - graph.add(key[0] & key[2] & key[4],"1 2 3 4 5 6 7 8"); - graph.add(key[1] & key[3] & key[4],"1 2 3 4 5 6 7 8"); - graph.add(key[2] & key[3] & key[4],"1 2 3 4 5 6 7 8"); + graph.add(key[0] & key[2] & key[4], "1 2 3 4 5 6 7 8"); + graph.add(key[1] & key[3] & key[4], "1 2 3 4 5 6 7 8"); + graph.add(key[2] & key[3] & key[4], "1 2 3 4 5 6 7 8"); // Calculate the marginals by brute force - vector allPosbValues = cartesianProduct( - key[0] & key[1] & key[2] & key[3] & key[4]); + vector allPosbValues = + cartesianProduct(key[0] & key[1] & key[2] & key[3] & key[4]); Vector T = Z_5x1, F = Z_5x1; for (size_t i = 0; i < allPosbValues.size(); ++i) { DiscreteFactor::Values x = allPosbValues[i]; double px = graph(x); - for (size_t j=0;j<5;j++) - if (x[j]) T[j]+=px; else F[j]+=px; - // cout << x[0] << " " << x[1] << " "<< x[2] << " " << x[3] << " " << x[4] << " :\t" << px << endl; + for (size_t j = 0; j < 5; j++) + if (x[j]) + T[j] += px; + else + F[j] += px; } // Check all marginals given by a sequential solver and Marginals -// DiscreteSequentialSolver solver(graph); + // DiscreteSequentialSolver solver(graph); DiscreteMarginals marginals(graph); - for (size_t j=0;j<5;j++) { - double sum = T[j]+F[j]; - T[j]/=sum; - F[j]/=sum; - -// // solver -// Vector actualV = solver.marginalProbabilities(key[j]); -// EXPECT(assert_equal((Vector(2) << F[j], T[j]), actualV)); + for (size_t j = 0; j < 5; j++) { + double sum = T[j] + F[j]; + T[j] /= sum; + F[j] /= sum; // Marginals vector table; - table += F[j],T[j]; - DecisionTreeFactor expectedM(key[j],table); + table += F[j], T[j]; + DecisionTreeFactor expectedM(key[j], table); DiscreteFactor::shared_ptr actualM = marginals(j); - EXPECT(assert_equal(expectedM, *boost::dynamic_pointer_cast(actualM))); + EXPECT(assert_equal( + expectedM, *boost::dynamic_pointer_cast(actualM))); } } diff --git a/gtsam/discrete/tests/testSignature.cpp b/gtsam/discrete/tests/testSignature.cpp index de47a00f3..049c455f7 100644 --- a/gtsam/discrete/tests/testSignature.cpp +++ b/gtsam/discrete/tests/testSignature.cpp @@ -11,36 +11,43 @@ /** * @file testSignature - * @brief Tests focusing on the details of Signatures to evaluate boost compliance + * @brief Tests focusing on the details of Signatures to evaluate boost + * compliance * @author Alex Cunningham * @date Sept 19th 2011 */ -#include #include - #include #include +#include +#include + using namespace std; using namespace gtsam; using namespace boost::assign; -DiscreteKey X(0,2), Y(1,3), Z(2,2); +DiscreteKey X(0, 2), Y(1, 3), Z(2, 2); /* ************************************************************************* */ TEST(testSignature, simple_conditional) { Signature sig(X | Y = "1/1 2/3 1/4"); + Signature::Table table = *sig.table(); + vector row[3]{{0.5, 0.5}, {0.4, 0.6}, {0.2, 0.8}}; + CHECK(row[0] == table[0]); + CHECK(row[1] == table[1]); + CHECK(row[2] == table[2]); DiscreteKey actKey = sig.key(); - LONGS_EQUAL((long)X.first, (long)actKey.first); + LONGS_EQUAL(X.first, actKey.first); - DiscreteKeys actKeys = sig.discreteKeysParentsFirst(); - LONGS_EQUAL(2, (long)actKeys.size()); - LONGS_EQUAL((long)Y.first, (long)actKeys.front().first); - LONGS_EQUAL((long)X.first, (long)actKeys.back().first); + DiscreteKeys actKeys = sig.discreteKeys(); + LONGS_EQUAL(2, actKeys.size()); + LONGS_EQUAL(X.first, actKeys.front().first); + LONGS_EQUAL(Y.first, actKeys.back().first); vector actCpt = sig.cpt(); - EXPECT_LONGS_EQUAL(6, (long)actCpt.size()); + EXPECT_LONGS_EQUAL(6, actCpt.size()); } /* ************************************************************************* */ @@ -54,17 +61,20 @@ TEST(testSignature, simple_conditional_nonparser) { Signature sig(X | Y = table); DiscreteKey actKey = sig.key(); - EXPECT_LONGS_EQUAL((long)X.first, (long)actKey.first); + EXPECT_LONGS_EQUAL(X.first, actKey.first); - DiscreteKeys actKeys = sig.discreteKeysParentsFirst(); - LONGS_EQUAL(2, (long)actKeys.size()); - LONGS_EQUAL((long)Y.first, (long)actKeys.front().first); - LONGS_EQUAL((long)X.first, (long)actKeys.back().first); + DiscreteKeys actKeys = sig.discreteKeys(); + LONGS_EQUAL(2, actKeys.size()); + LONGS_EQUAL(X.first, actKeys.front().first); + LONGS_EQUAL(Y.first, actKeys.back().first); vector actCpt = sig.cpt(); - EXPECT_LONGS_EQUAL(6, (long)actCpt.size()); + EXPECT_LONGS_EQUAL(6, actCpt.size()); } /* ************************************************************************* */ -int main() { TestResult tr; return TestRegistry::runAllTests(tr); } +int main() { + TestResult tr; + return TestRegistry::runAllTests(tr); +} /* ************************************************************************* */