customize discrete elimination in Hybrid
parent
71ea8c5d4c
commit
42f8e54c2a
|
|
@ -255,6 +255,48 @@ static TableFactor::shared_ptr DiscreteFactorFromErrors(
|
||||||
return std::make_shared<TableFactor>(discreteKeys, potentials);
|
return std::make_shared<TableFactor>(discreteKeys, potentials);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Multiply all the `factors` and normalize the
|
||||||
|
* product to prevent underflow.
|
||||||
|
*
|
||||||
|
* @param factors The factors to multiply as a DiscreteFactorGraph.
|
||||||
|
* @return TableFactor
|
||||||
|
*/
|
||||||
|
static TableFactor ProductAndNormalize(const DiscreteFactorGraph &factors) {
|
||||||
|
// PRODUCT: multiply all factors
|
||||||
|
#if GTSAM_HYBRID_TIMING
|
||||||
|
gttic_(DiscreteProduct);
|
||||||
|
#endif
|
||||||
|
TableFactor product;
|
||||||
|
for (const sharedFactor &factor : factors) {
|
||||||
|
if (factor) {
|
||||||
|
if (auto f = std::dynamic_pointer_cast<TableFactor>(factor)) {
|
||||||
|
product = product * (*f);
|
||||||
|
} else if (auto dtf =
|
||||||
|
std::dynamic_pointer_cast<DecisionTreeFactor>(factor)) {
|
||||||
|
product = TableFactor(product * (*dtf));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
#if GTSAM_HYBRID_TIMING
|
||||||
|
gttoc_(DiscreteProduct);
|
||||||
|
#endif
|
||||||
|
|
||||||
|
// Max over all the potentials by pretending all keys are frontal:
|
||||||
|
auto normalizer = product.max(product.size());
|
||||||
|
|
||||||
|
#if GTSAM_HYBRID_TIMING
|
||||||
|
gttic_(DiscreteNormalize);
|
||||||
|
#endif
|
||||||
|
// Normalize the product factor to prevent underflow.
|
||||||
|
product = product / (*normalizer);
|
||||||
|
#if GTSAM_HYBRID_TIMING
|
||||||
|
gttoc_(DiscreteNormalize);
|
||||||
|
#endif
|
||||||
|
|
||||||
|
return product;
|
||||||
|
}
|
||||||
|
|
||||||
/* ************************************************************************ */
|
/* ************************************************************************ */
|
||||||
static std::pair<HybridConditional::shared_ptr, std::shared_ptr<Factor>>
|
static std::pair<HybridConditional::shared_ptr, std::shared_ptr<Factor>>
|
||||||
discreteElimination(const HybridGaussianFactorGraph &factors,
|
discreteElimination(const HybridGaussianFactorGraph &factors,
|
||||||
|
|
@ -299,8 +341,32 @@ discreteElimination(const HybridGaussianFactorGraph &factors,
|
||||||
#if GTSAM_HYBRID_TIMING
|
#if GTSAM_HYBRID_TIMING
|
||||||
gttic_(EliminateDiscrete);
|
gttic_(EliminateDiscrete);
|
||||||
#endif
|
#endif
|
||||||
// NOTE: This does sum-product. For max-product, use EliminateForMPE.
|
/**** NOTE: This does sum-product. ****/
|
||||||
auto result = EliminateDiscrete(dfg, frontalKeys);
|
// Get product factor
|
||||||
|
TableFactor product = ProductAndNormalize(factors);
|
||||||
|
|
||||||
|
#if GTSAM_HYBRID_TIMING
|
||||||
|
gttic_(EliminateDiscreteSum);
|
||||||
|
#endif
|
||||||
|
// All the discrete variables should form a single clique,
|
||||||
|
// so we can sum out on all the variables as frontals.
|
||||||
|
// This should give an empty separator.
|
||||||
|
Ordering orderedKeys(product.keys());
|
||||||
|
DecisionTreeFactor::shared_ptr sum = product.sum(orderedKeys);
|
||||||
|
#if GTSAM_HYBRID_TIMING
|
||||||
|
gttoc_(EliminateDiscreteSum);
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#if GTSAM_HYBRID_TIMING
|
||||||
|
gttic_(EliminateDiscreteToDiscreteConditional);
|
||||||
|
#endif
|
||||||
|
// Finally, get the conditional
|
||||||
|
auto conditional =
|
||||||
|
std::make_shared<DiscreteConditional>(product, *sum, orderedKeys);
|
||||||
|
#if GTSAM_HYBRID_TIMING
|
||||||
|
gttoc_(EliminateDiscreteToDiscreteConditional);
|
||||||
|
#endif
|
||||||
|
|
||||||
#if GTSAM_HYBRID_TIMING
|
#if GTSAM_HYBRID_TIMING
|
||||||
gttoc_(EliminateDiscrete);
|
gttoc_(EliminateDiscrete);
|
||||||
#endif
|
#endif
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue