diff --git a/gtsam/discrete/DiscreteTableConditional.cpp b/gtsam/discrete/DiscreteTableConditional.cpp index a4fdcef5d..2bad12d2b 100644 --- a/gtsam/discrete/DiscreteTableConditional.cpp +++ b/gtsam/discrete/DiscreteTableConditional.cpp @@ -151,6 +151,33 @@ TableFactor::shared_ptr DiscreteTableConditional::likelihood( throw std::runtime_error("Likelihood not implemented"); } +/* ****************************************************************************/ +DiscreteConditional::shared_ptr DiscreteTableConditional::max( + const Ordering& keys) const override { + auto m = *table_.max(keys); + + return std::make_shared(m.discreteKeys().size(), m); +} + +/* ****************************************************************************/ +void DiscreteTableConditional::setData( + const DiscreteConditional::shared_ptr& dc) override { + if (auto dtc = std::dynamic_pointer_cast(dc)) { + this->table_ = dtc->table_; + } else { + this->table_ = TableFactor(dc->discreteKeys(), *dc); + } +} + +/* ****************************************************************************/ +DiscreteConditional::shared_ptr DiscreteTableConditional::prune( + size_t maxNrAssignments) const { + TableFactor pruned = table_.prune(maxNrAssignments); + + return std::make_shared( + this->nrFrontals(), this->discreteKeys(), pruned.sparseTable()); +} + /* ************************************************************************** */ size_t DiscreteTableConditional::argmax( const DiscreteValues& parentsValues) const { diff --git a/gtsam/discrete/DiscreteTableConditional.h b/gtsam/discrete/DiscreteTableConditional.h index 8a03dc361..f34cad2a3 100644 --- a/gtsam/discrete/DiscreteTableConditional.h +++ b/gtsam/discrete/DiscreteTableConditional.h @@ -181,6 +181,16 @@ class GTSAM_EXPORT DiscreteTableConditional : public DiscreteConditional { */ size_t argmax(const DiscreteValues& parentsValues = DiscreteValues()) const; + /** + * @brief Create new conditional by maximizing over all + * values with the same separator. + * + * @param keys The keys to sum over. + * @return DiscreteConditional::shared_ptr + */ + virtual DiscreteConditional::shared_ptr max( + const Ordering& keys) const override; + /// @} /// @name Advanced Interface /// @{ @@ -213,6 +223,13 @@ class GTSAM_EXPORT DiscreteTableConditional : public DiscreteConditional { return table_.evaluate(values); } + /// Set the underlying data from the DiscreteConditional + virtual void setData(const DiscreteConditional::shared_ptr& dc) override; + + /// Prune the conditional + virtual DiscreteConditional::shared_ptr prune( + size_t maxNrAssignments) const override; + /// @} private: