merged with api-extension

This commit is contained in:
Saeed Maleki
2023-04-24 23:26:28 +00:00
18 changed files with 784 additions and 394 deletions

View File

@@ -120,6 +120,7 @@ LDFLAGS := $(NVLDFLAGS) $(GDRCOPY_LDFLAGS) -libverbs -lnuma
LIBSRCS := $(addprefix src/,debug.cc utils.cc init.cc proxy.cc ib.cc config.cc)
LIBSRCS += $(addprefix src/bootstrap/,bootstrap.cc socket.cc)
LIBSRCS += $(addprefix src/,communicator.cc fifo.cc host_connection.cc proxy_cpp.cc basic_proxy_handler.cc)
ifneq ($(NPKIT), 0)
LIBSRCS += $(addprefix src/misc/,npkit.cc)
endif
@@ -147,7 +148,7 @@ UTOBJTARGETS := $(UTOBJS:%=$(BUILDDIR)/$(OBJDIR)/%)
UTBINS := $(patsubst %.o,$(BUILDDIR)/$(BINDIR)/%,$(UTOBJS))
TESTSDIR := tests
TESTSSRCS := $(addprefix $(TESTSDIR)/,bootstrap_test.cc allgather_test_standalone.cu bootstrap_test_cpp.cc)
TESTSSRCS := $(addprefix $(TESTSDIR)/,bootstrap_test.cc allgather_test_standalone.cu allgather_test_cpp.cu bootstrap_test_cpp.cc)
TESTSOBJS := $(patsubst %.cc,%.o,$(TESTSSRCS)) $(patsubst %.cu,%.o,$(TESTSSRCS))
TESTSOBJTARGETS := $(TESTSOBJS:%=$(BUILDDIR)/$(OBJDIR)/%)
TESTSBINS := $(patsubst %.o,$(BUILDDIR)/$(BINDIR)/%,$(TESTSOBJS))

View File

@@ -0,0 +1,29 @@
#include "basic_proxy_handler.hpp"
namespace mscclpp {
ProxyHandler makeBasicProxyHandler(Communicator::Impl &comm) {
return [&comm](ProxyTrigger triggerRaw) {
ChannelTrigger *trigger = reinterpret_cast<ChannelTrigger*>(&triggerRaw);
HostConnection& conn = *comm.connections.at(trigger->fields.connId);
auto result = ProxyHandlerResult::Continue;
if (trigger->fields.type & mscclppData) {
conn.put(trigger->fields.dstBufferHandle, trigger->fields.dstOffset, trigger->fields.srcBufferHandle, trigger->fields.srcOffset, trigger->fields.size);
}
if (trigger->fields.type & mscclppFlag) {
conn.signal();
}
if (trigger->fields.type & mscclppSync) {
conn.flush();
result = ProxyHandlerResult::FlushFifoTailAndContinue;
}
return result;
};
}
} // namespace mscclpp

81
src/communicator.cc Normal file
View File

@@ -0,0 +1,81 @@
#include "communicator.hpp"
#include "host_connection.hpp"
#include "comm.h"
#include "basic_proxy_handler.hpp"
#include "api.h"
namespace mscclpp {
Communicator::Impl::Impl() : comm(nullptr), proxy(makeBasicProxyHandler(*this)) {}
Communicator::Impl::~Impl() {
if (comm) {
mscclppCommDestroy(comm);
}
}
MSCCLPP_API_CPP Communicator::~Communicator() = default;
mscclppTransport_t transportTypeToCStyle(TransportType type) {
switch (type) {
case TransportType::IB:
return mscclppTransportIB;
case TransportType::P2P:
return mscclppTransportP2P;
default:
throw std::runtime_error("Unknown transport type");
}
}
MSCCLPP_API_CPP Communicator::Communicator(int nranks, const char* ipPortPair, int rank) : pimpl(std::make_unique<Impl>()) {
mscclppCommInitRank(&pimpl->comm, nranks, ipPortPair, rank);
}
MSCCLPP_API_CPP Communicator::Communicator(int nranks, UniqueId id, int rank) : pimpl(std::make_unique<Impl>()) {
static_assert(sizeof(mscclppUniqueId) == sizeof(UniqueId), "UniqueId size mismatch");
mscclppUniqueId *cstyle_id = reinterpret_cast<mscclppUniqueId*>(&id);
mscclppCommInitRankFromId(&pimpl->comm, nranks, *cstyle_id, rank);
}
MSCCLPP_API_CPP void Communicator::bootstrapAllGather(void* data, int size) {
mscclppBootstrapAllGather(pimpl->comm, data, size);
}
MSCCLPP_API_CPP void Communicator::bootstrapBarrier() {
mscclppBootstrapBarrier(pimpl->comm);
}
MSCCLPP_API_CPP std::shared_ptr<HostConnection> Communicator::connect(int remoteRank, int tag,
TransportType transportType, const char* ibDev) {
mscclppConnectWithoutBuffer(pimpl->comm, remoteRank, tag, transportTypeToCStyle(transportType), ibDev);
auto connIdx = pimpl->connections.size();
auto conn = std::make_shared<HostConnection>(std::make_unique<HostConnection::Impl>(this, &pimpl->comm->conns[connIdx]));
pimpl->connections.push_back(conn);
return conn;
}
MSCCLPP_API_CPP void Communicator::connectionSetup() {
mscclppConnectionSetup(pimpl->comm);
}
MSCCLPP_API_CPP void Communicator::startProxying() {
pimpl->proxy.start();
}
MSCCLPP_API_CPP void Communicator::stopProxying() {
pimpl->proxy.stop();
}
MSCCLPP_API_CPP int Communicator::rank() {
int result;
mscclppCommRank(pimpl->comm, &result);
return result;
}
MSCCLPP_API_CPP int Communicator::size() {
int result;
mscclppCommSize(pimpl->comm, &result);
return result;
}
} // namespace mscclpp

View File

