mirror of
				https://github.com/paboyle/Grid.git
				synced 2025-11-04 05:54:32 +00:00 
			
		
		
		
	multishift conjugate gradient added and a strong test: take a diagonal
but non-identity matrix l1 0 0 0 .... 0 l2 0 0 .... 0 0 l3 0 ... . . . . . . . . . And apply the multishift CG to it. Sum the poles and residues. Insist that this be the same as the exactly taken square root where l1,l2,l3 >= 0.
This commit is contained in:
		@@ -4,13 +4,16 @@
 | 
			
		||||
#include <algorithms/SparseMatrix.h>
 | 
			
		||||
#include <algorithms/LinearOperator.h>
 | 
			
		||||
 | 
			
		||||
#include <algorithms/approx/Zolotarev.h>
 | 
			
		||||
#include <algorithms/approx/Chebyshev.h>
 | 
			
		||||
#include <algorithms/approx/Remez.h>
 | 
			
		||||
#include <algorithms/approx/MultiShiftFunction.h>
 | 
			
		||||
 | 
			
		||||
#include <algorithms/iterative/ConjugateGradient.h>
 | 
			
		||||
#include <algorithms/iterative/NormalEquations.h>
 | 
			
		||||
#include <algorithms/iterative/SchurRedBlack.h>
 | 
			
		||||
 | 
			
		||||
#include <algorithms/approx/Zolotarev.h>
 | 
			
		||||
#include <algorithms/approx/Chebyshev.h>
 | 
			
		||||
#include <algorithms/approx/Remez.h>
 | 
			
		||||
#include <algorithms/iterative/ConjugateGradientMultiShift.h>
 | 
			
		||||
 | 
			
		||||
// Eigen/lanczos
 | 
			
		||||
// EigCg
 | 
			
		||||
 
 | 
			
		||||
@@ -1,4 +1,4 @@
 | 
			
		||||
 | 
			
		||||
