mirror of
				https://github.com/paboyle/Grid.git
				synced 2025-11-04 14:04:32 +00:00 
			
		
		
		
	Support a view for passing to accelerator
This commit is contained in:
		@@ -271,7 +271,8 @@ public:
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
    SimpleCompressor<siteVector> compressor;
 | 
					    SimpleCompressor<siteVector> compressor;
 | 
				
			||||||
    Stencil.HaloExchange(in,compressor);
 | 
					    Stencil.HaloExchange(in,compressor);
 | 
				
			||||||
 | 
					    auto in_v = in.View();
 | 
				
			||||||
 | 
					    auto out_v = in.View();
 | 
				
			||||||
    thread_loop( (int ss=0;ss<Grid()->oSites();ss++),{
 | 
					    thread_loop( (int ss=0;ss<Grid()->oSites();ss++),{
 | 
				
			||||||
      siteVector res = Zero();
 | 
					      siteVector res = Zero();
 | 
				
			||||||
      siteVector nbr;
 | 
					      siteVector nbr;
 | 
				
			||||||
@@ -282,15 +283,16 @@ public:
 | 
				
			|||||||
	SE=Stencil.GetEntry(ptype,point,ss);
 | 
						SE=Stencil.GetEntry(ptype,point,ss);
 | 
				
			||||||
	  
 | 
						  
 | 
				
			||||||
	if(SE->_is_local&&SE->_permute) { 
 | 
						if(SE->_is_local&&SE->_permute) { 
 | 
				
			||||||
	  permute(nbr,in[SE->_offset],ptype);
 | 
						  permute(nbr,in_v[SE->_offset],ptype);
 | 
				
			||||||
	} else if(SE->_is_local) { 
 | 
						} else if(SE->_is_local) { 
 | 
				
			||||||
	  nbr = in[SE->_offset];
 | 
						  nbr = in_v[SE->_offset];
 | 
				
			||||||
	} else {
 | 
						} else {
 | 
				
			||||||
	  nbr = Stencil.CommBuf()[SE->_offset];
 | 
						  nbr = Stencil.CommBuf()[SE->_offset];
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	res = res + A[point][ss]*nbr;
 | 
						auto A_point = A[point].View();
 | 
				
			||||||
 | 
						res = res + A_point[ss]*nbr;
 | 
				
			||||||
      }
 | 
					      }
 | 
				
			||||||
      vstream(out[ss],res);
 | 
					      vstream(out_v[ss],res);
 | 
				
			||||||
    });
 | 
					    });
 | 
				
			||||||
    return norm2(out);
 | 
					    return norm2(out);
 | 
				
			||||||
  };
 | 
					  };
 | 
				
			||||||
@@ -384,12 +386,16 @@ public:
 | 
				
			|||||||
	Subspace.ProjectToSubspace(oProj,oblock);
 | 
						Subspace.ProjectToSubspace(oProj,oblock);
 | 
				
			||||||
	//	  blockProject(iProj,iblock,Subspace.subspace);
 | 
						//	  blockProject(iProj,iblock,Subspace.subspace);
 | 
				
			||||||
	//	  blockProject(oProj,oblock,Subspace.subspace);
 | 
						//	  blockProject(oProj,oblock,Subspace.subspace);
 | 
				
			||||||
 | 
						auto iProj_v = iProj.View() ;
 | 
				
			||||||
 | 
						auto oProj_v = oProj.View() ;
 | 
				
			||||||
 | 
						auto A_p     =  A[p].View();
 | 
				
			||||||
 | 
						auto A_self  = A[self_stencil].View();
 | 
				
			||||||
	thread_loop( (int ss=0;ss<Grid()->oSites();ss++),{
 | 
						thread_loop( (int ss=0;ss<Grid()->oSites();ss++),{
 | 
				
			||||||
	  for(int j=0;j<nbasis;j++){
 | 
						  for(int j=0;j<nbasis;j++){
 | 
				
			||||||
	    if( disp!= 0 ) {
 | 
						    if( disp!= 0 ) {
 | 
				
			||||||
	      A[p][ss](j,i) = oProj[ss](j);
 | 
						      A_p[ss](j,i) = oProj_v[ss](j);
 | 
				
			||||||
	    }
 | 
						    }
 | 
				
			||||||
	    A[self_stencil][ss](j,i) =	A[self_stencil][ss](j,i) + iProj[ss](j);
 | 
						    A_self[ss](j,i) =	A_self[ss](j,i) + iProj_v[ss](j);
 | 
				
			||||||
	  }
 | 
						  }
 | 
				
			||||||
	});
 | 
						});
 | 
				
			||||||
      }
 | 
					      }
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -191,7 +191,7 @@ public:
 | 
				
			|||||||
    typedef typename sobj::scalar_type   scalar;
 | 
					    typedef typename sobj::scalar_type   scalar;
 | 
				
			||||||
      
 | 
					      
 | 
				
			||||||
    Lattice<sobj> pgbuf(&pencil_g);
 | 
					    Lattice<sobj> pgbuf(&pencil_g);
 | 
				
			||||||
      
 | 
					    auto pgbuf_v = pgbuf.View();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    typedef typename FFTW<scalar>::FFTW_scalar FFTW_scalar;
 | 
					    typedef typename FFTW<scalar>::FFTW_scalar FFTW_scalar;
 | 
				
			||||||
    typedef typename FFTW<scalar>::FFTW_plan   FFTW_plan;
 | 
					    typedef typename FFTW<scalar>::FFTW_plan   FFTW_plan;
 | 
				
			||||||
