conditional switch for hybrid timing

release/4.3a0
Varun Agrawal 2024-12-27 12:02:21 -05:00
parent 4cf0727f2a
commit 7c9d04fb65
4 changed files with 73 additions and 8 deletions

View File

@ -42,6 +42,10 @@
// Whether to enable merging of equal leaf nodes in the Discrete Decision Tree.
#cmakedefine GTSAM_DT_MERGING
// Whether to enable timing in hybrid factor graph machinery
// #cmakedefine01 GTSAM_HYBRID_TIMING
#define GTSAM_HYBRID_TIMING
// Whether we are using TBB (if TBB was found and GTSAM_WITH_TBB is enabled in CMake)
#cmakedefine GTSAM_USE_TBB

View File

@ -121,15 +121,25 @@ namespace gtsam {
static DecisionTreeFactor ProductAndNormalize(
const DiscreteFactorGraph& factors) {
// PRODUCT: multiply all factors
gttic(product);
#if GTSAM_HYBRID_TIMING
gttic_(DiscreteProduct);
#endif
DecisionTreeFactor product = factors.product();
gttoc(product);
#if GTSAM_HYBRID_TIMING
gttoc_(DiscreteProduct);
#endif
// Max over all the potentials by pretending all keys are frontal:
auto normalizer = product.max(product.size());
#if GTSAM_HYBRID_TIMING
gttic_(DiscreteNormalize);
#endif
// Normalize the product factor to prevent underflow.
product = product / (*normalizer);
#if GTSAM_HYBRID_TIMING
gttoc_(DiscreteNormalize);
#endif
return product;
}
@ -220,9 +230,13 @@ namespace gtsam {
DecisionTreeFactor product = ProductAndNormalize(factors);
// sum out frontals, this is the factor on the separator
gttic(sum);
#if GTSAM_HYBRID_TIMING
gttic_(EliminateDiscreteSum);
#endif
DecisionTreeFactor::shared_ptr sum = product.sum(frontalKeys);
gttoc(sum);
#if GTSAM_HYBRID_TIMING
gttoc_(EliminateDiscreteSum);
#endif
// Ordering keys for the conditional so that frontalKeys are really in front
Ordering orderedKeys;
@ -232,10 +246,14 @@ namespace gtsam {
sum->keys().end());
// now divide product/sum to get conditional
gttic(divide);
#if GTSAM_HYBRID_TIMING
gttic_(EliminateDiscreteToDiscreteConditional);
#endif
auto conditional =
std::make_shared<DiscreteConditional>(product, *sum, orderedKeys);
gttoc(divide);
#if GTSAM_HYBRID_TIMING
gttoc_(EliminateDiscreteToDiscreteConditional);
#endif
return {conditional, sum};
}

View File

@ -282,14 +282,28 @@ discreteElimination(const HybridGaussianFactorGraph &factors,
} else if (auto hc = dynamic_pointer_cast<HybridConditional>(f)) {
auto dc = hc->asDiscrete();
if (!dc) throwRuntimeError("discreteElimination", dc);
dfg.push_back(dc);
#if GTSAM_HYBRID_TIMING
gttic_(ConvertConditionalToTableFactor);
#endif
// Convert DiscreteConditional to TableFactor
auto tdc = std::make_shared<TableFactor>(*dc);
#if GTSAM_HYBRID_TIMING
gttoc_(ConvertConditionalToTableFactor);
#endif
dfg.push_back(tdc);
} else {
throwRuntimeError("discreteElimination", f);
}
}
#if GTSAM_HYBRID_TIMING
gttic_(EliminateDiscrete);
#endif
// NOTE: This does sum-product. For max-product, use EliminateForMPE.
auto result = EliminateDiscrete(dfg, frontalKeys);
#if GTSAM_HYBRID_TIMING
gttoc_(EliminateDiscrete);
#endif
return {std::make_shared<HybridConditional>(result.first), result.second};
}
@ -319,8 +333,19 @@ static std::shared_ptr<Factor> createDiscreteFactor(
}
};
#if GTSAM_HYBRID_TIMING
gttic_(DiscreteBoundaryErrors);
#endif
AlgebraicDecisionTree<Key> errors(eliminationResults, calculateError);
return DiscreteFactorFromErrors(discreteSeparator, errors);
#if GTSAM_HYBRID_TIMING
gttoc_(DiscreteBoundaryErrors);
gttic_(DiscreteBoundaryResult);
#endif
auto result = DiscreteFactorFromErrors(discreteSeparator, errors);
#if GTSAM_HYBRID_TIMING
gttoc_(DiscreteBoundaryResult);
#endif
return result;
}
/* *******************************************************************************/
@ -360,12 +385,18 @@ HybridGaussianFactorGraph::eliminate(const Ordering &keys) const {
// the discrete separator will be *all* the discrete keys.
DiscreteKeys discreteSeparator = GetDiscreteKeys(*this);
#if GTSAM_HYBRID_TIMING
gttic_(HybridCollectProductFactor);
#endif
// Collect all the factors to create a set of Gaussian factor graphs in a
// decision tree indexed by all discrete keys involved. Just like any hybrid
// factor, every assignment also has a scalar error, in this case the sum of
// all errors in the graph. This error is assignment-specific and accounts for
// any difference in noise models used.
HybridGaussianProductFactor productFactor = collectProductFactor();
#if GTSAM_HYBRID_TIMING
gttoc_(HybridCollectProductFactor);
#endif
// Check if a factor is null
auto isNull = [](const GaussianFactor::shared_ptr &ptr) { return !ptr; };
@ -393,8 +424,14 @@ HybridGaussianFactorGraph::eliminate(const Ordering &keys) const {
return {conditional, conditional->negLogConstant(), factor, scalar};
};
#if GTSAM_HYBRID_TIMING
gttic_(HybridEliminate);
#endif
// Perform elimination!
const ResultTree eliminationResults(productFactor, eliminate);
#if GTSAM_HYBRID_TIMING
gttoc_(HybridEliminate);
#endif
// If there are no more continuous parents we create a DiscreteFactor with the
// error for each discrete choice. Otherwise, create a HybridGaussianFactor

View File

@ -104,7 +104,13 @@ void HybridGaussianISAM::updateInternal(
elimination_ordering, function, std::cref(index));
if (maxNrLeaves) {
#if GTSAM_HYBRID_TIMING
gttic_(HybridBayesTreePrune);
#endif
bayesTree->prune(*maxNrLeaves);
#if GTSAM_HYBRID_TIMING
gttoc_(HybridBayesTreePrune);
#endif
}
// Re-add into Bayes tree data structures