HFILES=./algorithms/approx/bigfloat.h ./algorithms/approx/bigfloat_double.h ./algorithms/approx/Chebyshev.h ./algorithms/approx/Remez.h ./algorithms/approx/Zolotarev.h ./algorithms/iterative/ConjugateGradient.h ./algorithms/iterative/NormalEquations.h ./algorithms/iterative/SchurRedBlack.h ./algorithms/LinearOperator.h ./algorithms/SparseMatrix.h ./Algorithms.h ./AlignedAllocator.h ./cartesian/Cartesian_base.h ./cartesian/Cartesian_full.h ./cartesian/Cartesian_red_black.h ./Cartesian.h ./communicator/Communicator_base.h ./Communicator.h ./Comparison.h ./cshift/Cshift_common.h ./cshift/Cshift_mpi.h ./cshift/Cshift_none.h ./Cshift.h ./Grid.h ./GridConfig.h ./lattice/Lattice_arith.h ./lattice/Lattice_base.h ./lattice/Lattice_comparison.h ./lattice/Lattice_conformable.h ./lattice/Lattice_coordinate.h ./lattice/Lattice_ET.h ./lattice/Lattice_local.h ./lattice/Lattice_overload.h ./lattice/Lattice_peekpoke.h ./lattice/Lattice_reality.h ./lattice/Lattice_reduction.h ./lattice/Lattice_rng.h ./lattice/Lattice_trace.h ./lattice/Lattice_transfer.h ./lattice/Lattice_transpose.h ./lattice/Lattice_where.h ./Lattice.h ./parallelIO/NerscIO.h ./qcd/action/Actions.h ./qcd/action/fermion/CayleyFermion5D.h ./qcd/action/fermion/ContinuedFractionFermion5D.h ./qcd/action/fermion/DomainWallFermion.h ./qcd/action/fermion/FermionOperator.h ./qcd/action/fermion/MobiusFermion.h ./qcd/action/fermion/MobiusZolotarevFermion.h ./qcd/action/fermion/OverlapWilsonCayleyTanhFermion.h ./qcd/action/fermion/OverlapWilsonCayleyZolotarevFermion.h ./qcd/action/fermion/OverlapWilsonContfracTanhFermion.h ./qcd/action/fermion/OverlapWilsonContfracZolotarevFermion.h ./qcd/action/fermion/OverlapWilsonPartialFractionTanhFermion.h ./qcd/action/fermion/OverlapWilsonPartialFractionZolotarevFermion.h ./qcd/action/fermion/PartialFractionFermion5D.h ./qcd/action/fermion/ScaledShamirFermion.h ./qcd/action/fermion/ShamirZolotarevFermion.h ./qcd/action/fermion/WilsonCompressor.h ./qcd/action/fermion/WilsonFermion.h ./qcd/action/fermion/WilsonFermion5D.h ./qcd/action/fermion/WilsonKernels.h ./qcd/action/gauge/GaugeActionBase.h ./qcd/action/gauge/WilsonGaugeAction.h ./qcd/Dirac.h ./qcd/LinalgUtils.h ./qcd/QCD.h ./qcd/SpaceTimeGrid.h ./qcd/TwoSpinor.h ./qcd/utils/CovariantCshift.h ./qcd/utils/WilsonLoops.h ./simd/Grid_avx.h ./simd/Grid_avx512.h ./simd/Grid_qpx.h ./simd/Grid_sse4.h ./simd/Grid_vector_types.h ./simd/Old/Grid_vComplexD.h ./simd/Old/Grid_vComplexF.h ./simd/Old/Grid_vInteger.h ./simd/Old/Grid_vRealD.h ./simd/Old/Grid_vRealF.h ./Simd.h ./stencil/Lebesgue.h ./Stencil.h ./tensors/Tensor_arith.h ./tensors/Tensor_arith_add.h ./tensors/Tensor_arith_mac.h ./tensors/Tensor_arith_mul.h ./tensors/Tensor_arith_scalar.h ./tensors/Tensor_arith_sub.h ./tensors/Tensor_class.h ./tensors/Tensor_extract_merge.h ./tensors/Tensor_inner.h ./tensors/Tensor_outer.h ./tensors/Tensor_peek.h ./tensors/Tensor_poke.h ./tensors/Tensor_reality.h ./tensors/Tensor_Ta.h ./tensors/Tensor_trace.h ./tensors/Tensor_traits.h ./tensors/Tensor_transpose.h ./Tensors.h ./Threads.h
 | 
			
		||||
