Add fifo for host connections

This commit is contained in:
Changho Hwang
2023-04-11 12:28:45 +00:00
parent 35acdf796c
commit 7a0e64813a
5 changed files with 178 additions and 68 deletions

6
src/include/api.h Normal file
View File

@@ -0,0 +1,6 @@
#ifndef MSCCLPP_API_H_
#define MSCCLPP_API_H_
#define MSCCLPP_API extern "C" __attribute__((visibility("default")))
#endif // MSCCLPP_API_H_

View File

@@ -186,8 +186,23 @@ struct mscclppDevConn : mscclppBaseConn
struct mscclppConcurrentFifo fifo;
};
struct mscclppHostConn : mscclppBaseConn
{
void put(uint64_t dstDataOffset, uint64_t srcDataOffset, uint64_t dataSize);
void put(uint64_t dataOffset, uint64_t dataSize);
void signal();
void putWithSignal(uint64_t dstDataOffset, uint64_t srcDataOffset, uint64_t dataSize);
void putWithSignal(uint64_t dataOffset, uint64_t dataSize);
void putWithSignalAndFlush(uint64_t dstDataOffset, uint64_t srcDataOffset, uint64_t dataSize);
void putWithSignalAndFlush(uint64_t dataOffset, uint64_t dataSize);
void flush();
void wait();
void epochIncrement();
};
typedef struct mscclppComm* mscclppComm_t;
typedef struct mscclppDevConn mscclppDevConn_t;
typedef struct mscclppHostConn mscclppHostConn_t;
#define MSCCLPP_UNIQUE_ID_BYTES 128
typedef struct

View File

@@ -3,6 +3,7 @@
#include "comm.h"
#include "mscclpp.h"
#include <atomic>
#include <cuda_runtime.h>
#include <pthread.h>
@@ -15,11 +16,20 @@ typedef enum
MSCCLPP_PROXY_RUN_STATE_EXITING,
} mscclppProxyRunState_t;
// TODO: virtual functions
struct mscclppProxyFifo
{
// virtual mscclppResult_t create() = 0;
// virtual mscclppResult_t destroy() = 0;
// virtual mscclppResult_t poll(mscclppTrigger*) = 0;
// virtual mscclppResult_t pop() = 0;
// virtual mscclppResult_t flushTail(bool) = 0;
};
struct mscclppProxyDevFifo : mscclppProxyFifo
{
mscclppResult_t create();
mscclppResult_t destroy();
mscclppResult_t poll(mscclppTrigger* trigger);
mscclppResult_t pop();
mscclppResult_t flushTail(bool sync = false);
@@ -52,6 +62,24 @@ struct mscclppProxyFifo
cudaStream_t stream;
};
struct mscclppProxyHostFifo : mscclppProxyFifo
{
mscclppResult_t create();
mscclppResult_t destroy();
mscclppResult_t poll(mscclppTrigger* trigger);
mscclppResult_t pop();
mscclppResult_t flushTail(bool sync = false);
// fifo cudaHostCalloc'ed that is produced by device and consumed by host
mscclppTrigger* triggerFifo;
// allocated on the device and only accessed by the device
std::atomic<uint64_t>* fifoHead;
//
uint64_t fifoTailHost;
};
struct mscclppProxyState
{
mscclppTransport_t transportType;
@@ -62,7 +90,8 @@ struct mscclppProxyState
struct mscclppIbContext* ibContext; // For IB connection only
cudaStream_t p2pStream; // for P2P DMA engine only
struct mscclppProxyFifo fifo;
struct mscclppProxyDevFifo devFifo;
struct mscclppProxyHostFifo hostFifo;
};
mscclppResult_t mscclppProxyCreate(struct mscclppComm* comm);

View File

