1
0
mirror of https://github.com/paboyle/Grid.git synced 2024-11-10 07:55:35 +00:00

Initialise CUDA device prior to entering MPI.

This may or may not interact with Summit which configures MPI - CUDA mapping with jsrun.
TBD
Cases of OpenMPI and MVAPICH are covered, and default to cudaSetDevice(0) otherwise
This commit is contained in:
Peter Boyle 2019-07-11 03:14:23 +01:00
parent 4c3225412b
commit 44170cc15f

View File

@ -239,48 +239,62 @@ static int Grid_is_initialised;
///////////////////////////////////////////////////////// /////////////////////////////////////////////////////////
// Reinit guard // Reinit guard
///////////////////////////////////////////////////////// /////////////////////////////////////////////////////////
#ifdef GRID_NVCC
void GridGpuInit(void) void GridGpuInit(void)
{ {
int nDevices; #ifdef GRID_NVCC
int nDevices = 1;
cudaGetDeviceCount(&nDevices); cudaGetDeviceCount(&nDevices);
char * localRankStr = NULL;
int rank = 0, device = 0, world_rank=0;
#define ENV_LOCAL_RANK_OMPI "OMPI_COMM_WORLD_LOCAL_RANK"
#define ENV_LOCAL_RANK_MVAPICH "MV2_COMM_WORLD_LOCAL_RANK"
#define ENV_RANK_OMPI "OMPI_COMM_WORLD_RANK"
#define ENV_RANK_MVAPICH "MV2_COMM_WORLD_RANK"
// We extract the local rank initialization using an environment variable
if ((localRankStr = getenv(ENV_LOCAL_RANK_OMPI)) != NULL)
{
rank = atoi(localRankStr);
device = rank %nDevices;
}
if ((localRankStr = getenv(ENV_LOCAL_RANK_MVAPICH)) != NULL)
{
rank = atoi(localRankStr);
device = rank %nDevices;
}
if ((localRankStr = getenv(ENV_RANK_OMPI )) != NULL) { world_rank = atoi(localRankStr);}
if ((localRankStr = getenv(ENV_RANK_MVAPICH)) != NULL) { world_rank = atoi(localRankStr);}
cudaSetDevice(device);
for (int i = 0; i < nDevices; i++) { for (int i = 0; i < nDevices; i++) {
cudaDeviceProp prop; cudaDeviceProp prop;
cudaGetDeviceProperties(&prop, i); cudaGetDeviceProperties(&prop, i);
/*
printf("Device Number: %d\n", i); if ( world_rank == 0) {
printf(" Device name: %s\n", prop.name); printf("Device Number: %d\n", i);
printf(" Memory Clock Rate (KHz): %d\n", printf(" Device name: %s\n", prop.name);
prop.memoryClockRate); printf(" Peak Memory Bandwidth (GB/s): %f\n\n",2.0*prop.memoryClockRate*(prop.memoryBusWidth/8)/1.0e6);
printf(" Memory Bus Width (bits): %d\n",
prop.memoryBusWidth);
printf(" Peak Memory Bandwidth (GB/s): %f\n\n",
2.0*prop.memoryClockRate*(prop.memoryBusWidth/8)/1.0e6);
#define GPU_PROP_FMT(canMapHostMemory,FMT) printf(" " #canMapHostMemory ": " FMT" \n",prop.canMapHostMemory); #define GPU_PROP_FMT(canMapHostMemory,FMT) printf(" " #canMapHostMemory ": " FMT" \n",prop.canMapHostMemory);
#define GPU_PROP(canMapHostMemory) printf(" " #canMapHostMemory ": %d \n",prop.canMapHostMemory); #define GPU_PROP(canMapHostMemory) printf(" " #canMapHostMemory ": %d \n",prop.canMapHostMemory);
GPU_PROP(canMapHostMemory);
GPU_PROP(canUseHostPointerForRegisteredMem);
GPU_PROP(globalL1CacheSupported);
GPU_PROP(isMultiGpuBoard); GPU_PROP(isMultiGpuBoard);
GPU_PROP(kernelExecTimeoutEnabled);
GPU_PROP(l2CacheSize); GPU_PROP(l2CacheSize);
GPU_PROP(managedMemory); GPU_PROP(managedMemory);
GPU_PROP(pageableMemoryAccess);
GPU_PROP(regsPerMultiprocessor);
GPU_PROP_FMT(sharedMemPerBlock,"%lx");
GPU_PROP_FMT(sharedMemPerMultiprocessor,"%lx");
GPU_PROP(singleToDoublePrecisionPerfRatio); GPU_PROP(singleToDoublePrecisionPerfRatio);
GPU_PROP(unifiedAddressing); GPU_PROP(unifiedAddressing);
GPU_PROP(warpSize); GPU_PROP(warpSize);
*/ }
} }
}
#endif #endif
}
void Grid_init(int *argc,char ***argv) void Grid_init(int *argc,char ***argv)
{ {
GridGpuInit(); // Must come first to set device prior to MPI init
assert(Grid_is_initialised == 0); assert(Grid_is_initialised == 0);
GridLogger::GlobalStopWatch.Start(); GridLogger::GlobalStopWatch.Start();
@ -307,9 +321,6 @@ void Grid_init(int *argc,char ***argv)
Grid_debug_handler_init(); Grid_debug_handler_init();
} }
#ifdef GRID_NVCC
GridGpuInit();
#endif
CartesianCommunicator::Init(argc,argv); CartesianCommunicator::Init(argc,argv);
if( !GridCmdOptionExists(*argv,*argv+*argc,"--debug-stdout") ){ if( !GridCmdOptionExists(*argv,*argv+*argc,"--debug-stdout") ){