divide operator for DiscreteFactor::shared_ptr
parent
b5128b2c9f
commit
4ebca71146
|
@ -77,6 +77,19 @@ namespace gtsam {
|
|||
return result;
|
||||
}
|
||||
|
||||
/* ************************************************************************ */
|
||||
DiscreteFactor::shared_ptr DecisionTreeFactor::operator/(
|
||||
const DiscreteFactor::shared_ptr& f) const {
|
||||
if (auto tf = std::dynamic_pointer_cast<TableFactor>(f)) {
|
||||
return std::make_shared<TableFactor>(tf->operator/(TableFactor(*this)));
|
||||
} else if (auto dtf = std::dynamic_pointer_cast<DecisionTreeFactor>(f)) {
|
||||
return std::make_shared<DecisionTreeFactor>(this->operator/(*dtf));
|
||||
} else {
|
||||
return std::make_shared<DecisionTreeFactor>(
|
||||
this->operator/(this->toDecisionTreeFactor()));
|
||||
}
|
||||
}
|
||||
|
||||
/* ************************************************************************ */
|
||||
double DecisionTreeFactor::safe_div(const double& a, const double& b) {
|
||||
// The use for safe_div is when we divide the product factor by the sum
|
||||
|
|
|
@ -165,9 +165,8 @@ namespace gtsam {
|
|||
}
|
||||
|
||||
/// divide by DiscreteFactor::shared_ptr f (safely)
|
||||
DecisionTreeFactor operator/(const DiscreteFactor::shared_ptr& f) const {
|
||||
return apply(*std::dynamic_pointer_cast<DecisionTreeFactor>(f), safe_div);
|
||||
}
|
||||
DiscreteFactor::shared_ptr operator/(
|
||||
const DiscreteFactor::shared_ptr& f) const override;
|
||||
|
||||
/// Convert into a decision tree
|
||||
DecisionTreeFactor toDecisionTreeFactor() const override { return *this; }
|
||||
|
|
|
@ -140,6 +140,10 @@ class GTSAM_EXPORT DiscreteFactor : public Factor {
|
|||
virtual DiscreteFactor::shared_ptr multiply(
|
||||
const DiscreteFactor::shared_ptr& df) const = 0;
|
||||
|
||||
/// divide by DiscreteFactor::shared_ptr f (safely)
|
||||
virtual DiscreteFactor::shared_ptr operator/(
|
||||
const DiscreteFactor::shared_ptr& df) const = 0;
|
||||
|
||||
virtual DecisionTreeFactor toDecisionTreeFactor() const = 0;
|
||||
|
||||
/// Create new factor by summing all values with the same separator values
|
||||
|
|
|
@ -270,6 +270,20 @@ DiscreteFactor::shared_ptr TableFactor::multiply(
|
|||
return result;
|
||||
}
|
||||
|
||||
/* ************************************************************************ */
|
||||
DiscreteFactor::shared_ptr TableFactor::operator/(
|
||||
const DiscreteFactor::shared_ptr& f) const {
|
||||
if (auto tf = std::dynamic_pointer_cast<TableFactor>(f)) {
|
||||
return std::make_shared<TableFactor>(this->operator/(*tf));
|
||||
} else if (auto dtf = std::dynamic_pointer_cast<DecisionTreeFactor>(f)) {
|
||||
return std::make_shared<TableFactor>(
|
||||
this->operator/(TableFactor(f->discreteKeys(), *dtf)));
|
||||
} else {
|
||||
TableFactor divisor(f->toDecisionTreeFactor());
|
||||
return std::make_shared<TableFactor>(this->operator/(divisor));
|
||||
}
|
||||
}
|
||||
|
||||
/* ************************************************************************ */
|
||||
DecisionTreeFactor TableFactor::toDecisionTreeFactor() const {
|
||||
DiscreteKeys dkeys = discreteKeys();
|
||||
|
|
|
@ -191,15 +191,8 @@ class GTSAM_EXPORT TableFactor : public DiscreteFactor {
|
|||
}
|
||||
|
||||
/// divide by DiscreteFactor::shared_ptr f (safely)
|
||||
TableFactor operator/(const DiscreteFactor::shared_ptr& f) const {
|
||||
if (auto tf = std::dynamic_pointer_cast<TableFactor>(f)) {
|
||||
return apply(*tf, safe_div);
|
||||
} else if (auto dtf = std::dynamic_pointer_cast<DecisionTreeFactor>(f)) {
|
||||
return apply(TableFactor(f->discreteKeys(), *dtf), safe_div);
|
||||
} else {
|
||||
throw std::runtime_error("Unknown derived type for DiscreteFactor");
|
||||
}
|
||||
}
|
||||
DiscreteFactor::shared_ptr operator/(
|
||||
const DiscreteFactor::shared_ptr& f) const override;
|
||||
|
||||
/// Convert into a decisiontree
|
||||
DecisionTreeFactor toDecisionTreeFactor() const override;
|
||||
|
|
|
@ -56,6 +56,12 @@ DecisionTreeFactor AllDiff::operator*(const DecisionTreeFactor& f) const {
|
|||
return toDecisionTreeFactor() * f;
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
DiscreteFactor::shared_ptr AllDiff::operator/(
|
||||
const DiscreteFactor::shared_ptr& df) const {
|
||||
return this->toDecisionTreeFactor() / df;
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
bool AllDiff::ensureArcConsistency(Key j, Domains* domains) const {
|
||||
Domain& Dj = domains->at(j);
|
||||
|
|
|
@ -60,6 +60,10 @@ class GTSAM_UNSTABLE_EXPORT AllDiff : public Constraint {
|
|||
this->operator*(df->toDecisionTreeFactor()));
|
||||
}
|
||||
|
||||
/// divide by DiscreteFactor::shared_ptr f (safely)
|
||||
DiscreteFactor::shared_ptr operator/(
|
||||
const DiscreteFactor::shared_ptr& df) const override;
|
||||
|
||||
/// Compute error for each assignment and return as a tree
|
||||
AlgebraicDecisionTree<Key> errorTree() const override {
|
||||
throw std::runtime_error("AllDiff::error not implemented");
|
||||
|
|
|
@ -76,6 +76,12 @@ class BinaryAllDiff : public Constraint {
|
|||
this->operator*(df->toDecisionTreeFactor()));
|
||||
}
|
||||
|
||||
/// divide by DiscreteFactor::shared_ptr f (safely)
|
||||
DiscreteFactor::shared_ptr operator/(
|
||||
const DiscreteFactor::shared_ptr& df) const override {
|
||||
return this->toDecisionTreeFactor() / df;
|
||||
}
|
||||
|
||||
/*
|
||||
* Ensure Arc-consistency by checking every possible value of domain j.
|
||||
* @param j domain to be checked
|
||||
|
|
|
@ -49,6 +49,12 @@ DecisionTreeFactor Domain::operator*(const DecisionTreeFactor& f) const {
|
|||
return toDecisionTreeFactor() * f;
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
DiscreteFactor::shared_ptr Domain::operator/(
|
||||
const DiscreteFactor::shared_ptr& df) const {
|
||||
return this->toDecisionTreeFactor() / df;
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
bool Domain::ensureArcConsistency(Key j, Domains* domains) const {
|
||||
if (j != key()) throw invalid_argument("Domain check on wrong domain");
|
||||
|
|
|
@ -97,6 +97,10 @@ class GTSAM_UNSTABLE_EXPORT Domain : public Constraint {
|
|||
this->operator*(df->toDecisionTreeFactor()));
|
||||
}
|
||||
|
||||
/// divide by DiscreteFactor::shared_ptr f (safely)
|
||||
DiscreteFactor::shared_ptr operator/(
|
||||
const DiscreteFactor::shared_ptr& df) const override;
|
||||
|
||||
/*
|
||||
* Ensure Arc-consistency by checking every possible value of domain j.
|
||||
* @param j domain to be checked
|
||||
|
|
|
@ -41,6 +41,12 @@ DecisionTreeFactor SingleValue::operator*(const DecisionTreeFactor& f) const {
|
|||
return toDecisionTreeFactor() * f;
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
DiscreteFactor::shared_ptr SingleValue::operator/(
|
||||
const DiscreteFactor::shared_ptr& df) const {
|
||||
return this->toDecisionTreeFactor() / df;
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
bool SingleValue::ensureArcConsistency(Key j, Domains* domains) const {
|
||||
if (j != keys_[0])
|
||||
|
|
|
@ -70,6 +70,10 @@ class GTSAM_UNSTABLE_EXPORT SingleValue : public Constraint {
|
|||
this->operator*(df->toDecisionTreeFactor()));
|
||||
}
|
||||
|
||||
/// divide by DiscreteFactor::shared_ptr f (safely)
|
||||
DiscreteFactor::shared_ptr operator/(
|
||||
const DiscreteFactor::shared_ptr& df) const override;
|
||||
|
||||
/*
|
||||
* Ensure Arc-consistency: just sets domain[j] to {value_}.
|
||||
* @param j domain to be checked
|
||||
|
|
Loading…
Reference in New Issue