mirror of
https://github.com/microsoft/mscclpp.git
synced 2026-05-13 09:46:00 +00:00
Merge branch 'main' into chhwang/ib-proxy-merge
This commit is contained in:
@@ -5,9 +5,15 @@
|
||||
#include "comm.h"
|
||||
#include <pthread.h>
|
||||
|
||||
typedef enum {
|
||||
MSCCLPP_PROXY_RUN_STATE_IDLE = 0,
|
||||
MSCCLPP_PROXY_RUN_STATE_RUNNING,
|
||||
MSCCLPP_PROXY_RUN_STATE_EXITING,
|
||||
} mscclppProxyRunState_t;
|
||||
|
||||
struct mscclppProxyState {
|
||||
pthread_t *threads;
|
||||
int *runs;
|
||||
mscclppProxyRunState_t *runs;
|
||||
};
|
||||
|
||||
mscclppResult_t mscclppProxyCreate(struct mscclppComm* comm);
|
||||
|
||||
29
src/proxy.cc
29
src/proxy.cc
@@ -34,7 +34,7 @@ struct proxyArgs {
|
||||
struct mscclppComm* comm;
|
||||
struct mscclppIbContext* ibCtx;
|
||||
cudaStream_t stream;
|
||||
volatile int* run;
|
||||
volatile mscclppProxyRunState_t* run;
|
||||
int connIdx;
|
||||
};
|
||||
|
||||
@@ -42,8 +42,7 @@ struct proxyArgs {
|
||||
void* mscclppProxyServiceP2P(void* _args) {
|
||||
struct proxyArgs *args = (struct proxyArgs *)_args;
|
||||
struct mscclppComm *comm = args->comm;
|
||||
// TODO(saemal): we perhaps need a finite state for run instead of just 0 and 1
|
||||
volatile int *run = args->run;
|
||||
volatile mscclppProxyRunState_t *run = args->run;
|
||||
struct mscclppConn *conn = &comm->conns[args->connIdx];
|
||||
cudaStream_t stream = args->stream;
|
||||
free(_args);
|
||||
@@ -58,10 +57,7 @@ void* mscclppProxyServiceP2P(void* _args) {
|
||||
PROXYCUDACHECK(cudaSetDevice(comm->cudaDev));
|
||||
PROXYCUDACHECK(cudaStreamCreateWithFlags(&stream, cudaStreamNonBlocking));
|
||||
|
||||
cudaStreamCaptureStatus stat;
|
||||
cudaStreamIsCapturing(stream, &stat);
|
||||
|
||||
while (*run) {
|
||||
while (*run == MSCCLPP_PROXY_RUN_STATE_RUNNING) {
|
||||
// Poll to see if we are ready to send anything
|
||||
trigger.value = *(volatile uint64_t *)conn->cpuTrigger;
|
||||
if (trigger.value == 0) continue;
|
||||
@@ -84,7 +80,7 @@ void* mscclppProxyServiceP2P(void* _args) {
|
||||
volatile uint64_t *tmp = (volatile uint64_t *)conn->cpuTrigger;
|
||||
*tmp = 0;
|
||||
}
|
||||
*run = 1;
|
||||
*run = MSCCLPP_PROXY_RUN_STATE_IDLE;
|
||||
PROXYCUDACHECK(cudaStreamDestroy(stream));
|
||||
|
||||
// WARN("Proxy exits: rank %d", rank);
|
||||
@@ -95,7 +91,7 @@ void* mscclppProxyServiceIb(void* _args) {
|
||||
struct proxyArgs *args = (struct proxyArgs *)_args;
|
||||
struct mscclppComm *comm = args->comm;
|
||||
struct mscclppIbContext *ibCtx = args->ibCtx;
|
||||
volatile int *run = args->run;
|
||||
volatile mscclppProxyRunState_t *run = args->run;
|
||||
struct mscclppConn *conn = &comm->conns[args->connIdx];
|
||||
free(_args);
|
||||
|
||||
@@ -120,7 +116,7 @@ void* mscclppProxyServiceIb(void* _args) {
|
||||
}
|
||||
#endif
|
||||
|
||||
while (*run) {
|
||||
while (*run == MSCCLPP_PROXY_RUN_STATE_RUNNING) {
|
||||
#if (MSCCLPP_PROXY_FLAG_SET_BY_RDMA == 0)
|
||||
// Try send
|
||||
if (sendState == SEND_STATE_INIT) {
|
||||
@@ -219,7 +215,7 @@ void* mscclppProxyServiceIb(void* _args) {
|
||||
*tmp = 0;
|
||||
#endif
|
||||
}
|
||||
*run = 1;
|
||||
*run = MSCCLPP_PROXY_RUN_STATE_IDLE;
|
||||
// WARN("Proxy exits: rank %d", rank);
|
||||
return NULL;
|
||||
}
|
||||
@@ -263,7 +259,7 @@ mscclppResult_t mscclppProxyCreate(struct mscclppComm* comm) {
|
||||
args->ibCtx = comm->ibContext[i];
|
||||
args->run = &comm->proxyState[i].runs[j];
|
||||
args->connIdx = j;
|
||||
*args->run = 1;
|
||||
*args->run = MSCCLPP_PROXY_RUN_STATE_RUNNING;
|
||||
pthread_create(&comm->proxyState[i].threads[j], NULL, mscclppProxyService, args);
|
||||
mscclppSetThreadName(comm->proxyState[i].threads[j], "MSCCLPP Service %2d - %4d", i, j);
|
||||
}
|
||||
@@ -286,7 +282,7 @@ mscclppResult_t mscclppProxyCreate(struct mscclppComm* comm) {
|
||||
args->run = &proxyState->runs[j];
|
||||
args->connIdx = j;
|
||||
CUDACHECK(cudaStreamCreateWithFlags(&args->stream, cudaStreamNonBlocking));
|
||||
*args->run = 1;
|
||||
*args->run = MSCCLPP_PROXY_RUN_STATE_RUNNING;
|
||||
pthread_create(&proxyState->threads[j], NULL, mscclppProxyService, args);
|
||||
mscclppSetThreadName(proxyState->threads[j], "MSCCLPP Service %2d - %4d", MSCCLPP_IB_MAX_DEVS + 1, j);
|
||||
}
|
||||
@@ -295,12 +291,11 @@ mscclppResult_t mscclppProxyCreate(struct mscclppComm* comm) {
|
||||
|
||||
static void _stopProxy(struct mscclppComm* comm, int devIdx, int connIdx) {
|
||||
volatile int *run = (volatile int *)&comm->proxyState[devIdx].runs[connIdx];
|
||||
if (*run == 0) return;
|
||||
*run = 0;
|
||||
while (*run == 0 && *comm->abortFlag == 0) {
|
||||
if (*run == MSCCLPP_PROXY_RUN_STATE_IDLE) return;
|
||||
*run = MSCCLPP_PROXY_RUN_STATE_EXITING;
|
||||
while (*run == MSCCLPP_PROXY_RUN_STATE_EXITING && *comm->abortFlag == 0) {
|
||||
usleep(1000);
|
||||
}
|
||||
*run = 0;
|
||||
}
|
||||
|
||||
mscclppResult_t mscclppProxyDestroy(struct mscclppComm* comm) {
|
||||
|
||||
Reference in New Issue
Block a user