customize discrete elimination in Hybrid
parent
71ea8c5d4c
commit
42f8e54c2a
|
|
@ -255,6 +255,48 @@ static TableFactor::shared_ptr DiscreteFactorFromErrors(
|
|||
return std::make_shared<TableFactor>(discreteKeys, potentials);
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Multiply all the `factors` and normalize the
|
||||
* product to prevent underflow.
|
||||
*
|
||||
* @param factors The factors to multiply as a DiscreteFactorGraph.
|
||||
* @return TableFactor
|
||||
*/
|
||||
static TableFactor ProductAndNormalize(const DiscreteFactorGraph &factors) {
|
||||
// PRODUCT: multiply all factors
|
||||
#if GTSAM_HYBRID_TIMING
|
||||
gttic_(DiscreteProduct);
|
||||
#endif
|
||||
TableFactor product;
|
||||
for (const sharedFactor &factor : factors) {
|
||||
if (factor) {
|
||||
if (auto f = std::dynamic_pointer_cast<TableFactor>(factor)) {
|
||||
product = product * (*f);
|
||||
} else if (auto dtf =
|
||||
std::dynamic_pointer_cast<DecisionTreeFactor>(factor)) {
|
||||
product = TableFactor(product * (*dtf));
|
||||
}
|
||||
}
|
||||
}
|
||||
#if GTSAM_HYBRID_TIMING
|
||||
gttoc_(DiscreteProduct);
|
||||
#endif
|
||||
|
||||
// Max over all the potentials by pretending all keys are frontal:
|
||||
auto normalizer = product.max(product.size());
|
||||
|
||||
#if GTSAM_HYBRID_TIMING
|
||||
gttic_(DiscreteNormalize);
|
||||
#endif
|
||||
// Normalize the product factor to prevent underflow.
|
||||
product = product / (*normalizer);
|
||||
#if GTSAM_HYBRID_TIMING
|
||||
gttoc_(DiscreteNormalize);
|
||||
#endif
|
||||
|
||||
return product;
|
||||
}
|
||||
|
||||
/* ************************************************************************ */
|
||||
static std::pair<HybridConditional::shared_ptr, std::shared_ptr<Factor>>
|
||||
discreteElimination(const HybridGaussianFactorGraph &factors,
|
||||
|
|
@ -299,8 +341,32 @@ discreteElimination(const HybridGaussianFactorGraph &factors,
|
|||
#if GTSAM_HYBRID_TIMING
|
||||
gttic_(EliminateDiscrete);
|
||||
#endif
|
||||
// NOTE: This does sum-product. For max-product, use EliminateForMPE.
|
||||
auto result = EliminateDiscrete(dfg, frontalKeys);
|
||||
/**** NOTE: This does sum-product. ****/
|
||||
// Get product factor
|
||||
TableFactor product = ProductAndNormalize(factors);
|
||||
|
||||
#if GTSAM_HYBRID_TIMING
|
||||
gttic_(EliminateDiscreteSum);
|
||||
#endif
|
||||
// All the discrete variables should form a single clique,
|
||||
// so we can sum out on all the variables as frontals.
|
||||
// This should give an empty separator.
|
||||
Ordering orderedKeys(product.keys());
|
||||
DecisionTreeFactor::shared_ptr sum = product.sum(orderedKeys);
|
||||
#if GTSAM_HYBRID_TIMING
|
||||
gttoc_(EliminateDiscreteSum);
|
||||
#endif
|
||||
|
||||
#if GTSAM_HYBRID_TIMING
|
||||
gttic_(EliminateDiscreteToDiscreteConditional);
|
||||
#endif
|
||||
// Finally, get the conditional
|
||||
auto conditional =
|
||||
std::make_shared<DiscreteConditional>(product, *sum, orderedKeys);
|
||||
#if GTSAM_HYBRID_TIMING
|
||||
gttoc_(EliminateDiscreteToDiscreteConditional);
|
||||
#endif
|
||||
|
||||
#if GTSAM_HYBRID_TIMING
|
||||
gttoc_(EliminateDiscrete);
|
||||
#endif
|
||||
|
|
|
|||
Loading…
Reference in New Issue