@@ -1,90 +0,0 @@
#include "mscclpp.hpp"
#include "mscclpp.h"
namespace mscclpp {
mscclppTransport_t transportTypeToCStyle(TransportType type) {
switch (type) {
case TransportType::IB:
return mscclppTransportIB;
case TransportType::P2P:
return mscclppTransportP2P;
default:
throw std::runtime_error("Unknown transport type");
}
}
struct Communicator::Impl {
mscclppComm_t comm;
std::vector<std::shared_ptr<HostConnection>> connections;
Impl() : comm(nullptr) {}
~Impl() {
if (comm) {
mscclppCommDestroy(comm);
}
}
};
void Communicator::initRank(int nranks, const char* ipPortPair, int rank) {
if (pimpl) {
throw std::runtime_error("Communicator already initialized");
}
pimpl = std::make_unique<Impl>();
mscclppCommInitRank(&pimpl->comm, nranks, ipPortPair, rank);
}
void Communicator::initRankFromId(int nranks, UniqueId id, int rank) {
if (pimpl) {
throw std::runtime_error("Communicator already initialized");
}
pimpl = std::make_unique<Impl>();
static_assert(sizeof(mscclppUniqueId) == sizeof(UniqueId), "UniqueId size mismatch");
mscclppUniqueId *cstyle_id = reinterpret_cast<mscclppUniqueId*>(&id);
mscclppCommInitRankFromId(&pimpl->comm, nranks, *cstyle_id, rank);
}
void Communicator::bootstrapAllGather(void* data, int size) {
mscclppBootstrapAllGather(pimpl->comm, data, size);
}
void Communicator::bootstrapBarrier() {
mscclppBootstrapBarrier(pimpl->comm);
}
std::shared_ptr<HostConnection> Communicator::connect(int remoteRank, int tag,
TransportType transportType, const char* ibDev = 0) {
mscclppConnectWithoutBuffer(pimpl->comm, remoteRank, tag, transportTypeToCStyle(transportType), ibDev);
auto conn = std::make_shared<HostConnection>();
auto connIdx = pimpl->connections.size();
pimpl->connections.push_back(conn);
return conn;
}
void Communicator::connectionSetup() {
mscclppConnectionSetup(pimpl->comm);
mscclppHostConn_t *hostConns;
int numHostConns;
mscclppGetAllHostConnections(pimpl->comm, &hostConns, &numHostConns);
if (numHostConns != pimpl->connections.size()) {
throw std::logic_error("Number of HostConnections didn't match number of mscclppHostConns");
}
for (int connIdx = 0; connIdx < pimpl->connections.size(); ++connIdx) {
pimpl->connections[connIdx]->pimpl->setup(hostConns[connIdx]);
}
}
int Communicator::rank() {
int result;
mscclppCommRank(pimpl->comm, &result);
return result;
}
int Communicator::size() {
int result;
mscclppCommSize(pimpl->comm, &result);
return result;
}
} // namespace mscclpp

67
src/fifo.cc Normal file
View File

@@ -0,0 +1,67 @@
#include "mscclppfifo.hpp"
#include "alloc.h"
#include "checks.hpp"
#include <cuda_runtime.h>
#include <stdexcept>
#include <emmintrin.h>
namespace mscclpp {
struct HostProxyFifo::Impl {
DeviceProxyFifo deviceFifo;
// allocated on the host. Only accessed by the host. This is a copy of the
// value pointed to by fifoTailDev and the invariant is that
// *fifoTailDev <= hostTail. Meaning that host's copy of tail is
// always ahead of the device's copy and host updates the device's copy
// only when it is needed. Therefore, hostTail is the "true" tail
// and fifoTailDev is a "stale" tail. See proxy.cc to undertand how
// these updates are pushed to the device.
uint64_t hostTail;
// for transferring fifo tail
cudaStream_t stream;
};
HostProxyFifo::HostProxyFifo() {
pimpl = std::make_unique<Impl>();
MSCCLPPTHROW(mscclppCudaCalloc(&pimpl->deviceFifo.head, 1));
MSCCLPPTHROW(mscclppCudaHostCalloc(&pimpl->deviceFifo.triggers, MSCCLPP_PROXY_FIFO_SIZE));
MSCCLPPTHROW(mscclppCudaCalloc(&pimpl->deviceFifo.tailReplica, 1));
CUDATHROW(cudaStreamCreateWithFlags(&pimpl->stream, cudaStreamNonBlocking));
pimpl->hostTail = 0;
}
HostProxyFifo::~HostProxyFifo() {
MSCCLPPTHROW(mscclppCudaFree(pimpl->deviceFifo.head));
MSCCLPPTHROW(mscclppCudaHostFree(pimpl->deviceFifo.triggers));
MSCCLPPTHROW(mscclppCudaFree(pimpl->deviceFifo.tailReplica));
CUDATHROW(cudaStreamDestroy(pimpl->stream));
}
void HostProxyFifo::poll(ProxyTrigger *trigger) {
__m128i xmm0 = _mm_load_si128((__m128i*)&pimpl->deviceFifo.triggers[pimpl->hostTail % MSCCLPP_PROXY_FIFO_SIZE]);
_mm_store_si128((__m128i*)trigger, xmm0);
}
void HostProxyFifo::pop() {
*(volatile uint64_t*)(&pimpl->deviceFifo.triggers[pimpl->hostTail % MSCCLPP_PROXY_FIFO_SIZE]) = 0;
(pimpl->hostTail)++;
}
void HostProxyFifo::flushTail(bool sync) {
// Flush the tail to device memory. This is either triggered every MSCCLPP_PROXY_FIFO_FLUSH_COUNTER to make sure
// that the fifo can make progress even if there is no request mscclppSync. However, mscclppSync type is for flush
// request.
CUDATHROW(
cudaMemcpyAsync(pimpl->deviceFifo.tailReplica, &pimpl->hostTail, sizeof(uint64_t), cudaMemcpyHostToDevice, pimpl->stream));
if (sync) {
CUDATHROW(cudaStreamSynchronize(pimpl->stream));
}
}
DeviceProxyFifo HostProxyFifo::toDevice() {
return pimpl->deviceFifo;
}
} // namespace mscclpp

79
src/host_connection.cc Normal file
View File

@@ -0,0 +1,79 @@
#include "host_connection.hpp"
#include "communicator.hpp"
#include "comm.h"
#include "mscclpp.h"
#include "mscclppfifo.h"
#include "api.h"
namespace mscclpp {
HostConnection::Impl::Impl(Communicator* comm, mscclppConn* conn) : comm(comm), conn(conn) {
this->hostConn = conn->hostConn;
}
HostConnection::Impl::~Impl() {
// TODO: figure out memory ownership. Does this deallocate the mscclppHostConn? Likely not.
}
MSCCLPP_API_CPP HostConnection::~HostConnection() = default;
MSCCLPP_API_CPP HostConnection::HostConnection(std::unique_ptr<Impl> p) : pimpl(std::move(p)) {}
MSCCLPP_API_CPP int HostConnection::getId() {
return pimpl->conn->connId;
}
MSCCLPP_API_CPP BufferHandle HostConnection::registerBuffer(void* data, uint64_t size) {
BufferHandle result;
static_assert(sizeof(BufferHandle) == sizeof(mscclppBufferHandle_t));
mscclppRegisterBufferForConnection(pimpl->comm->pimpl->comm, pimpl->conn->connId, data, size, reinterpret_cast<mscclppBufferHandle_t*>(&result));
return result;
}
MSCCLPP_API_CPP int HostConnection::numLocalBuffers() {
return pimpl->conn->bufferRegistrations.size() - 1;
}
MSCCLPP_API_CPP BufferHandle HostConnection::getLocalBuffer(int index) {
return index + 1;
}
MSCCLPP_API_CPP int HostConnection::numRemoteBuffers() {
return pimpl->conn->remoteBufferRegistrations.size() - 1;
}
MSCCLPP_API_CPP BufferHandle HostConnection::getRemoteBuffer(int index) {
return index + 1;
}
MSCCLPP_API_CPP ConnectionEpoch HostConnection::getEpoch() {
ConnectionEpoch epoch;
static_assert(sizeof(SignalEpochId) == sizeof(mscclppDevConnSignalEpochId));
epoch.localSignalEpochId = reinterpret_cast<SignalEpochId*>(pimpl->conn->devConn->localSignalEpochId);
epoch.remoteSignalEpochId = reinterpret_cast<SignalEpochId*>(pimpl->conn->devConn->remoteSignalEpochId);
epoch.waitEpochId = pimpl->conn->devConn->waitEpochId;
return epoch;
}
MSCCLPP_API_CPP DeviceProxyFifo HostConnection::getDeviceFifo() {
return pimpl->comm->pimpl->proxy.fifo().toDevice();
}
MSCCLPP_API_CPP void HostConnection::put(BufferHandle dst, uint64_t dstOffset, BufferHandle src, uint64_t srcOffset, uint64_t size) {
pimpl->hostConn->put(dst, dstOffset, src, srcOffset, size);
}
MSCCLPP_API_CPP void HostConnection::signal() {
pimpl->hostConn->signal();
}
MSCCLPP_API_CPP void HostConnection::flush() {
pimpl->hostConn->flush();
}
MSCCLPP_API_CPP void HostConnection::wait() {
pimpl->hostConn->wait();
}
} // namespace mscclpp

