Add missing files that were omitted due to case-insensitive ignore.

release/4.3a0
Frank Dellaert 2023-01-22 13:38:16 -08:00
parent 6743c4f6bc
commit 8fb393fdf6
37 changed files with 19520 additions and 0 deletions

View File

@ -0,0 +1,413 @@
// This file is part of Eigen, a lightweight C++ template library
// for linear algebra.
//
// Copyright (C) 2017 Gael Guennebaud <gael.guennebaud@inria.fr>
//
// This Source Code Form is subject to the terms of the Mozilla
// Public License v. 2.0. If a copy of the MPL was not distributed
// with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
#ifndef EIGEN_ARITHMETIC_SEQUENCE_H
#define EIGEN_ARITHMETIC_SEQUENCE_H
namespace Eigen {
namespace internal {
#if (!EIGEN_HAS_CXX11) || !((!EIGEN_COMP_GNUC) || EIGEN_COMP_GNUC>=48)
template<typename T> struct aseq_negate {};
template<> struct aseq_negate<Index> {
typedef Index type;
};
template<int N> struct aseq_negate<FixedInt<N> > {
typedef FixedInt<-N> type;
};
// Compilation error in the following case:
template<> struct aseq_negate<FixedInt<DynamicIndex> > {};
template<typename FirstType,typename SizeType,typename IncrType,
bool FirstIsSymbolic=symbolic::is_symbolic<FirstType>::value,
bool SizeIsSymbolic =symbolic::is_symbolic<SizeType>::value>
struct aseq_reverse_first_type {
typedef Index type;
};
template<typename FirstType,typename SizeType,typename IncrType>
struct aseq_reverse_first_type<FirstType,SizeType,IncrType,true,true> {
typedef symbolic::AddExpr<FirstType,
symbolic::ProductExpr<symbolic::AddExpr<SizeType,symbolic::ValueExpr<FixedInt<-1> > >,
symbolic::ValueExpr<IncrType> >
> type;
};
template<typename SizeType,typename IncrType,typename EnableIf = void>
struct aseq_reverse_first_type_aux {
typedef Index type;
};
template<typename SizeType,typename IncrType>
struct aseq_reverse_first_type_aux<SizeType,IncrType,typename internal::enable_if<bool((SizeType::value+IncrType::value)|0x1)>::type> {
typedef FixedInt<(SizeType::value-1)*IncrType::value> type;
};
template<typename FirstType,typename SizeType,typename IncrType>
struct aseq_reverse_first_type<FirstType,SizeType,IncrType,true,false> {
typedef typename aseq_reverse_first_type_aux<SizeType,IncrType>::type Aux;
typedef symbolic::AddExpr<FirstType,symbolic::ValueExpr<Aux> > type;
};
template<typename FirstType,typename SizeType,typename IncrType>
struct aseq_reverse_first_type<FirstType,SizeType,IncrType,false,true> {
typedef symbolic::AddExpr<symbolic::ProductExpr<symbolic::AddExpr<SizeType,symbolic::ValueExpr<FixedInt<-1> > >,
symbolic::ValueExpr<IncrType> >,
symbolic::ValueExpr<> > type;
};
#endif
// Helper to cleanup the type of the increment:
template<typename T> struct cleanup_seq_incr {
typedef typename cleanup_index_type<T,DynamicIndex>::type type;
};
}
//--------------------------------------------------------------------------------
// seq(first,last,incr) and seqN(first,size,incr)
//--------------------------------------------------------------------------------
template<typename FirstType=Index,typename SizeType=Index,typename IncrType=internal::FixedInt<1> >
class ArithmeticSequence;
template<typename FirstType,typename SizeType,typename IncrType>
ArithmeticSequence<typename internal::cleanup_index_type<FirstType>::type,
typename internal::cleanup_index_type<SizeType>::type,
typename internal::cleanup_seq_incr<IncrType>::type >
seqN(FirstType first, SizeType size, IncrType incr);
/** \class ArithmeticSequence
* \ingroup Core_Module
*
* This class represents an arithmetic progression \f$ a_0, a_1, a_2, ..., a_{n-1}\f$ defined by
* its \em first value \f$ a_0 \f$, its \em size (aka length) \em n, and the \em increment (aka stride)
* that is equal to \f$ a_{i+1}-a_{i}\f$ for any \em i.
*
* It is internally used as the return type of the Eigen::seq and Eigen::seqN functions, and as the input arguments
* of DenseBase::operator()(const RowIndices&, const ColIndices&), and most of the time this is the
* only way it is used.
*
* \tparam FirstType type of the first element, usually an Index,
* but internally it can be a symbolic expression
* \tparam SizeType type representing the size of the sequence, usually an Index
* or a compile time integral constant. Internally, it can also be a symbolic expression
* \tparam IncrType type of the increment, can be a runtime Index, or a compile time integral constant (default is compile-time 1)
*
* \sa Eigen::seq, Eigen::seqN, DenseBase::operator()(const RowIndices&, const ColIndices&), class IndexedView
*/
template<typename FirstType,typename SizeType,typename IncrType>
class ArithmeticSequence
{
public:
ArithmeticSequence(FirstType first, SizeType size) : m_first(first), m_size(size) {}
ArithmeticSequence(FirstType first, SizeType size, IncrType incr) : m_first(first), m_size(size), m_incr(incr) {}
enum {
SizeAtCompileTime = internal::get_fixed_value<SizeType>::value,
IncrAtCompileTime = internal::get_fixed_value<IncrType,DynamicIndex>::value
};
/** \returns the size, i.e., number of elements, of the sequence */
Index size() const { return m_size; }
/** \returns the first element \f$ a_0 \f$ in the sequence */
Index first() const { return m_first; }
/** \returns the value \f$ a_i \f$ at index \a i in the sequence. */
Index operator[](Index i) const { return m_first + i * m_incr; }
const FirstType& firstObject() const { return m_first; }
const SizeType& sizeObject() const { return m_size; }
const IncrType& incrObject() const { return m_incr; }
protected:
FirstType m_first;
SizeType m_size;
IncrType m_incr;
public:
#if EIGEN_HAS_CXX11 && ((!EIGEN_COMP_GNUC) || EIGEN_COMP_GNUC>=48)
auto reverse() const -> decltype(Eigen::seqN(m_first+(m_size+fix<-1>())*m_incr,m_size,-m_incr)) {
return seqN(m_first+(m_size+fix<-1>())*m_incr,m_size,-m_incr);
}
#else
protected:
typedef typename internal::aseq_negate<IncrType>::type ReverseIncrType;
typedef typename internal::aseq_reverse_first_type<FirstType,SizeType,IncrType>::type ReverseFirstType;
public:
ArithmeticSequence<ReverseFirstType,SizeType,ReverseIncrType>
reverse() const {
return seqN(m_first+(m_size+fix<-1>())*m_incr,m_size,-m_incr);
}
#endif
};
/** \returns an ArithmeticSequence starting at \a first, of length \a size, and increment \a incr
*
* \sa seqN(FirstType,SizeType), seq(FirstType,LastType,IncrType) */
template<typename FirstType,typename SizeType,typename IncrType>
ArithmeticSequence<typename internal::cleanup_index_type<FirstType>::type,typename internal::cleanup_index_type<SizeType>::type,typename internal::cleanup_seq_incr<IncrType>::type >
seqN(FirstType first, SizeType size, IncrType incr) {
return ArithmeticSequence<typename internal::cleanup_index_type<FirstType>::type,typename internal::cleanup_index_type<SizeType>::type,typename internal::cleanup_seq_incr<IncrType>::type>(first,size,incr);
}
/** \returns an ArithmeticSequence starting at \a first, of length \a size, and unit increment
*
* \sa seqN(FirstType,SizeType,IncrType), seq(FirstType,LastType) */
template<typename FirstType,typename SizeType>
ArithmeticSequence<typename internal::cleanup_index_type<FirstType>::type,typename internal::cleanup_index_type<SizeType>::type >
seqN(FirstType first, SizeType size) {
return ArithmeticSequence<typename internal::cleanup_index_type<FirstType>::type,typename internal::cleanup_index_type<SizeType>::type>(first,size);
}
#ifdef EIGEN_PARSED_BY_DOXYGEN
/** \returns an ArithmeticSequence starting at \a f, up (or down) to \a l, and with positive (or negative) increment \a incr
*
* It is essentially an alias to:
* \code
* seqN(f, (l-f+incr)/incr, incr);
* \endcode
*
* \sa seqN(FirstType,SizeType,IncrType), seq(FirstType,LastType)
*/
template<typename FirstType,typename LastType, typename IncrType>
auto seq(FirstType f, LastType l, IncrType incr);
/** \returns an ArithmeticSequence starting at \a f, up (or down) to \a l, and unit increment
*
* It is essentially an alias to:
* \code
* seqN(f,l-f+1);
* \endcode
*
* \sa seqN(FirstType,SizeType), seq(FirstType,LastType,IncrType)
*/
template<typename FirstType,typename LastType>
auto seq(FirstType f, LastType l);
#else // EIGEN_PARSED_BY_DOXYGEN
#if EIGEN_HAS_CXX11
template<typename FirstType,typename LastType>
auto seq(FirstType f, LastType l) -> decltype(seqN(typename internal::cleanup_index_type<FirstType>::type(f),
( typename internal::cleanup_index_type<LastType>::type(l)
- typename internal::cleanup_index_type<FirstType>::type(f)+fix<1>())))
{
return seqN(typename internal::cleanup_index_type<FirstType>::type(f),
(typename internal::cleanup_index_type<LastType>::type(l)
-typename internal::cleanup_index_type<FirstType>::type(f)+fix<1>()));
}
template<typename FirstType,typename LastType, typename IncrType>
auto seq(FirstType f, LastType l, IncrType incr)
-> decltype(seqN(typename internal::cleanup_index_type<FirstType>::type(f),
( typename internal::cleanup_index_type<LastType>::type(l)
- typename internal::cleanup_index_type<FirstType>::type(f)+typename internal::cleanup_seq_incr<IncrType>::type(incr)
) / typename internal::cleanup_seq_incr<IncrType>::type(incr),
typename internal::cleanup_seq_incr<IncrType>::type(incr)))
{
typedef typename internal::cleanup_seq_incr<IncrType>::type CleanedIncrType;
return seqN(typename internal::cleanup_index_type<FirstType>::type(f),
( typename internal::cleanup_index_type<LastType>::type(l)
-typename internal::cleanup_index_type<FirstType>::type(f)+CleanedIncrType(incr)) / CleanedIncrType(incr),
CleanedIncrType(incr));
}
#else // EIGEN_HAS_CXX11
template<typename FirstType,typename LastType>
typename internal::enable_if<!(symbolic::is_symbolic<FirstType>::value || symbolic::is_symbolic<LastType>::value),
ArithmeticSequence<typename internal::cleanup_index_type<FirstType>::type,Index> >::type
seq(FirstType f, LastType l)
{
return seqN(typename internal::cleanup_index_type<FirstType>::type(f),
Index((typename internal::cleanup_index_type<LastType>::type(l)-typename internal::cleanup_index_type<FirstType>::type(f)+fix<1>())));
}
template<typename FirstTypeDerived,typename LastType>
typename internal::enable_if<!symbolic::is_symbolic<LastType>::value,
ArithmeticSequence<FirstTypeDerived, symbolic::AddExpr<symbolic::AddExpr<symbolic::NegateExpr<FirstTypeDerived>,symbolic::ValueExpr<> >,
symbolic::ValueExpr<internal::FixedInt<1> > > > >::type
seq(const symbolic::BaseExpr<FirstTypeDerived> &f, LastType l)
{
return seqN(f.derived(),(typename internal::cleanup_index_type<LastType>::type(l)-f.derived()+fix<1>()));
}
template<typename FirstType,typename LastTypeDerived>
typename internal::enable_if<!symbolic::is_symbolic<FirstType>::value,
ArithmeticSequence<typename internal::cleanup_index_type<FirstType>::type,
symbolic::AddExpr<symbolic::AddExpr<LastTypeDerived,symbolic::ValueExpr<> >,
symbolic::ValueExpr<internal::FixedInt<1> > > > >::type
seq(FirstType f, const symbolic::BaseExpr<LastTypeDerived> &l)
{
return seqN(typename internal::cleanup_index_type<FirstType>::type(f),(l.derived()-typename internal::cleanup_index_type<FirstType>::type(f)+fix<1>()));
}
template<typename FirstTypeDerived,typename LastTypeDerived>
ArithmeticSequence<FirstTypeDerived,
symbolic::AddExpr<symbolic::AddExpr<LastTypeDerived,symbolic::NegateExpr<FirstTypeDerived> >,symbolic::ValueExpr<internal::FixedInt<1> > > >
seq(const symbolic::BaseExpr<FirstTypeDerived> &f, const symbolic::BaseExpr<LastTypeDerived> &l)
{
return seqN(f.derived(),(l.derived()-f.derived()+fix<1>()));
}
template<typename FirstType,typename LastType, typename IncrType>
typename internal::enable_if<!(symbolic::is_symbolic<FirstType>::value || symbolic::is_symbolic<LastType>::value),
ArithmeticSequence<typename internal::cleanup_index_type<FirstType>::type,Index,typename internal::cleanup_seq_incr<IncrType>::type> >::type
seq(FirstType f, LastType l, IncrType incr)
{
typedef typename internal::cleanup_seq_incr<IncrType>::type CleanedIncrType;
return seqN(typename internal::cleanup_index_type<FirstType>::type(f),
Index((typename internal::cleanup_index_type<LastType>::type(l)-typename internal::cleanup_index_type<FirstType>::type(f)+CleanedIncrType(incr))/CleanedIncrType(incr)), incr);
}
template<typename FirstTypeDerived,typename LastType, typename IncrType>
typename internal::enable_if<!symbolic::is_symbolic<LastType>::value,
ArithmeticSequence<FirstTypeDerived,
symbolic::QuotientExpr<symbolic::AddExpr<symbolic::AddExpr<symbolic::NegateExpr<FirstTypeDerived>,
symbolic::ValueExpr<> >,
symbolic::ValueExpr<typename internal::cleanup_seq_incr<IncrType>::type> >,
symbolic::ValueExpr<typename internal::cleanup_seq_incr<IncrType>::type> >,
typename internal::cleanup_seq_incr<IncrType>::type> >::type
seq(const symbolic::BaseExpr<FirstTypeDerived> &f, LastType l, IncrType incr)
{
typedef typename internal::cleanup_seq_incr<IncrType>::type CleanedIncrType;
return seqN(f.derived(),(typename internal::cleanup_index_type<LastType>::type(l)-f.derived()+CleanedIncrType(incr))/CleanedIncrType(incr), incr);
}
template<typename FirstType,typename LastTypeDerived, typename IncrType>
typename internal::enable_if<!symbolic::is_symbolic<FirstType>::value,
ArithmeticSequence<typename internal::cleanup_index_type<FirstType>::type,
symbolic::QuotientExpr<symbolic::AddExpr<symbolic::AddExpr<LastTypeDerived,symbolic::ValueExpr<> >,
symbolic::ValueExpr<typename internal::cleanup_seq_incr<IncrType>::type> >,
symbolic::ValueExpr<typename internal::cleanup_seq_incr<IncrType>::type> >,
typename internal::cleanup_seq_incr<IncrType>::type> >::type
seq(FirstType f, const symbolic::BaseExpr<LastTypeDerived> &l, IncrType incr)
{
typedef typename internal::cleanup_seq_incr<IncrType>::type CleanedIncrType;
return seqN(typename internal::cleanup_index_type<FirstType>::type(f),
(l.derived()-typename internal::cleanup_index_type<FirstType>::type(f)+CleanedIncrType(incr))/CleanedIncrType(incr), incr);
}
template<typename FirstTypeDerived,typename LastTypeDerived, typename IncrType>
ArithmeticSequence<FirstTypeDerived,
symbolic::QuotientExpr<symbolic::AddExpr<symbolic::AddExpr<LastTypeDerived,
symbolic::NegateExpr<FirstTypeDerived> >,
symbolic::ValueExpr<typename internal::cleanup_seq_incr<IncrType>::type> >,
symbolic::ValueExpr<typename internal::cleanup_seq_incr<IncrType>::type> >,
typename internal::cleanup_seq_incr<IncrType>::type>
seq(const symbolic::BaseExpr<FirstTypeDerived> &f, const symbolic::BaseExpr<LastTypeDerived> &l, IncrType incr)
{
typedef typename internal::cleanup_seq_incr<IncrType>::type CleanedIncrType;
return seqN(f.derived(),(l.derived()-f.derived()+CleanedIncrType(incr))/CleanedIncrType(incr), incr);
}
#endif // EIGEN_HAS_CXX11
#endif // EIGEN_PARSED_BY_DOXYGEN
#if EIGEN_HAS_CXX11 || defined(EIGEN_PARSED_BY_DOXYGEN)
/** \cpp11
* \returns a symbolic ArithmeticSequence representing the last \a size elements with increment \a incr.
*
* It is a shortcut for: \code seqN(last-(size-fix<1>)*incr, size, incr) \endcode
*
* \sa lastN(SizeType), seqN(FirstType,SizeType), seq(FirstType,LastType,IncrType) */
template<typename SizeType,typename IncrType>
auto lastN(SizeType size, IncrType incr)
-> decltype(seqN(Eigen::last-(size-fix<1>())*incr, size, incr))
{
return seqN(Eigen::last-(size-fix<1>())*incr, size, incr);
}
/** \cpp11
* \returns a symbolic ArithmeticSequence representing the last \a size elements with a unit increment.
*
* It is a shortcut for: \code seq(last+fix<1>-size, last) \endcode
*
* \sa lastN(SizeType,IncrType, seqN(FirstType,SizeType), seq(FirstType,LastType) */
template<typename SizeType>
auto lastN(SizeType size)
-> decltype(seqN(Eigen::last+fix<1>()-size, size))
{
return seqN(Eigen::last+fix<1>()-size, size);
}
#endif
namespace internal {
// Convert a symbolic span into a usable one (i.e., remove last/end "keywords")
template<typename T>
struct make_size_type {
typedef typename internal::conditional<symbolic::is_symbolic<T>::value, Index, T>::type type;
};
template<typename FirstType,typename SizeType,typename IncrType,int XprSize>
struct IndexedViewCompatibleType<ArithmeticSequence<FirstType,SizeType,IncrType>, XprSize> {
typedef ArithmeticSequence<Index,typename make_size_type<SizeType>::type,IncrType> type;
};
template<typename FirstType,typename SizeType,typename IncrType>
ArithmeticSequence<Index,typename make_size_type<SizeType>::type,IncrType>
makeIndexedViewCompatible(const ArithmeticSequence<FirstType,SizeType,IncrType>& ids, Index size,SpecializedType) {
return ArithmeticSequence<Index,typename make_size_type<SizeType>::type,IncrType>(
eval_expr_given_size(ids.firstObject(),size),eval_expr_given_size(ids.sizeObject(),size),ids.incrObject());
}
template<typename FirstType,typename SizeType,typename IncrType>
struct get_compile_time_incr<ArithmeticSequence<FirstType,SizeType,IncrType> > {
enum { value = get_fixed_value<IncrType,DynamicIndex>::value };
};
} // end namespace internal
/** \namespace Eigen::indexing
* \ingroup Core_Module
*
* The sole purpose of this namespace is to be able to import all functions
* and symbols that are expected to be used within operator() for indexing
* and slicing. If you already imported the whole Eigen namespace:
* \code using namespace Eigen; \endcode
* then you are already all set. Otherwise, if you don't want/cannot import
* the whole Eigen namespace, the following line:
* \code using namespace Eigen::indexing; \endcode
* is equivalent to:
* \code
using Eigen::all;
using Eigen::seq;
using Eigen::seqN;
using Eigen::lastN; // c++11 only
using Eigen::last;
using Eigen::lastp1;
using Eigen::fix;
\endcode
*/
namespace indexing {
using Eigen::all;
using Eigen::seq;
using Eigen::seqN;
#if EIGEN_HAS_CXX11
using Eigen::lastN;
#endif
using Eigen::last;
using Eigen::lastp1;
using Eigen::fix;
}
} // end namespace Eigen
#endif // EIGEN_ARITHMETIC_SEQUENCE_H

View File

@ -0,0 +1,237 @@
// This file is part of Eigen, a lightweight C++ template library
// for linear algebra.
//
// Copyright (C) 2017 Gael Guennebaud <gael.guennebaud@inria.fr>
//
// This Source Code Form is subject to the terms of the Mozilla
// Public License v. 2.0. If a copy of the MPL was not distributed
// with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
#ifndef EIGEN_INDEXED_VIEW_H
#define EIGEN_INDEXED_VIEW_H
namespace Eigen {
namespace internal {
template<typename XprType, typename RowIndices, typename ColIndices>
struct traits<IndexedView<XprType, RowIndices, ColIndices> >
: traits<XprType>
{
enum {
RowsAtCompileTime = int(array_size<RowIndices>::value),
ColsAtCompileTime = int(array_size<ColIndices>::value),
MaxRowsAtCompileTime = RowsAtCompileTime != Dynamic ? int(RowsAtCompileTime) : Dynamic,
MaxColsAtCompileTime = ColsAtCompileTime != Dynamic ? int(ColsAtCompileTime) : Dynamic,
XprTypeIsRowMajor = (int(traits<XprType>::Flags)&RowMajorBit) != 0,
IsRowMajor = (MaxRowsAtCompileTime==1&&MaxColsAtCompileTime!=1) ? 1
: (MaxColsAtCompileTime==1&&MaxRowsAtCompileTime!=1) ? 0
: XprTypeIsRowMajor,
RowIncr = int(get_compile_time_incr<RowIndices>::value),
ColIncr = int(get_compile_time_incr<ColIndices>::value),
InnerIncr = IsRowMajor ? ColIncr : RowIncr,
OuterIncr = IsRowMajor ? RowIncr : ColIncr,
HasSameStorageOrderAsXprType = (IsRowMajor == XprTypeIsRowMajor),
XprInnerStride = HasSameStorageOrderAsXprType ? int(inner_stride_at_compile_time<XprType>::ret) : int(outer_stride_at_compile_time<XprType>::ret),
XprOuterstride = HasSameStorageOrderAsXprType ? int(outer_stride_at_compile_time<XprType>::ret) : int(inner_stride_at_compile_time<XprType>::ret),
InnerSize = XprTypeIsRowMajor ? ColsAtCompileTime : RowsAtCompileTime,
IsBlockAlike = InnerIncr==1 && OuterIncr==1,
IsInnerPannel = HasSameStorageOrderAsXprType && is_same<AllRange<InnerSize>,typename conditional<XprTypeIsRowMajor,ColIndices,RowIndices>::type>::value,
InnerStrideAtCompileTime = InnerIncr<0 || InnerIncr==DynamicIndex || XprInnerStride==Dynamic ? Dynamic : XprInnerStride * InnerIncr,
OuterStrideAtCompileTime = OuterIncr<0 || OuterIncr==DynamicIndex || XprOuterstride==Dynamic ? Dynamic : XprOuterstride * OuterIncr,
ReturnAsScalar = is_same<RowIndices,SingleRange>::value && is_same<ColIndices,SingleRange>::value,
ReturnAsBlock = (!ReturnAsScalar) && IsBlockAlike,
ReturnAsIndexedView = (!ReturnAsScalar) && (!ReturnAsBlock),
// FIXME we deal with compile-time strides if and only if we have DirectAccessBit flag,
// but this is too strict regarding negative strides...
DirectAccessMask = (int(InnerIncr)!=UndefinedIncr && int(OuterIncr)!=UndefinedIncr && InnerIncr>=0 && OuterIncr>=0) ? DirectAccessBit : 0,
FlagsRowMajorBit = IsRowMajor ? RowMajorBit : 0,
FlagsLvalueBit = is_lvalue<XprType>::value ? LvalueBit : 0,
FlagsLinearAccessBit = (RowsAtCompileTime == 1 || ColsAtCompileTime == 1) ? LinearAccessBit : 0,
Flags = (traits<XprType>::Flags & (HereditaryBits | DirectAccessMask )) | FlagsLvalueBit | FlagsRowMajorBit | FlagsLinearAccessBit
};
typedef Block<XprType,RowsAtCompileTime,ColsAtCompileTime,IsInnerPannel> BlockType;
};
}
template<typename XprType, typename RowIndices, typename ColIndices, typename StorageKind>
class IndexedViewImpl;
/** \class IndexedView
* \ingroup Core_Module
*
* \brief Expression of a non-sequential sub-matrix defined by arbitrary sequences of row and column indices
*
* \tparam XprType the type of the expression in which we are taking the intersections of sub-rows and sub-columns
* \tparam RowIndices the type of the object defining the sequence of row indices
* \tparam ColIndices the type of the object defining the sequence of column indices
*
* This class represents an expression of a sub-matrix (or sub-vector) defined as the intersection
* of sub-sets of rows and columns, that are themself defined by generic sequences of row indices \f$ \{r_0,r_1,..r_{m-1}\} \f$
* and column indices \f$ \{c_0,c_1,..c_{n-1} \}\f$. Let \f$ A \f$ be the nested matrix, then the resulting matrix \f$ B \f$ has \c m
* rows and \c n columns, and its entries are given by: \f$ B(i,j) = A(r_i,c_j) \f$.
*
* The \c RowIndices and \c ColIndices types must be compatible with the following API:
* \code
* <integral type> operator[](Index) const;
* Index size() const;
* \endcode
*
* Typical supported types thus include:
* - std::vector<int>
* - std::valarray<int>
* - std::array<int>
* - Plain C arrays: int[N]
* - Eigen::ArrayXi
* - decltype(ArrayXi::LinSpaced(...))
* - Any view/expressions of the previous types
* - Eigen::ArithmeticSequence
* - Eigen::internal::AllRange (helper for Eigen::all)
* - Eigen::internal::SingleRange (helper for single index)
* - etc.
*
* In typical usages of %Eigen, this class should never be used directly. It is the return type of
* DenseBase::operator()(const RowIndices&, const ColIndices&).
*
* \sa class Block
*/
template<typename XprType, typename RowIndices, typename ColIndices>
class IndexedView : public IndexedViewImpl<XprType, RowIndices, ColIndices, typename internal::traits<XprType>::StorageKind>
{
public:
typedef typename IndexedViewImpl<XprType, RowIndices, ColIndices, typename internal::traits<XprType>::StorageKind>::Base Base;
EIGEN_GENERIC_PUBLIC_INTERFACE(IndexedView)
EIGEN_INHERIT_ASSIGNMENT_OPERATORS(IndexedView)
typedef typename internal::ref_selector<XprType>::non_const_type MatrixTypeNested;
typedef typename internal::remove_all<XprType>::type NestedExpression;
template<typename T0, typename T1>
IndexedView(XprType& xpr, const T0& rowIndices, const T1& colIndices)
: m_xpr(xpr), m_rowIndices(rowIndices), m_colIndices(colIndices)
{}
/** \returns number of rows */
Index rows() const { return internal::size(m_rowIndices); }
/** \returns number of columns */
Index cols() const { return internal::size(m_colIndices); }
/** \returns the nested expression */
const typename internal::remove_all<XprType>::type&
nestedExpression() const { return m_xpr; }
/** \returns the nested expression */
typename internal::remove_reference<XprType>::type&
nestedExpression() { return m_xpr; }
/** \returns a const reference to the object storing/generating the row indices */
const RowIndices& rowIndices() const { return m_rowIndices; }
/** \returns a const reference to the object storing/generating the column indices */
const ColIndices& colIndices() const { return m_colIndices; }
protected:
MatrixTypeNested m_xpr;
RowIndices m_rowIndices;
ColIndices m_colIndices;
};
// Generic API dispatcher
template<typename XprType, typename RowIndices, typename ColIndices, typename StorageKind>
class IndexedViewImpl
: public internal::generic_xpr_base<IndexedView<XprType, RowIndices, ColIndices> >::type
{
public:
typedef typename internal::generic_xpr_base<IndexedView<XprType, RowIndices, ColIndices> >::type Base;
};
namespace internal {
template<typename ArgType, typename RowIndices, typename ColIndices>
struct unary_evaluator<IndexedView<ArgType, RowIndices, ColIndices>, IndexBased>
: evaluator_base<IndexedView<ArgType, RowIndices, ColIndices> >
{
typedef IndexedView<ArgType, RowIndices, ColIndices> XprType;
enum {
CoeffReadCost = evaluator<ArgType>::CoeffReadCost /* TODO + cost of row/col index */,
FlagsLinearAccessBit = (traits<XprType>::RowsAtCompileTime == 1 || traits<XprType>::ColsAtCompileTime == 1) ? LinearAccessBit : 0,
FlagsRowMajorBit = traits<XprType>::FlagsRowMajorBit,
Flags = (evaluator<ArgType>::Flags & (HereditaryBits & ~RowMajorBit /*| LinearAccessBit | DirectAccessBit*/)) | FlagsLinearAccessBit | FlagsRowMajorBit,
Alignment = 0
};
EIGEN_DEVICE_FUNC explicit unary_evaluator(const XprType& xpr) : m_argImpl(xpr.nestedExpression()), m_xpr(xpr)
{
EIGEN_INTERNAL_CHECK_COST_VALUE(CoeffReadCost);
}
typedef typename XprType::Scalar Scalar;
typedef typename XprType::CoeffReturnType CoeffReturnType;
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
CoeffReturnType coeff(Index row, Index col) const
{
return m_argImpl.coeff(m_xpr.rowIndices()[row], m_xpr.colIndices()[col]);
}
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
Scalar& coeffRef(Index row, Index col)
{
return m_argImpl.coeffRef(m_xpr.rowIndices()[row], m_xpr.colIndices()[col]);
}
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
Scalar& coeffRef(Index index)
{
EIGEN_STATIC_ASSERT_LVALUE(XprType)
Index row = XprType::RowsAtCompileTime == 1 ? 0 : index;
Index col = XprType::RowsAtCompileTime == 1 ? index : 0;
return m_argImpl.coeffRef( m_xpr.rowIndices()[row], m_xpr.colIndices()[col]);
}
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
const Scalar& coeffRef(Index index) const
{
Index row = XprType::RowsAtCompileTime == 1 ? 0 : index;
Index col = XprType::RowsAtCompileTime == 1 ? index : 0;
return m_argImpl.coeffRef( m_xpr.rowIndices()[row], m_xpr.colIndices()[col]);
}
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
const CoeffReturnType coeff(Index index) const
{
Index row = XprType::RowsAtCompileTime == 1 ? 0 : index;
Index col = XprType::RowsAtCompileTime == 1 ? index : 0;
return m_argImpl.coeff( m_xpr.rowIndices()[row], m_xpr.colIndices()[col]);
}
protected:
evaluator<ArgType> m_argImpl;
const XprType& m_xpr;
};
} // end namespace internal
} // end namespace Eigen
#endif // EIGEN_INDEXED_VIEW_H

View File

@ -0,0 +1,232 @@
// This file is part of Eigen, a lightweight C++ template library
// for linear algebra.
//
// Copyright (C) 2011-2018 Gael Guennebaud <gael.guennebaud@inria.fr>
//
// This Source Code Form is subject to the terms of the Mozilla
// Public License v. 2.0. If a copy of the MPL was not distributed
// with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
#ifndef EIGEN_PARTIALREDUX_H
#define EIGEN_PARTIALREDUX_H
namespace Eigen {
namespace internal {
/***************************************************************************
*
* This file provides evaluators for partial reductions.
* There are two modes:
*
* - scalar path: simply calls the respective function on the column or row.
* -> nothing special here, all the tricky part is handled by the return
* types of VectorwiseOp's members. They embed the functor calling the
* respective DenseBase's member function.
*
* - vectorized path: implements a packet-wise reductions followed by
* some (optional) processing of the outcome, e.g., division by n for mean.
*
* For the vectorized path let's observe that the packet-size and outer-unrolling
* are both decided by the assignement logic. So all we have to do is to decide
* on the inner unrolling.
*
* For the unrolling, we can reuse "internal::redux_vec_unroller" from Redux.h,
* but be need to be careful to specify correct increment.
*
***************************************************************************/
/* logic deciding a strategy for unrolling of vectorized paths */
template<typename Func, typename Evaluator>
struct packetwise_redux_traits
{
enum {
OuterSize = int(Evaluator::IsRowMajor) ? Evaluator::RowsAtCompileTime : Evaluator::ColsAtCompileTime,
Cost = OuterSize == Dynamic ? HugeCost
: OuterSize * Evaluator::CoeffReadCost + (OuterSize-1) * functor_traits<Func>::Cost,
Unrolling = Cost <= EIGEN_UNROLLING_LIMIT ? CompleteUnrolling : NoUnrolling
};
};
/* Value to be returned when size==0 , by default let's return 0 */
template<typename PacketType,typename Func>
EIGEN_DEVICE_FUNC
PacketType packetwise_redux_empty_value(const Func& ) { return pset1<PacketType>(0); }
/* For products the default is 1 */
template<typename PacketType,typename Scalar>
EIGEN_DEVICE_FUNC
PacketType packetwise_redux_empty_value(const scalar_product_op<Scalar,Scalar>& ) { return pset1<PacketType>(1); }
/* Perform the actual reduction */
template<typename Func, typename Evaluator,
int Unrolling = packetwise_redux_traits<Func, Evaluator>::Unrolling
>
struct packetwise_redux_impl;
/* Perform the actual reduction with unrolling */
template<typename Func, typename Evaluator>
struct packetwise_redux_impl<Func, Evaluator, CompleteUnrolling>
{
typedef redux_novec_unroller<Func,Evaluator, 0, Evaluator::SizeAtCompileTime> Base;
typedef typename Evaluator::Scalar Scalar;
template<typename PacketType>
EIGEN_DEVICE_FUNC static EIGEN_STRONG_INLINE
PacketType run(const Evaluator &eval, const Func& func, Index /*size*/)
{
return redux_vec_unroller<Func, Evaluator, 0, packetwise_redux_traits<Func, Evaluator>::OuterSize>::template run<PacketType>(eval,func);
}
};
/* Add a specialization of redux_vec_unroller for size==0 at compiletime.
* This specialization is not required for general reductions, which is
* why it is defined here.
*/
template<typename Func, typename Evaluator, int Start>
struct redux_vec_unroller<Func, Evaluator, Start, 0>
{
template<typename PacketType>
EIGEN_DEVICE_FUNC
static EIGEN_STRONG_INLINE PacketType run(const Evaluator &, const Func& f)
{
return packetwise_redux_empty_value<PacketType>(f);
}
};
/* Perform the actual reduction for dynamic sizes */
template<typename Func, typename Evaluator>
struct packetwise_redux_impl<Func, Evaluator, NoUnrolling>
{
typedef typename Evaluator::Scalar Scalar;
typedef typename redux_traits<Func, Evaluator>::PacketType PacketScalar;
template<typename PacketType>
EIGEN_DEVICE_FUNC
static PacketType run(const Evaluator &eval, const Func& func, Index size)
{
if(size==0)
return packetwise_redux_empty_value<PacketType>(func);
const Index size4 = (size-1)&(~3);
PacketType p = eval.template packetByOuterInner<Unaligned,PacketType>(0,0);
Index i = 1;
// This loop is optimized for instruction pipelining:
// - each iteration generates two independent instructions
// - thanks to branch prediction and out-of-order execution we have independent instructions across loops
for(; i<size4; i+=4)
p = func.packetOp(p,
func.packetOp(
func.packetOp(eval.template packetByOuterInner<Unaligned,PacketType>(i+0,0),eval.template packetByOuterInner<Unaligned,PacketType>(i+1,0)),
func.packetOp(eval.template packetByOuterInner<Unaligned,PacketType>(i+2,0),eval.template packetByOuterInner<Unaligned,PacketType>(i+3,0))));
for(; i<size; ++i)
p = func.packetOp(p, eval.template packetByOuterInner<Unaligned,PacketType>(i,0));
return p;
}
};
template< typename ArgType, typename MemberOp, int Direction>
struct evaluator<PartialReduxExpr<ArgType, MemberOp, Direction> >
: evaluator_base<PartialReduxExpr<ArgType, MemberOp, Direction> >
{
typedef PartialReduxExpr<ArgType, MemberOp, Direction> XprType;
typedef typename internal::nested_eval<ArgType,1>::type ArgTypeNested;
typedef typename internal::add_const_on_value_type<ArgTypeNested>::type ConstArgTypeNested;
typedef typename internal::remove_all<ArgTypeNested>::type ArgTypeNestedCleaned;
typedef typename ArgType::Scalar InputScalar;
typedef typename XprType::Scalar Scalar;
enum {
TraversalSize = Direction==int(Vertical) ? int(ArgType::RowsAtCompileTime) : int(ArgType::ColsAtCompileTime)
};
typedef typename MemberOp::template Cost<int(TraversalSize)> CostOpType;
enum {
CoeffReadCost = TraversalSize==Dynamic ? HugeCost
: TraversalSize==0 ? 1
: int(TraversalSize) * int(evaluator<ArgType>::CoeffReadCost) + int(CostOpType::value),
_ArgFlags = evaluator<ArgType>::Flags,
_Vectorizable = bool(int(_ArgFlags)&PacketAccessBit)
&& bool(MemberOp::Vectorizable)
&& (Direction==int(Vertical) ? bool(_ArgFlags&RowMajorBit) : (_ArgFlags&RowMajorBit)==0)
&& (TraversalSize!=0),
Flags = (traits<XprType>::Flags&RowMajorBit)
| (evaluator<ArgType>::Flags&(HereditaryBits&(~RowMajorBit)))
| (_Vectorizable ? PacketAccessBit : 0)
| LinearAccessBit,
Alignment = 0 // FIXME this will need to be improved once PartialReduxExpr is vectorized
};
EIGEN_DEVICE_FUNC explicit evaluator(const XprType xpr)
: m_arg(xpr.nestedExpression()), m_functor(xpr.functor())
{
EIGEN_INTERNAL_CHECK_COST_VALUE(TraversalSize==Dynamic ? HugeCost : (TraversalSize==0 ? 1 : int(CostOpType::value)));
EIGEN_INTERNAL_CHECK_COST_VALUE(CoeffReadCost);
}
typedef typename XprType::CoeffReturnType CoeffReturnType;
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
const Scalar coeff(Index i, Index j) const
{
return coeff(Direction==Vertical ? j : i);
}
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
const Scalar coeff(Index index) const
{
return m_functor(m_arg.template subVector<DirectionType(Direction)>(index));
}
template<int LoadMode,typename PacketType>
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
PacketType packet(Index i, Index j) const
{
return packet<LoadMode,PacketType>(Direction==Vertical ? j : i);
}
template<int LoadMode,typename PacketType>
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC
PacketType packet(Index idx) const
{
enum { PacketSize = internal::unpacket_traits<PacketType>::size };
typedef Block<const ArgTypeNestedCleaned,
Direction==Vertical ? int(ArgType::RowsAtCompileTime) : int(PacketSize),
Direction==Vertical ? int(PacketSize) : int(ArgType::ColsAtCompileTime),
true /* InnerPanel */> PanelType;
PanelType panel(m_arg,
Direction==Vertical ? 0 : idx,
Direction==Vertical ? idx : 0,
Direction==Vertical ? m_arg.rows() : Index(PacketSize),
Direction==Vertical ? Index(PacketSize) : m_arg.cols());
// FIXME
// See bug 1612, currently if PacketSize==1 (i.e. complex<double> with 128bits registers) then the storage-order of panel get reversed
// and methods like packetByOuterInner do not make sense anymore in this context.
// So let's just by pass "vectorization" in this case:
if(PacketSize==1)
return internal::pset1<PacketType>(coeff(idx));
typedef typename internal::redux_evaluator<PanelType> PanelEvaluator;
PanelEvaluator panel_eval(panel);
typedef typename MemberOp::BinaryOp BinaryOp;
PacketType p = internal::packetwise_redux_impl<BinaryOp,PanelEvaluator>::template run<PacketType>(panel_eval,m_functor.binaryFunc(),m_arg.outerSize());
return p;
}
protected:
ConstArgTypeNested m_arg;
const MemberOp m_functor;
};
} // end namespace internal
} // end namespace Eigen
#endif // EIGEN_PARTIALREDUX_H

View File

@ -0,0 +1,454 @@
// This file is part of Eigen, a lightweight C++ template library
// for linear algebra.
//
// Copyright (C) 2008-2017 Gael Guennebaud <gael.guennebaud@inria.fr>
// Copyright (C) 2014 yoco <peter.xiau@gmail.com>
//
// This Source Code Form is subject to the terms of the Mozilla
// Public License v. 2.0. If a copy of the MPL was not distributed
// with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
#ifndef EIGEN_RESHAPED_H
#define EIGEN_RESHAPED_H
namespace Eigen {
/** \class Reshaped
* \ingroup Core_Module
*
* \brief Expression of a fixed-size or dynamic-size reshape
*
* \tparam XprType the type of the expression in which we are taking a reshape
* \tparam Rows the number of rows of the reshape we are taking at compile time (optional)
* \tparam Cols the number of columns of the reshape we are taking at compile time (optional)
* \tparam Order can be ColMajor or RowMajor, default is ColMajor.
*
* This class represents an expression of either a fixed-size or dynamic-size reshape.
* It is the return type of DenseBase::reshaped(NRowsType,NColsType) and
* most of the time this is the only way it is used.
*
* However, in C++98, if you want to directly maniputate reshaped expressions,
* for instance if you want to write a function returning such an expression, you
* will need to use this class. In C++11, it is advised to use the \em auto
* keyword for such use cases.
*
* Here is an example illustrating the dynamic case:
* \include class_Reshaped.cpp
* Output: \verbinclude class_Reshaped.out
*
* Here is an example illustrating the fixed-size case:
* \include class_FixedReshaped.cpp
* Output: \verbinclude class_FixedReshaped.out
*
* \sa DenseBase::reshaped(NRowsType,NColsType)
*/
namespace internal {
template<typename XprType, int Rows, int Cols, int Order>
struct traits<Reshaped<XprType, Rows, Cols, Order> > : traits<XprType>
{
typedef typename traits<XprType>::Scalar Scalar;
typedef typename traits<XprType>::StorageKind StorageKind;
typedef typename traits<XprType>::XprKind XprKind;
enum{
MatrixRows = traits<XprType>::RowsAtCompileTime,
MatrixCols = traits<XprType>::ColsAtCompileTime,
RowsAtCompileTime = Rows,
ColsAtCompileTime = Cols,
MaxRowsAtCompileTime = Rows,
MaxColsAtCompileTime = Cols,
XpxStorageOrder = ((int(traits<XprType>::Flags) & RowMajorBit) == RowMajorBit) ? RowMajor : ColMajor,
ReshapedStorageOrder = (RowsAtCompileTime == 1 && ColsAtCompileTime != 1) ? RowMajor
: (ColsAtCompileTime == 1 && RowsAtCompileTime != 1) ? ColMajor
: XpxStorageOrder,
HasSameStorageOrderAsXprType = (ReshapedStorageOrder == XpxStorageOrder),
InnerSize = (ReshapedStorageOrder==int(RowMajor)) ? int(ColsAtCompileTime) : int(RowsAtCompileTime),
InnerStrideAtCompileTime = HasSameStorageOrderAsXprType
? int(inner_stride_at_compile_time<XprType>::ret)
: Dynamic,
OuterStrideAtCompileTime = Dynamic,
HasDirectAccess = internal::has_direct_access<XprType>::ret
&& (Order==int(XpxStorageOrder))
&& ((evaluator<XprType>::Flags&LinearAccessBit)==LinearAccessBit),
MaskPacketAccessBit = (InnerSize == Dynamic || (InnerSize % packet_traits<Scalar>::size) == 0)
&& (InnerStrideAtCompileTime == 1)
? PacketAccessBit : 0,
//MaskAlignedBit = ((OuterStrideAtCompileTime!=Dynamic) && (((OuterStrideAtCompileTime * int(sizeof(Scalar))) % 16) == 0)) ? AlignedBit : 0,
FlagsLinearAccessBit = (RowsAtCompileTime == 1 || ColsAtCompileTime == 1) ? LinearAccessBit : 0,
FlagsLvalueBit = is_lvalue<XprType>::value ? LvalueBit : 0,
FlagsRowMajorBit = (ReshapedStorageOrder==int(RowMajor)) ? RowMajorBit : 0,
FlagsDirectAccessBit = HasDirectAccess ? DirectAccessBit : 0,
Flags0 = traits<XprType>::Flags & ( (HereditaryBits & ~RowMajorBit) | MaskPacketAccessBit),
Flags = (Flags0 | FlagsLinearAccessBit | FlagsLvalueBit | FlagsRowMajorBit | FlagsDirectAccessBit)
};
};
template<typename XprType, int Rows, int Cols, int Order, bool HasDirectAccess> class ReshapedImpl_dense;
} // end namespace internal
template<typename XprType, int Rows, int Cols, int Order, typename StorageKind> class ReshapedImpl;
template<typename XprType, int Rows, int Cols, int Order> class Reshaped
: public ReshapedImpl<XprType, Rows, Cols, Order, typename internal::traits<XprType>::StorageKind>
{
typedef ReshapedImpl<XprType, Rows, Cols, Order, typename internal::traits<XprType>::StorageKind> Impl;
public:
//typedef typename Impl::Base Base;
typedef Impl Base;
EIGEN_GENERIC_PUBLIC_INTERFACE(Reshaped)
EIGEN_INHERIT_ASSIGNMENT_OPERATORS(Reshaped)
/** Fixed-size constructor
*/
EIGEN_DEVICE_FUNC
inline Reshaped(XprType& xpr)
: Impl(xpr)
{
EIGEN_STATIC_ASSERT(RowsAtCompileTime!=Dynamic && ColsAtCompileTime!=Dynamic,THIS_METHOD_IS_ONLY_FOR_FIXED_SIZE)
eigen_assert(Rows * Cols == xpr.rows() * xpr.cols());
}
/** Dynamic-size constructor
*/
EIGEN_DEVICE_FUNC
inline Reshaped(XprType& xpr,
Index reshapeRows, Index reshapeCols)
: Impl(xpr, reshapeRows, reshapeCols)
{
eigen_assert((RowsAtCompileTime==Dynamic || RowsAtCompileTime==reshapeRows)
&& (ColsAtCompileTime==Dynamic || ColsAtCompileTime==reshapeCols));
eigen_assert(reshapeRows * reshapeCols == xpr.rows() * xpr.cols());
}
};
// The generic default implementation for dense reshape simply forward to the internal::ReshapedImpl_dense
// that must be specialized for direct and non-direct access...
template<typename XprType, int Rows, int Cols, int Order>
class ReshapedImpl<XprType, Rows, Cols, Order, Dense>
: public internal::ReshapedImpl_dense<XprType, Rows, Cols, Order,internal::traits<Reshaped<XprType,Rows,Cols,Order> >::HasDirectAccess>
{
typedef internal::ReshapedImpl_dense<XprType, Rows, Cols, Order,internal::traits<Reshaped<XprType,Rows,Cols,Order> >::HasDirectAccess> Impl;
public:
typedef Impl Base;
EIGEN_INHERIT_ASSIGNMENT_OPERATORS(ReshapedImpl)
EIGEN_DEVICE_FUNC inline ReshapedImpl(XprType& xpr) : Impl(xpr) {}
EIGEN_DEVICE_FUNC inline ReshapedImpl(XprType& xpr, Index reshapeRows, Index reshapeCols)
: Impl(xpr, reshapeRows, reshapeCols) {}
};
namespace internal {
/** \internal Internal implementation of dense Reshaped in the general case. */
template<typename XprType, int Rows, int Cols, int Order>
class ReshapedImpl_dense<XprType,Rows,Cols,Order,false>
: public internal::dense_xpr_base<Reshaped<XprType, Rows, Cols, Order> >::type
{
typedef Reshaped<XprType, Rows, Cols, Order> ReshapedType;
public:
typedef typename internal::dense_xpr_base<ReshapedType>::type Base;
EIGEN_DENSE_PUBLIC_INTERFACE(ReshapedType)
EIGEN_INHERIT_ASSIGNMENT_OPERATORS(ReshapedImpl_dense)
typedef typename internal::ref_selector<XprType>::non_const_type MatrixTypeNested;
typedef typename internal::remove_all<XprType>::type NestedExpression;
class InnerIterator;
/** Fixed-size constructor
*/
EIGEN_DEVICE_FUNC
inline ReshapedImpl_dense(XprType& xpr)
: m_xpr(xpr), m_rows(Rows), m_cols(Cols)
{}
/** Dynamic-size constructor
*/
EIGEN_DEVICE_FUNC
inline ReshapedImpl_dense(XprType& xpr, Index nRows, Index nCols)
: m_xpr(xpr), m_rows(nRows), m_cols(nCols)
{}
EIGEN_DEVICE_FUNC Index rows() const { return m_rows; }
EIGEN_DEVICE_FUNC Index cols() const { return m_cols; }
#ifdef EIGEN_PARSED_BY_DOXYGEN
/** \sa MapBase::data() */
EIGEN_DEVICE_FUNC inline const Scalar* data() const;
EIGEN_DEVICE_FUNC inline Index innerStride() const;
EIGEN_DEVICE_FUNC inline Index outerStride() const;
#endif
/** \returns the nested expression */
EIGEN_DEVICE_FUNC
const typename internal::remove_all<XprType>::type&
nestedExpression() const { return m_xpr; }
/** \returns the nested expression */
EIGEN_DEVICE_FUNC
typename internal::remove_reference<XprType>::type&
nestedExpression() { return m_xpr; }
protected:
MatrixTypeNested m_xpr;
const internal::variable_if_dynamic<Index, Rows> m_rows;
const internal::variable_if_dynamic<Index, Cols> m_cols;
};
/** \internal Internal implementation of dense Reshaped in the direct access case. */
template<typename XprType, int Rows, int Cols, int Order>
class ReshapedImpl_dense<XprType, Rows, Cols, Order, true>
: public MapBase<Reshaped<XprType, Rows, Cols, Order> >
{
typedef Reshaped<XprType, Rows, Cols, Order> ReshapedType;
typedef typename internal::ref_selector<XprType>::non_const_type XprTypeNested;
public:
typedef MapBase<ReshapedType> Base;
EIGEN_DENSE_PUBLIC_INTERFACE(ReshapedType)
EIGEN_INHERIT_ASSIGNMENT_OPERATORS(ReshapedImpl_dense)
/** Fixed-size constructor
*/
EIGEN_DEVICE_FUNC
inline ReshapedImpl_dense(XprType& xpr)
: Base(xpr.data()), m_xpr(xpr)
{}
/** Dynamic-size constructor
*/
EIGEN_DEVICE_FUNC
inline ReshapedImpl_dense(XprType& xpr, Index nRows, Index nCols)
: Base(xpr.data(), nRows, nCols),
m_xpr(xpr)
{}
EIGEN_DEVICE_FUNC
const typename internal::remove_all<XprTypeNested>::type& nestedExpression() const
{
return m_xpr;
}
EIGEN_DEVICE_FUNC
XprType& nestedExpression() { return m_xpr; }
/** \sa MapBase::innerStride() */
EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR
inline Index innerStride() const
{
return m_xpr.innerStride();
}
/** \sa MapBase::outerStride() */
EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR
inline Index outerStride() const
{
return ((Flags&RowMajorBit)==RowMajorBit) ? this->cols() : this->rows();
}
protected:
XprTypeNested m_xpr;
};
// Evaluators
template<typename ArgType, int Rows, int Cols, int Order, bool HasDirectAccess> struct reshaped_evaluator;
template<typename ArgType, int Rows, int Cols, int Order>
struct evaluator<Reshaped<ArgType, Rows, Cols, Order> >
: reshaped_evaluator<ArgType, Rows, Cols, Order, traits<Reshaped<ArgType,Rows,Cols,Order> >::HasDirectAccess>
{
typedef Reshaped<ArgType, Rows, Cols, Order> XprType;
typedef typename XprType::Scalar Scalar;
// TODO: should check for smaller packet types
typedef typename packet_traits<Scalar>::type PacketScalar;
enum {
CoeffReadCost = evaluator<ArgType>::CoeffReadCost,
HasDirectAccess = traits<XprType>::HasDirectAccess,
// RowsAtCompileTime = traits<XprType>::RowsAtCompileTime,
// ColsAtCompileTime = traits<XprType>::ColsAtCompileTime,
// MaxRowsAtCompileTime = traits<XprType>::MaxRowsAtCompileTime,
// MaxColsAtCompileTime = traits<XprType>::MaxColsAtCompileTime,
//
// InnerStrideAtCompileTime = traits<XprType>::HasSameStorageOrderAsXprType
// ? int(inner_stride_at_compile_time<ArgType>::ret)
// : Dynamic,
// OuterStrideAtCompileTime = Dynamic,
FlagsLinearAccessBit = (traits<XprType>::RowsAtCompileTime == 1 || traits<XprType>::ColsAtCompileTime == 1 || HasDirectAccess) ? LinearAccessBit : 0,
FlagsRowMajorBit = (traits<XprType>::ReshapedStorageOrder==int(RowMajor)) ? RowMajorBit : 0,
FlagsDirectAccessBit = HasDirectAccess ? DirectAccessBit : 0,
Flags0 = evaluator<ArgType>::Flags & (HereditaryBits & ~RowMajorBit),
Flags = Flags0 | FlagsLinearAccessBit | FlagsRowMajorBit | FlagsDirectAccessBit,
PacketAlignment = unpacket_traits<PacketScalar>::alignment,
Alignment = evaluator<ArgType>::Alignment
};
typedef reshaped_evaluator<ArgType, Rows, Cols, Order, HasDirectAccess> reshaped_evaluator_type;
EIGEN_DEVICE_FUNC explicit evaluator(const XprType& xpr) : reshaped_evaluator_type(xpr)
{
EIGEN_INTERNAL_CHECK_COST_VALUE(CoeffReadCost);
}
};
template<typename ArgType, int Rows, int Cols, int Order>
struct reshaped_evaluator<ArgType, Rows, Cols, Order, /* HasDirectAccess */ false>
: evaluator_base<Reshaped<ArgType, Rows, Cols, Order> >
{
typedef Reshaped<ArgType, Rows, Cols, Order> XprType;
enum {
CoeffReadCost = evaluator<ArgType>::CoeffReadCost /* TODO + cost of index computations */,
Flags = (evaluator<ArgType>::Flags & (HereditaryBits /*| LinearAccessBit | DirectAccessBit*/)),
Alignment = 0
};
EIGEN_DEVICE_FUNC explicit reshaped_evaluator(const XprType& xpr) : m_argImpl(xpr.nestedExpression()), m_xpr(xpr)
{
EIGEN_INTERNAL_CHECK_COST_VALUE(CoeffReadCost);
}
typedef typename XprType::Scalar Scalar;
typedef typename XprType::CoeffReturnType CoeffReturnType;
typedef std::pair<Index, Index> RowCol;
inline RowCol index_remap(Index rowId, Index colId) const
{
if(Order==ColMajor)
{
const Index nth_elem_idx = colId * m_xpr.rows() + rowId;
return RowCol(nth_elem_idx % m_xpr.nestedExpression().rows(),
nth_elem_idx / m_xpr.nestedExpression().rows());
}
else
{
const Index nth_elem_idx = colId + rowId * m_xpr.cols();
return RowCol(nth_elem_idx / m_xpr.nestedExpression().cols(),
nth_elem_idx % m_xpr.nestedExpression().cols());
}
}
EIGEN_DEVICE_FUNC
inline Scalar& coeffRef(Index rowId, Index colId)
{
EIGEN_STATIC_ASSERT_LVALUE(XprType)
const RowCol row_col = index_remap(rowId, colId);
return m_argImpl.coeffRef(row_col.first, row_col.second);
}
EIGEN_DEVICE_FUNC
inline const Scalar& coeffRef(Index rowId, Index colId) const
{
const RowCol row_col = index_remap(rowId, colId);
return m_argImpl.coeffRef(row_col.first, row_col.second);
}
EIGEN_DEVICE_FUNC
EIGEN_STRONG_INLINE const CoeffReturnType coeff(Index rowId, Index colId) const
{
const RowCol row_col = index_remap(rowId, colId);
return m_argImpl.coeff(row_col.first, row_col.second);
}
EIGEN_DEVICE_FUNC
inline Scalar& coeffRef(Index index)
{
EIGEN_STATIC_ASSERT_LVALUE(XprType)
const RowCol row_col = index_remap(Rows == 1 ? 0 : index,
Rows == 1 ? index : 0);
return m_argImpl.coeffRef(row_col.first, row_col.second);
}
EIGEN_DEVICE_FUNC
inline const Scalar& coeffRef(Index index) const
{
const RowCol row_col = index_remap(Rows == 1 ? 0 : index,
Rows == 1 ? index : 0);
return m_argImpl.coeffRef(row_col.first, row_col.second);
}
EIGEN_DEVICE_FUNC
inline const CoeffReturnType coeff(Index index) const
{
const RowCol row_col = index_remap(Rows == 1 ? 0 : index,
Rows == 1 ? index : 0);
return m_argImpl.coeff(row_col.first, row_col.second);
}
#if 0
EIGEN_DEVICE_FUNC
template<int LoadMode>
inline PacketScalar packet(Index rowId, Index colId) const
{
const RowCol row_col = index_remap(rowId, colId);
return m_argImpl.template packet<Unaligned>(row_col.first, row_col.second);
}
template<int LoadMode>
EIGEN_DEVICE_FUNC
inline void writePacket(Index rowId, Index colId, const PacketScalar& val)
{
const RowCol row_col = index_remap(rowId, colId);
m_argImpl.const_cast_derived().template writePacket<Unaligned>
(row_col.first, row_col.second, val);
}
template<int LoadMode>
EIGEN_DEVICE_FUNC
inline PacketScalar packet(Index index) const
{
const RowCol row_col = index_remap(RowsAtCompileTime == 1 ? 0 : index,
RowsAtCompileTime == 1 ? index : 0);
return m_argImpl.template packet<Unaligned>(row_col.first, row_col.second);
}
template<int LoadMode>
EIGEN_DEVICE_FUNC
inline void writePacket(Index index, const PacketScalar& val)
{
const RowCol row_col = index_remap(RowsAtCompileTime == 1 ? 0 : index,
RowsAtCompileTime == 1 ? index : 0);
return m_argImpl.template packet<Unaligned>(row_col.first, row_col.second, val);
}
#endif
protected:
evaluator<ArgType> m_argImpl;
const XprType& m_xpr;
};
template<typename ArgType, int Rows, int Cols, int Order>
struct reshaped_evaluator<ArgType, Rows, Cols, Order, /* HasDirectAccess */ true>
: mapbase_evaluator<Reshaped<ArgType, Rows, Cols, Order>,
typename Reshaped<ArgType, Rows, Cols, Order>::PlainObject>
{
typedef Reshaped<ArgType, Rows, Cols, Order> XprType;
typedef typename XprType::Scalar Scalar;
EIGEN_DEVICE_FUNC explicit reshaped_evaluator(const XprType& xpr)
: mapbase_evaluator<XprType, typename XprType::PlainObject>(xpr)
{
// TODO: for the 3.4 release, this should be turned to an internal assertion, but let's keep it as is for the beta lifetime
eigen_assert(((internal::UIntPtr(xpr.data()) % EIGEN_PLAIN_ENUM_MAX(1,evaluator<XprType>::Alignment)) == 0) && "data is not aligned");
}
};
} // end namespace internal
} // end namespace Eigen
#endif // EIGEN_RESHAPED_H

View File

@ -0,0 +1,463 @@
// This file is part of Eigen, a lightweight C++ template library
// for linear algebra.
//
// Copyright (C) 2018 Gael Guennebaud <gael.guennebaud@inria.fr>
//
// This Source Code Form is subject to the terms of the Mozilla
// Public License v. 2.0. If a copy of the MPL was not distributed
// with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
#ifndef EIGEN_STLITERATORS_H
#define EIGEN_STLITERATORS_H
namespace Eigen {
namespace internal {
template<typename IteratorType>
struct indexed_based_stl_iterator_traits;
template<typename Derived>
class indexed_based_stl_iterator_base
{
protected:
typedef indexed_based_stl_iterator_traits<Derived> traits;
typedef typename traits::XprType XprType;
typedef indexed_based_stl_iterator_base<typename traits::non_const_iterator> non_const_iterator;
typedef indexed_based_stl_iterator_base<typename traits::const_iterator> const_iterator;
typedef typename internal::conditional<internal::is_const<XprType>::value,non_const_iterator,const_iterator>::type other_iterator;
// NOTE: in C++03 we cannot declare friend classes through typedefs because we need to write friend class:
friend class indexed_based_stl_iterator_base<typename traits::const_iterator>;
friend class indexed_based_stl_iterator_base<typename traits::non_const_iterator>;
public:
typedef Index difference_type;
typedef std::random_access_iterator_tag iterator_category;
indexed_based_stl_iterator_base() EIGEN_NO_THROW : mp_xpr(0), m_index(0) {}
indexed_based_stl_iterator_base(XprType& xpr, Index index) EIGEN_NO_THROW : mp_xpr(&xpr), m_index(index) {}
indexed_based_stl_iterator_base(const non_const_iterator& other) EIGEN_NO_THROW
: mp_xpr(other.mp_xpr), m_index(other.m_index)
{}
indexed_based_stl_iterator_base& operator=(const non_const_iterator& other)
{
mp_xpr = other.mp_xpr;
m_index = other.m_index;
return *this;
}
Derived& operator++() { ++m_index; return derived(); }
Derived& operator--() { --m_index; return derived(); }
Derived operator++(int) { Derived prev(derived()); operator++(); return prev;}
Derived operator--(int) { Derived prev(derived()); operator--(); return prev;}
friend Derived operator+(const indexed_based_stl_iterator_base& a, Index b) { Derived ret(a.derived()); ret += b; return ret; }
friend Derived operator-(const indexed_based_stl_iterator_base& a, Index b) { Derived ret(a.derived()); ret -= b; return ret; }
friend Derived operator+(Index a, const indexed_based_stl_iterator_base& b) { Derived ret(b.derived()); ret += a; return ret; }
friend Derived operator-(Index a, const indexed_based_stl_iterator_base& b) { Derived ret(b.derived()); ret -= a; return ret; }
Derived& operator+=(Index b) { m_index += b; return derived(); }
Derived& operator-=(Index b) { m_index -= b; return derived(); }
difference_type operator-(const indexed_based_stl_iterator_base& other) const
{
eigen_assert(mp_xpr == other.mp_xpr);
return m_index - other.m_index;
}
difference_type operator-(const other_iterator& other) const
{
eigen_assert(mp_xpr == other.mp_xpr);
return m_index - other.m_index;
}
bool operator==(const indexed_based_stl_iterator_base& other) const { eigen_assert(mp_xpr == other.mp_xpr); return m_index == other.m_index; }
bool operator!=(const indexed_based_stl_iterator_base& other) const { eigen_assert(mp_xpr == other.mp_xpr); return m_index != other.m_index; }
bool operator< (const indexed_based_stl_iterator_base& other) const { eigen_assert(mp_xpr == other.mp_xpr); return m_index < other.m_index; }
bool operator<=(const indexed_based_stl_iterator_base& other) const { eigen_assert(mp_xpr == other.mp_xpr); return m_index <= other.m_index; }
bool operator> (const indexed_based_stl_iterator_base& other) const { eigen_assert(mp_xpr == other.mp_xpr); return m_index > other.m_index; }
bool operator>=(const indexed_based_stl_iterator_base& other) const { eigen_assert(mp_xpr == other.mp_xpr); return m_index >= other.m_index; }
bool operator==(const other_iterator& other) const { eigen_assert(mp_xpr == other.mp_xpr); return m_index == other.m_index; }
bool operator!=(const other_iterator& other) const { eigen_assert(mp_xpr == other.mp_xpr); return m_index != other.m_index; }
bool operator< (const other_iterator& other) const { eigen_assert(mp_xpr == other.mp_xpr); return m_index < other.m_index; }
bool operator<=(const other_iterator& other) const { eigen_assert(mp_xpr == other.mp_xpr); return m_index <= other.m_index; }
bool operator> (const other_iterator& other) const { eigen_assert(mp_xpr == other.mp_xpr); return m_index > other.m_index; }
bool operator>=(const other_iterator& other) const { eigen_assert(mp_xpr == other.mp_xpr); return m_index >= other.m_index; }
protected:
Derived& derived() { return static_cast<Derived&>(*this); }
const Derived& derived() const { return static_cast<const Derived&>(*this); }
XprType *mp_xpr;
Index m_index;
};
template<typename Derived>
class indexed_based_stl_reverse_iterator_base
{
protected:
typedef indexed_based_stl_iterator_traits<Derived> traits;
typedef typename traits::XprType XprType;
typedef indexed_based_stl_reverse_iterator_base<typename traits::non_const_iterator> non_const_iterator;
typedef indexed_based_stl_reverse_iterator_base<typename traits::const_iterator> const_iterator;
typedef typename internal::conditional<internal::is_const<XprType>::value,non_const_iterator,const_iterator>::type other_iterator;
// NOTE: in C++03 we cannot declare friend classes through typedefs because we need to write friend class:
friend class indexed_based_stl_reverse_iterator_base<typename traits::const_iterator>;
friend class indexed_based_stl_reverse_iterator_base<typename traits::non_const_iterator>;
public:
typedef Index difference_type;
typedef std::random_access_iterator_tag iterator_category;
indexed_based_stl_reverse_iterator_base() : mp_xpr(0), m_index(0) {}
indexed_based_stl_reverse_iterator_base(XprType& xpr, Index index) : mp_xpr(&xpr), m_index(index) {}
indexed_based_stl_reverse_iterator_base(const non_const_iterator& other)
: mp_xpr(other.mp_xpr), m_index(other.m_index)
{}
indexed_based_stl_reverse_iterator_base& operator=(const non_const_iterator& other)
{
mp_xpr = other.mp_xpr;
m_index = other.m_index;
return *this;
}
Derived& operator++() { --m_index; return derived(); }
Derived& operator--() { ++m_index; return derived(); }
Derived operator++(int) { Derived prev(derived()); operator++(); return prev;}
Derived operator--(int) { Derived prev(derived()); operator--(); return prev;}
friend Derived operator+(const indexed_based_stl_reverse_iterator_base& a, Index b) { Derived ret(a.derived()); ret += b; return ret; }
friend Derived operator-(const indexed_based_stl_reverse_iterator_base& a, Index b) { Derived ret(a.derived()); ret -= b; return ret; }
friend Derived operator+(Index a, const indexed_based_stl_reverse_iterator_base& b) { Derived ret(b.derived()); ret += a; return ret; }
friend Derived operator-(Index a, const indexed_based_stl_reverse_iterator_base& b) { Derived ret(b.derived()); ret -= a; return ret; }
Derived& operator+=(Index b) { m_index -= b; return derived(); }
Derived& operator-=(Index b) { m_index += b; return derived(); }
difference_type operator-(const indexed_based_stl_reverse_iterator_base& other) const
{
eigen_assert(mp_xpr == other.mp_xpr);
return other.m_index - m_index;
}
difference_type operator-(const other_iterator& other) const
{
eigen_assert(mp_xpr == other.mp_xpr);
return other.m_index - m_index;
}
bool operator==(const indexed_based_stl_reverse_iterator_base& other) const { eigen_assert(mp_xpr == other.mp_xpr); return m_index == other.m_index; }
bool operator!=(const indexed_based_stl_reverse_iterator_base& other) const { eigen_assert(mp_xpr == other.mp_xpr); return m_index != other.m_index; }
bool operator< (const indexed_based_stl_reverse_iterator_base& other) const { eigen_assert(mp_xpr == other.mp_xpr); return m_index > other.m_index; }
bool operator<=(const indexed_based_stl_reverse_iterator_base& other) const { eigen_assert(mp_xpr == other.mp_xpr); return m_index >= other.m_index; }
bool operator> (const indexed_based_stl_reverse_iterator_base& other) const { eigen_assert(mp_xpr == other.mp_xpr); return m_index < other.m_index; }
bool operator>=(const indexed_based_stl_reverse_iterator_base& other) const { eigen_assert(mp_xpr == other.mp_xpr); return m_index <= other.m_index; }
bool operator==(const other_iterator& other) const { eigen_assert(mp_xpr == other.mp_xpr); return m_index == other.m_index; }
bool operator!=(const other_iterator& other) const { eigen_assert(mp_xpr == other.mp_xpr); return m_index != other.m_index; }
bool operator< (const other_iterator& other) const { eigen_assert(mp_xpr == other.mp_xpr); return m_index > other.m_index; }
bool operator<=(const other_iterator& other) const { eigen_assert(mp_xpr == other.mp_xpr); return m_index >= other.m_index; }
bool operator> (const other_iterator& other) const { eigen_assert(mp_xpr == other.mp_xpr); return m_index < other.m_index; }
bool operator>=(const other_iterator& other) const { eigen_assert(mp_xpr == other.mp_xpr); return m_index <= other.m_index; }
protected:
Derived& derived() { return static_cast<Derived&>(*this); }
const Derived& derived() const { return static_cast<const Derived&>(*this); }
XprType *mp_xpr;
Index m_index;
};
template<typename XprType>
class pointer_based_stl_iterator
{
enum { is_lvalue = internal::is_lvalue<XprType>::value };
typedef pointer_based_stl_iterator<typename internal::remove_const<XprType>::type> non_const_iterator;
typedef pointer_based_stl_iterator<typename internal::add_const<XprType>::type> const_iterator;
typedef typename internal::conditional<internal::is_const<XprType>::value,non_const_iterator,const_iterator>::type other_iterator;
// NOTE: in C++03 we cannot declare friend classes through typedefs because we need to write friend class:
friend class pointer_based_stl_iterator<typename internal::add_const<XprType>::type>;
friend class pointer_based_stl_iterator<typename internal::remove_const<XprType>::type>;
public:
typedef Index difference_type;
typedef typename XprType::Scalar value_type;
typedef std::random_access_iterator_tag iterator_category;
typedef typename internal::conditional<bool(is_lvalue), value_type*, const value_type*>::type pointer;
typedef typename internal::conditional<bool(is_lvalue), value_type&, const value_type&>::type reference;
pointer_based_stl_iterator() EIGEN_NO_THROW : m_ptr(0) {}
pointer_based_stl_iterator(XprType& xpr, Index index) EIGEN_NO_THROW : m_incr(xpr.innerStride())
{
m_ptr = xpr.data() + index * m_incr.value();
}
pointer_based_stl_iterator(const non_const_iterator& other) EIGEN_NO_THROW
: m_ptr(other.m_ptr), m_incr(other.m_incr)
{}
pointer_based_stl_iterator& operator=(const non_const_iterator& other) EIGEN_NO_THROW
{
m_ptr = other.m_ptr;
m_incr.setValue(other.m_incr);
return *this;
}
reference operator*() const { return *m_ptr; }
reference operator[](Index i) const { return *(m_ptr+i*m_incr.value()); }
pointer operator->() const { return m_ptr; }
pointer_based_stl_iterator& operator++() { m_ptr += m_incr.value(); return *this; }
pointer_based_stl_iterator& operator--() { m_ptr -= m_incr.value(); return *this; }
pointer_based_stl_iterator operator++(int) { pointer_based_stl_iterator prev(*this); operator++(); return prev;}
pointer_based_stl_iterator operator--(int) { pointer_based_stl_iterator prev(*this); operator--(); return prev;}
friend pointer_based_stl_iterator operator+(const pointer_based_stl_iterator& a, Index b) { pointer_based_stl_iterator ret(a); ret += b; return ret; }
friend pointer_based_stl_iterator operator-(const pointer_based_stl_iterator& a, Index b) { pointer_based_stl_iterator ret(a); ret -= b; return ret; }
friend pointer_based_stl_iterator operator+(Index a, const pointer_based_stl_iterator& b) { pointer_based_stl_iterator ret(b); ret += a; return ret; }
friend pointer_based_stl_iterator operator-(Index a, const pointer_based_stl_iterator& b) { pointer_based_stl_iterator ret(b); ret -= a; return ret; }
pointer_based_stl_iterator& operator+=(Index b) { m_ptr += b*m_incr.value(); return *this; }
pointer_based_stl_iterator& operator-=(Index b) { m_ptr -= b*m_incr.value(); return *this; }
difference_type operator-(const pointer_based_stl_iterator& other) const {
return (m_ptr - other.m_ptr)/m_incr.value();
}
difference_type operator-(const other_iterator& other) const {
return (m_ptr - other.m_ptr)/m_incr.value();
}
bool operator==(const pointer_based_stl_iterator& other) const { return m_ptr == other.m_ptr; }
bool operator!=(const pointer_based_stl_iterator& other) const { return m_ptr != other.m_ptr; }
bool operator< (const pointer_based_stl_iterator& other) const { return m_ptr < other.m_ptr; }
bool operator<=(const pointer_based_stl_iterator& other) const { return m_ptr <= other.m_ptr; }
bool operator> (const pointer_based_stl_iterator& other) const { return m_ptr > other.m_ptr; }
bool operator>=(const pointer_based_stl_iterator& other) const { return m_ptr >= other.m_ptr; }
bool operator==(const other_iterator& other) const { return m_ptr == other.m_ptr; }
bool operator!=(const other_iterator& other) const { return m_ptr != other.m_ptr; }
bool operator< (const other_iterator& other) const { return m_ptr < other.m_ptr; }
bool operator<=(const other_iterator& other) const { return m_ptr <= other.m_ptr; }
bool operator> (const other_iterator& other) const { return m_ptr > other.m_ptr; }
bool operator>=(const other_iterator& other) const { return m_ptr >= other.m_ptr; }
protected:
pointer m_ptr;
internal::variable_if_dynamic<Index, XprType::InnerStrideAtCompileTime> m_incr;
};
template<typename _XprType>
struct indexed_based_stl_iterator_traits<generic_randaccess_stl_iterator<_XprType> >
{
typedef _XprType XprType;
typedef generic_randaccess_stl_iterator<typename internal::remove_const<XprType>::type> non_const_iterator;
typedef generic_randaccess_stl_iterator<typename internal::add_const<XprType>::type> const_iterator;
};
template<typename XprType>
class generic_randaccess_stl_iterator : public indexed_based_stl_iterator_base<generic_randaccess_stl_iterator<XprType> >
{
public:
typedef typename XprType::Scalar value_type;
protected:
enum {
has_direct_access = (internal::traits<XprType>::Flags & DirectAccessBit) ? 1 : 0,
is_lvalue = internal::is_lvalue<XprType>::value
};
typedef indexed_based_stl_iterator_base<generic_randaccess_stl_iterator> Base;
using Base::m_index;
using Base::mp_xpr;
// TODO currently const Transpose/Reshape expressions never returns const references,
// so lets return by value too.
//typedef typename internal::conditional<bool(has_direct_access), const value_type&, const value_type>::type read_only_ref_t;
typedef const value_type read_only_ref_t;
public:
typedef typename internal::conditional<bool(is_lvalue), value_type *, const value_type *>::type pointer;
typedef typename internal::conditional<bool(is_lvalue), value_type&, read_only_ref_t>::type reference;
generic_randaccess_stl_iterator() : Base() {}
generic_randaccess_stl_iterator(XprType& xpr, Index index) : Base(xpr,index) {}
generic_randaccess_stl_iterator(const typename Base::non_const_iterator& other) : Base(other) {}
using Base::operator=;
reference operator*() const { return (*mp_xpr)(m_index); }
reference operator[](Index i) const { return (*mp_xpr)(m_index+i); }
pointer operator->() const { return &((*mp_xpr)(m_index)); }
};
template<typename _XprType, DirectionType Direction>
struct indexed_based_stl_iterator_traits<subvector_stl_iterator<_XprType,Direction> >
{
typedef _XprType XprType;
typedef subvector_stl_iterator<typename internal::remove_const<XprType>::type, Direction> non_const_iterator;
typedef subvector_stl_iterator<typename internal::add_const<XprType>::type, Direction> const_iterator;
};
template<typename XprType, DirectionType Direction>
class subvector_stl_iterator : public indexed_based_stl_iterator_base<subvector_stl_iterator<XprType,Direction> >
{
protected:
enum { is_lvalue = internal::is_lvalue<XprType>::value };
typedef indexed_based_stl_iterator_base<subvector_stl_iterator> Base;
using Base::m_index;
using Base::mp_xpr;
typedef typename internal::conditional<Direction==Vertical,typename XprType::ColXpr,typename XprType::RowXpr>::type SubVectorType;
typedef typename internal::conditional<Direction==Vertical,typename XprType::ConstColXpr,typename XprType::ConstRowXpr>::type ConstSubVectorType;
public:
typedef typename internal::conditional<bool(is_lvalue), SubVectorType, ConstSubVectorType>::type reference;
typedef typename reference::PlainObject value_type;
private:
class subvector_stl_iterator_ptr
{
public:
subvector_stl_iterator_ptr(const reference &subvector) : m_subvector(subvector) {}
reference* operator->() { return &m_subvector; }
private:
reference m_subvector;
};
public:
typedef subvector_stl_iterator_ptr pointer;
subvector_stl_iterator() : Base() {}
subvector_stl_iterator(XprType& xpr, Index index) : Base(xpr,index) {}
reference operator*() const { return (*mp_xpr).template subVector<Direction>(m_index); }
reference operator[](Index i) const { return (*mp_xpr).template subVector<Direction>(m_index+i); }
pointer operator->() const { return (*mp_xpr).template subVector<Direction>(m_index); }
};
template<typename _XprType, DirectionType Direction>
struct indexed_based_stl_iterator_traits<subvector_stl_reverse_iterator<_XprType,Direction> >
{
typedef _XprType XprType;
typedef subvector_stl_reverse_iterator<typename internal::remove_const<XprType>::type, Direction> non_const_iterator;
typedef subvector_stl_reverse_iterator<typename internal::add_const<XprType>::type, Direction> const_iterator;
};
template<typename XprType, DirectionType Direction>
class subvector_stl_reverse_iterator : public indexed_based_stl_reverse_iterator_base<subvector_stl_reverse_iterator<XprType,Direction> >
{
protected:
enum { is_lvalue = internal::is_lvalue<XprType>::value };
typedef indexed_based_stl_reverse_iterator_base<subvector_stl_reverse_iterator> Base;
using Base::m_index;
using Base::mp_xpr;
typedef typename internal::conditional<Direction==Vertical,typename XprType::ColXpr,typename XprType::RowXpr>::type SubVectorType;
typedef typename internal::conditional<Direction==Vertical,typename XprType::ConstColXpr,typename XprType::ConstRowXpr>::type ConstSubVectorType;
public:
typedef typename internal::conditional<bool(is_lvalue), SubVectorType, ConstSubVectorType>::type reference;
typedef typename reference::PlainObject value_type;
private:
class subvector_stl_reverse_iterator_ptr
{
public:
subvector_stl_reverse_iterator_ptr(const reference &subvector) : m_subvector(subvector) {}
reference* operator->() { return &m_subvector; }
private:
reference m_subvector;
};
public:
typedef subvector_stl_reverse_iterator_ptr pointer;
subvector_stl_reverse_iterator() : Base() {}
subvector_stl_reverse_iterator(XprType& xpr, Index index) : Base(xpr,index) {}
reference operator*() const { return (*mp_xpr).template subVector<Direction>(m_index); }
reference operator[](Index i) const { return (*mp_xpr).template subVector<Direction>(m_index+i); }
pointer operator->() const { return (*mp_xpr).template subVector<Direction>(m_index); }
};
} // namespace internal
/** returns an iterator to the first element of the 1D vector or array
* \only_for_vectors
* \sa end(), cbegin()
*/
template<typename Derived>
inline typename DenseBase<Derived>::iterator DenseBase<Derived>::begin()
{
EIGEN_STATIC_ASSERT_VECTOR_ONLY(Derived);
return iterator(derived(), 0);
}
/** const version of begin() */
template<typename Derived>
inline typename DenseBase<Derived>::const_iterator DenseBase<Derived>::begin() const
{
return cbegin();
}
/** returns a read-only const_iterator to the first element of the 1D vector or array
* \only_for_vectors
* \sa cend(), begin()
*/
template<typename Derived>
inline typename DenseBase<Derived>::const_iterator DenseBase<Derived>::cbegin() const
{
EIGEN_STATIC_ASSERT_VECTOR_ONLY(Derived);
return const_iterator(derived(), 0);
}
/** returns an iterator to the element following the last element of the 1D vector or array
* \only_for_vectors
* \sa begin(), cend()
*/
template<typename Derived>
inline typename DenseBase<Derived>::iterator DenseBase<Derived>::end()
{
EIGEN_STATIC_ASSERT_VECTOR_ONLY(Derived);
return iterator(derived(), size());
}
/** const version of end() */
template<typename Derived>
inline typename DenseBase<Derived>::const_iterator DenseBase<Derived>::end() const
{
return cend();
}
/** returns a read-only const_iterator to the element following the last element of the 1D vector or array
* \only_for_vectors
* \sa begin(), cend()
*/
template<typename Derived>
inline typename DenseBase<Derived>::const_iterator DenseBase<Derived>::cend() const
{
EIGEN_STATIC_ASSERT_VECTOR_ONLY(Derived);
return const_iterator(derived(), size());
}
} // namespace Eigen
#endif // EIGEN_STLITERATORS_H

View File

@ -0,0 +1,422 @@
// This file is part of Eigen, a lightweight C++ template library
// for linear algebra.
//
// Copyright (C) 2018 Gael Guennebaud <gael.guennebaud@inria.fr>
//
// This Source Code Form is subject to the terms of the Mozilla
// Public License v. 2.0. If a copy of the MPL was not distributed
// with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
#ifndef EIGEN_COMPLEX_AVX512_H
#define EIGEN_COMPLEX_AVX512_H
namespace Eigen {
namespace internal {
//---------- float ----------
struct Packet8cf
{
EIGEN_STRONG_INLINE Packet8cf() {}
EIGEN_STRONG_INLINE explicit Packet8cf(const __m512& a) : v(a) {}
__m512 v;
};
template<> struct packet_traits<std::complex<float> > : default_packet_traits
{
typedef Packet8cf type;
typedef Packet4cf half;
enum {
Vectorizable = 1,
AlignedOnScalar = 1,
size = 8,
HasHalfPacket = 1,
HasAdd = 1,
HasSub = 1,
HasMul = 1,
HasDiv = 1,
HasNegate = 1,
HasSqrt = 1,
HasAbs = 0,
HasAbs2 = 0,
HasMin = 0,
HasMax = 0,
HasSetLinear = 0
};
};
template<> struct unpacket_traits<Packet8cf> {
typedef std::complex<float> type;
typedef Packet4cf half;
typedef Packet16f as_real;
enum {
size = 8,
alignment=unpacket_traits<Packet16f>::alignment,
vectorizable=true,
masked_load_available=false,
masked_store_available=false
};
};
template<> EIGEN_STRONG_INLINE Packet8cf ptrue<Packet8cf>(const Packet8cf& a) { return Packet8cf(ptrue(Packet16f(a.v))); }
template<> EIGEN_STRONG_INLINE Packet8cf padd<Packet8cf>(const Packet8cf& a, const Packet8cf& b) { return Packet8cf(_mm512_add_ps(a.v,b.v)); }
template<> EIGEN_STRONG_INLINE Packet8cf psub<Packet8cf>(const Packet8cf& a, const Packet8cf& b) { return Packet8cf(_mm512_sub_ps(a.v,b.v)); }
template<> EIGEN_STRONG_INLINE Packet8cf pnegate(const Packet8cf& a)
{
return Packet8cf(pnegate(a.v));
}
template<> EIGEN_STRONG_INLINE Packet8cf pconj(const Packet8cf& a)
{
const __m512 mask = _mm512_castsi512_ps(_mm512_setr_epi32(
0x00000000,0x80000000,0x00000000,0x80000000,0x00000000,0x80000000,0x00000000,0x80000000,
0x00000000,0x80000000,0x00000000,0x80000000,0x00000000,0x80000000,0x00000000,0x80000000));
return Packet8cf(pxor(a.v,mask));
}
template<> EIGEN_STRONG_INLINE Packet8cf pmul<Packet8cf>(const Packet8cf& a, const Packet8cf& b)
{
__m512 tmp2 = _mm512_mul_ps(_mm512_movehdup_ps(a.v), _mm512_permute_ps(b.v, _MM_SHUFFLE(2,3,0,1)));
return Packet8cf(_mm512_fmaddsub_ps(_mm512_moveldup_ps(a.v), b.v, tmp2));
}
template<> EIGEN_STRONG_INLINE Packet8cf pand <Packet8cf>(const Packet8cf& a, const Packet8cf& b) { return Packet8cf(pand(a.v,b.v)); }
template<> EIGEN_STRONG_INLINE Packet8cf por <Packet8cf>(const Packet8cf& a, const Packet8cf& b) { return Packet8cf(por(a.v,b.v)); }
template<> EIGEN_STRONG_INLINE Packet8cf pxor <Packet8cf>(const Packet8cf& a, const Packet8cf& b) { return Packet8cf(pxor(a.v,b.v)); }
template<> EIGEN_STRONG_INLINE Packet8cf pandnot<Packet8cf>(const Packet8cf& a, const Packet8cf& b) { return Packet8cf(pandnot(a.v,b.v)); }
template <>
EIGEN_STRONG_INLINE Packet8cf pcmp_eq(const Packet8cf& a, const Packet8cf& b) {
__m512 eq = pcmp_eq<Packet16f>(a.v, b.v);
return Packet8cf(pand(eq, _mm512_permute_ps(eq, 0xB1)));
}
template<> EIGEN_STRONG_INLINE Packet8cf pload <Packet8cf>(const std::complex<float>* from) { EIGEN_DEBUG_ALIGNED_LOAD return Packet8cf(pload<Packet16f>(&numext::real_ref(*from))); }
template<> EIGEN_STRONG_INLINE Packet8cf ploadu<Packet8cf>(const std::complex<float>* from) { EIGEN_DEBUG_UNALIGNED_LOAD return Packet8cf(ploadu<Packet16f>(&numext::real_ref(*from))); }
template<> EIGEN_STRONG_INLINE Packet8cf pset1<Packet8cf>(const std::complex<float>& from)
{
return Packet8cf(_mm512_castpd_ps(pload1<Packet8d>((const double*)(const void*)&from)));
}
template<> EIGEN_STRONG_INLINE Packet8cf ploaddup<Packet8cf>(const std::complex<float>* from)
{
return Packet8cf( _mm512_castpd_ps( ploaddup<Packet8d>((const double*)(const void*)from )) );
}
template<> EIGEN_STRONG_INLINE Packet8cf ploadquad<Packet8cf>(const std::complex<float>* from)
{
return Packet8cf( _mm512_castpd_ps( ploadquad<Packet8d>((const double*)(const void*)from )) );
}
template<> EIGEN_STRONG_INLINE void pstore <std::complex<float> >(std::complex<float>* to, const Packet8cf& from) { EIGEN_DEBUG_ALIGNED_STORE pstore(&numext::real_ref(*to), from.v); }
template<> EIGEN_STRONG_INLINE void pstoreu<std::complex<float> >(std::complex<float>* to, const Packet8cf& from) { EIGEN_DEBUG_UNALIGNED_STORE pstoreu(&numext::real_ref(*to), from.v); }
template<> EIGEN_DEVICE_FUNC inline Packet8cf pgather<std::complex<float>, Packet8cf>(const std::complex<float>* from, Index stride)
{
return Packet8cf(_mm512_castpd_ps(pgather<double,Packet8d>((const double*)(const void*)from, stride)));
}
template<> EIGEN_DEVICE_FUNC inline void pscatter<std::complex<float>, Packet8cf>(std::complex<float>* to, const Packet8cf& from, Index stride)
{
pscatter((double*)(void*)to, _mm512_castps_pd(from.v), stride);
}
template<> EIGEN_STRONG_INLINE std::complex<float> pfirst<Packet8cf>(const Packet8cf& a)
{
return pfirst(Packet2cf(_mm512_castps512_ps128(a.v)));
}
template<> EIGEN_STRONG_INLINE Packet8cf preverse(const Packet8cf& a) {
return Packet8cf(_mm512_castsi512_ps(
_mm512_permutexvar_epi64( _mm512_set_epi32(0, 0, 0, 1, 0, 2, 0, 3, 0, 4, 0, 5, 0, 6, 0, 7),
_mm512_castps_si512(a.v))));
}
template<> EIGEN_STRONG_INLINE std::complex<float> predux<Packet8cf>(const Packet8cf& a)
{
return predux(padd(Packet4cf(extract256<0>(a.v)),
Packet4cf(extract256<1>(a.v))));
}
template<> EIGEN_STRONG_INLINE std::complex<float> predux_mul<Packet8cf>(const Packet8cf& a)
{
return predux_mul(pmul(Packet4cf(extract256<0>(a.v)),
Packet4cf(extract256<1>(a.v))));
}
template <>
EIGEN_STRONG_INLINE Packet4cf predux_half_dowto4<Packet8cf>(const Packet8cf& a) {
__m256 lane0 = extract256<0>(a.v);
__m256 lane1 = extract256<1>(a.v);
__m256 res = _mm256_add_ps(lane0, lane1);
return Packet4cf(res);
}
EIGEN_MAKE_CONJ_HELPER_CPLX_REAL(Packet8cf,Packet16f)
template<> EIGEN_STRONG_INLINE Packet8cf pdiv<Packet8cf>(const Packet8cf& a, const Packet8cf& b)
{
Packet8cf num = pmul(a, pconj(b));
__m512 tmp = _mm512_mul_ps(b.v, b.v);
__m512 tmp2 = _mm512_shuffle_ps(tmp,tmp,0xB1);
__m512 denom = _mm512_add_ps(tmp, tmp2);
return Packet8cf(_mm512_div_ps(num.v, denom));
}
template<> EIGEN_STRONG_INLINE Packet8cf pcplxflip<Packet8cf>(const Packet8cf& x)
{
return Packet8cf(_mm512_shuffle_ps(x.v, x.v, _MM_SHUFFLE(2, 3, 0 ,1)));
}
//---------- double ----------
struct Packet4cd
{
EIGEN_STRONG_INLINE Packet4cd() {}
EIGEN_STRONG_INLINE explicit Packet4cd(const __m512d& a) : v(a) {}
__m512d v;
};
template<> struct packet_traits<std::complex<double> > : default_packet_traits
{
typedef Packet4cd type;
typedef Packet2cd half;
enum {
Vectorizable = 1,
AlignedOnScalar = 0,
size = 4,
HasHalfPacket = 1,
HasAdd = 1,
HasSub = 1,
HasMul = 1,
HasDiv = 1,
HasNegate = 1,
HasSqrt = 1,
HasAbs = 0,
HasAbs2 = 0,
HasMin = 0,
HasMax = 0,
HasSetLinear = 0
};
};
template<> struct unpacket_traits<Packet4cd> {
typedef std::complex<double> type;
typedef Packet2cd half;
typedef Packet8d as_real;
enum {
size = 4,
alignment = unpacket_traits<Packet8d>::alignment,
vectorizable=true,
masked_load_available=false,
masked_store_available=false
};
};
template<> EIGEN_STRONG_INLINE Packet4cd padd<Packet4cd>(const Packet4cd& a, const Packet4cd& b) { return Packet4cd(_mm512_add_pd(a.v,b.v)); }
template<> EIGEN_STRONG_INLINE Packet4cd psub<Packet4cd>(const Packet4cd& a, const Packet4cd& b) { return Packet4cd(_mm512_sub_pd(a.v,b.v)); }
template<> EIGEN_STRONG_INLINE Packet4cd pnegate(const Packet4cd& a) { return Packet4cd(pnegate(a.v)); }
template<> EIGEN_STRONG_INLINE Packet4cd pconj(const Packet4cd& a)
{
const __m512d mask = _mm512_castsi512_pd(
_mm512_set_epi32(0x80000000,0x0,0x0,0x0,0x80000000,0x0,0x0,0x0,
0x80000000,0x0,0x0,0x0,0x80000000,0x0,0x0,0x0));
return Packet4cd(pxor(a.v,mask));
}
template<> EIGEN_STRONG_INLINE Packet4cd pmul<Packet4cd>(const Packet4cd& a, const Packet4cd& b)
{
__m512d tmp1 = _mm512_shuffle_pd(a.v,a.v,0x0);
__m512d tmp2 = _mm512_shuffle_pd(a.v,a.v,0xFF);
__m512d tmp3 = _mm512_shuffle_pd(b.v,b.v,0x55);
__m512d odd = _mm512_mul_pd(tmp2, tmp3);
return Packet4cd(_mm512_fmaddsub_pd(tmp1, b.v, odd));
}
template<> EIGEN_STRONG_INLINE Packet4cd ptrue<Packet4cd>(const Packet4cd& a) { return Packet4cd(ptrue(Packet8d(a.v))); }
template<> EIGEN_STRONG_INLINE Packet4cd pand <Packet4cd>(const Packet4cd& a, const Packet4cd& b) { return Packet4cd(pand(a.v,b.v)); }
template<> EIGEN_STRONG_INLINE Packet4cd por <Packet4cd>(const Packet4cd& a, const Packet4cd& b) { return Packet4cd(por(a.v,b.v)); }
template<> EIGEN_STRONG_INLINE Packet4cd pxor <Packet4cd>(const Packet4cd& a, const Packet4cd& b) { return Packet4cd(pxor(a.v,b.v)); }
template<> EIGEN_STRONG_INLINE Packet4cd pandnot<Packet4cd>(const Packet4cd& a, const Packet4cd& b) { return Packet4cd(pandnot(a.v,b.v)); }
template <>
EIGEN_STRONG_INLINE Packet4cd pcmp_eq(const Packet4cd& a, const Packet4cd& b) {
__m512d eq = pcmp_eq<Packet8d>(a.v, b.v);
return Packet4cd(pand(eq, _mm512_permute_pd(eq, 0x55)));
}
template<> EIGEN_STRONG_INLINE Packet4cd pload <Packet4cd>(const std::complex<double>* from)
{ EIGEN_DEBUG_ALIGNED_LOAD return Packet4cd(pload<Packet8d>((const double*)from)); }
template<> EIGEN_STRONG_INLINE Packet4cd ploadu<Packet4cd>(const std::complex<double>* from)
{ EIGEN_DEBUG_UNALIGNED_LOAD return Packet4cd(ploadu<Packet8d>((const double*)from)); }
template<> EIGEN_STRONG_INLINE Packet4cd pset1<Packet4cd>(const std::complex<double>& from)
{
#ifdef EIGEN_VECTORIZE_AVX512DQ
return Packet4cd(_mm512_broadcast_f64x2(pset1<Packet1cd>(from).v));
#else
return Packet4cd(_mm512_castps_pd(_mm512_broadcast_f32x4( _mm_castpd_ps(pset1<Packet1cd>(from).v))));
#endif
}
template<> EIGEN_STRONG_INLINE Packet4cd ploaddup<Packet4cd>(const std::complex<double>* from) {
return Packet4cd(_mm512_insertf64x4(
_mm512_castpd256_pd512(ploaddup<Packet2cd>(from).v), ploaddup<Packet2cd>(from+1).v, 1));
}
template<> EIGEN_STRONG_INLINE void pstore <std::complex<double> >(std::complex<double> * to, const Packet4cd& from) { EIGEN_DEBUG_ALIGNED_STORE pstore((double*)to, from.v); }
template<> EIGEN_STRONG_INLINE void pstoreu<std::complex<double> >(std::complex<double> * to, const Packet4cd& from) { EIGEN_DEBUG_UNALIGNED_STORE pstoreu((double*)to, from.v); }
template<> EIGEN_DEVICE_FUNC inline Packet4cd pgather<std::complex<double>, Packet4cd>(const std::complex<double>* from, Index stride)
{
return Packet4cd(_mm512_insertf64x4(_mm512_castpd256_pd512(
_mm256_insertf128_pd(_mm256_castpd128_pd256(ploadu<Packet1cd>(from+0*stride).v), ploadu<Packet1cd>(from+1*stride).v,1)),
_mm256_insertf128_pd(_mm256_castpd128_pd256(ploadu<Packet1cd>(from+2*stride).v), ploadu<Packet1cd>(from+3*stride).v,1), 1));
}
template<> EIGEN_DEVICE_FUNC inline void pscatter<std::complex<double>, Packet4cd>(std::complex<double>* to, const Packet4cd& from, Index stride)
{
__m512i fromi = _mm512_castpd_si512(from.v);
double* tod = (double*)(void*)to;
_mm_storeu_pd(tod+0*stride, _mm_castsi128_pd(_mm512_extracti32x4_epi32(fromi,0)) );
_mm_storeu_pd(tod+2*stride, _mm_castsi128_pd(_mm512_extracti32x4_epi32(fromi,1)) );
_mm_storeu_pd(tod+4*stride, _mm_castsi128_pd(_mm512_extracti32x4_epi32(fromi,2)) );
_mm_storeu_pd(tod+6*stride, _mm_castsi128_pd(_mm512_extracti32x4_epi32(fromi,3)) );
}
template<> EIGEN_STRONG_INLINE std::complex<double> pfirst<Packet4cd>(const Packet4cd& a)
{
__m128d low = extract128<0>(a.v);
EIGEN_ALIGN16 double res[2];
_mm_store_pd(res, low);
return std::complex<double>(res[0],res[1]);
}
template<> EIGEN_STRONG_INLINE Packet4cd preverse(const Packet4cd& a) {
return Packet4cd(_mm512_shuffle_f64x2(a.v, a.v, (shuffle_mask<3,2,1,0>::mask)));
}
template<> EIGEN_STRONG_INLINE std::complex<double> predux<Packet4cd>(const Packet4cd& a)
{
return predux(padd(Packet2cd(_mm512_extractf64x4_pd(a.v,0)),
Packet2cd(_mm512_extractf64x4_pd(a.v,1))));
}
template<> EIGEN_STRONG_INLINE std::complex<double> predux_mul<Packet4cd>(const Packet4cd& a)
{
return predux_mul(pmul(Packet2cd(_mm512_extractf64x4_pd(a.v,0)),
Packet2cd(_mm512_extractf64x4_pd(a.v,1))));
}
template<> struct conj_helper<Packet4cd, Packet4cd, false,true>
{
EIGEN_STRONG_INLINE Packet4cd pmadd(const Packet4cd& x, const Packet4cd& y, const Packet4cd& c) const
{ return padd(pmul(x,y),c); }
EIGEN_STRONG_INLINE Packet4cd pmul(const Packet4cd& a, const Packet4cd& b) const
{
return internal::pmul(a, pconj(b));
}
};
template<> struct conj_helper<Packet4cd, Packet4cd, true,false>
{
EIGEN_STRONG_INLINE Packet4cd pmadd(const Packet4cd& x, const Packet4cd& y, const Packet4cd& c) const
{ return padd(pmul(x,y),c); }
EIGEN_STRONG_INLINE Packet4cd pmul(const Packet4cd& a, const Packet4cd& b) const
{
return internal::pmul(pconj(a), b);
}
};
template<> struct conj_helper<Packet4cd, Packet4cd, true,true>
{
EIGEN_STRONG_INLINE Packet4cd pmadd(const Packet4cd& x, const Packet4cd& y, const Packet4cd& c) const
{ return padd(pmul(x,y),c); }
EIGEN_STRONG_INLINE Packet4cd pmul(const Packet4cd& a, const Packet4cd& b) const
{
return pconj(internal::pmul(a, b));
}
};
EIGEN_MAKE_CONJ_HELPER_CPLX_REAL(Packet4cd,Packet8d)
template<> EIGEN_STRONG_INLINE Packet4cd pdiv<Packet4cd>(const Packet4cd& a, const Packet4cd& b)
{
Packet4cd num = pmul(a, pconj(b));
__m512d tmp = _mm512_mul_pd(b.v, b.v);
__m512d denom = padd(_mm512_permute_pd(tmp,0x55), tmp);
return Packet4cd(_mm512_div_pd(num.v, denom));
}
template<> EIGEN_STRONG_INLINE Packet4cd pcplxflip<Packet4cd>(const Packet4cd& x)
{
return Packet4cd(_mm512_permute_pd(x.v,0x55));
}
EIGEN_DEVICE_FUNC inline void
ptranspose(PacketBlock<Packet8cf,4>& kernel) {
PacketBlock<Packet8d,4> pb;
pb.packet[0] = _mm512_castps_pd(kernel.packet[0].v);
pb.packet[1] = _mm512_castps_pd(kernel.packet[1].v);
pb.packet[2] = _mm512_castps_pd(kernel.packet[2].v);
pb.packet[3] = _mm512_castps_pd(kernel.packet[3].v);
ptranspose(pb);
kernel.packet[0].v = _mm512_castpd_ps(pb.packet[0]);
kernel.packet[1].v = _mm512_castpd_ps(pb.packet[1]);
kernel.packet[2].v = _mm512_castpd_ps(pb.packet[2]);
kernel.packet[3].v = _mm512_castpd_ps(pb.packet[3]);
}
EIGEN_DEVICE_FUNC inline void
ptranspose(PacketBlock<Packet8cf,8>& kernel) {
PacketBlock<Packet8d,8> pb;
pb.packet[0] = _mm512_castps_pd(kernel.packet[0].v);
pb.packet[1] = _mm512_castps_pd(kernel.packet[1].v);
pb.packet[2] = _mm512_castps_pd(kernel.packet[2].v);
pb.packet[3] = _mm512_castps_pd(kernel.packet[3].v);
pb.packet[4] = _mm512_castps_pd(kernel.packet[4].v);
pb.packet[5] = _mm512_castps_pd(kernel.packet[5].v);
pb.packet[6] = _mm512_castps_pd(kernel.packet[6].v);
pb.packet[7] = _mm512_castps_pd(kernel.packet[7].v);
ptranspose(pb);
kernel.packet[0].v = _mm512_castpd_ps(pb.packet[0]);
kernel.packet[1].v = _mm512_castpd_ps(pb.packet[1]);
kernel.packet[2].v = _mm512_castpd_ps(pb.packet[2]);
kernel.packet[3].v = _mm512_castpd_ps(pb.packet[3]);
kernel.packet[4].v = _mm512_castpd_ps(pb.packet[4]);
kernel.packet[5].v = _mm512_castpd_ps(pb.packet[5]);
kernel.packet[6].v = _mm512_castpd_ps(pb.packet[6]);
kernel.packet[7].v = _mm512_castpd_ps(pb.packet[7]);
}
EIGEN_DEVICE_FUNC inline void
ptranspose(PacketBlock<Packet4cd,4>& kernel) {
__m512d T0 = _mm512_shuffle_f64x2(kernel.packet[0].v, kernel.packet[1].v, (shuffle_mask<0,1,0,1>::mask)); // [a0 a1 b0 b1]
__m512d T1 = _mm512_shuffle_f64x2(kernel.packet[0].v, kernel.packet[1].v, (shuffle_mask<2,3,2,3>::mask)); // [a2 a3 b2 b3]
__m512d T2 = _mm512_shuffle_f64x2(kernel.packet[2].v, kernel.packet[3].v, (shuffle_mask<0,1,0,1>::mask)); // [c0 c1 d0 d1]
__m512d T3 = _mm512_shuffle_f64x2(kernel.packet[2].v, kernel.packet[3].v, (shuffle_mask<2,3,2,3>::mask)); // [c2 c3 d2 d3]
kernel.packet[3] = Packet4cd(_mm512_shuffle_f64x2(T1, T3, (shuffle_mask<1,3,1,3>::mask))); // [a3 b3 c3 d3]
kernel.packet[2] = Packet4cd(_mm512_shuffle_f64x2(T1, T3, (shuffle_mask<0,2,0,2>::mask))); // [a2 b2 c2 d2]
kernel.packet[1] = Packet4cd(_mm512_shuffle_f64x2(T0, T2, (shuffle_mask<1,3,1,3>::mask))); // [a1 b1 c1 d1]
kernel.packet[0] = Packet4cd(_mm512_shuffle_f64x2(T0, T2, (shuffle_mask<0,2,0,2>::mask))); // [a0 b0 c0 d0]
}
template<> EIGEN_STRONG_INLINE Packet4cd psqrt<Packet4cd>(const Packet4cd& a) {
return psqrt_complex<Packet4cd>(a);
}
template<> EIGEN_STRONG_INLINE Packet8cf psqrt<Packet8cf>(const Packet8cf& a) {
return psqrt_complex<Packet8cf>(a);
}
} // end namespace internal
} // end namespace Eigen
#endif // EIGEN_COMPLEX_AVX512_H

View File

@ -0,0 +1,89 @@
// This file is part of Eigen, a lightweight C++ template library
// for linear algebra.
//
// Copyright (C) 2019 Rasmus Munk Larsen <rmlarsen@google.com>
//
// This Source Code Form is subject to the terms of the Mozilla
// Public License v. 2.0. If a copy of the MPL was not distributed
// with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
#ifndef EIGEN_TYPE_CASTING_AVX512_H
#define EIGEN_TYPE_CASTING_AVX512_H
namespace Eigen {
namespace internal {
template<> EIGEN_STRONG_INLINE Packet16i pcast<Packet16f, Packet16i>(const Packet16f& a) {
return _mm512_cvttps_epi32(a);
}
template<> EIGEN_STRONG_INLINE Packet16f pcast<Packet16i, Packet16f>(const Packet16i& a) {
return _mm512_cvtepi32_ps(a);
}
template<> EIGEN_STRONG_INLINE Packet16i preinterpret<Packet16i, Packet16f>(const Packet16f& a) {
return _mm512_castps_si512(a);
}
template<> EIGEN_STRONG_INLINE Packet16f preinterpret<Packet16f, Packet16i>(const Packet16i& a) {
return _mm512_castsi512_ps(a);
}
template <>
struct type_casting_traits<half, float> {
enum {
VectorizedCast = 1,
SrcCoeffRatio = 1,
TgtCoeffRatio = 1
};
};
template<> EIGEN_STRONG_INLINE Packet16f pcast<Packet16h, Packet16f>(const Packet16h& a) {
return half2float(a);
}
template <>
struct type_casting_traits<float, half> {
enum {
VectorizedCast = 1,
SrcCoeffRatio = 1,
TgtCoeffRatio = 1
};
};
template<> EIGEN_STRONG_INLINE Packet16h pcast<Packet16f, Packet16h>(const Packet16f& a) {
return float2half(a);
}
template <>
struct type_casting_traits<bfloat16, float> {
enum {
VectorizedCast = 1,
SrcCoeffRatio = 1,
TgtCoeffRatio = 1
};
};
template<> EIGEN_STRONG_INLINE Packet16f pcast<Packet16bf, Packet16f>(const Packet16bf& a) {
return Bf16ToF32(a);
}
template <>
struct type_casting_traits<float, bfloat16> {
enum {
VectorizedCast = 1,
SrcCoeffRatio = 1,
TgtCoeffRatio = 1
};
};
template<> EIGEN_STRONG_INLINE Packet16bf pcast<Packet16f, Packet16bf>(const Packet16f& a) {
return F32ToBf16(a);
}
} // end namespace internal
} // end namespace Eigen
#endif // EIGEN_TYPE_CASTING_AVX512_H

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,221 @@
//#define EIGEN_POWER_USE_PREFETCH // Use prefetching in gemm routines
#ifdef EIGEN_POWER_USE_PREFETCH
#define EIGEN_POWER_PREFETCH(p) prefetch(p)
#else
#define EIGEN_POWER_PREFETCH(p)
#endif
namespace Eigen {
namespace internal {
template<typename Scalar, typename Packet, typename DataMapper, typename Index, const Index accRows>
EIGEN_STRONG_INLINE void gemm_extra_col(
const DataMapper& res,
const Scalar* lhs_base,
const Scalar* rhs_base,
Index depth,
Index strideA,
Index offsetA,
Index row,
Index col,
Index remaining_rows,
Index remaining_cols,
const Packet& pAlpha);
template<typename Scalar, typename Packet, typename DataMapper, typename Index, const Index accRows, const Index accCols>
EIGEN_STRONG_INLINE void gemm_extra_row(
const DataMapper& res,
const Scalar* lhs_base,
const Scalar* rhs_base,
Index depth,
Index strideA,
Index offsetA,
Index row,
Index col,
Index rows,
Index cols,
Index remaining_rows,
const Packet& pAlpha,
const Packet& pMask);
template<typename Scalar, typename Packet, typename DataMapper, typename Index, const Index accCols>
EIGEN_STRONG_INLINE void gemm_unrolled_col(
const DataMapper& res,
const Scalar* lhs_base,
const Scalar* rhs_base,
Index depth,
Index strideA,
Index offsetA,
Index& row,
Index rows,
Index col,
Index remaining_cols,
const Packet& pAlpha);
template<typename Packet>
EIGEN_ALWAYS_INLINE Packet bmask(const int remaining_rows);
template<typename Scalar, typename Packet, typename Packetc, typename DataMapper, typename Index, const Index accRows, const Index accCols, bool ConjugateLhs, bool ConjugateRhs, bool LhsIsReal, bool RhsIsReal>
EIGEN_STRONG_INLINE void gemm_complex_extra_col(
const DataMapper& res,
const Scalar* lhs_base,
const Scalar* rhs_base,
Index depth,
Index strideA,
Index offsetA,
Index strideB,
Index row,
Index col,
Index remaining_rows,
Index remaining_cols,
const Packet& pAlphaReal,
const Packet& pAlphaImag);
template<typename Scalar, typename Packet, typename Packetc, typename DataMapper, typename Index, const Index accRows, const Index accCols, bool ConjugateLhs, bool ConjugateRhs, bool LhsIsReal, bool RhsIsReal>
EIGEN_STRONG_INLINE void gemm_complex_extra_row(
const DataMapper& res,
const Scalar* lhs_base,
const Scalar* rhs_base,
Index depth,
Index strideA,
Index offsetA,
Index strideB,
Index row,
Index col,
Index rows,
Index cols,
Index remaining_rows,
const Packet& pAlphaReal,
const Packet& pAlphaImag,
const Packet& pMask);
template<typename Scalar, typename Packet, typename Packetc, typename DataMapper, typename Index, const Index accCols, bool ConjugateLhs, bool ConjugateRhs, bool LhsIsReal, bool RhsIsReal>
EIGEN_STRONG_INLINE void gemm_complex_unrolled_col(
const DataMapper& res,
const Scalar* lhs_base,
const Scalar* rhs_base,
Index depth,
Index strideA,
Index offsetA,
Index strideB,
Index& row,
Index rows,
Index col,
Index remaining_cols,
const Packet& pAlphaReal,
const Packet& pAlphaImag);
template<typename Scalar, typename Packet>
EIGEN_ALWAYS_INLINE Packet ploadLhs(const Scalar* lhs);
template<typename DataMapper, typename Packet, typename Index, const Index accCols, int N, int StorageOrder>
EIGEN_ALWAYS_INLINE void bload(PacketBlock<Packet,4>& acc, const DataMapper& res, Index row, Index col);
template<typename DataMapper, typename Packet, typename Index, const Index accCols, int N, int StorageOrder>
EIGEN_ALWAYS_INLINE void bload(PacketBlock<Packet,8>& acc, const DataMapper& res, Index row, Index col);
template<typename Packet>
EIGEN_ALWAYS_INLINE void bscale(PacketBlock<Packet,4>& acc, PacketBlock<Packet,4>& accZ, const Packet& pAlpha);
template<typename Packet, int N>
EIGEN_ALWAYS_INLINE void bscalec(PacketBlock<Packet,N>& aReal, PacketBlock<Packet,N>& aImag, const Packet& bReal, const Packet& bImag, PacketBlock<Packet,N>& cReal, PacketBlock<Packet,N>& cImag);
const static Packet16uc p16uc_SETCOMPLEX32_FIRST = { 0, 1, 2, 3,
16, 17, 18, 19,
4, 5, 6, 7,
20, 21, 22, 23};
const static Packet16uc p16uc_SETCOMPLEX32_SECOND = { 8, 9, 10, 11,
24, 25, 26, 27,
12, 13, 14, 15,
28, 29, 30, 31};
//[a,b],[ai,bi] = [a,ai] - This is equivalent to p16uc_GETREAL64
const static Packet16uc p16uc_SETCOMPLEX64_FIRST = { 0, 1, 2, 3, 4, 5, 6, 7,
16, 17, 18, 19, 20, 21, 22, 23};
//[a,b],[ai,bi] = [b,bi] - This is equivalent to p16uc_GETIMAG64
const static Packet16uc p16uc_SETCOMPLEX64_SECOND = { 8, 9, 10, 11, 12, 13, 14, 15,
24, 25, 26, 27, 28, 29, 30, 31};
// Grab two decouples real/imaginary PacketBlocks and return two coupled (real/imaginary pairs) PacketBlocks.
template<typename Packet, typename Packetc>
EIGEN_ALWAYS_INLINE void bcouple_common(PacketBlock<Packet,4>& taccReal, PacketBlock<Packet,4>& taccImag, PacketBlock<Packetc, 4>& acc1, PacketBlock<Packetc, 4>& acc2)
{
acc1.packet[0].v = vec_perm(taccReal.packet[0], taccImag.packet[0], p16uc_SETCOMPLEX32_FIRST);
acc1.packet[1].v = vec_perm(taccReal.packet[1], taccImag.packet[1], p16uc_SETCOMPLEX32_FIRST);
acc1.packet[2].v = vec_perm(taccReal.packet[2], taccImag.packet[2], p16uc_SETCOMPLEX32_FIRST);
acc1.packet[3].v = vec_perm(taccReal.packet[3], taccImag.packet[3], p16uc_SETCOMPLEX32_FIRST);
acc2.packet[0].v = vec_perm(taccReal.packet[0], taccImag.packet[0], p16uc_SETCOMPLEX32_SECOND);
acc2.packet[1].v = vec_perm(taccReal.packet[1], taccImag.packet[1], p16uc_SETCOMPLEX32_SECOND);
acc2.packet[2].v = vec_perm(taccReal.packet[2], taccImag.packet[2], p16uc_SETCOMPLEX32_SECOND);
acc2.packet[3].v = vec_perm(taccReal.packet[3], taccImag.packet[3], p16uc_SETCOMPLEX32_SECOND);
}
template<typename Packet, typename Packetc>
EIGEN_ALWAYS_INLINE void bcouple(PacketBlock<Packet,4>& taccReal, PacketBlock<Packet,4>& taccImag, PacketBlock<Packetc,8>& tRes, PacketBlock<Packetc, 4>& acc1, PacketBlock<Packetc, 4>& acc2)
{
bcouple_common<Packet, Packetc>(taccReal, taccImag, acc1, acc2);
acc1.packet[0] = padd<Packetc>(tRes.packet[0], acc1.packet[0]);
acc1.packet[1] = padd<Packetc>(tRes.packet[1], acc1.packet[1]);
acc1.packet[2] = padd<Packetc>(tRes.packet[2], acc1.packet[2]);
acc1.packet[3] = padd<Packetc>(tRes.packet[3], acc1.packet[3]);
acc2.packet[0] = padd<Packetc>(tRes.packet[4], acc2.packet[0]);
acc2.packet[1] = padd<Packetc>(tRes.packet[5], acc2.packet[1]);
acc2.packet[2] = padd<Packetc>(tRes.packet[6], acc2.packet[2]);
acc2.packet[3] = padd<Packetc>(tRes.packet[7], acc2.packet[3]);
}
template<typename Packet, typename Packetc>
EIGEN_ALWAYS_INLINE void bcouple_common(PacketBlock<Packet,1>& taccReal, PacketBlock<Packet,1>& taccImag, PacketBlock<Packetc, 1>& acc1, PacketBlock<Packetc, 1>& acc2)
{
acc1.packet[0].v = vec_perm(taccReal.packet[0], taccImag.packet[0], p16uc_SETCOMPLEX32_FIRST);
acc2.packet[0].v = vec_perm(taccReal.packet[0], taccImag.packet[0], p16uc_SETCOMPLEX32_SECOND);
}
template<typename Packet, typename Packetc>
EIGEN_ALWAYS_INLINE void bcouple(PacketBlock<Packet,1>& taccReal, PacketBlock<Packet,1>& taccImag, PacketBlock<Packetc,2>& tRes, PacketBlock<Packetc, 1>& acc1, PacketBlock<Packetc, 1>& acc2)
{
bcouple_common<Packet, Packetc>(taccReal, taccImag, acc1, acc2);
acc1.packet[0] = padd<Packetc>(tRes.packet[0], acc1.packet[0]);
acc2.packet[0] = padd<Packetc>(tRes.packet[1], acc2.packet[0]);
}
template<>
EIGEN_ALWAYS_INLINE void bcouple_common<Packet2d, Packet1cd>(PacketBlock<Packet2d,4>& taccReal, PacketBlock<Packet2d,4>& taccImag, PacketBlock<Packet1cd, 4>& acc1, PacketBlock<Packet1cd, 4>& acc2)
{
acc1.packet[0].v = vec_perm(taccReal.packet[0], taccImag.packet[0], p16uc_SETCOMPLEX64_FIRST);
acc1.packet[1].v = vec_perm(taccReal.packet[1], taccImag.packet[1], p16uc_SETCOMPLEX64_FIRST);
acc1.packet[2].v = vec_perm(taccReal.packet[2], taccImag.packet[2], p16uc_SETCOMPLEX64_FIRST);
acc1.packet[3].v = vec_perm(taccReal.packet[3], taccImag.packet[3], p16uc_SETCOMPLEX64_FIRST);
acc2.packet[0].v = vec_perm(taccReal.packet[0], taccImag.packet[0], p16uc_SETCOMPLEX64_SECOND);
acc2.packet[1].v = vec_perm(taccReal.packet[1], taccImag.packet[1], p16uc_SETCOMPLEX64_SECOND);
acc2.packet[2].v = vec_perm(taccReal.packet[2], taccImag.packet[2], p16uc_SETCOMPLEX64_SECOND);
acc2.packet[3].v = vec_perm(taccReal.packet[3], taccImag.packet[3], p16uc_SETCOMPLEX64_SECOND);
}
template<>
EIGEN_ALWAYS_INLINE void bcouple_common<Packet2d, Packet1cd>(PacketBlock<Packet2d,1>& taccReal, PacketBlock<Packet2d,1>& taccImag, PacketBlock<Packet1cd, 1>& acc1, PacketBlock<Packet1cd, 1>& acc2)
{
acc1.packet[0].v = vec_perm(taccReal.packet[0], taccImag.packet[0], p16uc_SETCOMPLEX64_FIRST);
acc2.packet[0].v = vec_perm(taccReal.packet[0], taccImag.packet[0], p16uc_SETCOMPLEX64_SECOND);
}
// This is necessary because ploadRhs for double returns a pair of vectors when MMA is enabled.
template<typename Scalar, typename Packet>
EIGEN_ALWAYS_INLINE Packet ploadRhs(const Scalar* rhs)
{
return ploadu<Packet>(rhs);
}
} // end namespace internal
} // end namespace Eigen

View File

@ -0,0 +1,629 @@
// This file is part of Eigen, a lightweight C++ template library
// for linear algebra.
//
// Copyright (C) 2020 Everton Constantino (everton.constantino@ibm.com)
// Copyright (C) 2021 Chip Kerchner (chip.kerchner@ibm.com)
//
// This Source Code Form is subject to the terms of the Mozilla
// Public License v. 2.0. If a copy of the MPL was not distributed
// with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
#ifndef EIGEN_MATRIX_PRODUCT_MMA_ALTIVEC_H
#define EIGEN_MATRIX_PRODUCT_MMA_ALTIVEC_H
#pragma GCC target("cpu=power10")
#ifdef __has_builtin
#if !__has_builtin(__builtin_vsx_assemble_pair)
#define __builtin_vsx_assemble_pair __builtin_mma_assemble_pair
#endif
#endif
namespace Eigen {
namespace internal {
template<typename Scalar, typename Packet>
EIGEN_ALWAYS_INLINE void bsetzeroMMA(__vector_quad* acc)
{
__builtin_mma_xxsetaccz(acc);
}
template<typename DataMapper, typename Index, typename Packet, const Index accCols>
EIGEN_ALWAYS_INLINE void storeAccumulator(Index i, Index j, const DataMapper& data, const Packet& alpha, __vector_quad* acc)
{
PacketBlock<Packet, 4> result;
__builtin_mma_disassemble_acc(&result.packet, acc);
PacketBlock<Packet, 4> tRes;
bload<DataMapper, Packet, Index, accCols, 0, ColMajor>(tRes, data, i, j);
bscale<Packet>(tRes, result, alpha);
data.template storePacketBlock<Packet, 4>(i, j, tRes);
}
template<typename DataMapper, typename Index, typename Packet, typename Packetc, const Index accColsC, int N>
EIGEN_ALWAYS_INLINE void storeComplexAccumulator(Index i, Index j, const DataMapper& data, const Packet& alphaReal, const Packet& alphaImag, __vector_quad* accReal, __vector_quad* accImag)
{
PacketBlock<Packet, 4> resultReal, resultImag;
__builtin_mma_disassemble_acc(&resultReal.packet, accReal);
__builtin_mma_disassemble_acc(&resultImag.packet, accImag);
PacketBlock<Packetc, 8> tRes;
bload<DataMapper, Packetc, Index, accColsC, N, ColMajor>(tRes, data, i, j);
PacketBlock<Packet,4> taccReal, taccImag;
bscalec<Packet,4>(resultReal, resultImag, alphaReal, alphaImag, taccReal, taccImag);
PacketBlock<Packetc, 4> acc1, acc2;
bcouple<Packet, Packetc>(taccReal, taccImag, tRes, acc1, acc2);
data.template storePacketBlock<Packetc, 4>(i + N*accColsC, j, acc1);
data.template storePacketBlock<Packetc, 4>(i + (N+1)*accColsC, j, acc2);
}
// Defaults to float32, since Eigen still supports C++03 we can't use default template arguments
template<typename LhsPacket, typename RhsPacket, bool NegativeAccumulate>
EIGEN_ALWAYS_INLINE void pgerMMA(__vector_quad* acc, const RhsPacket& a, const LhsPacket& b)
{
if(NegativeAccumulate)
{
__builtin_mma_xvf32gernp(acc, (__vector unsigned char)a, (__vector unsigned char)b);
} else {
__builtin_mma_xvf32gerpp(acc, (__vector unsigned char)a, (__vector unsigned char)b);
}
}
template<typename LhsPacket, typename RhsPacket, bool NegativeAccumulate>
EIGEN_ALWAYS_INLINE void pgerMMA(__vector_quad* acc, const PacketBlock<Packet2d,2>& a, const Packet2d& b)
{
__vector_pair* a0 = (__vector_pair *)(&a.packet[0]);
if(NegativeAccumulate)
{
__builtin_mma_xvf64gernp(acc, *a0, (__vector unsigned char)b);
} else {
__builtin_mma_xvf64gerpp(acc, *a0, (__vector unsigned char)b);
}
}
template<typename LhsPacket, typename RhsPacket, bool NegativeAccumulate>
EIGEN_ALWAYS_INLINE void pgerMMA(__vector_quad* acc, const __vector_pair& a, const Packet2d& b)
{
if(NegativeAccumulate)
{
__builtin_mma_xvf64gernp(acc, (__vector_pair)a, (__vector unsigned char)b);
} else {
__builtin_mma_xvf64gerpp(acc, (__vector_pair)a, (__vector unsigned char)b);
}
}
template<typename LhsPacket, typename RhsPacket, bool NegativeAccumulate>
EIGEN_ALWAYS_INLINE void pgerMMA(__vector_quad*, const __vector_pair&, const Packet4f&)
{
// Just for compilation
}
template<typename Scalar, typename Packet, typename RhsPacket, bool ConjugateLhs, bool ConjugateRhs, bool LhsIsReal, bool RhsIsReal>
EIGEN_ALWAYS_INLINE void pgercMMA(__vector_quad* accReal, __vector_quad* accImag, const Packet& lhsV, const Packet& lhsVi, const RhsPacket& rhsV, const RhsPacket& rhsVi)
{
pgerMMA<Packet, RhsPacket, false>(accReal, rhsV, lhsV);
if(LhsIsReal) {
pgerMMA<Packet, RhsPacket, ConjugateRhs>(accImag, rhsVi, lhsV);
} else {
if(!RhsIsReal) {
pgerMMA<Packet, RhsPacket, ConjugateLhs == ConjugateRhs>(accReal, rhsVi, lhsVi);
pgerMMA<Packet, RhsPacket, ConjugateRhs>(accImag, rhsVi, lhsV);
} else {
EIGEN_UNUSED_VARIABLE(rhsVi);
}
pgerMMA<Packet, RhsPacket, ConjugateLhs>(accImag, rhsV, lhsVi);
}
}
// This is necessary because ploadRhs for double returns a pair of vectors when MMA is enabled.
template<typename Scalar, typename Packet>
EIGEN_ALWAYS_INLINE void ploadRhsMMA(const Scalar* rhs, Packet& rhsV)
{
rhsV = ploadRhs<Scalar, Packet>((const Scalar*)(rhs));
}
template<>
EIGEN_ALWAYS_INLINE void ploadRhsMMA<double, PacketBlock<Packet2d, 2> >(const double* rhs, PacketBlock<Packet2d, 2>& rhsV)
{
rhsV.packet[0] = ploadRhs<double, Packet2d>((const double *)((Packet2d *)rhs ));
rhsV.packet[1] = ploadRhs<double, Packet2d>((const double *)(((Packet2d *)rhs) + 1));
}
template<>
EIGEN_ALWAYS_INLINE void ploadRhsMMA<double, __vector_pair>(const double* rhs, __vector_pair& rhsV)
{
#if EIGEN_COMP_LLVM
__builtin_vsx_assemble_pair(&rhsV,
(__vector unsigned char)(ploadRhs<double, Packet2d>((const double *)(((Packet2d *)rhs) + 1))),
(__vector unsigned char)(ploadRhs<double, Packet2d>((const double *)((Packet2d *)rhs ))));
#else
__asm__ ("lxvp %x0,%1" : "=wa" (rhsV) : "Y" (*rhs));
#endif
}
template<>
EIGEN_ALWAYS_INLINE void ploadRhsMMA(const float*, __vector_pair&)
{
// Just for compilation
}
// PEEL_MMA loop factor.
#define PEEL_MMA 7
#define MICRO_MMA_UNROLL(func) \
func(0) func(1) func(2) func(3) func(4) func(5) func(6) func(7)
#define MICRO_MMA_LOAD_ONE(iter) \
if (unroll_factor > iter) { \
lhsV##iter = ploadLhs<Scalar, Packet>(lhs_ptr##iter); \
lhs_ptr##iter += accCols; \
} else { \
EIGEN_UNUSED_VARIABLE(lhsV##iter); \
}
#define MICRO_MMA_WORK_ONE(iter, type, peel) \
if (unroll_factor > iter) { \
pgerMMA<Packet, type, false>(&accZero##iter, rhsV##peel, lhsV##iter); \
}
#define MICRO_MMA_TYPE_PEEL(func, func2, type, peel) \
if (PEEL_MMA > peel) { \
Packet lhsV0, lhsV1, lhsV2, lhsV3, lhsV4, lhsV5, lhsV6, lhsV7; \
ploadRhsMMA<Scalar, type>(rhs_ptr + (accRows * peel), rhsV##peel); \
MICRO_MMA_UNROLL(func2); \
func(0,type,peel) func(1,type,peel) func(2,type,peel) func(3,type,peel) \
func(4,type,peel) func(5,type,peel) func(6,type,peel) func(7,type,peel) \
} else { \
EIGEN_UNUSED_VARIABLE(rhsV##peel); \
}
#define MICRO_MMA_UNROLL_TYPE_PEEL(func, func2, type) \
type rhsV0, rhsV1, rhsV2, rhsV3, rhsV4, rhsV5, rhsV6, rhsV7, rhsV8, rhsV9; \
MICRO_MMA_TYPE_PEEL(func,func2,type,0); MICRO_MMA_TYPE_PEEL(func,func2,type,1); \
MICRO_MMA_TYPE_PEEL(func,func2,type,2); MICRO_MMA_TYPE_PEEL(func,func2,type,3); \
MICRO_MMA_TYPE_PEEL(func,func2,type,4); MICRO_MMA_TYPE_PEEL(func,func2,type,5); \
MICRO_MMA_TYPE_PEEL(func,func2,type,6); MICRO_MMA_TYPE_PEEL(func,func2,type,7); \
MICRO_MMA_TYPE_PEEL(func,func2,type,8); MICRO_MMA_TYPE_PEEL(func,func2,type,9);
#define MICRO_MMA_UNROLL_TYPE_ONE(func, func2, type) \
type rhsV0; \
MICRO_MMA_TYPE_PEEL(func,func2,type,0);
#define MICRO_MMA_ONE_PEEL \
if (sizeof(Scalar) == sizeof(float)) { \
MICRO_MMA_UNROLL_TYPE_PEEL(MICRO_MMA_WORK_ONE, MICRO_MMA_LOAD_ONE, RhsPacket); \
} else { \
MICRO_MMA_UNROLL_TYPE_PEEL(MICRO_MMA_WORK_ONE, MICRO_MMA_LOAD_ONE, __vector_pair); \
} \
rhs_ptr += (accRows * PEEL_MMA);
#define MICRO_MMA_ONE \
if (sizeof(Scalar) == sizeof(float)) { \
MICRO_MMA_UNROLL_TYPE_ONE(MICRO_MMA_WORK_ONE, MICRO_MMA_LOAD_ONE, RhsPacket); \
} else { \
MICRO_MMA_UNROLL_TYPE_ONE(MICRO_MMA_WORK_ONE, MICRO_MMA_LOAD_ONE, __vector_pair); \
} \
rhs_ptr += accRows;
#define MICRO_MMA_DST_PTR_ONE(iter) \
if (unroll_factor > iter) { \
bsetzeroMMA<Scalar, Packet>(&accZero##iter); \
} else { \
EIGEN_UNUSED_VARIABLE(accZero##iter); \
}
#define MICRO_MMA_DST_PTR MICRO_MMA_UNROLL(MICRO_MMA_DST_PTR_ONE)
#define MICRO_MMA_SRC_PTR_ONE(iter) \
if (unroll_factor > iter) { \
lhs_ptr##iter = lhs_base + ( (row/accCols) + iter )*strideA*accCols + accCols*offsetA; \
} else { \
EIGEN_UNUSED_VARIABLE(lhs_ptr##iter); \
}
#define MICRO_MMA_SRC_PTR MICRO_MMA_UNROLL(MICRO_MMA_SRC_PTR_ONE)
#define MICRO_MMA_PREFETCH_ONE(iter) \
if (unroll_factor > iter) { \
EIGEN_POWER_PREFETCH(lhs_ptr##iter); \
}
#define MICRO_MMA_PREFETCH MICRO_MMA_UNROLL(MICRO_MMA_PREFETCH_ONE)
#define MICRO_MMA_STORE_ONE(iter) \
if (unroll_factor > iter) { \
storeAccumulator<DataMapper, Index, Packet, accCols>(row + iter*accCols, col, res, pAlpha, &accZero##iter); \
}
#define MICRO_MMA_STORE MICRO_MMA_UNROLL(MICRO_MMA_STORE_ONE)
template<int unroll_factor, typename Scalar, typename Packet, typename RhsPacket, typename DataMapper, typename Index, const Index accRows, const Index accCols>
EIGEN_STRONG_INLINE void gemm_unrolled_MMA_iteration(
const DataMapper& res,
const Scalar* lhs_base,
const Scalar* rhs_base,
Index depth,
Index strideA,
Index offsetA,
Index& row,
Index col,
const Packet& pAlpha)
{
const Scalar* rhs_ptr = rhs_base;
const Scalar* lhs_ptr0 = NULL, * lhs_ptr1 = NULL, * lhs_ptr2 = NULL, * lhs_ptr3 = NULL, * lhs_ptr4 = NULL, * lhs_ptr5 = NULL, * lhs_ptr6 = NULL, * lhs_ptr7 = NULL;
__vector_quad accZero0, accZero1, accZero2, accZero3, accZero4, accZero5, accZero6, accZero7;
MICRO_MMA_SRC_PTR
MICRO_MMA_DST_PTR
Index k = 0;
for(; k + PEEL_MMA <= depth; k+= PEEL_MMA)
{
EIGEN_POWER_PREFETCH(rhs_ptr);
MICRO_MMA_PREFETCH
MICRO_MMA_ONE_PEEL
}
for(; k < depth; k++)
{
MICRO_MMA_ONE
}
MICRO_MMA_STORE
row += unroll_factor*accCols;
}
template<typename Scalar, typename Index, typename Packet, typename RhsPacket, typename DataMapper, const Index accRows, const Index accCols>
void gemmMMA(const DataMapper& res, const Scalar* blockA, const Scalar* blockB, Index rows, Index depth, Index cols, Scalar alpha, Index strideA, Index strideB, Index offsetA, Index offsetB)
{
const Index remaining_rows = rows % accCols;
const Index remaining_cols = cols % accRows;
if( strideA == -1 ) strideA = depth;
if( strideB == -1 ) strideB = depth;
const Packet pAlpha = pset1<Packet>(alpha);
const Packet pMask = bmask<Packet>((const int)(remaining_rows));
Index col = 0;
for(; col + accRows <= cols; col += accRows)
{
const Scalar* rhs_base = blockB + col*strideB + accRows*offsetB;
const Scalar* lhs_base = blockA;
Index row = 0;
#define MAX_MMA_UNROLL 7
while(row + MAX_MMA_UNROLL*accCols <= rows) {
gemm_unrolled_MMA_iteration<MAX_MMA_UNROLL, Scalar, Packet, RhsPacket, DataMapper, Index, accRows, accCols>(res, lhs_base, rhs_base, depth, strideA, offsetA, row, col, pAlpha);
}
switch( (rows-row)/accCols ) {
#if MAX_MMA_UNROLL > 7
case 7:
gemm_unrolled_MMA_iteration<7, Scalar, Packet, RhsPacket, DataMapper, Index, accRows, accCols>(res, lhs_base, rhs_base, depth, strideA, offsetA, row, col, pAlpha);
break;
#endif
#if MAX_MMA_UNROLL > 6
case 6:
gemm_unrolled_MMA_iteration<6, Scalar, Packet, RhsPacket, DataMapper, Index, accRows, accCols>(res, lhs_base, rhs_base, depth, strideA, offsetA, row, col, pAlpha);
break;
#endif
#if MAX_MMA_UNROLL > 5
case 5:
gemm_unrolled_MMA_iteration<5, Scalar, Packet, RhsPacket, DataMapper, Index, accRows, accCols>(res, lhs_base, rhs_base, depth, strideA, offsetA, row, col, pAlpha);
break;
#endif
#if MAX_MMA_UNROLL > 4
case 4:
gemm_unrolled_MMA_iteration<4, Scalar, Packet, RhsPacket, DataMapper, Index, accRows, accCols>(res, lhs_base, rhs_base, depth, strideA, offsetA, row, col, pAlpha);
break;
#endif
#if MAX_MMA_UNROLL > 3
case 3:
gemm_unrolled_MMA_iteration<3, Scalar, Packet, RhsPacket, DataMapper, Index, accRows, accCols>(res, lhs_base, rhs_base, depth, strideA, offsetA, row, col, pAlpha);
break;
#endif
#if MAX_MMA_UNROLL > 2
case 2:
gemm_unrolled_MMA_iteration<2, Scalar, Packet, RhsPacket, DataMapper, Index, accRows, accCols>(res, lhs_base, rhs_base, depth, strideA, offsetA, row, col, pAlpha);
break;
#endif
#if MAX_MMA_UNROLL > 1
case 1:
gemm_unrolled_MMA_iteration<1, Scalar, Packet, RhsPacket, DataMapper, Index, accRows, accCols>(res, lhs_base, rhs_base, depth, strideA, offsetA, row, col, pAlpha);
break;
#endif
default:
break;
}
#undef MAX_MMA_UNROLL
if(remaining_rows > 0)
{
gemm_extra_row<Scalar, Packet, DataMapper, Index, accRows, accCols>(res, lhs_base, rhs_base, depth, strideA, offsetA, row, col, rows, cols, remaining_rows, pAlpha, pMask);
}
}
if(remaining_cols > 0)
{
const Scalar* rhs_base = blockB + col*strideB + remaining_cols*offsetB;
const Scalar* lhs_base = blockA;
for(; col < cols; col++)
{
Index row = 0;
gemm_unrolled_col<Scalar, Packet, DataMapper, Index, accCols>(res, lhs_base, rhs_base, depth, strideA, offsetA, row, rows, col, remaining_cols, pAlpha);
if (remaining_rows > 0)
{
gemm_extra_col<Scalar, Packet, DataMapper, Index, accRows>(res, lhs_base, rhs_base, depth, strideA, offsetA, row, col, remaining_rows, remaining_cols, pAlpha);
}
rhs_base++;
}
}
}
#define accColsC (accCols / 2)
#define advanceRows ((LhsIsReal) ? 1 : 2)
#define advanceCols ((RhsIsReal) ? 1 : 2)
// PEEL_COMPLEX_MMA loop factor.
#define PEEL_COMPLEX_MMA 7
#define MICRO_COMPLEX_MMA_UNROLL(func) \
func(0) func(1) func(2) func(3) func(4)
#define MICRO_COMPLEX_MMA_LOAD_ONE(iter) \
if (unroll_factor > iter) { \
lhsV##iter = ploadLhs<Scalar, Packet>(lhs_ptr_real##iter); \
lhs_ptr_real##iter += accCols; \
if(!LhsIsReal) { \
lhsVi##iter = ploadLhs<Scalar, Packet>(lhs_ptr_imag##iter); \
lhs_ptr_imag##iter += accCols; \
} else { \
EIGEN_UNUSED_VARIABLE(lhsVi##iter); \
} \
} else { \
EIGEN_UNUSED_VARIABLE(lhsV##iter); \
EIGEN_UNUSED_VARIABLE(lhsVi##iter); \
}
#define MICRO_COMPLEX_MMA_WORK_ONE(iter, type, peel) \
if (unroll_factor > iter) { \
pgercMMA<Scalar, Packet, type, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(&accReal##iter, &accImag##iter, lhsV##iter, lhsVi##iter, rhsV##peel, rhsVi##peel); \
}
#define MICRO_COMPLEX_MMA_TYPE_PEEL(func, func2, type, peel) \
if (PEEL_COMPLEX_MMA > peel) { \
Packet lhsV0, lhsV1, lhsV2, lhsV3, lhsV4; \
Packet lhsVi0, lhsVi1, lhsVi2, lhsVi3, lhsVi4; \
ploadRhsMMA<Scalar, type>(rhs_ptr_real + (accRows * peel), rhsV##peel); \
if(!RhsIsReal) { \
ploadRhsMMA<Scalar, type>(rhs_ptr_imag + (accRows * peel), rhsVi##peel); \
} else { \
EIGEN_UNUSED_VARIABLE(rhsVi##peel); \
} \
MICRO_COMPLEX_MMA_UNROLL(func2); \
func(0,type,peel) func(1,type,peel) func(2,type,peel) func(3,type,peel) func(4,type,peel) \
} else { \
EIGEN_UNUSED_VARIABLE(rhsV##peel); \
EIGEN_UNUSED_VARIABLE(rhsVi##peel); \
}
#define MICRO_COMPLEX_MMA_UNROLL_TYPE_PEEL(func, func2, type) \
type rhsV0, rhsV1, rhsV2, rhsV3, rhsV4, rhsV5, rhsV6, rhsV7, rhsV8, rhsV9; \
type rhsVi0, rhsVi1, rhsVi2, rhsVi3, rhsVi4, rhsVi5, rhsVi6, rhsVi7, rhsVi8, rhsVi9; \
MICRO_COMPLEX_MMA_TYPE_PEEL(func,func2,type,0); MICRO_COMPLEX_MMA_TYPE_PEEL(func,func2,type,1); \
MICRO_COMPLEX_MMA_TYPE_PEEL(func,func2,type,2); MICRO_COMPLEX_MMA_TYPE_PEEL(func,func2,type,3); \
MICRO_COMPLEX_MMA_TYPE_PEEL(func,func2,type,4); MICRO_COMPLEX_MMA_TYPE_PEEL(func,func2,type,5); \
MICRO_COMPLEX_MMA_TYPE_PEEL(func,func2,type,6); MICRO_COMPLEX_MMA_TYPE_PEEL(func,func2,type,7); \
MICRO_COMPLEX_MMA_TYPE_PEEL(func,func2,type,8); MICRO_COMPLEX_MMA_TYPE_PEEL(func,func2,type,9);
#define MICRO_COMPLEX_MMA_UNROLL_TYPE_ONE(func, func2, type) \
type rhsV0, rhsVi0; \
MICRO_COMPLEX_MMA_TYPE_PEEL(func,func2,type,0);
#define MICRO_COMPLEX_MMA_ONE_PEEL \
if (sizeof(Scalar) == sizeof(float)) { \
MICRO_COMPLEX_MMA_UNROLL_TYPE_PEEL(MICRO_COMPLEX_MMA_WORK_ONE, MICRO_COMPLEX_MMA_LOAD_ONE, RhsPacket); \
} else { \
MICRO_COMPLEX_MMA_UNROLL_TYPE_PEEL(MICRO_COMPLEX_MMA_WORK_ONE, MICRO_COMPLEX_MMA_LOAD_ONE, __vector_pair); \
} \
rhs_ptr_real += (accRows * PEEL_COMPLEX_MMA); \
if(!RhsIsReal) rhs_ptr_imag += (accRows * PEEL_COMPLEX_MMA);
#define MICRO_COMPLEX_MMA_ONE \
if (sizeof(Scalar) == sizeof(float)) { \
MICRO_COMPLEX_MMA_UNROLL_TYPE_ONE(MICRO_COMPLEX_MMA_WORK_ONE, MICRO_COMPLEX_MMA_LOAD_ONE, RhsPacket); \
} else { \
MICRO_COMPLEX_MMA_UNROLL_TYPE_ONE(MICRO_COMPLEX_MMA_WORK_ONE, MICRO_COMPLEX_MMA_LOAD_ONE, __vector_pair); \
} \
rhs_ptr_real += accRows; \
if(!RhsIsReal) rhs_ptr_imag += accRows;
#define MICRO_COMPLEX_MMA_DST_PTR_ONE(iter) \
if (unroll_factor > iter) { \
bsetzeroMMA<Scalar, Packet>(&accReal##iter); \
bsetzeroMMA<Scalar, Packet>(&accImag##iter); \
} else { \
EIGEN_UNUSED_VARIABLE(accReal##iter); \
EIGEN_UNUSED_VARIABLE(accImag##iter); \
}
#define MICRO_COMPLEX_MMA_DST_PTR MICRO_COMPLEX_MMA_UNROLL(MICRO_COMPLEX_MMA_DST_PTR_ONE)
#define MICRO_COMPLEX_MMA_SRC_PTR_ONE(iter) \
if (unroll_factor > iter) { \
lhs_ptr_real##iter = lhs_base + ( ((advanceRows*row)/accCols) + iter*advanceRows )*strideA*accCols + accCols*offsetA; \
if(!LhsIsReal) { \
lhs_ptr_imag##iter = lhs_ptr_real##iter + accCols*strideA; \
} else { \
EIGEN_UNUSED_VARIABLE(lhs_ptr_imag##iter); \
} \
} else { \
EIGEN_UNUSED_VARIABLE(lhs_ptr_real##iter); \
EIGEN_UNUSED_VARIABLE(lhs_ptr_imag##iter); \
}
#define MICRO_COMPLEX_MMA_SRC_PTR MICRO_COMPLEX_MMA_UNROLL(MICRO_COMPLEX_MMA_SRC_PTR_ONE)
#define MICRO_COMPLEX_MMA_PREFETCH_ONE(iter) \
if (unroll_factor > iter) { \
EIGEN_POWER_PREFETCH(lhs_ptr_real##iter); \
if(!LhsIsReal) { \
EIGEN_POWER_PREFETCH(lhs_ptr_imag##iter); \
} \
}
#define MICRO_COMPLEX_MMA_PREFETCH MICRO_COMPLEX_MMA_UNROLL(MICRO_COMPLEX_MMA_PREFETCH_ONE)
#define MICRO_COMPLEX_MMA_STORE_ONE(iter) \
if (unroll_factor > iter) { \
storeComplexAccumulator<DataMapper, Index, Packet, Packetc, accColsC, 0>(row + iter*accCols, col, res, pAlphaReal, pAlphaImag, &accReal##iter, &accImag##iter); \
}
#define MICRO_COMPLEX_MMA_STORE MICRO_COMPLEX_MMA_UNROLL(MICRO_COMPLEX_MMA_STORE_ONE)
template<int unroll_factor, typename Scalar, typename Packet, typename Packetc, typename RhsPacket, typename DataMapper, typename Index, const Index accRows, const Index accCols, bool ConjugateLhs, bool ConjugateRhs, bool LhsIsReal, bool RhsIsReal>
EIGEN_STRONG_INLINE void gemm_complex_unrolled_MMA_iteration(
const DataMapper& res,
const Scalar* lhs_base,
const Scalar* rhs_base,
Index depth,
Index strideA,
Index offsetA,
Index strideB,
Index& row,
Index col,
const Packet& pAlphaReal,
const Packet& pAlphaImag)
{
const Scalar* rhs_ptr_real = rhs_base;
const Scalar* rhs_ptr_imag;
if(!RhsIsReal) {
rhs_ptr_imag = rhs_base + accRows*strideB;
} else {
EIGEN_UNUSED_VARIABLE(rhs_ptr_imag);
}
const Scalar* lhs_ptr_real0 = NULL, * lhs_ptr_imag0 = NULL, * lhs_ptr_real1 = NULL, * lhs_ptr_imag1 = NULL;
const Scalar* lhs_ptr_real2 = NULL, * lhs_ptr_imag2 = NULL, * lhs_ptr_real3 = NULL, * lhs_ptr_imag3 = NULL;
const Scalar* lhs_ptr_real4 = NULL, * lhs_ptr_imag4 = NULL;
__vector_quad accReal0, accImag0, accReal1, accImag1, accReal2, accImag2, accReal3, accImag3, accReal4, accImag4;
MICRO_COMPLEX_MMA_SRC_PTR
MICRO_COMPLEX_MMA_DST_PTR
Index k = 0;
for(; k + PEEL_COMPLEX_MMA <= depth; k+= PEEL_COMPLEX_MMA)
{
EIGEN_POWER_PREFETCH(rhs_ptr_real);
if(!RhsIsReal) {
EIGEN_POWER_PREFETCH(rhs_ptr_imag);
}
MICRO_COMPLEX_MMA_PREFETCH
MICRO_COMPLEX_MMA_ONE_PEEL
}
for(; k < depth; k++)
{
MICRO_COMPLEX_MMA_ONE
}
MICRO_COMPLEX_MMA_STORE
row += unroll_factor*accCols;
}
template<typename LhsScalar, typename RhsScalar, typename Scalarc, typename Scalar, typename Index, typename Packet, typename Packetc, typename RhsPacket, typename DataMapper, const Index accRows, const Index accCols, bool ConjugateLhs, bool ConjugateRhs, bool LhsIsReal, bool RhsIsReal>
void gemm_complexMMA(const DataMapper& res, const LhsScalar* blockAc, const RhsScalar* blockBc, Index rows, Index depth, Index cols, Scalarc alpha, Index strideA, Index strideB, Index offsetA, Index offsetB)
{
const Index remaining_rows = rows % accCols;
const Index remaining_cols = cols % accRows;
if( strideA == -1 ) strideA = depth;
if( strideB == -1 ) strideB = depth;
const Packet pAlphaReal = pset1<Packet>(alpha.real());
const Packet pAlphaImag = pset1<Packet>(alpha.imag());
const Packet pMask = bmask<Packet>((const int)(remaining_rows));
const Scalar* blockA = (Scalar *) blockAc;
const Scalar* blockB = (Scalar *) blockBc;
Index col = 0;
for(; col + accRows <= cols; col += accRows)
{
const Scalar* rhs_base = blockB + advanceCols*col*strideB + accRows*offsetB;
const Scalar* lhs_base = blockA;
Index row = 0;
#define MAX_COMPLEX_MMA_UNROLL 4
while(row + MAX_COMPLEX_MMA_UNROLL*accCols <= rows) {
gemm_complex_unrolled_MMA_iteration<MAX_COMPLEX_MMA_UNROLL, Scalar, Packet, Packetc, RhsPacket, DataMapper, Index, accRows, accCols, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(res, lhs_base, rhs_base, depth, strideA, offsetA, strideB, row, col, pAlphaReal, pAlphaImag);
}
switch( (rows-row)/accCols ) {
#if MAX_COMPLEX_MMA_UNROLL > 4
case 4:
gemm_complex_unrolled_MMA_iteration<4, Scalar, Packet, Packetc, RhsPacket, DataMapper, Index, accRows, accCols, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(res, lhs_base, rhs_base, depth, strideA, offsetA, strideB, row, col, pAlphaReal, pAlphaImag);
break;
#endif
#if MAX_COMPLEX_MMA_UNROLL > 3
case 3:
gemm_complex_unrolled_MMA_iteration<3, Scalar, Packet, Packetc, RhsPacket, DataMapper, Index, accRows, accCols, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(res, lhs_base, rhs_base, depth, strideA, offsetA, strideB, row, col, pAlphaReal, pAlphaImag);
break;
#endif
#if MAX_COMPLEX_MMA_UNROLL > 2
case 2:
gemm_complex_unrolled_MMA_iteration<2, Scalar, Packet, Packetc, RhsPacket, DataMapper, Index, accRows, accCols, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(res, lhs_base, rhs_base, depth, strideA, offsetA, strideB, row, col, pAlphaReal, pAlphaImag);
break;
#endif
#if MAX_COMPLEX_MMA_UNROLL > 1
case 1:
gemm_complex_unrolled_MMA_iteration<1, Scalar, Packet, Packetc, RhsPacket, DataMapper, Index, accRows, accCols, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(res, lhs_base, rhs_base, depth, strideA, offsetA, strideB, row, col, pAlphaReal, pAlphaImag);
break;
#endif
default:
break;
}
#undef MAX_COMPLEX_MMA_UNROLL
if(remaining_rows > 0)
{
gemm_complex_extra_row<Scalar, Packet, Packetc, DataMapper, Index, accRows, accCols, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(res, lhs_base, rhs_base, depth, strideA, offsetA, strideB, row, col, rows, cols, remaining_rows, pAlphaReal, pAlphaImag, pMask);
}
}
if(remaining_cols > 0)
{
const Scalar* rhs_base = blockB + advanceCols*col*strideB + remaining_cols*offsetB;
const Scalar* lhs_base = blockA;
for(; col < cols; col++)
{
Index row = 0;
gemm_complex_unrolled_col<Scalar, Packet, Packetc, DataMapper, Index, accCols, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(res, lhs_base, rhs_base, depth, strideA, offsetA, strideB, row, rows, col, remaining_cols, pAlphaReal, pAlphaImag);
if (remaining_rows > 0)
{
gemm_complex_extra_col<Scalar, Packet, Packetc, DataMapper, Index, accRows, accCols, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(res, lhs_base, rhs_base, depth, strideA, offsetA, strideB, row, col, remaining_rows, remaining_cols, pAlphaReal, pAlphaImag);
}
rhs_base++;
}
}
}
#undef accColsC
#undef advanceRows
#undef advanceCols
#pragma GCC reset_options
} // end namespace internal
} // end namespace Eigen
#endif // EIGEN_MATRIX_PRODUCT_MMA_ALTIVEC_H

View File

@ -0,0 +1,700 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef EIGEN_BFLOAT16_H
#define EIGEN_BFLOAT16_H
#define BF16_PACKET_FUNCTION(PACKET_F, PACKET_BF16, METHOD) \
template <> \
EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS EIGEN_UNUSED \
PACKET_BF16 METHOD<PACKET_BF16>(const PACKET_BF16& _x) { \
return F32ToBf16(METHOD<PACKET_F>(Bf16ToF32(_x))); \
}
namespace Eigen {
struct bfloat16;
namespace bfloat16_impl {
// Make our own __bfloat16_raw definition.
struct __bfloat16_raw {
EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR __bfloat16_raw() : value(0) {}
explicit EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR __bfloat16_raw(unsigned short raw) : value(raw) {}
unsigned short value;
};
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR __bfloat16_raw raw_uint16_to_bfloat16(unsigned short value);
template <bool AssumeArgumentIsNormalOrInfinityOrZero>
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC __bfloat16_raw float_to_bfloat16_rtne(float ff);
// Forward declarations of template specializations, to avoid Visual C++ 2019 errors, saying:
// > error C2908: explicit specialization; 'float_to_bfloat16_rtne' has already been instantiated
template <>
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC __bfloat16_raw float_to_bfloat16_rtne<false>(float ff);
template <>
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC __bfloat16_raw float_to_bfloat16_rtne<true>(float ff);
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC float bfloat16_to_float(__bfloat16_raw h);
struct bfloat16_base : public __bfloat16_raw {
EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR bfloat16_base() {}
EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR bfloat16_base(const __bfloat16_raw& h) : __bfloat16_raw(h) {}
};
} // namespace bfloat16_impl
// Class definition.
struct bfloat16 : public bfloat16_impl::bfloat16_base {
typedef bfloat16_impl::__bfloat16_raw __bfloat16_raw;
EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR bfloat16() {}
EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR bfloat16(const __bfloat16_raw& h) : bfloat16_impl::bfloat16_base(h) {}
explicit EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR bfloat16(bool b)
: bfloat16_impl::bfloat16_base(bfloat16_impl::raw_uint16_to_bfloat16(b ? 0x3f80 : 0)) {}
template<class T>
explicit EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR bfloat16(T val)
: bfloat16_impl::bfloat16_base(bfloat16_impl::float_to_bfloat16_rtne<internal::is_integral<T>::value>(static_cast<float>(val))) {}
explicit EIGEN_DEVICE_FUNC bfloat16(float f)
: bfloat16_impl::bfloat16_base(bfloat16_impl::float_to_bfloat16_rtne<false>(f)) {}
// Following the convention of numpy, converting between complex and
// float will lead to loss of imag value.
template<typename RealScalar>
explicit EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR bfloat16(const std::complex<RealScalar>& val)
: bfloat16_impl::bfloat16_base(bfloat16_impl::float_to_bfloat16_rtne<false>(static_cast<float>(val.real()))) {}
EIGEN_DEVICE_FUNC operator float() const { // NOLINT: Allow implicit conversion to float, because it is lossless.
return bfloat16_impl::bfloat16_to_float(*this);
}
};
} // namespace Eigen
namespace std {
template<>
struct numeric_limits<Eigen::bfloat16> {
static const bool is_specialized = true;
static const bool is_signed = true;
static const bool is_integer = false;
static const bool is_exact = false;
static const bool has_infinity = true;
static const bool has_quiet_NaN = true;
static const bool has_signaling_NaN = true;
static const float_denorm_style has_denorm = std::denorm_absent;
static const bool has_denorm_loss = false;
static const std::float_round_style round_style = numeric_limits<float>::round_style;
static const bool is_iec559 = false;
static const bool is_bounded = true;
static const bool is_modulo = false;
static const int digits = 8;
static const int digits10 = 2;
static const int max_digits10 = 4;
static const int radix = 2;
static const int min_exponent = numeric_limits<float>::min_exponent;
static const int min_exponent10 = numeric_limits<float>::min_exponent10;
static const int max_exponent = numeric_limits<float>::max_exponent;
static const int max_exponent10 = numeric_limits<float>::max_exponent10;
static const bool traps = numeric_limits<float>::traps;
static const bool tinyness_before = numeric_limits<float>::tinyness_before;
static Eigen::bfloat16 (min)() { return Eigen::bfloat16_impl::raw_uint16_to_bfloat16(0x0080); }
static Eigen::bfloat16 lowest() { return Eigen::bfloat16_impl::raw_uint16_to_bfloat16(0xff7f); }
static Eigen::bfloat16 (max)() { return Eigen::bfloat16_impl::raw_uint16_to_bfloat16(0x7f7f); }
static Eigen::bfloat16 epsilon() { return Eigen::bfloat16_impl::raw_uint16_to_bfloat16(0x3c00); }
static Eigen::bfloat16 round_error() { return Eigen::bfloat16(0x3f00); }
static Eigen::bfloat16 infinity() { return Eigen::bfloat16_impl::raw_uint16_to_bfloat16(0x7f80); }
static Eigen::bfloat16 quiet_NaN() { return Eigen::bfloat16_impl::raw_uint16_to_bfloat16(0x7fc0); }
static Eigen::bfloat16 signaling_NaN() { return Eigen::bfloat16_impl::raw_uint16_to_bfloat16(0x7f81); }
static Eigen::bfloat16 denorm_min() { return Eigen::bfloat16_impl::raw_uint16_to_bfloat16(0x0001); }
};
// If std::numeric_limits<T> is specialized, should also specialize
// std::numeric_limits<const T>, std::numeric_limits<volatile T>, and
// std::numeric_limits<const volatile T>
// https://stackoverflow.com/a/16519653/
template<>
struct numeric_limits<const Eigen::bfloat16> : numeric_limits<Eigen::bfloat16> {};
template<>
struct numeric_limits<volatile Eigen::bfloat16> : numeric_limits<Eigen::bfloat16> {};
template<>
struct numeric_limits<const volatile Eigen::bfloat16> : numeric_limits<Eigen::bfloat16> {};
} // namespace std
namespace Eigen {
namespace bfloat16_impl {
// We need to distinguish clang as the CUDA compiler from clang as the host compiler,
// invoked by NVCC (e.g. on MacOS). The former needs to see both host and device implementation
// of the functions, while the latter can only deal with one of them.
#if !defined(EIGEN_HAS_NATIVE_BF16) || (EIGEN_COMP_CLANG && !EIGEN_COMP_NVCC) // Emulate support for bfloat16 floats
#if EIGEN_COMP_CLANG && defined(EIGEN_CUDACC)
// We need to provide emulated *host-side* BF16 operators for clang.
#pragma push_macro("EIGEN_DEVICE_FUNC")
#undef EIGEN_DEVICE_FUNC
#if defined(EIGEN_HAS_CUDA_BF16) && defined(EIGEN_HAS_NATIVE_BF16)
#define EIGEN_DEVICE_FUNC __host__
#else // both host and device need emulated ops.
#define EIGEN_DEVICE_FUNC __host__ __device__
#endif
#endif
// Definitions for CPUs, mostly working through conversion
// to/from fp32.
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 operator + (const bfloat16& a, const bfloat16& b) {
return bfloat16(float(a) + float(b));
}
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 operator + (const bfloat16& a, const int& b) {
return bfloat16(float(a) + static_cast<float>(b));
}
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 operator + (const int& a, const bfloat16& b) {
return bfloat16(static_cast<float>(a) + float(b));
}
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 operator * (const bfloat16& a, const bfloat16& b) {
return bfloat16(float(a) * float(b));
}
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 operator - (const bfloat16& a, const bfloat16& b) {
return bfloat16(float(a) - float(b));
}
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 operator / (const bfloat16& a, const bfloat16& b) {
return bfloat16(float(a) / float(b));
}
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 operator - (const bfloat16& a) {
bfloat16 result;
result.value = a.value ^ 0x8000;
return result;
}
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16& operator += (bfloat16& a, const bfloat16& b) {
a = bfloat16(float(a) + float(b));
return a;
}
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16& operator *= (bfloat16& a, const bfloat16& b) {
a = bfloat16(float(a) * float(b));
return a;
}
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16& operator -= (bfloat16& a, const bfloat16& b) {
a = bfloat16(float(a) - float(b));
return a;
}
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16& operator /= (bfloat16& a, const bfloat16& b) {
a = bfloat16(float(a) / float(b));
return a;
}
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 operator++(bfloat16& a) {
a += bfloat16(1);
return a;
}
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 operator--(bfloat16& a) {
a -= bfloat16(1);
return a;
}
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 operator++(bfloat16& a, int) {
bfloat16 original_value = a;
++a;
return original_value;
}
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 operator--(bfloat16& a, int) {
bfloat16 original_value = a;
--a;
return original_value;
}
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bool operator == (const bfloat16& a, const bfloat16& b) {
return numext::equal_strict(float(a),float(b));
}
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bool operator != (const bfloat16& a, const bfloat16& b) {
return numext::not_equal_strict(float(a), float(b));
}
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bool operator < (const bfloat16& a, const bfloat16& b) {
return float(a) < float(b);
}
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bool operator <= (const bfloat16& a, const bfloat16& b) {
return float(a) <= float(b);
}
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bool operator > (const bfloat16& a, const bfloat16& b) {
return float(a) > float(b);
}
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bool operator >= (const bfloat16& a, const bfloat16& b) {
return float(a) >= float(b);
}
#if EIGEN_COMP_CLANG && defined(EIGEN_CUDACC)
#pragma pop_macro("EIGEN_DEVICE_FUNC")
#endif
#endif // Emulate support for bfloat16 floats
// Division by an index. Do it in full float precision to avoid accuracy
// issues in converting the denominator to bfloat16.
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 operator / (const bfloat16& a, Index b) {
return bfloat16(static_cast<float>(a) / static_cast<float>(b));
}
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC __bfloat16_raw truncate_to_bfloat16(const float v) {
__bfloat16_raw output;
if (Eigen::numext::isnan EIGEN_NOT_A_MACRO(v)) {
output.value = std::signbit(v) ? 0xFFC0: 0x7FC0;
return output;
}
const uint16_t* p = reinterpret_cast<const uint16_t*>(&v);
#if defined(__BYTE_ORDER__) && __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__
output.value = p[0];
#else
output.value = p[1];
#endif
return output;
}
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR __bfloat16_raw raw_uint16_to_bfloat16(numext::uint16_t value) {
return __bfloat16_raw(value);
}
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR numext::uint16_t raw_bfloat16_as_uint16(const __bfloat16_raw& bf) {
return bf.value;
}
// float_to_bfloat16_rtne template specialization that does not make any
// assumption about the value of its function argument (ff).
template <>
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC __bfloat16_raw float_to_bfloat16_rtne<false>(float ff) {
#if (defined(EIGEN_HAS_CUDA_BF16) && defined(EIGEN_HAS_HIP_BF16))
// Nothing to do here
#else
__bfloat16_raw output;
if (Eigen::numext::isnan EIGEN_NOT_A_MACRO(ff)) {
// If the value is a NaN, squash it to a qNaN with msb of fraction set,
// this makes sure after truncation we don't end up with an inf.
//
// qNaN magic: All exponent bits set + most significant bit of fraction
// set.
output.value = std::signbit(ff) ? 0xFFC0: 0x7FC0;
} else {
// Fast rounding algorithm that rounds a half value to nearest even. This
// reduces expected error when we convert a large number of floats. Here
// is how it works:
//
// Definitions:
// To convert a float 32 to bfloat16, a float 32 can be viewed as 32 bits
// with the following tags:
//
// Sign | Exp (8 bits) | Frac (23 bits)
// S EEEEEEEE FFFFFFLRTTTTTTTTTTTTTTT
//
// S: Sign bit.
// E: Exponent bits.
// F: First 6 bits of fraction.
// L: Least significant bit of resulting bfloat16 if we truncate away the
// rest of the float32. This is also the 7th bit of fraction
// R: Rounding bit, 8th bit of fraction.
// T: Sticky bits, rest of fraction, 15 bits.
//
// To round half to nearest even, there are 3 cases where we want to round
// down (simply truncate the result of the bits away, which consists of
// rounding bit and sticky bits) and two cases where we want to round up
// (truncate then add one to the result).
//
// The fast converting algorithm simply adds lsb (L) to 0x7fff (15 bits of
// 1s) as the rounding bias, adds the rounding bias to the input, then
// truncates the last 16 bits away.
//
// To understand how it works, we can analyze this algorithm case by case:
//
// 1. L = 0, R = 0:
// Expect: round down, this is less than half value.
//
// Algorithm:
// - Rounding bias: 0x7fff + 0 = 0x7fff
// - Adding rounding bias to input may create any carry, depending on
// whether there is any value set to 1 in T bits.
// - R may be set to 1 if there is a carry.
// - L remains 0.
// - Note that this case also handles Inf and -Inf, where all fraction
// bits, including L, R and Ts are all 0. The output remains Inf after
// this algorithm.
//
// 2. L = 1, R = 0:
// Expect: round down, this is less than half value.
//
// Algorithm:
// - Rounding bias: 0x7fff + 1 = 0x8000
// - Adding rounding bias to input doesn't change sticky bits but
// adds 1 to rounding bit.
// - L remains 1.
//
// 3. L = 0, R = 1, all of T are 0:
// Expect: round down, this is exactly at half, the result is already
// even (L=0).
//
// Algorithm:
// - Rounding bias: 0x7fff + 0 = 0x7fff
// - Adding rounding bias to input sets all sticky bits to 1, but
// doesn't create a carry.
// - R remains 1.
// - L remains 0.
//
// 4. L = 1, R = 1:
// Expect: round up, this is exactly at half, the result needs to be
// round to the next even number.
//
// Algorithm:
// - Rounding bias: 0x7fff + 1 = 0x8000
// - Adding rounding bias to input doesn't change sticky bits, but
// creates a carry from rounding bit.
// - The carry sets L to 0, creates another carry bit and propagate
// forward to F bits.
// - If all the F bits are 1, a carry then propagates to the exponent
// bits, which then creates the minimum value with the next exponent
// value. Note that we won't have the case where exponents are all 1,
// since that's either a NaN (handled in the other if condition) or inf
// (handled in case 1).
//
// 5. L = 0, R = 1, any of T is 1:
// Expect: round up, this is greater than half.
//
// Algorithm:
// - Rounding bias: 0x7fff + 0 = 0x7fff
// - Adding rounding bias to input creates a carry from sticky bits,
// sets rounding bit to 0, then create another carry.
// - The second carry sets L to 1.
//
// Examples:
//
// Exact half value that is already even:
// Input:
// Sign | Exp (8 bit) | Frac (first 7 bit) | Frac (last 16 bit)
// S E E E E E E E E F F F F F F L RTTTTTTTTTTTTTTT
// 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0 1000000000000000
//
// This falls into case 3. We truncate the rest of 16 bits and no
// carry is created into F and L:
//
// Output:
// Sign | Exp (8 bit) | Frac (first 7 bit)
// S E E E E E E E E F F F F F F L
// 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0
//
// Exact half value, round to next even number:
// Input:
// Sign | Exp (8 bit) | Frac (first 7 bit) | Frac (last 16 bit)
// S E E E E E E E E F F F F F F L RTTTTTTTTTTTTTTT
// 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1000000000000000
//
// This falls into case 4. We create a carry from R and T,
// which then propagates into L and F:
//
// Output:
// Sign | Exp (8 bit) | Frac (first 7 bit)
// S E E E E E E E E F F F F F F L
// 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0
//
//
// Max denormal value round to min normal value:
// Input:
// Sign | Exp (8 bit) | Frac (first 7 bit) | Frac (last 16 bit)
// S E E E E E E E E F F F F F F L RTTTTTTTTTTTTTTT
// 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1111111111111111
//
// This falls into case 4. We create a carry from R and T,
// propagate into L and F, which then propagates into exponent
// bits:
//
// Output:
// Sign | Exp (8 bit) | Frac (first 7 bit)
// S E E E E E E E E F F F F F F L
// 0 0 0 0 0 0 0 0 1 0 0 0 0 0 0 0
//
// Max normal value round to Inf:
// Input:
// Sign | Exp (8 bit) | Frac (first 7 bit) | Frac (last 16 bit)
// S E E E E E E E E F F F F F F L RTTTTTTTTTTTTTTT
// 0 1 1 1 1 1 1 1 0 1 1 1 1 1 1 1 1111111111111111
//
// This falls into case 4. We create a carry from R and T,
// propagate into L and F, which then propagates into exponent
// bits:
//
// Sign | Exp (8 bit) | Frac (first 7 bit)
// S E E E E E E E E F F F F F F L
// 0 1 1 1 1 1 1 1 1 0 0 0 0 0 0 0
// At this point, ff must be either a normal float, or +/-infinity.
output = float_to_bfloat16_rtne<true>(ff);
}
return output;
#endif
}
// float_to_bfloat16_rtne template specialization that assumes that its function
// argument (ff) is either a normal floating point number, or +/-infinity, or
// zero. Used to improve the runtime performance of conversion from an integer
// type to bfloat16.
template <>
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC __bfloat16_raw float_to_bfloat16_rtne<true>(float ff) {
#if (defined(EIGEN_HAS_CUDA_BF16) && defined(EIGEN_HAS_HIP_BF16))
// Nothing to do here
#else
numext::uint32_t input = numext::bit_cast<numext::uint32_t>(ff);
__bfloat16_raw output;
// Least significant bit of resulting bfloat.
numext::uint32_t lsb = (input >> 16) & 1;
numext::uint32_t rounding_bias = 0x7fff + lsb;
input += rounding_bias;
output.value = static_cast<numext::uint16_t>(input >> 16);
return output;
#endif
}
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC float bfloat16_to_float(__bfloat16_raw h) {
float result = 0;
unsigned short* q = reinterpret_cast<unsigned short*>(&result);
#if defined(__BYTE_ORDER__) && __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__
q[0] = h.value;
#else
q[1] = h.value;
#endif
return result;
}
// --- standard functions ---
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bool (isinf)(const bfloat16& a) {
EIGEN_USING_STD(isinf);
return (isinf)(float(a));
}
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bool (isnan)(const bfloat16& a) {
EIGEN_USING_STD(isnan);
return (isnan)(float(a));
}
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bool (isfinite)(const bfloat16& a) {
return !(isinf EIGEN_NOT_A_MACRO (a)) && !(isnan EIGEN_NOT_A_MACRO (a));
}
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 abs(const bfloat16& a) {
bfloat16 result;
result.value = a.value & 0x7FFF;
return result;
}
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 exp(const bfloat16& a) {
return bfloat16(::expf(float(a)));
}
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 expm1(const bfloat16& a) {
return bfloat16(numext::expm1(float(a)));
}
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 log(const bfloat16& a) {
return bfloat16(::logf(float(a)));
}
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 log1p(const bfloat16& a) {
return bfloat16(numext::log1p(float(a)));
}
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 log10(const bfloat16& a) {
return bfloat16(::log10f(float(a)));
}
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 log2(const bfloat16& a) {
return bfloat16(static_cast<float>(EIGEN_LOG2E) * ::logf(float(a)));
}
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 sqrt(const bfloat16& a) {
return bfloat16(::sqrtf(float(a)));
}
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 pow(const bfloat16& a, const bfloat16& b) {
return bfloat16(::powf(float(a), float(b)));
}
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 sin(const bfloat16& a) {
return bfloat16(::sinf(float(a)));
}
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 cos(const bfloat16& a) {
return bfloat16(::cosf(float(a)));
}
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 tan(const bfloat16& a) {
return bfloat16(::tanf(float(a)));
}
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 asin(const bfloat16& a) {
return bfloat16(::asinf(float(a)));
}
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 acos(const bfloat16& a) {
return bfloat16(::acosf(float(a)));
}
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 atan(const bfloat16& a) {
return bfloat16(::atanf(float(a)));
}
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 sinh(const bfloat16& a) {
return bfloat16(::sinhf(float(a)));
}
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 cosh(const bfloat16& a) {
return bfloat16(::coshf(float(a)));
}
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 tanh(const bfloat16& a) {
return bfloat16(::tanhf(float(a)));
}
#if EIGEN_HAS_CXX11_MATH
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 asinh(const bfloat16& a) {
return bfloat16(::asinhf(float(a)));
}
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 acosh(const bfloat16& a) {
return bfloat16(::acoshf(float(a)));
}
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 atanh(const bfloat16& a) {
return bfloat16(::atanhf(float(a)));
}
#endif
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 floor(const bfloat16& a) {
return bfloat16(::floorf(float(a)));
}
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 ceil(const bfloat16& a) {
return bfloat16(::ceilf(float(a)));
}
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 rint(const bfloat16& a) {
return bfloat16(::rintf(float(a)));
}
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 round(const bfloat16& a) {
return bfloat16(::roundf(float(a)));
}
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 fmod(const bfloat16& a, const bfloat16& b) {
return bfloat16(::fmodf(float(a), float(b)));
}
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 (min)(const bfloat16& a, const bfloat16& b) {
const float f1 = static_cast<float>(a);
const float f2 = static_cast<float>(b);
return f2 < f1 ? b : a;
}
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 (max)(const bfloat16& a, const bfloat16& b) {
const float f1 = static_cast<float>(a);
const float f2 = static_cast<float>(b);
return f1 < f2 ? b : a;
}
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 fmin(const bfloat16& a, const bfloat16& b) {
const float f1 = static_cast<float>(a);
const float f2 = static_cast<float>(b);
return bfloat16(::fminf(f1, f2));
}
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 fmax(const bfloat16& a, const bfloat16& b) {
const float f1 = static_cast<float>(a);
const float f2 = static_cast<float>(b);
return bfloat16(::fmaxf(f1, f2));
}
#ifndef EIGEN_NO_IO
EIGEN_ALWAYS_INLINE std::ostream& operator << (std::ostream& os, const bfloat16& v) {
os << static_cast<float>(v);
return os;
}
#endif
} // namespace bfloat16_impl
namespace internal {
template<>
struct random_default_impl<bfloat16, false, false>
{
static inline bfloat16 run(const bfloat16& x, const bfloat16& y)
{
return x + (y-x) * bfloat16(float(std::rand()) / float(RAND_MAX));
}
static inline bfloat16 run()
{
return run(bfloat16(-1.f), bfloat16(1.f));
}
};
template<> struct is_arithmetic<bfloat16> { enum { value = true }; };
} // namespace internal
template<> struct NumTraits<Eigen::bfloat16>
: GenericNumTraits<Eigen::bfloat16>
{
enum {
IsSigned = true,
IsInteger = false,
IsComplex = false,
RequireInitialization = false
};
EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR static EIGEN_STRONG_INLINE Eigen::bfloat16 epsilon() {
return bfloat16_impl::raw_uint16_to_bfloat16(0x3c00);
}
EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR static EIGEN_STRONG_INLINE Eigen::bfloat16 dummy_precision() {
return bfloat16_impl::raw_uint16_to_bfloat16(0x3D4D); // bfloat16(5e-2f);
}
EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR static EIGEN_STRONG_INLINE Eigen::bfloat16 highest() {
return bfloat16_impl::raw_uint16_to_bfloat16(0x7F7F);
}
EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR static EIGEN_STRONG_INLINE Eigen::bfloat16 lowest() {
return bfloat16_impl::raw_uint16_to_bfloat16(0xFF7F);
}
EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR static EIGEN_STRONG_INLINE Eigen::bfloat16 infinity() {
return bfloat16_impl::raw_uint16_to_bfloat16(0x7f80);
}
EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR static EIGEN_STRONG_INLINE Eigen::bfloat16 quiet_NaN() {
return bfloat16_impl::raw_uint16_to_bfloat16(0x7fc0);
}
};
} // namespace Eigen
namespace Eigen {
namespace numext {
template<>
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE
bool (isnan)(const Eigen::bfloat16& h) {
return (bfloat16_impl::isnan)(h);
}
template<>
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE
bool (isinf)(const Eigen::bfloat16& h) {
return (bfloat16_impl::isinf)(h);
}
template<>
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE
bool (isfinite)(const Eigen::bfloat16& h) {
return (bfloat16_impl::isfinite)(h);
}
template <>
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC Eigen::bfloat16 bit_cast<Eigen::bfloat16, uint16_t>(const uint16_t& src) {
return Eigen::bfloat16(Eigen::bfloat16_impl::raw_uint16_to_bfloat16(src));
}
template <>
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC uint16_t bit_cast<uint16_t, Eigen::bfloat16>(const Eigen::bfloat16& src) {
return Eigen::bfloat16_impl::raw_bfloat16_as_uint16(src);
}
} // namespace numext
} // namespace Eigen
#if EIGEN_HAS_STD_HASH
namespace std {
template <>
struct hash<Eigen::bfloat16> {
EIGEN_STRONG_INLINE std::size_t operator()(const Eigen::bfloat16& a) const {
return static_cast<std::size_t>(Eigen::numext::bit_cast<Eigen::numext::uint16_t>(a));
}
};
} // namespace std
#endif
#endif // EIGEN_BFLOAT16_H

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,110 @@
// This file is part of Eigen, a lightweight C++ template library
// for linear algebra.
//
// Copyright (C) 2019 Gael Guennebaud <gael.guennebaud@inria.fr>
//
// This Source Code Form is subject to the terms of the Mozilla
// Public License v. 2.0. If a copy of the MPL was not distributed
// with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
#ifndef EIGEN_ARCH_GENERIC_PACKET_MATH_FUNCTIONS_FWD_H
#define EIGEN_ARCH_GENERIC_PACKET_MATH_FUNCTIONS_FWD_H
namespace Eigen {
namespace internal {
// Forward declarations of the generic math functions
// implemented in GenericPacketMathFunctions.h
// This is needed to workaround a circular dependency.
/***************************************************************************
* Some generic implementations to be used by implementors
***************************************************************************/
/** Default implementation of pfrexp.
* It is expected to be called by implementers of template<> pfrexp.
*/
template<typename Packet> EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC
Packet pfrexp_generic(const Packet& a, Packet& exponent);
// Extracts the biased exponent value from Packet p, and casts the results to
// a floating-point Packet type. Used by pfrexp_generic. Override this if
// there is no unpacket_traits<Packet>::integer_packet.
template<typename Packet> EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC
Packet pfrexp_generic_get_biased_exponent(const Packet& p);
/** Default implementation of pldexp.
* It is expected to be called by implementers of template<> pldexp.
*/
template<typename Packet> EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC
Packet pldexp_generic(const Packet& a, const Packet& exponent);
/** \internal \returns log(x) for single precision float */
template <typename Packet>
EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS
EIGEN_UNUSED
Packet plog_float(const Packet _x);
/** \internal \returns log2(x) for single precision float */
template <typename Packet>
EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS
EIGEN_UNUSED
Packet plog2_float(const Packet _x);
/** \internal \returns log(x) for single precision float */
template <typename Packet>
EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS
EIGEN_UNUSED
Packet plog_double(const Packet _x);
/** \internal \returns log2(x) for single precision float */
template <typename Packet>
EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS
EIGEN_UNUSED
Packet plog2_double(const Packet _x);
/** \internal \returns log(1 + x) */
template<typename Packet>
Packet generic_plog1p(const Packet& x);
/** \internal \returns exp(x)-1 */
template<typename Packet>
Packet generic_expm1(const Packet& x);
/** \internal \returns exp(x) for single precision float */
template <typename Packet>
EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS
EIGEN_UNUSED
Packet pexp_float(const Packet _x);
/** \internal \returns exp(x) for double precision real numbers */
template <typename Packet>
EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS
EIGEN_UNUSED
Packet pexp_double(const Packet _x);
/** \internal \returns sin(x) for single precision float */
template<typename Packet>
EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS
EIGEN_UNUSED
Packet psin_float(const Packet& x);
/** \internal \returns cos(x) for single precision float */
template<typename Packet>
EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS
EIGEN_UNUSED
Packet pcos_float(const Packet& x);
/** \internal \returns sqrt(x) for complex types */
template<typename Packet>
EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS
EIGEN_UNUSED
Packet psqrt_complex(const Packet& a);
template <typename Packet, int N> struct ppolevl;
} // end namespace internal
} // end namespace Eigen
#endif // EIGEN_ARCH_GENERIC_PACKET_MATH_FUNCTIONS_FWD_H

View File

@ -0,0 +1,942 @@
// This file is part of Eigen, a lightweight C++ template library
// for linear algebra.
//
// This Source Code Form is subject to the terms of the Mozilla
// Public License v. 2.0. If a copy of the MPL was not distributed
// with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
//
// The conversion routines are Copyright (c) Fabian Giesen, 2016.
// The original license follows:
//
// Copyright (c) Fabian Giesen, 2016
// All rights reserved.
// Redistribution and use in source and binary forms, with or without
// modification, are permitted.
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
// HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
// Standard 16-bit float type, mostly useful for GPUs. Defines a new
// type Eigen::half (inheriting either from CUDA's or HIP's __half struct) with
// operator overloads such that it behaves basically as an arithmetic
// type. It will be quite slow on CPUs (so it is recommended to stay
// in fp32 for CPUs, except for simple parameter conversions, I/O
// to disk and the likes), but fast on GPUs.
#ifndef EIGEN_HALF_H
#define EIGEN_HALF_H
#include <sstream>
#if defined(EIGEN_HAS_GPU_FP16) || defined(EIGEN_HAS_ARM64_FP16_SCALAR_ARITHMETIC)
// When compiling with GPU support, the "__half_raw" base class as well as
// some other routines are defined in the GPU compiler header files
// (cuda_fp16.h, hip_fp16.h), and they are not tagged constexpr
// As a consequence, we get compile failures when compiling Eigen with
// GPU support. Hence the need to disable EIGEN_CONSTEXPR when building
// Eigen with GPU support
#pragma push_macro("EIGEN_CONSTEXPR")
#undef EIGEN_CONSTEXPR
#define EIGEN_CONSTEXPR
#endif
#define F16_PACKET_FUNCTION(PACKET_F, PACKET_F16, METHOD) \
template <> \
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC EIGEN_UNUSED \
PACKET_F16 METHOD<PACKET_F16>(const PACKET_F16& _x) { \
return float2half(METHOD<PACKET_F>(half2float(_x))); \
}
namespace Eigen {
struct half;
namespace half_impl {
// We want to use the __half_raw struct from the HIP header file only during the device compile phase.
// This is required because of a quirk in the way TensorFlow GPU builds are done.
// When compiling TensorFlow source code with GPU support, files that
// * contain GPU kernels (i.e. *.cu.cc files) are compiled via hipcc
// * do not contain GPU kernels ( i.e. *.cc files) are compiled via gcc (typically)
//
// Tensorflow uses the Eigen::half type as its FP16 type, and there are functions that
// * are defined in a file that gets compiled via hipcc AND
// * have Eigen::half as a pass-by-value argument AND
// * are called in a file that gets compiled via gcc
//
// In the scenario described above the caller and callee will see different versions
// of the Eigen::half base class __half_raw, and they will be compiled by different compilers
//
// There appears to be an ABI mismatch between gcc and clang (which is called by hipcc) that results in
// the callee getting corrupted values for the Eigen::half argument.
//
// Making the host side compile phase of hipcc use the same Eigen::half impl, as the gcc compile, resolves
// this error, and hence the following convoluted #if condition
#if !defined(EIGEN_HAS_GPU_FP16) || !defined(EIGEN_GPU_COMPILE_PHASE)
// Make our own __half_raw definition that is similar to CUDA's.
struct __half_raw {
#if (defined(EIGEN_HAS_GPU_FP16) && !defined(EIGEN_GPU_COMPILE_PHASE))
// Eigen::half can be used as the datatype for shared memory declarations (in Eigen and TF)
// The element type for shared memory cannot have non-trivial constructors
// and hence the following special casing (which skips the zero-initilization).
// Note that this check gets done even in the host compilation phase, and
// hence the need for this
EIGEN_DEVICE_FUNC __half_raw() {}
#else
EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR __half_raw() : x(0) {}
#endif
#if defined(EIGEN_HAS_ARM64_FP16_SCALAR_ARITHMETIC)
explicit EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR __half_raw(numext::uint16_t raw) : x(numext::bit_cast<__fp16>(raw)) {
}
__fp16 x;
#else
explicit EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR __half_raw(numext::uint16_t raw) : x(raw) {}
numext::uint16_t x;
#endif
};
#elif defined(EIGEN_HAS_HIP_FP16)
// Nothing to do here
// HIP fp16 header file has a definition for __half_raw
#elif defined(EIGEN_HAS_CUDA_FP16)
#if EIGEN_CUDA_SDK_VER < 90000
// In CUDA < 9.0, __half is the equivalent of CUDA 9's __half_raw
typedef __half __half_raw;
#endif // defined(EIGEN_HAS_CUDA_FP16)
#elif defined(SYCL_DEVICE_ONLY)
typedef cl::sycl::half __half_raw;
#endif
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR __half_raw raw_uint16_to_half(numext::uint16_t x);
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC __half_raw float_to_half_rtne(float ff);
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC float half_to_float(__half_raw h);
struct half_base : public __half_raw {
EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR half_base() {}
EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR half_base(const __half_raw& h) : __half_raw(h) {}
#if defined(EIGEN_HAS_GPU_FP16)
#if defined(EIGEN_HAS_HIP_FP16)
EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR half_base(const __half& h) { x = __half_as_ushort(h); }
#elif defined(EIGEN_HAS_CUDA_FP16)
#if EIGEN_CUDA_SDK_VER >= 90000
EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR half_base(const __half& h) : __half_raw(*(__half_raw*)&h) {}
#endif
#endif
#endif
};
} // namespace half_impl
// Class definition.
struct half : public half_impl::half_base {
// Writing this out as separate #if-else blocks to make the code easier to follow
// The same applies to most #if-else blocks in this file
#if !defined(EIGEN_HAS_GPU_FP16) || !defined(EIGEN_GPU_COMPILE_PHASE)
// Use the same base class for the following two scenarios
// * when compiling without GPU support enabled
// * during host compile phase when compiling with GPU support enabled
typedef half_impl::__half_raw __half_raw;
#elif defined(EIGEN_HAS_HIP_FP16)
// Nothing to do here
// HIP fp16 header file has a definition for __half_raw
#elif defined(EIGEN_HAS_CUDA_FP16)
// Note that EIGEN_CUDA_SDK_VER is set to 0 even when compiling with HIP, so
// (EIGEN_CUDA_SDK_VER < 90000) is true even for HIP! So keeping this within
// #if defined(EIGEN_HAS_CUDA_FP16) is needed
#if defined(EIGEN_CUDA_SDK_VER) && EIGEN_CUDA_SDK_VER < 90000
typedef half_impl::__half_raw __half_raw;
#endif
#endif
EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR half() {}
EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR half(const __half_raw& h) : half_impl::half_base(h) {}
#if defined(EIGEN_HAS_GPU_FP16)
#if defined(EIGEN_HAS_HIP_FP16)
EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR half(const __half& h) : half_impl::half_base(h) {}
#elif defined(EIGEN_HAS_CUDA_FP16)
#if defined(EIGEN_CUDA_SDK_VER) && EIGEN_CUDA_SDK_VER >= 90000
EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR half(const __half& h) : half_impl::half_base(h) {}
#endif
#endif
#endif
explicit EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR half(bool b)
: half_impl::half_base(half_impl::raw_uint16_to_half(b ? 0x3c00 : 0)) {}
template<class T>
explicit EIGEN_DEVICE_FUNC half(T val)
: half_impl::half_base(half_impl::float_to_half_rtne(static_cast<float>(val))) {}
explicit EIGEN_DEVICE_FUNC half(float f)
: half_impl::half_base(half_impl::float_to_half_rtne(f)) {}
// Following the convention of numpy, converting between complex and
// float will lead to loss of imag value.
template<typename RealScalar>
explicit EIGEN_DEVICE_FUNC half(std::complex<RealScalar> c)
: half_impl::half_base(half_impl::float_to_half_rtne(static_cast<float>(c.real()))) {}
EIGEN_DEVICE_FUNC operator float() const { // NOLINT: Allow implicit conversion to float, because it is lossless.
return half_impl::half_to_float(*this);
}
#if defined(EIGEN_HAS_GPU_FP16) && !defined(EIGEN_GPU_COMPILE_PHASE)
EIGEN_DEVICE_FUNC operator __half() const {
::__half_raw hr;
hr.x = x;
return __half(hr);
}
#endif
};
} // end namespace Eigen
namespace std {
template<>
struct numeric_limits<Eigen::half> {
static const bool is_specialized = true;
static const bool is_signed = true;
static const bool is_integer = false;
static const bool is_exact = false;
static const bool has_infinity = true;
static const bool has_quiet_NaN = true;
static const bool has_signaling_NaN = true;
static const float_denorm_style has_denorm = denorm_present;
static const bool has_denorm_loss = false;
static const std::float_round_style round_style = std::round_to_nearest;
static const bool is_iec559 = false;
static const bool is_bounded = false;
static const bool is_modulo = false;
static const int digits = 11;
static const int digits10 = 3; // according to http://half.sourceforge.net/structstd_1_1numeric__limits_3_01half__float_1_1half_01_4.html
static const int max_digits10 = 5; // according to http://half.sourceforge.net/structstd_1_1numeric__limits_3_01half__float_1_1half_01_4.html
static const int radix = 2;
static const int min_exponent = -13;
static const int min_exponent10 = -4;
static const int max_exponent = 16;
static const int max_exponent10 = 4;
static const bool traps = true;
static const bool tinyness_before = false;
static Eigen::half (min)() { return Eigen::half_impl::raw_uint16_to_half(0x400); }
static Eigen::half lowest() { return Eigen::half_impl::raw_uint16_to_half(0xfbff); }
static Eigen::half (max)() { return Eigen::half_impl::raw_uint16_to_half(0x7bff); }
static Eigen::half epsilon() { return Eigen::half_impl::raw_uint16_to_half(0x0800); }
static Eigen::half round_error() { return Eigen::half(0.5); }
static Eigen::half infinity() { return Eigen::half_impl::raw_uint16_to_half(0x7c00); }
static Eigen::half quiet_NaN() { return Eigen::half_impl::raw_uint16_to_half(0x7e00); }
static Eigen::half signaling_NaN() { return Eigen::half_impl::raw_uint16_to_half(0x7d00); }
static Eigen::half denorm_min() { return Eigen::half_impl::raw_uint16_to_half(0x1); }
};
// If std::numeric_limits<T> is specialized, should also specialize
// std::numeric_limits<const T>, std::numeric_limits<volatile T>, and
// std::numeric_limits<const volatile T>
// https://stackoverflow.com/a/16519653/
template<>
struct numeric_limits<const Eigen::half> : numeric_limits<Eigen::half> {};
template<>
struct numeric_limits<volatile Eigen::half> : numeric_limits<Eigen::half> {};
template<>
struct numeric_limits<const volatile Eigen::half> : numeric_limits<Eigen::half> {};
} // end namespace std
namespace Eigen {
namespace half_impl {
#if (defined(EIGEN_HAS_CUDA_FP16) && defined(EIGEN_CUDA_ARCH) && \
EIGEN_CUDA_ARCH >= 530) || \
(defined(EIGEN_HAS_HIP_FP16) && defined(HIP_DEVICE_COMPILE))
// Note: We deliberatly do *not* define this to 1 even if we have Arm's native
// fp16 type since GPU halfs are rather different from native CPU halfs.
// TODO: Rename to something like EIGEN_HAS_NATIVE_GPU_FP16
#define EIGEN_HAS_NATIVE_FP16
#endif
// Intrinsics for native fp16 support. Note that on current hardware,
// these are no faster than fp32 arithmetic (you need to use the half2
// versions to get the ALU speed increased), but you do save the
// conversion steps back and forth.
#if defined(EIGEN_HAS_NATIVE_FP16)
EIGEN_STRONG_INLINE __device__ half operator + (const half& a, const half& b) {
#if defined(EIGEN_CUDA_SDK_VER) && EIGEN_CUDA_SDK_VER >= 90000
return __hadd(::__half(a), ::__half(b));
#else
return __hadd(a, b);
#endif
}
EIGEN_STRONG_INLINE __device__ half operator * (const half& a, const half& b) {
return __hmul(a, b);
}
EIGEN_STRONG_INLINE __device__ half operator - (const half& a, const half& b) {
return __hsub(a, b);
}
EIGEN_STRONG_INLINE __device__ half operator / (const half& a, const half& b) {
#if defined(EIGEN_CUDA_SDK_VER) && EIGEN_CUDA_SDK_VER >= 90000
return __hdiv(a, b);
#else
float num = __half2float(a);
float denom = __half2float(b);
return __float2half(num / denom);
#endif
}
EIGEN_STRONG_INLINE __device__ half operator - (const half& a) {
return __hneg(a);
}
EIGEN_STRONG_INLINE __device__ half& operator += (half& a, const half& b) {
a = a + b;
return a;
}
EIGEN_STRONG_INLINE __device__ half& operator *= (half& a, const half& b) {
a = a * b;
return a;
}
EIGEN_STRONG_INLINE __device__ half& operator -= (half& a, const half& b) {
a = a - b;
return a;
}
EIGEN_STRONG_INLINE __device__ half& operator /= (half& a, const half& b) {
a = a / b;
return a;
}
EIGEN_STRONG_INLINE __device__ bool operator == (const half& a, const half& b) {
return __heq(a, b);
}
EIGEN_STRONG_INLINE __device__ bool operator != (const half& a, const half& b) {
return __hne(a, b);
}
EIGEN_STRONG_INLINE __device__ bool operator < (const half& a, const half& b) {
return __hlt(a, b);
}
EIGEN_STRONG_INLINE __device__ bool operator <= (const half& a, const half& b) {
return __hle(a, b);
}
EIGEN_STRONG_INLINE __device__ bool operator > (const half& a, const half& b) {
return __hgt(a, b);
}
EIGEN_STRONG_INLINE __device__ bool operator >= (const half& a, const half& b) {
return __hge(a, b);
}
#endif
#if defined(EIGEN_HAS_ARM64_FP16_SCALAR_ARITHMETIC)
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC half operator + (const half& a, const half& b) {
return half(vaddh_f16(a.x, b.x));
}
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC half operator * (const half& a, const half& b) {
return half(vmulh_f16(a.x, b.x));
}
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC half operator - (const half& a, const half& b) {
return half(vsubh_f16(a.x, b.x));
}
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC half operator / (const half& a, const half& b) {
return half(vdivh_f16(a.x, b.x));
}
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC half operator - (const half& a) {
return half(vnegh_f16(a.x));
}
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC half& operator += (half& a, const half& b) {
a = half(vaddh_f16(a.x, b.x));
return a;
}
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC half& operator *= (half& a, const half& b) {
a = half(vmulh_f16(a.x, b.x));
return a;
}
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC half& operator -= (half& a, const half& b) {
a = half(vsubh_f16(a.x, b.x));
return a;
}
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC half& operator /= (half& a, const half& b) {
a = half(vdivh_f16(a.x, b.x));
return a;
}
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bool operator == (const half& a, const half& b) {
return vceqh_f16(a.x, b.x);
}
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bool operator != (const half& a, const half& b) {
return !vceqh_f16(a.x, b.x);
}
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bool operator < (const half& a, const half& b) {
return vclth_f16(a.x, b.x);
}
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bool operator <= (const half& a, const half& b) {
return vcleh_f16(a.x, b.x);
}
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bool operator > (const half& a, const half& b) {
return vcgth_f16(a.x, b.x);
}
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bool operator >= (const half& a, const half& b) {
return vcgeh_f16(a.x, b.x);
}
// We need to distinguish clang as the CUDA compiler from clang as the host compiler,
// invoked by NVCC (e.g. on MacOS). The former needs to see both host and device implementation
// of the functions, while the latter can only deal with one of them.
#elif !defined(EIGEN_HAS_NATIVE_FP16) || (EIGEN_COMP_CLANG && !EIGEN_COMP_NVCC) // Emulate support for half floats
#if EIGEN_COMP_CLANG && defined(EIGEN_CUDACC)
// We need to provide emulated *host-side* FP16 operators for clang.
#pragma push_macro("EIGEN_DEVICE_FUNC")
#undef EIGEN_DEVICE_FUNC
#if defined(EIGEN_HAS_CUDA_FP16) && defined(EIGEN_HAS_NATIVE_FP16)
#define EIGEN_DEVICE_FUNC __host__
#else // both host and device need emulated ops.
#define EIGEN_DEVICE_FUNC __host__ __device__
#endif
#endif
// Definitions for CPUs and older HIP+CUDA, mostly working through conversion
// to/from fp32.
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC half operator + (const half& a, const half& b) {
return half(float(a) + float(b));
}
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC half operator * (const half& a, const half& b) {
return half(float(a) * float(b));
}
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC half operator - (const half& a, const half& b) {
return half(float(a) - float(b));
}
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC half operator / (const half& a, const half& b) {
return half(float(a) / float(b));
}
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC half operator - (const half& a) {
half result;
result.x = a.x ^ 0x8000;
return result;
}
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC half& operator += (half& a, const half& b) {
a = half(float(a) + float(b));
return a;
}
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC half& operator *= (half& a, const half& b) {
a = half(float(a) * float(b));
return a;
}
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC half& operator -= (half& a, const half& b) {
a = half(float(a) - float(b));
return a;
}
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC half& operator /= (half& a, const half& b) {
a = half(float(a) / float(b));
return a;
}
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bool operator == (const half& a, const half& b) {
return numext::equal_strict(float(a),float(b));
}
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bool operator != (const half& a, const half& b) {
return numext::not_equal_strict(float(a), float(b));
}
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bool operator < (const half& a, const half& b) {
return float(a) < float(b);
}
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bool operator <= (const half& a, const half& b) {
return float(a) <= float(b);
}
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bool operator > (const half& a, const half& b) {
return float(a) > float(b);
}
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bool operator >= (const half& a, const half& b) {
return float(a) >= float(b);
}
#if defined(__clang__) && defined(__CUDA__)
#pragma pop_macro("EIGEN_DEVICE_FUNC")
#endif
#endif // Emulate support for half floats
// Division by an index. Do it in full float precision to avoid accuracy
// issues in converting the denominator to half.
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC half operator / (const half& a, Index b) {
return half(static_cast<float>(a) / static_cast<float>(b));
}
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC half operator++(half& a) {
a += half(1);
return a;
}
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC half operator--(half& a) {
a -= half(1);
return a;
}
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC half operator++(half& a, int) {
half original_value = a;
++a;
return original_value;
}
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC half operator--(half& a, int) {
half original_value = a;
--a;
return original_value;
}
// Conversion routines, including fallbacks for the host or older CUDA.
// Note that newer Intel CPUs (Haswell or newer) have vectorized versions of
// these in hardware. If we need more performance on older/other CPUs, they are
// also possible to vectorize directly.
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR __half_raw raw_uint16_to_half(numext::uint16_t x) {
// We cannot simply do a "return __half_raw(x)" here, because __half_raw is union type
// in the hip_fp16 header file, and that will trigger a compile error
// On the other hand, having anything but a return statement also triggers a compile error
// because this is constexpr function.
// Fortunately, since we need to disable EIGEN_CONSTEXPR for GPU anyway, we can get out
// of this catch22 by having separate bodies for GPU / non GPU
#if defined(EIGEN_HAS_GPU_FP16)
__half_raw h;
h.x = x;
return h;
#else
return __half_raw(x);
#endif
}
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC numext::uint16_t raw_half_as_uint16(const __half_raw& h) {
// HIP/CUDA/Default have a member 'x' of type uint16_t.
// For ARM64 native half, the member 'x' is of type __fp16, so we need to bit-cast.
// For SYCL, cl::sycl::half is _Float16, so cast directly.
#if defined(EIGEN_HAS_ARM64_FP16_SCALAR_ARITHMETIC)
return numext::bit_cast<numext::uint16_t>(h.x);
#elif defined(SYCL_DEVICE_ONLY)
return numext::bit_cast<numext::uint16_t>(h);
#else
return h.x;
#endif
}
union float32_bits {
unsigned int u;
float f;
};
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC __half_raw float_to_half_rtne(float ff) {
#if (defined(EIGEN_HAS_CUDA_FP16) && defined(EIGEN_CUDA_ARCH) && EIGEN_CUDA_ARCH >= 300) || \
(defined(EIGEN_HAS_HIP_FP16) && defined(EIGEN_HIP_DEVICE_COMPILE))
__half tmp_ff = __float2half(ff);
return *(__half_raw*)&tmp_ff;
#elif defined(EIGEN_HAS_FP16_C)
__half_raw h;
h.x = _cvtss_sh(ff, 0);
return h;
#elif defined(EIGEN_HAS_ARM64_FP16_SCALAR_ARITHMETIC)
__half_raw h;
h.x = static_cast<__fp16>(ff);
return h;
#else
float32_bits f; f.f = ff;
const float32_bits f32infty = { 255 << 23 };
const float32_bits f16max = { (127 + 16) << 23 };
const float32_bits denorm_magic = { ((127 - 15) + (23 - 10) + 1) << 23 };
unsigned int sign_mask = 0x80000000u;
__half_raw o;
o.x = static_cast<numext::uint16_t>(0x0u);
unsigned int sign = f.u & sign_mask;
f.u ^= sign;
// NOTE all the integer compares in this function can be safely
// compiled into signed compares since all operands are below
// 0x80000000. Important if you want fast straight SSE2 code
// (since there's no unsigned PCMPGTD).
if (f.u >= f16max.u) { // result is Inf or NaN (all exponent bits set)
o.x = (f.u > f32infty.u) ? 0x7e00 : 0x7c00; // NaN->qNaN and Inf->Inf
} else { // (De)normalized number or zero
if (f.u < (113 << 23)) { // resulting FP16 is subnormal or zero
// use a magic value to align our 10 mantissa bits at the bottom of
// the float. as long as FP addition is round-to-nearest-even this
// just works.
f.f += denorm_magic.f;
// and one integer subtract of the bias later, we have our final float!
o.x = static_cast<numext::uint16_t>(f.u - denorm_magic.u);
} else {
unsigned int mant_odd = (f.u >> 13) & 1; // resulting mantissa is odd
// update exponent, rounding bias part 1
// Equivalent to `f.u += ((unsigned int)(15 - 127) << 23) + 0xfff`, but
// without arithmetic overflow.
f.u += 0xc8000fffU;
// rounding bias part 2
f.u += mant_odd;
// take the bits!
o.x = static_cast<numext::uint16_t>(f.u >> 13);
}
}
o.x |= static_cast<numext::uint16_t>(sign >> 16);
return o;
#endif
}
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC float half_to_float(__half_raw h) {
#if (defined(EIGEN_HAS_CUDA_FP16) && defined(EIGEN_CUDA_ARCH) && EIGEN_CUDA_ARCH >= 300) || \
(defined(EIGEN_HAS_HIP_FP16) && defined(EIGEN_HIP_DEVICE_COMPILE))
return __half2float(h);
#elif defined(EIGEN_HAS_FP16_C)
return _cvtsh_ss(h.x);
#elif defined(EIGEN_HAS_ARM64_FP16_SCALAR_ARITHMETIC)
return static_cast<float>(h.x);
#else
const float32_bits magic = { 113 << 23 };
const unsigned int shifted_exp = 0x7c00 << 13; // exponent mask after shift
float32_bits o;
o.u = (h.x & 0x7fff) << 13; // exponent/mantissa bits
unsigned int exp = shifted_exp & o.u; // just the exponent
o.u += (127 - 15) << 23; // exponent adjust
// handle exponent special cases
if (exp == shifted_exp) { // Inf/NaN?
o.u += (128 - 16) << 23; // extra exp adjust
} else if (exp == 0) { // Zero/Denormal?
o.u += 1 << 23; // extra exp adjust
o.f -= magic.f; // renormalize
}
o.u |= (h.x & 0x8000) << 16; // sign bit
return o.f;
#endif
}
// --- standard functions ---
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bool (isinf)(const half& a) {
#ifdef EIGEN_HAS_ARM64_FP16_SCALAR_ARITHMETIC
return (numext::bit_cast<numext::uint16_t>(a.x) & 0x7fff) == 0x7c00;
#else
return (a.x & 0x7fff) == 0x7c00;
#endif
}
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bool (isnan)(const half& a) {
#if (defined(EIGEN_HAS_CUDA_FP16) && defined(EIGEN_CUDA_ARCH) && EIGEN_CUDA_ARCH >= 530) || \
(defined(EIGEN_HAS_HIP_FP16) && defined(EIGEN_HIP_DEVICE_COMPILE))
return __hisnan(a);
#elif defined(EIGEN_HAS_ARM64_FP16_SCALAR_ARITHMETIC)
return (numext::bit_cast<numext::uint16_t>(a.x) & 0x7fff) > 0x7c00;
#else
return (a.x & 0x7fff) > 0x7c00;
#endif
}
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bool (isfinite)(const half& a) {
return !(isinf EIGEN_NOT_A_MACRO (a)) && !(isnan EIGEN_NOT_A_MACRO (a));
}
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC half abs(const half& a) {
#if defined(EIGEN_HAS_ARM64_FP16_SCALAR_ARITHMETIC)
return half(vabsh_f16(a.x));
#else
half result;
result.x = a.x & 0x7FFF;
return result;
#endif
}
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC half exp(const half& a) {
#if (EIGEN_CUDA_SDK_VER >= 80000 && defined EIGEN_CUDA_ARCH && EIGEN_CUDA_ARCH >= 530) || \
defined(EIGEN_HIP_DEVICE_COMPILE)
return half(hexp(a));
#else
return half(::expf(float(a)));
#endif
}
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC half expm1(const half& a) {
return half(numext::expm1(float(a)));
}
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC half log(const half& a) {
#if (defined(EIGEN_HAS_CUDA_FP16) && EIGEN_CUDA_SDK_VER >= 80000 && defined(EIGEN_CUDA_ARCH) && EIGEN_CUDA_ARCH >= 530) || \
(defined(EIGEN_HAS_HIP_FP16) && defined(EIGEN_HIP_DEVICE_COMPILE))
return half(::hlog(a));
#else
return half(::logf(float(a)));
#endif
}
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC half log1p(const half& a) {
return half(numext::log1p(float(a)));
}
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC half log10(const half& a) {
return half(::log10f(float(a)));
}
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC half log2(const half& a) {
return half(static_cast<float>(EIGEN_LOG2E) * ::logf(float(a)));
}
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC half sqrt(const half& a) {
#if (EIGEN_CUDA_SDK_VER >= 80000 && defined EIGEN_CUDA_ARCH && EIGEN_CUDA_ARCH >= 530) || \
defined(EIGEN_HIP_DEVICE_COMPILE)
return half(hsqrt(a));
#else
return half(::sqrtf(float(a)));
#endif
}
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC half pow(const half& a, const half& b) {
return half(::powf(float(a), float(b)));
}
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC half sin(const half& a) {
return half(::sinf(float(a)));
}
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC half cos(const half& a) {
return half(::cosf(float(a)));
}
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC half tan(const half& a) {
return half(::tanf(float(a)));
}
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC half tanh(const half& a) {
return half(::tanhf(float(a)));
}
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC half asin(const half& a) {
return half(::asinf(float(a)));
}
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC half acos(const half& a) {
return half(::acosf(float(a)));
}
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC half floor(const half& a) {
#if (EIGEN_CUDA_SDK_VER >= 80000 && defined EIGEN_CUDA_ARCH && EIGEN_CUDA_ARCH >= 300) || \
defined(EIGEN_HIP_DEVICE_COMPILE)
return half(hfloor(a));
#else
return half(::floorf(float(a)));
#endif
}
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC half ceil(const half& a) {
#if (EIGEN_CUDA_SDK_VER >= 80000 && defined EIGEN_CUDA_ARCH && EIGEN_CUDA_ARCH >= 300) || \
defined(EIGEN_HIP_DEVICE_COMPILE)
return half(hceil(a));
#else
return half(::ceilf(float(a)));
#endif
}
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC half rint(const half& a) {
return half(::rintf(float(a)));
}
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC half round(const half& a) {
return half(::roundf(float(a)));
}
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC half fmod(const half& a, const half& b) {
return half(::fmodf(float(a), float(b)));
}
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC half (min)(const half& a, const half& b) {
#if (defined(EIGEN_HAS_CUDA_FP16) && defined(EIGEN_CUDA_ARCH) && EIGEN_CUDA_ARCH >= 530) || \
(defined(EIGEN_HAS_HIP_FP16) && defined(EIGEN_HIP_DEVICE_COMPILE))
return __hlt(b, a) ? b : a;
#else
const float f1 = static_cast<float>(a);
const float f2 = static_cast<float>(b);
return f2 < f1 ? b : a;
#endif
}
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC half (max)(const half& a, const half& b) {
#if (defined(EIGEN_HAS_CUDA_FP16) && defined(EIGEN_CUDA_ARCH) && EIGEN_CUDA_ARCH >= 530) || \
(defined(EIGEN_HAS_HIP_FP16) && defined(EIGEN_HIP_DEVICE_COMPILE))
return __hlt(a, b) ? b : a;
#else
const float f1 = static_cast<float>(a);
const float f2 = static_cast<float>(b);
return f1 < f2 ? b : a;
#endif
}
#ifndef EIGEN_NO_IO
EIGEN_ALWAYS_INLINE std::ostream& operator << (std::ostream& os, const half& v) {
os << static_cast<float>(v);
return os;
}
#endif
} // end namespace half_impl
// import Eigen::half_impl::half into Eigen namespace
// using half_impl::half;
namespace internal {
template<>
struct random_default_impl<half, false, false>
{
static inline half run(const half& x, const half& y)
{
return x + (y-x) * half(float(std::rand()) / float(RAND_MAX));
}
static inline half run()
{
return run(half(-1.f), half(1.f));
}
};
template<> struct is_arithmetic<half> { enum { value = true }; };
} // end namespace internal
template<> struct NumTraits<Eigen::half>
: GenericNumTraits<Eigen::half>
{
enum {
IsSigned = true,
IsInteger = false,
IsComplex = false,
RequireInitialization = false
};
EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR static EIGEN_STRONG_INLINE Eigen::half epsilon() {
return half_impl::raw_uint16_to_half(0x0800);
}
EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR static EIGEN_STRONG_INLINE Eigen::half dummy_precision() {
return half_impl::raw_uint16_to_half(0x211f); // Eigen::half(1e-2f);
}
EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR static EIGEN_STRONG_INLINE Eigen::half highest() {
return half_impl::raw_uint16_to_half(0x7bff);
}
EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR static EIGEN_STRONG_INLINE Eigen::half lowest() {
return half_impl::raw_uint16_to_half(0xfbff);
}
EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR static EIGEN_STRONG_INLINE Eigen::half infinity() {
return half_impl::raw_uint16_to_half(0x7c00);
}
EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR static EIGEN_STRONG_INLINE Eigen::half quiet_NaN() {
return half_impl::raw_uint16_to_half(0x7e00);
}
};
} // end namespace Eigen
#if defined(EIGEN_HAS_GPU_FP16) || defined(EIGEN_HAS_ARM64_FP16_SCALAR_ARITHMETIC)
#pragma pop_macro("EIGEN_CONSTEXPR")
#endif
namespace Eigen {
namespace numext {
#if defined(EIGEN_GPU_COMPILE_PHASE)
template <>
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE bool(isnan)(const Eigen::half& h) {
return (half_impl::isnan)(h);
}
template <>
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE bool(isinf)(const Eigen::half& h) {
return (half_impl::isinf)(h);
}
template <>
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE bool(isfinite)(const Eigen::half& h) {
return (half_impl::isfinite)(h);
}
#endif
template <>
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC Eigen::half bit_cast<Eigen::half, uint16_t>(const uint16_t& src) {
return Eigen::half(Eigen::half_impl::raw_uint16_to_half(src));
}
template <>
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC uint16_t bit_cast<uint16_t, Eigen::half>(const Eigen::half& src) {
return Eigen::half_impl::raw_half_as_uint16(src);
}
} // namespace numext
} // namespace Eigen
// Add the missing shfl* intrinsics.
// The __shfl* functions are only valid on HIP or _CUDA_ARCH_ >= 300.
// CUDA defines them for (__CUDA_ARCH__ >= 300 || !defined(__CUDA_ARCH__))
//
// HIP and CUDA prior to SDK 9.0 define
// __shfl, __shfl_up, __shfl_down, __shfl_xor for int and float
// CUDA since 9.0 deprecates those and instead defines
// __shfl_sync, __shfl_up_sync, __shfl_down_sync, __shfl_xor_sync,
// with native support for __half and __nv_bfloat16
//
// Note that the following are __device__ - only functions.
#if (defined(EIGEN_CUDACC) && (!defined(EIGEN_CUDA_ARCH) || EIGEN_CUDA_ARCH >= 300)) \
|| defined(EIGEN_HIPCC)
#if defined(EIGEN_HAS_CUDA_FP16) && EIGEN_CUDA_SDK_VER >= 90000
__device__ EIGEN_STRONG_INLINE Eigen::half __shfl_sync(unsigned mask, Eigen::half var, int srcLane, int width=warpSize) {
const __half h = var;
return static_cast<Eigen::half>(__shfl_sync(mask, h, srcLane, width));
}
__device__ EIGEN_STRONG_INLINE Eigen::half __shfl_up_sync(unsigned mask, Eigen::half var, unsigned int delta, int width=warpSize) {
const __half h = var;
return static_cast<Eigen::half>(__shfl_up_sync(mask, h, delta, width));
}
__device__ EIGEN_STRONG_INLINE Eigen::half __shfl_down_sync(unsigned mask, Eigen::half var, unsigned int delta, int width=warpSize) {
const __half h = var;
return static_cast<Eigen::half>(__shfl_down_sync(mask, h, delta, width));
}
__device__ EIGEN_STRONG_INLINE Eigen::half __shfl_xor_sync(unsigned mask, Eigen::half var, int laneMask, int width=warpSize) {
const __half h = var;
return static_cast<Eigen::half>(__shfl_xor_sync(mask, h, laneMask, width));
}
#else // HIP or CUDA SDK < 9.0
__device__ EIGEN_STRONG_INLINE Eigen::half __shfl(Eigen::half var, int srcLane, int width=warpSize) {
const int ivar = static_cast<int>(Eigen::numext::bit_cast<Eigen::numext::uint16_t>(var));
return Eigen::numext::bit_cast<Eigen::half>(static_cast<Eigen::numext::uint16_t>(__shfl(ivar, srcLane, width)));
}
__device__ EIGEN_STRONG_INLINE Eigen::half __shfl_up(Eigen::half var, unsigned int delta, int width=warpSize) {
const int ivar = static_cast<int>(Eigen::numext::bit_cast<Eigen::numext::uint16_t>(var));
return Eigen::numext::bit_cast<Eigen::half>(static_cast<Eigen::numext::uint16_t>(__shfl_up(ivar, delta, width)));
}
__device__ EIGEN_STRONG_INLINE Eigen::half __shfl_down(Eigen::half var, unsigned int delta, int width=warpSize) {
const int ivar = static_cast<int>(Eigen::numext::bit_cast<Eigen::numext::uint16_t>(var));
return Eigen::numext::bit_cast<Eigen::half>(static_cast<Eigen::numext::uint16_t>(__shfl_down(ivar, delta, width)));
}
__device__ EIGEN_STRONG_INLINE Eigen::half __shfl_xor(Eigen::half var, int laneMask, int width=warpSize) {
const int ivar = static_cast<int>(Eigen::numext::bit_cast<Eigen::numext::uint16_t>(var));
return Eigen::numext::bit_cast<Eigen::half>(static_cast<Eigen::numext::uint16_t>(__shfl_xor(ivar, laneMask, width)));
}
#endif // HIP vs CUDA
#endif // __shfl*
// ldg() has an overload for __half_raw, but we also need one for Eigen::half.
#if (defined(EIGEN_CUDACC) && (!defined(EIGEN_CUDA_ARCH) || EIGEN_CUDA_ARCH >= 350)) \
|| defined(EIGEN_HIPCC)
EIGEN_STRONG_INLINE __device__ Eigen::half __ldg(const Eigen::half* ptr) {
return Eigen::half_impl::raw_uint16_to_half(__ldg(reinterpret_cast<const Eigen::numext::uint16_t*>(ptr)));
}
#endif // __ldg
#if EIGEN_HAS_STD_HASH
namespace std {
template <>
struct hash<Eigen::half> {
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE std::size_t operator()(const Eigen::half& a) const {
return static_cast<std::size_t>(Eigen::numext::bit_cast<Eigen::numext::uint16_t>(a));
}
};
} // end namespace std
#endif
#endif // EIGEN_HALF_H

View File

@ -0,0 +1,120 @@
// This file is part of Eigen, a lightweight C++ template library
// for linear algebra.
//
// Copyright (C) 2016 Benoit Steiner <benoit.steiner.goog@gmail.com>
// Copyright (C) 2019 Rasmus Munk Larsen <rmlarsen@google.com>
//
// This Source Code Form is subject to the terms of the Mozilla
// Public License v. 2.0. If a copy of the MPL was not distributed
// with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
#ifndef EIGEN_GENERIC_TYPE_CASTING_H
#define EIGEN_GENERIC_TYPE_CASTING_H
namespace Eigen {
namespace internal {
template<>
struct scalar_cast_op<float, Eigen::half> {
EIGEN_EMPTY_STRUCT_CTOR(scalar_cast_op)
typedef Eigen::half result_type;
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Eigen::half operator() (const float& a) const {
#if (defined(EIGEN_HAS_CUDA_FP16) && defined(EIGEN_CUDA_ARCH) && EIGEN_CUDA_ARCH >= 300) || \
(defined(EIGEN_HAS_HIP_FP16) && defined(EIGEN_HIP_DEVICE_COMPILE))
return __float2half(a);
#else
return Eigen::half(a);
#endif
}
};
template<>
struct functor_traits<scalar_cast_op<float, Eigen::half> >
{ enum { Cost = NumTraits<float>::AddCost, PacketAccess = false }; };
template<>
struct scalar_cast_op<int, Eigen::half> {
EIGEN_EMPTY_STRUCT_CTOR(scalar_cast_op)
typedef Eigen::half result_type;
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Eigen::half operator() (const int& a) const {
#if (defined(EIGEN_HAS_CUDA_FP16) && defined(EIGEN_CUDA_ARCH) && EIGEN_CUDA_ARCH >= 300) || \
(defined(EIGEN_HAS_HIP_FP16) && defined(EIGEN_HIP_DEVICE_COMPILE))
return __float2half(static_cast<float>(a));
#else
return Eigen::half(static_cast<float>(a));
#endif
}
};
template<>
struct functor_traits<scalar_cast_op<int, Eigen::half> >
{ enum { Cost = NumTraits<float>::AddCost, PacketAccess = false }; };
template<>
struct scalar_cast_op<Eigen::half, float> {
EIGEN_EMPTY_STRUCT_CTOR(scalar_cast_op)
typedef float result_type;
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE float operator() (const Eigen::half& a) const {
#if (defined(EIGEN_HAS_CUDA_FP16) && defined(EIGEN_CUDA_ARCH) && EIGEN_CUDA_ARCH >= 300) || \
(defined(EIGEN_HAS_HIP_FP16) && defined(EIGEN_HIP_DEVICE_COMPILE))
return __half2float(a);
#else
return static_cast<float>(a);
#endif
}
};
template<>
struct functor_traits<scalar_cast_op<Eigen::half, float> >
{ enum { Cost = NumTraits<float>::AddCost, PacketAccess = false }; };
template<>
struct scalar_cast_op<float, Eigen::bfloat16> {
EIGEN_EMPTY_STRUCT_CTOR(scalar_cast_op)
typedef Eigen::bfloat16 result_type;
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Eigen::bfloat16 operator() (const float& a) const {
return Eigen::bfloat16(a);
}
};
template<>
struct functor_traits<scalar_cast_op<float, Eigen::bfloat16> >
{ enum { Cost = NumTraits<float>::AddCost, PacketAccess = false }; };
template<>
struct scalar_cast_op<int, Eigen::bfloat16> {
EIGEN_EMPTY_STRUCT_CTOR(scalar_cast_op)
typedef Eigen::bfloat16 result_type;
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Eigen::bfloat16 operator() (const int& a) const {
return Eigen::bfloat16(static_cast<float>(a));
}
};
template<>
struct functor_traits<scalar_cast_op<int, Eigen::bfloat16> >
{ enum { Cost = NumTraits<float>::AddCost, PacketAccess = false }; };
template<>
struct scalar_cast_op<Eigen::bfloat16, float> {
EIGEN_EMPTY_STRUCT_CTOR(scalar_cast_op)
typedef float result_type;
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE float operator() (const Eigen::bfloat16& a) const {
return static_cast<float>(a);
}
};
template<>
struct functor_traits<scalar_cast_op<Eigen::bfloat16, float> >
{ enum { Cost = NumTraits<float>::AddCost, PacketAccess = false }; };
}
}
#endif // EIGEN_GENERIC_TYPE_CASTING_H

View File

@ -0,0 +1,103 @@
// This file is part of Eigen, a lightweight C++ template library
// for linear algebra.
//
// Copyright (C) 2014 Benoit Steiner <benoit.steiner.goog@gmail.com>
//
// This Source Code Form is subject to the terms of the Mozilla
// Public License v. 2.0. If a copy of the MPL was not distributed
// with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
#ifndef EIGEN_MATH_FUNCTIONS_GPU_H
#define EIGEN_MATH_FUNCTIONS_GPU_H
namespace Eigen {
namespace internal {
// Make sure this is only available when targeting a GPU: we don't want to
// introduce conflicts between these packet_traits definitions and the ones
// we'll use on the host side (SSE, AVX, ...)
#if defined(EIGEN_GPUCC) && defined(EIGEN_USE_GPU)
template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
float4 plog<float4>(const float4& a)
{
return make_float4(logf(a.x), logf(a.y), logf(a.z), logf(a.w));
}
template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
double2 plog<double2>(const double2& a)
{
using ::log;
return make_double2(log(a.x), log(a.y));
}
template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
float4 plog1p<float4>(const float4& a)
{
return make_float4(log1pf(a.x), log1pf(a.y), log1pf(a.z), log1pf(a.w));
}
template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
double2 plog1p<double2>(const double2& a)
{
return make_double2(log1p(a.x), log1p(a.y));
}
template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
float4 pexp<float4>(const float4& a)
{
return make_float4(expf(a.x), expf(a.y), expf(a.z), expf(a.w));
}
template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
double2 pexp<double2>(const double2& a)
{
using ::exp;
return make_double2(exp(a.x), exp(a.y));
}
template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
float4 pexpm1<float4>(const float4& a)
{
return make_float4(expm1f(a.x), expm1f(a.y), expm1f(a.z), expm1f(a.w));
}
template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
double2 pexpm1<double2>(const double2& a)
{
return make_double2(expm1(a.x), expm1(a.y));
}
template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
float4 psqrt<float4>(const float4& a)
{
return make_float4(sqrtf(a.x), sqrtf(a.y), sqrtf(a.z), sqrtf(a.w));
}
template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
double2 psqrt<double2>(const double2& a)
{
using ::sqrt;
return make_double2(sqrt(a.x), sqrt(a.y));
}
template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
float4 prsqrt<float4>(const float4& a)
{
return make_float4(rsqrtf(a.x), rsqrtf(a.y), rsqrtf(a.z), rsqrtf(a.w));
}
template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
double2 prsqrt<double2>(const double2& a)
{
return make_double2(rsqrt(a.x), rsqrt(a.y));
}
#endif
} // end namespace internal
} // end namespace Eigen
#endif // EIGEN_MATH_FUNCTIONS_GPU_H

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,80 @@
// This file is part of Eigen, a lightweight C++ template library
// for linear algebra.
//
// Copyright (C) 2016 Benoit Steiner <benoit.steiner.goog@gmail.com>
//
// This Source Code Form is subject to the terms of the Mozilla
// Public License v. 2.0. If a copy of the MPL was not distributed
// with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
#ifndef EIGEN_TYPE_CASTING_GPU_H
#define EIGEN_TYPE_CASTING_GPU_H
namespace Eigen {
namespace internal {
#if (defined(EIGEN_HAS_CUDA_FP16) && defined(EIGEN_CUDA_ARCH) && EIGEN_CUDA_ARCH >= 300) || \
(defined(EIGEN_HAS_HIP_FP16) && defined(EIGEN_HIP_DEVICE_COMPILE))
template <>
struct type_casting_traits<Eigen::half, float> {
enum {
VectorizedCast = 1,
SrcCoeffRatio = 1,
TgtCoeffRatio = 2
};
};
template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE float4 pcast<half2, float4>(const half2& a, const half2& b) {
float2 r1 = __half22float2(a);
float2 r2 = __half22float2(b);
return make_float4(r1.x, r1.y, r2.x, r2.y);
}
template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet4h2 pcast<float4, Packet4h2>(const float4& a, const float4& b) {
Packet4h2 r;
half2* r_alias=reinterpret_cast<half2*>(&r);
r_alias[0]=__floats2half2_rn(a.x,a.y);
r_alias[1]=__floats2half2_rn(a.z,a.w);
r_alias[2]=__floats2half2_rn(b.x,b.y);
r_alias[3]=__floats2half2_rn(b.z,b.w);
return r;
}
template <>
struct type_casting_traits<float, Eigen::half> {
enum {
VectorizedCast = 1,
SrcCoeffRatio = 2,
TgtCoeffRatio = 1
};
};
template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE float4 pcast<Packet4h2, float4>(const Packet4h2& a) {
// Simply discard the second half of the input
float4 r;
const half2* a_alias=reinterpret_cast<const half2*>(&a);
float2 r1 = __half22float2(a_alias[0]);
float2 r2 = __half22float2(a_alias[1]);
r.x=static_cast<float>(r1.x);
r.y=static_cast<float>(r1.y);
r.z=static_cast<float>(r2.x);
r.w=static_cast<float>(r2.y);
return r;
}
template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE half2 pcast<float4, half2>(const float4& a) {
// Simply discard the second half of the input
return __floats2half2_rn(a.x, a.y);
}
#endif
} // end namespace internal
} // end namespace Eigen
#endif // EIGEN_TYPE_CASTING_GPU_H

View File

@ -0,0 +1,23 @@
/*
* math_constants.h -
* HIP equivalent of the CUDA header of the same name
*/
#ifndef __MATH_CONSTANTS_H__
#define __MATH_CONSTANTS_H__
/* single precision constants */
#define HIPRT_INF_F __int_as_float(0x7f800000)
#define HIPRT_NAN_F __int_as_float(0x7fffffff)
#define HIPRT_MIN_DENORM_F __int_as_float(0x00000001)
#define HIPRT_MAX_NORMAL_F __int_as_float(0x7f7fffff)
#define HIPRT_NEG_ZERO_F __int_as_float(0x80000000)
#define HIPRT_ZERO_F 0.0f
#define HIPRT_ONE_F 1.0f
/* double precision constants */
#define HIPRT_INF __hiloint2double(0x7ff00000, 0x00000000)
#define HIPRT_NAN __hiloint2double(0xfff80000, 0x00000000)
#endif

View File

@ -0,0 +1,648 @@
// This file is part of Eigen, a lightweight C++ template library
// for linear algebra.
//
// Copyright (C) 2018 Wave Computing, Inc.
// Written by:
// Chris Larsen
// Alexey Frunze (afrunze@wavecomp.com)
//
// This Source Code Form is subject to the terms of the Mozilla
// Public License v. 2.0. If a copy of the MPL was not distributed
// with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
#ifndef EIGEN_COMPLEX_MSA_H
#define EIGEN_COMPLEX_MSA_H
#include <iostream>
namespace Eigen {
namespace internal {
//---------- float ----------
struct Packet2cf {
EIGEN_STRONG_INLINE Packet2cf() {
}
EIGEN_STRONG_INLINE explicit Packet2cf(const std::complex<float>& a,
const std::complex<float>& b) {
Packet4f t = { std::real(a), std::imag(a), std::real(b), std::imag(b) };
v = t;
}
EIGEN_STRONG_INLINE explicit Packet2cf(const Packet4f& a) : v(a) {
}
EIGEN_STRONG_INLINE Packet2cf(const Packet2cf& a) : v(a.v) {
}
EIGEN_STRONG_INLINE Packet2cf& operator=(const Packet2cf& b) {
v = b.v;
return *this;
}
EIGEN_STRONG_INLINE Packet2cf conjugate(void) const {
return Packet2cf((Packet4f)__builtin_msa_bnegi_d((v2u64)v, 63));
}
EIGEN_STRONG_INLINE Packet2cf& operator*=(const Packet2cf& b) {
Packet4f v1, v2;
// Get the real values of a | a1_re | a1_re | a2_re | a2_re |
v1 = (Packet4f)__builtin_msa_ilvev_w((v4i32)v, (v4i32)v);
// Get the imag values of a | a1_im | a1_im | a2_im | a2_im |
v2 = (Packet4f)__builtin_msa_ilvod_w((v4i32)v, (v4i32)v);
// Multiply the real a with b
v1 = pmul(v1, b.v);
// Multiply the imag a with b
v2 = pmul(v2, b.v);
// Conjugate v2
v2 = Packet2cf(v2).conjugate().v;
// Swap real/imag elements in v2.
v2 = (Packet4f)__builtin_msa_shf_w((v4i32)v2, EIGEN_MSA_SHF_I8(1, 0, 3, 2));
// Add and return the result
v = padd(v1, v2);
return *this;
}
EIGEN_STRONG_INLINE Packet2cf operator*(const Packet2cf& b) const {
return Packet2cf(*this) *= b;
}
EIGEN_STRONG_INLINE Packet2cf& operator+=(const Packet2cf& b) {
v = padd(v, b.v);
return *this;
}
EIGEN_STRONG_INLINE Packet2cf operator+(const Packet2cf& b) const {
return Packet2cf(*this) += b;
}
EIGEN_STRONG_INLINE Packet2cf& operator-=(const Packet2cf& b) {
v = psub(v, b.v);
return *this;
}
EIGEN_STRONG_INLINE Packet2cf operator-(const Packet2cf& b) const {
return Packet2cf(*this) -= b;
}
EIGEN_STRONG_INLINE Packet2cf& operator/=(const Packet2cf& b) {
*this *= b.conjugate();
Packet4f s = pmul<Packet4f>(b.v, b.v);
s = padd(s, (Packet4f)__builtin_msa_shf_w((v4i32)s, EIGEN_MSA_SHF_I8(1, 0, 3, 2)));
v = pdiv(v, s);
return *this;
}
EIGEN_STRONG_INLINE Packet2cf operator/(const Packet2cf& b) const {
return Packet2cf(*this) /= b;
}
EIGEN_STRONG_INLINE Packet2cf operator-(void) const {
return Packet2cf(pnegate(v));
}
Packet4f v;
};
inline std::ostream& operator<<(std::ostream& os, const Packet2cf& value) {
os << "[ (" << value.v[0] << ", " << value.v[1]
<< "i),"
" ("
<< value.v[2] << ", " << value.v[3] << "i) ]";
return os;
}
template <>
struct packet_traits<std::complex<float> > : default_packet_traits {
typedef Packet2cf type;
typedef Packet2cf half;
enum {
Vectorizable = 1,
AlignedOnScalar = 1,
size = 2,
HasHalfPacket = 0,
HasAdd = 1,
HasSub = 1,
HasMul = 1,
HasDiv = 1,
HasNegate = 1,
HasAbs = 0,
HasAbs2 = 0,
HasMin = 0,
HasMax = 0,
HasSetLinear = 0,
HasBlend = 1
};
};
template <>
struct unpacket_traits<Packet2cf> {
typedef std::complex<float> type;
enum { size = 2, alignment = Aligned16, vectorizable=true, masked_load_available=false, masked_store_available=false };
typedef Packet2cf half;
};
template <>
EIGEN_STRONG_INLINE Packet2cf pset1<Packet2cf>(const std::complex<float>& from) {
EIGEN_MSA_DEBUG;
float f0 = from.real(), f1 = from.imag();
Packet4f v0 = { f0, f0, f0, f0 };
Packet4f v1 = { f1, f1, f1, f1 };
return Packet2cf((Packet4f)__builtin_msa_ilvr_w((Packet4i)v1, (Packet4i)v0));
}
template <>
EIGEN_STRONG_INLINE Packet2cf padd<Packet2cf>(const Packet2cf& a, const Packet2cf& b) {
EIGEN_MSA_DEBUG;
return a + b;
}
template <>
EIGEN_STRONG_INLINE Packet2cf psub<Packet2cf>(const Packet2cf& a, const Packet2cf& b) {
EIGEN_MSA_DEBUG;
return a - b;
}
template <>
EIGEN_STRONG_INLINE Packet2cf pnegate(const Packet2cf& a) {
EIGEN_MSA_DEBUG;
return -a;
}
template <>
EIGEN_STRONG_INLINE Packet2cf pconj(const Packet2cf& a) {
EIGEN_MSA_DEBUG;
return a.conjugate();
}
template <>
EIGEN_STRONG_INLINE Packet2cf pmul<Packet2cf>(const Packet2cf& a, const Packet2cf& b) {
EIGEN_MSA_DEBUG;
return a * b;
}
template <>
EIGEN_STRONG_INLINE Packet2cf pand<Packet2cf>(const Packet2cf& a, const Packet2cf& b) {
EIGEN_MSA_DEBUG;
return Packet2cf(pand(a.v, b.v));
}
template <>
EIGEN_STRONG_INLINE Packet2cf por<Packet2cf>(const Packet2cf& a, const Packet2cf& b) {
EIGEN_MSA_DEBUG;
return Packet2cf(por(a.v, b.v));
}
template <>
EIGEN_STRONG_INLINE Packet2cf pxor<Packet2cf>(const Packet2cf& a, const Packet2cf& b) {
EIGEN_MSA_DEBUG;
return Packet2cf(pxor(a.v, b.v));
}
template <>
EIGEN_STRONG_INLINE Packet2cf pandnot<Packet2cf>(const Packet2cf& a, const Packet2cf& b) {
EIGEN_MSA_DEBUG;
return Packet2cf(pandnot(a.v, b.v));
}
template <>
EIGEN_STRONG_INLINE Packet2cf pload<Packet2cf>(const std::complex<float>* from) {
EIGEN_MSA_DEBUG;
EIGEN_DEBUG_ALIGNED_LOAD return Packet2cf(pload<Packet4f>((const float*)from));
}
template <>
EIGEN_STRONG_INLINE Packet2cf ploadu<Packet2cf>(const std::complex<float>* from) {
EIGEN_MSA_DEBUG;
EIGEN_DEBUG_UNALIGNED_LOAD return Packet2cf(ploadu<Packet4f>((const float*)from));
}
template <>
EIGEN_STRONG_INLINE Packet2cf ploaddup<Packet2cf>(const std::complex<float>* from) {
EIGEN_MSA_DEBUG;
return pset1<Packet2cf>(*from);
}
template <>
EIGEN_STRONG_INLINE void pstore<std::complex<float> >(std::complex<float>* to,
const Packet2cf& from) {
EIGEN_MSA_DEBUG;
EIGEN_DEBUG_ALIGNED_STORE pstore<float>((float*)to, from.v);
}
template <>
EIGEN_STRONG_INLINE void pstoreu<std::complex<float> >(std::complex<float>* to,
const Packet2cf& from) {
EIGEN_MSA_DEBUG;
EIGEN_DEBUG_UNALIGNED_STORE pstoreu<float>((float*)to, from.v);
}
template <>
EIGEN_DEVICE_FUNC inline Packet2cf pgather<std::complex<float>, Packet2cf>(
const std::complex<float>* from, Index stride) {
EIGEN_MSA_DEBUG;
return Packet2cf(from[0 * stride], from[1 * stride]);
}
template <>
EIGEN_DEVICE_FUNC inline void pscatter<std::complex<float>, Packet2cf>(std::complex<float>* to,
const Packet2cf& from,
Index stride) {
EIGEN_MSA_DEBUG;
*to = std::complex<float>(from.v[0], from.v[1]);
to += stride;
*to = std::complex<float>(from.v[2], from.v[3]);
}
template <>
EIGEN_STRONG_INLINE void prefetch<std::complex<float> >(const std::complex<float>* addr) {
EIGEN_MSA_DEBUG;
prefetch(reinterpret_cast<const float*>(addr));
}
template <>
EIGEN_STRONG_INLINE std::complex<float> pfirst<Packet2cf>(const Packet2cf& a) {
EIGEN_MSA_DEBUG;
return std::complex<float>(a.v[0], a.v[1]);
}
template <>
EIGEN_STRONG_INLINE Packet2cf preverse(const Packet2cf& a) {
EIGEN_MSA_DEBUG;
return Packet2cf((Packet4f)__builtin_msa_shf_w((v4i32)a.v, EIGEN_MSA_SHF_I8(2, 3, 0, 1)));
}
template <>
EIGEN_STRONG_INLINE Packet2cf pcplxflip<Packet2cf>(const Packet2cf& a) {
EIGEN_MSA_DEBUG;
return Packet2cf((Packet4f)__builtin_msa_shf_w((v4i32)a.v, EIGEN_MSA_SHF_I8(1, 0, 3, 2)));
}
template <>
EIGEN_STRONG_INLINE std::complex<float> predux<Packet2cf>(const Packet2cf& a) {
EIGEN_MSA_DEBUG;
Packet4f value = (Packet4f)preverse((Packet2d)a.v);
value += a.v;
return std::complex<float>(value[0], value[1]);
}
template <>
EIGEN_STRONG_INLINE std::complex<float> predux_mul<Packet2cf>(const Packet2cf& a) {
EIGEN_MSA_DEBUG;
return std::complex<float>((a.v[0] * a.v[2]) - (a.v[1] * a.v[3]),
(a.v[0] * a.v[3]) + (a.v[1] * a.v[2]));
}
EIGEN_MAKE_CONJ_HELPER_CPLX_REAL(Packet2cf, Packet4f)
template <>
EIGEN_STRONG_INLINE Packet2cf pdiv<Packet2cf>(const Packet2cf& a, const Packet2cf& b) {
EIGEN_MSA_DEBUG;
return a / b;
}
inline std::ostream& operator<<(std::ostream& os, const PacketBlock<Packet2cf, 2>& value) {
os << "[ " << value.packet[0] << ", " << std::endl << " " << value.packet[1] << " ]";
return os;
}
EIGEN_DEVICE_FUNC inline void ptranspose(PacketBlock<Packet2cf, 2>& kernel) {
EIGEN_MSA_DEBUG;
Packet4f tmp =
(Packet4f)__builtin_msa_ilvl_d((v2i64)kernel.packet[1].v, (v2i64)kernel.packet[0].v);
kernel.packet[0].v =
(Packet4f)__builtin_msa_ilvr_d((v2i64)kernel.packet[1].v, (v2i64)kernel.packet[0].v);
kernel.packet[1].v = tmp;
}
template <>
EIGEN_STRONG_INLINE Packet2cf pblend(const Selector<2>& ifPacket, const Packet2cf& thenPacket,
const Packet2cf& elsePacket) {
return (Packet2cf)(Packet4f)pblend<Packet2d>(ifPacket, (Packet2d)thenPacket.v,
(Packet2d)elsePacket.v);
}
//---------- double ----------
struct Packet1cd {
EIGEN_STRONG_INLINE Packet1cd() {
}
EIGEN_STRONG_INLINE explicit Packet1cd(const std::complex<double>& a) {
v[0] = std::real(a);
v[1] = std::imag(a);
}
EIGEN_STRONG_INLINE explicit Packet1cd(const Packet2d& a) : v(a) {
}
EIGEN_STRONG_INLINE Packet1cd(const Packet1cd& a) : v(a.v) {
}
EIGEN_STRONG_INLINE Packet1cd& operator=(const Packet1cd& b) {
v = b.v;
return *this;
}
EIGEN_STRONG_INLINE Packet1cd conjugate(void) const {
static const v2u64 p2ul_CONJ_XOR = { 0x0, 0x8000000000000000 };
return (Packet1cd)pxor(v, (Packet2d)p2ul_CONJ_XOR);
}
EIGEN_STRONG_INLINE Packet1cd& operator*=(const Packet1cd& b) {
Packet2d v1, v2;
// Get the real values of a | a1_re | a1_re
v1 = (Packet2d)__builtin_msa_ilvev_d((v2i64)v, (v2i64)v);
// Get the imag values of a | a1_im | a1_im
v2 = (Packet2d)__builtin_msa_ilvod_d((v2i64)v, (v2i64)v);
// Multiply the real a with b
v1 = pmul(v1, b.v);
// Multiply the imag a with b
v2 = pmul(v2, b.v);
// Conjugate v2
v2 = Packet1cd(v2).conjugate().v;
// Swap real/imag elements in v2.
v2 = (Packet2d)__builtin_msa_shf_w((v4i32)v2, EIGEN_MSA_SHF_I8(2, 3, 0, 1));
// Add and return the result
v = padd(v1, v2);
return *this;
}
EIGEN_STRONG_INLINE Packet1cd operator*(const Packet1cd& b) const {
return Packet1cd(*this) *= b;
}
EIGEN_STRONG_INLINE Packet1cd& operator+=(const Packet1cd& b) {
v = padd(v, b.v);
return *this;
}
EIGEN_STRONG_INLINE Packet1cd operator+(const Packet1cd& b) const {
return Packet1cd(*this) += b;
}
EIGEN_STRONG_INLINE Packet1cd& operator-=(const Packet1cd& b) {
v = psub(v, b.v);
return *this;
}
EIGEN_STRONG_INLINE Packet1cd operator-(const Packet1cd& b) const {
return Packet1cd(*this) -= b;
}
EIGEN_STRONG_INLINE Packet1cd& operator/=(const Packet1cd& b) {
*this *= b.conjugate();
Packet2d s = pmul<Packet2d>(b.v, b.v);
s = padd(s, preverse<Packet2d>(s));
v = pdiv(v, s);
return *this;
}
EIGEN_STRONG_INLINE Packet1cd operator/(const Packet1cd& b) const {
return Packet1cd(*this) /= b;
}
EIGEN_STRONG_INLINE Packet1cd operator-(void) const {
return Packet1cd(pnegate(v));
}
Packet2d v;
};
inline std::ostream& operator<<(std::ostream& os, const Packet1cd& value) {
os << "[ (" << value.v[0] << ", " << value.v[1] << "i) ]";
return os;
}
template <>
struct packet_traits<std::complex<double> > : default_packet_traits {
typedef Packet1cd type;
typedef Packet1cd half;
enum {
Vectorizable = 1,
AlignedOnScalar = 0,
size = 1,
HasHalfPacket = 0,
HasAdd = 1,
HasSub = 1,
HasMul = 1,
HasDiv = 1,
HasNegate = 1,
HasAbs = 0,
HasAbs2 = 0,
HasMin = 0,
HasMax = 0,
HasSetLinear = 0
};
};
template <>
struct unpacket_traits<Packet1cd> {
typedef std::complex<double> type;
enum { size = 1, alignment = Aligned16, vectorizable=true, masked_load_available=false, masked_store_available=false };
typedef Packet1cd half;
};
template <>
EIGEN_STRONG_INLINE Packet1cd pload<Packet1cd>(const std::complex<double>* from) {
EIGEN_MSA_DEBUG;
EIGEN_DEBUG_ALIGNED_LOAD return Packet1cd(pload<Packet2d>((const double*)from));
}
template <>
EIGEN_STRONG_INLINE Packet1cd ploadu<Packet1cd>(const std::complex<double>* from) {
EIGEN_MSA_DEBUG;
EIGEN_DEBUG_UNALIGNED_LOAD return Packet1cd(ploadu<Packet2d>((const double*)from));
}
template <>
EIGEN_STRONG_INLINE Packet1cd pset1<Packet1cd>(const std::complex<double>& from) {
EIGEN_MSA_DEBUG;
return Packet1cd(from);
}
template <>
EIGEN_STRONG_INLINE Packet1cd padd<Packet1cd>(const Packet1cd& a, const Packet1cd& b) {
EIGEN_MSA_DEBUG;
return a + b;
}
template <>
EIGEN_STRONG_INLINE Packet1cd psub<Packet1cd>(const Packet1cd& a, const Packet1cd& b) {
EIGEN_MSA_DEBUG;
return a - b;
}
template <>
EIGEN_STRONG_INLINE Packet1cd pnegate(const Packet1cd& a) {
EIGEN_MSA_DEBUG;
return -a;
}
template <>
EIGEN_STRONG_INLINE Packet1cd pconj(const Packet1cd& a) {
EIGEN_MSA_DEBUG;
return a.conjugate();
}
template <>
EIGEN_STRONG_INLINE Packet1cd pmul<Packet1cd>(const Packet1cd& a, const Packet1cd& b) {
EIGEN_MSA_DEBUG;
return a * b;
}
template <>
EIGEN_STRONG_INLINE Packet1cd pand<Packet1cd>(const Packet1cd& a, const Packet1cd& b) {
EIGEN_MSA_DEBUG;
return Packet1cd(pand(a.v, b.v));
}
template <>
EIGEN_STRONG_INLINE Packet1cd por<Packet1cd>(const Packet1cd& a, const Packet1cd& b) {
EIGEN_MSA_DEBUG;
return Packet1cd(por(a.v, b.v));
}
template <>
EIGEN_STRONG_INLINE Packet1cd pxor<Packet1cd>(const Packet1cd& a, const Packet1cd& b) {
EIGEN_MSA_DEBUG;
return Packet1cd(pxor(a.v, b.v));
}
template <>
EIGEN_STRONG_INLINE Packet1cd pandnot<Packet1cd>(const Packet1cd& a, const Packet1cd& b) {
EIGEN_MSA_DEBUG;
return Packet1cd(pandnot(a.v, b.v));
}
template <>
EIGEN_STRONG_INLINE Packet1cd ploaddup<Packet1cd>(const std::complex<double>* from) {
EIGEN_MSA_DEBUG;
return pset1<Packet1cd>(*from);
}
template <>
EIGEN_STRONG_INLINE void pstore<std::complex<double> >(std::complex<double>* to,
const Packet1cd& from) {
EIGEN_MSA_DEBUG;
EIGEN_DEBUG_ALIGNED_STORE pstore<double>((double*)to, from.v);
}
template <>
EIGEN_STRONG_INLINE void pstoreu<std::complex<double> >(std::complex<double>* to,
const Packet1cd& from) {
EIGEN_MSA_DEBUG;
EIGEN_DEBUG_UNALIGNED_STORE pstoreu<double>((double*)to, from.v);
}
template <>
EIGEN_STRONG_INLINE void prefetch<std::complex<double> >(const std::complex<double>* addr) {
EIGEN_MSA_DEBUG;
prefetch(reinterpret_cast<const double*>(addr));
}
template <>
EIGEN_DEVICE_FUNC inline Packet1cd pgather<std::complex<double>, Packet1cd>(
const std::complex<double>* from, Index stride __attribute__((unused))) {
EIGEN_MSA_DEBUG;
Packet1cd res;
res.v[0] = std::real(from[0]);
res.v[1] = std::imag(from[0]);
return res;
}
template <>
EIGEN_DEVICE_FUNC inline void pscatter<std::complex<double>, Packet1cd>(std::complex<double>* to,
const Packet1cd& from,
Index stride
__attribute__((unused))) {
EIGEN_MSA_DEBUG;
pstore(to, from);
}
template <>
EIGEN_STRONG_INLINE std::complex<double> pfirst<Packet1cd>(const Packet1cd& a) {
EIGEN_MSA_DEBUG;
return std::complex<double>(a.v[0], a.v[1]);
}
template <>
EIGEN_STRONG_INLINE Packet1cd preverse(const Packet1cd& a) {
EIGEN_MSA_DEBUG;
return a;
}
template <>
EIGEN_STRONG_INLINE std::complex<double> predux<Packet1cd>(const Packet1cd& a) {
EIGEN_MSA_DEBUG;
return pfirst(a);
}
template <>
EIGEN_STRONG_INLINE std::complex<double> predux_mul<Packet1cd>(const Packet1cd& a) {
EIGEN_MSA_DEBUG;
return pfirst(a);
}
EIGEN_MAKE_CONJ_HELPER_CPLX_REAL(Packet1cd, Packet2d)
template <>
EIGEN_STRONG_INLINE Packet1cd pdiv<Packet1cd>(const Packet1cd& a, const Packet1cd& b) {
EIGEN_MSA_DEBUG;
return a / b;
}
EIGEN_STRONG_INLINE Packet1cd pcplxflip /*<Packet1cd>*/ (const Packet1cd& x) {
EIGEN_MSA_DEBUG;
return Packet1cd(preverse(Packet2d(x.v)));
}
inline std::ostream& operator<<(std::ostream& os, const PacketBlock<Packet1cd, 2>& value) {
os << "[ " << value.packet[0] << ", " << std::endl << " " << value.packet[1] << " ]";
return os;
}
EIGEN_STRONG_INLINE void ptranspose(PacketBlock<Packet1cd, 2>& kernel) {
EIGEN_MSA_DEBUG;
Packet2d v1, v2;
v1 = (Packet2d)__builtin_msa_ilvev_d((v2i64)kernel.packet[0].v, (v2i64)kernel.packet[1].v);
// Get the imag values of a
v2 = (Packet2d)__builtin_msa_ilvod_d((v2i64)kernel.packet[0].v, (v2i64)kernel.packet[1].v);
kernel.packet[0].v = v1;
kernel.packet[1].v = v2;
}
} // end namespace internal
} // end namespace Eigen
#endif // EIGEN_COMPLEX_MSA_H

View File

@ -0,0 +1,387 @@
// This file is part of Eigen, a lightweight C++ template library
// for linear algebra.
//
// Copyright (C) 2007 Julien Pommier
// Copyright (C) 2014 Pedro Gonnet (pedro.gonnet@gmail.com)
// Copyright (C) 2016 Gael Guennebaud <gael.guennebaud@inria.fr>
//
// Copyright (C) 2018 Wave Computing, Inc.
// Written by:
// Chris Larsen
// Alexey Frunze (afrunze@wavecomp.com)
//
// This Source Code Form is subject to the terms of the Mozilla
// Public License v. 2.0. If a copy of the MPL was not distributed
// with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
/* The sin, cos, exp, and log functions of this file come from
* Julien Pommier's sse math library: http://gruntthepeon.free.fr/ssemath/
*/
/* The tanh function of this file is an adaptation of
* template<typename T> T generic_fast_tanh_float(const T&)
* from MathFunctionsImpl.h.
*/
#ifndef EIGEN_MATH_FUNCTIONS_MSA_H
#define EIGEN_MATH_FUNCTIONS_MSA_H
namespace Eigen {
namespace internal {
template <>
EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS EIGEN_UNUSED Packet4f
plog<Packet4f>(const Packet4f& _x) {
static _EIGEN_DECLARE_CONST_Packet4f(cephes_SQRTHF, 0.707106781186547524f);
static _EIGEN_DECLARE_CONST_Packet4f(cephes_log_p0, 7.0376836292e-2f);
static _EIGEN_DECLARE_CONST_Packet4f(cephes_log_p1, -1.1514610310e-1f);
static _EIGEN_DECLARE_CONST_Packet4f(cephes_log_p2, 1.1676998740e-1f);
static _EIGEN_DECLARE_CONST_Packet4f(cephes_log_p3, -1.2420140846e-1f);
static _EIGEN_DECLARE_CONST_Packet4f(cephes_log_p4, +1.4249322787e-1f);
static _EIGEN_DECLARE_CONST_Packet4f(cephes_log_p5, -1.6668057665e-1f);
static _EIGEN_DECLARE_CONST_Packet4f(cephes_log_p6, +2.0000714765e-1f);
static _EIGEN_DECLARE_CONST_Packet4f(cephes_log_p7, -2.4999993993e-1f);
static _EIGEN_DECLARE_CONST_Packet4f(cephes_log_p8, +3.3333331174e-1f);
static _EIGEN_DECLARE_CONST_Packet4f(cephes_log_q1, -2.12194440e-4f);
static _EIGEN_DECLARE_CONST_Packet4f(cephes_log_q2, 0.693359375f);
static _EIGEN_DECLARE_CONST_Packet4f(half, 0.5f);
static _EIGEN_DECLARE_CONST_Packet4f(1, 1.0f);
// Convert negative argument into NAN (quiet negative, to be specific).
Packet4f zero = (Packet4f)__builtin_msa_ldi_w(0);
Packet4i neg_mask = __builtin_msa_fclt_w(_x, zero);
Packet4i zero_mask = __builtin_msa_fceq_w(_x, zero);
Packet4f non_neg_x_or_nan = padd(_x, (Packet4f)neg_mask); // Add 0.0 or NAN.
Packet4f x = non_neg_x_or_nan;
// Extract exponent from x = mantissa * 2**exponent, where 1.0 <= mantissa < 2.0.
// N.B. the exponent is one less of what frexpf() would return.
Packet4i e_int = __builtin_msa_ftint_s_w(__builtin_msa_flog2_w(x));
// Multiply x by 2**(-exponent-1) to get 0.5 <= x < 1.0 as from frexpf().
x = __builtin_msa_fexp2_w(x, (Packet4i)__builtin_msa_nori_b((v16u8)e_int, 0));
/*
if (x < SQRTHF) {
x = x + x - 1.0;
} else {
e += 1;
x = x - 1.0;
}
*/
Packet4f xx = padd(x, x);
Packet4i ge_mask = __builtin_msa_fcle_w(p4f_cephes_SQRTHF, x);
e_int = psub(e_int, ge_mask);
x = (Packet4f)__builtin_msa_bsel_v((v16u8)ge_mask, (v16u8)xx, (v16u8)x);
x = psub(x, p4f_1);
Packet4f e = __builtin_msa_ffint_s_w(e_int);
Packet4f x2 = pmul(x, x);
Packet4f x3 = pmul(x2, x);
Packet4f y, y1, y2;
y = pmadd(p4f_cephes_log_p0, x, p4f_cephes_log_p1);
y1 = pmadd(p4f_cephes_log_p3, x, p4f_cephes_log_p4);
y2 = pmadd(p4f_cephes_log_p6, x, p4f_cephes_log_p7);
y = pmadd(y, x, p4f_cephes_log_p2);
y1 = pmadd(y1, x, p4f_cephes_log_p5);
y2 = pmadd(y2, x, p4f_cephes_log_p8);
y = pmadd(y, x3, y1);
y = pmadd(y, x3, y2);
y = pmul(y, x3);
y = pmadd(e, p4f_cephes_log_q1, y);
x = __builtin_msa_fmsub_w(x, x2, p4f_half);
x = padd(x, y);
x = pmadd(e, p4f_cephes_log_q2, x);
// x is now the logarithm result candidate. We still need to handle the
// extreme arguments of zero and positive infinity, though.
// N.B. if the argument is +INFINITY, x is NAN because the polynomial terms
// contain infinities of both signs (see the coefficients and code above).
// INFINITY - INFINITY is NAN.
// If the argument is +INFINITY, make it the new result candidate.
// To achieve that we choose the smaller of the result candidate and the
// argument.
// This is correct for all finite pairs of values (the logarithm is smaller
// than the argument).
// This is also correct in the special case when the argument is +INFINITY
// and the result candidate is NAN. This is because the fmin.df instruction
// prefers non-NANs to NANs.
x = __builtin_msa_fmin_w(x, non_neg_x_or_nan);
// If the argument is zero (including -0.0), the result becomes -INFINITY.
Packet4i neg_infs = __builtin_msa_slli_w(zero_mask, 23);
x = (Packet4f)__builtin_msa_bsel_v((v16u8)zero_mask, (v16u8)x, (v16u8)neg_infs);
return x;
}
template <>
EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS EIGEN_UNUSED Packet4f
pexp<Packet4f>(const Packet4f& _x) {
// Limiting single-precision pexp's argument to [-128, +128] lets pexp
// reach 0 and INFINITY naturally.
static _EIGEN_DECLARE_CONST_Packet4f(exp_lo, -128.0f);
static _EIGEN_DECLARE_CONST_Packet4f(exp_hi, +128.0f);
static _EIGEN_DECLARE_CONST_Packet4f(cephes_LOG2EF, 1.44269504088896341f);
static _EIGEN_DECLARE_CONST_Packet4f(cephes_exp_C1, 0.693359375f);
static _EIGEN_DECLARE_CONST_Packet4f(cephes_exp_C2, -2.12194440e-4f);
static _EIGEN_DECLARE_CONST_Packet4f(cephes_exp_p0, 1.9875691500e-4f);
static _EIGEN_DECLARE_CONST_Packet4f(cephes_exp_p1, 1.3981999507e-3f);
static _EIGEN_DECLARE_CONST_Packet4f(cephes_exp_p2, 8.3334519073e-3f);
static _EIGEN_DECLARE_CONST_Packet4f(cephes_exp_p3, 4.1665795894e-2f);
static _EIGEN_DECLARE_CONST_Packet4f(cephes_exp_p4, 1.6666665459e-1f);
static _EIGEN_DECLARE_CONST_Packet4f(cephes_exp_p5, 5.0000001201e-1f);
static _EIGEN_DECLARE_CONST_Packet4f(half, 0.5f);
static _EIGEN_DECLARE_CONST_Packet4f(1, 1.0f);
Packet4f x = _x;
// Clamp x.
x = (Packet4f)__builtin_msa_bsel_v((v16u8)__builtin_msa_fclt_w(x, p4f_exp_lo), (v16u8)x,
(v16u8)p4f_exp_lo);
x = (Packet4f)__builtin_msa_bsel_v((v16u8)__builtin_msa_fclt_w(p4f_exp_hi, x), (v16u8)x,
(v16u8)p4f_exp_hi);
// Round to nearest integer by adding 0.5 (with x's sign) and truncating.
Packet4f x2_add = (Packet4f)__builtin_msa_binsli_w((v4u32)p4f_half, (v4u32)x, 0);
Packet4f x2 = pmadd(x, p4f_cephes_LOG2EF, x2_add);
Packet4i x2_int = __builtin_msa_ftrunc_s_w(x2);
Packet4f x2_int_f = __builtin_msa_ffint_s_w(x2_int);
x = __builtin_msa_fmsub_w(x, x2_int_f, p4f_cephes_exp_C1);
x = __builtin_msa_fmsub_w(x, x2_int_f, p4f_cephes_exp_C2);
Packet4f z = pmul(x, x);
Packet4f y = p4f_cephes_exp_p0;
y = pmadd(y, x, p4f_cephes_exp_p1);
y = pmadd(y, x, p4f_cephes_exp_p2);
y = pmadd(y, x, p4f_cephes_exp_p3);
y = pmadd(y, x, p4f_cephes_exp_p4);
y = pmadd(y, x, p4f_cephes_exp_p5);
y = pmadd(y, z, x);
y = padd(y, p4f_1);
// y *= 2**exponent.
y = __builtin_msa_fexp2_w(y, x2_int);
return y;
}
template <>
EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS EIGEN_UNUSED Packet4f
ptanh<Packet4f>(const Packet4f& _x) {
static _EIGEN_DECLARE_CONST_Packet4f(tanh_tiny, 1e-4f);
static _EIGEN_DECLARE_CONST_Packet4f(tanh_hi, 9.0f);
// The monomial coefficients of the numerator polynomial (odd).
static _EIGEN_DECLARE_CONST_Packet4f(alpha_1, 4.89352455891786e-3f);
static _EIGEN_DECLARE_CONST_Packet4f(alpha_3, 6.37261928875436e-4f);
static _EIGEN_DECLARE_CONST_Packet4f(alpha_5, 1.48572235717979e-5f);
static _EIGEN_DECLARE_CONST_Packet4f(alpha_7, 5.12229709037114e-8f);
static _EIGEN_DECLARE_CONST_Packet4f(alpha_9, -8.60467152213735e-11f);
static _EIGEN_DECLARE_CONST_Packet4f(alpha_11, 2.00018790482477e-13f);
static _EIGEN_DECLARE_CONST_Packet4f(alpha_13, -2.76076847742355e-16f);
// The monomial coefficients of the denominator polynomial (even).
static _EIGEN_DECLARE_CONST_Packet4f(beta_0, 4.89352518554385e-3f);
static _EIGEN_DECLARE_CONST_Packet4f(beta_2, 2.26843463243900e-3f);
static _EIGEN_DECLARE_CONST_Packet4f(beta_4, 1.18534705686654e-4f);
static _EIGEN_DECLARE_CONST_Packet4f(beta_6, 1.19825839466702e-6f);
Packet4f x = pabs(_x);
Packet4i tiny_mask = __builtin_msa_fclt_w(x, p4f_tanh_tiny);
// Clamp the inputs to the range [-9, 9] since anything outside
// this range is -/+1.0f in single-precision.
x = (Packet4f)__builtin_msa_bsel_v((v16u8)__builtin_msa_fclt_w(p4f_tanh_hi, x), (v16u8)x,
(v16u8)p4f_tanh_hi);
// Since the polynomials are odd/even, we need x**2.
Packet4f x2 = pmul(x, x);
// Evaluate the numerator polynomial p.
Packet4f p = pmadd(x2, p4f_alpha_13, p4f_alpha_11);
p = pmadd(x2, p, p4f_alpha_9);
p = pmadd(x2, p, p4f_alpha_7);
p = pmadd(x2, p, p4f_alpha_5);
p = pmadd(x2, p, p4f_alpha_3);
p = pmadd(x2, p, p4f_alpha_1);
p = pmul(x, p);
// Evaluate the denominator polynomial q.
Packet4f q = pmadd(x2, p4f_beta_6, p4f_beta_4);
q = pmadd(x2, q, p4f_beta_2);
q = pmadd(x2, q, p4f_beta_0);
// Divide the numerator by the denominator.
p = pdiv(p, q);
// Reinstate the sign.
p = (Packet4f)__builtin_msa_binsli_w((v4u32)p, (v4u32)_x, 0);
// When the argument is very small in magnitude it's more accurate to just return it.
p = (Packet4f)__builtin_msa_bsel_v((v16u8)tiny_mask, (v16u8)p, (v16u8)_x);
return p;
}
template <bool sine>
Packet4f psincos_inner_msa_float(const Packet4f& _x) {
static _EIGEN_DECLARE_CONST_Packet4f(sincos_max_arg, 13176795.0f); // Approx. (2**24) / (4/Pi).
static _EIGEN_DECLARE_CONST_Packet4f(minus_cephes_DP1, -0.78515625f);
static _EIGEN_DECLARE_CONST_Packet4f(minus_cephes_DP2, -2.4187564849853515625e-4f);
static _EIGEN_DECLARE_CONST_Packet4f(minus_cephes_DP3, -3.77489497744594108e-8f);
static _EIGEN_DECLARE_CONST_Packet4f(sincof_p0, -1.9515295891e-4f);
static _EIGEN_DECLARE_CONST_Packet4f(sincof_p1, 8.3321608736e-3f);
static _EIGEN_DECLARE_CONST_Packet4f(sincof_p2, -1.6666654611e-1f);
static _EIGEN_DECLARE_CONST_Packet4f(coscof_p0, 2.443315711809948e-5f);
static _EIGEN_DECLARE_CONST_Packet4f(coscof_p1, -1.388731625493765e-3f);
static _EIGEN_DECLARE_CONST_Packet4f(coscof_p2, 4.166664568298827e-2f);
static _EIGEN_DECLARE_CONST_Packet4f(cephes_FOPI, 1.27323954473516f); // 4/Pi.
static _EIGEN_DECLARE_CONST_Packet4f(half, 0.5f);
static _EIGEN_DECLARE_CONST_Packet4f(1, 1.0f);
Packet4f x = pabs(_x);
// Translate infinite arguments into NANs.
Packet4f zero_or_nan_if_inf = psub(_x, _x);
x = padd(x, zero_or_nan_if_inf);
// Prevent sin/cos from generating values larger than 1.0 in magnitude
// for very large arguments by setting x to 0.0.
Packet4i small_or_nan_mask = __builtin_msa_fcult_w(x, p4f_sincos_max_arg);
x = pand(x, (Packet4f)small_or_nan_mask);
// Scale x by 4/Pi to find x's octant.
Packet4f y = pmul(x, p4f_cephes_FOPI);
// Get the octant. We'll reduce x by this number of octants or by one more than it.
Packet4i y_int = __builtin_msa_ftrunc_s_w(y);
// x's from even-numbered octants will translate to octant 0: [0, +Pi/4].
// x's from odd-numbered octants will translate to octant -1: [-Pi/4, 0].
// Adjustment for odd-numbered octants: octant = (octant + 1) & (~1).
Packet4i y_int1 = __builtin_msa_addvi_w(y_int, 1);
Packet4i y_int2 = (Packet4i)__builtin_msa_bclri_w((Packet4ui)y_int1, 0); // bclri = bit-clear
y = __builtin_msa_ffint_s_w(y_int2);
// Compute the sign to apply to the polynomial.
Packet4i sign_mask = sine ? pxor(__builtin_msa_slli_w(y_int1, 29), (Packet4i)_x)
: __builtin_msa_slli_w(__builtin_msa_addvi_w(y_int, 3), 29);
// Get the polynomial selection mask.
// We'll calculate both (sin and cos) polynomials and then select from the two.
Packet4i poly_mask = __builtin_msa_ceqi_w(__builtin_msa_slli_w(y_int2, 30), 0);
// Reduce x by y octants to get: -Pi/4 <= x <= +Pi/4.
// The magic pass: "Extended precision modular arithmetic"
// x = ((x - y * DP1) - y * DP2) - y * DP3
Packet4f tmp1 = pmul(y, p4f_minus_cephes_DP1);
Packet4f tmp2 = pmul(y, p4f_minus_cephes_DP2);
Packet4f tmp3 = pmul(y, p4f_minus_cephes_DP3);
x = padd(x, tmp1);
x = padd(x, tmp2);
x = padd(x, tmp3);
// Evaluate the cos(x) polynomial.
y = p4f_coscof_p0;
Packet4f z = pmul(x, x);
y = pmadd(y, z, p4f_coscof_p1);
y = pmadd(y, z, p4f_coscof_p2);
y = pmul(y, z);
y = pmul(y, z);
y = __builtin_msa_fmsub_w(y, z, p4f_half);
y = padd(y, p4f_1);
// Evaluate the sin(x) polynomial.
Packet4f y2 = p4f_sincof_p0;
y2 = pmadd(y2, z, p4f_sincof_p1);
y2 = pmadd(y2, z, p4f_sincof_p2);
y2 = pmul(y2, z);
y2 = pmadd(y2, x, x);
// Select the correct result from the two polynomials.
y = sine ? (Packet4f)__builtin_msa_bsel_v((v16u8)poly_mask, (v16u8)y, (v16u8)y2)
: (Packet4f)__builtin_msa_bsel_v((v16u8)poly_mask, (v16u8)y2, (v16u8)y);
// Update the sign.
sign_mask = pxor(sign_mask, (Packet4i)y);
y = (Packet4f)__builtin_msa_binsli_w((v4u32)y, (v4u32)sign_mask, 0); // binsli = bit-insert-left
return y;
}
template <>
EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS EIGEN_UNUSED Packet4f
psin<Packet4f>(const Packet4f& x) {
return psincos_inner_msa_float</* sine */ true>(x);
}
template <>
EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS EIGEN_UNUSED Packet4f
pcos<Packet4f>(const Packet4f& x) {
return psincos_inner_msa_float</* sine */ false>(x);
}
template <>
EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS EIGEN_UNUSED Packet2d
pexp<Packet2d>(const Packet2d& _x) {
// Limiting double-precision pexp's argument to [-1024, +1024] lets pexp
// reach 0 and INFINITY naturally.
static _EIGEN_DECLARE_CONST_Packet2d(exp_lo, -1024.0);
static _EIGEN_DECLARE_CONST_Packet2d(exp_hi, +1024.0);
static _EIGEN_DECLARE_CONST_Packet2d(cephes_LOG2EF, 1.4426950408889634073599);
static _EIGEN_DECLARE_CONST_Packet2d(cephes_exp_C1, 0.693145751953125);
static _EIGEN_DECLARE_CONST_Packet2d(cephes_exp_C2, 1.42860682030941723212e-6);
static _EIGEN_DECLARE_CONST_Packet2d(cephes_exp_p0, 1.26177193074810590878e-4);
static _EIGEN_DECLARE_CONST_Packet2d(cephes_exp_p1, 3.02994407707441961300e-2);
static _EIGEN_DECLARE_CONST_Packet2d(cephes_exp_p2, 9.99999999999999999910e-1);
static _EIGEN_DECLARE_CONST_Packet2d(cephes_exp_q0, 3.00198505138664455042e-6);
static _EIGEN_DECLARE_CONST_Packet2d(cephes_exp_q1, 2.52448340349684104192e-3);
static _EIGEN_DECLARE_CONST_Packet2d(cephes_exp_q2, 2.27265548208155028766e-1);
static _EIGEN_DECLARE_CONST_Packet2d(cephes_exp_q3, 2.00000000000000000009e0);
static _EIGEN_DECLARE_CONST_Packet2d(half, 0.5);
static _EIGEN_DECLARE_CONST_Packet2d(1, 1.0);
static _EIGEN_DECLARE_CONST_Packet2d(2, 2.0);
Packet2d x = _x;
// Clamp x.
x = (Packet2d)__builtin_msa_bsel_v((v16u8)__builtin_msa_fclt_d(x, p2d_exp_lo), (v16u8)x,
(v16u8)p2d_exp_lo);
x = (Packet2d)__builtin_msa_bsel_v((v16u8)__builtin_msa_fclt_d(p2d_exp_hi, x), (v16u8)x,
(v16u8)p2d_exp_hi);
// Round to nearest integer by adding 0.5 (with x's sign) and truncating.
Packet2d x2_add = (Packet2d)__builtin_msa_binsli_d((v2u64)p2d_half, (v2u64)x, 0);
Packet2d x2 = pmadd(x, p2d_cephes_LOG2EF, x2_add);
Packet2l x2_long = __builtin_msa_ftrunc_s_d(x2);
Packet2d x2_long_d = __builtin_msa_ffint_s_d(x2_long);
x = __builtin_msa_fmsub_d(x, x2_long_d, p2d_cephes_exp_C1);
x = __builtin_msa_fmsub_d(x, x2_long_d, p2d_cephes_exp_C2);
x2 = pmul(x, x);
Packet2d px = p2d_cephes_exp_p0;
px = pmadd(px, x2, p2d_cephes_exp_p1);
px = pmadd(px, x2, p2d_cephes_exp_p2);
px = pmul(px, x);
Packet2d qx = p2d_cephes_exp_q0;
qx = pmadd(qx, x2, p2d_cephes_exp_q1);
qx = pmadd(qx, x2, p2d_cephes_exp_q2);
qx = pmadd(qx, x2, p2d_cephes_exp_q3);
x = pdiv(px, psub(qx, px));
x = pmadd(p2d_2, x, p2d_1);
// x *= 2**exponent.
x = __builtin_msa_fexp2_d(x, x2_long);
return x;
}
} // end namespace internal
} // end namespace Eigen
#endif // EIGEN_MATH_FUNCTIONS_MSA_H

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,183 @@
namespace Eigen {
namespace internal {
#if EIGEN_ARCH_ARM && EIGEN_COMP_CLANG
// Clang seems to excessively spill registers in the GEBP kernel on 32-bit arm.
// Here we specialize gebp_traits to eliminate these register spills.
// See #2138.
template<>
struct gebp_traits <float,float,false,false,Architecture::NEON,GEBPPacketFull>
: gebp_traits<float,float,false,false,Architecture::Generic,GEBPPacketFull>
{
EIGEN_STRONG_INLINE void acc(const AccPacket& c, const ResPacket& alpha, ResPacket& r) const
{
// This volatile inline ASM both acts as a barrier to prevent reordering,
// as well as enforces strict register use.
asm volatile(
"vmla.f32 %q[r], %q[c], %q[alpha]"
: [r] "+w" (r)
: [c] "w" (c),
[alpha] "w" (alpha)
: );
}
template <typename LaneIdType>
EIGEN_STRONG_INLINE void madd(const Packet4f& a, const Packet4f& b,
Packet4f& c, Packet4f& tmp,
const LaneIdType&) const {
acc(a, b, c);
}
template <typename LaneIdType>
EIGEN_STRONG_INLINE void madd(const Packet4f& a, const QuadPacket<Packet4f>& b,
Packet4f& c, Packet4f& tmp,
const LaneIdType& lane) const {
madd(a, b.get(lane), c, tmp, lane);
}
};
#endif // EIGEN_ARCH_ARM && EIGEN_COMP_CLANG
#if EIGEN_ARCH_ARM64
template<>
struct gebp_traits <float,float,false,false,Architecture::NEON,GEBPPacketFull>
: gebp_traits<float,float,false,false,Architecture::Generic,GEBPPacketFull>
{
typedef float RhsPacket;
typedef float32x4_t RhsPacketx4;
EIGEN_STRONG_INLINE void loadRhs(const RhsScalar* b, RhsPacket& dest) const
{
dest = *b;
}
EIGEN_STRONG_INLINE void loadRhs(const RhsScalar* b, RhsPacketx4& dest) const
{
dest = vld1q_f32(b);
}
EIGEN_STRONG_INLINE void updateRhs(const RhsScalar* b, RhsPacket& dest) const
{
dest = *b;
}
EIGEN_STRONG_INLINE void updateRhs(const RhsScalar*, RhsPacketx4&) const
{}
EIGEN_STRONG_INLINE void loadRhsQuad(const RhsScalar* b, RhsPacket& dest) const
{
loadRhs(b,dest);
}
EIGEN_STRONG_INLINE void madd(const LhsPacket& a, const RhsPacket& b, AccPacket& c, RhsPacket& /*tmp*/, const FixedInt<0>&) const
{
c = vfmaq_n_f32(c, a, b);
}
// NOTE: Template parameter inference failed when compiled with Android NDK:
// "candidate template ignored: could not match 'FixedInt<N>' against 'Eigen::internal::FixedInt<0>".
EIGEN_STRONG_INLINE void madd(const LhsPacket& a, const RhsPacketx4& b, AccPacket& c, RhsPacket& /*tmp*/, const FixedInt<0>&) const
{ madd_helper<0>(a, b, c); }
EIGEN_STRONG_INLINE void madd(const LhsPacket& a, const RhsPacketx4& b, AccPacket& c, RhsPacket& /*tmp*/, const FixedInt<1>&) const
{ madd_helper<1>(a, b, c); }
EIGEN_STRONG_INLINE void madd(const LhsPacket& a, const RhsPacketx4& b, AccPacket& c, RhsPacket& /*tmp*/, const FixedInt<2>&) const
{ madd_helper<2>(a, b, c); }
EIGEN_STRONG_INLINE void madd(const LhsPacket& a, const RhsPacketx4& b, AccPacket& c, RhsPacket& /*tmp*/, const FixedInt<3>&) const
{ madd_helper<3>(a, b, c); }
private:
template<int LaneID>
EIGEN_STRONG_INLINE void madd_helper(const LhsPacket& a, const RhsPacketx4& b, AccPacket& c) const
{
#if EIGEN_COMP_GNUC_STRICT && !(EIGEN_GNUC_AT_LEAST(9,0))
// workaround gcc issue https://gcc.gnu.org/bugzilla/show_bug.cgi?id=89101
// vfmaq_laneq_f32 is implemented through a costly dup
if(LaneID==0) asm("fmla %0.4s, %1.4s, %2.s[0]\n" : "+w" (c) : "w" (a), "w" (b) : );
else if(LaneID==1) asm("fmla %0.4s, %1.4s, %2.s[1]\n" : "+w" (c) : "w" (a), "w" (b) : );
else if(LaneID==2) asm("fmla %0.4s, %1.4s, %2.s[2]\n" : "+w" (c) : "w" (a), "w" (b) : );
else if(LaneID==3) asm("fmla %0.4s, %1.4s, %2.s[3]\n" : "+w" (c) : "w" (a), "w" (b) : );
#else
c = vfmaq_laneq_f32(c, a, b, LaneID);
#endif
}
};
template<>
struct gebp_traits <double,double,false,false,Architecture::NEON>
: gebp_traits<double,double,false,false,Architecture::Generic>
{
typedef double RhsPacket;
struct RhsPacketx4 {
float64x2_t B_0, B_1;
};
EIGEN_STRONG_INLINE void loadRhs(const RhsScalar* b, RhsPacket& dest) const
{
dest = *b;
}
EIGEN_STRONG_INLINE void loadRhs(const RhsScalar* b, RhsPacketx4& dest) const
{
dest.B_0 = vld1q_f64(b);
dest.B_1 = vld1q_f64(b+2);
}
EIGEN_STRONG_INLINE void updateRhs(const RhsScalar* b, RhsPacket& dest) const
{
loadRhs(b,dest);
}
EIGEN_STRONG_INLINE void updateRhs(const RhsScalar*, RhsPacketx4&) const
{}
EIGEN_STRONG_INLINE void loadRhsQuad(const RhsScalar* b, RhsPacket& dest) const
{
loadRhs(b,dest);
}
EIGEN_STRONG_INLINE void madd(const LhsPacket& a, const RhsPacket& b, AccPacket& c, RhsPacket& /*tmp*/, const FixedInt<0>&) const
{
c = vfmaq_n_f64(c, a, b);
}
// NOTE: Template parameter inference failed when compiled with Android NDK:
// "candidate template ignored: could not match 'FixedInt<N>' against 'Eigen::internal::FixedInt<0>".
EIGEN_STRONG_INLINE void madd(const LhsPacket& a, const RhsPacketx4& b, AccPacket& c, RhsPacket& /*tmp*/, const FixedInt<0>&) const
{ madd_helper<0>(a, b, c); }
EIGEN_STRONG_INLINE void madd(const LhsPacket& a, const RhsPacketx4& b, AccPacket& c, RhsPacket& /*tmp*/, const FixedInt<1>&) const
{ madd_helper<1>(a, b, c); }
EIGEN_STRONG_INLINE void madd(const LhsPacket& a, const RhsPacketx4& b, AccPacket& c, RhsPacket& /*tmp*/, const FixedInt<2>&) const
{ madd_helper<2>(a, b, c); }
EIGEN_STRONG_INLINE void madd(const LhsPacket& a, const RhsPacketx4& b, AccPacket& c, RhsPacket& /*tmp*/, const FixedInt<3>&) const
{ madd_helper<3>(a, b, c); }
private:
template <int LaneID>
EIGEN_STRONG_INLINE void madd_helper(const LhsPacket& a, const RhsPacketx4& b, AccPacket& c) const
{
#if EIGEN_COMP_GNUC_STRICT && !(EIGEN_GNUC_AT_LEAST(9,0))
// workaround gcc issue https://gcc.gnu.org/bugzilla/show_bug.cgi?id=89101
// vfmaq_laneq_f64 is implemented through a costly dup
if(LaneID==0) asm("fmla %0.2d, %1.2d, %2.d[0]\n" : "+w" (c) : "w" (a), "w" (b.B_0) : );
else if(LaneID==1) asm("fmla %0.2d, %1.2d, %2.d[1]\n" : "+w" (c) : "w" (a), "w" (b.B_0) : );
else if(LaneID==2) asm("fmla %0.2d, %1.2d, %2.d[0]\n" : "+w" (c) : "w" (a), "w" (b.B_1) : );
else if(LaneID==3) asm("fmla %0.2d, %1.2d, %2.d[1]\n" : "+w" (c) : "w" (a), "w" (b.B_1) : );
#else
if(LaneID==0) c = vfmaq_laneq_f64(c, a, b.B_0, 0);
else if(LaneID==1) c = vfmaq_laneq_f64(c, a, b.B_0, 1);
else if(LaneID==2) c = vfmaq_laneq_f64(c, a, b.B_1, 0);
else if(LaneID==3) c = vfmaq_laneq_f64(c, a, b.B_1, 1);
#endif
}
};
#endif // EIGEN_ARCH_ARM64
} // namespace internal
} // namespace Eigen

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,44 @@
// This file is part of Eigen, a lightweight C++ template library
// for linear algebra.
//
// Copyright (C) 2020, Arm Limited and Contributors
//
// This Source Code Form is subject to the terms of the Mozilla
// Public License v. 2.0. If a copy of the MPL was not distributed
// with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
#ifndef EIGEN_MATH_FUNCTIONS_SVE_H
#define EIGEN_MATH_FUNCTIONS_SVE_H
namespace Eigen {
namespace internal {
template <>
EIGEN_STRONG_INLINE EIGEN_UNUSED PacketXf pexp<PacketXf>(const PacketXf& x) {
return pexp_float(x);
}
template <>
EIGEN_STRONG_INLINE EIGEN_UNUSED PacketXf plog<PacketXf>(const PacketXf& x) {
return plog_float(x);
}
template <>
EIGEN_STRONG_INLINE EIGEN_UNUSED PacketXf psin<PacketXf>(const PacketXf& x) {
return psin_float(x);
}
template <>
EIGEN_STRONG_INLINE EIGEN_UNUSED PacketXf pcos<PacketXf>(const PacketXf& x) {
return pcos_float(x);
}
// Hyperbolic Tangent function.
template <>
EIGEN_STRONG_INLINE EIGEN_UNUSED PacketXf ptanh<PacketXf>(const PacketXf& x) {
return internal::generic_fast_tanh_float(x);
}
} // end namespace internal
} // end namespace Eigen
#endif // EIGEN_MATH_FUNCTIONS_SVE_H

View File

@ -0,0 +1,752 @@
// This file is part of Eigen, a lightweight C++ template library
// for linear algebra.
//
// Copyright (C) 2020, Arm Limited and Contributors
//
// This Source Code Form is subject to the terms of the Mozilla
// Public License v. 2.0. If a copy of the MPL was not distributed
// with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
#ifndef EIGEN_PACKET_MATH_SVE_H
#define EIGEN_PACKET_MATH_SVE_H
namespace Eigen
{
namespace internal
{
#ifndef EIGEN_CACHEFRIENDLY_PRODUCT_THRESHOLD
#define EIGEN_CACHEFRIENDLY_PRODUCT_THRESHOLD 8
#endif
#ifndef EIGEN_HAS_SINGLE_INSTRUCTION_MADD
#define EIGEN_HAS_SINGLE_INSTRUCTION_MADD
#endif
#define EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS 32
template <typename Scalar, int SVEVectorLength>
struct sve_packet_size_selector {
enum { size = SVEVectorLength / (sizeof(Scalar) * CHAR_BIT) };
};
/********************************* int32 **************************************/
typedef svint32_t PacketXi __attribute__((arm_sve_vector_bits(EIGEN_ARM64_SVE_VL)));
template <>
struct packet_traits<numext::int32_t> : default_packet_traits {
typedef PacketXi type;
typedef PacketXi half; // Half not implemented yet
enum {
Vectorizable = 1,
AlignedOnScalar = 1,
size = sve_packet_size_selector<numext::int32_t, EIGEN_ARM64_SVE_VL>::size,
HasHalfPacket = 0,
HasAdd = 1,
HasSub = 1,
HasShift = 1,
HasMul = 1,
HasNegate = 1,
HasAbs = 1,
HasArg = 0,
HasAbs2 = 1,
HasMin = 1,
HasMax = 1,
HasConj = 1,
HasSetLinear = 0,
HasBlend = 0,
HasReduxp = 0 // Not implemented in SVE
};
};
template <>
struct unpacket_traits<PacketXi> {
typedef numext::int32_t type;
typedef PacketXi half; // Half not yet implemented
enum {
size = sve_packet_size_selector<numext::int32_t, EIGEN_ARM64_SVE_VL>::size,
alignment = Aligned64,
vectorizable = true,
masked_load_available = false,
masked_store_available = false
};
};
template <>
EIGEN_STRONG_INLINE void prefetch<numext::int32_t>(const numext::int32_t* addr)
{
svprfw(svptrue_b32(), addr, SV_PLDL1KEEP);
}
template <>
EIGEN_STRONG_INLINE PacketXi pset1<PacketXi>(const numext::int32_t& from)
{
return svdup_n_s32(from);
}
template <>
EIGEN_STRONG_INLINE PacketXi plset<PacketXi>(const numext::int32_t& a)
{
numext::int32_t c[packet_traits<numext::int32_t>::size];
for (int i = 0; i < packet_traits<numext::int32_t>::size; i++) c[i] = i;
return svadd_s32_z(svptrue_b32(), pset1<PacketXi>(a), svld1_s32(svptrue_b32(), c));
}
template <>
EIGEN_STRONG_INLINE PacketXi padd<PacketXi>(const PacketXi& a, const PacketXi& b)
{
return svadd_s32_z(svptrue_b32(), a, b);
}
template <>
EIGEN_STRONG_INLINE PacketXi psub<PacketXi>(const PacketXi& a, const PacketXi& b)
{
return svsub_s32_z(svptrue_b32(), a, b);
}
template <>
EIGEN_STRONG_INLINE PacketXi pnegate(const PacketXi& a)
{
return svneg_s32_z(svptrue_b32(), a);
}
template <>
EIGEN_STRONG_INLINE PacketXi pconj(const PacketXi& a)
{
return a;
}
template <>
EIGEN_STRONG_INLINE PacketXi pmul<PacketXi>(const PacketXi& a, const PacketXi& b)
{
return svmul_s32_z(svptrue_b32(), a, b);
}
template <>
EIGEN_STRONG_INLINE PacketXi pdiv<PacketXi>(const PacketXi& a, const PacketXi& b)
{
return svdiv_s32_z(svptrue_b32(), a, b);
}
template <>
EIGEN_STRONG_INLINE PacketXi pmadd(const PacketXi& a, const PacketXi& b, const PacketXi& c)
{
return svmla_s32_z(svptrue_b32(), c, a, b);
}
template <>
EIGEN_STRONG_INLINE PacketXi pmin<PacketXi>(const PacketXi& a, const PacketXi& b)
{
return svmin_s32_z(svptrue_b32(), a, b);
}
template <>
EIGEN_STRONG_INLINE PacketXi pmax<PacketXi>(const PacketXi& a, const PacketXi& b)
{
return svmax_s32_z(svptrue_b32(), a, b);
}
template <>
EIGEN_STRONG_INLINE PacketXi pcmp_le<PacketXi>(const PacketXi& a, const PacketXi& b)
{
return svdup_n_s32_z(svcmplt_s32(svptrue_b32(), a, b), 0xffffffffu);
}
template <>
EIGEN_STRONG_INLINE PacketXi pcmp_lt<PacketXi>(const PacketXi& a, const PacketXi& b)
{
return svdup_n_s32_z(svcmplt_s32(svptrue_b32(), a, b), 0xffffffffu);
}
template <>
EIGEN_STRONG_INLINE PacketXi pcmp_eq<PacketXi>(const PacketXi& a, const PacketXi& b)
{
return svdup_n_s32_z(svcmpeq_s32(svptrue_b32(), a, b), 0xffffffffu);
}
template <>
EIGEN_STRONG_INLINE PacketXi ptrue<PacketXi>(const PacketXi& /*a*/)
{
return svdup_n_s32_z(svptrue_b32(), 0xffffffffu);
}
template <>
EIGEN_STRONG_INLINE PacketXi pzero<PacketXi>(const PacketXi& /*a*/)
{
return svdup_n_s32_z(svptrue_b32(), 0);
}
template <>
EIGEN_STRONG_INLINE PacketXi pand<PacketXi>(const PacketXi& a, const PacketXi& b)
{
return svand_s32_z(svptrue_b32(), a, b);
}
template <>
EIGEN_STRONG_INLINE PacketXi por<PacketXi>(const PacketXi& a, const PacketXi& b)
{
return svorr_s32_z(svptrue_b32(), a, b);
}
template <>
EIGEN_STRONG_INLINE PacketXi pxor<PacketXi>(const PacketXi& a, const PacketXi& b)
{
return sveor_s32_z(svptrue_b32(), a, b);
}
template <>
EIGEN_STRONG_INLINE PacketXi pandnot<PacketXi>(const PacketXi& a, const PacketXi& b)
{
return svbic_s32_z(svptrue_b32(), a, b);
}
template <int N>
EIGEN_STRONG_INLINE PacketXi parithmetic_shift_right(PacketXi a)
{
return svasrd_n_s32_z(svptrue_b32(), a, N);
}
template <int N>
EIGEN_STRONG_INLINE PacketXi plogical_shift_right(PacketXi a)
{
return svreinterpret_s32_u32(svlsr_u32_z(svptrue_b32(), svreinterpret_u32_s32(a), svdup_n_u32_z(svptrue_b32(), N)));
}
template <int N>
EIGEN_STRONG_INLINE PacketXi plogical_shift_left(PacketXi a)
{
return svlsl_s32_z(svptrue_b32(), a, svdup_n_u32_z(svptrue_b32(), N));
}
template <>
EIGEN_STRONG_INLINE PacketXi pload<PacketXi>(const numext::int32_t* from)
{
EIGEN_DEBUG_ALIGNED_LOAD return svld1_s32(svptrue_b32(), from);
}
template <>
EIGEN_STRONG_INLINE PacketXi ploadu<PacketXi>(const numext::int32_t* from)
{
EIGEN_DEBUG_UNALIGNED_LOAD return svld1_s32(svptrue_b32(), from);
}
template <>
EIGEN_STRONG_INLINE PacketXi ploaddup<PacketXi>(const numext::int32_t* from)
{
svuint32_t indices = svindex_u32(0, 1); // index {base=0, base+step=1, base+step*2, ...}
indices = svzip1_u32(indices, indices); // index in the format {a0, a0, a1, a1, a2, a2, ...}
return svld1_gather_u32index_s32(svptrue_b32(), from, indices);
}
template <>
EIGEN_STRONG_INLINE PacketXi ploadquad<PacketXi>(const numext::int32_t* from)
{
svuint32_t indices = svindex_u32(0, 1); // index {base=0, base+step=1, base+step*2, ...}
indices = svzip1_u32(indices, indices); // index in the format {a0, a0, a1, a1, a2, a2, ...}
indices = svzip1_u32(indices, indices); // index in the format {a0, a0, a0, a0, a1, a1, a1, a1, ...}
return svld1_gather_u32index_s32(svptrue_b32(), from, indices);
}
template <>
EIGEN_STRONG_INLINE void pstore<numext::int32_t>(numext::int32_t* to, const PacketXi& from)
{
EIGEN_DEBUG_ALIGNED_STORE svst1_s32(svptrue_b32(), to, from);
}
template <>
EIGEN_STRONG_INLINE void pstoreu<numext::int32_t>(numext::int32_t* to, const PacketXi& from)
{
EIGEN_DEBUG_UNALIGNED_STORE svst1_s32(svptrue_b32(), to, from);
}
template <>
EIGEN_DEVICE_FUNC inline PacketXi pgather<numext::int32_t, PacketXi>(const numext::int32_t* from, Index stride)
{
// Indice format: {base=0, base+stride, base+stride*2, base+stride*3, ...}
svint32_t indices = svindex_s32(0, stride);
return svld1_gather_s32index_s32(svptrue_b32(), from, indices);
}
template <>
EIGEN_DEVICE_FUNC inline void pscatter<numext::int32_t, PacketXi>(numext::int32_t* to, const PacketXi& from, Index stride)
{
// Indice format: {base=0, base+stride, base+stride*2, base+stride*3, ...}
svint32_t indices = svindex_s32(0, stride);
svst1_scatter_s32index_s32(svptrue_b32(), to, indices, from);
}
template <>
EIGEN_STRONG_INLINE numext::int32_t pfirst<PacketXi>(const PacketXi& a)
{
// svlasta returns the first element if all predicate bits are 0
return svlasta_s32(svpfalse_b(), a);
}
template <>
EIGEN_STRONG_INLINE PacketXi preverse(const PacketXi& a)
{
return svrev_s32(a);
}
template <>
EIGEN_STRONG_INLINE PacketXi pabs(const PacketXi& a)
{
return svabs_s32_z(svptrue_b32(), a);
}
template <>
EIGEN_STRONG_INLINE numext::int32_t predux<PacketXi>(const PacketXi& a)
{
return static_cast<numext::int32_t>(svaddv_s32(svptrue_b32(), a));
}
template <>
EIGEN_STRONG_INLINE numext::int32_t predux_mul<PacketXi>(const PacketXi& a)
{
EIGEN_STATIC_ASSERT((EIGEN_ARM64_SVE_VL % 128 == 0),
EIGEN_INTERNAL_ERROR_PLEASE_FILE_A_BUG_REPORT);
// Multiply the vector by its reverse
svint32_t prod = svmul_s32_z(svptrue_b32(), a, svrev_s32(a));
svint32_t half_prod;
// Extract the high half of the vector. Depending on the VL more reductions need to be done
if (EIGEN_ARM64_SVE_VL >= 2048) {
half_prod = svtbl_s32(prod, svindex_u32(32, 1));
prod = svmul_s32_z(svptrue_b32(), prod, half_prod);
}
if (EIGEN_ARM64_SVE_VL >= 1024) {
half_prod = svtbl_s32(prod, svindex_u32(16, 1));
prod = svmul_s32_z(svptrue_b32(), prod, half_prod);
}
if (EIGEN_ARM64_SVE_VL >= 512) {
half_prod = svtbl_s32(prod, svindex_u32(8, 1));
prod = svmul_s32_z(svptrue_b32(), prod, half_prod);
}
if (EIGEN_ARM64_SVE_VL >= 256) {
half_prod = svtbl_s32(prod, svindex_u32(4, 1));
prod = svmul_s32_z(svptrue_b32(), prod, half_prod);
}
// Last reduction
half_prod = svtbl_s32(prod, svindex_u32(2, 1));
prod = svmul_s32_z(svptrue_b32(), prod, half_prod);
// The reduction is done to the first element.
return pfirst<PacketXi>(prod);
}
template <>
EIGEN_STRONG_INLINE numext::int32_t predux_min<PacketXi>(const PacketXi& a)
{
return svminv_s32(svptrue_b32(), a);
}
template <>
EIGEN_STRONG_INLINE numext::int32_t predux_max<PacketXi>(const PacketXi& a)
{
return svmaxv_s32(svptrue_b32(), a);
}
template <int N>
EIGEN_DEVICE_FUNC inline void ptranspose(PacketBlock<PacketXi, N>& kernel) {
int buffer[packet_traits<numext::int32_t>::size * N] = {0};
int i = 0;
PacketXi stride_index = svindex_s32(0, N);
for (i = 0; i < N; i++) {
svst1_scatter_s32index_s32(svptrue_b32(), buffer + i, stride_index, kernel.packet[i]);
}
for (i = 0; i < N; i++) {
kernel.packet[i] = svld1_s32(svptrue_b32(), buffer + i * packet_traits<numext::int32_t>::size);
}
}
/********************************* float32 ************************************/
typedef svfloat32_t PacketXf __attribute__((arm_sve_vector_bits(EIGEN_ARM64_SVE_VL)));
template <>
struct packet_traits<float> : default_packet_traits {
typedef PacketXf type;
typedef PacketXf half;
enum {
Vectorizable = 1,
AlignedOnScalar = 1,
size = sve_packet_size_selector<float, EIGEN_ARM64_SVE_VL>::size,
HasHalfPacket = 0,
HasAdd = 1,
HasSub = 1,
HasShift = 1,
HasMul = 1,
HasNegate = 1,
HasAbs = 1,
HasArg = 0,
HasAbs2 = 1,
HasMin = 1,
HasMax = 1,
HasConj = 1,
HasSetLinear = 0,
HasBlend = 0,
HasReduxp = 0, // Not implemented in SVE
HasDiv = 1,
HasFloor = 1,
HasSin = EIGEN_FAST_MATH,
HasCos = EIGEN_FAST_MATH,
HasLog = 1,
HasExp = 1,
HasSqrt = 0,
HasTanh = EIGEN_FAST_MATH,
HasErf = EIGEN_FAST_MATH
};
};
template <>
struct unpacket_traits<PacketXf> {
typedef float type;
typedef PacketXf half; // Half not yet implemented
typedef PacketXi integer_packet;
enum {
size = sve_packet_size_selector<float, EIGEN_ARM64_SVE_VL>::size,
alignment = Aligned64,
vectorizable = true,
masked_load_available = false,
masked_store_available = false
};
};
template <>
EIGEN_STRONG_INLINE PacketXf pset1<PacketXf>(const float& from)
{
return svdup_n_f32(from);
}
template <>
EIGEN_STRONG_INLINE PacketXf pset1frombits<PacketXf>(numext::uint32_t from)
{
return svreinterpret_f32_u32(svdup_n_u32_z(svptrue_b32(), from));
}
template <>
EIGEN_STRONG_INLINE PacketXf plset<PacketXf>(const float& a)
{
float c[packet_traits<float>::size];
for (int i = 0; i < packet_traits<float>::size; i++) c[i] = i;
return svadd_f32_z(svptrue_b32(), pset1<PacketXf>(a), svld1_f32(svptrue_b32(), c));
}
template <>
EIGEN_STRONG_INLINE PacketXf padd<PacketXf>(const PacketXf& a, const PacketXf& b)
{
return svadd_f32_z(svptrue_b32(), a, b);
}
template <>
EIGEN_STRONG_INLINE PacketXf psub<PacketXf>(const PacketXf& a, const PacketXf& b)
{
return svsub_f32_z(svptrue_b32(), a, b);
}
template <>
EIGEN_STRONG_INLINE PacketXf pnegate(const PacketXf& a)
{
return svneg_f32_z(svptrue_b32(), a);
}
template <>
EIGEN_STRONG_INLINE PacketXf pconj(const PacketXf& a)
{
return a;
}
template <>
EIGEN_STRONG_INLINE PacketXf pmul<PacketXf>(const PacketXf& a, const PacketXf& b)
{
return svmul_f32_z(svptrue_b32(), a, b);
}
template <>
EIGEN_STRONG_INLINE PacketXf pdiv<PacketXf>(const PacketXf& a, const PacketXf& b)
{
return svdiv_f32_z(svptrue_b32(), a, b);
}
template <>
EIGEN_STRONG_INLINE PacketXf pmadd(const PacketXf& a, const PacketXf& b, const PacketXf& c)
{
return svmla_f32_z(svptrue_b32(), c, a, b);
}
template <>
EIGEN_STRONG_INLINE PacketXf pmin<PacketXf>(const PacketXf& a, const PacketXf& b)
{
return svmin_f32_z(svptrue_b32(), a, b);
}
template <>
EIGEN_STRONG_INLINE PacketXf pmin<PropagateNaN, PacketXf>(const PacketXf& a, const PacketXf& b)
{
return pmin<PacketXf>(a, b);
}
template <>
EIGEN_STRONG_INLINE PacketXf pmin<PropagateNumbers, PacketXf>(const PacketXf& a, const PacketXf& b)
{
return svminnm_f32_z(svptrue_b32(), a, b);
}
template <>
EIGEN_STRONG_INLINE PacketXf pmax<PacketXf>(const PacketXf& a, const PacketXf& b)
{
return svmax_f32_z(svptrue_b32(), a, b);
}
template <>
EIGEN_STRONG_INLINE PacketXf pmax<PropagateNaN, PacketXf>(const PacketXf& a, const PacketXf& b)
{
return pmax<PacketXf>(a, b);
}
template <>
EIGEN_STRONG_INLINE PacketXf pmax<PropagateNumbers, PacketXf>(const PacketXf& a, const PacketXf& b)
{
return svmaxnm_f32_z(svptrue_b32(), a, b);
}
// Float comparisons in SVE return svbool (predicate). Use svdup to set active
// lanes to 1 (0xffffffffu) and inactive lanes to 0.
template <>
EIGEN_STRONG_INLINE PacketXf pcmp_le<PacketXf>(const PacketXf& a, const PacketXf& b)
{
return svreinterpret_f32_u32(svdup_n_u32_z(svcmplt_f32(svptrue_b32(), a, b), 0xffffffffu));
}
template <>
EIGEN_STRONG_INLINE PacketXf pcmp_lt<PacketXf>(const PacketXf& a, const PacketXf& b)
{
return svreinterpret_f32_u32(svdup_n_u32_z(svcmplt_f32(svptrue_b32(), a, b), 0xffffffffu));
}
template <>
EIGEN_STRONG_INLINE PacketXf pcmp_eq<PacketXf>(const PacketXf& a, const PacketXf& b)
{
return svreinterpret_f32_u32(svdup_n_u32_z(svcmpeq_f32(svptrue_b32(), a, b), 0xffffffffu));
}
// Do a predicate inverse (svnot_b_z) on the predicate resulted from the
// greater/equal comparison (svcmpge_f32). Then fill a float vector with the
// active elements.
template <>
EIGEN_STRONG_INLINE PacketXf pcmp_lt_or_nan<PacketXf>(const PacketXf& a, const PacketXf& b)
{
return svreinterpret_f32_u32(svdup_n_u32_z(svnot_b_z(svptrue_b32(), svcmpge_f32(svptrue_b32(), a, b)), 0xffffffffu));
}
template <>
EIGEN_STRONG_INLINE PacketXf pfloor<PacketXf>(const PacketXf& a)
{
return svrintm_f32_z(svptrue_b32(), a);
}
template <>
EIGEN_STRONG_INLINE PacketXf ptrue<PacketXf>(const PacketXf& /*a*/)
{
return svreinterpret_f32_u32(svdup_n_u32_z(svptrue_b32(), 0xffffffffu));
}
// Logical Operations are not supported for float, so reinterpret casts
template <>
EIGEN_STRONG_INLINE PacketXf pand<PacketXf>(const PacketXf& a, const PacketXf& b)
{
return svreinterpret_f32_u32(svand_u32_z(svptrue_b32(), svreinterpret_u32_f32(a), svreinterpret_u32_f32(b)));
}
template <>
EIGEN_STRONG_INLINE PacketXf por<PacketXf>(const PacketXf& a, const PacketXf& b)
{
return svreinterpret_f32_u32(svorr_u32_z(svptrue_b32(), svreinterpret_u32_f32(a), svreinterpret_u32_f32(b)));
}
template <>
EIGEN_STRONG_INLINE PacketXf pxor<PacketXf>(const PacketXf& a, const PacketXf& b)
{
return svreinterpret_f32_u32(sveor_u32_z(svptrue_b32(), svreinterpret_u32_f32(a), svreinterpret_u32_f32(b)));
}
template <>
EIGEN_STRONG_INLINE PacketXf pandnot<PacketXf>(const PacketXf& a, const PacketXf& b)
{
return svreinterpret_f32_u32(svbic_u32_z(svptrue_b32(), svreinterpret_u32_f32(a), svreinterpret_u32_f32(b)));
}
template <>
EIGEN_STRONG_INLINE PacketXf pload<PacketXf>(const float* from)
{
EIGEN_DEBUG_ALIGNED_LOAD return svld1_f32(svptrue_b32(), from);
}
template <>
EIGEN_STRONG_INLINE PacketXf ploadu<PacketXf>(const float* from)
{
EIGEN_DEBUG_UNALIGNED_LOAD return svld1_f32(svptrue_b32(), from);
}
template <>
EIGEN_STRONG_INLINE PacketXf ploaddup<PacketXf>(const float* from)
{
svuint32_t indices = svindex_u32(0, 1); // index {base=0, base+step=1, base+step*2, ...}
indices = svzip1_u32(indices, indices); // index in the format {a0, a0, a1, a1, a2, a2, ...}
return svld1_gather_u32index_f32(svptrue_b32(), from, indices);
}
template <>
EIGEN_STRONG_INLINE PacketXf ploadquad<PacketXf>(const float* from)
{
svuint32_t indices = svindex_u32(0, 1); // index {base=0, base+step=1, base+step*2, ...}
indices = svzip1_u32(indices, indices); // index in the format {a0, a0, a1, a1, a2, a2, ...}
indices = svzip1_u32(indices, indices); // index in the format {a0, a0, a0, a0, a1, a1, a1, a1, ...}
return svld1_gather_u32index_f32(svptrue_b32(), from, indices);
}
template <>
EIGEN_STRONG_INLINE void pstore<float>(float* to, const PacketXf& from)
{
EIGEN_DEBUG_ALIGNED_STORE svst1_f32(svptrue_b32(), to, from);
}
template <>
EIGEN_STRONG_INLINE void pstoreu<float>(float* to, const PacketXf& from)
{
EIGEN_DEBUG_UNALIGNED_STORE svst1_f32(svptrue_b32(), to, from);
}
template <>
EIGEN_DEVICE_FUNC inline PacketXf pgather<float, PacketXf>(const float* from, Index stride)
{
// Indice format: {base=0, base+stride, base+stride*2, base+stride*3, ...}
svint32_t indices = svindex_s32(0, stride);
return svld1_gather_s32index_f32(svptrue_b32(), from, indices);
}
template <>
EIGEN_DEVICE_FUNC inline void pscatter<float, PacketXf>(float* to, const PacketXf& from, Index stride)
{
// Indice format: {base=0, base+stride, base+stride*2, base+stride*3, ...}
svint32_t indices = svindex_s32(0, stride);
svst1_scatter_s32index_f32(svptrue_b32(), to, indices, from);
}
template <>
EIGEN_STRONG_INLINE float pfirst<PacketXf>(const PacketXf& a)
{
// svlasta returns the first element if all predicate bits are 0
return svlasta_f32(svpfalse_b(), a);
}
template <>
EIGEN_STRONG_INLINE PacketXf preverse(const PacketXf& a)
{
return svrev_f32(a);
}
template <>
EIGEN_STRONG_INLINE PacketXf pabs(const PacketXf& a)
{
return svabs_f32_z(svptrue_b32(), a);
}
// TODO(tellenbach): Should this go into MathFunctions.h? If so, change for
// all vector extensions and the generic version.
template <>
EIGEN_STRONG_INLINE PacketXf pfrexp<PacketXf>(const PacketXf& a, PacketXf& exponent)
{
return pfrexp_generic(a, exponent);
}
template <>
EIGEN_STRONG_INLINE float predux<PacketXf>(const PacketXf& a)
{
return svaddv_f32(svptrue_b32(), a);
}
// Other reduction functions:
// mul
// Only works for SVE Vls multiple of 128
template <>
EIGEN_STRONG_INLINE float predux_mul<PacketXf>(const PacketXf& a)
{
EIGEN_STATIC_ASSERT((EIGEN_ARM64_SVE_VL % 128 == 0),
EIGEN_INTERNAL_ERROR_PLEASE_FILE_A_BUG_REPORT);
// Multiply the vector by its reverse
svfloat32_t prod = svmul_f32_z(svptrue_b32(), a, svrev_f32(a));
svfloat32_t half_prod;
// Extract the high half of the vector. Depending on the VL more reductions need to be done
if (EIGEN_ARM64_SVE_VL >= 2048) {
half_prod = svtbl_f32(prod, svindex_u32(32, 1));
prod = svmul_f32_z(svptrue_b32(), prod, half_prod);
}
if (EIGEN_ARM64_SVE_VL >= 1024) {
half_prod = svtbl_f32(prod, svindex_u32(16, 1));
prod = svmul_f32_z(svptrue_b32(), prod, half_prod);
}
if (EIGEN_ARM64_SVE_VL >= 512) {
half_prod = svtbl_f32(prod, svindex_u32(8, 1));
prod = svmul_f32_z(svptrue_b32(), prod, half_prod);
}
if (EIGEN_ARM64_SVE_VL >= 256) {
half_prod = svtbl_f32(prod, svindex_u32(4, 1));
prod = svmul_f32_z(svptrue_b32(), prod, half_prod);
}
// Last reduction
half_prod = svtbl_f32(prod, svindex_u32(2, 1));
prod = svmul_f32_z(svptrue_b32(), prod, half_prod);
// The reduction is done to the first element.
return pfirst<PacketXf>(prod);
}
template <>
EIGEN_STRONG_INLINE float predux_min<PacketXf>(const PacketXf& a)
{
return svminv_f32(svptrue_b32(), a);
}
template <>
EIGEN_STRONG_INLINE float predux_max<PacketXf>(const PacketXf& a)
{
return svmaxv_f32(svptrue_b32(), a);
}
template<int N>
EIGEN_DEVICE_FUNC inline void ptranspose(PacketBlock<PacketXf, N>& kernel)
{
float buffer[packet_traits<float>::size * N] = {0};
int i = 0;
PacketXi stride_index = svindex_s32(0, N);
for (i = 0; i < N; i++) {
svst1_scatter_s32index_f32(svptrue_b32(), buffer + i, stride_index, kernel.packet[i]);
}
for (i = 0; i < N; i++) {
kernel.packet[i] = svld1_f32(svptrue_b32(), buffer + i * packet_traits<float>::size);
}
}
template<>
EIGEN_STRONG_INLINE PacketXf pldexp<PacketXf>(const PacketXf& a, const PacketXf& exponent)
{
return pldexp_generic(a, exponent);
}
} // namespace internal
} // namespace Eigen
#endif // EIGEN_PACKET_MATH_SVE_H

View File

@ -0,0 +1,49 @@
// This file is part of Eigen, a lightweight C++ template library
// for linear algebra.
//
// Copyright (C) 2020, Arm Limited and Contributors
//
// This Source Code Form is subject to the terms of the Mozilla
// Public License v. 2.0. If a copy of the MPL was not distributed
// with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
#ifndef EIGEN_TYPE_CASTING_SVE_H
#define EIGEN_TYPE_CASTING_SVE_H
namespace Eigen {
namespace internal {
template <>
struct type_casting_traits<float, numext::int32_t> {
enum { VectorizedCast = 1, SrcCoeffRatio = 1, TgtCoeffRatio = 1 };
};
template <>
struct type_casting_traits<numext::int32_t, float> {
enum { VectorizedCast = 1, SrcCoeffRatio = 1, TgtCoeffRatio = 1 };
};
template <>
EIGEN_STRONG_INLINE PacketXf pcast<PacketXi, PacketXf>(const PacketXi& a) {
return svcvt_f32_s32_z(svptrue_b32(), a);
}
template <>
EIGEN_STRONG_INLINE PacketXi pcast<PacketXf, PacketXi>(const PacketXf& a) {
return svcvt_s32_f32_z(svptrue_b32(), a);
}
template <>
EIGEN_STRONG_INLINE PacketXf preinterpret<PacketXf, PacketXi>(const PacketXi& a) {
return svreinterpret_f32_s32(a);
}
template <>
EIGEN_STRONG_INLINE PacketXi preinterpret<PacketXi, PacketXf>(const PacketXf& a) {
return svreinterpret_s32_f32(a);
}
} // namespace internal
} // namespace Eigen
#endif // EIGEN_TYPE_CASTING_SVE_H

View File

@ -0,0 +1,232 @@
// This file is part of Eigen, a lightweight C++ template library
// for linear algebra.
//
// Mehdi Goli Codeplay Software Ltd.
// Ralph Potter Codeplay Software Ltd.
// Luke Iwanski Codeplay Software Ltd.
// Contact: <eigen@codeplay.com>
//
// This Source Code Form is subject to the terms of the Mozilla
// Public License v. 2.0. If a copy of the MPL was not distributed
// with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
/*****************************************************************
* InteropHeaders.h
*
* \brief:
* InteropHeaders
*
*****************************************************************/
#ifndef EIGEN_INTEROP_HEADERS_SYCL_H
#define EIGEN_INTEROP_HEADERS_SYCL_H
namespace Eigen {
#if !defined(EIGEN_DONT_VECTORIZE_SYCL)
namespace internal {
template <int has_blend, int lengths>
struct sycl_packet_traits : default_packet_traits {
enum {
Vectorizable = 1,
AlignedOnScalar = 1,
size = lengths,
HasHalfPacket = 0,
HasDiv = 1,
HasLog = 1,
HasExp = 1,
HasSqrt = 1,
HasRsqrt = 1,
HasSin = 1,
HasCos = 1,
HasTan = 1,
HasASin = 1,
HasACos = 1,
HasATan = 1,
HasSinh = 1,
HasCosh = 1,
HasTanh = 1,
HasLGamma = 0,
HasDiGamma = 0,
HasZeta = 0,
HasPolygamma = 0,
HasErf = 0,
HasErfc = 0,
HasNdtri = 0,
HasIGamma = 0,
HasIGammac = 0,
HasBetaInc = 0,
HasBlend = has_blend,
// This flag is used to indicate whether packet comparison is supported.
// pcmp_eq, pcmp_lt and pcmp_le should be defined for it to be true.
HasCmp = 1,
HasMax = 1,
HasMin = 1,
HasMul = 1,
HasAdd = 1,
HasFloor = 1,
HasRound = 1,
HasRint = 1,
HasLog1p = 1,
HasExpm1 = 1,
HasCeil = 1,
};
};
#ifdef SYCL_DEVICE_ONLY
#define SYCL_PACKET_TRAITS(packet_type, has_blend, unpacket_type, lengths) \
template <> \
struct packet_traits<unpacket_type> \
: sycl_packet_traits<has_blend, lengths> { \
typedef packet_type type; \
typedef packet_type half; \
};
SYCL_PACKET_TRAITS(cl::sycl::cl_float4, 1, float, 4)
SYCL_PACKET_TRAITS(cl::sycl::cl_float4, 1, const float, 4)
SYCL_PACKET_TRAITS(cl::sycl::cl_double2, 0, double, 2)
SYCL_PACKET_TRAITS(cl::sycl::cl_double2, 0, const double, 2)
#undef SYCL_PACKET_TRAITS
// Make sure this is only available when targeting a GPU: we don't want to
// introduce conflicts between these packet_traits definitions and the ones
// we'll use on the host side (SSE, AVX, ...)
#define SYCL_ARITHMETIC(packet_type) \
template <> \
struct is_arithmetic<packet_type> { \
enum { value = true }; \
};
SYCL_ARITHMETIC(cl::sycl::cl_float4)
SYCL_ARITHMETIC(cl::sycl::cl_double2)
#undef SYCL_ARITHMETIC
#define SYCL_UNPACKET_TRAITS(packet_type, unpacket_type, lengths) \
template <> \
struct unpacket_traits<packet_type> { \
typedef unpacket_type type; \
enum { size = lengths, vectorizable = true, alignment = Aligned16 }; \
typedef packet_type half; \
};
SYCL_UNPACKET_TRAITS(cl::sycl::cl_float4, float, 4)
SYCL_UNPACKET_TRAITS(cl::sycl::cl_double2, double, 2)
#undef SYCL_UNPACKET_TRAITS
#endif
} // end namespace internal
#endif
namespace TensorSycl {
namespace internal {
template <typename PacketReturnType, int PacketSize>
struct PacketWrapper;
// This function should never get called on the device
#ifndef SYCL_DEVICE_ONLY
template <typename PacketReturnType, int PacketSize>
struct PacketWrapper {
typedef typename ::Eigen::internal::unpacket_traits<PacketReturnType>::type
Scalar;
template <typename Index>
EIGEN_DEVICE_FUNC static Scalar scalarize(Index, PacketReturnType &) {
eigen_assert(false && "THERE IS NO PACKETIZE VERSION FOR THE CHOSEN TYPE");
abort();
}
EIGEN_DEVICE_FUNC static PacketReturnType convert_to_packet_type(Scalar in,
Scalar) {
return ::Eigen::internal::template plset<PacketReturnType>(in);
}
EIGEN_DEVICE_FUNC static void set_packet(PacketReturnType, Scalar *) {
eigen_assert(false && "THERE IS NO PACKETIZE VERSION FOR THE CHOSEN TYPE");
abort();
}
};
#elif defined(SYCL_DEVICE_ONLY)
template <typename PacketReturnType>
struct PacketWrapper<PacketReturnType, 4> {
typedef typename ::Eigen::internal::unpacket_traits<PacketReturnType>::type
Scalar;
template <typename Index>
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE static Scalar scalarize(Index index, PacketReturnType &in) {
switch (index) {
case 0:
return in.x();
case 1:
return in.y();
case 2:
return in.z();
case 3:
return in.w();
default:
//INDEX MUST BE BETWEEN 0 and 3.There is no abort function in SYCL kernel. so we cannot use abort here.
// The code will never reach here
__builtin_unreachable();
}
__builtin_unreachable();
}
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE static PacketReturnType convert_to_packet_type(
Scalar in, Scalar other) {
return PacketReturnType(in, other, other, other);
}
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE static void set_packet(PacketReturnType &lhs, Scalar *rhs) {
lhs = PacketReturnType(rhs[0], rhs[1], rhs[2], rhs[3]);
}
};
template <typename PacketReturnType>
struct PacketWrapper<PacketReturnType, 1> {
typedef typename ::Eigen::internal::unpacket_traits<PacketReturnType>::type
Scalar;
template <typename Index>
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE static Scalar scalarize(Index, PacketReturnType &in) {
return in;
}
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE static PacketReturnType convert_to_packet_type(Scalar in,
Scalar) {
return PacketReturnType(in);
}
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE static void set_packet(PacketReturnType &lhs, Scalar *rhs) {
lhs = rhs[0];
}
};
template <typename PacketReturnType>
struct PacketWrapper<PacketReturnType, 2> {
typedef typename ::Eigen::internal::unpacket_traits<PacketReturnType>::type
Scalar;
template <typename Index>
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE static Scalar scalarize(Index index, PacketReturnType &in) {
switch (index) {
case 0:
return in.x();
case 1:
return in.y();
default:
//INDEX MUST BE BETWEEN 0 and 1.There is no abort function in SYCL kernel. so we cannot use abort here.
// The code will never reach here
__builtin_unreachable();
}
__builtin_unreachable();
}
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE static PacketReturnType convert_to_packet_type(
Scalar in, Scalar other) {
return PacketReturnType(in, other);
}
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE static void set_packet(PacketReturnType &lhs, Scalar *rhs) {
lhs = PacketReturnType(rhs[0], rhs[1]);
}
};
#endif
} // end namespace internal
} // end namespace TensorSycl
} // end namespace Eigen
#endif // EIGEN_INTEROP_HEADERS_SYCL_H

View File

@ -0,0 +1,301 @@
// This file is part of Eigen, a lightweight C++ template library
// for linear algebra.
//
// Mehdi Goli Codeplay Software Ltd.
// Ralph Potter Codeplay Software Ltd.
// Luke Iwanski Codeplay Software Ltd.
// Contact: <eigen@codeplay.com>
//
// This Source Code Form is subject to the terms of the Mozilla
// Public License v. 2.0. If a copy of the MPL was not distributed
// with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
/*****************************************************************
* MathFunctions.h
*
* \brief:
* MathFunctions
*
*****************************************************************/
#ifndef EIGEN_MATH_FUNCTIONS_SYCL_H
#define EIGEN_MATH_FUNCTIONS_SYCL_H
namespace Eigen {
namespace internal {
// Make sure this is only available when targeting a GPU: we don't want to
// introduce conflicts between these packet_traits definitions and the ones
// we'll use on the host side (SSE, AVX, ...)
#if defined(SYCL_DEVICE_ONLY)
#define SYCL_PLOG(packet_type) \
template <> \
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE packet_type plog<packet_type>( \
const packet_type& a) { \
return cl::sycl::log(a); \
}
SYCL_PLOG(cl::sycl::cl_float4)
SYCL_PLOG(cl::sycl::cl_double2)
#undef SYCL_PLOG
#define SYCL_PLOG1P(packet_type) \
template <> \
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE packet_type plog1p<packet_type>( \
const packet_type& a) { \
return cl::sycl::log1p(a); \
}
SYCL_PLOG1P(cl::sycl::cl_float4)
SYCL_PLOG1P(cl::sycl::cl_double2)
#undef SYCL_PLOG1P
#define SYCL_PLOG10(packet_type) \
template <> \
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE packet_type plog10<packet_type>( \
const packet_type& a) { \
return cl::sycl::log10(a); \
}
SYCL_PLOG10(cl::sycl::cl_float4)
SYCL_PLOG10(cl::sycl::cl_double2)
#undef SYCL_PLOG10
#define SYCL_PEXP(packet_type) \
template <> \
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE packet_type pexp<packet_type>( \
const packet_type& a) { \
return cl::sycl::exp(a); \
}
SYCL_PEXP(cl::sycl::cl_float4)
SYCL_PEXP(cl::sycl::cl_float)
SYCL_PEXP(cl::sycl::cl_double2)
#undef SYCL_PEXP
#define SYCL_PEXPM1(packet_type) \
template <> \
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE packet_type pexpm1<packet_type>( \
const packet_type& a) { \
return cl::sycl::expm1(a); \
}
SYCL_PEXPM1(cl::sycl::cl_float4)
SYCL_PEXPM1(cl::sycl::cl_double2)
#undef SYCL_PEXPM1
#define SYCL_PSQRT(packet_type) \
template <> \
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE packet_type psqrt<packet_type>( \
const packet_type& a) { \
return cl::sycl::sqrt(a); \
}
SYCL_PSQRT(cl::sycl::cl_float4)
SYCL_PSQRT(cl::sycl::cl_double2)
#undef SYCL_PSQRT
#define SYCL_PRSQRT(packet_type) \
template <> \
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE packet_type prsqrt<packet_type>( \
const packet_type& a) { \
return cl::sycl::rsqrt(a); \
}
SYCL_PRSQRT(cl::sycl::cl_float4)
SYCL_PRSQRT(cl::sycl::cl_double2)
#undef SYCL_PRSQRT
/** \internal \returns the hyperbolic sine of \a a (coeff-wise) */
#define SYCL_PSIN(packet_type) \
template <> \
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE packet_type psin<packet_type>( \
const packet_type& a) { \
return cl::sycl::sin(a); \
}
SYCL_PSIN(cl::sycl::cl_float4)
SYCL_PSIN(cl::sycl::cl_double2)
#undef SYCL_PSIN
/** \internal \returns the hyperbolic cosine of \a a (coeff-wise) */
#define SYCL_PCOS(packet_type) \
template <> \
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE packet_type pcos<packet_type>( \
const packet_type& a) { \
return cl::sycl::cos(a); \
}
SYCL_PCOS(cl::sycl::cl_float4)
SYCL_PCOS(cl::sycl::cl_double2)
#undef SYCL_PCOS
/** \internal \returns the hyperbolic tan of \a a (coeff-wise) */
#define SYCL_PTAN(packet_type) \
template <> \
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE packet_type ptan<packet_type>( \
const packet_type& a) { \
return cl::sycl::tan(a); \
}
SYCL_PTAN(cl::sycl::cl_float4)
SYCL_PTAN(cl::sycl::cl_double2)
#undef SYCL_PTAN
/** \internal \returns the hyperbolic sine of \a a (coeff-wise) */
#define SYCL_PASIN(packet_type) \
template <> \
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE packet_type pasin<packet_type>( \
const packet_type& a) { \
return cl::sycl::asin(a); \
}
SYCL_PASIN(cl::sycl::cl_float4)
SYCL_PASIN(cl::sycl::cl_double2)
#undef SYCL_PASIN
/** \internal \returns the hyperbolic cosine of \a a (coeff-wise) */
#define SYCL_PACOS(packet_type) \
template <> \
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE packet_type pacos<packet_type>( \
const packet_type& a) { \
return cl::sycl::acos(a); \
}
SYCL_PACOS(cl::sycl::cl_float4)
SYCL_PACOS(cl::sycl::cl_double2)
#undef SYCL_PACOS
/** \internal \returns the hyperbolic tan of \a a (coeff-wise) */
#define SYCL_PATAN(packet_type) \
template <> \
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE packet_type patan<packet_type>( \
const packet_type& a) { \
return cl::sycl::atan(a); \
}
SYCL_PATAN(cl::sycl::cl_float4)
SYCL_PATAN(cl::sycl::cl_double2)
#undef SYCL_PATAN
/** \internal \returns the hyperbolic sine of \a a (coeff-wise) */
#define SYCL_PSINH(packet_type) \
template <> \
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE packet_type psinh<packet_type>( \
const packet_type& a) { \
return cl::sycl::sinh(a); \
}
SYCL_PSINH(cl::sycl::cl_float4)
SYCL_PSINH(cl::sycl::cl_double2)
#undef SYCL_PSINH
/** \internal \returns the hyperbolic cosine of \a a (coeff-wise) */
#define SYCL_PCOSH(packet_type) \
template <> \
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE packet_type pcosh<packet_type>( \
const packet_type& a) { \
return cl::sycl::cosh(a); \
}
SYCL_PCOSH(cl::sycl::cl_float4)
SYCL_PCOSH(cl::sycl::cl_double2)
#undef SYCL_PCOSH
/** \internal \returns the hyperbolic tan of \a a (coeff-wise) */
#define SYCL_PTANH(packet_type) \
template <> \
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE packet_type ptanh<packet_type>( \
const packet_type& a) { \
return cl::sycl::tanh(a); \
}
SYCL_PTANH(cl::sycl::cl_float4)
SYCL_PTANH(cl::sycl::cl_double2)
#undef SYCL_PTANH
#define SYCL_PCEIL(packet_type) \
template <> \
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE packet_type pceil<packet_type>( \
const packet_type& a) { \
return cl::sycl::ceil(a); \
}
SYCL_PCEIL(cl::sycl::cl_float4)
SYCL_PCEIL(cl::sycl::cl_double2)
#undef SYCL_PCEIL
#define SYCL_PROUND(packet_type) \
template <> \
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE packet_type pround<packet_type>( \
const packet_type& a) { \
return cl::sycl::round(a); \
}
SYCL_PROUND(cl::sycl::cl_float4)
SYCL_PROUND(cl::sycl::cl_double2)
#undef SYCL_PROUND
#define SYCL_PRINT(packet_type) \
template <> \
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE packet_type print<packet_type>( \
const packet_type& a) { \
return cl::sycl::rint(a); \
}
SYCL_PRINT(cl::sycl::cl_float4)
SYCL_PRINT(cl::sycl::cl_double2)
#undef SYCL_PRINT
#define SYCL_FLOOR(packet_type) \
template <> \
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE packet_type pfloor<packet_type>( \
const packet_type& a) { \
return cl::sycl::floor(a); \
}
SYCL_FLOOR(cl::sycl::cl_float4)
SYCL_FLOOR(cl::sycl::cl_double2)
#undef SYCL_FLOOR
#define SYCL_PMIN(packet_type, expr) \
template <> \
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE packet_type pmin<packet_type>( \
const packet_type& a, const packet_type& b) { \
return expr; \
}
SYCL_PMIN(cl::sycl::cl_float4, cl::sycl::fmin(a, b))
SYCL_PMIN(cl::sycl::cl_double2, cl::sycl::fmin(a, b))
#undef SYCL_PMIN
#define SYCL_PMAX(packet_type, expr) \
template <> \
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE packet_type pmax<packet_type>( \
const packet_type& a, const packet_type& b) { \
return expr; \
}
SYCL_PMAX(cl::sycl::cl_float4, cl::sycl::fmax(a, b))
SYCL_PMAX(cl::sycl::cl_double2, cl::sycl::fmax(a, b))
#undef SYCL_PMAX
#define SYCL_PLDEXP(packet_type) \
template <> \
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE packet_type pldexp( \
const packet_type& a, const packet_type& exponent) { \
return cl::sycl::ldexp( \
a, exponent.template convert<cl::sycl::cl_int, \
cl::sycl::rounding_mode::automatic>()); \
}
SYCL_PLDEXP(cl::sycl::cl_float4)
SYCL_PLDEXP(cl::sycl::cl_double2)
#undef SYCL_PLDEXP
#endif
} // end namespace internal
} // end namespace Eigen
#endif // EIGEN_MATH_FUNCTIONS_SYCL_H

View File

@ -0,0 +1,670 @@
// This file is part of Eigen, a lightweight C++ template library
// for linear algebra.
//
// Mehdi Goli Codeplay Software Ltd.
// Ralph Potter Codeplay Software Ltd.
// Luke Iwanski Codeplay Software Ltd.
// Contact: <eigen@codeplay.com>
//
// This Source Code Form is subject to the terms of the Mozilla
// Public License v. 2.0. If a copy of the MPL was not distributed
// with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
/*****************************************************************
* PacketMath.h
*
* \brief:
* PacketMath
*
*****************************************************************/
#ifndef EIGEN_PACKET_MATH_SYCL_H
#define EIGEN_PACKET_MATH_SYCL_H
#include <type_traits>
namespace Eigen {
namespace internal {
#ifdef SYCL_DEVICE_ONLY
#define SYCL_PLOADT_RO(address_space_target) \
template <typename packet_type, int Alignment> \
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE packet_type ploadt_ro( \
typename cl::sycl::multi_ptr< \
const typename unpacket_traits<packet_type>::type, \
cl::sycl::access::address_space::address_space_target>::pointer_t \
from) { \
typedef typename unpacket_traits<packet_type>::type scalar; \
typedef cl::sycl::multi_ptr< \
scalar, cl::sycl::access::address_space::address_space_target> \
multi_ptr; \
auto res = packet_type( \
static_cast<typename unpacket_traits<packet_type>::type>(0)); \
res.load(0, multi_ptr(const_cast<typename multi_ptr::pointer_t>(from))); \
return res; \
}
SYCL_PLOADT_RO(global_space)
SYCL_PLOADT_RO(local_space)
#undef SYCL_PLOADT_RO
#endif
template <typename packet_type, int Alignment, typename T>
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE packet_type
ploadt_ro(const Eigen::TensorSycl::internal::RangeAccess<
cl::sycl::access::mode::read_write, T>& from) {
return ploadt_ro<packet_type, Alignment>(from.get_pointer());
}
#ifdef SYCL_DEVICE_ONLY
#define SYCL_PLOAD(address_space_target, Alignment, AlignedType) \
template <typename packet_type> \
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE packet_type pload##AlignedType( \
typename cl::sycl::multi_ptr< \
const typename unpacket_traits<packet_type>::type, \
cl::sycl::access::address_space::address_space_target>::pointer_t \
from) { \
return ploadt_ro<packet_type, Alignment>(from); \
}
// global space
SYCL_PLOAD(global_space, Unaligned, u)
SYCL_PLOAD(global_space, Aligned, )
// local space
SYCL_PLOAD(local_space, Unaligned, u)
SYCL_PLOAD(local_space, Aligned, )
#undef SYCL_PLOAD
#endif
#define SYCL_PLOAD(Alignment, AlignedType) \
template <typename packet_type> \
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE packet_type pload##AlignedType( \
const Eigen::TensorSycl::internal::RangeAccess< \
cl::sycl::access::mode::read_write, \
typename unpacket_traits<packet_type>::type> \
from) { \
return ploadt_ro<packet_type, Alignment>(from); \
}
SYCL_PLOAD(Unaligned, u)
SYCL_PLOAD(Aligned, )
#undef SYCL_PLOAD
#ifdef SYCL_DEVICE_ONLY
/** \internal \returns a packet version of \a *from.
* The pointer \a from must be aligned on a \a Alignment bytes boundary. */
#define SYCL_PLOADT(address_space_target) \
template <typename packet_type, int Alignment> \
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE packet_type ploadt( \
typename cl::sycl::multi_ptr< \
const typename unpacket_traits<packet_type>::type, \
cl::sycl::access::address_space::address_space_target>::pointer_t \
from) { \
if (Alignment >= unpacket_traits<packet_type>::alignment) \
return pload<packet_type>(from); \
else \
return ploadu<packet_type>(from); \
}
// global space
SYCL_PLOADT(global_space)
// local space
SYCL_PLOADT(local_space)
#undef SYCL_PLOADT
#endif
template <typename packet_type, int Alignment>
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE packet_type
ploadt(const Eigen::TensorSycl::internal::RangeAccess<
cl::sycl::access::mode::read_write,
typename unpacket_traits<packet_type>::type>& from) {
return ploadt<packet_type, Alignment>(from.get_pointer());
}
#ifdef SYCL_DEVICE_ONLY
// private_space
#define SYCL_PLOADT_RO_SPECIAL(packet_type, Alignment) \
template <> \
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE packet_type \
ploadt_ro<packet_type, Alignment>( \
const typename unpacket_traits<packet_type>::type* from) { \
typedef typename unpacket_traits<packet_type>::type scalar; \
auto res = packet_type(static_cast<scalar>(0)); \
res.template load<cl::sycl::access::address_space::private_space>( \
0, const_cast<scalar*>(from)); \
return res; \
}
SYCL_PLOADT_RO_SPECIAL(cl::sycl::cl_float4, Aligned)
SYCL_PLOADT_RO_SPECIAL(cl::sycl::cl_double2, Aligned)
SYCL_PLOADT_RO_SPECIAL(cl::sycl::cl_float4, Unaligned)
SYCL_PLOADT_RO_SPECIAL(cl::sycl::cl_double2, Unaligned)
#define SYCL_PLOAD_SPECIAL(packet_type, alignment_type) \
template <> \
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE packet_type pload##alignment_type( \
const typename unpacket_traits<packet_type>::type* from) { \
typedef typename unpacket_traits<packet_type>::type scalar; \
auto res = packet_type(static_cast<scalar>(0)); \
res.template load<cl::sycl::access::address_space::private_space>( \
0, const_cast<scalar*>(from)); \
return res; \
}
SYCL_PLOAD_SPECIAL(cl::sycl::cl_float4, )
SYCL_PLOAD_SPECIAL(cl::sycl::cl_double2, )
SYCL_PLOAD_SPECIAL(cl::sycl::cl_float4, u)
SYCL_PLOAD_SPECIAL(cl::sycl::cl_double2, u)
#undef SYCL_PLOAD_SPECIAL
#define SYCL_PSTORE(scalar, packet_type, address_space_target, alignment) \
template <> \
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE void pstore##alignment( \
typename cl::sycl::multi_ptr< \
scalar, \
cl::sycl::access::address_space::address_space_target>::pointer_t \
to, \
const packet_type& from) { \
typedef cl::sycl::multi_ptr< \
scalar, cl::sycl::access::address_space::address_space_target> \
multi_ptr; \
from.store(0, multi_ptr(to)); \
}
// global space
SYCL_PSTORE(float, cl::sycl::cl_float4, global_space, )
SYCL_PSTORE(float, cl::sycl::cl_float4, global_space, u)
SYCL_PSTORE(double, cl::sycl::cl_double2, global_space, )
SYCL_PSTORE(double, cl::sycl::cl_double2, global_space, u)
SYCL_PSTORE(float, cl::sycl::cl_float4, local_space, )
SYCL_PSTORE(float, cl::sycl::cl_float4, local_space, u)
SYCL_PSTORE(double, cl::sycl::cl_double2, local_space, )
SYCL_PSTORE(double, cl::sycl::cl_double2, local_space, u)
SYCL_PSTORE(float, cl::sycl::cl_float4, private_space, )
SYCL_PSTORE(float, cl::sycl::cl_float4, private_space, u)
SYCL_PSTORE(double, cl::sycl::cl_double2, private_space, )
SYCL_PSTORE(double, cl::sycl::cl_double2, private_space, u)
#undef SYCL_PSTORE
#define SYCL_PSTORE_T(address_space_target) \
template <typename scalar, typename packet_type, int Alignment> \
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE void pstoret( \
typename cl::sycl::multi_ptr< \
scalar, \
cl::sycl::access::address_space::address_space_target>::pointer_t \
to, \
const packet_type& from) { \
if (Alignment) \
pstore(to, from); \
else \
pstoreu(to, from); \
}
SYCL_PSTORE_T(global_space)
SYCL_PSTORE_T(local_space)
#undef SYCL_PSTORE_T
#define SYCL_PSET1(packet_type) \
template <> \
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE packet_type pset1<packet_type>( \
const typename unpacket_traits<packet_type>::type& from) { \
return packet_type(from); \
}
// global space
SYCL_PSET1(cl::sycl::cl_float4)
SYCL_PSET1(cl::sycl::cl_double2)
#undef SYCL_PSET1
template <typename packet_type>
struct get_base_packet {
template <typename sycl_multi_pointer>
static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE packet_type
get_ploaddup(sycl_multi_pointer) {}
template <typename sycl_multi_pointer>
static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE packet_type
get_pgather(sycl_multi_pointer, Index) {}
};
template <>
struct get_base_packet<cl::sycl::cl_float4> {
template <typename sycl_multi_pointer>
static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE cl::sycl::cl_float4 get_ploaddup(
sycl_multi_pointer from) {
return cl::sycl::cl_float4(from[0], from[0], from[1], from[1]);
}
template <typename sycl_multi_pointer>
static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE cl::sycl::cl_float4 get_pgather(
sycl_multi_pointer from, Index stride) {
return cl::sycl::cl_float4(from[0 * stride], from[1 * stride],
from[2 * stride], from[3 * stride]);
}
template <typename sycl_multi_pointer>
static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void set_pscatter(
sycl_multi_pointer to, const cl::sycl::cl_float4& from, Index stride) {
auto tmp = stride;
to[0] = from.x();
to[tmp] = from.y();
to[tmp += stride] = from.z();
to[tmp += stride] = from.w();
}
static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE cl::sycl::cl_float4 set_plset(
const float& a) {
return cl::sycl::cl_float4(static_cast<float>(a), static_cast<float>(a + 1),
static_cast<float>(a + 2),
static_cast<float>(a + 3));
}
};
template <>
struct get_base_packet<cl::sycl::cl_double2> {
template <typename sycl_multi_pointer>
static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE cl::sycl::cl_double2
get_ploaddup(const sycl_multi_pointer from) {
return cl::sycl::cl_double2(from[0], from[0]);
}
template <typename sycl_multi_pointer, typename Index>
static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE cl::sycl::cl_double2 get_pgather(
const sycl_multi_pointer from, Index stride) {
return cl::sycl::cl_double2(from[0 * stride], from[1 * stride]);
}
template <typename sycl_multi_pointer>
static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void set_pscatter(
sycl_multi_pointer to, const cl::sycl::cl_double2& from, Index stride) {
to[0] = from.x();
to[stride] = from.y();
}
static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE cl::sycl::cl_double2 set_plset(
const double& a) {
return cl::sycl::cl_double2(static_cast<double>(a),
static_cast<double>(a + 1));
}
};
#define SYCL_PLOAD_DUP(address_space_target) \
template <typename packet_type> \
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE packet_type ploaddup( \
typename cl::sycl::multi_ptr< \
const typename unpacket_traits<packet_type>::type, \
cl::sycl::access::address_space::address_space_target>::pointer_t \
from) { \
return get_base_packet<packet_type>::get_ploaddup(from); \
}
// global space
SYCL_PLOAD_DUP(global_space)
// local_space
SYCL_PLOAD_DUP(local_space)
#undef SYCL_PLOAD_DUP
#define SYCL_PLOAD_DUP_SPECILIZE(packet_type) \
template <> \
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE packet_type ploaddup<packet_type>( \
const typename unpacket_traits<packet_type>::type* from) { \
return get_base_packet<packet_type>::get_ploaddup(from); \
}
SYCL_PLOAD_DUP_SPECILIZE(cl::sycl::cl_float4)
SYCL_PLOAD_DUP_SPECILIZE(cl::sycl::cl_double2)
#undef SYCL_PLOAD_DUP_SPECILIZE
#define SYCL_PLSET(packet_type) \
template <> \
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE packet_type plset<packet_type>( \
const typename unpacket_traits<packet_type>::type& a) { \
return get_base_packet<packet_type>::set_plset(a); \
}
SYCL_PLSET(cl::sycl::cl_float4)
SYCL_PLSET(cl::sycl::cl_double2)
#undef SYCL_PLSET
#define SYCL_PGATHER(address_space_target) \
template <typename Scalar, typename packet_type> \
EIGEN_DEVICE_FUNC inline packet_type pgather( \
typename cl::sycl::multi_ptr< \
const typename unpacket_traits<packet_type>::type, \
cl::sycl::access::address_space::address_space_target>::pointer_t \
from, \
Index stride) { \
return get_base_packet<packet_type>::get_pgather(from, stride); \
}
// global space
SYCL_PGATHER(global_space)
// local space
SYCL_PGATHER(local_space)
#undef SYCL_PGATHER
#define SYCL_PGATHER_SPECILIZE(scalar, packet_type) \
template <> \
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE packet_type \
pgather<scalar, packet_type>( \
const typename unpacket_traits<packet_type>::type* from, Index stride) { \
return get_base_packet<packet_type>::get_pgather(from, stride); \
}
SYCL_PGATHER_SPECILIZE(float, cl::sycl::cl_float4)
SYCL_PGATHER_SPECILIZE(double, cl::sycl::cl_double2)
#undef SYCL_PGATHER_SPECILIZE
#define SYCL_PSCATTER(address_space_target) \
template <typename Scalar, typename packet_type> \
EIGEN_DEVICE_FUNC inline void pscatter( \
typename cl::sycl::multi_ptr< \
typename unpacket_traits<packet_type>::type, \
cl::sycl::access::address_space::address_space_target>::pointer_t \
to, \
const packet_type& from, Index stride) { \
get_base_packet<packet_type>::set_pscatter(to, from, stride); \
}
// global space
SYCL_PSCATTER(global_space)
// local space
SYCL_PSCATTER(local_space)
#undef SYCL_PSCATTER
#define SYCL_PSCATTER_SPECILIZE(scalar, packet_type) \
template <> \
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void pscatter<scalar, packet_type>( \
typename unpacket_traits<packet_type>::type * to, \
const packet_type& from, Index stride) { \
get_base_packet<packet_type>::set_pscatter(to, from, stride); \
}
SYCL_PSCATTER_SPECILIZE(float, cl::sycl::cl_float4)
SYCL_PSCATTER_SPECILIZE(double, cl::sycl::cl_double2)
#undef SYCL_PSCATTER_SPECILIZE
#define SYCL_PMAD(packet_type) \
template <> \
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE packet_type pmadd( \
const packet_type& a, const packet_type& b, const packet_type& c) { \
return cl::sycl::mad(a, b, c); \
}
SYCL_PMAD(cl::sycl::cl_float4)
SYCL_PMAD(cl::sycl::cl_double2)
#undef SYCL_PMAD
template <>
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE float pfirst<cl::sycl::cl_float4>(
const cl::sycl::cl_float4& a) {
return a.x();
}
template <>
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE double pfirst<cl::sycl::cl_double2>(
const cl::sycl::cl_double2& a) {
return a.x();
}
template <>
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE float predux<cl::sycl::cl_float4>(
const cl::sycl::cl_float4& a) {
return a.x() + a.y() + a.z() + a.w();
}
template <>
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE double predux<cl::sycl::cl_double2>(
const cl::sycl::cl_double2& a) {
return a.x() + a.y();
}
template <>
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE float predux_max<cl::sycl::cl_float4>(
const cl::sycl::cl_float4& a) {
return cl::sycl::fmax(cl::sycl::fmax(a.x(), a.y()),
cl::sycl::fmax(a.z(), a.w()));
}
template <>
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE double predux_max<cl::sycl::cl_double2>(
const cl::sycl::cl_double2& a) {
return cl::sycl::fmax(a.x(), a.y());
}
template <>
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE float predux_min<cl::sycl::cl_float4>(
const cl::sycl::cl_float4& a) {
return cl::sycl::fmin(cl::sycl::fmin(a.x(), a.y()),
cl::sycl::fmin(a.z(), a.w()));
}
template <>
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE double predux_min<cl::sycl::cl_double2>(
const cl::sycl::cl_double2& a) {
return cl::sycl::fmin(a.x(), a.y());
}
template <>
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE float predux_mul<cl::sycl::cl_float4>(
const cl::sycl::cl_float4& a) {
return a.x() * a.y() * a.z() * a.w();
}
template <>
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE double predux_mul<cl::sycl::cl_double2>(
const cl::sycl::cl_double2& a) {
return a.x() * a.y();
}
template <>
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE cl::sycl::cl_float4
pabs<cl::sycl::cl_float4>(const cl::sycl::cl_float4& a) {
return cl::sycl::cl_float4(cl::sycl::fabs(a.x()), cl::sycl::fabs(a.y()),
cl::sycl::fabs(a.z()), cl::sycl::fabs(a.w()));
}
template <>
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE cl::sycl::cl_double2
pabs<cl::sycl::cl_double2>(const cl::sycl::cl_double2& a) {
return cl::sycl::cl_double2(cl::sycl::fabs(a.x()), cl::sycl::fabs(a.y()));
}
template <typename Packet>
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Packet sycl_pcmp_le(const Packet &a,
const Packet &b) {
return ((a <= b)
.template convert<typename unpacket_traits<Packet>::type,
cl::sycl::rounding_mode::automatic>());
}
template <typename Packet>
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Packet sycl_pcmp_lt(const Packet &a,
const Packet &b) {
return ((a < b)
.template convert<typename unpacket_traits<Packet>::type,
cl::sycl::rounding_mode::automatic>());
}
template <typename Packet>
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Packet sycl_pcmp_eq(const Packet &a,
const Packet &b) {
return ((a == b)
.template convert<typename unpacket_traits<Packet>::type,
cl::sycl::rounding_mode::automatic>());
}
#define SYCL_PCMP(OP, TYPE) \
template <> \
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE TYPE pcmp_##OP<TYPE>(const TYPE &a, \
const TYPE &b) { \
return sycl_pcmp_##OP<TYPE>(a, b); \
}
SYCL_PCMP(le, cl::sycl::cl_float4)
SYCL_PCMP(lt, cl::sycl::cl_float4)
SYCL_PCMP(eq, cl::sycl::cl_float4)
SYCL_PCMP(le, cl::sycl::cl_double2)
SYCL_PCMP(lt, cl::sycl::cl_double2)
SYCL_PCMP(eq, cl::sycl::cl_double2)
#undef SYCL_PCMP
template <typename T> struct convert_to_integer;
template <> struct convert_to_integer<float> {
using type = std::int32_t;
using packet_type = cl::sycl::cl_int4;
};
template <> struct convert_to_integer<double> {
using type = std::int64_t;
using packet_type = cl::sycl::cl_long2;
};
template <typename PacketIn>
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE typename convert_to_integer<
typename unpacket_traits<PacketIn>::type>::packet_type
vector_as_int(const PacketIn &p) {
return (
p.template convert<typename convert_to_integer<
typename unpacket_traits<PacketIn>::type>::type,
cl::sycl::rounding_mode::automatic>());
}
template <typename packetOut, typename PacketIn>
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE packetOut
convert_vector(const PacketIn &p) {
return (p.template convert<typename unpacket_traits<packetOut>::type,
cl::sycl::rounding_mode::automatic>());
}
#define SYCL_PAND(TYPE) \
template <> \
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TYPE pand<TYPE>(const TYPE &a, \
const TYPE &b) { \
return convert_vector<TYPE>(vector_as_int(a) & vector_as_int(b)); \
}
SYCL_PAND(cl::sycl::cl_float4)
SYCL_PAND(cl::sycl::cl_double2)
#undef SYCL_PAND
#define SYCL_POR(TYPE) \
template <> \
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TYPE por<TYPE>(const TYPE &a, \
const TYPE &b) { \
return convert_vector<TYPE>(vector_as_int(a) | vector_as_int(b)); \
}
SYCL_POR(cl::sycl::cl_float4)
SYCL_POR(cl::sycl::cl_double2)
#undef SYCL_POR
#define SYCL_PXOR(TYPE) \
template <> \
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TYPE pxor<TYPE>(const TYPE &a, \
const TYPE &b) { \
return convert_vector<TYPE>(vector_as_int(a) ^ vector_as_int(b)); \
}
SYCL_PXOR(cl::sycl::cl_float4)
SYCL_PXOR(cl::sycl::cl_double2)
#undef SYCL_PXOR
#define SYCL_PANDNOT(TYPE) \
template <> \
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TYPE pandnot<TYPE>(const TYPE &a, \
const TYPE &b) { \
return convert_vector<TYPE>(vector_as_int(a) & (~vector_as_int(b))); \
}
SYCL_PANDNOT(cl::sycl::cl_float4)
SYCL_PANDNOT(cl::sycl::cl_double2)
#undef SYCL_PANDNOT
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE void ptranspose(
PacketBlock<cl::sycl::cl_float4, 4>& kernel) {
float tmp = kernel.packet[0].y();
kernel.packet[0].y() = kernel.packet[1].x();
kernel.packet[1].x() = tmp;
tmp = kernel.packet[0].z();
kernel.packet[0].z() = kernel.packet[2].x();
kernel.packet[2].x() = tmp;
tmp = kernel.packet[0].w();
kernel.packet[0].w() = kernel.packet[3].x();
kernel.packet[3].x() = tmp;
tmp = kernel.packet[1].z();
kernel.packet[1].z() = kernel.packet[2].y();
kernel.packet[2].y() = tmp;
tmp = kernel.packet[1].w();
kernel.packet[1].w() = kernel.packet[3].y();
kernel.packet[3].y() = tmp;
tmp = kernel.packet[2].w();
kernel.packet[2].w() = kernel.packet[3].z();
kernel.packet[3].z() = tmp;
}
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE void ptranspose(
PacketBlock<cl::sycl::cl_double2, 2>& kernel) {
double tmp = kernel.packet[0].y();
kernel.packet[0].y() = kernel.packet[1].x();
kernel.packet[1].x() = tmp;
}
template <>
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE cl::sycl::cl_float4 pblend(
const Selector<unpacket_traits<cl::sycl::cl_float4>::size>& ifPacket,
const cl::sycl::cl_float4& thenPacket,
const cl::sycl::cl_float4& elsePacket) {
cl::sycl::cl_int4 condition(
ifPacket.select[0] ? 0 : -1, ifPacket.select[1] ? 0 : -1,
ifPacket.select[2] ? 0 : -1, ifPacket.select[3] ? 0 : -1);
return cl::sycl::select(thenPacket, elsePacket, condition);
}
template <>
inline cl::sycl::cl_double2 pblend(
const Selector<unpacket_traits<cl::sycl::cl_double2>::size>& ifPacket,
const cl::sycl::cl_double2& thenPacket,
const cl::sycl::cl_double2& elsePacket) {
cl::sycl::cl_long2 condition(ifPacket.select[0] ? 0 : -1,
ifPacket.select[1] ? 0 : -1);
return cl::sycl::select(thenPacket, elsePacket, condition);
}
#endif // SYCL_DEVICE_ONLY
#define SYCL_PSTORE(alignment) \
template <typename packet_type> \
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE void pstore##alignment( \
const Eigen::TensorSycl::internal::RangeAccess< \
cl::sycl::access::mode::read_write, \
typename unpacket_traits<packet_type>::type>& to, \
const packet_type& from) { \
pstore##alignment(to.get_pointer(), from); \
}
// global space
SYCL_PSTORE()
SYCL_PSTORE(u)
#undef SYCL_PSTORE
template <typename scalar, typename packet_type, int Alignment>
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE void pstoret(
Eigen::TensorSycl::internal::RangeAccess<
cl::sycl::access::mode::read_write,
typename unpacket_traits<packet_type>::type>
to,
const packet_type& from) {
pstoret<scalar, packet_type, Alignment>(to.get_pointer(), from);
}
} // end namespace internal
} // end namespace Eigen
#endif // EIGEN_PACKET_MATH_SYCL_H

View File

@ -0,0 +1,694 @@
/***************************************************************************
* Copyright (C) 2017 Codeplay Software Limited
* This Source Code Form is subject to the terms of the Mozilla
* Public License v. 2.0. If a copy of the MPL was not distributed
* with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
*
*
* SyclMemoryModel.h
*
* Description:
* Interface for SYCL buffers to behave as a non-dereferenceable pointer
* Interface for Placeholder accessor to behave as a pointer on both host
* and device
*
* Authors:
*
* Ruyman Reyes Codeplay Software Ltd.
* Mehdi Goli Codeplay Software Ltd.
* Vanya Yaneva Codeplay Software Ltd.
*
**************************************************************************/
#if defined(EIGEN_USE_SYCL) && \
!defined(EIGEN_CXX11_TENSOR_TENSOR_SYCL_STORAGE_MEMORY_H)
#define EIGEN_CXX11_TENSOR_TENSOR_SYCL_STORAGE_MEMORY_H
#include <CL/sycl.hpp>
#ifdef EIGEN_EXCEPTIONS
#include <stdexcept>
#endif
#include <cstddef>
#include <queue>
#include <set>
#include <unordered_map>
namespace Eigen {
namespace TensorSycl {
namespace internal {
using sycl_acc_target = cl::sycl::access::target;
using sycl_acc_mode = cl::sycl::access::mode;
/**
* Default values for template arguments
*/
using buffer_data_type_t = uint8_t;
const sycl_acc_target default_acc_target = sycl_acc_target::global_buffer;
const sycl_acc_mode default_acc_mode = sycl_acc_mode::read_write;
/**
* PointerMapper
* Associates fake pointers with buffers.
*
*/
class PointerMapper {
public:
using base_ptr_t = std::intptr_t;
/* Structure of a virtual pointer
*
* |================================================|
* | POINTER ADDRESS |
* |================================================|
*/
struct virtual_pointer_t {
/* Type for the pointers
*/
base_ptr_t m_contents;
/** Conversions from virtual_pointer_t to
* void * should just reinterpret_cast the integer number
*/
operator void *() const { return reinterpret_cast<void *>(m_contents); }
/**
* Convert back to the integer number.
*/
operator base_ptr_t() const { return m_contents; }
/**
* Add a certain value to the pointer to create a
* new pointer to that offset
*/
virtual_pointer_t operator+(size_t off) { return m_contents + off; }
/* Numerical order for sorting pointers in containers. */
bool operator<(virtual_pointer_t rhs) const {
return (static_cast<base_ptr_t>(m_contents) <
static_cast<base_ptr_t>(rhs.m_contents));
}
bool operator>(virtual_pointer_t rhs) const {
return (static_cast<base_ptr_t>(m_contents) >
static_cast<base_ptr_t>(rhs.m_contents));
}
/**
* Numerical order for sorting pointers in containers
*/
bool operator==(virtual_pointer_t rhs) const {
return (static_cast<base_ptr_t>(m_contents) ==
static_cast<base_ptr_t>(rhs.m_contents));
}
/**
* Simple forward to the equality overload.
*/
bool operator!=(virtual_pointer_t rhs) const {
return !(this->operator==(rhs));
}
/**
* Converts a void * into a virtual pointer structure.
* Note that this will only work if the void * was
* already a virtual_pointer_t, but we have no way of
* checking
*/
virtual_pointer_t(const void *ptr)
: m_contents(reinterpret_cast<base_ptr_t>(ptr)){};
/**
* Creates a virtual_pointer_t from the given integer
* number
*/
virtual_pointer_t(base_ptr_t u) : m_contents(u){};
};
/* Definition of a null pointer
*/
const virtual_pointer_t null_virtual_ptr = nullptr;
/**
* Whether if a pointer is null or not.
* A pointer is nullptr if the value is of null_virtual_ptr
*/
static inline bool is_nullptr(virtual_pointer_t ptr) {
return (static_cast<void *>(ptr) == nullptr);
}
/* basic type for all buffers
*/
using buffer_t = cl::sycl::buffer_mem;
/**
* Node that stores information about a device allocation.
* Nodes are sorted by size to organise a free list of nodes
* that can be recovered.
*/
struct pMapNode_t {
buffer_t m_buffer;
size_t m_size;
bool m_free;
pMapNode_t(buffer_t b, size_t size, bool f)
: m_buffer{b}, m_size{size}, m_free{f} {
m_buffer.set_final_data(nullptr);
}
bool operator<=(const pMapNode_t &rhs) { return (m_size <= rhs.m_size); }
};
/** Storage of the pointer / buffer tree
*/
using pointerMap_t = std::map<virtual_pointer_t, pMapNode_t>;
/**
* Obtain the insertion point in the pointer map for
* a pointer of the given size.
* \param requiredSize Size attemted to reclaim
*/
typename pointerMap_t::iterator get_insertion_point(size_t requiredSize) {
typename pointerMap_t::iterator retVal;
bool reuse = false;
if (!m_freeList.empty()) {
// try to re-use an existing block
for (auto freeElem : m_freeList) {
if (freeElem->second.m_size >= requiredSize) {
retVal = freeElem;
reuse = true;
// Element is not going to be free anymore
m_freeList.erase(freeElem);
break;
}
}
}
if (!reuse) {
retVal = std::prev(m_pointerMap.end());
}
return retVal;
}
/**
* Returns an iterator to the node that stores the information
* of the given virtual pointer from the given pointer map structure.
* If pointer is not found, throws std::out_of_range.
* If the pointer map structure is empty, throws std::out_of_range
*
* \param pMap the pointerMap_t structure storing all the pointers
* \param virtual_pointer_ptr The virtual pointer to obtain the node of
* \throws std::out:of_range if the pointer is not found or pMap is empty
*/
typename pointerMap_t::iterator get_node(const virtual_pointer_t ptr) {
if (this->count() == 0) {
m_pointerMap.clear();
EIGEN_THROW_X(std::out_of_range("There are no pointers allocated\n"));
}
if (is_nullptr(ptr)) {
m_pointerMap.clear();
EIGEN_THROW_X(std::out_of_range("Cannot access null pointer\n"));
}
// The previous element to the lower bound is the node that
// holds this memory address
auto node = m_pointerMap.lower_bound(ptr);
// If the value of the pointer is not the one of the node
// then we return the previous one
if (node == std::end(m_pointerMap)) {
--node;
} else if (node->first != ptr) {
if (node == std::begin(m_pointerMap)) {
m_pointerMap.clear();
EIGEN_THROW_X(
std::out_of_range("The pointer is not registered in the map\n"));
}
--node;
}
return node;
}
/* get_buffer.
* Returns a buffer from the map using the pointer address
*/
template <typename buffer_data_type = buffer_data_type_t>
cl::sycl::buffer<buffer_data_type, 1> get_buffer(
const virtual_pointer_t ptr) {
using sycl_buffer_t = cl::sycl::buffer<buffer_data_type, 1>;
// get_node() returns a `buffer_mem`, so we need to cast it to a `buffer<>`.
// We can do this without the `buffer_mem` being a pointer, as we
// only declare member variables in the base class (`buffer_mem`) and not in
// the child class (`buffer<>).
auto node = get_node(ptr);
eigen_assert(node->first == ptr || node->first < ptr);
eigen_assert(ptr < static_cast<virtual_pointer_t>(node->second.m_size +
node->first));
return *(static_cast<sycl_buffer_t *>(&node->second.m_buffer));
}
/**
* @brief Returns an accessor to the buffer of the given virtual pointer
* @param accessMode
* @param accessTarget
* @param ptr The virtual pointer
*/
template <sycl_acc_mode access_mode = default_acc_mode,
sycl_acc_target access_target = default_acc_target,
typename buffer_data_type = buffer_data_type_t>
cl::sycl::accessor<buffer_data_type, 1, access_mode, access_target>
get_access(const virtual_pointer_t ptr) {
auto buf = get_buffer<buffer_data_type>(ptr);
return buf.template get_access<access_mode, access_target>();
}
/**
* @brief Returns an accessor to the buffer of the given virtual pointer
* in the given command group scope
* @param accessMode
* @param accessTarget
* @param ptr The virtual pointer
* @param cgh Reference to the command group scope
*/
template <sycl_acc_mode access_mode = default_acc_mode,
sycl_acc_target access_target = default_acc_target,
typename buffer_data_type = buffer_data_type_t>
cl::sycl::accessor<buffer_data_type, 1, access_mode, access_target>
get_access(const virtual_pointer_t ptr, cl::sycl::handler &cgh) {
auto buf = get_buffer<buffer_data_type>(ptr);
return buf.template get_access<access_mode, access_target>(cgh);
}
/*
* Returns the offset from the base address of this pointer.
*/
inline std::ptrdiff_t get_offset(const virtual_pointer_t ptr) {
// The previous element to the lower bound is the node that
// holds this memory address
auto node = get_node(ptr);
auto start = node->first;
eigen_assert(start == ptr || start < ptr);
eigen_assert(ptr < start + node->second.m_size);
return (ptr - start);
}
/*
* Returns the number of elements by which the given pointer is offset from
* the base address.
*/
template <typename buffer_data_type>
inline size_t get_element_offset(const virtual_pointer_t ptr) {
return get_offset(ptr) / sizeof(buffer_data_type);
}
/**
* Constructs the PointerMapper structure.
*/
PointerMapper(base_ptr_t baseAddress = 4096)
: m_pointerMap{}, m_freeList{}, m_baseAddress{baseAddress} {
if (m_baseAddress == 0) {
EIGEN_THROW_X(std::invalid_argument("Base address cannot be zero\n"));
}
};
/**
* PointerMapper cannot be copied or moved
*/
PointerMapper(const PointerMapper &) = delete;
/**
* Empty the pointer list
*/
inline void clear() {
m_freeList.clear();
m_pointerMap.clear();
}
/* add_pointer.
* Adds an existing pointer to the map and returns the virtual pointer id.
*/
inline virtual_pointer_t add_pointer(const buffer_t &b) {
return add_pointer_impl(b);
}
/* add_pointer.
* Adds a pointer to the map and returns the virtual pointer id.
*/
inline virtual_pointer_t add_pointer(buffer_t &&b) {
return add_pointer_impl(b);
}
/**
* @brief Fuses the given node with the previous nodes in the
* pointer map if they are free
*
* @param node A reference to the free node to be fused
*/
void fuse_forward(typename pointerMap_t::iterator &node) {
while (node != std::prev(m_pointerMap.end())) {
// if following node is free
// remove it and extend the current node with its size
auto fwd_node = std::next(node);
if (!fwd_node->second.m_free) {
break;
}
auto fwd_size = fwd_node->second.m_size;
m_freeList.erase(fwd_node);
m_pointerMap.erase(fwd_node);
node->second.m_size += fwd_size;
}
}
/**
* @brief Fuses the given node with the following nodes in the
* pointer map if they are free
*
* @param node A reference to the free node to be fused
*/
void fuse_backward(typename pointerMap_t::iterator &node) {
while (node != m_pointerMap.begin()) {
// if previous node is free, extend it
// with the size of the current one
auto prev_node = std::prev(node);
if (!prev_node->second.m_free) {
break;
}
prev_node->second.m_size += node->second.m_size;
// remove the current node
m_freeList.erase(node);
m_pointerMap.erase(node);
// point to the previous node
node = prev_node;
}
}
/* remove_pointer.
* Removes the given pointer from the map.
* The pointer is allowed to be reused only if ReUse if true.
*/
template <bool ReUse = true>
void remove_pointer(const virtual_pointer_t ptr) {
if (is_nullptr(ptr)) {
return;
}
auto node = this->get_node(ptr);
node->second.m_free = true;
m_freeList.emplace(node);
// Fuse the node
// with free nodes before and after it
fuse_forward(node);
fuse_backward(node);
// If after fusing the node is the last one
// simply remove it (since it is free)
if (node == std::prev(m_pointerMap.end())) {
m_freeList.erase(node);
m_pointerMap.erase(node);
}
}
/* count.
* Return the number of active pointers (i.e, pointers that
* have been malloc but not freed).
*/
size_t count() const { return (m_pointerMap.size() - m_freeList.size()); }
private:
/* add_pointer_impl.
* Adds a pointer to the map and returns the virtual pointer id.
* BufferT is either a const buffer_t& or a buffer_t&&.
*/
template <class BufferT>
virtual_pointer_t add_pointer_impl(BufferT b) {
virtual_pointer_t retVal = nullptr;
size_t bufSize = b.get_count();
pMapNode_t p{b, bufSize, false};
// If this is the first pointer:
if (m_pointerMap.empty()) {
virtual_pointer_t initialVal{m_baseAddress};
m_pointerMap.emplace(initialVal, p);
return initialVal;
}
auto lastElemIter = get_insertion_point(bufSize);
// We are recovering an existing free node
if (lastElemIter->second.m_free) {
lastElemIter->second.m_buffer = b;
lastElemIter->second.m_free = false;
// If the recovered node is bigger than the inserted one
// add a new free node with the remaining space
if (lastElemIter->second.m_size > bufSize) {
// create a new node with the remaining space
auto remainingSize = lastElemIter->second.m_size - bufSize;
pMapNode_t p2{b, remainingSize, true};
// update size of the current node
lastElemIter->second.m_size = bufSize;
// add the new free node
auto newFreePtr = lastElemIter->first + bufSize;
auto freeNode = m_pointerMap.emplace(newFreePtr, p2).first;
m_freeList.emplace(freeNode);
}
retVal = lastElemIter->first;
} else {
size_t lastSize = lastElemIter->second.m_size;
retVal = lastElemIter->first + lastSize;
m_pointerMap.emplace(retVal, p);
}
return retVal;
}
/**
* Compare two iterators to pointer map entries according to
* the size of the allocation on the device.
*/
struct SortBySize {
bool operator()(typename pointerMap_t::iterator a,
typename pointerMap_t::iterator b) const {
return ((a->first < b->first) && (a->second <= b->second)) ||
((a->first < b->first) && (b->second <= a->second));
}
};
/* Maps the pointer addresses to buffer and size pairs.
*/
pointerMap_t m_pointerMap;
/* List of free nodes available for re-using
*/
std::set<typename pointerMap_t::iterator, SortBySize> m_freeList;
/* Base address used when issuing the first virtual pointer, allows users
* to specify alignment. Cannot be zero. */
std::intptr_t m_baseAddress;
};
/* remove_pointer.
* Removes the given pointer from the map.
* The pointer is allowed to be reused only if ReUse if true.
*/
template <>
inline void PointerMapper::remove_pointer<false>(const virtual_pointer_t ptr) {
if (is_nullptr(ptr)) {
return;
}
m_pointerMap.erase(this->get_node(ptr));
}
/**
* Malloc-like interface to the pointer-mapper.
* Given a size, creates a byte-typed buffer and returns a
* fake pointer to keep track of it.
* \param size Size in bytes of the desired allocation
* \throw cl::sycl::exception if error while creating the buffer
*/
inline void *SYCLmalloc(size_t size, PointerMapper &pMap) {
if (size == 0) {
return nullptr;
}
// Create a generic buffer of the given size
using buffer_t = cl::sycl::buffer<buffer_data_type_t, 1>;
auto thePointer = pMap.add_pointer(buffer_t(cl::sycl::range<1>{size}));
// Store the buffer on the global list
return static_cast<void *>(thePointer);
}
/**
* Free-like interface to the pointer mapper.
* Given a fake-pointer created with the virtual-pointer malloc,
* destroys the buffer and remove it from the list.
* If ReUse is false, the pointer is not added to the freeList,
* it should be false only for sub-buffers.
*/
template <bool ReUse = true, typename PointerMapper>
inline void SYCLfree(void *ptr, PointerMapper &pMap) {
pMap.template remove_pointer<ReUse>(ptr);
}
/**
* Clear all the memory allocated by SYCL.
*/
template <typename PointerMapper>
inline void SYCLfreeAll(PointerMapper &pMap) {
pMap.clear();
}
template <cl::sycl::access::mode AcMd, typename T>
struct RangeAccess {
static const auto global_access = cl::sycl::access::target::global_buffer;
static const auto is_place_holder = cl::sycl::access::placeholder::true_t;
typedef T scalar_t;
typedef scalar_t &ref_t;
typedef typename cl::sycl::global_ptr<scalar_t>::pointer_t ptr_t;
// the accessor type does not necessarily the same as T
typedef cl::sycl::accessor<scalar_t, 1, AcMd, global_access, is_place_holder>
accessor;
typedef RangeAccess<AcMd, T> self_t;
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE RangeAccess(accessor access,
size_t offset,
std::intptr_t virtual_ptr)
: access_(access), offset_(offset), virtual_ptr_(virtual_ptr) {}
RangeAccess(cl::sycl::buffer<scalar_t, 1> buff =
cl::sycl::buffer<scalar_t, 1>(cl::sycl::range<1>(1)))
: access_{accessor{buff}}, offset_(0), virtual_ptr_(-1) {}
// This should be only used for null constructor on the host side
RangeAccess(std::nullptr_t) : RangeAccess() {}
// This template parameter must be removed and scalar_t should be replaced
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE ptr_t get_pointer() const {
return (access_.get_pointer().get() + offset_);
}
template <typename Index>
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE self_t &operator+=(Index offset) {
offset_ += (offset);
return *this;
}
template <typename Index>
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE self_t operator+(Index offset) const {
return self_t(access_, offset_ + offset, virtual_ptr_);
}
template <typename Index>
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE self_t operator-(Index offset) const {
return self_t(access_, offset_ - offset, virtual_ptr_);
}
template <typename Index>
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE self_t &operator-=(Index offset) {
offset_ -= offset;
return *this;
}
// THIS IS FOR NULL COMPARISON ONLY
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE friend bool operator==(
const RangeAccess &lhs, std::nullptr_t) {
return ((lhs.virtual_ptr_ == -1));
}
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE friend bool operator!=(
const RangeAccess &lhs, std::nullptr_t i) {
return !(lhs == i);
}
// THIS IS FOR NULL COMPARISON ONLY
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE friend bool operator==(
std::nullptr_t, const RangeAccess &rhs) {
return ((rhs.virtual_ptr_ == -1));
}
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE friend bool operator!=(
std::nullptr_t i, const RangeAccess &rhs) {
return !(i == rhs);
}
// Prefix operator (Increment and return value)
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE self_t &operator++() {
offset_++;
return (*this);
}
// Postfix operator (Return value and increment)
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE self_t operator++(int i) {
EIGEN_UNUSED_VARIABLE(i);
self_t temp_iterator(*this);
offset_++;
return temp_iterator;
}
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE std::ptrdiff_t get_size() const {
return (access_.get_count() - offset_);
}
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE std::ptrdiff_t get_offset() const {
return offset_;
}
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void set_offset(std::ptrdiff_t offset) {
offset_ = offset;
}
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE ref_t operator*() const {
return *get_pointer();
}
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE ref_t operator*() {
return *get_pointer();
}
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE ptr_t operator->() = delete;
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE ref_t operator[](int x) {
return *(get_pointer() + x);
}
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE ref_t operator[](int x) const {
return *(get_pointer() + x);
}
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE scalar_t *get_virtual_pointer() const {
return reinterpret_cast<scalar_t *>(virtual_ptr_ +
(offset_ * sizeof(scalar_t)));
}
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE explicit operator bool() const {
return (virtual_ptr_ != -1);
}
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE operator RangeAccess<AcMd, const T>() {
return RangeAccess<AcMd, const T>(access_, offset_, virtual_ptr_);
}
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
operator RangeAccess<AcMd, const T>() const {
return RangeAccess<AcMd, const T>(access_, offset_, virtual_ptr_);
}
// binding placeholder accessors to a command group handler for SYCL
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void bind(
cl::sycl::handler &cgh) const {
cgh.require(access_);
}
private:
accessor access_;
size_t offset_;
std::intptr_t virtual_ptr_; // the location of the buffer in the map
};
template <cl::sycl::access::mode AcMd, typename T>
struct RangeAccess<AcMd, const T> : RangeAccess<AcMd, T> {
typedef RangeAccess<AcMd, T> Base;
using Base::Base;
};
} // namespace internal
} // namespace TensorSycl
} // namespace Eigen
#endif // EIGEN_CXX11_TENSOR_TENSOR_SYCL_STORAGE_MEMORY_H

View File

@ -0,0 +1,85 @@
// This file is part of Eigen, a lightweight C++ template library
// for linear algebra.
//
// Mehdi Goli Codeplay Software Ltd.
// Ralph Potter Codeplay Software Ltd.
// Luke Iwanski Codeplay Software Ltd.
// Contact: <eigen@codeplay.com>
//
// This Source Code Form is subject to the terms of the Mozilla
// Public License v. 2.0. If a copy of the MPL was not distributed
// with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
/*****************************************************************
* TypeCasting.h
*
* \brief:
* TypeCasting
*
*****************************************************************/
#ifndef EIGEN_TYPE_CASTING_SYCL_H
#define EIGEN_TYPE_CASTING_SYCL_H
namespace Eigen {
namespace internal {
#ifdef SYCL_DEVICE_ONLY
template <>
struct type_casting_traits<float, int> {
enum { VectorizedCast = 1, SrcCoeffRatio = 1, TgtCoeffRatio = 1 };
};
template <>
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE cl::sycl::cl_int4
pcast<cl::sycl::cl_float4, cl::sycl::cl_int4>(const cl::sycl::cl_float4& a) {
return a
.template convert<cl::sycl::cl_int, cl::sycl::rounding_mode::automatic>();
}
template <>
struct type_casting_traits<int, float> {
enum { VectorizedCast = 1, SrcCoeffRatio = 1, TgtCoeffRatio = 1 };
};
template <>
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE cl::sycl::cl_float4
pcast<cl::sycl::cl_int4, cl::sycl::cl_float4>(const cl::sycl::cl_int4& a) {
return a.template convert<cl::sycl::cl_float,
cl::sycl::rounding_mode::automatic>();
}
template <>
struct type_casting_traits<double, float> {
enum { VectorizedCast = 1, SrcCoeffRatio = 2, TgtCoeffRatio = 1 };
};
template <>
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE cl::sycl::cl_float4
pcast<cl::sycl::cl_double2, cl::sycl::cl_float4>(
const cl::sycl::cl_double2& a, const cl::sycl::cl_double2& b) {
auto a1 = a.template convert<cl::sycl::cl_float,
cl::sycl::rounding_mode::automatic>();
auto b1 = b.template convert<cl::sycl::cl_float,
cl::sycl::rounding_mode::automatic>();
return cl::sycl::float4(a1.x(), a1.y(), b1.x(), b1.y());
}
template <>
struct type_casting_traits<float, double> {
enum { VectorizedCast = 1, SrcCoeffRatio = 1, TgtCoeffRatio = 2 };
};
template <>
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE cl::sycl::cl_double2
pcast<cl::sycl::cl_float4, cl::sycl::cl_double2>(const cl::sycl::cl_float4& a) {
// Simply discard the second half of the input
return cl::sycl::cl_double2(a.x(), a.y());
}
#endif
} // end namespace internal
} // end namespace Eigen
#endif // EIGEN_TYPE_CASTING_SYCL_H

View File

@ -0,0 +1,512 @@
// This file is part of Eigen, a lightweight C++ template library
// for linear algebra.
//
// Copyright (C) 2008-2018 Gael Guennebaud <gael.guennebaud@inria.fr>
// Copyright (C) 2020, Arm Limited and Contributors
//
// This Source Code Form is subject to the terms of the Mozilla
// Public License v. 2.0. If a copy of the MPL was not distributed
// with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
#ifndef EIGEN_CONFIGURE_VECTORIZATION_H
#define EIGEN_CONFIGURE_VECTORIZATION_H
//------------------------------------------------------------------------------------------
// Static and dynamic alignment control
//
// The main purpose of this section is to define EIGEN_MAX_ALIGN_BYTES and EIGEN_MAX_STATIC_ALIGN_BYTES
// as the maximal boundary in bytes on which dynamically and statically allocated data may be alignment respectively.
// The values of EIGEN_MAX_ALIGN_BYTES and EIGEN_MAX_STATIC_ALIGN_BYTES can be specified by the user. If not,
// a default value is automatically computed based on architecture, compiler, and OS.
//
// This section also defines macros EIGEN_ALIGN_TO_BOUNDARY(N) and the shortcuts EIGEN_ALIGN{8,16,32,_MAX}
// to be used to declare statically aligned buffers.
//------------------------------------------------------------------------------------------
/* EIGEN_ALIGN_TO_BOUNDARY(n) forces data to be n-byte aligned. This is used to satisfy SIMD requirements.
* However, we do that EVEN if vectorization (EIGEN_VECTORIZE) is disabled,
* so that vectorization doesn't affect binary compatibility.
*
* If we made alignment depend on whether or not EIGEN_VECTORIZE is defined, it would be impossible to link
* vectorized and non-vectorized code.
*
* FIXME: this code can be cleaned up once we switch to proper C++11 only.
*/
#if (defined EIGEN_CUDACC)
#define EIGEN_ALIGN_TO_BOUNDARY(n) __align__(n)
#define EIGEN_ALIGNOF(x) __alignof(x)
#elif EIGEN_HAS_ALIGNAS
#define EIGEN_ALIGN_TO_BOUNDARY(n) alignas(n)
#define EIGEN_ALIGNOF(x) alignof(x)
#elif EIGEN_COMP_GNUC || EIGEN_COMP_PGI || EIGEN_COMP_IBM || EIGEN_COMP_ARM
#define EIGEN_ALIGN_TO_BOUNDARY(n) __attribute__((aligned(n)))
#define EIGEN_ALIGNOF(x) __alignof(x)
#elif EIGEN_COMP_MSVC
#define EIGEN_ALIGN_TO_BOUNDARY(n) __declspec(align(n))
#define EIGEN_ALIGNOF(x) __alignof(x)
#elif EIGEN_COMP_SUNCC
// FIXME not sure about this one:
#define EIGEN_ALIGN_TO_BOUNDARY(n) __attribute__((aligned(n)))
#define EIGEN_ALIGNOF(x) __alignof(x)
#else
#error Please tell me what is the equivalent of alignas(n) and alignof(x) for your compiler
#endif
// If the user explicitly disable vectorization, then we also disable alignment
#if defined(EIGEN_DONT_VECTORIZE)
#if defined(EIGEN_GPUCC)
// GPU code is always vectorized and requires memory alignment for
// statically allocated buffers.
#define EIGEN_IDEAL_MAX_ALIGN_BYTES 16
#else
#define EIGEN_IDEAL_MAX_ALIGN_BYTES 0
#endif
#elif defined(__AVX512F__)
// 64 bytes static alignment is preferred only if really required
#define EIGEN_IDEAL_MAX_ALIGN_BYTES 64
#elif defined(__AVX__)
// 32 bytes static alignment is preferred only if really required
#define EIGEN_IDEAL_MAX_ALIGN_BYTES 32
#else
#define EIGEN_IDEAL_MAX_ALIGN_BYTES 16
#endif
// EIGEN_MIN_ALIGN_BYTES defines the minimal value for which the notion of explicit alignment makes sense
#define EIGEN_MIN_ALIGN_BYTES 16
// Defined the boundary (in bytes) on which the data needs to be aligned. Note
// that unless EIGEN_ALIGN is defined and not equal to 0, the data may not be
// aligned at all regardless of the value of this #define.
#if (defined(EIGEN_DONT_ALIGN_STATICALLY) || defined(EIGEN_DONT_ALIGN)) && defined(EIGEN_MAX_STATIC_ALIGN_BYTES) && EIGEN_MAX_STATIC_ALIGN_BYTES>0
#error EIGEN_MAX_STATIC_ALIGN_BYTES and EIGEN_DONT_ALIGN[_STATICALLY] are both defined with EIGEN_MAX_STATIC_ALIGN_BYTES!=0. Use EIGEN_MAX_STATIC_ALIGN_BYTES=0 as a synonym of EIGEN_DONT_ALIGN_STATICALLY.
#endif
// EIGEN_DONT_ALIGN_STATICALLY and EIGEN_DONT_ALIGN are deprecated
// They imply EIGEN_MAX_STATIC_ALIGN_BYTES=0
#if defined(EIGEN_DONT_ALIGN_STATICALLY) || defined(EIGEN_DONT_ALIGN)
#ifdef EIGEN_MAX_STATIC_ALIGN_BYTES
#undef EIGEN_MAX_STATIC_ALIGN_BYTES
#endif
#define EIGEN_MAX_STATIC_ALIGN_BYTES 0
#endif
#ifndef EIGEN_MAX_STATIC_ALIGN_BYTES
// Try to automatically guess what is the best default value for EIGEN_MAX_STATIC_ALIGN_BYTES
// 16 byte alignment is only useful for vectorization. Since it affects the ABI, we need to enable
// 16 byte alignment on all platforms where vectorization might be enabled. In theory we could always
// enable alignment, but it can be a cause of problems on some platforms, so we just disable it in
// certain common platform (compiler+architecture combinations) to avoid these problems.
// Only static alignment is really problematic (relies on nonstandard compiler extensions),
// try to keep heap alignment even when we have to disable static alignment.
#if EIGEN_COMP_GNUC && !(EIGEN_ARCH_i386_OR_x86_64 || EIGEN_ARCH_ARM_OR_ARM64 || EIGEN_ARCH_PPC || EIGEN_ARCH_IA64 || EIGEN_ARCH_MIPS)
#define EIGEN_GCC_AND_ARCH_DOESNT_WANT_STACK_ALIGNMENT 1
#elif EIGEN_ARCH_ARM_OR_ARM64 && EIGEN_COMP_GNUC_STRICT && EIGEN_GNUC_AT_MOST(4, 6)
// Old versions of GCC on ARM, at least 4.4, were once seen to have buggy static alignment support.
// Not sure which version fixed it, hopefully it doesn't affect 4.7, which is still somewhat in use.
// 4.8 and newer seem definitely unaffected.
#define EIGEN_GCC_AND_ARCH_DOESNT_WANT_STACK_ALIGNMENT 1
#else
#define EIGEN_GCC_AND_ARCH_DOESNT_WANT_STACK_ALIGNMENT 0
#endif
// static alignment is completely disabled with GCC 3, Sun Studio, and QCC/QNX
#if !EIGEN_GCC_AND_ARCH_DOESNT_WANT_STACK_ALIGNMENT \
&& !EIGEN_GCC3_OR_OLDER \
&& !EIGEN_COMP_SUNCC \
&& !EIGEN_OS_QNX
#define EIGEN_ARCH_WANTS_STACK_ALIGNMENT 1
#else
#define EIGEN_ARCH_WANTS_STACK_ALIGNMENT 0
#endif
#if EIGEN_ARCH_WANTS_STACK_ALIGNMENT
#define EIGEN_MAX_STATIC_ALIGN_BYTES EIGEN_IDEAL_MAX_ALIGN_BYTES
#else
#define EIGEN_MAX_STATIC_ALIGN_BYTES 0
#endif
#endif
// If EIGEN_MAX_ALIGN_BYTES is defined, then it is considered as an upper bound for EIGEN_MAX_STATIC_ALIGN_BYTES
#if defined(EIGEN_MAX_ALIGN_BYTES) && EIGEN_MAX_ALIGN_BYTES<EIGEN_MAX_STATIC_ALIGN_BYTES
#undef EIGEN_MAX_STATIC_ALIGN_BYTES
#define EIGEN_MAX_STATIC_ALIGN_BYTES EIGEN_MAX_ALIGN_BYTES
#endif
#if EIGEN_MAX_STATIC_ALIGN_BYTES==0 && !defined(EIGEN_DISABLE_UNALIGNED_ARRAY_ASSERT)
#define EIGEN_DISABLE_UNALIGNED_ARRAY_ASSERT
#endif
// At this stage, EIGEN_MAX_STATIC_ALIGN_BYTES>0 is the true test whether we want to align arrays on the stack or not.
// It takes into account both the user choice to explicitly enable/disable alignment (by setting EIGEN_MAX_STATIC_ALIGN_BYTES)
// and the architecture config (EIGEN_ARCH_WANTS_STACK_ALIGNMENT).
// Henceforth, only EIGEN_MAX_STATIC_ALIGN_BYTES should be used.
// Shortcuts to EIGEN_ALIGN_TO_BOUNDARY
#define EIGEN_ALIGN8 EIGEN_ALIGN_TO_BOUNDARY(8)
#define EIGEN_ALIGN16 EIGEN_ALIGN_TO_BOUNDARY(16)
#define EIGEN_ALIGN32 EIGEN_ALIGN_TO_BOUNDARY(32)
#define EIGEN_ALIGN64 EIGEN_ALIGN_TO_BOUNDARY(64)
#if EIGEN_MAX_STATIC_ALIGN_BYTES>0
#define EIGEN_ALIGN_MAX EIGEN_ALIGN_TO_BOUNDARY(EIGEN_MAX_STATIC_ALIGN_BYTES)
#else
#define EIGEN_ALIGN_MAX
#endif
// Dynamic alignment control
#if defined(EIGEN_DONT_ALIGN) && defined(EIGEN_MAX_ALIGN_BYTES) && EIGEN_MAX_ALIGN_BYTES>0
#error EIGEN_MAX_ALIGN_BYTES and EIGEN_DONT_ALIGN are both defined with EIGEN_MAX_ALIGN_BYTES!=0. Use EIGEN_MAX_ALIGN_BYTES=0 as a synonym of EIGEN_DONT_ALIGN.
#endif
#ifdef EIGEN_DONT_ALIGN
#ifdef EIGEN_MAX_ALIGN_BYTES
#undef EIGEN_MAX_ALIGN_BYTES
#endif
#define EIGEN_MAX_ALIGN_BYTES 0
#elif !defined(EIGEN_MAX_ALIGN_BYTES)
#define EIGEN_MAX_ALIGN_BYTES EIGEN_IDEAL_MAX_ALIGN_BYTES
#endif
#if EIGEN_IDEAL_MAX_ALIGN_BYTES > EIGEN_MAX_ALIGN_BYTES
#define EIGEN_DEFAULT_ALIGN_BYTES EIGEN_IDEAL_MAX_ALIGN_BYTES
#else
#define EIGEN_DEFAULT_ALIGN_BYTES EIGEN_MAX_ALIGN_BYTES
#endif
#ifndef EIGEN_UNALIGNED_VECTORIZE
#define EIGEN_UNALIGNED_VECTORIZE 1
#endif
//----------------------------------------------------------------------
// if alignment is disabled, then disable vectorization. Note: EIGEN_MAX_ALIGN_BYTES is the proper check, it takes into
// account both the user's will (EIGEN_MAX_ALIGN_BYTES,EIGEN_DONT_ALIGN) and our own platform checks
#if EIGEN_MAX_ALIGN_BYTES==0
#ifndef EIGEN_DONT_VECTORIZE
#define EIGEN_DONT_VECTORIZE
#endif
#endif
// The following (except #include <malloc.h> and _M_IX86_FP ??) can likely be
// removed as gcc 4.1 and msvc 2008 are not supported anyways.
#if EIGEN_COMP_MSVC
#include <malloc.h> // for _aligned_malloc -- need it regardless of whether vectorization is enabled
#if (EIGEN_COMP_MSVC >= 1500) // 2008 or later
// a user reported that in 64-bit mode, MSVC doesn't care to define _M_IX86_FP.
#if (defined(_M_IX86_FP) && (_M_IX86_FP >= 2)) || EIGEN_ARCH_x86_64
#define EIGEN_SSE2_ON_MSVC_2008_OR_LATER
#endif
#endif
#else
#if (defined __SSE2__) && ( (!EIGEN_COMP_GNUC) || EIGEN_COMP_ICC || EIGEN_GNUC_AT_LEAST(4,2) )
#define EIGEN_SSE2_ON_NON_MSVC_BUT_NOT_OLD_GCC
#endif
#endif
#if !(defined(EIGEN_DONT_VECTORIZE) || defined(EIGEN_GPUCC))
#if defined (EIGEN_SSE2_ON_NON_MSVC_BUT_NOT_OLD_GCC) || defined(EIGEN_SSE2_ON_MSVC_2008_OR_LATER)
// Defines symbols for compile-time detection of which instructions are
// used.
// EIGEN_VECTORIZE_YY is defined if and only if the instruction set YY is used
#define EIGEN_VECTORIZE
#define EIGEN_VECTORIZE_SSE
#define EIGEN_VECTORIZE_SSE2
// Detect sse3/ssse3/sse4:
// gcc and icc defines __SSE3__, ...
// there is no way to know about this on msvc. You can define EIGEN_VECTORIZE_SSE* if you
// want to force the use of those instructions with msvc.
#ifdef __SSE3__
#define EIGEN_VECTORIZE_SSE3
#endif
#ifdef __SSSE3__
#define EIGEN_VECTORIZE_SSSE3
#endif
#ifdef __SSE4_1__
#define EIGEN_VECTORIZE_SSE4_1
#endif
#ifdef __SSE4_2__
#define EIGEN_VECTORIZE_SSE4_2
#endif
#ifdef __AVX__
#ifndef EIGEN_USE_SYCL
#define EIGEN_VECTORIZE_AVX
#endif
#define EIGEN_VECTORIZE_SSE3
#define EIGEN_VECTORIZE_SSSE3
#define EIGEN_VECTORIZE_SSE4_1
#define EIGEN_VECTORIZE_SSE4_2
#endif
#ifdef __AVX2__
#ifndef EIGEN_USE_SYCL
#define EIGEN_VECTORIZE_AVX2
#define EIGEN_VECTORIZE_AVX
#endif
#define EIGEN_VECTORIZE_SSE3
#define EIGEN_VECTORIZE_SSSE3
#define EIGEN_VECTORIZE_SSE4_1
#define EIGEN_VECTORIZE_SSE4_2
#endif
#if defined(__FMA__) || (EIGEN_COMP_MSVC && defined(__AVX2__))
// MSVC does not expose a switch dedicated for FMA
// For MSVC, AVX2 => FMA
#define EIGEN_VECTORIZE_FMA
#endif
#if defined(__AVX512F__)
#ifndef EIGEN_VECTORIZE_FMA
#if EIGEN_COMP_GNUC
#error Please add -mfma to your compiler flags: compiling with -mavx512f alone without SSE/AVX FMA is not supported (bug 1638).
#else
#error Please enable FMA in your compiler flags (e.g. -mfma): compiling with AVX512 alone without SSE/AVX FMA is not supported (bug 1638).
#endif
#endif
#ifndef EIGEN_USE_SYCL
#define EIGEN_VECTORIZE_AVX512
#define EIGEN_VECTORIZE_AVX2
#define EIGEN_VECTORIZE_AVX
#endif
#define EIGEN_VECTORIZE_FMA
#define EIGEN_VECTORIZE_SSE3
#define EIGEN_VECTORIZE_SSSE3
#define EIGEN_VECTORIZE_SSE4_1
#define EIGEN_VECTORIZE_SSE4_2
#ifndef EIGEN_USE_SYCL
#ifdef __AVX512DQ__
#define EIGEN_VECTORIZE_AVX512DQ
#endif
#ifdef __AVX512ER__
#define EIGEN_VECTORIZE_AVX512ER
#endif
#ifdef __AVX512BF16__
#define EIGEN_VECTORIZE_AVX512BF16
#endif
#endif
#endif
// Disable AVX support on broken xcode versions
#if defined(__apple_build_version__) && (__apple_build_version__ == 11000033 ) && ( __MAC_OS_X_VERSION_MIN_REQUIRED == 101500 )
// A nasty bug in the clang compiler shipped with xcode in a common compilation situation
// when XCode 11.0 and Mac deployment target macOS 10.15 is https://trac.macports.org/ticket/58776#no1
#ifdef EIGEN_VECTORIZE_AVX
#undef EIGEN_VECTORIZE_AVX
#warning "Disabling AVX support: clang compiler shipped with XCode 11.[012] generates broken assembly with -macosx-version-min=10.15 and AVX enabled. "
#ifdef EIGEN_VECTORIZE_AVX2
#undef EIGEN_VECTORIZE_AVX2
#endif
#ifdef EIGEN_VECTORIZE_FMA
#undef EIGEN_VECTORIZE_FMA
#endif
#ifdef EIGEN_VECTORIZE_AVX512
#undef EIGEN_VECTORIZE_AVX512
#endif
#ifdef EIGEN_VECTORIZE_AVX512DQ
#undef EIGEN_VECTORIZE_AVX512DQ
#endif
#ifdef EIGEN_VECTORIZE_AVX512ER
#undef EIGEN_VECTORIZE_AVX512ER
#endif
#endif
// NOTE: Confirmed test failures in XCode 11.0, and XCode 11.2 with -macosx-version-min=10.15 and AVX
// NOTE using -macosx-version-min=10.15 with Xcode 11.0 results in runtime segmentation faults in many tests, 11.2 produce core dumps in 3 tests
// NOTE using -macosx-version-min=10.14 produces functioning and passing tests in all cases
// NOTE __clang_version__ "11.0.0 (clang-1100.0.33.8)" XCode 11.0 <- Produces many segfault and core dumping tests
// with -macosx-version-min=10.15 and AVX
// NOTE __clang_version__ "11.0.0 (clang-1100.0.33.12)" XCode 11.2 <- Produces 3 core dumping tests with
// -macosx-version-min=10.15 and AVX
#endif
// include files
// This extern "C" works around a MINGW-w64 compilation issue
// https://sourceforge.net/tracker/index.php?func=detail&aid=3018394&group_id=202880&atid=983354
// In essence, intrin.h is included by windows.h and also declares intrinsics (just as emmintrin.h etc. below do).
// However, intrin.h uses an extern "C" declaration, and g++ thus complains of duplicate declarations
// with conflicting linkage. The linkage for intrinsics doesn't matter, but at that stage the compiler doesn't know;
// so, to avoid compile errors when windows.h is included after Eigen/Core, ensure intrinsics are extern "C" here too.
// notice that since these are C headers, the extern "C" is theoretically needed anyways.
extern "C" {
// In theory we should only include immintrin.h and not the other *mmintrin.h header files directly.
// Doing so triggers some issues with ICC. However old gcc versions seems to not have this file, thus:
#if EIGEN_COMP_ICC >= 1110
#include <immintrin.h>
#else
#include <mmintrin.h>
#include <emmintrin.h>
#include <xmmintrin.h>
#ifdef EIGEN_VECTORIZE_SSE3
#include <pmmintrin.h>
#endif
#ifdef EIGEN_VECTORIZE_SSSE3
#include <tmmintrin.h>
#endif
#ifdef EIGEN_VECTORIZE_SSE4_1
#include <smmintrin.h>
#endif
#ifdef EIGEN_VECTORIZE_SSE4_2
#include <nmmintrin.h>
#endif
#if defined(EIGEN_VECTORIZE_AVX) || defined(EIGEN_VECTORIZE_AVX512)
#include <immintrin.h>
#endif
#endif
} // end extern "C"
#elif defined __VSX__
#define EIGEN_VECTORIZE
#define EIGEN_VECTORIZE_VSX
#include <altivec.h>
// We need to #undef all these ugly tokens defined in <altivec.h>
// => use __vector instead of vector
#undef bool
#undef vector
#undef pixel
#elif defined __ALTIVEC__
#define EIGEN_VECTORIZE
#define EIGEN_VECTORIZE_ALTIVEC
#include <altivec.h>
// We need to #undef all these ugly tokens defined in <altivec.h>
// => use __vector instead of vector
#undef bool
#undef vector
#undef pixel
#elif ((defined __ARM_NEON) || (defined __ARM_NEON__)) && !(defined EIGEN_ARM64_USE_SVE)
#define EIGEN_VECTORIZE
#define EIGEN_VECTORIZE_NEON
#include <arm_neon.h>
// We currently require SVE to be enabled explicitly via EIGEN_ARM64_USE_SVE and
// will not select the backend automatically
#elif (defined __ARM_FEATURE_SVE) && (defined EIGEN_ARM64_USE_SVE)
#define EIGEN_VECTORIZE
#define EIGEN_VECTORIZE_SVE
#include <arm_sve.h>
// Since we depend on knowing SVE vector lengths at compile-time, we need
// to ensure a fixed lengths is set
#if defined __ARM_FEATURE_SVE_BITS
#define EIGEN_ARM64_SVE_VL __ARM_FEATURE_SVE_BITS
#else
#error "Eigen requires a fixed SVE lector length but EIGEN_ARM64_SVE_VL is not set."
#endif
#elif (defined __s390x__ && defined __VEC__)
#define EIGEN_VECTORIZE
#define EIGEN_VECTORIZE_ZVECTOR
#include <vecintrin.h>
#elif defined __mips_msa
// Limit MSA optimizations to little-endian CPUs for now.
// TODO: Perhaps, eventually support MSA optimizations on big-endian CPUs?
#if defined(__BYTE_ORDER__) && (__BYTE_ORDER__ == __ORDER_LITTLE_ENDIAN__)
#if defined(__LP64__)
#define EIGEN_MIPS_64
#else
#define EIGEN_MIPS_32
#endif
#define EIGEN_VECTORIZE
#define EIGEN_VECTORIZE_MSA
#include <msa.h>
#endif
#endif
#endif
// Following the Arm ACLE arm_neon.h should also include arm_fp16.h but not all
// compilers seem to follow this. We therefore include it explicitly.
// See also: https://bugs.llvm.org/show_bug.cgi?id=47955
#if defined(EIGEN_HAS_ARM64_FP16_SCALAR_ARITHMETIC)
#include <arm_fp16.h>
#endif
#if defined(__F16C__) && (!defined(EIGEN_GPUCC) && (!defined(EIGEN_COMP_CLANG) || EIGEN_COMP_CLANG>=380))
// We can use the optimized fp16 to float and float to fp16 conversion routines
#define EIGEN_HAS_FP16_C
#if defined(EIGEN_COMP_CLANG)
// Workaround for clang: The FP16C intrinsics for clang are included by
// immintrin.h, as opposed to emmintrin.h as suggested by Intel:
// https://software.intel.com/sites/landingpage/IntrinsicsGuide/#othertechs=FP16C&expand=1711
#include <immintrin.h>
#endif
#endif
#if defined EIGEN_CUDACC
#define EIGEN_VECTORIZE_GPU
#include <vector_types.h>
#if EIGEN_CUDA_SDK_VER >= 70500
#define EIGEN_HAS_CUDA_FP16
#endif
#endif
#if defined(EIGEN_HAS_CUDA_FP16)
#include <cuda_runtime_api.h>
#include <cuda_fp16.h>
#endif
#if defined(EIGEN_HIPCC)
#define EIGEN_VECTORIZE_GPU
#include <hip/hip_vector_types.h>
#define EIGEN_HAS_HIP_FP16
#include <hip/hip_fp16.h>
#endif
/** \brief Namespace containing all symbols from the %Eigen library. */
namespace Eigen {
inline static const char *SimdInstructionSetsInUse(void) {
#if defined(EIGEN_VECTORIZE_AVX512)
return "AVX512, FMA, AVX2, AVX, SSE, SSE2, SSE3, SSSE3, SSE4.1, SSE4.2";
#elif defined(EIGEN_VECTORIZE_AVX)
return "AVX SSE, SSE2, SSE3, SSSE3, SSE4.1, SSE4.2";
#elif defined(EIGEN_VECTORIZE_SSE4_2)
return "SSE, SSE2, SSE3, SSSE3, SSE4.1, SSE4.2";
#elif defined(EIGEN_VECTORIZE_SSE4_1)
return "SSE, SSE2, SSE3, SSSE3, SSE4.1";
#elif defined(EIGEN_VECTORIZE_SSSE3)
return "SSE, SSE2, SSE3, SSSE3";
#elif defined(EIGEN_VECTORIZE_SSE3)
return "SSE, SSE2, SSE3";
#elif defined(EIGEN_VECTORIZE_SSE2)
return "SSE, SSE2";
#elif defined(EIGEN_VECTORIZE_ALTIVEC)
return "AltiVec";
#elif defined(EIGEN_VECTORIZE_VSX)
return "VSX";
#elif defined(EIGEN_VECTORIZE_NEON)
return "ARM NEON";
#elif defined(EIGEN_VECTORIZE_SVE)
return "ARM SVE";
#elif defined(EIGEN_VECTORIZE_ZVECTOR)
return "S390X ZVECTOR";
#elif defined(EIGEN_VECTORIZE_MSA)
return "MIPS MSA";
#else
return "None";
#endif
}
} // end namespace Eigen
#endif // EIGEN_CONFIGURE_VECTORIZATION_H

View File

@ -0,0 +1,186 @@
// This file is part of Eigen, a lightweight C++ template library
// for linear algebra.
//
// Copyright (C) 2017 Gael Guennebaud <gael.guennebaud@inria.fr>
//
// This Source Code Form is subject to the terms of the Mozilla
// Public License v. 2.0. If a copy of the MPL was not distributed
// with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
#ifndef EIGEN_INDEXED_VIEW_HELPER_H
#define EIGEN_INDEXED_VIEW_HELPER_H
namespace Eigen {
namespace internal {
struct symbolic_last_tag {};
}
/** \var last
* \ingroup Core_Module
*
* Can be used as a parameter to Eigen::seq and Eigen::seqN functions to symbolically reference the last element/row/columns
* of the underlying vector or matrix once passed to DenseBase::operator()(const RowIndices&, const ColIndices&).
*
* This symbolic placeholder supports standard arithmetic operations.
*
* A typical usage example would be:
* \code
* using namespace Eigen;
* using Eigen::last;
* VectorXd v(n);
* v(seq(2,last-2)).setOnes();
* \endcode
*
* \sa end
*/
static const symbolic::SymbolExpr<internal::symbolic_last_tag> last; // PLEASE use Eigen::last instead of Eigen::placeholders::last
/** \var lastp1
* \ingroup Core_Module
*
* Can be used as a parameter to Eigen::seq and Eigen::seqN functions to symbolically
* reference the last+1 element/row/columns of the underlying vector or matrix once
* passed to DenseBase::operator()(const RowIndices&, const ColIndices&).
*
* This symbolic placeholder supports standard arithmetic operations.
* It is essentially an alias to last+fix<1>.
*
* \sa last
*/
#ifdef EIGEN_PARSED_BY_DOXYGEN
static const auto lastp1 = last+fix<1>;
#else
// Using a FixedExpr<1> expression is important here to make sure the compiler
// can fully optimize the computation starting indices with zero overhead.
static const symbolic::AddExpr<symbolic::SymbolExpr<internal::symbolic_last_tag>,symbolic::ValueExpr<Eigen::internal::FixedInt<1> > > lastp1(last+fix<1>());
#endif
namespace internal {
// Replace symbolic last/end "keywords" by their true runtime value
inline Index eval_expr_given_size(Index x, Index /* size */) { return x; }
template<int N>
FixedInt<N> eval_expr_given_size(FixedInt<N> x, Index /*size*/) { return x; }
template<typename Derived>
Index eval_expr_given_size(const symbolic::BaseExpr<Derived> &x, Index size)
{
return x.derived().eval(last=size-1);
}
// Extract increment/step at compile time
template<typename T, typename EnableIf = void> struct get_compile_time_incr {
enum { value = UndefinedIncr };
};
// Analogue of std::get<0>(x), but tailored for our needs.
template<typename T>
EIGEN_CONSTEXPR Index first(const T& x) EIGEN_NOEXCEPT { return x.first(); }
// IndexedViewCompatibleType/makeIndexedViewCompatible turn an arbitrary object of type T into something usable by MatrixSlice
// The generic implementation is a no-op
template<typename T,int XprSize,typename EnableIf=void>
struct IndexedViewCompatibleType {
typedef T type;
};
template<typename T,typename Q>
const T& makeIndexedViewCompatible(const T& x, Index /*size*/, Q) { return x; }
//--------------------------------------------------------------------------------
// Handling of a single Index
//--------------------------------------------------------------------------------
struct SingleRange {
enum {
SizeAtCompileTime = 1
};
SingleRange(Index val) : m_value(val) {}
Index operator[](Index) const { return m_value; }
static EIGEN_CONSTEXPR Index size() EIGEN_NOEXCEPT { return 1; }
Index first() const EIGEN_NOEXCEPT { return m_value; }
Index m_value;
};
template<> struct get_compile_time_incr<SingleRange> {
enum { value = 1 }; // 1 or 0 ??
};
// Turn a single index into something that looks like an array (i.e., that exposes a .size(), and operator[](int) methods)
template<typename T, int XprSize>
struct IndexedViewCompatibleType<T,XprSize,typename internal::enable_if<internal::is_integral<T>::value>::type> {
// Here we could simply use Array, but maybe it's less work for the compiler to use
// a simpler wrapper as SingleRange
//typedef Eigen::Array<Index,1,1> type;
typedef SingleRange type;
};
template<typename T, int XprSize>
struct IndexedViewCompatibleType<T, XprSize, typename enable_if<symbolic::is_symbolic<T>::value>::type> {
typedef SingleRange type;
};
template<typename T>
typename enable_if<symbolic::is_symbolic<T>::value,SingleRange>::type
makeIndexedViewCompatible(const T& id, Index size, SpecializedType) {
return eval_expr_given_size(id,size);
}
//--------------------------------------------------------------------------------
// Handling of all
//--------------------------------------------------------------------------------
struct all_t { all_t() {} };
// Convert a symbolic 'all' into a usable range type
template<int XprSize>
struct AllRange {
enum { SizeAtCompileTime = XprSize };
AllRange(Index size = XprSize) : m_size(size) {}
EIGEN_CONSTEXPR Index operator[](Index i) const EIGEN_NOEXCEPT { return i; }
EIGEN_CONSTEXPR Index size() const EIGEN_NOEXCEPT { return m_size.value(); }
EIGEN_CONSTEXPR Index first() const EIGEN_NOEXCEPT { return 0; }
variable_if_dynamic<Index,XprSize> m_size;
};
template<int XprSize>
struct IndexedViewCompatibleType<all_t,XprSize> {
typedef AllRange<XprSize> type;
};
template<typename XprSizeType>
inline AllRange<get_fixed_value<XprSizeType>::value> makeIndexedViewCompatible(all_t , XprSizeType size, SpecializedType) {
return AllRange<get_fixed_value<XprSizeType>::value>(size);
}
template<int Size> struct get_compile_time_incr<AllRange<Size> > {
enum { value = 1 };
};
} // end namespace internal
/** \var all
* \ingroup Core_Module
* Can be used as a parameter to DenseBase::operator()(const RowIndices&, const ColIndices&) to index all rows or columns
*/
static const Eigen::internal::all_t all; // PLEASE use Eigen::all instead of Eigen::placeholders::all
namespace placeholders {
typedef symbolic::SymbolExpr<internal::symbolic_last_tag> last_t;
typedef symbolic::AddExpr<symbolic::SymbolExpr<internal::symbolic_last_tag>,symbolic::ValueExpr<Eigen::internal::FixedInt<1> > > end_t;
typedef Eigen::internal::all_t all_t;
EIGEN_DEPRECATED static const all_t all = Eigen::all; // PLEASE use Eigen::all instead of Eigen::placeholders::all
EIGEN_DEPRECATED static const last_t last = Eigen::last; // PLEASE use Eigen::last instead of Eigen::placeholders::last
EIGEN_DEPRECATED static const end_t end = Eigen::lastp1; // PLEASE use Eigen::lastp1 instead of Eigen::placeholders::end
}
} // end namespace Eigen
#endif // EIGEN_INDEXED_VIEW_HELPER_H

View File

@ -0,0 +1,272 @@
// This file is part of Eigen, a lightweight C++ template library
// for linear algebra.
//
// Copyright (C) 2017 Gael Guennebaud <gael.guennebaud@inria.fr>
//
// This Source Code Form is subject to the terms of the Mozilla
// Public License v. 2.0. If a copy of the MPL was not distributed
// with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
#ifndef EIGEN_INTEGRAL_CONSTANT_H
#define EIGEN_INTEGRAL_CONSTANT_H
namespace Eigen {
namespace internal {
template<int N> class FixedInt;
template<int N> class VariableAndFixedInt;
/** \internal
* \class FixedInt
*
* This class embeds a compile-time integer \c N.
*
* It is similar to c++11 std::integral_constant<int,N> but with some additional features
* such as:
* - implicit conversion to int
* - arithmetic and some bitwise operators: -, +, *, /, %, &, |
* - c++98/14 compatibility with fix<N> and fix<N>() syntax to define integral constants.
*
* It is strongly discouraged to directly deal with this class FixedInt. Instances are expcected to
* be created by the user using Eigen::fix<N> or Eigen::fix<N>(). In C++98-11, the former syntax does
* not create a FixedInt<N> instance but rather a point to function that needs to be \em cleaned-up
* using the generic helper:
* \code
* internal::cleanup_index_type<T>::type
* internal::cleanup_index_type<T,DynamicKey>::type
* \endcode
* where T can a FixedInt<N>, a pointer to function FixedInt<N> (*)(), or numerous other integer-like representations.
* \c DynamicKey is either Dynamic (default) or DynamicIndex and used to identify true compile-time values.
*
* For convenience, you can extract the compile-time value \c N in a generic way using the following helper:
* \code
* internal::get_fixed_value<T,DefaultVal>::value
* \endcode
* that will give you \c N if T equals FixedInt<N> or FixedInt<N> (*)(), and \c DefaultVal if T does not embed any compile-time value (e.g., T==int).
*
* \sa fix<N>, class VariableAndFixedInt
*/
template<int N> class FixedInt
{
public:
static const int value = N;
EIGEN_CONSTEXPR operator int() const { return value; }
FixedInt() {}
FixedInt( VariableAndFixedInt<N> other) {
#ifndef EIGEN_INTERNAL_DEBUGGING
EIGEN_UNUSED_VARIABLE(other);
#endif
eigen_internal_assert(int(other)==N);
}
FixedInt<-N> operator-() const { return FixedInt<-N>(); }
template<int M>
FixedInt<N+M> operator+( FixedInt<M>) const { return FixedInt<N+M>(); }
template<int M>
FixedInt<N-M> operator-( FixedInt<M>) const { return FixedInt<N-M>(); }
template<int M>
FixedInt<N*M> operator*( FixedInt<M>) const { return FixedInt<N*M>(); }
template<int M>
FixedInt<N/M> operator/( FixedInt<M>) const { return FixedInt<N/M>(); }
template<int M>
FixedInt<N%M> operator%( FixedInt<M>) const { return FixedInt<N%M>(); }
template<int M>
FixedInt<N|M> operator|( FixedInt<M>) const { return FixedInt<N|M>(); }
template<int M>
FixedInt<N&M> operator&( FixedInt<M>) const { return FixedInt<N&M>(); }
#if EIGEN_HAS_CXX14_VARIABLE_TEMPLATES
// Needed in C++14 to allow fix<N>():
FixedInt operator() () const { return *this; }
VariableAndFixedInt<N> operator() (int val) const { return VariableAndFixedInt<N>(val); }
#else
FixedInt ( FixedInt<N> (*)() ) {}
#endif
#if EIGEN_HAS_CXX11
FixedInt(std::integral_constant<int,N>) {}
#endif
};
/** \internal
* \class VariableAndFixedInt
*
* This class embeds both a compile-time integer \c N and a runtime integer.
* Both values are supposed to be equal unless the compile-time value \c N has a special
* value meaning that the runtime-value should be used. Depending on the context, this special
* value can be either Eigen::Dynamic (for positive quantities) or Eigen::DynamicIndex (for
* quantities that can be negative).
*
* It is the return-type of the function Eigen::fix<N>(int), and most of the time this is the only
* way it is used. It is strongly discouraged to directly deal with instances of VariableAndFixedInt.
* Indeed, in order to write generic code, it is the responsibility of the callee to properly convert
* it to either a true compile-time quantity (i.e. a FixedInt<N>), or to a runtime quantity (e.g., an Index)
* using the following generic helper:
* \code
* internal::cleanup_index_type<T>::type
* internal::cleanup_index_type<T,DynamicKey>::type
* \endcode
* where T can be a template instantiation of VariableAndFixedInt or numerous other integer-like representations.
* \c DynamicKey is either Dynamic (default) or DynamicIndex and used to identify true compile-time values.
*
* For convenience, you can also extract the compile-time value \c N using the following helper:
* \code
* internal::get_fixed_value<T,DefaultVal>::value
* \endcode
* that will give you \c N if T equals VariableAndFixedInt<N>, and \c DefaultVal if T does not embed any compile-time value (e.g., T==int).
*
* \sa fix<N>(int), class FixedInt
*/
template<int N> class VariableAndFixedInt
{
public:
static const int value = N;
operator int() const { return m_value; }
VariableAndFixedInt(int val) { m_value = val; }
protected:
int m_value;
};
template<typename T, int Default=Dynamic> struct get_fixed_value {
static const int value = Default;
};
template<int N,int Default> struct get_fixed_value<FixedInt<N>,Default> {
static const int value = N;
};
#if !EIGEN_HAS_CXX14
template<int N,int Default> struct get_fixed_value<FixedInt<N> (*)(),Default> {
static const int value = N;
};
#endif
template<int N,int Default> struct get_fixed_value<VariableAndFixedInt<N>,Default> {
static const int value = N ;
};
template<typename T, int N, int Default>
struct get_fixed_value<variable_if_dynamic<T,N>,Default> {
static const int value = N;
};
template<typename T> EIGEN_DEVICE_FUNC Index get_runtime_value(const T &x) { return x; }
#if !EIGEN_HAS_CXX14
template<int N> EIGEN_DEVICE_FUNC Index get_runtime_value(FixedInt<N> (*)()) { return N; }
#endif
// Cleanup integer/FixedInt/VariableAndFixedInt/etc types:
// By default, no cleanup:
template<typename T, int DynamicKey=Dynamic, typename EnableIf=void> struct cleanup_index_type { typedef T type; };
// Convert any integral type (e.g., short, int, unsigned int, etc.) to Eigen::Index
template<typename T, int DynamicKey> struct cleanup_index_type<T,DynamicKey,typename internal::enable_if<internal::is_integral<T>::value>::type> { typedef Index type; };
#if !EIGEN_HAS_CXX14
// In c++98/c++11, fix<N> is a pointer to function that we better cleanup to a true FixedInt<N>:
template<int N, int DynamicKey> struct cleanup_index_type<FixedInt<N> (*)(), DynamicKey> { typedef FixedInt<N> type; };
#endif
// If VariableAndFixedInt does not match DynamicKey, then we turn it to a pure compile-time value:
template<int N, int DynamicKey> struct cleanup_index_type<VariableAndFixedInt<N>, DynamicKey> { typedef FixedInt<N> type; };
// If VariableAndFixedInt matches DynamicKey, then we turn it to a pure runtime-value (aka Index):
template<int DynamicKey> struct cleanup_index_type<VariableAndFixedInt<DynamicKey>, DynamicKey> { typedef Index type; };
#if EIGEN_HAS_CXX11
template<int N, int DynamicKey> struct cleanup_index_type<std::integral_constant<int,N>, DynamicKey> { typedef FixedInt<N> type; };
#endif
} // end namespace internal
#ifndef EIGEN_PARSED_BY_DOXYGEN
#if EIGEN_HAS_CXX14_VARIABLE_TEMPLATES
template<int N>
static const internal::FixedInt<N> fix{};
#else
template<int N>
inline internal::FixedInt<N> fix() { return internal::FixedInt<N>(); }
// The generic typename T is mandatory. Otherwise, a code like fix<N> could refer to either the function above or this next overload.
// This way a code like fix<N> can only refer to the previous function.
template<int N,typename T>
inline internal::VariableAndFixedInt<N> fix(T val) { return internal::VariableAndFixedInt<N>(internal::convert_index<int>(val)); }
#endif
#else // EIGEN_PARSED_BY_DOXYGEN
/** \var fix<N>()
* \ingroup Core_Module
*
* This \em identifier permits to construct an object embedding a compile-time integer \c N.
*
* \tparam N the compile-time integer value
*
* It is typically used in conjunction with the Eigen::seq and Eigen::seqN functions to pass compile-time values to them:
* \code
* seqN(10,fix<4>,fix<-3>) // <=> [10 7 4 1]
* \endcode
*
* See also the function fix(int) to pass both a compile-time and runtime value.
*
* In c++14, it is implemented as:
* \code
* template<int N> static const internal::FixedInt<N> fix{};
* \endcode
* where internal::FixedInt<N> is an internal template class similar to
* <a href="http://en.cppreference.com/w/cpp/types/integral_constant">\c std::integral_constant </a><tt> <int,N> </tt>
* Here, \c fix<N> is thus an object of type \c internal::FixedInt<N>.
*
* In c++98/11, it is implemented as a function:
* \code
* template<int N> inline internal::FixedInt<N> fix();
* \endcode
* Here internal::FixedInt<N> is thus a pointer to function.
*
* If for some reason you want a true object in c++98 then you can write: \code fix<N>() \endcode which is also valid in c++14.
*
* \sa fix<N>(int), seq, seqN
*/
template<int N>
static const auto fix();
/** \fn fix<N>(int)
* \ingroup Core_Module
*
* This function returns an object embedding both a compile-time integer \c N, and a fallback runtime value \a val.
*
* \tparam N the compile-time integer value
* \param val the fallback runtime integer value
*
* This function is a more general version of the \ref fix identifier/function that can be used in template code
* where the compile-time value could turn out to actually mean "undefined at compile-time". For positive integers
* such as a size or a dimension, this case is identified by Eigen::Dynamic, whereas runtime signed integers
* (e.g., an increment/stride) are identified as Eigen::DynamicIndex. In such a case, the runtime value \a val
* will be used as a fallback.
*
* A typical use case would be:
* \code
* template<typename Derived> void foo(const MatrixBase<Derived> &mat) {
* const int N = Derived::RowsAtCompileTime==Dynamic ? Dynamic : Derived::RowsAtCompileTime/2;
* const int n = mat.rows()/2;
* ... mat( seqN(0,fix<N>(n) ) ...;
* }
* \endcode
* In this example, the function Eigen::seqN knows that the second argument is expected to be a size.
* If the passed compile-time value N equals Eigen::Dynamic, then the proxy object returned by fix will be dissmissed, and converted to an Eigen::Index of value \c n.
* Otherwise, the runtime-value \c n will be dissmissed, and the returned ArithmeticSequence will be of the exact same type as <tt> seqN(0,fix<N>) </tt>.
*
* \sa fix, seqN, class ArithmeticSequence
*/
template<int N>
static const auto fix(int val);
#endif // EIGEN_PARSED_BY_DOXYGEN
} // end namespace Eigen
#endif // EIGEN_INTEGRAL_CONSTANT_H

View File

@ -0,0 +1,51 @@
// This file is part of Eigen, a lightweight C++ template library
// for linear algebra.
//
// Copyright (C) 2017 Gael Guennebaud <gael.guennebaud@inria.fr>
//
// This Source Code Form is subject to the terms of the Mozilla
// Public License v. 2.0. If a copy of the MPL was not distributed
// with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
#ifndef EIGEN_RESHAPED_HELPER_H
#define EIGEN_RESHAPED_HELPER_H
namespace Eigen {
enum AutoSize_t { AutoSize };
const int AutoOrder = 2;
namespace internal {
template<typename SizeType,typename OtherSize, int TotalSize>
struct get_compiletime_reshape_size {
enum { value = get_fixed_value<SizeType>::value };
};
template<typename SizeType>
Index get_runtime_reshape_size(SizeType size, Index /*other*/, Index /*total*/) {
return internal::get_runtime_value(size);
}
template<typename OtherSize, int TotalSize>
struct get_compiletime_reshape_size<AutoSize_t,OtherSize,TotalSize> {
enum {
other_size = get_fixed_value<OtherSize>::value,
value = (TotalSize==Dynamic || other_size==Dynamic) ? Dynamic : TotalSize / other_size };
};
inline Index get_runtime_reshape_size(AutoSize_t /*size*/, Index other, Index total) {
return total/other;
}
template<int Flags, int Order>
struct get_compiletime_reshape_order {
enum { value = Order == AutoOrder ? Flags & RowMajorBit : Order };
};
}
} // end namespace Eigen
#endif // EIGEN_RESHAPED_HELPER_H

View File

@ -0,0 +1,293 @@
// This file is part of Eigen, a lightweight C++ template library
// for linear algebra.
//
// Copyright (C) 2017 Gael Guennebaud <gael.guennebaud@inria.fr>
//
// This Source Code Form is subject to the terms of the Mozilla
// Public License v. 2.0. If a copy of the MPL was not distributed
// with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
#ifndef EIGEN_SYMBOLIC_INDEX_H
#define EIGEN_SYMBOLIC_INDEX_H
namespace Eigen {
/** \namespace Eigen::symbolic
* \ingroup Core_Module
*
* This namespace defines a set of classes and functions to build and evaluate symbolic expressions of scalar type Index.
* Here is a simple example:
*
* \code
* // First step, defines symbols:
* struct x_tag {}; static const symbolic::SymbolExpr<x_tag> x;
* struct y_tag {}; static const symbolic::SymbolExpr<y_tag> y;
* struct z_tag {}; static const symbolic::SymbolExpr<z_tag> z;
*
* // Defines an expression:
* auto expr = (x+3)/y+z;
*
* // And evaluate it: (c++14)
* std::cout << expr.eval(x=6,y=3,z=-13) << "\n";
*
* // In c++98/11, only one symbol per expression is supported for now:
* auto expr98 = (3-x)/2;
* std::cout << expr98.eval(x=6) << "\n";
* \endcode
*
* It is currently only used internally to define and manipulate the Eigen::last and Eigen::lastp1 symbols in Eigen::seq and Eigen::seqN.
*
*/
namespace symbolic {
template<typename Tag> class Symbol;
template<typename Arg0> class NegateExpr;
template<typename Arg1,typename Arg2> class AddExpr;
template<typename Arg1,typename Arg2> class ProductExpr;
template<typename Arg1,typename Arg2> class QuotientExpr;
// A simple wrapper around an integral value to provide the eval method.
// We could also use a free-function symbolic_eval...
template<typename IndexType=Index>
class ValueExpr {
public:
ValueExpr(IndexType val) : m_value(val) {}
template<typename T>
IndexType eval_impl(const T&) const { return m_value; }
protected:
IndexType m_value;
};
// Specialization for compile-time value,
// It is similar to ValueExpr(N) but this version helps the compiler to generate better code.
template<int N>
class ValueExpr<internal::FixedInt<N> > {
public:
ValueExpr() {}
template<typename T>
EIGEN_CONSTEXPR Index eval_impl(const T&) const { return N; }
};
/** \class BaseExpr
* \ingroup Core_Module
* Common base class of any symbolic expressions
*/
template<typename Derived>
class BaseExpr
{
public:
const Derived& derived() const { return *static_cast<const Derived*>(this); }
/** Evaluate the expression given the \a values of the symbols.
*
* \param values defines the values of the symbols, it can either be a SymbolValue or a std::tuple of SymbolValue
* as constructed by SymbolExpr::operator= operator.
*
*/
template<typename T>
Index eval(const T& values) const { return derived().eval_impl(values); }
#if EIGEN_HAS_CXX14
template<typename... Types>
Index eval(Types&&... values) const { return derived().eval_impl(std::make_tuple(values...)); }
#endif
NegateExpr<Derived> operator-() const { return NegateExpr<Derived>(derived()); }
AddExpr<Derived,ValueExpr<> > operator+(Index b) const
{ return AddExpr<Derived,ValueExpr<> >(derived(), b); }
AddExpr<Derived,ValueExpr<> > operator-(Index a) const
{ return AddExpr<Derived,ValueExpr<> >(derived(), -a); }
ProductExpr<Derived,ValueExpr<> > operator*(Index a) const
{ return ProductExpr<Derived,ValueExpr<> >(derived(),a); }
QuotientExpr<Derived,ValueExpr<> > operator/(Index a) const
{ return QuotientExpr<Derived,ValueExpr<> >(derived(),a); }
friend AddExpr<Derived,ValueExpr<> > operator+(Index a, const BaseExpr& b)
{ return AddExpr<Derived,ValueExpr<> >(b.derived(), a); }
friend AddExpr<NegateExpr<Derived>,ValueExpr<> > operator-(Index a, const BaseExpr& b)
{ return AddExpr<NegateExpr<Derived>,ValueExpr<> >(-b.derived(), a); }
friend ProductExpr<ValueExpr<>,Derived> operator*(Index a, const BaseExpr& b)
{ return ProductExpr<ValueExpr<>,Derived>(a,b.derived()); }
friend QuotientExpr<ValueExpr<>,Derived> operator/(Index a, const BaseExpr& b)
{ return QuotientExpr<ValueExpr<>,Derived>(a,b.derived()); }
template<int N>
AddExpr<Derived,ValueExpr<internal::FixedInt<N> > > operator+(internal::FixedInt<N>) const
{ return AddExpr<Derived,ValueExpr<internal::FixedInt<N> > >(derived(), ValueExpr<internal::FixedInt<N> >()); }
template<int N>
AddExpr<Derived,ValueExpr<internal::FixedInt<-N> > > operator-(internal::FixedInt<N>) const
{ return AddExpr<Derived,ValueExpr<internal::FixedInt<-N> > >(derived(), ValueExpr<internal::FixedInt<-N> >()); }
template<int N>
ProductExpr<Derived,ValueExpr<internal::FixedInt<N> > > operator*(internal::FixedInt<N>) const
{ return ProductExpr<Derived,ValueExpr<internal::FixedInt<N> > >(derived(),ValueExpr<internal::FixedInt<N> >()); }
template<int N>
QuotientExpr<Derived,ValueExpr<internal::FixedInt<N> > > operator/(internal::FixedInt<N>) const
{ return QuotientExpr<Derived,ValueExpr<internal::FixedInt<N> > >(derived(),ValueExpr<internal::FixedInt<N> >()); }
template<int N>
friend AddExpr<Derived,ValueExpr<internal::FixedInt<N> > > operator+(internal::FixedInt<N>, const BaseExpr& b)
{ return AddExpr<Derived,ValueExpr<internal::FixedInt<N> > >(b.derived(), ValueExpr<internal::FixedInt<N> >()); }
template<int N>
friend AddExpr<NegateExpr<Derived>,ValueExpr<internal::FixedInt<N> > > operator-(internal::FixedInt<N>, const BaseExpr& b)
{ return AddExpr<NegateExpr<Derived>,ValueExpr<internal::FixedInt<N> > >(-b.derived(), ValueExpr<internal::FixedInt<N> >()); }
template<int N>
friend ProductExpr<ValueExpr<internal::FixedInt<N> >,Derived> operator*(internal::FixedInt<N>, const BaseExpr& b)
{ return ProductExpr<ValueExpr<internal::FixedInt<N> >,Derived>(ValueExpr<internal::FixedInt<N> >(),b.derived()); }
template<int N>
friend QuotientExpr<ValueExpr<internal::FixedInt<N> >,Derived> operator/(internal::FixedInt<N>, const BaseExpr& b)
{ return QuotientExpr<ValueExpr<internal::FixedInt<N> > ,Derived>(ValueExpr<internal::FixedInt<N> >(),b.derived()); }
#if (!EIGEN_HAS_CXX14)
template<int N>
AddExpr<Derived,ValueExpr<internal::FixedInt<N> > > operator+(internal::FixedInt<N> (*)()) const
{ return AddExpr<Derived,ValueExpr<internal::FixedInt<N> > >(derived(), ValueExpr<internal::FixedInt<N> >()); }
template<int N>
AddExpr<Derived,ValueExpr<internal::FixedInt<-N> > > operator-(internal::FixedInt<N> (*)()) const
{ return AddExpr<Derived,ValueExpr<internal::FixedInt<-N> > >(derived(), ValueExpr<internal::FixedInt<-N> >()); }
template<int N>
ProductExpr<Derived,ValueExpr<internal::FixedInt<N> > > operator*(internal::FixedInt<N> (*)()) const
{ return ProductExpr<Derived,ValueExpr<internal::FixedInt<N> > >(derived(),ValueExpr<internal::FixedInt<N> >()); }
template<int N>
QuotientExpr<Derived,ValueExpr<internal::FixedInt<N> > > operator/(internal::FixedInt<N> (*)()) const
{ return QuotientExpr<Derived,ValueExpr<internal::FixedInt<N> > >(derived(),ValueExpr<internal::FixedInt<N> >()); }
template<int N>
friend AddExpr<Derived,ValueExpr<internal::FixedInt<N> > > operator+(internal::FixedInt<N> (*)(), const BaseExpr& b)
{ return AddExpr<Derived,ValueExpr<internal::FixedInt<N> > >(b.derived(), ValueExpr<internal::FixedInt<N> >()); }
template<int N>
friend AddExpr<NegateExpr<Derived>,ValueExpr<internal::FixedInt<N> > > operator-(internal::FixedInt<N> (*)(), const BaseExpr& b)
{ return AddExpr<NegateExpr<Derived>,ValueExpr<internal::FixedInt<N> > >(-b.derived(), ValueExpr<internal::FixedInt<N> >()); }
template<int N>
friend ProductExpr<ValueExpr<internal::FixedInt<N> >,Derived> operator*(internal::FixedInt<N> (*)(), const BaseExpr& b)
{ return ProductExpr<ValueExpr<internal::FixedInt<N> >,Derived>(ValueExpr<internal::FixedInt<N> >(),b.derived()); }
template<int N>
friend QuotientExpr<ValueExpr<internal::FixedInt<N> >,Derived> operator/(internal::FixedInt<N> (*)(), const BaseExpr& b)
{ return QuotientExpr<ValueExpr<internal::FixedInt<N> > ,Derived>(ValueExpr<internal::FixedInt<N> >(),b.derived()); }
#endif
template<typename OtherDerived>
AddExpr<Derived,OtherDerived> operator+(const BaseExpr<OtherDerived> &b) const
{ return AddExpr<Derived,OtherDerived>(derived(), b.derived()); }
template<typename OtherDerived>
AddExpr<Derived,NegateExpr<OtherDerived> > operator-(const BaseExpr<OtherDerived> &b) const
{ return AddExpr<Derived,NegateExpr<OtherDerived> >(derived(), -b.derived()); }
template<typename OtherDerived>
ProductExpr<Derived,OtherDerived> operator*(const BaseExpr<OtherDerived> &b) const
{ return ProductExpr<Derived,OtherDerived>(derived(), b.derived()); }
template<typename OtherDerived>
QuotientExpr<Derived,OtherDerived> operator/(const BaseExpr<OtherDerived> &b) const
{ return QuotientExpr<Derived,OtherDerived>(derived(), b.derived()); }
};
template<typename T>
struct is_symbolic {
// BaseExpr has no conversion ctor, so we only have to check whether T can be statically cast to its base class BaseExpr<T>.
enum { value = internal::is_convertible<T,BaseExpr<T> >::value };
};
/** Represents the actual value of a symbol identified by its tag
*
* It is the return type of SymbolValue::operator=, and most of the time this is only way it is used.
*/
template<typename Tag>
class SymbolValue
{
public:
/** Default constructor from the value \a val */
SymbolValue(Index val) : m_value(val) {}
/** \returns the stored value of the symbol */
Index value() const { return m_value; }
protected:
Index m_value;
};
/** Expression of a symbol uniquely identified by the template parameter type \c tag */
template<typename tag>
class SymbolExpr : public BaseExpr<SymbolExpr<tag> >
{
public:
/** Alias to the template parameter \c tag */
typedef tag Tag;
SymbolExpr() {}
/** Associate the value \a val to the given symbol \c *this, uniquely identified by its \c Tag.
*
* The returned object should be passed to ExprBase::eval() to evaluate a given expression with this specified runtime-time value.
*/
SymbolValue<Tag> operator=(Index val) const {
return SymbolValue<Tag>(val);
}
Index eval_impl(const SymbolValue<Tag> &values) const { return values.value(); }
#if EIGEN_HAS_CXX14
// C++14 versions suitable for multiple symbols
template<typename... Types>
Index eval_impl(const std::tuple<Types...>& values) const { return std::get<SymbolValue<Tag> >(values).value(); }
#endif
};
template<typename Arg0>
class NegateExpr : public BaseExpr<NegateExpr<Arg0> >
{
public:
NegateExpr(const Arg0& arg0) : m_arg0(arg0) {}
template<typename T>
Index eval_impl(const T& values) const { return -m_arg0.eval_impl(values); }
protected:
Arg0 m_arg0;
};
template<typename Arg0, typename Arg1>
class AddExpr : public BaseExpr<AddExpr<Arg0,Arg1> >
{
public:
AddExpr(const Arg0& arg0, const Arg1& arg1) : m_arg0(arg0), m_arg1(arg1) {}
template<typename T>
Index eval_impl(const T& values) const { return m_arg0.eval_impl(values) + m_arg1.eval_impl(values); }
protected:
Arg0 m_arg0;
Arg1 m_arg1;
};
template<typename Arg0, typename Arg1>
class ProductExpr : public BaseExpr<ProductExpr<Arg0,Arg1> >
{
public:
ProductExpr(const Arg0& arg0, const Arg1& arg1) : m_arg0(arg0), m_arg1(arg1) {}
template<typename T>
Index eval_impl(const T& values) const { return m_arg0.eval_impl(values) * m_arg1.eval_impl(values); }
protected:
Arg0 m_arg0;
Arg1 m_arg1;
};
template<typename Arg0, typename Arg1>
class QuotientExpr : public BaseExpr<QuotientExpr<Arg0,Arg1> >
{
public:
QuotientExpr(const Arg0& arg0, const Arg1& arg1) : m_arg0(arg0), m_arg1(arg1) {}
template<typename T>
Index eval_impl(const T& values) const { return m_arg0.eval_impl(values) / m_arg1.eval_impl(values); }
protected:
Arg0 m_arg0;
Arg1 m_arg1;
};
} // end namespace symbolic
} // end namespace Eigen
#endif // EIGEN_SYMBOLIC_INDEX_H