Merge pull request #402 from borglab/feature/bayesnet_example
Two new discrete examplesrelease/4.3a0
						commit
						52fcdb51a8
					
				|  | @ -0,0 +1,83 @@ | |||
| /* ----------------------------------------------------------------------------
 | ||||
| 
 | ||||
|  * GTSAM Copyright 2010, Georgia Tech Research Corporation, | ||||
|  * Atlanta, Georgia 30332-0415 | ||||
|  * All Rights Reserved | ||||
|  * Authors: Frank Dellaert, et al. (see THANKS for the full author list) | ||||
| 
 | ||||
|  * See LICENSE for the license information | ||||
| 
 | ||||
|  * -------------------------------------------------------------------------- */ | ||||
| 
 | ||||
| /**
 | ||||
|  * @file  DiscreteBayesNetExample.cpp | ||||
|  * @brief   Discrete Bayes Net example with famous Asia Bayes Network | ||||
|  * @author  Frank Dellaert | ||||
|  * @date  JULY 10, 2020 | ||||
|  */ | ||||
| 
 | ||||
| #include <gtsam/discrete/DiscreteFactorGraph.h> | ||||
| #include <gtsam/discrete/DiscreteMarginals.h> | ||||
| #include <gtsam/inference/BayesNet-inst.h> | ||||
| 
 | ||||
| #include <iomanip> | ||||
| 
 | ||||
| using namespace std; | ||||
| using namespace gtsam; | ||||
| 
 | ||||
| int main(int argc, char **argv) { | ||||
|   DiscreteBayesNet asia; | ||||
|   DiscreteKey Asia(0, 2), Smoking(4, 2), Tuberculosis(3, 2), LungCancer(6, 2), | ||||
|       Bronchitis(7, 2), Either(5, 2), XRay(2, 2), Dyspnea(1, 2); | ||||
|   asia.add(Asia % "99/1"); | ||||
|   asia.add(Smoking % "50/50"); | ||||
| 
 | ||||
|   asia.add(Tuberculosis | Asia = "99/1 95/5"); | ||||
|   asia.add(LungCancer | Smoking = "99/1 90/10"); | ||||
|   asia.add(Bronchitis | Smoking = "70/30 40/60"); | ||||
| 
 | ||||
|   asia.add((Either | Tuberculosis, LungCancer) = "F T T T"); | ||||
| 
 | ||||
|   asia.add(XRay | Either = "95/5 2/98"); | ||||
|   asia.add((Dyspnea | Either, Bronchitis) = "9/1 2/8 3/7 1/9"); | ||||
| 
 | ||||
|   // print
 | ||||
|   vector<string> pretty = {"Asia",    "Dyspnea", "XRay",       "Tuberculosis", | ||||
|                            "Smoking", "Either",  "LungCancer", "Bronchitis"}; | ||||
|   auto formatter = [pretty](Key key) { return pretty[key]; }; | ||||
|   asia.print("Asia", formatter); | ||||
| 
 | ||||
|   // Convert to factor graph
 | ||||
|   DiscreteFactorGraph fg(asia); | ||||
| 
 | ||||
|   // Create solver and eliminate
 | ||||
|   Ordering ordering; | ||||
|   ordering += Key(0), Key(1), Key(2), Key(3), Key(4), Key(5), Key(6), Key(7); | ||||
|   DiscreteBayesNet::shared_ptr chordal = fg.eliminateSequential(ordering); | ||||
| 
 | ||||
|   // solve
 | ||||
|   DiscreteFactor::sharedValues mpe = chordal->optimize(); | ||||
|   GTSAM_PRINT(*mpe); | ||||
| 
 | ||||
|   // We can also build a Bayes tree (directed junction tree).
 | ||||
|   // The elimination order above will do fine:
 | ||||
|   auto bayesTree = fg.eliminateMultifrontal(ordering); | ||||
|   bayesTree->print("bayesTree", formatter); | ||||
| 
 | ||||
|   // add evidence, we were in Asia and we have dyspnea
 | ||||
|   fg.add(Asia, "0 1"); | ||||
|   fg.add(Dyspnea, "0 1"); | ||||
| 
 | ||||
|   // solve again, now with evidence
 | ||||
|   DiscreteBayesNet::shared_ptr chordal2 = fg.eliminateSequential(ordering); | ||||
|   DiscreteFactor::sharedValues mpe2 = chordal2->optimize(); | ||||
|   GTSAM_PRINT(*mpe2); | ||||
| 
 | ||||
|   // We can also sample from it
 | ||||
|   cout << "\n10 samples:" << endl; | ||||
|   for (size_t i = 0; i < 10; i++) { | ||||
|     DiscreteFactor::sharedValues sample = chordal2->sample(); | ||||
|     GTSAM_PRINT(*sample); | ||||
|   } | ||||
|   return 0; | ||||
| } | ||||
|  | @ -10,7 +10,7 @@ | |||
|  * -------------------------------------------------------------------------- */ | ||||
| 
 | ||||
| /**
 | ||||
|  * @file  DiscreteBayesNet_graph.cpp | ||||
|  * @file  DiscreteBayesNet_FG.cpp | ||||
|  * @brief   Discrete Bayes Net example using Factor Graphs | ||||
|  * @author  Abhijit | ||||
|  * @date  Jun 4, 2012 | ||||
|  |  | |||
|  | @ -0,0 +1,94 @@ | |||
| /* ----------------------------------------------------------------------------
 | ||||
| 
 | ||||
|  * GTSAM Copyright 2010-2020, Georgia Tech Research Corporation, | ||||
|  * Atlanta, Georgia 30332-0415 | ||||
|  * All Rights Reserved | ||||
|  * Authors: Frank Dellaert, et al. (see THANKS for the full author list) | ||||
| 
 | ||||
|  * See LICENSE for the license information | ||||
| 
 | ||||
|  * -------------------------------------------------------------------------- */ | ||||
| 
 | ||||
| /**
 | ||||
|  * @file  DiscreteBayesNetExample.cpp | ||||
|  * @brief   Hidden Markov Model example, discrete. | ||||
|  * @author  Frank Dellaert | ||||
|  * @date  July 12, 2020 | ||||
|  */ | ||||
| 
 | ||||
| #include <gtsam/discrete/DiscreteFactorGraph.h> | ||||
| #include <gtsam/discrete/DiscreteMarginals.h> | ||||
| #include <gtsam/inference/BayesNet-inst.h> | ||||
| 
 | ||||
| #include <iomanip> | ||||
| #include <sstream> | ||||
| 
 | ||||
| using namespace std; | ||||
| using namespace gtsam; | ||||
| 
 | ||||
