fix equality to check for matching keys
parent
b5cd82d0b4
commit
ab3f48bbe9
|
@ -48,7 +48,7 @@ namespace gtsam {
|
||||||
return false;
|
return false;
|
||||||
} else {
|
} else {
|
||||||
const auto& f(static_cast<const DecisionTreeFactor&>(other));
|
const auto& f(static_cast<const DecisionTreeFactor&>(other));
|
||||||
return ADT::equals(f, tol);
|
return Base::equals(other, tol) && ADT::equals(f, tol);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -28,6 +28,11 @@ using namespace std;
|
||||||
|
|
||||||
namespace gtsam {
|
namespace gtsam {
|
||||||
|
|
||||||
|
/* ************************************************************************* */
|
||||||
|
bool DiscreteFactor::equals(const DiscreteFactor& lf, double tol) const {
|
||||||
|
return Base::equals(lf, tol) && cardinalities_ == lf.cardinalities_;
|
||||||
|
}
|
||||||
|
|
||||||
/* ************************************************************************ */
|
/* ************************************************************************ */
|
||||||
DiscreteKeys DiscreteFactor::discreteKeys() const {
|
DiscreteKeys DiscreteFactor::discreteKeys() const {
|
||||||
DiscreteKeys result;
|
DiscreteKeys result;
|
||||||
|
|
|
@ -77,7 +77,7 @@ class GTSAM_EXPORT DiscreteFactor : public Factor {
|
||||||
/// @{
|
/// @{
|
||||||
|
|
||||||
/// equals
|
/// equals
|
||||||
virtual bool equals(const DiscreteFactor& lf, double tol = 1e-9) const = 0;
|
virtual bool equals(const DiscreteFactor& lf, double tol = 1e-9) const;
|
||||||
|
|
||||||
/// print
|
/// print
|
||||||
void print(
|
void print(
|
||||||
|
|
|
@ -144,7 +144,8 @@ bool TableFactor::equals(const DiscreteFactor& other, double tol) const {
|
||||||
return false;
|
return false;
|
||||||
} else {
|
} else {
|
||||||
const auto& f(static_cast<const TableFactor&>(other));
|
const auto& f(static_cast<const TableFactor&>(other));
|
||||||
return sparse_table_.isApprox(f.sparse_table_, tol);
|
return Base::equals(other, tol) &&
|
||||||
|
sparse_table_.isApprox(f.sparse_table_, tol);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -134,14 +134,17 @@ TEST(TableFactor, constructors) {
|
||||||
EXPECT(assert_equal(expected, f4));
|
EXPECT(assert_equal(expected, f4));
|
||||||
|
|
||||||
// Test for 9=3x3 values.
|
// Test for 9=3x3 values.
|
||||||
DiscreteKey V(0, 3), W(1, 3);
|
DiscreteKey V(0, 3), W(1, 3), O(100, 3);
|
||||||
DiscreteConditional conditional5(V | W = "1/2/3 5/6/7 9/10/11");
|
DiscreteConditional conditional5(V | W = "1/2/3 5/6/7 9/10/11");
|
||||||
TableFactor f5(conditional5);
|
TableFactor f5(conditional5);
|
||||||
// GTSAM_PRINT(f5);
|
|
||||||
TableFactor expected_f5(
|
std::string expected_values =
|
||||||
X & Y,
|
"0.166667 0.277778 0.3 0.333333 0.333333 0.333333 0.5 0.388889 0.366667";
|
||||||
"0.166667 0.277778 0.3 0.333333 0.333333 0.333333 0.5 0.388889 0.366667");
|
TableFactor expected_f5(V & W, expected_values);
|
||||||
EXPECT(assert_equal(expected_f5, f5, 1e-6));
|
EXPECT(assert_equal(expected_f5, f5, 1e-6));
|
||||||
|
|
||||||
|
TableFactor f5_with_wrong_keys(V & O, expected_values);
|
||||||
|
EXPECT(!assert_equal(f5_with_wrong_keys, f5, 1e-9));
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
|
|
Loading…
Reference in New Issue