common function for product and normalization

release/4.3a0
Varun Agrawal 2024-12-10 10:49:08 -05:00
parent 92e0a55e78
commit dea9c7f765
1 changed files with 17 additions and 15 deletions

View File

@ -112,11 +112,14 @@ namespace gtsam {
// }
// }
/* ************************************************************************ */
// Alternate eliminate function for MPE
std::pair<DiscreteConditional::shared_ptr, DecisionTreeFactor::shared_ptr> //
EliminateForMPE(const DiscreteFactorGraph& factors,
const Ordering& frontalKeys) {
/**
* @brief Multiply all the `factors` and normalize the
* product to prevent underflow.
*
* @param factors The factors to multiply as a DiscreteFactorGraph.
* @return DecisionTreeFactor
*/
static DecisionTreeFactor ProductAndNormalize(const DiscreteFactorGraph& factors) {
// PRODUCT: multiply all factors
gttic(product);
DecisionTreeFactor product = factors.product();
@ -127,6 +130,14 @@ namespace gtsam {
// Normalize the product factor to prevent underflow.
product = product / (*normalization);
}
/* ************************************************************************ */
// Alternate eliminate function for MPE
std::pair<DiscreteConditional::shared_ptr, DecisionTreeFactor::shared_ptr> //
EliminateForMPE(const DiscreteFactorGraph& factors,
const Ordering& frontalKeys) {
DecisionTreeFactor product = ProductAndNormalize(factors);
// max out frontals, this is the factor on the separator
gttic(max);
@ -205,16 +216,7 @@ namespace gtsam {
std::pair<DiscreteConditional::shared_ptr, DecisionTreeFactor::shared_ptr> //
EliminateDiscrete(const DiscreteFactorGraph& factors,
const Ordering& frontalKeys) {
// PRODUCT: multiply all factors
gttic(product);
DecisionTreeFactor product = factors.product();
gttoc(product);
// Max over all the potentials by pretending all keys are frontal:
auto normalization = product.max(product.size());
// Normalize the product factor to prevent underflow.
product = product / (*normalization);
DecisionTreeFactor product = ProductAndNormalize(factors);
// sum out frontals, this is the factor on the separator
gttic(sum);