diff --git a/src/include/proxy.h b/src/include/proxy.h index 604399e5..225e9fae 100644 --- a/src/include/proxy.h +++ b/src/include/proxy.h @@ -5,9 +5,15 @@ #include "comm.h" #include +typedef enum { + MSCCLPP_PROXY_RUN_STATE_IDLE = 0, + MSCCLPP_PROXY_RUN_STATE_RUNNING, + MSCCLPP_PROXY_RUN_STATE_EXITING, +} mscclppProxyRunState_t; + struct mscclppProxyState { pthread_t *threads; - int *runs; + mscclppProxyRunState_t *runs; }; mscclppResult_t mscclppProxyCreate(struct mscclppComm* comm); diff --git a/src/proxy.cc b/src/proxy.cc index cf20c7a7..8c8e1fc8 100644 --- a/src/proxy.cc +++ b/src/proxy.cc @@ -34,7 +34,7 @@ struct proxyArgs { struct mscclppComm* comm; struct mscclppIbContext* ibCtx; cudaStream_t stream; - volatile int* run; + volatile mscclppProxyRunState_t* run; int connIdx; }; @@ -42,8 +42,7 @@ struct proxyArgs { void* mscclppProxyServiceP2P(void* _args) { struct proxyArgs *args = (struct proxyArgs *)_args; struct mscclppComm *comm = args->comm; - // TODO(saemal): we perhaps need a finite state for run instead of just 0 and 1 - volatile int *run = args->run; + volatile mscclppProxyRunState_t *run = args->run; struct mscclppConn *conn = &comm->conns[args->connIdx]; cudaStream_t stream = args->stream; free(_args); @@ -58,10 +57,7 @@ void* mscclppProxyServiceP2P(void* _args) { PROXYCUDACHECK(cudaSetDevice(comm->cudaDev)); PROXYCUDACHECK(cudaStreamCreateWithFlags(&stream, cudaStreamNonBlocking)); - cudaStreamCaptureStatus stat; - cudaStreamIsCapturing(stream, &stat); - - while (*run) { + while (*run == MSCCLPP_PROXY_RUN_STATE_RUNNING) { // Poll to see if we are ready to send anything trigger.value = *(volatile uint64_t *)conn->cpuTrigger; if (trigger.value == 0) continue; @@ -84,7 +80,7 @@ void* mscclppProxyServiceP2P(void* _args) { volatile uint64_t *tmp = (volatile uint64_t *)conn->cpuTrigger; *tmp = 0; } - *run = 1; + *run = MSCCLPP_PROXY_RUN_STATE_IDLE; PROXYCUDACHECK(cudaStreamDestroy(stream)); // WARN("Proxy exits: rank %d", rank); @@ -95,7 +91,7 @@ void* mscclppProxyServiceIb(void* _args) { struct proxyArgs *args = (struct proxyArgs *)_args; struct mscclppComm *comm = args->comm; struct mscclppIbContext *ibCtx = args->ibCtx; - volatile int *run = args->run; + volatile mscclppProxyRunState_t *run = args->run; struct mscclppConn *conn = &comm->conns[args->connIdx]; free(_args); @@ -120,7 +116,7 @@ void* mscclppProxyServiceIb(void* _args) { } #endif - while (*run) { + while (*run == MSCCLPP_PROXY_RUN_STATE_RUNNING) { #if (MSCCLPP_PROXY_FLAG_SET_BY_RDMA == 0) // Try send if (sendState == SEND_STATE_INIT) { @@ -219,7 +215,7 @@ void* mscclppProxyServiceIb(void* _args) { *tmp = 0; #endif } - *run = 1; + *run = MSCCLPP_PROXY_RUN_STATE_IDLE; // WARN("Proxy exits: rank %d", rank); return NULL; } @@ -263,7 +259,7 @@ mscclppResult_t mscclppProxyCreate(struct mscclppComm* comm) { args->ibCtx = comm->ibContext[i]; args->run = &comm->proxyState[i].runs[j]; args->connIdx = j; - *args->run = 1; + *args->run = MSCCLPP_PROXY_RUN_STATE_RUNNING; pthread_create(&comm->proxyState[i].threads[j], NULL, mscclppProxyService, args); mscclppSetThreadName(comm->proxyState[i].threads[j], "MSCCLPP Service %2d - %4d", i, j); } @@ -286,7 +282,7 @@ mscclppResult_t mscclppProxyCreate(struct mscclppComm* comm) { args->run = &proxyState->runs[j]; args->connIdx = j; CUDACHECK(cudaStreamCreateWithFlags(&args->stream, cudaStreamNonBlocking)); - *args->run = 1; + *args->run = MSCCLPP_PROXY_RUN_STATE_RUNNING; pthread_create(&proxyState->threads[j], NULL, mscclppProxyService, args); mscclppSetThreadName(proxyState->threads[j], "MSCCLPP Service %2d - %4d", MSCCLPP_IB_MAX_DEVS + 1, j); } @@ -295,12 +291,11 @@ mscclppResult_t mscclppProxyCreate(struct mscclppComm* comm) { static void _stopProxy(struct mscclppComm* comm, int devIdx, int connIdx) { volatile int *run = (volatile int *)&comm->proxyState[devIdx].runs[connIdx]; - if (*run == 0) return; - *run = 0; - while (*run == 0 && *comm->abortFlag == 0) { + if (*run == MSCCLPP_PROXY_RUN_STATE_IDLE) return; + *run = MSCCLPP_PROXY_RUN_STATE_EXITING; + while (*run == MSCCLPP_PROXY_RUN_STATE_EXITING && *comm->abortFlag == 0) { usleep(1000); } - *run = 0; } mscclppResult_t mscclppProxyDestroy(struct mscclppComm* comm) {