diff --git a/src/include/proxy.h b/src/include/proxy.h index 604399e5..225e9fae 100644 --- a/src/include/proxy.h +++ b/src/include/proxy.h @@ -5,9 +5,15 @@ #include "comm.h" #include +typedef enum { + MSCCLPP_PROXY_RUN_STATE_IDLE = 0, + MSCCLPP_PROXY_RUN_STATE_RUNNING, + MSCCLPP_PROXY_RUN_STATE_EXITING, +} mscclppProxyRunState_t; + struct mscclppProxyState { pthread_t *threads; - int *runs; + mscclppProxyRunState_t *runs; }; mscclppResult_t mscclppProxyCreate(struct mscclppComm* comm); diff --git a/src/proxy.cc b/src/proxy.cc index df607541..462db3fc 100644 --- a/src/proxy.cc +++ b/src/proxy.cc @@ -33,7 +33,7 @@ struct proxyArgs { struct mscclppComm* comm; struct mscclppIbContext* ibCtx; cudaStream_t stream; - volatile int* run; + volatile mscclppProxyRunState_t* run; int connIdx; }; @@ -41,8 +41,7 @@ struct proxyArgs { void* mscclppProxyServiceP2P(void* _args) { struct proxyArgs *args = (struct proxyArgs *)_args; struct mscclppComm *comm = args->comm; - // TODO(saemal): we perhaps need a finite state for run instead of just 0 and 1 - volatile int *run = args->run; + volatile mscclppProxyRunState_t *run = args->run; struct mscclppConn *conn = &comm->conns[args->connIdx]; cudaStream_t stream = args->stream; free(_args); @@ -57,7 +56,7 @@ void* mscclppProxyServiceP2P(void* _args) { PROXYCUDACHECK(cudaSetDevice(comm->cudaDev)); PROXYCUDACHECK(cudaStreamCreateWithFlags(&stream, cudaStreamNonBlocking)); - while (*run) { + while (*run == MSCCLPP_PROXY_RUN_STATE_RUNNING) { // Poll to see if we are ready to send anything trigger.value = *(volatile uint64_t *)conn->cpuTrigger; if (trigger.value == 0) continue; @@ -80,7 +79,7 @@ void* mscclppProxyServiceP2P(void* _args) { volatile uint64_t *tmp = (volatile uint64_t *)conn->cpuTrigger; *tmp = 0; } - *run = 1; + *run = MSCCLPP_PROXY_RUN_STATE_IDLE; PROXYCUDACHECK(cudaStreamDestroy(stream)); // WARN("Proxy exits: rank %d", rank); @@ -94,7 +93,7 @@ void* mscclppProxyServiceIb(void* _args) { struct proxyArgs *args = (struct proxyArgs *)_args; struct mscclppComm *comm = args->comm; struct mscclppIbContext *ibCtx = args->ibCtx; - volatile int *run = args->run; + volatile mscclppProxyRunState_t *run = args->run; struct mscclppConn *conn = &comm->conns[args->connIdx]; free(_args); uint64_t currentProxyFlagVlaue = *conn->cpuProxyFlag; @@ -114,7 +113,7 @@ void* mscclppProxyServiceIb(void* _args) { WARN("postRecv failed: errno %d", errno); } - while (*run) { + while (*run == MSCCLPP_PROXY_RUN_STATE_RUNNING) { // Try send if (sendState == SEND_STATE_INIT) { trigger.value = *(volatile uint64_t *)conn->cpuTrigger; @@ -162,7 +161,7 @@ void* mscclppProxyServiceIb(void* _args) { } } } - *run = 1; + *run = MSCCLPP_PROXY_RUN_STATE_IDLE; // WARN("Proxy exits: rank %d", rank); return NULL; } @@ -174,7 +173,7 @@ void* mscclppProxyServiceIb(void* _args) { struct proxyArgs *args = (struct proxyArgs *)_args; struct mscclppComm *comm = args->comm; struct mscclppIbContext *ibCtx = args->ibCtx; - volatile int *run = args->run; + volatile mscclppProxyRunState_t *run = args->run; struct mscclppConn *conn = &comm->conns[args->connIdx]; free(_args); @@ -184,7 +183,7 @@ void* mscclppProxyServiceIb(void* _args) { NumaBind(ibCtx->numaNode); - while (*run) { + while (*run == MSCCLPP_PROXY_RUN_STATE_RUNNING) { // Poll to see if we are ready to send anything trigger.value = *(volatile uint64_t *)conn->cpuTrigger; if (trigger.value == 0) continue; @@ -234,7 +233,7 @@ void* mscclppProxyServiceIb(void* _args) { volatile uint64_t *tmp = (volatile uint64_t *)conn->cpuTrigger; *tmp = 0; } - *run = 1; + *run = MSCCLPP_PROXY_RUN_STATE_IDLE; // WARN("Proxy exits: rank %d", rank); return NULL; } @@ -280,7 +279,7 @@ mscclppResult_t mscclppProxyCreate(struct mscclppComm* comm) { args->ibCtx = comm->ibContext[i]; args->run = &comm->proxyState[i].runs[j]; args->connIdx = j; - *args->run = 1; + *args->run = MSCCLPP_PROXY_RUN_STATE_RUNNING; pthread_create(&comm->proxyState[i].threads[j], NULL, mscclppProxyService, args); mscclppSetThreadName(comm->proxyState[i].threads[j], "MSCCLPP Service %2d - %4d", i, j); } @@ -303,7 +302,7 @@ mscclppResult_t mscclppProxyCreate(struct mscclppComm* comm) { args->run = &proxyState->runs[j]; args->connIdx = j; CUDACHECK(cudaStreamCreateWithFlags(&args->stream, cudaStreamNonBlocking)); - *args->run = 1; + *args->run = MSCCLPP_PROXY_RUN_STATE_RUNNING; pthread_create(&proxyState->threads[j], NULL, mscclppProxyService, args); mscclppSetThreadName(proxyState->threads[j], "MSCCLPP Service %2d - %4d", MSCCLPP_IB_MAX_DEVS + 1, j); } @@ -312,12 +311,11 @@ mscclppResult_t mscclppProxyCreate(struct mscclppComm* comm) { static void _stopProxy(struct mscclppComm* comm, int devIdx, int connIdx) { volatile int *run = (volatile int *)&comm->proxyState[devIdx].runs[connIdx]; - if (*run == 0) return; - *run = 0; - while (*run == 0 && *comm->abortFlag == 0) { + if (*run == MSCCLPP_PROXY_RUN_STATE_IDLE) return; + *run = MSCCLPP_PROXY_RUN_STATE_EXITING; + while (*run == MSCCLPP_PROXY_RUN_STATE_EXITING && *comm->abortFlag == 0) { usleep(1000); } - *run = 0; } mscclppResult_t mscclppProxyDestroy(struct mscclppComm* comm) {