BayesNet is now list-based for fast bi-directional access

SLOW O(n) random access operator[key] provided 
(should maybe be called [at] as it does bounds checking)
I also fixed a bug in equals.
release/4.3a0
Frank Dellaert 2009-11-03 06:29:56 +00:00
parent eab038651e
commit e9d942f81e
9 changed files with 61 additions and 65 deletions

View File

@ -490,9 +490,10 @@
<useDefaultCommand>false</useDefaultCommand> <useDefaultCommand>false</useDefaultCommand>
<runAllBuilders>true</runAllBuilders> <runAllBuilders>true</runAllBuilders>
</target> </target>
<target name="testSymbolicBayesChain.run" path="cpp" targetID="org.eclipse.cdt.build.MakeTargetBuilder"> <target name="testSymbolicBayesNet.run" path="cpp" targetID="org.eclipse.cdt.build.MakeTargetBuilder">
<buildCommand>make</buildCommand> <buildCommand>make</buildCommand>
<buildTarget>testSymbolicBayesChain.run</buildTarget> <buildArguments/>
<buildTarget>testSymbolicBayesNet.run</buildTarget>
<stopOnError>true</stopOnError> <stopOnError>true</stopOnError>
<useDefaultCommand>false</useDefaultCommand> <useDefaultCommand>false</useDefaultCommand>
<runAllBuilders>true</runAllBuilders> <runAllBuilders>true</runAllBuilders>

View File

@ -30,29 +30,8 @@ namespace gtsam {
/* ************************************************************************* */ /* ************************************************************************* */
template<class Conditional> template<class Conditional>
bool BayesNet<Conditional>::equals(const BayesNet& cbn, double tol) const { bool BayesNet<Conditional>::equals(const BayesNet& cbn, double tol) const {
if(indices_ != cbn.indices_) return false;
if(size() != cbn.size()) return false; if(size() != cbn.size()) return false;
return equal(conditionals_.begin(),conditionals_.begin(),conditionals_.begin(),equals_star<Conditional>); return equal(conditionals_.begin(),conditionals_.end(),cbn.conditionals_.begin(),equals_star<Conditional>(tol));
}
/* ************************************************************************* */
template<class Conditional>
void BayesNet<Conditional>::push_back
(const boost::shared_ptr<Conditional>& conditional) {
indices_.insert(make_pair(conditional->key(),conditionals_.size()));
conditionals_.push_back(conditional);
}
/* ************************************************************************* *
template<class Conditional>
void BayesNet<Conditional>::erase(const string& key) {
list<string>::iterator it;
for (it=keys_.begin(); it != keys_.end(); ++it){
if( strcmp(key.c_str(), (*it).c_str()) == 0 )
break;
}
keys_.erase(it);
conditionals_.erase(key);
} }
/* ************************************************************************* */ /* ************************************************************************* */
@ -65,5 +44,24 @@ namespace gtsam {
} }
/* ************************************************************************* */ /* ************************************************************************* */
// predicate to check whether a conditional has the sought key
template<class Conditional>
class HasKey {
const string& key_;
public:
HasKey(const std::string& key):key_(key) {}
bool operator()(const boost::shared_ptr<Conditional>& conditional) {
return (conditional->key()==key_);
}
};
template<class Conditional>
boost::shared_ptr<Conditional> BayesNet<Conditional>::operator[](const std::string& key) const {
const_iterator it = find_if(conditionals_.begin(),conditionals_.end(),HasKey<Conditional>(key));
if (it == conditionals_.end()) throw(invalid_argument(
"BayesNet::operator['"+key+"']: not found"));
return *it;
}
/* ************************************************************************* */
} // namespace gtsam } // namespace gtsam

View File

