From 9fbb0debdd0951524775c48f2c05e2707e54c341 Mon Sep 17 00:00:00 2001 From: Olli Saarikivi Date: Wed, 19 Apr 2023 22:02:23 +0000 Subject: [PATCH] C++ API changes --- src/basic_proxy_handler.cc | 2 +- src/communicator.cc | 37 +------ src/host_connection.cc | 35 +++++-- src/include/mscclpp.hpp | 194 ++++++++++++++++++++++-------------- src/proxy_cpp.cc | 6 +- tests/allgather_test_cpp.cu | 22 ++-- 6 files changed, 160 insertions(+), 136 deletions(-) diff --git a/src/basic_proxy_handler.cc b/src/basic_proxy_handler.cc index 736c44bd..482aa842 100644 --- a/src/basic_proxy_handler.cc +++ b/src/basic_proxy_handler.cc @@ -19,7 +19,7 @@ ProxyHandler makeBasicProxyHandler(Communicator::Impl &comm) { if (trigger->fields.type & mscclppSync) { conn.flush(); - result = ProxyHandlerResult::FlushAndContinue; + result = ProxyHandlerResult::FlushFifoTailAndContinue; } return result; diff --git a/src/communicator.cc b/src/communicator.cc index cade59a3..5a843c78 100644 --- a/src/communicator.cc +++ b/src/communicator.cc @@ -14,7 +14,6 @@ Communicator::Impl::~Impl() { } } -MSCCLPP_API_CPP Communicator::Communicator() = default; MSCCLPP_API_CPP Communicator::~Communicator() = default; mscclppTransport_t transportTypeToCStyle(TransportType type) { @@ -28,43 +27,26 @@ mscclppTransport_t transportTypeToCStyle(TransportType type) { } } -MSCCLPP_API_CPP void Communicator::initRank(int nranks, const char* ipPortPair, int rank) { - if (pimpl) { - throw std::runtime_error("Communicator already initialized"); - } - pimpl = std::make_unique(); +MSCCLPP_API_CPP Communicator::Communicator(int nranks, const char* ipPortPair, int rank) : pimpl(std::make_unique()) { mscclppCommInitRank(&pimpl->comm, nranks, ipPortPair, rank); } -MSCCLPP_API_CPP void Communicator::initRankFromId(int nranks, UniqueId id, int rank) { - if (pimpl) { - throw std::runtime_error("Communicator already initialized"); - } - pimpl = std::make_unique(); +MSCCLPP_API_CPP Communicator::Communicator(int nranks, UniqueId id, int rank) : pimpl(std::make_unique()) { static_assert(sizeof(mscclppUniqueId) == sizeof(UniqueId), "UniqueId size mismatch"); mscclppUniqueId *cstyle_id = reinterpret_cast(&id); mscclppCommInitRankFromId(&pimpl->comm, nranks, *cstyle_id, rank); } MSCCLPP_API_CPP void Communicator::bootstrapAllGather(void* data, int size) { - if (!pimpl) { - throw std::runtime_error("Communicator not initialized"); - } mscclppBootstrapAllGather(pimpl->comm, data, size); } MSCCLPP_API_CPP void Communicator::bootstrapBarrier() { - if (!pimpl) { - throw std::runtime_error("Communicator not initialized"); - } mscclppBootstrapBarrier(pimpl->comm); } MSCCLPP_API_CPP std::shared_ptr Communicator::connect(int remoteRank, int tag, TransportType transportType, const char* ibDev) { - if (!pimpl) { - throw std::runtime_error("Communicator not initialized"); - } mscclppConnectWithoutBuffer(pimpl->comm, remoteRank, tag, transportTypeToCStyle(transportType), ibDev); auto connIdx = pimpl->connections.size(); auto conn = std::make_shared(std::make_unique(this, &pimpl->comm->conns[connIdx])); @@ -73,39 +55,24 @@ MSCCLPP_API_CPP std::shared_ptr Communicator::connect(int remote } MSCCLPP_API_CPP void Communicator::connectionSetup() { - if (!pimpl) { - throw std::runtime_error("Communicator not initialized"); - } mscclppConnectionSetup(pimpl->comm); } MSCCLPP_API_CPP void Communicator::startProxying() { - if (!pimpl) { - throw std::runtime_error("Communicator not initialized"); - } pimpl->proxy.start(); } MSCCLPP_API_CPP void Communicator::stopProxying() { - if (!pimpl) { - throw std::runtime_error("Communicator not initialized"); - } pimpl->proxy.stop(); } MSCCLPP_API_CPP int Communicator::rank() { - if (!pimpl) { - throw std::runtime_error("Communicator not initialized"); - } int result; mscclppCommRank(pimpl->comm, &result); return result; } MSCCLPP_API_CPP int Communicator::size() { - if (!pimpl) { - throw std::runtime_error("Communicator not initialized"); - } int result; mscclppCommSize(pimpl->comm, &result); return result; diff --git a/src/host_connection.cc b/src/host_connection.cc index cba9f81d..72e11ffc 100644 --- a/src/host_connection.cc +++ b/src/host_connection.cc @@ -15,8 +15,14 @@ HostConnection::Impl::~Impl() { // TODO: figure out memory ownership. Does this deallocate the mscclppHostConn? Likely not. } +MSCCLPP_API_CPP HostConnection::~HostConnection() = default; + MSCCLPP_API_CPP HostConnection::HostConnection(std::unique_ptr p) : pimpl(std::move(p)) {} +MSCCLPP_API_CPP int HostConnection::getId() { + return pimpl->conn->connId; +} + MSCCLPP_API_CPP BufferHandle HostConnection::registerBuffer(void* data, uint64_t size) { BufferHandle result; static_assert(sizeof(BufferHandle) == sizeof(mscclppBufferHandle_t)); @@ -24,10 +30,15 @@ MSCCLPP_API_CPP BufferHandle HostConnection::registerBuffer(void* data, uint64_t return result; } +MSCCLPP_API_CPP int HostConnection::numLocalBuffers() { + return pimpl->conn->bufferRegistrations.size() - 1; +} + +MSCCLPP_API_CPP BufferHandle HostConnection::getLocalBuffer(int index) { + return index + 1; +} + MSCCLPP_API_CPP int HostConnection::numRemoteBuffers() { - if (!pimpl->conn) { - throw std::runtime_error("HostConnection not initialized"); - } return pimpl->conn->remoteBufferRegistrations.size() - 1; } @@ -35,16 +46,18 @@ MSCCLPP_API_CPP BufferHandle HostConnection::getRemoteBuffer(int index) { return index + 1; } -MSCCLPP_API_CPP DeviceConnection HostConnection::toDevice() { - DeviceConnection devConn; +MSCCLPP_API_CPP ConnectionEpoch HostConnection::getEpoch() { + ConnectionEpoch epoch; static_assert(sizeof(SignalEpochId) == sizeof(mscclppDevConnSignalEpochId)); - devConn.connectionId = pimpl->conn->connId; - devConn.localSignalEpochId = reinterpret_cast(pimpl->conn->devConn->localSignalEpochId); - devConn.remoteSignalEpochId = reinterpret_cast(pimpl->conn->devConn->remoteSignalEpochId); - devConn.waitEpochId = pimpl->conn->devConn->waitEpochId; - devConn.fifo = pimpl->comm->pimpl->proxy.fifo().toDevice(); + epoch.localSignalEpochId = reinterpret_cast(pimpl->conn->devConn->localSignalEpochId); + epoch.remoteSignalEpochId = reinterpret_cast(pimpl->conn->devConn->remoteSignalEpochId); + epoch.waitEpochId = pimpl->conn->devConn->waitEpochId; + return epoch; +} - return devConn; + +MSCCLPP_API_CPP DeviceProxyFifo HostConnection::getDeviceFifo() { + return pimpl->comm->pimpl->proxy.fifo().toDevice(); } MSCCLPP_API_CPP void HostConnection::put(BufferHandle dst, uint64_t dstOffset, BufferHandle src, uint64_t srcOffset, uint64_t size) { diff --git a/src/include/mscclpp.hpp b/src/include/mscclpp.hpp index 4a21bae7..e41e94b8 100644 --- a/src/include/mscclpp.hpp +++ b/src/include/mscclpp.hpp @@ -72,6 +72,99 @@ union ChannelTrigger { #endif // __CUDACC__ }; +struct ConnectionEpoch { +#ifdef __CUDACC__ + __forceinline__ __device__ void wait() + { + (*waitEpochId) += 1; + while (*(volatile uint64_t*)&(localSignalEpochId->proxy) < (*waitEpochId)) + ; + } + + __forceinline__ __device__ void epochIncrement() + { + *(volatile uint64_t*)&(localSignalEpochId->device) += 1; + } +#endif // __CUDACC__ + + SignalEpochId* localSignalEpochId; + // used by the signal() function directly from gpu + SignalEpochId* remoteSignalEpochId; + + // every wait(), increments this and then the gpu waits for either: + // 1) localSignalEpochId->proxy to be >= this in case of a proxy thread + // 2) remoteSignalEpochId->device to be >= this in case of a gpu thread + uint64_t* waitEpochId; +}; + +class HostConnection { + struct Impl; +public: + /* HostConnection can not be constructed from user code and must instead be created through Communicator::connect */ + HostConnection(std::unique_ptr); + + ~HostConnection(); + + int getId(); + + /* Register a region of GPU memory for use with this connection. Must be called before connectionSetup() + * in the communicator. + * + * Inputs: + * data: base pointer to the memory + * size: size of the memory region in bytes + * + * Returns: a handle to the buffer + */ + BufferHandle registerBuffer(void* data, uint64_t size); + + /* Get the number of times registerBuffer(...) was called. + * + * Returns: the number of buffers registered + */ + int numLocalBuffers(); + + /* Get the BufferHandle returned by a call to registerBuffer(...) as identified by the index + * + * Inputs: + * index: the index of the handle to get + * + * Returns: a handle to the buffer + */ + BufferHandle getLocalBuffer(int index); + + /* Get the number of times registerBuffer(...) was called on the remote peer. + * + * Returns: the number of buffers registered on the remote peer + */ + int numRemoteBuffers(); + + /* Get the BufferHandle returned by a call to registerBuffer(...) on the remote peer as identified by the index + * + * Inputs: + * index: the index of the handle to get + * + * Returns: a handle to the buffer on the remote peer + */ + BufferHandle getRemoteBuffer(int index); + + ConnectionEpoch getEpoch(); + + DeviceProxyFifo getDeviceFifo(); + + void put(BufferHandle dst, uint64_t dstOffset, BufferHandle src, uint64_t srcOffset, uint64_t size); + + void signal(); + + void flush(); + + void wait(); + +private: + std::unique_ptr pimpl; + friend class Communicator; +}; + /*************************************************************************************************************** * A mscclppDevConn provides a zero-copy connection between two GPUs connected via P2P NVLink or InfiniBand. * The communication API is one-sided meaning that for every single data transfer, only one side @@ -135,9 +228,17 @@ union ChannelTrigger { * indices in the registered buffer. **************************************************************************************************************/ struct DeviceConnection { -#ifdef __CUDACC__ - // TODO: add buffer handles + DeviceConnection() = default; + DeviceConnection(HostConnection& hostConn) + : connectionId(hostConn.getId()), epoch(hostConn.getEpoch()), + fifo(hostConn.getDeviceFifo()) {} + + DeviceConnection(const DeviceConnection& other) = default; + + DeviceConnection& operator=(DeviceConnection& other) = default; + +#ifdef __CUDACC__ __forceinline__ __device__ void put(BufferHandle dst, uint64_t dstOffset, BufferHandle src, uint64_t srcOffset, uint64_t size) { fifo.push(ChannelTrigger(channelTriggerData, dst, dstOffset, src, srcOffset, size, connectionId).value); @@ -191,28 +292,18 @@ struct DeviceConnection { __forceinline__ __device__ void wait() { - (*waitEpochId) += 1; - while (*(volatile uint64_t*)&(localSignalEpochId->proxy) < (*waitEpochId)) - ; + epoch.wait(); } __forceinline__ __device__ void epochIncrement() { - *(volatile uint64_t*)&(localSignalEpochId->device) += 1; + epoch.epochIncrement(); } - #endif // __CUDACC__ int connectionId; - SignalEpochId* localSignalEpochId; - // used by the signal() function directly from gpu - SignalEpochId* remoteSignalEpochId; - - // every wait(), increments this and then the gpu waits for either: - // 1) localSignalEpochId->proxy to be >= this in case of a proxy thread - // 2) remoteSignalEpochId->device to be >= this in case of a gpu thread - uint64_t* waitEpochId; + ConnectionEpoch epoch; // this is a concurrent fifo which is multiple threads from the device // can produce for and the sole proxy thread consumes it. @@ -220,9 +311,15 @@ struct DeviceConnection { }; struct SimpleDeviceConnection { - SimpleDeviceConnection() {} - SimpleDeviceConnection(DeviceConnection devConn, BufferHandle dst, BufferHandle src) : devConn(devConn), dst(dst), src(src) {} + SimpleDeviceConnection() = default; + + SimpleDeviceConnection(HostConnection& hostConn) : devConn(hostConn) { + dst = hostConn.getRemoteBuffer(0); + src = hostConn.getLocalBuffer(0); + } + SimpleDeviceConnection(const SimpleDeviceConnection& other) = default; + SimpleDeviceConnection& operator=(SimpleDeviceConnection& other) = default; #ifdef __CUDACC__ @@ -284,59 +381,6 @@ struct SimpleDeviceConnection { BufferHandle src; }; -class HostConnection { - struct Impl; -public: - /* HostConnection can not be constructed from user code and must instead be created through Communicator::connect */ - HostConnection(std::unique_ptr); - - /* Register a region of GPU memory for use with this connection. Must be called before connectionSetup() - * in the communicator. - * - * Inputs: - * data: base pointer to the memory - * size: size of the memory region in bytes - * - * Returns: a handle to the buffer - */ - BufferHandle registerBuffer(void* data, uint64_t size); - - /* Get the number of times registerBuffer(...) was called on the remote peer. - * - * Returns: the number of buffers registered on the remote peer - */ - int numRemoteBuffers(); - - /* Get the BufferHandle returned by a call to registerBuffer(...) on the remote peer as identified by the index - * - * Inputs: - * index: the index of the handle to get - * - * Returns: a handle to the buffer on the remote peer - */ - BufferHandle getRemoteBuffer(int index); - - /* Create a DeviceConnection paired with this HostConnection. A background proxy thread will - * trigger operations on this HostConnection corresponding to put/signal/etc. calls made to the - * DeviceConnection. - * - * Returns: the newly created DeviceConnection - */ - DeviceConnection toDevice(); - - void put(BufferHandle dst, uint64_t dstOffset, BufferHandle src, uint64_t srcOffset, uint64_t size); - - void signal(); - - void flush(); - - void wait(); - -private: - std::unique_ptr pimpl; - friend class Communicator; -}; - #define MSCCLPP_UNIQUE_ID_BYTES 128 struct UniqueId { char internal[MSCCLPP_UNIQUE_ID_BYTES]; @@ -359,8 +403,6 @@ enum class TransportType : uint8_t { class Communicator { public: - Communicator(); - ~Communicator(); /* Initialize the communicator. nranks processes with rank 0 to nranks-1 need to call this function. * @@ -369,7 +411,7 @@ public: * ipPortPair: a string of the form "ip:port" that represents the address of the root process * rank: rank of the calling process */ - void initRank(int nranks, const char* ipPortPair, int rank); + Communicator(int nranks, const char* ipPortPair, int rank); /* Initialize the communicator from a given UniqueId. Same as mscclppCommInitRank() except that * id is provided by the user by calling getUniqueId() @@ -379,7 +421,9 @@ public: * id: the unique ID to be used for communication * rank: rank of the calling process */ - void initRankFromId(int nranks, UniqueId id, int rank); + Communicator(int nranks, UniqueId id, int rank); + + ~Communicator(); /* Ring-based AllGather through the bootstrap socket. * @@ -441,7 +485,7 @@ private: enum class ProxyHandlerResult { Continue, - FlushAndContinue, + FlushFifoTailAndContinue, Stop, }; diff --git a/src/proxy_cpp.cc b/src/proxy_cpp.cc index 9360d560..2d1cf098 100644 --- a/src/proxy_cpp.cc +++ b/src/proxy_cpp.cc @@ -62,10 +62,14 @@ MSCCLPP_API_CPP void Proxy::start() { // Flush the tail to device memory. This is either triggered every ProxyFlushPeriod to make sure // that the fifo can make progress even if there is no request mscclppSync. However, mscclppSync type is for flush // request. - if ((++flushCnt % ProxyFlushPeriod) == 0 || result == ProxyHandlerResult::FlushAndContinue) { + if ((++flushCnt % ProxyFlushPeriod) == 0 || result == ProxyHandlerResult::FlushFifoTailAndContinue) { // TODO: relocate this check: || (trigger.fields.type & mscclppSync) fifo.flushTail(); } + + if (result == ProxyHandlerResult::Stop) { + break; + } } // make sure the tail is flushed before we shut the proxy diff --git a/tests/allgather_test_cpp.cu b/tests/allgather_test_cpp.cu index ddecadbf..9b056e84 100644 --- a/tests/allgather_test_cpp.cu +++ b/tests/allgather_test_cpp.cu @@ -11,6 +11,7 @@ #include #include #include +#include static int nranksPerNode = 8; @@ -220,7 +221,7 @@ void setupMscclppConnections(int rank, int world_size, mscclpp::Communicator& co int thisNode = rankToNode(rank); int cudaNum = rankToLocalRank(rank); std::string ibDevStr = "mlx5_ib" + std::to_string(cudaNum); - std::vector, mscclpp::BufferHandle>> hostConns; + std::vector> hostConns; for (int r = 0; r < world_size; ++r) { if (r == rank) @@ -235,19 +236,17 @@ void setupMscclppConnections(int rank, int world_size, mscclpp::Communicator& co } // Connect with all other ranks auto hostConn = comm.connect(r, 0, transportType, ibDev); - auto localBuffer = hostConn->registerBuffer(data_d, dataSize); - hostConns.emplace_back(hostConn, localBuffer); + hostConn->registerBuffer(data_d, dataSize); + hostConns.push_back(hostConn); } comm.connectionSetup(); std::vector devConns; - for (auto& entry : hostConns) { - assert(entry.first); - assert(entry.first->numRemoteBuffers() == 1); - auto remoteBuffer = entry.first->getRemoteBuffer(0); - devConns.emplace_back(entry.first->toDevice(), entry.second, remoteBuffer); - } + std::transform(hostConns.begin(), hostConns.end(), std::back_inserter(devConns), + [](std::shared_ptr& hostConn) { + return mscclpp::SimpleDeviceConnection(*hostConn); + }); assert(devConns.size() < sizeof(constDevConns) / sizeof(mscclpp::SimpleDeviceConnection)); CUDACHECK(cudaMemcpyToSymbol(constDevConns, devConns.data(), sizeof(mscclpp::SimpleDeviceConnection) * devConns.size() )); @@ -401,12 +400,9 @@ int main(int argc, const char* argv[]) size_t nelemsPerGPU = dataSize / sizeof(int) / world_size; try{ - mscclpp::Communicator comm; - if (rank == 0) printf("Initializing MSCCL++\n"); - - comm.initRank(world_size, ip_port, rank); + mscclpp::Communicator comm(world_size, ip_port, rank); if (rank == 0) printf("Initializing data for allgather test\n");