diff --git a/cpp/BinaryConditional.h b/cpp/BinaryConditional.h index 860971aa6..31257b9e8 100644 --- a/cpp/BinaryConditional.h +++ b/cpp/BinaryConditional.h @@ -14,6 +14,7 @@ #include #include // TODO: make cpp file #include +#include #include "Conditional.h" namespace gtsam { @@ -26,6 +27,7 @@ namespace gtsam { private: std::list parents_; + std::vector cpt_; public: @@ -42,6 +44,8 @@ namespace gtsam { */ BinaryConditional(const std::string& key, double p) : Conditional(key) { + cpt_.push_back(1-p); + cpt_.push_back(p); } /** @@ -50,6 +54,7 @@ namespace gtsam { BinaryConditional(const std::string& key, const std::string& parent, const std::vector& cpt) : Conditional(key) { parents_.push_back(parent); + cpt_ = cpt; } /** print */ @@ -58,6 +63,9 @@ namespace gtsam { if (parents_.size()>0) std::cout << " |"; BOOST_FOREACH(std::string parent, parents_) std::cout << " " << parent; std::cout << ")" << std::endl; + std::cout << "Conditional Probability Table::" << std::endl; + BOOST_FOREACH(double p, cpt_) std::cout << p << "\t"; + std::cout<< std::endl; } /** check equality */ @@ -65,12 +73,15 @@ namespace gtsam { if (!Conditional::equals(c)) return false; const BinaryConditional* p = dynamic_cast (&c); if (p == NULL) return false; - return parents_ == p->parents_; + return (parents_ == p->parents_ && cpt_ == p->cpt_); } /** return parents */ std::list parents() const { return parents_;} + /** return Conditional probability table*/ + std::vector cpt() const { return cpt_;} + /** find the number of parents */ size_t nrParents() const { return parents_.size(); @@ -83,6 +94,7 @@ namespace gtsam { void serialize(Archive & ar, const unsigned int version) { ar & boost::serialization::make_nvp("Conditional", boost::serialization::base_object(*this)); ar & BOOST_SERIALIZATION_NVP(parents_); + ar & BOOST_SERIALIZATION_NVP(cpt_); } }; } /// namespace gtsam diff --git a/cpp/testBinaryBayesNet.cpp b/cpp/testBinaryBayesNet.cpp index 8f7ee360d..09fa9c3e6 100644 --- a/cpp/testBinaryBayesNet.cpp +++ b/cpp/testBinaryBayesNet.cpp @@ -38,8 +38,9 @@ struct BinaryConfig { }; double probability(const BinaryBayesNet& bayesNet, const BinaryConfig& config) { - - return 0; + double result = 1.0; + /* TODO: using config multiply the probabilities */ + return result; } /* ************************************************************************* */ @@ -52,11 +53,13 @@ TEST( BinaryBayesNet, constructor ) // unary conditional for y boost::shared_ptr py(new BinaryConditional("y",0.2)); + py->print("py"); // single parent conditional for x vector cpt; - cpt += 0.3, 0.5; // array index corresponds to binary parent configuration + cpt += 0.7, 0.5, 0.3, 0.5 ; // array index corresponds to binary parent configuration boost::shared_ptr px_y(new BinaryConditional("x","y",cpt)); + px_y->print("px_y"); // push back conditionals in topological sort order (parents last) BinaryBayesNet bbn; @@ -64,7 +67,7 @@ TEST( BinaryBayesNet, constructor ) bbn.push_back(px_y); // Test probability of 00,01,10,11 - //DOUBLES_EQUAL(0.56,probability(bbn,BinaryConfig(false,false)),0.01); // P(y=0)P(x=0|y=0) = 0.8 * 0.7 = 0.56; + DOUBLES_EQUAL(0.56,probability(bbn,BinaryConfig(false,false)),0.01); // P(y=0)P(x=0|y=0) = 0.8 * 0.7 = 0.56; } /* ************************************************************************* */