HFILES=./algorithms/approx/bigfloat.h ./algorithms/approx/bigfloat_double.h ./algorithms/approx/Chebyshev.h ./algorithms/approx/MultiShiftFunction.h ./algorithms/approx/Remez.h ./algorithms/approx/Zolotarev.h ./algorithms/iterative/ConjugateGradient.h ./algorithms/iterative/ConjugateGradientMultiShift.h ./algorithms/iterative/NormalEquations.h ./algorithms/iterative/SchurRedBlack.h ./algorithms/LinearOperator.h ./algorithms/SparseMatrix.h ./Algorithms.h ./AlignedAllocator.h ./cartesian/Cartesian_base.h ./cartesian/Cartesian_full.h ./cartesian/Cartesian_red_black.h ./Cartesian.h ./communicator/Communicator_base.h ./Communicator.h ./Comparison.h ./cshift/Cshift_common.h ./cshift/Cshift_mpi.h ./cshift/Cshift_none.h ./Cshift.h ./Grid.h ./GridConfig.h ./lattice/Lattice_arith.h ./lattice/Lattice_base.h ./lattice/Lattice_comparison.h ./lattice/Lattice_conformable.h ./lattice/Lattice_coordinate.h ./lattice/Lattice_ET.h ./lattice/Lattice_local.h ./lattice/Lattice_overload.h ./lattice/Lattice_peekpoke.h ./lattice/Lattice_reality.h ./lattice/Lattice_reduction.h ./lattice/Lattice_rng.h ./lattice/Lattice_trace.h ./lattice/Lattice_transfer.h ./lattice/Lattice_transpose.h ./lattice/Lattice_where.h ./Lattice.h ./parallelIO/NerscIO.h ./qcd/action/Actions.h ./qcd/action/fermion/CayleyFermion5D.h ./qcd/action/fermion/ContinuedFractionFermion5D.h ./qcd/action/fermion/DomainWallFermion.h ./qcd/action/fermion/FermionOperator.h ./qcd/action/fermion/MobiusFermion.h ./qcd/action/fermion/MobiusZolotarevFermion.h ./qcd/action/fermion/OverlapWilsonCayleyTanhFermion.h ./qcd/action/fermion/OverlapWilsonCayleyZolotarevFermion.h ./qcd/action/fermion/OverlapWilsonContfracTanhFermion.h ./qcd/action/fermion/OverlapWilsonContfracZolotarevFermion.h ./qcd/action/fermion/OverlapWilsonPartialFractionTanhFermion.h ./qcd/action/fermion/OverlapWilsonPartialFractionZolotarevFermion.h ./qcd/action/fermion/PartialFractionFermion5D.h ./qcd/action/fermion/ScaledShamirFermion.h ./qcd/action/fermion/ShamirZolotarevFermion.h ./qcd/action/fermion/WilsonCompressor.h ./qcd/action/fermion/WilsonFermion.h ./qcd/action/fermion/WilsonFermion5D.h ./qcd/action/fermion/WilsonKernels.h ./qcd/action/gauge/GaugeActionBase.h ./qcd/action/gauge/WilsonGaugeAction.h ./qcd/Dirac.h ./qcd/LinalgUtils.h ./qcd/QCD.h ./qcd/SpaceTimeGrid.h ./qcd/TwoSpinor.h ./qcd/utils/CovariantCshift.h ./qcd/utils/WilsonLoops.h ./simd/Grid_avx.h ./simd/Grid_avx512.h ./simd/Grid_qpx.h ./simd/Grid_sse4.h ./simd/Grid_vector_types.h ./simd/Old/Grid_vComplexD.h ./simd/Old/Grid_vComplexF.h ./simd/Old/Grid_vInteger.h ./simd/Old/Grid_vRealD.h ./simd/Old/Grid_vRealF.h ./Simd.h ./stencil/Lebesgue.h ./Stencil.h ./tensors/Tensor_arith.h ./tensors/Tensor_arith_add.h ./tensors/Tensor_arith_mac.h ./tensors/Tensor_arith_mul.h ./tensors/Tensor_arith_scalar.h ./tensors/Tensor_arith_sub.h ./tensors/Tensor_class.h ./tensors/Tensor_extract_merge.h ./tensors/Tensor_inner.h ./tensors/Tensor_outer.h ./tensors/Tensor_peek.h ./tensors/Tensor_poke.h ./tensors/Tensor_reality.h ./tensors/Tensor_Ta.h ./tensors/Tensor_trace.h ./tensors/Tensor_traits.h ./tensors/Tensor_transpose.h ./Tensors.h ./Threads.h
 | 
			
		||||
 | 
			
		||||
CCFILES=./algorithms/approx/Remez.cc ./algorithms/approx/Zolotarev.cc ./GridInit.cc ./qcd/action/fermion/CayleyFermion5D.cc ./qcd/action/fermion/ContinuedFractionFermion5D.cc ./qcd/action/fermion/PartialFractionFermion5D.cc ./qcd/action/fermion/WilsonFermion.cc ./qcd/action/fermion/WilsonFermion5D.cc ./qcd/action/fermion/WilsonKernels.cc ./qcd/action/fermion/WilsonKernelsHand.cc ./qcd/Dirac.cc ./qcd/SpaceTimeGrid.cc ./stencil/Lebesgue.cc ./stencil/Stencil_common.cc
 | 
			
		||||
CCFILES=./algorithms/approx/MultiShiftFunction.cc ./algorithms/approx/Remez.cc ./algorithms/approx/Zolotarev.cc ./GridInit.cc ./qcd/action/fermion/CayleyFermion5D.cc ./qcd/action/fermion/ContinuedFractionFermion5D.cc ./qcd/action/fermion/PartialFractionFermion5D.cc ./qcd/action/fermion/WilsonFermion.cc ./qcd/action/fermion/WilsonFermion5D.cc ./qcd/action/fermion/WilsonKernels.cc ./qcd/action/fermion/WilsonKernelsHand.cc ./qcd/Dirac.cc ./qcd/SpaceTimeGrid.cc ./stencil/Lebesgue.cc ./stencil/Stencil_common.cc
 | 
			
		||||
 
 | 
			
		||||
