Merge pull request #1928 from borglab/fix-table-factor

release/4.3a0
Varun Agrawal 2024-12-11 23:35:47 -05:00 committed by GitHub
commit 137a503746
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 49 additions and 18 deletions

View File

@ -48,7 +48,7 @@ namespace gtsam {
return false; return false;
} else { } else {
const auto& f(static_cast<const DecisionTreeFactor&>(other)); const auto& f(static_cast<const DecisionTreeFactor&>(other));
return ADT::equals(f, tol); return Base::equals(other, tol) && ADT::equals(f, tol);
} }
} }

View File

@ -28,6 +28,11 @@ using namespace std;
namespace gtsam { namespace gtsam {
/* ************************************************************************* */
bool DiscreteFactor::equals(const DiscreteFactor& lf, double tol) const {
return Base::equals(lf, tol) && cardinalities_ == lf.cardinalities_;
}
/* ************************************************************************ */ /* ************************************************************************ */
DiscreteKeys DiscreteFactor::discreteKeys() const { DiscreteKeys DiscreteFactor::discreteKeys() const {
DiscreteKeys result; DiscreteKeys result;

View File

@ -77,7 +77,7 @@ class GTSAM_EXPORT DiscreteFactor : public Factor {
/// @{ /// @{
/// equals /// equals
virtual bool equals(const DiscreteFactor& lf, double tol = 1e-9) const = 0; virtual bool equals(const DiscreteFactor& lf, double tol = 1e-9) const;
/// print /// print
void print( void print(

View File

@ -92,13 +92,28 @@ TableFactor::TableFactor(const DiscreteKeys& dkeys,
const DecisionTreeFactor& dtf) const DecisionTreeFactor& dtf)
: TableFactor(dkeys, ComputeLeafOrdering(dkeys, dtf)) {} : TableFactor(dkeys, ComputeLeafOrdering(dkeys, dtf)) {}
/* ************************************************************************ */
TableFactor::TableFactor(const DecisionTreeFactor& dtf)
: TableFactor(dtf.discreteKeys(),
ComputeLeafOrdering(dtf.discreteKeys(), dtf)) {}
/* ************************************************************************ */ /* ************************************************************************ */
TableFactor::TableFactor(const DiscreteConditional& c) TableFactor::TableFactor(const DiscreteConditional& c)
: TableFactor(c.discreteKeys(), c) {} : TableFactor(c.discreteKeys(), c) {}
/* ************************************************************************ */ /* ************************************************************************ */
Eigen::SparseVector<double> TableFactor::Convert( Eigen::SparseVector<double> TableFactor::Convert(
const std::vector<double>& table) { const DiscreteKeys& keys, const std::vector<double>& table) {
size_t max_size = 1;
for (auto&& [_, cardinality] : keys.cardinalities()) {
max_size *= cardinality;
}
if (table.size() != max_size) {
throw std::runtime_error(
"The cardinalities of the keys don't match the number of values in the "
"input.");
}
Eigen::SparseVector<double> sparse_table(table.size()); Eigen::SparseVector<double> sparse_table(table.size());
// Count number of nonzero elements in table and reserve the space. // Count number of nonzero elements in table and reserve the space.
const uint64_t nnz = std::count_if(table.begin(), table.end(), const uint64_t nnz = std::count_if(table.begin(), table.end(),
@ -113,13 +128,14 @@ Eigen::SparseVector<double> TableFactor::Convert(
} }
/* ************************************************************************ */ /* ************************************************************************ */
Eigen::SparseVector<double> TableFactor::Convert(const std::string& table) { Eigen::SparseVector<double> TableFactor::Convert(const DiscreteKeys& keys,
const std::string& table) {
// Convert string to doubles. // Convert string to doubles.
std::vector<double> ys; std::vector<double> ys;
std::istringstream iss(table); std::istringstream iss(table);
std::copy(std::istream_iterator<double>(iss), std::istream_iterator<double>(), std::copy(std::istream_iterator<double>(iss), std::istream_iterator<double>(),
std::back_inserter(ys)); std::back_inserter(ys));
return Convert(ys); return Convert(keys, ys);
} }
/* ************************************************************************ */ /* ************************************************************************ */
@ -128,7 +144,8 @@ bool TableFactor::equals(const DiscreteFactor& other, double tol) const {
return false; return false;
} else { } else {
const auto& f(static_cast<const TableFactor&>(other)); const auto& f(static_cast<const TableFactor&>(other));
return sparse_table_.isApprox(f.sparse_table_, tol); return Base::equals(other, tol) &&
sparse_table_.isApprox(f.sparse_table_, tol);
} }
} }
@ -250,7 +267,8 @@ void TableFactor::print(const string& s, const KeyFormatter& formatter) const {
for (auto&& kv : assignment) { for (auto&& kv : assignment) {
cout << "(" << formatter(kv.first) << ", " << kv.second << ")"; cout << "(" << formatter(kv.first) << ", " << kv.second << ")";
} }
cout << " | " << it.value() << " | " << it.index() << endl; cout << " | " << std::setw(10) << std::left << it.value() << " | "
<< it.index() << endl;
} }
cout << "number of nnzs: " << sparse_table_.nonZeros() << endl; cout << "number of nnzs: " << sparse_table_.nonZeros() << endl;
} }

View File

@ -80,12 +80,16 @@ class GTSAM_EXPORT TableFactor : public DiscreteFactor {
return DiscreteKey(keys_[i], cardinalities_.at(keys_[i])); return DiscreteKey(keys_[i], cardinalities_.at(keys_[i]));
} }
/// Convert probability table given as doubles to SparseVector. /**
/// Example) {0, 1, 1, 0, 0, 1, 0} -> values: {1, 1, 1}, indices: {1, 2, 5} * Convert probability table given as doubles to SparseVector.
static Eigen::SparseVector<double> Convert(const std::vector<double>& table); * Example: {0, 1, 1, 0, 0, 1, 0} -> values: {1, 1, 1}, indices: {1, 2, 5}
*/
static Eigen::SparseVector<double> Convert(const DiscreteKeys& keys,
const std::vector<double>& table);
/// Convert probability table given as string to SparseVector. /// Convert probability table given as string to SparseVector.
static Eigen::SparseVector<double> Convert(const std::string& table); static Eigen::SparseVector<double> Convert(const DiscreteKeys& keys,
const std::string& table);
public: public:
// typedefs needed to play nice with gtsam // typedefs needed to play nice with gtsam
@ -111,11 +115,11 @@ class GTSAM_EXPORT TableFactor : public DiscreteFactor {
/** Constructor from doubles */ /** Constructor from doubles */
TableFactor(const DiscreteKeys& keys, const std::vector<double>& table) TableFactor(const DiscreteKeys& keys, const std::vector<double>& table)
: TableFactor(keys, Convert(table)) {} : TableFactor(keys, Convert(keys, table)) {}
/** Constructor from string */ /** Constructor from string */
TableFactor(const DiscreteKeys& keys, const std::string& table) TableFactor(const DiscreteKeys& keys, const std::string& table)
: TableFactor(keys, Convert(table)) {} : TableFactor(keys, Convert(keys, table)) {}
/// Single-key specialization /// Single-key specialization
template <class SOURCE> template <class SOURCE>
@ -128,6 +132,7 @@ class GTSAM_EXPORT TableFactor : public DiscreteFactor {
/// Constructor from DecisionTreeFactor /// Constructor from DecisionTreeFactor
TableFactor(const DiscreteKeys& keys, const DecisionTreeFactor& dtf); TableFactor(const DiscreteKeys& keys, const DecisionTreeFactor& dtf);
TableFactor(const DecisionTreeFactor& dtf);
/// Constructor from DecisionTree<Key, double>/AlgebraicDecisionTree /// Constructor from DecisionTree<Key, double>/AlgebraicDecisionTree
TableFactor(const DiscreteKeys& keys, const DecisionTree<Key, double>& dtree); TableFactor(const DiscreteKeys& keys, const DecisionTree<Key, double>& dtree);

View File

@ -134,14 +134,17 @@ TEST(TableFactor, constructors) {
EXPECT(assert_equal(expected, f4)); EXPECT(assert_equal(expected, f4));
// Test for 9=3x3 values. // Test for 9=3x3 values.
DiscreteKey V(0, 3), W(1, 3); DiscreteKey V(0, 3), W(1, 3), O(100, 3);
DiscreteConditional conditional5(V | W = "1/2/3 5/6/7 9/10/11"); DiscreteConditional conditional5(V | W = "1/2/3 5/6/7 9/10/11");
TableFactor f5(conditional5); TableFactor f5(conditional5);
// GTSAM_PRINT(f5);
TableFactor expected_f5( std::string expected_values =
X & Y, "0.166667 0.277778 0.3 0.333333 0.333333 0.333333 0.5 0.388889 0.366667";
"0.166667 0.277778 0.3 0.333333 0.333333 0.333333 0.5 0.388889 0.366667"); TableFactor expected_f5(V & W, expected_values);
EXPECT(assert_equal(expected_f5, f5, 1e-6)); EXPECT(assert_equal(expected_f5, f5, 1e-6));
TableFactor f5_with_wrong_keys(V & O, expected_values);
EXPECT(assert_inequal(f5_with_wrong_keys, f5, 1e-9));
} }
/* ************************************************************************* */ /* ************************************************************************* */