Additional arithmetic

release/4.3a0
Frank Dellaert 2024-09-30 15:19:05 -07:00
parent d054a041ed
commit 5fb3b37771
2 changed files with 43 additions and 20 deletions

View File

@ -70,6 +70,7 @@ namespace gtsam {
return a / b;
}
static inline double id(const double& x) { return x; }
static inline double negate(const double& x) { return -x; }
};
AlgebraicDecisionTree(double leaf = 1.0) : Base(leaf) {}
@ -186,6 +187,16 @@ namespace gtsam {
return this->apply(g, &Ring::add);
}
/** negation */
AlgebraicDecisionTree operator-() const {
return this->apply(&Ring::negate);
}
/** subtract */
AlgebraicDecisionTree operator-(const AlgebraicDecisionTree& g) const {
return *this + (-g);
}
/** product */
AlgebraicDecisionTree operator*(const AlgebraicDecisionTree& g) const {
return this->apply(g, &Ring::mul);

View File

@ -10,8 +10,8 @@
* -------------------------------------------------------------------------- */
/*
* @file testDecisionTree.cpp
* @brief Develop DecisionTree
* @file testAlgebraicDecisionTree.cpp
* @brief Unit tests for Algebraic decision tree
* @author Frank Dellaert
* @date Mar 6, 2011
*/
@ -46,23 +46,35 @@ void dot(const T& f, const string& filename) {
#endif
}
/** I can't get this to work !
class Mul: std::function<double(const double&, const double&)> {
inline double operator()(const double& a, const double& b) {
return a * b;
}
};
/* ************************************************************************** */
// Test arithmetic:
TEST(ADT, arithmetic) {
DiscreteKey A(0, 2), B(1, 2);
ADT zero{0}, one{1};
ADT a(A, 1, 2);
ADT b(B, 3, 4);
// If second argument of binary op is Leaf
template<typename L>
typename DecisionTree<L, double>::Node::Ptr DecisionTree<L,
double>::Choice::apply_fC_op_gL( Cache& cache, const Leaf& gL, Mul op) const {
Ptr h(new Choice(label(), cardinality()));
for(const NodePtr& branch: branches_)
h->push_back(branch->apply_f_op_g(cache, gL, op));
return Unique(cache, h);
// Addition
CHECK(assert_equal(a, zero + a));
// Negate and subtraction
CHECK(assert_equal(-a, zero - a));
CHECK(assert_equal({zero}, a - a));
CHECK(assert_equal(a + b, b + a));
CHECK(assert_equal({A, 3, 4}, a + 2));
CHECK(assert_equal({B, 1, 2}, b - 2));
// Multiplication
CHECK(assert_equal(zero, zero * a));
CHECK(assert_equal(zero, a * zero));
CHECK(assert_equal(a, one * a));
CHECK(assert_equal(a, a * one));
CHECK(assert_equal(a * b, b * a));
// division
// CHECK(assert_equal(a, (a * b) / b)); // not true because no pruning
CHECK(assert_equal(b, (a * b) / a));
}
*/
/* ************************************************************************** */
// instrumented operators