diff --git a/src/include/comm.h b/src/include/comm.h index 3637d0ca..5c652af7 100644 --- a/src/include/comm.h +++ b/src/include/comm.h @@ -54,6 +54,7 @@ struct mscclppComm int rank; // my rank in the communicator int nRanks; // number of GPUs in communicator int cudaDev; // my cuda device index + int numaNode; // my numa node number // Flag to ask MSCCLPP kernels to abort volatile uint32_t* abortFlag; diff --git a/src/include/mscclpp.h b/src/include/mscclpp.h index 4f978486..012cffc8 100644 --- a/src/include/mscclpp.h +++ b/src/include/mscclpp.h @@ -373,6 +373,13 @@ void mscclppDefaultLogHandler(const char* msg); */ mscclppResult_t mscclppSetLogHandler(mscclppLogHandler_t handler); +/* Bind NUMA node for the communicator. + * + * Inputs: + * numaNode: the NUMA node to be bound +*/ +mscclppResult_t mscclppNumaBind(mscclppComm_t comm, int numaNode); + #ifdef __cplusplus } // end extern "C" #endif diff --git a/src/init.cc b/src/init.cc index 3b5be4c1..a1d7dffc 100644 --- a/src/init.cc +++ b/src/init.cc @@ -75,6 +75,7 @@ mscclppResult_t mscclppCommInitRank(mscclppComm_t* comm, int nranks, const char* MSCCLPPCHECKGOTO(mscclppCalloc(&_comm, 1), res, fail); _comm->rank = rank; _comm->nRanks = nranks; + _comm->numaNode = -1; // We assume that the user has set the device to the intended one already CUDACHECK(cudaGetDevice(&_comm->cudaDev)); @@ -547,3 +548,10 @@ mscclppResult_t mscclppSetBootstrapConnTimeout(int timeout) config->setBootstrapConnectionTimeoutConfig(timeout); return mscclppSuccess; } + +MSCCLPP_API(mscclppResult_t, mscclppNumaBind, mscclppComm_t comm, int numaNode); +mscclppResult_t mscclppNumaBind(mscclppComm_t comm, int numaNode) +{ + comm->numaNode = numaNode; + return mscclppSuccess; +} diff --git a/src/proxy.cc b/src/proxy.cc index 44a17b27..8381b5c2 100644 --- a/src/proxy.cc +++ b/src/proxy.cc @@ -74,14 +74,19 @@ void* mscclppProxyService(void* _args) PROXYCUDACHECK(cudaStreamCreateWithFlags(&stream, cudaStreamNonBlocking)); bool isP2pProxy = (ibCtx == nullptr); - if (isP2pProxy) { - // TODO(chhwang): find numa node - // Current mapping is based on NDv4: GPU [0,1,2,3,4,5,6,7] -> NUMA [1,1,0,0,3,3,2,2] - // TODO(saemal): either ask user or detect it automatically - NumaBind((comm->cudaDev / 2) ^ 1); - p2pStream = args->proxyState->stream; + int numaNode = comm->numaNode; + if (numaNode != -1) { + NumaBind(numaNode); } else { - NumaBind(ibCtx->numaNode); + if (isP2pProxy) { + // TODO(chhwang): find numa node + // Current mapping is based on NDv4: GPU [0,1,2,3,4,5,6,7] -> NUMA [1,1,0,0,3,3,2,2] + // TODO(saemal): either ask user or detect it automatically + NumaBind((comm->cudaDev / 2) ^ 1); + p2pStream = args->proxyState->stream; + } else { + NumaBind(ibCtx->numaNode); + } } free(_args); // allocated in mscclppProxyCreate