From 1671adfd492704cb4c15a4a759cfaba1d47ac526 Mon Sep 17 00:00:00 2001 From: Daniel Richtmann Date: Fri, 2 Feb 2018 10:03:15 +0100 Subject: [PATCH] WilsonMG: Add some tests for linear operators --- tests/solver/Test_wilson_mg.cc | 100 ++++++++++++++++++++++++++------- 1 file changed, 80 insertions(+), 20 deletions(-) diff --git a/tests/solver/Test_wilson_mg.cc b/tests/solver/Test_wilson_mg.cc index 9dd780cf..a402cf02 100644 --- a/tests/solver/Test_wilson_mg.cc +++ b/tests/solver/Test_wilson_mg.cc @@ -126,38 +126,98 @@ public: } }; -template void testOperator(LinearOperatorBase &LinOp, GridBase *Grid) { +template void testLinearOperator(LinearOperatorBase &LinOp, GridBase *Grid, std::string const &name = "") { std::vector seeds({1, 2, 3, 4}); GridParallelRNG RNG(Grid); RNG.SeedFixedIntegers(seeds); - // clang-format off - Field src(Grid); random(RNG, src); - Field result(Grid); result = zero; - Field ref(Grid); ref = zero; - Field tmp(Grid); - Field err(Grid); - // clang-format on + { + std::cout << GridLogMessage << "Testing that Mdiag + Σ_μ Mdir_μ == M for operator " << name << ":" << std::endl; - LinOp.Op(src, ref); + // clang-format off + Field src(Grid); random(RNG, src); + Field ref(Grid); ref = zero; + Field result(Grid); result = zero; + Field diag(Grid); diag = zero; + Field sumDir(Grid); sumDir = zero; + Field tmp(Grid); + Field err(Grid); + // clang-format on - LinOp.OpDiag(src, result); - std::cout << GridLogMessage << "diag: norm2(result) = " << norm2(result) << std::endl; + LinOp.Op(src, ref); + std::cout << GridLogMessage << " norm2(M * src) = " << norm2(ref) << std::endl; - for(int d = 0; d < 4; d++) { - LinOp.OpDir(src, tmp, d, +1); - std::cout << GridLogMessage << "dir + " << d << ": norm2(tmp) = " << norm2(tmp) << std::endl; - result = result + tmp; + LinOp.OpDiag(src, diag); + std::cout << GridLogMessage << " norm2(Mdiag * src) = " << norm2(diag) << std::endl; - LinOp.OpDir(src, tmp, d, -1); - std::cout << GridLogMessage << "dir - " << d << ": norm2(tmp) = " << norm2(tmp) << std::endl; - result = result + tmp; + for(int dir = 0; dir < 4; dir++) { + for(auto disp : {+1, -1}) { + LinOp.OpDir(src, tmp, dir, disp); + std::cout << GridLogMessage << " norm2(Mdir_{" << dir << "," << disp << "} * src) = " << norm2(tmp) << std::endl; + sumDir = sumDir + tmp; + } + } + std::cout << GridLogMessage << " norm2(Σ_μ Mdir_μ * src) = " << norm2(sumDir) << std::endl; + + result = diag + sumDir; + err = ref - result; + + std::cout << GridLogMessage << " Absolute deviation = " << norm2(err) << std::endl; + std::cout << GridLogMessage << " Relative deviation = " << norm2(err) / norm2(ref) << std::endl; } - err = result - ref; + { + std::cout << GridLogMessage << "Testing hermiticity stochastically for operator " << name << ":" << std::endl; - std::cout << GridLogMessage << "Error: absolute = " << norm2(err) << " relative = " << norm2(err) / norm2(ref) << std::endl; + // clang-format off + Field phi(Grid); random(RNG, phi); + Field chi(Grid); random(RNG, chi); + Field MPhi(Grid); + Field MdagChi(Grid); + // clang-format on + + LinOp.Op(phi, MPhi); + LinOp.AdjOp(chi, MdagChi); + + ComplexD chiMPhi = innerProduct(chi, MPhi); + ComplexD phiMdagChi = innerProduct(phi, MdagChi); + + ComplexD phiMPhi = innerProduct(phi, MPhi); + ComplexD chiMdagChi = innerProduct(chi, MdagChi); + + std::cout << GridLogMessage << " chiMPhi = " << chiMPhi << " phiMdagChi = " << phiMdagChi + << " difference = " << chiMPhi - conjugate(phiMdagChi) << std::endl; + + std::cout << GridLogMessage << " phiMPhi = " << phiMPhi << " chiMdagChi = " << chiMdagChi << " <- should be real if hermitian" + << std::endl; + } + + { + std::cout << GridLogMessage << "Testing linearity for operator " << name << ":" << std::endl; + + // clang-format off + Field phi(Grid); random(RNG, phi); + Field chi(Grid); random(RNG, chi); + Field phiPlusChi(Grid); + Field MPhi(Grid); + Field MChi(Grid); + Field MPhiPlusChi(Grid); + Field linearityError(Grid); + // clang-format on + + LinOp.Op(phi, MPhi); + LinOp.Op(chi, MChi); + + phiPlusChi = phi + chi; + + LinOp.Op(phiPlusChi, MPhiPlusChi); + + linearityError = MPhiPlusChi - MPhi; + linearityError = linearityError - MChi; + + std::cout << GridLogMessage << " norm2(linearityError) = " << norm2(linearityError) << std::endl; + } } // template < class Fobj, class CComplex, int coarseSpins, int nbasis, class Matrix >