Merge branch 'discrete-multiply' into discrete-elimination-refactor
commit
5e9c1300db
|
@ -70,6 +70,9 @@ namespace gtsam {
|
||||||
result = std::make_shared<TableFactor>((*tf) * TableFactor(*this));
|
result = std::make_shared<TableFactor>((*tf) * TableFactor(*this));
|
||||||
} else if (auto dtf = std::dynamic_pointer_cast<DecisionTreeFactor>(f)) {
|
} else if (auto dtf = std::dynamic_pointer_cast<DecisionTreeFactor>(f)) {
|
||||||
result = std::make_shared<DecisionTreeFactor>(this->operator*(*dtf));
|
result = std::make_shared<DecisionTreeFactor>(this->operator*(*dtf));
|
||||||
|
} else {
|
||||||
|
// Simulate double dispatch in C++
|
||||||
|
result = std::make_shared<DecisionTreeFactor>(f->operator*(*this));
|
||||||
}
|
}
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
|
@ -65,11 +65,18 @@ namespace gtsam {
|
||||||
|
|
||||||
/* ************************************************************************ */
|
/* ************************************************************************ */
|
||||||
DecisionTreeFactor DiscreteFactorGraph::product() const {
|
DecisionTreeFactor DiscreteFactorGraph::product() const {
|
||||||
DecisionTreeFactor result;
|
DiscreteFactor::shared_ptr result;
|
||||||
for (const sharedFactor& factor : *this) {
|
for (auto it = this->begin(); it != this->end(); ++it) {
|
||||||
if (factor) result = (*factor) * result;
|
if (*it) {
|
||||||
|
if (result) {
|
||||||
|
result = result->multiply(*it);
|
||||||
|
} else {
|
||||||
|
// Assign to the first non-null factor
|
||||||
|
result = *it;
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
return result;
|
return result->toDecisionTreeFactor();
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************ */
|
/* ************************************************************************ */
|
||||||
|
|
|
@ -256,12 +256,16 @@ DecisionTreeFactor TableFactor::operator*(const DecisionTreeFactor& f) const {
|
||||||
|
|
||||||
/* ************************************************************************ */
|
/* ************************************************************************ */
|
||||||
DiscreteFactor::shared_ptr TableFactor::multiply(
|
DiscreteFactor::shared_ptr TableFactor::multiply(
|
||||||
const DiscreteFactor::shared_ptr& f) const override {
|
const DiscreteFactor::shared_ptr& f) const {
|
||||||
DiscreteFactor::shared_ptr result;
|
DiscreteFactor::shared_ptr result;
|
||||||
if (auto tf = std::dynamic_pointer_cast<TableFactor>(f)) {
|
if (auto tf = std::dynamic_pointer_cast<TableFactor>(f)) {
|
||||||
result = std::make_shared<TableFactor>(this->operator*(*tf));
|
result = std::make_shared<TableFactor>(this->operator*(*tf));
|
||||||
} else if (auto dtf = std::dynamic_pointer_cast<DecisionTreeFactor>(f)) {
|
} else if (auto dtf = std::dynamic_pointer_cast<DecisionTreeFactor>(f)) {
|
||||||
result = std::make_shared<TableFactor>(this->operator*(TableFactor(*dtf)));
|
result = std::make_shared<TableFactor>(this->operator*(TableFactor(*dtf)));
|
||||||
|
} else {
|
||||||
|
// Simulate double dispatch in C++
|
||||||
|
result = std::make_shared<DecisionTreeFactor>(
|
||||||
|
f->operator*(this->toDecisionTreeFactor()));
|
||||||
}
|
}
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
|
@ -53,6 +53,13 @@ class GTSAM_UNSTABLE_EXPORT AllDiff : public Constraint {
|
||||||
/// Multiply into a decisiontree
|
/// Multiply into a decisiontree
|
||||||
DecisionTreeFactor operator*(const DecisionTreeFactor& f) const override;
|
DecisionTreeFactor operator*(const DecisionTreeFactor& f) const override;
|
||||||
|
|
||||||
|
/// Multiply factors, DiscreteFactor::shared_ptr edition
|
||||||
|
DiscreteFactor::shared_ptr multiply(
|
||||||
|
const DiscreteFactor::shared_ptr& df) const override {
|
||||||
|
return std::make_shared<DecisionTreeFactor>(
|
||||||
|
this->operator*(df->toDecisionTreeFactor()));
|
||||||
|
}
|
||||||
|
|
||||||
/// Compute error for each assignment and return as a tree
|
/// Compute error for each assignment and return as a tree
|
||||||
AlgebraicDecisionTree<Key> errorTree() const override {
|
AlgebraicDecisionTree<Key> errorTree() const override {
|
||||||
throw std::runtime_error("AllDiff::error not implemented");
|
throw std::runtime_error("AllDiff::error not implemented");
|
||||||
|
|
|
@ -69,6 +69,13 @@ class BinaryAllDiff : public Constraint {
|
||||||
return toDecisionTreeFactor() * f;
|
return toDecisionTreeFactor() * f;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Multiply factors, DiscreteFactor::shared_ptr edition
|
||||||
|
DiscreteFactor::shared_ptr multiply(
|
||||||
|
const DiscreteFactor::shared_ptr& df) const override {
|
||||||
|
return std::make_shared<DecisionTreeFactor>(
|
||||||
|
this->operator*(df->toDecisionTreeFactor()));
|
||||||
|
}
|
||||||
|
|
||||||
/*
|
/*
|
||||||
* Ensure Arc-consistency by checking every possible value of domain j.
|
* Ensure Arc-consistency by checking every possible value of domain j.
|
||||||
* @param j domain to be checked
|
* @param j domain to be checked
|
||||||
|
|
|
@ -90,6 +90,13 @@ class GTSAM_UNSTABLE_EXPORT Domain : public Constraint {
|
||||||
/// Multiply into a decisiontree
|
/// Multiply into a decisiontree
|
||||||
DecisionTreeFactor operator*(const DecisionTreeFactor& f) const override;
|
DecisionTreeFactor operator*(const DecisionTreeFactor& f) const override;
|
||||||
|
|
||||||
|
/// Multiply factors, DiscreteFactor::shared_ptr edition
|
||||||
|
DiscreteFactor::shared_ptr multiply(
|
||||||
|
const DiscreteFactor::shared_ptr& df) const override {
|
||||||
|
return std::make_shared<DecisionTreeFactor>(
|
||||||
|
this->operator*(df->toDecisionTreeFactor()));
|
||||||
|
}
|
||||||
|
|
||||||
/*
|
/*
|
||||||
* Ensure Arc-consistency by checking every possible value of domain j.
|
* Ensure Arc-consistency by checking every possible value of domain j.
|
||||||
* @param j domain to be checked
|
* @param j domain to be checked
|
||||||
|
|
|
@ -63,6 +63,13 @@ class GTSAM_UNSTABLE_EXPORT SingleValue : public Constraint {
|
||||||
/// Multiply into a decisiontree
|
/// Multiply into a decisiontree
|
||||||
DecisionTreeFactor operator*(const DecisionTreeFactor& f) const override;
|
DecisionTreeFactor operator*(const DecisionTreeFactor& f) const override;
|
||||||
|
|
||||||
|
/// Multiply factors, DiscreteFactor::shared_ptr edition
|
||||||
|
DiscreteFactor::shared_ptr multiply(
|
||||||
|
const DiscreteFactor::shared_ptr& df) const override {
|
||||||
|
return std::make_shared<DecisionTreeFactor>(
|
||||||
|
this->operator*(df->toDecisionTreeFactor()));
|
||||||
|
}
|
||||||
|
|
||||||
/*
|
/*
|
||||||
* Ensure Arc-consistency: just sets domain[j] to {value_}.
|
* Ensure Arc-consistency: just sets domain[j] to {value_}.
|
||||||
* @param j domain to be checked
|
* @param j domain to be checked
|
||||||
|
|
Loading…
Reference in New Issue