diff --git a/gtsam/inference/JunctionTree-inst.h b/gtsam/inference/JunctionTree-inst.h index d7ff0281a..5735a3bd2 100644 --- a/gtsam/inference/JunctionTree-inst.h +++ b/gtsam/inference/JunctionTree-inst.h @@ -27,7 +27,7 @@ namespace gtsam { -template +template struct ConstructorTraversalData { typedef typename JunctionTree::Node Node; typedef typename JunctionTree::sharedNode sharedNode; @@ -37,8 +37,13 @@ struct ConstructorTraversalData { FastVector childSymbolicConditionals; FastVector childSymbolicFactors; - ConstructorTraversalData(ConstructorTraversalData* _parentData) - : parentData(_parentData) {} + // Small inner class to store symbolic factors + class SymbolicFactors: public FactorGraph { + }; + + ConstructorTraversalData(ConstructorTraversalData* _parentData) : + parentData(_parentData) { + } // Pre-order visitor function static ConstructorTraversalData ConstructorTraversalVisitorPre( @@ -64,13 +69,11 @@ struct ConstructorTraversalData { // our number of symbolic elimination parents is exactly 1 less than // our child's symbolic elimination parents - this condition indicates that // eliminating the current node did not introduce any parents beyond those - // already in the child. + // already in the child-> // Do symbolic elimination for this node - class : public FactorGraph {} - symbolicFactors; - symbolicFactors.reserve(ETreeNode->factors.size() + - myData.childSymbolicFactors.size()); + SymbolicFactors symbolicFactors; + symbolicFactors.reserve(ETreeNode->factors.size() + myData.childSymbolicFactors.size()); // Add ETree node factors symbolicFactors += ETreeNode->factors; // Add symbolic factors passed up from children @@ -78,60 +81,87 @@ struct ConstructorTraversalData { Ordering keyAsOrdering; keyAsOrdering.push_back(ETreeNode->key); - std::pair - symbolicElimResult = - internal::EliminateSymbolic(symbolicFactors, keyAsOrdering); + SymbolicConditional::shared_ptr myConditional; + SymbolicFactor::shared_ptr mySeparatorFactor; + boost::tie(myConditional, mySeparatorFactor) = internal::EliminateSymbolic( + symbolicFactors, keyAsOrdering); // Store symbolic elimination results in the parent - myData.parentData->childSymbolicConditionals.push_back( - symbolicElimResult.first); - myData.parentData->childSymbolicFactors.push_back( - symbolicElimResult.second); + myData.parentData->childSymbolicConditionals.push_back(myConditional); + myData.parentData->childSymbolicFactors.push_back(mySeparatorFactor); + sharedNode node = myData.myJTNode; + const FastVector& childConditionals = + myData.childSymbolicConditionals; // Merge our children if they are in our clique - if our conditional has // exactly one fewer parent than our child's conditional. size_t myNrFrontals = 1; - const size_t myNrParents = symbolicElimResult.first->nrParents(); - size_t nrMergedChildren = 0; - assert(node->children.size() == myData.childSymbolicConditionals.size()); - // Loop over children - int combinedProblemSize = - (int)(symbolicElimResult.first->size() * symbolicFactors.size()); + const size_t myNrParents = myConditional->nrParents(); + assert(node->newChildren.size() == childConditionals.size()); + gttic(merge_children); - for (size_t i = 0; i < myData.childSymbolicConditionals.size(); ++i) { + // First count how many keys, factors and children we'll end up with + size_t nrKeys = node->orderedFrontalKeys.size(); + size_t nrFactors = node->factors.size(); + size_t nrChildren = 0; + // Loop over children + for (size_t i = 0; i < childConditionals.size(); ++i) { // Check if we should merge the i^th child - if (myNrParents + myNrFrontals == - myData.childSymbolicConditionals[i]->nrParents()) { + if (myNrParents + myNrFrontals == childConditionals[i]->nrParents()) { // Get a reference to the i, adjusting the index to account for children // previously merged and removed from the i list. - const Node& child = *node->children[i - nrMergedChildren]; - // Merge keys. For efficiency, we add keys in reverse order at end, calling reverse after.. - node->orderedFrontalKeys.insert(node->orderedFrontalKeys.end(), - child.orderedFrontalKeys.rbegin(), - child.orderedFrontalKeys.rend()); - // Merge keys, factors, and children. - node->factors.insert(node->factors.end(), child.factors.begin(), child.factors.end()); - node->children.insert(node->children.end(), child.children.begin(), child.children.end()); - // Increment problem size - combinedProblemSize = std::max(combinedProblemSize, child.problemSize_); + sharedNode child = node->children[i]; + nrKeys += child->orderedFrontalKeys.size(); + nrFactors += child->factors.size(); + nrChildren += child->children.size(); // Increment number of frontal variables - myNrFrontals += child.orderedFrontalKeys.size(); - // Remove i from list. - node->children.erase(node->children.begin() + (i - nrMergedChildren)); - // Increment number of merged children - ++nrMergedChildren; + myNrFrontals += child->orderedFrontalKeys.size(); + } else { + nrChildren += 1; // we keep the child } } + + // now reserve space, and really merge + node->orderedFrontalKeys.reserve(nrKeys); + node->factors.reserve(nrFactors); + typename Node::Children newChildren; + newChildren.reserve(nrChildren); + myNrFrontals = 1; + int combinedProblemSize = (int) (myConditional->size() * symbolicFactors.size()); + // Loop over newChildren + for (size_t i = 0; i < childConditionals.size(); ++i) { + // Check if we should merge the i^th child + sharedNode child = node->children[i]; + if (myNrParents + myNrFrontals == childConditionals[i]->nrParents()) { + // Get a reference to the i, adjusting the index to account for newChildren + // previously merged and removed from the i list. + // Merge keys. For efficiency, we add keys in reverse order at end, calling reverse after.. + node->orderedFrontalKeys.insert(node->orderedFrontalKeys.end(), + child->orderedFrontalKeys.rbegin(), + child->orderedFrontalKeys.rend()); + // Merge keys, factors, and children. + node->factors.insert(node->factors.end(), child->factors.begin(), child->factors.end()); + newChildren.insert(newChildren.end(), child->children.begin(), child->children.end()); + // Increment problem size + combinedProblemSize = std::max(combinedProblemSize, child->problemSize_); + // Increment number of frontal variables + myNrFrontals += child->orderedFrontalKeys.size(); + } else { + newChildren.push_back(child); // we keep the child + } + } + node->children = newChildren; std::reverse(node->orderedFrontalKeys.begin(), node->orderedFrontalKeys.end()); gttoc(merge_children); node->problemSize_ = combinedProblemSize; } -}; +} +; /* ************************************************************************* */ -template -template +template +template JunctionTree::JunctionTree( const EliminationTree& eliminationTree) { gttic(JunctionTree_FromEliminationTree); @@ -147,12 +177,11 @@ JunctionTree::JunctionTree( typedef typename EliminationTree::Node ETreeNode; typedef ConstructorTraversalData Data; Data rootData(0); - rootData.myJTNode = - boost::make_shared(); // Make a dummy node to gather - // the junction tree roots + rootData.myJTNode = boost::make_shared(); // Make a dummy node to gather + // the junction tree roots treeTraversal::DepthFirstForest(eliminationTree, rootData, - Data::ConstructorTraversalVisitorPre, - Data::ConstructorTraversalVisitorPostAlg2); + Data::ConstructorTraversalVisitorPre, + Data::ConstructorTraversalVisitorPostAlg2); // Assign roots from the dummy node Base::roots_ = rootData.myJTNode->children; @@ -161,4 +190,4 @@ JunctionTree::JunctionTree( Base::remainingFactors_ = eliminationTree.remainingFactors(); } -} // namespace gtsam +} // namespace gtsam