add override methods to DiscreteTableConditional

release/4.3a0
Varun Agrawal 2024-12-31 00:19:22 -05:00
parent d9faa820de
commit 60945c8e32
2 changed files with 44 additions and 0 deletions

View File

@ -151,6 +151,33 @@ TableFactor::shared_ptr DiscreteTableConditional::likelihood(
throw std::runtime_error("Likelihood not implemented");
}
/* ****************************************************************************/
DiscreteConditional::shared_ptr DiscreteTableConditional::max(
const Ordering& keys) const override {
auto m = *table_.max(keys);
return std::make_shared<DiscreteTableConditional>(m.discreteKeys().size(), m);
}
/* ****************************************************************************/
void DiscreteTableConditional::setData(
const DiscreteConditional::shared_ptr& dc) override {
if (auto dtc = std::dynamic_pointer_cast<DiscreteTableConditional>(dc)) {
this->table_ = dtc->table_;
} else {
this->table_ = TableFactor(dc->discreteKeys(), *dc);
}
}
/* ****************************************************************************/
DiscreteConditional::shared_ptr DiscreteTableConditional::prune(
size_t maxNrAssignments) const {
TableFactor pruned = table_.prune(maxNrAssignments);
return std::make_shared<DiscreteTableConditional>(
this->nrFrontals(), this->discreteKeys(), pruned.sparseTable());
}
/* ************************************************************************** */
size_t DiscreteTableConditional::argmax(
const DiscreteValues& parentsValues) const {

View File

@ -181,6 +181,16 @@ class GTSAM_EXPORT DiscreteTableConditional : public DiscreteConditional {
*/
size_t argmax(const DiscreteValues& parentsValues = DiscreteValues()) const;
/**
* @brief Create new conditional by maximizing over all
* values with the same separator.
*
* @param keys The keys to sum over.
* @return DiscreteConditional::shared_ptr
*/
virtual DiscreteConditional::shared_ptr max(
const Ordering& keys) const override;
/// @}
/// @name Advanced Interface
/// @{
@ -213,6 +223,13 @@ class GTSAM_EXPORT DiscreteTableConditional : public DiscreteConditional {
return table_.evaluate(values);
}
/// Set the underlying data from the DiscreteConditional
virtual void setData(const DiscreteConditional::shared_ptr& dc) override;
/// Prune the conditional
virtual DiscreteConditional::shared_ptr prune(
size_t maxNrAssignments) const override;
/// @}
private: