get rid of setData and make prune() imperative for non-factors

release/4.3a0
Varun Agrawal 2025-01-04 14:39:18 -05:00
parent 7cb818136f
commit d39641d8ac
4 changed files with 8 additions and 33 deletions

View File

@ -478,11 +478,6 @@ double DiscreteConditional::evaluate(const HybridValues& x) const {
return this->evaluate(x.discrete()); return this->evaluate(x.discrete());
} }
/* ************************************************************************* */
void DiscreteConditional::setData(const DiscreteConditional::shared_ptr& dc) {
this->root_ = dc->root_;
}
/* ************************************************************************* */ /* ************************************************************************* */
DiscreteConditional::shared_ptr DiscreteConditional::max( DiscreteConditional::shared_ptr DiscreteConditional::max(
const Ordering& keys) const { const Ordering& keys) const {
@ -491,10 +486,10 @@ DiscreteConditional::shared_ptr DiscreteConditional::max(
} }
/* ************************************************************************* */ /* ************************************************************************* */
DiscreteConditional::shared_ptr DiscreteConditional::prune( void DiscreteConditional::prune(size_t maxNrAssignments) {
size_t maxNrAssignments) const { // Get as DiscreteConditional so the probabilities are normalized
return std::make_shared<DiscreteConditional>( DiscreteConditional pruned(nrFrontals(), BaseFactor::prune(maxNrAssignments));
this->nrFrontals(), BaseFactor::prune(maxNrAssignments)); this->root_ = pruned.root_;
} }
/* ************************************************************************* */ /* ************************************************************************* */

View File

@ -276,11 +276,8 @@ class GTSAM_EXPORT DiscreteConditional
*/ */
double negLogConstant() const override; double negLogConstant() const override;
/// Set the data from another DiscreteConditional.
virtual void setData(const DiscreteConditional::shared_ptr& dc);
/// Prune the conditional /// Prune the conditional
virtual DiscreteConditional::shared_ptr prune(size_t maxNrAssignments) const; virtual void prune(size_t maxNrAssignments);
/// @} /// @}

View File

@ -122,21 +122,8 @@ DiscreteConditional::shared_ptr TableDistribution::max(
} }
/* ****************************************************************************/ /* ****************************************************************************/
void TableDistribution::setData(const DiscreteConditional::shared_ptr& dc) { void TableDistribution::prune(size_t maxNrAssignments) {
if (auto dtc = std::dynamic_pointer_cast<TableDistribution>(dc)) { table_ = table_.prune(maxNrAssignments);
this->table_ = dtc->table_;
} else {
this->table_ = TableFactor(dc->discreteKeys(), *dc);
}
}
/* ****************************************************************************/
DiscreteConditional::shared_ptr TableDistribution::prune(
size_t maxNrAssignments) const {
TableFactor pruned = table_.prune(maxNrAssignments);
return std::make_shared<TableDistribution>(this->discreteKeys(),
pruned.sparseTable());
} }
} // namespace gtsam } // namespace gtsam

View File

@ -145,12 +145,8 @@ class GTSAM_EXPORT TableDistribution : public DiscreteConditional {
/// @name Advanced Interface /// @name Advanced Interface
/// @{ /// @{
/// Set the underlying data from the DiscreteConditional
virtual void setData(const DiscreteConditional::shared_ptr& dc) override;
/// Prune the conditional /// Prune the conditional
virtual DiscreteConditional::shared_ptr prune( virtual void prune(size_t maxNrAssignments) override;
size_t maxNrAssignments) const override;
/// Get a DecisionTreeFactor representation. /// Get a DecisionTreeFactor representation.
DecisionTreeFactor toDecisionTreeFactor() const override { DecisionTreeFactor toDecisionTreeFactor() const override {