revert changes to make code generic
parent
5e86f7ee51
commit
1c14a56f5d
|
|
@ -55,17 +55,15 @@ DiscreteConditional::DiscreteConditional(size_t nrFrontals,
|
||||||
: BaseFactor(keys, potentials), BaseConditional(nrFrontals) {}
|
: BaseFactor(keys, potentials), BaseConditional(nrFrontals) {}
|
||||||
|
|
||||||
/* ************************************************************************** */
|
/* ************************************************************************** */
|
||||||
DiscreteConditional::DiscreteConditional(
|
DiscreteConditional::DiscreteConditional(const DecisionTreeFactor& joint,
|
||||||
const DiscreteFactor::shared_ptr& joint,
|
const DecisionTreeFactor& marginal)
|
||||||
const DiscreteFactor::shared_ptr& marginal)
|
: BaseFactor(joint / marginal),
|
||||||
: BaseFactor(*std::dynamic_pointer_cast<DecisionTreeFactor>(
|
BaseConditional(joint.size() - marginal.size()) {}
|
||||||
joint->operator/(marginal))),
|
|
||||||
BaseConditional(joint->size() - marginal->size()) {}
|
|
||||||
|
|
||||||
/* ************************************************************************** */
|
/* ************************************************************************** */
|
||||||
DiscreteConditional::DiscreteConditional(
|
DiscreteConditional::DiscreteConditional(const DecisionTreeFactor& joint,
|
||||||
const DiscreteFactor::shared_ptr& joint,
|
const DecisionTreeFactor& marginal,
|
||||||
const DiscreteFactor::shared_ptr& marginal, const Ordering& orderedKeys)
|
const Ordering& orderedKeys)
|
||||||
: DiscreteConditional(joint, marginal) {
|
: DiscreteConditional(joint, marginal) {
|
||||||
keys_.clear();
|
keys_.clear();
|
||||||
keys_.insert(keys_.end(), orderedKeys.begin(), orderedKeys.end());
|
keys_.insert(keys_.end(), orderedKeys.begin(), orderedKeys.end());
|
||||||
|
|
|
||||||
|
|
@ -126,10 +126,9 @@ class GTSAM_EXPORT DiscreteFactor : public Factor {
|
||||||
/// Compute error for each assignment and return as a tree
|
/// Compute error for each assignment and return as a tree
|
||||||
virtual AlgebraicDecisionTree<Key> errorTree() const;
|
virtual AlgebraicDecisionTree<Key> errorTree() const;
|
||||||
|
|
||||||
/// Multiply in a DiscreteFactor and return the result as
|
/// Multiply in a DecisionTreeFactor and return the result as
|
||||||
/// DiscreteFactor
|
/// DecisionTreeFactor
|
||||||
virtual DiscreteFactor::shared_ptr operator*(
|
virtual DecisionTreeFactor operator*(const DecisionTreeFactor&) const = 0;
|
||||||
const DiscreteFactor::shared_ptr&) const = 0;
|
|
||||||
|
|
||||||
virtual DecisionTreeFactor toDecisionTreeFactor() const = 0;
|
virtual DecisionTreeFactor toDecisionTreeFactor() const = 0;
|
||||||
|
|
||||||
|
|
@ -145,9 +144,6 @@ class GTSAM_EXPORT DiscreteFactor : public Factor {
|
||||||
/// Create new factor by maximizing over all values with the same separator.
|
/// Create new factor by maximizing over all values with the same separator.
|
||||||
virtual DiscreteFactor::shared_ptr max(const Ordering& keys) const = 0;
|
virtual DiscreteFactor::shared_ptr max(const Ordering& keys) const = 0;
|
||||||
|
|
||||||
/// divide by factor f (safely)
|
|
||||||
virtual DiscreteFactor::shared_ptr operator/(
|
|
||||||
const DiscreteFactor::shared_ptr& f) const = 0;
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Get the number of non-zero values contained in this factor.
|
* Get the number of non-zero values contained in this factor.
|
||||||
|
|
|
||||||
|
|
@ -171,13 +171,8 @@ double TableFactor::error(const HybridValues& values) const {
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************ */
|
/* ************************************************************************ */
|
||||||
DiscreteFactor::shared_ptr TableFactor::operator*(
|
DecisionTreeFactor TableFactor::operator*(const DecisionTreeFactor& f) const {
|
||||||
const DiscreteFactor::shared_ptr& f) const {
|
return toDecisionTreeFactor() * f;
|
||||||
if (auto derived = std::dynamic_pointer_cast<TableFactor>(f)) {
|
|
||||||
return std::make_shared<TableFactor>(this->operator*(*derived));
|
|
||||||
} else {
|
|
||||||
throw std::runtime_error("Cannot convert DiscreteFactor to TableFactor");
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************ */
|
/* ************************************************************************ */
|
||||||
|
|
|
||||||
|
|
@ -161,9 +161,8 @@ class GTSAM_EXPORT TableFactor : public DiscreteFactor {
|
||||||
return apply(f, Ring::mul);
|
return apply(f, Ring::mul);
|
||||||
};
|
};
|
||||||
|
|
||||||
/// multiply with DiscreteFactor
|
/// multiply with DecisionTreeFactor
|
||||||
DiscreteFactor::shared_ptr operator*(
|
DecisionTreeFactor operator*(const DecisionTreeFactor& f) const override;
|
||||||
const DiscreteFactor::shared_ptr& f) const override;
|
|
||||||
|
|
||||||
static double safe_div(const double& a, const double& b);
|
static double safe_div(const double& a, const double& b);
|
||||||
|
|
||||||
|
|
@ -172,15 +171,6 @@ class GTSAM_EXPORT TableFactor : public DiscreteFactor {
|
||||||
return apply(f, safe_div);
|
return apply(f, safe_div);
|
||||||
}
|
}
|
||||||
|
|
||||||
/// divide by factor f (pointer version)
|
|
||||||
DiscreteFactor::shared_ptr operator/(
|
|
||||||
const DiscreteFactor::shared_ptr& f) const override {
|
|
||||||
if (auto derived = std::dynamic_pointer_cast<TableFactor>(f)) {
|
|
||||||
return std::make_shared<TableFactor>(apply(*derived, safe_div));
|
|
||||||
} else {
|
|
||||||
throw std::runtime_error("Cannot convert DiscreteFactor to Table Factor");
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Convert into a decisiontree
|
/// Convert into a decisiontree
|
||||||
DecisionTreeFactor toDecisionTreeFactor() const override;
|
DecisionTreeFactor toDecisionTreeFactor() const override;
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue