Count then merge
parent
c3811a5488
commit
2dd83fd92c
|
@ -27,7 +27,7 @@
|
||||||
|
|
||||||
namespace gtsam {
|
namespace gtsam {
|
||||||
|
|
||||||
template <class BAYESTREE, class GRAPH, class ETREE_NODE>
|
template<class BAYESTREE, class GRAPH, class ETREE_NODE>
|
||||||
struct ConstructorTraversalData {
|
struct ConstructorTraversalData {
|
||||||
typedef typename JunctionTree<BAYESTREE, GRAPH>::Node Node;
|
typedef typename JunctionTree<BAYESTREE, GRAPH>::Node Node;
|
||||||
typedef typename JunctionTree<BAYESTREE, GRAPH>::sharedNode sharedNode;
|
typedef typename JunctionTree<BAYESTREE, GRAPH>::sharedNode sharedNode;
|
||||||
|
@ -37,8 +37,13 @@ struct ConstructorTraversalData {
|
||||||
FastVector<SymbolicConditional::shared_ptr> childSymbolicConditionals;
|
FastVector<SymbolicConditional::shared_ptr> childSymbolicConditionals;
|
||||||
FastVector<SymbolicFactor::shared_ptr> childSymbolicFactors;
|
FastVector<SymbolicFactor::shared_ptr> childSymbolicFactors;
|
||||||
|
|
||||||
ConstructorTraversalData(ConstructorTraversalData* _parentData)
|
// Small inner class to store symbolic factors
|
||||||
: parentData(_parentData) {}
|
class SymbolicFactors: public FactorGraph<Factor> {
|
||||||
|
};
|
||||||
|
|
||||||
|
ConstructorTraversalData(ConstructorTraversalData* _parentData) :
|
||||||
|
parentData(_parentData) {
|
||||||
|
}
|
||||||
|
|
||||||
// Pre-order visitor function
|
// Pre-order visitor function
|
||||||
static ConstructorTraversalData ConstructorTraversalVisitorPre(
|
static ConstructorTraversalData ConstructorTraversalVisitorPre(
|
||||||
|
@ -64,13 +69,11 @@ struct ConstructorTraversalData {
|
||||||
// our number of symbolic elimination parents is exactly 1 less than
|
// our number of symbolic elimination parents is exactly 1 less than
|
||||||
// our child's symbolic elimination parents - this condition indicates that
|
// our child's symbolic elimination parents - this condition indicates that
|
||||||
// eliminating the current node did not introduce any parents beyond those
|
// eliminating the current node did not introduce any parents beyond those
|
||||||
// already in the child.
|
// already in the child->
|
||||||
|
|
||||||
// Do symbolic elimination for this node
|
// Do symbolic elimination for this node
|
||||||
class : public FactorGraph<Factor> {}
|
SymbolicFactors symbolicFactors;
|
||||||
symbolicFactors;
|
symbolicFactors.reserve(ETreeNode->factors.size() + myData.childSymbolicFactors.size());
|
||||||
symbolicFactors.reserve(ETreeNode->factors.size() +
|
|
||||||
myData.childSymbolicFactors.size());
|
|
||||||
// Add ETree node factors
|
// Add ETree node factors
|
||||||
symbolicFactors += ETreeNode->factors;
|
symbolicFactors += ETreeNode->factors;
|
||||||
// Add symbolic factors passed up from children
|
// Add symbolic factors passed up from children
|
||||||
|
@ -78,60 +81,87 @@ struct ConstructorTraversalData {
|
||||||
|
|
||||||
Ordering keyAsOrdering;
|
Ordering keyAsOrdering;
|
||||||
keyAsOrdering.push_back(ETreeNode->key);
|
keyAsOrdering.push_back(ETreeNode->key);
|
||||||
std::pair<SymbolicConditional::shared_ptr, SymbolicFactor::shared_ptr>
|
SymbolicConditional::shared_ptr myConditional;
|
||||||
symbolicElimResult =
|
SymbolicFactor::shared_ptr mySeparatorFactor;
|
||||||
internal::EliminateSymbolic(symbolicFactors, keyAsOrdering);
|
boost::tie(myConditional, mySeparatorFactor) = internal::EliminateSymbolic(
|
||||||
|
symbolicFactors, keyAsOrdering);
|
||||||
|
|
||||||
// Store symbolic elimination results in the parent
|
// Store symbolic elimination results in the parent
|
||||||
myData.parentData->childSymbolicConditionals.push_back(
|
myData.parentData->childSymbolicConditionals.push_back(myConditional);
|
||||||
symbolicElimResult.first);
|
myData.parentData->childSymbolicFactors.push_back(mySeparatorFactor);
|
||||||
myData.parentData->childSymbolicFactors.push_back(
|
|
||||||
symbolicElimResult.second);
|
|
||||||
sharedNode node = myData.myJTNode;
|
sharedNode node = myData.myJTNode;
|
||||||
|
const FastVector<SymbolicConditional::shared_ptr>& childConditionals =
|
||||||
|
myData.childSymbolicConditionals;
|
||||||
|
|
||||||
// Merge our children if they are in our clique - if our conditional has
|
// Merge our children if they are in our clique - if our conditional has
|
||||||
// exactly one fewer parent than our child's conditional.
|
// exactly one fewer parent than our child's conditional.
|
||||||
size_t myNrFrontals = 1;
|
size_t myNrFrontals = 1;
|
||||||
const size_t myNrParents = symbolicElimResult.first->nrParents();
|
const size_t myNrParents = myConditional->nrParents();
|
||||||
size_t nrMergedChildren = 0;
|
assert(node->newChildren.size() == childConditionals.size());
|
||||||
assert(node->children.size() == myData.childSymbolicConditionals.size());
|
|
||||||
// Loop over children
|
|
||||||
int combinedProblemSize =
|
|
||||||
(int)(symbolicElimResult.first->size() * symbolicFactors.size());
|
|
||||||
gttic(merge_children);
|
gttic(merge_children);
|
||||||
for (size_t i = 0; i < myData.childSymbolicConditionals.size(); ++i) {
|
// First count how many keys, factors and children we'll end up with
|
||||||
|
size_t nrKeys = node->orderedFrontalKeys.size();
|
||||||
|
size_t nrFactors = node->factors.size();
|
||||||
|
size_t nrChildren = 0;
|
||||||
|
// Loop over children
|
||||||
|
for (size_t i = 0; i < childConditionals.size(); ++i) {
|
||||||
// Check if we should merge the i^th child
|
// Check if we should merge the i^th child
|
||||||
if (myNrParents + myNrFrontals ==
|
if (myNrParents + myNrFrontals == childConditionals[i]->nrParents()) {
|
||||||
myData.childSymbolicConditionals[i]->nrParents()) {
|
|
||||||
// Get a reference to the i, adjusting the index to account for children
|
// Get a reference to the i, adjusting the index to account for children
|
||||||
// previously merged and removed from the i list.
|
// previously merged and removed from the i list.
|
||||||
const Node& child = *node->children[i - nrMergedChildren];
|
sharedNode child = node->children[i];
|
||||||
// Merge keys. For efficiency, we add keys in reverse order at end, calling reverse after..
|
nrKeys += child->orderedFrontalKeys.size();
|
||||||
node->orderedFrontalKeys.insert(node->orderedFrontalKeys.end(),
|
nrFactors += child->factors.size();
|
||||||
child.orderedFrontalKeys.rbegin(),
|
nrChildren += child->children.size();
|
||||||
child.orderedFrontalKeys.rend());
|
|
||||||
// Merge keys, factors, and children.
|
|
||||||
node->factors.insert(node->factors.end(), child.factors.begin(), child.factors.end());
|
|
||||||
node->children.insert(node->children.end(), child.children.begin(), child.children.end());
|
|
||||||
// Increment problem size
|
|
||||||
combinedProblemSize = std::max(combinedProblemSize, child.problemSize_);
|
|
||||||
// Increment number of frontal variables
|
// Increment number of frontal variables
|
||||||
myNrFrontals += child.orderedFrontalKeys.size();
|
myNrFrontals += child->orderedFrontalKeys.size();
|
||||||
// Remove i from list.
|
} else {
|
||||||
node->children.erase(node->children.begin() + (i - nrMergedChildren));
|
nrChildren += 1; // we keep the child
|
||||||
// Increment number of merged children
|
|
||||||
++nrMergedChildren;
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// now reserve space, and really merge
|
||||||
|
node->orderedFrontalKeys.reserve(nrKeys);
|
||||||
|
node->factors.reserve(nrFactors);
|
||||||
|
typename Node::Children newChildren;
|
||||||
|
newChildren.reserve(nrChildren);
|
||||||
|
myNrFrontals = 1;
|
||||||
|
int combinedProblemSize = (int) (myConditional->size() * symbolicFactors.size());
|
||||||
|
// Loop over newChildren
|
||||||
|
for (size_t i = 0; i < childConditionals.size(); ++i) {
|
||||||
|
// Check if we should merge the i^th child
|
||||||
|
sharedNode child = node->children[i];
|
||||||
|
if (myNrParents + myNrFrontals == childConditionals[i]->nrParents()) {
|
||||||
|
// Get a reference to the i, adjusting the index to account for newChildren
|
||||||
|
// previously merged and removed from the i list.
|
||||||
|
// Merge keys. For efficiency, we add keys in reverse order at end, calling reverse after..
|
||||||
|
node->orderedFrontalKeys.insert(node->orderedFrontalKeys.end(),
|
||||||
|
child->orderedFrontalKeys.rbegin(),
|
||||||
|
child->orderedFrontalKeys.rend());
|
||||||
|
// Merge keys, factors, and children.
|
||||||
|
node->factors.insert(node->factors.end(), child->factors.begin(), child->factors.end());
|
||||||
|
newChildren.insert(newChildren.end(), child->children.begin(), child->children.end());
|
||||||
|
// Increment problem size
|
||||||
|
combinedProblemSize = std::max(combinedProblemSize, child->problemSize_);
|
||||||
|
// Increment number of frontal variables
|
||||||
|
myNrFrontals += child->orderedFrontalKeys.size();
|
||||||
|
} else {
|
||||||
|
newChildren.push_back(child); // we keep the child
|
||||||
|
}
|
||||||
|
}
|
||||||
|
node->children = newChildren;
|
||||||
std::reverse(node->orderedFrontalKeys.begin(), node->orderedFrontalKeys.end());
|
std::reverse(node->orderedFrontalKeys.begin(), node->orderedFrontalKeys.end());
|
||||||
gttoc(merge_children);
|
gttoc(merge_children);
|
||||||
node->problemSize_ = combinedProblemSize;
|
node->problemSize_ = combinedProblemSize;
|
||||||
}
|
}
|
||||||
};
|
}
|
||||||
|
;
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
template <class BAYESTREE, class GRAPH>
|
template<class BAYESTREE, class GRAPH>
|
||||||
template <class ETREE_BAYESNET, class ETREE_GRAPH>
|
template<class ETREE_BAYESNET, class ETREE_GRAPH>
|
||||||
JunctionTree<BAYESTREE, GRAPH>::JunctionTree(
|
JunctionTree<BAYESTREE, GRAPH>::JunctionTree(
|
||||||
const EliminationTree<ETREE_BAYESNET, ETREE_GRAPH>& eliminationTree) {
|
const EliminationTree<ETREE_BAYESNET, ETREE_GRAPH>& eliminationTree) {
|
||||||
gttic(JunctionTree_FromEliminationTree);
|
gttic(JunctionTree_FromEliminationTree);
|
||||||
|
@ -147,12 +177,11 @@ JunctionTree<BAYESTREE, GRAPH>::JunctionTree(
|
||||||
typedef typename EliminationTree<ETREE_BAYESNET, ETREE_GRAPH>::Node ETreeNode;
|
typedef typename EliminationTree<ETREE_BAYESNET, ETREE_GRAPH>::Node ETreeNode;
|
||||||
typedef ConstructorTraversalData<BAYESTREE, GRAPH, ETreeNode> Data;
|
typedef ConstructorTraversalData<BAYESTREE, GRAPH, ETreeNode> Data;
|
||||||
Data rootData(0);
|
Data rootData(0);
|
||||||
rootData.myJTNode =
|
rootData.myJTNode = boost::make_shared<typename Base::Node>(); // Make a dummy node to gather
|
||||||
boost::make_shared<typename Base::Node>(); // Make a dummy node to gather
|
// the junction tree roots
|
||||||
// the junction tree roots
|
|
||||||
treeTraversal::DepthFirstForest(eliminationTree, rootData,
|
treeTraversal::DepthFirstForest(eliminationTree, rootData,
|
||||||
Data::ConstructorTraversalVisitorPre,
|
Data::ConstructorTraversalVisitorPre,
|
||||||
Data::ConstructorTraversalVisitorPostAlg2);
|
Data::ConstructorTraversalVisitorPostAlg2);
|
||||||
|
|
||||||
// Assign roots from the dummy node
|
// Assign roots from the dummy node
|
||||||
Base::roots_ = rootData.myJTNode->children;
|
Base::roots_ = rootData.myJTNode->children;
|
||||||
|
@ -161,4 +190,4 @@ JunctionTree<BAYESTREE, GRAPH>::JunctionTree(
|
||||||
Base::remainingFactors_ = eliminationTree.remainingFactors();
|
Base::remainingFactors_ = eliminationTree.remainingFactors();
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace gtsam
|
} // namespace gtsam
|
||||||
|
|
Loading…
Reference in New Issue