/* * Tensor3Expression.h * @brief Tensor expression templates based on http://www.gps.caltech.edu/~walter/FTensor/FTensor.pdf * Created on: Feb 10, 2010 * @author: Frank Dellaert */ #pragma once #include #include "tensors.h" namespace tensors { /** templated class to interface to an object A as a rank 3 tensor */ template class Tensor3Expression { A iter; typedef Tensor3Expression This; /** Helper class for instantiating one index */ class FixK_ { const int k; const A iter; public: FixK_(int k_, const A &a) : k(k_), iter(a) { } double operator()(int i, int j) const { return iter(i, j, k); } }; /** Helper class for contracting rank3 and rank1 tensor */ template class TimesRank1_ { typedef Tensor1Expression Rank1; const This T; const Rank1 t; public: TimesRank1_(const This &a, const Rank1 &b) : T(a), t(b) { } double operator()(int j, int k) const { double sum = 0.0; for (int i = 0; i < I::dim; i++) sum += T(i, j, k) * t(i); return sum; } }; public: /** constructor */ Tensor3Expression(const A &a) : iter(a) { } /** Print */ void print(const std::string& s = "Tensor3:") const { std::cout << s << "{"; (*this)(0).print(""); for (int k = 1; k < K::dim; k++) { std::cout << ","; (*this)(k).print(""); } std::cout << "}" << std::endl; } template bool equals(const Tensor3Expression & q, double tol) const { for (int k = 0; k < K::dim; k++) if (!(*this)(k).equals(q(k), tol)) return false; return true; } /** element access */ double operator()(int i, int j, int k) const { return iter(i, j, k); } /** Fix a single index */ Tensor2Expression operator()(int k) const { return FixK_(k, iter); } /** Contracting with rank1 tensor */ template inline Tensor2Expression , J, K> operator*( const Tensor1Expression &b) const { return TimesRank1_ (*this, b); } }; // Tensor3Expression /** Print */ template void print(const Tensor3Expression& T, const std::string& s = "Tensor3:") { T.print(s); } /** Helper class for outer product of rank2 and rank1 tensor */ template class Rank2Rank1_ { typedef Tensor2Expression Rank2; typedef Tensor1Expression Rank1; const Rank2 iterA; const Rank1 iterB; public: Rank2Rank1_(const Rank2 &a, const Rank1 &b) : iterA(a), iterB(b) { } double operator()(int i, int j, int k) const { return iterA(i, j) * iterB(k); } }; /** outer product of rank2 and rank1 tensor */ template inline Tensor3Expression , I, J, K> operator*( const Tensor2Expression& a, const Tensor1Expression &b) { return Rank2Rank1_ (a, b); } /** Helper class for outer product of rank1 and rank2 tensor */ template class Rank1Rank2_ { typedef Tensor1Expression Rank1; typedef Tensor2Expression Rank2; const Rank1 iterA; const Rank2 iterB; public: Rank1Rank2_(const Rank1 &a, const Rank2 &b) : iterA(a), iterB(b) { } double operator()(int i, int j, int k) const { return iterA(i) * iterB(j, k); } }; /** outer product of rank2 and rank1 tensor */ template inline Tensor3Expression , I, J, K> operator*( const Tensor1Expression& a, const Tensor2Expression &b) { return Rank1Rank2_ (a, b); } /** * This template works for any two expressions */ template bool assert_equality(const Tensor3Expression& expected, const Tensor3Expression& actual, double tol = 1e-9) { if (actual.equals(expected, tol)) return true; std::cout << "Not equal:\n"; expected.print("expected:\n"); actual.print("actual:\n"); return false; } } // namespace tensors