Merge pull request #1580 from borglab/tablefactor-apply
commit
0f7bc5cf2d
|
@ -64,7 +64,7 @@ TableFactor::TableFactor(const DiscreteConditional& c)
|
|||
Eigen::SparseVector<double> TableFactor::Convert(
|
||||
const std::vector<double>& table) {
|
||||
Eigen::SparseVector<double> sparse_table(table.size());
|
||||
// Count number of nonzero elements in table and reserving the space.
|
||||
// Count number of nonzero elements in table and reserve the space.
|
||||
const uint64_t nnz = std::count_if(table.begin(), table.end(),
|
||||
[](uint64_t i) { return i != 0; });
|
||||
sparse_table.reserve(nnz);
|
||||
|
@ -218,6 +218,45 @@ void TableFactor::print(const string& s, const KeyFormatter& formatter) const {
|
|||
cout << "number of nnzs: " << sparse_table_.nonZeros() << endl;
|
||||
}
|
||||
|
||||
/* ************************************************************************ */
|
||||
TableFactor TableFactor::apply(Unary op) const {
|
||||
// Initialize new factor.
|
||||
uint64_t cardi = 1;
|
||||
for (auto [key, c] : cardinalities_) cardi *= c;
|
||||
Eigen::SparseVector<double> sparse_table(cardi);
|
||||
sparse_table.reserve(sparse_table_.nonZeros());
|
||||
|
||||
// Populate
|
||||
for (SparseIt it(sparse_table_); it; ++it) {
|
||||
sparse_table.coeffRef(it.index()) = op(it.value());
|
||||
}
|
||||
|
||||
// Free unused memory and return.
|
||||
sparse_table.pruned();
|
||||
sparse_table.data().squeeze();
|
||||
return TableFactor(discreteKeys(), sparse_table);
|
||||
}
|
||||
|
||||
/* ************************************************************************ */
|
||||
TableFactor TableFactor::apply(UnaryAssignment op) const {
|
||||
// Initialize new factor.
|
||||
uint64_t cardi = 1;
|
||||
for (auto [key, c] : cardinalities_) cardi *= c;
|
||||
Eigen::SparseVector<double> sparse_table(cardi);
|
||||
sparse_table.reserve(sparse_table_.nonZeros());
|
||||
|
||||
// Populate
|
||||
for (SparseIt it(sparse_table_); it; ++it) {
|
||||
DiscreteValues assignment = findAssignments(it.index());
|
||||
sparse_table.coeffRef(it.index()) = op(assignment, it.value());
|
||||
}
|
||||
|
||||
// Free unused memory and return.
|
||||
sparse_table.pruned();
|
||||
sparse_table.data().squeeze();
|
||||
return TableFactor(discreteKeys(), sparse_table);
|
||||
}
|
||||
|
||||
/* ************************************************************************ */
|
||||
TableFactor TableFactor::apply(const TableFactor& f, Binary op) const {
|
||||
if (keys_.empty() && sparse_table_.nonZeros() == 0)
|
||||
|
|
|
@ -93,6 +93,9 @@ class GTSAM_EXPORT TableFactor : public DiscreteFactor {
|
|||
typedef std::shared_ptr<TableFactor> shared_ptr;
|
||||
typedef Eigen::SparseVector<double>::InnerIterator SparseIt;
|
||||
typedef std::vector<std::pair<DiscreteValues, double>> AssignValList;
|
||||
using Unary = std::function<double(const double&)>;
|
||||
using UnaryAssignment =
|
||||
std::function<double(const Assignment<Key>&, const double&)>;
|
||||
using Binary = std::function<double(const double, const double)>;
|
||||
|
||||
public:
|
||||
|
@ -218,6 +221,18 @@ class GTSAM_EXPORT TableFactor : public DiscreteFactor {
|
|||
/// @name Advanced Interface
|
||||
/// @{
|
||||
|
||||
/**
|
||||
* Apply unary operator `op(*this)` where `op` accepts the discrete value.
|
||||
* @param op a unary operator that operates on TableFactor
|
||||
*/
|
||||
TableFactor apply(Unary op) const;
|
||||
/**
|
||||
* Apply unary operator `op(*this)` where `op` accepts the discrete assignment
|
||||
* and the value at that assignment.
|
||||
* @param op a unary operator that operates on TableFactor
|
||||
*/
|
||||
TableFactor apply(UnaryAssignment op) const;
|
||||
|
||||
/**
|
||||
* Apply binary operator (*this) "op" f
|
||||
* @param f the second argument for op
|
||||
|
@ -225,10 +240,19 @@ class GTSAM_EXPORT TableFactor : public DiscreteFactor {
|
|||
*/
|
||||
TableFactor apply(const TableFactor& f, Binary op) const;
|
||||
|
||||
/// Return keys in contract mode.
|
||||
/**
|
||||
* Return keys in contract mode.
|
||||
*
|
||||
* Modes are each of the dimensions of a sparse tensor,
|
||||
* and the contract modes represent which dimensions will
|
||||
* be involved in contraction (aka tensor multiplication).
|
||||
*/
|
||||
DiscreteKeys contractDkeys(const TableFactor& f) const;
|
||||
|
||||
/// Return keys in free mode.
|
||||
/**
|
||||
* @brief Return keys in free mode which are the dimensions
|
||||
* not involved in the contraction operation.
|
||||
*/
|
||||
DiscreteKeys freeDkeys(const TableFactor& f) const;
|
||||
|
||||
/// Return union of DiscreteKeys in two factors.
|
||||
|
|
|
@ -93,8 +93,7 @@ void printTime(map<double, pair<chrono::microseconds, chrono::microseconds>>
|
|||
for (auto&& kv : measured_time) {
|
||||
cout << "dropout: " << kv.first
|
||||
<< " | TableFactor time: " << kv.second.first.count()
|
||||
<< " | DecisionTreeFactor time: " << kv.second.second.count() <<
|
||||
endl;
|
||||
<< " | DecisionTreeFactor time: " << kv.second.second.count() << endl;
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -361,6 +360,39 @@ TEST(TableFactor, htmlWithValueFormatter) {
|
|||
EXPECT(actual == expected);
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
TEST(TableFactor, Unary) {
|
||||
// Declare a bunch of keys
|
||||
DiscreteKey X(0, 2), Y(1, 3);
|
||||
|
||||
// Create factors
|
||||
TableFactor f(X & Y, "2 5 3 6 2 7");
|
||||
auto op = [](const double x) { return 2 * x; };
|
||||
auto g = f.apply(op);
|
||||
|
||||
TableFactor expected(X & Y, "4 10 6 12 4 14");
|
||||
EXPECT(assert_equal(g, expected));
|
||||
|
||||
auto sq_op = [](const double x) { return x * x; };
|
||||
auto g_sq = f.apply(sq_op);
|
||||
TableFactor expected_sq(X & Y, "4 25 9 36 4 49");
|
||||
EXPECT(assert_equal(g_sq, expected_sq));
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
TEST(TableFactor, UnaryAssignment) {
|
||||
// Declare a bunch of keys
|
||||
DiscreteKey X(0, 2), Y(1, 3);
|
||||
|
||||
// Create factors
|
||||
TableFactor f(X & Y, "2 5 3 6 2 7");
|
||||
auto op = [](const Assignment<Key>& key, const double x) { return 2 * x; };
|
||||
auto g = f.apply(op);
|
||||
|
||||
TableFactor expected(X & Y, "4 10 6 12 4 14");
|
||||
EXPECT(assert_equal(g, expected));
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
int main() {
|
||||
TestResult tr;
|
||||
|
|
|
@ -176,6 +176,7 @@ class HybridGaussianFactorGraph {
|
|||
void push_back(const gtsam::HybridBayesTree& bayesTree);
|
||||
void push_back(const gtsam::GaussianMixtureFactor* gmm);
|
||||
void push_back(gtsam::DecisionTreeFactor* factor);
|
||||
void push_back(gtsam::TableFactor* factor);
|
||||
void push_back(gtsam::JacobianFactor* factor);
|
||||
|
||||
bool empty() const;
|
||||
|
|
Loading…
Reference in New Issue