diff --git a/.gitignore b/.gitignore index 5338acb9..13efd67c 100644 --- a/.gitignore +++ b/.gitignore @@ -10,6 +10,8 @@ *~ *# *.sublime-* +.ctags +tags # Precompiled Headers # ####################### diff --git a/Grid/algorithms/iterative/ImplicitlyRestartedBlockLanczos.h b/Grid/algorithms/iterative/ImplicitlyRestartedBlockLanczos.h index 5076a527..7cc11653 100644 --- a/Grid/algorithms/iterative/ImplicitlyRestartedBlockLanczos.h +++ b/Grid/algorithms/iterative/ImplicitlyRestartedBlockLanczos.h @@ -39,6 +39,10 @@ Author: Guido Cossu #undef USE_LAPACK #define Glog std::cout << GridLogMessage +#ifdef GRID_NVCC +#include "cublas_v2.h" +#endif + namespace Grid { //////////////////////////////////////////////////////////////////////////////// @@ -89,6 +93,12 @@ class SortEigen { enum class LanczosType { irbl, rbl }; +enum IRBLdiagonalisation { + IRBLdiagonaliseWithDSTEGR, + IRBLdiagonaliseWithQR, + IRBLdiagonaliseWithEigen +}; + ///////////////////////////////////////////////////////////// // Implicitly restarted block lanczos ///////////////////////////////////////////////////////////// @@ -107,7 +117,7 @@ private: int Nblock_m; // Nm/Nu int Nconv_test_interval; // Number of skipped vectors when checking a convergence RealD eresid; - IRLdiagonalisation diagonalisation; + IRBLdiagonalisation diagonalisation; int split_test; //test split in the first iteration //////////////////////////////////// // Embedded objects @@ -137,7 +147,7 @@ public: int _Nm, // total vecs RealD _eresid, // resid in lmd deficit int _MaxIter, // Max iterations - IRLdiagonalisation _diagonalisation = IRLdiagonaliseWithEigen) + IRBLdiagonalisation _diagonalisation = IRBLdiagonaliseWithEigen) : _Linop(Linop), _SLinop(SLinop), _poly(poly),sf_grid(SFrbGrid),f_grid(FrbGrid), Nstop(_Nstop), Nconv_test_interval(_Nconv_test_interval), mrhs(_mrhs), Nu(_Nu), Nk(_Nk), Nm(_Nm), @@ -211,7 +221,126 @@ public: #endif }} for(int i=0; i<_Nu; ++i) - normalize(w[i],if_print); + assert(normalize(w[i],if_print) !=0); + } + + + void orthogonalize_blas(std::vector& w, int _Nu, std::vector& evec, int _R, int _print=0) + { +#ifdef GRID_NVCC + Glog << "cuBLAS orthogonalize" << std::endl; + + typedef typename Field::vector_object vobj; + typedef typename vobj::scalar_type scalar_type; + typedef typename vobj::vector_type vector_type; + + typedef typename Field::scalar_type MyComplex; + + GridBase *grid = w[0].Grid(); + //grid->show_decomposition(); + //const uint64_t nsimd = grid->Nsimd(); + const uint64_t sites = grid->lSites(); + + //auto w_v = w[0].View(); + //cuDoubleComplex *z = reinterpret_cast(&w_v._odata[0]); + //cuDoubleComplex *z = w_v._odata._internal; + //thread_for(ss,w_v.size(),{ + // Glog << w_v[ss] << std::endl; + //}); + //w_v[0] + //exit(0); + //scalar_type *z = (scalar_type *)&w_v[0]; // OK + //cuDoubleComplex *z = reinterpret_cast(&w_v[0]); // OK + + cudaError_t cudaStat; + + cuDoubleComplex *w_acc, *evec_acc, *c_acc; + + cudaStat = cudaMallocManaged((void **)&w_acc, _Nu*sites*12*sizeof(cuDoubleComplex)); + Glog << cudaStat << std::endl; + cudaStat = cudaMallocManaged((void **)&evec_acc, _R*sites*12*sizeof(cuDoubleComplex)); + Glog << cudaStat << std::endl; + cudaStat = cudaMallocManaged((void **)&c_acc, _Nu*_R*12*sizeof(cuDoubleComplex)); + Glog << cudaStat << std::endl; + + Glog << "cuBLAS prepare array"<< std::endl; +#if 0 // a trivial test + for (int col=0; col<_Nu; ++col) { + for (size_t row=0; row(&w_v[0]); + for (size_t row=0; row(&evec_v[0]); + for (size_t row=0; rowGlobalSumVector((double*)c_acc,2*_Nu*_R); + + cublasDestroy(handle); + + Glog << "cuBLAS Zgemm done"<< std::endl; + + for (int i=0; i<_Nu; ++i) { + for (size_t j=0; j<_R; ++j) { + cuDoubleComplex z = c_acc[i*_R+j]; + MyComplex ip(z.x,z.y); + if (_print) { + Glog << "[" << j << "," << i << "] = " + << z.x << " + i " << z.y << std::endl; + } + w[i] = w[i] - ip * evec[j]; + } + assert(normalize(w[i],_print)!=0); + } + + cudaFree(w_acc); + cudaFree(evec_acc); + cudaFree(c_acc); + + Glog << "cuBLAS orthogonalize done" << std::endl; +#else + Glog << "BLAS wrapper is not implemented" << std::endl; + exit(1); +#endif } @@ -310,10 +439,10 @@ for( int i =0;i eval(JP.Nm);