mirror of
https://github.com/microsoft/mscclpp.git
synced 2026-05-23 06:16:46 +00:00
merged with api-extension
This commit is contained in:
@@ -2,5 +2,6 @@
|
||||
#define MSCCLPP_API_H_
|
||||
|
||||
#define MSCCLPP_API extern "C" __attribute__((visibility("default")))
|
||||
#define MSCCLPP_API_CPP __attribute__((visibility("default")))
|
||||
|
||||
#endif // MSCCLPP_API_H_
|
||||
|
||||
13
src/include/basic_proxy_handler.hpp
Normal file
13
src/include/basic_proxy_handler.hpp
Normal file
@@ -0,0 +1,13 @@
|
||||
#ifndef MSCCLPP_BASIC_PROXY_SERVICE_HPP_
|
||||
#define MSCCLPP_BASIC_PROXY_SERVICE_HPP_
|
||||
|
||||
#include "mscclpp.hpp"
|
||||
#include "communicator.hpp"
|
||||
|
||||
namespace mscclpp {
|
||||
|
||||
ProxyHandler makeBasicProxyHandler(Communicator::Impl &comm);
|
||||
|
||||
}
|
||||
|
||||
#endif
|
||||
@@ -27,29 +27,3 @@
|
||||
} while (false)
|
||||
|
||||
#endif
|
||||
|
||||
#include <errno.h>
|
||||
// Check system calls
|
||||
#define SYSCHECKTHROW(call, name) \
|
||||
do { \
|
||||
int retval; \
|
||||
SYSCHECKVAL(call, name, retval); \
|
||||
} while (false)
|
||||
|
||||
#define SYSCHECKVALTHROW(call, name, retval) \
|
||||
do { \
|
||||
SYSCHECKSYNC(call, name, retval); \
|
||||
if (retval == -1) { \
|
||||
std::runtime_error(std::string("Call to " name " failed : ") + strerror(errno)); \
|
||||
} \
|
||||
} while (false)
|
||||
|
||||
#define SYSCHECKSYNCTHROW(call, name, retval) \
|
||||
do { \
|
||||
retval = call; \
|
||||
if (retval == -1 && (errno == EINTR || errno == EWOULDBLOCK || errno == EAGAIN)) { \
|
||||
INFO(MSCCLPP_ALL, "Call to " name " returned %s, retrying", strerror(errno)); \
|
||||
} else { \
|
||||
break; \
|
||||
} \
|
||||
} while (true)
|
||||
|
||||
@@ -0,0 +1,23 @@
|
||||
#ifndef MSCCL_COMMUNICATOR_HPP_
|
||||
#define MSCCL_COMMUNICATOR_HPP_
|
||||
|
||||
#include "mscclpp.hpp"
|
||||
#include "mscclpp.h"
|
||||
|
||||
namespace mscclpp {
|
||||
|
||||
struct Communicator::Impl {
|
||||
mscclppComm_t comm;
|
||||
std::vector<std::shared_ptr<HostConnection>> connections;
|
||||
Proxy proxy;
|
||||
|
||||
Impl();
|
||||
|
||||
~Impl();
|
||||
|
||||
friend class HostConnection;
|
||||
};
|
||||
|
||||
} // namespace mscclpp
|
||||
|
||||
#endif
|
||||
@@ -3,17 +3,18 @@
|
||||
|
||||
#include "mscclpp.hpp"
|
||||
#include "mscclpp.h"
|
||||
#include "comm.h"
|
||||
|
||||
namespace mscclpp {
|
||||
|
||||
struct HostConnection::Impl {
|
||||
Communicator* comm;
|
||||
mscclppConn* conn;
|
||||
mscclppHostConn_t* hostConn;
|
||||
|
||||
Impl();
|
||||
Impl(Communicator* comm, mscclppConn* conn);
|
||||
|
||||
~Impl();
|
||||
|
||||
void setup(mscclppHostConn_t *hostConn);
|
||||
};
|
||||
|
||||
} // namespace mscclpp
|
||||
|
||||
@@ -29,7 +29,7 @@ struct alignas(16) mscclppDevConnSignalEpochId
|
||||
uint64_t proxy;
|
||||
};
|
||||
|
||||
using mscclppBufferHandle_t = uint8_t;
|
||||
using mscclppBufferHandle_t = uint32_t;
|
||||
|
||||
/***************************************************************************************************************
|
||||
* A mscclppDevConn provides a zero-copy connection between two GPUs connected via P2P NVLink or InfiniBand.
|
||||
|
||||
@@ -13,6 +13,7 @@
|
||||
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
#include <functional>
|
||||
|
||||
#include <mscclppfifo.hpp>
|
||||
|
||||
@@ -27,15 +28,14 @@ struct alignas(16) SignalEpochId {
|
||||
uint64_t proxy;
|
||||
};
|
||||
|
||||
enum ChannelTriggerType : uint64_t {
|
||||
channelTriggerData = 0x1,
|
||||
channelTriggerFlag = 0x2,
|
||||
channelTriggerSync = 0x4
|
||||
};
|
||||
using ChannelTriggerType = uint64_t;
|
||||
const ChannelTriggerType channelTriggerData = 0x1;
|
||||
const ChannelTriggerType channelTriggerFlag = 0x2;
|
||||
const ChannelTriggerType channelTriggerSync = 0x4;
|
||||
|
||||
// This is just a numeric ID. Each HostConnection will have an internal array indexed by these handles
|
||||
// mapping to the actual
|
||||
using BufferHandle = uint8_t;
|
||||
using BufferHandle = uint32_t;
|
||||
|
||||
#define MSCCLPP_BITS_SIZE 32
|
||||
#define MSCCLPP_BITS_OFFSET 32
|
||||
@@ -58,15 +58,111 @@ union ChannelTrigger {
|
||||
uint64_t srcBufferHandle : MSCCLPP_BITS_BUFFER_HANDLE;
|
||||
uint64_t dstBufferHandle : MSCCLPP_BITS_BUFFER_HANDLE;
|
||||
uint64_t type : MSCCLPP_BITS_TYPE;
|
||||
uint64_t connId : MSCCLPP_BITS_CONNID;
|
||||
uint64_t : (64 - MSCCLPP_BITS_OFFSET - MSCCLPP_BITS_BUFFER_HANDLE - MSCCLPP_BITS_BUFFER_HANDLE - MSCCLPP_BITS_TYPE); // ensure 64-bit alignment
|
||||
} fields;
|
||||
|
||||
ChannelTrigger() {}
|
||||
ChannelTrigger(ProxyTrigger value) : value(value) {}
|
||||
ChannelTrigger(ChannelTriggerType type, BufferHandle dst, uint64_t dstOffset, BufferHandle src, uint64_t srcOffset, uint64_t size) {
|
||||
#ifdef __CUDACC__
|
||||
__device__ ChannelTrigger() {}
|
||||
__device__ ChannelTrigger(ProxyTrigger value) : value(value) {}
|
||||
__device__ ChannelTrigger(ChannelTriggerType type, BufferHandle dst, uint64_t dstOffset, BufferHandle src, uint64_t srcOffset, uint64_t size, int connectionId) {
|
||||
value.fst = ((srcOffset << MSCCLPP_BITS_SIZE) + size);
|
||||
value.snd = (((((((uint64_t)type << MSCCLPP_BITS_BUFFER_HANDLE) + dst) << MSCCLPP_BITS_BUFFER_HANDLE) + src) << MSCCLPP_BITS_OFFSET) + dstOffset);
|
||||
value.snd = ((((((((connectionId << MSCCLPP_BITS_TYPE) + (uint64_t)type) << MSCCLPP_BITS_BUFFER_HANDLE) + dst) << MSCCLPP_BITS_BUFFER_HANDLE) + src) << MSCCLPP_BITS_OFFSET) + dstOffset);
|
||||
}
|
||||
#endif // __CUDACC__
|
||||
};
|
||||
|
||||
struct ConnectionEpoch {
|
||||
#ifdef __CUDACC__
|
||||
__forceinline__ __device__ void wait()
|
||||
{
|
||||
(*waitEpochId) += 1;
|
||||
while (*(volatile uint64_t*)&(localSignalEpochId->proxy) < (*waitEpochId))
|
||||
;
|
||||
}
|
||||
|
||||
__forceinline__ __device__ void epochIncrement()
|
||||
{
|
||||
*(volatile uint64_t*)&(localSignalEpochId->device) += 1;
|
||||
}
|
||||
#endif // __CUDACC__
|
||||
|
||||
SignalEpochId* localSignalEpochId;
|
||||
// used by the signal() function directly from gpu
|
||||
SignalEpochId* remoteSignalEpochId;
|
||||
|
||||
// every wait(), increments this and then the gpu waits for either:
|
||||
// 1) localSignalEpochId->proxy to be >= this in case of a proxy thread
|
||||
// 2) remoteSignalEpochId->device to be >= this in case of a gpu thread
|
||||
uint64_t* waitEpochId;
|
||||
};
|
||||
|
||||
class HostConnection {
|
||||
struct Impl;
|
||||
public:
|
||||
/* HostConnection can not be constructed from user code and must instead be created through Communicator::connect */
|
||||
HostConnection(std::unique_ptr<Impl>);
|
||||
|
||||
~HostConnection();
|
||||
|
||||
int getId();
|
||||
|
||||
/* Register a region of GPU memory for use with this connection. Must be called before connectionSetup()
|
||||
* in the communicator.
|
||||
*
|
||||
* Inputs:
|
||||
* data: base pointer to the memory
|
||||
* size: size of the memory region in bytes
|
||||
*
|
||||
* Returns: a handle to the buffer
|
||||
*/
|
||||
BufferHandle registerBuffer(void* data, uint64_t size);
|
||||
|
||||
/* Get the number of times registerBuffer(...) was called.
|
||||
*
|
||||
* Returns: the number of buffers registered
|
||||
*/
|
||||
int numLocalBuffers();
|
||||
|
||||
/* Get the BufferHandle returned by a call to registerBuffer(...) as identified by the index
|
||||
*
|
||||
* Inputs:
|
||||
* index: the index of the handle to get
|
||||
*
|
||||
* Returns: a handle to the buffer
|
||||
*/
|
||||
BufferHandle getLocalBuffer(int index);
|
||||
|
||||
/* Get the number of times registerBuffer(...) was called on the remote peer.
|
||||
*
|
||||
* Returns: the number of buffers registered on the remote peer
|
||||
*/
|
||||
int numRemoteBuffers();
|
||||
|
||||
/* Get the BufferHandle returned by a call to registerBuffer(...) on the remote peer as identified by the index
|
||||
*
|
||||
* Inputs:
|
||||
* index: the index of the handle to get
|
||||
*
|
||||
* Returns: a handle to the buffer on the remote peer
|
||||
*/
|
||||
BufferHandle getRemoteBuffer(int index);
|
||||
|
||||
ConnectionEpoch getEpoch();
|
||||
|
||||
DeviceProxyFifo getDeviceFifo();
|
||||
|
||||
void put(BufferHandle dst, uint64_t dstOffset, BufferHandle src, uint64_t srcOffset, uint64_t size);
|
||||
|
||||
void signal();
|
||||
|
||||
void flush();
|
||||
|
||||
void wait();
|
||||
|
||||
private:
|
||||
std::unique_ptr<Impl> pimpl;
|
||||
friend class Communicator;
|
||||
};
|
||||
|
||||
/***************************************************************************************************************
|
||||
@@ -132,12 +228,20 @@ union ChannelTrigger {
|
||||
* indices in the registered buffer.
|
||||
**************************************************************************************************************/
|
||||
struct DeviceConnection {
|
||||
#ifdef __CUDACC__
|
||||
// TODO: add buffer handles
|
||||
DeviceConnection() = default;
|
||||
|
||||
DeviceConnection(HostConnection& hostConn)
|
||||
: connectionId(hostConn.getId()), epoch(hostConn.getEpoch()),
|
||||
fifo(hostConn.getDeviceFifo()) {}
|
||||
|
||||
DeviceConnection(const DeviceConnection& other) = default;
|
||||
|
||||
DeviceConnection& operator=(DeviceConnection& other) = default;
|
||||
|
||||
#ifdef __CUDACC__
|
||||
__forceinline__ __device__ void put(BufferHandle dst, uint64_t dstOffset, BufferHandle src, uint64_t srcOffset, uint64_t size)
|
||||
{
|
||||
fifo.push(ChannelTrigger(channelTriggerData, dst, dstOffset, src, srcOffset, size).value);
|
||||
fifo.push(ChannelTrigger(channelTriggerData, dst, dstOffset, src, srcOffset, size, connectionId).value);
|
||||
}
|
||||
|
||||
__forceinline__ __device__ void put(BufferHandle dst, BufferHandle src, uint64_t offset, uint64_t size)
|
||||
@@ -148,13 +252,13 @@ struct DeviceConnection {
|
||||
__forceinline__ __device__ void signal()
|
||||
{
|
||||
epochIncrement();
|
||||
fifo.push(ChannelTrigger(channelTriggerFlag, 0, 0, 0, 0, 1).value);
|
||||
fifo.push(ChannelTrigger(channelTriggerFlag, 0, 0, 0, 0, 1, connectionId).value);
|
||||
}
|
||||
|
||||
__forceinline__ __device__ void putWithSignal(BufferHandle dst, uint64_t dstOffset, BufferHandle src, uint64_t srcOffset, uint64_t size)
|
||||
{
|
||||
epochIncrement();
|
||||
fifo.push(ChannelTrigger(channelTriggerData | channelTriggerFlag, dst, dstOffset, src, srcOffset, size).value);
|
||||
fifo.push(ChannelTrigger(channelTriggerData | channelTriggerFlag, dst, dstOffset, src, srcOffset, size, connectionId).value);
|
||||
}
|
||||
|
||||
__forceinline__ __device__ void putWithSignal(BufferHandle dst, BufferHandle src, uint64_t offset, uint64_t size)
|
||||
@@ -165,107 +269,116 @@ struct DeviceConnection {
|
||||
__forceinline__ __device__ void putWithSignalAndFlush(BufferHandle dst, uint64_t dstOffset, BufferHandle src, uint64_t srcOffset, uint64_t size)
|
||||
{
|
||||
epochIncrement();
|
||||
uint64_t curFifoHead = fifo.push(channelTriggerData | channelTriggerFlag | channelTriggerSync, dstOffset, srcOffset, size);
|
||||
while (*(volatile uint64_t*)&fifo.triggerFifo[curFifoHead % MSCCLPP_PROXY_FIFO_SIZE] != 0 &&
|
||||
*(volatile uint64_t*)fifo.triggerFifoTail <= curFifoHead)
|
||||
uint64_t curFifoHead = fifo.push(ChannelTrigger(channelTriggerData | channelTriggerFlag | channelTriggerSync, dst, dstOffset, src, srcOffset, size, connectionId).value);
|
||||
while (*(volatile uint64_t*)&fifo.triggers[curFifoHead % MSCCLPP_PROXY_FIFO_SIZE] != 0 &&
|
||||
*(volatile uint64_t*)fifo.tailReplica <= curFifoHead)
|
||||
;
|
||||
}
|
||||
|
||||
__forceinline__ __device__ void putWithSignalAndFlush(BufferHandle dst, BufferHandle src, uint64_t offset, uint64_t size)
|
||||
{
|
||||
putWithSignalAndFlush(dst, offset, src, offset, size);
|
||||
}
|
||||
|
||||
__forceinline__ __device__ void flush()
|
||||
{
|
||||
uint64_t curFifoHead = fifo.push(ChannelTrigger(mscclppSync, 0, 0, 0, 0, 1, connectionId).value);
|
||||
// we need to wait for two conditions to be met to ensure the CPU is done flushing. (1) wait for the tail
|
||||
// to go pass by curFifoHead (this is safety net) and (2) wait for the work element value to change to 0.
|
||||
while (*(volatile uint64_t*)&fifo.triggers[curFifoHead % MSCCLPP_PROXY_FIFO_SIZE] != 0 &&
|
||||
*(volatile uint64_t*)fifo.tailReplica <= curFifoHead)
|
||||
;
|
||||
}
|
||||
|
||||
__forceinline__ __device__ void wait()
|
||||
{
|
||||
epoch.wait();
|
||||
}
|
||||
|
||||
__forceinline__ __device__ void epochIncrement()
|
||||
{
|
||||
epoch.epochIncrement();
|
||||
}
|
||||
#endif // __CUDACC__
|
||||
|
||||
int connectionId;
|
||||
|
||||
ConnectionEpoch epoch;
|
||||
|
||||
// this is a concurrent fifo which is multiple threads from the device
|
||||
// can produce for and the sole proxy thread consumes it.
|
||||
DeviceProxyFifo fifo;
|
||||
};
|
||||
|
||||
struct SimpleDeviceConnection {
|
||||
SimpleDeviceConnection() = default;
|
||||
|
||||
SimpleDeviceConnection(HostConnection& hostConn) : devConn(hostConn) {
|
||||
dst = hostConn.getRemoteBuffer(0);
|
||||
src = hostConn.getLocalBuffer(0);
|
||||
}
|
||||
|
||||
SimpleDeviceConnection(const SimpleDeviceConnection& other) = default;
|
||||
|
||||
SimpleDeviceConnection& operator=(SimpleDeviceConnection& other) = default;
|
||||
|
||||
#ifdef __CUDACC__
|
||||
|
||||
__forceinline__ __device__ void put(uint64_t dstOffset, uint64_t srcOffset, uint64_t size)
|
||||
{
|
||||
devConn.put(dst, dstOffset, src, srcOffset, size);
|
||||
}
|
||||
|
||||
__forceinline__ __device__ void put(uint64_t offset, uint64_t size)
|
||||
{
|
||||
put(offset, offset, size);
|
||||
}
|
||||
|
||||
__forceinline__ __device__ void signal()
|
||||
{
|
||||
devConn.signal();
|
||||
}
|
||||
|
||||
__forceinline__ __device__ void putWithSignal(uint64_t dstOffset, uint64_t srcOffset, uint64_t size)
|
||||
{
|
||||
devConn.putWithSignal(dst, dstOffset, src, srcOffset, size);
|
||||
}
|
||||
|
||||
__forceinline__ __device__ void putWithSignal(uint64_t offset, uint64_t size)
|
||||
{
|
||||
putWithSignal(offset, offset, size);
|
||||
}
|
||||
|
||||
__forceinline__ __device__ void putWithSignalAndFlush(uint64_t dstOffset, uint64_t srcOffset, uint64_t size)
|
||||
{
|
||||
devConn.putWithSignalAndFlush(dst, dstOffset, src, srcOffset, size);
|
||||
}
|
||||
|
||||
__forceinline__ __device__ void putWithSignalAndFlush(uint64_t offset, uint64_t size)
|
||||
{
|
||||
putWithSignalAndFlush(offset, offset, size);
|
||||
}
|
||||
|
||||
__forceinline__ __device__ void flush()
|
||||
{
|
||||
uint64_t curFifoHead = fifo.push(mscclppSync, 0, 0, 1);
|
||||
// we need to wait for two conditions to be met to ensure the CPU is done flushing. (1) wait for the tail
|
||||
// to go pass by curFifoHead (this is safety net) and (2) wait for the work element value to change to 0.
|
||||
while (*(volatile uint64_t*)&fifo.triggerFifo[curFifoHead % MSCCLPP_PROXY_FIFO_SIZE] != 0 &&
|
||||
*(volatile uint64_t*)fifo.triggerFifoTail <= curFifoHead)
|
||||
;
|
||||
devConn.flush();
|
||||
}
|
||||
|
||||
__forceinline__ __device__ void wait()
|
||||
{
|
||||
(*waitEpochId) += 1;
|
||||
while (*(volatile uint64_t*)&(localSignalEpochId->proxy) < (*waitEpochId))
|
||||
;
|
||||
devConn.wait();
|
||||
}
|
||||
|
||||
__forceinline__ __device__ void epochIncrement()
|
||||
{
|
||||
*(volatile uint64_t*)&(localSignalEpochId->device) += 1;
|
||||
devConn.epochIncrement();
|
||||
}
|
||||
|
||||
#endif // __CUDACC__
|
||||
|
||||
int remoteRank;
|
||||
int tag;
|
||||
|
||||
SignalEpochId* localSignalEpochId;
|
||||
// used by the signal() function directly from gpu
|
||||
SignalEpochId* remoteSignalEpochId;
|
||||
|
||||
// every wait(), increments this and then the gpu waits for either:
|
||||
// 1) localSignalEpochId->proxy to be >= this in case of a proxy thread
|
||||
// 2) remoteSignalEpochId->device to be >= this in case of a gpu thread
|
||||
uint64_t* waitEpochId;
|
||||
|
||||
// this is a concurrent fifo which is multiple threads from the device
|
||||
// can produce for and the sole proxy thread consumes it.
|
||||
ProxyFifo fifo;
|
||||
};
|
||||
|
||||
class HostConnection {
|
||||
public:
|
||||
/* Register a region of GPU memory for use with this connection. Must be called before connectionSetup()
|
||||
* in the communicator.
|
||||
*
|
||||
* Inputs:
|
||||
* data: base pointer to the memory
|
||||
* size: size of the memory region in bytes
|
||||
*
|
||||
* Returns: a handle to the buffer
|
||||
*/
|
||||
BufferHandle registerBuffer(void* data, uint64_t size);
|
||||
|
||||
/* Get the number of times registerBuffer(...) was called on the remote peer.
|
||||
*
|
||||
* Returns: the number of buffers registered on the remote peer
|
||||
*/
|
||||
int numRemoteBuffers();
|
||||
|
||||
/* Get the BufferHandle returned by a call to registerBuffer(...) on the remote peer as identified by the index
|
||||
*
|
||||
* Inputs:
|
||||
* index: the index of the handle to get
|
||||
*
|
||||
* Returns: a handle to the buffer on the remote peer
|
||||
*/
|
||||
BufferHandle getRemoteBuffer(int index);
|
||||
|
||||
/* Create a DeviceConnection paired with this HostConnection. A background proxy thread will
|
||||
* trigger operations on this HostConnection corresponding to put/signal/etc. calls made to the
|
||||
* DeviceConnection.
|
||||
*
|
||||
* Inputs:
|
||||
* startProxyThread: whether to start the proxy thread (default is true)
|
||||
*
|
||||
* Returns: the newly created DeviceConnection
|
||||
*/
|
||||
DeviceConnection toDevice(bool startProxyThread = true);
|
||||
|
||||
void put(BufferHandle dst, uint64_t dstOffset, BufferHandle src, uint64_t srcOffset, uint64_t size);
|
||||
void put(BufferHandle dst, BufferHandle src, uint64_t offset, uint64_t size);
|
||||
void signal();
|
||||
void flush();
|
||||
void wait();
|
||||
void epochIncrement();
|
||||
|
||||
private:
|
||||
struct Impl;
|
||||
std::unique_ptr<Impl> pimpl;
|
||||
DeviceConnection devConn;
|
||||
BufferHandle dst;
|
||||
BufferHandle src;
|
||||
};
|
||||
|
||||
#define MSCCLPP_UNIQUE_ID_BYTES 128
|
||||
@@ -290,6 +403,7 @@ enum class TransportType : uint8_t {
|
||||
|
||||
class Communicator {
|
||||
public:
|
||||
|
||||
/* Initialize the communicator. nranks processes with rank 0 to nranks-1 need to call this function.
|
||||
*
|
||||
* Inputs:
|
||||
@@ -297,7 +411,7 @@ public:
|
||||
* ipPortPair: a string of the form "ip:port" that represents the address of the root process
|
||||
* rank: rank of the calling process
|
||||
*/
|
||||
void initRank(int nranks, const char* ipPortPair, int rank);
|
||||
Communicator(int nranks, const char* ipPortPair, int rank);
|
||||
|
||||
/* Initialize the communicator from a given UniqueId. Same as mscclppCommInitRank() except that
|
||||
* id is provided by the user by calling getUniqueId()
|
||||
@@ -307,7 +421,9 @@ public:
|
||||
* id: the unique ID to be used for communication
|
||||
* rank: rank of the calling process
|
||||
*/
|
||||
void initRankFromId(int nranks, UniqueId id, int rank);
|
||||
Communicator(int nranks, UniqueId id, int rank);
|
||||
|
||||
~Communicator();
|
||||
|
||||
/* Ring-based AllGather through the bootstrap socket.
|
||||
*
|
||||
@@ -341,6 +457,12 @@ public:
|
||||
*/
|
||||
void connectionSetup();
|
||||
|
||||
/* Launch proxy thread(s). This function is supposed to be called before starting a kernel that uses DeviceConnection. */
|
||||
void startProxying();
|
||||
|
||||
/* Stop proxy thread(s). */
|
||||
void stopProxying();
|
||||
|
||||
/* Return the rank of the calling process.
|
||||
*
|
||||
* Outputs:
|
||||
@@ -355,6 +477,33 @@ public:
|
||||
*/
|
||||
int size();
|
||||
|
||||
struct Impl;
|
||||
private:
|
||||
std::unique_ptr<Impl> pimpl;
|
||||
friend class HostConnection;
|
||||
};
|
||||
|
||||
enum class ProxyHandlerResult {
|
||||
Continue,
|
||||
FlushFifoTailAndContinue,
|
||||
Stop,
|
||||
};
|
||||
|
||||
class Proxy;
|
||||
using ProxyHandler = std::function<ProxyHandlerResult(ProxyTrigger)>;
|
||||
|
||||
class Proxy {
|
||||
public:
|
||||
Proxy(ProxyHandler handler);
|
||||
|
||||
~Proxy();
|
||||
|
||||
void start();
|
||||
|
||||
void stop();
|
||||
|
||||
HostProxyFifo& fifo();
|
||||
|
||||
private:
|
||||
struct Impl;
|
||||
std::unique_ptr<Impl> pimpl;
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
|
||||
#include <stdint.h>
|
||||
#include <functional>
|
||||
#include <memory>
|
||||
|
||||
namespace mscclpp {
|
||||
|
||||
@@ -13,39 +14,56 @@ struct alignas(16) ProxyTrigger {
|
||||
/* This is a concurrent fifo where multiple device threads can push mscclppTrigger work elements to
|
||||
* and a single host proxy thread consumes these work elements. There is a head pointer allocated on device
|
||||
* which starts with 0 and goes to 2^64-1 which is almost infinity. There are two copies of tail, one
|
||||
* that is on the deivce (triggerFifoTail) and another that is on host (proxyState->fifoTailHost).
|
||||
* that is on the deivce (tailReplica) and another that is on host (proxyState->fifoTailHost).
|
||||
* The host always has the "true" tail and occasionally, pushes it to the copy on the device.
|
||||
* Therefore, most of the time, the device has a stale version. The invariants are:
|
||||
* triggerFifoTail <= proxyState->fifoTailHost <= triggerFifoHead.
|
||||
* push() function increments triggerFifoHead, proxyState->fifoTailHost is updated in proxy.cc:mscclppProxyService
|
||||
* and it occasionally flushes it to triggerFifoTail via a cudaMemcpyAsync.
|
||||
* tailReplica <= proxyState->fifoTailHost <= head.
|
||||
* push() function increments head, proxyState->fifoTailHost is updated in proxy.cc:mscclppProxyService
|
||||
* and it occasionally flushes it to tailReplica via a cudaMemcpyAsync.
|
||||
*
|
||||
* Why duplicating the tail is a good idea? The fifo is large engouh and we do not need frequent updates
|
||||
* for the tail as there is usually enough space for device threads to push their work into.
|
||||
*/
|
||||
struct ProxyFifo {
|
||||
struct DeviceProxyFifo {
|
||||
#ifdef __CUDACC__
|
||||
__forceinline__ __device__ uint64_t push(ProxyTrigger element)
|
||||
__forceinline__ __device__ uint64_t push(ProxyTrigger trigger)
|
||||
{
|
||||
uint64_t curFifoHead = atomicAdd((unsigned long long int*)this->triggerFifoHead, 1);
|
||||
while (curFifoHead >= MSCCLPP_PROXY_FIFO_SIZE + *((volatile uint64_t*)this->triggerFifoTail))
|
||||
uint64_t curFifoHead = atomicAdd((unsigned long long int*)this->head, 1);
|
||||
while (curFifoHead >= MSCCLPP_PROXY_FIFO_SIZE + *((volatile uint64_t*)this->tailReplica))
|
||||
;
|
||||
while (*(volatile uint64_t*)&this->triggerFifo[curFifoHead % MSCCLPP_PROXY_FIFO_SIZE] != 0)
|
||||
while (*(volatile uint64_t*)&this->triggers[curFifoHead % MSCCLPP_PROXY_FIFO_SIZE] != 0)
|
||||
;
|
||||
uint64_t* valptr = (uint64_t*)&(this->triggerFifo[curFifoHead % MSCCLPP_PROXY_FIFO_SIZE].value);
|
||||
asm volatile("st.volatile.global.v2.u64 [%0], {%1,%2};" ::"l"(valptr),
|
||||
"l"(element.value[0]), "l"(element.value[1]));
|
||||
ProxyTrigger* triggerPtr = (ProxyTrigger*)&(this->triggers[curFifoHead % MSCCLPP_PROXY_FIFO_SIZE]);
|
||||
asm volatile("st.volatile.global.v2.u64 [%0], {%1,%2};" ::"l"(triggerPtr),
|
||||
"l"(trigger.fst), "l"(trigger.snd));
|
||||
return curFifoHead;
|
||||
}
|
||||
#endif // __CUDACC__
|
||||
|
||||
void startProxyThread(std::function<void(ProxyTrigger)> handler);
|
||||
void stopProxyThread();
|
||||
|
||||
ProxyTrigger* triggerFifo; // Allocate on host via cudaHostAlloc. This space is used for pushing the workelements
|
||||
uint64_t* triggerFifoTail; // Allocated on device. proxyState->fifoTailHost is the true tail on host and pused
|
||||
ProxyTrigger* triggers; // Allocate on host via cudaHostAlloc. This space is used for pushing the workelements
|
||||
uint64_t* tailReplica; // Allocated on device. proxyState->fifoTailHost is the true tail on host and pused
|
||||
// occasionally to device
|
||||
uint64_t* triggerFifoHead; // Allocated on device. Only accessed by device
|
||||
uint64_t* head; // Allocated on device. Only accessed by device
|
||||
};
|
||||
|
||||
class HostProxyFifo
|
||||
{
|
||||
public:
|
||||
HostProxyFifo();
|
||||
|
||||
~HostProxyFifo();
|
||||
|
||||
void poll(ProxyTrigger *trigger);
|
||||
|
||||
void pop();
|
||||
|
||||
void flushTail(bool sync = false);
|
||||
|
||||
DeviceProxyFifo toDevice();
|
||||
|
||||
private:
|
||||
struct Impl;
|
||||
std::unique_ptr<Impl> pimpl;
|
||||
};
|
||||
|
||||
} // namespace mscclpp
|
||||
|
||||
Reference in New Issue
Block a user