@@ -1,4 +1,5 @@
#include "alloc.h"
#include "api.h"
#include "bootstrap.h"
#include "checks.h"
#include "config.h"
@@ -12,8 +13,6 @@
#include "npkit/npkit.h"
#endif
#define MSCCLPP_API(ret, func, args...) extern "C" __attribute__((visibility("default"))) ret func(args)
static uint64_t hashUniqueId(mscclppUniqueId const& id)
{
char const* bytes = (char const*)&id;
@@ -70,8 +69,7 @@ static std::string mscclppShmFileName(mscclppComm_t comm, int rank)
return ss.str();
}
MSCCLPP_API(mscclppResult_t, mscclppGetUniqueId, mscclppUniqueId* out);
mscclppResult_t mscclppGetUniqueId(mscclppUniqueId* out)
MSCCLPP_API mscclppResult_t mscclppGetUniqueId(mscclppUniqueId* out)
{
MSCCLPPCHECK(mscclppInit());
// mscclppCHECK(PtrCheck(out, "GetUniqueId", "out"));
@@ -80,15 +78,13 @@ mscclppResult_t mscclppGetUniqueId(mscclppUniqueId* out)
return res;
}
MSCCLPP_API(mscclppResult_t, mscclppBootstrapAllGather, mscclppComm_t comm, void* data, int size);
mscclppResult_t mscclppBootstrapAllGather(mscclppComm_t comm, void* data, int size)
MSCCLPP_API mscclppResult_t mscclppBootstrapAllGather(mscclppComm_t comm, void* data, int size)
{
MSCCLPPCHECK(bootstrapAllGather(comm->bootstrap, data, size));
return mscclppSuccess;
}
MSCCLPP_API(mscclppResult_t, mscclppCommInitRank, mscclppComm_t* comm, int nranks, const char* ipPortPair, int rank);
mscclppResult_t mscclppCommInitRank(mscclppComm_t* comm, int nranks, const char* ipPortPair, int rank)
MSCCLPP_API mscclppResult_t mscclppCommInitRank(mscclppComm_t* comm, int nranks, const char* ipPortPair, int rank)
{
#if defined(MSCCLPP_USE_GDRCOPY)
MSCCLPPCHECK(initGdrCopy());
@@ -133,8 +129,7 @@ fail:
return res;
}
MSCCLPP_API(mscclppResult_t, mscclppCommInitRankFromId, mscclppComm_t* comm, int nranks, mscclppUniqueId id, int rank);
mscclppResult_t mscclppCommInitRankFromId(mscclppComm_t* comm, int nranks, mscclppUniqueId id, int rank)
MSCCLPP_API mscclppResult_t mscclppCommInitRankFromId(mscclppComm_t* comm, int nranks, mscclppUniqueId id, int rank)
{
#if defined(MSCCLPP_USE_GDRCOPY)
MSCCLPPCHECK(initGdrCopy());
@@ -174,8 +169,7 @@ fail:
return res;
}
MSCCLPP_API(mscclppResult_t, mscclppCommDestroy, mscclppComm_t comm);
mscclppResult_t mscclppCommDestroy(mscclppComm_t comm)
MSCCLPP_API mscclppResult_t mscclppCommDestroy(mscclppComm_t comm)
{
#if defined(ENABLE_NPKIT)
const char* npkitDumpDir = nullptr;
@@ -187,7 +181,8 @@ mscclppResult_t mscclppCommDestroy(mscclppComm_t comm)
for (int i = 0; i < MSCCLPP_PROXY_MAX_NUM; ++i) {
struct mscclppProxyState* proxyState = comm->proxyState[i];
if (proxyState) {
MSCCLPPCHECK(proxyState->fifo.destroy());
MSCCLPPCHECK(proxyState->devFifo.destroy());
MSCCLPPCHECK(proxyState->hostFifo.destroy());
if (proxyState->p2pStream)
CUDACHECK(cudaStreamDestroy(proxyState->p2pStream));
free(proxyState);
@@ -228,8 +223,7 @@ mscclppResult_t mscclppCommDestroy(mscclppComm_t comm)
return mscclppSuccess;
}
MSCCLPP_API(const char*, mscclppGetErrorString, mscclppResult_t code);
const char* mscclppGetErrorString(mscclppResult_t code)
MSCCLPP_API const char* mscclppGetErrorString(mscclppResult_t code)
{
switch (code) {
case mscclppSuccess:
@@ -253,9 +247,7 @@ const char* mscclppGetErrorString(mscclppResult_t code)
}
}
MSCCLPP_API(mscclppResult_t, mscclppGetDeviceConnection, mscclppComm_t comm, int remoteRank, int tag,
mscclppDevConn_t** devConn);
mscclppResult_t mscclppGetDeviceConnection(mscclppComm_t comm, int remoteRank, int tag, mscclppDevConn_t** devConn)
MSCCLPP_API mscclppResult_t mscclppGetDeviceConnection(mscclppComm_t comm, int remoteRank, int tag, mscclppDevConn_t** devConn)
{
for (int i = 0; i < comm->nConns; i++) {
if (comm->devConns[i].remoteRank == remoteRank && comm->devConns[i].tag == tag) {
@@ -267,18 +259,14 @@ mscclppResult_t mscclppGetDeviceConnection(mscclppComm_t comm, int remoteRank, i
return mscclppInvalidArgument;
}
MSCCLPP_API(mscclppResult_t, mscclppGetAllDeviceConnections, mscclppComm_t comm, mscclppDevConn_t** devConns,
int* nConns);
mscclppResult_t mscclppGetAllDeviceConnections(mscclppComm_t comm, mscclppDevConn_t** devConns, int* nConns)
MSCCLPP_API mscclppResult_t mscclppGetAllDeviceConnections(mscclppComm_t comm, mscclppDevConn_t** devConns, int* nConns)
{
*nConns = comm->nConns;
*devConns = comm->devConns;
return mscclppSuccess;
}
MSCCLPP_API(mscclppResult_t, mscclppConnect, mscclppComm_t comm, int remoteRank, int tag, void* localBuff,
uint64_t buffSize, mscclppTransport_t transportType, const char* ibDev);
mscclppResult_t mscclppConnect(mscclppComm_t comm, int remoteRank, int tag, void* localBuff, uint64_t buffSize,
MSCCLPP_API mscclppResult_t mscclppConnect(mscclppComm_t comm, int remoteRank, int tag, void* localBuff, uint64_t buffSize,
mscclppTransport_t transportType, const char* ibDev)
{
// save this processes numa binding and set it to the one closest to the device
@@ -367,7 +355,8 @@ mscclppResult_t mscclppConnect(mscclppComm_t comm, int remoteRank, int tag, void
// If we couldn't find a matching context, create one
if (proxyState == NULL) {
MSCCLPPCHECK(mscclppCalloc(&proxyState, 1));
MSCCLPPCHECK(proxyState->fifo.create());
MSCCLPPCHECK(proxyState->devFifo.create());
MSCCLPPCHECK(proxyState->hostFifo.create());
if (transportType == mscclppTransportIB) {
proxyState->ibContext = conn->ibCtx;
@@ -398,12 +387,12 @@ mscclppResult_t mscclppConnect(mscclppComm_t comm, int remoteRank, int tag, void
conn->devConn->tag = tag;
conn->devConn->fifo.connId = comm->nConns;
#if defined(MSCCLPP_USE_GDRCOPY)
conn->devConn->fifo.triggerFifo = proxyState->fifo.triggerFifoDev;
conn->devConn->fifo.triggerFifo = proxyState->devFifo.triggerFifoDev;
#else
conn->devConn->fifo.triggerFifo = proxyState->fifo.triggerFifo;
conn->devConn->fifo.triggerFifo = proxyState->devFifo.triggerFifo;
#endif
conn->devConn->fifo.triggerFifoHead = proxyState->fifo.fifoHead;
conn->devConn->fifo.triggerFifoTail = proxyState->fifo.fifoTailDev;
conn->devConn->fifo.triggerFifoHead = proxyState->devFifo.fifoHead;
conn->devConn->fifo.triggerFifoTail = proxyState->devFifo.fifoTailDev;
comm->nConns++;
@@ -489,8 +478,7 @@ mscclppResult_t mscclppIbConnectionSetupEnd(struct connInfo* connInfo /*input*/,
return mscclppSuccess;
}
MSCCLPP_API(mscclppResult_t, mscclppConnectionSetup, mscclppComm_t comm);
mscclppResult_t mscclppConnectionSetup(mscclppComm_t comm)
MSCCLPP_API mscclppResult_t mscclppConnectionSetup(mscclppComm_t comm)
{
// Send info to peers
for (int i = 0; i < comm->nConns; ++i) {
@@ -529,9 +517,7 @@ struct bufferInfo
mscclppIbMrInfo infoBuffMr;
};
MSCCLPP_API(mscclppResult_t, mscclppRegisterBuffer, mscclppComm_t comm, void* local_memory, size_t size,
mscclppRegisteredMemory* regMem);
mscclppResult_t mscclppRegisterBuffer(mscclppComm_t comm, void* local_memory, size_t size,
MSCCLPP_API mscclppResult_t mscclppRegisterBuffer(mscclppComm_t comm, void* local_memory, size_t size,
mscclppRegisteredMemory* regMem)
{
std::vector<struct mscclppIbMr*> ibMrs;
@@ -573,9 +559,7 @@ mscclppResult_t mscclppRegisterBuffer(mscclppComm_t comm, void* local_memory, si
return mscclppSuccess;
}
MSCCLPP_API(mscclppResult_t, mscclppRegisteredBufferWrite, mscclppComm_t comm, mscclppRegisteredMemory* regMem,
void* srcBuff, size_t size, uint32_t srcOffset, uint32_t dstOffset, int64_t stream);
mscclppResult_t mscclppRegisteredBufferWrite(mscclppComm_t comm, mscclppRegisteredMemory* regMem, void* srcBuff,
MSCCLPP_API mscclppResult_t mscclppRegisteredBufferWrite(mscclppComm_t comm, mscclppRegisteredMemory* regMem, void* srcBuff,
size_t size, uint32_t srcOffset, uint32_t dstOffset, int64_t stream)
{
int ret = 0;
@@ -605,15 +589,13 @@ mscclppResult_t mscclppRegisteredBufferWrite(mscclppComm_t comm, mscclppRegister
// TODO: destroy registered buffer
MSCCLPP_API(mscclppResult_t, mscclppProxyLaunch, mscclppComm_t comm);
mscclppResult_t mscclppProxyLaunch(mscclppComm_t comm)
MSCCLPP_API mscclppResult_t mscclppProxyLaunch(mscclppComm_t comm)
{
MSCCLPPCHECK(mscclppProxyCreate(comm));
return mscclppSuccess;
}
MSCCLPP_API(mscclppResult_t, mscclppBootstrapBarrier, mscclppComm_t comm);
mscclppResult_t mscclppBootstrapBarrier(mscclppComm_t comm)
MSCCLPP_API mscclppResult_t mscclppBootstrapBarrier(mscclppComm_t comm)
{
int* tmp = new int[comm->nRanks];
MSCCLPPCHECK(mscclppBootstrapAllGather(comm, tmp, sizeof(int)));
@@ -621,8 +603,7 @@ mscclppResult_t mscclppBootstrapBarrier(mscclppComm_t comm)
return mscclppSuccess;
}
MSCCLPP_API(mscclppResult_t, mscclppProxyStop, mscclppComm_t comm);
mscclppResult_t mscclppProxyStop(mscclppComm_t comm)
MSCCLPP_API mscclppResult_t mscclppProxyStop(mscclppComm_t comm)
{
// a barrier to make sure all ranks are done with their work before stopping the proxy
MSCCLPPCHECK(mscclppBootstrapBarrier(comm));
@@ -631,8 +612,7 @@ mscclppResult_t mscclppProxyStop(mscclppComm_t comm)
return mscclppSuccess;
}
MSCCLPP_API(mscclppResult_t, mscclppCommRank, mscclppComm_t comm, int* rank);
mscclppResult_t mscclppCommRank(mscclppComm_t comm, int* rank)
MSCCLPP_API mscclppResult_t mscclppCommRank(mscclppComm_t comm, int* rank)
{
if (comm == NULL || rank == NULL) {
WARN("comm or rank cannot be null");
@@ -642,8 +622,7 @@ mscclppResult_t mscclppCommRank(mscclppComm_t comm, int* rank)
return mscclppSuccess;
}
MSCCLPP_API(mscclppResult_t, mscclppCommSize, mscclppComm_t comm, int* size);
mscclppResult_t mscclppCommSize(mscclppComm_t comm, int* size)
MSCCLPP_API mscclppResult_t mscclppCommSize(mscclppComm_t comm, int* size)
{
if (comm == NULL || size == NULL) {
WARN("comm or size cannot be null");
@@ -653,22 +632,76 @@ mscclppResult_t mscclppCommSize(mscclppComm_t comm, int* size)
return mscclppSuccess;
}
MSCCLPP_API(void, mscclppDefaultLogHandler, const char* msg);
void mscclppDefaultLogHandler(const char* msg)
MSCCLPP_API void mscclppDefaultLogHandler(const char* msg)
{
mscclppDebugDefaultLogHandler(msg);
}
MSCCLPP_API(mscclppResult_t, mscclppSetLogHandler, mscclppLogHandler_t handler);
mscclppResult_t mscclppSetLogHandler(mscclppLogHandler_t handler)
MSCCLPP_API mscclppResult_t mscclppSetLogHandler(mscclppLogHandler_t handler)
{
return mscclppDebugSetLogHandler(handler);
}
MSCCLPP_API(mscclppResult_t, mscclppSetBootstrapConnTimeout, int timeout);
mscclppResult_t mscclppSetBootstrapConnTimeout(int timeout)
MSCCLPP_API mscclppResult_t mscclppSetBootstrapConnTimeout(int timeout)
{
mscclppConfig* config = mscclppConfig::getInstance();
config->setBootstrapConnectionTimeoutConfig(timeout);
return mscclppSuccess;
}
static inline uint64_t hostFifoPush(uint64_t type, uint64_t dstDataOffset, uint64_t srcDataOffset, uint64_t dataSize)
{
}
MSCCLPP_API void mscclppHostConn::put(uint64_t dstDataOffset, uint64_t srcDataOffset, uint64_t dataSize)
{
}
MSCCLPP_API void mscclppHostConn::put(uint64_t dataOffset, uint64_t dataSize)
{
}
MSCCLPP_API void mscclppHostConn::signal()
{
}
MSCCLPP_API void mscclppHostConn::putWithSignal(uint64_t dstDataOffset, uint64_t srcDataOffset, uint64_t dataSize)
{
}
MSCCLPP_API void mscclppHostConn::putWithSignal(uint64_t dataOffset, uint64_t dataSize)
{
}
MSCCLPP_API void mscclppHostConn::putWithSignalAndFlush(uint64_t dstDataOffset, uint64_t srcDataOffset, uint64_t dataSize)
{
}
MSCCLPP_API void mscclppHostConn::putWithSignalAndFlush(uint64_t dataOffset, uint64_t dataSize)
{
}
MSCCLPP_API void mscclppHostConn::flush()
{
}
MSCCLPP_API void mscclppHostConn::wait()
{
}
MSCCLPP_API void mscclppHostConn::epochIncrement()
{
}

View File

@@ -40,12 +40,6 @@ struct proxyArgs
struct mscclppProxyState* proxyState;
};
static void readTrigger(mscclppTrigger* dst, mscclppTrigger* src)
{
__m128i xmm0 = _mm_load_si128((__m128i*)src);
_mm_store_si128((__m128i*)dst, xmm0);
}
#if defined(ENABLE_NPKIT)
static void npkitInitReqIds(struct mscclppComm* comm)
@@ -93,7 +87,7 @@ static void npkitCollectExitEvents(struct mscclppConn* conn, uint8_t type, int c
#endif
mscclppResult_t mscclppProxyFifo::create()
mscclppResult_t mscclppProxyDevFifo::create()
{
MSCCLPPCHECK(mscclppCudaCalloc(&this->fifoHead, 1));
#if defined(MSCCLPP_USE_GDRCOPY)
@@ -110,7 +104,7 @@ mscclppResult_t mscclppProxyFifo::create()
return mscclppSuccess;
}
mscclppResult_t mscclppProxyFifo::destroy()
mscclppResult_t mscclppProxyDevFifo::destroy()
{
MSCCLPPCHECK(mscclppCudaFree(this->fifoHead));
#if defined(MSCCLPP_USE_GDRCOPY)
@@ -125,21 +119,21 @@ mscclppResult_t mscclppProxyFifo::destroy()
}
// return true if the trigger is valid
mscclppResult_t mscclppProxyFifo::poll(mscclppTrigger* trigger)
mscclppResult_t mscclppProxyDevFifo::poll(mscclppTrigger* trigger)
{
__m128i xmm0 = _mm_load_si128((__m128i*)&this->triggerFifo[this->fifoTailHost % MSCCLPP_PROXY_FIFO_SIZE]);
_mm_store_si128((__m128i*)trigger, xmm0);
return mscclppSuccess;
}
mscclppResult_t mscclppProxyFifo::pop()
mscclppResult_t mscclppProxyDevFifo::pop()
{
*(volatile uint64_t*)(&this->triggerFifo[this->fifoTailHost % MSCCLPP_PROXY_FIFO_SIZE]) = 0;
(this->fifoTailHost)++;
return mscclppSuccess;
}
mscclppResult_t mscclppProxyFifo::flushTail(bool sync)
mscclppResult_t mscclppProxyDevFifo::flushTail(bool sync)
{
// Flush the tail to device memory. This is either triggered every MSCCLPP_PROXY_FIFO_FLUSH_COUNTER to make sure
// that the fifo can make progress even if there is no request mscclppSync. However, mscclppSync type is for flush
@@ -156,6 +150,40 @@ mscclppResult_t mscclppProxyFifo::flushTail(bool sync)
return mscclppSuccess;
}
mscclppResult_t mscclppProxyHostFifo::create()
{
MSCCLPPCHECK(mscclppCalloc(&this->fifoHead, 1));
MSCCLPPCHECK(mscclppCalloc(&this->triggerFifo, MSCCLPP_PROXY_FIFO_SIZE));
this->fifoTailHost = 0;
return mscclppSuccess;
}
mscclppResult_t mscclppProxyHostFifo::destroy()
{
free(this->fifoHead);
free(this->triggerFifo);
return mscclppSuccess;
}
mscclppResult_t mscclppProxyHostFifo::poll(mscclppTrigger* trigger)
{
__m128i xmm0 = _mm_load_si128((__m128i*)&this->triggerFifo[this->fifoTailHost % MSCCLPP_PROXY_FIFO_SIZE]);
_mm_store_si128((__m128i*)trigger, xmm0);
return mscclppSuccess;
}
mscclppResult_t mscclppProxyHostFifo::pop()
{
*(volatile uint64_t*)(&this->triggerFifo[this->fifoTailHost % MSCCLPP_PROXY_FIFO_SIZE]) = 0;
(this->fifoTailHost)++;
return mscclppSuccess;
}
mscclppResult_t mscclppProxyHostFifo::flushTail(bool)
{
return mscclppSuccess;
}
void* mscclppProxyService(void* _args)
{
struct proxyArgs* args = (struct proxyArgs*)_args;
@@ -164,9 +192,8 @@ void* mscclppProxyService(void* _args)
// from this point on, proxy thread will stay close to the device
PROXYMSCCLPPCHECK(numaBind(comm->devNumaNode));
struct mscclppProxyFifo* fifo = &args->proxyState->fifo;
struct mscclppProxyDevFifo* fifo = &args->proxyState->devFifo;
volatile mscclppProxyRunState_t* run = &args->proxyState->run;
mscclppTrigger trigger;
mscclppIbContext* ibCtx = args->proxyState->ibContext;
cudaStream_t p2pStream = args->proxyState->p2pStream;