From ba9bbe0221798d953d1916365e98ac9fba9613ed Mon Sep 17 00:00:00 2001 From: Peter Boyle Date: Wed, 12 Feb 2025 19:34:59 +0000 Subject: [PATCH] Bounce MPI through host --- Grid/lattice/PaddedCell.h | 22 ++++++++++++++++++++++ 1 file changed, 22 insertions(+) diff --git a/Grid/lattice/PaddedCell.h b/Grid/lattice/PaddedCell.h index fb533212..bdb0301c 100644 --- a/Grid/lattice/PaddedCell.h +++ b/Grid/lattice/PaddedCell.h @@ -466,6 +466,12 @@ public: static deviceVector recv_buf; send_buf.resize(buffer_size*2*depth); recv_buf.resize(buffer_size*2*depth); +#ifndef ACCELERATOR_AWARE_MPI + static hostVector hsend_buf; + static hostVector hrecv_buf; + hsend_buf.resize(buffer_size*2*depth); + hrecv_buf.resize(buffer_size*2*depth); +#endif std::vector fwd_req; std::vector bwd_req; @@ -495,9 +501,17 @@ public: t_gather+=usecond()-t; t=usecond(); +#ifdef ACCELERATOR_AWARE_MPI grid->SendToRecvFromBegin(fwd_req, (void *)&send_buf[d*buffer_size], xmit_to_rank, (void *)&recv_buf[d*buffer_size], recv_from_rank, bytes, tag); +#else + acceleratorCopyFromDevice(&send_buf[d*buffer_size],&hsend_buf[d*buffer_size],bytes); + grid->SendToRecvFromBegin(fwd_req, + (void *)&hsend_buf[d*buffer_size], xmit_to_rank, + (void *)&hrecv_buf[d*buffer_size], recv_from_rank, bytes, tag); + acceleratorCopyToDevice(&hrecv_buf[d*buffer_size],&recv_buf[d*buffer_size],bytes); +#endif t_comms+=usecond()-t; } for ( int d=0;d < depth ; d ++ ) { @@ -508,9 +522,17 @@ public: t_gather+= usecond() - t; t=usecond(); +#ifdef ACCELERATOR_AWARE_MPI grid->SendToRecvFromBegin(bwd_req, (void *)&send_buf[(d+depth)*buffer_size], recv_from_rank, (void *)&recv_buf[(d+depth)*buffer_size], xmit_to_rank, bytes,tag); +#else + acceleratorCopyFromDevice(&send_buf[(d+depth)*buffer_size],&hsend_buf[(d+depth)*buffer_size],bytes); + grid->SendToRecvFromBegin(bwd_req, + (void *)&hsend_buf[(d+depth)*buffer_size], recv_from_rank, + (void *)&hrecv_buf[(d+depth)*buffer_size], xmit_to_rank, bytes,tag); + acceleratorCopyToDevice(&hrecv_buf[(d+depth)*buffer_size],&recv_buf[(d+depth)*buffer_size],bytes); +#endif t_comms+=usecond()-t; }