diff --git a/gtsam/discrete/DecisionTreeFactor.cpp b/gtsam/discrete/DecisionTreeFactor.cpp index ef7979d0a..19deccd78 100644 --- a/gtsam/discrete/DecisionTreeFactor.cpp +++ b/gtsam/discrete/DecisionTreeFactor.cpp @@ -87,6 +87,25 @@ namespace gtsam { return result; } + /* ************************************************************************ */ + DiscreteFactor::shared_ptr DecisionTreeFactor::operator/( + const DiscreteFactor::shared_ptr& f) const { + if (auto tf = std::dynamic_pointer_cast(f)) { + // Check if `f` is a TableFactor. If yes, then + // convert `this` to a TableFactor which is cheaper. + return std::make_shared(tf->operator/(TableFactor(*this))); + + } else if (auto dtf = std::dynamic_pointer_cast(f)) { + // If `f` is a DecisionTreeFactor, divide normally. + return std::make_shared(this->operator/(*dtf)); + + } else { + // Else, convert `f` to a DecisionTreeFactor so we can divide + return std::make_shared( + this->operator/(f->toDecisionTreeFactor())); + } + } + /* ************************************************************************ */ double DecisionTreeFactor::safe_div(const double& a, const double& b) { // The use for safe_div is when we divide the product factor by the sum diff --git a/gtsam/discrete/DecisionTreeFactor.h b/gtsam/discrete/DecisionTreeFactor.h index ec1c2f8a2..4da5a7c17 100644 --- a/gtsam/discrete/DecisionTreeFactor.h +++ b/gtsam/discrete/DecisionTreeFactor.h @@ -184,26 +184,30 @@ namespace gtsam { return apply(f, safe_div); } + /// divide by DiscreteFactor::shared_ptr f (safely) + DiscreteFactor::shared_ptr operator/( + const DiscreteFactor::shared_ptr& f) const override; + /// Convert into a decision tree DecisionTreeFactor toDecisionTreeFactor() const override { return *this; } /// Create new factor by summing all values with the same separator values - shared_ptr sum(size_t nrFrontals) const { + DiscreteFactor::shared_ptr sum(size_t nrFrontals) const override { return combine(nrFrontals, Ring::add); } /// Create new factor by summing all values with the same separator values - shared_ptr sum(const Ordering& keys) const { + DiscreteFactor::shared_ptr sum(const Ordering& keys) const override { return combine(keys, Ring::add); } /// Create new factor by maximizing over all values with the same separator. - shared_ptr max(size_t nrFrontals) const { + DiscreteFactor::shared_ptr max(size_t nrFrontals) const override { return combine(nrFrontals, Ring::max); } /// Create new factor by maximizing over all values with the same separator. - shared_ptr max(const Ordering& keys) const { + DiscreteFactor::shared_ptr max(const Ordering& keys) const override { return combine(keys, Ring::max); } @@ -284,6 +288,12 @@ namespace gtsam { */ DecisionTreeFactor prune(size_t maxNrAssignments) const; + /** + * Get the number of non-zero values contained in this factor. + * It could be much smaller than `prod_{key}(cardinality(key))`. + */ + uint64_t nrValues() const override { return nrLeaves(); } + /// @} /// @name Wrapper support /// @{ diff --git a/gtsam/discrete/DiscreteConditional.cpp b/gtsam/discrete/DiscreteConditional.cpp index eeb5dca3f..19dcdc729 100644 --- a/gtsam/discrete/DiscreteConditional.cpp +++ b/gtsam/discrete/DiscreteConditional.cpp @@ -24,13 +24,13 @@ #include #include +#include #include #include #include #include #include #include -#include using namespace std; using std::pair; @@ -44,8 +44,9 @@ template class GTSAM_EXPORT /* ************************************************************************** */ DiscreteConditional::DiscreteConditional(const size_t nrFrontals, - const DecisionTreeFactor& f) - : BaseFactor(f / (*f.sum(nrFrontals))), BaseConditional(nrFrontals) {} + const DiscreteFactor& f) + : BaseFactor((f / f.sum(nrFrontals))->toDecisionTreeFactor()), + BaseConditional(nrFrontals) {} /* ************************************************************************** */ DiscreteConditional::DiscreteConditional(size_t nrFrontals, @@ -150,11 +151,11 @@ void DiscreteConditional::print(const string& s, /* ************************************************************************** */ bool DiscreteConditional::equals(const DiscreteFactor& other, double tol) const { - if (!dynamic_cast(&other)) { + if (!dynamic_cast(&other)) { return false; } else { - const DecisionTreeFactor& f(static_cast(other)); - return DecisionTreeFactor::equals(f, tol); + const BaseFactor& f(static_cast(other)); + return BaseFactor::equals(f, tol); } } @@ -375,7 +376,7 @@ std::string DiscreteConditional::markdown(const KeyFormatter& keyFormatter, ss << "*\n" << std::endl; if (nrParents() == 0) { // We have no parents, call factor method. - ss << DecisionTreeFactor::markdown(keyFormatter, names); + ss << BaseFactor::markdown(keyFormatter, names); return ss.str(); } @@ -427,7 +428,7 @@ string DiscreteConditional::html(const KeyFormatter& keyFormatter, ss << "

\n"; if (nrParents() == 0) { // We have no parents, call factor method. - ss << DecisionTreeFactor::html(keyFormatter, names); + ss << BaseFactor::html(keyFormatter, names); return ss.str(); } @@ -475,7 +476,7 @@ string DiscreteConditional::html(const KeyFormatter& keyFormatter, /* ************************************************************************* */ double DiscreteConditional::evaluate(const HybridValues& x) const { - return this->evaluate(x.discrete()); + return this->operator()(x.discrete()); } /* ************************************************************************* */ diff --git a/gtsam/discrete/DiscreteConditional.h b/gtsam/discrete/DiscreteConditional.h index 3ec9ae590..67f8a0056 100644 --- a/gtsam/discrete/DiscreteConditional.h +++ b/gtsam/discrete/DiscreteConditional.h @@ -54,7 +54,7 @@ class GTSAM_EXPORT DiscreteConditional DiscreteConditional() {} /// Construct from factor, taking the first `nFrontals` keys as frontals. - DiscreteConditional(size_t nFrontals, const DecisionTreeFactor& f); + DiscreteConditional(size_t nFrontals, const DiscreteFactor& f); /** * Construct from DiscreteKeys and AlgebraicDecisionTree, taking the first diff --git a/gtsam/discrete/DiscreteFactor.h b/gtsam/discrete/DiscreteFactor.h index c18eaae2f..6cbc00d09 100644 --- a/gtsam/discrete/DiscreteFactor.h +++ b/gtsam/discrete/DiscreteFactor.h @@ -22,6 +22,7 @@ #include #include #include +#include #include namespace gtsam { @@ -139,8 +140,30 @@ class GTSAM_EXPORT DiscreteFactor : public Factor { virtual DiscreteFactor::shared_ptr multiply( const DiscreteFactor::shared_ptr& df) const = 0; + /// divide by DiscreteFactor::shared_ptr f (safely) + virtual DiscreteFactor::shared_ptr operator/( + const DiscreteFactor::shared_ptr& df) const = 0; + virtual DecisionTreeFactor toDecisionTreeFactor() const = 0; + /// Create new factor by summing all values with the same separator values + virtual DiscreteFactor::shared_ptr sum(size_t nrFrontals) const = 0; + + /// Create new factor by summing all values with the same separator values + virtual DiscreteFactor::shared_ptr sum(const Ordering& keys) const = 0; + + /// Create new factor by maximizing over all values with the same separator. + virtual DiscreteFactor::shared_ptr max(size_t nrFrontals) const = 0; + + /// Create new factor by maximizing over all values with the same separator. + virtual DiscreteFactor::shared_ptr max(const Ordering& keys) const = 0; + + /** + * Get the number of non-zero values contained in this factor. + * It could be much smaller than `prod_{key}(cardinality(key))`. + */ + virtual uint64_t nrValues() const = 0; + /// @} /// @name Wrapper support /// @{ diff --git a/gtsam/discrete/DiscreteFactorGraph.cpp b/gtsam/discrete/DiscreteFactorGraph.cpp index a2b896286..eb3221819 100644 --- a/gtsam/discrete/DiscreteFactorGraph.cpp +++ b/gtsam/discrete/DiscreteFactorGraph.cpp @@ -64,7 +64,7 @@ namespace gtsam { } /* ************************************************************************ */ - DecisionTreeFactor DiscreteFactorGraph::product() const { + DiscreteFactor::shared_ptr DiscreteFactorGraph::product() const { DiscreteFactor::shared_ptr result; for (auto it = this->begin(); it != this->end(); ++it) { if (*it) { @@ -76,7 +76,7 @@ namespace gtsam { } } } - return result->toDecisionTreeFactor(); + return result; } /* ************************************************************************ */ @@ -122,20 +122,20 @@ namespace gtsam { * @brief Multiply all the `factors`. * * @param factors The factors to multiply as a DiscreteFactorGraph. - * @return DecisionTreeFactor + * @return DiscreteFactor::shared_ptr */ - static DecisionTreeFactor DiscreteProduct( + static DiscreteFactor::shared_ptr DiscreteProduct( const DiscreteFactorGraph& factors) { // PRODUCT: multiply all factors gttic(product); - DecisionTreeFactor product = factors.product(); + DiscreteFactor::shared_ptr product = factors.product(); gttoc(product); // Max over all the potentials by pretending all keys are frontal: - auto denominator = product.max(product.size()); + auto denominator = product->max(product->size()); // Normalize the product factor to prevent underflow. - product = product / (*denominator); + product = product->operator/(denominator); return product; } @@ -145,25 +145,25 @@ namespace gtsam { std::pair // EliminateForMPE(const DiscreteFactorGraph& factors, const Ordering& frontalKeys) { - DecisionTreeFactor product = DiscreteProduct(factors); + DiscreteFactor::shared_ptr product = DiscreteProduct(factors); // max out frontals, this is the factor on the separator gttic(max); - DecisionTreeFactor::shared_ptr max = product.max(frontalKeys); + DiscreteFactor::shared_ptr max = product->max(frontalKeys); gttoc(max); // Ordering keys for the conditional so that frontalKeys are really in front DiscreteKeys orderedKeys; for (auto&& key : frontalKeys) - orderedKeys.emplace_back(key, product.cardinality(key)); + orderedKeys.emplace_back(key, product->cardinality(key)); for (auto&& key : max->keys()) - orderedKeys.emplace_back(key, product.cardinality(key)); + orderedKeys.emplace_back(key, product->cardinality(key)); // Make lookup with product gttic(lookup); size_t nrFrontals = frontalKeys.size(); - auto lookup = - std::make_shared(nrFrontals, orderedKeys, product); + auto lookup = std::make_shared( + nrFrontals, orderedKeys, product->toDecisionTreeFactor()); gttoc(lookup); return {std::dynamic_pointer_cast(lookup), max}; @@ -223,11 +223,11 @@ namespace gtsam { std::pair // EliminateDiscrete(const DiscreteFactorGraph& factors, const Ordering& frontalKeys) { - DecisionTreeFactor product = DiscreteProduct(factors); + DiscreteFactor::shared_ptr product = DiscreteProduct(factors); // sum out frontals, this is the factor on the separator gttic(sum); - DecisionTreeFactor::shared_ptr sum = product.sum(frontalKeys); + DiscreteFactor::shared_ptr sum = product->sum(frontalKeys); gttoc(sum); // Ordering keys for the conditional so that frontalKeys are really in front @@ -239,8 +239,9 @@ namespace gtsam { // now divide product/sum to get conditional gttic(divide); - auto conditional = - std::make_shared(product, *sum, orderedKeys); + auto conditional = std::make_shared( + product->toDecisionTreeFactor(), sum->toDecisionTreeFactor(), + orderedKeys); gttoc(divide); return {conditional, sum}; diff --git a/gtsam/discrete/DiscreteFactorGraph.h b/gtsam/discrete/DiscreteFactorGraph.h index c57d2258c..3d9e86cd1 100644 --- a/gtsam/discrete/DiscreteFactorGraph.h +++ b/gtsam/discrete/DiscreteFactorGraph.h @@ -134,6 +134,7 @@ class GTSAM_EXPORT DiscreteFactorGraph /// @} + //TODO(Varun): Make compatible with TableFactor /** Add a decision-tree factor */ template void add(Args&&... args) { @@ -147,7 +148,7 @@ class GTSAM_EXPORT DiscreteFactorGraph DiscreteKeys discreteKeys() const; /** return product of all factors as a single factor */ - DecisionTreeFactor product() const; + DiscreteFactor::shared_ptr product() const; /** * Evaluates the factor graph given values, returns the joint probability of diff --git a/gtsam/discrete/DiscreteLookupDAG.h b/gtsam/discrete/DiscreteLookupDAG.h index afb4d4208..1c0be2f06 100644 --- a/gtsam/discrete/DiscreteLookupDAG.h +++ b/gtsam/discrete/DiscreteLookupDAG.h @@ -18,6 +18,7 @@ #pragma once #include +#include #include #include @@ -54,6 +55,18 @@ class GTSAM_EXPORT DiscreteLookupTable : public DiscreteConditional { const ADT& potentials) : DiscreteConditional(nFrontals, keys, potentials) {} + /** + * @brief Construct a new Discrete Lookup Table object + * + * @param nFrontals number of frontal variables + * @param keys a sorted list of gtsam::Keys + * @param potentials Discrete potentials as a TableFactor. + */ + DiscreteLookupTable(size_t nFrontals, const DiscreteKeys& keys, + const TableFactor& potentials) + : DiscreteConditional(nFrontals, keys, + potentials.toDecisionTreeFactor()) {} + /// GTSAM-style print void print( const std::string& s = "Discrete Lookup Table: ", diff --git a/gtsam/discrete/TableFactor.cpp b/gtsam/discrete/TableFactor.cpp index fe901aac1..d1cedc9ef 100644 --- a/gtsam/discrete/TableFactor.cpp +++ b/gtsam/discrete/TableFactor.cpp @@ -280,6 +280,20 @@ DiscreteFactor::shared_ptr TableFactor::multiply( return result; } +/* ************************************************************************ */ +DiscreteFactor::shared_ptr TableFactor::operator/( + const DiscreteFactor::shared_ptr& f) const { + if (auto tf = std::dynamic_pointer_cast(f)) { + return std::make_shared(this->operator/(*tf)); + } else if (auto dtf = std::dynamic_pointer_cast(f)) { + return std::make_shared( + this->operator/(TableFactor(f->discreteKeys(), *dtf))); + } else { + TableFactor divisor(f->toDecisionTreeFactor()); + return std::make_shared(this->operator/(divisor)); + } +} + /* ************************************************************************ */ DecisionTreeFactor TableFactor::toDecisionTreeFactor() const { DiscreteKeys dkeys = discreteKeys(); diff --git a/gtsam/discrete/TableFactor.h b/gtsam/discrete/TableFactor.h index a2e89b302..5a804b6a6 100644 --- a/gtsam/discrete/TableFactor.h +++ b/gtsam/discrete/TableFactor.h @@ -17,6 +17,7 @@ #pragma once +#include #include #include #include @@ -202,6 +203,10 @@ class GTSAM_EXPORT TableFactor : public DiscreteFactor { return apply(f, safe_div); } + /// divide by DiscreteFactor::shared_ptr f (safely) + DiscreteFactor::shared_ptr operator/( + const DiscreteFactor::shared_ptr& f) const override; + /// Convert into a decisiontree DecisionTreeFactor toDecisionTreeFactor() const override; @@ -210,22 +215,22 @@ class GTSAM_EXPORT TableFactor : public DiscreteFactor { DiscreteKeys parent_keys) const; /// Create new factor by summing all values with the same separator values - shared_ptr sum(size_t nrFrontals) const { + DiscreteFactor::shared_ptr sum(size_t nrFrontals) const override { return combine(nrFrontals, Ring::add); } /// Create new factor by summing all values with the same separator values - shared_ptr sum(const Ordering& keys) const { + DiscreteFactor::shared_ptr sum(const Ordering& keys) const override { return combine(keys, Ring::add); } /// Create new factor by maximizing over all values with the same separator. - shared_ptr max(size_t nrFrontals) const { + DiscreteFactor::shared_ptr max(size_t nrFrontals) const override { return combine(nrFrontals, Ring::max); } /// Create new factor by maximizing over all values with the same separator. - shared_ptr max(const Ordering& keys) const { + DiscreteFactor::shared_ptr max(const Ordering& keys) const override { return combine(keys, Ring::max); } @@ -330,6 +335,12 @@ class GTSAM_EXPORT TableFactor : public DiscreteFactor { */ TableFactor prune(size_t maxNrAssignments) const; + /** + * Get the number of non-zero values contained in this factor. + * It could be much smaller than `prod_{key}(cardinality(key))`. + */ + uint64_t nrValues() const override { return sparse_table_.nonZeros(); } + /// @} /// @name Wrapper support /// @{ diff --git a/gtsam/discrete/tests/testDecisionTreeFactor.cpp b/gtsam/discrete/tests/testDecisionTreeFactor.cpp index ba8714783..ec9185ecb 100644 --- a/gtsam/discrete/tests/testDecisionTreeFactor.cpp +++ b/gtsam/discrete/tests/testDecisionTreeFactor.cpp @@ -138,15 +138,18 @@ TEST(DecisionTreeFactor, sum_max) { DecisionTreeFactor f1(v0 & v1, "1 2 3 4 5 6"); DecisionTreeFactor expected(v1, "9 12"); - DecisionTreeFactor::shared_ptr actual = f1.sum(1); + auto actual = std::dynamic_pointer_cast(f1.sum(1)); + CHECK(actual); CHECK(assert_equal(expected, *actual, 1e-5)); DecisionTreeFactor expected2(v1, "5 6"); - DecisionTreeFactor::shared_ptr actual2 = f1.max(1); + auto actual2 = std::dynamic_pointer_cast(f1.max(1)); + CHECK(actual2); CHECK(assert_equal(expected2, *actual2)); DecisionTreeFactor f2(v1 & v0, "1 2 3 4 5 6"); - DecisionTreeFactor::shared_ptr actual22 = f2.sum(1); + auto actual22 = std::dynamic_pointer_cast(f2.sum(1)); + CHECK(actual22); } /* ************************************************************************* */ diff --git a/gtsam/discrete/tests/testDiscreteConditional.cpp b/gtsam/discrete/tests/testDiscreteConditional.cpp index 2482a86a2..b91e1bd8a 100644 --- a/gtsam/discrete/tests/testDiscreteConditional.cpp +++ b/gtsam/discrete/tests/testDiscreteConditional.cpp @@ -46,7 +46,7 @@ TEST(DiscreteConditional, constructors) { DecisionTreeFactor f2( X & Y & Z, "0.2 0.5 0.3 0.6 0.4 0.7 0.25 0.55 0.35 0.65 0.45 0.75"); DiscreteConditional actual2(1, f2); - DecisionTreeFactor expected2 = f2 / *f2.sum(1); + DecisionTreeFactor expected2 = f2 / f2.sum(1)->toDecisionTreeFactor(); EXPECT(assert_equal(expected2, static_cast(actual2))); std::vector probs{0.2, 0.5, 0.3, 0.6, 0.4, 0.7, 0.25, 0.55, 0.35, 0.65, 0.45, 0.75}; @@ -70,7 +70,7 @@ TEST(DiscreteConditional, constructors_alt_interface) { DecisionTreeFactor f2( X & Y & Z, "0.2 0.5 0.3 0.6 0.4 0.7 0.25 0.55 0.35 0.65 0.45 0.75"); DiscreteConditional actual2(1, f2); - DecisionTreeFactor expected2 = f2 / *f2.sum(1); + DecisionTreeFactor expected2 = f2 / f2.sum(1)->toDecisionTreeFactor(); EXPECT(assert_equal(expected2, static_cast(actual2))); } diff --git a/gtsam/discrete/tests/testDiscreteFactorGraph.cpp b/gtsam/discrete/tests/testDiscreteFactorGraph.cpp index cbcf5234e..0c1dd7a2a 100644 --- a/gtsam/discrete/tests/testDiscreteFactorGraph.cpp +++ b/gtsam/discrete/tests/testDiscreteFactorGraph.cpp @@ -94,7 +94,7 @@ TEST_UNSAFE( DiscreteFactorGraph, DiscreteFactorGraphEvaluationTest) { EXPECT_DOUBLES_EQUAL( 1.944, graph(values), 1e-9); // Check if graph product works - DecisionTreeFactor product = graph.product(); + DecisionTreeFactor product = graph.product()->toDecisionTreeFactor(); EXPECT_DOUBLES_EQUAL( 1.944, product(values), 1e-9); } @@ -117,9 +117,9 @@ TEST(DiscreteFactorGraph, test) { *std::dynamic_pointer_cast(newFactorPtr); // Normalize newFactor by max for comparison with expected - auto normalizer = newFactor.max(newFactor.size()); + auto denominator = newFactor.max(newFactor.size())->toDecisionTreeFactor(); - newFactor = newFactor / *normalizer; + newFactor = newFactor / denominator; // Check Conditional CHECK(conditional); @@ -131,9 +131,10 @@ TEST(DiscreteFactorGraph, test) { CHECK(&newFactor); DecisionTreeFactor expectedFactor(B & A, "10 6 6 10"); // Normalize by max. - normalizer = expectedFactor.max(expectedFactor.size()); - // Ensure normalizer is correct. - expectedFactor = expectedFactor / *normalizer; + denominator = + expectedFactor.max(expectedFactor.size())->toDecisionTreeFactor(); + // Ensure denominator is correct. + expectedFactor = expectedFactor / denominator; EXPECT(assert_equal(expectedFactor, newFactor)); // Test using elimination tree diff --git a/gtsam/discrete/tests/testTableFactor.cpp b/gtsam/discrete/tests/testTableFactor.cpp index e6c71e15c..76a0f2b5c 100644 --- a/gtsam/discrete/tests/testTableFactor.cpp +++ b/gtsam/discrete/tests/testTableFactor.cpp @@ -194,15 +194,17 @@ TEST(TableFactor, Conversion) { TEST(TableFactor, Empty) { DiscreteKey X(1, 2); - TableFactor single = *TableFactor({X}, "1 1").sum(1); + auto single = TableFactor({X}, "1 1").sum(1); // Should not throw a segfault - EXPECT(assert_equal(*DecisionTreeFactor(X, "1 1").sum(1), - single.toDecisionTreeFactor())); + auto expected_single = DecisionTreeFactor(X, "1 1").sum(1); + EXPECT(assert_equal(expected_single->toDecisionTreeFactor(), + single->toDecisionTreeFactor())); - TableFactor empty = *TableFactor({X}, "0 0").sum(1); + auto empty = TableFactor({X}, "0 0").sum(1); // Should not throw a segfault - EXPECT(assert_equal(*DecisionTreeFactor(X, "0 0").sum(1), - empty.toDecisionTreeFactor())); + auto expected_empty = DecisionTreeFactor(X, "0 0").sum(1); + EXPECT(assert_equal(expected_empty->toDecisionTreeFactor(), + empty->toDecisionTreeFactor())); } /* ************************************************************************* */ @@ -303,15 +305,18 @@ TEST(TableFactor, sum_max) { TableFactor f1(v0 & v1, "1 2 3 4 5 6"); TableFactor expected(v1, "9 12"); - TableFactor::shared_ptr actual = f1.sum(1); + auto actual = std::dynamic_pointer_cast(f1.sum(1)); + CHECK(actual); CHECK(assert_equal(expected, *actual, 1e-5)); TableFactor expected2(v1, "5 6"); - TableFactor::shared_ptr actual2 = f1.max(1); + auto actual2 = std::dynamic_pointer_cast(f1.max(1)); + CHECK(actual2); CHECK(assert_equal(expected2, *actual2)); TableFactor f2(v1 & v0, "1 2 3 4 5 6"); - TableFactor::shared_ptr actual22 = f2.sum(1); + auto actual22 = std::dynamic_pointer_cast(f2.sum(1)); + CHECK(actual22); } /* ************************************************************************* */ diff --git a/gtsam_unstable/discrete/Constraint.h b/gtsam_unstable/discrete/Constraint.h index 71ed7647a..328fabbea 100644 --- a/gtsam_unstable/discrete/Constraint.h +++ b/gtsam_unstable/discrete/Constraint.h @@ -68,7 +68,8 @@ class GTSAM_UNSTABLE_EXPORT Constraint : public DiscreteFactor { /* * Ensure Arc-consistency by checking every possible value of domain j. * @param j domain to be checked - * @param (in/out) domains all domains, but only domains->at(j) will be checked. + * @param (in/out) domains all domains, but only domains->at(j) will be + * checked. * @return true if domains->at(j) was changed, false otherwise. */ virtual bool ensureArcConsistency(Key j, Domains* domains) const = 0; @@ -86,6 +87,31 @@ class GTSAM_UNSTABLE_EXPORT Constraint : public DiscreteFactor { this->operator*(df->toDecisionTreeFactor())); } + /// divide by DiscreteFactor::shared_ptr f (safely) + DiscreteFactor::shared_ptr operator/( + const DiscreteFactor::shared_ptr& df) const override { + return this->toDecisionTreeFactor() / df; + } + + /// Get the number of non-zero values contained in this factor. + uint64_t nrValues() const override { return 1; }; + + DiscreteFactor::shared_ptr sum(size_t nrFrontals) const override { + return toDecisionTreeFactor().sum(nrFrontals); + } + + DiscreteFactor::shared_ptr sum(const Ordering& keys) const override { + return toDecisionTreeFactor().sum(keys); + } + + DiscreteFactor::shared_ptr max(size_t nrFrontals) const override { + return toDecisionTreeFactor().max(nrFrontals); + } + + DiscreteFactor::shared_ptr max(const Ordering& keys) const override { + return toDecisionTreeFactor().max(keys); + } + /// @} /// @name Wrapper support /// @{ diff --git a/gtsam_unstable/discrete/Domain.h b/gtsam_unstable/discrete/Domain.h index 23a566d24..6ce846201 100644 --- a/gtsam_unstable/discrete/Domain.h +++ b/gtsam_unstable/discrete/Domain.h @@ -49,7 +49,7 @@ class GTSAM_UNSTABLE_EXPORT Domain : public Constraint { /// Erase a value, non const :-( void erase(size_t value) { values_.erase(value); } - size_t nrValues() const { return values_.size(); } + uint64_t nrValues() const override { return values_.size(); } bool isSingleton() const { return nrValues() == 1; } diff --git a/gtsam_unstable/discrete/tests/testCSP.cpp b/gtsam_unstable/discrete/tests/testCSP.cpp index 2b9a20ca6..6806bfe58 100644 --- a/gtsam_unstable/discrete/tests/testCSP.cpp +++ b/gtsam_unstable/discrete/tests/testCSP.cpp @@ -124,7 +124,7 @@ TEST(CSP, allInOne) { EXPECT_DOUBLES_EQUAL(1, csp(valid), 1e-9); // Just for fun, create the product and check it - DecisionTreeFactor product = csp.product(); + DecisionTreeFactor product = csp.product()->toDecisionTreeFactor(); // product.dot("product"); DecisionTreeFactor expectedProduct(ID & AZ & UT, "0 1 0 0 0 0 1 0"); EXPECT(assert_equal(expectedProduct, product)); diff --git a/gtsam_unstable/discrete/tests/testScheduler.cpp b/gtsam_unstable/discrete/tests/testScheduler.cpp index f868abb5e..5f9b7f287 100644 --- a/gtsam_unstable/discrete/tests/testScheduler.cpp +++ b/gtsam_unstable/discrete/tests/testScheduler.cpp @@ -113,7 +113,7 @@ TEST(schedulingExample, test) { EXPECT(assert_equal(expected, (DiscreteFactorGraph)s)); // Do brute force product and output that to file - DecisionTreeFactor product = s.product(); + DecisionTreeFactor product = s.product()->toDecisionTreeFactor(); // product.dot("scheduling", false); // Do exact inference