1
0
mirror of https://github.com/paboyle/Grid.git synced 2026-06-06 04:04:36 +01:00

Compare commits

..

65 Commits

Author SHA1 Message Date
Peter Boyle 8540b2a85d Test_extended_meson_field: add view_open timers to measure MemoryManager H2D transfers
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-05-28 14:04:38 -04:00
Peter Boyle dbd3a0e612 A2ALoopPropagator: fuse outer product sum into single accelerator_for kernel
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-05-28 10:39:37 -04:00
Peter Boyle 5b58d1da62 Test_extended_meson_field: add --Ni and --Nj command line options
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-05-27 22:46:11 -04:00
Peter Boyle 377db1bc08 Tensor_inner: move scalar innerProductD overloads before norm2 for ADL visibility
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-05-27 22:24:52 -04:00
Peter Boyle 699564997e Test_extended_meson_field: use decltype(coalescedRead) for arch-portable kernel types
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-05-27 22:18:43 -04:00
Peter Boyle f2750fae09 Test_extended_meson_field: use Grid norm2 instead of std::norm for HIP compatibility
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-05-27 22:04:10 -04:00
Peter Boyle ed12fa09c5 Not sure how this old lattice slice sum core
fix didn't propagate
2026-05-27 21:54:24 -04:00
Peter Boyle b914403bbe Better setup Frontier 2026-05-27 21:31:54 -04:00
Peter Boyle c1566fb9a2 Merge branch 'develop' into feature/Kpipi-masaaki-offload 2026-05-27 21:03:16 -04:00
Peter Boyle 905da6f083 Merge branch 'feature/reduction-reorganisation' into develop 2026-05-27 21:01:30 -04:00
Peter Boyle cb199c127c Merge branch 'develop' into feature/Kpipi-masaaki-offload 2026-05-27 20:59:30 -04:00
Peter Boyle 1a932ea33b Merge 2026-05-27 20:45:00 -04:00
Peter Boyle 86c7f29183 Config command update 2026-05-27 16:19:33 -04:00
Peter Boyle b0c99f876e Configure on mac update 2026-05-27 16:16:55 -04:00
Peter Boyle bf5fcdc860 Ease of use for std::complex interchangable with thrust 2026-05-27 16:05:37 -04:00
Peter Boyle 5822a6599c skills: add GPU/A2A reference skill documents
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-05-27 11:12:47 -04:00
Peter Boyle 0eeb334fe0 systems/mac-arm: add MPI configure command and spack sourceme
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-05-27 11:12:42 -04:00
Peter Boyle d8d16407e9 A2ASpatialSum: extended meson field kernel and test
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-05-27 11:12:29 -04:00
Peter Boyle b58a1508fa Perlmutter cuda version update 2026-05-21 13:25:13 -07:00
Peter Boyle 4d527e81fa Remove hip specific files 2026-05-21 12:34:30 -04:00
Peter Boyle 7803580aa6 Lattice_reduction_gpu: demote timing logs to Debug, disable by default
skills/mpi-heterogeneous: add Bug Class 4 for Frontier GTL/libamdhip64 ABI mismatch

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-05-21 12:34:30 -04:00
Peter Boyle 32654db366 Test_planned_fft: fix PlannedFFT template parameter to use ::vector_object
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-05-21 12:34:30 -04:00
Peter Boyle cd340cfab3 tests: add Test_planned_fft exercising PlannedFFT<vobj>
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-05-21 12:34:30 -04:00
Peter Boyle f32866b2ff tests/fft: remove PlanDestroy calls (FFT handles plans per-call)
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-05-21 12:34:30 -04:00
Peter Boyle 1cd1dc091e FFT: add FFTbase, PlannedFFT; factor FFT_dim_execute free function
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-05-21 12:34:30 -04:00
Peter Boyle 0493656e86 debug: add Test_hipfft_repro — reproducer for hipFFT PARSE_ERROR on ROCm 7
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-05-21 12:34:30 -04:00
Peter Boyle 66fd504c4d tests/debug: add G=4 to hipfft fail reproducer
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-05-21 12:34:30 -04:00
Peter Boyle be4dd2b52f tests/debug: test hipMemset variant before cache is populated
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-05-21 12:34:30 -04:00
Peter Boyle 707d059766 tests/debug: extend hipfft fail reproducer with hipMemset and sync variants
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-05-21 12:34:30 -04:00
Peter Boyle f08c755ae6 FFT: use host stack buffer in PlanCreate, not deviceVector
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-05-21 12:34:30 -04:00
Peter Boyle dbbfdd4e4b tests/debug: add minimal hipfft ordering bug fail/pass pair
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-05-21 12:34:30 -04:00
Peter Boyle f967fb40bf tests/debug: test plan-before-malloc vs malloc-before-plan ordering
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-05-21 12:34:30 -04:00
Peter Boyle 74e0f846cb tests/debug: extend hipfft reproducer with Grid-realistic howmany and exec tests
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-05-21 12:34:30 -04:00
Peter Boyle 303a4d26e5 tests/debug: add minimal hipfft plan-creation reproducer
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-05-21 12:34:30 -04:00
Peter Boyle 119888653c FFT HIP: use hipfftCreate+hipfftMakePlanMany instead of hipfftPlanMany 2026-05-21 12:34:30 -04:00
Peter Boyle a9f42c08f9 FFT: pass nullptr for inembed/onembed in hipfftPlanMany to avoid HIPFFT_PARSE_ERROR 2026-05-21 12:34:30 -04:00
Peter Boyle e79adc9d31 FFT: cache plans per vobj type across calls
Plans are created lazily on the first FFT_dim call and reused for all
subsequent calls on the same FFT object.  PlanCreate<vobj>() can be
called explicitly to pre-warm the cache.  PlanDestroy() must be called
before switching to a different vobj type; the destructor cleans up any
live plans automatically.

Update Test_fft.cc and Test_fftf.cc to call PlanDestroy() between the
LatticeComplex and LatticeSpinMatrix sections that reuse the same FFT object.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-05-21 12:34:30 -04:00
Peter Boyle 5a9056cd93 Accelerator: lower default accelerator_threads from 16 to 8
Benchmark_dwf_fp32 on MI250X GCD: 1.7 TF/s at nt=8, ~300 GF/s at nt=16.
With Nsimd=8 (fp32, GEN_SIMD_WIDTH=64B), nt=8 gives exactly 64 threads =
one full AMD wavefront. Higher values double register demand per block and
hit a register-pressure cliff for stencil kernels.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-05-21 12:34:30 -04:00
Peter Boyle 012c36ab5a Accelerator: raise default accelerator_threads from 2 to 16 2026-05-21 12:34:30 -04:00
Peter Boyle 5c4574f9aa skills: add gpu-memory-performance.md
Documents the acceleratorThreads() default=2 trap, LambdaApply thread
mapping, coalescedRead/Write idiom, when to use __global__ vs
accelerator_for, and fused vs staged HBM access patterns.

Includes observed MI250X numbers from LatticePropagatorD reduction
(50 → 297 → 546 GB/s progression).

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-05-21 12:34:30 -04:00
Peter Boyle a424775884 sumD_gpu_reduce_words: fuse pack+reduce into single packReduceKernel
Replace the two-kernel pack+reduce sequence with a single fused kernel
packReduceKernel<R> that reads R words of each vobj at offset 'base'
and accumulates directly into iVector<iScalar<scalarD>,R>, eliminating
the intermediate bundle buffer entirely.

HBM access per word-group drops from 3x (pack-read + pack-write +
reduce-read) to 1x.  Thread count comes from getNumBlocksAndThreads
(warpSize..256) rather than acceleratorThreads(), so occupancy is
correct regardless of the --accelerator-threads setting.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-05-21 12:34:30 -04:00
Peter Boyle d6b1388741 Modified repack 2026-05-21 12:34:30 -04:00
Peter Boyle 796c6cae4e Enable GRID_REDUCTION_TIMING unconditionally
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-05-21 12:34:30 -04:00
Peter Boyle 1a8064d6d9 Lattice_reduction_gpu: add GRID_REDUCTION_TIMING instrumentation
Uncomment #define GRID_REDUCTION_TIMING to enable per-phase timing output:

  sumD_gpu_reduce_words: pack time (accelerator_for) per R and base
  sumD_gpu_small:        reduceKernel+barrier time and D2H time separately
  sumD_gpu_large:        total wall time across all word groups

This lets us identify whether the large-type bottleneck is in the pack
kernel, the shared-memory reduction kernel, the barrier, or the D2H.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-05-21 12:34:30 -04:00
Peter Boyle 43648924c3 sumD_gpu_large: radix-12 word-bundle reduction replacing radix-1
Replace the word-by-word loop (one kernel launch per scalar word) with
sumD_gpu_reduce_words<R> which packs R consecutive vector_type words per
site into iVector<iScalar<vector>,R>, then calls the existing sumD_gpu_small
shared-memory kernel once for the whole bundle.

Dispatch: radix-12 first, radix-4 for the remainder < 12, radix-1 for
any final < 4 words.  For LatticePropagator (144 words = 12x12), this
reduces the kernel-launch count from 144 to 12 -- a 12x reduction.

Bundle::Nsimd() inherits from vector_type so sumD_gpu_small handles SIMD
lane extraction and double-precision promotion identically to the scalar
word case.  sizeof(Bundle::scalar_objectD) = R*16 <= 192 B; well within
sharedMemPerBlock on all supported devices.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-05-21 12:34:30 -04:00
Peter Boyle bf2140e74d Lattice_reduction_sycl: fix double-precision accumulation in sumD_gpu_tensor
Accumulate in sobjD throughout rather than accumulating in sobj and
converting the final sum. For float fields this matters: summing N floats
then casting loses O(N*eps_float) relative precision vs accumulating in
double from the start.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-05-21 12:34:30 -04:00
Peter Boyle a1119266c1 Revert to hand-rolled reduction; drop Lattice_reduction_gpu_cub.h
Remove the CUB/hipCUB direction entirely. Restore Lattice_reduction_gpu.h,
Lattice_reduction_sycl.h, and Lattice_reduction.h to the state before the
CUB rewrite (commit 969b0a39), recovering the original primary function names
(sumD_gpu_small, sumD_gpu_large, sumD_gpu, sum_gpu, sum_gpu_large) and the
hand-rolled shared-memory reduction kernel.

Delete Lattice_reduction_gpu_cub.h. Update Test_reduction to remove the
old/new comparison sections that depended on sum_gpu_old.

