Merge branch 'main' into chhwang/ib-proxy-merge

This commit is contained in:
Changho Hwang
2023-03-07 13:47:04 +08:00
committed by GitHub
2 changed files with 19 additions and 18 deletions

View File

@@ -5,9 +5,15 @@
#include "comm.h"
#include <pthread.h>
typedef enum {
MSCCLPP_PROXY_RUN_STATE_IDLE = 0,
MSCCLPP_PROXY_RUN_STATE_RUNNING,
MSCCLPP_PROXY_RUN_STATE_EXITING,
} mscclppProxyRunState_t;
struct mscclppProxyState {
pthread_t *threads;
int *runs;
mscclppProxyRunState_t *runs;
};
mscclppResult_t mscclppProxyCreate(struct mscclppComm* comm);

View File

@@ -34,7 +34,7 @@ struct proxyArgs {
struct mscclppComm* comm;
struct mscclppIbContext* ibCtx;
cudaStream_t stream;
volatile int* run;
volatile mscclppProxyRunState_t* run;
int connIdx;
};
@@ -42,8 +42,7 @@ struct proxyArgs {
void* mscclppProxyServiceP2P(void* _args) {
struct proxyArgs *args = (struct proxyArgs *)_args;
struct mscclppComm *comm = args->comm;
// TODO(saemal): we perhaps need a finite state for run instead of just 0 and 1
volatile int *run = args->run;
volatile mscclppProxyRunState_t *run = args->run;
struct mscclppConn *conn = &comm->conns[args->connIdx];
cudaStream_t stream = args->stream;
free(_args);
@@ -58,10 +57,7 @@ void* mscclppProxyServiceP2P(void* _args) {
PROXYCUDACHECK(cudaSetDevice(comm->cudaDev));
PROXYCUDACHECK(cudaStreamCreateWithFlags(&stream, cudaStreamNonBlocking));
cudaStreamCaptureStatus stat;
cudaStreamIsCapturing(stream, &stat);
while (*run) {
while (*run == MSCCLPP_PROXY_RUN_STATE_RUNNING) {
// Poll to see if we are ready to send anything
trigger.value = *(volatile uint64_t *)conn->cpuTrigger;
if (trigger.value == 0) continue;
@@ -84,7 +80,7 @@ void* mscclppProxyServiceP2P(void* _args) {
volatile uint64_t *tmp = (volatile uint64_t *)conn->cpuTrigger;
*tmp = 0;
}
*run = 1;
*run = MSCCLPP_PROXY_RUN_STATE_IDLE;
PROXYCUDACHECK(cudaStreamDestroy(stream));
// WARN("Proxy exits: rank %d", rank);
@@ -95,7 +91,7 @@ void* mscclppProxyServiceIb(void* _args) {
struct proxyArgs *args = (struct proxyArgs *)_args;
struct mscclppComm *comm = args->comm;
struct mscclppIbContext *ibCtx = args->ibCtx;
volatile int *run = args->run;
volatile mscclppProxyRunState_t *run = args->run;
struct mscclppConn *conn = &comm->conns[args->connIdx];
free(_args);
@@ -120,7 +116,7 @@ void* mscclppProxyServiceIb(void* _args) {
}
#endif
while (*run) {
while (*run == MSCCLPP_PROXY_RUN_STATE_RUNNING) {
#if (MSCCLPP_PROXY_FLAG_SET_BY_RDMA == 0)
// Try send
if (sendState == SEND_STATE_INIT) {
@@ -219,7 +215,7 @@ void* mscclppProxyServiceIb(void* _args) {
*tmp = 0;
#endif
}
*run = 1;
*run = MSCCLPP_PROXY_RUN_STATE_IDLE;
// WARN("Proxy exits: rank %d", rank);
return NULL;
}
@@ -263,7 +259,7 @@ mscclppResult_t mscclppProxyCreate(struct mscclppComm* comm) {
args->ibCtx = comm->ibContext[i];
args->run = &comm->proxyState[i].runs[j];
args->connIdx = j;
*args->run = 1;
*args->run = MSCCLPP_PROXY_RUN_STATE_RUNNING;
pthread_create(&comm->proxyState[i].threads[j], NULL, mscclppProxyService, args);
mscclppSetThreadName(comm->proxyState[i].threads[j], "MSCCLPP Service %2d - %4d", i, j);
}
@@ -286,7 +282,7 @@ mscclppResult_t mscclppProxyCreate(struct mscclppComm* comm) {
args->run = &proxyState->runs[j];
args->connIdx = j;
CUDACHECK(cudaStreamCreateWithFlags(&args->stream, cudaStreamNonBlocking));
*args->run = 1;
*args->run = MSCCLPP_PROXY_RUN_STATE_RUNNING;
pthread_create(&proxyState->threads[j], NULL, mscclppProxyService, args);
mscclppSetThreadName(proxyState->threads[j], "MSCCLPP Service %2d - %4d", MSCCLPP_IB_MAX_DEVS + 1, j);
}
@@ -295,12 +291,11 @@ mscclppResult_t mscclppProxyCreate(struct mscclppComm* comm) {
static void _stopProxy(struct mscclppComm* comm, int devIdx, int connIdx) {
volatile int *run = (volatile int *)&comm->proxyState[devIdx].runs[connIdx];
if (*run == 0) return;
*run = 0;
while (*run == 0 && *comm->abortFlag == 0) {
if (*run == MSCCLPP_PROXY_RUN_STATE_IDLE) return;
*run = MSCCLPP_PROXY_RUN_STATE_EXITING;
while (*run == MSCCLPP_PROXY_RUN_STATE_EXITING && *comm->abortFlag == 0) {
usleep(1000);
}
*run = 0;
}
mscclppResult_t mscclppProxyDestroy(struct mscclppComm* comm) {