diff --git a/Grid/algorithms/deflation/MultiRHSBlockCGLinalg.h b/Grid/algorithms/deflation/MultiRHSBlockCGLinalg.h new file mode 100644 index 00000000..9db5313d --- /dev/null +++ b/Grid/algorithms/deflation/MultiRHSBlockCGLinalg.h @@ -0,0 +1,249 @@ +/************************************************************************************* + + Grid physics library, www.github.com/paboyle/Grid + + Source file: MultiRHSBlockCGLinalg.h + + Copyright (C) 2024 + +Author: Peter Boyle + + This program is free software; you can redistribute it and/or modify + it under the terms of the GNU General Public License as published by + the Free Software Foundation; either version 2 of the License, or + (at your option) any later version. + + This program is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU General Public License for more details. + + You should have received a copy of the GNU General Public License along + with this program; if not, write to the Free Software Foundation, Inc., + 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA. + + See the full license in the file "LICENSE" in the top level distribution directory +*************************************************************************************/ +/* END LEGAL */ +#pragma once + +NAMESPACE_BEGIN(Grid); + + +/* Need helper object for BLAS accelerated mrhs blockCG */ +template +class MultiRHSBlockCGLinalg +{ +public: + + typedef typename Field::scalar_type scalar; + typedef typename Field::scalar_object scalar_object; + + deviceVector BLAS_X; // nrhs x vol -- the sources + deviceVector BLAS_Y; // nrhs x vol -- the result + deviceVector BLAS_C; // nrhs x nrhs -- the coefficients + + MultiRHSBlockCGLinalg() {}; + ~MultiRHSBlockCGLinalg(){ Deallocate(); }; + + void Deallocate(void) + { + BLAS_C.resize(0); + BLAS_X.resize(0); + BLAS_Y.resize(0); + } + void MaddMatrix(std::vector &AP, Eigen::MatrixXcd &m , const std::vector &X,const std::vector &Y,RealD scale=1.0) + { + std::vector Y_copy(AP.size(),AP[0].Grid()); + for(int r=0;r &Y, Eigen::MatrixXcd &m , const std::vector &X) + { + typedef typename Field::scalar_type scomplex; + GridBase *grid; + uint64_t vol; + uint64_t words; + + int nrhs = Y.size(); + grid = X[0].Grid(); + vol = grid->lSites(); + words = sizeof(scalar_object)/sizeof(scalar); + int64_t vw = vol * words; + + RealD t0 = usecond(); + BLAS_X.resize(nrhs * vw); // cost free if size doesn't change + BLAS_Y.resize(nrhs * vw); // cost free if size doesn't change + BLAS_C.resize(nrhs * nrhs);// cost free if size doesn't change + RealD t1 = usecond(); + + ///////////////////////////////////////////// + // Copy in the multi-rhs sources + ///////////////////////////////////////////// + for(int r=0;r Xd(1); + deviceVector Yd(1); + deviceVector Cd(1); + + scalar * Xh = & BLAS_X[0]; + scalar * Yh = & BLAS_Y[0]; + scalar * Ch = & BLAS_C[0]; + + acceleratorPut(Xd[0],Xh); + acceleratorPut(Yd[0],Yh); + acceleratorPut(Cd[0],Ch); + + RealD t2 = usecond(); + GridBLAS BLAS; + ///////////////////////////////////////// + // Y = X*C (transpose?) + ///////////////////////////////////////// + BLAS.gemmBatched(GridBLAS_OP_N,GridBLAS_OP_N, + vw,nrhs,nrhs, + ComplexD(1.0), + Xd, + Cd, + ComplexD(0.0), // wipe out Y + Yd); + BLAS.synchronise(); + RealD t3 = usecond(); + + // Copy back Y = m X + for(int r=0;r &X, const std::vector &Y) + { + int nrhs; + GridBase *grid; + uint64_t vol; + uint64_t words; + + nrhs = X.size(); + assert(X.size()==Y.size()); + conformable(X[0],Y[0]); + + grid = X[0].Grid(); + vol = grid->lSites(); + words = sizeof(scalar_object)/sizeof(scalar); + int64_t vw = vol * words; + + RealD t0 = usecond(); + BLAS_X.resize(nrhs * vw); // cost free if size doesn't change + BLAS_Y.resize(nrhs * vw); // cost free if size doesn't change + BLAS_C.resize(nrhs * nrhs);// cost free if size doesn't change + RealD t1 = usecond(); + + ///////////////////////////////////////////// + // Copy in the multi-rhs sources + ///////////////////////////////////////////// + for(int r=0;r Xd(1); + deviceVector Yd(1); + deviceVector Cd(1); + + scalar * Xh = & BLAS_X[0]; + scalar * Yh = & BLAS_Y[0]; + scalar * Ch = & BLAS_C[0]; + + acceleratorPut(Xd[0],Xh); + acceleratorPut(Yd[0],Yh); + acceleratorPut(Cd[0],Ch); + + GridBLAS BLAS; + + RealD t3 = usecond(); + ///////////////////////////////////////// + // C_rs = X^dag Y + ///////////////////////////////////////// + BLAS.gemmBatched(GridBLAS_OP_C,GridBLAS_OP_N, + nrhs,nrhs,vw, + ComplexD(1.0), + Xd, + Yd, + ComplexD(0.0), // wipe out C + Cd); + BLAS.synchronise(); + RealD t4 = usecond(); + + std::vector HOST_C(BLAS_C.size()); // nrhs . nrhs -- the coefficients + acceleratorCopyFromDevice(&BLAS_C[0],&HOST_C[0],BLAS_C.size()*sizeof(scalar)); + grid->GlobalSumVector(&HOST_C[0],nrhs*nrhs); + + RealD t5 = usecond(); + for(int rr=0;rr