diff --git a/src/include/api.h b/src/include/api.h new file mode 100644 index 00000000..bc5bd1a6 --- /dev/null +++ b/src/include/api.h @@ -0,0 +1,6 @@ +#ifndef MSCCLPP_API_H_ +#define MSCCLPP_API_H_ + +#define MSCCLPP_API extern "C" __attribute__((visibility("default"))) + +#endif // MSCCLPP_API_H_ diff --git a/src/include/mscclpp.h b/src/include/mscclpp.h index c67add94..94cabd58 100644 --- a/src/include/mscclpp.h +++ b/src/include/mscclpp.h @@ -186,8 +186,23 @@ struct mscclppDevConn : mscclppBaseConn struct mscclppConcurrentFifo fifo; }; +struct mscclppHostConn : mscclppBaseConn +{ + void put(uint64_t dstDataOffset, uint64_t srcDataOffset, uint64_t dataSize); + void put(uint64_t dataOffset, uint64_t dataSize); + void signal(); + void putWithSignal(uint64_t dstDataOffset, uint64_t srcDataOffset, uint64_t dataSize); + void putWithSignal(uint64_t dataOffset, uint64_t dataSize); + void putWithSignalAndFlush(uint64_t dstDataOffset, uint64_t srcDataOffset, uint64_t dataSize); + void putWithSignalAndFlush(uint64_t dataOffset, uint64_t dataSize); + void flush(); + void wait(); + void epochIncrement(); +}; + typedef struct mscclppComm* mscclppComm_t; typedef struct mscclppDevConn mscclppDevConn_t; +typedef struct mscclppHostConn mscclppHostConn_t; #define MSCCLPP_UNIQUE_ID_BYTES 128 typedef struct diff --git a/src/include/proxy.h b/src/include/proxy.h index 8b300919..682164a0 100644 --- a/src/include/proxy.h +++ b/src/include/proxy.h @@ -3,6 +3,7 @@ #include "comm.h" #include "mscclpp.h" +#include #include #include @@ -15,11 +16,20 @@ typedef enum MSCCLPP_PROXY_RUN_STATE_EXITING, } mscclppProxyRunState_t; +// TODO: virtual functions struct mscclppProxyFifo +{ + // virtual mscclppResult_t create() = 0; + // virtual mscclppResult_t destroy() = 0; + // virtual mscclppResult_t poll(mscclppTrigger*) = 0; + // virtual mscclppResult_t pop() = 0; + // virtual mscclppResult_t flushTail(bool) = 0; +}; + +struct mscclppProxyDevFifo : mscclppProxyFifo { mscclppResult_t create(); mscclppResult_t destroy(); - mscclppResult_t poll(mscclppTrigger* trigger); mscclppResult_t pop(); mscclppResult_t flushTail(bool sync = false); @@ -52,6 +62,24 @@ struct mscclppProxyFifo cudaStream_t stream; }; +struct mscclppProxyHostFifo : mscclppProxyFifo +{ + mscclppResult_t create(); + mscclppResult_t destroy(); + mscclppResult_t poll(mscclppTrigger* trigger); + mscclppResult_t pop(); + mscclppResult_t flushTail(bool sync = false); + + // fifo cudaHostCalloc'ed that is produced by device and consumed by host + mscclppTrigger* triggerFifo; + + // allocated on the device and only accessed by the device + std::atomic* fifoHead; + + // + uint64_t fifoTailHost; +}; + struct mscclppProxyState { mscclppTransport_t transportType; @@ -62,7 +90,8 @@ struct mscclppProxyState struct mscclppIbContext* ibContext; // For IB connection only cudaStream_t p2pStream; // for P2P DMA engine only - struct mscclppProxyFifo fifo; + struct mscclppProxyDevFifo devFifo; + struct mscclppProxyHostFifo hostFifo; }; mscclppResult_t mscclppProxyCreate(struct mscclppComm* comm); diff --git a/src/init.cc b/src/init.cc index daaac49d..10215984 100644 --- a/src/init.cc +++ b/src/init.cc @@ -1,4 +1,5 @@ #include "alloc.h" +#include "api.h" #include "bootstrap.h" #include "checks.h" #include "config.h" @@ -12,8 +13,6 @@ #include "npkit/npkit.h" #endif -#define MSCCLPP_API(ret, func, args...) extern "C" __attribute__((visibility("default"))) ret func(args) - static uint64_t hashUniqueId(mscclppUniqueId const& id) { char const* bytes = (char const*)&id; @@ -70,8 +69,7 @@ static std::string mscclppShmFileName(mscclppComm_t comm, int rank) return ss.str(); } -MSCCLPP_API(mscclppResult_t, mscclppGetUniqueId, mscclppUniqueId* out); -mscclppResult_t mscclppGetUniqueId(mscclppUniqueId* out) +MSCCLPP_API mscclppResult_t mscclppGetUniqueId(mscclppUniqueId* out) { MSCCLPPCHECK(mscclppInit()); // mscclppCHECK(PtrCheck(out, "GetUniqueId", "out")); @@ -80,15 +78,13 @@ mscclppResult_t mscclppGetUniqueId(mscclppUniqueId* out) return res; } -MSCCLPP_API(mscclppResult_t, mscclppBootstrapAllGather, mscclppComm_t comm, void* data, int size); -mscclppResult_t mscclppBootstrapAllGather(mscclppComm_t comm, void* data, int size) +MSCCLPP_API mscclppResult_t mscclppBootstrapAllGather(mscclppComm_t comm, void* data, int size) { MSCCLPPCHECK(bootstrapAllGather(comm->bootstrap, data, size)); return mscclppSuccess; } -MSCCLPP_API(mscclppResult_t, mscclppCommInitRank, mscclppComm_t* comm, int nranks, const char* ipPortPair, int rank); -mscclppResult_t mscclppCommInitRank(mscclppComm_t* comm, int nranks, const char* ipPortPair, int rank) +MSCCLPP_API mscclppResult_t mscclppCommInitRank(mscclppComm_t* comm, int nranks, const char* ipPortPair, int rank) { #if defined(MSCCLPP_USE_GDRCOPY) MSCCLPPCHECK(initGdrCopy()); @@ -133,8 +129,7 @@ fail: return res; } -MSCCLPP_API(mscclppResult_t, mscclppCommInitRankFromId, mscclppComm_t* comm, int nranks, mscclppUniqueId id, int rank); -mscclppResult_t mscclppCommInitRankFromId(mscclppComm_t* comm, int nranks, mscclppUniqueId id, int rank) +MSCCLPP_API mscclppResult_t mscclppCommInitRankFromId(mscclppComm_t* comm, int nranks, mscclppUniqueId id, int rank) { #if defined(MSCCLPP_USE_GDRCOPY) MSCCLPPCHECK(initGdrCopy()); @@ -174,8 +169,7 @@ fail: return res; } -MSCCLPP_API(mscclppResult_t, mscclppCommDestroy, mscclppComm_t comm); -mscclppResult_t mscclppCommDestroy(mscclppComm_t comm) +MSCCLPP_API mscclppResult_t mscclppCommDestroy(mscclppComm_t comm) { #if defined(ENABLE_NPKIT) const char* npkitDumpDir = nullptr; @@ -187,7 +181,8 @@ mscclppResult_t mscclppCommDestroy(mscclppComm_t comm) for (int i = 0; i < MSCCLPP_PROXY_MAX_NUM; ++i) { struct mscclppProxyState* proxyState = comm->proxyState[i]; if (proxyState) { - MSCCLPPCHECK(proxyState->fifo.destroy()); + MSCCLPPCHECK(proxyState->devFifo.destroy()); + MSCCLPPCHECK(proxyState->hostFifo.destroy()); if (proxyState->p2pStream) CUDACHECK(cudaStreamDestroy(proxyState->p2pStream)); free(proxyState); @@ -228,8 +223,7 @@ mscclppResult_t mscclppCommDestroy(mscclppComm_t comm) return mscclppSuccess; } -MSCCLPP_API(const char*, mscclppGetErrorString, mscclppResult_t code); -const char* mscclppGetErrorString(mscclppResult_t code) +MSCCLPP_API const char* mscclppGetErrorString(mscclppResult_t code) { switch (code) { case mscclppSuccess: @@ -253,9 +247,7 @@ const char* mscclppGetErrorString(mscclppResult_t code) } } -MSCCLPP_API(mscclppResult_t, mscclppGetDeviceConnection, mscclppComm_t comm, int remoteRank, int tag, - mscclppDevConn_t** devConn); -mscclppResult_t mscclppGetDeviceConnection(mscclppComm_t comm, int remoteRank, int tag, mscclppDevConn_t** devConn) +MSCCLPP_API mscclppResult_t mscclppGetDeviceConnection(mscclppComm_t comm, int remoteRank, int tag, mscclppDevConn_t** devConn) { for (int i = 0; i < comm->nConns; i++) { if (comm->devConns[i].remoteRank == remoteRank && comm->devConns[i].tag == tag) { @@ -267,18 +259,14 @@ mscclppResult_t mscclppGetDeviceConnection(mscclppComm_t comm, int remoteRank, i return mscclppInvalidArgument; } -MSCCLPP_API(mscclppResult_t, mscclppGetAllDeviceConnections, mscclppComm_t comm, mscclppDevConn_t** devConns, - int* nConns); -mscclppResult_t mscclppGetAllDeviceConnections(mscclppComm_t comm, mscclppDevConn_t** devConns, int* nConns) +MSCCLPP_API mscclppResult_t mscclppGetAllDeviceConnections(mscclppComm_t comm, mscclppDevConn_t** devConns, int* nConns) { *nConns = comm->nConns; *devConns = comm->devConns; return mscclppSuccess; } -MSCCLPP_API(mscclppResult_t, mscclppConnect, mscclppComm_t comm, int remoteRank, int tag, void* localBuff, - uint64_t buffSize, mscclppTransport_t transportType, const char* ibDev); -mscclppResult_t mscclppConnect(mscclppComm_t comm, int remoteRank, int tag, void* localBuff, uint64_t buffSize, +MSCCLPP_API mscclppResult_t mscclppConnect(mscclppComm_t comm, int remoteRank, int tag, void* localBuff, uint64_t buffSize, mscclppTransport_t transportType, const char* ibDev) { // save this processes numa binding and set it to the one closest to the device @@ -367,7 +355,8 @@ mscclppResult_t mscclppConnect(mscclppComm_t comm, int remoteRank, int tag, void // If we couldn't find a matching context, create one if (proxyState == NULL) { MSCCLPPCHECK(mscclppCalloc(&proxyState, 1)); - MSCCLPPCHECK(proxyState->fifo.create()); + MSCCLPPCHECK(proxyState->devFifo.create()); + MSCCLPPCHECK(proxyState->hostFifo.create()); if (transportType == mscclppTransportIB) { proxyState->ibContext = conn->ibCtx; @@ -398,12 +387,12 @@ mscclppResult_t mscclppConnect(mscclppComm_t comm, int remoteRank, int tag, void conn->devConn->tag = tag; conn->devConn->fifo.connId = comm->nConns; #if defined(MSCCLPP_USE_GDRCOPY) - conn->devConn->fifo.triggerFifo = proxyState->fifo.triggerFifoDev; + conn->devConn->fifo.triggerFifo = proxyState->devFifo.triggerFifoDev; #else - conn->devConn->fifo.triggerFifo = proxyState->fifo.triggerFifo; + conn->devConn->fifo.triggerFifo = proxyState->devFifo.triggerFifo; #endif - conn->devConn->fifo.triggerFifoHead = proxyState->fifo.fifoHead; - conn->devConn->fifo.triggerFifoTail = proxyState->fifo.fifoTailDev; + conn->devConn->fifo.triggerFifoHead = proxyState->devFifo.fifoHead; + conn->devConn->fifo.triggerFifoTail = proxyState->devFifo.fifoTailDev; comm->nConns++; @@ -489,8 +478,7 @@ mscclppResult_t mscclppIbConnectionSetupEnd(struct connInfo* connInfo /*input*/, return mscclppSuccess; } -MSCCLPP_API(mscclppResult_t, mscclppConnectionSetup, mscclppComm_t comm); -mscclppResult_t mscclppConnectionSetup(mscclppComm_t comm) +MSCCLPP_API mscclppResult_t mscclppConnectionSetup(mscclppComm_t comm) { // Send info to peers for (int i = 0; i < comm->nConns; ++i) { @@ -529,9 +517,7 @@ struct bufferInfo mscclppIbMrInfo infoBuffMr; }; -MSCCLPP_API(mscclppResult_t, mscclppRegisterBuffer, mscclppComm_t comm, void* local_memory, size_t size, - mscclppRegisteredMemory* regMem); -mscclppResult_t mscclppRegisterBuffer(mscclppComm_t comm, void* local_memory, size_t size, +MSCCLPP_API mscclppResult_t mscclppRegisterBuffer(mscclppComm_t comm, void* local_memory, size_t size, mscclppRegisteredMemory* regMem) { std::vector ibMrs; @@ -573,9 +559,7 @@ mscclppResult_t mscclppRegisterBuffer(mscclppComm_t comm, void* local_memory, si return mscclppSuccess; } -MSCCLPP_API(mscclppResult_t, mscclppRegisteredBufferWrite, mscclppComm_t comm, mscclppRegisteredMemory* regMem, - void* srcBuff, size_t size, uint32_t srcOffset, uint32_t dstOffset, int64_t stream); -mscclppResult_t mscclppRegisteredBufferWrite(mscclppComm_t comm, mscclppRegisteredMemory* regMem, void* srcBuff, +MSCCLPP_API mscclppResult_t mscclppRegisteredBufferWrite(mscclppComm_t comm, mscclppRegisteredMemory* regMem, void* srcBuff, size_t size, uint32_t srcOffset, uint32_t dstOffset, int64_t stream) { int ret = 0; @@ -605,15 +589,13 @@ mscclppResult_t mscclppRegisteredBufferWrite(mscclppComm_t comm, mscclppRegister // TODO: destroy registered buffer -MSCCLPP_API(mscclppResult_t, mscclppProxyLaunch, mscclppComm_t comm); -mscclppResult_t mscclppProxyLaunch(mscclppComm_t comm) +MSCCLPP_API mscclppResult_t mscclppProxyLaunch(mscclppComm_t comm) { MSCCLPPCHECK(mscclppProxyCreate(comm)); return mscclppSuccess; } -MSCCLPP_API(mscclppResult_t, mscclppBootstrapBarrier, mscclppComm_t comm); -mscclppResult_t mscclppBootstrapBarrier(mscclppComm_t comm) +MSCCLPP_API mscclppResult_t mscclppBootstrapBarrier(mscclppComm_t comm) { int* tmp = new int[comm->nRanks]; MSCCLPPCHECK(mscclppBootstrapAllGather(comm, tmp, sizeof(int))); @@ -621,8 +603,7 @@ mscclppResult_t mscclppBootstrapBarrier(mscclppComm_t comm) return mscclppSuccess; } -MSCCLPP_API(mscclppResult_t, mscclppProxyStop, mscclppComm_t comm); -mscclppResult_t mscclppProxyStop(mscclppComm_t comm) +MSCCLPP_API mscclppResult_t mscclppProxyStop(mscclppComm_t comm) { // a barrier to make sure all ranks are done with their work before stopping the proxy MSCCLPPCHECK(mscclppBootstrapBarrier(comm)); @@ -631,8 +612,7 @@ mscclppResult_t mscclppProxyStop(mscclppComm_t comm) return mscclppSuccess; } -MSCCLPP_API(mscclppResult_t, mscclppCommRank, mscclppComm_t comm, int* rank); -mscclppResult_t mscclppCommRank(mscclppComm_t comm, int* rank) +MSCCLPP_API mscclppResult_t mscclppCommRank(mscclppComm_t comm, int* rank) { if (comm == NULL || rank == NULL) { WARN("comm or rank cannot be null"); @@ -642,8 +622,7 @@ mscclppResult_t mscclppCommRank(mscclppComm_t comm, int* rank) return mscclppSuccess; } -MSCCLPP_API(mscclppResult_t, mscclppCommSize, mscclppComm_t comm, int* size); -mscclppResult_t mscclppCommSize(mscclppComm_t comm, int* size) +MSCCLPP_API mscclppResult_t mscclppCommSize(mscclppComm_t comm, int* size) { if (comm == NULL || size == NULL) { WARN("comm or size cannot be null"); @@ -653,22 +632,76 @@ mscclppResult_t mscclppCommSize(mscclppComm_t comm, int* size) return mscclppSuccess; } -MSCCLPP_API(void, mscclppDefaultLogHandler, const char* msg); -void mscclppDefaultLogHandler(const char* msg) +MSCCLPP_API void mscclppDefaultLogHandler(const char* msg) { mscclppDebugDefaultLogHandler(msg); } -MSCCLPP_API(mscclppResult_t, mscclppSetLogHandler, mscclppLogHandler_t handler); -mscclppResult_t mscclppSetLogHandler(mscclppLogHandler_t handler) +MSCCLPP_API mscclppResult_t mscclppSetLogHandler(mscclppLogHandler_t handler) { return mscclppDebugSetLogHandler(handler); } -MSCCLPP_API(mscclppResult_t, mscclppSetBootstrapConnTimeout, int timeout); -mscclppResult_t mscclppSetBootstrapConnTimeout(int timeout) +MSCCLPP_API mscclppResult_t mscclppSetBootstrapConnTimeout(int timeout) { mscclppConfig* config = mscclppConfig::getInstance(); config->setBootstrapConnectionTimeoutConfig(timeout); return mscclppSuccess; } + +static inline uint64_t hostFifoPush(uint64_t type, uint64_t dstDataOffset, uint64_t srcDataOffset, uint64_t dataSize) +{ + +} + +MSCCLPP_API void mscclppHostConn::put(uint64_t dstDataOffset, uint64_t srcDataOffset, uint64_t dataSize) +{ + +} + + +MSCCLPP_API void mscclppHostConn::put(uint64_t dataOffset, uint64_t dataSize) +{ + +} + +MSCCLPP_API void mscclppHostConn::signal() +{ + +} + +MSCCLPP_API void mscclppHostConn::putWithSignal(uint64_t dstDataOffset, uint64_t srcDataOffset, uint64_t dataSize) +{ + +} + +MSCCLPP_API void mscclppHostConn::putWithSignal(uint64_t dataOffset, uint64_t dataSize) +{ + +} + +MSCCLPP_API void mscclppHostConn::putWithSignalAndFlush(uint64_t dstDataOffset, uint64_t srcDataOffset, uint64_t dataSize) +{ + +} + +MSCCLPP_API void mscclppHostConn::putWithSignalAndFlush(uint64_t dataOffset, uint64_t dataSize) +{ + +} + +MSCCLPP_API void mscclppHostConn::flush() +{ + +} + +MSCCLPP_API void mscclppHostConn::wait() +{ + +} + +MSCCLPP_API void mscclppHostConn::epochIncrement() +{ + +} + diff --git a/src/proxy.cc b/src/proxy.cc index cda01466..97df77b4 100644 --- a/src/proxy.cc +++ b/src/proxy.cc @@ -40,12 +40,6 @@ struct proxyArgs struct mscclppProxyState* proxyState; }; -static void readTrigger(mscclppTrigger* dst, mscclppTrigger* src) -{ - __m128i xmm0 = _mm_load_si128((__m128i*)src); - _mm_store_si128((__m128i*)dst, xmm0); -} - #if defined(ENABLE_NPKIT) static void npkitInitReqIds(struct mscclppComm* comm) @@ -93,7 +87,7 @@ static void npkitCollectExitEvents(struct mscclppConn* conn, uint8_t type, int c #endif -mscclppResult_t mscclppProxyFifo::create() +mscclppResult_t mscclppProxyDevFifo::create() { MSCCLPPCHECK(mscclppCudaCalloc(&this->fifoHead, 1)); #if defined(MSCCLPP_USE_GDRCOPY) @@ -110,7 +104,7 @@ mscclppResult_t mscclppProxyFifo::create() return mscclppSuccess; } -mscclppResult_t mscclppProxyFifo::destroy() +mscclppResult_t mscclppProxyDevFifo::destroy() { MSCCLPPCHECK(mscclppCudaFree(this->fifoHead)); #if defined(MSCCLPP_USE_GDRCOPY) @@ -125,21 +119,21 @@ mscclppResult_t mscclppProxyFifo::destroy() } // return true if the trigger is valid -mscclppResult_t mscclppProxyFifo::poll(mscclppTrigger* trigger) +mscclppResult_t mscclppProxyDevFifo::poll(mscclppTrigger* trigger) { __m128i xmm0 = _mm_load_si128((__m128i*)&this->triggerFifo[this->fifoTailHost % MSCCLPP_PROXY_FIFO_SIZE]); _mm_store_si128((__m128i*)trigger, xmm0); return mscclppSuccess; } -mscclppResult_t mscclppProxyFifo::pop() +mscclppResult_t mscclppProxyDevFifo::pop() { *(volatile uint64_t*)(&this->triggerFifo[this->fifoTailHost % MSCCLPP_PROXY_FIFO_SIZE]) = 0; (this->fifoTailHost)++; return mscclppSuccess; } -mscclppResult_t mscclppProxyFifo::flushTail(bool sync) +mscclppResult_t mscclppProxyDevFifo::flushTail(bool sync) { // Flush the tail to device memory. This is either triggered every MSCCLPP_PROXY_FIFO_FLUSH_COUNTER to make sure // that the fifo can make progress even if there is no request mscclppSync. However, mscclppSync type is for flush @@ -156,6 +150,40 @@ mscclppResult_t mscclppProxyFifo::flushTail(bool sync) return mscclppSuccess; } +mscclppResult_t mscclppProxyHostFifo::create() +{ + MSCCLPPCHECK(mscclppCalloc(&this->fifoHead, 1)); + MSCCLPPCHECK(mscclppCalloc(&this->triggerFifo, MSCCLPP_PROXY_FIFO_SIZE)); + this->fifoTailHost = 0; + return mscclppSuccess; +} + +mscclppResult_t mscclppProxyHostFifo::destroy() +{ + free(this->fifoHead); + free(this->triggerFifo); + return mscclppSuccess; +} + +mscclppResult_t mscclppProxyHostFifo::poll(mscclppTrigger* trigger) +{ + __m128i xmm0 = _mm_load_si128((__m128i*)&this->triggerFifo[this->fifoTailHost % MSCCLPP_PROXY_FIFO_SIZE]); + _mm_store_si128((__m128i*)trigger, xmm0); + return mscclppSuccess; +} + +mscclppResult_t mscclppProxyHostFifo::pop() +{ + *(volatile uint64_t*)(&this->triggerFifo[this->fifoTailHost % MSCCLPP_PROXY_FIFO_SIZE]) = 0; + (this->fifoTailHost)++; + return mscclppSuccess; +} + +mscclppResult_t mscclppProxyHostFifo::flushTail(bool) +{ + return mscclppSuccess; +} + void* mscclppProxyService(void* _args) { struct proxyArgs* args = (struct proxyArgs*)_args; @@ -164,9 +192,8 @@ void* mscclppProxyService(void* _args) // from this point on, proxy thread will stay close to the device PROXYMSCCLPPCHECK(numaBind(comm->devNumaNode)); - struct mscclppProxyFifo* fifo = &args->proxyState->fifo; + struct mscclppProxyDevFifo* fifo = &args->proxyState->devFifo; volatile mscclppProxyRunState_t* run = &args->proxyState->run; - mscclppTrigger trigger; mscclppIbContext* ibCtx = args->proxyState->ibContext; cudaStream_t p2pStream = args->proxyState->p2pStream;