revert changes to make code generic

release/4.3a0
Varun Agrawal 2024-12-08 15:58:07 -05:00
parent 5e86f7ee51
commit 1c14a56f5d
4 changed files with 14 additions and 35 deletions

View File

@ -55,17 +55,15 @@ DiscreteConditional::DiscreteConditional(size_t nrFrontals,
: BaseFactor(keys, potentials), BaseConditional(nrFrontals) {} : BaseFactor(keys, potentials), BaseConditional(nrFrontals) {}
/* ************************************************************************** */ /* ************************************************************************** */
DiscreteConditional::DiscreteConditional( DiscreteConditional::DiscreteConditional(const DecisionTreeFactor& joint,
const DiscreteFactor::shared_ptr& joint, const DecisionTreeFactor& marginal)
const DiscreteFactor::shared_ptr& marginal) : BaseFactor(joint / marginal),
: BaseFactor(*std::dynamic_pointer_cast<DecisionTreeFactor>( BaseConditional(joint.size() - marginal.size()) {}
joint->operator/(marginal))),
BaseConditional(joint->size() - marginal->size()) {}
/* ************************************************************************** */ /* ************************************************************************** */
DiscreteConditional::DiscreteConditional( DiscreteConditional::DiscreteConditional(const DecisionTreeFactor& joint,
const DiscreteFactor::shared_ptr& joint, const DecisionTreeFactor& marginal,
const DiscreteFactor::shared_ptr& marginal, const Ordering& orderedKeys) const Ordering& orderedKeys)
: DiscreteConditional(joint, marginal) { : DiscreteConditional(joint, marginal) {
keys_.clear(); keys_.clear();
keys_.insert(keys_.end(), orderedKeys.begin(), orderedKeys.end()); keys_.insert(keys_.end(), orderedKeys.begin(), orderedKeys.end());

View File

@ -126,10 +126,9 @@ class GTSAM_EXPORT DiscreteFactor : public Factor {
/// Compute error for each assignment and return as a tree /// Compute error for each assignment and return as a tree
virtual AlgebraicDecisionTree<Key> errorTree() const; virtual AlgebraicDecisionTree<Key> errorTree() const;
/// Multiply in a DiscreteFactor and return the result as /// Multiply in a DecisionTreeFactor and return the result as
/// DiscreteFactor /// DecisionTreeFactor
virtual DiscreteFactor::shared_ptr operator*( virtual DecisionTreeFactor operator*(const DecisionTreeFactor&) const = 0;
const DiscreteFactor::shared_ptr&) const = 0;
virtual DecisionTreeFactor toDecisionTreeFactor() const = 0; virtual DecisionTreeFactor toDecisionTreeFactor() const = 0;
@ -145,9 +144,6 @@ class GTSAM_EXPORT DiscreteFactor : public Factor {
/// Create new factor by maximizing over all values with the same separator. /// Create new factor by maximizing over all values with the same separator.
virtual DiscreteFactor::shared_ptr max(const Ordering& keys) const = 0; virtual DiscreteFactor::shared_ptr max(const Ordering& keys) const = 0;
/// divide by factor f (safely)
virtual DiscreteFactor::shared_ptr operator/(
const DiscreteFactor::shared_ptr& f) const = 0;
/** /**
* Get the number of non-zero values contained in this factor. * Get the number of non-zero values contained in this factor.

View File

@ -171,13 +171,8 @@ double TableFactor::error(const HybridValues& values) const {
} }
/* ************************************************************************ */ /* ************************************************************************ */
DiscreteFactor::shared_ptr TableFactor::operator*( DecisionTreeFactor TableFactor::operator*(const DecisionTreeFactor& f) const {
const DiscreteFactor::shared_ptr& f) const { return toDecisionTreeFactor() * f;
if (auto derived = std::dynamic_pointer_cast<TableFactor>(f)) {
return std::make_shared<TableFactor>(this->operator*(*derived));
} else {
throw std::runtime_error("Cannot convert DiscreteFactor to TableFactor");
}
} }
/* ************************************************************************ */ /* ************************************************************************ */

View File

@ -161,9 +161,8 @@ class GTSAM_EXPORT TableFactor : public DiscreteFactor {
return apply(f, Ring::mul); return apply(f, Ring::mul);
}; };
/// multiply with DiscreteFactor /// multiply with DecisionTreeFactor
DiscreteFactor::shared_ptr operator*( DecisionTreeFactor operator*(const DecisionTreeFactor& f) const override;
const DiscreteFactor::shared_ptr& f) const override;
static double safe_div(const double& a, const double& b); static double safe_div(const double& a, const double& b);
@ -172,15 +171,6 @@ class GTSAM_EXPORT TableFactor : public DiscreteFactor {
return apply(f, safe_div); return apply(f, safe_div);
} }
/// divide by factor f (pointer version)
DiscreteFactor::shared_ptr operator/(
const DiscreteFactor::shared_ptr& f) const override {
if (auto derived = std::dynamic_pointer_cast<TableFactor>(f)) {
return std::make_shared<TableFactor>(apply(*derived, safe_div));
} else {
throw std::runtime_error("Cannot convert DiscreteFactor to Table Factor");
}
}
/// Convert into a decisiontree /// Convert into a decisiontree
DecisionTreeFactor toDecisionTreeFactor() const override; DecisionTreeFactor toDecisionTreeFactor() const override;