diff --git a/Makefile b/Makefile index 801a1ffd..881296f4 100644 --- a/Makefile +++ b/Makefile @@ -120,6 +120,7 @@ LDFLAGS := $(NVLDFLAGS) $(GDRCOPY_LDFLAGS) -libverbs -lnuma LIBSRCS := $(addprefix src/,debug.cc utils.cc init.cc proxy.cc ib.cc config.cc) LIBSRCS += $(addprefix src/bootstrap/,bootstrap.cc socket.cc) +LIBSRCS += $(addprefix src/,communicator.cc fifo.cc host_connection.cc proxy_cpp.cc basic_proxy_handler.cc) ifneq ($(NPKIT), 0) LIBSRCS += $(addprefix src/misc/,npkit.cc) endif @@ -147,7 +148,7 @@ UTOBJTARGETS := $(UTOBJS:%=$(BUILDDIR)/$(OBJDIR)/%) UTBINS := $(patsubst %.o,$(BUILDDIR)/$(BINDIR)/%,$(UTOBJS)) TESTSDIR := tests -TESTSSRCS := $(addprefix $(TESTSDIR)/,bootstrap_test.cc allgather_test_standalone.cu bootstrap_test_cpp.cc) +TESTSSRCS := $(addprefix $(TESTSDIR)/,bootstrap_test.cc allgather_test_standalone.cu allgather_test_cpp.cu bootstrap_test_cpp.cc) TESTSOBJS := $(patsubst %.cc,%.o,$(TESTSSRCS)) $(patsubst %.cu,%.o,$(TESTSSRCS)) TESTSOBJTARGETS := $(TESTSOBJS:%=$(BUILDDIR)/$(OBJDIR)/%) TESTSBINS := $(patsubst %.o,$(BUILDDIR)/$(BINDIR)/%,$(TESTSOBJS)) diff --git a/src/basic_proxy_handler.cc b/src/basic_proxy_handler.cc new file mode 100644 index 00000000..482aa842 --- /dev/null +++ b/src/basic_proxy_handler.cc @@ -0,0 +1,29 @@ +#include "basic_proxy_handler.hpp" + +namespace mscclpp { + +ProxyHandler makeBasicProxyHandler(Communicator::Impl &comm) { + return [&comm](ProxyTrigger triggerRaw) { + ChannelTrigger *trigger = reinterpret_cast(&triggerRaw); + HostConnection& conn = *comm.connections.at(trigger->fields.connId); + + auto result = ProxyHandlerResult::Continue; + + if (trigger->fields.type & mscclppData) { + conn.put(trigger->fields.dstBufferHandle, trigger->fields.dstOffset, trigger->fields.srcBufferHandle, trigger->fields.srcOffset, trigger->fields.size); + } + + if (trigger->fields.type & mscclppFlag) { + conn.signal(); + } + + if (trigger->fields.type & mscclppSync) { + conn.flush(); + result = ProxyHandlerResult::FlushFifoTailAndContinue; + } + + return result; + }; +} + +} // namespace mscclpp diff --git a/src/communicator.cc b/src/communicator.cc new file mode 100644 index 00000000..5a843c78 --- /dev/null +++ b/src/communicator.cc @@ -0,0 +1,81 @@ +#include "communicator.hpp" +#include "host_connection.hpp" +#include "comm.h" +#include "basic_proxy_handler.hpp" +#include "api.h" + +namespace mscclpp { + +Communicator::Impl::Impl() : comm(nullptr), proxy(makeBasicProxyHandler(*this)) {} + +Communicator::Impl::~Impl() { + if (comm) { + mscclppCommDestroy(comm); + } +} + +MSCCLPP_API_CPP Communicator::~Communicator() = default; + +mscclppTransport_t transportTypeToCStyle(TransportType type) { + switch (type) { + case TransportType::IB: + return mscclppTransportIB; + case TransportType::P2P: + return mscclppTransportP2P; + default: + throw std::runtime_error("Unknown transport type"); + } +} + +MSCCLPP_API_CPP Communicator::Communicator(int nranks, const char* ipPortPair, int rank) : pimpl(std::make_unique()) { + mscclppCommInitRank(&pimpl->comm, nranks, ipPortPair, rank); +} + +MSCCLPP_API_CPP Communicator::Communicator(int nranks, UniqueId id, int rank) : pimpl(std::make_unique()) { + static_assert(sizeof(mscclppUniqueId) == sizeof(UniqueId), "UniqueId size mismatch"); + mscclppUniqueId *cstyle_id = reinterpret_cast(&id); + mscclppCommInitRankFromId(&pimpl->comm, nranks, *cstyle_id, rank); +} + +MSCCLPP_API_CPP void Communicator::bootstrapAllGather(void* data, int size) { + mscclppBootstrapAllGather(pimpl->comm, data, size); +} + +MSCCLPP_API_CPP void Communicator::bootstrapBarrier() { + mscclppBootstrapBarrier(pimpl->comm); +} + +MSCCLPP_API_CPP std::shared_ptr Communicator::connect(int remoteRank, int tag, + TransportType transportType, const char* ibDev) { + mscclppConnectWithoutBuffer(pimpl->comm, remoteRank, tag, transportTypeToCStyle(transportType), ibDev); + auto connIdx = pimpl->connections.size(); + auto conn = std::make_shared(std::make_unique(this, &pimpl->comm->conns[connIdx])); + pimpl->connections.push_back(conn); + return conn; +} + +MSCCLPP_API_CPP void Communicator::connectionSetup() { + mscclppConnectionSetup(pimpl->comm); +} + +MSCCLPP_API_CPP void Communicator::startProxying() { + pimpl->proxy.start(); +} + +MSCCLPP_API_CPP void Communicator::stopProxying() { + pimpl->proxy.stop(); +} + +MSCCLPP_API_CPP int Communicator::rank() { + int result; + mscclppCommRank(pimpl->comm, &result); + return result; +} + +MSCCLPP_API_CPP int Communicator::size() { + int result; + mscclppCommSize(pimpl->comm, &result); + return result; +} + +} // namespace mscclpp \ No newline at end of file diff --git a/src/communicator.cpp b/src/communicator.cpp deleted file mode 100644 index 73d82997..00000000 --- a/src/communicator.cpp +++ /dev/null @@ -1,90 +0,0 @@ -#include "mscclpp.hpp" -#include "mscclpp.h" - -namespace mscclpp { - -mscclppTransport_t transportTypeToCStyle(TransportType type) { - switch (type) { - case TransportType::IB: - return mscclppTransportIB; - case TransportType::P2P: - return mscclppTransportP2P; - default: - throw std::runtime_error("Unknown transport type"); - } -} - -struct Communicator::Impl { - mscclppComm_t comm; - std::vector> connections; - - Impl() : comm(nullptr) {} - - ~Impl() { - if (comm) { - mscclppCommDestroy(comm); - } - } -}; - -void Communicator::initRank(int nranks, const char* ipPortPair, int rank) { - if (pimpl) { - throw std::runtime_error("Communicator already initialized"); - } - pimpl = std::make_unique(); - mscclppCommInitRank(&pimpl->comm, nranks, ipPortPair, rank); -} - -void Communicator::initRankFromId(int nranks, UniqueId id, int rank) { - if (pimpl) { - throw std::runtime_error("Communicator already initialized"); - } - pimpl = std::make_unique(); - static_assert(sizeof(mscclppUniqueId) == sizeof(UniqueId), "UniqueId size mismatch"); - mscclppUniqueId *cstyle_id = reinterpret_cast(&id); - mscclppCommInitRankFromId(&pimpl->comm, nranks, *cstyle_id, rank); -} - -void Communicator::bootstrapAllGather(void* data, int size) { - mscclppBootstrapAllGather(pimpl->comm, data, size); -} - -void Communicator::bootstrapBarrier() { - mscclppBootstrapBarrier(pimpl->comm); -} - -std::shared_ptr Communicator::connect(int remoteRank, int tag, - TransportType transportType, const char* ibDev = 0) { - mscclppConnectWithoutBuffer(pimpl->comm, remoteRank, tag, transportTypeToCStyle(transportType), ibDev); - auto conn = std::make_shared(); - auto connIdx = pimpl->connections.size(); - pimpl->connections.push_back(conn); - return conn; -} - -void Communicator::connectionSetup() { - mscclppConnectionSetup(pimpl->comm); - mscclppHostConn_t *hostConns; - int numHostConns; - mscclppGetAllHostConnections(pimpl->comm, &hostConns, &numHostConns); - if (numHostConns != pimpl->connections.size()) { - throw std::logic_error("Number of HostConnections didn't match number of mscclppHostConns"); - } - for (int connIdx = 0; connIdx < pimpl->connections.size(); ++connIdx) { - pimpl->connections[connIdx]->pimpl->setup(hostConns[connIdx]); - } -} - -int Communicator::rank() { - int result; - mscclppCommRank(pimpl->comm, &result); - return result; -} - -int Communicator::size() { - int result; - mscclppCommSize(pimpl->comm, &result); - return result; -} - -} // namespace mscclpp \ No newline at end of file diff --git a/src/fifo.cc b/src/fifo.cc new file mode 100644 index 00000000..fe7f12d3 --- /dev/null +++ b/src/fifo.cc @@ -0,0 +1,67 @@ +#include "mscclppfifo.hpp" +#include "alloc.h" +#include "checks.hpp" +#include +#include +#include + +namespace mscclpp { + +struct HostProxyFifo::Impl { + DeviceProxyFifo deviceFifo; + + // allocated on the host. Only accessed by the host. This is a copy of the + // value pointed to by fifoTailDev and the invariant is that + // *fifoTailDev <= hostTail. Meaning that host's copy of tail is + // always ahead of the device's copy and host updates the device's copy + // only when it is needed. Therefore, hostTail is the "true" tail + // and fifoTailDev is a "stale" tail. See proxy.cc to undertand how + // these updates are pushed to the device. + uint64_t hostTail; + + // for transferring fifo tail + cudaStream_t stream; +}; + +HostProxyFifo::HostProxyFifo() { + pimpl = std::make_unique(); + MSCCLPPTHROW(mscclppCudaCalloc(&pimpl->deviceFifo.head, 1)); + MSCCLPPTHROW(mscclppCudaHostCalloc(&pimpl->deviceFifo.triggers, MSCCLPP_PROXY_FIFO_SIZE)); + MSCCLPPTHROW(mscclppCudaCalloc(&pimpl->deviceFifo.tailReplica, 1)); + CUDATHROW(cudaStreamCreateWithFlags(&pimpl->stream, cudaStreamNonBlocking)); + pimpl->hostTail = 0; +} + +HostProxyFifo::~HostProxyFifo() { + MSCCLPPTHROW(mscclppCudaFree(pimpl->deviceFifo.head)); + MSCCLPPTHROW(mscclppCudaHostFree(pimpl->deviceFifo.triggers)); + MSCCLPPTHROW(mscclppCudaFree(pimpl->deviceFifo.tailReplica)); + CUDATHROW(cudaStreamDestroy(pimpl->stream)); +} + +void HostProxyFifo::poll(ProxyTrigger *trigger) { + __m128i xmm0 = _mm_load_si128((__m128i*)&pimpl->deviceFifo.triggers[pimpl->hostTail % MSCCLPP_PROXY_FIFO_SIZE]); + _mm_store_si128((__m128i*)trigger, xmm0); +} + +void HostProxyFifo::pop() { + *(volatile uint64_t*)(&pimpl->deviceFifo.triggers[pimpl->hostTail % MSCCLPP_PROXY_FIFO_SIZE]) = 0; + (pimpl->hostTail)++; +} + +void HostProxyFifo::flushTail(bool sync) { + // Flush the tail to device memory. This is either triggered every MSCCLPP_PROXY_FIFO_FLUSH_COUNTER to make sure + // that the fifo can make progress even if there is no request mscclppSync. However, mscclppSync type is for flush + // request. + CUDATHROW( + cudaMemcpyAsync(pimpl->deviceFifo.tailReplica, &pimpl->hostTail, sizeof(uint64_t), cudaMemcpyHostToDevice, pimpl->stream)); + if (sync) { + CUDATHROW(cudaStreamSynchronize(pimpl->stream)); + } +} + +DeviceProxyFifo HostProxyFifo::toDevice() { + return pimpl->deviceFifo; +} + +} // namespace mscclpp diff --git a/src/host_connection.cc b/src/host_connection.cc new file mode 100644 index 00000000..72e11ffc --- /dev/null +++ b/src/host_connection.cc @@ -0,0 +1,79 @@ +#include "host_connection.hpp" +#include "communicator.hpp" +#include "comm.h" +#include "mscclpp.h" +#include "mscclppfifo.h" +#include "api.h" + +namespace mscclpp { + +HostConnection::Impl::Impl(Communicator* comm, mscclppConn* conn) : comm(comm), conn(conn) { + this->hostConn = conn->hostConn; +} + +HostConnection::Impl::~Impl() { + // TODO: figure out memory ownership. Does this deallocate the mscclppHostConn? Likely not. +} + +MSCCLPP_API_CPP HostConnection::~HostConnection() = default; + +MSCCLPP_API_CPP HostConnection::HostConnection(std::unique_ptr p) : pimpl(std::move(p)) {} + +MSCCLPP_API_CPP int HostConnection::getId() { + return pimpl->conn->connId; +} + +MSCCLPP_API_CPP BufferHandle HostConnection::registerBuffer(void* data, uint64_t size) { + BufferHandle result; + static_assert(sizeof(BufferHandle) == sizeof(mscclppBufferHandle_t)); + mscclppRegisterBufferForConnection(pimpl->comm->pimpl->comm, pimpl->conn->connId, data, size, reinterpret_cast(&result)); + return result; +} + +MSCCLPP_API_CPP int HostConnection::numLocalBuffers() { + return pimpl->conn->bufferRegistrations.size() - 1; +} + +MSCCLPP_API_CPP BufferHandle HostConnection::getLocalBuffer(int index) { + return index + 1; +} + +MSCCLPP_API_CPP int HostConnection::numRemoteBuffers() { + return pimpl->conn->remoteBufferRegistrations.size() - 1; +} + +MSCCLPP_API_CPP BufferHandle HostConnection::getRemoteBuffer(int index) { + return index + 1; +} + +MSCCLPP_API_CPP ConnectionEpoch HostConnection::getEpoch() { + ConnectionEpoch epoch; + static_assert(sizeof(SignalEpochId) == sizeof(mscclppDevConnSignalEpochId)); + epoch.localSignalEpochId = reinterpret_cast(pimpl->conn->devConn->localSignalEpochId); + epoch.remoteSignalEpochId = reinterpret_cast(pimpl->conn->devConn->remoteSignalEpochId); + epoch.waitEpochId = pimpl->conn->devConn->waitEpochId; + return epoch; +} + + +MSCCLPP_API_CPP DeviceProxyFifo HostConnection::getDeviceFifo() { + return pimpl->comm->pimpl->proxy.fifo().toDevice(); +} + +MSCCLPP_API_CPP void HostConnection::put(BufferHandle dst, uint64_t dstOffset, BufferHandle src, uint64_t srcOffset, uint64_t size) { + pimpl->hostConn->put(dst, dstOffset, src, srcOffset, size); +} + +MSCCLPP_API_CPP void HostConnection::signal() { + pimpl->hostConn->signal(); +} + +MSCCLPP_API_CPP void HostConnection::flush() { + pimpl->hostConn->flush(); +} + +MSCCLPP_API_CPP void HostConnection::wait() { + pimpl->hostConn->wait(); +} + +} // namespace mscclpp \ No newline at end of file diff --git a/src/host_connection.cpp b/src/host_connection.cpp deleted file mode 100644 index 6a06de63..00000000 --- a/src/host_connection.cpp +++ /dev/null @@ -1,55 +0,0 @@ -#include "host_connection.hpp" - -namespace mscclpp { - -HostConnection::Impl::Impl() : hostConn(nullptr) {} - -HostConnection::Impl::~Impl() { - // TODO: figure out memory ownership. Does this deallocate the mscclppHostConn? Likely not. -} - -void HostConnection::Impl::setup(mscclppHostConn_t *hostConn) { - this->hostConn = hostConn; -} - -BufferHandle HostConnection::registerBuffer(void* data, uint64_t size) { - -} - -int HostConnection::numRemoteBuffers() { - -} - -BufferHandle HostConnection::getRemoteBuffer(int index) { - -} - -DeviceConnection HostConnection::toDevice(bool startProxyThread = true) { - -} - -void HostConnection::put(BufferHandle dst, uint64_t dstOffset, BufferHandle src, uint64_t srcOffset, uint64_t size) { - -} - -void HostConnection::put(BufferHandle dst, BufferHandle src, uint64_t offset, uint64_t size) { - -} - -void HostConnection::signal() { - -} - -void HostConnection::flush() { - -} - -void HostConnection::wait() { - -} - -void HostConnection::epochIncrement() { - -} - -} // namespace mscclpp \ No newline at end of file diff --git a/src/include/api.h b/src/include/api.h index bc5bd1a6..cf546e39 100644 --- a/src/include/api.h +++ b/src/include/api.h @@ -2,5 +2,6 @@ #define MSCCLPP_API_H_ #define MSCCLPP_API extern "C" __attribute__((visibility("default"))) +#define MSCCLPP_API_CPP __attribute__((visibility("default"))) #endif // MSCCLPP_API_H_ diff --git a/src/include/basic_proxy_handler.hpp b/src/include/basic_proxy_handler.hpp new file mode 100644 index 00000000..1c4b3f86 --- /dev/null +++ b/src/include/basic_proxy_handler.hpp @@ -0,0 +1,13 @@ +#ifndef MSCCLPP_BASIC_PROXY_SERVICE_HPP_ +#define MSCCLPP_BASIC_PROXY_SERVICE_HPP_ + +#include "mscclpp.hpp" +#include "communicator.hpp" + +namespace mscclpp { + +ProxyHandler makeBasicProxyHandler(Communicator::Impl &comm); + +} + +#endif \ No newline at end of file diff --git a/src/include/checks.hpp b/src/include/checks.hpp index ee5f7058..ad985e76 100644 --- a/src/include/checks.hpp +++ b/src/include/checks.hpp @@ -27,29 +27,3 @@ } while (false) #endif - -#include -// Check system calls -#define SYSCHECKTHROW(call, name) \ - do { \ - int retval; \ - SYSCHECKVAL(call, name, retval); \ - } while (false) - -#define SYSCHECKVALTHROW(call, name, retval) \ - do { \ - SYSCHECKSYNC(call, name, retval); \ - if (retval == -1) { \ - std::runtime_error(std::string("Call to " name " failed : ") + strerror(errno)); \ - } \ - } while (false) - -#define SYSCHECKSYNCTHROW(call, name, retval) \ - do { \ - retval = call; \ - if (retval == -1 && (errno == EINTR || errno == EWOULDBLOCK || errno == EAGAIN)) { \ - INFO(MSCCLPP_ALL, "Call to " name " returned %s, retrying", strerror(errno)); \ - } else { \ - break; \ - } \ - } while (true) diff --git a/src/include/communicator.hpp b/src/include/communicator.hpp index e69de29b..8294eeb6 100644 --- a/src/include/communicator.hpp +++ b/src/include/communicator.hpp @@ -0,0 +1,23 @@ +#ifndef MSCCL_COMMUNICATOR_HPP_ +#define MSCCL_COMMUNICATOR_HPP_ + +#include "mscclpp.hpp" +#include "mscclpp.h" + +namespace mscclpp { + +struct Communicator::Impl { + mscclppComm_t comm; + std::vector> connections; + Proxy proxy; + + Impl(); + + ~Impl(); + + friend class HostConnection; +}; + +} // namespace mscclpp + +#endif \ No newline at end of file diff --git a/src/include/host_connection.hpp b/src/include/host_connection.hpp index 4a66c846..495130d9 100644 --- a/src/include/host_connection.hpp +++ b/src/include/host_connection.hpp @@ -3,17 +3,18 @@ #include "mscclpp.hpp" #include "mscclpp.h" +#include "comm.h" namespace mscclpp { struct HostConnection::Impl { + Communicator* comm; + mscclppConn* conn; mscclppHostConn_t* hostConn; - Impl(); + Impl(Communicator* comm, mscclppConn* conn); ~Impl(); - - void setup(mscclppHostConn_t *hostConn); }; } // namespace mscclpp diff --git a/src/include/mscclpp.h b/src/include/mscclpp.h index 0e7f76e5..a9675d1e 100644 --- a/src/include/mscclpp.h +++ b/src/include/mscclpp.h @@ -29,7 +29,7 @@ struct alignas(16) mscclppDevConnSignalEpochId uint64_t proxy; }; -using mscclppBufferHandle_t = uint8_t; +using mscclppBufferHandle_t = uint32_t; /*************************************************************************************************************** * A mscclppDevConn provides a zero-copy connection between two GPUs connected via P2P NVLink or InfiniBand. diff --git a/src/include/mscclpp.hpp b/src/include/mscclpp.hpp index fbc96f43..e41e94b8 100644 --- a/src/include/mscclpp.hpp +++ b/src/include/mscclpp.hpp @@ -13,6 +13,7 @@ #include #include +#include #include @@ -27,15 +28,14 @@ struct alignas(16) SignalEpochId { uint64_t proxy; }; -enum ChannelTriggerType : uint64_t { - channelTriggerData = 0x1, - channelTriggerFlag = 0x2, - channelTriggerSync = 0x4 -}; +using ChannelTriggerType = uint64_t; +const ChannelTriggerType channelTriggerData = 0x1; +const ChannelTriggerType channelTriggerFlag = 0x2; +const ChannelTriggerType channelTriggerSync = 0x4; // This is just a numeric ID. Each HostConnection will have an internal array indexed by these handles // mapping to the actual -using BufferHandle = uint8_t; +using BufferHandle = uint32_t; #define MSCCLPP_BITS_SIZE 32 #define MSCCLPP_BITS_OFFSET 32 @@ -58,15 +58,111 @@ union ChannelTrigger { uint64_t srcBufferHandle : MSCCLPP_BITS_BUFFER_HANDLE; uint64_t dstBufferHandle : MSCCLPP_BITS_BUFFER_HANDLE; uint64_t type : MSCCLPP_BITS_TYPE; + uint64_t connId : MSCCLPP_BITS_CONNID; uint64_t : (64 - MSCCLPP_BITS_OFFSET - MSCCLPP_BITS_BUFFER_HANDLE - MSCCLPP_BITS_BUFFER_HANDLE - MSCCLPP_BITS_TYPE); // ensure 64-bit alignment } fields; - ChannelTrigger() {} - ChannelTrigger(ProxyTrigger value) : value(value) {} - ChannelTrigger(ChannelTriggerType type, BufferHandle dst, uint64_t dstOffset, BufferHandle src, uint64_t srcOffset, uint64_t size) { +#ifdef __CUDACC__ + __device__ ChannelTrigger() {} + __device__ ChannelTrigger(ProxyTrigger value) : value(value) {} + __device__ ChannelTrigger(ChannelTriggerType type, BufferHandle dst, uint64_t dstOffset, BufferHandle src, uint64_t srcOffset, uint64_t size, int connectionId) { value.fst = ((srcOffset << MSCCLPP_BITS_SIZE) + size); - value.snd = (((((((uint64_t)type << MSCCLPP_BITS_BUFFER_HANDLE) + dst) << MSCCLPP_BITS_BUFFER_HANDLE) + src) << MSCCLPP_BITS_OFFSET) + dstOffset); + value.snd = ((((((((connectionId << MSCCLPP_BITS_TYPE) + (uint64_t)type) << MSCCLPP_BITS_BUFFER_HANDLE) + dst) << MSCCLPP_BITS_BUFFER_HANDLE) + src) << MSCCLPP_BITS_OFFSET) + dstOffset); } +#endif // __CUDACC__ +}; + +struct ConnectionEpoch { +#ifdef __CUDACC__ + __forceinline__ __device__ void wait() + { + (*waitEpochId) += 1; + while (*(volatile uint64_t*)&(localSignalEpochId->proxy) < (*waitEpochId)) + ; + } + + __forceinline__ __device__ void epochIncrement() + { + *(volatile uint64_t*)&(localSignalEpochId->device) += 1; + } +#endif // __CUDACC__ + + SignalEpochId* localSignalEpochId; + // used by the signal() function directly from gpu + SignalEpochId* remoteSignalEpochId; + + // every wait(), increments this and then the gpu waits for either: + // 1) localSignalEpochId->proxy to be >= this in case of a proxy thread + // 2) remoteSignalEpochId->device to be >= this in case of a gpu thread + uint64_t* waitEpochId; +}; + +class HostConnection { + struct Impl; +public: + /* HostConnection can not be constructed from user code and must instead be created through Communicator::connect */ + HostConnection(std::unique_ptr); + + ~HostConnection(); + + int getId(); + + /* Register a region of GPU memory for use with this connection. Must be called before connectionSetup() + * in the communicator. + * + * Inputs: + * data: base pointer to the memory + * size: size of the memory region in bytes + * + * Returns: a handle to the buffer + */ + BufferHandle registerBuffer(void* data, uint64_t size); + + /* Get the number of times registerBuffer(...) was called. + * + * Returns: the number of buffers registered + */ + int numLocalBuffers(); + + /* Get the BufferHandle returned by a call to registerBuffer(...) as identified by the index + * + * Inputs: + * index: the index of the handle to get + * + * Returns: a handle to the buffer + */ + BufferHandle getLocalBuffer(int index); + + /* Get the number of times registerBuffer(...) was called on the remote peer. + * + * Returns: the number of buffers registered on the remote peer + */ + int numRemoteBuffers(); + + /* Get the BufferHandle returned by a call to registerBuffer(...) on the remote peer as identified by the index + * + * Inputs: + * index: the index of the handle to get + * + * Returns: a handle to the buffer on the remote peer + */ + BufferHandle getRemoteBuffer(int index); + + ConnectionEpoch getEpoch(); + + DeviceProxyFifo getDeviceFifo(); + + void put(BufferHandle dst, uint64_t dstOffset, BufferHandle src, uint64_t srcOffset, uint64_t size); + + void signal(); + + void flush(); + + void wait(); + +private: + std::unique_ptr pimpl; + friend class Communicator; }; /*************************************************************************************************************** @@ -132,12 +228,20 @@ union ChannelTrigger { * indices in the registered buffer. **************************************************************************************************************/ struct DeviceConnection { -#ifdef __CUDACC__ - // TODO: add buffer handles + DeviceConnection() = default; + DeviceConnection(HostConnection& hostConn) + : connectionId(hostConn.getId()), epoch(hostConn.getEpoch()), + fifo(hostConn.getDeviceFifo()) {} + + DeviceConnection(const DeviceConnection& other) = default; + + DeviceConnection& operator=(DeviceConnection& other) = default; + +#ifdef __CUDACC__ __forceinline__ __device__ void put(BufferHandle dst, uint64_t dstOffset, BufferHandle src, uint64_t srcOffset, uint64_t size) { - fifo.push(ChannelTrigger(channelTriggerData, dst, dstOffset, src, srcOffset, size).value); + fifo.push(ChannelTrigger(channelTriggerData, dst, dstOffset, src, srcOffset, size, connectionId).value); } __forceinline__ __device__ void put(BufferHandle dst, BufferHandle src, uint64_t offset, uint64_t size) @@ -148,13 +252,13 @@ struct DeviceConnection { __forceinline__ __device__ void signal() { epochIncrement(); - fifo.push(ChannelTrigger(channelTriggerFlag, 0, 0, 0, 0, 1).value); + fifo.push(ChannelTrigger(channelTriggerFlag, 0, 0, 0, 0, 1, connectionId).value); } __forceinline__ __device__ void putWithSignal(BufferHandle dst, uint64_t dstOffset, BufferHandle src, uint64_t srcOffset, uint64_t size) { epochIncrement(); - fifo.push(ChannelTrigger(channelTriggerData | channelTriggerFlag, dst, dstOffset, src, srcOffset, size).value); + fifo.push(ChannelTrigger(channelTriggerData | channelTriggerFlag, dst, dstOffset, src, srcOffset, size, connectionId).value); } __forceinline__ __device__ void putWithSignal(BufferHandle dst, BufferHandle src, uint64_t offset, uint64_t size) @@ -165,107 +269,116 @@ struct DeviceConnection { __forceinline__ __device__ void putWithSignalAndFlush(BufferHandle dst, uint64_t dstOffset, BufferHandle src, uint64_t srcOffset, uint64_t size) { epochIncrement(); - uint64_t curFifoHead = fifo.push(channelTriggerData | channelTriggerFlag | channelTriggerSync, dstOffset, srcOffset, size); - while (*(volatile uint64_t*)&fifo.triggerFifo[curFifoHead % MSCCLPP_PROXY_FIFO_SIZE] != 0 && - *(volatile uint64_t*)fifo.triggerFifoTail <= curFifoHead) + uint64_t curFifoHead = fifo.push(ChannelTrigger(channelTriggerData | channelTriggerFlag | channelTriggerSync, dst, dstOffset, src, srcOffset, size, connectionId).value); + while (*(volatile uint64_t*)&fifo.triggers[curFifoHead % MSCCLPP_PROXY_FIFO_SIZE] != 0 && + *(volatile uint64_t*)fifo.tailReplica <= curFifoHead) ; } __forceinline__ __device__ void putWithSignalAndFlush(BufferHandle dst, BufferHandle src, uint64_t offset, uint64_t size) + { + putWithSignalAndFlush(dst, offset, src, offset, size); + } + + __forceinline__ __device__ void flush() + { + uint64_t curFifoHead = fifo.push(ChannelTrigger(mscclppSync, 0, 0, 0, 0, 1, connectionId).value); + // we need to wait for two conditions to be met to ensure the CPU is done flushing. (1) wait for the tail + // to go pass by curFifoHead (this is safety net) and (2) wait for the work element value to change to 0. + while (*(volatile uint64_t*)&fifo.triggers[curFifoHead % MSCCLPP_PROXY_FIFO_SIZE] != 0 && + *(volatile uint64_t*)fifo.tailReplica <= curFifoHead) + ; + } + + __forceinline__ __device__ void wait() + { + epoch.wait(); + } + + __forceinline__ __device__ void epochIncrement() + { + epoch.epochIncrement(); + } +#endif // __CUDACC__ + + int connectionId; + + ConnectionEpoch epoch; + + // this is a concurrent fifo which is multiple threads from the device + // can produce for and the sole proxy thread consumes it. + DeviceProxyFifo fifo; +}; + +struct SimpleDeviceConnection { + SimpleDeviceConnection() = default; + + SimpleDeviceConnection(HostConnection& hostConn) : devConn(hostConn) { + dst = hostConn.getRemoteBuffer(0); + src = hostConn.getLocalBuffer(0); + } + + SimpleDeviceConnection(const SimpleDeviceConnection& other) = default; + + SimpleDeviceConnection& operator=(SimpleDeviceConnection& other) = default; + +#ifdef __CUDACC__ + + __forceinline__ __device__ void put(uint64_t dstOffset, uint64_t srcOffset, uint64_t size) + { + devConn.put(dst, dstOffset, src, srcOffset, size); + } + + __forceinline__ __device__ void put(uint64_t offset, uint64_t size) + { + put(offset, offset, size); + } + + __forceinline__ __device__ void signal() + { + devConn.signal(); + } + + __forceinline__ __device__ void putWithSignal(uint64_t dstOffset, uint64_t srcOffset, uint64_t size) + { + devConn.putWithSignal(dst, dstOffset, src, srcOffset, size); + } + + __forceinline__ __device__ void putWithSignal(uint64_t offset, uint64_t size) + { + putWithSignal(offset, offset, size); + } + + __forceinline__ __device__ void putWithSignalAndFlush(uint64_t dstOffset, uint64_t srcOffset, uint64_t size) + { + devConn.putWithSignalAndFlush(dst, dstOffset, src, srcOffset, size); + } + + __forceinline__ __device__ void putWithSignalAndFlush(uint64_t offset, uint64_t size) { putWithSignalAndFlush(offset, offset, size); } __forceinline__ __device__ void flush() { - uint64_t curFifoHead = fifo.push(mscclppSync, 0, 0, 1); - // we need to wait for two conditions to be met to ensure the CPU is done flushing. (1) wait for the tail - // to go pass by curFifoHead (this is safety net) and (2) wait for the work element value to change to 0. - while (*(volatile uint64_t*)&fifo.triggerFifo[curFifoHead % MSCCLPP_PROXY_FIFO_SIZE] != 0 && - *(volatile uint64_t*)fifo.triggerFifoTail <= curFifoHead) - ; + devConn.flush(); } __forceinline__ __device__ void wait() { - (*waitEpochId) += 1; - while (*(volatile uint64_t*)&(localSignalEpochId->proxy) < (*waitEpochId)) - ; + devConn.wait(); } __forceinline__ __device__ void epochIncrement() { - *(volatile uint64_t*)&(localSignalEpochId->device) += 1; + devConn.epochIncrement(); } #endif // __CUDACC__ - int remoteRank; - int tag; - - SignalEpochId* localSignalEpochId; - // used by the signal() function directly from gpu - SignalEpochId* remoteSignalEpochId; - - // every wait(), increments this and then the gpu waits for either: - // 1) localSignalEpochId->proxy to be >= this in case of a proxy thread - // 2) remoteSignalEpochId->device to be >= this in case of a gpu thread - uint64_t* waitEpochId; - - // this is a concurrent fifo which is multiple threads from the device - // can produce for and the sole proxy thread consumes it. - ProxyFifo fifo; -}; - -class HostConnection { -public: - /* Register a region of GPU memory for use with this connection. Must be called before connectionSetup() - * in the communicator. - * - * Inputs: - * data: base pointer to the memory - * size: size of the memory region in bytes - * - * Returns: a handle to the buffer - */ - BufferHandle registerBuffer(void* data, uint64_t size); - - /* Get the number of times registerBuffer(...) was called on the remote peer. - * - * Returns: the number of buffers registered on the remote peer - */ - int numRemoteBuffers(); - - /* Get the BufferHandle returned by a call to registerBuffer(...) on the remote peer as identified by the index - * - * Inputs: - * index: the index of the handle to get - * - * Returns: a handle to the buffer on the remote peer - */ - BufferHandle getRemoteBuffer(int index); - - /* Create a DeviceConnection paired with this HostConnection. A background proxy thread will - * trigger operations on this HostConnection corresponding to put/signal/etc. calls made to the - * DeviceConnection. - * - * Inputs: - * startProxyThread: whether to start the proxy thread (default is true) - * - * Returns: the newly created DeviceConnection - */ - DeviceConnection toDevice(bool startProxyThread = true); - - void put(BufferHandle dst, uint64_t dstOffset, BufferHandle src, uint64_t srcOffset, uint64_t size); - void put(BufferHandle dst, BufferHandle src, uint64_t offset, uint64_t size); - void signal(); - void flush(); - void wait(); - void epochIncrement(); - -private: - struct Impl; - std::unique_ptr pimpl; + DeviceConnection devConn; + BufferHandle dst; + BufferHandle src; }; #define MSCCLPP_UNIQUE_ID_BYTES 128 @@ -290,6 +403,7 @@ enum class TransportType : uint8_t { class Communicator { public: + /* Initialize the communicator. nranks processes with rank 0 to nranks-1 need to call this function. * * Inputs: @@ -297,7 +411,7 @@ public: * ipPortPair: a string of the form "ip:port" that represents the address of the root process * rank: rank of the calling process */ - void initRank(int nranks, const char* ipPortPair, int rank); + Communicator(int nranks, const char* ipPortPair, int rank); /* Initialize the communicator from a given UniqueId. Same as mscclppCommInitRank() except that * id is provided by the user by calling getUniqueId() @@ -307,7 +421,9 @@ public: * id: the unique ID to be used for communication * rank: rank of the calling process */ - void initRankFromId(int nranks, UniqueId id, int rank); + Communicator(int nranks, UniqueId id, int rank); + + ~Communicator(); /* Ring-based AllGather through the bootstrap socket. * @@ -341,6 +457,12 @@ public: */ void connectionSetup(); + /* Launch proxy thread(s). This function is supposed to be called before starting a kernel that uses DeviceConnection. */ + void startProxying(); + + /* Stop proxy thread(s). */ + void stopProxying(); + /* Return the rank of the calling process. * * Outputs: @@ -355,6 +477,33 @@ public: */ int size(); + struct Impl; +private: + std::unique_ptr pimpl; + friend class HostConnection; +}; + +enum class ProxyHandlerResult { + Continue, + FlushFifoTailAndContinue, + Stop, +}; + +class Proxy; +using ProxyHandler = std::function; + +class Proxy { +public: + Proxy(ProxyHandler handler); + + ~Proxy(); + + void start(); + + void stop(); + + HostProxyFifo& fifo(); + private: struct Impl; std::unique_ptr pimpl; diff --git a/src/include/mscclppfifo.hpp b/src/include/mscclppfifo.hpp index 27abd4c5..b5f8ba4c 100644 --- a/src/include/mscclppfifo.hpp +++ b/src/include/mscclppfifo.hpp @@ -3,6 +3,7 @@ #include #include +#include namespace mscclpp { @@ -13,39 +14,56 @@ struct alignas(16) ProxyTrigger { /* This is a concurrent fifo where multiple device threads can push mscclppTrigger work elements to * and a single host proxy thread consumes these work elements. There is a head pointer allocated on device * which starts with 0 and goes to 2^64-1 which is almost infinity. There are two copies of tail, one - * that is on the deivce (triggerFifoTail) and another that is on host (proxyState->fifoTailHost). + * that is on the deivce (tailReplica) and another that is on host (proxyState->fifoTailHost). * The host always has the "true" tail and occasionally, pushes it to the copy on the device. * Therefore, most of the time, the device has a stale version. The invariants are: - * triggerFifoTail <= proxyState->fifoTailHost <= triggerFifoHead. - * push() function increments triggerFifoHead, proxyState->fifoTailHost is updated in proxy.cc:mscclppProxyService - * and it occasionally flushes it to triggerFifoTail via a cudaMemcpyAsync. + * tailReplica <= proxyState->fifoTailHost <= head. + * push() function increments head, proxyState->fifoTailHost is updated in proxy.cc:mscclppProxyService + * and it occasionally flushes it to tailReplica via a cudaMemcpyAsync. * * Why duplicating the tail is a good idea? The fifo is large engouh and we do not need frequent updates * for the tail as there is usually enough space for device threads to push their work into. */ -struct ProxyFifo { +struct DeviceProxyFifo { #ifdef __CUDACC__ - __forceinline__ __device__ uint64_t push(ProxyTrigger element) + __forceinline__ __device__ uint64_t push(ProxyTrigger trigger) { - uint64_t curFifoHead = atomicAdd((unsigned long long int*)this->triggerFifoHead, 1); - while (curFifoHead >= MSCCLPP_PROXY_FIFO_SIZE + *((volatile uint64_t*)this->triggerFifoTail)) + uint64_t curFifoHead = atomicAdd((unsigned long long int*)this->head, 1); + while (curFifoHead >= MSCCLPP_PROXY_FIFO_SIZE + *((volatile uint64_t*)this->tailReplica)) ; - while (*(volatile uint64_t*)&this->triggerFifo[curFifoHead % MSCCLPP_PROXY_FIFO_SIZE] != 0) + while (*(volatile uint64_t*)&this->triggers[curFifoHead % MSCCLPP_PROXY_FIFO_SIZE] != 0) ; - uint64_t* valptr = (uint64_t*)&(this->triggerFifo[curFifoHead % MSCCLPP_PROXY_FIFO_SIZE].value); - asm volatile("st.volatile.global.v2.u64 [%0], {%1,%2};" ::"l"(valptr), - "l"(element.value[0]), "l"(element.value[1])); + ProxyTrigger* triggerPtr = (ProxyTrigger*)&(this->triggers[curFifoHead % MSCCLPP_PROXY_FIFO_SIZE]); + asm volatile("st.volatile.global.v2.u64 [%0], {%1,%2};" ::"l"(triggerPtr), + "l"(trigger.fst), "l"(trigger.snd)); return curFifoHead; } #endif // __CUDACC__ - void startProxyThread(std::function handler); - void stopProxyThread(); - - ProxyTrigger* triggerFifo; // Allocate on host via cudaHostAlloc. This space is used for pushing the workelements - uint64_t* triggerFifoTail; // Allocated on device. proxyState->fifoTailHost is the true tail on host and pused + ProxyTrigger* triggers; // Allocate on host via cudaHostAlloc. This space is used for pushing the workelements + uint64_t* tailReplica; // Allocated on device. proxyState->fifoTailHost is the true tail on host and pused // occasionally to device - uint64_t* triggerFifoHead; // Allocated on device. Only accessed by device + uint64_t* head; // Allocated on device. Only accessed by device +}; + +class HostProxyFifo +{ +public: + HostProxyFifo(); + + ~HostProxyFifo(); + + void poll(ProxyTrigger *trigger); + + void pop(); + + void flushTail(bool sync = false); + + DeviceProxyFifo toDevice(); + +private: + struct Impl; + std::unique_ptr pimpl; }; } // namespace mscclpp diff --git a/src/init.cc b/src/init.cc index 7c3b76b9..7cf159c8 100644 --- a/src/init.cc +++ b/src/init.cc @@ -629,7 +629,7 @@ struct connInfo h.numBufferInfos = bufferInfos.size(); MSCCLPPCHECK(bootstrapSend(bootstrap, remoteRank, tag, &h, sizeof(header))); MSCCLPPCHECK(bootstrapSend(bootstrap, remoteRank, tag, bufferInfos.data(), bufferInfos.size() * sizeof(mscclppBufferRegistrationInfo))); - return mscclppSuccess; + return mscclppSuccess; } mscclppResult_t recvOverBootstrap(void* bootstrap, int remoteRank, int tag) { @@ -638,7 +638,7 @@ struct connInfo infoQp = h.infoQp; bufferInfos.resize(h.numBufferInfos); MSCCLPPCHECK(bootstrapRecv(bootstrap, remoteRank, tag, bufferInfos.data(), bufferInfos.size() * sizeof(mscclppBufferRegistrationInfo))); - return mscclppSuccess; + return mscclppSuccess; } }; diff --git a/src/proxy_cpp.cc b/src/proxy_cpp.cc new file mode 100644 index 00000000..2d1cf098 --- /dev/null +++ b/src/proxy_cpp.cc @@ -0,0 +1,97 @@ +#include "mscclpp.hpp" +#include "utils.h" +#include "api.h" +#include +#include + +namespace mscclpp { + +const int ProxyStopCheckPeriod = 1000; + +const int ProxyFlushPeriod = 4; + +struct Proxy::Impl { + ProxyHandler handler; + HostProxyFifo fifo; + std::thread service; + std::atomic_bool running; + + Impl(ProxyHandler handler) : handler(handler), running(false) {} +}; + +MSCCLPP_API_CPP Proxy::Proxy(ProxyHandler handler) { + pimpl = std::make_unique(handler); +} + +MSCCLPP_API_CPP Proxy::~Proxy() { + if (pimpl) { + stop(); + } +} + +MSCCLPP_API_CPP void Proxy::start() { + pimpl->running = true; + pimpl->service = std::thread([this] { + // from this point on, proxy thread will stay close to the device + // PROXYMSCCLPPCHECK(numaBind(pimpl->comm->devNumaNode)); // TODO: reenable this + + ProxyHandler handler = this->pimpl->handler; + HostProxyFifo& fifo = this->pimpl->fifo; + std::atomic_bool& running = this->pimpl->running; + ProxyTrigger trigger; + + int runCnt = ProxyStopCheckPeriod; + uint64_t flushCnt = 0; + for (;;) { + if (runCnt-- == 0) { + runCnt = ProxyStopCheckPeriod; + if (!running) { + break; + } + } + // Poll to see if we are ready to send anything + fifo.poll(&trigger); + if (trigger.fst == 0) { // TODO: this check is a potential pitfall for custom triggers + continue; // there is one in progress + } + + ProxyHandlerResult result = handler(trigger); + + // Send completion: reset only the high 64 bits + fifo.pop(); + // Flush the tail to device memory. This is either triggered every ProxyFlushPeriod to make sure + // that the fifo can make progress even if there is no request mscclppSync. However, mscclppSync type is for flush + // request. + if ((++flushCnt % ProxyFlushPeriod) == 0 || result == ProxyHandlerResult::FlushFifoTailAndContinue) { + // TODO: relocate this check: || (trigger.fields.type & mscclppSync) + fifo.flushTail(); + } + + if (result == ProxyHandlerResult::Stop) { + break; + } + } + + // make sure the tail is flushed before we shut the proxy + fifo.flushTail(/*sync=*/true); + // TODO: do these need to run? + // bool isP2pProxy = (proxyState->ibContext == nullptr); + // if (isP2pProxy) { + // cudaStream_t p2pStream = proxyState->p2pStream; + // PROXYCUDACHECK(cudaStreamSynchronize(p2pStream)); + // } + }); +} + +MSCCLPP_API_CPP void Proxy::stop() { + pimpl->running = false; + if (pimpl->service.joinable()) { + pimpl->service.join(); + } +} + +MSCCLPP_API_CPP HostProxyFifo& Proxy::fifo() { + return pimpl->fifo; +} + +} // namespace mscclpp \ No newline at end of file diff --git a/tests/allgather_test_cpp.cu b/tests/allgather_test_cpp.cu index ca30945f..9b056e84 100644 --- a/tests/allgather_test_cpp.cu +++ b/tests/allgather_test_cpp.cu @@ -10,6 +10,8 @@ #include #include #include +#include +#include static int nranksPerNode = 8; @@ -46,9 +48,9 @@ static double getTime(void) return (tspec.tv_nsec / 1.0e9) + tspec.tv_sec; } -__constant__ mscclpp::DeviceConnection constDevConns[16]; +__constant__ mscclpp::SimpleDeviceConnection constDevConns[16]; -__device__ void allgather0(mscclppDevConn_t devConn, int rank, int world_size, int remoteRank, size_t nelemsPerGPU) +__device__ void allgather0(mscclpp::SimpleDeviceConnection devConn, int rank, int world_size, int remoteRank, size_t nelemsPerGPU) { // this allgather is really simple and implemented as an alltoall @@ -67,7 +69,7 @@ __device__ void allgather0(mscclppDevConn_t devConn, int rank, int world_size, i devConn.wait(); } -__device__ void localAllGather(mscclppDevConn_t devConn, int rank, int world_size, int nranksPerNode, int remoteRank, +__device__ void localAllGather(mscclpp::SimpleDeviceConnection devConn, int rank, int world_size, int nranksPerNode, int remoteRank, uint64_t offset, uint64_t size) { // this allgather algorithm works as follows: @@ -91,14 +93,14 @@ __device__ void localAllGather(mscclppDevConn_t devConn, int rank, int world_siz } } -__device__ void allgather1(mscclppDevConn_t devConn, int rank, int world_size, int nranksPerNode, int remoteRank, +__device__ void allgather1(mscclpp::SimpleDeviceConnection devConn, int rank, int world_size, int nranksPerNode, int remoteRank, size_t nelemsPerGPU) { localAllGather(devConn, rank, world_size, nranksPerNode, remoteRank, rank * nelemsPerGPU * sizeof(int), nelemsPerGPU * sizeof(int)); } -__device__ void allgather2(mscclppDevConn_t devConn, int rank, int world_size, int nranksPerNode, int remoteRank, +__device__ void allgather2(mscclpp::SimpleDeviceConnection devConn, int rank, int world_size, int nranksPerNode, int remoteRank, size_t nelemsPerGPU) { // this allgather is a pipelined and hierarchical one and only works for two nodes @@ -167,7 +169,7 @@ __global__ void kernel(int rank, int world_size, int nranksPerNode, size_t nelem int warpId = threadIdx.x / 32; int remoteRank = (warpId < rank) ? warpId : warpId + 1; // Each warp is responsible for one of the remote ranks - mscclppDevConn_t devConn = constDevConns[warpId]; + mscclpp::SimpleDeviceConnection devConn = constDevConns[warpId]; if (kernel == 0) allgather0(devConn, rank, world_size, remoteRank, nelemsPerGPU); @@ -219,7 +221,7 @@ void setupMscclppConnections(int rank, int world_size, mscclpp::Communicator& co int thisNode = rankToNode(rank); int cudaNum = rankToLocalRank(rank); std::string ibDevStr = "mlx5_ib" + std::to_string(cudaNum); - std::vector devConns(world_size); + std::vector> hostConns; for (int r = 0; r < world_size; ++r) { if (r == rank) @@ -235,12 +237,19 @@ void setupMscclppConnections(int rank, int world_size, mscclpp::Communicator& co // Connect with all other ranks auto hostConn = comm.connect(r, 0, transportType, ibDev); hostConn->registerBuffer(data_d, dataSize); - devConns.push_back(hostConn->toDevice(false)); + hostConns.push_back(hostConn); } comm.connectionSetup(); - assert(devConns.size() < sizeof(constDevConns) / sizeof(mscclpp::DeviceConnection)); - CUDACHECK(cudaMemcpyToSymbol(constDevConns, devConns.data(), sizeof(mscclpp::DeviceConnection) * devConns.size() )); + + std::vector devConns; + std::transform(hostConns.begin(), hostConns.end(), std::back_inserter(devConns), + [](std::shared_ptr& hostConn) { + return mscclpp::SimpleDeviceConnection(*hostConn); + }); + + assert(devConns.size() < sizeof(constDevConns) / sizeof(mscclpp::SimpleDeviceConnection)); + CUDACHECK(cudaMemcpyToSymbol(constDevConns, devConns.data(), sizeof(mscclpp::SimpleDeviceConnection) * devConns.size() )); } void printUsage(const char* prog, bool isMpi) @@ -391,12 +400,9 @@ int main(int argc, const char* argv[]) size_t nelemsPerGPU = dataSize / sizeof(int) / world_size; try{ - mscclpp::Communicator comm; - if (rank == 0) printf("Initializing MSCCL++\n"); - - comm.initRank(world_size, ip_port, rank); + mscclpp::Communicator comm(world_size, ip_port, rank); if (rank == 0) printf("Initializing data for allgather test\n"); @@ -406,97 +412,93 @@ int main(int argc, const char* argv[]) printf("Setting up the connection in MSCCL++\n"); setupMscclppConnections(rank, world_size, comm, data_d, dataSize); + if (rank == 0) + printf("Launching MSCCL++ proxy threads\n"); + comm.startProxying(); + + if (rank == 0) + printf("Testing the correctness of AllGather implementation\n"); + cudaStream_t stream; + CUDACHECK(cudaStreamCreateWithFlags(&stream, cudaStreamNonBlocking)); + CUDACHECK(cudaDeviceSynchronize()); + kernel<<<1, 32 * (world_size - 1), 0, stream>>>(rank, world_size, nranksPerNode, nelemsPerGPU, kernelNum); + CUDACHECK(cudaDeviceSynchronize()); + CUDACHECK(cudaMemcpy(data_h, data_d, dataSize, cudaMemcpyDeviceToHost)); + + for (size_t i = 0; i < nelemsPerGPU * world_size; i++) { + int val = i + 1; + if (data_h[i] != val) { + printf("oh uh! data_h[%ld] (%d) != val (%d)\n", i, data_h[i], val); + break; + } + } + int tmp[16]; + // A simple barrier + comm.bootstrapAllGather(tmp, sizeof(int)); + if (rank == 0) + printf("Successfully checked the correctness\n"); + + // Perf test + int iterwithoutcudagraph = 10; + if (rank == 0) + printf("Running %d iterations of the kernel without CUDA graph\n", iterwithoutcudagraph); + CUDACHECK(cudaStreamSynchronize(stream)); + comm.bootstrapAllGather(tmp, sizeof(int)); + for (int i = 0; i < iterwithoutcudagraph; ++i) { + kernel<<<1, 32 * (world_size - 1), 0, stream>>>(rank, world_size, nranksPerNode, nelemsPerGPU, kernelNum); + } + CUDACHECK(cudaStreamSynchronize(stream)); + comm.bootstrapAllGather(tmp, sizeof(int)); + + // cudaGraph Capture + int cudagraphiter = 10; + if (rank == 0) + printf("Capturing %d iterations of the kernel in a CUDA graph\n", cudagraphiter); + cudaGraph_t graph; + cudaGraphExec_t instance; + cudaStreamBeginCapture(stream, cudaStreamCaptureModeGlobal); + for (int i = 0; i < cudagraphiter; ++i) { + kernel<<<1, 32 * (world_size - 1), 0, stream>>>(rank, world_size, nranksPerNode, nelemsPerGPU, kernelNum); + } + cudaStreamEndCapture(stream, &graph); + cudaGraphInstantiate(&instance, graph, NULL, NULL, 0); + + int cudagraphwarmup = 10; + if (rank == 0) + printf("Warming up %d iterations of the CUDA graph with %d iterations of the kernel\n", cudagraphwarmup, + cudagraphiter); + for (int i = 0; i < cudagraphwarmup; ++i) { + cudaGraphLaunch(instance, stream); + } + CUDACHECK(cudaStreamSynchronize(stream)); + + // measure runtime + int cudagraphlaunch = 10; + if (rank == 0) + printf("Running %d iterations of the CUDA graph with %d iterations of the kernel\n", cudagraphlaunch, + cudagraphiter); + comm.bootstrapAllGather(tmp, sizeof(int)); + double t0, t1, ms, time_in_us; + t0 = getTime(); + for (int i = 0; i < cudagraphlaunch; ++i) { + cudaGraphLaunch(instance, stream); + } + CUDACHECK(cudaStreamSynchronize(stream)); + + t1 = getTime(); + ms = (t1 - t0) * 1000.0; + time_in_us = ms * 1000. / (float)cudagraphlaunch / (float)cudagraphiter; + printf("Rank %d report: size %lu time: %f us/iter algBW %f GBps\n", rank, dataSize, time_in_us, + (double)(dataSize) / 1e9 / (time_in_us / 1e6)); + comm.bootstrapAllGather(tmp, sizeof(int)); + + if (rank == 0) + printf("Stopping MSCCL++ proxy threads\n"); + comm.stopProxying(); + } catch (std::exception& e) { // todo: throw exceptions in the implementation and process them here } - - if (rank == 0) - printf("Launching MSCCL++ proxy threads\n"); - MSCCLPPCHECK(mscclppProxyLaunch(comm)); - - if (rank == 0) - printf("Testing the correctness of AllGather implementation\n"); - cudaStream_t stream; - CUDACHECK(cudaStreamCreateWithFlags(&stream, cudaStreamNonBlocking)); - CUDACHECK(cudaDeviceSynchronize()); - kernel<<<1, 32 * (world_size - 1), 0, stream>>>(rank, world_size, nranksPerNode, nelemsPerGPU, kernelNum); - CUDACHECK(cudaDeviceSynchronize()); - CUDACHECK(cudaMemcpy(data_h, data_d, dataSize, cudaMemcpyDeviceToHost)); - - for (size_t i = 0; i < nelemsPerGPU * world_size; i++) { - int val = i + 1; - if (data_h[i] != val) { - printf("oh uh! data_h[%ld] (%d) != val (%d)\n", i, data_h[i], val); - break; - } - } - int tmp[16]; - // A simple barrier - MSCCLPPCHECK(mscclppBootstrapAllGather(comm, tmp, sizeof(int))); - if (rank == 0) - printf("Successfully checked the correctness\n"); - - // Perf test - int iterwithoutcudagraph = 10; - if (rank == 0) - printf("Running %d iterations of the kernel without CUDA graph\n", iterwithoutcudagraph); - CUDACHECK(cudaStreamSynchronize(stream)); - MSCCLPPCHECK(mscclppBootstrapAllGather(comm, tmp, sizeof(int))); - for (int i = 0; i < iterwithoutcudagraph; ++i) { - kernel<<<1, 32 * (world_size - 1), 0, stream>>>(rank, world_size, nranksPerNode, nelemsPerGPU, kernelNum); - } - CUDACHECK(cudaStreamSynchronize(stream)); - MSCCLPPCHECK(mscclppBootstrapAllGather(comm, tmp, sizeof(int))); - - // cudaGraph Capture - int cudagraphiter = 10; - if (rank == 0) - printf("Capturing %d iterations of the kernel in a CUDA graph\n", cudagraphiter); - cudaGraph_t graph; - cudaGraphExec_t instance; - cudaStreamBeginCapture(stream, cudaStreamCaptureModeGlobal); - for (int i = 0; i < cudagraphiter; ++i) { - kernel<<<1, 32 * (world_size - 1), 0, stream>>>(rank, world_size, nranksPerNode, nelemsPerGPU, kernelNum); - } - cudaStreamEndCapture(stream, &graph); - cudaGraphInstantiate(&instance, graph, NULL, NULL, 0); - - int cudagraphwarmup = 10; - if (rank == 0) - printf("Warming up %d iterations of the CUDA graph with %d iterations of the kernel\n", cudagraphwarmup, - cudagraphiter); - for (int i = 0; i < cudagraphwarmup; ++i) { - cudaGraphLaunch(instance, stream); - } - CUDACHECK(cudaStreamSynchronize(stream)); - - // measure runtime - int cudagraphlaunch = 10; - if (rank == 0) - printf("Running %d iterations of the CUDA graph with %d iterations of the kernel\n", cudagraphlaunch, - cudagraphiter); - MSCCLPPCHECK(mscclppBootstrapAllGather(comm, tmp, sizeof(int))); - double t0, t1, ms, time_in_us; - t0 = getTime(); - for (int i = 0; i < cudagraphlaunch; ++i) { - cudaGraphLaunch(instance, stream); - } - CUDACHECK(cudaStreamSynchronize(stream)); - - t1 = getTime(); - ms = (t1 - t0) * 1000.0; - time_in_us = ms * 1000. / (float)cudagraphlaunch / (float)cudagraphiter; - printf("Rank %d report: size %lu time: %f us/iter algBW %f GBps\n", rank, dataSize, time_in_us, - (double)(dataSize) / 1e9 / (time_in_us / 1e6)); - MSCCLPPCHECK(mscclppBootstrapAllGather(comm, tmp, sizeof(int))); - - if (rank == 0) - printf("Stopping MSCCL++ proxy threads\n"); - MSCCLPPCHECK(mscclppProxyStop(comm)); - - if (rank == 0) - printf("Destroying MSCCL++ communicator\n"); - MSCCLPPCHECK(mscclppCommDestroy(comm)); printf("Rank %d succeeded!\n", rank); #ifdef MSCCLPP_USE_MPI_FOR_TESTS