@@ -217,8 +217,8 @@ public:
 | 
				
			|||||||
      
 | 
					      
 | 
				
			||||||
    FFTW_plan p;
 | 
					    FFTW_plan p;
 | 
				
			||||||
    {
 | 
					    {
 | 
				
			||||||
      FFTW_scalar *in = (FFTW_scalar *)&pgbuf[0];
 | 
					      FFTW_scalar *in = (FFTW_scalar *)&pgbuf_v[0];
 | 
				
			||||||
      FFTW_scalar *out= (FFTW_scalar *)&pgbuf[0];
 | 
					      FFTW_scalar *out= (FFTW_scalar *)&pgbuf_v[0];
 | 
				
			||||||
      p = FFTW<scalar>::fftw_plan_many_dft(rank,n,howmany,
 | 
					      p = FFTW<scalar>::fftw_plan_many_dft(rank,n,howmany,
 | 
				
			||||||
					   in,inembed,
 | 
										   in,inembed,
 | 
				
			||||||
					   istride,idist,
 | 
										   istride,idist,
 | 
				
			||||||
@@ -254,8 +254,8 @@ public:
 | 
				
			|||||||
        Coordinate cbuf(Nd);
 | 
					        Coordinate cbuf(Nd);
 | 
				
			||||||
	pencil_g.LocalIndexToLocalCoor(idx, cbuf);
 | 
						pencil_g.LocalIndexToLocalCoor(idx, cbuf);
 | 
				
			||||||
	if ( cbuf[dim] == 0 ) {  // restricts loop to plane at lcoor[dim]==0
 | 
						if ( cbuf[dim] == 0 ) {  // restricts loop to plane at lcoor[dim]==0
 | 
				
			||||||
	  FFTW_scalar *in = (FFTW_scalar *)&pgbuf[idx];
 | 
						  FFTW_scalar *in = (FFTW_scalar *)&pgbuf_v[idx];
 | 
				
			||||||
	  FFTW_scalar *out= (FFTW_scalar *)&pgbuf[idx];
 | 
						  FFTW_scalar *out= (FFTW_scalar *)&pgbuf_v[idx];
 | 
				
			||||||
	  FFTW<scalar>::fftw_execute_dft(p,in,out);
 | 
						  FFTW<scalar>::fftw_execute_dft(p,in,out);
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
    });
 | 
					    });
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -58,17 +58,18 @@ void basisRotate(std::vector<Field> &basis,Eigen::MatrixXd& Qt,int j0, int j1, i
 | 
				
			|||||||
  thread_region
 | 
					  thread_region
 | 
				
			||||||
  {
 | 
					  {
 | 
				
			||||||
    std::vector < vobj > B(Nm); // Thread private
 | 
					    std::vector < vobj > B(Nm); // Thread private
 | 
				
			||||||
        
 | 
					 | 
				
			||||||
    thread_loop_in_region( (int ss=0;ss < grid->oSites();ss++),{
 | 
					    thread_loop_in_region( (int ss=0;ss < grid->oSites();ss++),{
 | 
				
			||||||
      for(int j=j0; j<j1; ++j) B[j]=0.;
 | 
					      for(int j=j0; j<j1; ++j) B[j]=0.;
 | 
				
			||||||
      
 | 
					      
 | 
				
			||||||
      for(int j=j0; j<j1; ++j){
 | 
					      for(int j=j0; j<j1; ++j){
 | 
				
			||||||
	for(int k=k0; k<k1; ++k){
 | 
						for(int k=k0; k<k1; ++k){
 | 
				
			||||||
	  B[j] +=Qt(j,k) * basis[k][ss];
 | 
						  auto basis_k = basis[k].View();
 | 
				
			||||||
 | 
						  B[j] +=Qt(j,k) * basis_k[ss];
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
      }
 | 
					      }
 | 
				
			||||||
      for(int j=j0; j<j1; ++j){
 | 
					      for(int j=j0; j<j1; ++j){
 | 
				
			||||||
	  basis[j][ss] = B[j];
 | 
						auto basis_j = basis[j].View();
 | 
				
			||||||
 | 
						basis_j[ss] = B[j];
 | 
				
			||||||
      }
 | 
					      }
 | 
				
			||||||
    });
 | 
					    });
 | 
				
			||||||
  }
 | 
					  }
 | 
				
			||||||
@@ -82,12 +83,14 @@ void basisRotateJ(Field &result,std::vector<Field> &basis,Eigen::MatrixXd& Qt,in
 | 
				
			|||||||
  GridBase* grid = basis[0].Grid();
 | 
					  GridBase* grid = basis[0].Grid();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  result.Checkerboard() = basis[0].Checkerboard();
 | 
					  result.Checkerboard() = basis[0].Checkerboard();
 | 
				
			||||||
 | 
					  auto result_v=result.View();
 | 
				
			||||||
  thread_loop( (int ss=0;ss < grid->oSites();ss++),{
 | 
					  thread_loop( (int ss=0;ss < grid->oSites();ss++),{
 | 
				
			||||||
    vobj B = Zero();
 | 
					    vobj B = Zero();
 | 
				
			||||||
    for(int k=k0; k<k1; ++k){
 | 
					    for(int k=k0; k<k1; ++k){
 | 
				
			||||||
      B +=Qt(j,k) * basis[k][ss];
 | 
					      auto basis_k = basis[k].View();
 | 
				
			||||||
 | 
					      B +=Qt(j,k) * basis_k[ss];
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
    result[ss] = B;
 | 
					    result_v[ss] = B;
 | 
				
			||||||
  });
 | 
					  });
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 
 | 
				
			|||||||
		Reference in New Issue
	
	Block a user