| int main(int argc, char **argv) { | ||||
|   const int nrNodes = 4; | ||||
|   const size_t nrStates = 3; | ||||
| 
 | ||||
|   // Define variables as well as ordering
 | ||||
|   Ordering ordering; | ||||
|   vector<DiscreteKey> keys; | ||||
|   for (int k = 0; k < nrNodes; k++) { | ||||
|     DiscreteKey key_i(k, nrStates); | ||||
|     keys.push_back(key_i); | ||||
|     ordering.emplace_back(k); | ||||
|   } | ||||
| 
 | ||||
|   // Create HMM as a DiscreteBayesNet
 | ||||
|   DiscreteBayesNet hmm; | ||||
| 
 | ||||
|   // Define backbone
 | ||||
|   const string transition = "8/1/1 1/8/1 1/1/8"; | ||||
|   for (int k = 1; k < nrNodes; k++) { | ||||
|     hmm.add(keys[k] | keys[k - 1] = transition); | ||||
|   } | ||||
| 
 | ||||
|   // Add some measurements, not needed for all time steps!
 | ||||
|   hmm.add(keys[0] % "7/2/1"); | ||||
|   hmm.add(keys[1] % "1/9/0"); | ||||
|   hmm.add(keys.back() % "5/4/1"); | ||||
| 
 | ||||
|   // print
 | ||||
|   hmm.print("HMM"); | ||||
| 
 | ||||
|   // Convert to factor graph
 | ||||
|   DiscreteFactorGraph factorGraph(hmm); | ||||
| 
 | ||||
|   // Create solver and eliminate
 | ||||
|   // This will create a DAG ordered with arrow of time reversed
 | ||||
|   DiscreteBayesNet::shared_ptr chordal = | ||||
|       factorGraph.eliminateSequential(ordering); | ||||
|   chordal->print("Eliminated"); | ||||
| 
 | ||||
|   // solve
 | ||||
|   DiscreteFactor::sharedValues mpe = chordal->optimize(); | ||||
|   GTSAM_PRINT(*mpe); | ||||
| 
 | ||||
|   // We can also sample from it
 | ||||
|   cout << "\n10 samples:" << endl; | ||||
|   for (size_t k = 0; k < 10; k++) { | ||||
|     DiscreteFactor::sharedValues sample = chordal->sample(); | ||||
|     GTSAM_PRINT(*sample); | ||||
|   } | ||||
| 
 | ||||
|   // Or compute the marginals. This re-eliminates the FG into a Bayes tree
 | ||||
|   cout << "\nComputing Node Marginals .." << endl; | ||||
|   DiscreteMarginals marginals(factorGraph); | ||||
|   for (int k = 0; k < nrNodes; k++) { | ||||
|     Vector margProbs = marginals.marginalProbabilities(keys[k]); | ||||
|     stringstream ss; | ||||
|     ss << "marginal " << k; | ||||
|     print(margProbs, ss.str()); | ||||
|   } | ||||
| 
 | ||||
|   // TODO(frank): put in the glue to have DiscreteMarginals produce *arbitrary*
 | ||||
|   // joints efficiently, by the Bayes tree shortcut magic. All the code is there
 | ||||
|   // but it's not yet connected.
 | ||||
| 
 | ||||
|   return 0; | ||||
| } | ||||
|  | @ -27,6 +27,7 @@ | |||
| #include <algorithm> | ||||
| #include <random> | ||||
| #include <stdexcept> | ||||
| #include <string> | ||||
| #include <vector> | ||||
| 
 | ||||
| using namespace std; | ||||
|  | @ -61,16 +62,26 @@ DiscreteConditional::DiscreteConditional(const DecisionTreeFactor& joint, | |||
| } | ||||
| 
 | ||||
| /* ******************************************************************************** */ | ||||
| DiscreteConditional::DiscreteConditional(const Signature& signature) : | ||||
|         BaseFactor(signature.discreteKeysParentsFirst(), signature.cpt()), BaseConditional( | ||||
|             1) { | ||||
| } | ||||
| DiscreteConditional::DiscreteConditional(const Signature& signature) | ||||
|     : BaseFactor(signature.discreteKeys(), signature.cpt()), | ||||
|       BaseConditional(1) {} | ||||
| 
 | ||||
