diff --git a/cpp/BinaryConditional.h b/cpp/BinaryConditional.h new file mode 100644 index 000000000..860971aa6 --- /dev/null +++ b/cpp/BinaryConditional.h @@ -0,0 +1,88 @@ +/** + * @file DiscreteConditional.h + * @brief Discrete Conditional node for use in Bayes nets + * @author Manohar Paluri + */ + +// \callgraph + +#pragma once + +#include +#include +#include +#include +#include // TODO: make cpp file +#include +#include "Conditional.h" + +namespace gtsam { + + /** + * Conditional node for use in a Bayes net + */ + class BinaryConditional: public Conditional { + + private: + + std::list parents_; + + public: + + /** convenience typename for a shared pointer to this class */ + typedef boost::shared_ptr shared_ptr; + + /** + * Empty Constructor to make serialization possible + */ + BinaryConditional(){} + + /** + * No parents + */ + BinaryConditional(const std::string& key, double p) : + Conditional(key) { + } + + /** + * Single parent + */ + BinaryConditional(const std::string& key, const std::string& parent, const std::vector& cpt) : + Conditional(key) { + parents_.push_back(parent); + } + + /** print */ + void print(const std::string& s = "BinaryConditional") const { + std::cout << s << " P(" << key_; + if (parents_.size()>0) std::cout << " |"; + BOOST_FOREACH(std::string parent, parents_) std::cout << " " << parent; + std::cout << ")" << std::endl; + } + + /** check equality */ + bool equals(const Conditional& c, double tol = 1e-9) const { + if (!Conditional::equals(c)) return false; + const BinaryConditional* p = dynamic_cast (&c); + if (p == NULL) return false; + return parents_ == p->parents_; + } + + /** return parents */ + std::list parents() const { return parents_;} + + /** find the number of parents */ + size_t nrParents() const { + return parents_.size(); + } + + private: + /** Serialization function */ + friend class boost::serialization::access; + template + void serialize(Archive & ar, const unsigned int version) { + ar & boost::serialization::make_nvp("Conditional", boost::serialization::base_object(*this)); + ar & BOOST_SERIALIZATION_NVP(parents_); + } + }; +} /// namespace gtsam diff --git a/cpp/Makefile.am b/cpp/Makefile.am index ab1f474dc..b40233b6d 100644 --- a/cpp/Makefile.am +++ b/cpp/Makefile.am @@ -85,6 +85,12 @@ testSymbolicFactorGraph_LDADD = libgtsam.la testSymbolicBayesNet_SOURCES = $(example) testSymbolicBayesNet.cpp testSymbolicBayesNet_LDADD = libgtsam.la +# Binary Inference +headers += BinaryConditional.h +check_PROGRAMS += testBinaryBayesNet +testBinaryBayesNet_SOURCES = testBinaryBayesNet.cpp +testBinaryBayesNet_LDADD = libgtsam.la + # Gaussian inference headers += GaussianFactorSet.h sources += VectorConfig.cpp GaussianFactor.cpp GaussianFactorGraph.cpp GaussianConditional.cpp GaussianBayesNet.cpp diff --git a/cpp/testBinaryBayesNet.cpp b/cpp/testBinaryBayesNet.cpp index 4717cac72..8f7ee360d 100644 --- a/cpp/testBinaryBayesNet.cpp +++ b/cpp/testBinaryBayesNet.cpp @@ -1,27 +1,72 @@ /** * @file testBinaryBayesNet.cpp - * @brief Unit tests for Bayes Tree - * @author Frank Dellaert + * @brief Unit tests for BinaryBayesNet + * @author Manohar Paluri */ +// STL/C++ +#include +#include #include +#include +#include +#include // for operator += +using namespace boost::assign; + +#ifdef HAVE_BOOST_SERIALIZATION +#include +#include +#endif //HAVE_BOOST_SERIALIZATION + +#include "BinaryConditional.h" +#include "BayesNet-inl.h" +#include "smallExample.h" +#include "Ordering.h" + +using namespace std; using namespace gtsam; +/** A Bayes net made from binary conditional probability tables */ +typedef BayesNet BinaryBayesNet; + +struct BinaryConfig { + bool px_; + bool py_; + + BinaryConfig( bool px, bool py ):px_(px), py_(py){} +}; + +double probability(const BinaryBayesNet& bayesNet, const BinaryConfig& config) { + + return 0; +} + /* ************************************************************************* */ TEST( BinaryBayesNet, constructor ) { - map tables; - BinaryCPT pA(0.01);tables.insert("A",pA); - BinaryCPT pB("S",0.6,0.3); - BinaryBayesNet binaryBayesNet(tables); - BinaryConfig allFalse(false,false,false,...); - DOUBLES_EQUAL(0.12,binaryBayesNet.probability(allFalse)); + // small Bayes Net x <- y + // p(y) = 0.2 + // p(x|y=0) = 0.3 + // p(x|y=1) = 0.5 + + // unary conditional for y + boost::shared_ptr py(new BinaryConditional("y",0.2)); + + // single parent conditional for x + vector cpt; + cpt += 0.3, 0.5; // array index corresponds to binary parent configuration + boost::shared_ptr px_y(new BinaryConditional("x","y",cpt)); + + // push back conditionals in topological sort order (parents last) + BinaryBayesNet bbn; + bbn.push_back(py); + bbn.push_back(px_y); + + // Test probability of 00,01,10,11 + //DOUBLES_EQUAL(0.56,probability(bbn,BinaryConfig(false,false)),0.01); // P(y=0)P(x=0|y=0) = 0.8 * 0.7 = 0.56; } /* ************************************************************************* */ -int main() { - TestResult tr; - return TestRegistry::runAllTests(tr); -} +int main() { TestResult tr; return TestRegistry::runAllTests(tr);} /* ************************************************************************* */