check if input size matches the total cardinality of the keys
parent
b76c6d8250
commit
425c3ac42c
|
@ -98,7 +98,17 @@ TableFactor::TableFactor(const DiscreteConditional& c)
|
|||
|
||||
/* ************************************************************************ */
|
||||
Eigen::SparseVector<double> TableFactor::Convert(
|
||||
const std::vector<double>& table) {
|
||||
const DiscreteKeys& keys, const std::vector<double>& table) {
|
||||
size_t max_size = 1;
|
||||
for (auto&& [_, cardinality] : keys.cardinalities()) {
|
||||
max_size *= cardinality;
|
||||
}
|
||||
if (table.size() != max_size) {
|
||||
throw std::runtime_error(
|
||||
"The cardinalities of the keys don't match the number of values in the "
|
||||
"input.");
|
||||
}
|
||||
|
||||
Eigen::SparseVector<double> sparse_table(table.size());
|
||||
// Count number of nonzero elements in table and reserve the space.
|
||||
const uint64_t nnz = std::count_if(table.begin(), table.end(),
|
||||
|
@ -113,13 +123,14 @@ Eigen::SparseVector<double> TableFactor::Convert(
|
|||
}
|
||||
|
||||
/* ************************************************************************ */
|
||||
Eigen::SparseVector<double> TableFactor::Convert(const std::string& table) {
|
||||
Eigen::SparseVector<double> TableFactor::Convert(const DiscreteKeys& keys,
|
||||
const std::string& table) {
|
||||
// Convert string to doubles.
|
||||
std::vector<double> ys;
|
||||
std::istringstream iss(table);
|
||||
std::copy(std::istream_iterator<double>(iss), std::istream_iterator<double>(),
|
||||
std::back_inserter(ys));
|
||||
return Convert(ys);
|
||||
return Convert(keys, ys);
|
||||
}
|
||||
|
||||
/* ************************************************************************ */
|
||||
|
@ -250,7 +261,8 @@ void TableFactor::print(const string& s, const KeyFormatter& formatter) const {
|
|||
for (auto&& kv : assignment) {
|
||||
cout << "(" << formatter(kv.first) << ", " << kv.second << ")";
|
||||
}
|
||||
cout << " | " << it.value() << " | " << it.index() << endl;
|
||||
cout << " | " << std::setw(10) << std::left << it.value() << " | "
|
||||
<< it.index() << endl;
|
||||
}
|
||||
cout << "number of nnzs: " << sparse_table_.nonZeros() << endl;
|
||||
}
|
||||
|
|
|
@ -80,12 +80,16 @@ class GTSAM_EXPORT TableFactor : public DiscreteFactor {
|
|||
return DiscreteKey(keys_[i], cardinalities_.at(keys_[i]));
|
||||
}
|
||||
|
||||
/// Convert probability table given as doubles to SparseVector.
|
||||
/// Example) {0, 1, 1, 0, 0, 1, 0} -> values: {1, 1, 1}, indices: {1, 2, 5}
|
||||
static Eigen::SparseVector<double> Convert(const std::vector<double>& table);
|
||||
/**
|
||||
* Convert probability table given as doubles to SparseVector.
|
||||
* Example: {0, 1, 1, 0, 0, 1, 0} -> values: {1, 1, 1}, indices: {1, 2, 5}
|
||||
*/
|
||||
static Eigen::SparseVector<double> Convert(const DiscreteKeys& keys,
|
||||
const std::vector<double>& table);
|
||||
|
||||
/// Convert probability table given as string to SparseVector.
|
||||
static Eigen::SparseVector<double> Convert(const std::string& table);
|
||||
static Eigen::SparseVector<double> Convert(const DiscreteKeys& keys,
|
||||
const std::string& table);
|
||||
|
||||
public:
|
||||
// typedefs needed to play nice with gtsam
|
||||
|
@ -111,11 +115,11 @@ class GTSAM_EXPORT TableFactor : public DiscreteFactor {
|
|||
|
||||
/** Constructor from doubles */
|
||||
TableFactor(const DiscreteKeys& keys, const std::vector<double>& table)
|
||||
: TableFactor(keys, Convert(table)) {}
|
||||
: TableFactor(keys, Convert(keys, table)) {}
|
||||
|
||||
/** Constructor from string */
|
||||
TableFactor(const DiscreteKeys& keys, const std::string& table)
|
||||
: TableFactor(keys, Convert(table)) {}
|
||||
: TableFactor(keys, Convert(keys, table)) {}
|
||||
|
||||
/// Single-key specialization
|
||||
template <class SOURCE>
|
||||
|
|
Loading…
Reference in New Issue