diff --git a/gtsam/base/treeTraversal-inst.h b/gtsam/base/treeTraversal-inst.h index bcc0c281d..7268f36b2 100644 --- a/gtsam/base/treeTraversal-inst.h +++ b/gtsam/base/treeTraversal-inst.h @@ -56,7 +56,7 @@ namespace gtsam { void no_op(const NODE& node, const DATA& data) {} } - /** Traverse a forest depth-first. + /** Traverse a forest depth-first with pre-order and post-order visits. * @param forest The forest of trees to traverse. The method \c forest.roots() should exist * and return a collection of (shared) pointers to \c FOREST::Node. * @param visitorPre \c visitorPre(node, parentData) will be called at every node, before @@ -70,9 +70,8 @@ namespace gtsam { * call to \c visitorPre (the \c DATA object may be modified by visiting the children). * @param rootData The data to pass by reference to \c visitorPre when it is called on each * root node. */ - template - void DepthFirstForest(FOREST& forest, DATA& rootData, - VISITOR_PRE& visitorPre, VISITOR_POST& visitorPost = no_op) + template + void DepthFirstForest(FOREST& forest, DATA& rootData, VISITOR_PRE& visitorPre, VISITOR_POST& visitorPost) { // Depth first traversal stack typedef TraversalNode TraversalNode; @@ -83,7 +82,7 @@ namespace gtsam { // Add roots to stack (use reverse iterators so children are processed in the order they // appear) (void) std::for_each(forest.roots().rbegin(), forest.roots().rend(), - Expander(visitor, &rootData, stack)); + Expander(visitorPre, &rootData, stack)); // Traverse while(!stack.empty()) @@ -105,6 +104,25 @@ namespace gtsam { } } } + + /** Traverse a forest depth-first, with a pre-order visit but no post-order visit. + * @param forest The forest of trees to traverse. The method \c forest.roots() should exist + * and return a collection of (shared) pointers to \c FOREST::Node. + * @param visitorPre \c visitorPre(node, parentData) will be called at every node, before + * visiting its children, and will be passed, by reference, the \c DATA object returned + * by the visit to its parent. Likewise, \c visitorPre should return the \c DATA object + * to pass to the children. The returned \c DATA object will be copy-constructed only + * upon returning to store internally, thus may be modified by visiting the children. + * Regarding efficiency, this copy-on-return is usually optimized out by the compiler. + * @param rootData The data to pass by reference to \c visitorPre when it is called on each + * root node. */ + template + void DepthFirstForest(FOREST& forest, DATA& rootData, VISITOR_PRE& visitorPre) + { + DepthFirstForest( + forest, rootData, visitorPre, no_op); + } + } } \ No newline at end of file diff --git a/gtsam/inference/inference-inst.h b/gtsam/inference/inference-inst.h index 7d37b06eb..668f6e6e3 100644 --- a/gtsam/inference/inference-inst.h +++ b/gtsam/inference/inference-inst.h @@ -27,21 +27,9 @@ namespace gtsam { namespace inference { - /* ************************************************************************* */ namespace { - template - struct EliminationNode { - bool expanded; - const typename ELIMINATIONTREE::Node* const treeNode; - std::vector childrenFactors; - EliminationNode* const parent; - EliminationNode(const typename ELIMINATIONTREE::Node* _treeNode, EliminationNode* _parent) : - expanded(false), treeNode(_treeNode), parent(_parent) { - childrenFactors.reserve(treeNode->children.size()); } - }; - /* ************************************************************************* */ - template + template struct EliminationData { EliminationData* const parentData; std::vector childFactors; @@ -50,22 +38,22 @@ namespace gtsam { }; /* ************************************************************************* */ - template + template EliminationData eliminationPreOrderVisitor( - const typename TREE::sharedNode& node, EliminationData* parentData) + const typename TREE::sharedNode& node, EliminationData& parentData) { // This function is called before visiting the children. Here, we create this node's data, // which includes a pointer to the parent data and space for the factors of the children. - return EliminationData(parentData, node->children.size()); + return EliminationData(&parentData, node->children.size()); } /* ************************************************************************* */ - template - void eliminationPostOrderVisitor(const TREE::Node* const node, EliminationData& myData, + template + void eliminationPostOrderVisitor(const typename TREE::Node& const node, EliminationData& myData, RESULT& result, const typename TREE::Eliminate& eliminationFunction) { // Call eliminate on the node and add the result to the parent's gathered factors - myData.parentData->childFactors.push_back(node->eliminate(result, eliminationFunction, myData.childFactors)); + myData.parentData->childFactors.push_back(node.eliminate(result, eliminationFunction, myData.childFactors)); } } @@ -81,60 +69,18 @@ namespace gtsam { typedef typename TREE::sharedNode sharedNode; typedef typename TREE::sharedFactor sharedFactor; - // Allocate remaining factors - std::vector remainingFactors; - remainingFactors.reserve(tree.roots().size()); - - treeTraversal::DepthFirstForest(tree, remainingFactors, ) - - // Stack for eliminating nodes. We use this stack instead of recursive function calls to - // avoid call stack overflow due to very long trees that arise from chain-like graphs. - // TODO: Check whether this is faster as a vector (then use indices instead of parent pointers). - typedef EliminationNode EliminationNode; - std::stack > eliminationStack; - - // Allocate remaining factors - std::vector remainingFactors; - remainingFactors.reserve(tree.roots().size()); - - // Add roots to the stack (use reverse foreach so conditionals to appear in elimination order - - // doesn't matter for computation but can make printouts easier to interpret by hand). - BOOST_REVERSE_FOREACH(const sharedNode& root, tree.roots()) { - eliminationStack.push( - EliminationNode(root.get(), 0)); } - - // Until the stack is empty - while(!eliminationStack.empty()) { - // Process the next node. If it has children, add its children to the stack and mark it - // expanded - we'll come back and eliminate it later after the children have been processed. - EliminationNode& node = eliminationStack.top(); - if(node.expanded) - { - // Do elimination step - sharedFactor remainingFactor = node.treeNode->eliminate(result, function, node.childrenFactors); - - // TODO: Don't add null factor? - if(node.parent) - node.parent->childrenFactors.push_back(remainingFactor); - else - remainingFactors.push_back(remainingFactor); - - // Remove from stack - eliminationStack.pop(); - } else - { - // Expand children and mark as expanded (use reverse foreach so conditionals to appear in - // elimination order - doesn't matter for computation but can make printouts easier to - // interpret by hand). - node.expanded = true; - BOOST_REVERSE_FOREACH(const sharedNode& child, node.treeNode->children) { - eliminationStack.push( - EliminationNode(child.get(), &node)); } - } - } + // Do elimination using a depth-first traversal. During the pre-order visit (see + // eliminationPreOrderVisitor), we store a pointer to the parent data (where we'll put the + // remaining factor) and reserve a vector of factors to store the children elimination + // results. During the post-order visit (see eliminationPostOrderVisitor), we call dense + // elimination (using the gathered child factors) and store the result in the parent's + // gathered factors. + EliminationData rootData(0, tree.roots().size()); + treeTraversal::DepthFirstForest(tree, rootData, eliminationPreOrderVisitor, + boost::bind(eliminationPostOrderVisitor, _1, _2, result, function)); // Return remaining factors - return remainingFactors; + return rootData.childFactors; } }