conditional switch for hybrid timing

release/4.3a0
Varun Agrawal 2024-12-27 12:02:21 -05:00
parent 4cf0727f2a
commit 7c9d04fb65
4 changed files with 73 additions and 8 deletions

View File

@ -42,6 +42,10 @@
// Whether to enable merging of equal leaf nodes in the Discrete Decision Tree. // Whether to enable merging of equal leaf nodes in the Discrete Decision Tree.
#cmakedefine GTSAM_DT_MERGING #cmakedefine GTSAM_DT_MERGING
// Whether to enable timing in hybrid factor graph machinery
// #cmakedefine01 GTSAM_HYBRID_TIMING
#define GTSAM_HYBRID_TIMING
// Whether we are using TBB (if TBB was found and GTSAM_WITH_TBB is enabled in CMake) // Whether we are using TBB (if TBB was found and GTSAM_WITH_TBB is enabled in CMake)
#cmakedefine GTSAM_USE_TBB #cmakedefine GTSAM_USE_TBB

View File

@ -121,15 +121,25 @@ namespace gtsam {
static DecisionTreeFactor ProductAndNormalize( static DecisionTreeFactor ProductAndNormalize(
const DiscreteFactorGraph& factors) { const DiscreteFactorGraph& factors) {
// PRODUCT: multiply all factors // PRODUCT: multiply all factors
gttic(product); #if GTSAM_HYBRID_TIMING
gttic_(DiscreteProduct);
#endif
DecisionTreeFactor product = factors.product(); DecisionTreeFactor product = factors.product();
gttoc(product); #if GTSAM_HYBRID_TIMING
gttoc_(DiscreteProduct);
#endif
// Max over all the potentials by pretending all keys are frontal: // Max over all the potentials by pretending all keys are frontal:
auto normalizer = product.max(product.size()); auto normalizer = product.max(product.size());
#if GTSAM_HYBRID_TIMING
gttic_(DiscreteNormalize);
#endif
// Normalize the product factor to prevent underflow. // Normalize the product factor to prevent underflow.
product = product / (*normalizer); product = product / (*normalizer);
#if GTSAM_HYBRID_TIMING
gttoc_(DiscreteNormalize);
#endif
return product; return product;
} }
@ -220,9 +230,13 @@ namespace gtsam {
DecisionTreeFactor product = ProductAndNormalize(factors); DecisionTreeFactor product = ProductAndNormalize(factors);
// sum out frontals, this is the factor on the separator // sum out frontals, this is the factor on the separator
gttic(sum); #if GTSAM_HYBRID_TIMING
gttic_(EliminateDiscreteSum);
#endif
DecisionTreeFactor::shared_ptr sum = product.sum(frontalKeys); DecisionTreeFactor::shared_ptr sum = product.sum(frontalKeys);
gttoc(sum); #if GTSAM_HYBRID_TIMING
gttoc_(EliminateDiscreteSum);
#endif
// Ordering keys for the conditional so that frontalKeys are really in front // Ordering keys for the conditional so that frontalKeys are really in front
Ordering orderedKeys; Ordering orderedKeys;
@ -232,10 +246,14 @@ namespace gtsam {
sum->keys().end()); sum->keys().end());
// now divide product/sum to get conditional // now divide product/sum to get conditional
gttic(divide); #if GTSAM_HYBRID_TIMING
gttic_(EliminateDiscreteToDiscreteConditional);
#endif
auto conditional = auto conditional =
std::make_shared<DiscreteConditional>(product, *sum, orderedKeys); std::make_shared<DiscreteConditional>(product, *sum, orderedKeys);
gttoc(divide); #if GTSAM_HYBRID_TIMING
gttoc_(EliminateDiscreteToDiscreteConditional);
#endif
return {conditional, sum}; return {conditional, sum};
} }

View File

