mirror of
https://github.com/microsoft/mscclpp.git
synced 2026-05-12 09:17:06 +00:00
merged with api-extension
This commit is contained in:
3
Makefile
3
Makefile
@@ -120,6 +120,7 @@ LDFLAGS := $(NVLDFLAGS) $(GDRCOPY_LDFLAGS) -libverbs -lnuma
|
||||
|
||||
LIBSRCS := $(addprefix src/,debug.cc utils.cc init.cc proxy.cc ib.cc config.cc)
|
||||
LIBSRCS += $(addprefix src/bootstrap/,bootstrap.cc socket.cc)
|
||||
LIBSRCS += $(addprefix src/,communicator.cc fifo.cc host_connection.cc proxy_cpp.cc basic_proxy_handler.cc)
|
||||
ifneq ($(NPKIT), 0)
|
||||
LIBSRCS += $(addprefix src/misc/,npkit.cc)
|
||||
endif
|
||||
@@ -147,7 +148,7 @@ UTOBJTARGETS := $(UTOBJS:%=$(BUILDDIR)/$(OBJDIR)/%)
|
||||
UTBINS := $(patsubst %.o,$(BUILDDIR)/$(BINDIR)/%,$(UTOBJS))
|
||||
|
||||
TESTSDIR := tests
|
||||
TESTSSRCS := $(addprefix $(TESTSDIR)/,bootstrap_test.cc allgather_test_standalone.cu bootstrap_test_cpp.cc)
|
||||
TESTSSRCS := $(addprefix $(TESTSDIR)/,bootstrap_test.cc allgather_test_standalone.cu allgather_test_cpp.cu bootstrap_test_cpp.cc)
|
||||
TESTSOBJS := $(patsubst %.cc,%.o,$(TESTSSRCS)) $(patsubst %.cu,%.o,$(TESTSSRCS))
|
||||
TESTSOBJTARGETS := $(TESTSOBJS:%=$(BUILDDIR)/$(OBJDIR)/%)
|
||||
TESTSBINS := $(patsubst %.o,$(BUILDDIR)/$(BINDIR)/%,$(TESTSOBJS))
|
||||
|
||||
29
src/basic_proxy_handler.cc
Normal file
29
src/basic_proxy_handler.cc
Normal file
@@ -0,0 +1,29 @@
|
||||
#include "basic_proxy_handler.hpp"
|
||||
|
||||
namespace mscclpp {
|
||||
|
||||
ProxyHandler makeBasicProxyHandler(Communicator::Impl &comm) {
|
||||
return [&comm](ProxyTrigger triggerRaw) {
|
||||
ChannelTrigger *trigger = reinterpret_cast<ChannelTrigger*>(&triggerRaw);
|
||||
HostConnection& conn = *comm.connections.at(trigger->fields.connId);
|
||||
|
||||
auto result = ProxyHandlerResult::Continue;
|
||||
|
||||
if (trigger->fields.type & mscclppData) {
|
||||
conn.put(trigger->fields.dstBufferHandle, trigger->fields.dstOffset, trigger->fields.srcBufferHandle, trigger->fields.srcOffset, trigger->fields.size);
|
||||
}
|
||||
|
||||
if (trigger->fields.type & mscclppFlag) {
|
||||
conn.signal();
|
||||
}
|
||||
|
||||
if (trigger->fields.type & mscclppSync) {
|
||||
conn.flush();
|
||||
result = ProxyHandlerResult::FlushFifoTailAndContinue;
|
||||
}
|
||||
|
||||
return result;
|
||||
};
|
||||
}
|
||||
|
||||
} // namespace mscclpp
|
||||
81
src/communicator.cc
Normal file
81
src/communicator.cc
Normal file
@@ -0,0 +1,81 @@
|
||||
#include "communicator.hpp"
|
||||
#include "host_connection.hpp"
|
||||
#include "comm.h"
|
||||
#include "basic_proxy_handler.hpp"
|
||||
#include "api.h"
|
||||
|
||||
namespace mscclpp {
|
||||
|
||||
Communicator::Impl::Impl() : comm(nullptr), proxy(makeBasicProxyHandler(*this)) {}
|
||||
|
||||
Communicator::Impl::~Impl() {
|
||||
if (comm) {
|
||||
mscclppCommDestroy(comm);
|
||||
}
|
||||
}
|
||||
|
||||
MSCCLPP_API_CPP Communicator::~Communicator() = default;
|
||||
|
||||
mscclppTransport_t transportTypeToCStyle(TransportType type) {
|
||||
switch (type) {
|
||||
case TransportType::IB:
|
||||
return mscclppTransportIB;
|
||||
case TransportType::P2P:
|
||||
return mscclppTransportP2P;
|
||||
default:
|
||||
throw std::runtime_error("Unknown transport type");
|
||||
}
|
||||
}
|
||||
|
||||
MSCCLPP_API_CPP Communicator::Communicator(int nranks, const char* ipPortPair, int rank) : pimpl(std::make_unique<Impl>()) {
|
||||
mscclppCommInitRank(&pimpl->comm, nranks, ipPortPair, rank);
|
||||
}
|
||||
|
||||
MSCCLPP_API_CPP Communicator::Communicator(int nranks, UniqueId id, int rank) : pimpl(std::make_unique<Impl>()) {
|
||||
static_assert(sizeof(mscclppUniqueId) == sizeof(UniqueId), "UniqueId size mismatch");
|
||||
mscclppUniqueId *cstyle_id = reinterpret_cast<mscclppUniqueId*>(&id);
|
||||
mscclppCommInitRankFromId(&pimpl->comm, nranks, *cstyle_id, rank);
|
||||
}
|
||||
|
||||
MSCCLPP_API_CPP void Communicator::bootstrapAllGather(void* data, int size) {
|
||||
mscclppBootstrapAllGather(pimpl->comm, data, size);
|
||||
}
|
||||
|
||||
MSCCLPP_API_CPP void Communicator::bootstrapBarrier() {
|
||||
mscclppBootstrapBarrier(pimpl->comm);
|
||||
}
|
||||
|
||||
MSCCLPP_API_CPP std::shared_ptr<HostConnection> Communicator::connect(int remoteRank, int tag,
|
||||
TransportType transportType, const char* ibDev) {
|
||||
mscclppConnectWithoutBuffer(pimpl->comm, remoteRank, tag, transportTypeToCStyle(transportType), ibDev);
|
||||
auto connIdx = pimpl->connections.size();
|
||||
auto conn = std::make_shared<HostConnection>(std::make_unique<HostConnection::Impl>(this, &pimpl->comm->conns[connIdx]));
|
||||
pimpl->connections.push_back(conn);
|
||||
return conn;
|
||||
}
|
||||
|
||||
MSCCLPP_API_CPP void Communicator::connectionSetup() {
|
||||
mscclppConnectionSetup(pimpl->comm);
|
||||
}
|
||||
|
||||
MSCCLPP_API_CPP void Communicator::startProxying() {
|
||||
pimpl->proxy.start();
|
||||
}
|
||||
|
||||
MSCCLPP_API_CPP void Communicator::stopProxying() {
|
||||
pimpl->proxy.stop();
|
||||
}
|
||||
|
||||
MSCCLPP_API_CPP int Communicator::rank() {
|
||||
int result;
|
||||
mscclppCommRank(pimpl->comm, &result);
|
||||
return result;
|
||||
}
|
||||
|
||||
MSCCLPP_API_CPP int Communicator::size() {
|
||||
int result;
|
||||
mscclppCommSize(pimpl->comm, &result);
|
||||
return result;
|
||||
}
|
||||
|
||||
} // namespace mscclpp
|
||||
@@ -1,90 +0,0 @@
|
||||
#include "mscclpp.hpp"
|
||||
#include "mscclpp.h"
|
||||
|
||||
namespace mscclpp {
|
||||
|
||||
mscclppTransport_t transportTypeToCStyle(TransportType type) {
|
||||
switch (type) {
|
||||
case TransportType::IB:
|
||||
return mscclppTransportIB;
|
||||
case TransportType::P2P:
|
||||
return mscclppTransportP2P;
|
||||
default:
|
||||
throw std::runtime_error("Unknown transport type");
|
||||
}
|
||||
}
|
||||
|
||||
struct Communicator::Impl {
|
||||
mscclppComm_t comm;
|
||||
std::vector<std::shared_ptr<HostConnection>> connections;
|
||||
|
||||
Impl() : comm(nullptr) {}
|
||||
|
||||
~Impl() {
|
||||
if (comm) {
|
||||
mscclppCommDestroy(comm);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
void Communicator::initRank(int nranks, const char* ipPortPair, int rank) {
|
||||
if (pimpl) {
|
||||
throw std::runtime_error("Communicator already initialized");
|
||||
}
|
||||
pimpl = std::make_unique<Impl>();
|
||||
mscclppCommInitRank(&pimpl->comm, nranks, ipPortPair, rank);
|
||||
}
|
||||
|
||||
void Communicator::initRankFromId(int nranks, UniqueId id, int rank) {
|
||||
if (pimpl) {
|
||||
throw std::runtime_error("Communicator already initialized");
|
||||
}
|
||||
pimpl = std::make_unique<Impl>();
|
||||
static_assert(sizeof(mscclppUniqueId) == sizeof(UniqueId), "UniqueId size mismatch");
|
||||
mscclppUniqueId *cstyle_id = reinterpret_cast<mscclppUniqueId*>(&id);
|
||||
mscclppCommInitRankFromId(&pimpl->comm, nranks, *cstyle_id, rank);
|
||||
}
|
||||
|
||||
void Communicator::bootstrapAllGather(void* data, int size) {
|
||||
mscclppBootstrapAllGather(pimpl->comm, data, size);
|
||||
}
|
||||
|
||||
void Communicator::bootstrapBarrier() {
|
||||
mscclppBootstrapBarrier(pimpl->comm);
|
||||
}
|
||||
|
||||
std::shared_ptr<HostConnection> Communicator::connect(int remoteRank, int tag,
|
||||
TransportType transportType, const char* ibDev = 0) {
|
||||
mscclppConnectWithoutBuffer(pimpl->comm, remoteRank, tag, transportTypeToCStyle(transportType), ibDev);
|
||||
auto conn = std::make_shared<HostConnection>();
|
||||
auto connIdx = pimpl->connections.size();
|
||||
pimpl->connections.push_back(conn);
|
||||
return conn;
|
||||
}
|
||||
|
||||
void Communicator::connectionSetup() {
|
||||
mscclppConnectionSetup(pimpl->comm);
|
||||
mscclppHostConn_t *hostConns;
|
||||
int numHostConns;
|
||||
mscclppGetAllHostConnections(pimpl->comm, &hostConns, &numHostConns);
|
||||
if (numHostConns != pimpl->connections.size()) {
|
||||
throw std::logic_error("Number of HostConnections didn't match number of mscclppHostConns");
|
||||
}
|
||||
for (int connIdx = 0; connIdx < pimpl->connections.size(); ++connIdx) {
|
||||
pimpl->connections[connIdx]->pimpl->setup(hostConns[connIdx]);
|
||||
}
|
||||
}
|
||||
|
||||
int Communicator::rank() {
|
||||
int result;
|
||||
mscclppCommRank(pimpl->comm, &result);
|
||||
return result;
|
||||
}
|
||||
|
||||
int Communicator::size() {
|
||||
int result;
|
||||
mscclppCommSize(pimpl->comm, &result);
|
||||
return result;
|
||||
}
|
||||
|
||||
} // namespace mscclpp
|
||||
67
src/fifo.cc
Normal file
67
src/fifo.cc
Normal file
@@ -0,0 +1,67 @@
|
||||
#include "mscclppfifo.hpp"
|
||||
#include "alloc.h"
|
||||
#include "checks.hpp"
|
||||
#include <cuda_runtime.h>
|
||||
#include <stdexcept>
|
||||
#include <emmintrin.h>
|
||||
|
||||
namespace mscclpp {
|
||||
|
||||
struct HostProxyFifo::Impl {
|
||||
DeviceProxyFifo deviceFifo;
|
||||
|
||||
// allocated on the host. Only accessed by the host. This is a copy of the
|
||||
// value pointed to by fifoTailDev and the invariant is that
|
||||
// *fifoTailDev <= hostTail. Meaning that host's copy of tail is
|
||||
// always ahead of the device's copy and host updates the device's copy
|
||||
// only when it is needed. Therefore, hostTail is the "true" tail
|
||||
// and fifoTailDev is a "stale" tail. See proxy.cc to undertand how
|
||||
// these updates are pushed to the device.
|
||||
uint64_t hostTail;
|
||||
|
||||
// for transferring fifo tail
|
||||
cudaStream_t stream;
|
||||
};
|
||||
|
||||
HostProxyFifo::HostProxyFifo() {
|
||||
pimpl = std::make_unique<Impl>();
|
||||
MSCCLPPTHROW(mscclppCudaCalloc(&pimpl->deviceFifo.head, 1));
|
||||
MSCCLPPTHROW(mscclppCudaHostCalloc(&pimpl->deviceFifo.triggers, MSCCLPP_PROXY_FIFO_SIZE));
|
||||
MSCCLPPTHROW(mscclppCudaCalloc(&pimpl->deviceFifo.tailReplica, 1));
|
||||
CUDATHROW(cudaStreamCreateWithFlags(&pimpl->stream, cudaStreamNonBlocking));
|
||||
pimpl->hostTail = 0;
|
||||
}
|
||||
|
||||
HostProxyFifo::~HostProxyFifo() {
|
||||
MSCCLPPTHROW(mscclppCudaFree(pimpl->deviceFifo.head));
|
||||
MSCCLPPTHROW(mscclppCudaHostFree(pimpl->deviceFifo.triggers));
|
||||
MSCCLPPTHROW(mscclppCudaFree(pimpl->deviceFifo.tailReplica));
|
||||
CUDATHROW(cudaStreamDestroy(pimpl->stream));
|
||||
}
|
||||
|
||||
void HostProxyFifo::poll(ProxyTrigger *trigger) {
|
||||
__m128i xmm0 = _mm_load_si128((__m128i*)&pimpl->deviceFifo.triggers[pimpl->hostTail % MSCCLPP_PROXY_FIFO_SIZE]);
|
||||
_mm_store_si128((__m128i*)trigger, xmm0);
|
||||
}
|
||||
|
||||
void HostProxyFifo::pop() {
|
||||
*(volatile uint64_t*)(&pimpl->deviceFifo.triggers[pimpl->hostTail % MSCCLPP_PROXY_FIFO_SIZE]) = 0;
|
||||
(pimpl->hostTail)++;
|
||||
}
|
||||
|
||||
void HostProxyFifo::flushTail(bool sync) {
|
||||
// Flush the tail to device memory. This is either triggered every MSCCLPP_PROXY_FIFO_FLUSH_COUNTER to make sure
|
||||
// that the fifo can make progress even if there is no request mscclppSync. However, mscclppSync type is for flush
|
||||
// request.
|
||||
CUDATHROW(
|
||||
cudaMemcpyAsync(pimpl->deviceFifo.tailReplica, &pimpl->hostTail, sizeof(uint64_t), cudaMemcpyHostToDevice, pimpl->stream));
|
||||
if (sync) {
|
||||
CUDATHROW(cudaStreamSynchronize(pimpl->stream));
|
||||
}
|
||||
}
|
||||
|
||||
DeviceProxyFifo HostProxyFifo::toDevice() {
|
||||
return pimpl->deviceFifo;
|
||||
}
|
||||
|
||||
} // namespace mscclpp
|
||||
79
src/host_connection.cc
Normal file
79
src/host_connection.cc
Normal file
@@ -0,0 +1,79 @@
|
||||
#include "host_connection.hpp"
|
||||
#include "communicator.hpp"
|
||||
#include "comm.h"
|
||||
#include "mscclpp.h"
|
||||
#include "mscclppfifo.h"
|
||||
#include "api.h"
|
||||
|
||||
namespace mscclpp {
|
||||
|
||||
HostConnection::Impl::Impl(Communicator* comm, mscclppConn* conn) : comm(comm), conn(conn) {
|
||||
this->hostConn = conn->hostConn;
|
||||
}
|
||||
|
||||
HostConnection::Impl::~Impl() {
|
||||
// TODO: figure out memory ownership. Does this deallocate the mscclppHostConn? Likely not.
|
||||
}
|
||||
|
||||
MSCCLPP_API_CPP HostConnection::~HostConnection() = default;
|
||||
|
||||
MSCCLPP_API_CPP HostConnection::HostConnection(std::unique_ptr<Impl> p) : pimpl(std::move(p)) {}
|
||||
|
||||
MSCCLPP_API_CPP int HostConnection::getId() {
|
||||
return pimpl->conn->connId;
|
||||
}
|
||||
|
||||
MSCCLPP_API_CPP BufferHandle HostConnection::registerBuffer(void* data, uint64_t size) {
|
||||
BufferHandle result;
|
||||
static_assert(sizeof(BufferHandle) == sizeof(mscclppBufferHandle_t));
|
||||
mscclppRegisterBufferForConnection(pimpl->comm->pimpl->comm, pimpl->conn->connId, data, size, reinterpret_cast<mscclppBufferHandle_t*>(&result));
|
||||
return result;
|
||||
}
|
||||
|
||||
MSCCLPP_API_CPP int HostConnection::numLocalBuffers() {
|
||||
return pimpl->conn->bufferRegistrations.size() - 1;
|
||||
}
|
||||
|
||||
MSCCLPP_API_CPP BufferHandle HostConnection::getLocalBuffer(int index) {
|
||||
return index + 1;
|
||||
}
|
||||
|
||||
MSCCLPP_API_CPP int HostConnection::numRemoteBuffers() {
|
||||
return pimpl->conn->remoteBufferRegistrations.size() - 1;
|
||||
}
|
||||
|
||||
MSCCLPP_API_CPP BufferHandle HostConnection::getRemoteBuffer(int index) {
|
||||
return index + 1;
|
||||
}
|
||||
|
||||
MSCCLPP_API_CPP ConnectionEpoch HostConnection::getEpoch() {
|
||||
ConnectionEpoch epoch;
|
||||
static_assert(sizeof(SignalEpochId) == sizeof(mscclppDevConnSignalEpochId));
|
||||
epoch.localSignalEpochId = reinterpret_cast<SignalEpochId*>(pimpl->conn->devConn->localSignalEpochId);
|
||||
epoch.remoteSignalEpochId = reinterpret_cast<SignalEpochId*>(pimpl->conn->devConn->remoteSignalEpochId);
|
||||
epoch.waitEpochId = pimpl->conn->devConn->waitEpochId;
|
||||
return epoch;
|
||||
}
|
||||
|
||||
|
||||
MSCCLPP_API_CPP DeviceProxyFifo HostConnection::getDeviceFifo() {
|
||||
return pimpl->comm->pimpl->proxy.fifo().toDevice();
|
||||
}
|
||||
|
||||
MSCCLPP_API_CPP void HostConnection::put(BufferHandle dst, uint64_t dstOffset, BufferHandle src, uint64_t srcOffset, uint64_t size) {
|
||||
pimpl->hostConn->put(dst, dstOffset, src, srcOffset, size);
|
||||
}
|
||||
|
||||
MSCCLPP_API_CPP void HostConnection::signal() {
|
||||
pimpl->hostConn->signal();
|
||||
}
|
||||
|
||||
MSCCLPP_API_CPP void HostConnection::flush() {
|
||||
pimpl->hostConn->flush();
|
||||
}
|
||||
|
||||
MSCCLPP_API_CPP void HostConnection::wait() {
|
||||
pimpl->hostConn->wait();
|
||||
}
|
||||
|
||||
} // namespace mscclpp
|
||||
@@ -1,55 +0,0 @@
|
||||
#include "host_connection.hpp"
|
||||
|
||||
namespace mscclpp {
|
||||
|
||||
HostConnection::Impl::Impl() : hostConn(nullptr) {}
|
||||
|
||||
HostConnection::Impl::~Impl() {
|
||||
// TODO: figure out memory ownership. Does this deallocate the mscclppHostConn? Likely not.
|
||||
}
|
||||
|
||||
void HostConnection::Impl::setup(mscclppHostConn_t *hostConn) {
|
||||
this->hostConn = hostConn;
|
||||
}
|
||||
|
||||
BufferHandle HostConnection::registerBuffer(void* data, uint64_t size) {
|
||||
|
||||
}
|
||||
|
||||
int HostConnection::numRemoteBuffers() {
|
||||
|
||||
}
|
||||
|
||||
BufferHandle HostConnection::getRemoteBuffer(int index) {
|
||||
|
||||
}
|
||||
|
||||
DeviceConnection HostConnection::toDevice(bool startProxyThread = true) {
|
||||
|
||||
}
|
||||
|
||||
void HostConnection::put(BufferHandle dst, uint64_t dstOffset, BufferHandle src, uint64_t srcOffset, uint64_t size) {
|
||||
|
||||
}
|
||||
|
||||
void HostConnection::put(BufferHandle dst, BufferHandle src, uint64_t offset, uint64_t size) {
|
||||
|
||||
}
|
||||
|
||||
void HostConnection::signal() {
|
||||
|
||||
}
|
||||
|
||||
void HostConnection::flush() {
|
||||
|
||||
}
|
||||
|
||||
void HostConnection::wait() {
|
||||
|
||||
}
|
||||
|
||||
void HostConnection::epochIncrement() {
|
||||
|
||||
}
|
||||
|
||||
} // namespace mscclpp
|
||||
@@ -2,5 +2,6 @@
|
||||
#define MSCCLPP_API_H_
|
||||
|
||||
#define MSCCLPP_API extern "C" __attribute__((visibility("default")))
|
||||
#define MSCCLPP_API_CPP __attribute__((visibility("default")))
|
||||
|
||||
#endif // MSCCLPP_API_H_
|
||||
|
||||
13
src/include/basic_proxy_handler.hpp
Normal file
13
src/include/basic_proxy_handler.hpp
Normal file
@@ -0,0 +1,13 @@
|
||||
#ifndef MSCCLPP_BASIC_PROXY_SERVICE_HPP_
|
||||
#define MSCCLPP_BASIC_PROXY_SERVICE_HPP_
|
||||
|
||||
#include "mscclpp.hpp"
|
||||
#include "communicator.hpp"
|
||||
|
||||
namespace mscclpp {
|
||||
|
||||
ProxyHandler makeBasicProxyHandler(Communicator::Impl &comm);
|
||||
|
||||
}
|
||||
|
||||
#endif
|
||||
@@ -27,29 +27,3 @@
|
||||
} while (false)
|
||||
|
||||
#endif
|
||||
|
||||
#include <errno.h>
|
||||
// Check system calls
|
||||
#define SYSCHECKTHROW(call, name) \
|
||||
do { \
|
||||
int retval; \
|
||||
SYSCHECKVAL(call, name, retval); \
|
||||
} while (false)
|
||||
|
||||
#define SYSCHECKVALTHROW(call, name, retval) \
|
||||
do { \
|
||||
SYSCHECKSYNC(call, name, retval); \
|
||||
if (retval == -1) { \
|
||||
std::runtime_error(std::string("Call to " name " failed : ") + strerror(errno)); \
|
||||
} \
|
||||
} while (false)
|
||||
|
||||
#define SYSCHECKSYNCTHROW(call, name, retval) \
|
||||
do { \
|
||||
retval = call; \
|
||||
if (retval == -1 && (errno == EINTR || errno == EWOULDBLOCK || errno == EAGAIN)) { \
|
||||
INFO(MSCCLPP_ALL, "Call to " name " returned %s, retrying", strerror(errno)); \
|
||||
} else { \
|
||||
break; \
|
||||
} \
|
||||
} while (true)
|
||||
|
||||
@@ -0,0 +1,23 @@
|
||||
#ifndef MSCCL_COMMUNICATOR_HPP_
|
||||
#define MSCCL_COMMUNICATOR_HPP_
|
||||
|
||||
#include "mscclpp.hpp"
|
||||
#include "mscclpp.h"
|
||||
|
||||
namespace mscclpp {
|
||||
|
||||
struct Communicator::Impl {
|
||||
mscclppComm_t comm;
|
||||
std::vector<std::shared_ptr<HostConnection>> connections;
|
||||
Proxy proxy;
|
||||
|
||||
Impl();
|
||||
|
||||
~Impl();
|
||||
|
||||
friend class HostConnection;
|
||||
};
|
||||
|
||||
} // namespace mscclpp
|
||||
|
||||
#endif
|
||||
@@ -3,17 +3,18 @@
|
||||
|
||||
#include "mscclpp.hpp"
|
||||
#include "mscclpp.h"
|
||||
#include "comm.h"
|
||||
|
||||
namespace mscclpp {
|
||||
|
||||
struct HostConnection::Impl {
|
||||
Communicator* comm;
|
||||
mscclppConn* conn;
|
||||
mscclppHostConn_t* hostConn;
|
||||
|
||||
Impl();
|
||||
Impl(Communicator* comm, mscclppConn* conn);
|
||||
|
||||
~Impl();
|
||||
|
||||
void setup(mscclppHostConn_t *hostConn);
|
||||
};
|
||||
|
||||
} // namespace mscclpp
|
||||
|
||||
@@ -29,7 +29,7 @@ struct alignas(16) mscclppDevConnSignalEpochId
|
||||
uint64_t proxy;
|
||||
};
|
||||
|
||||
using mscclppBufferHandle_t = uint8_t;
|
||||
using mscclppBufferHandle_t = uint32_t;
|
||||
|
||||
/***************************************************************************************************************
|
||||
* A mscclppDevConn provides a zero-copy connection between two GPUs connected via P2P NVLink or InfiniBand.
|
||||
|
||||
@@ -13,6 +13,7 @@
|
||||
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
#include <functional>
|
||||
|
||||
#include <mscclppfifo.hpp>
|
||||
|
||||
@@ -27,15 +28,14 @@ struct alignas(16) SignalEpochId {
|
||||
uint64_t proxy;
|
||||
};
|
||||
|
||||
enum ChannelTriggerType : uint64_t {
|
||||
channelTriggerData = 0x1,
|
||||
channelTriggerFlag = 0x2,
|
||||
channelTriggerSync = 0x4
|
||||
};
|
||||
using ChannelTriggerType = uint64_t;
|
||||
const ChannelTriggerType channelTriggerData = 0x1;
|
||||
const ChannelTriggerType channelTriggerFlag = 0x2;
|
||||
const ChannelTriggerType channelTriggerSync = 0x4;
|
||||
|
||||
// This is just a numeric ID. Each HostConnection will have an internal array indexed by these handles
|
||||
// mapping to the actual
|
||||
using BufferHandle = uint8_t;
|
||||
using BufferHandle = uint32_t;
|
||||
|
||||
#define MSCCLPP_BITS_SIZE 32
|
||||
#define MSCCLPP_BITS_OFFSET 32
|
||||
@@ -58,15 +58,111 @@ union ChannelTrigger {
|
||||
uint64_t srcBufferHandle : MSCCLPP_BITS_BUFFER_HANDLE;
|
||||
uint64_t dstBufferHandle : MSCCLPP_BITS_BUFFER_HANDLE;
|
||||
uint64_t type : MSCCLPP_BITS_TYPE;
|
||||
uint64_t connId : MSCCLPP_BITS_CONNID;
|
||||
uint64_t : (64 - MSCCLPP_BITS_OFFSET - MSCCLPP_BITS_BUFFER_HANDLE - MSCCLPP_BITS_BUFFER_HANDLE - MSCCLPP_BITS_TYPE); // ensure 64-bit alignment
|
||||
} fields;
|
||||
|
||||
ChannelTrigger() {}
|
||||
ChannelTrigger(ProxyTrigger value) : value(value) {}
|
||||
ChannelTrigger(ChannelTriggerType type, BufferHandle dst, uint64_t dstOffset, BufferHandle src, uint64_t srcOffset, uint64_t size) {
|
||||
#ifdef __CUDACC__
|
||||
__device__ ChannelTrigger() {}
|
||||
__device__ ChannelTrigger(ProxyTrigger value) : value(value) {}
|
||||
__device__ ChannelTrigger(ChannelTriggerType type, BufferHandle dst, uint64_t dstOffset, BufferHandle src, uint64_t srcOffset, uint64_t size, int connectionId) {
|
||||
value.fst = ((srcOffset << MSCCLPP_BITS_SIZE) + size);
|
||||
value.snd = (((((((uint64_t)type << MSCCLPP_BITS_BUFFER_HANDLE) + dst) << MSCCLPP_BITS_BUFFER_HANDLE) + src) << MSCCLPP_BITS_OFFSET) + dstOffset);
|
||||
value.snd = ((((((((connectionId << MSCCLPP_BITS_TYPE) + (uint64_t)type) << MSCCLPP_BITS_BUFFER_HANDLE) + dst) << MSCCLPP_BITS_BUFFER_HANDLE) + src) << MSCCLPP_BITS_OFFSET) + dstOffset);
|
||||
}
|
||||
#endif // __CUDACC__
|
||||
};
|
||||
|
||||
struct ConnectionEpoch {
|
||||
#ifdef __CUDACC__
|
||||
__forceinline__ __device__ void wait()
|
||||
{
|
||||
(*waitEpochId) += 1;
|
||||
while (*(volatile uint64_t*)&(localSignalEpochId->proxy) < (*waitEpochId))
|
||||
;
|
||||
}
|
||||
|
||||
__forceinline__ __device__ void epochIncrement()
|
||||
{
|
||||
*(volatile uint64_t*)&(localSignalEpochId->device) += 1;
|
||||
}
|
||||
#endif // __CUDACC__
|
||||
|
||||
SignalEpochId* localSignalEpochId;
|
||||
// used by the signal() function directly from gpu
|
||||
SignalEpochId* remoteSignalEpochId;
|
||||
|
||||
// every wait(), increments this and then the gpu waits for either:
|
||||
// 1) localSignalEpochId->proxy to be >= this in case of a proxy thread
|
||||
// 2) remoteSignalEpochId->device to be >= this in case of a gpu thread
|
||||
uint64_t* waitEpochId;
|
||||
};
|
||||
|
||||
class HostConnection {
|
||||
struct Impl;
|
||||
public:
|
||||
/* HostConnection can not be constructed from user code and must instead be created through Communicator::connect */
|
||||
HostConnection(std::unique_ptr<Impl>);
|
||||
|
||||
~HostConnection();
|
||||
|
||||
int getId();
|
||||
|
||||
/* Register a region of GPU memory for use with this connection. Must be called before connectionSetup()
|
||||
* in the communicator.
|
||||
*
|
||||
* Inputs:
|
||||
* data: base pointer to the memory
|
||||
* size: size of the memory region in bytes
|
||||
*
|
||||
* Returns: a handle to the buffer
|
||||
*/
|
||||
BufferHandle registerBuffer(void* data, uint64_t size);
|
||||
|
||||
/* Get the number of times registerBuffer(...) was called.
|
||||
*
|
||||
* Returns: the number of buffers registered
|
||||
*/
|
||||
int numLocalBuffers();
|
||||
|
||||
/* Get the BufferHandle returned by a call to registerBuffer(...) as identified by the index
|
||||
*
|
||||
* Inputs:
|
||||
* index: the index of the handle to get
|
||||
*
|
||||
* Returns: a handle to the buffer
|
||||
*/
|
||||
BufferHandle getLocalBuffer(int index);
|
||||
|
||||
/* Get the number of times registerBuffer(...) was called on the remote peer.
|
||||
*
|
||||
* Returns: the number of buffers registered on the remote peer
|
||||
*/
|
||||
int numRemoteBuffers();
|
||||
|
||||
/* Get the BufferHandle returned by a call to registerBuffer(...) on the remote peer as identified by the index
|
||||
*
|
||||
* Inputs:
|
||||
* index: the index of the handle to get
|
||||
*
|
||||
* Returns: a handle to the buffer on the remote peer
|
||||
*/
|
||||
BufferHandle getRemoteBuffer(int index);
|
||||
|
||||
ConnectionEpoch getEpoch();
|
||||
|
||||
DeviceProxyFifo getDeviceFifo();
|
||||
|
||||
void put(BufferHandle dst, uint64_t dstOffset, BufferHandle src, uint64_t srcOffset, uint64_t size);
|
||||
|
||||
void signal();
|
||||
|
||||
void flush();
|
||||
|
||||
void wait();
|
||||
|
||||
private:
|
||||
std::unique_ptr<Impl> pimpl;
|
||||
friend class Communicator;
|
||||
};
|
||||
|
||||
/***************************************************************************************************************
|
||||
@@ -132,12 +228,20 @@ union ChannelTrigger {
|
||||
* indices in the registered buffer.
|
||||
**************************************************************************************************************/
|
||||
struct DeviceConnection {
|
||||
#ifdef __CUDACC__
|
||||
// TODO: add buffer handles
|
||||
DeviceConnection() = default;
|
||||
|
||||
DeviceConnection(HostConnection& hostConn)
|
||||
: connectionId(hostConn.getId()), epoch(hostConn.getEpoch()),
|
||||
fifo(hostConn.getDeviceFifo()) {}
|
||||
|
||||
DeviceConnection(const DeviceConnection& other) = default;
|
||||
|
||||
DeviceConnection& operator=(DeviceConnection& other) = default;
|
||||
|
||||
#ifdef __CUDACC__
|
||||
__forceinline__ __device__ void put(BufferHandle dst, uint64_t dstOffset, BufferHandle src, uint64_t srcOffset, uint64_t size)
|
||||
{
|
||||
fifo.push(ChannelTrigger(channelTriggerData, dst, dstOffset, src, srcOffset, size).value);
|
||||
fifo.push(ChannelTrigger(channelTriggerData, dst, dstOffset, src, srcOffset, size, connectionId).value);
|
||||
}
|
||||
|
||||
__forceinline__ __device__ void put(BufferHandle dst, BufferHandle src, uint64_t offset, uint64_t size)
|
||||
@@ -148,13 +252,13 @@ struct DeviceConnection {
|
||||
__forceinline__ __device__ void signal()
|
||||
{
|
||||
epochIncrement();
|
||||
fifo.push(ChannelTrigger(channelTriggerFlag, 0, 0, 0, 0, 1).value);
|
||||
fifo.push(ChannelTrigger(channelTriggerFlag, 0, 0, 0, 0, 1, connectionId).value);
|
||||
}
|
||||
|
||||
__forceinline__ __device__ void putWithSignal(BufferHandle dst, uint64_t dstOffset, BufferHandle src, uint64_t srcOffset, uint64_t size)
|
||||
{
|
||||
epochIncrement();
|
||||
fifo.push(ChannelTrigger(channelTriggerData | channelTriggerFlag, dst, dstOffset, src, srcOffset, size).value);
|
||||
fifo.push(ChannelTrigger(channelTriggerData | channelTriggerFlag, dst, dstOffset, src, srcOffset, size, connectionId).value);
|
||||
}
|
||||
|
||||
__forceinline__ __device__ void putWithSignal(BufferHandle dst, BufferHandle src, uint64_t offset, uint64_t size)
|
||||
@@ -165,107 +269,116 @@ struct DeviceConnection {
|
||||
__forceinline__ __device__ void putWithSignalAndFlush(BufferHandle dst, uint64_t dstOffset, BufferHandle src, uint64_t srcOffset, uint64_t size)
|
||||
{
|
||||
epochIncrement();
|
||||
uint64_t curFifoHead = fifo.push(channelTriggerData | channelTriggerFlag | channelTriggerSync, dstOffset, srcOffset, size);
|
||||
while (*(volatile uint64_t*)&fifo.triggerFifo[curFifoHead % MSCCLPP_PROXY_FIFO_SIZE] != 0 &&
|
||||
*(volatile uint64_t*)fifo.triggerFifoTail <= curFifoHead)
|
||||
uint64_t curFifoHead = fifo.push(ChannelTrigger(channelTriggerData | channelTriggerFlag | channelTriggerSync, dst, dstOffset, src, srcOffset, size, connectionId).value);
|
||||
while (*(volatile uint64_t*)&fifo.triggers[curFifoHead % MSCCLPP_PROXY_FIFO_SIZE] != 0 &&
|
||||
*(volatile uint64_t*)fifo.tailReplica <= curFifoHead)
|
||||
;
|
||||
}
|
||||
|
||||
__forceinline__ __device__ void putWithSignalAndFlush(BufferHandle dst, BufferHandle src, uint64_t offset, uint64_t size)
|
||||
{
|
||||
putWithSignalAndFlush(dst, offset, src, offset, size);
|
||||
}
|
||||
|
||||
__forceinline__ __device__ void flush()
|
||||
{
|
||||
uint64_t curFifoHead = fifo.push(ChannelTrigger(mscclppSync, 0, 0, 0, 0, 1, connectionId).value);
|
||||
// we need to wait for two conditions to be met to ensure the CPU is done flushing. (1) wait for the tail
|
||||
// to go pass by curFifoHead (this is safety net) and (2) wait for the work element value to change to 0.
|
||||
while (*(volatile uint64_t*)&fifo.triggers[curFifoHead % MSCCLPP_PROXY_FIFO_SIZE] != 0 &&
|
||||
*(volatile uint64_t*)fifo.tailReplica <= curFifoHead)
|
||||
;
|
||||
}
|
||||
|
||||
__forceinline__ __device__ void wait()
|
||||
{
|
||||
epoch.wait();
|
||||
}
|
||||
|
||||
__forceinline__ __device__ void epochIncrement()
|
||||
{
|
||||
epoch.epochIncrement();
|
||||
}
|
||||
#endif // __CUDACC__
|
||||
|
||||
int connectionId;
|
||||
|
||||
ConnectionEpoch epoch;
|
||||
|
||||
// this is a concurrent fifo which is multiple threads from the device
|
||||
// can produce for and the sole proxy thread consumes it.
|
||||
DeviceProxyFifo fifo;
|
||||
};
|
||||
|
||||
struct SimpleDeviceConnection {
|
||||
SimpleDeviceConnection() = default;
|
||||
|
||||
SimpleDeviceConnection(HostConnection& hostConn) : devConn(hostConn) {
|
||||
dst = hostConn.getRemoteBuffer(0);
|
||||
src = hostConn.getLocalBuffer(0);
|
||||
}
|
||||
|
||||
SimpleDeviceConnection(const SimpleDeviceConnection& other) = default;
|
||||
|
||||
SimpleDeviceConnection& operator=(SimpleDeviceConnection& other) = default;
|
||||
|
||||
#ifdef __CUDACC__
|
||||
|
||||
__forceinline__ __device__ void put(uint64_t dstOffset, uint64_t srcOffset, uint64_t size)
|
||||
{
|
||||
devConn.put(dst, dstOffset, src, srcOffset, size);
|
||||
}
|
||||
|
||||
__forceinline__ __device__ void put(uint64_t offset, uint64_t size)
|
||||
{
|
||||
put(offset, offset, size);
|
||||
}
|
||||
|
||||
__forceinline__ __device__ void signal()
|
||||
{
|
||||
devConn.signal();
|
||||
}
|
||||
|
||||
__forceinline__ __device__ void putWithSignal(uint64_t dstOffset, uint64_t srcOffset, uint64_t size)
|
||||
{
|
||||
devConn.putWithSignal(dst, dstOffset, src, srcOffset, size);
|
||||
}
|
||||
|
||||
__forceinline__ __device__ void putWithSignal(uint64_t offset, uint64_t size)
|
||||
{
|
||||
putWithSignal(offset, offset, size);
|
||||
}
|
||||
|
||||
__forceinline__ __device__ void putWithSignalAndFlush(uint64_t dstOffset, uint64_t srcOffset, uint64_t size)
|
||||
{
|
||||
devConn.putWithSignalAndFlush(dst, dstOffset, src, srcOffset, size);
|
||||
}
|
||||
|
||||
__forceinline__ __device__ void putWithSignalAndFlush(uint64_t offset, uint64_t size)
|
||||
{
|
||||
putWithSignalAndFlush(offset, offset, size);
|
||||
}
|
||||
|
||||
__forceinline__ __device__ void flush()
|
||||
{
|
||||
uint64_t curFifoHead = fifo.push(mscclppSync, 0, 0, 1);
|
||||
// we need to wait for two conditions to be met to ensure the CPU is done flushing. (1) wait for the tail
|
||||
// to go pass by curFifoHead (this is safety net) and (2) wait for the work element value to change to 0.
|
||||
while (*(volatile uint64_t*)&fifo.triggerFifo[curFifoHead % MSCCLPP_PROXY_FIFO_SIZE] != 0 &&
|
||||
*(volatile uint64_t*)fifo.triggerFifoTail <= curFifoHead)
|
||||
;
|
||||
devConn.flush();
|
||||
}
|
||||
|
||||
__forceinline__ __device__ void wait()
|
||||
{
|
||||
(*waitEpochId) += 1;
|
||||
while (*(volatile uint64_t*)&(localSignalEpochId->proxy) < (*waitEpochId))
|
||||
;
|
||||
devConn.wait();
|
||||
}
|
||||
|
||||
__forceinline__ __device__ void epochIncrement()
|
||||
{
|
||||
*(volatile uint64_t*)&(localSignalEpochId->device) += 1;
|
||||
devConn.epochIncrement();
|
||||
}
|
||||
|
||||
#endif // __CUDACC__
|
||||
|
||||
int remoteRank;
|
||||
int tag;
|
||||
|
||||
SignalEpochId* localSignalEpochId;
|
||||
// used by the signal() function directly from gpu
|
||||
SignalEpochId* remoteSignalEpochId;
|
||||
|
||||
// every wait(), increments this and then the gpu waits for either:
|
||||
// 1) localSignalEpochId->proxy to be >= this in case of a proxy thread
|
||||
// 2) remoteSignalEpochId->device to be >= this in case of a gpu thread
|
||||
uint64_t* waitEpochId;
|
||||
|
||||
// this is a concurrent fifo which is multiple threads from the device
|
||||
// can produce for and the sole proxy thread consumes it.
|
||||
ProxyFifo fifo;
|
||||
};
|
||||
|
||||
class HostConnection {
|
||||
public:
|
||||
/* Register a region of GPU memory for use with this connection. Must be called before connectionSetup()
|
||||
* in the communicator.
|
||||
*
|
||||
* Inputs:
|
||||
* data: base pointer to the memory
|
||||
* size: size of the memory region in bytes
|
||||
*
|
||||
* Returns: a handle to the buffer
|
||||
*/
|
||||
BufferHandle registerBuffer(void* data, uint64_t size);
|
||||
|
||||
/* Get the number of times registerBuffer(...) was called on the remote peer.
|
||||
*
|
||||
* Returns: the number of buffers registered on the remote peer
|
||||
*/
|
||||
int numRemoteBuffers();
|
||||
|
||||
/* Get the BufferHandle returned by a call to registerBuffer(...) on the remote peer as identified by the index
|
||||
*
|
||||
* Inputs:
|
||||
* index: the index of the handle to get
|
||||
*
|
||||
* Returns: a handle to the buffer on the remote peer
|
||||
*/
|
||||
BufferHandle getRemoteBuffer(int index);
|
||||
|
||||
/* Create a DeviceConnection paired with this HostConnection. A background proxy thread will
|
||||
* trigger operations on this HostConnection corresponding to put/signal/etc. calls made to the
|
||||
* DeviceConnection.
|
||||
*
|
||||
* Inputs:
|
||||
* startProxyThread: whether to start the proxy thread (default is true)
|
||||
*
|
||||
* Returns: the newly created DeviceConnection
|
||||
*/
|
||||
DeviceConnection toDevice(bool startProxyThread = true);
|
||||
|
||||
void put(BufferHandle dst, uint64_t dstOffset, BufferHandle src, uint64_t srcOffset, uint64_t size);
|
||||
void put(BufferHandle dst, BufferHandle src, uint64_t offset, uint64_t size);
|
||||
void signal();
|
||||
void flush();
|
||||
void wait();
|
||||
void epochIncrement();
|
||||
|
||||
private:
|
||||
struct Impl;
|
||||
std::unique_ptr<Impl> pimpl;
|
||||
DeviceConnection devConn;
|
||||
BufferHandle dst;
|
||||
BufferHandle src;
|
||||
};
|
||||
|
||||
#define MSCCLPP_UNIQUE_ID_BYTES 128
|
||||
@@ -290,6 +403,7 @@ enum class TransportType : uint8_t {
|
||||
|
||||
class Communicator {
|
||||
public:
|
||||
|
||||
/* Initialize the communicator. nranks processes with rank 0 to nranks-1 need to call this function.
|
||||
*
|
||||
* Inputs:
|
||||
@@ -297,7 +411,7 @@ public:
|
||||
* ipPortPair: a string of the form "ip:port" that represents the address of the root process
|
||||
* rank: rank of the calling process
|
||||
*/
|
||||
void initRank(int nranks, const char* ipPortPair, int rank);
|
||||
Communicator(int nranks, const char* ipPortPair, int rank);
|
||||
|
||||
/* Initialize the communicator from a given UniqueId. Same as mscclppCommInitRank() except that
|
||||
* id is provided by the user by calling getUniqueId()
|
||||
@@ -307,7 +421,9 @@ public:
|
||||
* id: the unique ID to be used for communication
|
||||
* rank: rank of the calling process
|
||||
*/
|
||||
void initRankFromId(int nranks, UniqueId id, int rank);
|
||||
Communicator(int nranks, UniqueId id, int rank);
|
||||
|
||||
~Communicator();
|
||||
|
||||
/* Ring-based AllGather through the bootstrap socket.
|
||||
*
|
||||
@@ -341,6 +457,12 @@ public:
|
||||
*/
|
||||
void connectionSetup();
|
||||
|
||||
/* Launch proxy thread(s). This function is supposed to be called before starting a kernel that uses DeviceConnection. */
|
||||
void startProxying();
|
||||
|
||||
/* Stop proxy thread(s). */
|
||||
void stopProxying();
|
||||
|
||||
/* Return the rank of the calling process.
|
||||
*
|
||||
* Outputs:
|
||||
@@ -355,6 +477,33 @@ public:
|
||||
*/
|
||||
int size();
|
||||
|
||||
struct Impl;
|
||||
private:
|
||||
std::unique_ptr<Impl> pimpl;
|
||||
friend class HostConnection;
|
||||
};
|
||||
|
||||
enum class ProxyHandlerResult {
|
||||
Continue,
|
||||
FlushFifoTailAndContinue,
|
||||
Stop,
|
||||
};
|
||||
|
||||
class Proxy;
|
||||
using ProxyHandler = std::function<ProxyHandlerResult(ProxyTrigger)>;
|
||||
|
||||
class Proxy {
|
||||
public:
|
||||
Proxy(ProxyHandler handler);
|
||||
|
||||
~Proxy();
|
||||
|
||||
void start();
|
||||
|
||||
void stop();
|
||||
|
||||
HostProxyFifo& fifo();
|
||||
|
||||
private:
|
||||
struct Impl;
|
||||
std::unique_ptr<Impl> pimpl;
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
|
||||
#include <stdint.h>
|
||||
#include <functional>
|
||||
#include <memory>
|
||||
|
||||
namespace mscclpp {
|
||||
|
||||
@@ -13,39 +14,56 @@ struct alignas(16) ProxyTrigger {
|
||||
/* This is a concurrent fifo where multiple device threads can push mscclppTrigger work elements to
|
||||
* and a single host proxy thread consumes these work elements. There is a head pointer allocated on device
|
||||
* which starts with 0 and goes to 2^64-1 which is almost infinity. There are two copies of tail, one
|
||||
* that is on the deivce (triggerFifoTail) and another that is on host (proxyState->fifoTailHost).
|
||||
* that is on the deivce (tailReplica) and another that is on host (proxyState->fifoTailHost).
|
||||
* The host always has the "true" tail and occasionally, pushes it to the copy on the device.
|
||||
* Therefore, most of the time, the device has a stale version. The invariants are:
|
||||
* triggerFifoTail <= proxyState->fifoTailHost <= triggerFifoHead.
|
||||
* push() function increments triggerFifoHead, proxyState->fifoTailHost is updated in proxy.cc:mscclppProxyService
|
||||
* and it occasionally flushes it to triggerFifoTail via a cudaMemcpyAsync.
|
||||
* tailReplica <= proxyState->fifoTailHost <= head.
|
||||
* push() function increments head, proxyState->fifoTailHost is updated in proxy.cc:mscclppProxyService
|
||||
* and it occasionally flushes it to tailReplica via a cudaMemcpyAsync.
|
||||
*
|
||||
* Why duplicating the tail is a good idea? The fifo is large engouh and we do not need frequent updates
|
||||
* for the tail as there is usually enough space for device threads to push their work into.
|
||||
*/
|
||||
struct ProxyFifo {
|
||||
struct DeviceProxyFifo {
|
||||
#ifdef __CUDACC__
|
||||
__forceinline__ __device__ uint64_t push(ProxyTrigger element)
|
||||
__forceinline__ __device__ uint64_t push(ProxyTrigger trigger)
|
||||
{
|
||||
uint64_t curFifoHead = atomicAdd((unsigned long long int*)this->triggerFifoHead, 1);
|
||||
while (curFifoHead >= MSCCLPP_PROXY_FIFO_SIZE + *((volatile uint64_t*)this->triggerFifoTail))
|
||||
uint64_t curFifoHead = atomicAdd((unsigned long long int*)this->head, 1);
|
||||
while (curFifoHead >= MSCCLPP_PROXY_FIFO_SIZE + *((volatile uint64_t*)this->tailReplica))
|
||||
;
|
||||
while (*(volatile uint64_t*)&this->triggerFifo[curFifoHead % MSCCLPP_PROXY_FIFO_SIZE] != 0)
|
||||
while (*(volatile uint64_t*)&this->triggers[curFifoHead % MSCCLPP_PROXY_FIFO_SIZE] != 0)
|
||||
;
|
||||
uint64_t* valptr = (uint64_t*)&(this->triggerFifo[curFifoHead % MSCCLPP_PROXY_FIFO_SIZE].value);
|
||||
asm volatile("st.volatile.global.v2.u64 [%0], {%1,%2};" ::"l"(valptr),
|
||||
"l"(element.value[0]), "l"(element.value[1]));
|
||||
ProxyTrigger* triggerPtr = (ProxyTrigger*)&(this->triggers[curFifoHead % MSCCLPP_PROXY_FIFO_SIZE]);
|
||||
asm volatile("st.volatile.global.v2.u64 [%0], {%1,%2};" ::"l"(triggerPtr),
|
||||
"l"(trigger.fst), "l"(trigger.snd));
|
||||
return curFifoHead;
|
||||
}
|
||||
#endif // __CUDACC__
|
||||
|
||||
void startProxyThread(std::function<void(ProxyTrigger)> handler);
|
||||
void stopProxyThread();
|
||||
|
||||
ProxyTrigger* triggerFifo; // Allocate on host via cudaHostAlloc. This space is used for pushing the workelements
|
||||
uint64_t* triggerFifoTail; // Allocated on device. proxyState->fifoTailHost is the true tail on host and pused
|
||||
ProxyTrigger* triggers; // Allocate on host via cudaHostAlloc. This space is used for pushing the workelements
|
||||
uint64_t* tailReplica; // Allocated on device. proxyState->fifoTailHost is the true tail on host and pused
|
||||
// occasionally to device
|
||||
uint64_t* triggerFifoHead; // Allocated on device. Only accessed by device
|
||||
uint64_t* head; // Allocated on device. Only accessed by device
|
||||
};
|
||||
|
||||
class HostProxyFifo
|
||||
{
|
||||
public:
|
||||
HostProxyFifo();
|
||||
|
||||
~HostProxyFifo();
|
||||
|
||||
void poll(ProxyTrigger *trigger);
|
||||
|
||||
void pop();
|
||||
|
||||
void flushTail(bool sync = false);
|
||||
|
||||
DeviceProxyFifo toDevice();
|
||||
|
||||
private:
|
||||
struct Impl;
|
||||
std::unique_ptr<Impl> pimpl;
|
||||
};
|
||||
|
||||
} // namespace mscclpp
|
||||
|
||||
@@ -629,7 +629,7 @@ struct connInfo
|
||||
h.numBufferInfos = bufferInfos.size();
|
||||
MSCCLPPCHECK(bootstrapSend(bootstrap, remoteRank, tag, &h, sizeof(header)));
|
||||
MSCCLPPCHECK(bootstrapSend(bootstrap, remoteRank, tag, bufferInfos.data(), bufferInfos.size() * sizeof(mscclppBufferRegistrationInfo)));
|
||||
return mscclppSuccess;
|
||||
return mscclppSuccess;
|
||||
}
|
||||
|
||||
mscclppResult_t recvOverBootstrap(void* bootstrap, int remoteRank, int tag) {
|
||||
@@ -638,7 +638,7 @@ struct connInfo
|
||||
infoQp = h.infoQp;
|
||||
bufferInfos.resize(h.numBufferInfos);
|
||||
MSCCLPPCHECK(bootstrapRecv(bootstrap, remoteRank, tag, bufferInfos.data(), bufferInfos.size() * sizeof(mscclppBufferRegistrationInfo)));
|
||||
return mscclppSuccess;
|
||||
return mscclppSuccess;
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
97
src/proxy_cpp.cc
Normal file
97
src/proxy_cpp.cc
Normal file
@@ -0,0 +1,97 @@
|
||||
#include "mscclpp.hpp"
|
||||
#include "utils.h"
|
||||
#include "api.h"
|
||||
#include <thread>
|
||||
#include <atomic>
|
||||
|
||||
namespace mscclpp {
|
||||
|
||||
const int ProxyStopCheckPeriod = 1000;
|
||||
|
||||
const int ProxyFlushPeriod = 4;
|
||||
|
||||
struct Proxy::Impl {
|
||||
ProxyHandler handler;
|
||||
HostProxyFifo fifo;
|
||||
std::thread service;
|
||||
std::atomic_bool running;
|
||||
|
||||
Impl(ProxyHandler handler) : handler(handler), running(false) {}
|
||||
};
|
||||
|
||||
MSCCLPP_API_CPP Proxy::Proxy(ProxyHandler handler) {
|
||||
pimpl = std::make_unique<Impl>(handler);
|
||||
}
|
||||
|
||||
MSCCLPP_API_CPP Proxy::~Proxy() {
|
||||
if (pimpl) {
|
||||
stop();
|
||||
}
|
||||
}
|
||||
|
||||
MSCCLPP_API_CPP void Proxy::start() {
|
||||
pimpl->running = true;
|
||||
pimpl->service = std::thread([this] {
|
||||
// from this point on, proxy thread will stay close to the device
|
||||
// PROXYMSCCLPPCHECK(numaBind(pimpl->comm->devNumaNode)); // TODO: reenable this
|
||||
|
||||
ProxyHandler handler = this->pimpl->handler;
|
||||
HostProxyFifo& fifo = this->pimpl->fifo;
|
||||
std::atomic_bool& running = this->pimpl->running;
|
||||
ProxyTrigger trigger;
|
||||
|
||||
int runCnt = ProxyStopCheckPeriod;
|
||||
uint64_t flushCnt = 0;
|
||||
for (;;) {
|
||||
if (runCnt-- == 0) {
|
||||
runCnt = ProxyStopCheckPeriod;
|
||||
if (!running) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
// Poll to see if we are ready to send anything
|
||||
fifo.poll(&trigger);
|
||||
if (trigger.fst == 0) { // TODO: this check is a potential pitfall for custom triggers
|
||||
continue; // there is one in progress
|
||||
}
|
||||
|
||||
ProxyHandlerResult result = handler(trigger);
|
||||
|
||||
// Send completion: reset only the high 64 bits
|
||||
fifo.pop();
|
||||
// Flush the tail to device memory. This is either triggered every ProxyFlushPeriod to make sure
|
||||
// that the fifo can make progress even if there is no request mscclppSync. However, mscclppSync type is for flush
|
||||
// request.
|
||||
if ((++flushCnt % ProxyFlushPeriod) == 0 || result == ProxyHandlerResult::FlushFifoTailAndContinue) {
|
||||
// TODO: relocate this check: || (trigger.fields.type & mscclppSync)
|
||||
fifo.flushTail();
|
||||
}
|
||||
|
||||
if (result == ProxyHandlerResult::Stop) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// make sure the tail is flushed before we shut the proxy
|
||||
fifo.flushTail(/*sync=*/true);
|
||||
// TODO: do these need to run?
|
||||
// bool isP2pProxy = (proxyState->ibContext == nullptr);
|
||||
// if (isP2pProxy) {
|
||||
// cudaStream_t p2pStream = proxyState->p2pStream;
|
||||
// PROXYCUDACHECK(cudaStreamSynchronize(p2pStream));
|
||||
// }
|
||||
});
|
||||
}
|
||||
|
||||
MSCCLPP_API_CPP void Proxy::stop() {
|
||||
pimpl->running = false;
|
||||
if (pimpl->service.joinable()) {
|
||||
pimpl->service.join();
|
||||
}
|
||||
}
|
||||
|
||||
MSCCLPP_API_CPP HostProxyFifo& Proxy::fifo() {
|
||||
return pimpl->fifo;
|
||||
}
|
||||
|
||||
} // namespace mscclpp
|
||||
@@ -10,6 +10,8 @@
|
||||
#include <string>
|
||||
#include <unistd.h>
|
||||
#include <unordered_map>
|
||||
#include <cassert>
|
||||
#include <algorithm>
|
||||
|
||||
static int nranksPerNode = 8;
|
||||
|
||||
@@ -46,9 +48,9 @@ static double getTime(void)
|
||||
return (tspec.tv_nsec / 1.0e9) + tspec.tv_sec;
|
||||
}
|
||||
|
||||
__constant__ mscclpp::DeviceConnection constDevConns[16];
|
||||
__constant__ mscclpp::SimpleDeviceConnection constDevConns[16];
|
||||
|
||||
__device__ void allgather0(mscclppDevConn_t devConn, int rank, int world_size, int remoteRank, size_t nelemsPerGPU)
|
||||
__device__ void allgather0(mscclpp::SimpleDeviceConnection devConn, int rank, int world_size, int remoteRank, size_t nelemsPerGPU)
|
||||
{
|
||||
// this allgather is really simple and implemented as an alltoall
|
||||
|
||||
@@ -67,7 +69,7 @@ __device__ void allgather0(mscclppDevConn_t devConn, int rank, int world_size, i
|
||||
devConn.wait();
|
||||
}
|
||||
|
||||
__device__ void localAllGather(mscclppDevConn_t devConn, int rank, int world_size, int nranksPerNode, int remoteRank,
|
||||
__device__ void localAllGather(mscclpp::SimpleDeviceConnection devConn, int rank, int world_size, int nranksPerNode, int remoteRank,
|
||||
uint64_t offset, uint64_t size)
|
||||
{
|
||||
// this allgather algorithm works as follows:
|
||||
@@ -91,14 +93,14 @@ __device__ void localAllGather(mscclppDevConn_t devConn, int rank, int world_siz
|
||||
}
|
||||
}
|
||||
|
||||
__device__ void allgather1(mscclppDevConn_t devConn, int rank, int world_size, int nranksPerNode, int remoteRank,
|
||||
__device__ void allgather1(mscclpp::SimpleDeviceConnection devConn, int rank, int world_size, int nranksPerNode, int remoteRank,
|
||||
size_t nelemsPerGPU)
|
||||
{
|
||||
localAllGather(devConn, rank, world_size, nranksPerNode, remoteRank, rank * nelemsPerGPU * sizeof(int),
|
||||
nelemsPerGPU * sizeof(int));
|
||||
}
|
||||
|
||||
__device__ void allgather2(mscclppDevConn_t devConn, int rank, int world_size, int nranksPerNode, int remoteRank,
|
||||
__device__ void allgather2(mscclpp::SimpleDeviceConnection devConn, int rank, int world_size, int nranksPerNode, int remoteRank,
|
||||
size_t nelemsPerGPU)
|
||||
{
|
||||
// this allgather is a pipelined and hierarchical one and only works for two nodes
|
||||
@@ -167,7 +169,7 @@ __global__ void kernel(int rank, int world_size, int nranksPerNode, size_t nelem
|
||||
int warpId = threadIdx.x / 32;
|
||||
int remoteRank = (warpId < rank) ? warpId : warpId + 1;
|
||||
// Each warp is responsible for one of the remote ranks
|
||||
mscclppDevConn_t devConn = constDevConns[warpId];
|
||||
mscclpp::SimpleDeviceConnection devConn = constDevConns[warpId];
|
||||
|
||||
if (kernel == 0)
|
||||
allgather0(devConn, rank, world_size, remoteRank, nelemsPerGPU);
|
||||
@@ -219,7 +221,7 @@ void setupMscclppConnections(int rank, int world_size, mscclpp::Communicator& co
|
||||
int thisNode = rankToNode(rank);
|
||||
int cudaNum = rankToLocalRank(rank);
|
||||
std::string ibDevStr = "mlx5_ib" + std::to_string(cudaNum);
|
||||
std::vector<mscclpp::DeviceConnection> devConns(world_size);
|
||||
std::vector<std::shared_ptr<mscclpp::HostConnection>> hostConns;
|
||||
|
||||
for (int r = 0; r < world_size; ++r) {
|
||||
if (r == rank)
|
||||
@@ -235,12 +237,19 @@ void setupMscclppConnections(int rank, int world_size, mscclpp::Communicator& co
|
||||
// Connect with all other ranks
|
||||
auto hostConn = comm.connect(r, 0, transportType, ibDev);
|
||||
hostConn->registerBuffer(data_d, dataSize);
|
||||
devConns.push_back(hostConn->toDevice(false));
|
||||
hostConns.push_back(hostConn);
|
||||
}
|
||||
|
||||
comm.connectionSetup();
|
||||
assert(devConns.size() < sizeof(constDevConns) / sizeof(mscclpp::DeviceConnection));
|
||||
CUDACHECK(cudaMemcpyToSymbol(constDevConns, devConns.data(), sizeof(mscclpp::DeviceConnection) * devConns.size() ));
|
||||
|
||||
std::vector<mscclpp::SimpleDeviceConnection> devConns;
|
||||
std::transform(hostConns.begin(), hostConns.end(), std::back_inserter(devConns),
|
||||
[](std::shared_ptr<mscclpp::HostConnection>& hostConn) {
|
||||
return mscclpp::SimpleDeviceConnection(*hostConn);
|
||||
});
|
||||
|
||||
assert(devConns.size() < sizeof(constDevConns) / sizeof(mscclpp::SimpleDeviceConnection));
|
||||
CUDACHECK(cudaMemcpyToSymbol(constDevConns, devConns.data(), sizeof(mscclpp::SimpleDeviceConnection) * devConns.size() ));
|
||||
}
|
||||
|
||||
void printUsage(const char* prog, bool isMpi)
|
||||
@@ -391,12 +400,9 @@ int main(int argc, const char* argv[])
|
||||
size_t nelemsPerGPU = dataSize / sizeof(int) / world_size;
|
||||
|
||||
try{
|
||||
mscclpp::Communicator comm;
|
||||
|
||||
if (rank == 0)
|
||||
printf("Initializing MSCCL++\n");
|
||||
|
||||
comm.initRank(world_size, ip_port, rank);
|
||||
mscclpp::Communicator comm(world_size, ip_port, rank);
|
||||
|
||||
if (rank == 0)
|
||||
printf("Initializing data for allgather test\n");
|
||||
@@ -406,97 +412,93 @@ int main(int argc, const char* argv[])
|
||||
printf("Setting up the connection in MSCCL++\n");
|
||||
setupMscclppConnections(rank, world_size, comm, data_d, dataSize);
|
||||
|
||||
if (rank == 0)
|
||||
printf("Launching MSCCL++ proxy threads\n");
|
||||
comm.startProxying();
|
||||
|
||||
if (rank == 0)
|
||||
printf("Testing the correctness of AllGather implementation\n");
|
||||
cudaStream_t stream;
|
||||
CUDACHECK(cudaStreamCreateWithFlags(&stream, cudaStreamNonBlocking));
|
||||
CUDACHECK(cudaDeviceSynchronize());
|
||||
kernel<<<1, 32 * (world_size - 1), 0, stream>>>(rank, world_size, nranksPerNode, nelemsPerGPU, kernelNum);
|
||||
CUDACHECK(cudaDeviceSynchronize());
|
||||
CUDACHECK(cudaMemcpy(data_h, data_d, dataSize, cudaMemcpyDeviceToHost));
|
||||
|
||||
for (size_t i = 0; i < nelemsPerGPU * world_size; i++) {
|
||||
int val = i + 1;
|
||||
if (data_h[i] != val) {
|
||||
printf("oh uh! data_h[%ld] (%d) != val (%d)\n", i, data_h[i], val);
|
||||
break;
|
||||
}
|
||||
}
|
||||
int tmp[16];
|
||||
// A simple barrier
|
||||
comm.bootstrapAllGather(tmp, sizeof(int));
|
||||
if (rank == 0)
|
||||
printf("Successfully checked the correctness\n");
|
||||
|
||||
// Perf test
|
||||
int iterwithoutcudagraph = 10;
|
||||
if (rank == 0)
|
||||
printf("Running %d iterations of the kernel without CUDA graph\n", iterwithoutcudagraph);
|
||||
CUDACHECK(cudaStreamSynchronize(stream));
|
||||
comm.bootstrapAllGather(tmp, sizeof(int));
|
||||
for (int i = 0; i < iterwithoutcudagraph; ++i) {
|
||||
kernel<<<1, 32 * (world_size - 1), 0, stream>>>(rank, world_size, nranksPerNode, nelemsPerGPU, kernelNum);
|
||||
}
|
||||
CUDACHECK(cudaStreamSynchronize(stream));
|
||||
comm.bootstrapAllGather(tmp, sizeof(int));
|
||||
|
||||
// cudaGraph Capture
|
||||
int cudagraphiter = 10;
|
||||
if (rank == 0)
|
||||
printf("Capturing %d iterations of the kernel in a CUDA graph\n", cudagraphiter);
|
||||
cudaGraph_t graph;
|
||||
cudaGraphExec_t instance;
|
||||
cudaStreamBeginCapture(stream, cudaStreamCaptureModeGlobal);
|
||||
for (int i = 0; i < cudagraphiter; ++i) {
|
||||
kernel<<<1, 32 * (world_size - 1), 0, stream>>>(rank, world_size, nranksPerNode, nelemsPerGPU, kernelNum);
|
||||
}
|
||||
cudaStreamEndCapture(stream, &graph);
|
||||
cudaGraphInstantiate(&instance, graph, NULL, NULL, 0);
|
||||
|
||||
int cudagraphwarmup = 10;
|
||||
if (rank == 0)
|
||||
printf("Warming up %d iterations of the CUDA graph with %d iterations of the kernel\n", cudagraphwarmup,
|
||||
cudagraphiter);
|
||||
for (int i = 0; i < cudagraphwarmup; ++i) {
|
||||
cudaGraphLaunch(instance, stream);
|
||||
}
|
||||
CUDACHECK(cudaStreamSynchronize(stream));
|
||||
|
||||
// measure runtime
|
||||
int cudagraphlaunch = 10;
|
||||
if (rank == 0)
|
||||
printf("Running %d iterations of the CUDA graph with %d iterations of the kernel\n", cudagraphlaunch,
|
||||
cudagraphiter);
|
||||
comm.bootstrapAllGather(tmp, sizeof(int));
|
||||
double t0, t1, ms, time_in_us;
|
||||
t0 = getTime();
|
||||
for (int i = 0; i < cudagraphlaunch; ++i) {
|
||||
cudaGraphLaunch(instance, stream);
|
||||
}
|
||||
CUDACHECK(cudaStreamSynchronize(stream));
|
||||
|
||||
t1 = getTime();
|
||||
ms = (t1 - t0) * 1000.0;
|
||||
time_in_us = ms * 1000. / (float)cudagraphlaunch / (float)cudagraphiter;
|
||||
printf("Rank %d report: size %lu time: %f us/iter algBW %f GBps\n", rank, dataSize, time_in_us,
|
||||
(double)(dataSize) / 1e9 / (time_in_us / 1e6));
|
||||
comm.bootstrapAllGather(tmp, sizeof(int));
|
||||
|
||||
if (rank == 0)
|
||||
printf("Stopping MSCCL++ proxy threads\n");
|
||||
comm.stopProxying();
|
||||
|
||||
} catch (std::exception& e) {
|
||||
// todo: throw exceptions in the implementation and process them here
|
||||
}
|
||||
|
||||
if (rank == 0)
|
||||
printf("Launching MSCCL++ proxy threads\n");
|
||||
MSCCLPPCHECK(mscclppProxyLaunch(comm));
|
||||
|
||||
if (rank == 0)
|
||||
printf("Testing the correctness of AllGather implementation\n");
|
||||
cudaStream_t stream;
|
||||
CUDACHECK(cudaStreamCreateWithFlags(&stream, cudaStreamNonBlocking));
|
||||
CUDACHECK(cudaDeviceSynchronize());
|
||||
kernel<<<1, 32 * (world_size - 1), 0, stream>>>(rank, world_size, nranksPerNode, nelemsPerGPU, kernelNum);
|
||||
CUDACHECK(cudaDeviceSynchronize());
|
||||
CUDACHECK(cudaMemcpy(data_h, data_d, dataSize, cudaMemcpyDeviceToHost));
|
||||
|
||||
for (size_t i = 0; i < nelemsPerGPU * world_size; i++) {
|
||||
int val = i + 1;
|
||||
if (data_h[i] != val) {
|
||||
printf("oh uh! data_h[%ld] (%d) != val (%d)\n", i, data_h[i], val);
|
||||
break;
|
||||
}
|
||||
}
|
||||
int tmp[16];
|
||||
// A simple barrier
|
||||
MSCCLPPCHECK(mscclppBootstrapAllGather(comm, tmp, sizeof(int)));
|
||||
if (rank == 0)
|
||||
printf("Successfully checked the correctness\n");
|
||||
|
||||
// Perf test
|
||||
int iterwithoutcudagraph = 10;
|
||||
if (rank == 0)
|
||||
printf("Running %d iterations of the kernel without CUDA graph\n", iterwithoutcudagraph);
|
||||
CUDACHECK(cudaStreamSynchronize(stream));
|
||||
MSCCLPPCHECK(mscclppBootstrapAllGather(comm, tmp, sizeof(int)));
|
||||
for (int i = 0; i < iterwithoutcudagraph; ++i) {
|
||||
kernel<<<1, 32 * (world_size - 1), 0, stream>>>(rank, world_size, nranksPerNode, nelemsPerGPU, kernelNum);
|
||||
}
|
||||
CUDACHECK(cudaStreamSynchronize(stream));
|
||||
MSCCLPPCHECK(mscclppBootstrapAllGather(comm, tmp, sizeof(int)));
|
||||
|
||||
// cudaGraph Capture
|
||||
int cudagraphiter = 10;
|
||||
if (rank == 0)
|
||||
printf("Capturing %d iterations of the kernel in a CUDA graph\n", cudagraphiter);
|
||||
cudaGraph_t graph;
|
||||
cudaGraphExec_t instance;
|
||||
cudaStreamBeginCapture(stream, cudaStreamCaptureModeGlobal);
|
||||
for (int i = 0; i < cudagraphiter; ++i) {
|
||||
kernel<<<1, 32 * (world_size - 1), 0, stream>>>(rank, world_size, nranksPerNode, nelemsPerGPU, kernelNum);
|
||||
}
|
||||
cudaStreamEndCapture(stream, &graph);
|
||||
cudaGraphInstantiate(&instance, graph, NULL, NULL, 0);
|
||||
|
||||
int cudagraphwarmup = 10;
|
||||
if (rank == 0)
|
||||
printf("Warming up %d iterations of the CUDA graph with %d iterations of the kernel\n", cudagraphwarmup,
|
||||
cudagraphiter);
|
||||
for (int i = 0; i < cudagraphwarmup; ++i) {
|
||||
cudaGraphLaunch(instance, stream);
|
||||
}
|
||||
CUDACHECK(cudaStreamSynchronize(stream));
|
||||
|
||||
// measure runtime
|
||||
int cudagraphlaunch = 10;
|
||||
if (rank == 0)
|
||||
printf("Running %d iterations of the CUDA graph with %d iterations of the kernel\n", cudagraphlaunch,
|
||||
cudagraphiter);
|
||||
MSCCLPPCHECK(mscclppBootstrapAllGather(comm, tmp, sizeof(int)));
|
||||
double t0, t1, ms, time_in_us;
|
||||
t0 = getTime();
|
||||
for (int i = 0; i < cudagraphlaunch; ++i) {
|
||||
cudaGraphLaunch(instance, stream);
|
||||
}
|
||||
CUDACHECK(cudaStreamSynchronize(stream));
|
||||
|
||||
t1 = getTime();
|
||||
ms = (t1 - t0) * 1000.0;
|
||||
time_in_us = ms * 1000. / (float)cudagraphlaunch / (float)cudagraphiter;
|
||||
printf("Rank %d report: size %lu time: %f us/iter algBW %f GBps\n", rank, dataSize, time_in_us,
|
||||
(double)(dataSize) / 1e9 / (time_in_us / 1e6));
|
||||
MSCCLPPCHECK(mscclppBootstrapAllGather(comm, tmp, sizeof(int)));
|
||||
|
||||
if (rank == 0)
|
||||
printf("Stopping MSCCL++ proxy threads\n");
|
||||
MSCCLPPCHECK(mscclppProxyStop(comm));
|
||||
|
||||
if (rank == 0)
|
||||
printf("Destroying MSCCL++ communicator\n");
|
||||
MSCCLPPCHECK(mscclppCommDestroy(comm));
|
||||
printf("Rank %d succeeded!\n", rank);
|
||||
|
||||
#ifdef MSCCLPP_USE_MPI_FOR_TESTS
|
||||
|
||||
Reference in New Issue
Block a user