mirror of
https://github.com/microsoft/mscclpp.git
synced 2026-05-21 21:39:21 +00:00
Implement C API buffer registration support
This commit is contained in:
@@ -9,13 +9,16 @@
|
||||
|
||||
#include "ib.h"
|
||||
#include "proxy.h"
|
||||
|
||||
#if defined(ENABLE_NPKIT)
|
||||
#include <vector>
|
||||
#endif
|
||||
|
||||
#define MAXCONNECTIONS 64
|
||||
|
||||
struct mscclppBufferRegistration
|
||||
{
|
||||
void *data;
|
||||
uint64_t size;
|
||||
};
|
||||
|
||||
struct mscclppConn
|
||||
{
|
||||
int connId;
|
||||
@@ -25,6 +28,9 @@ struct mscclppConn
|
||||
struct mscclppDevConn* devConn;
|
||||
struct mscclppHostConn* hostConn;
|
||||
|
||||
std::vector<mscclppBufferRegistration> bufferRegistrations;
|
||||
std::vector<mscclppBufferRegistration> remoteBufferRegistrations;
|
||||
|
||||
struct mscclppIbContext* ibCtx;
|
||||
#if defined(ENABLE_NPKIT)
|
||||
std::vector<uint64_t> npkitUsedReqIds;
|
||||
|
||||
@@ -370,7 +370,7 @@ mscclppResult_t mscclppConnectWithoutBuffer(mscclppComm_t comm, int remoteRank,
|
||||
* Outputs:
|
||||
* handle: a handle to the buffer registration
|
||||
*/
|
||||
mscclppResult_t mscclppRegisterBuffer(mscclppComm_t comm, int connIdx, void* localBuff, uint64_t buffSize, mscclppBufferHandle_t *handle);
|
||||
mscclppResult_t mscclppRegisterBufferForConnection(mscclppComm_t comm, int connIdx, void* localBuff, uint64_t buffSize, mscclppBufferHandle_t *handle);
|
||||
|
||||
/* Establish all connections declared by mscclppConnect(). This function must be called after all mscclppConnect()
|
||||
* calls are made. This function ensures that all remote ranks are ready to communicate when it returns.
|
||||
|
||||
216
src/init.cc
216
src/init.cc
@@ -9,6 +9,7 @@
|
||||
#include "mscclpp.h"
|
||||
#include <map>
|
||||
#include <sstream>
|
||||
#include <vector>
|
||||
#if defined(ENABLE_NPKIT)
|
||||
#include "npkit/npkit.h"
|
||||
#endif
|
||||
@@ -323,8 +324,12 @@ struct mscclppHostP2PConn : mscclppHostConn
|
||||
|
||||
void put(uint64_t dstDataOffset, uint64_t srcDataOffset, uint64_t dataSize)
|
||||
{
|
||||
void* srcBuff = (void*)((char*)conn->devConn->localBuff + srcDataOffset);
|
||||
void* dstBuff = (void*)((char*)conn->devConn->remoteBuff + dstDataOffset);
|
||||
put(1, dstDataOffset, 1, srcDataOffset, dataSize);
|
||||
}
|
||||
void put(mscclppBufferHandle_t dst, uint64_t dstDataOffset, mscclppBufferHandle_t src, uint64_t srcDataOffset, uint64_t dataSize)
|
||||
{
|
||||
void* srcBuff = (void*)((char*)conn->bufferRegistrations[src].data + srcDataOffset);
|
||||
void* dstBuff = (void*)((char*)conn->remoteBufferRegistrations[dst].data + dstDataOffset);
|
||||
CUDACHECKNORET(cudaMemcpyAsync(dstBuff, srcBuff, dataSize, cudaMemcpyDeviceToDevice, p2pStream));
|
||||
npkitCollectEntryEvent(conn, NPKIT_EVENT_DMA_SEND_DATA_ENTRY, (uint32_t)dataSize);
|
||||
}
|
||||
@@ -357,7 +362,11 @@ struct mscclppHostIBConn : mscclppHostConn
|
||||
|
||||
void put(uint64_t dstDataOffset, uint64_t srcDataOffset, uint64_t dataSize)
|
||||
{
|
||||
this->ibQp->stageSend(this->ibBuffMr, &this->ibBuffMrRemoteInfo, (uint32_t)dataSize,
|
||||
put(1, dstDataOffset, 1, srcDataOffset, dataSize);
|
||||
}
|
||||
void put(mscclppBufferHandle_t dst, uint64_t dstDataOffset, mscclppBufferHandle_t src, uint64_t srcDataOffset, uint64_t dataSize)
|
||||
{
|
||||
this->ibQp->stageSend(this->ibMrs[src], &this->remoteIbMrInfos[dst], (uint32_t)dataSize,
|
||||
/*wrId=*/0, /*srcOffset=*/srcDataOffset, /*dstOffset=*/dstDataOffset, /*signaled=*/false);
|
||||
int ret = this->ibQp->postSend();
|
||||
if (ret != 0) {
|
||||
@@ -369,7 +378,7 @@ struct mscclppHostIBConn : mscclppHostConn
|
||||
void signal()
|
||||
{
|
||||
// My local device flag is copied to the remote's proxy flag
|
||||
this->ibQp->stageSend(this->ibSignalEpochIdMr, &this->ibSignalEpochIdMrRemoteInfo, sizeof(uint64_t),
|
||||
this->ibQp->stageSend(this->ibMrs[0], &this->remoteIbMrInfos[0], sizeof(uint64_t),
|
||||
/*wrId=*/0, /*srcOffset=*/0, /*dstOffset=*/sizeof(uint64_t), /*signaled=*/true);
|
||||
int ret = this->ibQp->postSend();
|
||||
if (ret != 0) {
|
||||
@@ -410,14 +419,11 @@ struct mscclppHostIBConn : mscclppHostConn
|
||||
|
||||
mscclppConn* conn;
|
||||
struct mscclppIbQp* ibQp;
|
||||
struct mscclppIbMr* ibBuffMr;
|
||||
struct mscclppIbMr* ibSignalEpochIdMr;
|
||||
struct mscclppIbMrInfo ibBuffMrRemoteInfo;
|
||||
struct mscclppIbMrInfo ibSignalEpochIdMrRemoteInfo;
|
||||
std::vector<mscclppIbMr*> ibMrs;
|
||||
std::vector<mscclppIbMrInfo> remoteIbMrInfos;
|
||||
};
|
||||
|
||||
MSCCLPP_API mscclppResult_t mscclppConnect(mscclppComm_t comm, int remoteRank, int tag, void* localBuff,
|
||||
uint64_t buffSize, mscclppTransport_t transportType, const char* ibDev)
|
||||
MSCCLPP_API mscclppResult_t mscclppConnectWithoutBuffer(mscclppComm_t comm, int remoteRank, int tag, mscclppTransport_t transportType, const char* ibDev)
|
||||
{
|
||||
// save this processes numa binding and set it to the one closest to the device
|
||||
// so that all the allocation are close to the device
|
||||
@@ -440,7 +446,7 @@ MSCCLPP_API mscclppResult_t mscclppConnect(mscclppComm_t comm, int remoteRank, i
|
||||
struct mscclppConn* conn = &comm->conns[connId];
|
||||
conn->connId = connId;
|
||||
conn->transport = transportType;
|
||||
conn->buffSize = buffSize;
|
||||
conn->buffSize = 0;
|
||||
|
||||
conn->ibCtx = NULL;
|
||||
int ibDevIdx = -1;
|
||||
@@ -537,7 +543,7 @@ MSCCLPP_API mscclppResult_t mscclppConnect(mscclppComm_t comm, int remoteRank, i
|
||||
struct mscclppDevConn* devConn = &comm->devConns[connId];
|
||||
|
||||
conn->devConn = devConn;
|
||||
conn->devConn->localBuff = localBuff;
|
||||
conn->devConn->localBuff = nullptr;
|
||||
MSCCLPPCHECK(mscclppCudaCalloc(&conn->devConn->localSignalEpochId, 1));
|
||||
MSCCLPPCHECK(mscclppCudaCalloc(&conn->devConn->waitEpochId, 1));
|
||||
conn->devConn->remoteRank = remoteRank;
|
||||
@@ -556,27 +562,99 @@ MSCCLPP_API mscclppResult_t mscclppConnect(mscclppComm_t comm, int remoteRank, i
|
||||
// change the numa binding back to user's
|
||||
MSCCLPPCHECK(setNumaState(curProcessState));
|
||||
|
||||
mscclppBufferHandle_t signalHandle = -1;
|
||||
MSCCLPPCHECK(mscclppRegisterBufferForConnection(comm, connId, conn->devConn->localSignalEpochId, sizeof(mscclppDevConnSignalEpochId), &signalHandle));
|
||||
if (signalHandle != 0) {
|
||||
WARN("signal handle should be 0");
|
||||
return mscclppInternalError;
|
||||
}
|
||||
|
||||
return mscclppSuccess;
|
||||
}
|
||||
|
||||
struct connInfo
|
||||
MSCCLPP_API mscclppResult_t mscclppConnect(mscclppComm_t comm, int remoteRank, int tag, void* localBuff,
|
||||
uint64_t buffSize, mscclppTransport_t transportType, const char* ibDev)
|
||||
{
|
||||
cudaIpcMemHandle_t handleBuff;
|
||||
cudaIpcMemHandle_t handleSignalEpochId;
|
||||
mscclppIbQpInfo infoQp;
|
||||
mscclppIbMrInfo infoBuffMr;
|
||||
mscclppIbMrInfo infoSignalEpochIdMr;
|
||||
};
|
||||
int connId = comm->nConns;
|
||||
MSCCLPPCHECK(mscclppConnectWithoutBuffer(comm, remoteRank, tag, transportType, ibDev));
|
||||
struct mscclppConn* conn = &comm->conns[connId];
|
||||
|
||||
mscclppResult_t mscclppP2pConnectionSetupStart(struct connInfo* connInfo /*output*/, struct mscclppConn* conn /*input*/)
|
||||
{
|
||||
if (connInfo == NULL || conn == NULL) {
|
||||
WARN("connInfo or connection cannot be null");
|
||||
conn->buffSize = buffSize;
|
||||
conn->devConn->localBuff = localBuff;
|
||||
|
||||
mscclppBufferHandle_t localBuffHandle = -1;
|
||||
MSCCLPPCHECK(mscclppRegisterBufferForConnection(comm, connId, conn->devConn->localSignalEpochId, buffSize, &localBuffHandle));
|
||||
if (localBuffHandle != 1) {
|
||||
WARN("data buffer handle should be 1");
|
||||
return mscclppInternalError;
|
||||
}
|
||||
struct mscclppDevConn* devConn = conn->devConn;
|
||||
CUDACHECK(cudaIpcGetMemHandle(&connInfo->handleBuff, devConn->localBuff));
|
||||
CUDACHECK(cudaIpcGetMemHandle(&connInfo->handleSignalEpochId, devConn->localSignalEpochId));
|
||||
|
||||
return mscclppSuccess;
|
||||
}
|
||||
|
||||
MSCCLPP_API mscclppResult_t mscclppRegisterBufferForConnection(mscclppComm_t comm, int connIdx, void* localBuff, uint64_t buffSize, mscclppBufferHandle_t *handle) {
|
||||
if (connIdx >= comm->nConns) {
|
||||
WARN("connIdx out of range");
|
||||
return mscclppInvalidArgument;
|
||||
}
|
||||
mscclppConn& conn = comm->conns[connIdx];
|
||||
*handle = conn.bufferRegistrations.size();
|
||||
conn.bufferRegistrations.emplace_back();
|
||||
conn.bufferRegistrations.back().data = localBuff;
|
||||
conn.bufferRegistrations.back().size = buffSize;
|
||||
|
||||
return mscclppSuccess;
|
||||
}
|
||||
|
||||
struct mscclppBufferRegistrationInfo
|
||||
{
|
||||
cudaIpcMemHandle_t cudaHandle;
|
||||
mscclppIbMrInfo ibMrInfo;
|
||||
uint64_t size;
|
||||
};
|
||||
|
||||
struct connInfo
|
||||
{
|
||||
mscclppIbQpInfo infoQp;
|
||||
std::vector<mscclppBufferRegistrationInfo> bufferInfos;
|
||||
|
||||
struct header {
|
||||
mscclppIbQpInfo infoQp;
|
||||
int numBufferInfos;
|
||||
};
|
||||
|
||||
mscclppResult_t sendOverBootstrap(void* bootstrap, int remoteRank, int tag) {
|
||||
header h;
|
||||
h.infoQp = infoQp;
|
||||
h.numBufferInfos = bufferInfos.size();
|
||||
MSCCLPPCHECK(bootstrapSend(bootstrap, remoteRank, tag, &h, sizeof(header)));
|
||||
MSCCLPPCHECK(bootstrapSend(bootstrap, remoteRank, tag, bufferInfos.data(), bufferInfos.size() * sizeof(mscclppBufferRegistrationInfo)));
|
||||
return mscclppSuccess;
|
||||
}
|
||||
|
||||
mscclppResult_t recvOverBootstrap(void* bootstrap, int remoteRank, int tag) {
|
||||
header h;
|
||||
MSCCLPPCHECK(bootstrapRecv(bootstrap, remoteRank, tag, &h, sizeof(header)));
|
||||
infoQp = h.infoQp;
|
||||
bufferInfos.resize(h.numBufferInfos);
|
||||
MSCCLPPCHECK(bootstrapRecv(bootstrap, remoteRank, tag, bufferInfos.data(), bufferInfos.size() * sizeof(mscclppBufferRegistrationInfo)));
|
||||
return mscclppSuccess;
|
||||
}
|
||||
};
|
||||
|
||||
mscclppResult_t mscclppP2pConnectionSetupStart(struct connInfo* connInfo /*input*/, struct mscclppConn* conn /*input*/)
|
||||
{
|
||||
if (conn == NULL) {
|
||||
WARN("connection cannot be null");
|
||||
return mscclppInternalError;
|
||||
}
|
||||
|
||||
// Add all registered buffers
|
||||
for (const auto &bufReg : conn->bufferRegistrations) {
|
||||
connInfo->bufferInfos.emplace_back();
|
||||
CUDACHECK(cudaIpcGetMemHandle(&connInfo->bufferInfos.back().cudaHandle, bufReg.data));
|
||||
connInfo->bufferInfos.back().size = bufReg.size;
|
||||
}
|
||||
return mscclppSuccess;
|
||||
}
|
||||
|
||||
@@ -586,10 +664,30 @@ mscclppResult_t mscclppP2pConnectionSetupEnd(struct connInfo* connInfo /*input*/
|
||||
WARN("ipcHandles or connection cannot be null");
|
||||
return mscclppInternalError;
|
||||
}
|
||||
CUDACHECK(
|
||||
cudaIpcOpenMemHandle((void**)&conn->devConn->remoteBuff, connInfo->handleBuff, cudaIpcMemLazyEnablePeerAccess));
|
||||
CUDACHECK(cudaIpcOpenMemHandle((void**)&conn->devConn->remoteSignalEpochId, connInfo->handleSignalEpochId,
|
||||
cudaIpcMemLazyEnablePeerAccess));
|
||||
if (connInfo->bufferInfos.size() < 1) {
|
||||
WARN("at least 1 buffer info expected");
|
||||
return mscclppInternalError;
|
||||
}
|
||||
|
||||
// Open all remote registered buffers
|
||||
for (size_t i = 0; i < connInfo->bufferInfos.size(); i++) {
|
||||
mscclppBufferRegistration newBufReg;
|
||||
CUDACHECK(cudaIpcOpenMemHandle(&newBufReg.data, connInfo->bufferInfos[i].cudaHandle, cudaIpcMemLazyEnablePeerAccess));
|
||||
newBufReg.size = connInfo->bufferInfos[i].size;
|
||||
conn->remoteBufferRegistrations.push_back(newBufReg);
|
||||
}
|
||||
|
||||
if (conn->remoteBufferRegistrations[0].size != sizeof(mscclppDevConnSignalEpochId)) {
|
||||
WARN("buffer registration zero size doesn't match sizeof(mscclppDevConnSignalEpochId)");
|
||||
return mscclppInternalError;
|
||||
}
|
||||
conn->devConn->remoteSignalEpochId = (mscclppDevConnSignalEpochId*)conn->remoteBufferRegistrations[0].data;
|
||||
|
||||
// For backwards compatibility with the previous API that assumed one data buffer per connection, set the remote buffer
|
||||
// to the first remote data buffer
|
||||
if (conn->remoteBufferRegistrations.size() > 1) {
|
||||
conn->devConn->remoteBuff = conn->remoteBufferRegistrations[1].data;
|
||||
}
|
||||
return mscclppSuccess;
|
||||
}
|
||||
|
||||
@@ -608,12 +706,18 @@ mscclppResult_t mscclppIbConnectionSetupStart(struct connInfo* connInfo /*output
|
||||
if (hostConn->ibQp == NULL) {
|
||||
MSCCLPPCHECK(mscclppIbContextCreateQp(ibCtx, &hostConn->ibQp));
|
||||
}
|
||||
MSCCLPPCHECK(mscclppIbContextRegisterMr(ibCtx, devConn->localBuff, conn->buffSize, &hostConn->ibBuffMr));
|
||||
MSCCLPPCHECK(mscclppIbContextRegisterMr(ibCtx, devConn->localSignalEpochId,
|
||||
sizeof(struct mscclppDevConnSignalEpochId), &hostConn->ibSignalEpochIdMr));
|
||||
|
||||
// Add all registered buffers
|
||||
for (const auto &bufReg : conn->bufferRegistrations) {
|
||||
hostConn->ibMrs.emplace_back();
|
||||
MSCCLPPCHECK(mscclppIbContextRegisterMr(ibCtx, bufReg.data,
|
||||
sizeof(struct mscclppDevConnSignalEpochId), &hostConn->ibMrs.back()));
|
||||
connInfo->bufferInfos.emplace_back();
|
||||
connInfo->bufferInfos.back().ibMrInfo = hostConn->ibMrs.back()->info;
|
||||
connInfo->bufferInfos.back().size = bufReg.size;
|
||||
}
|
||||
|
||||
connInfo->infoQp = hostConn->ibQp->info;
|
||||
connInfo->infoBuffMr = hostConn->ibBuffMr->info;
|
||||
connInfo->infoSignalEpochIdMr = hostConn->ibSignalEpochIdMr->info;
|
||||
return mscclppSuccess;
|
||||
}
|
||||
|
||||
@@ -632,8 +736,18 @@ mscclppResult_t mscclppIbConnectionSetupEnd(struct connInfo* connInfo /*input*/,
|
||||
WARN("Failed to transition QP to RTS");
|
||||
return mscclppInvalidUsage;
|
||||
}
|
||||
hostConn->ibBuffMrRemoteInfo = connInfo->infoBuffMr;
|
||||
hostConn->ibSignalEpochIdMrRemoteInfo = connInfo->infoSignalEpochIdMr;
|
||||
|
||||
// No remote pointers to set with IB, so we just set the Mrs
|
||||
|
||||
// Push the Mrs for all the remote registered buffers
|
||||
for (size_t i = 1; i < connInfo->bufferInfos.size(); i++) {
|
||||
hostConn->remoteIbMrInfos.push_back(connInfo->bufferInfos[i].ibMrInfo);
|
||||
|
||||
mscclppBufferRegistration newBufReg;
|
||||
newBufReg.data = nullptr;
|
||||
newBufReg.size = connInfo->bufferInfos[i].size;
|
||||
conn->remoteBufferRegistrations.push_back(newBufReg);
|
||||
}
|
||||
return mscclppSuccess;
|
||||
}
|
||||
|
||||
@@ -650,14 +764,15 @@ MSCCLPP_API mscclppResult_t mscclppConnectionSetup(mscclppComm_t comm)
|
||||
MSCCLPPCHECK(mscclppIbConnectionSetupStart(&cInfo, conn));
|
||||
}
|
||||
// TODO: from saemal: do we possibly deadlock if there are too many outstanding sends?
|
||||
MSCCLPPCHECK(bootstrapSend(comm->bootstrap, conn->devConn->remoteRank, conn->devConn->tag, &cInfo, sizeof(cInfo)));
|
||||
// MSCCLPPCHECK(bootstrapSend(comm->bootstrap, conn->devConn->remoteRank, conn->devConn->tag, &cInfo, sizeof(cInfo)));
|
||||
MSCCLPPCHECK(cInfo.sendOverBootstrap(comm->bootstrap, conn->devConn->remoteRank, conn->devConn->tag));
|
||||
}
|
||||
|
||||
// Recv info from peers
|
||||
for (int i = 0; i < comm->nConns; ++i) {
|
||||
struct mscclppConn* conn = &comm->conns[i];
|
||||
struct connInfo cInfo;
|
||||
MSCCLPPCHECK(bootstrapRecv(comm->bootstrap, conn->devConn->remoteRank, conn->devConn->tag, &cInfo, sizeof(cInfo)));
|
||||
MSCCLPPCHECK(cInfo.recvOverBootstrap(comm->bootstrap, conn->devConn->remoteRank, conn->devConn->tag));
|
||||
if (conn->transport == mscclppTransportP2P) {
|
||||
MSCCLPPCHECK(mscclppP2pConnectionSetupEnd(&cInfo, conn));
|
||||
} else if (conn->transport == mscclppTransportIB) {
|
||||
@@ -731,16 +846,19 @@ MSCCLPP_API mscclppResult_t mscclppRegisteredBufferWrite(mscclppComm_t comm, msc
|
||||
void* dstBuff = regMem->p2p[i].remoteBuff;
|
||||
CUDACHECK(cudaMemcpyAsync(dstBuff, srcBuff, size, cudaMemcpyDeviceToDevice, (cudaStream_t)stream));
|
||||
} else {
|
||||
struct mscclppHostIBConn* hostConn = (struct mscclppHostIBConn*)conn->hostConn;
|
||||
hostConn->ibQp->stageSend(hostConn->ibBuffMr, &hostConn->ibBuffMrRemoteInfo, (uint32_t)size,
|
||||
/*wrId=*/0, /*srcOffset=*/srcOffset, /*dstOffset=*/dstOffset, /*signaled=*/false);
|
||||
if ((ret = hostConn->ibQp->postSend()) != 0) {
|
||||
// Return value is errno.
|
||||
WARN("data postSend failed: errno %d", ret);
|
||||
}
|
||||
// ??
|
||||
// npkitCollectEntryEvent(conn, NPKIT_EVENT_IB_SEND_ENTRY, (uint32_t)trigger.fields.dataSize,
|
||||
// trigger.fields.connId);
|
||||
WARN("mscclppRegisteredBufferWrite not implemented for IB");
|
||||
return mscclppInternalError;
|
||||
// TODO: fix the following (Olli: probably by including the relevant ibBuffMr in the mscclppRegisteredMemory)
|
||||
// struct mscclppHostIBConn* hostConn = (struct mscclppHostIBConn*)conn->hostConn;
|
||||
// hostConn->ibQp->stageSend(hostConn->ibBuffMr, &hostConn->ibBuffMrRemoteInfo, (uint32_t)size,
|
||||
// /*wrId=*/0, /*srcOffset=*/srcOffset, /*dstOffset=*/dstOffset, /*signaled=*/false);
|
||||
// if ((ret = hostConn->ibQp->postSend()) != 0) {
|
||||
// // Return value is errno.
|
||||
// WARN("data postSend failed: errno %d", ret);
|
||||
// }
|
||||
// // ??
|
||||
// // npkitCollectEntryEvent(conn, NPKIT_EVENT_IB_SEND_ENTRY, (uint32_t)trigger.fields.dataSize,
|
||||
// // trigger.fields.connId);
|
||||
}
|
||||
}
|
||||
return mscclppSuccess;
|
||||
|
||||
Reference in New Issue
Block a user