| /* ******************************************************************************** */ | ||||
| void DiscreteConditional::print(const std::string& s, | ||||
|     const KeyFormatter& formatter) const { | ||||
|   std::cout << s << std::endl; | ||||
|   Potentials::print(s); | ||||
| void DiscreteConditional::print(const string& s, | ||||
|                                 const KeyFormatter& formatter) const { | ||||
|   cout << s << " P( "; | ||||
|   for (const_iterator it = beginFrontals(); it != endFrontals(); ++it) { | ||||
|     cout << formatter(*it) << " "; | ||||
|   } | ||||
|   if (nrParents()) { | ||||
|     cout << "| "; | ||||
|     for (const_iterator it = beginParents(); it != endParents(); ++it) { | ||||
|       cout << formatter(*it) << " "; | ||||
|     } | ||||
|   } | ||||
|   cout << ")"; | ||||
|   Potentials::print(""); | ||||
|   cout << endl; | ||||
| } | ||||
| 
 | ||||
| /* ******************************************************************************** */ | ||||
|  | @ -173,55 +184,28 @@ size_t DiscreteConditional::solve(const Values& parentsValues) const { | |||
| 
 | ||||
| /* ******************************************************************************** */ | ||||
| size_t DiscreteConditional::sample(const Values& parentsValues) const { | ||||
| 
 | ||||
|   static mt19937 rng(2); // random number generator
 | ||||
| 
 | ||||
|   bool debug = ISDEBUG("DiscreteConditional::sample"); | ||||
|   static mt19937 rng(2);  // random number generator
 | ||||
| 
 | ||||
|   // Get the correct conditional density
 | ||||
|   ADT pFS = choose(parentsValues); // P(F|S=parentsValues)
 | ||||
|   if (debug) | ||||
|     GTSAM_PRINT(pFS); | ||||
|   ADT pFS = choose(parentsValues);  // P(F|S=parentsValues)
 | ||||
| 
 | ||||
|   // get cumulative distribution function (cdf)
 | ||||
|   // TODO, only works for one key now, seems horribly slow this way
 | ||||
|   // TODO(Duy): only works for one key now, seems horribly slow this way
 | ||||
|   assert(nrFrontals() == 1); | ||||
|   Key j = (firstFrontalKey()); | ||||
|   size_t nj = cardinality(j); | ||||
|   vector<double> cdf(nj); | ||||
|   Key key = firstFrontalKey(); | ||||
|   size_t nj = cardinality(key); | ||||
|   vector<double> p(nj); | ||||
|   Values frontals; | ||||
|   double sum = 0; | ||||
|   for (size_t value = 0; value < nj; value++) { | ||||
|     frontals[j] = value; | ||||
|     double pValueS = pFS(frontals); // P(F=value|S=parentsValues)
 | ||||
|     sum += pValueS; // accumulate
 | ||||
|     if (debug) | ||||
|       cout << sum << " "; | ||||
|     if (pValueS == 1) { | ||||
|       if (debug) | ||||
|         cout << "--> " << value << endl; | ||||
|       return value; // shortcut exit
 | ||||
|     frontals[key] = value; | ||||
|     p[value] = pFS(frontals);  // P(F=value|S=parentsValues)
 | ||||
|     if (p[value] == 1.0) { | ||||
|       return value;  // shortcut exit
 | ||||
|     } | ||||
|     cdf[value] = sum; | ||||
|   } | ||||
| 
 | ||||
|   // inspired by http://www.boost.org/doc/libs/1_46_1/doc/html/boost_random/tutorial.html
 | ||||
|   uniform_real_distribution<double> dist(0, cdf.back()); | ||||
|   size_t sampled = lower_bound(cdf.begin(), cdf.end(), dist(rng)) - cdf.begin(); | ||||
|   if (debug) | ||||
|     cout << "-> " << sampled << endl; | ||||
| 
 | ||||
|   return sampled; | ||||
| 
 | ||||
|   return 0; | ||||
|   std::discrete_distribution<size_t> distribution(p.begin(), p.end()); | ||||
|   return distribution(rng); | ||||
| } | ||||
| 
 | ||||
| /* ******************************************************************************** */ | ||||
| //void DiscreteConditional::permuteWithInverse(
 | ||||
| //    const Permutation& inversePermutation) {
 | ||||
| //  IndexConditionalOrdered::permuteWithInverse(inversePermutation);
 | ||||
| //  Potentials::permuteWithInverse(inversePermutation);
 | ||||
| //}
 | ||||
| /* ******************************************************************************** */ | ||||
| 
 | ||||
| }// namespace
 | ||||
|  |  | |||
|  | @ -15,50 +15,52 @@ | |||
|  * @author Frank Dellaert | ||||
|  */ | ||||
| 
 | ||||
| #include <gtsam/discrete/Potentials.h> | ||||
| #include <gtsam/discrete/DecisionTree-inl.h> | ||||
| #include <gtsam/discrete/Potentials.h> | ||||
| 
 | ||||
| #include <boost/format.hpp> | ||||
| 
 | ||||
| #include <string> | ||||
| 
 | ||||
| using namespace std; | ||||
| 
 | ||||
| namespace gtsam { | ||||
| 
 | ||||
|   // explicit instantiation
 | ||||
|   template class DecisionTree<Key, double> ; | ||||
|   template class AlgebraicDecisionTree<Key> ; | ||||
| // explicit instantiation
 | ||||
| template class DecisionTree<Key, double>; | ||||
| template class AlgebraicDecisionTree<Key>; | ||||
| 
 | ||||
|   /* ************************************************************************* */ | ||||
|   double Potentials::safe_div(const double& a, const double& b) { | ||||
|     // cout << boost::format("%g / %g = %g\n") % a % b % ((a == 0) ? 0 : (a / b));
 | ||||
|     // The use for safe_div is when we divide the product factor by the sum factor.
 | ||||
|     // If the product or sum is zero, we accord zero probability to the event.
 | ||||
|     return (a == 0 || b == 0) ? 0 : (a / b); | ||||
|   } | ||||
| /* ************************************************************************* */ | ||||
| double Potentials::safe_div(const double& a, const double& b) { | ||||
|   // cout << boost::format("%g / %g = %g\n") % a % b % ((a == 0) ? 0 : (a / b));
 | ||||
|   // The use for safe_div is when we divide the product factor by the sum
 | ||||
|   // factor. If the product or sum is zero, we accord zero probability to the
 | ||||
|   // event.
 | ||||
|   return (a == 0 || b == 0) ? 0 : (a / b); | ||||
| } | ||||
| 
 | ||||
|   /* ******************************************************************************** */ | ||||
|   Potentials::Potentials() : | ||||
|       ADT(1.0) { | ||||
|   } | ||||
| /* ********************************************************************************
 | ||||
|  */ | ||||
| Potentials::Potentials() : ADT(1.0) {} | ||||
| 
 | ||||
|   /* ******************************************************************************** */ | ||||
|   Potentials::Potentials(const DiscreteKeys& keys, const ADT& decisionTree) : | ||||
|       ADT(decisionTree), cardinalities_(keys.cardinalities()) { | ||||
|   } | ||||
| /* ********************************************************************************
 | ||||
|  */ | ||||
| Potentials::Potentials(const DiscreteKeys& keys, const ADT& decisionTree) | ||||
|     : ADT(decisionTree), cardinalities_(keys.cardinalities()) {} | ||||
| 
 | ||||
|   /* ************************************************************************* */ | ||||
|   bool Potentials::equals(const Potentials& other, double tol) const { | ||||
|     return ADT::equals(other, tol); | ||||
|   } | ||||
| /* ************************************************************************* */ | ||||
| bool Potentials::equals(const Potentials& other, double tol) const { | ||||
|   return ADT::equals(other, tol); | ||||
| } | ||||
| 
 | ||||
|   /* ************************************************************************* */ | ||||
|   void Potentials::print(const string& s, | ||||
|       const KeyFormatter& formatter) const { | ||||
|     cout << s << "\n  Cardinalities: "; | ||||
|     for(const DiscreteKey& key: cardinalities_) | ||||
|       cout << formatter(key.first) << "=" << formatter(key.second) << " "; | ||||
|     cout << endl; | ||||
|     ADT::print(" "); | ||||
|   } | ||||
| /* ************************************************************************* */ | ||||
| void Potentials::print(const string& s, const KeyFormatter& formatter) const { | ||||
|   cout << s << "\n  Cardinalities: {"; | ||||
|   for (const DiscreteKey& key : cardinalities_) | ||||
|     cout << formatter(key.first) << ":" << key.second << ", "; | ||||
|   cout << "}" << endl; | ||||
|   ADT::print(" "); | ||||
| } | ||||
| //
 | ||||
| //  /* ************************************************************************* */
 | ||||
| //  template<class P>
 | ||||
|  | @ -95,4 +97,4 @@ namespace gtsam { | |||
| 
 | ||||
|   /* ************************************************************************* */ | ||||
| 
 | ||||
| } // namespace gtsam
 | ||||
| }  // namespace gtsam
 | ||||
|  |  | |||
|  | @ -122,28 +122,30 @@ namespace gtsam { | |||
|     key_(key) { | ||||
|   } | ||||
| 
 | ||||
|   DiscreteKeys Signature::discreteKeysParentsFirst() const { | ||||
|   DiscreteKeys Signature::discreteKeys() const { | ||||
|     DiscreteKeys keys; | ||||
|     for(const DiscreteKey& key: parents_) | ||||
|       keys.push_back(key); | ||||
|     keys.push_back(key_); | ||||
|     for (const DiscreteKey& key : parents_) keys.push_back(key); | ||||
|     return keys; | ||||
|   } | ||||
| 
 | ||||
|   KeyVector Signature::indices() const { | ||||
|     KeyVector js; | ||||
|     js.push_back(key_.first); | ||||
|     for(const DiscreteKey& key: parents_) | ||||
|       js.push_back(key.first); | ||||
|     for (const DiscreteKey& key : parents_) js.push_back(key.first); | ||||
|     return js; | ||||
|   } | ||||
| 
 | ||||
|   vector<double> Signature::cpt() const { | ||||
|     vector<double> cpt; | ||||
|     if (table_) { | ||||
|       for(const Row& row: *table_) | ||||
|               for(const double& x: row) | ||||
|                       cpt.push_back(x); | ||||
|       const size_t nrStates = table_->at(0).size(); | ||||
|       for (size_t j = 0; j < nrStates; j++) { | ||||
|         for (const Row& row : *table_) { | ||||
|           assert(row.size() == nrStates); | ||||
|           cpt.push_back(row[j]); | ||||
|         } | ||||
|       } | ||||
|     } | ||||
|     return cpt; | ||||
|   } | ||||
|  |  | |||
|  | @ -86,8 +86,8 @@ namespace gtsam { | |||
|       return parents_; | ||||
|     } | ||||
| 
 | ||||
|     /** All keys, with variable key last */ | ||||
|     DiscreteKeys discreteKeysParentsFirst() const; | ||||
|     /** All keys, with variable key first */ | ||||
|     DiscreteKeys discreteKeys() const; | ||||
| 
 | ||||
|     /** All key indices, with variable key first */ | ||||
|     KeyVector indices() const; | ||||
|  |  | |||
|  | @ -132,7 +132,7 @@ TEST(ADT, example3) | |||
| 
 | ||||
| /** Convert Signature into CPT */ | ||||
| ADT create(const Signature& signature) { | ||||
|   ADT p(signature.discreteKeysParentsFirst(), signature.cpt()); | ||||
|   ADT p(signature.discreteKeys(), signature.cpt()); | ||||
|   static size_t count = 0; | ||||
|   const DiscreteKey& key = signature.key(); | ||||
|   string dotfile = (boost::format("CPT-%03d-%d") % ++count % key.first).str(); | ||||
|  | @ -181,19 +181,20 @@ TEST(ADT, joint) | |||
|   dot(joint, "Asia-ASTLBEX"); | ||||
|   joint = apply(joint, pD, &mul); | ||||
|   dot(joint, "Asia-ASTLBEXD"); | ||||
|   EXPECT_LONGS_EQUAL(346, (long)muls); | ||||
|   EXPECT_LONGS_EQUAL(346, muls); | ||||
|   gttoc_(asiaJoint); | ||||
|   tictoc_getNode(asiaJointNode, asiaJoint); | ||||
|   elapsed = asiaJointNode->secs() + asiaJointNode->wall(); | ||||
|   tictoc_reset_(); | ||||
|   printCounts("Asia joint"); | ||||
| 
 | ||||
|   // Form P(A,S,T,L) = P(A) P(S) P(T|A) P(L|S)
 | ||||
|   ADT pASTL = pA; | ||||
|   pASTL = apply(pASTL, pS, &mul); | ||||
|   pASTL = apply(pASTL, pT, &mul); | ||||
|   pASTL = apply(pASTL, pL, &mul); | ||||
| 
 | ||||
|   // test combine
 | ||||
|   // test combine to check that P(A) = \sum_{S,T,L} P(A,S,T,L)
 | ||||
|   ADT fAa = pASTL.combine(L, &add_).combine(T, &add_).combine(S, &add_); | ||||
|   EXPECT(assert_equal(pA, fAa)); | ||||
|   ADT fAb = pASTL.combine(S, &add_).combine(T, &add_).combine(L, &add_); | ||||
|  |  | |||
|  | @ -18,110 +18,135 @@ | |||
| 
 | ||||
| #include <gtsam/discrete/DiscreteBayesNet.h> | ||||
| #include <gtsam/discrete/DiscreteFactorGraph.h> | ||||
| #include <gtsam/base/Testable.h> | ||||
| #include <gtsam/discrete/DiscreteMarginals.h> | ||||
| #include <gtsam/base/debug.h> | ||||
| #include <gtsam/base/Testable.h> | ||||
| #include <gtsam/base/Vector.h> | ||||
| 
 | ||||
| #include <CppUnitLite/TestHarness.h> | ||||
| 
 | ||||
| #include <boost/assign/std/map.hpp> | ||||
| 
 | ||||
| #include <boost/assign/list_inserter.hpp> | ||||
| #include <boost/assign/std/map.hpp> | ||||
| 
 | ||||
| using namespace boost::assign; | ||||
| 
 | ||||
| #include <iostream> | ||||
| #include <string> | ||||
| #include <vector> | ||||
| 
 | ||||
| using namespace std; | ||||
| using namespace gtsam; | ||||
| 
 | ||||
| /* ************************************************************************* */ | ||||
| TEST(DiscreteBayesNet, Asia) | ||||
| { | ||||
| TEST(DiscreteBayesNet, bayesNet) { | ||||
|   DiscreteBayesNet bayesNet; | ||||
|   DiscreteKey Parent(0, 2), Child(1, 2); | ||||
| 
 | ||||
|   auto prior = boost::make_shared<DiscreteConditional>(Parent % "6/4"); | ||||
|   CHECK(assert_equal(Potentials::ADT({Parent}, "0.6 0.4"), | ||||
|                      (Potentials::ADT)*prior)); | ||||
|   bayesNet.push_back(prior); | ||||
| 
 | ||||
|   auto conditional = | ||||
|       boost::make_shared<DiscreteConditional>(Child | Parent = "7/3 8/2"); | ||||
|   EXPECT_LONGS_EQUAL(1, *(conditional->beginFrontals())); | ||||
|   Potentials::ADT expected(Child & Parent, "0.7 0.8 0.3 0.2"); | ||||
|   CHECK(assert_equal(expected, (Potentials::ADT)*conditional)); | ||||
|   bayesNet.push_back(conditional); | ||||
| 
 | ||||
|   DiscreteFactorGraph fg(bayesNet); | ||||
|   LONGS_EQUAL(2, fg.back()->size()); | ||||
| 
 | ||||
|   // Check the marginals
 | ||||
|   const double expectedMarginal[2]{0.4, 0.6 * 0.3 + 0.4 * 0.2}; | ||||
|   DiscreteMarginals marginals(fg); | ||||
|   for (size_t j = 0; j < 2; j++) { | ||||
|     Vector FT = marginals.marginalProbabilities(DiscreteKey(j, 2)); | ||||
|     EXPECT_DOUBLES_EQUAL(expectedMarginal[j], FT[1], 1e-3); | ||||
|     EXPECT_DOUBLES_EQUAL(FT[0], 1.0 - FT[1], 1e-9); | ||||
|   } | ||||
| } | ||||
| 
 | ||||
| /* ************************************************************************* */ | ||||
| TEST(DiscreteBayesNet, Asia) { | ||||
|   DiscreteBayesNet asia; | ||||
| //  DiscreteKey A("Asia"), S("Smoking"), T("Tuberculosis"), L("LungCancer"), B(
 | ||||
| //      "Bronchitis"), E("Either"), X("XRay"), D("Dyspnoea");
 | ||||
|   DiscreteKey A(0,2), S(4,2), T(3,2), L(6,2), B(7,2), E(5,2), X(2,2), D(1,2); | ||||
|   DiscreteKey Asia(0, 2), Smoking(4, 2), Tuberculosis(3, 2), LungCancer(6, 2), | ||||
|       Bronchitis(7, 2), Either(5, 2), XRay(2, 2), Dyspnea(1, 2); | ||||
| 
 | ||||
|   // TODO: make a version that doesn't use the parser
 | ||||
|   asia.add(A % "99/1"); | ||||
|   asia.add(S % "50/50"); | ||||
|   asia.add(Asia % "99/1"); | ||||
|   asia.add(Smoking % "50/50"); | ||||
| 
 | ||||
|   asia.add(T | A = "99/1 95/5"); | ||||
|   asia.add(L | S = "99/1 90/10"); | ||||
|   asia.add(B | S = "70/30 40/60"); | ||||
|   asia.add(Tuberculosis | Asia = "99/1 95/5"); | ||||
|   asia.add(LungCancer | Smoking = "99/1 90/10"); | ||||
|   asia.add(Bronchitis | Smoking = "70/30 40/60"); | ||||
| 
 | ||||
|   asia.add((E | T, L) = "F T T T"); | ||||
|   asia.add((Either | Tuberculosis, LungCancer) = "F T T T"); | ||||
| 
 | ||||
|   asia.add(X | E = "95/5 2/98"); | ||||
|   // next lines are same as asia.add((D | E, B) = "9/1 2/8 3/7 1/9");
 | ||||
|   DiscreteConditional::shared_ptr actual = | ||||
|       boost::make_shared<DiscreteConditional>((D | E, B) = "9/1 2/8 3/7 1/9"); | ||||
|   asia.push_back(actual); | ||||
|   //  GTSAM_PRINT(asia);
 | ||||
|   asia.add(XRay | Either = "95/5 2/98"); | ||||
|   asia.add((Dyspnea | Either, Bronchitis) = "9/1 2/8 3/7 1/9"); | ||||
| 
 | ||||
|   // Convert to factor graph
 | ||||
|   DiscreteFactorGraph fg(asia); | ||||
| //    GTSAM_PRINT(fg);
 | ||||
|   LONGS_EQUAL(3,fg.back()->size()); | ||||
|   Potentials::ADT expected(B & D & E, "0.9 0.3 0.1 0.7 0.2 0.1 0.8 0.9"); | ||||
|   CHECK(assert_equal(expected,(Potentials::ADT)*actual)); | ||||
|   LONGS_EQUAL(3, fg.back()->size()); | ||||
| 
 | ||||
|   // Check the marginals we know (of the parent-less nodes)
 | ||||
|   DiscreteMarginals marginals(fg); | ||||
|   Vector2 va(0.99, 0.01), vs(0.5, 0.5); | ||||
|   EXPECT(assert_equal(va, marginals.marginalProbabilities(Asia))); | ||||
|   EXPECT(assert_equal(vs, marginals.marginalProbabilities(Smoking))); | ||||
| 
 | ||||
|   // Create solver and eliminate
 | ||||
|   Ordering ordering; | ||||
|   ordering += Key(0),Key(1),Key(2),Key(3),Key(4),Key(5),Key(6),Key(7); | ||||
|   ordering += Key(0), Key(1), Key(2), Key(3), Key(4), Key(5), Key(6), Key(7); | ||||
|   DiscreteBayesNet::shared_ptr chordal = fg.eliminateSequential(ordering); | ||||
| //    GTSAM_PRINT(*chordal);
 | ||||
|   DiscreteConditional expected2(B % "11/9"); | ||||
|   CHECK(assert_equal(expected2,*chordal->back())); | ||||
|   DiscreteConditional expected2(Bronchitis % "11/9"); | ||||
|   EXPECT(assert_equal(expected2, *chordal->back())); | ||||
| 
 | ||||
|   // solve
 | ||||
|   DiscreteFactor::sharedValues actualMPE = chordal->optimize(); | ||||
|   DiscreteFactor::Values expectedMPE; | ||||
|   insert(expectedMPE)(A.first, 0)(D.first, 0)(X.first, 0)(T.first, 0)(S.first, | ||||
|       0)(E.first, 0)(L.first, 0)(B.first, 0); | ||||
|   insert(expectedMPE)(Asia.first, 0)(Dyspnea.first, 0)(XRay.first, 0)( | ||||
|       Tuberculosis.first, 0)(Smoking.first, 0)(Either.first, 0)( | ||||
|       LungCancer.first, 0)(Bronchitis.first, 0); | ||||
|   EXPECT(assert_equal(expectedMPE, *actualMPE)); | ||||
| 
 | ||||
|   // add evidence, we were in Asia and we have Dispnoea
 | ||||
|   fg.add(A, "0 1"); | ||||
|   fg.add(D, "0 1"); | ||||
| //  fg.product().dot("fg");
 | ||||
|   // add evidence, we were in Asia and we have dyspnea
 | ||||
|   fg.add(Asia, "0 1"); | ||||
|   fg.add(Dyspnea, "0 1"); | ||||
| 
 | ||||
|   // solve again, now with evidence
 | ||||
|   DiscreteBayesNet::shared_ptr chordal2 = fg.eliminateSequential(ordering); | ||||
| //  GTSAM_PRINT(*chordal2);
 | ||||
|   DiscreteFactor::sharedValues actualMPE2 = chordal2->optimize(); | ||||
|   DiscreteFactor::Values expectedMPE2; | ||||
|   insert(expectedMPE2)(A.first, 1)(D.first, 1)(X.first, 0)(T.first, 0)(S.first, | ||||
|       1)(E.first, 0)(L.first, 0)(B.first, 1); | ||||
|   insert(expectedMPE2)(Asia.first, 1)(Dyspnea.first, 1)(XRay.first, 0)( | ||||
|       Tuberculosis.first, 0)(Smoking.first, 1)(Either.first, 0)( | ||||
|       LungCancer.first, 0)(Bronchitis.first, 1); | ||||
|   EXPECT(assert_equal(expectedMPE2, *actualMPE2)); | ||||
| 
 | ||||
|   // now sample from it
 | ||||
|   DiscreteFactor::Values expectedSample; | ||||
|   SETDEBUG("DiscreteConditional::sample", false); | ||||
|   insert(expectedSample)(A.first, 1)(D.first, 1)(X.first, 1)(T.first, 0)( | ||||
|       S.first, 1)(E.first, 1)(L.first, 1)(B.first, 0); | ||||
|   insert(expectedSample)(Asia.first, 1)(Dyspnea.first, 1)(XRay.first, 1)( | ||||
|       Tuberculosis.first, 0)(Smoking.first, 1)(Either.first, 1)( | ||||
|       LungCancer.first, 1)(Bronchitis.first, 0); | ||||
|   DiscreteFactor::sharedValues actualSample = chordal2->sample(); | ||||
|   EXPECT(assert_equal(expectedSample, *actualSample)); | ||||
| } | ||||
| 
 | ||||
| /* ************************************************************************* */ | ||||
| TEST_UNSAFE(DiscreteBayesNet, Sugar) | ||||
| { | ||||
|   DiscreteKey T(0,2), L(1,2), E(2,2), D(3,2), C(8,3), S(7,2); | ||||
| TEST_UNSAFE(DiscreteBayesNet, Sugar) { | ||||
|   DiscreteKey T(0, 2), L(1, 2), E(2, 2), C(8, 3), S(7, 2); | ||||
| 
 | ||||
|   DiscreteBayesNet bn; | ||||
| 
 | ||||
|   // test some mistakes
 | ||||
|   //  add(bn, D);
 | ||||
|   //  add(bn, D | E);
 | ||||
|   //  add(bn, D | E = "blah");
 | ||||
| 
 | ||||
|   // try logic
 | ||||
|   bn.add((E | T, L) = "OR"); | ||||
|   bn.add((E | T, L) = "AND"); | ||||
| 
 | ||||
|   //  // try multivalued
 | ||||
|  bn.add(C % "1/1/2"); | ||||
|  bn.add(C | S = "1/1/2 5/2/3"); | ||||
|   // try multivalued
 | ||||
|   bn.add(C % "1/1/2"); | ||||
|   bn.add(C | S = "1/1/2 5/2/3"); | ||||
| } | ||||
| 
 | ||||
| /* ************************************************************************* */ | ||||
|  | @ -130,4 +155,3 @@ int main() { | |||
|   return TestRegistry::runAllTests(tr); | ||||
| } | ||||
| /* ************************************************************************* */ | ||||
| 
 | ||||
|  |  | |||
|  | @ -80,6 +80,12 @@ TEST_UNSAFE(DiscreteBayesTree, ThinTree) { | |||
|     bayesTree->saveGraph("/tmp/discreteBayesTree.dot"); | ||||
|   } | ||||
| 
 | ||||
|   // Check frontals and parents
 | ||||
|   for (size_t i : {13, 14, 9, 3, 2, 8, 1, 0, 10, 5, 4}) { | ||||
|     auto clique_i = (*bayesTree)[i]; | ||||
|     EXPECT_LONGS_EQUAL(i, *(clique_i->conditional_->beginFrontals())); | ||||
|   } | ||||
| 
 | ||||
|   auto R = bayesTree->roots().front(); | ||||
| 
 | ||||
|   // Check whether BN and BT give the same answer on all configurations
 | ||||
|  | @ -104,16 +110,22 @@ TEST_UNSAFE(DiscreteBayesTree, ThinTree) { | |||
|     double px = bayesTree->evaluate(x); | ||||
|     for (size_t i = 0; i < 15; i++) | ||||
|       if (x[i]) marginals[i] += px; | ||||
|     if (x[12] && x[14]) joint_12_14 += px; | ||||
|     if (x[9] && x[12] && x[14]) joint_9_12_14 += px; | ||||
|     if (x[8] && x[12] && x[14]) joint_8_12_14 += px; | ||||
|     if (x[12] && x[14]) { | ||||
|       joint_12_14 += px; | ||||
|       if (x[9]) joint_9_12_14 += px; | ||||
|       if (x[8]) joint_8_12_14 += px; | ||||
|     } | ||||
|     if (x[8] && x[12]) joint_8_12 += px; | ||||
|     if (x[8] && x[2]) joint82 += px; | ||||
|     if (x[1] && x[2]) joint12 += px; | ||||
|     if (x[2] && x[4]) joint24 += px; | ||||
|     if (x[4] && x[5]) joint45 += px; | ||||
|     if (x[4] && x[6]) joint46 += px; | ||||
|     if (x[4] && x[11]) joint_4_11 += px; | ||||
|     if (x[2]) { | ||||
|       if (x[8]) joint82 += px; | ||||
|       if (x[1]) joint12 += px; | ||||
|     } | ||||
|     if (x[4]) { | ||||
|       if (x[2]) joint24 += px; | ||||
|       if (x[5]) joint45 += px; | ||||
|       if (x[6]) joint46 += px; | ||||
|       if (x[11]) joint_4_11 += px; | ||||
|     } | ||||
|     if (x[11] && x[13]) { | ||||
|       joint_11_13 += px; | ||||
|       if (x[8] && x[12]) joint_8_11_12_13 += px; | ||||
|  | @ -129,50 +141,50 @@ TEST_UNSAFE(DiscreteBayesTree, ThinTree) { | |||
|   DiscreteFactor::Values all1 = allPosbValues.back(); | ||||
| 
 | ||||
|   // check separator marginal P(S0)
 | ||||
|   auto c = (*bayesTree)[0]; | ||||
|   auto clique = (*bayesTree)[0]; | ||||
|   DiscreteFactorGraph separatorMarginal0 = | ||||
|       c->separatorMarginal(EliminateDiscrete); | ||||
|       clique->separatorMarginal(EliminateDiscrete); | ||||
|   DOUBLES_EQUAL(joint_8_12, separatorMarginal0(all1), 1e-9); | ||||
| 
 | ||||
|   // check separator marginal P(S9), should be P(14)
 | ||||
|   c = (*bayesTree)[9]; | ||||
|   clique = (*bayesTree)[9]; | ||||
|   DiscreteFactorGraph separatorMarginal9 = | ||||
|       c->separatorMarginal(EliminateDiscrete); | ||||
|       clique->separatorMarginal(EliminateDiscrete); | ||||
|   DOUBLES_EQUAL(marginals[14], separatorMarginal9(all1), 1e-9); | ||||
| 
 | ||||
|   // check separator marginal of root, should be empty
 | ||||
|   c = (*bayesTree)[11]; | ||||
|   clique = (*bayesTree)[11]; | ||||
|   DiscreteFactorGraph separatorMarginal11 = | ||||
|       c->separatorMarginal(EliminateDiscrete); | ||||
|       clique->separatorMarginal(EliminateDiscrete); | ||||
|   LONGS_EQUAL(0, separatorMarginal11.size()); | ||||
| 
 | ||||
|   // check shortcut P(S9||R) to root
 | ||||
|   c = (*bayesTree)[9]; | ||||
|   DiscreteBayesNet shortcut = c->shortcut(R, EliminateDiscrete); | ||||
|   clique = (*bayesTree)[9]; | ||||
|   DiscreteBayesNet shortcut = clique->shortcut(R, EliminateDiscrete); | ||||
|   LONGS_EQUAL(1, shortcut.size()); | ||||
|   DOUBLES_EQUAL(joint_11_13_14 / joint_11_13, shortcut.evaluate(all1), 1e-9); | ||||
| 
 | ||||
|   // check shortcut P(S8||R) to root
 | ||||
|   c = (*bayesTree)[8]; | ||||
|   shortcut = c->shortcut(R, EliminateDiscrete); | ||||
|   clique = (*bayesTree)[8]; | ||||
|   shortcut = clique->shortcut(R, EliminateDiscrete); | ||||
|   DOUBLES_EQUAL(joint_11_12_13_14 / joint_11_13, shortcut.evaluate(all1), 1e-9); | ||||
| 
 | ||||
|   // check shortcut P(S2||R) to root
 | ||||
|   c = (*bayesTree)[2]; | ||||
|   shortcut = c->shortcut(R, EliminateDiscrete); | ||||
|   clique = (*bayesTree)[2]; | ||||
|   shortcut = clique->shortcut(R, EliminateDiscrete); | ||||
|   DOUBLES_EQUAL(joint_9_11_12_13 / joint_11_13, shortcut.evaluate(all1), 1e-9); | ||||
| 
 | ||||
|   // check shortcut P(S0||R) to root
 | ||||
|   c = (*bayesTree)[0]; | ||||
|   shortcut = c->shortcut(R, EliminateDiscrete); | ||||
|   clique = (*bayesTree)[0]; | ||||
|   shortcut = clique->shortcut(R, EliminateDiscrete); | ||||
|   DOUBLES_EQUAL(joint_8_11_12_13 / joint_11_13, shortcut.evaluate(all1), 1e-9); | ||||
| 
 | ||||
|   // calculate all shortcuts to root
 | ||||
|   DiscreteBayesTree::Nodes cliques = bayesTree->nodes(); | ||||
|   for (auto c : cliques) { | ||||
|     DiscreteBayesNet shortcut = c.second->shortcut(R, EliminateDiscrete); | ||||
|   for (auto clique : cliques) { | ||||
|     DiscreteBayesNet shortcut = clique.second->shortcut(R, EliminateDiscrete); | ||||
|     if (debug) { | ||||
|       c.second->conditional_->printSignature(); | ||||
|       clique.second->conditional_->printSignature(); | ||||
|       shortcut.print("shortcut:"); | ||||
|     } | ||||
|   } | ||||
|  |  | |||
|  | @ -16,9 +16,9 @@ | |||
|  * @date Feb 14, 2011 | ||||
|  */ | ||||
| 
 | ||||
| #include <boost/make_shared.hpp> | ||||
| #include <boost/assign/std/map.hpp> | ||||
| #include <boost/assign/std/vector.hpp> | ||||
| #include <boost/make_shared.hpp> | ||||
| using namespace boost::assign; | ||||
| 
 | ||||
| #include <CppUnitLite/TestHarness.h> | ||||
|  | @ -36,6 +36,11 @@ TEST( DiscreteConditional, constructors) | |||
|   DiscreteConditional::shared_ptr expected1 = //
 | ||||
|       boost::make_shared<DiscreteConditional>(X | Y = "1/1 2/3 1/4"); | ||||
|   EXPECT(expected1); | ||||
|   EXPECT_LONGS_EQUAL(0, *(expected1->beginFrontals())); | ||||
|   EXPECT_LONGS_EQUAL(2, *(expected1->beginParents())); | ||||
|   EXPECT(expected1->endParents() == expected1->end()); | ||||
|   EXPECT(expected1->endFrontals() == expected1->beginParents()); | ||||
|    | ||||
|   DecisionTreeFactor f1(X & Y, "0.5 0.4 0.2 0.5 0.6 0.8"); | ||||
|   DiscreteConditional actual1(1, f1); | ||||
|   EXPECT(assert_equal(*expected1, actual1, 1e-9)); | ||||
|  | @ -43,71 +48,68 @@ TEST( DiscreteConditional, constructors) | |||
|   DecisionTreeFactor f2(X & Y & Z, | ||||
|       "0.2 0.5 0.3 0.6 0.4 0.7 0.25 0.55 0.35 0.65 0.45 0.75"); | ||||
|   DiscreteConditional actual2(1, f2); | ||||
|   DecisionTreeFactor::shared_ptr actual2factor = actual2.toFactor(); | ||||
| //  EXPECT(assert_equal(f2, *actual2factor, 1e-9));
 | ||||
|   EXPECT(assert_equal(f2 / *f2.sum(1), *actual2.toFactor(), 1e-9)); | ||||
| } | ||||
| 
 | ||||
| /* ************************************************************************* */ | ||||
| TEST( DiscreteConditional, constructors_alt_interface) | ||||
| { | ||||
|   DiscreteKey X(0, 2), Y(2, 3), Z(1, 2); // watch ordering !
 | ||||
| TEST(DiscreteConditional, constructors_alt_interface) { | ||||
|   DiscreteKey X(0, 2), Y(2, 3), Z(1, 2);  // watch ordering !
 | ||||
| 
 | ||||
|   Signature::Table table; | ||||
|   Signature::Row r1, r2, r3; | ||||
|   r1 += 1.0, 1.0; r2 += 2.0, 3.0; r3 += 1.0, 4.0; | ||||
|   r1 += 1.0, 1.0; | ||||
|   r2 += 2.0, 3.0; | ||||
|   r3 += 1.0, 4.0; | ||||
|   table += r1, r2, r3; | ||||
|   DiscreteConditional::shared_ptr expected1 = //
 | ||||
|       boost::make_shared<DiscreteConditional>(X | Y = table); | ||||
|   EXPECT(expected1); | ||||
|   auto actual1 = boost::make_shared<DiscreteConditional>(X | Y = table); | ||||
|   EXPECT(actual1); | ||||
|   DecisionTreeFactor f1(X & Y, "0.5 0.4 0.2 0.5 0.6 0.8"); | ||||
|   DiscreteConditional actual1(1, f1); | ||||
|   EXPECT(assert_equal(*expected1, actual1, 1e-9)); | ||||
|   DiscreteConditional expected1(1, f1); | ||||
|   EXPECT(assert_equal(expected1, *actual1, 1e-9)); | ||||
| 
 | ||||
|   DecisionTreeFactor f2(X & Y & Z, | ||||
|       "0.2 0.5 0.3 0.6 0.4 0.7 0.25 0.55 0.35 0.65 0.45 0.75"); | ||||
|   DecisionTreeFactor f2( | ||||
|       X & Y & Z, "0.2 0.5 0.3 0.6 0.4 0.7 0.25 0.55 0.35 0.65 0.45 0.75"); | ||||
|   DiscreteConditional actual2(1, f2); | ||||
|   DecisionTreeFactor::shared_ptr actual2factor = actual2.toFactor(); | ||||
| //  EXPECT(assert_equal(f2, *actual2factor, 1e-9));
 | ||||
|   EXPECT(assert_equal(f2 / *f2.sum(1), *actual2.toFactor(), 1e-9)); | ||||
| } | ||||
| 
 | ||||
| /* ************************************************************************* */ | ||||
| TEST( DiscreteConditional, constructors2) | ||||
| { | ||||
| TEST(DiscreteConditional, constructors2) { | ||||
|   // Declare keys and ordering
 | ||||
|   DiscreteKey C(0,2), B(1,2); | ||||
|   DecisionTreeFactor expected(C & B, "0.8 0.75 0.2 0.25"); | ||||
|   DiscreteKey C(0, 2), B(1, 2); | ||||
|   DecisionTreeFactor actual(C & B, "0.8 0.75 0.2 0.25"); | ||||
|   Signature signature((C | B) = "4/1 3/1"); | ||||
|   DiscreteConditional actual(signature); | ||||
|   DecisionTreeFactor::shared_ptr actualFactor = actual.toFactor(); | ||||
|   EXPECT(assert_equal(expected, *actualFactor)); | ||||
|   DiscreteConditional expected(signature); | ||||
|   DecisionTreeFactor::shared_ptr expectedFactor = expected.toFactor(); | ||||
|   EXPECT(assert_equal(*expectedFactor, actual)); | ||||
| } | ||||
| 
 | ||||
| /* ************************************************************************* */ | ||||
| TEST( DiscreteConditional, constructors3) | ||||
| { | ||||
| TEST(DiscreteConditional, constructors3) { | ||||
|   // Declare keys and ordering
 | ||||
|   DiscreteKey C(0,2), B(1,2), A(2,2); | ||||
|   DecisionTreeFactor expected(C & B & A, "0.8 0.5 0.5 0.2 0.2 0.5 0.5 0.8"); | ||||
|   DiscreteKey C(0, 2), B(1, 2), A(2, 2); | ||||
|   DecisionTreeFactor actual(C & B & A, "0.8 0.5 0.5 0.2 0.2 0.5 0.5 0.8"); | ||||
|   Signature signature((C | B, A) = "4/1 1/1 1/1 1/4"); | ||||
|   DiscreteConditional actual(signature); | ||||
|   DecisionTreeFactor::shared_ptr actualFactor = actual.toFactor(); | ||||
|   EXPECT(assert_equal(expected, *actualFactor)); | ||||
|   DiscreteConditional expected(signature); | ||||
|   DecisionTreeFactor::shared_ptr expectedFactor = expected.toFactor(); | ||||
|   EXPECT(assert_equal(*expectedFactor, actual)); | ||||
| } | ||||
| 
 | ||||
| /* ************************************************************************* */ | ||||
| TEST( DiscreteConditional, Combine) { | ||||
| TEST(DiscreteConditional, Combine) { | ||||
|   DiscreteKey A(0, 2), B(1, 2); | ||||
|   vector<DiscreteConditional::shared_ptr> c; | ||||
|   c.push_back(boost::make_shared<DiscreteConditional>(A | B = "1/2 2/1")); | ||||
|   c.push_back(boost::make_shared<DiscreteConditional>(B % "1/2")); | ||||
|   DecisionTreeFactor factor(A & B, "0.111111 0.444444 0.222222 0.222222"); | ||||
|   DiscreteConditional expected(2, factor); | ||||
|   DiscreteConditional::shared_ptr actual = DiscreteConditional::Combine( | ||||
|       c.begin(), c.end()); | ||||
|   EXPECT(assert_equal(expected, *actual,1e-5)); | ||||
|   DiscreteConditional actual(2, factor); | ||||
|   auto expected = DiscreteConditional::Combine(c.begin(), c.end()); | ||||
|   EXPECT(assert_equal(*expected, actual, 1e-5)); | ||||
| } | ||||
| 
 | ||||
| /* ************************************************************************* */ | ||||
| int main() {  TestResult tr; return TestRegistry::runAllTests(tr); } | ||||
| int main() { | ||||
|   TestResult tr; | ||||
|   return TestRegistry::runAllTests(tr); | ||||
| } | ||||
| /* ************************************************************************* */ | ||||
| 
 | ||||
|  |  | |||
|  | @ -146,8 +146,7 @@ TEST_UNSAFE( DiscreteMarginals, truss ) { | |||
| 
 | ||||
| /* ************************************************************************* */ | ||||
| // Second truss example with non-trivial factors
 | ||||
| TEST_UNSAFE( DiscreteMarginals, truss2 ) { | ||||
| 
 | ||||
| TEST_UNSAFE(DiscreteMarginals, truss2) { | ||||
|   const int nrNodes = 5; | ||||
|   const size_t nrStates = 2; | ||||
| 
 | ||||
|  | @ -160,40 +159,39 @@ TEST_UNSAFE( DiscreteMarginals, truss2 ) { | |||
| 
 | ||||
|   // create graph and add three truss potentials
 | ||||
|   DiscreteFactorGraph graph; | ||||
|   graph.add(key[0] & key[2] & key[4],"1 2 3 4 5 6 7 8"); | ||||
|   graph.add(key[1] & key[3] & key[4],"1 2 3 4 5 6 7 8"); | ||||
|   graph.add(key[2] & key[3] & key[4],"1 2 3 4 5 6 7 8"); | ||||
|   graph.add(key[0] & key[2] & key[4], "1 2 3 4 5 6 7 8"); | ||||
|   graph.add(key[1] & key[3] & key[4], "1 2 3 4 5 6 7 8"); | ||||
|   graph.add(key[2] & key[3] & key[4], "1 2 3 4 5 6 7 8"); | ||||
| 
 | ||||
|   // Calculate the marginals by brute force
 | ||||
|   vector<DiscreteFactor::Values> allPosbValues = cartesianProduct( | ||||
|       key[0] & key[1] & key[2] & key[3] & key[4]); | ||||
|   vector<DiscreteFactor::Values> allPosbValues = | ||||
|       cartesianProduct(key[0] & key[1] & key[2] & key[3] & key[4]); | ||||
|   Vector T = Z_5x1, F = Z_5x1; | ||||
|   for (size_t i = 0; i < allPosbValues.size(); ++i) { | ||||
|     DiscreteFactor::Values x = allPosbValues[i]; | ||||
|     double px = graph(x); | ||||
|     for (size_t j=0;j<5;j++) | ||||
|       if (x[j]) T[j]+=px; else F[j]+=px; | ||||
|     // cout << x[0] << " " << x[1] << " "<< x[2] << " " << x[3] << " " << x[4] << " :\t" << px << endl;
 | ||||
|     for (size_t j = 0; j < 5; j++) | ||||
|       if (x[j]) | ||||
|         T[j] += px; | ||||
|       else | ||||
|         F[j] += px; | ||||
|   } | ||||
| 
 | ||||
|   // Check all marginals given by a sequential solver and Marginals
 | ||||
| //  DiscreteSequentialSolver solver(graph);
 | ||||
|   //  DiscreteSequentialSolver solver(graph);
 | ||||
|   DiscreteMarginals marginals(graph); | ||||
|   for (size_t j=0;j<5;j++) { | ||||
|     double sum = T[j]+F[j]; | ||||
|     T[j]/=sum; | ||||
|     F[j]/=sum; | ||||
| 
 | ||||
| //    // solver
 | ||||
| //    Vector actualV = solver.marginalProbabilities(key[j]);
 | ||||
| //    EXPECT(assert_equal((Vector(2) << F[j], T[j]), actualV));
 | ||||
|   for (size_t j = 0; j < 5; j++) { | ||||
|     double sum = T[j] + F[j]; | ||||
|     T[j] /= sum; | ||||
|     F[j] /= sum; | ||||
| 
 | ||||
|     // Marginals
 | ||||
|     vector<double> table; | ||||
|     table += F[j],T[j]; | ||||
|     DecisionTreeFactor expectedM(key[j],table); | ||||
|     table += F[j], T[j]; | ||||
|     DecisionTreeFactor expectedM(key[j], table); | ||||
|     DiscreteFactor::shared_ptr actualM = marginals(j); | ||||
|     EXPECT(assert_equal(expectedM, *boost::dynamic_pointer_cast<DecisionTreeFactor>(actualM))); | ||||
|     EXPECT(assert_equal( | ||||
|         expectedM, *boost::dynamic_pointer_cast<DecisionTreeFactor>(actualM))); | ||||
|   } | ||||
| } | ||||
| 
 | ||||
|  |  | |||
|  | @ -11,36 +11,43 @@ | |||
| 
 | ||||
| /**
 | ||||
|  * @file testSignature | ||||
|  * @brief Tests focusing on the details of Signatures to evaluate boost compliance | ||||
|  * @brief Tests focusing on the details of Signatures to evaluate boost | ||||
|  * compliance | ||||
|  * @author Alex Cunningham | ||||
|  * @date Sept 19th 2011 | ||||
|  */ | ||||
| 
 | ||||
| #include <boost/assign/std/vector.hpp> | ||||
| #include <CppUnitLite/TestHarness.h> | ||||
| 
 | ||||
| #include <gtsam/base/Testable.h> | ||||
| #include <gtsam/discrete/Signature.h> | ||||
| 
 | ||||
| #include <boost/assign/std/vector.hpp> | ||||
| #include <vector> | ||||
| 
 | ||||
| using namespace std; | ||||
| using namespace gtsam; | ||||
| using namespace boost::assign; | ||||
| 
 | ||||
| DiscreteKey X(0,2), Y(1,3), Z(2,2); | ||||
| DiscreteKey X(0, 2), Y(1, 3), Z(2, 2); | ||||
| 
 | ||||
| /* ************************************************************************* */ | ||||
| TEST(testSignature, simple_conditional) { | ||||
|   Signature sig(X | Y = "1/1 2/3 1/4"); | ||||
|   Signature::Table table = *sig.table(); | ||||
|   vector<double> row[3]{{0.5, 0.5}, {0.4, 0.6}, {0.2, 0.8}}; | ||||
|   CHECK(row[0] == table[0]); | ||||
|   CHECK(row[1] == table[1]); | ||||
|   CHECK(row[2] == table[2]); | ||||
|   DiscreteKey actKey = sig.key(); | ||||
|   LONGS_EQUAL((long)X.first, (long)actKey.first); | ||||
|   LONGS_EQUAL(X.first, actKey.first); | ||||
| 
 | ||||
|   DiscreteKeys actKeys = sig.discreteKeysParentsFirst(); | ||||
|   LONGS_EQUAL(2, (long)actKeys.size()); | ||||
|   LONGS_EQUAL((long)Y.first, (long)actKeys.front().first); | ||||
|   LONGS_EQUAL((long)X.first, (long)actKeys.back().first); | ||||
|   DiscreteKeys actKeys = sig.discreteKeys(); | ||||
|   LONGS_EQUAL(2, actKeys.size()); | ||||
|   LONGS_EQUAL(X.first, actKeys.front().first); | ||||
|   LONGS_EQUAL(Y.first, actKeys.back().first); | ||||
| 
 | ||||
|   vector<double> actCpt = sig.cpt(); | ||||
|   EXPECT_LONGS_EQUAL(6, (long)actCpt.size()); | ||||
|   EXPECT_LONGS_EQUAL(6, actCpt.size()); | ||||
| } | ||||
| 
 | ||||
| /* ************************************************************************* */ | ||||
|  | @ -54,17 +61,20 @@ TEST(testSignature, simple_conditional_nonparser) { | |||
| 
 | ||||
|   Signature sig(X | Y = table); | ||||
|   DiscreteKey actKey = sig.key(); | ||||
|   EXPECT_LONGS_EQUAL((long)X.first, (long)actKey.first); | ||||
|   EXPECT_LONGS_EQUAL(X.first, actKey.first); | ||||
| 
 | ||||
|   DiscreteKeys actKeys = sig.discreteKeysParentsFirst(); | ||||
|   LONGS_EQUAL(2, (long)actKeys.size()); | ||||
|   LONGS_EQUAL((long)Y.first, (long)actKeys.front().first); | ||||
|   LONGS_EQUAL((long)X.first, (long)actKeys.back().first); | ||||
|   DiscreteKeys actKeys = sig.discreteKeys(); | ||||
|   LONGS_EQUAL(2, actKeys.size()); | ||||
|   LONGS_EQUAL(X.first, actKeys.front().first); | ||||
|   LONGS_EQUAL(Y.first, actKeys.back().first); | ||||
| 
 | ||||
|   vector<double> actCpt = sig.cpt(); | ||||
|   EXPECT_LONGS_EQUAL(6, (long)actCpt.size()); | ||||
|   EXPECT_LONGS_EQUAL(6, actCpt.size()); | ||||
| } | ||||
| 
 | ||||
| /* ************************************************************************* */ | ||||
| int main() { TestResult tr; return TestRegistry::runAllTests(tr); } | ||||
| int main() { | ||||
|   TestResult tr; | ||||
|   return TestRegistry::runAllTests(tr); | ||||
| } | ||||
| /* ************************************************************************* */ | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue