mirror of
https://github.com/paboyle/Grid.git
synced 2024-11-10 07:55:35 +00:00
MixedPrec Multishift with better precision scheme for GPU
This commit is contained in:
parent
1713de35c0
commit
5b128a6f9f
@ -164,12 +164,8 @@ public:
|
|||||||
RealD cp,bp,qq; //prev
|
RealD cp,bp,qq; //prev
|
||||||
|
|
||||||
// Matrix mult fields
|
// Matrix mult fields
|
||||||
FieldF r_f(SinglePrecGrid);
|
|
||||||
FieldF p_f(SinglePrecGrid);
|
FieldF p_f(SinglePrecGrid);
|
||||||
FieldF tmp_f(SinglePrecGrid);
|
|
||||||
FieldF mmp_f(SinglePrecGrid);
|
FieldF mmp_f(SinglePrecGrid);
|
||||||
FieldF src_f(SinglePrecGrid);
|
|
||||||
precisionChange(src_f, src_d);
|
|
||||||
|
|
||||||
// Check lightest mass
|
// Check lightest mass
|
||||||
for(int s=0;s<nshift;s++){
|
for(int s=0;s<nshift;s++){
|
||||||
@ -198,14 +194,13 @@ public:
|
|||||||
ps_d[s] = src_d;
|
ps_d[s] = src_d;
|
||||||
}
|
}
|
||||||
// r and p for primary
|
// r and p for primary
|
||||||
r_f=src_f; //residual maintained in single
|
|
||||||
p_f=src_f;
|
|
||||||
p_d = src_d; //primary copy --- make this a reference to ps_d to save axpys
|
p_d = src_d; //primary copy --- make this a reference to ps_d to save axpys
|
||||||
|
r_d = p_d;
|
||||||
|
|
||||||
//MdagM+m[0]
|
//MdagM+m[0]
|
||||||
Linop_f.HermOpAndNorm(p_f,mmp_f,d,qq); // mmp = MdagM p d=real(dot(p, mmp)), qq=norm2(mmp)
|
Linop_d.HermOpAndNorm(p_d,mmp_d,d,qq); // mmp = MdagM p d=real(dot(p, mmp)), qq=norm2(mmp)
|
||||||
axpy(mmp_f,mass[0],p_f,mmp_f);
|
axpy(mmp_d,mass[0],p_d,mmp_d);
|
||||||
RealD rn = norm2(p_f);
|
RealD rn = norm2(p_d);
|
||||||
d += rn*mass[0];
|
d += rn*mass[0];
|
||||||
|
|
||||||
b = -cp /d;
|
b = -cp /d;
|
||||||
@ -223,7 +218,7 @@ public:
|
|||||||
|
|
||||||
// r += b[0] A.p[0]
|
// r += b[0] A.p[0]
|
||||||
// c= norm(r)
|
// c= norm(r)
|
||||||
c=axpy_norm(r_f,b,mmp_f,r_f);
|
c=axpy_norm(r_d,b,mmp_d,r_d);
|
||||||
|
|
||||||
for(int s=0;s<nshift;s++) {
|
for(int s=0;s<nshift;s++) {
|
||||||
axpby(psi_d[s],0.,-bs[s]*alpha[s],src_d,src_d);
|
axpby(psi_d[s],0.,-bs[s]*alpha[s],src_d,src_d);
|
||||||
@ -240,13 +235,8 @@ public:
|
|||||||
int k;
|
int k;
|
||||||
|
|
||||||
for (k=1;k<=MaxIterations;k++){
|
for (k=1;k<=MaxIterations;k++){
|
||||||
|
|
||||||
a = c /cp;
|
a = c /cp;
|
||||||
|
|
||||||
//Update double precision search direction by residual
|
|
||||||
PrecChangeTimer.Start();
|
|
||||||
precisionChange(r_d, r_f);
|
|
||||||
PrecChangeTimer.Stop();
|
|
||||||
|
|
||||||
AXPYTimer.Start();
|
AXPYTimer.Start();
|
||||||
axpy(p_d,a,p_d,r_d);
|
axpy(p_d,a,p_d,r_d);
|
||||||
|
|
||||||
@ -263,19 +253,23 @@ public:
|
|||||||
AXPYTimer.Stop();
|
AXPYTimer.Stop();
|
||||||
|
|
||||||
PrecChangeTimer.Start();
|
PrecChangeTimer.Start();
|
||||||
precisionChange(p_f, p_d); //get back single prec search direction for linop
|
precisionChangeFast(p_f, p_d); //get back single prec search direction for linop
|
||||||
PrecChangeTimer.Stop();
|
PrecChangeTimer.Stop();
|
||||||
|
|
||||||
cp=c;
|
cp=c;
|
||||||
MatrixTimer.Start();
|
MatrixTimer.Start();
|
||||||
Linop_f.HermOp(p_f,mmp_f);
|
Linop_f.HermOp(p_f,mmp_f);
|
||||||
d=real(innerProduct(p_f,mmp_f));
|
|
||||||
MatrixTimer.Stop();
|
MatrixTimer.Stop();
|
||||||
|
|
||||||
|
PrecChangeTimer.Start();
|
||||||
|
precisionChangeFast(mmp_d, mmp_f); // From Float to Double
|
||||||
|
PrecChangeTimer.Stop();
|
||||||
|
|
||||||
AXPYTimer.Start();
|
AXPYTimer.Start();
|
||||||
axpy(mmp_f,mass[0],p_f,mmp_f);
|
d=real(innerProduct(p_d,mmp_d));
|
||||||
|
axpy(mmp_d,mass[0],p_d,mmp_d);
|
||||||
AXPYTimer.Stop();
|
AXPYTimer.Stop();
|
||||||
RealD rn = norm2(p_f);
|
RealD rn = norm2(p_d);
|
||||||
d += rn*mass[0];
|
d += rn*mass[0];
|
||||||
|
|
||||||
bp=b;
|
bp=b;
|
||||||
@ -306,12 +300,11 @@ public:
|
|||||||
}
|
}
|
||||||
|
|
||||||
//Perform reliable update if necessary; otherwise update residual from single-prec mmp
|
//Perform reliable update if necessary; otherwise update residual from single-prec mmp
|
||||||
RealD c_f = axpy_norm(r_f,b,mmp_f,r_f);
|
c = axpy_norm(r_d,b,mmp_d,r_d);
|
||||||
AXPYTimer.Stop();
|
AXPYTimer.Stop();
|
||||||
|
|
||||||
c = c_f;
|
|
||||||
|
|
||||||
if(k % ReliableUpdateFreq == 0){
|
if(k % ReliableUpdateFreq == 0){
|
||||||
|
RealD c_old = c;
|
||||||
//Replace r with true residual
|
//Replace r with true residual
|
||||||
MatrixTimer.Start();
|
MatrixTimer.Start();
|
||||||
Linop_d.HermOp(psi_d[0],mmp_d);
|
Linop_d.HermOp(psi_d[0],mmp_d);
|
||||||
@ -320,15 +313,10 @@ public:
|
|||||||
AXPYTimer.Start();
|
AXPYTimer.Start();
|
||||||
axpy(mmp_d,mass[0],psi_d[0],mmp_d);
|
axpy(mmp_d,mass[0],psi_d[0],mmp_d);
|
||||||
|
|
||||||
RealD c_d = axpy_norm(r_d, -1.0, mmp_d, src_d);
|
c = axpy_norm(r_d, -1.0, mmp_d, src_d);
|
||||||
AXPYTimer.Stop();
|
AXPYTimer.Stop();
|
||||||
|
|
||||||
std::cout<<GridLogMessage<<"ConjugateGradientMultiShiftMixedPrec k="<<k<< ", replaced |r|^2 = "<<c_f <<" with |r|^2 = "<<c_d<<std::endl;
|
std::cout<<GridLogMessage<<"ConjugateGradientMultiShiftMixedPrec k="<<k<< ", replaced |r|^2 = "<<c_old <<" with |r|^2 = "<<c<<std::endl;
|
||||||
|
|
||||||
PrecChangeTimer.Start();
|
|
||||||
precisionChange(r_f, r_d);
|
|
||||||
PrecChangeTimer.Stop();
|
|
||||||
c = c_d;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Convergence checks
|
// Convergence checks
|
||||||
|
Loading…
Reference in New Issue
Block a user