@ -282,14 +282,28 @@ discreteElimination(const HybridGaussianFactorGraph &factors,
} else if (auto hc = dynamic_pointer_cast<HybridConditional>(f)) { } else if (auto hc = dynamic_pointer_cast<HybridConditional>(f)) {
auto dc = hc->asDiscrete(); auto dc = hc->asDiscrete();
if (!dc) throwRuntimeError("discreteElimination", dc); if (!dc) throwRuntimeError("discreteElimination", dc);
dfg.push_back(dc); #if GTSAM_HYBRID_TIMING
gttic_(ConvertConditionalToTableFactor);
#endif
// Convert DiscreteConditional to TableFactor
auto tdc = std::make_shared<TableFactor>(*dc);
#if GTSAM_HYBRID_TIMING
gttoc_(ConvertConditionalToTableFactor);
#endif
dfg.push_back(tdc);
} else { } else {
throwRuntimeError("discreteElimination", f); throwRuntimeError("discreteElimination", f);
} }
} }
#if GTSAM_HYBRID_TIMING
gttic_(EliminateDiscrete);
#endif
// NOTE: This does sum-product. For max-product, use EliminateForMPE. // NOTE: This does sum-product. For max-product, use EliminateForMPE.
auto result = EliminateDiscrete(dfg, frontalKeys); auto result = EliminateDiscrete(dfg, frontalKeys);
#if GTSAM_HYBRID_TIMING
gttoc_(EliminateDiscrete);
#endif
return {std::make_shared<HybridConditional>(result.first), result.second}; return {std::make_shared<HybridConditional>(result.first), result.second};
} }
@ -319,8 +333,19 @@ static std::shared_ptr<Factor> createDiscreteFactor(
} }
}; };
#if GTSAM_HYBRID_TIMING
gttic_(DiscreteBoundaryErrors);
#endif
AlgebraicDecisionTree<Key> errors(eliminationResults, calculateError); AlgebraicDecisionTree<Key> errors(eliminationResults, calculateError);
return DiscreteFactorFromErrors(discreteSeparator, errors); #if GTSAM_HYBRID_TIMING
gttoc_(DiscreteBoundaryErrors);
gttic_(DiscreteBoundaryResult);
#endif
auto result = DiscreteFactorFromErrors(discreteSeparator, errors);
#if GTSAM_HYBRID_TIMING
gttoc_(DiscreteBoundaryResult);
#endif
return result;
} }
/* *******************************************************************************/ /* *******************************************************************************/
@ -360,12 +385,18 @@ HybridGaussianFactorGraph::eliminate(const Ordering &keys) const {
// the discrete separator will be *all* the discrete keys. // the discrete separator will be *all* the discrete keys.
DiscreteKeys discreteSeparator = GetDiscreteKeys(*this); DiscreteKeys discreteSeparator = GetDiscreteKeys(*this);
#if GTSAM_HYBRID_TIMING
gttic_(HybridCollectProductFactor);
#endif
// Collect all the factors to create a set of Gaussian factor graphs in a // Collect all the factors to create a set of Gaussian factor graphs in a
// decision tree indexed by all discrete keys involved. Just like any hybrid // decision tree indexed by all discrete keys involved. Just like any hybrid
// factor, every assignment also has a scalar error, in this case the sum of // factor, every assignment also has a scalar error, in this case the sum of
// all errors in the graph. This error is assignment-specific and accounts for // all errors in the graph. This error is assignment-specific and accounts for
// any difference in noise models used. // any difference in noise models used.
HybridGaussianProductFactor productFactor = collectProductFactor(); HybridGaussianProductFactor productFactor = collectProductFactor();
#if GTSAM_HYBRID_TIMING
gttoc_(HybridCollectProductFactor);
#endif
// Check if a factor is null // Check if a factor is null
auto isNull = [](const GaussianFactor::shared_ptr &ptr) { return !ptr; }; auto isNull = [](const GaussianFactor::shared_ptr &ptr) { return !ptr; };
@ -393,8 +424,14 @@ HybridGaussianFactorGraph::eliminate(const Ordering &keys) const {
return {conditional, conditional->negLogConstant(), factor, scalar}; return {conditional, conditional->negLogConstant(), factor, scalar};
}; };
#if GTSAM_HYBRID_TIMING
gttic_(HybridEliminate);
#endif
// Perform elimination! // Perform elimination!
const ResultTree eliminationResults(productFactor, eliminate); const ResultTree eliminationResults(productFactor, eliminate);
#if GTSAM_HYBRID_TIMING
gttoc_(HybridEliminate);
#endif
// If there are no more continuous parents we create a DiscreteFactor with the // If there are no more continuous parents we create a DiscreteFactor with the
// error for each discrete choice. Otherwise, create a HybridGaussianFactor // error for each discrete choice. Otherwise, create a HybridGaussianFactor

View File

@ -104,7 +104,13 @@ void HybridGaussianISAM::updateInternal(
elimination_ordering, function, std::cref(index)); elimination_ordering, function, std::cref(index));
if (maxNrLeaves) { if (maxNrLeaves) {
#if GTSAM_HYBRID_TIMING
gttic_(HybridBayesTreePrune);
#endif
bayesTree->prune(*maxNrLeaves); bayesTree->prune(*maxNrLeaves);
#if GTSAM_HYBRID_TIMING
gttoc_(HybridBayesTreePrune);
#endif
} }
// Re-add into Bayes tree data structures // Re-add into Bayes tree data structures