fix equality to check for matching keys

release/4.3a0
Varun Agrawal 2024-12-11 15:36:25 -05:00
parent b5cd82d0b4
commit ab3f48bbe9
5 changed files with 17 additions and 8 deletions

View File

@ -48,7 +48,7 @@ namespace gtsam {
return false;
} else {
const auto& f(static_cast<const DecisionTreeFactor&>(other));
return ADT::equals(f, tol);
return Base::equals(other, tol) && ADT::equals(f, tol);
}
}

View File

@ -28,6 +28,11 @@ using namespace std;
namespace gtsam {
/* ************************************************************************* */
bool DiscreteFactor::equals(const DiscreteFactor& lf, double tol) const {
return Base::equals(lf, tol) && cardinalities_ == lf.cardinalities_;
}
/* ************************************************************************ */
DiscreteKeys DiscreteFactor::discreteKeys() const {
DiscreteKeys result;

View File

@ -77,7 +77,7 @@ class GTSAM_EXPORT DiscreteFactor : public Factor {
/// @{
/// equals
virtual bool equals(const DiscreteFactor& lf, double tol = 1e-9) const = 0;
virtual bool equals(const DiscreteFactor& lf, double tol = 1e-9) const;
/// print
void print(

View File

@ -144,7 +144,8 @@ bool TableFactor::equals(const DiscreteFactor& other, double tol) const {
return false;
} else {
const auto& f(static_cast<const TableFactor&>(other));
return sparse_table_.isApprox(f.sparse_table_, tol);
return Base::equals(other, tol) &&
sparse_table_.isApprox(f.sparse_table_, tol);
}
}

View File

@ -134,14 +134,17 @@ TEST(TableFactor, constructors) {
EXPECT(assert_equal(expected, f4));
// Test for 9=3x3 values.
DiscreteKey V(0, 3), W(1, 3);
DiscreteKey V(0, 3), W(1, 3), O(100, 3);
DiscreteConditional conditional5(V | W = "1/2/3 5/6/7 9/10/11");
TableFactor f5(conditional5);
// GTSAM_PRINT(f5);
TableFactor expected_f5(
X & Y,
"0.166667 0.277778 0.3 0.333333 0.333333 0.333333 0.5 0.388889 0.366667");
std::string expected_values =
"0.166667 0.277778 0.3 0.333333 0.333333 0.333333 0.5 0.388889 0.366667";
TableFactor expected_f5(V & W, expected_values);
EXPECT(assert_equal(expected_f5, f5, 1e-6));
TableFactor f5_with_wrong_keys(V & O, expected_values);
EXPECT(!assert_equal(f5_with_wrong_keys, f5, 1e-9));
}
/* ************************************************************************* */