customize discrete elimination in Hybrid

release/4.3a0
Varun Agrawal 2024-12-31 14:27:01 -05:00
parent 71ea8c5d4c
commit 42f8e54c2a
1 changed files with 68 additions and 2 deletions

View File

@ -255,6 +255,48 @@ static TableFactor::shared_ptr DiscreteFactorFromErrors(
return std::make_shared<TableFactor>(discreteKeys, potentials);
}
/**
* @brief Multiply all the `factors` and normalize the
* product to prevent underflow.
*
* @param factors The factors to multiply as a DiscreteFactorGraph.
* @return TableFactor
*/
static TableFactor ProductAndNormalize(const DiscreteFactorGraph &factors) {
// PRODUCT: multiply all factors
#if GTSAM_HYBRID_TIMING
gttic_(DiscreteProduct);
#endif
TableFactor product;
for (const sharedFactor &factor : factors) {
if (factor) {
if (auto f = std::dynamic_pointer_cast<TableFactor>(factor)) {
product = product * (*f);
} else if (auto dtf =
std::dynamic_pointer_cast<DecisionTreeFactor>(factor)) {
product = TableFactor(product * (*dtf));
}
}
}
#if GTSAM_HYBRID_TIMING
gttoc_(DiscreteProduct);
#endif
// Max over all the potentials by pretending all keys are frontal:
auto normalizer = product.max(product.size());
#if GTSAM_HYBRID_TIMING
gttic_(DiscreteNormalize);
#endif
// Normalize the product factor to prevent underflow.
product = product / (*normalizer);
#if GTSAM_HYBRID_TIMING
gttoc_(DiscreteNormalize);
#endif
return product;
}
/* ************************************************************************ */
static std::pair<HybridConditional::shared_ptr, std::shared_ptr<Factor>>
discreteElimination(const HybridGaussianFactorGraph &factors,
@ -299,8 +341,32 @@ discreteElimination(const HybridGaussianFactorGraph &factors,
#if GTSAM_HYBRID_TIMING
gttic_(EliminateDiscrete);
#endif
// NOTE: This does sum-product. For max-product, use EliminateForMPE.
auto result = EliminateDiscrete(dfg, frontalKeys);
/**** NOTE: This does sum-product. ****/
// Get product factor
TableFactor product = ProductAndNormalize(factors);
#if GTSAM_HYBRID_TIMING
gttic_(EliminateDiscreteSum);
#endif
// All the discrete variables should form a single clique,
// so we can sum out on all the variables as frontals.
// This should give an empty separator.
Ordering orderedKeys(product.keys());
DecisionTreeFactor::shared_ptr sum = product.sum(orderedKeys);
#if GTSAM_HYBRID_TIMING
gttoc_(EliminateDiscreteSum);
#endif
#if GTSAM_HYBRID_TIMING
gttic_(EliminateDiscreteToDiscreteConditional);
#endif
// Finally, get the conditional
auto conditional =
std::make_shared<DiscreteConditional>(product, *sum, orderedKeys);
#if GTSAM_HYBRID_TIMING
gttoc_(EliminateDiscreteToDiscreteConditional);
#endif
#if GTSAM_HYBRID_TIMING
gttoc_(EliminateDiscrete);
#endif