diff --git a/gtsam/discrete/DiscreteBayesNet.cpp b/gtsam/discrete/DiscreteBayesNet.cpp index 12c607223..7c6da3dac 100644 --- a/gtsam/discrete/DiscreteBayesNet.cpp +++ b/gtsam/discrete/DiscreteBayesNet.cpp @@ -71,16 +71,14 @@ DiscreteValues DiscreteBayesNet::sample(DiscreteValues result) const { /* ************************************************************************* */ // The implementation is: build the entire joint into one factor and then prune. -// TODO(Frank): This can be quite expensive *unless* the factors have already +// NOTE: This can be quite expensive *unless* the factors have already // been pruned before. Another, possibly faster approach is branch and bound // search to find the K-best leaves and then create a single pruned conditional. DiscreteBayesNet DiscreteBayesNet::prune( size_t maxNrLeaves, const std::optional& marginalThreshold, DiscreteValues* fixedValues) const { // Multiply into one big conditional. NOTE: possibly quite expensive. - DiscreteConditional joint; - for (const DiscreteConditional::shared_ptr& conditional : *this) - joint = joint * (*conditional); + DiscreteConditional joint = this->joint(); // Prune the joint. NOTE: imperative and, again, possibly quite expensive. DiscreteConditional pruned = joint; @@ -122,6 +120,15 @@ DiscreteBayesNet DiscreteBayesNet::prune( return result; } +/* *********************************************************************** */ +DiscreteConditional DiscreteBayesNet::joint() const { + DiscreteConditional joint; + for (const DiscreteConditional::shared_ptr& conditional : *this) + joint = joint * (*conditional); + + return joint; +} + /* *********************************************************************** */ std::string DiscreteBayesNet::markdown( const KeyFormatter& keyFormatter, diff --git a/gtsam/discrete/DiscreteBayesNet.h b/gtsam/discrete/DiscreteBayesNet.h index eea1739f6..e15576b37 100644 --- a/gtsam/discrete/DiscreteBayesNet.h +++ b/gtsam/discrete/DiscreteBayesNet.h @@ -136,6 +136,16 @@ class GTSAM_EXPORT DiscreteBayesNet: public BayesNet { const std::optional& marginalThreshold = {}, DiscreteValues* fixedValues = nullptr) const; + /** + * @brief Multiply all conditionals into one big joint conditional + * and return it. + * + * NOTE: possibly quite expensive. + * + * @return DiscreteConditional + */ + DiscreteConditional joint() const; + ///@} /// @name Wrapper support /// @{