diff --git a/lib/qcd/action/Actions.h b/lib/qcd/action/Actions.h index 9976c670..cf44aa7a 100644 --- a/lib/qcd/action/Actions.h +++ b/lib/qcd/action/Actions.h @@ -54,6 +54,7 @@ //////////////////////////////////////////// // Fermion operators / actions //////////////////////////////////////////// +#include #include #include // 4d wilson like diff --git a/lib/qcd/action/fermion/FermionImplTypedefs.h b/lib/qcd/action/fermion/FermionImplTypedefs.h index 99657d83..59f39be7 100644 --- a/lib/qcd/action/fermion/FermionImplTypedefs.h +++ b/lib/qcd/action/fermion/FermionImplTypedefs.h @@ -8,3 +8,4 @@ typedef typename Impl::SiteHalfSpinor SiteHalfSpinor; typedef typename Impl::Compressor Compressor; typedef WilsonKernels Kernels; + diff --git a/lib/qcd/action/fermion/FermionOperator.h b/lib/qcd/action/fermion/FermionOperator.h index 9829d350..4903f0e1 100644 --- a/lib/qcd/action/fermion/FermionOperator.h +++ b/lib/qcd/action/fermion/FermionOperator.h @@ -10,227 +10,7 @@ namespace Grid { // between gauge representation rank bc's, flavours etc. // and single/double precision. //////////////////////////////////////////////////////////////// - - template - class WilsonImpl { - public: - - typedef S Simd; - - template using iImplSpinor = iScalar, Ns> >; - template using iImplHalfSpinor = iScalar, Nhs> >; - - template using iImplGaugeLink = iScalar > >; - template using iImplGaugeField = iVector >, Nd >; - template using iImplDoubledGaugeField = iVector >, Nds >; - typedef iImplSpinor SiteSpinor; - typedef iImplHalfSpinor SiteHalfSpinor; - typedef iImplGaugeLink SiteGaugeLink; - typedef iImplGaugeField SiteGaugeField; - typedef iImplDoubledGaugeField SiteDoubledGaugeField; - - typedef Lattice FermionField; - typedef Lattice GaugeLinkField; // bit ugly naming; polarised gauge field, lorentz... all ugly - typedef Lattice GaugeField; - typedef Lattice DoubledGaugeField; - - typedef WilsonCompressor Compressor; - - static inline void multLink(SiteHalfSpinor &phi,const SiteDoubledGaugeField &U,const SiteHalfSpinor &chi,int mu,StencilEntry *SE,CartesianStencil &St){ - mult(&phi(),&U(mu),&chi()); - } - static inline void DoubleStore(GridBase *GaugeGrid,DoubledGaugeField &Uds,const GaugeField &Umu) - { - conformable(Uds._grid,GaugeGrid); - conformable(Umu._grid,GaugeGrid); - GaugeLinkField U(GaugeGrid); - for(int mu=0;mu(Umu,mu); - PokeIndex(Uds,U,mu); - U = adj(Cshift(U,mu,-1)); - PokeIndex(Uds,U,mu+4); - } - } - static inline void InsertForce(GaugeField &mat,const FermionField &Btilde,const FermionField &A,int mu){ - GaugeLinkField link(mat._grid); - link = TraceIndex(outerProduct(Btilde,A)); - PokeIndex(mat,link,mu); - } - - }; - - typedef WilsonImpl WilsonImplR; // Real.. whichever prec - typedef WilsonImpl WilsonImplF; // Float - typedef WilsonImpl WilsonImplD; // Double - - template - class GparityWilsonImpl { - public: - - typedef S Simd; - - template using iImplSpinor = iVector, Ns>, Ngp >; - template using iImplHalfSpinor = iVector, Nhs>, Ngp >; - template using iImplGaugeField = iVector >, Nd >; - - template using iImplGaugeLink = iScalar > >; - template using iImplDoubledGaugeField = iVector >, Nds >, Ngp >; - - typedef iImplSpinor SiteSpinor; - typedef iImplHalfSpinor SiteHalfSpinor; - typedef iImplGaugeLink SiteGaugeLink; - typedef iImplGaugeField SiteGaugeField; - typedef iImplDoubledGaugeField SiteDoubledGaugeField; - - typedef Lattice FermionField; - typedef Lattice GaugeLinkField; // bit ugly naming; polarised gauge field, lorentz... all ugly - typedef Lattice GaugeField; - typedef Lattice DoubledGaugeField; - - // typedef GparityWilsonCompressor Compressor; - typedef WilsonCompressor Compressor; - - // provide the multiply by link that is differentiated between Gparity (with flavour index) and - // non-Gparity - static inline void multLink(SiteHalfSpinor &phi,const SiteDoubledGaugeField &U,const SiteHalfSpinor &chi,int mu,StencilEntry *SE,CartesianStencil &St){ - // FIXME; need to be more careful. If this is a simd direction we are still stuffed - // Need access to _simd_layout[mu]. mu is not necessarily dim. - typedef SiteHalfSpinor vobj; - typedef typename SiteHalfSpinor::scalar_object sobj; - - vobj vtmp; - sobj stmp; - std::vector gpbc({0,0,0,1,0,0,0,1}); - - GridBase *grid = St._grid; - - const int Nsimd = grid->Nsimd(); - - int direction = St._directions[mu]; - int distance = St._distances[mu]; - int ptype = St._permute_type[mu]; - int sl = St._grid->_simd_layout[direction]; - - // assert our assumptions - assert((distance==1)||(distance==-1)); // nearest neighbour stencil hard code - assert((sl==1)||(sl==2)); - - std::vector icoor; - - if ( SE->_around_the_world && gpbc[mu] ) { - if ( sl == 2 ) { - - // std::cout << "multLink for mu= "< vals(Nsimd); - extract(chi,vals); - - for(int s=0;siCoorFromIindex(icoor,s); - - assert((icoor[direction]==0)||(icoor[direction]==1)); - - int permute_lane; - if ( distance == 1) { - permute_lane = icoor[direction]?1:0; - } else { - permute_lane = icoor[direction]?0:1; - } - - if ( permute_lane ) { - stmp(0) = vals[s](1); - stmp(1) = vals[s](0); - vals[s] = stmp; - } - } - - merge(vtmp,vals); - - } else { - vtmp(0) = chi(1); - vtmp(1) = chi(0); - } - mult(&phi(0),&U(0)(mu),&vtmp(0)); - mult(&phi(1),&U(1)(mu),&vtmp(1)); - - } else { - mult(&phi(0),&U(0)(mu),&chi(0)); - mult(&phi(1),&U(1)(mu),&chi(1)); - } - - } - - static inline void InsertForce(GaugeField &mat,const FermionField &Btilde,const FermionField &A,int mu){ - // Fixme - return; - } - static inline void DoubleStore(GridBase *GaugeGrid,DoubledGaugeField &Uds,const GaugeField &Umu) - { - conformable(Uds._grid,GaugeGrid); - conformable(Umu._grid,GaugeGrid); - - GaugeLinkField Utmp(GaugeGrid); - GaugeLinkField U(GaugeGrid); - GaugeLinkField Uconj(GaugeGrid); - - Lattice > coor(GaugeGrid); - - std::vector gpdirs({0,0,0,1}); - - for(int mu=0;mu(Umu,mu); - Uconj = conjugate(U); - - int neglink = GaugeGrid->GlobalDimensions()[mu]-1; - - if ( gpdirs[mu] ) { - Uconj = where(coor==neglink,-Uconj,Uconj); - } - -PARALLEL_FOR_LOOP - for(auto ss=U.begin();ss GparityWilsonImplR; // Real.. whichever prec - typedef GparityWilsonImpl GparityWilsonImplF; // Float - typedef GparityWilsonImpl GparityWilsonImplD; // Double - - ////////////////////////////////////////////////////////////////////////////// // Four component fermions @@ -238,10 +18,11 @@ PARALLEL_FOR_LOOP // Think about multiple representations ////////////////////////////////////////////////////////////////////////////// template - class FermionOperator : public CheckerBoardedSparseMatrixBase + class FermionOperator : public CheckerBoardedSparseMatrixBase, public Impl { public: #include + public: GridBase * Grid(void) { return FermionGrid(); }; // this is all the linalg routines need to know diff --git a/lib/qcd/action/fermion/FermionOperatorImpl.h b/lib/qcd/action/fermion/FermionOperatorImpl.h new file mode 100644 index 00000000..38a12d32 --- /dev/null +++ b/lib/qcd/action/fermion/FermionOperatorImpl.h @@ -0,0 +1,232 @@ +#ifndef GRID_QCD_FERMION_OPERATOR_IMPL_H +#define GRID_QCD_FERMION_OPERATOR_IMPL_H + +namespace Grid { + + namespace QCD { + + // Variable precision "S" and variable Nc + template + class WilsonImpl { + public: + + typedef S Simd; + + template using iImplSpinor = iScalar, Ns> >; + template using iImplHalfSpinor = iScalar, Nhs> >; + template using iImplGaugeLink = iScalar > >; + template using iImplGaugeField = iVector >, Nd >; + template using iImplDoubledGaugeField = iVector >, Nds >; + + typedef iImplSpinor SiteSpinor; + typedef iImplHalfSpinor SiteHalfSpinor; + typedef iImplGaugeLink SiteGaugeLink; + typedef iImplGaugeField SiteGaugeField; + typedef iImplDoubledGaugeField SiteDoubledGaugeField; + + typedef Lattice FermionField; + typedef Lattice GaugeLinkField; // bit ugly naming; polarised gauge field, lorentz... all ugly + typedef Lattice GaugeField; + typedef Lattice DoubledGaugeField; + + typedef WilsonCompressor Compressor; + + static inline void multLink(SiteHalfSpinor &phi,const SiteDoubledGaugeField &U,const SiteHalfSpinor &chi,int mu,StencilEntry *SE,CartesianStencil &St){ + mult(&phi(),&U(mu),&chi()); + } + inline void DoubleStore(GridBase *GaugeGrid,DoubledGaugeField &Uds,const GaugeField &Umu) + { + conformable(Uds._grid,GaugeGrid); + conformable(Umu._grid,GaugeGrid); + GaugeLinkField U(GaugeGrid); + for(int mu=0;mu(Umu,mu); + PokeIndex(Uds,U,mu); + U = adj(Cshift(U,mu,-1)); + PokeIndex(Uds,U,mu+4); + } + } + + inline void InsertForce(GaugeField &mat,const FermionField &Btilde,const FermionField &A,int mu){ + GaugeLinkField link(mat._grid); + link = TraceIndex(outerProduct(Btilde,A)); + PokeIndex(mat,link,mu); + } + + }; + + template + class GparityWilsonImpl { + public: + + typedef S Simd; + + template using iImplSpinor = iVector, Ns>, Ngp >; + template using iImplHalfSpinor = iVector, Nhs>, Ngp >; + template using iImplGaugeField = iVector >, Nd >; + + template using iImplGaugeLink = iScalar > >; + template using iImplDoubledGaugeField = iVector >, Nds >, Ngp >; + + typedef iImplSpinor SiteSpinor; + typedef iImplHalfSpinor SiteHalfSpinor; + typedef iImplGaugeLink SiteGaugeLink; + typedef iImplGaugeField SiteGaugeField; + typedef iImplDoubledGaugeField SiteDoubledGaugeField; + + typedef Lattice FermionField; + typedef Lattice GaugeLinkField; // bit ugly naming; polarised gauge field, lorentz... all ugly + typedef Lattice GaugeField; + typedef Lattice DoubledGaugeField; + + // typedef GparityWilsonCompressor Compressor; + typedef WilsonCompressor Compressor; + + // provide the multiply by link that is differentiated between Gparity (with flavour index) and + // non-Gparity + static inline void multLink(SiteHalfSpinor &phi,const SiteDoubledGaugeField &U,const SiteHalfSpinor &chi,int mu,StencilEntry *SE,CartesianStencil &St){ + + typedef SiteHalfSpinor vobj; + typedef typename SiteHalfSpinor::scalar_object sobj; + + vobj vtmp; + sobj stmp; + std::vector gpbc({0,0,0,1,0,0,0,1}); + + GridBase *grid = St._grid; + + const int Nsimd = grid->Nsimd(); + + int direction = St._directions[mu]; + int distance = St._distances[mu]; + int ptype = St._permute_type[mu]; + int sl = St._grid->_simd_layout[direction]; + + // assert our assumptions + assert((distance==1)||(distance==-1)); // nearest neighbour stencil hard code + assert((sl==1)||(sl==2)); + + std::vector icoor; + + if ( SE->_around_the_world && gpbc[mu] ) { + + if ( sl == 2 ) { + + std::vector vals(Nsimd); + + extract(chi,vals); + for(int s=0;siCoorFromIindex(icoor,s); + + assert((icoor[direction]==0)||(icoor[direction]==1)); + + int permute_lane; + if ( distance == 1) { + permute_lane = icoor[direction]?1:0; + } else { + permute_lane = icoor[direction]?0:1; + } + + if ( permute_lane ) { + stmp(0) = vals[s](1); + stmp(1) = vals[s](0); + vals[s] = stmp; + } + } + merge(vtmp,vals); + + } else { + vtmp(0) = chi(1); + vtmp(1) = chi(0); + } + mult(&phi(0),&U(0)(mu),&vtmp(0)); + mult(&phi(1),&U(1)(mu),&vtmp(1)); + + } else { + mult(&phi(0),&U(0)(mu),&chi(0)); + mult(&phi(1),&U(1)(mu),&chi(1)); + } + + } + + static inline void InsertForce(GaugeField &mat,const FermionField &Btilde,const FermionField &A,int mu){ + // Fixme + return; + } + static inline void DoubleStore(GridBase *GaugeGrid,DoubledGaugeField &Uds,const GaugeField &Umu) + { + + conformable(Uds._grid,GaugeGrid); + conformable(Umu._grid,GaugeGrid); + + GaugeLinkField Utmp(GaugeGrid); + GaugeLinkField U(GaugeGrid); + GaugeLinkField Uconj(GaugeGrid); + + Lattice > coor(GaugeGrid); + + std::vector gpdirs({0,0,0,1}); + + for(int mu=0;mu(Umu,mu); + Uconj = conjugate(U); + + int neglink = GaugeGrid->GlobalDimensions()[mu]-1; + + if ( gpdirs[mu] ) { + Uconj = where(coor==neglink,-Uconj,Uconj); + } + +PARALLEL_FOR_LOOP + for(auto ss=U.begin();ss WilsonImplR; // Real.. whichever prec + typedef WilsonImpl WilsonImplF; // Float + typedef WilsonImpl WilsonImplD; // Double + + typedef GparityWilsonImpl GparityWilsonImplR; // Real.. whichever prec + typedef GparityWilsonImpl GparityWilsonImplF; // Float + typedef GparityWilsonImpl GparityWilsonImplD; // Double + + + } +} +#endif