diff --git a/gtsam/inference/JunctionTree-inst.h b/gtsam/inference/JunctionTree-inst.h index 5735a3bd2..232246d5e 100644 --- a/gtsam/inference/JunctionTree-inst.h +++ b/gtsam/inference/JunctionTree-inst.h @@ -73,7 +73,8 @@ struct ConstructorTraversalData { // Do symbolic elimination for this node SymbolicFactors symbolicFactors; - symbolicFactors.reserve(ETreeNode->factors.size() + myData.childSymbolicFactors.size()); + symbolicFactors.reserve( + ETreeNode->factors.size() + myData.childSymbolicFactors.size()); // Add ETree node factors symbolicFactors += ETreeNode->factors; // Add symbolic factors passed up from children @@ -96,29 +97,42 @@ struct ConstructorTraversalData { // Merge our children if they are in our clique - if our conditional has // exactly one fewer parent than our child's conditional. - size_t myNrFrontals = 1; const size_t myNrParents = myConditional->nrParents(); - assert(node->newChildren.size() == childConditionals.size()); + const size_t nrChildren = node->children.size(); + assert(childConditionals.size() == nrChildren); gttic(merge_children); - // First count how many keys, factors and children we'll end up with + + // decide which children to merge, as index into children + std::vector merge(nrChildren, false); + { + size_t myNrFrontals = 1; + for (size_t i = 0; i < nrChildren; ++i) { + // Check if we should merge the i^th child + if (myNrParents + myNrFrontals == childConditionals[i]->nrParents()) { + sharedNode child = node->children[i]; + // Increment number of frontal variables + myNrFrontals += child->orderedFrontalKeys.size(); + merge[i] = true; + } + } + } + + // Count how many keys, factors and children we'll end up with size_t nrKeys = node->orderedFrontalKeys.size(); size_t nrFactors = node->factors.size(); - size_t nrChildren = 0; + size_t nrNewChildren = 0; // Loop over children - for (size_t i = 0; i < childConditionals.size(); ++i) { - // Check if we should merge the i^th child - if (myNrParents + myNrFrontals == childConditionals[i]->nrParents()) { + for (size_t i = 0; i < nrChildren; ++i) { + if (merge[i]) { // Get a reference to the i, adjusting the index to account for children // previously merged and removed from the i list. sharedNode child = node->children[i]; nrKeys += child->orderedFrontalKeys.size(); nrFactors += child->factors.size(); - nrChildren += child->children.size(); - // Increment number of frontal variables - myNrFrontals += child->orderedFrontalKeys.size(); + nrNewChildren += child->children.size(); } else { - nrChildren += 1; // we keep the child + nrNewChildren += 1; // we keep the child } } @@ -126,14 +140,14 @@ struct ConstructorTraversalData { node->orderedFrontalKeys.reserve(nrKeys); node->factors.reserve(nrFactors); typename Node::Children newChildren; - newChildren.reserve(nrChildren); - myNrFrontals = 1; - int combinedProblemSize = (int) (myConditional->size() * symbolicFactors.size()); + newChildren.reserve(nrNewChildren); + int combinedProblemSize = (int) (myConditional->size() + * symbolicFactors.size()); // Loop over newChildren - for (size_t i = 0; i < childConditionals.size(); ++i) { + for (size_t i = 0; i < nrChildren; ++i) { // Check if we should merge the i^th child sharedNode child = node->children[i]; - if (myNrParents + myNrFrontals == childConditionals[i]->nrParents()) { + if (merge[i]) { // Get a reference to the i, adjusting the index to account for newChildren // previously merged and removed from the i list. // Merge keys. For efficiency, we add keys in reverse order at end, calling reverse after.. @@ -141,18 +155,21 @@ struct ConstructorTraversalData { child->orderedFrontalKeys.rbegin(), child->orderedFrontalKeys.rend()); // Merge keys, factors, and children. - node->factors.insert(node->factors.end(), child->factors.begin(), child->factors.end()); - newChildren.insert(newChildren.end(), child->children.begin(), child->children.end()); + node->factors.insert(node->factors.end(), child->factors.begin(), + child->factors.end()); + newChildren.insert(newChildren.end(), child->children.begin(), + child->children.end()); // Increment problem size - combinedProblemSize = std::max(combinedProblemSize, child->problemSize_); + combinedProblemSize = std::max(combinedProblemSize, + child->problemSize_); // Increment number of frontal variables - myNrFrontals += child->orderedFrontalKeys.size(); } else { newChildren.push_back(child); // we keep the child } } node->children = newChildren; - std::reverse(node->orderedFrontalKeys.begin(), node->orderedFrontalKeys.end()); + std::reverse(node->orderedFrontalKeys.begin(), + node->orderedFrontalKeys.end()); gttoc(merge_children); node->problemSize_ = combinedProblemSize; }