diff --git a/gtsam/discrete/DiscreteFactorGraph.cpp b/gtsam/discrete/DiscreteFactorGraph.cpp index e31f94eae..50072f547 100644 --- a/gtsam/discrete/DiscreteFactorGraph.cpp +++ b/gtsam/discrete/DiscreteFactorGraph.cpp @@ -14,6 +14,7 @@ * @date Feb 14, 2011 * @author Duy-Nguyen Ta * @author Frank Dellaert + * @author Varun Agrawal */ #include @@ -35,13 +36,12 @@ namespace gtsam { template class FactorGraph; template class EliminateableFactorGraph; - /* ************************************************************************* */ - bool DiscreteFactorGraph::equals(const This& fg, double tol) const - { + /* ************************************************************************ */ + bool DiscreteFactorGraph::equals(const This& fg, double tol) const { return Base::equals(fg, tol); } - /* ************************************************************************* */ + /* ************************************************************************ */ KeySet DiscreteFactorGraph::keys() const { KeySet keys; for (const sharedFactor& factor : *this) { @@ -50,7 +50,7 @@ namespace gtsam { return keys; } - /* ************************************************************************* */ + /* ************************************************************************ */ DiscreteKeys DiscreteFactorGraph::discreteKeys() const { DiscreteKeys result; for (auto&& factor : *this) { @@ -63,7 +63,7 @@ namespace gtsam { return result; } - /* ************************************************************************* */ + /* ************************************************************************ */ DecisionTreeFactor DiscreteFactorGraph::product() const { DecisionTreeFactor result; for (const sharedFactor& factor : *this) { @@ -72,18 +72,18 @@ namespace gtsam { return result; } - /* ************************************************************************* */ - double DiscreteFactorGraph::operator()( - const DiscreteValues &values) const { + /* ************************************************************************ */ + double DiscreteFactorGraph::operator()(const DiscreteValues& values) const { double product = 1.0; - for( const sharedFactor& factor: factors_ ) - product *= (*factor)(values); + for (const sharedFactor& factor : factors_) { + if (factor) product *= (*factor)(values); + } return product; } - /* ************************************************************************* */ + /* ************************************************************************ */ void DiscreteFactorGraph::print(const string& s, - const KeyFormatter& formatter) const { + const KeyFormatter& formatter) const { std::cout << s << std::endl; std::cout << "size: " << size() << std::endl; for (size_t i = 0; i < factors_.size(); i++) { @@ -112,43 +112,36 @@ namespace gtsam { // } /** - * @brief Helper method to normalize the product factor by - * the max value to prevent underflow + * @brief Multiply all the `factors` and normalize the + * product to prevent underflow. * - * @param product The product discrete factor. - * @return DiscreteFactor::shared_ptr + * @param factors The factors to multiply as a DiscreteFactorGraph. + * @return DecisionTreeFactor */ - static DecisionTreeFactor Normalize(const DecisionTreeFactor& product) { - // Max over all the potentials by pretending all keys are frontal: - gttic_(DiscreteFindMax); - auto normalization = product.max(product.size()); - gttoc_(DiscreteFindMax); + static DecisionTreeFactor ProductAndNormalize( + const DiscreteFactorGraph& factors) { + // PRODUCT: multiply all factors + gttic(product); + DecisionTreeFactor product = factors.product(); + gttoc(product); + + // Max over all the potentials by pretending all keys are frontal: + auto normalization = product.max(product.size()); - gttic_(DiscreteNormalization); // Normalize the product factor to prevent underflow. auto normalized_product = product / (*std::dynamic_pointer_cast(normalization)); - gttoc_(DiscreteNormalization); return normalized_product; } /* ************************************************************************ */ // Alternate eliminate function for MPE - std::pair + std::pair // EliminateForMPE(const DiscreteFactorGraph& factors, const Ordering& frontalKeys) { - // PRODUCT: multiply all factors - gttic_(MPEProduct); - DecisionTreeFactor product = factors.product(); - gttoc_(MPEProduct); - - gttic_(Normalize); - - // Normalize the product - product = Normalize(product); - gttoc_(Normalize); + DecisionTreeFactor product = ProductAndNormalize(factors); // max out frontals, this is the factor on the separator gttic(max); @@ -225,18 +218,10 @@ namespace gtsam { } /* ************************************************************************ */ - std::pair + std::pair // EliminateDiscrete(const DiscreteFactorGraph& factors, const Ordering& frontalKeys) { - // PRODUCT: multiply all factors - gttic_(product); - DecisionTreeFactor product = factors.product(); - gttoc_(product); - - gttic_(Normalize); - // Normalize the product - product = Normalize(product); - gttoc_(Normalize); + DecisionTreeFactor product = ProductAndNormalize(factors); // sum out frontals, this is the factor on the separator gttic_(sum); diff --git a/gtsam/discrete/DiscreteFactorGraph.h b/gtsam/discrete/DiscreteFactorGraph.h index 43c48c2d0..b311cb78b 100644 --- a/gtsam/discrete/DiscreteFactorGraph.h +++ b/gtsam/discrete/DiscreteFactorGraph.h @@ -14,6 +14,7 @@ * @date Feb 14, 2011 * @author Duy-Nguyen Ta * @author Frank Dellaert + * @author Varun Agrawal */ #pragma once diff --git a/gtsam/discrete/tests/testDiscreteFactorGraph.cpp b/gtsam/discrete/tests/testDiscreteFactorGraph.cpp index e0d696d91..71c88bb7d 100644 --- a/gtsam/discrete/tests/testDiscreteFactorGraph.cpp +++ b/gtsam/discrete/tests/testDiscreteFactorGraph.cpp @@ -113,7 +113,8 @@ TEST(DiscreteFactorGraph, test) { const Ordering frontalKeys{0}; const auto [conditional, newFactorPtr] = EliminateDiscrete(graph, frontalKeys); - auto newFactor = *std::dynamic_pointer_cast(newFactorPtr); + DecisionTreeFactor newFactor = + *std::dynamic_pointer_cast(newFactorPtr); // Normalize newFactor by max for comparison with expected auto normalization = newFactor.max(newFactor.size());