@ -8,10 +8,9 @@
#pragma once #pragma once
#include <vector> #include <list>
#include <boost/shared_ptr.hpp> #include <boost/shared_ptr.hpp>
#include <boost/serialization/map.hpp> #include <boost/serialization/list.hpp>
#include <boost/serialization/vector.hpp>
#include <boost/serialization/shared_ptr.hpp> #include <boost/serialization/shared_ptr.hpp>
#include "Testable.h" #include "Testable.h"
@ -33,7 +32,8 @@ namespace gtsam {
/** We store shared pointers to Conditional densities */ /** We store shared pointers to Conditional densities */
typedef typename boost::shared_ptr<Conditional> conditional_ptr; typedef typename boost::shared_ptr<Conditional> conditional_ptr;
typedef typename std::vector<conditional_ptr> Conditionals; typedef typename std::list<conditional_ptr> Conditionals;
typedef typename Conditionals::const_iterator const_iterator; typedef typename Conditionals::const_iterator const_iterator;
typedef typename Conditionals::const_reverse_iterator const_reverse_iterator; typedef typename Conditionals::const_reverse_iterator const_reverse_iterator;
@ -46,19 +46,6 @@ namespace gtsam {
*/ */
Conditionals conditionals_; Conditionals conditionals_;
/**
* O(log n) random access on keys will provided by a map from keys to vector index.
*/
typedef std::map<std::string, int> Indices;
Indices indices_;
/** O(log n) lookup from key to node index */
inline int index(const std::string& key) const {
Indices::const_iterator it = indices_.find(key); // get node index
assert( it != indices_.end() );
return it->second;
}
public: public:
/** print */ /** print */
@ -68,7 +55,14 @@ namespace gtsam {
bool equals(const BayesNet& other, double tol = 1e-9) const; bool equals(const BayesNet& other, double tol = 1e-9) const;
/** push_back: use reverse topological sort (i.e. parents last / elimination order) */ /** push_back: use reverse topological sort (i.e. parents last / elimination order) */
void push_back(const boost::shared_ptr<Conditional>& conditional); inline void push_back(const boost::shared_ptr<Conditional>& conditional) {
conditionals_.push_back(conditional);
}
/** push_front: use topological sort (i.e. parents first / reverse elimination order) */
inline void push_front(const boost::shared_ptr<Conditional>& conditional) {
conditionals_.push_front(conditional);
}
/** size is the number of nodes */ /** size is the number of nodes */
inline size_t size() const { inline size_t size() const {
@ -78,11 +72,8 @@ namespace gtsam {
/** return keys in reverse topological sort order, i.e., elimination order */ /** return keys in reverse topological sort order, i.e., elimination order */
Ordering ordering() const; Ordering ordering() const;
/** O(log n) random access to Conditional by key */ /** SLOW O(n) random access to Conditional by key */
inline conditional_ptr operator[](const std::string& key) const { conditional_ptr operator[](const std::string& key) const;
int i = index(key);
return conditionals_[i];
}
/** return iterators. FD: breaks encapsulation? */ /** return iterators. FD: breaks encapsulation? */
const_iterator const begin() const {return conditionals_.begin();} const_iterator const begin() const {return conditionals_.begin();}
@ -96,7 +87,6 @@ namespace gtsam {
template<class Archive> template<class Archive>
void serialize(Archive & ar, const unsigned int version) { void serialize(Archive & ar, const unsigned int version) {
ar & BOOST_SERIALIZATION_NVP(conditionals_); ar & BOOST_SERIALIZATION_NVP(conditionals_);
ar & BOOST_SERIALIZATION_NVP(indices_);
} }
}; };

View File

@ -50,9 +50,8 @@ namespace gtsam {
template<class Conditional> template<class Conditional>
BayesTree<Conditional>::BayesTree(const BayesNet<Conditional>& bayesNet, bool verbose) { BayesTree<Conditional>::BayesTree(const BayesNet<Conditional>& bayesNet, bool verbose) {
typename BayesNet<Conditional>::const_reverse_iterator rit; typename BayesNet<Conditional>::const_reverse_iterator rit;
for ( rit=bayesNet.rbegin(); rit < bayesNet.rend(); ++rit ) { for ( rit=bayesNet.rbegin(); rit != bayesNet.rend(); ++rit )
insert(*rit,verbose); insert(*rit,verbose);
}
} }
/* ************************************************************************* */ /* ************************************************************************* */
@ -69,7 +68,7 @@ namespace gtsam {
double tol) const { double tol) const {
return size()==other.size() && return size()==other.size() &&
equal(nodeMap_.begin(),nodeMap_.end(),other.nodeMap_.begin()) && equal(nodeMap_.begin(),nodeMap_.end(),other.nodeMap_.begin()) &&
equal(nodes_.begin(),nodes_.end(),other.nodes_.begin(),equals_star<Node>); equal(nodes_.begin(),nodes_.end(),other.nodes_.begin(),equals_star<Node>(tol));
} }
/* ************************************************************************* */ /* ************************************************************************* */
@ -106,7 +105,7 @@ namespace gtsam {
if (parent_clique->size() == parents.size()) { if (parent_clique->size() == parents.size()) {
if (verbose) cout << "Adding to clique " << index << endl; if (verbose) cout << "Adding to clique " << index << endl;
nodeMap_.insert(make_pair(key, index)); nodeMap_.insert(make_pair(key, index));
parent_clique->push_back(conditional); parent_clique->push_front(conditional);
return; return;
} }

View File

@ -25,7 +25,7 @@ using namespace gtsam;
typedef pair<const string, Matrix>& mypair; typedef pair<const string, Matrix>& mypair;
/* ************************************************************************* */ /* ************************************************************************* */
LinearFactor::LinearFactor(const boost::shared_ptr<ConditionalGaussian> cg) : LinearFactor::LinearFactor(const boost::shared_ptr<ConditionalGaussian>& cg) :
b(cg->get_d()) { b(cg->get_d()) {
As.insert(make_pair(cg->key(), cg->get_R())); As.insert(make_pair(cg->key(), cg->get_R()));
std::map<std::string, Matrix>::const_iterator it = cg->parentsBegin(); std::map<std::string, Matrix>::const_iterator it = cg->parentsBegin();

View File

@ -82,7 +82,7 @@ public:
} }
/** Construct from Conditional Gaussian */ /** Construct from Conditional Gaussian */
LinearFactor(const boost::shared_ptr<ConditionalGaussian> cg); LinearFactor(const boost::shared_ptr<ConditionalGaussian>& cg);
/** /**
* Constructor that combines a set of factors * Constructor that combines a set of factors

View File

@ -65,7 +65,8 @@ namespace gtsam {
/** print */ /** print */
void print(const std::string& s = "SymbolicConditional") const { void print(const std::string& s = "SymbolicConditional") const {
std::cout << s << " P(" << key_ << " |"; std::cout << s << " P(" << key_;
if (parents_.size()>0) std::cout << " |";
BOOST_FOREACH(std::string parent, parents_) std::cout << " " << parent; BOOST_FOREACH(std::string parent, parents_) std::cout << " " << parent;
std::cout << ")" << std::endl; std::cout << ")" << std::endl;
} }

View File

@ -55,17 +55,24 @@ namespace gtsam {
* Template to create a binary predicate * Template to create a binary predicate
*/ */
template<class V> template<class V>
bool equals(const V& expected, const V& actual, double tol = 1e-9) { struct equals : public std::binary_function<const V&, const V&, bool> {
return (actual.equals(expected, tol)); double tol_;
} equals(double tol = 1e-9) : tol_(tol) {}
bool operator()(const V& expected, const V& actual) {
return (actual.equals(expected, tol_));
}
};
/** /**
* Binary predicate on shared pointers * Binary predicate on shared pointers
*/ */
template<class V> template<class V>
bool equals_star(const boost::shared_ptr<V>& expected, struct equals_star : public std::binary_function<const boost::shared_ptr<V>&, const boost::shared_ptr<V>&, bool> {
const boost::shared_ptr<V>& actual, double tol = 1e-9) { double tol_;
return (actual->equals(*expected, tol)); equals_star(double tol = 1e-9) : tol_(tol) {}
} bool operator()(const boost::shared_ptr<V>& expected, const boost::shared_ptr<V>& actual) {
return (actual->equals(*expected, tol_));
}
};
} // gtsam } // gtsam

View File

@ -54,9 +54,9 @@ TEST( BayesTree, constructor )
// Check root // Check root
BayesNet<SymbolicConditional> expected_root; BayesNet<SymbolicConditional> expected_root;
expected_root.push_back(B);
expected_root.push_back(L);
expected_root.push_back(E); expected_root.push_back(E);
expected_root.push_back(L);
expected_root.push_back(B);
BayesNet<SymbolicConditional> actual_root = bayesTree.root(); BayesNet<SymbolicConditional> actual_root = bayesTree.root();
CHECK(assert_equal(expected_root,actual_root)); CHECK(assert_equal(expected_root,actual_root));
@ -68,7 +68,7 @@ TEST( BayesTree, constructor )
ASIA.push_back(E); ASIA.push_back(E);
ASIA.push_back(L); ASIA.push_back(L);
ASIA.push_back(B); ASIA.push_back(B);
bool verbose = true; bool verbose = false;
BayesTree<SymbolicConditional> bayesTree2(ASIA,verbose); BayesTree<SymbolicConditional> bayesTree2(ASIA,verbose);
if (verbose) bayesTree2.print("bayesTree2"); if (verbose) bayesTree2.print("bayesTree2");