mirror of
https://github.com/paboyle/Grid.git
synced 2025-04-09 21:50:45 +01:00
Dirichlet filters running on AMD and now integrated in Fermion op
This commit is contained in:
parent
70988e43d2
commit
0f1c5b08a1
@ -37,6 +37,9 @@ NAMESPACE_CHECK(ActionSet);
|
|||||||
#include <Grid/qcd/action/ActionParams.h>
|
#include <Grid/qcd/action/ActionParams.h>
|
||||||
NAMESPACE_CHECK(ActionParams);
|
NAMESPACE_CHECK(ActionParams);
|
||||||
|
|
||||||
|
#include <Grid/qcd/action/filters/MomentumFilter.h>
|
||||||
|
#include <Grid/qcd/action/filters/DirichletFilter.h>
|
||||||
|
|
||||||
////////////////////////////////////////////
|
////////////////////////////////////////////
|
||||||
// Gauge Actions
|
// Gauge Actions
|
||||||
////////////////////////////////////////////
|
////////////////////////////////////////////
|
||||||
|
@ -49,6 +49,8 @@ public:
|
|||||||
|
|
||||||
virtual FermionField &tmp(void) = 0;
|
virtual FermionField &tmp(void) = 0;
|
||||||
|
|
||||||
|
virtual void DirichletBlock(Coordinate & _Block) { assert(0); };
|
||||||
|
|
||||||
GridBase * Grid(void) { return FermionGrid(); }; // this is all the linalg routines need to know
|
GridBase * Grid(void) { return FermionGrid(); }; // this is all the linalg routines need to know
|
||||||
GridBase * RedBlackGrid(void) { return FermionRedBlackGrid(); };
|
GridBase * RedBlackGrid(void) { return FermionRedBlackGrid(); };
|
||||||
|
|
||||||
|
@ -75,6 +75,10 @@ public:
|
|||||||
FermionField _tmp;
|
FermionField _tmp;
|
||||||
FermionField &tmp(void) { return _tmp; }
|
FermionField &tmp(void) { return _tmp; }
|
||||||
|
|
||||||
|
int Dirichlet;
|
||||||
|
Coordinate Block;
|
||||||
|
|
||||||
|
/********** Deprecate timers **********/
|
||||||
void Report(void);
|
void Report(void);
|
||||||
void ZeroCounters(void);
|
void ZeroCounters(void);
|
||||||
double DhopCalls;
|
double DhopCalls;
|
||||||
@ -174,10 +178,16 @@ public:
|
|||||||
GridRedBlackCartesian &FourDimRedBlackGrid,
|
GridRedBlackCartesian &FourDimRedBlackGrid,
|
||||||
double _M5,const ImplParams &p= ImplParams());
|
double _M5,const ImplParams &p= ImplParams());
|
||||||
|
|
||||||
void DirichletBlock(std::vector<int> & block){
|
virtual void DirichletBlock(Coordinate & block)
|
||||||
Stencil.DirichletBlock(block);
|
{
|
||||||
StencilEven.DirichletBlock(block);
|
assert(block.size()==Nd+1);
|
||||||
StencilOdd.DirichletBlock(block);
|
if ( block[0] || block[1] || block[2] || block[3] || block[4] ){
|
||||||
|
Dirichlet = 1;
|
||||||
|
Block = block;
|
||||||
|
Stencil.DirichletBlock(block);
|
||||||
|
StencilEven.DirichletBlock(block);
|
||||||
|
StencilOdd.DirichletBlock(block);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
// Constructors
|
// Constructors
|
||||||
/*
|
/*
|
||||||
|
@ -60,7 +60,8 @@ WilsonFermion5D<Impl>::WilsonFermion5D(GaugeField &_Umu,
|
|||||||
UmuOdd (_FourDimRedBlackGrid),
|
UmuOdd (_FourDimRedBlackGrid),
|
||||||
Lebesgue(_FourDimGrid),
|
Lebesgue(_FourDimGrid),
|
||||||
LebesgueEvenOdd(_FourDimRedBlackGrid),
|
LebesgueEvenOdd(_FourDimRedBlackGrid),
|
||||||
_tmp(&FiveDimRedBlackGrid)
|
_tmp(&FiveDimRedBlackGrid),
|
||||||
|
Dirichlet(0)
|
||||||
{
|
{
|
||||||
// some assertions
|
// some assertions
|
||||||
assert(FiveDimGrid._ndimension==5);
|
assert(FiveDimGrid._ndimension==5);
|
||||||
@ -218,6 +219,14 @@ void WilsonFermion5D<Impl>::ImportGauge(const GaugeField &_Umu)
|
|||||||
{
|
{
|
||||||
GaugeField HUmu(_Umu.Grid());
|
GaugeField HUmu(_Umu.Grid());
|
||||||
HUmu = _Umu*(-0.5);
|
HUmu = _Umu*(-0.5);
|
||||||
|
if ( Dirichlet ) {
|
||||||
|
std::cout << GridLogMessage << " Dirichlet BCs 5d " <<Block<<std::endl;
|
||||||
|
Coordinate GaugeBlock(Nd);
|
||||||
|
for(int d=0;d<Nd;d++) GaugeBlock[d] = Block[d+1];
|
||||||
|
std::cout << GridLogMessage << " Dirichlet BCs 4d " <<GaugeBlock<<std::endl;
|
||||||
|
DirichletFilter<GaugeField> Filter(GaugeBlock);
|
||||||
|
Filter.applyFilter(HUmu);
|
||||||
|
}
|
||||||
Impl::DoubleStore(GaugeGrid(),Umu,HUmu);
|
Impl::DoubleStore(GaugeGrid(),Umu,HUmu);
|
||||||
pickCheckerboard(Even,UmuEven,Umu);
|
pickCheckerboard(Even,UmuEven,Umu);
|
||||||
pickCheckerboard(Odd ,UmuOdd,Umu);
|
pickCheckerboard(Odd ,UmuOdd,Umu);
|
||||||
|
71
Grid/qcd/action/filters/DirichletFilter.h
Normal file
71
Grid/qcd/action/filters/DirichletFilter.h
Normal file
@ -0,0 +1,71 @@
|
|||||||
|
/*************************************************************************************
|
||||||
|
|
||||||
|
Grid physics library, www.github.com/paboyle/Grid
|
||||||
|
|
||||||
|
Source file: ./lib/qcd/hmc/integrators/DirichletFilter.h
|
||||||
|
|
||||||
|
Copyright (C) 2015
|
||||||
|
|
||||||
|
Author: Peter Boyle <paboyle@ph.ed.ac.uk>
|
||||||
|
|
||||||
|
This program is free software; you can redistribute it and/or modify
|
||||||
|
it under the terms of the GNU General Public License as published by
|
||||||
|
the Free Software Foundation; either version 2 of the License, or
|
||||||
|
(at your option) any later version.
|
||||||
|
|
||||||
|
This program is distributed in the hope that it will be useful,
|
||||||
|
but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||||
|
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||||
|
GNU General Public License for more details.
|
||||||
|
|
||||||
|
You should have received a copy of the GNU General Public License along
|
||||||
|
with this program; if not, write to the Free Software Foundation, Inc.,
|
||||||
|
51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
|
||||||
|
|
||||||
|
See the full license in the file "LICENSE" in the top level distribution
|
||||||
|
directory
|
||||||
|
*************************************************************************************/
|
||||||
|
/* END LEGAL */
|
||||||
|
//--------------------------------------------------------------------
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
NAMESPACE_BEGIN(Grid);
|
||||||
|
|
||||||
|
template<typename MomentaField>
|
||||||
|
struct DirichletFilter: public MomentumFilterBase<MomentaField>
|
||||||
|
{
|
||||||
|
typedef typename MomentaField::vector_type vector_type; //SIMD-vectorized complex type
|
||||||
|
typedef typename MomentaField::scalar_type scalar_type; //scalar complex type
|
||||||
|
|
||||||
|
typedef iScalar<iScalar<iScalar<vector_type> > > ScalarType; //complex phase for each site
|
||||||
|
|
||||||
|
Coordinate Block;
|
||||||
|
|
||||||
|
DirichletFilter(const Coordinate &_Block): Block(_Block){}
|
||||||
|
|
||||||
|
void applyFilter(MomentaField &P) const override
|
||||||
|
{
|
||||||
|
GridBase *grid = P.Grid();
|
||||||
|
typedef decltype(PeekIndex<LorentzIndex>(P, 0)) LatCM;
|
||||||
|
////////////////////////////////////////////////////
|
||||||
|
// Zero strictly links crossing between domains
|
||||||
|
////////////////////////////////////////////////////
|
||||||
|
LatticeInteger coor(grid);
|
||||||
|
LatCM zz(grid); zz = Zero();
|
||||||
|
for(int mu=0;mu<Nd;mu++) {
|
||||||
|
if ( (Block[mu]) && (Block[mu] < grid->GlobalDimensions()[mu] ) ) {
|
||||||
|
// If costly could provide Grid earlier and precompute masks
|
||||||
|
std::cout << " Dirichlet in mu="<<mu<<std::endl;
|
||||||
|
LatticeCoordinate(coor,mu);
|
||||||
|
auto P_mu = PeekIndex<LorentzIndex>(P, mu);
|
||||||
|
P_mu = where(mod(coor,Block[mu])==Integer(Block[mu]-1),zz,P_mu);
|
||||||
|
PokeIndex<LorentzIndex>(P, P_mu, mu);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
NAMESPACE_END(Grid);
|
||||||
|
|
@ -33,7 +33,6 @@ directory
|
|||||||
#define INTEGRATOR_INCLUDED
|
#define INTEGRATOR_INCLUDED
|
||||||
|
|
||||||
#include <memory>
|
#include <memory>
|
||||||
#include "MomentumFilter.h"
|
|
||||||
|
|
||||||
NAMESPACE_BEGIN(Grid);
|
NAMESPACE_BEGIN(Grid);
|
||||||
|
|
||||||
|
@ -648,7 +648,7 @@ public:
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
/// Introduce a block structure and switch off comms on boundaries
|
/// Introduce a block structure and switch off comms on boundaries
|
||||||
void DirichletBlock(const std::vector<int> &dirichlet_block)
|
void DirichletBlock(const Coordinate &dirichlet_block)
|
||||||
{
|
{
|
||||||
this->_dirichlet = 1;
|
this->_dirichlet = 1;
|
||||||
for(int ii=0;ii<this->_npoints;ii++){
|
for(int ii=0;ii<this->_npoints;ii++){
|
||||||
|
@ -36,41 +36,6 @@ using namespace Grid;
|
|||||||
/// Move to domains ////
|
/// Move to domains ////
|
||||||
////////////////////////
|
////////////////////////
|
||||||
|
|
||||||
template<typename MomentaField>
|
|
||||||
struct DirichletFilter: public MomentumFilterBase<MomentaField>
|
|
||||||
{
|
|
||||||
typedef typename MomentaField::vector_type vector_type; //SIMD-vectorized complex type
|
|
||||||
typedef typename MomentaField::scalar_type scalar_type; //scalar complex type
|
|
||||||
|
|
||||||
typedef iScalar<iScalar<iScalar<vector_type> > > ScalarType; //complex phase for each site
|
|
||||||
|
|
||||||
Coordinate Block;
|
|
||||||
|
|
||||||
DirichletFilter(const Coordinate &_Block): Block(_Block){}
|
|
||||||
|
|
||||||
void applyFilter(MomentaField &P) const override
|
|
||||||
{
|
|
||||||
GridBase *grid = P.Grid();
|
|
||||||
typedef decltype(PeekIndex<LorentzIndex>(P, 0)) LatCM;
|
|
||||||
////////////////////////////////////////////////////
|
|
||||||
// Zero strictly links crossing between domains
|
|
||||||
////////////////////////////////////////////////////
|
|
||||||
LatticeInteger coor(grid);
|
|
||||||
LatCM zz(grid); zz = Zero();
|
|
||||||
for(int mu=0;mu<Nd;mu++) {
|
|
||||||
|
|
||||||
if ( Block[mu] ) {
|
|
||||||
// If costly could provide Grid earlier and precompute masks
|
|
||||||
LatticeCoordinate(coor,mu);
|
|
||||||
auto P_mu = PeekIndex<LorentzIndex>(P, mu);
|
|
||||||
P_mu = where(mod(coor,Block[mu])==Integer(Block[mu]-1),zz,P_mu);
|
|
||||||
PokeIndex<LorentzIndex>(P, P_mu, mu);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
|
|
||||||
Gamma::Algebra Gmu [] = {
|
Gamma::Algebra Gmu [] = {
|
||||||
Gamma::Algebra::GammaX,
|
Gamma::Algebra::GammaX,
|
||||||
Gamma::Algebra::GammaY,
|
Gamma::Algebra::GammaY,
|
||||||
@ -78,7 +43,7 @@ Gamma::Algebra Gmu [] = {
|
|||||||
Gamma::Algebra::GammaT
|
Gamma::Algebra::GammaT
|
||||||
};
|
};
|
||||||
|
|
||||||
void Benchmark(int Ls, std::vector<int> Dirichlet);
|
void Benchmark(int Ls, Coordinate Dirichlet);
|
||||||
|
|
||||||
int main (int argc, char ** argv)
|
int main (int argc, char ** argv)
|
||||||
{
|
{
|
||||||
@ -97,8 +62,9 @@ int main (int argc, char ** argv)
|
|||||||
//////////////////
|
//////////////////
|
||||||
// With comms
|
// With comms
|
||||||
//////////////////
|
//////////////////
|
||||||
std::vector<int> Dirichlet(5,0);
|
Coordinate Dirichlet(Nd+1,0);
|
||||||
|
|
||||||
|
std::cout << "\n\n\n\n\n\n" <<std::endl;
|
||||||
std::cout << GridLogMessage<< "++++++++++++++++++++++++++++++++++++++++++++++++" <<std::endl;
|
std::cout << GridLogMessage<< "++++++++++++++++++++++++++++++++++++++++++++++++" <<std::endl;
|
||||||
std::cout << GridLogMessage<< " Testing with full communication " <<std::endl;
|
std::cout << GridLogMessage<< " Testing with full communication " <<std::endl;
|
||||||
std::cout << GridLogMessage<< "++++++++++++++++++++++++++++++++++++++++++++++++" <<std::endl;
|
std::cout << GridLogMessage<< "++++++++++++++++++++++++++++++++++++++++++++++++" <<std::endl;
|
||||||
@ -110,7 +76,7 @@ int main (int argc, char ** argv)
|
|||||||
//////////////////
|
//////////////////
|
||||||
Coordinate latt4 = GridDefaultLatt();
|
Coordinate latt4 = GridDefaultLatt();
|
||||||
Coordinate mpi = GridDefaultMpi();
|
Coordinate mpi = GridDefaultMpi();
|
||||||
std::vector<int> CommDim(Nd);
|
Coordinate CommDim(Nd);
|
||||||
Coordinate shm;
|
Coordinate shm;
|
||||||
GlobalSharedMemory::GetShmDims(mpi,shm);
|
GlobalSharedMemory::GetShmDims(mpi,shm);
|
||||||
|
|
||||||
@ -118,36 +84,39 @@ int main (int argc, char ** argv)
|
|||||||
//////////////////////
|
//////////////////////
|
||||||
// Node level
|
// Node level
|
||||||
//////////////////////
|
//////////////////////
|
||||||
|
std::cout << "\n\n\n\n\n\n" <<std::endl;
|
||||||
std::cout << GridLogMessage<< "++++++++++++++++++++++++++++++++++++++++++++++++" <<std::endl;
|
std::cout << GridLogMessage<< "++++++++++++++++++++++++++++++++++++++++++++++++" <<std::endl;
|
||||||
std::cout << GridLogMessage<< " Testing without internode communication " <<std::endl;
|
std::cout << GridLogMessage<< " Testing without internode communication " <<std::endl;
|
||||||
std::cout << GridLogMessage<< "++++++++++++++++++++++++++++++++++++++++++++++++" <<std::endl;
|
std::cout << GridLogMessage<< "++++++++++++++++++++++++++++++++++++++++++++++++" <<std::endl;
|
||||||
|
|
||||||
for(int d=0;d<Nd;d++) CommDim[d]= (mpi[d]/shm[d])>1 ? 1 : 0;
|
for(int d=0;d<Nd;d++) CommDim[d]= (mpi[d]/shm[d])>1 ? 1 : 0;
|
||||||
Dirichlet = std::vector<int>({0,
|
Dirichlet[0] = 0;
|
||||||
CommDim[0]*latt4[0]/mpi[0] * shm[0],
|
Dirichlet[1] = CommDim[0]*latt4[0]/mpi[0] * shm[0];
|
||||||
CommDim[1]*latt4[1]/mpi[1] * shm[1],
|
Dirichlet[2] = CommDim[1]*latt4[1]/mpi[1] * shm[1];
|
||||||
CommDim[2]*latt4[2]/mpi[2] * shm[2],
|
Dirichlet[3] = CommDim[2]*latt4[2]/mpi[2] * shm[2];
|
||||||
CommDim[3]*latt4[3]/mpi[3] * shm[3]});
|
Dirichlet[4] = CommDim[3]*latt4[3]/mpi[3] * shm[3];
|
||||||
|
|
||||||
Benchmark(Ls,Dirichlet);
|
Benchmark(Ls,Dirichlet);
|
||||||
|
|
||||||
|
std::cout << "\n\n\n\n\n\n" <<std::endl;
|
||||||
|
|
||||||
std::cout << GridLogMessage<< "++++++++++++++++++++++++++++++++++++++++++++++++" <<std::endl;
|
std::cout << GridLogMessage<< "++++++++++++++++++++++++++++++++++++++++++++++++" <<std::endl;
|
||||||
std::cout << GridLogMessage<< " Testing without intranode communication " <<std::endl;
|
std::cout << GridLogMessage<< " Testing without intranode communication " <<std::endl;
|
||||||
std::cout << GridLogMessage<< "++++++++++++++++++++++++++++++++++++++++++++++++" <<std::endl;
|
std::cout << GridLogMessage<< "++++++++++++++++++++++++++++++++++++++++++++++++" <<std::endl;
|
||||||
|
|
||||||
for(int d=0;d<Nd;d++) CommDim[d]= mpi[d]>1 ? 1 : 0;
|
for(int d=0;d<Nd;d++) CommDim[d]= mpi[d]>1 ? 1 : 0;
|
||||||
Dirichlet = std::vector<int>({0,
|
Dirichlet[0] = 0;
|
||||||
CommDim[0]*latt4[0]/mpi[0],
|
Dirichlet[1] = CommDim[0]*latt4[0]/mpi[0];
|
||||||
CommDim[1]*latt4[1]/mpi[1],
|
Dirichlet[2] = CommDim[1]*latt4[1]/mpi[1];
|
||||||
CommDim[2]*latt4[2]/mpi[2],
|
Dirichlet[3] = CommDim[2]*latt4[2]/mpi[2];
|
||||||
CommDim[3]*latt4[3]/mpi[3]});
|
Dirichlet[4] = CommDim[3]*latt4[3]/mpi[3];
|
||||||
|
|
||||||
Benchmark(Ls,Dirichlet);
|
Benchmark(Ls,Dirichlet);
|
||||||
|
|
||||||
Grid_finalize();
|
Grid_finalize();
|
||||||
exit(0);
|
exit(0);
|
||||||
}
|
}
|
||||||
void Benchmark(int Ls, std::vector<int> Dirichlet)
|
void Benchmark(int Ls, Coordinate Dirichlet)
|
||||||
{
|
{
|
||||||
Coordinate latt4 = GridDefaultLatt();
|
Coordinate latt4 = GridDefaultLatt();
|
||||||
GridLogLayout();
|
GridLogLayout();
|
||||||
@ -196,7 +165,9 @@ void Benchmark(int Ls, std::vector<int> Dirichlet)
|
|||||||
|
|
||||||
std::cout << GridLogMessage << "Drawing gauge field" << std::endl;
|
std::cout << GridLogMessage << "Drawing gauge field" << std::endl;
|
||||||
LatticeGaugeFieldF Umu(UGrid);
|
LatticeGaugeFieldF Umu(UGrid);
|
||||||
|
LatticeGaugeFieldF UmuCopy(UGrid);
|
||||||
SU<Nc>::HotConfiguration(RNG4,Umu);
|
SU<Nc>::HotConfiguration(RNG4,Umu);
|
||||||
|
UmuCopy=Umu;
|
||||||
std::cout << GridLogMessage << "Random gauge initialised " << std::endl;
|
std::cout << GridLogMessage << "Random gauge initialised " << std::endl;
|
||||||
|
|
||||||
////////////////////////////////////
|
////////////////////////////////////
|
||||||
@ -205,7 +176,8 @@ void Benchmark(int Ls, std::vector<int> Dirichlet)
|
|||||||
Coordinate Block(4);
|
Coordinate Block(4);
|
||||||
for(int d=0;d<4;d++) Block[d]= Dirichlet[d+1];
|
for(int d=0;d<4;d++) Block[d]= Dirichlet[d+1];
|
||||||
|
|
||||||
std::cout << GridLogMessage << "Applying BCs for Dirichlet Block " << Block << std::endl;
|
std::cout << GridLogMessage << "Applying BCs for Dirichlet Block5 " << Dirichlet << std::endl;
|
||||||
|
std::cout << GridLogMessage << "Applying BCs for Dirichlet Block4 " << Block << std::endl;
|
||||||
|
|
||||||
DirichletFilter<LatticeGaugeFieldF> Filter(Block);
|
DirichletFilter<LatticeGaugeFieldF> Filter(Block);
|
||||||
Filter.applyFilter(Umu);
|
Filter.applyFilter(Umu);
|
||||||
@ -279,6 +251,7 @@ void Benchmark(int Ls, std::vector<int> Dirichlet)
|
|||||||
|
|
||||||
DomainWallFermionF Dw(Umu,*FGrid,*FrbGrid,*UGrid,*UrbGrid,mass,M5);
|
DomainWallFermionF Dw(Umu,*FGrid,*FrbGrid,*UGrid,*UrbGrid,mass,M5);
|
||||||
Dw.DirichletBlock(Dirichlet);
|
Dw.DirichletBlock(Dirichlet);
|
||||||
|
Dw.ImportGauge(Umu);
|
||||||
|
|
||||||
int ncall =300;
|
int ncall =300;
|
||||||
|
|
||||||
@ -377,9 +350,6 @@ void Benchmark(int Ls, std::vector<int> Dirichlet)
|
|||||||
std::cout<<GridLogMessage << "norm dag ref "<< norm2(ref)<<std::endl;
|
std::cout<<GridLogMessage << "norm dag ref "<< norm2(ref)<<std::endl;
|
||||||
err = ref-result;
|
err = ref-result;
|
||||||
std::cout<<GridLogMessage << "norm dag diff "<< norm2(err)<<std::endl;
|
std::cout<<GridLogMessage << "norm dag diff "<< norm2(err)<<std::endl;
|
||||||
if ( norm2(err)>1.0e-4) {
|
|
||||||
std::cout << err << std::endl;
|
|
||||||
}
|
|
||||||
assert((norm2(err)<1.0e-4));
|
assert((norm2(err)<1.0e-4));
|
||||||
|
|
||||||
LatticeFermionF src_e (FrbGrid);
|
LatticeFermionF src_e (FrbGrid);
|
||||||
|
Loading…
x
Reference in New Issue
Block a user