divide operator for DiscreteFactor::shared_ptr

release/4.3a0
Varun Agrawal 2025-01-05 20:44:10 -05:00
parent b5128b2c9f
commit 4ebca71146
12 changed files with 71 additions and 12 deletions

View File

@ -77,6 +77,19 @@ namespace gtsam {
return result;
}
/* ************************************************************************ */
DiscreteFactor::shared_ptr DecisionTreeFactor::operator/(
const DiscreteFactor::shared_ptr& f) const {
if (auto tf = std::dynamic_pointer_cast<TableFactor>(f)) {
return std::make_shared<TableFactor>(tf->operator/(TableFactor(*this)));
} else if (auto dtf = std::dynamic_pointer_cast<DecisionTreeFactor>(f)) {
return std::make_shared<DecisionTreeFactor>(this->operator/(*dtf));
} else {
return std::make_shared<DecisionTreeFactor>(
this->operator/(this->toDecisionTreeFactor()));
}
}
/* ************************************************************************ */
double DecisionTreeFactor::safe_div(const double& a, const double& b) {
// The use for safe_div is when we divide the product factor by the sum

View File

@ -165,9 +165,8 @@ namespace gtsam {
}
/// divide by DiscreteFactor::shared_ptr f (safely)
DecisionTreeFactor operator/(const DiscreteFactor::shared_ptr& f) const {
return apply(*std::dynamic_pointer_cast<DecisionTreeFactor>(f), safe_div);
}
DiscreteFactor::shared_ptr operator/(
const DiscreteFactor::shared_ptr& f) const override;
/// Convert into a decision tree
DecisionTreeFactor toDecisionTreeFactor() const override { return *this; }

View File

@ -140,6 +140,10 @@ class GTSAM_EXPORT DiscreteFactor : public Factor {
virtual DiscreteFactor::shared_ptr multiply(
const DiscreteFactor::shared_ptr& df) const = 0;
/// divide by DiscreteFactor::shared_ptr f (safely)
virtual DiscreteFactor::shared_ptr operator/(
const DiscreteFactor::shared_ptr& df) const = 0;
virtual DecisionTreeFactor toDecisionTreeFactor() const = 0;
/// Create new factor by summing all values with the same separator values

View File

@ -270,6 +270,20 @@ DiscreteFactor::shared_ptr TableFactor::multiply(
return result;
}
/* ************************************************************************ */
DiscreteFactor::shared_ptr TableFactor::operator/(
const DiscreteFactor::shared_ptr& f) const {
if (auto tf = std::dynamic_pointer_cast<TableFactor>(f)) {
return std::make_shared<TableFactor>(this->operator/(*tf));
} else if (auto dtf = std::dynamic_pointer_cast<DecisionTreeFactor>(f)) {
return std::make_shared<TableFactor>(
this->operator/(TableFactor(f->discreteKeys(), *dtf)));
} else {
TableFactor divisor(f->toDecisionTreeFactor());
return std::make_shared<TableFactor>(this->operator/(divisor));
}
}
/* ************************************************************************ */
DecisionTreeFactor TableFactor::toDecisionTreeFactor() const {
DiscreteKeys dkeys = discreteKeys();

View File

@ -191,15 +191,8 @@ class GTSAM_EXPORT TableFactor : public DiscreteFactor {
}
/// divide by DiscreteFactor::shared_ptr f (safely)
TableFactor operator/(const DiscreteFactor::shared_ptr& f) const {
if (auto tf = std::dynamic_pointer_cast<TableFactor>(f)) {
return apply(*tf, safe_div);
} else if (auto dtf = std::dynamic_pointer_cast<DecisionTreeFactor>(f)) {
return apply(TableFactor(f->discreteKeys(), *dtf), safe_div);
} else {
throw std::runtime_error("Unknown derived type for DiscreteFactor");
}
}
DiscreteFactor::shared_ptr operator/(
const DiscreteFactor::shared_ptr& f) const override;
/// Convert into a decisiontree
DecisionTreeFactor toDecisionTreeFactor() const override;

View File

@ -56,6 +56,12 @@ DecisionTreeFactor AllDiff::operator*(const DecisionTreeFactor& f) const {
return toDecisionTreeFactor() * f;
}
/* ************************************************************************* */
DiscreteFactor::shared_ptr AllDiff::operator/(
const DiscreteFactor::shared_ptr& df) const {
return this->toDecisionTreeFactor() / df;
}
/* ************************************************************************* */
bool AllDiff::ensureArcConsistency(Key j, Domains* domains) const {
Domain& Dj = domains->at(j);

View File

@ -60,6 +60,10 @@ class GTSAM_UNSTABLE_EXPORT AllDiff : public Constraint {
this->operator*(df->toDecisionTreeFactor()));
}
/// divide by DiscreteFactor::shared_ptr f (safely)
DiscreteFactor::shared_ptr operator/(
const DiscreteFactor::shared_ptr& df) const override;
/// Compute error for each assignment and return as a tree
AlgebraicDecisionTree<Key> errorTree() const override {
throw std::runtime_error("AllDiff::error not implemented");

View File

@ -76,6 +76,12 @@ class BinaryAllDiff : public Constraint {
this->operator*(df->toDecisionTreeFactor()));
}
/// divide by DiscreteFactor::shared_ptr f (safely)
DiscreteFactor::shared_ptr operator/(
const DiscreteFactor::shared_ptr& df) const override {
return this->toDecisionTreeFactor() / df;
}
/*
* Ensure Arc-consistency by checking every possible value of domain j.
* @param j domain to be checked

View File

@ -49,6 +49,12 @@ DecisionTreeFactor Domain::operator*(const DecisionTreeFactor& f) const {
return toDecisionTreeFactor() * f;
}
/* ************************************************************************* */
DiscreteFactor::shared_ptr Domain::operator/(
const DiscreteFactor::shared_ptr& df) const {
return this->toDecisionTreeFactor() / df;
}
/* ************************************************************************* */
bool Domain::ensureArcConsistency(Key j, Domains* domains) const {
if (j != key()) throw invalid_argument("Domain check on wrong domain");

View File

@ -97,6 +97,10 @@ class GTSAM_UNSTABLE_EXPORT Domain : public Constraint {
this->operator*(df->toDecisionTreeFactor()));
}
/// divide by DiscreteFactor::shared_ptr f (safely)
DiscreteFactor::shared_ptr operator/(
const DiscreteFactor::shared_ptr& df) const override;
/*
* Ensure Arc-consistency by checking every possible value of domain j.
* @param j domain to be checked

View File

@ -41,6 +41,12 @@ DecisionTreeFactor SingleValue::operator*(const DecisionTreeFactor& f) const {
return toDecisionTreeFactor() * f;
}
/* ************************************************************************* */
DiscreteFactor::shared_ptr SingleValue::operator/(
const DiscreteFactor::shared_ptr& df) const {
return this->toDecisionTreeFactor() / df;
}
/* ************************************************************************* */
bool SingleValue::ensureArcConsistency(Key j, Domains* domains) const {
if (j != keys_[0])

View File

@ -70,6 +70,10 @@ class GTSAM_UNSTABLE_EXPORT SingleValue : public Constraint {
this->operator*(df->toDecisionTreeFactor()));
}
/// divide by DiscreteFactor::shared_ptr f (safely)
DiscreteFactor::shared_ptr operator/(
const DiscreteFactor::shared_ptr& df) const override;
/*
* Ensure Arc-consistency: just sets domain[j] to {value_}.
* @param j domain to be checked