diff --git a/gtsam/inference/BayesNet-inl.h b/gtsam/inference/BayesNet-inl.h index 0875ec9c3..72d314d2b 100644 --- a/gtsam/inference/BayesNet-inl.h +++ b/gtsam/inference/BayesNet-inl.h @@ -49,6 +49,24 @@ namespace gtsam { return equal(conditionals_.begin(),conditionals_.end(),cbn.conditionals_.begin(),equals_star(tol)); } + /* ************************************************************************* */ + template + typename BayesNet::const_iterator BayesNet::find(Index key) const { + for(const_iterator it = begin(); it != end(); ++it) + if(std::find((*it)->beginFrontals(), (*it)->endFrontals(), key) != (*it)->endFrontals()) + return it; + return end(); + } + + /* ************************************************************************* */ + template + typename BayesNet::iterator BayesNet::find(Index key) { + for(iterator it = begin(); it != end(); ++it) + if(std::find((*it)->beginFrontals(), (*it)->endFrontals(), key) != (*it)->endFrontals()) + return it; + return end(); + } + /* ************************************************************************* */ template void BayesNet::permuteWithInverse(const Permutation& inversePermutation) { @@ -82,6 +100,21 @@ namespace gtsam { push_front(conditional); } + /* ************************************************************************* */ + template + void BayesNet::popLeaf(iterator conditional) { +#ifndef NDEBUG + BOOST_FOREACH(typename CONDITIONAL::shared_ptr checkConditional, conditionals_) { + BOOST_FOREACH(Index key, (*conditional)->frontals()) { + if(std::find(checkConditional->beginParents(), checkConditional->endParents(), key) != checkConditional->endParents()) + throw std::invalid_argument( + "Debug mode exception: in BayesNet::popLeaf, the requested conditional is not a leaf."); + } + } +#endif + conditionals_.erase(conditional); + } + /* ************************************************************************* */ template FastList BayesNet::ordering() const { @@ -118,7 +151,7 @@ namespace gtsam { } } throw(invalid_argument((boost::format( - "BayesNet::operator['%1%']: not found") % key).str())); + "BayesNet::operator['%1%']: not found") % key).str())); } /* ************************************************************************* */ diff --git a/gtsam/inference/BayesNet.h b/gtsam/inference/BayesNet.h index 8d4c71e56..f79065e9d 100644 --- a/gtsam/inference/BayesNet.h +++ b/gtsam/inference/BayesNet.h @@ -22,6 +22,7 @@ #include #include #include +#include #include #include @@ -30,126 +31,181 @@ namespace gtsam { - /** - * A BayesNet is a list of conditionals, stored in elimination order, i.e. - * leaves first, parents last. GaussianBayesNet and SymbolicBayesNet are - * defined as typedefs of this class, using GaussianConditional and - * IndexConditional as the CONDITIONAL template argument. - * - * todo: Symbolic using Index is a misnomer. - * todo: how to handle Bayes nets with an optimize function? Currently using global functions. - */ - template - class BayesNet: public Testable > { +/** + * A BayesNet is a list of conditionals, stored in elimination order, i.e. + * leaves first, parents last. GaussianBayesNet and SymbolicBayesNet are + * defined as typedefs of this class, using GaussianConditional and + * IndexConditional as the CONDITIONAL template argument. + * + * todo: Symbolic using Index is a misnomer. + * todo: how to handle Bayes nets with an optimize function? Currently using global functions. + */ +template +class BayesNet: public Testable > { - public: +public: - typedef typename boost::shared_ptr > shared_ptr; + typedef typename boost::shared_ptr > shared_ptr; - /** We store shared pointers to Conditional densities */ - typedef typename boost::shared_ptr sharedConditional; - typedef typename boost::shared_ptr const_sharedConditional; - typedef typename std::list Conditionals; + /** We store shared pointers to Conditional densities */ + typedef typename boost::shared_ptr sharedConditional; + typedef typename boost::shared_ptr const_sharedConditional; + typedef typename std::list Conditionals; - typedef typename Conditionals::const_iterator iterator; - typedef typename Conditionals::const_reverse_iterator reverse_iterator; - typedef typename Conditionals::const_iterator const_iterator; - typedef typename Conditionals::const_reverse_iterator const_reverse_iterator; + typedef typename Conditionals::iterator iterator; + typedef typename Conditionals::reverse_iterator reverse_iterator; + typedef typename Conditionals::const_iterator const_iterator; + typedef typename Conditionals::const_reverse_iterator const_reverse_iterator; - protected: +protected: - /** - * Conditional densities are stored in reverse topological sort order (i.e., leaves first, - * parents last), which corresponds to the elimination ordering if so obtained, - * and is consistent with the column (block) ordering of an upper triangular matrix. - */ - Conditionals conditionals_; + /** + * Conditional densities are stored in reverse topological sort order (i.e., leaves first, + * parents last), which corresponds to the elimination ordering if so obtained, + * and is consistent with the column (block) ordering of an upper triangular matrix. + */ + Conditionals conditionals_; - public: +public: - /** Default constructor as an empty BayesNet */ - BayesNet() {}; + /** Default constructor as an empty BayesNet */ + BayesNet() {}; - /** BayesNet with 1 conditional */ - BayesNet(const sharedConditional& conditional) { push_back(conditional); } + /** BayesNet with 1 conditional */ + BayesNet(const sharedConditional& conditional) { push_back(conditional); } - /** print */ - void print(const std::string& s = "") const; + /** print */ + void print(const std::string& s = "") const; - /** check equality */ - bool equals(const BayesNet& other, double tol = 1e-9) const; + /** check equality */ + bool equals(const BayesNet& other, double tol = 1e-9) const; - /** push_back: use reverse topological sort (i.e. parents last / elimination order) */ - inline void push_back(const sharedConditional& conditional) { - conditionals_.push_back(conditional); - } + /** Find an iterator pointing to the conditional where the specified key + * appears as a frontal variable, or end() if no conditional contains this + * key. Running time is approximately \f$ O(n) \f$ in the number of + * conditionals in the BayesNet. + * @param key The index to find in the frontal variables of a conditional. + */ + const_iterator find(Index key) const; - /** push_front: use topological sort (i.e. parents first / reverse elimination order) */ - inline void push_front(const sharedConditional& conditional) { - conditionals_.push_front(conditional); - } + /** Find an iterator pointing to the conditional where the specified key + * appears as a frontal variable, or end() if no conditional contains this + * key. Running time is approximately \f$ O(n) \f$ in the number of + * conditionals in the BayesNet. + * @param key The index to find in the frontal variables of a conditional. + */ + iterator find(Index key); - // push_back an entire Bayes net */ - void push_back(const BayesNet bn); + /** push_back: use reverse topological sort (i.e. parents last / elimination order) */ + inline void push_back(const sharedConditional& conditional) { + conditionals_.push_back(conditional); + } - // push_front an entire Bayes net */ - void push_front(const BayesNet bn); + /** push_front: use topological sort (i.e. parents first / reverse elimination order) */ + inline void push_front(const sharedConditional& conditional) { + conditionals_.push_front(conditional); + } - /** - * pop_front: remove node at the bottom, used in marginalization - * For example P(ABC)=P(A|BC)P(B|C)P(C) becomes P(BC)=P(B|C)P(C) - */ - inline void pop_front() {conditionals_.pop_front();} + /// push_back an entire Bayes net + void push_back(const BayesNet bn); - /** Permute the variables in the BayesNet */ - void permuteWithInverse(const Permutation& inversePermutation); + /// push_front an entire Bayes net + void push_front(const BayesNet bn); - /** - * Permute the variables when only separator variables need to be permuted. - * Returns true if any reordered variables appeared in the separator and - * false if not. - */ - bool permuteSeparatorWithInverse(const Permutation& inversePermutation); + /** += syntax for push_back, e.g. bayesNet += c1, c2, c3 + * @param conditional The conditional to add to the back of the BayesNet + */ + boost::assign::list_inserter >, sharedConditional> + operator+=(const sharedConditional& conditional) { + return boost::assign::make_list_inserter(boost::assign_detail::call_push_back >(*this))(conditional); } - /** size is the number of nodes */ - inline size_t size() const { - return conditionals_.size(); - } + /** + * pop_front: remove node at the bottom, used in marginalization + * For example P(ABC)=P(A|BC)P(B|C)P(C) becomes P(BC)=P(B|C)P(C) + */ + void pop_front() {conditionals_.pop_front();} - /** return keys in reverse topological sort order, i.e., elimination order */ - FastList ordering() const; + /** + * Remove any leaf conditional. The conditional to remove is specified by + * iterator. To find the iterator pointing to the conditional containing a + * particular key, use find(), which has \f$ O(n) \f$ complexity. The + * popLeaf function by itself has \f$ O(1) \f$ complexity. + * + * If the program calling this function is + * compiled without NDEBUG defined, this function will check that the node + * is indeed a leaf, but otherwise will not check, because the check has + * \f$ O(n^2) \f$ complexity. + * + * Example 1: + \code + // Remove a leaf node with a known conditional + GaussianBayesNet gbn = ... + GaussianBayesNet::iterator leafConditional = ... + gbn.popLeaf(leafConditional); + \endcode + * Example 2: + \code + // Remove the leaf node containing variable index 14 + GaussianBayesNet gbn = ... + gbn.popLeaf(gbn.find(14)); + \endcode + * @param conditional The iterator pointing to the leaf conditional to remove + */ + void popLeaf(iterator conditional); - /** SLOW O(n) random access to Conditional by key */ - sharedConditional operator[](Index key) const; + /** Permute the variables in the BayesNet */ + void permuteWithInverse(const Permutation& inversePermutation); - /** return last node in ordering */ - sharedConditional& front() { return conditionals_.front(); } + /** + * Permute the variables when only separator variables need to be permuted. + * Returns true if any reordered variables appeared in the separator and + * false if not. + */ + bool permuteSeparatorWithInverse(const Permutation& inversePermutation); - /** return last node in ordering */ - boost::shared_ptr front() const { return conditionals_.front(); } + /** size is the number of nodes */ + size_t size() const { + return conditionals_.size(); + } - /** return last node in ordering */ - sharedConditional& back() { return conditionals_.back(); } + /** return keys in reverse topological sort order, i.e., elimination order */ + FastList ordering() const; - /** return last node in ordering */ - boost::shared_ptr back() const { return conditionals_.back(); } + /** SLOW O(n) random access to Conditional by key */ + sharedConditional operator[](Index key) const; - /** return iterators. FD: breaks encapsulation? */ - inline const_iterator const begin() const {return conditionals_.begin();} - inline const_iterator const end() const {return conditionals_.end();} - inline const_reverse_iterator const rbegin() const {return conditionals_.rbegin();} - inline const_reverse_iterator const rend() const {return conditionals_.rend();} + /** return last node in ordering */ + sharedConditional& front() { return conditionals_.front(); } - /** saves the bayes to a text file in GraphViz format */ -// void saveGraph(const std::string& s) const; + /** return last node in ordering */ + boost::shared_ptr front() const { return conditionals_.front(); } - private: - /** Serialization function */ - friend class boost::serialization::access; - template - void serialize(ARCHIVE & ar, const unsigned int version) { - ar & BOOST_SERIALIZATION_NVP(conditionals_); - } - }; // BayesNet + /** return last node in ordering */ + sharedConditional& back() { return conditionals_.back(); } + + /** return last node in ordering */ + boost::shared_ptr back() const { return conditionals_.back(); } + + /** return iterators. FD: breaks encapsulation? */ + const_iterator begin() const {return conditionals_.begin();} + const_iterator end() const {return conditionals_.end();} + const_reverse_iterator rbegin() const {return conditionals_.rbegin();} + const_reverse_iterator rend() const {return conditionals_.rend();} + iterator begin() {return conditionals_.begin();} + iterator end() {return conditionals_.end();} + reverse_iterator rbegin() {return conditionals_.rbegin();} + reverse_iterator rend() {return conditionals_.rend();} + + /** saves the bayes to a text file in GraphViz format */ + // void saveGraph(const std::string& s) const; + +private: + /** Serialization function */ + friend class boost::serialization::access; + template + void serialize(ARCHIVE & ar, const unsigned int version) { + ar & BOOST_SERIALIZATION_NVP(conditionals_); + } +}; // BayesNet } /// namespace gtsam diff --git a/gtsam/inference/tests/testSymbolicBayesNet.cpp b/gtsam/inference/tests/testSymbolicBayesNet.cpp index bdff9116f..628770456 100644 --- a/gtsam/inference/tests/testSymbolicBayesNet.cpp +++ b/gtsam/inference/tests/testSymbolicBayesNet.cpp @@ -22,7 +22,6 @@ using namespace boost::assign; #include #include -#ifdef ALL #include using namespace std; @@ -32,6 +31,8 @@ static const Index _L_ = 0; static const Index _A_ = 1; static const Index _B_ = 2; static const Index _C_ = 3; +static const Index _D_ = 4; +static const Index _E_ = 5; IndexConditional::shared_ptr B(new IndexConditional(_B_)), @@ -100,8 +101,92 @@ TEST( SymbolicBayesNet, combine ) CHECK(assert_equal(expected,p_ABC)); } + +/* ************************************************************************* */ +TEST(SymbolicBayesNet, find) { + SymbolicBayesNet bn; + bn += IndexConditional::shared_ptr(new IndexConditional(_A_, _B_)); + std::vector keys; + keys.push_back(_B_); + keys.push_back(_C_); + keys.push_back(_D_); + bn += IndexConditional::shared_ptr(new IndexConditional(keys,2)); + bn += IndexConditional::shared_ptr(new IndexConditional(_D_)); + + SymbolicBayesNet::iterator expected = bn.begin(); ++ expected; + SymbolicBayesNet::iterator actual = bn.find(_C_); + EXPECT(assert_equal(**expected, **actual)); +} + +/* ************************************************************************* */ +TEST_UNSAFE(SymbolicBayesNet, popLeaf) { + IndexConditional::shared_ptr + A(new IndexConditional(_A_,_E_)), + B(new IndexConditional(_B_,_E_)), + C(new IndexConditional(_C_,_D_)), + D(new IndexConditional(_D_,_E_)), + E(new IndexConditional(_E_)); + + // BayesNet after popping A + SymbolicBayesNet expected1; + expected1 += B, C, D, E; + + // BayesNet after popping C + SymbolicBayesNet expected2; + expected2 += A, B, D, E; + + // BayesNet after popping C and D + SymbolicBayesNet expected3; + expected3 += A, B, E; + + // BayesNet after popping C and A + SymbolicBayesNet expected4; + expected4 += B, D, E; + + + // BayesNet after popping A + SymbolicBayesNet actual1; + actual1 += A, B, C, D, E; + actual1.popLeaf(actual1.find(_A_)); + + // BayesNet after popping C + SymbolicBayesNet actual2; + actual2 += A, B, C, D, E; + actual2.popLeaf(actual2.find(_C_)); + + // BayesNet after popping C and D + SymbolicBayesNet actual3; + actual3 += A, B, C, D, E; + actual3.popLeaf(actual3.find(_C_)); + actual3.popLeaf(actual3.find(_D_)); + + // BayesNet after popping C and A + SymbolicBayesNet actual4; + actual4 += A, B, C, D, E; + actual4.popLeaf(actual4.find(_C_)); + actual4.popLeaf(actual4.find(_A_)); + + EXPECT(assert_equal(expected1, actual1)); + EXPECT(assert_equal(expected2, actual2)); + EXPECT(assert_equal(expected3, actual3)); + EXPECT(assert_equal(expected4, actual4)); + + // Try to remove a non-leaf node +#undef NDEBUG_SAVED +#ifdef NDEBUG +#define NDEBUG_SAVED #endif +#undef NDEBUG + SymbolicBayesNet actual5; + actual5 += A, B, C, D, E; + CHECK_EXCEPTION(actual5.popLeaf(actual5.find(_D_)), std::invalid_argument); + +#ifdef NDEBUG_SAVED +#define NDEBUG +#endif +} + /* ************************************************************************* */ int main() { TestResult tr;