mirror of
https://github.com/microsoft/mscclpp.git
synced 2026-05-13 01:36:10 +00:00
Move MRs / MR infos to mscclppHostIBConn & cleanup
This commit is contained in:
@@ -20,6 +20,7 @@
|
||||
|
||||
struct mscclppConn
|
||||
{
|
||||
int connId;
|
||||
mscclppTransport_t transport;
|
||||
int remoteRank;
|
||||
uint64_t buffSize;
|
||||
@@ -28,10 +29,6 @@ struct mscclppConn
|
||||
|
||||
struct mscclppIbContext* ibCtx;
|
||||
struct mscclppIbQp* ibQp;
|
||||
struct mscclppIbMr* ibBuffMr;
|
||||
struct mscclppIbMr* ibSignalEpochIdMr;
|
||||
struct mscclppIbMrInfo ibBuffMrInfo;
|
||||
struct mscclppIbMrInfo ibSignalEpochIdMrInfo;
|
||||
#if defined(ENABLE_NPKIT)
|
||||
std::vector<uint64_t> npkitUsedReqIds;
|
||||
std::vector<uint64_t> npkitFreeReqIds;
|
||||
|
||||
@@ -188,6 +188,7 @@ struct mscclppDevConn
|
||||
|
||||
// Host interface for mscclppDevCon functionality
|
||||
struct mscclppHostConn{
|
||||
virtual ~mscclppHostConn() = default;
|
||||
virtual void put(uint64_t dstDataOffset, uint64_t srcDataOffset, uint64_t dataSize) = 0;
|
||||
virtual void signal() = 0;
|
||||
virtual void wait() = 0;
|
||||
|
||||
88
src/init.cc
88
src/init.cc
@@ -267,6 +267,53 @@ MSCCLPP_API mscclppResult_t mscclppGetAllDeviceConnections(mscclppComm_t comm, m
|
||||
return mscclppSuccess;
|
||||
}
|
||||
|
||||
#if defined(ENABLE_NPKIT)
|
||||
|
||||
static void npkitInitReqIds(struct mscclppComm* comm)
|
||||
{
|
||||
for (int i = 0; i < comm->nConns; i++) {
|
||||
struct mscclppConn* conn = &comm->conns[i];
|
||||
conn->npkitUsedReqIds.resize(0);
|
||||
conn->npkitFreeReqIds.resize(MSCCLPP_IB_MAX_SENDS);
|
||||
for (uint64_t j = 0; j < MSCCLPP_IB_MAX_SENDS; j++) {
|
||||
conn->npkitFreeReqIds[j] = MSCCLPP_IB_MAX_SENDS - j - 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
static void npkitCollectEntryEvent(struct mscclppConn* conn, uint8_t type, uint32_t size)
|
||||
{
|
||||
uint64_t reqId = 0;
|
||||
if (conn->npkitFreeReqIds.size() == 0) {
|
||||
reqId = conn->npkitUsedReqIds.size();
|
||||
} else {
|
||||
reqId = conn->npkitFreeReqIds.back();
|
||||
conn->npkitFreeReqIds.pop_back();
|
||||
}
|
||||
conn->npkitUsedReqIds.push_back(reqId);
|
||||
NpKit::CollectCpuEvent(type, size, (uint32_t)reqId, NpKit::GetCpuTimestamp(), conn->connId);
|
||||
}
|
||||
|
||||
static void npkitCollectExitEvents(struct mscclppConn* conn, uint8_t type)
|
||||
{
|
||||
while (conn->npkitUsedReqIds.size()) {
|
||||
uint64_t reqId = conn->npkitUsedReqIds.back();
|
||||
NpKit::CollectCpuEvent(type, 0, (uint32_t)reqId, NpKit::GetCpuTimestamp(), conn->connId);
|
||||
conn->npkitFreeReqIds.push_back(reqId);
|
||||
conn->npkitUsedReqIds.pop_back();
|
||||
}
|
||||
}
|
||||
|
||||
#else
|
||||
|
||||
#define npkitInitReqIds(comm)
|
||||
|
||||
#define npkitCollectEntryEvent(conn, type, size)
|
||||
|
||||
#define npkitCollectExitEvents(conn, type)
|
||||
|
||||
#endif
|
||||
|
||||
|
||||
struct mscclppHostP2PConn : mscclppHostConn{
|
||||
mscclppHostP2PConn(mscclppConn* _conn, cudaStream_t _stream) : conn(_conn), p2pStream(_stream){}
|
||||
@@ -275,15 +322,18 @@ struct mscclppHostP2PConn : mscclppHostConn{
|
||||
void* srcBuff = (void*)((char*)conn->devConn->localBuff + srcDataOffset);
|
||||
void* dstBuff = (void*)((char*)conn->devConn->remoteBuff + dstDataOffset);
|
||||
CUDACHECKNORET(cudaMemcpyAsync(dstBuff, srcBuff, dataSize, cudaMemcpyDeviceToDevice, p2pStream));
|
||||
npkitCollectEntryEvent(conn, NPKIT_EVENT_DMA_SEND_DATA_ENTRY, (uint32_t)dataSize);
|
||||
}
|
||||
void signal(){
|
||||
CUDACHECKNORET(cudaMemcpyAsync(&conn->devConn->remoteSignalEpochId->proxy,
|
||||
&(conn->devConn->localSignalEpochId->device), sizeof(uint64_t),
|
||||
cudaMemcpyDeviceToDevice, p2pStream));
|
||||
npkitCollectEntryEvent(conn, NPKIT_EVENT_DMA_SEND_FLAG_ENTRY, (uint32_t)sizeof(uint64_t));
|
||||
}
|
||||
void wait(){}
|
||||
void flush(){
|
||||
CUDACHECKNORET(cudaStreamSynchronize(p2pStream));
|
||||
npkitCollectExitEvents(conn, NPKIT_EVENT_DMA_SEND_EXIT);
|
||||
}
|
||||
|
||||
mscclppConn* conn;
|
||||
@@ -294,7 +344,7 @@ struct mscclppHostIBConn : mscclppHostConn{
|
||||
mscclppHostIBConn(mscclppConn* conn) : conn(conn) {}
|
||||
|
||||
void put(uint64_t dstDataOffset, uint64_t srcDataOffset, uint64_t dataSize){
|
||||
conn->ibQp->stageSend(conn->ibBuffMr, &conn->ibBuffMrInfo, (uint32_t)dataSize,
|
||||
conn->ibQp->stageSend(this->ibBuffMr, &this->ibBuffMrInfo, (uint32_t)dataSize,
|
||||
/*wrId=*/0, /*srcOffset=*/srcDataOffset,
|
||||
/*dstOffset=*/dstDataOffset,
|
||||
/*signaled=*/false);
|
||||
@@ -303,15 +353,17 @@ struct mscclppHostIBConn : mscclppHostConn{
|
||||
// Return value is errno.
|
||||
WARN("data postSend failed: errno %d", ret);
|
||||
}
|
||||
npkitCollectEntryEvent(conn, NPKIT_EVENT_IB_SEND_DATA_ENTRY, (uint32_t)dataSize);
|
||||
}
|
||||
void signal(){
|
||||
// My local device flag is copied to the remote's proxy flag
|
||||
conn->ibQp->stageSend(conn->ibSignalEpochIdMr, &conn->ibSignalEpochIdMrInfo, sizeof(uint64_t),
|
||||
conn->ibQp->stageSend(this->ibSignalEpochIdMr, &this->ibSignalEpochIdMrInfo, sizeof(uint64_t),
|
||||
/*wrId=*/0, /*srcOffset=*/0, /*dstOffset=*/sizeof(uint64_t), /*signaled=*/true);
|
||||
int ret = conn->ibQp->postSend();
|
||||
if (ret != 0) {
|
||||
WARN("flag postSend failed: errno %d", ret);
|
||||
}
|
||||
npkitCollectEntryEvent(conn, NPKIT_EVENT_IB_SEND_FLAG_ENTRY, (uint32_t)sizeof(uint64_t));
|
||||
}
|
||||
void wait(){}
|
||||
void flush(){
|
||||
@@ -338,9 +390,14 @@ struct mscclppHostIBConn : mscclppHostConn{
|
||||
}
|
||||
}
|
||||
}
|
||||
npkitCollectExitEvents(conn, NPKIT_EVENT_IB_SEND_EXIT);
|
||||
}
|
||||
|
||||
mscclppConn* conn;
|
||||
struct mscclppIbMr* ibBuffMr;
|
||||
struct mscclppIbMr* ibSignalEpochIdMr;
|
||||
struct mscclppIbMrInfo ibBuffMrInfo;
|
||||
struct mscclppIbMrInfo ibSignalEpochIdMrInfo;
|
||||
};
|
||||
|
||||
|
||||
@@ -365,7 +422,9 @@ MSCCLPP_API mscclppResult_t mscclppConnect(mscclppComm_t comm, int remoteRank, i
|
||||
WARN("Too many connections made");
|
||||
return mscclppInternalError;
|
||||
}
|
||||
struct mscclppConn* conn = &comm->conns[comm->nConns];
|
||||
int connId = comm->nConns;
|
||||
struct mscclppConn* conn = &comm->conns[connId];
|
||||
conn->connId = connId;
|
||||
conn->transport = transportType;
|
||||
conn->buffSize = buffSize;
|
||||
|
||||
@@ -463,8 +522,7 @@ MSCCLPP_API mscclppResult_t mscclppConnect(mscclppComm_t comm, int remoteRank, i
|
||||
conn->hostConn = new mscclppHostP2PConn(conn, proxyState->p2pStream);
|
||||
}
|
||||
|
||||
|
||||
struct mscclppDevConn* devConn = &comm->devConns[comm->nConns];
|
||||
struct mscclppDevConn* devConn = &comm->devConns[connId];
|
||||
|
||||
conn->devConn = devConn;
|
||||
conn->devConn->localBuff = localBuff;
|
||||
@@ -472,7 +530,7 @@ MSCCLPP_API mscclppResult_t mscclppConnect(mscclppComm_t comm, int remoteRank, i
|
||||
MSCCLPPCHECK(mscclppCudaCalloc(&conn->devConn->waitEpochId, 1));
|
||||
conn->devConn->remoteRank = remoteRank;
|
||||
conn->devConn->tag = tag;
|
||||
conn->devConn->fifo.connId = comm->nConns;
|
||||
conn->devConn->fifo.connId = connId;
|
||||
#if defined(MSCCLPP_USE_GDRCOPY)
|
||||
conn->devConn->fifo.triggerFifo = proxyState->fifo.triggerFifoDev;
|
||||
#else
|
||||
@@ -530,6 +588,7 @@ mscclppResult_t mscclppIbConnectionSetupStart(struct connInfo* connInfo /*output
|
||||
return mscclppInternalError;
|
||||
}
|
||||
struct mscclppDevConn* devConn = conn->devConn;
|
||||
struct mscclppHostIBConn* hostConn = (struct mscclppHostIBConn*)conn->hostConn;
|
||||
devConn->remoteBuff = NULL;
|
||||
devConn->remoteSignalEpochId = NULL;
|
||||
|
||||
@@ -537,12 +596,12 @@ mscclppResult_t mscclppIbConnectionSetupStart(struct connInfo* connInfo /*output
|
||||
if (conn->ibQp == NULL) {
|
||||
MSCCLPPCHECK(mscclppIbContextCreateQp(ibCtx, &conn->ibQp));
|
||||
}
|
||||
MSCCLPPCHECK(mscclppIbContextRegisterMr(ibCtx, devConn->localBuff, conn->buffSize, &conn->ibBuffMr));
|
||||
MSCCLPPCHECK(mscclppIbContextRegisterMr(ibCtx, devConn->localBuff, conn->buffSize, &hostConn->ibBuffMr));
|
||||
MSCCLPPCHECK(mscclppIbContextRegisterMr(ibCtx, devConn->localSignalEpochId,
|
||||
sizeof(struct mscclppDevConnSignalEpochId), &conn->ibSignalEpochIdMr));
|
||||
sizeof(struct mscclppDevConnSignalEpochId), &hostConn->ibSignalEpochIdMr));
|
||||
connInfo->infoQp = conn->ibQp->info;
|
||||
connInfo->infoBuffMr = conn->ibBuffMr->info;
|
||||
connInfo->infoSignalEpochIdMr = conn->ibSignalEpochIdMr->info;
|
||||
connInfo->infoBuffMr = hostConn->ibBuffMr->info;
|
||||
connInfo->infoSignalEpochIdMr = hostConn->ibSignalEpochIdMr->info;
|
||||
return mscclppSuccess;
|
||||
}
|
||||
|
||||
@@ -560,8 +619,9 @@ mscclppResult_t mscclppIbConnectionSetupEnd(struct connInfo* connInfo /*input*/,
|
||||
WARN("Failed to transition QP to RTS");
|
||||
return mscclppInvalidUsage;
|
||||
}
|
||||
conn->ibBuffMrInfo = connInfo->infoBuffMr;
|
||||
conn->ibSignalEpochIdMrInfo = connInfo->infoSignalEpochIdMr;
|
||||
struct mscclppHostIBConn* hostConn = (struct mscclppHostIBConn*)conn->hostConn;
|
||||
hostConn->ibBuffMrInfo = connInfo->infoBuffMr;
|
||||
hostConn->ibSignalEpochIdMrInfo = connInfo->infoSignalEpochIdMr;
|
||||
return mscclppSuccess;
|
||||
}
|
||||
|
||||
@@ -658,7 +718,8 @@ MSCCLPP_API mscclppResult_t mscclppRegisteredBufferWrite(mscclppComm_t comm, msc
|
||||
void* dstBuff = regMem->p2p[i].remoteBuff;
|
||||
CUDACHECK(cudaMemcpyAsync(dstBuff, srcBuff, size, cudaMemcpyDeviceToDevice, (cudaStream_t)stream));
|
||||
} else {
|
||||
conn->ibQp->stageSend(conn->ibBuffMr, &conn->ibBuffMrInfo, (uint32_t)size,
|
||||
struct mscclppHostIBConn* hostConn = (struct mscclppHostIBConn*)conn->hostConn;
|
||||
conn->ibQp->stageSend(hostConn->ibBuffMr, &hostConn->ibBuffMrInfo, (uint32_t)size,
|
||||
/*wrId=*/0, /*srcOffset=*/srcOffset,
|
||||
/*dstOffset=*/dstOffset,
|
||||
/*signaled=*/false);
|
||||
@@ -678,6 +739,7 @@ MSCCLPP_API mscclppResult_t mscclppRegisteredBufferWrite(mscclppComm_t comm, msc
|
||||
|
||||
MSCCLPP_API mscclppResult_t mscclppProxyLaunch(mscclppComm_t comm)
|
||||
{
|
||||
npkitInitReqIds(comm);
|
||||
MSCCLPPCHECK(mscclppProxyCreate(comm));
|
||||
return mscclppSuccess;
|
||||
}
|
||||
|
||||
66
src/proxy.cc
66
src/proxy.cc
@@ -10,8 +10,6 @@
|
||||
#include <sys/syscall.h>
|
||||
#include <thread>
|
||||
|
||||
#include "npkit/npkit.h"
|
||||
|
||||
#define MSCCLPP_PROXY_RUN_STATE_CHECK_PERIOD 100
|
||||
|
||||
#define PROXYCUDACHECK(cmd) \
|
||||
@@ -40,53 +38,6 @@ struct proxyArgs
|
||||
struct mscclppProxyState* proxyState;
|
||||
};
|
||||
|
||||
#if defined(ENABLE_NPKIT)
|
||||
|
||||
static void npkitInitReqIds(struct mscclppComm* comm)
|
||||
{
|
||||
for (int i = 0; i < comm->nConns; i++) {
|
||||
struct mscclppConn* conn = &comm->conns[i];
|
||||
conn->npkitUsedReqIds.resize(0);
|
||||
conn->npkitFreeReqIds.resize(MSCCLPP_IB_MAX_SENDS);
|
||||
for (uint64_t j = 0; j < MSCCLPP_IB_MAX_SENDS; j++) {
|
||||
conn->npkitFreeReqIds[j] = MSCCLPP_IB_MAX_SENDS - j - 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
static void npkitCollectEntryEvent(struct mscclppConn* conn, uint8_t type, uint32_t size, int channelId)
|
||||
{
|
||||
uint64_t reqId = 0;
|
||||
if (conn->npkitFreeReqIds.size() == 0) {
|
||||
reqId = conn->npkitUsedReqIds.size();
|
||||
} else {
|
||||
reqId = conn->npkitFreeReqIds.back();
|
||||
conn->npkitFreeReqIds.pop_back();
|
||||
}
|
||||
conn->npkitUsedReqIds.push_back(reqId);
|
||||
NpKit::CollectCpuEvent(type, size, (uint32_t)reqId, NpKit::GetCpuTimestamp(), channelId);
|
||||
}
|
||||
|
||||
static void npkitCollectExitEvents(struct mscclppConn* conn, uint8_t type, int channelId)
|
||||
{
|
||||
while (conn->npkitUsedReqIds.size()) {
|
||||
uint64_t reqId = conn->npkitUsedReqIds.back();
|
||||
NpKit::CollectCpuEvent(type, 0, (uint32_t)reqId, NpKit::GetCpuTimestamp(), channelId);
|
||||
conn->npkitFreeReqIds.push_back(reqId);
|
||||
conn->npkitUsedReqIds.pop_back();
|
||||
}
|
||||
}
|
||||
|
||||
#else
|
||||
|
||||
#define npkitInitReqIds(comm)
|
||||
|
||||
#define npkitCollectEntryEvent(conn, type, size, channelId)
|
||||
|
||||
#define npkitCollectExitEvents(conn, type, channelId)
|
||||
|
||||
#endif
|
||||
|
||||
mscclppResult_t mscclppProxyFifo::create()
|
||||
{
|
||||
MSCCLPPCHECK(mscclppCudaCalloc(&this->fifoHead, 1));
|
||||
@@ -150,29 +101,20 @@ mscclppResult_t mscclppProxyFifo::flushTail(bool sync)
|
||||
return mscclppSuccess;
|
||||
}
|
||||
|
||||
void processTrigger(const mscclppTrigger trigger, mscclppConn* conn, mscclppProxyState* proxyState){
|
||||
mscclppIbContext* ibCtx = proxyState->ibContext;
|
||||
bool isP2pProxy = (ibCtx == nullptr);
|
||||
|
||||
static void processTrigger(const mscclppTrigger trigger, mscclppConn* conn)
|
||||
{
|
||||
// Iterate over what send is needed
|
||||
if (trigger.fields.type & mscclppData) {
|
||||
conn->hostConn->put(trigger.fields.dstDataOffset, trigger.fields.srcDataOffset, trigger.fields.dataSize);
|
||||
|
||||
npkitCollectEntryEvent(conn, isP2pProxy ? NPKIT_EVENT_DMA_SEND_DATA_ENTRY : NPKIT_EVENT_IB_SEND_DATA_ENTRY,
|
||||
(uint32_t)trigger.fields.dataSize, trigger.fields.connId);
|
||||
}
|
||||
|
||||
if (trigger.fields.type & mscclppFlag) {
|
||||
conn->hostConn->signal();
|
||||
|
||||
npkitCollectEntryEvent(conn, isP2pProxy ? NPKIT_EVENT_P2P_SEND_FLAG_ENTRY : NPKIT_EVENT_IB_SEND_FLAG_ENTRY,
|
||||
(uint32_t)sizeof(uint64_t), trigger.fields.connId);
|
||||
}
|
||||
|
||||
// Wait for completion
|
||||
if (trigger.fields.type & mscclppSync) {
|
||||
conn->hostConn->flush();
|
||||
npkitCollectExitEvents(conn, isP2pProxy? NPKIT_EVENT_DMA_SEND_EXIT : NPKIT_EVENT_IB_SEND_EXIT, trigger.fields.connId);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -191,8 +133,6 @@ void* mscclppProxyService(void* _args)
|
||||
volatile mscclppProxyRunState_t* run = &proxyState->run;
|
||||
mscclppTrigger trigger;
|
||||
|
||||
npkitInitReqIds(comm);
|
||||
|
||||
int runCnt = MSCCLPP_PROXY_RUN_STATE_CHECK_PERIOD;
|
||||
uint64_t flushCnt = 0;
|
||||
for (;;) {
|
||||
@@ -209,7 +149,7 @@ void* mscclppProxyService(void* _args)
|
||||
}
|
||||
|
||||
mscclppConn* conn = &comm->conns[trigger.fields.connId];
|
||||
processTrigger(trigger, conn, proxyState);
|
||||
processTrigger(trigger, conn);
|
||||
|
||||
// Send completion: reset only the high 64 bits
|
||||
PROXYMSCCLPPCHECK(fifo->pop());
|
||||
|
||||
Reference in New Issue
Block a user