Core API teasing out WIP

This commit is contained in:
Olli Saarikivi
2023-04-22 00:35:25 +00:00
parent 9fbb0debdd
commit 0bc3c3e574
6 changed files with 516 additions and 417 deletions

22
src/epoch.cc Normal file
View File

@@ -0,0 +1,22 @@
#include "epoch.hpp"
#include "checks.hpp"
namespace mscclpp {
struct Epoch::Impl {
DeviceEpoch deviceEpoch;
Impl() {
MSCCLPPTHROW(mscclppCudaCalloc(&deviceEpoch.localSignalEpochId, 1));
MSCCLPPTHROW(mscclppCudaCalloc(&deviceEpoch.waitEpochId, 1));
}
~Impl() {
MSCCLPPTHROW(mscclppCudaFree(deviceEpoch.localSignalEpochId));
MSCCLPPTHROW(mscclppCudaFree(deviceEpoch.waitEpochId));
}
};
Epoch::Epoch() : pimpl(std::make_unique<Impl>()) {}
} // namespace mscclpp

295
src/include/channel.hpp Normal file
View File

@@ -0,0 +1,295 @@
#ifndef MSCCLPP_CHANNEL_HPP_
#define MSCCLPP_CHANNEL_HPP_
#include "mscclpp.hpp"
#include "proxy.hpp"
namespace mscclpp {
// For every MSCCLPP_PROXY_FIFO_FLUSH_COUNTER, a flush of the tail to device memory is triggered.
// As long as MSCCLPP_PROXY_FIFO_SIZE is large enough, having a stale tail is not a problem.
#define MSCCLPP_PROXY_FIFO_SIZE 128
#define MSCCLPP_PROXY_FIFO_FLUSH_COUNTER 4
using ChannelTriggerType = uint64_t;
const ChannelTriggerType channelTriggerData = 0x1;
const ChannelTriggerType channelTriggerFlag = 0x2;
const ChannelTriggerType channelTriggerSync = 0x4;
// This is just a numeric ID. Each HostConnection will have an internal array indexed by these handles
// mapping to the actual
using BufferHandle = uint32_t;
#define MSCCLPP_BITS_SIZE 32
#define MSCCLPP_BITS_OFFSET 32
#define MSCCLPP_BITS_BUFFER_HANDLE 8
#define MSCCLPP_BITS_TYPE 3
#define MSCCLPP_BITS_CONNID 10
// this is the basic structure of each work element in the fifo
// the summation of number of bits must be 128 or less
union ChannelTrigger {
ProxyTrigger value;
struct
{
// first 64 bits: value[0]
uint64_t size : MSCCLPP_BITS_SIZE;
uint64_t srcOffset : MSCCLPP_BITS_OFFSET;
uint64_t : (64 - MSCCLPP_BITS_SIZE - MSCCLPP_BITS_OFFSET); // ensure 64-bit alignment
// second 64 bits: value[1]
uint64_t dstOffset : MSCCLPP_BITS_OFFSET;
uint64_t srcBufferHandle : MSCCLPP_BITS_BUFFER_HANDLE;
uint64_t dstBufferHandle : MSCCLPP_BITS_BUFFER_HANDLE;
uint64_t type : MSCCLPP_BITS_TYPE;
uint64_t connId : MSCCLPP_BITS_CONNID;
uint64_t : (64 - MSCCLPP_BITS_OFFSET - MSCCLPP_BITS_BUFFER_HANDLE - MSCCLPP_BITS_BUFFER_HANDLE - MSCCLPP_BITS_TYPE); // ensure 64-bit alignment
} fields;
#ifdef __CUDACC__
__device__ ChannelTrigger() {}
__device__ ChannelTrigger(ProxyTrigger value) : value(value) {}
__device__ ChannelTrigger(ChannelTriggerType type, BufferHandle dst, uint64_t dstOffset, BufferHandle src, uint64_t srcOffset, uint64_t size, int connectionId) {
value.fst = ((srcOffset << MSCCLPP_BITS_SIZE) + size);
value.snd = ((((((((connectionId << MSCCLPP_BITS_TYPE) + (uint64_t)type) << MSCCLPP_BITS_BUFFER_HANDLE) + dst) << MSCCLPP_BITS_BUFFER_HANDLE) + src) << MSCCLPP_BITS_OFFSET) + dstOffset);
}
#endif // __CUDACC__
};
struct ConnectionEpoch {
#ifdef __CUDACC__
__forceinline__ __device__ void wait()
{
(*waitEpochId) += 1;
while (*(volatile uint64_t*)&(localSignalEpochId->proxy) < (*waitEpochId))
;
}
__forceinline__ __device__ void epochIncrement()
{
*(volatile uint64_t*)&(localSignalEpochId->device) += 1;
}
#endif // __CUDACC__
SignalEpochId* localSignalEpochId;
// used by the signal() function directly from gpu
SignalEpochId* remoteSignalEpochId;
// every wait(), increments this and then the gpu waits for either:
// 1) localSignalEpochId->proxy to be >= this in case of a proxy thread
// 2) remoteSignalEpochId->device to be >= this in case of a gpu thread
uint64_t* waitEpochId;
};
class HostConnection {
struct Impl;
public:
/* HostConnection can not be constructed from user code and must instead be created through Communicator::connect */
HostConnection(std::unique_ptr<Impl>);
~HostConnection();
void write()
int getId();
/* Get the number of times registerBuffer(...) was called.
*
* Returns: the number of buffers registered
*/
int numLocalBuffers();
/* Get the BufferHandle returned by a call to registerBuffer(...) as identified by the index
*
* Inputs:
* index: the index of the handle to get
*
* Returns: a handle to the buffer
*/
BufferHandle getLocalBuffer(int index);
/* Get the number of times registerBuffer(...) was called on the remote peer.
*
* Returns: the number of buffers registered on the remote peer
*/
int numRemoteBuffers();
/* Get the BufferHandle returned by a call to registerBuffer(...) on the remote peer as identified by the index
*
* Inputs:
* index: the index of the handle to get
*
* Returns: a handle to the buffer on the remote peer
*/
BufferHandle getRemoteBuffer(int index);
ConnectionEpoch getEpoch();
DeviceProxyFifo getDeviceFifo();
void put(BufferHandle dst, uint64_t dstOffset, BufferHandle src, uint64_t srcOffset, uint64_t size);
void signal();
void flush();
void wait();
private:
std::unique_ptr<Impl> pimpl;
friend class Communicator;
};
struct DeviceConnection {
DeviceConnection() = default;
DeviceConnection(HostConnection& hostConn)
: connectionId(hostConn.getId()), epoch(hostConn.getEpoch()),
fifo(hostConn.getDeviceFifo()) {}
DeviceConnection(const DeviceConnection& other) = default;
DeviceConnection& operator=(DeviceConnection& other) = default;
#ifdef __CUDACC__
__forceinline__ __device__ void put(BufferHandle dst, uint64_t dstOffset, BufferHandle src, uint64_t srcOffset, uint64_t size)
{
fifo.push(ChannelTrigger(channelTriggerData, dst, dstOffset, src, srcOffset, size, connectionId).value);
}
__forceinline__ __device__ void put(BufferHandle dst, BufferHandle src, uint64_t offset, uint64_t size)
{
put(dst, offset, src, offset, size);
}
__forceinline__ __device__ void signal()
{
epochIncrement();
fifo.push(ChannelTrigger(channelTriggerFlag, 0, 0, 0, 0, 1, connectionId).value);
}
__forceinline__ __device__ void putWithSignal(BufferHandle dst, uint64_t dstOffset, BufferHandle src, uint64_t srcOffset, uint64_t size)
{
epochIncrement();
fifo.push(ChannelTrigger(channelTriggerData | channelTriggerFlag, dst, dstOffset, src, srcOffset, size, connectionId).value);
}
__forceinline__ __device__ void putWithSignal(BufferHandle dst, BufferHandle src, uint64_t offset, uint64_t size)
{
putWithSignal(dst, offset, src, offset, size);
}
__forceinline__ __device__ void putWithSignalAndFlush(BufferHandle dst, uint64_t dstOffset, BufferHandle src, uint64_t srcOffset, uint64_t size)
{
epochIncrement();
uint64_t curFifoHead = fifo.push(ChannelTrigger(channelTriggerData | channelTriggerFlag | channelTriggerSync, dst, dstOffset, src, srcOffset, size, connectionId).value);
while (*(volatile uint64_t*)&fifo.triggers[curFifoHead % MSCCLPP_PROXY_FIFO_SIZE] != 0 &&
*(volatile uint64_t*)fifo.tailReplica <= curFifoHead)
;
}
__forceinline__ __device__ void putWithSignalAndFlush(BufferHandle dst, BufferHandle src, uint64_t offset, uint64_t size)
{
putWithSignalAndFlush(dst, offset, src, offset, size);
}
__forceinline__ __device__ void flush()
{
uint64_t curFifoHead = fifo.push(ChannelTrigger(mscclppSync, 0, 0, 0, 0, 1, connectionId).value);
// we need to wait for two conditions to be met to ensure the CPU is done flushing. (1) wait for the tail
// to go pass by curFifoHead (this is safety net) and (2) wait for the work element value to change to 0.
while (*(volatile uint64_t*)&fifo.triggers[curFifoHead % MSCCLPP_PROXY_FIFO_SIZE] != 0 &&
*(volatile uint64_t*)fifo.tailReplica <= curFifoHead)
;
}
__forceinline__ __device__ void wait()
{
epoch.wait();
}
__forceinline__ __device__ void epochIncrement()
{
epoch.epochIncrement();
}
#endif // __CUDACC__
int connectionId;
ConnectionEpoch epoch;
// this is a concurrent fifo which is multiple threads from the device
// can produce for and the sole proxy thread consumes it.
DeviceProxyFifo fifo;
};
struct SimpleDeviceConnection {
SimpleDeviceConnection() = default;
SimpleDeviceConnection(HostConnection& hostConn) : devConn(hostConn) {
dst = hostConn.getRemoteBuffer(0);
src = hostConn.getLocalBuffer(0);
}
SimpleDeviceConnection(const SimpleDeviceConnection& other) = default;
SimpleDeviceConnection& operator=(SimpleDeviceConnection& other) = default;
#ifdef __CUDACC__
__forceinline__ __device__ void put(uint64_t dstOffset, uint64_t srcOffset, uint64_t size)
{
devConn.put(dst, dstOffset, src, srcOffset, size);
}
__forceinline__ __device__ void put(uint64_t offset, uint64_t size)
{
put(offset, offset, size);
}
__forceinline__ __device__ void signal()
{
devConn.signal();
}
__forceinline__ __device__ void putWithSignal(uint64_t dstOffset, uint64_t srcOffset, uint64_t size)
{
devConn.putWithSignal(dst, dstOffset, src, srcOffset, size);
}
__forceinline__ __device__ void putWithSignal(uint64_t offset, uint64_t size)
{
putWithSignal(offset, offset, size);
}
__forceinline__ __device__ void putWithSignalAndFlush(uint64_t dstOffset, uint64_t srcOffset, uint64_t size)
{
devConn.putWithSignalAndFlush(dst, dstOffset, src, srcOffset, size);
}
__forceinline__ __device__ void putWithSignalAndFlush(uint64_t offset, uint64_t size)
{
putWithSignalAndFlush(offset, offset, size);
}
__forceinline__ __device__ void flush()
{
devConn.flush();
}
__forceinline__ __device__ void wait()
{
devConn.wait();
}
__forceinline__ __device__ void epochIncrement()
{
devConn.epochIncrement();
}
#endif // __CUDACC__
DeviceConnection devConn;
BufferHandle dst;
BufferHandle src;
};

52
src/include/epoch.hpp Normal file
View File

@@ -0,0 +1,52 @@
#ifndef MSCCLPP_EPOCH_HPP_
#define MSCCLPP_EPOCH_HPP_
#include "mscclpp.hpp"
namespace mscclpp {
struct alignas(16) SignalEpochId {
// every signal(), increaments this and either:
// 1) proxy thread pushes it to the remote peer's localSignalEpochId->proxy
// 2) gpu thread directly writes it to remoteSignalEpochId->device
uint64_t device;
// signal() function triggers the cpu proxy thread to write to it
uint64_t proxy;
};
struct DeviceEpoch {
#ifdef __CUDACC__
__forceinline__ __device__ void wait()
{
(*waitEpochId) += 1;
while (*(volatile uint64_t*)&(localSignalEpochId->proxy) < (*waitEpochId))
;
}
__forceinline__ __device__ void epochIncrement()
{
*(volatile uint64_t*)&(localSignalEpochId->device) += 1;
}
#endif // __CUDACC__
SignalEpochId* localSignalEpochId;
SignalEpochId* remoteSignalEpochId;
uint64_t* waitEpochId;
};
class Epoch {
struct Impl;
std::unique_ptr<Impl> pimpl;
public:
Epoch();
~Epoch();
void signal();
DeviceEpoch& getDeviceEpoch();
};
} // namespace mscclpp
#endif // MSCCLPP_EPOCH_HPP_

View File

@@ -6,381 +6,11 @@
#define MSCCLPP_PATCH 0
#define MSCCLPP_VERSION (MSCCLPP_MAJOR * 10000 + MSCCLPP_MINOR * 100 + MSCCLPP_PATCH)
// For every MSCCLPP_PROXY_FIFO_FLUSH_COUNTER, a flush of the tail to device memory is triggered.
// As long as MSCCLPP_PROXY_FIFO_SIZE is large enough, having a stale tail is not a problem.
#define MSCCLPP_PROXY_FIFO_SIZE 128
#define MSCCLPP_PROXY_FIFO_FLUSH_COUNTER 4
#include <vector>
#include <memory>
#include <functional>
#include <mscclppfifo.hpp>
namespace mscclpp {
struct alignas(16) SignalEpochId {
// every signal(), increaments this and either:
// 1) proxy thread pushes it to the remote peer's localSignalEpochId->proxy
// 2) gpu thread directly writes it to remoteSignalEpochId->device
uint64_t device;
// signal() function triggers the cpu proxy thread to write to it
uint64_t proxy;
};
using ChannelTriggerType = uint64_t;
const ChannelTriggerType channelTriggerData = 0x1;
const ChannelTriggerType channelTriggerFlag = 0x2;
const ChannelTriggerType channelTriggerSync = 0x4;
// This is just a numeric ID. Each HostConnection will have an internal array indexed by these handles
// mapping to the actual
using BufferHandle = uint32_t;
#define MSCCLPP_BITS_SIZE 32
#define MSCCLPP_BITS_OFFSET 32
#define MSCCLPP_BITS_BUFFER_HANDLE 8
#define MSCCLPP_BITS_TYPE 3
#define MSCCLPP_BITS_CONNID 10
// this is the basic structure of each work element in the fifo
// the summation of number of bits must be 128 or less
union ChannelTrigger {
ProxyTrigger value;
struct
{
// first 64 bits: value[0]
uint64_t size : MSCCLPP_BITS_SIZE;
uint64_t srcOffset : MSCCLPP_BITS_OFFSET;
uint64_t : (64 - MSCCLPP_BITS_SIZE - MSCCLPP_BITS_OFFSET); // ensure 64-bit alignment
// second 64 bits: value[1]
uint64_t dstOffset : MSCCLPP_BITS_OFFSET;
uint64_t srcBufferHandle : MSCCLPP_BITS_BUFFER_HANDLE;
uint64_t dstBufferHandle : MSCCLPP_BITS_BUFFER_HANDLE;
uint64_t type : MSCCLPP_BITS_TYPE;
uint64_t connId : MSCCLPP_BITS_CONNID;
uint64_t : (64 - MSCCLPP_BITS_OFFSET - MSCCLPP_BITS_BUFFER_HANDLE - MSCCLPP_BITS_BUFFER_HANDLE - MSCCLPP_BITS_TYPE); // ensure 64-bit alignment
} fields;
#ifdef __CUDACC__
__device__ ChannelTrigger() {}
__device__ ChannelTrigger(ProxyTrigger value) : value(value) {}
__device__ ChannelTrigger(ChannelTriggerType type, BufferHandle dst, uint64_t dstOffset, BufferHandle src, uint64_t srcOffset, uint64_t size, int connectionId) {
value.fst = ((srcOffset << MSCCLPP_BITS_SIZE) + size);
value.snd = ((((((((connectionId << MSCCLPP_BITS_TYPE) + (uint64_t)type) << MSCCLPP_BITS_BUFFER_HANDLE) + dst) << MSCCLPP_BITS_BUFFER_HANDLE) + src) << MSCCLPP_BITS_OFFSET) + dstOffset);
}
#endif // __CUDACC__
};
struct ConnectionEpoch {
#ifdef __CUDACC__
__forceinline__ __device__ void wait()
{
(*waitEpochId) += 1;
while (*(volatile uint64_t*)&(localSignalEpochId->proxy) < (*waitEpochId))
;
}
__forceinline__ __device__ void epochIncrement()
{
*(volatile uint64_t*)&(localSignalEpochId->device) += 1;
}
#endif // __CUDACC__
SignalEpochId* localSignalEpochId;
// used by the signal() function directly from gpu
SignalEpochId* remoteSignalEpochId;
// every wait(), increments this and then the gpu waits for either:
// 1) localSignalEpochId->proxy to be >= this in case of a proxy thread
// 2) remoteSignalEpochId->device to be >= this in case of a gpu thread
uint64_t* waitEpochId;
};
class HostConnection {
struct Impl;
public:
/* HostConnection can not be constructed from user code and must instead be created through Communicator::connect */
HostConnection(std::unique_ptr<Impl>);
~HostConnection();
int getId();
/* Register a region of GPU memory for use with this connection. Must be called before connectionSetup()
* in the communicator.
*
* Inputs:
* data: base pointer to the memory
* size: size of the memory region in bytes
*
* Returns: a handle to the buffer
*/
BufferHandle registerBuffer(void* data, uint64_t size);
/* Get the number of times registerBuffer(...) was called.
*
* Returns: the number of buffers registered
*/
int numLocalBuffers();
/* Get the BufferHandle returned by a call to registerBuffer(...) as identified by the index
*
* Inputs:
* index: the index of the handle to get
*
* Returns: a handle to the buffer
*/
BufferHandle getLocalBuffer(int index);
/* Get the number of times registerBuffer(...) was called on the remote peer.
*
* Returns: the number of buffers registered on the remote peer
*/
int numRemoteBuffers();
/* Get the BufferHandle returned by a call to registerBuffer(...) on the remote peer as identified by the index
*
* Inputs:
* index: the index of the handle to get
*
* Returns: a handle to the buffer on the remote peer
*/
BufferHandle getRemoteBuffer(int index);
ConnectionEpoch getEpoch();
DeviceProxyFifo getDeviceFifo();
void put(BufferHandle dst, uint64_t dstOffset, BufferHandle src, uint64_t srcOffset, uint64_t size);
void signal();
void flush();
void wait();
private:
std::unique_ptr<Impl> pimpl;
friend class Communicator;
};
/***************************************************************************************************************
* A mscclppDevConn provides a zero-copy connection between two GPUs connected via P2P NVLink or InfiniBand.
* The communication API is one-sided meaning that for every single data transfer, only one side
* needs to execute unlike a two-sided communication stack such as NCCL where both sides
* need to execute a send and a receive instruction, respectively, for every transfer.
*
* A connection is uniquely identified by the (remoteRank, tag) pair at an endpoint.
* The two endpoints register buffers of the same size with the connection.
*
* The endpoints provide the remoteRank, tag, and the buffer when registering a connection with msccppConnect().
*
* mscllppConnectionSetup() sets up all the registered connections.
*
***************************************************************************************************************
* A proxy thread running on the CPU is necessary to perform transfers using InfiniBand or the DMA engine.
* The current implementation uses a single proxy thread per context - one IB connection or DMA engine per node.
* Thus multiple threadblocks using different connections might use the same CPU proxy thread.
*
* Before using any of functionality of connections, mscclppProxyLaunch needs to be called to spawn the
* proxy threads. There are currently two types of connections:
*
* P2P via NVLink: the DMA engine can perform the copy between the buffers. DMA engine has higher latency
* but has a higher bandwidth and costs no compute cycles on the GPU.
*
* InfiniBand: the RDMA engine copies the data over MLX devices.
*
***************************************************************************************************************
* At the runtime, a GPU kernel has access to a mscclppDevConn object that provides the following functions:
*
* put(): [non-blocking] the sender initiates a data transfer to the receiver.
*
* signal(): [non-blocking] the sender signals the receiver that data is ready to be consumed.
*
* flush(): [blocking] the sender waits for all the data transfers to complete
*
* wait(): [blocking] the reciever waits on the signal() to start reading the data.
*
* The sender should not reuse the buffer till the flush() returns.
* The receiver should only access the data after the wait() returns.
*
* putWithSignal(): the sender initiates a data transfer and signals the receiver that data is ready to be consumed.
* This is an optimized version of a put() followed by a signal().
*
* These functions hide the complexity of syncrhonization between the two GPUs and the CPU proxy thread.
* Example:
*
* // sender GPU
* devConn.put(data1)
* // not OK to write to data1
* devConn.put(data2)
* // not OK to write to data1, data2
* devConn.put(data3) // receiver GPU
* // not OK to write to data1, data2, data3 // not OK to read data1, data2, data3
* devConn.signal() -------------------------------> devConn.wait()
* // not OK to write to data1, data2, data3 // OK to read data1, data2, data3
* devConn.flush()
* // OK to write to data1, data2, data3
*
*
* The two endpoint can concurrently use the same connection provided they are writing (puts) on different
* indices in the registered buffer.
**************************************************************************************************************/
struct DeviceConnection {
DeviceConnection() = default;
DeviceConnection(HostConnection& hostConn)
: connectionId(hostConn.getId()), epoch(hostConn.getEpoch()),
fifo(hostConn.getDeviceFifo()) {}
DeviceConnection(const DeviceConnection& other) = default;
DeviceConnection& operator=(DeviceConnection& other) = default;
#ifdef __CUDACC__
__forceinline__ __device__ void put(BufferHandle dst, uint64_t dstOffset, BufferHandle src, uint64_t srcOffset, uint64_t size)
{
fifo.push(ChannelTrigger(channelTriggerData, dst, dstOffset, src, srcOffset, size, connectionId).value);
}
__forceinline__ __device__ void put(BufferHandle dst, BufferHandle src, uint64_t offset, uint64_t size)
{
put(dst, offset, src, offset, size);
}
__forceinline__ __device__ void signal()
{
epochIncrement();
fifo.push(ChannelTrigger(channelTriggerFlag, 0, 0, 0, 0, 1, connectionId).value);
}
__forceinline__ __device__ void putWithSignal(BufferHandle dst, uint64_t dstOffset, BufferHandle src, uint64_t srcOffset, uint64_t size)
{
epochIncrement();
fifo.push(ChannelTrigger(channelTriggerData | channelTriggerFlag, dst, dstOffset, src, srcOffset, size, connectionId).value);
}
__forceinline__ __device__ void putWithSignal(BufferHandle dst, BufferHandle src, uint64_t offset, uint64_t size)
{
putWithSignal(dst, offset, src, offset, size);
}
__forceinline__ __device__ void putWithSignalAndFlush(BufferHandle dst, uint64_t dstOffset, BufferHandle src, uint64_t srcOffset, uint64_t size)
{
epochIncrement();
uint64_t curFifoHead = fifo.push(ChannelTrigger(channelTriggerData | channelTriggerFlag | channelTriggerSync, dst, dstOffset, src, srcOffset, size, connectionId).value);
while (*(volatile uint64_t*)&fifo.triggers[curFifoHead % MSCCLPP_PROXY_FIFO_SIZE] != 0 &&
*(volatile uint64_t*)fifo.tailReplica <= curFifoHead)
;
}
__forceinline__ __device__ void putWithSignalAndFlush(BufferHandle dst, BufferHandle src, uint64_t offset, uint64_t size)
{
putWithSignalAndFlush(dst, offset, src, offset, size);
}
__forceinline__ __device__ void flush()
{
uint64_t curFifoHead = fifo.push(ChannelTrigger(mscclppSync, 0, 0, 0, 0, 1, connectionId).value);
// we need to wait for two conditions to be met to ensure the CPU is done flushing. (1) wait for the tail
// to go pass by curFifoHead (this is safety net) and (2) wait for the work element value to change to 0.
while (*(volatile uint64_t*)&fifo.triggers[curFifoHead % MSCCLPP_PROXY_FIFO_SIZE] != 0 &&
*(volatile uint64_t*)fifo.tailReplica <= curFifoHead)
;
}
__forceinline__ __device__ void wait()
{
epoch.wait();
}
__forceinline__ __device__ void epochIncrement()
{
epoch.epochIncrement();
}
#endif // __CUDACC__
int connectionId;
ConnectionEpoch epoch;
// this is a concurrent fifo which is multiple threads from the device
// can produce for and the sole proxy thread consumes it.
DeviceProxyFifo fifo;
};
struct SimpleDeviceConnection {
SimpleDeviceConnection() = default;
SimpleDeviceConnection(HostConnection& hostConn) : devConn(hostConn) {
dst = hostConn.getRemoteBuffer(0);
src = hostConn.getLocalBuffer(0);
}
SimpleDeviceConnection(const SimpleDeviceConnection& other) = default;
SimpleDeviceConnection& operator=(SimpleDeviceConnection& other) = default;
#ifdef __CUDACC__
__forceinline__ __device__ void put(uint64_t dstOffset, uint64_t srcOffset, uint64_t size)
{
devConn.put(dst, dstOffset, src, srcOffset, size);
}
__forceinline__ __device__ void put(uint64_t offset, uint64_t size)
{
put(offset, offset, size);
}
__forceinline__ __device__ void signal()
{
devConn.signal();
}
__forceinline__ __device__ void putWithSignal(uint64_t dstOffset, uint64_t srcOffset, uint64_t size)
{
devConn.putWithSignal(dst, dstOffset, src, srcOffset, size);
}
__forceinline__ __device__ void putWithSignal(uint64_t offset, uint64_t size)
{
putWithSignal(offset, offset, size);
}
__forceinline__ __device__ void putWithSignalAndFlush(uint64_t dstOffset, uint64_t srcOffset, uint64_t size)
{
devConn.putWithSignalAndFlush(dst, dstOffset, src, srcOffset, size);
}
__forceinline__ __device__ void putWithSignalAndFlush(uint64_t offset, uint64_t size)
{
putWithSignalAndFlush(offset, offset, size);
}
__forceinline__ __device__ void flush()
{
devConn.flush();
}
__forceinline__ __device__ void wait()
{
devConn.wait();
}
__forceinline__ __device__ void epochIncrement()
{
devConn.epochIncrement();
}
#endif // __CUDACC__
DeviceConnection devConn;
BufferHandle dst;
BufferHandle src;
};
#define MSCCLPP_UNIQUE_ID_BYTES 128
struct UniqueId {
char internal[MSCCLPP_UNIQUE_ID_BYTES];
@@ -395,13 +25,66 @@ struct UniqueId {
*/
std::unique_ptr<UniqueId> getUniqueId();
/* Transport Types */
enum class TransportType : uint8_t {
P2P = 0,
IB = 1,
using TransportFlags = uint32_t;
const TransportFlags TransportCudaIpc = 0b1;
const TransportFlags TransportIB = 0b10;
const TransportFlags TransportIB1 = 0b100;
const TransportFlags TransportIB2 = 0b1000;
const TransportFlags TransportIB3 = 0b10000;
const TransportFlags TransportIB4 = 0b100000;
const TransportFlags TransportIB5 = 0b1000000;
const TransportFlags TransportIB6 = 0b10000000;
const TransportFlags TransportIB7 = 0b100000000;
const TransportFlags TransportAll = 0b111111111;
class Communicator;
class RegisteredMemory {
struct Impl;
std::shared_ptr<Impl> pimpl;
public:
RegisteredMemory(std::shared_ptr<Impl> pimpl);
~RegisteredMemory();
void* data();
size_t size();
TransportFlags transports();
std::vector<char> serialize();
static RegisteredMemory deserialize(const std::vector<char>& data);
int rank();
bool isLocal();
bool isRemote();
};
class Connection {
struct Impl;
std::unique_ptr<Impl> pimpl;
public:
/* Connection can not be constructed from user code and must instead be created through Communicator::connect */
Connection(std::unique_ptr<Impl>);
~Connection();
void write(RegisteredMemory dst, uint64_t dstOffset, RegisteredMemory src, uint64_t srcOffset, uint64_t size);
void flush();
TransportFlags transport();
TransportFlags remoteTransport(); // Good to have because different IB transports can still connect to each other
// template<typename T> void write(RegisteredPtr<T> dst, RegisteredPtr<T> src, uint64_t size) {
// write(dst.memory(), dst.offset() * sizeof(T), src.memory(), src.offset() * sizeof(T), size);
// }
friend class Communicator;
};
class Communicator {
struct Impl;
std::unique_ptr<Impl> pimpl;
public:
/* Initialize the communicator. nranks processes with rank 0 to nranks-1 need to call this function.
@@ -436,6 +119,16 @@ public:
/* A no-op function that is used to synchronize all processes via a bootstrap allgather*/
void bootstrapBarrier();
/* Register a region of GPU memory for use in this communicator.
*
* Inputs:
* data: base pointer to the memory
* size: size of the memory region in bytes
*
* Returns: a handle to the buffer
*/
RegisteredMemory registerMemory(void* ptr, size_t size, TransportFlags transports);
/* Connect to a remote rank. This function only prepares metadata for connection. The actual connection
* is made by a following call of mscclppConnectionSetup(). Note that this function is two-way and a connection
* from rank i to remote rank j needs to have a counterpart from rank j to rank i.
@@ -450,19 +143,8 @@ public:
* transportType: the type of transport to be used (mscclppTransportP2P or mscclppTransportIB)
* ibDev: the name of the IB device to be used. Expects a null for mscclppTransportP2P.
*/
std::shared_ptr<HostConnection> connect(int remoteRank, int tag, TransportType transportType, const char* ibDev = 0);
std::shared_ptr<Connection> connect(int remoteRank, int tag, TransportFlags transport);
/* Establish all connections created by mscclppConnect(). This function must be called after all mscclppConnect()
* calls are made. This function ensures that all remote ranks are ready to communicate when it returns.
*/
void connectionSetup();
/* Launch proxy thread(s). This function is supposed to be called before starting a kernel that uses DeviceConnection. */
void startProxying();
/* Stop proxy thread(s). */
void stopProxying();
/* Return the rank of the calling process.
*
* Outputs:
@@ -476,37 +158,6 @@ public:
* size: the number of ranks of the communicator
*/
int size();
struct Impl;
private:
std::unique_ptr<Impl> pimpl;
friend class HostConnection;
};
enum class ProxyHandlerResult {
Continue,
FlushFifoTailAndContinue,
Stop,
};
class Proxy;
using ProxyHandler = std::function<ProxyHandlerResult(ProxyTrigger)>;
class Proxy {
public:
Proxy(ProxyHandler handler);
~Proxy();
void start();
void stop();
HostProxyFifo& fifo();
private:
struct Impl;
std::unique_ptr<Impl> pimpl;
};
} // namespace mscclpp

39
src/include/proxy.hpp Normal file
View File

@@ -0,0 +1,39 @@
#ifndef MSCCLPP_PROXY_HPP_
#define MSCCLPP_PROXY_HPP_
#include <mscclppfifo.hpp>
#include <memory>
#include <functional>
namespace mscclpp {
enum class ProxyHandlerResult {
Continue,
FlushFifoTailAndContinue,
Stop,
};
class Proxy;
using ProxyHandler = std::function<ProxyHandlerResult(ProxyTrigger)>;
class Proxy {
public:
Proxy(ProxyHandler handler);
~Proxy();
void start();
void stop();
HostProxyFifo& fifo();
private:
struct Impl;
std::unique_ptr<Impl> pimpl;
};
} // namespace mscclpp
#endif // MSCCLPP_PROXY_HPP_

View File

@@ -0,0 +1,40 @@
#ifndef MSCCLPP_REGISTERED_PTR_HPP_
#define MSCCLPP_REGISTERED_PTR_HPP_
namespace mscclpp {
template<typename T>
class RegisteredPtr {
RegisteredMemory memory;
size_t offset;
public:
RegisteredPtr(RegisteredMemory memory, size_t offset) : memory(memory), offset(offset) {}
RegisteredPtr(RegisteredMemory memory) : RegisteredPtr(memory, 0) {}
~RegisteredPtr() {}
RegisteredMemory memory() {
return memory;
}
T* data() {
return reinterpret_cast<T*>(memory.data());
}
size_t size() {
return memory.size() / sizeof(T);
}
size_t offset() {
return offset;
}
RegisteredPtr<T> operator+(size_t offset) {
return RegisteredPtr<T>(memory, this->offset + offset);
}
// TODO: all other relevant overloads
};
} // namespace mscclpp
#endif // MSCCLPP_REGISTERED_PTR_HPP_