View File

@@ -1,55 +0,0 @@
#include "host_connection.hpp"
namespace mscclpp {
HostConnection::Impl::Impl() : hostConn(nullptr) {}
HostConnection::Impl::~Impl() {
// TODO: figure out memory ownership. Does this deallocate the mscclppHostConn? Likely not.
}
void HostConnection::Impl::setup(mscclppHostConn_t *hostConn) {
this->hostConn = hostConn;
}
BufferHandle HostConnection::registerBuffer(void* data, uint64_t size) {
}
int HostConnection::numRemoteBuffers() {
}
BufferHandle HostConnection::getRemoteBuffer(int index) {
}
DeviceConnection HostConnection::toDevice(bool startProxyThread = true) {
}
void HostConnection::put(BufferHandle dst, uint64_t dstOffset, BufferHandle src, uint64_t srcOffset, uint64_t size) {
}
void HostConnection::put(BufferHandle dst, BufferHandle src, uint64_t offset, uint64_t size) {
}
void HostConnection::signal() {
}
void HostConnection::flush() {
}
void HostConnection::wait() {
}
void HostConnection::epochIncrement() {
}
} // namespace mscclpp

View File

@@ -2,5 +2,6 @@
#define MSCCLPP_API_H_
#define MSCCLPP_API extern "C" __attribute__((visibility("default")))
#define MSCCLPP_API_CPP __attribute__((visibility("default")))
#endif // MSCCLPP_API_H_

View File

@@ -0,0 +1,13 @@
#ifndef MSCCLPP_BASIC_PROXY_SERVICE_HPP_
#define MSCCLPP_BASIC_PROXY_SERVICE_HPP_
#include "mscclpp.hpp"
#include "communicator.hpp"
namespace mscclpp {
ProxyHandler makeBasicProxyHandler(Communicator::Impl &comm);
}
#endif

View File

@@ -27,29 +27,3 @@
} while (false)
#endif
#include <errno.h>
// Check system calls
#define SYSCHECKTHROW(call, name) \
do { \
int retval; \
SYSCHECKVAL(call, name, retval); \
} while (false)
#define SYSCHECKVALTHROW(call, name, retval) \
do { \
SYSCHECKSYNC(call, name, retval); \
if (retval == -1) { \
std::runtime_error(std::string("Call to " name " failed : ") + strerror(errno)); \
} \
} while (false)
#define SYSCHECKSYNCTHROW(call, name, retval) \
do { \
retval = call; \
if (retval == -1 && (errno == EINTR || errno == EWOULDBLOCK || errno == EAGAIN)) { \
INFO(MSCCLPP_ALL, "Call to " name " returned %s, retrying", strerror(errno)); \
} else { \
break; \
} \
} while (true)

View File

@@ -0,0 +1,23 @@
#ifndef MSCCL_COMMUNICATOR_HPP_
#define MSCCL_COMMUNICATOR_HPP_
#include "mscclpp.hpp"
#include "mscclpp.h"
namespace mscclpp {
struct Communicator::Impl {
mscclppComm_t comm;
std::vector<std::shared_ptr<HostConnection>> connections;
Proxy proxy;
Impl();
~Impl();
friend class HostConnection;
};
} // namespace mscclpp
#endif

View File

@@ -3,17 +3,18 @@
#include "mscclpp.hpp"
#include "mscclpp.h"
#include "comm.h"
namespace mscclpp {
struct HostConnection::Impl {
Communicator* comm;
mscclppConn* conn;
mscclppHostConn_t* hostConn;
Impl();
Impl(Communicator* comm, mscclppConn* conn);
~Impl();
void setup(mscclppHostConn_t *hostConn);
};
} // namespace mscclpp

View File

