check if input size matches the total cardinality of the keys

release/4.3a0
Varun Agrawal 2024-12-11 15:00:53 -05:00
parent b76c6d8250
commit 425c3ac42c
2 changed files with 26 additions and 10 deletions

View File

@ -98,7 +98,17 @@ TableFactor::TableFactor(const DiscreteConditional& c)
/* ************************************************************************ */
Eigen::SparseVector<double> TableFactor::Convert(
const std::vector<double>& table) {
const DiscreteKeys& keys, const std::vector<double>& table) {
size_t max_size = 1;
for (auto&& [_, cardinality] : keys.cardinalities()) {
max_size *= cardinality;
}
if (table.size() != max_size) {
throw std::runtime_error(
"The cardinalities of the keys don't match the number of values in the "
"input.");
}
Eigen::SparseVector<double> sparse_table(table.size());
// Count number of nonzero elements in table and reserve the space.
const uint64_t nnz = std::count_if(table.begin(), table.end(),
@ -113,13 +123,14 @@ Eigen::SparseVector<double> TableFactor::Convert(
}
/* ************************************************************************ */
Eigen::SparseVector<double> TableFactor::Convert(const std::string& table) {
Eigen::SparseVector<double> TableFactor::Convert(const DiscreteKeys& keys,
const std::string& table) {
// Convert string to doubles.
std::vector<double> ys;
std::istringstream iss(table);
std::copy(std::istream_iterator<double>(iss), std::istream_iterator<double>(),
std::back_inserter(ys));
return Convert(ys);
return Convert(keys, ys);
}
/* ************************************************************************ */
@ -250,7 +261,8 @@ void TableFactor::print(const string& s, const KeyFormatter& formatter) const {
for (auto&& kv : assignment) {
cout << "(" << formatter(kv.first) << ", " << kv.second << ")";
}
cout << " | " << it.value() << " | " << it.index() << endl;
cout << " | " << std::setw(10) << std::left << it.value() << " | "
<< it.index() << endl;
}
cout << "number of nnzs: " << sparse_table_.nonZeros() << endl;
}

View File

@ -80,12 +80,16 @@ class GTSAM_EXPORT TableFactor : public DiscreteFactor {
return DiscreteKey(keys_[i], cardinalities_.at(keys_[i]));
}
/// Convert probability table given as doubles to SparseVector.
/// Example) {0, 1, 1, 0, 0, 1, 0} -> values: {1, 1, 1}, indices: {1, 2, 5}
static Eigen::SparseVector<double> Convert(const std::vector<double>& table);
/**
* Convert probability table given as doubles to SparseVector.
* Example: {0, 1, 1, 0, 0, 1, 0} -> values: {1, 1, 1}, indices: {1, 2, 5}
*/
static Eigen::SparseVector<double> Convert(const DiscreteKeys& keys,
const std::vector<double>& table);
/// Convert probability table given as string to SparseVector.
static Eigen::SparseVector<double> Convert(const std::string& table);
static Eigen::SparseVector<double> Convert(const DiscreteKeys& keys,
const std::string& table);
public:
// typedefs needed to play nice with gtsam
@ -111,11 +115,11 @@ class GTSAM_EXPORT TableFactor : public DiscreteFactor {
/** Constructor from doubles */
TableFactor(const DiscreteKeys& keys, const std::vector<double>& table)
: TableFactor(keys, Convert(table)) {}
: TableFactor(keys, Convert(keys, table)) {}
/** Constructor from string */
TableFactor(const DiscreteKeys& keys, const std::string& table)
: TableFactor(keys, Convert(table)) {}
: TableFactor(keys, Convert(keys, table)) {}
/// Single-key specialization
template <class SOURCE>