Converted elimination and printing of elimination trees to use generic DFS code
parent
df728a969c
commit
5b1ac91c85
|
|
@ -56,7 +56,7 @@ namespace gtsam {
|
||||||
void no_op(const NODE& node, const DATA& data) {}
|
void no_op(const NODE& node, const DATA& data) {}
|
||||||
}
|
}
|
||||||
|
|
||||||
/** Traverse a forest depth-first.
|
/** Traverse a forest depth-first with pre-order and post-order visits.
|
||||||
* @param forest The forest of trees to traverse. The method \c forest.roots() should exist
|
* @param forest The forest of trees to traverse. The method \c forest.roots() should exist
|
||||||
* and return a collection of (shared) pointers to \c FOREST::Node.
|
* and return a collection of (shared) pointers to \c FOREST::Node.
|
||||||
* @param visitorPre \c visitorPre(node, parentData) will be called at every node, before
|
* @param visitorPre \c visitorPre(node, parentData) will be called at every node, before
|
||||||
|
|
@ -70,9 +70,8 @@ namespace gtsam {
|
||||||
* call to \c visitorPre (the \c DATA object may be modified by visiting the children).
|
* call to \c visitorPre (the \c DATA object may be modified by visiting the children).
|
||||||
* @param rootData The data to pass by reference to \c visitorPre when it is called on each
|
* @param rootData The data to pass by reference to \c visitorPre when it is called on each
|
||||||
* root node. */
|
* root node. */
|
||||||
template<class FOREST, typename VISITOR_PRE, typename VISITOR_POST, typename DATA>
|
template<class FOREST, typename DATA, typename VISITOR_PRE, typename VISITOR_POST>
|
||||||
void DepthFirstForest(FOREST& forest, DATA& rootData,
|
void DepthFirstForest(FOREST& forest, DATA& rootData, VISITOR_PRE& visitorPre, VISITOR_POST& visitorPost)
|
||||||
VISITOR_PRE& visitorPre, VISITOR_POST& visitorPost = no_op<typename FOREST::Node, DATA>)
|
|
||||||
{
|
{
|
||||||
// Depth first traversal stack
|
// Depth first traversal stack
|
||||||
typedef TraversalNode<typename FOREST::Node, DATA> TraversalNode;
|
typedef TraversalNode<typename FOREST::Node, DATA> TraversalNode;
|
||||||
|
|
@ -83,7 +82,7 @@ namespace gtsam {
|
||||||
// Add roots to stack (use reverse iterators so children are processed in the order they
|
// Add roots to stack (use reverse iterators so children are processed in the order they
|
||||||
// appear)
|
// appear)
|
||||||
(void) std::for_each(forest.roots().rbegin(), forest.roots().rend(),
|
(void) std::for_each(forest.roots().rbegin(), forest.roots().rend(),
|
||||||
Expander(visitor, &rootData, stack));
|
Expander(visitorPre, &rootData, stack));
|
||||||
|
|
||||||
// Traverse
|
// Traverse
|
||||||
while(!stack.empty())
|
while(!stack.empty())
|
||||||
|
|
@ -105,6 +104,25 @@ namespace gtsam {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/** Traverse a forest depth-first, with a pre-order visit but no post-order visit.
|
||||||
|
* @param forest The forest of trees to traverse. The method \c forest.roots() should exist
|
||||||
|
* and return a collection of (shared) pointers to \c FOREST::Node.
|
||||||
|
* @param visitorPre \c visitorPre(node, parentData) will be called at every node, before
|
||||||
|
* visiting its children, and will be passed, by reference, the \c DATA object returned
|
||||||
|
* by the visit to its parent. Likewise, \c visitorPre should return the \c DATA object
|
||||||
|
* to pass to the children. The returned \c DATA object will be copy-constructed only
|
||||||
|
* upon returning to store internally, thus may be modified by visiting the children.
|
||||||
|
* Regarding efficiency, this copy-on-return is usually optimized out by the compiler.
|
||||||
|
* @param rootData The data to pass by reference to \c visitorPre when it is called on each
|
||||||
|
* root node. */
|
||||||
|
template<class FOREST, typename DATA, typename VISITOR_PRE>
|
||||||
|
void DepthFirstForest(FOREST& forest, DATA& rootData, VISITOR_PRE& visitorPre)
|
||||||
|
{
|
||||||
|
DepthFirstForest<FOREST, DATA, VISITOR_PRE, void(&)(const typename FOREST::Node&, const DATA&)>(
|
||||||
|
forest, rootData, visitorPre, no_op<typename FOREST::Node, DATA>);
|
||||||
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
@ -27,21 +27,9 @@
|
||||||
namespace gtsam {
|
namespace gtsam {
|
||||||
namespace inference {
|
namespace inference {
|
||||||
|
|
||||||
/* ************************************************************************* */
|
|
||||||
namespace {
|
namespace {
|
||||||
template<class ELIMINATIONTREE>
|
|
||||||
struct EliminationNode {
|
|
||||||
bool expanded;
|
|
||||||
const typename ELIMINATIONTREE::Node* const treeNode;
|
|
||||||
std::vector<typename ELIMINATIONTREE::sharedFactor> childrenFactors;
|
|
||||||
EliminationNode<ELIMINATIONTREE>* const parent;
|
|
||||||
EliminationNode(const typename ELIMINATIONTREE::Node* _treeNode, EliminationNode<ELIMINATIONTREE>* _parent) :
|
|
||||||
expanded(false), treeNode(_treeNode), parent(_parent) {
|
|
||||||
childrenFactors.reserve(treeNode->children.size()); }
|
|
||||||
};
|
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
template<TREE>
|
template<class TREE>
|
||||||
struct EliminationData {
|
struct EliminationData {
|
||||||
EliminationData* const parentData;
|
EliminationData* const parentData;
|
||||||
std::vector<typename TREE::sharedFactor> childFactors;
|
std::vector<typename TREE::sharedFactor> childFactors;
|
||||||
|
|
@ -50,22 +38,22 @@ namespace gtsam {
|
||||||
};
|
};
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
template<TREE>
|
template<class TREE>
|
||||||
EliminationData<TREE> eliminationPreOrderVisitor(
|
EliminationData<TREE> eliminationPreOrderVisitor(
|
||||||
const typename TREE::sharedNode& node, EliminationData<TREE>* parentData)
|
const typename TREE::sharedNode& node, EliminationData<TREE>& parentData)
|
||||||
{
|
{
|
||||||
// This function is called before visiting the children. Here, we create this node's data,
|
// This function is called before visiting the children. Here, we create this node's data,
|
||||||
// which includes a pointer to the parent data and space for the factors of the children.
|
// which includes a pointer to the parent data and space for the factors of the children.
|
||||||
return EliminationData<TREE>(parentData, node->children.size());
|
return EliminationData<TREE>(&parentData, node->children.size());
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
template<TREE, RESULT>
|
template<class TREE, class RESULT>
|
||||||
void eliminationPostOrderVisitor(const TREE::Node* const node, EliminationData<TREE>& myData,
|
void eliminationPostOrderVisitor(const typename TREE::Node& const node, EliminationData<TREE>& myData,
|
||||||
RESULT& result, const typename TREE::Eliminate& eliminationFunction)
|
RESULT& result, const typename TREE::Eliminate& eliminationFunction)
|
||||||
{
|
{
|
||||||
// Call eliminate on the node and add the result to the parent's gathered factors
|
// Call eliminate on the node and add the result to the parent's gathered factors
|
||||||
myData.parentData->childFactors.push_back(node->eliminate(result, eliminationFunction, myData.childFactors));
|
myData.parentData->childFactors.push_back(node.eliminate(result, eliminationFunction, myData.childFactors));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -81,60 +69,18 @@ namespace gtsam {
|
||||||
typedef typename TREE::sharedNode sharedNode;
|
typedef typename TREE::sharedNode sharedNode;
|
||||||
typedef typename TREE::sharedFactor sharedFactor;
|
typedef typename TREE::sharedFactor sharedFactor;
|
||||||
|
|
||||||
// Allocate remaining factors
|
// Do elimination using a depth-first traversal. During the pre-order visit (see
|
||||||
std::vector<sharedFactor> remainingFactors;
|
// eliminationPreOrderVisitor), we store a pointer to the parent data (where we'll put the
|
||||||
remainingFactors.reserve(tree.roots().size());
|
// remaining factor) and reserve a vector of factors to store the children elimination
|
||||||
|
// results. During the post-order visit (see eliminationPostOrderVisitor), we call dense
|
||||||
treeTraversal::DepthFirstForest(tree, remainingFactors, )
|
// elimination (using the gathered child factors) and store the result in the parent's
|
||||||
|
// gathered factors.
|
||||||
// Stack for eliminating nodes. We use this stack instead of recursive function calls to
|
EliminationData<TREE> rootData(0, tree.roots().size());
|
||||||
// avoid call stack overflow due to very long trees that arise from chain-like graphs.
|
treeTraversal::DepthFirstForest(tree, rootData, eliminationPreOrderVisitor<TREE>,
|
||||||
// TODO: Check whether this is faster as a vector (then use indices instead of parent pointers).
|
boost::bind(eliminationPostOrderVisitor<TREE,RESULT>, _1, _2, result, function));
|
||||||
typedef EliminationNode<TREE> EliminationNode;
|
|
||||||
std::stack<EliminationNode, FastList<EliminationNode> > eliminationStack;
|
|
||||||
|
|
||||||
// Allocate remaining factors
|
|
||||||
std::vector<sharedFactor> remainingFactors;
|
|
||||||
remainingFactors.reserve(tree.roots().size());
|
|
||||||
|
|
||||||
// Add roots to the stack (use reverse foreach so conditionals to appear in elimination order -
|
|
||||||
// doesn't matter for computation but can make printouts easier to interpret by hand).
|
|
||||||
BOOST_REVERSE_FOREACH(const sharedNode& root, tree.roots()) {
|
|
||||||
eliminationStack.push(
|
|
||||||
EliminationNode(root.get(), 0)); }
|
|
||||||
|
|
||||||
// Until the stack is empty
|
|
||||||
while(!eliminationStack.empty()) {
|
|
||||||
// Process the next node. If it has children, add its children to the stack and mark it
|
|
||||||
// expanded - we'll come back and eliminate it later after the children have been processed.
|
|
||||||
EliminationNode& node = eliminationStack.top();
|
|
||||||
if(node.expanded)
|
|
||||||
{
|
|
||||||
// Do elimination step
|
|
||||||
sharedFactor remainingFactor = node.treeNode->eliminate(result, function, node.childrenFactors);
|
|
||||||
|
|
||||||
// TODO: Don't add null factor?
|
|
||||||
if(node.parent)
|
|
||||||
node.parent->childrenFactors.push_back(remainingFactor);
|
|
||||||
else
|
|
||||||
remainingFactors.push_back(remainingFactor);
|
|
||||||
|
|
||||||
// Remove from stack
|
|
||||||
eliminationStack.pop();
|
|
||||||
} else
|
|
||||||
{
|
|
||||||
// Expand children and mark as expanded (use reverse foreach so conditionals to appear in
|
|
||||||
// elimination order - doesn't matter for computation but can make printouts easier to
|
|
||||||
// interpret by hand).
|
|
||||||
node.expanded = true;
|
|
||||||
BOOST_REVERSE_FOREACH(const sharedNode& child, node.treeNode->children) {
|
|
||||||
eliminationStack.push(
|
|
||||||
EliminationNode(child.get(), &node)); }
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Return remaining factors
|
// Return remaining factors
|
||||||
return remainingFactors;
|
return rootData.childFactors;
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue