mirror of
https://github.com/microsoft/mscclpp.git
synced 2026-05-13 01:36:10 +00:00
C++ API changes
This commit is contained in:
@@ -19,7 +19,7 @@ ProxyHandler makeBasicProxyHandler(Communicator::Impl &comm) {
|
||||
|
||||
if (trigger->fields.type & mscclppSync) {
|
||||
conn.flush();
|
||||
result = ProxyHandlerResult::FlushAndContinue;
|
||||
result = ProxyHandlerResult::FlushFifoTailAndContinue;
|
||||
}
|
||||
|
||||
return result;
|
||||
|
||||
@@ -14,7 +14,6 @@ Communicator::Impl::~Impl() {
|
||||
}
|
||||
}
|
||||
|
||||
MSCCLPP_API_CPP Communicator::Communicator() = default;
|
||||
MSCCLPP_API_CPP Communicator::~Communicator() = default;
|
||||
|
||||
mscclppTransport_t transportTypeToCStyle(TransportType type) {
|
||||
@@ -28,43 +27,26 @@ mscclppTransport_t transportTypeToCStyle(TransportType type) {
|
||||
}
|
||||
}
|
||||
|
||||
MSCCLPP_API_CPP void Communicator::initRank(int nranks, const char* ipPortPair, int rank) {
|
||||
if (pimpl) {
|
||||
throw std::runtime_error("Communicator already initialized");
|
||||
}
|
||||
pimpl = std::make_unique<Impl>();
|
||||
MSCCLPP_API_CPP Communicator::Communicator(int nranks, const char* ipPortPair, int rank) : pimpl(std::make_unique<Impl>()) {
|
||||
mscclppCommInitRank(&pimpl->comm, nranks, ipPortPair, rank);
|
||||
}
|
||||
|
||||
MSCCLPP_API_CPP void Communicator::initRankFromId(int nranks, UniqueId id, int rank) {
|
||||
if (pimpl) {
|
||||
throw std::runtime_error("Communicator already initialized");
|
||||
}
|
||||
pimpl = std::make_unique<Impl>();
|
||||
MSCCLPP_API_CPP Communicator::Communicator(int nranks, UniqueId id, int rank) : pimpl(std::make_unique<Impl>()) {
|
||||
static_assert(sizeof(mscclppUniqueId) == sizeof(UniqueId), "UniqueId size mismatch");
|
||||
mscclppUniqueId *cstyle_id = reinterpret_cast<mscclppUniqueId*>(&id);
|
||||
mscclppCommInitRankFromId(&pimpl->comm, nranks, *cstyle_id, rank);
|
||||
}
|
||||
|
||||
MSCCLPP_API_CPP void Communicator::bootstrapAllGather(void* data, int size) {
|
||||
if (!pimpl) {
|
||||
throw std::runtime_error("Communicator not initialized");
|
||||
}
|
||||
mscclppBootstrapAllGather(pimpl->comm, data, size);
|
||||
}
|
||||
|
||||
MSCCLPP_API_CPP void Communicator::bootstrapBarrier() {
|
||||
if (!pimpl) {
|
||||
throw std::runtime_error("Communicator not initialized");
|
||||
}
|
||||
mscclppBootstrapBarrier(pimpl->comm);
|
||||
}
|
||||
|
||||
MSCCLPP_API_CPP std::shared_ptr<HostConnection> Communicator::connect(int remoteRank, int tag,
|
||||
TransportType transportType, const char* ibDev) {
|
||||
if (!pimpl) {
|
||||
throw std::runtime_error("Communicator not initialized");
|
||||
}
|
||||
mscclppConnectWithoutBuffer(pimpl->comm, remoteRank, tag, transportTypeToCStyle(transportType), ibDev);
|
||||
auto connIdx = pimpl->connections.size();
|
||||
auto conn = std::make_shared<HostConnection>(std::make_unique<HostConnection::Impl>(this, &pimpl->comm->conns[connIdx]));
|
||||
@@ -73,39 +55,24 @@ MSCCLPP_API_CPP std::shared_ptr<HostConnection> Communicator::connect(int remote
|
||||
}
|
||||
|
||||
MSCCLPP_API_CPP void Communicator::connectionSetup() {
|
||||
if (!pimpl) {
|
||||
throw std::runtime_error("Communicator not initialized");
|
||||
}
|
||||
mscclppConnectionSetup(pimpl->comm);
|
||||
}
|
||||
|
||||
MSCCLPP_API_CPP void Communicator::startProxying() {
|
||||
if (!pimpl) {
|
||||
throw std::runtime_error("Communicator not initialized");
|
||||
}
|
||||
pimpl->proxy.start();
|
||||
}
|
||||
|
||||
MSCCLPP_API_CPP void Communicator::stopProxying() {
|
||||
if (!pimpl) {
|
||||
throw std::runtime_error("Communicator not initialized");
|
||||
}
|
||||
pimpl->proxy.stop();
|
||||
}
|
||||
|
||||
MSCCLPP_API_CPP int Communicator::rank() {
|
||||
if (!pimpl) {
|
||||
throw std::runtime_error("Communicator not initialized");
|
||||
}
|
||||
int result;
|
||||
mscclppCommRank(pimpl->comm, &result);
|
||||
return result;
|
||||
}
|
||||
|
||||
MSCCLPP_API_CPP int Communicator::size() {
|
||||
if (!pimpl) {
|
||||
throw std::runtime_error("Communicator not initialized");
|
||||
}
|
||||
int result;
|
||||
mscclppCommSize(pimpl->comm, &result);
|
||||
return result;
|
||||
|
||||
@@ -15,8 +15,14 @@ HostConnection::Impl::~Impl() {
|
||||
// TODO: figure out memory ownership. Does this deallocate the mscclppHostConn? Likely not.
|
||||
}
|
||||
|
||||
MSCCLPP_API_CPP HostConnection::~HostConnection() = default;
|
||||
|
||||
MSCCLPP_API_CPP HostConnection::HostConnection(std::unique_ptr<Impl> p) : pimpl(std::move(p)) {}
|
||||
|
||||
MSCCLPP_API_CPP int HostConnection::getId() {
|
||||
return pimpl->conn->connId;
|
||||
}
|
||||
|
||||
MSCCLPP_API_CPP BufferHandle HostConnection::registerBuffer(void* data, uint64_t size) {
|
||||
BufferHandle result;
|
||||
static_assert(sizeof(BufferHandle) == sizeof(mscclppBufferHandle_t));
|
||||
@@ -24,10 +30,15 @@ MSCCLPP_API_CPP BufferHandle HostConnection::registerBuffer(void* data, uint64_t
|
||||
return result;
|
||||
}
|
||||
|
||||
MSCCLPP_API_CPP int HostConnection::numLocalBuffers() {
|
||||
return pimpl->conn->bufferRegistrations.size() - 1;
|
||||
}
|
||||
|
||||
MSCCLPP_API_CPP BufferHandle HostConnection::getLocalBuffer(int index) {
|
||||
return index + 1;
|
||||
}
|
||||
|
||||
MSCCLPP_API_CPP int HostConnection::numRemoteBuffers() {
|
||||
if (!pimpl->conn) {
|
||||
throw std::runtime_error("HostConnection not initialized");
|
||||
}
|
||||
return pimpl->conn->remoteBufferRegistrations.size() - 1;
|
||||
}
|
||||
|
||||
@@ -35,16 +46,18 @@ MSCCLPP_API_CPP BufferHandle HostConnection::getRemoteBuffer(int index) {
|
||||
return index + 1;
|
||||
}
|
||||
|
||||
MSCCLPP_API_CPP DeviceConnection HostConnection::toDevice() {
|
||||
DeviceConnection devConn;
|
||||
MSCCLPP_API_CPP ConnectionEpoch HostConnection::getEpoch() {
|
||||
ConnectionEpoch epoch;
|
||||
static_assert(sizeof(SignalEpochId) == sizeof(mscclppDevConnSignalEpochId));
|
||||
devConn.connectionId = pimpl->conn->connId;
|
||||
devConn.localSignalEpochId = reinterpret_cast<SignalEpochId*>(pimpl->conn->devConn->localSignalEpochId);
|
||||
devConn.remoteSignalEpochId = reinterpret_cast<SignalEpochId*>(pimpl->conn->devConn->remoteSignalEpochId);
|
||||
devConn.waitEpochId = pimpl->conn->devConn->waitEpochId;
|
||||
devConn.fifo = pimpl->comm->pimpl->proxy.fifo().toDevice();
|
||||
epoch.localSignalEpochId = reinterpret_cast<SignalEpochId*>(pimpl->conn->devConn->localSignalEpochId);
|
||||
epoch.remoteSignalEpochId = reinterpret_cast<SignalEpochId*>(pimpl->conn->devConn->remoteSignalEpochId);
|
||||
epoch.waitEpochId = pimpl->conn->devConn->waitEpochId;
|
||||
return epoch;
|
||||
}
|
||||
|
||||
return devConn;
|
||||
|
||||
MSCCLPP_API_CPP DeviceProxyFifo HostConnection::getDeviceFifo() {
|
||||
return pimpl->comm->pimpl->proxy.fifo().toDevice();
|
||||
}
|
||||
|
||||
MSCCLPP_API_CPP void HostConnection::put(BufferHandle dst, uint64_t dstOffset, BufferHandle src, uint64_t srcOffset, uint64_t size) {
|
||||
|
||||
@@ -72,6 +72,99 @@ union ChannelTrigger {
|
||||
#endif // __CUDACC__
|
||||
};
|
||||
|
||||
struct ConnectionEpoch {
|
||||
#ifdef __CUDACC__
|
||||
__forceinline__ __device__ void wait()
|
||||
{
|
||||
(*waitEpochId) += 1;
|
||||
while (*(volatile uint64_t*)&(localSignalEpochId->proxy) < (*waitEpochId))
|
||||
;
|
||||
}
|
||||
|
||||
__forceinline__ __device__ void epochIncrement()
|
||||
{
|
||||
*(volatile uint64_t*)&(localSignalEpochId->device) += 1;
|
||||
}
|
||||
#endif // __CUDACC__
|
||||
|
||||
SignalEpochId* localSignalEpochId;
|
||||
// used by the signal() function directly from gpu
|
||||
SignalEpochId* remoteSignalEpochId;
|
||||
|
||||
// every wait(), increments this and then the gpu waits for either:
|
||||
// 1) localSignalEpochId->proxy to be >= this in case of a proxy thread
|
||||
// 2) remoteSignalEpochId->device to be >= this in case of a gpu thread
|
||||
uint64_t* waitEpochId;
|
||||
};
|
||||
|
||||
class HostConnection {
|
||||
struct Impl;
|
||||
public:
|
||||
/* HostConnection can not be constructed from user code and must instead be created through Communicator::connect */
|
||||
HostConnection(std::unique_ptr<Impl>);
|
||||
|
||||
~HostConnection();
|
||||
|
||||
int getId();
|
||||
|
||||
/* Register a region of GPU memory for use with this connection. Must be called before connectionSetup()
|
||||
* in the communicator.
|
||||
*
|
||||
* Inputs:
|
||||
* data: base pointer to the memory
|
||||
* size: size of the memory region in bytes
|
||||
*
|
||||
* Returns: a handle to the buffer
|
||||
*/
|
||||
BufferHandle registerBuffer(void* data, uint64_t size);
|
||||
|
||||
/* Get the number of times registerBuffer(...) was called.
|
||||
*
|
||||
* Returns: the number of buffers registered
|
||||
*/
|
||||
int numLocalBuffers();
|
||||
|
||||
/* Get the BufferHandle returned by a call to registerBuffer(...) as identified by the index
|
||||
*
|
||||
* Inputs:
|
||||
* index: the index of the handle to get
|
||||
*
|
||||
* Returns: a handle to the buffer
|
||||
*/
|
||||
BufferHandle getLocalBuffer(int index);
|
||||
|
||||
/* Get the number of times registerBuffer(...) was called on the remote peer.
|
||||
*
|
||||
* Returns: the number of buffers registered on the remote peer
|
||||
*/
|
||||
int numRemoteBuffers();
|
||||
|
||||
/* Get the BufferHandle returned by a call to registerBuffer(...) on the remote peer as identified by the index
|
||||
*
|
||||
* Inputs:
|
||||
* index: the index of the handle to get
|
||||
*
|
||||
* Returns: a handle to the buffer on the remote peer
|
||||
*/
|
||||
BufferHandle getRemoteBuffer(int index);
|
||||
|
||||
ConnectionEpoch getEpoch();
|
||||
|
||||
DeviceProxyFifo getDeviceFifo();
|
||||
|
||||
void put(BufferHandle dst, uint64_t dstOffset, BufferHandle src, uint64_t srcOffset, uint64_t size);
|
||||
|
||||
void signal();
|
||||
|
||||
void flush();
|
||||
|
||||
void wait();
|
||||
|
||||
private:
|
||||
std::unique_ptr<Impl> pimpl;
|
||||
friend class Communicator;
|
||||
};
|
||||
|
||||
/***************************************************************************************************************
|
||||
* A mscclppDevConn provides a zero-copy connection between two GPUs connected via P2P NVLink or InfiniBand.
|
||||
* The communication API is one-sided meaning that for every single data transfer, only one side
|
||||
@@ -135,9 +228,17 @@ union ChannelTrigger {
|
||||
* indices in the registered buffer.
|
||||
**************************************************************************************************************/
|
||||
struct DeviceConnection {
|
||||
#ifdef __CUDACC__
|
||||
// TODO: add buffer handles
|
||||
DeviceConnection() = default;
|
||||
|
||||
DeviceConnection(HostConnection& hostConn)
|
||||
: connectionId(hostConn.getId()), epoch(hostConn.getEpoch()),
|
||||
fifo(hostConn.getDeviceFifo()) {}
|
||||
|
||||
DeviceConnection(const DeviceConnection& other) = default;
|
||||
|
||||
DeviceConnection& operator=(DeviceConnection& other) = default;
|
||||
|
||||
#ifdef __CUDACC__
|
||||
__forceinline__ __device__ void put(BufferHandle dst, uint64_t dstOffset, BufferHandle src, uint64_t srcOffset, uint64_t size)
|
||||
{
|
||||
fifo.push(ChannelTrigger(channelTriggerData, dst, dstOffset, src, srcOffset, size, connectionId).value);
|
||||
@@ -191,28 +292,18 @@ struct DeviceConnection {
|
||||
|
||||
__forceinline__ __device__ void wait()
|
||||
{
|
||||
(*waitEpochId) += 1;
|
||||
while (*(volatile uint64_t*)&(localSignalEpochId->proxy) < (*waitEpochId))
|
||||
;
|
||||
epoch.wait();
|
||||
}
|
||||
|
||||
__forceinline__ __device__ void epochIncrement()
|
||||
{
|
||||
*(volatile uint64_t*)&(localSignalEpochId->device) += 1;
|
||||
epoch.epochIncrement();
|
||||
}
|
||||
|
||||
#endif // __CUDACC__
|
||||
|
||||
int connectionId;
|
||||
|
||||
SignalEpochId* localSignalEpochId;
|
||||
// used by the signal() function directly from gpu
|
||||
SignalEpochId* remoteSignalEpochId;
|
||||
|
||||
// every wait(), increments this and then the gpu waits for either:
|
||||
// 1) localSignalEpochId->proxy to be >= this in case of a proxy thread
|
||||
// 2) remoteSignalEpochId->device to be >= this in case of a gpu thread
|
||||
uint64_t* waitEpochId;
|
||||
ConnectionEpoch epoch;
|
||||
|
||||
// this is a concurrent fifo which is multiple threads from the device
|
||||
// can produce for and the sole proxy thread consumes it.
|
||||
@@ -220,9 +311,15 @@ struct DeviceConnection {
|
||||
};
|
||||
|
||||
struct SimpleDeviceConnection {
|
||||
SimpleDeviceConnection() {}
|
||||
SimpleDeviceConnection(DeviceConnection devConn, BufferHandle dst, BufferHandle src) : devConn(devConn), dst(dst), src(src) {}
|
||||
SimpleDeviceConnection() = default;
|
||||
|
||||
SimpleDeviceConnection(HostConnection& hostConn) : devConn(hostConn) {
|
||||
dst = hostConn.getRemoteBuffer(0);
|
||||
src = hostConn.getLocalBuffer(0);
|
||||
}
|
||||
|
||||
SimpleDeviceConnection(const SimpleDeviceConnection& other) = default;
|
||||
|
||||
SimpleDeviceConnection& operator=(SimpleDeviceConnection& other) = default;
|
||||
|
||||
#ifdef __CUDACC__
|
||||
@@ -284,59 +381,6 @@ struct SimpleDeviceConnection {
|
||||
BufferHandle src;
|
||||
};
|
||||
|
||||
class HostConnection {
|
||||
struct Impl;
|
||||
public:
|
||||
/* HostConnection can not be constructed from user code and must instead be created through Communicator::connect */
|
||||
HostConnection(std::unique_ptr<Impl>);
|
||||
|
||||
/* Register a region of GPU memory for use with this connection. Must be called before connectionSetup()
|
||||
* in the communicator.
|
||||
*
|
||||
* Inputs:
|
||||
* data: base pointer to the memory
|
||||
* size: size of the memory region in bytes
|
||||
*
|
||||
* Returns: a handle to the buffer
|
||||
*/
|
||||
BufferHandle registerBuffer(void* data, uint64_t size);
|
||||
|
||||
/* Get the number of times registerBuffer(...) was called on the remote peer.
|
||||
*
|
||||
* Returns: the number of buffers registered on the remote peer
|
||||
*/
|
||||
int numRemoteBuffers();
|
||||
|
||||
/* Get the BufferHandle returned by a call to registerBuffer(...) on the remote peer as identified by the index
|
||||
*
|
||||
* Inputs:
|
||||
* index: the index of the handle to get
|
||||
*
|
||||
* Returns: a handle to the buffer on the remote peer
|
||||
*/
|
||||
BufferHandle getRemoteBuffer(int index);
|
||||
|
||||
/* Create a DeviceConnection paired with this HostConnection. A background proxy thread will
|
||||
* trigger operations on this HostConnection corresponding to put/signal/etc. calls made to the
|
||||
* DeviceConnection.
|
||||
*
|
||||
* Returns: the newly created DeviceConnection
|
||||
*/
|
||||
DeviceConnection toDevice();
|
||||
|
||||
void put(BufferHandle dst, uint64_t dstOffset, BufferHandle src, uint64_t srcOffset, uint64_t size);
|
||||
|
||||
void signal();
|
||||
|
||||
void flush();
|
||||
|
||||
void wait();
|
||||
|
||||
private:
|
||||
std::unique_ptr<Impl> pimpl;
|
||||
friend class Communicator;
|
||||
};
|
||||
|
||||
#define MSCCLPP_UNIQUE_ID_BYTES 128
|
||||
struct UniqueId {
|
||||
char internal[MSCCLPP_UNIQUE_ID_BYTES];
|
||||
@@ -359,8 +403,6 @@ enum class TransportType : uint8_t {
|
||||
|
||||
class Communicator {
|
||||
public:
|
||||
Communicator();
|
||||
~Communicator();
|
||||
|
||||
/* Initialize the communicator. nranks processes with rank 0 to nranks-1 need to call this function.
|
||||
*
|
||||
@@ -369,7 +411,7 @@ public:
|
||||
* ipPortPair: a string of the form "ip:port" that represents the address of the root process
|
||||
* rank: rank of the calling process
|
||||
*/
|
||||
void initRank(int nranks, const char* ipPortPair, int rank);
|
||||
Communicator(int nranks, const char* ipPortPair, int rank);
|
||||
|
||||
/* Initialize the communicator from a given UniqueId. Same as mscclppCommInitRank() except that
|
||||
* id is provided by the user by calling getUniqueId()
|
||||
@@ -379,7 +421,9 @@ public:
|
||||
* id: the unique ID to be used for communication
|
||||
* rank: rank of the calling process
|
||||
*/
|
||||
void initRankFromId(int nranks, UniqueId id, int rank);
|
||||
Communicator(int nranks, UniqueId id, int rank);
|
||||
|
||||
~Communicator();
|
||||
|
||||
/* Ring-based AllGather through the bootstrap socket.
|
||||
*
|
||||
@@ -441,7 +485,7 @@ private:
|
||||
|
||||
enum class ProxyHandlerResult {
|
||||
Continue,
|
||||
FlushAndContinue,
|
||||
FlushFifoTailAndContinue,
|
||||
Stop,
|
||||
};
|
||||
|
||||
|
||||
@@ -62,10 +62,14 @@ MSCCLPP_API_CPP void Proxy::start() {
|
||||
// Flush the tail to device memory. This is either triggered every ProxyFlushPeriod to make sure
|
||||
// that the fifo can make progress even if there is no request mscclppSync. However, mscclppSync type is for flush
|
||||
// request.
|
||||
if ((++flushCnt % ProxyFlushPeriod) == 0 || result == ProxyHandlerResult::FlushAndContinue) {
|
||||
if ((++flushCnt % ProxyFlushPeriod) == 0 || result == ProxyHandlerResult::FlushFifoTailAndContinue) {
|
||||
// TODO: relocate this check: || (trigger.fields.type & mscclppSync)
|
||||
fifo.flushTail();
|
||||
}
|
||||
|
||||
if (result == ProxyHandlerResult::Stop) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// make sure the tail is flushed before we shut the proxy
|
||||
|
||||
@@ -11,6 +11,7 @@
|
||||
#include <unistd.h>
|
||||
#include <unordered_map>
|
||||
#include <cassert>
|
||||
#include <algorithm>
|
||||
|
||||
static int nranksPerNode = 8;
|
||||
|
||||
@@ -220,7 +221,7 @@ void setupMscclppConnections(int rank, int world_size, mscclpp::Communicator& co
|
||||
int thisNode = rankToNode(rank);
|
||||
int cudaNum = rankToLocalRank(rank);
|
||||
std::string ibDevStr = "mlx5_ib" + std::to_string(cudaNum);
|
||||
std::vector<std::pair<std::shared_ptr<mscclpp::HostConnection>, mscclpp::BufferHandle>> hostConns;
|
||||
std::vector<std::shared_ptr<mscclpp::HostConnection>> hostConns;
|
||||
|
||||
for (int r = 0; r < world_size; ++r) {
|
||||
if (r == rank)
|
||||
@@ -235,19 +236,17 @@ void setupMscclppConnections(int rank, int world_size, mscclpp::Communicator& co
|
||||
}
|
||||
// Connect with all other ranks
|
||||
auto hostConn = comm.connect(r, 0, transportType, ibDev);
|
||||
auto localBuffer = hostConn->registerBuffer(data_d, dataSize);
|
||||
hostConns.emplace_back(hostConn, localBuffer);
|
||||
hostConn->registerBuffer(data_d, dataSize);
|
||||
hostConns.push_back(hostConn);
|
||||
}
|
||||
|
||||
comm.connectionSetup();
|
||||
|
||||
std::vector<mscclpp::SimpleDeviceConnection> devConns;
|
||||
for (auto& entry : hostConns) {
|
||||
assert(entry.first);
|
||||
assert(entry.first->numRemoteBuffers() == 1);
|
||||
auto remoteBuffer = entry.first->getRemoteBuffer(0);
|
||||
devConns.emplace_back(entry.first->toDevice(), entry.second, remoteBuffer);
|
||||
}
|
||||
std::transform(hostConns.begin(), hostConns.end(), std::back_inserter(devConns),
|
||||
[](std::shared_ptr<mscclpp::HostConnection>& hostConn) {
|
||||
return mscclpp::SimpleDeviceConnection(*hostConn);
|
||||
});
|
||||
|
||||
assert(devConns.size() < sizeof(constDevConns) / sizeof(mscclpp::SimpleDeviceConnection));
|
||||
CUDACHECK(cudaMemcpyToSymbol(constDevConns, devConns.data(), sizeof(mscclpp::SimpleDeviceConnection) * devConns.size() ));
|
||||
@@ -401,12 +400,9 @@ int main(int argc, const char* argv[])
|
||||
size_t nelemsPerGPU = dataSize / sizeof(int) / world_size;
|
||||
|
||||
try{
|
||||
mscclpp::Communicator comm;
|
||||
|
||||
if (rank == 0)
|
||||
printf("Initializing MSCCL++\n");
|
||||
|
||||
comm.initRank(world_size, ip_port, rank);
|
||||
mscclpp::Communicator comm(world_size, ip_port, rank);
|
||||
|
||||
if (rank == 0)
|
||||
printf("Initializing data for allgather test\n");
|
||||
|
||||
Reference in New Issue
Block a user