diff --git a/gtsam/discrete/DecisionTreeFactor.h b/gtsam/discrete/DecisionTreeFactor.h index 804b956fe..a5b82f277 100644 --- a/gtsam/discrete/DecisionTreeFactor.h +++ b/gtsam/discrete/DecisionTreeFactor.h @@ -154,17 +154,17 @@ namespace gtsam { static double safe_div(const double& a, const double& b); - /// divide by factor f (safely) + /** + * @brief Divide by factor f (safely). + * Division of a factor \f$f(x, y)\f$ by another factor \f$g(y, z)\f$ + * results in a function which involves all keys + * \f$(\frac{f}{g})(x, y, z) = f(x, y) / g(y, z)\f$ + * + * @param f The DecisinTreeFactor to divide by. + * @return DecisionTreeFactor + */ DecisionTreeFactor operator/(const DecisionTreeFactor& f) const { - KeyVector diff; - std::set_difference(this->keys().begin(), this->keys().end(), - f.keys().begin(), f.keys().end(), - std::back_inserter(diff)); - DiscreteKeys keys; - for (Key key : diff) { - keys.push_back({key, this->cardinality(key)}); - } - return DecisionTreeFactor(keys, apply(f, safe_div)); + return apply(f, safe_div); } /// Convert into a decision tree @@ -181,12 +181,12 @@ namespace gtsam { } /// Create new factor by maximizing over all values with the same separator. - shared_ptr max(size_t nrFrontals) const { + DiscreteFactor::shared_ptr max(size_t nrFrontals) const { return combine(nrFrontals, Ring::max); } /// Create new factor by maximizing over all values with the same separator. - shared_ptr max(const Ordering& keys) const { + DiscreteFactor::shared_ptr max(const Ordering& keys) const { return combine(keys, Ring::max); } diff --git a/gtsam/discrete/tests/testDecisionTreeFactor.cpp b/gtsam/discrete/tests/testDecisionTreeFactor.cpp index 7210622d8..ba8714783 100644 --- a/gtsam/discrete/tests/testDecisionTreeFactor.cpp +++ b/gtsam/discrete/tests/testDecisionTreeFactor.cpp @@ -116,8 +116,20 @@ TEST(DecisionTreeFactor, Divide) { DiscreteKey A(0, 2), S(1, 2); DecisionTreeFactor pA = create(A % "99/1"), pS = create(S % "50/50"); DecisionTreeFactor joint = pA * pS; + DecisionTreeFactor s = joint / pA; - EXPECT(assert_equal(pS, s)); + + // Factors are not equal due to difference in keys + EXPECT(assert_inequal(pS, s)); + + // The underlying data should be the same + using ADT = AlgebraicDecisionTree; + EXPECT(assert_equal(ADT(pS), ADT(s))); + + KeySet keys(joint.keys()); + keys.insert(pA.keys().begin(), pA.keys().end()); + EXPECT(assert_inequal(KeySet(pS.keys()), keys)); + } /* ************************************************************************* */