1
0
mirror of https://github.com/paboyle/Grid.git synced 2024-11-10 07:55:35 +00:00

MixedPrec Multishift with better precision scheme for GPU

This commit is contained in:
Peter Boyle 2022-09-23 16:18:47 -04:00
parent 1713de35c0
commit 5b128a6f9f

View File

@ -164,12 +164,8 @@ public:
RealD cp,bp,qq; //prev RealD cp,bp,qq; //prev
// Matrix mult fields // Matrix mult fields
FieldF r_f(SinglePrecGrid);
FieldF p_f(SinglePrecGrid); FieldF p_f(SinglePrecGrid);
FieldF tmp_f(SinglePrecGrid);
FieldF mmp_f(SinglePrecGrid); FieldF mmp_f(SinglePrecGrid);
FieldF src_f(SinglePrecGrid);
precisionChange(src_f, src_d);
// Check lightest mass // Check lightest mass
for(int s=0;s<nshift;s++){ for(int s=0;s<nshift;s++){
@ -198,14 +194,13 @@ public:
ps_d[s] = src_d; ps_d[s] = src_d;
} }
// r and p for primary // r and p for primary
r_f=src_f; //residual maintained in single
p_f=src_f;
p_d = src_d; //primary copy --- make this a reference to ps_d to save axpys p_d = src_d; //primary copy --- make this a reference to ps_d to save axpys
r_d = p_d;
//MdagM+m[0] //MdagM+m[0]
Linop_f.HermOpAndNorm(p_f,mmp_f,d,qq); // mmp = MdagM p d=real(dot(p, mmp)), qq=norm2(mmp) Linop_d.HermOpAndNorm(p_d,mmp_d,d,qq); // mmp = MdagM p d=real(dot(p, mmp)), qq=norm2(mmp)
axpy(mmp_f,mass[0],p_f,mmp_f); axpy(mmp_d,mass[0],p_d,mmp_d);
RealD rn = norm2(p_f); RealD rn = norm2(p_d);
d += rn*mass[0]; d += rn*mass[0];
b = -cp /d; b = -cp /d;
@ -223,7 +218,7 @@ public:
// r += b[0] A.p[0] // r += b[0] A.p[0]
// c= norm(r) // c= norm(r)
c=axpy_norm(r_f,b,mmp_f,r_f); c=axpy_norm(r_d,b,mmp_d,r_d);
for(int s=0;s<nshift;s++) { for(int s=0;s<nshift;s++) {
axpby(psi_d[s],0.,-bs[s]*alpha[s],src_d,src_d); axpby(psi_d[s],0.,-bs[s]*alpha[s],src_d,src_d);
@ -240,13 +235,8 @@ public:
int k; int k;
for (k=1;k<=MaxIterations;k++){ for (k=1;k<=MaxIterations;k++){
a = c /cp; a = c /cp;
//Update double precision search direction by residual
PrecChangeTimer.Start();
precisionChange(r_d, r_f);
PrecChangeTimer.Stop();
AXPYTimer.Start(); AXPYTimer.Start();
axpy(p_d,a,p_d,r_d); axpy(p_d,a,p_d,r_d);
@ -263,19 +253,23 @@ public:
AXPYTimer.Stop(); AXPYTimer.Stop();
PrecChangeTimer.Start(); PrecChangeTimer.Start();
precisionChange(p_f, p_d); //get back single prec search direction for linop precisionChangeFast(p_f, p_d); //get back single prec search direction for linop
PrecChangeTimer.Stop(); PrecChangeTimer.Stop();
cp=c; cp=c;
MatrixTimer.Start(); MatrixTimer.Start();
Linop_f.HermOp(p_f,mmp_f); Linop_f.HermOp(p_f,mmp_f);
d=real(innerProduct(p_f,mmp_f));
MatrixTimer.Stop(); MatrixTimer.Stop();
PrecChangeTimer.Start();
precisionChangeFast(mmp_d, mmp_f); // From Float to Double
PrecChangeTimer.Stop();
AXPYTimer.Start(); AXPYTimer.Start();
axpy(mmp_f,mass[0],p_f,mmp_f); d=real(innerProduct(p_d,mmp_d));
axpy(mmp_d,mass[0],p_d,mmp_d);
AXPYTimer.Stop(); AXPYTimer.Stop();
RealD rn = norm2(p_f); RealD rn = norm2(p_d);
d += rn*mass[0]; d += rn*mass[0];
bp=b; bp=b;
@ -306,12 +300,11 @@ public:
} }
//Perform reliable update if necessary; otherwise update residual from single-prec mmp //Perform reliable update if necessary; otherwise update residual from single-prec mmp
RealD c_f = axpy_norm(r_f,b,mmp_f,r_f); c = axpy_norm(r_d,b,mmp_d,r_d);
AXPYTimer.Stop(); AXPYTimer.Stop();
c = c_f;
if(k % ReliableUpdateFreq == 0){ if(k % ReliableUpdateFreq == 0){
RealD c_old = c;
//Replace r with true residual //Replace r with true residual
MatrixTimer.Start(); MatrixTimer.Start();
Linop_d.HermOp(psi_d[0],mmp_d); Linop_d.HermOp(psi_d[0],mmp_d);
@ -320,15 +313,10 @@ public:
AXPYTimer.Start(); AXPYTimer.Start();
axpy(mmp_d,mass[0],psi_d[0],mmp_d); axpy(mmp_d,mass[0],psi_d[0],mmp_d);
RealD c_d = axpy_norm(r_d, -1.0, mmp_d, src_d); c = axpy_norm(r_d, -1.0, mmp_d, src_d);
AXPYTimer.Stop(); AXPYTimer.Stop();
std::cout<<GridLogMessage<<"ConjugateGradientMultiShiftMixedPrec k="<<k<< ", replaced |r|^2 = "<<c_f <<" with |r|^2 = "<<c_d<<std::endl; std::cout<<GridLogMessage<<"ConjugateGradientMultiShiftMixedPrec k="<<k<< ", replaced |r|^2 = "<<c_old <<" with |r|^2 = "<<c<<std::endl;
PrecChangeTimer.Start();
precisionChange(r_f, r_d);
PrecChangeTimer.Stop();
c = c_d;
} }
// Convergence checks // Convergence checks