make normalization code common

release/4.3a0
Varun Agrawal 2024-12-07 11:15:06 -05:00
parent 27bbce150a
commit 84e419456a
1 changed files with 32 additions and 18 deletions

View File

@ -116,6 +116,28 @@ namespace gtsam {
// }
// }
/**
* @brief Helper method to normalize the product factor by
* the max value to prevent underflow
*
* @param product The product discrete factor.
* @return DiscreteFactor::shared_ptr
*/
static DiscreteFactor::shared_ptr Normalize(
const DiscreteFactor::shared_ptr& product) {
// Max over all the potentials by pretending all keys are frontal:
gttic(DiscreteFindMax);
auto normalization = product->max(product->size());
gttoc(DiscreteFindMax);
gttic(DiscreteNormalization);
// Normalize the product factor to prevent underflow.
auto normalized_product = product->operator/(normalization);
gttoc(DiscreteNormalization);
return normalized_product;
}
/* ************************************************************************ */
// Alternate eliminate function for MPE
std::pair<DiscreteConditional::shared_ptr, DiscreteFactor::shared_ptr>
@ -123,27 +145,23 @@ namespace gtsam {
const Ordering& frontalKeys) {
// PRODUCT: multiply all factors
gttic(product);
DecisionTreeFactor product;
for (auto&& factor : factors) product = (*factor) * product;
DiscreteFactor::shared_ptr product = factors.product();
gttoc(product);
// Max over all the potentials by pretending all keys are frontal:
auto normalization = product.max(product.size());
// Normalize the product factor to prevent underflow.
product = product / (*normalization);
// Normalize the product
product = Normalize(product);
// max out frontals, this is the factor on the separator
gttic(max);
DecisionTreeFactor::shared_ptr max = product.max(frontalKeys);
DecisionTreeFactor::shared_ptr max = product->max(frontalKeys);
gttoc(max);
// Ordering keys for the conditional so that frontalKeys are really in front
DiscreteKeys orderedKeys;
for (auto&& key : frontalKeys)
orderedKeys.emplace_back(key, product.cardinality(key));
orderedKeys.emplace_back(key, product->cardinality(key));
for (auto&& key : max->keys())
orderedKeys.emplace_back(key, product.cardinality(key));
orderedKeys.emplace_back(key, product->cardinality(key));
// Make lookup with product
gttic(lookup);
@ -212,19 +230,15 @@ namespace gtsam {
const Ordering& frontalKeys) {
// PRODUCT: multiply all factors
gttic(product);
DecisionTreeFactor product;
for (auto&& factor : factors) product = (*factor) * product;
DiscreteFactor::shared_ptr product = factors.product();
gttoc(product);
// Max over all the potentials by pretending all keys are frontal:
auto normalization = product.max(product.size());
// Normalize the product factor to prevent underflow.
product = product / (*normalization);
// Normalize the product
product = Normalize(product);
// sum out frontals, this is the factor on the separator
gttic(sum);
DecisionTreeFactor::shared_ptr sum = product.sum(frontalKeys);
DecisionTreeFactor::shared_ptr sum = product->sum(frontalKeys);
gttoc(sum);
// Ordering keys for the conditional so that frontalKeys are really in front