Implement C API buffer registration support

This commit is contained in:
Olli Saarikivi
2023-04-14 23:04:48 +00:00
parent a0f1d36026
commit 46790d79e8
3 changed files with 177 additions and 53 deletions

View File

@@ -9,13 +9,16 @@
#include "ib.h"
#include "proxy.h"
#if defined(ENABLE_NPKIT)
#include <vector>
#endif
#define MAXCONNECTIONS 64
struct mscclppBufferRegistration
{
void *data;
uint64_t size;
};
struct mscclppConn
{
int connId;
@@ -25,6 +28,9 @@ struct mscclppConn
struct mscclppDevConn* devConn;
struct mscclppHostConn* hostConn;
std::vector<mscclppBufferRegistration> bufferRegistrations;
std::vector<mscclppBufferRegistration> remoteBufferRegistrations;
struct mscclppIbContext* ibCtx;
#if defined(ENABLE_NPKIT)
std::vector<uint64_t> npkitUsedReqIds;

View File

@@ -370,7 +370,7 @@ mscclppResult_t mscclppConnectWithoutBuffer(mscclppComm_t comm, int remoteRank,
* Outputs:
* handle: a handle to the buffer registration
*/
mscclppResult_t mscclppRegisterBuffer(mscclppComm_t comm, int connIdx, void* localBuff, uint64_t buffSize, mscclppBufferHandle_t *handle);
mscclppResult_t mscclppRegisterBufferForConnection(mscclppComm_t comm, int connIdx, void* localBuff, uint64_t buffSize, mscclppBufferHandle_t *handle);
/* Establish all connections declared by mscclppConnect(). This function must be called after all mscclppConnect()
* calls are made. This function ensures that all remote ranks are ready to communicate when it returns.

View File

@@ -9,6 +9,7 @@
#include "mscclpp.h"
#include <map>
#include <sstream>
#include <vector>
#if defined(ENABLE_NPKIT)
#include "npkit/npkit.h"
#endif
@@ -323,8 +324,12 @@ struct mscclppHostP2PConn : mscclppHostConn
void put(uint64_t dstDataOffset, uint64_t srcDataOffset, uint64_t dataSize)
{
void* srcBuff = (void*)((char*)conn->devConn->localBuff + srcDataOffset);
void* dstBuff = (void*)((char*)conn->devConn->remoteBuff + dstDataOffset);
put(1, dstDataOffset, 1, srcDataOffset, dataSize);
}
void put(mscclppBufferHandle_t dst, uint64_t dstDataOffset, mscclppBufferHandle_t src, uint64_t srcDataOffset, uint64_t dataSize)
{
void* srcBuff = (void*)((char*)conn->bufferRegistrations[src].data + srcDataOffset);
void* dstBuff = (void*)((char*)conn->remoteBufferRegistrations[dst].data + dstDataOffset);
CUDACHECKNORET(cudaMemcpyAsync(dstBuff, srcBuff, dataSize, cudaMemcpyDeviceToDevice, p2pStream));
npkitCollectEntryEvent(conn, NPKIT_EVENT_DMA_SEND_DATA_ENTRY, (uint32_t)dataSize);
}
@@ -357,7 +362,11 @@ struct mscclppHostIBConn : mscclppHostConn
void put(uint64_t dstDataOffset, uint64_t srcDataOffset, uint64_t dataSize)
{
this->ibQp->stageSend(this->ibBuffMr, &this->ibBuffMrRemoteInfo, (uint32_t)dataSize,
put(1, dstDataOffset, 1, srcDataOffset, dataSize);
}
void put(mscclppBufferHandle_t dst, uint64_t dstDataOffset, mscclppBufferHandle_t src, uint64_t srcDataOffset, uint64_t dataSize)
{
this->ibQp->stageSend(this->ibMrs[src], &this->remoteIbMrInfos[dst], (uint32_t)dataSize,
/*wrId=*/0, /*srcOffset=*/srcDataOffset, /*dstOffset=*/dstDataOffset, /*signaled=*/false);
int ret = this->ibQp->postSend();
if (ret != 0) {
@@ -369,7 +378,7 @@ struct mscclppHostIBConn : mscclppHostConn
void signal()
{
// My local device flag is copied to the remote's proxy flag
this->ibQp->stageSend(this->ibSignalEpochIdMr, &this->ibSignalEpochIdMrRemoteInfo, sizeof(uint64_t),
this->ibQp->stageSend(this->ibMrs[0], &this->remoteIbMrInfos[0], sizeof(uint64_t),
/*wrId=*/0, /*srcOffset=*/0, /*dstOffset=*/sizeof(uint64_t), /*signaled=*/true);
int ret = this->ibQp->postSend();
if (ret != 0) {
@@ -410,14 +419,11 @@ struct mscclppHostIBConn : mscclppHostConn
mscclppConn* conn;
struct mscclppIbQp* ibQp;
struct mscclppIbMr* ibBuffMr;
struct mscclppIbMr* ibSignalEpochIdMr;
struct mscclppIbMrInfo ibBuffMrRemoteInfo;
struct mscclppIbMrInfo ibSignalEpochIdMrRemoteInfo;
std::vector<mscclppIbMr*> ibMrs;
std::vector<mscclppIbMrInfo> remoteIbMrInfos;
};
MSCCLPP_API mscclppResult_t mscclppConnect(mscclppComm_t comm, int remoteRank, int tag, void* localBuff,
uint64_t buffSize, mscclppTransport_t transportType, const char* ibDev)
MSCCLPP_API mscclppResult_t mscclppConnectWithoutBuffer(mscclppComm_t comm, int remoteRank, int tag, mscclppTransport_t transportType, const char* ibDev)
{
// save this processes numa binding and set it to the one closest to the device
// so that all the allocation are close to the device
@@ -440,7 +446,7 @@ MSCCLPP_API mscclppResult_t mscclppConnect(mscclppComm_t comm, int remoteRank, i
struct mscclppConn* conn = &comm->conns[connId];
conn->connId = connId;
conn->transport = transportType;
conn->buffSize = buffSize;
conn->buffSize = 0;
conn->ibCtx = NULL;
int ibDevIdx = -1;
@@ -537,7 +543,7 @@ MSCCLPP_API mscclppResult_t mscclppConnect(mscclppComm_t comm, int remoteRank, i
struct mscclppDevConn* devConn = &comm->devConns[connId];
conn->devConn = devConn;
conn->devConn->localBuff = localBuff;
conn->devConn->localBuff = nullptr;
MSCCLPPCHECK(mscclppCudaCalloc(&conn->devConn->localSignalEpochId, 1));
MSCCLPPCHECK(mscclppCudaCalloc(&conn->devConn->waitEpochId, 1));
conn->devConn->remoteRank = remoteRank;
@@ -556,27 +562,99 @@ MSCCLPP_API mscclppResult_t mscclppConnect(mscclppComm_t comm, int remoteRank, i
// change the numa binding back to user's
MSCCLPPCHECK(setNumaState(curProcessState));
mscclppBufferHandle_t signalHandle = -1;
MSCCLPPCHECK(mscclppRegisterBufferForConnection(comm, connId, conn->devConn->localSignalEpochId, sizeof(mscclppDevConnSignalEpochId), &signalHandle));
if (signalHandle != 0) {
WARN("signal handle should be 0");
return mscclppInternalError;
}
return mscclppSuccess;
}
struct connInfo
MSCCLPP_API mscclppResult_t mscclppConnect(mscclppComm_t comm, int remoteRank, int tag, void* localBuff,
uint64_t buffSize, mscclppTransport_t transportType, const char* ibDev)
{
cudaIpcMemHandle_t handleBuff;
cudaIpcMemHandle_t handleSignalEpochId;
mscclppIbQpInfo infoQp;
mscclppIbMrInfo infoBuffMr;
mscclppIbMrInfo infoSignalEpochIdMr;
};
int connId = comm->nConns;
MSCCLPPCHECK(mscclppConnectWithoutBuffer(comm, remoteRank, tag, transportType, ibDev));
struct mscclppConn* conn = &comm->conns[connId];
mscclppResult_t mscclppP2pConnectionSetupStart(struct connInfo* connInfo /*output*/, struct mscclppConn* conn /*input*/)
{
if (connInfo == NULL || conn == NULL) {
WARN("connInfo or connection cannot be null");
conn->buffSize = buffSize;
conn->devConn->localBuff = localBuff;
mscclppBufferHandle_t localBuffHandle = -1;
MSCCLPPCHECK(mscclppRegisterBufferForConnection(comm, connId, conn->devConn->localSignalEpochId, buffSize, &localBuffHandle));
if (localBuffHandle != 1) {
WARN("data buffer handle should be 1");
return mscclppInternalError;
}
struct mscclppDevConn* devConn = conn->devConn;
CUDACHECK(cudaIpcGetMemHandle(&connInfo->handleBuff, devConn->localBuff));
CUDACHECK(cudaIpcGetMemHandle(&connInfo->handleSignalEpochId, devConn->localSignalEpochId));
return mscclppSuccess;
}
MSCCLPP_API mscclppResult_t mscclppRegisterBufferForConnection(mscclppComm_t comm, int connIdx, void* localBuff, uint64_t buffSize, mscclppBufferHandle_t *handle) {
if (connIdx >= comm->nConns) {
WARN("connIdx out of range");
return mscclppInvalidArgument;
}
mscclppConn& conn = comm->conns[connIdx];
*handle = conn.bufferRegistrations.size();
conn.bufferRegistrations.emplace_back();
conn.bufferRegistrations.back().data = localBuff;
conn.bufferRegistrations.back().size = buffSize;
return mscclppSuccess;
}
struct mscclppBufferRegistrationInfo
{
cudaIpcMemHandle_t cudaHandle;
mscclppIbMrInfo ibMrInfo;
uint64_t size;
};
struct connInfo
{
mscclppIbQpInfo infoQp;
std::vector<mscclppBufferRegistrationInfo> bufferInfos;
struct header {
mscclppIbQpInfo infoQp;
int numBufferInfos;
};
mscclppResult_t sendOverBootstrap(void* bootstrap, int remoteRank, int tag) {
header h;
h.infoQp = infoQp;
h.numBufferInfos = bufferInfos.size();
MSCCLPPCHECK(bootstrapSend(bootstrap, remoteRank, tag, &h, sizeof(header)));
MSCCLPPCHECK(bootstrapSend(bootstrap, remoteRank, tag, bufferInfos.data(), bufferInfos.size() * sizeof(mscclppBufferRegistrationInfo)));
return mscclppSuccess;
}
mscclppResult_t recvOverBootstrap(void* bootstrap, int remoteRank, int tag) {
header h;
MSCCLPPCHECK(bootstrapRecv(bootstrap, remoteRank, tag, &h, sizeof(header)));
infoQp = h.infoQp;
bufferInfos.resize(h.numBufferInfos);
MSCCLPPCHECK(bootstrapRecv(bootstrap, remoteRank, tag, bufferInfos.data(), bufferInfos.size() * sizeof(mscclppBufferRegistrationInfo)));
return mscclppSuccess;
}
};
mscclppResult_t mscclppP2pConnectionSetupStart(struct connInfo* connInfo /*input*/, struct mscclppConn* conn /*input*/)
{
if (conn == NULL) {
WARN("connection cannot be null");
return mscclppInternalError;
}
// Add all registered buffers
for (const auto &bufReg : conn->bufferRegistrations) {
connInfo->bufferInfos.emplace_back();
CUDACHECK(cudaIpcGetMemHandle(&connInfo->bufferInfos.back().cudaHandle, bufReg.data));
connInfo->bufferInfos.back().size = bufReg.size;
}
return mscclppSuccess;
}
@@ -586,10 +664,30 @@ mscclppResult_t mscclppP2pConnectionSetupEnd(struct connInfo* connInfo /*input*/
WARN("ipcHandles or connection cannot be null");
return mscclppInternalError;
}
CUDACHECK(
cudaIpcOpenMemHandle((void**)&conn->devConn->remoteBuff, connInfo->handleBuff, cudaIpcMemLazyEnablePeerAccess));
CUDACHECK(cudaIpcOpenMemHandle((void**)&conn->devConn->remoteSignalEpochId, connInfo->handleSignalEpochId,
cudaIpcMemLazyEnablePeerAccess));
if (connInfo->bufferInfos.size() < 1) {
WARN("at least 1 buffer info expected");
return mscclppInternalError;
}
// Open all remote registered buffers
for (size_t i = 0; i < connInfo->bufferInfos.size(); i++) {
mscclppBufferRegistration newBufReg;
CUDACHECK(cudaIpcOpenMemHandle(&newBufReg.data, connInfo->bufferInfos[i].cudaHandle, cudaIpcMemLazyEnablePeerAccess));
newBufReg.size = connInfo->bufferInfos[i].size;
conn->remoteBufferRegistrations.push_back(newBufReg);
}
if (conn->remoteBufferRegistrations[0].size != sizeof(mscclppDevConnSignalEpochId)) {
WARN("buffer registration zero size doesn't match sizeof(mscclppDevConnSignalEpochId)");
return mscclppInternalError;
}
conn->devConn->remoteSignalEpochId = (mscclppDevConnSignalEpochId*)conn->remoteBufferRegistrations[0].data;
// For backwards compatibility with the previous API that assumed one data buffer per connection, set the remote buffer
// to the first remote data buffer
if (conn->remoteBufferRegistrations.size() > 1) {
conn->devConn->remoteBuff = conn->remoteBufferRegistrations[1].data;
}
return mscclppSuccess;
}
@@ -608,12 +706,18 @@ mscclppResult_t mscclppIbConnectionSetupStart(struct connInfo* connInfo /*output
if (hostConn->ibQp == NULL) {
MSCCLPPCHECK(mscclppIbContextCreateQp(ibCtx, &hostConn->ibQp));
}
MSCCLPPCHECK(mscclppIbContextRegisterMr(ibCtx, devConn->localBuff, conn->buffSize, &hostConn->ibBuffMr));
MSCCLPPCHECK(mscclppIbContextRegisterMr(ibCtx, devConn->localSignalEpochId,
sizeof(struct mscclppDevConnSignalEpochId), &hostConn->ibSignalEpochIdMr));
// Add all registered buffers
for (const auto &bufReg : conn->bufferRegistrations) {
hostConn->ibMrs.emplace_back();
MSCCLPPCHECK(mscclppIbContextRegisterMr(ibCtx, bufReg.data,
sizeof(struct mscclppDevConnSignalEpochId), &hostConn->ibMrs.back()));
connInfo->bufferInfos.emplace_back();
connInfo->bufferInfos.back().ibMrInfo = hostConn->ibMrs.back()->info;
connInfo->bufferInfos.back().size = bufReg.size;
}
connInfo->infoQp = hostConn->ibQp->info;
connInfo->infoBuffMr = hostConn->ibBuffMr->info;
connInfo->infoSignalEpochIdMr = hostConn->ibSignalEpochIdMr->info;
return mscclppSuccess;
}
@@ -632,8 +736,18 @@ mscclppResult_t mscclppIbConnectionSetupEnd(struct connInfo* connInfo /*input*/,
WARN("Failed to transition QP to RTS");
return mscclppInvalidUsage;
}
hostConn->ibBuffMrRemoteInfo = connInfo->infoBuffMr;
hostConn->ibSignalEpochIdMrRemoteInfo = connInfo->infoSignalEpochIdMr;
// No remote pointers to set with IB, so we just set the Mrs
// Push the Mrs for all the remote registered buffers
for (size_t i = 1; i < connInfo->bufferInfos.size(); i++) {
hostConn->remoteIbMrInfos.push_back(connInfo->bufferInfos[i].ibMrInfo);
mscclppBufferRegistration newBufReg;
newBufReg.data = nullptr;
newBufReg.size = connInfo->bufferInfos[i].size;
conn->remoteBufferRegistrations.push_back(newBufReg);
}
return mscclppSuccess;
}
@@ -650,14 +764,15 @@ MSCCLPP_API mscclppResult_t mscclppConnectionSetup(mscclppComm_t comm)
MSCCLPPCHECK(mscclppIbConnectionSetupStart(&cInfo, conn));
}
// TODO: from saemal: do we possibly deadlock if there are too many outstanding sends?
MSCCLPPCHECK(bootstrapSend(comm->bootstrap, conn->devConn->remoteRank, conn->devConn->tag, &cInfo, sizeof(cInfo)));
// MSCCLPPCHECK(bootstrapSend(comm->bootstrap, conn->devConn->remoteRank, conn->devConn->tag, &cInfo, sizeof(cInfo)));
MSCCLPPCHECK(cInfo.sendOverBootstrap(comm->bootstrap, conn->devConn->remoteRank, conn->devConn->tag));
}
// Recv info from peers
for (int i = 0; i < comm->nConns; ++i) {
struct mscclppConn* conn = &comm->conns[i];
struct connInfo cInfo;
MSCCLPPCHECK(bootstrapRecv(comm->bootstrap, conn->devConn->remoteRank, conn->devConn->tag, &cInfo, sizeof(cInfo)));
MSCCLPPCHECK(cInfo.recvOverBootstrap(comm->bootstrap, conn->devConn->remoteRank, conn->devConn->tag));
if (conn->transport == mscclppTransportP2P) {
MSCCLPPCHECK(mscclppP2pConnectionSetupEnd(&cInfo, conn));
} else if (conn->transport == mscclppTransportIB) {
@@ -731,16 +846,19 @@ MSCCLPP_API mscclppResult_t mscclppRegisteredBufferWrite(mscclppComm_t comm, msc
void* dstBuff = regMem->p2p[i].remoteBuff;
CUDACHECK(cudaMemcpyAsync(dstBuff, srcBuff, size, cudaMemcpyDeviceToDevice, (cudaStream_t)stream));
} else {
struct mscclppHostIBConn* hostConn = (struct mscclppHostIBConn*)conn->hostConn;
hostConn->ibQp->stageSend(hostConn->ibBuffMr, &hostConn->ibBuffMrRemoteInfo, (uint32_t)size,
/*wrId=*/0, /*srcOffset=*/srcOffset, /*dstOffset=*/dstOffset, /*signaled=*/false);
if ((ret = hostConn->ibQp->postSend()) != 0) {
// Return value is errno.
WARN("data postSend failed: errno %d", ret);
}
// ??
// npkitCollectEntryEvent(conn, NPKIT_EVENT_IB_SEND_ENTRY, (uint32_t)trigger.fields.dataSize,
// trigger.fields.connId);
WARN("mscclppRegisteredBufferWrite not implemented for IB");
return mscclppInternalError;
// TODO: fix the following (Olli: probably by including the relevant ibBuffMr in the mscclppRegisteredMemory)
// struct mscclppHostIBConn* hostConn = (struct mscclppHostIBConn*)conn->hostConn;
// hostConn->ibQp->stageSend(hostConn->ibBuffMr, &hostConn->ibBuffMrRemoteInfo, (uint32_t)size,
// /*wrId=*/0, /*srcOffset=*/srcOffset, /*dstOffset=*/dstOffset, /*signaled=*/false);
// if ((ret = hostConn->ibQp->postSend()) != 0) {
// // Return value is errno.
// WARN("data postSend failed: errno %d", ret);
// }
// // ??
// // npkitCollectEntryEvent(conn, NPKIT_EVENT_IB_SEND_ENTRY, (uint32_t)trigger.fields.dataSize,
// // trigger.fields.connId);
}
}
return mscclppSuccess;