mirror of
https://github.com/paboyle/Grid.git
synced 2025-04-09 21:50:45 +01:00
Faster linalg on CG optimised against staggered
Sum overhead is bigger for staggered
This commit is contained in:
parent
eac6ec4b5e
commit
3e125c5b61
@ -70,7 +70,6 @@ class ConjugateGradient : public OperatorFunction<Field> {
|
|||||||
|
|
||||||
|
|
||||||
Linop.HermOpAndNorm(psi, mmp, d, b);
|
Linop.HermOpAndNorm(psi, mmp, d, b);
|
||||||
|
|
||||||
|
|
||||||
r = src - mmp;
|
r = src - mmp;
|
||||||
p = r;
|
p = r;
|
||||||
@ -97,6 +96,9 @@ class ConjugateGradient : public OperatorFunction<Field> {
|
|||||||
<< "ConjugateGradient: k=0 residual " << cp << " target " << rsq << std::endl;
|
<< "ConjugateGradient: k=0 residual " << cp << " target " << rsq << std::endl;
|
||||||
|
|
||||||
GridStopWatch LinalgTimer;
|
GridStopWatch LinalgTimer;
|
||||||
|
GridStopWatch InnerTimer;
|
||||||
|
GridStopWatch AxpyNormTimer;
|
||||||
|
GridStopWatch LinearCombTimer;
|
||||||
GridStopWatch MatrixTimer;
|
GridStopWatch MatrixTimer;
|
||||||
GridStopWatch SolverTimer;
|
GridStopWatch SolverTimer;
|
||||||
|
|
||||||
@ -106,30 +108,32 @@ class ConjugateGradient : public OperatorFunction<Field> {
|
|||||||
c = cp;
|
c = cp;
|
||||||
|
|
||||||
MatrixTimer.Start();
|
MatrixTimer.Start();
|
||||||
Linop.HermOpAndNorm(p, mmp, d, qq);
|
Linop.HermOp(p, mmp);
|
||||||
MatrixTimer.Stop();
|
MatrixTimer.Stop();
|
||||||
|
|
||||||
LinalgTimer.Start();
|
LinalgTimer.Start();
|
||||||
// AA
|
|
||||||
// RealD qqck = norm2(mmp);
|
|
||||||
// ComplexD dck = innerProduct(p,mmp);
|
|
||||||
|
|
||||||
|
InnerTimer.Start();
|
||||||
|
ComplexD dc = innerProduct(p,mmp);
|
||||||
|
InnerTimer.Stop();
|
||||||
|
d = dc.real();
|
||||||
a = c / d;
|
a = c / d;
|
||||||
b_pred = a * (a * qq - d) / c;
|
|
||||||
|
|
||||||
|
AxpyNormTimer.Start();
|
||||||
cp = axpy_norm(r, -a, mmp, r);
|
cp = axpy_norm(r, -a, mmp, r);
|
||||||
|
AxpyNormTimer.Stop();
|
||||||
b = cp / c;
|
b = cp / c;
|
||||||
|
|
||||||
// Fuse these loops ; should be really easy
|
LinearCombTimer.Start();
|
||||||
psi = a * p + psi;
|
parallel_for(int ss=0;ss<src._grid->oSites();ss++){
|
||||||
p = p * b + r;
|
vstream(psi[ss], a * p[ss] + psi[ss]);
|
||||||
|
vstream(p [ss], b * p[ss] + r[ss]);
|
||||||
|
}
|
||||||
|
LinearCombTimer.Stop();
|
||||||
LinalgTimer.Stop();
|
LinalgTimer.Stop();
|
||||||
|
|
||||||
std::cout << GridLogIterative << "ConjugateGradient: Iteration " << k
|
std::cout << GridLogIterative << "ConjugateGradient: Iteration " << k
|
||||||
<< " residual " << cp << " target " << rsq << std::endl;
|
<< " residual " << cp << " target " << rsq << std::endl;
|
||||||
std::cout << GridLogDebug << "a = "<< a << " b_pred = "<< b_pred << " b = "<< b << std::endl;
|
|
||||||
std::cout << GridLogDebug << "qq = "<< qq << " d = "<< d << " c = "<< c << std::endl;
|
|
||||||
|
|
||||||
// Stopping condition
|
// Stopping condition
|
||||||
if (cp <= rsq) {
|
if (cp <= rsq) {
|
||||||
@ -150,6 +154,9 @@ class ConjugateGradient : public OperatorFunction<Field> {
|
|||||||
std::cout << GridLogMessage << "\tElapsed " << SolverTimer.Elapsed() <<std::endl;
|
std::cout << GridLogMessage << "\tElapsed " << SolverTimer.Elapsed() <<std::endl;
|
||||||
std::cout << GridLogMessage << "\tMatrix " << MatrixTimer.Elapsed() <<std::endl;
|
std::cout << GridLogMessage << "\tMatrix " << MatrixTimer.Elapsed() <<std::endl;
|
||||||
std::cout << GridLogMessage << "\tLinalg " << LinalgTimer.Elapsed() <<std::endl;
|
std::cout << GridLogMessage << "\tLinalg " << LinalgTimer.Elapsed() <<std::endl;
|
||||||
|
std::cout << GridLogMessage << "\tInner " << InnerTimer.Elapsed() <<std::endl;
|
||||||
|
std::cout << GridLogMessage << "\tAxpyNorm " << AxpyNormTimer.Elapsed() <<std::endl;
|
||||||
|
std::cout << GridLogMessage << "\tLinearComb " << LinearCombTimer.Elapsed() <<std::endl;
|
||||||
|
|
||||||
if (ErrorOnNoConverge) assert(true_residual / Tolerance < 10000.0);
|
if (ErrorOnNoConverge) assert(true_residual / Tolerance < 10000.0);
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user