diff --git a/gtsam/discrete/DecisionTreeFactor.cpp b/gtsam/discrete/DecisionTreeFactor.cpp
index ef7979d0a..19deccd78 100644
--- a/gtsam/discrete/DecisionTreeFactor.cpp
+++ b/gtsam/discrete/DecisionTreeFactor.cpp
@@ -87,6 +87,25 @@ namespace gtsam {
return result;
}
+ /* ************************************************************************ */
+ DiscreteFactor::shared_ptr DecisionTreeFactor::operator/(
+ const DiscreteFactor::shared_ptr& f) const {
+ if (auto tf = std::dynamic_pointer_cast(f)) {
+ // Check if `f` is a TableFactor. If yes, then
+ // convert `this` to a TableFactor which is cheaper.
+ return std::make_shared(tf->operator/(TableFactor(*this)));
+
+ } else if (auto dtf = std::dynamic_pointer_cast(f)) {
+ // If `f` is a DecisionTreeFactor, divide normally.
+ return std::make_shared(this->operator/(*dtf));
+
+ } else {
+ // Else, convert `f` to a DecisionTreeFactor so we can divide
+ return std::make_shared(
+ this->operator/(f->toDecisionTreeFactor()));
+ }
+ }
+
/* ************************************************************************ */
double DecisionTreeFactor::safe_div(const double& a, const double& b) {
// The use for safe_div is when we divide the product factor by the sum
diff --git a/gtsam/discrete/DecisionTreeFactor.h b/gtsam/discrete/DecisionTreeFactor.h
index ec1c2f8a2..4da5a7c17 100644
--- a/gtsam/discrete/DecisionTreeFactor.h
+++ b/gtsam/discrete/DecisionTreeFactor.h
@@ -184,26 +184,30 @@ namespace gtsam {
return apply(f, safe_div);
}
+ /// divide by DiscreteFactor::shared_ptr f (safely)
+ DiscreteFactor::shared_ptr operator/(
+ const DiscreteFactor::shared_ptr& f) const override;
+
/// Convert into a decision tree
DecisionTreeFactor toDecisionTreeFactor() const override { return *this; }
/// Create new factor by summing all values with the same separator values
- shared_ptr sum(size_t nrFrontals) const {
+ DiscreteFactor::shared_ptr sum(size_t nrFrontals) const override {
return combine(nrFrontals, Ring::add);
}
/// Create new factor by summing all values with the same separator values
- shared_ptr sum(const Ordering& keys) const {
+ DiscreteFactor::shared_ptr sum(const Ordering& keys) const override {
return combine(keys, Ring::add);
}
/// Create new factor by maximizing over all values with the same separator.
- shared_ptr max(size_t nrFrontals) const {
+ DiscreteFactor::shared_ptr max(size_t nrFrontals) const override {
return combine(nrFrontals, Ring::max);
}
/// Create new factor by maximizing over all values with the same separator.
- shared_ptr max(const Ordering& keys) const {
+ DiscreteFactor::shared_ptr max(const Ordering& keys) const override {
return combine(keys, Ring::max);
}
@@ -284,6 +288,12 @@ namespace gtsam {
*/
DecisionTreeFactor prune(size_t maxNrAssignments) const;
+ /**
+ * Get the number of non-zero values contained in this factor.
+ * It could be much smaller than `prod_{key}(cardinality(key))`.
+ */
+ uint64_t nrValues() const override { return nrLeaves(); }
+
/// @}
/// @name Wrapper support
/// @{
diff --git a/gtsam/discrete/DiscreteConditional.cpp b/gtsam/discrete/DiscreteConditional.cpp
index eeb5dca3f..19dcdc729 100644
--- a/gtsam/discrete/DiscreteConditional.cpp
+++ b/gtsam/discrete/DiscreteConditional.cpp
@@ -24,13 +24,13 @@
#include
#include
+#include
#include
#include
#include
#include
#include
#include
-#include
using namespace std;
using std::pair;
@@ -44,8 +44,9 @@ template class GTSAM_EXPORT
/* ************************************************************************** */
DiscreteConditional::DiscreteConditional(const size_t nrFrontals,
- const DecisionTreeFactor& f)
- : BaseFactor(f / (*f.sum(nrFrontals))), BaseConditional(nrFrontals) {}
+ const DiscreteFactor& f)
+ : BaseFactor((f / f.sum(nrFrontals))->toDecisionTreeFactor()),
+ BaseConditional(nrFrontals) {}
/* ************************************************************************** */
DiscreteConditional::DiscreteConditional(size_t nrFrontals,
@@ -150,11 +151,11 @@ void DiscreteConditional::print(const string& s,
/* ************************************************************************** */
bool DiscreteConditional::equals(const DiscreteFactor& other,
double tol) const {
- if (!dynamic_cast(&other)) {
+ if (!dynamic_cast(&other)) {
return false;
} else {
- const DecisionTreeFactor& f(static_cast(other));
- return DecisionTreeFactor::equals(f, tol);
+ const BaseFactor& f(static_cast(other));
+ return BaseFactor::equals(f, tol);
}
}
@@ -375,7 +376,7 @@ std::string DiscreteConditional::markdown(const KeyFormatter& keyFormatter,
ss << "*\n" << std::endl;
if (nrParents() == 0) {
// We have no parents, call factor method.
- ss << DecisionTreeFactor::markdown(keyFormatter, names);
+ ss << BaseFactor::markdown(keyFormatter, names);
return ss.str();
}
@@ -427,7 +428,7 @@ string DiscreteConditional::html(const KeyFormatter& keyFormatter,
ss << "
\n";
if (nrParents() == 0) {
// We have no parents, call factor method.
- ss << DecisionTreeFactor::html(keyFormatter, names);
+ ss << BaseFactor::html(keyFormatter, names);
return ss.str();
}
@@ -475,7 +476,7 @@ string DiscreteConditional::html(const KeyFormatter& keyFormatter,
/* ************************************************************************* */
double DiscreteConditional::evaluate(const HybridValues& x) const {
- return this->evaluate(x.discrete());
+ return this->operator()(x.discrete());
}
/* ************************************************************************* */
diff --git a/gtsam/discrete/DiscreteConditional.h b/gtsam/discrete/DiscreteConditional.h
index 3ec9ae590..67f8a0056 100644
--- a/gtsam/discrete/DiscreteConditional.h
+++ b/gtsam/discrete/DiscreteConditional.h
@@ -54,7 +54,7 @@ class GTSAM_EXPORT DiscreteConditional
DiscreteConditional() {}
/// Construct from factor, taking the first `nFrontals` keys as frontals.
- DiscreteConditional(size_t nFrontals, const DecisionTreeFactor& f);
+ DiscreteConditional(size_t nFrontals, const DiscreteFactor& f);
/**
* Construct from DiscreteKeys and AlgebraicDecisionTree, taking the first
diff --git a/gtsam/discrete/DiscreteFactor.h b/gtsam/discrete/DiscreteFactor.h
index c18eaae2f..6cbc00d09 100644
--- a/gtsam/discrete/DiscreteFactor.h
+++ b/gtsam/discrete/DiscreteFactor.h
@@ -22,6 +22,7 @@
#include
#include
#include
+#include
#include
namespace gtsam {
@@ -139,8 +140,30 @@ class GTSAM_EXPORT DiscreteFactor : public Factor {
virtual DiscreteFactor::shared_ptr multiply(
const DiscreteFactor::shared_ptr& df) const = 0;
+ /// divide by DiscreteFactor::shared_ptr f (safely)
+ virtual DiscreteFactor::shared_ptr operator/(
+ const DiscreteFactor::shared_ptr& df) const = 0;
+
virtual DecisionTreeFactor toDecisionTreeFactor() const = 0;
+ /// Create new factor by summing all values with the same separator values
+ virtual DiscreteFactor::shared_ptr sum(size_t nrFrontals) const = 0;
+
+ /// Create new factor by summing all values with the same separator values
+ virtual DiscreteFactor::shared_ptr sum(const Ordering& keys) const = 0;
+
+ /// Create new factor by maximizing over all values with the same separator.
+ virtual DiscreteFactor::shared_ptr max(size_t nrFrontals) const = 0;
+
+ /// Create new factor by maximizing over all values with the same separator.
+ virtual DiscreteFactor::shared_ptr max(const Ordering& keys) const = 0;
+
+ /**
+ * Get the number of non-zero values contained in this factor.
+ * It could be much smaller than `prod_{key}(cardinality(key))`.
+ */
+ virtual uint64_t nrValues() const = 0;
+
/// @}
/// @name Wrapper support
/// @{
diff --git a/gtsam/discrete/DiscreteFactorGraph.cpp b/gtsam/discrete/DiscreteFactorGraph.cpp
index a2b896286..eb3221819 100644
--- a/gtsam/discrete/DiscreteFactorGraph.cpp
+++ b/gtsam/discrete/DiscreteFactorGraph.cpp
@@ -64,7 +64,7 @@ namespace gtsam {
}
/* ************************************************************************ */
- DecisionTreeFactor DiscreteFactorGraph::product() const {
+ DiscreteFactor::shared_ptr DiscreteFactorGraph::product() const {
DiscreteFactor::shared_ptr result;
for (auto it = this->begin(); it != this->end(); ++it) {
if (*it) {
@@ -76,7 +76,7 @@ namespace gtsam {
}
}
}
- return result->toDecisionTreeFactor();
+ return result;
}
/* ************************************************************************ */
@@ -122,20 +122,20 @@ namespace gtsam {
* @brief Multiply all the `factors`.
*
* @param factors The factors to multiply as a DiscreteFactorGraph.
- * @return DecisionTreeFactor
+ * @return DiscreteFactor::shared_ptr
*/
- static DecisionTreeFactor DiscreteProduct(
+ static DiscreteFactor::shared_ptr DiscreteProduct(
const DiscreteFactorGraph& factors) {
// PRODUCT: multiply all factors
gttic(product);
- DecisionTreeFactor product = factors.product();
+ DiscreteFactor::shared_ptr product = factors.product();
gttoc(product);
// Max over all the potentials by pretending all keys are frontal:
- auto denominator = product.max(product.size());
+ auto denominator = product->max(product->size());
// Normalize the product factor to prevent underflow.
- product = product / (*denominator);
+ product = product->operator/(denominator);
return product;
}
@@ -145,25 +145,25 @@ namespace gtsam {
std::pair //
EliminateForMPE(const DiscreteFactorGraph& factors,
const Ordering& frontalKeys) {
- DecisionTreeFactor product = DiscreteProduct(factors);
+ DiscreteFactor::shared_ptr product = DiscreteProduct(factors);
// max out frontals, this is the factor on the separator
gttic(max);
- DecisionTreeFactor::shared_ptr max = product.max(frontalKeys);
+ DiscreteFactor::shared_ptr max = product->max(frontalKeys);
gttoc(max);
// Ordering keys for the conditional so that frontalKeys are really in front
DiscreteKeys orderedKeys;
for (auto&& key : frontalKeys)
- orderedKeys.emplace_back(key, product.cardinality(key));
+ orderedKeys.emplace_back(key, product->cardinality(key));
for (auto&& key : max->keys())
- orderedKeys.emplace_back(key, product.cardinality(key));
+ orderedKeys.emplace_back(key, product->cardinality(key));
// Make lookup with product
gttic(lookup);
size_t nrFrontals = frontalKeys.size();
- auto lookup =
- std::make_shared(nrFrontals, orderedKeys, product);
+ auto lookup = std::make_shared(
+ nrFrontals, orderedKeys, product->toDecisionTreeFactor());
gttoc(lookup);
return {std::dynamic_pointer_cast(lookup), max};
@@ -223,11 +223,11 @@ namespace gtsam {
std::pair //
EliminateDiscrete(const DiscreteFactorGraph& factors,
const Ordering& frontalKeys) {
- DecisionTreeFactor product = DiscreteProduct(factors);
+ DiscreteFactor::shared_ptr product = DiscreteProduct(factors);
// sum out frontals, this is the factor on the separator
gttic(sum);
- DecisionTreeFactor::shared_ptr sum = product.sum(frontalKeys);
+ DiscreteFactor::shared_ptr sum = product->sum(frontalKeys);
gttoc(sum);
// Ordering keys for the conditional so that frontalKeys are really in front
@@ -239,8 +239,9 @@ namespace gtsam {
// now divide product/sum to get conditional
gttic(divide);
- auto conditional =
- std::make_shared(product, *sum, orderedKeys);
+ auto conditional = std::make_shared(
+ product->toDecisionTreeFactor(), sum->toDecisionTreeFactor(),
+ orderedKeys);
gttoc(divide);
return {conditional, sum};
diff --git a/gtsam/discrete/DiscreteFactorGraph.h b/gtsam/discrete/DiscreteFactorGraph.h
index c57d2258c..3d9e86cd1 100644
--- a/gtsam/discrete/DiscreteFactorGraph.h
+++ b/gtsam/discrete/DiscreteFactorGraph.h
@@ -134,6 +134,7 @@ class GTSAM_EXPORT DiscreteFactorGraph
/// @}
+ //TODO(Varun): Make compatible with TableFactor
/** Add a decision-tree factor */
template
void add(Args&&... args) {
@@ -147,7 +148,7 @@ class GTSAM_EXPORT DiscreteFactorGraph
DiscreteKeys discreteKeys() const;
/** return product of all factors as a single factor */
- DecisionTreeFactor product() const;
+ DiscreteFactor::shared_ptr product() const;
/**
* Evaluates the factor graph given values, returns the joint probability of
diff --git a/gtsam/discrete/DiscreteLookupDAG.h b/gtsam/discrete/DiscreteLookupDAG.h
index afb4d4208..1c0be2f06 100644
--- a/gtsam/discrete/DiscreteLookupDAG.h
+++ b/gtsam/discrete/DiscreteLookupDAG.h
@@ -18,6 +18,7 @@
#pragma once
#include
+#include
#include
#include
@@ -54,6 +55,18 @@ class GTSAM_EXPORT DiscreteLookupTable : public DiscreteConditional {
const ADT& potentials)
: DiscreteConditional(nFrontals, keys, potentials) {}
+ /**
+ * @brief Construct a new Discrete Lookup Table object
+ *
+ * @param nFrontals number of frontal variables
+ * @param keys a sorted list of gtsam::Keys
+ * @param potentials Discrete potentials as a TableFactor.
+ */
+ DiscreteLookupTable(size_t nFrontals, const DiscreteKeys& keys,
+ const TableFactor& potentials)
+ : DiscreteConditional(nFrontals, keys,
+ potentials.toDecisionTreeFactor()) {}
+
/// GTSAM-style print
void print(
const std::string& s = "Discrete Lookup Table: ",
diff --git a/gtsam/discrete/TableFactor.cpp b/gtsam/discrete/TableFactor.cpp
index fe901aac1..d1cedc9ef 100644
--- a/gtsam/discrete/TableFactor.cpp
+++ b/gtsam/discrete/TableFactor.cpp
@@ -280,6 +280,20 @@ DiscreteFactor::shared_ptr TableFactor::multiply(
return result;
}
+/* ************************************************************************ */
+DiscreteFactor::shared_ptr TableFactor::operator/(
+ const DiscreteFactor::shared_ptr& f) const {
+ if (auto tf = std::dynamic_pointer_cast(f)) {
+ return std::make_shared(this->operator/(*tf));
+ } else if (auto dtf = std::dynamic_pointer_cast(f)) {
+ return std::make_shared(
+ this->operator/(TableFactor(f->discreteKeys(), *dtf)));
+ } else {
+ TableFactor divisor(f->toDecisionTreeFactor());
+ return std::make_shared(this->operator/(divisor));
+ }
+}
+
/* ************************************************************************ */
DecisionTreeFactor TableFactor::toDecisionTreeFactor() const {
DiscreteKeys dkeys = discreteKeys();
diff --git a/gtsam/discrete/TableFactor.h b/gtsam/discrete/TableFactor.h
index a2e89b302..5a804b6a6 100644
--- a/gtsam/discrete/TableFactor.h
+++ b/gtsam/discrete/TableFactor.h
@@ -17,6 +17,7 @@
#pragma once
+#include
#include
#include
#include
@@ -202,6 +203,10 @@ class GTSAM_EXPORT TableFactor : public DiscreteFactor {
return apply(f, safe_div);
}
+ /// divide by DiscreteFactor::shared_ptr f (safely)
+ DiscreteFactor::shared_ptr operator/(
+ const DiscreteFactor::shared_ptr& f) const override;
+
/// Convert into a decisiontree
DecisionTreeFactor toDecisionTreeFactor() const override;
@@ -210,22 +215,22 @@ class GTSAM_EXPORT TableFactor : public DiscreteFactor {
DiscreteKeys parent_keys) const;
/// Create new factor by summing all values with the same separator values
- shared_ptr sum(size_t nrFrontals) const {
+ DiscreteFactor::shared_ptr sum(size_t nrFrontals) const override {
return combine(nrFrontals, Ring::add);
}
/// Create new factor by summing all values with the same separator values
- shared_ptr sum(const Ordering& keys) const {
+ DiscreteFactor::shared_ptr sum(const Ordering& keys) const override {
return combine(keys, Ring::add);
}
/// Create new factor by maximizing over all values with the same separator.
- shared_ptr max(size_t nrFrontals) const {
+ DiscreteFactor::shared_ptr max(size_t nrFrontals) const override {
return combine(nrFrontals, Ring::max);
}
/// Create new factor by maximizing over all values with the same separator.
- shared_ptr max(const Ordering& keys) const {
+ DiscreteFactor::shared_ptr max(const Ordering& keys) const override {
return combine(keys, Ring::max);
}
@@ -330,6 +335,12 @@ class GTSAM_EXPORT TableFactor : public DiscreteFactor {
*/
TableFactor prune(size_t maxNrAssignments) const;
+ /**
+ * Get the number of non-zero values contained in this factor.
+ * It could be much smaller than `prod_{key}(cardinality(key))`.
+ */
+ uint64_t nrValues() const override { return sparse_table_.nonZeros(); }
+
/// @}
/// @name Wrapper support
/// @{
diff --git a/gtsam/discrete/tests/testDecisionTreeFactor.cpp b/gtsam/discrete/tests/testDecisionTreeFactor.cpp
index ba8714783..ec9185ecb 100644
--- a/gtsam/discrete/tests/testDecisionTreeFactor.cpp
+++ b/gtsam/discrete/tests/testDecisionTreeFactor.cpp
@@ -138,15 +138,18 @@ TEST(DecisionTreeFactor, sum_max) {
DecisionTreeFactor f1(v0 & v1, "1 2 3 4 5 6");
DecisionTreeFactor expected(v1, "9 12");
- DecisionTreeFactor::shared_ptr actual = f1.sum(1);
+ auto actual = std::dynamic_pointer_cast(f1.sum(1));
+ CHECK(actual);
CHECK(assert_equal(expected, *actual, 1e-5));
DecisionTreeFactor expected2(v1, "5 6");
- DecisionTreeFactor::shared_ptr actual2 = f1.max(1);
+ auto actual2 = std::dynamic_pointer_cast(f1.max(1));
+ CHECK(actual2);
CHECK(assert_equal(expected2, *actual2));
DecisionTreeFactor f2(v1 & v0, "1 2 3 4 5 6");
- DecisionTreeFactor::shared_ptr actual22 = f2.sum(1);
+ auto actual22 = std::dynamic_pointer_cast(f2.sum(1));
+ CHECK(actual22);
}
/* ************************************************************************* */
diff --git a/gtsam/discrete/tests/testDiscreteConditional.cpp b/gtsam/discrete/tests/testDiscreteConditional.cpp
index 2482a86a2..b91e1bd8a 100644
--- a/gtsam/discrete/tests/testDiscreteConditional.cpp
+++ b/gtsam/discrete/tests/testDiscreteConditional.cpp
@@ -46,7 +46,7 @@ TEST(DiscreteConditional, constructors) {
DecisionTreeFactor f2(
X & Y & Z, "0.2 0.5 0.3 0.6 0.4 0.7 0.25 0.55 0.35 0.65 0.45 0.75");
DiscreteConditional actual2(1, f2);
- DecisionTreeFactor expected2 = f2 / *f2.sum(1);
+ DecisionTreeFactor expected2 = f2 / f2.sum(1)->toDecisionTreeFactor();
EXPECT(assert_equal(expected2, static_cast(actual2)));
std::vector probs{0.2, 0.5, 0.3, 0.6, 0.4, 0.7, 0.25, 0.55, 0.35, 0.65, 0.45, 0.75};
@@ -70,7 +70,7 @@ TEST(DiscreteConditional, constructors_alt_interface) {
DecisionTreeFactor f2(
X & Y & Z, "0.2 0.5 0.3 0.6 0.4 0.7 0.25 0.55 0.35 0.65 0.45 0.75");
DiscreteConditional actual2(1, f2);
- DecisionTreeFactor expected2 = f2 / *f2.sum(1);
+ DecisionTreeFactor expected2 = f2 / f2.sum(1)->toDecisionTreeFactor();
EXPECT(assert_equal(expected2, static_cast(actual2)));
}
diff --git a/gtsam/discrete/tests/testDiscreteFactorGraph.cpp b/gtsam/discrete/tests/testDiscreteFactorGraph.cpp
index cbcf5234e..0c1dd7a2a 100644
--- a/gtsam/discrete/tests/testDiscreteFactorGraph.cpp
+++ b/gtsam/discrete/tests/testDiscreteFactorGraph.cpp
@@ -94,7 +94,7 @@ TEST_UNSAFE( DiscreteFactorGraph, DiscreteFactorGraphEvaluationTest) {
EXPECT_DOUBLES_EQUAL( 1.944, graph(values), 1e-9);
// Check if graph product works
- DecisionTreeFactor product = graph.product();
+ DecisionTreeFactor product = graph.product()->toDecisionTreeFactor();
EXPECT_DOUBLES_EQUAL( 1.944, product(values), 1e-9);
}
@@ -117,9 +117,9 @@ TEST(DiscreteFactorGraph, test) {
*std::dynamic_pointer_cast(newFactorPtr);
// Normalize newFactor by max for comparison with expected
- auto normalizer = newFactor.max(newFactor.size());
+ auto denominator = newFactor.max(newFactor.size())->toDecisionTreeFactor();
- newFactor = newFactor / *normalizer;
+ newFactor = newFactor / denominator;
// Check Conditional
CHECK(conditional);
@@ -131,9 +131,10 @@ TEST(DiscreteFactorGraph, test) {
CHECK(&newFactor);
DecisionTreeFactor expectedFactor(B & A, "10 6 6 10");
// Normalize by max.
- normalizer = expectedFactor.max(expectedFactor.size());
- // Ensure normalizer is correct.
- expectedFactor = expectedFactor / *normalizer;
+ denominator =
+ expectedFactor.max(expectedFactor.size())->toDecisionTreeFactor();
+ // Ensure denominator is correct.
+ expectedFactor = expectedFactor / denominator;
EXPECT(assert_equal(expectedFactor, newFactor));
// Test using elimination tree
diff --git a/gtsam/discrete/tests/testTableFactor.cpp b/gtsam/discrete/tests/testTableFactor.cpp
index e6c71e15c..76a0f2b5c 100644
--- a/gtsam/discrete/tests/testTableFactor.cpp
+++ b/gtsam/discrete/tests/testTableFactor.cpp
@@ -194,15 +194,17 @@ TEST(TableFactor, Conversion) {
TEST(TableFactor, Empty) {
DiscreteKey X(1, 2);
- TableFactor single = *TableFactor({X}, "1 1").sum(1);
+ auto single = TableFactor({X}, "1 1").sum(1);
// Should not throw a segfault
- EXPECT(assert_equal(*DecisionTreeFactor(X, "1 1").sum(1),
- single.toDecisionTreeFactor()));
+ auto expected_single = DecisionTreeFactor(X, "1 1").sum(1);
+ EXPECT(assert_equal(expected_single->toDecisionTreeFactor(),
+ single->toDecisionTreeFactor()));
- TableFactor empty = *TableFactor({X}, "0 0").sum(1);
+ auto empty = TableFactor({X}, "0 0").sum(1);
// Should not throw a segfault
- EXPECT(assert_equal(*DecisionTreeFactor(X, "0 0").sum(1),
- empty.toDecisionTreeFactor()));
+ auto expected_empty = DecisionTreeFactor(X, "0 0").sum(1);
+ EXPECT(assert_equal(expected_empty->toDecisionTreeFactor(),
+ empty->toDecisionTreeFactor()));
}
/* ************************************************************************* */
@@ -303,15 +305,18 @@ TEST(TableFactor, sum_max) {
TableFactor f1(v0 & v1, "1 2 3 4 5 6");
TableFactor expected(v1, "9 12");
- TableFactor::shared_ptr actual = f1.sum(1);
+ auto actual = std::dynamic_pointer_cast(f1.sum(1));
+ CHECK(actual);
CHECK(assert_equal(expected, *actual, 1e-5));
TableFactor expected2(v1, "5 6");
- TableFactor::shared_ptr actual2 = f1.max(1);
+ auto actual2 = std::dynamic_pointer_cast(f1.max(1));
+ CHECK(actual2);
CHECK(assert_equal(expected2, *actual2));
TableFactor f2(v1 & v0, "1 2 3 4 5 6");
- TableFactor::shared_ptr actual22 = f2.sum(1);
+ auto actual22 = std::dynamic_pointer_cast(f2.sum(1));
+ CHECK(actual22);
}
/* ************************************************************************* */
diff --git a/gtsam_unstable/discrete/Constraint.h b/gtsam_unstable/discrete/Constraint.h
index 71ed7647a..328fabbea 100644
--- a/gtsam_unstable/discrete/Constraint.h
+++ b/gtsam_unstable/discrete/Constraint.h
@@ -68,7 +68,8 @@ class GTSAM_UNSTABLE_EXPORT Constraint : public DiscreteFactor {
/*
* Ensure Arc-consistency by checking every possible value of domain j.
* @param j domain to be checked
- * @param (in/out) domains all domains, but only domains->at(j) will be checked.
+ * @param (in/out) domains all domains, but only domains->at(j) will be
+ * checked.
* @return true if domains->at(j) was changed, false otherwise.
*/
virtual bool ensureArcConsistency(Key j, Domains* domains) const = 0;
@@ -86,6 +87,31 @@ class GTSAM_UNSTABLE_EXPORT Constraint : public DiscreteFactor {
this->operator*(df->toDecisionTreeFactor()));
}
+ /// divide by DiscreteFactor::shared_ptr f (safely)
+ DiscreteFactor::shared_ptr operator/(
+ const DiscreteFactor::shared_ptr& df) const override {
+ return this->toDecisionTreeFactor() / df;
+ }
+
+ /// Get the number of non-zero values contained in this factor.
+ uint64_t nrValues() const override { return 1; };
+
+ DiscreteFactor::shared_ptr sum(size_t nrFrontals) const override {
+ return toDecisionTreeFactor().sum(nrFrontals);
+ }
+
+ DiscreteFactor::shared_ptr sum(const Ordering& keys) const override {
+ return toDecisionTreeFactor().sum(keys);
+ }
+
+ DiscreteFactor::shared_ptr max(size_t nrFrontals) const override {
+ return toDecisionTreeFactor().max(nrFrontals);
+ }
+
+ DiscreteFactor::shared_ptr max(const Ordering& keys) const override {
+ return toDecisionTreeFactor().max(keys);
+ }
+
/// @}
/// @name Wrapper support
/// @{
diff --git a/gtsam_unstable/discrete/Domain.h b/gtsam_unstable/discrete/Domain.h
index 23a566d24..6ce846201 100644
--- a/gtsam_unstable/discrete/Domain.h
+++ b/gtsam_unstable/discrete/Domain.h
@@ -49,7 +49,7 @@ class GTSAM_UNSTABLE_EXPORT Domain : public Constraint {
/// Erase a value, non const :-(
void erase(size_t value) { values_.erase(value); }
- size_t nrValues() const { return values_.size(); }
+ uint64_t nrValues() const override { return values_.size(); }
bool isSingleton() const { return nrValues() == 1; }
diff --git a/gtsam_unstable/discrete/tests/testCSP.cpp b/gtsam_unstable/discrete/tests/testCSP.cpp
index 2b9a20ca6..6806bfe58 100644
--- a/gtsam_unstable/discrete/tests/testCSP.cpp
+++ b/gtsam_unstable/discrete/tests/testCSP.cpp
@@ -124,7 +124,7 @@ TEST(CSP, allInOne) {
EXPECT_DOUBLES_EQUAL(1, csp(valid), 1e-9);
// Just for fun, create the product and check it
- DecisionTreeFactor product = csp.product();
+ DecisionTreeFactor product = csp.product()->toDecisionTreeFactor();
// product.dot("product");
DecisionTreeFactor expectedProduct(ID & AZ & UT, "0 1 0 0 0 0 1 0");
EXPECT(assert_equal(expectedProduct, product));
diff --git a/gtsam_unstable/discrete/tests/testScheduler.cpp b/gtsam_unstable/discrete/tests/testScheduler.cpp
index f868abb5e..5f9b7f287 100644
--- a/gtsam_unstable/discrete/tests/testScheduler.cpp
+++ b/gtsam_unstable/discrete/tests/testScheduler.cpp
@@ -113,7 +113,7 @@ TEST(schedulingExample, test) {
EXPECT(assert_equal(expected, (DiscreteFactorGraph)s));
// Do brute force product and output that to file
- DecisionTreeFactor product = s.product();
+ DecisionTreeFactor product = s.product()->toDecisionTreeFactor();
// product.dot("scheduling", false);
// Do exact inference