diff --git a/gtsam/discrete/DecisionTreeFactor.cpp b/gtsam/discrete/DecisionTreeFactor.cpp index f68d2ae00..1ac782b88 100644 --- a/gtsam/discrete/DecisionTreeFactor.cpp +++ b/gtsam/discrete/DecisionTreeFactor.cpp @@ -48,7 +48,7 @@ namespace gtsam { return false; } else { const auto& f(static_cast(other)); - return ADT::equals(f, tol); + return Base::equals(other, tol) && ADT::equals(f, tol); } } diff --git a/gtsam/discrete/DiscreteFactor.cpp b/gtsam/discrete/DiscreteFactor.cpp index 2b11046f4..68cc1df7d 100644 --- a/gtsam/discrete/DiscreteFactor.cpp +++ b/gtsam/discrete/DiscreteFactor.cpp @@ -28,6 +28,11 @@ using namespace std; namespace gtsam { +/* ************************************************************************* */ +bool DiscreteFactor::equals(const DiscreteFactor& lf, double tol) const { + return Base::equals(lf, tol) && cardinalities_ == lf.cardinalities_; +} + /* ************************************************************************ */ DiscreteKeys DiscreteFactor::discreteKeys() const { DiscreteKeys result; diff --git a/gtsam/discrete/DiscreteFactor.h b/gtsam/discrete/DiscreteFactor.h index fa1179c39..29981e94b 100644 --- a/gtsam/discrete/DiscreteFactor.h +++ b/gtsam/discrete/DiscreteFactor.h @@ -77,7 +77,7 @@ class GTSAM_EXPORT DiscreteFactor : public Factor { /// @{ /// equals - virtual bool equals(const DiscreteFactor& lf, double tol = 1e-9) const = 0; + virtual bool equals(const DiscreteFactor& lf, double tol = 1e-9) const; /// print void print( diff --git a/gtsam/discrete/TableFactor.cpp b/gtsam/discrete/TableFactor.cpp index ea51a996c..8e185eb3b 100644 --- a/gtsam/discrete/TableFactor.cpp +++ b/gtsam/discrete/TableFactor.cpp @@ -92,13 +92,28 @@ TableFactor::TableFactor(const DiscreteKeys& dkeys, const DecisionTreeFactor& dtf) : TableFactor(dkeys, ComputeLeafOrdering(dkeys, dtf)) {} +/* ************************************************************************ */ +TableFactor::TableFactor(const DecisionTreeFactor& dtf) + : TableFactor(dtf.discreteKeys(), + ComputeLeafOrdering(dtf.discreteKeys(), dtf)) {} + /* ************************************************************************ */ TableFactor::TableFactor(const DiscreteConditional& c) : TableFactor(c.discreteKeys(), c) {} /* ************************************************************************ */ Eigen::SparseVector TableFactor::Convert( - const std::vector& table) { + const DiscreteKeys& keys, const std::vector& table) { + size_t max_size = 1; + for (auto&& [_, cardinality] : keys.cardinalities()) { + max_size *= cardinality; + } + if (table.size() != max_size) { + throw std::runtime_error( + "The cardinalities of the keys don't match the number of values in the " + "input."); + } + Eigen::SparseVector sparse_table(table.size()); // Count number of nonzero elements in table and reserve the space. const uint64_t nnz = std::count_if(table.begin(), table.end(), @@ -113,13 +128,14 @@ Eigen::SparseVector TableFactor::Convert( } /* ************************************************************************ */ -Eigen::SparseVector TableFactor::Convert(const std::string& table) { +Eigen::SparseVector TableFactor::Convert(const DiscreteKeys& keys, + const std::string& table) { // Convert string to doubles. std::vector ys; std::istringstream iss(table); std::copy(std::istream_iterator(iss), std::istream_iterator(), std::back_inserter(ys)); - return Convert(ys); + return Convert(keys, ys); } /* ************************************************************************ */ @@ -128,7 +144,8 @@ bool TableFactor::equals(const DiscreteFactor& other, double tol) const { return false; } else { const auto& f(static_cast(other)); - return sparse_table_.isApprox(f.sparse_table_, tol); + return Base::equals(other, tol) && + sparse_table_.isApprox(f.sparse_table_, tol); } } @@ -250,7 +267,8 @@ void TableFactor::print(const string& s, const KeyFormatter& formatter) const { for (auto&& kv : assignment) { cout << "(" << formatter(kv.first) << ", " << kv.second << ")"; } - cout << " | " << it.value() << " | " << it.index() << endl; + cout << " | " << std::setw(10) << std::left << it.value() << " | " + << it.index() << endl; } cout << "number of nnzs: " << sparse_table_.nonZeros() << endl; } diff --git a/gtsam/discrete/TableFactor.h b/gtsam/discrete/TableFactor.h index 07564892f..d27c4740c 100644 --- a/gtsam/discrete/TableFactor.h +++ b/gtsam/discrete/TableFactor.h @@ -80,12 +80,16 @@ class GTSAM_EXPORT TableFactor : public DiscreteFactor { return DiscreteKey(keys_[i], cardinalities_.at(keys_[i])); } - /// Convert probability table given as doubles to SparseVector. - /// Example) {0, 1, 1, 0, 0, 1, 0} -> values: {1, 1, 1}, indices: {1, 2, 5} - static Eigen::SparseVector Convert(const std::vector& table); + /** + * Convert probability table given as doubles to SparseVector. + * Example: {0, 1, 1, 0, 0, 1, 0} -> values: {1, 1, 1}, indices: {1, 2, 5} + */ + static Eigen::SparseVector Convert(const DiscreteKeys& keys, + const std::vector& table); /// Convert probability table given as string to SparseVector. - static Eigen::SparseVector Convert(const std::string& table); + static Eigen::SparseVector Convert(const DiscreteKeys& keys, + const std::string& table); public: // typedefs needed to play nice with gtsam @@ -111,11 +115,11 @@ class GTSAM_EXPORT TableFactor : public DiscreteFactor { /** Constructor from doubles */ TableFactor(const DiscreteKeys& keys, const std::vector& table) - : TableFactor(keys, Convert(table)) {} + : TableFactor(keys, Convert(keys, table)) {} /** Constructor from string */ TableFactor(const DiscreteKeys& keys, const std::string& table) - : TableFactor(keys, Convert(table)) {} + : TableFactor(keys, Convert(keys, table)) {} /// Single-key specialization template @@ -128,6 +132,7 @@ class GTSAM_EXPORT TableFactor : public DiscreteFactor { /// Constructor from DecisionTreeFactor TableFactor(const DiscreteKeys& keys, const DecisionTreeFactor& dtf); + TableFactor(const DecisionTreeFactor& dtf); /// Constructor from DecisionTree/AlgebraicDecisionTree TableFactor(const DiscreteKeys& keys, const DecisionTree& dtree); diff --git a/gtsam/discrete/tests/testTableFactor.cpp b/gtsam/discrete/tests/testTableFactor.cpp index 0f7d7a615..212067cb3 100644 --- a/gtsam/discrete/tests/testTableFactor.cpp +++ b/gtsam/discrete/tests/testTableFactor.cpp @@ -134,14 +134,17 @@ TEST(TableFactor, constructors) { EXPECT(assert_equal(expected, f4)); // Test for 9=3x3 values. - DiscreteKey V(0, 3), W(1, 3); + DiscreteKey V(0, 3), W(1, 3), O(100, 3); DiscreteConditional conditional5(V | W = "1/2/3 5/6/7 9/10/11"); TableFactor f5(conditional5); - // GTSAM_PRINT(f5); - TableFactor expected_f5( - X & Y, - "0.166667 0.277778 0.3 0.333333 0.333333 0.333333 0.5 0.388889 0.366667"); + + std::string expected_values = + "0.166667 0.277778 0.3 0.333333 0.333333 0.333333 0.5 0.388889 0.366667"; + TableFactor expected_f5(V & W, expected_values); EXPECT(assert_equal(expected_f5, f5, 1e-6)); + + TableFactor f5_with_wrong_keys(V & O, expected_values); + EXPECT(assert_inequal(f5_with_wrong_keys, f5, 1e-9)); } /* ************************************************************************* */