@@ -167,6 +167,14 @@ namespace Grid {
 | 
			
		||||
      virtual void operator() (LinearOperatorBase<Field> &Linop, const Field &in, Field &out) = 0;
 | 
			
		||||
    };
 | 
			
		||||
 | 
			
		||||
    /////////////////////////////////////////////////////////////
 | 
			
		||||
    // Base classes for Multishift solvers for operators
 | 
			
		||||
    /////////////////////////////////////////////////////////////
 | 
			
		||||
    template<class Field> class OperatorMultiFunction {
 | 
			
		||||
    public:
 | 
			
		||||
      virtual void operator() (LinearOperatorBase<Field> &Linop, const Field &in, std::vector<Field> &out) = 0;
 | 
			
		||||
    };
 | 
			
		||||
 | 
			
		||||
    // FIXME : To think about
 | 
			
		||||
 | 
			
		||||
    // Chroma functionality list defining LinearOperator
 | 
			
		||||
 
 | 
			
		||||
							
								
								
									
										29
									
								
								lib/algorithms/approx/MultiShiftFunction.cc
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										29
									
								
								lib/algorithms/approx/MultiShiftFunction.cc
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,29 @@
 | 
			
		||||
#include <Grid.h>
 | 
			
		||||
 | 
			
		||||
namespace Grid {
 | 
			
		||||
double MultiShiftFunction::approx(double x)
 | 
			
		||||
{
 | 
			
		||||
  double a = norm;
 | 
			
		||||
  for(int n=0;n<poles.size();n++){
 | 
			
		||||
    a = a + residues[n]/(x+poles[n]);
 | 
			
		||||
  }
 | 
			
		||||
  return a;
 | 
			
		||||
}
 | 
			
		||||
void MultiShiftFunction::gnuplot(std::ostream &out)
 | 
			
		||||
{
 | 
			
		||||
  out<<"f(x) = "<<norm<<"";
 | 
			
		||||
  for(int n=0;n<poles.size();n++){
 | 
			
		||||
    out<<"+("<<residues[n]<<"/(x+"<<poles[n]<<"))";
 | 
			
		||||
  }
 | 
			
		||||
  out<<";"<<std::endl;
 | 
			
		||||
}
 | 
			
		||||
void MultiShiftFunction::csv(std::ostream &out)
 | 
			
		||||
{
 | 
			
		||||
  for (double x=lo; x<hi; x*=1.05) {
 | 
			
		||||
    double f = approx(x);
 | 
			
		||||
    double r = sqrt(x);
 | 
			
		||||
    out<< x<<","<<r<<","<<f<<","<<r-f<<std::endl;
 | 
			
		||||
  }
 | 
			
		||||
  return;
 | 
			
		||||
}
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										28
									
								
								lib/algorithms/approx/MultiShiftFunction.h
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										28
									
								
								lib/algorithms/approx/MultiShiftFunction.h
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,28 @@
 | 
			
		||||
#ifndef MULTI_SHIFT_FUNCTION
 | 
			
		||||
#define MULTI_SHIFT_FUNCTION
 | 
			
		||||
namespace Grid {
 | 
			
		||||
class MultiShiftFunction {
 | 
			
		||||
public:
 | 
			
		||||
  int order;
 | 
			
		||||
  std::vector<RealD> poles;
 | 
			
		||||
  std::vector<RealD> residues;
 | 
			
		||||
  std::vector<RealD> tolerances;
 | 
			
		||||
  RealD norm;
 | 
			
		||||
  RealD lo,hi;
 | 
			
		||||
  MultiShiftFunction(int n,RealD _lo,RealD _hi): poles(n), residues(n), lo(_lo), hi(_hi) {;};
 | 
			
		||||
  RealD approx(RealD x);
 | 
			
		||||
  void csv(std::ostream &out);
 | 
			
		||||
  void gnuplot(std::ostream &out);
 | 
			
		||||
  MultiShiftFunction(AlgRemez & remez,double tol,bool inverse) :
 | 
			
		||||
      order(remez.getDegree()),
 | 
			
		||||
      tolerances(remez.getDegree(),tol),
 | 
			
		||||
      poles(remez.getDegree()),
 | 
			
		||||
      residues(remez.getDegree())
 | 
			
		||||
  {
 | 
			
		||||
    remez.getBounds(lo,hi);
 | 
			
		||||
    if ( inverse ) remez.getIPFE (&residues[0],&poles[0],&norm);
 | 
			
		||||
    else remez.getPFE (&residues[0],&poles[0],&norm);
 | 
			
		||||
  }
 | 
			
		||||
};
 | 
			
		||||
}
 | 
			
		||||
#endif
 | 
			
		||||
@@ -125,8 +125,17 @@ class AlgRemez
 | 
			
		||||
  // Destructor
 | 
			
		||||
  virtual ~AlgRemez();
 | 
			
		||||
 | 
			
		||||
  int getDegree(void){ 
 | 
			
		||||
    assert(n==d);
 | 
			
		||||
    return n;
 | 
			
		||||
  }
 | 
			
		||||
  // Reset the bounds of the approximation
 | 
			
		||||
  void setBounds(double lower, double upper);
 | 
			
		||||
  // Reset the bounds of the approximation
 | 
			
		||||
  void getBounds(double &lower, double &upper) { 
 | 
			
		||||
    lower=(double)apstrt;
 | 
			
		||||
    upper=(double)apend;
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  // Generate the rational approximation x^(pnum/pden)
 | 
			
		||||
  double generateApprox(int num_degree, int den_degree, 
 | 
			
		||||
 
 | 
			
		||||
							
								
								
									
										250
									
								
								lib/algorithms/iterative/ConjugateGradientMultiShift.h
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										250
									
								
								lib/algorithms/iterative/ConjugateGradientMultiShift.h
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,250 @@
 | 
			
		||||
#ifndef GRID_CONJUGATE_MULTI_SHIFT_GRADIENT_H
 | 
			
		||||
#define GRID_CONJUGATE_MULTI_SHIFT_GRADIENT_H
 | 
			
		||||
 | 
			
		||||
namespace Grid {
 | 
			
		||||
 | 
			
		||||
    /////////////////////////////////////////////////////////////
 | 
			
		||||
    // Base classes for iterative processes based on operators
 | 
			
		||||
    // single input vec, single output vec.
 | 
			
		||||
    /////////////////////////////////////////////////////////////
 | 
			
		||||
 | 
			
		||||
  template<class Field> 
 | 
			
		||||
    class ConjugateGradientMultiShift : public OperatorMultiFunction<Field>,
 | 
			
		||||
                                        public OperatorFunction<Field>
 | 
			
		||||
    {
 | 
			
		||||
public:                                                
 | 
			
		||||
    RealD   Tolerance;
 | 
			
		||||
    Integer MaxIterations;
 | 
			
		||||
    int verbose;
 | 
			
		||||
    MultiShiftFunction shifts;
 | 
			
		||||
 | 
			
		||||
    ConjugateGradientMultiShift(Integer maxit,MultiShiftFunction &_shifts) : 
 | 
			
		||||
	MaxIterations(maxit),
 | 
			
		||||
	shifts(_shifts)
 | 
			
		||||
    { 
 | 
			
		||||
      verbose=1;
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
void operator() (LinearOperatorBase<Field> &Linop, const Field &src, Field &psi)
 | 
			
		||||
{
 | 
			
		||||
 | 
			
		||||
  GridBase *grid = src._grid;
 | 
			
		||||
  int nshift = shifts.order;
 | 
			
		||||
  std::vector<Field> results(nshift,grid);
 | 
			
		||||
 | 
			
		||||
  (*this)(Linop,src,results);
 | 
			
		||||
  
 | 
			
		||||
  psi = shifts.norm*src;
 | 
			
		||||
  for(int i=0;i<nshift;i++){
 | 
			
		||||
    psi = psi + shifts.residues[i]*results[i];
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  return;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
void operator() (LinearOperatorBase<Field> &Linop, const Field &src, std::vector<Field> &psi)
 | 
			
		||||
{
 | 
			
		||||
  
 | 
			
		||||
  GridBase *grid = src._grid;
 | 
			
		||||
  
 | 
			
		||||
  ////////////////////////////////////////////////////////////////////////
 | 
			
		||||
  // Convenience references to the info stored in "MultiShiftFunction"
 | 
			
		||||
  ////////////////////////////////////////////////////////////////////////
 | 
			
		||||
  int nshift = shifts.order;
 | 
			
		||||
 | 
			
		||||
  std::vector<RealD> &mass(shifts.poles); // Make references to array in "shifts"
 | 
			
		||||
  std::vector<RealD> &mresidual(shifts.tolerances);
 | 
			
		||||
  std::vector<RealD> alpha(nshift,1.0);
 | 
			
		||||
  std::vector<Field>   ps(nshift,grid);// Search directions
 | 
			
		||||
 | 
			
		||||
  assert(psi.size()==nshift);
 | 
			
		||||
  assert(mass.size()==nshift);
 | 
			
		||||
  assert(mresidual.size()==nshift);
 | 
			
		||||
  
 | 
			
		||||
  // dynamic sized arrays on stack; 2d is a pain with vector
 | 
			
		||||
  RealD  bs[nshift];
 | 
			
		||||
  RealD  rsq[nshift];
 | 
			
		||||
  RealD  z[nshift][2];
 | 
			
		||||
  int     converged[nshift];
 | 
			
		||||
  
 | 
			
		||||
  const int       primary =0;
 | 
			
		||||
  
 | 
			
		||||
  //Primary shift fields CG iteration
 | 
			
		||||
  RealD a,b,c,d;
 | 
			
		||||
  RealD cp,bp,qq; //prev
 | 
			
		||||
  
 | 
			
		||||
  // Matrix mult fields
 | 
			
		||||
  Field r(grid);
 | 
			
		||||
  Field p(grid);
 | 
			
		||||
  Field tmp(grid);
 | 
			
		||||
  Field mmp(grid);
 | 
			
		||||
  
 | 
			
		||||
  // Check lightest mass
 | 
			
		||||
  for(int s=0;s<nshift;s++){
 | 
			
		||||
    assert( mass[s]>= mass[primary] );
 | 
			
		||||
    converged[s]=0;
 | 
			
		||||
  }
 | 
			
		||||
  
 | 
			
		||||
  // Wire guess to zero
 | 
			
		||||
  // Residuals "r" are src
 | 
			
		||||
  // First search direction "p" is also src
 | 
			
		||||
  cp = norm2(src);
 | 
			
		||||
  for(int s=0;s<nshift;s++){
 | 
			
		||||
    rsq[s] = cp * mresidual[s] * mresidual[s];
 | 
			
		||||
    std::cout<<"ConjugateGradientMultiShift: shift "<<s
 | 
			
		||||
	     <<" target resid "<<rsq[s]<<std::endl;
 | 
			
		||||
    ps[s] = src;
 | 
			
		||||
  }
 | 
			
		||||
  // r and p for primary
 | 
			
		||||
  r=src;
 | 
			
		||||
  p=src;
 | 
			
		||||
  
 | 
			
		||||
  //MdagM+m[0]
 | 
			
		||||
  Linop.HermOpAndNorm(p,mmp,d,qq);
 | 
			
		||||
  axpy(mmp,mass[0],p,mmp);
 | 
			
		||||
  RealD rn = norm2(p);
 | 
			
		||||
  d += rn*mass[0];
 | 
			
		||||
  
 | 
			
		||||
  // have verified that inner product of 
 | 
			
		||||
  // p and mmp is equal to d after this since
 | 
			
		||||
  // the d computation is tricky
 | 
			
		||||
  //  qq = real(innerProduct(p,mmp));
 | 
			
		||||
  //  std::cout << "debug equal ?  qq "<<qq<<" d "<< d<<std::endl;
 | 
			
		||||
  
 | 
			
		||||
  b = -cp /d;
 | 
			
		||||
  
 | 
			
		||||
  // Set up the various shift variables
 | 
			
		||||
  int       iz=0;
 | 
			
		||||
  z[0][1-iz] = 1.0;
 | 
			
		||||
  z[0][iz]   = 1.0;
 | 
			
		||||
  bs[0]      = b;
 | 
			
		||||
  for(int s=1;s<nshift;s++){
 | 
			
		||||
    z[s][1-iz] = 1.0;
 | 
			
		||||
    z[s][iz]   = 1.0/( 1.0 - b*(mass[s]-mass[0]));
 | 
			
		||||
    bs[s]      = b*z[s][iz]; 
 | 
			
		||||
  }
 | 
			
		||||
  
 | 
			
		||||
  // r += b[0] A.p[0]
 | 
			
		||||
  // c= norm(r)
 | 
			
		||||
  c=axpy_norm(r,b,mmp,r);
 | 
			
		||||
  
 | 
			
		||||
  for(int s=0;s<nshift;s++) {
 | 
			
		||||
    axpby(psi[s],0.,-bs[s]*alpha[s],src,src);
 | 
			
		||||
  }
 | 
			
		||||
  
 | 
			
		||||
  
 | 
			
		||||
  // Iteration loop
 | 
			
		||||
  int k;
 | 
			
		||||
  
 | 
			
		||||
  for (k=1;k<=MaxIterations;k++){
 | 
			
		||||
    
 | 
			
		||||
    a = c /cp;
 | 
			
		||||
    axpy(p,a,p,r);
 | 
			
		||||
    
 | 
			
		||||
    // Note to self - direction ps is iterated seperately
 | 
			
		||||
    // for each shift. Does not appear to have any scope
 | 
			
		||||
    // for avoiding linear algebra in "single" case.
 | 
			
		||||
    // 
 | 
			
		||||
    // However SAME r is used. Could load "r" and update
 | 
			
		||||
    // ALL ps[s]. 2/3 Bandwidth saving
 | 
			
		||||
    // New Kernel: Load r, vector of coeffs, vector of pointers ps
 | 
			
		||||
    for(int s=0;s<nshift;s++){
 | 
			
		||||
      if ( ! converged[s] ) { 
 | 
			
		||||
	if (s==0){
 | 
			
		||||
	  axpy(ps[s],a,ps[s],r);
 | 
			
		||||
	} else{
 | 
			
		||||
	  RealD as =a *z[s][iz]*bs[s] /(z[s][1-iz]*b);
 | 
			
		||||
	  axpby(ps[s],z[s][iz],as,r,ps[s]);
 | 
			
		||||
	}
 | 
			
		||||
      }
 | 
			
		||||
    }
 | 
			
		||||
    
 | 
			
		||||
    cp=c;
 | 
			
		||||
    
 | 
			
		||||
    Linop.HermOpAndNorm(p,mmp,d,qq);
 | 
			
		||||
    axpy(mmp,mass[0],p,mmp);
 | 
			
		||||
    RealD rn = norm2(p);
 | 
			
		||||
    d += rn*mass[0];
 | 
			
		||||
    
 | 
			
		||||
    bp=b;
 | 
			
		||||
    b=-cp/d;
 | 
			
		||||
    
 | 
			
		||||
    c=axpy_norm(r,b,mmp,r);
 | 
			
		||||
 | 
			
		||||
    // Toggle the recurrence history
 | 
			
		||||
    bs[0] = b;
 | 
			
		||||
    iz = 1-iz;
 | 
			
		||||
    for(int s=1;s<nshift;s++){
 | 
			
		||||
      if((!converged[s])){
 | 
			
		||||
	RealD z0 = z[s][1-iz];
 | 
			
		||||
	RealD z1 = z[s][iz];
 | 
			
		||||
	z[s][iz] = z0*z1*bp
 | 
			
		||||
	  / (b*a*(z1-z0) + z1*bp*(1- (mass[s]-mass[0])*b)); 
 | 
			
		||||
	bs[s] = b*z[s][iz]/z0; // NB sign  rel to Mike
 | 
			
		||||
      }
 | 
			
		||||
    }
 | 
			
		||||
    
 | 
			
		||||
    for(int s=0;s<nshift;s++){
 | 
			
		||||
      int ss = s;
 | 
			
		||||
      // Scope for optimisation here in case of "single".
 | 
			
		||||
      // Could load psi[0] and pull all ps[s] in.
 | 
			
		||||
      //      if ( single ) ss=primary;
 | 
			
		||||
      // Bandwith saving in single case is Ls * 3 -> 2+Ls, so ~ 3x saving
 | 
			
		||||
      // Pipelined CG gain:
 | 
			
		||||
      //
 | 
			
		||||
      // New Kernel: Load r, vector of coeffs, vector of pointers ps
 | 
			
		||||
      // New Kernel: Load psi[0], vector of coeffs, vector of pointers ps
 | 
			
		||||
      // If can predict the coefficient bs then we can fuse these and avoid write reread cyce
 | 
			
		||||
      //  on ps[s].
 | 
			
		||||
      // Before:  3 x npole  + 3 x npole
 | 
			
		||||
      // After :  2 x npole (ps[s])        => 3x speed up of multishift CG.
 | 
			
		||||
      
 | 
			
		||||
      if( (!converged[s]) ) { 
 | 
			
		||||
	axpy(psi[ss],-bs[s]*alpha[s],ps[s],psi[ss]);
 | 
			
		||||
      }
 | 
			
		||||
    }
 | 
			
		||||
    
 | 
			
		||||
    // Convergence checks
 | 
			
		||||
    int all_converged = 1;
 | 
			
		||||
    for(int s=0;s<nshift;s++){
 | 
			
		||||
      
 | 
			
		||||
      if ( (!converged[s]) ){
 | 
			
		||||
	
 | 
			
		||||
	RealD css  = c * z[s][iz]* z[s][iz];
 | 
			
		||||
	
 | 
			
		||||
	if(css<rsq[s]){
 | 
			
		||||
	  if ( ! converged[s] )
 | 
			
		||||
	    std::cout<<"ConjugateGradientMultiShift k="<<k<<" Shift "<<s<<" has converged"<<std::endl;
 | 
			
		||||
	      converged[s]=1;
 | 
			
		||||
	} else {
 | 
			
		||||
	  all_converged=0;
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
      }
 | 
			
		||||
    }
 | 
			
		||||
    
 | 
			
		||||
    if ( all_converged ){
 | 
			
		||||
 | 
			
		||||
      std::cout<< "CGMultiShift: All shifts have converged iteration "<<k<<std::endl;
 | 
			
		||||
      std::cout<< "CGMultiShift: Checking solutions"<<std::endl;
 | 
			
		||||
      
 | 
			
		||||
      // Check answers 
 | 
			
		||||
      for(int s=0; s < nshift; s++) { 
 | 
			
		||||
	Linop.HermOpAndNorm(psi[s],mmp,d,qq);
 | 
			
		||||
	axpy(tmp,mass[s],psi[s],mmp);
 | 
			
		||||
	axpy(r,-alpha[s],src,tmp);
 | 
			
		||||
	RealD rn = norm2(r);
 | 
			
		||||
	RealD cn = norm2(src);
 | 
			
		||||
	std::cout<<"CGMultiShift: shift["<<s<<"] true residual "<<std::sqrt(rn/cn)<<std::endl;
 | 
			
		||||
      }
 | 
			
		||||
      return;
 | 
			
		||||
    }
 | 
			
		||||
  }
 | 
			
		||||
  // ugly hack
 | 
			
		||||
  std::cout<<"CG multi shift did not converge"<<std::endl;
 | 
			
		||||
  assert(0);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
  };
 | 
			
		||||
}
 | 
			
		||||
#endif
 | 
			
		||||
		Reference in New Issue
	
	Block a user