diff --git a/gtsam/discrete/DecisionTreeFactor.cpp b/gtsam/discrete/DecisionTreeFactor.cpp index f68d2ae00..1ac782b88 100644 --- a/gtsam/discrete/DecisionTreeFactor.cpp +++ b/gtsam/discrete/DecisionTreeFactor.cpp @@ -48,7 +48,7 @@ namespace gtsam { return false; } else { const auto& f(static_cast(other)); - return ADT::equals(f, tol); + return Base::equals(other, tol) && ADT::equals(f, tol); } } diff --git a/gtsam/discrete/DiscreteFactor.cpp b/gtsam/discrete/DiscreteFactor.cpp index 2b11046f4..68cc1df7d 100644 --- a/gtsam/discrete/DiscreteFactor.cpp +++ b/gtsam/discrete/DiscreteFactor.cpp @@ -28,6 +28,11 @@ using namespace std; namespace gtsam { +/* ************************************************************************* */ +bool DiscreteFactor::equals(const DiscreteFactor& lf, double tol) const { + return Base::equals(lf, tol) && cardinalities_ == lf.cardinalities_; +} + /* ************************************************************************ */ DiscreteKeys DiscreteFactor::discreteKeys() const { DiscreteKeys result; diff --git a/gtsam/discrete/DiscreteFactor.h b/gtsam/discrete/DiscreteFactor.h index fa1179c39..29981e94b 100644 --- a/gtsam/discrete/DiscreteFactor.h +++ b/gtsam/discrete/DiscreteFactor.h @@ -77,7 +77,7 @@ class GTSAM_EXPORT DiscreteFactor : public Factor { /// @{ /// equals - virtual bool equals(const DiscreteFactor& lf, double tol = 1e-9) const = 0; + virtual bool equals(const DiscreteFactor& lf, double tol = 1e-9) const; /// print void print( diff --git a/gtsam/discrete/TableFactor.cpp b/gtsam/discrete/TableFactor.cpp index 486641e3f..8e185eb3b 100644 --- a/gtsam/discrete/TableFactor.cpp +++ b/gtsam/discrete/TableFactor.cpp @@ -144,7 +144,8 @@ bool TableFactor::equals(const DiscreteFactor& other, double tol) const { return false; } else { const auto& f(static_cast(other)); - return sparse_table_.isApprox(f.sparse_table_, tol); + return Base::equals(other, tol) && + sparse_table_.isApprox(f.sparse_table_, tol); } } diff --git a/gtsam/discrete/tests/testTableFactor.cpp b/gtsam/discrete/tests/testTableFactor.cpp index 0f7d7a615..35dcaf093 100644 --- a/gtsam/discrete/tests/testTableFactor.cpp +++ b/gtsam/discrete/tests/testTableFactor.cpp @@ -134,14 +134,17 @@ TEST(TableFactor, constructors) { EXPECT(assert_equal(expected, f4)); // Test for 9=3x3 values. - DiscreteKey V(0, 3), W(1, 3); + DiscreteKey V(0, 3), W(1, 3), O(100, 3); DiscreteConditional conditional5(V | W = "1/2/3 5/6/7 9/10/11"); TableFactor f5(conditional5); - // GTSAM_PRINT(f5); - TableFactor expected_f5( - X & Y, - "0.166667 0.277778 0.3 0.333333 0.333333 0.333333 0.5 0.388889 0.366667"); + + std::string expected_values = + "0.166667 0.277778 0.3 0.333333 0.333333 0.333333 0.5 0.388889 0.366667"; + TableFactor expected_f5(V & W, expected_values); EXPECT(assert_equal(expected_f5, f5, 1e-6)); + + TableFactor f5_with_wrong_keys(V & O, expected_values); + EXPECT(!assert_equal(f5_with_wrong_keys, f5, 1e-9)); } /* ************************************************************************* */