fix equality to check for matching keys
parent
b5cd82d0b4
commit
ab3f48bbe9
|
@ -48,7 +48,7 @@ namespace gtsam {
|
|||
return false;
|
||||
} else {
|
||||
const auto& f(static_cast<const DecisionTreeFactor&>(other));
|
||||
return ADT::equals(f, tol);
|
||||
return Base::equals(other, tol) && ADT::equals(f, tol);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -28,6 +28,11 @@ using namespace std;
|
|||
|
||||
namespace gtsam {
|
||||
|
||||
/* ************************************************************************* */
|
||||
bool DiscreteFactor::equals(const DiscreteFactor& lf, double tol) const {
|
||||
return Base::equals(lf, tol) && cardinalities_ == lf.cardinalities_;
|
||||
}
|
||||
|
||||
/* ************************************************************************ */
|
||||
DiscreteKeys DiscreteFactor::discreteKeys() const {
|
||||
DiscreteKeys result;
|
||||
|
|
|
@ -77,7 +77,7 @@ class GTSAM_EXPORT DiscreteFactor : public Factor {
|
|||
/// @{
|
||||
|
||||
/// equals
|
||||
virtual bool equals(const DiscreteFactor& lf, double tol = 1e-9) const = 0;
|
||||
virtual bool equals(const DiscreteFactor& lf, double tol = 1e-9) const;
|
||||
|
||||
/// print
|
||||
void print(
|
||||
|
|
|
@ -144,7 +144,8 @@ bool TableFactor::equals(const DiscreteFactor& other, double tol) const {
|
|||
return false;
|
||||
} else {
|
||||
const auto& f(static_cast<const TableFactor&>(other));
|
||||
return sparse_table_.isApprox(f.sparse_table_, tol);
|
||||
return Base::equals(other, tol) &&
|
||||
sparse_table_.isApprox(f.sparse_table_, tol);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -134,14 +134,17 @@ TEST(TableFactor, constructors) {
|
|||
EXPECT(assert_equal(expected, f4));
|
||||
|
||||
// Test for 9=3x3 values.
|
||||
DiscreteKey V(0, 3), W(1, 3);
|
||||
DiscreteKey V(0, 3), W(1, 3), O(100, 3);
|
||||
DiscreteConditional conditional5(V | W = "1/2/3 5/6/7 9/10/11");
|
||||
TableFactor f5(conditional5);
|
||||
// GTSAM_PRINT(f5);
|
||||
TableFactor expected_f5(
|
||||
X & Y,
|
||||
"0.166667 0.277778 0.3 0.333333 0.333333 0.333333 0.5 0.388889 0.366667");
|
||||
|
||||
std::string expected_values =
|
||||
"0.166667 0.277778 0.3 0.333333 0.333333 0.333333 0.5 0.388889 0.366667";
|
||||
TableFactor expected_f5(V & W, expected_values);
|
||||
EXPECT(assert_equal(expected_f5, f5, 1e-6));
|
||||
|
||||
TableFactor f5_with_wrong_keys(V & O, expected_values);
|
||||
EXPECT(!assert_equal(f5_with_wrong_keys, f5, 1e-9));
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
|
|
Loading…
Reference in New Issue