Merge pull request #7 from microsoft/chhwang/proxy-run-states

Add proxy run states
This commit is contained in:
Saeed Maleki
2023-03-07 00:22:52 -05:00
committed by GitHub
2 changed files with 22 additions and 18 deletions

View File

@@ -5,9 +5,15 @@
#include "comm.h"
#include <pthread.h>
typedef enum {
MSCCLPP_PROXY_RUN_STATE_IDLE = 0,
MSCCLPP_PROXY_RUN_STATE_RUNNING,
MSCCLPP_PROXY_RUN_STATE_EXITING,
} mscclppProxyRunState_t;
struct mscclppProxyState {
pthread_t *threads;
int *runs;
mscclppProxyRunState_t *runs;
};
mscclppResult_t mscclppProxyCreate(struct mscclppComm* comm);

View File

@@ -33,7 +33,7 @@ struct proxyArgs {
struct mscclppComm* comm;
struct mscclppIbContext* ibCtx;
cudaStream_t stream;
volatile int* run;
volatile mscclppProxyRunState_t* run;
int connIdx;
};
@@ -41,8 +41,7 @@ struct proxyArgs {
void* mscclppProxyServiceP2P(void* _args) {
struct proxyArgs *args = (struct proxyArgs *)_args;
struct mscclppComm *comm = args->comm;
// TODO(saemal): we perhaps need a finite state for run instead of just 0 and 1
volatile int *run = args->run;
volatile mscclppProxyRunState_t *run = args->run;
struct mscclppConn *conn = &comm->conns[args->connIdx];
cudaStream_t stream = args->stream;
free(_args);
@@ -57,7 +56,7 @@ void* mscclppProxyServiceP2P(void* _args) {
PROXYCUDACHECK(cudaSetDevice(comm->cudaDev));
PROXYCUDACHECK(cudaStreamCreateWithFlags(&stream, cudaStreamNonBlocking));
while (*run) {
while (*run == MSCCLPP_PROXY_RUN_STATE_RUNNING) {
// Poll to see if we are ready to send anything
trigger.value = *(volatile uint64_t *)conn->cpuTrigger;
if (trigger.value == 0) continue;
@@ -80,7 +79,7 @@ void* mscclppProxyServiceP2P(void* _args) {
volatile uint64_t *tmp = (volatile uint64_t *)conn->cpuTrigger;
*tmp = 0;
}
*run = 1;
*run = MSCCLPP_PROXY_RUN_STATE_IDLE;
PROXYCUDACHECK(cudaStreamDestroy(stream));
// WARN("Proxy exits: rank %d", rank);
@@ -94,7 +93,7 @@ void* mscclppProxyServiceIb(void* _args) {
struct proxyArgs *args = (struct proxyArgs *)_args;
struct mscclppComm *comm = args->comm;
struct mscclppIbContext *ibCtx = args->ibCtx;
volatile int *run = args->run;
volatile mscclppProxyRunState_t *run = args->run;
struct mscclppConn *conn = &comm->conns[args->connIdx];
free(_args);
uint64_t currentProxyFlagVlaue = *conn->cpuProxyFlag;
@@ -114,7 +113,7 @@ void* mscclppProxyServiceIb(void* _args) {
WARN("postRecv failed: errno %d", errno);
}
while (*run) {
while (*run == MSCCLPP_PROXY_RUN_STATE_RUNNING) {
// Try send
if (sendState == SEND_STATE_INIT) {
trigger.value = *(volatile uint64_t *)conn->cpuTrigger;
@@ -162,7 +161,7 @@ void* mscclppProxyServiceIb(void* _args) {
}
}
}
*run = 1;
*run = MSCCLPP_PROXY_RUN_STATE_IDLE;
// WARN("Proxy exits: rank %d", rank);
return NULL;
}
@@ -174,7 +173,7 @@ void* mscclppProxyServiceIb(void* _args) {
struct proxyArgs *args = (struct proxyArgs *)_args;
struct mscclppComm *comm = args->comm;
struct mscclppIbContext *ibCtx = args->ibCtx;
volatile int *run = args->run;
volatile mscclppProxyRunState_t *run = args->run;
struct mscclppConn *conn = &comm->conns[args->connIdx];
free(_args);
@@ -184,7 +183,7 @@ void* mscclppProxyServiceIb(void* _args) {
NumaBind(ibCtx->numaNode);
while (*run) {
while (*run == MSCCLPP_PROXY_RUN_STATE_RUNNING) {
// Poll to see if we are ready to send anything
trigger.value = *(volatile uint64_t *)conn->cpuTrigger;
if (trigger.value == 0) continue;
@@ -234,7 +233,7 @@ void* mscclppProxyServiceIb(void* _args) {
volatile uint64_t *tmp = (volatile uint64_t *)conn->cpuTrigger;
*tmp = 0;
}
*run = 1;
*run = MSCCLPP_PROXY_RUN_STATE_IDLE;
// WARN("Proxy exits: rank %d", rank);
return NULL;
}
@@ -280,7 +279,7 @@ mscclppResult_t mscclppProxyCreate(struct mscclppComm* comm) {
args->ibCtx = comm->ibContext[i];
args->run = &comm->proxyState[i].runs[j];
args->connIdx = j;
*args->run = 1;
*args->run = MSCCLPP_PROXY_RUN_STATE_RUNNING;
pthread_create(&comm->proxyState[i].threads[j], NULL, mscclppProxyService, args);
mscclppSetThreadName(comm->proxyState[i].threads[j], "MSCCLPP Service %2d - %4d", i, j);
}
@@ -303,7 +302,7 @@ mscclppResult_t mscclppProxyCreate(struct mscclppComm* comm) {
args->run = &proxyState->runs[j];
args->connIdx = j;
CUDACHECK(cudaStreamCreateWithFlags(&args->stream, cudaStreamNonBlocking));
*args->run = 1;
*args->run = MSCCLPP_PROXY_RUN_STATE_RUNNING;
pthread_create(&proxyState->threads[j], NULL, mscclppProxyService, args);
mscclppSetThreadName(proxyState->threads[j], "MSCCLPP Service %2d - %4d", MSCCLPP_IB_MAX_DEVS + 1, j);
}
@@ -312,12 +311,11 @@ mscclppResult_t mscclppProxyCreate(struct mscclppComm* comm) {
static void _stopProxy(struct mscclppComm* comm, int devIdx, int connIdx) {
volatile int *run = (volatile int *)&comm->proxyState[devIdx].runs[connIdx];
if (*run == 0) return;
*run = 0;
while (*run == 0 && *comm->abortFlag == 0) {
if (*run == MSCCLPP_PROXY_RUN_STATE_IDLE) return;
*run = MSCCLPP_PROXY_RUN_STATE_EXITING;
while (*run == MSCCLPP_PROXY_RUN_STATE_EXITING && *comm->abortFlag == 0) {
usleep(1000);
}
*run = 0;
}
mscclppResult_t mscclppProxyDestroy(struct mscclppComm* comm) {