diff --git a/gtsam/discrete/DecisionTreeFactor.cpp b/gtsam/discrete/DecisionTreeFactor.cpp
index 1ac782b88..19deccd78 100644
--- a/gtsam/discrete/DecisionTreeFactor.cpp
+++ b/gtsam/discrete/DecisionTreeFactor.cpp
@@ -18,9 +18,10 @@
*/
#include
-#include
#include
#include
+#include
+#include
#include
@@ -62,6 +63,49 @@ namespace gtsam {
return error(values.discrete());
}
+ /* ************************************************************************ */
+ DiscreteFactor::shared_ptr DecisionTreeFactor::multiply(
+ const DiscreteFactor::shared_ptr& f) const {
+ DiscreteFactor::shared_ptr result;
+ if (auto tf = std::dynamic_pointer_cast(f)) {
+ // If f is a TableFactor, we convert `this` to a TableFactor since this
+ // conversion is cheaper than converting `f` to a DecisionTreeFactor. We
+ // then return a TableFactor.
+ result = std::make_shared((*tf) * TableFactor(*this));
+
+ } else if (auto dtf = std::dynamic_pointer_cast(f)) {
+ // If `f` is a DecisionTreeFactor, simply call operator*.
+ result = std::make_shared(this->operator*(*dtf));
+
+ } else {
+ // Simulate double dispatch in C++
+ // Useful for other classes which inherit from DiscreteFactor and have
+ // only `operator*(DecisionTreeFactor)` defined. Thus, other classes don't
+ // need to be updated.
+ result = std::make_shared(f->operator*(*this));
+ }
+ return result;
+ }
+
+ /* ************************************************************************ */
+ DiscreteFactor::shared_ptr DecisionTreeFactor::operator/(
+ const DiscreteFactor::shared_ptr& f) const {
+ if (auto tf = std::dynamic_pointer_cast(f)) {
+ // Check if `f` is a TableFactor. If yes, then
+ // convert `this` to a TableFactor which is cheaper.
+ return std::make_shared(tf->operator/(TableFactor(*this)));
+
+ } else if (auto dtf = std::dynamic_pointer_cast(f)) {
+ // If `f` is a DecisionTreeFactor, divide normally.
+ return std::make_shared(this->operator/(*dtf));
+
+ } else {
+ // Else, convert `f` to a DecisionTreeFactor so we can divide
+ return std::make_shared(
+ this->operator/(f->toDecisionTreeFactor()));
+ }
+ }
+
/* ************************************************************************ */
double DecisionTreeFactor::safe_div(const double& a, const double& b) {
// The use for safe_div is when we divide the product factor by the sum
diff --git a/gtsam/discrete/DecisionTreeFactor.h b/gtsam/discrete/DecisionTreeFactor.h
index 80ee10a7b..4da5a7c17 100644
--- a/gtsam/discrete/DecisionTreeFactor.h
+++ b/gtsam/discrete/DecisionTreeFactor.h
@@ -147,6 +147,23 @@ namespace gtsam {
/// Calculate error for DiscreteValues `x`, is -log(probability).
double error(const DiscreteValues& values) const override;
+ /**
+ * @brief Multiply factors, DiscreteFactor::shared_ptr edition.
+ *
+ * This method accepts `DiscreteFactor::shared_ptr` and uses dynamic
+ * dispatch and specializations to perform the most efficient
+ * multiplication.
+ *
+ * While converting a DecisionTreeFactor to a TableFactor is efficient, the
+ * reverse is not. Hence we specialize the code to return a TableFactor if
+ * `f` is a TableFactor, and DecisionTreeFactor otherwise.
+ *
+ * @param f The factor to multiply with.
+ * @return DiscreteFactor::shared_ptr
+ */
+ virtual DiscreteFactor::shared_ptr multiply(
+ const DiscreteFactor::shared_ptr& f) const override;
+
/// multiply two factors
DecisionTreeFactor operator*(const DecisionTreeFactor& f) const override {
return apply(f, Ring::mul);
@@ -154,31 +171,43 @@ namespace gtsam {
static double safe_div(const double& a, const double& b);
- /// divide by factor f (safely)
+ /**
+ * @brief Divide by factor f (safely).
+ * Division of a factor \f$f(x, y)\f$ by another factor \f$g(y, z)\f$
+ * results in a function which involves all keys
+ * \f$(\frac{f}{g})(x, y, z) = f(x, y) / g(y, z)\f$
+ *
+ * @param f The DecisinTreeFactor to divide by.
+ * @return DecisionTreeFactor
+ */
DecisionTreeFactor operator/(const DecisionTreeFactor& f) const {
return apply(f, safe_div);
}
+ /// divide by DiscreteFactor::shared_ptr f (safely)
+ DiscreteFactor::shared_ptr operator/(
+ const DiscreteFactor::shared_ptr& f) const override;
+
/// Convert into a decision tree
DecisionTreeFactor toDecisionTreeFactor() const override { return *this; }
/// Create new factor by summing all values with the same separator values
- shared_ptr sum(size_t nrFrontals) const {
+ DiscreteFactor::shared_ptr sum(size_t nrFrontals) const override {
return combine(nrFrontals, Ring::add);
}
/// Create new factor by summing all values with the same separator values
- shared_ptr sum(const Ordering& keys) const {
+ DiscreteFactor::shared_ptr sum(const Ordering& keys) const override {
return combine(keys, Ring::add);
}
/// Create new factor by maximizing over all values with the same separator.
- shared_ptr max(size_t nrFrontals) const {
+ DiscreteFactor::shared_ptr max(size_t nrFrontals) const override {
return combine(nrFrontals, Ring::max);
}
/// Create new factor by maximizing over all values with the same separator.
- shared_ptr max(const Ordering& keys) const {
+ DiscreteFactor::shared_ptr max(const Ordering& keys) const override {
return combine(keys, Ring::max);
}
@@ -259,6 +288,12 @@ namespace gtsam {
*/
DecisionTreeFactor prune(size_t maxNrAssignments) const;
+ /**
+ * Get the number of non-zero values contained in this factor.
+ * It could be much smaller than `prod_{key}(cardinality(key))`.
+ */
+ uint64_t nrValues() const override { return nrLeaves(); }
+
/// @}
/// @name Wrapper support
/// @{
diff --git a/gtsam/discrete/DiscreteConditional.cpp b/gtsam/discrete/DiscreteConditional.cpp
index e433243e1..606f4c13c 100644
--- a/gtsam/discrete/DiscreteConditional.cpp
+++ b/gtsam/discrete/DiscreteConditional.cpp
@@ -44,8 +44,9 @@ template class GTSAM_EXPORT
/* ************************************************************************** */
DiscreteConditional::DiscreteConditional(const size_t nrFrontals,
- const DecisionTreeFactor& f)
- : BaseFactor(f / (*f.sum(nrFrontals))), BaseConditional(nrFrontals) {}
+ const DiscreteFactor& f)
+ : BaseFactor((f / f.sum(nrFrontals))->toDecisionTreeFactor()),
+ BaseConditional(nrFrontals) {}
/* ************************************************************************** */
DiscreteConditional::DiscreteConditional(size_t nrFrontals,
@@ -150,11 +151,11 @@ void DiscreteConditional::print(const string& s,
/* ************************************************************************** */
bool DiscreteConditional::equals(const DiscreteFactor& other,
double tol) const {
- if (!dynamic_cast(&other)) {
+ if (!dynamic_cast(&other)) {
return false;
} else {
- const DecisionTreeFactor& f(static_cast(other));
- return DecisionTreeFactor::equals(f, tol);
+ const BaseFactor& f(static_cast(other));
+ return BaseFactor::equals(f, tol);
}
}
@@ -375,7 +376,7 @@ std::string DiscreteConditional::markdown(const KeyFormatter& keyFormatter,
ss << "*\n" << std::endl;
if (nrParents() == 0) {
// We have no parents, call factor method.
- ss << DecisionTreeFactor::markdown(keyFormatter, names);
+ ss << BaseFactor::markdown(keyFormatter, names);
return ss.str();
}
@@ -427,7 +428,7 @@ string DiscreteConditional::html(const KeyFormatter& keyFormatter,
ss << "
\n";
if (nrParents() == 0) {
// We have no parents, call factor method.
- ss << DecisionTreeFactor::html(keyFormatter, names);
+ ss << BaseFactor::html(keyFormatter, names);
return ss.str();
}
@@ -475,7 +476,7 @@ string DiscreteConditional::html(const KeyFormatter& keyFormatter,
/* ************************************************************************* */
double DiscreteConditional::evaluate(const HybridValues& x) const {
- return this->evaluate(x.discrete());
+ return this->operator()(x.discrete());
}
/* ************************************************************************* */
diff --git a/gtsam/discrete/DiscreteConditional.h b/gtsam/discrete/DiscreteConditional.h
index c92a69050..5a26c45e0 100644
--- a/gtsam/discrete/DiscreteConditional.h
+++ b/gtsam/discrete/DiscreteConditional.h
@@ -54,7 +54,7 @@ class GTSAM_EXPORT DiscreteConditional
DiscreteConditional() {}
/// Construct from factor, taking the first `nFrontals` keys as frontals.
- DiscreteConditional(size_t nFrontals, const DecisionTreeFactor& f);
+ DiscreteConditional(size_t nFrontals, const DiscreteFactor& f);
/**
* Construct from DiscreteKeys and AlgebraicDecisionTree, taking the first
diff --git a/gtsam/discrete/DiscreteFactor.h b/gtsam/discrete/DiscreteFactor.h
index a1fde0f86..6cbc00d09 100644
--- a/gtsam/discrete/DiscreteFactor.h
+++ b/gtsam/discrete/DiscreteFactor.h
@@ -22,6 +22,7 @@
#include
#include
#include
+#include
#include
namespace gtsam {
@@ -129,8 +130,40 @@ class GTSAM_EXPORT DiscreteFactor : public Factor {
/// DecisionTreeFactor
virtual DecisionTreeFactor operator*(const DecisionTreeFactor&) const = 0;
+ /**
+ * @brief Multiply in a DiscreteFactor and return the result as
+ * DiscreteFactor, both via shared pointers.
+ *
+ * @param df DiscreteFactor shared_ptr
+ * @return DiscreteFactor::shared_ptr
+ */
+ virtual DiscreteFactor::shared_ptr multiply(
+ const DiscreteFactor::shared_ptr& df) const = 0;
+
+ /// divide by DiscreteFactor::shared_ptr f (safely)
+ virtual DiscreteFactor::shared_ptr operator/(
+ const DiscreteFactor::shared_ptr& df) const = 0;
+
virtual DecisionTreeFactor toDecisionTreeFactor() const = 0;
+ /// Create new factor by summing all values with the same separator values
+ virtual DiscreteFactor::shared_ptr sum(size_t nrFrontals) const = 0;
+
+ /// Create new factor by summing all values with the same separator values
+ virtual DiscreteFactor::shared_ptr sum(const Ordering& keys) const = 0;
+
+ /// Create new factor by maximizing over all values with the same separator.
+ virtual DiscreteFactor::shared_ptr max(size_t nrFrontals) const = 0;
+
+ /// Create new factor by maximizing over all values with the same separator.
+ virtual DiscreteFactor::shared_ptr max(const Ordering& keys) const = 0;
+
+ /**
+ * Get the number of non-zero values contained in this factor.
+ * It could be much smaller than `prod_{key}(cardinality(key))`.
+ */
+ virtual uint64_t nrValues() const = 0;
+
/// @}
/// @name Wrapper support
/// @{
diff --git a/gtsam/discrete/DiscreteFactorGraph.cpp b/gtsam/discrete/DiscreteFactorGraph.cpp
index b48e09b03..9b1774f49 100644
--- a/gtsam/discrete/DiscreteFactorGraph.cpp
+++ b/gtsam/discrete/DiscreteFactorGraph.cpp
@@ -64,10 +64,17 @@ namespace gtsam {
}
/* ************************************************************************ */
- DecisionTreeFactor DiscreteFactorGraph::product() const {
- DecisionTreeFactor result;
- for (const sharedFactor& factor : *this) {
- if (factor) result = (*factor) * result;
+ DiscreteFactor::shared_ptr DiscreteFactorGraph::product() const {
+ DiscreteFactor::shared_ptr result;
+ for (auto it = this->begin(); it != this->end(); ++it) {
+ if (*it) {
+ if (result) {
+ result = result->multiply(*it);
+ } else {
+ // Assign to the first non-null factor
+ result = *it;
+ }
+ }
}
return result;
}
@@ -115,21 +122,23 @@ namespace gtsam {
* @brief Multiply all the `factors`.
*
* @param factors The factors to multiply as a DiscreteFactorGraph.
- * @return DecisionTreeFactor
+ * @return DiscreteFactor::shared_ptr
*/
- static DecisionTreeFactor DiscreteProduct(
+ static DiscreteFactor::shared_ptr DiscreteProduct(
const DiscreteFactorGraph& factors) {
// PRODUCT: multiply all factors
- DecisionTreeFactor product = factors.product();
+ gttic(product);
+ DiscreteFactor::shared_ptr product = factors.product();
+ gttoc(product);
#if GTSAM_HYBRID_TIMING
gttic_(DiscreteNormalize);
#endif
// Max over all the potentials by pretending all keys are frontal:
- auto denominator = product.max(product.size());
+ auto denominator = product->max(product->size());
// Normalize the product factor to prevent underflow.
- product = product / (*denominator);
+ product = product->operator/(denominator);
#if GTSAM_HYBRID_TIMING
gttoc_(DiscreteNormalize);
#endif
@@ -142,25 +151,25 @@ namespace gtsam {
std::pair //
EliminateForMPE(const DiscreteFactorGraph& factors,
const Ordering& frontalKeys) {
- DecisionTreeFactor product = DiscreteProduct(factors);
+ DiscreteFactor::shared_ptr product = DiscreteProduct(factors);
// max out frontals, this is the factor on the separator
gttic(max);
- DecisionTreeFactor::shared_ptr max = product.max(frontalKeys);
+ DiscreteFactor::shared_ptr max = product->max(frontalKeys);
gttoc(max);
// Ordering keys for the conditional so that frontalKeys are really in front
DiscreteKeys orderedKeys;
for (auto&& key : frontalKeys)
- orderedKeys.emplace_back(key, product.cardinality(key));
+ orderedKeys.emplace_back(key, product->cardinality(key));
for (auto&& key : max->keys())
- orderedKeys.emplace_back(key, product.cardinality(key));
+ orderedKeys.emplace_back(key, product->cardinality(key));
// Make lookup with product
gttic(lookup);
size_t nrFrontals = frontalKeys.size();
- auto lookup =
- std::make_shared(nrFrontals, orderedKeys, product);
+ auto lookup = std::make_shared(
+ nrFrontals, orderedKeys, product->toDecisionTreeFactor());
gttoc(lookup);
return {std::dynamic_pointer_cast(lookup), max};
@@ -220,10 +229,12 @@ namespace gtsam {
std::pair //
EliminateDiscrete(const DiscreteFactorGraph& factors,
const Ordering& frontalKeys) {
- DecisionTreeFactor product = DiscreteProduct(factors);
+ DiscreteFactor::shared_ptr product = DiscreteProduct(factors);
// sum out frontals, this is the factor on the separator
- DecisionTreeFactor::shared_ptr sum = product.sum(frontalKeys);
+ gttic(sum);
+ DiscreteFactor::shared_ptr sum = product->sum(frontalKeys);
+ gttoc(sum);
// Ordering keys for the conditional so that frontalKeys are really in front
Ordering orderedKeys;
@@ -233,8 +244,11 @@ namespace gtsam {
sum->keys().end());
// now divide product/sum to get conditional
- auto conditional =
- std::make_shared(product, *sum, orderedKeys);
+ gttic(divide);
+ auto conditional = std::make_shared(
+ product->toDecisionTreeFactor(), sum->toDecisionTreeFactor(),
+ orderedKeys);
+ gttoc(divide);
return {conditional, sum};
}
diff --git a/gtsam/discrete/DiscreteFactorGraph.h b/gtsam/discrete/DiscreteFactorGraph.h
index c57d2258c..3d9e86cd1 100644
--- a/gtsam/discrete/DiscreteFactorGraph.h
+++ b/gtsam/discrete/DiscreteFactorGraph.h
@@ -134,6 +134,7 @@ class GTSAM_EXPORT DiscreteFactorGraph
/// @}
+ //TODO(Varun): Make compatible with TableFactor
/** Add a decision-tree factor */
template
void add(Args&&... args) {
@@ -147,7 +148,7 @@ class GTSAM_EXPORT DiscreteFactorGraph
DiscreteKeys discreteKeys() const;
/** return product of all factors as a single factor */
- DecisionTreeFactor product() const;
+ DiscreteFactor::shared_ptr product() const;
/**
* Evaluates the factor graph given values, returns the joint probability of
diff --git a/gtsam/discrete/DiscreteLookupDAG.h b/gtsam/discrete/DiscreteLookupDAG.h
index afb4d4208..1c0be2f06 100644
--- a/gtsam/discrete/DiscreteLookupDAG.h
+++ b/gtsam/discrete/DiscreteLookupDAG.h
@@ -18,6 +18,7 @@
#pragma once
#include
+#include
#include
#include
@@ -54,6 +55,18 @@ class GTSAM_EXPORT DiscreteLookupTable : public DiscreteConditional {
const ADT& potentials)
: DiscreteConditional(nFrontals, keys, potentials) {}
+ /**
+ * @brief Construct a new Discrete Lookup Table object
+ *
+ * @param nFrontals number of frontal variables
+ * @param keys a sorted list of gtsam::Keys
+ * @param potentials Discrete potentials as a TableFactor.
+ */
+ DiscreteLookupTable(size_t nFrontals, const DiscreteKeys& keys,
+ const TableFactor& potentials)
+ : DiscreteConditional(nFrontals, keys,
+ potentials.toDecisionTreeFactor()) {}
+
/// GTSAM-style print
void print(
const std::string& s = "Discrete Lookup Table: ",
diff --git a/gtsam/discrete/TableFactor.cpp b/gtsam/discrete/TableFactor.cpp
index a59095d40..d1cedc9ef 100644
--- a/gtsam/discrete/TableFactor.cpp
+++ b/gtsam/discrete/TableFactor.cpp
@@ -254,6 +254,46 @@ DecisionTreeFactor TableFactor::operator*(const DecisionTreeFactor& f) const {
return toDecisionTreeFactor() * f;
}
+/* ************************************************************************ */
+DiscreteFactor::shared_ptr TableFactor::multiply(
+ const DiscreteFactor::shared_ptr& f) const {
+ DiscreteFactor::shared_ptr result;
+ if (auto tf = std::dynamic_pointer_cast(f)) {
+ // If `f` is a TableFactor, we can simply call `operator*`.
+ result = std::make_shared(this->operator*(*tf));
+
+ } else if (auto dtf = std::dynamic_pointer_cast(f)) {
+ // If `f` is a DecisionTreeFactor, we convert to a TableFactor which is
+ // cheaper than converting `this` to a DecisionTreeFactor.
+ result = std::make_shared(this->operator*(TableFactor(*dtf)));
+
+ } else {
+ // Simulate double dispatch in C++
+ // Useful for other classes which inherit from DiscreteFactor and have
+ // only `operator*(DecisionTreeFactor)` defined. Thus, other classes don't
+ // need to be updated to know about TableFactor.
+ // Those classes can be specialized to use TableFactor
+ // if efficiency is a problem.
+ result = std::make_shared(
+ f->operator*(this->toDecisionTreeFactor()));
+ }
+ return result;
+}
+
+/* ************************************************************************ */
+DiscreteFactor::shared_ptr TableFactor::operator/(
+ const DiscreteFactor::shared_ptr& f) const {
+ if (auto tf = std::dynamic_pointer_cast(f)) {
+ return std::make_shared(this->operator/(*tf));
+ } else if (auto dtf = std::dynamic_pointer_cast(f)) {
+ return std::make_shared(
+ this->operator/(TableFactor(f->discreteKeys(), *dtf)));
+ } else {
+ TableFactor divisor(f->toDecisionTreeFactor());
+ return std::make_shared(this->operator/(divisor));
+ }
+}
+
/* ************************************************************************ */
DecisionTreeFactor TableFactor::toDecisionTreeFactor() const {
DiscreteKeys dkeys = discreteKeys();
diff --git a/gtsam/discrete/TableFactor.h b/gtsam/discrete/TableFactor.h
index 72778d711..43f84f874 100644
--- a/gtsam/discrete/TableFactor.h
+++ b/gtsam/discrete/TableFactor.h
@@ -17,6 +17,7 @@
#pragma once
+#include
#include
#include
#include
@@ -178,6 +179,23 @@ class GTSAM_EXPORT TableFactor : public DiscreteFactor {
/// multiply with DecisionTreeFactor
DecisionTreeFactor operator*(const DecisionTreeFactor& f) const override;
+ /**
+ * @brief Multiply factors, DiscreteFactor::shared_ptr edition.
+ *
+ * This method accepts `DiscreteFactor::shared_ptr` and uses dynamic
+ * dispatch and specializations to perform the most efficient
+ * multiplication.
+ *
+ * While converting a DecisionTreeFactor to a TableFactor is efficient, the
+ * reverse is not.
+ * Hence we specialize the code to return a TableFactor always.
+ *
+ * @param f The factor to multiply with.
+ * @return DiscreteFactor::shared_ptr
+ */
+ virtual DiscreteFactor::shared_ptr multiply(
+ const DiscreteFactor::shared_ptr& f) const override;
+
static double safe_div(const double& a, const double& b);
/// divide by factor f (safely)
@@ -185,6 +203,10 @@ class GTSAM_EXPORT TableFactor : public DiscreteFactor {
return apply(f, safe_div);
}
+ /// divide by DiscreteFactor::shared_ptr f (safely)
+ DiscreteFactor::shared_ptr operator/(
+ const DiscreteFactor::shared_ptr& f) const override;
+
/// Convert into a decisiontree
DecisionTreeFactor toDecisionTreeFactor() const override;
@@ -193,22 +215,22 @@ class GTSAM_EXPORT TableFactor : public DiscreteFactor {
DiscreteKeys parent_keys) const;
/// Create new factor by summing all values with the same separator values
- shared_ptr sum(size_t nrFrontals) const {
+ DiscreteFactor::shared_ptr sum(size_t nrFrontals) const override {
return combine(nrFrontals, Ring::add);
}
/// Create new factor by summing all values with the same separator values
- shared_ptr sum(const Ordering& keys) const {
+ DiscreteFactor::shared_ptr sum(const Ordering& keys) const override {
return combine(keys, Ring::add);
}
/// Create new factor by maximizing over all values with the same separator.
- shared_ptr max(size_t nrFrontals) const {
+ DiscreteFactor::shared_ptr max(size_t nrFrontals) const override {
return combine(nrFrontals, Ring::max);
}
/// Create new factor by maximizing over all values with the same separator.
- shared_ptr max(const Ordering& keys) const {
+ DiscreteFactor::shared_ptr max(const Ordering& keys) const override {
return combine(keys, Ring::max);
}
@@ -313,6 +335,12 @@ class GTSAM_EXPORT TableFactor : public DiscreteFactor {
*/
TableFactor prune(size_t maxNrAssignments) const;
+ /**
+ * Get the number of non-zero values contained in this factor.
+ * It could be much smaller than `prod_{key}(cardinality(key))`.
+ */
+ uint64_t nrValues() const override { return sparse_table_.nonZeros(); }
+
/// @}
/// @name Wrapper support
/// @{
diff --git a/gtsam/discrete/tests/testDecisionTreeFactor.cpp b/gtsam/discrete/tests/testDecisionTreeFactor.cpp
index 756a0cebe..ec9185ecb 100644
--- a/gtsam/discrete/tests/testDecisionTreeFactor.cpp
+++ b/gtsam/discrete/tests/testDecisionTreeFactor.cpp
@@ -30,6 +30,12 @@
using namespace std;
using namespace gtsam;
+/** Convert Signature into CPT */
+DecisionTreeFactor create(const Signature& signature) {
+ DecisionTreeFactor p(signature.discreteKeys(), signature.cpt());
+ return p;
+}
+
/* ************************************************************************* */
TEST(DecisionTreeFactor, ConstructorsMatch) {
// Declare two keys
@@ -105,21 +111,45 @@ TEST(DecisionTreeFactor, multiplication) {
CHECK(assert_equal(expected2, actual));
}
+/* ************************************************************************* */
+TEST(DecisionTreeFactor, Divide) {
+ DiscreteKey A(0, 2), S(1, 2);
+ DecisionTreeFactor pA = create(A % "99/1"), pS = create(S % "50/50");
+ DecisionTreeFactor joint = pA * pS;
+
+ DecisionTreeFactor s = joint / pA;
+
+ // Factors are not equal due to difference in keys
+ EXPECT(assert_inequal(pS, s));
+
+ // The underlying data should be the same
+ using ADT = AlgebraicDecisionTree;
+ EXPECT(assert_equal(ADT(pS), ADT(s)));
+
+ KeySet keys(joint.keys());
+ keys.insert(pA.keys().begin(), pA.keys().end());
+ EXPECT(assert_inequal(KeySet(pS.keys()), keys));
+
+}
+
/* ************************************************************************* */
TEST(DecisionTreeFactor, sum_max) {
DiscreteKey v0(0, 3), v1(1, 2);
DecisionTreeFactor f1(v0 & v1, "1 2 3 4 5 6");
DecisionTreeFactor expected(v1, "9 12");
- DecisionTreeFactor::shared_ptr actual = f1.sum(1);
+ auto actual = std::dynamic_pointer_cast(f1.sum(1));
+ CHECK(actual);
CHECK(assert_equal(expected, *actual, 1e-5));
DecisionTreeFactor expected2(v1, "5 6");
- DecisionTreeFactor::shared_ptr actual2 = f1.max(1);
+ auto actual2 = std::dynamic_pointer_cast(f1.max(1));
+ CHECK(actual2);
CHECK(assert_equal(expected2, *actual2));
DecisionTreeFactor f2(v1 & v0, "1 2 3 4 5 6");
- DecisionTreeFactor::shared_ptr actual22 = f2.sum(1);
+ auto actual22 = std::dynamic_pointer_cast(f2.sum(1));
+ CHECK(actual22);
}
/* ************************************************************************* */
@@ -217,12 +247,6 @@ void maybeSaveDotFile(const DecisionTreeFactor& f, const string& filename) {
#endif
}
-/** Convert Signature into CPT */
-DecisionTreeFactor create(const Signature& signature) {
- DecisionTreeFactor p(signature.discreteKeys(), signature.cpt());
- return p;
-}
-
/* ************************************************************************* */
// test Asia Joint
TEST(DecisionTreeFactor, joint) {
diff --git a/gtsam/discrete/tests/testDiscreteConditional.cpp b/gtsam/discrete/tests/testDiscreteConditional.cpp
index 2482a86a2..b91e1bd8a 100644
--- a/gtsam/discrete/tests/testDiscreteConditional.cpp
+++ b/gtsam/discrete/tests/testDiscreteConditional.cpp
@@ -46,7 +46,7 @@ TEST(DiscreteConditional, constructors) {
DecisionTreeFactor f2(
X & Y & Z, "0.2 0.5 0.3 0.6 0.4 0.7 0.25 0.55 0.35 0.65 0.45 0.75");
DiscreteConditional actual2(1, f2);
- DecisionTreeFactor expected2 = f2 / *f2.sum(1);
+ DecisionTreeFactor expected2 = f2 / f2.sum(1)->toDecisionTreeFactor();
EXPECT(assert_equal(expected2, static_cast(actual2)));
std::vector probs{0.2, 0.5, 0.3, 0.6, 0.4, 0.7, 0.25, 0.55, 0.35, 0.65, 0.45, 0.75};
@@ -70,7 +70,7 @@ TEST(DiscreteConditional, constructors_alt_interface) {
DecisionTreeFactor f2(
X & Y & Z, "0.2 0.5 0.3 0.6 0.4 0.7 0.25 0.55 0.35 0.65 0.45 0.75");
DiscreteConditional actual2(1, f2);
- DecisionTreeFactor expected2 = f2 / *f2.sum(1);
+ DecisionTreeFactor expected2 = f2 / f2.sum(1)->toDecisionTreeFactor();
EXPECT(assert_equal(expected2, static_cast(actual2)));
}
diff --git a/gtsam/discrete/tests/testDiscreteFactorGraph.cpp b/gtsam/discrete/tests/testDiscreteFactorGraph.cpp
index cbcf5234e..0c1dd7a2a 100644
--- a/gtsam/discrete/tests/testDiscreteFactorGraph.cpp
+++ b/gtsam/discrete/tests/testDiscreteFactorGraph.cpp
@@ -94,7 +94,7 @@ TEST_UNSAFE( DiscreteFactorGraph, DiscreteFactorGraphEvaluationTest) {
EXPECT_DOUBLES_EQUAL( 1.944, graph(values), 1e-9);
// Check if graph product works
- DecisionTreeFactor product = graph.product();
+ DecisionTreeFactor product = graph.product()->toDecisionTreeFactor();
EXPECT_DOUBLES_EQUAL( 1.944, product(values), 1e-9);
}
@@ -117,9 +117,9 @@ TEST(DiscreteFactorGraph, test) {
*std::dynamic_pointer_cast(newFactorPtr);
// Normalize newFactor by max for comparison with expected
- auto normalizer = newFactor.max(newFactor.size());
+ auto denominator = newFactor.max(newFactor.size())->toDecisionTreeFactor();
- newFactor = newFactor / *normalizer;
+ newFactor = newFactor / denominator;
// Check Conditional
CHECK(conditional);
@@ -131,9 +131,10 @@ TEST(DiscreteFactorGraph, test) {
CHECK(&newFactor);
DecisionTreeFactor expectedFactor(B & A, "10 6 6 10");
// Normalize by max.
- normalizer = expectedFactor.max(expectedFactor.size());
- // Ensure normalizer is correct.
- expectedFactor = expectedFactor / *normalizer;
+ denominator =
+ expectedFactor.max(expectedFactor.size())->toDecisionTreeFactor();
+ // Ensure denominator is correct.
+ expectedFactor = expectedFactor / denominator;
EXPECT(assert_equal(expectedFactor, newFactor));
// Test using elimination tree
diff --git a/gtsam/discrete/tests/testTableFactor.cpp b/gtsam/discrete/tests/testTableFactor.cpp
index e6c71e15c..76a0f2b5c 100644
--- a/gtsam/discrete/tests/testTableFactor.cpp
+++ b/gtsam/discrete/tests/testTableFactor.cpp
@@ -194,15 +194,17 @@ TEST(TableFactor, Conversion) {
TEST(TableFactor, Empty) {
DiscreteKey X(1, 2);
- TableFactor single = *TableFactor({X}, "1 1").sum(1);
+ auto single = TableFactor({X}, "1 1").sum(1);
// Should not throw a segfault
- EXPECT(assert_equal(*DecisionTreeFactor(X, "1 1").sum(1),
- single.toDecisionTreeFactor()));
+ auto expected_single = DecisionTreeFactor(X, "1 1").sum(1);
+ EXPECT(assert_equal(expected_single->toDecisionTreeFactor(),
+ single->toDecisionTreeFactor()));
- TableFactor empty = *TableFactor({X}, "0 0").sum(1);
+ auto empty = TableFactor({X}, "0 0").sum(1);
// Should not throw a segfault
- EXPECT(assert_equal(*DecisionTreeFactor(X, "0 0").sum(1),
- empty.toDecisionTreeFactor()));
+ auto expected_empty = DecisionTreeFactor(X, "0 0").sum(1);
+ EXPECT(assert_equal(expected_empty->toDecisionTreeFactor(),
+ empty->toDecisionTreeFactor()));
}
/* ************************************************************************* */
@@ -303,15 +305,18 @@ TEST(TableFactor, sum_max) {
TableFactor f1(v0 & v1, "1 2 3 4 5 6");
TableFactor expected(v1, "9 12");
- TableFactor::shared_ptr actual = f1.sum(1);
+ auto actual = std::dynamic_pointer_cast(f1.sum(1));
+ CHECK(actual);
CHECK(assert_equal(expected, *actual, 1e-5));
TableFactor expected2(v1, "5 6");
- TableFactor::shared_ptr actual2 = f1.max(1);
+ auto actual2 = std::dynamic_pointer_cast(f1.max(1));
+ CHECK(actual2);
CHECK(assert_equal(expected2, *actual2));
TableFactor f2(v1 & v0, "1 2 3 4 5 6");
- TableFactor::shared_ptr actual22 = f2.sum(1);
+ auto actual22 = std::dynamic_pointer_cast(f2.sum(1));
+ CHECK(actual22);
}
/* ************************************************************************* */
diff --git a/gtsam_unstable/discrete/Constraint.h b/gtsam_unstable/discrete/Constraint.h
index 3526a282d..328fabbea 100644
--- a/gtsam_unstable/discrete/Constraint.h
+++ b/gtsam_unstable/discrete/Constraint.h
@@ -68,7 +68,8 @@ class GTSAM_UNSTABLE_EXPORT Constraint : public DiscreteFactor {
/*
* Ensure Arc-consistency by checking every possible value of domain j.
* @param j domain to be checked
- * @param (in/out) domains all domains, but only domains->at(j) will be checked.
+ * @param (in/out) domains all domains, but only domains->at(j) will be
+ * checked.
* @return true if domains->at(j) was changed, false otherwise.
*/
virtual bool ensureArcConsistency(Key j, Domains* domains) const = 0;
@@ -78,6 +79,39 @@ class GTSAM_UNSTABLE_EXPORT Constraint : public DiscreteFactor {
/// Partially apply known values, domain version
virtual shared_ptr partiallyApply(const Domains&) const = 0;
+
+ /// Multiply factors, DiscreteFactor::shared_ptr edition
+ DiscreteFactor::shared_ptr multiply(
+ const DiscreteFactor::shared_ptr& df) const override {
+ return std::make_shared(
+ this->operator*(df->toDecisionTreeFactor()));
+ }
+
+ /// divide by DiscreteFactor::shared_ptr f (safely)
+ DiscreteFactor::shared_ptr operator/(
+ const DiscreteFactor::shared_ptr& df) const override {
+ return this->toDecisionTreeFactor() / df;
+ }
+
+ /// Get the number of non-zero values contained in this factor.
+ uint64_t nrValues() const override { return 1; };
+
+ DiscreteFactor::shared_ptr sum(size_t nrFrontals) const override {
+ return toDecisionTreeFactor().sum(nrFrontals);
+ }
+
+ DiscreteFactor::shared_ptr sum(const Ordering& keys) const override {
+ return toDecisionTreeFactor().sum(keys);
+ }
+
+ DiscreteFactor::shared_ptr max(size_t nrFrontals) const override {
+ return toDecisionTreeFactor().max(nrFrontals);
+ }
+
+ DiscreteFactor::shared_ptr max(const Ordering& keys) const override {
+ return toDecisionTreeFactor().max(keys);
+ }
+
/// @}
/// @name Wrapper support
/// @{
diff --git a/gtsam_unstable/discrete/Domain.h b/gtsam_unstable/discrete/Domain.h
index 23a566d24..6ce846201 100644
--- a/gtsam_unstable/discrete/Domain.h
+++ b/gtsam_unstable/discrete/Domain.h
@@ -49,7 +49,7 @@ class GTSAM_UNSTABLE_EXPORT Domain : public Constraint {
/// Erase a value, non const :-(
void erase(size_t value) { values_.erase(value); }
- size_t nrValues() const { return values_.size(); }
+ uint64_t nrValues() const override { return values_.size(); }
bool isSingleton() const { return nrValues() == 1; }
diff --git a/gtsam_unstable/discrete/tests/testCSP.cpp b/gtsam_unstable/discrete/tests/testCSP.cpp
index 2b9a20ca6..6806bfe58 100644
--- a/gtsam_unstable/discrete/tests/testCSP.cpp
+++ b/gtsam_unstable/discrete/tests/testCSP.cpp
@@ -124,7 +124,7 @@ TEST(CSP, allInOne) {
EXPECT_DOUBLES_EQUAL(1, csp(valid), 1e-9);
// Just for fun, create the product and check it
- DecisionTreeFactor product = csp.product();
+ DecisionTreeFactor product = csp.product()->toDecisionTreeFactor();
// product.dot("product");
DecisionTreeFactor expectedProduct(ID & AZ & UT, "0 1 0 0 0 0 1 0");
EXPECT(assert_equal(expectedProduct, product));
diff --git a/gtsam_unstable/discrete/tests/testScheduler.cpp b/gtsam_unstable/discrete/tests/testScheduler.cpp
index f868abb5e..5f9b7f287 100644
--- a/gtsam_unstable/discrete/tests/testScheduler.cpp
+++ b/gtsam_unstable/discrete/tests/testScheduler.cpp
@@ -113,7 +113,7 @@ TEST(schedulingExample, test) {
EXPECT(assert_equal(expected, (DiscreteFactorGraph)s));
// Do brute force product and output that to file
- DecisionTreeFactor product = s.product();
+ DecisionTreeFactor product = s.product()->toDecisionTreeFactor();
// product.dot("scheduling", false);
// Do exact inference