1
0
mirror of https://github.com/paboyle/Grid.git synced 2024-09-20 09:15:38 +01:00
Grid/Grid_predicated.h
Peter Boyle 31f4f4f1e1 "where" and integer comparisons logic implemented for conditional
assignment. LatticeCoordinate helper to get global (reduced) coordinate.

Some more work of similar type perhaps needed, but the bulk of the required
structure for masked array assignment is now in place.
2015-04-09 08:06:03 +02:00

63 lines
1.8 KiB
C++

#ifndef GRID_PREDICATED_H
#define GRID_PREDICATED_H
// Must implement the predicate gating the
// Must be able to reduce the predicate down to a single vInteger per site.
// Must be able to require the type be iScalar x iScalar x ....
// give a GetVtype method in iScalar
// and blow away the tensor structures.
//
template<class vobj>
inline void where(Lattice<vobj> &ret,const LatticeInteger &predicate,Lattice<vobj> &iftrue,Lattice<vobj> &iffalse)
{
conformable(iftrue,iffalse);
conformable(iftrue,predicate);
conformable(iftrue,ret);
GridBase *grid=iftrue._grid;
typedef typename vobj::scalar_type scalar_type;
typedef typename vobj::vector_type vector_type;
const int Nsimd = grid->Nsimd();
const int words = sizeof(vobj)/sizeof(vector_type);
std::vector<Integer> mask(Nsimd);
std::vector<std::vector<scalar_type> > truevals (Nsimd,std::vector<scalar_type>(words) );
std::vector<std::vector<scalar_type> > falsevals(Nsimd,std::vector<scalar_type>(words) );
std::vector<scalar_type *> pointers(Nsimd);
#pragma omp parallel for
for(int ss=0;ss<iftrue._grid->oSites(); ss++){
for(int s=0;s<Nsimd;s++) pointers[s] = & truevals[s][0];
extract(iftrue._odata[ss] ,pointers);
for(int s=0;s<Nsimd;s++) pointers[s] = & falsevals[s][0];
extract(iffalse._odata[ss] ,pointers);
extract(predicate._odata[ss],mask);
for(int s=0;s<Nsimd;s++){
if (mask[s]) pointers[s]=&truevals[s][0];
else pointers[s]=&falsevals[s][0];
}
merge(ret._odata[ss],pointers);
}
}
template<class vobj>
inline Lattice<vobj> where(const LatticeInteger &predicate,Lattice<vobj> &iftrue,Lattice<vobj> &iffalse)
{
conformable(iftrue,iffalse);
conformable(iftrue,predicate);
Lattice<vobj> ret(iftrue._grid);
where(ret,predicate,iftrue,iffalse);
return ret;
}
#endif