mirror of
https://github.com/paboyle/Grid.git
synced 2025-06-16 23:07:05 +01:00
Merge branch 'develop' into feature/hadrons
This commit is contained in:
@ -51,7 +51,7 @@ namespace Grid {
|
||||
|
||||
virtual void Op (const Field &in, Field &out) = 0; // Abstract base
|
||||
virtual void AdjOp (const Field &in, Field &out) = 0; // Abstract base
|
||||
virtual void HermOpAndNorm(const Field &in, Field &out,RealD &n1,RealD &n2)=0;
|
||||
virtual void HermOpAndNorm(const Field &in, Field &out,RealD &n1,RealD &n2) = 0;
|
||||
virtual void HermOp(const Field &in, Field &out)=0;
|
||||
};
|
||||
|
||||
@ -309,36 +309,59 @@ namespace Grid {
|
||||
class SchurStaggeredOperator : public SchurOperatorBase<Field> {
|
||||
protected:
|
||||
Matrix &_Mat;
|
||||
Field tmp;
|
||||
RealD mass;
|
||||
double tMpc;
|
||||
double tIP;
|
||||
double tMeo;
|
||||
double taxpby_norm;
|
||||
uint64_t ncall;
|
||||
public:
|
||||
SchurStaggeredOperator (Matrix &Mat): _Mat(Mat){};
|
||||
void Report(void)
|
||||
{
|
||||
std::cout << GridLogMessage << " HermOpAndNorm.Mpc "<< tMpc/ncall<<" usec "<<std::endl;
|
||||
std::cout << GridLogMessage << " HermOpAndNorm.IP "<< tIP /ncall<<" usec "<<std::endl;
|
||||
std::cout << GridLogMessage << " Mpc.MeoMoe "<< tMeo/ncall<<" usec "<<std::endl;
|
||||
std::cout << GridLogMessage << " Mpc.axpby_norm "<< taxpby_norm/ncall<<" usec "<<std::endl;
|
||||
}
|
||||
SchurStaggeredOperator (Matrix &Mat): _Mat(Mat), tmp(_Mat.RedBlackGrid())
|
||||
{
|
||||
assert( _Mat.isTrivialEE() );
|
||||
mass = _Mat.Mass();
|
||||
tMpc=0;
|
||||
tIP =0;
|
||||
tMeo=0;
|
||||
taxpby_norm=0;
|
||||
ncall=0;
|
||||
}
|
||||
virtual void HermOpAndNorm(const Field &in, Field &out,RealD &n1,RealD &n2){
|
||||
GridLogIterative.TimingMode(1);
|
||||
std::cout << GridLogIterative << " HermOpAndNorm "<<std::endl;
|
||||
ncall++;
|
||||
tMpc-=usecond();
|
||||
n2 = Mpc(in,out);
|
||||
std::cout << GridLogIterative << " HermOpAndNorm.Mpc "<<std::endl;
|
||||
tMpc+=usecond();
|
||||
tIP-=usecond();
|
||||
ComplexD dot= innerProduct(in,out);
|
||||
std::cout << GridLogIterative << " HermOpAndNorm.innerProduct "<<std::endl;
|
||||
tIP+=usecond();
|
||||
n1 = real(dot);
|
||||
}
|
||||
virtual void HermOp(const Field &in, Field &out){
|
||||
std::cout << GridLogIterative << " HermOp "<<std::endl;
|
||||
Mpc(in,out);
|
||||
ncall++;
|
||||
tMpc-=usecond();
|
||||
_Mat.Meooe(in,out);
|
||||
_Mat.Meooe(out,tmp);
|
||||
tMpc+=usecond();
|
||||
taxpby_norm-=usecond();
|
||||
axpby(out,-1.0,mass*mass,tmp,in);
|
||||
taxpby_norm+=usecond();
|
||||
}
|
||||
virtual RealD Mpc (const Field &in, Field &out) {
|
||||
Field tmp(in._grid);
|
||||
Field tmp2(in._grid);
|
||||
|
||||
std::cout << GridLogIterative << " HermOp.Mpc "<<std::endl;
|
||||
_Mat.Mooee(in,out);
|
||||
_Mat.Mooee(out,tmp);
|
||||
std::cout << GridLogIterative << " HermOp.MooeeMooee "<<std::endl;
|
||||
|
||||
tMeo-=usecond();
|
||||
_Mat.Meooe(in,out);
|
||||
_Mat.Meooe(out,tmp2);
|
||||
std::cout << GridLogIterative << " HermOp.MeooeMeooe "<<std::endl;
|
||||
|
||||
RealD nn=axpy_norm(out,-1.0,tmp2,tmp);
|
||||
std::cout << GridLogIterative << " HermOp.axpy_norm "<<std::endl;
|
||||
_Mat.Meooe(out,tmp);
|
||||
tMeo+=usecond();
|
||||
taxpby_norm-=usecond();
|
||||
RealD nn=axpby_norm(out,-1.0,mass*mass,tmp,in);
|
||||
taxpby_norm+=usecond();
|
||||
return nn;
|
||||
}
|
||||
virtual RealD MpcDag (const Field &in, Field &out){
|
||||
|
@ -54,6 +54,7 @@ class ConjugateGradient : public OperatorFunction<Field> {
|
||||
|
||||
void operator()(LinearOperatorBase<Field> &Linop, const Field &src, Field &psi) {
|
||||
|
||||
|
||||
psi.checkerboard = src.checkerboard;
|
||||
conformable(psi, src);
|
||||
|
||||
@ -69,7 +70,6 @@ class ConjugateGradient : public OperatorFunction<Field> {
|
||||
|
||||
|
||||
Linop.HermOpAndNorm(psi, mmp, d, b);
|
||||
|
||||
|
||||
r = src - mmp;
|
||||
p = r;
|
||||
@ -96,38 +96,44 @@ class ConjugateGradient : public OperatorFunction<Field> {
|
||||
<< "ConjugateGradient: k=0 residual " << cp << " target " << rsq << std::endl;
|
||||
|
||||
GridStopWatch LinalgTimer;
|
||||
GridStopWatch InnerTimer;
|
||||
GridStopWatch AxpyNormTimer;
|
||||
GridStopWatch LinearCombTimer;
|
||||
GridStopWatch MatrixTimer;
|
||||
GridStopWatch SolverTimer;
|
||||
|
||||
SolverTimer.Start();
|
||||
int k;
|
||||
for (k = 1; k <= MaxIterations; k++) {
|
||||
for (k = 1; k <= MaxIterations*1000; k++) {
|
||||
c = cp;
|
||||
|
||||
MatrixTimer.Start();
|
||||
Linop.HermOpAndNorm(p, mmp, d, qq);
|
||||
Linop.HermOp(p, mmp);
|
||||
MatrixTimer.Stop();
|
||||
|
||||
LinalgTimer.Start();
|
||||
// RealD qqck = norm2(mmp);
|
||||
// ComplexD dck = innerProduct(p,mmp);
|
||||
|
||||
InnerTimer.Start();
|
||||
ComplexD dc = innerProduct(p,mmp);
|
||||
InnerTimer.Stop();
|
||||
d = dc.real();
|
||||
a = c / d;
|
||||
b_pred = a * (a * qq - d) / c;
|
||||
|
||||
AxpyNormTimer.Start();
|
||||
cp = axpy_norm(r, -a, mmp, r);
|
||||
AxpyNormTimer.Stop();
|
||||
b = cp / c;
|
||||
|
||||
// Fuse these loops ; should be really easy
|
||||
psi = a * p + psi;
|
||||
p = p * b + r;
|
||||
|
||||
LinearCombTimer.Start();
|
||||
parallel_for(int ss=0;ss<src._grid->oSites();ss++){
|
||||
vstream(psi[ss], a * p[ss] + psi[ss]);
|
||||
vstream(p [ss], b * p[ss] + r[ss]);
|
||||
}
|
||||
LinearCombTimer.Stop();
|
||||
LinalgTimer.Stop();
|
||||
|
||||
std::cout << GridLogIterative << "ConjugateGradient: Iteration " << k
|
||||
<< " residual " << cp << " target " << rsq << std::endl;
|
||||
std::cout << GridLogDebug << "a = "<< a << " b_pred = "<< b_pred << " b = "<< b << std::endl;
|
||||
std::cout << GridLogDebug << "qq = "<< qq << " d = "<< d << " c = "<< c << std::endl;
|
||||
|
||||
// Stopping condition
|
||||
if (cp <= rsq) {
|
||||
@ -148,6 +154,9 @@ class ConjugateGradient : public OperatorFunction<Field> {
|
||||
std::cout << GridLogMessage << "\tElapsed " << SolverTimer.Elapsed() <<std::endl;
|
||||
std::cout << GridLogMessage << "\tMatrix " << MatrixTimer.Elapsed() <<std::endl;
|
||||
std::cout << GridLogMessage << "\tLinalg " << LinalgTimer.Elapsed() <<std::endl;
|
||||
std::cout << GridLogMessage << "\tInner " << InnerTimer.Elapsed() <<std::endl;
|
||||
std::cout << GridLogMessage << "\tAxpyNorm " << AxpyNormTimer.Elapsed() <<std::endl;
|
||||
std::cout << GridLogMessage << "\tLinearComb " << LinearCombTimer.Elapsed() <<std::endl;
|
||||
|
||||
if (ErrorOnNoConverge) assert(true_residual / Tolerance < 10000.0);
|
||||
|
||||
|
@ -43,6 +43,7 @@ namespace Grid {
|
||||
public:
|
||||
RealD Tolerance;
|
||||
Integer MaxIterations;
|
||||
Integer IterationsToComplete; //Number of iterations the CG took to finish. Filled in upon completion
|
||||
int verbose;
|
||||
MultiShiftFunction shifts;
|
||||
|
||||
@ -163,7 +164,16 @@ void operator() (LinearOperatorBase<Field> &Linop, const Field &src, std::vector
|
||||
for(int s=0;s<nshift;s++) {
|
||||
axpby(psi[s],0.,-bs[s]*alpha[s],src,src);
|
||||
}
|
||||
|
||||
|
||||
///////////////////////////////////////
|
||||
// Timers
|
||||
///////////////////////////////////////
|
||||
GridStopWatch AXPYTimer;
|
||||
GridStopWatch ShiftTimer;
|
||||
GridStopWatch QRTimer;
|
||||
GridStopWatch MatrixTimer;
|
||||
GridStopWatch SolverTimer;
|
||||
SolverTimer.Start();
|
||||
|
||||
// Iteration loop
|
||||
int k;
|
||||
@ -171,7 +181,9 @@ void operator() (LinearOperatorBase<Field> &Linop, const Field &src, std::vector
|
||||
for (k=1;k<=MaxIterations;k++){
|
||||
|
||||
a = c /cp;
|
||||
AXPYTimer.Start();
|
||||
axpy(p,a,p,r);
|
||||
AXPYTimer.Stop();
|
||||
|
||||
// Note to self - direction ps is iterated seperately
|
||||
// for each shift. Does not appear to have any scope
|
||||
@ -180,6 +192,7 @@ void operator() (LinearOperatorBase<Field> &Linop, const Field &src, std::vector
|
||||
// However SAME r is used. Could load "r" and update
|
||||
// ALL ps[s]. 2/3 Bandwidth saving
|
||||
// New Kernel: Load r, vector of coeffs, vector of pointers ps
|
||||
AXPYTimer.Start();
|
||||
for(int s=0;s<nshift;s++){
|
||||
if ( ! converged[s] ) {
|
||||
if (s==0){
|
||||
@ -190,22 +203,34 @@ void operator() (LinearOperatorBase<Field> &Linop, const Field &src, std::vector
|
||||
}
|
||||
}
|
||||
}
|
||||
AXPYTimer.Stop();
|
||||
|
||||
cp=c;
|
||||
MatrixTimer.Start();
|
||||
//Linop.HermOpAndNorm(p,mmp,d,qq); // d is used
|
||||
// The below is faster on KNL
|
||||
Linop.HermOp(p,mmp);
|
||||
d=real(innerProduct(p,mmp));
|
||||
|
||||
Linop.HermOpAndNorm(p,mmp,d,qq);
|
||||
MatrixTimer.Stop();
|
||||
|
||||
AXPYTimer.Start();
|
||||
axpy(mmp,mass[0],p,mmp);
|
||||
AXPYTimer.Stop();
|
||||
RealD rn = norm2(p);
|
||||
d += rn*mass[0];
|
||||
|
||||
bp=b;
|
||||
b=-cp/d;
|
||||
|
||||
AXPYTimer.Start();
|
||||
c=axpy_norm(r,b,mmp,r);
|
||||
AXPYTimer.Stop();
|
||||
|
||||
// Toggle the recurrence history
|
||||
bs[0] = b;
|
||||
iz = 1-iz;
|
||||
ShiftTimer.Start();
|
||||
for(int s=1;s<nshift;s++){
|
||||
if((!converged[s])){
|
||||
RealD z0 = z[s][1-iz];
|
||||
@ -215,6 +240,7 @@ void operator() (LinearOperatorBase<Field> &Linop, const Field &src, std::vector
|
||||
bs[s] = b*z[s][iz]/z0; // NB sign rel to Mike
|
||||
}
|
||||
}
|
||||
ShiftTimer.Stop();
|
||||
|
||||
for(int s=0;s<nshift;s++){
|
||||
int ss = s;
|
||||
@ -257,6 +283,9 @@ void operator() (LinearOperatorBase<Field> &Linop, const Field &src, std::vector
|
||||
|
||||
if ( all_converged ){
|
||||
|
||||
SolverTimer.Stop();
|
||||
|
||||
|
||||
std::cout<<GridLogMessage<< "CGMultiShift: All shifts have converged iteration "<<k<<std::endl;
|
||||
std::cout<<GridLogMessage<< "CGMultiShift: Checking solutions"<<std::endl;
|
||||
|
||||
@ -269,8 +298,19 @@ void operator() (LinearOperatorBase<Field> &Linop, const Field &src, std::vector
|
||||
RealD cn = norm2(src);
|
||||
std::cout<<GridLogMessage<<"CGMultiShift: shift["<<s<<"] true residual "<<std::sqrt(rn/cn)<<std::endl;
|
||||
}
|
||||
|
||||
std::cout << GridLogMessage << "Time Breakdown "<<std::endl;
|
||||
std::cout << GridLogMessage << "\tElapsed " << SolverTimer.Elapsed() <<std::endl;
|
||||
std::cout << GridLogMessage << "\tAXPY " << AXPYTimer.Elapsed() <<std::endl;
|
||||
std::cout << GridLogMessage << "\tMarix " << MatrixTimer.Elapsed() <<std::endl;
|
||||
std::cout << GridLogMessage << "\tShift " << ShiftTimer.Elapsed() <<std::endl;
|
||||
|
||||
IterationsToComplete = k;
|
||||
|
||||
return;
|
||||
}
|
||||
|
||||
|
||||
}
|
||||
// ugly hack
|
||||
std::cout<<GridLogMessage<<"CG multi shift did not converge"<<std::endl;
|
||||
|
Reference in New Issue
Block a user