some name cleaning in the HybridJunctionTree

release/4.3a0
Varun Agrawal 2022-09-02 15:11:49 -04:00
parent 773af1ed44
commit ba4720709b
1 changed files with 13 additions and 14 deletions

View File

@ -31,9 +31,7 @@ template class EliminatableClusterTree<HybridBayesTree,
template class JunctionTree<HybridBayesTree, HybridGaussianFactorGraph>; template class JunctionTree<HybridBayesTree, HybridGaussianFactorGraph>;
struct HybridConstructorTraversalData { struct HybridConstructorTraversalData {
typedef typedef HybridJunctionTree::Node Node;
typename JunctionTree<HybridBayesTree, HybridGaussianFactorGraph>::Node
Node;
typedef typedef
typename JunctionTree<HybridBayesTree, typename JunctionTree<HybridBayesTree,
HybridGaussianFactorGraph>::sharedNode sharedNode; HybridGaussianFactorGraph>::sharedNode sharedNode;
@ -62,6 +60,7 @@ struct HybridConstructorTraversalData {
data.junctionTreeNode = boost::make_shared<Node>(node->key, node->factors); data.junctionTreeNode = boost::make_shared<Node>(node->key, node->factors);
parentData.junctionTreeNode->addChild(data.junctionTreeNode); parentData.junctionTreeNode->addChild(data.junctionTreeNode);
// Add all the discrete keys in the hybrid factors to the current data
for (HybridFactor::shared_ptr& f : node->factors) { for (HybridFactor::shared_ptr& f : node->factors) {
for (auto& k : f->discreteKeys()) { for (auto& k : f->discreteKeys()) {
data.discreteKeys.insert(k.first); data.discreteKeys.insert(k.first);
@ -72,8 +71,8 @@ struct HybridConstructorTraversalData {
} }
// Post-order visitor function // Post-order visitor function
static void ConstructorTraversalVisitorPostAlg2( static void ConstructorTraversalVisitorPost(
const boost::shared_ptr<HybridEliminationTree::Node>& ETreeNode, const boost::shared_ptr<HybridEliminationTree::Node>& node,
const HybridConstructorTraversalData& data) { const HybridConstructorTraversalData& data) {
// In this post-order visitor, we combine the symbolic elimination results // In this post-order visitor, we combine the symbolic elimination results
// from the elimination tree children and symbolically eliminate the current // from the elimination tree children and symbolically eliminate the current
@ -86,15 +85,15 @@ struct HybridConstructorTraversalData {
// Do symbolic elimination for this node // Do symbolic elimination for this node
SymbolicFactors symbolicFactors; SymbolicFactors symbolicFactors;
symbolicFactors.reserve(ETreeNode->factors.size() + symbolicFactors.reserve(node->factors.size() +
data.childSymbolicFactors.size()); data.childSymbolicFactors.size());
// Add ETree node factors // Add ETree node factors
symbolicFactors += ETreeNode->factors; symbolicFactors += node->factors;
// Add symbolic factors passed up from children // Add symbolic factors passed up from children
symbolicFactors += data.childSymbolicFactors; symbolicFactors += data.childSymbolicFactors;
Ordering keyAsOrdering; Ordering keyAsOrdering;
keyAsOrdering.push_back(ETreeNode->key); keyAsOrdering.push_back(node->key);
SymbolicConditional::shared_ptr conditional; SymbolicConditional::shared_ptr conditional;
SymbolicFactor::shared_ptr separatorFactor; SymbolicFactor::shared_ptr separatorFactor;
boost::tie(conditional, separatorFactor) = boost::tie(conditional, separatorFactor) =
@ -105,19 +104,19 @@ struct HybridConstructorTraversalData {
data.parentData->childSymbolicFactors.push_back(separatorFactor); data.parentData->childSymbolicFactors.push_back(separatorFactor);
data.parentData->discreteKeys.merge(data.discreteKeys); data.parentData->discreteKeys.merge(data.discreteKeys);
sharedNode node = data.junctionTreeNode; sharedNode jt_node = data.junctionTreeNode;
const FastVector<SymbolicConditional::shared_ptr>& childConditionals = const FastVector<SymbolicConditional::shared_ptr>& childConditionals =
data.childSymbolicConditionals; data.childSymbolicConditionals;
node->problemSize_ = (int)(conditional->size() * symbolicFactors.size()); jt_node->problemSize_ = (int)(conditional->size() * symbolicFactors.size());
// Merge our children if they are in our clique - if our conditional has // Merge our children if they are in our clique - if our conditional has
// exactly one fewer parent than our child's conditional. // exactly one fewer parent than our child's conditional.
const size_t nrParents = conditional->nrParents(); const size_t nrParents = conditional->nrParents();
const size_t nrChildren = node->nrChildren(); const size_t nrChildren = jt_node->nrChildren();
assert(childConditionals.size() == nrChildren); assert(childConditionals.size() == nrChildren);
// decide which children to merge, as index into children // decide which children to merge, as index into children
std::vector<size_t> nrChildrenFrontals = node->nrFrontalsOfChildren(); std::vector<size_t> nrChildrenFrontals = jt_node->nrFrontalsOfChildren();
std::vector<bool> merge(nrChildren, false); std::vector<bool> merge(nrChildren, false);
size_t nrFrontals = 1; size_t nrFrontals = 1;
for (size_t i = 0; i < nrChildren; i++) { for (size_t i = 0; i < nrChildren; i++) {
@ -137,7 +136,7 @@ struct HybridConstructorTraversalData {
} }
// now really merge // now really merge
node->mergeChildren(merge); jt_node->mergeChildren(merge);
} }
}; };
@ -161,7 +160,7 @@ HybridJunctionTree::HybridJunctionTree(
// the junction tree roots // the junction tree roots
treeTraversal::DepthFirstForest(eliminationTree, rootData, treeTraversal::DepthFirstForest(eliminationTree, rootData,
Data::ConstructorTraversalVisitorPre, Data::ConstructorTraversalVisitorPre,
Data::ConstructorTraversalVisitorPostAlg2); Data::ConstructorTraversalVisitorPost);
// Assign roots from the dummy node // Assign roots from the dummy node
this->addChildrenAsRoots(rootData.junctionTreeNode); this->addChildrenAsRoots(rootData.junctionTreeNode);