C++ API changes

This commit is contained in:
Olli Saarikivi
2023-04-19 22:02:23 +00:00
parent 83c7ba1afb
commit 9fbb0debdd
6 changed files with 160 additions and 136 deletions

View File

@@ -19,7 +19,7 @@ ProxyHandler makeBasicProxyHandler(Communicator::Impl &comm) {
if (trigger->fields.type & mscclppSync) {
conn.flush();
result = ProxyHandlerResult::FlushAndContinue;
result = ProxyHandlerResult::FlushFifoTailAndContinue;
}
return result;

View File

@@ -14,7 +14,6 @@ Communicator::Impl::~Impl() {
}
}
MSCCLPP_API_CPP Communicator::Communicator() = default;
MSCCLPP_API_CPP Communicator::~Communicator() = default;
mscclppTransport_t transportTypeToCStyle(TransportType type) {
@@ -28,43 +27,26 @@ mscclppTransport_t transportTypeToCStyle(TransportType type) {
}
}
MSCCLPP_API_CPP void Communicator::initRank(int nranks, const char* ipPortPair, int rank) {
if (pimpl) {
throw std::runtime_error("Communicator already initialized");
}
pimpl = std::make_unique<Impl>();
MSCCLPP_API_CPP Communicator::Communicator(int nranks, const char* ipPortPair, int rank) : pimpl(std::make_unique<Impl>()) {
mscclppCommInitRank(&pimpl->comm, nranks, ipPortPair, rank);
}
MSCCLPP_API_CPP void Communicator::initRankFromId(int nranks, UniqueId id, int rank) {
if (pimpl) {
throw std::runtime_error("Communicator already initialized");
}
pimpl = std::make_unique<Impl>();
MSCCLPP_API_CPP Communicator::Communicator(int nranks, UniqueId id, int rank) : pimpl(std::make_unique<Impl>()) {
static_assert(sizeof(mscclppUniqueId) == sizeof(UniqueId), "UniqueId size mismatch");
mscclppUniqueId *cstyle_id = reinterpret_cast<mscclppUniqueId*>(&id);
mscclppCommInitRankFromId(&pimpl->comm, nranks, *cstyle_id, rank);
}
MSCCLPP_API_CPP void Communicator::bootstrapAllGather(void* data, int size) {
if (!pimpl) {
throw std::runtime_error("Communicator not initialized");
}
mscclppBootstrapAllGather(pimpl->comm, data, size);
}
MSCCLPP_API_CPP void Communicator::bootstrapBarrier() {
if (!pimpl) {
throw std::runtime_error("Communicator not initialized");
}
mscclppBootstrapBarrier(pimpl->comm);
}
MSCCLPP_API_CPP std::shared_ptr<HostConnection> Communicator::connect(int remoteRank, int tag,
TransportType transportType, const char* ibDev) {
if (!pimpl) {
throw std::runtime_error("Communicator not initialized");
}
mscclppConnectWithoutBuffer(pimpl->comm, remoteRank, tag, transportTypeToCStyle(transportType), ibDev);
auto connIdx = pimpl->connections.size();
auto conn = std::make_shared<HostConnection>(std::make_unique<HostConnection::Impl>(this, &pimpl->comm->conns[connIdx]));
@@ -73,39 +55,24 @@ MSCCLPP_API_CPP std::shared_ptr<HostConnection> Communicator::connect(int remote
}
MSCCLPP_API_CPP void Communicator::connectionSetup() {
if (!pimpl) {
throw std::runtime_error("Communicator not initialized");
}
mscclppConnectionSetup(pimpl->comm);
}
MSCCLPP_API_CPP void Communicator::startProxying() {
if (!pimpl) {
throw std::runtime_error("Communicator not initialized");
}
pimpl->proxy.start();
}
MSCCLPP_API_CPP void Communicator::stopProxying() {
if (!pimpl) {
throw std::runtime_error("Communicator not initialized");
}
pimpl->proxy.stop();
}
MSCCLPP_API_CPP int Communicator::rank() {
if (!pimpl) {
throw std::runtime_error("Communicator not initialized");
}
int result;
mscclppCommRank(pimpl->comm, &result);
return result;
}
MSCCLPP_API_CPP int Communicator::size() {
if (!pimpl) {
throw std::runtime_error("Communicator not initialized");
}
int result;
mscclppCommSize(pimpl->comm, &result);
return result;

View File

@@ -15,8 +15,14 @@ HostConnection::Impl::~Impl() {
// TODO: figure out memory ownership. Does this deallocate the mscclppHostConn? Likely not.
}
MSCCLPP_API_CPP HostConnection::~HostConnection() = default;
MSCCLPP_API_CPP HostConnection::HostConnection(std::unique_ptr<Impl> p) : pimpl(std::move(p)) {}
MSCCLPP_API_CPP int HostConnection::getId() {
return pimpl->conn->connId;
}
MSCCLPP_API_CPP BufferHandle HostConnection::registerBuffer(void* data, uint64_t size) {
BufferHandle result;
static_assert(sizeof(BufferHandle) == sizeof(mscclppBufferHandle_t));
@@ -24,10 +30,15 @@ MSCCLPP_API_CPP BufferHandle HostConnection::registerBuffer(void* data, uint64_t
return result;
}
MSCCLPP_API_CPP int HostConnection::numLocalBuffers() {
return pimpl->conn->bufferRegistrations.size() - 1;
}
MSCCLPP_API_CPP BufferHandle HostConnection::getLocalBuffer(int index) {
return index + 1;
}
MSCCLPP_API_CPP int HostConnection::numRemoteBuffers() {
if (!pimpl->conn) {
throw std::runtime_error("HostConnection not initialized");
}
return pimpl->conn->remoteBufferRegistrations.size() - 1;
}
@@ -35,16 +46,18 @@ MSCCLPP_API_CPP BufferHandle HostConnection::getRemoteBuffer(int index) {
return index + 1;
}
MSCCLPP_API_CPP DeviceConnection HostConnection::toDevice() {
DeviceConnection devConn;
MSCCLPP_API_CPP ConnectionEpoch HostConnection::getEpoch() {
ConnectionEpoch epoch;
static_assert(sizeof(SignalEpochId) == sizeof(mscclppDevConnSignalEpochId));
devConn.connectionId = pimpl->conn->connId;
devConn.localSignalEpochId = reinterpret_cast<SignalEpochId*>(pimpl->conn->devConn->localSignalEpochId);
devConn.remoteSignalEpochId = reinterpret_cast<SignalEpochId*>(pimpl->conn->devConn->remoteSignalEpochId);
devConn.waitEpochId = pimpl->conn->devConn->waitEpochId;
devConn.fifo = pimpl->comm->pimpl->proxy.fifo().toDevice();
epoch.localSignalEpochId = reinterpret_cast<SignalEpochId*>(pimpl->conn->devConn->localSignalEpochId);
epoch.remoteSignalEpochId = reinterpret_cast<SignalEpochId*>(pimpl->conn->devConn->remoteSignalEpochId);
epoch.waitEpochId = pimpl->conn->devConn->waitEpochId;
return epoch;
}
return devConn;
MSCCLPP_API_CPP DeviceProxyFifo HostConnection::getDeviceFifo() {
return pimpl->comm->pimpl->proxy.fifo().toDevice();
}
MSCCLPP_API_CPP void HostConnection::put(BufferHandle dst, uint64_t dstOffset, BufferHandle src, uint64_t srcOffset, uint64_t size) {

View File

@@ -72,6 +72,99 @@ union ChannelTrigger {
#endif // __CUDACC__
};
struct ConnectionEpoch {
#ifdef __CUDACC__
__forceinline__ __device__ void wait()
{
(*waitEpochId) += 1;
while (*(volatile uint64_t*)&(localSignalEpochId->proxy) < (*waitEpochId))
;
}
__forceinline__ __device__ void epochIncrement()
{
*(volatile uint64_t*)&(localSignalEpochId->device) += 1;
}
#endif // __CUDACC__
SignalEpochId* localSignalEpochId;
// used by the signal() function directly from gpu
SignalEpochId* remoteSignalEpochId;
// every wait(), increments this and then the gpu waits for either:
// 1) localSignalEpochId->proxy to be >= this in case of a proxy thread
// 2) remoteSignalEpochId->device to be >= this in case of a gpu thread
uint64_t* waitEpochId;
};
class HostConnection {
struct Impl;
public:
/* HostConnection can not be constructed from user code and must instead be created through Communicator::connect */
HostConnection(std::unique_ptr<Impl>);
~HostConnection();
int getId();
/* Register a region of GPU memory for use with this connection. Must be called before connectionSetup()
* in the communicator.
*
* Inputs:
* data: base pointer to the memory
* size: size of the memory region in bytes
*
* Returns: a handle to the buffer
*/
BufferHandle registerBuffer(void* data, uint64_t size);
/* Get the number of times registerBuffer(...) was called.
*
* Returns: the number of buffers registered
*/
int numLocalBuffers();
/* Get the BufferHandle returned by a call to registerBuffer(...) as identified by the index
*
* Inputs:
* index: the index of the handle to get
*
* Returns: a handle to the buffer
*/
BufferHandle getLocalBuffer(int index);
/* Get the number of times registerBuffer(...) was called on the remote peer.
*
* Returns: the number of buffers registered on the remote peer
*/
int numRemoteBuffers();
/* Get the BufferHandle returned by a call to registerBuffer(...) on the remote peer as identified by the index
*
* Inputs:
* index: the index of the handle to get
*
* Returns: a handle to the buffer on the remote peer
*/
BufferHandle getRemoteBuffer(int index);
ConnectionEpoch getEpoch();
DeviceProxyFifo getDeviceFifo();
void put(BufferHandle dst, uint64_t dstOffset, BufferHandle src, uint64_t srcOffset, uint64_t size);
void signal();
void flush();
void wait();
private:
std::unique_ptr<Impl> pimpl;
friend class Communicator;
};
/***************************************************************************************************************
* A mscclppDevConn provides a zero-copy connection between two GPUs connected via P2P NVLink or InfiniBand.
* The communication API is one-sided meaning that for every single data transfer, only one side
@@ -135,9 +228,17 @@ union ChannelTrigger {
* indices in the registered buffer.
**************************************************************************************************************/
struct DeviceConnection {
#ifdef __CUDACC__
// TODO: add buffer handles
DeviceConnection() = default;
DeviceConnection(HostConnection& hostConn)
: connectionId(hostConn.getId()), epoch(hostConn.getEpoch()),
fifo(hostConn.getDeviceFifo()) {}
DeviceConnection(const DeviceConnection& other) = default;
DeviceConnection& operator=(DeviceConnection& other) = default;
#ifdef __CUDACC__
__forceinline__ __device__ void put(BufferHandle dst, uint64_t dstOffset, BufferHandle src, uint64_t srcOffset, uint64_t size)
{
fifo.push(ChannelTrigger(channelTriggerData, dst, dstOffset, src, srcOffset, size, connectionId).value);
@@ -191,28 +292,18 @@ struct DeviceConnection {
__forceinline__ __device__ void wait()
{
(*waitEpochId) += 1;
while (*(volatile uint64_t*)&(localSignalEpochId->proxy) < (*waitEpochId))
;
epoch.wait();
}
__forceinline__ __device__ void epochIncrement()
{
*(volatile uint64_t*)&(localSignalEpochId->device) += 1;
epoch.epochIncrement();
}
#endif // __CUDACC__
int connectionId;
SignalEpochId* localSignalEpochId;
// used by the signal() function directly from gpu
SignalEpochId* remoteSignalEpochId;
// every wait(), increments this and then the gpu waits for either:
// 1) localSignalEpochId->proxy to be >= this in case of a proxy thread
// 2) remoteSignalEpochId->device to be >= this in case of a gpu thread
uint64_t* waitEpochId;
ConnectionEpoch epoch;
// this is a concurrent fifo which is multiple threads from the device
// can produce for and the sole proxy thread consumes it.
@@ -220,9 +311,15 @@ struct DeviceConnection {
};
struct SimpleDeviceConnection {
SimpleDeviceConnection() {}
SimpleDeviceConnection(DeviceConnection devConn, BufferHandle dst, BufferHandle src) : devConn(devConn), dst(dst), src(src) {}
SimpleDeviceConnection() = default;
SimpleDeviceConnection(HostConnection& hostConn) : devConn(hostConn) {
dst = hostConn.getRemoteBuffer(0);
src = hostConn.getLocalBuffer(0);
}
SimpleDeviceConnection(const SimpleDeviceConnection& other) = default;
SimpleDeviceConnection& operator=(SimpleDeviceConnection& other) = default;
#ifdef __CUDACC__
@@ -284,59 +381,6 @@ struct SimpleDeviceConnection {
BufferHandle src;
};
class HostConnection {
struct Impl;
public:
/* HostConnection can not be constructed from user code and must instead be created through Communicator::connect */
HostConnection(std::unique_ptr<Impl>);
/* Register a region of GPU memory for use with this connection. Must be called before connectionSetup()
* in the communicator.
*
* Inputs:
* data: base pointer to the memory
* size: size of the memory region in bytes
*
* Returns: a handle to the buffer
*/
BufferHandle registerBuffer(void* data, uint64_t size);
/* Get the number of times registerBuffer(...) was called on the remote peer.
*
* Returns: the number of buffers registered on the remote peer
*/
int numRemoteBuffers();
/* Get the BufferHandle returned by a call to registerBuffer(...) on the remote peer as identified by the index
*
* Inputs:
* index: the index of the handle to get
*
* Returns: a handle to the buffer on the remote peer
*/
BufferHandle getRemoteBuffer(int index);
/* Create a DeviceConnection paired with this HostConnection. A background proxy thread will
* trigger operations on this HostConnection corresponding to put/signal/etc. calls made to the
* DeviceConnection.
*
* Returns: the newly created DeviceConnection
*/
DeviceConnection toDevice();
void put(BufferHandle dst, uint64_t dstOffset, BufferHandle src, uint64_t srcOffset, uint64_t size);
void signal();
void flush();
void wait();
private:
std::unique_ptr<Impl> pimpl;
friend class Communicator;
};
#define MSCCLPP_UNIQUE_ID_BYTES 128
struct UniqueId {
char internal[MSCCLPP_UNIQUE_ID_BYTES];
@@ -359,8 +403,6 @@ enum class TransportType : uint8_t {
class Communicator {
public:
Communicator();
~Communicator();
/* Initialize the communicator. nranks processes with rank 0 to nranks-1 need to call this function.
*
@@ -369,7 +411,7 @@ public:
* ipPortPair: a string of the form "ip:port" that represents the address of the root process
* rank: rank of the calling process
*/
void initRank(int nranks, const char* ipPortPair, int rank);
Communicator(int nranks, const char* ipPortPair, int rank);
/* Initialize the communicator from a given UniqueId. Same as mscclppCommInitRank() except that
* id is provided by the user by calling getUniqueId()
@@ -379,7 +421,9 @@ public:
* id: the unique ID to be used for communication
* rank: rank of the calling process
*/
void initRankFromId(int nranks, UniqueId id, int rank);
Communicator(int nranks, UniqueId id, int rank);
~Communicator();
/* Ring-based AllGather through the bootstrap socket.
*
@@ -441,7 +485,7 @@ private:
enum class ProxyHandlerResult {
Continue,
FlushAndContinue,
FlushFifoTailAndContinue,
Stop,
};

View File

@@ -62,10 +62,14 @@ MSCCLPP_API_CPP void Proxy::start() {
// Flush the tail to device memory. This is either triggered every ProxyFlushPeriod to make sure
// that the fifo can make progress even if there is no request mscclppSync. However, mscclppSync type is for flush
// request.
if ((++flushCnt % ProxyFlushPeriod) == 0 || result == ProxyHandlerResult::FlushAndContinue) {
if ((++flushCnt % ProxyFlushPeriod) == 0 || result == ProxyHandlerResult::FlushFifoTailAndContinue) {
// TODO: relocate this check: || (trigger.fields.type & mscclppSync)
fifo.flushTail();
}
if (result == ProxyHandlerResult::Stop) {
break;
}
}
// make sure the tail is flushed before we shut the proxy

View File

@@ -11,6 +11,7 @@
#include <unistd.h>
#include <unordered_map>
#include <cassert>
#include <algorithm>
static int nranksPerNode = 8;
@@ -220,7 +221,7 @@ void setupMscclppConnections(int rank, int world_size, mscclpp::Communicator& co
int thisNode = rankToNode(rank);
int cudaNum = rankToLocalRank(rank);
std::string ibDevStr = "mlx5_ib" + std::to_string(cudaNum);
std::vector<std::pair<std::shared_ptr<mscclpp::HostConnection>, mscclpp::BufferHandle>> hostConns;
std::vector<std::shared_ptr<mscclpp::HostConnection>> hostConns;
for (int r = 0; r < world_size; ++r) {
if (r == rank)
@@ -235,19 +236,17 @@ void setupMscclppConnections(int rank, int world_size, mscclpp::Communicator& co
}
// Connect with all other ranks
auto hostConn = comm.connect(r, 0, transportType, ibDev);
auto localBuffer = hostConn->registerBuffer(data_d, dataSize);
hostConns.emplace_back(hostConn, localBuffer);
hostConn->registerBuffer(data_d, dataSize);
hostConns.push_back(hostConn);
}
comm.connectionSetup();
std::vector<mscclpp::SimpleDeviceConnection> devConns;
for (auto& entry : hostConns) {
assert(entry.first);
assert(entry.first->numRemoteBuffers() == 1);
auto remoteBuffer = entry.first->getRemoteBuffer(0);
devConns.emplace_back(entry.first->toDevice(), entry.second, remoteBuffer);
}
std::transform(hostConns.begin(), hostConns.end(), std::back_inserter(devConns),
[](std::shared_ptr<mscclpp::HostConnection>& hostConn) {
return mscclpp::SimpleDeviceConnection(*hostConn);
});
assert(devConns.size() < sizeof(constDevConns) / sizeof(mscclpp::SimpleDeviceConnection));
CUDACHECK(cudaMemcpyToSymbol(constDevConns, devConns.data(), sizeof(mscclpp::SimpleDeviceConnection) * devConns.size() ));
@@ -401,12 +400,9 @@ int main(int argc, const char* argv[])
size_t nelemsPerGPU = dataSize / sizeof(int) / world_size;
try{
mscclpp::Communicator comm;
if (rank == 0)
printf("Initializing MSCCL++\n");
comm.initRank(world_size, ip_port, rank);
mscclpp::Communicator comm(world_size, ip_port, rank);
if (rank == 0)
printf("Initializing data for allgather test\n");