diff --git a/gtsam/discrete/DiscreteConditional.cpp b/gtsam/discrete/DiscreteConditional.cpp index 055503b8e..981986ea1 100644 --- a/gtsam/discrete/DiscreteConditional.cpp +++ b/gtsam/discrete/DiscreteConditional.cpp @@ -488,6 +488,20 @@ void DiscreteConditional::setData(const DiscreteConditional::shared_ptr& dc) { this->root_ = dc->root_; } +/* ************************************************************************* */ +DiscreteConditional::shared_ptr DiscreteConditional::max( + const Ordering& keys) const { + auto m = *BaseFactor::max(keys); + return std::make_shared(m.discreteKeys().size(), m); +} + +/* ************************************************************************* */ +DiscreteConditional::shared_ptr DiscreteConditional::prune( + size_t maxNrAssignments) const { + return std::make_shared( + this->nrFrontals(), BaseFactor::prune(maxNrAssignments)); +} + /* ************************************************************************* */ double DiscreteConditional::negLogConstant() const { return 0.0; } diff --git a/gtsam/discrete/DiscreteConditional.h b/gtsam/discrete/DiscreteConditional.h index 318024faa..12b5d457c 100644 --- a/gtsam/discrete/DiscreteConditional.h +++ b/gtsam/discrete/DiscreteConditional.h @@ -214,6 +214,15 @@ class GTSAM_EXPORT DiscreteConditional */ size_t argmax(const DiscreteValues& parentsValues = DiscreteValues()) const; + /** + * @brief Create new conditional by maximizing over all + * values with the same separator. + * + * @param keys The keys to sum over. + * @return DiscreteConditional::shared_ptr + */ + virtual DiscreteConditional::shared_ptr max(const Ordering& keys) const; + /// @} /// @name Advanced Interface /// @{ @@ -273,6 +282,9 @@ class GTSAM_EXPORT DiscreteConditional /// Set the data from another DiscreteConditional. virtual void setData(const DiscreteConditional::shared_ptr& dc); + /// Prune the conditional + virtual DiscreteConditional::shared_ptr prune(size_t maxNrAssignments) const; + /// @} protected: diff --git a/gtsam/discrete/DiscreteTableConditional.h b/gtsam/discrete/DiscreteTableConditional.h index fae0c0761..8a03dc361 100644 --- a/gtsam/discrete/DiscreteTableConditional.h +++ b/gtsam/discrete/DiscreteTableConditional.h @@ -205,6 +205,14 @@ class GTSAM_EXPORT DiscreteTableConditional : public DiscreteConditional { return -error(x); } + /// Return the underlying TableFactor + TableFactor table() const { return table_; } + + /// Evaluate the conditional given the values. + virtual double evaluate(const Assignment& values) const override { + return table_.evaluate(values); + } + /// @} private: