mirror of
https://github.com/paboyle/Grid.git
synced 2025-04-25 13:15:55 +01:00
Added Hermition check in BlockCG
This commit is contained in:
parent
95e9fd1889
commit
8f6039646b
@ -54,9 +54,10 @@ class BlockConjugateGradient : public OperatorFunction<Field> {
|
||||
RealD Tolerance;
|
||||
Integer MaxIterations;
|
||||
Integer IterationsToComplete; //Number of iterations the CG took to finish. Filled in upon completion
|
||||
Integer PrintInterval; //GridLogMessages or Iterative
|
||||
|
||||
BlockConjugateGradient(BlockCGtype cgtype,int _Orthog,RealD tol, Integer maxit, bool err_on_no_conv = true)
|
||||
: Tolerance(tol), CGtype(cgtype), blockDim(_Orthog), MaxIterations(maxit), ErrorOnNoConverge(err_on_no_conv)
|
||||
: Tolerance(tol), CGtype(cgtype), blockDim(_Orthog), MaxIterations(maxit), ErrorOnNoConverge(err_on_no_conv),PrintInterval(100)
|
||||
{};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@ -608,6 +609,28 @@ void CGmultiRHSsolve(LinearOperatorBase<Field> &Linop, const Field &Src, Field &
|
||||
IterationsToComplete = k;
|
||||
}
|
||||
|
||||
void InnerProductMatrix(Eigen::MatrixXcd &m , const std::vector<Field> &X, std::vector<Field> &Y){
|
||||
for(int b=0;b<Nblock;b++)
|
||||
for(int bp=0;bp<Nblock;bp++) {
|
||||
m(b,bp) = innerProduct(X[b],Y[bp]);
|
||||
}
|
||||
}
|
||||
double HermCheck( Eigen::MatrixXcd &m, const std::string &str, int ForceHerm=1 , int Print = 0) {
|
||||
for(int b=0;b<Nblock;b++)
|
||||
for(int bp=0;bp<=b;bp++) {
|
||||
if(Print)
|
||||
std::cout<<GridLogMessage << "HermCheck "<<str<<" "<<b<<" "<<bp<<" : "<< m(b,bp) <<" "<<conj(m(bp,b))<<" " <<m(b,bp)-conj(m(bp,b)) <<std::endl;
|
||||
if(ForceHerm){
|
||||
if(b==bp) m(b,b) = real(m(b,b));
|
||||
else{
|
||||
auto temp = 0.5*(m(b,bp)+conj(m(bp,b)));
|
||||
m(b,bp) = temp;
|
||||
m(bp,b) = conj(temp);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void BlockCGVecsolve(LinearOperatorBase<Field> &Linop, const std::vector<Field> &Src, std::vector<Field> &Psi)
|
||||
{
|
||||
// int Orthog = blockDim; // First dimension is block dim; this is an assumption
|
||||
@ -680,11 +703,16 @@ void BlockCGVecsolve(LinearOperatorBase<Field> &Linop, const std::vector<Field>
|
||||
P[b] = R[b]; // P_1
|
||||
}
|
||||
// sliceInnerProductMatrix(m_rr,R,R,Orthog);
|
||||
InnerProductMatrix(m_rr,R,R);
|
||||
HermCheck(m_rr, "R_0 R_0",1,1);
|
||||
HermCheck(m_rr, "R_0 R_0",0,1);
|
||||
#if 0
|
||||
for(int b=0;b<Nblock;b++)
|
||||
for(int bp=0;bp<Nblock;bp++) {
|
||||
m_rr(b,bp) = innerProduct(R[b],R[bp]);
|
||||
std::cout << 0 <<" : R_0 R_0 "<< b <<" "<<bp<<" "<<innerProduct(R[b],R[bp]) <<std::endl;
|
||||
}
|
||||
#endif
|
||||
|
||||
GridStopWatch sliceInnerTimer;
|
||||
GridStopWatch sliceMaddTimer;
|
||||
@ -693,14 +721,21 @@ void BlockCGVecsolve(LinearOperatorBase<Field> &Linop, const std::vector<Field>
|
||||
SolverTimer.Start();
|
||||
|
||||
int k;
|
||||
int if_print =0;
|
||||
for (k = 1; k <= MaxIterations; k++) {
|
||||
|
||||
RealD rrsum=0;
|
||||
for(int b=0;b<Nblock;b++) rrsum+=real(m_rr(b,b));
|
||||
|
||||
if((k%1)==0)
|
||||
std::cout << GridLogMessage << "\titeration "<<k<<" rr_sum "<<rrsum<<" ssq_sum "<< sssum
|
||||
if(PrintInterval && (k%PrintInterval)==0){
|
||||
if_print=1;
|
||||
std::cout << GridLogMessage << "\titeration "<<k<<" rr_sum "<<rrsum<<" ssq_sum "<< sssum
|
||||
<<" / "<<std::sqrt(rrsum/sssum) <<std::endl;
|
||||
} else {
|
||||
if_print=0;
|
||||
std::cout << GridLogIterative << "\titeration "<<k<<" rr_sum "<<rrsum<<" ssq_sum "<< sssum
|
||||
<<" / "<<std::sqrt(rrsum/sssum) <<std::endl;
|
||||
}
|
||||
|
||||
MatrixTimer.Start();
|
||||
for(int b=0;b<Nblock;b++) Linop.HermOp(P[b], AP[b]);
|
||||
@ -709,13 +744,21 @@ void BlockCGVecsolve(LinearOperatorBase<Field> &Linop, const std::vector<Field>
|
||||
// Alpha
|
||||
sliceInnerTimer.Start();
|
||||
// sliceInnerProductMatrix(m_pAp,P,AP,Orthog);
|
||||
InnerProductMatrix(m_pAp,P,AP);
|
||||
HermCheck(m_pAp, "P AP",1,if_print);
|
||||
if(if_print) HermCheck(m_pAp, "P AP",0,if_print);
|
||||
#if 0
|
||||
for(int b=0;b<Nblock;b++)
|
||||
for(int bp=0;bp<Nblock;bp++) {
|
||||
m_pAp(b,bp) = innerProduct(P[b],AP[bp]);
|
||||
std::cout << k <<" : m_pAp "<< b <<" "<<bp<<" "<<innerProduct(P[b],AP[bp]) <<std::endl;
|
||||
}
|
||||
#endif
|
||||
sliceInnerTimer.Stop();
|
||||
m_pAp_inv = m_pAp.inverse();
|
||||
HermCheck(m_pAp_inv, "inv (P AP)",1,if_print);
|
||||
if(if_print) HermCheck(m_pAp_inv, "inv (P AP)",0,if_print);
|
||||
if(if_print)
|
||||
{
|
||||
m_alpha = m_pAp*m_pAp_inv;
|
||||
for(int b=0;b<Nblock;b++){
|
||||
@ -741,6 +784,7 @@ void BlockCGVecsolve(LinearOperatorBase<Field> &Linop, const std::vector<Field>
|
||||
R[b] -= m_alpha(bp,b)*AP[bp]; // R_k+1 = R_k - AP_k+1 alpha_k+1
|
||||
}
|
||||
sliceMaddTimer.Stop();
|
||||
if(if_print)
|
||||
{
|
||||
//check
|
||||
for(int b=0;b<Nblock;b++){
|
||||
@ -753,14 +797,16 @@ void BlockCGVecsolve(LinearOperatorBase<Field> &Linop, const std::vector<Field>
|
||||
|
||||
// Beta
|
||||
m_rr_inv = m_rr.inverse(); //m_rr_inv = (R_k^t R_k)^-1
|
||||
HermCheck(m_rr_inv,"m_rr_inv",1,if_print);
|
||||
if(if_print) HermCheck(m_rr_inv,"m_rr_inv",0,if_print);
|
||||
sliceInnerTimer.Start();
|
||||
// sliceInnerProductMatrix(m_rr,R,R,Orthog);
|
||||
for(int b=0;b<Nblock;b++)
|
||||
for(int bp=0;bp<Nblock;bp++) {
|
||||
m_rr(b,bp) = innerProduct(R[b],R[bp]); // beta_k+2 = (R_k
|
||||
}
|
||||
InnerProductMatrix(m_rr,R,R);
|
||||
HermCheck(m_rr,"m_rr",1,if_print);
|
||||
if(if_print) HermCheck(m_rr,"m_rr",0,if_print);
|
||||
sliceInnerTimer.Stop();
|
||||
m_beta = m_rr_inv *m_rr; // beta_k+2 = (R_k^t R_k)^-1 (R_k+1^5 R_k+1)
|
||||
// HermCheck(m_beta,"m_beta");
|
||||
|
||||
// Search update
|
||||
sliceMaddTimer.Start();
|
||||
@ -771,6 +817,7 @@ void BlockCGVecsolve(LinearOperatorBase<Field> &Linop, const std::vector<Field>
|
||||
AP[b] += m_beta(bp,b)*P[bp]; //AP = R_k+1 + P_k+1 beta_k+1
|
||||
}
|
||||
}
|
||||
if(if_print)
|
||||
{
|
||||
//check
|
||||
for(int b=0;b<Nblock;b++) Linop.HermOp(P[b], TMP[b]);
|
||||
|
@ -125,7 +125,7 @@ int main (int argc, char ** argv)
|
||||
for(int s=0;s<nrhs;s++) {
|
||||
random(pRNG5,src[s]);
|
||||
tmp = 10.0*s;
|
||||
src[s] = (src[s] * 0.1) + tmp;
|
||||
// src[s] = (src[s] * 0.1) + tmp;
|
||||
std::cout << GridLogMessage << " src ["<<s<<"] "<<norm2(src[s])<<std::endl;
|
||||
}
|
||||
#endif
|
||||
@ -240,14 +240,10 @@ int main (int argc, char ** argv)
|
||||
|
||||
for(int s=0;s<nrhs;s++) result[s]=zero;
|
||||
|
||||
// ConjugateGradient<FermionField> CG(stp,10000);
|
||||
int blockDim = 0;
|
||||
// BlockConjugateGradient<FermionField> BCGrQ(BlockCGrQ,blockDim,stp,10000);
|
||||
BlockConjugateGradient<FermionField> BCG (BlockCG,blockDim,stp,10000);
|
||||
// BlockConjugateGradient<FermionField> mCG (CGmultiRHS,blockDim,stp,10000);
|
||||
int blockDim = 0;//not used for BlockCGVec
|
||||
BlockConjugateGradient<FermionField> BCGV (BlockCGVec,blockDim,stp,10000);
|
||||
BCGV.PrintInterval=10;
|
||||
{
|
||||
// BCG(HermOpCk,src[0],result[0]);
|
||||
BCGV(HermOpCk,src,result);
|
||||
}
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user