Added BayesNet::popLeaf

release/4.3a0
Richard Roberts 2011-10-03 17:39:36 +00:00
parent 8fe0f6a501
commit f19c9c2da4
3 changed files with 269 additions and 95 deletions

View File

@ -49,6 +49,24 @@ namespace gtsam {
return equal(conditionals_.begin(),conditionals_.end(),cbn.conditionals_.begin(),equals_star<CONDITIONAL>(tol)); return equal(conditionals_.begin(),conditionals_.end(),cbn.conditionals_.begin(),equals_star<CONDITIONAL>(tol));
} }
/* ************************************************************************* */
template<class CONDITIONAL>
typename BayesNet<CONDITIONAL>::const_iterator BayesNet<CONDITIONAL>::find(Index key) const {
for(const_iterator it = begin(); it != end(); ++it)
if(std::find((*it)->beginFrontals(), (*it)->endFrontals(), key) != (*it)->endFrontals())
return it;
return end();
}
/* ************************************************************************* */
template<class CONDITIONAL>
typename BayesNet<CONDITIONAL>::iterator BayesNet<CONDITIONAL>::find(Index key) {
for(iterator it = begin(); it != end(); ++it)
if(std::find((*it)->beginFrontals(), (*it)->endFrontals(), key) != (*it)->endFrontals())
return it;
return end();
}
/* ************************************************************************* */ /* ************************************************************************* */
template<class CONDITIONAL> template<class CONDITIONAL>
void BayesNet<CONDITIONAL>::permuteWithInverse(const Permutation& inversePermutation) { void BayesNet<CONDITIONAL>::permuteWithInverse(const Permutation& inversePermutation) {
@ -82,6 +100,21 @@ namespace gtsam {
push_front(conditional); push_front(conditional);
} }
/* ************************************************************************* */
template<class CONDITIONAL>
void BayesNet<CONDITIONAL>::popLeaf(iterator conditional) {
#ifndef NDEBUG
BOOST_FOREACH(typename CONDITIONAL::shared_ptr checkConditional, conditionals_) {
BOOST_FOREACH(Index key, (*conditional)->frontals()) {
if(std::find(checkConditional->beginParents(), checkConditional->endParents(), key) != checkConditional->endParents())
throw std::invalid_argument(
"Debug mode exception: in BayesNet::popLeaf, the requested conditional is not a leaf.");
}
}
#endif
conditionals_.erase(conditional);
}
/* ************************************************************************* */ /* ************************************************************************* */
template<class CONDITIONAL> template<class CONDITIONAL>
FastList<Index> BayesNet<CONDITIONAL>::ordering() const { FastList<Index> BayesNet<CONDITIONAL>::ordering() const {

View File

@ -22,6 +22,7 @@
#include <list> #include <list>
#include <boost/shared_ptr.hpp> #include <boost/shared_ptr.hpp>
#include <boost/serialization/nvp.hpp> #include <boost/serialization/nvp.hpp>
#include <boost/assign/list_inserter.hpp>
#include <gtsam/base/types.h> #include <gtsam/base/types.h>
#include <gtsam/base/FastList.h> #include <gtsam/base/FastList.h>
@ -51,8 +52,8 @@ namespace gtsam {
typedef typename boost::shared_ptr<const CONDITIONAL> const_sharedConditional; typedef typename boost::shared_ptr<const CONDITIONAL> const_sharedConditional;
typedef typename std::list<sharedConditional> Conditionals; typedef typename std::list<sharedConditional> Conditionals;
typedef typename Conditionals::const_iterator iterator; typedef typename Conditionals::iterator iterator;
typedef typename Conditionals::const_reverse_iterator reverse_iterator; typedef typename Conditionals::reverse_iterator reverse_iterator;
typedef typename Conditionals::const_iterator const_iterator; typedef typename Conditionals::const_iterator const_iterator;
typedef typename Conditionals::const_reverse_iterator const_reverse_iterator; typedef typename Conditionals::const_reverse_iterator const_reverse_iterator;
@ -79,6 +80,22 @@ namespace gtsam {
/** check equality */ /** check equality */
bool equals(const BayesNet& other, double tol = 1e-9) const; bool equals(const BayesNet& other, double tol = 1e-9) const;
/** Find an iterator pointing to the conditional where the specified key
* appears as a frontal variable, or end() if no conditional contains this
* key. Running time is approximately \f$ O(n) \f$ in the number of
* conditionals in the BayesNet.
* @param key The index to find in the frontal variables of a conditional.
*/
const_iterator find(Index key) const;
/** Find an iterator pointing to the conditional where the specified key
* appears as a frontal variable, or end() if no conditional contains this
* key. Running time is approximately \f$ O(n) \f$ in the number of
* conditionals in the BayesNet.
* @param key The index to find in the frontal variables of a conditional.
*/
iterator find(Index key);
/** push_back: use reverse topological sort (i.e. parents last / elimination order) */ /** push_back: use reverse topological sort (i.e. parents last / elimination order) */
inline void push_back(const sharedConditional& conditional) { inline void push_back(const sharedConditional& conditional) {
conditionals_.push_back(conditional); conditionals_.push_back(conditional);
@ -89,17 +106,52 @@ namespace gtsam {
conditionals_.push_front(conditional); conditionals_.push_front(conditional);
} }
// push_back an entire Bayes net */ /// push_back an entire Bayes net
void push_back(const BayesNet<CONDITIONAL> bn); void push_back(const BayesNet<CONDITIONAL> bn);
// push_front an entire Bayes net */ /// push_front an entire Bayes net
void push_front(const BayesNet<CONDITIONAL> bn); void push_front(const BayesNet<CONDITIONAL> bn);
/** += syntax for push_back, e.g. bayesNet += c1, c2, c3
* @param conditional The conditional to add to the back of the BayesNet
*/
boost::assign::list_inserter<boost::assign_detail::call_push_back<BayesNet<CONDITIONAL> >, sharedConditional>
operator+=(const sharedConditional& conditional) {
return boost::assign::make_list_inserter(boost::assign_detail::call_push_back<BayesNet<CONDITIONAL> >(*this))(conditional); }
/** /**
* pop_front: remove node at the bottom, used in marginalization * pop_front: remove node at the bottom, used in marginalization
* For example P(ABC)=P(A|BC)P(B|C)P(C) becomes P(BC)=P(B|C)P(C) * For example P(ABC)=P(A|BC)P(B|C)P(C) becomes P(BC)=P(B|C)P(C)
*/ */
inline void pop_front() {conditionals_.pop_front();} void pop_front() {conditionals_.pop_front();}
/**
* Remove any leaf conditional. The conditional to remove is specified by
* iterator. To find the iterator pointing to the conditional containing a
* particular key, use find(), which has \f$ O(n) \f$ complexity. The
* popLeaf function by itself has \f$ O(1) \f$ complexity.
*
* If the program calling this function is
* compiled without NDEBUG defined, this function will check that the node
* is indeed a leaf, but otherwise will not check, because the check has
* \f$ O(n^2) \f$ complexity.
*
* Example 1:
\code
// Remove a leaf node with a known conditional
GaussianBayesNet gbn = ...
GaussianBayesNet::iterator leafConditional = ...
gbn.popLeaf(leafConditional);
\endcode
* Example 2:
\code
// Remove the leaf node containing variable index 14
GaussianBayesNet gbn = ...
gbn.popLeaf(gbn.find(14));
\endcode
* @param conditional The iterator pointing to the leaf conditional to remove
*/
void popLeaf(iterator conditional);
/** Permute the variables in the BayesNet */ /** Permute the variables in the BayesNet */
void permuteWithInverse(const Permutation& inversePermutation); void permuteWithInverse(const Permutation& inversePermutation);
@ -112,7 +164,7 @@ namespace gtsam {
bool permuteSeparatorWithInverse(const Permutation& inversePermutation); bool permuteSeparatorWithInverse(const Permutation& inversePermutation);
/** size is the number of nodes */ /** size is the number of nodes */
inline size_t size() const { size_t size() const {
return conditionals_.size(); return conditionals_.size();
} }
@ -135,10 +187,14 @@ namespace gtsam {
boost::shared_ptr<const CONDITIONAL> back() const { return conditionals_.back(); } boost::shared_ptr<const CONDITIONAL> back() const { return conditionals_.back(); }
/** return iterators. FD: breaks encapsulation? */ /** return iterators. FD: breaks encapsulation? */
inline const_iterator const begin() const {return conditionals_.begin();} const_iterator begin() const {return conditionals_.begin();}
inline const_iterator const end() const {return conditionals_.end();} const_iterator end() const {return conditionals_.end();}
inline const_reverse_iterator const rbegin() const {return conditionals_.rbegin();} const_reverse_iterator rbegin() const {return conditionals_.rbegin();}
inline const_reverse_iterator const rend() const {return conditionals_.rend();} const_reverse_iterator rend() const {return conditionals_.rend();}
iterator begin() {return conditionals_.begin();}
iterator end() {return conditionals_.end();}
reverse_iterator rbegin() {return conditionals_.rbegin();}
reverse_iterator rend() {return conditionals_.rend();}
/** saves the bayes to a text file in GraphViz format */ /** saves the bayes to a text file in GraphViz format */
// void saveGraph(const std::string& s) const; // void saveGraph(const std::string& s) const;

View File

@ -22,7 +22,6 @@ using namespace boost::assign;
#include <CppUnitLite/TestHarness.h> #include <CppUnitLite/TestHarness.h>
#include <gtsam/inference/IndexConditional.h> #include <gtsam/inference/IndexConditional.h>
#ifdef ALL
#include <gtsam/inference/SymbolicFactorGraph.h> #include <gtsam/inference/SymbolicFactorGraph.h>
using namespace std; using namespace std;
@ -32,6 +31,8 @@ static const Index _L_ = 0;
static const Index _A_ = 1; static const Index _A_ = 1;
static const Index _B_ = 2; static const Index _B_ = 2;
static const Index _C_ = 3; static const Index _C_ = 3;
static const Index _D_ = 4;
static const Index _E_ = 5;
IndexConditional::shared_ptr IndexConditional::shared_ptr
B(new IndexConditional(_B_)), B(new IndexConditional(_B_)),
@ -100,8 +101,92 @@ TEST( SymbolicBayesNet, combine )
CHECK(assert_equal(expected,p_ABC)); CHECK(assert_equal(expected,p_ABC));
} }
/* ************************************************************************* */
TEST(SymbolicBayesNet, find) {
SymbolicBayesNet bn;
bn += IndexConditional::shared_ptr(new IndexConditional(_A_, _B_));
std::vector<Index> keys;
keys.push_back(_B_);
keys.push_back(_C_);
keys.push_back(_D_);
bn += IndexConditional::shared_ptr(new IndexConditional(keys,2));
bn += IndexConditional::shared_ptr(new IndexConditional(_D_));
SymbolicBayesNet::iterator expected = bn.begin(); ++ expected;
SymbolicBayesNet::iterator actual = bn.find(_C_);
EXPECT(assert_equal(**expected, **actual));
}
/* ************************************************************************* */
TEST_UNSAFE(SymbolicBayesNet, popLeaf) {
IndexConditional::shared_ptr
A(new IndexConditional(_A_,_E_)),
B(new IndexConditional(_B_,_E_)),
C(new IndexConditional(_C_,_D_)),
D(new IndexConditional(_D_,_E_)),
E(new IndexConditional(_E_));
// BayesNet after popping A
SymbolicBayesNet expected1;
expected1 += B, C, D, E;
// BayesNet after popping C
SymbolicBayesNet expected2;
expected2 += A, B, D, E;
// BayesNet after popping C and D
SymbolicBayesNet expected3;
expected3 += A, B, E;
// BayesNet after popping C and A
SymbolicBayesNet expected4;
expected4 += B, D, E;
// BayesNet after popping A
SymbolicBayesNet actual1;
actual1 += A, B, C, D, E;
actual1.popLeaf(actual1.find(_A_));
// BayesNet after popping C
SymbolicBayesNet actual2;
actual2 += A, B, C, D, E;
actual2.popLeaf(actual2.find(_C_));
// BayesNet after popping C and D
SymbolicBayesNet actual3;
actual3 += A, B, C, D, E;
actual3.popLeaf(actual3.find(_C_));
actual3.popLeaf(actual3.find(_D_));
// BayesNet after popping C and A
SymbolicBayesNet actual4;
actual4 += A, B, C, D, E;
actual4.popLeaf(actual4.find(_C_));
actual4.popLeaf(actual4.find(_A_));
EXPECT(assert_equal(expected1, actual1));
EXPECT(assert_equal(expected2, actual2));
EXPECT(assert_equal(expected3, actual3));
EXPECT(assert_equal(expected4, actual4));
// Try to remove a non-leaf node
#undef NDEBUG_SAVED
#ifdef NDEBUG
#define NDEBUG_SAVED
#endif #endif
#undef NDEBUG
SymbolicBayesNet actual5;
actual5 += A, B, C, D, E;
CHECK_EXCEPTION(actual5.popLeaf(actual5.find(_D_)), std::invalid_argument);
#ifdef NDEBUG_SAVED
#define NDEBUG
#endif
}
/* ************************************************************************* */ /* ************************************************************************* */
int main() { int main() {
TestResult tr; TestResult tr;