diff --git a/gtsam/discrete/DecisionTreeFactor.cpp b/gtsam/discrete/DecisionTreeFactor.cpp index 1ac782b88..19deccd78 100644 --- a/gtsam/discrete/DecisionTreeFactor.cpp +++ b/gtsam/discrete/DecisionTreeFactor.cpp @@ -18,9 +18,10 @@ */ #include -#include #include #include +#include +#include #include @@ -62,6 +63,49 @@ namespace gtsam { return error(values.discrete()); } + /* ************************************************************************ */ + DiscreteFactor::shared_ptr DecisionTreeFactor::multiply( + const DiscreteFactor::shared_ptr& f) const { + DiscreteFactor::shared_ptr result; + if (auto tf = std::dynamic_pointer_cast(f)) { + // If f is a TableFactor, we convert `this` to a TableFactor since this + // conversion is cheaper than converting `f` to a DecisionTreeFactor. We + // then return a TableFactor. + result = std::make_shared((*tf) * TableFactor(*this)); + + } else if (auto dtf = std::dynamic_pointer_cast(f)) { + // If `f` is a DecisionTreeFactor, simply call operator*. + result = std::make_shared(this->operator*(*dtf)); + + } else { + // Simulate double dispatch in C++ + // Useful for other classes which inherit from DiscreteFactor and have + // only `operator*(DecisionTreeFactor)` defined. Thus, other classes don't + // need to be updated. + result = std::make_shared(f->operator*(*this)); + } + return result; + } + + /* ************************************************************************ */ + DiscreteFactor::shared_ptr DecisionTreeFactor::operator/( + const DiscreteFactor::shared_ptr& f) const { + if (auto tf = std::dynamic_pointer_cast(f)) { + // Check if `f` is a TableFactor. If yes, then + // convert `this` to a TableFactor which is cheaper. + return std::make_shared(tf->operator/(TableFactor(*this))); + + } else if (auto dtf = std::dynamic_pointer_cast(f)) { + // If `f` is a DecisionTreeFactor, divide normally. + return std::make_shared(this->operator/(*dtf)); + + } else { + // Else, convert `f` to a DecisionTreeFactor so we can divide + return std::make_shared( + this->operator/(f->toDecisionTreeFactor())); + } + } + /* ************************************************************************ */ double DecisionTreeFactor::safe_div(const double& a, const double& b) { // The use for safe_div is when we divide the product factor by the sum diff --git a/gtsam/discrete/DecisionTreeFactor.h b/gtsam/discrete/DecisionTreeFactor.h index 80ee10a7b..4da5a7c17 100644 --- a/gtsam/discrete/DecisionTreeFactor.h +++ b/gtsam/discrete/DecisionTreeFactor.h @@ -147,6 +147,23 @@ namespace gtsam { /// Calculate error for DiscreteValues `x`, is -log(probability). double error(const DiscreteValues& values) const override; + /** + * @brief Multiply factors, DiscreteFactor::shared_ptr edition. + * + * This method accepts `DiscreteFactor::shared_ptr` and uses dynamic + * dispatch and specializations to perform the most efficient + * multiplication. + * + * While converting a DecisionTreeFactor to a TableFactor is efficient, the + * reverse is not. Hence we specialize the code to return a TableFactor if + * `f` is a TableFactor, and DecisionTreeFactor otherwise. + * + * @param f The factor to multiply with. + * @return DiscreteFactor::shared_ptr + */ + virtual DiscreteFactor::shared_ptr multiply( + const DiscreteFactor::shared_ptr& f) const override; + /// multiply two factors DecisionTreeFactor operator*(const DecisionTreeFactor& f) const override { return apply(f, Ring::mul); @@ -154,31 +171,43 @@ namespace gtsam { static double safe_div(const double& a, const double& b); - /// divide by factor f (safely) + /** + * @brief Divide by factor f (safely). + * Division of a factor \f$f(x, y)\f$ by another factor \f$g(y, z)\f$ + * results in a function which involves all keys + * \f$(\frac{f}{g})(x, y, z) = f(x, y) / g(y, z)\f$ + * + * @param f The DecisinTreeFactor to divide by. + * @return DecisionTreeFactor + */ DecisionTreeFactor operator/(const DecisionTreeFactor& f) const { return apply(f, safe_div); } + /// divide by DiscreteFactor::shared_ptr f (safely) + DiscreteFactor::shared_ptr operator/( + const DiscreteFactor::shared_ptr& f) const override; + /// Convert into a decision tree DecisionTreeFactor toDecisionTreeFactor() const override { return *this; } /// Create new factor by summing all values with the same separator values - shared_ptr sum(size_t nrFrontals) const { + DiscreteFactor::shared_ptr sum(size_t nrFrontals) const override { return combine(nrFrontals, Ring::add); } /// Create new factor by summing all values with the same separator values - shared_ptr sum(const Ordering& keys) const { + DiscreteFactor::shared_ptr sum(const Ordering& keys) const override { return combine(keys, Ring::add); } /// Create new factor by maximizing over all values with the same separator. - shared_ptr max(size_t nrFrontals) const { + DiscreteFactor::shared_ptr max(size_t nrFrontals) const override { return combine(nrFrontals, Ring::max); } /// Create new factor by maximizing over all values with the same separator. - shared_ptr max(const Ordering& keys) const { + DiscreteFactor::shared_ptr max(const Ordering& keys) const override { return combine(keys, Ring::max); } @@ -259,6 +288,12 @@ namespace gtsam { */ DecisionTreeFactor prune(size_t maxNrAssignments) const; + /** + * Get the number of non-zero values contained in this factor. + * It could be much smaller than `prod_{key}(cardinality(key))`. + */ + uint64_t nrValues() const override { return nrLeaves(); } + /// @} /// @name Wrapper support /// @{ diff --git a/gtsam/discrete/DiscreteConditional.cpp b/gtsam/discrete/DiscreteConditional.cpp index 0eea8b4bd..19dcdc729 100644 --- a/gtsam/discrete/DiscreteConditional.cpp +++ b/gtsam/discrete/DiscreteConditional.cpp @@ -44,8 +44,9 @@ template class GTSAM_EXPORT /* ************************************************************************** */ DiscreteConditional::DiscreteConditional(const size_t nrFrontals, - const DecisionTreeFactor& f) - : BaseFactor(f / (*f.sum(nrFrontals))), BaseConditional(nrFrontals) {} + const DiscreteFactor& f) + : BaseFactor((f / f.sum(nrFrontals))->toDecisionTreeFactor()), + BaseConditional(nrFrontals) {} /* ************************************************************************** */ DiscreteConditional::DiscreteConditional(size_t nrFrontals, @@ -150,11 +151,11 @@ void DiscreteConditional::print(const string& s, /* ************************************************************************** */ bool DiscreteConditional::equals(const DiscreteFactor& other, double tol) const { - if (!dynamic_cast(&other)) { + if (!dynamic_cast(&other)) { return false; } else { - const DecisionTreeFactor& f(static_cast(other)); - return DecisionTreeFactor::equals(f, tol); + const BaseFactor& f(static_cast(other)); + return BaseFactor::equals(f, tol); } } @@ -375,7 +376,7 @@ std::string DiscreteConditional::markdown(const KeyFormatter& keyFormatter, ss << "*\n" << std::endl; if (nrParents() == 0) { // We have no parents, call factor method. - ss << DecisionTreeFactor::markdown(keyFormatter, names); + ss << BaseFactor::markdown(keyFormatter, names); return ss.str(); } @@ -427,7 +428,7 @@ string DiscreteConditional::html(const KeyFormatter& keyFormatter, ss << "

