1
0
mirror of https://github.com/paboyle/Grid.git synced 2025-04-09 21:50:45 +01:00

Support a view for passing to accelerator

This commit is contained in:
paboyle 2018-03-04 15:54:35 +00:00
parent e5ea04ee0c
commit 9b1f29c4c2
3 changed files with 26 additions and 17 deletions

View File

@ -271,7 +271,8 @@ public:
SimpleCompressor<siteVector> compressor; SimpleCompressor<siteVector> compressor;
Stencil.HaloExchange(in,compressor); Stencil.HaloExchange(in,compressor);
auto in_v = in.View();
auto out_v = in.View();
thread_loop( (int ss=0;ss<Grid()->oSites();ss++),{ thread_loop( (int ss=0;ss<Grid()->oSites();ss++),{
siteVector res = Zero(); siteVector res = Zero();
siteVector nbr; siteVector nbr;
@ -282,15 +283,16 @@ public:
SE=Stencil.GetEntry(ptype,point,ss); SE=Stencil.GetEntry(ptype,point,ss);
if(SE->_is_local&&SE->_permute) { if(SE->_is_local&&SE->_permute) {
permute(nbr,in[SE->_offset],ptype); permute(nbr,in_v[SE->_offset],ptype);
} else if(SE->_is_local) { } else if(SE->_is_local) {
nbr = in[SE->_offset]; nbr = in_v[SE->_offset];
} else { } else {
nbr = Stencil.CommBuf()[SE->_offset]; nbr = Stencil.CommBuf()[SE->_offset];
} }
res = res + A[point][ss]*nbr; auto A_point = A[point].View();
res = res + A_point[ss]*nbr;
} }
vstream(out[ss],res); vstream(out_v[ss],res);
}); });
return norm2(out); return norm2(out);
}; };
@ -384,12 +386,16 @@ public:
Subspace.ProjectToSubspace(oProj,oblock); Subspace.ProjectToSubspace(oProj,oblock);
// blockProject(iProj,iblock,Subspace.subspace); // blockProject(iProj,iblock,Subspace.subspace);
// blockProject(oProj,oblock,Subspace.subspace); // blockProject(oProj,oblock,Subspace.subspace);
auto iProj_v = iProj.View() ;
auto oProj_v = oProj.View() ;
auto A_p = A[p].View();
auto A_self = A[self_stencil].View();
thread_loop( (int ss=0;ss<Grid()->oSites();ss++),{ thread_loop( (int ss=0;ss<Grid()->oSites();ss++),{
for(int j=0;j<nbasis;j++){ for(int j=0;j<nbasis;j++){
if( disp!= 0 ) { if( disp!= 0 ) {
A[p][ss](j,i) = oProj[ss](j); A_p[ss](j,i) = oProj_v[ss](j);
} }
A[self_stencil][ss](j,i) = A[self_stencil][ss](j,i) + iProj[ss](j); A_self[ss](j,i) = A_self[ss](j,i) + iProj_v[ss](j);
} }
}); });
} }

View File

@ -191,7 +191,7 @@ public:
typedef typename sobj::scalar_type scalar; typedef typename sobj::scalar_type scalar;
Lattice<sobj> pgbuf(&pencil_g); Lattice<sobj> pgbuf(&pencil_g);
auto pgbuf_v = pgbuf.View();
typedef typename FFTW<scalar>::FFTW_scalar FFTW_scalar; typedef typename FFTW<scalar>::FFTW_scalar FFTW_scalar;
typedef typename FFTW<scalar>::FFTW_plan FFTW_plan; typedef typename FFTW<scalar>::FFTW_plan FFTW_plan;
@ -217,8 +217,8 @@ public:
FFTW_plan p; FFTW_plan p;
{ {
FFTW_scalar *in = (FFTW_scalar *)&pgbuf[0]; FFTW_scalar *in = (FFTW_scalar *)&pgbuf_v[0];
FFTW_scalar *out= (FFTW_scalar *)&pgbuf[0]; FFTW_scalar *out= (FFTW_scalar *)&pgbuf_v[0];
p = FFTW<scalar>::fftw_plan_many_dft(rank,n,howmany, p = FFTW<scalar>::fftw_plan_many_dft(rank,n,howmany,
in,inembed, in,inembed,
istride,idist, istride,idist,
@ -254,8 +254,8 @@ public:
Coordinate cbuf(Nd); Coordinate cbuf(Nd);
pencil_g.LocalIndexToLocalCoor(idx, cbuf); pencil_g.LocalIndexToLocalCoor(idx, cbuf);
if ( cbuf[dim] == 0 ) { // restricts loop to plane at lcoor[dim]==0 if ( cbuf[dim] == 0 ) { // restricts loop to plane at lcoor[dim]==0
FFTW_scalar *in = (FFTW_scalar *)&pgbuf[idx]; FFTW_scalar *in = (FFTW_scalar *)&pgbuf_v[idx];
FFTW_scalar *out= (FFTW_scalar *)&pgbuf[idx]; FFTW_scalar *out= (FFTW_scalar *)&pgbuf_v[idx];
FFTW<scalar>::fftw_execute_dft(p,in,out); FFTW<scalar>::fftw_execute_dft(p,in,out);
} }
}); });

View File

@ -58,17 +58,18 @@ void basisRotate(std::vector<Field> &basis,Eigen::MatrixXd& Qt,int j0, int j1, i
thread_region thread_region
{ {
std::vector < vobj > B(Nm); // Thread private std::vector < vobj > B(Nm); // Thread private
thread_loop_in_region( (int ss=0;ss < grid->oSites();ss++),{ thread_loop_in_region( (int ss=0;ss < grid->oSites();ss++),{
for(int j=j0; j<j1; ++j) B[j]=0.; for(int j=j0; j<j1; ++j) B[j]=0.;
for(int j=j0; j<j1; ++j){ for(int j=j0; j<j1; ++j){
for(int k=k0; k<k1; ++k){ for(int k=k0; k<k1; ++k){
B[j] +=Qt(j,k) * basis[k][ss]; auto basis_k = basis[k].View();
B[j] +=Qt(j,k) * basis_k[ss];
} }
} }
for(int j=j0; j<j1; ++j){ for(int j=j0; j<j1; ++j){
basis[j][ss] = B[j]; auto basis_j = basis[j].View();
basis_j[ss] = B[j];
} }
}); });
} }
@ -82,12 +83,14 @@ void basisRotateJ(Field &result,std::vector<Field> &basis,Eigen::MatrixXd& Qt,in
GridBase* grid = basis[0].Grid(); GridBase* grid = basis[0].Grid();
result.Checkerboard() = basis[0].Checkerboard(); result.Checkerboard() = basis[0].Checkerboard();
auto result_v=result.View();
thread_loop( (int ss=0;ss < grid->oSites();ss++),{ thread_loop( (int ss=0;ss < grid->oSites();ss++),{
vobj B = Zero(); vobj B = Zero();
for(int k=k0; k<k1; ++k){ for(int k=k0; k<k1; ++k){
B +=Qt(j,k) * basis[k][ss]; auto basis_k = basis[k].View();
B +=Qt(j,k) * basis_k[ss];
} }
result[ss] = B; result_v[ss] = B;
}); });
} }