@@ -29,7 +29,7 @@ struct alignas(16) mscclppDevConnSignalEpochId
uint64_t proxy;
};
using mscclppBufferHandle_t = uint8_t;
using mscclppBufferHandle_t = uint32_t;
/***************************************************************************************************************
* A mscclppDevConn provides a zero-copy connection between two GPUs connected via P2P NVLink or InfiniBand.

View File

@@ -13,6 +13,7 @@
#include <vector>
#include <memory>
#include <functional>
#include <mscclppfifo.hpp>
@@ -27,15 +28,14 @@ struct alignas(16) SignalEpochId {
uint64_t proxy;
};
enum ChannelTriggerType : uint64_t {
channelTriggerData = 0x1,
channelTriggerFlag = 0x2,
channelTriggerSync = 0x4
};
using ChannelTriggerType = uint64_t;
const ChannelTriggerType channelTriggerData = 0x1;
const ChannelTriggerType channelTriggerFlag = 0x2;
const ChannelTriggerType channelTriggerSync = 0x4;
// This is just a numeric ID. Each HostConnection will have an internal array indexed by these handles
// mapping to the actual
using BufferHandle = uint8_t;
using BufferHandle = uint32_t;
#define MSCCLPP_BITS_SIZE 32
#define MSCCLPP_BITS_OFFSET 32
@@ -58,15 +58,111 @@ union ChannelTrigger {
uint64_t srcBufferHandle : MSCCLPP_BITS_BUFFER_HANDLE;
uint64_t dstBufferHandle : MSCCLPP_BITS_BUFFER_HANDLE;
uint64_t type : MSCCLPP_BITS_TYPE;
uint64_t connId : MSCCLPP_BITS_CONNID;
uint64_t : (64 - MSCCLPP_BITS_OFFSET - MSCCLPP_BITS_BUFFER_HANDLE - MSCCLPP_BITS_BUFFER_HANDLE - MSCCLPP_BITS_TYPE); // ensure 64-bit alignment
} fields;
ChannelTrigger() {}
ChannelTrigger(ProxyTrigger value) : value(value) {}
ChannelTrigger(ChannelTriggerType type, BufferHandle dst, uint64_t dstOffset, BufferHandle src, uint64_t srcOffset, uint64_t size) {
#ifdef __CUDACC__
__device__ ChannelTrigger() {}
__device__ ChannelTrigger(ProxyTrigger value) : value(value) {}
__device__ ChannelTrigger(ChannelTriggerType type, BufferHandle dst, uint64_t dstOffset, BufferHandle src, uint64_t srcOffset, uint64_t size, int connectionId) {
value.fst = ((srcOffset << MSCCLPP_BITS_SIZE) + size);
value.snd = (((((((uint64_t)type << MSCCLPP_BITS_BUFFER_HANDLE) + dst) << MSCCLPP_BITS_BUFFER_HANDLE) + src) << MSCCLPP_BITS_OFFSET) + dstOffset);
value.snd = ((((((((connectionId << MSCCLPP_BITS_TYPE) + (uint64_t)type) << MSCCLPP_BITS_BUFFER_HANDLE) + dst) << MSCCLPP_BITS_BUFFER_HANDLE) + src) << MSCCLPP_BITS_OFFSET) + dstOffset);
}
#endif // __CUDACC__
};
struct ConnectionEpoch {
#ifdef __CUDACC__
__forceinline__ __device__ void wait()
{
(*waitEpochId) += 1;
while (*(volatile uint64_t*)&(localSignalEpochId->proxy) < (*waitEpochId))
;
}
__forceinline__ __device__ void epochIncrement()
{
*(volatile uint64_t*)&(localSignalEpochId->device) += 1;
}
#endif // __CUDACC__
SignalEpochId* localSignalEpochId;
// used by the signal() function directly from gpu
SignalEpochId* remoteSignalEpochId;
// every wait(), increments this and then the gpu waits for either:
// 1) localSignalEpochId->proxy to be >= this in case of a proxy thread
// 2) remoteSignalEpochId->device to be >= this in case of a gpu thread
uint64_t* waitEpochId;
};
class HostConnection {
struct Impl;
public:
/* HostConnection can not be constructed from user code and must instead be created through Communicator::connect */
HostConnection(std::unique_ptr<Impl>);
~HostConnection();
int getId();
/* Register a region of GPU memory for use with this connection. Must be called before connectionSetup()
* in the communicator.
*
* Inputs:
* data: base pointer to the memory
* size: size of the memory region in bytes
*
* Returns: a handle to the buffer
*/
BufferHandle registerBuffer(void* data, uint64_t size);
/* Get the number of times registerBuffer(...) was called.
*
* Returns: the number of buffers registered
*/
int numLocalBuffers();
/* Get the BufferHandle returned by a call to registerBuffer(...) as identified by the index
*
* Inputs:
* index: the index of the handle to get
*
* Returns: a handle to the buffer
*/
BufferHandle getLocalBuffer(int index);
/* Get the number of times registerBuffer(...) was called on the remote peer.
*
* Returns: the number of buffers registered on the remote peer
*/
int numRemoteBuffers();
/* Get the BufferHandle returned by a call to registerBuffer(...) on the remote peer as identified by the index
*
* Inputs:
* index: the index of the handle to get
*
* Returns: a handle to the buffer on the remote peer
*/
BufferHandle getRemoteBuffer(int index);
ConnectionEpoch getEpoch();
DeviceProxyFifo getDeviceFifo();
void put(BufferHandle dst, uint64_t dstOffset, BufferHandle src, uint64_t srcOffset, uint64_t size);
void signal();
void flush();
void wait();
private:
std::unique_ptr<Impl> pimpl;
friend class Communicator;
};
/***************************************************************************************************************
@@ -132,12 +228,20 @@ union ChannelTrigger {
* indices in the registered buffer.
**************************************************************************************************************/
struct DeviceConnection {
#ifdef __CUDACC__
// TODO: add buffer handles
DeviceConnection() = default;
DeviceConnection(HostConnection& hostConn)
: connectionId(hostConn.getId()), epoch(hostConn.getEpoch()),
fifo(hostConn.getDeviceFifo()) {}
DeviceConnection(const DeviceConnection& other) = default;
DeviceConnection& operator=(DeviceConnection& other) = default;
#ifdef __CUDACC__
__forceinline__ __device__ void put(BufferHandle dst, uint64_t dstOffset, BufferHandle src, uint64_t srcOffset, uint64_t size)
{
fifo.push(ChannelTrigger(channelTriggerData, dst, dstOffset, src, srcOffset, size).value);
fifo.push(ChannelTrigger(channelTriggerData, dst, dstOffset, src, srcOffset, size, connectionId).value);
}
__forceinline__ __device__ void put(BufferHandle dst, BufferHandle src, uint64_t offset, uint64_t size)
@@ -148,13 +252,13 @@ struct DeviceConnection {
__forceinline__ __device__ void signal()
{
epochIncrement();
fifo.push(ChannelTrigger(channelTriggerFlag, 0, 0, 0, 0, 1).value);
fifo.push(ChannelTrigger(channelTriggerFlag, 0, 0, 0, 0, 1, connectionId).value);
}
__forceinline__ __device__ void putWithSignal(BufferHandle dst, uint64_t dstOffset, BufferHandle src, uint64_t srcOffset, uint64_t size)
{
epochIncrement();
fifo.push(ChannelTrigger(channelTriggerData | channelTriggerFlag, dst, dstOffset, src, srcOffset, size).value);
fifo.push(ChannelTrigger(channelTriggerData | channelTriggerFlag, dst, dstOffset, src, srcOffset, size, connectionId).value);
}
__forceinline__ __device__ void putWithSignal(BufferHandle dst, BufferHandle src, uint64_t offset, uint64_t size)
@@ -165,107 +269,116 @@ struct DeviceConnection {
__forceinline__ __device__ void putWithSignalAndFlush(BufferHandle dst, uint64_t dstOffset, BufferHandle src, uint64_t srcOffset, uint64_t size)
{
epochIncrement();
uint64_t curFifoHead = fifo.push(channelTriggerData | channelTriggerFlag | channelTriggerSync, dstOffset, srcOffset, size);
while (*(volatile uint64_t*)&fifo.triggerFifo[curFifoHead % MSCCLPP_PROXY_FIFO_SIZE] != 0 &&
*(volatile uint64_t*)fifo.triggerFifoTail <= curFifoHead)
uint64_t curFifoHead = fifo.push(ChannelTrigger(channelTriggerData | channelTriggerFlag | channelTriggerSync, dst, dstOffset, src, srcOffset, size, connectionId).value);
while (*(volatile uint64_t*)&fifo.triggers[curFifoHead % MSCCLPP_PROXY_FIFO_SIZE] != 0 &&
*(volatile uint64_t*)fifo.tailReplica <= curFifoHead)
;
}
__forceinline__ __device__ void putWithSignalAndFlush(BufferHandle dst, BufferHandle src, uint64_t offset, uint64_t size)
{
putWithSignalAndFlush(dst, offset, src, offset, size);
}
__forceinline__ __device__ void flush()
{
uint64_t curFifoHead = fifo.push(ChannelTrigger(mscclppSync, 0, 0, 0, 0, 1, connectionId).value);
// we need to wait for two conditions to be met to ensure the CPU is done flushing. (1) wait for the tail
// to go pass by curFifoHead (this is safety net) and (2) wait for the work element value to change to 0.
while (*(volatile uint64_t*)&fifo.triggers[curFifoHead % MSCCLPP_PROXY_FIFO_SIZE] != 0 &&
*(volatile uint64_t*)fifo.tailReplica <= curFifoHead)
;
}
__forceinline__ __device__ void wait()
{
epoch.wait();
}
__forceinline__ __device__ void epochIncrement()
{
epoch.epochIncrement();
}
#endif // __CUDACC__
int connectionId;
ConnectionEpoch epoch;
// this is a concurrent fifo which is multiple threads from the device
// can produce for and the sole proxy thread consumes it.
DeviceProxyFifo fifo;
};
struct SimpleDeviceConnection {
SimpleDeviceConnection() = default;
SimpleDeviceConnection(HostConnection& hostConn) : devConn(hostConn) {
dst = hostConn.getRemoteBuffer(0);
src = hostConn.getLocalBuffer(0);
}
SimpleDeviceConnection(const SimpleDeviceConnection& other) = default;
SimpleDeviceConnection& operator=(SimpleDeviceConnection& other) = default;
#ifdef __CUDACC__
__forceinline__ __device__ void put(uint64_t dstOffset, uint64_t srcOffset, uint64_t size)
{
devConn.put(dst, dstOffset, src, srcOffset, size);
}
__forceinline__ __device__ void put(uint64_t offset, uint64_t size)
{
put(offset, offset, size);
}
__forceinline__ __device__ void signal()
{
devConn.signal();
}
__forceinline__ __device__ void putWithSignal(uint64_t dstOffset, uint64_t srcOffset, uint64_t size)
{
devConn.putWithSignal(dst, dstOffset, src, srcOffset, size);
}
__forceinline__ __device__ void putWithSignal(uint64_t offset, uint64_t size)
{
putWithSignal(offset, offset, size);
}
__forceinline__ __device__ void putWithSignalAndFlush(uint64_t dstOffset, uint64_t srcOffset, uint64_t size)
{
devConn.putWithSignalAndFlush(dst, dstOffset, src, srcOffset, size);
}
__forceinline__ __device__ void putWithSignalAndFlush(uint64_t offset, uint64_t size)
{
putWithSignalAndFlush(offset, offset, size);
}
__forceinline__ __device__ void flush()
{
uint64_t curFifoHead = fifo.push(mscclppSync, 0, 0, 1);
// we need to wait for two conditions to be met to ensure the CPU is done flushing. (1) wait for the tail
// to go pass by curFifoHead (this is safety net) and (2) wait for the work element value to change to 0.
while (*(volatile uint64_t*)&fifo.triggerFifo[curFifoHead % MSCCLPP_PROXY_FIFO_SIZE] != 0 &&
*(volatile uint64_t*)fifo.triggerFifoTail <= curFifoHead)
;
devConn.flush();
}
__forceinline__ __device__ void wait()
{
(*waitEpochId) += 1;
while (*(volatile uint64_t*)&(localSignalEpochId->proxy) < (*waitEpochId))
;
devConn.wait();
}
__forceinline__ __device__ void epochIncrement()
{
*(volatile uint64_t*)&(localSignalEpochId->device) += 1;
devConn.epochIncrement();
}
#endif // __CUDACC__
int remoteRank;
int tag;
SignalEpochId* localSignalEpochId;
// used by the signal() function directly from gpu
SignalEpochId* remoteSignalEpochId;
// every wait(), increments this and then the gpu waits for either:
// 1) localSignalEpochId->proxy to be >= this in case of a proxy thread
// 2) remoteSignalEpochId->device to be >= this in case of a gpu thread
uint64_t* waitEpochId;
// this is a concurrent fifo which is multiple threads from the device
// can produce for and the sole proxy thread consumes it.
ProxyFifo fifo;
};
class HostConnection {
public:
/* Register a region of GPU memory for use with this connection. Must be called before connectionSetup()
* in the communicator.
*
* Inputs:
* data: base pointer to the memory
* size: size of the memory region in bytes
*
* Returns: a handle to the buffer
*/
BufferHandle registerBuffer(void* data, uint64_t size);
/* Get the number of times registerBuffer(...) was called on the remote peer.
*
* Returns: the number of buffers registered on the remote peer
*/
int numRemoteBuffers();
/* Get the BufferHandle returned by a call to registerBuffer(...) on the remote peer as identified by the index
*
* Inputs:
* index: the index of the handle to get
*
* Returns: a handle to the buffer on the remote peer
*/
BufferHandle getRemoteBuffer(int index);
/* Create a DeviceConnection paired with this HostConnection. A background proxy thread will
* trigger operations on this HostConnection corresponding to put/signal/etc. calls made to the
* DeviceConnection.
*
* Inputs:
* startProxyThread: whether to start the proxy thread (default is true)
*
* Returns: the newly created DeviceConnection
*/
DeviceConnection toDevice(bool startProxyThread = true);
void put(BufferHandle dst, uint64_t dstOffset, BufferHandle src, uint64_t srcOffset, uint64_t size);
void put(BufferHandle dst, BufferHandle src, uint64_t offset, uint64_t size);
void signal();
void flush();
void wait();
void epochIncrement();
private:
struct Impl;
std::unique_ptr<Impl> pimpl;
DeviceConnection devConn;
BufferHandle dst;
BufferHandle src;
};
#define MSCCLPP_UNIQUE_ID_BYTES 128
@@ -290,6 +403,7 @@ enum class TransportType : uint8_t {
class Communicator {
public:
/* Initialize the communicator. nranks processes with rank 0 to nranks-1 need to call this function.
*
* Inputs:
@@ -297,7 +411,7 @@ public:
* ipPortPair: a string of the form "ip:port" that represents the address of the root process
* rank: rank of the calling process
*/
void initRank(int nranks, const char* ipPortPair, int rank);
Communicator(int nranks, const char* ipPortPair, int rank);
/* Initialize the communicator from a given UniqueId. Same as mscclppCommInitRank() except that
* id is provided by the user by calling getUniqueId()
@@ -307,7 +421,9 @@ public:
* id: the unique ID to be used for communication
* rank: rank of the calling process
*/
void initRankFromId(int nranks, UniqueId id, int rank);
Communicator(int nranks, UniqueId id, int rank);
~Communicator();
/* Ring-based AllGather through the bootstrap socket.
*
@@ -341,6 +457,12 @@ public:
*/
void connectionSetup();
/* Launch proxy thread(s). This function is supposed to be called before starting a kernel that uses DeviceConnection. */
void startProxying();
/* Stop proxy thread(s). */
void stopProxying();
/* Return the rank of the calling process.
*
* Outputs:
@@ -355,6 +477,33 @@ public:
*/
int size();
struct Impl;
private:
std::unique_ptr<Impl> pimpl;
friend class HostConnection;
};
enum class ProxyHandlerResult {
Continue,
FlushFifoTailAndContinue,
Stop,
};
class Proxy;
using ProxyHandler = std::function<ProxyHandlerResult(ProxyTrigger)>;
class Proxy {
public:
Proxy(ProxyHandler handler);
~Proxy();
void start();
void stop();
HostProxyFifo& fifo();
private:
struct Impl;
std::unique_ptr<Impl> pimpl;

View File

@@ -3,6 +3,7 @@
#include <stdint.h>
#include <functional>
#include <memory>
namespace mscclpp {
@@ -13,39 +14,56 @@ struct alignas(16) ProxyTrigger {
/* This is a concurrent fifo where multiple device threads can push mscclppTrigger work elements to
* and a single host proxy thread consumes these work elements. There is a head pointer allocated on device
* which starts with 0 and goes to 2^64-1 which is almost infinity. There are two copies of tail, one
* that is on the deivce (triggerFifoTail) and another that is on host (proxyState->fifoTailHost).
* that is on the deivce (tailReplica) and another that is on host (proxyState->fifoTailHost).
* The host always has the "true" tail and occasionally, pushes it to the copy on the device.
* Therefore, most of the time, the device has a stale version. The invariants are:
* triggerFifoTail <= proxyState->fifoTailHost <= triggerFifoHead.
* push() function increments triggerFifoHead, proxyState->fifoTailHost is updated in proxy.cc:mscclppProxyService
* and it occasionally flushes it to triggerFifoTail via a cudaMemcpyAsync.
* tailReplica <= proxyState->fifoTailHost <= head.
* push() function increments head, proxyState->fifoTailHost is updated in proxy.cc:mscclppProxyService
* and it occasionally flushes it to tailReplica via a cudaMemcpyAsync.
*
* Why duplicating the tail is a good idea? The fifo is large engouh and we do not need frequent updates
* for the tail as there is usually enough space for device threads to push their work into.
*/
struct ProxyFifo {
struct DeviceProxyFifo {
#ifdef __CUDACC__
__forceinline__ __device__ uint64_t push(ProxyTrigger element)
__forceinline__ __device__ uint64_t push(ProxyTrigger trigger)
{
uint64_t curFifoHead = atomicAdd((unsigned long long int*)this->triggerFifoHead, 1);
while (curFifoHead >= MSCCLPP_PROXY_FIFO_SIZE + *((volatile uint64_t*)this->triggerFifoTail))
uint64_t curFifoHead = atomicAdd((unsigned long long int*)this->head, 1);
while (curFifoHead >= MSCCLPP_PROXY_FIFO_SIZE + *((volatile uint64_t*)this->tailReplica))
;
while (*(volatile uint64_t*)&this->triggerFifo[curFifoHead % MSCCLPP_PROXY_FIFO_SIZE] != 0)
while (*(volatile uint64_t*)&this->triggers[curFifoHead % MSCCLPP_PROXY_FIFO_SIZE] != 0)
;
uint64_t* valptr = (uint64_t*)&(this->triggerFifo[curFifoHead % MSCCLPP_PROXY_FIFO_SIZE].value);
asm volatile("st.volatile.global.v2.u64 [%0], {%1,%2};" ::"l"(valptr),
"l"(element.value[0]), "l"(element.value[1]));
ProxyTrigger* triggerPtr = (ProxyTrigger*)&(this->triggers[curFifoHead % MSCCLPP_PROXY_FIFO_SIZE]);
asm volatile("st.volatile.global.v2.u64 [%0], {%1,%2};" ::"l"(triggerPtr),
"l"(trigger.fst), "l"(trigger.snd));
return curFifoHead;
}
#endif // __CUDACC__
void startProxyThread(std::function<void(ProxyTrigger)> handler);
void stopProxyThread();
ProxyTrigger* triggerFifo; // Allocate on host via cudaHostAlloc. This space is used for pushing the workelements
uint64_t* triggerFifoTail; // Allocated on device. proxyState->fifoTailHost is the true tail on host and pused
ProxyTrigger* triggers; // Allocate on host via cudaHostAlloc. This space is used for pushing the workelements
uint64_t* tailReplica; // Allocated on device. proxyState->fifoTailHost is the true tail on host and pused
// occasionally to device
uint64_t* triggerFifoHead; // Allocated on device. Only accessed by device
uint64_t* head; // Allocated on device. Only accessed by device
};
class HostProxyFifo
{
public:
HostProxyFifo();
~HostProxyFifo();
void poll(ProxyTrigger *trigger);
void pop();
void flushTail(bool sync = false);
DeviceProxyFifo toDevice();
private:
struct Impl;
std::unique_ptr<Impl> pimpl;
};
} // namespace mscclpp

View File

@@ -629,7 +629,7 @@ struct connInfo
h.numBufferInfos = bufferInfos.size();
MSCCLPPCHECK(bootstrapSend(bootstrap, remoteRank, tag, &h, sizeof(header)));
MSCCLPPCHECK(bootstrapSend(bootstrap, remoteRank, tag, bufferInfos.data(), bufferInfos.size() * sizeof(mscclppBufferRegistrationInfo)));
return mscclppSuccess;
return mscclppSuccess;
}
mscclppResult_t recvOverBootstrap(void* bootstrap, int remoteRank, int tag) {
@@ -638,7 +638,7 @@ struct connInfo
infoQp = h.infoQp;
bufferInfos.resize(h.numBufferInfos);
MSCCLPPCHECK(bootstrapRecv(bootstrap, remoteRank, tag, bufferInfos.data(), bufferInfos.size() * sizeof(mscclppBufferRegistrationInfo)));
return mscclppSuccess;
return mscclppSuccess;
}
};

97
src/proxy_cpp.cc Normal file
View File

@@ -0,0 +1,97 @@
#include "mscclpp.hpp"
#include "utils.h"
#include "api.h"
#include <thread>
#include <atomic>
namespace mscclpp {
const int ProxyStopCheckPeriod = 1000;
const int ProxyFlushPeriod = 4;
struct Proxy::Impl {
ProxyHandler handler;
HostProxyFifo fifo;
std::thread service;
std::atomic_bool running;
Impl(ProxyHandler handler) : handler(handler), running(false) {}
};
MSCCLPP_API_CPP Proxy::Proxy(ProxyHandler handler) {
pimpl = std::make_unique<Impl>(handler);
}
MSCCLPP_API_CPP Proxy::~Proxy() {
if (pimpl) {
stop();
}
}
MSCCLPP_API_CPP void Proxy::start() {
pimpl->running = true;
pimpl->service = std::thread([this] {
// from this point on, proxy thread will stay close to the device
// PROXYMSCCLPPCHECK(numaBind(pimpl->comm->devNumaNode)); // TODO: reenable this
ProxyHandler handler = this->pimpl->handler;
HostProxyFifo& fifo = this->pimpl->fifo;
std::atomic_bool& running = this->pimpl->running;
ProxyTrigger trigger;
int runCnt = ProxyStopCheckPeriod;
uint64_t flushCnt = 0;
for (;;) {
if (runCnt-- == 0) {
runCnt = ProxyStopCheckPeriod;
if (!running) {
break;
}
}
// Poll to see if we are ready to send anything
fifo.poll(&trigger);
if (trigger.fst == 0) { // TODO: this check is a potential pitfall for custom triggers
continue; // there is one in progress
}
ProxyHandlerResult result = handler(trigger);
// Send completion: reset only the high 64 bits
fifo.pop();
// Flush the tail to device memory. This is either triggered every ProxyFlushPeriod to make sure
// that the fifo can make progress even if there is no request mscclppSync. However, mscclppSync type is for flush
// request.
if ((++flushCnt % ProxyFlushPeriod) == 0 || result == ProxyHandlerResult::FlushFifoTailAndContinue) {
// TODO: relocate this check: || (trigger.fields.type & mscclppSync)
fifo.flushTail();
}
if (result == ProxyHandlerResult::Stop) {
break;
}
}
// make sure the tail is flushed before we shut the proxy
fifo.flushTail(/*sync=*/true);
// TODO: do these need to run?
// bool isP2pProxy = (proxyState->ibContext == nullptr);
// if (isP2pProxy) {
// cudaStream_t p2pStream = proxyState->p2pStream;
// PROXYCUDACHECK(cudaStreamSynchronize(p2pStream));
// }
});
}
MSCCLPP_API_CPP void Proxy::stop() {
pimpl->running = false;
if (pimpl->service.joinable()) {
pimpl->service.join();
}
}
MSCCLPP_API_CPP HostProxyFifo& Proxy::fifo() {
return pimpl->fifo;
}
} // namespace mscclpp

View File

@@ -10,6 +10,8 @@
#include <string>
#include <unistd.h>
#include <unordered_map>
#include <cassert>
#include <algorithm>
static int nranksPerNode = 8;
@@ -46,9 +48,9 @@ static double getTime(void)
return (tspec.tv_nsec / 1.0e9) + tspec.tv_sec;
}
__constant__ mscclpp::DeviceConnection constDevConns[16];
__constant__ mscclpp::SimpleDeviceConnection constDevConns[16];
__device__ void allgather0(mscclppDevConn_t devConn, int rank, int world_size, int remoteRank, size_t nelemsPerGPU)
__device__ void allgather0(mscclpp::SimpleDeviceConnection devConn, int rank, int world_size, int remoteRank, size_t nelemsPerGPU)
{
// this allgather is really simple and implemented as an alltoall
@@ -67,7 +69,7 @@ __device__ void allgather0(mscclppDevConn_t devConn, int rank, int world_size, i
devConn.wait();
}
__device__ void localAllGather(mscclppDevConn_t devConn, int rank, int world_size, int nranksPerNode, int remoteRank,
__device__ void localAllGather(mscclpp::SimpleDeviceConnection devConn, int rank, int world_size, int nranksPerNode, int remoteRank,
uint64_t offset, uint64_t size)
{
// this allgather algorithm works as follows:
@@ -91,14 +93,14 @@ __device__ void localAllGather(mscclppDevConn_t devConn, int rank, int world_siz
}
}
__device__ void allgather1(mscclppDevConn_t devConn, int rank, int world_size, int nranksPerNode, int remoteRank,
__device__ void allgather1(mscclpp::SimpleDeviceConnection devConn, int rank, int world_size, int nranksPerNode, int remoteRank,
size_t nelemsPerGPU)
{
localAllGather(devConn, rank, world_size, nranksPerNode, remoteRank, rank * nelemsPerGPU * sizeof(int),
nelemsPerGPU * sizeof(int));
}
__device__ void allgather2(mscclppDevConn_t devConn, int rank, int world_size, int nranksPerNode, int remoteRank,
__device__ void allgather2(mscclpp::SimpleDeviceConnection devConn, int rank, int world_size, int nranksPerNode, int remoteRank,
size_t nelemsPerGPU)
{
// this allgather is a pipelined and hierarchical one and only works for two nodes
@@ -167,7 +169,7 @@ __global__ void kernel(int rank, int world_size, int nranksPerNode, size_t nelem
int warpId = threadIdx.x / 32;
int remoteRank = (warpId < rank) ? warpId : warpId + 1;
// Each warp is responsible for one of the remote ranks
mscclppDevConn_t devConn = constDevConns[warpId];
mscclpp::SimpleDeviceConnection devConn = constDevConns[warpId];
if (kernel == 0)
allgather0(devConn, rank, world_size, remoteRank, nelemsPerGPU);
@@ -219,7 +221,7 @@ void setupMscclppConnections(int rank, int world_size, mscclpp::Communicator& co
int thisNode = rankToNode(rank);
int cudaNum = rankToLocalRank(rank);
std::string ibDevStr = "mlx5_ib" + std::to_string(cudaNum);
std::vector<mscclpp::DeviceConnection> devConns(world_size);
std::vector<std::shared_ptr<mscclpp::HostConnection>> hostConns;
for (int r = 0; r < world_size; ++r) {
if (r == rank)
@@ -235,12 +237,19 @@ void setupMscclppConnections(int rank, int world_size, mscclpp::Communicator& co
// Connect with all other ranks
auto hostConn = comm.connect(r, 0, transportType, ibDev);
hostConn->registerBuffer(data_d, dataSize);
devConns.push_back(hostConn->toDevice(false));
hostConns.push_back(hostConn);
}
comm.connectionSetup();
assert(devConns.size() < sizeof(constDevConns) / sizeof(mscclpp::DeviceConnection));
CUDACHECK(cudaMemcpyToSymbol(constDevConns, devConns.data(), sizeof(mscclpp::DeviceConnection) * devConns.size() ));
std::vector<mscclpp::SimpleDeviceConnection> devConns;
std::transform(hostConns.begin(), hostConns.end(), std::back_inserter(devConns),
[](std::shared_ptr<mscclpp::HostConnection>& hostConn) {
return mscclpp::SimpleDeviceConnection(*hostConn);
});
assert(devConns.size() < sizeof(constDevConns) / sizeof(mscclpp::SimpleDeviceConnection));
CUDACHECK(cudaMemcpyToSymbol(constDevConns, devConns.data(), sizeof(mscclpp::SimpleDeviceConnection) * devConns.size() ));
}
void printUsage(const char* prog, bool isMpi)
@@ -391,12 +400,9 @@ int main(int argc, const char* argv[])
size_t nelemsPerGPU = dataSize / sizeof(int) / world_size;
try{
mscclpp::Communicator comm;
if (rank == 0)
printf("Initializing MSCCL++\n");
comm.initRank(world_size, ip_port, rank);
mscclpp::Communicator comm(world_size, ip_port, rank);
if (rank == 0)
printf("Initializing data for allgather test\n");
@@ -406,97 +412,93 @@ int main(int argc, const char* argv[])
printf("Setting up the connection in MSCCL++\n");
setupMscclppConnections(rank, world_size, comm, data_d, dataSize);
if (rank == 0)
printf("Launching MSCCL++ proxy threads\n");
comm.startProxying();
if (rank == 0)
printf("Testing the correctness of AllGather implementation\n");
cudaStream_t stream;
CUDACHECK(cudaStreamCreateWithFlags(&stream, cudaStreamNonBlocking));
CUDACHECK(cudaDeviceSynchronize());
kernel<<<1, 32 * (world_size - 1), 0, stream>>>(rank, world_size, nranksPerNode, nelemsPerGPU, kernelNum);
CUDACHECK(cudaDeviceSynchronize());
CUDACHECK(cudaMemcpy(data_h, data_d, dataSize, cudaMemcpyDeviceToHost));
for (size_t i = 0; i < nelemsPerGPU * world_size; i++) {
int val = i + 1;
if (data_h[i] != val) {
printf("oh uh! data_h[%ld] (%d) != val (%d)\n", i, data_h[i], val);
break;
}
}
int tmp[16];
// A simple barrier
comm.bootstrapAllGather(tmp, sizeof(int));
if (rank == 0)
printf("Successfully checked the correctness\n");
// Perf test
int iterwithoutcudagraph = 10;
if (rank == 0)
printf("Running %d iterations of the kernel without CUDA graph\n", iterwithoutcudagraph);
CUDACHECK(cudaStreamSynchronize(stream));
comm.bootstrapAllGather(tmp, sizeof(int));
for (int i = 0; i < iterwithoutcudagraph; ++i) {
kernel<<<1, 32 * (world_size - 1), 0, stream>>>(rank, world_size, nranksPerNode, nelemsPerGPU, kernelNum);
}
CUDACHECK(cudaStreamSynchronize(stream));
comm.bootstrapAllGather(tmp, sizeof(int));
// cudaGraph Capture
int cudagraphiter = 10;
if (rank == 0)
printf("Capturing %d iterations of the kernel in a CUDA graph\n", cudagraphiter);
cudaGraph_t graph;
cudaGraphExec_t instance;
cudaStreamBeginCapture(stream, cudaStreamCaptureModeGlobal);
for (int i = 0; i < cudagraphiter; ++i) {
kernel<<<1, 32 * (world_size - 1), 0, stream>>>(rank, world_size, nranksPerNode, nelemsPerGPU, kernelNum);
}
cudaStreamEndCapture(stream, &graph);
cudaGraphInstantiate(&instance, graph, NULL, NULL, 0);
int cudagraphwarmup = 10;
if (rank == 0)
printf("Warming up %d iterations of the CUDA graph with %d iterations of the kernel\n", cudagraphwarmup,
cudagraphiter);
for (int i = 0; i < cudagraphwarmup; ++i) {
cudaGraphLaunch(instance, stream);
}
CUDACHECK(cudaStreamSynchronize(stream));
// measure runtime
int cudagraphlaunch = 10;
if (rank == 0)
printf("Running %d iterations of the CUDA graph with %d iterations of the kernel\n", cudagraphlaunch,
cudagraphiter);
comm.bootstrapAllGather(tmp, sizeof(int));
double t0, t1, ms, time_in_us;
t0 = getTime();
for (int i = 0; i < cudagraphlaunch; ++i) {
cudaGraphLaunch(instance, stream);
}
CUDACHECK(cudaStreamSynchronize(stream));
t1 = getTime();
ms = (t1 - t0) * 1000.0;
time_in_us = ms * 1000. / (float)cudagraphlaunch / (float)cudagraphiter;
printf("Rank %d report: size %lu time: %f us/iter algBW %f GBps\n", rank, dataSize, time_in_us,
(double)(dataSize) / 1e9 / (time_in_us / 1e6));
comm.bootstrapAllGather(tmp, sizeof(int));
if (rank == 0)
printf("Stopping MSCCL++ proxy threads\n");
comm.stopProxying();
} catch (std::exception& e) {
// todo: throw exceptions in the implementation and process them here
}
if (rank == 0)
printf("Launching MSCCL++ proxy threads\n");
MSCCLPPCHECK(mscclppProxyLaunch(comm));
if (rank == 0)
printf("Testing the correctness of AllGather implementation\n");
cudaStream_t stream;
CUDACHECK(cudaStreamCreateWithFlags(&stream, cudaStreamNonBlocking));
CUDACHECK(cudaDeviceSynchronize());
kernel<<<1, 32 * (world_size - 1), 0, stream>>>(rank, world_size, nranksPerNode, nelemsPerGPU, kernelNum);
CUDACHECK(cudaDeviceSynchronize());
CUDACHECK(cudaMemcpy(data_h, data_d, dataSize, cudaMemcpyDeviceToHost));
for (size_t i = 0; i < nelemsPerGPU * world_size; i++) {
int val = i + 1;
if (data_h[i] != val) {
printf("oh uh! data_h[%ld] (%d) != val (%d)\n", i, data_h[i], val);
break;
}
}
int tmp[16];
// A simple barrier
MSCCLPPCHECK(mscclppBootstrapAllGather(comm, tmp, sizeof(int)));
if (rank == 0)
printf("Successfully checked the correctness\n");
// Perf test
int iterwithoutcudagraph = 10;
if (rank == 0)
printf("Running %d iterations of the kernel without CUDA graph\n", iterwithoutcudagraph);
CUDACHECK(cudaStreamSynchronize(stream));
MSCCLPPCHECK(mscclppBootstrapAllGather(comm, tmp, sizeof(int)));
for (int i = 0; i < iterwithoutcudagraph; ++i) {
kernel<<<1, 32 * (world_size - 1), 0, stream>>>(rank, world_size, nranksPerNode, nelemsPerGPU, kernelNum);
}
CUDACHECK(cudaStreamSynchronize(stream));
MSCCLPPCHECK(mscclppBootstrapAllGather(comm, tmp, sizeof(int)));
// cudaGraph Capture
int cudagraphiter = 10;
if (rank == 0)
printf("Capturing %d iterations of the kernel in a CUDA graph\n", cudagraphiter);
cudaGraph_t graph;
cudaGraphExec_t instance;
cudaStreamBeginCapture(stream, cudaStreamCaptureModeGlobal);
for (int i = 0; i < cudagraphiter; ++i) {
kernel<<<1, 32 * (world_size - 1), 0, stream>>>(rank, world_size, nranksPerNode, nelemsPerGPU, kernelNum);
}
cudaStreamEndCapture(stream, &graph);
cudaGraphInstantiate(&instance, graph, NULL, NULL, 0);
int cudagraphwarmup = 10;
if (rank == 0)
printf("Warming up %d iterations of the CUDA graph with %d iterations of the kernel\n", cudagraphwarmup,
cudagraphiter);
for (int i = 0; i < cudagraphwarmup; ++i) {
cudaGraphLaunch(instance, stream);
}
CUDACHECK(cudaStreamSynchronize(stream));
// measure runtime
int cudagraphlaunch = 10;
if (rank == 0)
printf("Running %d iterations of the CUDA graph with %d iterations of the kernel\n", cudagraphlaunch,
cudagraphiter);
MSCCLPPCHECK(mscclppBootstrapAllGather(comm, tmp, sizeof(int)));
double t0, t1, ms, time_in_us;
t0 = getTime();
for (int i = 0; i < cudagraphlaunch; ++i) {
cudaGraphLaunch(instance, stream);
}
CUDACHECK(cudaStreamSynchronize(stream));
t1 = getTime();
ms = (t1 - t0) * 1000.0;
time_in_us = ms * 1000. / (float)cudagraphlaunch / (float)cudagraphiter;
printf("Rank %d report: size %lu time: %f us/iter algBW %f GBps\n", rank, dataSize, time_in_us,
(double)(dataSize) / 1e9 / (time_in_us / 1e6));
MSCCLPPCHECK(mscclppBootstrapAllGather(comm, tmp, sizeof(int)));
if (rank == 0)
printf("Stopping MSCCL++ proxy threads\n");
MSCCLPPCHECK(mscclppProxyStop(comm));
if (rank == 0)
printf("Destroying MSCCL++ communicator\n");
MSCCLPPCHECK(mscclppCommDestroy(comm));
printf("Rank %d succeeded!\n", rank);
#ifdef MSCCLPP_USE_MPI_FOR_TESTS