improve operator/ documentation and also showcase understanding in test

release/4.3a0
Varun Agrawal 2025-01-05 11:54:09 -05:00
parent a142556c52
commit e309bf370b
2 changed files with 25 additions and 13 deletions

View File

@ -154,17 +154,17 @@ namespace gtsam {
static double safe_div(const double& a, const double& b); static double safe_div(const double& a, const double& b);
/// divide by factor f (safely) /**
* @brief Divide by factor f (safely).
* Division of a factor \f$f(x, y)\f$ by another factor \f$g(y, z)\f$
* results in a function which involves all keys
* \f$(\frac{f}{g})(x, y, z) = f(x, y) / g(y, z)\f$
*
* @param f The DecisinTreeFactor to divide by.
* @return DecisionTreeFactor
*/
DecisionTreeFactor operator/(const DecisionTreeFactor& f) const { DecisionTreeFactor operator/(const DecisionTreeFactor& f) const {
KeyVector diff; return apply(f, safe_div);
std::set_difference(this->keys().begin(), this->keys().end(),
f.keys().begin(), f.keys().end(),
std::back_inserter(diff));
DiscreteKeys keys;
for (Key key : diff) {
keys.push_back({key, this->cardinality(key)});
}
return DecisionTreeFactor(keys, apply(f, safe_div));
} }
/// Convert into a decision tree /// Convert into a decision tree
@ -181,12 +181,12 @@ namespace gtsam {
} }
/// Create new factor by maximizing over all values with the same separator. /// Create new factor by maximizing over all values with the same separator.
shared_ptr max(size_t nrFrontals) const { DiscreteFactor::shared_ptr max(size_t nrFrontals) const {
return combine(nrFrontals, Ring::max); return combine(nrFrontals, Ring::max);
} }
/// Create new factor by maximizing over all values with the same separator. /// Create new factor by maximizing over all values with the same separator.
shared_ptr max(const Ordering& keys) const { DiscreteFactor::shared_ptr max(const Ordering& keys) const {
return combine(keys, Ring::max); return combine(keys, Ring::max);
} }

View File

@ -116,8 +116,20 @@ TEST(DecisionTreeFactor, Divide) {
DiscreteKey A(0, 2), S(1, 2); DiscreteKey A(0, 2), S(1, 2);
DecisionTreeFactor pA = create(A % "99/1"), pS = create(S % "50/50"); DecisionTreeFactor pA = create(A % "99/1"), pS = create(S % "50/50");
DecisionTreeFactor joint = pA * pS; DecisionTreeFactor joint = pA * pS;
DecisionTreeFactor s = joint / pA; DecisionTreeFactor s = joint / pA;
EXPECT(assert_equal(pS, s));
// Factors are not equal due to difference in keys
EXPECT(assert_inequal(pS, s));
// The underlying data should be the same
using ADT = AlgebraicDecisionTree<Key>;
EXPECT(assert_equal(ADT(pS), ADT(s)));
KeySet keys(joint.keys());
keys.insert(pA.keys().begin(), pA.keys().end());
EXPECT(assert_inequal(KeySet(pS.keys()), keys));
} }
/* ************************************************************************* */ /* ************************************************************************* */