diff --git a/Grid/qcd/action/momentum/DirichletFilter.h b/Grid/qcd/action/momentum/DirichletFilter.h index 85010ed1..d7606bb2 100644 --- a/Grid/qcd/action/momentum/DirichletFilter.h +++ b/Grid/qcd/action/momentum/DirichletFilter.h @@ -32,37 +32,67 @@ directory //////////////////////////////////////////////////// #pragma once +#include + NAMESPACE_BEGIN(Grid); + template struct DirichletFilter: public MomentumFilterBase { Coordinate Block; - DirichletFilter(const Coordinate &_Block): Block(_Block){} + DirichletFilter(const Coordinate &_Block): Block(_Block) {} - void applyFilter(MomentaField &P) const override + // Edge detect using domain projectors + void applyFilter (MomentaField &U) const override { - GridBase *grid = P.Grid(); + DomainDecomposition Domains(Block); + GridBase *grid = U.Grid(); + LatticeInteger coor(grid); + LatticeInteger face(grid); + LatticeInteger one(grid); one = 1; + LatticeInteger zero(grid); zero = 0; + LatticeInteger omega(grid); + LatticeInteger omegabar(grid); + LatticeInteger tmp(grid); - //////////////////////////////////////////////////// - // Zero strictly links crossing between domains - //////////////////////////////////////////////////// - LatticeInteger coor(grid); - typedef decltype(PeekIndex(P,0)) MatrixType; - MatrixType zz(grid); zz = Zero(); + omega=one; Domains.ProjectDomain(omega,0); + omegabar=one; Domains.ProjectDomain(omegabar,1); + + LatticeInteger nface(grid); nface=Zero(); + + MomentaField projected(grid); projected=Zero(); + typedef decltype(PeekIndex(U,0)) MomentaLinkField; + MomentaLinkField Umu(grid); + MomentaLinkField zz(grid); zz=Zero(); + + int dims = grid->Nd(); Coordinate Global=grid->GlobalDimensions(); - for(int mu=0;mu1) ) { - // If costly could provide Grid earlier and precompute masks - LatticeCoordinate(coor,mu); - - auto P_mu = PeekIndex(P, mu); - P_mu = where(mod(coor,Block[mu])==Integer(Block[mu]-1),zz,P_mu); - PokeIndex(P, P_mu, mu); + assert(dims==Nd); + + for(int mu=0;mu(U,mu); + + // Upper face + tmp = Cshift(omegabar,mu,1); + tmp = tmp + omega; + face = where(tmp == Integer(2),one,zero ); + + tmp = Cshift(omega,mu,1); + tmp = tmp + omegabar; + face = where(tmp == Integer(2),one,face ); + + Umu = where(face,zz,Umu); + + PokeIndex(U, Umu, mu); } } } + }; NAMESPACE_END(Grid);