Move MRs / MR infos to mscclppHostIBConn & cleanup

This commit is contained in:
Changho Hwang
2023-04-12 09:05:42 +00:00
parent fd3f928108
commit bc729cd481
4 changed files with 80 additions and 80 deletions

View File

@@ -20,6 +20,7 @@
struct mscclppConn
{
int connId;
mscclppTransport_t transport;
int remoteRank;
uint64_t buffSize;
@@ -28,10 +29,6 @@ struct mscclppConn
struct mscclppIbContext* ibCtx;
struct mscclppIbQp* ibQp;
struct mscclppIbMr* ibBuffMr;
struct mscclppIbMr* ibSignalEpochIdMr;
struct mscclppIbMrInfo ibBuffMrInfo;
struct mscclppIbMrInfo ibSignalEpochIdMrInfo;
#if defined(ENABLE_NPKIT)
std::vector<uint64_t> npkitUsedReqIds;
std::vector<uint64_t> npkitFreeReqIds;

View File

@@ -188,6 +188,7 @@ struct mscclppDevConn
// Host interface for mscclppDevCon functionality
struct mscclppHostConn{
virtual ~mscclppHostConn() = default;
virtual void put(uint64_t dstDataOffset, uint64_t srcDataOffset, uint64_t dataSize) = 0;
virtual void signal() = 0;
virtual void wait() = 0;

View File

@@ -267,6 +267,53 @@ MSCCLPP_API mscclppResult_t mscclppGetAllDeviceConnections(mscclppComm_t comm, m
return mscclppSuccess;
}
#if defined(ENABLE_NPKIT)
static void npkitInitReqIds(struct mscclppComm* comm)
{
for (int i = 0; i < comm->nConns; i++) {
struct mscclppConn* conn = &comm->conns[i];
conn->npkitUsedReqIds.resize(0);
conn->npkitFreeReqIds.resize(MSCCLPP_IB_MAX_SENDS);
for (uint64_t j = 0; j < MSCCLPP_IB_MAX_SENDS; j++) {
conn->npkitFreeReqIds[j] = MSCCLPP_IB_MAX_SENDS - j - 1;
}
}
}
static void npkitCollectEntryEvent(struct mscclppConn* conn, uint8_t type, uint32_t size)
{
uint64_t reqId = 0;
if (conn->npkitFreeReqIds.size() == 0) {
reqId = conn->npkitUsedReqIds.size();
} else {
reqId = conn->npkitFreeReqIds.back();
conn->npkitFreeReqIds.pop_back();
}
conn->npkitUsedReqIds.push_back(reqId);
NpKit::CollectCpuEvent(type, size, (uint32_t)reqId, NpKit::GetCpuTimestamp(), conn->connId);
}
static void npkitCollectExitEvents(struct mscclppConn* conn, uint8_t type)
{
while (conn->npkitUsedReqIds.size()) {
uint64_t reqId = conn->npkitUsedReqIds.back();
NpKit::CollectCpuEvent(type, 0, (uint32_t)reqId, NpKit::GetCpuTimestamp(), conn->connId);
conn->npkitFreeReqIds.push_back(reqId);
conn->npkitUsedReqIds.pop_back();
}
}
#else
#define npkitInitReqIds(comm)
#define npkitCollectEntryEvent(conn, type, size)
#define npkitCollectExitEvents(conn, type)
#endif
struct mscclppHostP2PConn : mscclppHostConn{
mscclppHostP2PConn(mscclppConn* _conn, cudaStream_t _stream) : conn(_conn), p2pStream(_stream){}
@@ -275,15 +322,18 @@ struct mscclppHostP2PConn : mscclppHostConn{
void* srcBuff = (void*)((char*)conn->devConn->localBuff + srcDataOffset);
void* dstBuff = (void*)((char*)conn->devConn->remoteBuff + dstDataOffset);
CUDACHECKNORET(cudaMemcpyAsync(dstBuff, srcBuff, dataSize, cudaMemcpyDeviceToDevice, p2pStream));
npkitCollectEntryEvent(conn, NPKIT_EVENT_DMA_SEND_DATA_ENTRY, (uint32_t)dataSize);
}
void signal(){
CUDACHECKNORET(cudaMemcpyAsync(&conn->devConn->remoteSignalEpochId->proxy,
&(conn->devConn->localSignalEpochId->device), sizeof(uint64_t),
cudaMemcpyDeviceToDevice, p2pStream));
npkitCollectEntryEvent(conn, NPKIT_EVENT_DMA_SEND_FLAG_ENTRY, (uint32_t)sizeof(uint64_t));
}
void wait(){}
void flush(){
CUDACHECKNORET(cudaStreamSynchronize(p2pStream));
npkitCollectExitEvents(conn, NPKIT_EVENT_DMA_SEND_EXIT);
}
mscclppConn* conn;
@@ -294,7 +344,7 @@ struct mscclppHostIBConn : mscclppHostConn{
mscclppHostIBConn(mscclppConn* conn) : conn(conn) {}
void put(uint64_t dstDataOffset, uint64_t srcDataOffset, uint64_t dataSize){
conn->ibQp->stageSend(conn->ibBuffMr, &conn->ibBuffMrInfo, (uint32_t)dataSize,
conn->ibQp->stageSend(this->ibBuffMr, &this->ibBuffMrInfo, (uint32_t)dataSize,
/*wrId=*/0, /*srcOffset=*/srcDataOffset,
/*dstOffset=*/dstDataOffset,
/*signaled=*/false);
@@ -303,15 +353,17 @@ struct mscclppHostIBConn : mscclppHostConn{
// Return value is errno.
WARN("data postSend failed: errno %d", ret);
}
npkitCollectEntryEvent(conn, NPKIT_EVENT_IB_SEND_DATA_ENTRY, (uint32_t)dataSize);
}
void signal(){
// My local device flag is copied to the remote's proxy flag
conn->ibQp->stageSend(conn->ibSignalEpochIdMr, &conn->ibSignalEpochIdMrInfo, sizeof(uint64_t),
conn->ibQp->stageSend(this->ibSignalEpochIdMr, &this->ibSignalEpochIdMrInfo, sizeof(uint64_t),
/*wrId=*/0, /*srcOffset=*/0, /*dstOffset=*/sizeof(uint64_t), /*signaled=*/true);
int ret = conn->ibQp->postSend();
if (ret != 0) {
WARN("flag postSend failed: errno %d", ret);
}
npkitCollectEntryEvent(conn, NPKIT_EVENT_IB_SEND_FLAG_ENTRY, (uint32_t)sizeof(uint64_t));
}
void wait(){}
void flush(){
@@ -338,9 +390,14 @@ struct mscclppHostIBConn : mscclppHostConn{
}
}
}
npkitCollectExitEvents(conn, NPKIT_EVENT_IB_SEND_EXIT);
}
mscclppConn* conn;
struct mscclppIbMr* ibBuffMr;
struct mscclppIbMr* ibSignalEpochIdMr;
struct mscclppIbMrInfo ibBuffMrInfo;
struct mscclppIbMrInfo ibSignalEpochIdMrInfo;
};
@@ -365,7 +422,9 @@ MSCCLPP_API mscclppResult_t mscclppConnect(mscclppComm_t comm, int remoteRank, i
WARN("Too many connections made");
return mscclppInternalError;
}
struct mscclppConn* conn = &comm->conns[comm->nConns];
int connId = comm->nConns;
struct mscclppConn* conn = &comm->conns[connId];
conn->connId = connId;
conn->transport = transportType;
conn->buffSize = buffSize;
@@ -463,8 +522,7 @@ MSCCLPP_API mscclppResult_t mscclppConnect(mscclppComm_t comm, int remoteRank, i
conn->hostConn = new mscclppHostP2PConn(conn, proxyState->p2pStream);
}
struct mscclppDevConn* devConn = &comm->devConns[comm->nConns];
struct mscclppDevConn* devConn = &comm->devConns[connId];
conn->devConn = devConn;
conn->devConn->localBuff = localBuff;
@@ -472,7 +530,7 @@ MSCCLPP_API mscclppResult_t mscclppConnect(mscclppComm_t comm, int remoteRank, i
MSCCLPPCHECK(mscclppCudaCalloc(&conn->devConn->waitEpochId, 1));
conn->devConn->remoteRank = remoteRank;
conn->devConn->tag = tag;
conn->devConn->fifo.connId = comm->nConns;
conn->devConn->fifo.connId = connId;
#if defined(MSCCLPP_USE_GDRCOPY)
conn->devConn->fifo.triggerFifo = proxyState->fifo.triggerFifoDev;
#else
@@ -530,6 +588,7 @@ mscclppResult_t mscclppIbConnectionSetupStart(struct connInfo* connInfo /*output
return mscclppInternalError;
}
struct mscclppDevConn* devConn = conn->devConn;
struct mscclppHostIBConn* hostConn = (struct mscclppHostIBConn*)conn->hostConn;
devConn->remoteBuff = NULL;
devConn->remoteSignalEpochId = NULL;
@@ -537,12 +596,12 @@ mscclppResult_t mscclppIbConnectionSetupStart(struct connInfo* connInfo /*output
if (conn->ibQp == NULL) {
MSCCLPPCHECK(mscclppIbContextCreateQp(ibCtx, &conn->ibQp));
}
MSCCLPPCHECK(mscclppIbContextRegisterMr(ibCtx, devConn->localBuff, conn->buffSize, &conn->ibBuffMr));
MSCCLPPCHECK(mscclppIbContextRegisterMr(ibCtx, devConn->localBuff, conn->buffSize, &hostConn->ibBuffMr));
MSCCLPPCHECK(mscclppIbContextRegisterMr(ibCtx, devConn->localSignalEpochId,
sizeof(struct mscclppDevConnSignalEpochId), &conn->ibSignalEpochIdMr));
sizeof(struct mscclppDevConnSignalEpochId), &hostConn->ibSignalEpochIdMr));
connInfo->infoQp = conn->ibQp->info;
connInfo->infoBuffMr = conn->ibBuffMr->info;
connInfo->infoSignalEpochIdMr = conn->ibSignalEpochIdMr->info;
connInfo->infoBuffMr = hostConn->ibBuffMr->info;
connInfo->infoSignalEpochIdMr = hostConn->ibSignalEpochIdMr->info;
return mscclppSuccess;
}
@@ -560,8 +619,9 @@ mscclppResult_t mscclppIbConnectionSetupEnd(struct connInfo* connInfo /*input*/,
WARN("Failed to transition QP to RTS");
return mscclppInvalidUsage;
}
conn->ibBuffMrInfo = connInfo->infoBuffMr;
conn->ibSignalEpochIdMrInfo = connInfo->infoSignalEpochIdMr;
struct mscclppHostIBConn* hostConn = (struct mscclppHostIBConn*)conn->hostConn;
hostConn->ibBuffMrInfo = connInfo->infoBuffMr;
hostConn->ibSignalEpochIdMrInfo = connInfo->infoSignalEpochIdMr;
return mscclppSuccess;
}
@@ -658,7 +718,8 @@ MSCCLPP_API mscclppResult_t mscclppRegisteredBufferWrite(mscclppComm_t comm, msc
void* dstBuff = regMem->p2p[i].remoteBuff;
CUDACHECK(cudaMemcpyAsync(dstBuff, srcBuff, size, cudaMemcpyDeviceToDevice, (cudaStream_t)stream));
} else {
conn->ibQp->stageSend(conn->ibBuffMr, &conn->ibBuffMrInfo, (uint32_t)size,
struct mscclppHostIBConn* hostConn = (struct mscclppHostIBConn*)conn->hostConn;
conn->ibQp->stageSend(hostConn->ibBuffMr, &hostConn->ibBuffMrInfo, (uint32_t)size,
/*wrId=*/0, /*srcOffset=*/srcOffset,
/*dstOffset=*/dstOffset,
/*signaled=*/false);
@@ -678,6 +739,7 @@ MSCCLPP_API mscclppResult_t mscclppRegisteredBufferWrite(mscclppComm_t comm, msc
MSCCLPP_API mscclppResult_t mscclppProxyLaunch(mscclppComm_t comm)
{
npkitInitReqIds(comm);
MSCCLPPCHECK(mscclppProxyCreate(comm));
return mscclppSuccess;
}

View File

@@ -10,8 +10,6 @@
#include <sys/syscall.h>
#include <thread>
#include "npkit/npkit.h"
#define MSCCLPP_PROXY_RUN_STATE_CHECK_PERIOD 100
#define PROXYCUDACHECK(cmd) \
@@ -40,53 +38,6 @@ struct proxyArgs
struct mscclppProxyState* proxyState;
};
#if defined(ENABLE_NPKIT)
static void npkitInitReqIds(struct mscclppComm* comm)
{
for (int i = 0; i < comm->nConns; i++) {
struct mscclppConn* conn = &comm->conns[i];
conn->npkitUsedReqIds.resize(0);
conn->npkitFreeReqIds.resize(MSCCLPP_IB_MAX_SENDS);
for (uint64_t j = 0; j < MSCCLPP_IB_MAX_SENDS; j++) {
conn->npkitFreeReqIds[j] = MSCCLPP_IB_MAX_SENDS - j - 1;
}
}
}
static void npkitCollectEntryEvent(struct mscclppConn* conn, uint8_t type, uint32_t size, int channelId)
{
uint64_t reqId = 0;
if (conn->npkitFreeReqIds.size() == 0) {
reqId = conn->npkitUsedReqIds.size();
} else {
reqId = conn->npkitFreeReqIds.back();
conn->npkitFreeReqIds.pop_back();
}
conn->npkitUsedReqIds.push_back(reqId);
NpKit::CollectCpuEvent(type, size, (uint32_t)reqId, NpKit::GetCpuTimestamp(), channelId);
}
static void npkitCollectExitEvents(struct mscclppConn* conn, uint8_t type, int channelId)
{
while (conn->npkitUsedReqIds.size()) {
uint64_t reqId = conn->npkitUsedReqIds.back();
NpKit::CollectCpuEvent(type, 0, (uint32_t)reqId, NpKit::GetCpuTimestamp(), channelId);
conn->npkitFreeReqIds.push_back(reqId);
conn->npkitUsedReqIds.pop_back();
}
}
#else
#define npkitInitReqIds(comm)
#define npkitCollectEntryEvent(conn, type, size, channelId)
#define npkitCollectExitEvents(conn, type, channelId)
#endif
mscclppResult_t mscclppProxyFifo::create()
{
MSCCLPPCHECK(mscclppCudaCalloc(&this->fifoHead, 1));
@@ -150,29 +101,20 @@ mscclppResult_t mscclppProxyFifo::flushTail(bool sync)
return mscclppSuccess;
}
void processTrigger(const mscclppTrigger trigger, mscclppConn* conn, mscclppProxyState* proxyState){
mscclppIbContext* ibCtx = proxyState->ibContext;
bool isP2pProxy = (ibCtx == nullptr);
static void processTrigger(const mscclppTrigger trigger, mscclppConn* conn)
{
// Iterate over what send is needed
if (trigger.fields.type & mscclppData) {
conn->hostConn->put(trigger.fields.dstDataOffset, trigger.fields.srcDataOffset, trigger.fields.dataSize);
npkitCollectEntryEvent(conn, isP2pProxy ? NPKIT_EVENT_DMA_SEND_DATA_ENTRY : NPKIT_EVENT_IB_SEND_DATA_ENTRY,
(uint32_t)trigger.fields.dataSize, trigger.fields.connId);
}
if (trigger.fields.type & mscclppFlag) {
conn->hostConn->signal();
npkitCollectEntryEvent(conn, isP2pProxy ? NPKIT_EVENT_P2P_SEND_FLAG_ENTRY : NPKIT_EVENT_IB_SEND_FLAG_ENTRY,
(uint32_t)sizeof(uint64_t), trigger.fields.connId);
}
// Wait for completion
if (trigger.fields.type & mscclppSync) {
conn->hostConn->flush();
npkitCollectExitEvents(conn, isP2pProxy? NPKIT_EVENT_DMA_SEND_EXIT : NPKIT_EVENT_IB_SEND_EXIT, trigger.fields.connId);
}
}
@@ -191,8 +133,6 @@ void* mscclppProxyService(void* _args)
volatile mscclppProxyRunState_t* run = &proxyState->run;
mscclppTrigger trigger;
npkitInitReqIds(comm);
int runCnt = MSCCLPP_PROXY_RUN_STATE_CHECK_PERIOD;
uint64_t flushCnt = 0;
for (;;) {
@@ -209,7 +149,7 @@ void* mscclppProxyService(void* _args)
}
mscclppConn* conn = &comm->conns[trigger.fields.connId];
processTrigger(trigger, conn, proxyState);
processTrigger(trigger, conn);
// Send completion: reset only the high 64 bits
PROXYMSCCLPPCHECK(fifo->pop());