diff --git a/src/include/comm.h b/src/include/comm.h index 3cfd772c..8da70875 100644 --- a/src/include/comm.h +++ b/src/include/comm.h @@ -158,8 +158,6 @@ // } channels[MAXCHANNELS]; // }; -#define MSCCLPP_PROXY_FIFO_SIZE 8 - struct mscclppConn { mscclppTransport_t transport; int remoteRank; diff --git a/src/include/mscclpp.h b/src/include/mscclpp.h index 2558f138..a47593eb 100644 --- a/src/include/mscclpp.h +++ b/src/include/mscclpp.h @@ -10,6 +10,7 @@ #define MSCCLPP_MAJOR 0 #define MSCCLPP_MINOR 1 +#define MSCCLPP_PROXY_FIFO_SIZE 8 #define MSCCLPP_VERSION (MSCCLPP_MAJOR * 100 + MSCCLPP_MINOR) diff --git a/src/proxy.cc b/src/proxy.cc index 94dc4896..aaf6d6c5 100644 --- a/src/proxy.cc +++ b/src/proxy.cc @@ -8,6 +8,7 @@ #include #include #include +#include #include // TODO(chhwang): verify if MSCCLPP_PROXY_FLAG_SET_BY_RDMA == 0 is useful, otherwise delete this option. @@ -43,7 +44,14 @@ void* mscclppProxyServiceP2P(void* _args) { struct proxyArgs *args = (struct proxyArgs *)_args; struct mscclppComm *comm = args->comm; volatile mscclppProxyRunState_t *run = args->run; - struct mscclppConn *conn = &comm->conns[args->connIdx]; + std::vector conns; + for (int i = 0; i < comm->nConns; ++i) { + struct mscclppConn *conn = &comm->conns[i]; + // TODO(saemal): we need to create another transport type which doesn't need a proxy. + if (conn->transport == mscclppTransportP2P) { + conns.push_back(conn); + } + } cudaStream_t stream = args->stream; free(_args); @@ -58,31 +66,37 @@ void* mscclppProxyServiceP2P(void* _args) { PROXYCUDACHECK(cudaStreamCreateWithFlags(&stream, cudaStreamNonBlocking)); while (*run == MSCCLPP_PROXY_RUN_STATE_RUNNING) { - // Poll to see if we are ready to send anything - trigger.value = *(volatile uint64_t *)(&conn->cpuTriggerFifo[conn->fifoTail]); - if (trigger.value == 0) continue; + for (struct mscclppConn *conn : conns) { + // Poll to see if we are ready to send anything + trigger.value = *(volatile uint64_t *)(&conn->cpuTriggerFifo[conn->fifoTail]); + if (trigger.value == 0) continue; - // Iterate over what send is needed - if (trigger.fields.type & mscclppData){ - void *srcBuff = (void *)((char *)conn->devConn->localBuff + trigger.fields.dataOffset); - void *dstBuff = (void *)((char *)conn->devConn->remoteBuff + trigger.fields.dataOffset); - PROXYCUDACHECK(cudaMemcpyAsync(dstBuff, srcBuff, trigger.fields.dataSize, cudaMemcpyDeviceToDevice, stream)); - } - if (trigger.fields.type & mscclppFlag) { - PROXYCUDACHECK(cudaMemcpyAsync(conn->remoteProxyFlag, conn->devConn->localFlag, sizeof(uint64_t), cudaMemcpyDeviceToDevice, stream)); - } - // Wait for completion - if (trigger.fields.type & mscclppSync){ - PROXYCUDACHECK(cudaStreamSynchronize(stream)); - } + // Iterate over what send is needed + if (trigger.fields.type & mscclppData){ + void *srcBuff = (void *)((char *)conn->devConn->localBuff + trigger.fields.dataOffset); + void *dstBuff = (void *)((char *)conn->devConn->remoteBuff + trigger.fields.dataOffset); + PROXYCUDACHECK(cudaMemcpyAsync(dstBuff, srcBuff, trigger.fields.dataSize, cudaMemcpyDeviceToDevice, stream)); + } + if (trigger.fields.type & mscclppFlag) { + PROXYCUDACHECK(cudaMemcpyAsync(conn->remoteProxyFlag, conn->devConn->localFlag, sizeof(uint64_t), cudaMemcpyDeviceToDevice, stream)); + } + // Wait for completion + if (trigger.fields.type & mscclppSync){ + PROXYCUDACHECK(cudaStreamSynchronize(stream)); + } - // send completion - volatile uint64_t *tmp = (volatile uint64_t *)(&conn->cpuTriggerFifo[conn->fifoTail]); - *tmp = 0; - conn->fifoTail++; - if (conn->fifoTail == MSCCLPP_PROXY_FIFO_SIZE) - conn->fifoTail = 0; + // Send completion + volatile uint64_t *tmp = (volatile uint64_t *)(&conn->cpuTriggerFifo[conn->fifoTail]); + *tmp = 0; + conn->fifoTail++; + if (conn->fifoTail == MSCCLPP_PROXY_FIFO_SIZE) + conn->fifoTail = 0; + } } + + // Need a sync in case previous copies are not completed + PROXYCUDACHECK(cudaStreamSynchronize(stream)); + *run = MSCCLPP_PROXY_RUN_STATE_IDLE; PROXYCUDACHECK(cudaStreamDestroy(stream)); @@ -95,7 +109,13 @@ void* mscclppProxyServiceIb(void* _args) { struct mscclppComm *comm = args->comm; struct mscclppIbContext *ibCtx = args->ibCtx; volatile mscclppProxyRunState_t *run = args->run; - struct mscclppConn *conn = &comm->conns[args->connIdx]; + std::vector conns; + for (int i = 0; i < comm->nConns; ++i) { + struct mscclppConn *conn = &comm->conns[i]; + if (conn->transport == mscclppTransportIB) { + conns.push_back(conn); + } + } free(_args); #if (MSCCLPP_PROXY_FLAG_SET_BY_RDMA == 0) @@ -114,89 +134,36 @@ void* mscclppProxyServiceIb(void* _args) { NumaBind(ibCtx->numaNode); #if (MSCCLPP_PROXY_FLAG_SET_BY_RDMA == 0) - if (conn->ibQp->postRecv(0) != 0) { - WARN("postRecv failed: errno %d", errno); + for (struct mscclppConn *conn : conns) { + // Post recv + if (conn->ibQp->postRecv(0) != 0) { + WARN("postRecv failed: errno %d", errno); + } } #endif while (*run == MSCCLPP_PROXY_RUN_STATE_RUNNING) { + for (struct mscclppConn *conn : conns) { #if (MSCCLPP_PROXY_FLAG_SET_BY_RDMA == 0) - // Try send - if (sendState == SEND_STATE_INIT) { - trigger.value = *(volatile uint64_t *)(&conn->cpuTriggerFifo[conn->fifoTail]); - if (trigger.value != 0) { - // Do send - conn->ibQp->stageSendWithImm(conn->ibBuffMr, &conn->ibBuffMrInfo, (uint32_t)trigger.fields.dataSize, - /*wrId=*/0, /*offset=*/trigger.fields.dataOffset, /*signaled=*/true, /*immData=*/0); - if (conn->ibQp->postSend() != 0) { - WARN("postSend failed: errno %d", errno); - } - sendState = SEND_STATE_INPROGRESS; - } - } - - // Poll completions - wcNum = conn->ibQp->pollCq(); - if (wcNum < 0) { - WARN("rank %d pollCq failed: errno %d", rank, errno); - } else { - for (int i = 0; i < wcNum; ++i) { - struct ibv_wc *wc = &conn->ibQp->wcs[i]; - if (wc->status != IBV_WC_SUCCESS) { - WARN("rank %d wc status %d", rank, wc->status); - continue; - } - if (wc->qp_num != conn->ibQp->qp->qp_num) { - WARN("rank %d got wc of unknown qp_num %d", rank, wc->qp_num); - continue; - } - if (wc->opcode == IBV_WC_RECV_RDMA_WITH_IMM) { - // TODO(chhwang): cpu flush - *((volatile uint64_t *)conn->cpuProxyFlag) = ++currentProxyFlagVlaue; - // recv completion - if (conn->ibQp->postRecv(wc->wr_id) != 0) { - WARN("postRecv failed: errno %d", errno); + // Try send + if (sendState == SEND_STATE_INIT) { + trigger.value = *(volatile uint64_t *)(&conn->cpuTriggerFifo[conn->fifoTail]); + if (trigger.value != 0) { + // Do send + conn->ibQp->stageSendWithImm(conn->ibBuffMr, &conn->ibBuffMrInfo, (uint32_t)trigger.fields.dataSize, + /*wrId=*/0, /*offset=*/trigger.fields.dataOffset, /*signaled=*/true, /*immData=*/0); + if (conn->ibQp->postSend() != 0) { + WARN("postSend failed: errno %d", errno); } - // WARN("rank %d recv completion", rank); - } else if (wc->opcode == IBV_WC_RDMA_WRITE) { - // send completion - volatile uint64_t *tmp = (volatile uint64_t *)(&conn->cpuTriggerFifo[conn->fifoTail]); - *tmp = 0; - conn->fifoTail++; - if (conn->fifoTail == MSCCLPP_PROXY_FIFO_SIZE) - conn->fifoTail = 0; - sendState = SEND_STATE_INIT; - // WARN("rank %d send completion", rank); + sendState = SEND_STATE_INPROGRESS; } } - } -#else // (MSCCLPP_PROXY_FLAG_SET_BY_RDMA == 1) - // Poll to see if we are ready to send anything - trigger.value = *(volatile uint64_t *)(&conn->cpuTriggerFifo[conn->fifoTail]); - if (trigger.value == 0) continue; - if (trigger.fields.type & mscclppData) { - conn->ibQp->stageSend(conn->ibBuffMr, &conn->ibBuffMrInfo, (uint32_t)trigger.fields.dataSize, - /*wrId=*/0, /*offset=*/trigger.fields.dataOffset, /*signaled=*/false); - } - if (trigger.fields.type & mscclppFlag) { - // My local flag is copied to the peer's proxy flag - conn->ibQp->stageSend(conn->ibLocalFlagMr, &conn->ibProxyFlagMrInfo, sizeof(uint64_t), - /*wrId=*/0, /*offset=*/0, /*signaled=*/true); - } - if (conn->ibQp->postSend() != 0) { - WARN("postSend failed: errno %d", errno); - } - - // Wait for completion - if (trigger.fields.type & mscclppSync) { - bool waiting = true; - while (waiting) { - wcNum = conn->ibQp->pollCq(); - if (wcNum < 0) { - WARN("rank %d pollCq failed: errno %d", rank, errno); - continue; - } + // Poll completions + wcNum = conn->ibQp->pollCq(); + if (wcNum < 0) { + WARN("rank %d pollCq failed: errno %d", rank, errno); + } else { for (int i = 0; i < wcNum; ++i) { struct ibv_wc *wc = &conn->ibQp->wcs[i]; if (wc->status != IBV_WC_SUCCESS) { @@ -207,22 +174,80 @@ void* mscclppProxyServiceIb(void* _args) { WARN("rank %d got wc of unknown qp_num %d", rank, wc->qp_num); continue; } - if (wc->opcode == IBV_WC_RDMA_WRITE) { + if (wc->opcode == IBV_WC_RECV_RDMA_WITH_IMM) { + // TODO(chhwang): cpu flush + *((volatile uint64_t *)conn->cpuProxyFlag) = ++currentProxyFlagVlaue; + // recv completion + if (conn->ibQp->postRecv(wc->wr_id) != 0) { + WARN("postRecv failed: errno %d", errno); + } + // WARN("rank %d recv completion", rank); + } else if (wc->opcode == IBV_WC_RDMA_WRITE) { // send completion - waiting = false; - break; + volatile uint64_t *tmp = (volatile uint64_t *)(&conn->cpuTriggerFifo[conn->fifoTail]); + *tmp = 0; + conn->fifoTail++; + if (conn->fifoTail == MSCCLPP_PROXY_FIFO_SIZE) + conn->fifoTail = 0; + sendState = SEND_STATE_INIT; + // WARN("rank %d send completion", rank); } } } - } +#else // (MSCCLPP_PROXY_FLAG_SET_BY_RDMA == 1) + // Poll to see if we are ready to send anything + trigger.value = *(volatile uint64_t *)(&conn->cpuTriggerFifo[conn->fifoTail]); + if (trigger.value == 0) continue; - // Send completion - volatile uint64_t *tmp = (volatile uint64_t *)(&conn->cpuTriggerFifo[conn->fifoTail]); - *tmp = 0; - conn->fifoTail++; - if (conn->fifoTail == MSCCLPP_PROXY_FIFO_SIZE) - conn->fifoTail = 0; + if (trigger.fields.type & mscclppData) { + conn->ibQp->stageSend(conn->ibBuffMr, &conn->ibBuffMrInfo, (uint32_t)trigger.fields.dataSize, + /*wrId=*/0, /*offset=*/trigger.fields.dataOffset, /*signaled=*/false); + } + if (trigger.fields.type & mscclppFlag) { + // My local flag is copied to the peer's proxy flag + conn->ibQp->stageSend(conn->ibLocalFlagMr, &conn->ibProxyFlagMrInfo, sizeof(uint64_t), + /*wrId=*/0, /*offset=*/0, /*signaled=*/true); + } + if (conn->ibQp->postSend() != 0) { + WARN("postSend failed: errno %d", errno); + } + + // Wait for completion + if (trigger.fields.type & mscclppSync) { + bool waiting = true; + while (waiting) { + wcNum = conn->ibQp->pollCq(); + if (wcNum < 0) { + WARN("rank %d pollCq failed: errno %d", rank, errno); + continue; + } + for (int i = 0; i < wcNum; ++i) { + struct ibv_wc *wc = &conn->ibQp->wcs[i]; + if (wc->status != IBV_WC_SUCCESS) { + WARN("rank %d wc status %d", rank, wc->status); + continue; + } + if (wc->qp_num != conn->ibQp->qp->qp_num) { + WARN("rank %d got wc of unknown qp_num %d", rank, wc->qp_num); + continue; + } + if (wc->opcode == IBV_WC_RDMA_WRITE) { + // send completion + waiting = false; + break; + } + } + } + } + + // Send completion + volatile uint64_t *tmp = (volatile uint64_t *)(&conn->cpuTriggerFifo[conn->fifoTail]); + *tmp = 0; + conn->fifoTail++; + if (conn->fifoTail == MSCCLPP_PROXY_FIFO_SIZE) + conn->fifoTail = 0; #endif + } } *run = MSCCLPP_PROXY_RUN_STATE_IDLE; // WARN("Proxy exits: rank %d", rank); @@ -248,53 +273,43 @@ void* mscclppProxyService(void* _args) { // } mscclppResult_t mscclppProxyCreate(struct mscclppComm* comm) { - // comm->proxyState.thread is pthread_join()'d by commFree() in init.cc for (int i = 0; i < MSCCLPP_IB_MAX_DEVS; ++i) { if (comm->ibContext[i] == NULL) continue; if (comm->proxyState[i].threads == NULL) { - MSCCLPPCHECK(mscclppCalloc(&comm->proxyState[i].threads, comm->nConns)); + MSCCLPPCHECK(mscclppCalloc(&comm->proxyState[i].threads, 1)); } if (comm->proxyState[i].runs == NULL) { - MSCCLPPCHECK(mscclppCalloc(&comm->proxyState[i].runs, comm->nConns)); + MSCCLPPCHECK(mscclppCalloc(&comm->proxyState[i].runs, 1)); } - for (int j = 0; j < comm->nConns; ++j) { - // Create IB proxy threads - struct mscclppConn *conn = &comm->conns[j]; - if (conn->transport != mscclppTransportIB) continue; - if (conn->ibCtx != comm->ibContext[i]) continue; - struct proxyArgs *args; - MSCCLPPCHECK(mscclppCalloc(&args, 1)); - args->comm = comm; - args->ibCtx = comm->ibContext[i]; - args->run = &comm->proxyState[i].runs[j]; - args->connIdx = j; - *args->run = MSCCLPP_PROXY_RUN_STATE_RUNNING; - pthread_create(&comm->proxyState[i].threads[j], NULL, mscclppProxyService, args); - mscclppSetThreadName(comm->proxyState[i].threads[j], "MSCCLPP Service %2d - %4d", i, j); - } - } - // P2P proxies - mscclppProxyState *proxyState = &comm->proxyState[MSCCLPP_IB_MAX_DEVS]; - if (proxyState->threads == NULL) { - MSCCLPPCHECK(mscclppCalloc(&proxyState->threads, comm->nConns)); - } - if (proxyState->runs == NULL) { - MSCCLPPCHECK(mscclppCalloc(&proxyState->runs, comm->nConns)); - } - for (int j = 0; j < comm->nConns; ++j) { - // Create P2P DMA proxy threads - if (comm->conns[j].transport != mscclppTransportP2P) continue; + // Create IB proxy threads struct proxyArgs *args; MSCCLPPCHECK(mscclppCalloc(&args, 1)); args->comm = comm; - args->ibCtx = NULL; - args->run = &proxyState->runs[j]; - args->connIdx = j; - CUDACHECK(cudaStreamCreateWithFlags(&args->stream, cudaStreamNonBlocking)); + args->ibCtx = comm->ibContext[i]; + args->run = comm->proxyState[i].runs; *args->run = MSCCLPP_PROXY_RUN_STATE_RUNNING; - pthread_create(&proxyState->threads[j], NULL, mscclppProxyService, args); - mscclppSetThreadName(proxyState->threads[j], "MSCCLPP Service %2d - %4d", MSCCLPP_IB_MAX_DEVS + 1, j); + pthread_create(comm->proxyState[i].threads, NULL, mscclppProxyService, args); + mscclppSetThreadName(comm->proxyState[i].threads[0], "MSCCLPP Service IB - %02d", i); } + // P2P proxy + mscclppProxyState *proxyState = &comm->proxyState[MSCCLPP_IB_MAX_DEVS]; + if (proxyState->threads == NULL) { + MSCCLPPCHECK(mscclppCalloc(&proxyState->threads, 1)); + } + if (proxyState->runs == NULL) { + MSCCLPPCHECK(mscclppCalloc(&proxyState->runs, 1)); + } + // Create P2P DMA proxy thread + struct proxyArgs *args; + MSCCLPPCHECK(mscclppCalloc(&args, 1)); + args->comm = comm; + args->ibCtx = NULL; + args->run = proxyState->runs; + args->connIdx = -1; // unused + CUDACHECK(cudaStreamCreateWithFlags(&args->stream, cudaStreamNonBlocking)); + *args->run = MSCCLPP_PROXY_RUN_STATE_RUNNING; + pthread_create(proxyState->threads, NULL, mscclppProxyService, args); + mscclppSetThreadName(proxyState->threads[0], "MSCCLPP Service P2P - %02d", comm->cudaDev); return mscclppSuccess; } @@ -310,16 +325,10 @@ static void _stopProxy(struct mscclppComm* comm, int devIdx, int connIdx) { mscclppResult_t mscclppProxyDestroy(struct mscclppComm* comm) { for (int i = 0; i < MSCCLPP_IB_MAX_DEVS; ++i) { if (comm->ibContext[i] != NULL) { - for (int j = 0; j < comm->nConns; ++j) { - _stopProxy(comm, i, j); - } + _stopProxy(comm, i, 0); } } // P2P proxies - mscclppProxyState *proxyState = &comm->proxyState[MSCCLPP_IB_MAX_DEVS]; - for (int j = 0; j < comm->nConns; ++j) { - if (comm->conns[j].transport != mscclppTransportP2P) continue; - _stopProxy(comm, MSCCLPP_IB_MAX_DEVS, j); - } + _stopProxy(comm, MSCCLPP_IB_MAX_DEVS, 0); return mscclppSuccess; } diff --git a/tests/p2p_test.cu b/tests/p2p_test.cu index eb962cc0..a621e619 100644 --- a/tests/p2p_test.cu +++ b/tests/p2p_test.cu @@ -29,7 +29,6 @@ } \ } while(false) -#define MSCCLPP_PROXY_FIFO_SIZE 8 __constant__ mscclppDevConn_t constDevConns[16]; __global__ void kernel(int rank, int world_size) @@ -67,13 +66,13 @@ __global__ void kernel(int rank, int world_size) // Each warp receives data from different ranks #if (USE_DMA_FOR_P2P == 1) + // Wait until the proxy have sent my data and flag + while (*trig != 0) {} + // Trigger sending data and flag uint64_t dataOffset = rank * sizeof(int); uint64_t dataSize = sizeof(int); - *trig = TRIGGER_VALUE(mscclppSync | mscclppFlag | mscclppData, dataOffset, dataSize); - - // Wait until the proxy have sent my data and flag - while (*trig != 0) {} + *trig = TRIGGER_VALUE(mscclppFlag | mscclppData, dataOffset, dataSize); // Wait for receiving data from remote rank while (*proxyFlag == baseFlag) {}