Count then merge

release/4.3a0
dellaert 2015-06-21 11:44:17 -07:00
parent c3811a5488
commit 2dd83fd92c
1 changed files with 77 additions and 48 deletions

View File

@ -27,7 +27,7 @@
namespace gtsam { namespace gtsam {
template <class BAYESTREE, class GRAPH, class ETREE_NODE> template<class BAYESTREE, class GRAPH, class ETREE_NODE>
struct ConstructorTraversalData { struct ConstructorTraversalData {
typedef typename JunctionTree<BAYESTREE, GRAPH>::Node Node; typedef typename JunctionTree<BAYESTREE, GRAPH>::Node Node;
typedef typename JunctionTree<BAYESTREE, GRAPH>::sharedNode sharedNode; typedef typename JunctionTree<BAYESTREE, GRAPH>::sharedNode sharedNode;
@ -37,8 +37,13 @@ struct ConstructorTraversalData {
FastVector<SymbolicConditional::shared_ptr> childSymbolicConditionals; FastVector<SymbolicConditional::shared_ptr> childSymbolicConditionals;
FastVector<SymbolicFactor::shared_ptr> childSymbolicFactors; FastVector<SymbolicFactor::shared_ptr> childSymbolicFactors;
ConstructorTraversalData(ConstructorTraversalData* _parentData) // Small inner class to store symbolic factors
: parentData(_parentData) {} class SymbolicFactors: public FactorGraph<Factor> {
};
ConstructorTraversalData(ConstructorTraversalData* _parentData) :
parentData(_parentData) {
}
// Pre-order visitor function // Pre-order visitor function
static ConstructorTraversalData ConstructorTraversalVisitorPre( static ConstructorTraversalData ConstructorTraversalVisitorPre(
@ -64,13 +69,11 @@ struct ConstructorTraversalData {
// our number of symbolic elimination parents is exactly 1 less than // our number of symbolic elimination parents is exactly 1 less than
// our child's symbolic elimination parents - this condition indicates that // our child's symbolic elimination parents - this condition indicates that
// eliminating the current node did not introduce any parents beyond those // eliminating the current node did not introduce any parents beyond those
// already in the child. // already in the child->
// Do symbolic elimination for this node // Do symbolic elimination for this node
class : public FactorGraph<Factor> {} SymbolicFactors symbolicFactors;
symbolicFactors; symbolicFactors.reserve(ETreeNode->factors.size() + myData.childSymbolicFactors.size());
symbolicFactors.reserve(ETreeNode->factors.size() +
myData.childSymbolicFactors.size());
// Add ETree node factors // Add ETree node factors
symbolicFactors += ETreeNode->factors; symbolicFactors += ETreeNode->factors;
// Add symbolic factors passed up from children // Add symbolic factors passed up from children
@ -78,60 +81,87 @@ struct ConstructorTraversalData {
Ordering keyAsOrdering; Ordering keyAsOrdering;
keyAsOrdering.push_back(ETreeNode->key); keyAsOrdering.push_back(ETreeNode->key);
std::pair<SymbolicConditional::shared_ptr, SymbolicFactor::shared_ptr> SymbolicConditional::shared_ptr myConditional;
symbolicElimResult = SymbolicFactor::shared_ptr mySeparatorFactor;
internal::EliminateSymbolic(symbolicFactors, keyAsOrdering); boost::tie(myConditional, mySeparatorFactor) = internal::EliminateSymbolic(
symbolicFactors, keyAsOrdering);
// Store symbolic elimination results in the parent // Store symbolic elimination results in the parent
myData.parentData->childSymbolicConditionals.push_back( myData.parentData->childSymbolicConditionals.push_back(myConditional);
symbolicElimResult.first); myData.parentData->childSymbolicFactors.push_back(mySeparatorFactor);
myData.parentData->childSymbolicFactors.push_back(
symbolicElimResult.second);
sharedNode node = myData.myJTNode; sharedNode node = myData.myJTNode;
const FastVector<SymbolicConditional::shared_ptr>& childConditionals =
myData.childSymbolicConditionals;
// Merge our children if they are in our clique - if our conditional has // Merge our children if they are in our clique - if our conditional has
// exactly one fewer parent than our child's conditional. // exactly one fewer parent than our child's conditional.
size_t myNrFrontals = 1; size_t myNrFrontals = 1;
const size_t myNrParents = symbolicElimResult.first->nrParents(); const size_t myNrParents = myConditional->nrParents();
size_t nrMergedChildren = 0; assert(node->newChildren.size() == childConditionals.size());
assert(node->children.size() == myData.childSymbolicConditionals.size());
// Loop over children
int combinedProblemSize =
(int)(symbolicElimResult.first->size() * symbolicFactors.size());
gttic(merge_children); gttic(merge_children);
for (size_t i = 0; i < myData.childSymbolicConditionals.size(); ++i) { // First count how many keys, factors and children we'll end up with
size_t nrKeys = node->orderedFrontalKeys.size();
size_t nrFactors = node->factors.size();
size_t nrChildren = 0;
// Loop over children
for (size_t i = 0; i < childConditionals.size(); ++i) {
// Check if we should merge the i^th child // Check if we should merge the i^th child
if (myNrParents + myNrFrontals == if (myNrParents + myNrFrontals == childConditionals[i]->nrParents()) {
myData.childSymbolicConditionals[i]->nrParents()) {
// Get a reference to the i, adjusting the index to account for children // Get a reference to the i, adjusting the index to account for children
// previously merged and removed from the i list. // previously merged and removed from the i list.
const Node& child = *node->children[i - nrMergedChildren]; sharedNode child = node->children[i];
// Merge keys. For efficiency, we add keys in reverse order at end, calling reverse after.. nrKeys += child->orderedFrontalKeys.size();
node->orderedFrontalKeys.insert(node->orderedFrontalKeys.end(), nrFactors += child->factors.size();
child.orderedFrontalKeys.rbegin(), nrChildren += child->children.size();
child.orderedFrontalKeys.rend());
// Merge keys, factors, and children.
node->factors.insert(node->factors.end(), child.factors.begin(), child.factors.end());
node->children.insert(node->children.end(), child.children.begin(), child.children.end());
// Increment problem size
combinedProblemSize = std::max(combinedProblemSize, child.problemSize_);
// Increment number of frontal variables // Increment number of frontal variables
myNrFrontals += child.orderedFrontalKeys.size(); myNrFrontals += child->orderedFrontalKeys.size();
// Remove i from list. } else {
node->children.erase(node->children.begin() + (i - nrMergedChildren)); nrChildren += 1; // we keep the child
// Increment number of merged children
++nrMergedChildren;
} }
} }
// now reserve space, and really merge
node->orderedFrontalKeys.reserve(nrKeys);
node->factors.reserve(nrFactors);
typename Node::Children newChildren;
newChildren.reserve(nrChildren);
myNrFrontals = 1;
int combinedProblemSize = (int) (myConditional->size() * symbolicFactors.size());
// Loop over newChildren
for (size_t i = 0; i < childConditionals.size(); ++i) {
// Check if we should merge the i^th child
sharedNode child = node->children[i];
if (myNrParents + myNrFrontals == childConditionals[i]->nrParents()) {
// Get a reference to the i, adjusting the index to account for newChildren
// previously merged and removed from the i list.
// Merge keys. For efficiency, we add keys in reverse order at end, calling reverse after..
node->orderedFrontalKeys.insert(node->orderedFrontalKeys.end(),
child->orderedFrontalKeys.rbegin(),
child->orderedFrontalKeys.rend());
// Merge keys, factors, and children.
node->factors.insert(node->factors.end(), child->factors.begin(), child->factors.end());
newChildren.insert(newChildren.end(), child->children.begin(), child->children.end());
// Increment problem size
combinedProblemSize = std::max(combinedProblemSize, child->problemSize_);
// Increment number of frontal variables
myNrFrontals += child->orderedFrontalKeys.size();
} else {
newChildren.push_back(child); // we keep the child
}
}
node->children = newChildren;
std::reverse(node->orderedFrontalKeys.begin(), node->orderedFrontalKeys.end()); std::reverse(node->orderedFrontalKeys.begin(), node->orderedFrontalKeys.end());
gttoc(merge_children); gttoc(merge_children);
node->problemSize_ = combinedProblemSize; node->problemSize_ = combinedProblemSize;
} }
}; }
;
/* ************************************************************************* */ /* ************************************************************************* */
template <class BAYESTREE, class GRAPH> template<class BAYESTREE, class GRAPH>
template <class ETREE_BAYESNET, class ETREE_GRAPH> template<class ETREE_BAYESNET, class ETREE_GRAPH>
JunctionTree<BAYESTREE, GRAPH>::JunctionTree( JunctionTree<BAYESTREE, GRAPH>::JunctionTree(
const EliminationTree<ETREE_BAYESNET, ETREE_GRAPH>& eliminationTree) { const EliminationTree<ETREE_BAYESNET, ETREE_GRAPH>& eliminationTree) {
gttic(JunctionTree_FromEliminationTree); gttic(JunctionTree_FromEliminationTree);
@ -147,12 +177,11 @@ JunctionTree<BAYESTREE, GRAPH>::JunctionTree(
typedef typename EliminationTree<ETREE_BAYESNET, ETREE_GRAPH>::Node ETreeNode; typedef typename EliminationTree<ETREE_BAYESNET, ETREE_GRAPH>::Node ETreeNode;
typedef ConstructorTraversalData<BAYESTREE, GRAPH, ETreeNode> Data; typedef ConstructorTraversalData<BAYESTREE, GRAPH, ETreeNode> Data;
Data rootData(0); Data rootData(0);
rootData.myJTNode = rootData.myJTNode = boost::make_shared<typename Base::Node>(); // Make a dummy node to gather
boost::make_shared<typename Base::Node>(); // Make a dummy node to gather // the junction tree roots
// the junction tree roots
treeTraversal::DepthFirstForest(eliminationTree, rootData, treeTraversal::DepthFirstForest(eliminationTree, rootData,
Data::ConstructorTraversalVisitorPre, Data::ConstructorTraversalVisitorPre,
Data::ConstructorTraversalVisitorPostAlg2); Data::ConstructorTraversalVisitorPostAlg2);
// Assign roots from the dummy node // Assign roots from the dummy node
Base::roots_ = rootData.myJTNode->children; Base::roots_ = rootData.myJTNode->children;
@ -161,4 +190,4 @@ JunctionTree<BAYESTREE, GRAPH>::JunctionTree(
Base::remainingFactors_ = eliminationTree.remainingFactors(); Base::remainingFactors_ = eliminationTree.remainingFactors();
} }
} // namespace gtsam } // namespace gtsam