diff --git a/.gitignore b/.gitignore index da7de5e4..d743ee06 100644 --- a/.gitignore +++ b/.gitignore @@ -9,6 +9,7 @@ ################ *~ *# +*.sublime-* # Precompiled Headers # ####################### @@ -91,6 +92,7 @@ build*/* ##################### *.xcodeproj/* build.sh +.vscode # Eigen source # ################ @@ -103,4 +105,21 @@ lib/fftw/* # libtool macros # ################## m4/lt* -m4/libtool.m4 \ No newline at end of file +m4/libtool.m4 + +# github pages # +################ +gh-pages/ + +# Buck files # +############## +.buck* +buck-out +BUCK +make-bin-BUCK.sh + +# generated sources # +##################### +lib/qcd/spin/gamma-gen/*.h +lib/qcd/spin/gamma-gen/*.cc + diff --git a/.travis.yml b/.travis.yml index ae3efda8..7d8203ce 100644 --- a/.travis.yml +++ b/.travis.yml @@ -7,64 +7,8 @@ cache: matrix: include: - os: osx - osx_image: xcode7.2 + osx_image: xcode8.3 compiler: clang - - compiler: gcc - addons: - apt: - sources: - - ubuntu-toolchain-r-test - packages: - - g++-4.9 - - libmpfr-dev - - libgmp-dev - - libmpc-dev - - libopenmpi-dev - - openmpi-bin - - binutils-dev - env: VERSION=-4.9 - - compiler: gcc - addons: - apt: - sources: - - ubuntu-toolchain-r-test - packages: - - g++-5 - - libmpfr-dev - - libgmp-dev - - libmpc-dev - - libopenmpi-dev - - openmpi-bin - - binutils-dev - env: VERSION=-5 - - compiler: clang - addons: - apt: - sources: - - ubuntu-toolchain-r-test - packages: - - g++-4.8 - - libmpfr-dev - - libgmp-dev - - libmpc-dev - - libopenmpi-dev - - openmpi-bin - - binutils-dev - env: CLANG_LINK=http://llvm.org/releases/3.8.0/clang+llvm-3.8.0-x86_64-linux-gnu-ubuntu-14.04.tar.xz - - compiler: clang - addons: - apt: - sources: - - ubuntu-toolchain-r-test - packages: - - g++-4.8 - - libmpfr-dev - - libgmp-dev - - libmpc-dev - - libopenmpi-dev - - openmpi-bin - - binutils-dev - env: CLANG_LINK=http://llvm.org/releases/3.7.0/clang+llvm-3.7.0-x86_64-linux-gnu-ubuntu-14.04.tar.xz before_install: - export GRIDDIR=`pwd` @@ -73,13 +17,15 @@ before_install: - if [[ "$TRAVIS_OS_NAME" == "linux" ]] && [[ "$CC" == "clang" ]]; then export LD_LIBRARY_PATH="${GRIDDIR}/clang/lib:${LD_LIBRARY_PATH}"; fi - if [[ "$TRAVIS_OS_NAME" == "osx" ]]; then brew update; fi - if [[ "$TRAVIS_OS_NAME" == "osx" ]]; then brew install libmpc; fi - - if [[ "$TRAVIS_OS_NAME" == "osx" ]]; then brew install openmpi; fi - - if [[ "$TRAVIS_OS_NAME" == "osx" ]] && [[ "$CC" == "gcc" ]]; then brew install gcc5; fi install: - export CC=$CC$VERSION - export CXX=$CXX$VERSION - echo $PATH + - which autoconf + - autoconf --version + - which automake + - automake --version - which $CC - $CC --version - which $CXX @@ -92,15 +38,9 @@ script: - cd build - ../configure --enable-precision=single --enable-simd=SSE4 --enable-comms=none - make -j4 - - ./benchmarks/Benchmark_dwf --threads 1 + - ./benchmarks/Benchmark_dwf --threads 1 --debug-signals - echo make clean - ../configure --enable-precision=double --enable-simd=SSE4 --enable-comms=none - make -j4 - - ./benchmarks/Benchmark_dwf --threads 1 - - echo make clean - - if [[ "$TRAVIS_OS_NAME" == "linux" ]]; then export CXXFLAGS='-DMPI_UINT32_T=MPI_UNSIGNED -DMPI_UINT64_T=MPI_UNSIGNED_LONG'; fi - - ../configure --enable-precision=single --enable-simd=SSE4 --enable-comms=mpi-auto - - make -j4 - - if [[ "$TRAVIS_OS_NAME" == "linux" ]]; then mpirun.openmpi -n 2 ./benchmarks/Benchmark_dwf --threads 1 --mpi 2.1.1.1; fi - - if [[ "$TRAVIS_OS_NAME" == "osx" ]]; then mpirun -n 2 ./benchmarks/Benchmark_dwf --threads 1 --mpi 2.1.1.1; fi - + - ./benchmarks/Benchmark_dwf --threads 1 --debug-signals + - make check diff --git a/Makefile.am b/Makefile.am index 818f0983..3a65cf1b 100644 --- a/Makefile.am +++ b/Makefile.am @@ -1,12 +1,17 @@ # additional include paths necessary to compile the C++ library -SUBDIRS = lib benchmarks tests +SUBDIRS = lib benchmarks tests extras include $(top_srcdir)/doxygen.inc -tests: all - $(MAKE) -C tests tests +bin_SCRIPTS=grid-config -.PHONY: tests doxygen-run doxygen-doc $(DX_PS_GOAL) $(DX_PDF_GOAL) + +.PHONY: bench check tests doxygen-run doxygen-doc $(DX_PS_GOAL) $(DX_PDF_GOAL) + +tests-local: all +bench-local: all +check-local: all AM_CXXFLAGS += -I$(top_builddir)/include + ACLOCAL_AMFLAGS = -I m4 diff --git a/README.md b/README.md index c47a257c..13dd6996 100644 --- a/README.md +++ b/README.md @@ -1,41 +1,13 @@ -# Grid - - - - - - - - - -
Last stable release - -
Development branch - -
+# Grid [![Teamcity status](http://ci.cliath.ph.ed.ac.uk/app/rest/builds/aggregated/strob:(buildType:(affectedProject(id:Grid)),branch:name:develop)/statusIcon.svg)](http://ci.cliath.ph.ed.ac.uk/project.html?projectId=Grid&tab=projectOverview) [![Travis status](https://travis-ci.org/paboyle/Grid.svg?branch=develop)](https://travis-ci.org/paboyle/Grid) **Data parallel C++ mathematical object library.** License: GPL v2. -Last update Nov 2016. +Last update June 2017. _Please do not send pull requests to the `master` branch which is reserved for releases._ -### Bug report - -_To help us tracking and solving more efficiently issues with Grid, please report problems using the issue system of GitHub rather than sending emails to Grid developers._ - -When you file an issue, please go though the following checklist: - -1. Check that the code is pointing to the `HEAD` of `develop` or any commit in `master` which is tagged with a version number. -2. Give a description of the target platform (CPU, network, compiler). Please give the full CPU part description, using for example `cat /proc/cpuinfo | grep 'model name' | uniq` (Linux) or `sysctl machdep.cpu.brand_string` (macOS) and the full output the `--version` option of your compiler. -3. Give the exact `configure` command used. -4. Attach `config.log`. -5. Attach `config.summary`. -6. Attach the output of `make V=1`. -7. Describe the issue and any previous attempt to solve it. If relevant, show how to reproduce the issue using a minimal working example. - ### Description @@ -58,13 +30,68 @@ optimally use MPI, OpenMP and SIMD parallelism under the hood. This is a signifi for most programmers. The layout transformations are parametrised by the SIMD vector length. This adapts according to the architecture. -Presently SSE4 (128 bit) AVX, AVX2, QPX (256 bit), IMCI, and AVX512 (512 bit) targets are supported (ARM NEON on the way). +Presently SSE4, ARM NEON (128 bits) AVX, AVX2, QPX (256 bits), IMCI and AVX512 (512 bits) targets are supported. -These are presented as `vRealF`, `vRealD`, `vComplexF`, and `vComplexD` internal vector data types. These may be useful in themselves for other programmers. +These are presented as `vRealF`, `vRealD`, `vComplexF`, and `vComplexD` internal vector data types. The corresponding scalar types are named `RealF`, `RealD`, `ComplexF` and `ComplexD`. MPI, OpenMP, and SIMD parallelism are present in the library. -Please see https://arxiv.org/abs/1512.03487 for more detail. +Please see [this paper](https://arxiv.org/abs/1512.03487) for more detail. + + +### Compilers + +Intel ICPC v16.0.3 and later + +Clang v3.5 and later (need 3.8 and later for OpenMP) + +GCC v4.9.x (recommended) + +GCC v6.3 and later + +### Important: + +Some versions of GCC appear to have a bug under high optimisation (-O2, -O3). + +The safety of these compiler versions cannot be guaranteed at this time. Follow Issue 100 for details and updates. + +GCC v5.x + +GCC v6.1, v6.2 + +### Bug report + +_To help us tracking and solving more efficiently issues with Grid, please report problems using the issue system of GitHub rather than sending emails to Grid developers._ + +When you file an issue, please go though the following checklist: + +1. Check that the code is pointing to the `HEAD` of `develop` or any commit in `master` which is tagged with a version number. +2. Give a description of the target platform (CPU, network, compiler). Please give the full CPU part description, using for example `cat /proc/cpuinfo | grep 'model name' | uniq` (Linux) or `sysctl machdep.cpu.brand_string` (macOS) and the full output the `--version` option of your compiler. +3. Give the exact `configure` command used. +4. Attach `config.log`. +5. Attach `grid.config.summary`. +6. Attach the output of `make V=1`. +7. Describe the issue and any previous attempt to solve it. If relevant, show how to reproduce the issue using a minimal working example. + +### Required libraries +Grid requires: + +[GMP](https://gmplib.org/), + +[MPFR](http://www.mpfr.org/) + +Bootstrapping grid downloads and uses for internal dense matrix (non-QCD operations) the Eigen library. + +Grid optionally uses: + +[HDF5](https://support.hdfgroup.org/HDF5/) + +[LIME](http://usqcd-software.github.io/c-lime/) for ILDG and SciDAC file format support. + +[FFTW](http://www.fftw.org) either generic version or via the Intel MKL library. + +LAPACK either generic version or Intel MKL library. + ### Quick start First, start by cloning the repository: @@ -95,10 +122,10 @@ install Grid. Other options are detailed in the next section, you can also use ` `CXX`, `CXXFLAGS`, `LDFLAGS`, ... environment variables can be modified to customise the build. -Finally, you can build and install Grid: +Finally, you can build, check, and install Grid: ``` bash -make; make install +make; make check; make install ``` To minimise the build time, only the tests at the root of the `tests` directory are built by default. If you want to build tests in the sub-directory `` you can execute: @@ -121,7 +148,7 @@ If you want to build all the tests at once just use `make tests`. - `--enable-gen-simd-width=`: select the size (in bytes) of the generic SIMD vector type (default: 32 bytes). - `--enable-precision={single|double}`: set the default precision (default: `double`). - `--enable-precision=`: Use `` for message passing (default: `none`). A list of possible SIMD targets is detailed in a section below. -- `--enable-rng={ranlux48|mt19937}`: choose the RNG (default: `ranlux48 `). +- `--enable-rng={sitmo|ranlux48|mt19937}`: choose the RNG (default: `sitmo `). - `--disable-timers`: disable system dependent high-resolution timers. - `--enable-chroma`: enable Chroma regression tests. - `--enable-doxygen-doc`: enable the Doxygen documentation generation (build with `make doxygen-doc`) @@ -135,7 +162,6 @@ The following options can be use with the `--enable-comms=` option to target dif | `none` | no communications | | `mpi[-auto]` | MPI communications | | `mpi3[-auto]` | MPI communications using MPI 3 shared memory | -| `mpi3l[-auto]` | MPI communications using MPI 3 shared memory and leader model | | `shmem ` | Cray SHMEM communications | For the MPI interfaces the optional `-auto` suffix instructs the `configure` scripts to determine all the necessary compilation and linking flags. This is done by extracting the informations from the MPI wrapper specified in the environment variable `MPICXX` (if not specified `configure` will scan though a list of default names). The `-auto` suffix is not supported by the Cray environment wrapper scripts. Use the standard versions instead. @@ -153,13 +179,13 @@ The following options can be use with the `--enable-simd=` option to target diff | `AVXFMA4` | AVX (256 bit) + FMA4 | | `AVX2` | AVX 2 (256 bit) | | `AVX512` | AVX 512 bit | -| `QPX` | QPX (256 bit) | +| `NEONv8` | [ARM NEON](http://infocenter.arm.com/help/index.jsp?topic=/com.arm.doc.den0024a/ch07s03.html) (128 bit) | +| `QPX` | IBM QPX (256 bit) | Alternatively, some CPU codenames can be directly used: | `` | Description | | ----------- | -------------------------------------- | -| `KNC` | [Intel Xeon Phi codename Knights Corner](http://ark.intel.com/products/codename/57721/Knights-Corner) | | `KNL` | [Intel Xeon Phi codename Knights Landing](http://ark.intel.com/products/codename/48999/Knights-Landing) | | `BGQ` | Blue Gene/Q | @@ -176,21 +202,205 @@ The following configuration is recommended for the Intel Knights Landing platfor ``` bash ../configure --enable-precision=double\ --enable-simd=KNL \ - --enable-comms=mpi-auto \ - --with-gmp= \ - --with-mpfr= \ + --enable-comms=mpi-auto \ --enable-mkl \ CXX=icpc MPICXX=mpiicpc ``` +The MKL flag enables use of BLAS and FFTW from the Intel Math Kernels Library. -where `` is the UNIX prefix where GMP and MPFR are installed. If you are working on a Cray machine that does not use the `mpiicpc` wrapper, please use: +If you are working on a Cray machine that does not use the `mpiicpc` wrapper, please use: ``` bash ../configure --enable-precision=double\ --enable-simd=KNL \ --enable-comms=mpi \ - --with-gmp= \ - --with-mpfr= \ --enable-mkl \ CXX=CC CC=cc -``` \ No newline at end of file +``` + +If gmp and mpfr are NOT in standard places (/usr/) these flags may be needed: +``` bash + --with-gmp= \ + --with-mpfr= \ +``` +where `` is the UNIX prefix where GMP and MPFR are installed. + +Knight's Landing with Intel Omnipath adapters with two adapters per node +presently performs better with use of more than one rank per node, using shared memory +for interior communication. This is the mpi3 communications implementation. +We recommend four ranks per node for best performance, but optimum is local volume dependent. + +``` bash +../configure --enable-precision=double\ + --enable-simd=KNL \ + --enable-comms=mpi3-auto \ + --enable-mkl \ + CC=icpc MPICXX=mpiicpc +``` + +### Build setup for Intel Haswell Xeon platform + +The following configuration is recommended for the Intel Haswell platform: + +``` bash +../configure --enable-precision=double\ + --enable-simd=AVX2 \ + --enable-comms=mpi3-auto \ + --enable-mkl \ + CXX=icpc MPICXX=mpiicpc +``` +The MKL flag enables use of BLAS and FFTW from the Intel Math Kernels Library. + +If gmp and mpfr are NOT in standard places (/usr/) these flags may be needed: +``` bash + --with-gmp= \ + --with-mpfr= \ +``` +where `` is the UNIX prefix where GMP and MPFR are installed. + +If you are working on a Cray machine that does not use the `mpiicpc` wrapper, please use: + +``` bash +../configure --enable-precision=double\ + --enable-simd=AVX2 \ + --enable-comms=mpi3 \ + --enable-mkl \ + CXX=CC CC=cc +``` +Since Dual socket nodes are commonplace, we recommend MPI-3 as the default with the use of +one rank per socket. If using the Intel MPI library, threads should be pinned to NUMA domains using +``` + export I_MPI_PIN=1 +``` +This is the default. + +### Build setup for Intel Skylake Xeon platform + +The following configuration is recommended for the Intel Skylake platform: + +``` bash +../configure --enable-precision=double\ + --enable-simd=AVX512 \ + --enable-comms=mpi3 \ + --enable-mkl \ + CXX=mpiicpc +``` +The MKL flag enables use of BLAS and FFTW from the Intel Math Kernels Library. + +If gmp and mpfr are NOT in standard places (/usr/) these flags may be needed: +``` bash + --with-gmp= \ + --with-mpfr= \ +``` +where `` is the UNIX prefix where GMP and MPFR are installed. + +If you are working on a Cray machine that does not use the `mpiicpc` wrapper, please use: + +``` bash +../configure --enable-precision=double\ + --enable-simd=AVX512 \ + --enable-comms=mpi3 \ + --enable-mkl \ + CXX=CC CC=cc +``` +Since Dual socket nodes are commonplace, we recommend MPI-3 as the default with the use of +one rank per socket. If using the Intel MPI library, threads should be pinned to NUMA domains using +``` + export I_MPI_PIN=1 +``` +This is the default. + +#### Expected Skylake Gold 6148 dual socket (single prec, single node 20+20 cores) performance using NUMA MPI mapping): + +mpirun -n 2 benchmarks/Benchmark_dwf --grid 16.16.16.16 --mpi 2.1.1.1 --cacheblocking 2.2.2.2 --dslash-asm --shm 1024 --threads 18 + +TBA + + +### Build setup for AMD EPYC / RYZEN + +The AMD EPYC is a multichip module comprising 32 cores spread over four distinct chips each with 8 cores. +So, even with a single socket node there is a quad-chip module. Dual socket nodes with 64 cores total +are common. Each chip within the module exposes a separate NUMA domain. +There are four NUMA domains per socket and we recommend one MPI rank per NUMA domain. +MPI-3 is recommended with the use of four ranks per socket, +and 8 threads per rank. + +The following configuration is recommended for the AMD EPYC platform. + +``` bash +../configure --enable-precision=double\ + --enable-simd=AVX2 \ + --enable-comms=mpi3 \ + CXX=mpicxx +``` + +If gmp and mpfr are NOT in standard places (/usr/) these flags may be needed: +``` bash + --with-gmp= \ + --with-mpfr= \ +``` +where `` is the UNIX prefix where GMP and MPFR are installed. + +Using MPICH and g++ v4.9.2, best performance can be obtained using explicit GOMP_CPU_AFFINITY flags for each MPI rank. +This can be done by invoking MPI on a wrapper script omp_bind.sh to handle this. + +It is recommended to run 8 MPI ranks on a single dual socket AMD EPYC, with 8 threads per rank using MPI3 and +shared memory to communicate within this node: + +mpirun -np 8 ./omp_bind.sh ./Benchmark_dwf --mpi 2.2.2.1 --dslash-unroll --threads 8 --grid 16.16.16.16 --cacheblocking 4.4.4.4 + +Where omp_bind.sh does the following: +``` +#!/bin/bash + +numanode=` expr $PMI_RANK % 8 ` +basecore=`expr $numanode \* 16` +core0=`expr $basecore + 0 ` +core1=`expr $basecore + 2 ` +core2=`expr $basecore + 4 ` +core3=`expr $basecore + 6 ` +core4=`expr $basecore + 8 ` +core5=`expr $basecore + 10 ` +core6=`expr $basecore + 12 ` +core7=`expr $basecore + 14 ` + +export GOMP_CPU_AFFINITY="$core0 $core1 $core2 $core3 $core4 $core5 $core6 $core7" +echo GOMP_CUP_AFFINITY $GOMP_CPU_AFFINITY + +$@ +``` + +Performance: + +#### Expected AMD EPYC 7601 dual socket (single prec, single node 32+32 cores) performance using NUMA MPI mapping): + +mpirun -np 8 ./omp_bind.sh ./Benchmark_dwf --threads 8 --mpi 2.2.2.1 --dslash-unroll --grid 16.16.16.16 --cacheblocking 4.4.4.4 + +TBA + +### Build setup for BlueGene/Q + +To be written... + +### Build setup for ARM Neon + +To be written... + +### Build setup for laptops, other compilers, non-cluster builds + +Many versions of g++ and clang++ work with Grid, and involve merely replacing CXX (and MPICXX), +and omit the enable-mkl flag. + +Single node builds are enabled with +``` + --enable-comms=none +``` + +FFTW support that is not in the default search path may then enabled with +``` + --with-fftw= +``` + +BLAS will not be compiled in by default, and Lanczos will default to Eigen diagonalisation. + diff --git a/TODO b/TODO index df8554cc..c37cbf8b 100644 --- a/TODO +++ b/TODO @@ -1,6 +1,35 @@ TODO: --------------- +Large item work list: + +1)- BG/Q port and check +2)- Christoph's local basis expansion Lanczos +3)- Precision conversion and sort out localConvert <-- partial + + - Consistent linear solver flop count/rate -- PARTIAL, time but no flop/s yet +4)- Physical propagator interface +5)- Conserved currents +6)- Multigrid Wilson and DWF, compare to other Multigrid implementations +7)- HDCR resume + +Recent DONE + +-- MultiRHS with spread out extra dim -- Go through filesystem with SciDAC I/O. <--- DONE +-- Lanczos Remove DenseVector, DenseMatrix; Use Eigen instead. <-- DONE +-- GaugeFix into central location <-- DONE +-- Scidac and Ildg metadata handling <-- DONE +-- Binary I/O MPI2 IO <-- DONE +-- Binary I/O speed up & x-strips <-- DONE +-- Cut down the exterior overhead <-- DONE +-- Interior legs from SHM comms <-- DONE +-- Half-precision comms <-- DONE +-- Merge high precision reduction into develop <-- DONE +-- BlockCG, BCGrQ <-- DONE +-- multiRHS DWF; benchmark on Cori/BNL for comms elimination <-- DONE + -- slice* linalg routines for multiRHS, BlockCG + +----- * Forces; the UdSdU term in gauge force term is half of what I think it should be. This is a consequence of taking ONLY the first term in: @@ -21,16 +50,8 @@ TODO: This means we must double the force in the Test_xxx_force routines, and is the origin of the factor of two. This 2x is applied by hand in the fermion routines and in the Test_rect_force routine. - -Policies: - -* Link smearing/boundary conds; Policy class based implementation ; framework more in place - * Support different boundary conditions (finite temp, chem. potential ... ) -* Support different fermion representations? - - contained entirely within the integrator presently - - Sign of force term. - Reversibility test. @@ -41,11 +62,6 @@ Policies: - Audit oIndex usage for cb behaviour -- Rectangle gauge actions. - Iwasaki, - Symanzik, - ... etc... - - Prepare multigrid for HMC. - Alternate setup schemes. - Support for ILDG --- ugly, not done @@ -55,9 +71,11 @@ Policies: - FFTnD ? - Gparity; hand opt use template specialisation elegance to enable the optimised paths ? + - Gparity force term; Gparity (R)HMC. -- Random number state save restore + - Mobius implementation clean up to rmove #if 0 stale code sequences + - CG -- profile carefully, kernel fusion, whole CG performance measurements. ================================================================ @@ -90,6 +108,7 @@ Insert/Extract Not sure of status of this -- reverify. Things are working nicely now though. * Make the Tensor types and Complex etc... play more nicely. + - TensorRemove is a hack, come up with a long term rationalised approach to Complex vs. Scalar > > QDP forces use of "toDouble" to get back to non tensor scalar. This role is presently taken TensorRemove, but I want to introduce a syntax that does not require this. @@ -112,6 +131,8 @@ Not sure of status of this -- reverify. Things are working nicely now though. RECENT --------------- + - Support different fermion representations? -- DONE + - contained entirely within the integrator presently - Clean up HMC -- DONE - LorentzScalar gets Gauge link type (cleaner). -- DONE - Simplified the integrators a bit. -- DONE @@ -123,6 +144,26 @@ RECENT - Parallel io improvements -- DONE - Plaquette and link trace checks into nersc reader from the Grid_nersc_io.cc test. -- DONE + +DONE: +- MultiArray -- MultiRHS done +- ConjugateGradientMultiShift -- DONE +- MCR -- DONE +- Remez -- Mike or Boost? -- DONE +- Proto (ET) -- DONE +- uBlas -- DONE ; Eigen +- Potentially Useful Boost libraries -- DONE ; Eigen +- Aligned allocator; memory pool -- DONE +- Multiprecision -- DONE +- Serialization -- DONE +- Regex -- Not needed +- Tokenize -- Why? + +- Random number state save restore -- DONE +- Rectangle gauge actions. -- DONE + Iwasaki, + Symanzik, + ... etc... Done: Cayley, Partial , ContFrac force terms. DONE @@ -207,6 +248,7 @@ Done FUNCTIONALITY: it pleases me to keep track of things I have done (keeps me arguably sane) ====================================================================================================== +* Link smearing/boundary conds; Policy class based implementation ; framework more in place -- DONE * Command line args for geometry, simd, etc. layout. Is it necessary to have -- DONE user pass these? Is this a QCD specific? diff --git a/VERSION b/VERSION index e7abbba7..bfad377d 100644 --- a/VERSION +++ b/VERSION @@ -1,6 +1,5 @@ -Version : 0.6.0 +Version : 0.7.0 -- AVX512, AVX2, AVX, SSE good -- Clang 3.5 and above, ICPC v16 and above, GCC 4.9 and above -- MPI and MPI3 -- HiRep, Smearing, Generic gauge group +- Clang 3.5 and above, ICPC v16 and above, GCC 6.3 and above recommended +- MPI and MPI3 comms optimisations for KNL and OPA finished +- Half precision comms diff --git a/benchmarks/Benchmark_ITT.cc b/benchmarks/Benchmark_ITT.cc new file mode 100644 index 00000000..666e4830 --- /dev/null +++ b/benchmarks/Benchmark_ITT.cc @@ -0,0 +1,800 @@ + /************************************************************************************* + + Grid physics library, www.github.com/paboyle/Grid + + Source file: ./benchmarks/Benchmark_memory_bandwidth.cc + + Copyright (C) 2015 + +Author: Peter Boyle +Author: paboyle + + This program is free software; you can redistribute it and/or modify + it under the terms of the GNU General Public License as published by + the Free Software Foundation; either version 2 of the License, or + (at your option) any later version. + + This program is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU General Public License for more details. + + You should have received a copy of the GNU General Public License along + with this program; if not, write to the Free Software Foundation, Inc., + 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA. + + See the full license in the file "LICENSE" in the top level distribution directory + *************************************************************************************/ + /* END LEGAL */ +#include + +using namespace std; +using namespace Grid; +using namespace Grid::QCD; + +typedef WilsonFermion5D WilsonFermion5DR; +typedef WilsonFermion5D WilsonFermion5DF; +typedef WilsonFermion5D WilsonFermion5DD; + + +std::vector L_list; +std::vector Ls_list; +std::vector mflop_list; + +double mflop_ref; +double mflop_ref_err; + +int NN_global; + +struct time_statistics{ + double mean; + double err; + double min; + double max; + + void statistics(std::vector v){ + double sum = std::accumulate(v.begin(), v.end(), 0.0); + mean = sum / v.size(); + + std::vector diff(v.size()); + std::transform(v.begin(), v.end(), diff.begin(), [=](double x) { return x - mean; }); + double sq_sum = std::inner_product(diff.begin(), diff.end(), diff.begin(), 0.0); + err = std::sqrt(sq_sum / (v.size()*(v.size() - 1))); + + auto result = std::minmax_element(v.begin(), v.end()); + min = *result.first; + max = *result.second; +} +}; + +void comms_header(){ + std::cout < simd_layout = GridDefaultSimd(Nd,vComplexD::Nsimd()); + std::vector mpi_layout = GridDefaultMpi(); + + for(int mu=0;mu1) nmu++; + + std::vector t_time(Nloop); + time_statistics timestat; + + std::cout< latt_size ({lat*mpi_layout[0], + lat*mpi_layout[1], + lat*mpi_layout[2], + lat*mpi_layout[3]}); + + GridCartesian Grid(latt_size,simd_layout,mpi_layout); + RealD Nrank = Grid._Nprocessors; + RealD Nnode = Grid.NodeCount(); + RealD ppn = Nrank/Nnode; + + std::vector xbuf(8); + std::vector rbuf(8); + Grid.ShmBufferFreeAll(); + for(int d=0;d<8;d++){ + xbuf[d] = (HalfSpinColourVectorD *)Grid.ShmBufferMalloc(lat*lat*lat*Ls*sizeof(HalfSpinColourVectorD)); + rbuf[d] = (HalfSpinColourVectorD *)Grid.ShmBufferMalloc(lat*lat*lat*Ls*sizeof(HalfSpinColourVectorD)); + bzero((void *)xbuf[d],lat*lat*lat*Ls*sizeof(HalfSpinColourVectorD)); + bzero((void *)rbuf[d],lat*lat*lat*Ls*sizeof(HalfSpinColourVectorD)); + } + + int bytes=lat*lat*lat*Ls*sizeof(HalfSpinColourVectorD); + int ncomm; + double dbytes; + std::vector times(Nloop); + for(int i=0;i1 ) { + + int xmit_to_rank; + int recv_from_rank; + if ( dir == mu ) { + int comm_proc=1; + Grid.ShiftedRanks(mu,comm_proc,xmit_to_rank,recv_from_rank); + } else { + int comm_proc = mpi_layout[mu]-1; + Grid.ShiftedRanks(mu,comm_proc,xmit_to_rank,recv_from_rank); + } + tbytes= Grid.StencilSendToRecvFrom((void *)&xbuf[dir][0], xmit_to_rank, + (void *)&rbuf[dir][0], recv_from_rank, + bytes,dir); + +#ifdef GRID_OMP +#pragma omp atomic +#endif + ncomm++; + +#ifdef GRID_OMP +#pragma omp atomic +#endif + dbytes+=tbytes; + } + } + Grid.Barrier(); + double stop=usecond(); + t_time[i] = stop-start; // microseconds + } + + timestat.statistics(t_time); + // for(int i=0;i > LatticeVec; + typedef iVector Vec; + + std::vector simd_layout = GridDefaultSimd(Nd,vReal::Nsimd()); + std::vector mpi_layout = GridDefaultMpi(); + + std::cout<({45,12,81,9})); + for(int lat=8;lat<=lmax;lat+=4){ + + std::vector latt_size ({lat*mpi_layout[0],lat*mpi_layout[1],lat*mpi_layout[2],lat*mpi_layout[3]}); + int64_t vol= latt_size[0]*latt_size[1]*latt_size[2]*latt_size[3]; + GridCartesian Grid(latt_size,simd_layout,mpi_layout); + + NP= Grid.RankCount(); + NN =Grid.NodeCount(); + + Vec rn ; random(sRNG,rn); + + LatticeVec z(&Grid); z=rn; + LatticeVec x(&Grid); x=rn; + LatticeVec y(&Grid); y=rn; + double a=2.0; + + uint64_t Nloop=NLOOP; + + double start=usecond(); + for(int i=0;i mflops_all; + + /////////////////////////////////////////////////////// + // Set/Get the layout & grid size + /////////////////////////////////////////////////////// + int threads = GridThread::GetThreads(); + std::vector mpi = GridDefaultMpi(); assert(mpi.size()==4); + std::vector local({L,L,L,L}); + + GridCartesian * TmpGrid = SpaceTimeGrid::makeFourDimGrid(std::vector({64,64,64,64}), + GridDefaultSimd(Nd,vComplex::Nsimd()),GridDefaultMpi()); + uint64_t NP = TmpGrid->RankCount(); + uint64_t NN = TmpGrid->NodeCount(); + NN_global=NN; + uint64_t SHM=NP/NN; + + std::vector internal; + if ( SHM == 1 ) internal = std::vector({1,1,1,1}); + else if ( SHM == 2 ) internal = std::vector({2,1,1,1}); + else if ( SHM == 4 ) internal = std::vector({2,2,1,1}); + else if ( SHM == 8 ) internal = std::vector({2,2,2,1}); + else assert(0); + + std::vector nodes({mpi[0]/internal[0],mpi[1]/internal[1],mpi[2]/internal[2],mpi[3]/internal[3]}); + std::vector latt4({local[0]*nodes[0],local[1]*nodes[1],local[2]*nodes[2],local[3]*nodes[3]}); + + ///////// Welcome message //////////// + std::cout< seeds4({1,2,3,4}); + std::vector seeds5({5,6,7,8}); + GridParallelRNG RNG4(UGrid); RNG4.SeedFixedIntegers(seeds4); + GridParallelRNG RNG5(sFGrid); RNG5.SeedFixedIntegers(seeds5); + std::cout << GridLogMessage << "Initialised RNGs" << std::endl; + + ///////// Source preparation //////////// + LatticeFermion src (sFGrid); random(RNG5,src); + LatticeFermion tmp (sFGrid); + + RealD N2 = 1.0/::sqrt(norm2(src)); + src = src*N2; + + LatticeGaugeField Umu(UGrid); SU3::HotConfiguration(RNG4,Umu); + + WilsonFermion5DR sDw(Umu,*sFGrid,*sFrbGrid,*sUGrid,*sUrbGrid,M5); + LatticeFermion src_e (sFrbGrid); + LatticeFermion src_o (sFrbGrid); + LatticeFermion r_e (sFrbGrid); + LatticeFermion r_o (sFrbGrid); + LatticeFermion r_eo (sFGrid); + LatticeFermion err (sFGrid); + { + + pickCheckerboard(Even,src_e,src); + pickCheckerboard(Odd,src_o,src); + +#if defined(AVX512) + const int num_cases = 6; + std::string fmt("A/S ; A/O ; U/S ; U/O ; G/S ; G/O "); +#else + const int num_cases = 4; + std::string fmt("U/S ; U/O ; G/S ; G/O "); +#endif + controls Cases [] = { +#ifdef AVX512 + { QCD::WilsonKernelsStatic::OptInlineAsm , QCD::WilsonKernelsStatic::CommsThenCompute ,CartesianCommunicator::CommunicatorPolicySequential }, + { QCD::WilsonKernelsStatic::OptInlineAsm , QCD::WilsonKernelsStatic::CommsAndCompute ,CartesianCommunicator::CommunicatorPolicySequential }, +#endif + { QCD::WilsonKernelsStatic::OptHandUnroll, QCD::WilsonKernelsStatic::CommsThenCompute ,CartesianCommunicator::CommunicatorPolicySequential }, + { QCD::WilsonKernelsStatic::OptHandUnroll, QCD::WilsonKernelsStatic::CommsAndCompute ,CartesianCommunicator::CommunicatorPolicySequential }, + { QCD::WilsonKernelsStatic::OptGeneric , QCD::WilsonKernelsStatic::CommsThenCompute ,CartesianCommunicator::CommunicatorPolicySequential }, + { QCD::WilsonKernelsStatic::OptGeneric , QCD::WilsonKernelsStatic::CommsAndCompute ,CartesianCommunicator::CommunicatorPolicySequential } + }; + + for(int c=0;cBarrier(); + for(int i=0;iBarrier(); + double t1=usecond(); + + sDw.ZeroCounters(); + time_statistics timestat; + std::vector t_time(ncall); + for(uint64_t i=0;iBarrier(); + + double volume=Ls; for(int mu=0;mumflops_best ) mflops_best = mflops; + if ( mflops mflops_all; + + /////////////////////////////////////////////////////// + // Set/Get the layout & grid size + /////////////////////////////////////////////////////// + int threads = GridThread::GetThreads(); + std::vector mpi = GridDefaultMpi(); assert(mpi.size()==4); + std::vector local({L,L,L,L}); + + GridCartesian * TmpGrid = SpaceTimeGrid::makeFourDimGrid(std::vector({64,64,64,64}), + GridDefaultSimd(Nd,vComplex::Nsimd()),GridDefaultMpi()); + uint64_t NP = TmpGrid->RankCount(); + uint64_t NN = TmpGrid->NodeCount(); + NN_global=NN; + uint64_t SHM=NP/NN; + + std::vector internal; + if ( SHM == 1 ) internal = std::vector({1,1,1,1}); + else if ( SHM == 2 ) internal = std::vector({2,1,1,1}); + else if ( SHM == 4 ) internal = std::vector({2,2,1,1}); + else if ( SHM == 8 ) internal = std::vector({2,2,2,1}); + else assert(0); + + std::vector nodes({mpi[0]/internal[0],mpi[1]/internal[1],mpi[2]/internal[2],mpi[3]/internal[3]}); + std::vector latt4({local[0]*nodes[0],local[1]*nodes[1],local[2]*nodes[2],local[3]*nodes[3]}); + + ///////// Welcome message //////////// + std::cout< seeds4({1,2,3,4}); + std::vector seeds5({5,6,7,8}); + GridParallelRNG RNG4(UGrid); RNG4.SeedFixedIntegers(seeds4); + GridParallelRNG RNG5(FGrid); RNG5.SeedFixedIntegers(seeds5); + std::cout << GridLogMessage << "Initialised RNGs" << std::endl; + + ///////// Source preparation //////////// + LatticeFermion src (FGrid); random(RNG5,src); + LatticeFermion ref (FGrid); + LatticeFermion tmp (FGrid); + + RealD N2 = 1.0/::sqrt(norm2(src)); + src = src*N2; + + LatticeGaugeField Umu(UGrid); SU3::HotConfiguration(RNG4,Umu); + + DomainWallFermionR Dw(Umu,*FGrid,*FrbGrid,*UGrid,*UrbGrid,mass,M5); + + //////////////////////////////////// + // Naive wilson implementation + //////////////////////////////////// + { + LatticeGaugeField Umu5d(FGrid); + std::vector U(4,FGrid); + for(int ss=0;ssoSites();ss++){ + for(int s=0;s(Umu5d,mu); + } + for(int mu=0;muBarrier(); + for(int i=0;iBarrier(); + double t1=usecond(); + // uint64_t ncall = (uint64_t) 2.5*1000.0*1000.0*nwarm/(t1-t0); + // if (ncall < 500) ncall = 500; + uint64_t ncall = 1000; + + FGrid->Broadcast(0,&ncall,sizeof(ncall)); + + // std::cout << GridLogMessage << " Estimate " << ncall << " calls per second"< t_time(ncall); + for(uint64_t i=0;iBarrier(); + + double volume=Ls; for(int mu=0;mumflops_best ) mflops_best = mflops; + if ( mflops({8,2,2,2}); +#else + LebesgueOrder::Block = std::vector({2,2,2,2}); +#endif + Benchmark::Decomposition(); + + int do_memory=1; + int do_comms =1; + int do_su3 =0; + int do_wilson=1; + int do_dwf =1; + + if ( do_su3 ) { + // empty for now + } +#if 1 + int sel=2; + std::vector L_list({8,12,16,24}); +#else + int sel=1; + std::vector L_list({8,12}); +#endif + int selm1=sel-1; + std::vector robust_list; + + std::vector wilson; + std::vector dwf4; + std::vector dwf5; + + if ( do_wilson ) { + int Ls=1; + std::cout<1) ) { + std::cout< v){ + double sum = std::accumulate(v.begin(), v.end(), 0.0); + mean = sum / v.size(); + + std::vector diff(v.size()); + std::transform(v.begin(), v.end(), diff.begin(), [=](double x) { return x - mean; }); + double sq_sum = std::inner_product(diff.begin(), diff.end(), diff.begin(), 0.0); + err = std::sqrt(sq_sum / (v.size()*(v.size() - 1))); + + auto result = std::minmax_element(v.begin(), v.end()); + min = *result.first; + max = *result.second; +} +}; + +void header(){ + std::cout <1) nmu++; + std::cout << GridLogMessage << "Number of iterations to average: "<< Nloop << std::endl; + std::vector t_time(Nloop); + time_statistics timestat; + std::cout< latt_size ({lat*mpi_layout[0], lat*mpi_layout[1], @@ -58,15 +88,23 @@ int main (int argc, char ** argv) lat*mpi_layout[3]}); GridCartesian Grid(latt_size,simd_layout,mpi_layout); + RealD Nrank = Grid._Nprocessors; + RealD Nnode = Grid.NodeCount(); + RealD ppn = Nrank/Nnode; - std::vector > xbuf(8,std::vector(lat*lat*lat*Ls)); - std::vector > rbuf(8,std::vector(lat*lat*lat*Ls)); + std::vector > xbuf(8); + std::vector > rbuf(8); int ncomm; int bytes=lat*lat*lat*Ls*sizeof(HalfSpinColourVectorD); + for(int mu=0;mu<8;mu++){ + xbuf[mu].resize(lat*lat*lat*Ls); + rbuf[mu].resize(lat*lat*lat*Ls); + // std::cout << " buffers " << std::hex << (uint64_t)&xbuf[mu][0] <<" " << (uint64_t)&rbuf[mu][0] < requests; @@ -79,7 +117,6 @@ int main (int argc, char ** argv) int comm_proc=1; int xmit_to_rank; int recv_from_rank; - Grid.ShiftedRanks(mu,comm_proc,xmit_to_rank,recv_from_rank); Grid.SendToRecvFromBegin(requests, (void *)&xbuf[mu][0], @@ -102,18 +139,24 @@ int main (int argc, char ** argv) } Grid.SendToRecvFromComplete(requests); Grid.Barrier(); - + double stop=usecond(); + t_time[i] = stop-start; // microseconds } - double stop=usecond(); - double dbytes = bytes; - double xbytes = Nloop*dbytes*2.0*ncomm; + timestat.statistics(t_time); + + double dbytes = bytes*ppn; + double xbytes = dbytes*2.0*ncomm; double rbytes = xbytes; double bidibytes = xbytes+rbytes; - double time = stop-start; // microseconds + std::cout< latt_size ({lat,lat,lat,lat}); GridCartesian Grid(latt_size,simd_layout,mpi_layout); + RealD Nrank = Grid._Nprocessors; + RealD Nnode = Grid.NodeCount(); + RealD ppn = Nrank/Nnode; - std::vector > xbuf(8,std::vector(lat*lat*lat*Ls)); - std::vector > rbuf(8,std::vector(lat*lat*lat*Ls)); + std::vector > xbuf(8); + std::vector > rbuf(8); + for(int mu=0;mu<8;mu++){ + xbuf[mu].resize(lat*lat*lat*Ls); + rbuf[mu].resize(lat*lat*lat*Ls); + // std::cout << " buffers " << std::hex << (uint64_t)&xbuf[mu][0] <<" " << (uint64_t)&rbuf[mu][0] < latt_size ({lat*mpi_layout[0], lat*mpi_layout[1], @@ -209,6 +266,9 @@ int main (int argc, char ** argv) lat*mpi_layout[3]}); GridCartesian Grid(latt_size,simd_layout,mpi_layout); + RealD Nrank = Grid._Nprocessors; + RealD Nnode = Grid.NodeCount(); + RealD ppn = Nrank/Nnode; std::vector xbuf(8); std::vector rbuf(8); @@ -216,73 +276,86 @@ int main (int argc, char ** argv) for(int d=0;d<8;d++){ xbuf[d] = (HalfSpinColourVectorD *)Grid.ShmBufferMalloc(lat*lat*lat*Ls*sizeof(HalfSpinColourVectorD)); rbuf[d] = (HalfSpinColourVectorD *)Grid.ShmBufferMalloc(lat*lat*lat*Ls*sizeof(HalfSpinColourVectorD)); + bzero((void *)xbuf[d],lat*lat*lat*Ls*sizeof(HalfSpinColourVectorD)); + bzero((void *)rbuf[d],lat*lat*lat*Ls*sizeof(HalfSpinColourVectorD)); } int ncomm; int bytes=lat*lat*lat*Ls*sizeof(HalfSpinColourVectorD); - double start=usecond(); + double dbytes; for(int i=0;i requests; - ncomm=0; for(int mu=0;mu<4;mu++){ + if (mpi_layout[mu]>1 ) { ncomm++; int comm_proc=1; int xmit_to_rank; int recv_from_rank; - Grid.ShiftedRanks(mu,comm_proc,xmit_to_rank,recv_from_rank); - Grid.StencilSendToRecvFromBegin(requests, - (void *)&xbuf[mu][0], - xmit_to_rank, - (void *)&rbuf[mu][0], - recv_from_rank, - bytes); + dbytes+= + Grid.StencilSendToRecvFromBegin(requests, + (void *)&xbuf[mu][0], + xmit_to_rank, + (void *)&rbuf[mu][0], + recv_from_rank, + bytes,mu); comm_proc = mpi_layout[mu]-1; Grid.ShiftedRanks(mu,comm_proc,xmit_to_rank,recv_from_rank); - Grid.StencilSendToRecvFromBegin(requests, - (void *)&xbuf[mu+4][0], - xmit_to_rank, - (void *)&rbuf[mu+4][0], - recv_from_rank, - bytes); + dbytes+= + Grid.StencilSendToRecvFromBegin(requests, + (void *)&xbuf[mu+4][0], + xmit_to_rank, + (void *)&rbuf[mu+4][0], + recv_from_rank, + bytes,mu+4); } } - Grid.StencilSendToRecvFromComplete(requests); + Grid.StencilSendToRecvFromComplete(requests,0); Grid.Barrier(); - + double stop=usecond(); + t_time[i] = stop-start; // microseconds + } - double stop=usecond(); - double dbytes = bytes; - double xbytes = Nloop*dbytes*2.0*ncomm; - double rbytes = xbytes; - double bidibytes = xbytes+rbytes; + timestat.statistics(t_time); + + dbytes=dbytes*ppn; + double xbytes = dbytes*0.5; + double rbytes = dbytes*0.5; + double bidibytes = dbytes; + + std::cout< latt_size ({lat*mpi_layout[0], lat*mpi_layout[1], @@ -290,6 +363,9 @@ int main (int argc, char ** argv) lat*mpi_layout[3]}); GridCartesian Grid(latt_size,simd_layout,mpi_layout); + RealD Nrank = Grid._Nprocessors; + RealD Nnode = Grid.NodeCount(); + RealD ppn = Nrank/Nnode; std::vector xbuf(8); std::vector rbuf(8); @@ -297,16 +373,18 @@ int main (int argc, char ** argv) for(int d=0;d<8;d++){ xbuf[d] = (HalfSpinColourVectorD *)Grid.ShmBufferMalloc(lat*lat*lat*Ls*sizeof(HalfSpinColourVectorD)); rbuf[d] = (HalfSpinColourVectorD *)Grid.ShmBufferMalloc(lat*lat*lat*Ls*sizeof(HalfSpinColourVectorD)); + bzero((void *)xbuf[d],lat*lat*lat*Ls*sizeof(HalfSpinColourVectorD)); + bzero((void *)rbuf[d],lat*lat*lat*Ls*sizeof(HalfSpinColourVectorD)); } int ncomm; int bytes=lat*lat*lat*Ls*sizeof(HalfSpinColourVectorD); - - double start=usecond(); + double dbytes; for(int i=0;i requests; - + dbytes=0; ncomm=0; for(int mu=0;mu<4;mu++){ @@ -318,44 +396,146 @@ int main (int argc, char ** argv) int recv_from_rank; Grid.ShiftedRanks(mu,comm_proc,xmit_to_rank,recv_from_rank); - Grid.StencilSendToRecvFromBegin(requests, - (void *)&xbuf[mu][0], - xmit_to_rank, - (void *)&rbuf[mu][0], - recv_from_rank, - bytes); - // Grid.StencilSendToRecvFromComplete(requests); - // requests.resize(0); + dbytes+= + Grid.StencilSendToRecvFromBegin(requests, + (void *)&xbuf[mu][0], + xmit_to_rank, + (void *)&rbuf[mu][0], + recv_from_rank, + bytes,mu); + Grid.StencilSendToRecvFromComplete(requests,mu); + requests.resize(0); comm_proc = mpi_layout[mu]-1; Grid.ShiftedRanks(mu,comm_proc,xmit_to_rank,recv_from_rank); - Grid.StencilSendToRecvFromBegin(requests, - (void *)&xbuf[mu+4][0], - xmit_to_rank, - (void *)&rbuf[mu+4][0], - recv_from_rank, - bytes); - Grid.StencilSendToRecvFromComplete(requests); + dbytes+= + Grid.StencilSendToRecvFromBegin(requests, + (void *)&xbuf[mu+4][0], + xmit_to_rank, + (void *)&rbuf[mu+4][0], + recv_from_rank, + bytes,mu+4); + Grid.StencilSendToRecvFromComplete(requests,mu+4); requests.resize(0); } } Grid.Barrier(); - + double stop=usecond(); + t_time[i] = stop-start; // microseconds + } - double stop=usecond(); - double dbytes = bytes; - double xbytes = Nloop*dbytes*2.0*ncomm; - double rbytes = xbytes; - double bidibytes = xbytes+rbytes; + timestat.statistics(t_time); - double time = stop-start; // microseconds + dbytes=dbytes*ppn; + double xbytes = dbytes*0.5; + double rbytes = dbytes*0.5; + double bidibytes = dbytes; - std::cout< latt_size ({lat*mpi_layout[0], + lat*mpi_layout[1], + lat*mpi_layout[2], + lat*mpi_layout[3]}); + + GridCartesian Grid(latt_size,simd_layout,mpi_layout); + RealD Nrank = Grid._Nprocessors; + RealD Nnode = Grid.NodeCount(); + RealD ppn = Nrank/Nnode; + + std::vector xbuf(8); + std::vector rbuf(8); + Grid.ShmBufferFreeAll(); + for(int d=0;d<8;d++){ + xbuf[d] = (HalfSpinColourVectorD *)Grid.ShmBufferMalloc(lat*lat*lat*Ls*sizeof(HalfSpinColourVectorD)); + rbuf[d] = (HalfSpinColourVectorD *)Grid.ShmBufferMalloc(lat*lat*lat*Ls*sizeof(HalfSpinColourVectorD)); + bzero((void *)xbuf[d],lat*lat*lat*Ls*sizeof(HalfSpinColourVectorD)); + bzero((void *)rbuf[d],lat*lat*lat*Ls*sizeof(HalfSpinColourVectorD)); + } + + int ncomm; + int bytes=lat*lat*lat*Ls*sizeof(HalfSpinColourVectorD); + double dbytes; + for(int i=0;i requests; + dbytes=0; + ncomm=0; + + parallel_for(int dir=0;dir<8;dir++){ + + double tbytes; + int mu =dir % 4; + + if (mpi_layout[mu]>1 ) { + + ncomm++; + int xmit_to_rank; + int recv_from_rank; + if ( dir == mu ) { + int comm_proc=1; + Grid.ShiftedRanks(mu,comm_proc,xmit_to_rank,recv_from_rank); + } else { + int comm_proc = mpi_layout[mu]-1; + Grid.ShiftedRanks(mu,comm_proc,xmit_to_rank,recv_from_rank); + } + + tbytes= Grid.StencilSendToRecvFrom((void *)&xbuf[dir][0], xmit_to_rank, + (void *)&rbuf[dir][0], recv_from_rank, bytes,dir); + +#pragma omp atomic + dbytes+=tbytes; + } + } + Grid.Barrier(); + double stop=usecond(); + t_time[i] = stop-start; // microseconds + } + + timestat.statistics(t_time); + + dbytes=dbytes*ppn; + double xbytes = dbytes*0.5; + double rbytes = dbytes*0.5; + double bidibytes = dbytes; + + + std::cout< -Author: paboyle + Author: Peter Boyle + Author: paboyle This program is free software; you can redistribute it and/or modify it under the terms of the GNU General Public License as published by the Free Software Foundation; either version 2 of the License, or (at your option) any later version. - This program is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License for more details. - You should have received a copy of the GNU General Public License along with this program; if not, write to the Free Software Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA. - See the full license in the file "LICENSE" in the top level distribution directory *************************************************************************************/ /* END LEGAL */ @@ -37,27 +31,33 @@ struct scal { d internal; }; - Gamma::GammaMatrix Gmu [] = { - Gamma::GammaX, - Gamma::GammaY, - Gamma::GammaZ, - Gamma::GammaT + Gamma::Algebra Gmu [] = { + Gamma::Algebra::GammaX, + Gamma::Algebra::GammaY, + Gamma::Algebra::GammaZ, + Gamma::Algebra::GammaT }; typedef WilsonFermion5D WilsonFermion5DR; typedef WilsonFermion5D WilsonFermion5DF; typedef WilsonFermion5D WilsonFermion5DD; - int main (int argc, char ** argv) { Grid_init(&argc,&argv); + int threads = GridThread::GetThreads(); std::cout< latt4 = GridDefaultLatt(); - const int Ls=8; + int Ls=16; + for(int i=0;i> Ls; + } + + GridCartesian * UGrid = SpaceTimeGrid::makeFourDimGrid(GridDefaultLatt(), GridDefaultSimd(Nd,vComplex::Nsimd()),GridDefaultMpi()); GridRedBlackCartesian * UrbGrid = SpaceTimeGrid::makeFourDimRedBlackGrid(UGrid); GridCartesian * FGrid = SpaceTimeGrid::makeFiveDimGrid(Ls,UGrid); @@ -71,35 +71,66 @@ int main (int argc, char ** argv) std::vector seeds4({1,2,3,4}); std::vector seeds5({5,6,7,8}); - + + std::cout << GridLogMessage << "Initialising 4d RNG" << std::endl; GridParallelRNG RNG4(UGrid); RNG4.SeedFixedIntegers(seeds4); + std::cout << GridLogMessage << "Initialising 5d RNG" << std::endl; GridParallelRNG RNG5(FGrid); RNG5.SeedFixedIntegers(seeds5); + std::cout << GridLogMessage << "Initialised RNGs" << std::endl; LatticeFermion src (FGrid); random(RNG5,src); +#if 0 + src = zero; + { + std::vector origin({0,0,0,latt4[2]-1,0}); + SpinColourVectorF tmp; + tmp=zero; + tmp()(0)(0)=Complex(-2.0,0.0); + std::cout << " source site 0 " << tmp<(Umu,mu); + // if (mu !=2 ) ttmp = 0; + // ttmp = ttmp* pow(10.0,mu); + PokeIndex(Umu,ttmp,mu); + } + std::cout << GridLogMessage << "Forced to diagonal " << std::endl; +#endif + //////////////////////////////////// + // Naive wilson implementation + //////////////////////////////////// // replicate across fifth dimension + LatticeGaugeField Umu5d(FGrid); + std::vector U(4,FGrid); for(int ss=0;ssoSites();ss++){ for(int s=0;s U(4,FGrid); for(int mu=0;mu(Umu5d,mu); } + std::cout << GridLogMessage << "Setting up Cshift based reference " << std::endl; if (1) { @@ -120,8 +151,7 @@ int main (int argc, char ** argv) RealD M5 =1.8; RealD NP = UGrid->_Nprocessors; - - DomainWallFermionR Dw(Umu,*FGrid,*FrbGrid,*UGrid,*UrbGrid,mass,M5); + RealD NN = UGrid->NodeCount(); std::cout << GridLogMessage<< "*****************************************************************" <Barrier(); Dw.ZeroCounters(); + Dw.Dhop(src,result,0); + std::cout<1.0e-4) ) { + std::cout << "RESULT\n " << result<Barrier(); + exit(-1); + } + */ assert (norm2(err)< 1.0e-4 ); Dw.Report(); } + DomainWallFermionRL DwH(Umu,*FGrid,*FrbGrid,*UGrid,*UrbGrid,mass,M5); + if (1) { + FGrid->Barrier(); + DwH.ZeroCounters(); + DwH.Dhop(src,result,0); + double t0=usecond(); + for(int i=0;iBarrier(); + + double volume=Ls; for(int mu=0;mu site({s,x,y,z,t}); - SpinColourVector tmp; - peekSite(tmp,src,site); - pokeSite(tmp,ssrc,site); - }}}}} + + localConvert(src,ssrc); std::cout<Barrier(); - double t0=usecond(); + sDw.Dhop(ssrc,sresult,0); sDw.ZeroCounters(); + double t0=usecond(); for(int i=0;i site({s,x,y,z,t}); - SpinColourVector normal, simd; - peekSite(normal,result,site); - peekSite(simd,sresult,site); - sum=sum+norm2(normal-simd); - if (norm2(normal-simd) > 1.0e-6 ) { - std::cout << "site "< 1.0e-4 ){ + std::cout<< "sD REF\n " < 1.0e-4 ){ + std::cout<< "sD REF\n " <::DhopEO "<::DhopEO "<Barrier(); + sDw.DhopEO(ssrc_o, sr_e, DaggerNo); sDw.ZeroCounters(); - sDw.stat.init("DhopEO"); + // sDw.stat.init("DhopEO"); double t0=usecond(); for (int i = 0; i < ncall; i++) { sDw.DhopEO(ssrc_o, sr_e, DaggerNo); } double t1=usecond(); FGrid->Barrier(); - sDw.stat.print(); + // sDw.stat.print(); double volume=Ls; for(int mu=0;mu1.0e-4) { + + if(( error>1.0e-4) ) { setCheckerboard(ssrc,ssrc_o); setCheckerboard(ssrc,ssrc_e); - std::cout<< ssrc << std::endl; + std::cout<< "DIFF\n " <1.0e-4)){ + std::cout<< "DAG RESULT\n " <Barrier(); + Dw.DhopEO(src_o,r_e,DaggerNo); double t0=usecond(); for(int i=0;i1.0e-4)){ + std::cout<< "Deo RESULT\n " < & L, int Ls, int threads, int report =0 ); diff --git a/benchmarks/Benchmark_gparity.cc b/benchmarks/Benchmark_gparity.cc new file mode 100644 index 00000000..f6036aa8 --- /dev/null +++ b/benchmarks/Benchmark_gparity.cc @@ -0,0 +1,190 @@ +#include +#include +using namespace std; +using namespace Grid; +using namespace Grid::QCD; + +template +struct scal { + d internal; +}; + + Gamma::Algebra Gmu [] = { + Gamma::Algebra::GammaX, + Gamma::Algebra::GammaY, + Gamma::Algebra::GammaZ, + Gamma::Algebra::GammaT + }; + +typedef typename GparityDomainWallFermionF::FermionField GparityLatticeFermionF; +typedef typename GparityDomainWallFermionD::FermionField GparityLatticeFermionD; + + + +int main (int argc, char ** argv) +{ + Grid_init(&argc,&argv); + + int Ls=16; + for(int i=0;i> Ls; + } + + + int threads = GridThread::GetThreads(); + std::cout<_Nprocessors; + RealD NN = UGrid->NodeCount(); + + std::cout << GridLogMessage<< "*****************************************************************" <Barrier(); + Dw.ZeroCounters(); + Dw.Dhop(src,result,0); + std::cout<Barrier(); + + double volume=Ls; for(int mu=0;muBarrier(); + DwH.ZeroCounters(); + DwH.Dhop(src,result,0); + double t0=usecond(); + for(int i=0;iBarrier(); + + double volume=Ls; for(int mu=0;muBarrier(); + DwD.ZeroCounters(); + DwD.Dhop(src_d,result_d,0); + std::cout<Barrier(); + + double volume=Ls; for(int mu=0;mu latt_size ({lat*mpi_layout[0],lat*mpi_layout[1],lat*mpi_layout[2],lat*mpi_layout[3]}); - int vol = latt_size[0]*latt_size[1]*latt_size[2]*latt_size[3]; + int64_t vol= latt_size[0]*latt_size[1]*latt_size[2]*latt_size[3]; GridCartesian Grid(latt_size,simd_layout,mpi_layout); uint64_t Nloop=NLOOP; - // GridParallelRNG pRNG(&Grid); pRNG.SeedRandomDevice(); + // GridParallelRNG pRNG(&Grid); pRNG.SeedFixedIntegers(std::vector({45,12,81,9})); - LatticeVec z(&Grid); //random(pRNG,z); - LatticeVec x(&Grid); //random(pRNG,x); - LatticeVec y(&Grid); //random(pRNG,y); + LatticeVec z(&Grid);// random(pRNG,z); + LatticeVec x(&Grid);// random(pRNG,x); + LatticeVec y(&Grid);// random(pRNG,y); double a=2.0; @@ -83,7 +83,7 @@ int main (int argc, char ** argv) double time = (stop-start)/Nloop*1000; double flops=vol*Nvec*2;// mul,add - double bytes=3*vol*Nvec*sizeof(Real); + double bytes=3.0*vol*Nvec*sizeof(Real); std::cout< latt_size ({lat*mpi_layout[0],lat*mpi_layout[1],lat*mpi_layout[2],lat*mpi_layout[3]}); - int vol = latt_size[0]*latt_size[1]*latt_size[2]*latt_size[3]; + int64_t vol= latt_size[0]*latt_size[1]*latt_size[2]*latt_size[3]; GridCartesian Grid(latt_size,simd_layout,mpi_layout); - // GridParallelRNG pRNG(&Grid); pRNG.SeedRandomDevice(); + // GridParallelRNG pRNG(&Grid); pRNG.SeedFixedIntegers(std::vector({45,12,81,9})); - LatticeVec z(&Grid); //random(pRNG,z); - LatticeVec x(&Grid); //random(pRNG,x); - LatticeVec y(&Grid); //random(pRNG,y); + LatticeVec z(&Grid);// random(pRNG,z); + LatticeVec x(&Grid);// random(pRNG,x); + LatticeVec y(&Grid);// random(pRNG,y); double a=2.0; uint64_t Nloop=NLOOP; @@ -119,7 +119,7 @@ int main (int argc, char ** argv) double time = (stop-start)/Nloop*1000; double flops=vol*Nvec*2;// mul,add - double bytes=3*vol*Nvec*sizeof(Real); + double bytes=3.0*vol*Nvec*sizeof(Real); std::cout< latt_size ({lat*mpi_layout[0],lat*mpi_layout[1],lat*mpi_layout[2],lat*mpi_layout[3]}); - int vol = latt_size[0]*latt_size[1]*latt_size[2]*latt_size[3]; + int64_t vol= latt_size[0]*latt_size[1]*latt_size[2]*latt_size[3]; uint64_t Nloop=NLOOP; GridCartesian Grid(latt_size,simd_layout,mpi_layout); - // GridParallelRNG pRNG(&Grid); pRNG.SeedRandomDevice(); + // GridParallelRNG pRNG(&Grid); pRNG.SeedFixedIntegers(std::vector({45,12,81,9})); - LatticeVec z(&Grid); //random(pRNG,z); - LatticeVec x(&Grid); //random(pRNG,x); - LatticeVec y(&Grid); //random(pRNG,y); + LatticeVec z(&Grid);// random(pRNG,z); + LatticeVec x(&Grid);// random(pRNG,x); + LatticeVec y(&Grid);// random(pRNG,y); RealD a=2.0; @@ -154,7 +154,7 @@ int main (int argc, char ** argv) double stop=usecond(); double time = (stop-start)/Nloop*1000; - double bytes=2*vol*Nvec*sizeof(Real); + double bytes=2.0*vol*Nvec*sizeof(Real); double flops=vol*Nvec*1;// mul std::cout< latt_size ({lat*mpi_layout[0],lat*mpi_layout[1],lat*mpi_layout[2],lat*mpi_layout[3]}); - int vol = latt_size[0]*latt_size[1]*latt_size[2]*latt_size[3]; + int64_t vol= latt_size[0]*latt_size[1]*latt_size[2]*latt_size[3]; uint64_t Nloop=NLOOP; GridCartesian Grid(latt_size,simd_layout,mpi_layout); - // GridParallelRNG pRNG(&Grid); pRNG.SeedRandomDevice(); - LatticeVec z(&Grid); //random(pRNG,z); - LatticeVec x(&Grid); //random(pRNG,x); - LatticeVec y(&Grid); //random(pRNG,y); + // GridParallelRNG pRNG(&Grid); pRNG.SeedFixedIntegers(std::vector({45,12,81,9})); + LatticeVec z(&Grid);// random(pRNG,z); + LatticeVec x(&Grid);// random(pRNG,x); + LatticeVec y(&Grid);// random(pRNG,y); RealD a=2.0; Real nn; double start=usecond(); @@ -187,7 +187,7 @@ int main (int argc, char ** argv) double stop=usecond(); double time = (stop-start)/Nloop*1000; - double bytes=vol*Nvec*sizeof(Real); + double bytes=1.0*vol*Nvec*sizeof(Real); double flops=vol*Nvec*2;// mul,add std::cout<Barrier(); \ + t0=usecond(); \ + for(int i=0;iBarrier(); \ + zDw.CayleyReport(); \ + std::cout<Barrier(); \ + t0=usecond(); \ + for(int i=0;iBarrier(); \ + Dw.CayleyReport(); \ + std::cout< gamma(Ls,std::complex(1.0,0.0)); + ZMobiusFermionVec5dR zDw(Umu,*sFGrid,*sFrbGrid,*sUGrid,*sUrbGrid,mass,M5,gamma,b,c); + std::cout<Barrier(); @@ -173,10 +209,13 @@ int main (int argc, char ** argv) BENCH_DW_MEO(Dhop ,src,result); BENCH_DW_MEO(DhopEO ,src_o,r_e); - BENCH_DW(Meooe ,src_o,r_e); + BENCH_DW_SSC(Meooe ,src_o,r_e); BENCH_DW(Mooee ,src_o,r_o); BENCH_DW(MooeeInv,src_o,r_o); + BENCH_ZDW(Mooee ,src_o,r_o); + BENCH_ZDW(MooeeInv,src_o,r_o); + } Grid_finalize(); diff --git a/benchmarks/Benchmark_staggered.cc b/benchmarks/Benchmark_staggered.cc index 121dc0d5..dc2dcf91 100644 --- a/benchmarks/Benchmark_staggered.cc +++ b/benchmarks/Benchmark_staggered.cc @@ -51,7 +51,7 @@ int main (int argc, char ** argv) std::vector seeds({1,2,3,4}); GridParallelRNG pRNG(&Grid); pRNG.SeedFixedIntegers(seeds); - // pRNG.SeedRandomDevice(); + // pRNG.SeedFixedIntegers(std::vector({45,12,81,9}); typedef typename ImprovedStaggeredFermionR::FermionField FermionField; typename ImprovedStaggeredFermionR::ImplParams params; diff --git a/benchmarks/Benchmark_su3.cc b/benchmarks/Benchmark_su3.cc index b6d1d303..035af2d9 100644 --- a/benchmarks/Benchmark_su3.cc +++ b/benchmarks/Benchmark_su3.cc @@ -35,13 +35,14 @@ using namespace Grid::QCD; int main (int argc, char ** argv) { Grid_init(&argc,&argv); +#define LMAX (64) - int Nloop=1000; + int64_t Nloop=20; std::vector simd_layout = GridDefaultSimd(Nd,vComplex::Nsimd()); std::vector mpi_layout = GridDefaultMpi(); - int threads = GridThread::GetThreads(); + int64_t threads = GridThread::GetThreads(); std::cout< latt_size ({lat*mpi_layout[0],lat*mpi_layout[1],lat*mpi_layout[2],lat*mpi_layout[3]}); - int vol = latt_size[0]*latt_size[1]*latt_size[2]*latt_size[3]; + int64_t vol = latt_size[0]*latt_size[1]*latt_size[2]*latt_size[3]; GridCartesian Grid(latt_size,simd_layout,mpi_layout); - // GridParallelRNG pRNG(&Grid); pRNG.SeedRandomDevice(); + GridParallelRNG pRNG(&Grid); pRNG.SeedFixedIntegers(std::vector({45,12,81,9})); - LatticeColourMatrix z(&Grid);// random(pRNG,z); - LatticeColourMatrix x(&Grid);// random(pRNG,x); - LatticeColourMatrix y(&Grid);// random(pRNG,y); + LatticeColourMatrix z(&Grid); random(pRNG,z); + LatticeColourMatrix x(&Grid); random(pRNG,x); + LatticeColourMatrix y(&Grid); random(pRNG,y); double start=usecond(); - for(int i=0;i latt_size ({lat*mpi_layout[0],lat*mpi_layout[1],lat*mpi_layout[2],lat*mpi_layout[3]}); - int vol = latt_size[0]*latt_size[1]*latt_size[2]*latt_size[3]; + int64_t vol = latt_size[0]*latt_size[1]*latt_size[2]*latt_size[3]; GridCartesian Grid(latt_size,simd_layout,mpi_layout); - // GridParallelRNG pRNG(&Grid); pRNG.SeedRandomDevice(); + GridParallelRNG pRNG(&Grid); pRNG.SeedFixedIntegers(std::vector({45,12,81,9})); - LatticeColourMatrix z(&Grid); //random(pRNG,z); - LatticeColourMatrix x(&Grid); //random(pRNG,x); - LatticeColourMatrix y(&Grid); //random(pRNG,y); + LatticeColourMatrix z(&Grid); random(pRNG,z); + LatticeColourMatrix x(&Grid); random(pRNG,x); + LatticeColourMatrix y(&Grid); random(pRNG,y); double start=usecond(); - for(int i=0;i latt_size ({lat*mpi_layout[0],lat*mpi_layout[1],lat*mpi_layout[2],lat*mpi_layout[3]}); - int vol = latt_size[0]*latt_size[1]*latt_size[2]*latt_size[3]; + int64_t vol = latt_size[0]*latt_size[1]*latt_size[2]*latt_size[3]; GridCartesian Grid(latt_size,simd_layout,mpi_layout); - // GridParallelRNG pRNG(&Grid); pRNG.SeedRandomDevice(); + GridParallelRNG pRNG(&Grid); pRNG.SeedFixedIntegers(std::vector({45,12,81,9})); - LatticeColourMatrix z(&Grid); //random(pRNG,z); - LatticeColourMatrix x(&Grid); //random(pRNG,x); - LatticeColourMatrix y(&Grid); //random(pRNG,y); + LatticeColourMatrix z(&Grid); random(pRNG,z); + LatticeColourMatrix x(&Grid); random(pRNG,x); + LatticeColourMatrix y(&Grid); random(pRNG,y); double start=usecond(); - for(int i=0;i latt_size ({lat*mpi_layout[0],lat*mpi_layout[1],lat*mpi_layout[2],lat*mpi_layout[3]}); - int vol = latt_size[0]*latt_size[1]*latt_size[2]*latt_size[3]; + int64_t vol = latt_size[0]*latt_size[1]*latt_size[2]*latt_size[3]; GridCartesian Grid(latt_size,simd_layout,mpi_layout); - // GridParallelRNG pRNG(&Grid); pRNG.SeedRandomDevice(); + GridParallelRNG pRNG(&Grid); pRNG.SeedFixedIntegers(std::vector({45,12,81,9})); - LatticeColourMatrix z(&Grid); //random(pRNG,z); - LatticeColourMatrix x(&Grid); //random(pRNG,x); - LatticeColourMatrix y(&Grid); //random(pRNG,y); + LatticeColourMatrix z(&Grid); random(pRNG,z); + LatticeColourMatrix x(&Grid); random(pRNG,x); + LatticeColourMatrix y(&Grid); random(pRNG,y); double start=usecond(); - for(int i=0;i seeds({1,2,3,4}); GridParallelRNG pRNG(&Grid); pRNG.SeedFixedIntegers(seeds); - // pRNG.SeedRandomDevice(); + // pRNG.SeedFixedIntegers(std::vector({45,12,81,9}); LatticeFermion src (&Grid); random(pRNG,src); LatticeFermion result(&Grid); result=zero; @@ -106,7 +106,7 @@ int main (int argc, char ** argv) { // Naive wilson implementation ref = zero; for(int mu=0;mu]]) AC_CHECK_DECLS([be64toh],[], [], [[#include ]]) +############## Standard libraries +AC_CHECK_LIB([m],[cos]) +AC_CHECK_LIB([stdc++],[abort]) + ############### GMP and MPFR AC_ARG_WITH([gmp], [AS_HELP_STRING([--with-gmp=prefix], @@ -60,16 +74,23 @@ AC_ARG_WITH([mpfr], [AM_CXXFLAGS="-I$with_mpfr/include $AM_CXXFLAGS"] [AM_LDFLAGS="-L$with_mpfr/lib $AM_LDFLAGS"]) -############### FFTW3 -AC_ARG_WITH([fftw], +############### FFTW3 +AC_ARG_WITH([fftw], [AS_HELP_STRING([--with-fftw=prefix], [try this for a non-standard install prefix of the FFTW3 library])], [AM_CXXFLAGS="-I$with_fftw/include $AM_CXXFLAGS"] [AM_LDFLAGS="-L$with_fftw/lib $AM_LDFLAGS"]) -############### lapack +############### LIME +AC_ARG_WITH([lime], + [AS_HELP_STRING([--with-lime=prefix], + [try this for a non-standard install prefix of the LIME library])], + [AM_CXXFLAGS="-I$with_lime/include $AM_CXXFLAGS"] + [AM_LDFLAGS="-L$with_lime/lib $AM_LDFLAGS"]) + +############### lapack AC_ARG_ENABLE([lapack], - [AC_HELP_STRING([--enable-lapack=yes|no|prefix], [enable LAPACK])], + [AC_HELP_STRING([--enable-lapack=yes|no|prefix], [enable LAPACK])], [ac_LAPACK=${enable_lapack}], [ac_LAPACK=no]) case ${ac_LAPACK} in @@ -83,6 +104,18 @@ case ${ac_LAPACK} in AC_DEFINE([USE_LAPACK],[1],[use LAPACK]);; esac +############### FP16 conversions +AC_ARG_ENABLE([sfw-fp16], + [AC_HELP_STRING([--enable-sfw-fp16=yes|no], [enable software fp16 comms])], + [ac_SFW_FP16=${enable_sfw_fp16}], [ac_SFW_FP16=yes]) +case ${ac_SFW_FP16} in + yes) + AC_DEFINE([SFW_FP16],[1],[software conversion to fp16]);; + no);; + *) + AC_MSG_ERROR(["SFW FP16 option not supported ${ac_SFW_FP16}"]);; +esac + ############### MKL AC_ARG_ENABLE([mkl], [AC_HELP_STRING([--enable-mkl=yes|no|prefix], [enable Intel MKL for LAPACK & FFTW])], @@ -99,9 +132,16 @@ case ${ac_MKL} in AC_DEFINE([USE_MKL], [1], [Define to 1 if you use the Intel MKL]);; esac +############### HDF5 +AC_ARG_WITH([hdf5], + [AS_HELP_STRING([--with-hdf5=prefix], + [try this for a non-standard install prefix of the HDF5 library])], + [AM_CXXFLAGS="-I$with_hdf5/include $AM_CXXFLAGS"] + [AM_LDFLAGS="-L$with_hdf5/lib $AM_LDFLAGS"]) + ############### first-touch AC_ARG_ENABLE([numa], - [AC_HELP_STRING([--enable-numa=yes|no|prefix], [enable first touch numa opt])], + [AC_HELP_STRING([--enable-numa=yes|no|prefix], [enable first touch numa opt])], [ac_NUMA=${enable_NUMA}],[ac_NUMA=no]) case ${ac_NUMA} in @@ -127,8 +167,8 @@ if test "${ac_MKL}x" != "nox"; then fi AC_SEARCH_LIBS([__gmpf_init], [gmp], - [AC_SEARCH_LIBS([mpfr_init], [mpfr], - [AC_DEFINE([HAVE_LIBMPFR], [1], + [AC_SEARCH_LIBS([mpfr_init], [mpfr], + [AC_DEFINE([HAVE_LIBMPFR], [1], [Define to 1 if you have the `MPFR' library])] [have_mpfr=true], [AC_MSG_ERROR([MPFR library not found])])] [AC_DEFINE([HAVE_LIBGMP], [1], [Define to 1 if you have the `GMP' library])] @@ -137,7 +177,7 @@ AC_SEARCH_LIBS([__gmpf_init], [gmp], if test "${ac_LAPACK}x" != "nox"; then AC_SEARCH_LIBS([LAPACKE_sbdsdc], [lapack], [], [AC_MSG_ERROR("LAPACK enabled but library not found")]) -fi +fi AC_SEARCH_LIBS([fftw_execute], [fftw3], [AC_SEARCH_LIBS([fftwf_execute], [fftw3f], [], @@ -145,6 +185,29 @@ AC_SEARCH_LIBS([fftw_execute], [fftw3], [AC_DEFINE([HAVE_FFTW], [1], [Define to 1 if you have the `FFTW' library])] [have_fftw=true]) +AC_SEARCH_LIBS([limeCreateReader], [lime], + [AC_DEFINE([HAVE_LIME], [1], [Define to 1 if you have the `LIME' library])] + [have_lime=true], + [AC_MSG_WARN(C-LIME library was not found in your system. +In order to use ILGG file format please install or provide the correct path to your installation +Info at: http://usqcd.jlab.org/usqcd-docs/c-lime/)]) + +AC_SEARCH_LIBS([crc32], [z], + [AC_DEFINE([HAVE_ZLIB], [1], [Define to 1 if you have the `LIBZ' library])] + [have_zlib=true] [LIBS="${LIBS} -lz"], + [AC_MSG_ERROR(zlib library was not found in your system.)]) + +AC_SEARCH_LIBS([move_pages], [numa], + [AC_DEFINE([HAVE_LIBNUMA], [1], [Define to 1 if you have the `LIBNUMA' library])] + [have_libnuma=true] [LIBS="${LIBS} -lnuma"], + [AC_MSG_WARN(libnuma library was not found in your system. Some optimisations will not apply)]) + +AC_SEARCH_LIBS([H5Fopen], [hdf5_cpp], + [AC_DEFINE([HAVE_HDF5], [1], [Define to 1 if you have the `HDF5' library])] + [have_hdf5=true] + [LIBS="${LIBS} -lhdf5"], [], [-lhdf5]) +AM_CONDITIONAL(BUILD_HDF5, [ test "${have_hdf5}X" == "trueX" ]) + CXXFLAGS=$CXXFLAGS_CPY LDFLAGS=$LDFLAGS_CPY @@ -163,19 +226,26 @@ case ${ax_cv_cxx_compiler_vendor} in case ${ac_SIMD} in SSE4) AC_DEFINE([SSE4],[1],[SSE4 intrinsics]) - SIMD_FLAGS='-msse4.2';; + case ${ac_SFW_FP16} in + yes) + SIMD_FLAGS='-msse4.2';; + no) + SIMD_FLAGS='-msse4.2 -mf16c';; + *) + AC_MSG_ERROR(["SFW_FP16 must be either yes or no value ${ac_SFW_FP16} "]);; + esac;; AVX) AC_DEFINE([AVX1],[1],[AVX intrinsics]) - SIMD_FLAGS='-mavx';; + SIMD_FLAGS='-mavx -mf16c';; AVXFMA4) AC_DEFINE([AVXFMA4],[1],[AVX intrinsics with FMA4]) - SIMD_FLAGS='-mavx -mfma4';; + SIMD_FLAGS='-mavx -mfma4 -mf16c';; AVXFMA) AC_DEFINE([AVXFMA],[1],[AVX intrinsics with FMA3]) - SIMD_FLAGS='-mavx -mfma';; + SIMD_FLAGS='-mavx -mfma -mf16c';; AVX2) AC_DEFINE([AVX2],[1],[AVX2 intrinsics]) - SIMD_FLAGS='-mavx2 -mfma';; + SIMD_FLAGS='-mavx2 -mfma -mf16c';; AVX512) AC_DEFINE([AVX512],[1],[AVX512 intrinsics]) SIMD_FLAGS='-mavx512f -mavx512pf -mavx512er -mavx512cd';; @@ -184,6 +254,7 @@ case ${ax_cv_cxx_compiler_vendor} in SIMD_FLAGS='';; KNL) AC_DEFINE([AVX512],[1],[AVX512 intrinsics]) + AC_DEFINE([KNL],[1],[Knights landing processor]) SIMD_FLAGS='-march=knl';; GEN) AC_DEFINE([GEN],[1],[generic vector code]) @@ -191,6 +262,9 @@ case ${ax_cv_cxx_compiler_vendor} in [generic SIMD vector width (in bytes)]) SIMD_GEN_WIDTH_MSG=" (width= $ac_gen_simd_width)" SIMD_FLAGS='';; + NEONv8) + AC_DEFINE([NEONV8],[1],[ARMv8 NEON]) + SIMD_FLAGS='-march=armv8-a';; QPX|BGQ) AC_DEFINE([QPX],[1],[QPX intrinsics for BG/Q]) SIMD_FLAGS='';; @@ -219,6 +293,7 @@ case ${ax_cv_cxx_compiler_vendor} in SIMD_FLAGS='';; KNL) AC_DEFINE([AVX512],[1],[AVX512 intrinsics for Knights Landing]) + AC_DEFINE([KNL],[1],[Knights landing processor]) SIMD_FLAGS='-xmic-avx512';; GEN) AC_DEFINE([GEN],[1],[generic vector code]) @@ -256,8 +331,41 @@ case ${ac_PRECISION} in double) AC_DEFINE([GRID_DEFAULT_PRECISION_DOUBLE],[1],[GRID_DEFAULT_PRECISION is DOUBLE] ) ;; + *) + AC_MSG_ERROR([${ac_PRECISION} unsupported --enable-precision option]); + ;; esac +###################### Shared memory allocation technique under MPI3 +AC_ARG_ENABLE([shm],[AC_HELP_STRING([--enable-shm=shmget|shmopen|hugetlbfs], + [Select SHM allocation technique])],[ac_SHM=${enable_shm}],[ac_SHM=shmopen]) + +case ${ac_SHM} in + + shmget) + AC_DEFINE([GRID_MPI3_SHMGET],[1],[GRID_MPI3_SHMGET] ) + ;; + + shmopen) + AC_DEFINE([GRID_MPI3_SHMOPEN],[1],[GRID_MPI3_SHMOPEN] ) + ;; + + hugetlbfs) + AC_DEFINE([GRID_MPI3_SHMMMAP],[1],[GRID_MPI3_SHMMMAP] ) + ;; + + *) + AC_MSG_ERROR([${ac_SHM} unsupported --enable-shm option]); + ;; +esac + +###################### Shared base path for SHMMMAP +AC_ARG_ENABLE([shmpath],[AC_HELP_STRING([--enable-shmpath=path], + [Select SHM mmap base path for hugetlbfs])], + [ac_SHMPATH=${enable_shmpath}], + [ac_SHMPATH=/var/lib/hugetlbfs/pagesize-2MB/]) +AC_DEFINE_UNQUOTED([GRID_SHM_PATH],["$ac_SHMPATH"],[Path to a hugetlbfs filesystem for MMAPing]) + ############### communication type selection AC_ARG_ENABLE([comms],[AC_HELP_STRING([--enable-comms=none|mpi|mpi-auto|mpi3|mpi3-auto|shmem], [Select communications])],[ac_COMMS=${enable_comms}],[ac_COMMS=none]) @@ -267,14 +375,14 @@ case ${ac_COMMS} in AC_DEFINE([GRID_COMMS_NONE],[1],[GRID_COMMS_NONE] ) comms_type='none' ;; - mpi3l*) - AC_DEFINE([GRID_COMMS_MPI3L],[1],[GRID_COMMS_MPI3L] ) - comms_type='mpi3l' - ;; mpi3*) AC_DEFINE([GRID_COMMS_MPI3],[1],[GRID_COMMS_MPI3] ) comms_type='mpi3' ;; + mpit) + AC_DEFINE([GRID_COMMS_MPIT],[1],[GRID_COMMS_MPIT] ) + comms_type='mpit' + ;; mpi*) AC_DEFINE([GRID_COMMS_MPI],[1],[GRID_COMMS_MPI] ) comms_type='mpi' @@ -284,7 +392,7 @@ case ${ac_COMMS} in comms_type='shmem' ;; *) - AC_MSG_ERROR([${ac_COMMS} unsupported --enable-comms option]); + AC_MSG_ERROR([${ac_COMMS} unsupported --enable-comms option]); ;; esac case ${ac_COMMS} in @@ -302,13 +410,13 @@ esac AM_CONDITIONAL(BUILD_COMMS_SHMEM, [ test "${comms_type}X" == "shmemX" ]) AM_CONDITIONAL(BUILD_COMMS_MPI, [ test "${comms_type}X" == "mpiX" ]) AM_CONDITIONAL(BUILD_COMMS_MPI3, [ test "${comms_type}X" == "mpi3X" ] ) -AM_CONDITIONAL(BUILD_COMMS_MPI3L, [ test "${comms_type}X" == "mpi3lX" ] ) +AM_CONDITIONAL(BUILD_COMMS_MPIT, [ test "${comms_type}X" == "mpitX" ] ) AM_CONDITIONAL(BUILD_COMMS_NONE, [ test "${comms_type}X" == "noneX" ]) ############### RNG selection -AC_ARG_ENABLE([rng],[AC_HELP_STRING([--enable-rng=ranlux48|mt19937],\ +AC_ARG_ENABLE([rng],[AC_HELP_STRING([--enable-rng=ranlux48|mt19937|sitmo],\ [Select Random Number Generator to be used])],\ - [ac_RNG=${enable_rng}],[ac_RNG=ranlux48]) + [ac_RNG=${enable_rng}],[ac_RNG=sitmo]) case ${ac_RNG} in ranlux48) @@ -317,8 +425,11 @@ case ${ac_RNG} in mt19937) AC_DEFINE([RNG_MT19937],[1],[RNG_MT19937] ) ;; + sitmo) + AC_DEFINE([RNG_SITMO],[1],[RNG_SITMO] ) + ;; *) - AC_MSG_ERROR([${ac_RNG} unsupported --enable-rng option]); + AC_MSG_ERROR([${ac_RNG} unsupported --enable-rng option]); ;; esac @@ -335,7 +446,7 @@ case ${ac_TIMERS} in AC_DEFINE([TIMERS_OFF],[1],[TIMERS_OFF] ) ;; *) - AC_MSG_ERROR([${ac_TIMERS} unsupported --enable-timers option]); + AC_MSG_ERROR([${ac_TIMERS} unsupported --enable-timers option]); ;; esac @@ -347,7 +458,7 @@ case ${ac_CHROMA} in yes|no) ;; *) - AC_MSG_ERROR([${ac_CHROMA} unsupported --enable-chroma option]); + AC_MSG_ERROR([${ac_CHROMA} unsupported --enable-chroma option]); ;; esac @@ -368,29 +479,31 @@ DX_INIT_DOXYGEN([$PACKAGE_NAME], [doxygen.cfg]) ############### Ouput cwd=`pwd -P`; cd ${srcdir}; abs_srcdir=`pwd -P`; cd ${cwd} +GRID_CXXFLAGS="$AM_CXXFLAGS $CXXFLAGS" +GRID_LDFLAGS="$AM_LDFLAGS $LDFLAGS" +GRID_LIBS=$LIBS +GRID_SHORT_SHA=`git rev-parse --short HEAD` +GRID_SHA=`git rev-parse HEAD` +GRID_BRANCH=`git rev-parse --abbrev-ref HEAD` AM_CXXFLAGS="-I${abs_srcdir}/include $AM_CXXFLAGS" AM_CFLAGS="-I${abs_srcdir}/include $AM_CFLAGS" AM_LDFLAGS="-L${cwd}/lib $AM_LDFLAGS" AC_SUBST([AM_CFLAGS]) AC_SUBST([AM_CXXFLAGS]) AC_SUBST([AM_LDFLAGS]) -AC_CONFIG_FILES(Makefile) -AC_CONFIG_FILES(lib/Makefile) -AC_CONFIG_FILES(tests/Makefile) -AC_CONFIG_FILES(tests/IO/Makefile) -AC_CONFIG_FILES(tests/core/Makefile) -AC_CONFIG_FILES(tests/debug/Makefile) -AC_CONFIG_FILES(tests/forces/Makefile) -AC_CONFIG_FILES(tests/hmc/Makefile) -AC_CONFIG_FILES(tests/solver/Makefile) -AC_CONFIG_FILES(tests/qdpxx/Makefile) -AC_CONFIG_FILES(benchmarks/Makefile) -AC_OUTPUT +AC_SUBST([GRID_CXXFLAGS]) +AC_SUBST([GRID_LDFLAGS]) +AC_SUBST([GRID_LIBS]) +AC_SUBST([GRID_SHA]) +AC_SUBST([GRID_BRANCH]) + +git_commit=`cd $srcdir && ./scripts/configure.commit` echo "~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Summary of configuration for $PACKAGE v$VERSION ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - +----- GIT VERSION ------------------------------------- +$git_commit ----- PLATFORM ---------------------------------------- architecture (build) : $build_cpu os (build) : $build_os @@ -400,13 +513,18 @@ compiler vendor : ${ax_cv_cxx_compiler_vendor} compiler version : ${ax_cv_gxx_version} ----- BUILD OPTIONS ----------------------------------- SIMD : ${ac_SIMD}${SIMD_GEN_WIDTH_MSG} -Threading : ${ac_openmp} +Threading : ${ac_openmp} Communications type : ${comms_type} +Shared memory allocator : ${ac_SHM} +Shared memory mmap path : ${ac_SHMPATH} Default precision : ${ac_PRECISION} -RNG choice : ${ac_RNG} +Software FP16 conversion : ${ac_SFW_FP16} +RNG choice : ${ac_RNG} GMP : `if test "x$have_gmp" = xtrue; then echo yes; else echo no; fi` LAPACK : ${ac_LAPACK} FFTW : `if test "x$have_fftw" = xtrue; then echo yes; else echo no; fi` +LIME (ILDG support) : `if test "x$have_lime" = xtrue; then echo yes; else echo no; fi` +HDF5 : `if test "x$have_hdf5" = xtrue; then echo yes; else echo no; fi` build DOXYGEN documentation : `if test "$DX_FLAG_doc" = '1'; then echo yes; else echo no; fi` ----- BUILD FLAGS ------------------------------------- CXXFLAGS: @@ -415,7 +533,32 @@ LDFLAGS: `echo ${AM_LDFLAGS} ${LDFLAGS} | tr ' ' '\n' | sed 's/^-/ -/g'` LIBS: `echo ${LIBS} | tr ' ' '\n' | sed 's/^-/ -/g'` --------------------------------------------------------" > config.summary +-------------------------------------------------------" > grid.configure.summary + +GRID_SUMMARY="`cat grid.configure.summary`" +AM_SUBST_NOTMAKE([GRID_SUMMARY]) +AC_SUBST([GRID_SUMMARY]) + +AC_CONFIG_FILES([grid-config], [chmod +x grid-config]) +AC_CONFIG_FILES(Makefile) +AC_CONFIG_FILES(lib/Makefile) +AC_CONFIG_FILES(tests/Makefile) +AC_CONFIG_FILES(tests/IO/Makefile) +AC_CONFIG_FILES(tests/core/Makefile) +AC_CONFIG_FILES(tests/debug/Makefile) +AC_CONFIG_FILES(tests/forces/Makefile) +AC_CONFIG_FILES(tests/hadrons/Makefile) +AC_CONFIG_FILES(tests/hmc/Makefile) +AC_CONFIG_FILES(tests/solver/Makefile) +AC_CONFIG_FILES(tests/smearing/Makefile) +AC_CONFIG_FILES(tests/qdpxx/Makefile) +AC_CONFIG_FILES(tests/testu01/Makefile) +AC_CONFIG_FILES(benchmarks/Makefile) +AC_CONFIG_FILES(extras/Makefile) +AC_CONFIG_FILES(extras/Hadrons/Makefile) +AC_OUTPUT + echo "" -cat config.summary +cat grid.configure.summary echo "" + diff --git a/extras/Hadrons/Application.cc b/extras/Hadrons/Application.cc new file mode 100644 index 00000000..90ebcfd7 --- /dev/null +++ b/extras/Hadrons/Application.cc @@ -0,0 +1,318 @@ +/************************************************************************************* + +Grid physics library, www.github.com/paboyle/Grid + +Source file: extras/Hadrons/Application.cc + +Copyright (C) 2015 +Copyright (C) 2016 + +Author: Antonin Portelli + +This program is free software; you can redistribute it and/or modify +it under the terms of the GNU General Public License as published by +the Free Software Foundation; either version 2 of the License, or +(at your option) any later version. + +This program is distributed in the hope that it will be useful, +but WITHOUT ANY WARRANTY; without even the implied warranty of +MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +GNU General Public License for more details. + +You should have received a copy of the GNU General Public License along +with this program; if not, write to the Free Software Foundation, Inc., +51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA. + +See the full license in the file "LICENSE" in the top level distribution directory +*************************************************************************************/ +/* END LEGAL */ + +#include +#include + +using namespace Grid; +using namespace QCD; +using namespace Hadrons; + +#define BIG_SEP "===============" +#define SEP "---------------" + +/****************************************************************************** + * Application implementation * + ******************************************************************************/ +// constructors //////////////////////////////////////////////////////////////// +Application::Application(void) +{ + LOG(Message) << "Modules available:" << std::endl; + auto list = ModuleFactory::getInstance().getBuilderList(); + for (auto &m: list) + { + LOG(Message) << " " << m << std::endl; + } + auto dim = GridDefaultLatt(), mpi = GridDefaultMpi(), loc(dim); + locVol_ = 1; + for (unsigned int d = 0; d < dim.size(); ++d) + { + loc[d] /= mpi[d]; + locVol_ *= loc[d]; + } + LOG(Message) << "Global lattice: " << dim << std::endl; + LOG(Message) << "MPI partition : " << mpi << std::endl; + LOG(Message) << "Local lattice : " << loc << std::endl; +} + +Application::Application(const Application::GlobalPar &par) +: Application() +{ + setPar(par); +} + +Application::Application(const std::string parameterFileName) +: Application() +{ + parameterFileName_ = parameterFileName; +} + +// environment shortcut //////////////////////////////////////////////////////// +Environment & Application::env(void) const +{ + return Environment::getInstance(); +} + +// access ////////////////////////////////////////////////////////////////////// +void Application::setPar(const Application::GlobalPar &par) +{ + par_ = par; + env().setSeed(strToVec(par_.seed)); +} + +const Application::GlobalPar & Application::getPar(void) +{ + return par_; +} + +// execute ///////////////////////////////////////////////////////////////////// +void Application::run(void) +{ + if (!parameterFileName_.empty() and (env().getNModule() == 0)) + { + parseParameterFile(parameterFileName_); + } + if (!scheduled_) + { + schedule(); + } + printSchedule(); + configLoop(); +} + +// parse parameter file //////////////////////////////////////////////////////// +class ObjectId: Serializable +{ +public: + GRID_SERIALIZABLE_CLASS_MEMBERS(ObjectId, + std::string, name, + std::string, type); +}; + +void Application::parseParameterFile(const std::string parameterFileName) +{ + XmlReader reader(parameterFileName); + GlobalPar par; + ObjectId id; + + LOG(Message) << "Building application from '" << parameterFileName << "'..." << std::endl; + read(reader, "parameters", par); + setPar(par); + push(reader, "modules"); + push(reader, "module"); + do + { + read(reader, "id", id); + env().createModule(id.name, id.type, reader); + } while (reader.nextElement("module")); + pop(reader); + pop(reader); +} + +void Application::saveParameterFile(const std::string parameterFileName) +{ + XmlWriter writer(parameterFileName); + ObjectId id; + const unsigned int nMod = env().getNModule(); + + LOG(Message) << "Saving application to '" << parameterFileName << "'..." << std::endl; + write(writer, "parameters", getPar()); + push(writer, "modules"); + for (unsigned int i = 0; i < nMod; ++i) + { + push(writer, "module"); + id.name = env().getModuleName(i); + id.type = env().getModule(i)->getRegisteredName(); + write(writer, "id", id); + env().getModule(i)->saveParameters(writer, "options"); + pop(writer); + } + pop(writer); + pop(writer); +} + +// schedule computation //////////////////////////////////////////////////////// +#define MEM_MSG(size)\ +sizeString((size)*locVol_) << " (" << sizeString(size) << "/site)" + +#define DEFINE_MEMPEAK \ +GeneticScheduler::ObjFunc memPeak = \ +[this](const std::vector &program)\ +{\ + unsigned int memPeak;\ + bool msg;\ + \ + msg = HadronsLogMessage.isActive();\ + HadronsLogMessage.Active(false);\ + env().dryRun(true);\ + memPeak = env().executeProgram(program);\ + env().dryRun(false);\ + env().freeAll();\ + HadronsLogMessage.Active(true);\ + \ + return memPeak;\ +} + +void Application::schedule(void) +{ + DEFINE_MEMPEAK; + + // build module dependency graph + LOG(Message) << "Building module graph..." << std::endl; + auto graph = env().makeModuleGraph(); + auto con = graph.getConnectedComponents(); + + // constrained topological sort using a genetic algorithm + LOG(Message) << "Scheduling computation..." << std::endl; + LOG(Message) << " #module= " << graph.size() << std::endl; + LOG(Message) << " population size= " << par_.genetic.popSize << std::endl; + LOG(Message) << " max. generation= " << par_.genetic.maxGen << std::endl; + LOG(Message) << " max. cst. generation= " << par_.genetic.maxCstGen << std::endl; + LOG(Message) << " mutation rate= " << par_.genetic.mutationRate << std::endl; + + unsigned int k = 0, gen, prevPeak, nCstPeak = 0; + std::random_device rd; + GeneticScheduler::Parameters par; + + par.popSize = par_.genetic.popSize; + par.mutationRate = par_.genetic.mutationRate; + par.seed = rd(); + memPeak_ = 0; + CartesianCommunicator::BroadcastWorld(0, &(par.seed), sizeof(par.seed)); + for (unsigned int i = 0; i < con.size(); ++i) + { + GeneticScheduler scheduler(con[i], memPeak, par); + + gen = 0; + do + { + LOG(Debug) << "Generation " << gen << ":" << std::endl; + scheduler.nextGeneration(); + if (gen != 0) + { + if (prevPeak == scheduler.getMinValue()) + { + nCstPeak++; + } + else + { + nCstPeak = 0; + } + } + + prevPeak = scheduler.getMinValue(); + if (gen % 10 == 0) + { + LOG(Iterative) << "Generation " << gen << ": " + << MEM_MSG(scheduler.getMinValue()) << std::endl; + } + + gen++; + } while ((gen < par_.genetic.maxGen) + and (nCstPeak < par_.genetic.maxCstGen)); + auto &t = scheduler.getMinSchedule(); + if (scheduler.getMinValue() > memPeak_) + { + memPeak_ = scheduler.getMinValue(); + } + for (unsigned int j = 0; j < t.size(); ++j) + { + program_.push_back(t[j]); + } + } + scheduled_ = true; +} + +void Application::saveSchedule(const std::string filename) +{ + TextWriter writer(filename); + std::vector program; + + if (!scheduled_) + { + HADRON_ERROR("Computation not scheduled"); + } + LOG(Message) << "Saving current schedule to '" << filename << "'..." + << std::endl; + for (auto address: program_) + { + program.push_back(env().getModuleName(address)); + } + write(writer, "schedule", program); +} + +void Application::loadSchedule(const std::string filename) +{ + DEFINE_MEMPEAK; + + TextReader reader(filename); + std::vector program; + + LOG(Message) << "Loading schedule from '" << filename << "'..." + << std::endl; + read(reader, "schedule", program); + program_.clear(); + for (auto &name: program) + { + program_.push_back(env().getModuleAddress(name)); + } + scheduled_ = true; + memPeak_ = memPeak(program_); +} + +void Application::printSchedule(void) +{ + if (!scheduled_) + { + HADRON_ERROR("Computation not scheduled"); + } + LOG(Message) << "Schedule (memory peak: " << MEM_MSG(memPeak_) << "):" + << std::endl; + for (unsigned int i = 0; i < program_.size(); ++i) + { + LOG(Message) << std::setw(4) << i + 1 << ": " + << env().getModuleName(program_[i]) << std::endl; + } +} + +// loop on configurations ////////////////////////////////////////////////////// +void Application::configLoop(void) +{ + auto range = par_.trajCounter; + + for (unsigned int t = range.start; t < range.end; t += range.step) + { + LOG(Message) << BIG_SEP << " Starting measurement for trajectory " << t + << " " << BIG_SEP << std::endl; + env().setTrajectory(t); + env().executeProgram(program_); + } + LOG(Message) << BIG_SEP << " End of measurement " << BIG_SEP << std::endl; + env().freeAll(); +} diff --git a/extras/Hadrons/Application.hpp b/extras/Hadrons/Application.hpp new file mode 100644 index 00000000..fce9b6eb --- /dev/null +++ b/extras/Hadrons/Application.hpp @@ -0,0 +1,132 @@ +/************************************************************************************* + +Grid physics library, www.github.com/paboyle/Grid + +Source file: extras/Hadrons/Application.hpp + +Copyright (C) 2015 +Copyright (C) 2016 + +Author: Antonin Portelli + +This program is free software; you can redistribute it and/or modify +it under the terms of the GNU General Public License as published by +the Free Software Foundation; either version 2 of the License, or +(at your option) any later version. + +This program is distributed in the hope that it will be useful, +but WITHOUT ANY WARRANTY; without even the implied warranty of +MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +GNU General Public License for more details. + +You should have received a copy of the GNU General Public License along +with this program; if not, write to the Free Software Foundation, Inc., +51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA. + +See the full license in the file "LICENSE" in the top level distribution directory +*************************************************************************************/ +/* END LEGAL */ + +#ifndef Hadrons_Application_hpp_ +#define Hadrons_Application_hpp_ + +#include +#include +#include +#include + +BEGIN_HADRONS_NAMESPACE + +/****************************************************************************** + * Main program manager * + ******************************************************************************/ +class Application +{ +public: + class TrajRange: Serializable + { + public: + GRID_SERIALIZABLE_CLASS_MEMBERS(TrajRange, + unsigned int, start, + unsigned int, end, + unsigned int, step); + }; + class GeneticPar: Serializable + { + public: + GeneticPar(void): + popSize{20}, maxGen{1000}, maxCstGen{100}, mutationRate{.1} {}; + public: + GRID_SERIALIZABLE_CLASS_MEMBERS(GeneticPar, + unsigned int, popSize, + unsigned int, maxGen, + unsigned int, maxCstGen, + double , mutationRate); + }; + class GlobalPar: Serializable + { + public: + GRID_SERIALIZABLE_CLASS_MEMBERS(GlobalPar, + TrajRange, trajCounter, + GeneticPar, genetic, + std::string, seed); + }; +public: + // constructors + Application(void); + Application(const GlobalPar &par); + Application(const std::string parameterFileName); + // destructor + virtual ~Application(void) = default; + // access + void setPar(const GlobalPar &par); + const GlobalPar & getPar(void); + // module creation + template + void createModule(const std::string name); + template + void createModule(const std::string name, const typename M::Par &par); + // execute + void run(void); + // XML parameter file I/O + void parseParameterFile(const std::string parameterFileName); + void saveParameterFile(const std::string parameterFileName); + // schedule computation + void schedule(void); + void saveSchedule(const std::string filename); + void loadSchedule(const std::string filename); + void printSchedule(void); + // loop on configurations + void configLoop(void); +private: + // environment shortcut + Environment & env(void) const; +private: + long unsigned int locVol_; + std::string parameterFileName_{""}; + GlobalPar par_; + std::vector program_; + Environment::Size memPeak_; + bool scheduled_{false}; +}; + +/****************************************************************************** + * Application template implementation * + ******************************************************************************/ +// module creation ///////////////////////////////////////////////////////////// +template +void Application::createModule(const std::string name) +{ + env().createModule(name); +} + +template +void Application::createModule(const std::string name, + const typename M::Par &par) +{ + env().createModule(name, par); +} + +END_HADRONS_NAMESPACE + +#endif // Hadrons_Application_hpp_ diff --git a/extras/Hadrons/Environment.cc b/extras/Hadrons/Environment.cc new file mode 100644 index 00000000..0e7a4326 --- /dev/null +++ b/extras/Hadrons/Environment.cc @@ -0,0 +1,793 @@ +/************************************************************************************* + +Grid physics library, www.github.com/paboyle/Grid + +Source file: extras/Hadrons/Environment.cc + +Copyright (C) 2015 +Copyright (C) 2016 + +Author: Antonin Portelli + +This program is free software; you can redistribute it and/or modify +it under the terms of the GNU General Public License as published by +the Free Software Foundation; either version 2 of the License, or +(at your option) any later version. + +This program is distributed in the hope that it will be useful, +but WITHOUT ANY WARRANTY; without even the implied warranty of +MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +GNU General Public License for more details. + +You should have received a copy of the GNU General Public License along +with this program; if not, write to the Free Software Foundation, Inc., +51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA. + +See the full license in the file "LICENSE" in the top level distribution directory +*************************************************************************************/ +/* END LEGAL */ + +#include +#include +#include + +using namespace Grid; +using namespace QCD; +using namespace Hadrons; + +/****************************************************************************** + * Environment implementation * + ******************************************************************************/ +// constructor ///////////////////////////////////////////////////////////////// +Environment::Environment(void) +{ + dim_ = GridDefaultLatt(); + nd_ = dim_.size(); + grid4d_.reset(SpaceTimeGrid::makeFourDimGrid( + dim_, GridDefaultSimd(nd_, vComplex::Nsimd()), + GridDefaultMpi())); + gridRb4d_.reset(SpaceTimeGrid::makeFourDimRedBlackGrid(grid4d_.get())); + auto loc = getGrid()->LocalDimensions(); + locVol_ = 1; + for (unsigned int d = 0; d < loc.size(); ++d) + { + locVol_ *= loc[d]; + } + rng4d_.reset(new GridParallelRNG(grid4d_.get())); +} + +// dry run ///////////////////////////////////////////////////////////////////// +void Environment::dryRun(const bool isDry) +{ + dryRun_ = isDry; +} + +bool Environment::isDryRun(void) const +{ + return dryRun_; +} + +// trajectory number /////////////////////////////////////////////////////////// +void Environment::setTrajectory(const unsigned int traj) +{ + traj_ = traj; +} + +unsigned int Environment::getTrajectory(void) const +{ + return traj_; +} + +// grids /////////////////////////////////////////////////////////////////////// +void Environment::createGrid(const unsigned int Ls) +{ + if (grid5d_.find(Ls) == grid5d_.end()) + { + auto g = getGrid(); + + grid5d_[Ls].reset(SpaceTimeGrid::makeFiveDimGrid(Ls, g)); + gridRb5d_[Ls].reset(SpaceTimeGrid::makeFiveDimRedBlackGrid(Ls, g)); + } +} + +GridCartesian * Environment::getGrid(const unsigned int Ls) const +{ + try + { + if (Ls == 1) + { + return grid4d_.get(); + } + else + { + return grid5d_.at(Ls).get(); + } + } + catch(std::out_of_range &) + { + HADRON_ERROR("no grid with Ls= " << Ls); + } +} + +GridRedBlackCartesian * Environment::getRbGrid(const unsigned int Ls) const +{ + try + { + if (Ls == 1) + { + return gridRb4d_.get(); + } + else + { + return gridRb5d_.at(Ls).get(); + } + } + catch(std::out_of_range &) + { + HADRON_ERROR("no red-black 5D grid with Ls= " << Ls); + } +} + +unsigned int Environment::getNd(void) const +{ + return nd_; +} + +std::vector Environment::getDim(void) const +{ + return dim_; +} + +int Environment::getDim(const unsigned int mu) const +{ + return dim_[mu]; +} + +// random number generator ///////////////////////////////////////////////////// +void Environment::setSeed(const std::vector &seed) +{ + rng4d_->SeedFixedIntegers(seed); +} + +GridParallelRNG * Environment::get4dRng(void) const +{ + return rng4d_.get(); +} + +// module management /////////////////////////////////////////////////////////// +void Environment::pushModule(Environment::ModPt &pt) +{ + std::string name = pt->getName(); + + if (!hasModule(name)) + { + std::vector inputAddress; + unsigned int address; + ModuleInfo m; + + m.data = std::move(pt); + m.type = typeIdPt(*m.data.get()); + m.name = name; + auto input = m.data->getInput(); + for (auto &in: input) + { + if (!hasObject(in)) + { + addObject(in , -1); + } + m.input.push_back(objectAddress_[in]); + } + auto output = m.data->getOutput(); + module_.push_back(std::move(m)); + address = static_cast(module_.size() - 1); + moduleAddress_[name] = address; + for (auto &out: output) + { + if (!hasObject(out)) + { + addObject(out, address); + } + else + { + if (object_[objectAddress_[out]].module < 0) + { + object_[objectAddress_[out]].module = address; + } + else + { + HADRON_ERROR("object '" + out + + "' is already produced by module '" + + module_[object_[getObjectAddress(out)].module].name + + "' (while pushing module '" + name + "')"); + } + } + } + } + else + { + HADRON_ERROR("module '" + name + "' already exists"); + } +} + +unsigned int Environment::getNModule(void) const +{ + return module_.size(); +} + +void Environment::createModule(const std::string name, const std::string type, + XmlReader &reader) +{ + auto &factory = ModuleFactory::getInstance(); + auto pt = factory.create(type, name); + + pt->parseParameters(reader, "options"); + pushModule(pt); +} + +ModuleBase * Environment::getModule(const unsigned int address) const +{ + if (hasModule(address)) + { + return module_[address].data.get(); + } + else + { + HADRON_ERROR("no module with address " + std::to_string(address)); + } +} + +ModuleBase * Environment::getModule(const std::string name) const +{ + return getModule(getModuleAddress(name)); +} + +unsigned int Environment::getModuleAddress(const std::string name) const +{ + if (hasModule(name)) + { + return moduleAddress_.at(name); + } + else + { + HADRON_ERROR("no module with name '" + name + "'"); + } +} + +std::string Environment::getModuleName(const unsigned int address) const +{ + if (hasModule(address)) + { + return module_[address].name; + } + else + { + HADRON_ERROR("no module with address " + std::to_string(address)); + } +} + +std::string Environment::getModuleType(const unsigned int address) const +{ + if (hasModule(address)) + { + return typeName(module_[address].type); + } + else + { + HADRON_ERROR("no module with address " + std::to_string(address)); + } +} + +std::string Environment::getModuleType(const std::string name) const +{ + return getModuleType(getModuleAddress(name)); +} + +std::string Environment::getModuleNamespace(const unsigned int address) const +{ + std::string type = getModuleType(address), ns; + + auto pos2 = type.rfind("::"); + auto pos1 = type.rfind("::", pos2 - 2); + + return type.substr(pos1 + 2, pos2 - pos1 - 2); +} + +std::string Environment::getModuleNamespace(const std::string name) const +{ + return getModuleNamespace(getModuleAddress(name)); +} + +bool Environment::hasModule(const unsigned int address) const +{ + return (address < module_.size()); +} + +bool Environment::hasModule(const std::string name) const +{ + return (moduleAddress_.find(name) != moduleAddress_.end()); +} + +Graph Environment::makeModuleGraph(void) const +{ + Graph moduleGraph; + + for (unsigned int i = 0; i < module_.size(); ++i) + { + moduleGraph.addVertex(i); + for (auto &j: module_[i].input) + { + moduleGraph.addEdge(object_[j].module, i); + } + } + + return moduleGraph; +} + +#define BIG_SEP "===============" +#define SEP "---------------" +#define MEM_MSG(size)\ +sizeString((size)*locVol_) << " (" << sizeString(size) << "/site)" + +Environment::Size +Environment::executeProgram(const std::vector &p) +{ + Size memPeak = 0, sizeBefore, sizeAfter; + std::vector> freeProg; + bool continueCollect, nothingFreed; + + // build garbage collection schedule + freeProg.resize(p.size()); + for (unsigned int i = 0; i < object_.size(); ++i) + { + auto pred = [i, this](const unsigned int j) + { + auto &in = module_[j].input; + auto it = std::find(in.begin(), in.end(), i); + + return (it != in.end()) or (j == object_[i].module); + }; + auto it = std::find_if(p.rbegin(), p.rend(), pred); + if (it != p.rend()) + { + freeProg[p.rend() - it - 1].insert(i); + } + } + + // program execution + for (unsigned int i = 0; i < p.size(); ++i) + { + // execute module + if (!isDryRun()) + { + LOG(Message) << SEP << " Measurement step " << i+1 << "/" + << p.size() << " (module '" << module_[p[i]].name + << "') " << SEP << std::endl; + } + (*module_[p[i]].data)(); + sizeBefore = getTotalSize(); + // print used memory after execution + if (!isDryRun()) + { + LOG(Message) << "Allocated objects: " << MEM_MSG(sizeBefore) + << std::endl; + } + if (sizeBefore > memPeak) + { + memPeak = sizeBefore; + } + // garbage collection for step i + if (!isDryRun()) + { + LOG(Message) << "Garbage collection..." << std::endl; + } + nothingFreed = true; + do + { + continueCollect = false; + auto toFree = freeProg[i]; + for (auto &j: toFree) + { + // continue garbage collection while there are still + // objects without owners + continueCollect = continueCollect or !hasOwners(j); + if(freeObject(j)) + { + // if an object has been freed, remove it from + // the garbage collection schedule + freeProg[i].erase(j); + nothingFreed = false; + } + } + } while (continueCollect); + // any remaining objects in step i garbage collection schedule + // is scheduled for step i + 1 + if (i + 1 < p.size()) + { + for (auto &j: freeProg[i]) + { + freeProg[i + 1].insert(j); + } + } + // print used memory after garbage collection if necessary + if (!isDryRun()) + { + sizeAfter = getTotalSize(); + if (sizeBefore != sizeAfter) + { + LOG(Message) << "Allocated objects: " << MEM_MSG(sizeAfter) + << std::endl; + } + else + { + LOG(Message) << "Nothing to free" << std::endl; + } + } + } + + return memPeak; +} + +Environment::Size Environment::executeProgram(const std::vector &p) +{ + std::vector pAddress; + + for (auto &n: p) + { + pAddress.push_back(getModuleAddress(n)); + } + + return executeProgram(pAddress); +} + +// general memory management /////////////////////////////////////////////////// +void Environment::addObject(const std::string name, const int moduleAddress) +{ + if (!hasObject(name)) + { + ObjInfo info; + + info.name = name; + info.module = moduleAddress; + object_.push_back(std::move(info)); + objectAddress_[name] = static_cast(object_.size() - 1); + } + else + { + HADRON_ERROR("object '" + name + "' already exists"); + } +} + +void Environment::registerObject(const unsigned int address, + const unsigned int size, const unsigned int Ls) +{ + if (!hasRegisteredObject(address)) + { + if (hasObject(address)) + { + object_[address].size = size; + object_[address].Ls = Ls; + object_[address].isRegistered = true; + } + else + { + HADRON_ERROR("no object with address " + std::to_string(address)); + } + } + else + { + HADRON_ERROR("object with address " + std::to_string(address) + + " already registered"); + } +} + +void Environment::registerObject(const std::string name, + const unsigned int size, const unsigned int Ls) +{ + if (!hasObject(name)) + { + addObject(name); + } + registerObject(getObjectAddress(name), size, Ls); +} + +unsigned int Environment::getObjectAddress(const std::string name) const +{ + if (hasObject(name)) + { + return objectAddress_.at(name); + } + else + { + HADRON_ERROR("no object with name '" + name + "'"); + } +} + +std::string Environment::getObjectName(const unsigned int address) const +{ + if (hasObject(address)) + { + return object_[address].name; + } + else + { + HADRON_ERROR("no object with address " + std::to_string(address)); + } +} + +std::string Environment::getObjectType(const unsigned int address) const +{ + if (hasRegisteredObject(address)) + { + if (object_[address].type) + { + return typeName(object_[address].type); + } + else + { + return ""; + } + } + else if (hasObject(address)) + { + HADRON_ERROR("object with address " + std::to_string(address) + + " exists but is not registered"); + } + else + { + HADRON_ERROR("no object with address " + std::to_string(address)); + } +} + +std::string Environment::getObjectType(const std::string name) const +{ + return getObjectType(getObjectAddress(name)); +} + +Environment::Size Environment::getObjectSize(const unsigned int address) const +{ + if (hasRegisteredObject(address)) + { + return object_[address].size; + } + else if (hasObject(address)) + { + HADRON_ERROR("object with address " + std::to_string(address) + + " exists but is not registered"); + } + else + { + HADRON_ERROR("no object with address " + std::to_string(address)); + } +} + +Environment::Size Environment::getObjectSize(const std::string name) const +{ + return getObjectSize(getObjectAddress(name)); +} + +unsigned int Environment::getObjectModule(const unsigned int address) const +{ + if (hasObject(address)) + { + return object_[address].module; + } + else + { + HADRON_ERROR("no object with address " + std::to_string(address)); + } +} + +unsigned int Environment::getObjectModule(const std::string name) const +{ + return getObjectModule(getObjectAddress(name)); +} + +unsigned int Environment::getObjectLs(const unsigned int address) const +{ + if (hasRegisteredObject(address)) + { + return object_[address].Ls; + } + else if (hasObject(address)) + { + HADRON_ERROR("object with address " + std::to_string(address) + + " exists but is not registered"); + } + else + { + HADRON_ERROR("no object with address " + std::to_string(address)); + } +} + +unsigned int Environment::getObjectLs(const std::string name) const +{ + return getObjectLs(getObjectAddress(name)); +} + +bool Environment::hasObject(const unsigned int address) const +{ + return (address < object_.size()); +} + +bool Environment::hasObject(const std::string name) const +{ + auto it = objectAddress_.find(name); + + return ((it != objectAddress_.end()) and hasObject(it->second)); +} + +bool Environment::hasRegisteredObject(const unsigned int address) const +{ + if (hasObject(address)) + { + return object_[address].isRegistered; + } + else + { + return false; + } +} + +bool Environment::hasRegisteredObject(const std::string name) const +{ + if (hasObject(name)) + { + return hasRegisteredObject(getObjectAddress(name)); + } + else + { + return false; + } +} + +bool Environment::hasCreatedObject(const unsigned int address) const +{ + if (hasObject(address)) + { + return (object_[address].data != nullptr); + } + else + { + return false; + } +} + +bool Environment::hasCreatedObject(const std::string name) const +{ + if (hasObject(name)) + { + return hasCreatedObject(getObjectAddress(name)); + } + else + { + return false; + } +} + +bool Environment::isObject5d(const unsigned int address) const +{ + return (getObjectLs(address) > 1); +} + +bool Environment::isObject5d(const std::string name) const +{ + return (getObjectLs(name) > 1); +} + +Environment::Size Environment::getTotalSize(void) const +{ + Environment::Size size = 0; + + for (auto &o: object_) + { + if (o.isRegistered) + { + size += o.size; + } + } + + return size; +} + +void Environment::addOwnership(const unsigned int owner, + const unsigned int property) +{ + if (hasObject(property)) + { + object_[property].owners.insert(owner); + } + else + { + HADRON_ERROR("no object with address " + std::to_string(property)); + } + if (hasObject(owner)) + { + object_[owner].properties.insert(property); + } + else + { + HADRON_ERROR("no object with address " + std::to_string(owner)); + } +} + +void Environment::addOwnership(const std::string owner, + const std::string property) +{ + addOwnership(getObjectAddress(owner), getObjectAddress(property)); +} + +bool Environment::hasOwners(const unsigned int address) const +{ + + if (hasObject(address)) + { + return (!object_[address].owners.empty()); + } + else + { + HADRON_ERROR("no object with address " + std::to_string(address)); + } +} + +bool Environment::hasOwners(const std::string name) const +{ + return hasOwners(getObjectAddress(name)); +} + +bool Environment::freeObject(const unsigned int address) +{ + if (!hasOwners(address)) + { + if (!isDryRun() and object_[address].isRegistered) + { + LOG(Message) << "Destroying object '" << object_[address].name + << "'" << std::endl; + } + for (auto &p: object_[address].properties) + { + object_[p].owners.erase(address); + } + object_[address].size = 0; + object_[address].Ls = 0; + object_[address].isRegistered = false; + object_[address].type = nullptr; + object_[address].owners.clear(); + object_[address].properties.clear(); + object_[address].data.reset(nullptr); + + return true; + } + else + { + return false; + } +} + +bool Environment::freeObject(const std::string name) +{ + return freeObject(getObjectAddress(name)); +} + +void Environment::freeAll(void) +{ + for (unsigned int i = 0; i < object_.size(); ++i) + { + freeObject(i); + } +} + +void Environment::printContent(void) +{ + LOG(Message) << "Modules: " << std::endl; + for (unsigned int i = 0; i < module_.size(); ++i) + { + LOG(Message) << std::setw(4) << i << ": " + << getModuleName(i) << std::endl; + } + LOG(Message) << "Objects: " << std::endl; + for (unsigned int i = 0; i < object_.size(); ++i) + { + LOG(Message) << std::setw(4) << i << ": " + << getObjectName(i) << std::endl; + } +} diff --git a/extras/Hadrons/Environment.hpp b/extras/Hadrons/Environment.hpp new file mode 100644 index 00000000..13264bd5 --- /dev/null +++ b/extras/Hadrons/Environment.hpp @@ -0,0 +1,427 @@ +/************************************************************************************* + +Grid physics library, www.github.com/paboyle/Grid + +Source file: extras/Hadrons/Environment.hpp + +Copyright (C) 2015 +Copyright (C) 2016 + +Author: Antonin Portelli + +This program is free software; you can redistribute it and/or modify +it under the terms of the GNU General Public License as published by +the Free Software Foundation; either version 2 of the License, or +(at your option) any later version. + +This program is distributed in the hope that it will be useful, +but WITHOUT ANY WARRANTY; without even the implied warranty of +MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +GNU General Public License for more details. + +You should have received a copy of the GNU General Public License along +with this program; if not, write to the Free Software Foundation, Inc., +51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA. + +See the full license in the file "LICENSE" in the top level distribution directory +*************************************************************************************/ +/* END LEGAL */ + +#ifndef Hadrons_Environment_hpp_ +#define Hadrons_Environment_hpp_ + +#include +#include + +#ifndef SITE_SIZE_TYPE +#define SITE_SIZE_TYPE unsigned int +#endif + +BEGIN_HADRONS_NAMESPACE + +/****************************************************************************** + * Global environment * + ******************************************************************************/ +// forward declaration of Module +class ModuleBase; + +class Object +{ +public: + Object(void) = default; + virtual ~Object(void) = default; +}; + +template +class Holder: public Object +{ +public: + Holder(void) = default; + Holder(T *pt); + virtual ~Holder(void) = default; + T & get(void) const; + T * getPt(void) const; + void reset(T *pt); +private: + std::unique_ptr objPt_{nullptr}; +}; + +class Environment +{ + SINGLETON(Environment); +public: + typedef SITE_SIZE_TYPE Size; + typedef std::unique_ptr ModPt; + typedef std::unique_ptr GridPt; + typedef std::unique_ptr GridRbPt; + typedef std::unique_ptr RngPt; + typedef std::unique_ptr LatticePt; +private: + struct ModuleInfo + { + const std::type_info *type{nullptr}; + std::string name; + ModPt data{nullptr}; + std::vector input; + }; + struct ObjInfo + { + Size size{0}; + unsigned int Ls{0}; + bool isRegistered{false}; + const std::type_info *type{nullptr}; + std::string name; + int module{-1}; + std::set owners, properties; + std::unique_ptr data{nullptr}; + }; +public: + // dry run + void dryRun(const bool isDry); + bool isDryRun(void) const; + // trajectory number + void setTrajectory(const unsigned int traj); + unsigned int getTrajectory(void) const; + // grids + void createGrid(const unsigned int Ls); + GridCartesian * getGrid(const unsigned int Ls = 1) const; + GridRedBlackCartesian * getRbGrid(const unsigned int Ls = 1) const; + std::vector getDim(void) const; + int getDim(const unsigned int mu) const; + unsigned int getNd(void) const; + // random number generator + void setSeed(const std::vector &seed); + GridParallelRNG * get4dRng(void) const; + // module management + void pushModule(ModPt &pt); + template + void createModule(const std::string name); + template + void createModule(const std::string name, + const typename M::Par &par); + void createModule(const std::string name, + const std::string type, + XmlReader &reader); + unsigned int getNModule(void) const; + ModuleBase * getModule(const unsigned int address) const; + ModuleBase * getModule(const std::string name) const; + template + M * getModule(const unsigned int address) const; + template + M * getModule(const std::string name) const; + unsigned int getModuleAddress(const std::string name) const; + std::string getModuleName(const unsigned int address) const; + std::string getModuleType(const unsigned int address) const; + std::string getModuleType(const std::string name) const; + std::string getModuleNamespace(const unsigned int address) const; + std::string getModuleNamespace(const std::string name) const; + bool hasModule(const unsigned int address) const; + bool hasModule(const std::string name) const; + Graph makeModuleGraph(void) const; + Size executeProgram(const std::vector &p); + Size executeProgram(const std::vector &p); + // general memory management + void addObject(const std::string name, + const int moduleAddress = -1); + void registerObject(const unsigned int address, + const unsigned int size, + const unsigned int Ls = 1); + void registerObject(const std::string name, + const unsigned int size, + const unsigned int Ls = 1); + template + unsigned int lattice4dSize(void) const; + template + void registerLattice(const unsigned int address, + const unsigned int Ls = 1); + template + void registerLattice(const std::string name, + const unsigned int Ls = 1); + template + void setObject(const unsigned int address, T *object); + template + void setObject(const std::string name, T *object); + template + T * getObject(const unsigned int address) const; + template + T * getObject(const std::string name) const; + template + T * createLattice(const unsigned int address); + template + T * createLattice(const std::string name); + unsigned int getObjectAddress(const std::string name) const; + std::string getObjectName(const unsigned int address) const; + std::string getObjectType(const unsigned int address) const; + std::string getObjectType(const std::string name) const; + Size getObjectSize(const unsigned int address) const; + Size getObjectSize(const std::string name) const; + unsigned int getObjectModule(const unsigned int address) const; + unsigned int getObjectModule(const std::string name) const; + unsigned int getObjectLs(const unsigned int address) const; + unsigned int getObjectLs(const std::string name) const; + bool hasObject(const unsigned int address) const; + bool hasObject(const std::string name) const; + bool hasRegisteredObject(const unsigned int address) const; + bool hasRegisteredObject(const std::string name) const; + bool hasCreatedObject(const unsigned int address) const; + bool hasCreatedObject(const std::string name) const; + bool isObject5d(const unsigned int address) const; + bool isObject5d(const std::string name) const; + template + bool isObjectOfType(const unsigned int address) const; + template + bool isObjectOfType(const std::string name) const; + Environment::Size getTotalSize(void) const; + void addOwnership(const unsigned int owner, + const unsigned int property); + void addOwnership(const std::string owner, + const std::string property); + bool hasOwners(const unsigned int address) const; + bool hasOwners(const std::string name) const; + bool freeObject(const unsigned int address); + bool freeObject(const std::string name); + void freeAll(void); + void printContent(void); +private: + // general + bool dryRun_{false}; + unsigned int traj_, locVol_; + // grids + std::vector dim_; + GridPt grid4d_; + std::map grid5d_; + GridRbPt gridRb4d_; + std::map gridRb5d_; + unsigned int nd_; + // random number generator + RngPt rng4d_; + // module and related maps + std::vector module_; + std::map moduleAddress_; + // lattice store + std::map lattice_; + // object store + std::vector object_; + std::map objectAddress_; +}; + +/****************************************************************************** + * Holder template implementation * + ******************************************************************************/ +// constructor ///////////////////////////////////////////////////////////////// +template +Holder::Holder(T *pt) +: objPt_(pt) +{} + +// access ////////////////////////////////////////////////////////////////////// +template +T & Holder::get(void) const +{ + return &objPt_.get(); +} + +template +T * Holder::getPt(void) const +{ + return objPt_.get(); +} + +template +void Holder::reset(T *pt) +{ + objPt_.reset(pt); +} + +/****************************************************************************** + * Environment template implementation * + ******************************************************************************/ +// module management /////////////////////////////////////////////////////////// +template +void Environment::createModule(const std::string name) +{ + ModPt pt(new M(name)); + + pushModule(pt); +} + +template +void Environment::createModule(const std::string name, + const typename M::Par &par) +{ + ModPt pt(new M(name)); + + static_cast(pt.get())->setPar(par); + pushModule(pt); +} + +template +M * Environment::getModule(const unsigned int address) const +{ + if (auto *pt = dynamic_cast(getModule(address))) + { + return pt; + } + else + { + HADRON_ERROR("module '" + module_[address].name + + "' does not have type " + typeid(M).name() + + "(object type: " + getModuleType(address) + ")"); + } +} + +template +M * Environment::getModule(const std::string name) const +{ + return getModule(getModuleAddress(name)); +} + +template +unsigned int Environment::lattice4dSize(void) const +{ + return sizeof(typename T::vector_object)/getGrid()->Nsimd(); +} + +template +void Environment::registerLattice(const unsigned int address, + const unsigned int Ls) +{ + createGrid(Ls); + registerObject(address, Ls*lattice4dSize(), Ls); +} + +template +void Environment::registerLattice(const std::string name, const unsigned int Ls) +{ + createGrid(Ls); + registerObject(name, Ls*lattice4dSize(), Ls); +} + +template +void Environment::setObject(const unsigned int address, T *object) +{ + if (hasRegisteredObject(address)) + { + object_[address].data.reset(new Holder(object)); + object_[address].type = &typeid(T); + } + else if (hasObject(address)) + { + HADRON_ERROR("object with address " + std::to_string(address) + + " exists but is not registered"); + } + else + { + HADRON_ERROR("no object with address " + std::to_string(address)); + } +} + +template +void Environment::setObject(const std::string name, T *object) +{ + setObject(getObjectAddress(name), object); +} + +template +T * Environment::getObject(const unsigned int address) const +{ + if (hasRegisteredObject(address)) + { + if (auto h = dynamic_cast *>(object_[address].data.get())) + { + return h->getPt(); + } + else + { + HADRON_ERROR("object with address " + std::to_string(address) + + " does not have type '" + typeName(&typeid(T)) + + "' (has type '" + getObjectType(address) + "')"); + } + } + else if (hasObject(address)) + { + HADRON_ERROR("object with address " + std::to_string(address) + + " exists but is not registered"); + } + else + { + HADRON_ERROR("no object with address " + std::to_string(address)); + } +} + +template +T * Environment::getObject(const std::string name) const +{ + return getObject(getObjectAddress(name)); +} + +template +T * Environment::createLattice(const unsigned int address) +{ + GridCartesian *g = getGrid(getObjectLs(address)); + + setObject(address, new T(g)); + + return getObject(address); +} + +template +T * Environment::createLattice(const std::string name) +{ + return createLattice(getObjectAddress(name)); +} + +template +bool Environment::isObjectOfType(const unsigned int address) const +{ + if (hasRegisteredObject(address)) + { + if (auto h = dynamic_cast *>(object_[address].data.get())) + { + return true; + } + else + { + return false; + } + } + else if (hasObject(address)) + { + HADRON_ERROR("object with address " + std::to_string(address) + + " exists but is not registered"); + } + else + { + HADRON_ERROR("no object with address " + std::to_string(address)); + } +} + +template +bool Environment::isObjectOfType(const std::string name) const +{ + return isObjectOfType(getObjectAddress(name)); +} + +END_HADRONS_NAMESPACE + +#endif // Hadrons_Environment_hpp_ diff --git a/extras/Hadrons/Factory.hpp b/extras/Hadrons/Factory.hpp new file mode 100644 index 00000000..da86acae --- /dev/null +++ b/extras/Hadrons/Factory.hpp @@ -0,0 +1,106 @@ +/************************************************************************************* + +Grid physics library, www.github.com/paboyle/Grid + +Source file: extras/Hadrons/Factory.hpp + +Copyright (C) 2015 +Copyright (C) 2016 + +Author: Antonin Portelli + +This program is free software; you can redistribute it and/or modify +it under the terms of the GNU General Public License as published by +the Free Software Foundation; either version 2 of the License, or +(at your option) any later version. + +This program is distributed in the hope that it will be useful, +but WITHOUT ANY WARRANTY; without even the implied warranty of +MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +GNU General Public License for more details. + +You should have received a copy of the GNU General Public License along +with this program; if not, write to the Free Software Foundation, Inc., +51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA. + +See the full license in the file "LICENSE" in the top level distribution directory +*************************************************************************************/ +/* END LEGAL */ + +#ifndef Hadrons_Factory_hpp_ +#define Hadrons_Factory_hpp_ + +#include + +BEGIN_HADRONS_NAMESPACE + +/****************************************************************************** + * abstract factory class * + ******************************************************************************/ +template +class Factory +{ +public: + typedef std::function(const std::string)> Func; +public: + // constructor + Factory(void) = default; + // destructor + virtual ~Factory(void) = default; + // registration + void registerBuilder(const std::string type, const Func &f); + // get builder list + std::vector getBuilderList(void) const; + // factory + std::unique_ptr create(const std::string type, + const std::string name) const; +private: + std::map builder_; +}; + +/****************************************************************************** + * template implementation * + ******************************************************************************/ +// registration //////////////////////////////////////////////////////////////// +template +void Factory::registerBuilder(const std::string type, const Func &f) +{ + builder_[type] = f; +} + +// get module list ///////////////////////////////////////////////////////////// +template +std::vector Factory::getBuilderList(void) const +{ + std::vector list; + + for (auto &b: builder_) + { + list.push_back(b.first); + } + + return list; +} + +// factory ///////////////////////////////////////////////////////////////////// +template +std::unique_ptr Factory::create(const std::string type, + const std::string name) const +{ + Func func; + + try + { + func = builder_.at(type); + } + catch (std::out_of_range &) + { + HADRON_ERROR("object of type '" + type + "' unknown"); + } + + return func(name); +} + +END_HADRONS_NAMESPACE + +#endif // Hadrons_Factory_hpp_ diff --git a/extras/Hadrons/GeneticScheduler.hpp b/extras/Hadrons/GeneticScheduler.hpp new file mode 100644 index 00000000..d0c52596 --- /dev/null +++ b/extras/Hadrons/GeneticScheduler.hpp @@ -0,0 +1,329 @@ +/************************************************************************************* + +Grid physics library, www.github.com/paboyle/Grid + +Source file: extras/Hadrons/GeneticScheduler.hpp + +Copyright (C) 2015 +Copyright (C) 2016 + +Author: Antonin Portelli + +This program is free software; you can redistribute it and/or modify +it under the terms of the GNU General Public License as published by +the Free Software Foundation; either version 2 of the License, or +(at your option) any later version. + +This program is distributed in the hope that it will be useful, +but WITHOUT ANY WARRANTY; without even the implied warranty of +MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +GNU General Public License for more details. + +You should have received a copy of the GNU General Public License along +with this program; if not, write to the Free Software Foundation, Inc., +51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA. + +See the full license in the file "LICENSE" in the top level distribution directory +*************************************************************************************/ +/* END LEGAL */ + +#ifndef Hadrons_GeneticScheduler_hpp_ +#define Hadrons_GeneticScheduler_hpp_ + +#include +#include + +BEGIN_HADRONS_NAMESPACE + +/****************************************************************************** + * Scheduler based on a genetic algorithm * + ******************************************************************************/ +template +class GeneticScheduler +{ +public: + typedef std::vector Gene; + typedef std::pair GenePair; + typedef std::function ObjFunc; + struct Parameters + { + double mutationRate; + unsigned int popSize, seed; + }; +public: + // constructor + GeneticScheduler(Graph &graph, const ObjFunc &func, + const Parameters &par); + // destructor + virtual ~GeneticScheduler(void) = default; + // access + const Gene & getMinSchedule(void); + int getMinValue(void); + // breed a new generation + void nextGeneration(void); + // heuristic benchmarks + void benchmarkCrossover(const unsigned int nIt); + // print population + friend std::ostream & operator<<(std::ostream &out, + const GeneticScheduler &s) + { + out << "["; + for (auto &p: s.population_) + { + out << p.first << ", "; + } + out << "\b\b]"; + + return out; + } +private: + // evolution steps + void initPopulation(void); + void doCrossover(void); + void doMutation(void); + // genetic operators + GenePair selectPair(void); + void crossover(Gene &c1, Gene &c2, const Gene &p1, const Gene &p2); + void mutation(Gene &m, const Gene &c); + +private: + Graph &graph_; + const ObjFunc &func_; + const Parameters par_; + std::multimap population_; + std::mt19937 gen_; +}; + +/****************************************************************************** + * template implementation * + ******************************************************************************/ +// constructor ///////////////////////////////////////////////////////////////// +template +GeneticScheduler::GeneticScheduler(Graph &graph, const ObjFunc &func, + const Parameters &par) +: graph_(graph) +, func_(func) +, par_(par) +{ + gen_.seed(par_.seed); +} + +// access ////////////////////////////////////////////////////////////////////// +template +const typename GeneticScheduler::Gene & +GeneticScheduler::getMinSchedule(void) +{ + return population_.begin()->second; +} + +template +int GeneticScheduler::getMinValue(void) +{ + return population_.begin()->first; +} + +// breed a new generation ////////////////////////////////////////////////////// +template +void GeneticScheduler::nextGeneration(void) +{ + // random initialization of the population if necessary + if (population_.size() != par_.popSize) + { + initPopulation(); + } + LOG(Debug) << "Starting population:\n" << *this << std::endl; + + // random mutations + //PARALLEL_FOR_LOOP + for (unsigned int i = 0; i < par_.popSize; ++i) + { + doMutation(); + } + LOG(Debug) << "After mutations:\n" << *this << std::endl; + + // mating + //PARALLEL_FOR_LOOP + for (unsigned int i = 0; i < par_.popSize/2; ++i) + { + doCrossover(); + } + LOG(Debug) << "After mating:\n" << *this << std::endl; + + // grim reaper + auto it = population_.begin(); + + std::advance(it, par_.popSize); + population_.erase(it, population_.end()); + LOG(Debug) << "After grim reaper:\n" << *this << std::endl; +} + +// evolution steps ///////////////////////////////////////////////////////////// +template +void GeneticScheduler::initPopulation(void) +{ + population_.clear(); + for (unsigned int i = 0; i < par_.popSize; ++i) + { + auto p = graph_.topoSort(gen_); + + population_.insert(std::make_pair(func_(p), p)); + } +} + +template +void GeneticScheduler::doCrossover(void) +{ + auto p = selectPair(); + Gene &p1 = *(p.first), &p2 = *(p.second); + Gene c1, c2; + + crossover(c1, c2, p1, p2); + PARALLEL_CRITICAL + { + population_.insert(std::make_pair(func_(c1), c1)); + population_.insert(std::make_pair(func_(c2), c2)); + } +} + +template +void GeneticScheduler::doMutation(void) +{ + std::uniform_real_distribution mdis(0., 1.); + std::uniform_int_distribution pdis(0, population_.size() - 1); + + if (mdis(gen_) < par_.mutationRate) + { + Gene m; + auto it = population_.begin(); + + std::advance(it, pdis(gen_)); + mutation(m, it->second); + PARALLEL_CRITICAL + { + population_.insert(std::make_pair(func_(m), m)); + } + } +} + +// genetic operators /////////////////////////////////////////////////////////// +template +typename GeneticScheduler::GenePair GeneticScheduler::selectPair(void) +{ + std::vector prob; + unsigned int ind; + Gene *p1, *p2; + + for (auto &c: population_) + { + prob.push_back(1./c.first); + } + do + { + double probCpy; + + std::discrete_distribution dis1(prob.begin(), prob.end()); + auto rIt = population_.begin(); + ind = dis1(gen_); + std::advance(rIt, ind); + p1 = &(rIt->second); + probCpy = prob[ind]; + prob[ind] = 0.; + std::discrete_distribution dis2(prob.begin(), prob.end()); + rIt = population_.begin(); + std::advance(rIt, dis2(gen_)); + p2 = &(rIt->second); + prob[ind] = probCpy; + } while (p1 == p2); + + return std::make_pair(p1, p2); +} + +template +void GeneticScheduler::crossover(Gene &c1, Gene &c2, const Gene &p1, + const Gene &p2) +{ + Gene buf; + std::uniform_int_distribution dis(0, p1.size() - 1); + unsigned int cut = dis(gen_); + + c1.clear(); + buf = p2; + for (unsigned int i = 0; i < cut; ++i) + { + c1.push_back(p1[i]); + buf.erase(std::find(buf.begin(), buf.end(), p1[i])); + } + for (unsigned int i = 0; i < buf.size(); ++i) + { + c1.push_back(buf[i]); + } + c2.clear(); + buf = p2; + for (unsigned int i = cut; i < p1.size(); ++i) + { + buf.erase(std::find(buf.begin(), buf.end(), p1[i])); + } + for (unsigned int i = 0; i < buf.size(); ++i) + { + c2.push_back(buf[i]); + } + for (unsigned int i = cut; i < p1.size(); ++i) + { + c2.push_back(p1[i]); + } +} + +template +void GeneticScheduler::mutation(Gene &m, const Gene &c) +{ + Gene buf; + std::uniform_int_distribution dis(0, c.size() - 1); + unsigned int cut = dis(gen_); + Graph g1 = graph_, g2 = graph_; + + for (unsigned int i = 0; i < cut; ++i) + { + g1.removeVertex(c[i]); + } + for (unsigned int i = cut; i < c.size(); ++i) + { + g2.removeVertex(c[i]); + } + if (g1.size() > 0) + { + buf = g1.topoSort(gen_); + } + if (g2.size() > 0) + { + m = g2.topoSort(gen_); + } + for (unsigned int i = cut; i < c.size(); ++i) + { + m.push_back(buf[i - cut]); + } +} + +template +void GeneticScheduler::benchmarkCrossover(const unsigned int nIt) +{ + Gene p1, p2, c1, c2; + double neg = 0., eq = 0., pos = 0., total; + int improvement; + + LOG(Message) << "Benchmarking crossover..." << std::endl; + for (unsigned int i = 0; i < nIt; ++i) + { + p1 = graph_.topoSort(gen_); + p2 = graph_.topoSort(gen_); + crossover(c1, c2, p1, p2); + improvement = (func_(c1) + func_(c2) - func_(p1) - func_(p2))/2; + if (improvement < 0) neg++; else if (improvement == 0) eq++; else pos++; + } + total = neg + eq + pos; + LOG(Message) << " -: " << neg/total << " =: " << eq/total + << " +: " << pos/total << std::endl; +} + +END_HADRONS_NAMESPACE + +#endif // Hadrons_GeneticScheduler_hpp_ diff --git a/extras/Hadrons/Global.cc b/extras/Hadrons/Global.cc new file mode 100644 index 00000000..7b0b8fb6 --- /dev/null +++ b/extras/Hadrons/Global.cc @@ -0,0 +1,82 @@ +/************************************************************************************* + +Grid physics library, www.github.com/paboyle/Grid + +Source file: extras/Hadrons/Global.cc + +Copyright (C) 2015 +Copyright (C) 2016 + +Author: Antonin Portelli + +This program is free software; you can redistribute it and/or modify +it under the terms of the GNU General Public License as published by +the Free Software Foundation; either version 2 of the License, or +(at your option) any later version. + +This program is distributed in the hope that it will be useful, +but WITHOUT ANY WARRANTY; without even the implied warranty of +MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +GNU General Public License for more details. + +You should have received a copy of the GNU General Public License along +with this program; if not, write to the Free Software Foundation, Inc., +51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA. + +See the full license in the file "LICENSE" in the top level distribution directory +*************************************************************************************/ +/* END LEGAL */ + +#include + +using namespace Grid; +using namespace QCD; +using namespace Hadrons; + +HadronsLogger Hadrons::HadronsLogError(1,"Error"); +HadronsLogger Hadrons::HadronsLogWarning(1,"Warning"); +HadronsLogger Hadrons::HadronsLogMessage(1,"Message"); +HadronsLogger Hadrons::HadronsLogIterative(1,"Iterative"); +HadronsLogger Hadrons::HadronsLogDebug(1,"Debug"); + +// pretty size formatting ////////////////////////////////////////////////////// +std::string Hadrons::sizeString(long unsigned int bytes) + +{ + constexpr unsigned int bufSize = 256; + const char *suffixes[7] = {"", "K", "M", "G", "T", "P", "E"}; + char buf[256]; + long unsigned int s = 0; + double count = bytes; + + while (count >= 1024 && s < 7) + { + s++; + count /= 1024; + } + if (count - floor(count) == 0.0) + { + snprintf(buf, bufSize, "%d %sB", (int)count, suffixes[s]); + } + else + { + snprintf(buf, bufSize, "%.1f %sB", count, suffixes[s]); + } + + return std::string(buf); +} + +// type utilities ////////////////////////////////////////////////////////////// +constexpr unsigned int maxNameSize = 1024u; + +std::string Hadrons::typeName(const std::type_info *info) +{ + char *buf; + std::string name; + + buf = abi::__cxa_demangle(info->name(), nullptr, nullptr, nullptr); + name = buf; + free(buf); + + return name; +} diff --git a/extras/Hadrons/Global.hpp b/extras/Hadrons/Global.hpp new file mode 100644 index 00000000..9de01623 --- /dev/null +++ b/extras/Hadrons/Global.hpp @@ -0,0 +1,179 @@ +/************************************************************************************* + +Grid physics library, www.github.com/paboyle/Grid + +Source file: extras/Hadrons/Global.hpp + +Copyright (C) 2015 +Copyright (C) 2016 + +Author: Antonin Portelli + +This program is free software; you can redistribute it and/or modify +it under the terms of the GNU General Public License as published by +the Free Software Foundation; either version 2 of the License, or +(at your option) any later version. + +This program is distributed in the hope that it will be useful, +but WITHOUT ANY WARRANTY; without even the implied warranty of +MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +GNU General Public License for more details. + +You should have received a copy of the GNU General Public License along +with this program; if not, write to the Free Software Foundation, Inc., +51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA. + +See the full license in the file "LICENSE" in the top level distribution directory +*************************************************************************************/ +/* END LEGAL */ + +#ifndef Hadrons_Global_hpp_ +#define Hadrons_Global_hpp_ + +#include +#include +#include +#include + +#define BEGIN_HADRONS_NAMESPACE \ +namespace Grid {\ +using namespace QCD;\ +namespace Hadrons {\ +using Grid::operator<<; +#define END_HADRONS_NAMESPACE }} + +#define BEGIN_MODULE_NAMESPACE(name)\ +namespace name {\ +using Grid::operator<<; +#define END_MODULE_NAMESPACE } + +/* the 'using Grid::operator<<;' statement prevents a very nasty compilation + * error with GCC 5 (clang & GCC 6 compile fine without it). + */ + +#ifndef FIMPL +#define FIMPL WilsonImplR +#endif +#ifndef SIMPL +#define SIMPL ScalarImplCR +#endif + +BEGIN_HADRONS_NAMESPACE + +// type aliases +#define FERM_TYPE_ALIASES(FImpl, suffix)\ +typedef FermionOperator FMat##suffix; \ +typedef typename FImpl::FermionField FermionField##suffix; \ +typedef typename FImpl::PropagatorField PropagatorField##suffix; \ +typedef typename FImpl::SitePropagator SitePropagator##suffix; \ +typedef std::vector \ + SlicedPropagator##suffix; + +#define GAUGE_TYPE_ALIASES(FImpl, suffix)\ +typedef typename FImpl::DoubledGaugeField DoubledGaugeField##suffix; + +#define SCALAR_TYPE_ALIASES(SImpl, suffix)\ +typedef typename SImpl::Field ScalarField##suffix;\ +typedef typename SImpl::Field PropagatorField##suffix; + +#define SOLVER_TYPE_ALIASES(FImpl, suffix)\ +typedef std::function SolverFn##suffix; + +#define SINK_TYPE_ALIASES(suffix)\ +typedef std::function SinkFn##suffix; + +#define FGS_TYPE_ALIASES(FImpl, suffix)\ +FERM_TYPE_ALIASES(FImpl, suffix)\ +GAUGE_TYPE_ALIASES(FImpl, suffix)\ +SOLVER_TYPE_ALIASES(FImpl, suffix) + +// logger +class HadronsLogger: public Logger +{ +public: + HadronsLogger(int on, std::string nm): Logger("Hadrons", on, nm, + GridLogColours, "BLACK"){}; +}; + +#define LOG(channel) std::cout << HadronsLog##channel +#define HADRON_ERROR(msg)\ +LOG(Error) << msg << " (" << __FUNCTION__ << " at " << __FILE__ << ":"\ + << __LINE__ << ")" << std::endl;\ +abort(); + +#define DEBUG_VAR(var) LOG(Debug) << #var << "= " << (var) << std::endl; + +extern HadronsLogger HadronsLogError; +extern HadronsLogger HadronsLogWarning; +extern HadronsLogger HadronsLogMessage; +extern HadronsLogger HadronsLogIterative; +extern HadronsLogger HadronsLogDebug; + +// singleton pattern +#define SINGLETON(name)\ +public:\ + name(const name &e) = delete;\ + void operator=(const name &e) = delete;\ + static name & getInstance(void)\ + {\ + static name e;\ + return e;\ + }\ +private:\ + name(void); + +#define SINGLETON_DEFCTOR(name)\ +public:\ + name(const name &e) = delete;\ + void operator=(const name &e) = delete;\ + static name & getInstance(void)\ + {\ + static name e;\ + return e;\ + }\ +private:\ + name(void) = default; + +// pretty size formating +std::string sizeString(long unsigned int bytes); + +// type utilities +template +const std::type_info * typeIdPt(const T &x) +{ + return &typeid(x); +} + +std::string typeName(const std::type_info *info); + +template +const std::type_info * typeIdPt(void) +{ + return &typeid(T); +} + +template +std::string typeName(const T &x) +{ + return typeName(typeIdPt(x)); +} + +template +std::string typeName(void) +{ + return typeName(typeIdPt()); +} + +// default writers/readers +#ifdef HAVE_HDF5 +typedef Hdf5Reader CorrReader; +typedef Hdf5Writer CorrWriter; +#else +typedef XmlReader CorrReader; +typedef XmlWriter CorrWriter; +#endif + +END_HADRONS_NAMESPACE + +#endif // Hadrons_Global_hpp_ diff --git a/extras/Hadrons/Graph.hpp b/extras/Hadrons/Graph.hpp new file mode 100644 index 00000000..df255517 --- /dev/null +++ b/extras/Hadrons/Graph.hpp @@ -0,0 +1,760 @@ +/************************************************************************************* + +Grid physics library, www.github.com/paboyle/Grid + +Source file: extras/Hadrons/Graph.hpp + +Copyright (C) 2015 +Copyright (C) 2016 + +Author: Antonin Portelli + +This program is free software; you can redistribute it and/or modify +it under the terms of the GNU General Public License as published by +the Free Software Foundation; either version 2 of the License, or +(at your option) any later version. + +This program is distributed in the hope that it will be useful, +but WITHOUT ANY WARRANTY; without even the implied warranty of +MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +GNU General Public License for more details. + +You should have received a copy of the GNU General Public License along +with this program; if not, write to the Free Software Foundation, Inc., +51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA. + +See the full license in the file "LICENSE" in the top level distribution directory +*************************************************************************************/ +/* END LEGAL */ + +#ifndef Hadrons_Graph_hpp_ +#define Hadrons_Graph_hpp_ + +#include + +BEGIN_HADRONS_NAMESPACE + +/****************************************************************************** + * Oriented graph class * + ******************************************************************************/ +// I/O for edges +template +std::ostream & operator<<(std::ostream &out, const std::pair &e) +{ + out << "\"" << e.first << "\" -> \"" << e.second << "\""; + + return out; +} + +// main class +template +class Graph +{ +public: + typedef std::pair Edge; +public: + // constructor + Graph(void); + // destructor + virtual ~Graph(void) = default; + // access + void addVertex(const T &value); + void addEdge(const Edge &e); + void addEdge(const T &start, const T &end); + std::vector getVertices(void) const; + void removeVertex(const T &value); + void removeEdge(const Edge &e); + void removeEdge(const T &start, const T &end); + unsigned int size(void) const; + // tests + bool gotValue(const T &value) const; + // graph topological manipulations + std::vector getAdjacentVertices(const T &value) const; + std::vector getChildren(const T &value) const; + std::vector getParents(const T &value) const; + std::vector getRoots(void) const; + std::vector> getConnectedComponents(void) const; + std::vector topoSort(void); + template + std::vector topoSort(Gen &gen); + std::vector> allTopoSort(void); + // I/O + friend std::ostream & operator<<(std::ostream &out, const Graph &g) + { + out << "{"; + for (auto &e: g.edgeSet_) + { + out << e << ", "; + } + if (g.edgeSet_.size() != 0) + { + out << "\b\b"; + } + out << "}"; + + return out; + } +private: + // vertex marking + void mark(const T &value, const bool doMark = true); + void markAll(const bool doMark = true); + void unmark(const T &value); + void unmarkAll(void); + bool isMarked(const T &value) const; + const T * getFirstMarked(const bool isMarked = true) const; + template + const T * getRandomMarked(const bool isMarked, Gen &gen); + const T * getFirstUnmarked(void) const; + template + const T * getRandomUnmarked(Gen &gen); + // prune marked/unmarked vertices + void removeMarked(const bool isMarked = true); + void removeUnmarked(void); + // depth-first search marking + void depthFirstSearch(void); + void depthFirstSearch(const T &root); +private: + std::map isMarked_; + std::set edgeSet_; +}; + +// build depedency matrix from topological sorts +template +std::map> +makeDependencyMatrix(const std::vector> &topSort); + +/****************************************************************************** + * template implementation * + ****************************************************************************** + * in all the following V is the number of vertex and E is the number of edge + * in the worst case E = V^2 + */ + +// constructor ///////////////////////////////////////////////////////////////// +template +Graph::Graph(void) +{} + +// access ////////////////////////////////////////////////////////////////////// +// complexity: log(V) +template +void Graph::addVertex(const T &value) +{ + isMarked_[value] = false; +} + +// complexity: O(log(V)) +template +void Graph::addEdge(const Edge &e) +{ + addVertex(e.first); + addVertex(e.second); + edgeSet_.insert(e); +} + +// complexity: O(log(V)) +template +void Graph::addEdge(const T &start, const T &end) +{ + addEdge(Edge(start, end)); +} + +template +std::vector Graph::getVertices(void) const +{ + std::vector vertex; + + for (auto &v: isMarked_) + { + vertex.push_back(v.first); + } + + return vertex; +} + +// complexity: O(V*log(V)) +template +void Graph::removeVertex(const T &value) +{ + // remove vertex from the mark table + auto vIt = isMarked_.find(value); + + if (vIt != isMarked_.end()) + { + isMarked_.erase(vIt); + } + else + { + HADRON_ERROR("vertex " << value << " does not exists"); + } + + // remove all edges containing the vertex + auto pred = [&value](const Edge &e) + { + return ((e.first == value) or (e.second == value)); + }; + auto eIt = find_if(edgeSet_.begin(), edgeSet_.end(), pred); + + while (eIt != edgeSet_.end()) + { + edgeSet_.erase(eIt); + eIt = find_if(edgeSet_.begin(), edgeSet_.end(), pred); + } +} + +// complexity: O(log(V)) +template +void Graph::removeEdge(const Edge &e) +{ + auto eIt = edgeSet_.find(e); + + if (eIt != edgeSet_.end()) + { + edgeSet_.erase(eIt); + } + else + { + HADRON_ERROR("edge " << e << " does not exists"); + } +} + +// complexity: O(log(V)) +template +void Graph::removeEdge(const T &start, const T &end) +{ + removeEdge(Edge(start, end)); +} + +// complexity: O(1) +template +unsigned int Graph::size(void) const +{ + return isMarked_.size(); +} + +// tests /////////////////////////////////////////////////////////////////////// +// complexity: O(log(V)) +template +bool Graph::gotValue(const T &value) const +{ + auto it = isMarked_.find(value); + + if (it == isMarked_.end()) + { + return false; + } + else + { + return true; + } +} + +// vertex marking ////////////////////////////////////////////////////////////// +// complexity: O(log(V)) +template +void Graph::mark(const T &value, const bool doMark) +{ + if (gotValue(value)) + { + isMarked_[value] = doMark; + } + else + { + HADRON_ERROR("vertex " << value << " does not exists"); + } +} + +// complexity: O(V*log(V)) +template +void Graph::markAll(const bool doMark) +{ + for (auto &v: isMarked_) + { + mark(v.first, doMark); + } +} + +// complexity: O(log(V)) +template +void Graph::unmark(const T &value) +{ + mark(value, false); +} + +// complexity: O(V*log(V)) +template +void Graph::unmarkAll(void) +{ + markAll(false); +} + +// complexity: O(log(V)) +template +bool Graph::isMarked(const T &value) const +{ + if (gotValue(value)) + { + return isMarked_.at(value); + } + else + { + HADRON_ERROR("vertex " << value << " does not exists"); + + return false; + } +} + +// complexity: O(log(V)) +template +const T * Graph::getFirstMarked(const bool isMarked) const +{ + auto pred = [&isMarked](const std::pair &v) + { + return (v.second == isMarked); + }; + auto vIt = std::find_if(isMarked_.begin(), isMarked_.end(), pred); + + if (vIt != isMarked_.end()) + { + return &(vIt->first); + } + else + { + return nullptr; + } +} + +// complexity: O(log(V)) +template +template +const T * Graph::getRandomMarked(const bool isMarked, Gen &gen) +{ + auto pred = [&isMarked](const std::pair &v) + { + return (v.second == isMarked); + }; + std::uniform_int_distribution dis(0, size() - 1); + auto rIt = isMarked_.begin(); + + std::advance(rIt, dis(gen)); + auto vIt = std::find_if(rIt, isMarked_.end(), pred); + if (vIt != isMarked_.end()) + { + return &(vIt->first); + } + else + { + vIt = std::find_if(isMarked_.begin(), rIt, pred); + if (vIt != rIt) + { + return &(vIt->first); + } + else + { + return nullptr; + } + } +} + +// complexity: O(log(V)) +template +const T * Graph::getFirstUnmarked(void) const +{ + return getFirstMarked(false); +} + +// complexity: O(log(V)) +template +template +const T * Graph::getRandomUnmarked(Gen &gen) +{ + return getRandomMarked(false, gen); +} + +// prune marked/unmarked vertices ////////////////////////////////////////////// +// complexity: O(V^2*log(V)) +template +void Graph::removeMarked(const bool isMarked) +{ + auto isMarkedCopy = isMarked_; + + for (auto &v: isMarkedCopy) + { + if (v.second == isMarked) + { + removeVertex(v.first); + } + } +} + +// complexity: O(V^2*log(V)) +template +void Graph::removeUnmarked(void) +{ + removeMarked(false); +} + +// depth-first search marking ////////////////////////////////////////////////// +// complexity: O(V*log(V)) +template +void Graph::depthFirstSearch(void) +{ + depthFirstSearch(isMarked_.begin()->first); +} + +// complexity: O(V*log(V)) +template +void Graph::depthFirstSearch(const T &root) +{ + std::vector adjacentVertex; + + mark(root); + adjacentVertex = getAdjacentVertices(root); + for (auto &v: adjacentVertex) + { + if (!isMarked(v)) + { + depthFirstSearch(v); + } + } +} + +// graph topological manipulations ///////////////////////////////////////////// +// complexity: O(V*log(V)) +template +std::vector Graph::getAdjacentVertices(const T &value) const +{ + std::vector adjacentVertex; + + auto pred = [&value](const Edge &e) + { + return ((e.first == value) or (e.second == value)); + }; + auto eIt = find_if(edgeSet_.begin(), edgeSet_.end(), pred); + + while (eIt != edgeSet_.end()) + { + if (eIt->first == value) + { + adjacentVertex.push_back((*eIt).second); + } + else if (eIt->second == value) + { + adjacentVertex.push_back((*eIt).first); + } + eIt = find_if(++eIt, edgeSet_.end(), pred); + } + + return adjacentVertex; +} + +// complexity: O(V*log(V)) +template +std::vector Graph::getChildren(const T &value) const +{ + std::vector child; + + auto pred = [&value](const Edge &e) + { + return (e.first == value); + }; + auto eIt = find_if(edgeSet_.begin(), edgeSet_.end(), pred); + + while (eIt != edgeSet_.end()) + { + child.push_back((*eIt).second); + eIt = find_if(++eIt, edgeSet_.end(), pred); + } + + return child; +} + +// complexity: O(V*log(V)) +template +std::vector Graph::getParents(const T &value) const +{ + std::vector parent; + + auto pred = [&value](const Edge &e) + { + return (e.second == value); + }; + auto eIt = find_if(edgeSet_.begin(), edgeSet_.end(), pred); + + while (eIt != edgeSet_.end()) + { + parent.push_back((*eIt).first); + eIt = find_if(++eIt, edgeSet_.end(), pred); + } + + return parent; +} + +// complexity: O(V^2*log(V)) +template +std::vector Graph::getRoots(void) const +{ + std::vector root; + + for (auto &v: isMarked_) + { + auto parent = getParents(v.first); + + if (parent.size() == 0) + { + root.push_back(v.first); + } + } + + return root; +} + +// complexity: O(V^2*log(V)) +template +std::vector> Graph::getConnectedComponents(void) const +{ + std::vector> res; + Graph copy(*this); + + while (copy.size() > 0) + { + copy.depthFirstSearch(); + res.push_back(copy); + res.back().removeUnmarked(); + res.back().unmarkAll(); + copy.removeMarked(); + copy.unmarkAll(); + } + + return res; +} + +// topological sort using a directed DFS algorithm +// complexity: O(V*log(V)) +template +std::vector Graph::topoSort(void) +{ + std::stack buf; + std::vector res; + const T *vPt; + std::map tmpMarked(isMarked_); + + // visit function + std::function visit = [&](const T &v) + { + if (tmpMarked.at(v)) + { + HADRON_ERROR("cannot topologically sort a cyclic graph"); + } + if (!isMarked(v)) + { + std::vector child = getChildren(v); + + tmpMarked[v] = true; + for (auto &c: child) + { + visit(c); + } + mark(v); + tmpMarked[v] = false; + buf.push(v); + } + }; + + // reset temporary marks + for (auto &v: tmpMarked) + { + tmpMarked.at(v.first) = false; + } + + // loop on unmarked vertices + unmarkAll(); + vPt = getFirstUnmarked(); + while (vPt) + { + visit(*vPt); + vPt = getFirstUnmarked(); + } + unmarkAll(); + + // create result vector + while (!buf.empty()) + { + res.push_back(buf.top()); + buf.pop(); + } + + return res; +} + +// random version of the topological sort +// complexity: O(V*log(V)) +template +template +std::vector Graph::topoSort(Gen &gen) +{ + std::stack buf; + std::vector res; + const T *vPt; + std::map tmpMarked(isMarked_); + + // visit function + std::function visit = [&](const T &v) + { + if (tmpMarked.at(v)) + { + HADRON_ERROR("cannot topologically sort a cyclic graph"); + } + if (!isMarked(v)) + { + std::vector child = getChildren(v); + + tmpMarked[v] = true; + std::shuffle(child.begin(), child.end(), gen); + for (auto &c: child) + { + visit(c); + } + mark(v); + tmpMarked[v] = false; + buf.push(v); + } + }; + + // reset temporary marks + for (auto &v: tmpMarked) + { + tmpMarked.at(v.first) = false; + } + + // loop on unmarked vertices + unmarkAll(); + vPt = getRandomUnmarked(gen); + while (vPt) + { + visit(*vPt); + vPt = getRandomUnmarked(gen); + } + unmarkAll(); + + // create result vector + while (!buf.empty()) + { + res.push_back(buf.top()); + buf.pop(); + } + + return res; +} + +// generate all possible topological sorts +// Y. L. Varol & D. Rotem, Comput. J. 24(1), pp. 83–84, 1981 +// http://comjnl.oupjournals.org/cgi/doi/10.1093/comjnl/24.1.83 +// complexity: O(V*log(V)) (from the paper, but really ?) +template +std::vector> Graph::allTopoSort(void) +{ + std::vector> res; + std::map> iMat; + + // create incidence matrix + for (auto &v1: isMarked_) + for (auto &v2: isMarked_) + { + iMat[v1.first][v2.first] = false; + } + for (auto &v: isMarked_) + { + auto cVec = getChildren(v.first); + + for (auto &c: cVec) + { + iMat[v.first][c] = true; + } + } + + // generate initial topological sort + res.push_back(topoSort()); + + // generate all other topological sorts by permutation + std::vector p = res[0]; + const unsigned int n = size(); + std::vector loc(n); + unsigned int i, k, k1; + T obj_k, obj_k1; + bool isFinal; + + for (unsigned int j = 0; j < n; ++j) + { + loc[j] = j; + } + i = 0; + while (i < n-1) + { + k = loc[i]; + k1 = k + 1; + obj_k = p[k]; + if (k1 >= n) + { + isFinal = true; + obj_k1 = obj_k; + } + else + { + isFinal = false; + obj_k1 = p[k1]; + } + if (iMat[res[0][i]][obj_k1] or isFinal) + { + for (unsigned int l = k; l >= i + 1; --l) + { + p[l] = p[l-1]; + } + p[i] = obj_k; + loc[i] = i; + i++; + } + else + { + p[k] = obj_k1; + p[k1] = obj_k; + loc[i] = k1; + i = 0; + res.push_back(p); + } + } + + return res; +} + +// build depedency matrix from topological sorts /////////////////////////////// +// complexity: something like O(V^2*log(V!)) +template +std::map> +makeDependencyMatrix(const std::vector> &topSort) +{ + std::map> m; + const std::vector &vList = topSort[0]; + + for (auto &v1: vList) + for (auto &v2: vList) + { + bool dep = true; + + for (auto &t: topSort) + { + auto i1 = std::find(t.begin(), t.end(), v1); + auto i2 = std::find(t.begin(), t.end(), v2); + + dep = dep and (i1 - i2 > 0); + if (!dep) break; + } + m[v1][v2] = dep; + } + + return m; +} + +END_HADRONS_NAMESPACE + +#endif // Hadrons_Graph_hpp_ diff --git a/extras/Hadrons/HadronsXmlRun.cc b/extras/Hadrons/HadronsXmlRun.cc new file mode 100644 index 00000000..0dff8f9a --- /dev/null +++ b/extras/Hadrons/HadronsXmlRun.cc @@ -0,0 +1,80 @@ +/************************************************************************************* + +Grid physics library, www.github.com/paboyle/Grid + +Source file: extras/Hadrons/HadronsXmlRun.cc + +Copyright (C) 2015 +Copyright (C) 2016 + +Author: Antonin Portelli + +This program is free software; you can redistribute it and/or modify +it under the terms of the GNU General Public License as published by +the Free Software Foundation; either version 2 of the License, or +(at your option) any later version. + +This program is distributed in the hope that it will be useful, +but WITHOUT ANY WARRANTY; without even the implied warranty of +MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +GNU General Public License for more details. + +You should have received a copy of the GNU General Public License along +with this program; if not, write to the Free Software Foundation, Inc., +51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA. + +See the full license in the file "LICENSE" in the top level distribution directory +*************************************************************************************/ +/* END LEGAL */ + +#include + +using namespace Grid; +using namespace QCD; +using namespace Hadrons; + +int main(int argc, char *argv[]) +{ + // parse command line + std::string parameterFileName, scheduleFileName = ""; + + if (argc < 2) + { + std::cerr << "usage: " << argv[0] << " [] [Grid options]"; + std::cerr << std::endl; + std::exit(EXIT_FAILURE); + } + parameterFileName = argv[1]; + if (argc > 2) + { + if (argv[2][0] != '-') + { + scheduleFileName = argv[2]; + } + } + + // initialization + Grid_init(&argc, &argv); + HadronsLogError.Active(GridLogError.isActive()); + HadronsLogWarning.Active(GridLogWarning.isActive()); + HadronsLogMessage.Active(GridLogMessage.isActive()); + HadronsLogIterative.Active(GridLogIterative.isActive()); + HadronsLogDebug.Active(GridLogDebug.isActive()); + LOG(Message) << "Grid initialized" << std::endl; + + // execution + Application application(parameterFileName); + + application.parseParameterFile(parameterFileName); + if (!scheduleFileName.empty()) + { + application.loadSchedule(scheduleFileName); + } + application.run(); + + // epilogue + LOG(Message) << "Grid is finalizing now" << std::endl; + Grid_finalize(); + + return EXIT_SUCCESS; +} diff --git a/extras/Hadrons/HadronsXmlSchedule.cc b/extras/Hadrons/HadronsXmlSchedule.cc new file mode 100644 index 00000000..a8ca9a63 --- /dev/null +++ b/extras/Hadrons/HadronsXmlSchedule.cc @@ -0,0 +1,72 @@ +/************************************************************************************* + +Grid physics library, www.github.com/paboyle/Grid + +Source file: extras/Hadrons/HadronsXmlSchedule.cc + +Copyright (C) 2015 +Copyright (C) 2016 + +Author: Antonin Portelli + +This program is free software; you can redistribute it and/or modify +it under the terms of the GNU General Public License as published by +the Free Software Foundation; either version 2 of the License, or +(at your option) any later version. + +This program is distributed in the hope that it will be useful, +but WITHOUT ANY WARRANTY; without even the implied warranty of +MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +GNU General Public License for more details. + +You should have received a copy of the GNU General Public License along +with this program; if not, write to the Free Software Foundation, Inc., +51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA. + +See the full license in the file "LICENSE" in the top level distribution directory +*************************************************************************************/ +/* END LEGAL */ + +#include + +using namespace Grid; +using namespace QCD; +using namespace Hadrons; + +int main(int argc, char *argv[]) +{ + // parse command line + std::string parameterFileName, scheduleFileName; + + if (argc < 3) + { + std::cerr << "usage: " << argv[0] << " [Grid options]"; + std::cerr << std::endl; + std::exit(EXIT_FAILURE); + } + parameterFileName = argv[1]; + scheduleFileName = argv[2]; + + // initialization + Grid_init(&argc, &argv); + HadronsLogError.Active(GridLogError.isActive()); + HadronsLogWarning.Active(GridLogWarning.isActive()); + HadronsLogMessage.Active(GridLogMessage.isActive()); + HadronsLogIterative.Active(GridLogIterative.isActive()); + HadronsLogDebug.Active(GridLogDebug.isActive()); + LOG(Message) << "Grid initialized" << std::endl; + + // execution + Application application; + + application.parseParameterFile(parameterFileName); + application.schedule(); + application.printSchedule(); + application.saveSchedule(scheduleFileName); + + // epilogue + LOG(Message) << "Grid is finalizing now" << std::endl; + Grid_finalize(); + + return EXIT_SUCCESS; +} diff --git a/extras/Hadrons/Makefile.am b/extras/Hadrons/Makefile.am new file mode 100644 index 00000000..9cb23600 --- /dev/null +++ b/extras/Hadrons/Makefile.am @@ -0,0 +1,29 @@ +lib_LIBRARIES = libHadrons.a +bin_PROGRAMS = HadronsXmlRun HadronsXmlSchedule + +include modules.inc + +libHadrons_a_SOURCES = \ + $(modules_cc) \ + Application.cc \ + Environment.cc \ + Global.cc \ + Module.cc +libHadrons_adir = $(pkgincludedir)/Hadrons +nobase_libHadrons_a_HEADERS = \ + $(modules_hpp) \ + Application.hpp \ + Environment.hpp \ + Factory.hpp \ + GeneticScheduler.hpp \ + Global.hpp \ + Graph.hpp \ + Module.hpp \ + Modules.hpp \ + ModuleFactory.hpp + +HadronsXmlRun_SOURCES = HadronsXmlRun.cc +HadronsXmlRun_LDADD = libHadrons.a -lGrid + +HadronsXmlSchedule_SOURCES = HadronsXmlSchedule.cc +HadronsXmlSchedule_LDADD = libHadrons.a -lGrid diff --git a/extras/Hadrons/Module.cc b/extras/Hadrons/Module.cc new file mode 100644 index 00000000..2549a931 --- /dev/null +++ b/extras/Hadrons/Module.cc @@ -0,0 +1,71 @@ +/************************************************************************************* + +Grid physics library, www.github.com/paboyle/Grid + +Source file: extras/Hadrons/Module.cc + +Copyright (C) 2015 +Copyright (C) 2016 + +Author: Antonin Portelli + +This program is free software; you can redistribute it and/or modify +it under the terms of the GNU General Public License as published by +the Free Software Foundation; either version 2 of the License, or +(at your option) any later version. + +This program is distributed in the hope that it will be useful, +but WITHOUT ANY WARRANTY; without even the implied warranty of +MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +GNU General Public License for more details. + +You should have received a copy of the GNU General Public License along +with this program; if not, write to the Free Software Foundation, Inc., +51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA. + +See the full license in the file "LICENSE" in the top level distribution directory +*************************************************************************************/ +/* END LEGAL */ + +#include + +using namespace Grid; +using namespace QCD; +using namespace Hadrons; + +/****************************************************************************** + * ModuleBase implementation * + ******************************************************************************/ +// constructor ///////////////////////////////////////////////////////////////// +ModuleBase::ModuleBase(const std::string name) +: name_(name) +, env_(Environment::getInstance()) +{} + +// access ////////////////////////////////////////////////////////////////////// +std::string ModuleBase::getName(void) const +{ + return name_; +} + +Environment & ModuleBase::env(void) const +{ + return env_; +} + +// get factory registration name if available +std::string ModuleBase::getRegisteredName(void) +{ + HADRON_ERROR("module '" + getName() + "' has a type not registered" + + " in the factory"); +} + +// execution /////////////////////////////////////////////////////////////////// +void ModuleBase::operator()(void) +{ + setup(); + if (!env().isDryRun()) + { + execute(); + } +} diff --git a/extras/Hadrons/Module.hpp b/extras/Hadrons/Module.hpp new file mode 100644 index 00000000..071e254a --- /dev/null +++ b/extras/Hadrons/Module.hpp @@ -0,0 +1,198 @@ +/************************************************************************************* + +Grid physics library, www.github.com/paboyle/Grid + +Source file: extras/Hadrons/Module.hpp + +Copyright (C) 2015 +Copyright (C) 2016 + +Author: Antonin Portelli + +This program is free software; you can redistribute it and/or modify +it under the terms of the GNU General Public License as published by +the Free Software Foundation; either version 2 of the License, or +(at your option) any later version. + +This program is distributed in the hope that it will be useful, +but WITHOUT ANY WARRANTY; without even the implied warranty of +MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +GNU General Public License for more details. + +You should have received a copy of the GNU General Public License along +with this program; if not, write to the Free Software Foundation, Inc., +51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA. + +See the full license in the file "LICENSE" in the top level distribution directory +*************************************************************************************/ +/* END LEGAL */ + +#ifndef Hadrons_Module_hpp_ +#define Hadrons_Module_hpp_ + +#include +#include + +BEGIN_HADRONS_NAMESPACE + +// module registration macros +#define MODULE_REGISTER(mod, base)\ +class mod: public base\ +{\ +public:\ + typedef base Base;\ + using Base::Base;\ + virtual std::string getRegisteredName(void)\ + {\ + return std::string(#mod);\ + }\ +};\ +class mod##ModuleRegistrar\ +{\ +public:\ + mod##ModuleRegistrar(void)\ + {\ + ModuleFactory &modFac = ModuleFactory::getInstance();\ + modFac.registerBuilder(#mod, [&](const std::string name)\ + {\ + return std::unique_ptr(new mod(name));\ + });\ + }\ +};\ +static mod##ModuleRegistrar mod##ModuleRegistrarInstance; + +#define MODULE_REGISTER_NS(mod, base, ns)\ +class mod: public base\ +{\ +public:\ + typedef base Base;\ + using Base::Base;\ + virtual std::string getRegisteredName(void)\ + {\ + return std::string(#ns "::" #mod);\ + }\ +};\ +class ns##mod##ModuleRegistrar\ +{\ +public:\ + ns##mod##ModuleRegistrar(void)\ + {\ + ModuleFactory &modFac = ModuleFactory::getInstance();\ + modFac.registerBuilder(#ns "::" #mod, [&](const std::string name)\ + {\ + return std::unique_ptr(new ns::mod(name));\ + });\ + }\ +};\ +static ns##mod##ModuleRegistrar ns##mod##ModuleRegistrarInstance; + +#define ARG(...) __VA_ARGS__ + +/****************************************************************************** + * Module class * + ******************************************************************************/ +// base class +class ModuleBase +{ +public: + // constructor + ModuleBase(const std::string name); + // destructor + virtual ~ModuleBase(void) = default; + // access + std::string getName(void) const; + Environment &env(void) const; + // get factory registration name if available + virtual std::string getRegisteredName(void); + // dependencies/products + virtual std::vector getInput(void) = 0; + virtual std::vector getOutput(void) = 0; + // parse parameters + virtual void parseParameters(XmlReader &reader, const std::string name) = 0; + virtual void saveParameters(XmlWriter &writer, const std::string name) = 0; + // setup + virtual void setup(void) {}; + // execution + void operator()(void); + virtual void execute(void) = 0; +private: + std::string name_; + Environment &env_; +}; + +// derived class, templating the parameter class +template +class Module: public ModuleBase +{ +public: + typedef P Par; +public: + // constructor + Module(const std::string name); + // destructor + virtual ~Module(void) = default; + // parse parameters + virtual void parseParameters(XmlReader &reader, const std::string name); + virtual void saveParameters(XmlWriter &writer, const std::string name); + // parameter access + const P & par(void) const; + void setPar(const P &par); +private: + P par_; +}; + +// no parameter type +class NoPar {}; + +template <> +class Module: public ModuleBase +{ +public: + // constructor + Module(const std::string name): ModuleBase(name) {}; + // destructor + virtual ~Module(void) = default; + // parse parameters (do nothing) + virtual void parseParameters(XmlReader &reader, const std::string name) {}; + virtual void saveParameters(XmlWriter &writer, const std::string name) + { + push(writer, "options"); + pop(writer); + }; +}; + +/****************************************************************************** + * Template implementation * + ******************************************************************************/ +template +Module

::Module(const std::string name) +: ModuleBase(name) +{} + +template +void Module

::parseParameters(XmlReader &reader, const std::string name) +{ + read(reader, name, par_); +} + +template +void Module

::saveParameters(XmlWriter &writer, const std::string name) +{ + write(writer, name, par_); +} + +template +const P & Module

::par(void) const +{ + return par_; +} + +template +void Module

::setPar(const P &par) +{ + par_ = par; +} + +END_HADRONS_NAMESPACE + +#endif // Hadrons_Module_hpp_ diff --git a/extras/Hadrons/ModuleFactory.hpp b/extras/Hadrons/ModuleFactory.hpp new file mode 100644 index 00000000..48ab305c --- /dev/null +++ b/extras/Hadrons/ModuleFactory.hpp @@ -0,0 +1,49 @@ +/************************************************************************************* + +Grid physics library, www.github.com/paboyle/Grid + +Source file: extras/Hadrons/ModuleFactory.hpp + +Copyright (C) 2015 +Copyright (C) 2016 + +Author: Antonin Portelli + +This program is free software; you can redistribute it and/or modify +it under the terms of the GNU General Public License as published by +the Free Software Foundation; either version 2 of the License, or +(at your option) any later version. + +This program is distributed in the hope that it will be useful, +but WITHOUT ANY WARRANTY; without even the implied warranty of +MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +GNU General Public License for more details. + +You should have received a copy of the GNU General Public License along +with this program; if not, write to the Free Software Foundation, Inc., +51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA. + +See the full license in the file "LICENSE" in the top level distribution directory +*************************************************************************************/ +/* END LEGAL */ + +#ifndef Hadrons_ModuleFactory_hpp_ +#define Hadrons_ModuleFactory_hpp_ + +#include +#include +#include + +BEGIN_HADRONS_NAMESPACE + +/****************************************************************************** + * ModuleFactory * + ******************************************************************************/ +class ModuleFactory: public Factory +{ + SINGLETON_DEFCTOR(ModuleFactory) +}; + +END_HADRONS_NAMESPACE + +#endif // Hadrons_ModuleFactory_hpp_ diff --git a/extras/Hadrons/Modules.hpp b/extras/Hadrons/Modules.hpp new file mode 100644 index 00000000..c27254aa --- /dev/null +++ b/extras/Hadrons/Modules.hpp @@ -0,0 +1,25 @@ +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include diff --git a/extras/Hadrons/Modules/MAction/DWF.hpp b/extras/Hadrons/Modules/MAction/DWF.hpp new file mode 100644 index 00000000..78e0916c --- /dev/null +++ b/extras/Hadrons/Modules/MAction/DWF.hpp @@ -0,0 +1,140 @@ +/************************************************************************************* + +Grid physics library, www.github.com/paboyle/Grid + +Source file: extras/Hadrons/Modules/MAction/DWF.hpp + +Copyright (C) 2015 +Copyright (C) 2016 + +Author: Antonin Portelli + +This program is free software; you can redistribute it and/or modify +it under the terms of the GNU General Public License as published by +the Free Software Foundation; either version 2 of the License, or +(at your option) any later version. + +This program is distributed in the hope that it will be useful, +but WITHOUT ANY WARRANTY; without even the implied warranty of +MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +GNU General Public License for more details. + +You should have received a copy of the GNU General Public License along +with this program; if not, write to the Free Software Foundation, Inc., +51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA. + +See the full license in the file "LICENSE" in the top level distribution directory +*************************************************************************************/ +/* END LEGAL */ + +#ifndef Hadrons_MAction_DWF_hpp_ +#define Hadrons_MAction_DWF_hpp_ + +#include +#include +#include + +BEGIN_HADRONS_NAMESPACE + +/****************************************************************************** + * Domain wall quark action * + ******************************************************************************/ +BEGIN_MODULE_NAMESPACE(MAction) + +class DWFPar: Serializable +{ +public: + GRID_SERIALIZABLE_CLASS_MEMBERS(DWFPar, + std::string, gauge, + unsigned int, Ls, + double , mass, + double , M5, + std::string , boundary); +}; + +template +class TDWF: public Module +{ +public: + FGS_TYPE_ALIASES(FImpl,); +public: + // constructor + TDWF(const std::string name); + // destructor + virtual ~TDWF(void) = default; + // dependency relation + virtual std::vector getInput(void); + virtual std::vector getOutput(void); + // setup + virtual void setup(void); + // execution + virtual void execute(void); +}; + +MODULE_REGISTER_NS(DWF, TDWF, MAction); + +/****************************************************************************** + * DWF template implementation * + ******************************************************************************/ +// constructor ///////////////////////////////////////////////////////////////// +template +TDWF::TDWF(const std::string name) +: Module(name) +{} + +// dependencies/products /////////////////////////////////////////////////////// +template +std::vector TDWF::getInput(void) +{ + std::vector in = {par().gauge}; + + return in; +} + +template +std::vector TDWF::getOutput(void) +{ + std::vector out = {getName()}; + + return out; +} + +// setup /////////////////////////////////////////////////////////////////////// +template +void TDWF::setup(void) +{ + unsigned int size; + + size = 2*env().template lattice4dSize(); + env().registerObject(getName(), size, par().Ls); +} + +// execution /////////////////////////////////////////////////////////////////// +template +void TDWF::execute(void) +{ + LOG(Message) << "Setting up domain wall fermion matrix with m= " + << par().mass << ", M5= " << par().M5 << " and Ls= " + << par().Ls << " using gauge field '" << par().gauge << "'" + << std::endl; + LOG(Message) << "Fermion boundary conditions: " << par().boundary + << std::endl; + env().createGrid(par().Ls); + auto &U = *env().template getObject(par().gauge); + auto &g4 = *env().getGrid(); + auto &grb4 = *env().getRbGrid(); + auto &g5 = *env().getGrid(par().Ls); + auto &grb5 = *env().getRbGrid(par().Ls); + std::vector boundary = strToVec(par().boundary); + typename DomainWallFermion::ImplParams implParams(boundary); + FMat *fMatPt = new DomainWallFermion(U, g5, grb5, g4, grb4, + par().mass, par().M5, + implParams); + env().setObject(getName(), fMatPt); +} + +END_MODULE_NAMESPACE + +END_HADRONS_NAMESPACE + +#endif // Hadrons_MAction_DWF_hpp_ diff --git a/extras/Hadrons/Modules/MAction/Wilson.hpp b/extras/Hadrons/Modules/MAction/Wilson.hpp new file mode 100644 index 00000000..aab54245 --- /dev/null +++ b/extras/Hadrons/Modules/MAction/Wilson.hpp @@ -0,0 +1,132 @@ +/************************************************************************************* + +Grid physics library, www.github.com/paboyle/Grid + +Source file: extras/Hadrons/Modules/MAction/Wilson.hpp + +Copyright (C) 2015 +Copyright (C) 2016 + +Author: Antonin Portelli + +This program is free software; you can redistribute it and/or modify +it under the terms of the GNU General Public License as published by +the Free Software Foundation; either version 2 of the License, or +(at your option) any later version. + +This program is distributed in the hope that it will be useful, +but WITHOUT ANY WARRANTY; without even the implied warranty of +MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +GNU General Public License for more details. + +You should have received a copy of the GNU General Public License along +with this program; if not, write to the Free Software Foundation, Inc., +51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA. + +See the full license in the file "LICENSE" in the top level distribution directory +*************************************************************************************/ +/* END LEGAL */ + +#ifndef Hadrons_MAction_Wilson_hpp_ +#define Hadrons_MAction_Wilson_hpp_ + +#include +#include +#include + +BEGIN_HADRONS_NAMESPACE + +/****************************************************************************** + * TWilson quark action * + ******************************************************************************/ +BEGIN_MODULE_NAMESPACE(MAction) + +class WilsonPar: Serializable +{ +public: + GRID_SERIALIZABLE_CLASS_MEMBERS(WilsonPar, + std::string, gauge, + double , mass, + std::string, boundary); +}; + +template +class TWilson: public Module +{ +public: + FGS_TYPE_ALIASES(FImpl,); +public: + // constructor + TWilson(const std::string name); + // destructor + virtual ~TWilson(void) = default; + // dependencies/products + virtual std::vector getInput(void); + virtual std::vector getOutput(void); + // setup + virtual void setup(void); + // execution + virtual void execute(void); +}; + +MODULE_REGISTER_NS(Wilson, TWilson, MAction); + +/****************************************************************************** + * TWilson template implementation * + ******************************************************************************/ +// constructor ///////////////////////////////////////////////////////////////// +template +TWilson::TWilson(const std::string name) +: Module(name) +{} + +// dependencies/products /////////////////////////////////////////////////////// +template +std::vector TWilson::getInput(void) +{ + std::vector in = {par().gauge}; + + return in; +} + +template +std::vector TWilson::getOutput(void) +{ + std::vector out = {getName()}; + + return out; +} + +// setup /////////////////////////////////////////////////////////////////////// +template +void TWilson::setup(void) +{ + unsigned int size; + + size = 2*env().template lattice4dSize(); + env().registerObject(getName(), size); +} + +// execution /////////////////////////////////////////////////////////////////// +template +void TWilson::execute() +{ + LOG(Message) << "Setting up TWilson fermion matrix with m= " << par().mass + << " using gauge field '" << par().gauge << "'" << std::endl; + LOG(Message) << "Fermion boundary conditions: " << par().boundary + << std::endl; + auto &U = *env().template getObject(par().gauge); + auto &grid = *env().getGrid(); + auto &gridRb = *env().getRbGrid(); + std::vector boundary = strToVec(par().boundary); + typename WilsonFermion::ImplParams implParams(boundary); + FMat *fMatPt = new WilsonFermion(U, grid, gridRb, par().mass, + implParams); + env().setObject(getName(), fMatPt); +} + +END_MODULE_NAMESPACE + +END_HADRONS_NAMESPACE + +#endif // Hadrons_Wilson_hpp_ diff --git a/extras/Hadrons/Modules/MContraction/Baryon.hpp b/extras/Hadrons/Modules/MContraction/Baryon.hpp new file mode 100644 index 00000000..78bde5a2 --- /dev/null +++ b/extras/Hadrons/Modules/MContraction/Baryon.hpp @@ -0,0 +1,131 @@ +/************************************************************************************* + +Grid physics library, www.github.com/paboyle/Grid + +Source file: extras/Hadrons/Modules/MContraction/Baryon.hpp + +Copyright (C) 2015 +Copyright (C) 2016 + +Author: Antonin Portelli + +This program is free software; you can redistribute it and/or modify +it under the terms of the GNU General Public License as published by +the Free Software Foundation; either version 2 of the License, or +(at your option) any later version. + +This program is distributed in the hope that it will be useful, +but WITHOUT ANY WARRANTY; without even the implied warranty of +MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +GNU General Public License for more details. + +You should have received a copy of the GNU General Public License along +with this program; if not, write to the Free Software Foundation, Inc., +51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA. + +See the full license in the file "LICENSE" in the top level distribution directory +*************************************************************************************/ +/* END LEGAL */ + +#ifndef Hadrons_MContraction_Baryon_hpp_ +#define Hadrons_MContraction_Baryon_hpp_ + +#include +#include +#include + +BEGIN_HADRONS_NAMESPACE + +/****************************************************************************** + * Baryon * + ******************************************************************************/ +BEGIN_MODULE_NAMESPACE(MContraction) + +class BaryonPar: Serializable +{ +public: + GRID_SERIALIZABLE_CLASS_MEMBERS(BaryonPar, + std::string, q1, + std::string, q2, + std::string, q3, + std::string, output); +}; + +template +class TBaryon: public Module +{ +public: + FERM_TYPE_ALIASES(FImpl1, 1); + FERM_TYPE_ALIASES(FImpl2, 2); + FERM_TYPE_ALIASES(FImpl3, 3); + class Result: Serializable + { + public: + GRID_SERIALIZABLE_CLASS_MEMBERS(Result, + std::vector>>, corr); + }; +public: + // constructor + TBaryon(const std::string name); + // destructor + virtual ~TBaryon(void) = default; + // dependency relation + virtual std::vector getInput(void); + virtual std::vector getOutput(void); + // execution + virtual void execute(void); +}; + +MODULE_REGISTER_NS(Baryon, ARG(TBaryon), MContraction); + +/****************************************************************************** + * TBaryon implementation * + ******************************************************************************/ +// constructor ///////////////////////////////////////////////////////////////// +template +TBaryon::TBaryon(const std::string name) +: Module(name) +{} + +// dependencies/products /////////////////////////////////////////////////////// +template +std::vector TBaryon::getInput(void) +{ + std::vector input = {par().q1, par().q2, par().q3}; + + return input; +} + +template +std::vector TBaryon::getOutput(void) +{ + std::vector out = {getName()}; + + return out; +} + +// execution /////////////////////////////////////////////////////////////////// +template +void TBaryon::execute(void) +{ + LOG(Message) << "Computing baryon contractions '" << getName() << "' using" + << " quarks '" << par().q1 << "', '" << par().q2 << "', and '" + << par().q3 << "'" << std::endl; + + CorrWriter writer(par().output); + PropagatorField1 &q1 = *env().template getObject(par().q1); + PropagatorField2 &q2 = *env().template getObject(par().q2); + PropagatorField3 &q3 = *env().template getObject(par().q2); + LatticeComplex c(env().getGrid()); + Result result; + + // FIXME: do contractions + + // write(writer, "meson", result); +} + +END_MODULE_NAMESPACE + +END_HADRONS_NAMESPACE + +#endif // Hadrons_MContraction_Baryon_hpp_ diff --git a/extras/Hadrons/Modules/MContraction/DiscLoop.hpp b/extras/Hadrons/Modules/MContraction/DiscLoop.hpp new file mode 100644 index 00000000..4f782cd3 --- /dev/null +++ b/extras/Hadrons/Modules/MContraction/DiscLoop.hpp @@ -0,0 +1,144 @@ +/************************************************************************************* + +Grid physics library, www.github.com/paboyle/Grid + +Source file: extras/Hadrons/Modules/MContraction/DiscLoop.hpp + +Copyright (C) 2017 + +Author: Andrew Lawson + +This program is free software; you can redistribute it and/or modify +it under the terms of the GNU General Public License as published by +the Free Software Foundation; either version 2 of the License, or +(at your option) any later version. + +This program is distributed in the hope that it will be useful, +but WITHOUT ANY WARRANTY; without even the implied warranty of +MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +GNU General Public License for more details. + +You should have received a copy of the GNU General Public License along +with this program; if not, write to the Free Software Foundation, Inc., +51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA. + +See the full license in the file "LICENSE" in the top level distribution directory +*************************************************************************************/ +/* END LEGAL */ + +#ifndef Hadrons_MContraction_DiscLoop_hpp_ +#define Hadrons_MContraction_DiscLoop_hpp_ + +#include +#include +#include + +BEGIN_HADRONS_NAMESPACE + +/****************************************************************************** + * DiscLoop * + ******************************************************************************/ +BEGIN_MODULE_NAMESPACE(MContraction) + +class DiscLoopPar: Serializable +{ +public: + GRID_SERIALIZABLE_CLASS_MEMBERS(DiscLoopPar, + std::string, q_loop, + Gamma::Algebra, gamma, + std::string, output); +}; + +template +class TDiscLoop: public Module +{ + FERM_TYPE_ALIASES(FImpl,); + class Result: Serializable + { + public: + GRID_SERIALIZABLE_CLASS_MEMBERS(Result, + Gamma::Algebra, gamma, + std::vector, corr); + }; +public: + // constructor + TDiscLoop(const std::string name); + // destructor + virtual ~TDiscLoop(void) = default; + // dependency relation + virtual std::vector getInput(void); + virtual std::vector getOutput(void); + // setup + virtual void setup(void); + // execution + virtual void execute(void); +}; + +MODULE_REGISTER_NS(DiscLoop, TDiscLoop, MContraction); + +/****************************************************************************** + * TDiscLoop implementation * + ******************************************************************************/ +// constructor ///////////////////////////////////////////////////////////////// +template +TDiscLoop::TDiscLoop(const std::string name) +: Module(name) +{} + +// dependencies/products /////////////////////////////////////////////////////// +template +std::vector TDiscLoop::getInput(void) +{ + std::vector in = {par().q_loop}; + + return in; +} + +template +std::vector TDiscLoop::getOutput(void) +{ + std::vector out = {getName()}; + + return out; +} + +// setup /////////////////////////////////////////////////////////////////////// +template +void TDiscLoop::setup(void) +{ + +} + +// execution /////////////////////////////////////////////////////////////////// +template +void TDiscLoop::execute(void) +{ + LOG(Message) << "Computing disconnected loop contraction '" << getName() + << "' using '" << par().q_loop << "' with " << par().gamma + << " insertion." << std::endl; + + CorrWriter writer(par().output); + PropagatorField &q_loop = *env().template getObject(par().q_loop); + LatticeComplex c(env().getGrid()); + Gamma gamma(par().gamma); + std::vector buf; + Result result; + + c = trace(gamma*q_loop); + sliceSum(c, buf, Tp); + + result.gamma = par().gamma; + result.corr.resize(buf.size()); + for (unsigned int t = 0; t < buf.size(); ++t) + { + result.corr[t] = TensorRemove(buf[t]); + } + + write(writer, "disc", result); +} + +END_MODULE_NAMESPACE + +END_HADRONS_NAMESPACE + +#endif // Hadrons_MContraction_DiscLoop_hpp_ diff --git a/extras/Hadrons/Modules/MContraction/Gamma3pt.hpp b/extras/Hadrons/Modules/MContraction/Gamma3pt.hpp new file mode 100644 index 00000000..7f643d49 --- /dev/null +++ b/extras/Hadrons/Modules/MContraction/Gamma3pt.hpp @@ -0,0 +1,170 @@ +/************************************************************************************* + +Grid physics library, www.github.com/paboyle/Grid + +Source file: extras/Hadrons/Modules/MContraction/Gamma3pt.hpp + +Copyright (C) 2017 + +Author: Andrew Lawson + +This program is free software; you can redistribute it and/or modify +it under the terms of the GNU General Public License as published by +the Free Software Foundation; either version 2 of the License, or +(at your option) any later version. + +This program is distributed in the hope that it will be useful, +but WITHOUT ANY WARRANTY; without even the implied warranty of +MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +GNU General Public License for more details. + +You should have received a copy of the GNU General Public License along +with this program; if not, write to the Free Software Foundation, Inc., +51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA. + +See the full license in the file "LICENSE" in the top level distribution directory +*************************************************************************************/ +/* END LEGAL */ + +#ifndef Hadrons_MContraction_Gamma3pt_hpp_ +#define Hadrons_MContraction_Gamma3pt_hpp_ + +#include +#include +#include + +BEGIN_HADRONS_NAMESPACE + +/* + * 3pt contraction with gamma matrix insertion. + * + * Schematic: + * + * q2 q3 + * /----<------*------<----¬ + * / gamma \ + * / \ + * i * * f + * \ / + * \ / + * \----------->----------/ + * q1 + * + * trace(g5*q1*adj(q2)*g5*gamma*q3) + */ + +/****************************************************************************** + * Gamma3pt * + ******************************************************************************/ +BEGIN_MODULE_NAMESPACE(MContraction) + +class Gamma3ptPar: Serializable +{ +public: + GRID_SERIALIZABLE_CLASS_MEMBERS(Gamma3ptPar, + std::string, q1, + std::string, q2, + std::string, q3, + Gamma::Algebra, gamma, + std::string, output); +}; + +template +class TGamma3pt: public Module +{ + FERM_TYPE_ALIASES(FImpl1, 1); + FERM_TYPE_ALIASES(FImpl2, 2); + FERM_TYPE_ALIASES(FImpl3, 3); + class Result: Serializable + { + public: + GRID_SERIALIZABLE_CLASS_MEMBERS(Result, + Gamma::Algebra, gamma, + std::vector, corr); + }; +public: + // constructor + TGamma3pt(const std::string name); + // destructor + virtual ~TGamma3pt(void) = default; + // dependency relation + virtual std::vector getInput(void); + virtual std::vector getOutput(void); + // setup + virtual void setup(void); + // execution + virtual void execute(void); +}; + +MODULE_REGISTER_NS(Gamma3pt, ARG(TGamma3pt), MContraction); + +/****************************************************************************** + * TGamma3pt implementation * + ******************************************************************************/ +// constructor ///////////////////////////////////////////////////////////////// +template +TGamma3pt::TGamma3pt(const std::string name) +: Module(name) +{} + +// dependencies/products /////////////////////////////////////////////////////// +template +std::vector TGamma3pt::getInput(void) +{ + std::vector in = {par().q1, par().q2, par().q3}; + + return in; +} + +template +std::vector TGamma3pt::getOutput(void) +{ + std::vector out = {getName()}; + + return out; +} + +// setup /////////////////////////////////////////////////////////////////////// +template +void TGamma3pt::setup(void) +{ + +} + +// execution /////////////////////////////////////////////////////////////////// +template +void TGamma3pt::execute(void) +{ + LOG(Message) << "Computing 3pt contractions '" << getName() << "' using" + << " quarks '" << par().q1 << "', '" << par().q2 << "' and '" + << par().q3 << "', with " << par().gamma << " insertion." + << std::endl; + + CorrWriter writer(par().output); + PropagatorField1 &q1 = *env().template getObject(par().q1); + PropagatorField2 &q2 = *env().template getObject(par().q2); + PropagatorField3 &q3 = *env().template getObject(par().q3); + LatticeComplex c(env().getGrid()); + Gamma g5(Gamma::Algebra::Gamma5); + Gamma gamma(par().gamma); + std::vector buf; + Result result; + + c = trace(g5*q1*adj(q2)*(g5*gamma)*q3); + sliceSum(c, buf, Tp); + + result.gamma = par().gamma; + result.corr.resize(buf.size()); + for (unsigned int t = 0; t < buf.size(); ++t) + { + result.corr[t] = TensorRemove(buf[t]); + } + + write(writer, "gamma3pt", result); +} + +END_MODULE_NAMESPACE + +END_HADRONS_NAMESPACE + +#endif // Hadrons_MContraction_Gamma3pt_hpp_ diff --git a/extras/Hadrons/Modules/MContraction/Meson.hpp b/extras/Hadrons/Modules/MContraction/Meson.hpp new file mode 100644 index 00000000..7810326a --- /dev/null +++ b/extras/Hadrons/Modules/MContraction/Meson.hpp @@ -0,0 +1,244 @@ +/************************************************************************************* + +Grid physics library, www.github.com/paboyle/Grid + +Source file: extras/Hadrons/Modules/MContraction/Meson.hpp + +Copyright (C) 2015 +Copyright (C) 2016 +Copyright (C) 2017 + +Author: Antonin Portelli + Andrew Lawson + +This program is free software; you can redistribute it and/or modify +it under the terms of the GNU General Public License as published by +the Free Software Foundation; either version 2 of the License, or +(at your option) any later version. + +This program is distributed in the hope that it will be useful, +but WITHOUT ANY WARRANTY; without even the implied warranty of +MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +GNU General Public License for more details. + +You should have received a copy of the GNU General Public License along +with this program; if not, write to the Free Software Foundation, Inc., +51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA. + +See the full license in the file "LICENSE" in the top level distribution directory +*************************************************************************************/ +/* END LEGAL */ + +#ifndef Hadrons_MContraction_Meson_hpp_ +#define Hadrons_MContraction_Meson_hpp_ + +#include +#include +#include + +BEGIN_HADRONS_NAMESPACE + +/* + + Meson contractions + ----------------------------- + + * options: + - q1: input propagator 1 (string) + - q2: input propagator 2 (string) + - gammas: gamma products to insert at sink & source, pairs of gamma matrices + (space-separated strings) in angled brackets (i.e. ), + in a sequence (e.g. ""). + + Special values: "all" - perform all possible contractions. + - mom: momentum insertion, space-separated float sequence (e.g ".1 .2 1. 0."), + given as multiples of (2*pi) / L. +*/ + +/****************************************************************************** + * TMeson * + ******************************************************************************/ +BEGIN_MODULE_NAMESPACE(MContraction) + +typedef std::pair GammaPair; + +class MesonPar: Serializable +{ +public: + GRID_SERIALIZABLE_CLASS_MEMBERS(MesonPar, + std::string, q1, + std::string, q2, + std::string, gammas, + std::string, sink, + std::string, output); +}; + +template +class TMeson: public Module +{ +public: + FERM_TYPE_ALIASES(FImpl1, 1); + FERM_TYPE_ALIASES(FImpl2, 2); + FERM_TYPE_ALIASES(ScalarImplCR, Scalar); + SINK_TYPE_ALIASES(Scalar); + class Result: Serializable + { + public: + GRID_SERIALIZABLE_CLASS_MEMBERS(Result, + Gamma::Algebra, gamma_snk, + Gamma::Algebra, gamma_src, + std::vector, corr); + }; +public: + // constructor + TMeson(const std::string name); + // destructor + virtual ~TMeson(void) = default; + // dependencies/products + virtual std::vector getInput(void); + virtual std::vector getOutput(void); + virtual void parseGammaString(std::vector &gammaList); + // execution + virtual void execute(void); +}; + +MODULE_REGISTER_NS(Meson, ARG(TMeson), MContraction); + +/****************************************************************************** + * TMeson implementation * + ******************************************************************************/ +// constructor ///////////////////////////////////////////////////////////////// +template +TMeson::TMeson(const std::string name) +: Module(name) +{} + +// dependencies/products /////////////////////////////////////////////////////// +template +std::vector TMeson::getInput(void) +{ + std::vector input = {par().q1, par().q2, par().sink}; + + return input; +} + +template +std::vector TMeson::getOutput(void) +{ + std::vector output = {getName()}; + + return output; +} + +template +void TMeson::parseGammaString(std::vector &gammaList) +{ + gammaList.clear(); + // Determine gamma matrices to insert at source/sink. + if (par().gammas.compare("all") == 0) + { + // Do all contractions. + for (unsigned int i = 1; i < Gamma::nGamma; i += 2) + { + for (unsigned int j = 1; j < Gamma::nGamma; j += 2) + { + gammaList.push_back(std::make_pair((Gamma::Algebra)i, + (Gamma::Algebra)j)); + } + } + } + else + { + // Parse individual contractions from input string. + gammaList = strToVec(par().gammas); + } +} + + +// execution /////////////////////////////////////////////////////////////////// +#define mesonConnected(q1, q2, gSnk, gSrc) \ +(g5*(gSnk))*(q1)*(adj(gSrc)*g5)*adj(q2) + +template +void TMeson::execute(void) +{ + LOG(Message) << "Computing meson contractions '" << getName() << "' using" + << " quarks '" << par().q1 << "' and '" << par().q2 << "'" + << std::endl; + + CorrWriter writer(par().output); + std::vector buf; + std::vector result; + Gamma g5(Gamma::Algebra::Gamma5); + std::vector gammaList; + int nt = env().getDim(Tp); + + parseGammaString(gammaList); + result.resize(gammaList.size()); + for (unsigned int i = 0; i < result.size(); ++i) + { + result[i].gamma_snk = gammaList[i].first; + result[i].gamma_src = gammaList[i].second; + result[i].corr.resize(nt); + } + if (env().template isObjectOfType(par().q1) and + env().template isObjectOfType(par().q2)) + { + SlicedPropagator1 &q1 = *env().template getObject(par().q1); + SlicedPropagator2 &q2 = *env().template getObject(par().q2); + + LOG(Message) << "(propagator already sinked)" << std::endl; + for (unsigned int i = 0; i < result.size(); ++i) + { + Gamma gSnk(gammaList[i].first); + Gamma gSrc(gammaList[i].second); + + for (unsigned int t = 0; t < buf.size(); ++t) + { + result[i].corr[t] = TensorRemove(trace(mesonConnected(q1[t], q2[t], gSnk, gSrc))); + } + } + } + else + { + PropagatorField1 &q1 = *env().template getObject(par().q1); + PropagatorField2 &q2 = *env().template getObject(par().q2); + LatticeComplex c(env().getGrid()); + + LOG(Message) << "(using sink '" << par().sink << "')" << std::endl; + for (unsigned int i = 0; i < result.size(); ++i) + { + Gamma gSnk(gammaList[i].first); + Gamma gSrc(gammaList[i].second); + std::string ns; + + ns = env().getModuleNamespace(env().getObjectModule(par().sink)); + if (ns == "MSource") + { + PropagatorField1 &sink = + *env().template getObject(par().sink); + + c = trace(mesonConnected(q1, q2, gSnk, gSrc)*sink); + sliceSum(c, buf, Tp); + } + else if (ns == "MSink") + { + SinkFnScalar &sink = *env().template getObject(par().sink); + + c = trace(mesonConnected(q1, q2, gSnk, gSrc)); + buf = sink(c); + } + for (unsigned int t = 0; t < buf.size(); ++t) + { + result[i].corr[t] = TensorRemove(buf[t]); + } + } + } + write(writer, "meson", result); +} + +END_MODULE_NAMESPACE + +END_HADRONS_NAMESPACE + +#endif // Hadrons_MContraction_Meson_hpp_ diff --git a/extras/Hadrons/Modules/MContraction/WeakHamiltonian.hpp b/extras/Hadrons/Modules/MContraction/WeakHamiltonian.hpp new file mode 100644 index 00000000..0a3c2e31 --- /dev/null +++ b/extras/Hadrons/Modules/MContraction/WeakHamiltonian.hpp @@ -0,0 +1,114 @@ +/************************************************************************************* + +Grid physics library, www.github.com/paboyle/Grid + +Source file: extras/Hadrons/Modules/MContraction/WeakHamiltonian.hpp + +Copyright (C) 2017 + +Author: Andrew Lawson + +This program is free software; you can redistribute it and/or modify +it under the terms of the GNU General Public License as published by +the Free Software Foundation; either version 2 of the License, or +(at your option) any later version. + +This program is distributed in the hope that it will be useful, +but WITHOUT ANY WARRANTY; without even the implied warranty of +MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +GNU General Public License for more details. + +You should have received a copy of the GNU General Public License along +with this program; if not, write to the Free Software Foundation, Inc., +51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA. + +See the full license in the file "LICENSE" in the top level distribution directory +*************************************************************************************/ +/* END LEGAL */ + +#ifndef Hadrons_MContraction_WeakHamiltonian_hpp_ +#define Hadrons_MContraction_WeakHamiltonian_hpp_ + +#include +#include +#include + +BEGIN_HADRONS_NAMESPACE + +/****************************************************************************** + * WeakHamiltonian * + ******************************************************************************/ +BEGIN_MODULE_NAMESPACE(MContraction) + +/******************************************************************************* + * Utilities for contractions involving the Weak Hamiltonian. + ******************************************************************************/ +//// Sum and store correlator. +#define MAKE_DIAG(exp, buf, res, n)\ +sliceSum(exp, buf, Tp);\ +res.name = (n);\ +res.corr.resize(buf.size());\ +for (unsigned int t = 0; t < buf.size(); ++t)\ +{\ + res.corr[t] = TensorRemove(buf[t]);\ +} + +//// Contraction of mu index: use 'mu' variable in exp. +#define SUM_MU(buf,exp)\ +buf = zero;\ +for (unsigned int mu = 0; mu < ndim; ++mu)\ +{\ + buf += exp;\ +} + +enum +{ + i_V = 0, + i_A = 1, + n_i = 2 +}; + +class WeakHamiltonianPar: Serializable +{ +public: + GRID_SERIALIZABLE_CLASS_MEMBERS(WeakHamiltonianPar, + std::string, q1, + std::string, q2, + std::string, q3, + std::string, q4, + std::string, output); +}; + +#define MAKE_WEAK_MODULE(modname)\ +class T##modname: public Module\ +{\ +public:\ + FERM_TYPE_ALIASES(FIMPL,)\ + class Result: Serializable\ + {\ + public:\ + GRID_SERIALIZABLE_CLASS_MEMBERS(Result,\ + std::string, name,\ + std::vector, corr);\ + };\ +public:\ + /* constructor */ \ + T##modname(const std::string name);\ + /* destructor */ \ + virtual ~T##modname(void) = default;\ + /* dependency relation */ \ + virtual std::vector getInput(void);\ + virtual std::vector getOutput(void);\ + /* setup */ \ + virtual void setup(void);\ + /* execution */ \ + virtual void execute(void);\ + std::vector VA_label = {"V", "A"};\ +};\ +MODULE_REGISTER_NS(modname, T##modname, MContraction); + +END_MODULE_NAMESPACE + +END_HADRONS_NAMESPACE + +#endif // Hadrons_MContraction_WeakHamiltonian_hpp_ diff --git a/extras/Hadrons/Modules/MContraction/WeakHamiltonianEye.cc b/extras/Hadrons/Modules/MContraction/WeakHamiltonianEye.cc new file mode 100644 index 00000000..a44c2534 --- /dev/null +++ b/extras/Hadrons/Modules/MContraction/WeakHamiltonianEye.cc @@ -0,0 +1,137 @@ +/************************************************************************************* + +Grid physics library, www.github.com/paboyle/Grid + +Source file: extras/Hadrons/Modules/MContraction/WeakHamiltonianEye.cc + +Copyright (C) 2017 + +Author: Andrew Lawson + +This program is free software; you can redistribute it and/or modify +it under the terms of the GNU General Public License as published by +the Free Software Foundation; either version 2 of the License, or +(at your option) any later version. + +This program is distributed in the hope that it will be useful, +but WITHOUT ANY WARRANTY; without even the implied warranty of +MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +GNU General Public License for more details. + +You should have received a copy of the GNU General Public License along +with this program; if not, write to the Free Software Foundation, Inc., +51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA. + +See the full license in the file "LICENSE" in the top level distribution directory +*************************************************************************************/ +/* END LEGAL */ + +#include + +using namespace Grid; +using namespace Hadrons; +using namespace MContraction; + +/* + * Weak Hamiltonian current-current contractions, Eye-type. + * + * These contractions are generated by the Q1 and Q2 operators in the physical + * basis (see e.g. Fig 3 of arXiv:1507.03094). + * + * Schematics: q4 | + * /-<-¬ | + * / \ | q2 q3 + * \ / | /----<------*------<----¬ + * q2 \ / q3 | / /-*-¬ \ + * /-----<-----* *-----<----¬ | / / \ \ + * i * H_W * f | i * \ / q4 * f + * \ / | \ \->-/ / + * \ / | \ / + * \---------->---------/ | \----------->----------/ + * q1 | q1 + * | + * Saucer (S) | Eye (E) + * + * S: trace(q3*g5*q1*adj(q2)*g5*gL[mu][p_1]*q4*gL[mu][p_2]) + * E: trace(q3*g5*q1*adj(q2)*g5*gL[mu][p_1])*trace(q4*gL[mu][p_2]) + */ + +/****************************************************************************** + * TWeakHamiltonianEye implementation * + ******************************************************************************/ +// constructor ///////////////////////////////////////////////////////////////// +TWeakHamiltonianEye::TWeakHamiltonianEye(const std::string name) +: Module(name) +{} + +// dependencies/products /////////////////////////////////////////////////////// +std::vector TWeakHamiltonianEye::getInput(void) +{ + std::vector in = {par().q1, par().q2, par().q3, par().q4}; + + return in; +} + +std::vector TWeakHamiltonianEye::getOutput(void) +{ + std::vector out = {getName()}; + + return out; +} + +// setup /////////////////////////////////////////////////////////////////////// +void TWeakHamiltonianEye::setup(void) +{ + +} + +// execution /////////////////////////////////////////////////////////////////// +void TWeakHamiltonianEye::execute(void) +{ + LOG(Message) << "Computing Weak Hamiltonian (Eye type) contractions '" + << getName() << "' using quarks '" << par().q1 << "', '" + << par().q2 << ", '" << par().q3 << "' and '" << par().q4 + << "'." << std::endl; + + CorrWriter writer(par().output); + PropagatorField &q1 = *env().template getObject(par().q1); + PropagatorField &q2 = *env().template getObject(par().q2); + PropagatorField &q3 = *env().template getObject(par().q3); + PropagatorField &q4 = *env().template getObject(par().q4); + Gamma g5 = Gamma(Gamma::Algebra::Gamma5); + LatticeComplex expbuf(env().getGrid()); + std::vector corrbuf; + std::vector result(n_eye_diag); + unsigned int ndim = env().getNd(); + + PropagatorField tmp1(env().getGrid()); + LatticeComplex tmp2(env().getGrid()); + std::vector S_body(ndim, tmp1); + std::vector S_loop(ndim, tmp1); + std::vector E_body(ndim, tmp2); + std::vector E_loop(ndim, tmp2); + + // Setup for S-type contractions. + for (int mu = 0; mu < ndim; ++mu) + { + S_body[mu] = MAKE_SE_BODY(q1, q2, q3, GammaL(Gamma::gmu[mu])); + S_loop[mu] = MAKE_SE_LOOP(q4, GammaL(Gamma::gmu[mu])); + } + + // Perform S-type contractions. + SUM_MU(expbuf, trace(S_body[mu]*S_loop[mu])) + MAKE_DIAG(expbuf, corrbuf, result[S_diag], "HW_S") + + // Recycle sub-expressions for E-type contractions. + for (unsigned int mu = 0; mu < ndim; ++mu) + { + E_body[mu] = trace(S_body[mu]); + E_loop[mu] = trace(S_loop[mu]); + } + + // Perform E-type contractions. + SUM_MU(expbuf, E_body[mu]*E_loop[mu]) + MAKE_DIAG(expbuf, corrbuf, result[E_diag], "HW_E") + + write(writer, "HW_Eye", result); +} diff --git a/extras/Hadrons/Modules/MContraction/WeakHamiltonianEye.hpp b/extras/Hadrons/Modules/MContraction/WeakHamiltonianEye.hpp new file mode 100644 index 00000000..3a2b9309 --- /dev/null +++ b/extras/Hadrons/Modules/MContraction/WeakHamiltonianEye.hpp @@ -0,0 +1,58 @@ +/************************************************************************************* + +Grid physics library, www.github.com/paboyle/Grid + +Source file: extras/Hadrons/Modules/MContraction/WeakHamiltonianEye.hpp + +Copyright (C) 2017 + +Author: Andrew Lawson + +This program is free software; you can redistribute it and/or modify +it under the terms of the GNU General Public License as published by +the Free Software Foundation; either version 2 of the License, or +(at your option) any later version. + +This program is distributed in the hope that it will be useful, +but WITHOUT ANY WARRANTY; without even the implied warranty of +MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +GNU General Public License for more details. + +You should have received a copy of the GNU General Public License along +with this program; if not, write to the Free Software Foundation, Inc., +51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA. + +See the full license in the file "LICENSE" in the top level distribution directory +*************************************************************************************/ +/* END LEGAL */ + +#ifndef Hadrons_MContraction_WeakHamiltonianEye_hpp_ +#define Hadrons_MContraction_WeakHamiltonianEye_hpp_ + +#include + +BEGIN_HADRONS_NAMESPACE + +/****************************************************************************** + * WeakHamiltonianEye * + ******************************************************************************/ +BEGIN_MODULE_NAMESPACE(MContraction) + +enum +{ + S_diag = 0, + E_diag = 1, + n_eye_diag = 2 +}; + +// Saucer and Eye subdiagram contractions. +#define MAKE_SE_BODY(Q_1, Q_2, Q_3, gamma) (Q_3*g5*Q_1*adj(Q_2)*g5*gamma) +#define MAKE_SE_LOOP(Q_loop, gamma) (Q_loop*gamma) + +MAKE_WEAK_MODULE(WeakHamiltonianEye) + +END_MODULE_NAMESPACE + +END_HADRONS_NAMESPACE + +#endif // Hadrons_MContraction_WeakHamiltonianEye_hpp_ diff --git a/extras/Hadrons/Modules/MContraction/WeakHamiltonianNonEye.cc b/extras/Hadrons/Modules/MContraction/WeakHamiltonianNonEye.cc new file mode 100644 index 00000000..2c4df68a --- /dev/null +++ b/extras/Hadrons/Modules/MContraction/WeakHamiltonianNonEye.cc @@ -0,0 +1,139 @@ +/************************************************************************************* + +Grid physics library, www.github.com/paboyle/Grid + +Source file: extras/Hadrons/Modules/MContraction/WeakHamiltonianNonEye.cc + +Copyright (C) 2017 + +Author: Andrew Lawson + +This program is free software; you can redistribute it and/or modify +it under the terms of the GNU General Public License as published by +the Free Software Foundation; either version 2 of the License, or +(at your option) any later version. + +This program is distributed in the hope that it will be useful, +but WITHOUT ANY WARRANTY; without even the implied warranty of +MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +GNU General Public License for more details. + +You should have received a copy of the GNU General Public License along +with this program; if not, write to the Free Software Foundation, Inc., +51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA. + +See the full license in the file "LICENSE" in the top level distribution directory +*************************************************************************************/ +/* END LEGAL */ + +#include + +using namespace Grid; +using namespace Hadrons; +using namespace MContraction; + +/* + * Weak Hamiltonian current-current contractions, Non-Eye-type. + * + * These contractions are generated by the Q1 and Q2 operators in the physical + * basis (see e.g. Fig 3 of arXiv:1507.03094). + * + * Schematic: + * q2 q3 | q2 q3 + * /--<--¬ /--<--¬ | /--<--¬ /--<--¬ + * / \ / \ | / \ / \ + * / \ / \ | / \ / \ + * / \ / \ | / \ / \ + * i * * H_W * f | i * * * H_W * f + * \ * | | \ / \ / + * \ / \ / | \ / \ / + * \ / \ / | \ / \ / + * \ / \ / | \-->--/ \-->--/ + * \-->--/ \-->--/ | q1 q4 + * q1 q4 | + * Connected (C) | Wing (W) + * + * C: trace(q1*adj(q2)*g5*gL[mu]*q3*adj(q4)*g5*gL[mu]) + * W: trace(q1*adj(q2)*g5*gL[mu])*trace(q3*adj(q4)*g5*gL[mu]) + * + */ + +/****************************************************************************** + * TWeakHamiltonianNonEye implementation * + ******************************************************************************/ +// constructor ///////////////////////////////////////////////////////////////// +TWeakHamiltonianNonEye::TWeakHamiltonianNonEye(const std::string name) +: Module(name) +{} + +// dependencies/products /////////////////////////////////////////////////////// +std::vector TWeakHamiltonianNonEye::getInput(void) +{ + std::vector in = {par().q1, par().q2, par().q3, par().q4}; + + return in; +} + +std::vector TWeakHamiltonianNonEye::getOutput(void) +{ + std::vector out = {getName()}; + + return out; +} + +// setup /////////////////////////////////////////////////////////////////////// +void TWeakHamiltonianNonEye::setup(void) +{ + +} + +// execution /////////////////////////////////////////////////////////////////// +void TWeakHamiltonianNonEye::execute(void) +{ + LOG(Message) << "Computing Weak Hamiltonian (Non-Eye type) contractions '" + << getName() << "' using quarks '" << par().q1 << "', '" + << par().q2 << ", '" << par().q3 << "' and '" << par().q4 + << "'." << std::endl; + + CorrWriter writer(par().output); + PropagatorField &q1 = *env().template getObject(par().q1); + PropagatorField &q2 = *env().template getObject(par().q2); + PropagatorField &q3 = *env().template getObject(par().q3); + PropagatorField &q4 = *env().template getObject(par().q4); + Gamma g5 = Gamma(Gamma::Algebra::Gamma5); + LatticeComplex expbuf(env().getGrid()); + std::vector corrbuf; + std::vector result(n_noneye_diag); + unsigned int ndim = env().getNd(); + + PropagatorField tmp1(env().getGrid()); + LatticeComplex tmp2(env().getGrid()); + std::vector C_i_side_loop(ndim, tmp1); + std::vector C_f_side_loop(ndim, tmp1); + std::vector W_i_side_loop(ndim, tmp2); + std::vector W_f_side_loop(ndim, tmp2); + + // Setup for C-type contractions. + for (int mu = 0; mu < ndim; ++mu) + { + C_i_side_loop[mu] = MAKE_CW_SUBDIAG(q1, q2, GammaL(Gamma::gmu[mu])); + C_f_side_loop[mu] = MAKE_CW_SUBDIAG(q3, q4, GammaL(Gamma::gmu[mu])); + } + + // Perform C-type contractions. + SUM_MU(expbuf, trace(C_i_side_loop[mu]*C_f_side_loop[mu])) + MAKE_DIAG(expbuf, corrbuf, result[C_diag], "HW_C") + + // Recycle sub-expressions for W-type contractions. + for (unsigned int mu = 0; mu < ndim; ++mu) + { + W_i_side_loop[mu] = trace(C_i_side_loop[mu]); + W_f_side_loop[mu] = trace(C_f_side_loop[mu]); + } + + // Perform W-type contractions. + SUM_MU(expbuf, W_i_side_loop[mu]*W_f_side_loop[mu]) + MAKE_DIAG(expbuf, corrbuf, result[W_diag], "HW_W") + + write(writer, "HW_NonEye", result); +} diff --git a/extras/Hadrons/Modules/MContraction/WeakHamiltonianNonEye.hpp b/extras/Hadrons/Modules/MContraction/WeakHamiltonianNonEye.hpp new file mode 100644 index 00000000..eb5abe3c --- /dev/null +++ b/extras/Hadrons/Modules/MContraction/WeakHamiltonianNonEye.hpp @@ -0,0 +1,57 @@ +/************************************************************************************* + +Grid physics library, www.github.com/paboyle/Grid + +Source file: extras/Hadrons/Modules/MContraction/WeakHamiltonianNonEye.hpp + +Copyright (C) 2017 + +Author: Andrew Lawson + +This program is free software; you can redistribute it and/or modify +it under the terms of the GNU General Public License as published by +the Free Software Foundation; either version 2 of the License, or +(at your option) any later version. + +This program is distributed in the hope that it will be useful, +but WITHOUT ANY WARRANTY; without even the implied warranty of +MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +GNU General Public License for more details. + +You should have received a copy of the GNU General Public License along +with this program; if not, write to the Free Software Foundation, Inc., +51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA. + +See the full license in the file "LICENSE" in the top level distribution directory +*************************************************************************************/ +/* END LEGAL */ + +#ifndef Hadrons_MContraction_WeakHamiltonianNonEye_hpp_ +#define Hadrons_MContraction_WeakHamiltonianNonEye_hpp_ + +#include + +BEGIN_HADRONS_NAMESPACE + +/****************************************************************************** + * WeakHamiltonianNonEye * + ******************************************************************************/ +BEGIN_MODULE_NAMESPACE(MContraction) + +enum +{ + W_diag = 0, + C_diag = 1, + n_noneye_diag = 2 +}; + +// Wing and Connected subdiagram contractions +#define MAKE_CW_SUBDIAG(Q_1, Q_2, gamma) (Q_1*adj(Q_2)*g5*gamma) + +MAKE_WEAK_MODULE(WeakHamiltonianNonEye) + +END_MODULE_NAMESPACE + +END_HADRONS_NAMESPACE + +#endif // Hadrons_MContraction_WeakHamiltonianNonEye_hpp_ diff --git a/extras/Hadrons/Modules/MContraction/WeakNeutral4ptDisc.cc b/extras/Hadrons/Modules/MContraction/WeakNeutral4ptDisc.cc new file mode 100644 index 00000000..6685f292 --- /dev/null +++ b/extras/Hadrons/Modules/MContraction/WeakNeutral4ptDisc.cc @@ -0,0 +1,135 @@ +/************************************************************************************* + +Grid physics library, www.github.com/paboyle/Grid + +Source file: extras/Hadrons/Modules/MContraction/WeakNeutral4ptDisc.cc + +Copyright (C) 2017 + +Author: Andrew Lawson + +This program is free software; you can redistribute it and/or modify +it under the terms of the GNU General Public License as published by +the Free Software Foundation; either version 2 of the License, or +(at your option) any later version. + +This program is distributed in the hope that it will be useful, +but WITHOUT ANY WARRANTY; without even the implied warranty of +MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +GNU General Public License for more details. + +You should have received a copy of the GNU General Public License along +with this program; if not, write to the Free Software Foundation, Inc., +51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA. + +See the full license in the file "LICENSE" in the top level distribution directory +*************************************************************************************/ +/* END LEGAL */ + +#include + +using namespace Grid; +using namespace Hadrons; +using namespace MContraction; + +/* + * Weak Hamiltonian + current contractions, disconnected topology for neutral + * mesons. + * + * These contractions are generated by operators Q_1,...,10 of the dS=1 Weak + * Hamiltonian in the physical basis and an additional current J (see e.g. + * Fig 11 of arXiv:1507.03094). + * + * Schematic: + * + * q2 q4 q3 + * /--<--¬ /---<--¬ /---<--¬ + * / \ / \ / \ + * i * * H_W | J * * f + * \ / \ / \ / + * \--->---/ \-------/ \------/ + * q1 + * + * options + * - q1: input propagator 1 (string) + * - q2: input propagator 2 (string) + * - q3: input propagator 3 (string), assumed to be sequential propagator + * - q4: input propagator 4 (string), assumed to be a loop + * + * type 1: trace(q1*adj(q2)*g5*gL[mu])*trace(loop*gL[mu])*trace(q3*g5) + * type 2: trace(q1*adj(q2)*g5*gL[mu]*loop*gL[mu])*trace(q3*g5) + */ + +/******************************************************************************* + * TWeakNeutral4ptDisc implementation * + ******************************************************************************/ +// constructor ///////////////////////////////////////////////////////////////// +TWeakNeutral4ptDisc::TWeakNeutral4ptDisc(const std::string name) +: Module(name) +{} + +// dependencies/products /////////////////////////////////////////////////////// +std::vector TWeakNeutral4ptDisc::getInput(void) +{ + std::vector in = {par().q1, par().q2, par().q3, par().q4}; + + return in; +} + +std::vector TWeakNeutral4ptDisc::getOutput(void) +{ + std::vector out = {getName()}; + + return out; +} + +// setup /////////////////////////////////////////////////////////////////////// +void TWeakNeutral4ptDisc::setup(void) +{ + +} + +// execution /////////////////////////////////////////////////////////////////// +void TWeakNeutral4ptDisc::execute(void) +{ + LOG(Message) << "Computing Weak Hamiltonian neutral disconnected contractions '" + << getName() << "' using quarks '" << par().q1 << "', '" + << par().q2 << ", '" << par().q3 << "' and '" << par().q4 + << "'." << std::endl; + + CorrWriter writer(par().output); + PropagatorField &q1 = *env().template getObject(par().q1); + PropagatorField &q2 = *env().template getObject(par().q2); + PropagatorField &q3 = *env().template getObject(par().q3); + PropagatorField &q4 = *env().template getObject(par().q4); + Gamma g5 = Gamma(Gamma::Algebra::Gamma5); + LatticeComplex expbuf(env().getGrid()); + std::vector corrbuf; + std::vector result(n_neut_disc_diag); + unsigned int ndim = env().getNd(); + + PropagatorField tmp(env().getGrid()); + std::vector meson(ndim, tmp); + std::vector loop(ndim, tmp); + LatticeComplex curr(env().getGrid()); + + // Setup for type 1 contractions. + for (int mu = 0; mu < ndim; ++mu) + { + meson[mu] = MAKE_DISC_MESON(q1, q2, GammaL(Gamma::gmu[mu])); + loop[mu] = MAKE_DISC_LOOP(q4, GammaL(Gamma::gmu[mu])); + } + curr = MAKE_DISC_CURR(q3, GammaL(Gamma::Algebra::Gamma5)); + + // Perform type 1 contractions. + SUM_MU(expbuf, trace(meson[mu]*loop[mu])) + expbuf *= curr; + MAKE_DIAG(expbuf, corrbuf, result[neut_disc_1_diag], "HW_disc0_1") + + // Perform type 2 contractions. + SUM_MU(expbuf, trace(meson[mu])*trace(loop[mu])) + expbuf *= curr; + MAKE_DIAG(expbuf, corrbuf, result[neut_disc_2_diag], "HW_disc0_2") + + write(writer, "HW_disc0", result); +} diff --git a/extras/Hadrons/Modules/MContraction/WeakNeutral4ptDisc.hpp b/extras/Hadrons/Modules/MContraction/WeakNeutral4ptDisc.hpp new file mode 100644 index 00000000..f26d4636 --- /dev/null +++ b/extras/Hadrons/Modules/MContraction/WeakNeutral4ptDisc.hpp @@ -0,0 +1,59 @@ +/************************************************************************************* + +Grid physics library, www.github.com/paboyle/Grid + +Source file: extras/Hadrons/Modules/MContraction/WeakNeutral4ptDisc.hpp + +Copyright (C) 2017 + +Author: Andrew Lawson + +This program is free software; you can redistribute it and/or modify +it under the terms of the GNU General Public License as published by +the Free Software Foundation; either version 2 of the License, or +(at your option) any later version. + +This program is distributed in the hope that it will be useful, +but WITHOUT ANY WARRANTY; without even the implied warranty of +MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +GNU General Public License for more details. + +You should have received a copy of the GNU General Public License along +with this program; if not, write to the Free Software Foundation, Inc., +51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA. + +See the full license in the file "LICENSE" in the top level distribution directory +*************************************************************************************/ +/* END LEGAL */ + +#ifndef Hadrons_MContraction_WeakNeutral4ptDisc_hpp_ +#define Hadrons_MContraction_WeakNeutral4ptDisc_hpp_ + +#include + +BEGIN_HADRONS_NAMESPACE + +/****************************************************************************** + * WeakNeutral4ptDisc * + ******************************************************************************/ +BEGIN_MODULE_NAMESPACE(MContraction) + +enum +{ + neut_disc_1_diag = 0, + neut_disc_2_diag = 1, + n_neut_disc_diag = 2 +}; + +// Neutral 4pt disconnected subdiagram contractions. +#define MAKE_DISC_MESON(Q_1, Q_2, gamma) (Q_1*adj(Q_2)*g5*gamma) +#define MAKE_DISC_LOOP(Q_LOOP, gamma) (Q_LOOP*gamma) +#define MAKE_DISC_CURR(Q_c, gamma) (trace(Q_c*gamma)) + +MAKE_WEAK_MODULE(WeakNeutral4ptDisc) + +END_MODULE_NAMESPACE + +END_HADRONS_NAMESPACE + +#endif // Hadrons_MContraction_WeakNeutral4ptDisc_hpp_ diff --git a/extras/Hadrons/Modules/MFermion/GaugeProp.hpp b/extras/Hadrons/Modules/MFermion/GaugeProp.hpp new file mode 100644 index 00000000..b4f9edcc --- /dev/null +++ b/extras/Hadrons/Modules/MFermion/GaugeProp.hpp @@ -0,0 +1,160 @@ +#ifndef Hadrons_MFermion_GaugeProp_hpp_ +#define Hadrons_MFermion_GaugeProp_hpp_ + +#include +#include +#include + +BEGIN_HADRONS_NAMESPACE + +/****************************************************************************** + * GaugeProp * + ******************************************************************************/ +BEGIN_MODULE_NAMESPACE(MFermion) + +class GaugePropPar: Serializable +{ +public: + GRID_SERIALIZABLE_CLASS_MEMBERS(GaugePropPar, + std::string, source, + std::string, solver); +}; + +template +class TGaugeProp: public Module +{ +public: + FGS_TYPE_ALIASES(FImpl,); +public: + // constructor + TGaugeProp(const std::string name); + // destructor + virtual ~TGaugeProp(void) = default; + // dependency relation + virtual std::vector getInput(void); + virtual std::vector getOutput(void); + // setup + virtual void setup(void); + // execution + virtual void execute(void); +private: + unsigned int Ls_; + SolverFn *solver_{nullptr}; +}; + +MODULE_REGISTER_NS(GaugeProp, TGaugeProp, MFermion); + +/****************************************************************************** + * TGaugeProp implementation * + ******************************************************************************/ +// constructor ///////////////////////////////////////////////////////////////// +template +TGaugeProp::TGaugeProp(const std::string name) +: Module(name) +{} + +// dependencies/products /////////////////////////////////////////////////////// +template +std::vector TGaugeProp::getInput(void) +{ + std::vector in = {par().source, par().solver}; + + return in; +} + +template +std::vector TGaugeProp::getOutput(void) +{ + std::vector out = {getName(), getName() + "_5d"}; + + return out; +} + +// setup /////////////////////////////////////////////////////////////////////// +template +void TGaugeProp::setup(void) +{ + Ls_ = env().getObjectLs(par().solver); + env().template registerLattice(getName()); + if (Ls_ > 1) + { + env().template registerLattice(getName() + "_5d", Ls_); + } +} + +// execution /////////////////////////////////////////////////////////////////// +template +void TGaugeProp::execute(void) +{ + LOG(Message) << "Computing quark propagator '" << getName() << "'" + << std::endl; + + FermionField source(env().getGrid(Ls_)), sol(env().getGrid(Ls_)), + tmp(env().getGrid()); + std::string propName = (Ls_ == 1) ? getName() : (getName() + "_5d"); + PropagatorField &prop = *env().template createLattice(propName); + PropagatorField &fullSrc = *env().template getObject(par().source); + SolverFn &solver = *env().template getObject(par().solver); + if (Ls_ > 1) + { + env().template createLattice(getName()); + } + + LOG(Message) << "Inverting using solver '" << par().solver + << "' on source '" << par().source << "'" << std::endl; + for (unsigned int s = 0; s < Ns; ++s) + for (unsigned int c = 0; c < Nc; ++c) + { + LOG(Message) << "Inversion for spin= " << s << ", color= " << c + << std::endl; + // source conversion for 4D sources + if (!env().isObject5d(par().source)) + { + if (Ls_ == 1) + { + PropToFerm(source, fullSrc, s, c); + } + else + { + source = zero; + PropToFerm(tmp, fullSrc, s, c); + InsertSlice(tmp, source, 0, 0); + InsertSlice(tmp, source, Ls_-1, 0); + axpby_ssp_pplus(source, 0., source, 1., source, 0, 0); + axpby_ssp_pminus(source, 0., source, 1., source, Ls_-1, Ls_-1); + } + } + // source conversion for 5D sources + else + { + if (Ls_ != env().getObjectLs(par().source)) + { + HADRON_ERROR("Ls mismatch between quark action and source"); + } + else + { + PropToFerm(source, fullSrc, s, c); + } + } + sol = zero; + solver(sol, source); + FermToProp(prop, sol, s, c); + // create 4D propagators from 5D one if necessary + if (Ls_ > 1) + { + PropagatorField &p4d = + *env().template getObject(getName()); + + axpby_ssp_pminus(sol, 0., sol, 1., sol, 0, 0); + axpby_ssp_pplus(sol, 1., sol, 1., sol, 0, Ls_-1); + ExtractSlice(tmp, sol, 0, 0); + FermToProp(p4d, tmp, s, c); + } + } +} + +END_MODULE_NAMESPACE + +END_HADRONS_NAMESPACE + +#endif // Hadrons_MFermion_GaugeProp_hpp_ diff --git a/extras/Hadrons/Modules/MGauge/Load.cc b/extras/Hadrons/Modules/MGauge/Load.cc new file mode 100644 index 00000000..062e7e98 --- /dev/null +++ b/extras/Hadrons/Modules/MGauge/Load.cc @@ -0,0 +1,78 @@ +/************************************************************************************* + +Grid physics library, www.github.com/paboyle/Grid + +Source file: extras/Hadrons/Modules/MGauge/Load.cc + +Copyright (C) 2015 +Copyright (C) 2016 + +Author: Antonin Portelli + +This program is free software; you can redistribute it and/or modify +it under the terms of the GNU General Public License as published by +the Free Software Foundation; either version 2 of the License, or +(at your option) any later version. + +This program is distributed in the hope that it will be useful, +but WITHOUT ANY WARRANTY; without even the implied warranty of +MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +GNU General Public License for more details. + +You should have received a copy of the GNU General Public License along +with this program; if not, write to the Free Software Foundation, Inc., +51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA. + +See the full license in the file "LICENSE" in the top level distribution directory +*************************************************************************************/ +/* END LEGAL */ + +#include + +using namespace Grid; +using namespace Hadrons; +using namespace MGauge; + +/****************************************************************************** +* TLoad implementation * +******************************************************************************/ +// constructor ///////////////////////////////////////////////////////////////// +TLoad::TLoad(const std::string name) +: Module(name) +{} + +// dependencies/products /////////////////////////////////////////////////////// +std::vector TLoad::getInput(void) +{ + std::vector in; + + return in; +} + +std::vector TLoad::getOutput(void) +{ + std::vector out = {getName()}; + + return out; +} + +// setup /////////////////////////////////////////////////////////////////////// +void TLoad::setup(void) +{ + env().registerLattice(getName()); +} + +// execution /////////////////////////////////////////////////////////////////// +void TLoad::execute(void) +{ + FieldMetaData header; + std::string fileName = par().file + "." + + std::to_string(env().getTrajectory()); + + LOG(Message) << "Loading NERSC configuration from file '" << fileName + << "'" << std::endl; + LatticeGaugeField &U = *env().createLattice(getName()); + NerscIO::readConfiguration(U, header, fileName); + LOG(Message) << "NERSC header:" << std::endl; + dump_meta_data(header, LOG(Message)); +} diff --git a/extras/Hadrons/Modules/MGauge/Load.hpp b/extras/Hadrons/Modules/MGauge/Load.hpp new file mode 100644 index 00000000..5ff6da0f --- /dev/null +++ b/extras/Hadrons/Modules/MGauge/Load.hpp @@ -0,0 +1,73 @@ +/************************************************************************************* + +Grid physics library, www.github.com/paboyle/Grid + +Source file: extras/Hadrons/Modules/MGauge/Load.hpp + +Copyright (C) 2015 +Copyright (C) 2016 + +Author: Antonin Portelli + +This program is free software; you can redistribute it and/or modify +it under the terms of the GNU General Public License as published by +the Free Software Foundation; either version 2 of the License, or +(at your option) any later version. + +This program is distributed in the hope that it will be useful, +but WITHOUT ANY WARRANTY; without even the implied warranty of +MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +GNU General Public License for more details. + +You should have received a copy of the GNU General Public License along +with this program; if not, write to the Free Software Foundation, Inc., +51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA. + +See the full license in the file "LICENSE" in the top level distribution directory +*************************************************************************************/ +/* END LEGAL */ + +#ifndef Hadrons_MGauge_Load_hpp_ +#define Hadrons_MGauge_Load_hpp_ + +#include +#include +#include + +BEGIN_HADRONS_NAMESPACE + +/****************************************************************************** + * Load a NERSC configuration * + ******************************************************************************/ +BEGIN_MODULE_NAMESPACE(MGauge) + +class LoadPar: Serializable +{ +public: + GRID_SERIALIZABLE_CLASS_MEMBERS(LoadPar, + std::string, file); +}; + +class TLoad: public Module +{ +public: + // constructor + TLoad(const std::string name); + // destructor + virtual ~TLoad(void) = default; + // dependency relation + virtual std::vector getInput(void); + virtual std::vector getOutput(void); + // setup + virtual void setup(void); + // execution + virtual void execute(void); +}; + +MODULE_REGISTER_NS(Load, TLoad, MGauge); + +END_MODULE_NAMESPACE + +END_HADRONS_NAMESPACE + +#endif // Hadrons_MGauge_Load_hpp_ diff --git a/extras/Hadrons/Modules/MGauge/Random.cc b/extras/Hadrons/Modules/MGauge/Random.cc new file mode 100644 index 00000000..c10fdfc3 --- /dev/null +++ b/extras/Hadrons/Modules/MGauge/Random.cc @@ -0,0 +1,69 @@ +/************************************************************************************* + +Grid physics library, www.github.com/paboyle/Grid + +Source file: extras/Hadrons/Modules/MGauge/Random.cc + +Copyright (C) 2015 +Copyright (C) 2016 + +Author: Antonin Portelli + +This program is free software; you can redistribute it and/or modify +it under the terms of the GNU General Public License as published by +the Free Software Foundation; either version 2 of the License, or +(at your option) any later version. + +This program is distributed in the hope that it will be useful, +but WITHOUT ANY WARRANTY; without even the implied warranty of +MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +GNU General Public License for more details. + +You should have received a copy of the GNU General Public License along +with this program; if not, write to the Free Software Foundation, Inc., +51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA. + +See the full license in the file "LICENSE" in the top level distribution directory +*************************************************************************************/ +/* END LEGAL */ + +#include + +using namespace Grid; +using namespace Hadrons; +using namespace MGauge; + +/****************************************************************************** +* TRandom implementation * +******************************************************************************/ +// constructor ///////////////////////////////////////////////////////////////// +TRandom::TRandom(const std::string name) +: Module(name) +{} + +// dependencies/products /////////////////////////////////////////////////////// +std::vector TRandom::getInput(void) +{ + return std::vector(); +} + +std::vector TRandom::getOutput(void) +{ + std::vector out = {getName()}; + + return out; +} + +// setup /////////////////////////////////////////////////////////////////////// +void TRandom::setup(void) +{ + env().registerLattice(getName()); +} + +// execution /////////////////////////////////////////////////////////////////// +void TRandom::execute(void) +{ + LOG(Message) << "Generating random gauge configuration" << std::endl; + LatticeGaugeField &U = *env().createLattice(getName()); + SU3::HotConfiguration(*env().get4dRng(), U); +} diff --git a/extras/Hadrons/Modules/MGauge/Random.hpp b/extras/Hadrons/Modules/MGauge/Random.hpp new file mode 100644 index 00000000..a97d25cf --- /dev/null +++ b/extras/Hadrons/Modules/MGauge/Random.hpp @@ -0,0 +1,66 @@ +/************************************************************************************* + +Grid physics library, www.github.com/paboyle/Grid + +Source file: extras/Hadrons/Modules/MGauge/Random.hpp + +Copyright (C) 2015 +Copyright (C) 2016 + +Author: Antonin Portelli + +This program is free software; you can redistribute it and/or modify +it under the terms of the GNU General Public License as published by +the Free Software Foundation; either version 2 of the License, or +(at your option) any later version. + +This program is distributed in the hope that it will be useful, +but WITHOUT ANY WARRANTY; without even the implied warranty of +MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +GNU General Public License for more details. + +You should have received a copy of the GNU General Public License along +with this program; if not, write to the Free Software Foundation, Inc., +51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA. + +See the full license in the file "LICENSE" in the top level distribution directory +*************************************************************************************/ +/* END LEGAL */ + +#ifndef Hadrons_MGauge_Random_hpp_ +#define Hadrons_MGauge_Random_hpp_ + +#include +#include +#include + +BEGIN_HADRONS_NAMESPACE + +/****************************************************************************** + * Random gauge * + ******************************************************************************/ +BEGIN_MODULE_NAMESPACE(MGauge) + +class TRandom: public Module +{ +public: + // constructor + TRandom(const std::string name); + // destructor + virtual ~TRandom(void) = default; + // dependency relation + virtual std::vector getInput(void); + virtual std::vector getOutput(void); + // setup + virtual void setup(void); + // execution + virtual void execute(void); +}; + +MODULE_REGISTER_NS(Random, TRandom, MGauge); + +END_MODULE_NAMESPACE + +END_HADRONS_NAMESPACE + +#endif // Hadrons_MGauge_Random_hpp_ diff --git a/extras/Hadrons/Modules/MGauge/StochEm.cc b/extras/Hadrons/Modules/MGauge/StochEm.cc new file mode 100644 index 00000000..c7a9fc4f --- /dev/null +++ b/extras/Hadrons/Modules/MGauge/StochEm.cc @@ -0,0 +1,88 @@ +/************************************************************************************* + +Grid physics library, www.github.com/paboyle/Grid + +Source file: extras/Hadrons/Modules/MGauge/StochEm.cc + +Copyright (C) 2015 +Copyright (C) 2016 + + +This program is free software; you can redistribute it and/or modify +it under the terms of the GNU General Public License as published by +the Free Software Foundation; either version 2 of the License, or +(at your option) any later version. + +This program is distributed in the hope that it will be useful, +but WITHOUT ANY WARRANTY; without even the implied warranty of +MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +GNU General Public License for more details. + +You should have received a copy of the GNU General Public License along +with this program; if not, write to the Free Software Foundation, Inc., +51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA. + +See the full license in the file "LICENSE" in the top level distribution directory +*************************************************************************************/ +/* END LEGAL */ +#include + +using namespace Grid; +using namespace Hadrons; +using namespace MGauge; + +/****************************************************************************** +* TStochEm implementation * +******************************************************************************/ +// constructor ///////////////////////////////////////////////////////////////// +TStochEm::TStochEm(const std::string name) +: Module(name) +{} + +// dependencies/products /////////////////////////////////////////////////////// +std::vector TStochEm::getInput(void) +{ + std::vector in; + + return in; +} + +std::vector TStochEm::getOutput(void) +{ + std::vector out = {getName()}; + + return out; +} + +// setup /////////////////////////////////////////////////////////////////////// +void TStochEm::setup(void) +{ + if (!env().hasRegisteredObject("_" + getName() + "_weight")) + { + env().registerLattice("_" + getName() + "_weight"); + } + env().registerLattice(getName()); +} + +// execution /////////////////////////////////////////////////////////////////// +void TStochEm::execute(void) +{ + PhotonR photon(par().gauge, par().zmScheme); + EmField &a = *env().createLattice(getName()); + EmComp *w; + + if (!env().hasCreatedObject("_" + getName() + "_weight")) + { + LOG(Message) << "Caching stochatic EM potential weight (gauge: " + << par().gauge << ", zero-mode scheme: " + << par().zmScheme << ")..." << std::endl; + w = env().createLattice("_" + getName() + "_weight"); + photon.StochasticWeight(*w); + } + else + { + w = env().getObject("_" + getName() + "_weight"); + } + LOG(Message) << "Generating stochatic EM potential..." << std::endl; + photon.StochasticField(a, *env().get4dRng(), *w); +} diff --git a/extras/Hadrons/Modules/MGauge/StochEm.hpp b/extras/Hadrons/Modules/MGauge/StochEm.hpp new file mode 100644 index 00000000..12ce9fdc --- /dev/null +++ b/extras/Hadrons/Modules/MGauge/StochEm.hpp @@ -0,0 +1,75 @@ +/************************************************************************************* + +Grid physics library, www.github.com/paboyle/Grid + +Source file: extras/Hadrons/Modules/MGauge/StochEm.hpp + +Copyright (C) 2015 +Copyright (C) 2016 + + +This program is free software; you can redistribute it and/or modify +it under the terms of the GNU General Public License as published by +the Free Software Foundation; either version 2 of the License, or +(at your option) any later version. + +This program is distributed in the hope that it will be useful, +but WITHOUT ANY WARRANTY; without even the implied warranty of +MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +GNU General Public License for more details. + +You should have received a copy of the GNU General Public License along +with this program; if not, write to the Free Software Foundation, Inc., +51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA. + +See the full license in the file "LICENSE" in the top level distribution directory +*************************************************************************************/ +/* END LEGAL */ +#ifndef Hadrons_MGauge_StochEm_hpp_ +#define Hadrons_MGauge_StochEm_hpp_ + +#include +#include +#include + +BEGIN_HADRONS_NAMESPACE + +/****************************************************************************** + * StochEm * + ******************************************************************************/ +BEGIN_MODULE_NAMESPACE(MGauge) + +class StochEmPar: Serializable +{ +public: + GRID_SERIALIZABLE_CLASS_MEMBERS(StochEmPar, + PhotonR::Gauge, gauge, + PhotonR::ZmScheme, zmScheme); +}; + +class TStochEm: public Module +{ +public: + typedef PhotonR::GaugeField EmField; + typedef PhotonR::GaugeLinkField EmComp; +public: + // constructor + TStochEm(const std::string name); + // destructor + virtual ~TStochEm(void) = default; + // dependency relation + virtual std::vector getInput(void); + virtual std::vector getOutput(void); + // setup + virtual void setup(void); + // execution + virtual void execute(void); +}; + +MODULE_REGISTER_NS(StochEm, TStochEm, MGauge); + +END_MODULE_NAMESPACE + +END_HADRONS_NAMESPACE + +#endif // Hadrons_MGauge_StochEm_hpp_ diff --git a/extras/Hadrons/Modules/MGauge/Unit.cc b/extras/Hadrons/Modules/MGauge/Unit.cc new file mode 100644 index 00000000..18d75c59 --- /dev/null +++ b/extras/Hadrons/Modules/MGauge/Unit.cc @@ -0,0 +1,69 @@ +/************************************************************************************* + +Grid physics library, www.github.com/paboyle/Grid + +Source file: extras/Hadrons/Modules/MGauge/Unit.cc + +Copyright (C) 2015 +Copyright (C) 2016 + +Author: Antonin Portelli + +This program is free software; you can redistribute it and/or modify +it under the terms of the GNU General Public License as published by +the Free Software Foundation; either version 2 of the License, or +(at your option) any later version. + +This program is distributed in the hope that it will be useful, +but WITHOUT ANY WARRANTY; without even the implied warranty of +MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +GNU General Public License for more details. + +You should have received a copy of the GNU General Public License along +with this program; if not, write to the Free Software Foundation, Inc., +51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA. + +See the full license in the file "LICENSE" in the top level distribution directory +*************************************************************************************/ +/* END LEGAL */ + +#include + +using namespace Grid; +using namespace Hadrons; +using namespace MGauge; + +/****************************************************************************** +* TUnit implementation * +******************************************************************************/ +// constructor ///////////////////////////////////////////////////////////////// +TUnit::TUnit(const std::string name) +: Module(name) +{} + +// dependencies/products /////////////////////////////////////////////////////// +std::vector TUnit::getInput(void) +{ + return std::vector(); +} + +std::vector TUnit::getOutput(void) +{ + std::vector out = {getName()}; + + return out; +} + +// setup /////////////////////////////////////////////////////////////////////// +void TUnit::setup(void) +{ + env().registerLattice(getName()); +} + +// execution /////////////////////////////////////////////////////////////////// +void TUnit::execute(void) +{ + LOG(Message) << "Creating unit gauge configuration" << std::endl; + LatticeGaugeField &U = *env().createLattice(getName()); + SU3::ColdConfiguration(*env().get4dRng(), U); +} diff --git a/extras/Hadrons/Modules/MGauge/Unit.hpp b/extras/Hadrons/Modules/MGauge/Unit.hpp new file mode 100644 index 00000000..7cd15ef7 --- /dev/null +++ b/extras/Hadrons/Modules/MGauge/Unit.hpp @@ -0,0 +1,66 @@ +/************************************************************************************* + +Grid physics library, www.github.com/paboyle/Grid + +Source file: extras/Hadrons/Modules/MGauge/Unit.hpp + +Copyright (C) 2015 +Copyright (C) 2016 + +Author: Antonin Portelli + +This program is free software; you can redistribute it and/or modify +it under the terms of the GNU General Public License as published by +the Free Software Foundation; either version 2 of the License, or +(at your option) any later version. + +This program is distributed in the hope that it will be useful, +but WITHOUT ANY WARRANTY; without even the implied warranty of +MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +GNU General Public License for more details. + +You should have received a copy of the GNU General Public License along +with this program; if not, write to the Free Software Foundation, Inc., +51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA. + +See the full license in the file "LICENSE" in the top level distribution directory +*************************************************************************************/ +/* END LEGAL */ + +#ifndef Hadrons_MGauge_Unit_hpp_ +#define Hadrons_MGauge_Unit_hpp_ + +#include +#include +#include + +BEGIN_HADRONS_NAMESPACE + +/****************************************************************************** + * Unit gauge * + ******************************************************************************/ +BEGIN_MODULE_NAMESPACE(MGauge) + +class TUnit: public Module +{ +public: + // constructor + TUnit(const std::string name); + // destructor + virtual ~TUnit(void) = default; + // dependencies/products + virtual std::vector getInput(void); + virtual std::vector getOutput(void); + // setup + virtual void setup(void); + // execution + virtual void execute(void); +}; + +MODULE_REGISTER_NS(Unit, TUnit, MGauge); + +END_MODULE_NAMESPACE + +END_HADRONS_NAMESPACE + +#endif // Hadrons_MGauge_Unit_hpp_ diff --git a/extras/Hadrons/Modules/MLoop/NoiseLoop.hpp b/extras/Hadrons/Modules/MLoop/NoiseLoop.hpp new file mode 100644 index 00000000..5d2c4a13 --- /dev/null +++ b/extras/Hadrons/Modules/MLoop/NoiseLoop.hpp @@ -0,0 +1,132 @@ +/************************************************************************************* + +Grid physics library, www.github.com/paboyle/Grid + +Source file: extras/Hadrons/Modules/MLoop/NoiseLoop.hpp + +Copyright (C) 2016 + +Author: Andrew Lawson + +This program is free software; you can redistribute it and/or modify +it under the terms of the GNU General Public License as published by +the Free Software Foundation; either version 2 of the License, or +(at your option) any later version. + +This program is distributed in the hope that it will be useful, +but WITHOUT ANY WARRANTY; without even the implied warranty of +MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +GNU General Public License for more details. + +You should have received a copy of the GNU General Public License along +with this program; if not, write to the Free Software Foundation, Inc., +51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA. + +See the full license in the file "LICENSE" in the top level distribution directory +*************************************************************************************/ +/* END LEGAL */ + +#ifndef Hadrons_MLoop_NoiseLoop_hpp_ +#define Hadrons_MLoop_NoiseLoop_hpp_ + +#include +#include +#include + +BEGIN_HADRONS_NAMESPACE + +/* + + Noise loop propagator + ----------------------------- + * loop_x = q_x * adj(eta_x) + + * options: + - q = Result of inversion on noise source. + - eta = noise source. + + */ + + +/****************************************************************************** + * NoiseLoop * + ******************************************************************************/ +BEGIN_MODULE_NAMESPACE(MLoop) + +class NoiseLoopPar: Serializable +{ +public: + GRID_SERIALIZABLE_CLASS_MEMBERS(NoiseLoopPar, + std::string, q, + std::string, eta); +}; + +template +class TNoiseLoop: public Module +{ +public: + FERM_TYPE_ALIASES(FImpl,); +public: + // constructor + TNoiseLoop(const std::string name); + // destructor + virtual ~TNoiseLoop(void) = default; + // dependency relation + virtual std::vector getInput(void); + virtual std::vector getOutput(void); + // setup + virtual void setup(void); + // execution + virtual void execute(void); +}; + +MODULE_REGISTER_NS(NoiseLoop, TNoiseLoop, MLoop); + +/****************************************************************************** + * TNoiseLoop implementation * + ******************************************************************************/ +// constructor ///////////////////////////////////////////////////////////////// +template +TNoiseLoop::TNoiseLoop(const std::string name) +: Module(name) +{} + +// dependencies/products /////////////////////////////////////////////////////// +template +std::vector TNoiseLoop::getInput(void) +{ + std::vector in = {par().q, par().eta}; + + return in; +} + +template +std::vector TNoiseLoop::getOutput(void) +{ + std::vector out = {getName()}; + + return out; +} + +// setup /////////////////////////////////////////////////////////////////////// +template +void TNoiseLoop::setup(void) +{ + env().template registerLattice(getName()); +} + +// execution /////////////////////////////////////////////////////////////////// +template +void TNoiseLoop::execute(void) +{ + PropagatorField &loop = *env().template createLattice(getName()); + PropagatorField &q = *env().template getObject(par().q); + PropagatorField &eta = *env().template getObject(par().eta); + loop = q*adj(eta); +} + +END_MODULE_NAMESPACE + +END_HADRONS_NAMESPACE + +#endif // Hadrons_MLoop_NoiseLoop_hpp_ diff --git a/extras/Hadrons/Modules/MScalar/ChargedProp.cc b/extras/Hadrons/Modules/MScalar/ChargedProp.cc new file mode 100644 index 00000000..cd8dc244 --- /dev/null +++ b/extras/Hadrons/Modules/MScalar/ChargedProp.cc @@ -0,0 +1,226 @@ +#include +#include + +using namespace Grid; +using namespace Hadrons; +using namespace MScalar; + +/****************************************************************************** +* TChargedProp implementation * +******************************************************************************/ +// constructor ///////////////////////////////////////////////////////////////// +TChargedProp::TChargedProp(const std::string name) +: Module(name) +{} + +// dependencies/products /////////////////////////////////////////////////////// +std::vector TChargedProp::getInput(void) +{ + std::vector in = {par().source, par().emField}; + + return in; +} + +std::vector TChargedProp::getOutput(void) +{ + std::vector out = {getName()}; + + return out; +} + +// setup /////////////////////////////////////////////////////////////////////// +void TChargedProp::setup(void) +{ + freeMomPropName_ = FREEMOMPROP(par().mass); + phaseName_.clear(); + for (unsigned int mu = 0; mu < env().getNd(); ++mu) + { + phaseName_.push_back("_shiftphase_" + std::to_string(mu)); + } + GFSrcName_ = "_" + getName() + "_DinvSrc"; + if (!env().hasRegisteredObject(freeMomPropName_)) + { + env().registerLattice(freeMomPropName_); + } + if (!env().hasRegisteredObject(phaseName_[0])) + { + for (unsigned int mu = 0; mu < env().getNd(); ++mu) + { + env().registerLattice(phaseName_[mu]); + } + } + if (!env().hasRegisteredObject(GFSrcName_)) + { + env().registerLattice(GFSrcName_); + } + env().registerLattice(getName()); +} + +// execution /////////////////////////////////////////////////////////////////// +void TChargedProp::execute(void) +{ + // CACHING ANALYTIC EXPRESSIONS + ScalarField &source = *env().getObject(par().source); + Complex ci(0.0,1.0); + FFT fft(env().getGrid()); + + // cache free scalar propagator + if (!env().hasCreatedObject(freeMomPropName_)) + { + LOG(Message) << "Caching momentum space free scalar propagator" + << " (mass= " << par().mass << ")..." << std::endl; + freeMomProp_ = env().createLattice(freeMomPropName_); + SIMPL::MomentumSpacePropagator(*freeMomProp_, par().mass); + } + else + { + freeMomProp_ = env().getObject(freeMomPropName_); + } + // cache G*F*src + if (!env().hasCreatedObject(GFSrcName_)) + + { + GFSrc_ = env().createLattice(GFSrcName_); + fft.FFT_all_dim(*GFSrc_, source, FFT::forward); + *GFSrc_ = (*freeMomProp_)*(*GFSrc_); + } + else + { + GFSrc_ = env().getObject(GFSrcName_); + } + // cache phases + if (!env().hasCreatedObject(phaseName_[0])) + { + std::vector &l = env().getGrid()->_fdimensions; + + LOG(Message) << "Caching shift phases..." << std::endl; + for (unsigned int mu = 0; mu < env().getNd(); ++mu) + { + Real twoPiL = M_PI*2./l[mu]; + + phase_.push_back(env().createLattice(phaseName_[mu])); + LatticeCoordinate(*(phase_[mu]), mu); + *(phase_[mu]) = exp(ci*twoPiL*(*(phase_[mu]))); + } + } + else + { + for (unsigned int mu = 0; mu < env().getNd(); ++mu) + { + phase_.push_back(env().getObject(phaseName_[mu])); + } + } + + // PROPAGATOR CALCULATION + LOG(Message) << "Computing charged scalar propagator" + << " (mass= " << par().mass + << ", charge= " << par().charge << ")..." << std::endl; + + ScalarField &prop = *env().createLattice(getName()); + ScalarField buf(env().getGrid()); + ScalarField &GFSrc = *GFSrc_, &G = *freeMomProp_; + double q = par().charge; + + // G*F*Src + prop = GFSrc; + + // - q*G*momD1*G*F*Src (momD1 = F*D1*Finv) + buf = GFSrc; + momD1(buf, fft); + buf = G*buf; + prop = prop - q*buf; + + // + q^2*G*momD1*G*momD1*G*F*Src (here buf = G*momD1*G*F*Src) + momD1(buf, fft); + prop = prop + q*q*G*buf; + + // - q^2*G*momD2*G*F*Src (momD2 = F*D2*Finv) + buf = GFSrc; + momD2(buf, fft); + prop = prop - q*q*G*buf; + + // final FT + fft.FFT_all_dim(prop, prop, FFT::backward); + + // OUTPUT IF NECESSARY + if (!par().output.empty()) + { + std::string filename = par().output + "." + + std::to_string(env().getTrajectory()); + + LOG(Message) << "Saving zero-momentum projection to '" + << filename << "'..." << std::endl; + + CorrWriter writer(filename); + std::vector vecBuf; + std::vector result; + + sliceSum(prop, vecBuf, Tp); + result.resize(vecBuf.size()); + for (unsigned int t = 0; t < vecBuf.size(); ++t) + { + result[t] = TensorRemove(vecBuf[t]); + } + write(writer, "charge", q); + write(writer, "prop", result); + } +} + +void TChargedProp::momD1(ScalarField &s, FFT &fft) +{ + EmField &A = *env().getObject(par().emField); + ScalarField buf(env().getGrid()), result(env().getGrid()), + Amu(env().getGrid()); + Complex ci(0.0,1.0); + + result = zero; + + for (unsigned int mu = 0; mu < env().getNd(); ++mu) + { + Amu = peekLorentz(A, mu); + buf = (*phase_[mu])*s; + fft.FFT_all_dim(buf, buf, FFT::backward); + buf = Amu*buf; + fft.FFT_all_dim(buf, buf, FFT::forward); + result = result - ci*buf; + } + fft.FFT_all_dim(s, s, FFT::backward); + for (unsigned int mu = 0; mu < env().getNd(); ++mu) + { + Amu = peekLorentz(A, mu); + buf = Amu*s; + fft.FFT_all_dim(buf, buf, FFT::forward); + result = result + ci*adj(*phase_[mu])*buf; + } + + s = result; +} + +void TChargedProp::momD2(ScalarField &s, FFT &fft) +{ + EmField &A = *env().getObject(par().emField); + ScalarField buf(env().getGrid()), result(env().getGrid()), + Amu(env().getGrid()); + + result = zero; + + for (unsigned int mu = 0; mu < env().getNd(); ++mu) + { + Amu = peekLorentz(A, mu); + buf = (*phase_[mu])*s; + fft.FFT_all_dim(buf, buf, FFT::backward); + buf = Amu*Amu*buf; + fft.FFT_all_dim(buf, buf, FFT::forward); + result = result + .5*buf; + } + fft.FFT_all_dim(s, s, FFT::backward); + for (unsigned int mu = 0; mu < env().getNd(); ++mu) + { + Amu = peekLorentz(A, mu); + buf = Amu*Amu*s; + fft.FFT_all_dim(buf, buf, FFT::forward); + result = result + .5*adj(*phase_[mu])*buf; + } + + s = result; +} diff --git a/extras/Hadrons/Modules/MScalar/ChargedProp.hpp b/extras/Hadrons/Modules/MScalar/ChargedProp.hpp new file mode 100644 index 00000000..fbe75c05 --- /dev/null +++ b/extras/Hadrons/Modules/MScalar/ChargedProp.hpp @@ -0,0 +1,61 @@ +#ifndef Hadrons_MScalar_ChargedProp_hpp_ +#define Hadrons_MScalar_ChargedProp_hpp_ + +#include +#include +#include + +BEGIN_HADRONS_NAMESPACE + +/****************************************************************************** + * Charged scalar propagator * + ******************************************************************************/ +BEGIN_MODULE_NAMESPACE(MScalar) + +class ChargedPropPar: Serializable +{ +public: + GRID_SERIALIZABLE_CLASS_MEMBERS(ChargedPropPar, + std::string, emField, + std::string, source, + double, mass, + double, charge, + std::string, output); +}; + +class TChargedProp: public Module +{ +public: + SCALAR_TYPE_ALIASES(SIMPL,); + typedef PhotonR::GaugeField EmField; + typedef PhotonR::GaugeLinkField EmComp; +public: + // constructor + TChargedProp(const std::string name); + // destructor + virtual ~TChargedProp(void) = default; + // dependency relation + virtual std::vector getInput(void); + virtual std::vector getOutput(void); + // setup + virtual void setup(void); + // execution + virtual void execute(void); +private: + void momD1(ScalarField &s, FFT &fft); + void momD2(ScalarField &s, FFT &fft); +private: + std::string freeMomPropName_, GFSrcName_; + std::vector phaseName_; + ScalarField *freeMomProp_, *GFSrc_; + std::vector phase_; + EmField *A; +}; + +MODULE_REGISTER_NS(ChargedProp, TChargedProp, MScalar); + +END_MODULE_NAMESPACE + +END_HADRONS_NAMESPACE + +#endif // Hadrons_MScalar_ChargedProp_hpp_ diff --git a/extras/Hadrons/Modules/MScalar/FreeProp.cc b/extras/Hadrons/Modules/MScalar/FreeProp.cc new file mode 100644 index 00000000..674867e3 --- /dev/null +++ b/extras/Hadrons/Modules/MScalar/FreeProp.cc @@ -0,0 +1,79 @@ +#include +#include + +using namespace Grid; +using namespace Hadrons; +using namespace MScalar; + +/****************************************************************************** +* TFreeProp implementation * +******************************************************************************/ +// constructor ///////////////////////////////////////////////////////////////// +TFreeProp::TFreeProp(const std::string name) +: Module(name) +{} + +// dependencies/products /////////////////////////////////////////////////////// +std::vector TFreeProp::getInput(void) +{ + std::vector in = {par().source}; + + return in; +} + +std::vector TFreeProp::getOutput(void) +{ + std::vector out = {getName()}; + + return out; +} + +// setup /////////////////////////////////////////////////////////////////////// +void TFreeProp::setup(void) +{ + freeMomPropName_ = FREEMOMPROP(par().mass); + + if (!env().hasRegisteredObject(freeMomPropName_)) + { + env().registerLattice(freeMomPropName_); + } + env().registerLattice(getName()); +} + +// execution /////////////////////////////////////////////////////////////////// +void TFreeProp::execute(void) +{ + ScalarField &prop = *env().createLattice(getName()); + ScalarField &source = *env().getObject(par().source); + ScalarField *freeMomProp; + + if (!env().hasCreatedObject(freeMomPropName_)) + { + LOG(Message) << "Caching momentum space free scalar propagator" + << " (mass= " << par().mass << ")..." << std::endl; + freeMomProp = env().createLattice(freeMomPropName_); + SIMPL::MomentumSpacePropagator(*freeMomProp, par().mass); + } + else + { + freeMomProp = env().getObject(freeMomPropName_); + } + LOG(Message) << "Computing free scalar propagator..." << std::endl; + SIMPL::FreePropagator(source, prop, *freeMomProp); + + if (!par().output.empty()) + { + TextWriter writer(par().output + "." + + std::to_string(env().getTrajectory())); + std::vector buf; + std::vector result; + + sliceSum(prop, buf, Tp); + result.resize(buf.size()); + for (unsigned int t = 0; t < buf.size(); ++t) + { + result[t] = TensorRemove(buf[t]); + } + write(writer, "prop", result); + } +} diff --git a/extras/Hadrons/Modules/MScalar/FreeProp.hpp b/extras/Hadrons/Modules/MScalar/FreeProp.hpp new file mode 100644 index 00000000..97cf288a --- /dev/null +++ b/extras/Hadrons/Modules/MScalar/FreeProp.hpp @@ -0,0 +1,50 @@ +#ifndef Hadrons_MScalar_FreeProp_hpp_ +#define Hadrons_MScalar_FreeProp_hpp_ + +#include +#include +#include + +BEGIN_HADRONS_NAMESPACE + +/****************************************************************************** + * FreeProp * + ******************************************************************************/ +BEGIN_MODULE_NAMESPACE(MScalar) + +class FreePropPar: Serializable +{ +public: + GRID_SERIALIZABLE_CLASS_MEMBERS(FreePropPar, + std::string, source, + double, mass, + std::string, output); +}; + +class TFreeProp: public Module +{ +public: + SCALAR_TYPE_ALIASES(SIMPL,); +public: + // constructor + TFreeProp(const std::string name); + // destructor + virtual ~TFreeProp(void) = default; + // dependency relation + virtual std::vector getInput(void); + virtual std::vector getOutput(void); + // setup + virtual void setup(void); + // execution + virtual void execute(void); +private: + std::string freeMomPropName_; +}; + +MODULE_REGISTER_NS(FreeProp, TFreeProp, MScalar); + +END_MODULE_NAMESPACE + +END_HADRONS_NAMESPACE + +#endif // Hadrons_MScalar_FreeProp_hpp_ diff --git a/extras/Hadrons/Modules/MScalar/Scalar.hpp b/extras/Hadrons/Modules/MScalar/Scalar.hpp new file mode 100644 index 00000000..db702ff2 --- /dev/null +++ b/extras/Hadrons/Modules/MScalar/Scalar.hpp @@ -0,0 +1,6 @@ +#ifndef Hadrons_Scalar_hpp_ +#define Hadrons_Scalar_hpp_ + +#define FREEMOMPROP(m) "_scalar_mom_prop_" + std::to_string(m) + +#endif // Hadrons_Scalar_hpp_ diff --git a/extras/Hadrons/Modules/MSink/Point.hpp b/extras/Hadrons/Modules/MSink/Point.hpp new file mode 100644 index 00000000..7b3aa9de --- /dev/null +++ b/extras/Hadrons/Modules/MSink/Point.hpp @@ -0,0 +1,114 @@ +#ifndef Hadrons_MSink_Point_hpp_ +#define Hadrons_MSink_Point_hpp_ + +#include +#include +#include + +BEGIN_HADRONS_NAMESPACE + +/****************************************************************************** + * Point * + ******************************************************************************/ +BEGIN_MODULE_NAMESPACE(MSink) + +class PointPar: Serializable +{ +public: + GRID_SERIALIZABLE_CLASS_MEMBERS(PointPar, + std::string, mom); +}; + +template +class TPoint: public Module +{ +public: + FERM_TYPE_ALIASES(FImpl,); + SINK_TYPE_ALIASES(); +public: + // constructor + TPoint(const std::string name); + // destructor + virtual ~TPoint(void) = default; + // dependency relation + virtual std::vector getInput(void); + virtual std::vector getOutput(void); + // setup + virtual void setup(void); + // execution + virtual void execute(void); +}; + +MODULE_REGISTER_NS(Point, TPoint, MSink); +MODULE_REGISTER_NS(ScalarPoint, TPoint, MSink); + +/****************************************************************************** + * TPoint implementation * + ******************************************************************************/ +// constructor ///////////////////////////////////////////////////////////////// +template +TPoint::TPoint(const std::string name) +: Module(name) +{} + +// dependencies/products /////////////////////////////////////////////////////// +template +std::vector TPoint::getInput(void) +{ + std::vector in; + + return in; +} + +template +std::vector TPoint::getOutput(void) +{ + std::vector out = {getName()}; + + return out; +} + +// setup /////////////////////////////////////////////////////////////////////// +template +void TPoint::setup(void) +{ + unsigned int size; + + size = env().template lattice4dSize(); + env().registerObject(getName(), size); +} + +// execution /////////////////////////////////////////////////////////////////// +template +void TPoint::execute(void) +{ + std::vector p = strToVec(par().mom); + LatticeComplex ph(env().getGrid()), coor(env().getGrid()); + Complex i(0.0,1.0); + + LOG(Message) << "Setting up point sink function for momentum [" + << par().mom << "]" << std::endl; + ph = zero; + for(unsigned int mu = 0; mu < env().getNd(); mu++) + { + LatticeCoordinate(coor, mu); + ph = ph + (p[mu]/env().getGrid()->_fdimensions[mu])*coor; + } + ph = exp((Real)(2*M_PI)*i*ph); + auto sink = [ph](const PropagatorField &field) + { + SlicedPropagator res; + PropagatorField tmp = ph*field; + + sliceSum(tmp, res, Tp); + + return res; + }; + env().setObject(getName(), new SinkFn(sink)); +} + +END_MODULE_NAMESPACE + +END_HADRONS_NAMESPACE + +#endif // Hadrons_MSink_Point_hpp_ diff --git a/extras/Hadrons/Modules/MSolver/RBPrecCG.hpp b/extras/Hadrons/Modules/MSolver/RBPrecCG.hpp new file mode 100644 index 00000000..b1f63a5d --- /dev/null +++ b/extras/Hadrons/Modules/MSolver/RBPrecCG.hpp @@ -0,0 +1,132 @@ +/************************************************************************************* + +Grid physics library, www.github.com/paboyle/Grid + +Source file: extras/Hadrons/Modules/MSolver/RBPrecCG.hpp + +Copyright (C) 2015 +Copyright (C) 2016 + +Author: Antonin Portelli + +This program is free software; you can redistribute it and/or modify +it under the terms of the GNU General Public License as published by +the Free Software Foundation; either version 2 of the License, or +(at your option) any later version. + +This program is distributed in the hope that it will be useful, +but WITHOUT ANY WARRANTY; without even the implied warranty of +MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +GNU General Public License for more details. + +You should have received a copy of the GNU General Public License along +with this program; if not, write to the Free Software Foundation, Inc., +51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA. + +See the full license in the file "LICENSE" in the top level distribution directory +*************************************************************************************/ +/* END LEGAL */ + +#ifndef Hadrons_MSolver_RBPrecCG_hpp_ +#define Hadrons_MSolver_RBPrecCG_hpp_ + +#include +#include +#include + +BEGIN_HADRONS_NAMESPACE + +/****************************************************************************** + * Schur red-black preconditioned CG * + ******************************************************************************/ +BEGIN_MODULE_NAMESPACE(MSolver) + +class RBPrecCGPar: Serializable +{ +public: + GRID_SERIALIZABLE_CLASS_MEMBERS(RBPrecCGPar, + std::string, action, + double , residual); +}; + +template +class TRBPrecCG: public Module +{ +public: + FGS_TYPE_ALIASES(FImpl,); +public: + // constructor + TRBPrecCG(const std::string name); + // destructor + virtual ~TRBPrecCG(void) = default; + // dependencies/products + virtual std::vector getInput(void); + virtual std::vector getOutput(void); + // setup + virtual void setup(void); + // execution + virtual void execute(void); +}; + +MODULE_REGISTER_NS(RBPrecCG, TRBPrecCG, MSolver); + +/****************************************************************************** + * TRBPrecCG template implementation * + ******************************************************************************/ +// constructor ///////////////////////////////////////////////////////////////// +template +TRBPrecCG::TRBPrecCG(const std::string name) +: Module(name) +{} + +// dependencies/products /////////////////////////////////////////////////////// +template +std::vector TRBPrecCG::getInput(void) +{ + std::vector in = {par().action}; + + return in; +} + +template +std::vector TRBPrecCG::getOutput(void) +{ + std::vector out = {getName()}; + + return out; +} + +// setup /////////////////////////////////////////////////////////////////////// +template +void TRBPrecCG::setup(void) +{ + auto Ls = env().getObjectLs(par().action); + + env().registerObject(getName(), 0, Ls); + env().addOwnership(getName(), par().action); +} + +// execution /////////////////////////////////////////////////////////////////// +template +void TRBPrecCG::execute(void) +{ + auto &mat = *(env().template getObject(par().action)); + auto solver = [&mat, this](FermionField &sol, const FermionField &source) + { + ConjugateGradient cg(par().residual, 10000); + SchurRedBlackDiagMooeeSolve schurSolver(cg); + + schurSolver(mat, source, sol); + }; + + LOG(Message) << "setting up Schur red-black preconditioned CG for" + << " action '" << par().action << "' with residual " + << par().residual << std::endl; + env().setObject(getName(), new SolverFn(solver)); +} + +END_MODULE_NAMESPACE + +END_HADRONS_NAMESPACE + +#endif // Hadrons_MSolver_RBPrecCG_hpp_ diff --git a/extras/Hadrons/Modules/MSource/Point.hpp b/extras/Hadrons/Modules/MSource/Point.hpp new file mode 100644 index 00000000..0c415807 --- /dev/null +++ b/extras/Hadrons/Modules/MSource/Point.hpp @@ -0,0 +1,136 @@ +/************************************************************************************* + +Grid physics library, www.github.com/paboyle/Grid + +Source file: extras/Hadrons/Modules/MSource/Point.hpp + +Copyright (C) 2015 +Copyright (C) 2016 + +Author: Antonin Portelli + +This program is free software; you can redistribute it and/or modify +it under the terms of the GNU General Public License as published by +the Free Software Foundation; either version 2 of the License, or +(at your option) any later version. + +This program is distributed in the hope that it will be useful, +but WITHOUT ANY WARRANTY; without even the implied warranty of +MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +GNU General Public License for more details. + +You should have received a copy of the GNU General Public License along +with this program; if not, write to the Free Software Foundation, Inc., +51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA. + +See the full license in the file "LICENSE" in the top level distribution directory +*************************************************************************************/ +/* END LEGAL */ + +#ifndef Hadrons_MSource_Point_hpp_ +#define Hadrons_MSource_Point_hpp_ + +#include +#include +#include + +BEGIN_HADRONS_NAMESPACE + +/* + + Point source + ------------ + * src_x = delta_x,position + + * options: + - position: space-separated integer sequence (e.g. "0 1 1 0") + + */ + +/****************************************************************************** + * TPoint * + ******************************************************************************/ +BEGIN_MODULE_NAMESPACE(MSource) + +class PointPar: Serializable +{ +public: + GRID_SERIALIZABLE_CLASS_MEMBERS(PointPar, + std::string, position); +}; + +template +class TPoint: public Module +{ +public: + FERM_TYPE_ALIASES(FImpl,); +public: + // constructor + TPoint(const std::string name); + // destructor + virtual ~TPoint(void) = default; + // dependency relation + virtual std::vector getInput(void); + virtual std::vector getOutput(void); + // setup + virtual void setup(void); + // execution + virtual void execute(void); +}; + +MODULE_REGISTER_NS(Point, TPoint, MSource); +MODULE_REGISTER_NS(ScalarPoint, TPoint, MSource); + +/****************************************************************************** + * TPoint template implementation * + ******************************************************************************/ +// constructor ///////////////////////////////////////////////////////////////// +template +TPoint::TPoint(const std::string name) +: Module(name) +{} + +// dependencies/products /////////////////////////////////////////////////////// +template +std::vector TPoint::getInput(void) +{ + std::vector in; + + return in; +} + +template +std::vector TPoint::getOutput(void) +{ + std::vector out = {getName()}; + + return out; +} + +// setup /////////////////////////////////////////////////////////////////////// +template +void TPoint::setup(void) +{ + env().template registerLattice(getName()); +} + +// execution /////////////////////////////////////////////////////////////////// +template +void TPoint::execute(void) +{ + std::vector position = strToVec(par().position); + typename SitePropagator::scalar_object id; + + LOG(Message) << "Creating point source at position [" << par().position + << "]" << std::endl; + PropagatorField &src = *env().template createLattice(getName()); + id = 1.; + src = zero; + pokeSite(id, src, position); +} + +END_MODULE_NAMESPACE + +END_HADRONS_NAMESPACE + +#endif // Hadrons_MSource_Point_hpp_ diff --git a/extras/Hadrons/Modules/MSource/SeqGamma.hpp b/extras/Hadrons/Modules/MSource/SeqGamma.hpp new file mode 100644 index 00000000..e2129a46 --- /dev/null +++ b/extras/Hadrons/Modules/MSource/SeqGamma.hpp @@ -0,0 +1,164 @@ +/************************************************************************************* + +Grid physics library, www.github.com/paboyle/Grid + +Source file: extras/Hadrons/Modules/MSource/SeqGamma.hpp + +Copyright (C) 2015 +Copyright (C) 2016 +Copyright (C) 2017 + +Author: Antonin Portelli + +This program is free software; you can redistribute it and/or modify +it under the terms of the GNU General Public License as published by +the Free Software Foundation; either version 2 of the License, or +(at your option) any later version. + +This program is distributed in the hope that it will be useful, +but WITHOUT ANY WARRANTY; without even the implied warranty of +MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +GNU General Public License for more details. + +You should have received a copy of the GNU General Public License along +with this program; if not, write to the Free Software Foundation, Inc., +51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA. + +See the full license in the file "LICENSE" in the top level distribution directory +*************************************************************************************/ +/* END LEGAL */ + +#ifndef Hadrons_MSource_SeqGamma_hpp_ +#define Hadrons_MSource_SeqGamma_hpp_ + +#include +#include +#include + +BEGIN_HADRONS_NAMESPACE + +/* + + Sequential source + ----------------------------- + * src_x = q_x * theta(x_3 - tA) * theta(tB - x_3) * gamma * exp(i x.mom) + + * options: + - q: input propagator (string) + - tA: begin timeslice (integer) + - tB: end timesilce (integer) + - gamma: gamma product to insert (integer) + - mom: momentum insertion, space-separated float sequence (e.g ".1 .2 1. 0.") + + */ + +/****************************************************************************** + * SeqGamma * + ******************************************************************************/ +BEGIN_MODULE_NAMESPACE(MSource) + +class SeqGammaPar: Serializable +{ +public: + GRID_SERIALIZABLE_CLASS_MEMBERS(SeqGammaPar, + std::string, q, + unsigned int, tA, + unsigned int, tB, + Gamma::Algebra, gamma, + std::string, mom); +}; + +template +class TSeqGamma: public Module +{ +public: + FGS_TYPE_ALIASES(FImpl,); +public: + // constructor + TSeqGamma(const std::string name); + // destructor + virtual ~TSeqGamma(void) = default; + // dependency relation + virtual std::vector getInput(void); + virtual std::vector getOutput(void); + // setup + virtual void setup(void); + // execution + virtual void execute(void); +}; + +MODULE_REGISTER_NS(SeqGamma, TSeqGamma, MSource); + +/****************************************************************************** + * TSeqGamma implementation * + ******************************************************************************/ +// constructor ///////////////////////////////////////////////////////////////// +template +TSeqGamma::TSeqGamma(const std::string name) +: Module(name) +{} + +// dependencies/products /////////////////////////////////////////////////////// +template +std::vector TSeqGamma::getInput(void) +{ + std::vector in = {par().q}; + + return in; +} + +template +std::vector TSeqGamma::getOutput(void) +{ + std::vector out = {getName()}; + + return out; +} + +// setup /////////////////////////////////////////////////////////////////////// +template +void TSeqGamma::setup(void) +{ + env().template registerLattice(getName()); +} + +// execution /////////////////////////////////////////////////////////////////// +template +void TSeqGamma::execute(void) +{ + if (par().tA == par().tB) + { + LOG(Message) << "Generating gamma_" << par().gamma + << " sequential source at t= " << par().tA << std::endl; + } + else + { + LOG(Message) << "Generating gamma_" << par().gamma + << " sequential source for " + << par().tA << " <= t <= " << par().tB << std::endl; + } + PropagatorField &src = *env().template createLattice(getName()); + PropagatorField &q = *env().template getObject(par().q); + Lattice> t(env().getGrid()); + LatticeComplex ph(env().getGrid()), coor(env().getGrid()); + Gamma g(par().gamma); + std::vector p; + Complex i(0.0,1.0); + + p = strToVec(par().mom); + ph = zero; + for(unsigned int mu = 0; mu < env().getNd(); mu++) + { + LatticeCoordinate(coor, mu); + ph = ph + p[mu]*coor*((1./(env().getGrid()->_fdimensions[mu]))); + } + ph = exp((Real)(2*M_PI)*i*ph); + LatticeCoordinate(t, Tp); + src = where((t >= par().tA) and (t <= par().tB), ph*(g*q), 0.*q); +} + +END_MODULE_NAMESPACE + +END_HADRONS_NAMESPACE + +#endif // Hadrons_MSource_SeqGamma_hpp_ diff --git a/extras/Hadrons/Modules/MSource/Wall.hpp b/extras/Hadrons/Modules/MSource/Wall.hpp new file mode 100644 index 00000000..4de37e4d --- /dev/null +++ b/extras/Hadrons/Modules/MSource/Wall.hpp @@ -0,0 +1,147 @@ +/************************************************************************************* + +Grid physics library, www.github.com/paboyle/Grid + +Source file: extras/Hadrons/Modules/MSource/Wall.hpp + +Copyright (C) 2017 + +Author: Andrew Lawson + +This program is free software; you can redistribute it and/or modify +it under the terms of the GNU General Public License as published by +the Free Software Foundation; either version 2 of the License, or +(at your option) any later version. + +This program is distributed in the hope that it will be useful, +but WITHOUT ANY WARRANTY; without even the implied warranty of +MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +GNU General Public License for more details. + +You should have received a copy of the GNU General Public License along +with this program; if not, write to the Free Software Foundation, Inc., +51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA. + +See the full license in the file "LICENSE" in the top level distribution directory +*************************************************************************************/ +/* END LEGAL */ + +#ifndef Hadrons_MSource_WallSource_hpp_ +#define Hadrons_MSource_WallSource_hpp_ + +#include +#include +#include + +BEGIN_HADRONS_NAMESPACE + +/* + + Wall source + ----------------------------- + * src_x = delta(x_3 - tW) * exp(i x.mom) + + * options: + - tW: source timeslice (integer) + - mom: momentum insertion, space-separated float sequence (e.g ".1 .2 1. 0.") + + */ + +/****************************************************************************** + * Wall * + ******************************************************************************/ +BEGIN_MODULE_NAMESPACE(MSource) + +class WallPar: Serializable +{ +public: + GRID_SERIALIZABLE_CLASS_MEMBERS(WallPar, + unsigned int, tW, + std::string, mom); +}; + +template +class TWall: public Module +{ +public: + FERM_TYPE_ALIASES(FImpl,); +public: + // constructor + TWall(const std::string name); + // destructor + virtual ~TWall(void) = default; + // dependency relation + virtual std::vector getInput(void); + virtual std::vector getOutput(void); + // setup + virtual void setup(void); + // execution + virtual void execute(void); +}; + +MODULE_REGISTER_NS(Wall, TWall, MSource); + +/****************************************************************************** + * TWall implementation * + ******************************************************************************/ +// constructor ///////////////////////////////////////////////////////////////// +template +TWall::TWall(const std::string name) +: Module(name) +{} + +// dependencies/products /////////////////////////////////////////////////////// +template +std::vector TWall::getInput(void) +{ + std::vector in; + + return in; +} + +template +std::vector TWall::getOutput(void) +{ + std::vector out = {getName()}; + + return out; +} + +// setup /////////////////////////////////////////////////////////////////////// +template +void TWall::setup(void) +{ + env().template registerLattice(getName()); +} + +// execution /////////////////////////////////////////////////////////////////// +template +void TWall::execute(void) +{ + LOG(Message) << "Generating wall source at t = " << par().tW + << " with momentum " << par().mom << std::endl; + + PropagatorField &src = *env().template createLattice(getName()); + Lattice> t(env().getGrid()); + LatticeComplex ph(env().getGrid()), coor(env().getGrid()); + std::vector p; + Complex i(0.0,1.0); + + p = strToVec(par().mom); + ph = zero; + for(unsigned int mu = 0; mu < Nd; mu++) + { + LatticeCoordinate(coor, mu); + ph = ph + p[mu]*coor*((1./(env().getGrid()->_fdimensions[mu]))); + } + ph = exp((Real)(2*M_PI)*i*ph); + LatticeCoordinate(t, Tp); + src = 1.; + src = where((t == par().tW), src*ph, 0.*src); +} + +END_MODULE_NAMESPACE + +END_HADRONS_NAMESPACE + +#endif // Hadrons_MSource_WallSource_hpp_ diff --git a/extras/Hadrons/Modules/MSource/Z2.hpp b/extras/Hadrons/Modules/MSource/Z2.hpp new file mode 100644 index 00000000..a7f7a3e6 --- /dev/null +++ b/extras/Hadrons/Modules/MSource/Z2.hpp @@ -0,0 +1,152 @@ +/************************************************************************************* + +Grid physics library, www.github.com/paboyle/Grid + +Source file: extras/Hadrons/Modules/MSource/Z2.hpp + +Copyright (C) 2015 +Copyright (C) 2016 + +Author: Antonin Portelli + +This program is free software; you can redistribute it and/or modify +it under the terms of the GNU General Public License as published by +the Free Software Foundation; either version 2 of the License, or +(at your option) any later version. + +This program is distributed in the hope that it will be useful, +but WITHOUT ANY WARRANTY; without even the implied warranty of +MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +GNU General Public License for more details. + +You should have received a copy of the GNU General Public License along +with this program; if not, write to the Free Software Foundation, Inc., +51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA. + +See the full license in the file "LICENSE" in the top level distribution directory +*************************************************************************************/ +/* END LEGAL */ + +#ifndef Hadrons_MSource_Z2_hpp_ +#define Hadrons_MSource_Z2_hpp_ + +#include +#include +#include + +BEGIN_HADRONS_NAMESPACE + +/* + + Z_2 stochastic source + ----------------------------- + * src_x = eta_x * theta(x_3 - tA) * theta(tB - x_3) + + the eta_x are independent uniform random numbers in {+/- 1 +/- i} + + * options: + - tA: begin timeslice (integer) + - tB: end timesilce (integer) + + */ + +/****************************************************************************** + * Z2 stochastic source * + ******************************************************************************/ +BEGIN_MODULE_NAMESPACE(MSource) + +class Z2Par: Serializable +{ +public: + GRID_SERIALIZABLE_CLASS_MEMBERS(Z2Par, + unsigned int, tA, + unsigned int, tB); +}; + +template +class TZ2: public Module +{ +public: + FERM_TYPE_ALIASES(FImpl,); +public: + // constructor + TZ2(const std::string name); + // destructor + virtual ~TZ2(void) = default; + // dependency relation + virtual std::vector getInput(void); + virtual std::vector getOutput(void); + // setup + virtual void setup(void); + // execution + virtual void execute(void); +}; + +MODULE_REGISTER_NS(Z2, TZ2, MSource); +MODULE_REGISTER_NS(ScalarZ2, TZ2, MSource); + +/****************************************************************************** + * TZ2 template implementation * + ******************************************************************************/ +// constructor ///////////////////////////////////////////////////////////////// +template +TZ2::TZ2(const std::string name) +: Module(name) +{} + +// dependencies/products /////////////////////////////////////////////////////// +template +std::vector TZ2::getInput(void) +{ + std::vector in; + + return in; +} + +template +std::vector TZ2::getOutput(void) +{ + std::vector out = {getName()}; + + return out; +} + +// setup /////////////////////////////////////////////////////////////////////// +template +void TZ2::setup(void) +{ + env().template registerLattice(getName()); +} + +// execution /////////////////////////////////////////////////////////////////// +template +void TZ2::execute(void) +{ + Lattice> t(env().getGrid()); + LatticeComplex eta(env().getGrid()); + Complex shift(1., 1.); + + if (par().tA == par().tB) + { + LOG(Message) << "Generating Z_2 wall source at t= " << par().tA + << std::endl; + } + else + { + LOG(Message) << "Generating Z_2 band for " << par().tA << " <= t <= " + << par().tB << std::endl; + } + PropagatorField &src = *env().template createLattice(getName()); + LatticeCoordinate(t, Tp); + bernoulli(*env().get4dRng(), eta); + eta = (2.*eta - shift)*(1./::sqrt(2.)); + eta = where((t >= par().tA) and (t <= par().tB), eta, 0.*eta); + src = 1.; + src = src*eta; +} + +END_MODULE_NAMESPACE + +END_HADRONS_NAMESPACE + +#endif // Hadrons_MSource_Z2_hpp_ diff --git a/extras/Hadrons/Modules/templates/Module.cc.template b/extras/Hadrons/Modules/templates/Module.cc.template new file mode 100644 index 00000000..0c509d6d --- /dev/null +++ b/extras/Hadrons/Modules/templates/Module.cc.template @@ -0,0 +1,39 @@ +#include + +using namespace Grid; +using namespace Hadrons; + +/****************************************************************************** +* T___FILEBASENAME___ implementation * +******************************************************************************/ +// constructor ///////////////////////////////////////////////////////////////// +T___FILEBASENAME___::T___FILEBASENAME___(const std::string name) +: Module<___FILEBASENAME___Par>(name) +{} + +// dependencies/products /////////////////////////////////////////////////////// +std::vector T___FILEBASENAME___::getInput(void) +{ + std::vector in; + + return in; +} + +std::vector T___FILEBASENAME___::getOutput(void) +{ + std::vector out = {getName()}; + + return out; +} + +// setup /////////////////////////////////////////////////////////////////////// +void T___FILEBASENAME___::setup(void) +{ + +} + +// execution /////////////////////////////////////////////////////////////////// +void T___FILEBASENAME___::execute(void) +{ + +} diff --git a/extras/Hadrons/Modules/templates/Module.hpp.template b/extras/Hadrons/Modules/templates/Module.hpp.template new file mode 100644 index 00000000..fb43260f --- /dev/null +++ b/extras/Hadrons/Modules/templates/Module.hpp.template @@ -0,0 +1,40 @@ +#ifndef Hadrons____FILEBASENAME____hpp_ +#define Hadrons____FILEBASENAME____hpp_ + +#include +#include +#include + +BEGIN_HADRONS_NAMESPACE + +/****************************************************************************** + * ___FILEBASENAME___ * + ******************************************************************************/ +class ___FILEBASENAME___Par: Serializable +{ +public: + GRID_SERIALIZABLE_CLASS_MEMBERS(___FILEBASENAME___Par, + unsigned int, i); +}; + +class T___FILEBASENAME___: public Module<___FILEBASENAME___Par> +{ +public: + // constructor + T___FILEBASENAME___(const std::string name); + // destructor + virtual ~T___FILEBASENAME___(void) = default; + // dependency relation + virtual std::vector getInput(void); + virtual std::vector getOutput(void); + // setup + virtual void setup(void); + // execution + virtual void execute(void); +}; + +MODULE_REGISTER(___FILEBASENAME___, T___FILEBASENAME___); + +END_HADRONS_NAMESPACE + +#endif // Hadrons____FILEBASENAME____hpp_ diff --git a/extras/Hadrons/Modules/templates/Module_in_NS.cc.template b/extras/Hadrons/Modules/templates/Module_in_NS.cc.template new file mode 100644 index 00000000..8b2a0ec0 --- /dev/null +++ b/extras/Hadrons/Modules/templates/Module_in_NS.cc.template @@ -0,0 +1,40 @@ +#include + +using namespace Grid; +using namespace Hadrons; +using namespace ___NAMESPACE___; + +/****************************************************************************** +* T___FILEBASENAME___ implementation * +******************************************************************************/ +// constructor ///////////////////////////////////////////////////////////////// +T___FILEBASENAME___::T___FILEBASENAME___(const std::string name) +: Module<___FILEBASENAME___Par>(name) +{} + +// dependencies/products /////////////////////////////////////////////////////// +std::vector T___FILEBASENAME___::getInput(void) +{ + std::vector in; + + return in; +} + +std::vector T___FILEBASENAME___::getOutput(void) +{ + std::vector out = {getName()}; + + return out; +} + +// setup /////////////////////////////////////////////////////////////////////// +void T___FILEBASENAME___::setup(void) +{ + +} + +// execution /////////////////////////////////////////////////////////////////// +void T___FILEBASENAME___::execute(void) +{ + +} diff --git a/extras/Hadrons/Modules/templates/Module_in_NS.hpp.template b/extras/Hadrons/Modules/templates/Module_in_NS.hpp.template new file mode 100644 index 00000000..ea77b12a --- /dev/null +++ b/extras/Hadrons/Modules/templates/Module_in_NS.hpp.template @@ -0,0 +1,44 @@ +#ifndef Hadrons____NAMESPACE_______FILEBASENAME____hpp_ +#define Hadrons____NAMESPACE_______FILEBASENAME____hpp_ + +#include +#include +#include + +BEGIN_HADRONS_NAMESPACE + +/****************************************************************************** + * ___FILEBASENAME___ * + ******************************************************************************/ +BEGIN_MODULE_NAMESPACE(___NAMESPACE___) + +class ___FILEBASENAME___Par: Serializable +{ +public: + GRID_SERIALIZABLE_CLASS_MEMBERS(___FILEBASENAME___Par, + unsigned int, i); +}; + +class T___FILEBASENAME___: public Module<___FILEBASENAME___Par> +{ +public: + // constructor + T___FILEBASENAME___(const std::string name); + // destructor + virtual ~T___FILEBASENAME___(void) = default; + // dependency relation + virtual std::vector getInput(void); + virtual std::vector getOutput(void); + // setup + virtual void setup(void); + // execution + virtual void execute(void); +}; + +MODULE_REGISTER_NS(___FILEBASENAME___, T___FILEBASENAME___, ___NAMESPACE___); + +END_MODULE_NAMESPACE + +END_HADRONS_NAMESPACE + +#endif // Hadrons____NAMESPACE_______FILEBASENAME____hpp_ diff --git a/extras/Hadrons/Modules/templates/Module_tmp.hpp.template b/extras/Hadrons/Modules/templates/Module_tmp.hpp.template new file mode 100644 index 00000000..2ee053a9 --- /dev/null +++ b/extras/Hadrons/Modules/templates/Module_tmp.hpp.template @@ -0,0 +1,81 @@ +#ifndef Hadrons____FILEBASENAME____hpp_ +#define Hadrons____FILEBASENAME____hpp_ + +#include +#include +#include + +BEGIN_HADRONS_NAMESPACE + +/****************************************************************************** + * ___FILEBASENAME___ * + ******************************************************************************/ +class ___FILEBASENAME___Par: Serializable +{ +public: + GRID_SERIALIZABLE_CLASS_MEMBERS(___FILEBASENAME___Par, + unsigned int, i); +}; + +template +class T___FILEBASENAME___: public Module<___FILEBASENAME___Par> +{ +public: + // constructor + T___FILEBASENAME___(const std::string name); + // destructor + virtual ~T___FILEBASENAME___(void) = default; + // dependency relation + virtual std::vector getInput(void); + virtual std::vector getOutput(void); + // setup + virtual void setup(void); + // execution + virtual void execute(void); +}; + +MODULE_REGISTER(___FILEBASENAME___, T___FILEBASENAME___); + +/****************************************************************************** + * T___FILEBASENAME___ implementation * + ******************************************************************************/ +// constructor ///////////////////////////////////////////////////////////////// +template +T___FILEBASENAME___::T___FILEBASENAME___(const std::string name) +: Module<___FILEBASENAME___Par>(name) +{} + +// dependencies/products /////////////////////////////////////////////////////// +template +std::vector T___FILEBASENAME___::getInput(void) +{ + std::vector in; + + return in; +} + +template +std::vector T___FILEBASENAME___::getOutput(void) +{ + std::vector out = {getName()}; + + return out; +} + +// setup /////////////////////////////////////////////////////////////////////// +template +void T___FILEBASENAME___::setup(void) +{ + +} + +// execution /////////////////////////////////////////////////////////////////// +template +void T___FILEBASENAME___::execute(void) +{ + +} + +END_HADRONS_NAMESPACE + +#endif // Hadrons____FILEBASENAME____hpp_ diff --git a/extras/Hadrons/Modules/templates/Module_tmp_in_NS.hpp.template b/extras/Hadrons/Modules/templates/Module_tmp_in_NS.hpp.template new file mode 100644 index 00000000..b79c0ad3 --- /dev/null +++ b/extras/Hadrons/Modules/templates/Module_tmp_in_NS.hpp.template @@ -0,0 +1,85 @@ +#ifndef Hadrons____NAMESPACE_______FILEBASENAME____hpp_ +#define Hadrons____NAMESPACE_______FILEBASENAME____hpp_ + +#include +#include +#include + +BEGIN_HADRONS_NAMESPACE + +/****************************************************************************** + * ___FILEBASENAME___ * + ******************************************************************************/ +BEGIN_MODULE_NAMESPACE(___NAMESPACE___) + +class ___FILEBASENAME___Par: Serializable +{ +public: + GRID_SERIALIZABLE_CLASS_MEMBERS(___FILEBASENAME___Par, + unsigned int, i); +}; + +template +class T___FILEBASENAME___: public Module<___FILEBASENAME___Par> +{ +public: + // constructor + T___FILEBASENAME___(const std::string name); + // destructor + virtual ~T___FILEBASENAME___(void) = default; + // dependency relation + virtual std::vector getInput(void); + virtual std::vector getOutput(void); + // setup + virtual void setup(void); + // execution + virtual void execute(void); +}; + +MODULE_REGISTER_NS(___FILEBASENAME___, T___FILEBASENAME___, ___NAMESPACE___); + +/****************************************************************************** + * T___FILEBASENAME___ implementation * + ******************************************************************************/ +// constructor ///////////////////////////////////////////////////////////////// +template +T___FILEBASENAME___::T___FILEBASENAME___(const std::string name) +: Module<___FILEBASENAME___Par>(name) +{} + +// dependencies/products /////////////////////////////////////////////////////// +template +std::vector T___FILEBASENAME___::getInput(void) +{ + std::vector in; + + return in; +} + +template +std::vector T___FILEBASENAME___::getOutput(void) +{ + std::vector out = {getName()}; + + return out; +} + +// setup /////////////////////////////////////////////////////////////////////// +template +void T___FILEBASENAME___::setup(void) +{ + +} + +// execution /////////////////////////////////////////////////////////////////// +template +void T___FILEBASENAME___::execute(void) +{ + +} + +END_MODULE_NAMESPACE + +END_HADRONS_NAMESPACE + +#endif // Hadrons____NAMESPACE_______FILEBASENAME____hpp_ diff --git a/extras/Hadrons/add_module.sh b/extras/Hadrons/add_module.sh new file mode 100755 index 00000000..d5d23ea4 --- /dev/null +++ b/extras/Hadrons/add_module.sh @@ -0,0 +1,31 @@ +#!/usr/bin/env bash + +if (( $# != 1 && $# != 2)); then + echo "usage: `basename $0` []" 1>&2 + exit 1 +fi +NAME=$1 +NS=$2 + +if (( $# == 1 )); then + if [ -e "Modules/${NAME}.cc" ] || [ -e "Modules/${NAME}.hpp" ]; then + echo "error: files Modules/${NAME}.* already exists" 1>&2 + exit 1 + fi + sed "s/___FILEBASENAME___/${NAME}/g" Modules/templates/Module.cc.template > Modules/${NAME}.cc + sed "s/___FILEBASENAME___/${NAME}/g" Modules/templates/Module.hpp.template > Modules/${NAME}.hpp +elif (( $# == 2 )); then + mkdir -p Modules/${NS} + if [ -e "Modules/${NS}/${NAME}.cc" ] || [ -e "Modules/${NS}/${NAME}.hpp" ]; then + echo "error: files Modules/${NS}/${NAME}.* already exists" 1>&2 + exit 1 + fi + TMPCC=".${NS}.${NAME}.tmp.cc" + TMPHPP=".${NS}.${NAME}.tmp.hpp" + sed "s/___FILEBASENAME___/${NAME}/g" Modules/templates/Module_in_NS.cc.template > ${TMPCC} + sed "s/___FILEBASENAME___/${NAME}/g" Modules/templates/Module_in_NS.hpp.template > ${TMPHPP} + sed "s/___NAMESPACE___/${NS}/g" ${TMPCC} > Modules/${NS}/${NAME}.cc + sed "s/___NAMESPACE___/${NS}/g" ${TMPHPP} > Modules/${NS}/${NAME}.hpp + rm -f ${TMPCC} ${TMPHPP} +fi +./make_module_list.sh diff --git a/extras/Hadrons/add_module_template.sh b/extras/Hadrons/add_module_template.sh new file mode 100755 index 00000000..0069fcea --- /dev/null +++ b/extras/Hadrons/add_module_template.sh @@ -0,0 +1,28 @@ +#!/usr/bin/env bash + +if (( $# != 1 && $# != 2)); then + echo "usage: `basename $0` []" 1>&2 + exit 1 +fi +NAME=$1 +NS=$2 + +if (( $# == 1 )); then + if [ -e "Modules/${NAME}.cc" ] || [ -e "Modules/${NAME}.hpp" ]; then + echo "error: files Modules/${NAME}.* already exists" 1>&2 + exit 1 + fi + sed "s/___FILEBASENAME___/${NAME}/g" Modules/templates/Module_tmp.hpp.template > Modules/${NAME}.hpp +elif (( $# == 2 )); then + mkdir -p Modules/${NS} + if [ -e "Modules/${NS}/${NAME}.cc" ] || [ -e "Modules/${NS}/${NAME}.hpp" ]; then + echo "error: files Modules/${NS}/${NAME}.* already exists" 1>&2 + exit 1 + fi + TMPCC=".${NS}.${NAME}.tmp.cc" + TMPHPP=".${NS}.${NAME}.tmp.hpp" + sed "s/___FILEBASENAME___/${NAME}/g" Modules/templates/Module_tmp_in_NS.hpp.template > ${TMPHPP} + sed "s/___NAMESPACE___/${NS}/g" ${TMPHPP} > Modules/${NS}/${NAME}.hpp + rm -f ${TMPCC} ${TMPHPP} +fi +./make_module_list.sh diff --git a/extras/Hadrons/make_module_list.sh b/extras/Hadrons/make_module_list.sh new file mode 100755 index 00000000..ddc56ff6 --- /dev/null +++ b/extras/Hadrons/make_module_list.sh @@ -0,0 +1,12 @@ +#!/usr/bin/env bash + +echo 'modules_cc =\' > modules.inc +find Modules -name '*.cc' -type f -print | sed 's/^/ /;$q;s/$/ \\/' >> modules.inc +echo '' >> modules.inc +echo 'modules_hpp =\' >> modules.inc +find Modules -name '*.hpp' -type f -print | sed 's/^/ /;$q;s/$/ \\/' >> modules.inc +echo '' >> modules.inc +rm -f Modules.hpp +for f in `find Modules -name '*.hpp'`; do + echo "#include " >> Modules.hpp +done diff --git a/extras/Hadrons/modules.inc b/extras/Hadrons/modules.inc new file mode 100644 index 00000000..669b08ba --- /dev/null +++ b/extras/Hadrons/modules.inc @@ -0,0 +1,38 @@ +modules_cc =\ + Modules/MContraction/WeakHamiltonianEye.cc \ + Modules/MContraction/WeakHamiltonianNonEye.cc \ + Modules/MContraction/WeakNeutral4ptDisc.cc \ + Modules/MGauge/Load.cc \ + Modules/MGauge/Random.cc \ + Modules/MGauge/StochEm.cc \ + Modules/MGauge/Unit.cc \ + Modules/MScalar/ChargedProp.cc \ + Modules/MScalar/FreeProp.cc + +modules_hpp =\ + Modules/MAction/DWF.hpp \ + Modules/MAction/Wilson.hpp \ + Modules/MContraction/Baryon.hpp \ + Modules/MContraction/DiscLoop.hpp \ + Modules/MContraction/Gamma3pt.hpp \ + Modules/MContraction/Meson.hpp \ + Modules/MContraction/WeakHamiltonian.hpp \ + Modules/MContraction/WeakHamiltonianEye.hpp \ + Modules/MContraction/WeakHamiltonianNonEye.hpp \ + Modules/MContraction/WeakNeutral4ptDisc.hpp \ + Modules/MFermion/GaugeProp.hpp \ + Modules/MGauge/Load.hpp \ + Modules/MGauge/Random.hpp \ + Modules/MGauge/StochEm.hpp \ + Modules/MGauge/Unit.hpp \ + Modules/MLoop/NoiseLoop.hpp \ + Modules/MScalar/ChargedProp.hpp \ + Modules/MScalar/FreeProp.hpp \ + Modules/MScalar/Scalar.hpp \ + Modules/MSink/Point.hpp \ + Modules/MSolver/RBPrecCG.hpp \ + Modules/MSource/Point.hpp \ + Modules/MSource/SeqGamma.hpp \ + Modules/MSource/Wall.hpp \ + Modules/MSource/Z2.hpp + diff --git a/extras/Makefile.am b/extras/Makefile.am new file mode 100644 index 00000000..d8c2b675 --- /dev/null +++ b/extras/Makefile.am @@ -0,0 +1 @@ +SUBDIRS = Hadrons \ No newline at end of file diff --git a/extras/qed-fvol/Global.cc b/extras/qed-fvol/Global.cc new file mode 100644 index 00000000..57ed97cc --- /dev/null +++ b/extras/qed-fvol/Global.cc @@ -0,0 +1,11 @@ +#include + +using namespace Grid; +using namespace QCD; +using namespace QedFVol; + +QedFVolLogger QedFVol::QedFVolLogError(1,"Error"); +QedFVolLogger QedFVol::QedFVolLogWarning(1,"Warning"); +QedFVolLogger QedFVol::QedFVolLogMessage(1,"Message"); +QedFVolLogger QedFVol::QedFVolLogIterative(1,"Iterative"); +QedFVolLogger QedFVol::QedFVolLogDebug(1,"Debug"); diff --git a/extras/qed-fvol/Global.hpp b/extras/qed-fvol/Global.hpp new file mode 100644 index 00000000..7f07200d --- /dev/null +++ b/extras/qed-fvol/Global.hpp @@ -0,0 +1,42 @@ +#ifndef QedFVol_Global_hpp_ +#define QedFVol_Global_hpp_ + +#include + +#define BEGIN_QEDFVOL_NAMESPACE \ +namespace Grid {\ +using namespace QCD;\ +namespace QedFVol {\ +using Grid::operator<<; +#define END_QEDFVOL_NAMESPACE }} + +/* the 'using Grid::operator<<;' statement prevents a very nasty compilation + * error with GCC (clang compiles fine without it). + */ + +BEGIN_QEDFVOL_NAMESPACE + +class QedFVolLogger: public Logger +{ +public: + QedFVolLogger(int on, std::string nm): Logger("QedFVol", on, nm, + GridLogColours, "BLACK"){}; +}; + +#define LOG(channel) std::cout << QedFVolLog##channel +#define QEDFVOL_ERROR(msg)\ +LOG(Error) << msg << " (" << __FUNCTION__ << " at " << __FILE__ << ":"\ + << __LINE__ << ")" << std::endl;\ +abort(); + +#define DEBUG_VAR(var) LOG(Debug) << #var << "= " << (var) << std::endl; + +extern QedFVolLogger QedFVolLogError; +extern QedFVolLogger QedFVolLogWarning; +extern QedFVolLogger QedFVolLogMessage; +extern QedFVolLogger QedFVolLogIterative; +extern QedFVolLogger QedFVolLogDebug; + +END_QEDFVOL_NAMESPACE + +#endif // QedFVol_Global_hpp_ diff --git a/extras/qed-fvol/Makefile.am b/extras/qed-fvol/Makefile.am new file mode 100644 index 00000000..0a9030c7 --- /dev/null +++ b/extras/qed-fvol/Makefile.am @@ -0,0 +1,9 @@ +AM_CXXFLAGS += -I$(top_srcdir)/extras + +bin_PROGRAMS = qed-fvol + +qed_fvol_SOURCES = \ + qed-fvol.cc \ + Global.cc + +qed_fvol_LDADD = -lGrid diff --git a/extras/qed-fvol/WilsonLoops.h b/extras/qed-fvol/WilsonLoops.h new file mode 100644 index 00000000..98db6b7a --- /dev/null +++ b/extras/qed-fvol/WilsonLoops.h @@ -0,0 +1,265 @@ +#ifndef QEDFVOL_WILSONLOOPS_H +#define QEDFVOL_WILSONLOOPS_H + +#include + +BEGIN_QEDFVOL_NAMESPACE + +template class NewWilsonLoops : public Gimpl { +public: + INHERIT_GIMPL_TYPES(Gimpl); + + typedef typename Gimpl::GaugeLinkField GaugeMat; + typedef typename Gimpl::GaugeField GaugeLorentz; + + ////////////////////////////////////////////////// + // directed plaquette oriented in mu,nu plane + ////////////////////////////////////////////////// + static void dirPlaquette(GaugeMat &plaq, const std::vector &U, + const int mu, const int nu) { + // Annoyingly, must use either scope resolution to find dependent base + // class, + // or this-> ; there is no "this" in a static method. This forces explicit + // Gimpl scope + // resolution throughout the usage in this file, and rather defeats the + // purpose of deriving + // from Gimpl. + plaq = Gimpl::CovShiftBackward( + U[mu], mu, Gimpl::CovShiftBackward( + U[nu], nu, Gimpl::CovShiftForward(U[mu], mu, U[nu]))); + } + ////////////////////////////////////////////////// + // trace of directed plaquette oriented in mu,nu plane + ////////////////////////////////////////////////// + static void traceDirPlaquette(LatticeComplex &plaq, + const std::vector &U, const int mu, + const int nu) { + GaugeMat sp(U[0]._grid); + dirPlaquette(sp, U, mu, nu); + plaq = trace(sp); + } + ////////////////////////////////////////////////// + // sum over all planes of plaquette + ////////////////////////////////////////////////// + static void sitePlaquette(LatticeComplex &Plaq, + const std::vector &U) { + LatticeComplex sitePlaq(U[0]._grid); + Plaq = zero; + for (int mu = 1; mu < U[0]._grid->_ndimension; mu++) { + for (int nu = 0; nu < mu; nu++) { + traceDirPlaquette(sitePlaq, U, mu, nu); + Plaq = Plaq + sitePlaq; + } + } + } + ////////////////////////////////////////////////// + // sum over all x,y,z,t and over all planes of plaquette + ////////////////////////////////////////////////// + static Real sumPlaquette(const GaugeLorentz &Umu) { + std::vector U(4, Umu._grid); + + for (int mu = 0; mu < Umu._grid->_ndimension; mu++) { + U[mu] = PeekIndex(Umu, mu); + } + + LatticeComplex Plaq(Umu._grid); + + sitePlaquette(Plaq, U); + + TComplex Tp = sum(Plaq); + Complex p = TensorRemove(Tp); + return p.real(); + } + ////////////////////////////////////////////////// + // average over all x,y,z,t and over all planes of plaquette + ////////////////////////////////////////////////// + static Real avgPlaquette(const GaugeLorentz &Umu) { + int ndim = Umu._grid->_ndimension; + Real sumplaq = sumPlaquette(Umu); + Real vol = Umu._grid->gSites(); + Real faces = (1.0 * ndim * (ndim - 1)) / 2.0; + return sumplaq / vol / faces / Nc; // Nc dependent... FIXME + } + + ////////////////////////////////////////////////// + // Wilson loop of size (R1, R2), oriented in mu,nu plane + ////////////////////////////////////////////////// + static void wilsonLoop(GaugeMat &wl, const std::vector &U, + const int Rmu, const int Rnu, + const int mu, const int nu) { + wl = U[nu]; + + for(int i = 0; i < Rnu-1; i++){ + wl = Gimpl::CovShiftForward(U[nu], nu, wl); + } + + for(int i = 0; i < Rmu; i++){ + wl = Gimpl::CovShiftForward(U[mu], mu, wl); + } + + for(int i = 0; i < Rnu; i++){ + wl = Gimpl::CovShiftBackward(U[nu], nu, wl); + } + + for(int i = 0; i < Rmu; i++){ + wl = Gimpl::CovShiftBackward(U[mu], mu, wl); + } + } + ////////////////////////////////////////////////// + // trace of Wilson Loop oriented in mu,nu plane + ////////////////////////////////////////////////// + static void traceWilsonLoop(LatticeComplex &wl, + const std::vector &U, + const int Rmu, const int Rnu, + const int mu, const int nu) { + GaugeMat sp(U[0]._grid); + wilsonLoop(sp, U, Rmu, Rnu, mu, nu); + wl = trace(sp); + } + ////////////////////////////////////////////////// + // sum over all planes of Wilson loop + ////////////////////////////////////////////////// + static void siteWilsonLoop(LatticeComplex &Wl, + const std::vector &U, + const int R1, const int R2) { + LatticeComplex siteWl(U[0]._grid); + Wl = zero; + for (int mu = 1; mu < U[0]._grid->_ndimension; mu++) { + for (int nu = 0; nu < mu; nu++) { + traceWilsonLoop(siteWl, U, R1, R2, mu, nu); + Wl = Wl + siteWl; + traceWilsonLoop(siteWl, U, R2, R1, mu, nu); + Wl = Wl + siteWl; + } + } + } + ////////////////////////////////////////////////// + // sum over planes of Wilson loop with length R1 + // in the time direction + ////////////////////////////////////////////////// + static void siteTimelikeWilsonLoop(LatticeComplex &Wl, + const std::vector &U, + const int R1, const int R2) { + LatticeComplex siteWl(U[0]._grid); + + int ndim = U[0]._grid->_ndimension; + + Wl = zero; + for (int nu = 0; nu < ndim - 1; nu++) { + traceWilsonLoop(siteWl, U, R1, R2, ndim-1, nu); + Wl = Wl + siteWl; + } + } + ////////////////////////////////////////////////// + // sum Wilson loop over all planes orthogonal to the time direction + ////////////////////////////////////////////////// + static void siteSpatialWilsonLoop(LatticeComplex &Wl, + const std::vector &U, + const int R1, const int R2) { + LatticeComplex siteWl(U[0]._grid); + + Wl = zero; + for (int mu = 1; mu < U[0]._grid->_ndimension - 1; mu++) { + for (int nu = 0; nu < mu; nu++) { + traceWilsonLoop(siteWl, U, R1, R2, mu, nu); + Wl = Wl + siteWl; + traceWilsonLoop(siteWl, U, R2, R1, mu, nu); + Wl = Wl + siteWl; + } + } + } + ////////////////////////////////////////////////// + // sum over all x,y,z,t and over all planes of Wilson loop + ////////////////////////////////////////////////// + static Real sumWilsonLoop(const GaugeLorentz &Umu, + const int R1, const int R2) { + std::vector U(4, Umu._grid); + + for (int mu = 0; mu < Umu._grid->_ndimension; mu++) { + U[mu] = PeekIndex(Umu, mu); + } + + LatticeComplex Wl(Umu._grid); + + siteWilsonLoop(Wl, U, R1, R2); + + TComplex Tp = sum(Wl); + Complex p = TensorRemove(Tp); + return p.real(); + } + ////////////////////////////////////////////////// + // sum over all x,y,z,t and over all planes of timelike Wilson loop + ////////////////////////////////////////////////// + static Real sumTimelikeWilsonLoop(const GaugeLorentz &Umu, + const int R1, const int R2) { + std::vector U(4, Umu._grid); + + for (int mu = 0; mu < Umu._grid->_ndimension; mu++) { + U[mu] = PeekIndex(Umu, mu); + } + + LatticeComplex Wl(Umu._grid); + + siteTimelikeWilsonLoop(Wl, U, R1, R2); + + TComplex Tp = sum(Wl); + Complex p = TensorRemove(Tp); + return p.real(); + } + ////////////////////////////////////////////////// + // sum over all x,y,z,t and over all planes of spatial Wilson loop + ////////////////////////////////////////////////// + static Real sumSpatialWilsonLoop(const GaugeLorentz &Umu, + const int R1, const int R2) { + std::vector U(4, Umu._grid); + + for (int mu = 0; mu < Umu._grid->_ndimension; mu++) { + U[mu] = PeekIndex(Umu, mu); + } + + LatticeComplex Wl(Umu._grid); + + siteSpatialWilsonLoop(Wl, U, R1, R2); + + TComplex Tp = sum(Wl); + Complex p = TensorRemove(Tp); + return p.real(); + } + ////////////////////////////////////////////////// + // average over all x,y,z,t and over all planes of Wilson loop + ////////////////////////////////////////////////// + static Real avgWilsonLoop(const GaugeLorentz &Umu, + const int R1, const int R2) { + int ndim = Umu._grid->_ndimension; + Real sumWl = sumWilsonLoop(Umu, R1, R2); + Real vol = Umu._grid->gSites(); + Real faces = 1.0 * ndim * (ndim - 1); + return sumWl / vol / faces / Nc; // Nc dependent... FIXME + } + ////////////////////////////////////////////////// + // average over all x,y,z,t and over all planes of timelike Wilson loop + ////////////////////////////////////////////////// + static Real avgTimelikeWilsonLoop(const GaugeLorentz &Umu, + const int R1, const int R2) { + int ndim = Umu._grid->_ndimension; + Real sumWl = sumTimelikeWilsonLoop(Umu, R1, R2); + Real vol = Umu._grid->gSites(); + Real faces = 1.0 * (ndim - 1); + return sumWl / vol / faces / Nc; // Nc dependent... FIXME + } + ////////////////////////////////////////////////// + // average over all x,y,z,t and over all planes of spatial Wilson loop + ////////////////////////////////////////////////// + static Real avgSpatialWilsonLoop(const GaugeLorentz &Umu, + const int R1, const int R2) { + int ndim = Umu._grid->_ndimension; + Real sumWl = sumSpatialWilsonLoop(Umu, R1, R2); + Real vol = Umu._grid->gSites(); + Real faces = 1.0 * (ndim - 1) * (ndim - 2); + return sumWl / vol / faces / Nc; // Nc dependent... FIXME + } +}; + +END_QEDFVOL_NAMESPACE + +#endif // QEDFVOL_WILSONLOOPS_H \ No newline at end of file diff --git a/extras/qed-fvol/qed-fvol.cc b/extras/qed-fvol/qed-fvol.cc new file mode 100644 index 00000000..3ecac2fc --- /dev/null +++ b/extras/qed-fvol/qed-fvol.cc @@ -0,0 +1,88 @@ +#include +#include + +using namespace Grid; +using namespace QCD; +using namespace QedFVol; + +typedef PeriodicGaugeImpl QedPeriodicGimplR; +typedef PhotonR::GaugeField EmField; +typedef PhotonR::GaugeLinkField EmComp; + +const int NCONFIGS = 10; +const int NWILSON = 10; + +int main(int argc, char *argv[]) +{ + // parse command line + std::string parameterFileName; + + if (argc < 2) + { + std::cerr << "usage: " << argv[0] << " [Grid options]"; + std::cerr << std::endl; + std::exit(EXIT_FAILURE); + } + parameterFileName = argv[1]; + + // initialization + Grid_init(&argc, &argv); + QedFVolLogError.Active(GridLogError.isActive()); + QedFVolLogWarning.Active(GridLogWarning.isActive()); + QedFVolLogMessage.Active(GridLogMessage.isActive()); + QedFVolLogIterative.Active(GridLogIterative.isActive()); + QedFVolLogDebug.Active(GridLogDebug.isActive()); + LOG(Message) << "Grid initialized" << std::endl; + + // QED stuff + std::vector latt_size = GridDefaultLatt(); + std::vector simd_layout = GridDefaultSimd(4, vComplex::Nsimd()); + std::vector mpi_layout = GridDefaultMpi(); + GridCartesian grid(latt_size,simd_layout,mpi_layout); + GridParallelRNG pRNG(&grid); + PhotonR photon(PhotonR::Gauge::feynman, + PhotonR::ZmScheme::qedL); + EmField a(&grid); + EmField expA(&grid); + + Complex imag_unit(0, 1); + + Real wlA; + std::vector logWlAvg(NWILSON, 0.0), logWlTime(NWILSON, 0.0), logWlSpace(NWILSON, 0.0); + + pRNG.SeedRandomDevice(); + + LOG(Message) << "Wilson loop calculation beginning" << std::endl; + for(int ic = 0; ic < NCONFIGS; ic++){ + LOG(Message) << "Configuration " << ic <::avgWilsonLoop(expA, iw, iw) * 3; + logWlAvg[iw-1] -= 2*log(wlA); + wlA = NewWilsonLoops::avgTimelikeWilsonLoop(expA, iw, iw) * 3; + logWlTime[iw-1] -= 2*log(wlA); + wlA = NewWilsonLoops::avgSpatialWilsonLoop(expA, iw, iw) * 3; + logWlSpace[iw-1] -= 2*log(wlA); + } + } + LOG(Message) << "Wilson loop calculation completed" << std::endl; + + // Calculate Wilson loops + for(int iw=1; iw<=10; iw++){ + LOG(Message) << iw << 'x' << iw << " Wilson loop" << std::endl; + LOG(Message) << "-2log(W) average: " << logWlAvg[iw-1]/NCONFIGS << std::endl; + LOG(Message) << "-2log(W) timelike: " << logWlTime[iw-1]/NCONFIGS << std::endl; + LOG(Message) << "-2log(W) spatial: " << logWlSpace[iw-1]/NCONFIGS << std::endl; + } + + // epilogue + LOG(Message) << "Grid is finalizing now" << std::endl; + Grid_finalize(); + + return EXIT_SUCCESS; +} diff --git a/gcc-bug-report/README b/gcc-bug-report/README index 294d0a6c..a879ebe3 100644 --- a/gcc-bug-report/README +++ b/gcc-bug-report/README @@ -20,4 +20,17 @@ The simple testcase in this directory is the submitted bug report that encapsula problem. The test case works with icpc and with clang++, but fails consistently on g++ current variants. -Peter \ No newline at end of file +Peter + + +************ + +Second GCC bug reported, see Issue 100. + +https://wandbox.org/permlink/tzssJza6R9XnqANw +https://gcc.gnu.org/bugzilla/show_bug.cgi?id=80652 + +Getting Travis fails under gcc-5 for Test_simd, now that I added more comprehensive testing to the +CI test suite. The limitations of Travis runtime limits & weak cores are being shown. + +Travis uses 5.4.1 for g++-5. diff --git a/grid-config.in b/grid-config.in new file mode 100755 index 00000000..bd340846 --- /dev/null +++ b/grid-config.in @@ -0,0 +1,86 @@ +#! /bin/sh + +prefix=@prefix@ +exec_prefix=@exec_prefix@ +includedir=@includedir@ + +usage() +{ + cat < + +This program is free software; you can redistribute it and/or modify +it under the terms of the GNU General Public License as published by +the Free Software Foundation; either version 2 of the License, or +(at your option) any later version. + +This program is distributed in the hope that it will be useful, +but WITHOUT ANY WARRANTY; without even the implied warranty of +MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +GNU General Public License for more details. + +You should have received a copy of the GNU General Public License along +with this program; if not, write to the Free Software Foundation, Inc., +51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA. + +See the full license in the file "LICENSE" in the top level distribution +directory +*************************************************************************************/ +/* END LEGAL */ + +#ifndef DISABLE_WARNINGS_H +#define DISABLE_WARNINGS_H + + //disables and intel compiler specific warning (in json.hpp) +#pragma warning disable 488 + + +#endif diff --git a/lib/Grid.h b/lib/Grid.h index 0c5983f3..9dcc207b 100644 --- a/lib/Grid.h +++ b/lib/Grid.h @@ -38,52 +38,12 @@ Author: paboyle #ifndef GRID_H #define GRID_H -/////////////////// -// Std C++ dependencies -/////////////////// -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -/////////////////// -// Grid headers -/////////////////// -#include -#include "Config.h" -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -#include -#include -#include -#include - - +#include +#include +#include +#include +#include +#include +#include #endif diff --git a/lib/algorithms/iterative/MatrixUtils.h b/lib/GridCore.h similarity index 54% rename from lib/algorithms/iterative/MatrixUtils.h rename to lib/GridCore.h index 39b5c043..55396a37 100644 --- a/lib/algorithms/iterative/MatrixUtils.h +++ b/lib/GridCore.h @@ -2,11 +2,13 @@ Grid physics library, www.github.com/paboyle/Grid - Source file: ./lib/algorithms/iterative/MatrixUtils.h + Source file: ./lib/Grid.h Copyright (C) 2015 Author: Peter Boyle +Author: azusayamaguchi +Author: paboyle This program is free software; you can redistribute it and/or modify it under the terms of the GNU General Public License as published by @@ -25,51 +27,34 @@ Author: Peter Boyle See the full license in the file "LICENSE" in the top level distribution directory *************************************************************************************/ /* END LEGAL */ -#ifndef GRID_MATRIX_UTILS_H -#define GRID_MATRIX_UTILS_H +// +// Grid.h +// simd +// +// Created by Peter Boyle on 09/05/2014. +// Copyright (c) 2014 University of Edinburgh. All rights reserved. +// -namespace Grid { +#ifndef GRID_BASE_H +#define GRID_BASE_H - namespace MatrixUtils { +#include - template inline void Size(Matrix& A,int &N,int &M){ - N=A.size(); assert(N>0); - M=A[0].size(); - for(int i=0;i +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include - template inline void SizeSquare(Matrix& A,int &N) - { - int M; - Size(A,N,M); - assert(N==M); - } - - template inline void Fill(Matrix& A,T & val) - { - int N,M; - Size(A,N,M); - for(int i=0;i inline void Diagonal(Matrix& A,T & val) - { - int N; - SizeSquare(A,N); - for(int i=0;i inline void Identity(Matrix& A) - { - Fill(A,0.0); - Diagonal(A,1.0); - } - - }; -} #endif diff --git a/lib/qcd/hmc/HMC.cc b/lib/GridQCDcore.h similarity index 75% rename from lib/qcd/hmc/HMC.cc rename to lib/GridQCDcore.h index 3cb39111..7f50761f 100644 --- a/lib/qcd/hmc/HMC.cc +++ b/lib/GridQCDcore.h @@ -2,12 +2,12 @@ Grid physics library, www.github.com/paboyle/Grid - Source file: ./lib/qcd/hmc/HMC.cc + Source file: ./lib/Grid.h Copyright (C) 2015 Author: Peter Boyle -Author: neo +Author: azusayamaguchi Author: paboyle This program is free software; you can redistribute it and/or modify @@ -27,10 +27,16 @@ Author: paboyle See the full license in the file "LICENSE" in the top level distribution directory *************************************************************************************/ /* END LEGAL */ -#include +#ifndef GRID_QCD_CORE_H +#define GRID_QCD_CORE_H -namespace Grid{ - namespace QCD{ +///////////////////////// +// Core Grid QCD headers +///////////////////////// +#include +#include +#include +#include +#include - } -} +#endif diff --git a/lib/GridStd.h b/lib/GridStd.h new file mode 100644 index 00000000..097e62ab --- /dev/null +++ b/lib/GridStd.h @@ -0,0 +1,29 @@ +#ifndef GRID_STD_H +#define GRID_STD_H + +/////////////////// +// Std C++ dependencies +/////////////////// +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +/////////////////// +// Grid config +/////////////////// +#include "Config.h" + +#endif /* GRID_STD_H */ diff --git a/lib/Grid_Eigen_Dense.h b/lib/Grid_Eigen_Dense.h new file mode 100644 index 00000000..4fb5b831 --- /dev/null +++ b/lib/Grid_Eigen_Dense.h @@ -0,0 +1,9 @@ +#pragma once +#if defined __GNUC__ +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wdeprecated-declarations" +#endif +#include +#if defined __GNUC__ +#pragma GCC diagnostic pop +#endif diff --git a/lib/Hadrons b/lib/Hadrons new file mode 120000 index 00000000..1f422592 --- /dev/null +++ b/lib/Hadrons @@ -0,0 +1 @@ +../extras/Hadrons \ No newline at end of file diff --git a/lib/Makefile.am b/lib/Makefile.am index a779135f..6dd7899e 100644 --- a/lib/Makefile.am +++ b/lib/Makefile.am @@ -1,4 +1,5 @@ extra_sources= +extra_headers= if BUILD_COMMS_MPI extra_sources+=communicator/Communicator_mpi.cc extra_sources+=communicator/Communicator_base.cc @@ -9,8 +10,8 @@ if BUILD_COMMS_MPI3 extra_sources+=communicator/Communicator_base.cc endif -if BUILD_COMMS_MPI3L - extra_sources+=communicator/Communicator_mpi3_leader.cc +if BUILD_COMMS_MPIT + extra_sources+=communicator/Communicator_mpit.cc extra_sources+=communicator/Communicator_base.cc endif @@ -24,6 +25,12 @@ if BUILD_COMMS_NONE extra_sources+=communicator/Communicator_base.cc endif +if BUILD_HDF5 + extra_sources+=serialisation/Hdf5IO.cc + extra_headers+=serialisation/Hdf5IO.h + extra_headers+=serialisation/Hdf5Type.h +endif + # # Libraries # @@ -32,6 +39,9 @@ include Eigen.inc lib_LIBRARIES = libGrid.a -libGrid_a_SOURCES = $(CCFILES) $(extra_sources) +CCFILES += $(extra_sources) +HFILES += $(extra_headers) + +libGrid_a_SOURCES = $(CCFILES) libGrid_adir = $(pkgincludedir) nobase_dist_pkginclude_HEADERS = $(HFILES) $(eigen_files) Config.h diff --git a/lib/Old/Endeavour.tgz b/lib/Old/Endeavour.tgz deleted file mode 100644 index 33bfbc01..00000000 Binary files a/lib/Old/Endeavour.tgz and /dev/null differ diff --git a/lib/Old/Tensor_peek.h b/lib/Old/Tensor_peek.h deleted file mode 100644 index eecb3cd5..00000000 --- a/lib/Old/Tensor_peek.h +++ /dev/null @@ -1,154 +0,0 @@ - /************************************************************************************* - - Grid physics library, www.github.com/paboyle/Grid - - Source file: ./lib/Old/Tensor_peek.h - - Copyright (C) 2015 - -Author: Peter Boyle - - This program is free software; you can redistribute it and/or modify - it under the terms of the GNU General Public License as published by - the Free Software Foundation; either version 2 of the License, or - (at your option) any later version. - - This program is distributed in the hope that it will be useful, - but WITHOUT ANY WARRANTY; without even the implied warranty of - MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the - GNU General Public License for more details. - - You should have received a copy of the GNU General Public License along - with this program; if not, write to the Free Software Foundation, Inc., - 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA. - - See the full license in the file "LICENSE" in the top level distribution directory - *************************************************************************************/ - /* END LEGAL */ -#ifndef GRID_MATH_PEEK_H -#define GRID_MATH_PEEK_H -namespace Grid { - -////////////////////////////////////////////////////////////////////////////// -// Peek on a specific index; returns a scalar in that index, tensor inherits rest -////////////////////////////////////////////////////////////////////////////// -// If we hit the right index, return scalar with no further recursion - -//template inline ComplexF peekIndex(const ComplexF arg) { return arg;} -//template inline ComplexD peekIndex(const ComplexD arg) { return arg;} -//template inline RealF peekIndex(const RealF arg) { return arg;} -//template inline RealD peekIndex(const RealD arg) { return arg;} -#if 0 -// Scalar peek, no indices -template::TensorLevel == Level >::type * =nullptr> inline - auto peekIndex(const iScalar &arg) -> iScalar -{ - return arg; -} -// Vector peek, one index -template::TensorLevel == Level >::type * =nullptr> inline - auto peekIndex(const iVector &arg,int i) -> iScalar // Index matches -{ - iScalar ret; // return scalar - ret._internal = arg._internal[i]; - return ret; -} -// Matrix peek, two indices -template::TensorLevel == Level >::type * =nullptr> inline - auto peekIndex(const iMatrix &arg,int i,int j) -> iScalar -{ - iScalar ret; // return scalar - ret._internal = arg._internal[i][j]; - return ret; -} - -///////////// -// No match peek for scalar,vector,matrix must forward on either 0,1,2 args. Must have 9 routines with notvalue -///////////// -// scalar -template::TensorLevel != Level >::type * =nullptr> inline - auto peekIndex(const iScalar &arg) -> iScalar(arg._internal))> -{ - iScalar(arg._internal))> ret; - ret._internal= peekIndex(arg._internal); - return ret; -} -template::TensorLevel != Level >::type * =nullptr> inline - auto peekIndex(const iScalar &arg,int i) -> iScalar(arg._internal,i))> -{ - iScalar(arg._internal,i))> ret; - ret._internal=peekIndex(arg._internal,i); - return ret; -} -template::TensorLevel != Level >::type * =nullptr> inline - auto peekIndex(const iScalar &arg,int i,int j) -> iScalar(arg._internal,i,j))> -{ - iScalar(arg._internal,i,j))> ret; - ret._internal=peekIndex(arg._internal,i,j); - return ret; -} -// vector -template::TensorLevel != Level >::type * =nullptr> inline -auto peekIndex(const iVector &arg) -> iVector(arg._internal[0])),N> -{ - iVector(arg._internal[0])),N> ret; - for(int ii=0;ii(arg._internal[ii]); - } - return ret; -} -template::TensorLevel != Level >::type * =nullptr> inline - auto peekIndex(const iVector &arg,int i) -> iVector(arg._internal[0],i)),N> -{ - iVector(arg._internal[0],i)),N> ret; - for(int ii=0;ii(arg._internal[ii],i); - } - return ret; -} -template::TensorLevel != Level >::type * =nullptr> inline - auto peekIndex(const iVector &arg,int i,int j) -> iVector(arg._internal[0],i,j)),N> -{ - iVector(arg._internal[0],i,j)),N> ret; - for(int ii=0;ii(arg._internal[ii],i,j); - } - return ret; -} - -// matrix -template::TensorLevel != Level >::type * =nullptr> inline -auto peekIndex(const iMatrix &arg) -> iMatrix(arg._internal[0][0])),N> -{ - iMatrix(arg._internal[0][0])),N> ret; - for(int ii=0;ii(arg._internal[ii][jj]);// Could avoid this because peeking a scalar is dumb - }} - return ret; -} -template::TensorLevel != Level >::type * =nullptr> inline - auto peekIndex(const iMatrix &arg,int i) -> iMatrix(arg._internal[0][0],i)),N> -{ - iMatrix(arg._internal[0][0],i)),N> ret; - for(int ii=0;ii(arg._internal[ii][jj],i); - }} - return ret; -} -template::TensorLevel != Level >::type * =nullptr> inline - auto peekIndex(const iMatrix &arg,int i,int j) -> iMatrix(arg._internal[0][0],i,j)),N> -{ - iMatrix(arg._internal[0][0],i,j)),N> ret; - for(int ii=0;ii(arg._internal[ii][jj],i,j); - }} - return ret; -} -#endif - - -} -#endif diff --git a/lib/Old/Tensor_poke.h b/lib/Old/Tensor_poke.h deleted file mode 100644 index 83d09cf1..00000000 --- a/lib/Old/Tensor_poke.h +++ /dev/null @@ -1,127 +0,0 @@ - /************************************************************************************* - - Grid physics library, www.github.com/paboyle/Grid - - Source file: ./lib/Old/Tensor_poke.h - - Copyright (C) 2015 - -Author: Peter Boyle - - This program is free software; you can redistribute it and/or modify - it under the terms of the GNU General Public License as published by - the Free Software Foundation; either version 2 of the License, or - (at your option) any later version. - - This program is distributed in the hope that it will be useful, - but WITHOUT ANY WARRANTY; without even the implied warranty of - MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the - GNU General Public License for more details. - - You should have received a copy of the GNU General Public License along - with this program; if not, write to the Free Software Foundation, Inc., - 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA. - - See the full license in the file "LICENSE" in the top level distribution directory - *************************************************************************************/ - /* END LEGAL */ -#ifndef GRID_MATH_POKE_H -#define GRID_MATH_POKE_H -namespace Grid { - -////////////////////////////////////////////////////////////////////////////// -// Poke a specific index; -////////////////////////////////////////////////////////////////////////////// -#if 0 -// Scalar poke -template::TensorLevel == Level >::type * =nullptr> inline - void pokeIndex(iScalar &ret, const iScalar &arg) -{ - ret._internal = arg._internal; -} -// Vector poke, one index -template::TensorLevel == Level >::type * =nullptr> inline - void pokeIndex(iVector &ret, const iScalar &arg,int i) -{ - ret._internal[i] = arg._internal; -} -//Matrix poke, two indices -template::TensorLevel == Level >::type * =nullptr> inline - void pokeIndex(iMatrix &ret, const iScalar &arg,int i,int j) -{ - ret._internal[i][j] = arg._internal; -} - -///////////// -// No match poke for scalar,vector,matrix must forward on either 0,1,2 args. Must have 9 routines with notvalue -///////////// -// scalar -template::TensorLevel != Level >::type * =nullptr> inline -void pokeIndex(iScalar &ret, const iScalar(ret._internal))> &arg) -{ - pokeIndex(ret._internal,arg._internal); -} -template::TensorLevel != Level >::type * =nullptr> inline - void pokeIndex(iScalar &ret, const iScalar(ret._internal,0))> &arg, int i) - -{ - pokeIndex(ret._internal,arg._internal,i); -} -template::TensorLevel != Level >::type * =nullptr> inline - void pokeIndex(iScalar &ret, const iScalar(ret._internal,0,0))> &arg,int i,int j) -{ - pokeIndex(ret._internal,arg._internal,i,j); -} - -// Vector -template::TensorLevel != Level >::type * =nullptr> inline - void pokeIndex(iVector &ret, iVector(ret._internal)),N> &arg) -{ - for(int ii=0;ii(ret._internal[ii],arg._internal[ii]); - } -} -template::TensorLevel != Level >::type * =nullptr> inline - void pokeIndex(iVector &ret, const iVector(ret._internal,0)),N> &arg,int i) -{ - for(int ii=0;ii(ret._internal[ii],arg._internal[ii],i); - } -} -template::TensorLevel != Level >::type * =nullptr> inline - void pokeIndex(iVector &ret, const iVector(ret._internal,0,0)),N> &arg,int i,int j) -{ - for(int ii=0;ii(ret._internal[ii],arg._internal[ii],i,j); - } -} - -// Matrix -template::TensorLevel != Level >::type * =nullptr> inline - void pokeIndex(iMatrix &ret, const iMatrix(ret._internal)),N> &arg) -{ - for(int ii=0;ii(ret._internal[ii][jj],arg._internal[ii][jj]); - }} -} -template::TensorLevel != Level >::type * =nullptr> inline - void pokeIndex(iMatrix &ret, const iMatrix(ret._internal,0)),N> &arg,int i) -{ - for(int ii=0;ii(ret._internal[ii][jj],arg._internal[ii][jj],i); - }} -} -template::TensorLevel != Level >::type * =nullptr> inline - void pokeIndex(iMatrix &ret, const iMatrix(ret._internal,0,0)),N> &arg, int i,int j) -{ - for(int ii=0;ii(ret._internal[ii][jj],arg._internal[ii][jj],i,j); - }} -} -#endif - -} -#endif diff --git a/lib/Algorithms.h b/lib/algorithms/Algorithms.h similarity index 89% rename from lib/Algorithms.h rename to lib/algorithms/Algorithms.h index 67eb11c3..07ae839c 100644 --- a/lib/Algorithms.h +++ b/lib/algorithms/Algorithms.h @@ -1,6 +1,6 @@ /************************************************************************************* - Grid physics library, www.github.com/paboyle/Grid + Grid physics library, www.github.com/paboyle/Grid Source file: ./lib/Algorithms.h @@ -37,20 +37,21 @@ Author: Peter Boyle #include #include #include +#include #include #include #include #include - #include #include +#include // Lanczos support -#include +//#include #include - #include +#include // Eigen/lanczos // EigCg diff --git a/lib/algorithms/CoarsenedMatrix.h b/lib/algorithms/CoarsenedMatrix.h index fd9acc91..c2910151 100644 --- a/lib/algorithms/CoarsenedMatrix.h +++ b/lib/algorithms/CoarsenedMatrix.h @@ -267,8 +267,7 @@ namespace Grid { SimpleCompressor compressor; Stencil.HaloExchange(in,compressor); -PARALLEL_FOR_LOOP - for(int ss=0;ssoSites();ss++){ + parallel_for(int ss=0;ssoSites();ss++){ siteVector res = zero; siteVector nbr; int ptype; @@ -380,8 +379,7 @@ PARALLEL_FOR_LOOP Subspace.ProjectToSubspace(oProj,oblock); // blockProject(iProj,iblock,Subspace.subspace); // blockProject(oProj,oblock,Subspace.subspace); -PARALLEL_FOR_LOOP - for(int ss=0;ssoSites();ss++){ + parallel_for(int ss=0;ssoSites();ss++){ for(int j=0;j({55,72,19,17,34})); Lattice > val(Grid()); random(RNG,val); Complex one(1.0); diff --git a/lib/FFT.h b/lib/algorithms/FFT.h similarity index 99% rename from lib/FFT.h rename to lib/algorithms/FFT.h index 240f338b..ec558ad9 100644 --- a/lib/FFT.h +++ b/lib/algorithms/FFT.h @@ -230,6 +230,7 @@ namespace Grid { // Barrel shift and collect global pencil std::vector lcoor(Nd), gcoor(Nd); result = source; + int pc = processor_coor[dim]; for(int p=0;plSites();idx++) { sgrid->LocalIndexToLocalCoor(idx,cbuf); peekLocalSite(s,result,cbuf); - cbuf[dim]+=p*L; + cbuf[dim]+=((pc+p) % processors[dim])*L; + // cbuf[dim]+=p*L; pokeLocalSite(s,pgbuf,cbuf); } } @@ -278,7 +280,6 @@ namespace Grid { flops+= flops_call*NN; // writing out result - int pc = processor_coor[dim]; PARALLEL_REGION { std::vector clbuf(Nd), cgbuf(Nd); diff --git a/lib/algorithms/LinearOperator.h b/lib/algorithms/LinearOperator.h index ea47d43b..6cb77296 100644 --- a/lib/algorithms/LinearOperator.h +++ b/lib/algorithms/LinearOperator.h @@ -235,7 +235,7 @@ namespace Grid { Field tmp(in._grid); _Mat.MeooeDag(in,tmp); - _Mat.MooeeInvDag(tmp,out); + _Mat.MooeeInvDag(tmp,out); _Mat.MeooeDag(out,tmp); _Mat.MooeeDag(in,out); diff --git a/lib/algorithms/approx/.dirstamp b/lib/algorithms/approx/.dirstamp deleted file mode 100644 index e69de29b..00000000 diff --git a/lib/algorithms/approx/Chebyshev.h b/lib/algorithms/approx/Chebyshev.h index 6837ae99..2793f138 100644 --- a/lib/algorithms/approx/Chebyshev.h +++ b/lib/algorithms/approx/Chebyshev.h @@ -197,8 +197,9 @@ namespace Grid { void operator() (LinearOperatorBase &Linop, const Field &in, Field &out) { GridBase *grid=in._grid; -//std::cout << "Chevyshef(): in._grid="< + class Forecast + { + public: + virtual Field operator()(Matrix &Mat, const Field& phi, const std::vector& chi) = 0; + }; + + // Implementation of Brower et al.'s chronological inverter (arXiv:hep-lat/9509012), + // used to forecast solutions across poles of the EOFA heatbath. + // + // Modified from CPS (cps_pp/src/util/dirac_op/d_op_base/comsrc/minresext.C) + template + class ChronoForecast : public Forecast + { + public: + Field operator()(Matrix &Mat, const Field& phi, const std::vector& prev_solns) + { + int degree = prev_solns.size(); + Field chi(phi); // forecasted solution + + // Trivial cases + if(degree == 0){ chi = zero; return chi; } + else if(degree == 1){ return prev_solns[0]; } + + RealD dot; + ComplexD xp; + Field r(phi); // residual + Field Mv(phi); + std::vector v(prev_solns); // orthonormalized previous solutions + std::vector MdagMv(degree,phi); + + // Array to hold the matrix elements + std::vector> G(degree, std::vector(degree)); + + // Solution and source vectors + std::vector a(degree); + std::vector b(degree); + + // Orthonormalize the vector basis + for(int i=0; i std::abs(G[k][k])){ k = j; } } + if(k != i){ + xp = b[k]; + b[k] = b[i]; + b[i] = xp; + for(int j=0; j=0; i--){ + a[i] = 0.0; + for(int j=i+1; j +#include namespace Grid { double MultiShiftFunction::approx(double x) diff --git a/lib/algorithms/approx/Remez.cc b/lib/algorithms/approx/Remez.cc index 38d60088..ca00a330 100644 --- a/lib/algorithms/approx/Remez.cc +++ b/lib/algorithms/approx/Remez.cc @@ -20,7 +20,7 @@ #include #include -#include +#include // Constructor AlgRemez::AlgRemez(double lower, double upper, long precision) diff --git a/lib/algorithms/approx/Remez.h b/lib/algorithms/approx/Remez.h index 31938779..71b1093b 100644 --- a/lib/algorithms/approx/Remez.h +++ b/lib/algorithms/approx/Remez.h @@ -16,7 +16,7 @@ #define INCLUDED_ALG_REMEZ_H #include -#include +#include #ifdef HAVE_LIBGMP #include "bigfloat.h" diff --git a/lib/algorithms/iterative/BlockConjugateGradient.h b/lib/algorithms/iterative/BlockConjugateGradient.h new file mode 100644 index 00000000..e0eeddcb --- /dev/null +++ b/lib/algorithms/iterative/BlockConjugateGradient.h @@ -0,0 +1,606 @@ +/************************************************************************************* + +Grid physics library, www.github.com/paboyle/Grid + +Source file: ./lib/algorithms/iterative/BlockConjugateGradient.h + +Copyright (C) 2017 + +Author: Azusa Yamaguchi +Author: Peter Boyle + +This program is free software; you can redistribute it and/or modify +it under the terms of the GNU General Public License as published by +the Free Software Foundation; either version 2 of the License, or +(at your option) any later version. + +This program is distributed in the hope that it will be useful, +but WITHOUT ANY WARRANTY; without even the implied warranty of +MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +GNU General Public License for more details. + +You should have received a copy of the GNU General Public License along +with this program; if not, write to the Free Software Foundation, Inc., +51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA. + +See the full license in the file "LICENSE" in the top level distribution +directory +*************************************************************************************/ +/* END LEGAL */ +#ifndef GRID_BLOCK_CONJUGATE_GRADIENT_H +#define GRID_BLOCK_CONJUGATE_GRADIENT_H + + +namespace Grid { + +enum BlockCGtype { BlockCG, BlockCGrQ, CGmultiRHS }; + +////////////////////////////////////////////////////////////////////////// +// Block conjugate gradient. Dimension zero should be the block direction +////////////////////////////////////////////////////////////////////////// +template +class BlockConjugateGradient : public OperatorFunction { + public: + + + typedef typename Field::scalar_type scomplex; + + int blockDim ; + int Nblock; + + BlockCGtype CGtype; + bool ErrorOnNoConverge; // throw an assert when the CG fails to converge. + // Defaults true. + RealD Tolerance; + Integer MaxIterations; + Integer IterationsToComplete; //Number of iterations the CG took to finish. Filled in upon completion + + BlockConjugateGradient(BlockCGtype cgtype,int _Orthog,RealD tol, Integer maxit, bool err_on_no_conv = true) + : Tolerance(tol), CGtype(cgtype), blockDim(_Orthog), MaxIterations(maxit), ErrorOnNoConverge(err_on_no_conv) + {}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// Thin QR factorisation (google it) +//////////////////////////////////////////////////////////////////////////////////////////////////// +void ThinQRfact (Eigen::MatrixXcd &m_rr, + Eigen::MatrixXcd &C, + Eigen::MatrixXcd &Cinv, + Field & Q, + const Field & R) +{ + int Orthog = blockDim; // First dimension is block dim; this is an assumption + //////////////////////////////////////////////////////////////////////////////////////////////////// + //Dimensions + // R_{ferm x Nblock} = Q_{ferm x Nblock} x C_{Nblock x Nblock} -> ferm x Nblock + // + // Rdag R = m_rr = Herm = L L^dag <-- Cholesky decomposition (LLT routine in Eigen) + // + // Q C = R => Q = R C^{-1} + // + // Want Ident = Q^dag Q = C^{-dag} R^dag R C^{-1} = C^{-dag} L L^dag C^{-1} = 1_{Nblock x Nblock} + // + // Set C = L^{dag}, and then Q^dag Q = ident + // + // Checks: + // Cdag C = Rdag R ; passes. + // QdagQ = 1 ; passes + //////////////////////////////////////////////////////////////////////////////////////////////////// + sliceInnerProductMatrix(m_rr,R,R,Orthog); + + // Force manifest hermitian to avoid rounding related + m_rr = 0.5*(m_rr+m_rr.adjoint()); + +#if 0 + std::cout << " Calling Cholesky ldlt on m_rr " << m_rr < &Linop, const Field &Src, Field &Psi) +{ + if ( CGtype == BlockCGrQ ) { + BlockCGrQsolve(Linop,Src,Psi); + } else if (CGtype == BlockCG ) { + BlockCGsolve(Linop,Src,Psi); + } else if (CGtype == CGmultiRHS ) { + CGmultiRHSsolve(Linop,Src,Psi); + } else { + assert(0); + } +} + +//////////////////////////////////////////////////////////////////////////// +// BlockCGrQ implementation: +//-------------------------- +// X is guess/Solution +// B is RHS +// Solve A X_i = B_i ; i refers to Nblock index +//////////////////////////////////////////////////////////////////////////// +void BlockCGrQsolve(LinearOperatorBase &Linop, const Field &B, Field &X) +{ + int Orthog = blockDim; // First dimension is block dim; this is an assumption + Nblock = B._grid->_fdimensions[Orthog]; + + std::cout< residuals(Nblock); + std::vector ssq(Nblock); + + sliceNorm(ssq,B,Orthog); + RealD sssum=0; + for(int b=0;b Thin QR factorisation (google it) + * for k: + * Z = AD + * M = [D^dag Z]^{-1} + * X = X + D MC + * QS = Q - ZM + * D = Q + D S^dag + * C = S C + */ + /////////////////////////////////////// + // Initial block: initial search dir is guess + /////////////////////////////////////// + std::cout << GridLogMessage<<"BlockCGrQ algorithm initialisation " < Thin QR factorisation (google it) + + Linop.HermOp(X, AD); + tmp = B - AD; + //std::cout << GridLogMessage << " initial tmp " << norm2(tmp)<< std::endl; + ThinQRfact (m_rr, m_C, m_Cinv, Q, tmp); + //std::cout << GridLogMessage << " initial Q " << norm2(Q)<< std::endl; + //std::cout << GridLogMessage << " m_rr " << m_rr< max_resid ) max_resid = rr; + } + + std::cout << GridLogIterative << "\titeration "< &Linop, const Field &Src, Field &Psi) +{ + int Orthog = blockDim; // First dimension is block dim; this is an assumption + Nblock = Src._grid->_fdimensions[Orthog]; + + std::cout< residuals(Nblock); + std::vector ssq(Nblock); + + sliceNorm(ssq,Src,Orthog); + RealD sssum=0; + for(int b=0;b max_resid ) max_resid = rr; + } + + if ( max_resid < Tolerance*Tolerance ) { + + SolverTimer.Stop(); + + std::cout << GridLogMessage<<"BlockCG converged in "< &Linop, const Field &Src, Field &Psi) +{ + int Orthog = blockDim; // First dimension is block dim + Nblock = Src._grid->_fdimensions[Orthog]; + + std::cout< v_pAp(Nblock); + std::vector v_rr (Nblock); + std::vector v_rr_inv(Nblock); + std::vector v_alpha(Nblock); + std::vector v_beta(Nblock); + + // Initial residual computation & set up + std::vector residuals(Nblock); + std::vector ssq(Nblock); + + sliceNorm(ssq,Src,Orthog); + RealD sssum=0; + for(int b=0;b max_resid ) max_resid = rr; + } + + if ( max_resid < Tolerance*Tolerance ) { + + SolverTimer.Stop(); + + std::cout << GridLogMessage<<"MultiRHS solver converged in " < { // Defaults true. RealD Tolerance; Integer MaxIterations; + Integer IterationsToComplete; //Number of iterations the CG took to finish. Filled in upon completion + ConjugateGradient(RealD tol, Integer maxit, bool err_on_no_conv = true) : Tolerance(tol), MaxIterations(maxit), @@ -76,18 +78,12 @@ class ConjugateGradient : public OperatorFunction { cp = a; ssq = norm2(src); - std::cout << GridLogIterative << std::setprecision(4) - << "ConjugateGradient: guess " << guess << std::endl; - std::cout << GridLogIterative << std::setprecision(4) - << "ConjugateGradient: src " << ssq << std::endl; - std::cout << GridLogIterative << std::setprecision(4) - << "ConjugateGradient: mp " << d << std::endl; - std::cout << GridLogIterative << std::setprecision(4) - << "ConjugateGradient: mmp " << b << std::endl; - std::cout << GridLogIterative << std::setprecision(4) - << "ConjugateGradient: cp,r " << cp << std::endl; - std::cout << GridLogIterative << std::setprecision(4) - << "ConjugateGradient: p " << a << std::endl; + std::cout << GridLogIterative << std::setprecision(4) << "ConjugateGradient: guess " << guess << std::endl; + std::cout << GridLogIterative << std::setprecision(4) << "ConjugateGradient: src " << ssq << std::endl; + std::cout << GridLogIterative << std::setprecision(4) << "ConjugateGradient: mp " << d << std::endl; + std::cout << GridLogIterative << std::setprecision(4) << "ConjugateGradient: mmp " << b << std::endl; + std::cout << GridLogIterative << std::setprecision(4) << "ConjugateGradient: cp,r " << cp << std::endl; + std::cout << GridLogIterative << std::setprecision(4) << "ConjugateGradient: p " << a << std::endl; RealD rsq = Tolerance * Tolerance * ssq; @@ -97,8 +93,7 @@ class ConjugateGradient : public OperatorFunction { } std::cout << GridLogIterative << std::setprecision(4) - << "ConjugateGradient: k=0 residual " << cp << " target " << rsq - << std::endl; + << "ConjugateGradient: k=0 residual " << cp << " target " << rsq << std::endl; GridStopWatch LinalgTimer; GridStopWatch MatrixTimer; @@ -128,8 +123,11 @@ class ConjugateGradient : public OperatorFunction { p = p * b + r; LinalgTimer.Stop(); + std::cout << GridLogIterative << "ConjugateGradient: Iteration " << k << " residual " << cp << " target " << rsq << std::endl; + std::cout << GridLogDebug << "a = "<< a << " b_pred = "<< b_pred << " b = "<< b << std::endl; + std::cout << GridLogDebug << "qq = "<< qq << " d = "<< d << " c = "<< c << std::endl; // Stopping condition if (cp <= rsq) { @@ -137,31 +135,33 @@ class ConjugateGradient : public OperatorFunction { Linop.HermOpAndNorm(psi, mmp, d, qq); p = mmp - src; - RealD mmpnorm = sqrt(norm2(mmp)); - RealD psinorm = sqrt(norm2(psi)); RealD srcnorm = sqrt(norm2(src)); RealD resnorm = sqrt(norm2(p)); RealD true_residual = resnorm / srcnorm; - std::cout << GridLogMessage - << "ConjugateGradient: Converged on iteration " << k << std::endl; - std::cout << GridLogMessage << "Computed residual " << sqrt(cp / ssq) - << " true residual " << true_residual << " target " - << Tolerance << std::endl; - std::cout << GridLogMessage << "Time elapsed: Iterations " - << SolverTimer.Elapsed() << " Matrix " - << MatrixTimer.Elapsed() << " Linalg " - << LinalgTimer.Elapsed(); - std::cout << std::endl; + std::cout << GridLogMessage << "ConjugateGradient Converged on iteration " << k << std::endl; + std::cout << GridLogMessage << "\tComputed residual " << sqrt(cp / ssq)< { public: RealD Tolerance; + RealD InnerTolerance; //Initial tolerance for inner CG. Defaults to Tolerance but can be changed Integer MaxInnerIterations; Integer MaxOuterIterations; GridBase* SinglePrecGrid; //Grid for single-precision fields @@ -42,12 +43,16 @@ namespace Grid { LinearOperatorBase &Linop_f; LinearOperatorBase &Linop_d; + Integer TotalInnerIterations; //Number of inner CG iterations + Integer TotalOuterIterations; //Number of restarts + Integer TotalFinalStepIterations; //Number of CG iterations in final patch-up step + //Option to speed up *inner single precision* solves using a LinearFunction that produces a guess LinearFunction *guesser; MixedPrecisionConjugateGradient(RealD tol, Integer maxinnerit, Integer maxouterit, GridBase* _sp_grid, LinearOperatorBase &_Linop_f, LinearOperatorBase &_Linop_d) : Linop_f(_Linop_f), Linop_d(_Linop_d), - Tolerance(tol), MaxInnerIterations(maxinnerit), MaxOuterIterations(maxouterit), SinglePrecGrid(_sp_grid), + Tolerance(tol), InnerTolerance(tol), MaxInnerIterations(maxinnerit), MaxOuterIterations(maxouterit), SinglePrecGrid(_sp_grid), OuterLoopNormMult(100.), guesser(NULL){ }; void useGuesser(LinearFunction &g){ @@ -55,6 +60,8 @@ namespace Grid { } void operator() (const FieldD &src_d_in, FieldD &sol_d){ + TotalInnerIterations = 0; + GridStopWatch TotalTimer; TotalTimer.Start(); @@ -74,7 +81,7 @@ namespace Grid { FieldD src_d(DoublePrecGrid); src_d = src_d_in; //source for next inner iteration, computed from residual during operation - RealD inner_tol = Tolerance; + RealD inner_tol = InnerTolerance; FieldF src_f(SinglePrecGrid); src_f.checkerboard = cb; @@ -89,7 +96,9 @@ namespace Grid { GridStopWatch PrecChangeTimer; - for(Integer outer_iter = 0; outer_iter < MaxOuterIterations; outer_iter++){ + Integer &outer_iter = TotalOuterIterations; //so it will be equal to the final iteration count + + for(outer_iter = 0; outer_iter < MaxOuterIterations; outer_iter++){ //Compute double precision rsd and also new RHS vector. Linop_d.HermOp(sol_d, tmp_d); RealD norm = axpy_norm(src_d, -1., tmp_d, src_d_in); //src_d is residual vector @@ -117,6 +126,7 @@ namespace Grid { InnerCGtimer.Start(); CG_f(Linop_f, src_f, sol_f); InnerCGtimer.Stop(); + TotalInnerIterations += CG_f.IterationsToComplete; //Convert sol back to double and add to double prec solution PrecChangeTimer.Start(); @@ -131,9 +141,11 @@ namespace Grid { ConjugateGradient CG_d(Tolerance, MaxInnerIterations); CG_d(Linop_d, src_d_in, sol_d); + TotalFinalStepIterations = CG_d.IterationsToComplete; TotalTimer.Stop(); - std::cout< + + This program is free software; you can redistribute it and/or modify + it under the terms of the GNU General Public License as published by + the Free Software Foundation; either version 2 of the License, or + (at your option) any later version. + + This program is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU General Public License for more details. + + You should have received a copy of the GNU General Public License along + with this program; if not, write to the Free Software Foundation, Inc., + 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA. + + See the full license in the file "LICENSE" in the top level distribution directory + *************************************************************************************/ + /* END LEGAL */ +#ifndef GRID_CONJUGATE_GRADIENT_RELIABLE_UPDATE_H +#define GRID_CONJUGATE_GRADIENT_RELIABLE_UPDATE_H + +namespace Grid { + + template::value == 2, int>::type = 0,typename std::enable_if< getPrecision::value == 1, int>::type = 0> + class ConjugateGradientReliableUpdate : public LinearFunction { + public: + bool ErrorOnNoConverge; // throw an assert when the CG fails to converge. + // Defaults true. + RealD Tolerance; + Integer MaxIterations; + Integer IterationsToComplete; //Number of iterations the CG took to finish. Filled in upon completion + Integer ReliableUpdatesPerformed; + + bool DoFinalCleanup; //Final DP cleanup, defaults to true + Integer IterationsToCleanup; //Final DP cleanup step iterations + + LinearOperatorBase &Linop_f; + LinearOperatorBase &Linop_d; + GridBase* SinglePrecGrid; + RealD Delta; //reliable update parameter + + //Optional ability to switch to a different linear operator once the tolerance reaches a certain point. Useful for single/half -> single/single + LinearOperatorBase *Linop_fallback; + RealD fallback_transition_tol; + + + ConjugateGradientReliableUpdate(RealD tol, Integer maxit, RealD _delta, GridBase* _sp_grid, LinearOperatorBase &_Linop_f, LinearOperatorBase &_Linop_d, bool err_on_no_conv = true) + : Tolerance(tol), + MaxIterations(maxit), + Delta(_delta), + Linop_f(_Linop_f), + Linop_d(_Linop_d), + SinglePrecGrid(_sp_grid), + ErrorOnNoConverge(err_on_no_conv), + DoFinalCleanup(true), + Linop_fallback(NULL) + {}; + + void setFallbackLinop(LinearOperatorBase &_Linop_fallback, const RealD _fallback_transition_tol){ + Linop_fallback = &_Linop_fallback; + fallback_transition_tol = _fallback_transition_tol; + } + + void operator()(const FieldD &src, FieldD &psi) { + LinearOperatorBase *Linop_f_use = &Linop_f; + bool using_fallback = false; + + psi.checkerboard = src.checkerboard; + conformable(psi, src); + + RealD cp, c, a, d, b, ssq, qq, b_pred; + + FieldD p(src); + FieldD mmp(src); + FieldD r(src); + + // Initial residual computation & set up + RealD guess = norm2(psi); + assert(std::isnan(guess) == 0); + + Linop_d.HermOpAndNorm(psi, mmp, d, b); + + r = src - mmp; + p = r; + + a = norm2(p); + cp = a; + ssq = norm2(src); + + std::cout << GridLogIterative << std::setprecision(4) << "ConjugateGradientReliableUpdate: guess " << guess << std::endl; + std::cout << GridLogIterative << std::setprecision(4) << "ConjugateGradientReliableUpdate: src " << ssq << std::endl; + std::cout << GridLogIterative << std::setprecision(4) << "ConjugateGradientReliableUpdate: mp " << d << std::endl; + std::cout << GridLogIterative << std::setprecision(4) << "ConjugateGradientReliableUpdate: mmp " << b << std::endl; + std::cout << GridLogIterative << std::setprecision(4) << "ConjugateGradientReliableUpdate: cp,r " << cp << std::endl; + std::cout << GridLogIterative << std::setprecision(4) << "ConjugateGradientReliableUpdate: p " << a << std::endl; + + RealD rsq = Tolerance * Tolerance * ssq; + + // Check if guess is really REALLY good :) + if (cp <= rsq) { + std::cout << GridLogMessage << "ConjugateGradientReliableUpdate guess was REALLY good\n"; + std::cout << GridLogMessage << "\tComputed residual " << sqrt(cp / ssq)<HermOpAndNorm(p_f, mmp_f, d, qq); + MatrixTimer.Stop(); + + LinalgTimer.Start(); + + a = c / d; + b_pred = a * (a * qq - d) / c; + + cp = axpy_norm(r_f, -a, mmp_f, r_f); + b = cp / c; + + // Fuse these loops ; should be really easy + psi_f = a * p_f + psi_f; + //p_f = p_f * b + r_f; + + LinalgTimer.Stop(); + + std::cout << GridLogIterative << "ConjugateGradientReliableUpdate: Iteration " << k + << " residual " << cp << " target " << rsq << std::endl; + std::cout << GridLogDebug << "a = "<< a << " b_pred = "<< b_pred << " b = "<< b << std::endl; + std::cout << GridLogDebug << "qq = "<< qq << " d = "<< d << " c = "<< c << std::endl; + + if(cp > MaxResidSinceLastRelUp){ + std::cout << GridLogIterative << "ConjugateGradientReliableUpdate: updating MaxResidSinceLastRelUp : " << MaxResidSinceLastRelUp << " -> " << cp << std::endl; + MaxResidSinceLastRelUp = cp; + } + + // Stopping condition + if (cp <= rsq) { + //Although not written in the paper, I assume that I have to add on the final solution + precisionChange(mmp, psi_f); + psi = psi + mmp; + + + SolverTimer.Stop(); + Linop_d.HermOpAndNorm(psi, mmp, d, qq); + p = mmp - src; + + RealD srcnorm = sqrt(norm2(src)); + RealD resnorm = sqrt(norm2(p)); + RealD true_residual = resnorm / srcnorm; + + std::cout << GridLogMessage << "ConjugateGradientReliableUpdate Converged on iteration " << k << " after " << l << " reliable updates" << std::endl; + std::cout << GridLogMessage << "\tComputed residual " << sqrt(cp / ssq)< CG(Tolerance,MaxIterations); + CG.ErrorOnNoConverge = ErrorOnNoConverge; + CG(Linop_d,src,psi); + IterationsToCleanup = CG.IterationsToComplete; + } + else if (ErrorOnNoConverge) assert(true_residual / Tolerance < 10000.0); + + std::cout << GridLogMessage << "ConjugateGradientReliableUpdate complete.\n"; + return; + } + else if(cp < Delta * MaxResidSinceLastRelUp) { //reliable update + std::cout << GridLogMessage << "ConjugateGradientReliableUpdate " + << cp << "(residual) < " << Delta << "(Delta) * " << MaxResidSinceLastRelUp << "(MaxResidSinceLastRelUp) on iteration " << k << " : performing reliable update\n"; + precisionChange(mmp, psi_f); + psi = psi + mmp; + + Linop_d.HermOpAndNorm(psi, mmp, d, qq); + r = src - mmp; + + psi_f = zero; + precisionChange(r_f, r); + cp = norm2(r); + MaxResidSinceLastRelUp = cp; + + b = cp/c; + + std::cout << GridLogMessage << "ConjugateGradientReliableUpdate new residual " << cp << std::endl; + + l = l+1; + } + + p_f = p_f * b + r_f; //update search vector after reliable update appears to help convergence + + if(!using_fallback && Linop_fallback != NULL && cp < fallback_transition_tol){ + std::cout << GridLogMessage << "ConjugateGradientReliableUpdate switching to fallback linear operator on iteration " << k << " at residual " << cp << std::endl; + Linop_f_use = Linop_fallback; + using_fallback = true; + } + + + } + std::cout << GridLogMessage << "ConjugateGradientReliableUpdate did NOT converge" + << std::endl; + + if (ErrorOnNoConverge) assert(0); + IterationsToComplete = k; + ReliableUpdatesPerformed = l; + } + }; + + +}; + + + +#endif diff --git a/lib/algorithms/iterative/DenseMatrix.h b/lib/algorithms/iterative/DenseMatrix.h deleted file mode 100644 index d86add21..00000000 --- a/lib/algorithms/iterative/DenseMatrix.h +++ /dev/null @@ -1,137 +0,0 @@ - /************************************************************************************* - - Grid physics library, www.github.com/paboyle/Grid - - Source file: ./lib/algorithms/iterative/DenseMatrix.h - - Copyright (C) 2015 - -Author: Peter Boyle -Author: paboyle - - This program is free software; you can redistribute it and/or modify - it under the terms of the GNU General Public License as published by - the Free Software Foundation; either version 2 of the License, or - (at your option) any later version. - - This program is distributed in the hope that it will be useful, - but WITHOUT ANY WARRANTY; without even the implied warranty of - MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the - GNU General Public License for more details. - - You should have received a copy of the GNU General Public License along - with this program; if not, write to the Free Software Foundation, Inc., - 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA. - - See the full license in the file "LICENSE" in the top level distribution directory - *************************************************************************************/ - /* END LEGAL */ -#ifndef GRID_DENSE_MATRIX_H -#define GRID_DENSE_MATRIX_H - -namespace Grid { - ///////////////////////////////////////////////////////////// - // Matrix untils - ///////////////////////////////////////////////////////////// - -template using DenseVector = std::vector; -template using DenseMatrix = DenseVector >; - -template void Size(DenseVector & vec, int &N) -{ - N= vec.size(); -} -template void Size(DenseMatrix & mat, int &N,int &M) -{ - N= mat.size(); - M= mat[0].size(); -} - -template void SizeSquare(DenseMatrix & mat, int &N) -{ - int M; Size(mat,N,M); - assert(N==M); -} - -template void Resize(DenseVector & mat, int N) { - mat.resize(N); -} -template void Resize(DenseMatrix & mat, int N, int M) { - mat.resize(N); - for(int i=0;i void Fill(DenseMatrix & mat, T&val) { - int N,M; - Size(mat,N,M); - for(int i=0;i DenseMatrix Transpose(DenseMatrix & mat){ - int N,M; - Size(mat,N,M); - DenseMatrix C; Resize(C,M,N); - for(int i=0;i void Unity(DenseMatrix &A){ - int N; SizeSquare(A,N); - for(int i=0;i -void PlusUnit(DenseMatrix & A,T c){ - int dim; SizeSquare(A,dim); - for(int i=0;i -DenseMatrix HermitianConj(DenseMatrix &mat){ - - int dim; SizeSquare(mat,dim); - - DenseMatrix C; Resize(C,dim,dim); - - for(int i=0;i -DenseMatrix GetSubMtx(DenseMatrix &A,int row_st, int row_end, int col_st, int col_end) -{ - DenseMatrix H; Resize(H,row_end - row_st,col_end-col_st); - - for(int i = row_st; i - - This program is free software; you can redistribute it and/or modify - it under the terms of the GNU General Public License as published by - the Free Software Foundation; either version 2 of the License, or - (at your option) any later version. - - This program is distributed in the hope that it will be useful, - but WITHOUT ANY WARRANTY; without even the implied warranty of - MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the - GNU General Public License for more details. - - You should have received a copy of the GNU General Public License along - with this program; if not, write to the Free Software Foundation, Inc., - 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA. - - See the full license in the file "LICENSE" in the top level distribution directory - *************************************************************************************/ - /* END LEGAL */ -#ifndef GRID_EIGENSORT_H -#define GRID_EIGENSORT_H - - -namespace Grid { - ///////////////////////////////////////////////////////////// - // Eigen sorter to begin with - ///////////////////////////////////////////////////////////// - -template -class SortEigen { - private: - -//hacking for testing for now - private: - static bool less_lmd(RealD left,RealD right){ - return left > right; - } - static bool less_pair(std::pair& left, - std::pair& right){ - return left.first > (right.first); - } - - - public: - - void push(DenseVector& lmd, - DenseVector& evec,int N) { - DenseVector cpy(lmd.size(),evec[0]._grid); - for(int i=0;i > emod(lmd.size()); - for(int i=0;i(lmd[i],&cpy[i]); - - partial_sort(emod.begin(),emod.begin()+N,emod.end(),less_pair); - - typename DenseVector >::iterator it = emod.begin(); - for(int i=0;ifirst; - evec[i]=*(it->second); - ++it; - } - } - void push(DenseVector& lmd,int N) { - std::partial_sort(lmd.begin(),lmd.begin()+N,lmd.end(),less_lmd); - } - bool saturated(RealD lmd, RealD thrs) { - return fabs(lmd) > fabs(thrs); - } -}; - -} -#endif diff --git a/lib/algorithms/iterative/Francis.h b/lib/algorithms/iterative/Francis.h deleted file mode 100644 index 08ecbd7b..00000000 --- a/lib/algorithms/iterative/Francis.h +++ /dev/null @@ -1,525 +0,0 @@ - /************************************************************************************* - - Grid physics library, www.github.com/paboyle/Grid - - Source file: ./lib/algorithms/iterative/Francis.h - - Copyright (C) 2015 - -Author: Peter Boyle - - This program is free software; you can redistribute it and/or modify - it under the terms of the GNU General Public License as published by - the Free Software Foundation; either version 2 of the License, or - (at your option) any later version. - - This program is distributed in the hope that it will be useful, - but WITHOUT ANY WARRANTY; without even the implied warranty of - MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the - GNU General Public License for more details. - - You should have received a copy of the GNU General Public License along - with this program; if not, write to the Free Software Foundation, Inc., - 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA. - - See the full license in the file "LICENSE" in the top level distribution directory - *************************************************************************************/ - /* END LEGAL */ -#ifndef FRANCIS_H -#define FRANCIS_H - -#include -#include -#include -#include -#include -#include -#include -#include -#include - -//#include -//#include -//#include - -namespace Grid { - -template int SymmEigensystem(DenseMatrix &Ain, DenseVector &evals, DenseMatrix &evecs, RealD small); -template int Eigensystem(DenseMatrix &Ain, DenseVector &evals, DenseMatrix &evecs, RealD small); - -/** - Find the eigenvalues of an upper hessenberg matrix using the Francis QR algorithm. -H = - x x x x x x x x x - x x x x x x x x x - 0 x x x x x x x x - 0 0 x x x x x x x - 0 0 0 x x x x x x - 0 0 0 0 x x x x x - 0 0 0 0 0 x x x x - 0 0 0 0 0 0 x x x - 0 0 0 0 0 0 0 x x -Factorization is P T P^H where T is upper triangular (mod cc blocks) and P is orthagonal/unitary. -**/ -template -int QReigensystem(DenseMatrix &Hin, DenseVector &evals, DenseMatrix &evecs, RealD small) -{ - DenseMatrix H = Hin; - - int N ; SizeSquare(H,N); - int M = N; - - Fill(evals,0); - Fill(evecs,0); - - T s,t,x=0,y=0,z=0; - T u,d; - T apd,amd,bc; - DenseVector p(N,0); - T nrm = Norm(H); ///DenseMatrix Norm - int n, m; - int e = 0; - int it = 0; - int tot_it = 0; - int l = 0; - int r = 0; - DenseMatrix P; Resize(P,N,N); Unity(P); - DenseVector trows(N,0); - - /// Check if the matrix is really hessenberg, if not abort - RealD sth = 0; - for(int j=0;j small){ - std::cout << "Non hessenberg H = " << sth << " > " << small << std::endl; - exit(1); - } - } - } - - do{ - std::cout << "Francis QR Step N = " << N << std::endl; - /** Check for convergence - x x x x x - 0 x x x x - 0 0 x x x - 0 0 x x x - 0 0 0 0 x - for this matrix l = 4 - **/ - do{ - l = Chop_subdiag(H,nrm,e,small); - r = 0; ///May have converged on more than one eval - ///Single eval - if(l == N-1){ - evals[e] = H[l][l]; - N--; e++; r++; it = 0; - } - ///RealD eval - if(l == N-2){ - trows[l+1] = 1; ///Needed for UTSolve - apd = H[l][l] + H[l+1][l+1]; - amd = H[l][l] - H[l+1][l+1]; - bc = (T)4.0*H[l+1][l]*H[l][l+1]; - evals[e] = (T)0.5*( apd + sqrt(amd*amd + bc) ); - evals[e+1] = (T)0.5*( apd - sqrt(amd*amd + bc) ); - N-=2; e+=2; r++; it = 0; - } - } while(r>0); - - if(N ==0) break; - - DenseVector ck; Resize(ck,3); - DenseVector v; Resize(v,3); - - for(int m = N-3; m >= l; m--){ - ///Starting vector essentially random shift. - if(it%10 == 0 && N >= 3 && it > 0){ - s = (T)1.618033989*( abs( H[N-1][N-2] ) + abs( H[N-2][N-3] ) ); - t = (T)0.618033989*( abs( H[N-1][N-2] ) + abs( H[N-2][N-3] ) ); - x = H[m][m]*H[m][m] + H[m][m+1]*H[m+1][m] - s*H[m][m] + t; - y = H[m+1][m]*(H[m][m] + H[m+1][m+1] - s); - z = H[m+1][m]*H[m+2][m+1]; - } - ///Starting vector implicit Q theorem - else{ - s = (H[N-2][N-2] + H[N-1][N-1]); - t = (H[N-2][N-2]*H[N-1][N-1] - H[N-2][N-1]*H[N-1][N-2]); - x = H[m][m]*H[m][m] + H[m][m+1]*H[m+1][m] - s*H[m][m] + t; - y = H[m+1][m]*(H[m][m] + H[m+1][m+1] - s); - z = H[m+1][m]*H[m+2][m+1]; - } - ck[0] = x; ck[1] = y; ck[2] = z; - - if(m == l) break; - - /** Some stupid thing from numerical recipies, seems to work**/ - // PAB.. for heaven's sake quote page, purpose, evidence it works. - // what sort of comment is that!?!?!? - u=abs(H[m][m-1])*(abs(y)+abs(z)); - d=abs(x)*(abs(H[m-1][m-1])+abs(H[m][m])+abs(H[m+1][m+1])); - if ((T)abs(u+d) == (T)abs(d) ){ - l = m; break; - } - - //if (u < small){l = m; break;} - } - if(it > 100000){ - std::cout << "QReigensystem: bugger it got stuck after 100000 iterations" << std::endl; - std::cout << "got " << e << " evals " << l << " " << N << std::endl; - exit(1); - } - normalize(ck); ///Normalization cancels in PHP anyway - T beta; - Householder_vector(ck, 0, 2, v, beta); - Householder_mult(H,v,beta,0,l,l+2,0); - Householder_mult(H,v,beta,0,l,l+2,1); - ///Accumulate eigenvector - Householder_mult(P,v,beta,0,l,l+2,1); - int sw = 0; ///Are we on the last row? - for(int k=l;k(ck, 0, 2-sw, v, beta); - Householder_mult(H,v, beta,0,k+1,k+3-sw,0); - Householder_mult(H,v, beta,0,k+1,k+3-sw,1); - ///Accumulate eigenvector - Householder_mult(P,v, beta,0,k+1,k+3-sw,1); - } - it++; - tot_it++; - }while(N > 1); - N = evals.size(); - ///Annoying - UT solves in reverse order; - DenseVector tmp; Resize(tmp,N); - for(int i=0;i -int my_Wilkinson(DenseMatrix &Hin, DenseVector &evals, DenseMatrix &evecs, RealD small) -{ - /** - Find the eigenvalues of an upper Hessenberg matrix using the Wilkinson QR algorithm. - H = - x x 0 0 0 0 - x x x 0 0 0 - 0 x x x 0 0 - 0 0 x x x 0 - 0 0 0 x x x - 0 0 0 0 x x - Factorization is P T P^H where T is upper triangular (mod cc blocks) and P is orthagonal/unitary. **/ - return my_Wilkinson(Hin, evals, evecs, small, small); -} - -template -int my_Wilkinson(DenseMatrix &Hin, DenseVector &evals, DenseMatrix &evecs, RealD small, RealD tol) -{ - int N; SizeSquare(Hin,N); - int M = N; - - ///I don't want to modify the input but matricies must be passed by reference - //Scale a matrix by its "norm" - //RealD Hnorm = abs( Hin.LargestDiag() ); H = H*(1.0/Hnorm); - DenseMatrix H; H = Hin; - - RealD Hnorm = abs(Norm(Hin)); - H = H * (1.0 / Hnorm); - - // TODO use openmp and memset - Fill(evals,0); - Fill(evecs,0); - - T s, t, x = 0, y = 0, z = 0; - T u, d; - T apd, amd, bc; - DenseVector p; Resize(p,N); Fill(p,0); - - T nrm = Norm(H); ///DenseMatrix Norm - int n, m; - int e = 0; - int it = 0; - int tot_it = 0; - int l = 0; - int r = 0; - DenseMatrix P; Resize(P,N,N); - Unity(P); - DenseVector trows(N, 0); - /// Check if the matrix is really symm tridiag - RealD sth = 0; - for(int j = 0; j < N; ++j) - { - for(int i = j + 2; i < N; ++i) - { - if(abs(H[i][j]) > tol || abs(H[j][i]) > tol) - { - std::cout << "Non Tridiagonal H(" << i << ","<< j << ") = |" << Real( real( H[j][i] ) ) << "| > " << tol << std::endl; - std::cout << "Warning tridiagonalize and call again" << std::endl; - // exit(1); // see what is going on - //return; - } - } - } - - do{ - do{ - //Jasper - //Check if the subdiagonal term is small enough ( 0); - //Jasper - //Already converged - //-------------- - if(N == 0) break; - - DenseVector ck,v; Resize(ck,2); Resize(v,2); - - for(int m = N - 3; m >= l; m--) - { - ///Starting vector essentially random shift. - if(it%10 == 0 && N >= 3 && it > 0) - { - t = abs(H[N - 1][N - 2]) + abs(H[N - 2][N - 3]); - x = H[m][m] - t; - z = H[m + 1][m]; - } else { - ///Starting vector implicit Q theorem - d = (H[N - 2][N - 2] - H[N - 1][N - 1]) * (T) 0.5; - t = H[N - 1][N - 1] - H[N - 1][N - 2] * H[N - 1][N - 2] - / (d + sign(d) * sqrt(d * d + H[N - 1][N - 2] * H[N - 1][N - 2])); - x = H[m][m] - t; - z = H[m + 1][m]; - } - //Jasper - //why it is here???? - //----------------------- - if(m == l) - break; - - u = abs(H[m][m - 1]) * (abs(y) + abs(z)); - d = abs(x) * (abs(H[m - 1][m - 1]) + abs(H[m][m]) + abs(H[m + 1][m + 1])); - if ((T)abs(u + d) == (T)abs(d)) - { - l = m; - break; - } - } - //Jasper - if(it > 1000000) - { - std::cout << "Wilkinson: bugger it got stuck after 100000 iterations" << std::endl; - std::cout << "got " << e << " evals " << l << " " << N << std::endl; - exit(1); - } - // - T s, c; - Givens_calc(x, z, c, s); - Givens_mult(H, l, l + 1, c, -s, 0); - Givens_mult(H, l, l + 1, c, s, 1); - Givens_mult(P, l, l + 1, c, s, 1); - // - for(int k = l; k < N - 2; ++k) - { - x = H.A[k + 1][k]; - z = H.A[k + 2][k]; - Givens_calc(x, z, c, s); - Givens_mult(H, k + 1, k + 2, c, -s, 0); - Givens_mult(H, k + 1, k + 2, c, s, 1); - Givens_mult(P, k + 1, k + 2, c, s, 1); - } - it++; - tot_it++; - }while(N > 1); - - N = evals.size(); - ///Annoying - UT solves in reverse order; - DenseVector tmp(N); - for(int i = 0; i < N; ++i) - tmp[i] = evals[N-i-1]; - evals = tmp; - // - UTeigenvectors(H, trows, evals, evecs); - //UTSymmEigenvectors(H, trows, evals, evecs); - for(int i = 0; i < evals.size(); ++i) - { - evecs[i] = P * evecs[i]; - normalize(evecs[i]); - evals[i] = evals[i] * Hnorm; - } - // // FIXME this is to test - // Hin.write("evecs3", evecs); - // Hin.write("evals3", evals); - // // check rsd - // for(int i = 0; i < M; i++) { - // vector Aevec = Hin * evecs[i]; - // RealD norm2(0.); - // for(int j = 0; j < M; j++) { - // norm2 += (Aevec[j] - evals[i] * evecs[i][j]) * (Aevec[j] - evals[i] * evecs[i][j]); - // } - // } - return tot_it; -} - -template -void Hess(DenseMatrix &A, DenseMatrix &Q, int start){ - - /** - turn a matrix A = - x x x x x - x x x x x - x x x x x - x x x x x - x x x x x - into - x x x x x - x x x x x - 0 x x x x - 0 0 x x x - 0 0 0 x x - with householder rotations - Slow. - */ - int N ; SizeSquare(A,N); - DenseVector p; Resize(p,N); Fill(p,0); - - for(int k=start;k ck,v; Resize(ck,N-k-1); Resize(v,N-k-1); - for(int i=k+1;i(ck, 0, ck.size()-1, v, beta); ///Householder vector - Householder_mult(A,v,beta,start,k+1,N-1,0); ///A -> PA - Householder_mult(A,v,beta,start,k+1,N-1,1); ///PA -> PAP^H - ///Accumulate eigenvector - Householder_mult(Q,v,beta,start,k+1,N-1,1); ///Q -> QP^H - } - /*for(int l=0;l -void Tri(DenseMatrix &A, DenseMatrix &Q, int start){ -///Tridiagonalize a matrix - int N; SizeSquare(A,N); - Hess(A,Q,start); - /*for(int l=0;l -void ForceTridiagonal(DenseMatrix &A){ -///Tridiagonalize a matrix - int N ; SizeSquare(A,N); - for(int l=0;l -int my_SymmEigensystem(DenseMatrix &Ain, DenseVector &evals, DenseVector > &evecs, RealD small){ - ///Solve a symmetric eigensystem, not necessarily in tridiagonal form - int N; SizeSquare(Ain,N); - DenseMatrix A; A = Ain; - DenseMatrix Q; Resize(Q,N,N); Unity(Q); - Tri(A,Q,0); - int it = my_Wilkinson(A, evals, evecs, small); - for(int k=0;k -int Wilkinson(DenseMatrix &Ain, DenseVector &evals, DenseVector > &evecs, RealD small){ - return my_Wilkinson(Ain, evals, evecs, small); -} - -template -int SymmEigensystem(DenseMatrix &Ain, DenseVector &evals, DenseVector > &evecs, RealD small){ - return my_SymmEigensystem(Ain, evals, evecs, small); -} - -template -int Eigensystem(DenseMatrix &Ain, DenseVector &evals, DenseVector > &evecs, RealD small){ -///Solve a general eigensystem, not necessarily in tridiagonal form - int N = Ain.dim; - DenseMatrix A(N); A = Ain; - DenseMatrix Q(N);Q.Unity(); - Hess(A,Q,0); - int it = QReigensystem(A, evals, evecs, small); - for(int k=0;k - - This program is free software; you can redistribute it and/or modify - it under the terms of the GNU General Public License as published by - the Free Software Foundation; either version 2 of the License, or - (at your option) any later version. - - This program is distributed in the hope that it will be useful, - but WITHOUT ANY WARRANTY; without even the implied warranty of - MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the - GNU General Public License for more details. - - You should have received a copy of the GNU General Public License along - with this program; if not, write to the Free Software Foundation, Inc., - 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA. - - See the full license in the file "LICENSE" in the top level distribution directory - *************************************************************************************/ - /* END LEGAL */ -#ifndef HOUSEHOLDER_H -#define HOUSEHOLDER_H - -#define TIMER(A) std::cout << GridLogMessage << __FUNC__ << " file "<< __FILE__ <<" line " << __LINE__ << std::endl; -#define ENTER() std::cout << GridLogMessage << "ENTRY "<<__FUNC__ << " file "<< __FILE__ <<" line " << __LINE__ << std::endl; -#define LEAVE() std::cout << GridLogMessage << "EXIT "<<__FUNC__ << " file "<< __FILE__ <<" line " << __LINE__ << std::endl; - -#include -#include -#include -#include -#include -#include -#include -#include -#include - -namespace Grid { -/** Comparison function for finding the max element in a vector **/ -template bool cf(T i, T j) { - return abs(i) < abs(j); -} - -/** - Calculate a real Givens angle - **/ -template inline void Givens_calc(T y, T z, T &c, T &s){ - - RealD mz = (RealD)abs(z); - - if(mz==0.0){ - c = 1; s = 0; - } - if(mz >= (RealD)abs(y)){ - T t = -y/z; - s = (T)1.0 / sqrt ((T)1.0 + t * t); - c = s * t; - } else { - T t = -z/y; - c = (T)1.0 / sqrt ((T)1.0 + t * t); - s = c * t; - } -} - -template inline void Givens_mult(DenseMatrix &A, int i, int k, T c, T s, int dir) -{ - int q ; SizeSquare(A,q); - - if(dir == 0){ - for(int j=0;j inline void Householder_vector(DenseVector input, int k, int j, DenseVector &v, T &beta) -{ - int N ; Size(input,N); - T m = *max_element(input.begin() + k, input.begin() + j + 1, cf ); - - if(abs(m) > 0.0){ - T alpha = 0; - - for(int i=k; i 0.0) v[k] = v[k] + (v[k]/abs(v[k]))*alpha; - else v[k] = -alpha; - } else{ - for(int i=k; i inline void Householder_vector(DenseVector input, int k, int j, int dir, DenseVector &v, T &beta) -{ - int N = input.size(); - T m = *max_element(input.begin() + k, input.begin() + j + 1, cf); - - if(abs(m) > 0.0){ - T alpha = 0; - - for(int i=k; i 0.0) v[dir] = v[dir] + (v[dir]/abs(v[dir]))*alpha; - else v[dir] = -alpha; - }else{ - for(int i=k; i inline void Householder_mult(DenseMatrix &A , DenseVector v, T beta, int l, int k, int j, int trans) -{ - int N ; SizeSquare(A,N); - - if(abs(beta) > 0.0){ - for(int p=l; p inline void Householder_mult_tri(DenseMatrix &A , DenseVector v, T beta, int l, int M, int k, int j, int trans) -{ - if(abs(beta) > 0.0){ - - int N ; SizeSquare(A,N); - - DenseMatrix tmp; Resize(tmp,N,N); Fill(tmp,0); - - T s; - for(int p=l; p -Author: paboyle +Author: Chulwoo Jung +Author: Guido Cossu This program is free software; you can redistribute it and/or modify it under the terms of the GNU General Public License as published by @@ -30,435 +31,122 @@ Author: paboyle #define GRID_IRL_H #include //memset -#ifdef USE_LAPACK -void LAPACK_dstegr(char *jobz, char *range, int *n, double *d, double *e, - double *vl, double *vu, int *il, int *iu, double *abstol, - int *m, double *w, double *z, int *ldz, int *isuppz, - double *work, int *lwork, int *iwork, int *liwork, - int *info); -#endif -#include "DenseMatrix.h" -#include "EigenSort.h" namespace Grid { + enum IRLdiagonalisation { + IRLdiagonaliseWithDSTEGR, + IRLdiagonaliseWithQR, + IRLdiagonaliseWithEigen + }; + +//////////////////////////////////////////////////////////////////////////////// +// Helper class for sorting the evalues AND evectors by Field +// Use pointer swizzle on vectors +//////////////////////////////////////////////////////////////////////////////// +template +class SortEigen { + private: + static bool less_lmd(RealD left,RealD right){ + return left > right; + } + static bool less_pair(std::pair& left, + std::pair& right){ + return left.first > (right.first); + } + + public: + void push(std::vector& lmd,std::vector& evec,int N) { + + //////////////////////////////////////////////////////////////////////// + // PAB: FIXME: VERY VERY VERY wasteful: takes a copy of the entire vector set. + // : The vector reorder should be done by pointer swizzle somehow + //////////////////////////////////////////////////////////////////////// + std::vector cpy(lmd.size(),evec[0]._grid); + for(int i=0;i > emod(lmd.size()); + + for(int i=0;i(lmd[i],&cpy[i]); + + partial_sort(emod.begin(),emod.begin()+N,emod.end(),less_pair); + + typename std::vector >::iterator it = emod.begin(); + for(int i=0;ifirst; + evec[i]=*(it->second); + ++it; + } + } + void push(std::vector& lmd,int N) { + std::partial_sort(lmd.begin(),lmd.begin()+N,lmd.end(),less_lmd); + } + bool saturated(RealD lmd, RealD thrs) { + return fabs(lmd) > fabs(thrs); + } +}; + ///////////////////////////////////////////////////////////// // Implicitly restarted lanczos ///////////////////////////////////////////////////////////// - - template - class ImplicitlyRestartedLanczos { +class ImplicitlyRestartedLanczos { - const RealD small = 1.0e-16; +private: + + int MaxIter; // Max iterations + int Nstop; // Number of evecs checked for convergence + int Nk; // Number of converged sought + int Nm; // Nm -- total number of vectors + RealD eresid; + IRLdiagonalisation diagonalisation; + //////////////////////////////////// + // Embedded objects + //////////////////////////////////// + SortEigen _sort; + LinearOperatorBase &_Linop; + OperatorFunction &_poly; + + ///////////////////////// + // Constructor + ///////////////////////// public: - int lock; - int get; - int Niter; - int converged; + ImplicitlyRestartedLanczos(LinearOperatorBase &Linop, // op + OperatorFunction & poly, // polynomial + int _Nstop, // really sought vecs + int _Nk, // sought vecs + int _Nm, // total vecs + RealD _eresid, // resid in lmd deficit + int _MaxIter, // Max iterations + IRLdiagonalisation _diagonalisation= IRLdiagonaliseWithEigen ) : + _Linop(Linop), _poly(poly), + Nstop(_Nstop), Nk(_Nk), Nm(_Nm), + eresid(_eresid), MaxIter(_MaxIter), + diagonalisation(_diagonalisation) + { }; - int Nstop; // Number of evecs checked for convergence - int Nk; // Number of converged sought - int Np; // Np -- Number of spare vecs in kryloc space - int Nm; // Nm -- total number of vectors - - RealD eresid; - - SortEigen _sort; - -// GridCartesian &_fgrid; - - LinearOperatorBase &_Linop; - - OperatorFunction &_poly; - - ///////////////////////// - // Constructor - ///////////////////////// - void init(void){}; - void Abort(int ff, DenseVector &evals, DenseVector > &evecs); - - ImplicitlyRestartedLanczos( - LinearOperatorBase &Linop, // op - OperatorFunction & poly, // polynmial - int _Nstop, // sought vecs - int _Nk, // sought vecs - int _Nm, // spare vecs - RealD _eresid, // resid in lmdue deficit - int _Niter) : // Max iterations - _Linop(Linop), - _poly(poly), - Nstop(_Nstop), - Nk(_Nk), - Nm(_Nm), - eresid(_eresid), - Niter(_Niter) - { - Np = Nm-Nk; assert(Np>0); - }; - - ImplicitlyRestartedLanczos( - LinearOperatorBase &Linop, // op - OperatorFunction & poly, // polynmial - int _Nk, // sought vecs - int _Nm, // spare vecs - RealD _eresid, // resid in lmdue deficit - int _Niter) : // Max iterations - _Linop(Linop), - _poly(poly), - Nstop(_Nk), - Nk(_Nk), - Nm(_Nm), - eresid(_eresid), - Niter(_Niter) - { - Np = Nm-Nk; assert(Np>0); - }; - - ///////////////////////// - // Sanity checked this routine (step) against Saad. - ///////////////////////// - void RitzMatrix(DenseVector& evec,int k){ - - if(1) return; - - GridBase *grid = evec[0]._grid; - Field w(grid); - std::cout << "RitzMatrix "<1 ) { - if (abs(in) >1.0e-9 ) { - std::cout<<"oops"<& lmd, - DenseVector& lme, - DenseVector& evec, - Field& w,int Nm,int k) - { - assert( k< Nm ); - - _poly(_Linop,evec[k],w); // 3. wk:=Avk−βkv_{k−1} - if(k>0){ - w -= lme[k-1] * evec[k-1]; - } - - ComplexD zalph = innerProduct(evec[k],w); // 4. αk:=(wk,vk) - RealD alph = real(zalph); - - w = w - alph * evec[k];// 5. wk:=wk−αkvk - - RealD beta = normalise(w); // 6. βk+1 := ∥wk∥2. If βk+1 = 0 then Stop - // 7. vk+1 := wk/βk+1 - -// std::cout << "alpha = " << zalph << " beta "<0) { - orthogonalize(w,evec,k); // orthonormalise - } - - if(k < Nm-1) evec[k+1] = w; - } - - void qr_decomp(DenseVector& lmd, - DenseVector& lme, - int Nk, - int Nm, - DenseVector& Qt, - RealD Dsh, - int kmin, - int kmax) - { - int k = kmin-1; - RealD x; - - RealD Fden = 1.0/hypot(lmd[k]-Dsh,lme[k]); - RealD c = ( lmd[k] -Dsh) *Fden; - RealD s = -lme[k] *Fden; - - RealD tmpa1 = lmd[k]; - RealD tmpa2 = lmd[k+1]; - RealD tmpb = lme[k]; - - lmd[k] = c*c*tmpa1 +s*s*tmpa2 -2.0*c*s*tmpb; - lmd[k+1] = s*s*tmpa1 +c*c*tmpa2 +2.0*c*s*tmpb; - lme[k] = c*s*(tmpa1-tmpa2) +(c*c-s*s)*tmpb; - x =-s*lme[k+1]; - lme[k+1] = c*lme[k+1]; - - for(int i=0; i& lmd, - DenseVector& lme, - int N1, - int N2, - DenseVector& Qt, - GridBase *grid){ - const int size = Nm; -// tevals.resize(size); -// tevecs.resize(size); - int NN = N1; - double evals_tmp[NN]; - double evec_tmp[NN][NN]; - memset(evec_tmp[0],0,sizeof(double)*NN*NN); -// double AA[NN][NN]; - double DD[NN]; - double EE[NN]; - for (int i = 0; i< NN; i++) - for (int j = i - 1; j <= i + 1; j++) - if ( j < NN && j >= 0 ) { - if (i==j) DD[i] = lmd[i]; - if (i==j) evals_tmp[i] = lmd[i]; - if (j==(i-1)) EE[j] = lme[j]; - } - int evals_found; - int lwork = ( (18*NN) > (1+4*NN+NN*NN)? (18*NN):(1+4*NN+NN*NN)) ; - int liwork = 3+NN*10 ; - int iwork[liwork]; - double work[lwork]; - int isuppz[2*NN]; - char jobz = 'V'; // calculate evals & evecs - char range = 'I'; // calculate all evals - // char range = 'A'; // calculate all evals - char uplo = 'U'; // refer to upper half of original matrix - char compz = 'I'; // Compute eigenvectors of tridiagonal matrix - int ifail[NN]; - int info; -// int total = QMP_get_number_of_nodes(); -// int node = QMP_get_node_number(); -// GridBase *grid = evec[0]._grid; - int total = grid->_Nprocessors; - int node = grid->_processor; - int interval = (NN/total)+1; - double vl = 0.0, vu = 0.0; - int il = interval*node+1 , iu = interval*(node+1); - if (iu > NN) iu=NN; - double tol = 0.0; - if (1) { - memset(evals_tmp,0,sizeof(double)*NN); - if ( il <= NN){ - printf("total=%d node=%d il=%d iu=%d\n",total,node,il,iu); - LAPACK_dstegr(&jobz, &range, &NN, - (double*)DD, (double*)EE, - &vl, &vu, &il, &iu, // these four are ignored if second parameteris 'A' - &tol, // tolerance - &evals_found, evals_tmp, (double*)evec_tmp, &NN, - isuppz, - work, &lwork, iwork, &liwork, - &info); - for (int i = iu-1; i>= il-1; i--){ - printf("node=%d evals_found=%d evals_tmp[%d] = %g\n",node,evals_found, i - (il-1),evals_tmp[i - (il-1)]); - evals_tmp[i] = evals_tmp[i - (il-1)]; - if (il>1) evals_tmp[i-(il-1)]=0.; - for (int j = 0; j< NN; j++){ - evec_tmp[i][j] = evec_tmp[i - (il-1)][j]; - if (il>1) evec_tmp[i-(il-1)][j]=0.; - } - } - } - { -// QMP_sum_double_array(evals_tmp,NN); -// QMP_sum_double_array((double *)evec_tmp,NN*NN); - grid->GlobalSumVector(evals_tmp,NN); - grid->GlobalSumVector((double*)evec_tmp,NN*NN); - } - } -// cheating a bit. It is better to sort instead of just reversing it, but the document of the routine says evals are sorted in increasing order. qr gives evals in decreasing order. - for(int i=0;i& lmd, - DenseVector& lme, - int N2, - int N1, - DenseVector& Qt, - GridBase *grid) - { - -#ifdef USE_LAPACK - const int check_lapack=0; // just use lapack if 0, check against lapack if 1 - - if(!check_lapack) - return diagonalize_lapack(lmd,lme,N2,N1,Qt,grid); - - DenseVector lmd2(N1); - DenseVector lme2(N1); - DenseVector Qt2(N1*N1); - for(int k=0; k= kmin; --j){ - RealD dds = fabs(lmd[j-1])+fabs(lmd[j]); - if(fabs(lme[j-1])+dds > dds){ - kmax = j+1; - goto continued; - } - } - Niter = iter; -#ifdef USE_LAPACK - if(check_lapack){ - const double SMALL=1e-8; - diagonalize_lapack(lmd2,lme2,N2,N1,Qt2,grid); - DenseVector lmd3(N2); - for(int k=0; kSMALL) std::cout <<"lmd(qr) lmd(lapack) "<< k << ": " << lmd2[k] <<" "<< lmd3[k] <SMALL) std::cout <<"lme(qr)-lme(lapack) "<< k << ": " << lme2[k] - lme[k] <SMALL) std::cout <<"Qt(qr)-Qt(lapack) "<< k << ": " << Qt2[k] - Qt[k] < dds){ - kmin = j+1; - break; - } - } - } - std::cout << "[QL method] Error - Too many iteration: "<& evec, - int k) - { - typedef typename Field::scalar_type MyComplex; - MyComplex ip; - - if ( 0 ) { - for(int j=0; j &Qt) { - for(int i=0; i& evec, int k) + { + typedef typename Field::scalar_type MyComplex; + MyComplex ip; + + for(int j=0; j& eval, - DenseVector& evec, - const Field& src, - int& Nconv) - { - - GridBase *grid = evec[0]._grid; - assert(grid == src._grid); - - std::cout << " -- Nk = " << Nk << " Np = "<< Np << std::endl; - std::cout << " -- Nm = " << Nm << std::endl; - std::cout << " -- size of eval = " << eval.size() << std::endl; - std::cout << " -- size of evec = " << evec.size() << std::endl; - - assert(Nm == evec.size() && Nm == eval.size()); - - DenseVector lme(Nm); - DenseVector lme2(Nm); - DenseVector eval2(Nm); - DenseVector Qt(Nm*Nm); - DenseVector Iconv(Nm); - - DenseVector B(Nm,grid); // waste of space replicating - - Field f(grid); - Field v(grid); - - int k1 = 1; - int k2 = Nk; - - Nconv = 0; - - RealD beta_k; - - // Set initial vector - // (uniform vector) Why not src?? - // evec[0] = 1.0; - evec[0] = src; - std:: cout <<"norm2(src)= " << norm2(src)<& eval, std::vector& evec, const Field& src, int& Nconv) + { - for(int i=0; i<(Nk+1); ++i) B[i] = 0.0; + GridBase *grid = evec[0]._grid; + assert(grid == src._grid); + + std::cout << GridLogMessage <<"**************************************************************************"<< std::endl; + std::cout << GridLogMessage <<" ImplicitlyRestartedLanczos::calc() starting iteration 0 / "<< MaxIter<< std::endl; + std::cout << GridLogMessage <<"**************************************************************************"<< std::endl; + std::cout << GridLogMessage <<" -- seek Nk = " << Nk <<" vectors"<< std::endl; + std::cout << GridLogMessage <<" -- accept Nstop = " << Nstop <<" vectors"<< std::endl; + std::cout << GridLogMessage <<" -- total Nm = " << Nm <<" vectors"<< std::endl; + std::cout << GridLogMessage <<" -- size of eval = " << eval.size() << std::endl; + std::cout << GridLogMessage <<" -- size of evec = " << evec.size() << std::endl; + if ( diagonalisation == IRLdiagonaliseWithDSTEGR ) { + std::cout << GridLogMessage << "Diagonalisation is DSTEGR "< lme(Nm); + std::vector lme2(Nm); + std::vector eval2(Nm); + + Eigen::MatrixXd Qt = Eigen::MatrixXd::Zero(Nm,Nm); + + std::vector Iconv(Nm); + std::vector B(Nm,grid); // waste of space replicating + + Field f(grid); + Field v(grid); + + int k1 = 1; + int k2 = Nk; + + Nconv = 0; + + RealD beta_k; + + // Set initial vector + evec[0] = src; + std::cout << GridLogMessage <<"norm2(src)= " << norm2(src)<=Nstop ){ - goto converged; - } - } // end of iter loop + if((vv=Nstop ){ + goto converged; + } + } // end of iter loop + + std::cout << GridLogMessage <<"**************************************************************************"<< std::endl; + std::cout<< GridLogError <<" ImplicitlyRestartedLanczos::calc() NOT converged."; + std::cout << GridLogMessage <<"**************************************************************************"<< std::endl; + abort(); - converged: - // Sorting - eval.resize(Nconv); - evec.resize(Nconv,grid); - for(int i=0; i & bq, - Field &bf, - DenseMatrix &H){ +private: +/* Saad PP. 195 +1. Choose an initial vector v1 of 2-norm unity. Set β1 ≡ 0, v0 ≡ 0 +2. For k = 1,2,...,m Do: +3. wk:=Avk−βkv_{k−1} +4. αk:=(wk,vk) // +5. wk:=wk−αkvk // wk orthog vk +6. βk+1 := ∥wk∥2. If βk+1 = 0 then Stop +7. vk+1 := wk/βk+1 +8. EndDo + */ + void step(std::vector& lmd, + std::vector& lme, + std::vector& evec, + Field& w,int Nm,int k) + { + const RealD tiny = 1.0e-20; + assert( k< Nm ); + + _poly(_Linop,evec[k],w); // 3. wk:=Avk−βkv_{k−1} + + if(k>0) w -= lme[k-1] * evec[k-1]; + + ComplexD zalph = innerProduct(evec[k],w); // 4. αk:=(wk,vk) + RealD alph = real(zalph); + + w = w - alph * evec[k];// 5. wk:=wk−αkvk + + RealD beta = normalise(w); // 6. βk+1 := ∥wk∥2. If βk+1 = 0 then Stop + // 7. vk+1 := wk/βk+1 + + lmd[k] = alph; + lme[k] = beta; + + if ( k > 0 ) orthogonalize(w,evec,k); // orthonormalise + if ( k < Nm-1) evec[k+1] = w; + + if ( beta < tiny ) std::cout << GridLogMessage << " beta is tiny "<& lmd, std::vector& lme, + int Nk, int Nm, + Eigen::MatrixXd & Qt, // Nm x Nm + GridBase *grid) + { + Eigen::MatrixXd TriDiag = Eigen::MatrixXd::Zero(Nk,Nk); - RealD beta; - RealD sqbt; - RealD alpha; + for(int i=0;i eigensolver(TriDiag); - for(int i=start;i& lmd, // Nm + std::vector& lme, // Nm + int Nk, int Nm, // Nk, Nm + Eigen::MatrixXd& Qt, // Nm x Nm matrix + RealD Dsh, int kmin, int kmax) + { + int k = kmin-1; + RealD x; + + RealD Fden = 1.0/hypot(lmd[k]-Dsh,lme[k]); + RealD c = ( lmd[k] -Dsh) *Fden; + RealD s = -lme[k] *Fden; + + RealD tmpa1 = lmd[k]; + RealD tmpa2 = lmd[k+1]; + RealD tmpb = lme[k]; - // Starting from scratch, bq[0] contains a random vector and |bq[0]| = 1 - int first; - if(start == 0){ + lmd[k] = c*c*tmpa1 +s*s*tmpa2 -2.0*c*s*tmpb; + lmd[k+1] = s*s*tmpa1 +c*c*tmpa2 +2.0*c*s*tmpb; + lme[k] = c*s*(tmpa1-tmpa2) +(c*c-s*s)*tmpb; + x =-s*lme[k+1]; + lme[k+1] = c*lme[k+1]; + + for(int i=0; i 1) std::cout << "orthagonality refined " << re << " times" < evals, - DenseVector evecs){ - int N= evals.size(); - _sort.push(evals,evecs, evals.size(),N); + void diagonalize(std::vector& lmd, std::vector& lme, + int Nk, int Nm, + Eigen::MatrixXd & Qt, + GridBase *grid) + { + Qt = Eigen::MatrixXd::Identity(Nm,Nm); + if ( diagonalisation == IRLdiagonaliseWithDSTEGR ) { + diagonalize_lapack(lmd,lme,Nk,Nm,Qt,grid); + } else if ( diagonalisation == IRLdiagonaliseWithQR ) { + diagonalize_QR(lmd,lme,Nk,Nm,Qt,grid); + } else if ( diagonalisation == IRLdiagonaliseWithEigen ) { + diagonalize_Eigen(lmd,lme,Nk,Nm,Qt,grid); + } else { + assert(0); } + } - void ImplicitRestart(int TM, DenseVector &evals, DenseVector > &evecs, DenseVector &bq, Field &bf, int cont) +#ifdef USE_LAPACK +void LAPACK_dstegr(char *jobz, char *range, int *n, double *d, double *e, + double *vl, double *vu, int *il, int *iu, double *abstol, + int *m, double *w, double *z, int *ldz, int *isuppz, + double *work, int *lwork, int *iwork, int *liwork, + int *info); +#endif + +void diagonalize_lapack(std::vector& lmd, + std::vector& lme, + int Nk, int Nm, + Eigen::MatrixXd& Qt, + GridBase *grid) +{ +#ifdef USE_LAPACK + const int size = Nm; + int NN = Nk; + double evals_tmp[NN]; + double evec_tmp[NN][NN]; + memset(evec_tmp[0],0,sizeof(double)*NN*NN); + double DD[NN]; + double EE[NN]; + for (int i = 0; i< NN; i++) { + for (int j = i - 1; j <= i + 1; j++) { + if ( j < NN && j >= 0 ) { + if (i==j) DD[i] = lmd[i]; + if (i==j) evals_tmp[i] = lmd[i]; + if (j==(i-1)) EE[j] = lme[j]; + } + } + } + int evals_found; + int lwork = ( (18*NN) > (1+4*NN+NN*NN)? (18*NN):(1+4*NN+NN*NN)) ; + int liwork = 3+NN*10 ; + int iwork[liwork]; + double work[lwork]; + int isuppz[2*NN]; + char jobz = 'V'; // calculate evals & evecs + char range = 'I'; // calculate all evals + // char range = 'A'; // calculate all evals + char uplo = 'U'; // refer to upper half of original matrix + char compz = 'I'; // Compute eigenvectors of tridiagonal matrix + int ifail[NN]; + int info; + int total = grid->_Nprocessors; + int node = grid->_processor; + int interval = (NN/total)+1; + double vl = 0.0, vu = 0.0; + int il = interval*node+1 , iu = interval*(node+1); + if (iu > NN) iu=NN; + double tol = 0.0; + if (1) { + memset(evals_tmp,0,sizeof(double)*NN); + if ( il <= NN){ + LAPACK_dstegr(&jobz, &range, &NN, + (double*)DD, (double*)EE, + &vl, &vu, &il, &iu, // these four are ignored if second parameteris 'A' + &tol, // tolerance + &evals_found, evals_tmp, (double*)evec_tmp, &NN, + isuppz, + work, &lwork, iwork, &liwork, + &info); + for (int i = iu-1; i>= il-1; i--){ + evals_tmp[i] = evals_tmp[i - (il-1)]; + if (il>1) evals_tmp[i-(il-1)]=0.; + for (int j = 0; j< NN; j++){ + evec_tmp[i][j] = evec_tmp[i - (il-1)][j]; + if (il>1) evec_tmp[i-(il-1)][j]=0.; + } + } + } { - std::cout << "ImplicitRestart begin. Eigensort starting\n"; + grid->GlobalSumVector(evals_tmp,NN); + grid->GlobalSumVector((double*)evec_tmp,NN*NN); + } + } + // Safer to sort instead of just reversing it, + // but the document of the routine says evals are sorted in increasing order. + // qr gives evals in decreasing order. + for(int i=0;i H; Resize(H,Nm,Nm); + void diagonalize_QR(std::vector& lmd, std::vector& lme, + int Nk, int Nm, + Eigen::MatrixXd & Qt, + GridBase *grid) + { + int Niter = 100*Nm; + int kmin = 1; + int kmax = Nk; - EigenSort(evals, evecs); - - ///Assign shifts - int K=Nk; - int M=Nm; - int P=Np; - int converged=0; - if(K - converged < 4) P = (M - K-1); //one - // DenseVector shifts(P + shift_extra.size()); - DenseVector shifts(P); - for(int k = 0; k < P; ++k) - shifts[k] = evals[k]; - - /// Shift to form a new H and q - DenseMatrix Q; Resize(Q,TM,TM); - Unity(Q); - Shift(Q, shifts); // H is implicitly passed in in Rudy's Shift routine - - int ff = K; - - /// Shifted H defines a new K step Arnoldi factorization - RealD beta = H[ff][ff-1]; - RealD sig = Q[TM - 1][ff - 1]; - std::cout << "beta = " << beta << " sig = " << real(sig) < q Q - times_real(bq, Q, TM); - - std::cout << norm2(bq[0]) << " -- after " << ff < &bq, Field &bf, DenseVector > & evecs,DenseVector &evals) - { - init(); - - int M=Nm; - - DenseMatrix H; Resize(H,Nm,Nm); - Resize(evals,Nm); - Resize(evecs,Nm); - - int ff = Lanczos_Factor(0, M, cont, bq,bf,H); // 0--M to begin with - - if(ff < M) { - std::cout << "Krylov: aborting ff "< " << it << std::endl; - int lock_num = lock ? converged : 0; - DenseVector tevals(M - lock_num ); - DenseMatrix tevecs; Resize(tevecs,M - lock_num,M - lock_num); - - //check residual of polynominal - TestConv(H,M, tevals, tevecs); - - if(converged >= Nk) - break; - - ImplicitRestart(ff, tevals,tevecs,H); - } - Wilkinson(H, evals, evecs, small); - // Check(); - - std::cout << "Done "< & H,DenseMatrix &Q, DenseVector shifts) { - - int P; Size(shifts,P); - int M; SizeSquare(Q,M); - - Unity(Q); - - int lock_num = lock ? converged : 0; - - RealD t_Househoulder_vector(0.0); - RealD t_Househoulder_mult(0.0); - - for(int i=0;i ck(3), v(3); - - x = H[lock_num+0][lock_num+0]-shifts[i]; - y = H[lock_num+1][lock_num+0]; - ck[0] = x; ck[1] = y; ck[2] = 0; - - normalise(ck); ///Normalization cancels in PHP anyway - RealD beta; - - Householder_vector(ck, 0, 2, v, beta); - Householder_mult(H,v,beta,0,lock_num+0,lock_num+2,0); - Householder_mult(H,v,beta,0,lock_num+0,lock_num+2,1); - ///Accumulate eigenvector - Householder_mult(Q,v,beta,0,lock_num+0,lock_num+2,1); - - int sw = 0; - for(int k=lock_num+0;k(ck, 0, 2-sw, v, beta); - Householder_mult(H,v, beta,0,k+1,k+3-sw,0); - Householder_mult(H,v, beta,0,k+1,k+3-sw,1); - ///Accumulate eigenvector - Householder_mult(Q,v, beta,0,k+1,k+3-sw,1); + // determination of 2x2 leading submatrix + RealD dsub = lmd[kmax-1]-lmd[kmax-2]; + RealD dd = sqrt(dsub*dsub + 4.0*lme[kmax-2]*lme[kmax-2]); + RealD Dsh = 0.5*(lmd[kmax-2]+lmd[kmax-1] +dd*(dsub/fabs(dsub))); + // (Dsh: shift) + + // transformation + qr_decomp(lmd,lme,Nk,Nm,Qt,Dsh,kmin,kmax); // Nk, Nm + + // Convergence criterion (redef of kmin and kamx) + for(int j=kmax-1; j>= kmin; --j){ + RealD dds = fabs(lmd[j-1])+fabs(lmd[j]); + if(fabs(lme[j-1])+dds > dds){ + kmax = j+1; + goto continued; } } - } + Niter = iter; + return; - void TestConv(DenseMatrix & H,int SS, - DenseVector &bq, Field &bf, - DenseVector &tevals, DenseVector > &tevecs, - int lock, int converged) - { - std::cout << "Converged " << converged << " so far." << std::endl; - int lock_num = lock ? converged : 0; - int M = Nm; - - ///Active Factorization - DenseMatrix AH; Resize(AH,SS - lock_num,SS - lock_num ); - - AH = GetSubMtx(H,lock_num, SS, lock_num, SS); - - int NN=tevals.size(); - int AHsize=SS-lock_num; - - RealD small=1.0e-16; - Wilkinson(AH, tevals, tevecs, small); - - EigenSort(tevals, tevecs); - - RealD resid_nrm= norm2(bf); - - if(!lock) converged = 0; -#if 0 - for(int i = SS - lock_num - 1; i >= SS - Nk && i >= 0; --i){ - - RealD diff = 0; - diff = abs( tevecs[i][Nm - 1 - lock_num] ) * resid_nrm; - - std::cout << "residual estimate " << SS-1-i << " " << diff << " of (" << tevals[i] << ")" << std::endl; - - if(diff < converged) { - - if(lock) { - - DenseMatrix Q; Resize(Q,M,M); - bool herm = true; - - Lock(H, Q, tevals[i], converged, small, SS, herm); - - times_real(bq, Q, bq.size()); - bf = Q[M - 1][M - 1]* bf; - lock_num++; - } - converged++; - std::cout << " converged on eval " << converged << " of " << Nk << std::endl; - } else { + continued: + for(int j=0; j dds){ + kmin = j+1; break; } } -#endif - std::cout << "Got " << converged << " so far " < &evals, - DenseVector > &evecs) { - - DenseVector goodval(this->get); - - EigenSort(evals,evecs); - - int NM = Nm; - - DenseVector< DenseVector > V; Size(V,NM); - DenseVector QZ(NM*NM); - - for(int i = 0; i < NM; i++){ - for(int j = 0; j < NM; j++){ - // evecs[i][j]; - } - } - } - - -/** - There is some matrix Q such that for any vector y - Q.e_1 = y and Q is unitary. -**/ - template - static T orthQ(DenseMatrix &Q, DenseVector y){ - int N = y.size(); //Matrix Size - Fill(Q,0.0); - T tau; - for(int i=0;i 0.0){ - - T gam = conj( (y[j]/tau)/tau0 ); - for(int k=0;k<=j-1;k++){ - Q[k][j]=-gam*y[k]; - } - Q[j][j]=tau0/tau; - } else { - Q[j-1][j]=1.0; - } - tau0 = tau; - } - return tau; + std::cout << GridLogError << "[QL method] Error - Too many iteration: "< - static T orthU(DenseMatrix &Q, DenseVector y){ - T tau = orthQ(Q,y); - SL(Q); - return tau; - } - - -/** - Wind up with a matrix with the first con rows untouched - -say con = 2 - Q is such that Qdag H Q has {x, x, val, 0, 0, 0, 0, ...} as 1st colum - and the matrix is upper hessenberg - and with f and Q appropriately modidied with Q is the arnoldi factorization - -**/ - -template -static void Lock(DenseMatrix &H, // Hess mtx - DenseMatrix &Q, // Lock Transform - T val, // value to be locked - int con, // number already locked - RealD small, - int dfg, - bool herm) -{ - - - //ForceTridiagonal(H); - - int M = H.dim; - DenseVector vec; Resize(vec,M-con); - - DenseMatrix AH; Resize(AH,M-con,M-con); - AH = GetSubMtx(H,con, M, con, M); - - DenseMatrix QQ; Resize(QQ,M-con,M-con); - - Unity(Q); Unity(QQ); - - DenseVector evals; Resize(evals,M-con); - DenseMatrix evecs; Resize(evecs,M-con,M-con); - - Wilkinson(AH, evals, evecs, small); - - int k=0; - RealD cold = abs( val - evals[k]); - for(int i=1;icon+2; j--){ - - DenseMatrix U; Resize(U,j-1-con,j-1-con); - DenseVector z; Resize(z,j-1-con); - T nm = norm(z); - for(int k = con+0;k Hb; Resize(Hb,j-1-con,M); - - for(int a = 0;a Qb; Resize(Qb,M,M); - - for(int a = 0;a Hc; Resize(Hc,M,M); - - for(int a = 0;a - - This program is free software; you can redistribute it and/or modify - it under the terms of the GNU General Public License as published by - the Free Software Foundation; either version 2 of the License, or - (at your option) any later version. - - This program is distributed in the hope that it will be useful, - but WITHOUT ANY WARRANTY; without even the implied warranty of - MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the - GNU General Public License for more details. - - You should have received a copy of the GNU General Public License along - with this program; if not, write to the Free Software Foundation, Inc., - 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA. - - See the full license in the file "LICENSE" in the top level distribution directory - *************************************************************************************/ - /* END LEGAL */ -#ifndef MATRIX_H -#define MATRIX_H - -#include -#include -#include -#include -#include -#include -#include -#include -#include - - -/** Sign function **/ -template T sign(T p){return ( p/abs(p) );} - -///////////////////////////////////////////////////////////////////////////////////////////////////////// -///////////////////// Hijack STL containers for our wicked means ///////////////////////////////////////// -///////////////////////////////////////////////////////////////////////////////////////////////////////// -template using Vector = Vector; -template using Matrix = Vector >; - -template void Resize(Vector & vec, int N) { vec.resize(N); } - -template void Resize(Matrix & mat, int N, int M) { - mat.resize(N); - for(int i=0;i void Size(Vector & vec, int &N) -{ - N= vec.size(); -} -template void Size(Matrix & mat, int &N,int &M) -{ - N= mat.size(); - M= mat[0].size(); -} -template void SizeSquare(Matrix & mat, int &N) -{ - int M; Size(mat,N,M); - assert(N==M); -} -template void SizeSame(Matrix & mat1,Matrix &mat2, int &N1,int &M1) -{ - int N2,M2; - Size(mat1,N1,M1); - Size(mat2,N2,M2); - assert(N1==N2); - assert(M1==M2); -} - -//***************************************** -//* (Complex) Vector operations * -//***************************************** - -/**Conj of a Vector **/ -template Vector conj(Vector p){ - Vector q(p.size()); - for(int i=0;i T norm(Vector p){ - T sum = 0; - for(int i=0;i T norm2(Vector p){ - T sum = 0; - for(int i=0;i T trace(Vector p){ - T sum = 0; - for(int i=0;i void Fill(Vector &p, T c){ - for(int i=0;i void normalize(Vector &p){ - T m = norm(p); - if( abs(m) > 0.0) for(int i=0;i Vector times(Vector p, U s){ - for(int i=0;i Vector times(U s, Vector p){ - for(int i=0;i T inner(Vector a, Vector b){ - T m = 0.; - for(int i=0;i Vector add(Vector a, Vector b){ - Vector m(a.size()); - for(int i=0;i Vector sub(Vector a, Vector b){ - Vector m(a.size()); - for(int i=0;i void Fill(Matrix & mat, T&val) { - int N,M; - Size(mat,N,M); - for(int i=0;i Transpose(Matrix & mat){ - int N,M; - Size(mat,N,M); - Matrix C; Resize(C,M,N); - for(int i=0;i void Unity(Matrix &mat){ - int N; SizeSquare(mat,N); - for(int i=0;i -void PlusUnit(Matrix & A,T c){ - int dim; SizeSquare(A,dim); - for(int i=0;i HermitianConj(Matrix &mat){ - - int dim; SizeSquare(mat,dim); - - Matrix C; Resize(C,dim,dim); - - for(int i=0;i diag(Matrix &A) -{ - int dim; SizeSquare(A,dim); - Vector d; Resize(d,dim); - - for(int i=0;i operator *(Vector &B,Matrix &A) -{ - int K,M,N; - Size(B,K); - Size(A,M,N); - assert(K==M); - - Vector C; Resize(C,N); - - for(int j=0;j inv_diag(Matrix & A){ - int dim; SizeSquare(A,dim); - Vector d; Resize(d,dim); - for(int i=0;i operator + (Matrix &A,Matrix &B) -{ - int N,M ; SizeSame(A,B,N,M); - Matrix C; Resize(C,N,M); - for(int i=0;i operator- (Matrix & A,Matrix &B){ - int N,M ; SizeSame(A,B,N,M); - Matrix C; Resize(C,N,M); - for(int i=0;i operator* (Matrix & A,T c){ - int N,M; Size(A,N,M); - Matrix C; Resize(C,N,M); - for(int i=0;i operator* (Matrix &A,Matrix &B){ - int K,L,N,M; - Size(A,K,L); - Size(B,N,M); assert(L==N); - Matrix C; Resize(C,K,M); - - for(int i=0;i operator* (Matrix &A,Vector &B){ - int M,N,K; - Size(A,N,M); - Size(B,K); assert(K==M); - Vector C; Resize(C,N); - for(int i=0;i T LargestDiag(Matrix &A) -{ - int dim ; SizeSquare(A,dim); - - T ld = abs(A[0][0]); - for(int i=1;i abs(ld) ){ld = cf;} - } - return ld; -} - -/** Look for entries on the leading subdiagonal that are smaller than 'small' **/ -template int Chop_subdiag(Matrix &A,T norm, int offset, U small) -{ - int dim; SizeSquare(A,dim); - for(int l = dim - 1 - offset; l >= 1; l--) { - if((U)abs(A[l][l - 1]) < (U)small) { - A[l][l-1]=(U)0.0; - return l; - } - } - return 0; -} - -/** Look for entries on the leading subdiagonal that are smaller than 'small' **/ -template int Chop_symm_subdiag(Matrix & A,T norm, int offset, U small) -{ - int dim; SizeSquare(A,dim); - for(int l = dim - 1 - offset; l >= 1; l--) { - if((U)abs(A[l][l - 1]) < (U)small) { - A[l][l - 1] = (U)0.0; - A[l - 1][l] = (U)0.0; - return l; - } - } - return 0; -} -/**Assign a submatrix to a larger one**/ -template -void AssignSubMtx(Matrix & A,int row_st, int row_end, int col_st, int col_end, Matrix &S) -{ - for(int i = row_st; i -Matrix GetSubMtx(Matrix &A,int row_st, int row_end, int col_st, int col_end) -{ - Matrix H; Resize(row_end - row_st,col_end-col_st); - - for(int i = row_st; i -void AssignSubMtx(Matrix & A,int row_st, int row_end, int col_st, int col_end, Matrix &S) -{ - for(int i = row_st; i T proj(Matrix A, Vector B){ - int dim; SizeSquare(A,dim); - int dimB; Size(B,dimB); - assert(dimB==dim); - T C = 0; - for(int i=0;i q Q -template void times(Vector &q, Matrix &Q) -{ - int M; SizeSquare(Q,M); - int N; Size(q,N); - assert(M==N); - - times(q,Q,N); -} - -/// q -> q Q -template void times(multi1d &q, Matrix &Q, int N) -{ - GridBase *grid = q[0]._grid; - int M; SizeSquare(Q,M); - int K; Size(q,K); - assert(N S(N,grid ); - for(int j=0;j -#include -#include - -struct Bisection { - -static void get_eig2(int row_num,std::vector &ALPHA,std::vector &BETA, std::vector & eig) -{ - int i,j; - std::vector evec1(row_num+3); - std::vector evec2(row_num+3); - RealD eps2; - ALPHA[1]=0.; - BETHA[1]=0.; - for(i=0;imag(evec2[i+1])) { - swap(evec2+i,evec2+i+1); - swapped=1; - } - } - end--; - for(i=end-1;i>=begin;i--){ - if(mag(evec2[i])>mag(evec2[i+1])) { - swap(evec2+i,evec2+i+1); - swapped=1; - } - } - begin++; - } - - for(i=0;i &c, - std::vector &b, - int n, - int m1, - int m2, - RealD eps1, - RealD relfeh, - std::vector &x, - RealD &eps2) -{ - std::vector wu(n+2); - - RealD h,q,x1,xu,x0,xmin,xmax; - int i,a,k; - - b[1]=0.0; - xmin=c[n]-fabs(b[n]); - xmax=c[n]+fabs(b[n]); - for(i=1;ixmax) xmax= c[i]+h; - if(c[i]-h0.0 ? xmax : -xmin); - if(eps1<=0.0) eps1=eps2; - eps2=0.5*eps1+7.0*(eps2); - x0=xmax; - for(i=m1;i<=m2;i++){ - x[i]=xmax; - wu[i]=xmin; - } - - for(k=m2;k>=m1;k--){ - xu=xmin; - i=k; - do{ - if(xu=m1); - if(x0>x[k]) x0=x[k]; - while((x0-xu)>2*relfeh*(fabs(xu)+fabs(x0))+eps1){ - x1=(xu+x0)/2; - - a=0; - q=1.0; - for(i=1;i<=n;i++){ - q=c[i]-x1-((q!=0.0)? b[i]*b[i]/q:fabs(b[i])/relfeh); - if(q<0) a++; - } - // printf("x1=%e a=%d\n",x1,a); - if(ax1) x[a]=x1; - } - }else x0=x1; - } - x[k]=(x0+xu)/2; - } -} -} diff --git a/lib/algorithms/iterative/get_eig.c b/lib/algorithms/iterative/get_eig.c deleted file mode 100644 index d3f5a12f..00000000 --- a/lib/algorithms/iterative/get_eig.c +++ /dev/null @@ -1 +0,0 @@ - diff --git a/lib/allocator/AlignedAllocator.cc b/lib/allocator/AlignedAllocator.cc new file mode 100644 index 00000000..967b2571 --- /dev/null +++ b/lib/allocator/AlignedAllocator.cc @@ -0,0 +1,97 @@ +#include +#include + +namespace Grid { + +int PointerCache::victim; + + PointerCache::PointerCacheEntry PointerCache::Entries[PointerCache::Ncache]; + +void *PointerCache::Insert(void *ptr,size_t bytes) { + + if (bytes < 4096 ) return ptr; + +#ifdef GRID_OMP + assert(omp_in_parallel()==0); +#endif + + void * ret = NULL; + int v = -1; + + for(int e=0;e= 0); + const int page_size = 4096; + uint64_t virt_pfn = (uint64_t)Buf / page_size; + off_t offset = sizeof(uint64_t) * virt_pfn; + uint64_t npages = (BYTES + page_size-1) / page_size; + uint64_t pagedata[npages]; + uint64_t ret = lseek(fd, offset, SEEK_SET); + assert(ret == offset); + ret = ::read(fd, pagedata, sizeof(uint64_t)*npages); + assert(ret == sizeof(uint64_t) * npages); + int nhugepages = npages / 512; + int n4ktotal, nnothuge; + n4ktotal = 0; + nnothuge = 0; + for (int i = 0; i < nhugepages; ++i) { + uint64_t baseaddr = (pagedata[i*512] & 0x7fffffffffffffULL) * page_size; + for (int j = 0; j < 512; ++j) { + uint64_t pageaddr = (pagedata[i*512+j] & 0x7fffffffffffffULL) * page_size; + ++n4ktotal; + if (pageaddr != baseaddr + j * page_size) + ++nnothuge; + } + } + int rank = CartesianCommunicator::RankWorld(); + printf("rank %d Allocated %d 4k pages, %d not in huge pages\n", rank, n4ktotal, nnothuge); +#endif +} + +} diff --git a/lib/AlignedAllocator.h b/lib/allocator/AlignedAllocator.h similarity index 74% rename from lib/AlignedAllocator.h rename to lib/allocator/AlignedAllocator.h index a8b9c53b..62579587 100644 --- a/lib/AlignedAllocator.h +++ b/lib/allocator/AlignedAllocator.h @@ -1,4 +1,4 @@ - /************************************************************************************* +/************************************************************************************* Grid physics library, www.github.com/paboyle/Grid @@ -42,9 +42,34 @@ Author: Peter Boyle namespace Grid { + class PointerCache { + private: + + static const int Ncache=8; + static int victim; + + typedef struct { + void *address; + size_t bytes; + int valid; + } PointerCacheEntry; + + static PointerCacheEntry Entries[Ncache]; + + public: + + + static void *Insert(void *ptr,size_t bytes) ; + static void *Lookup(size_t bytes) ; + + }; + + void check_huge_pages(void *Buf,uint64_t BYTES); + //////////////////////////////////////////////////////////////////// // A lattice of something, but assume the something is SIMDized. //////////////////////////////////////////////////////////////////// + template class alignedAllocator { public: @@ -66,27 +91,43 @@ public: pointer allocate(size_type __n, const void* _p= 0) { -#ifdef HAVE_MM_MALLOC_H - _Tp * ptr = (_Tp *) _mm_malloc(__n*sizeof(_Tp),128); -#else - _Tp * ptr = (_Tp *) memalign(128,__n*sizeof(_Tp)); -#endif + size_type bytes = __n*sizeof(_Tp); - _Tp tmp; -#ifdef GRID_NUMA -#pragma omp parallel for schedule(static) - for(int i=0;i<__n;i++){ - ptr[i]=tmp; - } -#endif + _Tp *ptr = (_Tp *) PointerCache::Lookup(bytes); + // if ( ptr != NULL ) + // std::cout << "alignedAllocator "<<__n << " cache hit "<< std::hex << ptr < -Author: paboyle + Author: Peter Boyle + Author: paboyle + Author: Guido Cossu This program is free software; you can redistribute it and/or modify it under the terms of the GNU General Public License as published by @@ -49,10 +50,9 @@ public: GridBase(const std::vector & processor_grid) : CartesianCommunicator(processor_grid) {}; - // Physics Grid information. std::vector _simd_layout;// Which dimensions get relayed out over simd lanes. - std::vector _fdimensions;// Global dimensions of array prior to cb removal + std::vector _fdimensions;// (full) Global dimensions of array prior to cb removal std::vector _gdimensions;// Global dimensions of array after cb removal std::vector _ldimensions;// local dimensions of array with processor images removed std::vector _rdimensions;// Reduced local dimensions with simd lane images and processor images removed @@ -62,13 +62,12 @@ public: int _isites; int _fsites; // _isites*_osites = product(dimensions). int _gsites; - std::vector _slice_block; // subslice information + std::vector _slice_block;// subslice information std::vector _slice_stride; std::vector _slice_nblock; - // Might need these at some point - // std::vector _lstart; // local start of array in gcoors. _processor_coor[d]*_ldimensions[d] - // std::vector _lend; // local end of array in gcoors _processor_coor[d]*_ldimensions[d]+_ldimensions_[d]-1 + std::vector _lstart; // local start of array in gcoors _processor_coor[d]*_ldimensions[d] + std::vector _lend ; // local end of array in gcoors _processor_coor[d]*_ldimensions[d]+_ldimensions_[d]-1 public: @@ -77,7 +76,7 @@ public: // GridCartesian / GridRedBlackCartesian //////////////////////////////////////////////////////////////// virtual int CheckerBoarded(int dim)=0; - virtual int CheckerBoard(std::vector &site)=0; + virtual int CheckerBoard(const std::vector &site)=0; virtual int CheckerBoardDestination(int source_cb,int shift,int dim)=0; virtual int CheckerBoardShift(int source_cb,int dim,int shift,int osite)=0; virtual int CheckerBoardShiftForCB(int source_cb,int dim,int shift,int cb)=0; @@ -99,7 +98,7 @@ public: virtual int oIndex(std::vector &coor) { int idx=0; - // Works with either global or local coordinates + // Works with either global or local coordinates for(int d=0;d<_ndimension;d++) idx+=_ostride[d]*(coor[d]%_rdimensions[d]); return idx; } @@ -121,6 +120,11 @@ public: Lexicographic::CoorFromIndex(coor,Oindex,_rdimensions); } + inline void InOutCoorToLocalCoor (std::vector &ocoor, std::vector &icoor, std::vector &lcoor) { + lcoor.resize(_ndimension); + for (int d = 0; d < _ndimension; d++) + lcoor[d] = ocoor[d] + _rdimensions[d] * icoor[d]; + } ////////////////////////////////////////////////////////// // SIMD lane addressing @@ -129,6 +133,7 @@ public: { Lexicographic::CoorFromIndex(coor,lane,_simd_layout); } + inline int PermuteDim(int dimension){ return _simd_layout[dimension]>1; } @@ -146,15 +151,15 @@ public: // Distance should be either 0,1,2.. // if ( _simd_layout[dimension] > 2 ) { - for(int d=0;d<_ndimension;d++){ - if ( d != dimension ) assert ( (_simd_layout[d]==1) ); - } - permute_type = RotateBit; // How to specify distance; this is not just direction. - return permute_type; + for(int d=0;d<_ndimension;d++){ + if ( d != dimension ) assert ( (_simd_layout[d]==1) ); + } + permute_type = RotateBit; // How to specify distance; this is not just direction. + return permute_type; } for(int d=_ndimension-1;d>dimension;d--){ - if (_simd_layout[d]>1 ) permute_type++; + if (_simd_layout[d]>1 ) permute_type++; } return permute_type; } @@ -169,26 +174,51 @@ public: inline int gSites(void) const { return _isites*_osites*_Nprocessors; }; inline int Nd (void) const { return _ndimension;}; + inline const std::vector LocalStarts(void) { return _lstart; }; inline const std::vector &FullDimensions(void) { return _fdimensions;}; inline const std::vector &GlobalDimensions(void) { return _gdimensions;}; inline const std::vector &LocalDimensions(void) { return _ldimensions;}; inline const std::vector &VirtualLocalDimensions(void) { return _ldimensions;}; + //////////////////////////////////////////////////////////////// + // Utility to print the full decomposition details + //////////////////////////////////////////////////////////////// + + void show_decomposition(){ + std::cout << GridLogMessage << "\tFull Dimensions : " << _fdimensions << std::endl; + std::cout << GridLogMessage << "\tSIMD layout : " << _simd_layout << std::endl; + std::cout << GridLogMessage << "\tGlobal Dimensions : " << _gdimensions << std::endl; + std::cout << GridLogMessage << "\tLocal Dimensions : " << _ldimensions << std::endl; + std::cout << GridLogMessage << "\tReduced Dimensions : " << _rdimensions << std::endl; + std::cout << GridLogMessage << "\tOuter strides : " << _ostride << std::endl; + std::cout << GridLogMessage << "\tInner strides : " << _istride << std::endl; + std::cout << GridLogMessage << "\tiSites : " << _isites << std::endl; + std::cout << GridLogMessage << "\toSites : " << _osites << std::endl; + std::cout << GridLogMessage << "\tlSites : " << lSites() << std::endl; + std::cout << GridLogMessage << "\tgSites : " << gSites() << std::endl; + std::cout << GridLogMessage << "\tNd : " << _ndimension << std::endl; + } + //////////////////////////////////////////////////////////////// // Global addressing //////////////////////////////////////////////////////////////// void GlobalIndexToGlobalCoor(int gidx,std::vector &gcoor){ + assert(gidx< gSites()); Lexicographic::CoorFromIndex(gcoor,gidx,_gdimensions); } void LocalIndexToLocalCoor(int lidx,std::vector &lcoor){ + assert(lidx & gcoor,int & gidx){ gidx=0; int mult=1; for(int mu=0;mu<_ndimension;mu++) { - gidx+=mult*gcoor[mu]; - mult*=_gdimensions[mu]; + gidx+=mult*gcoor[mu]; + mult*=_gdimensions[mu]; } } void GlobalCoorToProcessorCoorLocalCoor(std::vector &pcoor,std::vector &lcoor,const std::vector &gcoor) @@ -196,9 +226,9 @@ public: pcoor.resize(_ndimension); lcoor.resize(_ndimension); for(int mu=0;mu<_ndimension;mu++){ - int _fld = _fdimensions[mu]/_processors[mu]; - pcoor[mu] = gcoor[mu]/_fld; - lcoor[mu] = gcoor[mu]%_fld; + int _fld = _fdimensions[mu]/_processors[mu]; + pcoor[mu] = gcoor[mu]/_fld; + lcoor[mu] = gcoor[mu]%_fld; } } void GlobalCoorToRankIndex(int &rank, int &o_idx, int &i_idx ,const std::vector &gcoor) @@ -207,16 +237,16 @@ public: std::vector lcoor; GlobalCoorToProcessorCoorLocalCoor(pcoor,lcoor,gcoor); rank = RankFromProcessorCoor(pcoor); - + /* std::vector cblcoor(lcoor); for(int d=0;dCheckerBoarded(d) ) { - cblcoor[d] = lcoor[d]/2; - } + if( this->CheckerBoarded(d) ) { + cblcoor[d] = lcoor[d]/2; + } } - - i_idx= iIndex(cblcoor);// this does not imply divide by 2 on checker dim - o_idx= oIndex(lcoor); // this implies divide by 2 on checkerdim + */ + i_idx= iIndex(lcoor); + o_idx= oIndex(lcoor); } void RankIndexToGlobalCoor(int rank, int o_idx, int i_idx , std::vector &gcoor) @@ -238,7 +268,7 @@ public: { RankIndexToGlobalCoor(rank,o_idx,i_idx ,fcoor); if(CheckerBoarded(0)){ - fcoor[0] = fcoor[0]*2+cb; + fcoor[0] = fcoor[0]*2+cb; } } void ProcessorCoorLocalCoorToGlobalCoor(std::vector &Pcoor,std::vector &Lcoor,std::vector &gcoor) diff --git a/lib/cartesian/Cartesian_full.h b/lib/cartesian/Cartesian_full.h index b0d20441..815e3b22 100644 --- a/lib/cartesian/Cartesian_full.h +++ b/lib/cartesian/Cartesian_full.h @@ -49,7 +49,7 @@ public: virtual int CheckerBoarded(int dim){ return 0; } - virtual int CheckerBoard(std::vector &site){ + virtual int CheckerBoard(const std::vector &site){ return 0; } virtual int CheckerBoardDestination(int cb,int shift,int dim){ @@ -62,73 +62,81 @@ public: return shift; } GridCartesian(const std::vector &dimensions, - const std::vector &simd_layout, - const std::vector &processor_grid - ) : GridBase(processor_grid) + const std::vector &simd_layout, + const std::vector &processor_grid) : GridBase(processor_grid) { - /////////////////////// - // Grid information - /////////////////////// - _ndimension = dimensions.size(); - - _fdimensions.resize(_ndimension); - _gdimensions.resize(_ndimension); - _ldimensions.resize(_ndimension); - _rdimensions.resize(_ndimension); - _simd_layout.resize(_ndimension); - - _ostride.resize(_ndimension); - _istride.resize(_ndimension); - - _fsites = _gsites = _osites = _isites = 1; + /////////////////////// + // Grid information + /////////////////////// + _ndimension = dimensions.size(); - for(int d=0;d<_ndimension;d++){ - _fdimensions[d] = dimensions[d]; // Global dimensions - _gdimensions[d] = _fdimensions[d]; // Global dimensions - _simd_layout[d] = simd_layout[d]; - _fsites = _fsites * _fdimensions[d]; - _gsites = _gsites * _gdimensions[d]; + _fdimensions.resize(_ndimension); + _gdimensions.resize(_ndimension); + _ldimensions.resize(_ndimension); + _rdimensions.resize(_ndimension); + _simd_layout.resize(_ndimension); + _lstart.resize(_ndimension); + _lend.resize(_ndimension); - //FIXME check for exact division + _ostride.resize(_ndimension); + _istride.resize(_ndimension); - // Use a reduced simd grid - _ldimensions[d]= _gdimensions[d]/_processors[d]; //local dimensions - _rdimensions[d]= _ldimensions[d]/_simd_layout[d]; //overdecomposition - _osites *= _rdimensions[d]; - _isites *= _simd_layout[d]; - - // Addressing support - if ( d==0 ) { - _ostride[d] = 1; - _istride[d] = 1; - } else { - _ostride[d] = _ostride[d-1]*_rdimensions[d-1]; - _istride[d] = _istride[d-1]*_simd_layout[d-1]; - } + _fsites = _gsites = _osites = _isites = 1; + + for (int d = 0; d < _ndimension; d++) + { + _fdimensions[d] = dimensions[d]; // Global dimensions + _gdimensions[d] = _fdimensions[d]; // Global dimensions + _simd_layout[d] = simd_layout[d]; + _fsites = _fsites * _fdimensions[d]; + _gsites = _gsites * _gdimensions[d]; + + // Use a reduced simd grid + _ldimensions[d] = _gdimensions[d] / _processors[d]; //local dimensions + assert(_ldimensions[d] * _processors[d] == _gdimensions[d]); + + _rdimensions[d] = _ldimensions[d] / _simd_layout[d]; //overdecomposition + assert(_rdimensions[d] * _simd_layout[d] == _ldimensions[d]); + + _lstart[d] = _processor_coor[d] * _ldimensions[d]; + _lend[d] = _processor_coor[d] * _ldimensions[d] + _ldimensions[d] - 1; + _osites *= _rdimensions[d]; + _isites *= _simd_layout[d]; + + // Addressing support + if (d == 0) + { + _ostride[d] = 1; + _istride[d] = 1; } - - /////////////////////// - // subplane information - /////////////////////// - _slice_block.resize(_ndimension); - _slice_stride.resize(_ndimension); - _slice_nblock.resize(_ndimension); - - int block =1; - int nblock=1; - for(int d=0;d<_ndimension;d++) nblock*=_rdimensions[d]; - - for(int d=0;d<_ndimension;d++){ - nblock/=_rdimensions[d]; - _slice_block[d] =block; - _slice_stride[d]=_ostride[d]*_rdimensions[d]; - _slice_nblock[d]=nblock; - block = block*_rdimensions[d]; + else + { + _ostride[d] = _ostride[d - 1] * _rdimensions[d - 1]; + _istride[d] = _istride[d - 1] * _simd_layout[d - 1]; } + } + /////////////////////// + // subplane information + /////////////////////// + _slice_block.resize(_ndimension); + _slice_stride.resize(_ndimension); + _slice_nblock.resize(_ndimension); + + int block = 1; + int nblock = 1; + for (int d = 0; d < _ndimension; d++) + nblock *= _rdimensions[d]; + + for (int d = 0; d < _ndimension; d++) + { + nblock /= _rdimensions[d]; + _slice_block[d] = block; + _slice_stride[d] = _ostride[d] * _rdimensions[d]; + _slice_nblock[d] = nblock; + block = block * _rdimensions[d]; + } }; }; - - } #endif diff --git a/lib/cartesian/Cartesian_red_black.h b/lib/cartesian/Cartesian_red_black.h index 6a4300d7..b1a5b9ef 100644 --- a/lib/cartesian/Cartesian_red_black.h +++ b/lib/cartesian/Cartesian_red_black.h @@ -49,7 +49,7 @@ public: if( dim==_checker_dim) return 1; else return 0; } - virtual int CheckerBoard(std::vector &site){ + virtual int CheckerBoard(const std::vector &site){ int linear=0; assert(site.size()==_ndimension); for(int d=0;d<_ndimension;d++){ @@ -131,132 +131,155 @@ public: Init(dimensions,simd_layout,processor_grid,checker_dim_mask,0); } void Init(const std::vector &dimensions, - const std::vector &simd_layout, - const std::vector &processor_grid, - const std::vector &checker_dim_mask, - int checker_dim) + const std::vector &simd_layout, + const std::vector &processor_grid, + const std::vector &checker_dim_mask, + int checker_dim) { - /////////////////////// - // Grid information - /////////////////////// + /////////////////////// + // Grid information + /////////////////////// _checker_dim = checker_dim; - assert(checker_dim_mask[checker_dim]==1); + assert(checker_dim_mask[checker_dim] == 1); _ndimension = dimensions.size(); - assert(checker_dim_mask.size()==_ndimension); - assert(processor_grid.size()==_ndimension); - assert(simd_layout.size()==_ndimension); - + assert(checker_dim_mask.size() == _ndimension); + assert(processor_grid.size() == _ndimension); + assert(simd_layout.size() == _ndimension); + _fdimensions.resize(_ndimension); _gdimensions.resize(_ndimension); _ldimensions.resize(_ndimension); _rdimensions.resize(_ndimension); _simd_layout.resize(_ndimension); - + _lstart.resize(_ndimension); + _lend.resize(_ndimension); + _ostride.resize(_ndimension); _istride.resize(_ndimension); - + _fsites = _gsites = _osites = _isites = 1; - - _checker_dim_mask=checker_dim_mask; - for(int d=0;d<_ndimension;d++){ - _fdimensions[d] = dimensions[d]; - _gdimensions[d] = _fdimensions[d]; - _fsites = _fsites * _fdimensions[d]; - _gsites = _gsites * _gdimensions[d]; - - if (d==_checker_dim) { - _gdimensions[d] = _gdimensions[d]/2; // Remove a checkerboard - } - _ldimensions[d] = _gdimensions[d]/_processors[d]; + _checker_dim_mask = checker_dim_mask; - // Use a reduced simd grid - _simd_layout[d] = simd_layout[d]; - _rdimensions[d]= _ldimensions[d]/_simd_layout[d]; - assert(_rdimensions[d]>0); + for (int d = 0; d < _ndimension; d++) + { + _fdimensions[d] = dimensions[d]; + _gdimensions[d] = _fdimensions[d]; + _fsites = _fsites * _fdimensions[d]; + _gsites = _gsites * _gdimensions[d]; - // all elements of a simd vector must have same checkerboard. - // If Ls vectorised, this must still be the case; e.g. dwf rb5d - if ( _simd_layout[d]>1 ) { - if ( checker_dim_mask[d] ) { - assert( (_rdimensions[d]&0x1) == 0 ); - } - } + if (d == _checker_dim) + { + assert((_gdimensions[d] & 0x1) == 0); + _gdimensions[d] = _gdimensions[d] / 2; // Remove a checkerboard + } + _ldimensions[d] = _gdimensions[d] / _processors[d]; + assert(_ldimensions[d] * _processors[d] == _gdimensions[d]); + _lstart[d] = _processor_coor[d] * _ldimensions[d]; + _lend[d] = _processor_coor[d] * _ldimensions[d] + _ldimensions[d] - 1; - _osites *= _rdimensions[d]; - _isites *= _simd_layout[d]; - - // Addressing support - if ( d==0 ) { - _ostride[d] = 1; - _istride[d] = 1; - } else { - _ostride[d] = _ostride[d-1]*_rdimensions[d-1]; - _istride[d] = _istride[d-1]*_simd_layout[d-1]; - } + // Use a reduced simd grid + _simd_layout[d] = simd_layout[d]; + _rdimensions[d] = _ldimensions[d] / _simd_layout[d]; // this is not checking if this is integer + assert(_rdimensions[d] * _simd_layout[d] == _ldimensions[d]); + assert(_rdimensions[d] > 0); + // all elements of a simd vector must have same checkerboard. + // If Ls vectorised, this must still be the case; e.g. dwf rb5d + if (_simd_layout[d] > 1) + { + if (checker_dim_mask[d]) + { + assert((_rdimensions[d] & 0x1) == 0); + } + } + _osites *= _rdimensions[d]; + _isites *= _simd_layout[d]; + + // Addressing support + if (d == 0) + { + _ostride[d] = 1; + _istride[d] = 1; + } + else + { + _ostride[d] = _ostride[d - 1] * _rdimensions[d - 1]; + _istride[d] = _istride[d - 1] * _simd_layout[d - 1]; + } } - + //////////////////////////////////////////////////////////////////////////////////////////// // subplane information //////////////////////////////////////////////////////////////////////////////////////////// _slice_block.resize(_ndimension); _slice_stride.resize(_ndimension); _slice_nblock.resize(_ndimension); - - int block =1; - int nblock=1; - for(int d=0;d<_ndimension;d++) nblock*=_rdimensions[d]; - - for(int d=0;d<_ndimension;d++){ - nblock/=_rdimensions[d]; - _slice_block[d] =block; - _slice_stride[d]=_ostride[d]*_rdimensions[d]; - _slice_nblock[d]=nblock; - block = block*_rdimensions[d]; + + int block = 1; + int nblock = 1; + for (int d = 0; d < _ndimension; d++) + nblock *= _rdimensions[d]; + + for (int d = 0; d < _ndimension; d++) + { + nblock /= _rdimensions[d]; + _slice_block[d] = block; + _slice_stride[d] = _ostride[d] * _rdimensions[d]; + _slice_nblock[d] = nblock; + block = block * _rdimensions[d]; } //////////////////////////////////////////////// // Create a checkerboard lookup table //////////////////////////////////////////////// int rvol = 1; - for(int d=0;d<_ndimension;d++){ - rvol=rvol * _rdimensions[d]; + for (int d = 0; d < _ndimension; d++) + { + rvol = rvol * _rdimensions[d]; } _checker_board.resize(rvol); - for(int osite=0;osite<_osites;osite++){ - _checker_board[osite] = CheckerBoardFromOindex (osite); + for (int osite = 0; osite < _osites; osite++) + { + _checker_board[osite] = CheckerBoardFromOindex(osite); } - }; -protected: + + protected: virtual int oIndex(std::vector &coor) { - int idx=0; - for(int d=0;d<_ndimension;d++) { - if( d==_checker_dim ) { - idx+=_ostride[d]*((coor[d]/2)%_rdimensions[d]); - } else { - idx+=_ostride[d]*(coor[d]%_rdimensions[d]); - } + int idx = 0; + for (int d = 0; d < _ndimension; d++) + { + if (d == _checker_dim) + { + idx += _ostride[d] * ((coor[d] / 2) % _rdimensions[d]); + } + else + { + idx += _ostride[d] * (coor[d] % _rdimensions[d]); + } } return idx; }; - + virtual int iIndex(std::vector &lcoor) { - int idx=0; - for(int d=0;d<_ndimension;d++) { - if( d==_checker_dim ) { - idx+=_istride[d]*(lcoor[d]/(2*_rdimensions[d])); - } else { - idx+=_istride[d]*(lcoor[d]/_rdimensions[d]); - } - } - return idx; + int idx = 0; + for (int d = 0; d < _ndimension; d++) + { + if (d == _checker_dim) + { + idx += _istride[d] * (lcoor[d] / (2 * _rdimensions[d])); + } + else + { + idx += _istride[d] * (lcoor[d] / _rdimensions[d]); + } + } + return idx; } }; - } #endif diff --git a/lib/communicator/.dirstamp b/lib/communicator/.dirstamp deleted file mode 100644 index e69de29b..00000000 diff --git a/lib/Communicator.h b/lib/communicator/Communicator.h similarity index 100% rename from lib/Communicator.h rename to lib/communicator/Communicator.h diff --git a/lib/communicator/Communicator_base.cc b/lib/communicator/Communicator_base.cc index b003d867..20c310c0 100644 --- a/lib/communicator/Communicator_base.cc +++ b/lib/communicator/Communicator_base.cc @@ -25,14 +25,23 @@ Author: Peter Boyle See the full license in the file "LICENSE" in the top level distribution directory *************************************************************************************/ /* END LEGAL */ -#include "Grid.h" +#include +#include +#include +#include +#include + namespace Grid { /////////////////////////////////////////////////////////////// // Info that is setup once and indept of cartesian layout /////////////////////////////////////////////////////////////// void * CartesianCommunicator::ShmCommBuf; -uint64_t CartesianCommunicator::MAX_MPI_SHM_BYTES = 128*1024*1024; +uint64_t CartesianCommunicator::MAX_MPI_SHM_BYTES = 1024LL*1024LL*1024LL; +CartesianCommunicator::CommunicatorPolicy_t +CartesianCommunicator::CommunicatorPolicy= CartesianCommunicator::CommunicatorPolicyConcurrent; +int CartesianCommunicator::nCommThreads = -1; +int CartesianCommunicator::Hugepages = 0; ///////////////////////////////// // Alloc, free shmem region @@ -58,6 +67,7 @@ void CartesianCommunicator::ShmBufferFreeAll(void) { ///////////////////////////////// // Grid information queries ///////////////////////////////// +int CartesianCommunicator::Dimensions(void) { return _ndimension; }; int CartesianCommunicator::IsBoss(void) { return _processor==0; }; int CartesianCommunicator::BossRank(void) { return 0; }; int CartesianCommunicator::ThisRank(void) { return _processor; }; @@ -86,21 +96,43 @@ void CartesianCommunicator::GlobalSumVector(ComplexD *c,int N) GlobalSumVector((double *)c,2*N); } -#if !defined( GRID_COMMS_MPI3) && !defined (GRID_COMMS_MPI3L) +#if !defined( GRID_COMMS_MPI3) -void CartesianCommunicator::StencilSendToRecvFromBegin(std::vector &list, - void *xmit, - int xmit_to_rank, - void *recv, - int recv_from_rank, - int bytes) +int CartesianCommunicator::NodeCount(void) { return ProcessorCount();}; +int CartesianCommunicator::RankCount(void) { return ProcessorCount();}; +#endif +#if !defined( GRID_COMMS_MPI3) && !defined (GRID_COMMS_MPIT) +double CartesianCommunicator::StencilSendToRecvFrom( void *xmit, + int xmit_to_rank, + void *recv, + int recv_from_rank, + int bytes, int dir) { - SendToRecvFromBegin(list,xmit,xmit_to_rank,recv,recv_from_rank,bytes); + std::vector list; + // Discard the "dir" + SendToRecvFromBegin (list,xmit,xmit_to_rank,recv,recv_from_rank,bytes); + SendToRecvFromComplete(list); + return 2.0*bytes; } -void CartesianCommunicator::StencilSendToRecvFromComplete(std::vector &waitall) +double CartesianCommunicator::StencilSendToRecvFromBegin(std::vector &list, + void *xmit, + int xmit_to_rank, + void *recv, + int recv_from_rank, + int bytes, int dir) +{ + // Discard the "dir" + SendToRecvFromBegin(list,xmit,xmit_to_rank,recv,recv_from_rank,bytes); + return 2.0*bytes; +} +void CartesianCommunicator::StencilSendToRecvFromComplete(std::vector &waitall,int dir) { SendToRecvFromComplete(waitall); } +#endif + +#if !defined( GRID_COMMS_MPI3) + void CartesianCommunicator::StencilBarrier(void){}; commVector CartesianCommunicator::ShmBufStorageVector; @@ -114,8 +146,25 @@ void *CartesianCommunicator::ShmBufferTranslate(int rank,void * local_p) { return NULL; } void CartesianCommunicator::ShmInitGeneric(void){ +#if 1 + + int mmap_flag = MAP_SHARED | MAP_ANONYMOUS; +#ifdef MAP_HUGETLB + if ( Hugepages ) mmap_flag |= MAP_HUGETLB; +#endif + ShmCommBuf =(void *) mmap(NULL, MAX_MPI_SHM_BYTES, PROT_READ | PROT_WRITE, mmap_flag, -1, 0); + if (ShmCommBuf == (void *)MAP_FAILED) { + perror("mmap failed "); + exit(EXIT_FAILURE); + } +#ifdef MADV_HUGEPAGE + if (!Hugepages ) madvise(ShmCommBuf,MAX_MPI_SHM_BYTES,MADV_HUGEPAGE); +#endif +#else ShmBufStorageVector.resize(MAX_MPI_SHM_BYTES); ShmCommBuf=(void *)&ShmBufStorageVector[0]; +#endif + bzero(ShmCommBuf,MAX_MPI_SHM_BYTES); } #endif diff --git a/lib/communicator/Communicator_base.h b/lib/communicator/Communicator_base.h index 94ad1093..ac866ced 100644 --- a/lib/communicator/Communicator_base.h +++ b/lib/communicator/Communicator_base.h @@ -38,7 +38,7 @@ Author: Peter Boyle #ifdef GRID_COMMS_MPI3 #include #endif -#ifdef GRID_COMMS_MPI3L +#ifdef GRID_COMMS_MPIT #include #endif #ifdef GRID_COMMS_SHMEM @@ -50,12 +50,24 @@ namespace Grid { class CartesianCommunicator { public: - // 65536 ranks per node adequate for now + + //////////////////////////////////////////// + // Isend/Irecv/Wait, or Sendrecv blocking + //////////////////////////////////////////// + enum CommunicatorPolicy_t { CommunicatorPolicyConcurrent, CommunicatorPolicySequential }; + static CommunicatorPolicy_t CommunicatorPolicy; + static void SetCommunicatorPolicy(CommunicatorPolicy_t policy ) { CommunicatorPolicy = policy; } + + /////////////////////////////////////////// + // Up to 65536 ranks per node adequate for now // 128MB shared memory for comms enought for 48^4 local vol comms // Give external control (command line override?) of this - - static const int MAXLOG2RANKSPERNODE = 16; - static uint64_t MAX_MPI_SHM_BYTES; + /////////////////////////////////////////// + static const int MAXLOG2RANKSPERNODE = 16; + static uint64_t MAX_MPI_SHM_BYTES; + static int nCommThreads; + // use explicit huge pages + static int Hugepages; // Communicator should know nothing of the physics grid, only processor grid. int _Nprocessors; // How many in all @@ -64,14 +76,18 @@ class CartesianCommunicator { std::vector _processor_coor; // linear processor coordinate unsigned long _ndimension; -#if defined (GRID_COMMS_MPI) || defined (GRID_COMMS_MPI3) || defined (GRID_COMMS_MPI3L) +#if defined (GRID_COMMS_MPI) || defined (GRID_COMMS_MPI3) || defined (GRID_COMMS_MPIT) static MPI_Comm communicator_world; - MPI_Comm communicator; + + MPI_Comm communicator; + std::vector communicator_halo; + typedef MPI_Request CommsRequest_t; #else typedef int CommsRequest_t; #endif + //////////////////////////////////////////////////////////////////// // Helper functionality for SHM Windows common to all other impls //////////////////////////////////////////////////////////////////// @@ -116,6 +132,8 @@ class CartesianCommunicator { // Implemented in Communicator_base.C ///////////////////////////////// static void * ShmCommBuf; + + size_t heap_top; size_t heap_bytes; @@ -142,12 +160,15 @@ class CartesianCommunicator { int RankFromProcessorCoor(std::vector &coor); void ProcessorCoorFromRank(int rank,std::vector &coor); + int Dimensions(void) ; int IsBoss(void) ; int BossRank(void) ; int ThisRank(void) ; const std::vector & ThisProcessorCoor(void) ; const std::vector & ProcessorGrid(void) ; int ProcessorCount(void) ; + int NodeCount(void) ; + int RankCount(void) ; //////////////////////////////////////////////////////////////////////////////// // very VERY rarely (Log, serial RNG) we need world without a grid @@ -168,6 +189,8 @@ class CartesianCommunicator { void GlobalSumVector(ComplexF *c,int N); void GlobalSum(ComplexD &c); void GlobalSumVector(ComplexD *c,int N); + void GlobalXOR(uint32_t &); + void GlobalXOR(uint64_t &); template void GlobalSum(obj &o){ typedef typename obj::scalar_type scalar_type; @@ -200,14 +223,21 @@ class CartesianCommunicator { void SendToRecvFromComplete(std::vector &waitall); - void StencilSendToRecvFromBegin(std::vector &list, - void *xmit, - int xmit_to_rank, - void *recv, - int recv_from_rank, - int bytes); + double StencilSendToRecvFrom(void *xmit, + int xmit_to_rank, + void *recv, + int recv_from_rank, + int bytes,int dir); + + double StencilSendToRecvFromBegin(std::vector &list, + void *xmit, + int xmit_to_rank, + void *recv, + int recv_from_rank, + int bytes,int dir); - void StencilSendToRecvFromComplete(std::vector &waitall); + + void StencilSendToRecvFromComplete(std::vector &waitall,int i); void StencilBarrier(void); //////////////////////////////////////////////////////////// diff --git a/lib/communicator/Communicator_mpi.cc b/lib/communicator/Communicator_mpi.cc index 65ced9c7..bd2a62fb 100644 --- a/lib/communicator/Communicator_mpi.cc +++ b/lib/communicator/Communicator_mpi.cc @@ -25,7 +25,9 @@ Author: Peter Boyle See the full license in the file "LICENSE" in the top level distribution directory *************************************************************************************/ /* END LEGAL */ -#include "Grid.h" +#include +#include +#include #include namespace Grid { @@ -39,9 +41,13 @@ MPI_Comm CartesianCommunicator::communicator_world; // Should error check all MPI calls. void CartesianCommunicator::Init(int *argc, char ***argv) { int flag; + int provided; MPI_Initialized(&flag); // needed to coexist with other libs apparently if ( !flag ) { - MPI_Init(argc,argv); + MPI_Init_thread(argc,argv,MPI_THREAD_MULTIPLE,&provided); + if ( provided != MPI_THREAD_MULTIPLE ) { + QCD::WilsonKernelsStatic::Comms = QCD::WilsonKernelsStatic::CommsThenCompute; + } } MPI_Comm_dup (MPI_COMM_WORLD,&communicator_world); ShmInitGeneric(); @@ -77,6 +83,14 @@ void CartesianCommunicator::GlobalSum(uint64_t &u){ int ierr=MPI_Allreduce(MPI_IN_PLACE,&u,1,MPI_UINT64_T,MPI_SUM,communicator); assert(ierr==0); } +void CartesianCommunicator::GlobalXOR(uint32_t &u){ + int ierr=MPI_Allreduce(MPI_IN_PLACE,&u,1,MPI_UINT32_T,MPI_BXOR,communicator); + assert(ierr==0); +} +void CartesianCommunicator::GlobalXOR(uint64_t &u){ + int ierr=MPI_Allreduce(MPI_IN_PLACE,&u,1,MPI_UINT64_T,MPI_BXOR,communicator); + assert(ierr==0); +} void CartesianCommunicator::GlobalSum(float &f){ int ierr=MPI_Allreduce(MPI_IN_PLACE,&f,1,MPI_FLOAT,MPI_SUM,communicator); assert(ierr==0); @@ -152,24 +166,34 @@ void CartesianCommunicator::SendToRecvFromBegin(std::vector &lis int from, int bytes) { - MPI_Request xrq; - MPI_Request rrq; - int rank = _processor; + int myrank = _processor; int ierr; - ierr =MPI_Isend(xmit, bytes, MPI_CHAR,dest,_processor,communicator,&xrq); - ierr|=MPI_Irecv(recv, bytes, MPI_CHAR,from,from,communicator,&rrq); - - assert(ierr==0); + if ( CommunicatorPolicy == CommunicatorPolicyConcurrent ) { + MPI_Request xrq; + MPI_Request rrq; - list.push_back(xrq); - list.push_back(rrq); + ierr =MPI_Irecv(recv, bytes, MPI_CHAR,from,from,communicator,&rrq); + ierr|=MPI_Isend(xmit, bytes, MPI_CHAR,dest,_processor,communicator,&xrq); + + assert(ierr==0); + list.push_back(xrq); + list.push_back(rrq); + } else { + // Give the CPU to MPI immediately; can use threads to overlap optionally + ierr=MPI_Sendrecv(xmit,bytes,MPI_CHAR,dest,myrank, + recv,bytes,MPI_CHAR,from, from, + communicator,MPI_STATUS_IGNORE); + assert(ierr==0); + } } void CartesianCommunicator::SendToRecvFromComplete(std::vector &list) { - int nreq=list.size(); - std::vector status(nreq); - int ierr = MPI_Waitall(nreq,&list[0],&status[0]); - assert(ierr==0); + if ( CommunicatorPolicy == CommunicatorPolicyConcurrent ) { + int nreq=list.size(); + std::vector status(nreq); + int ierr = MPI_Waitall(nreq,&list[0],&status[0]); + assert(ierr==0); + } } void CartesianCommunicator::Barrier(void) diff --git a/lib/communicator/Communicator_mpi3.cc b/lib/communicator/Communicator_mpi3.cc index c707ec1f..44aa1024 100644 --- a/lib/communicator/Communicator_mpi3.cc +++ b/lib/communicator/Communicator_mpi3.cc @@ -1,4 +1,4 @@ - /************************************************************************************* +/************************************************************************************* Grid physics library, www.github.com/paboyle/Grid @@ -25,9 +25,24 @@ Author: Peter Boyle See the full license in the file "LICENSE" in the top level distribution directory *************************************************************************************/ /* END LEGAL */ -#include "Grid.h" +#include + #include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#ifdef HAVE_NUMAIF_H +#include +#endif + + namespace Grid { /////////////////////////////////////////////////////////////////////////////////////////////////// @@ -50,6 +65,11 @@ std::vector CartesianCommunicator::GroupRanks; std::vector CartesianCommunicator::MyGroup; std::vector CartesianCommunicator::ShmCommBufs; +int CartesianCommunicator::NodeCount(void) { return GroupSize;}; +int CartesianCommunicator::RankCount(void) { return WorldSize;}; + + +#undef FORCE_COMMS void *CartesianCommunicator::ShmBufferSelf(void) { return ShmCommBufs[ShmRank]; @@ -57,6 +77,9 @@ void *CartesianCommunicator::ShmBufferSelf(void) void *CartesianCommunicator::ShmBuffer(int rank) { int gpeer = GroupRanks[rank]; +#ifdef FORCE_COMMS + return NULL; +#endif if (gpeer == MPI_UNDEFINED){ return NULL; } else { @@ -65,7 +88,13 @@ void *CartesianCommunicator::ShmBuffer(int rank) } void *CartesianCommunicator::ShmBufferTranslate(int rank,void * local_p) { + static int count =0; int gpeer = GroupRanks[rank]; + assert(gpeer!=ShmRank); // never send to self + assert(rank!=WorldRank);// never send to self +#ifdef FORCE_COMMS + return NULL; +#endif if (gpeer == MPI_UNDEFINED){ return NULL; } else { @@ -76,16 +105,27 @@ void *CartesianCommunicator::ShmBufferTranslate(int rank,void * local_p) } void CartesianCommunicator::Init(int *argc, char ***argv) { + int flag; + int provided; + // mtrace(); + MPI_Initialized(&flag); // needed to coexist with other libs apparently if ( !flag ) { - MPI_Init(argc,argv); + MPI_Init_thread(argc,argv,MPI_THREAD_MULTIPLE,&provided); + assert (provided == MPI_THREAD_MULTIPLE); } + Grid_quiesce_nodes(); + MPI_Comm_dup (MPI_COMM_WORLD,&communicator_world); MPI_Comm_rank(communicator_world,&WorldRank); MPI_Comm_size(communicator_world,&WorldSize); + if ( WorldRank == 0 ) { + std::cout << GridLogMessage<< "Initialising MPI "<< WorldRank <<"/"< - for(uint64_t page=0;page shmids(ShmSize); + + if ( ShmRank == 0 ) { + for(int r=0;r coor = _processor_coor; - + std::vector coor = _processor_coor; // my coord assert(std::abs(shift) <_processors[dim]); coor[dim] = (_processor_coor[dim] + shift + _processors[dim])%_processors[dim]; @@ -242,28 +430,38 @@ void CartesianCommunicator::ShiftedRanks(int dim,int shift,int &source,int &dest coor[dim] = (_processor_coor[dim] - shift + _processors[dim])%_processors[dim]; Lexicographic::IndexFromCoor(coor,dest,_processors); dest = LexicographicToWorldRank[dest]; -} + +}// rank is world rank. + int CartesianCommunicator::RankFromProcessorCoor(std::vector &coor) { int rank; Lexicographic::IndexFromCoor(coor,rank,_processors); rank = LexicographicToWorldRank[rank]; return rank; -} +}// rank is world rank + void CartesianCommunicator::ProcessorCoorFromRank(int rank, std::vector &coor) { - Lexicographic::CoorFromIndex(coor,rank,_processors); - rank = LexicographicToWorldRank[rank]; + int lr=-1; + for(int r=0;r &processors) { int ierr; - communicator=communicator_world; _ndimension = processors.size(); - + + communicator_halo.resize (2*_ndimension); + for(int i=0;i<_ndimension*2;i++){ + MPI_Comm_dup(communicator,&communicator_halo[i]); + } + //////////////////////////////////////////////////////////////// // Assert power of two shm_size. //////////////////////////////////////////////////////////////// @@ -275,24 +473,22 @@ CartesianCommunicator::CartesianCommunicator(const std::vector &processors) } } assert(log2size != -1); - + //////////////////////////////////////////////////////////////// // Identify subblock of ranks on node spreading across dims // in a maximally symmetrical way //////////////////////////////////////////////////////////////// - int dim = 0; - std::vector WorldDims = processors; - ShmDims.resize(_ndimension,1); + ShmDims.resize (_ndimension,1); GroupDims.resize(_ndimension); - - ShmCoor.resize(_ndimension); + ShmCoor.resize (_ndimension); GroupCoor.resize(_ndimension); WorldCoor.resize(_ndimension); + int dim = 0; for(int l2=0;l2 &processors) GroupDims[d] = WorldDims[d]/ShmDims[d]; } + //////////////////////////////////////////////////////////////// + // Verbose + //////////////////////////////////////////////////////////////// +#if 0 + std::cout<< GridLogMessage << "MPI-3 usage "< &processors) //////////////////////////////////////////////////////////////// // Establish mapping between lexico physics coord and WorldRank - // //////////////////////////////////////////////////////////////// - LexicographicToWorldRank.resize(WorldSize,0); Lexicographic::CoorFromIndex(GroupCoor,GroupRank,GroupDims); Lexicographic::CoorFromIndex(ShmCoor,ShmRank,ShmDims); for(int d=0;d<_ndimension;d++){ WorldCoor[d] = GroupCoor[d]*ShmDims[d]+ShmCoor[d]; } _processor_coor = WorldCoor; - - int lexico; - Lexicographic::IndexFromCoor(WorldCoor,lexico,WorldDims); - LexicographicToWorldRank[lexico]=WorldRank; - _processor = lexico; + _processor = WorldRank; /////////////////////////////////////////////////////////////////// // global sum Lexico to World mapping /////////////////////////////////////////////////////////////////// + int lexico; + LexicographicToWorldRank.resize(WorldSize,0); + Lexicographic::IndexFromCoor(WorldCoor,lexico,WorldDims); + LexicographicToWorldRank[lexico] = WorldRank; ierr=MPI_Allreduce(MPI_IN_PLACE,&LexicographicToWorldRank[0],WorldSize,MPI_INT,MPI_SUM,communicator); assert(ierr==0); - -}; + for(int i=0;i coor(_ndimension); + ProcessorCoorFromRank(wr,coor); // from world rank + int ck = RankFromProcessorCoor(coor); + assert(ck==wr); + + if ( wr == WorldRank ) { + for(int j=0;j mcoor = coor; + this->Broadcast(0,(void *)&mcoor[0],mcoor.size()*sizeof(int)); + for(int d = 0 ; d< _ndimension; d++) { + assert(coor[d] == mcoor[d]); + } + } +}; void CartesianCommunicator::GlobalSum(uint32_t &u){ int ierr=MPI_Allreduce(MPI_IN_PLACE,&u,1,MPI_UINT32_T,MPI_SUM,communicator); assert(ierr==0); @@ -348,6 +595,14 @@ void CartesianCommunicator::GlobalSum(uint64_t &u){ int ierr=MPI_Allreduce(MPI_IN_PLACE,&u,1,MPI_UINT64_T,MPI_SUM,communicator); assert(ierr==0); } +void CartesianCommunicator::GlobalXOR(uint32_t &u){ + int ierr=MPI_Allreduce(MPI_IN_PLACE,&u,1,MPI_UINT32_T,MPI_BXOR,communicator); + assert(ierr==0); +} +void CartesianCommunicator::GlobalXOR(uint64_t &u){ + int ierr=MPI_Allreduce(MPI_IN_PLACE,&u,1,MPI_UINT64_T,MPI_BXOR,communicator); + assert(ierr==0); +} void CartesianCommunicator::GlobalSum(float &f){ int ierr=MPI_Allreduce(MPI_IN_PLACE,&f,1,MPI_FLOAT,MPI_SUM,communicator); assert(ierr==0); @@ -367,8 +622,6 @@ void CartesianCommunicator::GlobalSumVector(double *d,int N) int ierr = MPI_Allreduce(MPI_IN_PLACE,d,N,MPI_DOUBLE,MPI_SUM,communicator); assert(ierr==0); } - - // Basic Halo comms primitive void CartesianCommunicator::SendToRecvFrom(void *xmit, int dest, @@ -377,10 +630,14 @@ void CartesianCommunicator::SendToRecvFrom(void *xmit, int bytes) { std::vector reqs(0); + // unsigned long xcrc = crc32(0L, Z_NULL, 0); + // unsigned long rcrc = crc32(0L, Z_NULL, 0); + // xcrc = crc32(xcrc,(unsigned char *)xmit,bytes); SendToRecvFromBegin(reqs,xmit,dest,recv,from,bytes); SendToRecvFromComplete(reqs); + // rcrc = crc32(rcrc,(unsigned char *)recv,bytes); + // printf("proc %d SendToRecvFrom %d bytes %lx %lx\n",_processor,bytes,xcrc,rcrc); } - void CartesianCommunicator::SendRecvPacket(void *xmit, void *recv, int sender, @@ -397,7 +654,6 @@ void CartesianCommunicator::SendRecvPacket(void *xmit, MPI_Recv(recv, bytes, MPI_CHAR,sender,tag,communicator,&stat); } } - // Basic Halo comms primitive void CartesianCommunicator::SendToRecvFromBegin(std::vector &list, void *xmit, @@ -406,156 +662,110 @@ void CartesianCommunicator::SendToRecvFromBegin(std::vector &lis int from, int bytes) { -#if 0 - this->StencilBarrier(); - - MPI_Request xrq; - MPI_Request rrq; - - static int sequence; - + int myrank = _processor; int ierr; - int tag; - int check; - assert(dest != _processor); - assert(from != _processor); - - int gdest = GroupRanks[dest]; - int gfrom = GroupRanks[from]; - int gme = GroupRanks[_processor]; + if ( CommunicatorPolicy == CommunicatorPolicyConcurrent ) { + MPI_Request xrq; + MPI_Request rrq; - sequence++; - - char *from_ptr = (char *)ShmCommBufs[ShmRank]; - - int small = (bytesStencilBarrier(); - - if (small && (gfrom !=MPI_UNDEFINED) ) { - T *ip = (T *)from_ptr; - T *op = (T *)recv; -PARALLEL_FOR_LOOP - for(int w=0;wStencilBarrier(); - -#else - MPI_Request xrq; - MPI_Request rrq; - int rank = _processor; - int ierr; - ierr =MPI_Isend(xmit, bytes, MPI_CHAR,dest,_processor,communicator,&xrq); - ierr|=MPI_Irecv(recv, bytes, MPI_CHAR,from,from,communicator,&rrq); - - assert(ierr==0); - - list.push_back(xrq); - list.push_back(rrq); -#endif } -void CartesianCommunicator::StencilSendToRecvFromBegin(std::vector &list, - void *xmit, - int dest, - void *recv, - int from, - int bytes) +double CartesianCommunicator::StencilSendToRecvFrom( void *xmit, + int dest, + void *recv, + int from, + int bytes,int dir) { + std::vector list; + double offbytes = StencilSendToRecvFromBegin(list,xmit,dest,recv,from,bytes,dir); + StencilSendToRecvFromComplete(list,dir); + return offbytes; +} + +double CartesianCommunicator::StencilSendToRecvFromBegin(std::vector &list, + void *xmit, + int dest, + void *recv, + int from, + int bytes,int dir) +{ + assert(dir < communicator_halo.size()); + MPI_Request xrq; MPI_Request rrq; int ierr; - - assert(dest != _processor); - assert(from != _processor); - int gdest = GroupRanks[dest]; int gfrom = GroupRanks[from]; int gme = GroupRanks[_processor]; - assert(gme == ShmRank); + assert(dest != _processor); + assert(from != _processor); + assert(gme == ShmRank); + double off_node_bytes=0.0; + +#ifdef FORCE_COMMS + gdest = MPI_UNDEFINED; + gfrom = MPI_UNDEFINED; +#endif + if ( gfrom ==MPI_UNDEFINED) { + ierr=MPI_Irecv(recv, bytes, MPI_CHAR,from,from,communicator_halo[dir],&rrq); + assert(ierr==0); + list.push_back(rrq); + off_node_bytes+=bytes; + } if ( gdest == MPI_UNDEFINED ) { - ierr =MPI_Isend(xmit, bytes, MPI_CHAR,dest,_processor,communicator,&xrq); + ierr =MPI_Isend(xmit, bytes, MPI_CHAR,dest,_processor,communicator_halo[dir],&xrq); assert(ierr==0); list.push_back(xrq); - } - - if ( gfrom ==MPI_UNDEFINED) { - ierr=MPI_Irecv(recv, bytes, MPI_CHAR,from,from,communicator,&rrq); - assert(ierr==0); - list.push_back(rrq); + off_node_bytes+=bytes; } + if ( CommunicatorPolicy == CommunicatorPolicySequential ) { + this->StencilSendToRecvFromComplete(list,dir); + } + + return off_node_bytes; } - - -void CartesianCommunicator::StencilSendToRecvFromComplete(std::vector &list) +void CartesianCommunicator::StencilSendToRecvFromComplete(std::vector &waitall,int dir) { - SendToRecvFromComplete(list); + SendToRecvFromComplete(waitall); } - void CartesianCommunicator::StencilBarrier(void) { - MPI_Win_sync (ShmWindow); MPI_Barrier (ShmComm); - MPI_Win_sync (ShmWindow); } - void CartesianCommunicator::SendToRecvFromComplete(std::vector &list) { int nreq=list.size(); + + if (nreq==0) return; + std::vector status(nreq); int ierr = MPI_Waitall(nreq,&list[0],&status[0]); assert(ierr==0); + list.resize(0); } - void CartesianCommunicator::Barrier(void) { int ierr = MPI_Barrier(communicator); assert(ierr==0); } - void CartesianCommunicator::Broadcast(int root,void* data, int bytes) { int ierr=MPI_Bcast(data, @@ -565,7 +775,11 @@ void CartesianCommunicator::Broadcast(int root,void* data, int bytes) communicator); assert(ierr==0); } - +int CartesianCommunicator::RankWorld(void){ + int r; + MPI_Comm_rank(communicator_world,&r); + return r; +} void CartesianCommunicator::BroadcastWorld(int root,void* data, int bytes) { int ierr= MPI_Bcast(data, diff --git a/lib/communicator/Communicator_mpi3_leader.cc b/lib/communicator/Communicator_mpi3_leader.cc index 71f1a913..6e26bd3e 100644 --- a/lib/communicator/Communicator_mpi3_leader.cc +++ b/lib/communicator/Communicator_mpi3_leader.cc @@ -27,6 +27,7 @@ Author: Peter Boyle /* END LEGAL */ #include "Grid.h" #include +//#include //////////////////////////////////////////////////////////////////////////////////////////////////////////////// /// Workarounds: @@ -42,19 +43,27 @@ Author: Peter Boyle #include #include #include - typedef sem_t *Grid_semaphore; + +#error /*THis is deprecated*/ + +#if 0 #define SEM_INIT(S) S = sem_open(sem_name,0,0600,0); assert ( S != SEM_FAILED ); #define SEM_INIT_EXCL(S) sem_unlink(sem_name); S = sem_open(sem_name,O_CREAT|O_EXCL,0600,0); assert ( S != SEM_FAILED ); #define SEM_POST(S) assert ( sem_post(S) == 0 ); #define SEM_WAIT(S) assert ( sem_wait(S) == 0 ); - +#else +#define SEM_INIT(S) ; +#define SEM_INIT_EXCL(S) ; +#define SEM_POST(S) ; +#define SEM_WAIT(S) ; +#endif #include namespace Grid { -enum { COMMAND_ISEND, COMMAND_IRECV, COMMAND_WAITALL }; +enum { COMMAND_ISEND, COMMAND_IRECV, COMMAND_WAITALL, COMMAND_SENDRECV }; struct Descriptor { uint64_t buf; @@ -62,6 +71,12 @@ struct Descriptor { int rank; int tag; int command; + uint64_t xbuf; + uint64_t rbuf; + int xtag; + int rtag; + int src; + int dest; MPI_Request request; }; @@ -94,18 +109,14 @@ public: void SemInit(void) { sprintf(sem_name,"/Grid_mpi3_sem_head_%d",universe_rank); - // printf("SEM_NAME: %s \n",sem_name); SEM_INIT(sem_head); sprintf(sem_name,"/Grid_mpi3_sem_tail_%d",universe_rank); - // printf("SEM_NAME: %s \n",sem_name); SEM_INIT(sem_tail); } void SemInitExcl(void) { sprintf(sem_name,"/Grid_mpi3_sem_head_%d",universe_rank); - // printf("SEM_INIT_EXCL: %s \n",sem_name); SEM_INIT_EXCL(sem_head); sprintf(sem_name,"/Grid_mpi3_sem_tail_%d",universe_rank); - // printf("SEM_INIT_EXCL: %s \n",sem_name); SEM_INIT_EXCL(sem_tail); } void WakeUpDMA(void) { @@ -125,6 +136,13 @@ public: while(1){ WaitForCommand(); // std::cout << "Getting command "<head,0,0); + int s=state->start; + if ( s != state->head ) { + _mm_mwait(0,0); + } +#endif Event(); } } @@ -132,6 +150,7 @@ public: int Event (void) ; uint64_t QueueCommand(int command,void *buf, int bytes, int hashtag, MPI_Comm comm,int u_rank) ; + void QueueSendRecv(void *xbuf, void *rbuf, int bytes, int xtag, int rtag, MPI_Comm comm,int dest,int src) ; void WaitAll() { // std::cout << "Queueing WAIT command "<tail == state->head ); + while ( state->tail != state->head ); } }; @@ -196,6 +215,12 @@ public: // std::cout << "Waking up DMA "<< slave< MPIoffloadEngine::VerticalShmBufs; std::vector > MPIoffloadEngine::UniverseRanks; std::vector MPIoffloadEngine::UserCommunicatorToWorldRanks; +int CartesianCommunicator::NodeCount(void) { return HorizontalSize;}; int MPIoffloadEngine::ShmSetup = 0; void MPIoffloadEngine::CommunicatorInit (MPI_Comm &communicator_world, @@ -370,12 +418,22 @@ void MPIoffloadEngine::CommunicatorInit (MPI_Comm &communicator_world, ftruncate(fd, size); VerticalShmBufs[r] = mmap(NULL,size, PROT_READ | PROT_WRITE, MAP_SHARED, fd, 0); - if ( VerticalShmBufs[r] == MAP_FAILED ) { perror("failed mmap"); assert(0); } + /* + for(uint64_t page=0;pagehead ) { switch ( state->Descrs[s].command ) { case COMMAND_ISEND: - /* - std::cout<< " Send "<Descrs[s].buf<< "["<Descrs[s].bytes<<"]" - << " to " << state->Descrs[s].rank<< " tag" << state->Descrs[s].tag - << " Comm " << MPIoffloadEngine::communicator_universe<< " me " <Descrs[s].buf+base), state->Descrs[s].bytes, MPI_CHAR, @@ -568,11 +623,6 @@ int Slave::Event (void) { break; case COMMAND_IRECV: - /* - std::cout<< " Recv "<Descrs[s].buf<< "["<Descrs[s].bytes<<"]" - << " from " << state->Descrs[s].rank<< " tag" << state->Descrs[s].tag - << " Comm " << MPIoffloadEngine::communicator_universe<< " me "<< universe_rank<< std::endl; - */ ierr=MPI_Irecv((void *)(state->Descrs[s].buf+base), state->Descrs[s].bytes, MPI_CHAR, @@ -588,10 +638,32 @@ int Slave::Event (void) { return 1; break; + case COMMAND_SENDRECV: + + // fprintf(stderr,"Sendrecv ->%d %d : <-%d %d \n",state->Descrs[s].dest, state->Descrs[s].xtag+i*10,state->Descrs[s].src, state->Descrs[s].rtag+i*10); + + ierr=MPI_Sendrecv((void *)(state->Descrs[s].xbuf+base), state->Descrs[s].bytes, MPI_CHAR, state->Descrs[s].dest, state->Descrs[s].xtag+i*10, + (void *)(state->Descrs[s].rbuf+base), state->Descrs[s].bytes, MPI_CHAR, state->Descrs[s].src , state->Descrs[s].rtag+i*10, + MPIoffloadEngine::communicator_universe,MPI_STATUS_IGNORE); + + assert(ierr==0); + + // fprintf(stderr,"Sendrecv done %d %d\n",ierr,i); + // MPI_Barrier(MPIoffloadEngine::HorizontalComm); + // fprintf(stderr,"Barrier\n"); + i++; + + state->start = PERI_PLUS(s); + + return 1; + break; + case COMMAND_WAITALL: for(int t=state->tail;t!=s; t=PERI_PLUS(t) ){ - MPI_Wait((MPI_Request *)&state->Descrs[t].request,MPI_STATUS_IGNORE); + if ( state->Descrs[t].command != COMMAND_SENDRECV ) { + MPI_Wait((MPI_Request *)&state->Descrs[t].request,MPI_STATUS_IGNORE); + } }; s=PERI_PLUS(s); state->start = s; @@ -613,6 +685,45 @@ int Slave::Event (void) { // External interaction with the queue ////////////////////////////////////////////////////////////////////////////// +void Slave::QueueSendRecv(void *xbuf, void *rbuf, int bytes, int xtag, int rtag, MPI_Comm comm,int dest,int src) +{ + int head =state->head; + int next = PERI_PLUS(head); + + // Set up descriptor + int worldrank; + int hashtag; + MPI_Comm communicator; + MPI_Request request; + uint64_t relative; + + relative = (uint64_t)xbuf - base; + state->Descrs[head].xbuf = relative; + + relative= (uint64_t)rbuf - base; + state->Descrs[head].rbuf = relative; + + state->Descrs[head].bytes = bytes; + + MPIoffloadEngine::MapCommRankToWorldRank(hashtag,worldrank,xtag,comm,dest); + state->Descrs[head].dest = MPIoffloadEngine::UniverseRanks[worldrank][vertical_rank]; + state->Descrs[head].xtag = hashtag; + + MPIoffloadEngine::MapCommRankToWorldRank(hashtag,worldrank,rtag,comm,src); + state->Descrs[head].src = MPIoffloadEngine::UniverseRanks[worldrank][vertical_rank]; + state->Descrs[head].rtag = hashtag; + + state->Descrs[head].command= COMMAND_SENDRECV; + + // Block until FIFO has space + while( state->tail==next ); + + // Msync on weak order architectures + + // Advance pointer + state->head = next; + +}; uint64_t Slave::QueueCommand(int command,void *buf, int bytes, int tag, MPI_Comm comm,int commrank) { ///////////////////////////////////////// @@ -812,19 +923,22 @@ void CartesianCommunicator::StencilSendToRecvFromBegin(std::vector= shm) && (recv_i+bytes <= shm+MAX_MPI_SHM_BYTES) ); assert(from!=_processor); assert(dest!=_processor); - MPIoffloadEngine::QueueMultiplexedSend(xmit,bytes,_processor,communicator,dest); - MPIoffloadEngine::QueueMultiplexedRecv(recv,bytes,from,communicator,from); -} + MPIoffloadEngine::QueueMultiplexedSendRecv(xmit,recv,bytes,_processor,from,communicator,dest,from); + + //MPIoffloadEngine::QueueRoundRobinSendRecv(xmit,recv,bytes,_processor,from,communicator,dest,from); + + //MPIoffloadEngine::QueueMultiplexedSend(xmit,bytes,_processor,communicator,dest); + //MPIoffloadEngine::QueueMultiplexedRecv(recv,bytes,from,communicator,from); +} void CartesianCommunicator::StencilSendToRecvFromComplete(std::vector &list) { MPIoffloadEngine::WaitAll(); + //this->Barrier(); } -void CartesianCommunicator::StencilBarrier(void) -{ -} +void CartesianCommunicator::StencilBarrier(void) { } void CartesianCommunicator::SendToRecvFromComplete(std::vector &list) { diff --git a/lib/communicator/Communicator_mpit.cc b/lib/communicator/Communicator_mpit.cc new file mode 100644 index 00000000..eb6ef87d --- /dev/null +++ b/lib/communicator/Communicator_mpit.cc @@ -0,0 +1,286 @@ + /************************************************************************************* + + Grid physics library, www.github.com/paboyle/Grid + + Source file: ./lib/communicator/Communicator_mpi.cc + + Copyright (C) 2015 + +Author: Peter Boyle + + This program is free software; you can redistribute it and/or modify + it under the terms of the GNU General Public License as published by + the Free Software Foundation; either version 2 of the License, or + (at your option) any later version. + + This program is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU General Public License for more details. + + You should have received a copy of the GNU General Public License along + with this program; if not, write to the Free Software Foundation, Inc., + 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA. + + See the full license in the file "LICENSE" in the top level distribution directory + *************************************************************************************/ + /* END LEGAL */ +#include +#include +#include +#include + +namespace Grid { + + +/////////////////////////////////////////////////////////////////////////////////////////////////// +// Info that is setup once and indept of cartesian layout +/////////////////////////////////////////////////////////////////////////////////////////////////// +MPI_Comm CartesianCommunicator::communicator_world; + +// Should error check all MPI calls. +void CartesianCommunicator::Init(int *argc, char ***argv) { + int flag; + int provided; + MPI_Initialized(&flag); // needed to coexist with other libs apparently + if ( !flag ) { + MPI_Init_thread(argc,argv,MPI_THREAD_MULTIPLE,&provided); + if ( provided != MPI_THREAD_MULTIPLE ) { + QCD::WilsonKernelsStatic::Comms = QCD::WilsonKernelsStatic::CommsThenCompute; + } + } + MPI_Comm_dup (MPI_COMM_WORLD,&communicator_world); + ShmInitGeneric(); +} + +CartesianCommunicator::CartesianCommunicator(const std::vector &processors) +{ + _ndimension = processors.size(); + std::vector periodic(_ndimension,1); + + _Nprocessors=1; + _processors = processors; + _processor_coor.resize(_ndimension); + + MPI_Cart_create(communicator_world, _ndimension,&_processors[0],&periodic[0],1,&communicator); + MPI_Comm_rank(communicator,&_processor); + MPI_Cart_coords(communicator,_processor,_ndimension,&_processor_coor[0]); + + for(int i=0;i<_ndimension;i++){ + _Nprocessors*=_processors[i]; + } + + communicator_halo.resize (2*_ndimension); + for(int i=0;i<_ndimension*2;i++){ + MPI_Comm_dup(communicator,&communicator_halo[i]); + } + + int Size; + MPI_Comm_size(communicator,&Size); + + assert(Size==_Nprocessors); +} +void CartesianCommunicator::GlobalSum(uint32_t &u){ + int ierr=MPI_Allreduce(MPI_IN_PLACE,&u,1,MPI_UINT32_T,MPI_SUM,communicator); + assert(ierr==0); +} +void CartesianCommunicator::GlobalSum(uint64_t &u){ + int ierr=MPI_Allreduce(MPI_IN_PLACE,&u,1,MPI_UINT64_T,MPI_SUM,communicator); + assert(ierr==0); +} +void CartesianCommunicator::GlobalXOR(uint32_t &u){ + int ierr=MPI_Allreduce(MPI_IN_PLACE,&u,1,MPI_UINT32_T,MPI_BXOR,communicator); + assert(ierr==0); +} +void CartesianCommunicator::GlobalXOR(uint64_t &u){ + int ierr=MPI_Allreduce(MPI_IN_PLACE,&u,1,MPI_UINT64_T,MPI_BXOR,communicator); + assert(ierr==0); +} +void CartesianCommunicator::GlobalSum(float &f){ + int ierr=MPI_Allreduce(MPI_IN_PLACE,&f,1,MPI_FLOAT,MPI_SUM,communicator); + assert(ierr==0); +} +void CartesianCommunicator::GlobalSumVector(float *f,int N) +{ + int ierr=MPI_Allreduce(MPI_IN_PLACE,f,N,MPI_FLOAT,MPI_SUM,communicator); + assert(ierr==0); +} +void CartesianCommunicator::GlobalSum(double &d) +{ + int ierr = MPI_Allreduce(MPI_IN_PLACE,&d,1,MPI_DOUBLE,MPI_SUM,communicator); + assert(ierr==0); +} +void CartesianCommunicator::GlobalSumVector(double *d,int N) +{ + int ierr = MPI_Allreduce(MPI_IN_PLACE,d,N,MPI_DOUBLE,MPI_SUM,communicator); + assert(ierr==0); +} +void CartesianCommunicator::ShiftedRanks(int dim,int shift,int &source,int &dest) +{ + int ierr=MPI_Cart_shift(communicator,dim,shift,&source,&dest); + assert(ierr==0); +} +int CartesianCommunicator::RankFromProcessorCoor(std::vector &coor) +{ + int rank; + int ierr=MPI_Cart_rank (communicator, &coor[0], &rank); + assert(ierr==0); + return rank; +} +void CartesianCommunicator::ProcessorCoorFromRank(int rank, std::vector &coor) +{ + coor.resize(_ndimension); + int ierr=MPI_Cart_coords (communicator, rank, _ndimension,&coor[0]); + assert(ierr==0); +} + +// Basic Halo comms primitive +void CartesianCommunicator::SendToRecvFrom(void *xmit, + int dest, + void *recv, + int from, + int bytes) +{ + std::vector reqs(0); + SendToRecvFromBegin(reqs,xmit,dest,recv,from,bytes); + SendToRecvFromComplete(reqs); +} + +void CartesianCommunicator::SendRecvPacket(void *xmit, + void *recv, + int sender, + int receiver, + int bytes) +{ + MPI_Status stat; + assert(sender != receiver); + int tag = sender; + if ( _processor == sender ) { + MPI_Send(xmit, bytes, MPI_CHAR,receiver,tag,communicator); + } + if ( _processor == receiver ) { + MPI_Recv(recv, bytes, MPI_CHAR,sender,tag,communicator,&stat); + } +} + +// Basic Halo comms primitive +void CartesianCommunicator::SendToRecvFromBegin(std::vector &list, + void *xmit, + int dest, + void *recv, + int from, + int bytes) +{ + int myrank = _processor; + int ierr; + if ( CommunicatorPolicy == CommunicatorPolicyConcurrent ) { + MPI_Request xrq; + MPI_Request rrq; + + ierr =MPI_Irecv(recv, bytes, MPI_CHAR,from,from,communicator,&rrq); + ierr|=MPI_Isend(xmit, bytes, MPI_CHAR,dest,_processor,communicator,&xrq); + + assert(ierr==0); + list.push_back(xrq); + list.push_back(rrq); + } else { + // Give the CPU to MPI immediately; can use threads to overlap optionally + ierr=MPI_Sendrecv(xmit,bytes,MPI_CHAR,dest,myrank, + recv,bytes,MPI_CHAR,from, from, + communicator,MPI_STATUS_IGNORE); + assert(ierr==0); + } +} +void CartesianCommunicator::SendToRecvFromComplete(std::vector &list) +{ + if ( CommunicatorPolicy == CommunicatorPolicyConcurrent ) { + int nreq=list.size(); + std::vector status(nreq); + int ierr = MPI_Waitall(nreq,&list[0],&status[0]); + assert(ierr==0); + } +} + +void CartesianCommunicator::Barrier(void) +{ + int ierr = MPI_Barrier(communicator); + assert(ierr==0); +} + +void CartesianCommunicator::Broadcast(int root,void* data, int bytes) +{ + int ierr=MPI_Bcast(data, + bytes, + MPI_BYTE, + root, + communicator); + assert(ierr==0); +} + /////////////////////////////////////////////////////// + // Should only be used prior to Grid Init finished. + // Check for this? + /////////////////////////////////////////////////////// +int CartesianCommunicator::RankWorld(void){ + int r; + MPI_Comm_rank(communicator_world,&r); + return r; +} +void CartesianCommunicator::BroadcastWorld(int root,void* data, int bytes) +{ + int ierr= MPI_Bcast(data, + bytes, + MPI_BYTE, + root, + communicator_world); + assert(ierr==0); +} + +double CartesianCommunicator::StencilSendToRecvFromBegin(std::vector &list, + void *xmit, + int xmit_to_rank, + void *recv, + int recv_from_rank, + int bytes,int dir) +{ + int myrank = _processor; + int ierr; + assert(dir < communicator_halo.size()); + + // std::cout << " sending on communicator "< &waitall,int dir) +{ + int nreq=waitall.size(); + MPI_Waitall(nreq, &waitall[0], MPI_STATUSES_IGNORE); +}; +double CartesianCommunicator::StencilSendToRecvFrom(void *xmit, + int xmit_to_rank, + void *recv, + int recv_from_rank, + int bytes,int dir) +{ + int myrank = _processor; + int ierr; + assert(dir < communicator_halo.size()); + + // std::cout << " sending on communicator "< See the full license in the file "LICENSE" in the top level distribution directory *************************************************************************************/ /* END LEGAL */ -#include "Grid.h" +#include + namespace Grid { /////////////////////////////////////////////////////////////////////////////////////////////////// @@ -58,6 +59,8 @@ void CartesianCommunicator::GlobalSum(double &){} void CartesianCommunicator::GlobalSum(uint32_t &){} void CartesianCommunicator::GlobalSum(uint64_t &){} void CartesianCommunicator::GlobalSumVector(double *,int N){} +void CartesianCommunicator::GlobalXOR(uint32_t &){} +void CartesianCommunicator::GlobalXOR(uint64_t &){} void CartesianCommunicator::SendRecvPacket(void *xmit, void *recv, diff --git a/lib/communicator/Communicator_shmem.cc b/lib/communicator/Communicator_shmem.cc index 56e03224..3c76c808 100644 --- a/lib/communicator/Communicator_shmem.cc +++ b/lib/communicator/Communicator_shmem.cc @@ -25,8 +25,9 @@ Author: Peter Boyle See the full license in the file "LICENSE" in the top level distribution directory *************************************************************************************/ /* END LEGAL */ -#include "Grid.h" +#include #include +#include namespace Grid { @@ -51,7 +52,7 @@ typedef struct HandShake_t { } HandShake; std::array make_psync_init(void) { - array ret; + std::array ret; ret.fill(SHMEM_SYNC_VALUE); return ret; } @@ -109,7 +110,7 @@ void CartesianCommunicator::GlobalSum(uint32_t &u){ source = u; dest = 0; - shmem_longlong_sum_to_all(&dest,&source,1,0,0,_Nprocessors,llwrk,psync); + shmem_longlong_sum_to_all(&dest,&source,1,0,0,_Nprocessors,llwrk,psync.data()); shmem_barrier_all(); // necessary? u = dest; } @@ -125,7 +126,7 @@ void CartesianCommunicator::GlobalSum(uint64_t &u){ source = u; dest = 0; - shmem_longlong_sum_to_all(&dest,&source,1,0,0,_Nprocessors,llwrk,psync); + shmem_longlong_sum_to_all(&dest,&source,1,0,0,_Nprocessors,llwrk,psync.data()); shmem_barrier_all(); // necessary? u = dest; } @@ -137,7 +138,8 @@ void CartesianCommunicator::GlobalSum(float &f){ source = f; dest =0.0; - shmem_float_sum_to_all(&dest,&source,1,0,0,_Nprocessors,llwrk,psync); + shmem_float_sum_to_all(&dest,&source,1,0,0,_Nprocessors,llwrk,psync.data()); + shmem_barrier_all(); f = dest; } void CartesianCommunicator::GlobalSumVector(float *f,int N) @@ -148,14 +150,16 @@ void CartesianCommunicator::GlobalSumVector(float *f,int N) static std::array psync = psync_init; if ( shmem_addr_accessible(f,_processor) ){ - shmem_float_sum_to_all(f,f,N,0,0,_Nprocessors,llwrk,psync); + shmem_float_sum_to_all(f,f,N,0,0,_Nprocessors,llwrk,psync.data()); + shmem_barrier_all(); return; } for(int i=0;i &lis SHMEM_VET(recv); // shmem_putmem_nb(recv,xmit,bytes,dest,NULL); shmem_putmem(recv,xmit,bytes,dest); + + if ( CommunicatorPolicy == CommunicatorPolicySequential ) shmem_barrier_all(); } void CartesianCommunicator::SendToRecvFromComplete(std::vector &list) { // shmem_quiet(); // I'm done - shmem_barrier_all();// He's done too + if( CommunicatorPolicy == CommunicatorPolicyConcurrent ) shmem_barrier_all();// He's done too } void CartesianCommunicator::Barrier(void) { @@ -301,13 +310,13 @@ void CartesianCommunicator::Broadcast(int root,void* data, int bytes) int words = bytes/4; if ( shmem_addr_accessible(data,_processor) ){ - shmem_broadcast32(data,data,words,root,0,0,shmem_n_pes(),psync); + shmem_broadcast32(data,data,words,root,0,0,shmem_n_pes(),psync.data()); return; } for(int w=0;w #include #endif -#ifdef GRID_COMMS_MPI3L +#ifdef GRID_COMMS_MPIT #include #endif diff --git a/lib/cshift/Cshift_common.h b/lib/cshift/Cshift_common.h index 2b146daa..1be672e8 100644 --- a/lib/cshift/Cshift_common.h +++ b/lib/cshift/Cshift_common.h @@ -1,5 +1,4 @@ - - /************************************************************************************* +/************************************************************************************* Grid physics library, www.github.com/paboyle/Grid @@ -31,21 +30,11 @@ Author: Peter Boyle namespace Grid { -template -class SimpleCompressor { -public: - void Point(int) {}; - - vobj operator() (const vobj &arg) { - return arg; - } -}; - /////////////////////////////////////////////////////////////////// -// Gather for when there is no need to SIMD split with compression +// Gather for when there is no need to SIMD split /////////////////////////////////////////////////////////////////// -template void -Gather_plane_simple (const Lattice &rhs,commVector &buffer,int dimension,int plane,int cbmask,compressor &compress, int off=0) +template void +Gather_plane_simple (const Lattice &rhs,commVector &buffer,int dimension,int plane,int cbmask, int off=0) { int rd = rhs._grid->_rdimensions[dimension]; @@ -53,19 +42,17 @@ Gather_plane_simple (const Lattice &rhs,commVector &buffer,int dimen cbmask = 0x3; } - int so = plane*rhs._grid->_ostride[dimension]; // base offset for start of plane - + int so=plane*rhs._grid->_ostride[dimension]; // base offset for start of plane int e1=rhs._grid->_slice_nblock[dimension]; int e2=rhs._grid->_slice_block[dimension]; int stride=rhs._grid->_slice_stride[dimension]; if ( cbmask == 0x3 ) { -PARALLEL_NESTED_LOOP2 - for(int n=0;nCheckerBoardFromOindexTable(o+b); + int ocb=1<CheckerBoardFromOindex(o+b); if ( ocb &cbmask ) { table.push_back(std::pair (bo++,o+b)); } } } -PARALLEL_FOR_LOOP - for(int i=0;i void -Gather_plane_extract(const Lattice &rhs,std::vector pointers,int dimension,int plane,int cbmask,compressor &compress) +template void +Gather_plane_extract(const Lattice &rhs,std::vector pointers,int dimension,int plane,int cbmask) { int rd = rhs._grid->_rdimensions[dimension]; @@ -105,57 +90,40 @@ Gather_plane_extract(const Lattice &rhs,std::vector_slice_nblock[dimension]; int e2=rhs._grid->_slice_block[dimension]; int n1=rhs._grid->_slice_stride[dimension]; - int n2=rhs._grid->_slice_block[dimension]; + if ( cbmask ==0x3){ -PARALLEL_NESTED_LOOP2 - for(int n=0;n(temp,pointers,offset); + int offset = b+n*e2; + + vobj temp =rhs._odata[so+o+b]; + extract(temp,pointers,offset); } } } else { - assert(0); //Fixme think this is buggy - - for(int n=0;n_slice_stride[dimension]; + + int o=n*n1; int ocb=1<CheckerBoardFromOindex(o+b); - int offset = b+n*rhs._grid->_slice_block[dimension]; + int offset = b+n*e2; if ( ocb & cbmask ) { - cobj temp =compress(rhs._odata[so+o+b]); - extract(temp,pointers,offset); + vobj temp =rhs._odata[so+o+b]; + extract(temp,pointers,offset); } } } } } -////////////////////////////////////////////////////// -// Gather for when there is no need to SIMD split -////////////////////////////////////////////////////// -template void Gather_plane_simple (const Lattice &rhs,commVector &buffer, int dimension,int plane,int cbmask) -{ - SimpleCompressor dontcompress; - Gather_plane_simple (rhs,buffer,dimension,plane,cbmask,dontcompress); -} - -////////////////////////////////////////////////////// -// Gather for when there *is* need to SIMD split -////////////////////////////////////////////////////// -template void Gather_plane_extract(const Lattice &rhs,std::vector pointers,int dimension,int plane,int cbmask) -{ - SimpleCompressor dontcompress; - Gather_plane_extract(rhs,pointers,dimension,plane,cbmask,dontcompress); -} - ////////////////////////////////////////////////////// // Scatter for when there is no need to SIMD split ////////////////////////////////////////////////////// @@ -171,10 +139,10 @@ template void Scatter_plane_simple (Lattice &rhs,commVector_slice_nblock[dimension]; int e2=rhs._grid->_slice_block[dimension]; + int stride=rhs._grid->_slice_stride[dimension]; if ( cbmask ==0x3 ) { -PARALLEL_NESTED_LOOP2 - for(int n=0;n_slice_stride[dimension]; int bo =n*rhs._grid->_slice_block[dimension]; @@ -182,24 +150,28 @@ PARALLEL_NESTED_LOOP2 } } } else { + std::vector > table; int bo=0; for(int n=0;n_slice_stride[dimension]; - int bo =n*rhs._grid->_slice_block[dimension]; int ocb=1<CheckerBoardFromOindex(o+b);// Could easily be a table lookup if ( ocb & cbmask ) { - rhs._odata[so+o+b]=buffer[bo++]; + table.push_back(std::pair (so+o+b,bo++)); } } } + parallel_for(int i=0;i void Scatter_plane_merge(Lattice &rhs,std::vector pointers,int dimension,int plane,int cbmask) +template void Scatter_plane_merge(Lattice &rhs,std::vector pointers,int dimension,int plane,int cbmask) { int rd = rhs._grid->_rdimensions[dimension]; @@ -213,8 +185,7 @@ PARALLEL_NESTED_LOOP2 int e2=rhs._grid->_slice_block[dimension]; if(cbmask ==0x3 ) { -PARALLEL_NESTED_LOOP2 - for(int n=0;n_slice_stride[dimension]; int offset = b+n*rhs._grid->_slice_block[dimension]; @@ -222,7 +193,11 @@ PARALLEL_NESTED_LOOP2 } } } else { - assert(0); // think this is buggy FIXME + + // Case of SIMD split AND checker dim cannot currently be hit, except in + // Test_cshift_red_black code. + // std::cout << "Scatter_plane merge assert(0); think this is buggy FIXME "<< std::endl;// think this is buggy FIXME + std::cout<<" Unthreaded warning -- buffer is not densely packed ??"<_slice_stride[dimension]; @@ -254,8 +229,7 @@ template void Copy_plane(Lattice& lhs,const Lattice &rhs int e2=rhs._grid->_slice_block[dimension]; int stride = rhs._grid->_slice_stride[dimension]; if(cbmask == 0x3 ){ -PARALLEL_NESTED_LOOP2 - for(int n=0;n void Copy_plane_permute(Lattice& lhs,const Lattice_slice_nblock[dimension]; int e2=rhs._grid->_slice_block [dimension]; int stride = rhs._grid->_slice_stride[dimension]; -PARALLEL_NESTED_LOOP2 - for(int n=0;n Lattice Cshift_local(Lattice &ret,const Lattice // Map to always positive shift modulo global full dimension. shift = (shift+fd)%fd; - ret.checkerboard = grid->CheckerBoardDestination(rhs.checkerboard,shift,dimension); // the permute type + ret.checkerboard = grid->CheckerBoardDestination(rhs.checkerboard,shift,dimension); int permute_dim =grid->PermuteDim(dimension); int permute_type=grid->PermuteType(dimension); int permute_type_dist; @@ -348,7 +321,6 @@ template Lattice Cshift_local(Lattice &ret,const Lattice int o = 0; int bo = x * grid->_ostride[dimension]; - int cb= (cbmask==0x2)? Odd : Even; int sshift = grid->CheckerBoardShiftForCB(rhs.checkerboard,dimension,shift,cb); @@ -361,9 +333,23 @@ template Lattice Cshift_local(Lattice &ret,const Lattice // wrap is whether sshift > rd. // num is sshift mod rd. // + // shift 7 + // + // XoXo YcYc + // oXoX cYcY + // XoXo YcYc + // oXoX cYcY + // + // sshift -- + // + // XX YY ; 3 + // XX YY ; 0 + // XX YY ; 3 + // XX YY ; 0 + // int permute_slice=0; if(permute_dim){ - int wrap = sshift/rd; + int wrap = sshift/rd; wrap=wrap % ly; int num = sshift%rd; if ( x< rd-num ) permute_slice=wrap; @@ -375,7 +361,6 @@ template Lattice Cshift_local(Lattice &ret,const Lattice } else { permute_type_dist = permute_type; } - } if ( permute_slice ) Copy_plane_permute(ret,rhs,dimension,x,sx,cbmask,permute_type_dist); diff --git a/lib/cshift/Cshift_mpi.h b/lib/cshift/Cshift_mpi.h index b3c07cd6..a66b49bf 100644 --- a/lib/cshift/Cshift_mpi.h +++ b/lib/cshift/Cshift_mpi.h @@ -74,7 +74,6 @@ template void Cshift_comms(Lattice& ret,const Lattice &r sshift[1] = rhs._grid->CheckerBoardShiftForCB(rhs.checkerboard,dimension,shift,Odd); // std::cout << "Cshift_comms dim "< void Cshift_comms(Lattice &ret,const Lattice &r (void *)&recv_buf[0], recv_from_rank, bytes); + grid->Barrier(); - // for(int i=0;i void Cshift_comms_simd(Lattice &ret,const LatticeBarrier(); rpointers[i] = &recv_buf_extract[i][0]; } else { rpointers[i] = &send_buf_extract[nbr_lane][0]; diff --git a/lib/json/json.hpp b/lib/json/json.hpp new file mode 100644 index 00000000..9d589120 --- /dev/null +++ b/lib/json/json.hpp @@ -0,0 +1,14750 @@ +/* + __ _____ _____ _____ + __| | __| | | | JSON for Modern C++ +| | |__ | | | | | | version 2.1.1 +|_____|_____|_____|_|___| https://github.com/nlohmann/json + +Licensed under the MIT License . +Copyright (c) 2013-2017 Niels Lohmann . + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. +*/ + +#ifndef NLOHMANN_JSON_HPP +#define NLOHMANN_JSON_HPP + +#include // all_of, copy, fill, find, for_each, generate_n, none_of, remove, reverse, transform +#include // array +#include // assert +#include // and, not, or +#include // lconv, localeconv +#include // isfinite, labs, ldexp, signbit +#include // nullptr_t, ptrdiff_t, size_t +#include // int64_t, uint64_t +#include // abort, strtod, strtof, strtold, strtoul, strtoll, strtoull +#include // memcpy, strlen +#include // forward_list +#include // function, hash, less +#include // initializer_list +#include // hex +#include // istream, ostream +#include // advance, begin, back_inserter, bidirectional_iterator_tag, distance, end, inserter, iterator, iterator_traits, next, random_access_iterator_tag, reverse_iterator +#include // numeric_limits +#include // locale +#include // map +#include // addressof, allocator, allocator_traits, unique_ptr +#include // accumulate +#include // stringstream +#include // getline, stoi, string, to_string +#include // add_pointer, conditional, decay, enable_if, false_type, integral_constant, is_arithmetic, is_base_of, is_const, is_constructible, is_convertible, is_default_constructible, is_enum, is_floating_point, is_integral, is_nothrow_move_assignable, is_nothrow_move_constructible, is_pointer, is_reference, is_same, is_scalar, is_signed, remove_const, remove_cv, remove_pointer, remove_reference, true_type, underlying_type +#include // declval, forward, make_pair, move, pair, swap +#include // valarray +#include // vector + +// exclude unsupported compilers +#if defined(__clang__) + #if (__clang_major__ * 10000 + __clang_minor__ * 100 + __clang_patchlevel__) < 30400 + #error "unsupported Clang version - see https://github.com/nlohmann/json#supported-compilers" + #endif +#elif defined(__GNUC__) + #if (__GNUC__ * 10000 + __GNUC_MINOR__ * 100 + __GNUC_PATCHLEVEL__) < 40805 + #error "unsupported GCC version - see https://github.com/nlohmann/json#supported-compilers" + #endif +#endif + +// disable float-equal warnings on GCC/clang +#if defined(__clang__) || defined(__GNUC__) || defined(__GNUG__) + #pragma GCC diagnostic push + #pragma GCC diagnostic ignored "-Wfloat-equal" +#endif + +// disable documentation warnings on clang +#if defined(__clang__) + #pragma GCC diagnostic push + #pragma GCC diagnostic ignored "-Wdocumentation" +#endif + +// allow for portable deprecation warnings +#if defined(__clang__) || defined(__GNUC__) || defined(__GNUG__) + #define JSON_DEPRECATED __attribute__((deprecated)) +#elif defined(_MSC_VER) + #define JSON_DEPRECATED __declspec(deprecated) +#else + #define JSON_DEPRECATED +#endif + +// allow to disable exceptions +#if (defined(__cpp_exceptions) || defined(__EXCEPTIONS) || defined(_CPPUNWIND)) && not defined(JSON_NOEXCEPTION) + #define JSON_THROW(exception) throw exception + #define JSON_TRY try + #define JSON_CATCH(exception) catch(exception) +#else + #define JSON_THROW(exception) std::abort() + #define JSON_TRY if(true) + #define JSON_CATCH(exception) if(false) +#endif + +// manual branch prediction +#if defined(__clang__) || defined(__GNUC__) || defined(__GNUG__) + #define JSON_LIKELY(x) __builtin_expect(!!(x), 1) + #define JSON_UNLIKELY(x) __builtin_expect(!!(x), 0) +#else + #define JSON_LIKELY(x) x + #define JSON_UNLIKELY(x) x +#endif + +/*! +@brief namespace for Niels Lohmann +@see https://github.com/nlohmann +@since version 1.0.0 +*/ +namespace nlohmann +{ +template +struct adl_serializer; + +// forward declaration of basic_json (required to split the class) +template class ObjectType = + std::map, + template class ArrayType = std::vector, + class StringType = std::string, class BooleanType = bool, + class NumberIntegerType = std::int64_t, + class NumberUnsignedType = std::uint64_t, + class NumberFloatType = double, + template class AllocatorType = std::allocator, + template class JSONSerializer = + adl_serializer> +class basic_json; + +// Ugly macros to avoid uglier copy-paste when specializing basic_json +// This is only temporary and will be removed in 3.0 + +#define NLOHMANN_BASIC_JSON_TPL_DECLARATION \ + template class ObjectType, \ + template class ArrayType, \ + class StringType, class BooleanType, class NumberIntegerType, \ + class NumberUnsignedType, class NumberFloatType, \ + template class AllocatorType, \ + template class JSONSerializer> + +#define NLOHMANN_BASIC_JSON_TPL \ + basic_json + + +/*! +@brief unnamed namespace with internal helper functions + +This namespace collects some functions that could not be defined inside the +@ref basic_json class. + +@since version 2.1.0 +*/ +namespace detail +{ +//////////////// +// exceptions // +//////////////// + +/*! +@brief general exception of the @ref basic_json class + +This class is an extension of `std::exception` objects with a member @a id for +exception ids. It is used as the base class for all exceptions thrown by the +@ref basic_json class. This class can hence be used as "wildcard" to catch +exceptions. + +Subclasses: +- @ref parse_error for exceptions indicating a parse error +- @ref invalid_iterator for exceptions indicating errors with iterators +- @ref type_error for exceptions indicating executing a member function with + a wrong type +- @ref out_of_range for exceptions indicating access out of the defined range +- @ref other_error for exceptions indicating other library errors + +@internal +@note To have nothrow-copy-constructible exceptions, we internally use + `std::runtime_error` which can cope with arbitrary-length error messages. + Intermediate strings are built with static functions and then passed to + the actual constructor. +@endinternal + +@liveexample{The following code shows how arbitrary library exceptions can be +caught.,exception} + +@since version 3.0.0 +*/ +class exception : public std::exception +{ + public: + /// returns the explanatory string + const char* what() const noexcept override + { + return m.what(); + } + + /// the id of the exception + const int id; + + protected: + exception(int id_, const char* what_arg) : id(id_), m(what_arg) {} + + static std::string name(const std::string& ename, int id) + { + return "[json.exception." + ename + "." + std::to_string(id) + "] "; + } + + private: + /// an exception object as storage for error messages + std::runtime_error m; +}; + +/*! +@brief exception indicating a parse error + +This excpetion is thrown by the library when a parse error occurs. Parse errors +can occur during the deserialization of JSON text, CBOR, MessagePack, as well +as when using JSON Patch. + +Member @a byte holds the byte index of the last read character in the input +file. + +Exceptions have ids 1xx. + +name / id | example message | description +------------------------------ | --------------- | ------------------------- +json.exception.parse_error.101 | parse error at 2: unexpected end of input; expected string literal | This error indicates a syntax error while deserializing a JSON text. The error message describes that an unexpected token (character) was encountered, and the member @a byte indicates the error position. +json.exception.parse_error.102 | parse error at 14: missing or wrong low surrogate | JSON uses the `\uxxxx` format to describe Unicode characters. Code points above above 0xFFFF are split into two `\uxxxx` entries ("surrogate pairs"). This error indicates that the surrogate pair is incomplete or contains an invalid code point. +json.exception.parse_error.103 | parse error: code points above 0x10FFFF are invalid | Unicode supports code points up to 0x10FFFF. Code points above 0x10FFFF are invalid. +json.exception.parse_error.104 | parse error: JSON patch must be an array of objects | [RFC 6902](https://tools.ietf.org/html/rfc6902) requires a JSON Patch document to be a JSON document that represents an array of objects. +json.exception.parse_error.105 | parse error: operation must have string member 'op' | An operation of a JSON Patch document must contain exactly one "op" member, whose value indicates the operation to perform. Its value must be one of "add", "remove", "replace", "move", "copy", or "test"; other values are errors. +json.exception.parse_error.106 | parse error: array index '01' must not begin with '0' | An array index in a JSON Pointer ([RFC 6901](https://tools.ietf.org/html/rfc6901)) may be `0` or any number wihtout a leading `0`. +json.exception.parse_error.107 | parse error: JSON pointer must be empty or begin with '/' - was: 'foo' | A JSON Pointer must be a Unicode string containing a sequence of zero or more reference tokens, each prefixed by a `/` character. +json.exception.parse_error.108 | parse error: escape character '~' must be followed with '0' or '1' | In a JSON Pointer, only `~0` and `~1` are valid escape sequences. +json.exception.parse_error.109 | parse error: array index 'one' is not a number | A JSON Pointer array index must be a number. +json.exception.parse_error.110 | parse error at 1: cannot read 2 bytes from vector | When parsing CBOR or MessagePack, the byte vector ends before the complete value has been read. +json.exception.parse_error.112 | parse error at 1: error reading CBOR; last byte: 0xf8 | Not all types of CBOR or MessagePack are supported. This exception occurs if an unsupported byte was read. +json.exception.parse_error.113 | parse error at 2: expected a CBOR string; last byte: 0x98 | While parsing a map key, a value that is not a string has been read. + +@note For an input with n bytes, 1 is the index of the first character and n+1 + is the index of the terminating null byte or the end of file. This also + holds true when reading a byte vector (CBOR or MessagePack). + +@liveexample{The following code shows how a `parse_error` exception can be +caught.,parse_error} + +@sa @ref exception for the base class of the library exceptions +@sa @ref invalid_iterator for exceptions indicating errors with iterators +@sa @ref type_error for exceptions indicating executing a member function with + a wrong type +@sa @ref out_of_range for exceptions indicating access out of the defined range +@sa @ref other_error for exceptions indicating other library errors + +@since version 3.0.0 +*/ +class parse_error : public exception +{ + public: + /*! + @brief create a parse error exception + @param[in] id the id of the exception + @param[in] byte_ the byte index where the error occurred (or 0 if the + position cannot be determined) + @param[in] what_arg the explanatory string + @return parse_error object + */ + static parse_error create(int id, std::size_t byte_, const std::string& what_arg) + { + std::string w = exception::name("parse_error", id) + "parse error" + + (byte_ != 0 ? (" at " + std::to_string(byte_)) : "") + + ": " + what_arg; + return parse_error(id, byte_, w.c_str()); + } + + /*! + @brief byte index of the parse error + + The byte index of the last read character in the input file. + + @note For an input with n bytes, 1 is the index of the first character and + n+1 is the index of the terminating null byte or the end of file. + This also holds true when reading a byte vector (CBOR or MessagePack). + */ + const std::size_t byte; + + private: + parse_error(int id_, std::size_t byte_, const char* what_arg) + : exception(id_, what_arg), byte(byte_) {} +}; + +/*! +@brief exception indicating errors with iterators + +This exception is thrown if iterators passed to a library function do not match +the expected semantics. + +Exceptions have ids 2xx. + +name / id | example message | description +----------------------------------- | --------------- | ------------------------- +json.exception.invalid_iterator.201 | iterators are not compatible | The iterators passed to constructor @ref basic_json(InputIT first, InputIT last) are not compatible, meaning they do not belong to the same container. Therefore, the range (@a first, @a last) is invalid. +json.exception.invalid_iterator.202 | iterator does not fit current value | In an erase or insert function, the passed iterator @a pos does not belong to the JSON value for which the function was called. It hence does not define a valid position for the deletion/insertion. +json.exception.invalid_iterator.203 | iterators do not fit current value | Either iterator passed to function @ref erase(IteratorType first, IteratorType last) does not belong to the JSON value from which values shall be erased. It hence does not define a valid range to delete values from. +json.exception.invalid_iterator.204 | iterators out of range | When an iterator range for a primitive type (number, boolean, or string) is passed to a constructor or an erase function, this range has to be exactly (@ref begin(), @ref end()), because this is the only way the single stored value is expressed. All other ranges are invalid. +json.exception.invalid_iterator.205 | iterator out of range | When an iterator for a primitive type (number, boolean, or string) is passed to an erase function, the iterator has to be the @ref begin() iterator, because it is the only way to address the stored value. All other iterators are invalid. +json.exception.invalid_iterator.206 | cannot construct with iterators from null | The iterators passed to constructor @ref basic_json(InputIT first, InputIT last) belong to a JSON null value and hence to not define a valid range. +json.exception.invalid_iterator.207 | cannot use key() for non-object iterators | The key() member function can only be used on iterators belonging to a JSON object, because other types do not have a concept of a key. +json.exception.invalid_iterator.208 | cannot use operator[] for object iterators | The operator[] to specify a concrete offset cannot be used on iterators belonging to a JSON object, because JSON objects are unordered. +json.exception.invalid_iterator.209 | cannot use offsets with object iterators | The offset operators (+, -, +=, -=) cannot be used on iterators belonging to a JSON object, because JSON objects are unordered. +json.exception.invalid_iterator.210 | iterators do not fit | The iterator range passed to the insert function are not compatible, meaning they do not belong to the same container. Therefore, the range (@a first, @a last) is invalid. +json.exception.invalid_iterator.211 | passed iterators may not belong to container | The iterator range passed to the insert function must not be a subrange of the container to insert to. +json.exception.invalid_iterator.212 | cannot compare iterators of different containers | When two iterators are compared, they must belong to the same container. +json.exception.invalid_iterator.213 | cannot compare order of object iterators | The order of object iterators cannot be compared, because JSON objects are unordered. +json.exception.invalid_iterator.214 | cannot get value | Cannot get value for iterator: Either the iterator belongs to a null value or it is an iterator to a primitive type (number, boolean, or string), but the iterator is different to @ref begin(). + +@liveexample{The following code shows how an `invalid_iterator` exception can be +caught.,invalid_iterator} + +@sa @ref exception for the base class of the library exceptions +@sa @ref parse_error for exceptions indicating a parse error +@sa @ref type_error for exceptions indicating executing a member function with + a wrong type +@sa @ref out_of_range for exceptions indicating access out of the defined range +@sa @ref other_error for exceptions indicating other library errors + +@since version 3.0.0 +*/ +class invalid_iterator : public exception +{ + public: + static invalid_iterator create(int id, const std::string& what_arg) + { + std::string w = exception::name("invalid_iterator", id) + what_arg; + return invalid_iterator(id, w.c_str()); + } + + private: + invalid_iterator(int id_, const char* what_arg) + : exception(id_, what_arg) {} +}; + +/*! +@brief exception indicating executing a member function with a wrong type + +This exception is thrown in case of a type error; that is, a library function is +executed on a JSON value whose type does not match the expected semantics. + +Exceptions have ids 3xx. + +name / id | example message | description +----------------------------- | --------------- | ------------------------- +json.exception.type_error.301 | cannot create object from initializer list | To create an object from an initializer list, the initializer list must consist only of a list of pairs whose first element is a string. When this constraint is violated, an array is created instead. +json.exception.type_error.302 | type must be object, but is array | During implicit or explicit value conversion, the JSON type must be compatible to the target type. For instance, a JSON string can only be converted into string types, but not into numbers or boolean types. +json.exception.type_error.303 | incompatible ReferenceType for get_ref, actual type is object | To retrieve a reference to a value stored in a @ref basic_json object with @ref get_ref, the type of the reference must match the value type. For instance, for a JSON array, the @a ReferenceType must be @ref array_t&. +json.exception.type_error.304 | cannot use at() with string | The @ref at() member functions can only be executed for certain JSON types. +json.exception.type_error.305 | cannot use operator[] with string | The @ref operator[] member functions can only be executed for certain JSON types. +json.exception.type_error.306 | cannot use value() with string | The @ref value() member functions can only be executed for certain JSON types. +json.exception.type_error.307 | cannot use erase() with string | The @ref erase() member functions can only be executed for certain JSON types. +json.exception.type_error.308 | cannot use push_back() with string | The @ref push_back() and @ref operator+= member functions can only be executed for certain JSON types. +json.exception.type_error.309 | cannot use insert() with | The @ref insert() member functions can only be executed for certain JSON types. +json.exception.type_error.310 | cannot use swap() with number | The @ref swap() member functions can only be executed for certain JSON types. +json.exception.type_error.311 | cannot use emplace_back() with string | The @ref emplace_back() member function can only be executed for certain JSON types. +json.exception.type_error.312 | cannot use update() with string | The @ref update() member functions can only be executed for certain JSON types. +json.exception.type_error.313 | invalid value to unflatten | The @ref unflatten function converts an object whose keys are JSON Pointers back into an arbitrary nested JSON value. The JSON Pointers must not overlap, because then the resulting value would not be well defined. +json.exception.type_error.314 | only objects can be unflattened | The @ref unflatten function only works for an object whose keys are JSON Pointers. +json.exception.type_error.315 | values in object must be primitive | The @ref unflatten function only works for an object whose keys are JSON Pointers and whose values are primitive. + +@liveexample{The following code shows how a `type_error` exception can be +caught.,type_error} + +@sa @ref exception for the base class of the library exceptions +@sa @ref parse_error for exceptions indicating a parse error +@sa @ref invalid_iterator for exceptions indicating errors with iterators +@sa @ref out_of_range for exceptions indicating access out of the defined range +@sa @ref other_error for exceptions indicating other library errors + +@since version 3.0.0 +*/ +class type_error : public exception +{ + public: + static type_error create(int id, const std::string& what_arg) + { + std::string w = exception::name("type_error", id) + what_arg; + return type_error(id, w.c_str()); + } + + private: + type_error(int id_, const char* what_arg) : exception(id_, what_arg) {} +}; + +/*! +@brief exception indicating access out of the defined range + +This exception is thrown in case a library function is called on an input +parameter that exceeds the expected range, for instance in case of array +indices or nonexisting object keys. + +Exceptions have ids 4xx. + +name / id | example message | description +------------------------------- | --------------- | ------------------------- +json.exception.out_of_range.401 | array index 3 is out of range | The provided array index @a i is larger than @a size-1. +json.exception.out_of_range.402 | array index '-' (3) is out of range | The special array index `-` in a JSON Pointer never describes a valid element of the array, but the index past the end. That is, it can only be used to add elements at this position, but not to read it. +json.exception.out_of_range.403 | key 'foo' not found | The provided key was not found in the JSON object. +json.exception.out_of_range.404 | unresolved reference token 'foo' | A reference token in a JSON Pointer could not be resolved. +json.exception.out_of_range.405 | JSON pointer has no parent | The JSON Patch operations 'remove' and 'add' can not be applied to the root element of the JSON value. +json.exception.out_of_range.406 | number overflow parsing '10E1000' | A parsed number could not be stored as without changing it to NaN or INF. + +@liveexample{The following code shows how an `out_of_range` exception can be +caught.,out_of_range} + +@sa @ref exception for the base class of the library exceptions +@sa @ref parse_error for exceptions indicating a parse error +@sa @ref invalid_iterator for exceptions indicating errors with iterators +@sa @ref type_error for exceptions indicating executing a member function with + a wrong type +@sa @ref other_error for exceptions indicating other library errors + +@since version 3.0.0 +*/ +class out_of_range : public exception +{ + public: + static out_of_range create(int id, const std::string& what_arg) + { + std::string w = exception::name("out_of_range", id) + what_arg; + return out_of_range(id, w.c_str()); + } + + private: + out_of_range(int id_, const char* what_arg) : exception(id_, what_arg) {} +}; + +/*! +@brief exception indicating other library errors + +This exception is thrown in case of errors that cannot be classified with the +other exception types. + +Exceptions have ids 5xx. + +name / id | example message | description +------------------------------ | --------------- | ------------------------- +json.exception.other_error.501 | unsuccessful: {"op":"test","path":"/baz", "value":"bar"} | A JSON Patch operation 'test' failed. The unsuccessful operation is also printed. +json.exception.other_error.502 | invalid object size for conversion | Some conversions to user-defined types impose constraints on the object size (e.g. std::pair) + +@sa @ref exception for the base class of the library exceptions +@sa @ref parse_error for exceptions indicating a parse error +@sa @ref invalid_iterator for exceptions indicating errors with iterators +@sa @ref type_error for exceptions indicating executing a member function with + a wrong type +@sa @ref out_of_range for exceptions indicating access out of the defined range + +@liveexample{The following code shows how an `other_error` exception can be +caught.,other_error} + +@since version 3.0.0 +*/ +class other_error : public exception +{ + public: + static other_error create(int id, const std::string& what_arg) + { + std::string w = exception::name("other_error", id) + what_arg; + return other_error(id, w.c_str()); + } + + private: + other_error(int id_, const char* what_arg) : exception(id_, what_arg) {} +}; + + + +/////////////////////////// +// JSON type enumeration // +/////////////////////////// + +/*! +@brief the JSON type enumeration + +This enumeration collects the different JSON types. It is internally used to +distinguish the stored values, and the functions @ref basic_json::is_null(), +@ref basic_json::is_object(), @ref basic_json::is_array(), +@ref basic_json::is_string(), @ref basic_json::is_boolean(), +@ref basic_json::is_number() (with @ref basic_json::is_number_integer(), +@ref basic_json::is_number_unsigned(), and @ref basic_json::is_number_float()), +@ref basic_json::is_discarded(), @ref basic_json::is_primitive(), and +@ref basic_json::is_structured() rely on it. + +@note There are three enumeration entries (number_integer, number_unsigned, and +number_float), because the library distinguishes these three types for numbers: +@ref basic_json::number_unsigned_t is used for unsigned integers, +@ref basic_json::number_integer_t is used for signed integers, and +@ref basic_json::number_float_t is used for floating-point numbers or to +approximate integers which do not fit in the limits of their respective type. + +@sa @ref basic_json::basic_json(const value_t value_type) -- create a JSON +value with the default value for a given type + +@since version 1.0.0 +*/ +enum class value_t : uint8_t +{ + null, ///< null value + object, ///< object (unordered set of name/value pairs) + array, ///< array (ordered collection of values) + string, ///< string value + boolean, ///< boolean value + number_integer, ///< number value (signed integer) + number_unsigned, ///< number value (unsigned integer) + number_float, ///< number value (floating-point) + discarded ///< discarded by the the parser callback function +}; + +/*! +@brief comparison operator for JSON types + +Returns an ordering that is similar to Python: +- order: null < boolean < number < object < array < string +- furthermore, each type is not smaller than itself + +@since version 1.0.0 +*/ +inline bool operator<(const value_t lhs, const value_t rhs) noexcept +{ + static constexpr std::array order = {{ + 0, // null + 3, // object + 4, // array + 5, // string + 1, // boolean + 2, // integer + 2, // unsigned + 2, // float + } + }; + + // discarded values are not comparable + return lhs != value_t::discarded and rhs != value_t::discarded and + order[static_cast(lhs)] < order[static_cast(rhs)]; +} + + +///////////// +// helpers // +///////////// + +template struct is_basic_json : std::false_type {}; + +NLOHMANN_BASIC_JSON_TPL_DECLARATION +struct is_basic_json : std::true_type {}; + +// alias templates to reduce boilerplate +template +using enable_if_t = typename std::enable_if::type; + +template +using uncvref_t = typename std::remove_cv::type>::type; + +// implementation of C++14 index_sequence and affiliates +// source: https://stackoverflow.com/a/32223343 +template +struct index_sequence +{ + using type = index_sequence; + using value_type = std::size_t; + static constexpr std::size_t size() noexcept + { + return sizeof...(Ints); + } +}; + +template +struct merge_and_renumber; + +template +struct merge_and_renumber, index_sequence> + : index_sequence < I1..., (sizeof...(I1) + I2)... > + {}; + +template +struct make_index_sequence + : merge_and_renumber < typename make_index_sequence < N / 2 >::type, + typename make_index_sequence < N - N / 2 >::type > +{}; + +template<> struct make_index_sequence<0> : index_sequence<> { }; +template<> struct make_index_sequence<1> : index_sequence<0> { }; + +template +using index_sequence_for = make_index_sequence; + +/* +Implementation of two C++17 constructs: conjunction, negation. This is needed +to avoid evaluating all the traits in a condition + +For example: not std::is_same::value and has_value_type::value +will not compile when T = void (on MSVC at least). Whereas +conjunction>, has_value_type>::value will +stop evaluating if negation<...>::value == false + +Please note that those constructs must be used with caution, since symbols can +become very long quickly (which can slow down compilation and cause MSVC +internal compiler errors). Only use it when you have to (see example ahead). +*/ +template struct conjunction : std::true_type {}; +template struct conjunction : B1 {}; +template +struct conjunction : std::conditional, B1>::type {}; + +template struct negation : std::integral_constant < bool, !B::value > {}; + +// dispatch utility (taken from ranges-v3) +template struct priority_tag : priority_tag < N - 1 > {}; +template<> struct priority_tag<0> {}; + + +////////////////// +// constructors // +////////////////// + +template struct external_constructor; + +template<> +struct external_constructor +{ + template + static void construct(BasicJsonType& j, typename BasicJsonType::boolean_t b) noexcept + { + j.m_type = value_t::boolean; + j.m_value = b; + j.assert_invariant(); + } +}; + +template<> +struct external_constructor +{ + template + static void construct(BasicJsonType& j, const typename BasicJsonType::string_t& s) + { + j.m_type = value_t::string; + j.m_value = s; + j.assert_invariant(); + } + + template + static void construct(BasicJsonType& j, typename BasicJsonType::string_t&& s) + { + j.m_type = value_t::string; + j.m_value = std::move(s); + j.assert_invariant(); + } +}; + +template<> +struct external_constructor +{ + template + static void construct(BasicJsonType& j, typename BasicJsonType::number_float_t val) noexcept + { + j.m_type = value_t::number_float; + j.m_value = val; + j.assert_invariant(); + } +}; + +template<> +struct external_constructor +{ + template + static void construct(BasicJsonType& j, typename BasicJsonType::number_unsigned_t val) noexcept + { + j.m_type = value_t::number_unsigned; + j.m_value = val; + j.assert_invariant(); + } +}; + +template<> +struct external_constructor +{ + template + static void construct(BasicJsonType& j, typename BasicJsonType::number_integer_t val) noexcept + { + j.m_type = value_t::number_integer; + j.m_value = val; + j.assert_invariant(); + } +}; + +template<> +struct external_constructor +{ + template + static void construct(BasicJsonType& j, const typename BasicJsonType::array_t& arr) + { + j.m_type = value_t::array; + j.m_value = arr; + j.assert_invariant(); + } + + template + static void construct(BasicJsonType& j, typename BasicJsonType::array_t&& arr) + { + j.m_type = value_t::array; + j.m_value = std::move(arr); + j.assert_invariant(); + } + + template::value, + int> = 0> + static void construct(BasicJsonType& j, const CompatibleArrayType& arr) + { + using std::begin; + using std::end; + j.m_type = value_t::array; + j.m_value.array = j.template create(begin(arr), end(arr)); + j.assert_invariant(); + } + + template + static void construct(BasicJsonType& j, const std::vector& arr) + { + j.m_type = value_t::array; + j.m_value = value_t::array; + j.m_value.array->reserve(arr.size()); + for (bool x : arr) + { + j.m_value.array->push_back(x); + } + j.assert_invariant(); + } + + template::value, int> = 0> + static void construct(BasicJsonType& j, const std::valarray& arr) + { + using std::begin; + using std::end; + j.m_type = value_t::array; + j.m_value = value_t::array; + j.m_value.array = j.template create(begin(arr), end(arr)); + j.assert_invariant(); + } +}; + +template<> +struct external_constructor +{ + template + static void construct(BasicJsonType& j, const typename BasicJsonType::object_t& obj) + { + j.m_type = value_t::object; + j.m_value = obj; + j.assert_invariant(); + } + + template + static void construct(BasicJsonType& j, typename BasicJsonType::object_t&& obj) + { + j.m_type = value_t::object; + j.m_value = std::move(obj); + j.assert_invariant(); + } + + template::value, int> = 0> + static void construct(BasicJsonType& j, const CompatibleObjectType& obj) + { + using std::begin; + using std::end; + + j.m_type = value_t::object; + j.m_value.object = j.template create(begin(obj), end(obj)); + j.assert_invariant(); + } +}; + + +//////////////////////// +// has_/is_ functions // +//////////////////////// + +/*! +@brief Helper to determine whether there's a key_type for T. + +This helper is used to tell associative containers apart from other containers +such as sequence containers. For instance, `std::map` passes the test as it +contains a `mapped_type`, whereas `std::vector` fails the test. + +@sa http://stackoverflow.com/a/7728728/266378 +@since version 1.0.0, overworked in version 2.0.6 +*/ +#define NLOHMANN_JSON_HAS_HELPER(type) \ + template struct has_##type { \ + private: \ + template \ + static int detect(U &&); \ + static void detect(...); \ + public: \ + static constexpr bool value = \ + std::is_integral()))>::value; \ + } + +NLOHMANN_JSON_HAS_HELPER(mapped_type); +NLOHMANN_JSON_HAS_HELPER(key_type); +NLOHMANN_JSON_HAS_HELPER(value_type); +NLOHMANN_JSON_HAS_HELPER(iterator); + +#undef NLOHMANN_JSON_HAS_HELPER + + +template +struct is_compatible_object_type_impl : std::false_type {}; + +template +struct is_compatible_object_type_impl +{ + static constexpr auto value = + std::is_constructible::value and + std::is_constructible::value; +}; + +template +struct is_compatible_object_type +{ + static auto constexpr value = is_compatible_object_type_impl < + conjunction>, + has_mapped_type, + has_key_type>::value, + typename BasicJsonType::object_t, CompatibleObjectType >::value; +}; + +template +struct is_basic_json_nested_type +{ + static auto constexpr value = std::is_same::value or + std::is_same::value or + std::is_same::value or + std::is_same::value; +}; + +template +struct is_compatible_array_type +{ + static auto constexpr value = + conjunction>, + negation>, + negation>, + negation>, + has_value_type, + has_iterator>::value; +}; + +template +struct is_compatible_integer_type_impl : std::false_type {}; + +template +struct is_compatible_integer_type_impl +{ + // is there an assert somewhere on overflows? + using RealLimits = std::numeric_limits; + using CompatibleLimits = std::numeric_limits; + + static constexpr auto value = + std::is_constructible::value and + CompatibleLimits::is_integer and + RealLimits::is_signed == CompatibleLimits::is_signed; +}; + +template +struct is_compatible_integer_type +{ + static constexpr auto value = + is_compatible_integer_type_impl < + std::is_integral::value and + not std::is_same::value, + RealIntegerType, CompatibleNumberIntegerType > ::value; +}; + + +// trait checking if JSONSerializer::from_json(json const&, udt&) exists +template +struct has_from_json +{ + private: + // also check the return type of from_json + template::from_json( + std::declval(), std::declval()))>::value>> + static int detect(U&&); + static void detect(...); + + public: + static constexpr bool value = std::is_integral>()))>::value; +}; + +// This trait checks if JSONSerializer::from_json(json const&) exists +// this overload is used for non-default-constructible user-defined-types +template +struct has_non_default_from_json +{ + private: + template < + typename U, + typename = enable_if_t::from_json(std::declval()))>::value >> + static int detect(U&&); + static void detect(...); + + public: + static constexpr bool value = std::is_integral>()))>::value; +}; + +// This trait checks if BasicJsonType::json_serializer::to_json exists +template +struct has_to_json +{ + private: + template::to_json( + std::declval(), std::declval()))> + static int detect(U&&); + static void detect(...); + + public: + static constexpr bool value = std::is_integral>()))>::value; +}; + + +///////////// +// to_json // +///////////// + +template::value, int> = 0> +void to_json(BasicJsonType& j, T b) noexcept +{ + external_constructor::construct(j, b); +} + +template::value, int> = 0> +void to_json(BasicJsonType& j, const CompatibleString& s) +{ + external_constructor::construct(j, s); +} + +template +void to_json(BasicJsonType& j, typename BasicJsonType::string_t&& s) +{ + external_constructor::construct(j, std::move(s)); +} + +template::value, int> = 0> +void to_json(BasicJsonType& j, FloatType val) noexcept +{ + external_constructor::construct(j, static_cast(val)); +} + +template < + typename BasicJsonType, typename CompatibleNumberUnsignedType, + enable_if_t::value, int> = 0 > +void to_json(BasicJsonType& j, CompatibleNumberUnsignedType val) noexcept +{ + external_constructor::construct(j, static_cast(val)); +} + +template < + typename BasicJsonType, typename CompatibleNumberIntegerType, + enable_if_t::value, int> = 0 > +void to_json(BasicJsonType& j, CompatibleNumberIntegerType val) noexcept +{ + external_constructor::construct(j, static_cast(val)); +} + +template::value, int> = 0> +void to_json(BasicJsonType& j, EnumType e) noexcept +{ + using underlying_type = typename std::underlying_type::type; + external_constructor::construct(j, static_cast(e)); +} + +template +void to_json(BasicJsonType& j, const std::vector& e) +{ + external_constructor::construct(j, e); +} + +template < + typename BasicJsonType, typename CompatibleArrayType, + enable_if_t < + is_compatible_array_type::value or + std::is_same::value, + int > = 0 > +void to_json(BasicJsonType& j, const CompatibleArrayType& arr) +{ + external_constructor::construct(j, arr); +} + +template ::value, int> = 0> +void to_json(BasicJsonType& j, std::valarray arr) +{ + external_constructor::construct(j, std::move(arr)); +} + +template +void to_json(BasicJsonType& j, typename BasicJsonType::array_t&& arr) +{ + external_constructor::construct(j, std::move(arr)); +} + +template < + typename BasicJsonType, typename CompatibleObjectType, + enable_if_t::value, + int> = 0 > +void to_json(BasicJsonType& j, const CompatibleObjectType& obj) +{ + external_constructor::construct(j, obj); +} + +template +void to_json(BasicJsonType& j, typename BasicJsonType::object_t&& obj) +{ + external_constructor::construct(j, std::move(obj)); +} + +template::value, + int> = 0> +void to_json(BasicJsonType& j, T (&arr)[N]) +{ + external_constructor::construct(j, arr); +} + +template +void to_json(BasicJsonType& j, const std::pair& p) +{ + j = {p.first, p.second}; +} + +template +void to_json_tuple_impl(BasicJsonType& j, const Tuple& t, index_sequence) +{ + j = {std::get(t)...}; +} + +template +void to_json(BasicJsonType& j, const std::tuple& t) +{ + to_json_tuple_impl(j, t, index_sequence_for {}); +} + +/////////////// +// from_json // +/////////////// + +// overloads for basic_json template parameters +template::value and + not std::is_same::value, + int> = 0> +void get_arithmetic_value(const BasicJsonType& j, ArithmeticType& val) +{ + switch (static_cast(j)) + { + case value_t::number_unsigned: + { + val = static_cast(*j.template get_ptr()); + break; + } + case value_t::number_integer: + { + val = static_cast(*j.template get_ptr()); + break; + } + case value_t::number_float: + { + val = static_cast(*j.template get_ptr()); + break; + } + + default: + JSON_THROW(type_error::create(302, "type must be number, but is " + std::string(j.type_name()))); + } +} + +template +void from_json(const BasicJsonType& j, typename BasicJsonType::boolean_t& b) +{ + if (JSON_UNLIKELY(not j.is_boolean())) + { + JSON_THROW(type_error::create(302, "type must be boolean, but is " + std::string(j.type_name()))); + } + b = *j.template get_ptr(); +} + +template +void from_json(const BasicJsonType& j, typename BasicJsonType::string_t& s) +{ + if (JSON_UNLIKELY(not j.is_string())) + { + JSON_THROW(type_error::create(302, "type must be string, but is " + std::string(j.type_name()))); + } + s = *j.template get_ptr(); +} + +template +void from_json(const BasicJsonType& j, typename BasicJsonType::number_float_t& val) +{ + get_arithmetic_value(j, val); +} + +template +void from_json(const BasicJsonType& j, typename BasicJsonType::number_unsigned_t& val) +{ + get_arithmetic_value(j, val); +} + +template +void from_json(const BasicJsonType& j, typename BasicJsonType::number_integer_t& val) +{ + get_arithmetic_value(j, val); +} + +template::value, int> = 0> +void from_json(const BasicJsonType& j, EnumType& e) +{ + typename std::underlying_type::type val; + get_arithmetic_value(j, val); + e = static_cast(val); +} + +template +void from_json(const BasicJsonType& j, typename BasicJsonType::array_t& arr) +{ + if (JSON_UNLIKELY(not j.is_array())) + { + JSON_THROW(type_error::create(302, "type must be array, but is " + std::string(j.type_name()))); + } + arr = *j.template get_ptr(); +} + +// forward_list doesn't have an insert method +template::value, int> = 0> +void from_json(const BasicJsonType& j, std::forward_list& l) +{ + if (JSON_UNLIKELY(not j.is_array())) + { + JSON_THROW(type_error::create(302, "type must be array, but is " + std::string(j.type_name()))); + } + std::transform(j.rbegin(), j.rend(), + std::front_inserter(l), [](const BasicJsonType & i) + { + return i.template get(); + }); +} + +// valarray doesn't have an insert method +template::value, int> = 0> +void from_json(const BasicJsonType& j, std::valarray& l) +{ + if (JSON_UNLIKELY(not j.is_array())) + { + JSON_THROW(type_error::create(302, "type must be array, but is " + std::string(j.type_name()))); + } + l.resize(j.size()); + for (size_t i = 0; i < j.size(); ++i) + { + l[i] = j[i]; + } +} + +template +void from_json_array_impl(const BasicJsonType& j, CompatibleArrayType& arr, priority_tag<0> /*unused*/) +{ + using std::end; + + std::transform(j.begin(), j.end(), + std::inserter(arr, end(arr)), [](const BasicJsonType & i) + { + // get() returns *this, this won't call a from_json + // method when value_type is BasicJsonType + return i.template get(); + }); +} + +template +auto from_json_array_impl(const BasicJsonType& j, CompatibleArrayType& arr, priority_tag<1> /*unused*/) +-> decltype( + arr.reserve(std::declval()), + void()) +{ + using std::end; + + arr.reserve(j.size()); + std::transform(j.begin(), j.end(), + std::inserter(arr, end(arr)), [](const BasicJsonType & i) + { + // get() returns *this, this won't call a from_json + // method when value_type is BasicJsonType + return i.template get(); + }); +} + +template +void from_json_array_impl(const BasicJsonType& j, std::array& arr, priority_tag<2> /*unused*/) +{ + for (std::size_t i = 0; i < N; ++i) + { + arr[i] = j.at(i).template get(); + } +} + +template::value and + std::is_convertible::value and + not std::is_same::value, int> = 0> +void from_json(const BasicJsonType& j, CompatibleArrayType& arr) +{ + if (JSON_UNLIKELY(not j.is_array())) + { + JSON_THROW(type_error::create(302, "type must be array, but is " + std::string(j.type_name()))); + } + + from_json_array_impl(j, arr, priority_tag<2> {}); +} + +template::value, int> = 0> +void from_json(const BasicJsonType& j, CompatibleObjectType& obj) +{ + if (JSON_UNLIKELY(not j.is_object())) + { + JSON_THROW(type_error::create(302, "type must be object, but is " + std::string(j.type_name()))); + } + + auto inner_object = j.template get_ptr(); + using value_type = typename CompatibleObjectType::value_type; + std::transform( + inner_object->begin(), inner_object->end(), + std::inserter(obj, obj.begin()), + [](typename BasicJsonType::object_t::value_type const & p) + { + return value_type(p.first, p.second.template get()); + }); +} + +// overload for arithmetic types, not chosen for basic_json template arguments +// (BooleanType, etc..); note: Is it really necessary to provide explicit +// overloads for boolean_t etc. in case of a custom BooleanType which is not +// an arithmetic type? +template::value and + not std::is_same::value and + not std::is_same::value and + not std::is_same::value and + not std::is_same::value, + int> = 0> +void from_json(const BasicJsonType& j, ArithmeticType& val) +{ + switch (static_cast(j)) + { + case value_t::number_unsigned: + { + val = static_cast(*j.template get_ptr()); + break; + } + case value_t::number_integer: + { + val = static_cast(*j.template get_ptr()); + break; + } + case value_t::number_float: + { + val = static_cast(*j.template get_ptr()); + break; + } + case value_t::boolean: + { + val = static_cast(*j.template get_ptr()); + break; + } + + default: + JSON_THROW(type_error::create(302, "type must be number, but is " + std::string(j.type_name()))); + } +} + +template +void from_json(const BasicJsonType& j, std::pair& p) +{ + p = {j.at(0).template get(), j.at(1).template get()}; +} + +template +void from_json_tuple_impl(const BasicJsonType& j, Tuple& t, index_sequence) +{ + t = std::make_tuple(j.at(Idx).template get::type>()...); +} + +template +void from_json(const BasicJsonType& j, std::tuple& t) +{ + from_json_tuple_impl(j, t, index_sequence_for {}); +} + +struct to_json_fn +{ + private: + template + auto call(BasicJsonType& j, T&& val, priority_tag<1> /*unused*/) const noexcept(noexcept(to_json(j, std::forward(val)))) + -> decltype(to_json(j, std::forward(val)), void()) + { + return to_json(j, std::forward(val)); + } + + template + void call(BasicJsonType& /*unused*/, T&& /*unused*/, priority_tag<0> /*unused*/) const noexcept + { + static_assert(sizeof(BasicJsonType) == 0, + "could not find to_json() method in T's namespace"); + } + + public: + template + void operator()(BasicJsonType& j, T&& val) const + noexcept(noexcept(std::declval().call(j, std::forward(val), priority_tag<1> {}))) + { + return call(j, std::forward(val), priority_tag<1> {}); + } +}; + +struct from_json_fn +{ + private: + template + auto call(const BasicJsonType& j, T& val, priority_tag<1> /*unused*/) const + noexcept(noexcept(from_json(j, val))) + -> decltype(from_json(j, val), void()) + { + return from_json(j, val); + } + + template + void call(const BasicJsonType& /*unused*/, T& /*unused*/, priority_tag<0> /*unused*/) const noexcept + { + static_assert(sizeof(BasicJsonType) == 0, + "could not find from_json() method in T's namespace"); + } + + public: + template + void operator()(const BasicJsonType& j, T& val) const + noexcept(noexcept(std::declval().call(j, val, priority_tag<1> {}))) + { + return call(j, val, priority_tag<1> {}); + } +}; + +// taken from ranges-v3 +template +struct static_const +{ + static constexpr T value{}; +}; + +template +constexpr T static_const::value; + +//////////////////// +// input adapters // +//////////////////// + +/// abstract input adapter interface +struct input_adapter_protocol +{ + virtual int get_character() = 0; + virtual std::string read(std::size_t offset, std::size_t length) = 0; + virtual ~input_adapter_protocol() = default; +}; + +/// a type to simplify interfaces +using input_adapter_t = std::shared_ptr; + +/// input adapter for cached stream input +template +class cached_input_stream_adapter : public input_adapter_protocol +{ + public: + explicit cached_input_stream_adapter(std::istream& i) + : is(i), start_position(is.tellg()) + { + fill_buffer(); + + // skip byte order mark + if (fill_size >= 3 and buffer[0] == '\xEF' and buffer[1] == '\xBB' and buffer[2] == '\xBF') + { + buffer_pos += 3; + processed_chars += 3; + } + } + + ~cached_input_stream_adapter() override + { + // clear stream flags + is.clear(); + // We initially read a lot of characters into the buffer, and we may + // not have processed all of them. Therefore, we need to "rewind" the + // stream after the last processed char. + is.seekg(start_position); + is.ignore(static_cast(processed_chars)); + // clear stream flags + is.clear(); + } + + int get_character() override + { + // check if refilling is necessary and possible + if (buffer_pos == fill_size and not eof) + { + fill_buffer(); + + // check and remember that filling did not yield new input + if (fill_size == 0) + { + eof = true; + return std::char_traits::eof(); + } + + // the buffer is ready + buffer_pos = 0; + } + + ++processed_chars; + assert(buffer_pos < buffer.size()); + return buffer[buffer_pos++] & 0xFF; + } + + std::string read(std::size_t offset, std::size_t length) override + { + // create buffer + std::string result(length, '\0'); + + // save stream position + const auto current_pos = is.tellg(); + // save stream flags + const auto flags = is.rdstate(); + + // clear stream flags + is.clear(); + // set stream position + is.seekg(static_cast(offset)); + // read bytes + is.read(&result[0], static_cast(length)); + + // reset stream position + is.seekg(current_pos); + // reset stream flags + is.setstate(flags); + + return result; + } + + private: + void fill_buffer() + { + // fill + is.read(buffer.data(), static_cast(buffer.size())); + // store number of bytes in the buffer + fill_size = static_cast(is.gcount()); + } + + /// the associated input stream + std::istream& is; + + /// chars returned via get_character() + std::size_t processed_chars = 0; + /// chars processed in the current buffer + std::size_t buffer_pos = 0; + + /// whether stream reached eof + bool eof = false; + /// how many chars have been copied to the buffer by last (re)fill + std::size_t fill_size = 0; + + /// position of the stream when we started + const std::streampos start_position; + + /// internal buffer + std::array buffer{{}}; +}; + +/// input adapter for buffer input +class input_buffer_adapter : public input_adapter_protocol +{ + public: + input_buffer_adapter(const char* b, const std::size_t l) + : cursor(b), limit(b + l), start(b) + { + // skip byte order mark + if (l >= 3 and b[0] == '\xEF' and b[1] == '\xBB' and b[2] == '\xBF') + { + cursor += 3; + } + } + + // delete because of pointer members + input_buffer_adapter(const input_buffer_adapter&) = delete; + input_buffer_adapter& operator=(input_buffer_adapter&) = delete; + + int get_character() noexcept override + { + if (JSON_LIKELY(cursor < limit)) + { + return *(cursor++) & 0xFF; + } + + return std::char_traits::eof(); + } + + std::string read(std::size_t offset, std::size_t length) override + { + // avoid reading too many characters + const auto max_length = static_cast(limit - start); + return std::string(start + offset, (std::min)(length, max_length - offset)); + } + + private: + /// pointer to the current character + const char* cursor; + /// pointer past the last character + const char* limit; + /// pointer to the first character + const char* start; +}; + +class input_adapter +{ + public: + // native support + + /// input adapter for input stream + input_adapter(std::istream& i) + : ia(std::make_shared>(i)) {} + + /// input adapter for input stream + input_adapter(std::istream&& i) + : ia(std::make_shared>(i)) {} + + /// input adapter for buffer + template::value and + std::is_integral< + typename std::remove_pointer::type>::value and + sizeof(typename std::remove_pointer::type) == 1, + int>::type = 0> + input_adapter(CharT b, std::size_t l) + : ia(std::make_shared(reinterpret_cast(b), l)) {} + + // derived support + + /// input adapter for string literal + template::value and + std::is_integral< + typename std::remove_pointer::type>::value and + sizeof(typename std::remove_pointer::type) == 1, + int>::type = 0> + input_adapter(CharT b) + : input_adapter(reinterpret_cast(b), + std::strlen(reinterpret_cast(b))) {} + + /// input adapter for iterator range with contiguous storage + template::iterator_category, + std::random_access_iterator_tag>::value, + int>::type = 0> + input_adapter(IteratorType first, IteratorType last) + { + // assertion to check that the iterator range is indeed contiguous, + // see http://stackoverflow.com/a/35008842/266378 for more discussion + assert(std::accumulate( + first, last, std::pair(true, 0), + [&first](std::pair res, decltype(*first) val) + { + res.first &= (val == *(std::next(std::addressof(*first), res.second++))); + return res; + }).first); + + // assertion to check that each element is 1 byte long + static_assert( + sizeof(typename std::iterator_traits::value_type) == 1, + "each element in the iterator range must have the size of 1 byte"); + + const auto len = static_cast(std::distance(first, last)); + if (JSON_LIKELY(len > 0)) + { + // there is at least one element: use the address of first + ia = std::make_shared(reinterpret_cast(&(*first)), len); + } + else + { + // the address of first cannot be used: use nullptr + ia = std::make_shared(nullptr, len); + } + } + + /// input adapter for array + template + input_adapter(T (&array)[N]) + : input_adapter(std::begin(array), std::end(array)) {} + + /// input adapter for contiguous container + template < + class ContiguousContainer, + typename std::enable_if < + not std::is_pointer::value and + std::is_base_of()))>::iterator_category>::value, + int >::type = 0 > + input_adapter(const ContiguousContainer& c) + : input_adapter(std::begin(c), std::end(c)) {} + + operator input_adapter_t() + { + return ia; + } + + private: + /// the actual adapter + input_adapter_t ia = nullptr; +}; + +////////////////////// +// lexer and parser // +////////////////////// + +/*! +@brief lexical analysis + +This class organizes the lexical analysis during JSON deserialization. +*/ +template +class lexer +{ + using number_integer_t = typename BasicJsonType::number_integer_t; + using number_unsigned_t = typename BasicJsonType::number_unsigned_t; + using number_float_t = typename BasicJsonType::number_float_t; + + public: + /// token types for the parser + enum class token_type + { + uninitialized, ///< indicating the scanner is uninitialized + literal_true, ///< the `true` literal + literal_false, ///< the `false` literal + literal_null, ///< the `null` literal + value_string, ///< a string -- use get_string() for actual value + value_unsigned, ///< an unsigned integer -- use get_number_unsigned() for actual value + value_integer, ///< a signed integer -- use get_number_integer() for actual value + value_float, ///< an floating point number -- use get_number_float() for actual value + begin_array, ///< the character for array begin `[` + begin_object, ///< the character for object begin `{` + end_array, ///< the character for array end `]` + end_object, ///< the character for object end `}` + name_separator, ///< the name separator `:` + value_separator, ///< the value separator `,` + parse_error, ///< indicating a parse error + end_of_input, ///< indicating the end of the input buffer + literal_or_value ///< a literal or the begin of a value (only for diagnostics) + }; + + /// return name of values of type token_type (only used for errors) + static const char* token_type_name(const token_type t) noexcept + { + switch (t) + { + case token_type::uninitialized: + return ""; + case token_type::literal_true: + return "true literal"; + case token_type::literal_false: + return "false literal"; + case token_type::literal_null: + return "null literal"; + case token_type::value_string: + return "string literal"; + case lexer::token_type::value_unsigned: + case lexer::token_type::value_integer: + case lexer::token_type::value_float: + return "number literal"; + case token_type::begin_array: + return "'['"; + case token_type::begin_object: + return "'{'"; + case token_type::end_array: + return "']'"; + case token_type::end_object: + return "'}'"; + case token_type::name_separator: + return "':'"; + case token_type::value_separator: + return "','"; + case token_type::parse_error: + return ""; + case token_type::end_of_input: + return "end of input"; + case token_type::literal_or_value: + return "'[', '{', or a literal"; + default: // catch non-enum values + return "unknown token"; // LCOV_EXCL_LINE + } + } + + explicit lexer(detail::input_adapter_t adapter) + : ia(std::move(adapter)), decimal_point_char(get_decimal_point()) {} + + // delete because of pointer members + lexer(const lexer&) = delete; + lexer& operator=(lexer&) = delete; + + private: + ///////////////////// + // locales + ///////////////////// + + /// return the locale-dependent decimal point + static char get_decimal_point() noexcept + { + const auto loc = localeconv(); + assert(loc != nullptr); + return (loc->decimal_point == nullptr) ? '.' : loc->decimal_point[0]; + } + + ///////////////////// + // scan functions + ///////////////////// + + /*! + @brief get codepoint from 4 hex characters following `\u` + + For input "\u c1 c2 c3 c4" the codepoint is: + (c1 * 0x1000) + (c2 * 0x0100) + (c3 * 0x0010) + c4 + = (c1 << 12) + (c2 << 8) + (c3 << 4) + (c4 << 0) + + Furthermore, the possible characters '0'..'9', 'A'..'F', and 'a'..'f' + must be converted to the integers 0x0..0x9, 0xA..0xF, 0xA..0xF, resp. The + conversion is done by subtracting the offset (0x30, 0x37, and 0x57) + between the ASCII value of the character and the desired integer value. + + @return codepoint (0x0000..0xFFFF) or -1 in case of an error (e.g. EOF or + non-hex character) + */ + int get_codepoint() + { + // this function only makes sense after reading `\u` + assert(current == 'u'); + int codepoint = 0; + + const auto factors = { 12, 8, 4, 0 }; + for (const auto factor : factors) + { + get(); + + if (current >= '0' and current <= '9') + { + codepoint += ((current - 0x30) << factor); + } + else if (current >= 'A' and current <= 'F') + { + codepoint += ((current - 0x37) << factor); + } + else if (current >= 'a' and current <= 'f') + { + codepoint += ((current - 0x57) << factor); + } + else + { + return -1; + } + } + + assert(0x0000 <= codepoint and codepoint <= 0xFFFF); + return codepoint; + } + + /*! + @brief check if the next byte(s) are inside a given range + + Adds the current byte and, for each passed range, reads a new byte and + checks if it is inside the range. If a violation was detected, set up an + error message and return false. Otherwise, return true. + + @return true if and only if no range violation was detected + */ + bool next_byte_in_range(std::initializer_list ranges) + { + assert(ranges.size() == 2 or ranges.size() == 4 or ranges.size() == 6); + add(current); + + for (auto range = ranges.begin(); range != ranges.end(); ++range) + { + get(); + if (JSON_LIKELY(*range <= current and current <= *(++range))) + { + add(current); + } + else + { + error_message = "invalid string: ill-formed UTF-8 byte"; + return false; + } + } + + return true; + } + + /*! + @brief scan a string literal + + This function scans a string according to Sect. 7 of RFC 7159. While + scanning, bytes are escaped and copied into buffer yytext. Then the + function returns successfully, yytext is null-terminated and yylen + contains the number of bytes in the string. + + @return token_type::value_string if string could be successfully scanned, + token_type::parse_error otherwise + + @note In case of errors, variable error_message contains a textual + description. + */ + token_type scan_string() + { + // reset yytext (ignore opening quote) + reset(); + + // we entered the function by reading an open quote + assert(current == '\"'); + + while (true) + { + // get next character + switch (get()) + { + // end of file while parsing string + case std::char_traits::eof(): + { + error_message = "invalid string: missing closing quote"; + return token_type::parse_error; + } + + // closing quote + case '\"': + { + // terminate yytext + add('\0'); + --yylen; + return token_type::value_string; + } + + // escapes + case '\\': + { + switch (get()) + { + // quotation mark + case '\"': + add('\"'); + break; + // reverse solidus + case '\\': + add('\\'); + break; + // solidus + case '/': + add('/'); + break; + // backspace + case 'b': + add('\b'); + break; + // form feed + case 'f': + add('\f'); + break; + // line feed + case 'n': + add('\n'); + break; + // carriage return + case 'r': + add('\r'); + break; + // tab + case 't': + add('\t'); + break; + + // unicode escapes + case 'u': + { + int codepoint; + const int codepoint1 = get_codepoint(); + + if (JSON_UNLIKELY(codepoint1 == -1)) + { + error_message = "invalid string: '\\u' must be followed by 4 hex digits"; + return token_type::parse_error; + } + + // check if code point is a high surrogate + if (0xD800 <= codepoint1 and codepoint1 <= 0xDBFF) + { + // expect next \uxxxx entry + if (JSON_LIKELY(get() == '\\' and get() == 'u')) + { + const int codepoint2 = get_codepoint(); + + if (JSON_UNLIKELY(codepoint2 == -1)) + { + error_message = "invalid string: '\\u' must be followed by 4 hex digits"; + return token_type::parse_error; + } + + // check if codepoint2 is a low surrogate + if (JSON_LIKELY(0xDC00 <= codepoint2 and codepoint2 <= 0xDFFF)) + { + codepoint = + // high surrogate occupies the most significant 22 bits + (codepoint1 << 10) + // low surrogate occupies the least significant 15 bits + + codepoint2 + // there is still the 0xD800, 0xDC00 and 0x10000 noise + // in the result so we have to subtract with: + // (0xD800 << 10) + DC00 - 0x10000 = 0x35FDC00 + - 0x35FDC00; + } + else + { + error_message = "invalid string: surrogate U+DC00..U+DFFF must be followed by U+DC00..U+DFFF"; + return token_type::parse_error; + } + } + else + { + error_message = "invalid string: surrogate U+DC00..U+DFFF must be followed by U+DC00..U+DFFF"; + return token_type::parse_error; + } + } + else + { + if (JSON_UNLIKELY(0xDC00 <= codepoint1 and codepoint1 <= 0xDFFF)) + { + error_message = "invalid string: surrogate U+DC00..U+DFFF must follow U+D800..U+DBFF"; + return token_type::parse_error; + } + + // only work with first code point + codepoint = codepoint1; + } + + // result of the above calculation yields a proper codepoint + assert(0x00 <= codepoint and codepoint <= 0x10FFFF); + + // translate code point to bytes + if (codepoint < 0x80) + { + // 1-byte characters: 0xxxxxxx (ASCII) + add(codepoint); + } + else if (codepoint <= 0x7ff) + { + // 2-byte characters: 110xxxxx 10xxxxxx + add(0xC0 | (codepoint >> 6)); + add(0x80 | (codepoint & 0x3F)); + } + else if (codepoint <= 0xffff) + { + // 3-byte characters: 1110xxxx 10xxxxxx 10xxxxxx + add(0xE0 | (codepoint >> 12)); + add(0x80 | ((codepoint >> 6) & 0x3F)); + add(0x80 | (codepoint & 0x3F)); + } + else + { + // 4-byte characters: 11110xxx 10xxxxxx 10xxxxxx 10xxxxxx + add(0xF0 | (codepoint >> 18)); + add(0x80 | ((codepoint >> 12) & 0x3F)); + add(0x80 | ((codepoint >> 6) & 0x3F)); + add(0x80 | (codepoint & 0x3F)); + } + + break; + } + + // other characters after escape + default: + error_message = "invalid string: forbidden character after backslash"; + return token_type::parse_error; + } + + break; + } + + // invalid control characters + case 0x00: + case 0x01: + case 0x02: + case 0x03: + case 0x04: + case 0x05: + case 0x06: + case 0x07: + case 0x08: + case 0x09: + case 0x0a: + case 0x0b: + case 0x0c: + case 0x0d: + case 0x0e: + case 0x0f: + case 0x10: + case 0x11: + case 0x12: + case 0x13: + case 0x14: + case 0x15: + case 0x16: + case 0x17: + case 0x18: + case 0x19: + case 0x1a: + case 0x1b: + case 0x1c: + case 0x1d: + case 0x1e: + case 0x1f: + { + error_message = "invalid string: control character must be escaped"; + return token_type::parse_error; + } + + // U+0020..U+007F (except U+0022 (quote) and U+005C (backspace)) + case 0x20: + case 0x21: + case 0x23: + case 0x24: + case 0x25: + case 0x26: + case 0x27: + case 0x28: + case 0x29: + case 0x2a: + case 0x2b: + case 0x2c: + case 0x2d: + case 0x2e: + case 0x2f: + case 0x30: + case 0x31: + case 0x32: + case 0x33: + case 0x34: + case 0x35: + case 0x36: + case 0x37: + case 0x38: + case 0x39: + case 0x3a: + case 0x3b: + case 0x3c: + case 0x3d: + case 0x3e: + case 0x3f: + case 0x40: + case 0x41: + case 0x42: + case 0x43: + case 0x44: + case 0x45: + case 0x46: + case 0x47: + case 0x48: + case 0x49: + case 0x4a: + case 0x4b: + case 0x4c: + case 0x4d: + case 0x4e: + case 0x4f: + case 0x50: + case 0x51: + case 0x52: + case 0x53: + case 0x54: + case 0x55: + case 0x56: + case 0x57: + case 0x58: + case 0x59: + case 0x5a: + case 0x5b: + case 0x5d: + case 0x5e: + case 0x5f: + case 0x60: + case 0x61: + case 0x62: + case 0x63: + case 0x64: + case 0x65: + case 0x66: + case 0x67: + case 0x68: + case 0x69: + case 0x6a: + case 0x6b: + case 0x6c: + case 0x6d: + case 0x6e: + case 0x6f: + case 0x70: + case 0x71: + case 0x72: + case 0x73: + case 0x74: + case 0x75: + case 0x76: + case 0x77: + case 0x78: + case 0x79: + case 0x7a: + case 0x7b: + case 0x7c: + case 0x7d: + case 0x7e: + case 0x7f: + { + add(current); + break; + } + + // U+0080..U+07FF: bytes C2..DF 80..BF + case 0xc2: + case 0xc3: + case 0xc4: + case 0xc5: + case 0xc6: + case 0xc7: + case 0xc8: + case 0xc9: + case 0xca: + case 0xcb: + case 0xcc: + case 0xcd: + case 0xce: + case 0xcf: + case 0xd0: + case 0xd1: + case 0xd2: + case 0xd3: + case 0xd4: + case 0xd5: + case 0xd6: + case 0xd7: + case 0xd8: + case 0xd9: + case 0xda: + case 0xdb: + case 0xdc: + case 0xdd: + case 0xde: + case 0xdf: + { + if (JSON_UNLIKELY(not next_byte_in_range({0x80, 0xBF}))) + { + return token_type::parse_error; + } + break; + } + + // U+0800..U+0FFF: bytes E0 A0..BF 80..BF + case 0xe0: + { + if (JSON_UNLIKELY(not (next_byte_in_range({0xA0, 0xBF, 0x80, 0xBF})))) + { + return token_type::parse_error; + } + break; + } + + // U+1000..U+CFFF: bytes E1..EC 80..BF 80..BF + // U+E000..U+FFFF: bytes EE..EF 80..BF 80..BF + case 0xe1: + case 0xe2: + case 0xe3: + case 0xe4: + case 0xe5: + case 0xe6: + case 0xe7: + case 0xe8: + case 0xe9: + case 0xea: + case 0xeb: + case 0xec: + case 0xee: + case 0xef: + { + if (JSON_UNLIKELY(not (next_byte_in_range({0x80, 0xBF, 0x80, 0xBF})))) + { + return token_type::parse_error; + } + break; + } + + // U+D000..U+D7FF: bytes ED 80..9F 80..BF + case 0xed: + { + if (JSON_UNLIKELY(not (next_byte_in_range({0x80, 0x9F, 0x80, 0xBF})))) + { + return token_type::parse_error; + } + break; + } + + // U+10000..U+3FFFF F0 90..BF 80..BF 80..BF + case 0xf0: + { + if (JSON_UNLIKELY(not (next_byte_in_range({0x90, 0xBF, 0x80, 0xBF, 0x80, 0xBF})))) + { + return token_type::parse_error; + } + break; + } + + // U+40000..U+FFFFF F1..F3 80..BF 80..BF 80..BF + case 0xf1: + case 0xf2: + case 0xf3: + { + if (JSON_UNLIKELY(not (next_byte_in_range({0x80, 0xBF, 0x80, 0xBF, 0x80, 0xBF})))) + { + return token_type::parse_error; + } + break; + } + + // U+100000..U+10FFFF F4 80..8F 80..BF 80..BF + case 0xf4: + { + if (JSON_UNLIKELY(not (next_byte_in_range({0x80, 0x8F, 0x80, 0xBF, 0x80, 0xBF})))) + { + return token_type::parse_error; + } + break; + } + + // remaining bytes (80..C1 and F5..FF) are ill-formed + default: + { + error_message = "invalid string: ill-formed UTF-8 byte"; + return token_type::parse_error; + } + } + } + } + + static void strtof(float& f, const char* str, char** endptr) noexcept + { + f = std::strtof(str, endptr); + } + + static void strtof(double& f, const char* str, char** endptr) noexcept + { + f = std::strtod(str, endptr); + } + + static void strtof(long double& f, const char* str, char** endptr) noexcept + { + f = std::strtold(str, endptr); + } + + /*! + @brief scan a number literal + + This function scans a string according to Sect. 6 of RFC 7159. + + The function is realized with a deterministic finite state machine derived + from the grammar described in RFC 7159. Starting in state "init", the + input is read and used to determined the next state. Only state "done" + accepts the number. State "error" is a trap state to model errors. In the + table below, "anything" means any character but the ones listed before. + + state | 0 | 1-9 | e E | + | - | . | anything + ---------|----------|----------|----------|---------|---------|----------|----------- + init | zero | any1 | [error] | [error] | minus | [error] | [error] + minus | zero | any1 | [error] | [error] | [error] | [error] | [error] + zero | done | done | exponent | done | done | decimal1 | done + any1 | any1 | any1 | exponent | done | done | decimal1 | done + decimal1 | decimal2 | [error] | [error] | [error] | [error] | [error] | [error] + decimal2 | decimal2 | decimal2 | exponent | done | done | done | done + exponent | any2 | any2 | [error] | sign | sign | [error] | [error] + sign | any2 | any2 | [error] | [error] | [error] | [error] | [error] + any2 | any2 | any2 | done | done | done | done | done + + The state machine is realized with one label per state (prefixed with + "scan_number_") and `goto` statements between them. The state machine + contains cycles, but any cycle can be left when EOF is read. Therefore, + the function is guaranteed to terminate. + + During scanning, the read bytes are stored in yytext. This string is + then converted to a signed integer, an unsigned integer, or a + floating-point number. + + @return token_type::value_unsigned, token_type::value_integer, or + token_type::value_float if number could be successfully scanned, + token_type::parse_error otherwise + + @note The scanner is independent of the current locale. Internally, the + locale's decimal point is used instead of `.` to work with the + locale-dependent converters. + */ + token_type scan_number() + { + // reset yytext to store the number's bytes + reset(); + + // the type of the parsed number; initially set to unsigned; will be + // changed if minus sign, decimal point or exponent is read + token_type number_type = token_type::value_unsigned; + + // state (init): we just found out we need to scan a number + switch (current) + { + case '-': + { + add(current); + goto scan_number_minus; + } + + case '0': + { + add(current); + goto scan_number_zero; + } + + case '1': + case '2': + case '3': + case '4': + case '5': + case '6': + case '7': + case '8': + case '9': + { + add(current); + goto scan_number_any1; + } + + default: + { + // all other characters are rejected outside scan_number() + assert(false); // LCOV_EXCL_LINE + } + } + +scan_number_minus: + // state: we just parsed a leading minus sign + number_type = token_type::value_integer; + switch (get()) + { + case '0': + { + add(current); + goto scan_number_zero; + } + + case '1': + case '2': + case '3': + case '4': + case '5': + case '6': + case '7': + case '8': + case '9': + { + add(current); + goto scan_number_any1; + } + + default: + { + error_message = "invalid number; expected digit after '-'"; + return token_type::parse_error; + } + } + +scan_number_zero: + // state: we just parse a zero (maybe with a leading minus sign) + switch (get()) + { + case '.': + { + add(decimal_point_char); + goto scan_number_decimal1; + } + + case 'e': + case 'E': + { + add(current); + goto scan_number_exponent; + } + + default: + goto scan_number_done; + } + +scan_number_any1: + // state: we just parsed a number 0-9 (maybe with a leading minus sign) + switch (get()) + { + case '0': + case '1': + case '2': + case '3': + case '4': + case '5': + case '6': + case '7': + case '8': + case '9': + { + add(current); + goto scan_number_any1; + } + + case '.': + { + add(decimal_point_char); + goto scan_number_decimal1; + } + + case 'e': + case 'E': + { + add(current); + goto scan_number_exponent; + } + + default: + goto scan_number_done; + } + +scan_number_decimal1: + // state: we just parsed a decimal point + number_type = token_type::value_float; + switch (get()) + { + case '0': + case '1': + case '2': + case '3': + case '4': + case '5': + case '6': + case '7': + case '8': + case '9': + { + add(current); + goto scan_number_decimal2; + } + + default: + { + error_message = "invalid number; expected digit after '.'"; + return token_type::parse_error; + } + } + +scan_number_decimal2: + // we just parsed at least one number after a decimal point + switch (get()) + { + case '0': + case '1': + case '2': + case '3': + case '4': + case '5': + case '6': + case '7': + case '8': + case '9': + { + add(current); + goto scan_number_decimal2; + } + + case 'e': + case 'E': + { + add(current); + goto scan_number_exponent; + } + + default: + goto scan_number_done; + } + +scan_number_exponent: + // we just parsed an exponent + number_type = token_type::value_float; + switch (get()) + { + case '+': + case '-': + { + add(current); + goto scan_number_sign; + } + + case '0': + case '1': + case '2': + case '3': + case '4': + case '5': + case '6': + case '7': + case '8': + case '9': + { + add(current); + goto scan_number_any2; + } + + default: + { + error_message = + "invalid number; expected '+', '-', or digit after exponent"; + return token_type::parse_error; + } + } + +scan_number_sign: + // we just parsed an exponent sign + switch (get()) + { + case '0': + case '1': + case '2': + case '3': + case '4': + case '5': + case '6': + case '7': + case '8': + case '9': + { + add(current); + goto scan_number_any2; + } + + default: + { + error_message = "invalid number; expected digit after exponent sign"; + return token_type::parse_error; + } + } + +scan_number_any2: + // we just parsed a number after the exponent or exponent sign + switch (get()) + { + case '0': + case '1': + case '2': + case '3': + case '4': + case '5': + case '6': + case '7': + case '8': + case '9': + { + add(current); + goto scan_number_any2; + } + + default: + goto scan_number_done; + } + +scan_number_done: + // unget the character after the number (we only read it to know that + // we are done scanning a number) + --chars_read; + next_unget = true; + + // terminate token + add('\0'); + --yylen; + + char* endptr = nullptr; + errno = 0; + + // try to parse integers first and fall back to floats + if (number_type == token_type::value_unsigned) + { + const auto x = std::strtoull(yytext.data(), &endptr, 10); + + // we checked the number format before + assert(endptr == yytext.data() + yylen); + + if (errno == 0) + { + value_unsigned = static_cast(x); + if (value_unsigned == x) + { + return token_type::value_unsigned; + } + } + } + else if (number_type == token_type::value_integer) + { + const auto x = std::strtoll(yytext.data(), &endptr, 10); + + // we checked the number format before + assert(endptr == yytext.data() + yylen); + + if (errno == 0) + { + value_integer = static_cast(x); + if (value_integer == x) + { + return token_type::value_integer; + } + } + } + + // this code is reached if we parse a floating-point number or if an + // integer conversion above failed + strtof(value_float, yytext.data(), &endptr); + + // we checked the number format before + assert(endptr == yytext.data() + yylen); + + return token_type::value_float; + } + + /*! + @param[in] literal_text the literal text to expect + @param[in] length the length of the passed literal text + @param[in] return_type the token type to return on success + */ + token_type scan_literal(const char* literal_text, const std::size_t length, + token_type return_type) + { + assert(current == literal_text[0]); + for (std::size_t i = 1; i < length; ++i) + { + if (JSON_UNLIKELY(get() != literal_text[i])) + { + error_message = "invalid literal"; + return token_type::parse_error; + } + } + return return_type; + } + + ///////////////////// + // input management + ///////////////////// + + /// reset yytext + void reset() noexcept + { + yylen = 0; + start_pos = chars_read - 1; + } + + /// get a character from the input + int get() + { + ++chars_read; + return next_unget ? (next_unget = false, current) + : (current = ia->get_character()); + } + + /// add a character to yytext + void add(int c) + { + // resize yytext if necessary; this condition is deemed unlikely, + // because we start with a 1024-byte buffer + if (JSON_UNLIKELY((yylen + 1 > yytext.capacity()))) + { + yytext.resize(2 * yytext.capacity(), '\0'); + } + assert(yylen < yytext.size()); + yytext[yylen++] = static_cast(c); + } + + public: + ///////////////////// + // value getters + ///////////////////// + + /// return integer value + constexpr number_integer_t get_number_integer() const noexcept + { + return value_integer; + } + + /// return unsigned integer value + constexpr number_unsigned_t get_number_unsigned() const noexcept + { + return value_unsigned; + } + + /// return floating-point value + constexpr number_float_t get_number_float() const noexcept + { + return value_float; + } + + /// return string value + const std::string get_string() + { + // yytext cannot be returned as char*, because it may contain a null + // byte (parsed as "\u0000") + return std::string(yytext.data(), yylen); + } + + ///////////////////// + // diagnostics + ///////////////////// + + /// return position of last read token + constexpr std::size_t get_position() const noexcept + { + return chars_read; + } + + /// return the last read token (for errors only) + std::string get_token_string() const + { + // get the raw byte sequence of the last token + std::string s = ia->read(start_pos, chars_read - start_pos); + + // escape control characters + std::string result; + for (auto c : s) + { + if (c == '\0' or c == std::char_traits::eof()) + { + // ignore EOF + continue; + } + else if ('\x00' <= c and c <= '\x1f') + { + // escape control characters + std::stringstream ss; + ss << "(c) << ">"; + result += ss.str(); + } + else + { + // add character as is + result.push_back(c); + } + } + + return result; + } + + /// return syntax error message + constexpr const char* get_error_message() const noexcept + { + return error_message; + } + + ///////////////////// + // actual scanner + ///////////////////// + + token_type scan() + { + // read next character and ignore whitespace + do + { + get(); + } + while (current == ' ' or current == '\t' or current == '\n' or current == '\r'); + + switch (current) + { + // structural characters + case '[': + return token_type::begin_array; + case ']': + return token_type::end_array; + case '{': + return token_type::begin_object; + case '}': + return token_type::end_object; + case ':': + return token_type::name_separator; + case ',': + return token_type::value_separator; + + // literals + case 't': + return scan_literal("true", 4, token_type::literal_true); + case 'f': + return scan_literal("false", 5, token_type::literal_false); + case 'n': + return scan_literal("null", 4, token_type::literal_null); + + // string + case '\"': + return scan_string(); + + // number + case '-': + case '0': + case '1': + case '2': + case '3': + case '4': + case '5': + case '6': + case '7': + case '8': + case '9': + return scan_number(); + + // end of input (the null byte is needed when parsing from + // string literals) + case '\0': + case std::char_traits::eof(): + return token_type::end_of_input; + + // error + default: + error_message = "invalid literal"; + return token_type::parse_error; + } + } + + private: + /// input adapter + detail::input_adapter_t ia = nullptr; + + /// the current character + int current = std::char_traits::eof(); + + /// whether get() should return the last character again + bool next_unget = false; + + /// the number of characters read + std::size_t chars_read = 0; + /// the start position of the current token + std::size_t start_pos = 0; + + /// buffer for variable-length tokens (numbers, strings) + std::vector yytext = std::vector(1024, '\0'); + /// current index in yytext + std::size_t yylen = 0; + + /// a description of occurred lexer errors + const char* error_message = ""; + + // number values + number_integer_t value_integer = 0; + number_unsigned_t value_unsigned = 0; + number_float_t value_float = 0; + + /// the decimal point + const char decimal_point_char = '.'; +}; + +/*! +@brief syntax analysis + +This class implements a recursive decent parser. +*/ +template +class parser +{ + using number_integer_t = typename BasicJsonType::number_integer_t; + using number_unsigned_t = typename BasicJsonType::number_unsigned_t; + using number_float_t = typename BasicJsonType::number_float_t; + using lexer_t = lexer; + using token_type = typename lexer_t::token_type; + + public: + enum class parse_event_t : uint8_t + { + /// the parser read `{` and started to process a JSON object + object_start, + /// the parser read `}` and finished processing a JSON object + object_end, + /// the parser read `[` and started to process a JSON array + array_start, + /// the parser read `]` and finished processing a JSON array + array_end, + /// the parser read a key of a value in an object + key, + /// the parser finished reading a JSON value + value + }; + + using parser_callback_t = + std::function; + + /// a parser reading from an input adapter + explicit parser(detail::input_adapter_t adapter, + const parser_callback_t cb = nullptr, + const bool allow_exceptions_ = true) + : callback(cb), m_lexer(adapter), allow_exceptions(allow_exceptions_) + {} + + /*! + @brief public parser interface + + @param[in] strict whether to expect the last token to be EOF + @param[in,out] result parsed JSON value + + @throw parse_error.101 in case of an unexpected token + @throw parse_error.102 if to_unicode fails or surrogate error + @throw parse_error.103 if to_unicode fails + */ + void parse(const bool strict, BasicJsonType& result) + { + // read first token + get_token(); + + parse_internal(true, result); + result.assert_invariant(); + + // in strict mode, input must be completely read + if (strict) + { + get_token(); + expect(token_type::end_of_input); + } + + // in case of an error, return discarded value + if (errored) + { + result = value_t::discarded; + return; + } + + // set top-level value to null if it was discarded by the callback + // function + if (result.is_discarded()) + { + result = nullptr; + } + } + + /*! + @brief public accept interface + + @param[in] strict whether to expect the last token to be EOF + @return whether the input is a proper JSON text + */ + bool accept(const bool strict = true) + { + // read first token + get_token(); + + if (not accept_internal()) + { + return false; + } + + // strict => last token must be EOF + return not strict or (get_token() == token_type::end_of_input); + } + + private: + /*! + @brief the actual parser + @throw parse_error.101 in case of an unexpected token + @throw parse_error.102 if to_unicode fails or surrogate error + @throw parse_error.103 if to_unicode fails + */ + void parse_internal(bool keep, BasicJsonType& result) + { + // never parse after a parse error was detected + assert(not errored); + + // start with a discarded value + if (not result.is_discarded()) + { + result.m_value.destroy(result.m_type); + result.m_type = value_t::discarded; + } + + switch (last_token) + { + case token_type::begin_object: + { + if (keep and (not callback or ((keep = callback(depth++, parse_event_t::object_start, result))))) + { + // explicitly set result to object to cope with {} + result.m_type = value_t::object; + result.m_value = value_t::object; + } + + // read next token + get_token(); + + // closing } -> we are done + if (last_token == token_type::end_object) + { + if (keep and callback and not callback(--depth, parse_event_t::object_end, result)) + { + result.m_value.destroy(result.m_type); + result.m_type = value_t::discarded; + } + break; + } + + // parse values + std::string key; + BasicJsonType value; + while (true) + { + // store key + if (not expect(token_type::value_string)) + { + return; + } + key = m_lexer.get_string(); + + bool keep_tag = false; + if (keep) + { + if (callback) + { + BasicJsonType k(key); + keep_tag = callback(depth, parse_event_t::key, k); + } + else + { + keep_tag = true; + } + } + + // parse separator (:) + get_token(); + if (not expect(token_type::name_separator)) + { + return; + } + + // parse and add value + get_token(); + value.m_value.destroy(value.m_type); + value.m_type = value_t::discarded; + parse_internal(keep, value); + + if (JSON_UNLIKELY(errored)) + { + return; + } + + if (keep and keep_tag and not value.is_discarded()) + { + result.m_value.object->emplace(std::move(key), std::move(value)); + } + + // comma -> next value + get_token(); + if (last_token == token_type::value_separator) + { + get_token(); + continue; + } + + // closing } + if (not expect(token_type::end_object)) + { + return; + } + break; + } + + if (keep and callback and not callback(--depth, parse_event_t::object_end, result)) + { + result.m_value.destroy(result.m_type); + result.m_type = value_t::discarded; + } + break; + } + + case token_type::begin_array: + { + if (keep and (not callback or ((keep = callback(depth++, parse_event_t::array_start, result))))) + { + // explicitly set result to object to cope with [] + result.m_type = value_t::array; + result.m_value = value_t::array; + } + + // read next token + get_token(); + + // closing ] -> we are done + if (last_token == token_type::end_array) + { + if (callback and not callback(--depth, parse_event_t::array_end, result)) + { + result.m_value.destroy(result.m_type); + result.m_type = value_t::discarded; + } + break; + } + + // parse values + BasicJsonType value; + while (true) + { + // parse value + value.m_value.destroy(value.m_type); + value.m_type = value_t::discarded; + parse_internal(keep, value); + + if (JSON_UNLIKELY(errored)) + { + return; + } + + if (keep and not value.is_discarded()) + { + result.m_value.array->push_back(std::move(value)); + } + + // comma -> next value + get_token(); + if (last_token == token_type::value_separator) + { + get_token(); + continue; + } + + // closing ] + if (not expect(token_type::end_array)) + { + return; + } + break; + } + + if (keep and callback and not callback(--depth, parse_event_t::array_end, result)) + { + result.m_value.destroy(result.m_type); + result.m_type = value_t::discarded; + } + break; + } + + case token_type::literal_null: + { + result.m_type = value_t::null; + break; + } + + case token_type::value_string: + { + result.m_type = value_t::string; + result.m_value = m_lexer.get_string(); + break; + } + + case token_type::literal_true: + { + result.m_type = value_t::boolean; + result.m_value = true; + break; + } + + case token_type::literal_false: + { + result.m_type = value_t::boolean; + result.m_value = false; + break; + } + + case token_type::value_unsigned: + { + result.m_type = value_t::number_unsigned; + result.m_value = m_lexer.get_number_unsigned(); + break; + } + + case token_type::value_integer: + { + result.m_type = value_t::number_integer; + result.m_value = m_lexer.get_number_integer(); + break; + } + + case token_type::value_float: + { + result.m_type = value_t::number_float; + result.m_value = m_lexer.get_number_float(); + + // throw in case of infinity or NAN + if (JSON_UNLIKELY(not std::isfinite(result.m_value.number_float))) + { + if (allow_exceptions) + { + JSON_THROW(out_of_range::create(406, "number overflow parsing '" + + m_lexer.get_token_string() + "'")); + } + expect(token_type::uninitialized); + } + break; + } + + case token_type::parse_error: + { + // using "uninitialized" to avoid "expected" message + if (not expect(token_type::uninitialized)) + { + return; + } + break; // LCOV_EXCL_LINE + } + + default: + { + // the last token was unexpected; we expected a value + if (not expect(token_type::literal_or_value)) + { + return; + } + break; // LCOV_EXCL_LINE + } + } + + if (keep and callback and not callback(depth, parse_event_t::value, result)) + { + result.m_type = value_t::discarded; + } + } + + /*! + @brief the acutal acceptor + + @invariant 1. The last token is not yet processed. Therefore, the caller + of this function must make sure a token has been read. + 2. When this function returns, the last token is processed. + That is, the last read character was already considered. + + This invariant makes sure that no token needs to be "unput". + */ + bool accept_internal() + { + switch (last_token) + { + case token_type::begin_object: + { + // read next token + get_token(); + + // closing } -> we are done + if (last_token == token_type::end_object) + { + return true; + } + + // parse values + while (true) + { + // parse key + if (last_token != token_type::value_string) + { + return false; + } + + // parse separator (:) + get_token(); + if (last_token != token_type::name_separator) + { + return false; + } + + // parse value + get_token(); + if (not accept_internal()) + { + return false; + } + + // comma -> next value + get_token(); + if (last_token == token_type::value_separator) + { + get_token(); + continue; + } + + // closing } + return (last_token == token_type::end_object); + } + } + + case token_type::begin_array: + { + // read next token + get_token(); + + // closing ] -> we are done + if (last_token == token_type::end_array) + { + return true; + } + + // parse values + while (true) + { + // parse value + if (not accept_internal()) + { + return false; + } + + // comma -> next value + get_token(); + if (last_token == token_type::value_separator) + { + get_token(); + continue; + } + + // closing ] + return (last_token == token_type::end_array); + } + } + + case token_type::value_float: + { + // reject infinity or NAN + return std::isfinite(m_lexer.get_number_float()); + } + + case token_type::literal_false: + case token_type::literal_null: + case token_type::literal_true: + case token_type::value_integer: + case token_type::value_string: + case token_type::value_unsigned: + return true; + + default: // the last token was unexpected + return false; + } + } + + /// get next token from lexer + token_type get_token() + { + return (last_token = m_lexer.scan()); + } + + /*! + @throw parse_error.101 if expected token did not occur + */ + bool expect(token_type t) + { + if (JSON_UNLIKELY(t != last_token)) + { + errored = true; + expected = t; + if (allow_exceptions) + { + throw_exception(); + } + else + { + return false; + } + } + + return true; + } + + [[noreturn]] void throw_exception() const + { + std::string error_msg = "syntax error - "; + if (last_token == token_type::parse_error) + { + error_msg += std::string(m_lexer.get_error_message()) + "; last read: '" + + m_lexer.get_token_string() + "'"; + } + else + { + error_msg += "unexpected " + std::string(lexer_t::token_type_name(last_token)); + } + + if (expected != token_type::uninitialized) + { + error_msg += "; expected " + std::string(lexer_t::token_type_name(expected)); + } + + JSON_THROW(parse_error::create(101, m_lexer.get_position(), error_msg)); + } + + private: + /// current level of recursion + int depth = 0; + /// callback function + const parser_callback_t callback = nullptr; + /// the type of the last read token + token_type last_token = token_type::uninitialized; + /// the lexer + lexer_t m_lexer; + /// whether a syntax error occurred + bool errored = false; + /// possible reason for the syntax error + token_type expected = token_type::uninitialized; + /// whether to throw exceptions in case of errors + const bool allow_exceptions = true; +}; + +/////////////// +// iterators // +/////////////// + +/*! +@brief an iterator for primitive JSON types + +This class models an iterator for primitive JSON types (boolean, number, +string). It's only purpose is to allow the iterator/const_iterator classes +to "iterate" over primitive values. Internally, the iterator is modeled by +a `difference_type` variable. Value begin_value (`0`) models the begin, +end_value (`1`) models past the end. +*/ +class primitive_iterator_t +{ + public: + using difference_type = std::ptrdiff_t; + + constexpr difference_type get_value() const noexcept + { + return m_it; + } + + /// set iterator to a defined beginning + void set_begin() noexcept + { + m_it = begin_value; + } + + /// set iterator to a defined past the end + void set_end() noexcept + { + m_it = end_value; + } + + /// return whether the iterator can be dereferenced + constexpr bool is_begin() const noexcept + { + return (m_it == begin_value); + } + + /// return whether the iterator is at end + constexpr bool is_end() const noexcept + { + return (m_it == end_value); + } + + friend constexpr bool operator==(primitive_iterator_t lhs, primitive_iterator_t rhs) noexcept + { + return (lhs.m_it == rhs.m_it); + } + + friend constexpr bool operator!=(primitive_iterator_t lhs, primitive_iterator_t rhs) noexcept + { + return not(lhs == rhs); + } + + friend constexpr bool operator<(primitive_iterator_t lhs, primitive_iterator_t rhs) noexcept + { + return lhs.m_it < rhs.m_it; + } + + friend constexpr bool operator<=(primitive_iterator_t lhs, primitive_iterator_t rhs) noexcept + { + return lhs.m_it <= rhs.m_it; + } + + friend constexpr bool operator>(primitive_iterator_t lhs, primitive_iterator_t rhs) noexcept + { + return lhs.m_it > rhs.m_it; + } + + friend constexpr bool operator>=(primitive_iterator_t lhs, primitive_iterator_t rhs) noexcept + { + return lhs.m_it >= rhs.m_it; + } + + primitive_iterator_t operator+(difference_type i) + { + auto result = *this; + result += i; + return result; + } + + friend constexpr difference_type operator-(primitive_iterator_t lhs, primitive_iterator_t rhs) noexcept + { + return lhs.m_it - rhs.m_it; + } + + friend std::ostream& operator<<(std::ostream& os, primitive_iterator_t it) + { + return os << it.m_it; + } + + primitive_iterator_t& operator++() + { + ++m_it; + return *this; + } + + primitive_iterator_t operator++(int) + { + auto result = *this; + m_it++; + return result; + } + + primitive_iterator_t& operator--() + { + --m_it; + return *this; + } + + primitive_iterator_t operator--(int) + { + auto result = *this; + m_it--; + return result; + } + + primitive_iterator_t& operator+=(difference_type n) + { + m_it += n; + return *this; + } + + primitive_iterator_t& operator-=(difference_type n) + { + m_it -= n; + return *this; + } + + private: + static constexpr difference_type begin_value = 0; + static constexpr difference_type end_value = begin_value + 1; + + /// iterator as signed integer type + difference_type m_it = std::numeric_limits::denorm_min(); +}; + +/*! +@brief an iterator value + +@note This structure could easily be a union, but MSVC currently does not allow +unions members with complex constructors, see https://github.com/nlohmann/json/pull/105. +*/ +template struct internal_iterator +{ + /// iterator for JSON objects + typename BasicJsonType::object_t::iterator object_iterator {}; + /// iterator for JSON arrays + typename BasicJsonType::array_t::iterator array_iterator {}; + /// generic iterator for all other types + primitive_iterator_t primitive_iterator {}; +}; + +template class iteration_proxy; + +/*! +@brief a template for a random access iterator for the @ref basic_json class + +This class implements a both iterators (iterator and const_iterator) for the +@ref basic_json class. + +@note An iterator is called *initialized* when a pointer to a JSON value has + been set (e.g., by a constructor or a copy assignment). If the iterator is + default-constructed, it is *uninitialized* and most methods are undefined. + **The library uses assertions to detect calls on uninitialized iterators.** + +@requirement The class satisfies the following concept requirements: +- +[RandomAccessIterator](http://en.cppreference.com/w/cpp/concept/RandomAccessIterator): + The iterator that can be moved to point (forward and backward) to any + element in constant time. + +@since version 1.0.0, simplified in version 2.0.9 +*/ +template +class iter_impl : public std::iterator +{ + /// allow basic_json to access private members + friend iter_impl::value, typename std::remove_const::type, const BasicJsonType>::type>; + friend BasicJsonType; + friend iteration_proxy; + + using object_t = typename BasicJsonType::object_t; + using array_t = typename BasicJsonType::array_t; + // make sure BasicJsonType is basic_json or const basic_json + static_assert(is_basic_json::type>::value, + "iter_impl only accepts (const) basic_json"); + + public: + /// the type of the values when the iterator is dereferenced + using value_type = typename BasicJsonType::value_type; + /// a type to represent differences between iterators + using difference_type = typename BasicJsonType::difference_type; + /// defines a pointer to the type iterated over (value_type) + using pointer = typename std::conditional::value, + typename BasicJsonType::const_pointer, + typename BasicJsonType::pointer>::type; + /// defines a reference to the type iterated over (value_type) + using reference = + typename std::conditional::value, + typename BasicJsonType::const_reference, + typename BasicJsonType::reference>::type; + /// the category of the iterator + using iterator_category = std::bidirectional_iterator_tag; + + /// default constructor + iter_impl() = default; + + /*! + @brief constructor for a given JSON instance + @param[in] object pointer to a JSON object for this iterator + @pre object != nullptr + @post The iterator is initialized; i.e. `m_object != nullptr`. + */ + explicit iter_impl(pointer object) noexcept : m_object(object) + { + assert(m_object != nullptr); + + switch (m_object->m_type) + { + case value_t::object: + { + m_it.object_iterator = typename object_t::iterator(); + break; + } + + case value_t::array: + { + m_it.array_iterator = typename array_t::iterator(); + break; + } + + default: + { + m_it.primitive_iterator = primitive_iterator_t(); + break; + } + } + } + + /*! + @note The conventional copy constructor and copy assignment are implicitly + defined. Combined with the following converting constructor and + assignment, they support: (1) copy from iterator to iterator, (2) + copy from const iterator to const iterator, and (3) conversion from + iterator to const iterator. However conversion from const iterator + to iterator is not defined. + */ + + /*! + @brief converting constructor + @param[in] other non-const iterator to copy from + @note It is not checked whether @a other is initialized. + */ + iter_impl(const iter_impl::type>& other) noexcept + : m_object(other.m_object), m_it(other.m_it) {} + + /*! + @brief converting assignment + @param[in,out] other non-const iterator to copy from + @return const/non-const iterator + @note It is not checked whether @a other is initialized. + */ + iter_impl& operator=(const iter_impl::type>& other) noexcept + { + m_object = other.m_object; + m_it = other.m_it; + return *this; + } + + private: + /*! + @brief set the iterator to the first value + @pre The iterator is initialized; i.e. `m_object != nullptr`. + */ + void set_begin() noexcept + { + assert(m_object != nullptr); + + switch (m_object->m_type) + { + case value_t::object: + { + m_it.object_iterator = m_object->m_value.object->begin(); + break; + } + + case value_t::array: + { + m_it.array_iterator = m_object->m_value.array->begin(); + break; + } + + case value_t::null: + { + // set to end so begin()==end() is true: null is empty + m_it.primitive_iterator.set_end(); + break; + } + + default: + { + m_it.primitive_iterator.set_begin(); + break; + } + } + } + + /*! + @brief set the iterator past the last value + @pre The iterator is initialized; i.e. `m_object != nullptr`. + */ + void set_end() noexcept + { + assert(m_object != nullptr); + + switch (m_object->m_type) + { + case value_t::object: + { + m_it.object_iterator = m_object->m_value.object->end(); + break; + } + + case value_t::array: + { + m_it.array_iterator = m_object->m_value.array->end(); + break; + } + + default: + { + m_it.primitive_iterator.set_end(); + break; + } + } + } + + public: + /*! + @brief return a reference to the value pointed to by the iterator + @pre The iterator is initialized; i.e. `m_object != nullptr`. + */ + reference operator*() const + { + assert(m_object != nullptr); + + switch (m_object->m_type) + { + case value_t::object: + { + assert(m_it.object_iterator != m_object->m_value.object->end()); + return m_it.object_iterator->second; + } + + case value_t::array: + { + assert(m_it.array_iterator != m_object->m_value.array->end()); + return *m_it.array_iterator; + } + + case value_t::null: + JSON_THROW(invalid_iterator::create(214, "cannot get value")); + + default: + { + if (JSON_LIKELY(m_it.primitive_iterator.is_begin())) + { + return *m_object; + } + + JSON_THROW(invalid_iterator::create(214, "cannot get value")); + } + } + } + + /*! + @brief dereference the iterator + @pre The iterator is initialized; i.e. `m_object != nullptr`. + */ + pointer operator->() const + { + assert(m_object != nullptr); + + switch (m_object->m_type) + { + case value_t::object: + { + assert(m_it.object_iterator != m_object->m_value.object->end()); + return &(m_it.object_iterator->second); + } + + case value_t::array: + { + assert(m_it.array_iterator != m_object->m_value.array->end()); + return &*m_it.array_iterator; + } + + default: + { + if (JSON_LIKELY(m_it.primitive_iterator.is_begin())) + { + return m_object; + } + + JSON_THROW(invalid_iterator::create(214, "cannot get value")); + } + } + } + + /*! + @brief post-increment (it++) + @pre The iterator is initialized; i.e. `m_object != nullptr`. + */ + iter_impl operator++(int) + { + auto result = *this; + ++(*this); + return result; + } + + /*! + @brief pre-increment (++it) + @pre The iterator is initialized; i.e. `m_object != nullptr`. + */ + iter_impl& operator++() + { + assert(m_object != nullptr); + + switch (m_object->m_type) + { + case value_t::object: + { + std::advance(m_it.object_iterator, 1); + break; + } + + case value_t::array: + { + std::advance(m_it.array_iterator, 1); + break; + } + + default: + { + ++m_it.primitive_iterator; + break; + } + } + + return *this; + } + + /*! + @brief post-decrement (it--) + @pre The iterator is initialized; i.e. `m_object != nullptr`. + */ + iter_impl operator--(int) + { + auto result = *this; + --(*this); + return result; + } + + /*! + @brief pre-decrement (--it) + @pre The iterator is initialized; i.e. `m_object != nullptr`. + */ + iter_impl& operator--() + { + assert(m_object != nullptr); + + switch (m_object->m_type) + { + case value_t::object: + { + std::advance(m_it.object_iterator, -1); + break; + } + + case value_t::array: + { + std::advance(m_it.array_iterator, -1); + break; + } + + default: + { + --m_it.primitive_iterator; + break; + } + } + + return *this; + } + + /*! + @brief comparison: equal + @pre The iterator is initialized; i.e. `m_object != nullptr`. + */ + bool operator==(const iter_impl& other) const + { + // if objects are not the same, the comparison is undefined + if (JSON_UNLIKELY(m_object != other.m_object)) + { + JSON_THROW(invalid_iterator::create(212, "cannot compare iterators of different containers")); + } + + assert(m_object != nullptr); + + switch (m_object->m_type) + { + case value_t::object: + return (m_it.object_iterator == other.m_it.object_iterator); + + case value_t::array: + return (m_it.array_iterator == other.m_it.array_iterator); + + default: + return (m_it.primitive_iterator == other.m_it.primitive_iterator); + } + } + + /*! + @brief comparison: not equal + @pre The iterator is initialized; i.e. `m_object != nullptr`. + */ + bool operator!=(const iter_impl& other) const + { + return not operator==(other); + } + + /*! + @brief comparison: smaller + @pre The iterator is initialized; i.e. `m_object != nullptr`. + */ + bool operator<(const iter_impl& other) const + { + // if objects are not the same, the comparison is undefined + if (JSON_UNLIKELY(m_object != other.m_object)) + { + JSON_THROW(invalid_iterator::create(212, "cannot compare iterators of different containers")); + } + + assert(m_object != nullptr); + + switch (m_object->m_type) + { + case value_t::object: + JSON_THROW(invalid_iterator::create(213, "cannot compare order of object iterators")); + + case value_t::array: + return (m_it.array_iterator < other.m_it.array_iterator); + + default: + return (m_it.primitive_iterator < other.m_it.primitive_iterator); + } + } + + /*! + @brief comparison: less than or equal + @pre The iterator is initialized; i.e. `m_object != nullptr`. + */ + bool operator<=(const iter_impl& other) const + { + return not other.operator < (*this); + } + + /*! + @brief comparison: greater than + @pre The iterator is initialized; i.e. `m_object != nullptr`. + */ + bool operator>(const iter_impl& other) const + { + return not operator<=(other); + } + + /*! + @brief comparison: greater than or equal + @pre The iterator is initialized; i.e. `m_object != nullptr`. + */ + bool operator>=(const iter_impl& other) const + { + return not operator<(other); + } + + /*! + @brief add to iterator + @pre The iterator is initialized; i.e. `m_object != nullptr`. + */ + iter_impl& operator+=(difference_type i) + { + assert(m_object != nullptr); + + switch (m_object->m_type) + { + case value_t::object: + JSON_THROW(invalid_iterator::create(209, "cannot use offsets with object iterators")); + + case value_t::array: + { + std::advance(m_it.array_iterator, i); + break; + } + + default: + { + m_it.primitive_iterator += i; + break; + } + } + + return *this; + } + + /*! + @brief subtract from iterator + @pre The iterator is initialized; i.e. `m_object != nullptr`. + */ + iter_impl& operator-=(difference_type i) + { + return operator+=(-i); + } + + /*! + @brief add to iterator + @pre The iterator is initialized; i.e. `m_object != nullptr`. + */ + iter_impl operator+(difference_type i) const + { + auto result = *this; + result += i; + return result; + } + + /*! + @brief addition of distance and iterator + @pre The iterator is initialized; i.e. `m_object != nullptr`. + */ + friend iter_impl operator+(difference_type i, const iter_impl& it) + { + auto result = it; + result += i; + return result; + } + + /*! + @brief subtract from iterator + @pre The iterator is initialized; i.e. `m_object != nullptr`. + */ + iter_impl operator-(difference_type i) const + { + auto result = *this; + result -= i; + return result; + } + + /*! + @brief return difference + @pre The iterator is initialized; i.e. `m_object != nullptr`. + */ + difference_type operator-(const iter_impl& other) const + { + assert(m_object != nullptr); + + switch (m_object->m_type) + { + case value_t::object: + JSON_THROW(invalid_iterator::create(209, "cannot use offsets with object iterators")); + + case value_t::array: + return m_it.array_iterator - other.m_it.array_iterator; + + default: + return m_it.primitive_iterator - other.m_it.primitive_iterator; + } + } + + /*! + @brief access to successor + @pre The iterator is initialized; i.e. `m_object != nullptr`. + */ + reference operator[](difference_type n) const + { + assert(m_object != nullptr); + + switch (m_object->m_type) + { + case value_t::object: + JSON_THROW(invalid_iterator::create(208, "cannot use operator[] for object iterators")); + + case value_t::array: + return *std::next(m_it.array_iterator, n); + + case value_t::null: + JSON_THROW(invalid_iterator::create(214, "cannot get value")); + + default: + { + if (JSON_LIKELY(m_it.primitive_iterator.get_value() == -n)) + { + return *m_object; + } + + JSON_THROW(invalid_iterator::create(214, "cannot get value")); + } + } + } + + /*! + @brief return the key of an object iterator + @pre The iterator is initialized; i.e. `m_object != nullptr`. + */ + typename object_t::key_type key() const + { + assert(m_object != nullptr); + + if (JSON_LIKELY(m_object->is_object())) + { + return m_it.object_iterator->first; + } + + JSON_THROW(invalid_iterator::create(207, "cannot use key() for non-object iterators")); + } + + /*! + @brief return the value of an iterator + @pre The iterator is initialized; i.e. `m_object != nullptr`. + */ + reference value() const + { + return operator*(); + } + + private: + /// associated JSON instance + pointer m_object = nullptr; + /// the actual iterator of the associated instance + internal_iterator::type> m_it = {}; +}; + +/// proxy class for the iterator_wrapper functions +template class iteration_proxy +{ + private: + /// helper class for iteration + class iteration_proxy_internal + { + private: + /// the iterator + IteratorType anchor; + /// an index for arrays (used to create key names) + std::size_t array_index = 0; + + public: + explicit iteration_proxy_internal(IteratorType it) noexcept : anchor(it) {} + + /// dereference operator (needed for range-based for) + iteration_proxy_internal& operator*() + { + return *this; + } + + /// increment operator (needed for range-based for) + iteration_proxy_internal& operator++() + { + ++anchor; + ++array_index; + + return *this; + } + + /// inequality operator (needed for range-based for) + bool operator!=(const iteration_proxy_internal& o) const noexcept + { + return anchor != o.anchor; + } + + /// return key of the iterator + std::string key() const + { + assert(anchor.m_object != nullptr); + + switch (anchor.m_object->type()) + { + // use integer array index as key + case value_t::array: + return std::to_string(array_index); + + // use key from the object + case value_t::object: + return anchor.key(); + + // use an empty key for all primitive types + default: + return ""; + } + } + + /// return value of the iterator + typename IteratorType::reference value() const + { + return anchor.value(); + } + }; + + /// the container to iterate + typename IteratorType::reference container; + + public: + /// construct iteration proxy from a container + explicit iteration_proxy(typename IteratorType::reference cont) + : container(cont) {} + + /// return iterator begin (needed for range-based for) + iteration_proxy_internal begin() noexcept + { + return iteration_proxy_internal(container.begin()); + } + + /// return iterator end (needed for range-based for) + iteration_proxy_internal end() noexcept + { + return iteration_proxy_internal(container.end()); + } +}; + +/*! +@brief a template for a reverse iterator class + +@tparam Base the base iterator type to reverse. Valid types are @ref +iterator (to create @ref reverse_iterator) and @ref const_iterator (to +create @ref const_reverse_iterator). + +@requirement The class satisfies the following concept requirements: +- +[RandomAccessIterator](http://en.cppreference.com/w/cpp/concept/RandomAccessIterator): + The iterator that can be moved to point (forward and backward) to any + element in constant time. +- [OutputIterator](http://en.cppreference.com/w/cpp/concept/OutputIterator): + It is possible to write to the pointed-to element (only if @a Base is + @ref iterator). + +@since version 1.0.0 +*/ +template +class json_reverse_iterator : public std::reverse_iterator +{ + public: + using difference_type = std::ptrdiff_t; + /// shortcut to the reverse iterator adaptor + using base_iterator = std::reverse_iterator; + /// the reference type for the pointed-to element + using reference = typename Base::reference; + + /// create reverse iterator from iterator + json_reverse_iterator(const typename base_iterator::iterator_type& it) noexcept + : base_iterator(it) {} + + /// create reverse iterator from base class + json_reverse_iterator(const base_iterator& it) noexcept : base_iterator(it) {} + + /// post-increment (it++) + json_reverse_iterator operator++(int) + { + return static_cast(base_iterator::operator++(1)); + } + + /// pre-increment (++it) + json_reverse_iterator& operator++() + { + return static_cast(base_iterator::operator++()); + } + + /// post-decrement (it--) + json_reverse_iterator operator--(int) + { + return static_cast(base_iterator::operator--(1)); + } + + /// pre-decrement (--it) + json_reverse_iterator& operator--() + { + return static_cast(base_iterator::operator--()); + } + + /// add to iterator + json_reverse_iterator& operator+=(difference_type i) + { + return static_cast(base_iterator::operator+=(i)); + } + + /// add to iterator + json_reverse_iterator operator+(difference_type i) const + { + return static_cast(base_iterator::operator+(i)); + } + + /// subtract from iterator + json_reverse_iterator operator-(difference_type i) const + { + return static_cast(base_iterator::operator-(i)); + } + + /// return difference + difference_type operator-(const json_reverse_iterator& other) const + { + return base_iterator(*this) - base_iterator(other); + } + + /// access to successor + reference operator[](difference_type n) const + { + return *(this->operator+(n)); + } + + /// return the key of an object iterator + auto key() const -> decltype(std::declval().key()) + { + auto it = --this->base(); + return it.key(); + } + + /// return the value of an iterator + reference value() const + { + auto it = --this->base(); + return it.operator * (); + } +}; + +///////////////////// +// output adapters // +///////////////////// + +/// abstract output adapter interface +template struct output_adapter_protocol +{ + virtual void write_character(CharType c) = 0; + virtual void write_characters(const CharType* s, std::size_t length) = 0; + virtual ~output_adapter_protocol() = default; +}; + +/// a type to simplify interfaces +template +using output_adapter_t = std::shared_ptr>; + +/// output adapter for byte vectors +template +class output_vector_adapter : public output_adapter_protocol +{ + public: + explicit output_vector_adapter(std::vector& vec) : v(vec) {} + + void write_character(CharType c) override + { + v.push_back(c); + } + + void write_characters(const CharType* s, std::size_t length) override + { + std::copy(s, s + length, std::back_inserter(v)); + } + + private: + std::vector& v; +}; + +/// output adapter for output streams +template +class output_stream_adapter : public output_adapter_protocol +{ + public: + explicit output_stream_adapter(std::basic_ostream& s) : stream(s) {} + + void write_character(CharType c) override + { + stream.put(c); + } + + void write_characters(const CharType* s, std::size_t length) override + { + stream.write(s, static_cast(length)); + } + + private: + std::basic_ostream& stream; +}; + +/// output adapter for basic_string +template +class output_string_adapter : public output_adapter_protocol +{ + public: + explicit output_string_adapter(std::basic_string& s) : str(s) {} + + void write_character(CharType c) override + { + str.push_back(c); + } + + void write_characters(const CharType* s, std::size_t length) override + { + str.append(s, length); + } + + private: + std::basic_string& str; +}; + +template +class output_adapter +{ + public: + output_adapter(std::vector& vec) + : oa(std::make_shared>(vec)) {} + + output_adapter(std::basic_ostream& s) + : oa(std::make_shared>(s)) {} + + output_adapter(std::basic_string& s) + : oa(std::make_shared>(s)) {} + + operator output_adapter_t() + { + return oa; + } + + private: + output_adapter_t oa = nullptr; +}; + +////////////////////////////// +// binary reader and writer // +////////////////////////////// + +/*! +@brief deserialization of CBOR and MessagePack values +*/ +template +class binary_reader +{ + using number_integer_t = typename BasicJsonType::number_integer_t; + using number_unsigned_t = typename BasicJsonType::number_unsigned_t; + + public: + /*! + @brief create a binary reader + + @param[in] adapter input adapter to read from + */ + explicit binary_reader(input_adapter_t adapter) : ia(std::move(adapter)) + { + assert(ia); + } + + /*! + @brief create a JSON value from CBOR input + + @param[in] strict whether to expect the input to be consumed completed + @return JSON value created from CBOR input + + @throw parse_error.110 if input ended unexpectedly or the end of file was + not reached when @a strict was set to true + @throw parse_error.112 if unsupported byte was read + */ + BasicJsonType parse_cbor(const bool strict) + { + const auto res = parse_cbor_internal(); + if (strict) + { + get(); + check_eof(true); + } + return res; + } + + /*! + @brief create a JSON value from MessagePack input + + @param[in] strict whether to expect the input to be consumed completed + @return JSON value created from MessagePack input + + @throw parse_error.110 if input ended unexpectedly or the end of file was + not reached when @a strict was set to true + @throw parse_error.112 if unsupported byte was read + */ + BasicJsonType parse_msgpack(const bool strict) + { + const auto res = parse_msgpack_internal(); + if (strict) + { + get(); + check_eof(true); + } + return res; + } + + /*! + @brief determine system byte order + + @return true if and only if system's byte order is little endian + + @note from http://stackoverflow.com/a/1001328/266378 + */ + static constexpr bool little_endianess(int num = 1) noexcept + { + return (*reinterpret_cast(&num) == 1); + } + + private: + /*! + @param[in] get_char whether a new character should be retrieved from the + input (true, default) or whether the last read + character should be considered instead + */ + BasicJsonType parse_cbor_internal(const bool get_char = true) + { + switch (get_char ? get() : current) + { + // EOF + case std::char_traits::eof(): + JSON_THROW(parse_error::create(110, chars_read, "unexpected end of input")); + + // Integer 0x00..0x17 (0..23) + case 0x00: + case 0x01: + case 0x02: + case 0x03: + case 0x04: + case 0x05: + case 0x06: + case 0x07: + case 0x08: + case 0x09: + case 0x0a: + case 0x0b: + case 0x0c: + case 0x0d: + case 0x0e: + case 0x0f: + case 0x10: + case 0x11: + case 0x12: + case 0x13: + case 0x14: + case 0x15: + case 0x16: + case 0x17: + return static_cast(current); + + case 0x18: // Unsigned integer (one-byte uint8_t follows) + return get_number(); + + case 0x19: // Unsigned integer (two-byte uint16_t follows) + return get_number(); + + case 0x1a: // Unsigned integer (four-byte uint32_t follows) + return get_number(); + + case 0x1b: // Unsigned integer (eight-byte uint64_t follows) + return get_number(); + + // Negative integer -1-0x00..-1-0x17 (-1..-24) + case 0x20: + case 0x21: + case 0x22: + case 0x23: + case 0x24: + case 0x25: + case 0x26: + case 0x27: + case 0x28: + case 0x29: + case 0x2a: + case 0x2b: + case 0x2c: + case 0x2d: + case 0x2e: + case 0x2f: + case 0x30: + case 0x31: + case 0x32: + case 0x33: + case 0x34: + case 0x35: + case 0x36: + case 0x37: + return static_cast(0x20 - 1 - current); + + case 0x38: // Negative integer (one-byte uint8_t follows) + { + // must be uint8_t ! + return static_cast(-1) - get_number(); + } + + case 0x39: // Negative integer -1-n (two-byte uint16_t follows) + { + return static_cast(-1) - get_number(); + } + + case 0x3a: // Negative integer -1-n (four-byte uint32_t follows) + { + return static_cast(-1) - get_number(); + } + + case 0x3b: // Negative integer -1-n (eight-byte uint64_t follows) + { + return static_cast(-1) - + static_cast(get_number()); + } + + // UTF-8 string (0x00..0x17 bytes follow) + case 0x60: + case 0x61: + case 0x62: + case 0x63: + case 0x64: + case 0x65: + case 0x66: + case 0x67: + case 0x68: + case 0x69: + case 0x6a: + case 0x6b: + case 0x6c: + case 0x6d: + case 0x6e: + case 0x6f: + case 0x70: + case 0x71: + case 0x72: + case 0x73: + case 0x74: + case 0x75: + case 0x76: + case 0x77: + case 0x78: // UTF-8 string (one-byte uint8_t for n follows) + case 0x79: // UTF-8 string (two-byte uint16_t for n follow) + case 0x7a: // UTF-8 string (four-byte uint32_t for n follow) + case 0x7b: // UTF-8 string (eight-byte uint64_t for n follow) + case 0x7f: // UTF-8 string (indefinite length) + { + return get_cbor_string(); + } + + // array (0x00..0x17 data items follow) + case 0x80: + case 0x81: + case 0x82: + case 0x83: + case 0x84: + case 0x85: + case 0x86: + case 0x87: + case 0x88: + case 0x89: + case 0x8a: + case 0x8b: + case 0x8c: + case 0x8d: + case 0x8e: + case 0x8f: + case 0x90: + case 0x91: + case 0x92: + case 0x93: + case 0x94: + case 0x95: + case 0x96: + case 0x97: + { + return get_cbor_array(current & 0x1f); + } + + case 0x98: // array (one-byte uint8_t for n follows) + { + return get_cbor_array(get_number()); + } + + case 0x99: // array (two-byte uint16_t for n follow) + { + return get_cbor_array(get_number()); + } + + case 0x9a: // array (four-byte uint32_t for n follow) + { + return get_cbor_array(get_number()); + } + + case 0x9b: // array (eight-byte uint64_t for n follow) + { + return get_cbor_array(get_number()); + } + + case 0x9f: // array (indefinite length) + { + BasicJsonType result = value_t::array; + while (get() != 0xff) + { + result.push_back(parse_cbor_internal(false)); + } + return result; + } + + // map (0x00..0x17 pairs of data items follow) + case 0xa0: + case 0xa1: + case 0xa2: + case 0xa3: + case 0xa4: + case 0xa5: + case 0xa6: + case 0xa7: + case 0xa8: + case 0xa9: + case 0xaa: + case 0xab: + case 0xac: + case 0xad: + case 0xae: + case 0xaf: + case 0xb0: + case 0xb1: + case 0xb2: + case 0xb3: + case 0xb4: + case 0xb5: + case 0xb6: + case 0xb7: + { + return get_cbor_object(current & 0x1f); + } + + case 0xb8: // map (one-byte uint8_t for n follows) + { + return get_cbor_object(get_number()); + } + + case 0xb9: // map (two-byte uint16_t for n follow) + { + return get_cbor_object(get_number()); + } + + case 0xba: // map (four-byte uint32_t for n follow) + { + return get_cbor_object(get_number()); + } + + case 0xbb: // map (eight-byte uint64_t for n follow) + { + return get_cbor_object(get_number()); + } + + case 0xbf: // map (indefinite length) + { + BasicJsonType result = value_t::object; + while (get() != 0xff) + { + auto key = get_cbor_string(); + result[key] = parse_cbor_internal(); + } + return result; + } + + case 0xf4: // false + { + return false; + } + + case 0xf5: // true + { + return true; + } + + case 0xf6: // null + { + return value_t::null; + } + + case 0xf9: // Half-Precision Float (two-byte IEEE 754) + { + const int byte1 = get(); + check_eof(); + const int byte2 = get(); + check_eof(); + + // code from RFC 7049, Appendix D, Figure 3: + // As half-precision floating-point numbers were only added + // to IEEE 754 in 2008, today's programming platforms often + // still only have limited support for them. It is very + // easy to include at least decoding support for them even + // without such support. An example of a small decoder for + // half-precision floating-point numbers in the C language + // is shown in Fig. 3. + const int half = (byte1 << 8) + byte2; + const int exp = (half >> 10) & 0x1f; + const int mant = half & 0x3ff; + double val; + if (exp == 0) + { + val = std::ldexp(mant, -24); + } + else if (exp != 31) + { + val = std::ldexp(mant + 1024, exp - 25); + } + else + { + val = (mant == 0) ? std::numeric_limits::infinity() + : std::numeric_limits::quiet_NaN(); + } + return (half & 0x8000) != 0 ? -val : val; + } + + case 0xfa: // Single-Precision Float (four-byte IEEE 754) + { + return get_number(); + } + + case 0xfb: // Double-Precision Float (eight-byte IEEE 754) + { + return get_number(); + } + + default: // anything else (0xFF is handled inside the other types) + { + std::stringstream ss; + ss << std::setw(2) << std::setfill('0') << std::hex << current; + JSON_THROW(parse_error::create(112, chars_read, "error reading CBOR; last byte: 0x" + ss.str())); + } + } + } + + BasicJsonType parse_msgpack_internal() + { + switch (get()) + { + // EOF + case std::char_traits::eof(): + JSON_THROW(parse_error::create(110, chars_read, "unexpected end of input")); + + // positive fixint + case 0x00: + case 0x01: + case 0x02: + case 0x03: + case 0x04: + case 0x05: + case 0x06: + case 0x07: + case 0x08: + case 0x09: + case 0x0a: + case 0x0b: + case 0x0c: + case 0x0d: + case 0x0e: + case 0x0f: + case 0x10: + case 0x11: + case 0x12: + case 0x13: + case 0x14: + case 0x15: + case 0x16: + case 0x17: + case 0x18: + case 0x19: + case 0x1a: + case 0x1b: + case 0x1c: + case 0x1d: + case 0x1e: + case 0x1f: + case 0x20: + case 0x21: + case 0x22: + case 0x23: + case 0x24: + case 0x25: + case 0x26: + case 0x27: + case 0x28: + case 0x29: + case 0x2a: + case 0x2b: + case 0x2c: + case 0x2d: + case 0x2e: + case 0x2f: + case 0x30: + case 0x31: + case 0x32: + case 0x33: + case 0x34: + case 0x35: + case 0x36: + case 0x37: + case 0x38: + case 0x39: + case 0x3a: + case 0x3b: + case 0x3c: + case 0x3d: + case 0x3e: + case 0x3f: + case 0x40: + case 0x41: + case 0x42: + case 0x43: + case 0x44: + case 0x45: + case 0x46: + case 0x47: + case 0x48: + case 0x49: + case 0x4a: + case 0x4b: + case 0x4c: + case 0x4d: + case 0x4e: + case 0x4f: + case 0x50: + case 0x51: + case 0x52: + case 0x53: + case 0x54: + case 0x55: + case 0x56: + case 0x57: + case 0x58: + case 0x59: + case 0x5a: + case 0x5b: + case 0x5c: + case 0x5d: + case 0x5e: + case 0x5f: + case 0x60: + case 0x61: + case 0x62: + case 0x63: + case 0x64: + case 0x65: + case 0x66: + case 0x67: + case 0x68: + case 0x69: + case 0x6a: + case 0x6b: + case 0x6c: + case 0x6d: + case 0x6e: + case 0x6f: + case 0x70: + case 0x71: + case 0x72: + case 0x73: + case 0x74: + case 0x75: + case 0x76: + case 0x77: + case 0x78: + case 0x79: + case 0x7a: + case 0x7b: + case 0x7c: + case 0x7d: + case 0x7e: + case 0x7f: + return static_cast(current); + + // fixmap + case 0x80: + case 0x81: + case 0x82: + case 0x83: + case 0x84: + case 0x85: + case 0x86: + case 0x87: + case 0x88: + case 0x89: + case 0x8a: + case 0x8b: + case 0x8c: + case 0x8d: + case 0x8e: + case 0x8f: + { + return get_msgpack_object(current & 0x0f); + } + + // fixarray + case 0x90: + case 0x91: + case 0x92: + case 0x93: + case 0x94: + case 0x95: + case 0x96: + case 0x97: + case 0x98: + case 0x99: + case 0x9a: + case 0x9b: + case 0x9c: + case 0x9d: + case 0x9e: + case 0x9f: + { + return get_msgpack_array(current & 0x0f); + } + + // fixstr + case 0xa0: + case 0xa1: + case 0xa2: + case 0xa3: + case 0xa4: + case 0xa5: + case 0xa6: + case 0xa7: + case 0xa8: + case 0xa9: + case 0xaa: + case 0xab: + case 0xac: + case 0xad: + case 0xae: + case 0xaf: + case 0xb0: + case 0xb1: + case 0xb2: + case 0xb3: + case 0xb4: + case 0xb5: + case 0xb6: + case 0xb7: + case 0xb8: + case 0xb9: + case 0xba: + case 0xbb: + case 0xbc: + case 0xbd: + case 0xbe: + case 0xbf: + return get_msgpack_string(); + + case 0xc0: // nil + return value_t::null; + + case 0xc2: // false + return false; + + case 0xc3: // true + return true; + + case 0xca: // float 32 + return get_number(); + + case 0xcb: // float 64 + return get_number(); + + case 0xcc: // uint 8 + return get_number(); + + case 0xcd: // uint 16 + return get_number(); + + case 0xce: // uint 32 + return get_number(); + + case 0xcf: // uint 64 + return get_number(); + + case 0xd0: // int 8 + return get_number(); + + case 0xd1: // int 16 + return get_number(); + + case 0xd2: // int 32 + return get_number(); + + case 0xd3: // int 64 + return get_number(); + + case 0xd9: // str 8 + case 0xda: // str 16 + case 0xdb: // str 32 + return get_msgpack_string(); + + case 0xdc: // array 16 + { + return get_msgpack_array(get_number()); + } + + case 0xdd: // array 32 + { + return get_msgpack_array(get_number()); + } + + case 0xde: // map 16 + { + return get_msgpack_object(get_number()); + } + + case 0xdf: // map 32 + { + return get_msgpack_object(get_number()); + } + + // positive fixint + case 0xe0: + case 0xe1: + case 0xe2: + case 0xe3: + case 0xe4: + case 0xe5: + case 0xe6: + case 0xe7: + case 0xe8: + case 0xe9: + case 0xea: + case 0xeb: + case 0xec: + case 0xed: + case 0xee: + case 0xef: + case 0xf0: + case 0xf1: + case 0xf2: + case 0xf3: + case 0xf4: + case 0xf5: + case 0xf6: + case 0xf7: + case 0xf8: + case 0xf9: + case 0xfa: + case 0xfb: + case 0xfc: + case 0xfd: + case 0xfe: + case 0xff: + return static_cast(current); + + default: // anything else + { + std::stringstream ss; + ss << std::setw(2) << std::setfill('0') << std::hex << current; + JSON_THROW(parse_error::create(112, chars_read, + "error reading MessagePack; last byte: 0x" + ss.str())); + } + } + } + + /*! + @brief get next character from the input + + This function provides the interface to the used input adapter. It does + not throw in case the input reached EOF, but returns + `std::char_traits::eof()` in that case. + + @return character read from the input + */ + int get() + { + ++chars_read; + return (current = ia->get_character()); + } + + /* + @brief read a number from the input + + @tparam NumberType the type of the number + + @return number of type @a NumberType + + @note This function needs to respect the system's endianess, because + bytes in CBOR and MessagePack are stored in network order (big + endian) and therefore need reordering on little endian systems. + + @throw parse_error.110 if input has less than `sizeof(NumberType)` bytes + */ + template NumberType get_number() + { + // step 1: read input into array with system's byte order + std::array vec; + for (std::size_t i = 0; i < sizeof(NumberType); ++i) + { + get(); + check_eof(); + + // reverse byte order prior to conversion if necessary + if (is_little_endian) + { + vec[sizeof(NumberType) - i - 1] = static_cast(current); + } + else + { + vec[i] = static_cast(current); // LCOV_EXCL_LINE + } + } + + // step 2: convert array into number of type T and return + NumberType result; + std::memcpy(&result, vec.data(), sizeof(NumberType)); + return result; + } + + /*! + @brief create a string by reading characters from the input + + @param[in] len number of bytes to read + + @note We can not reserve @a len bytes for the result, because @a len + may be too large. Usually, @ref check_eof() detects the end of + the input before we run out of string memory. + + @return string created by reading @a len bytes + + @throw parse_error.110 if input has less than @a len bytes + */ + template + std::string get_string(const NumberType len) + { + std::string result; + std::generate_n(std::back_inserter(result), len, [this]() + { + get(); + check_eof(); + return current; + }); + return result; + } + + /*! + @brief reads a CBOR string + + This function first reads starting bytes to determine the expected + string length and then copies this number of bytes into a string. + Additionally, CBOR's strings with indefinite lengths are supported. + + @return string + + @throw parse_error.110 if input ended + @throw parse_error.113 if an unexpected byte is read + */ + std::string get_cbor_string() + { + check_eof(); + + switch (current) + { + // UTF-8 string (0x00..0x17 bytes follow) + case 0x60: + case 0x61: + case 0x62: + case 0x63: + case 0x64: + case 0x65: + case 0x66: + case 0x67: + case 0x68: + case 0x69: + case 0x6a: + case 0x6b: + case 0x6c: + case 0x6d: + case 0x6e: + case 0x6f: + case 0x70: + case 0x71: + case 0x72: + case 0x73: + case 0x74: + case 0x75: + case 0x76: + case 0x77: + { + return get_string(current & 0x1f); + } + + case 0x78: // UTF-8 string (one-byte uint8_t for n follows) + { + return get_string(get_number()); + } + + case 0x79: // UTF-8 string (two-byte uint16_t for n follow) + { + return get_string(get_number()); + } + + case 0x7a: // UTF-8 string (four-byte uint32_t for n follow) + { + return get_string(get_number()); + } + + case 0x7b: // UTF-8 string (eight-byte uint64_t for n follow) + { + return get_string(get_number()); + } + + case 0x7f: // UTF-8 string (indefinite length) + { + std::string result; + while (get() != 0xff) + { + check_eof(); + result.push_back(static_cast(current)); + } + return result; + } + + default: + { + std::stringstream ss; + ss << std::setw(2) << std::setfill('0') << std::hex << current; + JSON_THROW(parse_error::create(113, chars_read, "expected a CBOR string; last byte: 0x" + ss.str())); + } + } + } + + template + BasicJsonType get_cbor_array(const NumberType len) + { + BasicJsonType result = value_t::array; + std::generate_n(std::back_inserter(*result.m_value.array), len, [this]() + { + return parse_cbor_internal(); + }); + return result; + } + + template + BasicJsonType get_cbor_object(const NumberType len) + { + BasicJsonType result = value_t::object; + std::generate_n(std::inserter(*result.m_value.object, + result.m_value.object->end()), + len, [this]() + { + get(); + auto key = get_cbor_string(); + auto val = parse_cbor_internal(); + return std::make_pair(std::move(key), std::move(val)); + }); + return result; + } + + /*! + @brief reads a MessagePack string + + This function first reads starting bytes to determine the expected + string length and then copies this number of bytes into a string. + + @return string + + @throw parse_error.110 if input ended + @throw parse_error.113 if an unexpected byte is read + */ + std::string get_msgpack_string() + { + check_eof(); + + switch (current) + { + // fixstr + case 0xa0: + case 0xa1: + case 0xa2: + case 0xa3: + case 0xa4: + case 0xa5: + case 0xa6: + case 0xa7: + case 0xa8: + case 0xa9: + case 0xaa: + case 0xab: + case 0xac: + case 0xad: + case 0xae: + case 0xaf: + case 0xb0: + case 0xb1: + case 0xb2: + case 0xb3: + case 0xb4: + case 0xb5: + case 0xb6: + case 0xb7: + case 0xb8: + case 0xb9: + case 0xba: + case 0xbb: + case 0xbc: + case 0xbd: + case 0xbe: + case 0xbf: + { + return get_string(current & 0x1f); + } + + case 0xd9: // str 8 + { + return get_string(get_number()); + } + + case 0xda: // str 16 + { + return get_string(get_number()); + } + + case 0xdb: // str 32 + { + return get_string(get_number()); + } + + default: + { + std::stringstream ss; + ss << std::setw(2) << std::setfill('0') << std::hex << current; + JSON_THROW(parse_error::create(113, chars_read, + "expected a MessagePack string; last byte: 0x" + ss.str())); + } + } + } + + template + BasicJsonType get_msgpack_array(const NumberType len) + { + BasicJsonType result = value_t::array; + std::generate_n(std::back_inserter(*result.m_value.array), len, [this]() + { + return parse_msgpack_internal(); + }); + return result; + } + + template + BasicJsonType get_msgpack_object(const NumberType len) + { + BasicJsonType result = value_t::object; + std::generate_n(std::inserter(*result.m_value.object, + result.m_value.object->end()), + len, [this]() + { + get(); + auto key = get_msgpack_string(); + auto val = parse_msgpack_internal(); + return std::make_pair(std::move(key), std::move(val)); + }); + return result; + } + + /*! + @brief check if input ended + @throw parse_error.110 if input ended + */ + void check_eof(const bool expect_eof = false) const + { + if (expect_eof) + { + if (JSON_UNLIKELY(current != std::char_traits::eof())) + { + JSON_THROW(parse_error::create(110, chars_read, "expected end of input")); + } + } + else + { + if (JSON_UNLIKELY(current == std::char_traits::eof())) + { + JSON_THROW(parse_error::create(110, chars_read, "unexpected end of input")); + } + } + } + + private: + /// input adapter + input_adapter_t ia = nullptr; + + /// the current character + int current = std::char_traits::eof(); + + /// the number of characters read + std::size_t chars_read = 0; + + /// whether we can assume little endianess + const bool is_little_endian = little_endianess(); +}; + +/*! +@brief serialization to CBOR and MessagePack values +*/ +template +class binary_writer +{ + public: + /*! + @brief create a binary writer + + @param[in] adapter output adapter to write to + */ + explicit binary_writer(output_adapter_t adapter) : oa(adapter) + { + assert(oa); + } + + /*! + @brief[in] j JSON value to serialize + */ + void write_cbor(const BasicJsonType& j) + { + switch (j.type()) + { + case value_t::null: + { + oa->write_character(static_cast(0xf6)); + break; + } + + case value_t::boolean: + { + oa->write_character(j.m_value.boolean + ? static_cast(0xf5) + : static_cast(0xf4)); + break; + } + + case value_t::number_integer: + { + if (j.m_value.number_integer >= 0) + { + // CBOR does not differentiate between positive signed + // integers and unsigned integers. Therefore, we used the + // code from the value_t::number_unsigned case here. + if (j.m_value.number_integer <= 0x17) + { + write_number(static_cast(j.m_value.number_integer)); + } + else if (j.m_value.number_integer <= (std::numeric_limits::max)()) + { + oa->write_character(static_cast(0x18)); + write_number(static_cast(j.m_value.number_integer)); + } + else if (j.m_value.number_integer <= (std::numeric_limits::max)()) + { + oa->write_character(static_cast(0x19)); + write_number(static_cast(j.m_value.number_integer)); + } + else if (j.m_value.number_integer <= (std::numeric_limits::max)()) + { + oa->write_character(static_cast(0x1a)); + write_number(static_cast(j.m_value.number_integer)); + } + else + { + oa->write_character(static_cast(0x1b)); + write_number(static_cast(j.m_value.number_integer)); + } + } + else + { + // The conversions below encode the sign in the first + // byte, and the value is converted to a positive number. + const auto positive_number = -1 - j.m_value.number_integer; + if (j.m_value.number_integer >= -24) + { + write_number(static_cast(0x20 + positive_number)); + } + else if (positive_number <= (std::numeric_limits::max)()) + { + oa->write_character(static_cast(0x38)); + write_number(static_cast(positive_number)); + } + else if (positive_number <= (std::numeric_limits::max)()) + { + oa->write_character(static_cast(0x39)); + write_number(static_cast(positive_number)); + } + else if (positive_number <= (std::numeric_limits::max)()) + { + oa->write_character(static_cast(0x3a)); + write_number(static_cast(positive_number)); + } + else + { + oa->write_character(static_cast(0x3b)); + write_number(static_cast(positive_number)); + } + } + break; + } + + case value_t::number_unsigned: + { + if (j.m_value.number_unsigned <= 0x17) + { + write_number(static_cast(j.m_value.number_unsigned)); + } + else if (j.m_value.number_unsigned <= (std::numeric_limits::max)()) + { + oa->write_character(static_cast(0x18)); + write_number(static_cast(j.m_value.number_unsigned)); + } + else if (j.m_value.number_unsigned <= (std::numeric_limits::max)()) + { + oa->write_character(static_cast(0x19)); + write_number(static_cast(j.m_value.number_unsigned)); + } + else if (j.m_value.number_unsigned <= (std::numeric_limits::max)()) + { + oa->write_character(static_cast(0x1a)); + write_number(static_cast(j.m_value.number_unsigned)); + } + else + { + oa->write_character(static_cast(0x1b)); + write_number(static_cast(j.m_value.number_unsigned)); + } + break; + } + + case value_t::number_float: // Double-Precision Float + { + oa->write_character(static_cast(0xfb)); + write_number(j.m_value.number_float); + break; + } + + case value_t::string: + { + // step 1: write control byte and the string length + const auto N = j.m_value.string->size(); + if (N <= 0x17) + { + write_number(static_cast(0x60 + N)); + } + else if (N <= 0xff) + { + oa->write_character(static_cast(0x78)); + write_number(static_cast(N)); + } + else if (N <= 0xffff) + { + oa->write_character(static_cast(0x79)); + write_number(static_cast(N)); + } + else if (N <= 0xffffffff) + { + oa->write_character(static_cast(0x7a)); + write_number(static_cast(N)); + } + // LCOV_EXCL_START + else if (N <= 0xffffffffffffffff) + { + oa->write_character(static_cast(0x7b)); + write_number(static_cast(N)); + } + // LCOV_EXCL_STOP + + // step 2: write the string + oa->write_characters( + reinterpret_cast(j.m_value.string->c_str()), + j.m_value.string->size()); + break; + } + + case value_t::array: + { + // step 1: write control byte and the array size + const auto N = j.m_value.array->size(); + if (N <= 0x17) + { + write_number(static_cast(0x80 + N)); + } + else if (N <= 0xff) + { + oa->write_character(static_cast(0x98)); + write_number(static_cast(N)); + } + else if (N <= 0xffff) + { + oa->write_character(static_cast(0x99)); + write_number(static_cast(N)); + } + else if (N <= 0xffffffff) + { + oa->write_character(static_cast(0x9a)); + write_number(static_cast(N)); + } + // LCOV_EXCL_START + else if (N <= 0xffffffffffffffff) + { + oa->write_character(static_cast(0x9b)); + write_number(static_cast(N)); + } + // LCOV_EXCL_STOP + + // step 2: write each element + for (const auto& el : *j.m_value.array) + { + write_cbor(el); + } + break; + } + + case value_t::object: + { + // step 1: write control byte and the object size + const auto N = j.m_value.object->size(); + if (N <= 0x17) + { + write_number(static_cast(0xa0 + N)); + } + else if (N <= 0xff) + { + oa->write_character(static_cast(0xb8)); + write_number(static_cast(N)); + } + else if (N <= 0xffff) + { + oa->write_character(static_cast(0xb9)); + write_number(static_cast(N)); + } + else if (N <= 0xffffffff) + { + oa->write_character(static_cast(0xba)); + write_number(static_cast(N)); + } + // LCOV_EXCL_START + else if (N <= 0xffffffffffffffff) + { + oa->write_character(static_cast(0xbb)); + write_number(static_cast(N)); + } + // LCOV_EXCL_STOP + + // step 2: write each element + for (const auto& el : *j.m_value.object) + { + write_cbor(el.first); + write_cbor(el.second); + } + break; + } + + default: + break; + } + } + + /*! + @brief[in] j JSON value to serialize + */ + void write_msgpack(const BasicJsonType& j) + { + switch (j.type()) + { + case value_t::null: // nil + { + oa->write_character(static_cast(0xc0)); + break; + } + + case value_t::boolean: // true and false + { + oa->write_character(j.m_value.boolean + ? static_cast(0xc3) + : static_cast(0xc2)); + break; + } + + case value_t::number_integer: + { + if (j.m_value.number_integer >= 0) + { + // MessagePack does not differentiate between positive + // signed integers and unsigned integers. Therefore, we used + // the code from the value_t::number_unsigned case here. + if (j.m_value.number_unsigned < 128) + { + // positive fixnum + write_number(static_cast(j.m_value.number_integer)); + } + else if (j.m_value.number_unsigned <= (std::numeric_limits::max)()) + { + // uint 8 + oa->write_character(static_cast(0xcc)); + write_number(static_cast(j.m_value.number_integer)); + } + else if (j.m_value.number_unsigned <= (std::numeric_limits::max)()) + { + // uint 16 + oa->write_character(static_cast(0xcd)); + write_number(static_cast(j.m_value.number_integer)); + } + else if (j.m_value.number_unsigned <= (std::numeric_limits::max)()) + { + // uint 32 + oa->write_character(static_cast(0xce)); + write_number(static_cast(j.m_value.number_integer)); + } + else if (j.m_value.number_unsigned <= (std::numeric_limits::max)()) + { + // uint 64 + oa->write_character(static_cast(0xcf)); + write_number(static_cast(j.m_value.number_integer)); + } + } + else + { + if (j.m_value.number_integer >= -32) + { + // negative fixnum + write_number(static_cast(j.m_value.number_integer)); + } + else if (j.m_value.number_integer >= (std::numeric_limits::min)() and + j.m_value.number_integer <= (std::numeric_limits::max)()) + { + // int 8 + oa->write_character(static_cast(0xd0)); + write_number(static_cast(j.m_value.number_integer)); + } + else if (j.m_value.number_integer >= (std::numeric_limits::min)() and + j.m_value.number_integer <= (std::numeric_limits::max)()) + { + // int 16 + oa->write_character(static_cast(0xd1)); + write_number(static_cast(j.m_value.number_integer)); + } + else if (j.m_value.number_integer >= (std::numeric_limits::min)() and + j.m_value.number_integer <= (std::numeric_limits::max)()) + { + // int 32 + oa->write_character(static_cast(0xd2)); + write_number(static_cast(j.m_value.number_integer)); + } + else if (j.m_value.number_integer >= (std::numeric_limits::min)() and + j.m_value.number_integer <= (std::numeric_limits::max)()) + { + // int 64 + oa->write_character(static_cast(0xd3)); + write_number(static_cast(j.m_value.number_integer)); + } + } + break; + } + + case value_t::number_unsigned: + { + if (j.m_value.number_unsigned < 128) + { + // positive fixnum + write_number(static_cast(j.m_value.number_integer)); + } + else if (j.m_value.number_unsigned <= (std::numeric_limits::max)()) + { + // uint 8 + oa->write_character(static_cast(0xcc)); + write_number(static_cast(j.m_value.number_integer)); + } + else if (j.m_value.number_unsigned <= (std::numeric_limits::max)()) + { + // uint 16 + oa->write_character(static_cast(0xcd)); + write_number(static_cast(j.m_value.number_integer)); + } + else if (j.m_value.number_unsigned <= (std::numeric_limits::max)()) + { + // uint 32 + oa->write_character(static_cast(0xce)); + write_number(static_cast(j.m_value.number_integer)); + } + else if (j.m_value.number_unsigned <= (std::numeric_limits::max)()) + { + // uint 64 + oa->write_character(static_cast(0xcf)); + write_number(static_cast(j.m_value.number_integer)); + } + break; + } + + case value_t::number_float: // float 64 + { + oa->write_character(static_cast(0xcb)); + write_number(j.m_value.number_float); + break; + } + + case value_t::string: + { + // step 1: write control byte and the string length + const auto N = j.m_value.string->size(); + if (N <= 31) + { + // fixstr + write_number(static_cast(0xa0 | N)); + } + else if (N <= 255) + { + // str 8 + oa->write_character(static_cast(0xd9)); + write_number(static_cast(N)); + } + else if (N <= 65535) + { + // str 16 + oa->write_character(static_cast(0xda)); + write_number(static_cast(N)); + } + else if (N <= 4294967295) + { + // str 32 + oa->write_character(static_cast(0xdb)); + write_number(static_cast(N)); + } + + // step 2: write the string + oa->write_characters( + reinterpret_cast(j.m_value.string->c_str()), + j.m_value.string->size()); + break; + } + + case value_t::array: + { + // step 1: write control byte and the array size + const auto N = j.m_value.array->size(); + if (N <= 15) + { + // fixarray + write_number(static_cast(0x90 | N)); + } + else if (N <= 0xffff) + { + // array 16 + oa->write_character(static_cast(0xdc)); + write_number(static_cast(N)); + } + else if (N <= 0xffffffff) + { + // array 32 + oa->write_character(static_cast(0xdd)); + write_number(static_cast(N)); + } + + // step 2: write each element + for (const auto& el : *j.m_value.array) + { + write_msgpack(el); + } + break; + } + + case value_t::object: + { + // step 1: write control byte and the object size + const auto N = j.m_value.object->size(); + if (N <= 15) + { + // fixmap + write_number(static_cast(0x80 | (N & 0xf))); + } + else if (N <= 65535) + { + // map 16 + oa->write_character(static_cast(0xde)); + write_number(static_cast(N)); + } + else if (N <= 4294967295) + { + // map 32 + oa->write_character(static_cast(0xdf)); + write_number(static_cast(N)); + } + + // step 2: write each element + for (const auto& el : *j.m_value.object) + { + write_msgpack(el.first); + write_msgpack(el.second); + } + break; + } + + default: + break; + } + } + + private: + /* + @brief write a number to output input + + @param[in] n number of type @a NumberType + @tparam NumberType the type of the number + + @note This function needs to respect the system's endianess, because bytes + in CBOR and MessagePack are stored in network order (big endian) and + therefore need reordering on little endian systems. + */ + template void write_number(NumberType n) + { + // step 1: write number to array of length NumberType + std::array vec; + std::memcpy(vec.data(), &n, sizeof(NumberType)); + + // step 2: write array to output (with possible reordering) + if (is_little_endian) + { + // reverse byte order prior to conversion if necessary + std::reverse(vec.begin(), vec.end()); + } + + oa->write_characters(vec.data(), sizeof(NumberType)); + } + + private: + /// whether we can assume little endianess + const bool is_little_endian = binary_reader::little_endianess(); + + /// the output + output_adapter_t oa = nullptr; +}; + +/////////////////// +// serialization // +/////////////////// + +template +class serializer +{ + using string_t = typename BasicJsonType::string_t; + using number_float_t = typename BasicJsonType::number_float_t; + using number_integer_t = typename BasicJsonType::number_integer_t; + using number_unsigned_t = typename BasicJsonType::number_unsigned_t; + public: + /*! + @param[in] s output stream to serialize to + @param[in] ichar indentation character to use + */ + serializer(output_adapter_t s, const char ichar) + : o(std::move(s)), loc(std::localeconv()), + thousands_sep(loc->thousands_sep == nullptr ? '\0' : loc->thousands_sep[0]), + decimal_point(loc->decimal_point == nullptr ? '\0' : loc->decimal_point[0]), + indent_char(ichar), indent_string(512, indent_char) {} + + // delete because of pointer members + serializer(const serializer&) = delete; + serializer& operator=(const serializer&) = delete; + + /*! + @brief internal implementation of the serialization function + + This function is called by the public member function dump and organizes + the serialization internally. The indentation level is propagated as + additional parameter. In case of arrays and objects, the function is + called recursively. + + - strings and object keys are escaped using `escape_string()` + - integer numbers are converted implicitly via `operator<<` + - floating-point numbers are converted to a string using `"%g"` format + + @param[in] val value to serialize + @param[in] pretty_print whether the output shall be pretty-printed + @param[in] indent_step the indent level + @param[in] current_indent the current indent level (only used internally) + */ + void dump(const BasicJsonType& val, const bool pretty_print, + const bool ensure_ascii, + const unsigned int indent_step, + const unsigned int current_indent = 0) + { + switch (val.m_type) + { + case value_t::object: + { + if (val.m_value.object->empty()) + { + o->write_characters("{}", 2); + return; + } + + if (pretty_print) + { + o->write_characters("{\n", 2); + + // variable to hold indentation for recursive calls + const auto new_indent = current_indent + indent_step; + if (JSON_UNLIKELY(indent_string.size() < new_indent)) + { + indent_string.resize(indent_string.size() * 2, ' '); + } + + // first n-1 elements + auto i = val.m_value.object->cbegin(); + for (std::size_t cnt = 0; cnt < val.m_value.object->size() - 1; ++cnt, ++i) + { + o->write_characters(indent_string.c_str(), new_indent); + o->write_character('\"'); + dump_escaped(i->first, ensure_ascii); + o->write_characters("\": ", 3); + dump(i->second, true, ensure_ascii, indent_step, new_indent); + o->write_characters(",\n", 2); + } + + // last element + assert(i != val.m_value.object->cend()); + assert(std::next(i) == val.m_value.object->cend()); + o->write_characters(indent_string.c_str(), new_indent); + o->write_character('\"'); + dump_escaped(i->first, ensure_ascii); + o->write_characters("\": ", 3); + dump(i->second, true, ensure_ascii, indent_step, new_indent); + + o->write_character('\n'); + o->write_characters(indent_string.c_str(), current_indent); + o->write_character('}'); + } + else + { + o->write_character('{'); + + // first n-1 elements + auto i = val.m_value.object->cbegin(); + for (std::size_t cnt = 0; cnt < val.m_value.object->size() - 1; ++cnt, ++i) + { + o->write_character('\"'); + dump_escaped(i->first, ensure_ascii); + o->write_characters("\":", 2); + dump(i->second, false, ensure_ascii, indent_step, current_indent); + o->write_character(','); + } + + // last element + assert(i != val.m_value.object->cend()); + assert(std::next(i) == val.m_value.object->cend()); + o->write_character('\"'); + dump_escaped(i->first, ensure_ascii); + o->write_characters("\":", 2); + dump(i->second, false, ensure_ascii, indent_step, current_indent); + + o->write_character('}'); + } + + return; + } + + case value_t::array: + { + if (val.m_value.array->empty()) + { + o->write_characters("[]", 2); + return; + } + + if (pretty_print) + { + o->write_characters("[\n", 2); + + // variable to hold indentation for recursive calls + const auto new_indent = current_indent + indent_step; + if (JSON_UNLIKELY(indent_string.size() < new_indent)) + { + indent_string.resize(indent_string.size() * 2, ' '); + } + + // first n-1 elements + for (auto i = val.m_value.array->cbegin(); + i != val.m_value.array->cend() - 1; ++i) + { + o->write_characters(indent_string.c_str(), new_indent); + dump(*i, true, ensure_ascii, indent_step, new_indent); + o->write_characters(",\n", 2); + } + + // last element + assert(not val.m_value.array->empty()); + o->write_characters(indent_string.c_str(), new_indent); + dump(val.m_value.array->back(), true, ensure_ascii, indent_step, new_indent); + + o->write_character('\n'); + o->write_characters(indent_string.c_str(), current_indent); + o->write_character(']'); + } + else + { + o->write_character('['); + + // first n-1 elements + for (auto i = val.m_value.array->cbegin(); + i != val.m_value.array->cend() - 1; ++i) + { + dump(*i, false, ensure_ascii, indent_step, current_indent); + o->write_character(','); + } + + // last element + assert(not val.m_value.array->empty()); + dump(val.m_value.array->back(), false, ensure_ascii, indent_step, current_indent); + + o->write_character(']'); + } + + return; + } + + case value_t::string: + { + o->write_character('\"'); + dump_escaped(*val.m_value.string, ensure_ascii); + o->write_character('\"'); + return; + } + + case value_t::boolean: + { + if (val.m_value.boolean) + { + o->write_characters("true", 4); + } + else + { + o->write_characters("false", 5); + } + return; + } + + case value_t::number_integer: + { + dump_integer(val.m_value.number_integer); + return; + } + + case value_t::number_unsigned: + { + dump_integer(val.m_value.number_unsigned); + return; + } + + case value_t::number_float: + { + dump_float(val.m_value.number_float); + return; + } + + case value_t::discarded: + { + o->write_characters("", 11); + return; + } + + case value_t::null: + { + o->write_characters("null", 4); + return; + } + } + } + + private: + /*! + @brief returns the number of expected bytes following in UTF-8 string + + @param[in] u the first byte of a UTF-8 string + @return the number of expected bytes following + */ + static constexpr std::size_t bytes_following(const uint8_t u) + { + return ((0 <= u and u <= 127) ? 0 + : ((192 <= u and u <= 223) ? 1 + : ((224 <= u and u <= 239) ? 2 + : ((240 <= u and u <= 247) ? 3 : std::string::npos)))); + } + + /*! + @brief calculates the extra space to escape a JSON string + + @param[in] s the string to escape + @param[in] ensure_ascii whether to escape non-ASCII characters with + \uXXXX sequences + @return the number of characters required to escape string @a s + + @complexity Linear in the length of string @a s. + */ + static std::size_t extra_space(const string_t& s, + const bool ensure_ascii) noexcept + { + std::size_t res = 0; + + for (std::size_t i = 0; i < s.size(); ++i) + { + switch (s[i]) + { + // control characters that can be escaped with a backslash + case '"': + case '\\': + case '\b': + case '\f': + case '\n': + case '\r': + case '\t': + { + // from c (1 byte) to \x (2 bytes) + res += 1; + break; + } + + // control characters that need \uxxxx escaping + case 0x00: + case 0x01: + case 0x02: + case 0x03: + case 0x04: + case 0x05: + case 0x06: + case 0x07: + case 0x0b: + case 0x0e: + case 0x0f: + case 0x10: + case 0x11: + case 0x12: + case 0x13: + case 0x14: + case 0x15: + case 0x16: + case 0x17: + case 0x18: + case 0x19: + case 0x1a: + case 0x1b: + case 0x1c: + case 0x1d: + case 0x1e: + case 0x1f: + { + // from c (1 byte) to \uxxxx (6 bytes) + res += 5; + break; + } + + default: + { + if (ensure_ascii and (s[i] & 0x80 or s[i] == 0x7F)) + { + const auto bytes = bytes_following(static_cast(s[i])); + if (bytes == std::string::npos) + { + // invalid characters are treated as is, so no + // additional space will be used + break; + } + + if (bytes == 3) + { + // codepoints that need 4 bytes (i.e., 3 additional + // bytes) in UTF-8 need a surrogate pair when \u + // escaping is used: from 4 bytes to \uxxxx\uxxxx + // (12 bytes) + res += (12 - bytes - 1); + } + else + { + // from x bytes to \uxxxx (6 bytes) + res += (6 - bytes - 1); + } + + // skip the additional bytes + i += bytes; + } + break; + } + } + } + + return res; + } + + static void escape_codepoint(int codepoint, string_t& result, std::size_t& pos) + { + // expecting a proper codepoint + assert(0x00 <= codepoint and codepoint <= 0x10FFFF); + + // the last written character was the backslash before the 'u' + assert(result[pos] == '\\'); + + // write the 'u' + result[++pos] = 'u'; + + // convert a number 0..15 to its hex representation (0..f) + static const std::array hexify = + { + { + '0', '1', '2', '3', '4', '5', '6', '7', + '8', '9', 'a', 'b', 'c', 'd', 'e', 'f' + } + }; + + if (codepoint < 0x10000) + { + // codepoints U+0000..U+FFFF can be represented as \uxxxx. + result[++pos] = hexify[(codepoint >> 12) & 0x0F]; + result[++pos] = hexify[(codepoint >> 8) & 0x0F]; + result[++pos] = hexify[(codepoint >> 4) & 0x0F]; + result[++pos] = hexify[codepoint & 0x0F]; + } + else + { + // codepoints U+10000..U+10FFFF need a surrogate pair to be + // represented as \uxxxx\uxxxx. + // http://www.unicode.org/faq/utf_bom.html#utf16-4 + codepoint -= 0x10000; + const int high_surrogate = 0xD800 | ((codepoint >> 10) & 0x3FF); + const int low_surrogate = 0xDC00 | (codepoint & 0x3FF); + result[++pos] = hexify[(high_surrogate >> 12) & 0x0F]; + result[++pos] = hexify[(high_surrogate >> 8) & 0x0F]; + result[++pos] = hexify[(high_surrogate >> 4) & 0x0F]; + result[++pos] = hexify[high_surrogate & 0x0F]; + ++pos; // backslash is already in output + result[++pos] = 'u'; + result[++pos] = hexify[(low_surrogate >> 12) & 0x0F]; + result[++pos] = hexify[(low_surrogate >> 8) & 0x0F]; + result[++pos] = hexify[(low_surrogate >> 4) & 0x0F]; + result[++pos] = hexify[low_surrogate & 0x0F]; + } + + ++pos; + } + + /*! + @brief dump escaped string + + Escape a string by replacing certain special characters by a sequence of an + escape character (backslash) and another character and other control + characters by a sequence of "\u" followed by a four-digit hex + representation. The escaped string is written to output stream @a o. + + @param[in] s the string to escape + @param[in] ensure_ascii whether to escape non-ASCII characters with + \uXXXX sequences + + @complexity Linear in the length of string @a s. + */ + void dump_escaped(const string_t& s, const bool ensure_ascii) const + { + const auto space = extra_space(s, ensure_ascii); + if (space == 0) + { + o->write_characters(s.c_str(), s.size()); + return; + } + + // create a result string of necessary size + string_t result(s.size() + space, '\\'); + std::size_t pos = 0; + + for (std::size_t i = 0; i < s.size(); ++i) + { + switch (s[i]) + { + case '"': // quotation mark (0x22) + { + result[pos + 1] = '"'; + pos += 2; + break; + } + + case '\\': // reverse solidus (0x5c) + { + // nothing to change + pos += 2; + break; + } + + case '\b': // backspace (0x08) + { + result[pos + 1] = 'b'; + pos += 2; + break; + } + + case '\f': // formfeed (0x0c) + { + result[pos + 1] = 'f'; + pos += 2; + break; + } + + case '\n': // newline (0x0a) + { + result[pos + 1] = 'n'; + pos += 2; + break; + } + + case '\r': // carriage return (0x0d) + { + result[pos + 1] = 'r'; + pos += 2; + break; + } + + case '\t': // horizontal tab (0x09) + { + result[pos + 1] = 't'; + pos += 2; + break; + } + + default: + { + // escape control characters (0x00..0x1F) or, if + // ensure_ascii parameter is used, non-ASCII characters + if ((0x00 <= s[i] and s[i] <= 0x1F) or + (ensure_ascii and (s[i] & 0x80 or s[i] == 0x7F))) + { + const auto bytes = bytes_following(static_cast(s[i])); + if (bytes == std::string::npos) + { + // copy invalid character as is + result[pos++] = s[i]; + break; + } + + // check that the additional bytes are present + assert(i + bytes < s.size()); + + // to use \uxxxx escaping, we first need to caluclate + // the codepoint from the UTF-8 bytes + int codepoint = 0; + + assert(0 <= bytes and bytes <= 3); + switch (bytes) + { + case 0: + { + codepoint = s[i] & 0xFF; + break; + } + + case 1: + { + codepoint = ((s[i] & 0x3F) << 6) + + (s[i + 1] & 0x7F); + break; + } + + case 2: + { + codepoint = ((s[i] & 0x1F) << 12) + + ((s[i + 1] & 0x7F) << 6) + + (s[i + 2] & 0x7F); + break; + } + + case 3: + { + codepoint = ((s[i] & 0xF) << 18) + + ((s[i + 1] & 0x7F) << 12) + + ((s[i + 2] & 0x7F) << 6) + + (s[i + 3] & 0x7F); + break; + } + + default: + break; // LCOV_EXCL_LINE + } + + escape_codepoint(codepoint, result, pos); + i += bytes; + } + else + { + // all other characters are added as-is + result[pos++] = s[i]; + } + break; + } + } + } + + assert(pos == result.size()); + o->write_characters(result.c_str(), result.size()); + } + + /*! + @brief dump an integer + + Dump a given integer to output stream @a o. Works internally with + @a number_buffer. + + @param[in] x integer number (signed or unsigned) to dump + @tparam NumberType either @a number_integer_t or @a number_unsigned_t + */ + template < + typename NumberType, + detail::enable_if_t::value or + std::is_same::value, + int> = 0 > + void dump_integer(NumberType x) + { + // special case for "0" + if (x == 0) + { + o->write_character('0'); + return; + } + + const bool is_negative = x < 0; + std::size_t i = 0; + + // spare 1 byte for '\0' + while (x != 0 and i < number_buffer.size() - 1) + { + const auto digit = std::labs(static_cast(x % 10)); + number_buffer[i++] = static_cast('0' + digit); + x /= 10; + } + + // make sure the number has been processed completely + assert(x == 0); + + if (is_negative) + { + // make sure there is capacity for the '-' + assert(i < number_buffer.size() - 2); + number_buffer[i++] = '-'; + } + + std::reverse(number_buffer.begin(), number_buffer.begin() + i); + o->write_characters(number_buffer.data(), i); + } + + /*! + @brief dump a floating-point number + + Dump a given floating-point number to output stream @a o. Works internally + with @a number_buffer. + + @param[in] x floating-point number to dump + */ + void dump_float(number_float_t x) + { + // NaN / inf + if (not std::isfinite(x) or std::isnan(x)) + { + o->write_characters("null", 4); + return; + } + + // special case for 0.0 and -0.0 + if (x == 0) + { + if (std::signbit(x)) + { + o->write_characters("-0.0", 4); + } + else + { + o->write_characters("0.0", 3); + } + return; + } + + // get number of digits for a text -> float -> text round-trip + static constexpr auto d = std::numeric_limits::digits10; + + // the actual conversion + std::ptrdiff_t len = snprintf(number_buffer.data(), number_buffer.size(), "%.*g", d, x); + + // negative value indicates an error + assert(len > 0); + // check if buffer was large enough + assert(static_cast(len) < number_buffer.size()); + + // erase thousands separator + if (thousands_sep != '\0') + { + const auto end = std::remove(number_buffer.begin(), + number_buffer.begin() + len, thousands_sep); + std::fill(end, number_buffer.end(), '\0'); + assert((end - number_buffer.begin()) <= len); + len = (end - number_buffer.begin()); + } + + // convert decimal point to '.' + if (decimal_point != '\0' and decimal_point != '.') + { + const auto dec_pos = std::find(number_buffer.begin(), number_buffer.end(), decimal_point); + if (dec_pos != number_buffer.end()) + { + *dec_pos = '.'; + } + } + + o->write_characters(number_buffer.data(), static_cast(len)); + + // determine if need to append ".0" + const bool value_is_int_like = + std::none_of(number_buffer.begin(), number_buffer.begin() + len + 1, + [](char c) + { + return (c == '.' or c == 'e'); + }); + + if (value_is_int_like) + { + o->write_characters(".0", 2); + } + } + + private: + /// the output of the serializer + output_adapter_t o = nullptr; + + /// a (hopefully) large enough character buffer + std::array number_buffer{{}}; + + /// the locale + const std::lconv* loc = nullptr; + /// the locale's thousand separator character + const char thousands_sep = '\0'; + /// the locale's decimal point character + const char decimal_point = '\0'; + + /// the indentation character + const char indent_char; + + /// the indentation string + string_t indent_string; +}; + +template +class json_ref +{ + public: + using value_type = BasicJsonType; + + json_ref(value_type&& value) + : owned_value(std::move(value)), + value_ref(&owned_value), + is_rvalue(true) + {} + + json_ref(const value_type& value) + : value_ref(const_cast(&value)), + is_rvalue(false) + {} + + json_ref(std::initializer_list init) + : owned_value(init), + value_ref(&owned_value), + is_rvalue(true) + {} + + template + json_ref(Args... args) + : owned_value(std::forward(args)...), + value_ref(&owned_value), + is_rvalue(true) + {} + + // class should be movable only + json_ref(json_ref&&) = default; + json_ref(const json_ref&) = delete; + json_ref& operator=(const json_ref&) = delete; + + value_type moved_or_copied() const + { + if (is_rvalue) + { + return std::move(*value_ref); + } + return *value_ref; + } + + value_type const& operator*() const + { + return *static_cast(value_ref); + } + + value_type const* operator->() const + { + return static_cast(value_ref); + } + + private: + mutable value_type owned_value = nullptr; + value_type* value_ref = nullptr; + const bool is_rvalue; +}; + +} // namespace detail + +/// namespace to hold default `to_json` / `from_json` functions +namespace +{ +constexpr const auto& to_json = detail::static_const::value; +constexpr const auto& from_json = detail::static_const::value; +} + + +/*! +@brief default JSONSerializer template argument + +This serializer ignores the template arguments and uses ADL +([argument-dependent lookup](http://en.cppreference.com/w/cpp/language/adl)) +for serialization. +*/ +template +struct adl_serializer +{ + /*! + @brief convert a JSON value to any value type + + This function is usually called by the `get()` function of the + @ref basic_json class (either explicit or via conversion operators). + + @param[in] j JSON value to read from + @param[in,out] val value to write to + */ + template + static void from_json(BasicJsonType&& j, ValueType& val) noexcept( + noexcept(::nlohmann::from_json(std::forward(j), val))) + { + ::nlohmann::from_json(std::forward(j), val); + } + + /*! + @brief convert any value type to a JSON value + + This function is usually called by the constructors of the @ref basic_json + class. + + @param[in,out] j JSON value to write to + @param[in] val value to read from + */ + template + static void to_json(BasicJsonType& j, ValueType&& val) noexcept( + noexcept(::nlohmann::to_json(j, std::forward(val)))) + { + ::nlohmann::to_json(j, std::forward(val)); + } +}; + +/*! +@brief JSON Pointer + +A JSON pointer defines a string syntax for identifying a specific value +within a JSON document. It can be used with functions `at` and +`operator[]`. Furthermore, JSON pointers are the base for JSON patches. + +@sa [RFC 6901](https://tools.ietf.org/html/rfc6901) + +@since version 2.0.0 +*/ +class json_pointer +{ + /// allow basic_json to access private members + NLOHMANN_BASIC_JSON_TPL_DECLARATION + friend class basic_json; + + public: + /*! + @brief create JSON pointer + + Create a JSON pointer according to the syntax described in + [Section 3 of RFC6901](https://tools.ietf.org/html/rfc6901#section-3). + + @param[in] s string representing the JSON pointer; if omitted, the empty + string is assumed which references the whole JSON value + + @throw parse_error.107 if the given JSON pointer @a s is nonempty and + does not begin with a slash (`/`); see example below + + @throw parse_error.108 if a tilde (`~`) in the given JSON pointer @a s + is not followed by `0` (representing `~`) or `1` (representing `/`); + see example below + + @liveexample{The example shows the construction several valid JSON + pointers as well as the exceptional behavior.,json_pointer} + + @since version 2.0.0 + */ + explicit json_pointer(const std::string& s = "") : reference_tokens(split(s)) {} + + /*! + @brief return a string representation of the JSON pointer + + @invariant For each JSON pointer `ptr`, it holds: + @code {.cpp} + ptr == json_pointer(ptr.to_string()); + @endcode + + @return a string representation of the JSON pointer + + @liveexample{The example shows the result of `to_string`., + json_pointer__to_string} + + @since version 2.0.0 + */ + std::string to_string() const noexcept + { + return std::accumulate(reference_tokens.begin(), reference_tokens.end(), + std::string{}, + [](const std::string & a, const std::string & b) + { + return a + "/" + escape(b); + }); + } + + /// @copydoc to_string() + operator std::string() const + { + return to_string(); + } + + private: + /*! + @brief remove and return last reference pointer + @throw out_of_range.405 if JSON pointer has no parent + */ + std::string pop_back() + { + if (JSON_UNLIKELY(is_root())) + { + JSON_THROW(detail::out_of_range::create(405, "JSON pointer has no parent")); + } + + auto last = reference_tokens.back(); + reference_tokens.pop_back(); + return last; + } + + /// return whether pointer points to the root document + bool is_root() const + { + return reference_tokens.empty(); + } + + json_pointer top() const + { + if (JSON_UNLIKELY(is_root())) + { + JSON_THROW(detail::out_of_range::create(405, "JSON pointer has no parent")); + } + + json_pointer result = *this; + result.reference_tokens = {reference_tokens[0]}; + return result; + } + + + /*! + @brief create and return a reference to the pointed to value + + @complexity Linear in the number of reference tokens. + + @throw parse_error.109 if array index is not a number + @throw type_error.313 if value cannot be unflattened + */ + NLOHMANN_BASIC_JSON_TPL_DECLARATION + NLOHMANN_BASIC_JSON_TPL& get_and_create(NLOHMANN_BASIC_JSON_TPL& j) const; + + /*! + @brief return a reference to the pointed to value + + @note This version does not throw if a value is not present, but tries to + create nested values instead. For instance, calling this function + with pointer `"/this/that"` on a null value is equivalent to calling + `operator[]("this").operator[]("that")` on that value, effectively + changing the null value to an object. + + @param[in] ptr a JSON value + + @return reference to the JSON value pointed to by the JSON pointer + + @complexity Linear in the length of the JSON pointer. + + @throw parse_error.106 if an array index begins with '0' + @throw parse_error.109 if an array index was not a number + @throw out_of_range.404 if the JSON pointer can not be resolved + */ + NLOHMANN_BASIC_JSON_TPL_DECLARATION + NLOHMANN_BASIC_JSON_TPL& get_unchecked(NLOHMANN_BASIC_JSON_TPL* ptr) const; + + /*! + @throw parse_error.106 if an array index begins with '0' + @throw parse_error.109 if an array index was not a number + @throw out_of_range.402 if the array index '-' is used + @throw out_of_range.404 if the JSON pointer can not be resolved + */ + NLOHMANN_BASIC_JSON_TPL_DECLARATION + NLOHMANN_BASIC_JSON_TPL& get_checked(NLOHMANN_BASIC_JSON_TPL* ptr) const; + + /*! + @brief return a const reference to the pointed to value + + @param[in] ptr a JSON value + + @return const reference to the JSON value pointed to by the JSON + pointer + + @throw parse_error.106 if an array index begins with '0' + @throw parse_error.109 if an array index was not a number + @throw out_of_range.402 if the array index '-' is used + @throw out_of_range.404 if the JSON pointer can not be resolved + */ + NLOHMANN_BASIC_JSON_TPL_DECLARATION + const NLOHMANN_BASIC_JSON_TPL& get_unchecked(const NLOHMANN_BASIC_JSON_TPL* ptr) const; + + /*! + @throw parse_error.106 if an array index begins with '0' + @throw parse_error.109 if an array index was not a number + @throw out_of_range.402 if the array index '-' is used + @throw out_of_range.404 if the JSON pointer can not be resolved + */ + NLOHMANN_BASIC_JSON_TPL_DECLARATION + const NLOHMANN_BASIC_JSON_TPL& get_checked(const NLOHMANN_BASIC_JSON_TPL* ptr) const; + + /*! + @brief split the string input to reference tokens + + @note This function is only called by the json_pointer constructor. + All exceptions below are documented there. + + @throw parse_error.107 if the pointer is not empty or begins with '/' + @throw parse_error.108 if character '~' is not followed by '0' or '1' + */ + static std::vector split(const std::string& reference_string) + { + std::vector result; + + // special case: empty reference string -> no reference tokens + if (reference_string.empty()) + { + return result; + } + + // check if nonempty reference string begins with slash + if (JSON_UNLIKELY(reference_string[0] != '/')) + { + JSON_THROW(detail::parse_error::create(107, 1, + "JSON pointer must be empty or begin with '/' - was: '" + + reference_string + "'")); + } + + // extract the reference tokens: + // - slash: position of the last read slash (or end of string) + // - start: position after the previous slash + for ( + // search for the first slash after the first character + std::size_t slash = reference_string.find_first_of('/', 1), + // set the beginning of the first reference token + start = 1; + // we can stop if start == string::npos+1 = 0 + start != 0; + // set the beginning of the next reference token + // (will eventually be 0 if slash == std::string::npos) + start = slash + 1, + // find next slash + slash = reference_string.find_first_of('/', start)) + { + // use the text between the beginning of the reference token + // (start) and the last slash (slash). + auto reference_token = reference_string.substr(start, slash - start); + + // check reference tokens are properly escaped + for (std::size_t pos = reference_token.find_first_of('~'); + pos != std::string::npos; + pos = reference_token.find_first_of('~', pos + 1)) + { + assert(reference_token[pos] == '~'); + + // ~ must be followed by 0 or 1 + if (JSON_UNLIKELY(pos == reference_token.size() - 1 or + (reference_token[pos + 1] != '0' and + reference_token[pos + 1] != '1'))) + { + JSON_THROW(detail::parse_error::create(108, 0, "escape character '~' must be followed with '0' or '1'")); + } + } + + // finally, store the reference token + unescape(reference_token); + result.push_back(reference_token); + } + + return result; + } + + /*! + @brief replace all occurrences of a substring by another string + + @param[in,out] s the string to manipulate; changed so that all + occurrences of @a f are replaced with @a t + @param[in] f the substring to replace with @a t + @param[in] t the string to replace @a f + + @pre The search string @a f must not be empty. **This precondition is + enforced with an assertion.** + + @since version 2.0.0 + */ + static void replace_substring(std::string& s, const std::string& f, + const std::string& t) + { + assert(not f.empty()); + for (auto pos = s.find(f); // find first occurrence of f + pos != std::string::npos; // make sure f was found + s.replace(pos, f.size(), t), // replace with t, and + pos = s.find(f, pos + t.size())) // find next occurrence of f + {} + } + + /// escape "~"" to "~0" and "/" to "~1" + static std::string escape(std::string s) + { + replace_substring(s, "~", "~0"); + replace_substring(s, "/", "~1"); + return s; + } + + /// unescape "~1" to tilde and "~0" to slash (order is important!) + static void unescape(std::string& s) + { + replace_substring(s, "~1", "/"); + replace_substring(s, "~0", "~"); + } + + /*! + @param[in] reference_string the reference string to the current value + @param[in] value the value to consider + @param[in,out] result the result object to insert values to + + @note Empty objects or arrays are flattened to `null`. + */ + NLOHMANN_BASIC_JSON_TPL_DECLARATION + static void flatten(const std::string& reference_string, + const NLOHMANN_BASIC_JSON_TPL& value, + NLOHMANN_BASIC_JSON_TPL& result); + + /*! + @param[in] value flattened JSON + + @return unflattened JSON + + @throw parse_error.109 if array index is not a number + @throw type_error.314 if value is not an object + @throw type_error.315 if object values are not primitive + @throw type_error.313 if value cannot be unflattened + */ + NLOHMANN_BASIC_JSON_TPL_DECLARATION + static NLOHMANN_BASIC_JSON_TPL + unflatten(const NLOHMANN_BASIC_JSON_TPL& value); + + friend bool operator==(json_pointer const& lhs, + json_pointer const& rhs) noexcept; + + friend bool operator!=(json_pointer const& lhs, + json_pointer const& rhs) noexcept; + + /// the reference tokens + std::vector reference_tokens; +}; + +/*! +@brief a class to store JSON values + +@tparam ObjectType type for JSON objects (`std::map` by default; will be used +in @ref object_t) +@tparam ArrayType type for JSON arrays (`std::vector` by default; will be used +in @ref array_t) +@tparam StringType type for JSON strings and object keys (`std::string` by +default; will be used in @ref string_t) +@tparam BooleanType type for JSON booleans (`bool` by default; will be used +in @ref boolean_t) +@tparam NumberIntegerType type for JSON integer numbers (`int64_t` by +default; will be used in @ref number_integer_t) +@tparam NumberUnsignedType type for JSON unsigned integer numbers (@c +`uint64_t` by default; will be used in @ref number_unsigned_t) +@tparam NumberFloatType type for JSON floating-point numbers (`double` by +default; will be used in @ref number_float_t) +@tparam AllocatorType type of the allocator to use (`std::allocator` by +default) +@tparam JSONSerializer the serializer to resolve internal calls to `to_json()` +and `from_json()` (@ref adl_serializer by default) + +@requirement The class satisfies the following concept requirements: +- Basic + - [DefaultConstructible](http://en.cppreference.com/w/cpp/concept/DefaultConstructible): + JSON values can be default constructed. The result will be a JSON null + value. + - [MoveConstructible](http://en.cppreference.com/w/cpp/concept/MoveConstructible): + A JSON value can be constructed from an rvalue argument. + - [CopyConstructible](http://en.cppreference.com/w/cpp/concept/CopyConstructible): + A JSON value can be copy-constructed from an lvalue expression. + - [MoveAssignable](http://en.cppreference.com/w/cpp/concept/MoveAssignable): + A JSON value van be assigned from an rvalue argument. + - [CopyAssignable](http://en.cppreference.com/w/cpp/concept/CopyAssignable): + A JSON value can be copy-assigned from an lvalue expression. + - [Destructible](http://en.cppreference.com/w/cpp/concept/Destructible): + JSON values can be destructed. +- Layout + - [StandardLayoutType](http://en.cppreference.com/w/cpp/concept/StandardLayoutType): + JSON values have + [standard layout](http://en.cppreference.com/w/cpp/language/data_members#Standard_layout): + All non-static data members are private and standard layout types, the + class has no virtual functions or (virtual) base classes. +- Library-wide + - [EqualityComparable](http://en.cppreference.com/w/cpp/concept/EqualityComparable): + JSON values can be compared with `==`, see @ref + operator==(const_reference,const_reference). + - [LessThanComparable](http://en.cppreference.com/w/cpp/concept/LessThanComparable): + JSON values can be compared with `<`, see @ref + operator<(const_reference,const_reference). + - [Swappable](http://en.cppreference.com/w/cpp/concept/Swappable): + Any JSON lvalue or rvalue of can be swapped with any lvalue or rvalue of + other compatible types, using unqualified function call @ref swap(). + - [NullablePointer](http://en.cppreference.com/w/cpp/concept/NullablePointer): + JSON values can be compared against `std::nullptr_t` objects which are used + to model the `null` value. +- Container + - [Container](http://en.cppreference.com/w/cpp/concept/Container): + JSON values can be used like STL containers and provide iterator access. + - [ReversibleContainer](http://en.cppreference.com/w/cpp/concept/ReversibleContainer); + JSON values can be used like STL containers and provide reverse iterator + access. + +@invariant The member variables @a m_value and @a m_type have the following +relationship: +- If `m_type == value_t::object`, then `m_value.object != nullptr`. +- If `m_type == value_t::array`, then `m_value.array != nullptr`. +- If `m_type == value_t::string`, then `m_value.string != nullptr`. +The invariants are checked by member function assert_invariant(). + +@internal +@note ObjectType trick from http://stackoverflow.com/a/9860911 +@endinternal + +@see [RFC 7159: The JavaScript Object Notation (JSON) Data Interchange +Format](http://rfc7159.net/rfc7159) + +@since version 1.0.0 + +@nosubgrouping +*/ +NLOHMANN_BASIC_JSON_TPL_DECLARATION +class basic_json +{ + private: + template friend struct detail::external_constructor; + friend ::nlohmann::json_pointer; + friend ::nlohmann::detail::parser; + friend ::nlohmann::detail::serializer; + template + friend class ::nlohmann::detail::iter_impl; + template + friend class ::nlohmann::detail::binary_writer; + template + friend class ::nlohmann::detail::binary_reader; + + /// workaround type for MSVC + using basic_json_t = NLOHMANN_BASIC_JSON_TPL; + + // convenience aliases for types residing in namespace detail; + using lexer = ::nlohmann::detail::lexer; + using parser = ::nlohmann::detail::parser; + + using primitive_iterator_t = ::nlohmann::detail::primitive_iterator_t; + template + using internal_iterator = ::nlohmann::detail::internal_iterator; + template + using iter_impl = ::nlohmann::detail::iter_impl; + template + using iteration_proxy = ::nlohmann::detail::iteration_proxy; + template using json_reverse_iterator = ::nlohmann::detail::json_reverse_iterator; + + template + using output_adapter_t = ::nlohmann::detail::output_adapter_t; + + using binary_reader = ::nlohmann::detail::binary_reader; + template using binary_writer = ::nlohmann::detail::binary_writer; + + using serializer = ::nlohmann::detail::serializer; + + public: + using value_t = detail::value_t; + // forward declarations + using json_pointer = ::nlohmann::json_pointer; + template + using json_serializer = JSONSerializer; + + using initializer_list_t = std::initializer_list>; + + //////////////// + // exceptions // + //////////////// + + /// @name exceptions + /// Classes to implement user-defined exceptions. + /// @{ + + /// @copydoc detail::exception + using exception = detail::exception; + /// @copydoc detail::parse_error + using parse_error = detail::parse_error; + /// @copydoc detail::invalid_iterator + using invalid_iterator = detail::invalid_iterator; + /// @copydoc detail::type_error + using type_error = detail::type_error; + /// @copydoc detail::out_of_range + using out_of_range = detail::out_of_range; + /// @copydoc detail::other_error + using other_error = detail::other_error; + + /// @} + + + ///////////////////// + // container types // + ///////////////////// + + /// @name container types + /// The canonic container types to use @ref basic_json like any other STL + /// container. + /// @{ + + /// the type of elements in a basic_json container + using value_type = basic_json; + + /// the type of an element reference + using reference = value_type&; + /// the type of an element const reference + using const_reference = const value_type&; + + /// a type to represent differences between iterators + using difference_type = std::ptrdiff_t; + /// a type to represent container sizes + using size_type = std::size_t; + + /// the allocator type + using allocator_type = AllocatorType; + + /// the type of an element pointer + using pointer = typename std::allocator_traits::pointer; + /// the type of an element const pointer + using const_pointer = typename std::allocator_traits::const_pointer; + + /// an iterator for a basic_json container + using iterator = iter_impl; + /// a const iterator for a basic_json container + using const_iterator = iter_impl; + /// a reverse iterator for a basic_json container + using reverse_iterator = json_reverse_iterator; + /// a const reverse iterator for a basic_json container + using const_reverse_iterator = json_reverse_iterator; + + /// @} + + + /*! + @brief returns the allocator associated with the container + */ + static allocator_type get_allocator() + { + return allocator_type(); + } + + /*! + @brief returns version information on the library + + This function returns a JSON object with information about the library, + including the version number and information on the platform and compiler. + + @return JSON object holding version information + key | description + ----------- | --------------- + `compiler` | Information on the used compiler. It is an object with the following keys: `c++` (the used C++ standard), `family` (the compiler family; possible values are `clang`, `icc`, `gcc`, `ilecpp`, `msvc`, `pgcpp`, `sunpro`, and `unknown`), and `version` (the compiler version). + `copyright` | The copyright line for the library as string. + `name` | The name of the library as string. + `platform` | The used platform as string. Possible values are `win32`, `linux`, `apple`, `unix`, and `unknown`. + `url` | The URL of the project as string. + `version` | The version of the library. It is an object with the following keys: `major`, `minor`, and `patch` as defined by [Semantic Versioning](http://semver.org), and `string` (the version string). + + @liveexample{The following code shows an example output of the `meta()` + function.,meta} + + @exceptionsafety Strong guarantee: if an exception is thrown, there are no + changes to any JSON value. + + @complexity Constant. + + @since 2.1.0 + */ + static basic_json meta() + { + basic_json result; + + result["copyright"] = "(C) 2013-2017 Niels Lohmann"; + result["name"] = "JSON for Modern C++"; + result["url"] = "https://github.com/nlohmann/json"; + result["version"] = + { + {"string", "2.1.1"}, {"major", 2}, {"minor", 1}, {"patch", 1} + }; + +#ifdef _WIN32 + result["platform"] = "win32"; +#elif defined __linux__ + result["platform"] = "linux"; +#elif defined __APPLE__ + result["platform"] = "apple"; +#elif defined __unix__ + result["platform"] = "unix"; +#else + result["platform"] = "unknown"; +#endif + +#if defined(__ICC) || defined(__INTEL_COMPILER) + result["compiler"] = {{"family", "icc"}, {"version", __INTEL_COMPILER}}; +#elif defined(__clang__) + result["compiler"] = {{"family", "clang"}, {"version", __clang_version__}}; +#elif defined(__GNUC__) || defined(__GNUG__) + result["compiler"] = {{"family", "gcc"}, {"version", std::to_string(__GNUC__) + "." + std::to_string(__GNUC_MINOR__) + "." + std::to_string(__GNUC_PATCHLEVEL__)}}; +#elif defined(__HP_cc) || defined(__HP_aCC) + result["compiler"] = "hp" +#elif defined(__IBMCPP__) + result["compiler"] = {{"family", "ilecpp"}, {"version", __IBMCPP__}}; +#elif defined(_MSC_VER) + result["compiler"] = {{"family", "msvc"}, {"version", _MSC_VER}}; +#elif defined(__PGI) + result["compiler"] = {{"family", "pgcpp"}, {"version", __PGI}}; +#elif defined(__SUNPRO_CC) + result["compiler"] = {{"family", "sunpro"}, {"version", __SUNPRO_CC}}; +#else + result["compiler"] = {{"family", "unknown"}, {"version", "unknown"}}; +#endif + +#ifdef __cplusplus + result["compiler"]["c++"] = std::to_string(__cplusplus); +#else + result["compiler"]["c++"] = "unknown"; +#endif + return result; + } + + + /////////////////////////// + // JSON value data types // + /////////////////////////// + + /// @name JSON value data types + /// The data types to store a JSON value. These types are derived from + /// the template arguments passed to class @ref basic_json. + /// @{ + + /*! + @brief a type for an object + + [RFC 7159](http://rfc7159.net/rfc7159) describes JSON objects as follows: + > An object is an unordered collection of zero or more name/value pairs, + > where a name is a string and a value is a string, number, boolean, null, + > object, or array. + + To store objects in C++, a type is defined by the template parameters + described below. + + @tparam ObjectType the container to store objects (e.g., `std::map` or + `std::unordered_map`) + @tparam StringType the type of the keys or names (e.g., `std::string`). + The comparison function `std::less` is used to order elements + inside the container. + @tparam AllocatorType the allocator to use for objects (e.g., + `std::allocator`) + + #### Default type + + With the default values for @a ObjectType (`std::map`), @a StringType + (`std::string`), and @a AllocatorType (`std::allocator`), the default + value for @a object_t is: + + @code {.cpp} + std::map< + std::string, // key_type + basic_json, // value_type + std::less, // key_compare + std::allocator> // allocator_type + > + @endcode + + #### Behavior + + The choice of @a object_t influences the behavior of the JSON class. With + the default type, objects have the following behavior: + + - When all names are unique, objects will be interoperable in the sense + that all software implementations receiving that object will agree on + the name-value mappings. + - When the names within an object are not unique, later stored name/value + pairs overwrite previously stored name/value pairs, leaving the used + names unique. For instance, `{"key": 1}` and `{"key": 2, "key": 1}` will + be treated as equal and both stored as `{"key": 1}`. + - Internally, name/value pairs are stored in lexicographical order of the + names. Objects will also be serialized (see @ref dump) in this order. + For instance, `{"b": 1, "a": 2}` and `{"a": 2, "b": 1}` will be stored + and serialized as `{"a": 2, "b": 1}`. + - When comparing objects, the order of the name/value pairs is irrelevant. + This makes objects interoperable in the sense that they will not be + affected by these differences. For instance, `{"b": 1, "a": 2}` and + `{"a": 2, "b": 1}` will be treated as equal. + + #### Limits + + [RFC 7159](http://rfc7159.net/rfc7159) specifies: + > An implementation may set limits on the maximum depth of nesting. + + In this class, the object's limit of nesting is not explicitly constrained. + However, a maximum depth of nesting may be introduced by the compiler or + runtime environment. A theoretical limit can be queried by calling the + @ref max_size function of a JSON object. + + #### Storage + + Objects are stored as pointers in a @ref basic_json type. That is, for any + access to object values, a pointer of type `object_t*` must be + dereferenced. + + @sa @ref array_t -- type for an array value + + @since version 1.0.0 + + @note The order name/value pairs are added to the object is *not* + preserved by the library. Therefore, iterating an object may return + name/value pairs in a different order than they were originally stored. In + fact, keys will be traversed in alphabetical order as `std::map` with + `std::less` is used by default. Please note this behavior conforms to [RFC + 7159](http://rfc7159.net/rfc7159), because any order implements the + specified "unordered" nature of JSON objects. + */ + using object_t = ObjectType, + AllocatorType>>; + + /*! + @brief a type for an array + + [RFC 7159](http://rfc7159.net/rfc7159) describes JSON arrays as follows: + > An array is an ordered sequence of zero or more values. + + To store objects in C++, a type is defined by the template parameters + explained below. + + @tparam ArrayType container type to store arrays (e.g., `std::vector` or + `std::list`) + @tparam AllocatorType allocator to use for arrays (e.g., `std::allocator`) + + #### Default type + + With the default values for @a ArrayType (`std::vector`) and @a + AllocatorType (`std::allocator`), the default value for @a array_t is: + + @code {.cpp} + std::vector< + basic_json, // value_type + std::allocator // allocator_type + > + @endcode + + #### Limits + + [RFC 7159](http://rfc7159.net/rfc7159) specifies: + > An implementation may set limits on the maximum depth of nesting. + + In this class, the array's limit of nesting is not explicitly constrained. + However, a maximum depth of nesting may be introduced by the compiler or + runtime environment. A theoretical limit can be queried by calling the + @ref max_size function of a JSON array. + + #### Storage + + Arrays are stored as pointers in a @ref basic_json type. That is, for any + access to array values, a pointer of type `array_t*` must be dereferenced. + + @sa @ref object_t -- type for an object value + + @since version 1.0.0 + */ + using array_t = ArrayType>; + + /*! + @brief a type for a string + + [RFC 7159](http://rfc7159.net/rfc7159) describes JSON strings as follows: + > A string is a sequence of zero or more Unicode characters. + + To store objects in C++, a type is defined by the template parameter + described below. Unicode values are split by the JSON class into + byte-sized characters during deserialization. + + @tparam StringType the container to store strings (e.g., `std::string`). + Note this container is used for keys/names in objects, see @ref object_t. + + #### Default type + + With the default values for @a StringType (`std::string`), the default + value for @a string_t is: + + @code {.cpp} + std::string + @endcode + + #### Encoding + + Strings are stored in UTF-8 encoding. Therefore, functions like + `std::string::size()` or `std::string::length()` return the number of + bytes in the string rather than the number of characters or glyphs. + + #### String comparison + + [RFC 7159](http://rfc7159.net/rfc7159) states: + > Software implementations are typically required to test names of object + > members for equality. Implementations that transform the textual + > representation into sequences of Unicode code units and then perform the + > comparison numerically, code unit by code unit, are interoperable in the + > sense that implementations will agree in all cases on equality or + > inequality of two strings. For example, implementations that compare + > strings with escaped characters unconverted may incorrectly find that + > `"a\\b"` and `"a\u005Cb"` are not equal. + + This implementation is interoperable as it does compare strings code unit + by code unit. + + #### Storage + + String values are stored as pointers in a @ref basic_json type. That is, + for any access to string values, a pointer of type `string_t*` must be + dereferenced. + + @since version 1.0.0 + */ + using string_t = StringType; + + /*! + @brief a type for a boolean + + [RFC 7159](http://rfc7159.net/rfc7159) implicitly describes a boolean as a + type which differentiates the two literals `true` and `false`. + + To store objects in C++, a type is defined by the template parameter @a + BooleanType which chooses the type to use. + + #### Default type + + With the default values for @a BooleanType (`bool`), the default value for + @a boolean_t is: + + @code {.cpp} + bool + @endcode + + #### Storage + + Boolean values are stored directly inside a @ref basic_json type. + + @since version 1.0.0 + */ + using boolean_t = BooleanType; + + /*! + @brief a type for a number (integer) + + [RFC 7159](http://rfc7159.net/rfc7159) describes numbers as follows: + > The representation of numbers is similar to that used in most + > programming languages. A number is represented in base 10 using decimal + > digits. It contains an integer component that may be prefixed with an + > optional minus sign, which may be followed by a fraction part and/or an + > exponent part. Leading zeros are not allowed. (...) Numeric values that + > cannot be represented in the grammar below (such as Infinity and NaN) + > are not permitted. + + This description includes both integer and floating-point numbers. + However, C++ allows more precise storage if it is known whether the number + is a signed integer, an unsigned integer or a floating-point number. + Therefore, three different types, @ref number_integer_t, @ref + number_unsigned_t and @ref number_float_t are used. + + To store integer numbers in C++, a type is defined by the template + parameter @a NumberIntegerType which chooses the type to use. + + #### Default type + + With the default values for @a NumberIntegerType (`int64_t`), the default + value for @a number_integer_t is: + + @code {.cpp} + int64_t + @endcode + + #### Default behavior + + - The restrictions about leading zeros is not enforced in C++. Instead, + leading zeros in integer literals lead to an interpretation as octal + number. Internally, the value will be stored as decimal number. For + instance, the C++ integer literal `010` will be serialized to `8`. + During deserialization, leading zeros yield an error. + - Not-a-number (NaN) values will be serialized to `null`. + + #### Limits + + [RFC 7159](http://rfc7159.net/rfc7159) specifies: + > An implementation may set limits on the range and precision of numbers. + + When the default type is used, the maximal integer number that can be + stored is `9223372036854775807` (INT64_MAX) and the minimal integer number + that can be stored is `-9223372036854775808` (INT64_MIN). Integer numbers + that are out of range will yield over/underflow when used in a + constructor. During deserialization, too large or small integer numbers + will be automatically be stored as @ref number_unsigned_t or @ref + number_float_t. + + [RFC 7159](http://rfc7159.net/rfc7159) further states: + > Note that when such software is used, numbers that are integers and are + > in the range \f$[-2^{53}+1, 2^{53}-1]\f$ are interoperable in the sense + > that implementations will agree exactly on their numeric values. + + As this range is a subrange of the exactly supported range [INT64_MIN, + INT64_MAX], this class's integer type is interoperable. + + #### Storage + + Integer number values are stored directly inside a @ref basic_json type. + + @sa @ref number_float_t -- type for number values (floating-point) + + @sa @ref number_unsigned_t -- type for number values (unsigned integer) + + @since version 1.0.0 + */ + using number_integer_t = NumberIntegerType; + + /*! + @brief a type for a number (unsigned) + + [RFC 7159](http://rfc7159.net/rfc7159) describes numbers as follows: + > The representation of numbers is similar to that used in most + > programming languages. A number is represented in base 10 using decimal + > digits. It contains an integer component that may be prefixed with an + > optional minus sign, which may be followed by a fraction part and/or an + > exponent part. Leading zeros are not allowed. (...) Numeric values that + > cannot be represented in the grammar below (such as Infinity and NaN) + > are not permitted. + + This description includes both integer and floating-point numbers. + However, C++ allows more precise storage if it is known whether the number + is a signed integer, an unsigned integer or a floating-point number. + Therefore, three different types, @ref number_integer_t, @ref + number_unsigned_t and @ref number_float_t are used. + + To store unsigned integer numbers in C++, a type is defined by the + template parameter @a NumberUnsignedType which chooses the type to use. + + #### Default type + + With the default values for @a NumberUnsignedType (`uint64_t`), the + default value for @a number_unsigned_t is: + + @code {.cpp} + uint64_t + @endcode + + #### Default behavior + + - The restrictions about leading zeros is not enforced in C++. Instead, + leading zeros in integer literals lead to an interpretation as octal + number. Internally, the value will be stored as decimal number. For + instance, the C++ integer literal `010` will be serialized to `8`. + During deserialization, leading zeros yield an error. + - Not-a-number (NaN) values will be serialized to `null`. + + #### Limits + + [RFC 7159](http://rfc7159.net/rfc7159) specifies: + > An implementation may set limits on the range and precision of numbers. + + When the default type is used, the maximal integer number that can be + stored is `18446744073709551615` (UINT64_MAX) and the minimal integer + number that can be stored is `0`. Integer numbers that are out of range + will yield over/underflow when used in a constructor. During + deserialization, too large or small integer numbers will be automatically + be stored as @ref number_integer_t or @ref number_float_t. + + [RFC 7159](http://rfc7159.net/rfc7159) further states: + > Note that when such software is used, numbers that are integers and are + > in the range \f$[-2^{53}+1, 2^{53}-1]\f$ are interoperable in the sense + > that implementations will agree exactly on their numeric values. + + As this range is a subrange (when considered in conjunction with the + number_integer_t type) of the exactly supported range [0, UINT64_MAX], + this class's integer type is interoperable. + + #### Storage + + Integer number values are stored directly inside a @ref basic_json type. + + @sa @ref number_float_t -- type for number values (floating-point) + @sa @ref number_integer_t -- type for number values (integer) + + @since version 2.0.0 + */ + using number_unsigned_t = NumberUnsignedType; + + /*! + @brief a type for a number (floating-point) + + [RFC 7159](http://rfc7159.net/rfc7159) describes numbers as follows: + > The representation of numbers is similar to that used in most + > programming languages. A number is represented in base 10 using decimal + > digits. It contains an integer component that may be prefixed with an + > optional minus sign, which may be followed by a fraction part and/or an + > exponent part. Leading zeros are not allowed. (...) Numeric values that + > cannot be represented in the grammar below (such as Infinity and NaN) + > are not permitted. + + This description includes both integer and floating-point numbers. + However, C++ allows more precise storage if it is known whether the number + is a signed integer, an unsigned integer or a floating-point number. + Therefore, three different types, @ref number_integer_t, @ref + number_unsigned_t and @ref number_float_t are used. + + To store floating-point numbers in C++, a type is defined by the template + parameter @a NumberFloatType which chooses the type to use. + + #### Default type + + With the default values for @a NumberFloatType (`double`), the default + value for @a number_float_t is: + + @code {.cpp} + double + @endcode + + #### Default behavior + + - The restrictions about leading zeros is not enforced in C++. Instead, + leading zeros in floating-point literals will be ignored. Internally, + the value will be stored as decimal number. For instance, the C++ + floating-point literal `01.2` will be serialized to `1.2`. During + deserialization, leading zeros yield an error. + - Not-a-number (NaN) values will be serialized to `null`. + + #### Limits + + [RFC 7159](http://rfc7159.net/rfc7159) states: + > This specification allows implementations to set limits on the range and + > precision of numbers accepted. Since software that implements IEEE + > 754-2008 binary64 (double precision) numbers is generally available and + > widely used, good interoperability can be achieved by implementations + > that expect no more precision or range than these provide, in the sense + > that implementations will approximate JSON numbers within the expected + > precision. + + This implementation does exactly follow this approach, as it uses double + precision floating-point numbers. Note values smaller than + `-1.79769313486232e+308` and values greater than `1.79769313486232e+308` + will be stored as NaN internally and be serialized to `null`. + + #### Storage + + Floating-point number values are stored directly inside a @ref basic_json + type. + + @sa @ref number_integer_t -- type for number values (integer) + + @sa @ref number_unsigned_t -- type for number values (unsigned integer) + + @since version 1.0.0 + */ + using number_float_t = NumberFloatType; + + /// @} + + private: + + /// helper for exception-safe object creation + template + static T* create(Args&& ... args) + { + AllocatorType alloc; + auto deleter = [&](T * object) + { + alloc.deallocate(object, 1); + }; + std::unique_ptr object(alloc.allocate(1), deleter); + alloc.construct(object.get(), std::forward(args)...); + assert(object != nullptr); + return object.release(); + } + + //////////////////////// + // JSON value storage // + //////////////////////// + + /*! + @brief a JSON value + + The actual storage for a JSON value of the @ref basic_json class. This + union combines the different storage types for the JSON value types + defined in @ref value_t. + + JSON type | value_t type | used type + --------- | --------------- | ------------------------ + object | object | pointer to @ref object_t + array | array | pointer to @ref array_t + string | string | pointer to @ref string_t + boolean | boolean | @ref boolean_t + number | number_integer | @ref number_integer_t + number | number_unsigned | @ref number_unsigned_t + number | number_float | @ref number_float_t + null | null | *no value is stored* + + @note Variable-length types (objects, arrays, and strings) are stored as + pointers. The size of the union should not exceed 64 bits if the default + value types are used. + + @since version 1.0.0 + */ + union json_value + { + /// object (stored with pointer to save storage) + object_t* object; + /// array (stored with pointer to save storage) + array_t* array; + /// string (stored with pointer to save storage) + string_t* string; + /// boolean + boolean_t boolean; + /// number (integer) + number_integer_t number_integer; + /// number (unsigned integer) + number_unsigned_t number_unsigned; + /// number (floating-point) + number_float_t number_float; + + /// default constructor (for null values) + json_value() = default; + /// constructor for booleans + json_value(boolean_t v) noexcept : boolean(v) {} + /// constructor for numbers (integer) + json_value(number_integer_t v) noexcept : number_integer(v) {} + /// constructor for numbers (unsigned) + json_value(number_unsigned_t v) noexcept : number_unsigned(v) {} + /// constructor for numbers (floating-point) + json_value(number_float_t v) noexcept : number_float(v) {} + /// constructor for empty values of a given type + json_value(value_t t) + { + switch (t) + { + case value_t::object: + { + object = create(); + break; + } + + case value_t::array: + { + array = create(); + break; + } + + case value_t::string: + { + string = create(""); + break; + } + + case value_t::boolean: + { + boolean = boolean_t(false); + break; + } + + case value_t::number_integer: + { + number_integer = number_integer_t(0); + break; + } + + case value_t::number_unsigned: + { + number_unsigned = number_unsigned_t(0); + break; + } + + case value_t::number_float: + { + number_float = number_float_t(0.0); + break; + } + + case value_t::null: + { + break; + } + + default: + { + if (JSON_UNLIKELY(t == value_t::null)) + { + JSON_THROW(other_error::create(500, "961c151d2e87f2686a955a9be24d316f1362bf21 2.1.1")); // LCOV_EXCL_LINE + } + break; + } + } + } + + /// constructor for strings + json_value(const string_t& value) + { + string = create(value); + } + + /// constructor for rvalue strings + json_value(string_t&& value) + { + string = create(std::move(value)); + } + + /// constructor for objects + json_value(const object_t& value) + { + object = create(value); + } + + /// constructor for rvalue objects + json_value(object_t&& value) + { + object = create(std::move(value)); + } + + /// constructor for arrays + json_value(const array_t& value) + { + array = create(value); + } + + /// constructor for rvalue arrays + json_value(array_t&& value) + { + array = create(std::move(value)); + } + + void destroy(value_t t) + { + switch (t) + { + case value_t::object: + { + AllocatorType alloc; + alloc.destroy(object); + alloc.deallocate(object, 1); + break; + } + + case value_t::array: + { + AllocatorType alloc; + alloc.destroy(array); + alloc.deallocate(array, 1); + break; + } + + case value_t::string: + { + AllocatorType alloc; + alloc.destroy(string); + alloc.deallocate(string, 1); + break; + } + + default: + { + break; + } + } + } + }; + + /*! + @brief checks the class invariants + + This function asserts the class invariants. It needs to be called at the + end of every constructor to make sure that created objects respect the + invariant. Furthermore, it has to be called each time the type of a JSON + value is changed, because the invariant expresses a relationship between + @a m_type and @a m_value. + */ + void assert_invariant() const + { + assert(m_type != value_t::object or m_value.object != nullptr); + assert(m_type != value_t::array or m_value.array != nullptr); + assert(m_type != value_t::string or m_value.string != nullptr); + } + + public: + ////////////////////////// + // JSON parser callback // + ////////////////////////// + + using parse_event_t = typename parser::parse_event_t; + + /*! + @brief per-element parser callback type + + With a parser callback function, the result of parsing a JSON text can be + influenced. When passed to @ref parse(std::istream&, const + parser_callback_t) or @ref parse(const CharT, const parser_callback_t), + it is called on certain events (passed as @ref parse_event_t via parameter + @a event) with a set recursion depth @a depth and context JSON value + @a parsed. The return value of the callback function is a boolean + indicating whether the element that emitted the callback shall be kept or + not. + + We distinguish six scenarios (determined by the event type) in which the + callback function can be called. The following table describes the values + of the parameters @a depth, @a event, and @a parsed. + + parameter @a event | description | parameter @a depth | parameter @a parsed + ------------------ | ----------- | ------------------ | ------------------- + parse_event_t::object_start | the parser read `{` and started to process a JSON object | depth of the parent of the JSON object | a JSON value with type discarded + parse_event_t::key | the parser read a key of a value in an object | depth of the currently parsed JSON object | a JSON string containing the key + parse_event_t::object_end | the parser read `}` and finished processing a JSON object | depth of the parent of the JSON object | the parsed JSON object + parse_event_t::array_start | the parser read `[` and started to process a JSON array | depth of the parent of the JSON array | a JSON value with type discarded + parse_event_t::array_end | the parser read `]` and finished processing a JSON array | depth of the parent of the JSON array | the parsed JSON array + parse_event_t::value | the parser finished reading a JSON value | depth of the value | the parsed JSON value + + @image html callback_events.png "Example when certain parse events are triggered" + + Discarding a value (i.e., returning `false`) has different effects + depending on the context in which function was called: + + - Discarded values in structured types are skipped. That is, the parser + will behave as if the discarded value was never read. + - In case a value outside a structured type is skipped, it is replaced + with `null`. This case happens if the top-level element is skipped. + + @param[in] depth the depth of the recursion during parsing + + @param[in] event an event of type parse_event_t indicating the context in + the callback function has been called + + @param[in,out] parsed the current intermediate parse result; note that + writing to this value has no effect for parse_event_t::key events + + @return Whether the JSON value which called the function during parsing + should be kept (`true`) or not (`false`). In the latter case, it is either + skipped completely or replaced by an empty discarded object. + + @sa @ref parse(std::istream&, parser_callback_t) or + @ref parse(const CharT, const parser_callback_t) for examples + + @since version 1.0.0 + */ + using parser_callback_t = typename parser::parser_callback_t; + + + ////////////////// + // constructors // + ////////////////// + + /// @name constructors and destructors + /// Constructors of class @ref basic_json, copy/move constructor, copy + /// assignment, static functions creating objects, and the destructor. + /// @{ + + /*! + @brief create an empty value with a given type + + Create an empty JSON value with a given type. The value will be default + initialized with an empty value which depends on the type: + + Value type | initial value + ----------- | ------------- + null | `null` + boolean | `false` + string | `""` + number | `0` + object | `{}` + array | `[]` + + @param[in] v the type of the value to create + + @complexity Constant. + + @exceptionsafety Strong guarantee: if an exception is thrown, there are no + changes to any JSON value. + + @liveexample{The following code shows the constructor for different @ref + value_t values,basic_json__value_t} + + @sa @ref clear() -- restores the postcondition of this constructor + + @since version 1.0.0 + */ + basic_json(const value_t v) + : m_type(v), m_value(v) + { + assert_invariant(); + } + + /*! + @brief create a null object + + Create a `null` JSON value. It either takes a null pointer as parameter + (explicitly creating `null`) or no parameter (implicitly creating `null`). + The passed null pointer itself is not read -- it is only used to choose + the right constructor. + + @complexity Constant. + + @exceptionsafety No-throw guarantee: this constructor never throws + exceptions. + + @liveexample{The following code shows the constructor with and without a + null pointer parameter.,basic_json__nullptr_t} + + @since version 1.0.0 + */ + basic_json(std::nullptr_t = nullptr) noexcept + : basic_json(value_t::null) + { + assert_invariant(); + } + + /*! + @brief create a JSON value + + This is a "catch all" constructor for all compatible JSON types; that is, + types for which a `to_json()` method exsits. The constructor forwards the + parameter @a val to that method (to `json_serializer::to_json` method + with `U = uncvref_t`, to be exact). + + Template type @a CompatibleType includes, but is not limited to, the + following types: + - **arrays**: @ref array_t and all kinds of compatible containers such as + `std::vector`, `std::deque`, `std::list`, `std::forward_list`, + `std::array`, `std::valarray`, `std::set`, `std::unordered_set`, + `std::multiset`, and `std::unordered_multiset` with a `value_type` from + which a @ref basic_json value can be constructed. + - **objects**: @ref object_t and all kinds of compatible associative + containers such as `std::map`, `std::unordered_map`, `std::multimap`, + and `std::unordered_multimap` with a `key_type` compatible to + @ref string_t and a `value_type` from which a @ref basic_json value can + be constructed. + - **strings**: @ref string_t, string literals, and all compatible string + containers can be used. + - **numbers**: @ref number_integer_t, @ref number_unsigned_t, + @ref number_float_t, and all convertible number types such as `int`, + `size_t`, `int64_t`, `float` or `double` can be used. + - **boolean**: @ref boolean_t / `bool` can be used. + + See the examples below. + + @tparam CompatibleType a type such that: + - @a CompatibleType is not derived from `std::istream`, + - @a CompatibleType is not @ref basic_json (to avoid hijacking copy/move + constructors), + - @a CompatibleType is not a @ref basic_json nested type (e.g., + @ref json_pointer, @ref iterator, etc ...) + - @ref @ref json_serializer has a + `to_json(basic_json_t&, CompatibleType&&)` method + + @tparam U = `uncvref_t` + + @param[in] val the value to be forwarded to the respective constructor + + @complexity Usually linear in the size of the passed @a val, also + depending on the implementation of the called `to_json()` + method. + + @exceptionsafety Depends on the called constructor. For types directly + supported by the library (i.e., all types for which no `to_json()` function + was provided), strong guarantee holds: if an exception is thrown, there are + no changes to any JSON value. + + @liveexample{The following code shows the constructor with several + compatible types.,basic_json__CompatibleType} + + @since version 2.1.0 + */ + template, + detail::enable_if_t::value and + not std::is_same::value and + not detail::is_basic_json_nested_type< + basic_json_t, U>::value and + detail::has_to_json::value, + int> = 0> + basic_json(CompatibleType && val) noexcept(noexcept(JSONSerializer::to_json( + std::declval(), std::forward(val)))) + { + JSONSerializer::to_json(*this, std::forward(val)); + assert_invariant(); + } + + /*! + @brief create a container (array or object) from an initializer list + + Creates a JSON value of type array or object from the passed initializer + list @a init. In case @a type_deduction is `true` (default), the type of + the JSON value to be created is deducted from the initializer list @a init + according to the following rules: + + 1. If the list is empty, an empty JSON object value `{}` is created. + 2. If the list consists of pairs whose first element is a string, a JSON + object value is created where the first elements of the pairs are + treated as keys and the second elements are as values. + 3. In all other cases, an array is created. + + The rules aim to create the best fit between a C++ initializer list and + JSON values. The rationale is as follows: + + 1. The empty initializer list is written as `{}` which is exactly an empty + JSON object. + 2. C++ has no way of describing mapped types other than to list a list of + pairs. As JSON requires that keys must be of type string, rule 2 is the + weakest constraint one can pose on initializer lists to interpret them + as an object. + 3. In all other cases, the initializer list could not be interpreted as + JSON object type, so interpreting it as JSON array type is safe. + + With the rules described above, the following JSON values cannot be + expressed by an initializer list: + + - the empty array (`[]`): use @ref array(initializer_list_t) + with an empty initializer list in this case + - arrays whose elements satisfy rule 2: use @ref + array(initializer_list_t) with the same initializer list + in this case + + @note When used without parentheses around an empty initializer list, @ref + basic_json() is called instead of this function, yielding the JSON null + value. + + @param[in] init initializer list with JSON values + + @param[in] type_deduction internal parameter; when set to `true`, the type + of the JSON value is deducted from the initializer list @a init; when set + to `false`, the type provided via @a manual_type is forced. This mode is + used by the functions @ref array(initializer_list_t) and + @ref object(initializer_list_t). + + @param[in] manual_type internal parameter; when @a type_deduction is set + to `false`, the created JSON value will use the provided type (only @ref + value_t::array and @ref value_t::object are valid); when @a type_deduction + is set to `true`, this parameter has no effect + + @throw type_error.301 if @a type_deduction is `false`, @a manual_type is + `value_t::object`, but @a init contains an element which is not a pair + whose first element is a string. In this case, the constructor could not + create an object. If @a type_deduction would have be `true`, an array + would have been created. See @ref object(initializer_list_t) + for an example. + + @complexity Linear in the size of the initializer list @a init. + + @exceptionsafety Strong guarantee: if an exception is thrown, there are no + changes to any JSON value. + + @liveexample{The example below shows how JSON values are created from + initializer lists.,basic_json__list_init_t} + + @sa @ref array(initializer_list_t) -- create a JSON array + value from an initializer list + @sa @ref object(initializer_list_t) -- create a JSON object + value from an initializer list + + @since version 1.0.0 + */ + basic_json(initializer_list_t init, + bool type_deduction = true, + value_t manual_type = value_t::array) + { + // check if each element is an array with two elements whose first + // element is a string + bool is_an_object = std::all_of(init.begin(), init.end(), + [](const detail::json_ref& element_ref) + { + return (element_ref->is_array() and element_ref->size() == 2 and (*element_ref)[0].is_string()); + }); + + // adjust type if type deduction is not wanted + if (not type_deduction) + { + // if array is wanted, do not create an object though possible + if (manual_type == value_t::array) + { + is_an_object = false; + } + + // if object is wanted but impossible, throw an exception + if (JSON_UNLIKELY(manual_type == value_t::object and not is_an_object)) + { + JSON_THROW(type_error::create(301, "cannot create object from initializer list")); + } + } + + if (is_an_object) + { + // the initializer list is a list of pairs -> create object + m_type = value_t::object; + m_value = value_t::object; + + std::for_each(init.begin(), init.end(), [this](const detail::json_ref& element_ref) + { + auto element = element_ref.moved_or_copied(); + m_value.object->emplace( + std::move(*((*element.m_value.array)[0].m_value.string)), + std::move((*element.m_value.array)[1])); + }); + } + else + { + // the initializer list describes an array -> create array + m_type = value_t::array; + m_value.array = create(init.begin(), init.end()); + } + + assert_invariant(); + } + + /*! + @brief explicitly create an array from an initializer list + + Creates a JSON array value from a given initializer list. That is, given a + list of values `a, b, c`, creates the JSON value `[a, b, c]`. If the + initializer list is empty, the empty array `[]` is created. + + @note This function is only needed to express two edge cases that cannot + be realized with the initializer list constructor (@ref + basic_json(initializer_list_t, bool, value_t)). These cases + are: + 1. creating an array whose elements are all pairs whose first element is a + string -- in this case, the initializer list constructor would create an + object, taking the first elements as keys + 2. creating an empty array -- passing the empty initializer list to the + initializer list constructor yields an empty object + + @param[in] init initializer list with JSON values to create an array from + (optional) + + @return JSON array value + + @complexity Linear in the size of @a init. + + @exceptionsafety Strong guarantee: if an exception is thrown, there are no + changes to any JSON value. + + @liveexample{The following code shows an example for the `array` + function.,array} + + @sa @ref basic_json(initializer_list_t, bool, value_t) -- + create a JSON value from an initializer list + @sa @ref object(initializer_list_t) -- create a JSON object + value from an initializer list + + @since version 1.0.0 + */ + static basic_json array(initializer_list_t init = {}) + { + return basic_json(init, false, value_t::array); + } + + /*! + @brief explicitly create an object from an initializer list + + Creates a JSON object value from a given initializer list. The initializer + lists elements must be pairs, and their first elements must be strings. If + the initializer list is empty, the empty object `{}` is created. + + @note This function is only added for symmetry reasons. In contrast to the + related function @ref array(initializer_list_t), there are + no cases which can only be expressed by this function. That is, any + initializer list @a init can also be passed to the initializer list + constructor @ref basic_json(initializer_list_t, bool, value_t). + + @param[in] init initializer list to create an object from (optional) + + @return JSON object value + + @throw type_error.301 if @a init is not a list of pairs whose first + elements are strings. In this case, no object can be created. When such a + value is passed to @ref basic_json(initializer_list_t, bool, value_t), + an array would have been created from the passed initializer list @a init. + See example below. + + @complexity Linear in the size of @a init. + + @exceptionsafety Strong guarantee: if an exception is thrown, there are no + changes to any JSON value. + + @liveexample{The following code shows an example for the `object` + function.,object} + + @sa @ref basic_json(initializer_list_t, bool, value_t) -- + create a JSON value from an initializer list + @sa @ref array(initializer_list_t) -- create a JSON array + value from an initializer list + + @since version 1.0.0 + */ + static basic_json object(initializer_list_t init = {}) + { + return basic_json(init, false, value_t::object); + } + + /*! + @brief construct an array with count copies of given value + + Constructs a JSON array value by creating @a cnt copies of a passed value. + In case @a cnt is `0`, an empty array is created. + + @param[in] cnt the number of JSON copies of @a val to create + @param[in] val the JSON value to copy + + @post `std::distance(begin(),end()) == cnt` holds. + + @complexity Linear in @a cnt. + + @exceptionsafety Strong guarantee: if an exception is thrown, there are no + changes to any JSON value. + + @liveexample{The following code shows examples for the @ref + basic_json(size_type\, const basic_json&) + constructor.,basic_json__size_type_basic_json} + + @since version 1.0.0 + */ + basic_json(size_type cnt, const basic_json& val) + : m_type(value_t::array) + { + m_value.array = create(cnt, val); + assert_invariant(); + } + + /*! + @brief construct a JSON container given an iterator range + + Constructs the JSON value with the contents of the range `[first, last)`. + The semantics depends on the different types a JSON value can have: + - In case of a null type, invalid_iterator.206 is thrown. + - In case of other primitive types (number, boolean, or string), @a first + must be `begin()` and @a last must be `end()`. In this case, the value is + copied. Otherwise, invalid_iterator.204 is thrown. + - In case of structured types (array, object), the constructor behaves as + similar versions for `std::vector` or `std::map`; that is, a JSON array + or object is constructed from the values in the range. + + @tparam InputIT an input iterator type (@ref iterator or @ref + const_iterator) + + @param[in] first begin of the range to copy from (included) + @param[in] last end of the range to copy from (excluded) + + @pre Iterators @a first and @a last must be initialized. **This + precondition is enforced with an assertion (see warning).** If + assertions are switched off, a violation of this precondition yields + undefined behavior. + + @pre Range `[first, last)` is valid. Usually, this precondition cannot be + checked efficiently. Only certain edge cases are detected; see the + description of the exceptions below. A violation of this precondition + yields undefined behavior. + + @warning A precondition is enforced with a runtime assertion that will + result in calling `std::abort` if this precondition is not met. + Assertions can be disabled by defining `NDEBUG` at compile time. + See http://en.cppreference.com/w/cpp/error/assert for more + information. + + @throw invalid_iterator.201 if iterators @a first and @a last are not + compatible (i.e., do not belong to the same JSON value). In this case, + the range `[first, last)` is undefined. + @throw invalid_iterator.204 if iterators @a first and @a last belong to a + primitive type (number, boolean, or string), but @a first does not point + to the first element any more. In this case, the range `[first, last)` is + undefined. See example code below. + @throw invalid_iterator.206 if iterators @a first and @a last belong to a + null value. In this case, the range `[first, last)` is undefined. + + @complexity Linear in distance between @a first and @a last. + + @exceptionsafety Strong guarantee: if an exception is thrown, there are no + changes to any JSON value. + + @liveexample{The example below shows several ways to create JSON values by + specifying a subrange with iterators.,basic_json__InputIt_InputIt} + + @since version 1.0.0 + */ + template::value or + std::is_same::value, int>::type = 0> + basic_json(InputIT first, InputIT last) + { + assert(first.m_object != nullptr); + assert(last.m_object != nullptr); + + // make sure iterator fits the current value + if (JSON_UNLIKELY(first.m_object != last.m_object)) + { + JSON_THROW(invalid_iterator::create(201, "iterators are not compatible")); + } + + // copy type from first iterator + m_type = first.m_object->m_type; + + // check if iterator range is complete for primitive values + switch (m_type) + { + case value_t::boolean: + case value_t::number_float: + case value_t::number_integer: + case value_t::number_unsigned: + case value_t::string: + { + if (JSON_UNLIKELY(not first.m_it.primitive_iterator.is_begin() + or not last.m_it.primitive_iterator.is_end())) + { + JSON_THROW(invalid_iterator::create(204, "iterators out of range")); + } + break; + } + + default: + break; + } + + switch (m_type) + { + case value_t::number_integer: + { + m_value.number_integer = first.m_object->m_value.number_integer; + break; + } + + case value_t::number_unsigned: + { + m_value.number_unsigned = first.m_object->m_value.number_unsigned; + break; + } + + case value_t::number_float: + { + m_value.number_float = first.m_object->m_value.number_float; + break; + } + + case value_t::boolean: + { + m_value.boolean = first.m_object->m_value.boolean; + break; + } + + case value_t::string: + { + m_value = *first.m_object->m_value.string; + break; + } + + case value_t::object: + { + m_value.object = create(first.m_it.object_iterator, + last.m_it.object_iterator); + break; + } + + case value_t::array: + { + m_value.array = create(first.m_it.array_iterator, + last.m_it.array_iterator); + break; + } + + default: + JSON_THROW(invalid_iterator::create(206, "cannot construct with iterators from " + + std::string(first.m_object->type_name()))); + } + + assert_invariant(); + } + + + /////////////////////////////////////// + // other constructors and destructor // + /////////////////////////////////////// + + /// @private + basic_json(const detail::json_ref& ref) + : basic_json(ref.moved_or_copied()) + {} + + /*! + @brief copy constructor + + Creates a copy of a given JSON value. + + @param[in] other the JSON value to copy + + @post `*this == other` + + @complexity Linear in the size of @a other. + + @exceptionsafety Strong guarantee: if an exception is thrown, there are no + changes to any JSON value. + + @requirement This function helps `basic_json` satisfying the + [Container](http://en.cppreference.com/w/cpp/concept/Container) + requirements: + - The complexity is linear. + - As postcondition, it holds: `other == basic_json(other)`. + + @liveexample{The following code shows an example for the copy + constructor.,basic_json__basic_json} + + @since version 1.0.0 + */ + basic_json(const basic_json& other) + : m_type(other.m_type) + { + // check of passed value is valid + other.assert_invariant(); + + switch (m_type) + { + case value_t::object: + { + m_value = *other.m_value.object; + break; + } + + case value_t::array: + { + m_value = *other.m_value.array; + break; + } + + case value_t::string: + { + m_value = *other.m_value.string; + break; + } + + case value_t::boolean: + { + m_value = other.m_value.boolean; + break; + } + + case value_t::number_integer: + { + m_value = other.m_value.number_integer; + break; + } + + case value_t::number_unsigned: + { + m_value = other.m_value.number_unsigned; + break; + } + + case value_t::number_float: + { + m_value = other.m_value.number_float; + break; + } + + default: + break; + } + + assert_invariant(); + } + + /*! + @brief move constructor + + Move constructor. Constructs a JSON value with the contents of the given + value @a other using move semantics. It "steals" the resources from @a + other and leaves it as JSON null value. + + @param[in,out] other value to move to this object + + @post `*this` has the same value as @a other before the call. + @post @a other is a JSON null value. + + @complexity Constant. + + @exceptionsafety No-throw guarantee: this constructor never throws + exceptions. + + @requirement This function helps `basic_json` satisfying the + [MoveConstructible](http://en.cppreference.com/w/cpp/concept/MoveConstructible) + requirements. + + @liveexample{The code below shows the move constructor explicitly called + via std::move.,basic_json__moveconstructor} + + @since version 1.0.0 + */ + basic_json(basic_json&& other) noexcept + : m_type(std::move(other.m_type)), + m_value(std::move(other.m_value)) + { + // check that passed value is valid + other.assert_invariant(); + + // invalidate payload + other.m_type = value_t::null; + other.m_value = {}; + + assert_invariant(); + } + + /*! + @brief copy assignment + + Copy assignment operator. Copies a JSON value via the "copy and swap" + strategy: It is expressed in terms of the copy constructor, destructor, + and the `swap()` member function. + + @param[in] other value to copy from + + @complexity Linear. + + @requirement This function helps `basic_json` satisfying the + [Container](http://en.cppreference.com/w/cpp/concept/Container) + requirements: + - The complexity is linear. + + @liveexample{The code below shows and example for the copy assignment. It + creates a copy of value `a` which is then swapped with `b`. Finally\, the + copy of `a` (which is the null value after the swap) is + destroyed.,basic_json__copyassignment} + + @since version 1.0.0 + */ + reference& operator=(basic_json other) noexcept ( + std::is_nothrow_move_constructible::value and + std::is_nothrow_move_assignable::value and + std::is_nothrow_move_constructible::value and + std::is_nothrow_move_assignable::value + ) + { + // check that passed value is valid + other.assert_invariant(); + + using std::swap; + swap(m_type, other.m_type); + swap(m_value, other.m_value); + + assert_invariant(); + return *this; + } + + /*! + @brief destructor + + Destroys the JSON value and frees all allocated memory. + + @complexity Linear. + + @requirement This function helps `basic_json` satisfying the + [Container](http://en.cppreference.com/w/cpp/concept/Container) + requirements: + - The complexity is linear. + - All stored elements are destroyed and all memory is freed. + + @since version 1.0.0 + */ + ~basic_json() + { + assert_invariant(); + m_value.destroy(m_type); + } + + /// @} + + public: + /////////////////////// + // object inspection // + /////////////////////// + + /// @name object inspection + /// Functions to inspect the type of a JSON value. + /// @{ + + /*! + @brief serialization + + Serialization function for JSON values. The function tries to mimic + Python's `json.dumps()` function, and currently supports its @a indent + and @a ensure_ascii parameters. + + @param[in] indent If indent is nonnegative, then array elements and object + members will be pretty-printed with that indent level. An indent level of + `0` will only insert newlines. `-1` (the default) selects the most compact + representation. + @param[in] indent_char The character to use for indentation if @a indent is + greater than `0`. The default is ` ` (space). + @param[in] ensure_ascii If @a ensure_ascii is true, all non-ASCII characters + in the output are escaped with \uXXXX sequences, and the result consists + of ASCII characters only. + + @return string containing the serialization of the JSON value + + @complexity Linear. + + @exceptionsafety Strong guarantee: if an exception is thrown, there are no + changes in the JSON value. + + @liveexample{The following example shows the effect of different @a indent\, + @a indent_char\, and @a ensure_ascii parameters to the result of the + serialization.,dump} + + @see https://docs.python.org/2/library/json.html#json.dump + + @since version 1.0.0; indentation character @a indent_char and option + @a ensure_ascii added in version 3.0.0 + */ + string_t dump(const int indent = -1, const char indent_char = ' ', + const bool ensure_ascii = false) const + { + string_t result; + serializer s(detail::output_adapter(result), indent_char); + + if (indent >= 0) + { + s.dump(*this, true, ensure_ascii, static_cast(indent)); + } + else + { + s.dump(*this, false, ensure_ascii, 0); + } + + return result; + } + + /*! + @brief return the type of the JSON value (explicit) + + Return the type of the JSON value as a value from the @ref value_t + enumeration. + + @return the type of the JSON value + Value type | return value + ------------------------- | ------------------------- + null | value_t::null + boolean | value_t::boolean + string | value_t::string + number (integer) | value_t::number_integer + number (unsigned integer) | value_t::number_unsigned + number (foating-point) | value_t::number_float + object | value_t::object + array | value_t::array + discarded | value_t::discarded + + @complexity Constant. + + @exceptionsafety No-throw guarantee: this member function never throws + exceptions. + + @liveexample{The following code exemplifies `type()` for all JSON + types.,type} + + @sa @ref operator value_t() -- return the type of the JSON value (implicit) + @sa @ref type_name() -- return the type as string + + @since version 1.0.0 + */ + constexpr value_t type() const noexcept + { + return m_type; + } + + /*! + @brief return whether type is primitive + + This function returns true if and only if the JSON type is primitive + (string, number, boolean, or null). + + @return `true` if type is primitive (string, number, boolean, or null), + `false` otherwise. + + @complexity Constant. + + @exceptionsafety No-throw guarantee: this member function never throws + exceptions. + + @liveexample{The following code exemplifies `is_primitive()` for all JSON + types.,is_primitive} + + @sa @ref is_structured() -- returns whether JSON value is structured + @sa @ref is_null() -- returns whether JSON value is `null` + @sa @ref is_string() -- returns whether JSON value is a string + @sa @ref is_boolean() -- returns whether JSON value is a boolean + @sa @ref is_number() -- returns whether JSON value is a number + + @since version 1.0.0 + */ + constexpr bool is_primitive() const noexcept + { + return is_null() or is_string() or is_boolean() or is_number(); + } + + /*! + @brief return whether type is structured + + This function returns true if and only if the JSON type is structured + (array or object). + + @return `true` if type is structured (array or object), `false` otherwise. + + @complexity Constant. + + @exceptionsafety No-throw guarantee: this member function never throws + exceptions. + + @liveexample{The following code exemplifies `is_structured()` for all JSON + types.,is_structured} + + @sa @ref is_primitive() -- returns whether value is primitive + @sa @ref is_array() -- returns whether value is an array + @sa @ref is_object() -- returns whether value is an object + + @since version 1.0.0 + */ + constexpr bool is_structured() const noexcept + { + return is_array() or is_object(); + } + + /*! + @brief return whether value is null + + This function returns true if and only if the JSON value is null. + + @return `true` if type is null, `false` otherwise. + + @complexity Constant. + + @exceptionsafety No-throw guarantee: this member function never throws + exceptions. + + @liveexample{The following code exemplifies `is_null()` for all JSON + types.,is_null} + + @since version 1.0.0 + */ + constexpr bool is_null() const noexcept + { + return (m_type == value_t::null); + } + + /*! + @brief return whether value is a boolean + + This function returns true if and only if the JSON value is a boolean. + + @return `true` if type is boolean, `false` otherwise. + + @complexity Constant. + + @exceptionsafety No-throw guarantee: this member function never throws + exceptions. + + @liveexample{The following code exemplifies `is_boolean()` for all JSON + types.,is_boolean} + + @since version 1.0.0 + */ + constexpr bool is_boolean() const noexcept + { + return (m_type == value_t::boolean); + } + + /*! + @brief return whether value is a number + + This function returns true if and only if the JSON value is a number. This + includes both integer (signed and unsigned) and floating-point values. + + @return `true` if type is number (regardless whether integer, unsigned + integer or floating-type), `false` otherwise. + + @complexity Constant. + + @exceptionsafety No-throw guarantee: this member function never throws + exceptions. + + @liveexample{The following code exemplifies `is_number()` for all JSON + types.,is_number} + + @sa @ref is_number_integer() -- check if value is an integer or unsigned + integer number + @sa @ref is_number_unsigned() -- check if value is an unsigned integer + number + @sa @ref is_number_float() -- check if value is a floating-point number + + @since version 1.0.0 + */ + constexpr bool is_number() const noexcept + { + return is_number_integer() or is_number_float(); + } + + /*! + @brief return whether value is an integer number + + This function returns true if and only if the JSON value is a signed or + unsigned integer number. This excludes floating-point values. + + @return `true` if type is an integer or unsigned integer number, `false` + otherwise. + + @complexity Constant. + + @exceptionsafety No-throw guarantee: this member function never throws + exceptions. + + @liveexample{The following code exemplifies `is_number_integer()` for all + JSON types.,is_number_integer} + + @sa @ref is_number() -- check if value is a number + @sa @ref is_number_unsigned() -- check if value is an unsigned integer + number + @sa @ref is_number_float() -- check if value is a floating-point number + + @since version 1.0.0 + */ + constexpr bool is_number_integer() const noexcept + { + return (m_type == value_t::number_integer or m_type == value_t::number_unsigned); + } + + /*! + @brief return whether value is an unsigned integer number + + This function returns true if and only if the JSON value is an unsigned + integer number. This excludes floating-point and signed integer values. + + @return `true` if type is an unsigned integer number, `false` otherwise. + + @complexity Constant. + + @exceptionsafety No-throw guarantee: this member function never throws + exceptions. + + @liveexample{The following code exemplifies `is_number_unsigned()` for all + JSON types.,is_number_unsigned} + + @sa @ref is_number() -- check if value is a number + @sa @ref is_number_integer() -- check if value is an integer or unsigned + integer number + @sa @ref is_number_float() -- check if value is a floating-point number + + @since version 2.0.0 + */ + constexpr bool is_number_unsigned() const noexcept + { + return (m_type == value_t::number_unsigned); + } + + /*! + @brief return whether value is a floating-point number + + This function returns true if and only if the JSON value is a + floating-point number. This excludes signed and unsigned integer values. + + @return `true` if type is a floating-point number, `false` otherwise. + + @complexity Constant. + + @exceptionsafety No-throw guarantee: this member function never throws + exceptions. + + @liveexample{The following code exemplifies `is_number_float()` for all + JSON types.,is_number_float} + + @sa @ref is_number() -- check if value is number + @sa @ref is_number_integer() -- check if value is an integer number + @sa @ref is_number_unsigned() -- check if value is an unsigned integer + number + + @since version 1.0.0 + */ + constexpr bool is_number_float() const noexcept + { + return (m_type == value_t::number_float); + } + + /*! + @brief return whether value is an object + + This function returns true if and only if the JSON value is an object. + + @return `true` if type is object, `false` otherwise. + + @complexity Constant. + + @exceptionsafety No-throw guarantee: this member function never throws + exceptions. + + @liveexample{The following code exemplifies `is_object()` for all JSON + types.,is_object} + + @since version 1.0.0 + */ + constexpr bool is_object() const noexcept + { + return (m_type == value_t::object); + } + + /*! + @brief return whether value is an array + + This function returns true if and only if the JSON value is an array. + + @return `true` if type is array, `false` otherwise. + + @complexity Constant. + + @exceptionsafety No-throw guarantee: this member function never throws + exceptions. + + @liveexample{The following code exemplifies `is_array()` for all JSON + types.,is_array} + + @since version 1.0.0 + */ + constexpr bool is_array() const noexcept + { + return (m_type == value_t::array); + } + + /*! + @brief return whether value is a string + + This function returns true if and only if the JSON value is a string. + + @return `true` if type is string, `false` otherwise. + + @complexity Constant. + + @exceptionsafety No-throw guarantee: this member function never throws + exceptions. + + @liveexample{The following code exemplifies `is_string()` for all JSON + types.,is_string} + + @since version 1.0.0 + */ + constexpr bool is_string() const noexcept + { + return (m_type == value_t::string); + } + + /*! + @brief return whether value is discarded + + This function returns true if and only if the JSON value was discarded + during parsing with a callback function (see @ref parser_callback_t). + + @note This function will always be `false` for JSON values after parsing. + That is, discarded values can only occur during parsing, but will be + removed when inside a structured value or replaced by null in other cases. + + @return `true` if type is discarded, `false` otherwise. + + @complexity Constant. + + @exceptionsafety No-throw guarantee: this member function never throws + exceptions. + + @liveexample{The following code exemplifies `is_discarded()` for all JSON + types.,is_discarded} + + @since version 1.0.0 + */ + constexpr bool is_discarded() const noexcept + { + return (m_type == value_t::discarded); + } + + /*! + @brief return the type of the JSON value (implicit) + + Implicitly return the type of the JSON value as a value from the @ref + value_t enumeration. + + @return the type of the JSON value + + @complexity Constant. + + @exceptionsafety No-throw guarantee: this member function never throws + exceptions. + + @liveexample{The following code exemplifies the @ref value_t operator for + all JSON types.,operator__value_t} + + @sa @ref type() -- return the type of the JSON value (explicit) + @sa @ref type_name() -- return the type as string + + @since version 1.0.0 + */ + constexpr operator value_t() const noexcept + { + return m_type; + } + + /// @} + + private: + ////////////////// + // value access // + ////////////////// + + /// get a boolean (explicit) + boolean_t get_impl(boolean_t* /*unused*/) const + { + if (JSON_LIKELY(is_boolean())) + { + return m_value.boolean; + } + + JSON_THROW(type_error::create(302, "type must be boolean, but is " + std::string(type_name()))); + } + + /// get a pointer to the value (object) + object_t* get_impl_ptr(object_t* /*unused*/) noexcept + { + return is_object() ? m_value.object : nullptr; + } + + /// get a pointer to the value (object) + constexpr const object_t* get_impl_ptr(const object_t* /*unused*/) const noexcept + { + return is_object() ? m_value.object : nullptr; + } + + /// get a pointer to the value (array) + array_t* get_impl_ptr(array_t* /*unused*/) noexcept + { + return is_array() ? m_value.array : nullptr; + } + + /// get a pointer to the value (array) + constexpr const array_t* get_impl_ptr(const array_t* /*unused*/) const noexcept + { + return is_array() ? m_value.array : nullptr; + } + + /// get a pointer to the value (string) + string_t* get_impl_ptr(string_t* /*unused*/) noexcept + { + return is_string() ? m_value.string : nullptr; + } + + /// get a pointer to the value (string) + constexpr const string_t* get_impl_ptr(const string_t* /*unused*/) const noexcept + { + return is_string() ? m_value.string : nullptr; + } + + /// get a pointer to the value (boolean) + boolean_t* get_impl_ptr(boolean_t* /*unused*/) noexcept + { + return is_boolean() ? &m_value.boolean : nullptr; + } + + /// get a pointer to the value (boolean) + constexpr const boolean_t* get_impl_ptr(const boolean_t* /*unused*/) const noexcept + { + return is_boolean() ? &m_value.boolean : nullptr; + } + + /// get a pointer to the value (integer number) + number_integer_t* get_impl_ptr(number_integer_t* /*unused*/) noexcept + { + return is_number_integer() ? &m_value.number_integer : nullptr; + } + + /// get a pointer to the value (integer number) + constexpr const number_integer_t* get_impl_ptr(const number_integer_t* /*unused*/) const noexcept + { + return is_number_integer() ? &m_value.number_integer : nullptr; + } + + /// get a pointer to the value (unsigned number) + number_unsigned_t* get_impl_ptr(number_unsigned_t* /*unused*/) noexcept + { + return is_number_unsigned() ? &m_value.number_unsigned : nullptr; + } + + /// get a pointer to the value (unsigned number) + constexpr const number_unsigned_t* get_impl_ptr(const number_unsigned_t* /*unused*/) const noexcept + { + return is_number_unsigned() ? &m_value.number_unsigned : nullptr; + } + + /// get a pointer to the value (floating-point number) + number_float_t* get_impl_ptr(number_float_t* /*unused*/) noexcept + { + return is_number_float() ? &m_value.number_float : nullptr; + } + + /// get a pointer to the value (floating-point number) + constexpr const number_float_t* get_impl_ptr(const number_float_t* /*unused*/) const noexcept + { + return is_number_float() ? &m_value.number_float : nullptr; + } + + /*! + @brief helper function to implement get_ref() + + This function helps to implement get_ref() without code duplication for + const and non-const overloads + + @tparam ThisType will be deduced as `basic_json` or `const basic_json` + + @throw type_error.303 if ReferenceType does not match underlying value + type of the current JSON + */ + template + static ReferenceType get_ref_impl(ThisType& obj) + { + // delegate the call to get_ptr<>() + auto ptr = obj.template get_ptr::type>(); + + if (JSON_LIKELY(ptr != nullptr)) + { + return *ptr; + } + + JSON_THROW(type_error::create(303, "incompatible ReferenceType for get_ref, actual type is " + std::string(obj.type_name()))); + } + + public: + /// @name value access + /// Direct access to the stored value of a JSON value. + /// @{ + + /*! + @brief get special-case overload + + This overloads avoids a lot of template boilerplate, it can be seen as the + identity method + + @tparam BasicJsonType == @ref basic_json + + @return a copy of *this + + @complexity Constant. + + @since version 2.1.0 + */ + template < + typename BasicJsonType, + detail::enable_if_t::type, + basic_json_t>::value, + int> = 0 > + basic_json get() const + { + return *this; + } + + /*! + @brief get a value (explicit) + + Explicit type conversion between the JSON value and a compatible value + which is [CopyConstructible](http://en.cppreference.com/w/cpp/concept/CopyConstructible) + and [DefaultConstructible](http://en.cppreference.com/w/cpp/concept/DefaultConstructible). + The value is converted by calling the @ref json_serializer + `from_json()` method. + + The function is equivalent to executing + @code {.cpp} + ValueType ret; + JSONSerializer::from_json(*this, ret); + return ret; + @endcode + + This overloads is chosen if: + - @a ValueType is not @ref basic_json, + - @ref json_serializer has a `from_json()` method of the form + `void from_json(const basic_json&, ValueType&)`, and + - @ref json_serializer does not have a `from_json()` method of + the form `ValueType from_json(const basic_json&)` + + @tparam ValueTypeCV the provided value type + @tparam ValueType the returned value type + + @return copy of the JSON value, converted to @a ValueType + + @throw what @ref json_serializer `from_json()` method throws + + @liveexample{The example below shows several conversions from JSON values + to other types. There a few things to note: (1) Floating-point numbers can + be converted to integers\, (2) A JSON array can be converted to a standard + `std::vector`\, (3) A JSON object can be converted to C++ + associative containers such as `std::unordered_map`.,get__ValueType_const} + + @since version 2.1.0 + */ + template < + typename ValueTypeCV, + typename ValueType = detail::uncvref_t, + detail::enable_if_t < + not std::is_same::value and + detail::has_from_json::value and + not detail::has_non_default_from_json::value, + int > = 0 > + ValueType get() const noexcept(noexcept( + JSONSerializer::from_json(std::declval(), std::declval()))) + { + // we cannot static_assert on ValueTypeCV being non-const, because + // there is support for get(), which is why we + // still need the uncvref + static_assert(not std::is_reference::value, + "get() cannot be used with reference types, you might want to use get_ref()"); + static_assert(std::is_default_constructible::value, + "types must be DefaultConstructible when used with get()"); + + ValueType ret; + JSONSerializer::from_json(*this, ret); + return ret; + } + + /*! + @brief get a value (explicit); special case + + Explicit type conversion between the JSON value and a compatible value + which is **not** [CopyConstructible](http://en.cppreference.com/w/cpp/concept/CopyConstructible) + and **not** [DefaultConstructible](http://en.cppreference.com/w/cpp/concept/DefaultConstructible). + The value is converted by calling the @ref json_serializer + `from_json()` method. + + The function is equivalent to executing + @code {.cpp} + return JSONSerializer::from_json(*this); + @endcode + + This overloads is chosen if: + - @a ValueType is not @ref basic_json and + - @ref json_serializer has a `from_json()` method of the form + `ValueType from_json(const basic_json&)` + + @note If @ref json_serializer has both overloads of + `from_json()`, this one is chosen. + + @tparam ValueTypeCV the provided value type + @tparam ValueType the returned value type + + @return copy of the JSON value, converted to @a ValueType + + @throw what @ref json_serializer `from_json()` method throws + + @since version 2.1.0 + */ + template < + typename ValueTypeCV, + typename ValueType = detail::uncvref_t, + detail::enable_if_t::value and + detail::has_non_default_from_json::value, int> = 0 > + ValueType get() const noexcept(noexcept( + JSONSerializer::from_json(std::declval()))) + { + static_assert(not std::is_reference::value, + "get() cannot be used with reference types, you might want to use get_ref()"); + return JSONSerializer::from_json(*this); + } + + /*! + @brief get a pointer value (explicit) + + Explicit pointer access to the internally stored JSON value. No copies are + made. + + @warning The pointer becomes invalid if the underlying JSON object + changes. + + @tparam PointerType pointer type; must be a pointer to @ref array_t, @ref + object_t, @ref string_t, @ref boolean_t, @ref number_integer_t, + @ref number_unsigned_t, or @ref number_float_t. + + @return pointer to the internally stored JSON value if the requested + pointer type @a PointerType fits to the JSON value; `nullptr` otherwise + + @complexity Constant. + + @liveexample{The example below shows how pointers to internal values of a + JSON value can be requested. Note that no type conversions are made and a + `nullptr` is returned if the value and the requested pointer type does not + match.,get__PointerType} + + @sa @ref get_ptr() for explicit pointer-member access + + @since version 1.0.0 + */ + template::value, int>::type = 0> + PointerType get() noexcept + { + // delegate the call to get_ptr + return get_ptr(); + } + + /*! + @brief get a pointer value (explicit) + @copydoc get() + */ + template::value, int>::type = 0> + constexpr const PointerType get() const noexcept + { + // delegate the call to get_ptr + return get_ptr(); + } + + /*! + @brief get a pointer value (implicit) + + Implicit pointer access to the internally stored JSON value. No copies are + made. + + @warning Writing data to the pointee of the result yields an undefined + state. + + @tparam PointerType pointer type; must be a pointer to @ref array_t, @ref + object_t, @ref string_t, @ref boolean_t, @ref number_integer_t, + @ref number_unsigned_t, or @ref number_float_t. Enforced by a static + assertion. + + @return pointer to the internally stored JSON value if the requested + pointer type @a PointerType fits to the JSON value; `nullptr` otherwise + + @complexity Constant. + + @liveexample{The example below shows how pointers to internal values of a + JSON value can be requested. Note that no type conversions are made and a + `nullptr` is returned if the value and the requested pointer type does not + match.,get_ptr} + + @since version 1.0.0 + */ + template::value, int>::type = 0> + PointerType get_ptr() noexcept + { + // get the type of the PointerType (remove pointer and const) + using pointee_t = typename std::remove_const::type>::type>::type; + // make sure the type matches the allowed types + static_assert( + std::is_same::value + or std::is_same::value + or std::is_same::value + or std::is_same::value + or std::is_same::value + or std::is_same::value + or std::is_same::value + , "incompatible pointer type"); + + // delegate the call to get_impl_ptr<>() + return get_impl_ptr(static_cast(nullptr)); + } + + /*! + @brief get a pointer value (implicit) + @copydoc get_ptr() + */ + template::value and + std::is_const::type>::value, int>::type = 0> + constexpr const PointerType get_ptr() const noexcept + { + // get the type of the PointerType (remove pointer and const) + using pointee_t = typename std::remove_const::type>::type>::type; + // make sure the type matches the allowed types + static_assert( + std::is_same::value + or std::is_same::value + or std::is_same::value + or std::is_same::value + or std::is_same::value + or std::is_same::value + or std::is_same::value + , "incompatible pointer type"); + + // delegate the call to get_impl_ptr<>() const + return get_impl_ptr(static_cast(nullptr)); + } + + /*! + @brief get a reference value (implicit) + + Implicit reference access to the internally stored JSON value. No copies + are made. + + @warning Writing data to the referee of the result yields an undefined + state. + + @tparam ReferenceType reference type; must be a reference to @ref array_t, + @ref object_t, @ref string_t, @ref boolean_t, @ref number_integer_t, or + @ref number_float_t. Enforced by static assertion. + + @return reference to the internally stored JSON value if the requested + reference type @a ReferenceType fits to the JSON value; throws + type_error.303 otherwise + + @throw type_error.303 in case passed type @a ReferenceType is incompatible + with the stored JSON value; see example below + + @complexity Constant. + + @liveexample{The example shows several calls to `get_ref()`.,get_ref} + + @since version 1.1.0 + */ + template::value, int>::type = 0> + ReferenceType get_ref() + { + // delegate call to get_ref_impl + return get_ref_impl(*this); + } + + /*! + @brief get a reference value (implicit) + @copydoc get_ref() + */ + template::value and + std::is_const::type>::value, int>::type = 0> + ReferenceType get_ref() const + { + // delegate call to get_ref_impl + return get_ref_impl(*this); + } + + /*! + @brief get a value (implicit) + + Implicit type conversion between the JSON value and a compatible value. + The call is realized by calling @ref get() const. + + @tparam ValueType non-pointer type compatible to the JSON value, for + instance `int` for JSON integer numbers, `bool` for JSON booleans, or + `std::vector` types for JSON arrays. The character type of @ref string_t + as well as an initializer list of this type is excluded to avoid + ambiguities as these types implicitly convert to `std::string`. + + @return copy of the JSON value, converted to type @a ValueType + + @throw type_error.302 in case passed type @a ValueType is incompatible + to the JSON value type (e.g., the JSON value is of type boolean, but a + string is requested); see example below + + @complexity Linear in the size of the JSON value. + + @liveexample{The example below shows several conversions from JSON values + to other types. There a few things to note: (1) Floating-point numbers can + be converted to integers\, (2) A JSON array can be converted to a standard + `std::vector`\, (3) A JSON object can be converted to C++ + associative containers such as `std::unordered_map`.,operator__ValueType} + + @since version 1.0.0 + */ + template < typename ValueType, typename std::enable_if < + not std::is_pointer::value and + not std::is_same>::value and + not std::is_same::value +#ifndef _MSC_VER // fix for issue #167 operator<< ambiguity under VS2015 + and not std::is_same>::value +#endif +#if (defined(__cplusplus) && __cplusplus >= 201703L) || (defined(_MSC_VER) && _MSC_VER >1900 && defined(_HAS_CXX17) && _HAS_CXX17 == 1) // fix for issue #464 + and not std::is_same::value +#endif + , int >::type = 0 > + operator ValueType() const + { + // delegate the call to get<>() const + return get(); + } + + /// @} + + + //////////////////// + // element access // + //////////////////// + + /// @name element access + /// Access to the JSON value. + /// @{ + + /*! + @brief access specified array element with bounds checking + + Returns a reference to the element at specified location @a idx, with + bounds checking. + + @param[in] idx index of the element to access + + @return reference to the element at index @a idx + + @throw type_error.304 if the JSON value is not an array; in this case, + calling `at` with an index makes no sense. See example below. + @throw out_of_range.401 if the index @a idx is out of range of the array; + that is, `idx >= size()`. See example below. + + @exceptionsafety Strong guarantee: if an exception is thrown, there are no + changes in the JSON value. + + @complexity Constant. + + @since version 1.0.0 + + @liveexample{The example below shows how array elements can be read and + written using `at()`. It also demonstrates the different exceptions that + can be thrown.,at__size_type} + */ + reference at(size_type idx) + { + // at only works for arrays + if (JSON_LIKELY(is_array())) + { + JSON_TRY + { + return m_value.array->at(idx); + } + JSON_CATCH (std::out_of_range&) + { + // create better exception explanation + JSON_THROW(out_of_range::create(401, "array index " + std::to_string(idx) + " is out of range")); + } + } + else + { + JSON_THROW(type_error::create(304, "cannot use at() with " + std::string(type_name()))); + } + } + + /*! + @brief access specified array element with bounds checking + + Returns a const reference to the element at specified location @a idx, + with bounds checking. + + @param[in] idx index of the element to access + + @return const reference to the element at index @a idx + + @throw type_error.304 if the JSON value is not an array; in this case, + calling `at` with an index makes no sense. See example below. + @throw out_of_range.401 if the index @a idx is out of range of the array; + that is, `idx >= size()`. See example below. + + @exceptionsafety Strong guarantee: if an exception is thrown, there are no + changes in the JSON value. + + @complexity Constant. + + @since version 1.0.0 + + @liveexample{The example below shows how array elements can be read using + `at()`. It also demonstrates the different exceptions that can be thrown., + at__size_type_const} + */ + const_reference at(size_type idx) const + { + // at only works for arrays + if (JSON_LIKELY(is_array())) + { + JSON_TRY + { + return m_value.array->at(idx); + } + JSON_CATCH (std::out_of_range&) + { + // create better exception explanation + JSON_THROW(out_of_range::create(401, "array index " + std::to_string(idx) + " is out of range")); + } + } + else + { + JSON_THROW(type_error::create(304, "cannot use at() with " + std::string(type_name()))); + } + } + + /*! + @brief access specified object element with bounds checking + + Returns a reference to the element at with specified key @a key, with + bounds checking. + + @param[in] key key of the element to access + + @return reference to the element at key @a key + + @throw type_error.304 if the JSON value is not an object; in this case, + calling `at` with a key makes no sense. See example below. + @throw out_of_range.403 if the key @a key is is not stored in the object; + that is, `find(key) == end()`. See example below. + + @exceptionsafety Strong guarantee: if an exception is thrown, there are no + changes in the JSON value. + + @complexity Logarithmic in the size of the container. + + @sa @ref operator[](const typename object_t::key_type&) for unchecked + access by reference + @sa @ref value() for access by value with a default value + + @since version 1.0.0 + + @liveexample{The example below shows how object elements can be read and + written using `at()`. It also demonstrates the different exceptions that + can be thrown.,at__object_t_key_type} + */ + reference at(const typename object_t::key_type& key) + { + // at only works for objects + if (JSON_LIKELY(is_object())) + { + JSON_TRY + { + return m_value.object->at(key); + } + JSON_CATCH (std::out_of_range&) + { + // create better exception explanation + JSON_THROW(out_of_range::create(403, "key '" + key + "' not found")); + } + } + else + { + JSON_THROW(type_error::create(304, "cannot use at() with " + std::string(type_name()))); + } + } + + /*! + @brief access specified object element with bounds checking + + Returns a const reference to the element at with specified key @a key, + with bounds checking. + + @param[in] key key of the element to access + + @return const reference to the element at key @a key + + @throw type_error.304 if the JSON value is not an object; in this case, + calling `at` with a key makes no sense. See example below. + @throw out_of_range.403 if the key @a key is is not stored in the object; + that is, `find(key) == end()`. See example below. + + @exceptionsafety Strong guarantee: if an exception is thrown, there are no + changes in the JSON value. + + @complexity Logarithmic in the size of the container. + + @sa @ref operator[](const typename object_t::key_type&) for unchecked + access by reference + @sa @ref value() for access by value with a default value + + @since version 1.0.0 + + @liveexample{The example below shows how object elements can be read using + `at()`. It also demonstrates the different exceptions that can be thrown., + at__object_t_key_type_const} + */ + const_reference at(const typename object_t::key_type& key) const + { + // at only works for objects + if (JSON_LIKELY(is_object())) + { + JSON_TRY + { + return m_value.object->at(key); + } + JSON_CATCH (std::out_of_range&) + { + // create better exception explanation + JSON_THROW(out_of_range::create(403, "key '" + key + "' not found")); + } + } + else + { + JSON_THROW(type_error::create(304, "cannot use at() with " + std::string(type_name()))); + } + } + + /*! + @brief access specified array element + + Returns a reference to the element at specified location @a idx. + + @note If @a idx is beyond the range of the array (i.e., `idx >= size()`), + then the array is silently filled up with `null` values to make `idx` a + valid reference to the last stored element. + + @param[in] idx index of the element to access + + @return reference to the element at index @a idx + + @throw type_error.305 if the JSON value is not an array or null; in that + cases, using the [] operator with an index makes no sense. + + @complexity Constant if @a idx is in the range of the array. Otherwise + linear in `idx - size()`. + + @liveexample{The example below shows how array elements can be read and + written using `[]` operator. Note the addition of `null` + values.,operatorarray__size_type} + + @since version 1.0.0 + */ + reference operator[](size_type idx) + { + // implicitly convert null value to an empty array + if (is_null()) + { + m_type = value_t::array; + m_value.array = create(); + assert_invariant(); + } + + // operator[] only works for arrays + if (JSON_LIKELY(is_array())) + { + // fill up array with null values if given idx is outside range + if (idx >= m_value.array->size()) + { + m_value.array->insert(m_value.array->end(), + idx - m_value.array->size() + 1, + basic_json()); + } + + return m_value.array->operator[](idx); + } + + JSON_THROW(type_error::create(305, "cannot use operator[] with " + std::string(type_name()))); + } + + /*! + @brief access specified array element + + Returns a const reference to the element at specified location @a idx. + + @param[in] idx index of the element to access + + @return const reference to the element at index @a idx + + @throw type_error.305 if the JSON value is not an array; in that cases, + using the [] operator with an index makes no sense. + + @complexity Constant. + + @liveexample{The example below shows how array elements can be read using + the `[]` operator.,operatorarray__size_type_const} + + @since version 1.0.0 + */ + const_reference operator[](size_type idx) const + { + // const operator[] only works for arrays + if (JSON_LIKELY(is_array())) + { + return m_value.array->operator[](idx); + } + + JSON_THROW(type_error::create(305, "cannot use operator[] with " + std::string(type_name()))); + } + + /*! + @brief access specified object element + + Returns a reference to the element at with specified key @a key. + + @note If @a key is not found in the object, then it is silently added to + the object and filled with a `null` value to make `key` a valid reference. + In case the value was `null` before, it is converted to an object. + + @param[in] key key of the element to access + + @return reference to the element at key @a key + + @throw type_error.305 if the JSON value is not an object or null; in that + cases, using the [] operator with a key makes no sense. + + @complexity Logarithmic in the size of the container. + + @liveexample{The example below shows how object elements can be read and + written using the `[]` operator.,operatorarray__key_type} + + @sa @ref at(const typename object_t::key_type&) for access by reference + with range checking + @sa @ref value() for access by value with a default value + + @since version 1.0.0 + */ + reference operator[](const typename object_t::key_type& key) + { + // implicitly convert null value to an empty object + if (is_null()) + { + m_type = value_t::object; + m_value.object = create(); + assert_invariant(); + } + + // operator[] only works for objects + if (JSON_LIKELY(is_object())) + { + return m_value.object->operator[](key); + } + + JSON_THROW(type_error::create(305, "cannot use operator[] with " + std::string(type_name()))); + } + + /*! + @brief read-only access specified object element + + Returns a const reference to the element at with specified key @a key. No + bounds checking is performed. + + @warning If the element with key @a key does not exist, the behavior is + undefined. + + @param[in] key key of the element to access + + @return const reference to the element at key @a key + + @pre The element with key @a key must exist. **This precondition is + enforced with an assertion.** + + @throw type_error.305 if the JSON value is not an object; in that cases, + using the [] operator with a key makes no sense. + + @complexity Logarithmic in the size of the container. + + @liveexample{The example below shows how object elements can be read using + the `[]` operator.,operatorarray__key_type_const} + + @sa @ref at(const typename object_t::key_type&) for access by reference + with range checking + @sa @ref value() for access by value with a default value + + @since version 1.0.0 + */ + const_reference operator[](const typename object_t::key_type& key) const + { + // const operator[] only works for objects + if (JSON_LIKELY(is_object())) + { + assert(m_value.object->find(key) != m_value.object->end()); + return m_value.object->find(key)->second; + } + + JSON_THROW(type_error::create(305, "cannot use operator[] with " + std::string(type_name()))); + } + + /*! + @brief access specified object element + + Returns a reference to the element at with specified key @a key. + + @note If @a key is not found in the object, then it is silently added to + the object and filled with a `null` value to make `key` a valid reference. + In case the value was `null` before, it is converted to an object. + + @param[in] key key of the element to access + + @return reference to the element at key @a key + + @throw type_error.305 if the JSON value is not an object or null; in that + cases, using the [] operator with a key makes no sense. + + @complexity Logarithmic in the size of the container. + + @liveexample{The example below shows how object elements can be read and + written using the `[]` operator.,operatorarray__key_type} + + @sa @ref at(const typename object_t::key_type&) for access by reference + with range checking + @sa @ref value() for access by value with a default value + + @since version 1.1.0 + */ + template + reference operator[](T* key) + { + // implicitly convert null to object + if (is_null()) + { + m_type = value_t::object; + m_value = value_t::object; + assert_invariant(); + } + + // at only works for objects + if (JSON_LIKELY(is_object())) + { + return m_value.object->operator[](key); + } + + JSON_THROW(type_error::create(305, "cannot use operator[] with " + std::string(type_name()))); + } + + /*! + @brief read-only access specified object element + + Returns a const reference to the element at with specified key @a key. No + bounds checking is performed. + + @warning If the element with key @a key does not exist, the behavior is + undefined. + + @param[in] key key of the element to access + + @return const reference to the element at key @a key + + @pre The element with key @a key must exist. **This precondition is + enforced with an assertion.** + + @throw type_error.305 if the JSON value is not an object; in that cases, + using the [] operator with a key makes no sense. + + @complexity Logarithmic in the size of the container. + + @liveexample{The example below shows how object elements can be read using + the `[]` operator.,operatorarray__key_type_const} + + @sa @ref at(const typename object_t::key_type&) for access by reference + with range checking + @sa @ref value() for access by value with a default value + + @since version 1.1.0 + */ + template + const_reference operator[](T* key) const + { + // at only works for objects + if (JSON_LIKELY(is_object())) + { + assert(m_value.object->find(key) != m_value.object->end()); + return m_value.object->find(key)->second; + } + + JSON_THROW(type_error::create(305, "cannot use operator[] with " + std::string(type_name()))); + } + + /*! + @brief access specified object element with default value + + Returns either a copy of an object's element at the specified key @a key + or a given default value if no element with key @a key exists. + + The function is basically equivalent to executing + @code {.cpp} + try { + return at(key); + } catch(out_of_range) { + return default_value; + } + @endcode + + @note Unlike @ref at(const typename object_t::key_type&), this function + does not throw if the given key @a key was not found. + + @note Unlike @ref operator[](const typename object_t::key_type& key), this + function does not implicitly add an element to the position defined by @a + key. This function is furthermore also applicable to const objects. + + @param[in] key key of the element to access + @param[in] default_value the value to return if @a key is not found + + @tparam ValueType type compatible to JSON values, for instance `int` for + JSON integer numbers, `bool` for JSON booleans, or `std::vector` types for + JSON arrays. Note the type of the expected value at @a key and the default + value @a default_value must be compatible. + + @return copy of the element at key @a key or @a default_value if @a key + is not found + + @throw type_error.306 if the JSON value is not an objec; in that cases, + using `value()` with a key makes no sense. + + @complexity Logarithmic in the size of the container. + + @liveexample{The example below shows how object elements can be queried + with a default value.,basic_json__value} + + @sa @ref at(const typename object_t::key_type&) for access by reference + with range checking + @sa @ref operator[](const typename object_t::key_type&) for unchecked + access by reference + + @since version 1.0.0 + */ + template::value, int>::type = 0> + ValueType value(const typename object_t::key_type& key, const ValueType& default_value) const + { + // at only works for objects + if (JSON_LIKELY(is_object())) + { + // if key is found, return value and given default value otherwise + const auto it = find(key); + if (it != end()) + { + return *it; + } + + return default_value; + } + + JSON_THROW(type_error::create(306, "cannot use value() with " + std::string(type_name()))); + } + + /*! + @brief overload for a default value of type const char* + @copydoc basic_json::value(const typename object_t::key_type&, ValueType) const + */ + string_t value(const typename object_t::key_type& key, const char* default_value) const + { + return value(key, string_t(default_value)); + } + + /*! + @brief access specified object element via JSON Pointer with default value + + Returns either a copy of an object's element at the specified key @a key + or a given default value if no element with key @a key exists. + + The function is basically equivalent to executing + @code {.cpp} + try { + return at(ptr); + } catch(out_of_range) { + return default_value; + } + @endcode + + @note Unlike @ref at(const json_pointer&), this function does not throw + if the given key @a key was not found. + + @param[in] ptr a JSON pointer to the element to access + @param[in] default_value the value to return if @a ptr found no value + + @tparam ValueType type compatible to JSON values, for instance `int` for + JSON integer numbers, `bool` for JSON booleans, or `std::vector` types for + JSON arrays. Note the type of the expected value at @a key and the default + value @a default_value must be compatible. + + @return copy of the element at key @a key or @a default_value if @a key + is not found + + @throw type_error.306 if the JSON value is not an objec; in that cases, + using `value()` with a key makes no sense. + + @complexity Logarithmic in the size of the container. + + @liveexample{The example below shows how object elements can be queried + with a default value.,basic_json__value_ptr} + + @sa @ref operator[](const json_pointer&) for unchecked access by reference + + @since version 2.0.2 + */ + template::value, int>::type = 0> + ValueType value(const json_pointer& ptr, const ValueType& default_value) const + { + // at only works for objects + if (JSON_LIKELY(is_object())) + { + // if pointer resolves a value, return it or use default value + JSON_TRY + { + return ptr.get_checked(this); + } + JSON_CATCH (out_of_range&) + { + return default_value; + } + } + + JSON_THROW(type_error::create(306, "cannot use value() with " + std::string(type_name()))); + } + + /*! + @brief overload for a default value of type const char* + @copydoc basic_json::value(const json_pointer&, ValueType) const + */ + string_t value(const json_pointer& ptr, const char* default_value) const + { + return value(ptr, string_t(default_value)); + } + + /*! + @brief access the first element + + Returns a reference to the first element in the container. For a JSON + container `c`, the expression `c.front()` is equivalent to `*c.begin()`. + + @return In case of a structured type (array or object), a reference to the + first element is returned. In case of number, string, or boolean values, a + reference to the value is returned. + + @complexity Constant. + + @pre The JSON value must not be `null` (would throw `std::out_of_range`) + or an empty array or object (undefined behavior, **guarded by + assertions**). + @post The JSON value remains unchanged. + + @throw invalid_iterator.214 when called on `null` value + + @liveexample{The following code shows an example for `front()`.,front} + + @sa @ref back() -- access the last element + + @since version 1.0.0 + */ + reference front() + { + return *begin(); + } + + /*! + @copydoc basic_json::front() + */ + const_reference front() const + { + return *cbegin(); + } + + /*! + @brief access the last element + + Returns a reference to the last element in the container. For a JSON + container `c`, the expression `c.back()` is equivalent to + @code {.cpp} + auto tmp = c.end(); + --tmp; + return *tmp; + @endcode + + @return In case of a structured type (array or object), a reference to the + last element is returned. In case of number, string, or boolean values, a + reference to the value is returned. + + @complexity Constant. + + @pre The JSON value must not be `null` (would throw `std::out_of_range`) + or an empty array or object (undefined behavior, **guarded by + assertions**). + @post The JSON value remains unchanged. + + @throw invalid_iterator.214 when called on a `null` value. See example + below. + + @liveexample{The following code shows an example for `back()`.,back} + + @sa @ref front() -- access the first element + + @since version 1.0.0 + */ + reference back() + { + auto tmp = end(); + --tmp; + return *tmp; + } + + /*! + @copydoc basic_json::back() + */ + const_reference back() const + { + auto tmp = cend(); + --tmp; + return *tmp; + } + + /*! + @brief remove element given an iterator + + Removes the element specified by iterator @a pos. The iterator @a pos must + be valid and dereferenceable. Thus the `end()` iterator (which is valid, + but is not dereferenceable) cannot be used as a value for @a pos. + + If called on a primitive type other than `null`, the resulting JSON value + will be `null`. + + @param[in] pos iterator to the element to remove + @return Iterator following the last removed element. If the iterator @a + pos refers to the last element, the `end()` iterator is returned. + + @tparam IteratorType an @ref iterator or @ref const_iterator + + @post Invalidates iterators and references at or after the point of the + erase, including the `end()` iterator. + + @throw type_error.307 if called on a `null` value; example: `"cannot use + erase() with null"` + @throw invalid_iterator.202 if called on an iterator which does not belong + to the current JSON value; example: `"iterator does not fit current + value"` + @throw invalid_iterator.205 if called on a primitive type with invalid + iterator (i.e., any iterator which is not `begin()`); example: `"iterator + out of range"` + + @complexity The complexity depends on the type: + - objects: amortized constant + - arrays: linear in distance between @a pos and the end of the container + - strings: linear in the length of the string + - other types: constant + + @liveexample{The example shows the result of `erase()` for different JSON + types.,erase__IteratorType} + + @sa @ref erase(IteratorType, IteratorType) -- removes the elements in + the given range + @sa @ref erase(const typename object_t::key_type&) -- removes the element + from an object at the given key + @sa @ref erase(const size_type) -- removes the element from an array at + the given index + + @since version 1.0.0 + */ + template::value or + std::is_same::value, int>::type + = 0> + IteratorType erase(IteratorType pos) + { + // make sure iterator fits the current value + if (JSON_UNLIKELY(this != pos.m_object)) + { + JSON_THROW(invalid_iterator::create(202, "iterator does not fit current value")); + } + + IteratorType result = end(); + + switch (m_type) + { + case value_t::boolean: + case value_t::number_float: + case value_t::number_integer: + case value_t::number_unsigned: + case value_t::string: + { + if (JSON_UNLIKELY(not pos.m_it.primitive_iterator.is_begin())) + { + JSON_THROW(invalid_iterator::create(205, "iterator out of range")); + } + + if (is_string()) + { + AllocatorType alloc; + alloc.destroy(m_value.string); + alloc.deallocate(m_value.string, 1); + m_value.string = nullptr; + } + + m_type = value_t::null; + assert_invariant(); + break; + } + + case value_t::object: + { + result.m_it.object_iterator = m_value.object->erase(pos.m_it.object_iterator); + break; + } + + case value_t::array: + { + result.m_it.array_iterator = m_value.array->erase(pos.m_it.array_iterator); + break; + } + + default: + JSON_THROW(type_error::create(307, "cannot use erase() with " + std::string(type_name()))); + } + + return result; + } + + /*! + @brief remove elements given an iterator range + + Removes the element specified by the range `[first; last)`. The iterator + @a first does not need to be dereferenceable if `first == last`: erasing + an empty range is a no-op. + + If called on a primitive type other than `null`, the resulting JSON value + will be `null`. + + @param[in] first iterator to the beginning of the range to remove + @param[in] last iterator past the end of the range to remove + @return Iterator following the last removed element. If the iterator @a + second refers to the last element, the `end()` iterator is returned. + + @tparam IteratorType an @ref iterator or @ref const_iterator + + @post Invalidates iterators and references at or after the point of the + erase, including the `end()` iterator. + + @throw type_error.307 if called on a `null` value; example: `"cannot use + erase() with null"` + @throw invalid_iterator.203 if called on iterators which does not belong + to the current JSON value; example: `"iterators do not fit current value"` + @throw invalid_iterator.204 if called on a primitive type with invalid + iterators (i.e., if `first != begin()` and `last != end()`); example: + `"iterators out of range"` + + @complexity The complexity depends on the type: + - objects: `log(size()) + std::distance(first, last)` + - arrays: linear in the distance between @a first and @a last, plus linear + in the distance between @a last and end of the container + - strings: linear in the length of the string + - other types: constant + + @liveexample{The example shows the result of `erase()` for different JSON + types.,erase__IteratorType_IteratorType} + + @sa @ref erase(IteratorType) -- removes the element at a given position + @sa @ref erase(const typename object_t::key_type&) -- removes the element + from an object at the given key + @sa @ref erase(const size_type) -- removes the element from an array at + the given index + + @since version 1.0.0 + */ + template::value or + std::is_same::value, int>::type + = 0> + IteratorType erase(IteratorType first, IteratorType last) + { + // make sure iterator fits the current value + if (JSON_UNLIKELY(this != first.m_object or this != last.m_object)) + { + JSON_THROW(invalid_iterator::create(203, "iterators do not fit current value")); + } + + IteratorType result = end(); + + switch (m_type) + { + case value_t::boolean: + case value_t::number_float: + case value_t::number_integer: + case value_t::number_unsigned: + case value_t::string: + { + if (JSON_LIKELY(not first.m_it.primitive_iterator.is_begin() + or not last.m_it.primitive_iterator.is_end())) + { + JSON_THROW(invalid_iterator::create(204, "iterators out of range")); + } + + if (is_string()) + { + AllocatorType alloc; + alloc.destroy(m_value.string); + alloc.deallocate(m_value.string, 1); + m_value.string = nullptr; + } + + m_type = value_t::null; + assert_invariant(); + break; + } + + case value_t::object: + { + result.m_it.object_iterator = m_value.object->erase(first.m_it.object_iterator, + last.m_it.object_iterator); + break; + } + + case value_t::array: + { + result.m_it.array_iterator = m_value.array->erase(first.m_it.array_iterator, + last.m_it.array_iterator); + break; + } + + default: + JSON_THROW(type_error::create(307, "cannot use erase() with " + std::string(type_name()))); + } + + return result; + } + + /*! + @brief remove element from a JSON object given a key + + Removes elements from a JSON object with the key value @a key. + + @param[in] key value of the elements to remove + + @return Number of elements removed. If @a ObjectType is the default + `std::map` type, the return value will always be `0` (@a key was not + found) or `1` (@a key was found). + + @post References and iterators to the erased elements are invalidated. + Other references and iterators are not affected. + + @throw type_error.307 when called on a type other than JSON object; + example: `"cannot use erase() with null"` + + @complexity `log(size()) + count(key)` + + @liveexample{The example shows the effect of `erase()`.,erase__key_type} + + @sa @ref erase(IteratorType) -- removes the element at a given position + @sa @ref erase(IteratorType, IteratorType) -- removes the elements in + the given range + @sa @ref erase(const size_type) -- removes the element from an array at + the given index + + @since version 1.0.0 + */ + size_type erase(const typename object_t::key_type& key) + { + // this erase only works for objects + if (JSON_LIKELY(is_object())) + { + return m_value.object->erase(key); + } + + JSON_THROW(type_error::create(307, "cannot use erase() with " + std::string(type_name()))); + } + + /*! + @brief remove element from a JSON array given an index + + Removes element from a JSON array at the index @a idx. + + @param[in] idx index of the element to remove + + @throw type_error.307 when called on a type other than JSON object; + example: `"cannot use erase() with null"` + @throw out_of_range.401 when `idx >= size()`; example: `"array index 17 + is out of range"` + + @complexity Linear in distance between @a idx and the end of the container. + + @liveexample{The example shows the effect of `erase()`.,erase__size_type} + + @sa @ref erase(IteratorType) -- removes the element at a given position + @sa @ref erase(IteratorType, IteratorType) -- removes the elements in + the given range + @sa @ref erase(const typename object_t::key_type&) -- removes the element + from an object at the given key + + @since version 1.0.0 + */ + void erase(const size_type idx) + { + // this erase only works for arrays + if (JSON_LIKELY(is_array())) + { + if (JSON_UNLIKELY(idx >= size())) + { + JSON_THROW(out_of_range::create(401, "array index " + std::to_string(idx) + " is out of range")); + } + + m_value.array->erase(m_value.array->begin() + static_cast(idx)); + } + else + { + JSON_THROW(type_error::create(307, "cannot use erase() with " + std::string(type_name()))); + } + } + + /// @} + + + //////////// + // lookup // + //////////// + + /// @name lookup + /// @{ + + /*! + @brief find an element in a JSON object + + Finds an element in a JSON object with key equivalent to @a key. If the + element is not found or the JSON value is not an object, end() is + returned. + + @note This method always returns @ref end() when executed on a JSON type + that is not an object. + + @param[in] key key value of the element to search for + + @return Iterator to an element with key equivalent to @a key. If no such + element is found or the JSON value is not an object, past-the-end (see + @ref end()) iterator is returned. + + @complexity Logarithmic in the size of the JSON object. + + @liveexample{The example shows how `find()` is used.,find__key_type} + + @since version 1.0.0 + */ + iterator find(typename object_t::key_type key) + { + auto result = end(); + + if (is_object()) + { + result.m_it.object_iterator = m_value.object->find(key); + } + + return result; + } + + /*! + @brief find an element in a JSON object + @copydoc find(typename object_t::key_type) + */ + const_iterator find(typename object_t::key_type key) const + { + auto result = cend(); + + if (is_object()) + { + result.m_it.object_iterator = m_value.object->find(key); + } + + return result; + } + + /*! + @brief returns the number of occurrences of a key in a JSON object + + Returns the number of elements with key @a key. If ObjectType is the + default `std::map` type, the return value will always be `0` (@a key was + not found) or `1` (@a key was found). + + @note This method always returns `0` when executed on a JSON type that is + not an object. + + @param[in] key key value of the element to count + + @return Number of elements with key @a key. If the JSON value is not an + object, the return value will be `0`. + + @complexity Logarithmic in the size of the JSON object. + + @liveexample{The example shows how `count()` is used.,count} + + @since version 1.0.0 + */ + size_type count(typename object_t::key_type key) const + { + // return 0 for all nonobject types + return is_object() ? m_value.object->count(key) : 0; + } + + /// @} + + + /////////////// + // iterators // + /////////////// + + /// @name iterators + /// @{ + + /*! + @brief returns an iterator to the first element + + Returns an iterator to the first element. + + @image html range-begin-end.svg "Illustration from cppreference.com" + + @return iterator to the first element + + @complexity Constant. + + @requirement This function helps `basic_json` satisfying the + [Container](http://en.cppreference.com/w/cpp/concept/Container) + requirements: + - The complexity is constant. + + @liveexample{The following code shows an example for `begin()`.,begin} + + @sa @ref cbegin() -- returns a const iterator to the beginning + @sa @ref end() -- returns an iterator to the end + @sa @ref cend() -- returns a const iterator to the end + + @since version 1.0.0 + */ + iterator begin() noexcept + { + iterator result(this); + result.set_begin(); + return result; + } + + /*! + @copydoc basic_json::cbegin() + */ + const_iterator begin() const noexcept + { + return cbegin(); + } + + /*! + @brief returns a const iterator to the first element + + Returns a const iterator to the first element. + + @image html range-begin-end.svg "Illustration from cppreference.com" + + @return const iterator to the first element + + @complexity Constant. + + @requirement This function helps `basic_json` satisfying the + [Container](http://en.cppreference.com/w/cpp/concept/Container) + requirements: + - The complexity is constant. + - Has the semantics of `const_cast(*this).begin()`. + + @liveexample{The following code shows an example for `cbegin()`.,cbegin} + + @sa @ref begin() -- returns an iterator to the beginning + @sa @ref end() -- returns an iterator to the end + @sa @ref cend() -- returns a const iterator to the end + + @since version 1.0.0 + */ + const_iterator cbegin() const noexcept + { + const_iterator result(this); + result.set_begin(); + return result; + } + + /*! + @brief returns an iterator to one past the last element + + Returns an iterator to one past the last element. + + @image html range-begin-end.svg "Illustration from cppreference.com" + + @return iterator one past the last element + + @complexity Constant. + + @requirement This function helps `basic_json` satisfying the + [Container](http://en.cppreference.com/w/cpp/concept/Container) + requirements: + - The complexity is constant. + + @liveexample{The following code shows an example for `end()`.,end} + + @sa @ref cend() -- returns a const iterator to the end + @sa @ref begin() -- returns an iterator to the beginning + @sa @ref cbegin() -- returns a const iterator to the beginning + + @since version 1.0.0 + */ + iterator end() noexcept + { + iterator result(this); + result.set_end(); + return result; + } + + /*! + @copydoc basic_json::cend() + */ + const_iterator end() const noexcept + { + return cend(); + } + + /*! + @brief returns a const iterator to one past the last element + + Returns a const iterator to one past the last element. + + @image html range-begin-end.svg "Illustration from cppreference.com" + + @return const iterator one past the last element + + @complexity Constant. + + @requirement This function helps `basic_json` satisfying the + [Container](http://en.cppreference.com/w/cpp/concept/Container) + requirements: + - The complexity is constant. + - Has the semantics of `const_cast(*this).end()`. + + @liveexample{The following code shows an example for `cend()`.,cend} + + @sa @ref end() -- returns an iterator to the end + @sa @ref begin() -- returns an iterator to the beginning + @sa @ref cbegin() -- returns a const iterator to the beginning + + @since version 1.0.0 + */ + const_iterator cend() const noexcept + { + const_iterator result(this); + result.set_end(); + return result; + } + + /*! + @brief returns an iterator to the reverse-beginning + + Returns an iterator to the reverse-beginning; that is, the last element. + + @image html range-rbegin-rend.svg "Illustration from cppreference.com" + + @complexity Constant. + + @requirement This function helps `basic_json` satisfying the + [ReversibleContainer](http://en.cppreference.com/w/cpp/concept/ReversibleContainer) + requirements: + - The complexity is constant. + - Has the semantics of `reverse_iterator(end())`. + + @liveexample{The following code shows an example for `rbegin()`.,rbegin} + + @sa @ref crbegin() -- returns a const reverse iterator to the beginning + @sa @ref rend() -- returns a reverse iterator to the end + @sa @ref crend() -- returns a const reverse iterator to the end + + @since version 1.0.0 + */ + reverse_iterator rbegin() noexcept + { + return reverse_iterator(end()); + } + + /*! + @copydoc basic_json::crbegin() + */ + const_reverse_iterator rbegin() const noexcept + { + return crbegin(); + } + + /*! + @brief returns an iterator to the reverse-end + + Returns an iterator to the reverse-end; that is, one before the first + element. + + @image html range-rbegin-rend.svg "Illustration from cppreference.com" + + @complexity Constant. + + @requirement This function helps `basic_json` satisfying the + [ReversibleContainer](http://en.cppreference.com/w/cpp/concept/ReversibleContainer) + requirements: + - The complexity is constant. + - Has the semantics of `reverse_iterator(begin())`. + + @liveexample{The following code shows an example for `rend()`.,rend} + + @sa @ref crend() -- returns a const reverse iterator to the end + @sa @ref rbegin() -- returns a reverse iterator to the beginning + @sa @ref crbegin() -- returns a const reverse iterator to the beginning + + @since version 1.0.0 + */ + reverse_iterator rend() noexcept + { + return reverse_iterator(begin()); + } + + /*! + @copydoc basic_json::crend() + */ + const_reverse_iterator rend() const noexcept + { + return crend(); + } + + /*! + @brief returns a const reverse iterator to the last element + + Returns a const iterator to the reverse-beginning; that is, the last + element. + + @image html range-rbegin-rend.svg "Illustration from cppreference.com" + + @complexity Constant. + + @requirement This function helps `basic_json` satisfying the + [ReversibleContainer](http://en.cppreference.com/w/cpp/concept/ReversibleContainer) + requirements: + - The complexity is constant. + - Has the semantics of `const_cast(*this).rbegin()`. + + @liveexample{The following code shows an example for `crbegin()`.,crbegin} + + @sa @ref rbegin() -- returns a reverse iterator to the beginning + @sa @ref rend() -- returns a reverse iterator to the end + @sa @ref crend() -- returns a const reverse iterator to the end + + @since version 1.0.0 + */ + const_reverse_iterator crbegin() const noexcept + { + return const_reverse_iterator(cend()); + } + + /*! + @brief returns a const reverse iterator to one before the first + + Returns a const reverse iterator to the reverse-end; that is, one before + the first element. + + @image html range-rbegin-rend.svg "Illustration from cppreference.com" + + @complexity Constant. + + @requirement This function helps `basic_json` satisfying the + [ReversibleContainer](http://en.cppreference.com/w/cpp/concept/ReversibleContainer) + requirements: + - The complexity is constant. + - Has the semantics of `const_cast(*this).rend()`. + + @liveexample{The following code shows an example for `crend()`.,crend} + + @sa @ref rend() -- returns a reverse iterator to the end + @sa @ref rbegin() -- returns a reverse iterator to the beginning + @sa @ref crbegin() -- returns a const reverse iterator to the beginning + + @since version 1.0.0 + */ + const_reverse_iterator crend() const noexcept + { + return const_reverse_iterator(cbegin()); + } + + public: + /*! + @brief wrapper to access iterator member functions in range-based for + + This function allows to access @ref iterator::key() and @ref + iterator::value() during range-based for loops. In these loops, a + reference to the JSON values is returned, so there is no access to the + underlying iterator. + + @liveexample{The following code shows how the wrapper is used,iterator_wrapper} + + @note The name of this function is not yet final and may change in the + future. + */ + static iteration_proxy iterator_wrapper(reference cont) + { + return iteration_proxy(cont); + } + + /*! + @copydoc iterator_wrapper(reference) + */ + static iteration_proxy iterator_wrapper(const_reference cont) + { + return iteration_proxy(cont); + } + + /// @} + + + ////////////// + // capacity // + ////////////// + + /// @name capacity + /// @{ + + /*! + @brief checks whether the container is empty. + + Checks if a JSON value has no elements (i.e. whether its @ref size is `0`). + + @return The return value depends on the different types and is + defined as follows: + Value type | return value + ----------- | ------------- + null | `true` + boolean | `false` + string | `false` + number | `false` + object | result of function `object_t::empty()` + array | result of function `array_t::empty()` + + @liveexample{The following code uses `empty()` to check if a JSON + object contains any elements.,empty} + + @complexity Constant, as long as @ref array_t and @ref object_t satisfy + the Container concept; that is, their `empty()` functions have constant + complexity. + + @iterators No changes. + + @exceptionsafety No-throw guarantee: this function never throws exceptions. + + @note This function does not return whether a string stored as JSON value + is empty - it returns whether the JSON container itself is empty which is + false in the case of a string. + + @requirement This function helps `basic_json` satisfying the + [Container](http://en.cppreference.com/w/cpp/concept/Container) + requirements: + - The complexity is constant. + - Has the semantics of `begin() == end()`. + + @sa @ref size() -- returns the number of elements + + @since version 1.0.0 + */ + bool empty() const noexcept + { + switch (m_type) + { + case value_t::null: + { + // null values are empty + return true; + } + + case value_t::array: + { + // delegate call to array_t::empty() + return m_value.array->empty(); + } + + case value_t::object: + { + // delegate call to object_t::empty() + return m_value.object->empty(); + } + + default: + { + // all other types are nonempty + return false; + } + } + } + + /*! + @brief returns the number of elements + + Returns the number of elements in a JSON value. + + @return The return value depends on the different types and is + defined as follows: + Value type | return value + ----------- | ------------- + null | `0` + boolean | `1` + string | `1` + number | `1` + object | result of function object_t::size() + array | result of function array_t::size() + + @liveexample{The following code calls `size()` on the different value + types.,size} + + @complexity Constant, as long as @ref array_t and @ref object_t satisfy + the Container concept; that is, their size() functions have constant + complexity. + + @iterators No changes. + + @exceptionsafety No-throw guarantee: this function never throws exceptions. + + @note This function does not return the length of a string stored as JSON + value - it returns the number of elements in the JSON value which is 1 in + the case of a string. + + @requirement This function helps `basic_json` satisfying the + [Container](http://en.cppreference.com/w/cpp/concept/Container) + requirements: + - The complexity is constant. + - Has the semantics of `std::distance(begin(), end())`. + + @sa @ref empty() -- checks whether the container is empty + @sa @ref max_size() -- returns the maximal number of elements + + @since version 1.0.0 + */ + size_type size() const noexcept + { + switch (m_type) + { + case value_t::null: + { + // null values are empty + return 0; + } + + case value_t::array: + { + // delegate call to array_t::size() + return m_value.array->size(); + } + + case value_t::object: + { + // delegate call to object_t::size() + return m_value.object->size(); + } + + default: + { + // all other types have size 1 + return 1; + } + } + } + + /*! + @brief returns the maximum possible number of elements + + Returns the maximum number of elements a JSON value is able to hold due to + system or library implementation limitations, i.e. `std::distance(begin(), + end())` for the JSON value. + + @return The return value depends on the different types and is + defined as follows: + Value type | return value + ----------- | ------------- + null | `0` (same as `size()`) + boolean | `1` (same as `size()`) + string | `1` (same as `size()`) + number | `1` (same as `size()`) + object | result of function `object_t::max_size()` + array | result of function `array_t::max_size()` + + @liveexample{The following code calls `max_size()` on the different value + types. Note the output is implementation specific.,max_size} + + @complexity Constant, as long as @ref array_t and @ref object_t satisfy + the Container concept; that is, their `max_size()` functions have constant + complexity. + + @iterators No changes. + + @exceptionsafety No-throw guarantee: this function never throws exceptions. + + @requirement This function helps `basic_json` satisfying the + [Container](http://en.cppreference.com/w/cpp/concept/Container) + requirements: + - The complexity is constant. + - Has the semantics of returning `b.size()` where `b` is the largest + possible JSON value. + + @sa @ref size() -- returns the number of elements + + @since version 1.0.0 + */ + size_type max_size() const noexcept + { + switch (m_type) + { + case value_t::array: + { + // delegate call to array_t::max_size() + return m_value.array->max_size(); + } + + case value_t::object: + { + // delegate call to object_t::max_size() + return m_value.object->max_size(); + } + + default: + { + // all other types have max_size() == size() + return size(); + } + } + } + + /// @} + + + /////////////// + // modifiers // + /////////////// + + /// @name modifiers + /// @{ + + /*! + @brief clears the contents + + Clears the content of a JSON value and resets it to the default value as + if @ref basic_json(value_t) would have been called with the current value + type from @ref type(): + + Value type | initial value + ----------- | ------------- + null | `null` + boolean | `false` + string | `""` + number | `0` + object | `{}` + array | `[]` + + @post Has the same effect as calling + @code {.cpp} + *this = basic_json(type()); + @endcode + + @liveexample{The example below shows the effect of `clear()` to different + JSON types.,clear} + + @complexity Linear in the size of the JSON value. + + @iterators All iterators, pointers and references related to this container + are invalidated. + + @exceptionsafety No-throw guarantee: this function never throws exceptions. + + @sa @ref basic_json(value_t) -- constructor that creates an object with the + same value than calling `clear()` + + @since version 1.0.0 + */ + void clear() noexcept + { + switch (m_type) + { + case value_t::number_integer: + { + m_value.number_integer = 0; + break; + } + + case value_t::number_unsigned: + { + m_value.number_unsigned = 0; + break; + } + + case value_t::number_float: + { + m_value.number_float = 0.0; + break; + } + + case value_t::boolean: + { + m_value.boolean = false; + break; + } + + case value_t::string: + { + m_value.string->clear(); + break; + } + + case value_t::array: + { + m_value.array->clear(); + break; + } + + case value_t::object: + { + m_value.object->clear(); + break; + } + + default: + break; + } + } + + /*! + @brief add an object to an array + + Appends the given element @a val to the end of the JSON value. If the + function is called on a JSON null value, an empty array is created before + appending @a val. + + @param[in] val the value to add to the JSON array + + @throw type_error.308 when called on a type other than JSON array or + null; example: `"cannot use push_back() with number"` + + @complexity Amortized constant. + + @liveexample{The example shows how `push_back()` and `+=` can be used to + add elements to a JSON array. Note how the `null` value was silently + converted to a JSON array.,push_back} + + @since version 1.0.0 + */ + void push_back(basic_json&& val) + { + // push_back only works for null objects or arrays + if (JSON_UNLIKELY(not(is_null() or is_array()))) + { + JSON_THROW(type_error::create(308, "cannot use push_back() with " + std::string(type_name()))); + } + + // transform null object into an array + if (is_null()) + { + m_type = value_t::array; + m_value = value_t::array; + assert_invariant(); + } + + // add element to array (move semantics) + m_value.array->push_back(std::move(val)); + // invalidate object + val.m_type = value_t::null; + } + + /*! + @brief add an object to an array + @copydoc push_back(basic_json&&) + */ + reference operator+=(basic_json&& val) + { + push_back(std::move(val)); + return *this; + } + + /*! + @brief add an object to an array + @copydoc push_back(basic_json&&) + */ + void push_back(const basic_json& val) + { + // push_back only works for null objects or arrays + if (JSON_UNLIKELY(not(is_null() or is_array()))) + { + JSON_THROW(type_error::create(308, "cannot use push_back() with " + std::string(type_name()))); + } + + // transform null object into an array + if (is_null()) + { + m_type = value_t::array; + m_value = value_t::array; + assert_invariant(); + } + + // add element to array + m_value.array->push_back(val); + } + + /*! + @brief add an object to an array + @copydoc push_back(basic_json&&) + */ + reference operator+=(const basic_json& val) + { + push_back(val); + return *this; + } + + /*! + @brief add an object to an object + + Inserts the given element @a val to the JSON object. If the function is + called on a JSON null value, an empty object is created before inserting + @a val. + + @param[in] val the value to add to the JSON object + + @throw type_error.308 when called on a type other than JSON object or + null; example: `"cannot use push_back() with number"` + + @complexity Logarithmic in the size of the container, O(log(`size()`)). + + @liveexample{The example shows how `push_back()` and `+=` can be used to + add elements to a JSON object. Note how the `null` value was silently + converted to a JSON object.,push_back__object_t__value} + + @since version 1.0.0 + */ + void push_back(const typename object_t::value_type& val) + { + // push_back only works for null objects or objects + if (JSON_UNLIKELY(not(is_null() or is_object()))) + { + JSON_THROW(type_error::create(308, "cannot use push_back() with " + std::string(type_name()))); + } + + // transform null object into an object + if (is_null()) + { + m_type = value_t::object; + m_value = value_t::object; + assert_invariant(); + } + + // add element to array + m_value.object->insert(val); + } + + /*! + @brief add an object to an object + @copydoc push_back(const typename object_t::value_type&) + */ + reference operator+=(const typename object_t::value_type& val) + { + push_back(val); + return *this; + } + + /*! + @brief add an object to an object + + This function allows to use `push_back` with an initializer list. In case + + 1. the current value is an object, + 2. the initializer list @a init contains only two elements, and + 3. the first element of @a init is a string, + + @a init is converted into an object element and added using + @ref push_back(const typename object_t::value_type&). Otherwise, @a init + is converted to a JSON value and added using @ref push_back(basic_json&&). + + @param[in] init an initializer list + + @complexity Linear in the size of the initializer list @a init. + + @note This function is required to resolve an ambiguous overload error, + because pairs like `{"key", "value"}` can be both interpreted as + `object_t::value_type` or `std::initializer_list`, see + https://github.com/nlohmann/json/issues/235 for more information. + + @liveexample{The example shows how initializer lists are treated as + objects when possible.,push_back__initializer_list} + */ + void push_back(initializer_list_t init) + { + if (is_object() and init.size() == 2 and (*init.begin())->is_string()) + { + basic_json&& key = init.begin()->moved_or_copied(); + push_back(typename object_t::value_type( + std::move(key.get_ref()), (init.begin() + 1)->moved_or_copied())); + } + else + { + push_back(basic_json(init)); + } + } + + /*! + @brief add an object to an object + @copydoc push_back(initializer_list_t) + */ + reference operator+=(initializer_list_t init) + { + push_back(init); + return *this; + } + + /*! + @brief add an object to an array + + Creates a JSON value from the passed parameters @a args to the end of the + JSON value. If the function is called on a JSON null value, an empty array + is created before appending the value created from @a args. + + @param[in] args arguments to forward to a constructor of @ref basic_json + @tparam Args compatible types to create a @ref basic_json object + + @throw type_error.311 when called on a type other than JSON array or + null; example: `"cannot use emplace_back() with number"` + + @complexity Amortized constant. + + @liveexample{The example shows how `push_back()` can be used to add + elements to a JSON array. Note how the `null` value was silently converted + to a JSON array.,emplace_back} + + @since version 2.0.8 + */ + template + void emplace_back(Args&& ... args) + { + // emplace_back only works for null objects or arrays + if (JSON_UNLIKELY(not(is_null() or is_array()))) + { + JSON_THROW(type_error::create(311, "cannot use emplace_back() with " + std::string(type_name()))); + } + + // transform null object into an array + if (is_null()) + { + m_type = value_t::array; + m_value = value_t::array; + assert_invariant(); + } + + // add element to array (perfect forwarding) + m_value.array->emplace_back(std::forward(args)...); + } + + /*! + @brief add an object to an object if key does not exist + + Inserts a new element into a JSON object constructed in-place with the + given @a args if there is no element with the key in the container. If the + function is called on a JSON null value, an empty object is created before + appending the value created from @a args. + + @param[in] args arguments to forward to a constructor of @ref basic_json + @tparam Args compatible types to create a @ref basic_json object + + @return a pair consisting of an iterator to the inserted element, or the + already-existing element if no insertion happened, and a bool + denoting whether the insertion took place. + + @throw type_error.311 when called on a type other than JSON object or + null; example: `"cannot use emplace() with number"` + + @complexity Logarithmic in the size of the container, O(log(`size()`)). + + @liveexample{The example shows how `emplace()` can be used to add elements + to a JSON object. Note how the `null` value was silently converted to a + JSON object. Further note how no value is added if there was already one + value stored with the same key.,emplace} + + @since version 2.0.8 + */ + template + std::pair emplace(Args&& ... args) + { + // emplace only works for null objects or arrays + if (JSON_UNLIKELY(not(is_null() or is_object()))) + { + JSON_THROW(type_error::create(311, "cannot use emplace() with " + std::string(type_name()))); + } + + // transform null object into an object + if (is_null()) + { + m_type = value_t::object; + m_value = value_t::object; + assert_invariant(); + } + + // add element to array (perfect forwarding) + auto res = m_value.object->emplace(std::forward(args)...); + // create result iterator and set iterator to the result of emplace + auto it = begin(); + it.m_it.object_iterator = res.first; + + // return pair of iterator and boolean + return {it, res.second}; + } + + /*! + @brief inserts element + + Inserts element @a val before iterator @a pos. + + @param[in] pos iterator before which the content will be inserted; may be + the end() iterator + @param[in] val element to insert + @return iterator pointing to the inserted @a val. + + @throw type_error.309 if called on JSON values other than arrays; + example: `"cannot use insert() with string"` + @throw invalid_iterator.202 if @a pos is not an iterator of *this; + example: `"iterator does not fit current value"` + + @complexity Constant plus linear in the distance between @a pos and end of + the container. + + @liveexample{The example shows how `insert()` is used.,insert} + + @since version 1.0.0 + */ + iterator insert(const_iterator pos, const basic_json& val) + { + // insert only works for arrays + if (JSON_LIKELY(is_array())) + { + // check if iterator pos fits to this JSON value + if (JSON_UNLIKELY(pos.m_object != this)) + { + JSON_THROW(invalid_iterator::create(202, "iterator does not fit current value")); + } + + // insert to array and return iterator + iterator result(this); + result.m_it.array_iterator = m_value.array->insert(pos.m_it.array_iterator, val); + return result; + } + + JSON_THROW(type_error::create(309, "cannot use insert() with " + std::string(type_name()))); + } + + /*! + @brief inserts element + @copydoc insert(const_iterator, const basic_json&) + */ + iterator insert(const_iterator pos, basic_json&& val) + { + return insert(pos, val); + } + + /*! + @brief inserts elements + + Inserts @a cnt copies of @a val before iterator @a pos. + + @param[in] pos iterator before which the content will be inserted; may be + the end() iterator + @param[in] cnt number of copies of @a val to insert + @param[in] val element to insert + @return iterator pointing to the first element inserted, or @a pos if + `cnt==0` + + @throw type_error.309 if called on JSON values other than arrays; example: + `"cannot use insert() with string"` + @throw invalid_iterator.202 if @a pos is not an iterator of *this; + example: `"iterator does not fit current value"` + + @complexity Linear in @a cnt plus linear in the distance between @a pos + and end of the container. + + @liveexample{The example shows how `insert()` is used.,insert__count} + + @since version 1.0.0 + */ + iterator insert(const_iterator pos, size_type cnt, const basic_json& val) + { + // insert only works for arrays + if (JSON_LIKELY(is_array())) + { + // check if iterator pos fits to this JSON value + if (JSON_UNLIKELY(pos.m_object != this)) + { + JSON_THROW(invalid_iterator::create(202, "iterator does not fit current value")); + } + + // insert to array and return iterator + iterator result(this); + result.m_it.array_iterator = m_value.array->insert(pos.m_it.array_iterator, cnt, val); + return result; + } + + JSON_THROW(type_error::create(309, "cannot use insert() with " + std::string(type_name()))); + } + + /*! + @brief inserts elements + + Inserts elements from range `[first, last)` before iterator @a pos. + + @param[in] pos iterator before which the content will be inserted; may be + the end() iterator + @param[in] first begin of the range of elements to insert + @param[in] last end of the range of elements to insert + + @throw type_error.309 if called on JSON values other than arrays; example: + `"cannot use insert() with string"` + @throw invalid_iterator.202 if @a pos is not an iterator of *this; + example: `"iterator does not fit current value"` + @throw invalid_iterator.210 if @a first and @a last do not belong to the + same JSON value; example: `"iterators do not fit"` + @throw invalid_iterator.211 if @a first or @a last are iterators into + container for which insert is called; example: `"passed iterators may not + belong to container"` + + @return iterator pointing to the first element inserted, or @a pos if + `first==last` + + @complexity Linear in `std::distance(first, last)` plus linear in the + distance between @a pos and end of the container. + + @liveexample{The example shows how `insert()` is used.,insert__range} + + @since version 1.0.0 + */ + iterator insert(const_iterator pos, const_iterator first, const_iterator last) + { + // insert only works for arrays + if (JSON_UNLIKELY(not is_array())) + { + JSON_THROW(type_error::create(309, "cannot use insert() with " + std::string(type_name()))); + } + + // check if iterator pos fits to this JSON value + if (JSON_UNLIKELY(pos.m_object != this)) + { + JSON_THROW(invalid_iterator::create(202, "iterator does not fit current value")); + } + + // check if range iterators belong to the same JSON object + if (JSON_UNLIKELY(first.m_object != last.m_object)) + { + JSON_THROW(invalid_iterator::create(210, "iterators do not fit")); + } + + if (JSON_UNLIKELY(first.m_object == this or last.m_object == this)) + { + JSON_THROW(invalid_iterator::create(211, "passed iterators may not belong to container")); + } + + // insert to array and return iterator + iterator result(this); + result.m_it.array_iterator = m_value.array->insert( + pos.m_it.array_iterator, + first.m_it.array_iterator, + last.m_it.array_iterator); + return result; + } + + /*! + @brief inserts elements + + Inserts elements from initializer list @a ilist before iterator @a pos. + + @param[in] pos iterator before which the content will be inserted; may be + the end() iterator + @param[in] ilist initializer list to insert the values from + + @throw type_error.309 if called on JSON values other than arrays; example: + `"cannot use insert() with string"` + @throw invalid_iterator.202 if @a pos is not an iterator of *this; + example: `"iterator does not fit current value"` + + @return iterator pointing to the first element inserted, or @a pos if + `ilist` is empty + + @complexity Linear in `ilist.size()` plus linear in the distance between + @a pos and end of the container. + + @liveexample{The example shows how `insert()` is used.,insert__ilist} + + @since version 1.0.0 + */ + iterator insert(const_iterator pos, initializer_list_t ilist) + { + // insert only works for arrays + if (JSON_UNLIKELY(not is_array())) + { + JSON_THROW(type_error::create(309, "cannot use insert() with " + std::string(type_name()))); + } + + // check if iterator pos fits to this JSON value + if (JSON_UNLIKELY(pos.m_object != this)) + { + JSON_THROW(invalid_iterator::create(202, "iterator does not fit current value")); + } + + // insert to array and return iterator + iterator result(this); + result.m_it.array_iterator = m_value.array->insert(pos.m_it.array_iterator, ilist.begin(), ilist.end()); + return result; + } + + /*! + @brief inserts elements + + Inserts elements from range `[first, last)`. + + @param[in] first begin of the range of elements to insert + @param[in] last end of the range of elements to insert + + @throw type_error.309 if called on JSON values other than objects; example: + `"cannot use insert() with string"` + @throw invalid_iterator.202 if iterator @a first or @a last does does not + point to an object; example: `"iterators first and last must point to + objects"` + @throw invalid_iterator.210 if @a first and @a last do not belong to the + same JSON value; example: `"iterators do not fit"` + + @complexity Logarithmic: `O(N*log(size() + N))`, where `N` is the number + of elements to insert. + + @liveexample{The example shows how `insert()` is used.,insert__range_object} + + @since version 3.0.0 + */ + void insert(const_iterator first, const_iterator last) + { + // insert only works for objects + if (JSON_UNLIKELY(not is_object())) + { + JSON_THROW(type_error::create(309, "cannot use insert() with " + std::string(type_name()))); + } + + // check if range iterators belong to the same JSON object + if (JSON_UNLIKELY(first.m_object != last.m_object)) + { + JSON_THROW(invalid_iterator::create(210, "iterators do not fit")); + } + + // passed iterators must belong to objects + if (JSON_UNLIKELY(not first.m_object->is_object() + or not last.m_object->is_object())) + { + JSON_THROW(invalid_iterator::create(202, "iterators first and last must point to objects")); + } + + m_value.object->insert(first.m_it.object_iterator, last.m_it.object_iterator); + } + + /*! + @brief updates a JSON object from another object, overwriting existing keys + + Inserts all values from JSON object @a j and overwrites existing keys. + + @param[in] j JSON object to read values from + + @throw type_error.312 if called on JSON values other than objects; example: + `"cannot use update() with string"` + + @complexity O(N*log(size() + N)), where N is the number of elements to + insert. + + @liveexample{The example shows how `update()` is used.,update} + + @sa https://docs.python.org/3.6/library/stdtypes.html#dict.update + + @since version 3.0.0 + */ + void update(const_reference j) + { + // implicitly convert null value to an empty object + if (is_null()) + { + m_type = value_t::object; + m_value.object = create(); + assert_invariant(); + } + + if (JSON_UNLIKELY(not is_object())) + { + JSON_THROW(type_error::create(312, "cannot use update() with " + std::string(type_name()))); + } + if (JSON_UNLIKELY(not j.is_object())) + { + JSON_THROW(type_error::create(312, "cannot use update() with " + std::string(j.type_name()))); + } + + for (auto it = j.begin(); it != j.end(); ++it) + { + m_value.object->operator[](it.key()) = it.value(); + } + } + + /*! + @brief updates a JSON object from another object, overwriting existing keys + + Inserts all values from from range `[first, last)` and overwrites existing + keys. + + @param[in] first begin of the range of elements to insert + @param[in] last end of the range of elements to insert + + @throw type_error.312 if called on JSON values other than objects; example: + `"cannot use update() with string"` + @throw invalid_iterator.202 if iterator @a first or @a last does does not + point to an object; example: `"iterators first and last must point to + objects"` + @throw invalid_iterator.210 if @a first and @a last do not belong to the + same JSON value; example: `"iterators do not fit"` + + @complexity O(N*log(size() + N)), where N is the number of elements to + insert. + + @liveexample{The example shows how `update()` is used__range.,update} + + @sa https://docs.python.org/3.6/library/stdtypes.html#dict.update + + @since version 3.0.0 + */ + void update(const_iterator first, const_iterator last) + { + // implicitly convert null value to an empty object + if (is_null()) + { + m_type = value_t::object; + m_value.object = create(); + assert_invariant(); + } + + if (JSON_UNLIKELY(not is_object())) + { + JSON_THROW(type_error::create(312, "cannot use update() with " + std::string(type_name()))); + } + + // check if range iterators belong to the same JSON object + if (JSON_UNLIKELY(first.m_object != last.m_object)) + { + JSON_THROW(invalid_iterator::create(210, "iterators do not fit")); + } + + // passed iterators must belong to objects + if (JSON_UNLIKELY(not first.m_object->is_object() + or not first.m_object->is_object())) + { + JSON_THROW(invalid_iterator::create(202, "iterators first and last must point to objects")); + } + + for (auto it = first; it != last; ++it) + { + m_value.object->operator[](it.key()) = it.value(); + } + } + + /*! + @brief exchanges the values + + Exchanges the contents of the JSON value with those of @a other. Does not + invoke any move, copy, or swap operations on individual elements. All + iterators and references remain valid. The past-the-end iterator is + invalidated. + + @param[in,out] other JSON value to exchange the contents with + + @complexity Constant. + + @liveexample{The example below shows how JSON values can be swapped with + `swap()`.,swap__reference} + + @since version 1.0.0 + */ + void swap(reference other) noexcept ( + std::is_nothrow_move_constructible::value and + std::is_nothrow_move_assignable::value and + std::is_nothrow_move_constructible::value and + std::is_nothrow_move_assignable::value + ) + { + std::swap(m_type, other.m_type); + std::swap(m_value, other.m_value); + assert_invariant(); + } + + /*! + @brief exchanges the values + + Exchanges the contents of a JSON array with those of @a other. Does not + invoke any move, copy, or swap operations on individual elements. All + iterators and references remain valid. The past-the-end iterator is + invalidated. + + @param[in,out] other array to exchange the contents with + + @throw type_error.310 when JSON value is not an array; example: `"cannot + use swap() with string"` + + @complexity Constant. + + @liveexample{The example below shows how arrays can be swapped with + `swap()`.,swap__array_t} + + @since version 1.0.0 + */ + void swap(array_t& other) + { + // swap only works for arrays + if (JSON_LIKELY(is_array())) + { + std::swap(*(m_value.array), other); + } + else + { + JSON_THROW(type_error::create(310, "cannot use swap() with " + std::string(type_name()))); + } + } + + /*! + @brief exchanges the values + + Exchanges the contents of a JSON object with those of @a other. Does not + invoke any move, copy, or swap operations on individual elements. All + iterators and references remain valid. The past-the-end iterator is + invalidated. + + @param[in,out] other object to exchange the contents with + + @throw type_error.310 when JSON value is not an object; example: + `"cannot use swap() with string"` + + @complexity Constant. + + @liveexample{The example below shows how objects can be swapped with + `swap()`.,swap__object_t} + + @since version 1.0.0 + */ + void swap(object_t& other) + { + // swap only works for objects + if (JSON_LIKELY(is_object())) + { + std::swap(*(m_value.object), other); + } + else + { + JSON_THROW(type_error::create(310, "cannot use swap() with " + std::string(type_name()))); + } + } + + /*! + @brief exchanges the values + + Exchanges the contents of a JSON string with those of @a other. Does not + invoke any move, copy, or swap operations on individual elements. All + iterators and references remain valid. The past-the-end iterator is + invalidated. + + @param[in,out] other string to exchange the contents with + + @throw type_error.310 when JSON value is not a string; example: `"cannot + use swap() with boolean"` + + @complexity Constant. + + @liveexample{The example below shows how strings can be swapped with + `swap()`.,swap__string_t} + + @since version 1.0.0 + */ + void swap(string_t& other) + { + // swap only works for strings + if (JSON_LIKELY(is_string())) + { + std::swap(*(m_value.string), other); + } + else + { + JSON_THROW(type_error::create(310, "cannot use swap() with " + std::string(type_name()))); + } + } + + /// @} + + public: + ////////////////////////////////////////// + // lexicographical comparison operators // + ////////////////////////////////////////// + + /// @name lexicographical comparison operators + /// @{ + + /*! + @brief comparison: equal + + Compares two JSON values for equality according to the following rules: + - Two JSON values are equal if (1) they are from the same type and (2) + their stored values are the same according to their respective + `operator==`. + - Integer and floating-point numbers are automatically converted before + comparison. Note than two NaN values are always treated as unequal. + - Two JSON null values are equal. + + @note Floating-point inside JSON values numbers are compared with + `json::number_float_t::operator==` which is `double::operator==` by + default. To compare floating-point while respecting an epsilon, an alternative + [comparison function](https://github.com/mariokonrad/marnav/blob/master/src/marnav/math/floatingpoint.hpp#L34-#L39) + could be used, for instance + @code {.cpp} + template ::value, T>::type> + inline bool is_same(T a, T b, T epsilon = std::numeric_limits::epsilon()) noexcept + { + return std::abs(a - b) <= epsilon; + } + @endcode + + @note NaN values never compare equal to themselves or to other NaN values. + + @param[in] lhs first JSON value to consider + @param[in] rhs second JSON value to consider + @return whether the values @a lhs and @a rhs are equal + + @exceptionsafety No-throw guarantee: this function never throws exceptions. + + @complexity Linear. + + @liveexample{The example demonstrates comparing several JSON + types.,operator__equal} + + @since version 1.0.0 + */ + friend bool operator==(const_reference lhs, const_reference rhs) noexcept + { + const auto lhs_type = lhs.type(); + const auto rhs_type = rhs.type(); + + if (lhs_type == rhs_type) + { + switch (lhs_type) + { + case value_t::array: + return (*lhs.m_value.array == *rhs.m_value.array); + + case value_t::object: + return (*lhs.m_value.object == *rhs.m_value.object); + + case value_t::null: + return true; + + case value_t::string: + return (*lhs.m_value.string == *rhs.m_value.string); + + case value_t::boolean: + return (lhs.m_value.boolean == rhs.m_value.boolean); + + case value_t::number_integer: + return (lhs.m_value.number_integer == rhs.m_value.number_integer); + + case value_t::number_unsigned: + return (lhs.m_value.number_unsigned == rhs.m_value.number_unsigned); + + case value_t::number_float: + return (lhs.m_value.number_float == rhs.m_value.number_float); + + default: + return false; + } + } + else if (lhs_type == value_t::number_integer and rhs_type == value_t::number_float) + { + return (static_cast(lhs.m_value.number_integer) == rhs.m_value.number_float); + } + else if (lhs_type == value_t::number_float and rhs_type == value_t::number_integer) + { + return (lhs.m_value.number_float == static_cast(rhs.m_value.number_integer)); + } + else if (lhs_type == value_t::number_unsigned and rhs_type == value_t::number_float) + { + return (static_cast(lhs.m_value.number_unsigned) == rhs.m_value.number_float); + } + else if (lhs_type == value_t::number_float and rhs_type == value_t::number_unsigned) + { + return (lhs.m_value.number_float == static_cast(rhs.m_value.number_unsigned)); + } + else if (lhs_type == value_t::number_unsigned and rhs_type == value_t::number_integer) + { + return (static_cast(lhs.m_value.number_unsigned) == rhs.m_value.number_integer); + } + else if (lhs_type == value_t::number_integer and rhs_type == value_t::number_unsigned) + { + return (lhs.m_value.number_integer == static_cast(rhs.m_value.number_unsigned)); + } + + return false; + } + + /*! + @brief comparison: equal + @copydoc operator==(const_reference, const_reference) + */ + template::value, int>::type = 0> + friend bool operator==(const_reference lhs, const ScalarType rhs) noexcept + { + return (lhs == basic_json(rhs)); + } + + /*! + @brief comparison: equal + @copydoc operator==(const_reference, const_reference) + */ + template::value, int>::type = 0> + friend bool operator==(const ScalarType lhs, const_reference rhs) noexcept + { + return (basic_json(lhs) == rhs); + } + + /*! + @brief comparison: not equal + + Compares two JSON values for inequality by calculating `not (lhs == rhs)`. + + @param[in] lhs first JSON value to consider + @param[in] rhs second JSON value to consider + @return whether the values @a lhs and @a rhs are not equal + + @complexity Linear. + + @exceptionsafety No-throw guarantee: this function never throws exceptions. + + @liveexample{The example demonstrates comparing several JSON + types.,operator__notequal} + + @since version 1.0.0 + */ + friend bool operator!=(const_reference lhs, const_reference rhs) noexcept + { + return not (lhs == rhs); + } + + /*! + @brief comparison: not equal + @copydoc operator!=(const_reference, const_reference) + */ + template::value, int>::type = 0> + friend bool operator!=(const_reference lhs, const ScalarType rhs) noexcept + { + return (lhs != basic_json(rhs)); + } + + /*! + @brief comparison: not equal + @copydoc operator!=(const_reference, const_reference) + */ + template::value, int>::type = 0> + friend bool operator!=(const ScalarType lhs, const_reference rhs) noexcept + { + return (basic_json(lhs) != rhs); + } + + /*! + @brief comparison: less than + + Compares whether one JSON value @a lhs is less than another JSON value @a + rhs according to the following rules: + - If @a lhs and @a rhs have the same type, the values are compared using + the default `<` operator. + - Integer and floating-point numbers are automatically converted before + comparison + - In case @a lhs and @a rhs have different types, the values are ignored + and the order of the types is considered, see + @ref operator<(const value_t, const value_t). + + @param[in] lhs first JSON value to consider + @param[in] rhs second JSON value to consider + @return whether @a lhs is less than @a rhs + + @complexity Linear. + + @exceptionsafety No-throw guarantee: this function never throws exceptions. + + @liveexample{The example demonstrates comparing several JSON + types.,operator__less} + + @since version 1.0.0 + */ + friend bool operator<(const_reference lhs, const_reference rhs) noexcept + { + const auto lhs_type = lhs.type(); + const auto rhs_type = rhs.type(); + + if (lhs_type == rhs_type) + { + switch (lhs_type) + { + case value_t::array: + return (*lhs.m_value.array) < (*rhs.m_value.array); + + case value_t::object: + return *lhs.m_value.object < *rhs.m_value.object; + + case value_t::null: + return false; + + case value_t::string: + return *lhs.m_value.string < *rhs.m_value.string; + + case value_t::boolean: + return lhs.m_value.boolean < rhs.m_value.boolean; + + case value_t::number_integer: + return lhs.m_value.number_integer < rhs.m_value.number_integer; + + case value_t::number_unsigned: + return lhs.m_value.number_unsigned < rhs.m_value.number_unsigned; + + case value_t::number_float: + return lhs.m_value.number_float < rhs.m_value.number_float; + + default: + return false; + } + } + else if (lhs_type == value_t::number_integer and rhs_type == value_t::number_float) + { + return static_cast(lhs.m_value.number_integer) < rhs.m_value.number_float; + } + else if (lhs_type == value_t::number_float and rhs_type == value_t::number_integer) + { + return lhs.m_value.number_float < static_cast(rhs.m_value.number_integer); + } + else if (lhs_type == value_t::number_unsigned and rhs_type == value_t::number_float) + { + return static_cast(lhs.m_value.number_unsigned) < rhs.m_value.number_float; + } + else if (lhs_type == value_t::number_float and rhs_type == value_t::number_unsigned) + { + return lhs.m_value.number_float < static_cast(rhs.m_value.number_unsigned); + } + else if (lhs_type == value_t::number_integer and rhs_type == value_t::number_unsigned) + { + return lhs.m_value.number_integer < static_cast(rhs.m_value.number_unsigned); + } + else if (lhs_type == value_t::number_unsigned and rhs_type == value_t::number_integer) + { + return static_cast(lhs.m_value.number_unsigned) < rhs.m_value.number_integer; + } + + // We only reach this line if we cannot compare values. In that case, + // we compare types. Note we have to call the operator explicitly, + // because MSVC has problems otherwise. + return operator<(lhs_type, rhs_type); + } + + /*! + @brief comparison: less than + @copydoc operator<(const_reference, const_reference) + */ + template::value, int>::type = 0> + friend bool operator<(const_reference lhs, const ScalarType rhs) noexcept + { + return (lhs < basic_json(rhs)); + } + + /*! + @brief comparison: less than + @copydoc operator<(const_reference, const_reference) + */ + template::value, int>::type = 0> + friend bool operator<(const ScalarType lhs, const_reference rhs) noexcept + { + return (basic_json(lhs) < rhs); + } + + /*! + @brief comparison: less than or equal + + Compares whether one JSON value @a lhs is less than or equal to another + JSON value by calculating `not (rhs < lhs)`. + + @param[in] lhs first JSON value to consider + @param[in] rhs second JSON value to consider + @return whether @a lhs is less than or equal to @a rhs + + @complexity Linear. + + @exceptionsafety No-throw guarantee: this function never throws exceptions. + + @liveexample{The example demonstrates comparing several JSON + types.,operator__greater} + + @since version 1.0.0 + */ + friend bool operator<=(const_reference lhs, const_reference rhs) noexcept + { + return not (rhs < lhs); + } + + /*! + @brief comparison: less than or equal + @copydoc operator<=(const_reference, const_reference) + */ + template::value, int>::type = 0> + friend bool operator<=(const_reference lhs, const ScalarType rhs) noexcept + { + return (lhs <= basic_json(rhs)); + } + + /*! + @brief comparison: less than or equal + @copydoc operator<=(const_reference, const_reference) + */ + template::value, int>::type = 0> + friend bool operator<=(const ScalarType lhs, const_reference rhs) noexcept + { + return (basic_json(lhs) <= rhs); + } + + /*! + @brief comparison: greater than + + Compares whether one JSON value @a lhs is greater than another + JSON value by calculating `not (lhs <= rhs)`. + + @param[in] lhs first JSON value to consider + @param[in] rhs second JSON value to consider + @return whether @a lhs is greater than to @a rhs + + @complexity Linear. + + @exceptionsafety No-throw guarantee: this function never throws exceptions. + + @liveexample{The example demonstrates comparing several JSON + types.,operator__lessequal} + + @since version 1.0.0 + */ + friend bool operator>(const_reference lhs, const_reference rhs) noexcept + { + return not (lhs <= rhs); + } + + /*! + @brief comparison: greater than + @copydoc operator>(const_reference, const_reference) + */ + template::value, int>::type = 0> + friend bool operator>(const_reference lhs, const ScalarType rhs) noexcept + { + return (lhs > basic_json(rhs)); + } + + /*! + @brief comparison: greater than + @copydoc operator>(const_reference, const_reference) + */ + template::value, int>::type = 0> + friend bool operator>(const ScalarType lhs, const_reference rhs) noexcept + { + return (basic_json(lhs) > rhs); + } + + /*! + @brief comparison: greater than or equal + + Compares whether one JSON value @a lhs is greater than or equal to another + JSON value by calculating `not (lhs < rhs)`. + + @param[in] lhs first JSON value to consider + @param[in] rhs second JSON value to consider + @return whether @a lhs is greater than or equal to @a rhs + + @complexity Linear. + + @exceptionsafety No-throw guarantee: this function never throws exceptions. + + @liveexample{The example demonstrates comparing several JSON + types.,operator__greaterequal} + + @since version 1.0.0 + */ + friend bool operator>=(const_reference lhs, const_reference rhs) noexcept + { + return not (lhs < rhs); + } + + /*! + @brief comparison: greater than or equal + @copydoc operator>=(const_reference, const_reference) + */ + template::value, int>::type = 0> + friend bool operator>=(const_reference lhs, const ScalarType rhs) noexcept + { + return (lhs >= basic_json(rhs)); + } + + /*! + @brief comparison: greater than or equal + @copydoc operator>=(const_reference, const_reference) + */ + template::value, int>::type = 0> + friend bool operator>=(const ScalarType lhs, const_reference rhs) noexcept + { + return (basic_json(lhs) >= rhs); + } + + /// @} + + /////////////////// + // serialization // + /////////////////// + + /// @name serialization + /// @{ + + /*! + @brief serialize to stream + + Serialize the given JSON value @a j to the output stream @a o. The JSON + value will be serialized using the @ref dump member function. + + - The indentation of the output can be controlled with the member variable + `width` of the output stream @a o. For instance, using the manipulator + `std::setw(4)` on @a o sets the indentation level to `4` and the + serialization result is the same as calling `dump(4)`. + + - The indentation characrer can be controlled with the member variable + `fill` of the output stream @a o. For instance, the manipulator + `std::setfill('\\t')` sets indentation to use a tab character rather than + the default space character. + + @param[in,out] o stream to serialize to + @param[in] j JSON value to serialize + + @return the stream @a o + + @complexity Linear. + + @liveexample{The example below shows the serialization with different + parameters to `width` to adjust the indentation level.,operator_serialize} + + @since version 1.0.0; indentaction character added in version 3.0.0 + */ + friend std::ostream& operator<<(std::ostream& o, const basic_json& j) + { + // read width member and use it as indentation parameter if nonzero + const bool pretty_print = (o.width() > 0); + const auto indentation = (pretty_print ? o.width() : 0); + + // reset width to 0 for subsequent calls to this stream + o.width(0); + + // do the actual serialization + serializer s(detail::output_adapter(o), o.fill()); + s.dump(j, pretty_print, false, static_cast(indentation)); + return o; + } + + /*! + @brief serialize to stream + @deprecated This stream operator is deprecated and will be removed in a + future version of the library. Please use + @ref operator<<(std::ostream&, const basic_json&) + instead; that is, replace calls like `j >> o;` with `o << j;`. + */ + JSON_DEPRECATED + friend std::ostream& operator>>(const basic_json& j, std::ostream& o) + { + return o << j; + } + + /// @} + + + ///////////////////// + // deserialization // + ///////////////////// + + /// @name deserialization + /// @{ + + /*! + @brief deserialize from a compatible input + + This function reads from a compatible input. Examples are: + - an array of 1-byte values + - strings with character/literal type with size of 1 byte + - input streams + - container with contiguous storage of 1-byte values. Compatible container + types include `std::vector`, `std::string`, `std::array`, + `std::valarray`, and `std::initializer_list`. Furthermore, C-style + arrays can be used with `std::begin()`/`std::end()`. User-defined + containers can be used as long as they implement random-access iterators + and a contiguous storage. + + @pre Each element of the container has a size of 1 byte. Violating this + precondition yields undefined behavior. **This precondition is enforced + with a static assertion.** + + @pre The container storage is contiguous. Violating this precondition + yields undefined behavior. **This precondition is enforced with an + assertion.** + @pre Each element of the container has a size of 1 byte. Violating this + precondition yields undefined behavior. **This precondition is enforced + with a static assertion.** + + @warning There is no way to enforce all preconditions at compile-time. If + the function is called with a noncompliant container and with + assertions switched off, the behavior is undefined and will most + likely yield segmentation violation. + + @param[in] i input to read from + @param[in] cb a parser callback function of type @ref parser_callback_t + which is used to control the deserialization by filtering unwanted values + (optional) + + @return result of the deserialization + + @throw parse_error.101 if a parse error occurs; example: `""unexpected end + of input; expected string literal""` + @throw parse_error.102 if to_unicode fails or surrogate error + @throw parse_error.103 if to_unicode fails + + @complexity Linear in the length of the input. The parser is a predictive + LL(1) parser. The complexity can be higher if the parser callback function + @a cb has a super-linear complexity. + + @note A UTF-8 byte order mark is silently ignored. + + @liveexample{The example below demonstrates the `parse()` function reading + from an array.,parse__array__parser_callback_t} + + @liveexample{The example below demonstrates the `parse()` function with + and without callback function.,parse__string__parser_callback_t} + + @liveexample{The example below demonstrates the `parse()` function with + and without callback function.,parse__istream__parser_callback_t} + + @liveexample{The example below demonstrates the `parse()` function reading + from a contiguous container.,parse__contiguouscontainer__parser_callback_t} + + @since version 2.0.3 (contiguous containers) + */ + static basic_json parse(detail::input_adapter i, + const parser_callback_t cb = nullptr, + const bool allow_exceptions = true) + { + basic_json result; + parser(i, cb, allow_exceptions).parse(true, result); + return result; + } + + /*! + @copydoc basic_json parse(detail::input_adapter, const parser_callback_t) + */ + static basic_json parse(detail::input_adapter& i, + const parser_callback_t cb = nullptr, + const bool allow_exceptions = true) + { + basic_json result; + parser(i, cb, allow_exceptions).parse(true, result); + return result; + } + + static bool accept(detail::input_adapter i) + { + return parser(i).accept(true); + } + + static bool accept(detail::input_adapter& i) + { + return parser(i).accept(true); + } + + /*! + @brief deserialize from an iterator range with contiguous storage + + This function reads from an iterator range of a container with contiguous + storage of 1-byte values. Compatible container types include + `std::vector`, `std::string`, `std::array`, `std::valarray`, and + `std::initializer_list`. Furthermore, C-style arrays can be used with + `std::begin()`/`std::end()`. User-defined containers can be used as long + as they implement random-access iterators and a contiguous storage. + + @pre The iterator range is contiguous. Violating this precondition yields + undefined behavior. **This precondition is enforced with an assertion.** + @pre Each element in the range has a size of 1 byte. Violating this + precondition yields undefined behavior. **This precondition is enforced + with a static assertion.** + + @warning There is no way to enforce all preconditions at compile-time. If + the function is called with noncompliant iterators and with + assertions switched off, the behavior is undefined and will most + likely yield segmentation violation. + + @tparam IteratorType iterator of container with contiguous storage + @param[in] first begin of the range to parse (included) + @param[in] last end of the range to parse (excluded) + @param[in] cb a parser callback function of type @ref parser_callback_t + which is used to control the deserialization by filtering unwanted values + (optional) + + @return result of the deserialization + + @throw parse_error.101 in case of an unexpected token + @throw parse_error.102 if to_unicode fails or surrogate error + @throw parse_error.103 if to_unicode fails + + @complexity Linear in the length of the input. The parser is a predictive + LL(1) parser. The complexity can be higher if the parser callback function + @a cb has a super-linear complexity. + + @note A UTF-8 byte order mark is silently ignored. + + @liveexample{The example below demonstrates the `parse()` function reading + from an iterator range.,parse__iteratortype__parser_callback_t} + + @since version 2.0.3 + */ + template::iterator_category>::value, int>::type = 0> + static basic_json parse(IteratorType first, IteratorType last, + const parser_callback_t cb = nullptr, + const bool allow_exceptions = true) + { + basic_json result; + parser(detail::input_adapter(first, last), cb, allow_exceptions).parse(true, result); + return result; + } + + template::iterator_category>::value, int>::type = 0> + static bool accept(IteratorType first, IteratorType last) + { + return parser(detail::input_adapter(first, last)).accept(true); + } + + /*! + @brief deserialize from stream + @deprecated This stream operator is deprecated and will be removed in a + future version of the library. Please use + @ref operator>>(std::istream&, basic_json&) + instead; that is, replace calls like `j << i;` with `i >> j;`. + */ + JSON_DEPRECATED + friend std::istream& operator<<(basic_json& j, std::istream& i) + { + return operator>>(i, j); + } + + /*! + @brief deserialize from stream + + Deserializes an input stream to a JSON value. + + @param[in,out] i input stream to read a serialized JSON value from + @param[in,out] j JSON value to write the deserialized input to + + @throw parse_error.101 in case of an unexpected token + @throw parse_error.102 if to_unicode fails or surrogate error + @throw parse_error.103 if to_unicode fails + + @complexity Linear in the length of the input. The parser is a predictive + LL(1) parser. + + @note A UTF-8 byte order mark is silently ignored. + + @liveexample{The example below shows how a JSON value is constructed by + reading a serialization from a stream.,operator_deserialize} + + @sa parse(std::istream&, const parser_callback_t) for a variant with a + parser callback function to filter values while parsing + + @since version 1.0.0 + */ + friend std::istream& operator>>(std::istream& i, basic_json& j) + { + parser(detail::input_adapter(i)).parse(false, j); + return i; + } + + /// @} + + /////////////////////////// + // convenience functions // + /////////////////////////// + + /*! + @brief return the type as string + + Returns the type name as string to be used in error messages - usually to + indicate that a function was called on a wrong JSON type. + + @return a string representation of a the @a m_type member: + Value type | return value + ----------- | ------------- + null | `"null"` + boolean | `"boolean"` + string | `"string"` + number | `"number"` (for all number types) + object | `"object"` + array | `"array"` + discarded | `"discarded"` + + @exceptionsafety No-throw guarantee: this function never throws exceptions. + + @complexity Constant. + + @liveexample{The following code exemplifies `type_name()` for all JSON + types.,type_name} + + @sa @ref type() -- return the type of the JSON value + @sa @ref operator value_t() -- return the type of the JSON value (implicit) + + @since version 1.0.0, public since 2.1.0, `const char*` and `noexcept` + since 3.0.0 + */ + const char* type_name() const noexcept + { + { + switch (m_type) + { + case value_t::null: + return "null"; + case value_t::object: + return "object"; + case value_t::array: + return "array"; + case value_t::string: + return "string"; + case value_t::boolean: + return "boolean"; + case value_t::discarded: + return "discarded"; + default: + return "number"; + } + } + } + + + private: + ////////////////////// + // member variables // + ////////////////////// + + /// the type of the current element + value_t m_type = value_t::null; + + /// the value of the current element + json_value m_value = {}; + + ////////////////////////////////////////// + // binary serialization/deserialization // + ////////////////////////////////////////// + + /// @name binary serialization/deserialization support + /// @{ + + public: + /*! + @brief create a CBOR serialization of a given JSON value + + Serializes a given JSON value @a j to a byte vector using the CBOR (Concise + Binary Object Representation) serialization format. CBOR is a binary + serialization format which aims to be more compact than JSON itself, yet + more efficient to parse. + + The library uses the following mapping from JSON values types to + CBOR types according to the CBOR specification (RFC 7049): + + JSON value type | value/range | CBOR type | first byte + --------------- | ------------------------------------------ | ---------------------------------- | --------------- + null | `null` | Null | 0xf6 + boolean | `true` | True | 0xf5 + boolean | `false` | False | 0xf4 + number_integer | -9223372036854775808..-2147483649 | Negative integer (8 bytes follow) | 0x3b + number_integer | -2147483648..-32769 | Negative integer (4 bytes follow) | 0x3a + number_integer | -32768..-129 | Negative integer (2 bytes follow) | 0x39 + number_integer | -128..-25 | Negative integer (1 byte follow) | 0x38 + number_integer | -24..-1 | Negative integer | 0x20..0x37 + number_integer | 0..23 | Integer | 0x00..0x17 + number_integer | 24..255 | Unsigned integer (1 byte follow) | 0x18 + number_integer | 256..65535 | Unsigned integer (2 bytes follow) | 0x19 + number_integer | 65536..4294967295 | Unsigned integer (4 bytes follow) | 0x1a + number_integer | 4294967296..18446744073709551615 | Unsigned integer (8 bytes follow) | 0x1b + number_unsigned | 0..23 | Integer | 0x00..0x17 + number_unsigned | 24..255 | Unsigned integer (1 byte follow) | 0x18 + number_unsigned | 256..65535 | Unsigned integer (2 bytes follow) | 0x19 + number_unsigned | 65536..4294967295 | Unsigned integer (4 bytes follow) | 0x1a + number_unsigned | 4294967296..18446744073709551615 | Unsigned integer (8 bytes follow) | 0x1b + number_float | *any value* | Double-Precision Float | 0xfb + string | *length*: 0..23 | UTF-8 string | 0x60..0x77 + string | *length*: 23..255 | UTF-8 string (1 byte follow) | 0x78 + string | *length*: 256..65535 | UTF-8 string (2 bytes follow) | 0x79 + string | *length*: 65536..4294967295 | UTF-8 string (4 bytes follow) | 0x7a + string | *length*: 4294967296..18446744073709551615 | UTF-8 string (8 bytes follow) | 0x7b + array | *size*: 0..23 | array | 0x80..0x97 + array | *size*: 23..255 | array (1 byte follow) | 0x98 + array | *size*: 256..65535 | array (2 bytes follow) | 0x99 + array | *size*: 65536..4294967295 | array (4 bytes follow) | 0x9a + array | *size*: 4294967296..18446744073709551615 | array (8 bytes follow) | 0x9b + object | *size*: 0..23 | map | 0xa0..0xb7 + object | *size*: 23..255 | map (1 byte follow) | 0xb8 + object | *size*: 256..65535 | map (2 bytes follow) | 0xb9 + object | *size*: 65536..4294967295 | map (4 bytes follow) | 0xba + object | *size*: 4294967296..18446744073709551615 | map (8 bytes follow) | 0xbb + + @note The mapping is **complete** in the sense that any JSON value type + can be converted to a CBOR value. + + @note If NaN or Infinity are stored inside a JSON number, they are + serialized properly. This behavior differs from the @ref dump() + function which serializes NaN or Infinity to `null`. + + @note The following CBOR types are not used in the conversion: + - byte strings (0x40..0x5f) + - UTF-8 strings terminated by "break" (0x7f) + - arrays terminated by "break" (0x9f) + - maps terminated by "break" (0xbf) + - date/time (0xc0..0xc1) + - bignum (0xc2..0xc3) + - decimal fraction (0xc4) + - bigfloat (0xc5) + - tagged items (0xc6..0xd4, 0xd8..0xdb) + - expected conversions (0xd5..0xd7) + - simple values (0xe0..0xf3, 0xf8) + - undefined (0xf7) + - half and single-precision floats (0xf9-0xfa) + - break (0xff) + + @param[in] j JSON value to serialize + @return MessagePack serialization as byte vector + + @complexity Linear in the size of the JSON value @a j. + + @liveexample{The example shows the serialization of a JSON value to a byte + vector in CBOR format.,to_cbor} + + @sa http://cbor.io + @sa @ref from_cbor(const std::vector&, const size_t) for the + analogous deserialization + @sa @ref to_msgpack(const basic_json&) for the related MessagePack format + + @since version 2.0.9 + */ + static std::vector to_cbor(const basic_json& j) + { + std::vector result; + to_cbor(j, result); + return result; + } + + static void to_cbor(const basic_json& j, detail::output_adapter o) + { + binary_writer(o).write_cbor(j); + } + + static void to_cbor(const basic_json& j, detail::output_adapter o) + { + binary_writer(o).write_cbor(j); + } + + /*! + @brief create a MessagePack serialization of a given JSON value + + Serializes a given JSON value @a j to a byte vector using the MessagePack + serialization format. MessagePack is a binary serialization format which + aims to be more compact than JSON itself, yet more efficient to parse. + + The library uses the following mapping from JSON values types to + MessagePack types according to the MessagePack specification: + + JSON value type | value/range | MessagePack type | first byte + --------------- | --------------------------------- | ---------------- | ---------- + null | `null` | nil | 0xc0 + boolean | `true` | true | 0xc3 + boolean | `false` | false | 0xc2 + number_integer | -9223372036854775808..-2147483649 | int64 | 0xd3 + number_integer | -2147483648..-32769 | int32 | 0xd2 + number_integer | -32768..-129 | int16 | 0xd1 + number_integer | -128..-33 | int8 | 0xd0 + number_integer | -32..-1 | negative fixint | 0xe0..0xff + number_integer | 0..127 | positive fixint | 0x00..0x7f + number_integer | 128..255 | uint 8 | 0xcc + number_integer | 256..65535 | uint 16 | 0xcd + number_integer | 65536..4294967295 | uint 32 | 0xce + number_integer | 4294967296..18446744073709551615 | uint 64 | 0xcf + number_unsigned | 0..127 | positive fixint | 0x00..0x7f + number_unsigned | 128..255 | uint 8 | 0xcc + number_unsigned | 256..65535 | uint 16 | 0xcd + number_unsigned | 65536..4294967295 | uint 32 | 0xce + number_unsigned | 4294967296..18446744073709551615 | uint 64 | 0xcf + number_float | *any value* | float 64 | 0xcb + string | *length*: 0..31 | fixstr | 0xa0..0xbf + string | *length*: 32..255 | str 8 | 0xd9 + string | *length*: 256..65535 | str 16 | 0xda + string | *length*: 65536..4294967295 | str 32 | 0xdb + array | *size*: 0..15 | fixarray | 0x90..0x9f + array | *size*: 16..65535 | array 16 | 0xdc + array | *size*: 65536..4294967295 | array 32 | 0xdd + object | *size*: 0..15 | fix map | 0x80..0x8f + object | *size*: 16..65535 | map 16 | 0xde + object | *size*: 65536..4294967295 | map 32 | 0xdf + + @note The mapping is **complete** in the sense that any JSON value type + can be converted to a MessagePack value. + + @note The following values can **not** be converted to a MessagePack value: + - strings with more than 4294967295 bytes + - arrays with more than 4294967295 elements + - objects with more than 4294967295 elements + + @note The following MessagePack types are not used in the conversion: + - bin 8 - bin 32 (0xc4..0xc6) + - ext 8 - ext 32 (0xc7..0xc9) + - float 32 (0xca) + - fixext 1 - fixext 16 (0xd4..0xd8) + + @note Any MessagePack output created @ref to_msgpack can be successfully + parsed by @ref from_msgpack. + + @note If NaN or Infinity are stored inside a JSON number, they are + serialized properly. This behavior differs from the @ref dump() + function which serializes NaN or Infinity to `null`. + + @param[in] j JSON value to serialize + @return MessagePack serialization as byte vector + + @complexity Linear in the size of the JSON value @a j. + + @liveexample{The example shows the serialization of a JSON value to a byte + vector in MessagePack format.,to_msgpack} + + @sa http://msgpack.org + @sa @ref from_msgpack(const std::vector&, const size_t) for the + analogous deserialization + @sa @ref to_cbor(const basic_json& for the related CBOR format + + @since version 2.0.9 + */ + static std::vector to_msgpack(const basic_json& j) + { + std::vector result; + to_msgpack(j, result); + return result; + } + + static void to_msgpack(const basic_json& j, detail::output_adapter o) + { + binary_writer(o).write_msgpack(j); + } + + static void to_msgpack(const basic_json& j, detail::output_adapter o) + { + binary_writer(o).write_msgpack(j); + } + + /*! + @brief create a JSON value from an input in CBOR format + + Deserializes a given input @a i to a JSON value using the CBOR (Concise + Binary Object Representation) serialization format. + + The library maps CBOR types to JSON value types as follows: + + CBOR type | JSON value type | first byte + ---------------------- | --------------- | ---------- + Integer | number_unsigned | 0x00..0x17 + Unsigned integer | number_unsigned | 0x18 + Unsigned integer | number_unsigned | 0x19 + Unsigned integer | number_unsigned | 0x1a + Unsigned integer | number_unsigned | 0x1b + Negative integer | number_integer | 0x20..0x37 + Negative integer | number_integer | 0x38 + Negative integer | number_integer | 0x39 + Negative integer | number_integer | 0x3a + Negative integer | number_integer | 0x3b + Negative integer | number_integer | 0x40..0x57 + UTF-8 string | string | 0x60..0x77 + UTF-8 string | string | 0x78 + UTF-8 string | string | 0x79 + UTF-8 string | string | 0x7a + UTF-8 string | string | 0x7b + UTF-8 string | string | 0x7f + array | array | 0x80..0x97 + array | array | 0x98 + array | array | 0x99 + array | array | 0x9a + array | array | 0x9b + array | array | 0x9f + map | object | 0xa0..0xb7 + map | object | 0xb8 + map | object | 0xb9 + map | object | 0xba + map | object | 0xbb + map | object | 0xbf + False | `false` | 0xf4 + True | `true` | 0xf5 + Nill | `null` | 0xf6 + Half-Precision Float | number_float | 0xf9 + Single-Precision Float | number_float | 0xfa + Double-Precision Float | number_float | 0xfb + + @warning The mapping is **incomplete** in the sense that not all CBOR + types can be converted to a JSON value. The following CBOR types + are not supported and will yield parse errors (parse_error.112): + - byte strings (0x40..0x5f) + - date/time (0xc0..0xc1) + - bignum (0xc2..0xc3) + - decimal fraction (0xc4) + - bigfloat (0xc5) + - tagged items (0xc6..0xd4, 0xd8..0xdb) + - expected conversions (0xd5..0xd7) + - simple values (0xe0..0xf3, 0xf8) + - undefined (0xf7) + + @warning CBOR allows map keys of any type, whereas JSON only allows + strings as keys in object values. Therefore, CBOR maps with keys + other than UTF-8 strings are rejected (parse_error.113). + + @note Any CBOR output created @ref to_cbor can be successfully parsed by + @ref from_cbor. + + @param[in] i an input in CBOR format convertible to an input adapter + @param[in] strict whether to expect the input to be consumed until EOF + (true by default) + @return deserialized JSON value + + @throw parse_error.110 if the given input ends prematurely or the end of + file was not reached when @a strict was set to true + @throw parse_error.112 if unsupported features from CBOR were + used in the given input @a v or if the input is not valid CBOR + @throw parse_error.113 if a string was expected as map key, but not found + + @complexity Linear in the size of the input @a i. + + @liveexample{The example shows the deserialization of a byte vector in CBOR + format to a JSON value.,from_cbor} + + @sa http://cbor.io + @sa @ref to_cbor(const basic_json&) for the analogous serialization + @sa @ref from_msgpack(detail::input_adapter, const bool) for the + related MessagePack format + + @since version 2.0.9; parameter @a start_index since 2.1.1; changed to + consume input adapters, removed start_index parameter, and added + @a strict parameter since 3.0.0 + */ + static basic_json from_cbor(detail::input_adapter i, + const bool strict = true) + { + return binary_reader(i).parse_cbor(strict); + } + + /*! + @copydoc from_cbor(detail::input_adapter, const bool) + */ + template::value, int> = 0> + static basic_json from_cbor(A1 && a1, A2 && a2, const bool strict = true) + { + return binary_reader(detail::input_adapter(std::forward(a1), std::forward(a2))).parse_cbor(strict); + } + + /*! + @brief create a JSON value from an input in MessagePack format + + Deserializes a given input @a i to a JSON value using the MessagePack + serialization format. + + The library maps MessagePack types to JSON value types as follows: + + MessagePack type | JSON value type | first byte + ---------------- | --------------- | ---------- + positive fixint | number_unsigned | 0x00..0x7f + fixmap | object | 0x80..0x8f + fixarray | array | 0x90..0x9f + fixstr | string | 0xa0..0xbf + nil | `null` | 0xc0 + false | `false` | 0xc2 + true | `true` | 0xc3 + float 32 | number_float | 0xca + float 64 | number_float | 0xcb + uint 8 | number_unsigned | 0xcc + uint 16 | number_unsigned | 0xcd + uint 32 | number_unsigned | 0xce + uint 64 | number_unsigned | 0xcf + int 8 | number_integer | 0xd0 + int 16 | number_integer | 0xd1 + int 32 | number_integer | 0xd2 + int 64 | number_integer | 0xd3 + str 8 | string | 0xd9 + str 16 | string | 0xda + str 32 | string | 0xdb + array 16 | array | 0xdc + array 32 | array | 0xdd + map 16 | object | 0xde + map 32 | object | 0xdf + negative fixint | number_integer | 0xe0-0xff + + @warning The mapping is **incomplete** in the sense that not all + MessagePack types can be converted to a JSON value. The following + MessagePack types are not supported and will yield parse errors: + - bin 8 - bin 32 (0xc4..0xc6) + - ext 8 - ext 32 (0xc7..0xc9) + - fixext 1 - fixext 16 (0xd4..0xd8) + + @note Any MessagePack output created @ref to_msgpack can be successfully + parsed by @ref from_msgpack. + + @param[in] i an input in MessagePack format convertible to an input + adapter + @param[in] strict whether to expect the input to be consumed until EOF + (true by default) + + @throw parse_error.110 if the given input ends prematurely or the end of + file was not reached when @a strict was set to true + @throw parse_error.112 if unsupported features from MessagePack were + used in the given input @a i or if the input is not valid MessagePack + @throw parse_error.113 if a string was expected as map key, but not found + + @complexity Linear in the size of the input @a i. + + @liveexample{The example shows the deserialization of a byte vector in + MessagePack format to a JSON value.,from_msgpack} + + @sa http://msgpack.org + @sa @ref to_msgpack(const basic_json&) for the analogous serialization + @sa @ref from_cbor(detail::input_adapter, const bool) for the related CBOR + format + + @since version 2.0.9; parameter @a start_index since 2.1.1; changed to + consume input adapters, removed start_index parameter, and added + @a strict parameter since 3.0.0 + */ + static basic_json from_msgpack(detail::input_adapter i, + const bool strict = true) + { + return binary_reader(i).parse_msgpack(strict); + } + + /*! + @copydoc from_msgpack(detail::input_adapter, const bool) + */ + template::value, int> = 0> + static basic_json from_msgpack(A1 && a1, A2 && a2, const bool strict = true) + { + return binary_reader(detail::input_adapter(std::forward(a1), std::forward(a2))).parse_msgpack(strict); + } + + /// @} + + ////////////////////////// + // JSON Pointer support // + ////////////////////////// + + /// @name JSON Pointer functions + /// @{ + + /*! + @brief access specified element via JSON Pointer + + Uses a JSON pointer to retrieve a reference to the respective JSON value. + No bound checking is performed. Similar to @ref operator[](const typename + object_t::key_type&), `null` values are created in arrays and objects if + necessary. + + In particular: + - If the JSON pointer points to an object key that does not exist, it + is created an filled with a `null` value before a reference to it + is returned. + - If the JSON pointer points to an array index that does not exist, it + is created an filled with a `null` value before a reference to it + is returned. All indices between the current maximum and the given + index are also filled with `null`. + - The special value `-` is treated as a synonym for the index past the + end. + + @param[in] ptr a JSON pointer + + @return reference to the element pointed to by @a ptr + + @complexity Constant. + + @throw parse_error.106 if an array index begins with '0' + @throw parse_error.109 if an array index was not a number + @throw out_of_range.404 if the JSON pointer can not be resolved + + @liveexample{The behavior is shown in the example.,operatorjson_pointer} + + @since version 2.0.0 + */ + reference operator[](const json_pointer& ptr) + { + return ptr.get_unchecked(this); + } + + /*! + @brief access specified element via JSON Pointer + + Uses a JSON pointer to retrieve a reference to the respective JSON value. + No bound checking is performed. The function does not change the JSON + value; no `null` values are created. In particular, the the special value + `-` yields an exception. + + @param[in] ptr JSON pointer to the desired element + + @return const reference to the element pointed to by @a ptr + + @complexity Constant. + + @throw parse_error.106 if an array index begins with '0' + @throw parse_error.109 if an array index was not a number + @throw out_of_range.402 if the array index '-' is used + @throw out_of_range.404 if the JSON pointer can not be resolved + + @liveexample{The behavior is shown in the example.,operatorjson_pointer_const} + + @since version 2.0.0 + */ + const_reference operator[](const json_pointer& ptr) const + { + return ptr.get_unchecked(this); + } + + /*! + @brief access specified element via JSON Pointer + + Returns a reference to the element at with specified JSON pointer @a ptr, + with bounds checking. + + @param[in] ptr JSON pointer to the desired element + + @return reference to the element pointed to by @a ptr + + @throw parse_error.106 if an array index in the passed JSON pointer @a ptr + begins with '0'. See example below. + + @throw parse_error.109 if an array index in the passed JSON pointer @a ptr + is not a number. See example below. + + @throw out_of_range.401 if an array index in the passed JSON pointer @a ptr + is out of range. See example below. + + @throw out_of_range.402 if the array index '-' is used in the passed JSON + pointer @a ptr. As `at` provides checked access (and no elements are + implicitly inserted), the index '-' is always invalid. See example below. + + @throw out_of_range.404 if the JSON pointer @a ptr can not be resolved. + See example below. + + @exceptionsafety Strong guarantee: if an exception is thrown, there are no + changes in the JSON value. + + @complexity Constant. + + @since version 2.0.0 + + @liveexample{The behavior is shown in the example.,at_json_pointer} + */ + reference at(const json_pointer& ptr) + { + return ptr.get_checked(this); + } + + /*! + @brief access specified element via JSON Pointer + + Returns a const reference to the element at with specified JSON pointer @a + ptr, with bounds checking. + + @param[in] ptr JSON pointer to the desired element + + @return reference to the element pointed to by @a ptr + + @throw parse_error.106 if an array index in the passed JSON pointer @a ptr + begins with '0'. See example below. + + @throw parse_error.109 if an array index in the passed JSON pointer @a ptr + is not a number. See example below. + + @throw out_of_range.401 if an array index in the passed JSON pointer @a ptr + is out of range. See example below. + + @throw out_of_range.402 if the array index '-' is used in the passed JSON + pointer @a ptr. As `at` provides checked access (and no elements are + implicitly inserted), the index '-' is always invalid. See example below. + + @throw out_of_range.404 if the JSON pointer @a ptr can not be resolved. + See example below. + + @exceptionsafety Strong guarantee: if an exception is thrown, there are no + changes in the JSON value. + + @complexity Constant. + + @since version 2.0.0 + + @liveexample{The behavior is shown in the example.,at_json_pointer_const} + */ + const_reference at(const json_pointer& ptr) const + { + return ptr.get_checked(this); + } + + /*! + @brief return flattened JSON value + + The function creates a JSON object whose keys are JSON pointers (see [RFC + 6901](https://tools.ietf.org/html/rfc6901)) and whose values are all + primitive. The original JSON value can be restored using the @ref + unflatten() function. + + @return an object that maps JSON pointers to primitive values + + @note Empty objects and arrays are flattened to `null` and will not be + reconstructed correctly by the @ref unflatten() function. + + @complexity Linear in the size the JSON value. + + @liveexample{The following code shows how a JSON object is flattened to an + object whose keys consist of JSON pointers.,flatten} + + @sa @ref unflatten() for the reverse function + + @since version 2.0.0 + */ + basic_json flatten() const + { + basic_json result(value_t::object); + json_pointer::flatten("", *this, result); + return result; + } + + /*! + @brief unflatten a previously flattened JSON value + + The function restores the arbitrary nesting of a JSON value that has been + flattened before using the @ref flatten() function. The JSON value must + meet certain constraints: + 1. The value must be an object. + 2. The keys must be JSON pointers (see + [RFC 6901](https://tools.ietf.org/html/rfc6901)) + 3. The mapped values must be primitive JSON types. + + @return the original JSON from a flattened version + + @note Empty objects and arrays are flattened by @ref flatten() to `null` + values and can not unflattened to their original type. Apart from + this example, for a JSON value `j`, the following is always true: + `j == j.flatten().unflatten()`. + + @complexity Linear in the size the JSON value. + + @throw type_error.314 if value is not an object + @throw type_error.315 if object values are not primitive + + @liveexample{The following code shows how a flattened JSON object is + unflattened into the original nested JSON object.,unflatten} + + @sa @ref flatten() for the reverse function + + @since version 2.0.0 + */ + basic_json unflatten() const + { + return json_pointer::unflatten(*this); + } + + /// @} + + ////////////////////////// + // JSON Patch functions // + ////////////////////////// + + /// @name JSON Patch functions + /// @{ + + /*! + @brief applies a JSON patch + + [JSON Patch](http://jsonpatch.com) defines a JSON document structure for + expressing a sequence of operations to apply to a JSON) document. With + this function, a JSON Patch is applied to the current JSON value by + executing all operations from the patch. + + @param[in] json_patch JSON patch document + @return patched document + + @note The application of a patch is atomic: Either all operations succeed + and the patched document is returned or an exception is thrown. In + any case, the original value is not changed: the patch is applied + to a copy of the value. + + @throw parse_error.104 if the JSON patch does not consist of an array of + objects + + @throw parse_error.105 if the JSON patch is malformed (e.g., mandatory + attributes are missing); example: `"operation add must have member path"` + + @throw out_of_range.401 if an array index is out of range. + + @throw out_of_range.403 if a JSON pointer inside the patch could not be + resolved successfully in the current JSON value; example: `"key baz not + found"` + + @throw out_of_range.405 if JSON pointer has no parent ("add", "remove", + "move") + + @throw other_error.501 if "test" operation was unsuccessful + + @complexity Linear in the size of the JSON value and the length of the + JSON patch. As usually only a fraction of the JSON value is affected by + the patch, the complexity can usually be neglected. + + @liveexample{The following code shows how a JSON patch is applied to a + value.,patch} + + @sa @ref diff -- create a JSON patch by comparing two JSON values + + @sa [RFC 6902 (JSON Patch)](https://tools.ietf.org/html/rfc6902) + @sa [RFC 6901 (JSON Pointer)](https://tools.ietf.org/html/rfc6901) + + @since version 2.0.0 + */ + basic_json patch(const basic_json& json_patch) const + { + // make a working copy to apply the patch to + basic_json result = *this; + + // the valid JSON Patch operations + enum class patch_operations {add, remove, replace, move, copy, test, invalid}; + + const auto get_op = [](const std::string & op) + { + if (op == "add") + { + return patch_operations::add; + } + if (op == "remove") + { + return patch_operations::remove; + } + if (op == "replace") + { + return patch_operations::replace; + } + if (op == "move") + { + return patch_operations::move; + } + if (op == "copy") + { + return patch_operations::copy; + } + if (op == "test") + { + return patch_operations::test; + } + + return patch_operations::invalid; + }; + + // wrapper for "add" operation; add value at ptr + const auto operation_add = [&result](json_pointer & ptr, basic_json val) + { + // adding to the root of the target document means replacing it + if (ptr.is_root()) + { + result = val; + } + else + { + // make sure the top element of the pointer exists + json_pointer top_pointer = ptr.top(); + if (top_pointer != ptr) + { + result.at(top_pointer); + } + + // get reference to parent of JSON pointer ptr + const auto last_path = ptr.pop_back(); + basic_json& parent = result[ptr]; + + switch (parent.m_type) + { + case value_t::null: + case value_t::object: + { + // use operator[] to add value + parent[last_path] = val; + break; + } + + case value_t::array: + { + if (last_path == "-") + { + // special case: append to back + parent.push_back(val); + } + else + { + const auto idx = std::stoi(last_path); + if (JSON_UNLIKELY(static_cast(idx) > parent.size())) + { + // avoid undefined behavior + JSON_THROW(out_of_range::create(401, "array index " + std::to_string(idx) + " is out of range")); + } + else + { + // default case: insert add offset + parent.insert(parent.begin() + static_cast(idx), val); + } + } + break; + } + + default: + { + // if there exists a parent it cannot be primitive + assert(false); // LCOV_EXCL_LINE + } + } + } + }; + + // wrapper for "remove" operation; remove value at ptr + const auto operation_remove = [&result](json_pointer & ptr) + { + // get reference to parent of JSON pointer ptr + const auto last_path = ptr.pop_back(); + basic_json& parent = result.at(ptr); + + // remove child + if (parent.is_object()) + { + // perform range check + auto it = parent.find(last_path); + if (JSON_LIKELY(it != parent.end())) + { + parent.erase(it); + } + else + { + JSON_THROW(out_of_range::create(403, "key '" + last_path + "' not found")); + } + } + else if (parent.is_array()) + { + // note erase performs range check + parent.erase(static_cast(std::stoi(last_path))); + } + }; + + // type check: top level value must be an array + if (JSON_UNLIKELY(not json_patch.is_array())) + { + JSON_THROW(parse_error::create(104, 0, "JSON patch must be an array of objects")); + } + + // iterate and apply the operations + for (const auto& val : json_patch) + { + // wrapper to get a value for an operation + const auto get_value = [&val](const std::string & op, + const std::string & member, + bool string_type) -> basic_json& + { + // find value + auto it = val.m_value.object->find(member); + + // context-sensitive error message + const auto error_msg = (op == "op") ? "operation" : "operation '" + op + "'"; + + // check if desired value is present + if (JSON_UNLIKELY(it == val.m_value.object->end())) + { + JSON_THROW(parse_error::create(105, 0, error_msg + " must have member '" + member + "'")); + } + + // check if result is of type string + if (JSON_UNLIKELY(string_type and not it->second.is_string())) + { + JSON_THROW(parse_error::create(105, 0, error_msg + " must have string member '" + member + "'")); + } + + // no error: return value + return it->second; + }; + + // type check: every element of the array must be an object + if (JSON_UNLIKELY(not val.is_object())) + { + JSON_THROW(parse_error::create(104, 0, "JSON patch must be an array of objects")); + } + + // collect mandatory members + const std::string op = get_value("op", "op", true); + const std::string path = get_value(op, "path", true); + json_pointer ptr(path); + + switch (get_op(op)) + { + case patch_operations::add: + { + operation_add(ptr, get_value("add", "value", false)); + break; + } + + case patch_operations::remove: + { + operation_remove(ptr); + break; + } + + case patch_operations::replace: + { + // the "path" location must exist - use at() + result.at(ptr) = get_value("replace", "value", false); + break; + } + + case patch_operations::move: + { + const std::string from_path = get_value("move", "from", true); + json_pointer from_ptr(from_path); + + // the "from" location must exist - use at() + basic_json v = result.at(from_ptr); + + // The move operation is functionally identical to a + // "remove" operation on the "from" location, followed + // immediately by an "add" operation at the target + // location with the value that was just removed. + operation_remove(from_ptr); + operation_add(ptr, v); + break; + } + + case patch_operations::copy: + { + const std::string from_path = get_value("copy", "from", true); + const json_pointer from_ptr(from_path); + + // the "from" location must exist - use at() + result[ptr] = result.at(from_ptr); + break; + } + + case patch_operations::test: + { + bool success = false; + JSON_TRY + { + // check if "value" matches the one at "path" + // the "path" location must exist - use at() + success = (result.at(ptr) == get_value("test", "value", false)); + } + JSON_CATCH (out_of_range&) + { + // ignore out of range errors: success remains false + } + + // throw an exception if test fails + if (JSON_UNLIKELY(not success)) + { + JSON_THROW(other_error::create(501, "unsuccessful: " + val.dump())); + } + + break; + } + + case patch_operations::invalid: + { + // op must be "add", "remove", "replace", "move", "copy", or + // "test" + JSON_THROW(parse_error::create(105, 0, "operation value '" + op + "' is invalid")); + } + } + } + + return result; + } + + /*! + @brief creates a diff as a JSON patch + + Creates a [JSON Patch](http://jsonpatch.com) so that value @a source can + be changed into the value @a target by calling @ref patch function. + + @invariant For two JSON values @a source and @a target, the following code + yields always `true`: + @code {.cpp} + source.patch(diff(source, target)) == target; + @endcode + + @note Currently, only `remove`, `add`, and `replace` operations are + generated. + + @param[in] source JSON value to compare from + @param[in] target JSON value to compare against + @param[in] path helper value to create JSON pointers + + @return a JSON patch to convert the @a source to @a target + + @complexity Linear in the lengths of @a source and @a target. + + @liveexample{The following code shows how a JSON patch is created as a + diff for two JSON values.,diff} + + @sa @ref patch -- apply a JSON patch + + @sa [RFC 6902 (JSON Patch)](https://tools.ietf.org/html/rfc6902) + + @since version 2.0.0 + */ + static basic_json diff(const basic_json& source, const basic_json& target, + const std::string& path = "") + { + // the patch + basic_json result(value_t::array); + + // if the values are the same, return empty patch + if (source == target) + { + return result; + } + + if (source.type() != target.type()) + { + // different types: replace value + result.push_back( + { + {"op", "replace"}, {"path", path}, {"value", target} + }); + } + else + { + switch (source.type()) + { + case value_t::array: + { + // first pass: traverse common elements + std::size_t i = 0; + while (i < source.size() and i < target.size()) + { + // recursive call to compare array values at index i + auto temp_diff = diff(source[i], target[i], path + "/" + std::to_string(i)); + result.insert(result.end(), temp_diff.begin(), temp_diff.end()); + ++i; + } + + // i now reached the end of at least one array + // in a second pass, traverse the remaining elements + + // remove my remaining elements + const auto end_index = static_cast(result.size()); + while (i < source.size()) + { + // add operations in reverse order to avoid invalid + // indices + result.insert(result.begin() + end_index, object( + { + {"op", "remove"}, + {"path", path + "/" + std::to_string(i)} + })); + ++i; + } + + // add other remaining elements + while (i < target.size()) + { + result.push_back( + { + {"op", "add"}, + {"path", path + "/" + std::to_string(i)}, + {"value", target[i]} + }); + ++i; + } + + break; + } + + case value_t::object: + { + // first pass: traverse this object's elements + for (auto it = source.begin(); it != source.end(); ++it) + { + // escape the key name to be used in a JSON patch + const auto key = json_pointer::escape(it.key()); + + if (target.find(it.key()) != target.end()) + { + // recursive call to compare object values at key it + auto temp_diff = diff(it.value(), target[it.key()], path + "/" + key); + result.insert(result.end(), temp_diff.begin(), temp_diff.end()); + } + else + { + // found a key that is not in o -> remove it + result.push_back(object( + { + {"op", "remove"}, {"path", path + "/" + key} + })); + } + } + + // second pass: traverse other object's elements + for (auto it = target.begin(); it != target.end(); ++it) + { + if (source.find(it.key()) == source.end()) + { + // found a key that is not in this -> add it + const auto key = json_pointer::escape(it.key()); + result.push_back( + { + {"op", "add"}, {"path", path + "/" + key}, + {"value", it.value()} + }); + } + } + + break; + } + + default: + { + // both primitive type: replace value + result.push_back( + { + {"op", "replace"}, {"path", path}, {"value", target} + }); + break; + } + } + } + + return result; + } + + /// @} +}; + +///////////// +// presets // +///////////// + +/*! +@brief default JSON class + +This type is the default specialization of the @ref basic_json class which +uses the standard template types. + +@since version 1.0.0 +*/ +using json = basic_json<>; + +////////////////// +// json_pointer // +////////////////// + +NLOHMANN_BASIC_JSON_TPL_DECLARATION +NLOHMANN_BASIC_JSON_TPL& +json_pointer::get_and_create(NLOHMANN_BASIC_JSON_TPL& j) const +{ + using size_type = typename NLOHMANN_BASIC_JSON_TPL::size_type; + auto result = &j; + + // in case no reference tokens exist, return a reference to the JSON value + // j which will be overwritten by a primitive value + for (const auto& reference_token : reference_tokens) + { + switch (result->m_type) + { + case detail::value_t::null: + { + if (reference_token == "0") + { + // start a new array if reference token is 0 + result = &result->operator[](0); + } + else + { + // start a new object otherwise + result = &result->operator[](reference_token); + } + break; + } + + case detail::value_t::object: + { + // create an entry in the object + result = &result->operator[](reference_token); + break; + } + + case detail::value_t::array: + { + // create an entry in the array + JSON_TRY + { + result = &result->operator[](static_cast(std::stoi(reference_token))); + } + JSON_CATCH(std::invalid_argument&) + { + JSON_THROW(detail::parse_error::create(109, 0, "array index '" + reference_token + "' is not a number")); + } + break; + } + + /* + The following code is only reached if there exists a reference + token _and_ the current value is primitive. In this case, we have + an error situation, because primitive values may only occur as + single value; that is, with an empty list of reference tokens. + */ + default: + JSON_THROW(detail::type_error::create(313, "invalid value to unflatten")); + } + } + + return *result; +} + +NLOHMANN_BASIC_JSON_TPL_DECLARATION +NLOHMANN_BASIC_JSON_TPL& +json_pointer::get_unchecked(NLOHMANN_BASIC_JSON_TPL* ptr) const +{ + using size_type = typename NLOHMANN_BASIC_JSON_TPL::size_type; + for (const auto& reference_token : reference_tokens) + { + // convert null values to arrays or objects before continuing + if (ptr->m_type == detail::value_t::null) + { + // check if reference token is a number + const bool nums = + std::all_of(reference_token.begin(), reference_token.end(), + [](const char x) + { + return (x >= '0' and x <= '9'); + }); + + // change value to array for numbers or "-" or to object otherwise + *ptr = (nums or reference_token == "-") + ? detail::value_t::array + : detail::value_t::object; + } + + switch (ptr->m_type) + { + case detail::value_t::object: + { + // use unchecked object access + ptr = &ptr->operator[](reference_token); + break; + } + + case detail::value_t::array: + { + // error condition (cf. RFC 6901, Sect. 4) + if (JSON_UNLIKELY(reference_token.size() > 1 and reference_token[0] == '0')) + { + JSON_THROW(detail::parse_error::create(106, 0, + "array index '" + reference_token + + "' must not begin with '0'")); + } + + if (reference_token == "-") + { + // explicitly treat "-" as index beyond the end + ptr = &ptr->operator[](ptr->m_value.array->size()); + } + else + { + // convert array index to number; unchecked access + JSON_TRY + { + ptr = &ptr->operator[]( + static_cast(std::stoi(reference_token))); + } + JSON_CATCH(std::invalid_argument&) + { + JSON_THROW(detail::parse_error::create(109, 0, "array index '" + reference_token + "' is not a number")); + } + } + break; + } + + default: + JSON_THROW(detail::out_of_range::create(404, "unresolved reference token '" + reference_token + "'")); + } + } + + return *ptr; +} + +NLOHMANN_BASIC_JSON_TPL_DECLARATION +NLOHMANN_BASIC_JSON_TPL& +json_pointer::get_checked(NLOHMANN_BASIC_JSON_TPL* ptr) const +{ + using size_type = typename NLOHMANN_BASIC_JSON_TPL::size_type; + for (const auto& reference_token : reference_tokens) + { + switch (ptr->m_type) + { + case detail::value_t::object: + { + // note: at performs range check + ptr = &ptr->at(reference_token); + break; + } + + case detail::value_t::array: + { + if (JSON_UNLIKELY(reference_token == "-")) + { + // "-" always fails the range check + JSON_THROW(detail::out_of_range::create(402, + "array index '-' (" + std::to_string(ptr->m_value.array->size()) + + ") is out of range")); + } + + // error condition (cf. RFC 6901, Sect. 4) + if (JSON_UNLIKELY(reference_token.size() > 1 and reference_token[0] == '0')) + { + JSON_THROW(detail::parse_error::create(106, 0, + "array index '" + reference_token + + "' must not begin with '0'")); + } + + // note: at performs range check + JSON_TRY + { + ptr = &ptr->at(static_cast(std::stoi(reference_token))); + } + JSON_CATCH(std::invalid_argument&) + { + JSON_THROW(detail::parse_error::create(109, 0, "array index '" + reference_token + "' is not a number")); + } + break; + } + + default: + JSON_THROW(detail::out_of_range::create(404, "unresolved reference token '" + reference_token + "'")); + } + } + + return *ptr; +} + +NLOHMANN_BASIC_JSON_TPL_DECLARATION +const NLOHMANN_BASIC_JSON_TPL& +json_pointer::get_unchecked(const NLOHMANN_BASIC_JSON_TPL* ptr) const +{ + using size_type = typename NLOHMANN_BASIC_JSON_TPL::size_type; + for (const auto& reference_token : reference_tokens) + { + switch (ptr->m_type) + { + case detail::value_t::object: + { + // use unchecked object access + ptr = &ptr->operator[](reference_token); + break; + } + + case detail::value_t::array: + { + if (JSON_UNLIKELY(reference_token == "-")) + { + // "-" cannot be used for const access + JSON_THROW(detail::out_of_range::create(402, + "array index '-' (" + std::to_string(ptr->m_value.array->size()) + + ") is out of range")); + } + + // error condition (cf. RFC 6901, Sect. 4) + if (JSON_UNLIKELY(reference_token.size() > 1 and reference_token[0] == '0')) + { + JSON_THROW(detail::parse_error::create(106, 0, + "array index '" + reference_token + + "' must not begin with '0'")); + } + + // use unchecked array access + JSON_TRY + { + ptr = &ptr->operator[]( + static_cast(std::stoi(reference_token))); + } + JSON_CATCH(std::invalid_argument&) + { + JSON_THROW(detail::parse_error::create(109, 0, "array index '" + reference_token + "' is not a number")); + } + break; + } + + default: + JSON_THROW(detail::out_of_range::create(404, "unresolved reference token '" + reference_token + "'")); + } + } + + return *ptr; +} + +NLOHMANN_BASIC_JSON_TPL_DECLARATION +const NLOHMANN_BASIC_JSON_TPL& +json_pointer::get_checked(const NLOHMANN_BASIC_JSON_TPL* ptr) const +{ + using size_type = typename NLOHMANN_BASIC_JSON_TPL::size_type; + for (const auto& reference_token : reference_tokens) + { + switch (ptr->m_type) + { + case detail::value_t::object: + { + // note: at performs range check + ptr = &ptr->at(reference_token); + break; + } + + case detail::value_t::array: + { + if (JSON_UNLIKELY(reference_token == "-")) + { + // "-" always fails the range check + JSON_THROW(detail::out_of_range::create(402, + "array index '-' (" + std::to_string(ptr->m_value.array->size()) + + ") is out of range")); + } + + // error condition (cf. RFC 6901, Sect. 4) + if (JSON_UNLIKELY(reference_token.size() > 1 and reference_token[0] == '0')) + { + JSON_THROW(detail::parse_error::create(106, 0, + "array index '" + reference_token + + "' must not begin with '0'")); + } + + // note: at performs range check + JSON_TRY + { + ptr = &ptr->at(static_cast(std::stoi(reference_token))); + } + JSON_CATCH(std::invalid_argument&) + { + JSON_THROW(detail::parse_error::create(109, 0, "array index '" + reference_token + "' is not a number")); + } + break; + } + + default: + JSON_THROW(detail::out_of_range::create(404, "unresolved reference token '" + reference_token + "'")); + } + } + + return *ptr; +} + +NLOHMANN_BASIC_JSON_TPL_DECLARATION +void json_pointer::flatten(const std::string& reference_string, + const NLOHMANN_BASIC_JSON_TPL& value, + NLOHMANN_BASIC_JSON_TPL& result) +{ + switch (value.m_type) + { + case detail::value_t::array: + { + if (value.m_value.array->empty()) + { + // flatten empty array as null + result[reference_string] = nullptr; + } + else + { + // iterate array and use index as reference string + for (std::size_t i = 0; i < value.m_value.array->size(); ++i) + { + flatten(reference_string + "/" + std::to_string(i), + value.m_value.array->operator[](i), result); + } + } + break; + } + + case detail::value_t::object: + { + if (value.m_value.object->empty()) + { + // flatten empty object as null + result[reference_string] = nullptr; + } + else + { + // iterate object and use keys as reference string + for (const auto& element : *value.m_value.object) + { + flatten(reference_string + "/" + escape(element.first), element.second, result); + } + } + break; + } + + default: + { + // add primitive value with its reference string + result[reference_string] = value; + break; + } + } +} + +NLOHMANN_BASIC_JSON_TPL_DECLARATION +NLOHMANN_BASIC_JSON_TPL +json_pointer::unflatten(const NLOHMANN_BASIC_JSON_TPL& value) +{ + if (JSON_UNLIKELY(not value.is_object())) + { + JSON_THROW(detail::type_error::create(314, "only objects can be unflattened")); + } + + NLOHMANN_BASIC_JSON_TPL result; + + // iterate the JSON object values + for (const auto& element : *value.m_value.object) + { + if (JSON_UNLIKELY(not element.second.is_primitive())) + { + JSON_THROW(detail::type_error::create(315, "values in object must be primitive")); + } + + // assign value to reference pointed to by JSON pointer; Note that if + // the JSON pointer is "" (i.e., points to the whole value), function + // get_and_create returns a reference to result itself. An assignment + // will then create a primitive value. + json_pointer(element.first).get_and_create(result) = element.second; + } + + return result; +} + +inline bool operator==(json_pointer const& lhs, json_pointer const& rhs) noexcept +{ + return (lhs.reference_tokens == rhs.reference_tokens); +} + +inline bool operator!=(json_pointer const& lhs, json_pointer const& rhs) noexcept +{ + return not (lhs == rhs); +} +} // namespace nlohmann + + +/////////////////////// +// nonmember support // +/////////////////////// + +// specialization of std::swap, and std::hash +namespace std +{ +/*! +@brief exchanges the values of two JSON objects + +@since version 1.0.0 +*/ +template<> +inline void swap(nlohmann::json& j1, + nlohmann::json& j2) noexcept( + is_nothrow_move_constructible::value and + is_nothrow_move_assignable::value + ) +{ + j1.swap(j2); +} + +/// hash value for JSON objects +template<> +struct hash +{ + /*! + @brief return a hash value for a JSON object + + @since version 1.0.0 + */ + std::size_t operator()(const nlohmann::json& j) const + { + // a naive hashing via the string representation + const auto& h = hash(); + return h(j.dump()); + } +}; + +/// specialization for std::less +/// @note: do not remove the space after '<', +/// see https://github.com/nlohmann/json/pull/679 +template<> +struct less< ::nlohmann::detail::value_t> +{ + /*! + @brief compare two value_t enum values + @since version 3.0.0 + */ + bool operator()(nlohmann::detail::value_t lhs, + nlohmann::detail::value_t rhs) const noexcept + { + return nlohmann::detail::operator<(lhs, rhs); + } +}; + +} // namespace std + +/*! +@brief user-defined string literal for JSON values + +This operator implements a user-defined string literal for JSON objects. It +can be used by adding `"_json"` to a string literal and returns a JSON object +if no parse error occurred. + +@param[in] s a string representation of a JSON object +@param[in] n the length of string @a s +@return a JSON object + +@since version 1.0.0 +*/ +inline nlohmann::json operator "" _json(const char* s, std::size_t n) +{ + return nlohmann::json::parse(s, s + n); +} + +/*! +@brief user-defined string literal for JSON pointer + +This operator implements a user-defined string literal for JSON Pointers. It +can be used by adding `"_json_pointer"` to a string literal and returns a JSON pointer +object if no parse error occurred. + +@param[in] s a string representation of a JSON Pointer +@param[in] n the length of string @a s +@return a JSON pointer object + +@since version 2.0.0 +*/ +inline nlohmann::json::json_pointer operator "" _json_pointer(const char* s, std::size_t n) +{ + return nlohmann::json::json_pointer(std::string(s, n)); +} + +// restore GCC/clang diagnostic settings +#if defined(__clang__) || defined(__GNUC__) || defined(__GNUG__) + #pragma GCC diagnostic pop +#endif +#if defined(__clang__) + #pragma GCC diagnostic pop +#endif + +// clean up +#undef JSON_CATCH +#undef JSON_THROW +#undef JSON_TRY +#undef JSON_LIKELY +#undef JSON_UNLIKELY +#undef JSON_DEPRECATED +#undef NLOHMANN_BASIC_JSON_TPL_DECLARATION +#undef NLOHMANN_BASIC_JSON_TPL + +#endif diff --git a/lib/Lattice.h b/lib/lattice/Lattice.h similarity index 100% rename from lib/Lattice.h rename to lib/lattice/Lattice.h diff --git a/lib/lattice/Lattice_arith.h b/lib/lattice/Lattice_arith.h index 6527c487..c3093167 100644 --- a/lib/lattice/Lattice_arith.h +++ b/lib/lattice/Lattice_arith.h @@ -39,8 +39,7 @@ namespace Grid { ret.checkerboard = lhs.checkerboard; conformable(ret,rhs); conformable(lhs,rhs); -PARALLEL_FOR_LOOP - for(int ss=0;ssoSites();ss++){ + parallel_for(int ss=0;ssoSites();ss++){ #ifdef STREAMING_STORES obj1 tmp; mult(&tmp,&lhs._odata[ss],&rhs._odata[ss]); @@ -56,8 +55,7 @@ PARALLEL_FOR_LOOP ret.checkerboard = lhs.checkerboard; conformable(ret,rhs); conformable(lhs,rhs); -PARALLEL_FOR_LOOP - for(int ss=0;ssoSites();ss++){ + parallel_for(int ss=0;ssoSites();ss++){ #ifdef STREAMING_STORES obj1 tmp; mac(&tmp,&lhs._odata[ss],&rhs._odata[ss]); @@ -73,8 +71,7 @@ PARALLEL_FOR_LOOP ret.checkerboard = lhs.checkerboard; conformable(ret,rhs); conformable(lhs,rhs); -PARALLEL_FOR_LOOP - for(int ss=0;ssoSites();ss++){ + parallel_for(int ss=0;ssoSites();ss++){ #ifdef STREAMING_STORES obj1 tmp; sub(&tmp,&lhs._odata[ss],&rhs._odata[ss]); @@ -89,8 +86,7 @@ PARALLEL_FOR_LOOP ret.checkerboard = lhs.checkerboard; conformable(ret,rhs); conformable(lhs,rhs); -PARALLEL_FOR_LOOP - for(int ss=0;ssoSites();ss++){ + parallel_for(int ss=0;ssoSites();ss++){ #ifdef STREAMING_STORES obj1 tmp; add(&tmp,&lhs._odata[ss],&rhs._odata[ss]); @@ -108,8 +104,7 @@ PARALLEL_FOR_LOOP void mult(Lattice &ret,const Lattice &lhs,const obj3 &rhs){ ret.checkerboard = lhs.checkerboard; conformable(lhs,ret); -PARALLEL_FOR_LOOP - for(int ss=0;ssoSites();ss++){ + parallel_for(int ss=0;ssoSites();ss++){ obj1 tmp; mult(&tmp,&lhs._odata[ss],&rhs); vstream(ret._odata[ss],tmp); @@ -120,8 +115,7 @@ PARALLEL_FOR_LOOP void mac(Lattice &ret,const Lattice &lhs,const obj3 &rhs){ ret.checkerboard = lhs.checkerboard; conformable(ret,lhs); -PARALLEL_FOR_LOOP - for(int ss=0;ssoSites();ss++){ + parallel_for(int ss=0;ssoSites();ss++){ obj1 tmp; mac(&tmp,&lhs._odata[ss],&rhs); vstream(ret._odata[ss],tmp); @@ -132,8 +126,7 @@ PARALLEL_FOR_LOOP void sub(Lattice &ret,const Lattice &lhs,const obj3 &rhs){ ret.checkerboard = lhs.checkerboard; conformable(ret,lhs); -PARALLEL_FOR_LOOP - for(int ss=0;ssoSites();ss++){ + parallel_for(int ss=0;ssoSites();ss++){ #ifdef STREAMING_STORES obj1 tmp; sub(&tmp,&lhs._odata[ss],&rhs); @@ -147,8 +140,7 @@ PARALLEL_FOR_LOOP void add(Lattice &ret,const Lattice &lhs,const obj3 &rhs){ ret.checkerboard = lhs.checkerboard; conformable(lhs,ret); -PARALLEL_FOR_LOOP - for(int ss=0;ssoSites();ss++){ + parallel_for(int ss=0;ssoSites();ss++){ #ifdef STREAMING_STORES obj1 tmp; add(&tmp,&lhs._odata[ss],&rhs); @@ -166,8 +158,7 @@ PARALLEL_FOR_LOOP void mult(Lattice &ret,const obj2 &lhs,const Lattice &rhs){ ret.checkerboard = rhs.checkerboard; conformable(ret,rhs); -PARALLEL_FOR_LOOP - for(int ss=0;ssoSites();ss++){ + parallel_for(int ss=0;ssoSites();ss++){ #ifdef STREAMING_STORES obj1 tmp; mult(&tmp,&lhs,&rhs._odata[ss]); @@ -182,8 +173,7 @@ PARALLEL_FOR_LOOP void mac(Lattice &ret,const obj2 &lhs,const Lattice &rhs){ ret.checkerboard = rhs.checkerboard; conformable(ret,rhs); -PARALLEL_FOR_LOOP - for(int ss=0;ssoSites();ss++){ + parallel_for(int ss=0;ssoSites();ss++){ #ifdef STREAMING_STORES obj1 tmp; mac(&tmp,&lhs,&rhs._odata[ss]); @@ -198,8 +188,7 @@ PARALLEL_FOR_LOOP void sub(Lattice &ret,const obj2 &lhs,const Lattice &rhs){ ret.checkerboard = rhs.checkerboard; conformable(ret,rhs); -PARALLEL_FOR_LOOP - for(int ss=0;ssoSites();ss++){ + parallel_for(int ss=0;ssoSites();ss++){ #ifdef STREAMING_STORES obj1 tmp; sub(&tmp,&lhs,&rhs._odata[ss]); @@ -213,8 +202,7 @@ PARALLEL_FOR_LOOP void add(Lattice &ret,const obj2 &lhs,const Lattice &rhs){ ret.checkerboard = rhs.checkerboard; conformable(ret,rhs); -PARALLEL_FOR_LOOP - for(int ss=0;ssoSites();ss++){ + parallel_for(int ss=0;ssoSites();ss++){ #ifdef STREAMING_STORES obj1 tmp; add(&tmp,&lhs,&rhs._odata[ss]); @@ -230,8 +218,7 @@ PARALLEL_FOR_LOOP ret.checkerboard = x.checkerboard; conformable(ret,x); conformable(x,y); -PARALLEL_FOR_LOOP - for(int ss=0;ssoSites();ss++){ + parallel_for(int ss=0;ssoSites();ss++){ #ifdef STREAMING_STORES vobj tmp = a*x._odata[ss]+y._odata[ss]; vstream(ret._odata[ss],tmp); @@ -245,8 +232,7 @@ PARALLEL_FOR_LOOP ret.checkerboard = x.checkerboard; conformable(ret,x); conformable(x,y); -PARALLEL_FOR_LOOP - for(int ss=0;ssoSites();ss++){ + parallel_for(int ss=0;ssoSites();ss++){ #ifdef STREAMING_STORES vobj tmp = a*x._odata[ss]+b*y._odata[ss]; vstream(ret._odata[ss],tmp); diff --git a/lib/lattice/Lattice_base.h b/lib/lattice/Lattice_base.h index e4dc1ca8..014e443d 100644 --- a/lib/lattice/Lattice_base.h +++ b/lib/lattice/Lattice_base.h @@ -121,8 +121,7 @@ public: assert( (cb==Odd) || (cb==Even)); checkerboard=cb; -PARALLEL_FOR_LOOP - for(int ss=0;ss<_grid->oSites();ss++){ + parallel_for(int ss=0;ss<_grid->oSites();ss++){ #ifdef STREAMING_STORES vobj tmp = eval(ss,expr); vstream(_odata[ss] ,tmp); @@ -144,8 +143,7 @@ PARALLEL_FOR_LOOP assert( (cb==Odd) || (cb==Even)); checkerboard=cb; -PARALLEL_FOR_LOOP - for(int ss=0;ss<_grid->oSites();ss++){ + parallel_for(int ss=0;ss<_grid->oSites();ss++){ #ifdef STREAMING_STORES vobj tmp = eval(ss,expr); vstream(_odata[ss] ,tmp); @@ -167,8 +165,7 @@ PARALLEL_FOR_LOOP assert( (cb==Odd) || (cb==Even)); checkerboard=cb; -PARALLEL_FOR_LOOP - for(int ss=0;ss<_grid->oSites();ss++){ + parallel_for(int ss=0;ss<_grid->oSites();ss++){ #ifdef STREAMING_STORES //vobj tmp = eval(ss,expr); vstream(_odata[ss] ,eval(ss,expr)); @@ -191,8 +188,7 @@ PARALLEL_FOR_LOOP checkerboard=cb; _odata.resize(_grid->oSites()); -PARALLEL_FOR_LOOP - for(int ss=0;ss<_grid->oSites();ss++){ + parallel_for(int ss=0;ss<_grid->oSites();ss++){ #ifdef STREAMING_STORES vobj tmp = eval(ss,expr); vstream(_odata[ss] ,tmp); @@ -213,8 +209,7 @@ PARALLEL_FOR_LOOP checkerboard=cb; _odata.resize(_grid->oSites()); -PARALLEL_FOR_LOOP - for(int ss=0;ss<_grid->oSites();ss++){ + parallel_for(int ss=0;ss<_grid->oSites();ss++){ #ifdef STREAMING_STORES vobj tmp = eval(ss,expr); vstream(_odata[ss] ,tmp); @@ -235,73 +230,79 @@ PARALLEL_FOR_LOOP checkerboard=cb; _odata.resize(_grid->oSites()); -PARALLEL_FOR_LOOP - for(int ss=0;ss<_grid->oSites();ss++){ + parallel_for(int ss=0;ss<_grid->oSites();ss++){ vstream(_odata[ss] ,eval(ss,expr)); } }; - ////////////////////////////////////////////////////////////////// - // Constructor requires "grid" passed. - // what about a default grid? - ////////////////////////////////////////////////////////////////// - Lattice(GridBase *grid) : _odata(grid->oSites()) { - _grid = grid; + ////////////////////////////////////////////////////////////////// + // Constructor requires "grid" passed. + // what about a default grid? + ////////////////////////////////////////////////////////////////// + Lattice(GridBase *grid) : _odata(grid->oSites()) { + _grid = grid; // _odata.reserve(_grid->oSites()); // _odata.resize(_grid->oSites()); // std::cout << "Constructing lattice object with Grid pointer "<<_grid<oSites());// essential - PARALLEL_FOR_LOOP - for(int ss=0;ss<_grid->oSites();ss++){ - _odata[ss]=r._odata[ss]; - } - } - - - - virtual ~Lattice(void) = default; + assert((((uint64_t)&_odata[0])&0xF) ==0); + checkerboard=0; + } + + Lattice(const Lattice& r){ // copy constructor + _grid = r._grid; + checkerboard = r.checkerboard; + _odata.resize(_grid->oSites());// essential + parallel_for(int ss=0;ss<_grid->oSites();ss++){ + _odata[ss]=r._odata[ss]; + } + } + + + + virtual ~Lattice(void) = default; - template strong_inline Lattice & operator = (const sobj & r){ -PARALLEL_FOR_LOOP - for(int ss=0;ss<_grid->oSites();ss++){ - this->_odata[ss]=r; - } - return *this; - } - template strong_inline Lattice & operator = (const Lattice & r){ - this->checkerboard = r.checkerboard; - conformable(*this,r); - -PARALLEL_FOR_LOOP - for(int ss=0;ss<_grid->oSites();ss++){ - this->_odata[ss]=r._odata[ss]; - } - return *this; + void reset(GridBase* grid) { + if (_grid != grid) { + _grid = grid; + _odata.resize(grid->oSites()); + checkerboard = 0; } + } + - // *=,+=,-= operators inherit behvour from correspond */+/- operation - template strong_inline Lattice &operator *=(const T &r) { - *this = (*this)*r; - return *this; + template strong_inline Lattice & operator = (const sobj & r){ + parallel_for(int ss=0;ss<_grid->oSites();ss++){ + this->_odata[ss]=r; } - - template strong_inline Lattice &operator -=(const T &r) { - *this = (*this)-r; - return *this; + return *this; + } + + template strong_inline Lattice & operator = (const Lattice & r){ + this->checkerboard = r.checkerboard; + conformable(*this,r); + + parallel_for(int ss=0;ss<_grid->oSites();ss++){ + this->_odata[ss]=r._odata[ss]; } - template strong_inline Lattice &operator +=(const T &r) { - *this = (*this)+r; - return *this; - } - }; // class Lattice - + return *this; + } + + // *=,+=,-= operators inherit behvour from correspond */+/- operation + template strong_inline Lattice &operator *=(const T &r) { + *this = (*this)*r; + return *this; + } + + template strong_inline Lattice &operator -=(const T &r) { + *this = (*this)-r; + return *this; + } + template strong_inline Lattice &operator +=(const T &r) { + *this = (*this)+r; + return *this; + } +}; // class Lattice + template std::ostream& operator<< (std::ostream& stream, const Lattice &o){ std::vector gcoor; typedef typename vobj::scalar_object sobj; @@ -319,7 +320,7 @@ PARALLEL_FOR_LOOP } return stream; } - + } diff --git a/lib/lattice/Lattice_comparison.h b/lib/lattice/Lattice_comparison.h index 1b5b0624..9bf1fb2d 100644 --- a/lib/lattice/Lattice_comparison.h +++ b/lib/lattice/Lattice_comparison.h @@ -45,90 +45,87 @@ namespace Grid { ////////////////////////////////////////////////////////////////////////// template inline Lattice LLComparison(vfunctor op,const Lattice &lhs,const Lattice &rhs) - { - Lattice ret(rhs._grid); -PARALLEL_FOR_LOOP - for(int ss=0;ssoSites(); ss++){ - ret._odata[ss]=op(lhs._odata[ss],rhs._odata[ss]); - } - return ret; + { + Lattice ret(rhs._grid); + parallel_for(int ss=0;ssoSites(); ss++){ + ret._odata[ss]=op(lhs._odata[ss],rhs._odata[ss]); } + return ret; + } ////////////////////////////////////////////////////////////////////////// // compare lattice to scalar ////////////////////////////////////////////////////////////////////////// - template + template inline Lattice LSComparison(vfunctor op,const Lattice &lhs,const robj &rhs) - { - Lattice ret(lhs._grid); -PARALLEL_FOR_LOOP - for(int ss=0;ssoSites(); ss++){ - ret._odata[ss]=op(lhs._odata[ss],rhs); - } - return ret; + { + Lattice ret(lhs._grid); + parallel_for(int ss=0;ssoSites(); ss++){ + ret._odata[ss]=op(lhs._odata[ss],rhs); } + return ret; + } ////////////////////////////////////////////////////////////////////////// // compare scalar to lattice ////////////////////////////////////////////////////////////////////////// - template + template inline Lattice SLComparison(vfunctor op,const lobj &lhs,const Lattice &rhs) - { - Lattice ret(rhs._grid); -PARALLEL_FOR_LOOP - for(int ss=0;ssoSites(); ss++){ - ret._odata[ss]=op(lhs._odata[ss],rhs); - } - return ret; + { + Lattice ret(rhs._grid); + parallel_for(int ss=0;ssoSites(); ss++){ + ret._odata[ss]=op(lhs._odata[ss],rhs); } - + return ret; + } + ////////////////////////////////////////////////////////////////////////// // Map to functors ////////////////////////////////////////////////////////////////////////// - // Less than - template - inline Lattice operator < (const Lattice & lhs, const Lattice & rhs) { - return LLComparison(vlt(),lhs,rhs); - } - template - inline Lattice operator < (const Lattice & lhs, const robj & rhs) { - return LSComparison(vlt(),lhs,rhs); - } - template - inline Lattice operator < (const lobj & lhs, const Lattice & rhs) { - return SLComparison(vlt(),lhs,rhs); - } - - // Less than equal - template - inline Lattice operator <= (const Lattice & lhs, const Lattice & rhs) { - return LLComparison(vle(),lhs,rhs); - } - template - inline Lattice operator <= (const Lattice & lhs, const robj & rhs) { - return LSComparison(vle(),lhs,rhs); - } - template - inline Lattice operator <= (const lobj & lhs, const Lattice & rhs) { - return SLComparison(vle(),lhs,rhs); - } - - // Greater than - template - inline Lattice operator > (const Lattice & lhs, const Lattice & rhs) { - return LLComparison(vgt(),lhs,rhs); - } - template - inline Lattice operator > (const Lattice & lhs, const robj & rhs) { - return LSComparison(vgt(),lhs,rhs); - } - template - inline Lattice operator > (const lobj & lhs, const Lattice & rhs) { + // Less than + template + inline Lattice operator < (const Lattice & lhs, const Lattice & rhs) { + return LLComparison(vlt(),lhs,rhs); + } + template + inline Lattice operator < (const Lattice & lhs, const robj & rhs) { + return LSComparison(vlt(),lhs,rhs); + } + template + inline Lattice operator < (const lobj & lhs, const Lattice & rhs) { + return SLComparison(vlt(),lhs,rhs); + } + + // Less than equal + template + inline Lattice operator <= (const Lattice & lhs, const Lattice & rhs) { + return LLComparison(vle(),lhs,rhs); + } + template + inline Lattice operator <= (const Lattice & lhs, const robj & rhs) { + return LSComparison(vle(),lhs,rhs); + } + template + inline Lattice operator <= (const lobj & lhs, const Lattice & rhs) { + return SLComparison(vle(),lhs,rhs); + } + + // Greater than + template + inline Lattice operator > (const Lattice & lhs, const Lattice & rhs) { + return LLComparison(vgt(),lhs,rhs); + } + template + inline Lattice operator > (const Lattice & lhs, const robj & rhs) { + return LSComparison(vgt(),lhs,rhs); + } + template + inline Lattice operator > (const lobj & lhs, const Lattice & rhs) { return SLComparison(vgt(),lhs,rhs); - } - - - // Greater than equal + } + + + // Greater than equal template - inline Lattice operator >= (const Lattice & lhs, const Lattice & rhs) { + inline Lattice operator >= (const Lattice & lhs, const Lattice & rhs) { return LLComparison(vge(),lhs,rhs); } template @@ -136,38 +133,37 @@ PARALLEL_FOR_LOOP return LSComparison(vge(),lhs,rhs); } template - inline Lattice operator >= (const lobj & lhs, const Lattice & rhs) { + inline Lattice operator >= (const lobj & lhs, const Lattice & rhs) { return SLComparison(vge(),lhs,rhs); } - + // equal template - inline Lattice operator == (const Lattice & lhs, const Lattice & rhs) { + inline Lattice operator == (const Lattice & lhs, const Lattice & rhs) { return LLComparison(veq(),lhs,rhs); } template - inline Lattice operator == (const Lattice & lhs, const robj & rhs) { + inline Lattice operator == (const Lattice & lhs, const robj & rhs) { return LSComparison(veq(),lhs,rhs); } template - inline Lattice operator == (const lobj & lhs, const Lattice & rhs) { + inline Lattice operator == (const lobj & lhs, const Lattice & rhs) { return SLComparison(veq(),lhs,rhs); } - - + + // not equal template - inline Lattice operator != (const Lattice & lhs, const Lattice & rhs) { + inline Lattice operator != (const Lattice & lhs, const Lattice & rhs) { return LLComparison(vne(),lhs,rhs); } template - inline Lattice operator != (const Lattice & lhs, const robj & rhs) { + inline Lattice operator != (const Lattice & lhs, const robj & rhs) { return LSComparison(vne(),lhs,rhs); } template - inline Lattice operator != (const lobj & lhs, const Lattice & rhs) { + inline Lattice operator != (const lobj & lhs, const Lattice & rhs) { return SLComparison(vne(),lhs,rhs); } - } #endif diff --git a/lib/lattice/Lattice_local.h b/lib/lattice/Lattice_local.h index 65d1d929..9dae1cd9 100644 --- a/lib/lattice/Lattice_local.h +++ b/lib/lattice/Lattice_local.h @@ -34,47 +34,42 @@ Author: Peter Boyle namespace Grid { - ///////////////////////////////////////////////////// - // Non site, reduced locally reduced routines - ///////////////////////////////////////////////////// - - // localNorm2, - template + ///////////////////////////////////////////////////// + // Non site, reduced locally reduced routines + ///////////////////////////////////////////////////// + + // localNorm2, + template inline auto localNorm2 (const Lattice &rhs)-> Lattice { Lattice ret(rhs._grid); -PARALLEL_FOR_LOOP - for(int ss=0;ssoSites(); ss++){ - ret._odata[ss]=innerProduct(rhs._odata[ss],rhs._odata[ss]); - } - return ret; + parallel_for(int ss=0;ssoSites(); ss++){ + ret._odata[ss]=innerProduct(rhs._odata[ss],rhs._odata[ss]); + } + return ret; } - - // localInnerProduct - template + + // localInnerProduct + template inline auto localInnerProduct (const Lattice &lhs,const Lattice &rhs) -> Lattice { Lattice ret(rhs._grid); -PARALLEL_FOR_LOOP - for(int ss=0;ssoSites(); ss++){ + parallel_for(int ss=0;ssoSites(); ss++){ ret._odata[ss]=innerProduct(lhs._odata[ss],rhs._odata[ss]); } return ret; } - - // outerProduct Scalar x Scalar -> Scalar - // Vector x Vector -> Matrix - template + + // outerProduct Scalar x Scalar -> Scalar + // Vector x Vector -> Matrix + template inline auto outerProduct (const Lattice &lhs,const Lattice &rhs) -> Lattice - { - Lattice ret(rhs._grid); -PARALLEL_FOR_LOOP - for(int ss=0;ssoSites(); ss++){ - ret._odata[ss]=outerProduct(lhs._odata[ss],rhs._odata[ss]); - } - return ret; - } - + { + Lattice ret(rhs._grid); + parallel_for(int ss=0;ssoSites(); ss++){ + ret._odata[ss]=outerProduct(lhs._odata[ss],rhs._odata[ss]); + } + return ret; + } } - #endif diff --git a/lib/lattice/Lattice_overload.h b/lib/lattice/Lattice_overload.h index 2a5d16a1..0906b610 100644 --- a/lib/lattice/Lattice_overload.h +++ b/lib/lattice/Lattice_overload.h @@ -37,8 +37,7 @@ namespace Grid { inline Lattice operator -(const Lattice &r) { Lattice ret(r._grid); -PARALLEL_FOR_LOOP - for(int ss=0;ssoSites();ss++){ + parallel_for(int ss=0;ssoSites();ss++){ vstream(ret._odata[ss], -r._odata[ss]); } return ret; @@ -74,8 +73,7 @@ PARALLEL_FOR_LOOP inline auto operator * (const left &lhs,const Lattice &rhs) -> Lattice { Lattice ret(rhs._grid); -PARALLEL_FOR_LOOP - for(int ss=0;ssoSites(); ss++){ + parallel_for(int ss=0;ssoSites(); ss++){ decltype(lhs*rhs._odata[0]) tmp=lhs*rhs._odata[ss]; vstream(ret._odata[ss],tmp); // ret._odata[ss]=lhs*rhs._odata[ss]; @@ -86,8 +84,7 @@ PARALLEL_FOR_LOOP inline auto operator + (const left &lhs,const Lattice &rhs) -> Lattice { Lattice ret(rhs._grid); -PARALLEL_FOR_LOOP - for(int ss=0;ssoSites(); ss++){ + parallel_for(int ss=0;ssoSites(); ss++){ decltype(lhs+rhs._odata[0]) tmp =lhs-rhs._odata[ss]; vstream(ret._odata[ss],tmp); // ret._odata[ss]=lhs+rhs._odata[ss]; @@ -98,11 +95,9 @@ PARALLEL_FOR_LOOP inline auto operator - (const left &lhs,const Lattice &rhs) -> Lattice { Lattice ret(rhs._grid); -PARALLEL_FOR_LOOP - for(int ss=0;ssoSites(); ss++){ + parallel_for(int ss=0;ssoSites(); ss++){ decltype(lhs-rhs._odata[0]) tmp=lhs-rhs._odata[ss]; vstream(ret._odata[ss],tmp); - // ret._odata[ss]=lhs-rhs._odata[ss]; } return ret; } @@ -110,8 +105,7 @@ PARALLEL_FOR_LOOP inline auto operator * (const Lattice &lhs,const right &rhs) -> Lattice { Lattice ret(lhs._grid); -PARALLEL_FOR_LOOP - for(int ss=0;ssoSites(); ss++){ + parallel_for(int ss=0;ssoSites(); ss++){ decltype(lhs._odata[0]*rhs) tmp =lhs._odata[ss]*rhs; vstream(ret._odata[ss],tmp); // ret._odata[ss]=lhs._odata[ss]*rhs; @@ -122,8 +116,7 @@ PARALLEL_FOR_LOOP inline auto operator + (const Lattice &lhs,const right &rhs) -> Lattice { Lattice ret(lhs._grid); -PARALLEL_FOR_LOOP - for(int ss=0;ssoSites(); ss++){ + parallel_for(int ss=0;ssoSites(); ss++){ decltype(lhs._odata[0]+rhs) tmp=lhs._odata[ss]+rhs; vstream(ret._odata[ss],tmp); // ret._odata[ss]=lhs._odata[ss]+rhs; @@ -134,15 +127,12 @@ PARALLEL_FOR_LOOP inline auto operator - (const Lattice &lhs,const right &rhs) -> Lattice { Lattice ret(lhs._grid); -PARALLEL_FOR_LOOP - for(int ss=0;ssoSites(); ss++){ + parallel_for(int ss=0;ssoSites(); ss++){ decltype(lhs._odata[0]-rhs) tmp=lhs._odata[ss]-rhs; vstream(ret._odata[ss],tmp); // ret._odata[ss]=lhs._odata[ss]-rhs; } return ret; } - - } #endif diff --git a/lib/lattice/Lattice_peekpoke.h b/lib/lattice/Lattice_peekpoke.h index 19d349c4..3d6268d2 100644 --- a/lib/lattice/Lattice_peekpoke.h +++ b/lib/lattice/Lattice_peekpoke.h @@ -44,22 +44,20 @@ namespace Grid { { Lattice(lhs._odata[0],i))> ret(lhs._grid); ret.checkerboard=lhs.checkerboard; -PARALLEL_FOR_LOOP - for(int ss=0;ssoSites();ss++){ - ret._odata[ss] = peekIndex(lhs._odata[ss],i); - } - return ret; + parallel_for(int ss=0;ssoSites();ss++){ + ret._odata[ss] = peekIndex(lhs._odata[ss],i); + } + return ret; }; template - auto PeekIndex(const Lattice &lhs,int i,int j) -> Lattice(lhs._odata[0],i,j))> + auto PeekIndex(const Lattice &lhs,int i,int j) -> Lattice(lhs._odata[0],i,j))> { Lattice(lhs._odata[0],i,j))> ret(lhs._grid); ret.checkerboard=lhs.checkerboard; -PARALLEL_FOR_LOOP - for(int ss=0;ssoSites();ss++){ - ret._odata[ss] = peekIndex(lhs._odata[ss],i,j); - } - return ret; + parallel_for(int ss=0;ssoSites();ss++){ + ret._odata[ss] = peekIndex(lhs._odata[ss],i,j); + } + return ret; }; //////////////////////////////////////////////////////////////////////////////////////////////////// @@ -68,25 +66,23 @@ PARALLEL_FOR_LOOP template void PokeIndex(Lattice &lhs,const Lattice(lhs._odata[0],0))> & rhs,int i) { -PARALLEL_FOR_LOOP - for(int ss=0;ssoSites();ss++){ - pokeIndex(lhs._odata[ss],rhs._odata[ss],i); - } + parallel_for(int ss=0;ssoSites();ss++){ + pokeIndex(lhs._odata[ss],rhs._odata[ss],i); + } } template void PokeIndex(Lattice &lhs,const Lattice(lhs._odata[0],0,0))> & rhs,int i,int j) { -PARALLEL_FOR_LOOP - for(int ss=0;ssoSites();ss++){ - pokeIndex(lhs._odata[ss],rhs._odata[ss],i,j); - } + parallel_for(int ss=0;ssoSites();ss++){ + pokeIndex(lhs._odata[ss],rhs._odata[ss],i,j); + } } ////////////////////////////////////////////////////// // Poke a scalar object into the SIMD array ////////////////////////////////////////////////////// template - void pokeSite(const sobj &s,Lattice &l,std::vector &site){ + void pokeSite(const sobj &s,Lattice &l,const std::vector &site){ GridBase *grid=l._grid; @@ -120,7 +116,7 @@ PARALLEL_FOR_LOOP // Peek a scalar object from the SIMD array ////////////////////////////////////////////////////////// template - void peekSite(sobj &s,const Lattice &l,std::vector &site){ + void peekSite(sobj &s,const Lattice &l,const std::vector &site){ GridBase *grid=l._grid; @@ -131,9 +127,6 @@ PARALLEL_FOR_LOOP assert( l.checkerboard == l._grid->CheckerBoard(site)); - // FIXME - // assert( sizeof(sobj)*Nsimd == sizeof(vobj)); - int rank,odx,idx; grid->GlobalCoorToRankIndex(rank,odx,idx,site); diff --git a/lib/lattice/Lattice_reality.h b/lib/lattice/Lattice_reality.h index 10add8cd..7e7b2631 100644 --- a/lib/lattice/Lattice_reality.h +++ b/lib/lattice/Lattice_reality.h @@ -40,8 +40,7 @@ namespace Grid { template inline Lattice adj(const Lattice &lhs){ Lattice ret(lhs._grid); -PARALLEL_FOR_LOOP - for(int ss=0;ssoSites();ss++){ + parallel_for(int ss=0;ssoSites();ss++){ ret._odata[ss] = adj(lhs._odata[ss]); } return ret; @@ -49,13 +48,10 @@ PARALLEL_FOR_LOOP template inline Lattice conjugate(const Lattice &lhs){ Lattice ret(lhs._grid); -PARALLEL_FOR_LOOP - for(int ss=0;ssoSites();ss++){ - ret._odata[ss] = conjugate(lhs._odata[ss]); + parallel_for(int ss=0;ssoSites();ss++){ + ret._odata[ss] = conjugate(lhs._odata[ss]); } return ret; }; - - } #endif diff --git a/lib/lattice/Lattice_reduction.h b/lib/lattice/Lattice_reduction.h index 2615af48..db012c8c 100644 --- a/lib/lattice/Lattice_reduction.h +++ b/lib/lattice/Lattice_reduction.h @@ -1,159 +1,154 @@ - /************************************************************************************* - +/************************************************************************************* Grid physics library, www.github.com/paboyle/Grid - Source file: ./lib/lattice/Lattice_reduction.h - Copyright (C) 2015 - Author: Azusa Yamaguchi Author: Peter Boyle Author: paboyle - This program is free software; you can redistribute it and/or modify it under the terms of the GNU General Public License as published by the Free Software Foundation; either version 2 of the License, or (at your option) any later version. - This program is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License for more details. - You should have received a copy of the GNU General Public License along with this program; if not, write to the Free Software Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA. - See the full license in the file "LICENSE" in the top level distribution directory *************************************************************************************/ /* END LEGAL */ #ifndef GRID_LATTICE_REDUCTION_H #define GRID_LATTICE_REDUCTION_H +#include + namespace Grid { #ifdef GRID_WARN_SUBOPTIMAL #warning "Optimisation alert all these reduction loops are NOT threaded " #endif - //////////////////////////////////////////////////////////////////////////////////////////////////// - // Deterministic Reduction operations - //////////////////////////////////////////////////////////////////////////////////////////////////// - template inline RealD norm2(const Lattice &arg){ - ComplexD nrm = innerProduct(arg,arg); - return std::real(nrm); - } + //////////////////////////////////////////////////////////////////////////////////////////////////// + // Deterministic Reduction operations + //////////////////////////////////////////////////////////////////////////////////////////////////// +template inline RealD norm2(const Lattice &arg){ + ComplexD nrm = innerProduct(arg,arg); + return std::real(nrm); +} - template - inline ComplexD innerProduct(const Lattice &left,const Lattice &right) - { - typedef typename vobj::scalar_type scalar_type; - typedef typename vobj::vector_type vector_type; - scalar_type nrm; - - GridBase *grid = left._grid; - - std::vector > sumarray(grid->SumArraySize()); - for(int i=0;iSumArraySize();i++){ - sumarray[i]=zero; - } - -PARALLEL_FOR_LOOP - for(int thr=0;thrSumArraySize();thr++){ - int nwork, mywork, myoff; - GridThread::GetWork(left._grid->oSites(),thr,mywork,myoff); - - decltype(innerProduct(left._odata[0],right._odata[0])) vnrm=zero; // private to thread; sub summation - for(int ss=myoff;ss +inline ComplexD innerProduct(const Lattice &left,const Lattice &right) +{ + typedef typename vobj::scalar_type scalar_type; + typedef typename vobj::vector_typeD vector_type; + scalar_type nrm; + + GridBase *grid = left._grid; + + std::vector > sumarray(grid->SumArraySize()); + + parallel_for(int thr=0;thrSumArraySize();thr++){ + int nwork, mywork, myoff; + GridThread::GetWork(left._grid->oSites(),thr,mywork,myoff); - vector_type vvnrm; vvnrm=zero; // sum across threads - for(int i=0;iSumArraySize();i++){ - vvnrm = vvnrm+sumarray[i]; - } - nrm = Reduce(vvnrm);// sum across simd - right._grid->GlobalSum(nrm); - return nrm; + decltype(innerProductD(left._odata[0],right._odata[0])) vnrm=zero; // private to thread; sub summation + for(int ss=myoff;ssSumArraySize();i++){ + vvnrm = vvnrm+sumarray[i]; + } + nrm = Reduce(vvnrm);// sum across simd + right._grid->GlobalSum(nrm); + return nrm; +} + +template +inline auto sum(const LatticeUnaryExpression & expr) + ->typename decltype(expr.first.func(eval(0,std::get<0>(expr.second))))::scalar_object +{ + return sum(closure(expr)); +} - template - inline auto sum(const LatticeUnaryExpression & expr) - ->typename decltype(expr.first.func(eval(0,std::get<0>(expr.second))))::scalar_object - { - return sum(closure(expr)); - } - - template - inline auto sum(const LatticeBinaryExpression & expr) +template +inline auto sum(const LatticeBinaryExpression & expr) ->typename decltype(expr.first.func(eval(0,std::get<0>(expr.second)),eval(0,std::get<1>(expr.second))))::scalar_object - { - return sum(closure(expr)); - } - - - template - inline auto sum(const LatticeTrinaryExpression & expr) - ->typename decltype(expr.first.func(eval(0,std::get<0>(expr.second)), - eval(0,std::get<1>(expr.second)), - eval(0,std::get<2>(expr.second)) - ))::scalar_object - { - return sum(closure(expr)); - } - - template - inline typename vobj::scalar_object sum(const Lattice &arg){ - - GridBase *grid=arg._grid; - int Nsimd = grid->Nsimd(); - - std::vector > sumarray(grid->SumArraySize()); - for(int i=0;iSumArraySize();i++){ - sumarray[i]=zero; - } - -PARALLEL_FOR_LOOP - for(int thr=0;thrSumArraySize();thr++){ - int nwork, mywork, myoff; - GridThread::GetWork(grid->oSites(),thr,mywork,myoff); - - vobj vvsum=zero; - for(int ss=myoff;ssSumArraySize();i++){ - vsum = vsum+sumarray[i]; - } - - typedef typename vobj::scalar_object sobj; - sobj ssum=zero; - - std::vector buf(Nsimd); - extract(vsum,buf); - - for(int i=0;iGlobalSum(ssum); - - return ssum; +{ + return sum(closure(expr)); +} + + +template +inline auto sum(const LatticeTrinaryExpression & expr) + ->typename decltype(expr.first.func(eval(0,std::get<0>(expr.second)), + eval(0,std::get<1>(expr.second)), + eval(0,std::get<2>(expr.second)) + ))::scalar_object +{ + return sum(closure(expr)); +} + +template +inline typename vobj::scalar_object sum(const Lattice &arg) +{ + GridBase *grid=arg._grid; + int Nsimd = grid->Nsimd(); + + std::vector > sumarray(grid->SumArraySize()); + for(int i=0;iSumArraySize();i++){ + sumarray[i]=zero; + } + + parallel_for(int thr=0;thrSumArraySize();thr++){ + int nwork, mywork, myoff; + GridThread::GetWork(grid->oSites(),thr,mywork,myoff); + + vobj vvsum=zero; + for(int ss=myoff;ssSumArraySize();i++){ + vsum = vsum+sumarray[i]; + } + + typedef typename vobj::scalar_object sobj; + sobj ssum=zero; + + std::vector buf(Nsimd); + extract(vsum,buf); + + for(int i=0;iGlobalSum(ssum); + + return ssum; +} +////////////////////////////////////////////////////////////////////////////////////////////////////////////// +// sliceSum, sliceInnerProduct, sliceAxpy, sliceNorm etc... +////////////////////////////////////////////////////////////////////////////////////////////////////////////// template inline void sliceSum(const Lattice &Data,std::vector &result,int orthogdim) { + /////////////////////////////////////////////////////// + // FIXME precision promoted summation + // may be important for correlation functions + // But easily avoided by using double precision fields + /////////////////////////////////////////////////////// typedef typename vobj::scalar_object sobj; GridBase *grid = Data._grid; assert(grid!=NULL); - // FIXME - // std::cout<SumArraySize()<<" threads "<_ndimension; const int Nsimd = grid->Nsimd(); @@ -165,23 +160,31 @@ template inline void sliceSum(const Lattice &Data,std::vector< int rd=grid->_rdimensions[orthogdim]; std::vector > lvSum(rd); // will locally sum vectors first - std::vector lsSum(ld,zero); // sum across these down to scalars - std::vector extracted(Nsimd); // splitting the SIMD + std::vector lsSum(ld,zero); // sum across these down to scalars + std::vector extracted(Nsimd); // splitting the SIMD - result.resize(fd); // And then global sum to return the same vector to every node for IO to file + result.resize(fd); // And then global sum to return the same vector to every node for(int r=0;r coor(Nd); + int e1= grid->_slice_nblock[orthogdim]; + int e2= grid->_slice_block [orthogdim]; + int stride=grid->_slice_stride[orthogdim]; // sum over reduced dimension planes, breaking out orthog dir + // Parallel over orthog direction + parallel_for(int r=0;roSites();ss++){ - Lexicographic::CoorFromIndex(coor,ss,grid->_rdimensions); - int r = coor[orthogdim]; - lvSum[r]=lvSum[r]+Data._odata[ss]; - } + int so=r*grid->_ostride[orthogdim]; // base offset for start of plane + + for(int n=0;n icoor(Nd); @@ -216,10 +219,354 @@ template inline void sliceSum(const Lattice &Data,std::vector< result[t]=gsum; } - } +template +static void sliceInnerProductVector( std::vector & result, const Lattice &lhs,const Lattice &rhs,int orthogdim) +{ + typedef typename vobj::vector_type vector_type; + typedef typename vobj::scalar_type scalar_type; + GridBase *grid = lhs._grid; + assert(grid!=NULL); + conformable(grid,rhs._grid); + const int Nd = grid->_ndimension; + const int Nsimd = grid->Nsimd(); + + assert(orthogdim >= 0); + assert(orthogdim < Nd); + + int fd=grid->_fdimensions[orthogdim]; + int ld=grid->_ldimensions[orthogdim]; + int rd=grid->_rdimensions[orthogdim]; + + std::vector > lvSum(rd); // will locally sum vectors first + std::vector lsSum(ld,scalar_type(0.0)); // sum across these down to scalars + std::vector > extracted(Nsimd); // splitting the SIMD + + result.resize(fd); // And then global sum to return the same vector to every node for IO to file + for(int r=0;r_slice_nblock[orthogdim]; + int e2= grid->_slice_block [orthogdim]; + int stride=grid->_slice_stride[orthogdim]; + + parallel_for(int r=0;r_ostride[orthogdim]; // base offset for start of plane + + for(int n=0;n icoor(Nd); + for(int rt=0;rt temp; + temp._internal = lvSum[rt]; + extract(temp,extracted); + + for(int idx=0;idxiCoorFromIindex(icoor,idx); + + int ldx =rt+icoor[orthogdim]*rd; + + lsSum[ldx]=lsSum[ldx]+extracted[idx]._internal; + + } + } + + // sum over nodes. + scalar_type gsum; + for(int t=0;t_processor_coor[orthogdim] ) { + gsum=lsSum[lt]; + } else { + gsum=scalar_type(0.0); + } + + grid->GlobalSum(gsum); + + result[t]=gsum; + } } +template +static void sliceNorm (std::vector &sn,const Lattice &rhs,int Orthog) +{ + typedef typename vobj::scalar_object sobj; + typedef typename vobj::scalar_type scalar_type; + typedef typename vobj::vector_type vector_type; + + int Nblock = rhs._grid->GlobalDimensions()[Orthog]; + std::vector ip(Nblock); + sn.resize(Nblock); + + sliceInnerProductVector(ip,rhs,rhs,Orthog); + for(int ss=0;ss +static void sliceMaddVector(Lattice &R,std::vector &a,const Lattice &X,const Lattice &Y, + int orthogdim,RealD scale=1.0) +{ + typedef typename vobj::scalar_object sobj; + typedef typename vobj::scalar_type scalar_type; + typedef typename vobj::vector_type vector_type; + typedef typename vobj::tensor_reduced tensor_reduced; + + scalar_type zscale(scale); + + GridBase *grid = X._grid; + + int Nsimd =grid->Nsimd(); + int Nblock =grid->GlobalDimensions()[orthogdim]; + + int fd =grid->_fdimensions[orthogdim]; + int ld =grid->_ldimensions[orthogdim]; + int rd =grid->_rdimensions[orthogdim]; + + int e1 =grid->_slice_nblock[orthogdim]; + int e2 =grid->_slice_block [orthogdim]; + int stride =grid->_slice_stride[orthogdim]; + + std::vector icoor; + + for(int r=0;r_ostride[orthogdim]; // base offset for start of plane + + vector_type av; + + for(int l=0;liCoorFromIindex(icoor,l); + int ldx =r+icoor[orthogdim]*rd; + scalar_type *as =(scalar_type *)&av; + as[l] = scalar_type(a[ldx])*zscale; + } + + tensor_reduced at; at=av; + + parallel_for_nest2(int n=0;n_ndimension; + int nsimd = BlockSolverGrid->Nsimd(); + + std::vector latt_phys(0); + std::vector simd_phys(0); + std::vector mpi_phys(0); + + for(int d=0;d_fdimensions[d]); + simd_phys.push_back(BlockSolverGrid->_simd_layout[d]); + mpi_phys.push_back(BlockSolverGrid->_processors[d]); + } + } + return (GridBase *)new GridCartesian(latt_phys,simd_phys,mpi_phys); +} +*/ + +template +static void sliceMaddMatrix (Lattice &R,Eigen::MatrixXcd &aa,const Lattice &X,const Lattice &Y,int Orthog,RealD scale=1.0) +{ + typedef typename vobj::scalar_object sobj; + typedef typename vobj::scalar_type scalar_type; + typedef typename vobj::vector_type vector_type; + + int Nblock = X._grid->GlobalDimensions()[Orthog]; + + GridBase *FullGrid = X._grid; + // GridBase *SliceGrid = makeSubSliceGrid(FullGrid,Orthog); + + // Lattice Xslice(SliceGrid); + // Lattice Rslice(SliceGrid); + + assert( FullGrid->_simd_layout[Orthog]==1); + int nh = FullGrid->_ndimension; + // int nl = SliceGrid->_ndimension; + int nl = nh-1; + + //FIXME package in a convenient iterator + //Should loop over a plane orthogonal to direction "Orthog" + int stride=FullGrid->_slice_stride[Orthog]; + int block =FullGrid->_slice_block [Orthog]; + int nblock=FullGrid->_slice_nblock[Orthog]; + int ostride=FullGrid->_ostride[Orthog]; +#pragma omp parallel + { + std::vector s_x(Nblock); + +#pragma omp for collapse(2) + for(int n=0;n +static void sliceMulMatrix (Lattice &R,Eigen::MatrixXcd &aa,const Lattice &X,int Orthog,RealD scale=1.0) +{ + typedef typename vobj::scalar_object sobj; + typedef typename vobj::scalar_type scalar_type; + typedef typename vobj::vector_type vector_type; + + int Nblock = X._grid->GlobalDimensions()[Orthog]; + + GridBase *FullGrid = X._grid; + // GridBase *SliceGrid = makeSubSliceGrid(FullGrid,Orthog); + // Lattice Xslice(SliceGrid); + // Lattice Rslice(SliceGrid); + + assert( FullGrid->_simd_layout[Orthog]==1); + int nh = FullGrid->_ndimension; + // int nl = SliceGrid->_ndimension; + int nl=1; + + //FIXME package in a convenient iterator + //Should loop over a plane orthogonal to direction "Orthog" + int stride=FullGrid->_slice_stride[Orthog]; + int block =FullGrid->_slice_block [Orthog]; + int nblock=FullGrid->_slice_nblock[Orthog]; + int ostride=FullGrid->_ostride[Orthog]; +#pragma omp parallel + { + std::vector s_x(Nblock); + +#pragma omp for collapse(2) + for(int n=0;n +static void sliceInnerProductMatrix( Eigen::MatrixXcd &mat, const Lattice &lhs,const Lattice &rhs,int Orthog) +{ + typedef typename vobj::scalar_object sobj; + typedef typename vobj::scalar_type scalar_type; + typedef typename vobj::vector_type vector_type; + + GridBase *FullGrid = lhs._grid; + // GridBase *SliceGrid = makeSubSliceGrid(FullGrid,Orthog); + + int Nblock = FullGrid->GlobalDimensions()[Orthog]; + + // Lattice Lslice(SliceGrid); + // Lattice Rslice(SliceGrid); + + mat = Eigen::MatrixXcd::Zero(Nblock,Nblock); + + assert( FullGrid->_simd_layout[Orthog]==1); + int nh = FullGrid->_ndimension; + // int nl = SliceGrid->_ndimension; + int nl = nh-1; + + //FIXME package in a convenient iterator + //Should loop over a plane orthogonal to direction "Orthog" + int stride=FullGrid->_slice_stride[Orthog]; + int block =FullGrid->_slice_block [Orthog]; + int nblock=FullGrid->_slice_nblock[Orthog]; + int ostride=FullGrid->_ostride[Orthog]; + + typedef typename vobj::vector_typeD vector_typeD; + +#pragma omp parallel + { + std::vector Left(Nblock); + std::vector Right(Nblock); + Eigen::MatrixXcd mat_thread = Eigen::MatrixXcd::Zero(Nblock,Nblock); + +#pragma omp for collapse(2) + for(int n=0;nGlobalSum(sum); + mat(i,j)=sum; + }} + + return; +} + +} /*END NAMESPACE GRID*/ #endif + + diff --git a/lib/lattice/Lattice_rng.h b/lib/lattice/Lattice_rng.h index 51cc16ec..6dc50fd2 100644 --- a/lib/lattice/Lattice_rng.h +++ b/lib/lattice/Lattice_rng.h @@ -6,8 +6,8 @@ Copyright (C) 2015 -Author: Peter Boyle -Author: paboyle + Author: Peter Boyle + Author: Guido Cossu This program is free software; you can redistribute it and/or modify it under the terms of the GNU General Public License as published by @@ -31,8 +31,17 @@ Author: paboyle #include -namespace Grid { +#ifdef RNG_SITMO +#include +#endif +#if defined(RNG_SITMO) +#define RNG_FAST_DISCARD +#else +#undef RNG_FAST_DISCARD +#endif + +namespace Grid { ////////////////////////////////////////////////////////////// // Allow the RNG state to be less dense than the fine grid @@ -63,111 +72,188 @@ namespace Grid { multiplicity = multiplicity *fine->_rdimensions[fd] / coarse->_rdimensions[d]; } - return multiplicity; } + +// merge of April 11 2017 +//<<<<<<< HEAD + + + // this function is necessary for the LS vectorised field + inline int RNGfillable_general(GridBase *coarse,GridBase *fine) + { + int rngdims = coarse->_ndimension; + + // trivially extended in higher dims, with locality guaranteeing RNG state is local to node + int lowerdims = fine->_ndimension - coarse->_ndimension; assert(lowerdims >= 0); + // assumes that the higher dimensions are not using more processors + // all further divisions are local + for(int d=0;d_processors[d]==1); + for(int d=0;d_processors[d] == fine->_processors[d+lowerdims]); + + + // then divide the number of local sites + // check that the total number of sims agree, meanse the iSites are the same + assert(fine->Nsimd() == coarse->Nsimd()); + + // check that the two grids divide cleanly + assert( (fine->lSites() / coarse->lSites() ) * coarse->lSites() == fine->lSites() ); + + return fine->lSites() / coarse->lSites(); + } + + /* // Wrap seed_seq to give common interface with random_device class fixedSeed { public: - typedef std::seed_seq::result_type result_type; - std::seed_seq src; fixedSeed(const std::vector &seeds) : src(seeds.begin(),seeds.end()) {}; result_type operator () (void){ - std::vector list(1); - src.generate(list.begin(),list.end()); - return list[0]; - } }; +======= +>>>>>>> develop + */ + // real scalars are one component - template void fillScalar(scalar &s,distribution &dist,generator & gen) + template + void fillScalar(scalar &s,distribution &dist,generator & gen) { s=dist(gen); } - template void fillScalar(ComplexF &s,distribution &dist, generator &gen) + template + void fillScalar(ComplexF &s,distribution &dist, generator &gen) { s=ComplexF(dist(gen),dist(gen)); } - template void fillScalar(ComplexD &s,distribution &dist,generator &gen) + template + void fillScalar(ComplexD &s,distribution &dist,generator &gen) { s=ComplexD(dist(gen),dist(gen)); } class GridRNGbase { - public: - - int _seeded; // One generator per site. // Uniform and Gaussian distributions from these generators. #ifdef RNG_RANLUX - typedef uint64_t RngStateType; typedef std::ranlux48 RngEngine; + typedef uint64_t RngStateType; static const int RngStateCount = 15; -#else +#endif +#ifdef RNG_MT19937 typedef std::mt19937 RngEngine; typedef uint32_t RngStateType; static const int RngStateCount = std::mt19937::state_size; #endif - std::vector _generators; - std::vector> _uniform; - std::vector> _gaussian; - std::vector> _bernoulli; +#ifdef RNG_SITMO + typedef sitmo::prng_engine RngEngine; + typedef uint64_t RngStateType; + static const int RngStateCount = 13; +#endif - void GetState(std::vector & saved,int gen) { + std::vector _generators; + std::vector > _uniform; + std::vector > _gaussian; + std::vector > _bernoulli; + std::vector > _uid; + + /////////////////////// + // support for parallel init + /////////////////////// +#ifdef RNG_FAST_DISCARD + static void Skip(RngEngine &eng) + { + ///////////////////////////////////////////////////////////////////////////////////// + // Skip by 2^40 elements between successive lattice sites + // This goes by 10^12. + // Consider quenched updating; likely never exceeding rate of 1000 sweeps + // per second on any machine. This gives us of order 10^9 seconds, or 100 years + // skip ahead. + // For HMC unlikely to go at faster than a solve per second, and + // tens of seconds per trajectory so this is clean in all reasonable cases, + // and margin of safety is orders of magnitude. + // We could hack Sitmo to skip in the higher order words of state if necessary + ///////////////////////////////////////////////////////////////////////////////////// + uint64_t skip = 0x1; skip = skip<<40; + eng.discard(skip); + } +#endif + static RngEngine Reseed(RngEngine &eng) + { + std::vector newseed; + std::uniform_int_distribution uid; + return Reseed(eng,newseed,uid); + } + static RngEngine Reseed(RngEngine &eng,std::vector & newseed, + std::uniform_int_distribution &uid) + { + const int reseeds=4; + + newseed.resize(reseeds); + for(int i=0;i & saved,RngEngine &eng) { saved.resize(RngStateCount); std::stringstream ss; - ss<<_generators[gen]; + ss<>saved[i]; + ss>>saved[i]; } } - void SetState(std::vector & saved,int gen){ + void GetState(std::vector & saved,int gen) { + GetState(saved,_generators[gen]); + } + void SetState(std::vector & saved,RngEngine &eng){ assert(saved.size()==RngStateCount); std::stringstream ss; for(int i=0;i>_generators[gen]; + ss>>eng; } + void SetState(std::vector & saved,int gen){ + SetState(saved,_generators[gen]); + } + void SetEngine(RngEngine &Eng, int gen){ + _generators[gen]=Eng; + } + void GetEngine(RngEngine &Eng, int gen){ + Eng=_generators[gen]; + } + template void Seed(source &src, int gen) + { + _generators[gen] = RngEngine(src); + } }; class GridSerialRNG : public GridRNGbase { public: - // FIXME ... do we require lockstep draws of randoms - // from all nodes keeping seeds consistent. - // place a barrier/broadcast in the fill routine - template void Seed(source &src) - { - typename source::result_type init = src(); - CartesianCommunicator::BroadcastWorld(0,(void *)&init,sizeof(init)); - _generators[0] = RngEngine(init); - _seeded=1; - } - GridSerialRNG() : GridRNGbase() { _generators.resize(1); _uniform.resize(1,std::uniform_real_distribution{0,1}); _gaussian.resize(1,std::normal_distribution(0.0,1.0) ); _bernoulli.resize(1,std::discrete_distribution{1,1}); - _seeded=0; + _uid.resize(1,std::uniform_int_distribution() ); } - - template inline void fill(sobj &l,std::vector &dist){ typedef typename sobj::scalar_type scalar_type; @@ -178,9 +264,9 @@ namespace Grid { dist[0].reset(); for(int idx=0;idx &seeds){ - fixedSeed src(seeds); - Seed(src); + CartesianCommunicator::BroadcastWorld(0,(void *)&seeds[0],sizeof(int)*seeds.size()); + std::seed_seq src(seeds.begin(),seeds.end()); + Seed(src,0); } - }; class GridParallelRNG : public GridRNGbase { + + double _time_counter; + public: - GridBase *_grid; - int _vol; + unsigned int _vol; - int generator_idx(int os,int is){ + int generator_idx(int os,int is) { return is*_grid->oSites()+os; } GridParallelRNG(GridBase *grid) : GridRNGbase() { - _grid=grid; - _vol =_grid->iSites()*_grid->oSites(); + _grid = grid; + _vol =_grid->iSites()*_grid->oSites(); _generators.resize(_vol); _uniform.resize(_vol,std::uniform_real_distribution{0,1}); _gaussian.resize(_vol,std::normal_distribution(0.0,1.0) ); _bernoulli.resize(_vol,std::discrete_distribution{1,1}); - _seeded=0; + _uid.resize(_vol,std::uniform_int_distribution() ); } - - // This loop could be made faster to avoid the Ahmdahl by - // i) seed generators on each timeslice, for x=y=z=0; - // ii) seed generators on each z for x=y=0 - // iii)seed generators on each y,z for x=0 - // iv) seed generators on each y,z,x - // made possible by physical indexing. - template void Seed(source &src) - { - std::vector gcoor; - - int gsites = _grid->_gsites; - - typename source::result_type init = src(); - RngEngine pseeder(init); - std::uniform_int_distribution ui; - - for(int gidx=0;gidxGlobalIndexToGlobalCoor(gidx,gcoor); - _grid->GlobalCoorToRankIndex(rank,o_idx,i_idx,gcoor); - - int l_idx=generator_idx(o_idx,i_idx); - - const int num_rand_seed=16; - std::vector site_seeds(num_rand_seed); - for(int i=0;iBroadcast(0,(void *)&site_seeds[0],sizeof(int)*site_seeds.size()); - - if( rank == _grid->ThisRank() ){ - fixedSeed ssrc(site_seeds); - typename source::result_type sinit = ssrc(); - _generators[l_idx] = RngEngine(sinit); - } - } - _seeded=1; - } - - //FIXME implement generic IO and create state save/restore - //void SaveState(const std::string &file); - //void LoadState(const std::string &file); - template inline void fill(Lattice &l,std::vector &dist){ typedef typename vobj::scalar_object scalar_object; typedef typename vobj::scalar_type scalar_type; typedef typename vobj::vector_type vector_type; - - int multiplicity = RNGfillable(_grid,l._grid); - int Nsimd =_grid->Nsimd(); - int osites=_grid->oSites(); - int words=sizeof(scalar_object)/sizeof(scalar_type); + double inner_time_counter = usecond(); + int multiplicity = RNGfillable_general(_grid, l._grid); // l has finer or same grid + int Nsimd = _grid->Nsimd(); // guaranteed to be the same for l._grid too + int osites = _grid->oSites(); // guaranteed to be <= l._grid->oSites() by a factor multiplicity + int words = sizeof(scalar_object) / sizeof(scalar_type); -PARALLEL_FOR_LOOP - for(int ss=0;ss buf(Nsimd); + for (int m = 0; m < multiplicity; m++) { // Draw from same generator multiplicity times - std::vector buf(Nsimd); - for(int m=0;m &seeds){ - fixedSeed src(seeds); - Seed(src); + + // Everyone generates the same seed_seq based on input seeds + CartesianCommunicator::BroadcastWorld(0,(void *)&seeds[0],sizeof(int)*seeds.size()); + + std::seed_seq source(seeds.begin(),seeds.end()); + + RngEngine master_engine(source); + +#ifdef RNG_FAST_DISCARD + //////////////////////////////////////////////// + // Skip ahead through a single stream. + // Applicable to SITMO and other has based/crypto RNGs + // Should be applicable to Mersenne Twister, but the C++11 + // MT implementation does not implement fast discard even though + // in principle this is possible + //////////////////////////////////////////////// + std::vector gcoor; + int rank,o_idx,i_idx; + + // Everybody loops over global volume. + for(int gidx=0;gidx<_grid->_gsites;gidx++){ + + Skip(master_engine); // Skip to next RNG sequence + + // Where is it? + _grid->GlobalIndexToGlobalCoor(gidx,gcoor); + _grid->GlobalCoorToRankIndex(rank,o_idx,i_idx,gcoor); + + // If this is one of mine we take it + if( rank == _grid->ThisRank() ){ + int l_idx=generator_idx(o_idx,i_idx); + _generators[l_idx] = master_engine; + } + + } +#else + //////////////////////////////////////////////////////////////// + // Machine and thread decomposition dependent seeding is efficient + // and maximally parallel; but NOT reproducible from machine to machine. + // Not ideal, but fastest way to reseed all nodes. + //////////////////////////////////////////////////////////////// + { + // Obtain one Reseed per processor + int Nproc = _grid->ProcessorCount(); + std::vector seeders(Nproc); + int me= _grid->ThisRank(); + for(int p=0;p seeders(Nthread); + for(int t=0;t newseeds; + std::uniform_int_distribution uid; + for(int l=0;l<_grid->lSites();l++) { + if ( (l%Nthread)==t ) { + _generators[l] = Reseed(seeders[t],newseeds,uid); + } + } + } + } +#endif + } + + void Report(){ + std::cout << GridLogMessage << "Time spent in the fill() routine by GridParallelRNG: "<< _time_counter/1e3 << " ms" << std::endl; + } + + + //////////////////////////////////////////////////////////////////////// + // Support for rigorous test of RNG's + // Return uniform random uint32_t from requested site generator + //////////////////////////////////////////////////////////////////////// + uint32_t GlobalU01(int gsite){ + + uint32_t the_number; + // who + std::vector gcoor; + int rank,o_idx,i_idx; + _grid->GlobalIndexToGlobalCoor(gsite,gcoor); + _grid->GlobalCoorToRankIndex(rank,o_idx,i_idx,gcoor); + + // draw + int l_idx=generator_idx(o_idx,i_idx); + if( rank == _grid->ThisRank() ){ + the_number = _uid[l_idx](_generators[l_idx]); + } + + // share & return + _grid->Broadcast(rank,(void *)&the_number,sizeof(the_number)); + return the_number; } }; - template inline void random(GridParallelRNG &rng,Lattice &l){ - rng.fill(l,rng._uniform); - } + template inline void random(GridParallelRNG &rng,Lattice &l) { rng.fill(l,rng._uniform); } + template inline void gaussian(GridParallelRNG &rng,Lattice &l) { rng.fill(l,rng._gaussian); } + template inline void bernoulli(GridParallelRNG &rng,Lattice &l){ rng.fill(l,rng._bernoulli);} - template inline void gaussian(GridParallelRNG &rng,Lattice &l){ - rng.fill(l,rng._gaussian); - } - - template inline void bernoulli(GridParallelRNG &rng,Lattice &l){ - rng.fill(l,rng._bernoulli); - } - - template inline void random(GridSerialRNG &rng,sobj &l){ - rng.fill(l,rng._uniform); - } - - template inline void gaussian(GridSerialRNG &rng,sobj &l){ - rng.fill(l,rng._gaussian); - } - - template inline void bernoulli(GridSerialRNG &rng,sobj &l){ - rng.fill(l,rng._bernoulli); - } + template inline void random(GridSerialRNG &rng,sobj &l) { rng.fill(l,rng._uniform ); } + template inline void gaussian(GridSerialRNG &rng,sobj &l) { rng.fill(l,rng._gaussian ); } + template inline void bernoulli(GridSerialRNG &rng,sobj &l){ rng.fill(l,rng._bernoulli); } } #endif diff --git a/lib/lattice/Lattice_trace.h b/lib/lattice/Lattice_trace.h index a341ff7c..449c55f8 100644 --- a/lib/lattice/Lattice_trace.h +++ b/lib/lattice/Lattice_trace.h @@ -42,8 +42,7 @@ namespace Grid { -> Lattice { Lattice ret(lhs._grid); -PARALLEL_FOR_LOOP - for(int ss=0;ssoSites();ss++){ + parallel_for(int ss=0;ssoSites();ss++){ ret._odata[ss] = trace(lhs._odata[ss]); } return ret; @@ -56,8 +55,7 @@ PARALLEL_FOR_LOOP inline auto TraceIndex(const Lattice &lhs) -> Lattice(lhs._odata[0]))> { Lattice(lhs._odata[0]))> ret(lhs._grid); -PARALLEL_FOR_LOOP - for(int ss=0;ssoSites();ss++){ + parallel_for(int ss=0;ssoSites();ss++){ ret._odata[ss] = traceIndex(lhs._odata[ss]); } return ret; diff --git a/lib/lattice/Lattice_transfer.h b/lib/lattice/Lattice_transfer.h index cc4617de..cbf31f86 100644 --- a/lib/lattice/Lattice_transfer.h +++ b/lib/lattice/Lattice_transfer.h @@ -1,4 +1,4 @@ - /************************************************************************************* +/************************************************************************************* Grid physics library, www.github.com/paboyle/Grid @@ -51,7 +51,7 @@ inline void subdivides(GridBase *coarse,GridBase *fine) template inline void pickCheckerboard(int cb,Lattice &half,const Lattice &full){ half.checkerboard = cb; int ssh=0; - //PARALLEL_FOR_LOOP + //parallel_for for(int ss=0;ssoSites();ss++){ std::vector coor; int cbos; @@ -68,7 +68,7 @@ inline void subdivides(GridBase *coarse,GridBase *fine) template inline void setCheckerboard(Lattice &full,const Lattice &half){ int cb = half.checkerboard; int ssh=0; - //PARALLEL_FOR_LOOP + //parallel_for for(int ss=0;ssoSites();ss++){ std::vector coor; int cbos; @@ -153,8 +153,7 @@ inline void blockZAXPY(Lattice &fineZ, assert(block_r[d]*coarse->_rdimensions[d]==fine->_rdimensions[d]); } -PARALLEL_FOR_LOOP - for(int sf=0;sfoSites();sf++){ + parallel_for(int sf=0;sfoSites();sf++){ int sc; std::vector coor_c(_ndimension); @@ -186,8 +185,7 @@ template fine_inner = localInnerProduct(fineX,fineY); blockSum(coarse_inner,fine_inner); -PARALLEL_FOR_LOOP - for(int ss=0;ssoSites();ss++){ + parallel_for(int ss=0;ssoSites();ss++){ CoarseInner._odata[ss] = coarse_inner._odata[ss]; } } @@ -333,9 +331,6 @@ void localConvert(const Lattice &in,Lattice &out) typedef typename vobj::scalar_object sobj; typedef typename vvobj::scalar_object ssobj; - sobj s; - ssobj ss; - GridBase *ig = in._grid; GridBase *og = out._grid; @@ -347,10 +342,13 @@ void localConvert(const Lattice &in,Lattice &out) for(int d=0;d_processors[d] == og->_processors[d]); assert(ig->_ldimensions[d] == og->_ldimensions[d]); + assert(ig->lSites() == og->lSites()); } - //PARALLEL_FOR_LOOP - for(int idx=0;idxlSites();idx++){ + parallel_for(int idx=0;idxlSites();idx++){ + sobj s; + ssobj ss; + std::vector lcoor(ni); ig->LocalIndexToLocalCoor(idx,lcoor); peekLocalSite(s,in,lcoor); @@ -361,10 +359,9 @@ void localConvert(const Lattice &in,Lattice &out) template -void InsertSlice(Lattice &lowDim,Lattice & higherDim,int slice, int orthog) +void InsertSlice(const Lattice &lowDim,Lattice & higherDim,int slice, int orthog) { typedef typename vobj::scalar_object sobj; - sobj s; GridBase *lg = lowDim._grid; GridBase *hg = higherDim._grid; @@ -386,16 +383,16 @@ void InsertSlice(Lattice &lowDim,Lattice & higherDim,int slice, int } // the above should guarantee that the operations are local - //PARALLEL_FOR_LOOP - for(int idx=0;idxlSites();idx++){ + parallel_for(int idx=0;idxlSites();idx++){ + sobj s; std::vector lcoor(nl); std::vector hcoor(nh); lg->LocalIndexToLocalCoor(idx,lcoor); - dl=0; + int ddl=0; hcoor[orthog] = slice; for(int d=0;d &lowDim,Lattice & higherDim,int slice, int } template -void ExtractSlice(Lattice &lowDim, Lattice & higherDim,int slice, int orthog) +void ExtractSlice(Lattice &lowDim,const Lattice & higherDim,int slice, int orthog) { typedef typename vobj::scalar_object sobj; - sobj s; GridBase *lg = lowDim._grid; GridBase *hg = higherDim._grid; @@ -428,16 +424,16 @@ void ExtractSlice(Lattice &lowDim, Lattice & higherDim,int slice, in } } // the above should guarantee that the operations are local - //PARALLEL_FOR_LOOP - for(int idx=0;idxlSites();idx++){ + parallel_for(int idx=0;idxlSites();idx++){ + sobj s; std::vector lcoor(nl); std::vector hcoor(nh); lg->LocalIndexToLocalCoor(idx,lcoor); - dl=0; + int ddl=0; hcoor[orthog] = slice; for(int d=0;d &lowDim, Lattice & higherDim,int slice, in template -void InsertSliceLocal(Lattice &lowDim, Lattice & higherDim,int slice_lo,int slice_hi, int orthog) +void InsertSliceLocal(const Lattice &lowDim, Lattice & higherDim,int slice_lo,int slice_hi, int orthog) { typedef typename vobj::scalar_object sobj; - sobj s; GridBase *lg = lowDim._grid; GridBase *hg = higherDim._grid; @@ -468,8 +463,8 @@ void InsertSliceLocal(Lattice &lowDim, Lattice & higherDim,int slice } // the above should guarantee that the operations are local - //PARALLEL_FOR_LOOP - for(int idx=0;idxlSites();idx++){ + parallel_for(int idx=0;idxlSites();idx++){ + sobj s; std::vector lcoor(nl); std::vector hcoor(nh); lg->LocalIndexToLocalCoor(idx,lcoor); @@ -487,7 +482,6 @@ template void ExtractSliceLocal(Lattice &lowDim, Lattice & higherDim,int slice_lo,int slice_hi, int orthog) { typedef typename vobj::scalar_object sobj; - sobj s; GridBase *lg = lowDim._grid; GridBase *hg = higherDim._grid; @@ -504,8 +498,8 @@ void ExtractSliceLocal(Lattice &lowDim, Lattice & higherDim,int slic } // the above should guarantee that the operations are local - //PARALLEL_FOR_LOOP - for(int idx=0;idxlSites();idx++){ + parallel_for(int idx=0;idxlSites();idx++){ + sobj s; std::vector lcoor(nl); std::vector hcoor(nh); lg->LocalIndexToLocalCoor(idx,lcoor); @@ -557,7 +551,10 @@ void Replicate(Lattice &coarse,Lattice & fine) //Copy SIMD-vectorized lattice to array of scalar objects in lexicographic order template -typename std::enable_if::value && !isSIMDvectorized::value, void>::type unvectorizeToLexOrdArray(std::vector &out, const Lattice &in){ +typename std::enable_if::value && !isSIMDvectorized::value, void>::type +unvectorizeToLexOrdArray(std::vector &out, const Lattice &in) +{ + typedef typename vobj::vector_type vtype; GridBase* in_grid = in._grid; @@ -573,8 +570,7 @@ typename std::enable_if::value && !isSIMDvectorized in_grid->iCoorFromIindex(in_icoor[lane], lane); } -PARALLEL_FOR_LOOP - for(int in_oidx = 0; in_oidx < in_grid->oSites(); in_oidx++){ //loop over outer index + parallel_for(int in_oidx = 0; in_oidx < in_grid->oSites(); in_oidx++){ //loop over outer index //Assemble vector of pointers to output elements std::vector out_ptrs(in_nsimd); @@ -597,6 +593,54 @@ PARALLEL_FOR_LOOP extract1(in_vobj, out_ptrs, 0); } } +//Copy SIMD-vectorized lattice to array of scalar objects in lexicographic order +template +typename std::enable_if::value + && !isSIMDvectorized::value, void>::type +vectorizeFromLexOrdArray( std::vector &in, Lattice &out) +{ + + typedef typename vobj::vector_type vtype; + + GridBase* grid = out._grid; + assert(in.size()==grid->lSites()); + + int ndim = grid->Nd(); + int nsimd = vtype::Nsimd(); + + std::vector > icoor(nsimd); + + for(int lane=0; lane < nsimd; lane++){ + icoor[lane].resize(ndim); + grid->iCoorFromIindex(icoor[lane],lane); + } + + parallel_for(uint64_t oidx = 0; oidx < grid->oSites(); oidx++){ //loop over outer index + //Assemble vector of pointers to output elements + std::vector ptrs(nsimd); + + std::vector ocoor(ndim); + grid->oCoorFromOindex(ocoor, oidx); + + std::vector lcoor(grid->Nd()); + + for(int lane=0; lane < nsimd; lane++){ + + for(int mu=0;mu_rdimensions[mu]*icoor[lane][mu]; + } + + int lex; + Lexicographic::IndexFromCoor(lcoor, lex, grid->_ldimensions); + ptrs[lane] = &in[lex]; + } + + //pack from those ptrs + vobj vecobj; + merge1(vecobj, ptrs, 0); + out._odata[oidx] = vecobj; + } +} //Convert a Lattice from one precision to another template @@ -622,8 +666,7 @@ void precisionChange(Lattice &out, const Lattice &in){ std::vector in_slex_conv(in_grid->lSites()); unvectorizeToLexOrdArray(in_slex_conv, in); - PARALLEL_FOR_LOOP - for(int out_oidx=0;out_oidxoSites();out_oidx++){ + parallel_for(uint64_t out_oidx=0;out_oidxoSites();out_oidx++){ std::vector out_ocoor(ndim); out_grid->oCoorFromOindex(out_ocoor, out_oidx); @@ -641,10 +684,6 @@ void precisionChange(Lattice &out, const Lattice &in){ merge(out._odata[out_oidx], ptrs, 0); } } - - - - } #endif diff --git a/lib/lattice/Lattice_transpose.h b/lib/lattice/Lattice_transpose.h index c8d349a6..0ae7c6b3 100644 --- a/lib/lattice/Lattice_transpose.h +++ b/lib/lattice/Lattice_transpose.h @@ -40,27 +40,24 @@ namespace Grid { //////////////////////////////////////////////////////////////////////////////////////////////////// template inline Lattice transpose(const Lattice &lhs){ - Lattice ret(lhs._grid); -PARALLEL_FOR_LOOP - for(int ss=0;ssoSites();ss++){ - ret._odata[ss] = transpose(lhs._odata[ss]); - } - return ret; - }; + Lattice ret(lhs._grid); + parallel_for(int ss=0;ssoSites();ss++){ + ret._odata[ss] = transpose(lhs._odata[ss]); + } + return ret; + }; - //////////////////////////////////////////////////////////////////////////////////////////////////// - // Index level dependent transpose - //////////////////////////////////////////////////////////////////////////////////////////////////// - template + //////////////////////////////////////////////////////////////////////////////////////////////////// + // Index level dependent transpose + //////////////////////////////////////////////////////////////////////////////////////////////////// + template inline auto TransposeIndex(const Lattice &lhs) -> Lattice(lhs._odata[0]))> - { - Lattice(lhs._odata[0]))> ret(lhs._grid); -PARALLEL_FOR_LOOP - for(int ss=0;ssoSites();ss++){ - ret._odata[ss] = transposeIndex(lhs._odata[ss]); - } - return ret; - }; - + { + Lattice(lhs._odata[0]))> ret(lhs._grid); + parallel_for(int ss=0;ssoSites();ss++){ + ret._odata[ss] = transposeIndex(lhs._odata[ss]); + } + return ret; + }; } #endif diff --git a/lib/lattice/Lattice_unary.h b/lib/lattice/Lattice_unary.h index f3c54896..44b7b4f1 100644 --- a/lib/lattice/Lattice_unary.h +++ b/lib/lattice/Lattice_unary.h @@ -37,8 +37,7 @@ namespace Grid { Lattice ret(rhs._grid); ret.checkerboard = rhs.checkerboard; conformable(ret,rhs); -PARALLEL_FOR_LOOP - for(int ss=0;ssoSites();ss++){ + parallel_for(int ss=0;ssoSites();ss++){ ret._odata[ss]=pow(rhs._odata[ss],y); } return ret; @@ -47,8 +46,7 @@ PARALLEL_FOR_LOOP Lattice ret(rhs._grid); ret.checkerboard = rhs.checkerboard; conformable(ret,rhs); -PARALLEL_FOR_LOOP - for(int ss=0;ssoSites();ss++){ + parallel_for(int ss=0;ssoSites();ss++){ ret._odata[ss]=mod(rhs._odata[ss],y); } return ret; @@ -58,22 +56,26 @@ PARALLEL_FOR_LOOP Lattice ret(rhs._grid); ret.checkerboard = rhs.checkerboard; conformable(ret,rhs); -PARALLEL_FOR_LOOP - for(int ss=0;ssoSites();ss++){ + parallel_for(int ss=0;ssoSites();ss++){ ret._odata[ss]=div(rhs._odata[ss],y); } return ret; } - template Lattice expMat(const Lattice &rhs, ComplexD alpha, Integer Nexp = DEFAULT_MAT_EXP){ + template Lattice expMat(const Lattice &rhs, RealD alpha, Integer Nexp = DEFAULT_MAT_EXP){ Lattice ret(rhs._grid); ret.checkerboard = rhs.checkerboard; conformable(ret,rhs); -PARALLEL_FOR_LOOP - for(int ss=0;ssoSites();ss++){ + parallel_for(int ss=0;ssoSites();ss++){ ret._odata[ss]=Exponentiate(rhs._odata[ss],alpha, Nexp); } + return ret; + + + + + } diff --git a/lib/lattice/Lattice_where.h b/lib/lattice/Lattice_where.h index cff372a0..6686d1b3 100644 --- a/lib/lattice/Lattice_where.h +++ b/lib/lattice/Lattice_where.h @@ -56,8 +56,7 @@ inline void whereWolf(Lattice &ret,const Lattice &predicate,Lattice< std::vector truevals (Nsimd); std::vector falsevals(Nsimd); -PARALLEL_FOR_LOOP - for(int ss=0;ssoSites(); ss++){ + parallel_for(int ss=0;ssoSites(); ss++){ extract(iftrue._odata[ss] ,truevals); extract(iffalse._odata[ss] ,falsevals); diff --git a/lib/Log.cc b/lib/log/Log.cc similarity index 97% rename from lib/Log.cc rename to lib/log/Log.cc index 7521657b..65dc2812 100644 --- a/lib/Log.cc +++ b/lib/log/Log.cc @@ -29,9 +29,11 @@ See the full license in the file "LICENSE" in the top level distribution directory *************************************************************************************/ /* END LEGAL */ -#include +#include +#include #include +#include namespace Grid { @@ -93,7 +95,7 @@ void GridLogConfigure(std::vector &logstreams) { //////////////////////////////////////////////////////////// void Grid_quiesce_nodes(void) { int me = 0; -#if defined(GRID_COMMS_MPI) || defined(GRID_COMMS_MPI3) || defined(GRID_COMMS_MPI3L) +#if defined(GRID_COMMS_MPI) || defined(GRID_COMMS_MPI3) || defined(GRID_COMMS_MPIT) MPI_Comm_rank(MPI_COMM_WORLD, &me); #endif #ifdef GRID_COMMS_SHMEM diff --git a/lib/Log.h b/lib/log/Log.h similarity index 96% rename from lib/Log.h rename to lib/log/Log.h index d7422ca9..74d080bb 100644 --- a/lib/Log.h +++ b/lib/log/Log.h @@ -110,8 +110,8 @@ public: friend std::ostream& operator<< (std::ostream& stream, Logger& log){ if ( log.active ) { - stream << log.background()<< log.topName << log.background()<< " : "; - stream << log.colour() < -Author: paboyle + Author: Peter Boyle + Author: Guido Cossu This program is free software; you can redistribute it and/or modify it under the terms of the GNU General Public License as published by @@ -29,57 +29,165 @@ Author: paboyle #ifndef GRID_BINARY_IO_H #define GRID_BINARY_IO_H +#if defined(GRID_COMMS_MPI) || defined(GRID_COMMS_MPI3) || defined(GRID_COMMS_MPIT) +#define USE_MPI_IO +#else +#undef USE_MPI_IO +#endif #ifdef HAVE_ENDIAN_H #include #endif + #include #include -// 64bit endian swap is a portability pain -#ifndef __has_builtin // Optional of course. -#define __has_builtin(x) 0 // Compatibility with non-clang compilers. -#endif - -#if HAVE_DECL_BE64TOH -#undef Grid_ntohll -#define Grid_ntohll be64toh -#endif - -#if HAVE_DECL_NTOHLL -#undef Grid_ntohll -#define Grid_ntohll ntohll -#endif - -#ifndef Grid_ntohll - -#if BYTE_ORDER == BIG_ENDIAN - -#define Grid_ntohll(A) (A) - -#else - -#if __has_builtin(__builtin_bswap64) -#define Grid_ntohll(A) __builtin_bswap64(A) -#else -#error -#endif - -#endif - -#endif namespace Grid { - // A little helper - inline void removeWhitespace(std::string &key) - { - key.erase(std::remove_if(key.begin(), key.end(), ::isspace),key.end()); - } +///////////////////////////////////////////////////////////////////////////////// +// Byte reversal garbage +///////////////////////////////////////////////////////////////////////////////// +inline uint32_t byte_reverse32(uint32_t f) { + f = ((f&0xFF)<<24) | ((f&0xFF00)<<8) | ((f&0xFF0000)>>8) | ((f&0xFF000000UL)>>24) ; + return f; +} +inline uint64_t byte_reverse64(uint64_t f) { + uint64_t g; + g = ((f&0xFF)<<24) | ((f&0xFF00)<<8) | ((f&0xFF0000)>>8) | ((f&0xFF000000UL)>>24) ; + g = g << 32; + f = f >> 32; + g|= ((f&0xFF)<<24) | ((f&0xFF00)<<8) | ((f&0xFF0000)>>8) | ((f&0xFF000000UL)>>24) ; + return g; +} + +#if BYTE_ORDER == BIG_ENDIAN +inline uint64_t Grid_ntohll(uint64_t A) { return A; } +#else +inline uint64_t Grid_ntohll(uint64_t A) { + return byte_reverse64(A); +} +#endif + +// A little helper +inline void removeWhitespace(std::string &key) +{ + key.erase(std::remove_if(key.begin(), key.end(), ::isspace),key.end()); +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// +// Static class holding the parallel IO code +// Could just use a namespace +/////////////////////////////////////////////////////////////////////////////////////////////////// class BinaryIO { - public: + ///////////////////////////////////////////////////////////////////////////// + // more byte manipulation helpers + ///////////////////////////////////////////////////////////////////////////// + + template static inline void Uint32Checksum(Lattice &lat,uint32_t &nersc_csum) + { + typedef typename vobj::scalar_object sobj; + + GridBase *grid = lat._grid; + int lsites = grid->lSites(); + + std::vector scalardata(lsites); + unvectorizeToLexOrdArray(scalardata,lat); + + NerscChecksum(grid,scalardata,nersc_csum); + } + + template + static inline void NerscChecksum(GridBase *grid, std::vector &fbuf, uint32_t &nersc_csum) + { + const uint64_t size32 = sizeof(fobj) / sizeof(uint32_t); + + uint64_t lsites = grid->lSites(); + if (fbuf.size() == 1) + { + lsites = 1; + } + + #pragma omp parallel + { + uint32_t nersc_csum_thr = 0; + + #pragma omp for + for (uint64_t local_site = 0; local_site < lsites; local_site++) + { + uint32_t *site_buf = (uint32_t *)&fbuf[local_site]; + for (uint64_t j = 0; j < size32; j++) + { + nersc_csum_thr = nersc_csum_thr + site_buf[j]; + } + } + + #pragma omp critical + { + nersc_csum += nersc_csum_thr; + } + } + } + + template static inline void ScidacChecksum(GridBase *grid,std::vector &fbuf,uint32_t &scidac_csuma,uint32_t &scidac_csumb) + { + const uint64_t size32 = sizeof(fobj)/sizeof(uint32_t); + + + int nd = grid->_ndimension; + + uint64_t lsites =grid->lSites(); + if (fbuf.size()==1) { + lsites=1; + } + std::vector local_vol =grid->LocalDimensions(); + std::vector local_start =grid->LocalStarts(); + std::vector global_vol =grid->FullDimensions(); + +#pragma omp parallel + { + std::vector coor(nd); + uint32_t scidac_csuma_thr=0; + uint32_t scidac_csumb_thr=0; + uint32_t site_crc=0; + +#pragma omp for + for(uint64_t local_site=0;local_site>(32-gsite29); + scidac_csumb_thr ^= site_crc<>(32-gsite31); + } + +#pragma omp critical + { + scidac_csuma^= scidac_csuma_thr; + scidac_csumb^= scidac_csumb_thr; + } + } + } // Network is big endian static inline void htobe32_v(void *file_object,uint32_t bytes){ be32toh_v(file_object,bytes);} @@ -87,21 +195,22 @@ class BinaryIO { static inline void htole32_v(void *file_object,uint32_t bytes){ le32toh_v(file_object,bytes);} static inline void htole64_v(void *file_object,uint32_t bytes){ le64toh_v(file_object,bytes);} - static inline void be32toh_v(void *file_object,uint32_t bytes) + static inline void be32toh_v(void *file_object,uint64_t bytes) { uint32_t * f = (uint32_t *)file_object; - for(int i=0;i*sizeof(uint32_t)>8) | ((f&0xFF000000UL)>>24) ; @@ -110,21 +219,23 @@ class BinaryIO { } // BE is same as network - static inline void be64toh_v(void *file_object,uint32_t bytes) + static inline void be64toh_v(void *file_object,uint64_t bytes) { uint64_t * f = (uint64_t *)file_object; - for(int i=0;i*sizeof(uint64_t)>8) | ((f&0xFF000000UL)>>24) ; @@ -134,547 +245,465 @@ class BinaryIO { fp[i] = Grid_ntohll(g); } } + ///////////////////////////////////////////////////////////////////////////// + // Real action: + // Read or Write distributed lexico array of ANY object to a specific location in file + ////////////////////////////////////////////////////////////////////////////////////// - template static inline void Uint32Checksum(Lattice &lat,munger munge,uint32_t &csum) + static const int BINARYIO_MASTER_APPEND = 0x10; + static const int BINARYIO_UNORDERED = 0x08; + static const int BINARYIO_LEXICOGRAPHIC = 0x04; + static const int BINARYIO_READ = 0x02; + static const int BINARYIO_WRITE = 0x01; + + template + static inline void IOobject(word w, + GridBase *grid, + std::vector &iodata, + std::string file, + int offset, + const std::string &format, int control, + uint32_t &nersc_csum, + uint32_t &scidac_csuma, + uint32_t &scidac_csumb) { - typedef typename vobj::scalar_object sobj; - GridBase *grid = lat._grid ; - std::cout < lcoor; - for(int l=0;llSites();l++){ - Lexicographic::CoorFromIndex(lcoor,l,grid->_ldimensions); - peekLocalSite(siteObj,lat,lcoor); - munge(siteObj,fileObj,csum); - } - grid->GlobalSum(csum); - } + grid->Barrier(); + GridStopWatch timer; + GridStopWatch bstimer; - static inline void Uint32Checksum(uint32_t *buf,uint32_t buf_size_bytes,uint32_t &csum) - { - for(int i=0;i*sizeof(uint32_t)Dimensions(); + int nrank = grid->ProcessorCount(); + int myrank = grid->ThisRank(); + + std::vector psizes = grid->ProcessorGrid(); + std::vector pcoor = grid->ThisProcessorCoor(); + std::vector gLattice= grid->GlobalDimensions(); + std::vector lLattice= grid->LocalDimensions(); + + std::vector lStart(ndim); + std::vector gStart(ndim); + + // Flatten the file + uint64_t lsites = grid->lSites(); + if ( control & BINARYIO_MASTER_APPEND ) { + assert(iodata.size()==1); + } else { + assert(lsites==iodata.size()); + } + for(int d=0;d - static inline uint32_t readObjectSerial(Lattice &Umu,std::string file,munger munge,int offset,const std::string &format) - { - typedef typename vobj::scalar_object sobj; - GridBase *grid = Umu._grid; +#ifdef USE_MPI_IO + std::vector distribs(ndim,MPI_DISTRIBUTE_BLOCK); + std::vector dargs (ndim,MPI_DISTRIBUTE_DFLT_DARG); + MPI_Datatype mpiObject; + MPI_Datatype fileArray; + MPI_Datatype localArray; + MPI_Datatype mpiword; + MPI_Offset disp = offset; + MPI_File fh ; + MPI_Status status; + int numword; - std::cout<< GridLogMessage<< "Serial read I/O "<< file<< std::endl; - GridStopWatch timer; timer.Start(); + if ( sizeof( word ) == sizeof(float ) ) { + numword = sizeof(fobj)/sizeof(float); + mpiword = MPI_FLOAT; + } else { + numword = sizeof(fobj)/sizeof(double); + mpiword = MPI_DOUBLE; + } + ////////////////////////////////////////////////////////////////////////////// + // Sobj in MPI phrasing + ////////////////////////////////////////////////////////////////////////////// + int ierr; + ierr = MPI_Type_contiguous(numword,mpiword,&mpiObject); assert(ierr==0); + ierr = MPI_Type_commit(&mpiObject); + + ////////////////////////////////////////////////////////////////////////////// + // File global array data type + ////////////////////////////////////////////////////////////////////////////// + ierr=MPI_Type_create_subarray(ndim,&gLattice[0],&lLattice[0],&gStart[0],MPI_ORDER_FORTRAN, mpiObject,&fileArray); assert(ierr==0); + ierr=MPI_Type_commit(&fileArray); assert(ierr==0); + + ////////////////////////////////////////////////////////////////////////////// + // local lattice array + ////////////////////////////////////////////////////////////////////////////// + ierr=MPI_Type_create_subarray(ndim,&lLattice[0],&lLattice[0],&lStart[0],MPI_ORDER_FORTRAN, mpiObject,&localArray); assert(ierr==0); + ierr=MPI_Type_commit(&localArray); assert(ierr==0); +#endif + + ////////////////////////////////////////////////////////////////////////////// + // Byte order + ////////////////////////////////////////////////////////////////////////////// int ieee32big = (format == std::string("IEEE32BIG")); int ieee32 = (format == std::string("IEEE32")); int ieee64big = (format == std::string("IEEE64BIG")); int ieee64 = (format == std::string("IEEE64")); - // Find the location of each site and send to primary node - // Take loop order from Chroma; defines loop order now that NERSC doc no longer - // available (how short sighted is that?) - std::ifstream fin(file,std::ios::binary|std::ios::in); - fin.seekg(offset); + ////////////////////////////////////////////////////////////////////////////// + // Do the I/O + ////////////////////////////////////////////////////////////////////////////// + if ( control & BINARYIO_READ ) { - Umu = zero; - uint32_t csum=0; - uint64_t bytes=0; - fobj file_object; - sobj munged; - - for(int t=0;t_fdimensions[3];t++){ - for(int z=0;z_fdimensions[2];z++){ - for(int y=0;y_fdimensions[1];y++){ - for(int x=0;x_fdimensions[0];x++){ + timer.Start(); - std::vector site({x,y,z,t}); - - if (grid->IsBoss()) { - fin.read((char *)&file_object, sizeof(file_object)); - bytes += sizeof(file_object); - if (ieee32big) be32toh_v((void *)&file_object, sizeof(file_object)); - if (ieee32) le32toh_v((void *)&file_object, sizeof(file_object)); - if (ieee64big) be64toh_v((void *)&file_object, sizeof(file_object)); - if (ieee64) le64toh_v((void *)&file_object, sizeof(file_object)); - - munge(file_object, munged, csum); - } - // The boss who read the file has their value poked - pokeSite(munged,Umu,site); - }}}} - timer.Stop(); - std::cout< - static inline uint32_t writeObjectSerial(Lattice &Umu,std::string file,munger munge,int offset,const std::string & format) - { - typedef typename vobj::scalar_object sobj; - - GridBase *grid = Umu._grid; - - int ieee32big = (format == std::string("IEEE32BIG")); - int ieee32 = (format == std::string("IEEE32")); - int ieee64big = (format == std::string("IEEE64BIG")); - int ieee64 = (format == std::string("IEEE64")); - - ////////////////////////////////////////////////// - // Serialise through node zero - ////////////////////////////////////////////////// - std::cout<< GridLogMessage<< "Serial write I/O "<< file<IsBoss() ) { - fout.open(file,std::ios::binary|std::ios::out|std::ios::in); - fout.seekp(offset); - } - uint64_t bytes=0; - uint32_t csum=0; - fobj file_object; - sobj unmunged; - for(int t=0;t_fdimensions[3];t++){ - for(int z=0;z_fdimensions[2];z++){ - for(int y=0;y_fdimensions[1];y++){ - for(int x=0;x_fdimensions[0];x++){ - - std::vector site({x,y,z,t}); - // peek & write - peekSite(unmunged,Umu,site); - - munge(unmunged,file_object,csum); - - - if ( grid->IsBoss() ) { - - if(ieee32big) htobe32_v((void *)&file_object,sizeof(file_object)); - if(ieee32) htole32_v((void *)&file_object,sizeof(file_object)); - if(ieee64big) htobe64_v((void *)&file_object,sizeof(file_object)); - if(ieee64) htole64_v((void *)&file_object,sizeof(file_object)); - - // NB could gather an xstrip as an optimisation. - fout.write((char *)&file_object,sizeof(file_object)); - bytes+=sizeof(file_object); - } - }}}} - timer.Stop(); - std::cout<_gsites; - - ////////////////////////////////////////////////// - // Serialise through node zero - ////////////////////////////////////////////////// - std::cout<< GridLogMessage<< "Serial RNG write I/O "<< file<IsBoss() ) { - fout.open(file,std::ios::binary|std::ios::out|std::ios::in); - fout.seekp(offset); - } - - uint32_t csum=0; - std::vector saved(RngStateCount); - int bytes = sizeof(RngStateType)*saved.size(); - std::vector gcoor; - - for(int gidx=0;gidxGlobalIndexToGlobalCoor(gidx,gcoor); - grid->GlobalCoorToRankIndex(rank,o_idx,i_idx,gcoor); - int l_idx=parallel.generator_idx(o_idx,i_idx); - - if( rank == grid->ThisRank() ){ - // std::cout << "rank" << rank<<" Getting state for index "<Broadcast(rank,(void *)&saved[0],bytes); - - if ( grid->IsBoss() ) { - Uint32Checksum((uint32_t *)&saved[0],bytes,csum); - fout.write((char *)&saved[0],bytes); - } - - } - - if ( grid->IsBoss() ) { - serial.GetState(saved,0); - Uint32Checksum((uint32_t *)&saved[0],bytes,csum); - fout.write((char *)&saved[0],bytes); - } - grid->Broadcast(0,(void *)&csum,sizeof(csum)); - return csum; - } - static inline uint32_t readRNGSerial(GridSerialRNG &serial,GridParallelRNG ¶llel,std::string file,int offset) - { - typedef typename GridSerialRNG::RngStateType RngStateType; - const int RngStateCount = GridSerialRNG::RngStateCount; - - GridBase *grid = parallel._grid; - int gsites = grid->_gsites; - - ////////////////////////////////////////////////// - // Serialise through node zero - ////////////////////////////////////////////////// - std::cout<< GridLogMessage<< "Serial RNG read I/O "<< file< saved(RngStateCount); - int bytes = sizeof(RngStateType)*saved.size(); - std::vector gcoor; - - for(int gidx=0;gidxGlobalIndexToGlobalCoor(gidx,gcoor); - grid->GlobalCoorToRankIndex(rank,o_idx,i_idx,gcoor); - int l_idx=parallel.generator_idx(o_idx,i_idx); - - if ( grid->IsBoss() ) { - fin.read((char *)&saved[0],bytes); - Uint32Checksum((uint32_t *)&saved[0],bytes,csum); - } - - grid->Broadcast(0,(void *)&saved[0],bytes); - - if( rank == grid->ThisRank() ){ - parallel.SetState(saved,l_idx); - } - - } - - if ( grid->IsBoss() ) { - fin.read((char *)&saved[0],bytes); - serial.SetState(saved,0); - Uint32Checksum((uint32_t *)&saved[0],bytes,csum); - } - - grid->Broadcast(0,(void *)&csum,sizeof(csum)); - - return csum; - } - - - template - static inline uint32_t readObjectParallel(Lattice &Umu,std::string file,munger munge,int offset,const std::string &format) - { - typedef typename vobj::scalar_object sobj; - - GridBase *grid = Umu._grid; - - int ieee32big = (format == std::string("IEEE32BIG")); - int ieee32 = (format == std::string("IEEE32")); - int ieee64big = (format == std::string("IEEE64BIG")); - int ieee64 = (format == std::string("IEEE64")); - - - // Take into account block size of parallel file systems want about - // 4-16MB chunks. - // Ideally one reader/writer per xy plane and read these contiguously - // with comms from nominated I/O nodes. - std::ifstream fin; - - int nd = grid->_ndimension; - std::vector parallel(nd,1); - std::vector ioproc (nd); - std::vector start(nd); - std::vector range(nd); - - for(int d=0;dCheckerBoarded(d) == 0); - } - - uint64_t slice_vol = 1; - - int IOnode = 1; - for(int d=0;d_ndimension;d++) { - - if ( d == 0 ) parallel[d] = 0; - if (parallel[d]) { - range[d] = grid->_ldimensions[d]; - start[d] = grid->_processor_coor[d]*range[d]; - ioproc[d]= grid->_processor_coor[d]; + if ( (control & BINARYIO_LEXICOGRAPHIC) && (nrank > 1) ) { +#ifdef USE_MPI_IO + std::cout<< GridLogMessage<< "MPI read I/O "<< file<< std::endl; + ierr=MPI_File_open(grid->communicator,(char *) file.c_str(), MPI_MODE_RDONLY, MPI_INFO_NULL, &fh); assert(ierr==0); + ierr=MPI_File_set_view(fh, disp, mpiObject, fileArray, "native", MPI_INFO_NULL); assert(ierr==0); + ierr=MPI_File_read_all(fh, &iodata[0], 1, localArray, &status); assert(ierr==0); + MPI_File_close(&fh); + MPI_Type_free(&fileArray); + MPI_Type_free(&localArray); +#else + assert(0); +#endif } else { - range[d] = grid->_gdimensions[d]; - start[d] = 0; - ioproc[d]= 0; - - if ( grid->_processor_coor[d] != 0 ) IOnode = 0; + std::cout << GridLogMessage << "C++ read I/O " << file << " : " + << iodata.size() * sizeof(fobj) << " bytes" << std::endl; + std::ifstream fin; + fin.open(file, std::ios::binary | std::ios::in); + if (control & BINARYIO_MASTER_APPEND) + { + fin.seekg(-sizeof(fobj), fin.end); + } + else + { + fin.seekg(offset + myrank * lsites * sizeof(fobj)); + } + fin.read((char *)&iodata[0], iodata.size() * sizeof(fobj)); + assert(fin.fail() == 0); + fin.close(); } - slice_vol = slice_vol * range[d]; + timer.Stop(); + + grid->Barrier(); + + bstimer.Start(); + ScidacChecksum(grid,iodata,scidac_csuma,scidac_csumb); + if (ieee32big) be32toh_v((void *)&iodata[0], sizeof(fobj)*iodata.size()); + if (ieee32) le32toh_v((void *)&iodata[0], sizeof(fobj)*iodata.size()); + if (ieee64big) be64toh_v((void *)&iodata[0], sizeof(fobj)*iodata.size()); + if (ieee64) le64toh_v((void *)&iodata[0], sizeof(fobj)*iodata.size()); + NerscChecksum(grid,iodata,nersc_csum); + bstimer.Stop(); } + + if ( control & BINARYIO_WRITE ) { - { - uint32_t tmp = IOnode; - grid->GlobalSum(tmp); - std::cout<< std::dec ; - std::cout<< GridLogMessage<< "Parallel read I/O to "<< file << " with " <_ndimension;d++){ - std::cout<< range[d]; - if( d< grid->_ndimension-1 ) - std::cout<< " x "; - } - std::cout << std::endl; - } + bstimer.Start(); + NerscChecksum(grid,iodata,nersc_csum); + if (ieee32big) htobe32_v((void *)&iodata[0], sizeof(fobj)*iodata.size()); + if (ieee32) htole32_v((void *)&iodata[0], sizeof(fobj)*iodata.size()); + if (ieee64big) htobe64_v((void *)&iodata[0], sizeof(fobj)*iodata.size()); + if (ieee64) htole64_v((void *)&iodata[0], sizeof(fobj)*iodata.size()); + ScidacChecksum(grid,iodata,scidac_csuma,scidac_csumb); + bstimer.Stop(); - GridStopWatch timer; timer.Start(); - uint64_t bytes=0; + grid->Barrier(); - int myrank = grid->ThisRank(); - int iorank = grid->RankFromProcessorCoor(ioproc); + timer.Start(); + if ( (control & BINARYIO_LEXICOGRAPHIC) && (nrank > 1) ) { +#ifdef USE_MPI_IO + std::cout << GridLogMessage << "MPI write I/O " << file << std::endl; + ierr = MPI_File_open(grid->communicator, (char *)file.c_str(), MPI_MODE_RDWR | MPI_MODE_CREATE, MPI_INFO_NULL, &fh); + std::cout << GridLogMessage << "Checking for errors" << std::endl; + if (ierr != MPI_SUCCESS) + { + char error_string[BUFSIZ]; + int length_of_error_string, error_class; - if ( IOnode ) { - fin.open(file,std::ios::binary|std::ios::in); - } + MPI_Error_class(ierr, &error_class); + MPI_Error_string(error_class, error_string, &length_of_error_string); + fprintf(stderr, "%3d: %s\n", myrank, error_string); + MPI_Error_string(ierr, error_string, &length_of_error_string); + fprintf(stderr, "%3d: %s\n", myrank, error_string); + MPI_Abort(MPI_COMM_WORLD, 1); //assert(ierr == 0); + } - ////////////////////////////////////////////////////////// - // Find the location of each site and send to primary node - // Take loop order from Chroma; defines loop order now that NERSC doc no longer - // available (how short sighted is that?) - ////////////////////////////////////////////////////////// - Umu = zero; - static uint32_t csum; csum=0; - fobj fileObj; - static sobj siteObj; // Static to place in symmetric region for SHMEM + std::cout << GridLogDebug << "MPI read I/O set view " << file << std::endl; + ierr = MPI_File_set_view(fh, disp, mpiObject, fileArray, "native", MPI_INFO_NULL); + assert(ierr == 0); - // need to implement these loops in Nd independent way with a lexico conversion - for(int tlex=0;tlex tsite(nd); // temporary mixed up site - std::vector gsite(nd); - std::vector lsite(nd); - std::vector iosite(nd); + std::cout << GridLogDebug << "MPI read I/O write all " << file << std::endl; + ierr = MPI_File_write_all(fh, &iodata[0], 1, localArray, &status); + assert(ierr == 0); - Lexicographic::CoorFromIndex(tsite,tlex,range); - - for(int d=0;d_ldimensions[d]; // local site - gsite[d] = tsite[d]+start[d]; // global site - } - - ///////////////////////// - // Get the rank of owner of data - ///////////////////////// - int rank, o_idx,i_idx, g_idx; - grid->GlobalCoorToRankIndex(rank,o_idx,i_idx,gsite); - grid->GlobalCoorToGlobalIndex(gsite,g_idx); - - //////////////////////////////// - // iorank reads from the seek - //////////////////////////////// - if (myrank == iorank) { - - fin.seekg(offset+g_idx*sizeof(fileObj)); - fin.read((char *)&fileObj,sizeof(fileObj)); - bytes+=sizeof(fileObj); - - if(ieee32big) be32toh_v((void *)&fileObj,sizeof(fileObj)); - if(ieee32) le32toh_v((void *)&fileObj,sizeof(fileObj)); - if(ieee64big) be64toh_v((void *)&fileObj,sizeof(fileObj)); - if(ieee64) le64toh_v((void *)&fileObj,sizeof(fileObj)); - - munge(fileObj,siteObj,csum); - - } - - // Possibly do transport through pt2pt - if ( rank != iorank ) { - if ( (myrank == rank) || (myrank==iorank) ) { - grid->SendRecvPacket((void *)&siteObj,(void *)&siteObj,iorank,rank,sizeof(siteObj)); + MPI_File_close(&fh); + MPI_Type_free(&fileArray); + MPI_Type_free(&localArray); +#else + assert(0); +#endif + } else { + + std::ofstream fout; + fout.exceptions ( std::fstream::failbit | std::fstream::badbit ); + try { + fout.open(file,std::ios::binary|std::ios::out|std::ios::in); + } catch (const std::fstream::failure& exc) { + std::cout << GridLogError << "Error in opening the file " << file << " for output" <Barrier(); // necessary? - } - grid->GlobalSum(csum); - grid->GlobalSum(bytes); + fout.close(); + } + timer.Stop(); + } + + std::cout<Barrier(); + grid->GlobalSum(nersc_csum); + grid->GlobalXOR(scidac_csuma); + grid->GlobalXOR(scidac_csumb); + grid->Barrier(); + } + } + + ///////////////////////////////////////////////////////////////////////////// + // Read a Lattice of object + ////////////////////////////////////////////////////////////////////////////////////// + template + static inline void readLatticeObject(Lattice &Umu, + std::string file, + munger munge, + int offset, + const std::string &format, + uint32_t &nersc_csum, + uint32_t &scidac_csuma, + uint32_t &scidac_csumb) + { + typedef typename vobj::scalar_object sobj; + typedef typename vobj::Realified::scalar_type word; word w=0; + + GridBase *grid = Umu._grid; + int lsites = grid->lSites(); + + std::vector scalardata(lsites); + std::vector iodata(lsites); // Munge, checksum, byte order in here + + IOobject(w,grid,iodata,file,offset,format,BINARYIO_READ|BINARYIO_LEXICOGRAPHIC, + nersc_csum,scidac_csuma,scidac_csumb); + + GridStopWatch timer; + timer.Start(); + + parallel_for(int x=0;xBarrier(); timer.Stop(); - std::cout< - static inline uint32_t writeObjectParallel(Lattice &Umu,std::string file,munger munge,int offset,const std::string & format) + static inline void writeLatticeObject(Lattice &Umu, + std::string file, + munger munge, + int offset, + const std::string &format, + uint32_t &nersc_csum, + uint32_t &scidac_csuma, + uint32_t &scidac_csumb) { typedef typename vobj::scalar_object sobj; + typedef typename vobj::Realified::scalar_type word; word w=0; GridBase *grid = Umu._grid; + int lsites = grid->lSites(); - int ieee32big = (format == std::string("IEEE32BIG")); - int ieee32 = (format == std::string("IEEE32")); - int ieee64big = (format == std::string("IEEE64BIG")); - int ieee64 = (format == std::string("IEEE64")); - - int nd = grid->_ndimension; - for(int d=0;dCheckerBoarded(d) == 0); - } - - std::vector parallel(nd,1); - std::vector ioproc (nd); - std::vector start(nd); - std::vector range(nd); - - uint64_t slice_vol = 1; - - int IOnode = 1; - - for(int d=0;d_ndimension;d++) { - - if ( d!= grid->_ndimension-1 ) parallel[d] = 0; - - if (parallel[d]) { - range[d] = grid->_ldimensions[d]; - start[d] = grid->_processor_coor[d]*range[d]; - ioproc[d]= grid->_processor_coor[d]; - } else { - range[d] = grid->_gdimensions[d]; - start[d] = 0; - ioproc[d]= 0; - - if ( grid->_processor_coor[d] != 0 ) IOnode = 0; - } - - slice_vol = slice_vol * range[d]; - } - - { - uint32_t tmp = IOnode; - grid->GlobalSum(tmp); - std::cout<< GridLogMessage<< "Parallel write I/O from "<< file << " with " <_ndimension;d++){ - std::cout<< range[d]; - if( d< grid->_ndimension-1 ) - std::cout<< " x "; - } - std::cout << std::endl; - } + std::vector scalardata(lsites); + std::vector iodata(lsites); // Munge, checksum, byte order in here + ////////////////////////////////////////////////////////////////////////////// + // Munge [ .e.g 3rd row recon ] + ////////////////////////////////////////////////////////////////////////////// GridStopWatch timer; timer.Start(); - uint64_t bytes=0; + unvectorizeToLexOrdArray(scalardata,Umu); - int myrank = grid->ThisRank(); - int iorank = grid->RankFromProcessorCoor(ioproc); + parallel_for(int x=0;xBarrier(); + timer.Stop(); - ////////////////////////////////////////////////////////// - // Find the location of each site and send to primary node - // Take loop order from Chroma; defines loop order now that NERSC doc no longer - // available (how short sighted is that?) - ////////////////////////////////////////////////////////// + IOobject(w,grid,iodata,file,offset,format,BINARYIO_WRITE|BINARYIO_LEXICOGRAPHIC, + nersc_csum,scidac_csuma,scidac_csumb); - uint32_t csum=0; - fobj fileObj; - static sobj siteObj; // static for SHMEM target; otherwise dynamic allocate with AlignedAllocator - - // should aggregate a whole chunk and then write. - // need to implement these loops in Nd independent way with a lexico conversion - for(int tlex=0;tlex tsite(nd); // temporary mixed up site - std::vector gsite(nd); - std::vector lsite(nd); - std::vector iosite(nd); - - Lexicographic::CoorFromIndex(tsite,tlex,range); - - for(int d=0;d_ldimensions[d]; // local site - gsite[d] = tsite[d]+start[d]; // global site - } - - - ///////////////////////// - // Get the rank of owner of data - ///////////////////////// - int rank, o_idx,i_idx, g_idx; - grid->GlobalCoorToRankIndex(rank,o_idx,i_idx,gsite); - grid->GlobalCoorToGlobalIndex(gsite,g_idx); - - //////////////////////////////// - // iorank writes from the seek - //////////////////////////////// - - // Owner of data peeks it - peekLocalSite(siteObj,Umu,lsite); - - // Pair of nodes may need to do pt2pt send - if ( rank != iorank ) { // comms is necessary - if ( (myrank == rank) || (myrank==iorank) ) { // and we have to do it - // Send to IOrank - grid->SendRecvPacket((void *)&siteObj,(void *)&siteObj,rank,iorank,sizeof(siteObj)); + std::cout<Barrier(); // necessary? - - if (myrank == iorank) { - munge(siteObj,fileObj,csum); + ///////////////////////////////////////////////////////////////////////////// + // Read a RNG; use IOobject and lexico map to an array of state + ////////////////////////////////////////////////////////////////////////////////////// + static inline void readRNG(GridSerialRNG &serial, + GridParallelRNG ¶llel, + std::string file, + int offset, + uint32_t &nersc_csum, + uint32_t &scidac_csuma, + uint32_t &scidac_csumb) + { + typedef typename GridSerialRNG::RngStateType RngStateType; + const int RngStateCount = GridSerialRNG::RngStateCount; + typedef std::array RNGstate; + typedef RngStateType word; word w=0; - if(ieee32big) htobe32_v((void *)&fileObj,sizeof(fileObj)); - if(ieee32) htole32_v((void *)&fileObj,sizeof(fileObj)); - if(ieee64big) htobe64_v((void *)&fileObj,sizeof(fileObj)); - if(ieee64) htole64_v((void *)&fileObj,sizeof(fileObj)); - - fout.seekp(offset+g_idx*sizeof(fileObj)); - fout.write((char *)&fileObj,sizeof(fileObj)); - bytes+=sizeof(fileObj); - } + std::string format = "IEEE32BIG"; + + GridBase *grid = parallel._grid; + int gsites = grid->gSites(); + int lsites = grid->lSites(); + + uint32_t nersc_csum_tmp = 0; + uint32_t scidac_csuma_tmp = 0; + uint32_t scidac_csumb_tmp = 0; + + GridStopWatch timer; + + std::cout << GridLogMessage << "RNG read I/O on file " << file << std::endl; + + std::vector iodata(lsites); + IOobject(w,grid,iodata,file,offset,format,BINARYIO_READ|BINARYIO_LEXICOGRAPHIC, + nersc_csum,scidac_csuma,scidac_csumb); + + timer.Start(); + parallel_for(int lidx=0;lidx tmp(RngStateCount); + std::copy(iodata[lidx].begin(),iodata[lidx].end(),tmp.begin()); + parallel.SetState(tmp,lidx); + } + timer.Stop(); + + iodata.resize(1); + IOobject(w,grid,iodata,file,offset,format,BINARYIO_READ|BINARYIO_MASTER_APPEND, + nersc_csum_tmp,scidac_csuma_tmp,scidac_csumb_tmp); + + { + std::vector tmp(RngStateCount); + std::copy(iodata[0].begin(),iodata[0].end(),tmp.begin()); + serial.SetState(tmp,0); } - grid->GlobalSum(csum); - grid->GlobalSum(bytes); + nersc_csum = nersc_csum + nersc_csum_tmp; + scidac_csuma = scidac_csuma ^ scidac_csuma_tmp; + scidac_csumb = scidac_csumb ^ scidac_csumb_tmp; - timer.Stop(); - std::cout< RNGstate; + GridBase *grid = parallel._grid; + int gsites = grid->gSites(); + int lsites = grid->lSites(); + + uint32_t nersc_csum_tmp; + uint32_t scidac_csuma_tmp; + uint32_t scidac_csumb_tmp; + + GridStopWatch timer; + std::string format = "IEEE32BIG"; + + std::cout << GridLogMessage << "RNG write I/O on file " << file << std::endl; + + timer.Start(); + std::vector iodata(lsites); + parallel_for(int lidx=0;lidx tmp(RngStateCount); + parallel.GetState(tmp,lidx); + std::copy(tmp.begin(),tmp.end(),iodata[lidx].begin()); + } + timer.Stop(); + + IOobject(w,grid,iodata,file,offset,format,BINARYIO_WRITE|BINARYIO_LEXICOGRAPHIC, + nersc_csum,scidac_csuma,scidac_csumb); + + iodata.resize(1); + { + std::vector tmp(RngStateCount); + serial.GetState(tmp,0); + std::copy(tmp.begin(),tmp.end(),iodata[0].begin()); + } + IOobject(w,grid,iodata,file,offset,format,BINARYIO_WRITE|BINARYIO_MASTER_APPEND, + nersc_csum_tmp,scidac_csuma_tmp,scidac_csumb_tmp); + + nersc_csum = nersc_csum + nersc_csum_tmp; + scidac_csuma = scidac_csuma ^ scidac_csuma_tmp; + scidac_csumb = scidac_csumb ^ scidac_csumb_tmp; + + std::cout << GridLogMessage << "RNG file checksum " << std::hex << nersc_csum << std::dec << std::endl; + std::cout << GridLogMessage << "RNG file checksuma " << std::hex << scidac_csuma << std::dec << std::endl; + std::cout << GridLogMessage << "RNG file checksumb " << std::hex << scidac_csumb << std::dec << std::endl; + std::cout << GridLogMessage << "RNG state overhead " << timer.Elapsed() << std::endl; + } }; - } - #endif diff --git a/lib/parallelIO/IldgIO.h b/lib/parallelIO/IldgIO.h new file mode 100644 index 00000000..17ce4a06 --- /dev/null +++ b/lib/parallelIO/IldgIO.h @@ -0,0 +1,716 @@ +/************************************************************************************* + +Grid physics library, www.github.com/paboyle/Grid + +Source file: ./lib/parallelIO/IldgIO.h + +Copyright (C) 2015 + +This program is free software; you can redistribute it and/or modify +it under the terms of the GNU General Public License as published by +the Free Software Foundation; either version 2 of the License, or +(at your option) any later version. + +This program is distributed in the hope that it will be useful, +but WITHOUT ANY WARRANTY; without even the implied warranty of +MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +GNU General Public License for more details. + +You should have received a copy of the GNU General Public License along +with this program; if not, write to the Free Software Foundation, Inc., +51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA. + +See the full license in the file "LICENSE" in the top level distribution +directory +*************************************************************************************/ +/* END LEGAL */ +#ifndef GRID_ILDG_IO_H +#define GRID_ILDG_IO_H + +#ifdef HAVE_LIME +#include +#include +#include +#include +#include + +#include +#include +#include + +//C-Lime is a must have for this functionality +extern "C" { +#include "lime.h" +} + +namespace Grid { +namespace QCD { + + ///////////////////////////////// + // Encode word types as strings + ///////////////////////////////// + template inline std::string ScidacWordMnemonic(void){ return std::string("unknown"); } + template<> inline std::string ScidacWordMnemonic (void){ return std::string("D"); } + template<> inline std::string ScidacWordMnemonic (void){ return std::string("F"); } + template<> inline std::string ScidacWordMnemonic< int32_t>(void){ return std::string("I32_t"); } + template<> inline std::string ScidacWordMnemonic(void){ return std::string("U32_t"); } + template<> inline std::string ScidacWordMnemonic< int64_t>(void){ return std::string("I64_t"); } + template<> inline std::string ScidacWordMnemonic(void){ return std::string("U64_t"); } + + ///////////////////////////////////////// + // Encode a generic tensor as a string + ///////////////////////////////////////// + template std::string ScidacRecordTypeString(int &colors, int &spins, int & typesize,int &datacount) { + + typedef typename getPrecision::real_scalar_type stype; + + int _ColourN = indexRank(); + int _ColourScalar = isScalar(); + int _ColourVector = isVector(); + int _ColourMatrix = isMatrix(); + + int _SpinN = indexRank(); + int _SpinScalar = isScalar(); + int _SpinVector = isVector(); + int _SpinMatrix = isMatrix(); + + int _LorentzN = indexRank(); + int _LorentzScalar = isScalar(); + int _LorentzVector = isVector(); + int _LorentzMatrix = isMatrix(); + + std::stringstream stream; + + stream << "GRID_"; + stream << ScidacWordMnemonic(); + + // std::cout << " Lorentz N/S/V/M : " << _LorentzN<<" "<<_LorentzScalar<<"/"<<_LorentzVector<<"/"<<_LorentzMatrix< std::string ScidacRecordTypeString(Lattice & lat,int &colors, int &spins, int & typesize,int &datacount) { + return ScidacRecordTypeString(colors,spins,typesize,datacount); + }; + + + //////////////////////////////////////////////////////////// + // Helper to fill out metadata + //////////////////////////////////////////////////////////// + template void ScidacMetaData(Lattice & field, + FieldMetaData &header, + scidacRecord & _scidacRecord, + scidacFile & _scidacFile) + { + typedef typename getPrecision::real_scalar_type stype; + + ///////////////////////////////////// + // Pull Grid's metadata + ///////////////////////////////////// + PrepareMetaData(field,header); + + ///////////////////////////////////// + // Scidac Private File structure + ///////////////////////////////////// + _scidacFile = scidacFile(field._grid); + + ///////////////////////////////////// + // Scidac Private Record structure + ///////////////////////////////////// + scidacRecord sr; + sr.datatype = ScidacRecordTypeString(field,sr.colors,sr.spins,sr.typesize,sr.datacount); + sr.date = header.creation_date; + sr.precision = ScidacWordMnemonic(); + sr.recordtype = GRID_IO_FIELD; + + _scidacRecord = sr; + + std::cout << GridLogMessage << "Build SciDAC datatype " < + void readLimeLatticeBinaryObject(Lattice &field,std::string record_name) + { + typedef typename vobj::scalar_object sobj; + scidacChecksum scidacChecksum_; + uint32_t nersc_csum,scidac_csuma,scidac_csumb; + + std::string format = getFormatString(); + + while ( limeReaderNextRecord(LimeR) == LIME_SUCCESS ) { + + std::cout << GridLogMessage << limeReaderType(LimeR) < munge; + BinaryIO::readLatticeObject< sobj, sobj >(field, filename, munge, offset, format,nersc_csum,scidac_csuma,scidac_csumb); + + ///////////////////////////////////////////// + // Insist checksum is next record + ///////////////////////////////////////////// + readLimeObject(scidacChecksum_,std::string("scidacChecksum"),record_name); + + ///////////////////////////////////////////// + // Verify checksums + ///////////////////////////////////////////// + scidacChecksumVerify(scidacChecksum_,scidac_csuma,scidac_csumb); + return; + } + } + } + //////////////////////////////////////////// + // Read a generic serialisable object + //////////////////////////////////////////// + template + void readLimeObject(serialisable_object &object,std::string object_name,std::string record_name) + { + std::string xmlstring; + // should this be a do while; can we miss a first record?? + while ( limeReaderNextRecord(LimeR) == LIME_SUCCESS ) { + + uint64_t nbytes = limeReaderBytes(LimeR);//size of this record (configuration) + + if ( strncmp(limeReaderType(LimeR), record_name.c_str(),strlen(record_name.c_str()) ) ) { + std::vector xmlc(nbytes+1,'\0'); + limeReaderReadData((void *)&xmlc[0], &nbytes, LimeR); + XmlReader RD(&xmlc[0],""); + read(RD,object_name,object); + return; + } + + } + assert(0); + } +}; + +class GridLimeWriter : public BinaryIO { + public: + /////////////////////////////////////////////////// + // FIXME: format for RNG? Now just binary out instead + /////////////////////////////////////////////////// + + FILE *File; + LimeWriter *LimeW; + std::string filename; + + void open(std::string &_filename) { + filename= _filename; + File = fopen(filename.c_str(), "w"); + LimeW = limeCreateWriter(File); assert(LimeW != NULL ); + } + ///////////////////////////////////////////// + // Close the file + ///////////////////////////////////////////// + void close(void) { + fclose(File); + // limeDestroyWriter(LimeW); + } + /////////////////////////////////////////////////////// + // Lime utility functions + /////////////////////////////////////////////////////// + int createLimeRecordHeader(std::string message, int MB, int ME, size_t PayloadSize) + { + LimeRecordHeader *h; + h = limeCreateHeader(MB, ME, const_cast(message.c_str()), PayloadSize); + assert(limeWriteRecordHeader(h, LimeW) >= 0); + limeDestroyHeader(h); + return LIME_SUCCESS; + } + //////////////////////////////////////////// + // Write a generic serialisable object + //////////////////////////////////////////// + template + void writeLimeObject(int MB,int ME,serialisable_object &object,std::string object_name,std::string record_name) + { + std::string xmlstring; + { + XmlWriter WR("",""); + write(WR,object_name,object); + xmlstring = WR.XmlString(); + } + uint64_t nbytes = xmlstring.size(); + int err; + LimeRecordHeader *h = limeCreateHeader(MB, ME,(char *)record_name.c_str(), nbytes); assert(h!= NULL); + + err=limeWriteRecordHeader(h, LimeW); assert(err>=0); + err=limeWriteRecordData(&xmlstring[0], &nbytes, LimeW); assert(err>=0); + err=limeWriterCloseRecord(LimeW); assert(err>=0); + limeDestroyHeader(h); + } + //////////////////////////////////////////// + // Write a generic lattice field and csum + //////////////////////////////////////////// + template + void writeLimeLatticeBinaryObject(Lattice &field,std::string record_name) + { + //////////////////////////////////////////// + // Create record header + //////////////////////////////////////////// + typedef typename vobj::scalar_object sobj; + int err; + uint32_t nersc_csum,scidac_csuma,scidac_csumb; + uint64_t PayloadSize = sizeof(sobj) * field._grid->_gsites; + createLimeRecordHeader(record_name, 0, 0, PayloadSize); + + //////////////////////////////////////////////////////////////////// + // NB: FILE and iostream are jointly writing disjoint sequences in the + // the same file through different file handles (integer units). + // + // These are both buffered, so why I think this code is right is as follows. + // + // i) write record header to FILE *File, telegraphing the size. + // ii) ftell reads the offset from FILE *File . + // iii) iostream / MPI Open independently seek this offset. Write sequence direct to disk. + // Closes iostream and flushes. + // iv) fseek on FILE * to end of this disjoint section. + // v) Continue writing scidac record. + //////////////////////////////////////////////////////////////////// + off_t offset = ftell(File); + std::string format = getFormatString(); + BinarySimpleMunger munge; + BinaryIO::writeLatticeObject(field, filename, munge, offset, format,nersc_csum,scidac_csuma,scidac_csumb); + err=limeWriterCloseRecord(LimeW); assert(err>=0); + //////////////////////////////////////// + // Write checksum element, propagaing forward from the BinaryIO + // Always pair a checksum with a binary object, and close message + //////////////////////////////////////// + scidacChecksum checksum; + std::stringstream streama; streama << std::hex << scidac_csuma; + std::stringstream streamb; streamb << std::hex << scidac_csumb; + checksum.suma= streama.str(); + checksum.sumb= streamb.str(); + std::cout << GridLogMessage<<" writing scidac checksums "< + void writeScidacFileRecord(GridBase *grid,SerialisableUserFile &_userFile) + { + scidacFile _scidacFile(grid); + writeLimeObject(1,0,_scidacFile,_scidacFile.SerialisableClassName(),std::string(SCIDAC_PRIVATE_FILE_XML)); + writeLimeObject(0,1,_userFile,_userFile.SerialisableClassName(),std::string(SCIDAC_FILE_XML)); + } + //////////////////////////////////////////////// + // Write generic lattice field in scidac format + //////////////////////////////////////////////// + template + void writeScidacFieldRecord(Lattice &field,userRecord _userRecord) + { + typedef typename vobj::scalar_object sobj; + uint64_t nbytes; + GridBase * grid = field._grid; + + //////////////////////////////////////// + // fill the Grid header + //////////////////////////////////////// + FieldMetaData header; + scidacRecord _scidacRecord; + scidacFile _scidacFile; + + ScidacMetaData(field,header,_scidacRecord,_scidacFile); + + ////////////////////////////////////////////// + // Fill the Lime file record by record + ////////////////////////////////////////////// + writeLimeObject(1,0,header ,std::string("FieldMetaData"),std::string(GRID_FORMAT)); // Open message + writeLimeObject(0,0,_userRecord,_userRecord.SerialisableClassName(),std::string(SCIDAC_RECORD_XML)); + writeLimeObject(0,0,_scidacRecord,_scidacRecord.SerialisableClassName(),std::string(SCIDAC_PRIVATE_RECORD_XML)); + writeLimeLatticeBinaryObject(field,std::string(ILDG_BINARY_DATA)); // Closes message with checksum + } +}; + +class IldgWriter : public ScidacWriter { + public: + + /////////////////////////////////// + // A little helper + /////////////////////////////////// + void writeLimeIldgLFN(std::string &LFN) + { + uint64_t PayloadSize = LFN.size(); + int err; + createLimeRecordHeader(ILDG_DATA_LFN, 0 , 0, PayloadSize); + err=limeWriteRecordData(const_cast(LFN.c_str()), &PayloadSize,LimeW); assert(err>=0); + err=limeWriterCloseRecord(LimeW); assert(err>=0); + } + + //////////////////////////////////////////////////////////////// + // Special ILDG operations ; gauge configs only. + // Don't require scidac records EXCEPT checksum + // Use Grid MetaData object if present. + //////////////////////////////////////////////////////////////// + template + void writeConfiguration(Lattice > &Umu,int sequence,std::string LFN,std::string description) + { + GridBase * grid = Umu._grid; + typedef Lattice > GaugeField; + typedef iLorentzColourMatrix vobj; + typedef typename vobj::scalar_object sobj; + + uint64_t nbytes; + + //////////////////////////////////////// + // fill the Grid header + //////////////////////////////////////// + FieldMetaData header; + scidacRecord _scidacRecord; + scidacFile _scidacFile; + + ScidacMetaData(Umu,header,_scidacRecord,_scidacFile); + + std::string format = header.floating_point; + header.ensemble_id = description; + header.ensemble_label = description; + header.sequence_number = sequence; + header.ildg_lfn = LFN; + + assert ( (format == std::string("IEEE32BIG")) + ||(format == std::string("IEEE64BIG")) ); + + ////////////////////////////////////////////////////// + // Fill ILDG header data struct + ////////////////////////////////////////////////////// + ildgFormat ildgfmt ; + ildgfmt.field = std::string("su3gauge"); + + if ( format == std::string("IEEE32BIG") ) { + ildgfmt.precision = 32; + } else { + ildgfmt.precision = 64; + } + ildgfmt.version = 1.0; + ildgfmt.lx = header.dimension[0]; + ildgfmt.ly = header.dimension[1]; + ildgfmt.lz = header.dimension[2]; + ildgfmt.lt = header.dimension[3]; + assert(header.nd==4); + assert(header.nd==header.dimension.size()); + + ////////////////////////////////////////////////////////////////////////////// + // Fill the USQCD info field + ////////////////////////////////////////////////////////////////////////////// + usqcdInfo info; + info.version=1.0; + info.plaq = header.plaquette; + info.linktr = header.link_trace; + + std::cout << GridLogMessage << " Writing config; IldgIO "< + void readConfiguration(Lattice > &Umu, FieldMetaData &FieldMetaData_) { + + typedef Lattice > GaugeField; + typedef typename GaugeField::vector_object vobj; + typedef typename vobj::scalar_object sobj; + + typedef LorentzColourMatrixF fobj; + typedef LorentzColourMatrixD dobj; + + GridBase *grid = Umu._grid; + + std::vector dims = Umu._grid->FullDimensions(); + + assert(dims.size()==4); + + // Metadata holders + ildgFormat ildgFormat_ ; + std::string ildgLFN_ ; + scidacChecksum scidacChecksum_; + usqcdInfo usqcdInfo_ ; + + // track what we read from file + int found_ildgFormat =0; + int found_ildgLFN =0; + int found_scidacChecksum=0; + int found_usqcdInfo =0; + int found_ildgBinary =0; + int found_FieldMetaData =0; + + uint32_t nersc_csum; + uint32_t scidac_csuma; + uint32_t scidac_csumb; + + // Binary format + std::string format; + + ////////////////////////////////////////////////////////////////////////// + // Loop over all records + // -- Order is poorly guaranteed except ILDG header preceeds binary section. + // -- Run like an event loop. + // -- Impose trust hierarchy. Grid takes precedence & look for ILDG, and failing + // that Scidac. + // -- Insist on Scidac checksum record. + ////////////////////////////////////////////////////////////////////////// + + while ( limeReaderNextRecord(LimeR) == LIME_SUCCESS ) { + + uint64_t nbytes = limeReaderBytes(LimeR);//size of this record (configuration) + + ////////////////////////////////////////////////////////////////// + // If not BINARY_DATA read a string and parse + ////////////////////////////////////////////////////////////////// + if ( strncmp(limeReaderType(LimeR), ILDG_BINARY_DATA,strlen(ILDG_BINARY_DATA) ) ) { + + // Copy out the string + std::vector xmlc(nbytes+1,'\0'); + limeReaderReadData((void *)&xmlc[0], &nbytes, LimeR); + std::cout << GridLogMessage<< "Non binary record :" < munge; + BinaryIO::readLatticeObject< vobj, dobj >(Umu, filename, munge, offset, format,nersc_csum,scidac_csuma,scidac_csumb); + } else { + GaugeSimpleMunger munge; + BinaryIO::readLatticeObject< vobj, fobj >(Umu, filename, munge, offset, format,nersc_csum,scidac_csuma,scidac_csumb); + } + + found_ildgBinary = 1; + } + + } + + ////////////////////////////////////////////////////// + // Minimally must find binary segment and checksum + // Since this is an ILDG reader require ILDG format + ////////////////////////////////////////////////////// + assert(found_ildgBinary); + assert(found_ildgFormat); + assert(found_scidacChecksum); + + // Must find something with the lattice dimensions + assert(found_FieldMetaData||found_ildgFormat); + + if ( found_FieldMetaData ) { + + std::cout << GridLogMessage<<"Grid MetaData was record found: configuration was probably written by Grid ! Yay ! "<1.1416 16 16 32 0 +//////////////////////// +struct scidacFile : Serializable { + public: + GRID_SERIALIZABLE_CLASS_MEMBERS(scidacFile, + double, version, + int, spacetime, + std::string, dims, // must convert to int + int, volfmt); + + std::vector getDimensions(void) { + std::stringstream stream(dims); + std::vector dimensions; + int n; + while(stream >> n){ + dimensions.push_back(n); + } + return dimensions; + } + + void setDimensions(std::vector dimensions) { + char delimiter = ' '; + std::stringstream stream; + for(int i=0;i_ndimension; + setDimensions(grid->FullDimensions()); + volfmt = GRID_IO_SINGLEFILE; + } + +}; + +/////////////////////////////////////////////////////////////////////// +// scidac-private-record-xml : example +// +// 1.1Tue Jul 26 21:14:44 2011 UTC0 +// QDP_D3_ColorMatrixD34 +// 1444 +// +/////////////////////////////////////////////////////////////////////// + +struct scidacRecord : Serializable { + public: + GRID_SERIALIZABLE_CLASS_MEMBERS(scidacRecord, + double, version, + std::string, date, + int, recordtype, + std::string, datatype, + std::string, precision, + int, colors, + int, spins, + int, typesize, + int, datacount); + + scidacRecord() { version =1.0; } + +}; + +//////////////////////// +// ILDG format +//////////////////////// +struct ildgFormat : Serializable { +public: + GRID_SERIALIZABLE_CLASS_MEMBERS(ildgFormat, + double, version, + std::string, field, + int, precision, + int, lx, + int, ly, + int, lz, + int, lt); + ildgFormat() { version=1.0; }; +}; +//////////////////////// +// USQCD info +//////////////////////// +struct usqcdInfo : Serializable { + public: + GRID_SERIALIZABLE_CLASS_MEMBERS(usqcdInfo, + double, version, + double, plaq, + double, linktr, + std::string, info); + usqcdInfo() { + version=1.0; + }; +}; +//////////////////////// +// Scidac Checksum +//////////////////////// +struct scidacChecksum : Serializable { + public: + GRID_SERIALIZABLE_CLASS_MEMBERS(scidacChecksum, + double, version, + std::string, suma, + std::string, sumb); + scidacChecksum() { + version=1.0; + }; +}; +//////////////////////////////////////////////////////////////////////////////////////////////////////////////// +// Type: scidac-file-xml MILC ILDG archival gauge configuration +//////////////////////////////////////////////////////////////////////////////////////////////////////////////// + +//////////////////////////////////////////////////////////////////////////////////////////////////////////////// +// Type: +//////////////////////////////////////////////////////////////////////////////////////////////////////////////// + +//////////////////////// +// Scidac private file xml +// 1.1416 16 16 32 0 +//////////////////////// + +#if 0 +//////////////////////////////////////////////////////////////////////////////////////// +// From http://www.physics.utah.edu/~detar/scidac/qio_2p3.pdf +//////////////////////////////////////////////////////////////////////////////////////// +struct usqcdPropFile : Serializable { + public: + GRID_SERIALIZABLE_CLASS_MEMBERS(usqcdPropFile, + double, version, + std::string, type, + std::string, info); + usqcdPropFile() { + version=1.0; + }; +}; +struct usqcdSourceInfo : Serializable { + public: + GRID_SERIALIZABLE_CLASS_MEMBERS(usqcdSourceInfo, + double, version, + std::string, info); + usqcdSourceInfo() { + version=1.0; + }; +}; +struct usqcdPropInfo : Serializable { + public: + GRID_SERIALIZABLE_CLASS_MEMBERS(usqcdPropInfo, + double, version, + int, spin, + int, color, + std::string, info); + usqcdPropInfo() { + version=1.0; + }; +}; +#endif + +} +#endif +#endif diff --git a/lib/parallelIO/MetaData.h b/lib/parallelIO/MetaData.h new file mode 100644 index 00000000..6d45d0a5 --- /dev/null +++ b/lib/parallelIO/MetaData.h @@ -0,0 +1,325 @@ +/************************************************************************************* + + Grid physics library, www.github.com/paboyle/Grid + + Source file: ./lib/parallelIO/NerscIO.h + + Copyright (C) 2015 + + + Author: Peter Boyle + + This program is free software; you can redistribute it and/or modify + it under the terms of the GNU General Public License as published by + the Free Software Foundation; either version 2 of the License, or + (at your option) any later version. + + This program is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU General Public License for more details. + + You should have received a copy of the GNU General Public License along + with this program; if not, write to the Free Software Foundation, Inc., + 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA. + + See the full license in the file "LICENSE" in the top level distribution directory +*************************************************************************************/ +/* END LEGAL */ + +#include +#include +#include +#include +#include +#include +#include +#include + +namespace Grid { + + /////////////////////////////////////////////////////// + // Precision mapping + /////////////////////////////////////////////////////// + template static std::string getFormatString (void) + { + std::string format; + typedef typename getPrecision::real_scalar_type stype; + if ( sizeof(stype) == sizeof(float) ) { + format = std::string("IEEE32BIG"); + } + if ( sizeof(stype) == sizeof(double) ) { + format = std::string("IEEE64BIG"); + } + return format; + } + //////////////////////////////////////////////////////////////////////////////// + // header specification/interpretation + //////////////////////////////////////////////////////////////////////////////// + class FieldMetaData : Serializable { + public: + + GRID_SERIALIZABLE_CLASS_MEMBERS(FieldMetaData, + int, nd, + std::vector, dimension, + std::vector, boundary, + int, data_start, + std::string, hdr_version, + std::string, storage_format, + double, link_trace, + double, plaquette, + uint32_t, checksum, + uint32_t, scidac_checksuma, + uint32_t, scidac_checksumb, + unsigned int, sequence_number, + std::string, data_type, + std::string, ensemble_id, + std::string, ensemble_label, + std::string, ildg_lfn, + std::string, creator, + std::string, creator_hardware, + std::string, creation_date, + std::string, archive_date, + std::string, floating_point); + FieldMetaData(void) { + nd=4; + dimension.resize(4); + boundary.resize(4); + } + }; + + + + namespace QCD { + + using namespace Grid; + + + ////////////////////////////////////////////////////////////////////// + // Bit and Physical Checksumming and QA of data + ////////////////////////////////////////////////////////////////////// + inline void GridMetaData(GridBase *grid,FieldMetaData &header) + { + int nd = grid->_ndimension; + header.nd = nd; + header.dimension.resize(nd); + header.boundary.resize(nd); + for(int d=0;d_fdimensions[d]; + } + for(int d=0;dpw_name); + + // When + std::time_t t = std::time(nullptr); + std::tm tm_ = *std::localtime(&t); + std::ostringstream oss; + // oss << std::put_time(&tm_, "%c %Z"); + header.creation_date = oss.str(); + header.archive_date = header.creation_date; + + // What + struct utsname name; uname(&name); + header.creator_hardware = std::string(name.nodename)+"-"; + header.creator_hardware+= std::string(name.machine)+"-"; + header.creator_hardware+= std::string(name.sysname)+"-"; + header.creator_hardware+= std::string(name.release); + } + +#define dump_meta_data(field, s) \ + s << "BEGIN_HEADER" << std::endl; \ + s << "HDR_VERSION = " << field.hdr_version << std::endl; \ + s << "DATATYPE = " << field.data_type << std::endl; \ + s << "STORAGE_FORMAT = " << field.storage_format << std::endl; \ + for(int i=0;i<4;i++){ \ + s << "DIMENSION_" << i+1 << " = " << field.dimension[i] << std::endl ; \ + } \ + s << "LINK_TRACE = " << std::setprecision(10) << field.link_trace << std::endl; \ + s << "PLAQUETTE = " << std::setprecision(10) << field.plaquette << std::endl; \ + for(int i=0;i<4;i++){ \ + s << "BOUNDARY_"< inline void PrepareMetaData(Lattice & field, FieldMetaData &header) +{ + GridBase *grid = field._grid; + std::string format = getFormatString(); + header.floating_point = format; + header.checksum = 0x0; // Nersc checksum unused in ILDG, Scidac + GridMetaData(grid,header); + MachineCharacteristics(header); + } + inline void GaugeStatistics(Lattice & data,FieldMetaData &header) + { + // How to convert data precision etc... + header.link_trace=Grid::QCD::WilsonLoops::linkTrace(data); + header.plaquette =Grid::QCD::WilsonLoops::avgPlaquette(data); + } + inline void GaugeStatistics(Lattice & data,FieldMetaData &header) + { + // How to convert data precision etc... + header.link_trace=Grid::QCD::WilsonLoops::linkTrace(data); + header.plaquette =Grid::QCD::WilsonLoops::avgPlaquette(data); + } + template<> inline void PrepareMetaData(Lattice & field, FieldMetaData &header) + { + + GridBase *grid = field._grid; + std::string format = getFormatString(); + header.floating_point = format; + header.checksum = 0x0; // Nersc checksum unused in ILDG, Scidac + GridMetaData(grid,header); + GaugeStatistics(field,header); + MachineCharacteristics(header); + } + template<> inline void PrepareMetaData(Lattice & field, FieldMetaData &header) + { + GridBase *grid = field._grid; + std::string format = getFormatString(); + header.floating_point = format; + header.checksum = 0x0; // Nersc checksum unused in ILDG, Scidac + GridMetaData(grid,header); + GaugeStatistics(field,header); + MachineCharacteristics(header); + } + + ////////////////////////////////////////////////////////////////////// + // Utilities ; these are QCD aware + ////////////////////////////////////////////////////////////////////// + inline void reconstruct3(LorentzColourMatrix & cm) + { + const int x=0; + const int y=1; + const int z=2; + for(int mu=0;mu using iLorentzColour2x3 = iVector, 2>, Nd >; + + typedef iLorentzColour2x3 LorentzColour2x3; + typedef iLorentzColour2x3 LorentzColour2x3F; + typedef iLorentzColour2x3 LorentzColour2x3D; + +///////////////////////////////////////////////////////////////////////////////// +// Simple classes for precision conversion +///////////////////////////////////////////////////////////////////////////////// +template +struct BinarySimpleUnmunger { + typedef typename getPrecision::real_scalar_type fobj_stype; + typedef typename getPrecision::real_scalar_type sobj_stype; + + void operator()(sobj &in, fobj &out) { + // take word by word and transform accoding to the status + fobj_stype *out_buffer = (fobj_stype *)&out; + sobj_stype *in_buffer = (sobj_stype *)∈ + size_t fobj_words = sizeof(out) / sizeof(fobj_stype); + size_t sobj_words = sizeof(in) / sizeof(sobj_stype); + assert(fobj_words == sobj_words); + + for (unsigned int word = 0; word < sobj_words; word++) + out_buffer[word] = in_buffer[word]; // type conversion on the fly + + } +}; + +template +struct BinarySimpleMunger { + typedef typename getPrecision::real_scalar_type fobj_stype; + typedef typename getPrecision::real_scalar_type sobj_stype; + + void operator()(fobj &in, sobj &out) { + // take word by word and transform accoding to the status + fobj_stype *in_buffer = (fobj_stype *)∈ + sobj_stype *out_buffer = (sobj_stype *)&out; + size_t fobj_words = sizeof(in) / sizeof(fobj_stype); + size_t sobj_words = sizeof(out) / sizeof(sobj_stype); + assert(fobj_words == sobj_words); + + for (unsigned int word = 0; word < sobj_words; word++) + out_buffer[word] = in_buffer[word]; // type conversion on the fly + + } +}; + + + template + struct GaugeSimpleMunger{ + void operator()(fobj &in, sobj &out) { + for (int mu = 0; mu < Nd; mu++) { + for (int i = 0; i < Nc; i++) { + for (int j = 0; j < Nc; j++) { + out(mu)()(i, j) = in(mu)()(i, j); + }} + } + }; + }; + + template + struct GaugeSimpleUnmunger { + + void operator()(sobj &in, fobj &out) { + for (int mu = 0; mu < Nd; mu++) { + for (int i = 0; i < Nc; i++) { + for (int j = 0; j < Nc; j++) { + out(mu)()(i, j) = in(mu)()(i, j); + }} + } + }; + }; + + template + struct Gauge3x2munger{ + void operator() (fobj &in,sobj &out){ + for(int mu=0;mu + struct Gauge3x2unmunger{ + void operator() (sobj &in,fobj &out){ + for(int mu=0;mu -Author: Peter Boyle -Author: paboyle + Author: Matt Spraggs + Author: Peter Boyle + Author: paboyle This program is free software; you can redistribute it and/or modify it under the terms of the GNU General Public License as published by @@ -25,251 +25,59 @@ Author: paboyle 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA. See the full license in the file "LICENSE" in the top level distribution directory - *************************************************************************************/ - /* END LEGAL */ +*************************************************************************************/ +/* END LEGAL */ #ifndef GRID_NERSC_IO_H #define GRID_NERSC_IO_H -#include -#include -#include -#include -#include - -#include -#include -#include - namespace Grid { -namespace QCD { + namespace QCD { -using namespace Grid; + using namespace Grid; -//////////////////////////////////////////////////////////////////////////////// -// Some data types for intermediate storage -//////////////////////////////////////////////////////////////////////////////// - template using iLorentzColour2x3 = iVector, 2>, 4 >; + //////////////////////////////////////////////////////////////////////////////// + // Write and read from fstream; comput header offset for payload + //////////////////////////////////////////////////////////////////////////////// + class NerscIO : public BinaryIO { + public: - typedef iLorentzColour2x3 LorentzColour2x3; - typedef iLorentzColour2x3 LorentzColour2x3F; - typedef iLorentzColour2x3 LorentzColour2x3D; - -//////////////////////////////////////////////////////////////////////////////// -// header specification/interpretation -//////////////////////////////////////////////////////////////////////////////// -class NerscField { - public: - // header strings (not in order) - int dimension[4]; - std::string boundary[4]; - int data_start; - std::string hdr_version; - std::string storage_format; - // Checks on data - double link_trace; - double plaquette; - uint32_t checksum; - unsigned int sequence_number; - std::string data_type; - std::string ensemble_id ; - std::string ensemble_label ; - std::string creator ; - std::string creator_hardware ; - std::string creation_date ; - std::string archive_date ; - std::string floating_point; -}; - -////////////////////////////////////////////////////////////////////// -// Bit and Physical Checksumming and QA of data -////////////////////////////////////////////////////////////////////// - -inline void NerscGrid(GridBase *grid,NerscField &header) -{ - assert(grid->_ndimension==4); - for(int d=0;d<4;d++) { - header.dimension[d] = grid->_fdimensions[d]; - } - for(int d=0;d<4;d++) { - header.boundary[d] = std::string("PERIODIC"); - } -} -template -inline void NerscStatistics(GaugeField & data,NerscField &header) -{ - // How to convert data precision etc... - header.link_trace=Grid::QCD::WilsonLoops::linkTrace(data); - header.plaquette =Grid::QCD::WilsonLoops::avgPlaquette(data); -} - -inline void NerscMachineCharacteristics(NerscField &header) -{ - // Who - struct passwd *pw = getpwuid (getuid()); - if (pw) header.creator = std::string(pw->pw_name); - - // When - std::time_t t = std::time(nullptr); - std::tm tm = *std::localtime(&t); - std::ostringstream oss; - // oss << std::put_time(&tm, "%c %Z"); - header.creation_date = oss.str(); - header.archive_date = header.creation_date; - - // What - struct utsname name; uname(&name); - header.creator_hardware = std::string(name.nodename)+"-"; - header.creator_hardware+= std::string(name.machine)+"-"; - header.creator_hardware+= std::string(name.sysname)+"-"; - header.creator_hardware+= std::string(name.release); - -} -////////////////////////////////////////////////////////////////////// -// Utilities ; these are QCD aware -////////////////////////////////////////////////////////////////////// - inline void NerscChecksum(uint32_t *buf,uint32_t buf_size_bytes,uint32_t &csum) - { - BinaryIO::Uint32Checksum(buf,buf_size_bytes,csum); - } - inline void reconstruct3(LorentzColourMatrix & cm) - { - const int x=0; - const int y=1; - const int z=2; - for(int mu=0;mu<4;mu++){ - cm(mu)()(2,x) = adj(cm(mu)()(0,y)*cm(mu)()(1,z)-cm(mu)()(0,z)*cm(mu)()(1,y)); //x= yz-zy - cm(mu)()(2,y) = adj(cm(mu)()(0,z)*cm(mu)()(1,x)-cm(mu)()(0,x)*cm(mu)()(1,z)); //y= zx-xz - cm(mu)()(2,z) = adj(cm(mu)()(0,x)*cm(mu)()(1,y)-cm(mu)()(0,y)*cm(mu)()(1,x)); //z= xy-yx + static inline void truncate(std::string file){ + std::ofstream fout(file,std::ios::out); } + + static inline unsigned int writeHeader(FieldMetaData &field,std::string file) + { + std::ofstream fout(file,std::ios::out|std::ios::in); + fout.seekp(0,std::ios::beg); + dump_meta_data(field, fout); + field.data_start = fout.tellp(); + return field.data_start; } - template - struct NerscSimpleMunger{ + // for the header-reader + static inline int readHeader(std::string file,GridBase *grid, FieldMetaData &field) + { + int offset=0; + std::map header; + std::string line; - void operator() (fobj &in,sobj &out,uint32_t &csum){ + ////////////////////////////////////////////////// + // read the header + ////////////////////////////////////////////////// + std::ifstream fin(file); - for(int mu=0;mu<4;mu++){ - for(int i=0;i<3;i++){ - for(int j=0;j<3;j++){ - out(mu)()(i,j) = in(mu)()(i,j); - }}} - NerscChecksum((uint32_t *)&in,sizeof(in),csum); - }; - }; + getline(fin,line); // read one line and insist is - template - struct NerscSimpleUnmunger{ - void operator() (sobj &in,fobj &out,uint32_t &csum){ - for(int mu=0;mu - struct Nersc3x2munger{ - void operator() (fobj &in,sobj &out,uint32_t &csum){ - - NerscChecksum((uint32_t *)&in,sizeof(in),csum); + removeWhitespace(line); + std::cout << GridLogMessage << "* " << line << std::endl; - for(int mu=0;mu<4;mu++){ - for(int i=0;i<2;i++){ - for(int j=0;j<3;j++){ - out(mu)()(i,j) = in(mu)(i)(j); - }} - } - reconstruct3(out); - } - }; + assert(line==std::string("BEGIN_HEADER")); - template - struct Nersc3x2unmunger{ - - void operator() (sobj &in,fobj &out,uint32_t &csum){ - - - for(int mu=0;mu<4;mu++){ - for(int i=0;i<2;i++){ - for(int j=0;j<3;j++){ - out(mu)(i)(j) = in(mu)()(i,j); - }} - } - - NerscChecksum((uint32_t *)&out,sizeof(out),csum); - - } - }; - - -//////////////////////////////////////////////////////////////////////////////// -// Write and read from fstream; comput header offset for payload -//////////////////////////////////////////////////////////////////////////////// -class NerscIO : public BinaryIO { - public: - - static inline void truncate(std::string file){ - std::ofstream fout(file,std::ios::out); - } - - #define dump_nersc_header(field, s)\ - s << "BEGIN_HEADER" << std::endl;\ - s << "HDR_VERSION = " << field.hdr_version << std::endl;\ - s << "DATATYPE = " << field.data_type << std::endl;\ - s << "STORAGE_FORMAT = " << field.storage_format << std::endl;\ - for(int i=0;i<4;i++){\ - s << "DIMENSION_" << i+1 << " = " << field.dimension[i] << std::endl ;\ - }\ - s << "LINK_TRACE = " << std::setprecision(10) << field.link_trace << std::endl;\ - s << "PLAQUETTE = " << std::setprecision(10) << field.plaquette << std::endl;\ - for(int i=0;i<4;i++){\ - s << "BOUNDARY_"< header; - std::string line; - - ////////////////////////////////////////////////// - // read the header - ////////////////////////////////////////////////// - std::ifstream fin(file); - - getline(fin,line); // read one line and insist is - - removeWhitespace(line); - assert(line==std::string("BEGIN_HEADER")); - - do { - getline(fin,line); // read one line - int eq = line.find("="); - if(eq >0) { + do { + getline(fin,line); // read one line + std::cout << GridLogMessage << "* "<0) { std::string key=line.substr(0,eq); std::string val=line.substr(eq+1); removeWhitespace(key); @@ -277,249 +85,269 @@ static inline int readHeader(std::string file,GridBase *grid, NerscField &field header[key] = val; } - } while( line.find("END_HEADER") == std::string::npos ); + } while( line.find("END_HEADER") == std::string::npos ); - field.data_start = fin.tellg(); + field.data_start = fin.tellg(); - ////////////////////////////////////////////////// - // chomp the values - ////////////////////////////////////////////////// - field.hdr_version = header["HDR_VERSION"]; - field.data_type = header["DATATYPE"]; - field.storage_format = header["STORAGE_FORMAT"]; + ////////////////////////////////////////////////// + // chomp the values + ////////////////////////////////////////////////// + field.hdr_version = header["HDR_VERSION"]; + field.data_type = header["DATATYPE"]; + field.storage_format = header["STORAGE_FORMAT"]; - field.dimension[0] = std::stol(header["DIMENSION_1"]); - field.dimension[1] = std::stol(header["DIMENSION_2"]); - field.dimension[2] = std::stol(header["DIMENSION_3"]); - field.dimension[3] = std::stol(header["DIMENSION_4"]); + field.dimension[0] = std::stol(header["DIMENSION_1"]); + field.dimension[1] = std::stol(header["DIMENSION_2"]); + field.dimension[2] = std::stol(header["DIMENSION_3"]); + field.dimension[3] = std::stol(header["DIMENSION_4"]); - assert(grid->_ndimension == 4); - for(int d=0;d<4;d++){ - assert(grid->_fdimensions[d]==field.dimension[d]); - } - - field.link_trace = std::stod(header["LINK_TRACE"]); - field.plaquette = std::stod(header["PLAQUETTE"]); - - field.boundary[0] = header["BOUNDARY_1"]; - field.boundary[1] = header["BOUNDARY_2"]; - field.boundary[2] = header["BOUNDARY_3"]; - field.boundary[3] = header["BOUNDARY_4"]; - - field.checksum = std::stoul(header["CHECKSUM"],0,16); - field.ensemble_id = header["ENSEMBLE_ID"]; - field.ensemble_label = header["ENSEMBLE_LABEL"]; - field.sequence_number = std::stol(header["SEQUENCE_NUMBER"]); - field.creator = header["CREATOR"]; - field.creator_hardware = header["CREATOR_HARDWARE"]; - field.creation_date = header["CREATION_DATE"]; - field.archive_date = header["ARCHIVE_DATE"]; - field.floating_point = header["FLOATING_POINT"]; - - return field.data_start; -} - -///////////////////////////////////////////////////////////////////////////////////////////////////////////////////// -// Now the meat: the object readers -///////////////////////////////////////////////////////////////////////////////////////////////////////////////////// - -template -static inline void readConfiguration(Lattice > &Umu,NerscField& header,std::string file) -{ - typedef Lattice > GaugeField; - - GridBase *grid = Umu._grid; - int offset = readHeader(file,Umu._grid,header); - - NerscField clone(header); - - std::string format(header.floating_point); - - int ieee32big = (format == std::string("IEEE32BIG")); - int ieee32 = (format == std::string("IEEE32")); - int ieee64big = (format == std::string("IEEE64BIG")); - int ieee64 = (format == std::string("IEEE64")); - - uint32_t csum; - // depending on datatype, set up munger; - // munger is a function of - if ( header.data_type == std::string("4D_SU3_GAUGE") ) { - if ( ieee32 || ieee32big ) { - // csum=BinaryIO::readObjectSerial, LorentzColour2x3F> - csum=BinaryIO::readObjectParallel, LorentzColour2x3F> - (Umu,file,Nersc3x2munger(), offset,format); + assert(grid->_ndimension == 4); + for(int d=0;d<4;d++){ + assert(grid->_fdimensions[d]==field.dimension[d]); } - if ( ieee64 || ieee64big ) { - //csum=BinaryIO::readObjectSerial, LorentzColour2x3D> - csum=BinaryIO::readObjectParallel, LorentzColour2x3D> - (Umu,file,Nersc3x2munger(),offset,format); + + field.link_trace = std::stod(header["LINK_TRACE"]); + field.plaquette = std::stod(header["PLAQUETTE"]); + + field.boundary[0] = header["BOUNDARY_1"]; + field.boundary[1] = header["BOUNDARY_2"]; + field.boundary[2] = header["BOUNDARY_3"]; + field.boundary[3] = header["BOUNDARY_4"]; + + field.checksum = std::stoul(header["CHECKSUM"],0,16); + field.ensemble_id = header["ENSEMBLE_ID"]; + field.ensemble_label = header["ENSEMBLE_LABEL"]; + field.sequence_number = std::stol(header["SEQUENCE_NUMBER"]); + field.creator = header["CREATOR"]; + field.creator_hardware = header["CREATOR_HARDWARE"]; + field.creation_date = header["CREATION_DATE"]; + field.archive_date = header["ARCHIVE_DATE"]; + field.floating_point = header["FLOATING_POINT"]; + + return field.data_start; } - } else if ( header.data_type == std::string("4D_SU3_GAUGE_3x3") ) { - if ( ieee32 || ieee32big ) { - //csum=BinaryIO::readObjectSerial,LorentzColourMatrixF> - csum=BinaryIO::readObjectParallel,LorentzColourMatrixF> - (Umu,file,NerscSimpleMunger(),offset,format); + + ///////////////////////////////////////////////////////////////////////////////////////////////////////////////////// + // Now the meat: the object readers + ///////////////////////////////////////////////////////////////////////////////////////////////////////////////////// + + template + static inline void readConfiguration(Lattice > &Umu, + FieldMetaData& header, + std::string file) + { + typedef Lattice > GaugeField; + + GridBase *grid = Umu._grid; + int offset = readHeader(file,Umu._grid,header); + + FieldMetaData clone(header); + + std::string format(header.floating_point); + + int ieee32big = (format == std::string("IEEE32BIG")); + int ieee32 = (format == std::string("IEEE32")); + int ieee64big = (format == std::string("IEEE64BIG")); + int ieee64 = (format == std::string("IEEE64")); + + uint32_t nersc_csum,scidac_csuma,scidac_csumb; + // depending on datatype, set up munger; + // munger is a function of + if ( header.data_type == std::string("4D_SU3_GAUGE") ) { + if ( ieee32 || ieee32big ) { + BinaryIO::readLatticeObject, LorentzColour2x3F> + (Umu,file,Gauge3x2munger(), offset,format, + nersc_csum,scidac_csuma,scidac_csumb); + } + if ( ieee64 || ieee64big ) { + BinaryIO::readLatticeObject, LorentzColour2x3D> + (Umu,file,Gauge3x2munger(),offset,format, + nersc_csum,scidac_csuma,scidac_csumb); + } + } else if ( header.data_type == std::string("4D_SU3_GAUGE_3x3") ) { + if ( ieee32 || ieee32big ) { + BinaryIO::readLatticeObject,LorentzColourMatrixF> + (Umu,file,GaugeSimpleMunger(),offset,format, + nersc_csum,scidac_csuma,scidac_csumb); + } + if ( ieee64 || ieee64big ) { + BinaryIO::readLatticeObject,LorentzColourMatrixD> + (Umu,file,GaugeSimpleMunger(),offset,format, + nersc_csum,scidac_csuma,scidac_csumb); + } + } else { + assert(0); + } + + GaugeStatistics(Umu,clone); + + std::cout<= 1.0e-5 ) { + std::cout << " Plaquette mismatch "<,LorentzColourMatrixD> - csum=BinaryIO::readObjectParallel,LorentzColourMatrixD> - (Umu,file,NerscSimpleMunger(),offset,format); - } - } else { - assert(0); - } - NerscStatistics(Umu,clone); + template + static inline void writeConfiguration(Lattice > &Umu, + std::string file, + int two_row, + int bits32) + { + typedef Lattice > GaugeField; - assert(fabs(clone.plaquette -header.plaquette ) < 1.0e-5 ); - assert(fabs(clone.link_trace-header.link_trace) < 1.0e-6 ); + typedef iLorentzColourMatrix vobj; + typedef typename vobj::scalar_object sobj; - assert(csum == header.checksum ); + FieldMetaData header; + /////////////////////////////////////////// + // Following should become arguments + /////////////////////////////////////////// + header.sequence_number = 1; + header.ensemble_id = "UKQCD"; + header.ensemble_label = "DWF"; - std::cout< -static inline void writeConfiguration(Lattice > &Umu,std::string file, int two_row,int bits32) -{ - typedef Lattice > GaugeField; - - typedef iLorentzColourMatrix vobj; - typedef typename vobj::scalar_object sobj; - - // Following should become arguments - NerscField header; - header.sequence_number = 1; - header.ensemble_id = "UKQCD"; - header.ensemble_label = "DWF"; - - typedef LorentzColourMatrixD fobj3D; - typedef LorentzColour2x3D fobj2D; - typedef LorentzColourMatrixF fobj3f; - typedef LorentzColour2x3F fobj2f; - - GridBase *grid = Umu._grid; - - NerscGrid(grid,header); - NerscStatistics(Umu,header); - NerscMachineCharacteristics(header); - - uint32_t csum; - int offset; + typedef LorentzColourMatrixD fobj3D; + typedef LorentzColour2x3D fobj2D; - truncate(file); + GridBase *grid = Umu._grid; - if ( two_row ) { + GridMetaData(grid,header); + assert(header.nd==4); + GaugeStatistics(Umu,header); + MachineCharacteristics(header); - header.floating_point = std::string("IEEE64BIG"); - header.data_type = std::string("4D_SU3_GAUGE"); - Nersc3x2unmunger munge; - BinaryIO::Uint32Checksum(Umu, munge,header.checksum); - offset = writeHeader(header,file); - csum=BinaryIO::writeObjectSerial(Umu,file,munge,offset,header.floating_point); + int offset; + + truncate(file); - std::string file1 = file+"para"; - int offset1 = writeHeader(header,file1); - int csum1=BinaryIO::writeObjectParallel(Umu,file1,munge,offset,header.floating_point); - //int csum1=BinaryIO::writeObjectSerial(Umu,file1,munge,offset,header.floating_point); + // Sod it -- always write 3x3 double + header.floating_point = std::string("IEEE64BIG"); + header.data_type = std::string("4D_SU3_GAUGE_3x3"); + GaugeSimpleUnmunger munge; + offset = writeHeader(header,file); - - std::cout << GridLogMessage << " TESTING PARALLEL WRITE offsets " << offset1 << " "<< offset << std::endl; - std::cout << GridLogMessage << " TESTING PARALLEL WRITE csums " << csum1 << " "<(Umu,file,munge,offset,header.floating_point, + nersc_csum,scidac_csuma,scidac_csumb); + header.checksum = nersc_csum; + writeHeader(header,file); - assert(offset1==offset); - assert(csum1==csum); + std::cout< munge; - BinaryIO::Uint32Checksum(Umu, munge,header.checksum); - offset = writeHeader(header,file); - // csum=BinaryIO::writeObjectSerial(Umu,file,munge,offset,header.floating_point); - csum=BinaryIO::writeObjectParallel(Umu,file,munge,offset,header.floating_point); - } + } + /////////////////////////////// + // RNG state + /////////////////////////////// + static inline void writeRNGState(GridSerialRNG &serial,GridParallelRNG ¶llel,std::string file) + { + typedef typename GridParallelRNG::RngStateType RngStateType; - std::cout< - uint32_t csum=BinaryIO::readRNGSerial(serial,parallel,file,offset); + // depending on datatype, set up munger; + // munger is a function of + uint32_t nersc_csum,scidac_csuma,scidac_csumb; + BinaryIO::readRNG(serial,parallel,file,offset,nersc_csum,scidac_csuma,scidac_csumb); - std::cerr<<" Csum "<< csum << " "<< header.checksum < *************************************************************************************/ /* END LEGAL */ -#include -#include +#include +#include namespace Grid { @@ -40,7 +40,7 @@ const PerformanceCounter::PerformanceCounterConfig PerformanceCounter::Performan { PERF_TYPE_HARDWARE, PERF_COUNT_HW_CPU_CYCLES , "CPUCYCLES.........." , INSTRUCTIONS}, { PERF_TYPE_HARDWARE, PERF_COUNT_HW_INSTRUCTIONS , "INSTRUCTIONS......." , CPUCYCLES }, // 4 -#ifdef AVX512 +#ifdef KNL { PERF_TYPE_RAW, RawConfig(0x40,0x04), "ALL_LOADS..........", CPUCYCLES }, { PERF_TYPE_RAW, RawConfig(0x01,0x04), "L1_MISS_LOADS......", L1D_READ_ACCESS }, { PERF_TYPE_RAW, RawConfig(0x40,0x04), "ALL_LOADS..........", L1D_READ_ACCESS }, diff --git a/lib/PerfCount.h b/lib/perfmon/PerfCount.h similarity index 96% rename from lib/PerfCount.h rename to lib/perfmon/PerfCount.h index 5ab07c02..73d2c70f 100644 --- a/lib/PerfCount.h +++ b/lib/perfmon/PerfCount.h @@ -172,7 +172,7 @@ public: const char * name = PerformanceCounterConfigs[PCT].name; fd = perf_event_open(&pe, 0, -1, -1, 0); // pid 0, cpu -1 current process any cpu. group -1 if (fd == -1) { - fprintf(stderr, "Error opening leader %llx for event %s\n", pe.config,name); + fprintf(stderr, "Error opening leader %llx for event %s\n",(long long) pe.config,name); perror("Error is"); } int norm = PerformanceCounterConfigs[PCT].normalisation; @@ -181,7 +181,7 @@ public: name = PerformanceCounterConfigs[norm].name; cyclefd = perf_event_open(&pe, 0, -1, -1, 0); // pid 0, cpu -1 current process any cpu. group -1 if (cyclefd == -1) { - fprintf(stderr, "Error opening leader %llx for event %s\n", pe.config,name); + fprintf(stderr, "Error opening leader %llx for event %s\n",(long long) pe.config,name); perror("Error is"); } #endif @@ -206,11 +206,13 @@ public: count=0; cycles=0; #ifdef __linux__ + ssize_t ign; if ( fd!= -1) { ::ioctl(fd, PERF_EVENT_IOC_DISABLE, 0); ::ioctl(cyclefd, PERF_EVENT_IOC_DISABLE, 0); - ::read(fd, &count, sizeof(long long)); - ::read(cyclefd, &cycles, sizeof(long long)); + ign=::read(fd, &count, sizeof(long long)); + ign+=::read(cyclefd, &cycles, sizeof(long long)); + assert(ign=2*sizeof(long long)); } elapsed = cyclecount() - begin; #else diff --git a/lib/Stat.cc b/lib/perfmon/Stat.cc similarity index 98% rename from lib/Stat.cc rename to lib/perfmon/Stat.cc index 7f2e4086..3f47fd83 100644 --- a/lib/Stat.cc +++ b/lib/perfmon/Stat.cc @@ -1,11 +1,9 @@ -#include -#include -#include - +#include +#include +#include namespace Grid { - bool PmuStat::pmu_initialized=false; diff --git a/lib/Stat.h b/lib/perfmon/Stat.h similarity index 100% rename from lib/Stat.h rename to lib/perfmon/Stat.h diff --git a/lib/Timer.h b/lib/perfmon/Timer.h similarity index 100% rename from lib/Timer.h rename to lib/perfmon/Timer.h diff --git a/lib/pugixml/pugixml.cc b/lib/pugixml/pugixml.cc index 525d1419..a4f8fde2 100644 --- a/lib/pugixml/pugixml.cc +++ b/lib/pugixml/pugixml.cc @@ -14,7 +14,7 @@ #ifndef SOURCE_PUGIXML_CPP #define SOURCE_PUGIXML_CPP -#include +#include #include #include diff --git a/lib/qcd/LatticeTheories.h b/lib/qcd/LatticeTheories.h new file mode 100644 index 00000000..74c68d83 --- /dev/null +++ b/lib/qcd/LatticeTheories.h @@ -0,0 +1,124 @@ +/************************************************************************************* + +Grid physics library, www.github.com/paboyle/Grid + +Source file: ./lib/qcd/QCD.h + +Copyright (C) 2015 + +Author: Azusa Yamaguchi +Author: Peter Boyle +Author: Peter Boyle +Author: neo +Author: paboyle + +This program is free software; you can redistribute it and/or modify +it under the terms of the GNU General Public License as published by +the Free Software Foundation; either version 2 of the License, or +(at your option) any later version. + +This program is distributed in the hope that it will be useful, +but WITHOUT ANY WARRANTY; without even the implied warranty of +MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +GNU General Public License for more details. + +You should have received a copy of the GNU General Public License along +with this program; if not, write to the Free Software Foundation, Inc., +51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA. + +See the full license in the file "LICENSE" in the top level distribution +directory +*************************************************************************************/ +/* END LEGAL */ +#ifndef GRID_LT_H +#define GRID_LT_H +namespace Grid{ + +// First steps in the complete generalization of the Physics part +// Design not final +namespace LatticeTheories { + +template +struct LatticeTheory { + static const int Nd = Dimensions; + static const int Nds = Dimensions * 2; // double stored field + template + using iSinglet = iScalar > >; +}; + +template +struct LatticeGaugeTheory : public LatticeTheory { + static const int Nds = Dimensions * 2; + static const int Nd = Dimensions; + static const int Nc = Colours; + + template + using iColourMatrix = iScalar > >; + template + using iLorentzColourMatrix = iVector >, Nd>; + template + using iDoubleStoredColourMatrix = iVector >, Nds>; + template + using iColourVector = iScalar > >; +}; + +template +struct FermionicLatticeGaugeTheory + : public LatticeGaugeTheory { + static const int Nd = Dimensions; + static const int Nds = Dimensions * 2; + static const int Nc = Colours; + static const int Ns = Spin; + + template + using iSpinMatrix = iScalar, Ns> >; + template + using iSpinColourMatrix = iScalar, Ns> >; + template + using iSpinVector = iScalar, Ns> >; + template + using iSpinColourVector = iScalar, Ns> >; + // These 2 only if Spin is a multiple of 2 + static const int Nhs = Spin / 2; + template + using iHalfSpinVector = iScalar, Nhs> >; + template + using iHalfSpinColourVector = iScalar, Nhs> >; + + //tests + typedef iColourMatrix ColourMatrix; + typedef iColourMatrix ColourMatrixF; + typedef iColourMatrix ColourMatrixD; + + +}; + +// Examples, not complete now. +struct QCD : public FermionicLatticeGaugeTheory<4, 3, 4> { + static const int Xp = 0; + static const int Yp = 1; + static const int Zp = 2; + static const int Tp = 3; + static const int Xm = 4; + static const int Ym = 5; + static const int Zm = 6; + static const int Tm = 7; + + typedef FermionicLatticeGaugeTheory FLGT; + + typedef FLGT::iSpinMatrix SpinMatrix; + typedef FLGT::iSpinMatrix SpinMatrixF; + typedef FLGT::iSpinMatrix SpinMatrixD; + +}; +struct QED : public FermionicLatticeGaugeTheory<4, 1, 4> {//fill +}; + +template +struct Scalar : public LatticeTheory {}; + +}; // LatticeTheories + +} // Grid + +#endif diff --git a/lib/qcd/QCD.h b/lib/qcd/QCD.h index f434bdd9..fa336020 100644 --- a/lib/qcd/QCD.h +++ b/lib/qcd/QCD.h @@ -29,12 +29,15 @@ Author: paboyle See the full license in the file "LICENSE" in the top level distribution directory *************************************************************************************/ /* END LEGAL */ -#ifndef GRID_QCD_H -#define GRID_QCD_H +#ifndef GRID_QCD_BASE_H +#define GRID_QCD_BASE_H namespace Grid{ - namespace QCD { + static const int Xdir = 0; + static const int Ydir = 1; + static const int Zdir = 2; + static const int Tdir = 3; static const int Xp = 0; static const int Yp = 1; @@ -62,7 +65,6 @@ namespace QCD { #define SpinIndex 1 #define LorentzIndex 0 - // Also should make these a named enum type static const int DaggerNo=0; static const int DaggerYes=1; @@ -355,36 +357,36 @@ namespace QCD { ////////////////////////////////////////////// template void pokeColour(Lattice &lhs, - const Lattice(lhs._odata[0],0))> & rhs, - int i) + const Lattice(lhs._odata[0],0))> & rhs, + int i) { PokeIndex(lhs,rhs,i); } template void pokeColour(Lattice &lhs, - const Lattice(lhs._odata[0],0,0))> & rhs, - int i,int j) + const Lattice(lhs._odata[0],0,0))> & rhs, + int i,int j) { PokeIndex(lhs,rhs,i,j); } template void pokeSpin(Lattice &lhs, - const Lattice(lhs._odata[0],0))> & rhs, - int i) + const Lattice(lhs._odata[0],0))> & rhs, + int i) { PokeIndex(lhs,rhs,i); } template void pokeSpin(Lattice &lhs, - const Lattice(lhs._odata[0],0,0))> & rhs, - int i,int j) + const Lattice(lhs._odata[0],0,0))> & rhs, + int i,int j) { PokeIndex(lhs,rhs,i,j); } template void pokeLorentz(Lattice &lhs, - const Lattice(lhs._odata[0],0))> & rhs, - int i) + const Lattice(lhs._odata[0],0))> & rhs, + int i) { PokeIndex(lhs,rhs,i); } @@ -493,27 +495,38 @@ namespace QCD { } //namespace QCD } // Grid - +/* +<<<<<<< HEAD #include #include #include #include #include -// Include representations +// Include representations #include #include #include #include +// Scalar field +#include + #include #include #include #include +#include #include +//#include +======= + +>>>>>>> develop +*/ + #endif diff --git a/lib/qcd/action/Action.h b/lib/qcd/action/Action.h new file mode 100644 index 00000000..7272c90d --- /dev/null +++ b/lib/qcd/action/Action.h @@ -0,0 +1,50 @@ + /************************************************************************************* + + Grid physics library, www.github.com/paboyle/Grid + + Source file: ./lib/qcd/action/Actions.h + + Copyright (C) 2015 + +Author: Azusa Yamaguchi +Author: Peter Boyle +Author: Peter Boyle +Author: Peter Boyle +Author: neo +Author: paboyle + + This program is free software; you can redistribute it and/or modify + it under the terms of the GNU General Public License as published by + the Free Software Foundation; either version 2 of the License, or + (at your option) any later version. + + This program is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU General Public License for more details. + + You should have received a copy of the GNU General Public License along + with this program; if not, write to the Free Software Foundation, Inc., + 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA. + + See the full license in the file "LICENSE" in the top level distribution directory + *************************************************************************************/ + /* END LEGAL */ +#ifndef GRID_QCD_ACTION_H +#define GRID_QCD_ACTION_H + +//////////////////////////////////////////// +// Abstract base interface +//////////////////////////////////////////// +#include +//////////////////////////////////////////////////////////////////////// +// Fermion actions; prevent coupling fermion.cc files to other headers +//////////////////////////////////////////////////////////////////////// +#include +#include +//////////////////////////////////////// +// Pseudo fermion combinations for HMC +//////////////////////////////////////// +#include + +#endif diff --git a/lib/qcd/action/ActionBase.h b/lib/qcd/action/ActionBase.h index 56d6b8e0..8d853d45 100644 --- a/lib/qcd/action/ActionBase.h +++ b/lib/qcd/action/ActionBase.h @@ -4,10 +4,11 @@ Grid physics library, www.github.com/paboyle/Grid Source file: ./lib/qcd/action/ActionBase.h -Copyright (C) 2015 +Copyright (C) 2015-2016 Author: Peter Boyle Author: neo +Author: Guido Cossu This program is free software; you can redistribute it and/or modify it under the terms of the GNU General Public License as published by @@ -27,127 +28,29 @@ See the full license in the file "LICENSE" in the top level distribution directory *************************************************************************************/ /* END LEGAL */ -#ifndef QCD_ACTION_BASE -#define QCD_ACTION_BASE + +#ifndef ACTION_BASE_H +#define ACTION_BASE_H + namespace Grid { namespace QCD { -template -class Action { +template +class Action +{ + public: bool is_smeared = false; - // Boundary conditions? // Heatbath? - virtual void refresh(const GaugeField& U, - GridParallelRNG& pRNG) = 0; // refresh pseudofermions - virtual RealD S(const GaugeField& U) = 0; // evaluate the action - virtual void deriv(const GaugeField& U, - GaugeField& dSdU) = 0; // evaluate the action derivative - virtual ~Action(){}; + // Heatbath? + virtual void refresh(const GaugeField& U, GridParallelRNG& pRNG) = 0; // refresh pseudofermions + virtual RealD S(const GaugeField& U) = 0; // evaluate the action + virtual void deriv(const GaugeField& U, GaugeField& dSdU) = 0; // evaluate the action derivative + virtual std::string action_name() = 0; // return the action name + virtual std::string LogParameters() = 0; // prints action parameters + virtual ~Action(){} }; -// Indexing of tuple types -template -struct Index; - -template -struct Index> { - static const std::size_t value = 0; -}; - -template -struct Index> { - static const std::size_t value = 1 + Index>::value; -}; - -/* -template -struct ActionLevel { - public: - typedef Action* - ActPtr; // now force the same colours as the rest of the code - - //Add supported representations here - - - unsigned int multiplier; - - std::vector actions; - - ActionLevel(unsigned int mul = 1) : actions(0), multiplier(mul) { - assert(mul >= 1); - }; - - void push_back(ActPtr ptr) { actions.push_back(ptr); } -}; -*/ - -template -struct ActionLevel { - public: - unsigned int multiplier; - - // Fundamental repr actions separated because of the smearing - typedef Action* ActPtr; - - // construct a tuple of vectors of the actions for the corresponding higher - // representation fields - typedef typename AccessTypes::VectorCollection action_collection; - action_collection actions_hirep; - typedef typename AccessTypes::FieldTypeCollection action_hirep_types; - - std::vector& actions; - - // Temporary conversion between ActionLevel and ActionLevelHirep - //ActionLevelHirep(ActionLevel& AL ):actions(AL.actions), multiplier(AL.multiplier){} - - ActionLevel(unsigned int mul = 1) : actions(std::get<0>(actions_hirep)), multiplier(mul) { - // initialize the hirep vectors to zero. - //apply(this->resize, actions_hirep, 0); //need a working resize - assert(mul >= 1); - }; - - //void push_back(ActPtr ptr) { actions.push_back(ptr); } - - - - template < class Field > - void push_back(Action* ptr) { - // insert only in the correct vector - std::get< Index < Field, action_hirep_types>::value >(actions_hirep).push_back(ptr); - }; - - - - template < class ActPtr> - static void resize(ActPtr ap, unsigned int n){ - ap->resize(n); - - } - - //template - //auto getRepresentation(Repr& R)->decltype(std::get(R).U) {return std::get(R).U;} - - // Loop on tuple for a callable function - template - inline typename std::enable_if::value, void>::type apply( - Callable, Repr& R,Args&...) const {} - - template - inline typename std::enable_if::value, void>::type apply( - Callable fn, Repr& R, Args&... arguments) const { - fn(std::get(actions_hirep), std::get(R.rep), arguments...); - apply(fn, R, arguments...); - } - -}; - - -//template -//using ActionSet = std::vector >; - -template -using ActionSet = std::vector >; - } } -#endif + +#endif // ACTION_BASE_H diff --git a/lib/qcd/action/ActionCore.h b/lib/qcd/action/ActionCore.h new file mode 100644 index 00000000..7a5caf15 --- /dev/null +++ b/lib/qcd/action/ActionCore.h @@ -0,0 +1,61 @@ +/************************************************************************************* + +Grid physics library, www.github.com/paboyle/Grid + +Source file: ./lib/qcd/action/ActionCore.h + +Copyright (C) 2015 + +Author: Peter Boyle +Author: neo + +This program is free software; you can redistribute it and/or modify +it under the terms of the GNU General Public License as published by +the Free Software Foundation; either version 2 of the License, or +(at your option) any later version. + +This program is distributed in the hope that it will be useful, +but WITHOUT ANY WARRANTY; without even the implied warranty of +MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +GNU General Public License for more details. + +You should have received a copy of the GNU General Public License along +with this program; if not, write to the Free Software Foundation, Inc., +51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA. + +See the full license in the file "LICENSE" in the top level distribution +directory +*************************************************************************************/ +/* END LEGAL */ +#ifndef QCD_ACTION_CORE +#define QCD_ACTION_CORE + +#include +#include +#include + +//////////////////////////////////////////// +// Gauge Actions +//////////////////////////////////////////// +#include + +//////////////////////////////////////////// +// Fermion prereqs +//////////////////////////////////////////// +#include + +//////////////////////////////////////////// +// Scalar Actions +//////////////////////////////////////////// +#include + +//////////////////////////////////////////// +// Utility functions +//////////////////////////////////////////// +#include +#include + + + + +#endif diff --git a/lib/qcd/action/ActionParams.h b/lib/qcd/action/ActionParams.h index 91e94741..d25b60a9 100644 --- a/lib/qcd/action/ActionParams.h +++ b/lib/qcd/action/ActionParams.h @@ -1,67 +1,92 @@ - /************************************************************************************* +/************************************************************************************* - Grid physics library, www.github.com/paboyle/Grid +Grid physics library, www.github.com/paboyle/Grid - Source file: ./lib/qcd/action/ActionParams.h +Source file: ./lib/qcd/action/ActionParams.h - Copyright (C) 2015 +Copyright (C) 2015 Author: Peter Boyle Author: paboyle +Author: Guido Cossu - This program is free software; you can redistribute it and/or modify - it under the terms of the GNU General Public License as published by - the Free Software Foundation; either version 2 of the License, or - (at your option) any later version. +This program is free software; you can redistribute it and/or modify +it under the terms of the GNU General Public License as published by +the Free Software Foundation; either version 2 of the License, or +(at your option) any later version. - This program is distributed in the hope that it will be useful, - but WITHOUT ANY WARRANTY; without even the implied warranty of - MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the - GNU General Public License for more details. +This program is distributed in the hope that it will be useful, +but WITHOUT ANY WARRANTY; without even the implied warranty of +MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +GNU General Public License for more details. - You should have received a copy of the GNU General Public License along - with this program; if not, write to the Free Software Foundation, Inc., - 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA. +You should have received a copy of the GNU General Public License along +with this program; if not, write to the Free Software Foundation, Inc., +51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA. + +See the full license in the file "LICENSE" in the top level distribution +directory +*************************************************************************************/ +/* END LEGAL */ - See the full license in the file "LICENSE" in the top level distribution directory - *************************************************************************************/ - /* END LEGAL */ #ifndef GRID_QCD_ACTION_PARAMS_H #define GRID_QCD_ACTION_PARAMS_H namespace Grid { namespace QCD { - // These can move into a params header and be given MacroMagic serialisation - struct GparityWilsonImplParams { - bool overlapCommsCompute; - std::vector twists; - GparityWilsonImplParams () : twists(Nd,0), overlapCommsCompute(false) {}; - + // These can move into a params header and be given MacroMagic serialisation + struct GparityWilsonImplParams { + bool overlapCommsCompute; + std::vector twists; + GparityWilsonImplParams() : twists(Nd, 0), overlapCommsCompute(false){}; + }; + + struct WilsonImplParams { + bool overlapCommsCompute; + std::vector boundary_phases; + WilsonImplParams() : overlapCommsCompute(false) { + boundary_phases.resize(Nd, 1.0); }; + WilsonImplParams(const std::vector phi) + : boundary_phases(phi), overlapCommsCompute(false) {} + }; - struct WilsonImplParams { - bool overlapCommsCompute; - WilsonImplParams() : overlapCommsCompute(false) {}; - }; + struct StaggeredImplParams { + StaggeredImplParams() {}; + }; + + struct OneFlavourRationalParams : Serializable { + GRID_SERIALIZABLE_CLASS_MEMBERS(OneFlavourRationalParams, + RealD, lo, + RealD, hi, + int, MaxIter, + RealD, tolerance, + int, degree, + int, precision); + + // MaxIter and tolerance, vectors?? + + // constructor + OneFlavourRationalParams( RealD _lo = 0.0, + RealD _hi = 1.0, + int _maxit = 1000, + RealD tol = 1.0e-8, + int _degree = 10, + int _precision = 64) + : lo(_lo), + hi(_hi), + MaxIter(_maxit), + tolerance(tol), + degree(_degree), + precision(_precision){}; + }; + + +} +} - struct StaggeredImplParams { - StaggeredImplParams() {}; - }; - struct OneFlavourRationalParams { - RealD lo; - RealD hi; - int MaxIter; // Vector? - RealD tolerance; // Vector? - int degree=10; - int precision=64; - OneFlavourRationalParams (RealD _lo,RealD _hi,int _maxit,RealD tol=1.0e-8,int _degree = 10,int _precision=64) : - lo(_lo), hi(_hi), MaxIter(_maxit), tolerance(tol), degree(_degree), precision(_precision) - {}; - }; - -}} #endif diff --git a/lib/qcd/action/ActionSet.h b/lib/qcd/action/ActionSet.h new file mode 100644 index 00000000..4ed6a582 --- /dev/null +++ b/lib/qcd/action/ActionSet.h @@ -0,0 +1,116 @@ +/************************************************************************************* + +Grid physics library, www.github.com/paboyle/Grid + +Source file: ./lib/qcd/action/ActionSet.h + +Copyright (C) 2015 + +Author: Peter Boyle +Author: neo + +This program is free software; you can redistribute it and/or modify +it under the terms of the GNU General Public License as published by +the Free Software Foundation; either version 2 of the License, or +(at your option) any later version. + +This program is distributed in the hope that it will be useful, +but WITHOUT ANY WARRANTY; without even the implied warranty of +MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +GNU General Public License for more details. + +You should have received a copy of the GNU General Public License along +with this program; if not, write to the Free Software Foundation, Inc., +51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA. + +See the full license in the file "LICENSE" in the top level distribution +directory +*************************************************************************************/ +/* END LEGAL */ +#ifndef ACTION_SET_H +#define ACTION_SET_H + +namespace Grid { + +// Should drop this namespace here +namespace QCD { + +////////////////////////////////// +// Indexing of tuple types +////////////////////////////////// + +template +struct Index; + +template +struct Index> { + static const std::size_t value = 0; +}; + +template +struct Index> { + static const std::size_t value = 1 + Index>::value; +}; + + +//////////////////////////////////////////// +// Action Level +// Action collection +// in a integration level +// (for multilevel integration schemes) +//////////////////////////////////////////// + +template +struct ActionLevel { + public: + unsigned int multiplier; + + // Fundamental repr actions separated because of the smearing + typedef Action* ActPtr; + + // construct a tuple of vectors of the actions for the corresponding higher + // representation fields + typedef typename AccessTypes::VectorCollection action_collection; + typedef typename AccessTypes::FieldTypeCollection action_hirep_types; + + action_collection actions_hirep; + std::vector& actions; + + explicit ActionLevel(unsigned int mul = 1) : + actions(std::get<0>(actions_hirep)), multiplier(mul) { + // initialize the hirep vectors to zero. + // apply(this->resize, actions_hirep, 0); //need a working resize + assert(mul >= 1); + } + + template < class GenField > + void push_back(Action* ptr) { + // insert only in the correct vector + std::get< Index < GenField, action_hirep_types>::value >(actions_hirep).push_back(ptr); + }; + + template + static void resize(ActPtr ap, unsigned int n) { + ap->resize(n); + } + + // Loop on tuple for a callable function + template + inline typename std::enable_if::value, void>::type apply(Callable, Repr& R,Args&...) const {} + + template + inline typename std::enable_if::value, void>::type apply(Callable fn, Repr& R, Args&... arguments) const { + fn(std::get(actions_hirep), std::get(R.rep), arguments...); + apply(fn, R, arguments...); + } + +}; + +// Define the ActionSet +template +using ActionSet = std::vector >; + +} // QCD +} // Grid + +#endif // ACTION_SET_H diff --git a/lib/qcd/action/fermion/.dirstamp b/lib/qcd/action/fermion/.dirstamp deleted file mode 100644 index e69de29b..00000000 diff --git a/lib/qcd/action/fermion/AbstractEOFAFermion.h b/lib/qcd/action/fermion/AbstractEOFAFermion.h new file mode 100644 index 00000000..15faa401 --- /dev/null +++ b/lib/qcd/action/fermion/AbstractEOFAFermion.h @@ -0,0 +1,100 @@ +/************************************************************************************* + +Grid physics library, www.github.com/paboyle/Grid + +Source file: ./lib/qcd/action/fermion/AbstractEOFAFermion.h + +Copyright (C) 2017 + +Author: Peter Boyle +Author: Peter Boyle +Author: David Murphy + +This program is free software; you can redistribute it and/or modify +it under the terms of the GNU General Public License as published by +the Free Software Foundation; either version 2 of the License, or +(at your option) any later version. + +This program is distributed in the hope that it will be useful, +but WITHOUT ANY WARRANTY; without even the implied warranty of +MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +GNU General Public License for more details. + +You should have received a copy of the GNU General Public License along +with this program; if not, write to the Free Software Foundation, Inc., +51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA. + +See the full license in the file "LICENSE" in the top level distribution directory +*************************************************************************************/ +/* END LEGAL */ +#ifndef GRID_QCD_ABSTRACT_EOFA_FERMION_H +#define GRID_QCD_ABSTRACT_EOFA_FERMION_H + +#include + +namespace Grid { +namespace QCD { + + // DJM: Abstract base class for EOFA fermion types. + // Defines layout of additional EOFA-specific parameters and operators. + // Use to construct EOFA pseudofermion actions that are agnostic to + // Shamir / Mobius / etc., and ensure that no one can construct EOFA + // pseudofermion action with non-EOFA fermion type. + template + class AbstractEOFAFermion : public CayleyFermion5D { + public: + INHERIT_IMPL_TYPES(Impl); + + public: + // Fermion operator: D(mq1) + shift*\gamma_{5}*R_{5}*\Delta_{\pm}(mq2,mq3)*P_{\pm} + RealD mq1; + RealD mq2; + RealD mq3; + RealD shift; + int pm; + + RealD alpha; // Mobius scale + RealD k; // EOFA normalization constant + + virtual void Instantiatable(void) = 0; + + // EOFA-specific operations + // Force user to implement in derived classes + virtual void Omega (const FermionField& in, FermionField& out, int sign, int dag) = 0; + virtual void Dtilde (const FermionField& in, FermionField& out) = 0; + virtual void DtildeInv(const FermionField& in, FermionField& out) = 0; + + // Implement derivatives in base class: + // for EOFA both DWF and Mobius just need d(Dw)/dU + virtual void MDeriv(GaugeField& mat, const FermionField& U, const FermionField& V, int dag){ + this->DhopDeriv(mat, U, V, dag); + }; + virtual void MoeDeriv(GaugeField& mat, const FermionField& U, const FermionField& V, int dag){ + this->DhopDerivOE(mat, U, V, dag); + }; + virtual void MeoDeriv(GaugeField& mat, const FermionField& U, const FermionField& V, int dag){ + this->DhopDerivEO(mat, U, V, dag); + }; + + // Recompute 5D coefficients for different value of shift constant + // (needed for heatbath loop over poles) + virtual void RefreshShiftCoefficients(RealD new_shift) = 0; + + // Constructors + AbstractEOFAFermion(GaugeField& _Umu, GridCartesian& FiveDimGrid, GridRedBlackCartesian& FiveDimRedBlackGrid, + GridCartesian& FourDimGrid, GridRedBlackCartesian& FourDimRedBlackGrid, + RealD _mq1, RealD _mq2, RealD _mq3, RealD _shift, int _pm, + RealD _M5, RealD _b, RealD _c, const ImplParams& p=ImplParams()) + : CayleyFermion5D(_Umu, FiveDimGrid, FiveDimRedBlackGrid, FourDimGrid, FourDimRedBlackGrid, + _mq1, _M5, p), mq1(_mq1), mq2(_mq2), mq3(_mq3), shift(_shift), pm(_pm) + { + int Ls = this->Ls; + this->alpha = _b + _c; + this->k = this->alpha * (_mq3-_mq2) * std::pow(this->alpha+1.0,2*Ls) / + ( std::pow(this->alpha+1.0,Ls) + _mq2*std::pow(this->alpha-1.0,Ls) ) / + ( std::pow(this->alpha+1.0,Ls) + _mq3*std::pow(this->alpha-1.0,Ls) ); + }; + }; +}} + +#endif diff --git a/lib/qcd/action/fermion/CayleyFermion5D.cc b/lib/qcd/action/fermion/CayleyFermion5D.cc index b8e98dce..838b1c3d 100644 --- a/lib/qcd/action/fermion/CayleyFermion5D.cc +++ b/lib/qcd/action/fermion/CayleyFermion5D.cc @@ -29,8 +29,9 @@ Author: paboyle *************************************************************************************/ /* END LEGAL */ -#include - +#include +#include +#include namespace Grid { namespace QCD { @@ -48,18 +49,31 @@ namespace QCD { FourDimGrid, FourDimRedBlackGrid,_M5,p), mass(_mass) - { } + { + } template void CayleyFermion5D::Dminus(const FermionField &psi, FermionField &chi) { int Ls=this->Ls; - FermionField tmp(psi._grid); - this->DW(psi,tmp,DaggerNo); + FermionField tmp_f(this->FermionGrid()); + this->DW(psi,tmp_f,DaggerNo); for(int s=0;s +void CayleyFermion5D::DminusDag(const FermionField &psi, FermionField &chi) +{ + int Ls=this->Ls; + + FermionField tmp_f(this->FermionGrid()); + this->DW(psi,tmp_f,DaggerYes); + + for(int s=0;s void CayleyFermion5D::CayleyReport(void) std::cout << GridLogMessage << "CayleyFermion5D Number of MooeeInv Calls : " << MooeeInvCalls << std::endl; std::cout << GridLogMessage << "CayleyFermion5D ComputeTime/Calls : " << MooeeInvTime / MooeeInvCalls << " us" << std::endl; - // Flops = 9*12*Ls*vol/2 - RealD mflops = 9.0*12*volume*MooeeInvCalls/MooeeInvTime/2; // 2 for red black counting + // Flops = MADD * Ls *Ls *4dvol * spin/colour/complex + RealD mflops = 2.0*24*this->Ls*volume*MooeeInvCalls/MooeeInvTime/2; // 2 for red black counting std::cout << GridLogMessage << "Average mflops/s per call : " << mflops << std::endl; std::cout << GridLogMessage << "Average mflops/s per call per rank : " << mflops/NP << std::endl; } @@ -106,18 +120,6 @@ template void CayleyFermion5D::CayleyZeroCounters(void) } -template -void CayleyFermion5D::DminusDag(const FermionField &psi, FermionField &chi) -{ - int Ls=this->Ls; - FermionField tmp(psi._grid); - - this->DW(psi,tmp,DaggerYes); - - for(int s=0;s void CayleyFermion5D::M5D (const FermionField &psi, FermionField &chi) { @@ -138,6 +140,7 @@ void CayleyFermion5D::Meooe5D (const FermionField &psi, FermionField &D lower[0] =-mass*lower[0]; M5D(psi,psi,Din,lower,diag,upper); } +// FIXME Redunant with the above routine; check this and eliminate template void CayleyFermion5D::Meo5D (const FermionField &psi, FermionField &chi) { int Ls=this->Ls; @@ -167,7 +170,6 @@ void CayleyFermion5D::Mooee (const FermionField &psi, FermionField & lower[0] =-mass*lower[0]; M5D(psi,psi,chi,lower,diag,upper); } - template void CayleyFermion5D::MooeeDag (const FermionField &psi, FermionField &chi) { @@ -189,7 +191,12 @@ void CayleyFermion5D::MooeeDag (const FermionField &psi, FermionField & lower[s]=-cee[s-1]; } } - + // Conjugate the terms + for (int s=0;s::MeooeDag5D (const FermionField &psi, FermionField int Ls=this->Ls; std::vector diag =bs; std::vector upper=cs; - std::vector lower=cs; - upper[Ls-1]=-mass*upper[Ls-1]; - lower[0] =-mass*lower[0]; + std::vector lower=cs; + + for (int s=0;s void CayleyFermion5D::Meooe (const FermionField &psi, FermionField &chi) { int Ls=this->Ls; - FermionField tmp(psi._grid); - Meooe5D(psi,tmp); + Meooe5D(psi,this->tmp()); if ( psi.checkerboard == Odd ) { - this->DhopEO(tmp,chi,DaggerNo); + this->DhopEO(this->tmp(),chi,DaggerNo); } else { - this->DhopOE(tmp,chi,DaggerNo); + this->DhopOE(this->tmp(),chi,DaggerNo); } } template void CayleyFermion5D::MeooeDag (const FermionField &psi, FermionField &chi) { - FermionField tmp(psi._grid); // Apply 4d dslash if ( psi.checkerboard == Odd ) { - this->DhopEO(psi,tmp,DaggerYes); + this->DhopEO(psi,this->tmp(),DaggerYes); } else { - this->DhopOE(psi,tmp,DaggerYes); + this->DhopOE(psi,this->tmp(),DaggerYes); } - MeooeDag5D(tmp,chi); + MeooeDag5D(this->tmp(),chi); } template void CayleyFermion5D::Mdir (const FermionField &psi, FermionField &chi,int dir,int disp){ - FermionField tmp(psi._grid); - Meo5D(psi,tmp); + Meo5D(psi,this->tmp()); // Apply 4d dslash fragment - this->DhopDir(tmp,chi,dir,disp); + this->DhopDir(this->tmp(),chi,dir,disp); } // force terms; five routines; default to Dhop on diagonal template @@ -317,8 +335,8 @@ void CayleyFermion5D::MoeDeriv(GaugeField &mat,const FermionField &U,const this->DhopDerivOE(mat,U,Din,dag); } else { // U d/du [D_w D5]^dag V = U D5^dag d/du DW^dag Y // implicit adj on U in call - Meooe5D(U,Din); - this->DhopDerivOE(mat,Din,V,dag); + Meooe5D(U,Din); + this->DhopDerivOE(mat,Din,V,dag); } }; template @@ -362,6 +380,8 @@ void CayleyFermion5D::SetCoefficientsInternal(RealD zolo_hi,std::vector::SetCoefficientsInternal(RealD zolo_hi,std::vector::SetCoefficientsInternal(RealD zolo_hi,std::vectorM5) +1.0); + bee[i]=as[i]*(bs[i]*(4.0-this->M5) +1.0); + assert(bee[i]!=Coeff_t(0.0)); cee[i]=as[i]*(1.0-cs[i]*(4.0-this->M5)); beo[i]=as[i]*bs[i]; ceo[i]=-as[i]*cs[i]; } - aee.resize(Ls); aeo.resize(Ls); for(int i=0;i::SetCoefficientsInternal(RealD zolo_hi,std::vector::SetCoefficientsInternal(RealD zolo_hi,std::vectorMooeeInternalCompute(0,inv,MatpInv,MatmInv); + this->MooeeInternalCompute(1,inv,MatpInvDag,MatmInvDag); } +template +void CayleyFermion5D::MooeeInternalCompute(int dag, int inv, + Vector > & Matp, + Vector > & Matm) +{ + int Ls=this->Ls; + + GridBase *grid = this->FermionRedBlackGrid(); + int LLs = grid->_rdimensions[0]; + + if ( LLs == Ls ) { + return; // Not vectorised in 5th direction + } + + Eigen::MatrixXcd Pplus = Eigen::MatrixXcd::Zero(Ls,Ls); + Eigen::MatrixXcd Pminus = Eigen::MatrixXcd::Zero(Ls,Ls); + + for(int s=0;s::iscomplex() ) { + sp[l] = PplusMat (l*istride+s1*ostride,s2); + sm[l] = PminusMat(l*istride+s1*ostride,s2); + } else { + // if real + scalar_type tmp; + tmp = PplusMat (l*istride+s1*ostride,s2); + sp[l] = scalar_type(tmp.real(),tmp.real()); + tmp = PminusMat(l*istride+s1*ostride,s2); + sm[l] = scalar_type(tmp.real(),tmp.real()); + } + } + Matp[LLs*s2+s1] = Vp; + Matm[LLs*s2+s1] = Vm; + }} +} + FermOpTemplateInstantiate(CayleyFermion5D); GparityFermOpTemplateInstantiate(CayleyFermion5D); diff --git a/lib/qcd/action/fermion/CayleyFermion5D.h b/lib/qcd/action/fermion/CayleyFermion5D.h index 6fb58234..ef75235a 100644 --- a/lib/qcd/action/fermion/CayleyFermion5D.h +++ b/lib/qcd/action/fermion/CayleyFermion5D.h @@ -1,6 +1,6 @@ /************************************************************************************* - Grid physics library, www.github.com/paboyle/Grid + Grid physics library, www.github.com/paboyle/Grid Source file: ./lib/qcd/action/fermion/CayleyFermion5D.h @@ -29,10 +29,37 @@ Author: Peter Boyle #ifndef GRID_QCD_CAYLEY_FERMION_H #define GRID_QCD_CAYLEY_FERMION_H +#include + namespace Grid { namespace QCD { + template struct switcheroo { + static inline int iscomplex() { return 0; } + + template + static inline vec mult(vec a, vec b) { + return real_mult(a,b); + } + }; + template<> struct switcheroo { + static inline int iscomplex() { return 1; } + + template + static inline vec mult(vec a, vec b) { + return a*b; + } + }; + template<> struct switcheroo { + static inline int iscomplex() { return 1; } + template + static inline vec mult(vec a, vec b) { + return a*b; + } + }; + + template class CayleyFermion5D : public WilsonFermion5D { @@ -63,19 +90,31 @@ namespace Grid { // Instantiate different versions depending on Impl ///////////////////////////////////////////////////// void M5D(const FermionField &psi, - const FermionField &phi, + const FermionField &phi, FermionField &chi, std::vector &lower, std::vector &diag, std::vector &upper); void M5Ddag(const FermionField &psi, - const FermionField &phi, + const FermionField &phi, FermionField &chi, std::vector &lower, std::vector &diag, std::vector &upper); + void MooeeInternal(const FermionField &in, FermionField &out,int dag,int inv); + void MooeeInternalCompute(int dag, int inv, Vector > & Matp, Vector > & Matm); + + void MooeeInternalAsm(const FermionField &in, FermionField &out, + int LLs, int site, + Vector > &Matp, + Vector > &Matm); + void MooeeInternalZAsm(const FermionField &in, FermionField &out, + int LLs, int site, + Vector > &Matp, + Vector > &Matm); + virtual void Instantiatable(void)=0; @@ -86,7 +125,7 @@ namespace Grid { // Efficient support for multigrid coarsening virtual void Mdir (const FermionField &in, FermionField &out,int dir,int disp); - + void Meooe5D (const FermionField &in, FermionField &out); void MeooeDag5D (const FermionField &in, FermionField &out); @@ -94,23 +133,29 @@ namespace Grid { RealD mass; // Cayley form Moebius (tanh and zolotarev) - std::vector omega; + std::vector omega; std::vector bs; // S dependent coeffs - std::vector cs; - std::vector as; + std::vector cs; + std::vector as; // For preconditioning Cayley form - std::vector bee; - std::vector cee; - std::vector aee; - std::vector beo; - std::vector ceo; - std::vector aeo; + std::vector bee; + std::vector cee; + std::vector aee; + std::vector beo; + std::vector ceo; + std::vector aeo; // LDU factorisation of the eeoo matrix - std::vector lee; - std::vector leem; - std::vector uee; - std::vector ueem; - std::vector dee; + std::vector lee; + std::vector leem; + std::vector uee; + std::vector ueem; + std::vector dee; + + // Matrices of 5d ee inverse params + Vector > MatpInv; + Vector > MatmInv; + Vector > MatpInvDag; + Vector > MatmInvDag; // Constructors CayleyFermion5D(GaugeField &_Umu, @@ -120,7 +165,7 @@ namespace Grid { GridRedBlackCartesian &FourDimRedBlackGrid, RealD _mass,RealD _M5,const ImplParams &p= ImplParams()); - + void CayleyReport(void); void CayleyZeroCounters(void); @@ -134,9 +179,9 @@ namespace Grid { double MooeeInvTime; protected: - void SetCoefficientsZolotarev(RealD zolohi,Approx::zolotarev_data *zdata,RealD b,RealD c); - void SetCoefficientsTanh(Approx::zolotarev_data *zdata,RealD b,RealD c); - void SetCoefficientsInternal(RealD zolo_hi,std::vector & gamma,RealD b,RealD c); + virtual void SetCoefficientsZolotarev(RealD zolohi,Approx::zolotarev_data *zdata,RealD b,RealD c); + virtual void SetCoefficientsTanh(Approx::zolotarev_data *zdata,RealD b,RealD c); + virtual void SetCoefficientsInternal(RealD zolo_hi,std::vector & gamma,RealD b,RealD c); }; } @@ -149,7 +194,9 @@ template void CayleyFermion5D< A >::M5Ddag(const FermionField &psi,const Fermion template void CayleyFermion5D< A >::MooeeInv (const FermionField &psi, FermionField &chi); \ template void CayleyFermion5D< A >::MooeeInvDag (const FermionField &psi, FermionField &chi); -#define CAYLEY_DPERP_CACHE +#undef CAYLEY_DPERP_DENSE +#define CAYLEY_DPERP_CACHE #undef CAYLEY_DPERP_LINALG +#define CAYLEY_DPERP_VEC #endif diff --git a/lib/qcd/action/fermion/CayleyFermion5Dcache.cc b/lib/qcd/action/fermion/CayleyFermion5Dcache.cc index 8e7df945..dd6ec7bf 100644 --- a/lib/qcd/action/fermion/CayleyFermion5Dcache.cc +++ b/lib/qcd/action/fermion/CayleyFermion5Dcache.cc @@ -29,7 +29,8 @@ Author: paboyle *************************************************************************************/ /* END LEGAL */ -#include +#include +#include namespace Grid { @@ -54,8 +55,8 @@ void CayleyFermion5D::M5D(const FermionField &psi, // Flops = 6.0*(Nc*Ns) *Ls*vol M5Dcalls++; M5Dtime-=usecond(); -PARALLEL_FOR_LOOP - for(int ss=0;ssoSites();ss+=Ls){ // adds Ls + + parallel_for(int ss=0;ssoSites();ss+=Ls){ // adds Ls for(int s=0;s::M5Ddag(const FermionField &psi, // Flops = 6.0*(Nc*Ns) *Ls*vol M5Dcalls++; M5Dtime-=usecond(); -PARALLEL_FOR_LOOP - for(int ss=0;ssoSites();ss+=Ls){ // adds Ls + + parallel_for(int ss=0;ssoSites();ss+=Ls){ // adds Ls auto tmp = psi._odata[0]; for(int s=0;s::MooeeInv (const FermionField &psi, FermionField & MooeeInvCalls++; MooeeInvTime-=usecond(); -PARALLEL_FOR_LOOP - for(int ss=0;ssoSites();ss+=Ls){ // adds Ls + parallel_for(int ss=0;ssoSites();ss+=Ls){ // adds Ls auto tmp = psi._odata[0]; // flops = 12*2*Ls + 12*2*Ls + 3*12*Ls + 12*2*Ls = 12*Ls * (9) = 108*Ls flops @@ -181,11 +181,22 @@ void CayleyFermion5D::MooeeInvDag (const FermionField &psi, FermionField & assert(psi.checkerboard == psi.checkerboard); chi.checkerboard=psi.checkerboard; + std::vector ueec(Ls); + std::vector deec(Ls); + std::vector leec(Ls); + std::vector ueemc(Ls); + std::vector leemc(Ls); + for(int s=0;soSites();ss+=Ls){ // adds Ls + parallel_for(int ss=0;ssoSites();ss+=Ls){ // adds Ls auto tmp = psi._odata[0]; @@ -193,25 +204,25 @@ PARALLEL_FOR_LOOP chi[ss]=psi[ss]; for (int s=1;s=0;s--){ spProj5p(tmp,chi[ss+s+1]); - chi[ss+s] = chi[ss+s] - lee[s]*tmp; + chi[ss+s] = chi[ss+s] - leec[s]*tmp; } } @@ -226,6 +237,13 @@ PARALLEL_FOR_LOOP INSTANTIATE_DPERP(GparityWilsonImplD); INSTANTIATE_DPERP(ZWilsonImplF); INSTANTIATE_DPERP(ZWilsonImplD); + + INSTANTIATE_DPERP(WilsonImplFH); + INSTANTIATE_DPERP(WilsonImplDF); + INSTANTIATE_DPERP(GparityWilsonImplFH); + INSTANTIATE_DPERP(GparityWilsonImplDF); + INSTANTIATE_DPERP(ZWilsonImplFH); + INSTANTIATE_DPERP(ZWilsonImplDF); #endif }} diff --git a/lib/qcd/action/fermion/CayleyFermion5Ddense.cc b/lib/qcd/action/fermion/CayleyFermion5Ddense.cc index 5fa75b50..4014675a 100644 --- a/lib/qcd/action/fermion/CayleyFermion5Ddense.cc +++ b/lib/qcd/action/fermion/CayleyFermion5Ddense.cc @@ -29,8 +29,9 @@ Author: paboyle *************************************************************************************/ /* END LEGAL */ -#include -#include +#include +#include +#include namespace Grid { @@ -38,20 +39,17 @@ namespace QCD { /* * Dense matrix versions of routines */ - - /* template void CayleyFermion5D::MooeeInvDag (const FermionField &psi, FermionField &chi) { this->MooeeInternal(psi,chi,DaggerYes,InverseYes); } - template void CayleyFermion5D::MooeeInv(const FermionField &psi, FermionField &chi) { this->MooeeInternal(psi,chi,DaggerNo,InverseYes); } - */ + template void CayleyFermion5D::MooeeInternal(const FermionField &psi, FermionField &chi,int dag, int inv) { @@ -125,9 +123,34 @@ void CayleyFermion5D::MooeeInternal(const FermionField &psi, FermionField } } +#ifdef CAYLEY_DPERP_DENSE +INSTANTIATE_DPERP(GparityWilsonImplF); +INSTANTIATE_DPERP(GparityWilsonImplD); +INSTANTIATE_DPERP(WilsonImplF); +INSTANTIATE_DPERP(WilsonImplD); +INSTANTIATE_DPERP(ZWilsonImplF); +INSTANTIATE_DPERP(ZWilsonImplD); + template void CayleyFermion5D::MooeeInternal(const FermionField &psi, FermionField &chi,int dag, int inv); template void CayleyFermion5D::MooeeInternal(const FermionField &psi, FermionField &chi,int dag, int inv); template void CayleyFermion5D::MooeeInternal(const FermionField &psi, FermionField &chi,int dag, int inv); template void CayleyFermion5D::MooeeInternal(const FermionField &psi, FermionField &chi,int dag, int inv); +template void CayleyFermion5D::MooeeInternal(const FermionField &psi, FermionField &chi,int dag, int inv); +template void CayleyFermion5D::MooeeInternal(const FermionField &psi, FermionField &chi,int dag, int inv); + +INSTANTIATE_DPERP(GparityWilsonImplFH); +INSTANTIATE_DPERP(GparityWilsonImplDF); +INSTANTIATE_DPERP(WilsonImplFH); +INSTANTIATE_DPERP(WilsonImplDF); +INSTANTIATE_DPERP(ZWilsonImplFH); +INSTANTIATE_DPERP(ZWilsonImplDF); + +template void CayleyFermion5D::MooeeInternal(const FermionField &psi, FermionField &chi,int dag, int inv); +template void CayleyFermion5D::MooeeInternal(const FermionField &psi, FermionField &chi,int dag, int inv); +template void CayleyFermion5D::MooeeInternal(const FermionField &psi, FermionField &chi,int dag, int inv); +template void CayleyFermion5D::MooeeInternal(const FermionField &psi, FermionField &chi,int dag, int inv); +template void CayleyFermion5D::MooeeInternal(const FermionField &psi, FermionField &chi,int dag, int inv); +template void CayleyFermion5D::MooeeInternal(const FermionField &psi, FermionField &chi,int dag, int inv); +#endif }} diff --git a/lib/qcd/action/fermion/CayleyFermion5Dssp.cc b/lib/qcd/action/fermion/CayleyFermion5Dssp.cc index ad7daddb..cb9b2957 100644 --- a/lib/qcd/action/fermion/CayleyFermion5Dssp.cc +++ b/lib/qcd/action/fermion/CayleyFermion5Dssp.cc @@ -29,14 +29,14 @@ Author: paboyle *************************************************************************************/ /* END LEGAL */ -#include +#include +#include namespace Grid { namespace QCD { // FIXME -- make a version of these routines with site loop outermost for cache reuse. - // Pminus fowards // Pplus backwards template @@ -47,17 +47,18 @@ void CayleyFermion5D::M5D(const FermionField &psi, std::vector &diag, std::vector &upper) { + Coeff_t one(1.0); int Ls=this->Ls; for(int s=0;s::M5Ddag(const FermionField &psi, std::vector &diag, std::vector &upper) { + Coeff_t one(1.0); int Ls=this->Ls; for(int s=0;s::M5Ddag(const FermionField &psi, template void CayleyFermion5D::MooeeInv (const FermionField &psi, FermionField &chi) { + Coeff_t one(1.0); + Coeff_t czero(0.0); chi.checkerboard=psi.checkerboard; int Ls=this->Ls; // Apply (L^{\prime})^{-1} - axpby_ssp (chi,1.0,psi, 0.0,psi,0,0); // chi[0]=psi[0] + axpby_ssp (chi,one,psi, czero,psi,0,0); // chi[0]=psi[0] for (int s=1;s=0;s--){ - axpby_ssp_pminus (chi,1.0,chi,-uee[s],chi,s,s+1); // chi[Ls] + axpby_ssp_pminus (chi,one,chi,-uee[s],chi,s,s+1); // chi[Ls] } } template void CayleyFermion5D::MooeeInvDag (const FermionField &psi, FermionField &chi) { + Coeff_t one(1.0); + Coeff_t czero(0.0); chi.checkerboard=psi.checkerboard; int Ls=this->Ls; // Apply (U^{\prime})^{-dagger} - axpby_ssp (chi,1.0,psi, 0.0,psi,0,0); // chi[0]=psi[0] + axpby_ssp (chi,one,psi, czero,psi,0,0); // chi[0]=psi[0] for (int s=1;s=0;s--){ - axpby_ssp_pplus (chi,1.0,chi,-lee[s],chi,s,s+1); // chi[Ls] + axpby_ssp_pplus (chi,one,chi,-conjugate(lee[s]),chi,s,s+1); // chi[Ls] } } #ifdef CAYLEY_DPERP_LINALG - INSTANTIATE(WilsonImplF); - INSTANTIATE(WilsonImplD); - INSTANTIATE(GparityWilsonImplF); - INSTANTIATE(GparityWilsonImplD); + INSTANTIATE_DPERP(WilsonImplF); + INSTANTIATE_DPERP(WilsonImplD); + INSTANTIATE_DPERP(GparityWilsonImplF); + INSTANTIATE_DPERP(GparityWilsonImplD); + INSTANTIATE_DPERP(ZWilsonImplF); + INSTANTIATE_DPERP(ZWilsonImplD); + + INSTANTIATE_DPERP(WilsonImplFH); + INSTANTIATE_DPERP(WilsonImplDF); + INSTANTIATE_DPERP(GparityWilsonImplFH); + INSTANTIATE_DPERP(GparityWilsonImplDF); + INSTANTIATE_DPERP(ZWilsonImplFH); + INSTANTIATE_DPERP(ZWilsonImplDF); #endif } diff --git a/lib/qcd/action/fermion/CayleyFermion5Dvec.cc b/lib/qcd/action/fermion/CayleyFermion5Dvec.cc index 35a10de2..653e6ab3 100644 --- a/lib/qcd/action/fermion/CayleyFermion5Dvec.cc +++ b/lib/qcd/action/fermion/CayleyFermion5Dvec.cc @@ -29,12 +29,13 @@ Author: paboyle *************************************************************************************/ /* END LEGAL */ -#include -#include + +#include +#include namespace Grid { -namespace QCD { +namespace QCD { /* * Dense matrix versions of routines */ @@ -92,8 +93,7 @@ void CayleyFermion5D::M5D(const FermionField &psi, assert(Nc==3); -PARALLEL_FOR_LOOP - for(int ss=0;ssoSites();ss+=LLs){ // adds LLs + parallel_for(int ss=0;ssoSites();ss+=LLs){ // adds LLs #if 0 alignas(64) SiteHalfSpinor hp; alignas(64) SiteHalfSpinor hm; @@ -126,7 +126,6 @@ PARALLEL_FOR_LOOP for(int v=0;v(hp_00.v); hp_01.v = Optimization::Rotate::tRotate<2>(hp_01.v); @@ -165,42 +161,20 @@ PARALLEL_FOR_LOOP hm_12.v = Optimization::Rotate::tRotate<2*Simd::Nsimd()-2>(hm_12.v); } - /* - if ( ss==0) std::cout << " dphi_00 " <::mult(d[v]()()(), phi[ss+v]()(0)(0)) + switcheroo::mult(l[v]()()(),hm_00); + Simd p_01 = switcheroo::mult(d[v]()()(), phi[ss+v]()(0)(1)) + switcheroo::mult(l[v]()()(),hm_01); + Simd p_02 = switcheroo::mult(d[v]()()(), phi[ss+v]()(0)(2)) + switcheroo::mult(l[v]()()(),hm_02); + Simd p_10 = switcheroo::mult(d[v]()()(), phi[ss+v]()(1)(0)) + switcheroo::mult(l[v]()()(),hm_10); + Simd p_11 = switcheroo::mult(d[v]()()(), phi[ss+v]()(1)(1)) + switcheroo::mult(l[v]()()(),hm_11); + Simd p_12 = switcheroo::mult(d[v]()()(), phi[ss+v]()(1)(2)) + switcheroo::mult(l[v]()()(),hm_12); + Simd p_20 = switcheroo::mult(d[v]()()(), phi[ss+v]()(2)(0)) + switcheroo::mult(u[v]()()(),hp_00); + Simd p_21 = switcheroo::mult(d[v]()()(), phi[ss+v]()(2)(1)) + switcheroo::mult(u[v]()()(),hp_01); + Simd p_22 = switcheroo::mult(d[v]()()(), phi[ss+v]()(2)(2)) + switcheroo::mult(u[v]()()(),hp_02); + Simd p_30 = switcheroo::mult(d[v]()()(), phi[ss+v]()(3)(0)) + switcheroo::mult(u[v]()()(),hp_10); + Simd p_31 = switcheroo::mult(d[v]()()(), phi[ss+v]()(3)(1)) + switcheroo::mult(u[v]()()(),hp_11); + Simd p_32 = switcheroo::mult(d[v]()()(), phi[ss+v]()(3)(2)) + switcheroo::mult(u[v]()()(),hp_12); - - // if ( ss==0){ - /* - std::cout << ss<<" "<< v<< " good "<< chi[ss+v]()(0)(0) << " bad "<::M5Ddag(const FermionField &psi, M5Dcalls++; M5Dtime-=usecond(); -PARALLEL_FOR_LOOP - for(int ss=0;ssoSites();ss+=LLs){ // adds LLs - + parallel_for(int ss=0;ssoSites();ss+=LLs){ // adds LLs +#if 0 alignas(64) SiteHalfSpinor hp; alignas(64) SiteHalfSpinor hm; alignas(64) SiteSpinor fp; @@ -287,9 +260,504 @@ PARALLEL_FOR_LOOP chi[ss+v] = chi[ss+v] +l[v]*fm; } +#else + for(int v=0;v(hp_00.v); + hp_01.v = Optimization::Rotate::tRotate<2>(hp_01.v); + hp_02.v = Optimization::Rotate::tRotate<2>(hp_02.v); + hp_10.v = Optimization::Rotate::tRotate<2>(hp_10.v); + hp_11.v = Optimization::Rotate::tRotate<2>(hp_11.v); + hp_12.v = Optimization::Rotate::tRotate<2>(hp_12.v); + } + if ( vm>=v ) { + hm_00.v = Optimization::Rotate::tRotate<2*Simd::Nsimd()-2>(hm_00.v); + hm_01.v = Optimization::Rotate::tRotate<2*Simd::Nsimd()-2>(hm_01.v); + hm_02.v = Optimization::Rotate::tRotate<2*Simd::Nsimd()-2>(hm_02.v); + hm_10.v = Optimization::Rotate::tRotate<2*Simd::Nsimd()-2>(hm_10.v); + hm_11.v = Optimization::Rotate::tRotate<2*Simd::Nsimd()-2>(hm_11.v); + hm_12.v = Optimization::Rotate::tRotate<2*Simd::Nsimd()-2>(hm_12.v); + } + + Simd p_00 = switcheroo::mult(d[v]()()(), phi[ss+v]()(0)(0)) + switcheroo::mult(u[v]()()(),hp_00); + Simd p_01 = switcheroo::mult(d[v]()()(), phi[ss+v]()(0)(1)) + switcheroo::mult(u[v]()()(),hp_01); + Simd p_02 = switcheroo::mult(d[v]()()(), phi[ss+v]()(0)(2)) + switcheroo::mult(u[v]()()(),hp_02); + Simd p_10 = switcheroo::mult(d[v]()()(), phi[ss+v]()(1)(0)) + switcheroo::mult(u[v]()()(),hp_10); + Simd p_11 = switcheroo::mult(d[v]()()(), phi[ss+v]()(1)(1)) + switcheroo::mult(u[v]()()(),hp_11); + Simd p_12 = switcheroo::mult(d[v]()()(), phi[ss+v]()(1)(2)) + switcheroo::mult(u[v]()()(),hp_12); + + Simd p_20 = switcheroo::mult(d[v]()()(), phi[ss+v]()(2)(0)) + switcheroo::mult(l[v]()()(),hm_00); + Simd p_21 = switcheroo::mult(d[v]()()(), phi[ss+v]()(2)(1)) + switcheroo::mult(l[v]()()(),hm_01); + Simd p_22 = switcheroo::mult(d[v]()()(), phi[ss+v]()(2)(2)) + switcheroo::mult(l[v]()()(),hm_02); + Simd p_30 = switcheroo::mult(d[v]()()(), phi[ss+v]()(3)(0)) + switcheroo::mult(l[v]()()(),hm_10); + Simd p_31 = switcheroo::mult(d[v]()()(), phi[ss+v]()(3)(1)) + switcheroo::mult(l[v]()()(),hm_11); + Simd p_32 = switcheroo::mult(d[v]()()(), phi[ss+v]()(3)(2)) + switcheroo::mult(l[v]()()(),hm_12); + + vstream(chi[ss+v]()(0)(0),p_00); + vstream(chi[ss+v]()(0)(1),p_01); + vstream(chi[ss+v]()(0)(2),p_02); + vstream(chi[ss+v]()(1)(0),p_10); + vstream(chi[ss+v]()(1)(1),p_11); + vstream(chi[ss+v]()(1)(2),p_12); + vstream(chi[ss+v]()(2)(0),p_20); + vstream(chi[ss+v]()(2)(1),p_21); + vstream(chi[ss+v]()(2)(2),p_22); + vstream(chi[ss+v]()(3)(0),p_30); + vstream(chi[ss+v]()(3)(1),p_31); + vstream(chi[ss+v]()(3)(2),p_32); + } +#endif } M5Dtime+=usecond(); } + + +#ifdef AVX512 +#include +#include +#include +#endif + +template +void CayleyFermion5D::MooeeInternalAsm(const FermionField &psi, FermionField &chi, + int LLs, int site, + Vector > &Matp, + Vector > &Matm) +{ +#ifndef AVX512 + { + SiteHalfSpinor BcastP; + SiteHalfSpinor BcastM; + SiteHalfSpinor SiteChiP; + SiteHalfSpinor SiteChiM; + + // Ls*Ls * 2 * 12 * vol flops + for(int s1=0;s1); + for(int s1=0;s1 +void CayleyFermion5D::MooeeInternalZAsm(const FermionField &psi, FermionField &chi, + int LLs, int site, Vector > &Matp, Vector > &Matm) +{ +#ifndef AVX512 + { + SiteHalfSpinor BcastP; + SiteHalfSpinor BcastM; + SiteHalfSpinor SiteChiP; + SiteHalfSpinor SiteChiM; + + // Ls*Ls * 2 * 12 * vol flops + for(int s1=0;s1); + for(int s1=0;s1 void CayleyFermion5D::MooeeInternal(const FermionField &psi, FermionField &chi,int dag, int inv) { @@ -299,108 +767,39 @@ void CayleyFermion5D::MooeeInternal(const FermionField &psi, FermionField chi.checkerboard=psi.checkerboard; - Eigen::MatrixXcd Pplus = Eigen::MatrixXcd::Zero(Ls,Ls); - Eigen::MatrixXcd Pminus = Eigen::MatrixXcd::Zero(Ls,Ls); + Vector > Matp; + Vector > Matm; + Vector > *_Matp; + Vector > *_Matm; - for(int s=0;s > Matp(Ls*LLs); - Vector > Matm(Ls*LLs); + assert(_Matp->size()==Ls*LLs); - for(int s2=0;s2 SitePplus(LLs); - Vector SitePminus(LLs); - Vector SiteChiP(LLs); - Vector SiteChiM(LLs); - Vector SiteChi(LLs); - - SiteHalfSpinor BcastP; - SiteHalfSpinor BcastM; - -#pragma omp for - for(auto site=0;site::iscomplex() ) { + parallel_for(auto site=0;site::MooeeInternal(const FermionField &psi, FermionField &chi,int dag, int inv); template void CayleyFermion5D::MooeeInternal(const FermionField &psi, FermionField &chi,int dag, int inv); template void CayleyFermion5D::MooeeInternal(const FermionField &psi, FermionField &chi,int dag, int inv); template void CayleyFermion5D::MooeeInternal(const FermionField &psi, FermionField &chi,int dag, int inv); +template void CayleyFermion5D::MooeeInternal(const FermionField &psi, FermionField &chi,int dag, int inv); +template void CayleyFermion5D::MooeeInternal(const FermionField &psi, FermionField &chi,int dag, int inv); +template void CayleyFermion5D::MooeeInternal(const FermionField &psi, FermionField &chi,int dag, int inv); +template void CayleyFermion5D::MooeeInternal(const FermionField &psi, FermionField &chi,int dag, int inv); + + + }} diff --git a/lib/qcd/action/fermion/ContinuedFractionFermion5D.cc b/lib/qcd/action/fermion/ContinuedFractionFermion5D.cc index e58ab4da..5d39ef9b 100644 --- a/lib/qcd/action/fermion/ContinuedFractionFermion5D.cc +++ b/lib/qcd/action/fermion/ContinuedFractionFermion5D.cc @@ -26,7 +26,8 @@ Author: Peter Boyle See the full license in the file "LICENSE" in the top level distribution directory *************************************************************************************/ /* END LEGAL */ -#include +#include +#include namespace Grid { namespace QCD { diff --git a/lib/qcd/action/fermion/ContinuedFractionFermion5D.h b/lib/qcd/action/fermion/ContinuedFractionFermion5D.h index 15d44945..e1e50aa5 100644 --- a/lib/qcd/action/fermion/ContinuedFractionFermion5D.h +++ b/lib/qcd/action/fermion/ContinuedFractionFermion5D.h @@ -29,6 +29,8 @@ Author: Peter Boyle #ifndef GRID_QCD_CONTINUED_FRACTION_H #define GRID_QCD_CONTINUED_FRACTION_H +#include + namespace Grid { namespace QCD { diff --git a/lib/qcd/action/fermion/DomainWallEOFAFermion.cc b/lib/qcd/action/fermion/DomainWallEOFAFermion.cc new file mode 100644 index 00000000..dd8a500d --- /dev/null +++ b/lib/qcd/action/fermion/DomainWallEOFAFermion.cc @@ -0,0 +1,438 @@ +/************************************************************************************* + +Grid physics library, www.github.com/paboyle/Grid + +Source file: ./lib/qcd/action/fermion/DomainWallEOFAFermion.cc + +Copyright (C) 2017 + +Author: Peter Boyle +Author: Peter Boyle +Author: Peter Boyle +Author: paboyle +Author: David Murphy + +This program is free software; you can redistribute it and/or modify +it under the terms of the GNU General Public License as published by +the Free Software Foundation; either version 2 of the License, or +(at your option) any later version. + +This program is distributed in the hope that it will be useful, +but WITHOUT ANY WARRANTY; without even the implied warranty of +MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +GNU General Public License for more details. + +You should have received a copy of the GNU General Public License along +with this program; if not, write to the Free Software Foundation, Inc., +51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA. + +See the full license in the file "LICENSE" in the top level distribution directory +*************************************************************************************/ +/* END LEGAL */ + +#include +#include +#include + +namespace Grid { +namespace QCD { + + template + DomainWallEOFAFermion::DomainWallEOFAFermion( + GaugeField &_Umu, + GridCartesian &FiveDimGrid, + GridRedBlackCartesian &FiveDimRedBlackGrid, + GridCartesian &FourDimGrid, + GridRedBlackCartesian &FourDimRedBlackGrid, + RealD _mq1, RealD _mq2, RealD _mq3, + RealD _shift, int _pm, RealD _M5, const ImplParams &p) : + AbstractEOFAFermion(_Umu, FiveDimGrid, FiveDimRedBlackGrid, + FourDimGrid, FourDimRedBlackGrid, _mq1, _mq2, _mq3, + _shift, _pm, _M5, 1.0, 0.0, p) + { + RealD eps = 1.0; + Approx::zolotarev_data *zdata = Approx::higham(eps,this->Ls); + assert(zdata->n == this->Ls); + + std::cout << GridLogMessage << "DomainWallEOFAFermion with Ls=" << this->Ls << std::endl; + this->SetCoefficientsTanh(zdata, 1.0, 0.0); + + Approx::zolotarev_free(zdata); + } + + /*************************************************************** + /* Additional EOFA operators only called outside the inverter. + /* Since speed is not essential, simple axpby-style + /* implementations should be fine. + /***************************************************************/ + template + void DomainWallEOFAFermion::Omega(const FermionField& psi, FermionField& Din, int sign, int dag) + { + int Ls = this->Ls; + + Din = zero; + if((sign == 1) && (dag == 0)){ axpby_ssp(Din, 0.0, psi, 1.0, psi, Ls-1, 0); } + else if((sign == -1) && (dag == 0)){ axpby_ssp(Din, 0.0, psi, 1.0, psi, 0, 0); } + else if((sign == 1 ) && (dag == 1)){ axpby_ssp(Din, 0.0, psi, 1.0, psi, 0, Ls-1); } + else if((sign == -1) && (dag == 1)){ axpby_ssp(Din, 0.0, psi, 1.0, psi, 0, 0); } + } + + // This is just the identity for DWF + template + void DomainWallEOFAFermion::Dtilde(const FermionField& psi, FermionField& chi){ chi = psi; } + + // This is just the identity for DWF + template + void DomainWallEOFAFermion::DtildeInv(const FermionField& psi, FermionField& chi){ chi = psi; } + + /*****************************************************************************************************/ + + template + RealD DomainWallEOFAFermion::M(const FermionField& psi, FermionField& chi) + { + int Ls = this->Ls; + + FermionField Din(psi._grid); + + this->Meooe5D(psi, Din); + this->DW(Din, chi, DaggerNo); + axpby(chi, 1.0, 1.0, chi, psi); + this->M5D(psi, chi); + return(norm2(chi)); + } + + template + RealD DomainWallEOFAFermion::Mdag(const FermionField& psi, FermionField& chi) + { + int Ls = this->Ls; + + FermionField Din(psi._grid); + + this->DW(psi, Din, DaggerYes); + this->MeooeDag5D(Din, chi); + this->M5Ddag(psi, chi); + axpby(chi, 1.0, 1.0, chi, psi); + return(norm2(chi)); + } + + /******************************************************************** + /* Performance critical fermion operators called inside the inverter + /********************************************************************/ + + template + void DomainWallEOFAFermion::M5D(const FermionField& psi, FermionField& chi) + { + int Ls = this->Ls; + int pm = this->pm; + RealD shift = this->shift; + RealD mq1 = this->mq1; + RealD mq2 = this->mq2; + RealD mq3 = this->mq3; + + // coefficients for shift operator ( = shift*\gamma_{5}*R_{5}*\Delta_{\pm}(mq2,mq3)*P_{\pm} ) + Coeff_t shiftp(0.0), shiftm(0.0); + if(shift != 0.0){ + if(pm == 1){ shiftp = shift*(mq3-mq2); } + else{ shiftm = -shift*(mq3-mq2); } + } + + std::vector diag(Ls,1.0); + std::vector upper(Ls,-1.0); upper[Ls-1] = mq1 + shiftm; + std::vector lower(Ls,-1.0); lower[0] = mq1 + shiftp; + + #if(0) + std::cout << GridLogMessage << "DomainWallEOFAFermion::M5D(FF&,FF&):" << std::endl; + for(int i=0; i::iscomplex()) { + sp[l] = PplusMat (l*istride+s1*ostride,s2); + sm[l] = PminusMat(l*istride+s1*ostride,s2); + } else { + // if real + scalar_type tmp; + tmp = PplusMat (l*istride+s1*ostride,s2); + sp[l] = scalar_type(tmp.real(),tmp.real()); + tmp = PminusMat(l*istride+s1*ostride,s2); + sm[l] = scalar_type(tmp.real(),tmp.real()); + } + } + Matp[LLs*s2+s1] = Vp; + Matm[LLs*s2+s1] = Vm; + }} + } + + FermOpTemplateInstantiate(DomainWallEOFAFermion); + GparityFermOpTemplateInstantiate(DomainWallEOFAFermion); + +}} diff --git a/lib/qcd/action/fermion/DomainWallEOFAFermion.h b/lib/qcd/action/fermion/DomainWallEOFAFermion.h new file mode 100644 index 00000000..5362cda8 --- /dev/null +++ b/lib/qcd/action/fermion/DomainWallEOFAFermion.h @@ -0,0 +1,115 @@ +/************************************************************************************* + +Grid physics library, www.github.com/paboyle/Grid + +Source file: ./lib/qcd/action/fermion/DomainWallEOFAFermion.h + +Copyright (C) 2017 + +Author: Peter Boyle +Author: Peter Boyle +Author: David Murphy + +This program is free software; you can redistribute it and/or modify +it under the terms of the GNU General Public License as published by +the Free Software Foundation; either version 2 of the License, or +(at your option) any later version. + +This program is distributed in the hope that it will be useful, +but WITHOUT ANY WARRANTY; without even the implied warranty of +MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +GNU General Public License for more details. + +You should have received a copy of the GNU General Public License along +with this program; if not, write to the Free Software Foundation, Inc., +51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA. + +See the full license in the file "LICENSE" in the top level distribution directory +*************************************************************************************/ +/* END LEGAL */ +#ifndef GRID_QCD_DOMAIN_WALL_EOFA_FERMION_H +#define GRID_QCD_DOMAIN_WALL_EOFA_FERMION_H + +#include + +namespace Grid { +namespace QCD { + + template + class DomainWallEOFAFermion : public AbstractEOFAFermion + { + public: + INHERIT_IMPL_TYPES(Impl); + + public: + // Modified (0,Ls-1) and (Ls-1,0) elements of Mooee + // for red-black preconditioned Shamir EOFA + Coeff_t dm; + Coeff_t dp; + + virtual void Instantiatable(void) {}; + + // EOFA-specific operations + virtual void Omega (const FermionField& in, FermionField& out, int sign, int dag); + virtual void Dtilde (const FermionField& in, FermionField& out); + virtual void DtildeInv (const FermionField& in, FermionField& out); + + // override multiply + virtual RealD M (const FermionField& in, FermionField& out); + virtual RealD Mdag (const FermionField& in, FermionField& out); + + // half checkerboard operations + virtual void Mooee (const FermionField& in, FermionField& out); + virtual void MooeeDag (const FermionField& in, FermionField& out); + virtual void MooeeInv (const FermionField& in, FermionField& out); + virtual void MooeeInvDag(const FermionField& in, FermionField& out); + + virtual void M5D (const FermionField& psi, FermionField& chi); + virtual void M5Ddag (const FermionField& psi, FermionField& chi); + + ///////////////////////////////////////////////////// + // Instantiate different versions depending on Impl + ///////////////////////////////////////////////////// + void M5D(const FermionField& psi, const FermionField& phi, FermionField& chi, + std::vector& lower, std::vector& diag, std::vector& upper); + + void M5Ddag(const FermionField& psi, const FermionField& phi, FermionField& chi, + std::vector& lower, std::vector& diag, std::vector& upper); + + void MooeeInternal(const FermionField& in, FermionField& out, int dag, int inv); + + void MooeeInternalCompute(int dag, int inv, Vector>& Matp, Vector>& Matm); + + void MooeeInternalAsm(const FermionField& in, FermionField& out, int LLs, int site, + Vector>& Matp, Vector>& Matm); + + void MooeeInternalZAsm(const FermionField& in, FermionField& out, int LLs, int site, + Vector>& Matp, Vector>& Matm); + + virtual void RefreshShiftCoefficients(RealD new_shift); + + // Constructors + DomainWallEOFAFermion(GaugeField& _Umu, GridCartesian& FiveDimGrid, GridRedBlackCartesian& FiveDimRedBlackGrid, + GridCartesian& FourDimGrid, GridRedBlackCartesian& FourDimRedBlackGrid, + RealD _mq1, RealD _mq2, RealD _mq3, RealD _shift, int pm, + RealD _M5, const ImplParams& p=ImplParams()); + + protected: + void SetCoefficientsInternal(RealD zolo_hi, std::vector& gamma, RealD b, RealD c); + }; +}} + +#define INSTANTIATE_DPERP_DWF_EOFA(A)\ +template void DomainWallEOFAFermion::M5D(const FermionField& psi, const FermionField& phi, FermionField& chi, \ + std::vector& lower, std::vector& diag, std::vector& upper); \ +template void DomainWallEOFAFermion::M5Ddag(const FermionField& psi, const FermionField& phi, FermionField& chi, \ + std::vector& lower, std::vector& diag, std::vector& upper); \ +template void DomainWallEOFAFermion::MooeeInv(const FermionField& psi, FermionField& chi); \ +template void DomainWallEOFAFermion::MooeeInvDag(const FermionField& psi, FermionField& chi); + +#undef DOMAIN_WALL_EOFA_DPERP_DENSE +#define DOMAIN_WALL_EOFA_DPERP_CACHE +#undef DOMAIN_WALL_EOFA_DPERP_LINALG +#define DOMAIN_WALL_EOFA_DPERP_VEC + +#endif diff --git a/lib/qcd/action/fermion/DomainWallEOFAFermioncache.cc b/lib/qcd/action/fermion/DomainWallEOFAFermioncache.cc new file mode 100644 index 00000000..0b214d31 --- /dev/null +++ b/lib/qcd/action/fermion/DomainWallEOFAFermioncache.cc @@ -0,0 +1,248 @@ +/************************************************************************************* + +Grid physics library, www.github.com/paboyle/Grid + +Source file: ./lib/qcd/action/fermion/DomainWallEOFAFermioncache.cc + +Copyright (C) 2017 + +Author: Peter Boyle +Author: Peter Boyle +Author: Peter Boyle +Author: paboyle +Author: David Murphy + +This program is free software; you can redistribute it and/or modify +it under the terms of the GNU General Public License as published by +the Free Software Foundation; either version 2 of the License, or +(at your option) any later version. + +This program is distributed in the hope that it will be useful, +but WITHOUT ANY WARRANTY; without even the implied warranty of +MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +GNU General Public License for more details. + +You should have received a copy of the GNU General Public License along +with this program; if not, write to the Free Software Foundation, Inc., +51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA. + +See the full license in the file "LICENSE" in the top level distribution directory +*************************************************************************************/ +/* END LEGAL */ + +#include +#include + +namespace Grid { +namespace QCD { + + // FIXME -- make a version of these routines with site loop outermost for cache reuse. + + // Pminus fowards + // Pplus backwards.. + template + void DomainWallEOFAFermion::M5D(const FermionField& psi, const FermionField& phi, + FermionField& chi, std::vector& lower, std::vector& diag, std::vector& upper) + { + int Ls = this->Ls; + GridBase* grid = psi._grid; + + assert(phi.checkerboard == psi.checkerboard); + chi.checkerboard = psi.checkerboard; + // Flops = 6.0*(Nc*Ns) *Ls*vol + this->M5Dcalls++; + this->M5Dtime -= usecond(); + + parallel_for(int ss=0; ssoSites(); ss+=Ls){ // adds Ls + for(int s=0; sM5Dtime += usecond(); + } + + template + void DomainWallEOFAFermion::M5Ddag(const FermionField& psi, const FermionField& phi, + FermionField& chi, std::vector& lower, std::vector& diag, std::vector& upper) + { + int Ls = this->Ls; + GridBase* grid = psi._grid; + assert(phi.checkerboard == psi.checkerboard); + chi.checkerboard=psi.checkerboard; + + // Flops = 6.0*(Nc*Ns) *Ls*vol + this->M5Dcalls++; + this->M5Dtime -= usecond(); + + parallel_for(int ss=0; ssoSites(); ss+=Ls){ // adds Ls + auto tmp = psi._odata[0]; + for(int s=0; sM5Dtime += usecond(); + } + + template + void DomainWallEOFAFermion::MooeeInv(const FermionField& psi, FermionField& chi) + { + GridBase* grid = psi._grid; + int Ls = this->Ls; + + chi.checkerboard = psi.checkerboard; + + this->MooeeInvCalls++; + this->MooeeInvTime -= usecond(); + + parallel_for(int ss=0; ssoSites(); ss+=Ls){ // adds Ls + + auto tmp1 = psi._odata[0]; + auto tmp2 = psi._odata[0]; + + // flops = 12*2*Ls + 12*2*Ls + 3*12*Ls + 12*2*Ls = 12*Ls * (9) = 108*Ls flops + // Apply (L^{\prime})^{-1} + chi[ss] = psi[ss]; // chi[0]=psi[0] + for(int s=1; slee[s-1]*tmp1; + } + + // L_m^{-1} + for(int s=0; sleem[s]*tmp1; + } + + // U_m^{-1} D^{-1} + for(int s=0; sdee[s])*chi[ss+s] - (this->ueem[s]/this->dee[Ls])*tmp1; + } + spProj5m(tmp2, chi[ss+Ls-1]); + chi[ss+Ls-1] = (1.0/this->dee[Ls])*tmp1 + (1.0/this->dee[Ls-1])*tmp2; + + // Apply U^{-1} + for(int s=Ls-2; s>=0; s--){ + spProj5m(tmp1, chi[ss+s+1]); + chi[ss+s] = chi[ss+s] - this->uee[s]*tmp1; + } + } + + this->MooeeInvTime += usecond(); + } + + template + void DomainWallEOFAFermion::MooeeInvDag(const FermionField& psi, FermionField& chi) + { + GridBase* grid = psi._grid; + int Ls = this->Ls; + + assert(psi.checkerboard == psi.checkerboard); + chi.checkerboard = psi.checkerboard; + + std::vector ueec(Ls); + std::vector deec(Ls+1); + std::vector leec(Ls); + std::vector ueemc(Ls); + std::vector leemc(Ls); + + for(int s=0; suee[s]); + deec[s] = conjugate(this->dee[s]); + leec[s] = conjugate(this->lee[s]); + ueemc[s] = conjugate(this->ueem[s]); + leemc[s] = conjugate(this->leem[s]); + } + deec[Ls] = conjugate(this->dee[Ls]); + + this->MooeeInvCalls++; + this->MooeeInvTime -= usecond(); + + parallel_for(int ss=0; ssoSites(); ss+=Ls){ // adds Ls + + auto tmp1 = psi._odata[0]; + auto tmp2 = psi._odata[0]; + + // Apply (U^{\prime})^{-dagger} + chi[ss] = psi[ss]; + for(int s=1; s=0; s--){ + spProj5p(tmp1, chi[ss+s+1]); + chi[ss+s] = chi[ss+s] - leec[s]*tmp1; + } + } + + this->MooeeInvTime += usecond(); + } + + #ifdef DOMAIN_WALL_EOFA_DPERP_CACHE + + INSTANTIATE_DPERP_DWF_EOFA(WilsonImplF); + INSTANTIATE_DPERP_DWF_EOFA(WilsonImplD); + INSTANTIATE_DPERP_DWF_EOFA(GparityWilsonImplF); + INSTANTIATE_DPERP_DWF_EOFA(GparityWilsonImplD); + INSTANTIATE_DPERP_DWF_EOFA(ZWilsonImplF); + INSTANTIATE_DPERP_DWF_EOFA(ZWilsonImplD); + + INSTANTIATE_DPERP_DWF_EOFA(WilsonImplFH); + INSTANTIATE_DPERP_DWF_EOFA(WilsonImplDF); + INSTANTIATE_DPERP_DWF_EOFA(GparityWilsonImplFH); + INSTANTIATE_DPERP_DWF_EOFA(GparityWilsonImplDF); + INSTANTIATE_DPERP_DWF_EOFA(ZWilsonImplFH); + INSTANTIATE_DPERP_DWF_EOFA(ZWilsonImplDF); + + #endif + +}} diff --git a/lib/qcd/action/fermion/DomainWallEOFAFermiondense.cc b/lib/qcd/action/fermion/DomainWallEOFAFermiondense.cc new file mode 100644 index 00000000..c27074d9 --- /dev/null +++ b/lib/qcd/action/fermion/DomainWallEOFAFermiondense.cc @@ -0,0 +1,159 @@ +/************************************************************************************* + +Grid physics library, www.github.com/paboyle/Grid + +Source file: ./lib/qcd/action/fermion/DomainWallEOFAFermiondense.cc + +Copyright (C) 2017 + +Author: Peter Boyle +Author: Peter Boyle +Author: Peter Boyle +Author: paboyle +Author: David Murphy + +This program is free software; you can redistribute it and/or modify +it under the terms of the GNU General Public License as published by +the Free Software Foundation; either version 2 of the License, or +(at your option) any later version. + +This program is distributed in the hope that it will be useful, +but WITHOUT ANY WARRANTY; without even the implied warranty of +MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +GNU General Public License for more details. + +You should have received a copy of the GNU General Public License along +with this program; if not, write to the Free Software Foundation, Inc., +51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA. + +See the full license in the file "LICENSE" in the top level distribution directory +*************************************************************************************/ +/* END LEGAL */ + +#include +#include +#include + +namespace Grid { +namespace QCD { + + /* + * Dense matrix versions of routines + */ + template + void DomainWallEOFAFermion::MooeeInvDag(const FermionField& psi, FermionField& chi) + { + this->MooeeInternal(psi, chi, DaggerYes, InverseYes); + } + + template + void DomainWallEOFAFermion::MooeeInv(const FermionField& psi, FermionField& chi) + { + this->MooeeInternal(psi, chi, DaggerNo, InverseYes); + } + + template + void DomainWallEOFAFermion::MooeeInternal(const FermionField& psi, FermionField& chi, int dag, int inv) + { + int Ls = this->Ls; + int LLs = psi._grid->_rdimensions[0]; + int vol = psi._grid->oSites()/LLs; + + chi.checkerboard = psi.checkerboard; + + assert(Ls==LLs); + + Eigen::MatrixXd Pplus = Eigen::MatrixXd::Zero(Ls,Ls); + Eigen::MatrixXd Pminus = Eigen::MatrixXd::Zero(Ls,Ls); + + for(int s=0;sbee[s]; + Pminus(s,s) = this->bee[s]; + } + + for(int s=0; scee[s]; + } + + for(int s=0; scee[s+1]; + } + + Pplus (0,Ls-1) = this->dp; + Pminus(Ls-1,0) = this->dm; + + Eigen::MatrixXd PplusMat ; + Eigen::MatrixXd PminusMat; + + if(inv) { + PplusMat = Pplus.inverse(); + PminusMat = Pminus.inverse(); + } else { + PplusMat = Pplus; + PminusMat = Pminus; + } + + if(dag){ + PplusMat.adjointInPlace(); + PminusMat.adjointInPlace(); + } + + // For the non-vectorised s-direction this is simple + + for(auto site=0; site::MooeeInternal(const FermionField& psi, FermionField& chi, int dag, int inv); + template void DomainWallEOFAFermion::MooeeInternal(const FermionField& psi, FermionField& chi, int dag, int inv); + template void DomainWallEOFAFermion::MooeeInternal(const FermionField& psi, FermionField& chi, int dag, int inv); + template void DomainWallEOFAFermion::MooeeInternal(const FermionField& psi, FermionField& chi, int dag, int inv); + template void DomainWallEOFAFermion::MooeeInternal(const FermionField& psi, FermionField& chi, int dag, int inv); + template void DomainWallEOFAFermion::MooeeInternal(const FermionField& psi, FermionField& chi, int dag, int inv); + + INSTANTIATE_DPERP_DWF_EOFA(GparityWilsonImplFH); + INSTANTIATE_DPERP_DWF_EOFA(GparityWilsonImplDF); + INSTANTIATE_DPERP_DWF_EOFA(WilsonImplFH); + INSTANTIATE_DPERP_DWF_EOFA(WilsonImplDF); + INSTANTIATE_DPERP_DWF_EOFA(ZWilsonImplFH); + INSTANTIATE_DPERP_DWF_EOFA(ZWilsonImplDF); + + template void DomainWallEOFAFermion::MooeeInternal(const FermionField& psi, FermionField& chi, int dag, int inv); + template void DomainWallEOFAFermion::MooeeInternal(const FermionField& psi, FermionField& chi, int dag, int inv); + template void DomainWallEOFAFermion::MooeeInternal(const FermionField& psi, FermionField& chi, int dag, int inv); + template void DomainWallEOFAFermion::MooeeInternal(const FermionField& psi, FermionField& chi, int dag, int inv); + template void DomainWallEOFAFermion::MooeeInternal(const FermionField& psi, FermionField& chi, int dag, int inv); + template void DomainWallEOFAFermion::MooeeInternal(const FermionField& psi, FermionField& chi, int dag, int inv); + + #endif + +}} diff --git a/lib/qcd/action/fermion/DomainWallEOFAFermionssp.cc b/lib/qcd/action/fermion/DomainWallEOFAFermionssp.cc new file mode 100644 index 00000000..80a4bf09 --- /dev/null +++ b/lib/qcd/action/fermion/DomainWallEOFAFermionssp.cc @@ -0,0 +1,168 @@ +/************************************************************************************* + +Grid physics library, www.github.com/paboyle/Grid + +Source file: ./lib/qcd/action/fermion/DomainWallEOFAFermionssp.cc + +Copyright (C) 2017 + +Author: Peter Boyle +Author: Peter Boyle +Author: Peter Boyle +Author: paboyle +Author: David Murphy + +This program is free software; you can redistribute it and/or modify +it under the terms of the GNU General Public License as published by +the Free Software Foundation; either version 2 of the License, or +(at your option) any later version. + +This program is distributed in the hope that it will be useful, +but WITHOUT ANY WARRANTY; without even the implied warranty of +MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +GNU General Public License for more details. + +You should have received a copy of the GNU General Public License along +with this program; if not, write to the Free Software Foundation, Inc., +51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA. + +See the full license in the file "LICENSE" in the top level distribution directory +*************************************************************************************/ +/* END LEGAL */ + +#include +#include + +namespace Grid { +namespace QCD { + + // FIXME -- make a version of these routines with site loop outermost for cache reuse. + // Pminus fowards + // Pplus backwards + template + void DomainWallEOFAFermion::M5D(const FermionField& psi, const FermionField& phi, + FermionField& chi, std::vector& lower, std::vector& diag, std::vector& upper) + { + Coeff_t one(1.0); + int Ls = this->Ls; + for(int s=0; s + void DomainWallEOFAFermion::M5Ddag(const FermionField& psi, const FermionField& phi, + FermionField& chi, std::vector& lower, std::vector& diag, std::vector& upper) + { + Coeff_t one(1.0); + int Ls = this->Ls; + for(int s=0; s + void DomainWallEOFAFermion::MooeeInv(const FermionField& psi, FermionField& chi) + { + Coeff_t one(1.0); + Coeff_t czero(0.0); + chi.checkerboard = psi.checkerboard; + int Ls = this->Ls; + + FermionField tmp(psi._grid); + + // Apply (L^{\prime})^{-1} + axpby_ssp(chi, one, psi, czero, psi, 0, 0); // chi[0]=psi[0] + for(int s=1; slee[s-1], chi, s, s-1);// recursion Psi[s] -lee P_+ chi[s-1] + } + + // L_m^{-1} + for(int s=0; sleem[s], chi, Ls-1, s); + } + + // U_m^{-1} D^{-1} + for(int s=0; sdee[s], chi, -this->ueem[s]/this->dee[Ls], chi, s, Ls-1); + } + axpby_ssp_pminus(tmp, czero, chi, one/this->dee[Ls-1], chi, Ls-1, Ls-1); + axpby_ssp_pplus(chi, one, tmp, one/this->dee[Ls], chi, Ls-1, Ls-1); + + // Apply U^{-1} + for(int s=Ls-2; s>=0; s--){ + axpby_ssp_pminus(chi, one, chi, -this->uee[s], chi, s, s+1); // chi[Ls] + } + } + + template + void DomainWallEOFAFermion::MooeeInvDag(const FermionField& psi, FermionField& chi) + { + Coeff_t one(1.0); + Coeff_t czero(0.0); + chi.checkerboard = psi.checkerboard; + int Ls = this->Ls; + + FermionField tmp(psi._grid); + + // Apply (U^{\prime})^{-dagger} + axpby_ssp(chi, one, psi, czero, psi, 0, 0); // chi[0]=psi[0] + for(int s=1; suee[s-1]), chi, s, s-1); + } + + // U_m^{-\dagger} + for(int s=0; sueem[s]), chi, Ls-1, s); + } + + // L_m^{-\dagger} D^{-dagger} + for(int s=0; sdee[s]), chi, -conjugate(this->leem[s]/this->dee[Ls-1]), chi, s, Ls-1); + } + axpby_ssp_pminus(tmp, czero, chi, one/conjugate(this->dee[Ls-1]), chi, Ls-1, Ls-1); + axpby_ssp_pplus(chi, one, tmp, one/conjugate(this->dee[Ls]), chi, Ls-1, Ls-1); + + // Apply L^{-dagger} + for(int s=Ls-2; s>=0; s--){ + axpby_ssp_pplus(chi, one, chi, -conjugate(this->lee[s]), chi, s, s+1); // chi[Ls] + } + } + + #ifdef DOMAIN_WALL_EOFA_DPERP_LINALG + + INSTANTIATE_DPERP_DWF_EOFA(WilsonImplF); + INSTANTIATE_DPERP_DWF_EOFA(WilsonImplD); + INSTANTIATE_DPERP_DWF_EOFA(GparityWilsonImplF); + INSTANTIATE_DPERP_DWF_EOFA(GparityWilsonImplD); + INSTANTIATE_DPERP_DWF_EOFA(ZWilsonImplF); + INSTANTIATE_DPERP_DWF_EOFA(ZWilsonImplD); + + INSTANTIATE_DPERP_DWF_EOFA(WilsonImplFH); + INSTANTIATE_DPERP_DWF_EOFA(WilsonImplDF); + INSTANTIATE_DPERP_DWF_EOFA(GparityWilsonImplFH); + INSTANTIATE_DPERP_DWF_EOFA(GparityWilsonImplDF); + INSTANTIATE_DPERP_DWF_EOFA(ZWilsonImplFH); + INSTANTIATE_DPERP_DWF_EOFA(ZWilsonImplDF); + + #endif + +}} diff --git a/lib/qcd/action/fermion/DomainWallEOFAFermionvec.cc b/lib/qcd/action/fermion/DomainWallEOFAFermionvec.cc new file mode 100644 index 00000000..81ce448c --- /dev/null +++ b/lib/qcd/action/fermion/DomainWallEOFAFermionvec.cc @@ -0,0 +1,605 @@ +/************************************************************************************* + +Grid physics library, www.github.com/paboyle/Grid + +Source file: ./lib/qcd/action/fermion/DomainWallEOFAFermionvec.cc + +Copyright (C) 2017 + +Author: Peter Boyle +Author: Peter Boyle +Author: Peter Boyle +Author: paboyle +Author: David Murphy + +This program is free software; you can redistribute it and/or modify +it under the terms of the GNU General Public License as published by +the Free Software Foundation; either version 2 of the License, or +(at your option) any later version. + +This program is distributed in the hope that it will be useful, +but WITHOUT ANY WARRANTY; without even the implied warranty of +MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +GNU General Public License for more details. + +You should have received a copy of the GNU General Public License along +with this program; if not, write to the Free Software Foundation, Inc., +51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA. + +See the full license in the file "LICENSE" in the top level distribution directory +*************************************************************************************/ +/* END LEGAL */ + +#include +#include + +namespace Grid { +namespace QCD { + + /* + * Dense matrix versions of routines + */ + template + void DomainWallEOFAFermion::MooeeInvDag(const FermionField& psi, FermionField& chi) + { + this->MooeeInternal(psi, chi, DaggerYes, InverseYes); + } + + template + void DomainWallEOFAFermion::MooeeInv(const FermionField& psi, FermionField& chi) + { + this->MooeeInternal(psi, chi, DaggerNo, InverseYes); + } + + template + void DomainWallEOFAFermion::M5D(const FermionField& psi, const FermionField& phi, + FermionField& chi, std::vector& lower, std::vector& diag, std::vector& upper) + { + GridBase* grid = psi._grid; + int Ls = this->Ls; + int LLs = grid->_rdimensions[0]; + const int nsimd = Simd::Nsimd(); + + Vector > u(LLs); + Vector > l(LLs); + Vector > d(LLs); + + assert(Ls/LLs == nsimd); + assert(phi.checkerboard == psi.checkerboard); + + chi.checkerboard = psi.checkerboard; + + // just directly address via type pun + typedef typename Simd::scalar_type scalar_type; + scalar_type* u_p = (scalar_type*) &u[0]; + scalar_type* l_p = (scalar_type*) &l[0]; + scalar_type* d_p = (scalar_type*) &d[0]; + + for(int o=0;oM5Dcalls++; + this->M5Dtime -= usecond(); + + assert(Nc == 3); + + parallel_for(int ss=0; ssoSites(); ss+=LLs){ // adds LLs + + #if 0 + + alignas(64) SiteHalfSpinor hp; + alignas(64) SiteHalfSpinor hm; + alignas(64) SiteSpinor fp; + alignas(64) SiteSpinor fm; + + for(int v=0; v= v){ rotate(hm, hm, nsimd-1); } + + hp = 0.5*hp; + hm = 0.5*hm; + + spRecon5m(fp, hp); + spRecon5p(fm, hm); + + chi[ss+v] = d[v]*phi[ss+v]; + chi[ss+v] = chi[ss+v] + u[v]*fp; + chi[ss+v] = chi[ss+v] + l[v]*fm; + + } + + #else + + for(int v=0; v(hp_00.v); + hp_01.v = Optimization::Rotate::tRotate<2>(hp_01.v); + hp_02.v = Optimization::Rotate::tRotate<2>(hp_02.v); + hp_10.v = Optimization::Rotate::tRotate<2>(hp_10.v); + hp_11.v = Optimization::Rotate::tRotate<2>(hp_11.v); + hp_12.v = Optimization::Rotate::tRotate<2>(hp_12.v); + } + + if(vm >= v){ + hm_00.v = Optimization::Rotate::tRotate<2*Simd::Nsimd()-2>(hm_00.v); + hm_01.v = Optimization::Rotate::tRotate<2*Simd::Nsimd()-2>(hm_01.v); + hm_02.v = Optimization::Rotate::tRotate<2*Simd::Nsimd()-2>(hm_02.v); + hm_10.v = Optimization::Rotate::tRotate<2*Simd::Nsimd()-2>(hm_10.v); + hm_11.v = Optimization::Rotate::tRotate<2*Simd::Nsimd()-2>(hm_11.v); + hm_12.v = Optimization::Rotate::tRotate<2*Simd::Nsimd()-2>(hm_12.v); + } + + // Can force these to real arithmetic and save 2x. + Simd p_00 = switcheroo::mult(d[v]()()(), phi[ss+v]()(0)(0)) + switcheroo::mult(l[v]()()(), hm_00); + Simd p_01 = switcheroo::mult(d[v]()()(), phi[ss+v]()(0)(1)) + switcheroo::mult(l[v]()()(), hm_01); + Simd p_02 = switcheroo::mult(d[v]()()(), phi[ss+v]()(0)(2)) + switcheroo::mult(l[v]()()(), hm_02); + Simd p_10 = switcheroo::mult(d[v]()()(), phi[ss+v]()(1)(0)) + switcheroo::mult(l[v]()()(), hm_10); + Simd p_11 = switcheroo::mult(d[v]()()(), phi[ss+v]()(1)(1)) + switcheroo::mult(l[v]()()(), hm_11); + Simd p_12 = switcheroo::mult(d[v]()()(), phi[ss+v]()(1)(2)) + switcheroo::mult(l[v]()()(), hm_12); + Simd p_20 = switcheroo::mult(d[v]()()(), phi[ss+v]()(2)(0)) + switcheroo::mult(u[v]()()(), hp_00); + Simd p_21 = switcheroo::mult(d[v]()()(), phi[ss+v]()(2)(1)) + switcheroo::mult(u[v]()()(), hp_01); + Simd p_22 = switcheroo::mult(d[v]()()(), phi[ss+v]()(2)(2)) + switcheroo::mult(u[v]()()(), hp_02); + Simd p_30 = switcheroo::mult(d[v]()()(), phi[ss+v]()(3)(0)) + switcheroo::mult(u[v]()()(), hp_10); + Simd p_31 = switcheroo::mult(d[v]()()(), phi[ss+v]()(3)(1)) + switcheroo::mult(u[v]()()(), hp_11); + Simd p_32 = switcheroo::mult(d[v]()()(), phi[ss+v]()(3)(2)) + switcheroo::mult(u[v]()()(), hp_12); + + vstream(chi[ss+v]()(0)(0), p_00); + vstream(chi[ss+v]()(0)(1), p_01); + vstream(chi[ss+v]()(0)(2), p_02); + vstream(chi[ss+v]()(1)(0), p_10); + vstream(chi[ss+v]()(1)(1), p_11); + vstream(chi[ss+v]()(1)(2), p_12); + vstream(chi[ss+v]()(2)(0), p_20); + vstream(chi[ss+v]()(2)(1), p_21); + vstream(chi[ss+v]()(2)(2), p_22); + vstream(chi[ss+v]()(3)(0), p_30); + vstream(chi[ss+v]()(3)(1), p_31); + vstream(chi[ss+v]()(3)(2), p_32); + } + + #endif + } + + this->M5Dtime += usecond(); + } + + template + void DomainWallEOFAFermion::M5Ddag(const FermionField& psi, const FermionField& phi, + FermionField& chi, std::vector& lower, std::vector& diag, std::vector& upper) + { + GridBase* grid = psi._grid; + int Ls = this->Ls; + int LLs = grid->_rdimensions[0]; + int nsimd = Simd::Nsimd(); + + Vector > u(LLs); + Vector > l(LLs); + Vector > d(LLs); + + assert(Ls/LLs == nsimd); + assert(phi.checkerboard == psi.checkerboard); + + chi.checkerboard = psi.checkerboard; + + // just directly address via type pun + typedef typename Simd::scalar_type scalar_type; + scalar_type* u_p = (scalar_type*) &u[0]; + scalar_type* l_p = (scalar_type*) &l[0]; + scalar_type* d_p = (scalar_type*) &d[0]; + + for(int o=0; oM5Dcalls++; + this->M5Dtime -= usecond(); + + parallel_for(int ss=0; ssoSites(); ss+=LLs){ // adds LLs + + #if 0 + + alignas(64) SiteHalfSpinor hp; + alignas(64) SiteHalfSpinor hm; + alignas(64) SiteSpinor fp; + alignas(64) SiteSpinor fm; + + for(int v=0; v= v){ rotate(hm, hm, nsimd-1); } + + hp = hp*0.5; + hm = hm*0.5; + spRecon5p(fp, hp); + spRecon5m(fm, hm); + + chi[ss+v] = d[v]*phi[ss+v]+u[v]*fp; + chi[ss+v] = chi[ss+v] +l[v]*fm; + } + + #else + + for(int v=0; v(hp_00.v); + hp_01.v = Optimization::Rotate::tRotate<2>(hp_01.v); + hp_02.v = Optimization::Rotate::tRotate<2>(hp_02.v); + hp_10.v = Optimization::Rotate::tRotate<2>(hp_10.v); + hp_11.v = Optimization::Rotate::tRotate<2>(hp_11.v); + hp_12.v = Optimization::Rotate::tRotate<2>(hp_12.v); + } + + if(vm >= v){ + hm_00.v = Optimization::Rotate::tRotate<2*Simd::Nsimd()-2>(hm_00.v); + hm_01.v = Optimization::Rotate::tRotate<2*Simd::Nsimd()-2>(hm_01.v); + hm_02.v = Optimization::Rotate::tRotate<2*Simd::Nsimd()-2>(hm_02.v); + hm_10.v = Optimization::Rotate::tRotate<2*Simd::Nsimd()-2>(hm_10.v); + hm_11.v = Optimization::Rotate::tRotate<2*Simd::Nsimd()-2>(hm_11.v); + hm_12.v = Optimization::Rotate::tRotate<2*Simd::Nsimd()-2>(hm_12.v); + } + + Simd p_00 = switcheroo::mult(d[v]()()(), phi[ss+v]()(0)(0)) + switcheroo::mult(u[v]()()(), hp_00); + Simd p_01 = switcheroo::mult(d[v]()()(), phi[ss+v]()(0)(1)) + switcheroo::mult(u[v]()()(), hp_01); + Simd p_02 = switcheroo::mult(d[v]()()(), phi[ss+v]()(0)(2)) + switcheroo::mult(u[v]()()(), hp_02); + Simd p_10 = switcheroo::mult(d[v]()()(), phi[ss+v]()(1)(0)) + switcheroo::mult(u[v]()()(), hp_10); + Simd p_11 = switcheroo::mult(d[v]()()(), phi[ss+v]()(1)(1)) + switcheroo::mult(u[v]()()(), hp_11); + Simd p_12 = switcheroo::mult(d[v]()()(), phi[ss+v]()(1)(2)) + switcheroo::mult(u[v]()()(), hp_12); + Simd p_20 = switcheroo::mult(d[v]()()(), phi[ss+v]()(2)(0)) + switcheroo::mult(l[v]()()(), hm_00); + Simd p_21 = switcheroo::mult(d[v]()()(), phi[ss+v]()(2)(1)) + switcheroo::mult(l[v]()()(), hm_01); + Simd p_22 = switcheroo::mult(d[v]()()(), phi[ss+v]()(2)(2)) + switcheroo::mult(l[v]()()(), hm_02); + Simd p_30 = switcheroo::mult(d[v]()()(), phi[ss+v]()(3)(0)) + switcheroo::mult(l[v]()()(), hm_10); + Simd p_31 = switcheroo::mult(d[v]()()(), phi[ss+v]()(3)(1)) + switcheroo::mult(l[v]()()(), hm_11); + Simd p_32 = switcheroo::mult(d[v]()()(), phi[ss+v]()(3)(2)) + switcheroo::mult(l[v]()()(), hm_12); + + vstream(chi[ss+v]()(0)(0), p_00); + vstream(chi[ss+v]()(0)(1), p_01); + vstream(chi[ss+v]()(0)(2), p_02); + vstream(chi[ss+v]()(1)(0), p_10); + vstream(chi[ss+v]()(1)(1), p_11); + vstream(chi[ss+v]()(1)(2), p_12); + vstream(chi[ss+v]()(2)(0), p_20); + vstream(chi[ss+v]()(2)(1), p_21); + vstream(chi[ss+v]()(2)(2), p_22); + vstream(chi[ss+v]()(3)(0), p_30); + vstream(chi[ss+v]()(3)(1), p_31); + vstream(chi[ss+v]()(3)(2), p_32); + } + #endif + + } + + this->M5Dtime += usecond(); + } + + #ifdef AVX512 + #include + #include + #include + #endif + + template + void DomainWallEOFAFermion::MooeeInternalAsm(const FermionField& psi, FermionField& chi, + int LLs, int site, Vector >& Matp, Vector >& Matm) + { + #ifndef AVX512 + { + SiteHalfSpinor BcastP; + SiteHalfSpinor BcastM; + SiteHalfSpinor SiteChiP; + SiteHalfSpinor SiteChiM; + + // Ls*Ls * 2 * 12 * vol flops + for(int s1=0; s1); + for(int s1=0; s1 + void DomainWallEOFAFermion::MooeeInternalZAsm(const FermionField& psi, FermionField& chi, + int LLs, int site, Vector >& Matp, Vector >& Matm) + { + std::cout << "Error: zMobius not implemented for EOFA" << std::endl; + exit(-1); + }; + + template + void DomainWallEOFAFermion::MooeeInternal(const FermionField& psi, FermionField& chi, int dag, int inv) + { + int Ls = this->Ls; + int LLs = psi._grid->_rdimensions[0]; + int vol = psi._grid->oSites()/LLs; + + chi.checkerboard = psi.checkerboard; + + Vector > Matp; + Vector > Matm; + Vector > *_Matp; + Vector > *_Matm; + + // MooeeInternalCompute(dag,inv,Matp,Matm); + if(inv && dag){ + _Matp = &this->MatpInvDag; + _Matm = &this->MatmInvDag; + } + + if(inv && (!dag)){ + _Matp = &this->MatpInv; + _Matm = &this->MatmInv; + } + + if(!inv){ + MooeeInternalCompute(dag, inv, Matp, Matm); + _Matp = &Matp; + _Matm = &Matm; + } + + assert(_Matp->size() == Ls*LLs); + + this->MooeeInvCalls++; + this->MooeeInvTime -= usecond(); + + if(switcheroo::iscomplex()){ + parallel_for(auto site=0; siteMooeeInvTime += usecond(); + } + + #ifdef DOMAIN_WALL_EOFA_DPERP_VEC + + INSTANTIATE_DPERP_DWF_EOFA(DomainWallVec5dImplD); + INSTANTIATE_DPERP_DWF_EOFA(DomainWallVec5dImplF); + INSTANTIATE_DPERP_DWF_EOFA(ZDomainWallVec5dImplD); + INSTANTIATE_DPERP_DWF_EOFA(ZDomainWallVec5dImplF); + + INSTANTIATE_DPERP_DWF_EOFA(DomainWallVec5dImplDF); + INSTANTIATE_DPERP_DWF_EOFA(DomainWallVec5dImplFH); + INSTANTIATE_DPERP_DWF_EOFA(ZDomainWallVec5dImplDF); + INSTANTIATE_DPERP_DWF_EOFA(ZDomainWallVec5dImplFH); + + template void DomainWallEOFAFermion::MooeeInternal(const FermionField& psi, FermionField& chi, int dag, int inv); + template void DomainWallEOFAFermion::MooeeInternal(const FermionField& psi, FermionField& chi, int dag, int inv); + template void DomainWallEOFAFermion::MooeeInternal(const FermionField& psi, FermionField& chi, int dag, int inv); + template void DomainWallEOFAFermion::MooeeInternal(const FermionField& psi, FermionField& chi, int dag, int inv); + + template void DomainWallEOFAFermion::MooeeInternal(const FermionField& psi, FermionField& chi, int dag, int inv); + template void DomainWallEOFAFermion::MooeeInternal(const FermionField& psi, FermionField& chi, int dag, int inv); + template void DomainWallEOFAFermion::MooeeInternal(const FermionField& psi, FermionField& chi, int dag, int inv); + template void DomainWallEOFAFermion::MooeeInternal(const FermionField& psi, FermionField& chi, int dag, int inv); + + #endif + +}} diff --git a/lib/qcd/action/fermion/DomainWallFermion.h b/lib/qcd/action/fermion/DomainWallFermion.h index c0b6b6aa..72ce8f42 100644 --- a/lib/qcd/action/fermion/DomainWallFermion.h +++ b/lib/qcd/action/fermion/DomainWallFermion.h @@ -29,7 +29,7 @@ Author: Peter Boyle #ifndef GRID_QCD_DOMAIN_WALL_FERMION_H #define GRID_QCD_DOMAIN_WALL_FERMION_H -#include +#include namespace Grid { @@ -68,7 +68,7 @@ namespace Grid { Approx::zolotarev_data *zdata = Approx::higham(eps,this->Ls);// eps is ignored for higham assert(zdata->n==this->Ls); - // std::cout< -#include - -//////////////////////////////////////////// -// Utility functions -//////////////////////////////////////////// -#include -#include - -#include //used by all wilson type fermions -#include -#include -#include //used by all wilson type fermions -#include //used by all wilson type fermions - -//////////////////////////////////////////// -// Gauge Actions -//////////////////////////////////////////// -#include -#include - -namespace Grid { -namespace QCD { - -typedef WilsonGaugeAction WilsonGaugeActionR; -typedef WilsonGaugeAction WilsonGaugeActionF; -typedef WilsonGaugeAction WilsonGaugeActionD; -typedef PlaqPlusRectangleAction PlaqPlusRectangleActionR; -typedef PlaqPlusRectangleAction PlaqPlusRectangleActionF; -typedef PlaqPlusRectangleAction PlaqPlusRectangleActionD; -typedef IwasakiGaugeAction IwasakiGaugeActionR; -typedef IwasakiGaugeAction IwasakiGaugeActionF; -typedef IwasakiGaugeAction IwasakiGaugeActionD; -typedef SymanzikGaugeAction SymanzikGaugeActionR; -typedef SymanzikGaugeAction SymanzikGaugeActionF; -typedef SymanzikGaugeAction SymanzikGaugeActionD; - - -typedef WilsonGaugeAction ConjugateWilsonGaugeActionR; -typedef WilsonGaugeAction ConjugateWilsonGaugeActionF; -typedef WilsonGaugeAction ConjugateWilsonGaugeActionD; -typedef PlaqPlusRectangleAction ConjugatePlaqPlusRectangleActionR; -typedef PlaqPlusRectangleAction ConjugatePlaqPlusRectangleActionF; -typedef PlaqPlusRectangleAction ConjugatePlaqPlusRectangleActionD; -typedef IwasakiGaugeAction ConjugateIwasakiGaugeActionR; -typedef IwasakiGaugeAction ConjugateIwasakiGaugeActionF; -typedef IwasakiGaugeAction ConjugateIwasakiGaugeActionD; -typedef SymanzikGaugeAction ConjugateSymanzikGaugeActionR; -typedef SymanzikGaugeAction ConjugateSymanzikGaugeActionF; -typedef SymanzikGaugeAction ConjugateSymanzikGaugeActionD; - -}} +#ifndef GRID_QCD_FERMION_H +#define GRID_QCD_FERMION_H //////////////////////////////////////////////////////////////////////////////////////////////////// // Explicit explicit template instantiation is still required in the .cc files @@ -103,45 +38,13 @@ typedef SymanzikGaugeAction ConjugateSymanzikGaugeAction // - ContinuedFractionFermion5D.cc // - WilsonFermion.cc // - WilsonKernels.cc +// - DomainWallEOFAFermion.cc +// - MobiusEOFAFermion.cc // // The explicit instantiation is only avoidable if we move this source to headers and end up with include/parse/recompile // for EVERY .cc file. This define centralises the list and restores global push of impl cases //////////////////////////////////////////////////////////////////////////////////////////////////// - -#define FermOpStaggeredTemplateInstantiate(A) \ - template class A; \ - template class A; - -#define FermOp4dVecTemplateInstantiate(A) \ - template class A; \ - template class A; \ - template class A; \ - template class A; \ - template class A; \ - template class A; - -#define AdjointFermOpTemplateInstantiate(A) \ - template class A; \ - template class A; - -#define TwoIndexFermOpTemplateInstantiate(A) \ - template class A; \ - template class A; - -#define FermOp5dVecTemplateInstantiate(A) \ - template class A; \ - template class A; \ - template class A; \ - template class A; - -#define FermOpTemplateInstantiate(A) \ - FermOp4dVecTemplateInstantiate(A) \ - FermOp5dVecTemplateInstantiate(A) - - -#define GparityFermOpTemplateInstantiate(A) - //////////////////////////////////////////// // Fermion operators / actions //////////////////////////////////////////// @@ -149,30 +52,31 @@ typedef SymanzikGaugeAction ConjugateSymanzikGaugeAction #include // 4d wilson like #include // 4d wilson like #include // 5d base used by all 5d overlap types - //#include - #include #include - #include // Cayley types #include -#include +#include #include +#include #include +#include #include #include #include #include #include - #include // Continued fraction #include #include - #include // Partial fraction #include #include +/////////////////////////////////////////////////////////////////////////////// +// G5 herm -- this has to live in QCD since dirac matrix is not in the broader sector of code +/////////////////////////////////////////////////////////////////////////////// +#include //////////////////////////////////////////////////////////////////////////////////////////////////// // More maintainable to maintain the following typedef list centrally, as more "impl" targets @@ -188,6 +92,10 @@ typedef WilsonFermion WilsonFermionR; typedef WilsonFermion WilsonFermionF; typedef WilsonFermion WilsonFermionD; +typedef WilsonFermion WilsonFermionRL; +typedef WilsonFermion WilsonFermionFH; +typedef WilsonFermion WilsonFermionDF; + typedef WilsonFermion WilsonAdjFermionR; typedef WilsonFermion WilsonAdjFermionF; typedef WilsonFermion WilsonAdjFermionD; @@ -204,27 +112,82 @@ typedef DomainWallFermion DomainWallFermionR; typedef DomainWallFermion DomainWallFermionF; typedef DomainWallFermion DomainWallFermionD; +typedef DomainWallFermion DomainWallFermionRL; +typedef DomainWallFermion DomainWallFermionFH; +typedef DomainWallFermion DomainWallFermionDF; + +typedef DomainWallEOFAFermion DomainWallEOFAFermionR; +typedef DomainWallEOFAFermion DomainWallEOFAFermionF; +typedef DomainWallEOFAFermion DomainWallEOFAFermionD; + +typedef DomainWallEOFAFermion DomainWallEOFAFermionRL; +typedef DomainWallEOFAFermion DomainWallEOFAFermionFH; +typedef DomainWallEOFAFermion DomainWallEOFAFermionDF; + typedef MobiusFermion MobiusFermionR; typedef MobiusFermion MobiusFermionF; typedef MobiusFermion MobiusFermionD; +typedef MobiusFermion MobiusFermionRL; +typedef MobiusFermion MobiusFermionFH; +typedef MobiusFermion MobiusFermionDF; + +typedef MobiusEOFAFermion MobiusEOFAFermionR; +typedef MobiusEOFAFermion MobiusEOFAFermionF; +typedef MobiusEOFAFermion MobiusEOFAFermionD; + +typedef MobiusEOFAFermion MobiusEOFAFermionRL; +typedef MobiusEOFAFermion MobiusEOFAFermionFH; +typedef MobiusEOFAFermion MobiusEOFAFermionDF; + typedef ZMobiusFermion ZMobiusFermionR; typedef ZMobiusFermion ZMobiusFermionF; typedef ZMobiusFermion ZMobiusFermionD; -// Ls vectorised +typedef ZMobiusFermion ZMobiusFermionRL; +typedef ZMobiusFermion ZMobiusFermionFH; +typedef ZMobiusFermion ZMobiusFermionDF; + +// Ls vectorised typedef DomainWallFermion DomainWallFermionVec5dR; typedef DomainWallFermion DomainWallFermionVec5dF; typedef DomainWallFermion DomainWallFermionVec5dD; +typedef DomainWallFermion DomainWallFermionVec5dRL; +typedef DomainWallFermion DomainWallFermionVec5dFH; +typedef DomainWallFermion DomainWallFermionVec5dDF; + +typedef DomainWallEOFAFermion DomainWallEOFAFermionVec5dR; +typedef DomainWallEOFAFermion DomainWallEOFAFermionVec5dF; +typedef DomainWallEOFAFermion DomainWallEOFAFermionVec5dD; + +typedef DomainWallEOFAFermion DomainWallEOFAFermionVec5dRL; +typedef DomainWallEOFAFermion DomainWallEOFAFermionVec5dFH; +typedef DomainWallEOFAFermion DomainWallEOFAFermionVec5dDF; + typedef MobiusFermion MobiusFermionVec5dR; typedef MobiusFermion MobiusFermionVec5dF; typedef MobiusFermion MobiusFermionVec5dD; +typedef MobiusFermion MobiusFermionVec5dRL; +typedef MobiusFermion MobiusFermionVec5dFH; +typedef MobiusFermion MobiusFermionVec5dDF; + +typedef MobiusEOFAFermion MobiusEOFAFermionVec5dR; +typedef MobiusEOFAFermion MobiusEOFAFermionVec5dF; +typedef MobiusEOFAFermion MobiusEOFAFermionVec5dD; + +typedef MobiusEOFAFermion MobiusEOFAFermionVec5dRL; +typedef MobiusEOFAFermion MobiusEOFAFermionVec5dFH; +typedef MobiusEOFAFermion MobiusEOFAFermionVec5dDF; + typedef ZMobiusFermion ZMobiusFermionVec5dR; typedef ZMobiusFermion ZMobiusFermionVec5dF; typedef ZMobiusFermion ZMobiusFermionVec5dD; +typedef ZMobiusFermion ZMobiusFermionVec5dRL; +typedef ZMobiusFermion ZMobiusFermionVec5dFH; +typedef ZMobiusFermion ZMobiusFermionVec5dDF; typedef ScaledShamirFermion ScaledShamirFermionR; typedef ScaledShamirFermion ScaledShamirFermionF; @@ -265,17 +228,51 @@ typedef OverlapWilsonPartialFractionZolotarevFermion OverlapWilsonP typedef WilsonFermion GparityWilsonFermionR; typedef WilsonFermion GparityWilsonFermionF; typedef WilsonFermion GparityWilsonFermionD; + +typedef WilsonFermion GparityWilsonFermionRL; +typedef WilsonFermion GparityWilsonFermionFH; +typedef WilsonFermion GparityWilsonFermionDF; + typedef DomainWallFermion GparityDomainWallFermionR; typedef DomainWallFermion GparityDomainWallFermionF; typedef DomainWallFermion GparityDomainWallFermionD; +typedef DomainWallFermion GparityDomainWallFermionRL; +typedef DomainWallFermion GparityDomainWallFermionFH; +typedef DomainWallFermion GparityDomainWallFermionDF; + +typedef DomainWallEOFAFermion GparityDomainWallEOFAFermionR; +typedef DomainWallEOFAFermion GparityDomainWallEOFAFermionF; +typedef DomainWallEOFAFermion GparityDomainWallEOFAFermionD; + +typedef DomainWallEOFAFermion GparityDomainWallEOFAFermionRL; +typedef DomainWallEOFAFermion GparityDomainWallEOFAFermionFH; +typedef DomainWallEOFAFermion GparityDomainWallEOFAFermionDF; + typedef WilsonTMFermion GparityWilsonTMFermionR; typedef WilsonTMFermion GparityWilsonTMFermionF; typedef WilsonTMFermion GparityWilsonTMFermionD; + +typedef WilsonTMFermion GparityWilsonTMFermionRL; +typedef WilsonTMFermion GparityWilsonTMFermionFH; +typedef WilsonTMFermion GparityWilsonTMFermionDF; + typedef MobiusFermion GparityMobiusFermionR; typedef MobiusFermion GparityMobiusFermionF; typedef MobiusFermion GparityMobiusFermionD; +typedef MobiusFermion GparityMobiusFermionRL; +typedef MobiusFermion GparityMobiusFermionFH; +typedef MobiusFermion GparityMobiusFermionDF; + +typedef MobiusEOFAFermion GparityMobiusEOFAFermionR; +typedef MobiusEOFAFermion GparityMobiusEOFAFermionF; +typedef MobiusEOFAFermion GparityMobiusEOFAFermionD; + +typedef MobiusEOFAFermion GparityMobiusEOFAFermionRL; +typedef MobiusEOFAFermion GparityMobiusEOFAFermionFH; +typedef MobiusEOFAFermion GparityMobiusEOFAFermionDF; + typedef ImprovedStaggeredFermion ImprovedStaggeredFermionR; typedef ImprovedStaggeredFermion ImprovedStaggeredFermionF; typedef ImprovedStaggeredFermion ImprovedStaggeredFermionD; @@ -284,26 +281,18 @@ typedef ImprovedStaggeredFermion5D ImprovedStaggeredFermion5DR; typedef ImprovedStaggeredFermion5D ImprovedStaggeredFermion5DF; typedef ImprovedStaggeredFermion5D ImprovedStaggeredFermion5DD; +typedef ImprovedStaggeredFermion5D ImprovedStaggeredFermionVec5dR; +typedef ImprovedStaggeredFermion5D ImprovedStaggeredFermionVec5dF; +typedef ImprovedStaggeredFermion5D ImprovedStaggeredFermionVec5dD; + }} -/////////////////////////////////////////////////////////////////////////////// -// G5 herm -- this has to live in QCD since dirac matrix is not in the broader sector of code -/////////////////////////////////////////////////////////////////////////////// -#include -//////////////////////////////////////// -// Pseudo fermion combinations for HMC -//////////////////////////////////////// -#include - -#include -#include -#include -#include - -#include -#include -#include -#include +//////////////////// +// Scalar QED actions +// TODO: this needs to move to another header after rename to Fermion.h +//////////////////// +#include +#include #endif diff --git a/lib/qcd/action/fermion/FermionCore.h b/lib/qcd/action/fermion/FermionCore.h new file mode 100644 index 00000000..17006961 --- /dev/null +++ b/lib/qcd/action/fermion/FermionCore.h @@ -0,0 +1,91 @@ + /************************************************************************************* + + Grid physics library, www.github.com/paboyle/Grid + + Source file: ./lib/qcd/action/fermion/Fermion_base_aggregate.h + + Copyright (C) 2015 + +Author: Peter Boyle + + This program is free software; you can redistribute it and/or modify + it under the terms of the GNU General Public License as published by + the Free Software Foundation; either version 2 of the License, or + (at your option) any later version. + + This program is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU General Public License for more details. + + You should have received a copy of the GNU General Public License along + with this program; if not, write to the Free Software Foundation, Inc., + 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA. + + See the full license in the file "LICENSE" in the top level distribution directory + *************************************************************************************/ + /* END LEGAL */ +#ifndef GRID_QCD_FERMION_CORE_H +#define GRID_QCD_FERMION_CORE_H + +#include +#include +#include + +//////////////////////////////////////////// +// Fermion prereqs +//////////////////////////////////////////// +#include //used by all wilson type fermions +#include +#include +#include //used by all wilson type fermions +#include //used by all wilson type fermions + +#define FermOpStaggeredTemplateInstantiate(A) \ + template class A; \ + template class A; + +#define FermOpStaggeredVec5dTemplateInstantiate(A) \ + template class A; \ + template class A; + +#define FermOp4dVecTemplateInstantiate(A) \ + template class A; \ + template class A; \ + template class A; \ + template class A; \ + template class A; \ + template class A; \ + template class A; \ + template class A; \ + template class A; \ + template class A; \ + template class A; \ + template class A; + + +#define AdjointFermOpTemplateInstantiate(A) \ + template class A; \ + template class A; + +#define TwoIndexFermOpTemplateInstantiate(A) \ + template class A; \ + template class A; + +#define FermOp5dVecTemplateInstantiate(A) \ + template class A; \ + template class A; \ + template class A; \ + template class A; \ + template class A; \ + template class A; \ + template class A; \ + template class A; + +#define FermOpTemplateInstantiate(A) \ + FermOp4dVecTemplateInstantiate(A) \ + FermOp5dVecTemplateInstantiate(A) + +#define GparityFermOpTemplateInstantiate(A) + +#endif diff --git a/lib/qcd/action/fermion/FermionOperator.h b/lib/qcd/action/fermion/FermionOperator.h index 742c6e08..676a0e83 100644 --- a/lib/qcd/action/fermion/FermionOperator.h +++ b/lib/qcd/action/fermion/FermionOperator.h @@ -48,6 +48,8 @@ namespace Grid { FermionOperator(const ImplParams &p= ImplParams()) : Impl(p) {}; + virtual FermionField &tmp(void) = 0; + GridBase * Grid(void) { return FermionGrid(); }; // this is all the linalg routines need to know GridBase * RedBlackGrid(void) { return FermionRedBlackGrid(); }; diff --git a/lib/qcd/action/fermion/FermionOperatorImpl.h b/lib/qcd/action/fermion/FermionOperatorImpl.h index 36ab35ca..9d24deb2 100644 --- a/lib/qcd/action/fermion/FermionOperatorImpl.h +++ b/lib/qcd/action/fermion/FermionOperatorImpl.h @@ -35,7 +35,6 @@ directory namespace Grid { namespace QCD { - ////////////////////////////////////////////// // Template parameter class constructs to package // externally control Fermion implementations @@ -44,12 +43,14 @@ namespace QCD { // Ultimately need Impl to always define types where XXX is opaque // // typedef typename XXX Simd; - // typedef typename XXX GaugeLinkField; + // typedef typename XXX GaugeLinkField; // typedef typename XXX GaugeField; // typedef typename XXX GaugeActField; // typedef typename XXX FermionField; + // typedef typename XXX PropagatorField; // typedef typename XXX DoubledGaugeField; // typedef typename XXX SiteSpinor; + // typedef typename XXX SitePropagator; // typedef typename XXX SiteHalfSpinor; // typedef typename XXX Compressor; // @@ -87,7 +88,53 @@ namespace QCD { // // } ////////////////////////////////////////////// - + + template struct SamePrecisionMapper { + typedef T HigherPrecVector ; + typedef T LowerPrecVector ; + }; + template struct LowerPrecisionMapper { }; + template <> struct LowerPrecisionMapper { + typedef vRealF HigherPrecVector ; + typedef vRealH LowerPrecVector ; + }; + template <> struct LowerPrecisionMapper { + typedef vRealD HigherPrecVector ; + typedef vRealF LowerPrecVector ; + }; + template <> struct LowerPrecisionMapper { + typedef vComplexF HigherPrecVector ; + typedef vComplexH LowerPrecVector ; + }; + template <> struct LowerPrecisionMapper { + typedef vComplexD HigherPrecVector ; + typedef vComplexF LowerPrecVector ; + }; + + struct CoeffReal { + public: + typedef RealD _Coeff_t; + static const int Nhcs = 2; + template using PrecisionMapper = SamePrecisionMapper; + }; + struct CoeffRealHalfComms { + public: + typedef RealD _Coeff_t; + static const int Nhcs = 1; + template using PrecisionMapper = LowerPrecisionMapper; + }; + struct CoeffComplex { + public: + typedef ComplexD _Coeff_t; + static const int Nhcs = 2; + template using PrecisionMapper = SamePrecisionMapper; + }; + struct CoeffComplexHalfComms { + public: + typedef ComplexD _Coeff_t; + static const int Nhcs = 1; + template using PrecisionMapper = LowerPrecisionMapper; + }; //////////////////////////////////////////////////////////////////////// // Implementation dependent fermion types @@ -95,64 +142,74 @@ namespace QCD { #define INHERIT_FIMPL_TYPES(Impl)\ typedef typename Impl::FermionField FermionField; \ + typedef typename Impl::PropagatorField PropagatorField; \ typedef typename Impl::DoubledGaugeField DoubledGaugeField; \ typedef typename Impl::SiteSpinor SiteSpinor; \ + typedef typename Impl::SitePropagator SitePropagator; \ typedef typename Impl::SiteHalfSpinor SiteHalfSpinor; \ typedef typename Impl::Compressor Compressor; \ typedef typename Impl::StencilImpl StencilImpl; \ - typedef typename Impl::ImplParams ImplParams; \ - typedef typename Impl::Coeff_t Coeff_t; + typedef typename Impl::ImplParams ImplParams; \ + typedef typename Impl::Coeff_t Coeff_t; \ #define INHERIT_IMPL_TYPES(Base) \ - INHERIT_GIMPL_TYPES(Base) \ + INHERIT_GIMPL_TYPES(Base) \ INHERIT_FIMPL_TYPES(Base) ///////////////////////////////////////////////////////////////////////////// // Single flavour four spinors with colour index ///////////////////////////////////////////////////////////////////////////// - template + template class WilsonImpl : public PeriodicGaugeImpl > { - public: static const int Dimension = Representation::Dimension; + static const bool LsVectorised=false; + static const int Nhcs = Options::Nhcs; + typedef PeriodicGaugeImpl > Gimpl; + INHERIT_GIMPL_TYPES(Gimpl); //Necessary? constexpr bool is_fundamental() const{return Dimension == Nc ? 1 : 0;} - const bool LsVectorised=false; - typedef _Coeff_t Coeff_t; - - INHERIT_GIMPL_TYPES(Gimpl); + typedef typename Options::_Coeff_t Coeff_t; + typedef typename Options::template PrecisionMapper::LowerPrecVector SimdL; template using iImplSpinor = iScalar, Ns> >; + template using iImplPropagator = iScalar, Ns> >; template using iImplHalfSpinor = iScalar, Nhs> >; + template using iImplHalfCommSpinor = iScalar, Nhcs> >; template using iImplDoubledGaugeField = iVector >, Nds>; typedef iImplSpinor SiteSpinor; + typedef iImplPropagator SitePropagator; typedef iImplHalfSpinor SiteHalfSpinor; + typedef iImplHalfCommSpinor SiteHalfCommSpinor; typedef iImplDoubledGaugeField SiteDoubledGaugeField; typedef Lattice FermionField; + typedef Lattice PropagatorField; typedef Lattice DoubledGaugeField; - typedef WilsonCompressor Compressor; + typedef WilsonCompressor Compressor; typedef WilsonImplParams ImplParams; typedef WilsonStencil StencilImpl; ImplParams Params; - WilsonImpl(const ImplParams &p = ImplParams()) : Params(p){}; + WilsonImpl(const ImplParams &p = ImplParams()) : Params(p){ + assert(Params.boundary_phases.size() == Nd); + }; bool overlapCommsCompute(void) { return Params.overlapCommsCompute; }; inline void multLink(SiteHalfSpinor &phi, - const SiteDoubledGaugeField &U, - const SiteHalfSpinor &chi, - int mu, - StencilEntry *SE, - StencilImpl &St) { + const SiteDoubledGaugeField &U, + const SiteHalfSpinor &chi, + int mu, + StencilEntry *SE, + StencilImpl &St) { mult(&phi(), &U(mu), &chi()); } @@ -162,16 +219,34 @@ namespace QCD { } inline void DoubleStore(GridBase *GaugeGrid, - DoubledGaugeField &Uds, - const GaugeField &Umu) { + DoubledGaugeField &Uds, + const GaugeField &Umu) + { + typedef typename Simd::scalar_type scalar_type; + conformable(Uds._grid, GaugeGrid); conformable(Umu._grid, GaugeGrid); + GaugeLinkField U(GaugeGrid); + GaugeLinkField tmp(GaugeGrid); + + Lattice > coor(GaugeGrid); for (int mu = 0; mu < Nd; mu++) { - U = PeekIndex(Umu, mu); - PokeIndex(Uds, U, mu); - U = adj(Cshift(U, mu, -1)); - PokeIndex(Uds, U, mu + 4); + + auto pha = Params.boundary_phases[mu]; + scalar_type phase( real(pha),imag(pha) ); + + int Lmu = GaugeGrid->GlobalDimensions()[mu] - 1; + + LatticeCoordinate(coor, mu); + + U = PeekIndex(Umu, mu); + tmp = where(coor == Lmu, phase * U, U); + PokeIndex(Uds, tmp, mu); + + U = adj(Cshift(U, mu, -1)); + U = where(coor == 0, conjugate(phase) * U, U); + PokeIndex(Uds, U, mu + 4); } } @@ -187,13 +262,12 @@ namespace QCD { GaugeLinkField tmp(mat._grid); tmp = zero; - PARALLEL_FOR_LOOP - for(int sss=0;sssoSites();sss++){ - int sU=sss; - for(int s=0;s(outerProduct(Btilde[sF],Atilde[sF])); // ordering here - } + parallel_for(int sss=0;sssoSites();sss++){ + int sU=sss; + for(int s=0;s(outerProduct(Btilde[sF],Atilde[sF])); // ordering here + } } PokeIndex(mat,tmp,mu); @@ -203,36 +277,44 @@ namespace QCD { //////////////////////////////////////////////////////////////////////////////////// // Single flavour four spinors with colour index, 5d redblack //////////////////////////////////////////////////////////////////////////////////// - -template +template class DomainWallVec5dImpl : public PeriodicGaugeImpl< GaugeImplTypes< S,Nrepresentation> > { public: - - static const int Dimension = Nrepresentation; - const bool LsVectorised=true; - typedef _Coeff_t Coeff_t; + typedef PeriodicGaugeImpl > Gimpl; - INHERIT_GIMPL_TYPES(Gimpl); + + static const int Dimension = Nrepresentation; + static const bool LsVectorised=true; + static const int Nhcs = Options::Nhcs; + + typedef typename Options::_Coeff_t Coeff_t; + typedef typename Options::template PrecisionMapper::LowerPrecVector SimdL; template using iImplSpinor = iScalar, Ns> >; + template using iImplPropagator = iScalar, Ns> >; template using iImplHalfSpinor = iScalar, Nhs> >; + template using iImplHalfCommSpinor = iScalar, Nhcs> >; template using iImplDoubledGaugeField = iVector >, Nds>; template using iImplGaugeField = iVector >, Nd>; template using iImplGaugeLink = iScalar > >; - typedef iImplSpinor SiteSpinor; - typedef iImplHalfSpinor SiteHalfSpinor; - typedef Lattice FermionField; - + typedef iImplSpinor SiteSpinor; + typedef iImplPropagator SitePropagator; + typedef iImplHalfSpinor SiteHalfSpinor; + typedef iImplHalfCommSpinor SiteHalfCommSpinor; + typedef Lattice FermionField; + typedef Lattice PropagatorField; + + ///////////////////////////////////////////////// // Make the doubled gauge field a *scalar* + ///////////////////////////////////////////////// typedef iImplDoubledGaugeField SiteDoubledGaugeField; // This is a scalar typedef iImplGaugeField SiteScalarGaugeField; // scalar typedef iImplGaugeLink SiteScalarGaugeLink; // scalar + typedef Lattice DoubledGaugeField; - typedef Lattice DoubledGaugeField; - - typedef WilsonCompressor Compressor; + typedef WilsonCompressor Compressor; typedef WilsonImplParams ImplParams; typedef WilsonStencil StencilImpl; @@ -248,12 +330,12 @@ class DomainWallVec5dImpl : public PeriodicGaugeImpl< GaugeImplTypes< S,Nrepres } inline void multLink(SiteHalfSpinor &phi, const SiteDoubledGaugeField &U, - const SiteHalfSpinor &chi, int mu, StencilEntry *SE, - StencilImpl &St) { + const SiteHalfSpinor &chi, int mu, StencilEntry *SE, + StencilImpl &St) { SiteGaugeLink UU; for (int i = 0; i < Nrepresentation; i++) { for (int j = 0; j < Nrepresentation; j++) { - vsplat(UU()()(i, j), U(mu)()(i, j)); + vsplat(UU()()(i, j), U(mu)()(i, j)); } } mult(&phi(), &UU(), &chi()); @@ -261,11 +343,11 @@ class DomainWallVec5dImpl : public PeriodicGaugeImpl< GaugeImplTypes< S,Nrepres inline void DoubleStore(GridBase *GaugeGrid, DoubledGaugeField &Uds,const GaugeField &Umu) { - SiteScalarGaugeField ScalarUmu; + SiteScalarGaugeField ScalarUmu; SiteDoubledGaugeField ScalarUds; GaugeLinkField U(Umu._grid); - GaugeField Uadj(Umu._grid); + GaugeField Uadj(Umu._grid); for (int mu = 0; mu < Nd; mu++) { U = PeekIndex(Umu, mu); U = adj(Cshift(U, mu, -1)); @@ -290,42 +372,90 @@ class DomainWallVec5dImpl : public PeriodicGaugeImpl< GaugeImplTypes< S,Nrepres { assert(0); } - - inline void InsertForce5D(GaugeField &mat, FermionField &Btilde,FermionField Ã, int mu) - { - assert(0); + + inline void InsertForce5D(GaugeField &mat, FermionField &Btilde, FermionField Ã, int mu) { + + assert(0); + // Following lines to be revised after Peter's addition of half prec + // missing put lane... + /* + typedef decltype(traceIndex(outerProduct(Btilde[0], Atilde[0]))) result_type; + unsigned int LLs = Btilde._grid->_rdimensions[0]; + conformable(Atilde._grid,Btilde._grid); + GridBase* grid = mat._grid; + GridBase* Bgrid = Btilde._grid; + unsigned int dimU = grid->Nd(); + unsigned int dimF = Bgrid->Nd(); + GaugeLinkField tmp(grid); + tmp = zero; + + // FIXME + // Current implementation works, thread safe, probably suboptimal + // Passing through the local coordinate for grid transformation + // the force grid is in general very different from the Ls vectorized grid + + PARALLEL_FOR_LOOP + for (int so = 0; so < grid->oSites(); so++) { + std::vector vres(Bgrid->Nsimd()); + std::vector ocoor; grid->oCoorFromOindex(ocoor,so); + for (int si = 0; si < tmp._grid->iSites(); si++){ + typename result_type::scalar_object scalar_object; scalar_object = zero; + std::vector local_coor; + std::vector icoor; grid->iCoorFromIindex(icoor,si); + grid->InOutCoorToLocalCoor(ocoor, icoor, local_coor); + for (int s = 0; s < LLs; s++) { + std::vector slocal_coor(dimF); + slocal_coor[0] = s; + for (int s4d = 1; s4d< dimF; s4d++) slocal_coor[s4d] = local_coor[s4d-1]; + int sF = Bgrid->oIndexReduced(slocal_coor); + assert(sF < Bgrid->oSites()); + + extract(traceIndex(outerProduct(Btilde[sF], Atilde[sF])), vres); + // sum across the 5d dimension + for (auto v : vres) scalar_object += v; + } + tmp._odata[so].putlane(scalar_object, si); + } + } + PokeIndex(mat, tmp, mu); + */ } }; //////////////////////////////////////////////////////////////////////////////////////// // Flavour doubled spinors; is Gparity the only? what about C*? //////////////////////////////////////////////////////////////////////////////////////// - -template +template class GparityWilsonImpl : public ConjugateGaugeImpl > { public: static const int Dimension = Nrepresentation; + static const int Nhcs = Options::Nhcs; + static const bool LsVectorised=false; - const bool LsVectorised=false; - - typedef _Coeff_t Coeff_t; typedef ConjugateGaugeImpl< GaugeImplTypes > Gimpl; - INHERIT_GIMPL_TYPES(Gimpl); + + typedef typename Options::_Coeff_t Coeff_t; + typedef typename Options::template PrecisionMapper::LowerPrecVector SimdL; - template using iImplSpinor = iVector, Ns>, Ngp>; - template using iImplHalfSpinor = iVector, Nhs>, Ngp>; + template using iImplSpinor = iVector, Ns>, Ngp>; + template using iImplPropagator = iVector, Ns>, Ngp>; + template using iImplHalfSpinor = iVector, Nhs>, Ngp>; + template using iImplHalfCommSpinor = iVector, Nhcs>, Ngp>; template using iImplDoubledGaugeField = iVector >, Nds>, Ngp>; - - typedef iImplSpinor SiteSpinor; - typedef iImplHalfSpinor SiteHalfSpinor; + + typedef iImplSpinor SiteSpinor; + typedef iImplPropagator SitePropagator; + typedef iImplHalfSpinor SiteHalfSpinor; + typedef iImplHalfCommSpinor SiteHalfCommSpinor; typedef iImplDoubledGaugeField SiteDoubledGaugeField; - + typedef Lattice FermionField; + typedef Lattice PropagatorField; typedef Lattice DoubledGaugeField; - typedef WilsonCompressor Compressor; + typedef WilsonCompressor Compressor; typedef WilsonStencil StencilImpl; typedef GparityWilsonImplParams ImplParams; @@ -339,19 +469,19 @@ class GparityWilsonImpl : public ConjugateGaugeImplNsimd(); - + int direction = St._directions[mu]; int distance = St._distances[mu]; int ptype = St._permute_type[mu]; @@ -359,13 +489,13 @@ class GparityWilsonImpl : public ConjugateGaugeImpl icoor; - + if ( SE->_around_the_world && Params.twists[mmu] ) { if ( sl == 2 ) { @@ -375,25 +505,25 @@ class GparityWilsonImpl : public ConjugateGaugeImpliCoorFromIindex(icoor,s); - - assert((icoor[direction]==0)||(icoor[direction]==1)); - - int permute_lane; - if ( distance == 1) { - permute_lane = icoor[direction]?1:0; - } else { - permute_lane = icoor[direction]?0:1; - } - - if ( permute_lane ) { - stmp(0) = vals[s](1); - stmp(1) = vals[s](0); - vals[s] = stmp; - } + grid->iCoorFromIindex(icoor,s); + + assert((icoor[direction]==0)||(icoor[direction]==1)); + + int permute_lane; + if ( distance == 1) { + permute_lane = icoor[direction]?1:0; + } else { + permute_lane = icoor[direction]?0:1; + } + + if ( permute_lane ) { + stmp(0) = vals[s](1); + stmp(1) = vals[s](0); + vals[s] = stmp; + } } merge(vtmp,vals); - + } else { vtmp(0) = chi(1); vtmp(1) = chi(0); @@ -408,6 +538,12 @@ class GparityWilsonImpl : public ConjugateGaugeImpl + inline void loadLinkElement(Simd ®, ref &memory) { + reg = memory; + } + inline void DoubleStore(GridBase *GaugeGrid,DoubledGaugeField &Uds,const GaugeField &Umu) { conformable(Uds._grid,GaugeGrid); @@ -418,11 +554,11 @@ class GparityWilsonImpl : public ConjugateGaugeImpl > coor(GaugeGrid); - + for(int mu=0;mu(Umu,mu); Uconj = conjugate(U); @@ -432,12 +568,11 @@ class GparityWilsonImpl : public ConjugateGaugeImpl(outerProduct(Btilde, A)); -PARALLEL_FOR_LOOP - for (auto ss = tmp.begin(); ss < tmp.end(); ss++) { - link[ss]() = tmp[ss](0, 0) - conjugate(tmp[ss](1, 1)); + parallel_for(auto ss = tmp.begin(); ss < tmp.end(); ss++) { + link[ss]() = tmp[ss](0, 0) + conjugate(tmp[ss](1, 1)); } PokeIndex(mat, link, mu); return; @@ -482,11 +614,10 @@ PARALLEL_FOR_LOOP inline void InsertForce5D(GaugeField &mat, FermionField &Btilde, FermionField Ã, int mu) { int Ls = Btilde._grid->_fdimensions[0]; - + GaugeLinkField tmp(mat._grid); tmp = zero; -PARALLEL_FOR_LOOP - for (int ss = 0; ss < tmp._grid->oSites(); ss++) { + parallel_for(int ss = 0; ss < tmp._grid->oSites(); ss++) { for (int s = 0; s < Ls; s++) { int sF = s + Ls * ss; auto ttmp = traceIndex(outerProduct(Btilde[sF], Atilde[sF])); @@ -499,40 +630,39 @@ PARALLEL_FOR_LOOP }; - - ///////////////////////////////////////////////////////////////////////////// - // Single flavour one component spinors with colour index - ///////////////////////////////////////////////////////////////////////////// - template - class StaggeredImpl : public PeriodicGaugeImpl > { +///////////////////////////////////////////////////////////////////////////// +// Single flavour one component spinors with colour index +///////////////////////////////////////////////////////////////////////////// +template +class StaggeredImpl : public PeriodicGaugeImpl > { public: typedef RealD _Coeff_t ; static const int Dimension = Representation::Dimension; + static const bool LsVectorised=false; typedef PeriodicGaugeImpl > Gimpl; //Necessary? constexpr bool is_fundamental() const{return Dimension == Nc ? 1 : 0;} - const bool LsVectorised=false; typedef _Coeff_t Coeff_t; INHERIT_GIMPL_TYPES(Gimpl); - template using iImplScalar = iScalar > >; template using iImplSpinor = iScalar > >; - template using iImplHalfSpinor = iVector >, Ngp>; + template using iImplHalfSpinor = iScalar > >; template using iImplDoubledGaugeField = iVector >, Nds>; + template using iImplPropagator = iScalar > >; - typedef iImplScalar SiteComplex; typedef iImplSpinor SiteSpinor; typedef iImplHalfSpinor SiteHalfSpinor; typedef iImplDoubledGaugeField SiteDoubledGaugeField; + typedef iImplPropagator SitePropagator; - typedef Lattice ComplexField; typedef Lattice FermionField; typedef Lattice DoubledGaugeField; + typedef Lattice PropagatorField; typedef SimpleCompressor Compressor; typedef StaggeredImplParams ImplParams; @@ -629,39 +759,236 @@ PARALLEL_FOR_LOOP } }; + ///////////////////////////////////////////////////////////////////////////// + // Single flavour one component spinors with colour index. 5d vec + ///////////////////////////////////////////////////////////////////////////// + template + class StaggeredVec5dImpl : public PeriodicGaugeImpl > { + + public: + + static const int Dimension = Representation::Dimension; + static const bool LsVectorised=true; + typedef RealD Coeff_t ; + typedef PeriodicGaugeImpl > Gimpl; + + //Necessary? + constexpr bool is_fundamental() const{return Dimension == Nc ? 1 : 0;} - typedef WilsonImpl WilsonImplR; // Real.. whichever prec - typedef WilsonImpl WilsonImplF; // Float - typedef WilsonImpl WilsonImplD; // Double + INHERIT_GIMPL_TYPES(Gimpl); - typedef WilsonImpl ZWilsonImplR; // Real.. whichever prec - typedef WilsonImpl ZWilsonImplF; // Float - typedef WilsonImpl ZWilsonImplD; // Double - - typedef WilsonImpl WilsonAdjImplR; // Real.. whichever prec - typedef WilsonImpl WilsonAdjImplF; // Float - typedef WilsonImpl WilsonAdjImplD; // Double - - typedef WilsonImpl WilsonTwoIndexSymmetricImplR; // Real.. whichever prec - typedef WilsonImpl WilsonTwoIndexSymmetricImplF; // Float - typedef WilsonImpl WilsonTwoIndexSymmetricImplD; // Double - - typedef DomainWallVec5dImpl DomainWallVec5dImplR; // Real.. whichever prec - typedef DomainWallVec5dImpl DomainWallVec5dImplF; // Float - typedef DomainWallVec5dImpl DomainWallVec5dImplD; // Double - - typedef DomainWallVec5dImpl ZDomainWallVec5dImplR; // Real.. whichever prec - typedef DomainWallVec5dImpl ZDomainWallVec5dImplF; // Float - typedef DomainWallVec5dImpl ZDomainWallVec5dImplD; // Double - - typedef GparityWilsonImpl GparityWilsonImplR; // Real.. whichever prec - typedef GparityWilsonImpl GparityWilsonImplF; // Float - typedef GparityWilsonImpl GparityWilsonImplD; // Double + template using iImplSpinor = iScalar > >; + template using iImplHalfSpinor = iScalar > >; + template using iImplDoubledGaugeField = iVector >, Nds>; + template using iImplGaugeField = iVector >, Nd>; + template using iImplGaugeLink = iScalar > >; + template using iImplPropagator = iScalar > >; - typedef StaggeredImpl StaggeredImplR; // Real.. whichever prec - typedef StaggeredImpl StaggeredImplF; // Float - typedef StaggeredImpl StaggeredImplD; // Double + // Make the doubled gauge field a *scalar* + typedef iImplDoubledGaugeField SiteDoubledGaugeField; // This is a scalar + typedef iImplGaugeField SiteScalarGaugeField; // scalar + typedef iImplGaugeLink SiteScalarGaugeLink; // scalar + typedef iImplPropagator SitePropagator; + + typedef Lattice DoubledGaugeField; + typedef Lattice PropagatorField; + + typedef iImplSpinor SiteSpinor; + typedef iImplHalfSpinor SiteHalfSpinor; + + + typedef Lattice FermionField; + + typedef SimpleCompressor Compressor; + typedef StaggeredImplParams ImplParams; + typedef CartesianStencil StencilImpl; + + ImplParams Params; + + StaggeredVec5dImpl(const ImplParams &p = ImplParams()) : Params(p){}; + + template + inline void loadLinkElement(Simd ®, ref &memory) { + vsplat(reg, memory); + } + + inline void multLink(SiteHalfSpinor &phi, const SiteDoubledGaugeField &U, + const SiteHalfSpinor &chi, int mu) { + SiteGaugeLink UU; + for (int i = 0; i < Dimension; i++) { + for (int j = 0; j < Dimension; j++) { + vsplat(UU()()(i, j), U(mu)()(i, j)); + } + } + mult(&phi(), &UU(), &chi()); + } + inline void multLinkAdd(SiteHalfSpinor &phi, const SiteDoubledGaugeField &U, + const SiteHalfSpinor &chi, int mu) { + SiteGaugeLink UU; + for (int i = 0; i < Dimension; i++) { + for (int j = 0; j < Dimension; j++) { + vsplat(UU()()(i, j), U(mu)()(i, j)); + } + } + mac(&phi(), &UU(), &chi()); + } + + inline void DoubleStore(GridBase *GaugeGrid, + DoubledGaugeField &UUUds, // for Naik term + DoubledGaugeField &Uds, + const GaugeField &Uthin, + const GaugeField &Ufat) + { + + GridBase * InputGrid = Uthin._grid; + conformable(InputGrid,Ufat._grid); + + GaugeLinkField U(InputGrid); + GaugeLinkField UU(InputGrid); + GaugeLinkField UUU(InputGrid); + GaugeLinkField Udag(InputGrid); + GaugeLinkField UUUdag(InputGrid); + + for (int mu = 0; mu < Nd; mu++) { + + // Staggered Phase. + Lattice > coor(InputGrid); + Lattice > x(InputGrid); LatticeCoordinate(x,0); + Lattice > y(InputGrid); LatticeCoordinate(y,1); + Lattice > z(InputGrid); LatticeCoordinate(z,2); + Lattice > t(InputGrid); LatticeCoordinate(t,3); + + Lattice > lin_z(InputGrid); lin_z=x+y; + Lattice > lin_t(InputGrid); lin_t=x+y+z; + + ComplexField phases(InputGrid); phases=1.0; + + if ( mu == 1 ) phases = where( mod(x ,2)==(Integer)0, phases,-phases); + if ( mu == 2 ) phases = where( mod(lin_z,2)==(Integer)0, phases,-phases); + if ( mu == 3 ) phases = where( mod(lin_t,2)==(Integer)0, phases,-phases); + + // 1 hop based on fat links + U = PeekIndex(Ufat, mu); + Udag = adj( Cshift(U, mu, -1)); + + U = U *phases; + Udag = Udag *phases; + + + for (int lidx = 0; lidx < GaugeGrid->lSites(); lidx++) { + SiteScalarGaugeLink ScalarU; + SiteDoubledGaugeField ScalarUds; + + std::vector lcoor; + GaugeGrid->LocalIndexToLocalCoor(lidx, lcoor); + peekLocalSite(ScalarUds, Uds, lcoor); + + peekLocalSite(ScalarU, U, lcoor); + ScalarUds(mu) = ScalarU(); + + peekLocalSite(ScalarU, Udag, lcoor); + ScalarUds(mu + 4) = ScalarU(); + + pokeLocalSite(ScalarUds, Uds, lcoor); + } + + // 3 hop based on thin links. Crazy huh ? + U = PeekIndex(Uthin, mu); + UU = Gimpl::CovShiftForward(U,mu,U); + UUU= Gimpl::CovShiftForward(U,mu,UU); + + UUUdag = adj( Cshift(UUU, mu, -3)); + + UUU = UUU *phases; + UUUdag = UUUdag *phases; + + for (int lidx = 0; lidx < GaugeGrid->lSites(); lidx++) { + + SiteScalarGaugeLink ScalarU; + SiteDoubledGaugeField ScalarUds; + + std::vector lcoor; + GaugeGrid->LocalIndexToLocalCoor(lidx, lcoor); + + peekLocalSite(ScalarUds, UUUds, lcoor); + + peekLocalSite(ScalarU, UUU, lcoor); + ScalarUds(mu) = ScalarU(); + + peekLocalSite(ScalarU, UUUdag, lcoor); + ScalarUds(mu + 4) = ScalarU(); + + pokeLocalSite(ScalarUds, UUUds, lcoor); + } + + } + } + + inline void InsertForce4D(GaugeField &mat, FermionField &Btilde, FermionField &A,int mu){ + assert(0); + } + + inline void InsertForce5D(GaugeField &mat, FermionField &Btilde, FermionField Ã,int mu){ + assert (0); + } + }; + +typedef WilsonImpl WilsonImplR; // Real.. whichever prec +typedef WilsonImpl WilsonImplF; // Float +typedef WilsonImpl WilsonImplD; // Double + +typedef WilsonImpl WilsonImplRL; // Real.. whichever prec +typedef WilsonImpl WilsonImplFH; // Float +typedef WilsonImpl WilsonImplDF; // Double + +typedef WilsonImpl ZWilsonImplR; // Real.. whichever prec +typedef WilsonImpl ZWilsonImplF; // Float +typedef WilsonImpl ZWilsonImplD; // Double + +typedef WilsonImpl ZWilsonImplRL; // Real.. whichever prec +typedef WilsonImpl ZWilsonImplFH; // Float +typedef WilsonImpl ZWilsonImplDF; // Double + +typedef WilsonImpl WilsonAdjImplR; // Real.. whichever prec +typedef WilsonImpl WilsonAdjImplF; // Float +typedef WilsonImpl WilsonAdjImplD; // Double + +typedef WilsonImpl WilsonTwoIndexSymmetricImplR; // Real.. whichever prec +typedef WilsonImpl WilsonTwoIndexSymmetricImplF; // Float +typedef WilsonImpl WilsonTwoIndexSymmetricImplD; // Double + +typedef DomainWallVec5dImpl DomainWallVec5dImplR; // Real.. whichever prec +typedef DomainWallVec5dImpl DomainWallVec5dImplF; // Float +typedef DomainWallVec5dImpl DomainWallVec5dImplD; // Double + +typedef DomainWallVec5dImpl DomainWallVec5dImplRL; // Real.. whichever prec +typedef DomainWallVec5dImpl DomainWallVec5dImplFH; // Float +typedef DomainWallVec5dImpl DomainWallVec5dImplDF; // Double + +typedef DomainWallVec5dImpl ZDomainWallVec5dImplR; // Real.. whichever prec +typedef DomainWallVec5dImpl ZDomainWallVec5dImplF; // Float +typedef DomainWallVec5dImpl ZDomainWallVec5dImplD; // Double + +typedef DomainWallVec5dImpl ZDomainWallVec5dImplRL; // Real.. whichever prec +typedef DomainWallVec5dImpl ZDomainWallVec5dImplFH; // Float +typedef DomainWallVec5dImpl ZDomainWallVec5dImplDF; // Double + +typedef GparityWilsonImpl GparityWilsonImplR; // Real.. whichever prec +typedef GparityWilsonImpl GparityWilsonImplF; // Float +typedef GparityWilsonImpl GparityWilsonImplD; // Double + +typedef GparityWilsonImpl GparityWilsonImplRL; // Real.. whichever prec +typedef GparityWilsonImpl GparityWilsonImplFH; // Float +typedef GparityWilsonImpl GparityWilsonImplDF; // Double + +typedef StaggeredImpl StaggeredImplR; // Real.. whichever prec +typedef StaggeredImpl StaggeredImplF; // Float +typedef StaggeredImpl StaggeredImplD; // Double + +typedef StaggeredVec5dImpl StaggeredVec5dImplR; // Real.. whichever prec +typedef StaggeredVec5dImpl StaggeredVec5dImplF; // Float +typedef StaggeredVec5dImpl StaggeredVec5dImplD; // Double }} diff --git a/lib/qcd/action/fermion/ImprovedStaggeredFermion.cc b/lib/qcd/action/fermion/ImprovedStaggeredFermion.cc index 42dff5b2..5ce2b335 100644 --- a/lib/qcd/action/fermion/ImprovedStaggeredFermion.cc +++ b/lib/qcd/action/fermion/ImprovedStaggeredFermion.cc @@ -40,10 +40,10 @@ ImprovedStaggeredFermionStatic::displacements({1, 1, 1, 1, -1, -1, -1, -1, 3, 3, // Constructor and gauge import ///////////////////////////////// + template -ImprovedStaggeredFermion::ImprovedStaggeredFermion(GaugeField &_Uthin, GaugeField &_Ufat, GridCartesian &Fgrid, - GridRedBlackCartesian &Hgrid, RealD _mass, - RealD _c1, RealD _c2,RealD _u0, +ImprovedStaggeredFermion::ImprovedStaggeredFermion(GridCartesian &Fgrid, GridRedBlackCartesian &Hgrid, + RealD _mass, const ImplParams &p) : Kernels(p), _grid(&Fgrid), @@ -52,9 +52,6 @@ ImprovedStaggeredFermion::ImprovedStaggeredFermion(GaugeField &_Uthin, Gau StencilEven(&Hgrid, npoint, Even, directions, displacements), // source is Even StencilOdd(&Hgrid, npoint, Odd, directions, displacements), // source is Odd mass(_mass), - c1(_c1), - c2(_c2), - u0(_u0), Lebesgue(_grid), LebesgueEvenOdd(_cbgrid), Umu(&Fgrid), @@ -62,11 +59,32 @@ ImprovedStaggeredFermion::ImprovedStaggeredFermion(GaugeField &_Uthin, Gau UmuOdd(&Hgrid), UUUmu(&Fgrid), UUUmuEven(&Hgrid), - UUUmuOdd(&Hgrid) + UUUmuOdd(&Hgrid) , + _tmp(&Hgrid) { - // Allocate the required comms buffer +} + +template +ImprovedStaggeredFermion::ImprovedStaggeredFermion(GaugeField &_Uthin, GaugeField &_Ufat, GridCartesian &Fgrid, + GridRedBlackCartesian &Hgrid, RealD _mass, + RealD _c1, RealD _c2,RealD _u0, + const ImplParams &p) + : ImprovedStaggeredFermion(Fgrid,Hgrid,_mass,p) +{ + c1=_c1; + c2=_c2; + u0=_u0; ImportGauge(_Uthin,_Ufat); } +template +ImprovedStaggeredFermion::ImprovedStaggeredFermion(GaugeField &_Uthin,GaugeField &_Utriple, GaugeField &_Ufat, GridCartesian &Fgrid, + GridRedBlackCartesian &Hgrid, RealD _mass, + const ImplParams &p) + : ImprovedStaggeredFermion(Fgrid,Hgrid,_mass,p) +{ + ImportGaugeSimple(_Utriple,_Ufat); +} + //////////////////////////////////////////////////////////// // Momentum space propagator should be @@ -85,6 +103,34 @@ void ImprovedStaggeredFermion::ImportGauge(const GaugeField &_Uthin) ImportGauge(_Uthin,_Uthin); }; template +void ImprovedStaggeredFermion::ImportGaugeSimple(const GaugeField &_Utriple,const GaugeField &_Ufat) +{ + ///////////////////////////////////////////////////////////////// + // Trivial import; phases and fattening and such like preapplied + ///////////////////////////////////////////////////////////////// + GaugeLinkField U(GaugeGrid()); + + for (int mu = 0; mu < Nd; mu++) { + + U = PeekIndex(_Utriple, mu); + PokeIndex(UUUmu, U, mu ); + + U = adj( Cshift(U, mu, -3)); + PokeIndex(UUUmu, -U, mu+4 ); + + U = PeekIndex(_Ufat, mu); + PokeIndex(Umu, U, mu); + + U = adj( Cshift(U, mu, -1)); + PokeIndex(Umu, -U, mu+4); + + } + pickCheckerboard(Even, UmuEven, Umu); + pickCheckerboard(Odd, UmuOdd , Umu); + pickCheckerboard(Even, UUUmuEven,UUUmu); + pickCheckerboard(Odd, UUUmuOdd, UUUmu); +} +template void ImprovedStaggeredFermion::ImportGauge(const GaugeField &_Uthin,const GaugeField &_Ufat) { GaugeLinkField U(GaugeGrid()); @@ -94,7 +140,6 @@ void ImprovedStaggeredFermion::ImportGauge(const GaugeField &_Uthin,const //////////////////////////////////////////////////////// Impl::DoubleStore(GaugeGrid(), UUUmu, Umu, _Uthin, _Ufat ); - //////////////////////////////////////////////////////// // Apply scale factors to get the right fermion Kinetic term // Could pass coeffs into the double store to save work. @@ -338,12 +383,12 @@ void ImprovedStaggeredFermion::DhopInternal(StencilImpl &st, LebesgueOrder if (dag == DaggerYes) { PARALLEL_FOR_LOOP for (int sss = 0; sss < in._grid->oSites(); sss++) { - Kernels::DhopSiteDag(st, lo, U, UUU, st.CommBuf(), sss, sss, in, out); + Kernels::DhopSiteDag(st, lo, U, UUU, st.CommBuf(), 1, sss, in, out); } } else { PARALLEL_FOR_LOOP for (int sss = 0; sss < in._grid->oSites(); sss++) { - Kernels::DhopSite(st, lo, U, UUU, st.CommBuf(), sss, sss, in, out); + Kernels::DhopSite(st, lo, U, UUU, st.CommBuf(), 1, sss, in, out); } } }; diff --git a/lib/qcd/action/fermion/ImprovedStaggeredFermion.h b/lib/qcd/action/fermion/ImprovedStaggeredFermion.h index ad298d29..7d1f2996 100644 --- a/lib/qcd/action/fermion/ImprovedStaggeredFermion.h +++ b/lib/qcd/action/fermion/ImprovedStaggeredFermion.h @@ -46,6 +46,9 @@ class ImprovedStaggeredFermion : public StaggeredKernels, public ImprovedS INHERIT_IMPL_TYPES(Impl); typedef StaggeredKernels Kernels; + FermionField _tmp; + FermionField &tmp(void) { return _tmp; } + /////////////////////////////////////////////////////////////// // Implement the abstract base /////////////////////////////////////////////////////////////// @@ -109,7 +112,16 @@ class ImprovedStaggeredFermion : public StaggeredKernels, public ImprovedS RealD _c1=9.0/8.0, RealD _c2=-1.0/24.0,RealD _u0=1.0, const ImplParams &p = ImplParams()); + ImprovedStaggeredFermion(GaugeField &_Uthin, GaugeField &_Utriple, GaugeField &_Ufat, GridCartesian &Fgrid, + GridRedBlackCartesian &Hgrid, RealD _mass, + const ImplParams &p = ImplParams()); + + ImprovedStaggeredFermion(GridCartesian &Fgrid, GridRedBlackCartesian &Hgrid, RealD _mass, + const ImplParams &p = ImplParams()); + + // DoubleStore impl dependent + void ImportGaugeSimple(const GaugeField &_Utriple, const GaugeField &_Ufat); void ImportGauge(const GaugeField &_Uthin, const GaugeField &_Ufat); void ImportGauge(const GaugeField &_Uthin); diff --git a/lib/qcd/action/fermion/ImprovedStaggeredFermion5D.cc b/lib/qcd/action/fermion/ImprovedStaggeredFermion5D.cc index 71a6bf06..7d988d89 100644 --- a/lib/qcd/action/fermion/ImprovedStaggeredFermion5D.cc +++ b/lib/qcd/action/fermion/ImprovedStaggeredFermion5D.cc @@ -26,8 +26,9 @@ Author: Peter Boyle See the full license in the file "LICENSE" in the top level distribution directory *************************************************************************************/ /* END LEGAL */ -#include -#include +#include +#include +#include namespace Grid { namespace QCD { @@ -53,53 +54,74 @@ ImprovedStaggeredFermion5D::ImprovedStaggeredFermion5D(GaugeField &_Uthin, _FiveDimRedBlackGrid(&FiveDimRedBlackGrid), _FourDimGrid (&FourDimGrid), _FourDimRedBlackGrid(&FourDimRedBlackGrid), - Stencil (_FiveDimGrid,npoint,Even,directions,displacements), - StencilEven(_FiveDimRedBlackGrid,npoint,Even,directions,displacements), // source is Even - StencilOdd (_FiveDimRedBlackGrid,npoint,Odd ,directions,displacements), // source is Odd + Stencil (&FiveDimGrid,npoint,Even,directions,displacements), + StencilEven(&FiveDimRedBlackGrid,npoint,Even,directions,displacements), // source is Even + StencilOdd (&FiveDimRedBlackGrid,npoint,Odd ,directions,displacements), // source is Odd mass(_mass), c1(_c1), c2(_c2), u0(_u0), - Umu(_FourDimGrid), - UmuEven(_FourDimRedBlackGrid), - UmuOdd (_FourDimRedBlackGrid), - UUUmu(_FourDimGrid), - UUUmuEven(_FourDimRedBlackGrid), - UUUmuOdd(_FourDimRedBlackGrid), - Lebesgue(_FourDimGrid), - LebesgueEvenOdd(_FourDimRedBlackGrid) + Umu(&FourDimGrid), + UmuEven(&FourDimRedBlackGrid), + UmuOdd (&FourDimRedBlackGrid), + UUUmu(&FourDimGrid), + UUUmuEven(&FourDimRedBlackGrid), + UUUmuOdd(&FourDimRedBlackGrid), + Lebesgue(&FourDimGrid), + LebesgueEvenOdd(&FourDimRedBlackGrid), + _tmp(&FiveDimRedBlackGrid) { + // some assertions assert(FiveDimGrid._ndimension==5); assert(FourDimGrid._ndimension==4); - assert(FiveDimRedBlackGrid._ndimension==5); assert(FourDimRedBlackGrid._ndimension==4); - assert(FiveDimRedBlackGrid._checker_dim==1); - - // Dimension zero of the five-d is the Ls direction + assert(FiveDimRedBlackGrid._ndimension==5); + assert(FiveDimRedBlackGrid._checker_dim==1); // Don't checker the s direction + + // extent of fifth dim and not spread out Ls=FiveDimGrid._fdimensions[0]; assert(FiveDimRedBlackGrid._fdimensions[0]==Ls); - assert(FiveDimRedBlackGrid._processors[0] ==1); - assert(FiveDimRedBlackGrid._simd_layout[0]==1); assert(FiveDimGrid._processors[0] ==1); - assert(FiveDimGrid._simd_layout[0] ==1); - + assert(FiveDimRedBlackGrid._processors[0] ==1); + // Other dimensions must match the decomposition of the four-D fields for(int d=0;d<4;d++){ - assert(FourDimRedBlackGrid._fdimensions[d] ==FourDimGrid._fdimensions[d]); - assert(FiveDimRedBlackGrid._fdimensions[d+1]==FourDimGrid._fdimensions[d]); - - assert(FourDimRedBlackGrid._processors[d] ==FourDimGrid._processors[d]); - assert(FiveDimRedBlackGrid._processors[d+1] ==FourDimGrid._processors[d]); - - assert(FourDimRedBlackGrid._simd_layout[d] ==FourDimGrid._simd_layout[d]); - assert(FiveDimRedBlackGrid._simd_layout[d+1]==FourDimGrid._simd_layout[d]); - - assert(FiveDimGrid._fdimensions[d+1] ==FourDimGrid._fdimensions[d]); assert(FiveDimGrid._processors[d+1] ==FourDimGrid._processors[d]); + assert(FiveDimRedBlackGrid._processors[d+1] ==FourDimGrid._processors[d]); + assert(FourDimRedBlackGrid._processors[d] ==FourDimGrid._processors[d]); + + assert(FiveDimGrid._fdimensions[d+1] ==FourDimGrid._fdimensions[d]); + assert(FiveDimRedBlackGrid._fdimensions[d+1]==FourDimGrid._fdimensions[d]); + assert(FourDimRedBlackGrid._fdimensions[d] ==FourDimGrid._fdimensions[d]); + assert(FiveDimGrid._simd_layout[d+1] ==FourDimGrid._simd_layout[d]); + assert(FiveDimRedBlackGrid._simd_layout[d+1]==FourDimGrid._simd_layout[d]); + assert(FourDimRedBlackGrid._simd_layout[d] ==FourDimGrid._simd_layout[d]); } + + if (Impl::LsVectorised) { + + int nsimd = Simd::Nsimd(); + // Dimension zero of the five-d is the Ls direction + assert(FiveDimGrid._simd_layout[0] ==nsimd); + assert(FiveDimRedBlackGrid._simd_layout[0]==nsimd); + + for(int d=0;d<4;d++){ + assert(FourDimGrid._simd_layout[d]=1); + assert(FourDimRedBlackGrid._simd_layout[d]=1); + assert(FiveDimRedBlackGrid._simd_layout[d+1]==1); + } + + } else { + + // Dimension zero of the five-d is the Ls direction + assert(FiveDimRedBlackGrid._simd_layout[0]==1); + assert(FiveDimGrid._simd_layout[0] ==1); + + } + // Allocate the required comms buffer ImportGauge(_Uthin,_Ufat); } @@ -112,8 +134,6 @@ void ImprovedStaggeredFermion5D::ImportGauge(const GaugeField &_Uthin) template void ImprovedStaggeredFermion5D::ImportGauge(const GaugeField &_Uthin,const GaugeField &_Ufat) { - GaugeLinkField U(GaugeGrid()); - //////////////////////////////////////////////////////// // Double Store should take two fields for Naik and one hop separately. //////////////////////////////////////////////////////// @@ -126,7 +146,7 @@ void ImprovedStaggeredFermion5D::ImportGauge(const GaugeField &_Uthin,cons //////////////////////////////////////////////////////// for (int mu = 0; mu < Nd; mu++) { - U = PeekIndex(Umu, mu); + auto U = PeekIndex(Umu, mu); PokeIndex(Umu, U*( 0.5*c1/u0), mu ); U = PeekIndex(Umu, mu+4); @@ -153,8 +173,7 @@ void ImprovedStaggeredFermion5D::DhopDir(const FermionField &in, FermionFi Compressor compressor; Stencil.HaloExchange(in,compressor); - PARALLEL_FOR_LOOP - for(int ss=0;ssoSites();ss++){ + parallel_for(int ss=0;ssoSites();ss++){ for(int s=0;s::DhopInternal(StencilImpl & st, LebesgueOr const FermionField &in, FermionField &out,int dag) { Compressor compressor; - int LLs = in._grid->_rdimensions[0]; - + + + + DhopTotalTime -= usecond(); + DhopCommTime -= usecond(); st.HaloExchange(in,compressor); + DhopCommTime += usecond(); + DhopComputeTime -= usecond(); // Dhop takes the 4d grid from U, and makes a 5d index for fermion if (dag == DaggerYes) { - PARALLEL_FOR_LOOP - for (int ss = 0; ss < U._grid->oSites(); ss++) { - for(int s=0;soSites(); ss++) { int sU=ss; - int sF = s+LLs*sU; - Kernels::DhopSiteDag(st, lo, U, UUU, st.CommBuf(), sF, sU, in, out); - }} + Kernels::DhopSiteDag(st, lo, U, UUU, st.CommBuf(), LLs, sU,in, out); + } } else { - PARALLEL_FOR_LOOP - for (int ss = 0; ss < U._grid->oSites(); ss++) { - for(int s=0;soSites(); ss++) { int sU=ss; - int sF = s+LLs*sU; - Kernels::DhopSite(st,lo,U,UUU,st.CommBuf(),sF,sU,in,out); - }} + Kernels::DhopSite(st,lo,U,UUU,st.CommBuf(),LLs,sU,in,out); + } } + DhopComputeTime += usecond(); + DhopTotalTime += usecond(); } template void ImprovedStaggeredFermion5D::DhopOE(const FermionField &in, FermionField &out,int dag) { + DhopCalls+=1; conformable(in._grid,FermionRedBlackGrid()); // verifies half grid conformable(in._grid,out._grid); // drops the cb check @@ -250,6 +271,7 @@ void ImprovedStaggeredFermion5D::DhopOE(const FermionField &in, FermionFie template void ImprovedStaggeredFermion5D::DhopEO(const FermionField &in, FermionField &out,int dag) { + DhopCalls+=1; conformable(in._grid,FermionRedBlackGrid()); // verifies half grid conformable(in._grid,out._grid); // drops the cb check @@ -261,6 +283,7 @@ void ImprovedStaggeredFermion5D::DhopEO(const FermionField &in, FermionFie template void ImprovedStaggeredFermion5D::Dhop(const FermionField &in, FermionField &out,int dag) { + DhopCalls+=2; conformable(in._grid,FermionGrid()); // verifies full grid conformable(in._grid,out._grid); @@ -269,6 +292,54 @@ void ImprovedStaggeredFermion5D::Dhop(const FermionField &in, FermionField DhopInternal(Stencil,Lebesgue,Umu,UUUmu,in,out,dag); } +template +void ImprovedStaggeredFermion5D::Report(void) +{ + std::vector latt = GridDefaultLatt(); + RealD volume = Ls; for(int mu=0;mu_Nprocessors; + RealD NN = _FourDimGrid->NodeCount(); + + std::cout << GridLogMessage << "#### Dhop calls report " << std::endl; + + std::cout << GridLogMessage << "ImprovedStaggeredFermion5D Number of DhopEO Calls : " + << DhopCalls << std::endl; + std::cout << GridLogMessage << "ImprovedStaggeredFermion5D TotalTime /Calls : " + << DhopTotalTime / DhopCalls << " us" << std::endl; + std::cout << GridLogMessage << "ImprovedStaggeredFermion5D CommTime /Calls : " + << DhopCommTime / DhopCalls << " us" << std::endl; + std::cout << GridLogMessage << "ImprovedStaggeredFermion5D ComputeTime/Calls : " + << DhopComputeTime / DhopCalls << " us" << std::endl; + + // Average the compute time + _FourDimGrid->GlobalSum(DhopComputeTime); + DhopComputeTime/=NP; + + RealD mflops = 1154*volume*DhopCalls/DhopComputeTime/2; // 2 for red black counting + std::cout << GridLogMessage << "Average mflops/s per call : " << mflops << std::endl; + std::cout << GridLogMessage << "Average mflops/s per call per rank : " << mflops/NP << std::endl; + std::cout << GridLogMessage << "Average mflops/s per call per node : " << mflops/NN << std::endl; + + RealD Fullmflops = 1154*volume*DhopCalls/(DhopTotalTime)/2; // 2 for red black counting + std::cout << GridLogMessage << "Average mflops/s per call (full) : " << Fullmflops << std::endl; + std::cout << GridLogMessage << "Average mflops/s per call per rank (full): " << Fullmflops/NP << std::endl; + std::cout << GridLogMessage << "Average mflops/s per call per node (full): " << Fullmflops/NN << std::endl; + + std::cout << GridLogMessage << "ImprovedStaggeredFermion5D Stencil" < +void ImprovedStaggeredFermion5D::ZeroCounters(void) +{ + DhopCalls = 0; + DhopTotalTime = 0; + DhopCommTime = 0; + DhopComputeTime = 0; + Stencil.ZeroCounters(); + StencilEven.ZeroCounters(); + StencilOdd.ZeroCounters(); +} ///////////////////////////////////////////////////////////////////////// // Implement the general interface. Here we use SAME mass on all slices @@ -335,8 +406,8 @@ void ImprovedStaggeredFermion5D::MooeeInvDag(const FermionField &in, } - FermOpStaggeredTemplateInstantiate(ImprovedStaggeredFermion5D); +FermOpStaggeredVec5dTemplateInstantiate(ImprovedStaggeredFermion5D); }} diff --git a/lib/qcd/action/fermion/ImprovedStaggeredFermion5D.h b/lib/qcd/action/fermion/ImprovedStaggeredFermion5D.h index c3502229..ca1a955a 100644 --- a/lib/qcd/action/fermion/ImprovedStaggeredFermion5D.h +++ b/lib/qcd/action/fermion/ImprovedStaggeredFermion5D.h @@ -52,6 +52,19 @@ namespace QCD { INHERIT_IMPL_TYPES(Impl); typedef StaggeredKernels Kernels; + FermionField _tmp; + FermionField &tmp(void) { return _tmp; } + + //////////////////////////////////////// + // Performance monitoring + //////////////////////////////////////// + void Report(void); + void ZeroCounters(void); + double DhopTotalTime; + double DhopCalls; + double DhopCommTime; + double DhopComputeTime; + /////////////////////////////////////////////////////////////// // Implement the abstract base /////////////////////////////////////////////////////////////// diff --git a/lib/qcd/action/fermion/MobiusEOFAFermion.cc b/lib/qcd/action/fermion/MobiusEOFAFermion.cc new file mode 100644 index 00000000..085fa988 --- /dev/null +++ b/lib/qcd/action/fermion/MobiusEOFAFermion.cc @@ -0,0 +1,502 @@ +/************************************************************************************* + +Grid physics library, www.github.com/paboyle/Grid + +Source file: ./lib/qcd/action/fermion/MobiusEOFAFermion.cc + +Copyright (C) 2017 + +Author: Peter Boyle +Author: Peter Boyle +Author: Peter Boyle +Author: paboyle +Author: David Murphy + +This program is free software; you can redistribute it and/or modify +it under the terms of the GNU General Public License as published by +the Free Software Foundation; either version 2 of the License, or +(at your option) any later version. + +This program is distributed in the hope that it will be useful, +but WITHOUT ANY WARRANTY; without even the implied warranty of +MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +GNU General Public License for more details. + +You should have received a copy of the GNU General Public License along +with this program; if not, write to the Free Software Foundation, Inc., +51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA. + +See the full license in the file "LICENSE" in the top level distribution directory +*************************************************************************************/ +/* END LEGAL */ + +#include +#include +#include + +namespace Grid { +namespace QCD { + + template + MobiusEOFAFermion::MobiusEOFAFermion( + GaugeField &_Umu, + GridCartesian &FiveDimGrid, + GridRedBlackCartesian &FiveDimRedBlackGrid, + GridCartesian &FourDimGrid, + GridRedBlackCartesian &FourDimRedBlackGrid, + RealD _mq1, RealD _mq2, RealD _mq3, + RealD _shift, int _pm, RealD _M5, + RealD _b, RealD _c, const ImplParams &p) : + AbstractEOFAFermion(_Umu, FiveDimGrid, FiveDimRedBlackGrid, + FourDimGrid, FourDimRedBlackGrid, _mq1, _mq2, _mq3, + _shift, _pm, _M5, _b, _c, p) + { + int Ls = this->Ls; + + RealD eps = 1.0; + Approx::zolotarev_data *zdata = Approx::higham(eps, this->Ls); + assert(zdata->n == this->Ls); + + std::cout << GridLogMessage << "MobiusEOFAFermion (b=" << _b << + ",c=" << _c << ") with Ls=" << Ls << std::endl; + this->SetCoefficientsTanh(zdata, _b, _c); + std::cout << GridLogMessage << "EOFA parameters: (mq1=" << _mq1 << + ",mq2=" << _mq2 << ",mq3=" << _mq3 << ",shift=" << _shift << + ",pm=" << _pm << ")" << std::endl; + + Approx::zolotarev_free(zdata); + + if(_shift != 0.0){ + SetCoefficientsPrecondShiftOps(); + } else { + Mooee_shift.resize(Ls, 0.0); + MooeeInv_shift_lc.resize(Ls, 0.0); + MooeeInv_shift_norm.resize(Ls, 0.0); + MooeeInvDag_shift_lc.resize(Ls, 0.0); + MooeeInvDag_shift_norm.resize(Ls, 0.0); + } + } + + /*************************************************************** + /* Additional EOFA operators only called outside the inverter. + /* Since speed is not essential, simple axpby-style + /* implementations should be fine. + /***************************************************************/ + template + void MobiusEOFAFermion::Omega(const FermionField& psi, FermionField& Din, int sign, int dag) + { + int Ls = this->Ls; + RealD alpha = this->alpha; + + Din = zero; + if((sign == 1) && (dag == 0)) { // \Omega_{+} + for(int s=0; s + void MobiusEOFAFermion::Dtilde(const FermionField& psi, FermionField& chi) + { + int Ls = this->Ls; + RealD b = 0.5 * ( 1.0 + this->alpha ); + RealD c = 0.5 * ( 1.0 - this->alpha ); + RealD mq1 = this->mq1; + + for(int s=0; s + void MobiusEOFAFermion::DtildeInv(const FermionField& psi, FermionField& chi) + { + int Ls = this->Ls; + RealD m = this->mq1; + RealD c = 0.5 * this->alpha; + RealD d = 0.5; + + RealD DtInv_p(0.0), DtInv_m(0.0); + RealD N = std::pow(c+d,Ls) + m*std::pow(c-d,Ls); + FermionField tmp(this->FermionGrid()); + + for(int s=0; s sp) ? 0.0 : std::pow(-1.0,sp-s) * std::pow(c-d,sp-s) / std::pow(c+d,sp-s+1); + + if(sp == 0){ + axpby_ssp_pplus (tmp, 0.0, tmp, DtInv_p, psi, s, sp); + axpby_ssp_pminus(tmp, 0.0, tmp, DtInv_m, psi, s, sp); + } else { + axpby_ssp_pplus (tmp, 1.0, tmp, DtInv_p, psi, s, sp); + axpby_ssp_pminus(tmp, 1.0, tmp, DtInv_m, psi, s, sp); + } + + }} + } + + /*****************************************************************************************************/ + + template + RealD MobiusEOFAFermion::M(const FermionField& psi, FermionField& chi) + { + int Ls = this->Ls; + + FermionField Din(psi._grid); + + this->Meooe5D(psi, Din); + this->DW(Din, chi, DaggerNo); + axpby(chi, 1.0, 1.0, chi, psi); + this->M5D(psi, chi); + return(norm2(chi)); + } + + template + RealD MobiusEOFAFermion::Mdag(const FermionField& psi, FermionField& chi) + { + int Ls = this->Ls; + + FermionField Din(psi._grid); + + this->DW(psi, Din, DaggerYes); + this->MeooeDag5D(Din, chi); + this->M5Ddag(psi, chi); + axpby(chi, 1.0, 1.0, chi, psi); + return(norm2(chi)); + } + + /******************************************************************** + /* Performance critical fermion operators called inside the inverter + /********************************************************************/ + + template + void MobiusEOFAFermion::M5D(const FermionField& psi, FermionField& chi) + { + int Ls = this->Ls; + + std::vector diag(Ls,1.0); + std::vector upper(Ls,-1.0); upper[Ls-1] = this->mq1; + std::vector lower(Ls,-1.0); lower[0] = this->mq1; + + // no shift term + if(this->shift == 0.0){ this->M5D(psi, chi, chi, lower, diag, upper); } + + // fused M + shift operation + else{ this->M5D_shift(psi, chi, chi, lower, diag, upper, Mooee_shift); } + } + + template + void MobiusEOFAFermion::M5Ddag(const FermionField& psi, FermionField& chi) + { + int Ls = this->Ls; + + std::vector diag(Ls,1.0); + std::vector upper(Ls,-1.0); upper[Ls-1] = this->mq1; + std::vector lower(Ls,-1.0); lower[0] = this->mq1; + + // no shift term + if(this->shift == 0.0){ this->M5Ddag(psi, chi, chi, lower, diag, upper); } + + // fused M + shift operation + else{ this->M5Ddag_shift(psi, chi, chi, lower, diag, upper, Mooee_shift); } + } + + // half checkerboard operations + template + void MobiusEOFAFermion::Mooee(const FermionField& psi, FermionField& chi) + { + int Ls = this->Ls; + + // coefficients of Mooee + std::vector diag = this->bee; + std::vector upper(Ls); + std::vector lower(Ls); + for(int s=0; scee[s]; + lower[s] = -this->cee[s]; + } + upper[Ls-1] *= -this->mq1; + lower[0] *= -this->mq1; + + // no shift term + if(this->shift == 0.0){ this->M5D(psi, psi, chi, lower, diag, upper); } + + // fused M + shift operation + else { this->M5D_shift(psi, psi, chi, lower, diag, upper, Mooee_shift); } + } + + template + void MobiusEOFAFermion::MooeeDag(const FermionField& psi, FermionField& chi) + { + int Ls = this->Ls; + + // coefficients of MooeeDag + std::vector diag = this->bee; + std::vector upper(Ls); + std::vector lower(Ls); + for(int s=0; scee[s+1]; + lower[s] = this->mq1*this->cee[Ls-1]; + } else if(s==(Ls-1)) { + upper[s] = this->mq1*this->cee[0]; + lower[s] = -this->cee[s-1]; + } else { + upper[s] = -this->cee[s+1]; + lower[s] = -this->cee[s-1]; + } + } + + // no shift term + if(this->shift == 0.0){ this->M5Ddag(psi, psi, chi, lower, diag, upper); } + + // fused M + shift operation + else{ this->M5Ddag_shift(psi, psi, chi, lower, diag, upper, Mooee_shift); } + } + + /****************************************************************************************/ + + // Computes coefficients for applying Cayley preconditioned shift operators + // (Mooee + \Delta) --> Mooee_shift + // (Mooee + \Delta)^{-1} --> MooeeInv_shift_lc, MooeeInv_shift_norm + // (Mooee + \Delta)^{-dag} --> MooeeInvDag_shift_lc, MooeeInvDag_shift_norm + // For the latter two cases, the operation takes the form + // [ (Mooee + \Delta)^{-1} \psi ]_{i} = Mooee_{ij} \psi_{j} + + // ( MooeeInv_shift_norm )_{i} ( \sum_{j} [ MooeeInv_shift_lc ]_{j} P_{pm} \psi_{j} ) + template + void MobiusEOFAFermion::SetCoefficientsPrecondShiftOps() + { + int Ls = this->Ls; + int pm = this->pm; + RealD alpha = this->alpha; + RealD k = this->k; + RealD mq1 = this->mq1; + RealD shift = this->shift; + + // Initialize + Mooee_shift.resize(Ls); + MooeeInv_shift_lc.resize(Ls); + MooeeInv_shift_norm.resize(Ls); + MooeeInvDag_shift_lc.resize(Ls); + MooeeInvDag_shift_norm.resize(Ls); + + // Construct Mooee_shift + int idx(0); + Coeff_t N = ( (pm == 1) ? 1.0 : -1.0 ) * (2.0*shift*k) * + ( std::pow(alpha+1.0,Ls) + mq1*std::pow(alpha-1.0,Ls) ); + for(int s=0; s d = Mooee_shift; + std::vector u(Ls,0.0); + std::vector y(Ls,0.0); + std::vector q(Ls,0.0); + if(pm == 1){ u[0] = 1.0; } + else{ u[Ls-1] = 1.0; } + + // Tridiagonal matrix algorithm + Sherman-Morrison formula + // + // We solve + // ( Mooee' + u \otimes v ) MooeeInvDag_shift_lc = Mooee_shift + // where Mooee' is the tridiagonal part of Mooee_{+}, and + // u = (1,0,...,0) and v = (0,...,0,mq1*cee[0]) are chosen + // so that the outer-product u \otimes v gives the (0,Ls-1) + // entry of Mooee_{+}. + // + // We do this as two solves: Mooee'*y = d and Mooee'*q = u, + // and then construct the solution to the original system + // MooeeInvDag_shift_lc = y - / ( 1 + ) q + if(pm == 1){ + for(int s=1; scee[s] / this->bee[s-1]; + d[s] -= m*d[s-1]; + u[s] -= m*u[s-1]; + } + } + y[Ls-1] = d[Ls-1] / this->bee[Ls-1]; + q[Ls-1] = u[Ls-1] / this->bee[Ls-1]; + for(int s=Ls-2; s>=0; --s){ + if(pm == 1){ + y[s] = d[s] / this->bee[s]; + q[s] = u[s] / this->bee[s]; + } else { + y[s] = ( d[s] + this->cee[s]*y[s+1] ) / this->bee[s]; + q[s] = ( u[s] + this->cee[s]*q[s+1] ) / this->bee[s]; + } + } + + // Construct MooeeInvDag_shift_lc + for(int s=0; scee[0]*y[Ls-1] / + (1.0+mq1*this->cee[0]*q[Ls-1]) * q[s]; + } else { + MooeeInvDag_shift_lc[s] = y[s] - mq1*this->cee[Ls-1]*y[0] / + (1.0+mq1*this->cee[Ls-1]*q[0]) * q[s]; + } + } + + // Compute remaining coefficients + N = (pm == 1) ? (1.0 + MooeeInvDag_shift_lc[Ls-1]) : (1.0 + MooeeInvDag_shift_lc[0]); + for(int s=0; sbee[s],s) * std::pow(this->cee[s],Ls-1-s); } + else{ MooeeInv_shift_lc[s] = std::pow(this->bee[s],Ls-1-s) * std::pow(this->cee[s],s); } + + // MooeeInv_shift_norm + MooeeInv_shift_norm[s] = -MooeeInvDag_shift_lc[s] / + ( std::pow(this->bee[s],Ls) + mq1*std::pow(this->cee[s],Ls) ) / N; + + // MooeeInvDag_shift_norm + if(pm == 1){ MooeeInvDag_shift_norm[s] = -std::pow(this->bee[s],s) * std::pow(this->cee[s],Ls-1-s) / + ( std::pow(this->bee[s],Ls) + mq1*std::pow(this->cee[s],Ls) ) / N; } + else{ MooeeInvDag_shift_norm[s] = -std::pow(this->bee[s],Ls-1-s) * std::pow(this->cee[s],s) / + ( std::pow(this->bee[s],Ls) + mq1*std::pow(this->cee[s],Ls) ) / N; } + } + } + } + + // Recompute coefficients for a different value of shift constant + template + void MobiusEOFAFermion::RefreshShiftCoefficients(RealD new_shift) + { + this->shift = new_shift; + if(new_shift != 0.0){ + SetCoefficientsPrecondShiftOps(); + } else { + int Ls = this->Ls; + Mooee_shift.resize(Ls,0.0); + MooeeInv_shift_lc.resize(Ls,0.0); + MooeeInv_shift_norm.resize(Ls,0.0); + MooeeInvDag_shift_lc.resize(Ls,0.0); + MooeeInvDag_shift_norm.resize(Ls,0.0); + } + } + + template + void MobiusEOFAFermion::MooeeInternalCompute(int dag, int inv, + Vector >& Matp, Vector >& Matm) + { + int Ls = this->Ls; + + GridBase* grid = this->FermionRedBlackGrid(); + int LLs = grid->_rdimensions[0]; + + if(LLs == Ls){ return; } // Not vectorised in 5th direction + + Eigen::MatrixXcd Pplus = Eigen::MatrixXcd::Zero(Ls,Ls); + Eigen::MatrixXcd Pminus = Eigen::MatrixXcd::Zero(Ls,Ls); + + for(int s=0; sbee[s]; + Pminus(s,s) = this->bee[s]; + } + + for(int s=0; scee[s]; + Pplus(s+1,s) = -this->cee[s+1]; + } + + Pplus (0,Ls-1) = this->mq1*this->cee[0]; + Pminus(Ls-1,0) = this->mq1*this->cee[Ls-1]; + + if(this->shift != 0.0){ + RealD c = 0.5 * this->alpha; + RealD d = 0.5; + RealD N = this->shift * this->k * ( std::pow(c+d,Ls) + this->mq1*std::pow(c-d,Ls) ); + if(this->pm == 1) { + for(int s=0; s::iscomplex()) { + sp[l] = PplusMat (l*istride+s1*ostride,s2); + sm[l] = PminusMat(l*istride+s1*ostride,s2); + } else { + // if real + scalar_type tmp; + tmp = PplusMat (l*istride+s1*ostride,s2); + sp[l] = scalar_type(tmp.real(),tmp.real()); + tmp = PminusMat(l*istride+s1*ostride,s2); + sm[l] = scalar_type(tmp.real(),tmp.real()); + } + } + Matp[LLs*s2+s1] = Vp; + Matm[LLs*s2+s1] = Vm; + }} + } + + FermOpTemplateInstantiate(MobiusEOFAFermion); + GparityFermOpTemplateInstantiate(MobiusEOFAFermion); + +}} diff --git a/lib/qcd/action/fermion/MobiusEOFAFermion.h b/lib/qcd/action/fermion/MobiusEOFAFermion.h new file mode 100644 index 00000000..519b49e7 --- /dev/null +++ b/lib/qcd/action/fermion/MobiusEOFAFermion.h @@ -0,0 +1,133 @@ +/************************************************************************************* + +Grid physics library, www.github.com/paboyle/Grid + +Source file: ./lib/qcd/action/fermion/MobiusEOFAFermion.h + +Copyright (C) 2017 + +Author: Peter Boyle +Author: Peter Boyle +Author: David Murphy + +This program is free software; you can redistribute it and/or modify +it under the terms of the GNU General Public License as published by +the Free Software Foundation; either version 2 of the License, or +(at your option) any later version. + +This program is distributed in the hope that it will be useful, +but WITHOUT ANY WARRANTY; without even the implied warranty of +MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +GNU General Public License for more details. + +You should have received a copy of the GNU General Public License along +with this program; if not, write to the Free Software Foundation, Inc., +51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA. + +See the full license in the file "LICENSE" in the top level distribution directory +*************************************************************************************/ +/* END LEGAL */ +#ifndef GRID_QCD_MOBIUS_EOFA_FERMION_H +#define GRID_QCD_MOBIUS_EOFA_FERMION_H + +#include + +namespace Grid { +namespace QCD { + + template + class MobiusEOFAFermion : public AbstractEOFAFermion + { + public: + INHERIT_IMPL_TYPES(Impl); + + public: + // Shift operator coefficients for red-black preconditioned Mobius EOFA + std::vector Mooee_shift; + std::vector MooeeInv_shift_lc; + std::vector MooeeInv_shift_norm; + std::vector MooeeInvDag_shift_lc; + std::vector MooeeInvDag_shift_norm; + + virtual void Instantiatable(void) {}; + + // EOFA-specific operations + virtual void Omega (const FermionField& in, FermionField& out, int sign, int dag); + virtual void Dtilde (const FermionField& in, FermionField& out); + virtual void DtildeInv (const FermionField& in, FermionField& out); + + // override multiply + virtual RealD M (const FermionField& in, FermionField& out); + virtual RealD Mdag (const FermionField& in, FermionField& out); + + // half checkerboard operations + virtual void Mooee (const FermionField& in, FermionField& out); + virtual void MooeeDag (const FermionField& in, FermionField& out); + virtual void MooeeInv (const FermionField& in, FermionField& out); + virtual void MooeeInv_shift (const FermionField& in, FermionField& out); + virtual void MooeeInvDag (const FermionField& in, FermionField& out); + virtual void MooeeInvDag_shift(const FermionField& in, FermionField& out); + + virtual void M5D (const FermionField& psi, FermionField& chi); + virtual void M5Ddag (const FermionField& psi, FermionField& chi); + + ///////////////////////////////////////////////////// + // Instantiate different versions depending on Impl + ///////////////////////////////////////////////////// + void M5D(const FermionField& psi, const FermionField& phi, FermionField& chi, + std::vector& lower, std::vector& diag, std::vector& upper); + + void M5D_shift(const FermionField& psi, const FermionField& phi, FermionField& chi, + std::vector& lower, std::vector& diag, std::vector& upper, + std::vector& shift_coeffs); + + void M5Ddag(const FermionField& psi, const FermionField& phi, FermionField& chi, + std::vector& lower, std::vector& diag, std::vector& upper); + + void M5Ddag_shift(const FermionField& psi, const FermionField& phi, FermionField& chi, + std::vector& lower, std::vector& diag, std::vector& upper, + std::vector& shift_coeffs); + + void MooeeInternal(const FermionField& in, FermionField& out, int dag, int inv); + + void MooeeInternalCompute(int dag, int inv, Vector>& Matp, Vector>& Matm); + + void MooeeInternalAsm(const FermionField& in, FermionField& out, int LLs, int site, + Vector>& Matp, Vector>& Matm); + + void MooeeInternalZAsm(const FermionField& in, FermionField& out, int LLs, int site, + Vector>& Matp, Vector>& Matm); + + virtual void RefreshShiftCoefficients(RealD new_shift); + + // Constructors + MobiusEOFAFermion(GaugeField& _Umu, GridCartesian& FiveDimGrid, GridRedBlackCartesian& FiveDimRedBlackGrid, + GridCartesian& FourDimGrid, GridRedBlackCartesian& FourDimRedBlackGrid, + RealD _mq1, RealD _mq2, RealD _mq3, RealD _shift, int pm, + RealD _M5, RealD _b, RealD _c, const ImplParams& p=ImplParams()); + + protected: + void SetCoefficientsPrecondShiftOps(void); + }; +}} + +#define INSTANTIATE_DPERP_MOBIUS_EOFA(A)\ +template void MobiusEOFAFermion::M5D(const FermionField& psi, const FermionField& phi, FermionField& chi, \ + std::vector& lower, std::vector& diag, std::vector& upper); \ +template void MobiusEOFAFermion::M5D_shift(const FermionField& psi, const FermionField& phi, FermionField& chi, \ + std::vector& lower, std::vector& diag, std::vector& upper, std::vector& shift_coeffs); \ +template void MobiusEOFAFermion::M5Ddag(const FermionField& psi, const FermionField& phi, FermionField& chi, \ + std::vector& lower, std::vector& diag, std::vector& upper); \ +template void MobiusEOFAFermion::M5Ddag_shift(const FermionField& psi, const FermionField& phi, FermionField& chi, \ + std::vector& lower, std::vector& diag, std::vector& upper, std::vector& shift_coeffs); \ +template void MobiusEOFAFermion::MooeeInv(const FermionField& psi, FermionField& chi); \ +template void MobiusEOFAFermion::MooeeInv_shift(const FermionField& psi, FermionField& chi); \ +template void MobiusEOFAFermion::MooeeInvDag(const FermionField& psi, FermionField& chi); \ +template void MobiusEOFAFermion::MooeeInvDag_shift(const FermionField& psi, FermionField& chi); + +#undef MOBIUS_EOFA_DPERP_DENSE +#define MOBIUS_EOFA_DPERP_CACHE +#undef MOBIUS_EOFA_DPERP_LINALG +#define MOBIUS_EOFA_DPERP_VEC + +#endif diff --git a/lib/qcd/action/fermion/MobiusEOFAFermioncache.cc b/lib/qcd/action/fermion/MobiusEOFAFermioncache.cc new file mode 100644 index 00000000..420f6390 --- /dev/null +++ b/lib/qcd/action/fermion/MobiusEOFAFermioncache.cc @@ -0,0 +1,429 @@ +/************************************************************************************* + +Grid physics library, www.github.com/paboyle/Grid + +Source file: ./lib/qcd/action/fermion/MobiusEOFAFermioncache.cc + +Copyright (C) 2017 + +Author: Peter Boyle +Author: Peter Boyle +Author: Peter Boyle +Author: paboyle +Author: David Murphy + +This program is free software; you can redistribute it and/or modify +it under the terms of the GNU General Public License as published by +the Free Software Foundation; either version 2 of the License, or +(at your option) any later version. + +This program is distributed in the hope that it will be useful, +but WITHOUT ANY WARRANTY; without even the implied warranty of +MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +GNU General Public License for more details. + +You should have received a copy of the GNU General Public License along +with this program; if not, write to the Free Software Foundation, Inc., +51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA. + +See the full license in the file "LICENSE" in the top level distribution directory +*************************************************************************************/ +/* END LEGAL */ + +#include +#include + +namespace Grid { +namespace QCD { + + // FIXME -- make a version of these routines with site loop outermost for cache reuse. + + template + void MobiusEOFAFermion::M5D(const FermionField &psi, const FermionField &phi, FermionField &chi, + std::vector &lower, std::vector &diag, std::vector &upper) + { + int Ls = this->Ls; + GridBase *grid = psi._grid; + + assert(phi.checkerboard == psi.checkerboard); + chi.checkerboard = psi.checkerboard; + + // Flops = 6.0*(Nc*Ns) *Ls*vol + this->M5Dcalls++; + this->M5Dtime -= usecond(); + + parallel_for(int ss=0; ssoSites(); ss+=Ls){ + for(int s=0; sM5Dtime += usecond(); + } + + template + void MobiusEOFAFermion::M5D_shift(const FermionField &psi, const FermionField &phi, FermionField &chi, + std::vector &lower, std::vector &diag, std::vector &upper, + std::vector &shift_coeffs) + { + int Ls = this->Ls; + int shift_s = (this->pm == 1) ? (Ls-1) : 0; // s-component modified by shift operator + GridBase *grid = psi._grid; + + assert(phi.checkerboard == psi.checkerboard); + chi.checkerboard = psi.checkerboard; + + // Flops = 6.0*(Nc*Ns) *Ls*vol + this->M5Dcalls++; + this->M5Dtime -= usecond(); + + parallel_for(int ss=0; ssoSites(); ss+=Ls){ + for(int s=0; spm == 1){ spProj5p(tmp, psi._odata[ss+shift_s]); } + else{ spProj5m(tmp, psi._odata[ss+shift_s]); } + chi[ss+s] = chi[ss+s] + shift_coeffs[s]*tmp; + } + } + + this->M5Dtime += usecond(); + } + + template + void MobiusEOFAFermion::M5Ddag(const FermionField &psi, const FermionField &phi, FermionField &chi, + std::vector &lower, std::vector &diag, std::vector &upper) + { + int Ls = this->Ls; + GridBase *grid = psi._grid; + + assert(phi.checkerboard == psi.checkerboard); + chi.checkerboard = psi.checkerboard; + + // Flops = 6.0*(Nc*Ns) *Ls*vol + this->M5Dcalls++; + this->M5Dtime -= usecond(); + + parallel_for(int ss=0; ssoSites(); ss+=Ls){ + auto tmp = psi._odata[0]; + for(int s=0; sM5Dtime += usecond(); + } + + template + void MobiusEOFAFermion::M5Ddag_shift(const FermionField &psi, const FermionField &phi, FermionField &chi, + std::vector &lower, std::vector &diag, std::vector &upper, + std::vector &shift_coeffs) + { + int Ls = this->Ls; + int shift_s = (this->pm == 1) ? (Ls-1) : 0; // s-component modified by shift operator + GridBase *grid = psi._grid; + + assert(phi.checkerboard == psi.checkerboard); + chi.checkerboard = psi.checkerboard; + + // Flops = 6.0*(Nc*Ns) *Ls*vol + this->M5Dcalls++; + this->M5Dtime -= usecond(); + + parallel_for(int ss=0; ssoSites(); ss+=Ls){ + chi[ss+Ls-1] = zero; + auto tmp = psi._odata[0]; + for(int s=0; spm == 1){ spProj5p(tmp, psi._odata[ss+s]); } + else{ spProj5m(tmp, psi._odata[ss+s]); } + chi[ss+shift_s] = chi[ss+shift_s] + shift_coeffs[s]*tmp; + } + } + + this->M5Dtime += usecond(); + } + + template + void MobiusEOFAFermion::MooeeInv(const FermionField &psi, FermionField &chi) + { + if(this->shift != 0.0){ MooeeInv_shift(psi,chi); return; } + + GridBase *grid = psi._grid; + int Ls = this->Ls; + + chi.checkerboard = psi.checkerboard; + + this->MooeeInvCalls++; + this->MooeeInvTime -= usecond(); + + parallel_for(int ss=0; ssoSites(); ss+=Ls){ + + auto tmp = psi._odata[0]; + + // Apply (L^{\prime})^{-1} + chi[ss] = psi[ss]; // chi[0]=psi[0] + for(int s=1; slee[s-1]*tmp; + } + + // L_m^{-1} + for(int s=0; sleem[s]*tmp; + } + + // U_m^{-1} D^{-1} + for(int s=0; sdee[s])*chi[ss+s] - (this->ueem[s]/this->dee[Ls-1])*tmp; + } + chi[ss+Ls-1] = (1.0/this->dee[Ls-1])*chi[ss+Ls-1]; + + // Apply U^{-1} + for(int s=Ls-2; s>=0; s--){ + spProj5m(tmp, chi[ss+s+1]); + chi[ss+s] = chi[ss+s] - this->uee[s]*tmp; + } + } + + this->MooeeInvTime += usecond(); + } + + template + void MobiusEOFAFermion::MooeeInv_shift(const FermionField &psi, FermionField &chi) + { + GridBase *grid = psi._grid; + int Ls = this->Ls; + + chi.checkerboard = psi.checkerboard; + + this->MooeeInvCalls++; + this->MooeeInvTime -= usecond(); + + parallel_for(int ss=0; ssoSites(); ss+=Ls){ + + auto tmp1 = psi._odata[0]; + auto tmp2 = psi._odata[0]; + auto tmp2_spProj = psi._odata[0]; + + // Apply (L^{\prime})^{-1} and accumulate MooeeInv_shift_lc[j]*psi[j] in tmp2 + chi[ss] = psi[ss]; // chi[0]=psi[0] + tmp2 = MooeeInv_shift_lc[0]*psi[ss]; + for(int s=1; slee[s-1]*tmp1; + tmp2 = tmp2 + MooeeInv_shift_lc[s]*psi[ss+s]; + } + if(this->pm == 1){ spProj5p(tmp2_spProj, tmp2);} + else{ spProj5m(tmp2_spProj, tmp2); } + + // L_m^{-1} + for(int s=0; sleem[s]*tmp1; + } + + // U_m^{-1} D^{-1} + for(int s=0; sdee[s])*chi[ss+s] - (this->ueem[s]/this->dee[Ls-1])*tmp1; + } + // chi[ss+Ls-1] = (1.0/this->dee[Ls-1])*chi[ss+Ls-1] + MooeeInv_shift_norm[Ls-1]*tmp2_spProj; + chi[ss+Ls-1] = (1.0/this->dee[Ls-1])*chi[ss+Ls-1]; + spProj5m(tmp1, chi[ss+Ls-1]); + chi[ss+Ls-1] = chi[ss+Ls-1] + MooeeInv_shift_norm[Ls-1]*tmp2_spProj; + + // Apply U^{-1} and add shift term + for(int s=Ls-2; s>=0; s--){ + chi[ss+s] = chi[ss+s] - this->uee[s]*tmp1; + spProj5m(tmp1, chi[ss+s]); + chi[ss+s] = chi[ss+s] + MooeeInv_shift_norm[s]*tmp2_spProj; + } + } + + this->MooeeInvTime += usecond(); + } + + template + void MobiusEOFAFermion::MooeeInvDag(const FermionField &psi, FermionField &chi) + { + if(this->shift != 0.0){ MooeeInvDag_shift(psi,chi); return; } + + GridBase *grid = psi._grid; + int Ls = this->Ls; + + chi.checkerboard = psi.checkerboard; + + this->MooeeInvCalls++; + this->MooeeInvTime -= usecond(); + + parallel_for(int ss=0; ssoSites(); ss+=Ls){ + + auto tmp = psi._odata[0]; + + // Apply (U^{\prime})^{-dag} + chi[ss] = psi[ss]; + for(int s=1; suee[s-1]*tmp; + } + + // U_m^{-\dag} + for(int s=0; sueem[s]*tmp; + } + + // L_m^{-\dag} D^{-dag} + for(int s=0; sdee[s])*chi[ss+s] - (this->leem[s]/this->dee[Ls-1])*tmp; + } + chi[ss+Ls-1] = (1.0/this->dee[Ls-1])*chi[ss+Ls-1]; + + // Apply L^{-dag} + for(int s=Ls-2; s>=0; s--){ + spProj5p(tmp, chi[ss+s+1]); + chi[ss+s] = chi[ss+s] - this->lee[s]*tmp; + } + } + + this->MooeeInvTime += usecond(); + } + + template + void MobiusEOFAFermion::MooeeInvDag_shift(const FermionField &psi, FermionField &chi) + { + GridBase *grid = psi._grid; + int Ls = this->Ls; + + chi.checkerboard = psi.checkerboard; + + this->MooeeInvCalls++; + this->MooeeInvTime -= usecond(); + + parallel_for(int ss=0; ssoSites(); ss+=Ls){ + + auto tmp1 = psi._odata[0]; + auto tmp2 = psi._odata[0]; + auto tmp2_spProj = psi._odata[0]; + + // Apply (U^{\prime})^{-dag} and accumulate MooeeInvDag_shift_lc[j]*psi[j] in tmp2 + chi[ss] = psi[ss]; + tmp2 = MooeeInvDag_shift_lc[0]*psi[ss]; + for(int s=1; suee[s-1]*tmp1; + tmp2 = tmp2 + MooeeInvDag_shift_lc[s]*psi[ss+s]; + } + if(this->pm == 1){ spProj5p(tmp2_spProj, tmp2);} + else{ spProj5m(tmp2_spProj, tmp2); } + + // U_m^{-\dag} + for(int s=0; sueem[s]*tmp1; + } + + // L_m^{-\dag} D^{-dag} + for(int s=0; sdee[s])*chi[ss+s] - (this->leem[s]/this->dee[Ls-1])*tmp1; + } + chi[ss+Ls-1] = (1.0/this->dee[Ls-1])*chi[ss+Ls-1]; + spProj5p(tmp1, chi[ss+Ls-1]); + chi[ss+Ls-1] = chi[ss+Ls-1] + MooeeInvDag_shift_norm[Ls-1]*tmp2_spProj; + + // Apply L^{-dag} + for(int s=Ls-2; s>=0; s--){ + chi[ss+s] = chi[ss+s] - this->lee[s]*tmp1; + spProj5p(tmp1, chi[ss+s]); + chi[ss+s] = chi[ss+s] + MooeeInvDag_shift_norm[s]*tmp2_spProj; + } + } + + this->MooeeInvTime += usecond(); + } + + #ifdef MOBIUS_EOFA_DPERP_CACHE + + INSTANTIATE_DPERP_MOBIUS_EOFA(WilsonImplF); + INSTANTIATE_DPERP_MOBIUS_EOFA(WilsonImplD); + INSTANTIATE_DPERP_MOBIUS_EOFA(GparityWilsonImplF); + INSTANTIATE_DPERP_MOBIUS_EOFA(GparityWilsonImplD); + INSTANTIATE_DPERP_MOBIUS_EOFA(ZWilsonImplF); + INSTANTIATE_DPERP_MOBIUS_EOFA(ZWilsonImplD); + + INSTANTIATE_DPERP_MOBIUS_EOFA(WilsonImplFH); + INSTANTIATE_DPERP_MOBIUS_EOFA(WilsonImplDF); + INSTANTIATE_DPERP_MOBIUS_EOFA(GparityWilsonImplFH); + INSTANTIATE_DPERP_MOBIUS_EOFA(GparityWilsonImplDF); + INSTANTIATE_DPERP_MOBIUS_EOFA(ZWilsonImplFH); + INSTANTIATE_DPERP_MOBIUS_EOFA(ZWilsonImplDF); + + #endif + +}} diff --git a/lib/qcd/action/fermion/MobiusEOFAFermiondense.cc b/lib/qcd/action/fermion/MobiusEOFAFermiondense.cc new file mode 100644 index 00000000..d66b8cd9 --- /dev/null +++ b/lib/qcd/action/fermion/MobiusEOFAFermiondense.cc @@ -0,0 +1,184 @@ +/************************************************************************************* + +Grid physics library, www.github.com/paboyle/Grid + +Source file: ./lib/qcd/action/fermion/MobiusEOFAFermiondense.cc + +Copyright (C) 2017 + +Author: Peter Boyle +Author: Peter Boyle +Author: Peter Boyle +Author: paboyle +Author: David Murphy + +This program is free software; you can redistribute it and/or modify +it under the terms of the GNU General Public License as published by +the Free Software Foundation; either version 2 of the License, or +(at your option) any later version. + +This program is distributed in the hope that it will be useful, +but WITHOUT ANY WARRANTY; without even the implied warranty of +MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +GNU General Public License for more details. + +You should have received a copy of the GNU General Public License along +with this program; if not, write to the Free Software Foundation, Inc., +51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA. + +See the full license in the file "LICENSE" in the top level distribution directory +*************************************************************************************/ +/* END LEGAL */ + +#include +#include +#include + +namespace Grid { +namespace QCD { + + /* + * Dense matrix versions of routines + */ + template + void MobiusEOFAFermion::MooeeInv(const FermionField& psi, FermionField& chi) + { + this->MooeeInternal(psi, chi, DaggerNo, InverseYes); + } + + template + void MobiusEOFAFermion::MooeeInv_shift(const FermionField& psi, FermionField& chi) + { + this->MooeeInternal(psi, chi, DaggerNo, InverseYes); + } + + template + void MobiusEOFAFermion::MooeeInvDag(const FermionField& psi, FermionField& chi) + { + this->MooeeInternal(psi, chi, DaggerYes, InverseYes); + } + + template + void MobiusEOFAFermion::MooeeInvDag_shift(const FermionField& psi, FermionField& chi) + { + this->MooeeInternal(psi, chi, DaggerYes, InverseYes); + } + + template + void MobiusEOFAFermion::MooeeInternal(const FermionField& psi, FermionField& chi, int dag, int inv) + { + int Ls = this->Ls; + int LLs = psi._grid->_rdimensions[0]; + int vol = psi._grid->oSites()/LLs; + + int pm = this->pm; + RealD shift = this->shift; + RealD alpha = this->alpha; + RealD k = this->k; + RealD mq1 = this->mq1; + + chi.checkerboard = psi.checkerboard; + + assert(Ls==LLs); + + Eigen::MatrixXd Pplus = Eigen::MatrixXd::Zero(Ls,Ls); + Eigen::MatrixXd Pminus = Eigen::MatrixXd::Zero(Ls,Ls); + + for(int s=0;sbee[s]; + Pminus(s,s) = this->bee[s]; + } + + for(int s=0; scee[s]; + } + + for(int s=0; scee[s+1]; + } + Pplus (0,Ls-1) = mq1*this->cee[0]; + Pminus(Ls-1,0) = mq1*this->cee[Ls-1]; + + if(shift != 0.0){ + Coeff_t N = 2.0 * ( std::pow(alpha+1.0,Ls) + mq1*std::pow(alpha-1.0,Ls) ); + for(int s=0; s::MooeeInternal(const FermionField& psi, FermionField& chi, int dag, int inv); + template void MobiusEOFAFermion::MooeeInternal(const FermionField& psi, FermionField& chi, int dag, int inv); + template void MobiusEOFAFermion::MooeeInternal(const FermionField& psi, FermionField& chi, int dag, int inv); + template void MobiusEOFAFermion::MooeeInternal(const FermionField& psi, FermionField& chi, int dag, int inv); + template void MobiusEOFAFermion::MooeeInternal(const FermionField& psi, FermionField& chi, int dag, int inv); + template void MobiusEOFAFermion::MooeeInternal(const FermionField& psi, FermionField& chi, int dag, int inv); + + INSTANTIATE_DPERP_MOBIUS_EOFA(GparityWilsonImplFH); + INSTANTIATE_DPERP_MOBIUS_EOFA(GparityWilsonImplDF); + INSTANTIATE_DPERP_MOBIUS_EOFA(WilsonImplFH); + INSTANTIATE_DPERP_MOBIUS_EOFA(WilsonImplDF); + INSTANTIATE_DPERP_MOBIUS_EOFA(ZWilsonImplFH); + INSTANTIATE_DPERP_MOBIUS_EOFA(ZWilsonImplDF); + + template void MobiusEOFAFermion::MooeeInternal(const FermionField& psi, FermionField& chi, int dag, int inv); + template void MobiusEOFAFermion::MooeeInternal(const FermionField& psi, FermionField& chi, int dag, int inv); + template void MobiusEOFAFermion::MooeeInternal(const FermionField& psi, FermionField& chi, int dag, int inv); + template void MobiusEOFAFermion::MooeeInternal(const FermionField& psi, FermionField& chi, int dag, int inv); + template void MobiusEOFAFermion::MooeeInternal(const FermionField& psi, FermionField& chi, int dag, int inv); + template void MobiusEOFAFermion::MooeeInternal(const FermionField& psi, FermionField& chi, int dag, int inv); + + #endif + +}} diff --git a/lib/qcd/action/fermion/MobiusEOFAFermionssp.cc b/lib/qcd/action/fermion/MobiusEOFAFermionssp.cc new file mode 100644 index 00000000..c86bb995 --- /dev/null +++ b/lib/qcd/action/fermion/MobiusEOFAFermionssp.cc @@ -0,0 +1,290 @@ +/************************************************************************************* + +Grid physics library, www.github.com/paboyle/Grid + +Source file: ./lib/qcd/action/fermion/MobiusEOFAFermionssp.cc + +Copyright (C) 2017 + +Author: Peter Boyle +Author: Peter Boyle +Author: Peter Boyle +Author: paboyle +Author: David Murphy + +This program is free software; you can redistribute it and/or modify +it under the terms of the GNU General Public License as published by +the Free Software Foundation; either version 2 of the License, or +(at your option) any later version. + +This program is distributed in the hope that it will be useful, +but WITHOUT ANY WARRANTY; without even the implied warranty of +MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +GNU General Public License for more details. + +You should have received a copy of the GNU General Public License along +with this program; if not, write to the Free Software Foundation, Inc., +51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA. + +See the full license in the file "LICENSE" in the top level distribution directory +*************************************************************************************/ +/* END LEGAL */ + +#include +#include + +namespace Grid { +namespace QCD { + + // FIXME -- make a version of these routines with site loop outermost for cache reuse. + // Pminus fowards + // Pplus backwards + template + void MobiusEOFAFermion::M5D(const FermionField& psi, const FermionField& phi, + FermionField& chi, std::vector& lower, std::vector& diag, std::vector& upper) + { + Coeff_t one(1.0); + int Ls = this->Ls; + for(int s=0; s + void MobiusEOFAFermion::M5D_shift(const FermionField& psi, const FermionField& phi, + FermionField& chi, std::vector& lower, std::vector& diag, std::vector& upper, + std::vector& shift_coeffs) + { + Coeff_t one(1.0); + int Ls = this->Ls; + for(int s=0; spm == 1){ axpby_ssp_pplus(chi, one, chi, shift_coeffs[s], psi, s, Ls-1); } + else{ axpby_ssp_pminus(chi, one, chi, shift_coeffs[s], psi, s, 0); } + } + } + + template + void MobiusEOFAFermion::M5Ddag(const FermionField& psi, const FermionField& phi, + FermionField& chi, std::vector& lower, std::vector& diag, std::vector& upper) + { + Coeff_t one(1.0); + int Ls = this->Ls; + for(int s=0; s + void MobiusEOFAFermion::M5Ddag_shift(const FermionField& psi, const FermionField& phi, + FermionField& chi, std::vector& lower, std::vector& diag, std::vector& upper, + std::vector& shift_coeffs) + { + Coeff_t one(1.0); + int Ls = this->Ls; + for(int s=0; spm == 1){ axpby_ssp_pplus(chi, one, chi, shift_coeffs[s], psi, Ls-1, s); } + else{ axpby_ssp_pminus(chi, one, chi, shift_coeffs[s], psi, 0, s); } + } + } + + template + void MobiusEOFAFermion::MooeeInv(const FermionField& psi, FermionField& chi) + { + if(this->shift != 0.0){ MooeeInv_shift(psi,chi); return; } + + Coeff_t one(1.0); + Coeff_t czero(0.0); + chi.checkerboard = psi.checkerboard; + int Ls = this->Ls; + + // Apply (L^{\prime})^{-1} + axpby_ssp(chi, one, psi, czero, psi, 0, 0); // chi[0]=psi[0] + for(int s=1; slee[s-1], chi, s, s-1);// recursion Psi[s] -lee P_+ chi[s-1] + } + + // L_m^{-1} + for(int s=0; sleem[s], chi, Ls-1, s); + } + + // U_m^{-1} D^{-1} + for(int s=0; sdee[s], chi, -this->ueem[s]/this->dee[Ls-1], chi, s, Ls-1); + } + axpby_ssp(chi, one/this->dee[Ls-1], chi, czero, chi, Ls-1, Ls-1); + + // Apply U^{-1} + for(int s=Ls-2; s>=0; s--){ + axpby_ssp_pminus(chi, one, chi, -this->uee[s], chi, s, s+1); // chi[Ls] + } + } + + template + void MobiusEOFAFermion::MooeeInv_shift(const FermionField& psi, FermionField& chi) + { + Coeff_t one(1.0); + Coeff_t czero(0.0); + chi.checkerboard = psi.checkerboard; + int Ls = this->Ls; + + FermionField tmp(psi._grid); + + // Apply (L^{\prime})^{-1} + axpby_ssp(chi, one, psi, czero, psi, 0, 0); // chi[0]=psi[0] + axpby_ssp(tmp, czero, tmp, this->MooeeInv_shift_lc[0], psi, 0, 0); + for(int s=1; slee[s-1], chi, s, s-1);// recursion Psi[s] -lee P_+ chi[s-1] + axpby_ssp(tmp, one, tmp, this->MooeeInv_shift_lc[s], psi, 0, s); + } + + // L_m^{-1} + for(int s=0; sleem[s], chi, Ls-1, s); + } + + // U_m^{-1} D^{-1} + for(int s=0; sdee[s], chi, -this->ueem[s]/this->dee[Ls-1], chi, s, Ls-1); + } + axpby_ssp(chi, one/this->dee[Ls-1], chi, czero, chi, Ls-1, Ls-1); + + // Apply U^{-1} and add shift term + if(this->pm == 1){ axpby_ssp_pplus(chi, one, chi, this->MooeeInv_shift_norm[Ls-1], tmp, Ls-1, 0); } + else{ axpby_ssp_pminus(chi, one, chi, this->MooeeInv_shift_norm[Ls-1], tmp, Ls-1, 0); } + for(int s=Ls-2; s>=0; s--){ + axpby_ssp_pminus(chi, one, chi, -this->uee[s], chi, s, s+1); // chi[Ls] + if(this->pm == 1){ axpby_ssp_pplus(chi, one, chi, this->MooeeInv_shift_norm[s], tmp, s, 0); } + else{ axpby_ssp_pminus(chi, one, chi, this->MooeeInv_shift_norm[s], tmp, s, 0); } + } + } + + template + void MobiusEOFAFermion::MooeeInvDag(const FermionField& psi, FermionField& chi) + { + if(this->shift != 0.0){ MooeeInvDag_shift(psi,chi); return; } + + Coeff_t one(1.0); + Coeff_t czero(0.0); + chi.checkerboard = psi.checkerboard; + int Ls = this->Ls; + + // Apply (U^{\prime})^{-dagger} + axpby_ssp(chi, one, psi, czero, psi, 0, 0); // chi[0]=psi[0] + for(int s=1; suee[s-1]), chi, s, s-1); + } + + // U_m^{-\dagger} + for(int s=0; sueem[s]), chi, Ls-1, s); + } + + // L_m^{-\dagger} D^{-dagger} + for(int s=0; sdee[s]), chi, -conjugate(this->leem[s]/this->dee[Ls-1]), chi, s, Ls-1); + } + axpby_ssp(chi, one/conjugate(this->dee[Ls-1]), chi, czero, chi, Ls-1, Ls-1); + + // Apply L^{-dagger} + for(int s=Ls-2; s>=0; s--){ + axpby_ssp_pplus(chi, one, chi, -conjugate(this->lee[s]), chi, s, s+1); // chi[Ls] + } + } + + template + void MobiusEOFAFermion::MooeeInvDag_shift(const FermionField& psi, FermionField& chi) + { + Coeff_t one(1.0); + Coeff_t czero(0.0); + chi.checkerboard = psi.checkerboard; + int Ls = this->Ls; + + FermionField tmp(psi._grid); + + // Apply (U^{\prime})^{-dagger} and accumulate (MooeeInvDag_shift_lc)_{j} \psi_{j} in tmp[0] + axpby_ssp(chi, one, psi, czero, psi, 0, 0); // chi[0]=psi[0] + axpby_ssp(tmp, czero, tmp, this->MooeeInvDag_shift_lc[0], psi, 0, 0); + for(int s=1; suee[s-1]), chi, s, s-1); + axpby_ssp(tmp, one, tmp, this->MooeeInvDag_shift_lc[s], psi, 0, s); + } + + // U_m^{-\dagger} + for(int s=0; sueem[s]), chi, Ls-1, s); + } + + // L_m^{-\dagger} D^{-dagger} + for(int s=0; sdee[s]), chi, -conjugate(this->leem[s]/this->dee[Ls-1]), chi, s, Ls-1); + } + axpby_ssp(chi, one/conjugate(this->dee[Ls-1]), chi, czero, chi, Ls-1, Ls-1); + + // Apply L^{-dagger} and add shift + if(this->pm == 1){ axpby_ssp_pplus(chi, one, chi, this->MooeeInvDag_shift_norm[Ls-1], tmp, Ls-1, 0); } + else{ axpby_ssp_pminus(chi, one, chi, this->MooeeInvDag_shift_norm[Ls-1], tmp, Ls-1, 0); } + for(int s=Ls-2; s>=0; s--){ + axpby_ssp_pplus(chi, one, chi, -conjugate(this->lee[s]), chi, s, s+1); // chi[Ls] + if(this->pm == 1){ axpby_ssp_pplus(chi, one, chi, this->MooeeInvDag_shift_norm[s], tmp, s, 0); } + else{ axpby_ssp_pminus(chi, one, chi, this->MooeeInvDag_shift_norm[s], tmp, s, 0); } + } + } + + #ifdef MOBIUS_EOFA_DPERP_LINALG + + INSTANTIATE_DPERP_MOBIUS_EOFA(WilsonImplF); + INSTANTIATE_DPERP_MOBIUS_EOFA(WilsonImplD); + INSTANTIATE_DPERP_MOBIUS_EOFA(GparityWilsonImplF); + INSTANTIATE_DPERP_MOBIUS_EOFA(GparityWilsonImplD); + INSTANTIATE_DPERP_MOBIUS_EOFA(ZWilsonImplF); + INSTANTIATE_DPERP_MOBIUS_EOFA(ZWilsonImplD); + + INSTANTIATE_DPERP_MOBIUS_EOFA(WilsonImplFH); + INSTANTIATE_DPERP_MOBIUS_EOFA(WilsonImplDF); + INSTANTIATE_DPERP_MOBIUS_EOFA(GparityWilsonImplFH); + INSTANTIATE_DPERP_MOBIUS_EOFA(GparityWilsonImplDF); + INSTANTIATE_DPERP_MOBIUS_EOFA(ZWilsonImplFH); + INSTANTIATE_DPERP_MOBIUS_EOFA(ZWilsonImplDF); + + #endif + +}} diff --git a/lib/qcd/action/fermion/MobiusEOFAFermionvec.cc b/lib/qcd/action/fermion/MobiusEOFAFermionvec.cc new file mode 100644 index 00000000..c4eaf0f3 --- /dev/null +++ b/lib/qcd/action/fermion/MobiusEOFAFermionvec.cc @@ -0,0 +1,983 @@ +/************************************************************************************* + +Grid physics library, www.github.com/paboyle/Grid + +Source file: ./lib/qcd/action/fermion/MobiusEOFAFermionvec.cc + +Copyright (C) 2017 + +Author: Peter Boyle +Author: Peter Boyle +Author: Peter Boyle +Author: paboyle +Author: David Murphy + +This program is free software; you can redistribute it and/or modify +it under the terms of the GNU General Public License as published by +the Free Software Foundation; either version 2 of the License, or +(at your option) any later version. + +This program is distributed in the hope that it will be useful, +but WITHOUT ANY WARRANTY; without even the implied warranty of +MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +GNU General Public License for more details. + +You should have received a copy of the GNU General Public License along +with this program; if not, write to the Free Software Foundation, Inc., +51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA. + +See the full license in the file "LICENSE" in the top level distribution directory +*************************************************************************************/ +/* END LEGAL */ + +#include +#include + +namespace Grid { +namespace QCD { + + /* + * Dense matrix versions of routines + */ + template + void MobiusEOFAFermion::MooeeInv(const FermionField& psi, FermionField& chi) + { + this->MooeeInternal(psi, chi, DaggerNo, InverseYes); + } + + template + void MobiusEOFAFermion::MooeeInv_shift(const FermionField& psi, FermionField& chi) + { + this->MooeeInternal(psi, chi, DaggerNo, InverseYes); + } + + template + void MobiusEOFAFermion::MooeeInvDag(const FermionField& psi, FermionField& chi) + { + this->MooeeInternal(psi, chi, DaggerYes, InverseYes); + } + + template + void MobiusEOFAFermion::MooeeInvDag_shift(const FermionField& psi, FermionField& chi) + { + this->MooeeInternal(psi, chi, DaggerYes, InverseYes); + } + + template + void MobiusEOFAFermion::M5D(const FermionField& psi, const FermionField& phi, + FermionField& chi, std::vector& lower, std::vector& diag, std::vector& upper) + { + GridBase* grid = psi._grid; + int Ls = this->Ls; + int LLs = grid->_rdimensions[0]; + const int nsimd = Simd::Nsimd(); + + Vector> u(LLs); + Vector> l(LLs); + Vector> d(LLs); + + assert(Ls/LLs == nsimd); + assert(phi.checkerboard == psi.checkerboard); + + chi.checkerboard = psi.checkerboard; + + // just directly address via type pun + typedef typename Simd::scalar_type scalar_type; + scalar_type* u_p = (scalar_type*) &u[0]; + scalar_type* l_p = (scalar_type*) &l[0]; + scalar_type* d_p = (scalar_type*) &d[0]; + + for(int o=0; oM5Dcalls++; + this->M5Dtime -= usecond(); + + assert(Nc == 3); + + parallel_for(int ss=0; ssoSites(); ss+=LLs){ // adds LLs + + #if 0 + + alignas(64) SiteHalfSpinor hp; + alignas(64) SiteHalfSpinor hm; + alignas(64) SiteSpinor fp; + alignas(64) SiteSpinor fm; + + for(int v=0; v= v){ rotate(hm, hm, nsimd-1); } + + hp = 0.5*hp; + hm = 0.5*hm; + + spRecon5m(fp, hp); + spRecon5p(fm, hm); + + chi[ss+v] = d[v]*phi[ss+v]; + chi[ss+v] = chi[ss+v] + u[v]*fp; + chi[ss+v] = chi[ss+v] + l[v]*fm; + + } + + #else + + for(int v=0; v(hp_00.v); + hp_01.v = Optimization::Rotate::tRotate<2>(hp_01.v); + hp_02.v = Optimization::Rotate::tRotate<2>(hp_02.v); + hp_10.v = Optimization::Rotate::tRotate<2>(hp_10.v); + hp_11.v = Optimization::Rotate::tRotate<2>(hp_11.v); + hp_12.v = Optimization::Rotate::tRotate<2>(hp_12.v); + } + + if(vm >= v){ + hm_00.v = Optimization::Rotate::tRotate<2*Simd::Nsimd()-2>(hm_00.v); + hm_01.v = Optimization::Rotate::tRotate<2*Simd::Nsimd()-2>(hm_01.v); + hm_02.v = Optimization::Rotate::tRotate<2*Simd::Nsimd()-2>(hm_02.v); + hm_10.v = Optimization::Rotate::tRotate<2*Simd::Nsimd()-2>(hm_10.v); + hm_11.v = Optimization::Rotate::tRotate<2*Simd::Nsimd()-2>(hm_11.v); + hm_12.v = Optimization::Rotate::tRotate<2*Simd::Nsimd()-2>(hm_12.v); + } + + // Can force these to real arithmetic and save 2x. + Simd p_00 = switcheroo::mult(d[v]()()(), phi[ss+v]()(0)(0)) + switcheroo::mult(l[v]()()(), hm_00); + Simd p_01 = switcheroo::mult(d[v]()()(), phi[ss+v]()(0)(1)) + switcheroo::mult(l[v]()()(), hm_01); + Simd p_02 = switcheroo::mult(d[v]()()(), phi[ss+v]()(0)(2)) + switcheroo::mult(l[v]()()(), hm_02); + Simd p_10 = switcheroo::mult(d[v]()()(), phi[ss+v]()(1)(0)) + switcheroo::mult(l[v]()()(), hm_10); + Simd p_11 = switcheroo::mult(d[v]()()(), phi[ss+v]()(1)(1)) + switcheroo::mult(l[v]()()(), hm_11); + Simd p_12 = switcheroo::mult(d[v]()()(), phi[ss+v]()(1)(2)) + switcheroo::mult(l[v]()()(), hm_12); + Simd p_20 = switcheroo::mult(d[v]()()(), phi[ss+v]()(2)(0)) + switcheroo::mult(u[v]()()(), hp_00); + Simd p_21 = switcheroo::mult(d[v]()()(), phi[ss+v]()(2)(1)) + switcheroo::mult(u[v]()()(), hp_01); + Simd p_22 = switcheroo::mult(d[v]()()(), phi[ss+v]()(2)(2)) + switcheroo::mult(u[v]()()(), hp_02); + Simd p_30 = switcheroo::mult(d[v]()()(), phi[ss+v]()(3)(0)) + switcheroo::mult(u[v]()()(), hp_10); + Simd p_31 = switcheroo::mult(d[v]()()(), phi[ss+v]()(3)(1)) + switcheroo::mult(u[v]()()(), hp_11); + Simd p_32 = switcheroo::mult(d[v]()()(), phi[ss+v]()(3)(2)) + switcheroo::mult(u[v]()()(), hp_12); + + vstream(chi[ss+v]()(0)(0), p_00); + vstream(chi[ss+v]()(0)(1), p_01); + vstream(chi[ss+v]()(0)(2), p_02); + vstream(chi[ss+v]()(1)(0), p_10); + vstream(chi[ss+v]()(1)(1), p_11); + vstream(chi[ss+v]()(1)(2), p_12); + vstream(chi[ss+v]()(2)(0), p_20); + vstream(chi[ss+v]()(2)(1), p_21); + vstream(chi[ss+v]()(2)(2), p_22); + vstream(chi[ss+v]()(3)(0), p_30); + vstream(chi[ss+v]()(3)(1), p_31); + vstream(chi[ss+v]()(3)(2), p_32); + } + + #endif + } + + this->M5Dtime += usecond(); + } + + template + void MobiusEOFAFermion::M5D_shift(const FermionField& psi, const FermionField& phi, + FermionField& chi, std::vector& lower, std::vector& diag, std::vector& upper, + std::vector& shift_coeffs) + { + #if 0 + + this->M5D(psi, phi, chi, lower, diag, upper); + + // FIXME: possible gain from vectorizing shift operation as well? + Coeff_t one(1.0); + int Ls = this->Ls; + for(int s=0; spm == 1){ axpby_ssp_pplus(chi, one, chi, shift_coeffs[s], psi, s, Ls-1); } + else{ axpby_ssp_pminus(chi, one, chi, shift_coeffs[s], psi, s, 0); } + } + + #else + + GridBase* grid = psi._grid; + int Ls = this->Ls; + int LLs = grid->_rdimensions[0]; + const int nsimd = Simd::Nsimd(); + + Vector> u(LLs); + Vector> l(LLs); + Vector> d(LLs); + Vector> s(LLs); + + assert(Ls/LLs == nsimd); + assert(phi.checkerboard == psi.checkerboard); + + chi.checkerboard = psi.checkerboard; + + // just directly address via type pun + typedef typename Simd::scalar_type scalar_type; + scalar_type* u_p = (scalar_type*) &u[0]; + scalar_type* l_p = (scalar_type*) &l[0]; + scalar_type* d_p = (scalar_type*) &d[0]; + scalar_type* s_p = (scalar_type*) &s[0]; + + for(int o=0; oM5Dcalls++; + this->M5Dtime -= usecond(); + + assert(Nc == 3); + + parallel_for(int ss=0; ssoSites(); ss+=LLs){ // adds LLs + + int vs = (this->pm == 1) ? LLs-1 : 0; + Simd hs_00 = (this->pm == 1) ? psi[ss+vs]()(2)(0) : psi[ss+vs]()(0)(0); + Simd hs_01 = (this->pm == 1) ? psi[ss+vs]()(2)(1) : psi[ss+vs]()(0)(1); + Simd hs_02 = (this->pm == 1) ? psi[ss+vs]()(2)(2) : psi[ss+vs]()(0)(2); + Simd hs_10 = (this->pm == 1) ? psi[ss+vs]()(3)(0) : psi[ss+vs]()(1)(0); + Simd hs_11 = (this->pm == 1) ? psi[ss+vs]()(3)(1) : psi[ss+vs]()(1)(1); + Simd hs_12 = (this->pm == 1) ? psi[ss+vs]()(3)(2) : psi[ss+vs]()(1)(2); + + for(int v=0; v(hp_00.v); + hp_01.v = Optimization::Rotate::tRotate<2>(hp_01.v); + hp_02.v = Optimization::Rotate::tRotate<2>(hp_02.v); + hp_10.v = Optimization::Rotate::tRotate<2>(hp_10.v); + hp_11.v = Optimization::Rotate::tRotate<2>(hp_11.v); + hp_12.v = Optimization::Rotate::tRotate<2>(hp_12.v); + } + + if(this->pm == 1 && vs <= v){ + hs_00.v = Optimization::Rotate::tRotate<2>(hs_00.v); + hs_01.v = Optimization::Rotate::tRotate<2>(hs_01.v); + hs_02.v = Optimization::Rotate::tRotate<2>(hs_02.v); + hs_10.v = Optimization::Rotate::tRotate<2>(hs_10.v); + hs_11.v = Optimization::Rotate::tRotate<2>(hs_11.v); + hs_12.v = Optimization::Rotate::tRotate<2>(hs_12.v); + } + + if(vm >= v){ + hm_00.v = Optimization::Rotate::tRotate<2*Simd::Nsimd()-2>(hm_00.v); + hm_01.v = Optimization::Rotate::tRotate<2*Simd::Nsimd()-2>(hm_01.v); + hm_02.v = Optimization::Rotate::tRotate<2*Simd::Nsimd()-2>(hm_02.v); + hm_10.v = Optimization::Rotate::tRotate<2*Simd::Nsimd()-2>(hm_10.v); + hm_11.v = Optimization::Rotate::tRotate<2*Simd::Nsimd()-2>(hm_11.v); + hm_12.v = Optimization::Rotate::tRotate<2*Simd::Nsimd()-2>(hm_12.v); + } + + if(this->pm == -1 && vs >= v){ + hs_00.v = Optimization::Rotate::tRotate<2*Simd::Nsimd()-2>(hs_00.v); + hs_01.v = Optimization::Rotate::tRotate<2*Simd::Nsimd()-2>(hs_01.v); + hs_02.v = Optimization::Rotate::tRotate<2*Simd::Nsimd()-2>(hs_02.v); + hs_10.v = Optimization::Rotate::tRotate<2*Simd::Nsimd()-2>(hs_10.v); + hs_11.v = Optimization::Rotate::tRotate<2*Simd::Nsimd()-2>(hs_11.v); + hs_12.v = Optimization::Rotate::tRotate<2*Simd::Nsimd()-2>(hs_12.v); + } + + // Can force these to real arithmetic and save 2x. + Simd p_00 = (this->pm == 1) ? switcheroo::mult(d[v]()()(), phi[ss+v]()(0)(0)) + switcheroo::mult(l[v]()()(), hm_00) + : switcheroo::mult(d[v]()()(), phi[ss+v]()(0)(0)) + switcheroo::mult(l[v]()()(), hm_00) + + switcheroo::mult(s[v]()()(), hs_00); + Simd p_01 = (this->pm == 1) ? switcheroo::mult(d[v]()()(), phi[ss+v]()(0)(1)) + switcheroo::mult(l[v]()()(), hm_01) + : switcheroo::mult(d[v]()()(), phi[ss+v]()(0)(1)) + switcheroo::mult(l[v]()()(), hm_01) + + switcheroo::mult(s[v]()()(), hs_01); + Simd p_02 = (this->pm == 1) ? switcheroo::mult(d[v]()()(), phi[ss+v]()(0)(2)) + switcheroo::mult(l[v]()()(), hm_02) + : switcheroo::mult(d[v]()()(), phi[ss+v]()(0)(2)) + switcheroo::mult(l[v]()()(), hm_02) + + switcheroo::mult(s[v]()()(), hs_02); + Simd p_10 = (this->pm == 1) ? switcheroo::mult(d[v]()()(), phi[ss+v]()(1)(0)) + switcheroo::mult(l[v]()()(), hm_10) + : switcheroo::mult(d[v]()()(), phi[ss+v]()(1)(0)) + switcheroo::mult(l[v]()()(), hm_10) + + switcheroo::mult(s[v]()()(), hs_10); + Simd p_11 = (this->pm == 1) ? switcheroo::mult(d[v]()()(), phi[ss+v]()(1)(1)) + switcheroo::mult(l[v]()()(), hm_11) + : switcheroo::mult(d[v]()()(), phi[ss+v]()(1)(1)) + switcheroo::mult(l[v]()()(), hm_11) + + switcheroo::mult(s[v]()()(), hs_11); + Simd p_12 = (this->pm == 1) ? switcheroo::mult(d[v]()()(), phi[ss+v]()(1)(2)) + switcheroo::mult(l[v]()()(), hm_12) + : switcheroo::mult(d[v]()()(), phi[ss+v]()(1)(2)) + switcheroo::mult(l[v]()()(), hm_12) + + switcheroo::mult(s[v]()()(), hs_12); + Simd p_20 = (this->pm == 1) ? switcheroo::mult(d[v]()()(), phi[ss+v]()(2)(0)) + switcheroo::mult(u[v]()()(), hp_00) + + switcheroo::mult(s[v]()()(), hs_00) + : switcheroo::mult(d[v]()()(), phi[ss+v]()(2)(0)) + switcheroo::mult(u[v]()()(), hp_00); + Simd p_21 = (this->pm == 1) ? switcheroo::mult(d[v]()()(), phi[ss+v]()(2)(1)) + switcheroo::mult(u[v]()()(), hp_01) + + switcheroo::mult(s[v]()()(), hs_01) + : switcheroo::mult(d[v]()()(), phi[ss+v]()(2)(1)) + switcheroo::mult(u[v]()()(), hp_01); + Simd p_22 = (this->pm == 1) ? switcheroo::mult(d[v]()()(), phi[ss+v]()(2)(2)) + switcheroo::mult(u[v]()()(), hp_02) + + switcheroo::mult(s[v]()()(), hs_02) + : switcheroo::mult(d[v]()()(), phi[ss+v]()(2)(2)) + switcheroo::mult(u[v]()()(), hp_02); + Simd p_30 = (this->pm == 1) ? switcheroo::mult(d[v]()()(), phi[ss+v]()(3)(0)) + switcheroo::mult(u[v]()()(), hp_10) + + switcheroo::mult(s[v]()()(), hs_10) + : switcheroo::mult(d[v]()()(), phi[ss+v]()(3)(0)) + switcheroo::mult(u[v]()()(), hp_10); + Simd p_31 = (this->pm == 1) ? switcheroo::mult(d[v]()()(), phi[ss+v]()(3)(1)) + switcheroo::mult(u[v]()()(), hp_11) + + switcheroo::mult(s[v]()()(), hs_11) + : switcheroo::mult(d[v]()()(), phi[ss+v]()(3)(1)) + switcheroo::mult(u[v]()()(), hp_11); + Simd p_32 = (this->pm == 1) ? switcheroo::mult(d[v]()()(), phi[ss+v]()(3)(2)) + switcheroo::mult(u[v]()()(), hp_12) + + switcheroo::mult(s[v]()()(), hs_12) + : switcheroo::mult(d[v]()()(), phi[ss+v]()(3)(2)) + switcheroo::mult(u[v]()()(), hp_12); + + vstream(chi[ss+v]()(0)(0), p_00); + vstream(chi[ss+v]()(0)(1), p_01); + vstream(chi[ss+v]()(0)(2), p_02); + vstream(chi[ss+v]()(1)(0), p_10); + vstream(chi[ss+v]()(1)(1), p_11); + vstream(chi[ss+v]()(1)(2), p_12); + vstream(chi[ss+v]()(2)(0), p_20); + vstream(chi[ss+v]()(2)(1), p_21); + vstream(chi[ss+v]()(2)(2), p_22); + vstream(chi[ss+v]()(3)(0), p_30); + vstream(chi[ss+v]()(3)(1), p_31); + vstream(chi[ss+v]()(3)(2), p_32); + } + } + + this->M5Dtime += usecond(); + + #endif + } + + template + void MobiusEOFAFermion::M5Ddag(const FermionField& psi, const FermionField& phi, + FermionField& chi, std::vector& lower, std::vector& diag, std::vector& upper) + { + GridBase* grid = psi._grid; + int Ls = this->Ls; + int LLs = grid->_rdimensions[0]; + int nsimd = Simd::Nsimd(); + + Vector> u(LLs); + Vector> l(LLs); + Vector> d(LLs); + + assert(Ls/LLs == nsimd); + assert(phi.checkerboard == psi.checkerboard); + + chi.checkerboard = psi.checkerboard; + + // just directly address via type pun + typedef typename Simd::scalar_type scalar_type; + scalar_type* u_p = (scalar_type*) &u[0]; + scalar_type* l_p = (scalar_type*) &l[0]; + scalar_type* d_p = (scalar_type*) &d[0]; + + for(int o=0; oM5Dcalls++; + this->M5Dtime -= usecond(); + + parallel_for(int ss=0; ssoSites(); ss+=LLs){ // adds LLs + + #if 0 + + alignas(64) SiteHalfSpinor hp; + alignas(64) SiteHalfSpinor hm; + alignas(64) SiteSpinor fp; + alignas(64) SiteSpinor fm; + + for(int v=0; v= v){ rotate(hm, hm, nsimd-1); } + + hp = hp*0.5; + hm = hm*0.5; + spRecon5p(fp, hp); + spRecon5m(fm, hm); + + chi[ss+v] = d[v]*phi[ss+v]+u[v]*fp; + chi[ss+v] = chi[ss+v] +l[v]*fm; + + } + + #else + + for(int v=0; v(hp_00.v); + hp_01.v = Optimization::Rotate::tRotate<2>(hp_01.v); + hp_02.v = Optimization::Rotate::tRotate<2>(hp_02.v); + hp_10.v = Optimization::Rotate::tRotate<2>(hp_10.v); + hp_11.v = Optimization::Rotate::tRotate<2>(hp_11.v); + hp_12.v = Optimization::Rotate::tRotate<2>(hp_12.v); + } + + if(vm >= v){ + hm_00.v = Optimization::Rotate::tRotate<2*Simd::Nsimd()-2>(hm_00.v); + hm_01.v = Optimization::Rotate::tRotate<2*Simd::Nsimd()-2>(hm_01.v); + hm_02.v = Optimization::Rotate::tRotate<2*Simd::Nsimd()-2>(hm_02.v); + hm_10.v = Optimization::Rotate::tRotate<2*Simd::Nsimd()-2>(hm_10.v); + hm_11.v = Optimization::Rotate::tRotate<2*Simd::Nsimd()-2>(hm_11.v); + hm_12.v = Optimization::Rotate::tRotate<2*Simd::Nsimd()-2>(hm_12.v); + } + + Simd p_00 = switcheroo::mult(d[v]()()(), phi[ss+v]()(0)(0)) + switcheroo::mult(u[v]()()(), hp_00); + Simd p_01 = switcheroo::mult(d[v]()()(), phi[ss+v]()(0)(1)) + switcheroo::mult(u[v]()()(), hp_01); + Simd p_02 = switcheroo::mult(d[v]()()(), phi[ss+v]()(0)(2)) + switcheroo::mult(u[v]()()(), hp_02); + Simd p_10 = switcheroo::mult(d[v]()()(), phi[ss+v]()(1)(0)) + switcheroo::mult(u[v]()()(), hp_10); + Simd p_11 = switcheroo::mult(d[v]()()(), phi[ss+v]()(1)(1)) + switcheroo::mult(u[v]()()(), hp_11); + Simd p_12 = switcheroo::mult(d[v]()()(), phi[ss+v]()(1)(2)) + switcheroo::mult(u[v]()()(), hp_12); + Simd p_20 = switcheroo::mult(d[v]()()(), phi[ss+v]()(2)(0)) + switcheroo::mult(l[v]()()(), hm_00); + Simd p_21 = switcheroo::mult(d[v]()()(), phi[ss+v]()(2)(1)) + switcheroo::mult(l[v]()()(), hm_01); + Simd p_22 = switcheroo::mult(d[v]()()(), phi[ss+v]()(2)(2)) + switcheroo::mult(l[v]()()(), hm_02); + Simd p_30 = switcheroo::mult(d[v]()()(), phi[ss+v]()(3)(0)) + switcheroo::mult(l[v]()()(), hm_10); + Simd p_31 = switcheroo::mult(d[v]()()(), phi[ss+v]()(3)(1)) + switcheroo::mult(l[v]()()(), hm_11); + Simd p_32 = switcheroo::mult(d[v]()()(), phi[ss+v]()(3)(2)) + switcheroo::mult(l[v]()()(), hm_12); + + vstream(chi[ss+v]()(0)(0), p_00); + vstream(chi[ss+v]()(0)(1), p_01); + vstream(chi[ss+v]()(0)(2), p_02); + vstream(chi[ss+v]()(1)(0), p_10); + vstream(chi[ss+v]()(1)(1), p_11); + vstream(chi[ss+v]()(1)(2), p_12); + vstream(chi[ss+v]()(2)(0), p_20); + vstream(chi[ss+v]()(2)(1), p_21); + vstream(chi[ss+v]()(2)(2), p_22); + vstream(chi[ss+v]()(3)(0), p_30); + vstream(chi[ss+v]()(3)(1), p_31); + vstream(chi[ss+v]()(3)(2), p_32); + + } + + #endif + + } + + this->M5Dtime += usecond(); + } + + template + void MobiusEOFAFermion::M5Ddag_shift(const FermionField& psi, const FermionField& phi, + FermionField& chi, std::vector& lower, std::vector& diag, std::vector& upper, + std::vector& shift_coeffs) + { + #if 0 + + this->M5Ddag(psi, phi, chi, lower, diag, upper); + + // FIXME: possible gain from vectorizing shift operation as well? + Coeff_t one(1.0); + int Ls = this->Ls; + for(int s=0; spm == 1){ axpby_ssp_pplus(chi, one, chi, shift_coeffs[s], psi, Ls-1, s); } + else{ axpby_ssp_pminus(chi, one, chi, shift_coeffs[s], psi, 0, s); } + } + + #else + + GridBase* grid = psi._grid; + int Ls = this->Ls; + int LLs = grid->_rdimensions[0]; + int nsimd = Simd::Nsimd(); + + Vector> u(LLs); + Vector> l(LLs); + Vector> d(LLs); + Vector> s(LLs); + + assert(Ls/LLs == nsimd); + assert(phi.checkerboard == psi.checkerboard); + + chi.checkerboard = psi.checkerboard; + + // just directly address via type pun + typedef typename Simd::scalar_type scalar_type; + scalar_type* u_p = (scalar_type*) &u[0]; + scalar_type* l_p = (scalar_type*) &l[0]; + scalar_type* d_p = (scalar_type*) &d[0]; + scalar_type* s_p = (scalar_type*) &s[0]; + + for(int o=0; oM5Dcalls++; + this->M5Dtime -= usecond(); + + parallel_for(int ss=0; ssoSites(); ss+=LLs){ // adds LLs + + int vs = (this->pm == 1) ? LLs-1 : 0; + Simd hs_00 = (this->pm == 1) ? psi[ss+vs]()(0)(0) : psi[ss+vs]()(2)(0); + Simd hs_01 = (this->pm == 1) ? psi[ss+vs]()(0)(1) : psi[ss+vs]()(2)(1); + Simd hs_02 = (this->pm == 1) ? psi[ss+vs]()(0)(2) : psi[ss+vs]()(2)(2); + Simd hs_10 = (this->pm == 1) ? psi[ss+vs]()(1)(0) : psi[ss+vs]()(3)(0); + Simd hs_11 = (this->pm == 1) ? psi[ss+vs]()(1)(1) : psi[ss+vs]()(3)(1); + Simd hs_12 = (this->pm == 1) ? psi[ss+vs]()(1)(2) : psi[ss+vs]()(3)(2); + + for(int v=0; v(hp_00.v); + hp_01.v = Optimization::Rotate::tRotate<2>(hp_01.v); + hp_02.v = Optimization::Rotate::tRotate<2>(hp_02.v); + hp_10.v = Optimization::Rotate::tRotate<2>(hp_10.v); + hp_11.v = Optimization::Rotate::tRotate<2>(hp_11.v); + hp_12.v = Optimization::Rotate::tRotate<2>(hp_12.v); + } + + if(this->pm == 1 && vs <= v){ + hs_00.v = Optimization::Rotate::tRotate<2>(hs_00.v); + hs_01.v = Optimization::Rotate::tRotate<2>(hs_01.v); + hs_02.v = Optimization::Rotate::tRotate<2>(hs_02.v); + hs_10.v = Optimization::Rotate::tRotate<2>(hs_10.v); + hs_11.v = Optimization::Rotate::tRotate<2>(hs_11.v); + hs_12.v = Optimization::Rotate::tRotate<2>(hs_12.v); + } + + if(vm >= v){ + hm_00.v = Optimization::Rotate::tRotate<2*Simd::Nsimd()-2>(hm_00.v); + hm_01.v = Optimization::Rotate::tRotate<2*Simd::Nsimd()-2>(hm_01.v); + hm_02.v = Optimization::Rotate::tRotate<2*Simd::Nsimd()-2>(hm_02.v); + hm_10.v = Optimization::Rotate::tRotate<2*Simd::Nsimd()-2>(hm_10.v); + hm_11.v = Optimization::Rotate::tRotate<2*Simd::Nsimd()-2>(hm_11.v); + hm_12.v = Optimization::Rotate::tRotate<2*Simd::Nsimd()-2>(hm_12.v); + } + + if(this->pm == -1 && vs >= v){ + hs_00.v = Optimization::Rotate::tRotate<2*Simd::Nsimd()-2>(hs_00.v); + hs_01.v = Optimization::Rotate::tRotate<2*Simd::Nsimd()-2>(hs_01.v); + hs_02.v = Optimization::Rotate::tRotate<2*Simd::Nsimd()-2>(hs_02.v); + hs_10.v = Optimization::Rotate::tRotate<2*Simd::Nsimd()-2>(hs_10.v); + hs_11.v = Optimization::Rotate::tRotate<2*Simd::Nsimd()-2>(hs_11.v); + hs_12.v = Optimization::Rotate::tRotate<2*Simd::Nsimd()-2>(hs_12.v); + } + + Simd p_00 = (this->pm == 1) ? switcheroo::mult(d[v]()()(), phi[ss+v]()(0)(0)) + switcheroo::mult(u[v]()()(), hp_00) + + switcheroo::mult(s[v]()()(), hs_00) + : switcheroo::mult(d[v]()()(), phi[ss+v]()(0)(0)) + switcheroo::mult(u[v]()()(), hp_00); + Simd p_01 = (this->pm == 1) ? switcheroo::mult(d[v]()()(), phi[ss+v]()(0)(1)) + switcheroo::mult(u[v]()()(), hp_01) + + switcheroo::mult(s[v]()()(), hs_01) + : switcheroo::mult(d[v]()()(), phi[ss+v]()(0)(1)) + switcheroo::mult(u[v]()()(), hp_01); + Simd p_02 = (this->pm == 1) ? switcheroo::mult(d[v]()()(), phi[ss+v]()(0)(2)) + switcheroo::mult(u[v]()()(), hp_02) + + switcheroo::mult(s[v]()()(), hs_02) + : switcheroo::mult(d[v]()()(), phi[ss+v]()(0)(2)) + switcheroo::mult(u[v]()()(), hp_02); + Simd p_10 = (this->pm == 1) ? switcheroo::mult(d[v]()()(), phi[ss+v]()(1)(0)) + switcheroo::mult(u[v]()()(), hp_10) + + switcheroo::mult(s[v]()()(), hs_10) + : switcheroo::mult(d[v]()()(), phi[ss+v]()(1)(0)) + switcheroo::mult(u[v]()()(), hp_10); + Simd p_11 = (this->pm == 1) ? switcheroo::mult(d[v]()()(), phi[ss+v]()(1)(1)) + switcheroo::mult(u[v]()()(), hp_11) + + switcheroo::mult(s[v]()()(), hs_11) + : switcheroo::mult(d[v]()()(), phi[ss+v]()(1)(1)) + switcheroo::mult(u[v]()()(), hp_11); + Simd p_12 = (this->pm == 1) ? switcheroo::mult(d[v]()()(), phi[ss+v]()(1)(2)) + switcheroo::mult(u[v]()()(), hp_12) + + switcheroo::mult(s[v]()()(), hs_12) + : switcheroo::mult(d[v]()()(), phi[ss+v]()(1)(2)) + switcheroo::mult(u[v]()()(), hp_12); + Simd p_20 = (this->pm == 1) ? switcheroo::mult(d[v]()()(), phi[ss+v]()(2)(0)) + switcheroo::mult(l[v]()()(), hm_00) + : switcheroo::mult(d[v]()()(), phi[ss+v]()(2)(0)) + switcheroo::mult(l[v]()()(), hm_00) + + switcheroo::mult(s[v]()()(), hs_00); + Simd p_21 = (this->pm == 1) ? switcheroo::mult(d[v]()()(), phi[ss+v]()(2)(1)) + switcheroo::mult(l[v]()()(), hm_01) + : switcheroo::mult(d[v]()()(), phi[ss+v]()(2)(1)) + switcheroo::mult(l[v]()()(), hm_01) + + switcheroo::mult(s[v]()()(), hs_01); + Simd p_22 = (this->pm == 1) ? switcheroo::mult(d[v]()()(), phi[ss+v]()(2)(2)) + switcheroo::mult(l[v]()()(), hm_02) + : switcheroo::mult(d[v]()()(), phi[ss+v]()(2)(2)) + switcheroo::mult(l[v]()()(), hm_02) + + switcheroo::mult(s[v]()()(), hs_02); + Simd p_30 = (this->pm == 1) ? switcheroo::mult(d[v]()()(), phi[ss+v]()(3)(0)) + switcheroo::mult(l[v]()()(), hm_10) + : switcheroo::mult(d[v]()()(), phi[ss+v]()(3)(0)) + switcheroo::mult(l[v]()()(), hm_10) + + switcheroo::mult(s[v]()()(), hs_10); + Simd p_31 = (this->pm == 1) ? switcheroo::mult(d[v]()()(), phi[ss+v]()(3)(1)) + switcheroo::mult(l[v]()()(), hm_11) + : switcheroo::mult(d[v]()()(), phi[ss+v]()(3)(1)) + switcheroo::mult(l[v]()()(), hm_11) + + switcheroo::mult(s[v]()()(), hs_11); + Simd p_32 = (this->pm == 1) ? switcheroo::mult(d[v]()()(), phi[ss+v]()(3)(2)) + switcheroo::mult(l[v]()()(), hm_12) + : switcheroo::mult(d[v]()()(), phi[ss+v]()(3)(2)) + switcheroo::mult(l[v]()()(), hm_12) + + switcheroo::mult(s[v]()()(), hs_12); + + vstream(chi[ss+v]()(0)(0), p_00); + vstream(chi[ss+v]()(0)(1), p_01); + vstream(chi[ss+v]()(0)(2), p_02); + vstream(chi[ss+v]()(1)(0), p_10); + vstream(chi[ss+v]()(1)(1), p_11); + vstream(chi[ss+v]()(1)(2), p_12); + vstream(chi[ss+v]()(2)(0), p_20); + vstream(chi[ss+v]()(2)(1), p_21); + vstream(chi[ss+v]()(2)(2), p_22); + vstream(chi[ss+v]()(3)(0), p_30); + vstream(chi[ss+v]()(3)(1), p_31); + vstream(chi[ss+v]()(3)(2), p_32); + + } + + } + + this->M5Dtime += usecond(); + + #endif + } + + #ifdef AVX512 + #include + #include + #include + #endif + + template + void MobiusEOFAFermion::MooeeInternalAsm(const FermionField& psi, FermionField& chi, + int LLs, int site, Vector >& Matp, Vector >& Matm) + { + #ifndef AVX512 + { + SiteHalfSpinor BcastP; + SiteHalfSpinor BcastM; + SiteHalfSpinor SiteChiP; + SiteHalfSpinor SiteChiM; + + // Ls*Ls * 2 * 12 * vol flops + for(int s1=0; s1); + + for(int s1=0; s1 + void MobiusEOFAFermion::MooeeInternalZAsm(const FermionField& psi, FermionField& chi, + int LLs, int site, Vector >& Matp, Vector >& Matm) + { + std::cout << "Error: zMobius not implemented for EOFA" << std::endl; + exit(-1); + }; + + template + void MobiusEOFAFermion::MooeeInternal(const FermionField& psi, FermionField& chi, int dag, int inv) + { + int Ls = this->Ls; + int LLs = psi._grid->_rdimensions[0]; + int vol = psi._grid->oSites()/LLs; + + chi.checkerboard = psi.checkerboard; + + Vector> Matp; + Vector> Matm; + Vector>* _Matp; + Vector>* _Matm; + + // MooeeInternalCompute(dag,inv,Matp,Matm); + if(inv && dag){ + _Matp = &this->MatpInvDag; + _Matm = &this->MatmInvDag; + } + + if(inv && (!dag)){ + _Matp = &this->MatpInv; + _Matm = &this->MatmInv; + } + + if(!inv){ + MooeeInternalCompute(dag, inv, Matp, Matm); + _Matp = &Matp; + _Matm = &Matm; + } + + assert(_Matp->size() == Ls*LLs); + + this->MooeeInvCalls++; + this->MooeeInvTime -= usecond(); + + if(switcheroo::iscomplex()){ + parallel_for(auto site=0; siteMooeeInvTime += usecond(); + } + + #ifdef MOBIUS_EOFA_DPERP_VEC + + INSTANTIATE_DPERP_MOBIUS_EOFA(DomainWallVec5dImplD); + INSTANTIATE_DPERP_MOBIUS_EOFA(DomainWallVec5dImplF); + INSTANTIATE_DPERP_MOBIUS_EOFA(ZDomainWallVec5dImplD); + INSTANTIATE_DPERP_MOBIUS_EOFA(ZDomainWallVec5dImplF); + + INSTANTIATE_DPERP_MOBIUS_EOFA(DomainWallVec5dImplDF); + INSTANTIATE_DPERP_MOBIUS_EOFA(DomainWallVec5dImplFH); + INSTANTIATE_DPERP_MOBIUS_EOFA(ZDomainWallVec5dImplDF); + INSTANTIATE_DPERP_MOBIUS_EOFA(ZDomainWallVec5dImplFH); + + template void MobiusEOFAFermion::MooeeInternal(const FermionField& psi, FermionField& chi, int dag, int inv); + template void MobiusEOFAFermion::MooeeInternal(const FermionField& psi, FermionField& chi, int dag, int inv); + template void MobiusEOFAFermion::MooeeInternal(const FermionField& psi, FermionField& chi, int dag, int inv); + template void MobiusEOFAFermion::MooeeInternal(const FermionField& psi, FermionField& chi, int dag, int inv); + + template void MobiusEOFAFermion::MooeeInternal(const FermionField& psi, FermionField& chi, int dag, int inv); + template void MobiusEOFAFermion::MooeeInternal(const FermionField& psi, FermionField& chi, int dag, int inv); + template void MobiusEOFAFermion::MooeeInternal(const FermionField& psi, FermionField& chi, int dag, int inv); + template void MobiusEOFAFermion::MooeeInternal(const FermionField& psi, FermionField& chi, int dag, int inv); + + #endif + +}} diff --git a/lib/qcd/action/fermion/MobiusFermion.h b/lib/qcd/action/fermion/MobiusFermion.h index ade9ca4d..b61c26d5 100644 --- a/lib/qcd/action/fermion/MobiusFermion.h +++ b/lib/qcd/action/fermion/MobiusFermion.h @@ -29,7 +29,7 @@ Author: Peter Boyle #ifndef GRID_QCD_MOBIUS_FERMION_H #define GRID_QCD_MOBIUS_FERMION_H -#include +#include namespace Grid { diff --git a/lib/qcd/action/fermion/MobiusZolotarevFermion.h b/lib/qcd/action/fermion/MobiusZolotarevFermion.h index 609d5cea..078d4f3e 100644 --- a/lib/qcd/action/fermion/MobiusZolotarevFermion.h +++ b/lib/qcd/action/fermion/MobiusZolotarevFermion.h @@ -29,7 +29,7 @@ Author: Peter Boyle #ifndef GRID_QCD_MOBIUS_ZOLOTAREV_FERMION_H #define GRID_QCD_MOBIUS_ZOLOTAREV_FERMION_H -#include +#include namespace Grid { diff --git a/lib/qcd/action/fermion/OverlapWilsonCayleyTanhFermion.h b/lib/qcd/action/fermion/OverlapWilsonCayleyTanhFermion.h index 9cab0e22..f516c5d0 100644 --- a/lib/qcd/action/fermion/OverlapWilsonCayleyTanhFermion.h +++ b/lib/qcd/action/fermion/OverlapWilsonCayleyTanhFermion.h @@ -29,7 +29,7 @@ Author: Peter Boyle #ifndef OVERLAP_WILSON_CAYLEY_TANH_FERMION_H #define OVERLAP_WILSON_CAYLEY_TANH_FERMION_H -#include +#include namespace Grid { diff --git a/lib/qcd/action/fermion/OverlapWilsonCayleyZolotarevFermion.h b/lib/qcd/action/fermion/OverlapWilsonCayleyZolotarevFermion.h index 048244cc..4f1adbbf 100644 --- a/lib/qcd/action/fermion/OverlapWilsonCayleyZolotarevFermion.h +++ b/lib/qcd/action/fermion/OverlapWilsonCayleyZolotarevFermion.h @@ -29,7 +29,7 @@ Author: Peter Boyle #ifndef OVERLAP_WILSON_CAYLEY_ZOLOTAREV_FERMION_H #define OVERLAP_WILSON_CAYLEY_ZOLOTAREV_FERMION_H -#include +#include namespace Grid { diff --git a/lib/qcd/action/fermion/OverlapWilsonContfracTanhFermion.h b/lib/qcd/action/fermion/OverlapWilsonContfracTanhFermion.h index bbac735a..38d0fda2 100644 --- a/lib/qcd/action/fermion/OverlapWilsonContfracTanhFermion.h +++ b/lib/qcd/action/fermion/OverlapWilsonContfracTanhFermion.h @@ -29,7 +29,7 @@ Author: Peter Boyle #ifndef OVERLAP_WILSON_CONTFRAC_TANH_FERMION_H #define OVERLAP_WILSON_CONTFRAC_TANH_FERMION_H -#include +#include namespace Grid { diff --git a/lib/qcd/action/fermion/OverlapWilsonContfracZolotarevFermion.h b/lib/qcd/action/fermion/OverlapWilsonContfracZolotarevFermion.h index 9da30f65..6773b4d2 100644 --- a/lib/qcd/action/fermion/OverlapWilsonContfracZolotarevFermion.h +++ b/lib/qcd/action/fermion/OverlapWilsonContfracZolotarevFermion.h @@ -29,7 +29,7 @@ Author: Peter Boyle #ifndef OVERLAP_WILSON_CONTFRAC_ZOLOTAREV_FERMION_H #define OVERLAP_WILSON_CONTFRAC_ZOLOTAREV_FERMION_H -#include +#include namespace Grid { diff --git a/lib/qcd/action/fermion/OverlapWilsonPartialFractionTanhFermion.h b/lib/qcd/action/fermion/OverlapWilsonPartialFractionTanhFermion.h index 3b867174..84c4f597 100644 --- a/lib/qcd/action/fermion/OverlapWilsonPartialFractionTanhFermion.h +++ b/lib/qcd/action/fermion/OverlapWilsonPartialFractionTanhFermion.h @@ -29,7 +29,7 @@ Author: Peter Boyle #ifndef OVERLAP_WILSON_PARTFRAC_TANH_FERMION_H #define OVERLAP_WILSON_PARTFRAC_TANH_FERMION_H -#include +#include namespace Grid { diff --git a/lib/qcd/action/fermion/OverlapWilsonPartialFractionZolotarevFermion.h b/lib/qcd/action/fermion/OverlapWilsonPartialFractionZolotarevFermion.h index e1d0763b..dc275852 100644 --- a/lib/qcd/action/fermion/OverlapWilsonPartialFractionZolotarevFermion.h +++ b/lib/qcd/action/fermion/OverlapWilsonPartialFractionZolotarevFermion.h @@ -29,7 +29,7 @@ Author: Peter Boyle #ifndef OVERLAP_WILSON_PARTFRAC_ZOLOTAREV_FERMION_H #define OVERLAP_WILSON_PARTFRAC_ZOLOTAREV_FERMION_H -#include +#include namespace Grid { diff --git a/lib/qcd/action/fermion/PartialFractionFermion5D.cc b/lib/qcd/action/fermion/PartialFractionFermion5D.cc index 4fcb8784..3a78e043 100644 --- a/lib/qcd/action/fermion/PartialFractionFermion5D.cc +++ b/lib/qcd/action/fermion/PartialFractionFermion5D.cc @@ -26,7 +26,9 @@ Author: Peter Boyle See the full license in the file "LICENSE" in the top level distribution directory *************************************************************************************/ /* END LEGAL */ -#include +#include +#include + namespace Grid { namespace QCD { diff --git a/lib/qcd/action/fermion/PartialFractionFermion5D.h b/lib/qcd/action/fermion/PartialFractionFermion5D.h index 126f3299..0ec72de4 100644 --- a/lib/qcd/action/fermion/PartialFractionFermion5D.h +++ b/lib/qcd/action/fermion/PartialFractionFermion5D.h @@ -29,6 +29,8 @@ Author: Peter Boyle #ifndef GRID_QCD_PARTIAL_FRACTION_H #define GRID_QCD_PARTIAL_FRACTION_H +#include + namespace Grid { namespace QCD { diff --git a/lib/qcd/action/fermion/ScaledShamirFermion.h b/lib/qcd/action/fermion/ScaledShamirFermion.h index f850ee4d..b779b9c0 100644 --- a/lib/qcd/action/fermion/ScaledShamirFermion.h +++ b/lib/qcd/action/fermion/ScaledShamirFermion.h @@ -29,7 +29,7 @@ Author: Peter Boyle #ifndef GRID_QCD_SCALED_SHAMIR_FERMION_H #define GRID_QCD_SCALED_SHAMIR_FERMION_H -#include +#include namespace Grid { diff --git a/lib/qcd/action/fermion/SchurDiagTwoKappa.h b/lib/qcd/action/fermion/SchurDiagTwoKappa.h new file mode 100644 index 00000000..8305f98a --- /dev/null +++ b/lib/qcd/action/fermion/SchurDiagTwoKappa.h @@ -0,0 +1,102 @@ + /************************************************************************************* + + Grid physics library, www.github.com/paboyle/Grid + + Source file: SchurDiagTwoKappa.h + + Copyright (C) 2017 + +Author: Christoph Lehner +Author: Peter Boyle + + This program is free software; you can redistribute it and/or modify + it under the terms of the GNU General Public License as published by + the Free Software Foundation; either version 2 of the License, or + (at your option) any later version. + + This program is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU General Public License for more details. + + You should have received a copy of the GNU General Public License along + with this program; if not, write to the Free Software Foundation, Inc., + 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA. + + See the full license in the file "LICENSE" in the top level distribution directory + *************************************************************************************/ + /* END LEGAL */ +#ifndef _SCHUR_DIAG_TWO_KAPPA_H +#define _SCHUR_DIAG_TWO_KAPPA_H + +namespace Grid { + + // This is specific to (Z)mobius fermions + template + class KappaSimilarityTransform { + public: + INHERIT_IMPL_TYPES(Matrix); + std::vector kappa, kappaDag, kappaInv, kappaInvDag; + + KappaSimilarityTransform (Matrix &zmob) { + for (int i=0;i<(int)zmob.bs.size();i++) { + Coeff_t k = 1.0 / ( 2.0 * (zmob.bs[i] *(4 - zmob.M5) + 1.0) ); + kappa.push_back( k ); + kappaDag.push_back( conj(k) ); + kappaInv.push_back( 1.0 / k ); + kappaInvDag.push_back( 1.0 / conj(k) ); + } + } + + template + void sscale(const Lattice& in, Lattice& out, Coeff_t* s) { + GridBase *grid=out._grid; + out.checkerboard = in.checkerboard; + assert(grid->_simd_layout[0] == 1); // should be fine for ZMobius for now + int Ls = grid->_rdimensions[0]; + parallel_for(int ss=0;ssoSites();ss++){ + vobj tmp = s[ss % Ls]*in._odata[ss]; + vstream(out._odata[ss],tmp); + } + } + + RealD sscale_norm(const Field& in, Field& out, Coeff_t* s) { + sscale(in,out,s); + return norm2(out); + } + + virtual RealD M (const Field& in, Field& out) { return sscale_norm(in,out,&kappa[0]); } + virtual RealD MDag (const Field& in, Field& out) { return sscale_norm(in,out,&kappaDag[0]);} + virtual RealD MInv (const Field& in, Field& out) { return sscale_norm(in,out,&kappaInv[0]);} + virtual RealD MInvDag (const Field& in, Field& out) { return sscale_norm(in,out,&kappaInvDag[0]);} + + }; + + template + class SchurDiagTwoKappaOperator : public SchurOperatorBase { + public: + KappaSimilarityTransform _S; + SchurDiagTwoOperator _Mat; + + SchurDiagTwoKappaOperator (Matrix &Mat): _S(Mat), _Mat(Mat) {}; + + virtual RealD Mpc (const Field &in, Field &out) { + Field tmp(in._grid); + + _S.MInv(in,out); + _Mat.Mpc(out,tmp); + return _S.M(tmp,out); + + } + virtual RealD MpcDag (const Field &in, Field &out){ + Field tmp(in._grid); + + _S.MDag(in,out); + _Mat.MpcDag(out,tmp); + return _S.MInvDag(tmp,out); + } + }; + +} + +#endif diff --git a/lib/qcd/action/fermion/ShamirZolotarevFermion.h b/lib/qcd/action/fermion/ShamirZolotarevFermion.h index 732afa0a..f9397911 100644 --- a/lib/qcd/action/fermion/ShamirZolotarevFermion.h +++ b/lib/qcd/action/fermion/ShamirZolotarevFermion.h @@ -29,7 +29,7 @@ Author: Peter Boyle #ifndef GRID_QCD_SHAMIR_ZOLOTAREV_FERMION_H #define GRID_QCD_SHAMIR_ZOLOTAREV_FERMION_H -#include +#include namespace Grid { diff --git a/lib/qcd/action/fermion/StaggeredKernels.cc b/lib/qcd/action/fermion/StaggeredKernels.cc index 6608f8de..b6ec14c7 100644 --- a/lib/qcd/action/fermion/StaggeredKernels.cc +++ b/lib/qcd/action/fermion/StaggeredKernels.cc @@ -26,11 +26,12 @@ See the full license in the file "LICENSE" in the top level distribution directory *************************************************************************************/ /* END LEGAL */ -#include +#include + namespace Grid { namespace QCD { -int StaggeredKernelsStatic::Opt; +int StaggeredKernelsStatic::Opt= StaggeredKernelsStatic::OptGeneric; template StaggeredKernels::StaggeredKernels(const ImplParams &p) : Base(p){}; @@ -182,48 +183,79 @@ void StaggeredKernels::DhopSiteDepth(StencilImpl &st, LebesgueOrder &lo, D vstream(out, Uchi); }; -// Need controls to do interior, exterior, or both template void StaggeredKernels::DhopSiteDag(StencilImpl &st, LebesgueOrder &lo, DoubledGaugeField &U, DoubledGaugeField &UUU, - SiteSpinor *buf, int sF, - int sU, const FermionField &in, FermionField &out) { + SiteSpinor *buf, int LLs, int sU, + const FermionField &in, FermionField &out) { SiteSpinor naik; SiteSpinor naive; int oneLink =0; int threeLink=1; + int dag=1; switch(Opt) { +#ifdef AVX512 + //FIXME; move the sign into the Asm routine + case OptInlineAsm: + DhopSiteAsm(st,lo,U,UUU,buf,LLs,sU,in,out); + for(int s=0;s + +#ifdef AVX512 +#include +#include +#endif + +// Interleave operations from two directions +// This looks just like a 2 spin multiply and reuse same sequence from the Wilson +// Kernel. But the spin index becomes a mu index instead. +#define Chi_00 %zmm0 +#define Chi_01 %zmm1 +#define Chi_02 %zmm2 +#define Chi_10 %zmm3 +#define Chi_11 %zmm4 +#define Chi_12 %zmm5 +#define Chi_20 %zmm6 +#define Chi_21 %zmm7 +#define Chi_22 %zmm8 +#define Chi_30 %zmm9 +#define Chi_31 %zmm10 +#define Chi_32 %zmm11 + +#define UChi_00 %zmm12 +#define UChi_01 %zmm13 +#define UChi_02 %zmm14 +#define UChi_10 %zmm15 +#define UChi_11 %zmm16 +#define UChi_12 %zmm17 +#define UChi_20 %zmm18 +#define UChi_21 %zmm19 +#define UChi_22 %zmm20 +#define UChi_30 %zmm21 +#define UChi_31 %zmm22 +#define UChi_32 %zmm23 + +#define pChi_00 %%zmm0 +#define pChi_01 %%zmm1 +#define pChi_02 %%zmm2 +#define pChi_10 %%zmm3 +#define pChi_11 %%zmm4 +#define pChi_12 %%zmm5 +#define pChi_20 %%zmm6 +#define pChi_21 %%zmm7 +#define pChi_22 %%zmm8 +#define pChi_30 %%zmm9 +#define pChi_31 %%zmm10 +#define pChi_32 %%zmm11 + +#define pUChi_00 %%zmm12 +#define pUChi_01 %%zmm13 +#define pUChi_02 %%zmm14 +#define pUChi_10 %%zmm15 +#define pUChi_11 %%zmm16 +#define pUChi_12 %%zmm17 +#define pUChi_20 %%zmm18 +#define pUChi_21 %%zmm19 +#define pUChi_22 %%zmm20 +#define pUChi_30 %%zmm21 +#define pUChi_31 %%zmm22 +#define pUChi_32 %%zmm23 + +#define T0 %zmm24 +#define T1 %zmm25 +#define T2 %zmm26 +#define T3 %zmm27 + +#define Z00 %zmm26 +#define Z10 %zmm27 +#define Z0 Z00 +#define Z1 %zmm28 +#define Z2 %zmm29 + +#define Z3 %zmm30 +#define Z4 %zmm31 +#define Z5 Chi_31 +#define Z6 Chi_32 + +#define MULT_ADD_LS(g0,g1,g2,g3) \ + asm ( "movq %0, %%r8 \n\t" \ + "movq %1, %%r9 \n\t" \ + "movq %2, %%r10 \n\t" \ + "movq %3, %%r11 \n\t" : : "r"(g0), "r"(g1), "r"(g2), "r"(g3) : "%r8","%r9","%r10","%r11" );\ + asm ( \ + VSHUF(Chi_00,T0) VSHUF(Chi_10,T1) \ + VSHUF(Chi_20,T2) VSHUF(Chi_30,T3) \ + VMADDSUBIDUP(0,%r8,T0,UChi_00) VMADDSUBIDUP(0,%r9,T1,UChi_10) \ + VMADDSUBIDUP(3,%r8,T0,UChi_01) VMADDSUBIDUP(3,%r9,T1,UChi_11) \ + VMADDSUBIDUP(6,%r8,T0,UChi_02) VMADDSUBIDUP(6,%r9,T1,UChi_12) \ + VMADDSUBIDUP(0,%r10,T2,UChi_20) VMADDSUBIDUP(0,%r11,T3,UChi_30) \ + VMADDSUBIDUP(3,%r10,T2,UChi_21) VMADDSUBIDUP(3,%r11,T3,UChi_31) \ + VMADDSUBIDUP(6,%r10,T2,UChi_22) VMADDSUBIDUP(6,%r11,T3,UChi_32) \ + VMADDSUBRDUP(0,%r8,Chi_00,UChi_00) VMADDSUBRDUP(0,%r9,Chi_10,UChi_10) \ + VMADDSUBRDUP(3,%r8,Chi_00,UChi_01) VMADDSUBRDUP(3,%r9,Chi_10,UChi_11) \ + VMADDSUBRDUP(6,%r8,Chi_00,UChi_02) VMADDSUBRDUP(6,%r9,Chi_10,UChi_12) \ + VMADDSUBRDUP(0,%r10,Chi_20,UChi_20) VMADDSUBRDUP(0,%r11,Chi_30,UChi_30) \ + VMADDSUBRDUP(3,%r10,Chi_20,UChi_21) VMADDSUBRDUP(3,%r11,Chi_30,UChi_31) \ + VMADDSUBRDUP(6,%r10,Chi_20,UChi_22) VMADDSUBRDUP(6,%r11,Chi_30,UChi_32) \ + VSHUF(Chi_01,T0) VSHUF(Chi_11,T1) \ + VSHUF(Chi_21,T2) VSHUF(Chi_31,T3) \ + VMADDSUBIDUP(1,%r8,T0,UChi_00) VMADDSUBIDUP(1,%r9,T1,UChi_10) \ + VMADDSUBIDUP(4,%r8,T0,UChi_01) VMADDSUBIDUP(4,%r9,T1,UChi_11) \ + VMADDSUBIDUP(7,%r8,T0,UChi_02) VMADDSUBIDUP(7,%r9,T1,UChi_12) \ + VMADDSUBIDUP(1,%r10,T2,UChi_20) VMADDSUBIDUP(1,%r11,T3,UChi_30) \ + VMADDSUBIDUP(4,%r10,T2,UChi_21) VMADDSUBIDUP(4,%r11,T3,UChi_31) \ + VMADDSUBIDUP(7,%r10,T2,UChi_22) VMADDSUBIDUP(7,%r11,T3,UChi_32) \ + VMADDSUBRDUP(1,%r8,Chi_01,UChi_00) VMADDSUBRDUP(1,%r9,Chi_11,UChi_10) \ + VMADDSUBRDUP(4,%r8,Chi_01,UChi_01) VMADDSUBRDUP(4,%r9,Chi_11,UChi_11) \ + VMADDSUBRDUP(7,%r8,Chi_01,UChi_02) VMADDSUBRDUP(7,%r9,Chi_11,UChi_12) \ + VMADDSUBRDUP(1,%r10,Chi_21,UChi_20) VMADDSUBRDUP(1,%r11,Chi_31,UChi_30) \ + VMADDSUBRDUP(4,%r10,Chi_21,UChi_21) VMADDSUBRDUP(4,%r11,Chi_31,UChi_31) \ + VMADDSUBRDUP(7,%r10,Chi_21,UChi_22) VMADDSUBRDUP(7,%r11,Chi_31,UChi_32) \ + VSHUF(Chi_02,T0) VSHUF(Chi_12,T1) \ + VSHUF(Chi_22,T2) VSHUF(Chi_32,T3) \ + VMADDSUBIDUP(2,%r8,T0,UChi_00) VMADDSUBIDUP(2,%r9,T1,UChi_10) \ + VMADDSUBIDUP(5,%r8,T0,UChi_01) VMADDSUBIDUP(5,%r9,T1,UChi_11) \ + VMADDSUBIDUP(8,%r8,T0,UChi_02) VMADDSUBIDUP(8,%r9,T1,UChi_12) \ + VMADDSUBIDUP(2,%r10,T2,UChi_20) VMADDSUBIDUP(2,%r11,T3,UChi_30) \ + VMADDSUBIDUP(5,%r10,T2,UChi_21) VMADDSUBIDUP(5,%r11,T3,UChi_31) \ + VMADDSUBIDUP(8,%r10,T2,UChi_22) VMADDSUBIDUP(8,%r11,T3,UChi_32) \ + VMADDSUBRDUP(2,%r8,Chi_02,UChi_00) VMADDSUBRDUP(2,%r9,Chi_12,UChi_10) \ + VMADDSUBRDUP(5,%r8,Chi_02,UChi_01) VMADDSUBRDUP(5,%r9,Chi_12,UChi_11) \ + VMADDSUBRDUP(8,%r8,Chi_02,UChi_02) VMADDSUBRDUP(8,%r9,Chi_12,UChi_12) \ + VMADDSUBRDUP(2,%r10,Chi_22,UChi_20) VMADDSUBRDUP(2,%r11,Chi_32,UChi_30) \ + VMADDSUBRDUP(5,%r10,Chi_22,UChi_21) VMADDSUBRDUP(5,%r11,Chi_32,UChi_31) \ + VMADDSUBRDUP(8,%r10,Chi_22,UChi_22) VMADDSUBRDUP(8,%r11,Chi_32,UChi_32) ); + +#define MULT_LS(g0,g1,g2,g3) \ + asm ( "movq %0, %%r8 \n\t" \ + "movq %1, %%r9 \n\t" \ + "movq %2, %%r10 \n\t" \ + "movq %3, %%r11 \n\t" : : "r"(g0), "r"(g1), "r"(g2), "r"(g3) : "%r8","%r9","%r10","%r11" );\ + asm ( \ + VSHUF(Chi_00,T0) VSHUF(Chi_10,T1) \ + VSHUF(Chi_20,T2) VSHUF(Chi_30,T3) \ + VMULIDUP(0,%r8,T0,UChi_00) VMULIDUP(0,%r9,T1,UChi_10) \ + VMULIDUP(3,%r8,T0,UChi_01) VMULIDUP(3,%r9,T1,UChi_11) \ + VMULIDUP(6,%r8,T0,UChi_02) VMULIDUP(6,%r9,T1,UChi_12) \ + VMULIDUP(0,%r10,T2,UChi_20) VMULIDUP(0,%r11,T3,UChi_30) \ + VMULIDUP(3,%r10,T2,UChi_21) VMULIDUP(3,%r11,T3,UChi_31) \ + VMULIDUP(6,%r10,T2,UChi_22) VMULIDUP(6,%r11,T3,UChi_32) \ + VMADDSUBRDUP(0,%r8,Chi_00,UChi_00) VMADDSUBRDUP(0,%r9,Chi_10,UChi_10) \ + VMADDSUBRDUP(3,%r8,Chi_00,UChi_01) VMADDSUBRDUP(3,%r9,Chi_10,UChi_11) \ + VMADDSUBRDUP(6,%r8,Chi_00,UChi_02) VMADDSUBRDUP(6,%r9,Chi_10,UChi_12) \ + VMADDSUBRDUP(0,%r10,Chi_20,UChi_20) VMADDSUBRDUP(0,%r11,Chi_30,UChi_30) \ + VMADDSUBRDUP(3,%r10,Chi_20,UChi_21) VMADDSUBRDUP(3,%r11,Chi_30,UChi_31) \ + VMADDSUBRDUP(6,%r10,Chi_20,UChi_22) VMADDSUBRDUP(6,%r11,Chi_30,UChi_32) \ + VSHUF(Chi_01,T0) VSHUF(Chi_11,T1) \ + VSHUF(Chi_21,T2) VSHUF(Chi_31,T3) \ + VMADDSUBIDUP(1,%r8,T0,UChi_00) VMADDSUBIDUP(1,%r9,T1,UChi_10) \ + VMADDSUBIDUP(4,%r8,T0,UChi_01) VMADDSUBIDUP(4,%r9,T1,UChi_11) \ + VMADDSUBIDUP(7,%r8,T0,UChi_02) VMADDSUBIDUP(7,%r9,T1,UChi_12) \ + VMADDSUBIDUP(1,%r10,T2,UChi_20) VMADDSUBIDUP(1,%r11,T3,UChi_30) \ + VMADDSUBIDUP(4,%r10,T2,UChi_21) VMADDSUBIDUP(4,%r11,T3,UChi_31) \ + VMADDSUBIDUP(7,%r10,T2,UChi_22) VMADDSUBIDUP(7,%r11,T3,UChi_32) \ + VMADDSUBRDUP(1,%r8,Chi_01,UChi_00) VMADDSUBRDUP(1,%r9,Chi_11,UChi_10) \ + VMADDSUBRDUP(4,%r8,Chi_01,UChi_01) VMADDSUBRDUP(4,%r9,Chi_11,UChi_11) \ + VMADDSUBRDUP(7,%r8,Chi_01,UChi_02) VMADDSUBRDUP(7,%r9,Chi_11,UChi_12) \ + VMADDSUBRDUP(1,%r10,Chi_21,UChi_20) VMADDSUBRDUP(1,%r11,Chi_31,UChi_30) \ + VMADDSUBRDUP(4,%r10,Chi_21,UChi_21) VMADDSUBRDUP(4,%r11,Chi_31,UChi_31) \ + VMADDSUBRDUP(7,%r10,Chi_21,UChi_22) VMADDSUBRDUP(7,%r11,Chi_31,UChi_32) \ + VSHUF(Chi_02,T0) VSHUF(Chi_12,T1) \ + VSHUF(Chi_22,T2) VSHUF(Chi_32,T3) \ + VMADDSUBIDUP(2,%r8,T0,UChi_00) VMADDSUBIDUP(2,%r9,T1,UChi_10) \ + VMADDSUBIDUP(5,%r8,T0,UChi_01) VMADDSUBIDUP(5,%r9,T1,UChi_11) \ + VMADDSUBIDUP(8,%r8,T0,UChi_02) VMADDSUBIDUP(8,%r9,T1,UChi_12) \ + VMADDSUBIDUP(2,%r10,T2,UChi_20) VMADDSUBIDUP(2,%r11,T3,UChi_30) \ + VMADDSUBIDUP(5,%r10,T2,UChi_21) VMADDSUBIDUP(5,%r11,T3,UChi_31) \ + VMADDSUBIDUP(8,%r10,T2,UChi_22) VMADDSUBIDUP(8,%r11,T3,UChi_32) \ + VMADDSUBRDUP(2,%r8,Chi_02,UChi_00) VMADDSUBRDUP(2,%r9,Chi_12,UChi_10) \ + VMADDSUBRDUP(5,%r8,Chi_02,UChi_01) VMADDSUBRDUP(5,%r9,Chi_12,UChi_11) \ + VMADDSUBRDUP(8,%r8,Chi_02,UChi_02) VMADDSUBRDUP(8,%r9,Chi_12,UChi_12) \ + VMADDSUBRDUP(2,%r10,Chi_22,UChi_20) VMADDSUBRDUP(2,%r11,Chi_32,UChi_30) \ + VMADDSUBRDUP(5,%r10,Chi_22,UChi_21) VMADDSUBRDUP(5,%r11,Chi_32,UChi_31) \ + VMADDSUBRDUP(8,%r10,Chi_22,UChi_22) VMADDSUBRDUP(8,%r11,Chi_32,UChi_32) ); + +#define MULT_ADD_XYZTa(g0,g1) \ + asm ( "movq %0, %%r8 \n\t" \ + "movq %1, %%r9 \n\t" : : "r"(g0), "r"(g1) : "%r8","%r9");\ + __asm__ ( \ + VSHUF(Chi_00,T0) \ + VSHUF(Chi_10,T1) \ + VMOVIDUP(0,%r8,Z0 ) \ + VMOVIDUP(3,%r8,Z1 ) \ + VMOVIDUP(6,%r8,Z2 ) \ + VMADDSUB(Z0,T0,UChi_00) \ + VMADDSUB(Z1,T0,UChi_01) \ + VMADDSUB(Z2,T0,UChi_02) \ + \ + VMOVIDUP(0,%r9,Z0 ) \ + VMOVIDUP(3,%r9,Z1 ) \ + VMOVIDUP(6,%r9,Z2 ) \ + VMADDSUB(Z0,T1,UChi_10) \ + VMADDSUB(Z1,T1,UChi_11) \ + VMADDSUB(Z2,T1,UChi_12) \ + \ + \ + VMOVRDUP(0,%r8,Z3 ) \ + VMOVRDUP(3,%r8,Z4 ) \ + VMOVRDUP(6,%r8,Z5 ) \ + VMADDSUB(Z3,Chi_00,UChi_00)/*rr * ir = ri rr*/ \ + VMADDSUB(Z4,Chi_00,UChi_01) \ + VMADDSUB(Z5,Chi_00,UChi_02) \ + \ + VMOVRDUP(0,%r9,Z3 ) \ + VMOVRDUP(3,%r9,Z4 ) \ + VMOVRDUP(6,%r9,Z5 ) \ + VMADDSUB(Z3,Chi_10,UChi_10) \ + VMADDSUB(Z4,Chi_10,UChi_11)\ + VMADDSUB(Z5,Chi_10,UChi_12) \ + \ + \ + VMOVIDUP(1,%r8,Z0 ) \ + VMOVIDUP(4,%r8,Z1 ) \ + VMOVIDUP(7,%r8,Z2 ) \ + VSHUF(Chi_01,T0) \ + VMADDSUB(Z0,T0,UChi_00) \ + VMADDSUB(Z1,T0,UChi_01) \ + VMADDSUB(Z2,T0,UChi_02) \ + \ + VMOVIDUP(1,%r9,Z0 ) \ + VMOVIDUP(4,%r9,Z1 ) \ + VMOVIDUP(7,%r9,Z2 ) \ + VSHUF(Chi_11,T1) \ + VMADDSUB(Z0,T1,UChi_10) \ + VMADDSUB(Z1,T1,UChi_11) \ + VMADDSUB(Z2,T1,UChi_12) \ + \ + VMOVRDUP(1,%r8,Z3 ) \ + VMOVRDUP(4,%r8,Z4 ) \ + VMOVRDUP(7,%r8,Z5 ) \ + VMADDSUB(Z3,Chi_01,UChi_00) \ + VMADDSUB(Z4,Chi_01,UChi_01) \ + VMADDSUB(Z5,Chi_01,UChi_02) \ + \ + VMOVRDUP(1,%r9,Z3 ) \ + VMOVRDUP(4,%r9,Z4 ) \ + VMOVRDUP(7,%r9,Z5 ) \ + VMADDSUB(Z3,Chi_11,UChi_10) \ + VMADDSUB(Z4,Chi_11,UChi_11) \ + VMADDSUB(Z5,Chi_11,UChi_12) \ + \ + VSHUF(Chi_02,T0) \ + VSHUF(Chi_12,T1) \ + VMOVIDUP(2,%r8,Z0 ) \ + VMOVIDUP(5,%r8,Z1 ) \ + VMOVIDUP(8,%r8,Z2 ) \ + VMADDSUB(Z0,T0,UChi_00) \ + VMADDSUB(Z1,T0,UChi_01) \ + VMADDSUB(Z2,T0,UChi_02) \ + VMOVIDUP(2,%r9,Z0 ) \ + VMOVIDUP(5,%r9,Z1 ) \ + VMOVIDUP(8,%r9,Z2 ) \ + VMADDSUB(Z0,T1,UChi_10) \ + VMADDSUB(Z1,T1,UChi_11) \ + VMADDSUB(Z2,T1,UChi_12) \ + /*55*/ \ + VMOVRDUP(2,%r8,Z3 ) \ + VMOVRDUP(5,%r8,Z4 ) \ + VMOVRDUP(8,%r8,Z5 ) \ + VMADDSUB(Z3,Chi_02,UChi_00) \ + VMADDSUB(Z4,Chi_02,UChi_01) \ + VMADDSUB(Z5,Chi_02,UChi_02) \ + VMOVRDUP(2,%r9,Z3 ) \ + VMOVRDUP(5,%r9,Z4 ) \ + VMOVRDUP(8,%r9,Z5 ) \ + VMADDSUB(Z3,Chi_12,UChi_10) \ + VMADDSUB(Z4,Chi_12,UChi_11) \ + VMADDSUB(Z5,Chi_12,UChi_12) \ + /*61 insns*/ ); + +#define MULT_ADD_XYZT(g0,g1) \ + asm ( "movq %0, %%r8 \n\t" \ + "movq %1, %%r9 \n\t" : : "r"(g0), "r"(g1) : "%r8","%r9");\ + __asm__ ( \ + VSHUFMEM(0,%r8,Z00) VSHUFMEM(0,%r9,Z10) \ + VRDUP(Chi_00,T0) VIDUP(Chi_00,Chi_00) \ + VRDUP(Chi_10,T1) VIDUP(Chi_10,Chi_10) \ + VMUL(Z00,Chi_00,Z1) VMUL(Z10,Chi_10,Z2) \ + VSHUFMEM(3,%r8,Z00) VSHUFMEM(3,%r9,Z10) \ + VMUL(Z00,Chi_00,Z3) VMUL(Z10,Chi_10,Z4) \ + VSHUFMEM(6,%r8,Z00) VSHUFMEM(6,%r9,Z10) \ + VMUL(Z00,Chi_00,Z5) VMUL(Z10,Chi_10,Z6) \ + VMADDMEM(0,%r8,T0,UChi_00) VMADDMEM(0,%r9,T1,UChi_10) \ + VMADDMEM(3,%r8,T0,UChi_01) VMADDMEM(3,%r9,T1,UChi_11) \ + VMADDMEM(6,%r8,T0,UChi_02) VMADDMEM(6,%r9,T1,UChi_12) \ + VSHUFMEM(1,%r8,Z00) VSHUFMEM(1,%r9,Z10) \ + VRDUP(Chi_01,T0) VIDUP(Chi_01,Chi_01) \ + VRDUP(Chi_11,T1) VIDUP(Chi_11,Chi_11) \ + VMADD(Z00,Chi_01,Z1) VMADD(Z10,Chi_11,Z2) \ + VSHUFMEM(4,%r8,Z00) VSHUFMEM(4,%r9,Z10) \ + VMADD(Z00,Chi_01,Z3) VMADD(Z10,Chi_11,Z4) \ + VSHUFMEM(7,%r8,Z00) VSHUFMEM(7,%r9,Z10) \ + VMADD(Z00,Chi_01,Z5) VMADD(Z10,Chi_11,Z6) \ + VMADDMEM(1,%r8,T0,UChi_00) VMADDMEM(1,%r9,T1,UChi_10) \ + VMADDMEM(4,%r8,T0,UChi_01) VMADDMEM(4,%r9,T1,UChi_11) \ + VMADDMEM(7,%r8,T0,UChi_02) VMADDMEM(7,%r9,T1,UChi_12) \ + VSHUFMEM(2,%r8,Z00) VSHUFMEM(2,%r9,Z10) \ + VRDUP(Chi_02,T0) VIDUP(Chi_02,Chi_02) \ + VRDUP(Chi_12,T1) VIDUP(Chi_12,Chi_12) \ + VMADD(Z00,Chi_02,Z1) VMADD(Z10,Chi_12,Z2) \ + VSHUFMEM(5,%r8,Z00) VSHUFMEM(5,%r9,Z10) \ + VMADD(Z00,Chi_02,Z3) VMADD(Z10,Chi_12,Z4) \ + VSHUFMEM(8,%r8,Z00) VSHUFMEM(8,%r9,Z10) \ + VMADD(Z00,Chi_02,Z5) VMADD(Z10,Chi_12,Z6) \ + VMADDSUBMEM(2,%r8,T0,Z1) VMADDSUBMEM(2,%r9,T1,Z2) \ + VMADDSUBMEM(5,%r8,T0,Z3) VMADDSUBMEM(5,%r9,T1,Z4) \ + VMADDSUBMEM(8,%r8,T0,Z5) VMADDSUBMEM(8,%r9,T1,Z6) \ + VADD(Z1,UChi_00,UChi_00) VADD(Z2,UChi_10,UChi_10) \ + VADD(Z3,UChi_01,UChi_01) VADD(Z4,UChi_11,UChi_11) \ + VADD(Z5,UChi_02,UChi_02) VADD(Z6,UChi_12,UChi_12) ); + +#define MULT_XYZT(g0,g1) \ + asm ( "movq %0, %%r8 \n\t" \ + "movq %1, %%r9 \n\t" : : "r"(g0), "r"(g1) : "%r8","%r9" ); \ + __asm__ ( \ + VSHUF(Chi_00,T0) \ + VSHUF(Chi_10,T1) \ + VMOVIDUP(0,%r8,Z0 ) \ + VMOVIDUP(3,%r8,Z1 ) \ + VMOVIDUP(6,%r8,Z2 ) \ + /*6*/ \ + VMUL(Z0,T0,UChi_00) \ + VMUL(Z1,T0,UChi_01) \ + VMUL(Z2,T0,UChi_02) \ + VMOVIDUP(0,%r9,Z0 ) \ + VMOVIDUP(3,%r9,Z1 ) \ + VMOVIDUP(6,%r9,Z2 ) \ + VMUL(Z0,T1,UChi_10) \ + VMUL(Z1,T1,UChi_11) \ + VMUL(Z2,T1,UChi_12) \ + VMOVRDUP(0,%r8,Z3 ) \ + VMOVRDUP(3,%r8,Z4 ) \ + VMOVRDUP(6,%r8,Z5 ) \ + /*18*/ \ + VMADDSUB(Z3,Chi_00,UChi_00) \ + VMADDSUB(Z4,Chi_00,UChi_01)\ + VMADDSUB(Z5,Chi_00,UChi_02) \ + VMOVRDUP(0,%r9,Z3 ) \ + VMOVRDUP(3,%r9,Z4 ) \ + VMOVRDUP(6,%r9,Z5 ) \ + VMADDSUB(Z3,Chi_10,UChi_10) \ + VMADDSUB(Z4,Chi_10,UChi_11)\ + VMADDSUB(Z5,Chi_10,UChi_12) \ + VMOVIDUP(1,%r8,Z0 ) \ + VMOVIDUP(4,%r8,Z1 ) \ + VMOVIDUP(7,%r8,Z2 ) \ + /*28*/ \ + VSHUF(Chi_01,T0) \ + VMADDSUB(Z0,T0,UChi_00) \ + VMADDSUB(Z1,T0,UChi_01) \ + VMADDSUB(Z2,T0,UChi_02) \ + VMOVIDUP(1,%r9,Z0 ) \ + VMOVIDUP(4,%r9,Z1 ) \ + VMOVIDUP(7,%r9,Z2 ) \ + VSHUF(Chi_11,T1) \ + VMADDSUB(Z0,T1,UChi_10) \ + VMADDSUB(Z1,T1,UChi_11) \ + VMADDSUB(Z2,T1,UChi_12) \ + VMOVRDUP(1,%r8,Z3 ) \ + VMOVRDUP(4,%r8,Z4 ) \ + VMOVRDUP(7,%r8,Z5 ) \ + /*38*/ \ + VMADDSUB(Z3,Chi_01,UChi_00) \ + VMADDSUB(Z4,Chi_01,UChi_01) \ + VMADDSUB(Z5,Chi_01,UChi_02) \ + VMOVRDUP(1,%r9,Z3 ) \ + VMOVRDUP(4,%r9,Z4 ) \ + VMOVRDUP(7,%r9,Z5 ) \ + VMADDSUB(Z3,Chi_11,UChi_10) \ + VMADDSUB(Z4,Chi_11,UChi_11) \ + VMADDSUB(Z5,Chi_11,UChi_12) \ + /*48*/ \ + VSHUF(Chi_02,T0) \ + VSHUF(Chi_12,T1) \ + VMOVIDUP(2,%r8,Z0 ) \ + VMOVIDUP(5,%r8,Z1 ) \ + VMOVIDUP(8,%r8,Z2 ) \ + VMADDSUB(Z0,T0,UChi_00) \ + VMADDSUB(Z1,T0,UChi_01) \ + VMADDSUB(Z2,T0,UChi_02) \ + VMOVIDUP(2,%r9,Z0 ) \ + VMOVIDUP(5,%r9,Z1 ) \ + VMOVIDUP(8,%r9,Z2 ) \ + VMADDSUB(Z0,T1,UChi_10) \ + VMADDSUB(Z1,T1,UChi_11) \ + VMADDSUB(Z2,T1,UChi_12) \ + /*55*/ \ + VMOVRDUP(2,%r8,Z3 ) \ + VMOVRDUP(5,%r8,Z4 ) \ + VMOVRDUP(8,%r8,Z5 ) \ + VMADDSUB(Z3,Chi_02,UChi_00) \ + VMADDSUB(Z4,Chi_02,UChi_01) \ + VMADDSUB(Z5,Chi_02,UChi_02) \ + VMOVRDUP(2,%r9,Z3 ) \ + VMOVRDUP(5,%r9,Z4 ) \ + VMOVRDUP(8,%r9,Z5 ) \ + VMADDSUB(Z3,Chi_12,UChi_10) \ + VMADDSUB(Z4,Chi_12,UChi_11) \ + VMADDSUB(Z5,Chi_12,UChi_12) \ + /*61 insns*/ ); + +#define MULT_XYZTa(g0,g1) \ + asm ( "movq %0, %%r8 \n\t" \ + "movq %1, %%r9 \n\t" : : "r"(g0), "r"(g1) : "%r8","%r9" ); \ + __asm__ ( \ + VSHUFMEM(0,%r8,Z00) VSHUFMEM(0,%r9,Z10) \ + VRDUP(Chi_00,T0) VIDUP(Chi_00,Chi_00) \ + VRDUP(Chi_10,T1) VIDUP(Chi_10,Chi_10) \ + VMUL(Z00,Chi_00,Z1) VMUL(Z10,Chi_10,Z2) \ + VSHUFMEM(3,%r8,Z00) VSHUFMEM(3,%r9,Z10) \ + VMUL(Z00,Chi_00,Z3) VMUL(Z10,Chi_10,Z4) \ + VSHUFMEM(6,%r8,Z00) VSHUFMEM(6,%r9,Z10) \ + VMUL(Z00,Chi_00,Z5) VMUL(Z10,Chi_10,Z6) \ + VMULMEM(0,%r8,T0,UChi_00) VMULMEM(0,%r9,T1,UChi_10) \ + VMULMEM(3,%r8,T0,UChi_01) VMULMEM(3,%r9,T1,UChi_11) \ + VMULMEM(6,%r8,T0,UChi_02) VMULMEM(6,%r9,T1,UChi_12) \ + VSHUFMEM(1,%r8,Z00) VSHUFMEM(1,%r9,Z10) \ + VRDUP(Chi_01,T0) VIDUP(Chi_01,Chi_01) \ + VRDUP(Chi_11,T1) VIDUP(Chi_11,Chi_11) \ + VMADD(Z00,Chi_01,Z1) VMADD(Z10,Chi_11,Z2) \ + VSHUFMEM(4,%r8,Z00) VSHUFMEM(4,%r9,Z10) \ + VMADD(Z00,Chi_01,Z3) VMADD(Z10,Chi_11,Z4) \ + VSHUFMEM(7,%r8,Z00) VSHUFMEM(7,%r9,Z10) \ + VMADD(Z00,Chi_01,Z5) VMADD(Z10,Chi_11,Z6) \ + VMADDMEM(1,%r8,T0,UChi_00) VMADDMEM(1,%r9,T1,UChi_10) \ + VMADDMEM(4,%r8,T0,UChi_01) VMADDMEM(4,%r9,T1,UChi_11) \ + VMADDMEM(7,%r8,T0,UChi_02) VMADDMEM(7,%r9,T1,UChi_12) \ + VSHUFMEM(2,%r8,Z00) VSHUFMEM(2,%r9,Z10) \ + VRDUP(Chi_02,T0) VIDUP(Chi_02,Chi_02) \ + VRDUP(Chi_12,T1) VIDUP(Chi_12,Chi_12) \ + VMADD(Z00,Chi_02,Z1) VMADD(Z10,Chi_12,Z2) \ + VSHUFMEM(5,%r8,Z00) VSHUFMEM(5,%r9,Z10) \ + VMADD(Z00,Chi_02,Z3) VMADD(Z10,Chi_12,Z4) \ + VSHUFMEM(8,%r8,Z00) VSHUFMEM(8,%r9,Z10) \ + VMADD(Z00,Chi_02,Z5) VMADD(Z10,Chi_12,Z6) \ + VMADDSUBMEM(2,%r8,T0,Z1) VMADDSUBMEM(2,%r9,T1,Z2) \ + VMADDSUBMEM(5,%r8,T0,Z3) VMADDSUBMEM(5,%r9,T1,Z4) \ + VMADDSUBMEM(8,%r8,T0,Z5) VMADDSUBMEM(8,%r9,T1,Z6) \ + VADD(Z1,UChi_00,UChi_00) VADD(Z2,UChi_10,UChi_10) \ + VADD(Z3,UChi_01,UChi_01) VADD(Z4,UChi_11,UChi_11) \ + VADD(Z5,UChi_02,UChi_02) VADD(Z6,UChi_12,UChi_12) ); + + +#define LOAD_CHI(a0,a1,a2,a3) \ + asm ( \ + "movq %0, %%r8 \n\t" \ + VLOAD(0,%%r8,pChi_00) \ + VLOAD(1,%%r8,pChi_01) \ + VLOAD(2,%%r8,pChi_02) \ + : : "r" (a0) : "%r8" ); \ + asm ( \ + "movq %0, %%r8 \n\t" \ + VLOAD(0,%%r8,pChi_10) \ + VLOAD(1,%%r8,pChi_11) \ + VLOAD(2,%%r8,pChi_12) \ + : : "r" (a1) : "%r8" ); \ + asm ( \ + "movq %0, %%r8 \n\t" \ + VLOAD(0,%%r8,pChi_20) \ + VLOAD(1,%%r8,pChi_21) \ + VLOAD(2,%%r8,pChi_22) \ + : : "r" (a2) : "%r8" ); \ + asm ( \ + "movq %0, %%r8 \n\t" \ + VLOAD(0,%%r8,pChi_30) \ + VLOAD(1,%%r8,pChi_31) \ + VLOAD(2,%%r8,pChi_32) \ + : : "r" (a3) : "%r8" ); + +#define LOAD_CHIa(a0,a1) \ + asm ( \ + "movq %0, %%r8 \n\t" \ + VLOAD(0,%%r8,pChi_00) \ + VLOAD(1,%%r8,pChi_01) \ + VLOAD(2,%%r8,pChi_02) \ + : : "r" (a0) : "%r8" ); \ + asm ( \ + "movq %0, %%r8 \n\t" \ + VLOAD(0,%%r8,pChi_10) \ + VLOAD(1,%%r8,pChi_11) \ + VLOAD(2,%%r8,pChi_12) \ + : : "r" (a1) : "%r8" ); + +#define PF_CHI(a0) +#define PF_CHIa(a0) \ + asm ( \ + "movq %0, %%r8 \n\t" \ + VPREFETCH1(0,%%r8) \ + VPREFETCH1(1,%%r8) \ + VPREFETCH1(2,%%r8) \ + : : "r" (a0) : "%r8" ); \ + +#define PF_GAUGE_XYZT(a0) +#define PF_GAUGE_XYZTa(a0) \ + asm ( \ + "movq %0, %%r8 \n\t" \ + VPREFETCH1(0,%%r8) \ + VPREFETCH1(1,%%r8) \ + VPREFETCH1(2,%%r8) \ + VPREFETCH1(3,%%r8) \ + VPREFETCH1(4,%%r8) \ + VPREFETCH1(5,%%r8) \ + VPREFETCH1(6,%%r8) \ + VPREFETCH1(7,%%r8) \ + VPREFETCH1(8,%%r8) \ + : : "r" (a0) : "%r8" ); \ + +#define PF_GAUGE_LS(a0) +#define PF_GAUGE_LSa(a0) \ + asm ( \ + "movq %0, %%r8 \n\t" \ + VPREFETCH1(0,%%r8) \ + VPREFETCH1(1,%%r8) \ + : : "r" (a0) : "%r8" ); \ + + +#define REDUCE(out) \ + asm ( \ + VADD(UChi_00,UChi_10,UChi_00) \ + VADD(UChi_01,UChi_11,UChi_01) \ + VADD(UChi_02,UChi_12,UChi_02) \ + VADD(UChi_30,UChi_20,UChi_30) \ + VADD(UChi_31,UChi_21,UChi_31) \ + VADD(UChi_32,UChi_22,UChi_32) \ + VADD(UChi_00,UChi_30,UChi_00) \ + VADD(UChi_01,UChi_31,UChi_01) \ + VADD(UChi_02,UChi_32,UChi_02) ); \ + asm ( \ + VSTORE(0,%0,pUChi_00) \ + VSTORE(1,%0,pUChi_01) \ + VSTORE(2,%0,pUChi_02) \ + : : "r" (out) : "memory" ); + +#define REDUCEa(out) \ + asm ( \ + VADD(UChi_00,UChi_10,UChi_00) \ + VADD(UChi_01,UChi_11,UChi_01) \ + VADD(UChi_02,UChi_12,UChi_02) ); \ + asm ( \ + VSTORE(0,%0,pUChi_00) \ + VSTORE(1,%0,pUChi_01) \ + VSTORE(2,%0,pUChi_02) \ + : : "r" (out) : "memory" ); + +#define PERMUTE_DIR(dir) \ + permute##dir(Chi_0,Chi_0);\ + permute##dir(Chi_1,Chi_1);\ + permute##dir(Chi_2,Chi_2); + +namespace Grid { +namespace QCD { + +template +void StaggeredKernels::DhopSiteAsm(StencilImpl &st, LebesgueOrder &lo, + DoubledGaugeField &U, + DoubledGaugeField &UUU, + SiteSpinor *buf, int LLs, + int sU, const FermionField &in, FermionField &out) +{ + assert(0); +}; + + +//#define CONDITIONAL_MOVE(l,o,out) if ( l ) { out = (uint64_t) &in._odata[o] ; } else { out =(uint64_t) &buf[o]; } + +#define CONDITIONAL_MOVE(l,o,out) { const SiteSpinor *ptr = l? in_p : buf; out = (uint64_t) &ptr[o]; } + +#define PREPARE_XYZT(X,Y,Z,T,skew,UU) \ + PREPARE(X,Y,Z,T,skew,UU); \ + PF_GAUGE_XYZT(gauge0); \ + PF_GAUGE_XYZT(gauge1); \ + PF_GAUGE_XYZT(gauge2); \ + PF_GAUGE_XYZT(gauge3); + +#define PREPARE_LS(X,Y,Z,T,skew,UU) \ + PREPARE(X,Y,Z,T,skew,UU); \ + PF_GAUGE_LS(gauge0); \ + PF_GAUGE_LS(gauge1); \ + PF_GAUGE_LS(gauge2); \ + PF_GAUGE_LS(gauge3); + +#define PREPARE(X,Y,Z,T,skew,UU) \ + SE0=st.GetEntry(ptype,X+skew,sF); \ + o0 = SE0->_offset; \ + l0 = SE0->_is_local; \ + p0 = SE0->_permute; \ + CONDITIONAL_MOVE(l0,o0,addr0); \ + PF_CHI(addr0); \ + \ + SE1=st.GetEntry(ptype,Y+skew,sF); \ + o1 = SE1->_offset; \ + l1 = SE1->_is_local; \ + p1 = SE1->_permute; \ + CONDITIONAL_MOVE(l1,o1,addr1); \ + PF_CHI(addr1); \ + \ + SE2=st.GetEntry(ptype,Z+skew,sF); \ + o2 = SE2->_offset; \ + l2 = SE2->_is_local; \ + p2 = SE2->_permute; \ + CONDITIONAL_MOVE(l2,o2,addr2); \ + PF_CHI(addr2); \ + \ + SE3=st.GetEntry(ptype,T+skew,sF); \ + o3 = SE3->_offset; \ + l3 = SE3->_is_local; \ + p3 = SE3->_permute; \ + CONDITIONAL_MOVE(l3,o3,addr3); \ + PF_CHI(addr3); \ + \ + gauge0 =(uint64_t)&UU._odata[sU]( X ); \ + gauge1 =(uint64_t)&UU._odata[sU]( Y ); \ + gauge2 =(uint64_t)&UU._odata[sU]( Z ); \ + gauge3 =(uint64_t)&UU._odata[sU]( T ); + + // This is the single precision 5th direction vectorised kernel +#include +template <> void StaggeredKernels::DhopSiteAsm(StencilImpl &st, LebesgueOrder &lo, + DoubledGaugeField &U, + DoubledGaugeField &UUU, + SiteSpinor *buf, int LLs, + int sU, const FermionField &in, FermionField &out) +{ +#ifdef AVX512 + uint64_t gauge0,gauge1,gauge2,gauge3; + uint64_t addr0,addr1,addr2,addr3; + const SiteSpinor *in_p; in_p = &in._odata[0]; + + int o0,o1,o2,o3; // offsets + int l0,l1,l2,l3; // local + int p0,p1,p2,p3; // perm + int ptype; + StencilEntry *SE0; + StencilEntry *SE1; + StencilEntry *SE2; + StencilEntry *SE3; + + for(int s=0;s +template <> void StaggeredKernels::DhopSiteAsm(StencilImpl &st, LebesgueOrder &lo, + DoubledGaugeField &U, + DoubledGaugeField &UUU, + SiteSpinor *buf, int LLs, + int sU, const FermionField &in, FermionField &out) +{ +#ifdef AVX512 + uint64_t gauge0,gauge1,gauge2,gauge3; + uint64_t addr0,addr1,addr2,addr3; + const SiteSpinor *in_p; in_p = &in._odata[0]; + + int o0,o1,o2,o3; // offsets + int l0,l1,l2,l3; // local + int p0,p1,p2,p3; // perm + int ptype; + StencilEntry *SE0; + StencilEntry *SE1; + StencilEntry *SE2; + StencilEntry *SE3; + + for(int s=0;s +template <> void StaggeredKernels::DhopSiteAsm(StencilImpl &st, LebesgueOrder &lo, + DoubledGaugeField &U, + DoubledGaugeField &UUU, + SiteSpinor *buf, int LLs, + int sU, const FermionField &in, FermionField &out) +{ +#ifdef AVX512 + uint64_t gauge0,gauge1,gauge2,gauge3; + uint64_t addr0,addr1,addr2,addr3; + const SiteSpinor *in_p; in_p = &in._odata[0]; + + int o0,o1,o2,o3; // offsets + int l0,l1,l2,l3; // local + int p0,p1,p2,p3; // perm + int ptype; + StencilEntry *SE0; + StencilEntry *SE1; + StencilEntry *SE2; + StencilEntry *SE3; + + for(int s=0;s +template <> void StaggeredKernels::DhopSiteAsm(StencilImpl &st, LebesgueOrder &lo, + DoubledGaugeField &U, + DoubledGaugeField &UUU, + SiteSpinor *buf, int LLs, + int sU, const FermionField &in, FermionField &out) +{ +#ifdef AVX512 + uint64_t gauge0,gauge1,gauge2,gauge3; + uint64_t addr0,addr1,addr2,addr3; + const SiteSpinor *in_p; in_p = &in._odata[0]; + + int o0,o1,o2,o3; // offsets + int l0,l1,l2,l3; // local + int p0,p1,p2,p3; // perm + int ptype; + StencilEntry *SE0; + StencilEntry *SE1; + StencilEntry *SE2; + StencilEntry *SE3; + + for(int s=0;s::FUNC(StencilImpl &st, LebesgueOrder &lo, \ + DoubledGaugeField &U, \ + DoubledGaugeField &UUU, \ + SiteSpinor *buf, int LLs, \ + int sU, const FermionField &in, FermionField &out); + +KERNEL_INSTANTIATE(StaggeredKernels,DhopSiteAsm,StaggeredImplD); +KERNEL_INSTANTIATE(StaggeredKernels,DhopSiteAsm,StaggeredImplF); +KERNEL_INSTANTIATE(StaggeredKernels,DhopSiteAsm,StaggeredVec5dImplD); +KERNEL_INSTANTIATE(StaggeredKernels,DhopSiteAsm,StaggeredVec5dImplF); + +}} + diff --git a/lib/qcd/action/fermion/StaggeredKernelsHand.cc b/lib/qcd/action/fermion/StaggeredKernelsHand.cc index 5f9e11e5..7de8480c 100644 --- a/lib/qcd/action/fermion/StaggeredKernelsHand.cc +++ b/lib/qcd/action/fermion/StaggeredKernelsHand.cc @@ -90,10 +90,32 @@ namespace Grid { namespace QCD { +template +void StaggeredKernels::DhopSiteHand(StencilImpl &st, LebesgueOrder &lo, DoubledGaugeField &U,DoubledGaugeField &UUU, + SiteSpinor *buf, int LLs, + int sU, const FermionField &in, FermionField &out, int dag) +{ + SiteSpinor naik; + SiteSpinor naive; + int oneLink =0; + int threeLink=1; + int skew(0); + Real scale(1.0); + + if(dag) scale = -1.0; + + for(int s=0;s void StaggeredKernels::DhopSiteDepthHand(StencilImpl &st, LebesgueOrder &lo, DoubledGaugeField &U, - SiteSpinor *buf, int sF, - int sU, const FermionField &in, SiteSpinor &out,int threeLink) { + SiteSpinor *buf, int sF, + int sU, const FermionField &in, SiteSpinor &out,int threeLink) { typedef typename Simd::scalar_type S; typedef typename Simd::vector_type V; @@ -275,9 +297,26 @@ void StaggeredKernels::DhopSiteDepthHand(StencilImpl &st, LebesgueOrder &l vstream(out()()(1),even_1+odd_1); vstream(out()()(2),even_2+odd_2); - } } -FermOpStaggeredTemplateInstantiate(StaggeredKernels); +#define DHOP_SITE_HAND_INSTANTIATE(IMPL) \ + template void StaggeredKernels::DhopSiteHand(StencilImpl &st, LebesgueOrder &lo, \ + DoubledGaugeField &U,DoubledGaugeField &UUU, \ + SiteSpinor *buf, int LLs, \ + int sU, const FermionField &in, FermionField &out, int dag); + +#define DHOP_SITE_DEPTH_HAND_INSTANTIATE(IMPL) \ + template void StaggeredKernels::DhopSiteDepthHand(StencilImpl &st, LebesgueOrder &lo, DoubledGaugeField &U, \ + SiteSpinor *buf, int sF, \ + int sU, const FermionField &in, SiteSpinor &out,int threeLink) ; +DHOP_SITE_HAND_INSTANTIATE(StaggeredImplD); +DHOP_SITE_HAND_INSTANTIATE(StaggeredImplF); +DHOP_SITE_HAND_INSTANTIATE(StaggeredVec5dImplD); +DHOP_SITE_HAND_INSTANTIATE(StaggeredVec5dImplF); + +DHOP_SITE_DEPTH_HAND_INSTANTIATE(StaggeredImplD); +DHOP_SITE_DEPTH_HAND_INSTANTIATE(StaggeredImplF); +DHOP_SITE_DEPTH_HAND_INSTANTIATE(StaggeredVec5dImplD); +DHOP_SITE_DEPTH_HAND_INSTANTIATE(StaggeredVec5dImplF); }} diff --git a/lib/qcd/action/fermion/WilsonCompressor.h b/lib/qcd/action/fermion/WilsonCompressor.h index 41f24e1b..cc5c3c63 100644 --- a/lib/qcd/action/fermion/WilsonCompressor.h +++ b/lib/qcd/action/fermion/WilsonCompressor.h @@ -33,227 +33,359 @@ Author: paboyle namespace Grid { namespace QCD { - template - class WilsonCompressor { - public: - int mu; - int dag; +///////////////////////////////////////////////////////////////////////////////////////////// +// optimised versions supporting half precision too +///////////////////////////////////////////////////////////////////////////////////////////// - WilsonCompressor(int _dag){ - mu=0; - dag=_dag; - assert((dag==0)||(dag==1)); - } - void Point(int p) { - mu=p; - }; +template +class WilsonCompressorTemplate; - inline SiteHalfSpinor operator () (const SiteSpinor &in) { - SiteHalfSpinor ret; - int mudag=mu; - if (!dag) { - mudag=(mu+Nd)%(2*Nd); - } - switch(mudag) { - case Xp: - spProjXp(ret,in); - break; - case Yp: - spProjYp(ret,in); - break; - case Zp: - spProjZp(ret,in); - break; - case Tp: - spProjTp(ret,in); - break; - case Xm: - spProjXm(ret,in); - break; - case Ym: - spProjYm(ret,in); - break; - case Zm: - spProjZm(ret,in); - break; - case Tm: - spProjTm(ret,in); - break; - default: - assert(0); - break; - } - return ret; - } - }; - ///////////////////////// - // optimised versions - ///////////////////////// +template +class WilsonCompressorTemplate< _HCspinor, _Hspinor, _Spinor, projector, + typename std::enable_if::value>::type > +{ + public: + + int mu,dag; - template - class WilsonXpCompressor { - public: - inline SiteHalfSpinor operator () (const SiteSpinor &in) { - SiteHalfSpinor ret; - spProjXp(ret,in); - return ret; - } - }; - template - class WilsonYpCompressor { - public: - inline SiteHalfSpinor operator () (const SiteSpinor &in) { - SiteHalfSpinor ret; - spProjYp(ret,in); - return ret; - } - }; - template - class WilsonZpCompressor { - public: - inline SiteHalfSpinor operator () (const SiteSpinor &in) { - SiteHalfSpinor ret; - spProjZp(ret,in); - return ret; - } - }; - template - class WilsonTpCompressor { - public: - inline SiteHalfSpinor operator () (const SiteSpinor &in) { - SiteHalfSpinor ret; - spProjTp(ret,in); - return ret; - } - }; + void Point(int p) { mu=p; }; - template - class WilsonXmCompressor { - public: - inline SiteHalfSpinor operator () (const SiteSpinor &in) { - SiteHalfSpinor ret; - spProjXm(ret,in); - return ret; - } - }; - template - class WilsonYmCompressor { - public: - inline SiteHalfSpinor operator () (const SiteSpinor &in) { - SiteHalfSpinor ret; - spProjYm(ret,in); - return ret; - } - }; - template - class WilsonZmCompressor { - public: - inline SiteHalfSpinor operator () (const SiteSpinor &in) { - SiteHalfSpinor ret; - spProjZm(ret,in); - return ret; - } - }; - template - class WilsonTmCompressor { - public: - inline SiteHalfSpinor operator () (const SiteSpinor &in) { - SiteHalfSpinor ret; - spProjTm(ret,in); - return ret; - } - }; + WilsonCompressorTemplate(int _dag=0){ + dag = _dag; + } - // Fast comms buffer manipulation which should inline right through (avoid direction - // dependent logic that prevents inlining - template - class WilsonStencil : public CartesianStencil { - public: + typedef _Spinor SiteSpinor; + typedef _Hspinor SiteHalfSpinor; + typedef _HCspinor SiteHalfCommSpinor; + typedef typename SiteHalfCommSpinor::vector_type vComplexLow; + typedef typename SiteHalfSpinor::vector_type vComplexHigh; + constexpr static int Nw=sizeof(SiteHalfSpinor)/sizeof(vComplexHigh); - WilsonStencil(GridBase *grid, + inline int CommDatumSize(void) { + return sizeof(SiteHalfCommSpinor); + } + + /*****************************************************/ + /* Compress includes precision change if mpi data is not same */ + /*****************************************************/ + inline void Compress(SiteHalfSpinor *buf,Integer o,const SiteSpinor &in) { + projector::Proj(buf[o],in,mu,dag); + } + + /*****************************************************/ + /* Exchange includes precision change if mpi data is not same */ + /*****************************************************/ + inline void Exchange(SiteHalfSpinor *mp, + SiteHalfSpinor *vp0, + SiteHalfSpinor *vp1, + Integer type,Integer o){ + exchange(mp[2*o],mp[2*o+1],vp0[o],vp1[o],type); + } + + /*****************************************************/ + /* Have a decompression step if mpi data is not same */ + /*****************************************************/ + inline void Decompress(SiteHalfSpinor *out, + SiteHalfSpinor *in, Integer o) { + assert(0); + } + + /*****************************************************/ + /* Compress Exchange */ + /*****************************************************/ + inline void CompressExchange(SiteHalfSpinor *out0, + SiteHalfSpinor *out1, + const SiteSpinor *in, + Integer j,Integer k, Integer m,Integer type){ + SiteHalfSpinor temp1, temp2,temp3,temp4; + projector::Proj(temp1,in[k],mu,dag); + projector::Proj(temp2,in[m],mu,dag); + exchange(out0[j],out1[j],temp1,temp2,type); + } + + /*****************************************************/ + /* Pass the info to the stencil */ + /*****************************************************/ + inline bool DecompressionStep(void) { return false; } + +}; + +template +class WilsonCompressorTemplate< _HCspinor, _Hspinor, _Spinor, projector, + typename std::enable_if::value>::type > +{ + public: + + int mu,dag; + + void Point(int p) { mu=p; }; + + WilsonCompressorTemplate(int _dag=0){ + dag = _dag; + } + + typedef _Spinor SiteSpinor; + typedef _Hspinor SiteHalfSpinor; + typedef _HCspinor SiteHalfCommSpinor; + typedef typename SiteHalfCommSpinor::vector_type vComplexLow; + typedef typename SiteHalfSpinor::vector_type vComplexHigh; + constexpr static int Nw=sizeof(SiteHalfSpinor)/sizeof(vComplexHigh); + + inline int CommDatumSize(void) { + return sizeof(SiteHalfCommSpinor); + } + + /*****************************************************/ + /* Compress includes precision change if mpi data is not same */ + /*****************************************************/ + inline void Compress(SiteHalfSpinor *buf,Integer o,const SiteSpinor &in) { + SiteHalfSpinor hsp; + SiteHalfCommSpinor *hbuf = (SiteHalfCommSpinor *)buf; + projector::Proj(hsp,in,mu,dag); + precisionChange((vComplexLow *)&hbuf[o],(vComplexHigh *)&hsp,Nw); + } + + /*****************************************************/ + /* Exchange includes precision change if mpi data is not same */ + /*****************************************************/ + inline void Exchange(SiteHalfSpinor *mp, + SiteHalfSpinor *vp0, + SiteHalfSpinor *vp1, + Integer type,Integer o){ + SiteHalfSpinor vt0,vt1; + SiteHalfCommSpinor *vpp0 = (SiteHalfCommSpinor *)vp0; + SiteHalfCommSpinor *vpp1 = (SiteHalfCommSpinor *)vp1; + precisionChange((vComplexHigh *)&vt0,(vComplexLow *)&vpp0[o],Nw); + precisionChange((vComplexHigh *)&vt1,(vComplexLow *)&vpp1[o],Nw); + exchange(mp[2*o],mp[2*o+1],vt0,vt1,type); + } + + /*****************************************************/ + /* Have a decompression step if mpi data is not same */ + /*****************************************************/ + inline void Decompress(SiteHalfSpinor *out, + SiteHalfSpinor *in, Integer o){ + SiteHalfCommSpinor *hin=(SiteHalfCommSpinor *)in; + precisionChange((vComplexHigh *)&out[o],(vComplexLow *)&hin[o],Nw); + } + + /*****************************************************/ + /* Compress Exchange */ + /*****************************************************/ + inline void CompressExchange(SiteHalfSpinor *out0, + SiteHalfSpinor *out1, + const SiteSpinor *in, + Integer j,Integer k, Integer m,Integer type){ + SiteHalfSpinor temp1, temp2,temp3,temp4; + SiteHalfCommSpinor *hout0 = (SiteHalfCommSpinor *)out0; + SiteHalfCommSpinor *hout1 = (SiteHalfCommSpinor *)out1; + projector::Proj(temp1,in[k],mu,dag); + projector::Proj(temp2,in[m],mu,dag); + exchange(temp3,temp4,temp1,temp2,type); + precisionChange((vComplexLow *)&hout0[j],(vComplexHigh *)&temp3,Nw); + precisionChange((vComplexLow *)&hout1[j],(vComplexHigh *)&temp4,Nw); + } + + /*****************************************************/ + /* Pass the info to the stencil */ + /*****************************************************/ + inline bool DecompressionStep(void) { return true; } + +}; + +#define DECLARE_PROJ(Projector,Compressor,spProj) \ + class Projector { \ + public: \ + template \ + static void Proj(hsp &result,const fsp &in,int mu,int dag){ \ + spProj(result,in); \ + } \ + }; \ +template using Compressor = WilsonCompressorTemplate; + +DECLARE_PROJ(WilsonXpProjector,WilsonXpCompressor,spProjXp); +DECLARE_PROJ(WilsonYpProjector,WilsonYpCompressor,spProjYp); +DECLARE_PROJ(WilsonZpProjector,WilsonZpCompressor,spProjZp); +DECLARE_PROJ(WilsonTpProjector,WilsonTpCompressor,spProjTp); +DECLARE_PROJ(WilsonXmProjector,WilsonXmCompressor,spProjXm); +DECLARE_PROJ(WilsonYmProjector,WilsonYmCompressor,spProjYm); +DECLARE_PROJ(WilsonZmProjector,WilsonZmCompressor,spProjZm); +DECLARE_PROJ(WilsonTmProjector,WilsonTmCompressor,spProjTm); + +class WilsonProjector { + public: + template + static void Proj(hsp &result,const fsp &in,int mu,int dag){ + int mudag=dag? mu : (mu+Nd)%(2*Nd); + switch(mudag) { + case Xp: spProjXp(result,in); break; + case Yp: spProjYp(result,in); break; + case Zp: spProjZp(result,in); break; + case Tp: spProjTp(result,in); break; + case Xm: spProjXm(result,in); break; + case Ym: spProjYm(result,in); break; + case Zm: spProjZm(result,in); break; + case Tm: spProjTm(result,in); break; + default: assert(0); break; + } + } +}; +template using WilsonCompressor = WilsonCompressorTemplate; + +// Fast comms buffer manipulation which should inline right through (avoid direction +// dependent logic that prevents inlining +template +class WilsonStencil : public CartesianStencil { +public: + double timer0; + double timer1; + double timer2; + double timer3; + double timer4; + double timer5; + double timer6; + uint64_t callsi; + void ZeroCountersi(void) + { + timer0=0; + timer1=0; + timer2=0; + timer3=0; + timer4=0; + timer5=0; + timer6=0; + callsi=0; + } + void Reporti(int calls) + { + if ( timer0 ) std::cout << GridLogMessage << " timer0 (HaloGatherOpt) " < same_node; + std::vector surface_list; + + WilsonStencil(GridBase *grid, int npoints, int checkerboard, const std::vector &directions, - const std::vector &distances) : CartesianStencil (grid,npoints,checkerboard,directions,distances) - { }; - - template < class compressor> - std::thread HaloExchangeOptBegin(const Lattice &source,compressor &compress) { - this->Mergers.resize(0); - this->Packets.resize(0); - this->HaloGatherOpt(source,compress); - return std::thread([&] { this->Communicate(); }); - } - - template < class compressor> - void HaloExchangeOpt(const Lattice &source,compressor &compress) - { - auto thr = this->HaloExchangeOptBegin(source,compress); - this->HaloExchangeOptComplete(thr); - } - - void HaloExchangeOptComplete(std::thread &thr) - { - this->CommsMerge(); // spins - this->jointime-=usecond(); - thr.join(); - this->jointime+=usecond(); - } - - template < class compressor> - void HaloGatherOpt(const Lattice &source,compressor &compress) - { - // conformable(source._grid,_grid); - assert(source._grid==this->_grid); - this->halogtime-=usecond(); - - assert (this->comm_buf.size() == this->_unified_buffer_size ); - this->u_comm_offset=0; - - int dag = compress.dag; - static std::vector dirs(Nd*2); - for(int mu=0;mu XpCompress; - this->HaloGatherDir(source,XpCompress,dirs[0]); - - WilsonYpCompressor YpCompress; - this->HaloGatherDir(source,YpCompress,dirs[1]); - - WilsonZpCompressor ZpCompress; - this->HaloGatherDir(source,ZpCompress,dirs[2]); - - WilsonTpCompressor TpCompress; - this->HaloGatherDir(source,TpCompress,dirs[3]); - - WilsonXmCompressor XmCompress; - this->HaloGatherDir(source,XmCompress,dirs[4]); - - WilsonYmCompressor YmCompress; - this->HaloGatherDir(source,YmCompress,dirs[5]); - - WilsonZmCompressor ZmCompress; - this->HaloGatherDir(source,ZmCompress,dirs[6]); - - WilsonTmCompressor TmCompress; - this->HaloGatherDir(source,TmCompress,dirs[7]); - - assert(this->u_comm_offset==this->_unified_buffer_size); - this->halogtime+=usecond(); - } - + const std::vector &distances) + : CartesianStencil (grid,npoints,checkerboard,directions,distances) , + same_node(npoints) + { + ZeroCountersi(); + surface_list.resize(0); }; + void BuildSurfaceList(int Ls,int vol4){ + + // find same node for SHM + // Here we know the distance is 1 for WilsonStencil + for(int point=0;point_npoints;point++){ + same_node[point] = this->SameNode(point); + } + + for(int site = 0 ;site< vol4;site++){ + int local = 1; + for(int point=0;point_npoints;point++){ + if( (!this->GetNodeLocal(site*Ls,point)) && (!same_node[point]) ){ + local = 0; + } + } + if(local == 0) { + surface_list.push_back(site); + } + } + } + + template < class compressor> + void HaloExchangeOpt(const Lattice &source,compressor &compress) + { + std::vector > reqs; + this->HaloExchangeOptGather(source,compress); + double t1=usecond(); + // Asynchronous MPI calls multidirectional, Isend etc... + // this->CommunicateBegin(reqs); + // this->CommunicateComplete(reqs); + // Non-overlapped directions within a thread. Asynchronous calls except MPI3, threaded up to comm threads ways. + this->Communicate(); + double t2=usecond(); timer1 += t2-t1; + this->CommsMerge(compress); + double t3=usecond(); timer2 += t3-t2; + this->CommsMergeSHM(compress); + double t4=usecond(); timer3 += t4-t3; + } + + template + void HaloExchangeOptGather(const Lattice &source,compressor &compress) + { + this->Prepare(); + double t0=usecond(); + this->HaloGatherOpt(source,compress); + double t1=usecond(); + timer0 += t1-t0; + callsi++; + } + + template + void HaloGatherOpt(const Lattice &source,compressor &compress) + { + // Strategy. Inherit types from Compressor. + // Use types to select the write direction by directon compressor + typedef typename compressor::SiteSpinor SiteSpinor; + typedef typename compressor::SiteHalfSpinor SiteHalfSpinor; + typedef typename compressor::SiteHalfCommSpinor SiteHalfCommSpinor; + + this->mpi3synctime_g-=usecond(); + this->_grid->StencilBarrier(); + this->mpi3synctime_g+=usecond(); + + assert(source._grid==this->_grid); + this->halogtime-=usecond(); + + this->u_comm_offset=0; + + WilsonXpCompressor XpCompress; + WilsonYpCompressor YpCompress; + WilsonZpCompressor ZpCompress; + WilsonTpCompressor TpCompress; + WilsonXmCompressor XmCompress; + WilsonYmCompressor YmCompress; + WilsonZmCompressor ZmCompress; + WilsonTmCompressor TmCompress; + + int dag = compress.dag; + int face_idx=0; + if ( dag ) { + assert(same_node[Xp]==this->HaloGatherDir(source,XpCompress,Xp,face_idx)); + assert(same_node[Yp]==this->HaloGatherDir(source,YpCompress,Yp,face_idx)); + assert(same_node[Zp]==this->HaloGatherDir(source,ZpCompress,Zp,face_idx)); + assert(same_node[Tp]==this->HaloGatherDir(source,TpCompress,Tp,face_idx)); + assert(same_node[Xm]==this->HaloGatherDir(source,XmCompress,Xm,face_idx)); + assert(same_node[Ym]==this->HaloGatherDir(source,YmCompress,Ym,face_idx)); + assert(same_node[Zm]==this->HaloGatherDir(source,ZmCompress,Zm,face_idx)); + assert(same_node[Tm]==this->HaloGatherDir(source,TmCompress,Tm,face_idx)); + } else { + assert(same_node[Xp]==this->HaloGatherDir(source,XmCompress,Xp,face_idx)); + assert(same_node[Yp]==this->HaloGatherDir(source,YmCompress,Yp,face_idx)); + assert(same_node[Zp]==this->HaloGatherDir(source,ZmCompress,Zp,face_idx)); + assert(same_node[Tp]==this->HaloGatherDir(source,TmCompress,Tp,face_idx)); + assert(same_node[Xm]==this->HaloGatherDir(source,XpCompress,Xm,face_idx)); + assert(same_node[Ym]==this->HaloGatherDir(source,YpCompress,Ym,face_idx)); + assert(same_node[Zm]==this->HaloGatherDir(source,ZpCompress,Zm,face_idx)); + assert(same_node[Tm]==this->HaloGatherDir(source,TpCompress,Tm,face_idx)); + } + this->face_table_computed=1; + assert(this->u_comm_offset==this->_unified_buffer_size); + this->halogtime+=usecond(); + } + + }; }} // namespace close #endif diff --git a/lib/qcd/action/fermion/WilsonFermion.cc b/lib/qcd/action/fermion/WilsonFermion.cc index 39a769d5..19f9674d 100644 --- a/lib/qcd/action/fermion/WilsonFermion.cc +++ b/lib/qcd/action/fermion/WilsonFermion.cc @@ -1,3 +1,4 @@ + /************************************************************************************* Grid physics library, www.github.com/paboyle/Grid @@ -29,14 +30,14 @@ See the full license in the file "LICENSE" in the top level distribution directory *************************************************************************************/ /* END LEGAL */ -#include +#include +#include namespace Grid { namespace QCD { const std::vector WilsonFermionStatic::directions({0, 1, 2, 3, 0, 1, 2, 3}); const std::vector WilsonFermionStatic::displacements({1, 1, 1, 1, -1, -1, -1, -1}); - int WilsonFermionStatic::HandOptDslash; ///////////////////////////////// @@ -51,16 +52,16 @@ WilsonFermion::WilsonFermion(GaugeField &_Umu, GridCartesian &Fgrid, _grid(&Fgrid), _cbgrid(&Hgrid), Stencil(&Fgrid, npoint, Even, directions, displacements), - StencilEven(&Hgrid, npoint, Even, directions, - displacements), // source is Even - StencilOdd(&Hgrid, npoint, Odd, directions, - displacements), // source is Odd + StencilEven(&Hgrid, npoint, Even, directions,displacements), // source is Even + StencilOdd(&Hgrid, npoint, Odd, directions,displacements), // source is Odd mass(_mass), Lebesgue(_grid), LebesgueEvenOdd(_cbgrid), Umu(&Fgrid), UmuEven(&Hgrid), - UmuOdd(&Hgrid) { + UmuOdd(&Hgrid), + _tmp(&Hgrid) +{ // Allocate the required comms buffer ImportGauge(_Umu); } @@ -110,86 +111,84 @@ void WilsonFermion::MeooeDag(const FermionField &in, FermionField &out) { } } - template - void WilsonFermion::Mooee(const FermionField &in, FermionField &out) { - out.checkerboard = in.checkerboard; - typename FermionField::scalar_type scal(4.0 + mass); - out = scal * in; - } +template +void WilsonFermion::Mooee(const FermionField &in, FermionField &out) { + out.checkerboard = in.checkerboard; + typename FermionField::scalar_type scal(4.0 + mass); + out = scal * in; +} - template - void WilsonFermion::MooeeDag(const FermionField &in, FermionField &out) { - out.checkerboard = in.checkerboard; - Mooee(in, out); - } +template +void WilsonFermion::MooeeDag(const FermionField &in, FermionField &out) { + out.checkerboard = in.checkerboard; + Mooee(in, out); +} - template - void WilsonFermion::MooeeInv(const FermionField &in, FermionField &out) { - out.checkerboard = in.checkerboard; - out = (1.0/(4.0+mass))*in; +template +void WilsonFermion::MooeeInv(const FermionField &in, FermionField &out) { + out.checkerboard = in.checkerboard; + out = (1.0/(4.0+mass))*in; +} + +template +void WilsonFermion::MooeeInvDag(const FermionField &in, FermionField &out) { + out.checkerboard = in.checkerboard; + MooeeInv(in,out); +} +template +void WilsonFermion::MomentumSpacePropagator(FermionField &out, const FermionField &in,RealD _m) +{ + typedef typename FermionField::vector_type vector_type; + typedef typename FermionField::scalar_type ScalComplex; + typedef Lattice > LatComplex; + + // what type LatticeComplex + conformable(_grid,out._grid); + + Gamma::Algebra Gmu [] = { + Gamma::Algebra::GammaX, + Gamma::Algebra::GammaY, + Gamma::Algebra::GammaZ, + Gamma::Algebra::GammaT + }; + + std::vector latt_size = _grid->_fdimensions; + + FermionField num (_grid); num = zero; + LatComplex wilson(_grid); wilson= zero; + LatComplex one (_grid); one = ScalComplex(1.0,0.0); + + LatComplex denom(_grid); denom= zero; + LatComplex kmu(_grid); + ScalComplex ci(0.0,1.0); + // momphase = n * 2pi / L + for(int mu=0;mu - void WilsonFermion::MooeeInvDag(const FermionField &in, FermionField &out) { - out.checkerboard = in.checkerboard; - MooeeInv(in,out); - } - - template - void WilsonFermion::MomentumSpacePropagator(FermionField &out, const FermionField &in,RealD _m) { - - // what type LatticeComplex - conformable(_grid,out._grid); - - typedef typename FermionField::vector_type vector_type; - typedef typename FermionField::scalar_type ScalComplex; - - typedef Lattice > LatComplex; - - Gamma::GammaMatrix Gmu [] = { - Gamma::GammaX, - Gamma::GammaY, - Gamma::GammaZ, - Gamma::GammaT - }; - - std::vector latt_size = _grid->_fdimensions; - - FermionField num (_grid); num = zero; - LatComplex wilson(_grid); wilson= zero; - LatComplex one (_grid); one = ScalComplex(1.0,0.0); - - LatComplex denom(_grid); denom= zero; - LatComplex kmu(_grid); - ScalComplex ci(0.0,1.0); - // momphase = n * 2pi / L - for(int mu=0;mu::DerivInternal(StencilImpl &st, DoubledGaugeField &U, //////////////////////// // Call the single hop //////////////////////// - PARALLEL_FOR_LOOP - for (int sss = 0; sss < B._grid->oSites(); sss++) { - Kernels::DiracOptDhopDir(st, U, st.CommBuf(), sss, sss, B, Btilde, mu, gamma); + parallel_for (int sss = 0; sss < B._grid->oSites(); sss++) { + Kernels::DhopDir(st, U, st.CommBuf(), sss, sss, B, Btilde, mu, gamma); } ////////////////////////////////////////////////// @@ -232,8 +230,7 @@ void WilsonFermion::DerivInternal(StencilImpl &st, DoubledGaugeField &U, } template -void WilsonFermion::DhopDeriv(GaugeField &mat, const FermionField &U, - const FermionField &V, int dag) { +void WilsonFermion::DhopDeriv(GaugeField &mat, const FermionField &U, const FermionField &V, int dag) { conformable(U._grid, _grid); conformable(U._grid, V._grid); conformable(U._grid, mat._grid); @@ -244,12 +241,12 @@ void WilsonFermion::DhopDeriv(GaugeField &mat, const FermionField &U, } template -void WilsonFermion::DhopDerivOE(GaugeField &mat, const FermionField &U, - const FermionField &V, int dag) { +void WilsonFermion::DhopDerivOE(GaugeField &mat, const FermionField &U, const FermionField &V, int dag) { conformable(U._grid, _cbgrid); conformable(U._grid, V._grid); - conformable(U._grid, mat._grid); - + //conformable(U._grid, mat._grid); not general, leaving as a comment (Guido) + // Motivation: look at the SchurDiff operator + assert(V.checkerboard == Even); assert(U.checkerboard == Odd); mat.checkerboard = Odd; @@ -258,11 +255,10 @@ void WilsonFermion::DhopDerivOE(GaugeField &mat, const FermionField &U, } template -void WilsonFermion::DhopDerivEO(GaugeField &mat, const FermionField &U, - const FermionField &V, int dag) { +void WilsonFermion::DhopDerivEO(GaugeField &mat, const FermionField &U, const FermionField &V, int dag) { conformable(U._grid, _cbgrid); conformable(U._grid, V._grid); - conformable(U._grid, mat._grid); + //conformable(U._grid, mat._grid); assert(V.checkerboard == Odd); assert(U.checkerboard == Even); @@ -272,8 +268,7 @@ void WilsonFermion::DhopDerivEO(GaugeField &mat, const FermionField &U, } template -void WilsonFermion::Dhop(const FermionField &in, FermionField &out, - int dag) { +void WilsonFermion::Dhop(const FermionField &in, FermionField &out, int dag) { conformable(in._grid, _grid); // verifies full grid conformable(in._grid, out._grid); @@ -283,8 +278,7 @@ void WilsonFermion::Dhop(const FermionField &in, FermionField &out, } template -void WilsonFermion::DhopOE(const FermionField &in, FermionField &out, - int dag) { +void WilsonFermion::DhopOE(const FermionField &in, FermionField &out, int dag) { conformable(in._grid, _cbgrid); // verifies half grid conformable(in._grid, out._grid); // drops the cb check @@ -295,8 +289,7 @@ void WilsonFermion::DhopOE(const FermionField &in, FermionField &out, } template -void WilsonFermion::DhopEO(const FermionField &in, FermionField &out, - int dag) { +void WilsonFermion::DhopEO(const FermionField &in, FermionField &out,int dag) { conformable(in._grid, _cbgrid); // verifies half grid conformable(in._grid, out._grid); // drops the cb check @@ -307,14 +300,12 @@ void WilsonFermion::DhopEO(const FermionField &in, FermionField &out, } template -void WilsonFermion::Mdir(const FermionField &in, FermionField &out, - int dir, int disp) { +void WilsonFermion::Mdir(const FermionField &in, FermionField &out, int dir, int disp) { DhopDir(in, out, dir, disp); } template -void WilsonFermion::DhopDir(const FermionField &in, FermionField &out, - int dir, int disp) { +void WilsonFermion::DhopDir(const FermionField &in, FermionField &out, int dir, int disp) { int skip = (disp == 1) ? 0 : 1; int dirdisp = dir + skip * 4; int gamma = dir + (1 - skip) * 4; @@ -323,15 +314,13 @@ void WilsonFermion::DhopDir(const FermionField &in, FermionField &out, }; template -void WilsonFermion::DhopDirDisp(const FermionField &in, FermionField &out, - int dirdisp, int gamma, int dag) { +void WilsonFermion::DhopDirDisp(const FermionField &in, FermionField &out,int dirdisp, int gamma, int dag) { Compressor compressor(dag); Stencil.HaloExchange(in, compressor); - PARALLEL_FOR_LOOP - for (int sss = 0; sss < in._grid->oSites(); sss++) { - Kernels::DiracOptDhopDir(Stencil, Umu, Stencil.CommBuf(), sss, sss, in, out, dirdisp, gamma); + parallel_for (int sss = 0; sss < in._grid->oSites(); sss++) { + Kernels::DhopDir(Stencil, Umu, Stencil.CommBuf(), sss, sss, in, out, dirdisp, gamma); } }; @@ -346,14 +335,12 @@ void WilsonFermion::DhopInternal(StencilImpl &st, LebesgueOrder &lo, st.HaloExchange(in, compressor); if (dag == DaggerYes) { - PARALLEL_FOR_LOOP - for (int sss = 0; sss < in._grid->oSites(); sss++) { - Kernels::DiracOptDhopSiteDag(st, lo, U, st.CommBuf(), sss, sss, 1, 1, in, out); + parallel_for (int sss = 0; sss < in._grid->oSites(); sss++) { + Kernels::DhopSiteDag(st, lo, U, st.CommBuf(), sss, sss, 1, 1, in, out); } } else { - PARALLEL_FOR_LOOP - for (int sss = 0; sss < in._grid->oSites(); sss++) { - Kernels::DiracOptDhopSite(st, lo, U, st.CommBuf(), sss, sss, 1, 1, in, out); + parallel_for (int sss = 0; sss < in._grid->oSites(); sss++) { + Kernels::DhopSite(st, lo, U, st.CommBuf(), sss, sss, 1, 1, in, out); } } }; diff --git a/lib/qcd/action/fermion/WilsonFermion.h b/lib/qcd/action/fermion/WilsonFermion.h index 40fbd1bf..933be732 100644 --- a/lib/qcd/action/fermion/WilsonFermion.h +++ b/lib/qcd/action/fermion/WilsonFermion.h @@ -58,6 +58,9 @@ class WilsonFermion : public WilsonKernels, public WilsonFermionStatic { GridBase *FermionGrid(void) { return _grid; } GridBase *FermionRedBlackGrid(void) { return _cbgrid; } + FermionField _tmp; + FermionField &tmp(void) { return _tmp; } + ////////////////////////////////////////////////////////////////// // override multiply; cut number routines if pass dagger argument // and also make interface more uniformly consistent diff --git a/lib/qcd/action/fermion/WilsonFermion5D.cc b/lib/qcd/action/fermion/WilsonFermion5D.cc index d2ac96e3..1da58ddb 100644 --- a/lib/qcd/action/fermion/WilsonFermion5D.cc +++ b/lib/qcd/action/fermion/WilsonFermion5D.cc @@ -11,6 +11,7 @@ Author: Peter Boyle Author: Peter Boyle Author: Peter Boyle Author: paboyle +Author: Guido Cossu This program is free software; you can redistribute it and/or modify it under the terms of the GNU General Public License as published by @@ -29,8 +30,9 @@ Author: paboyle See the full license in the file "LICENSE" in the top level distribution directory *************************************************************************************/ /* END LEGAL */ -#include -#include +#include +#include +#include namespace Grid { namespace QCD { @@ -60,155 +62,115 @@ WilsonFermion5D::WilsonFermion5D(GaugeField &_Umu, UmuEven(_FourDimRedBlackGrid), UmuOdd (_FourDimRedBlackGrid), Lebesgue(_FourDimGrid), - LebesgueEvenOdd(_FourDimRedBlackGrid) + LebesgueEvenOdd(_FourDimRedBlackGrid), + _tmp(&FiveDimRedBlackGrid) { + // some assertions + assert(FiveDimGrid._ndimension==5); + assert(FourDimGrid._ndimension==4); + assert(FourDimRedBlackGrid._ndimension==4); + assert(FiveDimRedBlackGrid._ndimension==5); + assert(FiveDimRedBlackGrid._checker_dim==1); // Don't checker the s direction + + // extent of fifth dim and not spread out + Ls=FiveDimGrid._fdimensions[0]; + assert(FiveDimRedBlackGrid._fdimensions[0]==Ls); + assert(FiveDimGrid._processors[0] ==1); + assert(FiveDimRedBlackGrid._processors[0] ==1); + + // Other dimensions must match the decomposition of the four-D fields + for(int d=0;d<4;d++){ + + assert(FiveDimGrid._processors[d+1] ==FourDimGrid._processors[d]); + assert(FiveDimRedBlackGrid._processors[d+1] ==FourDimGrid._processors[d]); + assert(FourDimRedBlackGrid._processors[d] ==FourDimGrid._processors[d]); + + assert(FiveDimGrid._fdimensions[d+1] ==FourDimGrid._fdimensions[d]); + assert(FiveDimRedBlackGrid._fdimensions[d+1]==FourDimGrid._fdimensions[d]); + assert(FourDimRedBlackGrid._fdimensions[d] ==FourDimGrid._fdimensions[d]); + + assert(FiveDimGrid._simd_layout[d+1] ==FourDimGrid._simd_layout[d]); + assert(FiveDimRedBlackGrid._simd_layout[d+1]==FourDimGrid._simd_layout[d]); + assert(FourDimRedBlackGrid._simd_layout[d] ==FourDimGrid._simd_layout[d]); + } + if (Impl::LsVectorised) { int nsimd = Simd::Nsimd(); - // some assertions - assert(FiveDimGrid._ndimension==5); - assert(FiveDimRedBlackGrid._ndimension==5); - assert(FiveDimRedBlackGrid._checker_dim==1); // Don't checker the s direction - assert(FourDimGrid._ndimension==4); - // Dimension zero of the five-d is the Ls direction - Ls=FiveDimGrid._fdimensions[0]; - assert(FiveDimGrid._processors[0] ==1); assert(FiveDimGrid._simd_layout[0] ==nsimd); - - assert(FiveDimRedBlackGrid._fdimensions[0]==Ls); - assert(FiveDimRedBlackGrid._processors[0] ==1); assert(FiveDimRedBlackGrid._simd_layout[0]==nsimd); - // Other dimensions must match the decomposition of the four-D fields for(int d=0;d<4;d++){ - assert(FiveDimRedBlackGrid._fdimensions[d+1]==FourDimGrid._fdimensions[d]); - assert(FiveDimRedBlackGrid._processors[d+1] ==FourDimGrid._processors[d]); - assert(FourDimGrid._simd_layout[d]=1); assert(FourDimRedBlackGrid._simd_layout[d]=1); assert(FiveDimRedBlackGrid._simd_layout[d+1]==1); - - assert(FiveDimGrid._fdimensions[d+1] ==FourDimGrid._fdimensions[d]); - assert(FiveDimGrid._processors[d+1] ==FourDimGrid._processors[d]); - assert(FiveDimGrid._simd_layout[d+1] ==FourDimGrid._simd_layout[d]); } } else { - - // some assertions - assert(FiveDimGrid._ndimension==5); - assert(FourDimGrid._ndimension==4); - assert(FiveDimRedBlackGrid._ndimension==5); - assert(FourDimRedBlackGrid._ndimension==4); - assert(FiveDimRedBlackGrid._checker_dim==1); // Dimension zero of the five-d is the Ls direction - Ls=FiveDimGrid._fdimensions[0]; - assert(FiveDimRedBlackGrid._fdimensions[0]==Ls); - assert(FiveDimRedBlackGrid._processors[0] ==1); assert(FiveDimRedBlackGrid._simd_layout[0]==1); - assert(FiveDimGrid._processors[0] ==1); assert(FiveDimGrid._simd_layout[0] ==1); - - // Other dimensions must match the decomposition of the four-D fields - for(int d=0;d<4;d++){ - assert(FourDimRedBlackGrid._fdimensions[d] ==FourDimGrid._fdimensions[d]); - assert(FiveDimRedBlackGrid._fdimensions[d+1]==FourDimGrid._fdimensions[d]); - - assert(FourDimRedBlackGrid._processors[d] ==FourDimGrid._processors[d]); - assert(FiveDimRedBlackGrid._processors[d+1] ==FourDimGrid._processors[d]); - - assert(FourDimRedBlackGrid._simd_layout[d] ==FourDimGrid._simd_layout[d]); - assert(FiveDimRedBlackGrid._simd_layout[d+1]==FourDimGrid._simd_layout[d]); - - assert(FiveDimGrid._fdimensions[d+1] ==FourDimGrid._fdimensions[d]); - assert(FiveDimGrid._processors[d+1] ==FourDimGrid._processors[d]); - assert(FiveDimGrid._simd_layout[d+1] ==FourDimGrid._simd_layout[d]); - } + } // Allocate the required comms buffer ImportGauge(_Umu); + // Build lists of exterior only nodes + int LLs = FiveDimGrid._rdimensions[0]; + int vol4; + vol4=FourDimGrid.oSites(); + Stencil.BuildSurfaceList(LLs,vol4); + + vol4=FourDimRedBlackGrid.oSites(); + StencilEven.BuildSurfaceList(LLs,vol4); + StencilOdd.BuildSurfaceList(LLs,vol4); + + // std::cout << GridLogMessage << " SurfaceLists "<< Stencil.surface_list.size() + // <<" " << StencilEven.surface_list.size()< -WilsonFermion5D::WilsonFermion5D(int simd,GaugeField &_Umu, - GridCartesian &FiveDimGrid, - GridRedBlackCartesian &FiveDimRedBlackGrid, - GridCartesian &FourDimGrid, - RealD _M5,const ImplParams &p) : -{ - int nsimd = Simd::Nsimd(); - - // some assertions - assert(FiveDimGrid._ndimension==5); - assert(FiveDimRedBlackGrid._ndimension==5); - assert(FiveDimRedBlackGrid._checker_dim==0); // Checkerboard the s-direction - assert(FourDimGrid._ndimension==4); - - // Dimension zero of the five-d is the Ls direction - Ls=FiveDimGrid._fdimensions[0]; - assert(FiveDimGrid._processors[0] ==1); - assert(FiveDimGrid._simd_layout[0] ==nsimd); - - assert(FiveDimRedBlackGrid._fdimensions[0]==Ls); - assert(FiveDimRedBlackGrid._processors[0] ==1); - assert(FiveDimRedBlackGrid._simd_layout[0]==nsimd); - - // Other dimensions must match the decomposition of the four-D fields - for(int d=0;d<4;d++){ - assert(FiveDimRedBlackGrid._fdimensions[d+1]==FourDimGrid._fdimensions[d]); - assert(FiveDimRedBlackGrid._processors[d+1] ==FourDimGrid._processors[d]); - - assert(FourDimGrid._simd_layout[d]=1); - assert(FiveDimRedBlackGrid._simd_layout[d+1]==1); - - assert(FiveDimGrid._fdimensions[d+1] ==FourDimGrid._fdimensions[d]); - assert(FiveDimGrid._processors[d+1] ==FourDimGrid._processors[d]); - assert(FiveDimGrid._simd_layout[d+1] ==FourDimGrid._simd_layout[d]); - } - - { - } -} - */ template void WilsonFermion5D::Report(void) { - std::vector latt = GridDefaultLatt(); - RealD volume = Ls; for(int mu=0;mu_Nprocessors; + RealD NP = _FourDimGrid->_Nprocessors; + RealD NN = _FourDimGrid->NodeCount(); + RealD volume = Ls; + std::vector latt = _FourDimGrid->GlobalDimensions(); + for(int mu=0;mu 0 ) { std::cout << GridLogMessage << "#### Dhop calls report " << std::endl; - std::cout << GridLogMessage << "WilsonFermion5D Number of Dhop Calls : " << DhopCalls << std::endl; - std::cout << GridLogMessage << "WilsonFermion5D Total Communication time : " << DhopCommTime<< " us" << std::endl; - std::cout << GridLogMessage << "WilsonFermion5D CommTime/Calls : " << DhopCommTime / DhopCalls << " us" << std::endl; - std::cout << GridLogMessage << "WilsonFermion5D Total Compute time : " << DhopComputeTime << " us" << std::endl; - std::cout << GridLogMessage << "WilsonFermion5D ComputeTime/Calls : " << DhopComputeTime / DhopCalls << " us" << std::endl; + std::cout << GridLogMessage << "WilsonFermion5D Number of DhopEO Calls : " << DhopCalls << std::endl; + std::cout << GridLogMessage << "WilsonFermion5D TotalTime /Calls : " << DhopTotalTime / DhopCalls << " us" << std::endl; + std::cout << GridLogMessage << "WilsonFermion5D CommTime /Calls : " << DhopCommTime / DhopCalls << " us" << std::endl; + std::cout << GridLogMessage << "WilsonFermion5D FaceTime /Calls : " << DhopFaceTime / DhopCalls << " us" << std::endl; + std::cout << GridLogMessage << "WilsonFermion5D ComputeTime1/Calls : " << DhopComputeTime / DhopCalls << " us" << std::endl; + std::cout << GridLogMessage << "WilsonFermion5D ComputeTime2/Calls : " << DhopComputeTime2/ DhopCalls << " us" << std::endl; + // Average the compute time + _FourDimGrid->GlobalSum(DhopComputeTime); + DhopComputeTime/=NP; RealD mflops = 1344*volume*DhopCalls/DhopComputeTime/2; // 2 for red black counting std::cout << GridLogMessage << "Average mflops/s per call : " << mflops << std::endl; std::cout << GridLogMessage << "Average mflops/s per call per rank : " << mflops/NP << std::endl; + std::cout << GridLogMessage << "Average mflops/s per call per node : " << mflops/NN << std::endl; - RealD Fullmflops = 1344*volume*DhopCalls/(DhopComputeTime+DhopCommTime)/2; // 2 for red black counting + RealD Fullmflops = 1344*volume*DhopCalls/(DhopTotalTime)/2; // 2 for red black counting std::cout << GridLogMessage << "Average mflops/s per call (full) : " << Fullmflops << std::endl; std::cout << GridLogMessage << "Average mflops/s per call per rank (full): " << Fullmflops/NP << std::endl; - + std::cout << GridLogMessage << "Average mflops/s per call per node (full): " << Fullmflops/NN << std::endl; } if ( DerivCalls > 0 ) { std::cout << GridLogMessage << "#### Deriv calls report "<< std::endl; std::cout << GridLogMessage << "WilsonFermion5D Number of Deriv Calls : " <::Report(void) std::cout << GridLogMessage << "WilsonFermion5D StencilEven"< 0){ + std::cout << GridLogMessage << "WilsonFermion5D Stencil Reporti()" < @@ -231,6 +198,9 @@ void WilsonFermion5D::ZeroCounters(void) { DhopCalls = 0; DhopCommTime = 0; DhopComputeTime = 0; + DhopComputeTime2= 0; + DhopFaceTime = 0; + DhopTotalTime = 0; DerivCalls = 0; DerivCommTime = 0; @@ -240,6 +210,9 @@ void WilsonFermion5D::ZeroCounters(void) { Stencil.ZeroCounters(); StencilEven.ZeroCounters(); StencilOdd.ZeroCounters(); + Stencil.ZeroCountersi(); + StencilEven.ZeroCountersi(); + StencilOdd.ZeroCountersi(); } @@ -271,12 +244,11 @@ void WilsonFermion5D::DhopDir(const FermionField &in, FermionField &out,in assert(dirdisp<=7); assert(dirdisp>=0); -PARALLEL_FOR_LOOP - for(int ss=0;ssoSites();ss++){ + parallel_for(int ss=0;ssoSites();ss++){ for(int s=0;s::DerivInternal(StencilImpl & st, DerivCommTime+=usecond(); Atilde=A; + int LLs = B._grid->_rdimensions[0]; + DerivComputeTime-=usecond(); for (int mu = 0; mu < Nd; mu++) { @@ -319,8 +293,7 @@ void WilsonFermion5D::DerivInternal(StencilImpl & st, //////////////////////// DerivDhopComputeTime -= usecond(); - PARALLEL_FOR_LOOP - for (int sss = 0; sss < U._grid->oSites(); sss++) { + parallel_for (int sss = 0; sss < U._grid->oSites(); sss++) { for (int s = 0; s < Ls; s++) { int sU = sss; int sF = s + Ls * sU; @@ -328,13 +301,16 @@ void WilsonFermion5D::DerivInternal(StencilImpl & st, assert(sF < B._grid->oSites()); assert(sU < U._grid->oSites()); - Kernels::DiracOptDhopDir(st, U, st.CommBuf(), sF, sU, B, Btilde, mu, gamma); + Kernels::DhopDir(st, U, st.CommBuf(), sF, sU, B, Btilde, mu, gamma); //////////////////////////// // spin trace outer product //////////////////////////// } } + //////////////////////////// + // spin trace outer product + //////////////////////////// DerivDhopComputeTime += usecond(); Impl::InsertForce5D(mat, Btilde, Atilde, mu); } @@ -343,13 +319,14 @@ void WilsonFermion5D::DerivInternal(StencilImpl & st, template void WilsonFermion5D::DhopDeriv(GaugeField &mat, - const FermionField &A, - const FermionField &B, - int dag) + const FermionField &A, + const FermionField &B, + int dag) { conformable(A._grid,FermionGrid()); conformable(A._grid,B._grid); - conformable(GaugeGrid(),mat._grid); + + //conformable(GaugeGrid(),mat._grid);// this is not general! leaving as a comment mat.checkerboard = A.checkerboard; @@ -358,12 +335,11 @@ void WilsonFermion5D::DhopDeriv(GaugeField &mat, template void WilsonFermion5D::DhopDerivEO(GaugeField &mat, - const FermionField &A, - const FermionField &B, - int dag) + const FermionField &A, + const FermionField &B, + int dag) { conformable(A._grid,FermionRedBlackGrid()); - conformable(GaugeRedBlackGrid(),mat._grid); conformable(A._grid,B._grid); assert(B.checkerboard==Odd); @@ -376,12 +352,11 @@ void WilsonFermion5D::DhopDerivEO(GaugeField &mat, template void WilsonFermion5D::DhopDerivOE(GaugeField &mat, - const FermionField &A, - const FermionField &B, - int dag) + const FermionField &A, + const FermionField &B, + int dag) { conformable(A._grid,FermionRedBlackGrid()); - conformable(GaugeRedBlackGrid(),mat._grid); conformable(A._grid,B._grid); assert(B.checkerboard==Even); @@ -393,6 +368,124 @@ void WilsonFermion5D::DhopDerivOE(GaugeField &mat, template void WilsonFermion5D::DhopInternal(StencilImpl & st, LebesgueOrder &lo, + DoubledGaugeField & U, + const FermionField &in, FermionField &out,int dag) +{ + DhopTotalTime-=usecond(); +#ifdef GRID_OMP + if ( WilsonKernelsStatic::Comms == WilsonKernelsStatic::CommsAndCompute ) + DhopInternalOverlappedComms(st,lo,U,in,out,dag); + else +#endif + DhopInternalSerialComms(st,lo,U,in,out,dag); + DhopTotalTime+=usecond(); +} + + +template +void WilsonFermion5D::DhopInternalOverlappedComms(StencilImpl & st, LebesgueOrder &lo, + DoubledGaugeField & U, + const FermionField &in, FermionField &out,int dag) +{ +#ifdef GRID_OMP + // assert((dag==DaggerNo) ||(dag==DaggerYes)); + + Compressor compressor(dag); + + int LLs = in._grid->_rdimensions[0]; + int len = U._grid->oSites(); + + DhopFaceTime-=usecond(); + st.HaloExchangeOptGather(in,compressor); + st.CommsMergeSHM(compressor);// Could do this inside parallel region overlapped with comms + DhopFaceTime+=usecond(); + + double ctime=0; + double ptime=0; + + ////////////////////////////////////////////////////////////////////////////////////////////////////// + // Ugly explicit thread mapping introduced for OPA reasons. + ////////////////////////////////////////////////////////////////////////////////////////////////////// +#pragma omp parallel reduction(max:ctime) reduction(max:ptime) + { + int tid = omp_get_thread_num(); + int nthreads = omp_get_num_threads(); + int ncomms = CartesianCommunicator::nCommThreads; + if (ncomms == -1) ncomms = 1; + assert(nthreads > ncomms); + if (tid >= ncomms) { + double start = usecond(); + nthreads -= ncomms; + int ttid = tid - ncomms; + int n = U._grid->oSites(); + int chunk = n / nthreads; + int rem = n % nthreads; + int myblock, myn; + if (ttid < rem) { + myblock = ttid * chunk + ttid; + myn = chunk+1; + } else { + myblock = ttid*chunk + rem; + myn = chunk; + } + + // do the compute + if (dag == DaggerYes) { + for (int ss = myblock; ss < myblock+myn; ++ss) { + int sU = ss; + int sF = LLs * sU; + Kernels::DhopSiteDag(st,lo,U,st.CommBuf(),sF,sU,LLs,1,in,out,1,0); + } + } else { + for (int ss = myblock; ss < myblock+myn; ++ss) { + int sU = ss; + int sF = LLs * sU; + Kernels::DhopSite(st,lo,U,st.CommBuf(),sF,sU,LLs,1,in,out,1,0); + } + } + ptime = usecond() - start; + } + { + double start = usecond(); + st.CommunicateThreaded(); + ctime = usecond() - start; + } + } + DhopCommTime += ctime; + DhopComputeTime+=ptime; + + // First to enter, last to leave timing + st.CollateThreads(); + + DhopFaceTime-=usecond(); + st.CommsMerge(compressor); + DhopFaceTime+=usecond(); + + DhopComputeTime2-=usecond(); + if (dag == DaggerYes) { + int sz=st.surface_list.size(); + parallel_for (int ss = 0; ss < sz; ss++) { + int sU = st.surface_list[ss]; + int sF = LLs * sU; + Kernels::DhopSiteDag(st,lo,U,st.CommBuf(),sF,sU,LLs,1,in,out,0,1); + } + } else { + int sz=st.surface_list.size(); + parallel_for (int ss = 0; ss < sz; ss++) { + int sU = st.surface_list[ss]; + int sF = LLs * sU; + Kernels::DhopSite(st,lo,U,st.CommBuf(),sF,sU,LLs,1,in,out,0,1); + } + } + DhopComputeTime2+=usecond(); +#else + assert(0); +#endif +} + + +template +void WilsonFermion5D::DhopInternalSerialComms(StencilImpl & st, LebesgueOrder &lo, DoubledGaugeField & U, const FermionField &in, FermionField &out,int dag) { @@ -402,45 +495,23 @@ void WilsonFermion5D::DhopInternal(StencilImpl & st, LebesgueOrder &lo, int LLs = in._grid->_rdimensions[0]; DhopCommTime-=usecond(); - st.HaloExchange(in,compressor); + st.HaloExchangeOpt(in,compressor); DhopCommTime+=usecond(); DhopComputeTime-=usecond(); // Dhop takes the 4d grid from U, and makes a 5d index for fermion - if (dag == DaggerYes) { - PARALLEL_FOR_LOOP - for (int ss = 0; ss < U._grid->oSites(); ss++) { - int sU = ss; - int sF = LLs * sU; - Kernels::DiracOptDhopSiteDag(st, lo, U, st.CommBuf(), sF, sU, LLs, 1, in, out); - } -#ifdef AVX512 - } else if (stat.is_init() ) { - int nthreads; - stat.start(); -#pragma omp parallel - { -#pragma omp master - nthreads = omp_get_num_threads(); - int mythread = omp_get_thread_num(); - stat.enter(mythread); -#pragma omp for nowait - for(int ss=0;ssoSites();ss++) { - int sU=ss; - int sF=LLs*sU; - Kernels::DiracOptDhopSite(st,lo,U,st.CommBuf(),sF,sU,LLs,1,in,out); - } - stat.exit(mythread); - } - stat.accum(nthreads); -#endif - } else { - PARALLEL_FOR_LOOP - for (int ss = 0; ss < U._grid->oSites(); ss++) { + if (dag == DaggerYes) { + parallel_for (int ss = 0; ss < U._grid->oSites(); ss++) { int sU = ss; int sF = LLs * sU; - Kernels::DiracOptDhopSite(st,lo,U,st.CommBuf(),sF,sU,LLs,1,in,out); + Kernels::DhopSiteDag(st,lo,U,st.CommBuf(),sF,sU,LLs,1,in,out); + } + } else { + parallel_for (int ss = 0; ss < U._grid->oSites(); ss++) { + int sU = ss; + int sF = LLs * sU; + Kernels::DhopSite(st,lo,U,st.CommBuf(),sF,sU,LLs,1,in,out); } } DhopComputeTime+=usecond(); @@ -502,11 +573,11 @@ void WilsonFermion5D::MomentumSpacePropagatorHt(FermionField &out,const Fe typedef iSinglet Tcomplex; typedef Lattice > LatComplex; - Gamma::GammaMatrix Gmu [] = { - Gamma::GammaX, - Gamma::GammaY, - Gamma::GammaZ, - Gamma::GammaT + Gamma::Algebra Gmu [] = { + Gamma::Algebra::GammaX, + Gamma::Algebra::GammaY, + Gamma::Algebra::GammaZ, + Gamma::Algebra::GammaT }; std::vector latt_size = _grid->_fdimensions; @@ -573,11 +644,11 @@ void WilsonFermion5D::MomentumSpacePropagatorHt(FermionField &out,const Fe template void WilsonFermion5D::MomentumSpacePropagatorHw(FermionField &out,const FermionField &in,RealD mass) { - Gamma::GammaMatrix Gmu [] = { - Gamma::GammaX, - Gamma::GammaY, - Gamma::GammaZ, - Gamma::GammaT + Gamma::Algebra Gmu [] = { + Gamma::Algebra::GammaX, + Gamma::Algebra::GammaY, + Gamma::Algebra::GammaZ, + Gamma::Algebra::GammaT }; GridBase *_grid = _FourDimGrid; @@ -631,7 +702,6 @@ void WilsonFermion5D::MomentumSpacePropagatorHw(FermionField &out,const Fe } - FermOpTemplateInstantiate(WilsonFermion5D); GparityFermOpTemplateInstantiate(WilsonFermion5D); diff --git a/lib/qcd/action/fermion/WilsonFermion5D.h b/lib/qcd/action/fermion/WilsonFermion5D.h index ffb5c58e..e87e927e 100644 --- a/lib/qcd/action/fermion/WilsonFermion5D.h +++ b/lib/qcd/action/fermion/WilsonFermion5D.h @@ -31,7 +31,7 @@ Author: paboyle #ifndef GRID_QCD_WILSON_FERMION_5D_H #define GRID_QCD_WILSON_FERMION_5D_H -#include +#include namespace Grid { namespace QCD { @@ -74,11 +74,17 @@ namespace QCD { typedef WilsonKernels Kernels; PmuStat stat; + FermionField _tmp; + FermionField &tmp(void) { return _tmp; } + void Report(void); void ZeroCounters(void); double DhopCalls; double DhopCommTime; double DhopComputeTime; + double DhopComputeTime2; + double DhopFaceTime; + double DhopTotalTime; double DerivCalls; double DerivCommTime; @@ -142,6 +148,20 @@ namespace QCD { const FermionField &in, FermionField &out, int dag); + + void DhopInternalOverlappedComms(StencilImpl & st, + LebesgueOrder &lo, + DoubledGaugeField &U, + const FermionField &in, + FermionField &out, + int dag); + + void DhopInternalSerialComms(StencilImpl & st, + LebesgueOrder &lo, + DoubledGaugeField &U, + const FermionField &in, + FermionField &out, + int dag); // Constructors WilsonFermion5D(GaugeField &_Umu, diff --git a/lib/qcd/action/fermion/WilsonKernels.cc b/lib/qcd/action/fermion/WilsonKernels.cc index 43776c86..03c066b0 100644 --- a/lib/qcd/action/fermion/WilsonKernels.cc +++ b/lib/qcd/action/fermion/WilsonKernels.cc @@ -28,11 +28,13 @@ See the full license in the file "LICENSE" in the top level distribution directory *************************************************************************************/ /* END LEGAL */ -#include +#include + namespace Grid { namespace QCD { -int WilsonKernelsStatic::Opt; +int WilsonKernelsStatic::Opt = WilsonKernelsStatic::OptGeneric; +int WilsonKernelsStatic::Comms = WilsonKernelsStatic::CommsAndCompute; template WilsonKernels::WilsonKernels(const ImplParams &p) : Base(p){}; @@ -40,11 +42,72 @@ WilsonKernels::WilsonKernels(const ImplParams &p) : Base(p){}; //////////////////////////////////////////// // Generic implementation; move to different file? //////////////////////////////////////////// + +#define GENERIC_STENCIL_LEG(Dir,spProj,Recon) \ + SE = st.GetEntry(ptype, Dir, sF); \ + if (SE->_is_local) { \ + chi_p = χ \ + if (SE->_permute) { \ + spProj(tmp, in._odata[SE->_offset]); \ + permute(chi, tmp, ptype); \ + } else { \ + spProj(chi, in._odata[SE->_offset]); \ + } \ + } else { \ + chi_p = &buf[SE->_offset]; \ + } \ + Impl::multLink(Uchi, U._odata[sU], *chi_p, Dir, SE, st); \ + Recon(result, Uchi); + +#define GENERIC_STENCIL_LEG_INT(Dir,spProj,Recon) \ + SE = st.GetEntry(ptype, Dir, sF); \ + if (SE->_is_local) { \ + chi_p = χ \ + if (SE->_permute) { \ + spProj(tmp, in._odata[SE->_offset]); \ + permute(chi, tmp, ptype); \ + } else { \ + spProj(chi, in._odata[SE->_offset]); \ + } \ + } else if ( st.same_node[Dir] ) { \ + chi_p = &buf[SE->_offset]; \ + } \ + if (SE->_is_local || st.same_node[Dir] ) { \ + Impl::multLink(Uchi, U._odata[sU], *chi_p, Dir, SE, st); \ + Recon(result, Uchi); \ + } +#define GENERIC_STENCIL_LEG_EXT(Dir,spProj,Recon) \ + SE = st.GetEntry(ptype, Dir, sF); \ + if ((!SE->_is_local) && (!st.same_node[Dir]) ) { \ + chi_p = &buf[SE->_offset]; \ + Impl::multLink(Uchi, U._odata[sU], *chi_p, Dir, SE, st); \ + Recon(result, Uchi); \ + nmu++; \ + } + +#define GENERIC_DHOPDIR_LEG(Dir,spProj,Recon) \ + if (gamma == Dir) { \ + if (SE->_is_local && SE->_permute) { \ + spProj(tmp, in._odata[SE->_offset]); \ + permute(chi, tmp, ptype); \ + } else if (SE->_is_local) { \ + spProj(chi, in._odata[SE->_offset]); \ + } else { \ + chi = buf[SE->_offset]; \ + } \ + Impl::multLink(Uchi, U._odata[sU], chi, dir, SE, st); \ + Recon(result, Uchi); \ + } + + //////////////////////////////////////////////////////////////////// + // All legs kernels ; comms then compute + //////////////////////////////////////////////////////////////////// template -void WilsonKernels::DiracOptGenericDhopSiteDag(StencilImpl &st, LebesgueOrder &lo, DoubledGaugeField &U, - SiteHalfSpinor *buf, int sF, - int sU, const FermionField &in, FermionField &out) { +void WilsonKernels::GenericDhopSiteDag(StencilImpl &st, LebesgueOrder &lo, DoubledGaugeField &U, + SiteHalfSpinor *buf, int sF, + int sU, const FermionField &in, FermionField &out) +{ SiteHalfSpinor tmp; SiteHalfSpinor chi; SiteHalfSpinor *chi_p; @@ -53,174 +116,22 @@ void WilsonKernels::DiracOptGenericDhopSiteDag(StencilImpl &st, LebesgueOr StencilEntry *SE; int ptype; - /////////////////////////// - // Xp - /////////////////////////// - SE = st.GetEntry(ptype, Xp, sF); - - if (SE->_is_local) { - chi_p = χ - if (SE->_permute) { - spProjXp(tmp, in._odata[SE->_offset]); - permute(chi, tmp, ptype); - } else { - spProjXp(chi, in._odata[SE->_offset]); - } - } else { - chi_p = &buf[SE->_offset]; - } - - Impl::multLink(Uchi, U._odata[sU], *chi_p, Xp, SE, st); - spReconXp(result, Uchi); - - /////////////////////////// - // Yp - /////////////////////////// - SE = st.GetEntry(ptype, Yp, sF); - - if (SE->_is_local) { - chi_p = χ - if (SE->_permute) { - spProjYp(tmp, in._odata[SE->_offset]); - permute(chi, tmp, ptype); - } else { - spProjYp(chi, in._odata[SE->_offset]); - } - } else { - chi_p = &buf[SE->_offset]; - } - - Impl::multLink(Uchi, U._odata[sU], *chi_p, Yp, SE, st); - accumReconYp(result, Uchi); - - /////////////////////////// - // Zp - /////////////////////////// - SE = st.GetEntry(ptype, Zp, sF); - - if (SE->_is_local) { - chi_p = χ - if (SE->_permute) { - spProjZp(tmp, in._odata[SE->_offset]); - permute(chi, tmp, ptype); - } else { - spProjZp(chi, in._odata[SE->_offset]); - } - } else { - chi_p = &buf[SE->_offset]; - } - - Impl::multLink(Uchi, U._odata[sU], *chi_p, Zp, SE, st); - accumReconZp(result, Uchi); - - /////////////////////////// - // Tp - /////////////////////////// - SE = st.GetEntry(ptype, Tp, sF); - - if (SE->_is_local) { - chi_p = χ - if (SE->_permute) { - spProjTp(tmp, in._odata[SE->_offset]); - permute(chi, tmp, ptype); - } else { - spProjTp(chi, in._odata[SE->_offset]); - } - } else { - chi_p = &buf[SE->_offset]; - } - - Impl::multLink(Uchi, U._odata[sU], *chi_p, Tp, SE, st); - accumReconTp(result, Uchi); - - /////////////////////////// - // Xm - /////////////////////////// - SE = st.GetEntry(ptype, Xm, sF); - - if (SE->_is_local) { - chi_p = χ - if (SE->_permute) { - spProjXm(tmp, in._odata[SE->_offset]); - permute(chi, tmp, ptype); - } else { - spProjXm(chi, in._odata[SE->_offset]); - } - } else { - chi_p = &buf[SE->_offset]; - } - - Impl::multLink(Uchi, U._odata[sU], *chi_p, Xm, SE, st); - accumReconXm(result, Uchi); - - /////////////////////////// - // Ym - /////////////////////////// - SE = st.GetEntry(ptype, Ym, sF); - - if (SE->_is_local) { - chi_p = χ - if (SE->_permute) { - spProjYm(tmp, in._odata[SE->_offset]); - permute(chi, tmp, ptype); - } else { - spProjYm(chi, in._odata[SE->_offset]); - } - } else { - chi_p = &buf[SE->_offset]; - } - - Impl::multLink(Uchi, U._odata[sU], *chi_p, Ym, SE, st); - accumReconYm(result, Uchi); - - /////////////////////////// - // Zm - /////////////////////////// - SE = st.GetEntry(ptype, Zm, sF); - - if (SE->_is_local) { - chi_p = χ - if (SE->_permute) { - spProjZm(tmp, in._odata[SE->_offset]); - permute(chi, tmp, ptype); - } else { - spProjZm(chi, in._odata[SE->_offset]); - } - } else { - chi_p = &buf[SE->_offset]; - } - - Impl::multLink(Uchi, U._odata[sU], *chi_p, Zm, SE, st); - accumReconZm(result, Uchi); - - /////////////////////////// - // Tm - /////////////////////////// - SE = st.GetEntry(ptype, Tm, sF); - - if (SE->_is_local) { - chi_p = χ - if (SE->_permute) { - spProjTm(tmp, in._odata[SE->_offset]); - permute(chi, tmp, ptype); - } else { - spProjTm(chi, in._odata[SE->_offset]); - } - } else { - chi_p = &buf[SE->_offset]; - } - - Impl::multLink(Uchi, U._odata[sU], *chi_p, Tm, SE, st); - accumReconTm(result, Uchi); - + GENERIC_STENCIL_LEG(Xp,spProjXp,spReconXp); + GENERIC_STENCIL_LEG(Yp,spProjYp,accumReconYp); + GENERIC_STENCIL_LEG(Zp,spProjZp,accumReconZp); + GENERIC_STENCIL_LEG(Tp,spProjTp,accumReconTp); + GENERIC_STENCIL_LEG(Xm,spProjXm,accumReconXm); + GENERIC_STENCIL_LEG(Ym,spProjYm,accumReconYm); + GENERIC_STENCIL_LEG(Zm,spProjZm,accumReconZm); + GENERIC_STENCIL_LEG(Tm,spProjTm,accumReconTm); vstream(out._odata[sF], result); }; -// Need controls to do interior, exterior, or both template -void WilsonKernels::DiracOptGenericDhopSite(StencilImpl &st, LebesgueOrder &lo, DoubledGaugeField &U, - SiteHalfSpinor *buf, int sF, - int sU, const FermionField &in, FermionField &out) { +void WilsonKernels::GenericDhopSite(StencilImpl &st, LebesgueOrder &lo, DoubledGaugeField &U, + SiteHalfSpinor *buf, int sF, + int sU, const FermionField &in, FermionField &out) +{ SiteHalfSpinor tmp; SiteHalfSpinor chi; SiteHalfSpinor *chi_p; @@ -229,171 +140,126 @@ void WilsonKernels::DiracOptGenericDhopSite(StencilImpl &st, LebesgueOrder StencilEntry *SE; int ptype; - /////////////////////////// - // Xp - /////////////////////////// - SE = st.GetEntry(ptype, Xm, sF); - - if (SE->_is_local) { - chi_p = χ - if (SE->_permute) { - spProjXp(tmp, in._odata[SE->_offset]); - permute(chi, tmp, ptype); - } else { - spProjXp(chi, in._odata[SE->_offset]); - } - } else { - chi_p = &buf[SE->_offset]; - } - - Impl::multLink(Uchi, U._odata[sU], *chi_p, Xm, SE, st); - spReconXp(result, Uchi); - - /////////////////////////// - // Yp - /////////////////////////// - SE = st.GetEntry(ptype, Ym, sF); - - if (SE->_is_local) { - chi_p = χ - if (SE->_permute) { - spProjYp(tmp, in._odata[SE->_offset]); - permute(chi, tmp, ptype); - } else { - spProjYp(chi, in._odata[SE->_offset]); - } - } else { - chi_p = &buf[SE->_offset]; - } - - Impl::multLink(Uchi, U._odata[sU], *chi_p, Ym, SE, st); - accumReconYp(result, Uchi); - - /////////////////////////// - // Zp - /////////////////////////// - SE = st.GetEntry(ptype, Zm, sF); - - if (SE->_is_local) { - chi_p = χ - if (SE->_permute) { - spProjZp(tmp, in._odata[SE->_offset]); - permute(chi, tmp, ptype); - } else { - spProjZp(chi, in._odata[SE->_offset]); - } - } else { - chi_p = &buf[SE->_offset]; - } - - Impl::multLink(Uchi, U._odata[sU], *chi_p, Zm, SE, st); - accumReconZp(result, Uchi); - - /////////////////////////// - // Tp - /////////////////////////// - SE = st.GetEntry(ptype, Tm, sF); - - if (SE->_is_local) { - chi_p = χ - if (SE->_permute) { - spProjTp(tmp, in._odata[SE->_offset]); - permute(chi, tmp, ptype); - } else { - spProjTp(chi, in._odata[SE->_offset]); - } - } else { - chi_p = &buf[SE->_offset]; - } - - Impl::multLink(Uchi, U._odata[sU], *chi_p, Tm, SE, st); - accumReconTp(result, Uchi); - - /////////////////////////// - // Xm - /////////////////////////// - SE = st.GetEntry(ptype, Xp, sF); - - if (SE->_is_local) { - chi_p = χ - if (SE->_permute) { - spProjXm(tmp, in._odata[SE->_offset]); - permute(chi, tmp, ptype); - } else { - spProjXm(chi, in._odata[SE->_offset]); - } - } else { - chi_p = &buf[SE->_offset]; - } - - Impl::multLink(Uchi, U._odata[sU], *chi_p, Xp, SE, st); - accumReconXm(result, Uchi); - - /////////////////////////// - // Ym - /////////////////////////// - SE = st.GetEntry(ptype, Yp, sF); - - if (SE->_is_local) { - chi_p = χ - if (SE->_permute) { - spProjYm(tmp, in._odata[SE->_offset]); - permute(chi, tmp, ptype); - } else { - spProjYm(chi, in._odata[SE->_offset]); - } - } else { - chi_p = &buf[SE->_offset]; - } - - Impl::multLink(Uchi, U._odata[sU], *chi_p, Yp, SE, st); - accumReconYm(result, Uchi); - - /////////////////////////// - // Zm - /////////////////////////// - SE = st.GetEntry(ptype, Zp, sF); - - if (SE->_is_local) { - chi_p = χ - if (SE->_permute) { - spProjZm(tmp, in._odata[SE->_offset]); - permute(chi, tmp, ptype); - } else { - spProjZm(chi, in._odata[SE->_offset]); - } - } else { - chi_p = &buf[SE->_offset]; - } - - Impl::multLink(Uchi, U._odata[sU], *chi_p, Zp, SE, st); - accumReconZm(result, Uchi); - - /////////////////////////// - // Tm - /////////////////////////// - SE = st.GetEntry(ptype, Tp, sF); - - if (SE->_is_local) { - chi_p = χ - if (SE->_permute) { - spProjTm(tmp, in._odata[SE->_offset]); - permute(chi, tmp, ptype); - } else { - spProjTm(chi, in._odata[SE->_offset]); - } - } else { - chi_p = &buf[SE->_offset]; - } - - Impl::multLink(Uchi, U._odata[sU], *chi_p, Tp, SE, st); - accumReconTm(result, Uchi); + GENERIC_STENCIL_LEG(Xm,spProjXp,spReconXp); + GENERIC_STENCIL_LEG(Ym,spProjYp,accumReconYp); + GENERIC_STENCIL_LEG(Zm,spProjZp,accumReconZp); + GENERIC_STENCIL_LEG(Tm,spProjTp,accumReconTp); + GENERIC_STENCIL_LEG(Xp,spProjXm,accumReconXm); + GENERIC_STENCIL_LEG(Yp,spProjYm,accumReconYm); + GENERIC_STENCIL_LEG(Zp,spProjZm,accumReconZm); + GENERIC_STENCIL_LEG(Tp,spProjTm,accumReconTm); + vstream(out._odata[sF], result); +}; + //////////////////////////////////////////////////////////////////// + // Interior kernels + //////////////////////////////////////////////////////////////////// +template +void WilsonKernels::GenericDhopSiteDagInt(StencilImpl &st, LebesgueOrder &lo, DoubledGaugeField &U, + SiteHalfSpinor *buf, int sF, + int sU, const FermionField &in, FermionField &out) +{ + SiteHalfSpinor tmp; + SiteHalfSpinor chi; + SiteHalfSpinor *chi_p; + SiteHalfSpinor Uchi; + SiteSpinor result; + StencilEntry *SE; + int ptype; + result=zero; + GENERIC_STENCIL_LEG_INT(Xp,spProjXp,accumReconXp); + GENERIC_STENCIL_LEG_INT(Yp,spProjYp,accumReconYp); + GENERIC_STENCIL_LEG_INT(Zp,spProjZp,accumReconZp); + GENERIC_STENCIL_LEG_INT(Tp,spProjTp,accumReconTp); + GENERIC_STENCIL_LEG_INT(Xm,spProjXm,accumReconXm); + GENERIC_STENCIL_LEG_INT(Ym,spProjYm,accumReconYm); + GENERIC_STENCIL_LEG_INT(Zm,spProjZm,accumReconZm); + GENERIC_STENCIL_LEG_INT(Tm,spProjTm,accumReconTm); vstream(out._odata[sF], result); }; template -void WilsonKernels::DiracOptDhopDir( StencilImpl &st, DoubledGaugeField &U,SiteHalfSpinor *buf, int sF, +void WilsonKernels::GenericDhopSiteInt(StencilImpl &st, LebesgueOrder &lo, DoubledGaugeField &U, + SiteHalfSpinor *buf, int sF, + int sU, const FermionField &in, FermionField &out) +{ + SiteHalfSpinor tmp; + SiteHalfSpinor chi; + SiteHalfSpinor *chi_p; + SiteHalfSpinor Uchi; + SiteSpinor result; + StencilEntry *SE; + int ptype; + result=zero; + GENERIC_STENCIL_LEG_INT(Xm,spProjXp,accumReconXp); + GENERIC_STENCIL_LEG_INT(Ym,spProjYp,accumReconYp); + GENERIC_STENCIL_LEG_INT(Zm,spProjZp,accumReconZp); + GENERIC_STENCIL_LEG_INT(Tm,spProjTp,accumReconTp); + GENERIC_STENCIL_LEG_INT(Xp,spProjXm,accumReconXm); + GENERIC_STENCIL_LEG_INT(Yp,spProjYm,accumReconYm); + GENERIC_STENCIL_LEG_INT(Zp,spProjZm,accumReconZm); + GENERIC_STENCIL_LEG_INT(Tp,spProjTm,accumReconTm); + vstream(out._odata[sF], result); +}; +//////////////////////////////////////////////////////////////////// +// Exterior kernels +//////////////////////////////////////////////////////////////////// +template +void WilsonKernels::GenericDhopSiteDagExt(StencilImpl &st, LebesgueOrder &lo, DoubledGaugeField &U, + SiteHalfSpinor *buf, int sF, + int sU, const FermionField &in, FermionField &out) +{ + SiteHalfSpinor tmp; + SiteHalfSpinor chi; + SiteHalfSpinor *chi_p; + SiteHalfSpinor Uchi; + SiteSpinor result; + StencilEntry *SE; + int ptype; + int nmu=0; + result=zero; + GENERIC_STENCIL_LEG_EXT(Xp,spProjXp,accumReconXp); + GENERIC_STENCIL_LEG_EXT(Yp,spProjYp,accumReconYp); + GENERIC_STENCIL_LEG_EXT(Zp,spProjZp,accumReconZp); + GENERIC_STENCIL_LEG_EXT(Tp,spProjTp,accumReconTp); + GENERIC_STENCIL_LEG_EXT(Xm,spProjXm,accumReconXm); + GENERIC_STENCIL_LEG_EXT(Ym,spProjYm,accumReconYm); + GENERIC_STENCIL_LEG_EXT(Zm,spProjZm,accumReconZm); + GENERIC_STENCIL_LEG_EXT(Tm,spProjTm,accumReconTm); + if ( nmu ) { + out._odata[sF] = out._odata[sF] + result; + } +}; + +template +void WilsonKernels::GenericDhopSiteExt(StencilImpl &st, LebesgueOrder &lo, DoubledGaugeField &U, + SiteHalfSpinor *buf, int sF, + int sU, const FermionField &in, FermionField &out) +{ + SiteHalfSpinor tmp; + SiteHalfSpinor chi; + SiteHalfSpinor *chi_p; + SiteHalfSpinor Uchi; + SiteSpinor result; + StencilEntry *SE; + int ptype; + int nmu=0; + result=zero; + GENERIC_STENCIL_LEG_EXT(Xm,spProjXp,accumReconXp); + GENERIC_STENCIL_LEG_EXT(Ym,spProjYp,accumReconYp); + GENERIC_STENCIL_LEG_EXT(Zm,spProjZp,accumReconZp); + GENERIC_STENCIL_LEG_EXT(Tm,spProjTp,accumReconTp); + GENERIC_STENCIL_LEG_EXT(Xp,spProjXm,accumReconXm); + GENERIC_STENCIL_LEG_EXT(Yp,spProjYm,accumReconYm); + GENERIC_STENCIL_LEG_EXT(Zp,spProjZm,accumReconZm); + GENERIC_STENCIL_LEG_EXT(Tp,spProjTm,accumReconTm); + if ( nmu ) { + out._odata[sF] = out._odata[sF] + result; + } +}; + +template +void WilsonKernels::DhopDir( StencilImpl &st, DoubledGaugeField &U,SiteHalfSpinor *buf, int sF, int sU, const FermionField &in, FermionField &out, int dir, int gamma) { SiteHalfSpinor tmp; @@ -404,119 +270,14 @@ void WilsonKernels::DiracOptDhopDir( StencilImpl &st, DoubledGaugeField &U int ptype; SE = st.GetEntry(ptype, dir, sF); - - // Xp - if (gamma == Xp) { - if (SE->_is_local && SE->_permute) { - spProjXp(tmp, in._odata[SE->_offset]); - permute(chi, tmp, ptype); - } else if (SE->_is_local) { - spProjXp(chi, in._odata[SE->_offset]); - } else { - chi = buf[SE->_offset]; - } - Impl::multLink(Uchi, U._odata[sU], chi, dir, SE, st); - spReconXp(result, Uchi); - } - - // Yp - if (gamma == Yp) { - if (SE->_is_local && SE->_permute) { - spProjYp(tmp, in._odata[SE->_offset]); - permute(chi, tmp, ptype); - } else if (SE->_is_local) { - spProjYp(chi, in._odata[SE->_offset]); - } else { - chi = buf[SE->_offset]; - } - Impl::multLink(Uchi, U._odata[sU], chi, dir, SE, st); - spReconYp(result, Uchi); - } - - // Zp - if (gamma == Zp) { - if (SE->_is_local && SE->_permute) { - spProjZp(tmp, in._odata[SE->_offset]); - permute(chi, tmp, ptype); - } else if (SE->_is_local) { - spProjZp(chi, in._odata[SE->_offset]); - } else { - chi = buf[SE->_offset]; - } - Impl::multLink(Uchi, U._odata[sU], chi, dir, SE, st); - spReconZp(result, Uchi); - } - - // Tp - if (gamma == Tp) { - if (SE->_is_local && SE->_permute) { - spProjTp(tmp, in._odata[SE->_offset]); - permute(chi, tmp, ptype); - } else if (SE->_is_local) { - spProjTp(chi, in._odata[SE->_offset]); - } else { - chi = buf[SE->_offset]; - } - Impl::multLink(Uchi, U._odata[sU], chi, dir, SE, st); - spReconTp(result, Uchi); - } - - // Xm - if (gamma == Xm) { - if (SE->_is_local && SE->_permute) { - spProjXm(tmp, in._odata[SE->_offset]); - permute(chi, tmp, ptype); - } else if (SE->_is_local) { - spProjXm(chi, in._odata[SE->_offset]); - } else { - chi = buf[SE->_offset]; - } - Impl::multLink(Uchi, U._odata[sU], chi, dir, SE, st); - spReconXm(result, Uchi); - } - - // Ym - if (gamma == Ym) { - if (SE->_is_local && SE->_permute) { - spProjYm(tmp, in._odata[SE->_offset]); - permute(chi, tmp, ptype); - } else if (SE->_is_local) { - spProjYm(chi, in._odata[SE->_offset]); - } else { - chi = buf[SE->_offset]; - } - Impl::multLink(Uchi, U._odata[sU], chi, dir, SE, st); - spReconYm(result, Uchi); - } - - // Zm - if (gamma == Zm) { - if (SE->_is_local && SE->_permute) { - spProjZm(tmp, in._odata[SE->_offset]); - permute(chi, tmp, ptype); - } else if (SE->_is_local) { - spProjZm(chi, in._odata[SE->_offset]); - } else { - chi = buf[SE->_offset]; - } - Impl::multLink(Uchi, U._odata[sU], chi, dir, SE, st); - spReconZm(result, Uchi); - } - - // Tm - if (gamma == Tm) { - if (SE->_is_local && SE->_permute) { - spProjTm(tmp, in._odata[SE->_offset]); - permute(chi, tmp, ptype); - } else if (SE->_is_local) { - spProjTm(chi, in._odata[SE->_offset]); - } else { - chi = buf[SE->_offset]; - } - Impl::multLink(Uchi, U._odata[sU], chi, dir, SE, st); - spReconTm(result, Uchi); - } - + GENERIC_DHOPDIR_LEG(Xp,spProjXp,spReconXp); + GENERIC_DHOPDIR_LEG(Yp,spProjYp,spReconYp); + GENERIC_DHOPDIR_LEG(Zp,spProjZp,spReconZp); + GENERIC_DHOPDIR_LEG(Tp,spProjTp,spReconTp); + GENERIC_DHOPDIR_LEG(Xm,spProjXm,spReconXm); + GENERIC_DHOPDIR_LEG(Ym,spProjYm,spReconYm); + GENERIC_DHOPDIR_LEG(Zm,spProjZm,spReconZm); + GENERIC_DHOPDIR_LEG(Tm,spProjTm,spReconTm); vstream(out._odata[sF], result); } diff --git a/lib/qcd/action/fermion/WilsonKernels.h b/lib/qcd/action/fermion/WilsonKernels.h index 47da2b14..2cf52660 100644 --- a/lib/qcd/action/fermion/WilsonKernels.h +++ b/lib/qcd/action/fermion/WilsonKernels.h @@ -41,8 +41,9 @@ namespace QCD { class WilsonKernelsStatic { public: enum { OptGeneric, OptHandUnroll, OptInlineAsm }; - // S-direction is INNERMOST and takes no part in the parity. - static int Opt; // these are a temporary hack + enum { CommsAndCompute, CommsThenCompute }; + static int Opt; + static int Comms; }; template class WilsonKernels : public FermionOperator , public WilsonKernelsStatic { @@ -55,19 +56,25 @@ public: template typename std::enable_if::type - DiracOptDhopSite(StencilImpl &st, LebesgueOrder &lo, DoubledGaugeField &U, SiteHalfSpinor * buf, - int sF, int sU, int Ls, int Ns, const FermionField &in, FermionField &out) + DhopSite(StencilImpl &st, LebesgueOrder &lo, DoubledGaugeField &U, SiteHalfSpinor * buf, + int sF, int sU, int Ls, int Ns, const FermionField &in, FermionField &out,int interior=1,int exterior=1) { + bgq_l1p_optimisation(1); switch(Opt) { -#ifdef AVX512 +#if defined(AVX512) || defined (QPX) case OptInlineAsm: - WilsonKernels::DiracOptAsmDhopSite(st,lo,U,buf,sF,sU,Ls,Ns,in,out); - break; + if(interior&&exterior) WilsonKernels::AsmDhopSite (st,lo,U,buf,sF,sU,Ls,Ns,in,out); + else if (interior) WilsonKernels::AsmDhopSiteInt(st,lo,U,buf,sF,sU,Ls,Ns,in,out); + else if (exterior) WilsonKernels::AsmDhopSiteExt(st,lo,U,buf,sF,sU,Ls,Ns,in,out); + else assert(0); + break; #endif case OptHandUnroll: for (int site = 0; site < Ns; site++) { for (int s = 0; s < Ls; s++) { - WilsonKernels::DiracOptHandDhopSite(st,lo,U,buf,sF,sU,in,out); + if(interior&&exterior) WilsonKernels::HandDhopSite(st,lo,U,buf,sF,sU,in,out); + else if (interior) WilsonKernels::HandDhopSiteInt(st,lo,U,buf,sF,sU,in,out); + else if (exterior) WilsonKernels::HandDhopSiteExt(st,lo,U,buf,sF,sU,in,out); sF++; } sU++; @@ -76,7 +83,10 @@ public: case OptGeneric: for (int site = 0; site < Ns; site++) { for (int s = 0; s < Ls; s++) { - WilsonKernels::DiracOptGenericDhopSite(st,lo,U,buf,sF,sU,in,out); + if(interior&&exterior) WilsonKernels::GenericDhopSite(st,lo,U,buf,sF,sU,in,out); + else if (interior) WilsonKernels::GenericDhopSiteInt(st,lo,U,buf,sF,sU,in,out); + else if (exterior) WilsonKernels::GenericDhopSiteExt(st,lo,U,buf,sF,sU,in,out); + else assert(0); sF++; } sU++; @@ -85,16 +95,20 @@ public: default: assert(0); } + bgq_l1p_optimisation(0); } template typename std::enable_if<(Impl::Dimension != 3 || (Impl::Dimension == 3 && Nc != 3)) && EnableBool, void>::type - DiracOptDhopSite(StencilImpl &st, LebesgueOrder &lo, DoubledGaugeField &U, SiteHalfSpinor * buf, - int sF, int sU, int Ls, int Ns, const FermionField &in, FermionField &out) { + DhopSite(StencilImpl &st, LebesgueOrder &lo, DoubledGaugeField &U, SiteHalfSpinor * buf, + int sF, int sU, int Ls, int Ns, const FermionField &in, FermionField &out,int interior=1,int exterior=1 ) { // no kernel choice for (int site = 0; site < Ns; site++) { for (int s = 0; s < Ls; s++) { - WilsonKernels::DiracOptGenericDhopSite(st, lo, U, buf, sF, sU, in, out); + if(interior&&exterior) WilsonKernels::GenericDhopSite(st,lo,U,buf,sF,sU,in,out); + else if (interior) WilsonKernels::GenericDhopSiteInt(st,lo,U,buf,sF,sU,in,out); + else if (exterior) WilsonKernels::GenericDhopSiteExt(st,lo,U,buf,sF,sU,in,out); + else assert(0); sF++; } sU++; @@ -103,19 +117,26 @@ public: template typename std::enable_if::type - DiracOptDhopSiteDag(StencilImpl &st, LebesgueOrder &lo, DoubledGaugeField &U, SiteHalfSpinor * buf, - int sF, int sU, int Ls, int Ns, const FermionField &in, FermionField &out) { - + DhopSiteDag(StencilImpl &st, LebesgueOrder &lo, DoubledGaugeField &U, SiteHalfSpinor * buf, + int sF, int sU, int Ls, int Ns, const FermionField &in, FermionField &out,int interior=1,int exterior=1) +{ + bgq_l1p_optimisation(1); switch(Opt) { -#ifdef AVX512 +#if defined(AVX512) || defined (QPX) case OptInlineAsm: - WilsonKernels::DiracOptAsmDhopSiteDag(st,lo,U,buf,sF,sU,Ls,Ns,in,out); + if(interior&&exterior) WilsonKernels::AsmDhopSiteDag (st,lo,U,buf,sF,sU,Ls,Ns,in,out); + else if (interior) WilsonKernels::AsmDhopSiteDagInt(st,lo,U,buf,sF,sU,Ls,Ns,in,out); + else if (exterior) WilsonKernels::AsmDhopSiteDagExt(st,lo,U,buf,sF,sU,Ls,Ns,in,out); + else assert(0); break; #endif case OptHandUnroll: for (int site = 0; site < Ns; site++) { for (int s = 0; s < Ls; s++) { - WilsonKernels::DiracOptHandDhopSiteDag(st,lo,U,buf,sF,sU,in,out); + if(interior&&exterior) WilsonKernels::HandDhopSiteDag(st,lo,U,buf,sF,sU,in,out); + else if (interior) WilsonKernels::HandDhopSiteDagInt(st,lo,U,buf,sF,sU,in,out); + else if (exterior) WilsonKernels::HandDhopSiteDagExt(st,lo,U,buf,sF,sU,in,out); + else assert(0); sF++; } sU++; @@ -124,7 +145,10 @@ public: case OptGeneric: for (int site = 0; site < Ns; site++) { for (int s = 0; s < Ls; s++) { - WilsonKernels::DiracOptGenericDhopSiteDag(st,lo,U,buf,sF,sU,in,out); + if(interior&&exterior) WilsonKernels::GenericDhopSiteDag(st,lo,U,buf,sF,sU,in,out); + else if (interior) WilsonKernels::GenericDhopSiteDagInt(st,lo,U,buf,sF,sU,in,out); + else if (exterior) WilsonKernels::GenericDhopSiteDagExt(st,lo,U,buf,sF,sU,in,out); + else assert(0); sF++; } sU++; @@ -133,45 +157,86 @@ public: default: assert(0); } + bgq_l1p_optimisation(0); } template typename std::enable_if<(Impl::Dimension != 3 || (Impl::Dimension == 3 && Nc != 3)) && EnableBool,void>::type - DiracOptDhopSiteDag(StencilImpl &st, LebesgueOrder &lo, DoubledGaugeField &U,SiteHalfSpinor * buf, - int sF, int sU, int Ls, int Ns, const FermionField &in, FermionField &out) { + DhopSiteDag(StencilImpl &st, LebesgueOrder &lo, DoubledGaugeField &U,SiteHalfSpinor * buf, + int sF, int sU, int Ls, int Ns, const FermionField &in, FermionField &out,int interior=1,int exterior=1) { for (int site = 0; site < Ns; site++) { for (int s = 0; s < Ls; s++) { - WilsonKernels::DiracOptGenericDhopSiteDag(st,lo,U,buf,sF,sU,in,out); + if(interior&&exterior) WilsonKernels::GenericDhopSiteDag(st,lo,U,buf,sF,sU,in,out); + else if (interior) WilsonKernels::GenericDhopSiteDagInt(st,lo,U,buf,sF,sU,in,out); + else if (exterior) WilsonKernels::GenericDhopSiteDagExt(st,lo,U,buf,sF,sU,in,out); + else assert(0); sF++; } sU++; } } - void DiracOptDhopDir(StencilImpl &st, DoubledGaugeField &U,SiteHalfSpinor * buf, + void DhopDir(StencilImpl &st, DoubledGaugeField &U,SiteHalfSpinor * buf, int sF, int sU, const FermionField &in, FermionField &out, int dirdisp, int gamma); private: // Specialised variants - void DiracOptGenericDhopSite(StencilImpl &st, LebesgueOrder &lo, DoubledGaugeField &U, SiteHalfSpinor * buf, - int sF, int sU, const FermionField &in, FermionField &out); + void GenericDhopSite(StencilImpl &st, LebesgueOrder &lo, DoubledGaugeField &U, SiteHalfSpinor * buf, + int sF, int sU, const FermionField &in, FermionField &out); - void DiracOptGenericDhopSiteDag(StencilImpl &st, LebesgueOrder &lo, DoubledGaugeField &U, SiteHalfSpinor * buf, - int sF, int sU, const FermionField &in, FermionField &out); + void GenericDhopSiteDag(StencilImpl &st, LebesgueOrder &lo, DoubledGaugeField &U, SiteHalfSpinor * buf, + int sF, int sU, const FermionField &in, FermionField &out); - void DiracOptAsmDhopSite(StencilImpl &st, LebesgueOrder &lo, DoubledGaugeField &U, SiteHalfSpinor * buf, - int sF, int sU, int Ls, int Ns, const FermionField &in,FermionField &out); - - void DiracOptAsmDhopSiteDag(StencilImpl &st, LebesgueOrder &lo, DoubledGaugeField &U, SiteHalfSpinor * buf, - int sF, int sU, int Ls, int Ns, const FermionField &in, FermionField &out); - - void DiracOptHandDhopSite(StencilImpl &st, LebesgueOrder &lo, DoubledGaugeField &U, SiteHalfSpinor * buf, - int sF, int sU, const FermionField &in, FermionField &out); - - void DiracOptHandDhopSiteDag(StencilImpl &st, LebesgueOrder &lo, DoubledGaugeField &U, SiteHalfSpinor * buf, - int sF, int sU, const FermionField &in, FermionField &out); + void GenericDhopSiteInt(StencilImpl &st, LebesgueOrder &lo, DoubledGaugeField &U, SiteHalfSpinor * buf, + int sF, int sU, const FermionField &in, FermionField &out); + void GenericDhopSiteDagInt(StencilImpl &st, LebesgueOrder &lo, DoubledGaugeField &U, SiteHalfSpinor * buf, + int sF, int sU, const FermionField &in, FermionField &out); + + void GenericDhopSiteExt(StencilImpl &st, LebesgueOrder &lo, DoubledGaugeField &U, SiteHalfSpinor * buf, + int sF, int sU, const FermionField &in, FermionField &out); + + void GenericDhopSiteDagExt(StencilImpl &st, LebesgueOrder &lo, DoubledGaugeField &U, SiteHalfSpinor * buf, + int sF, int sU, const FermionField &in, FermionField &out); + + void AsmDhopSite(StencilImpl &st, LebesgueOrder &lo, DoubledGaugeField &U, SiteHalfSpinor * buf, + int sF, int sU, int Ls, int Ns, const FermionField &in,FermionField &out); + + void AsmDhopSiteDag(StencilImpl &st, LebesgueOrder &lo, DoubledGaugeField &U, SiteHalfSpinor * buf, + int sF, int sU, int Ls, int Ns, const FermionField &in, FermionField &out); + + void AsmDhopSiteInt(StencilImpl &st, LebesgueOrder &lo, DoubledGaugeField &U, SiteHalfSpinor * buf, + int sF, int sU, int Ls, int Ns, const FermionField &in,FermionField &out); + + void AsmDhopSiteDagInt(StencilImpl &st, LebesgueOrder &lo, DoubledGaugeField &U, SiteHalfSpinor * buf, + int sF, int sU, int Ls, int Ns, const FermionField &in, FermionField &out); + + void AsmDhopSiteExt(StencilImpl &st, LebesgueOrder &lo, DoubledGaugeField &U, SiteHalfSpinor * buf, + int sF, int sU, int Ls, int Ns, const FermionField &in,FermionField &out); + + void AsmDhopSiteDagExt(StencilImpl &st, LebesgueOrder &lo, DoubledGaugeField &U, SiteHalfSpinor * buf, + int sF, int sU, int Ls, int Ns, const FermionField &in, FermionField &out); + + + void HandDhopSite(StencilImpl &st, LebesgueOrder &lo, DoubledGaugeField &U, SiteHalfSpinor * buf, + int sF, int sU, const FermionField &in, FermionField &out); + + void HandDhopSiteDag(StencilImpl &st, LebesgueOrder &lo, DoubledGaugeField &U, SiteHalfSpinor * buf, + int sF, int sU, const FermionField &in, FermionField &out); + + void HandDhopSiteInt(StencilImpl &st, LebesgueOrder &lo, DoubledGaugeField &U, SiteHalfSpinor * buf, + int sF, int sU, const FermionField &in, FermionField &out); + + void HandDhopSiteDagInt(StencilImpl &st, LebesgueOrder &lo, DoubledGaugeField &U, SiteHalfSpinor * buf, + int sF, int sU, const FermionField &in, FermionField &out); + + void HandDhopSiteExt(StencilImpl &st, LebesgueOrder &lo, DoubledGaugeField &U, SiteHalfSpinor * buf, + int sF, int sU, const FermionField &in, FermionField &out); + + void HandDhopSiteDagExt(StencilImpl &st, LebesgueOrder &lo, DoubledGaugeField &U, SiteHalfSpinor * buf, + int sF, int sU, const FermionField &in, FermionField &out); + public: WilsonKernels(const ImplParams &p = ImplParams()); diff --git a/lib/qcd/action/fermion/WilsonKernelsAsm.cc b/lib/qcd/action/fermion/WilsonKernelsAsm.cc index d7a9edd3..cd5d2430 100644 --- a/lib/qcd/action/fermion/WilsonKernelsAsm.cc +++ b/lib/qcd/action/fermion/WilsonKernelsAsm.cc @@ -30,165 +30,75 @@ Author: Guido Cossu *************************************************************************************/ /* END LEGAL */ -#include - +#include namespace Grid { namespace QCD { - + + /////////////////////////////////////////////////////////// // Default to no assembler implementation /////////////////////////////////////////////////////////// template void -WilsonKernels::DiracOptAsmDhopSite(StencilImpl &st,LebesgueOrder & lo,DoubledGaugeField &U,SiteHalfSpinor *buf, +WilsonKernels::AsmDhopSite(StencilImpl &st,LebesgueOrder & lo,DoubledGaugeField &U,SiteHalfSpinor *buf, int ss,int ssU,int Ls,int Ns,const FermionField &in, FermionField &out) { assert(0); } template void -WilsonKernels::DiracOptAsmDhopSiteDag(StencilImpl &st,LebesgueOrder & lo,DoubledGaugeField &U,SiteHalfSpinor *buf, +WilsonKernels::AsmDhopSiteDag(StencilImpl &st,LebesgueOrder & lo,DoubledGaugeField &U,SiteHalfSpinor *buf, int ss,int ssU,int Ls,int Ns,const FermionField &in, FermionField &out) { assert(0); } -#if defined(AVX512) -#include +template void +WilsonKernels::AsmDhopSiteInt(StencilImpl &st,LebesgueOrder & lo,DoubledGaugeField &U,SiteHalfSpinor *buf, + int ss,int ssU,int Ls,int Ns,const FermionField &in, FermionField &out) +{ + assert(0); +} - /////////////////////////////////////////////////////////// - // If we are AVX512 specialise the single precision routine - /////////////////////////////////////////////////////////// +template void +WilsonKernels::AsmDhopSiteDagInt(StencilImpl &st,LebesgueOrder & lo,DoubledGaugeField &U,SiteHalfSpinor *buf, + int ss,int ssU,int Ls,int Ns,const FermionField &in, FermionField &out) +{ + assert(0); +} -#include - -static Vector signsF; +template void +WilsonKernels::AsmDhopSiteExt(StencilImpl &st,LebesgueOrder & lo,DoubledGaugeField &U,SiteHalfSpinor *buf, + int ss,int ssU,int Ls,int Ns,const FermionField &in, FermionField &out) +{ + assert(0); +} - template - int setupSigns(Vector& signs ){ - Vector bother(2); - signs = bother; - vrsign(signs[0]); - visign(signs[1]); - return 1; - } +template void +WilsonKernels::AsmDhopSiteDagExt(StencilImpl &st,LebesgueOrder & lo,DoubledGaugeField &U,SiteHalfSpinor *buf, + int ss,int ssU,int Ls,int Ns,const FermionField &in, FermionField &out) +{ + assert(0); +} - static int signInitF = setupSigns(signsF); - -#define label(A) ilabel(A) -#define ilabel(A) ".globl\n" #A ":\n" - -#define MAYBEPERM(A,perm) if (perm) { A ; } -#define MULT_2SPIN(ptr,pf) MULT_ADDSUB_2SPIN(ptr,pf) -#define FX(A) WILSONASM_ ##A -#define COMPLEX_TYPE vComplexF -#define signs signsF - -#undef KERNEL_DAG -template<> void -WilsonKernels::DiracOptAsmDhopSite(StencilImpl &st,LebesgueOrder & lo,DoubledGaugeField &U, SiteHalfSpinor *buf, - int ss,int ssU,int Ls,int Ns,const FermionField &in, FermionField &out) -#include - -#define KERNEL_DAG -template<> void -WilsonKernels::DiracOptAsmDhopSiteDag(StencilImpl &st,LebesgueOrder & lo,DoubledGaugeField &U,SiteHalfSpinor *buf, - int ss,int ssU,int Ls,int Ns,const FermionField &in, FermionField &out) -#include - -#undef VMOVIDUP -#undef VMOVRDUP -#undef MAYBEPERM -#undef MULT_2SPIN -#undef FX -#define FX(A) DWFASM_ ## A -#define MAYBEPERM(A,B) -//#define VMOVIDUP(A,B,C) VBCASTIDUPf(A,B,C) -//#define VMOVRDUP(A,B,C) VBCASTRDUPf(A,B,C) -#define MULT_2SPIN(ptr,pf) MULT_ADDSUB_2SPIN_LS(ptr,pf) - -#undef KERNEL_DAG -template<> void -WilsonKernels::DiracOptAsmDhopSite(StencilImpl &st,LebesgueOrder & lo,DoubledGaugeField &U, SiteHalfSpinor *buf, - int ss,int ssU,int Ls,int Ns,const FermionField &in, FermionField &out) -#include - -#define KERNEL_DAG -template<> void -WilsonKernels::DiracOptAsmDhopSiteDag(StencilImpl &st,LebesgueOrder & lo,DoubledGaugeField &U,SiteHalfSpinor *buf, - int ss,int ssU,int Ls,int Ns,const FermionField &in, FermionField &out) -#include -#undef COMPLEX_TYPE -#undef signs -#undef VMOVRDUP -#undef MAYBEPERM -#undef MULT_2SPIN -#undef FX - -/////////////////////////////////////////////////////////// -// If we are AVX512 specialise the double precision routine -/////////////////////////////////////////////////////////// - -#include - -static Vector signsD; -#define signs signsD -static int signInitD = setupSigns(signsD); - -#define MAYBEPERM(A,perm) if (perm) { A ; } -#define MULT_2SPIN(ptr,pf) MULT_ADDSUB_2SPIN(ptr,pf) -#define FX(A) WILSONASM_ ##A -#define COMPLEX_TYPE vComplexD - -#undef KERNEL_DAG -template<> void -WilsonKernels::DiracOptAsmDhopSite(StencilImpl &st,LebesgueOrder & lo,DoubledGaugeField &U, SiteHalfSpinor *buf, - int ss,int ssU,int Ls,int Ns,const FermionField &in, FermionField &out) -#include - -#define KERNEL_DAG -template<> void -WilsonKernels::DiracOptAsmDhopSiteDag(StencilImpl &st,LebesgueOrder & lo,DoubledGaugeField &U,SiteHalfSpinor *buf, - int ss,int ssU,int Ls,int Ns,const FermionField &in, FermionField &out) -#include - -#undef VMOVIDUP -#undef VMOVRDUP -#undef MAYBEPERM -#undef MULT_2SPIN -#undef FX -#define FX(A) DWFASM_ ## A -#define MAYBEPERM(A,B) -//#define VMOVIDUP(A,B,C) VBCASTIDUPd(A,B,C) -//#define VMOVRDUP(A,B,C) VBCASTRDUPd(A,B,C) -#define MULT_2SPIN(ptr,pf) MULT_ADDSUB_2SPIN_LS(ptr,pf) - -#undef KERNEL_DAG -template<> void -WilsonKernels::DiracOptAsmDhopSite(StencilImpl &st,LebesgueOrder & lo,DoubledGaugeField &U, SiteHalfSpinor *buf, - int ss,int ssU,int Ls,int Ns,const FermionField &in, FermionField &out) -#include - -#define KERNEL_DAG -template<> void -WilsonKernels::DiracOptAsmDhopSiteDag(StencilImpl &st,LebesgueOrder & lo,DoubledGaugeField &U,SiteHalfSpinor *buf, - int ss,int ssU,int Ls,int Ns,const FermionField &in, FermionField &out) -#include - -#undef COMPLEX_TYPE -#undef signs -#undef VMOVRDUP -#undef MAYBEPERM -#undef MULT_2SPIN -#undef FX - -#endif //AVX512 +#include +#include #define INSTANTIATE_ASM(A)\ -template void WilsonKernels::DiracOptAsmDhopSite(StencilImpl &st,LebesgueOrder & lo,DoubledGaugeField &U, SiteHalfSpinor *buf,\ +template void WilsonKernels::AsmDhopSite(StencilImpl &st,LebesgueOrder & lo,DoubledGaugeField &U, SiteHalfSpinor *buf,\ int ss,int ssU,int Ls,int Ns,const FermionField &in, FermionField &out);\ \ -template void WilsonKernels::DiracOptAsmDhopSiteDag(StencilImpl &st,LebesgueOrder & lo,DoubledGaugeField &U, SiteHalfSpinor *buf,\ +template void WilsonKernels::AsmDhopSiteDag(StencilImpl &st,LebesgueOrder & lo,DoubledGaugeField &U, SiteHalfSpinor *buf,\ + int ss,int ssU,int Ls,int Ns,const FermionField &in, FermionField &out);\ +template void WilsonKernels::AsmDhopSiteInt(StencilImpl &st,LebesgueOrder & lo,DoubledGaugeField &U, SiteHalfSpinor *buf,\ + int ss,int ssU,int Ls,int Ns,const FermionField &in, FermionField &out);\ + \ +template void WilsonKernels::AsmDhopSiteDagInt(StencilImpl &st,LebesgueOrder & lo,DoubledGaugeField &U, SiteHalfSpinor *buf,\ + int ss,int ssU,int Ls,int Ns,const FermionField &in, FermionField &out);\ +template void WilsonKernels::AsmDhopSiteExt(StencilImpl &st,LebesgueOrder & lo,DoubledGaugeField &U, SiteHalfSpinor *buf,\ + int ss,int ssU,int Ls,int Ns,const FermionField &in, FermionField &out);\ + \ +template void WilsonKernels::AsmDhopSiteDagExt(StencilImpl &st,LebesgueOrder & lo,DoubledGaugeField &U, SiteHalfSpinor *buf,\ int ss,int ssU,int Ls,int Ns,const FermionField &in, FermionField &out);\ INSTANTIATE_ASM(WilsonImplF); @@ -202,5 +112,16 @@ INSTANTIATE_ASM(DomainWallVec5dImplD); INSTANTIATE_ASM(ZDomainWallVec5dImplF); INSTANTIATE_ASM(ZDomainWallVec5dImplD); +INSTANTIATE_ASM(WilsonImplFH); +INSTANTIATE_ASM(WilsonImplDF); +INSTANTIATE_ASM(ZWilsonImplFH); +INSTANTIATE_ASM(ZWilsonImplDF); +INSTANTIATE_ASM(GparityWilsonImplFH); +INSTANTIATE_ASM(GparityWilsonImplDF); +INSTANTIATE_ASM(DomainWallVec5dImplFH); +INSTANTIATE_ASM(DomainWallVec5dImplDF); +INSTANTIATE_ASM(ZDomainWallVec5dImplFH); +INSTANTIATE_ASM(ZDomainWallVec5dImplDF); + }} diff --git a/lib/qcd/action/fermion/WilsonKernelsAsmAvx512.h b/lib/qcd/action/fermion/WilsonKernelsAsmAvx512.h new file mode 100644 index 00000000..948c16a2 --- /dev/null +++ b/lib/qcd/action/fermion/WilsonKernelsAsmAvx512.h @@ -0,0 +1,650 @@ +/************************************************************************************* + + Grid physics library, www.github.com/paboyle/Grid + + + + Source file: ./lib/qcd/action/fermion/WilsonKernelsAsmAvx512.h + + Copyright (C) 2015 + +Author: Peter Boyle +Author: paboyle + + This program is free software; you can redistribute it and/or modify + it under the terms of the GNU General Public License as published by + the Free Software Foundation; either version 2 of the License, or + (at your option) any later version. + + This program is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU General Public License for more details. + + You should have received a copy of the GNU General Public License along + with this program; if not, write to the Free Software Foundation, Inc., + 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA. + + See the full license in the file "LICENSE" in the top level distribution directory +*************************************************************************************/ +/* END LEGAL */ + + +#if defined(AVX512) + /////////////////////////////////////////////////////////// + // If we are AVX512 specialise the single precision routine + /////////////////////////////////////////////////////////// +#include +#include + +static Vector signsF; + + template + int setupSigns(Vector& signs ){ + Vector bother(2); + signs = bother; + vrsign(signs[0]); + visign(signs[1]); + return 1; + } + + static int signInitF = setupSigns(signsF); + +#define MAYBEPERM(A,perm) if (perm) { A ; } +#define MULT_2SPIN(ptr,pf) MULT_ADDSUB_2SPIN(ptr,pf) +#define COMPLEX_SIGNS(isigns) vComplexF *isigns = &signsF[0]; + +///////////////////////////////////////////////////////////////// +// XYZT vectorised, undag Kernel, single +///////////////////////////////////////////////////////////////// +#undef KERNEL_DAG +#define INTERIOR_AND_EXTERIOR +#undef INTERIOR +#undef EXTERIOR +template<> void +WilsonKernels::AsmDhopSite(StencilImpl &st,LebesgueOrder & lo,DoubledGaugeField &U, SiteHalfSpinor *buf, + int ss,int ssU,int Ls,int Ns,const FermionField &in, FermionField &out) +#include + +template<> void +WilsonKernels::AsmDhopSite(StencilImpl &st,LebesgueOrder & lo,DoubledGaugeField &U, SiteHalfSpinor *buf, + int ss,int ssU,int Ls,int Ns,const FermionField &in, FermionField &out) +#include + +template<> void +WilsonKernels::AsmDhopSite(StencilImpl &st,LebesgueOrder & lo,DoubledGaugeField &U, SiteHalfSpinor *buf, + int ss,int ssU,int Ls,int Ns,const FermionField &in, FermionField &out) +#include + +template<> void +WilsonKernels::AsmDhopSite(StencilImpl &st,LebesgueOrder & lo,DoubledGaugeField &U, SiteHalfSpinor *buf, + int ss,int ssU,int Ls,int Ns,const FermionField &in, FermionField &out) +#include + +#undef INTERIOR_AND_EXTERIOR +#define INTERIOR +#undef EXTERIOR +template<> void +WilsonKernels::AsmDhopSiteInt(StencilImpl &st,LebesgueOrder & lo,DoubledGaugeField &U, SiteHalfSpinor *buf, + int ss,int ssU,int Ls,int Ns,const FermionField &in, FermionField &out) +#include + +template<> void +WilsonKernels::AsmDhopSiteInt(StencilImpl &st,LebesgueOrder & lo,DoubledGaugeField &U, SiteHalfSpinor *buf, + int ss,int ssU,int Ls,int Ns,const FermionField &in, FermionField &out) +#include + +template<> void +WilsonKernels::AsmDhopSiteInt(StencilImpl &st,LebesgueOrder & lo,DoubledGaugeField &U, SiteHalfSpinor *buf, + int ss,int ssU,int Ls,int Ns,const FermionField &in, FermionField &out) +#include + +template<> void +WilsonKernels::AsmDhopSiteInt(StencilImpl &st,LebesgueOrder & lo,DoubledGaugeField &U, SiteHalfSpinor *buf, + int ss,int ssU,int Ls,int Ns,const FermionField &in, FermionField &out) +#include + + +#undef INTERIOR_AND_EXTERIOR +#undef INTERIOR +#define EXTERIOR +template<> void +WilsonKernels::AsmDhopSiteExt(StencilImpl &st,LebesgueOrder & lo,DoubledGaugeField &U, SiteHalfSpinor *buf, + int ss,int ssU,int Ls,int Ns,const FermionField &in, FermionField &out) +#include + +template<> void +WilsonKernels::AsmDhopSiteExt(StencilImpl &st,LebesgueOrder & lo,DoubledGaugeField &U, SiteHalfSpinor *buf, + int ss,int ssU,int Ls,int Ns,const FermionField &in, FermionField &out) +#include + +template<> void +WilsonKernels::AsmDhopSiteExt(StencilImpl &st,LebesgueOrder & lo,DoubledGaugeField &U, SiteHalfSpinor *buf, + int ss,int ssU,int Ls,int Ns,const FermionField &in, FermionField &out) +#include + +template<> void +WilsonKernels::AsmDhopSiteExt(StencilImpl &st,LebesgueOrder & lo,DoubledGaugeField &U, SiteHalfSpinor *buf, + int ss,int ssU,int Ls,int Ns,const FermionField &in, FermionField &out) +#include + +///////////////////////////////////////////////////////////////// +// XYZT vectorised, dag Kernel, single +///////////////////////////////////////////////////////////////// +#define KERNEL_DAG +#define INTERIOR_AND_EXTERIOR +#undef INTERIOR +#undef EXTERIOR +template<> void +WilsonKernels::AsmDhopSiteDag(StencilImpl &st,LebesgueOrder & lo,DoubledGaugeField &U, SiteHalfSpinor *buf, + int ss,int ssU,int Ls,int Ns,const FermionField &in, FermionField &out) +#include + +template<> void +WilsonKernels::AsmDhopSiteDag(StencilImpl &st,LebesgueOrder & lo,DoubledGaugeField &U, SiteHalfSpinor *buf, + int ss,int ssU,int Ls,int Ns,const FermionField &in, FermionField &out) +#include + +template<> void +WilsonKernels::AsmDhopSiteDag(StencilImpl &st,LebesgueOrder & lo,DoubledGaugeField &U, SiteHalfSpinor *buf, + int ss,int ssU,int Ls,int Ns,const FermionField &in, FermionField &out) +#include + +template<> void +WilsonKernels::AsmDhopSiteDag(StencilImpl &st,LebesgueOrder & lo,DoubledGaugeField &U, SiteHalfSpinor *buf, + int ss,int ssU,int Ls,int Ns,const FermionField &in, FermionField &out) +#include + +#undef INTERIOR_AND_EXTERIOR +#define INTERIOR +#undef EXTERIOR +template<> void +WilsonKernels::AsmDhopSiteDagInt(StencilImpl &st,LebesgueOrder & lo,DoubledGaugeField &U, SiteHalfSpinor *buf, + int ss,int ssU,int Ls,int Ns,const FermionField &in, FermionField &out) +#include + +template<> void +WilsonKernels::AsmDhopSiteDagInt(StencilImpl &st,LebesgueOrder & lo,DoubledGaugeField &U, SiteHalfSpinor *buf, + int ss,int ssU,int Ls,int Ns,const FermionField &in, FermionField &out) +#include + +template<> void +WilsonKernels::AsmDhopSiteDagInt(StencilImpl &st,LebesgueOrder & lo,DoubledGaugeField &U, SiteHalfSpinor *buf, + int ss,int ssU,int Ls,int Ns,const FermionField &in, FermionField &out) +#include + +template<> void +WilsonKernels::AsmDhopSiteDagInt(StencilImpl &st,LebesgueOrder & lo,DoubledGaugeField &U, SiteHalfSpinor *buf, + int ss,int ssU,int Ls,int Ns,const FermionField &in, FermionField &out) +#include + +#undef INTERIOR_AND_EXTERIOR +#undef INTERIOR +#define EXTERIOR +template<> void +WilsonKernels::AsmDhopSiteDagExt(StencilImpl &st,LebesgueOrder & lo,DoubledGaugeField &U, SiteHalfSpinor *buf, + int ss,int ssU,int Ls,int Ns,const FermionField &in, FermionField &out) +#include + +template<> void +WilsonKernels::AsmDhopSiteDagExt(StencilImpl &st,LebesgueOrder & lo,DoubledGaugeField &U, SiteHalfSpinor *buf, + int ss,int ssU,int Ls,int Ns,const FermionField &in, FermionField &out) +#include + +template<> void +WilsonKernels::AsmDhopSiteDagExt(StencilImpl &st,LebesgueOrder & lo,DoubledGaugeField &U, SiteHalfSpinor *buf, + int ss,int ssU,int Ls,int Ns,const FermionField &in, FermionField &out) +#include + +template<> void +WilsonKernels::AsmDhopSiteDagExt(StencilImpl &st,LebesgueOrder & lo,DoubledGaugeField &U, SiteHalfSpinor *buf, + int ss,int ssU,int Ls,int Ns,const FermionField &in, FermionField &out) +#include + +#undef MAYBEPERM +#undef MULT_2SPIN +#define MAYBEPERM(A,B) +#define MULT_2SPIN(ptr,pf) MULT_ADDSUB_2SPIN_LS(ptr,pf) + +///////////////////////////////////////////////////////////////// +// Ls vectorised, undag Kernel, single +///////////////////////////////////////////////////////////////// +#undef KERNEL_DAG +#define INTERIOR_AND_EXTERIOR +#undef INTERIOR +#undef EXTERIOR +template<> void +WilsonKernels::AsmDhopSite(StencilImpl &st,LebesgueOrder & lo,DoubledGaugeField &U, SiteHalfSpinor *buf, + int ss,int ssU,int Ls,int Ns,const FermionField &in, FermionField &out) +#include +template<> void +WilsonKernels::AsmDhopSite(StencilImpl &st,LebesgueOrder & lo,DoubledGaugeField &U, SiteHalfSpinor *buf, + int ss,int ssU,int Ls,int Ns,const FermionField &in, FermionField &out) +#include + +template<> void +WilsonKernels::AsmDhopSite(StencilImpl &st,LebesgueOrder & lo,DoubledGaugeField &U, SiteHalfSpinor *buf, + int ss,int ssU,int Ls,int Ns,const FermionField &in, FermionField &out) +#include +template<> void +WilsonKernels::AsmDhopSite(StencilImpl &st,LebesgueOrder & lo,DoubledGaugeField &U, SiteHalfSpinor *buf, + int ss,int ssU,int Ls,int Ns,const FermionField &in, FermionField &out) +#include + +#undef INTERIOR_AND_EXTERIOR +#define INTERIOR +#undef EXTERIOR +template<> void +WilsonKernels::AsmDhopSiteInt(StencilImpl &st,LebesgueOrder & lo,DoubledGaugeField &U, SiteHalfSpinor *buf, + int ss,int ssU,int Ls,int Ns,const FermionField &in, FermionField &out) +#include +template<> void +WilsonKernels::AsmDhopSiteInt(StencilImpl &st,LebesgueOrder & lo,DoubledGaugeField &U, SiteHalfSpinor *buf, + int ss,int ssU,int Ls,int Ns,const FermionField &in, FermionField &out) +#include + +template<> void +WilsonKernels::AsmDhopSiteInt(StencilImpl &st,LebesgueOrder & lo,DoubledGaugeField &U, SiteHalfSpinor *buf, + int ss,int ssU,int Ls,int Ns,const FermionField &in, FermionField &out) +#include +template<> void +WilsonKernels::AsmDhopSiteInt(StencilImpl &st,LebesgueOrder & lo,DoubledGaugeField &U, SiteHalfSpinor *buf, + int ss,int ssU,int Ls,int Ns,const FermionField &in, FermionField &out) +#include + +#undef INTERIOR_AND_EXTERIOR +#undef INTERIOR +#define EXTERIOR +#undef MULT_2SPIN +#define MULT_2SPIN(ptr,pf) MULT_ADDSUB_2SPIN_LSNOPF(ptr,pf) +template<> void +WilsonKernels::AsmDhopSiteExt(StencilImpl &st,LebesgueOrder & lo,DoubledGaugeField &U, SiteHalfSpinor *buf, + int ss,int ssU,int Ls,int Ns,const FermionField &in, FermionField &out) +#include + +template<> void +WilsonKernels::AsmDhopSiteExt(StencilImpl &st,LebesgueOrder & lo,DoubledGaugeField &U, SiteHalfSpinor *buf, + int ss,int ssU,int Ls,int Ns,const FermionField &in, FermionField &out) +#include + +template<> void +WilsonKernels::AsmDhopSiteExt(StencilImpl &st,LebesgueOrder & lo,DoubledGaugeField &U, SiteHalfSpinor *buf, + int ss,int ssU,int Ls,int Ns,const FermionField &in, FermionField &out) +#include + +template<> void +WilsonKernels::AsmDhopSiteExt(StencilImpl &st,LebesgueOrder & lo,DoubledGaugeField &U, SiteHalfSpinor *buf, + int ss,int ssU,int Ls,int Ns,const FermionField &in, FermionField &out) +#include + +///////////////////////////////////////////////////////////////// +// Ls vectorised, dag Kernel, single +///////////////////////////////////////////////////////////////// +#define KERNEL_DAG +#define INTERIOR_AND_EXTERIOR +#undef INTERIOR +#undef EXTERIOR +template<> void +WilsonKernels::AsmDhopSiteDag(StencilImpl &st,LebesgueOrder & lo,DoubledGaugeField &U,SiteHalfSpinor *buf, + int ss,int ssU,int Ls,int Ns,const FermionField &in, FermionField &out) +#include +template<> void +WilsonKernels::AsmDhopSiteDag(StencilImpl &st,LebesgueOrder & lo,DoubledGaugeField &U,SiteHalfSpinor *buf, + int ss,int ssU,int Ls,int Ns,const FermionField &in, FermionField &out) +#include + +template<> void +WilsonKernels::AsmDhopSiteDag(StencilImpl &st,LebesgueOrder & lo,DoubledGaugeField &U,SiteHalfSpinor *buf, + int ss,int ssU,int Ls,int Ns,const FermionField &in, FermionField &out) +#include +template<> void +WilsonKernels::AsmDhopSiteDag(StencilImpl &st,LebesgueOrder & lo,DoubledGaugeField &U,SiteHalfSpinor *buf, + int ss,int ssU,int Ls,int Ns,const FermionField &in, FermionField &out) +#include + +#undef INTERIOR_AND_EXTERIOR +#define INTERIOR +#undef EXTERIOR +template<> void +WilsonKernels::AsmDhopSiteDagInt(StencilImpl &st,LebesgueOrder & lo,DoubledGaugeField &U,SiteHalfSpinor *buf, + int ss,int ssU,int Ls,int Ns,const FermionField &in, FermionField &out) +#include +template<> void +WilsonKernels::AsmDhopSiteDagInt(StencilImpl &st,LebesgueOrder & lo,DoubledGaugeField &U,SiteHalfSpinor *buf, + int ss,int ssU,int Ls,int Ns,const FermionField &in, FermionField &out) +#include + +template<> void +WilsonKernels::AsmDhopSiteDagInt(StencilImpl &st,LebesgueOrder & lo,DoubledGaugeField &U,SiteHalfSpinor *buf, + int ss,int ssU,int Ls,int Ns,const FermionField &in, FermionField &out) +#include +template<> void +WilsonKernels::AsmDhopSiteDagInt(StencilImpl &st,LebesgueOrder & lo,DoubledGaugeField &U,SiteHalfSpinor *buf, + int ss,int ssU,int Ls,int Ns,const FermionField &in, FermionField &out) +#include + +#undef INTERIOR_AND_EXTERIOR +#undef INTERIOR +#define EXTERIOR +template<> void +WilsonKernels::AsmDhopSiteDagExt(StencilImpl &st,LebesgueOrder & lo,DoubledGaugeField &U,SiteHalfSpinor *buf, + int ss,int ssU,int Ls,int Ns,const FermionField &in, FermionField &out) +#include +template<> void +WilsonKernels::AsmDhopSiteDagExt(StencilImpl &st,LebesgueOrder & lo,DoubledGaugeField &U,SiteHalfSpinor *buf, + int ss,int ssU,int Ls,int Ns,const FermionField &in, FermionField &out) +#include + +template<> void +WilsonKernels::AsmDhopSiteDagExt(StencilImpl &st,LebesgueOrder & lo,DoubledGaugeField &U,SiteHalfSpinor *buf, + int ss,int ssU,int Ls,int Ns,const FermionField &in, FermionField &out) +#include +template<> void +WilsonKernels::AsmDhopSiteDagExt(StencilImpl &st,LebesgueOrder & lo,DoubledGaugeField &U,SiteHalfSpinor *buf, + int ss,int ssU,int Ls,int Ns,const FermionField &in, FermionField &out) +#include + +#undef COMPLEX_SIGNS +#undef MAYBEPERM +#undef MULT_2SPIN + + + +/////////////////////////////////////////////////////////// +// If we are AVX512 specialise the double precision routine +/////////////////////////////////////////////////////////// + +#include + +static Vector signsD; +static int signInitD = setupSigns(signsD); + +#define MAYBEPERM(A,perm) if (perm) { A ; } +#define MULT_2SPIN(ptr,pf) MULT_ADDSUB_2SPIN(ptr,pf) +#define COMPLEX_SIGNS(isigns) vComplexD *isigns = &signsD[0]; + + +#define INTERIOR_AND_EXTERIOR +#undef INTERIOR +#undef EXTERIOR + +///////////////////////////////////////////////////////////////// +// XYZT vectorised, undag Kernel, single +///////////////////////////////////////////////////////////////// +#undef KERNEL_DAG +#define INTERIOR_AND_EXTERIOR +#undef INTERIOR +#undef EXTERIOR +template<> void +WilsonKernels::AsmDhopSite(StencilImpl &st,LebesgueOrder & lo,DoubledGaugeField &U, SiteHalfSpinor *buf, + int ss,int ssU,int Ls,int Ns,const FermionField &in, FermionField &out) +#include +template<> void +WilsonKernels::AsmDhopSite(StencilImpl &st,LebesgueOrder & lo,DoubledGaugeField &U, SiteHalfSpinor *buf, + int ss,int ssU,int Ls,int Ns,const FermionField &in, FermionField &out) +#include + +template<> void +WilsonKernels::AsmDhopSite(StencilImpl &st,LebesgueOrder & lo,DoubledGaugeField &U, SiteHalfSpinor *buf, + int ss,int ssU,int Ls,int Ns,const FermionField &in, FermionField &out) +#include +template<> void +WilsonKernels::AsmDhopSite(StencilImpl &st,LebesgueOrder & lo,DoubledGaugeField &U, SiteHalfSpinor *buf, + int ss,int ssU,int Ls,int Ns,const FermionField &in, FermionField &out) +#include + +#undef INTERIOR_AND_EXTERIOR +#define INTERIOR +#undef EXTERIOR +template<> void +WilsonKernels::AsmDhopSiteInt(StencilImpl &st,LebesgueOrder & lo,DoubledGaugeField &U, SiteHalfSpinor *buf, + int ss,int ssU,int Ls,int Ns,const FermionField &in, FermionField &out) +#include +template<> void +WilsonKernels::AsmDhopSiteInt(StencilImpl &st,LebesgueOrder & lo,DoubledGaugeField &U, SiteHalfSpinor *buf, + int ss,int ssU,int Ls,int Ns,const FermionField &in, FermionField &out) +#include + +template<> void +WilsonKernels::AsmDhopSiteInt(StencilImpl &st,LebesgueOrder & lo,DoubledGaugeField &U, SiteHalfSpinor *buf, + int ss,int ssU,int Ls,int Ns,const FermionField &in, FermionField &out) +#include +template<> void +WilsonKernels::AsmDhopSiteInt(StencilImpl &st,LebesgueOrder & lo,DoubledGaugeField &U, SiteHalfSpinor *buf, + int ss,int ssU,int Ls,int Ns,const FermionField &in, FermionField &out) +#include + +#undef INTERIOR_AND_EXTERIOR +#undef INTERIOR +#define EXTERIOR +template<> void +WilsonKernels::AsmDhopSiteExt(StencilImpl &st,LebesgueOrder & lo,DoubledGaugeField &U, SiteHalfSpinor *buf, + int ss,int ssU,int Ls,int Ns,const FermionField &in, FermionField &out) +#include +template<> void +WilsonKernels::AsmDhopSiteExt(StencilImpl &st,LebesgueOrder & lo,DoubledGaugeField &U, SiteHalfSpinor *buf, + int ss,int ssU,int Ls,int Ns,const FermionField &in, FermionField &out) +#include + +template<> void +WilsonKernels::AsmDhopSiteExt(StencilImpl &st,LebesgueOrder & lo,DoubledGaugeField &U, SiteHalfSpinor *buf, + int ss,int ssU,int Ls,int Ns,const FermionField &in, FermionField &out) +#include +template<> void +WilsonKernels::AsmDhopSiteExt(StencilImpl &st,LebesgueOrder & lo,DoubledGaugeField &U, SiteHalfSpinor *buf, + int ss,int ssU,int Ls,int Ns,const FermionField &in, FermionField &out) +#include + +///////////////////////////////////////////////////////////////// +// XYZT vectorised, dag Kernel, single +///////////////////////////////////////////////////////////////// +#define KERNEL_DAG +#define INTERIOR_AND_EXTERIOR +#undef INTERIOR +#undef EXTERIOR +template<> void +WilsonKernels::AsmDhopSiteDag(StencilImpl &st,LebesgueOrder & lo,DoubledGaugeField &U, SiteHalfSpinor *buf, + int ss,int ssU,int Ls,int Ns,const FermionField &in, FermionField &out) +#include +template<> void +WilsonKernels::AsmDhopSiteDag(StencilImpl &st,LebesgueOrder & lo,DoubledGaugeField &U, SiteHalfSpinor *buf, + int ss,int ssU,int Ls,int Ns,const FermionField &in, FermionField &out) +#include + +template<> void +WilsonKernels::AsmDhopSiteDag(StencilImpl &st,LebesgueOrder & lo,DoubledGaugeField &U, SiteHalfSpinor *buf, + int ss,int ssU,int Ls,int Ns,const FermionField &in, FermionField &out) +#include +template<> void +WilsonKernels::AsmDhopSiteDag(StencilImpl &st,LebesgueOrder & lo,DoubledGaugeField &U, SiteHalfSpinor *buf, + int ss,int ssU,int Ls,int Ns,const FermionField &in, FermionField &out) +#include + +#undef INTERIOR_AND_EXTERIOR +#define INTERIOR +#undef EXTERIOR +template<> void +WilsonKernels::AsmDhopSiteDagInt(StencilImpl &st,LebesgueOrder & lo,DoubledGaugeField &U, SiteHalfSpinor *buf, + int ss,int ssU,int Ls,int Ns,const FermionField &in, FermionField &out) +#include +template<> void +WilsonKernels::AsmDhopSiteDagInt(StencilImpl &st,LebesgueOrder & lo,DoubledGaugeField &U, SiteHalfSpinor *buf, + int ss,int ssU,int Ls,int Ns,const FermionField &in, FermionField &out) +#include + +template<> void +WilsonKernels::AsmDhopSiteDagInt(StencilImpl &st,LebesgueOrder & lo,DoubledGaugeField &U, SiteHalfSpinor *buf, + int ss,int ssU,int Ls,int Ns,const FermionField &in, FermionField &out) +#include +template<> void +WilsonKernels::AsmDhopSiteDagInt(StencilImpl &st,LebesgueOrder & lo,DoubledGaugeField &U, SiteHalfSpinor *buf, + int ss,int ssU,int Ls,int Ns,const FermionField &in, FermionField &out) +#include + +#undef INTERIOR_AND_EXTERIOR +#undef INTERIOR +#define EXTERIOR +template<> void +WilsonKernels::AsmDhopSiteDagExt(StencilImpl &st,LebesgueOrder & lo,DoubledGaugeField &U, SiteHalfSpinor *buf, + int ss,int ssU,int Ls,int Ns,const FermionField &in, FermionField &out) +#include +template<> void +WilsonKernels::AsmDhopSiteDagExt(StencilImpl &st,LebesgueOrder & lo,DoubledGaugeField &U, SiteHalfSpinor *buf, + int ss,int ssU,int Ls,int Ns,const FermionField &in, FermionField &out) +#include + +template<> void +WilsonKernels::AsmDhopSiteDagExt(StencilImpl &st,LebesgueOrder & lo,DoubledGaugeField &U, SiteHalfSpinor *buf, + int ss,int ssU,int Ls,int Ns,const FermionField &in, FermionField &out) +#include +template<> void +WilsonKernels::AsmDhopSiteDagExt(StencilImpl &st,LebesgueOrder & lo,DoubledGaugeField &U, SiteHalfSpinor *buf, + int ss,int ssU,int Ls,int Ns,const FermionField &in, FermionField &out) +#include + +#undef MAYBEPERM +#undef MULT_2SPIN +#define MAYBEPERM(A,B) +#define MULT_2SPIN(ptr,pf) MULT_ADDSUB_2SPIN_LS(ptr,pf) + +///////////////////////////////////////////////////////////////// +// Ls vectorised, undag Kernel, single +///////////////////////////////////////////////////////////////// +#undef KERNEL_DAG +#define INTERIOR_AND_EXTERIOR +#undef INTERIOR +#undef EXTERIOR +template<> void +WilsonKernels::AsmDhopSite(StencilImpl &st,LebesgueOrder & lo,DoubledGaugeField &U, SiteHalfSpinor *buf, + int ss,int ssU,int Ls,int Ns,const FermionField &in, FermionField &out) +#include +template<> void +WilsonKernels::AsmDhopSite(StencilImpl &st,LebesgueOrder & lo,DoubledGaugeField &U, SiteHalfSpinor *buf, + int ss,int ssU,int Ls,int Ns,const FermionField &in, FermionField &out) +#include + +template<> void +WilsonKernels::AsmDhopSite(StencilImpl &st,LebesgueOrder & lo,DoubledGaugeField &U, SiteHalfSpinor *buf, + int ss,int ssU,int Ls,int Ns,const FermionField &in, FermionField &out) +#include +template<> void +WilsonKernels::AsmDhopSite(StencilImpl &st,LebesgueOrder & lo,DoubledGaugeField &U, SiteHalfSpinor *buf, + int ss,int ssU,int Ls,int Ns,const FermionField &in, FermionField &out) +#include + +#undef INTERIOR_AND_EXTERIOR +#define INTERIOR +#undef EXTERIOR +template<> void +WilsonKernels::AsmDhopSiteInt(StencilImpl &st,LebesgueOrder & lo,DoubledGaugeField &U, SiteHalfSpinor *buf, + int ss,int ssU,int Ls,int Ns,const FermionField &in, FermionField &out) +#include +template<> void +WilsonKernels::AsmDhopSiteInt(StencilImpl &st,LebesgueOrder & lo,DoubledGaugeField &U, SiteHalfSpinor *buf, + int ss,int ssU,int Ls,int Ns,const FermionField &in, FermionField &out) +#include + +template<> void +WilsonKernels::AsmDhopSiteInt(StencilImpl &st,LebesgueOrder & lo,DoubledGaugeField &U, SiteHalfSpinor *buf, + int ss,int ssU,int Ls,int Ns,const FermionField &in, FermionField &out) +#include +template<> void +WilsonKernels::AsmDhopSiteInt(StencilImpl &st,LebesgueOrder & lo,DoubledGaugeField &U, SiteHalfSpinor *buf, + int ss,int ssU,int Ls,int Ns,const FermionField &in, FermionField &out) +#include + +#undef INTERIOR_AND_EXTERIOR +#undef INTERIOR +#define EXTERIOR +#undef MULT_2SPIN +#define MULT_2SPIN(ptr,pf) MULT_ADDSUB_2SPIN_LSNOPF(ptr,pf) +template<> void +WilsonKernels::AsmDhopSiteExt(StencilImpl &st,LebesgueOrder & lo,DoubledGaugeField &U, SiteHalfSpinor *buf, + int ss,int ssU,int Ls,int Ns,const FermionField &in, FermionField &out) +#include +template<> void +WilsonKernels::AsmDhopSiteExt(StencilImpl &st,LebesgueOrder & lo,DoubledGaugeField &U, SiteHalfSpinor *buf, + int ss,int ssU,int Ls,int Ns,const FermionField &in, FermionField &out) +#include + +template<> void +WilsonKernels::AsmDhopSiteExt(StencilImpl &st,LebesgueOrder & lo,DoubledGaugeField &U, SiteHalfSpinor *buf, + int ss,int ssU,int Ls,int Ns,const FermionField &in, FermionField &out) +#include +template<> void +WilsonKernels::AsmDhopSiteExt(StencilImpl &st,LebesgueOrder & lo,DoubledGaugeField &U, SiteHalfSpinor *buf, + int ss,int ssU,int Ls,int Ns,const FermionField &in, FermionField &out) +#include + +///////////////////////////////////////////////////////////////// +// Ls vectorised, dag Kernel, single +///////////////////////////////////////////////////////////////// +#define KERNEL_DAG +#define INTERIOR_AND_EXTERIOR +#undef INTERIOR +#undef EXTERIOR +template<> void +WilsonKernels::AsmDhopSiteDag(StencilImpl &st,LebesgueOrder & lo,DoubledGaugeField &U,SiteHalfSpinor *buf, + int ss,int ssU,int Ls,int Ns,const FermionField &in, FermionField &out) +#include +template<> void +WilsonKernels::AsmDhopSiteDag(StencilImpl &st,LebesgueOrder & lo,DoubledGaugeField &U,SiteHalfSpinor *buf, + int ss,int ssU,int Ls,int Ns,const FermionField &in, FermionField &out) +#include + +template<> void +WilsonKernels::AsmDhopSiteDag(StencilImpl &st,LebesgueOrder & lo,DoubledGaugeField &U,SiteHalfSpinor *buf, + int ss,int ssU,int Ls,int Ns,const FermionField &in, FermionField &out) +#include +template<> void +WilsonKernels::AsmDhopSiteDag(StencilImpl &st,LebesgueOrder & lo,DoubledGaugeField &U,SiteHalfSpinor *buf, + int ss,int ssU,int Ls,int Ns,const FermionField &in, FermionField &out) +#include + +#undef INTERIOR_AND_EXTERIOR +#define INTERIOR +#undef EXTERIOR +template<> void +WilsonKernels::AsmDhopSiteDagInt(StencilImpl &st,LebesgueOrder & lo,DoubledGaugeField &U,SiteHalfSpinor *buf, + int ss,int ssU,int Ls,int Ns,const FermionField &in, FermionField &out) +#include +template<> void +WilsonKernels::AsmDhopSiteDagInt(StencilImpl &st,LebesgueOrder & lo,DoubledGaugeField &U,SiteHalfSpinor *buf, + int ss,int ssU,int Ls,int Ns,const FermionField &in, FermionField &out) +#include + +template<> void +WilsonKernels::AsmDhopSiteDagInt(StencilImpl &st,LebesgueOrder & lo,DoubledGaugeField &U,SiteHalfSpinor *buf, + int ss,int ssU,int Ls,int Ns,const FermionField &in, FermionField &out) +#include +template<> void +WilsonKernels::AsmDhopSiteDagInt(StencilImpl &st,LebesgueOrder & lo,DoubledGaugeField &U,SiteHalfSpinor *buf, + int ss,int ssU,int Ls,int Ns,const FermionField &in, FermionField &out) +#include + +#undef INTERIOR_AND_EXTERIOR +#undef INTERIOR +#define EXTERIOR +template<> void +WilsonKernels::AsmDhopSiteDagExt(StencilImpl &st,LebesgueOrder & lo,DoubledGaugeField &U,SiteHalfSpinor *buf, + int ss,int ssU,int Ls,int Ns,const FermionField &in, FermionField &out) +#include +template<> void +WilsonKernels::AsmDhopSiteDagExt(StencilImpl &st,LebesgueOrder & lo,DoubledGaugeField &U,SiteHalfSpinor *buf, + int ss,int ssU,int Ls,int Ns,const FermionField &in, FermionField &out) +#include + +template<> void +WilsonKernels::AsmDhopSiteDagExt(StencilImpl &st,LebesgueOrder & lo,DoubledGaugeField &U,SiteHalfSpinor *buf, + int ss,int ssU,int Ls,int Ns,const FermionField &in, FermionField &out) +#include +template<> void +WilsonKernels::AsmDhopSiteDagExt(StencilImpl &st,LebesgueOrder & lo,DoubledGaugeField &U,SiteHalfSpinor *buf, + int ss,int ssU,int Ls,int Ns,const FermionField &in, FermionField &out) +#include + +#undef COMPLEX_SIGNS +#undef MAYBEPERM +#undef MULT_2SPIN + +#endif //AVX512 diff --git a/lib/qcd/action/fermion/WilsonKernelsAsmBody.h b/lib/qcd/action/fermion/WilsonKernelsAsmBody.h index 72e13754..db8651ab 100644 --- a/lib/qcd/action/fermion/WilsonKernelsAsmBody.h +++ b/lib/qcd/action/fermion/WilsonKernelsAsmBody.h @@ -1,255 +1,196 @@ +#ifdef KERNEL_DAG +#define DIR0_PROJMEM(base) XP_PROJMEM(base); +#define DIR1_PROJMEM(base) YP_PROJMEM(base); +#define DIR2_PROJMEM(base) ZP_PROJMEM(base); +#define DIR3_PROJMEM(base) TP_PROJMEM(base); +#define DIR4_PROJMEM(base) XM_PROJMEM(base); +#define DIR5_PROJMEM(base) YM_PROJMEM(base); +#define DIR6_PROJMEM(base) ZM_PROJMEM(base); +#define DIR7_PROJMEM(base) TM_PROJMEM(base); +#define DIR0_RECON XP_RECON +#define DIR1_RECON YP_RECON_ACCUM +#define DIR2_RECON ZP_RECON_ACCUM +#define DIR3_RECON TP_RECON_ACCUM +#define DIR4_RECON XM_RECON_ACCUM +#define DIR5_RECON YM_RECON_ACCUM +#define DIR6_RECON ZM_RECON_ACCUM +#define DIR7_RECON TM_RECON_ACCUM +#else +#define DIR0_PROJMEM(base) XM_PROJMEM(base); +#define DIR1_PROJMEM(base) YM_PROJMEM(base); +#define DIR2_PROJMEM(base) ZM_PROJMEM(base); +#define DIR3_PROJMEM(base) TM_PROJMEM(base); +#define DIR4_PROJMEM(base) XP_PROJMEM(base); +#define DIR5_PROJMEM(base) YP_PROJMEM(base); +#define DIR6_PROJMEM(base) ZP_PROJMEM(base); +#define DIR7_PROJMEM(base) TP_PROJMEM(base); +#define DIR0_RECON XM_RECON +#define DIR1_RECON YM_RECON_ACCUM +#define DIR2_RECON ZM_RECON_ACCUM +#define DIR3_RECON TM_RECON_ACCUM +#define DIR4_RECON XP_RECON_ACCUM +#define DIR5_RECON YP_RECON_ACCUM +#define DIR6_RECON ZP_RECON_ACCUM +#define DIR7_RECON TP_RECON_ACCUM +#endif + +//////////////////////////////////////////////////////////////////////////////// +// Comms then compute kernel +//////////////////////////////////////////////////////////////////////////////// +#ifdef INTERIOR_AND_EXTERIOR + +#define ASM_LEG(Dir,NxtDir,PERMUTE_DIR,PROJ,RECON) \ + basep = st.GetPFInfo(nent,plocal); nent++; \ + if ( local ) { \ + LOAD64(%r10,isigns); \ + PROJ(base); \ + MAYBEPERM(PERMUTE_DIR,perm); \ + } else { \ + LOAD_CHI(base); \ + } \ + base = st.GetInfo(ptype,local,perm,NxtDir,ent,plocal); ent++; \ + PREFETCH_CHIMU(base); \ + MULT_2SPIN_DIR_PF(Dir,basep); \ + LOAD64(%r10,isigns); \ + RECON; \ + +#define ASM_LEG_XP(Dir,NxtDir,PERMUTE_DIR,PROJ,RECON) \ + base = st.GetInfo(ptype,local,perm,Dir,ent,plocal); ent++; \ + PF_GAUGE(Xp); \ + PREFETCH1_CHIMU(base); \ + ASM_LEG(Dir,NxtDir,PERMUTE_DIR,PROJ,RECON) + +#define RESULT(base,basep) SAVE_RESULT(base,basep); + +#endif + +//////////////////////////////////////////////////////////////////////////////// +// Pre comms kernel -- prefetch like normal because it is mostly right +//////////////////////////////////////////////////////////////////////////////// +#ifdef INTERIOR + +#define ASM_LEG(Dir,NxtDir,PERMUTE_DIR,PROJ,RECON) \ + basep = st.GetPFInfo(nent,plocal); nent++; \ + if ( local ) { \ + LOAD64(%r10,isigns); \ + PROJ(base); \ + MAYBEPERM(PERMUTE_DIR,perm); \ + }else if ( st.same_node[Dir] ) {LOAD_CHI(base);} \ + if ( local || st.same_node[Dir] ) { \ + MULT_2SPIN_DIR_PF(Dir,basep); \ + LOAD64(%r10,isigns); \ + RECON; \ + } \ + base = st.GetInfo(ptype,local,perm,NxtDir,ent,plocal); ent++; \ + PREFETCH_CHIMU(base); \ + +#define ASM_LEG_XP(Dir,NxtDir,PERMUTE_DIR,PROJ,RECON) \ + base = st.GetInfo(ptype,local,perm,Dir,ent,plocal); ent++; \ + PF_GAUGE(Xp); \ + PREFETCH1_CHIMU(base); \ + { ZERO_PSI; } \ + ASM_LEG(Dir,NxtDir,PERMUTE_DIR,PROJ,RECON) + +#define RESULT(base,basep) SAVE_RESULT(base,basep); + +#endif +//////////////////////////////////////////////////////////////////////////////// +// Post comms kernel +//////////////////////////////////////////////////////////////////////////////// +#ifdef EXTERIOR + + +#define ASM_LEG(Dir,NxtDir,PERMUTE_DIR,PROJ,RECON) \ + base = st.GetInfo(ptype,local,perm,Dir,ent,plocal); ent++; \ + if((!local)&&(!st.same_node[Dir]) ) { \ + LOAD_CHI(base); \ + MULT_2SPIN_DIR_PF(Dir,base); \ + LOAD64(%r10,isigns); \ + RECON; \ + nmu++; \ + } + +#define ASM_LEG_XP(Dir,NxtDir,PERMUTE_DIR,PROJ,RECON) \ + nmu=0; \ + { ZERO_PSI;} \ + base = st.GetInfo(ptype,local,perm,Dir,ent,plocal); ent++; \ + if((!local)&&(!st.same_node[Dir]) ) { \ + LOAD_CHI(base); \ + MULT_2SPIN_DIR_PF(Dir,base); \ + LOAD64(%r10,isigns); \ + RECON; \ + nmu++; \ + } + +#define RESULT(base,basep) if (nmu){ ADD_RESULT(base,base);} + +#endif { + int nmu; int local,perm, ptype; uint64_t base; uint64_t basep; const uint64_t plocal =(uint64_t) & in._odata[0]; - // vComplexF isigns[2] = { signs[0], signs[1] }; - //COMPLEX_TYPE is vComplexF of vComplexD depending - //on the chosen precision - COMPLEX_TYPE *isigns = &signs[0]; - + COMPLEX_SIGNS(isigns); MASK_REGS; int nmax=U._grid->oSites(); for(int site=0;site=nmax) ssn=0; - int sUn=lo.Reorder(ssn); - for(int s=0;s=nmax) ssn=0; + int sUn=lo.Reorder(ssn); + LOCK_GAUGE(0); +#else + int sU =ssU; + int ssn=ssU+1; if(ssn>=nmax) ssn=0; + int sUn=ssn; +#endif + for(int s=0;s shuffle and xor the real part sign bit -#ifdef KERNEL_DAG - YP_PROJMEM(base); -#else - YM_PROJMEM(base); -#endif - MAYBEPERM(PERMUTE_DIR2,perm); - } else { - LOAD_CHI(base); - } - base = st.GetInfo(ptype,local,perm,Zp,ent,plocal); ent++; - PREFETCH_CHIMU(base); - { - MULT_2SPIN_DIR_PFYP(Yp,basep); - } - LOAD64(%r10,isigns); // times i => shuffle and xor the real part sign bit -#ifdef KERNEL_DAG - YP_RECON_ACCUM; -#else - YM_RECON_ACCUM; -#endif + ASM_LEG(Xm,Ym,PERMUTE_DIR3,DIR4_PROJMEM,DIR4_RECON); + ASM_LEG(Ym,Zm,PERMUTE_DIR2,DIR5_PROJMEM,DIR5_RECON); + ASM_LEG(Zm,Tm,PERMUTE_DIR1,DIR6_PROJMEM,DIR6_RECON); + ASM_LEG(Tm,Xp,PERMUTE_DIR0,DIR7_PROJMEM,DIR7_RECON); - //////////////////////////////// - // Zp - //////////////////////////////// - basep = st.GetPFInfo(nent,plocal); nent++; - if ( local ) { - LOAD64(%r10,isigns); // times i => shuffle and xor the real part sign bit -#ifdef KERNEL_DAG - ZP_PROJMEM(base); -#else - ZM_PROJMEM(base); +#ifdef EXTERIOR + if (nmu==0) break; + // if (nmu!=0) std::cout << "EXT "< shuffle and xor the real part sign bit -#ifdef KERNEL_DAG - ZP_RECON_ACCUM; -#else - ZM_RECON_ACCUM; -#endif - - //////////////////////////////// - // Tp - //////////////////////////////// - basep = st.GetPFInfo(nent,plocal); nent++; - if ( local ) { - LOAD64(%r10,isigns); // times i => shuffle and xor the real part sign bit -#ifdef KERNEL_DAG - TP_PROJMEM(base); -#else - TM_PROJMEM(base); -#endif - MAYBEPERM(PERMUTE_DIR0,perm); - } else { - LOAD_CHI(base); - } - base = st.GetInfo(ptype,local,perm,Xm,ent,plocal); ent++; - PREFETCH_CHIMU(base); - { - MULT_2SPIN_DIR_PFTP(Tp,basep); - } - LOAD64(%r10,isigns); // times i => shuffle and xor the real part sign bit -#ifdef KERNEL_DAG - TP_RECON_ACCUM; -#else - TM_RECON_ACCUM; -#endif - - //////////////////////////////// - // Xm - //////////////////////////////// -#ifndef STREAM_STORE - basep= (uint64_t) &out._odata[ss]; -#endif - // basep= st.GetPFInfo(nent,plocal); nent++; - if ( local ) { - LOAD64(%r10,isigns); // times i => shuffle and xor the real part sign bit -#ifdef KERNEL_DAG - XM_PROJMEM(base); -#else - XP_PROJMEM(base); -#endif - MAYBEPERM(PERMUTE_DIR3,perm); - } else { - LOAD_CHI(base); - } - base = st.GetInfo(ptype,local,perm,Ym,ent,plocal); ent++; - PREFETCH_CHIMU(base); - { - MULT_2SPIN_DIR_PFXM(Xm,basep); - } - LOAD64(%r10,isigns); // times i => shuffle and xor the real part sign bit -#ifdef KERNEL_DAG - XM_RECON_ACCUM; -#else - XP_RECON_ACCUM; -#endif - - //////////////////////////////// - // Ym - //////////////////////////////// - basep= st.GetPFInfo(nent,plocal); nent++; - if ( local ) { - LOAD64(%r10,isigns); // times i => shuffle and xor the real part sign bit -#ifdef KERNEL_DAG - YM_PROJMEM(base); -#else - YP_PROJMEM(base); -#endif - MAYBEPERM(PERMUTE_DIR2,perm); - } else { - LOAD_CHI(base); - } - base = st.GetInfo(ptype,local,perm,Zm,ent,plocal); ent++; - PREFETCH_CHIMU(base); - { - MULT_2SPIN_DIR_PFYM(Ym,basep); - } - LOAD64(%r10,isigns); // times i => shuffle and xor the real part sign bit -#ifdef KERNEL_DAG - YM_RECON_ACCUM; -#else - YP_RECON_ACCUM; -#endif - - //////////////////////////////// - // Zm - //////////////////////////////// - basep= st.GetPFInfo(nent,plocal); nent++; - if ( local ) { - LOAD64(%r10,isigns); // times i => shuffle and xor the real part sign bit -#ifdef KERNEL_DAG - ZM_PROJMEM(base); -#else - ZP_PROJMEM(base); -#endif - MAYBEPERM(PERMUTE_DIR1,perm); - } else { - LOAD_CHI(base); - } - base = st.GetInfo(ptype,local,perm,Tm,ent,plocal); ent++; - PREFETCH_CHIMU(base); - { - MULT_2SPIN_DIR_PFZM(Zm,basep); - } - LOAD64(%r10,isigns); // times i => shuffle and xor the real part sign bit -#ifdef KERNEL_DAG - ZM_RECON_ACCUM; -#else - ZP_RECON_ACCUM; -#endif - - //////////////////////////////// - // Tm - //////////////////////////////// - basep= st.GetPFInfo(nent,plocal); nent++; - if ( local ) { - LOAD64(%r10,isigns); // times i => shuffle and xor the real part sign bit -#ifdef KERNEL_DAG - TM_PROJMEM(base); -#else - TP_PROJMEM(base); -#endif - MAYBEPERM(PERMUTE_DIR0,perm); - } else { - LOAD_CHI(base); - } - base= (uint64_t) &out._odata[ss]; -#ifndef STREAM_STORE - PREFETCH_CHIMU(base); -#endif - { - MULT_2SPIN_DIR_PFTM(Tm,basep); - } - LOAD64(%r10,isigns); // times i => shuffle and xor the real part sign bit -#ifdef KERNEL_DAG - TM_RECON_ACCUM; -#else - TP_RECON_ACCUM; -#endif - - basep= st.GetPFInfo(nent,plocal); nent++; - SAVE_RESULT(base,basep); - - } - ssU++; + base = (uint64_t) &out._odata[ss]; + basep= st.GetPFInfo(nent,plocal); nent++; + RESULT(base,basep); + } + ssU++; + UNLOCK_GAUGE(0); } } + +#undef DIR0_PROJMEM +#undef DIR1_PROJMEM +#undef DIR2_PROJMEM +#undef DIR3_PROJMEM +#undef DIR4_PROJMEM +#undef DIR5_PROJMEM +#undef DIR6_PROJMEM +#undef DIR7_PROJMEM +#undef DIR0_RECON +#undef DIR1_RECON +#undef DIR2_RECON +#undef DIR3_RECON +#undef DIR4_RECON +#undef DIR5_RECON +#undef DIR6_RECON +#undef DIR7_RECON +#undef ASM_LEG +#undef ASM_LEG_XP +#undef RESULT diff --git a/lib/qcd/action/fermion/WilsonKernelsAsmQPX.h b/lib/qcd/action/fermion/WilsonKernelsAsmQPX.h new file mode 100644 index 00000000..612234d7 --- /dev/null +++ b/lib/qcd/action/fermion/WilsonKernelsAsmQPX.h @@ -0,0 +1,150 @@ +/************************************************************************************* + + Grid physics library, www.github.com/paboyle/Grid + + + + Source file: ./lib/qcd/action/fermion/WilsonKernelsAsmQPX.h + + Copyright (C) 2015 + +Author: Peter Boyle +Author: paboyle + + This program is free software; you can redistribute it and/or modify + it under the terms of the GNU General Public License as published by + the Free Software Foundation; either version 2 of the License, or + (at your option) any later version. + + This program is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU General Public License for more details. + + You should have received a copy of the GNU General Public License along + with this program; if not, write to the Free Software Foundation, Inc., + 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA. + + See the full license in the file "LICENSE" in the top level distribution directory +*************************************************************************************/ +/* END LEGAL */ + + +#if defined(QPX) + + /////////////////////////////////////////////////////////// + // If we are QPX specialise the single precision routine + /////////////////////////////////////////////////////////// + +#include +#include + +#define MAYBEPERM(A,perm) if (perm) { A ; } +#define MULT_2SPIN(ptr,pf) MULT_2SPIN_QPX(ptr,pf) +#define COMPLEX_SIGNS(isigns) + +#define INTERIOR_AND_EXTERIOR +#undef INTERIOR +#undef EXTERIOR + +///////////////////////////////////////////////////////////////// +// XYZT vectorised, undag Kernel, single +///////////////////////////////////////////////////////////////// +#undef KERNEL_DAG +template<> void +WilsonKernels::AsmDhopSite(StencilImpl &st,LebesgueOrder & lo,DoubledGaugeField &U, SiteHalfSpinor *buf, + int ss,int ssU,int Ls,int Ns,const FermionField &in, FermionField &out) +#include + +///////////////////////////////////////////////////////////////// +// XYZT vectorised, dag Kernel, single +///////////////////////////////////////////////////////////////// +#define KERNEL_DAG +template<> void +WilsonKernels::AsmDhopSiteDag(StencilImpl &st,LebesgueOrder & lo,DoubledGaugeField &U,SiteHalfSpinor *buf, + int ss,int ssU,int Ls,int Ns,const FermionField &in, FermionField &out) +#include + +#undef MAYBEPERM +#undef MULT_2SPIN +#define MAYBEPERM(A,B) +#define MULT_2SPIN(ptr,pf) MULT_2SPIN_QPX_LS(ptr,pf) + +///////////////////////////////////////////////////////////////// +// Ls vectorised, undag Kernel, single +///////////////////////////////////////////////////////////////// +#undef KERNEL_DAG +template<> void +WilsonKernels::AsmDhopSite(StencilImpl &st,LebesgueOrder & lo,DoubledGaugeField &U, SiteHalfSpinor *buf, + int ss,int ssU,int Ls,int Ns,const FermionField &in, FermionField &out) +#include + +///////////////////////////////////////////////////////////////// +// Ls vectorised, dag Kernel, single +///////////////////////////////////////////////////////////////// +#define KERNEL_DAG +template<> void +WilsonKernels::AsmDhopSiteDag(StencilImpl &st,LebesgueOrder & lo,DoubledGaugeField &U,SiteHalfSpinor *buf, + int ss,int ssU,int Ls,int Ns,const FermionField &in, FermionField &out) +#include +#undef MAYBEPERM +#undef MULT_2SPIN + +/////////////////////////////////////////////////////////// +// DP routines +/////////////////////////////////////////////////////////// + +#include + +#define MAYBEPERM(A,perm) if (perm) { A ; } +#define MULT_2SPIN(ptr,pf) MULT_2SPIN_QPX(ptr,pf) + +///////////////////////////////////////////////////////////////// +// XYZT Vectorised, undag Kernel, double +///////////////////////////////////////////////////////////////// +#undef KERNEL_DAG +template<> void +WilsonKernels::AsmDhopSite(StencilImpl &st,LebesgueOrder & lo,DoubledGaugeField &U, SiteHalfSpinor *buf, + int ss,int ssU,int Ls,int Ns,const FermionField &in, FermionField &out) +#include +///////////////////////////////////////////////////////////////// + + +///////////////////////////////////////////////////////////////// +// XYZT Vectorised, dag Kernel, double +///////////////////////////////////////////////////////////////// +#define KERNEL_DAG +template<> void +WilsonKernels::AsmDhopSiteDag(StencilImpl &st,LebesgueOrder & lo,DoubledGaugeField &U,SiteHalfSpinor *buf, + int ss,int ssU,int Ls,int Ns,const FermionField &in, FermionField &out) +#include +///////////////////////////////////////////////////////////////// + +#undef MAYBEPERM +#undef MULT_2SPIN +#define MAYBEPERM(A,B) +#define MULT_2SPIN(ptr,pf) MULT_2SPIN_QPX_LS(ptr,pf) +///////////////////////////////////////////////////////////////// +// Ls vectorised, undag Kernel, double +///////////////////////////////////////////////////////////////// +#undef KERNEL_DAG +template<> void +WilsonKernels::AsmDhopSite(StencilImpl &st,LebesgueOrder & lo,DoubledGaugeField &U, SiteHalfSpinor *buf, + int ss,int ssU,int Ls,int Ns,const FermionField &in, FermionField &out) +#include +///////////////////////////////////////////////////////////////// + +///////////////////////////////////////////////////////////////// +// Ls vectorised, dag Kernel, double +///////////////////////////////////////////////////////////////// +#define KERNEL_DAG +template<> void +WilsonKernels::AsmDhopSiteDag(StencilImpl &st,LebesgueOrder & lo,DoubledGaugeField &U,SiteHalfSpinor *buf, + int ss,int ssU,int Ls,int Ns,const FermionField &in, FermionField &out) +#include +///////////////////////////////////////////////////////////////// + +#undef MAYBEPERM +#undef MULT_2SPIN + +#endif diff --git a/lib/qcd/action/fermion/WilsonKernelsHand.cc b/lib/qcd/action/fermion/WilsonKernelsHand.cc index f5900832..80b81714 100644 --- a/lib/qcd/action/fermion/WilsonKernelsHand.cc +++ b/lib/qcd/action/fermion/WilsonKernelsHand.cc @@ -26,64 +26,185 @@ Author: paboyle See the full license in the file "LICENSE" in the top level distribution directory *************************************************************************************/ /* END LEGAL */ -#include +#include #define REGISTER -#define LOAD_CHIMU \ - const SiteSpinor & ref (in._odata[offset]); \ - Chimu_00=ref()(0)(0);\ - Chimu_01=ref()(0)(1);\ - Chimu_02=ref()(0)(2);\ - Chimu_10=ref()(1)(0);\ - Chimu_11=ref()(1)(1);\ - Chimu_12=ref()(1)(2);\ - Chimu_20=ref()(2)(0);\ - Chimu_21=ref()(2)(1);\ - Chimu_22=ref()(2)(2);\ - Chimu_30=ref()(3)(0);\ - Chimu_31=ref()(3)(1);\ - Chimu_32=ref()(3)(2); +#define LOAD_CHIMU_BODY(F) \ + Chimu_00=ref(F)(0)(0); \ + Chimu_01=ref(F)(0)(1); \ + Chimu_02=ref(F)(0)(2); \ + Chimu_10=ref(F)(1)(0); \ + Chimu_11=ref(F)(1)(1); \ + Chimu_12=ref(F)(1)(2); \ + Chimu_20=ref(F)(2)(0); \ + Chimu_21=ref(F)(2)(1); \ + Chimu_22=ref(F)(2)(2); \ + Chimu_30=ref(F)(3)(0); \ + Chimu_31=ref(F)(3)(1); \ + Chimu_32=ref(F)(3)(2) -#define LOAD_CHI\ - const SiteHalfSpinor &ref(buf[offset]); \ - Chi_00 = ref()(0)(0);\ - Chi_01 = ref()(0)(1);\ - Chi_02 = ref()(0)(2);\ - Chi_10 = ref()(1)(0);\ - Chi_11 = ref()(1)(1);\ - Chi_12 = ref()(1)(2); +#define LOAD_CHIMU(DIR,F,PERM) \ + { const SiteSpinor & ref (in._odata[offset]); LOAD_CHIMU_BODY(F); } + +#define LOAD_CHI_BODY(F) \ + Chi_00 = ref(F)(0)(0);\ + Chi_01 = ref(F)(0)(1);\ + Chi_02 = ref(F)(0)(2);\ + Chi_10 = ref(F)(1)(0);\ + Chi_11 = ref(F)(1)(1);\ + Chi_12 = ref(F)(1)(2) + +#define LOAD_CHI(DIR,F,PERM) \ + {const SiteHalfSpinor &ref(buf[offset]); LOAD_CHI_BODY(F); } + + +//G-parity implementations using in-place intrinsic ops + +//1l 1h -> 1h 1l +//0l 0h , 1h 1l -> 0l 1h 0h,1l +//0h,1l -> 1l,0h +//if( (distance == 1 && !perm_will_occur) || (distance == -1 && perm_will_occur) ) +//Pulled fermion through forwards face, GPBC on upper component +//Need 0= 0l 1h 1= 1l 0h +//else if( (distance == -1 && !perm) || (distance == 1 && perm) ) +//Pulled fermion through backwards face, GPBC on lower component +//Need 0= 1l 0h 1= 0l 1h + +//1l 1h -> 1h 1l +//0l 0h , 1h 1l -> 0l 1h 0h,1l +#define DO_TWIST_0L_1H(INTO,S,C,F, PERM, tmp1, tmp2, tmp3) \ + permute##PERM(tmp1, ref(1)(S)(C)); \ + exchange##PERM(tmp2,tmp3, ref(0)(S)(C), tmp1); \ + INTO = tmp2; + +//0l 0h -> 0h 0l +//1l 1h, 0h 0l -> 1l 0h, 1h 0l +#define DO_TWIST_1L_0H(INTO,S,C,F, PERM, tmp1, tmp2, tmp3) \ + permute##PERM(tmp1, ref(0)(S)(C)); \ + exchange##PERM(tmp2,tmp3, ref(1)(S)(C), tmp1); \ + INTO = tmp2; + + + + +#define LOAD_CHI_SETUP(DIR,F) \ + g = F; \ + direction = st._directions[DIR]; \ + distance = st._distances[DIR]; \ + sl = st._grid->_simd_layout[direction]; \ + inplace_twist = 0; \ + if(SE->_around_the_world && this->Params.twists[DIR % 4]){ \ + if(sl == 1){ \ + g = (F+1) % 2; \ + }else{ \ + inplace_twist = 1; \ + } \ + } + +#define LOAD_CHIMU_GPARITY_INPLACE_TWIST(DIR,F,PERM) \ + { const SiteSpinor &ref(in._odata[offset]); \ + LOAD_CHI_SETUP(DIR,F); \ + if(!inplace_twist){ \ + LOAD_CHIMU_BODY(g); \ + }else{ \ + if( ( F==0 && ((distance == 1 && !perm) || (distance == -1 && perm)) ) || \ + ( F==1 && ((distance == -1 && !perm) || (distance == 1 && perm)) ) ){ \ + DO_TWIST_0L_1H(Chimu_00,0,0,F,PERM, U_00,U_01,U_10); \ + DO_TWIST_0L_1H(Chimu_01,0,1,F,PERM, U_11,U_20,U_21); \ + DO_TWIST_0L_1H(Chimu_02,0,2,F,PERM, U_00,U_01,U_10); \ + DO_TWIST_0L_1H(Chimu_10,1,0,F,PERM, U_11,U_20,U_21); \ + DO_TWIST_0L_1H(Chimu_11,1,1,F,PERM, U_00,U_01,U_10); \ + DO_TWIST_0L_1H(Chimu_12,1,2,F,PERM, U_11,U_20,U_21); \ + DO_TWIST_0L_1H(Chimu_20,2,0,F,PERM, U_00,U_01,U_10); \ + DO_TWIST_0L_1H(Chimu_21,2,1,F,PERM, U_11,U_20,U_21); \ + DO_TWIST_0L_1H(Chimu_22,2,2,F,PERM, U_00,U_01,U_10); \ + DO_TWIST_0L_1H(Chimu_30,3,0,F,PERM, U_11,U_20,U_21); \ + DO_TWIST_0L_1H(Chimu_31,3,1,F,PERM, U_00,U_01,U_10); \ + DO_TWIST_0L_1H(Chimu_32,3,2,F,PERM, U_11,U_20,U_21); \ + }else{ \ + DO_TWIST_1L_0H(Chimu_00,0,0,F,PERM, U_00,U_01,U_10); \ + DO_TWIST_1L_0H(Chimu_01,0,1,F,PERM, U_11,U_20,U_21); \ + DO_TWIST_1L_0H(Chimu_02,0,2,F,PERM, U_00,U_01,U_10); \ + DO_TWIST_1L_0H(Chimu_10,1,0,F,PERM, U_11,U_20,U_21); \ + DO_TWIST_1L_0H(Chimu_11,1,1,F,PERM, U_00,U_01,U_10); \ + DO_TWIST_1L_0H(Chimu_12,1,2,F,PERM, U_11,U_20,U_21); \ + DO_TWIST_1L_0H(Chimu_20,2,0,F,PERM, U_00,U_01,U_10); \ + DO_TWIST_1L_0H(Chimu_21,2,1,F,PERM, U_11,U_20,U_21); \ + DO_TWIST_1L_0H(Chimu_22,2,2,F,PERM, U_00,U_01,U_10); \ + DO_TWIST_1L_0H(Chimu_30,3,0,F,PERM, U_11,U_20,U_21); \ + DO_TWIST_1L_0H(Chimu_31,3,1,F,PERM, U_00,U_01,U_10); \ + DO_TWIST_1L_0H(Chimu_32,3,2,F,PERM, U_11,U_20,U_21); \ + } \ + } \ + } + + +#define LOAD_CHI_GPARITY_INPLACE_TWIST(DIR,F,PERM) \ + { const SiteHalfSpinor &ref(buf[offset]); \ + LOAD_CHI_SETUP(DIR,F); \ + if(!inplace_twist){ \ + LOAD_CHI_BODY(g); \ + }else{ \ + if( ( F==0 && ((distance == 1 && !perm) || (distance == -1 && perm)) ) || \ + ( F==1 && ((distance == -1 && !perm) || (distance == 1 && perm)) ) ){ \ + DO_TWIST_0L_1H(Chi_00,0,0,F,PERM, U_00,U_01,U_10); \ + DO_TWIST_0L_1H(Chi_01,0,1,F,PERM, U_11,U_20,U_21); \ + DO_TWIST_0L_1H(Chi_02,0,2,F,PERM, UChi_00,UChi_01,UChi_02); \ + DO_TWIST_0L_1H(Chi_10,1,0,F,PERM, UChi_10,UChi_11,UChi_12); \ + DO_TWIST_0L_1H(Chi_11,1,1,F,PERM, U_00,U_01,U_10); \ + DO_TWIST_0L_1H(Chi_12,1,2,F,PERM, U_11,U_20,U_21); \ + }else{ \ + DO_TWIST_1L_0H(Chi_00,0,0,F,PERM, U_00,U_01,U_10); \ + DO_TWIST_1L_0H(Chi_01,0,1,F,PERM, U_11,U_20,U_21); \ + DO_TWIST_1L_0H(Chi_02,0,2,F,PERM, UChi_00,UChi_01,UChi_02); \ + DO_TWIST_1L_0H(Chi_10,1,0,F,PERM, UChi_10,UChi_11,UChi_12); \ + DO_TWIST_1L_0H(Chi_11,1,1,F,PERM, U_00,U_01,U_10); \ + DO_TWIST_1L_0H(Chi_12,1,2,F,PERM, U_11,U_20,U_21); \ + } \ + } \ + } + + +#define LOAD_CHI_GPARITY(DIR,F,PERM) LOAD_CHI_GPARITY_INPLACE_TWIST(DIR,F,PERM) +#define LOAD_CHIMU_GPARITY(DIR,F,PERM) LOAD_CHIMU_GPARITY_INPLACE_TWIST(DIR,F,PERM) // To splat or not to splat depends on the implementation -#define MULT_2SPIN(A)\ - auto & ref(U._odata[sU](A)); \ - Impl::loadLinkElement(U_00,ref()(0,0)); \ - Impl::loadLinkElement(U_10,ref()(1,0)); \ - Impl::loadLinkElement(U_20,ref()(2,0)); \ - Impl::loadLinkElement(U_01,ref()(0,1)); \ - Impl::loadLinkElement(U_11,ref()(1,1)); \ - Impl::loadLinkElement(U_21,ref()(2,1)); \ - UChi_00 = U_00*Chi_00;\ - UChi_10 = U_00*Chi_10;\ - UChi_01 = U_10*Chi_00;\ - UChi_11 = U_10*Chi_10;\ - UChi_02 = U_20*Chi_00;\ - UChi_12 = U_20*Chi_10;\ - UChi_00+= U_01*Chi_01;\ - UChi_10+= U_01*Chi_11;\ - UChi_01+= U_11*Chi_01;\ - UChi_11+= U_11*Chi_11;\ - UChi_02+= U_21*Chi_01;\ - UChi_12+= U_21*Chi_11;\ - Impl::loadLinkElement(U_00,ref()(0,2)); \ - Impl::loadLinkElement(U_10,ref()(1,2)); \ - Impl::loadLinkElement(U_20,ref()(2,2)); \ - UChi_00+= U_00*Chi_02;\ - UChi_10+= U_00*Chi_12;\ - UChi_01+= U_10*Chi_02;\ - UChi_11+= U_10*Chi_12;\ - UChi_02+= U_20*Chi_02;\ - UChi_12+= U_20*Chi_12; +#define MULT_2SPIN_BODY \ + Impl::loadLinkElement(U_00,ref()(0,0)); \ + Impl::loadLinkElement(U_10,ref()(1,0)); \ + Impl::loadLinkElement(U_20,ref()(2,0)); \ + Impl::loadLinkElement(U_01,ref()(0,1)); \ + Impl::loadLinkElement(U_11,ref()(1,1)); \ + Impl::loadLinkElement(U_21,ref()(2,1)); \ + UChi_00 = U_00*Chi_00; \ + UChi_10 = U_00*Chi_10; \ + UChi_01 = U_10*Chi_00; \ + UChi_11 = U_10*Chi_10; \ + UChi_02 = U_20*Chi_00; \ + UChi_12 = U_20*Chi_10; \ + UChi_00+= U_01*Chi_01; \ + UChi_10+= U_01*Chi_11; \ + UChi_01+= U_11*Chi_01; \ + UChi_11+= U_11*Chi_11; \ + UChi_02+= U_21*Chi_01; \ + UChi_12+= U_21*Chi_11; \ + Impl::loadLinkElement(U_00,ref()(0,2)); \ + Impl::loadLinkElement(U_10,ref()(1,2)); \ + Impl::loadLinkElement(U_20,ref()(2,2)); \ + UChi_00+= U_00*Chi_02; \ + UChi_10+= U_00*Chi_12; \ + UChi_01+= U_10*Chi_02; \ + UChi_11+= U_10*Chi_12; \ + UChi_02+= U_20*Chi_02; \ + UChi_12+= U_20*Chi_12 + + +#define MULT_2SPIN(A,F) \ + {auto & ref(U._odata[sU](A)); MULT_2SPIN_BODY; } + +#define MULT_2SPIN_GPARITY(A,F) \ + {auto & ref(U._odata[sU](F)(A)); MULT_2SPIN_BODY; } #define PERMUTE_DIR(dir) \ @@ -307,538 +428,503 @@ Author: paboyle result_31-= UChi_11; \ result_32-= UChi_12; +#define HAND_STENCIL_LEG(PROJ,PERM,DIR,RECON,F,LOAD_CHI_IMPL,LOAD_CHIMU_IMPL,MULT_2SPIN_IMPL) \ + SE=st.GetEntry(ptype,DIR,ss); \ + offset = SE->_offset; \ + local = SE->_is_local; \ + perm = SE->_permute; \ + if ( local ) { \ + LOAD_CHIMU_IMPL(DIR,F,PERM); \ + PROJ; \ + if ( perm) { \ + PERMUTE_DIR(PERM); \ + } \ + } else { \ + LOAD_CHI_IMPL(DIR,F,PERM); \ + } \ + MULT_2SPIN_IMPL(DIR,F); \ + RECON; + + +#define HAND_STENCIL_LEG_INT(PROJ,PERM,DIR,RECON,F,LOAD_CHI_IMPL,LOAD_CHIMU_IMPL,MULT_2SPIN_IMPL) \ + SE=st.GetEntry(ptype,DIR,ss); \ + offset = SE->_offset; \ + local = SE->_is_local; \ + perm = SE->_permute; \ + if ( local ) { \ + LOAD_CHIMU_IMPL(DIR,F,PERM); \ + PROJ; \ + if ( perm) { \ + PERMUTE_DIR(PERM); \ + } \ + } else if ( st.same_node[DIR] ) { \ + LOAD_CHI_IMPL(DIR,F,PERM); \ + } \ + if (local || st.same_node[DIR] ) { \ + MULT_2SPIN_IMPL(DIR,F); \ + RECON; \ + } + +#define HAND_STENCIL_LEG_EXT(PROJ,PERM,DIR,RECON,F,LOAD_CHI_IMPL,LOAD_CHIMU_IMPL,MULT_2SPIN_IMPL) \ + SE=st.GetEntry(ptype,DIR,ss); \ + offset = SE->_offset; \ + local = SE->_is_local; \ + perm = SE->_permute; \ + if((!SE->_is_local)&&(!st.same_node[DIR]) ) { \ + LOAD_CHI_IMPL(DIR,F,PERM); \ + MULT_2SPIN_IMPL(DIR,F); \ + RECON; \ + nmu++; \ + } + +#define HAND_RESULT(ss,F) \ + { \ + SiteSpinor & ref (out._odata[ss]); \ + vstream(ref(F)(0)(0),result_00); \ + vstream(ref(F)(0)(1),result_01); \ + vstream(ref(F)(0)(2),result_02); \ + vstream(ref(F)(1)(0),result_10); \ + vstream(ref(F)(1)(1),result_11); \ + vstream(ref(F)(1)(2),result_12); \ + vstream(ref(F)(2)(0),result_20); \ + vstream(ref(F)(2)(1),result_21); \ + vstream(ref(F)(2)(2),result_22); \ + vstream(ref(F)(3)(0),result_30); \ + vstream(ref(F)(3)(1),result_31); \ + vstream(ref(F)(3)(2),result_32); \ + } + +#define HAND_RESULT_EXT(ss,F) \ + if (nmu){ \ + SiteSpinor & ref (out._odata[ss]); \ + ref(F)(0)(0)+=result_00; \ + ref(F)(0)(1)+=result_01; \ + ref(F)(0)(2)+=result_02; \ + ref(F)(1)(0)+=result_10; \ + ref(F)(1)(1)+=result_11; \ + ref(F)(1)(2)+=result_12; \ + ref(F)(2)(0)+=result_20; \ + ref(F)(2)(1)+=result_21; \ + ref(F)(2)(2)+=result_22; \ + ref(F)(3)(0)+=result_30; \ + ref(F)(3)(1)+=result_31; \ + ref(F)(3)(2)+=result_32; \ + } + + +#define HAND_DECLARATIONS(a) \ + Simd result_00; \ + Simd result_01; \ + Simd result_02; \ + Simd result_10; \ + Simd result_11; \ + Simd result_12; \ + Simd result_20; \ + Simd result_21; \ + Simd result_22; \ + Simd result_30; \ + Simd result_31; \ + Simd result_32; \ + Simd Chi_00; \ + Simd Chi_01; \ + Simd Chi_02; \ + Simd Chi_10; \ + Simd Chi_11; \ + Simd Chi_12; \ + Simd UChi_00; \ + Simd UChi_01; \ + Simd UChi_02; \ + Simd UChi_10; \ + Simd UChi_11; \ + Simd UChi_12; \ + Simd U_00; \ + Simd U_10; \ + Simd U_20; \ + Simd U_01; \ + Simd U_11; \ + Simd U_21; + +#define ZERO_RESULT \ + result_00=zero; \ + result_01=zero; \ + result_02=zero; \ + result_10=zero; \ + result_11=zero; \ + result_12=zero; \ + result_20=zero; \ + result_21=zero; \ + result_22=zero; \ + result_30=zero; \ + result_31=zero; \ + result_32=zero; + +#define Chimu_00 Chi_00 +#define Chimu_01 Chi_01 +#define Chimu_02 Chi_02 +#define Chimu_10 Chi_10 +#define Chimu_11 Chi_11 +#define Chimu_12 Chi_12 +#define Chimu_20 UChi_00 +#define Chimu_21 UChi_01 +#define Chimu_22 UChi_02 +#define Chimu_30 UChi_10 +#define Chimu_31 UChi_11 +#define Chimu_32 UChi_12 + namespace Grid { namespace QCD { - template void -WilsonKernels::DiracOptHandDhopSite(StencilImpl &st,LebesgueOrder &lo,DoubledGaugeField &U,SiteHalfSpinor *buf, +WilsonKernels::HandDhopSite(StencilImpl &st,LebesgueOrder &lo,DoubledGaugeField &U,SiteHalfSpinor *buf, int ss,int sU,const FermionField &in, FermionField &out) { +// T==0, Z==1, Y==2, Z==3 expect 1,2,2,2 simd layout etc... typedef typename Simd::scalar_type S; typedef typename Simd::vector_type V; - REGISTER Simd result_00; // 12 regs on knc - REGISTER Simd result_01; - REGISTER Simd result_02; - - REGISTER Simd result_10; - REGISTER Simd result_11; - REGISTER Simd result_12; - - REGISTER Simd result_20; - REGISTER Simd result_21; - REGISTER Simd result_22; - - REGISTER Simd result_30; - REGISTER Simd result_31; - REGISTER Simd result_32; // 20 left - - REGISTER Simd Chi_00; // two spinor; 6 regs - REGISTER Simd Chi_01; - REGISTER Simd Chi_02; - - REGISTER Simd Chi_10; - REGISTER Simd Chi_11; - REGISTER Simd Chi_12; // 14 left - - REGISTER Simd UChi_00; // two spinor; 6 regs - REGISTER Simd UChi_01; - REGISTER Simd UChi_02; - - REGISTER Simd UChi_10; - REGISTER Simd UChi_11; - REGISTER Simd UChi_12; // 8 left - - REGISTER Simd U_00; // two rows of U matrix - REGISTER Simd U_10; - REGISTER Simd U_20; - REGISTER Simd U_01; - REGISTER Simd U_11; - REGISTER Simd U_21; // 2 reg left. - -#define Chimu_00 Chi_00 -#define Chimu_01 Chi_01 -#define Chimu_02 Chi_02 -#define Chimu_10 Chi_10 -#define Chimu_11 Chi_11 -#define Chimu_12 Chi_12 -#define Chimu_20 UChi_00 -#define Chimu_21 UChi_01 -#define Chimu_22 UChi_02 -#define Chimu_30 UChi_10 -#define Chimu_31 UChi_11 -#define Chimu_32 UChi_12 - + HAND_DECLARATIONS(ignore); int offset,local,perm, ptype; StencilEntry *SE; - // Xp - SE=st.GetEntry(ptype,Xp,ss); - offset = SE->_offset; - local = SE->_is_local; - perm = SE->_permute; - - if ( local ) { - LOAD_CHIMU; - XM_PROJ; - if ( perm) { - PERMUTE_DIR(3); // T==0, Z==1, Y==2, Z==3 expect 1,2,2,2 simd layout etc... - } - } else { - LOAD_CHI; - } - { - MULT_2SPIN(Xp); - } - XM_RECON; - - // Yp - SE=st.GetEntry(ptype,Yp,ss); - offset = SE->_offset; - local = SE->_is_local; - perm = SE->_permute; - - if ( local ) { - LOAD_CHIMU; - YM_PROJ; - if ( perm) { - PERMUTE_DIR(2); // T==0, Z==1, Y==2, Z==3 expect 1,2,2,2 simd layout etc... - } - } else { - LOAD_CHI; - } - { - MULT_2SPIN(Yp); - } - YM_RECON_ACCUM; +#define HAND_DOP_SITE(F,LOAD_CHI_IMPL,LOAD_CHIMU_IMPL,MULT_2SPIN_IMPL) \ + HAND_STENCIL_LEG(XM_PROJ,3,Xp,XM_RECON,F,LOAD_CHI_IMPL,LOAD_CHIMU_IMPL,MULT_2SPIN_IMPL); \ + HAND_STENCIL_LEG(YM_PROJ,2,Yp,YM_RECON_ACCUM,F,LOAD_CHI_IMPL,LOAD_CHIMU_IMPL,MULT_2SPIN_IMPL); \ + HAND_STENCIL_LEG(ZM_PROJ,1,Zp,ZM_RECON_ACCUM,F,LOAD_CHI_IMPL,LOAD_CHIMU_IMPL,MULT_2SPIN_IMPL); \ + HAND_STENCIL_LEG(TM_PROJ,0,Tp,TM_RECON_ACCUM,F,LOAD_CHI_IMPL,LOAD_CHIMU_IMPL,MULT_2SPIN_IMPL); \ + HAND_STENCIL_LEG(XP_PROJ,3,Xm,XP_RECON_ACCUM,F,LOAD_CHI_IMPL,LOAD_CHIMU_IMPL,MULT_2SPIN_IMPL); \ + HAND_STENCIL_LEG(YP_PROJ,2,Ym,YP_RECON_ACCUM,F,LOAD_CHI_IMPL,LOAD_CHIMU_IMPL,MULT_2SPIN_IMPL); \ + HAND_STENCIL_LEG(ZP_PROJ,1,Zm,ZP_RECON_ACCUM,F,LOAD_CHI_IMPL,LOAD_CHIMU_IMPL,MULT_2SPIN_IMPL); \ + HAND_STENCIL_LEG(TP_PROJ,0,Tm,TP_RECON_ACCUM,F,LOAD_CHI_IMPL,LOAD_CHIMU_IMPL,MULT_2SPIN_IMPL); \ + HAND_RESULT(ss,F) - - // Zp - SE=st.GetEntry(ptype,Zp,ss); - offset = SE->_offset; - local = SE->_is_local; - perm = SE->_permute; - - if ( local ) { - LOAD_CHIMU; - ZM_PROJ; - if ( perm) { - PERMUTE_DIR(1); // T==0, Z==1, Y==2, Z==3 expect 1,2,2,2 simd layout etc... - } - } else { - LOAD_CHI; - } - { - MULT_2SPIN(Zp); - } - ZM_RECON_ACCUM; - - // Tp - SE=st.GetEntry(ptype,Tp,ss); - offset = SE->_offset; - local = SE->_is_local; - perm = SE->_permute; - - if ( local ) { - LOAD_CHIMU; - TM_PROJ; - if ( perm) { - PERMUTE_DIR(0); // T==0, Z==1, Y==2, Z==3 expect 1,2,2,2 simd layout etc... - } - } else { - LOAD_CHI; - } - { - MULT_2SPIN(Tp); - } - TM_RECON_ACCUM; - - // Xm - SE=st.GetEntry(ptype,Xm,ss); - offset = SE->_offset; - local = SE->_is_local; - perm = SE->_permute; - - if ( local ) { - LOAD_CHIMU; - XP_PROJ; - if ( perm) { - PERMUTE_DIR(3); // T==0, Z==1, Y==2, Z==3 expect 1,2,2,2 simd layout etc... - } - } else { - LOAD_CHI; - } - { - MULT_2SPIN(Xm); - } - XP_RECON_ACCUM; - - - // Ym - SE=st.GetEntry(ptype,Ym,ss); - offset = SE->_offset; - local = SE->_is_local; - perm = SE->_permute; - - if ( local ) { - LOAD_CHIMU; - YP_PROJ; - if ( perm) { - PERMUTE_DIR(2); // T==0, Z==1, Y==2, Z==3 expect 1,2,2,2 simd layout etc... - } - } else { - LOAD_CHI; - } - { - MULT_2SPIN(Ym); - } - YP_RECON_ACCUM; - - // Zm - SE=st.GetEntry(ptype,Zm,ss); - offset = SE->_offset; - local = SE->_is_local; - perm = SE->_permute; - - if ( local ) { - LOAD_CHIMU; - ZP_PROJ; - if ( perm) { - PERMUTE_DIR(1); // T==0, Z==1, Y==2, Z==3 expect 1,2,2,2 simd layout etc... - } - } else { - LOAD_CHI; - } - { - MULT_2SPIN(Zm); - } - ZP_RECON_ACCUM; - - // Tm - SE=st.GetEntry(ptype,Tm,ss); - offset = SE->_offset; - local = SE->_is_local; - perm = SE->_permute; - - if ( local ) { - LOAD_CHIMU; - TP_PROJ; - if ( perm) { - PERMUTE_DIR(0); // T==0, Z==1, Y==2, Z==3 expect 1,2,2,2 simd layout etc... - } - } else { - LOAD_CHI; - } - { - MULT_2SPIN(Tm); - } - TP_RECON_ACCUM; - - { - SiteSpinor & ref (out._odata[ss]); - vstream(ref()(0)(0),result_00); - vstream(ref()(0)(1),result_01); - vstream(ref()(0)(2),result_02); - vstream(ref()(1)(0),result_10); - vstream(ref()(1)(1),result_11); - vstream(ref()(1)(2),result_12); - vstream(ref()(2)(0),result_20); - vstream(ref()(2)(1),result_21); - vstream(ref()(2)(2),result_22); - vstream(ref()(3)(0),result_30); - vstream(ref()(3)(1),result_31); - vstream(ref()(3)(2),result_32); - } + HAND_DOP_SITE(, LOAD_CHI,LOAD_CHIMU,MULT_2SPIN); } template -void WilsonKernels::DiracOptHandDhopSiteDag(StencilImpl &st,LebesgueOrder &lo,DoubledGaugeField &U,SiteHalfSpinor *buf, +void WilsonKernels::HandDhopSiteDag(StencilImpl &st,LebesgueOrder &lo,DoubledGaugeField &U,SiteHalfSpinor *buf, int ss,int sU,const FermionField &in, FermionField &out) { - // std::cout << "Hand op Dhop "< void +WilsonKernels::HandDhopSiteInt(StencilImpl &st,LebesgueOrder &lo,DoubledGaugeField &U,SiteHalfSpinor *buf, + int ss,int sU,const FermionField &in, FermionField &out) +{ +// T==0, Z==1, Y==2, Z==3 expect 1,2,2,2 simd layout etc... + typedef typename Simd::scalar_type S; + typedef typename Simd::vector_type V; + + HAND_DECLARATIONS(ignore); + + int offset,local,perm, ptype; + StencilEntry *SE; + +#define HAND_DOP_SITE_INT(F,LOAD_CHI_IMPL,LOAD_CHIMU_IMPL,MULT_2SPIN_IMPL) \ + ZERO_RESULT; \ + HAND_STENCIL_LEG_INT(XM_PROJ,3,Xp,XM_RECON_ACCUM,F,LOAD_CHI_IMPL,LOAD_CHIMU_IMPL,MULT_2SPIN_IMPL); \ + HAND_STENCIL_LEG_INT(YM_PROJ,2,Yp,YM_RECON_ACCUM,F,LOAD_CHI_IMPL,LOAD_CHIMU_IMPL,MULT_2SPIN_IMPL); \ + HAND_STENCIL_LEG_INT(ZM_PROJ,1,Zp,ZM_RECON_ACCUM,F,LOAD_CHI_IMPL,LOAD_CHIMU_IMPL,MULT_2SPIN_IMPL); \ + HAND_STENCIL_LEG_INT(TM_PROJ,0,Tp,TM_RECON_ACCUM,F,LOAD_CHI_IMPL,LOAD_CHIMU_IMPL,MULT_2SPIN_IMPL); \ + HAND_STENCIL_LEG_INT(XP_PROJ,3,Xm,XP_RECON_ACCUM,F,LOAD_CHI_IMPL,LOAD_CHIMU_IMPL,MULT_2SPIN_IMPL); \ + HAND_STENCIL_LEG_INT(YP_PROJ,2,Ym,YP_RECON_ACCUM,F,LOAD_CHI_IMPL,LOAD_CHIMU_IMPL,MULT_2SPIN_IMPL); \ + HAND_STENCIL_LEG_INT(ZP_PROJ,1,Zm,ZP_RECON_ACCUM,F,LOAD_CHI_IMPL,LOAD_CHIMU_IMPL,MULT_2SPIN_IMPL); \ + HAND_STENCIL_LEG_INT(TP_PROJ,0,Tm,TP_RECON_ACCUM,F,LOAD_CHI_IMPL,LOAD_CHIMU_IMPL,MULT_2SPIN_IMPL); \ + HAND_RESULT(ss,F) + + HAND_DOP_SITE_INT(, LOAD_CHI,LOAD_CHIMU,MULT_2SPIN); +} + +template +void WilsonKernels::HandDhopSiteDagInt(StencilImpl &st,LebesgueOrder &lo,DoubledGaugeField &U,SiteHalfSpinor *buf, + int ss,int sU,const FermionField &in, FermionField &out) +{ + typedef typename Simd::scalar_type S; + typedef typename Simd::vector_type V; + + HAND_DECLARATIONS(ignore); + + StencilEntry *SE; + int offset,local,perm, ptype; + +#define HAND_DOP_SITE_DAG_INT(F,LOAD_CHI_IMPL,LOAD_CHIMU_IMPL,MULT_2SPIN_IMPL) \ + ZERO_RESULT; \ + HAND_STENCIL_LEG_INT(XP_PROJ,3,Xp,XP_RECON_ACCUM,F,LOAD_CHI_IMPL,LOAD_CHIMU_IMPL,MULT_2SPIN_IMPL); \ + HAND_STENCIL_LEG_INT(YP_PROJ,2,Yp,YP_RECON_ACCUM,F,LOAD_CHI_IMPL,LOAD_CHIMU_IMPL,MULT_2SPIN_IMPL); \ + HAND_STENCIL_LEG_INT(ZP_PROJ,1,Zp,ZP_RECON_ACCUM,F,LOAD_CHI_IMPL,LOAD_CHIMU_IMPL,MULT_2SPIN_IMPL); \ + HAND_STENCIL_LEG_INT(TP_PROJ,0,Tp,TP_RECON_ACCUM,F,LOAD_CHI_IMPL,LOAD_CHIMU_IMPL,MULT_2SPIN_IMPL); \ + HAND_STENCIL_LEG_INT(XM_PROJ,3,Xm,XM_RECON_ACCUM,F,LOAD_CHI_IMPL,LOAD_CHIMU_IMPL,MULT_2SPIN_IMPL); \ + HAND_STENCIL_LEG_INT(YM_PROJ,2,Ym,YM_RECON_ACCUM,F,LOAD_CHI_IMPL,LOAD_CHIMU_IMPL,MULT_2SPIN_IMPL); \ + HAND_STENCIL_LEG_INT(ZM_PROJ,1,Zm,ZM_RECON_ACCUM,F,LOAD_CHI_IMPL,LOAD_CHIMU_IMPL,MULT_2SPIN_IMPL); \ + HAND_STENCIL_LEG_INT(TM_PROJ,0,Tm,TM_RECON_ACCUM,F,LOAD_CHI_IMPL,LOAD_CHIMU_IMPL,MULT_2SPIN_IMPL); \ + HAND_RESULT(ss,F) - // Xp - SE=st.GetEntry(ptype,Xp,ss); - offset = SE->_offset; - local = SE->_is_local; - perm = SE->_permute; - - if ( local ) { - LOAD_CHIMU; - XP_PROJ; - if ( perm) { - PERMUTE_DIR(3); // T==0, Z==1, Y==2, Z==3 expect 1,2,2,2 simd layout etc... - } - } else { - LOAD_CHI; - } + HAND_DOP_SITE_DAG_INT(, LOAD_CHI,LOAD_CHIMU,MULT_2SPIN); +} - { - MULT_2SPIN(Xp); - } - XP_RECON; +template void +WilsonKernels::HandDhopSiteExt(StencilImpl &st,LebesgueOrder &lo,DoubledGaugeField &U,SiteHalfSpinor *buf, + int ss,int sU,const FermionField &in, FermionField &out) +{ +// T==0, Z==1, Y==2, Z==3 expect 1,2,2,2 simd layout etc... + typedef typename Simd::scalar_type S; + typedef typename Simd::vector_type V; - // Yp - SE=st.GetEntry(ptype,Yp,ss); - offset = SE->_offset; - local = SE->_is_local; - perm = SE->_permute; - - if ( local ) { - LOAD_CHIMU; - YP_PROJ; - if ( perm) { - PERMUTE_DIR(2); // T==0, Z==1, Y==2, Z==3 expect 1,2,2,2 simd layout etc... - } - } else { - LOAD_CHI; - } - { - MULT_2SPIN(Yp); - } - YP_RECON_ACCUM; + HAND_DECLARATIONS(ignore); + int offset,local,perm, ptype; + StencilEntry *SE; + int nmu=0; - // Zp - SE=st.GetEntry(ptype,Zp,ss); - offset = SE->_offset; - local = SE->_is_local; - perm = SE->_permute; - - if ( local ) { - LOAD_CHIMU; - ZP_PROJ; - if ( perm) { - PERMUTE_DIR(1); // T==0, Z==1, Y==2, Z==3 expect 1,2,2,2 simd layout etc... - } - } else { - LOAD_CHI; - } - { - MULT_2SPIN(Zp); - } - ZP_RECON_ACCUM; +#define HAND_DOP_SITE_EXT(F,LOAD_CHI_IMPL,LOAD_CHIMU_IMPL,MULT_2SPIN_IMPL) \ + ZERO_RESULT; \ + HAND_STENCIL_LEG_EXT(XM_PROJ,3,Xp,XM_RECON_ACCUM,F,LOAD_CHI_IMPL,LOAD_CHIMU_IMPL,MULT_2SPIN_IMPL); \ + HAND_STENCIL_LEG_EXT(YM_PROJ,2,Yp,YM_RECON_ACCUM,F,LOAD_CHI_IMPL,LOAD_CHIMU_IMPL,MULT_2SPIN_IMPL); \ + HAND_STENCIL_LEG_EXT(ZM_PROJ,1,Zp,ZM_RECON_ACCUM,F,LOAD_CHI_IMPL,LOAD_CHIMU_IMPL,MULT_2SPIN_IMPL); \ + HAND_STENCIL_LEG_EXT(TM_PROJ,0,Tp,TM_RECON_ACCUM,F,LOAD_CHI_IMPL,LOAD_CHIMU_IMPL,MULT_2SPIN_IMPL); \ + HAND_STENCIL_LEG_EXT(XP_PROJ,3,Xm,XP_RECON_ACCUM,F,LOAD_CHI_IMPL,LOAD_CHIMU_IMPL,MULT_2SPIN_IMPL); \ + HAND_STENCIL_LEG_EXT(YP_PROJ,2,Ym,YP_RECON_ACCUM,F,LOAD_CHI_IMPL,LOAD_CHIMU_IMPL,MULT_2SPIN_IMPL); \ + HAND_STENCIL_LEG_EXT(ZP_PROJ,1,Zm,ZP_RECON_ACCUM,F,LOAD_CHI_IMPL,LOAD_CHIMU_IMPL,MULT_2SPIN_IMPL); \ + HAND_STENCIL_LEG_EXT(TP_PROJ,0,Tm,TP_RECON_ACCUM,F,LOAD_CHI_IMPL,LOAD_CHIMU_IMPL,MULT_2SPIN_IMPL); \ + HAND_RESULT_EXT(ss,F) - // Tp - SE=st.GetEntry(ptype,Tp,ss); - offset = SE->_offset; - local = SE->_is_local; - perm = SE->_permute; - - if ( local ) { - LOAD_CHIMU; - TP_PROJ; - if ( perm) { - PERMUTE_DIR(0); // T==0, Z==1, Y==2, Z==3 expect 1,2,2,2 simd layout etc... - } - } else { - LOAD_CHI; - } - { - MULT_2SPIN(Tp); - } - TP_RECON_ACCUM; - - // Xm - SE=st.GetEntry(ptype,Xm,ss); - offset = SE->_offset; - local = SE->_is_local; - perm = SE->_permute; - - if ( local ) { - LOAD_CHIMU; - XM_PROJ; - if ( perm) { - PERMUTE_DIR(3); // T==0, Z==1, Y==2, Z==3 expect 1,2,2,2 simd layout etc... - } - } else { - LOAD_CHI; - } - { - MULT_2SPIN(Xm); - } - XM_RECON_ACCUM; - - // Ym - SE=st.GetEntry(ptype,Ym,ss); - offset = SE->_offset; - local = SE->_is_local; - perm = SE->_permute; - - if ( local ) { - LOAD_CHIMU; - YM_PROJ; - if ( perm) { - PERMUTE_DIR(2); // T==0, Z==1, Y==2, Z==3 expect 1,2,2,2 simd layout etc... - } - } else { - LOAD_CHI; - } - { - MULT_2SPIN(Ym); - } - YM_RECON_ACCUM; + HAND_DOP_SITE_EXT(, LOAD_CHI,LOAD_CHIMU,MULT_2SPIN); +} - // Zm - SE=st.GetEntry(ptype,Zm,ss); - offset = SE->_offset; - local = SE->_is_local; - perm = SE->_permute; +template +void WilsonKernels::HandDhopSiteDagExt(StencilImpl &st,LebesgueOrder &lo,DoubledGaugeField &U,SiteHalfSpinor *buf, + int ss,int sU,const FermionField &in, FermionField &out) +{ + typedef typename Simd::scalar_type S; + typedef typename Simd::vector_type V; - if ( local ) { - LOAD_CHIMU; - ZM_PROJ; - if ( perm) { - PERMUTE_DIR(1); // T==0, Z==1, Y==2, Z==3 expect 1,2,2,2 simd layout etc... - } - } else { - LOAD_CHI; - } - { - MULT_2SPIN(Zm); - } - ZM_RECON_ACCUM; + HAND_DECLARATIONS(ignore); - // Tm - SE=st.GetEntry(ptype,Tm,ss); - offset = SE->_offset; - local = SE->_is_local; - perm = SE->_permute; + StencilEntry *SE; + int offset,local,perm, ptype; + int nmu=0; - if ( local ) { - LOAD_CHIMU; - TM_PROJ; - if ( perm) { - PERMUTE_DIR(0); // T==0, Z==1, Y==2, Z==3 expect 1,2,2,2 simd layout etc... - } - } else { - LOAD_CHI; - } - { - MULT_2SPIN(Tm); - } - TM_RECON_ACCUM; +#define HAND_DOP_SITE_DAG_EXT(F,LOAD_CHI_IMPL,LOAD_CHIMU_IMPL,MULT_2SPIN_IMPL) \ + ZERO_RESULT; \ + HAND_STENCIL_LEG_EXT(XP_PROJ,3,Xp,XP_RECON_ACCUM,F,LOAD_CHI_IMPL,LOAD_CHIMU_IMPL,MULT_2SPIN_IMPL); \ + HAND_STENCIL_LEG_EXT(YP_PROJ,2,Yp,YP_RECON_ACCUM,F,LOAD_CHI_IMPL,LOAD_CHIMU_IMPL,MULT_2SPIN_IMPL); \ + HAND_STENCIL_LEG_EXT(ZP_PROJ,1,Zp,ZP_RECON_ACCUM,F,LOAD_CHI_IMPL,LOAD_CHIMU_IMPL,MULT_2SPIN_IMPL); \ + HAND_STENCIL_LEG_EXT(TP_PROJ,0,Tp,TP_RECON_ACCUM,F,LOAD_CHI_IMPL,LOAD_CHIMU_IMPL,MULT_2SPIN_IMPL); \ + HAND_STENCIL_LEG_EXT(XM_PROJ,3,Xm,XM_RECON_ACCUM,F,LOAD_CHI_IMPL,LOAD_CHIMU_IMPL,MULT_2SPIN_IMPL); \ + HAND_STENCIL_LEG_EXT(YM_PROJ,2,Ym,YM_RECON_ACCUM,F,LOAD_CHI_IMPL,LOAD_CHIMU_IMPL,MULT_2SPIN_IMPL); \ + HAND_STENCIL_LEG_EXT(ZM_PROJ,1,Zm,ZM_RECON_ACCUM,F,LOAD_CHI_IMPL,LOAD_CHIMU_IMPL,MULT_2SPIN_IMPL); \ + HAND_STENCIL_LEG_EXT(TM_PROJ,0,Tm,TM_RECON_ACCUM,F,LOAD_CHI_IMPL,LOAD_CHIMU_IMPL,MULT_2SPIN_IMPL); \ + HAND_RESULT_EXT(ss,F) - { - SiteSpinor & ref (out._odata[ss]); - vstream(ref()(0)(0),result_00); - vstream(ref()(0)(1),result_01); - vstream(ref()(0)(2),result_02); - vstream(ref()(1)(0),result_10); - vstream(ref()(1)(1),result_11); - vstream(ref()(1)(2),result_12); - vstream(ref()(2)(0),result_20); - vstream(ref()(2)(1),result_21); - vstream(ref()(2)(2),result_22); - vstream(ref()(3)(0),result_30); - vstream(ref()(3)(1),result_31); - vstream(ref()(3)(2),result_32); - } + HAND_DOP_SITE_DAG_EXT(, LOAD_CHI,LOAD_CHIMU,MULT_2SPIN); } //////////////////////////////////////////////// // Specialise Gparity to simple implementation //////////////////////////////////////////////// -template<> void -WilsonKernels::DiracOptHandDhopSite(StencilImpl &st,LebesgueOrder &lo,DoubledGaugeField &U, - SiteHalfSpinor *buf, - int sF,int sU,const FermionField &in, FermionField &out) -{ - assert(0); -} - -template<> void -WilsonKernels::DiracOptHandDhopSiteDag(StencilImpl &st,LebesgueOrder &lo,DoubledGaugeField &U, - SiteHalfSpinor *buf, - int sF,int sU,const FermionField &in, FermionField &out) -{ - assert(0); -} - -template<> void -WilsonKernels::DiracOptHandDhopSite(StencilImpl &st,LebesgueOrder &lo,DoubledGaugeField &U,SiteHalfSpinor *buf, - int sF,int sU,const FermionField &in, FermionField &out) -{ - assert(0); -} - -template<> void -WilsonKernels::DiracOptHandDhopSiteDag(StencilImpl &st,LebesgueOrder &lo,DoubledGaugeField &U,SiteHalfSpinor *buf, - int sF,int sU,const FermionField &in, FermionField &out) -{ - assert(0); -} +#define HAND_SPECIALISE_EMPTY(IMPL) \ + template<> void \ + WilsonKernels::HandDhopSite(StencilImpl &st, \ + LebesgueOrder &lo, \ + DoubledGaugeField &U, \ + SiteHalfSpinor *buf, \ + int sF,int sU, \ + const FermionField &in, \ + FermionField &out){ assert(0); } \ + template<> void \ + WilsonKernels::HandDhopSiteDag(StencilImpl &st, \ + LebesgueOrder &lo, \ + DoubledGaugeField &U, \ + SiteHalfSpinor *buf, \ + int sF,int sU, \ + const FermionField &in, \ + FermionField &out){ assert(0); } \ + template<> void \ + WilsonKernels::HandDhopSiteInt(StencilImpl &st, \ + LebesgueOrder &lo, \ + DoubledGaugeField &U, \ + SiteHalfSpinor *buf, \ + int sF,int sU, \ + const FermionField &in, \ + FermionField &out){ assert(0); } \ + template<> void \ + WilsonKernels::HandDhopSiteExt(StencilImpl &st, \ + LebesgueOrder &lo, \ + DoubledGaugeField &U, \ + SiteHalfSpinor *buf, \ + int sF,int sU, \ + const FermionField &in, \ + FermionField &out){ assert(0); } \ + template<> void \ + WilsonKernels::HandDhopSiteDagInt(StencilImpl &st, \ + LebesgueOrder &lo, \ + DoubledGaugeField &U, \ + SiteHalfSpinor *buf, \ + int sF,int sU, \ + const FermionField &in, \ + FermionField &out){ assert(0); } \ + template<> void \ + WilsonKernels::HandDhopSiteDagExt(StencilImpl &st, \ + LebesgueOrder &lo, \ + DoubledGaugeField &U, \ + SiteHalfSpinor *buf, \ + int sF,int sU, \ + const FermionField &in, \ + FermionField &out){ assert(0); } \ +#define HAND_SPECIALISE_GPARITY(IMPL) \ + template<> void \ + WilsonKernels::HandDhopSite(StencilImpl &st,LebesgueOrder &lo,DoubledGaugeField &U,SiteHalfSpinor *buf, \ + int ss,int sU,const FermionField &in, FermionField &out) \ + { \ + typedef IMPL Impl; \ + typedef typename Simd::scalar_type S; \ + typedef typename Simd::vector_type V; \ + \ + HAND_DECLARATIONS(ignore); \ + \ + int offset,local,perm, ptype, g, direction, distance, sl, inplace_twist; \ + StencilEntry *SE; \ + HAND_DOP_SITE(0, LOAD_CHI_GPARITY,LOAD_CHIMU_GPARITY,MULT_2SPIN_GPARITY); \ + HAND_DOP_SITE(1, LOAD_CHI_GPARITY,LOAD_CHIMU_GPARITY,MULT_2SPIN_GPARITY); \ + } \ + \ + template<> \ + void WilsonKernels::HandDhopSiteDag(StencilImpl &st,LebesgueOrder &lo,DoubledGaugeField &U,SiteHalfSpinor *buf, \ + int ss,int sU,const FermionField &in, FermionField &out) \ + { \ + typedef IMPL Impl; \ + typedef typename Simd::scalar_type S; \ + typedef typename Simd::vector_type V; \ + \ + HAND_DECLARATIONS(ignore); \ + \ + StencilEntry *SE; \ + int offset,local,perm, ptype, g, direction, distance, sl, inplace_twist; \ + HAND_DOP_SITE_DAG(0, LOAD_CHI_GPARITY,LOAD_CHIMU_GPARITY,MULT_2SPIN_GPARITY); \ + HAND_DOP_SITE_DAG(1, LOAD_CHI_GPARITY,LOAD_CHIMU_GPARITY,MULT_2SPIN_GPARITY); \ + } \ + \ + template<> void \ + WilsonKernels::HandDhopSiteInt(StencilImpl &st,LebesgueOrder &lo,DoubledGaugeField &U,SiteHalfSpinor *buf, \ + int ss,int sU,const FermionField &in, FermionField &out) \ + { \ + typedef IMPL Impl; \ + typedef typename Simd::scalar_type S; \ + typedef typename Simd::vector_type V; \ + \ + HAND_DECLARATIONS(ignore); \ + \ + int offset,local,perm, ptype, g, direction, distance, sl, inplace_twist; \ + StencilEntry *SE; \ + HAND_DOP_SITE_INT(0, LOAD_CHI_GPARITY,LOAD_CHIMU_GPARITY,MULT_2SPIN_GPARITY); \ + HAND_DOP_SITE_INT(1, LOAD_CHI_GPARITY,LOAD_CHIMU_GPARITY,MULT_2SPIN_GPARITY); \ + } \ + \ + template<> \ + void WilsonKernels::HandDhopSiteDagInt(StencilImpl &st,LebesgueOrder &lo,DoubledGaugeField &U,SiteHalfSpinor *buf, \ + int ss,int sU,const FermionField &in, FermionField &out) \ + { \ + typedef IMPL Impl; \ + typedef typename Simd::scalar_type S; \ + typedef typename Simd::vector_type V; \ + \ + HAND_DECLARATIONS(ignore); \ + \ + StencilEntry *SE; \ + int offset,local,perm, ptype, g, direction, distance, sl, inplace_twist; \ + HAND_DOP_SITE_DAG_INT(0, LOAD_CHI_GPARITY,LOAD_CHIMU_GPARITY,MULT_2SPIN_GPARITY); \ + HAND_DOP_SITE_DAG_INT(1, LOAD_CHI_GPARITY,LOAD_CHIMU_GPARITY,MULT_2SPIN_GPARITY); \ + } \ + \ + template<> void \ + WilsonKernels::HandDhopSiteExt(StencilImpl &st,LebesgueOrder &lo,DoubledGaugeField &U,SiteHalfSpinor *buf, \ + int ss,int sU,const FermionField &in, FermionField &out) \ + { \ + typedef IMPL Impl; \ + typedef typename Simd::scalar_type S; \ + typedef typename Simd::vector_type V; \ + \ + HAND_DECLARATIONS(ignore); \ + \ + int offset,local,perm, ptype, g, direction, distance, sl, inplace_twist; \ + StencilEntry *SE; \ + int nmu=0; \ + HAND_DOP_SITE_EXT(0, LOAD_CHI_GPARITY,LOAD_CHIMU_GPARITY,MULT_2SPIN_GPARITY); \ + nmu = 0; \ + HAND_DOP_SITE_EXT(1, LOAD_CHI_GPARITY,LOAD_CHIMU_GPARITY,MULT_2SPIN_GPARITY); \ + } \ + template<> \ + void WilsonKernels::HandDhopSiteDagExt(StencilImpl &st,LebesgueOrder &lo,DoubledGaugeField &U,SiteHalfSpinor *buf, \ + int ss,int sU,const FermionField &in, FermionField &out) \ + { \ + typedef IMPL Impl; \ + typedef typename Simd::scalar_type S; \ + typedef typename Simd::vector_type V; \ + \ + HAND_DECLARATIONS(ignore); \ + \ + StencilEntry *SE; \ + int offset,local,perm, ptype, g, direction, distance, sl, inplace_twist; \ + int nmu=0; \ + HAND_DOP_SITE_DAG_EXT(0, LOAD_CHI_GPARITY,LOAD_CHIMU_GPARITY,MULT_2SPIN_GPARITY); \ + nmu = 0; \ + HAND_DOP_SITE_DAG_EXT(1, LOAD_CHI_GPARITY,LOAD_CHIMU_GPARITY,MULT_2SPIN_GPARITY); \ + } + + +HAND_SPECIALISE_GPARITY(GparityWilsonImplF); +HAND_SPECIALISE_GPARITY(GparityWilsonImplD); +HAND_SPECIALISE_GPARITY(GparityWilsonImplFH); +HAND_SPECIALISE_GPARITY(GparityWilsonImplDF); + + + + + + + + + + + ////////////// Wilson ; uses this implementation ///////////////////// -// Need Nc=3 though // #define INSTANTIATE_THEM(A) \ -template void WilsonKernels::DiracOptHandDhopSite(StencilImpl &st,LebesgueOrder &lo,DoubledGaugeField &U,SiteHalfSpinor *buf,\ - int ss,int sU,const FermionField &in, FermionField &out); \ -template void WilsonKernels::DiracOptHandDhopSiteDag(StencilImpl &st,LebesgueOrder &lo,DoubledGaugeField &U,SiteHalfSpinor *buf,\ - int ss,int sU,const FermionField &in, FermionField &out); +template void WilsonKernels::HandDhopSite(StencilImpl &st,LebesgueOrder &lo,DoubledGaugeField &U,SiteHalfSpinor *buf,\ + int ss,int sU,const FermionField &in, FermionField &out); \ +template void WilsonKernels::HandDhopSiteDag(StencilImpl &st,LebesgueOrder &lo,DoubledGaugeField &U,SiteHalfSpinor *buf, \ + int ss,int sU,const FermionField &in, FermionField &out);\ +template void WilsonKernels::HandDhopSiteInt(StencilImpl &st,LebesgueOrder &lo,DoubledGaugeField &U,SiteHalfSpinor *buf,\ + int ss,int sU,const FermionField &in, FermionField &out); \ +template void WilsonKernels::HandDhopSiteDagInt(StencilImpl &st,LebesgueOrder &lo,DoubledGaugeField &U,SiteHalfSpinor *buf, \ + int ss,int sU,const FermionField &in, FermionField &out); \ +template void WilsonKernels::HandDhopSiteExt(StencilImpl &st,LebesgueOrder &lo,DoubledGaugeField &U,SiteHalfSpinor *buf,\ + int ss,int sU,const FermionField &in, FermionField &out); \ +template void WilsonKernels::HandDhopSiteDagExt(StencilImpl &st,LebesgueOrder &lo,DoubledGaugeField &U,SiteHalfSpinor *buf, \ + int ss,int sU,const FermionField &in, FermionField &out); INSTANTIATE_THEM(WilsonImplF); INSTANTIATE_THEM(WilsonImplD); @@ -850,5 +936,15 @@ INSTANTIATE_THEM(DomainWallVec5dImplF); INSTANTIATE_THEM(DomainWallVec5dImplD); INSTANTIATE_THEM(ZDomainWallVec5dImplF); INSTANTIATE_THEM(ZDomainWallVec5dImplD); +INSTANTIATE_THEM(WilsonImplFH); +INSTANTIATE_THEM(WilsonImplDF); +INSTANTIATE_THEM(ZWilsonImplFH); +INSTANTIATE_THEM(ZWilsonImplDF); +INSTANTIATE_THEM(GparityWilsonImplFH); +INSTANTIATE_THEM(GparityWilsonImplDF); +INSTANTIATE_THEM(DomainWallVec5dImplFH); +INSTANTIATE_THEM(DomainWallVec5dImplDF); +INSTANTIATE_THEM(ZDomainWallVec5dImplFH); +INSTANTIATE_THEM(ZDomainWallVec5dImplDF); }} diff --git a/lib/qcd/action/fermion/WilsonTMFermion.cc b/lib/qcd/action/fermion/WilsonTMFermion.cc index f74f9f00..d4604b10 100644 --- a/lib/qcd/action/fermion/WilsonTMFermion.cc +++ b/lib/qcd/action/fermion/WilsonTMFermion.cc @@ -25,7 +25,8 @@ Author: paboyle See the full license in the file "LICENSE" in the top level distribution directory *************************************************************************************/ /* END LEGAL */ -#include +#include +#include namespace Grid { namespace QCD { diff --git a/lib/qcd/action/fermion/WilsonTMFermion.h b/lib/qcd/action/fermion/WilsonTMFermion.h index 5901cb2f..f75c287b 100644 --- a/lib/qcd/action/fermion/WilsonTMFermion.h +++ b/lib/qcd/action/fermion/WilsonTMFermion.h @@ -28,7 +28,8 @@ Author: paboyle #ifndef GRID_QCD_WILSON_TM_FERMION_H #define GRID_QCD_WILSON_TM_FERMION_H -#include +#include +#include namespace Grid { diff --git a/lib/qcd/action/fermion/ZMobiusFermion.h b/lib/qcd/action/fermion/ZMobiusFermion.h index d0e00657..32ff7670 100644 --- a/lib/qcd/action/fermion/ZMobiusFermion.h +++ b/lib/qcd/action/fermion/ZMobiusFermion.h @@ -29,7 +29,7 @@ Author: Peter Boyle #ifndef GRID_QCD_ZMOBIUS_FERMION_H #define GRID_QCD_ZMOBIUS_FERMION_H -#include +#include namespace Grid { diff --git a/lib/qcd/action/fermion/g5HermitianLinop.h b/lib/qcd/action/fermion/g5HermitianLinop.h index af23c36f..cca7a113 100644 --- a/lib/qcd/action/fermion/g5HermitianLinop.h +++ b/lib/qcd/action/fermion/g5HermitianLinop.h @@ -80,7 +80,7 @@ class Gamma5HermitianLinearOperator : public LinearOperatorBase { Matrix &_Mat; Gamma g5; public: - Gamma5HermitianLinearOperator(Matrix &Mat): _Mat(Mat), g5(Gamma::Gamma5) {}; + Gamma5HermitianLinearOperator(Matrix &Mat): _Mat(Mat), g5(Gamma::Algebra::Gamma5) {}; void Op (const Field &in, Field &out){ HermOp(in,out); } diff --git a/lib/qcd/action/gauge/Gauge.h b/lib/qcd/action/gauge/Gauge.h new file mode 100644 index 00000000..6f94cf00 --- /dev/null +++ b/lib/qcd/action/gauge/Gauge.h @@ -0,0 +1,70 @@ +/************************************************************************************* + +Grid physics library, www.github.com/paboyle/Grid + +Source file: ./lib/qcd/action/gauge/Gauge_aggregate.h + +Copyright (C) 2015 + +Author: paboyle + +This program is free software; you can redistribute it and/or modify +it under the terms of the GNU General Public License as published by +the Free Software Foundation; either version 2 of the License, or +(at your option) any later version. + +This program is distributed in the hope that it will be useful, +but WITHOUT ANY WARRANTY; without even the implied warranty of +MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +GNU General Public License for more details. + +You should have received a copy of the GNU General Public License along +with this program; if not, write to the Free Software Foundation, Inc., +51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA. + +See the full license in the file "LICENSE" in the top level distribution +directory +*************************************************************************************/ +/* END LEGAL */ +#ifndef GRID_QCD_GAUGE_H +#define GRID_QCD_GAUGE_H + +#include +#include +#include +#include + +namespace Grid { +namespace QCD { + +typedef WilsonGaugeAction WilsonGaugeActionR; +typedef WilsonGaugeAction WilsonGaugeActionF; +typedef WilsonGaugeAction WilsonGaugeActionD; +typedef PlaqPlusRectangleAction PlaqPlusRectangleActionR; +typedef PlaqPlusRectangleAction PlaqPlusRectangleActionF; +typedef PlaqPlusRectangleAction PlaqPlusRectangleActionD; +typedef IwasakiGaugeAction IwasakiGaugeActionR; +typedef IwasakiGaugeAction IwasakiGaugeActionF; +typedef IwasakiGaugeAction IwasakiGaugeActionD; +typedef SymanzikGaugeAction SymanzikGaugeActionR; +typedef SymanzikGaugeAction SymanzikGaugeActionF; +typedef SymanzikGaugeAction SymanzikGaugeActionD; + + +typedef WilsonGaugeAction ConjugateWilsonGaugeActionR; +typedef WilsonGaugeAction ConjugateWilsonGaugeActionF; +typedef WilsonGaugeAction ConjugateWilsonGaugeActionD; +typedef PlaqPlusRectangleAction ConjugatePlaqPlusRectangleActionR; +typedef PlaqPlusRectangleAction ConjugatePlaqPlusRectangleActionF; +typedef PlaqPlusRectangleAction ConjugatePlaqPlusRectangleActionD; +typedef IwasakiGaugeAction ConjugateIwasakiGaugeActionR; +typedef IwasakiGaugeAction ConjugateIwasakiGaugeActionF; +typedef IwasakiGaugeAction ConjugateIwasakiGaugeActionD; +typedef SymanzikGaugeAction ConjugateSymanzikGaugeActionR; +typedef SymanzikGaugeAction ConjugateSymanzikGaugeActionF; +typedef SymanzikGaugeAction ConjugateSymanzikGaugeActionD; + +}} + + +#endif diff --git a/lib/qcd/action/gauge/GaugeImplTypes.h b/lib/qcd/action/gauge/GaugeImplTypes.h new file mode 100644 index 00000000..9e3e0d68 --- /dev/null +++ b/lib/qcd/action/gauge/GaugeImplTypes.h @@ -0,0 +1,153 @@ +/************************************************************************************* + +Grid physics library, www.github.com/paboyle/Grid + +Source file: ./lib/qcd/action/gauge/GaugeImpl.h + +Copyright (C) 2015 + +Author: paboyle + +This program is free software; you can redistribute it and/or modify +it under the terms of the GNU General Public License as published by +the Free Software Foundation; either version 2 of the License, or +(at your option) any later version. + +This program is distributed in the hope that it will be useful, +but WITHOUT ANY WARRANTY; without even the implied warranty of +MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +GNU General Public License for more details. + +You should have received a copy of the GNU General Public License along +with this program; if not, write to the Free Software Foundation, Inc., +51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA. + +See the full license in the file "LICENSE" in the top level distribution +directory +*************************************************************************************/ +/* END LEGAL */ +#ifndef GRID_GAUGE_IMPL_TYPES_H +#define GRID_GAUGE_IMPL_TYPES_H + +namespace Grid { +namespace QCD { + +//////////////////////////////////////////////////////////////////////// +// Implementation dependent gauge types +//////////////////////////////////////////////////////////////////////// + +#define INHERIT_GIMPL_TYPES(GImpl) \ + typedef typename GImpl::Simd Simd; \ + typedef typename GImpl::LinkField GaugeLinkField; \ + typedef typename GImpl::Field GaugeField; \ + typedef typename GImpl::ComplexField ComplexField;\ + typedef typename GImpl::SiteField SiteGaugeField; \ + typedef typename GImpl::SiteComplex SiteComplex; \ + typedef typename GImpl::SiteLink SiteGaugeLink; + +#define INHERIT_FIELD_TYPES(Impl) \ + typedef typename Impl::Simd Simd; \ + typedef typename Impl::ComplexField ComplexField; \ + typedef typename Impl::SiteField SiteField; \ + typedef typename Impl::Field Field; + +// hardcodes the exponential approximation in the template +template class GaugeImplTypes { +public: + typedef S Simd; + + template using iImplScalar = iScalar > >; + template using iImplGaugeLink = iScalar > >; + template using iImplGaugeField = iVector >, Nd>; + + typedef iImplScalar SiteComplex; + typedef iImplGaugeLink SiteLink; + typedef iImplGaugeField SiteField; + + typedef Lattice ComplexField; + typedef Lattice LinkField; + typedef Lattice Field; + + // Guido: we can probably separate the types from the HMC functions + // this will create 2 kind of implementations + // probably confusing the users + // Now keeping only one class + + + // Move this elsewhere? FIXME + static inline void AddLink(Field &U, LinkField &W, + int mu) { // U[mu] += W + PARALLEL_FOR_LOOP + for (auto ss = 0; ss < U._grid->oSites(); ss++) { + U._odata[ss]._internal[mu] = + U._odata[ss]._internal[mu] + W._odata[ss]._internal; + } + } + + /////////////////////////////////////////////////////////// + // Move these to another class + // HMC auxiliary functions + static inline void generate_momenta(Field &P, GridParallelRNG &pRNG) { + // specific for SU gauge fields + LinkField Pmu(P._grid); + Pmu = zero; + for (int mu = 0; mu < Nd; mu++) { + SU::GaussianFundamentalLieAlgebraMatrix(pRNG, Pmu); + PokeIndex(P, Pmu, mu); + } + } + + static inline Field projectForce(Field &P) { return Ta(P); } + + static inline void update_field(Field& P, Field& U, double ep){ + //static std::chrono::duration diff; + + //auto start = std::chrono::high_resolution_clock::now(); + parallel_for(int ss=0;ssoSites();ss++){ + for (int mu = 0; mu < Nd; mu++) + U[ss]._internal[mu] = ProjectOnGroup(Exponentiate(P[ss]._internal[mu], ep, Nexp) * U[ss]._internal[mu]); + } + + //auto end = std::chrono::high_resolution_clock::now(); + // diff += end - start; + // std::cout << "Time to exponentiate matrix " << diff.count() << " s\n"; + } + + static inline RealD FieldSquareNorm(Field& U){ + LatticeComplex Hloc(U._grid); + Hloc = zero; + for (int mu = 0; mu < Nd; mu++) { + auto Umu = PeekIndex(U, mu); + Hloc += trace(Umu * Umu); + } + Complex Hsum = sum(Hloc); + return Hsum.real(); + } + + static inline void HotConfiguration(GridParallelRNG &pRNG, Field &U) { + SU::HotConfiguration(pRNG, U); + } + + static inline void TepidConfiguration(GridParallelRNG &pRNG, Field &U) { + SU::TepidConfiguration(pRNG, U); + } + + static inline void ColdConfiguration(GridParallelRNG &pRNG, Field &U) { + SU::ColdConfiguration(pRNG, U); + } +}; + + +typedef GaugeImplTypes GimplTypesR; +typedef GaugeImplTypes GimplTypesF; +typedef GaugeImplTypes GimplTypesD; + +typedef GaugeImplTypes::AdjointDimension> GimplAdjointTypesR; +typedef GaugeImplTypes::AdjointDimension> GimplAdjointTypesF; +typedef GaugeImplTypes::AdjointDimension> GimplAdjointTypesD; + + +} // QCD +} // Grid + +#endif // GRID_GAUGE_IMPL_TYPES_H diff --git a/lib/qcd/action/gauge/GaugeImpl.h b/lib/qcd/action/gauge/GaugeImplementations.h similarity index 70% rename from lib/qcd/action/gauge/GaugeImpl.h rename to lib/qcd/action/gauge/GaugeImplementations.h index 400381bb..2d7464a9 100644 --- a/lib/qcd/action/gauge/GaugeImpl.h +++ b/lib/qcd/action/gauge/GaugeImplementations.h @@ -2,7 +2,7 @@ Grid physics library, www.github.com/paboyle/Grid -Source file: ./lib/qcd/action/gauge/GaugeImpl.h +Source file: ./lib/qcd/action/gauge/GaugeImplementations.h Copyright (C) 2015 @@ -26,54 +26,14 @@ See the full license in the file "LICENSE" in the top level distribution directory *************************************************************************************/ /* END LEGAL */ -#ifndef GRID_QCD_GAUGE_IMPL_H -#define GRID_QCD_GAUGE_IMPL_H +#ifndef GRID_QCD_GAUGE_IMPLEMENTATIONS_H +#define GRID_QCD_GAUGE_IMPLEMENTATIONS_H + +#include "GaugeImplTypes.h" namespace Grid { namespace QCD { -//////////////////////////////////////////////////////////////////////// -// Implementation dependent gauge types -//////////////////////////////////////////////////////////////////////// - -template class WilsonLoops; - -#define INHERIT_GIMPL_TYPES(GImpl) \ - typedef typename GImpl::Simd Simd; \ - typedef typename GImpl::GaugeLinkField GaugeLinkField; \ - typedef typename GImpl::GaugeField GaugeField; \ - typedef typename GImpl::SiteGaugeField SiteGaugeField; \ - typedef typename GImpl::SiteGaugeLink SiteGaugeLink; - -// -template class GaugeImplTypes { -public: - typedef S Simd; - - template - using iImplGaugeLink = iScalar>>; - template - using iImplGaugeField = iVector>, Nd>; - - typedef iImplGaugeLink SiteGaugeLink; - typedef iImplGaugeField SiteGaugeField; - - typedef Lattice GaugeLinkField; // bit ugly naming; polarised - // gauge field, lorentz... all - // ugly - typedef Lattice GaugeField; - - // Move this elsewhere? FIXME - static inline void AddGaugeLink(GaugeField &U, GaugeLinkField &W, - int mu) { // U[mu] += W - PARALLEL_FOR_LOOP - for (auto ss = 0; ss < U._grid->oSites(); ss++) { - U._odata[ss]._internal[mu] = - U._odata[ss]._internal[mu] + W._odata[ss]._internal; - } - } -}; - // Composition with smeared link, bc's etc.. probably need multiple inheritance // Variable precision "S" and variable Nc template class PeriodicGaugeImpl : public GimplTypes { @@ -169,14 +129,6 @@ public: static inline bool isPeriodicGaugeField(void) { return false; } }; -typedef GaugeImplTypes GimplTypesR; -typedef GaugeImplTypes GimplTypesF; -typedef GaugeImplTypes GimplTypesD; - -typedef GaugeImplTypes::AdjointDimension> GimplAdjointTypesR; -typedef GaugeImplTypes::AdjointDimension> GimplAdjointTypesF; -typedef GaugeImplTypes::AdjointDimension> GimplAdjointTypesD; - typedef PeriodicGaugeImpl PeriodicGimplR; // Real.. whichever prec typedef PeriodicGaugeImpl PeriodicGimplF; // Float typedef PeriodicGaugeImpl PeriodicGimplD; // Double @@ -188,6 +140,8 @@ typedef PeriodicGaugeImpl PeriodicGimplAdjD; // Double typedef ConjugateGaugeImpl ConjugateGimplR; // Real.. whichever prec typedef ConjugateGaugeImpl ConjugateGimplF; // Float typedef ConjugateGaugeImpl ConjugateGimplD; // Double + + } } diff --git a/lib/qcd/action/gauge/Photon.h b/lib/qcd/action/gauge/Photon.h new file mode 100644 index 00000000..7e21a1de --- /dev/null +++ b/lib/qcd/action/gauge/Photon.h @@ -0,0 +1,286 @@ +/************************************************************************************* + + Grid physics library, www.github.com/paboyle/Grid + + Source file: ./lib/qcd/action/gauge/Photon.h + + Copyright (C) 2015 + + Author: Peter Boyle + + This program is free software; you can redistribute it and/or modify + it under the terms of the GNU General Public License as published by + the Free Software Foundation; either version 2 of the License, or + (at your option) any later version. + + This program is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU General Public License for more details. + + You should have received a copy of the GNU General Public License along + with this program; if not, write to the Free Software Foundation, Inc., + 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA. + + See the full license in the file "LICENSE" in the top level distribution directory + *************************************************************************************/ +/* END LEGAL */ +#ifndef QCD_PHOTON_ACTION_H +#define QCD_PHOTON_ACTION_H + +namespace Grid{ +namespace QCD{ + template + class QedGimpl + { + public: + typedef S Simd; + + template + using iImplGaugeLink = iScalar>>; + template + using iImplGaugeField = iVector>, Nd>; + + typedef iImplGaugeLink SiteLink; + typedef iImplGaugeField SiteField; + typedef SiteField SiteComplex; + + typedef Lattice LinkField; + typedef Lattice Field; + typedef Field ComplexField; + }; + + typedef QedGimpl QedGimplR; + + template + class Photon + { + public: + INHERIT_GIMPL_TYPES(Gimpl); + GRID_SERIALIZABLE_ENUM(Gauge, undef, feynman, 1, coulomb, 2, landau, 3); + GRID_SERIALIZABLE_ENUM(ZmScheme, undef, qedL, 1, qedTL, 2); + public: + Photon(Gauge gauge, ZmScheme zmScheme); + virtual ~Photon(void) = default; + void FreePropagator(const GaugeField &in, GaugeField &out); + void MomentumSpacePropagator(const GaugeField &in, GaugeField &out); + void StochasticWeight(GaugeLinkField &weight); + void StochasticField(GaugeField &out, GridParallelRNG &rng); + void StochasticField(GaugeField &out, GridParallelRNG &rng, + const GaugeLinkField &weight); + private: + void invKHatSquared(GaugeLinkField &out); + void zmSub(GaugeLinkField &out); + private: + Gauge gauge_; + ZmScheme zmScheme_; + }; + + typedef Photon PhotonR; + + template + Photon::Photon(Gauge gauge, ZmScheme zmScheme) + : gauge_(gauge), zmScheme_(zmScheme) + {} + + template + void Photon::FreePropagator (const GaugeField &in,GaugeField &out) + { + FFT theFFT(in._grid); + + GaugeField in_k(in._grid); + GaugeField prop_k(in._grid); + + theFFT.FFT_all_dim(in_k,in,FFT::forward); + MomentumSpacePropagator(prop_k,in_k); + theFFT.FFT_all_dim(out,prop_k,FFT::backward); + } + + template + void Photon::invKHatSquared(GaugeLinkField &out) + { + GridBase *grid = out._grid; + GaugeLinkField kmu(grid), one(grid); + const unsigned int nd = grid->_ndimension; + std::vector &l = grid->_fdimensions; + std::vector zm(nd,0); + TComplex Tone = Complex(1.0,0.0); + TComplex Tzero= Complex(0.0,0.0); + + one = Complex(1.0,0.0); + out = zero; + for(int mu = 0; mu < nd; mu++) + { + Real twoPiL = M_PI*2./l[mu]; + + LatticeCoordinate(kmu,mu); + kmu = 2.*sin(.5*twoPiL*kmu); + out = out + kmu*kmu; + } + pokeSite(Tone, out, zm); + out = one/out; + pokeSite(Tzero, out, zm); + } + + template + void Photon::zmSub(GaugeLinkField &out) + { + GridBase *grid = out._grid; + const unsigned int nd = grid->_ndimension; + + switch (zmScheme_) + { + case ZmScheme::qedTL: + { + std::vector zm(nd,0); + TComplex Tzero = Complex(0.0,0.0); + + pokeSite(Tzero, out, zm); + + break; + } + case ZmScheme::qedL: + { + LatticeInteger spNrm(grid), coor(grid); + GaugeLinkField z(grid); + + spNrm = zero; + for(int d = 0; d < grid->_ndimension - 1; d++) + { + LatticeCoordinate(coor,d); + spNrm = spNrm + coor*coor; + } + out = where(spNrm == Integer(0), 0.*out, out); + + break; + } + default: + break; + } + } + + template + void Photon::MomentumSpacePropagator(const GaugeField &in, + GaugeField &out) + { + GridBase *grid = out._grid; + LatticeComplex k2Inv(grid); + + invKHatSquared(k2Inv); + zmSub(k2Inv); + + out = in*k2Inv; + } + + template + void Photon::StochasticWeight(GaugeLinkField &weight) + { + auto *grid = dynamic_cast(weight._grid); + const unsigned int nd = grid->_ndimension; + std::vector latt_size = grid->_fdimensions; + + Integer vol = 1; + for(int d = 0; d < nd; d++) + { + vol = vol * latt_size[d]; + } + invKHatSquared(weight); + weight = sqrt(vol*real(weight)); + zmSub(weight); + } + + template + void Photon::StochasticField(GaugeField &out, GridParallelRNG &rng) + { + auto *grid = dynamic_cast(out._grid); + GaugeLinkField weight(grid); + + StochasticWeight(weight); + StochasticField(out, rng, weight); + } + + template + void Photon::StochasticField(GaugeField &out, GridParallelRNG &rng, + const GaugeLinkField &weight) + { + auto *grid = dynamic_cast(out._grid); + const unsigned int nd = grid->_ndimension; + GaugeLinkField r(grid); + GaugeField aTilde(grid); + FFT fft(grid); + + for(int mu = 0; mu < nd; mu++) + { + gaussian(rng, r); + r = weight*r; + pokeLorentz(aTilde, r, mu); + } + fft.FFT_all_dim(out, aTilde, FFT::backward); + + out = real(out); + } +// template +// void Photon::FeynmanGaugeMomentumSpacePropagator_L(GaugeField &out, +// const GaugeField &in) +// { +// +// FeynmanGaugeMomentumSpacePropagator_TL(out,in); +// +// GridBase *grid = out._grid; +// LatticeInteger coor(grid); +// GaugeField zz(grid); zz=zero; +// +// // xyzt +// for(int d = 0; d < grid->_ndimension-1;d++){ +// LatticeCoordinate(coor,d); +// out = where(coor==Integer(0),zz,out); +// } +// } +// +// template +// void Photon::FeynmanGaugeMomentumSpacePropagator_TL(GaugeField &out, +// const GaugeField &in) +// { +// +// // what type LatticeComplex +// GridBase *grid = out._grid; +// int nd = grid->_ndimension; +// +// typedef typename GaugeField::vector_type vector_type; +// typedef typename GaugeField::scalar_type ScalComplex; +// typedef Lattice > LatComplex; +// +// std::vector latt_size = grid->_fdimensions; +// +// LatComplex denom(grid); denom= zero; +// LatComplex one(grid); one = ScalComplex(1.0,0.0); +// LatComplex kmu(grid); +// +// ScalComplex ci(0.0,1.0); +// // momphase = n * 2pi / L +// for(int mu=0;mu zero_mode(nd,0); +// TComplexD Tone = ComplexD(1.0,0.0); +// TComplexD Tzero= ComplexD(0.0,0.0); +// +// pokeSite(Tone,denom,zero_mode); +// +// denom= one/denom; +// +// pokeSite(Tzero,denom,zero_mode); +// +// out = zero; +// out = in*denom; +// }; + +}} +#endif diff --git a/lib/qcd/action/gauge/PlaqPlusRectangleAction.h b/lib/qcd/action/gauge/PlaqPlusRectangleAction.h index 6193bedb..5bfd39b2 100644 --- a/lib/qcd/action/gauge/PlaqPlusRectangleAction.h +++ b/lib/qcd/action/gauge/PlaqPlusRectangleAction.h @@ -47,9 +47,19 @@ namespace Grid{ public: PlaqPlusRectangleAction(RealD b,RealD c): c_plaq(b),c_rect(c){}; + + virtual std::string action_name(){return "PlaqPlusRectangleAction";} virtual void refresh(const GaugeField &U, GridParallelRNG& pRNG) {}; // noop as no pseudoferms + virtual std::string LogParameters(){ + std::stringstream sstream; + sstream << GridLogMessage << "["<gSites(); @@ -108,32 +118,32 @@ namespace Grid{ class RBCGaugeAction : public PlaqPlusRectangleAction { public: INHERIT_GIMPL_TYPES(Gimpl); - RBCGaugeAction(RealD beta,RealD c1) : PlaqPlusRectangleAction(beta*(1.0-8.0*c1), beta*c1) { - }; + RBCGaugeAction(RealD beta,RealD c1) : PlaqPlusRectangleAction(beta*(1.0-8.0*c1), beta*c1) {}; + virtual std::string action_name(){return "RBCGaugeAction";} }; template class IwasakiGaugeAction : public RBCGaugeAction { public: INHERIT_GIMPL_TYPES(Gimpl); - IwasakiGaugeAction(RealD beta) : RBCGaugeAction(beta,-0.331) { - }; + IwasakiGaugeAction(RealD beta) : RBCGaugeAction(beta,-0.331) {}; + virtual std::string action_name(){return "IwasakiGaugeAction";} }; template class SymanzikGaugeAction : public RBCGaugeAction { public: INHERIT_GIMPL_TYPES(Gimpl); - SymanzikGaugeAction(RealD beta) : RBCGaugeAction(beta,-1.0/12.0) { - }; + SymanzikGaugeAction(RealD beta) : RBCGaugeAction(beta,-1.0/12.0) {}; + virtual std::string action_name(){return "SymanzikGaugeAction";} }; template class DBW2GaugeAction : public RBCGaugeAction { public: INHERIT_GIMPL_TYPES(Gimpl); - DBW2GaugeAction(RealD beta) : RBCGaugeAction(beta,-1.4067) { - }; + DBW2GaugeAction(RealD beta) : RBCGaugeAction(beta,-1.4067) {}; + virtual std::string action_name(){return "DBW2GaugeAction";} }; } diff --git a/lib/qcd/action/gauge/WilsonGaugeAction.h b/lib/qcd/action/gauge/WilsonGaugeAction.h index aff67c67..1ea780b7 100644 --- a/lib/qcd/action/gauge/WilsonGaugeAction.h +++ b/lib/qcd/action/gauge/WilsonGaugeAction.h @@ -1,86 +1,99 @@ - /************************************************************************************* +/************************************************************************************* - Grid physics library, www.github.com/paboyle/Grid +Grid physics library, www.github.com/paboyle/Grid - Source file: ./lib/qcd/action/gauge/WilsonGaugeAction.h +Source file: ./lib/qcd/action/gauge/WilsonGaugeAction.h - Copyright (C) 2015 +Copyright (C) 2015 Author: Azusa Yamaguchi Author: Peter Boyle Author: neo Author: paboyle +Author: Guido Cossu - This program is free software; you can redistribute it and/or modify - it under the terms of the GNU General Public License as published by - the Free Software Foundation; either version 2 of the License, or - (at your option) any later version. +This program is free software; you can redistribute it and/or modify +it under the terms of the GNU General Public License as published by +the Free Software Foundation; either version 2 of the License, or +(at your option) any later version. - This program is distributed in the hope that it will be useful, - but WITHOUT ANY WARRANTY; without even the implied warranty of - MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the - GNU General Public License for more details. +This program is distributed in the hope that it will be useful, +but WITHOUT ANY WARRANTY; without even the implied warranty of +MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +GNU General Public License for more details. - You should have received a copy of the GNU General Public License along - with this program; if not, write to the Free Software Foundation, Inc., - 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA. +You should have received a copy of the GNU General Public License along +with this program; if not, write to the Free Software Foundation, Inc., +51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA. - See the full license in the file "LICENSE" in the top level distribution directory - *************************************************************************************/ - /* END LEGAL */ +See the full license in the file "LICENSE" in the top level distribution +directory +*************************************************************************************/ +/* END LEGAL */ #ifndef QCD_WILSON_GAUGE_ACTION_H #define QCD_WILSON_GAUGE_ACTION_H -namespace Grid{ - namespace QCD{ - - //////////////////////////////////////////////////////////////////////// - // Wilson Gauge Action .. should I template the Nc etc.. - //////////////////////////////////////////////////////////////////////// - template - class WilsonGaugeAction : public Action { - public: +namespace Grid { +namespace QCD { - INHERIT_GIMPL_TYPES(Gimpl); +//////////////////////////////////////////////////////////////////////// +// Wilson Gauge Action .. should I template the Nc etc.. +//////////////////////////////////////////////////////////////////////// +template +class WilsonGaugeAction : public Action { + public: + INHERIT_GIMPL_TYPES(Gimpl); - // typedef LorentzScalar GaugeLinkField; + /////////////////////////// constructors + explicit WilsonGaugeAction(RealD beta_):beta(beta_){}; - private: - RealD beta; - public: - WilsonGaugeAction(RealD b):beta(b){}; - - virtual void refresh(const GaugeField &U, GridParallelRNG& pRNG) {}; // noop as no pseudoferms - - virtual RealD S(const GaugeField &U) { - RealD plaq = WilsonLoops::avgPlaquette(U); - RealD vol = U._grid->gSites(); - RealD action=beta*(1.0 -plaq)*(Nd*(Nd-1.0))*vol*0.5; - return action; - }; + virtual std::string action_name() {return "WilsonGaugeAction";} - virtual void deriv(const GaugeField &U,GaugeField & dSdU) { - //not optimal implementation FIXME - //extend Ta to include Lorentz indexes - - //RealD factor = 0.5*beta/RealD(Nc); - RealD factor = 0.5*beta/RealD(Nc); - - GaugeLinkField Umu(U._grid); - GaugeLinkField dSdU_mu(U._grid); - for (int mu=0; mu < Nd; mu++){ - - Umu = PeekIndex(U,mu); - - // Staple in direction mu - WilsonLoops::Staple(dSdU_mu,U,mu); - dSdU_mu = Ta(Umu*dSdU_mu)*factor; - PokeIndex(dSdU, dSdU_mu, mu); - } - }; - }; - + virtual std::string LogParameters(){ + std::stringstream sstream; + sstream << GridLogMessage << "[WilsonGaugeAction] Beta: " << beta << std::endl; + return sstream.str(); } + + virtual void refresh(const GaugeField &U, + GridParallelRNG &pRNG){}; // noop as no pseudoferms + + virtual RealD S(const GaugeField &U) { + RealD plaq = WilsonLoops::avgPlaquette(U); + RealD vol = U._grid->gSites(); + RealD action = beta * (1.0 - plaq) * (Nd * (Nd - 1.0)) * vol * 0.5; + return action; + }; + + virtual void deriv(const GaugeField &U, GaugeField &dSdU) { + // not optimal implementation FIXME + // extend Ta to include Lorentz indexes + + RealD factor = 0.5 * beta / RealD(Nc); + + //GaugeLinkField Umu(U._grid); + GaugeLinkField dSdU_mu(U._grid); + for (int mu = 0; mu < Nd; mu++) { + //Umu = PeekIndex(U, mu); + + // Staple in direction mu + //WilsonLoops::Staple(dSdU_mu, U, mu); + //dSdU_mu = Ta(Umu * dSdU_mu) * factor; + + + WilsonLoops::StapleMult(dSdU_mu, U, mu); + dSdU_mu = Ta(dSdU_mu) * factor; + + PokeIndex(dSdU, dSdU_mu, mu); + } + } +private: + RealD beta; +}; + + + +} } #endif diff --git a/lib/qcd/action/pseudofermion/EvenOddSchurDifferentiable.h b/lib/qcd/action/pseudofermion/EvenOddSchurDifferentiable.h index 6837bb19..90b6dbaa 100644 --- a/lib/qcd/action/pseudofermion/EvenOddSchurDifferentiable.h +++ b/lib/qcd/action/pseudofermion/EvenOddSchurDifferentiable.h @@ -7,6 +7,7 @@ Copyright (C) 2015 Author: Peter Boyle +Author: Guido Cossu This program is free software; you can redistribute it and/or modify it under the terms of the GNU General Public License as published by @@ -45,92 +46,97 @@ namespace Grid{ public: INHERIT_IMPL_TYPES(Impl); - typedef FermionOperator Matrix; + typedef FermionOperator Matrix; - SchurDifferentiableOperator (Matrix &Mat) : SchurDiagMooeeOperator(Mat) {}; + SchurDifferentiableOperator (Matrix &Mat) : SchurDiagMooeeOperator(Mat) {}; - void MpcDeriv(GaugeField &Force,const FermionField &U,const FermionField &V) { - - GridBase *fgrid = this->_Mat.FermionGrid(); - GridBase *fcbgrid = this->_Mat.FermionRedBlackGrid(); - GridBase *ugrid = this->_Mat.GaugeGrid(); - GridBase *ucbgrid = this->_Mat.GaugeRedBlackGrid(); + void MpcDeriv(GaugeField &Force,const FermionField &U,const FermionField &V) { + + GridBase *fgrid = this->_Mat.FermionGrid(); + GridBase *fcbgrid = this->_Mat.FermionRedBlackGrid(); - Real coeff = 1.0; + FermionField tmp1(fcbgrid); + FermionField tmp2(fcbgrid); - FermionField tmp1(fcbgrid); - FermionField tmp2(fcbgrid); + conformable(fcbgrid,U._grid); + conformable(fcbgrid,V._grid); - conformable(fcbgrid,U._grid); - conformable(fcbgrid,V._grid); + // Assert the checkerboard?? or code for either + assert(U.checkerboard==Odd); + assert(V.checkerboard==U.checkerboard); - // Assert the checkerboard?? or code for either - assert(U.checkerboard==Odd); - assert(V.checkerboard==U.checkerboard); + // NOTE Guido: WE DO NOT WANT TO USE THE ucbgrid GRID FOR THE FORCE + // it is not conformable with the HMC force field + // Case: Ls vectorised fields + // INHERIT FROM THE Force field instead + GridRedBlackCartesian* forcecb = new GridRedBlackCartesian(Force._grid); + GaugeField ForceO(forcecb); + GaugeField ForceE(forcecb); - GaugeField ForceO(ucbgrid); - GaugeField ForceE(ucbgrid); - // X^dag Der_oe MeeInv Meo Y - // Use Mooee as nontrivial but gauge field indept - this->_Mat.Meooe (V,tmp1); // odd->even -- implicit -0.5 factor to be applied + // X^dag Der_oe MeeInv Meo Y + // Use Mooee as nontrivial but gauge field indept + this->_Mat.Meooe (V,tmp1); // odd->even -- implicit -0.5 factor to be applied this->_Mat.MooeeInv(tmp1,tmp2); // even->even - this->_Mat.MoeDeriv(ForceO,U,tmp2,DaggerNo); - - // Accumulate X^dag M_oe MeeInv Der_eo Y - this->_Mat.MeooeDag (U,tmp1); // even->odd -- implicit -0.5 factor to be applied - this->_Mat.MooeeInvDag(tmp1,tmp2); // even->even - this->_Mat.MeoDeriv(ForceE,tmp2,V,DaggerNo); - - assert(ForceE.checkerboard==Even); - assert(ForceO.checkerboard==Odd); + this->_Mat.MoeDeriv(ForceO,U,tmp2,DaggerNo); + // Accumulate X^dag M_oe MeeInv Der_eo Y + this->_Mat.MeooeDag (U,tmp1); // even->odd -- implicit -0.5 factor to be applied + this->_Mat.MooeeInvDag(tmp1,tmp2); // even->even + this->_Mat.MeoDeriv(ForceE,tmp2,V,DaggerNo); + + assert(ForceE.checkerboard==Even); + assert(ForceO.checkerboard==Odd); - setCheckerboard(Force,ForceE); - setCheckerboard(Force,ForceO); - Force=-Force; - } + setCheckerboard(Force,ForceE); + setCheckerboard(Force,ForceO); + Force=-Force; + + delete forcecb; + } - void MpcDagDeriv(GaugeField &Force,const FermionField &U,const FermionField &V) { - - GridBase *fgrid = this->_Mat.FermionGrid(); - GridBase *fcbgrid = this->_Mat.FermionRedBlackGrid(); - GridBase *ugrid = this->_Mat.GaugeGrid(); - GridBase *ucbgrid = this->_Mat.GaugeRedBlackGrid(); + void MpcDagDeriv(GaugeField &Force,const FermionField &U,const FermionField &V) { + + GridBase *fgrid = this->_Mat.FermionGrid(); + GridBase *fcbgrid = this->_Mat.FermionRedBlackGrid(); - Real coeff = 1.0; + FermionField tmp1(fcbgrid); + FermionField tmp2(fcbgrid); - FermionField tmp1(fcbgrid); - FermionField tmp2(fcbgrid); + conformable(fcbgrid,U._grid); + conformable(fcbgrid,V._grid); - conformable(fcbgrid,U._grid); - conformable(fcbgrid,V._grid); + // Assert the checkerboard?? or code for either + assert(V.checkerboard==Odd); + assert(V.checkerboard==V.checkerboard); - // Assert the checkerboard?? or code for either - assert(V.checkerboard==Odd); - assert(V.checkerboard==V.checkerboard); + // NOTE Guido: WE DO NOT WANT TO USE THE ucbgrid GRID FOR THE FORCE + // it is not conformable with the HMC force field + // INHERIT FROM THE Force field instead + GridRedBlackCartesian* forcecb = new GridRedBlackCartesian(Force._grid); + GaugeField ForceO(forcecb); + GaugeField ForceE(forcecb); - GaugeField ForceO(ucbgrid); - GaugeField ForceE(ucbgrid); + // X^dag Der_oe MeeInv Meo Y + // Use Mooee as nontrivial but gauge field indept + this->_Mat.MeooeDag (V,tmp1); // odd->even -- implicit -0.5 factor to be applied + this->_Mat.MooeeInvDag(tmp1,tmp2); // even->even + this->_Mat.MoeDeriv(ForceO,U,tmp2,DaggerYes); + + // Accumulate X^dag M_oe MeeInv Der_eo Y + this->_Mat.Meooe (U,tmp1); // even->odd -- implicit -0.5 factor to be applied + this->_Mat.MooeeInv(tmp1,tmp2); // even->even + this->_Mat.MeoDeriv(ForceE,tmp2,V,DaggerYes); - // X^dag Der_oe MeeInv Meo Y - // Use Mooee as nontrivial but gauge field indept - this->_Mat.MeooeDag (V,tmp1); // odd->even -- implicit -0.5 factor to be applied - this->_Mat.MooeeInvDag(tmp1,tmp2); // even->even - this->_Mat.MoeDeriv(ForceO,U,tmp2,DaggerYes); - - // Accumulate X^dag M_oe MeeInv Der_eo Y - this->_Mat.Meooe (U,tmp1); // even->odd -- implicit -0.5 factor to be applied - this->_Mat.MooeeInv(tmp1,tmp2); // even->even - this->_Mat.MeoDeriv(ForceE,tmp2,V,DaggerYes); + assert(ForceE.checkerboard==Even); + assert(ForceO.checkerboard==Odd); - assert(ForceE.checkerboard==Even); - assert(ForceO.checkerboard==Odd); + setCheckerboard(Force,ForceE); + setCheckerboard(Force,ForceO); + Force=-Force; - setCheckerboard(Force,ForceE); - setCheckerboard(Force,ForceO); - Force=-Force; - } + delete forcecb; + } }; diff --git a/lib/qcd/action/pseudofermion/ExactOneFlavourRatio.h b/lib/qcd/action/pseudofermion/ExactOneFlavourRatio.h new file mode 100644 index 00000000..9c1e2921 --- /dev/null +++ b/lib/qcd/action/pseudofermion/ExactOneFlavourRatio.h @@ -0,0 +1,264 @@ +/************************************************************************************* + +Grid physics library, www.github.com/paboyle/Grid + +Source file: ./lib/qcd/action/pseudofermion/ExactOneFlavourRatio.h + +Copyright (C) 2017 + +Author: Peter Boyle +Author: David Murphy + +This program is free software; you can redistribute it and/or modify +it under the terms of the GNU General Public License as published by +the Free Software Foundation; either version 2 of the License, or +(at your option) any later version. + +This program is distributed in the hope that it will be useful, +but WITHOUT ANY WARRANTY; without even the implied warranty of +MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +GNU General Public License for more details. + +You should have received a copy of the GNU General Public License along +with this program; if not, write to the Free Software Foundation, Inc., +51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA. + +See the full license in the file "LICENSE" in the top level distribution directory +*************************************************************************************/ +/* END LEGAL */ + +///////////////////////////////////////////////////////////////// +// Implementation of exact one flavour algorithm (EOFA) // +// using fermion classes defined in: // +// Grid/qcd/action/fermion/DomainWallEOFAFermion.h (Shamir) // +// Grid/qcd/action/fermion/MobiusEOFAFermion.h (Mobius) // +// arXiv: 1403.1683, 1706.05843 // +///////////////////////////////////////////////////////////////// + +#ifndef QCD_PSEUDOFERMION_EXACT_ONE_FLAVOUR_RATIO_H +#define QCD_PSEUDOFERMION_EXACT_ONE_FLAVOUR_RATIO_H + +namespace Grid{ +namespace QCD{ + + /////////////////////////////////////////////////////////////// + // Exact one flavour implementation of DWF determinant ratio // + /////////////////////////////////////////////////////////////// + + template + class ExactOneFlavourRatioPseudoFermionAction : public Action + { + public: + INHERIT_IMPL_TYPES(Impl); + typedef OneFlavourRationalParams Params; + Params param; + MultiShiftFunction PowerNegHalf; + + private: + bool use_heatbath_forecasting; + AbstractEOFAFermion& Lop; // the basic LH operator + AbstractEOFAFermion& Rop; // the basic RH operator + SchurRedBlackDiagMooeeSolve Solver; + FermionField Phi; // the pseudofermion field for this trajectory + + public: + ExactOneFlavourRatioPseudoFermionAction(AbstractEOFAFermion& _Lop, AbstractEOFAFermion& _Rop, + OperatorFunction& S, Params& p, bool use_fc=false) : Lop(_Lop), Rop(_Rop), Solver(S), + Phi(_Lop.FermionGrid()), param(p), use_heatbath_forecasting(use_fc) + { + AlgRemez remez(param.lo, param.hi, param.precision); + + // MdagM^(+- 1/2) + std::cout << GridLogMessage << "Generating degree " << param.degree << " for x^(-1/2)" << std::endl; + remez.generateApprox(param.degree, 1, 2); + PowerNegHalf.Init(remez, param.tolerance, true); + }; + + virtual std::string action_name() { return "ExactOneFlavourRatioPseudoFermionAction"; } + + virtual std::string LogParameters() { + std::stringstream sstream; + sstream << GridLogMessage << "[" << action_name() << "] Low :" << param.lo << std::endl; + sstream << GridLogMessage << "[" << action_name() << "] High :" << param.hi << std::endl; + sstream << GridLogMessage << "[" << action_name() << "] Max iterations :" << param.MaxIter << std::endl; + sstream << GridLogMessage << "[" << action_name() << "] Tolerance :" << param.tolerance << std::endl; + sstream << GridLogMessage << "[" << action_name() << "] Degree :" << param.degree << std::endl; + sstream << GridLogMessage << "[" << action_name() << "] Precision :" << param.precision << std::endl; + return sstream.str(); + } + + // Spin projection + void spProj(const FermionField& in, FermionField& out, int sign, int Ls) + { + if(sign == 1){ for(int s=0; s tmp(2, Lop.FermionGrid()); + + // Use chronological inverter to forecast solutions across poles + std::vector prev_solns; + if(use_heatbath_forecasting){ prev_solns.reserve(param.degree); } + ChronoForecast, FermionField> Forecast; + + // Seed with Gaussian noise vector (var = 0.5) + RealD scale = std::sqrt(0.5); + gaussian(pRNG,eta); + eta = eta * scale; + printf("Heatbath source vector: <\\eta|\\eta> = %1.15e\n", norm2(eta)); + + // \Phi = ( \alpha_{0} + \sum_{k=1}^{N_{p}} \alpha_{l} * \gamma_{l} ) * \eta + RealD N(PowerNegHalf.norm); + for(int k=0; k tmp(2, Lop.FermionGrid()); + + // S = <\Phi|\Phi> + RealD action(norm2(Phi)); + + // LH term: S = S - k <\Phi| P_{-} \Omega_{-}^{\dagger} H(mf)^{-1} \Omega_{-} P_{-} |\Phi> + spProj(Phi, spProj_Phi, -1, Lop.Ls); + Lop.Omega(spProj_Phi, tmp[0], -1, 0); + G5R5(tmp[1], tmp[0]); + tmp[0] = zero; + Solver(Lop, tmp[1], tmp[0]); + Lop.Dtilde(tmp[0], tmp[1]); // We actually solved Cayley preconditioned system: transform back + Lop.Omega(tmp[1], tmp[0], -1, 1); + action -= Lop.k * innerProduct(spProj_Phi, tmp[0]).real(); + + // RH term: S = S + k <\Phi| P_{+} \Omega_{+}^{\dagger} ( H(mb) + // - \Delta_{+}(mf,mb) P_{+} )^{-1} \Omega_{-} P_{-} |\Phi> + spProj(Phi, spProj_Phi, 1, Rop.Ls); + Rop.Omega(spProj_Phi, tmp[0], 1, 0); + G5R5(tmp[1], tmp[0]); + tmp[0] = zero; + Solver(Rop, tmp[1], tmp[0]); + Rop.Dtilde(tmp[0], tmp[1]); + Rop.Omega(tmp[1], tmp[0], 1, 1); + action += Rop.k * innerProduct(spProj_Phi, tmp[0]).real(); + + return action; + }; + + // EOFA pseudofermion force: see Eqns. (34)-(36) of arXiv:1706.05843 + virtual void deriv(const GaugeField& U, GaugeField& dSdU) + { + Lop.ImportGauge(U); + Rop.ImportGauge(U); + + FermionField spProj_Phi (Lop.FermionGrid()); + FermionField Omega_spProj_Phi(Lop.FermionGrid()); + FermionField CG_src (Lop.FermionGrid()); + FermionField Chi (Lop.FermionGrid()); + FermionField g5_R5_Chi (Lop.FermionGrid()); + + GaugeField force(Lop.GaugeGrid()); + + // LH: dSdU = k \chi_{L}^{\dagger} \gamma_{5} R_{5} ( \partial_{x,\mu} D_{w} ) \chi_{L} + // \chi_{L} = H(mf)^{-1} \Omega_{-} P_{-} \Phi + spProj(Phi, spProj_Phi, -1, Lop.Ls); + Lop.Omega(spProj_Phi, Omega_spProj_Phi, -1, 0); + G5R5(CG_src, Omega_spProj_Phi); + spProj_Phi = zero; + Solver(Lop, CG_src, spProj_Phi); + Lop.Dtilde(spProj_Phi, Chi); + G5R5(g5_R5_Chi, Chi); + Lop.MDeriv(force, g5_R5_Chi, Chi, DaggerNo); + dSdU = Lop.k * force; + + // RH: dSdU = dSdU - k \chi_{R}^{\dagger} \gamma_{5} R_{5} ( \partial_{x,\mu} D_{w} ) \chi_{} + // \chi_{R} = ( H(mb) - \Delta_{+}(mf,mb) P_{+} )^{-1} \Omega_{+} P_{+} \Phi + spProj(Phi, spProj_Phi, 1, Rop.Ls); + Rop.Omega(spProj_Phi, Omega_spProj_Phi, 1, 0); + G5R5(CG_src, Omega_spProj_Phi); + spProj_Phi = zero; + Solver(Rop, CG_src, spProj_Phi); + Rop.Dtilde(spProj_Phi, Chi); + G5R5(g5_R5_Chi, Chi); + Lop.MDeriv(force, g5_R5_Chi, Chi, DaggerNo); + dSdU = dSdU - Rop.k * force; + }; + }; +}} + +#endif diff --git a/lib/qcd/action/pseudofermion/OneFlavourEvenOddRational.h b/lib/qcd/action/pseudofermion/OneFlavourEvenOddRational.h index 080b1be2..9b89959e 100644 --- a/lib/qcd/action/pseudofermion/OneFlavourEvenOddRational.h +++ b/lib/qcd/action/pseudofermion/OneFlavourEvenOddRational.h @@ -1,3 +1,4 @@ + /************************************************************************************* Grid physics library, www.github.com/paboyle/Grid @@ -90,6 +91,19 @@ class OneFlavourEvenOddRationalPseudoFermionAction PowerNegQuarter.Init(remez, param.tolerance, true); }; + virtual std::string action_name(){return "OneFlavourEvenOddRationalPseudoFermionAction";} + + virtual std::string LogParameters(){ + std::stringstream sstream; + sstream << GridLogMessage << "["< + +This program is free software; you can redistribute it and/or modify +it under the terms of the GNU General Public License as published by +the Free Software Foundation; either version 2 of the License, or +(at your option) any later version. + +This program is distributed in the hope that it will be useful, +but WITHOUT ANY WARRANTY; without even the implied warranty of +MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +GNU General Public License for more details. + +You should have received a copy of the GNU General Public License along +with this program; if not, write to the Free Software Foundation, Inc., +51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA. + +See the full license in the file "LICENSE" in the top level distribution +directory +*************************************************************************************/ +/* END LEGAL */ +#ifndef QCD_PSEUDOFERMION_AGGREGATE_H +#define QCD_PSEUDOFERMION_AGGREGATE_H + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#endif diff --git a/lib/qcd/action/pseudofermion/TwoFlavour.h b/lib/qcd/action/pseudofermion/TwoFlavour.h index ddc17d42..17d307a4 100644 --- a/lib/qcd/action/pseudofermion/TwoFlavour.h +++ b/lib/qcd/action/pseudofermion/TwoFlavour.h @@ -62,6 +62,15 @@ class TwoFlavourPseudoFermionAction : public Action { ActionSolver(AS), Phi(Op.FermionGrid()){}; + + virtual std::string action_name(){return "TwoFlavourPseudoFermionAction";} + + virtual std::string LogParameters(){ + std::stringstream sstream; + sstream << GridLogMessage << "["< { // in the Phi integral, and thus is only an irrelevant prefactor for // the partition function. // + RealD scale = std::sqrt(0.5); + FermionField eta(FermOp.FermionGrid()); gaussian(pRNG, eta); diff --git a/lib/qcd/action/pseudofermion/TwoFlavourEvenOdd.h b/lib/qcd/action/pseudofermion/TwoFlavourEvenOdd.h index 5af1761e..0bbc0ae6 100644 --- a/lib/qcd/action/pseudofermion/TwoFlavourEvenOdd.h +++ b/lib/qcd/action/pseudofermion/TwoFlavourEvenOdd.h @@ -31,80 +31,89 @@ directory #define QCD_PSEUDOFERMION_TWO_FLAVOUR_EVEN_ODD_H namespace Grid { -namespace QCD { + namespace QCD { -//////////////////////////////////////////////////////////////////////// -// Two flavour pseudofermion action for any EO prec dop -//////////////////////////////////////////////////////////////////////// -template -class TwoFlavourEvenOddPseudoFermionAction - : public Action { - public: - INHERIT_IMPL_TYPES(Impl); + //////////////////////////////////////////////////////////////////////// + // Two flavour pseudofermion action for any EO prec dop + //////////////////////////////////////////////////////////////////////// + template + class TwoFlavourEvenOddPseudoFermionAction + : public Action { + public: + INHERIT_IMPL_TYPES(Impl); - private: - FermionOperator &FermOp; // the basic operator + private: + FermionOperator &FermOp; // the basic operator - OperatorFunction &DerivativeSolver; - OperatorFunction &ActionSolver; + OperatorFunction &DerivativeSolver; + OperatorFunction &ActionSolver; - FermionField PhiOdd; // the pseudo fermion field for this trajectory - FermionField PhiEven; // the pseudo fermion field for this trajectory + FermionField PhiOdd; // the pseudo fermion field for this trajectory + FermionField PhiEven; // the pseudo fermion field for this trajectory - public: - ///////////////////////////////////////////////// - // Pass in required objects. - ///////////////////////////////////////////////// - TwoFlavourEvenOddPseudoFermionAction(FermionOperator &Op, - OperatorFunction &DS, - OperatorFunction &AS) - : FermOp(Op), - DerivativeSolver(DS), - ActionSolver(AS), - PhiEven(Op.FermionRedBlackGrid()), - PhiOdd(Op.FermionRedBlackGrid()) - {}; + public: + ///////////////////////////////////////////////// + // Pass in required objects. + ///////////////////////////////////////////////// + TwoFlavourEvenOddPseudoFermionAction(FermionOperator &Op, + OperatorFunction &DS, + OperatorFunction &AS) + : FermOp(Op), + DerivativeSolver(DS), + ActionSolver(AS), + PhiEven(Op.FermionRedBlackGrid()), + PhiOdd(Op.FermionRedBlackGrid()) + {}; + + virtual std::string action_name(){return "TwoFlavourEvenOddPseudoFermionAction";} + virtual std::string LogParameters(){ + std::stringstream sstream; + sstream << GridLogMessage << "["< sig^2 = 0.5. - + RealD scale = std::sqrt(0.5); - + FermionField eta (FermOp.FermionGrid()); FermionField etaOdd (FermOp.FermionRedBlackGrid()); FermionField etaEven(FermOp.FermionRedBlackGrid()); - + gaussian(pRNG,eta); pickCheckerboard(Even,etaEven,eta); pickCheckerboard(Odd,etaOdd,eta); - + FermOp.ImportGauge(U); SchurDifferentiableOperator PCop(FermOp); - - + + PCop.MpcDag(etaOdd,PhiOdd); - + FermOp.MooeeDag(etaEven,PhiEven); - + PhiOdd =PhiOdd*scale; PhiEven=PhiEven*scale; - + }; - + ////////////////////////////////////////////////////// // S = phi^dag (Mdag M)^-1 phi (odd) // + phi^dag (Mdag M)^-1 phi (even) ////////////////////////////////////////////////////// virtual RealD S(const GaugeField &U) { - + FermOp.ImportGauge(U); FermionField X(FermOp.FermionRedBlackGrid()); @@ -135,7 +144,6 @@ class TwoFlavourEvenOddPseudoFermionAction // ////////////////////////////////////////////////////// virtual void deriv(const GaugeField &U,GaugeField & dSdU) { - FermOp.ImportGauge(U); FermionField X(FermOp.FermionRedBlackGrid()); @@ -150,8 +158,8 @@ class TwoFlavourEvenOddPseudoFermionAction X=zero; DerivativeSolver(Mpc,PhiOdd,X); Mpc.Mpc(X,Y); - Mpc.MpcDeriv(tmp , Y, X ); dSdU=tmp; - Mpc.MpcDagDeriv(tmp , X, Y); dSdU=dSdU+tmp; + Mpc.MpcDeriv(tmp , Y, X ); dSdU=tmp; + Mpc.MpcDagDeriv(tmp , X, Y); dSdU=dSdU+tmp; // Treat the EE case. (MdagM)^-1 = Minv Minvdag // Deriv defaults to zero. @@ -163,10 +171,10 @@ class TwoFlavourEvenOddPseudoFermionAction assert(FermOp.ConstEE() == 1); /* - FermOp.MooeeInvDag(PhiOdd,Y); - FermOp.MooeeInv(Y,X); - FermOp.MeeDeriv(tmp , Y, X,DaggerNo ); dSdU=tmp; - FermOp.MeeDeriv(tmp , X, Y,DaggerYes); dSdU=dSdU+tmp; + FermOp.MooeeInvDag(PhiOdd,Y); + FermOp.MooeeInv(Y,X); + FermOp.MeeDeriv(tmp , Y, X,DaggerNo ); dSdU=tmp; + FermOp.MeeDeriv(tmp , X, Y,DaggerYes); dSdU=dSdU+tmp; */ //dSdU = Ta(dSdU); diff --git a/lib/qcd/action/pseudofermion/TwoFlavourEvenOddRatio.h b/lib/qcd/action/pseudofermion/TwoFlavourEvenOddRatio.h index 5e3b80d9..0f57fe9c 100644 --- a/lib/qcd/action/pseudofermion/TwoFlavourEvenOddRatio.h +++ b/lib/qcd/action/pseudofermion/TwoFlavourEvenOddRatio.h @@ -52,66 +52,75 @@ namespace Grid{ public: TwoFlavourEvenOddRatioPseudoFermionAction(FermionOperator &_NumOp, - FermionOperator &_DenOp, - OperatorFunction & DS, - OperatorFunction & AS) : + FermionOperator &_DenOp, + OperatorFunction & DS, + OperatorFunction & AS) : NumOp(_NumOp), DenOp(_DenOp), DerivativeSolver(DS), ActionSolver(AS), PhiEven(_NumOp.FermionRedBlackGrid()), PhiOdd(_NumOp.FermionRedBlackGrid()) - { - conformable(_NumOp.FermionGrid(), _DenOp.FermionGrid()); - conformable(_NumOp.FermionRedBlackGrid(), _DenOp.FermionRedBlackGrid()); - conformable(_NumOp.GaugeGrid(), _DenOp.GaugeGrid()); - conformable(_NumOp.GaugeRedBlackGrid(), _DenOp.GaugeRedBlackGrid()); - }; + { + conformable(_NumOp.FermionGrid(), _DenOp.FermionGrid()); + conformable(_NumOp.FermionRedBlackGrid(), _DenOp.FermionRedBlackGrid()); + conformable(_NumOp.GaugeGrid(), _DenOp.GaugeGrid()); + conformable(_NumOp.GaugeRedBlackGrid(), _DenOp.GaugeRedBlackGrid()); + }; + + virtual std::string action_name(){return "TwoFlavourEvenOddRatioPseudoFermionAction";} + + virtual std::string LogParameters(){ + std::stringstream sstream; + sstream << GridLogMessage << "["< sig^2 = 0.5. - // - RealD scale = std::sqrt(0.5); + // P(phi) = e^{- phi^dag Vpc (MpcdagMpc)^-1 Vpcdag phi} + // + // NumOp == V + // DenOp == M + // + // Take phi_o = Vpcdag^{-1} Mpcdag eta_o ; eta_o = Mpcdag^{-1} Vpcdag Phi + // + // P(eta_o) = e^{- eta_o^dag eta_o} + // + // e^{x^2/2 sig^2} => sig^2 = 0.5. + // + RealD scale = std::sqrt(0.5); - FermionField eta (NumOp.FermionGrid()); - FermionField etaOdd (NumOp.FermionRedBlackGrid()); - FermionField etaEven(NumOp.FermionRedBlackGrid()); - FermionField tmp (NumOp.FermionRedBlackGrid()); + FermionField eta (NumOp.FermionGrid()); + FermionField etaOdd (NumOp.FermionRedBlackGrid()); + FermionField etaEven(NumOp.FermionRedBlackGrid()); + FermionField tmp (NumOp.FermionRedBlackGrid()); - gaussian(pRNG,eta); + gaussian(pRNG,eta); - pickCheckerboard(Even,etaEven,eta); - pickCheckerboard(Odd,etaOdd,eta); + pickCheckerboard(Even,etaEven,eta); + pickCheckerboard(Odd,etaOdd,eta); - NumOp.ImportGauge(U); - DenOp.ImportGauge(U); + NumOp.ImportGauge(U); + DenOp.ImportGauge(U); - SchurDifferentiableOperator Mpc(DenOp); - SchurDifferentiableOperator Vpc(NumOp); + SchurDifferentiableOperator Mpc(DenOp); + SchurDifferentiableOperator Vpc(NumOp); - // Odd det factors - Mpc.MpcDag(etaOdd,PhiOdd); - tmp=zero; - ActionSolver(Vpc,PhiOdd,tmp); - Vpc.Mpc(tmp,PhiOdd); + // Odd det factors + Mpc.MpcDag(etaOdd,PhiOdd); + tmp=zero; + ActionSolver(Vpc,PhiOdd,tmp); + Vpc.Mpc(tmp,PhiOdd); - // Even det factors - DenOp.MooeeDag(etaEven,tmp); - NumOp.MooeeInvDag(tmp,PhiEven); + // Even det factors + DenOp.MooeeDag(etaEven,tmp); + NumOp.MooeeInvDag(tmp,PhiEven); - PhiOdd =PhiOdd*scale; - PhiEven=PhiEven*scale; - + PhiOdd =PhiOdd*scale; + PhiEven=PhiEven*scale; + }; ////////////////////////////////////////////////////// @@ -119,33 +128,33 @@ namespace Grid{ ////////////////////////////////////////////////////// virtual RealD S(const GaugeField &U) { - NumOp.ImportGauge(U); - DenOp.ImportGauge(U); + NumOp.ImportGauge(U); + DenOp.ImportGauge(U); - SchurDifferentiableOperator Mpc(DenOp); - SchurDifferentiableOperator Vpc(NumOp); + SchurDifferentiableOperator Mpc(DenOp); + SchurDifferentiableOperator Vpc(NumOp); - FermionField X(NumOp.FermionRedBlackGrid()); - FermionField Y(NumOp.FermionRedBlackGrid()); + FermionField X(NumOp.FermionRedBlackGrid()); + FermionField Y(NumOp.FermionRedBlackGrid()); - Vpc.MpcDag(PhiOdd,Y); // Y= Vdag phi - X=zero; - ActionSolver(Mpc,Y,X); // X= (MdagM)^-1 Vdag phi - //Mpc.Mpc(X,Y); // Y= Mdag^-1 Vdag phi - // Multiply by Ydag - RealD action = real(innerProduct(Y,X)); + Vpc.MpcDag(PhiOdd,Y); // Y= Vdag phi + X=zero; + ActionSolver(Mpc,Y,X); // X= (MdagM)^-1 Vdag phi + //Mpc.Mpc(X,Y); // Y= Mdag^-1 Vdag phi + // Multiply by Ydag + RealD action = real(innerProduct(Y,X)); - //RealD action = norm2(Y); + //RealD action = norm2(Y); - // The EE factorised block; normally can replace with zero if det is constant (gauge field indept) - // Only really clover term that creates this. Leave the EE portion as a future to do to make most - // rapid progresss on DWF for now. - // - NumOp.MooeeDag(PhiEven,X); - DenOp.MooeeInvDag(X,Y); - action = action + norm2(Y); + // The EE factorised block; normally can replace with zero if det is constant (gauge field indept) + // Only really clover term that creates this. Leave the EE portion as a future to do to make most + // rapid progresss on DWF for now. + // + NumOp.MooeeDag(PhiEven,X); + DenOp.MooeeInvDag(X,Y); + action = action + norm2(Y); - return action; + return action; }; ////////////////////////////////////////////////////// @@ -155,44 +164,44 @@ namespace Grid{ ////////////////////////////////////////////////////// virtual void deriv(const GaugeField &U,GaugeField & dSdU) { - NumOp.ImportGauge(U); - DenOp.ImportGauge(U); + NumOp.ImportGauge(U); + DenOp.ImportGauge(U); - SchurDifferentiableOperator Mpc(DenOp); - SchurDifferentiableOperator Vpc(NumOp); + SchurDifferentiableOperator Mpc(DenOp); + SchurDifferentiableOperator Vpc(NumOp); - FermionField X(NumOp.FermionRedBlackGrid()); - FermionField Y(NumOp.FermionRedBlackGrid()); + FermionField X(NumOp.FermionRedBlackGrid()); + FermionField Y(NumOp.FermionRedBlackGrid()); - GaugeField force(NumOp.GaugeGrid()); + // This assignment is necessary to be compliant with the HMC grids + GaugeField force(dSdU._grid); - //Y=Vdag phi - //X = (Mdag M)^-1 V^dag phi - //Y = (Mdag)^-1 V^dag phi - Vpc.MpcDag(PhiOdd,Y); // Y= Vdag phi - X=zero; - DerivativeSolver(Mpc,Y,X); // X= (MdagM)^-1 Vdag phi - Mpc.Mpc(X,Y); // Y= Mdag^-1 Vdag phi + //Y=Vdag phi + //X = (Mdag M)^-1 V^dag phi + //Y = (Mdag)^-1 V^dag phi + Vpc.MpcDag(PhiOdd,Y); // Y= Vdag phi + X=zero; + DerivativeSolver(Mpc,Y,X); // X= (MdagM)^-1 Vdag phi + Mpc.Mpc(X,Y); // Y= Mdag^-1 Vdag phi - // phi^dag V (Mdag M)^-1 dV^dag phi - Vpc.MpcDagDeriv(force , X, PhiOdd ); dSdU=force; + // phi^dag V (Mdag M)^-1 dV^dag phi + Vpc.MpcDagDeriv(force , X, PhiOdd ); dSdU = force; - // phi^dag dV (Mdag M)^-1 V^dag phi - Vpc.MpcDeriv(force , PhiOdd, X ); dSdU=dSdU+force; + // phi^dag dV (Mdag M)^-1 V^dag phi + Vpc.MpcDeriv(force , PhiOdd, X ); dSdU = dSdU+force; - // - phi^dag V (Mdag M)^-1 Mdag dM (Mdag M)^-1 V^dag phi - // - phi^dag V (Mdag M)^-1 dMdag M (Mdag M)^-1 V^dag phi - Mpc.MpcDeriv(force,Y,X); dSdU=dSdU-force; - Mpc.MpcDagDeriv(force,X,Y); dSdU=dSdU-force; + // - phi^dag V (Mdag M)^-1 Mdag dM (Mdag M)^-1 V^dag phi + // - phi^dag V (Mdag M)^-1 dMdag M (Mdag M)^-1 V^dag phi + Mpc.MpcDeriv(force,Y,X); dSdU = dSdU-force; + Mpc.MpcDagDeriv(force,X,Y); dSdU = dSdU-force; - // FIXME No force contribution from EvenEven assumed here - // Needs a fix for clover. - assert(NumOp.ConstEE() == 1); - assert(DenOp.ConstEE() == 1); + // FIXME No force contribution from EvenEven assumed here + // Needs a fix for clover. + assert(NumOp.ConstEE() == 1); + assert(DenOp.ConstEE() == 1); - //dSdU = -Ta(dSdU); - dSdU = -dSdU; - + dSdU = -dSdU; + }; }; } diff --git a/lib/qcd/action/pseudofermion/TwoFlavourRatio.h b/lib/qcd/action/pseudofermion/TwoFlavourRatio.h index 26d21094..bcbf9364 100644 --- a/lib/qcd/action/pseudofermion/TwoFlavourRatio.h +++ b/lib/qcd/action/pseudofermion/TwoFlavourRatio.h @@ -57,6 +57,14 @@ namespace Grid{ OperatorFunction & AS ) : NumOp(_NumOp), DenOp(_DenOp), DerivativeSolver(DS), ActionSolver(AS), Phi(_NumOp.FermionGrid()) {}; + virtual std::string action_name(){return "TwoFlavourRatioPseudoFermionAction";} + + virtual std::string LogParameters(){ + std::stringstream sstream; + sstream << GridLogMessage << "["< + +This program is free software; you can redistribute it and/or modify +it under the terms of the GNU General Public License as published by +the Free Software Foundation; either version 2 of the License, or +(at your option) any later version. + +This program is distributed in the hope that it will be useful, +but WITHOUT ANY WARRANTY; without even the implied warranty of +MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +GNU General Public License for more details. + +You should have received a copy of the GNU General Public License along +with this program; if not, write to the Free Software Foundation, Inc., +51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA. + +See the full license in the file "LICENSE" in the top level distribution +directory +*************************************************************************************/ +/* END LEGAL */ +#ifndef GRID_QCD_SCALAR_H +#define GRID_QCD_SCALAR_H + +#include +#include +#include + +namespace Grid { +namespace QCD { + + typedef ScalarAction ScalarActionR; + typedef ScalarAction ScalarActionF; + typedef ScalarAction ScalarActionD; + + template using ScalarAdjActionR = ScalarInteractionAction, Dimensions>; + template using ScalarAdjActionF = ScalarInteractionAction, Dimensions>; + template using ScalarAdjActionD = ScalarInteractionAction, Dimensions>; + +} +} + +#endif // GRID_QCD_SCALAR_H diff --git a/lib/qcd/action/scalar/ScalarAction.h b/lib/qcd/action/scalar/ScalarAction.h new file mode 100644 index 00000000..2c82d2e3 --- /dev/null +++ b/lib/qcd/action/scalar/ScalarAction.h @@ -0,0 +1,83 @@ +/************************************************************************************* + + Grid physics library, www.github.com/paboyle/Grid + + Source file: ./lib/qcd/action/gauge/WilsonGaugeAction.h + + Copyright (C) 2015 + + Author: Azusa Yamaguchi + Author: Peter Boyle + Author: neo + Author: paboyle + + This program is free software; you can redistribute it and/or modify + it under the terms of the GNU General Public License as published by + the Free Software Foundation; either version 2 of the License, or + (at your option) any later version. + + This program is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU General Public License for more details. + + You should have received a copy of the GNU General Public License along + with this program; if not, write to the Free Software Foundation, Inc., + 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA. + + See the full license in the file "LICENSE" in the top level distribution +directory + *************************************************************************************/ +/* END LEGAL */ + +#ifndef SCALAR_ACTION_H +#define SCALAR_ACTION_H + +namespace Grid { + // FIXME drop the QCD namespace everywhere here + +template +class ScalarAction : public QCD::Action { + public: + INHERIT_FIELD_TYPES(Impl); + + private: + RealD mass_square; + RealD lambda; + + public: + ScalarAction(RealD ms, RealD l) : mass_square(ms), lambda(l) {} + + virtual std::string LogParameters() { + std::stringstream sstream; + sstream << GridLogMessage << "[ScalarAction] lambda : " << lambda << std::endl; + sstream << GridLogMessage << "[ScalarAction] mass_square : " << mass_square << std::endl; + return sstream.str(); + } + virtual std::string action_name() {return "ScalarAction";} + + virtual void refresh(const Field &U, GridParallelRNG &pRNG) {} // noop as no pseudoferms + + virtual RealD S(const Field &p) { + return (mass_square * 0.5 + QCD::Nd) * ScalarObs::sumphisquared(p) + + (lambda / 24.) * ScalarObs::sumphifourth(p) + + ScalarObs::sumphider(p); + }; + + virtual void deriv(const Field &p, + Field &force) { + Field tmp(p._grid); + Field p2(p._grid); + ScalarObs::phisquared(p2, p); + tmp = -(Cshift(p, 0, -1) + Cshift(p, 0, 1)); + for (int mu = 1; mu < QCD::Nd; mu++) tmp -= Cshift(p, mu, -1) + Cshift(p, mu, 1); + + force =+(mass_square + 2. * QCD::Nd) * p + (lambda / 6.) * p2 * p + tmp; + } +}; + + + +} // namespace Grid + +#endif // SCALAR_ACTION_H diff --git a/lib/qcd/action/scalar/ScalarImpl.h b/lib/qcd/action/scalar/ScalarImpl.h new file mode 100644 index 00000000..f85ab840 --- /dev/null +++ b/lib/qcd/action/scalar/ScalarImpl.h @@ -0,0 +1,162 @@ +#ifndef SCALAR_IMPL +#define SCALAR_IMPL + + +namespace Grid { + //namespace QCD { + +template +class ScalarImplTypes { + public: + typedef S Simd; + + template + using iImplField = iScalar > >; + + typedef iImplField SiteField; + typedef SiteField SitePropagator; + typedef SiteField SiteComplex; + + typedef Lattice Field; + typedef Field ComplexField; + typedef Field FermionField; + typedef Field PropagatorField; + + static inline void generate_momenta(Field& P, GridParallelRNG& pRNG){ + gaussian(pRNG, P); + } + + static inline Field projectForce(Field& P){return P;} + + static inline void update_field(Field& P, Field& U, double ep) { + U += P*ep; + } + + static inline RealD FieldSquareNorm(Field& U) { + return (- sum(trace(U*U))/2.0); + } + + static inline void HotConfiguration(GridParallelRNG &pRNG, Field &U) { + gaussian(pRNG, U); + } + + static inline void TepidConfiguration(GridParallelRNG &pRNG, Field &U) { + gaussian(pRNG, U); + } + + static inline void ColdConfiguration(GridParallelRNG &pRNG, Field &U) { + U = 1.0; + } + + static void MomentumSpacePropagator(Field &out, RealD m) + { + GridBase *grid = out._grid; + Field kmu(grid), one(grid); + const unsigned int nd = grid->_ndimension; + std::vector &l = grid->_fdimensions; + + one = Complex(1.0,0.0); + out = m*m; + for(int mu = 0; mu < nd; mu++) + { + Real twoPiL = M_PI*2./l[mu]; + + LatticeCoordinate(kmu,mu); + kmu = 2.*sin(.5*twoPiL*kmu); + out = out + kmu*kmu; + } + out = one/out; + } + + static void FreePropagator(const Field &in, Field &out, + const Field &momKernel) + { + FFT fft((GridCartesian *)in._grid); + Field inFT(in._grid); + + fft.FFT_all_dim(inFT, in, FFT::forward); + inFT = inFT*momKernel; + fft.FFT_all_dim(out, inFT, FFT::backward); + } + + static void FreePropagator(const Field &in, Field &out, RealD m) + { + Field momKernel(in._grid); + + MomentumSpacePropagator(momKernel, m); + FreePropagator(in, out, momKernel); + } + + }; + + template + class ScalarAdjMatrixImplTypes { + public: + typedef S Simd; + typedef QCD::SU Group; + + template + using iImplField = iScalar>>; + template + using iImplComplex = iScalar>>; + + typedef iImplField SiteField; + typedef SiteField SitePropagator; + typedef iImplComplex SiteComplex; + + typedef Lattice Field; + typedef Lattice ComplexField; + typedef Field FermionField; + typedef Field PropagatorField; + + static inline void generate_momenta(Field& P, GridParallelRNG& pRNG) { + Group::GaussianFundamentalLieAlgebraMatrix(pRNG, P); + } + + static inline Field projectForce(Field& P) {return P;} + + static inline void update_field(Field& P, Field& U, double ep) { + U += P*ep; + } + + static inline RealD FieldSquareNorm(Field& U) { + return (TensorRemove(sum(trace(U*U))).real()); + } + + static inline void HotConfiguration(GridParallelRNG &pRNG, Field &U) { + Group::GaussianFundamentalLieAlgebraMatrix(pRNG, U); + } + + static inline void TepidConfiguration(GridParallelRNG &pRNG, Field &U) { + Group::GaussianFundamentalLieAlgebraMatrix(pRNG, U, 0.01); + } + + static inline void ColdConfiguration(GridParallelRNG &pRNG, Field &U) { + U = zero; + } + + }; + + + + + typedef ScalarImplTypes ScalarImplR; + typedef ScalarImplTypes ScalarImplF; + typedef ScalarImplTypes ScalarImplD; + typedef ScalarImplTypes ScalarImplCR; + typedef ScalarImplTypes ScalarImplCF; + typedef ScalarImplTypes ScalarImplCD; + + // Hardcoding here the size of the matrices + typedef ScalarAdjMatrixImplTypes ScalarAdjImplR; + typedef ScalarAdjMatrixImplTypes ScalarAdjImplF; + typedef ScalarAdjMatrixImplTypes ScalarAdjImplD; + + template using ScalarNxNAdjImplR = ScalarAdjMatrixImplTypes; + template using ScalarNxNAdjImplF = ScalarAdjMatrixImplTypes; + template using ScalarNxNAdjImplD = ScalarAdjMatrixImplTypes; + + //} +} + +#endif diff --git a/lib/qcd/action/scalar/ScalarInteractionAction.h b/lib/qcd/action/scalar/ScalarInteractionAction.h new file mode 100644 index 00000000..4d189352 --- /dev/null +++ b/lib/qcd/action/scalar/ScalarInteractionAction.h @@ -0,0 +1,148 @@ +/************************************************************************************* + + Grid physics library, www.github.com/paboyle/Grid + + Source file: ./lib/qcd/action/gauge/WilsonGaugeAction.h + + Copyright (C) 2015 + + Author: Guido Cossu + + This program is free software; you can redistribute it and/or modify + it under the terms of the GNU General Public License as published by + the Free Software Foundation; either version 2 of the License, or + (at your option) any later version. + + This program is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU General Public License for more details. + + You should have received a copy of the GNU General Public License along + with this program; if not, write to the Free Software Foundation, Inc., + 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA. + + See the full license in the file "LICENSE" in the top level distribution +directory + *************************************************************************************/ +/* END LEGAL */ + +#ifndef SCALAR_INT_ACTION_H +#define SCALAR_INT_ACTION_H + + +// Note: this action can completely absorb the ScalarAction for real float fields +// use the scalarObjs to generalise the structure + +namespace Grid { + // FIXME drop the QCD namespace everywhere here + + template + class ScalarInteractionAction : public QCD::Action { + public: + INHERIT_FIELD_TYPES(Impl); + private: + RealD mass_square; + RealD lambda; + + + typedef typename Field::vector_object vobj; + typedef CartesianStencil Stencil; + + SimpleCompressor compressor; + int npoint = 2*Ndim; + std::vector directions;// = {0,1,2,3,0,1,2,3}; // forcing 4 dimensions + std::vector displacements;// = {1,1,1,1, -1,-1,-1,-1}; + + + public: + + ScalarInteractionAction(RealD ms, RealD l) : mass_square(ms), lambda(l), displacements(2*Ndim,0), directions(2*Ndim,0){ + for (int mu = 0 ; mu < Ndim; mu++){ + directions[mu] = mu; directions[mu+Ndim] = mu; + displacements[mu] = 1; displacements[mu+Ndim] = -1; + } + } + + virtual std::string LogParameters() { + std::stringstream sstream; + sstream << GridLogMessage << "[ScalarAction] lambda : " << lambda << std::endl; + sstream << GridLogMessage << "[ScalarAction] mass_square : " << mass_square << std::endl; + return sstream.str(); + } + + virtual std::string action_name() {return "ScalarAction";} + + virtual void refresh(const Field &U, GridParallelRNG &pRNG) {} + + virtual RealD S(const Field &p) { + assert(p._grid->Nd() == Ndim); + static Stencil phiStencil(p._grid, npoint, 0, directions, displacements); + phiStencil.HaloExchange(p, compressor); + Field action(p._grid), pshift(p._grid), phisquared(p._grid); + phisquared = p*p; + action = (2.0*Ndim + mass_square)*phisquared - lambda/24.*phisquared*phisquared; + for (int mu = 0; mu < Ndim; mu++) { + // pshift = Cshift(p, mu, +1); // not efficient, implement with stencils + parallel_for (int i = 0; i < p._grid->oSites(); i++) { + int permute_type; + StencilEntry *SE; + vobj temp2; + const vobj *temp, *t_p; + + SE = phiStencil.GetEntry(permute_type, mu, i); + t_p = &p._odata[i]; + if ( SE->_is_local ) { + temp = &p._odata[SE->_offset]; + if ( SE->_permute ) { + permute(temp2, *temp, permute_type); + action._odata[i] -= temp2*(*t_p) + (*t_p)*temp2; + } else { + action._odata[i] -= (*temp)*(*t_p) + (*t_p)*(*temp); + } + } else { + action._odata[i] -= phiStencil.CommBuf()[SE->_offset]*(*t_p) + (*t_p)*phiStencil.CommBuf()[SE->_offset]; + } + } + // action -= pshift*p + p*pshift; + } + // NB the trace in the algebra is normalised to 1/2 + // minus sign coming from the antihermitian fields + return -(TensorRemove(sum(trace(action)))).real(); + }; + + virtual void deriv(const Field &p, Field &force) { + assert(p._grid->Nd() == Ndim); + force = (2.0*Ndim + mass_square)*p - lambda/12.*p*p*p; + // move this outside + static Stencil phiStencil(p._grid, npoint, 0, directions, displacements); + phiStencil.HaloExchange(p, compressor); + + //for (int mu = 0; mu < QCD::Nd; mu++) force -= Cshift(p, mu, -1) + Cshift(p, mu, 1); + for (int point = 0; point < npoint; point++) { + parallel_for (int i = 0; i < p._grid->oSites(); i++) { + const vobj *temp; + vobj temp2; + int permute_type; + StencilEntry *SE; + SE = phiStencil.GetEntry(permute_type, point, i); + + if ( SE->_is_local ) { + temp = &p._odata[SE->_offset]; + if ( SE->_permute ) { + permute(temp2, *temp, permute_type); + force._odata[i] -= temp2; + } else { + force._odata[i] -= *temp; + } + } else { + force._odata[i] -= phiStencil.CommBuf()[SE->_offset]; + } + } + } + } + }; + +} // namespace Grid + +#endif // SCALAR_INT_ACTION_H diff --git a/lib/qcd/hmc/GenericHMCrunner.h b/lib/qcd/hmc/GenericHMCrunner.h new file mode 100644 index 00000000..4f6c1af0 --- /dev/null +++ b/lib/qcd/hmc/GenericHMCrunner.h @@ -0,0 +1,219 @@ +/************************************************************************************* + +Grid physics library, www.github.com/paboyle/Grid + +Source file: ./lib/qcd/hmc/GenericHmcRunner.h + +Copyright (C) 2015 + +Author: paboyle +Author: Guido Cossu + +This program is free software; you can redistribute it and/or modify +it under the terms of the GNU General Public License as published by +the Free Software Foundation; either version 2 of the License, or +(at your option) any later version. + +This program is distributed in the hope that it will be useful, +but WITHOUT ANY WARRANTY; without even the implied warranty of +MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +GNU General Public License for more details. + +You should have received a copy of the GNU General Public License along +with this program; if not, write to the Free Software Foundation, Inc., +51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA. + + See the full license in the file "LICENSE" in the top level distribution + directory + *************************************************************************************/ +/* END LEGAL */ +#ifndef GRID_GENERIC_HMC_RUNNER +#define GRID_GENERIC_HMC_RUNNER + +#include + +namespace Grid { +namespace QCD { + + +// very ugly here but possibly resolved if we had a base Reader class +template < class ReaderClass > +class HMCRunnerBase { +public: + virtual void Run() = 0; + virtual void initialize(ReaderClass& ) = 0; +}; + + +template class Integrator, + class RepresentationsPolicy = NoHirep, class ReaderClass = XmlReader> +class HMCWrapperTemplate: public HMCRunnerBase { + public: + INHERIT_FIELD_TYPES(Implementation); + typedef Implementation ImplPolicy; // visible from outside + template > + using IntegratorType = Integrator; + + HMCparameters Parameters; + std::string ParameterFile; + HMCResourceManager Resources; + + // The set of actions (keep here for lower level users, for now) + ActionSet TheAction; + + HMCWrapperTemplate() = default; + + HMCWrapperTemplate(HMCparameters Par){ + Parameters = Par; + } + + void initialize(ReaderClass & TheReader){ + std::cout << "Initialization of the HMC" << std::endl; + Resources.initialize(TheReader); + + // eventually add smearing + + Resources.GetActionSet(TheAction); + } + + + void ReadCommandLine(int argc, char **argv) { + std::string arg; + + if (GridCmdOptionExists(argv, argv + argc, "--StartingType")) { + arg = GridCmdOptionPayload(argv, argv + argc, "--StartingType"); + + if (arg != "HotStart" && arg != "ColdStart" && arg != "TepidStart" && + arg != "CheckpointStart") { + std::cout << GridLogError << "Unrecognized option in --StartingType\n"; + std::cout + << GridLogError + << "Valid [HotStart, ColdStart, TepidStart, CheckpointStart]\n"; + exit(1); + } + Parameters.StartingType = arg; + } + + if (GridCmdOptionExists(argv, argv + argc, "--StartingTrajectory")) { + arg = GridCmdOptionPayload(argv, argv + argc, "--StartingTrajectory"); + std::vector ivec(0); + GridCmdOptionIntVector(arg, ivec); + Parameters.StartTrajectory = ivec[0]; + } + + if (GridCmdOptionExists(argv, argv + argc, "--Trajectories")) { + arg = GridCmdOptionPayload(argv, argv + argc, "--Trajectories"); + std::vector ivec(0); + GridCmdOptionIntVector(arg, ivec); + Parameters.Trajectories = ivec[0]; + } + + if (GridCmdOptionExists(argv, argv + argc, "--Thermalizations")) { + arg = GridCmdOptionPayload(argv, argv + argc, "--Thermalizations"); + std::vector ivec(0); + GridCmdOptionIntVector(arg, ivec); + Parameters.NoMetropolisUntil = ivec[0]; + } + if (GridCmdOptionExists(argv, argv + argc, "--ParameterFile")) { + arg = GridCmdOptionPayload(argv, argv + argc, "--ParameterFile"); + ParameterFile = arg; + } + } + + + template + void Run(SmearingPolicy &S) { + Runner(S); + } + + void Run(){ + NoSmearing S; + Runner(S); + } + + ////////////////////////////////////////////////////////////////// + + private: + template + void Runner(SmearingPolicy &Smearing) { + auto UGrid = Resources.GetCartesian(); + Resources.AddRNGs(); + Field U(UGrid); + + // Can move this outside? + typedef IntegratorType TheIntegrator; + TheIntegrator MDynamics(UGrid, Parameters.MD, TheAction, Smearing); + + if (Parameters.StartingType == "HotStart") { + // Hot start + Resources.SeedFixedIntegers(); + Implementation::HotConfiguration(Resources.GetParallelRNG(), U); + } else if (Parameters.StartingType == "ColdStart") { + // Cold start + Resources.SeedFixedIntegers(); + Implementation::ColdConfiguration(Resources.GetParallelRNG(), U); + } else if (Parameters.StartingType == "TepidStart") { + // Tepid start + Resources.SeedFixedIntegers(); + Implementation::TepidConfiguration(Resources.GetParallelRNG(), U); + } else if (Parameters.StartingType == "CheckpointStart") { + // CheckpointRestart + Resources.GetCheckPointer()->CheckpointRestore(Parameters.StartTrajectory, U, + Resources.GetSerialRNG(), + Resources.GetParallelRNG()); + } + + Smearing.set_Field(U); + + HybridMonteCarlo HMC(Parameters, MDynamics, + Resources.GetSerialRNG(), + Resources.GetParallelRNG(), + Resources.GetObservables(), U); + + // Run it + HMC.evolve(); + } +}; + +// These are for gauge fields, default integrator MinimumNorm2 +template