reshape matrix to tensor

release/4.3a0
Chris Beall 2012-06-19 18:03:01 +00:00
parent 9fe0b66be5
commit a0851f0eb4
2 changed files with 16 additions and 0 deletions

View File

@ -60,6 +60,18 @@ namespace gtsam {
return tensors::Tensor2<N1, N2>(data);
}
/** Reshape Matrix into rank 2 tensor */
template<int N1, int N2>
tensors::Tensor2<N1, N2> reshape2matrix(const Matrix& m) {
if (m.rows() * m.cols() != N1 * N2) throw std::invalid_argument(
"reshape2: incompatible dimensions");
double data[N2][N1];
for (int j = 0; j < N2; j++)
for (int i = 0; i < N1; i++)
data[j][i] = m(j,i);
return tensors::Tensor2<N1, N2>(data);
}
/** Reshape rank 3 tensor into Matrix */
template<class A, class I, class J, class K>
Matrix reshape(const tensors::Tensor3Expression<A, I, J, K>& T, int m, int n) {

View File

@ -188,6 +188,10 @@ TEST( Tensor2, reshape2 )
{
Tensor2<3,4> actual = reshape2<3,4>(camera::vector);
CHECK(assert_equality(camera::M(a,A),actual(a,A)));
// reshape Matrix to rank 2 tensor
Tensor2<3,4> actual_m = reshape2matrix<3,4>(camera::matrix);
CHECK(assert_equality(camera::M(a,A), actual_m(a,A)));
}
/* ************************************************************************* */