unary apply methods for TableFactor
parent
cf1292791e
commit
aafc33db69
|
@ -64,7 +64,7 @@ TableFactor::TableFactor(const DiscreteConditional& c)
|
||||||
Eigen::SparseVector<double> TableFactor::Convert(
|
Eigen::SparseVector<double> TableFactor::Convert(
|
||||||
const std::vector<double>& table) {
|
const std::vector<double>& table) {
|
||||||
Eigen::SparseVector<double> sparse_table(table.size());
|
Eigen::SparseVector<double> sparse_table(table.size());
|
||||||
// Count number of nonzero elements in table and reserving the space.
|
// Count number of nonzero elements in table and reserve the space.
|
||||||
const uint64_t nnz = std::count_if(table.begin(), table.end(),
|
const uint64_t nnz = std::count_if(table.begin(), table.end(),
|
||||||
[](uint64_t i) { return i != 0; });
|
[](uint64_t i) { return i != 0; });
|
||||||
sparse_table.reserve(nnz);
|
sparse_table.reserve(nnz);
|
||||||
|
@ -218,6 +218,45 @@ void TableFactor::print(const string& s, const KeyFormatter& formatter) const {
|
||||||
cout << "number of nnzs: " << sparse_table_.nonZeros() << endl;
|
cout << "number of nnzs: " << sparse_table_.nonZeros() << endl;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************ */
|
||||||
|
TableFactor TableFactor::apply(Unary op) const {
|
||||||
|
// Initialize new factor.
|
||||||
|
uint64_t cardi = 1;
|
||||||
|
for (auto [key, c] : cardinalities_) cardi *= c;
|
||||||
|
Eigen::SparseVector<double> sparse_table(cardi);
|
||||||
|
sparse_table.reserve(sparse_table_.nonZeros());
|
||||||
|
|
||||||
|
// Populate
|
||||||
|
for (SparseIt it(sparse_table_); it; ++it) {
|
||||||
|
sparse_table.coeffRef(it.index()) = op(it.value());
|
||||||
|
}
|
||||||
|
|
||||||
|
// Free unused memory and return.
|
||||||
|
sparse_table.pruned();
|
||||||
|
sparse_table.data().squeeze();
|
||||||
|
return TableFactor(discreteKeys(), sparse_table);
|
||||||
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************ */
|
||||||
|
TableFactor TableFactor::apply(UnaryAssignment op) const {
|
||||||
|
// Initialize new factor.
|
||||||
|
uint64_t cardi = 1;
|
||||||
|
for (auto [key, c] : cardinalities_) cardi *= c;
|
||||||
|
Eigen::SparseVector<double> sparse_table(cardi);
|
||||||
|
sparse_table.reserve(sparse_table_.nonZeros());
|
||||||
|
|
||||||
|
// Populate
|
||||||
|
for (SparseIt it(sparse_table_); it; ++it) {
|
||||||
|
DiscreteValues assignment = findAssignments(it.index());
|
||||||
|
sparse_table.coeffRef(it.index()) = op(assignment, it.value());
|
||||||
|
}
|
||||||
|
|
||||||
|
// Free unused memory and return.
|
||||||
|
sparse_table.pruned();
|
||||||
|
sparse_table.data().squeeze();
|
||||||
|
return TableFactor(discreteKeys(), sparse_table);
|
||||||
|
}
|
||||||
|
|
||||||
/* ************************************************************************ */
|
/* ************************************************************************ */
|
||||||
TableFactor TableFactor::apply(const TableFactor& f, Binary op) const {
|
TableFactor TableFactor::apply(const TableFactor& f, Binary op) const {
|
||||||
if (keys_.empty() && sparse_table_.nonZeros() == 0)
|
if (keys_.empty() && sparse_table_.nonZeros() == 0)
|
||||||
|
|
|
@ -93,6 +93,9 @@ class GTSAM_EXPORT TableFactor : public DiscreteFactor {
|
||||||
typedef std::shared_ptr<TableFactor> shared_ptr;
|
typedef std::shared_ptr<TableFactor> shared_ptr;
|
||||||
typedef Eigen::SparseVector<double>::InnerIterator SparseIt;
|
typedef Eigen::SparseVector<double>::InnerIterator SparseIt;
|
||||||
typedef std::vector<std::pair<DiscreteValues, double>> AssignValList;
|
typedef std::vector<std::pair<DiscreteValues, double>> AssignValList;
|
||||||
|
using Unary = std::function<double(const double&)>;
|
||||||
|
using UnaryAssignment =
|
||||||
|
std::function<double(const Assignment<Key>&, const double&)>;
|
||||||
using Binary = std::function<double(const double, const double)>;
|
using Binary = std::function<double(const double, const double)>;
|
||||||
|
|
||||||
public:
|
public:
|
||||||
|
@ -218,6 +221,18 @@ class GTSAM_EXPORT TableFactor : public DiscreteFactor {
|
||||||
/// @name Advanced Interface
|
/// @name Advanced Interface
|
||||||
/// @{
|
/// @{
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Apply unary operator `op(*this)` where `op` accepts the discrete value.
|
||||||
|
* @param op a unary operator that operates on TableFactor
|
||||||
|
*/
|
||||||
|
TableFactor apply(Unary op) const;
|
||||||
|
/**
|
||||||
|
* Apply unary operator `op(*this)` where `op` accepts the discrete assignment
|
||||||
|
* and the value at that assignment.
|
||||||
|
* @param op a unary operator that operates on TableFactor
|
||||||
|
*/
|
||||||
|
TableFactor apply(UnaryAssignment op) const;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Apply binary operator (*this) "op" f
|
* Apply binary operator (*this) "op" f
|
||||||
* @param f the second argument for op
|
* @param f the second argument for op
|
||||||
|
@ -225,10 +240,19 @@ class GTSAM_EXPORT TableFactor : public DiscreteFactor {
|
||||||
*/
|
*/
|
||||||
TableFactor apply(const TableFactor& f, Binary op) const;
|
TableFactor apply(const TableFactor& f, Binary op) const;
|
||||||
|
|
||||||
/// Return keys in contract mode.
|
/**
|
||||||
|
* Return keys in contract mode.
|
||||||
|
*
|
||||||
|
* Modes are each of the dimensions of a sparse tensor,
|
||||||
|
* and the contract modes represent which dimensions will
|
||||||
|
* be involved in contraction (aka tensor multiplication).
|
||||||
|
*/
|
||||||
DiscreteKeys contractDkeys(const TableFactor& f) const;
|
DiscreteKeys contractDkeys(const TableFactor& f) const;
|
||||||
|
|
||||||
/// Return keys in free mode.
|
/**
|
||||||
|
* @brief Return keys in free mode which are the dimensions
|
||||||
|
* not involved in the contraction operation.
|
||||||
|
*/
|
||||||
DiscreteKeys freeDkeys(const TableFactor& f) const;
|
DiscreteKeys freeDkeys(const TableFactor& f) const;
|
||||||
|
|
||||||
/// Return union of DiscreteKeys in two factors.
|
/// Return union of DiscreteKeys in two factors.
|
||||||
|
|
|
@ -93,8 +93,7 @@ void printTime(map<double, pair<chrono::microseconds, chrono::microseconds>>
|
||||||
for (auto&& kv : measured_time) {
|
for (auto&& kv : measured_time) {
|
||||||
cout << "dropout: " << kv.first
|
cout << "dropout: " << kv.first
|
||||||
<< " | TableFactor time: " << kv.second.first.count()
|
<< " | TableFactor time: " << kv.second.first.count()
|
||||||
<< " | DecisionTreeFactor time: " << kv.second.second.count() <<
|
<< " | DecisionTreeFactor time: " << kv.second.second.count() << endl;
|
||||||
endl;
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -361,6 +360,39 @@ TEST(TableFactor, htmlWithValueFormatter) {
|
||||||
EXPECT(actual == expected);
|
EXPECT(actual == expected);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************* */
|
||||||
|
TEST(TableFactor, Unary) {
|
||||||
|
// Declare a bunch of keys
|
||||||
|
DiscreteKey X(0, 2), Y(1, 3);
|
||||||
|
|
||||||
|
// Create factors
|
||||||
|
TableFactor f(X & Y, "2 5 3 6 2 7");
|
||||||
|
auto op = [](const double x) { return 2 * x; };
|
||||||
|
auto g = f.apply(op);
|
||||||
|
|
||||||
|
TableFactor expected(X & Y, "4 10 6 12 4 14");
|
||||||
|
EXPECT(assert_equal(g, expected));
|
||||||
|
|
||||||
|
auto sq_op = [](const double x) { return x * x; };
|
||||||
|
auto g_sq = f.apply(sq_op);
|
||||||
|
TableFactor expected_sq(X & Y, "4 25 9 36 4 49");
|
||||||
|
EXPECT(assert_equal(g_sq, expected_sq));
|
||||||
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************* */
|
||||||
|
TEST(TableFactor, UnaryAssignment) {
|
||||||
|
// Declare a bunch of keys
|
||||||
|
DiscreteKey X(0, 2), Y(1, 3);
|
||||||
|
|
||||||
|
// Create factors
|
||||||
|
TableFactor f(X & Y, "2 5 3 6 2 7");
|
||||||
|
auto op = [](const Assignment<Key>& key, const double x) { return 2 * x; };
|
||||||
|
auto g = f.apply(op);
|
||||||
|
|
||||||
|
TableFactor expected(X & Y, "4 10 6 12 4 14");
|
||||||
|
EXPECT(assert_equal(g, expected));
|
||||||
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
int main() {
|
int main() {
|
||||||
TestResult tr;
|
TestResult tr;
|
||||||
|
|
Loading…
Reference in New Issue