The lesson: CUB DeviceReduce is slower than the hand-rolled kernel for small
types, and the smem sizing problem for the extraction pass has no clean
solution within the accelerator_for abstraction. The right improvement is
a higher radix (12 then 4) in sumD_gpu_large, applied directly to the
existing hand-rolled kernel.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-05-21 12:34:30 -04:00
Peter Boyle a0f00c0eca sumD_gpu_direct: revert to per-lane write; CUB handles Nsimd*osites inputs
Benchmarking showed the shared-memory lane-summation approach (843d6497)
was slower than writing each SIMD lane individually and letting CUB reduce
the full nlanes = osites*Nsimd array. CUB's device reduce is more efficient
over the larger input than the smem overhead + serialised lane-0 summation.
The smem approach also required overriding acceleratorThreads() to avoid
the block-size sizing problem. Restore the simpler per-lane path.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-05-21 12:34:30 -04:00
Peter Boyle d358954a84 sumD_gpu_direct: shared-memory lane reduction with acceleratorThreads(1)
Set acceleratorThreads to 1 before the extraction kernel so that
dim3(nsimd,1,1) blocks give exactly one site group per block and
__shared__ sobjD smem[nsimd] is correctly sized without depending on
the runtime acceleratorThreads() value. threadIdx.x (acceleratorSIMTlane)
indexes the SIMD lane for coalesced reads; lane 0 sums smem[0..nsimd-1]
and writes one sobjD per site. CUB then reduces osites elements instead
of osites*nsimd, reducing both store traffic and CUB work by Nsimd.
acceleratorSynchronise() (warp-level) suffices since nsimd < warpSize.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-05-21 12:34:30 -04:00
Peter Boyle aee00bdfb5 sumD_gpu_direct: one thread per SIMD lane using extractLane
Replaces one thread per outer site calling Reduce() (sequential Nsimd-wide
loop) with one thread per lane calling extractLane() — O(1) per thread.
CUB now reduces over osites*Nsimd elements. Avoids serial lane reduction
but leaves the per-lane sobjD store stride as a known remaining concern.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-05-21 12:34:30 -04:00
Peter Boyle cf324b0fa1 Lattice_reduction_gpu_cub: define GRID_REDUCTION_TIMING in header
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-05-21 12:34:30 -04:00
Peter Boyle b314dc224d Lattice_reduction_gpu_cub: add GRID_REDUCTION_TIMING instrumentation
Guards accelerator_for and CUB DeviceReduce calls in sumD_gpu_direct
and sumD_gpu_large with #ifdef GRID_REDUCTION_TIMING to isolate where
time is spent in each path. Large path accumulates across all groups
and prints totals with words/nfull/rem context.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-05-21 12:34:30 -04:00
Peter Boyle 1bbd62498e Lattice_reduction_gpu_cub: replace WordBundle4 with iVector<iScalar<scalarD>,4>
WordBundle4 was redundant with Grid's existing tensor infrastructure.
iVector<iScalar<scalarD>,4> already provides accelerator_inline operator+,
zeroit(), and sycl::is_device_copyable — no new type needed.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-05-21 12:34:30 -04:00
Peter Boyle f3c3b1c04b Test_reduction: add timing benchmark for new vs old reduction paths
Reports us/call and GB/s for sum_gpu (CUB/sycl::reduction) and
sum_gpu_old (hand-rolled shared-memory) for each field type, with
5-call warmup and 100-call timed loop.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-05-21 12:34:30 -04:00
Peter Boyle 069f98b253 skills: HPC battle-hardening skill files for GPU+MPI correctness
Six skill files encoding expertise for making codebases robust on
problematic HPC systems, covering: correctness verification
(double-run, fingerprinting, flight recorder), hang diagnosis,
GPU runtime correctness (premature barrier, infinite poll),
MPI correctness on heterogeneous systems (device buffer aliasing,
AARCH64 PLT corruption, deterministic reductions),
compiler validation, and communication/computation overlap pipeline
design.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-05-21 12:34:30 -04:00
Peter Boyle dfd0503eae Test_reduction: use separate float and double grids
Float fields require a grid constructed with vComplexF::Nsimd(); using
a double grid causes grid->_gsites to undercount the sites in float
vobjF, making the constant-field expected value wrong.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-05-21 12:34:30 -04:00
Peter Boyle c629b2e87e Rename scalarNorm2 to squaredSum in Test_reduction.cc
The function computes |sum|^2 — the squared magnitude of an aggregate sum —
not a norm. squaredSum makes clear that squaring is applied to the sum, not
to individual site values before summing, distinguishing it from sumOfSquares
(the squared L2 norm).

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-05-21 12:34:30 -04:00
Peter Boyle 7c8462abd1 Fix Zero() used on thrust::complex in WordBundle4 initialisation
Grid's Zero() sentinel is not assignable to thrust::complex<double>;
use scalarD(0) instead.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-05-21 12:34:30 -04:00
Peter Boyle 95a6a0bde7 Reinstate large/small dispatch in CUB reduction path; radix-4 word-bundle for large types
rocPRIM's DeviceReduce requires warpSize(64) threads each holding one element in shared
memory, so sizeof(T)*64 must fit in sharedMemPerBlock.  LatticePropagator::scalar_objectD
is 2304 bytes (64*2304 = 147 KB), exceeding the budget and triggering a compile-time
static_assert in limit_block_size.

Introduce sumD_gpu_direct (the original direct-CUB path, safe for small types) and a new
sumD_gpu_large that groups the vobj's vector_type words in bundles of 4, reducing each
bundle as WordBundle4<scalarD> (64 bytes, 64*64 = 4 KB — always within budget).  If
words % 4 != 0, the final partial bundle is zero-padded.  sumD_gpu dispatches at compile
time via if constexpr on sizeof(sobjD) > 512.

For LatticePropagator (144 words) this gives 36 CUB launches instead of 144.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-05-21 12:34:30 -04:00
Peter Boyle bba328fac5 Add Test_reduction to tests/debug
Tests the new CUB/hipCUB/SYCL lattice reduction (sum_gpu) against the
preserved hand-rolled implementation (sum_gpu_old) for LatticeComplexF/D,
LatticeColourMatrixF/D and LatticePropagatorF/D.

Part a) gaussian random field: checks that old and new agree to within
float/double roundoff tolerance.
Part b) constant field (= 1.0, identity-matrix init): verifies
innerProduct(sum, sum) = Ncomp * V^2 where Ncomp counts the nonzero
diagonal scalar components per site (1 / Nc / Ns*Nc respectively).

Make.inc is auto-generated by scripts/filelist on bootstrap and is not
tracked; the new .cc file is all that is needed.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-05-21 12:34:30 -04:00
Peter Boyle 41362349f3 Rewrite lattice GPU reduction to use CUB, hipCUB, and SYCL reduction
Replace hand-rolled shared-memory reduction kernels (reduceBlock/reduceBlocks/
reduceKernel) and the global device variable retirementCount with a unified
CUB/hipCUB DeviceReduce::Reduce path for CUDA/HIP and sycl::reduction for SYCL.
No small/large split is needed: both CUB and sycl::reduction handle arbitrary
object sizes internally.

Old implementations preserved as sum_gpu_old / sumD_gpu_old etc. in the
original files for regression testing on GPU hardware.

