mirror of
https://github.com/paboyle/Grid.git
synced 2025-04-09 13:40:46 +01:00
IfGridTensor shorthand
This commit is contained in:
parent
ffc0639cb9
commit
e936f5b80b
@ -69,6 +69,35 @@ accelerator_inline auto trace(const iVector<vtype,N> &arg) -> iVector<decltype(t
|
|||||||
}
|
}
|
||||||
return ret;
|
return ret;
|
||||||
}
|
}
|
||||||
|
////////////////////////////
|
||||||
|
// Fast path traceProduct
|
||||||
|
////////////////////////////
|
||||||
|
template<class S1 , class S2, IfNotGridTensor<S1> = 0, IfNotGridTensor<S2> = 0>
|
||||||
|
accelerator_inline auto traceProduct( const S1 &arg1,const S2 &arg2)
|
||||||
|
-> decltype(arg1*arg2)
|
||||||
|
{
|
||||||
|
return arg1*arg2;
|
||||||
|
}
|
||||||
|
|
||||||
|
template<class vtype,class rtype,int N >
|
||||||
|
accelerator_inline auto traceProduct(const iMatrix<vtype,N> &arg1,const iMatrix<rtype,N> &arg2) -> iScalar<decltype(trace(arg1._internal[0][0]*arg2._internal[0][0]))>
|
||||||
|
{
|
||||||
|
iScalar<decltype( trace(arg1._internal[0][0]*arg2._internal[0][0] )) > ret;
|
||||||
|
zeroit(ret._internal);
|
||||||
|
for(int i=0;i<N;i++){
|
||||||
|
for(int j=0;j<N;j++){
|
||||||
|
ret._internal=ret._internal+traceProduct(arg1._internal[i][j],arg2._internal[j][i]);
|
||||||
|
}}
|
||||||
|
return ret;
|
||||||
|
}
|
||||||
|
|
||||||
|
template<class vtype,class rtype >
|
||||||
|
accelerator_inline auto traceProduct(const iScalar<vtype> &arg1,const iScalar<rtype> &arg2) -> iScalar<decltype(trace(arg1._internal*arg2._internal))>
|
||||||
|
{
|
||||||
|
iScalar<decltype(trace(arg1._internal*arg2._internal))> ret;
|
||||||
|
ret._internal=traceProduct(arg1._internal,arg2._internal);
|
||||||
|
return ret;
|
||||||
|
}
|
||||||
|
|
||||||
NAMESPACE_END(Grid);
|
NAMESPACE_END(Grid);
|
||||||
|
|
||||||
|
@ -34,9 +34,12 @@ NAMESPACE_BEGIN(Grid);
|
|||||||
|
|
||||||
// These are the Grid tensors
|
// These are the Grid tensors
|
||||||
template<typename T> struct isGridTensor : public std::false_type { static constexpr bool notvalue = true; };
|
template<typename T> struct isGridTensor : public std::false_type { static constexpr bool notvalue = true; };
|
||||||
template<class T> struct isGridTensor<iScalar<T>> : public std::true_type { static constexpr bool notvalue = false; };
|
template<class T> struct isGridTensor<iScalar<T> > : public std::true_type { static constexpr bool notvalue = false; };
|
||||||
template<class T, int N> struct isGridTensor<iVector<T, N>> : public std::true_type { static constexpr bool notvalue = false; };
|
template<class T, int N> struct isGridTensor<iVector<T, N> >: public std::true_type { static constexpr bool notvalue = false; };
|
||||||
template<class T, int N> struct isGridTensor<iMatrix<T, N>> : public std::true_type { static constexpr bool notvalue = false; };
|
template<class T, int N> struct isGridTensor<iMatrix<T, N> >: public std::true_type { static constexpr bool notvalue = false; };
|
||||||
|
|
||||||
|
template <typename T> using IfGridTensor = Invoke<std::enable_if<isGridTensor<T>::value, int> >;
|
||||||
|
template <typename T> using IfNotGridTensor = Invoke<std::enable_if<!isGridTensor<T>::value, int> >;
|
||||||
|
|
||||||
// Traits to identify scalars
|
// Traits to identify scalars
|
||||||
template<typename T> struct isGridScalar : public std::false_type { static constexpr bool notvalue = true; };
|
template<typename T> struct isGridScalar : public std::false_type { static constexpr bool notvalue = true; };
|
||||||
|
Loading…
x
Reference in New Issue
Block a user