Merge pull request #1962 from borglab/fix-dtf-division
Fix DecisionTreeFactor divisionrelease/4.3a0
commit
ffd04fd454
|
@ -154,7 +154,15 @@ namespace gtsam {
|
||||||
|
|
||||||
static double safe_div(const double& a, const double& b);
|
static double safe_div(const double& a, const double& b);
|
||||||
|
|
||||||
/// divide by factor f (safely)
|
/**
|
||||||
|
* @brief Divide by factor f (safely).
|
||||||
|
* Division of a factor \f$f(x, y)\f$ by another factor \f$g(y, z)\f$
|
||||||
|
* results in a function which involves all keys
|
||||||
|
* \f$(\frac{f}{g})(x, y, z) = f(x, y) / g(y, z)\f$
|
||||||
|
*
|
||||||
|
* @param f The DecisinTreeFactor to divide by.
|
||||||
|
* @return DecisionTreeFactor
|
||||||
|
*/
|
||||||
DecisionTreeFactor operator/(const DecisionTreeFactor& f) const {
|
DecisionTreeFactor operator/(const DecisionTreeFactor& f) const {
|
||||||
return apply(f, safe_div);
|
return apply(f, safe_div);
|
||||||
}
|
}
|
||||||
|
|
|
@ -30,6 +30,12 @@
|
||||||
using namespace std;
|
using namespace std;
|
||||||
using namespace gtsam;
|
using namespace gtsam;
|
||||||
|
|
||||||
|
/** Convert Signature into CPT */
|
||||||
|
DecisionTreeFactor create(const Signature& signature) {
|
||||||
|
DecisionTreeFactor p(signature.discreteKeys(), signature.cpt());
|
||||||
|
return p;
|
||||||
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
TEST(DecisionTreeFactor, ConstructorsMatch) {
|
TEST(DecisionTreeFactor, ConstructorsMatch) {
|
||||||
// Declare two keys
|
// Declare two keys
|
||||||
|
@ -105,6 +111,27 @@ TEST(DecisionTreeFactor, multiplication) {
|
||||||
CHECK(assert_equal(expected2, actual));
|
CHECK(assert_equal(expected2, actual));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************* */
|
||||||
|
TEST(DecisionTreeFactor, Divide) {
|
||||||
|
DiscreteKey A(0, 2), S(1, 2);
|
||||||
|
DecisionTreeFactor pA = create(A % "99/1"), pS = create(S % "50/50");
|
||||||
|
DecisionTreeFactor joint = pA * pS;
|
||||||
|
|
||||||
|
DecisionTreeFactor s = joint / pA;
|
||||||
|
|
||||||
|
// Factors are not equal due to difference in keys
|
||||||
|
EXPECT(assert_inequal(pS, s));
|
||||||
|
|
||||||
|
// The underlying data should be the same
|
||||||
|
using ADT = AlgebraicDecisionTree<Key>;
|
||||||
|
EXPECT(assert_equal(ADT(pS), ADT(s)));
|
||||||
|
|
||||||
|
KeySet keys(joint.keys());
|
||||||
|
keys.insert(pA.keys().begin(), pA.keys().end());
|
||||||
|
EXPECT(assert_inequal(KeySet(pS.keys()), keys));
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
TEST(DecisionTreeFactor, sum_max) {
|
TEST(DecisionTreeFactor, sum_max) {
|
||||||
DiscreteKey v0(0, 3), v1(1, 2);
|
DiscreteKey v0(0, 3), v1(1, 2);
|
||||||
|
@ -217,12 +244,6 @@ void maybeSaveDotFile(const DecisionTreeFactor& f, const string& filename) {
|
||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
/** Convert Signature into CPT */
|
|
||||||
DecisionTreeFactor create(const Signature& signature) {
|
|
||||||
DecisionTreeFactor p(signature.discreteKeys(), signature.cpt());
|
|
||||||
return p;
|
|
||||||
}
|
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
// test Asia Joint
|
// test Asia Joint
|
||||||
TEST(DecisionTreeFactor, joint) {
|
TEST(DecisionTreeFactor, joint) {
|
||||||
|
|
Loading…
Reference in New Issue