revert some DiscreteFactorGraph changes

release/4.3a0
Varun Agrawal 2024-12-08 16:26:03 -05:00
parent b325150b37
commit 0afc198411
1 changed files with 15 additions and 14 deletions

View File

@ -118,16 +118,17 @@ namespace gtsam {
* @param product The product discrete factor.
* @return DiscreteFactor::shared_ptr
*/
static DiscreteFactor::shared_ptr Normalize(
const DiscreteFactor::shared_ptr& product) {
static DecisionTreeFactor Normalize(const DecisionTreeFactor& product) {
// Max over all the potentials by pretending all keys are frontal:
gttic_(DiscreteFindMax);
auto normalization = product->max(product->size());
auto normalization = product.max(product.size());
gttoc_(DiscreteFindMax);
gttic_(DiscreteNormalization);
// Normalize the product factor to prevent underflow.
auto normalized_product = product->operator/(normalization);
auto normalized_product =
product /
(*std::dynamic_pointer_cast<DecisionTreeFactor>(normalization));
gttoc_(DiscreteNormalization);
return normalized_product;
@ -140,7 +141,7 @@ namespace gtsam {
const Ordering& frontalKeys) {
// PRODUCT: multiply all factors
gttic_(MPEProduct);
DiscreteFactor::shared_ptr product = factors.product();
DecisionTreeFactor product = factors.product();
gttoc_(MPEProduct);
gttic_(Normalize);
@ -151,23 +152,22 @@ namespace gtsam {
// max out frontals, this is the factor on the separator
gttic(max);
DiscreteFactor::shared_ptr max = product->max(frontalKeys);
DecisionTreeFactor::shared_ptr max =
std::dynamic_pointer_cast<DecisionTreeFactor>(product.max(frontalKeys));
gttoc(max);
// Ordering keys for the conditional so that frontalKeys are really in front
DiscreteKeys orderedKeys;
for (auto&& key : frontalKeys)
orderedKeys.emplace_back(key, product->cardinality(key));
orderedKeys.emplace_back(key, product.cardinality(key));
for (auto&& key : max->keys())
orderedKeys.emplace_back(key, product->cardinality(key));
orderedKeys.emplace_back(key, product.cardinality(key));
// Make lookup with product
gttic(lookup);
size_t nrFrontals = frontalKeys.size();
//TODO(Varun): Should accept a DiscreteFactor::shared_ptr
auto lookup = std::make_shared<DiscreteLookupTable>(
nrFrontals, orderedKeys,
*std::dynamic_pointer_cast<TableFactor>(product));
auto lookup =
std::make_shared<DiscreteLookupTable>(nrFrontals, orderedKeys, product);
gttoc(lookup);
return {std::dynamic_pointer_cast<DiscreteConditional>(lookup), max};
@ -230,7 +230,7 @@ namespace gtsam {
const Ordering& frontalKeys) {
// PRODUCT: multiply all factors
gttic_(product);
DiscreteFactor::shared_ptr product = factors.product();
DecisionTreeFactor product = factors.product();
gttoc_(product);
gttic_(Normalize);
@ -240,7 +240,8 @@ namespace gtsam {
// sum out frontals, this is the factor on the separator
gttic_(sum);
DiscreteFactor::shared_ptr sum = product->sum(frontalKeys);
DecisionTreeFactor::shared_ptr sum = std::dynamic_pointer_cast<DecisionTreeFactor>(
product.sum(frontalKeys));
gttoc_(sum);
// Ordering keys for the conditional so that frontalKeys are really in front