From 0afc1984118d8bec1c2327f542cac8b22d11b96f Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Sun, 8 Dec 2024 16:26:03 -0500 Subject: [PATCH] revert some DiscreteFactorGraph changes --- gtsam/discrete/DiscreteFactorGraph.cpp | 29 +++++++++++++------------- 1 file changed, 15 insertions(+), 14 deletions(-) diff --git a/gtsam/discrete/DiscreteFactorGraph.cpp b/gtsam/discrete/DiscreteFactorGraph.cpp index 9e64b0f6d..04849985f 100644 --- a/gtsam/discrete/DiscreteFactorGraph.cpp +++ b/gtsam/discrete/DiscreteFactorGraph.cpp @@ -118,16 +118,17 @@ namespace gtsam { * @param product The product discrete factor. * @return DiscreteFactor::shared_ptr */ - static DiscreteFactor::shared_ptr Normalize( - const DiscreteFactor::shared_ptr& product) { + static DecisionTreeFactor Normalize(const DecisionTreeFactor& product) { // Max over all the potentials by pretending all keys are frontal: gttic_(DiscreteFindMax); - auto normalization = product->max(product->size()); + auto normalization = product.max(product.size()); gttoc_(DiscreteFindMax); gttic_(DiscreteNormalization); // Normalize the product factor to prevent underflow. - auto normalized_product = product->operator/(normalization); + auto normalized_product = + product / + (*std::dynamic_pointer_cast(normalization)); gttoc_(DiscreteNormalization); return normalized_product; @@ -140,7 +141,7 @@ namespace gtsam { const Ordering& frontalKeys) { // PRODUCT: multiply all factors gttic_(MPEProduct); - DiscreteFactor::shared_ptr product = factors.product(); + DecisionTreeFactor product = factors.product(); gttoc_(MPEProduct); gttic_(Normalize); @@ -151,23 +152,22 @@ namespace gtsam { // max out frontals, this is the factor on the separator gttic(max); - DiscreteFactor::shared_ptr max = product->max(frontalKeys); + DecisionTreeFactor::shared_ptr max = + std::dynamic_pointer_cast(product.max(frontalKeys)); gttoc(max); // Ordering keys for the conditional so that frontalKeys are really in front DiscreteKeys orderedKeys; for (auto&& key : frontalKeys) - orderedKeys.emplace_back(key, product->cardinality(key)); + orderedKeys.emplace_back(key, product.cardinality(key)); for (auto&& key : max->keys()) - orderedKeys.emplace_back(key, product->cardinality(key)); + orderedKeys.emplace_back(key, product.cardinality(key)); // Make lookup with product gttic(lookup); size_t nrFrontals = frontalKeys.size(); - //TODO(Varun): Should accept a DiscreteFactor::shared_ptr - auto lookup = std::make_shared( - nrFrontals, orderedKeys, - *std::dynamic_pointer_cast(product)); + auto lookup = + std::make_shared(nrFrontals, orderedKeys, product); gttoc(lookup); return {std::dynamic_pointer_cast(lookup), max}; @@ -230,7 +230,7 @@ namespace gtsam { const Ordering& frontalKeys) { // PRODUCT: multiply all factors gttic_(product); - DiscreteFactor::shared_ptr product = factors.product(); + DecisionTreeFactor product = factors.product(); gttoc_(product); gttic_(Normalize); @@ -240,7 +240,8 @@ namespace gtsam { // sum out frontals, this is the factor on the separator gttic_(sum); - DiscreteFactor::shared_ptr sum = product->sum(frontalKeys); + DecisionTreeFactor::shared_ptr sum = std::dynamic_pointer_cast( + product.sum(frontalKeys)); gttoc_(sum); // Ordering keys for the conditional so that frontalKeys are really in front