mirror of
https://github.com/paboyle/Grid.git
synced 2025-04-07 04:35:56 +01:00
IfGridTensor shorthand
This commit is contained in:
parent
ffc0639cb9
commit
e936f5b80b
Grid/tensors
@ -69,6 +69,35 @@ accelerator_inline auto trace(const iVector<vtype,N> &arg) -> iVector<decltype(t
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
////////////////////////////
|
||||
// Fast path traceProduct
|
||||
////////////////////////////
|
||||
template<class S1 , class S2, IfNotGridTensor<S1> = 0, IfNotGridTensor<S2> = 0>
|
||||
accelerator_inline auto traceProduct( const S1 &arg1,const S2 &arg2)
|
||||
-> decltype(arg1*arg2)
|
||||
{
|
||||
return arg1*arg2;
|
||||
}
|
||||
|
||||
template<class vtype,class rtype,int N >
|
||||
accelerator_inline auto traceProduct(const iMatrix<vtype,N> &arg1,const iMatrix<rtype,N> &arg2) -> iScalar<decltype(trace(arg1._internal[0][0]*arg2._internal[0][0]))>
|
||||
{
|
||||
iScalar<decltype( trace(arg1._internal[0][0]*arg2._internal[0][0] )) > ret;
|
||||
zeroit(ret._internal);
|
||||
for(int i=0;i<N;i++){
|
||||
for(int j=0;j<N;j++){
|
||||
ret._internal=ret._internal+traceProduct(arg1._internal[i][j],arg2._internal[j][i]);
|
||||
}}
|
||||
return ret;
|
||||
}
|
||||
|
||||
template<class vtype,class rtype >
|
||||
accelerator_inline auto traceProduct(const iScalar<vtype> &arg1,const iScalar<rtype> &arg2) -> iScalar<decltype(trace(arg1._internal*arg2._internal))>
|
||||
{
|
||||
iScalar<decltype(trace(arg1._internal*arg2._internal))> ret;
|
||||
ret._internal=traceProduct(arg1._internal,arg2._internal);
|
||||
return ret;
|
||||
}
|
||||
|
||||
NAMESPACE_END(Grid);
|
||||
|
||||
|
@ -34,9 +34,12 @@ NAMESPACE_BEGIN(Grid);
|
||||
|
||||
// These are the Grid tensors
|
||||
template<typename T> struct isGridTensor : public std::false_type { static constexpr bool notvalue = true; };
|
||||
template<class T> struct isGridTensor<iScalar<T>> : public std::true_type { static constexpr bool notvalue = false; };
|
||||
template<class T, int N> struct isGridTensor<iVector<T, N>> : public std::true_type { static constexpr bool notvalue = false; };
|
||||
template<class T, int N> struct isGridTensor<iMatrix<T, N>> : public std::true_type { static constexpr bool notvalue = false; };
|
||||
template<class T> struct isGridTensor<iScalar<T> > : public std::true_type { static constexpr bool notvalue = false; };
|
||||
template<class T, int N> struct isGridTensor<iVector<T, N> >: public std::true_type { static constexpr bool notvalue = false; };
|
||||
template<class T, int N> struct isGridTensor<iMatrix<T, N> >: public std::true_type { static constexpr bool notvalue = false; };
|
||||
|
||||
template <typename T> using IfGridTensor = Invoke<std::enable_if<isGridTensor<T>::value, int> >;
|
||||
template <typename T> using IfNotGridTensor = Invoke<std::enable_if<!isGridTensor<T>::value, int> >;
|
||||
|
||||
// Traits to identify scalars
|
||||
template<typename T> struct isGridScalar : public std::false_type { static constexpr bool notvalue = true; };
|
||||
|
Loading…
x
Reference in New Issue
Block a user