Also add CLAUDE.md with build, test, and architecture guidance.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-05-21 12:34:30 -04:00
Peter Boyle a5a04929fb Merge pull request #492 from giltirn/develop
Fixes to support CUDA > 13
2026-05-19 15:26:58 -04:00
Christopher Kelly 77b8657fcc Fixes to support CUDA > 13. Specifically, the CUDA header is no longer accidentally included within Grid's namespace, and the breaking change to cub::Sum() -> ::cuda::std::plus<>{} in CUDA-13 has been worked around 2026-05-19 12:22:14 -04:00
Peter Boyle f8b2eacf99 File list issue (Ed Bennets pull request?) 2026-05-15 12:57:42 -04:00
Peter Boyle 6140ac6864 Hip Happy 2026-05-15 12:13:01 -04:00
19 changed files with 1446 additions and 67 deletions
+1
View File
@@ -53,6 +53,7 @@ NAMESPACE_CHECK(approx);
#include <Grid/algorithms/deflation/MultiRHSBlockCGLinalg.h>
// Not really deflation, but useful
#include <Grid/algorithms/blas/MomentumProject.h>
#include <Grid/algorithms/blas/A2ASpatialSum.h>
NAMESPACE_CHECK(deflation);
#include <Grid/algorithms/iterative/ConjugateGradient.h>
NAMESPACE_CHECK(ConjGrad);
+213
View File
@@ -0,0 +1,213 @@
/*************************************************************************************
Grid physics library, www.github.com/paboyle/Grid
Source file: Grid/algorithms/blas/A2ASpatialSum.h
Copyright (C) 2025
Author: Peter Boyle <pboyle@bnl.gov>
This program is free software; you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by
the Free Software Foundation; either version 2 of the License, or
(at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU General Public License for more details.
You should have received a copy of the GNU General Public License along
with this program; if not, write to the Free Software Foundation, Inc.,
51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
See the full license in the file "LICENSE" in the top level distribution directory
*************************************************************************************/
/* END LEGAL */
#pragma once
NAMESPACE_BEGIN(Grid);
/*
A2ASpatialSum
Replaces the scalar spatial accumulation loop in A2A extended meson field
contractions with a batched GEMM over local time slices, enabling GPU offload.
Given:
leftv[N_i][osite] — conjugated left SpinColourVectors (SIMD-packed)
loopRight[N_j][osite]— type-contracted right SpinColourVectors (SIMD-packed)
Computes:
EMF[i,j,t] = sum_{x,s,c} leftv[i][x,t,s,c] * loopRight[j][x,t,s,c]
via batched GEMM over nt local time slices, then GlobalSumVector across MPI.
Memory layout (all C row-major):
W_buf [nt][N_i][nxyz*Nsc] — W[t][i][x*Nsc+sc] = leftv[i] at (x,t)
LR_buf [nt][N_j][nxyz*Nsc] — LR[t][j][x*Nsc+sc] = loopRight[j] at (x,t)
EMF_buf[nt][N_j][N_i] — column-major result; EMF[i,j,t] = EMF_buf[t][j][i]
BLAS call (column-major, OP_T on A so A is read as W[i][k]):
C = A^T * B where A=W[N_i×K C-row], B=LR[N_j×K C-row], C=[N_j×N_i C-row]
→ C[i,j] = sum_k W[i][k] * LR[j][k] = EMF[i,j] ✓
*/
template<class vobj>
class A2ASpatialSum
{
public:
typedef typename vobj::scalar_type scalar;
typedef typename vobj::scalar_object sobj;
GridBase *grid;
int N_i, N_j;
int nt, nxyz, Nsc;
deviceVector<scalar> W_buf;
deviceVector<scalar> LR_buf;
deviceVector<scalar> EMF_buf;
deviceVector<scalar *> W_ptrs;
deviceVector<scalar *> LR_ptrs;
deviceVector<scalar *> EMF_ptrs;
A2ASpatialSum() : grid(nullptr), N_i(0), N_j(0), nt(0), nxyz(0), Nsc(0) {}
void Allocate(int _N_i, int _N_j, GridBase *_grid)
{
grid = _grid;
N_i = _N_i;
N_j = _N_j;
Coordinate ldims = grid->LocalDimensions();
nt = ldims[grid->Nd() - 1];
nxyz = grid->lSites() / nt;
Nsc = sizeof(sobj) / sizeof(scalar);
W_buf.resize(nt * N_i * nxyz * Nsc);
LR_buf.resize(nt * N_j * nxyz * Nsc);
EMF_buf.resize(nt * N_j * N_i);
// Build persistent batch pointer arrays
W_ptrs.resize(nt);
LR_ptrs.resize(nt);
EMF_ptrs.resize(nt);
scalar *Wh = &W_buf[0];
scalar *LRh = &LR_buf[0];
scalar *EMFh = &EMF_buf[0];
int lN_i = N_i, lN_j = N_j, lnxyz = nxyz, lNsc = Nsc;
for (int t = 0; t < nt; t++) {
acceleratorPut(W_ptrs[t], Wh + t * lN_i * lnxyz * lNsc);
acceleratorPut(LR_ptrs[t], LRh + t * lN_j * lnxyz * lNsc);
acceleratorPut(EMF_ptrs[t], EMFh + t * lN_j * lN_i);
}
}
void PackLeft(const std::vector<Lattice<vobj>> &leftv)
{
GRID_ASSERT((int)leftv.size() == N_i);
PackVectors(leftv, &W_buf[0], N_i);
}
void PackRight(const std::vector<Lattice<vobj>> &loopRight)
{
GRID_ASSERT((int)loopRight.size() == N_j);
PackVectors(loopRight, &LR_buf[0], N_j);
}
private:
// Pack vecs[N] lattice fields into buf[nt][N][nxyz*Nsc], extracting all SIMD lanes.
void PackVectors(const std::vector<Lattice<vobj>> &vecs, scalar *buf, int N)
{
int nd = grid->_ndimension;
int osites = grid->oSites();
int Nsimd = vobj::Nsimd();
int lN = N;
int lNsc = Nsc;
int lnxyz = nxyz;
Coordinate rdimensions = grid->_rdimensions;
Coordinate ldims = grid->LocalDimensions();
Coordinate simd = grid->_simd_layout;
for (int n = 0; n < N; n++) {
autoView(src_v, vecs[n], AcceleratorRead);
accelerator_for(sf, osites, Nsimd, {
#ifdef GRID_SIMT
{
int lane = acceleratorSIMTlane(Nsimd);
#else
for (int lane = 0; lane < Nsimd; lane++) {
#endif
Coordinate icoor(nd), ocoor(nd), lcoor(nd);
Lexicographic::CoorFromIndex(icoor, lane, simd);
Lexicographic::CoorFromIndex(ocoor, sf, rdimensions);
for (int d = 0; d < nd; d++)
lcoor[d] = rdimensions[d] * icoor[d] + ocoor[d];
int l_t = lcoor[nd - 1];
Coordinate xyz_coor = lcoor;
xyz_coor[nd - 1] = 0;
int64_t l_xyz;
Lexicographic::IndexFromCoor(xyz_coor, l_xyz, ldims);
sobj data = extractLane(lane, src_v[sf]);
scalar *data_s = (scalar *)&data;
int64_t base = (int64_t)l_t * lN * lnxyz * lNsc
+ (int64_t)n * lnxyz * lNsc
+ l_xyz * lNsc;
for (int sc = 0; sc < lNsc; sc++)
buf[base + sc] = data_s[sc];
}
});
}
}
public:
// Batched GEMM + MPI reduction → result[nt_global][N_i][N_j]
//
// BLAS (column-major, OP_T on A):
// C[N_j×N_i] = A^T[N_i×K] * B[N_j×K] with K=nxyz*Nsc
// reading A as C row-major [N_i][K] and B as C row-major [N_j][K]
// → C[i,j] = sum_k W[i,k] * LR[j,k] = EMF[i,j] ✓
void Sum(Eigen::Tensor<ComplexD, 3> &result)
{
GridBLAS BLAS;
int K = nxyz * Nsc;
BLAS.gemmBatched(GridBLAS_OP_T, GridBLAS_OP_N,
N_i, N_j, K,
scalar(1.0),
W_ptrs,
LR_ptrs,
scalar(0.0),
EMF_ptrs);
BLAS.synchronise();
// Copy from device and distribute into global-t layout
int nt_global = result.dimension(0);
int nd = grid->Nd();
int lt_start = grid->LocalStarts()[nd - 1];
std::vector<scalar> host_emf(nt * N_j * N_i);
acceleratorCopyFromDevice(&EMF_buf[0], host_emf.data(),
nt * N_j * N_i * sizeof(scalar));
// EMF_buf[t][j*N_i + i] = EMF[i,j] for local t
std::vector<scalar> global_emf(nt_global * N_i * N_j, scalar(0.0));
for (int lt = 0; lt < nt; lt++) {
int gt = lt + lt_start;
for (int i = 0; i < N_i; i++)
for (int j = 0; j < N_j; j++)
global_emf[gt * N_i * N_j + i * N_j + j] = host_emf[lt * N_j * N_i + j * N_i + i];
}
grid->GlobalSumVector(global_emf.data(), nt_global * N_i * N_j);
for (int gt = 0; gt < nt_global; gt++)
for (int i = 0; i < N_i; i++)
for (int j = 0; j < N_j; j++)
result(gt, i, j) = global_emf[gt * N_i * N_j + i * N_j + j];
}
};
NAMESPACE_END(Grid);
+9 -3
View File
@@ -1,7 +1,6 @@
#pragma once
#if defined(GRID_CUDA)
#include <cub/cub.cuh>
#define gpucub cub
#define gpuError_t cudaError_t
@@ -57,8 +56,13 @@ inline void sliceSumReduction_cub_small(const vobj *Data,
//copy offsets to device
acceleratorCopyToDeviceAsynch(&offsets[0],d_offsets,sizeof(int)*(rd+1),computeStream);
#if defined(__CUDACC__) && (__CUDACC_VER_MAJOR__ >= 13)
#define GRID_CUB_SUM_OP ::cuda::std::plus<>{}
#else
#define GRID_CUB_SUM_OP ::gpucub::Sum()
#endif
gpuError_t gpuErr = gpucub::DeviceSegmentedReduce::Reduce(temp_storage_array, temp_storage_bytes, rb_p,d_out, rd, d_offsets, d_offsets+1, ::gpucub::Sum(), zero_init, computeStream);
gpuError_t gpuErr = gpucub::DeviceSegmentedReduce::Reduce(temp_storage_array, temp_storage_bytes, rb_p,d_out, rd, d_offsets, d_offsets+1, GRID_CUB_SUM_OP, zero_init, computeStream);
if (gpuErr!=gpuSuccess) {
std::cout << GridLogError << "Lattice_slicesum_gpu.h: Encountered error during gpucub::DeviceSegmentedReduce::Reduce (setup)! Error: " << gpuErr <<std::endl;
exit(EXIT_FAILURE);
@@ -82,11 +86,13 @@ inline void sliceSumReduction_cub_small(const vobj *Data,
});
//issue segmented reductions in computeStream
gpuErr = gpucub::DeviceSegmentedReduce::Reduce(temp_storage_array, temp_storage_bytes, rb_p, d_out, rd, d_offsets, d_offsets+1,::gpucub::Sum(), zero_init, computeStream);
gpuErr = gpucub::DeviceSegmentedReduce::Reduce(temp_storage_array, temp_storage_bytes, rb_p, d_out, rd, d_offsets, d_offsets+1, GRID_CUB_SUM_OP, zero_init, computeStream);
if (gpuErr!=gpuSuccess) {
std::cout << GridLogError << "Lattice_slicesum_gpu.h: Encountered error during gpucub::DeviceSegmentedReduce::Reduce! Error: " << gpuErr <<std::endl;
exit(EXIT_FAILURE);
}
#undef GRID_CUB_SUM_OP
acceleratorCopyFromDeviceAsynch(d_out,&lvSum[0],rd*sizeof(vobj),computeStream);
+8
View File
@@ -113,6 +113,14 @@ accelerator_inline RealD adj(const RealD & r){ return r; }
accelerator_inline ComplexD adj(const ComplexD& r){ return(conjugate(r)); }
accelerator_inline ComplexF adj(const ComplexF& r ){ return(conjugate(r)); }
#if defined(GRID_CUDA) || defined(GRID_HIP)
//Provide for convenience
accelerator_inline std::complex<double> conjugate(const std::complex<double>& r){ return(conj(r)); }
accelerator_inline std::complex<float> conjugate(const std::complex<float>& r) { return(conj(r)); }
accelerator_inline std::complex<double> adj(const std::complex<double>& r) { return(conj(r)); }
accelerator_inline std::complex<float> adj(const std::complex<float>& r) { return(conj(r)); }
#endif
accelerator_inline RealF real(const RealF & r){ return r; }
accelerator_inline RealD real(const RealD & r){ return r; }
accelerator_inline RealF real(const ComplexF & r){ return r.real(); }
+28 -23
View File
@@ -32,6 +32,33 @@ Author: Christoph Lehner <christoph@lhnr.de>
NAMESPACE_BEGIN(Grid);
//////////////////////////////////////
// innerProductD scalar overloads must be visible before norm2 is defined
//////////////////////////////////////
accelerator_inline ComplexD innerProductD(const ComplexF &l,const ComplexF &r){ return innerProduct(l,r); }
accelerator_inline ComplexD innerProductD(const ComplexD &l,const ComplexD &r){ return innerProduct(l,r); }
accelerator_inline RealD innerProductD(const RealD &l,const RealD &r){ return innerProduct(l,r); }
accelerator_inline RealD innerProductD(const RealF &l,const RealF &r){ return innerProduct(l,r); }
accelerator_inline vComplexD innerProductD(const vComplexD &l,const vComplexD &r){ return innerProduct(l,r); }
accelerator_inline vRealD innerProductD(const vRealD &l,const vRealD &r){ return innerProduct(l,r); }
accelerator_inline vComplexD innerProductD(const vComplexF &l,const vComplexF &r)
{
vComplexD la,lb;
vComplexD ra,rb;
Optimization::PrecisionChange::StoD(l.v,la.v,lb.v);
Optimization::PrecisionChange::StoD(r.v,ra.v,rb.v);
return innerProduct(la,ra) + innerProduct(lb,rb);
}
accelerator_inline vRealD innerProductD(const vRealF &l,const vRealF &r)
{
vRealD la,lb;
vRealD ra,rb;
Optimization::PrecisionChange::StoD(l.v,la.v,lb.v);
Optimization::PrecisionChange::StoD(r.v,ra.v,rb.v);
return innerProduct(la,ra) + innerProduct(lb,rb);
}
///////////////////////////////////////////////////////////////////////////////////////
// innerProduct Scalar x Scalar -> Scalar
// innerProduct Vector x Vector -> Scalar
@@ -138,30 +165,8 @@ auto Reduce (const iScalar<l>& lhs) -> iScalar<decltype(Reduce(lhs._internal))>
//////////////////////////////////////
// innerProductD : if single promote to double and evaluate with sum 2x
// (scalar/vector overloads are declared above norm2 for ADL visibility)
//////////////////////////////////////
accelerator_inline ComplexD innerProductD(const ComplexF &l,const ComplexF &r){ return innerProduct(l,r); }
accelerator_inline ComplexD innerProductD(const ComplexD &l,const ComplexD &r){ return innerProduct(l,r); }
accelerator_inline RealD innerProductD(const RealD &l,const RealD &r){ return innerProduct(l,r); }
accelerator_inline RealD innerProductD(const RealF &l,const RealF &r){ return innerProduct(l,r); }
accelerator_inline vComplexD innerProductD(const vComplexD &l,const vComplexD &r){ return innerProduct(l,r); }
accelerator_inline vRealD innerProductD(const vRealD &l,const vRealD &r){ return innerProduct(l,r); }
accelerator_inline vComplexD innerProductD(const vComplexF &l,const vComplexF &r)
{
vComplexD la,lb;
vComplexD ra,rb;
Optimization::PrecisionChange::StoD(l.v,la.v,lb.v);
Optimization::PrecisionChange::StoD(r.v,ra.v,rb.v);
return innerProduct(la,ra) + innerProduct(lb,rb);
}
accelerator_inline vRealD innerProductD(const vRealF &l,const vRealF &r)
{
vRealD la,lb;
vRealD ra,rb;
Optimization::PrecisionChange::StoD(l.v,la.v,lb.v);
Optimization::PrecisionChange::StoD(r.v,ra.v,rb.v);
return innerProduct(la,ra) + innerProduct(lb,rb);
}
// Now do it for vector, matrix, scalar
template<class l,class r,int N> accelerator_inline
+2
View File
@@ -96,7 +96,9 @@ void acceleratorInit(void);
#ifdef GRID_CUDA
NAMESPACE_END(Grid);
#include <cuda.h>
NAMESPACE_BEGIN(Grid);
#ifdef __CUDA_ARCH__
#define GRID_SIMT
+61
View File
@@ -0,0 +1,61 @@
---
name: ref_a2a_emf_work
description: "A2A Extended Meson Field GPU offload work — status, file locations, pending task"
metadata:
node_type: memory
type: project
originSessionId: 956e80aa-401d-481a-80bb-17f8abe1c131
---
## What was built
`Grid/algorithms/blas/A2ASpatialSum.h` — batched GEMM spatial sum replacing scalar SIMD accumulation. Included via `Grid/algorithms/Algorithms.h`.
`tests/Test_extended_meson_field.cc` — test with class `A2AExtendedMesonFieldRef` containing:
- CPU reference path (`use_blas=false`)
- BLAS path (`use_blas=true`) using `A2ASpatialSum`
- Per-phase timing with `[ref type=N]` / `[blas type=N]` labels
- 4 contraction types (0-3), all verified at machine precision (~4e-16 rel_err)
## Pending task: GPU offload class
**Goal**: Write `A2AExtendedMesonFieldGPU` in the same test file, replacing all `thread_for` loops with `accelerator_for`-based free function kernels.
The `thread_for` blocks to replace all have the form:
```cpp
thread_for(r, rd, {
int so = r * grid->_ostride[orthogdim];
for (int n = 0; n < e1; n++)
for (int b = 0; b < e2; b++) {
int ss = so + n * stride + b;
// work
}
});
```
Replace with `accelerator_for(ss, grid->oSites(), Nsimd, { ... })`.
**Free functions to write** (each takes `Lattice<T>` args, opens views internally):
- `A2ALoopPropagator` — outerProduct sum (loop build)
- `A2APackLeftConjugated` — conjugate left fermion fields into `Lattice<SpinColourVector_v>`
- `A2ALoopLeftContractionType0/1/2/3` — per-site loop × loop propagator → `tloop`
- `A2ALoopRightContractionType0/1/2/3` — per-site tloop × right → `loopRight[j]`
**Data structure changes required**:
- `tloopv`: `std::vector<SpinColourMatrix_v>``Lattice<SpinColourMatrix_v>` (PropagatorField)
- `leftv[i]`: `std::vector<SpinColourVector_v>``Lattice<SpinColourVector_v>`
- `loopRight[j]`: `std::vector<SpinColourVector_v>``Lattice<SpinColourVector_v>`
**Why**: `std::vector<vobj>` is host memory, not GPU accessible. See [[ref_lattice_vs_vector]].
**`A2ASpatialSum` impact**: `PackLeft`/`PackRight` currently take `std::vector<std::vector<vobj>>`. Once leftv/loopRight become `std::vector<Lattice<vobj>>`, those signatures must change to match.
## Timing on 8.8.8.16 (N_i=N_j=8, Nloop=4, 1 MPI rank)
Dominant costs:
- `loop_build`: 4-6 ms (outerProduct over 4 propagators)
- `pack_loopright`: 0.9-2.2 ms (type-dependent)
- `spatial_sum` (ref): ~1.5 ms
- `A2ASpatialSum TOTAL`: 2.5-4.3 ms (PackLeft+PackRight dominate GEMM on small volume)
## Related
[[ref_accelerator_for]] [[ref_coalesced_views]] [[ref_lattice_vs_vector]] [[ref_grid_simt_pattern]]
+43
View File
@@ -0,0 +1,43 @@
---
name: ref_accelerator_for
description: Grid accelerator_for usage — converting block-strided thread_for to GPU-portable oSites loops
metadata:
node_type: memory
type: reference
originSessionId: 956e80aa-401d-481a-80bb-17f8abe1c131
---
## Pattern: block-strided thread_for → accelerator_for over oSites
Old CPU-only pattern (block-strided over orthog dimension):
```cpp
thread_for(r, rd, {
int so = r * grid->_ostride[orthogdim];
for (int n = 0; n < e1; n++)
for (int b = 0; b < e2; b++) {
int ss = so + n * stride + b;
// work on site ss
}
});
```
GPU-portable replacement:
```cpp
accelerator_for(ss, grid->oSites(), Nsimd, {
// work on site ss — one SIMT thread per (osite, lane) on GPU
// one thread per osite (lane loop implicit via GRID_SIMT) on CPU
});
```
Key rules:
- `accelerator_for(iter, count, Nsimd, body)` — Nsimd is `vobj::Nsimd()` or `grid->Nsimd()`
- On CPU: expands to `thread_for` over count, `acceleratorSIMTlane` always returns 0 — must use `#ifdef GRID_SIMT` pattern if iterating lanes explicitly (see [[ref_grid_simt_pattern]])
- On GPU: one SIMT thread per (iter × lane), `acceleratorSIMTlane(Nsimd)` returns actual lane
- Loop body must capture only scalar/POD by value or via device-accessible pointers; no `std::vector` or host containers inside the body
- `Coordinate` inside `accelerator_for` must be `AcceleratorVector<int, MaxDims>` (stack-allocated, device-safe) — Grid's `Coordinate` typedef already satisfies this
## Where defined
`Grid/threads/Accelerator.h` — CPU path ~line 607; GPU paths in conditional blocks above.
## Model file
`Grid/algorithms/blas/MomentumProject.h``ImportVector` is the canonical example of correct `accelerator_for` + SIMD lane extraction.
+70
View File
@@ -0,0 +1,70 @@
---
name: ref_coalesced_views
description: Grid coalescedRead/coalescedWrite and autoView — GPU-portable field access inside accelerator_for
metadata:
node_type: memory
type: reference
originSessionId: 956e80aa-401d-481a-80bb-17f8abe1c131
---
## View access modes
```cpp
autoView(v, field, AcceleratorRead); // read-only, device-accessible
autoView(v, field, AcceleratorWrite); // write-only, device-accessible
autoView(v, field, AcceleratorReadWrite); // read-write, device-accessible
autoView(v, field, CpuRead); // CPU only (avoids GPU migration)
autoView(v, field, CpuWrite); // CPU only
```
Views must be opened **before** `accelerator_for` and closed (go out of scope) **after**. Never open a view inside the accelerator_for body.
## coalescedRead / coalescedWrite
Inside `accelerator_for(ss, oSites, Nsimd, { ... })`:
```cpp
auto site = coalescedRead(v[ss]); // reads SIMT lane; returns scalar_object on GPU, vobj on CPU
coalescedWrite(v[ss], site); // writes SIMT lane
```
- `coalescedRead(v[ss])` calls `v.operator()(ss)` which on GPU returns `extractLane(lane, v[ss])` — one lane per SIMT thread, contiguous across threads → coalesced
- On CPU returns the full vobj (no lane extraction needed; handled transparently)
- The returned type is `decltype(coalescedRead(v[ss]))` — use `auto` or match with scalar_object
## Typical kernel pattern
```cpp
autoView(out_v, out, AcceleratorWrite);
autoView(in_v, in, AcceleratorRead);
accelerator_for(ss, grid->oSites(), vobj::Nsimd(), {
auto x = coalescedRead(in_v[ss]);
// modify x ...
coalescedWrite(out_v[ss], x);
});
```
## Free function kernel signature
```cpp
template<class vobj>
void MyKernel(Lattice<vobj> &out, const Lattice<vobj> &in)
{
GridBase *grid = in.Grid();
autoView(out_v, out, AcceleratorWrite);
autoView(in_v, in, AcceleratorRead);
accelerator_for(ss, grid->oSites(), vobj::Nsimd(), {
auto x = coalescedRead(in_v[ss]);
coalescedWrite(out_v[ss], x);
});
}
```
## What NOT to do
- Do not access `std::vector` elements inside `accelerator_for` — not device-accessible
- Do not use `CpuRead`/`CpuWrite` views inside `accelerator_for` — GPU will fault
- Do not assign to `v[ss]` directly inside `accelerator_for` — use `coalescedWrite`
- Do not open multiple write views on the same field simultaneously
## Related
[[ref_accelerator_for]] [[ref_lattice_vs_vector]]
+47
View File
@@ -0,0 +1,47 @@
---
name: ref_grid_simt_pattern
description: Grid GRID_SIMT
metadata:
node_type: memory
type: reference
originSessionId: 956e80aa-401d-481a-80bb-17f8abe1c131
---
## The problem
On CPU, `accelerator_for(sf, oSites, Nsimd, {...})` expands to `thread_for(sf, oSites, {...})` — one thread per osite. `acceleratorSIMTlane(Nsimd)` always returns **0** on CPU. If you need to iterate all Nsimd lanes (e.g. to extract SIMD-packed data), you must loop explicitly on CPU.
On GPU, `accelerator_for` launches one SIMT thread per (osite × lane). `acceleratorSIMTlane(Nsimd)` returns the actual lane index [0, Nsimd).
## Correct pattern (from MomentumProject::ImportVector)
```cpp
accelerator_for(sf, osites, Nsimd, {
#ifdef GRID_SIMT
{
int lane = acceleratorSIMTlane(Nsimd);
#else
for (int lane = 0; lane < Nsimd; lane++) {
#endif
// body using lane
}
});
```
- On GPU: `GRID_SIMT` is defined → single-lane body, lane from hardware
- On CPU: `GRID_SIMT` is not defined → explicit lane loop inside the osite thread
## When is this needed?
Only when you explicitly need the lane index, e.g.:
- Extracting scalar data from SIMD-packed `vobj` via `extractLane(lane, src[sf])`
- Computing full local coordinates from (osite, lane) → `Lexicographic::CoorFromIndex(icoor, lane, simd_layout)`
When using `coalescedRead`/`coalescedWrite`, this pattern is **not needed** — those handle lane selection transparently.
## Pitfall that caused a bug
`A2ASpatialSum::PackVectors` originally used `accelerator_for` without the `#ifdef GRID_SIMT` lane loop. On CPU, only lane=0 was extracted, giving wrong norms (~8× too small for `GEN_SIMD_WIDTH=64`, `Nsimd=4`). Fix: add the `#ifdef GRID_SIMT` pattern. See [[ref_accelerator_for]].
## Model file
`Grid/algorithms/blas/MomentumProject.h`, function `ImportVector`, lines ~166-207.
+48
View File
@@ -0,0 +1,48 @@
---
name: ref_lattice_vs_vector
description: When to use Lattice<T> vs std::vector<T> for GPU-portable field storage in Grid
metadata:
node_type: memory
type: reference
originSessionId: 956e80aa-401d-481a-80bb-17f8abe1c131
---
## Rule
Use `Lattice<vobj>` (or `std::vector<Lattice<vobj>>`) for any field that will be read or written inside `accelerator_for`. `std::vector<vobj>` is host memory and is NOT device-accessible.
## Before vs after GPU offload
```cpp
// CPU-only (host memory, not GPU accessible)
std::vector<SpinColourVector_v> tloopv(oSites, Zero());
// accessed directly: tloopv[ss]
// GPU-portable
Lattice<SpinColourVector_v> tloop(grid);
// accessed via view: autoView(tloop_v, tloop, AcceleratorWrite);
// coalescedWrite(tloop_v[ss], val);
```
## Corollary: function signatures
CPU-only version:
```cpp
void PackLeft(const std::vector<std::vector<vobj>> &leftv);
```
GPU-portable version:
```cpp
void PackLeft(const std::vector<Lattice<vobj>> &leftv);
```
## deviceVector for raw device buffers
`deviceVector<T>` (defined in Grid) is like `std::vector<T>` but in device-accessible memory. Use for raw scalar scratch/pack buffers (e.g. GEMM input/output staging). Not for structured lattice data.
## Pointer arrays for batched BLAS
`deviceVector<scalar *>` holds batch pointer arrays. Populate with `acceleratorPut(ptrs[t], base + offset)` — sets device-side pointer from host. See `A2ASpatialSum::Allocate`.
## Related
[[ref_coalesced_views]] [[ref_accelerator_for]]
+3 -3
View File
@@ -1,4 +1,3 @@
CLIME=`spack find --paths c-lime@2-3-9 | grep c-lime| cut -c 15-`
../../configure --enable-comms=mpi-auto \
--with-lime=$CLIME \
--enable-unified=no \
@@ -9,8 +8,9 @@ CLIME=`spack find --paths c-lime@2-3-9 | grep c-lime| cut -c 15-`
--disable-gparity \
--disable-fermion-reps \
--enable-simd=GPU \
--with-gmp=$OLCF_GMP_ROOT \
--with-mpfr=/opt/cray/pe/gcc/mpfr/3.1.4/ \
--with-gmp=$GMP \
--with-mpfr=$MPFR \
--with-openssl=$OPENSSL \
--disable-fermion-reps \
CXX=hipcc MPICXX=mpicxx \
CXXFLAGS="-fPIC -I${ROCM_PATH}/include/ -I${MPICH_DIR}/include " \
+4
View File
@@ -2,6 +2,10 @@
echo spack
. /autofs/nccs-svm1_home1/paboyle/spack/share/spack/setup-env.sh
export CLIME=`spack find --paths c-lime | grep ^c-lime | awk '{print $2}' `
export MPFR=`spack find --paths mpfr | grep ^mpfr | awk '{print $2}' `
export OPENSSL=`spack find --paths openssl | grep openssl | awk '{print $2}' `
export GMP=`spack find --paths gmp | grep ^gmp | awk '{print $2}' `
module load cce/21.0.0
module load cpe/26.03
+3 -3
View File
@@ -1,12 +1,12 @@
DIR=`pwd`
PREFIX=$HOME/DDHMC/Grid/systems/Prerequisites/install/
../../configure \
--enable-comms=mpi \
--enable-simd=GPU \
--enable-shm=nvlink \
--enable-gen-simd-width=64 \
--with-gmp=$PREFIX \
--with-mpfr=$PREFIX \
--with-gmp=$GMP \
--with-mpfr=$MPFR \
--enable-accelerator=cuda \
--disable-fermion-reps \
--disable-unified \
+4 -2
View File
@@ -1,4 +1,6 @@
export CRAY_ACCEL_TARGET=nvidia80
source /global/homes/p/pboyle/spack/share/spack/setup-env.sh
export MPFR=`spack find --paths mpfr | grep mpfr | cut -c 13-`
export GMP=`spack find --paths gmp | grep gmp | cut -c 12-`
module load PrgEnv-gnu cpe-cuda cudatoolkit/11.4
module load PrgEnv-gnu cpe-cuda cudatoolkit/12.0
+5 -3
View File
@@ -3,12 +3,14 @@
CXX=mpicxx ../../configure \
--enable-simd=GEN \
--enable-comms=mpi-auto \
--enable-Sp=yes \
--enable-unified=yes \
--prefix /Users/peterboyle/QCD/vtk/Grid/install \
--disable-fermion-reps \
--disable-gparity \
--prefix /Users/peterboyle/QCD/Grid-install \
--with-lime=$CLIME \
--with-openssl=$OPENSSL \
--with-gmp=$GMP \
--with-mpfr=$MPFR \
--disable-debug
--with-fftw=$FFTW \
--disable-debug
+7
View File
@@ -0,0 +1,7 @@
source /Users/peterboyle/QCD//Spack/spack//share/spack/setup-env.sh
export FFTW=`spack find --paths fftw | grep ^fftw | awk '{print $2}' `
export CLIME=`spack find --paths c-lime | grep ^c-lime | awk '{print $2}' `
export MPFR=`spack find --paths mpfr | grep ^mpfr | awk '{print $2}' `
export OPENSSL=`spack find --paths openssl | grep openssl | awk '{print $2}' `
export GMP=`spack find --paths gmp | grep ^gmp | awk '{print $2}' `
+841
View File
@@ -0,0 +1,841 @@
/*************************************************************************************
Grid physics library, www.github.com/paboyle/Grid
Source file: tests/Test_extended_meson_field.cc
Copyright (C) 2015-2025
Author: Peter Boyle <pboyle@bnl.gov>
Author: Masaaki Tomii <masaaki.tomii@uconn.edu> (original Hadrons kernels)
This program is free software; you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by
the Free Software Foundation; either version 2 of the License, or
(at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU General Public License for more details.
You should have received a copy of the GNU General Public License along
with this program; if not, write to the Free Software Foundation, Inc.,
51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
See the full license in the file "LICENSE" in the top level distribution directory
*************************************************************************************/
#include "disable_tests_without_instantiations.h"
#ifdef ENABLE_FERMION_INSTANTIATIONS
#include <Grid/Grid.h>
#include <Grid/qcd/utils/A2Autils.h>
using namespace Grid;
typedef WilsonImplD FImpl;
typedef typename FImpl::FermionField FermionField;
typedef typename FImpl::SiteSpinor vobj;
typedef typename vobj::scalar_type scalar_type;
typedef typename vobj::vector_type vector_type;
typedef iSpinColourMatrix<vector_type> SpinColourMatrix_v;
typedef iSpinColourVector<vector_type> SpinColourVector_v;
typedef iSpinMatrix<vector_type> SpinMatrix_v;
typedef iSinglet<vector_type> Scalar_v;
typedef iSinglet<scalar_type> Scalar_s;
typedef Lattice<SpinColourMatrix_v> PropagatorField;
// CPU reference + optionally batched GEMM spatial sum, ported from
// Hadrons/Modules/MContraction/A2AExtendedMesonField.hpp
// (M. Tomii, mtomii/Hadrons:local-2025-edits). Hadrons infrastructure removed.
// thread_for / CpuRead / orthogdim=3 preserved.
class A2AExtendedMesonFieldRef
{
public:
// result is indexed [nt][N_i][N_j].
// use_blas=true replaces the scalar spatial accumulation with A2ASpatialSum.
static void compute(
Eigen::Tensor<ComplexD, 3> &result,
const std::vector<FermionField> &left,
const std::vector<FermionField> &right,
const std::vector<FermionField> &loop1,
const std::vector<FermionField> &loop2,
const std::vector<Gamma::Algebra> &gamma1,
const std::vector<Gamma::Algebra> &gamma2,
int type,
bool use_blas = false)
{
GridBase *grid = left[0].Grid();
const int orthogdim = 3;
int rd = grid->_rdimensions[orthogdim];
int ld = grid->_ldimensions[orthogdim];
int Nd = grid->_ndimension;
int Nsimd = grid->Nsimd();
int nt = result.dimension(0);
int N_i = (int)left.size();
int N_j = (int)right.size();
std::string tag = std::string(use_blas ? "[blas" : "[ref ") + " type=" + std::to_string(type) + "]";
auto Tms = [](double us) { return us * 1e-3; };
double t0;
// ------------------------------------------------------------------
// Loop propagator: sum_k outerProduct(loop1[k], loop2[k])
// ------------------------------------------------------------------
t0 = usecond();
PropagatorField loop(grid);
loop = Zero();
for (unsigned int k = 0; k < loop1.size(); ++k)
loop += outerProduct(loop1[k], loop2[k]);
std::cout << GridLogMessage << tag << " loop_build: " << Tms(usecond()-t0) << " ms\n";
// ------------------------------------------------------------------
// Pack conjugated left vectors
// ------------------------------------------------------------------
t0 = usecond();
std::vector<FermionField> leftv(N_i, grid);
for (int i = 0; i < N_i; i++)
leftv[i] = conjugate(left[i]);
std::cout << GridLogMessage << tag << " pack_left: " << Tms(usecond()-t0) << " ms\n";
// ------------------------------------------------------------------
// Per-site loop contraction into PropagatorField tloop (type-dependent)
// ------------------------------------------------------------------
t0 = usecond();
PropagatorField tloop(grid);
tloop = Zero();
{
autoView(tloopv, tloop, CpuWrite);
autoView(loopv, loop, CpuRead);
if (type == 0) {
thread_for(ss, grid->oSites(), {
for (int s1 = 0; s1 < Ns; ++s1)
for (int s2 = 0; s2 < Ns; ++s2)
tloopv[ss]()(s1,s2)(0,0) = loopv[ss]()(s1,s2)(0,0)
+ loopv[ss]()(s1,s2)(1,1)
+ loopv[ss]()(s1,s2)(2,2);
});
}
if (type == 1) {
thread_for(ss, grid->oSites(), {
tloopv[ss] = Zero();
for (int mu = 0; mu < (int)gamma1.size(); ++mu)
tloopv[ss] = tloopv[ss] + Gamma(gamma1[mu]) * loopv[ss] * Gamma(gamma2[mu]);
});
}
if (type == 2) {
thread_for(ss, grid->oSites(), {
tloopv[ss] = Zero();
for (int mu = 0; mu < (int)gamma2.size(); ++mu) {
SpinColourMatrix_v tmp = Gamma(gamma2[mu]) * loopv[ss];
int s1 = mu / Ns;
int s2 = mu % Ns;
for (int c1 = 0; c1 < Nc; ++c1)
for (int c2 = 0; c2 < Nc; ++c2)
tloopv[ss]()(s1,s2)(c1,c2) = tmp()(0,0)(c1,c2) + tmp()(1,1)(c1,c2)
+ tmp()(2,2)(c1,c2) + tmp()(3,3)(c1,c2);
}
});
}
if (type == 3) {
thread_for(ss, grid->oSites(), {
SpinMatrix_v spinLoop = Zero();
for (int s1 = 0; s1 < Ns; ++s1)
for (int s2 = 0; s2 < Ns; ++s2)
spinLoop()(s1,s2)() = loopv[ss]()(s1,s2)(0,0)
+ loopv[ss]()(s1,s2)(1,1)
+ loopv[ss]()(s1,s2)(2,2);
tloopv[ss] = Zero();
for (int mu = 0; mu < (int)gamma1.size(); ++mu) {
SpinMatrix_v tmp2 = Gamma(gamma1[mu]) * spinLoop * Gamma(gamma2[mu]);
for (int s1 = 0; s1 < Ns; ++s1)
for (int s2 = 0; s2 < Ns; ++s2)
tloopv[ss]()(s1,s2)(0,0) = tloopv[ss]()(s1,s2)(0,0) + tmp2()(s1,s2)();
}
});
}
}
std::cout << GridLogMessage << tag << " tloop: " << Tms(usecond()-t0) << " ms\n";
// Select addLoopRight kernel for this type
std::function<void(SpinColourVector_v &,
const SpinColourMatrix_v &,
const SpinColourVector_v &,
const std::vector<Gamma::Algebra> &,
const std::vector<Gamma::Algebra> &)> addLoopRight;
if (type == 0) {
addLoopRight = [](SpinColourVector_v &lR,
const SpinColourMatrix_v &loopm,
const SpinColourVector_v &rightv,
const std::vector<Gamma::Algebra> &g1,
const std::vector<Gamma::Algebra> &g2) {
SpinMatrix_v spinLoop = Zero();
for (int s1 = 0; s1 < Ns; ++s1)
for (int s2 = 0; s2 < Ns; ++s2)
spinLoop()(s1,s2)() = loopm()(s1,s2)(0,0);
for (int mu = 0; mu < (int)g1.size(); ++mu) {
SpinMatrix_v GLoop = Gamma(g2[mu]) * spinLoop;
auto trGLoop = GLoop()(0,0)() + GLoop()(1,1)() + GLoop()(2,2)() + GLoop()(3,3)();
SpinColourVector_v Grightv = Gamma(g1[mu]) * rightv;
for (int s = 0; s < Ns; ++s)
for (int c = 0; c < Nc; ++c)
lR()(s)(c) += Grightv()(s)(c) * trGLoop;
}
};
}
if (type == 1) {
addLoopRight = [](SpinColourVector_v &lR,
const SpinColourMatrix_v &loopm,
const SpinColourVector_v &rightv,
const std::vector<Gamma::Algebra> &g1,
const std::vector<Gamma::Algebra> &g2) {
for (int s = 0; s < Ns; ++s)
for (int c = 0; c < Nc; ++c) {
lR()(s)(c)
+= loopm()(s,0)(c,0) * rightv()(0)(0)
+ loopm()(s,0)(c,1) * rightv()(0)(1)
+ loopm()(s,0)(c,2) * rightv()(0)(2)
+ loopm()(s,1)(c,0) * rightv()(1)(0)
+ loopm()(s,1)(c,1) * rightv()(1)(1)
+ loopm()(s,1)(c,2) * rightv()(1)(2)
+ loopm()(s,2)(c,0) * rightv()(2)(0)
+ loopm()(s,2)(c,1) * rightv()(2)(1)
+ loopm()(s,2)(c,2) * rightv()(2)(2)
+ loopm()(s,3)(c,0) * rightv()(3)(0)
+ loopm()(s,3)(c,1) * rightv()(3)(1)
+ loopm()(s,3)(c,2) * rightv()(3)(2);
}
};
}
if (type == 2) {
addLoopRight = [](SpinColourVector_v &lR,
const SpinColourMatrix_v &loopm,
const SpinColourVector_v &rightv,
const std::vector<Gamma::Algebra> &g1,
const std::vector<Gamma::Algebra> &g2) {
for (int mu = 0; mu < (int)g1.size(); ++mu) {
int s1 = mu / Ns;
int s2 = mu % Ns;
SpinColourVector_v Grightv = Gamma(g1[mu]) * rightv;
for (int s = 0; s < Ns; ++s)
for (int c = 0; c < Nc; ++c)
lR()(s)(c) += loopm()(s1,s2)(c,0) * Grightv()(s)(0)
+ loopm()(s1,s2)(c,1) * Grightv()(s)(1)
+ loopm()(s1,s2)(c,2) * Grightv()(s)(2);
}
};
}
if (type == 3) {
addLoopRight = [](SpinColourVector_v &lR,
const SpinColourMatrix_v &loopm,
const SpinColourVector_v &rightv,
const std::vector<Gamma::Algebra> &g1,
const std::vector<Gamma::Algebra> &g2) {
for (int s = 0; s < Ns; ++s)
for (int c = 0; c < Nc; ++c)
lR()(s)(c) += loopm()(s,0)(0,0) * rightv()(0)(c)
+ loopm()(s,1)(0,0) * rightv()(1)(c)
+ loopm()(s,2)(0,0) * rightv()(2)(c)
+ loopm()(s,3)(0,0) * rightv()(3)(c);
};
}
// ------------------------------------------------------------------
// Pack loopRight[j] = type-kernel(tloop, right[j]) per site
// ------------------------------------------------------------------
t0 = usecond();
std::vector<FermionField> loopRight(N_j, grid);
{
autoView(tlv, tloop, CpuRead);
for (int j = 0; j < N_j; j++) {
loopRight[j] = Zero();
autoView(lRv, loopRight[j], CpuWrite);
autoView(rv, right[j], CpuRead);
thread_for(ss, grid->oSites(), {
addLoopRight(lRv[ss], tlv[ss], rv[ss], gamma1, gamma2);
});
}
}
std::cout << GridLogMessage << tag << " pack_loopright: " << Tms(usecond()-t0) << " ms\n";
if (use_blas) {
// ------------------------------------------------------------------
// BLAS path: A2ASpatialSum (Allocate + PackLeft + PackRight + Sum)
// ------------------------------------------------------------------
A2ASpatialSum<SpinColourVector_v> spatial_sum;
double t_blas_start = usecond();
t0 = usecond();
spatial_sum.Allocate(N_i, N_j, grid);
std::cout << GridLogMessage << tag << " Allocate: " << Tms(usecond()-t0) << " ms\n";
t0 = usecond();
spatial_sum.PackLeft(leftv);
std::cout << GridLogMessage << tag << " PackLeft: " << Tms(usecond()-t0) << " ms\n";
t0 = usecond();
spatial_sum.PackRight(loopRight);
std::cout << GridLogMessage << tag << " PackRight: " << Tms(usecond()-t0) << " ms\n";
t0 = usecond();
spatial_sum.Sum(result);
std::cout << GridLogMessage << tag << " Sum (GEMM+MPI): " << Tms(usecond()-t0) << " ms\n";
std::cout << GridLogMessage << tag << " A2ASpatialSum: " << Tms(usecond()-t_blas_start) << " ms [TOTAL]\n";
} else {
// ------------------------------------------------------------------
// Reference path: SIMD spatial sum + scalar extraction
// ------------------------------------------------------------------
int MFrvol = rd * N_i * N_j;
int MFlvol = ld * N_i * N_j;
Vector<Scalar_v> lvSum(MFrvol);
thread_for(r, MFrvol, { lvSum[r] = Zero(); });
t0 = usecond();
{
int e1 = grid->_slice_nblock[orthogdim];
int e2 = grid->_slice_block [orthogdim];
int stride = grid->_slice_stride[orthogdim];
using LView = decltype(leftv[0].View(CpuRead));
using RView = decltype(loopRight[0].View(CpuRead));
std::vector<LView> lv_views;
std::vector<RView> lr_views;
lv_views.reserve(N_i);
lr_views.reserve(N_j);
for (int i = 0; i < N_i; i++) lv_views.push_back(leftv[i].View(CpuRead));
for (int j = 0; j < N_j; j++) lr_views.push_back(loopRight[j].View(CpuRead));
thread_for(r, rd, {
int so = r * grid->_ostride[orthogdim];
int base = N_i * N_j * r;
for (int n = 0; n < e1; n++)
for (int b = 0; b < e2; b++) {
int ss = so + n * stride + b;
for (int ii = 0; ii < N_i; ii++)
for (int jj = 0; jj < N_j; jj++) {
int idx = jj + N_j * ii + base;
for (int s = 0; s < Ns; ++s)
for (int c = 0; c < Nc; ++c)
lvSum[idx]()()() += lv_views[ii][ss]()(s)(c) * lr_views[jj][ss]()(s)(c);
}
}
});
for (auto &v : lv_views) v.ViewClose();
for (auto &v : lr_views) v.ViewClose();
}
std::cout << GridLogMessage << tag << " spatial_sum: " << Tms(usecond()-t0) << " ms\n";
Vector<Scalar_s> lsSum(MFlvol);
thread_for(r, MFlvol, { lsSum[r] = scalar_type(0.0); });
t0 = usecond();
thread_for(rt, rd, {
Coordinate icoor(Nd);
ExtractBuffer<Scalar_s> extracted(Nsimd);
for (int ii = 0; ii < N_i; ii++)
for (int jj = 0; jj < N_j; jj++) {
int ij_rdx = jj + N_j * (ii + N_i * rt);
extract(lvSum[ij_rdx], extracted);
for (int idx = 0; idx < Nsimd; idx++) {
grid->iCoorFromIindex(icoor, idx);
int ldx = rt + icoor[orthogdim] * rd;
int ij_ldx = jj + N_j * (ii + N_i * ldx);
lsSum[ij_ldx] = lsSum[ij_ldx] + extracted[idx];
}
}
});
std::cout << GridLogMessage << tag << " simd_extract: " << Tms(usecond()-t0) << " ms\n";
int pd = grid->_processors[orthogdim];
int pc = grid->_processor_coor[orthogdim];
t0 = usecond();
Vector<ComplexD> cache(nt * N_i * N_j, ComplexD(0.0));
for (int lt = 0; lt < ld; lt++)
for (int pt = 0; pt < pd; pt++) {
int t = lt + pt * ld;
for (int ii = 0; ii < N_i; ii++)
for (int jj = 0; jj < N_j; jj++) {
if (pt == pc) {
int ij_ldx = jj + N_j * (ii + N_i * lt);
cache[jj + N_j * (ii + N_i * t)] = lsSum[ij_ldx]()()();
}
}
}
grid->GlobalSumVector(cache.data(), nt * N_i * N_j);
std::cout << GridLogMessage << tag << " globalsum: " << Tms(usecond()-t0) << " ms\n";
for (int t = 0; t < nt; t++)
for (int ii = 0; ii < N_i; ii++)
for (int jj = 0; jj < N_j; jj++)
result(t, ii, jj) = cache[jj + N_j * (ii + N_i * t)];
}
}
};
// ================================================================
// Free-function GPU kernels — accelerator_for, v(ss) reads,
// coalescedWrite writes, vobj-level arithmetic throughout.
// Gamma arrays passed as Vector<Gamma::Algebra> (UVM).
// ================================================================
void A2ALoopPropagator(PropagatorField &loop,
const std::vector<FermionField> &loop1,
const std::vector<FermionField> &loop2)
{
int Nk = (int)loop1.size();
uint64_t oSites = loop.Grid()->oSites();
int Nsimd = SpinColourVector_v::Nsimd();
typedef decltype(loop1[0].View(AcceleratorRead)) View;
std::vector<View> v1, v2;
v1.reserve(Nk); v2.reserve(Nk);
for (int k = 0; k < Nk; k++) {
v1.push_back(loop1[k].View(AcceleratorRead));
v2.push_back(loop2[k].View(AcceleratorRead));
}
deviceVector<SpinColourVector_v *> l1p(Nk), l2p(Nk);
for (int k = 0; k < Nk; k++) {
acceleratorPut(l1p[k], &v1[k][0]);
acceleratorPut(l2p[k], &v2[k][0]);
}
autoView(loopv, loop, AcceleratorWrite);
SpinColourVector_v **l1 = &l1p[0];
SpinColourVector_v **l2 = &l2p[0];
int lNk = Nk;
accelerator_for(ss, oSites, Nsimd, {
auto result = outerProduct(coalescedRead(l1[0][ss]), coalescedRead(l2[0][ss]));
for (int k = 1; k < lNk; k++)
result = result + outerProduct(coalescedRead(l1[k][ss]), coalescedRead(l2[k][ss]));
coalescedWrite(loopv[ss], result);
});
}
void A2APackLeftConjugated(FermionField &out, const FermionField &in)
{
autoView(outv, out, AcceleratorWrite);
autoView(inv, in, AcceleratorRead);
uint64_t Osites = in.Grid()->oSites();
int Nsimd = SpinColourVector_v::Nsimd();
accelerator_for(ss, Osites, Nsimd, {
coalescedWrite(outv[ss], conjugate(inv(ss)));
});
}
// Type 0: colour-trace stored in (s1,s2)(0,0)
void A2ALoopLeftContractionType0(PropagatorField &tloop, const PropagatorField &loop)
{
autoView(tloopv, tloop, AcceleratorWrite);
autoView(loopv, loop, AcceleratorRead);
uint64_t Osites = loop.Grid()->oSites();
int Nsimd = SpinColourMatrix_v::Nsimd();
accelerator_for(ss, Osites, Nsimd, {
auto l = loopv(ss);
auto tmp = l; tmp = Zero();
for (int s1 = 0; s1 < Ns; ++s1)
for (int s2 = 0; s2 < Ns; ++s2)
tmp()(s1,s2)(0,0) = l()(s1,s2)(0,0) + l()(s1,s2)(1,1) + l()(s1,s2)(2,2);
coalescedWrite(tloopv[ss], tmp);
});
}
// Type 1: tloop = sum_mu Gamma(g1[mu]) * loop * Gamma(g2[mu])
void A2ALoopLeftContractionType1(PropagatorField &tloop, const PropagatorField &loop,
const Vector<Gamma::Algebra> &gamma1,
const Vector<Gamma::Algebra> &gamma2)
{
int ng = (int)gamma1.size();
const Gamma::Algebra *g1 = gamma1.data();
const Gamma::Algebra *g2 = gamma2.data();
autoView(tloopv, tloop, AcceleratorWrite);
autoView(loopv, loop, AcceleratorRead);
uint64_t Osites = loop.Grid()->oSites();
int Nsimd = SpinColourMatrix_v::Nsimd();
accelerator_for(ss, Osites, Nsimd, {
auto l = loopv(ss);
auto tmp = l; tmp = Zero();
for (int mu = 0; mu < ng; ++mu)
tmp = tmp + Gamma(g1[mu]) * l * Gamma(g2[mu]);
coalescedWrite(tloopv[ss], tmp);
});
}
// Type 2: for mu=[0..ng), s1=mu/Ns, s2=mu%Ns;
// tloop(s1,s2)(c1,c2) = Tr_spin( Gamma(g2[mu]) * loop )(c1,c2)
void A2ALoopLeftContractionType2(PropagatorField &tloop, const PropagatorField &loop,
const Vector<Gamma::Algebra> &gamma2)
{
int ng = (int)gamma2.size();
const Gamma::Algebra *g2 = gamma2.data();
autoView(tloopv, tloop, AcceleratorWrite);
autoView(loopv, loop, AcceleratorRead);
uint64_t Osites = loop.Grid()->oSites();
int Nsimd = SpinColourMatrix_v::Nsimd();
accelerator_for(ss, Osites, Nsimd, {
auto l = loopv(ss);
auto tmp = l; tmp = Zero();
for (int mu = 0; mu < ng; ++mu) {
auto gtmp = Gamma(g2[mu]) * l;
int s1 = mu / Ns;
int s2 = mu % Ns;
for (int c1 = 0; c1 < Nc; ++c1)
for (int c2 = 0; c2 < Nc; ++c2)
tmp()(s1,s2)(c1,c2) = gtmp()(0,0)(c1,c2) + gtmp()(1,1)(c1,c2)
+ gtmp()(2,2)(c1,c2) + gtmp()(3,3)(c1,c2);
}
coalescedWrite(tloopv[ss], tmp);
});
}
// Type 3: colour-trace → spin matrix → sum_mu G1*spinLoop*G2 stored in (s1,s2)(0,0)
void A2ALoopLeftContractionType3(PropagatorField &tloop, const PropagatorField &loop,
const Vector<Gamma::Algebra> &gamma1,
const Vector<Gamma::Algebra> &gamma2)
{
int ng = (int)gamma1.size();
const Gamma::Algebra *g1 = gamma1.data();
const Gamma::Algebra *g2 = gamma2.data();
autoView(tloopv, tloop, AcceleratorWrite);
autoView(loopv, loop, AcceleratorRead);
uint64_t Osites = loop.Grid()->oSites();
int Nsimd = SpinColourMatrix_v::Nsimd();
accelerator_for(ss, Osites, Nsimd, {
typedef decltype(coalescedRead(loopv[0])) calcSCMatrix;
typedef iSpinMatrix<typename calcSCMatrix::vector_type> calcSpinMatrix;
auto l = loopv(ss);
calcSpinMatrix spinLoop; spinLoop = Zero();
for (int s1 = 0; s1 < Ns; ++s1)
for (int s2 = 0; s2 < Ns; ++s2)
spinLoop()(s1,s2)() = l()(s1,s2)(0,0) + l()(s1,s2)(1,1) + l()(s1,s2)(2,2);
auto tmp = l; tmp = Zero();
for (int mu = 0; mu < ng; ++mu) {
calcSpinMatrix tmp2 = Gamma(g1[mu]) * spinLoop * Gamma(g2[mu]);
for (int s1 = 0; s1 < Ns; ++s1)
for (int s2 = 0; s2 < Ns; ++s2)
tmp()(s1,s2)(0,0) = tmp()(s1,s2)(0,0) + tmp2()(s1,s2)();
}
coalescedWrite(tloopv[ss], tmp);
});
}
// Type 0: loopRight = sum_mu Tr(G2*spinLoop) * G1*right
// where spinLoop(s1,s2) = tloop(s1,s2)(0,0)
void A2ALoopRightContractionType0(FermionField &loopRight,
const PropagatorField &tloop,
const FermionField &right,
const Vector<Gamma::Algebra> &gamma1,
const Vector<Gamma::Algebra> &gamma2)
{
int ng = (int)gamma1.size();
const Gamma::Algebra *g1 = gamma1.data();
const Gamma::Algebra *g2 = gamma2.data();
autoView(lRv, loopRight, AcceleratorWrite);
autoView(tlv, tloop, AcceleratorRead);
autoView(rv, right, AcceleratorRead);
uint64_t Osites = right.Grid()->oSites();
int Nsimd = SpinColourVector_v::Nsimd();
accelerator_for(ss, Osites, Nsimd, {
typedef decltype(coalescedRead(rv[0])) calcSCVector;
typedef decltype(coalescedRead(tlv[0])) calcSCMatrix;
typedef iSpinMatrix<typename calcSCMatrix::vector_type> calcSpinMatrix;
auto loopm = tlv(ss);
auto rightv = rv(ss);
calcSpinMatrix spinLoop; spinLoop = Zero();
for (int s1 = 0; s1 < Ns; ++s1)
for (int s2 = 0; s2 < Ns; ++s2)
spinLoop()(s1,s2)() = loopm()(s1,s2)(0,0);
calcSCVector lR; lR = Zero();
for (int mu = 0; mu < ng; ++mu) {
auto GLoop = Gamma(g2[mu]) * spinLoop;
auto trGLoop = GLoop()(0,0)() + GLoop()(1,1)() + GLoop()(2,2)() + GLoop()(3,3)();
auto Grightv = Gamma(g1[mu]) * rightv;
for (int s = 0; s < Ns; ++s)
for (int c = 0; c < Nc; ++c)
lR()(s)(c) = lR()(s)(c) + Grightv()(s)(c) * trGLoop;
}
coalescedWrite(lRv[ss], lR);
});
}
// Type 1: loopRight = tloop * right (SpinColourMatrix * SpinColourVector)
void A2ALoopRightContractionType1(FermionField &loopRight,
const PropagatorField &tloop,
const FermionField &right)
{
autoView(lRv, loopRight, AcceleratorWrite);
autoView(tlv, tloop, AcceleratorRead);
autoView(rv, right, AcceleratorRead);
uint64_t Osites = right.Grid()->oSites();
int Nsimd = SpinColourVector_v::Nsimd();
accelerator_for(ss, Osites, Nsimd, {
coalescedWrite(lRv[ss], tlv(ss) * rv(ss));
});
}
// Type 2: loopRight(s)(c) = sum_{mu,c'} tloop(s1,s2)(c,c') * (G(g1[mu])*right)(s)(c')
void A2ALoopRightContractionType2(FermionField &loopRight,
const PropagatorField &tloop,
const FermionField &right,
const Vector<Gamma::Algebra> &gamma1)
{
int ng = (int)gamma1.size();
const Gamma::Algebra *g1 = gamma1.data();
autoView(lRv, loopRight, AcceleratorWrite);
autoView(tlv, tloop, AcceleratorRead);
autoView(rv, right, AcceleratorRead);
uint64_t Osites = right.Grid()->oSites();
int Nsimd = SpinColourVector_v::Nsimd();
accelerator_for(ss, Osites, Nsimd, {
typedef decltype(coalescedRead(rv[0])) calcSCVector;
auto loopm = tlv(ss);
auto rightv = rv(ss);
calcSCVector lR; lR = Zero();
for (int mu = 0; mu < ng; ++mu) {
int s1 = mu / Ns;
int s2 = mu % Ns;
auto Grightv = Gamma(g1[mu]) * rightv;
for (int s = 0; s < Ns; ++s)
for (int c = 0; c < Nc; ++c)
lR()(s)(c) = lR()(s)(c)
+ loopm()(s1,s2)(c,0) * Grightv()(s)(0)
+ loopm()(s1,s2)(c,1) * Grightv()(s)(1)
+ loopm()(s1,s2)(c,2) * Grightv()(s)(2);
}
coalescedWrite(lRv[ss], lR);
});
}
// Type 3: loopRight(s)(c) = sum_{s'} tloop(s,s')(0,0) * right(s')(c)
void A2ALoopRightContractionType3(FermionField &loopRight,
const PropagatorField &tloop,
const FermionField &right)
{
autoView(lRv, loopRight, AcceleratorWrite);
autoView(tlv, tloop, AcceleratorRead);
autoView(rv, right, AcceleratorRead);
uint64_t Osites = right.Grid()->oSites();
int Nsimd = SpinColourVector_v::Nsimd();
accelerator_for(ss, Osites, Nsimd, {
typedef decltype(coalescedRead(rv[0])) calcSCVector;
auto loopm = tlv(ss);
auto rightv = rv(ss);
calcSCVector lR; lR = Zero();
for (int s = 0; s < Ns; ++s)
for (int c = 0; c < Nc; ++c)
lR()(s)(c) = loopm()(s,0)(0,0) * rightv()(0)(c)
+ loopm()(s,1)(0,0) * rightv()(1)(c)
+ loopm()(s,2)(0,0) * rightv()(2)(c)
+ loopm()(s,3)(0,0) * rightv()(3)(c);
coalescedWrite(lRv[ss], lR);
});
}
// ================================================================
// GPU-offloaded extended meson field: accelerator_for contractions
// + A2ASpatialSum GEMM spatial reduction.
// ================================================================
class A2AExtendedMesonFieldGPU
{
public:
static void compute(
Eigen::Tensor<ComplexD, 3> &result,
const std::vector<FermionField> &left,
const std::vector<FermionField> &right,
const std::vector<FermionField> &loop1,
const std::vector<FermionField> &loop2,
const std::vector<Gamma::Algebra> &gamma1_in,
const std::vector<Gamma::Algebra> &gamma2_in,
int type)
{
GridBase *grid = left[0].Grid();
int N_i = (int)left.size();
int N_j = (int)right.size();
std::string tag = std::string("[gpu type=") + std::to_string(type) + "]";
auto Tms = [](double us) { return us * 1e-3; };
double t0;
Vector<Gamma::Algebra> gamma1(gamma1_in.begin(), gamma1_in.end());
Vector<Gamma::Algebra> gamma2(gamma2_in.begin(), gamma2_in.end());
t0 = usecond();
for (auto &f : loop1) { autoView(v, f, AcceleratorRead); }
for (auto &f : loop2) { autoView(v, f, AcceleratorRead); }
std::cout << GridLogMessage << tag << " view_open_loop: " << Tms(usecond()-t0) << " ms\n";
t0 = usecond();
PropagatorField loop(grid);
A2ALoopPropagator(loop, loop1, loop2);
std::cout << GridLogMessage << tag << " loop_build: " << Tms(usecond()-t0) << " ms\n";
t0 = usecond();
for (int i = 0; i < N_i; i++) { autoView(v, left[i], AcceleratorRead); }
std::cout << GridLogMessage << tag << " view_open_left: " << Tms(usecond()-t0) << " ms\n";
t0 = usecond();
std::vector<FermionField> leftv(N_i, grid);
for (int i = 0; i < N_i; i++)
A2APackLeftConjugated(leftv[i], left[i]);
std::cout << GridLogMessage << tag << " pack_left: " << Tms(usecond()-t0) << " ms\n";
t0 = usecond();
PropagatorField tloop(grid);
tloop = Zero();
switch (type) {
case 0: A2ALoopLeftContractionType0(tloop, loop); break;
case 1: A2ALoopLeftContractionType1(tloop, loop, gamma1, gamma2); break;
case 2: A2ALoopLeftContractionType2(tloop, loop, gamma2); break;
case 3: A2ALoopLeftContractionType3(tloop, loop, gamma1, gamma2); break;
}
std::cout << GridLogMessage << tag << " tloop: " << Tms(usecond()-t0) << " ms\n";
t0 = usecond();
{ autoView(tlv, tloop, AcceleratorRead); }
for (int j = 0; j < N_j; j++) { autoView(rv, right[j], AcceleratorRead); }
std::cout << GridLogMessage << tag << " view_open_right: " << Tms(usecond()-t0) << " ms\n";
t0 = usecond();
std::vector<FermionField> loopRight(N_j, grid);
for (int j = 0; j < N_j; j++) {
switch (type) {
case 0: A2ALoopRightContractionType0(loopRight[j], tloop, right[j], gamma1, gamma2); break;
case 1: A2ALoopRightContractionType1(loopRight[j], tloop, right[j]); break;
case 2: A2ALoopRightContractionType2(loopRight[j], tloop, right[j], gamma1); break;
case 3: A2ALoopRightContractionType3(loopRight[j], tloop, right[j]); break;
}
}
std::cout << GridLogMessage << tag << " pack_loopright: " << Tms(usecond()-t0) << " ms\n";
A2ASpatialSum<SpinColourVector_v> spatial_sum;
double t_blas = usecond();
t0 = usecond();
spatial_sum.Allocate(N_i, N_j, grid);
std::cout << GridLogMessage << tag << " Allocate: " << Tms(usecond()-t0) << " ms\n";
t0 = usecond();
spatial_sum.PackLeft(leftv);
std::cout << GridLogMessage << tag << " PackLeft: " << Tms(usecond()-t0) << " ms\n";
t0 = usecond();
spatial_sum.PackRight(loopRight);
std::cout << GridLogMessage << tag << " PackRight: " << Tms(usecond()-t0) << " ms\n";
t0 = usecond();
spatial_sum.Sum(result);
std::cout << GridLogMessage << tag << " Sum (GEMM+MPI): " << Tms(usecond()-t0) << " ms\n";
std::cout << GridLogMessage << tag << " A2ASpatialSum: " << Tms(usecond()-t_blas) << " ms [TOTAL]\n";
}
};
int main(int argc, char *argv[])
{
Grid_init(&argc, &argv);
Coordinate latt_size = GridDefaultLatt();
Coordinate simd_layout = GridDefaultSimd(4, vComplexD::Nsimd());
Coordinate mpi_layout = GridDefaultMpi();
GridCartesian grid(latt_size, simd_layout, mpi_layout);
int Nt = latt_size[Tp];
int N_i = 8;
int N_j = 8;
int Nloop = 4;
if (GridCmdOptionExists(argv, argv+argc, "--Ni"))
N_i = std::stoi(GridCmdOptionPayload(argv, argv+argc, "--Ni"));
if (GridCmdOptionExists(argv, argv+argc, "--Nj"))
N_j = std::stoi(GridCmdOptionPayload(argv, argv+argc, "--Nj"));
GridParallelRNG pRNG(&grid);
pRNG.SeedFixedIntegers({1, 2, 3, 4});
std::vector<FermionField> left(N_i, &grid);
std::vector<FermionField> right(N_j, &grid);
std::vector<FermionField> loop1(Nloop, &grid);
std::vector<FermionField> loop2(Nloop, &grid);
for (auto &f : left) random(pRNG, f);
for (auto &f : right) random(pRNG, f);
for (auto &f : loop1) random(pRNG, f);
for (auto &f : loop2) random(pRNG, f);
std::vector<Gamma::Algebra> GammaMU = {
Gamma::Algebra::GammaX,
Gamma::Algebra::GammaY,
Gamma::Algebra::GammaZ,
Gamma::Algebra::GammaT
};
Eigen::Tensor<ComplexD, 3> result_ref(Nt, N_i, N_j);
Eigen::Tensor<ComplexD, 3> result_blas(Nt, N_i, N_j);
Eigen::Tensor<ComplexD, 3> result_gpu(Nt, N_i, N_j);
double t_ref = 0, t_blas = 0, t_gpu = 0, start, stop;
for (int type = 0; type < 4; type++) {
result_ref.setZero();
start = usecond();
A2AExtendedMesonFieldRef::compute(result_ref, left, right, loop1, loop2,
GammaMU, GammaMU, type, false);
stop = usecond(); t_ref = stop - start;
result_blas.setZero();
start = usecond();
A2AExtendedMesonFieldRef::compute(result_blas, left, right, loop1, loop2,
GammaMU, GammaMU, type, true);
stop = usecond(); t_blas = stop - start;
result_gpu.setZero();
start = usecond();
A2AExtendedMesonFieldGPU::compute(result_gpu, left, right, loop1, loop2,
GammaMU, GammaMU, type);
stop = usecond(); t_gpu = stop - start;
double norm2_ref = 0.0, norm2_blas = 0.0, norm2_gpu = 0.0;
double norm2_diff_blas = 0.0, norm2_diff_gpu = 0.0;
for (int t = 0; t < Nt; t++)
for (int ii = 0; ii < N_i; ii++)
for (int jj = 0; jj < N_j; jj++) {
norm2_ref += norm2(result_ref(t, ii, jj));
norm2_blas += norm2(result_blas(t, ii, jj));
norm2_gpu += norm2(result_gpu(t, ii, jj));
ComplexD diff_blas = result_ref(t, ii, jj) - result_blas(t, ii, jj);
ComplexD diff_gpu = result_ref(t, ii, jj) - result_gpu(t, ii, jj);
norm2_diff_blas += norm2(diff_blas);
norm2_diff_gpu += norm2(diff_gpu);
}
double rel_blas = (norm2_ref > 0) ? std::sqrt(norm2_diff_blas / norm2_ref) : 0.0;
double rel_gpu = (norm2_ref > 0) ? std::sqrt(norm2_diff_gpu / norm2_ref) : 0.0;
std::cout << GridLogMessage
<< "type=" << type
<< " norm2_ref=" << norm2_ref
<< " norm2_blas=" << norm2_blas
<< " norm2_gpu=" << norm2_gpu
<< " rel_blas=" << rel_blas
<< " rel_gpu=" << rel_gpu
<< " t_ref=" << t_ref * 1e-6 << "s"
<< " t_blas=" << t_blas * 1e-6 << "s"
<< " t_gpu=" << t_gpu * 1e-6 << "s"
<< std::endl;
GRID_ASSERT(rel_blas < 1e-10);
GRID_ASSERT(rel_gpu < 1e-10);
}
std::cout << GridLogMessage << "All types passed A2ASpatialSum and GPU regression." << std::endl;
Grid_finalize();
return EXIT_SUCCESS;
}
#endif
+49 -30
View File
@@ -47,49 +47,68 @@ static void tryPlanAndExec(int G, long howmany) {
printf("--- G=%-4d howmany=%-10ld total_elems=%-12ld ---\n",
G, howmany, nelems);
// --- A: create plan first, allocate buffer afterwards ---
// Allocate device buffer (hipfftDoubleComplex = 16 bytes each)
hipfftDoubleComplex *dbuf = nullptr;
hipError_t herr = hipMalloc(&dbuf, nelems * sizeof(hipfftDoubleComplex));
if (herr != hipSuccess) {
printf(" hipMalloc failed (%d) for %ld elems — skipping\n\n", (int)herr, nelems);
return;
}
hipMemset(dbuf, 0, nelems * sizeof(hipfftDoubleComplex));
// 1. hipfftPlanMany (one-step, nullptr embed) — current Grid path
{
hipfftHandle p;
hipfftResult rv = hipfftPlanMany(&p, 1, n,
nullptr, 1, G,
nullptr, 1, G,
HIPFFT_Z2Z, (int)howmany);
printf(" hipfftPlanMany create : %d (%s)\n", (int)rv, hipfftResultString(rv));
if (rv == HIPFFT_SUCCESS) {
rv = hipfftExecZ2Z(p, dbuf, dbuf, HIPFFT_FORWARD);
hipDeviceSynchronize();
printf(" hipfftPlanMany execFwd: %d (%s)\n", (int)rv, hipfftResultString(rv));
hipfftDestroy(p);
}
}
// 2. hipfftCreate + hipfftMakePlanMany (two-step) — also current Grid path
{
hipfftHandle p;
size_t workSize = 0;
hipfftCreate(&p);
hipfftResult rv = hipfftMakePlanMany(p, 1, n,
nullptr, 1, G, nullptr, 1, G,
HIPFFT_Z2Z, (int)howmany, &workSize);
printf(" plan-first create : %d (%s)\n", (int)rv, hipfftResultString(rv));
if (rv == HIPFFT_SUCCESS) {
hipfftDoubleComplex *buf = nullptr;
hipMalloc(&buf, nelems * sizeof(hipfftDoubleComplex));
hipMemset(buf, 0, nelems * sizeof(hipfftDoubleComplex));
rv = hipfftExecZ2Z(p, buf, buf, HIPFFT_FORWARD);
hipDeviceSynchronize();
printf(" plan-first execFwd: %d (%s)\n", (int)rv, hipfftResultString(rv));
hipFree(buf);
hipfftResult rc = hipfftCreate(&p);
if (rc == HIPFFT_SUCCESS) {
hipfftResult rv = hipfftMakePlanMany(p, 1, n,
nullptr, 1, G,
nullptr, 1, G,
HIPFFT_Z2Z, (int)howmany, &workSize);
printf(" hipfftMakePlanMany : %d (%s) workSize=%zu\n",
(int)rv, hipfftResultString(rv), workSize);
if (rv == HIPFFT_SUCCESS) {
rv = hipfftExecZ2Z(p, dbuf, dbuf, HIPFFT_FORWARD);
hipDeviceSynchronize();
printf(" hipfftMakePlanMany exec : %d (%s)\n", (int)rv, hipfftResultString(rv));
}
hipfftDestroy(p);
} else {
printf(" hipfftCreate : %d (%s)\n", (int)rc, hipfftResultString(rc));
}
hipfftDestroy(p);
}
// --- B: hipMalloc first, create plan afterwards ---
// 3. hipfftPlan1d (simplest API, batch = howmany)
{
hipfftDoubleComplex *buf = nullptr;
hipMalloc(&buf, nelems * sizeof(hipfftDoubleComplex));
hipMemset(buf, 0, nelems * sizeof(hipfftDoubleComplex));
hipfftHandle p;
size_t workSize = 0;
hipfftCreate(&p);
hipfftResult rv = hipfftMakePlanMany(p, 1, n,
nullptr, 1, G, nullptr, 1, G,
HIPFFT_Z2Z, (int)howmany, &workSize);
printf(" malloc-first create : %d (%s)\n", (int)rv, hipfftResultString(rv));
hipfftResult rv = hipfftPlan1d(&p, G, HIPFFT_Z2Z, (int)howmany);
printf(" hipfftPlan1d create : %d (%s)\n", (int)rv, hipfftResultString(rv));
if (rv == HIPFFT_SUCCESS) {
rv = hipfftExecZ2Z(p, buf, buf, HIPFFT_FORWARD);
rv = hipfftExecZ2Z(p, dbuf, dbuf, HIPFFT_FORWARD);
hipDeviceSynchronize();
printf(" malloc-first execFwd: %d (%s)\n", (int)rv, hipfftResultString(rv));
printf(" hipfftPlan1d execFwd: %d (%s)\n", (int)rv, hipfftResultString(rv));
hipfftDestroy(p);
}
hipfftDestroy(p);
hipFree(buf);
}
hipFree(dbuf);
printf("\n");
}