\n"; if (nrParents() == 0) { // We have no parents, call factor method. - ss << DecisionTreeFactor::html(keyFormatter, names); + ss << BaseFactor::html(keyFormatter, names); return ss.str(); } @@ -475,7 +476,7 @@ string DiscreteConditional::html(const KeyFormatter& keyFormatter, /* ************************************************************************* */ double DiscreteConditional::evaluate(const HybridValues& x) const { - return this->evaluate(x.discrete()); + return this->operator()(x.discrete()); } /* ************************************************************************* */ diff --git a/gtsam/discrete/DiscreteConditional.h b/gtsam/discrete/DiscreteConditional.h index 3ec9ae590..67f8a0056 100644 --- a/gtsam/discrete/DiscreteConditional.h +++ b/gtsam/discrete/DiscreteConditional.h @@ -54,7 +54,7 @@ class GTSAM_EXPORT DiscreteConditional DiscreteConditional() {} /// Construct from factor, taking the first `nFrontals` keys as frontals. - DiscreteConditional(size_t nFrontals, const DecisionTreeFactor& f); + DiscreteConditional(size_t nFrontals, const DiscreteFactor& f); /** * Construct from DiscreteKeys and AlgebraicDecisionTree, taking the first diff --git a/gtsam/discrete/DiscreteFactor.h b/gtsam/discrete/DiscreteFactor.h index a1fde0f86..6cbc00d09 100644 --- a/gtsam/discrete/DiscreteFactor.h +++ b/gtsam/discrete/DiscreteFactor.h @@ -22,6 +22,7 @@ #include #include #include +#include #include namespace gtsam { @@ -129,8 +130,40 @@ class GTSAM_EXPORT DiscreteFactor : public Factor { /// DecisionTreeFactor virtual DecisionTreeFactor operator*(const DecisionTreeFactor&) const = 0; + /** + * @brief Multiply in a DiscreteFactor and return the result as + * DiscreteFactor, both via shared pointers. + * + * @param df DiscreteFactor shared_ptr + * @return DiscreteFactor::shared_ptr + */ + virtual DiscreteFactor::shared_ptr multiply( + const DiscreteFactor::shared_ptr& df) const = 0; + + /// divide by DiscreteFactor::shared_ptr f (safely) + virtual DiscreteFactor::shared_ptr operator/( + const DiscreteFactor::shared_ptr& df) const = 0; + virtual DecisionTreeFactor toDecisionTreeFactor() const = 0; + /// Create new factor by summing all values with the same separator values + virtual DiscreteFactor::shared_ptr sum(size_t nrFrontals) const = 0; + + /// Create new factor by summing all values with the same separator values + virtual DiscreteFactor::shared_ptr sum(const Ordering& keys) const = 0; + + /// Create new factor by maximizing over all values with the same separator. + virtual DiscreteFactor::shared_ptr max(size_t nrFrontals) const = 0; + + /// Create new factor by maximizing over all values with the same separator. + virtual DiscreteFactor::shared_ptr max(const Ordering& keys) const = 0; + + /** + * Get the number of non-zero values contained in this factor. + * It could be much smaller than `prod_{key}(cardinality(key))`. + */ + virtual uint64_t nrValues() const = 0; + /// @} /// @name Wrapper support /// @{ diff --git a/gtsam/discrete/DiscreteFactorGraph.cpp b/gtsam/discrete/DiscreteFactorGraph.cpp index b48e09b03..9b1774f49 100644 --- a/gtsam/discrete/DiscreteFactorGraph.cpp +++ b/gtsam/discrete/DiscreteFactorGraph.cpp @@ -64,10 +64,17 @@ namespace gtsam { } /* ************************************************************************ */ - DecisionTreeFactor DiscreteFactorGraph::product() const { - DecisionTreeFactor result; - for (const sharedFactor& factor : *this) { - if (factor) result = (*factor) * result; + DiscreteFactor::shared_ptr DiscreteFactorGraph::product() const { + DiscreteFactor::shared_ptr result; + for (auto it = this->begin(); it != this->end(); ++it) { + if (*it) { + if (result) { + result = result->multiply(*it); + } else { + // Assign to the first non-null factor + result = *it; + } + } } return result; } @@ -115,21 +122,23 @@ namespace gtsam { * @brief Multiply all the `factors`. * * @param factors The factors to multiply as a DiscreteFactorGraph. - * @return DecisionTreeFactor + * @return DiscreteFactor::shared_ptr */ - static DecisionTreeFactor DiscreteProduct( + static DiscreteFactor::shared_ptr DiscreteProduct( const DiscreteFactorGraph& factors) { // PRODUCT: multiply all factors - DecisionTreeFactor product = factors.product(); + gttic(product); + DiscreteFactor::shared_ptr product = factors.product(); + gttoc(product); #if GTSAM_HYBRID_TIMING gttic_(DiscreteNormalize); #endif // Max over all the potentials by pretending all keys are frontal: - auto denominator = product.max(product.size()); + auto denominator = product->max(product->size()); // Normalize the product factor to prevent underflow. - product = product / (*denominator); + product = product->operator/(denominator); #if GTSAM_HYBRID_TIMING gttoc_(DiscreteNormalize); #endif @@ -142,25 +151,25 @@ namespace gtsam { std::pair // EliminateForMPE(const DiscreteFactorGraph& factors, const Ordering& frontalKeys) { - DecisionTreeFactor product = DiscreteProduct(factors); + DiscreteFactor::shared_ptr product = DiscreteProduct(factors); // max out frontals, this is the factor on the separator gttic(max); - DecisionTreeFactor::shared_ptr max = product.max(frontalKeys); + DiscreteFactor::shared_ptr max = product->max(frontalKeys); gttoc(max); // Ordering keys for the conditional so that frontalKeys are really in front DiscreteKeys orderedKeys; for (auto&& key : frontalKeys) - orderedKeys.emplace_back(key, product.cardinality(key)); + orderedKeys.emplace_back(key, product->cardinality(key)); for (auto&& key : max->keys()) - orderedKeys.emplace_back(key, product.cardinality(key)); + orderedKeys.emplace_back(key, product->cardinality(key)); // Make lookup with product gttic(lookup); size_t nrFrontals = frontalKeys.size(); - auto lookup = - std::make_shared(nrFrontals, orderedKeys, product); + auto lookup = std::make_shared( + nrFrontals, orderedKeys, product->toDecisionTreeFactor()); gttoc(lookup); return {std::dynamic_pointer_cast(lookup), max}; @@ -220,10 +229,12 @@ namespace gtsam { std::pair // EliminateDiscrete(const DiscreteFactorGraph& factors, const Ordering& frontalKeys) { - DecisionTreeFactor product = DiscreteProduct(factors); + DiscreteFactor::shared_ptr product = DiscreteProduct(factors); // sum out frontals, this is the factor on the separator - DecisionTreeFactor::shared_ptr sum = product.sum(frontalKeys); + gttic(sum); + DiscreteFactor::shared_ptr sum = product->sum(frontalKeys); + gttoc(sum); // Ordering keys for the conditional so that frontalKeys are really in front Ordering orderedKeys; @@ -233,8 +244,11 @@ namespace gtsam { sum->keys().end()); // now divide product/sum to get conditional - auto conditional = - std::make_shared(product, *sum, orderedKeys); + gttic(divide); + auto conditional = std::make_shared( + product->toDecisionTreeFactor(), sum->toDecisionTreeFactor(), + orderedKeys); + gttoc(divide); return {conditional, sum}; } diff --git a/gtsam/discrete/DiscreteFactorGraph.h b/gtsam/discrete/DiscreteFactorGraph.h index c57d2258c..3d9e86cd1 100644 --- a/gtsam/discrete/DiscreteFactorGraph.h +++ b/gtsam/discrete/DiscreteFactorGraph.h @@ -134,6 +134,7 @@ class GTSAM_EXPORT DiscreteFactorGraph /// @} + //TODO(Varun): Make compatible with TableFactor /** Add a decision-tree factor */ template void add(Args&&... args) { @@ -147,7 +148,7 @@ class GTSAM_EXPORT DiscreteFactorGraph DiscreteKeys discreteKeys() const; /** return product of all factors as a single factor */ - DecisionTreeFactor product() const; + DiscreteFactor::shared_ptr product() const; /** * Evaluates the factor graph given values, returns the joint probability of diff --git a/gtsam/discrete/DiscreteLookupDAG.h b/gtsam/discrete/DiscreteLookupDAG.h index afb4d4208..1c0be2f06 100644 --- a/gtsam/discrete/DiscreteLookupDAG.h +++ b/gtsam/discrete/DiscreteLookupDAG.h @@ -18,6 +18,7 @@ #pragma once #include +#include #include #include @@ -54,6 +55,18 @@ class GTSAM_EXPORT DiscreteLookupTable : public DiscreteConditional { const ADT& potentials) : DiscreteConditional(nFrontals, keys, potentials) {} + /** + * @brief Construct a new Discrete Lookup Table object + * + * @param nFrontals number of frontal variables + * @param keys a sorted list of gtsam::Keys + * @param potentials Discrete potentials as a TableFactor. + */ + DiscreteLookupTable(size_t nFrontals, const DiscreteKeys& keys, + const TableFactor& potentials) + : DiscreteConditional(nFrontals, keys, + potentials.toDecisionTreeFactor()) {} + /// GTSAM-style print void print( const std::string& s = "Discrete Lookup Table: ", diff --git a/gtsam/discrete/TableFactor.cpp b/gtsam/discrete/TableFactor.cpp index a59095d40..d1cedc9ef 100644 --- a/gtsam/discrete/TableFactor.cpp +++ b/gtsam/discrete/TableFactor.cpp @@ -254,6 +254,46 @@ DecisionTreeFactor TableFactor::operator*(const DecisionTreeFactor& f) const { return toDecisionTreeFactor() * f; } +/* ************************************************************************ */ +DiscreteFactor::shared_ptr TableFactor::multiply( + const DiscreteFactor::shared_ptr& f) const { + DiscreteFactor::shared_ptr result; + if (auto tf = std::dynamic_pointer_cast(f)) { + // If `f` is a TableFactor, we can simply call `operator*`. + result = std::make_shared(this->operator*(*tf)); + + } else if (auto dtf = std::dynamic_pointer_cast(f)) { + // If `f` is a DecisionTreeFactor, we convert to a TableFactor which is + // cheaper than converting `this` to a DecisionTreeFactor. + result = std::make_shared(this->operator*(TableFactor(*dtf))); + + } else { + // Simulate double dispatch in C++ + // Useful for other classes which inherit from DiscreteFactor and have + // only `operator*(DecisionTreeFactor)` defined. Thus, other classes don't + // need to be updated to know about TableFactor. + // Those classes can be specialized to use TableFactor + // if efficiency is a problem. + result = std::make_shared( + f->operator*(this->toDecisionTreeFactor())); + } + return result; +} + +/* ************************************************************************ */ +DiscreteFactor::shared_ptr TableFactor::operator/( + const DiscreteFactor::shared_ptr& f) const { + if (auto tf = std::dynamic_pointer_cast(f)) { + return std::make_shared(this->operator/(*tf)); + } else if (auto dtf = std::dynamic_pointer_cast(f)) { + return std::make_shared( + this->operator/(TableFactor(f->discreteKeys(), *dtf))); + } else { + TableFactor divisor(f->toDecisionTreeFactor()); + return std::make_shared(this->operator/(divisor)); + } +} + /* ************************************************************************ */ DecisionTreeFactor TableFactor::toDecisionTreeFactor() const { DiscreteKeys dkeys = discreteKeys(); diff --git a/gtsam/discrete/TableFactor.h b/gtsam/discrete/TableFactor.h index a2fdb4d32..5a804b6a6 100644 --- a/gtsam/discrete/TableFactor.h +++ b/gtsam/discrete/TableFactor.h @@ -17,6 +17,7 @@ #pragma once +#include #include #include #include @@ -178,6 +179,23 @@ class GTSAM_EXPORT TableFactor : public DiscreteFactor { /// multiply with DecisionTreeFactor DecisionTreeFactor operator*(const DecisionTreeFactor& f) const override; + /** + * @brief Multiply factors, DiscreteFactor::shared_ptr edition. + * + * This method accepts `DiscreteFactor::shared_ptr` and uses dynamic + * dispatch and specializations to perform the most efficient + * multiplication. + * + * While converting a DecisionTreeFactor to a TableFactor is efficient, the + * reverse is not. + * Hence we specialize the code to return a TableFactor always. + * + * @param f The factor to multiply with. + * @return DiscreteFactor::shared_ptr + */ + virtual DiscreteFactor::shared_ptr multiply( + const DiscreteFactor::shared_ptr& f) const override; + static double safe_div(const double& a, const double& b); /// divide by factor f (safely) @@ -185,6 +203,10 @@ class GTSAM_EXPORT TableFactor : public DiscreteFactor { return apply(f, safe_div); } + /// divide by DiscreteFactor::shared_ptr f (safely) + DiscreteFactor::shared_ptr operator/( + const DiscreteFactor::shared_ptr& f) const override; + /// Convert into a decisiontree DecisionTreeFactor toDecisionTreeFactor() const override; @@ -193,22 +215,22 @@ class GTSAM_EXPORT TableFactor : public DiscreteFactor { DiscreteKeys parent_keys) const; /// Create new factor by summing all values with the same separator values - shared_ptr sum(size_t nrFrontals) const { + DiscreteFactor::shared_ptr sum(size_t nrFrontals) const override { return combine(nrFrontals, Ring::add); } /// Create new factor by summing all values with the same separator values - shared_ptr sum(const Ordering& keys) const { + DiscreteFactor::shared_ptr sum(const Ordering& keys) const override { return combine(keys, Ring::add); } /// Create new factor by maximizing over all values with the same separator. - shared_ptr max(size_t nrFrontals) const { + DiscreteFactor::shared_ptr max(size_t nrFrontals) const override { return combine(nrFrontals, Ring::max); } /// Create new factor by maximizing over all values with the same separator. - shared_ptr max(const Ordering& keys) const { + DiscreteFactor::shared_ptr max(const Ordering& keys) const override { return combine(keys, Ring::max); } @@ -313,6 +335,12 @@ class GTSAM_EXPORT TableFactor : public DiscreteFactor { */ TableFactor prune(size_t maxNrAssignments) const; + /** + * Get the number of non-zero values contained in this factor. + * It could be much smaller than `prod_{key}(cardinality(key))`. + */ + uint64_t nrValues() const override { return sparse_table_.nonZeros(); } + /// @} /// @name Wrapper support /// @{ diff --git a/gtsam/discrete/tests/testDecisionTreeFactor.cpp b/gtsam/discrete/tests/testDecisionTreeFactor.cpp index 756a0cebe..ec9185ecb 100644 --- a/gtsam/discrete/tests/testDecisionTreeFactor.cpp +++ b/gtsam/discrete/tests/testDecisionTreeFactor.cpp @@ -30,6 +30,12 @@ using namespace std; using namespace gtsam; +/** Convert Signature into CPT */ +DecisionTreeFactor create(const Signature& signature) { + DecisionTreeFactor p(signature.discreteKeys(), signature.cpt()); + return p; +} + /* ************************************************************************* */ TEST(DecisionTreeFactor, ConstructorsMatch) { // Declare two keys @@ -105,21 +111,45 @@ TEST(DecisionTreeFactor, multiplication) { CHECK(assert_equal(expected2, actual)); } +/* ************************************************************************* */ +TEST(DecisionTreeFactor, Divide) { + DiscreteKey A(0, 2), S(1, 2); + DecisionTreeFactor pA = create(A % "99/1"), pS = create(S % "50/50"); + DecisionTreeFactor joint = pA * pS; + + DecisionTreeFactor s = joint / pA; + + // Factors are not equal due to difference in keys + EXPECT(assert_inequal(pS, s)); + + // The underlying data should be the same + using ADT = AlgebraicDecisionTree; + EXPECT(assert_equal(ADT(pS), ADT(s))); + + KeySet keys(joint.keys()); + keys.insert(pA.keys().begin(), pA.keys().end()); + EXPECT(assert_inequal(KeySet(pS.keys()), keys)); + +} + /* ************************************************************************* */ TEST(DecisionTreeFactor, sum_max) { DiscreteKey v0(0, 3), v1(1, 2); DecisionTreeFactor f1(v0 & v1, "1 2 3 4 5 6"); DecisionTreeFactor expected(v1, "9 12"); - DecisionTreeFactor::shared_ptr actual = f1.sum(1); + auto actual = std::dynamic_pointer_cast(f1.sum(1)); + CHECK(actual); CHECK(assert_equal(expected, *actual, 1e-5)); DecisionTreeFactor expected2(v1, "5 6"); - DecisionTreeFactor::shared_ptr actual2 = f1.max(1); + auto actual2 = std::dynamic_pointer_cast(f1.max(1)); + CHECK(actual2); CHECK(assert_equal(expected2, *actual2)); DecisionTreeFactor f2(v1 & v0, "1 2 3 4 5 6"); - DecisionTreeFactor::shared_ptr actual22 = f2.sum(1); + auto actual22 = std::dynamic_pointer_cast(f2.sum(1)); + CHECK(actual22); } /* ************************************************************************* */ @@ -217,12 +247,6 @@ void maybeSaveDotFile(const DecisionTreeFactor& f, const string& filename) { #endif } -/** Convert Signature into CPT */ -DecisionTreeFactor create(const Signature& signature) { - DecisionTreeFactor p(signature.discreteKeys(), signature.cpt()); - return p; -} - /* ************************************************************************* */ // test Asia Joint TEST(DecisionTreeFactor, joint) { diff --git a/gtsam/discrete/tests/testDiscreteConditional.cpp b/gtsam/discrete/tests/testDiscreteConditional.cpp index 2482a86a2..b91e1bd8a 100644 --- a/gtsam/discrete/tests/testDiscreteConditional.cpp +++ b/gtsam/discrete/tests/testDiscreteConditional.cpp @@ -46,7 +46,7 @@ TEST(DiscreteConditional, constructors) { DecisionTreeFactor f2( X & Y & Z, "0.2 0.5 0.3 0.6 0.4 0.7 0.25 0.55 0.35 0.65 0.45 0.75"); DiscreteConditional actual2(1, f2); - DecisionTreeFactor expected2 = f2 / *f2.sum(1); + DecisionTreeFactor expected2 = f2 / f2.sum(1)->toDecisionTreeFactor(); EXPECT(assert_equal(expected2, static_cast(actual2))); std::vector probs{0.2, 0.5, 0.3, 0.6, 0.4, 0.7, 0.25, 0.55, 0.35, 0.65, 0.45, 0.75}; @@ -70,7 +70,7 @@ TEST(DiscreteConditional, constructors_alt_interface) { DecisionTreeFactor f2( X & Y & Z, "0.2 0.5 0.3 0.6 0.4 0.7 0.25 0.55 0.35 0.65 0.45 0.75"); DiscreteConditional actual2(1, f2); - DecisionTreeFactor expected2 = f2 / *f2.sum(1); + DecisionTreeFactor expected2 = f2 / f2.sum(1)->toDecisionTreeFactor(); EXPECT(assert_equal(expected2, static_cast(actual2))); } diff --git a/gtsam/discrete/tests/testDiscreteFactorGraph.cpp b/gtsam/discrete/tests/testDiscreteFactorGraph.cpp index cbcf5234e..0c1dd7a2a 100644 --- a/gtsam/discrete/tests/testDiscreteFactorGraph.cpp +++ b/gtsam/discrete/tests/testDiscreteFactorGraph.cpp @@ -94,7 +94,7 @@ TEST_UNSAFE( DiscreteFactorGraph, DiscreteFactorGraphEvaluationTest) { EXPECT_DOUBLES_EQUAL( 1.944, graph(values), 1e-9); // Check if graph product works - DecisionTreeFactor product = graph.product(); + DecisionTreeFactor product = graph.product()->toDecisionTreeFactor(); EXPECT_DOUBLES_EQUAL( 1.944, product(values), 1e-9); } @@ -117,9 +117,9 @@ TEST(DiscreteFactorGraph, test) { *std::dynamic_pointer_cast(newFactorPtr); // Normalize newFactor by max for comparison with expected - auto normalizer = newFactor.max(newFactor.size()); + auto denominator = newFactor.max(newFactor.size())->toDecisionTreeFactor(); - newFactor = newFactor / *normalizer; + newFactor = newFactor / denominator; // Check Conditional CHECK(conditional); @@ -131,9 +131,10 @@ TEST(DiscreteFactorGraph, test) { CHECK(&newFactor); DecisionTreeFactor expectedFactor(B & A, "10 6 6 10"); // Normalize by max. - normalizer = expectedFactor.max(expectedFactor.size()); - // Ensure normalizer is correct. - expectedFactor = expectedFactor / *normalizer; + denominator = + expectedFactor.max(expectedFactor.size())->toDecisionTreeFactor(); + // Ensure denominator is correct. + expectedFactor = expectedFactor / denominator; EXPECT(assert_equal(expectedFactor, newFactor)); // Test using elimination tree diff --git a/gtsam/discrete/tests/testTableFactor.cpp b/gtsam/discrete/tests/testTableFactor.cpp index e6c71e15c..76a0f2b5c 100644 --- a/gtsam/discrete/tests/testTableFactor.cpp +++ b/gtsam/discrete/tests/testTableFactor.cpp @@ -194,15 +194,17 @@ TEST(TableFactor, Conversion) { TEST(TableFactor, Empty) { DiscreteKey X(1, 2); - TableFactor single = *TableFactor({X}, "1 1").sum(1); + auto single = TableFactor({X}, "1 1").sum(1); // Should not throw a segfault - EXPECT(assert_equal(*DecisionTreeFactor(X, "1 1").sum(1), - single.toDecisionTreeFactor())); + auto expected_single = DecisionTreeFactor(X, "1 1").sum(1); + EXPECT(assert_equal(expected_single->toDecisionTreeFactor(), + single->toDecisionTreeFactor())); - TableFactor empty = *TableFactor({X}, "0 0").sum(1); + auto empty = TableFactor({X}, "0 0").sum(1); // Should not throw a segfault - EXPECT(assert_equal(*DecisionTreeFactor(X, "0 0").sum(1), - empty.toDecisionTreeFactor())); + auto expected_empty = DecisionTreeFactor(X, "0 0").sum(1); + EXPECT(assert_equal(expected_empty->toDecisionTreeFactor(), + empty->toDecisionTreeFactor())); } /* ************************************************************************* */ @@ -303,15 +305,18 @@ TEST(TableFactor, sum_max) { TableFactor f1(v0 & v1, "1 2 3 4 5 6"); TableFactor expected(v1, "9 12"); - TableFactor::shared_ptr actual = f1.sum(1); + auto actual = std::dynamic_pointer_cast(f1.sum(1)); + CHECK(actual); CHECK(assert_equal(expected, *actual, 1e-5)); TableFactor expected2(v1, "5 6"); - TableFactor::shared_ptr actual2 = f1.max(1); + auto actual2 = std::dynamic_pointer_cast(f1.max(1)); + CHECK(actual2); CHECK(assert_equal(expected2, *actual2)); TableFactor f2(v1 & v0, "1 2 3 4 5 6"); - TableFactor::shared_ptr actual22 = f2.sum(1); + auto actual22 = std::dynamic_pointer_cast(f2.sum(1)); + CHECK(actual22); } /* ************************************************************************* */ diff --git a/gtsam_unstable/discrete/Constraint.h b/gtsam_unstable/discrete/Constraint.h index 3526a282d..328fabbea 100644 --- a/gtsam_unstable/discrete/Constraint.h +++ b/gtsam_unstable/discrete/Constraint.h @@ -68,7 +68,8 @@ class GTSAM_UNSTABLE_EXPORT Constraint : public DiscreteFactor { /* * Ensure Arc-consistency by checking every possible value of domain j. * @param j domain to be checked - * @param (in/out) domains all domains, but only domains->at(j) will be checked. + * @param (in/out) domains all domains, but only domains->at(j) will be + * checked. * @return true if domains->at(j) was changed, false otherwise. */ virtual bool ensureArcConsistency(Key j, Domains* domains) const = 0; @@ -78,6 +79,39 @@ class GTSAM_UNSTABLE_EXPORT Constraint : public DiscreteFactor { /// Partially apply known values, domain version virtual shared_ptr partiallyApply(const Domains&) const = 0; + + /// Multiply factors, DiscreteFactor::shared_ptr edition + DiscreteFactor::shared_ptr multiply( + const DiscreteFactor::shared_ptr& df) const override { + return std::make_shared( + this->operator*(df->toDecisionTreeFactor())); + } + + /// divide by DiscreteFactor::shared_ptr f (safely) + DiscreteFactor::shared_ptr operator/( + const DiscreteFactor::shared_ptr& df) const override { + return this->toDecisionTreeFactor() / df; + } + + /// Get the number of non-zero values contained in this factor. + uint64_t nrValues() const override { return 1; }; + + DiscreteFactor::shared_ptr sum(size_t nrFrontals) const override { + return toDecisionTreeFactor().sum(nrFrontals); + } + + DiscreteFactor::shared_ptr sum(const Ordering& keys) const override { + return toDecisionTreeFactor().sum(keys); + } + + DiscreteFactor::shared_ptr max(size_t nrFrontals) const override { + return toDecisionTreeFactor().max(nrFrontals); + } + + DiscreteFactor::shared_ptr max(const Ordering& keys) const override { + return toDecisionTreeFactor().max(keys); + } + /// @} /// @name Wrapper support /// @{ diff --git a/gtsam_unstable/discrete/Domain.h b/gtsam_unstable/discrete/Domain.h index 23a566d24..6ce846201 100644 --- a/gtsam_unstable/discrete/Domain.h +++ b/gtsam_unstable/discrete/Domain.h @@ -49,7 +49,7 @@ class GTSAM_UNSTABLE_EXPORT Domain : public Constraint { /// Erase a value, non const :-( void erase(size_t value) { values_.erase(value); } - size_t nrValues() const { return values_.size(); } + uint64_t nrValues() const override { return values_.size(); } bool isSingleton() const { return nrValues() == 1; } diff --git a/gtsam_unstable/discrete/tests/testCSP.cpp b/gtsam_unstable/discrete/tests/testCSP.cpp index 2b9a20ca6..6806bfe58 100644 --- a/gtsam_unstable/discrete/tests/testCSP.cpp +++ b/gtsam_unstable/discrete/tests/testCSP.cpp @@ -124,7 +124,7 @@ TEST(CSP, allInOne) { EXPECT_DOUBLES_EQUAL(1, csp(valid), 1e-9); // Just for fun, create the product and check it - DecisionTreeFactor product = csp.product(); + DecisionTreeFactor product = csp.product()->toDecisionTreeFactor(); // product.dot("product"); DecisionTreeFactor expectedProduct(ID & AZ & UT, "0 1 0 0 0 0 1 0"); EXPECT(assert_equal(expectedProduct, product)); diff --git a/gtsam_unstable/discrete/tests/testScheduler.cpp b/gtsam_unstable/discrete/tests/testScheduler.cpp index f868abb5e..5f9b7f287 100644 --- a/gtsam_unstable/discrete/tests/testScheduler.cpp +++ b/gtsam_unstable/discrete/tests/testScheduler.cpp @@ -113,7 +113,7 @@ TEST(schedulingExample, test) { EXPECT(assert_equal(expected, (DiscreteFactorGraph)s)); // Do brute force product and output that to file - DecisionTreeFactor product = s.product(); + DecisionTreeFactor product = s.product()->toDecisionTreeFactor(); // product.dot("scheduling", false); // Do exact inference