mirror of
https://github.com/paboyle/Grid.git
synced 2025-06-14 05:07:05 +01:00
Merge branch 'feature/staggered-comms-compute' of https://github.com/paboyle/Grid into feature/staggered-comms-compute
This commit is contained in:
@ -51,7 +51,7 @@ namespace Grid {
|
||||
|
||||
virtual void Op (const Field &in, Field &out) = 0; // Abstract base
|
||||
virtual void AdjOp (const Field &in, Field &out) = 0; // Abstract base
|
||||
virtual void HermOpAndNorm(const Field &in, Field &out,RealD &n1,RealD &n2)=0;
|
||||
virtual void HermOpAndNorm(const Field &in, Field &out,RealD &n1,RealD &n2) = 0;
|
||||
virtual void HermOp(const Field &in, Field &out)=0;
|
||||
};
|
||||
|
||||
@ -305,36 +305,59 @@ namespace Grid {
|
||||
class SchurStaggeredOperator : public SchurOperatorBase<Field> {
|
||||
protected:
|
||||
Matrix &_Mat;
|
||||
Field tmp;
|
||||
RealD mass;
|
||||
double tMpc;
|
||||
double tIP;
|
||||
double tMeo;
|
||||
double taxpby_norm;
|
||||
uint64_t ncall;
|
||||
public:
|
||||
SchurStaggeredOperator (Matrix &Mat): _Mat(Mat){};
|
||||
void Report(void)
|
||||
{
|
||||
std::cout << GridLogMessage << " HermOpAndNorm.Mpc "<< tMpc/ncall<<" usec "<<std::endl;
|
||||
std::cout << GridLogMessage << " HermOpAndNorm.IP "<< tIP /ncall<<" usec "<<std::endl;
|
||||
std::cout << GridLogMessage << " Mpc.MeoMoe "<< tMeo/ncall<<" usec "<<std::endl;
|
||||
std::cout << GridLogMessage << " Mpc.axpby_norm "<< taxpby_norm/ncall<<" usec "<<std::endl;
|
||||
}
|
||||
SchurStaggeredOperator (Matrix &Mat): _Mat(Mat), tmp(_Mat.RedBlackGrid())
|
||||
{
|
||||
assert( _Mat.isTrivialEE() );
|
||||
mass = _Mat.Mass();
|
||||
tMpc=0;
|
||||
tIP =0;
|
||||
tMeo=0;
|
||||
taxpby_norm=0;
|
||||
ncall=0;
|
||||
}
|
||||
virtual void HermOpAndNorm(const Field &in, Field &out,RealD &n1,RealD &n2){
|
||||
GridLogIterative.TimingMode(1);
|
||||
std::cout << GridLogIterative << " HermOpAndNorm "<<std::endl;
|
||||
ncall++;
|
||||
tMpc-=usecond();
|
||||
n2 = Mpc(in,out);
|
||||
std::cout << GridLogIterative << " HermOpAndNorm.Mpc "<<std::endl;
|
||||
tMpc+=usecond();
|
||||
tIP-=usecond();
|
||||
ComplexD dot= innerProduct(in,out);
|
||||
std::cout << GridLogIterative << " HermOpAndNorm.innerProduct "<<std::endl;
|
||||
tIP+=usecond();
|
||||
n1 = real(dot);
|
||||
}
|
||||
virtual void HermOp(const Field &in, Field &out){
|
||||
std::cout << GridLogIterative << " HermOp "<<std::endl;
|
||||
Mpc(in,out);
|
||||
ncall++;
|
||||
tMpc-=usecond();
|
||||
_Mat.Meooe(in,out);
|
||||
_Mat.Meooe(out,tmp);
|
||||
tMpc+=usecond();
|
||||
taxpby_norm-=usecond();
|
||||
axpby(out,-1.0,mass*mass,tmp,in);
|
||||
taxpby_norm+=usecond();
|
||||
}
|
||||
virtual RealD Mpc (const Field &in, Field &out) {
|
||||
Field tmp(in._grid);
|
||||
Field tmp2(in._grid);
|
||||
|
||||
std::cout << GridLogIterative << " HermOp.Mpc "<<std::endl;
|
||||
_Mat.Mooee(in,out);
|
||||
_Mat.Mooee(out,tmp);
|
||||
std::cout << GridLogIterative << " HermOp.MooeeMooee "<<std::endl;
|
||||
|
||||
tMeo-=usecond();
|
||||
_Mat.Meooe(in,out);
|
||||
_Mat.Meooe(out,tmp2);
|
||||
std::cout << GridLogIterative << " HermOp.MeooeMeooe "<<std::endl;
|
||||
|
||||
RealD nn=axpy_norm(out,-1.0,tmp2,tmp);
|
||||
std::cout << GridLogIterative << " HermOp.axpy_norm "<<std::endl;
|
||||
_Mat.Meooe(out,tmp);
|
||||
tMeo+=usecond();
|
||||
taxpby_norm-=usecond();
|
||||
RealD nn=axpby_norm(out,-1.0,mass*mass,tmp,in);
|
||||
taxpby_norm+=usecond();
|
||||
return nn;
|
||||
}
|
||||
virtual RealD MpcDag (const Field &in, Field &out){
|
||||
|
@ -70,7 +70,6 @@ class ConjugateGradient : public OperatorFunction<Field> {
|
||||
|
||||
|
||||
Linop.HermOpAndNorm(psi, mmp, d, b);
|
||||
|
||||
|
||||
r = src - mmp;
|
||||
p = r;
|
||||
@ -97,6 +96,9 @@ class ConjugateGradient : public OperatorFunction<Field> {
|
||||
<< "ConjugateGradient: k=0 residual " << cp << " target " << rsq << std::endl;
|
||||
|
||||
GridStopWatch LinalgTimer;
|
||||
GridStopWatch InnerTimer;
|
||||
GridStopWatch AxpyNormTimer;
|
||||
GridStopWatch LinearCombTimer;
|
||||
GridStopWatch MatrixTimer;
|
||||
GridStopWatch SolverTimer;
|
||||
|
||||
@ -106,30 +108,32 @@ class ConjugateGradient : public OperatorFunction<Field> {
|
||||
c = cp;
|
||||
|
||||
MatrixTimer.Start();
|
||||
Linop.HermOpAndNorm(p, mmp, d, qq);
|
||||
Linop.HermOp(p, mmp);
|
||||
MatrixTimer.Stop();
|
||||
|
||||
LinalgTimer.Start();
|
||||
// AA
|
||||
// RealD qqck = norm2(mmp);
|
||||
// ComplexD dck = innerProduct(p,mmp);
|
||||
|
||||
InnerTimer.Start();
|
||||
ComplexD dc = innerProduct(p,mmp);
|
||||
InnerTimer.Stop();
|
||||
d = dc.real();
|
||||
a = c / d;
|
||||
b_pred = a * (a * qq - d) / c;
|
||||
|
||||
AxpyNormTimer.Start();
|
||||
cp = axpy_norm(r, -a, mmp, r);
|
||||
AxpyNormTimer.Stop();
|
||||
b = cp / c;
|
||||
|
||||
// Fuse these loops ; should be really easy
|
||||
psi = a * p + psi;
|
||||
p = p * b + r;
|
||||
|
||||
LinearCombTimer.Start();
|
||||
parallel_for(int ss=0;ss<src._grid->oSites();ss++){
|
||||
vstream(psi[ss], a * p[ss] + psi[ss]);
|
||||
vstream(p [ss], b * p[ss] + r[ss]);
|
||||
}
|
||||
LinearCombTimer.Stop();
|
||||
LinalgTimer.Stop();
|
||||
|
||||
std::cout << GridLogIterative << "ConjugateGradient: Iteration " << k
|
||||
<< " residual " << cp << " target " << rsq << std::endl;
|
||||
std::cout << GridLogDebug << "a = "<< a << " b_pred = "<< b_pred << " b = "<< b << std::endl;
|
||||
std::cout << GridLogDebug << "qq = "<< qq << " d = "<< d << " c = "<< c << std::endl;
|
||||
|
||||
// Stopping condition
|
||||
if (cp <= rsq) {
|
||||
@ -150,6 +154,9 @@ class ConjugateGradient : public OperatorFunction<Field> {
|
||||
std::cout << GridLogMessage << "\tElapsed " << SolverTimer.Elapsed() <<std::endl;
|
||||
std::cout << GridLogMessage << "\tMatrix " << MatrixTimer.Elapsed() <<std::endl;
|
||||
std::cout << GridLogMessage << "\tLinalg " << LinalgTimer.Elapsed() <<std::endl;
|
||||
std::cout << GridLogMessage << "\tInner " << InnerTimer.Elapsed() <<std::endl;
|
||||
std::cout << GridLogMessage << "\tAxpyNorm " << AxpyNormTimer.Elapsed() <<std::endl;
|
||||
std::cout << GridLogMessage << "\tLinearComb " << LinearCombTimer.Elapsed() <<std::endl;
|
||||
|
||||
if (ErrorOnNoConverge) assert(true_residual / Tolerance < 10000.0);
|
||||
|
||||
|
Reference in New Issue
Block a user