diff --git a/gtsam/discrete/DiscreteFactorGraph.cpp b/gtsam/discrete/DiscreteFactorGraph.cpp index e05cf9e33..eb3221819 100644 --- a/gtsam/discrete/DiscreteFactorGraph.cpp +++ b/gtsam/discrete/DiscreteFactorGraph.cpp @@ -64,7 +64,7 @@ namespace gtsam { } /* ************************************************************************ */ - DecisionTreeFactor DiscreteFactorGraph::product() const { + DiscreteFactor::shared_ptr DiscreteFactorGraph::product() const { DiscreteFactor::shared_ptr result; for (auto it = this->begin(); it != this->end(); ++it) { if (*it) { @@ -76,7 +76,7 @@ namespace gtsam { } } } - return result->toDecisionTreeFactor(); + return result; } /* ************************************************************************ */ @@ -122,20 +122,20 @@ namespace gtsam { * @brief Multiply all the `factors`. * * @param factors The factors to multiply as a DiscreteFactorGraph. - * @return DecisionTreeFactor + * @return DiscreteFactor::shared_ptr */ - static DecisionTreeFactor DiscreteProduct( + static DiscreteFactor::shared_ptr DiscreteProduct( const DiscreteFactorGraph& factors) { // PRODUCT: multiply all factors gttic(product); - DecisionTreeFactor product = factors.product(); + DiscreteFactor::shared_ptr product = factors.product(); gttoc(product); // Max over all the potentials by pretending all keys are frontal: - auto denominator = product.max(product.size()); + auto denominator = product->max(product->size()); // Normalize the product factor to prevent underflow. - product = product / denominator; + product = product->operator/(denominator); return product; } @@ -145,26 +145,25 @@ namespace gtsam { std::pair // EliminateForMPE(const DiscreteFactorGraph& factors, const Ordering& frontalKeys) { - DecisionTreeFactor product = DiscreteProduct(factors); + DiscreteFactor::shared_ptr product = DiscreteProduct(factors); // max out frontals, this is the factor on the separator gttic(max); - DecisionTreeFactor::shared_ptr max = - std::dynamic_pointer_cast(product.max(frontalKeys)); + DiscreteFactor::shared_ptr max = product->max(frontalKeys); gttoc(max); // Ordering keys for the conditional so that frontalKeys are really in front DiscreteKeys orderedKeys; for (auto&& key : frontalKeys) - orderedKeys.emplace_back(key, product.cardinality(key)); + orderedKeys.emplace_back(key, product->cardinality(key)); for (auto&& key : max->keys()) - orderedKeys.emplace_back(key, product.cardinality(key)); + orderedKeys.emplace_back(key, product->cardinality(key)); // Make lookup with product gttic(lookup); size_t nrFrontals = frontalKeys.size(); - auto lookup = - std::make_shared(nrFrontals, orderedKeys, product); + auto lookup = std::make_shared( + nrFrontals, orderedKeys, product->toDecisionTreeFactor()); gttoc(lookup); return {std::dynamic_pointer_cast(lookup), max}; @@ -224,12 +223,11 @@ namespace gtsam { std::pair // EliminateDiscrete(const DiscreteFactorGraph& factors, const Ordering& frontalKeys) { - DecisionTreeFactor product = DiscreteProduct(factors); + DiscreteFactor::shared_ptr product = DiscreteProduct(factors); // sum out frontals, this is the factor on the separator gttic(sum); - DecisionTreeFactor::shared_ptr sum = std::dynamic_pointer_cast( - product.sum(frontalKeys)); + DiscreteFactor::shared_ptr sum = product->sum(frontalKeys); gttoc(sum); // Ordering keys for the conditional so that frontalKeys are really in front @@ -241,8 +239,9 @@ namespace gtsam { // now divide product/sum to get conditional gttic(divide); - auto conditional = - std::make_shared(product, *sum, orderedKeys); + auto conditional = std::make_shared( + product->toDecisionTreeFactor(), sum->toDecisionTreeFactor(), + orderedKeys); gttoc(divide); return {conditional, sum}; diff --git a/gtsam/discrete/DiscreteFactorGraph.h b/gtsam/discrete/DiscreteFactorGraph.h index b311cb78b..3d9e86cd1 100644 --- a/gtsam/discrete/DiscreteFactorGraph.h +++ b/gtsam/discrete/DiscreteFactorGraph.h @@ -148,7 +148,7 @@ class GTSAM_EXPORT DiscreteFactorGraph DiscreteKeys discreteKeys() const; /** return product of all factors as a single factor */ - DecisionTreeFactor product() const; + DiscreteFactor::shared_ptr product() const; /** * Evaluates the factor graph given values, returns the joint probability of