mirror of
https://github.com/microsoft/mscclpp.git
synced 2026-05-24 14:54:51 +00:00
a lot of documentation
This commit is contained in:
@@ -473,7 +473,7 @@ static mscclppResult_t socketTryAccept(struct mscclppSocket* sock)
|
||||
return mscclppRemoteError;
|
||||
} else {
|
||||
usleep(SLEEP_INT);
|
||||
if (sock->acceptRetries % 10000 == 0)
|
||||
if (++sock->acceptRetries % 1000 == 0)
|
||||
INFO(MSCCLPP_ALL, "socketTryAccept: Call to try accept returned %s, retrying", strerror(errno));
|
||||
}
|
||||
return mscclppSuccess;
|
||||
|
||||
@@ -266,9 +266,6 @@ mscclppResult_t mscclppIbContextRegisterMr(struct mscclppIbContext* ctx, void* b
|
||||
if (pageSize == 0) {
|
||||
pageSize = sysconf(_SC_PAGESIZE);
|
||||
}
|
||||
if (reinterpret_cast<uintptr_t>(buff) % pageSize != 0) {
|
||||
WARN("buff (%p) is not aligned to the page size! Ignoring and proceeding anyway.", buff);
|
||||
}
|
||||
uintptr_t addr = reinterpret_cast<uintptr_t>(buff) & -pageSize;
|
||||
size_t pages = (size + (reinterpret_cast<uintptr_t>(buff) - addr) + pageSize - 1) / pageSize;
|
||||
struct ibv_mr* mr =
|
||||
|
||||
@@ -6,6 +6,8 @@
|
||||
#define MSCCLPP_PATCH 0
|
||||
#define MSCCLPP_VERSION (MSCCLPP_MAJOR * 10000 + MSCCLPP_MINOR * 100 + MSCCLPP_PATCH)
|
||||
|
||||
// For every MSCCLPP_FLUSH_FIFO_COUNTER, a flush of the tail to device memory is triggered.
|
||||
// As long as MSCCLPP_PROXY_FIFO_SIZE is large enough, having a stale tail is not a problem.
|
||||
#define MSCCLPP_PROXY_FIFO_SIZE 32
|
||||
#define MSCCLPP_FLUSH_FIFO_COUNTER 4
|
||||
|
||||
@@ -45,13 +47,15 @@ extern "C" {
|
||||
***************************************************************************************************************
|
||||
* At the runtime, a GPU kernel has access to a mscclppDevConn object that provides the following functions:
|
||||
*
|
||||
* put(): the sender initiates a data transfer to the receiver.
|
||||
* put(): [non-blocking] the sender initiates a data transfer to the receiver.
|
||||
*
|
||||
* signal(): the sender signals the receiver that data is ready to be consumed.
|
||||
* signal(): [non-blocking] the sender signals the receiver that data is ready to be consumed.
|
||||
*
|
||||
* wait(): the reciever waits on the signal() to start reading the data.
|
||||
* flush(): [blocking] the sender waits for all the data transfers to complete
|
||||
*
|
||||
* The sender should not reuse the buffer till the signal returns.
|
||||
* wait(): [blocking] the reciever waits on the signal() to start reading the data.
|
||||
*
|
||||
* The sender should not reuse the buffer till the flush returns.
|
||||
* The receiver should only access the data after the wait returns.
|
||||
*
|
||||
* putWithSignal(): the sender initiates a data transfer and signals the receiver that data is ready to be consumed.
|
||||
@@ -68,7 +72,9 @@ extern "C" {
|
||||
* devConn.put(data3) // receiver GPU
|
||||
* // not OK to write to data1, data2, data3 // not OK to read data1, data2, data3
|
||||
* devConn.signal() -------------------------------> devConn.wait()
|
||||
* // OK to write to data1, data2, data3 // OK to read data1, data2, data3
|
||||
* // Not OK to write to data1, data2, data3 // OK to read data1, data2, data3
|
||||
* devConn.flush()
|
||||
* // OK to write to data1, data2, data3
|
||||
*
|
||||
*
|
||||
* The two endpoint can concurrently use the same connection provided they are writing (puts) on different
|
||||
@@ -104,11 +110,13 @@ struct mscclppDevConn
|
||||
putWithSignal(dataOffset, dataOffset, dataSize);
|
||||
}
|
||||
|
||||
__forceinline__ __device__ void putWithSignalAndFlush(uint64_t dstDataOffset, uint64_t srcDataOffset, uint64_t dataSize)
|
||||
__forceinline__ __device__ void putWithSignalAndFlush(uint64_t dstDataOffset, uint64_t srcDataOffset,
|
||||
uint64_t dataSize)
|
||||
{
|
||||
epochIncrement();
|
||||
uint64_t curFifoHead = fifo.push(mscclppData | mscclppFlag | mscclppSync, dstDataOffset, srcDataOffset, dataSize);
|
||||
while (*(volatile uint64_t*)&fifo.triggerFifo[curFifoHead % MSCCLPP_PROXY_FIFO_SIZE] != 0 && *(volatile uint64_t*)fifo.triggerFifoTail <= curFifoHead)
|
||||
while (*(volatile uint64_t*)&fifo.triggerFifo[curFifoHead % MSCCLPP_PROXY_FIFO_SIZE] != 0 &&
|
||||
*(volatile uint64_t*)fifo.triggerFifoTail <= curFifoHead)
|
||||
;
|
||||
}
|
||||
|
||||
@@ -120,7 +128,10 @@ struct mscclppDevConn
|
||||
__forceinline__ __device__ void flush()
|
||||
{
|
||||
uint64_t curFifoHead = fifo.push(mscclppSync, 0, 0, 1);
|
||||
while (*(volatile uint64_t*)&fifo.triggerFifo[curFifoHead % MSCCLPP_PROXY_FIFO_SIZE] != 0 && *(volatile uint64_t*)fifo.triggerFifoTail <= curFifoHead)
|
||||
// there are two ways to know if the CPU is done flushing. It is either by waiting for the tail
|
||||
// to go pass by curFifoHead (this is safety net) or wait for the work element value to change to 0.
|
||||
while (*(volatile uint64_t*)&fifo.triggerFifo[curFifoHead % MSCCLPP_PROXY_FIFO_SIZE] != 0 &&
|
||||
*(volatile uint64_t*)fifo.triggerFifoTail <= curFifoHead)
|
||||
;
|
||||
}
|
||||
|
||||
@@ -148,7 +159,8 @@ struct mscclppDevConn
|
||||
uint64_t* remoteFlag;
|
||||
uint64_t* proxyEpochId; // this is only written by the proxy thread
|
||||
|
||||
// threads can access the fifo concurrently
|
||||
// this is a concurrent fifo which is multiple threads from the device
|
||||
// can produce for and the sole proxy thread consumes it.
|
||||
struct mscclppConcurrentFifo fifo;
|
||||
};
|
||||
|
||||
@@ -248,6 +260,9 @@ const char* mscclppGetErrorString(mscclppResult_t result);
|
||||
/* Connect to a remote rank. This function only prepares metadata for connection. The actual connection
|
||||
* is made by a following call of mscclppConnectionSetup(). Note that this function is two-way and a connection
|
||||
* from rank i to remote rank j needs to have a counterpart from rank j to rank i.
|
||||
* Note that with IB, buffers are registered at a page level and if a buffer is spread through multiple pages
|
||||
* and do not fully utilize all of them, IB's QP has to register for all involved pages. This potentially has
|
||||
* security risks if the devConn's accesses are given to a malicious process.
|
||||
*
|
||||
* Inputs:
|
||||
* comm: the communicator
|
||||
|
||||
@@ -19,6 +19,7 @@ typedef enum : uint64_t
|
||||
#define MSCCLPP_BITS_TYPE 3
|
||||
#define MSCCLPP_BITS_CONNID 10
|
||||
|
||||
// this is the basic structure of each work element in the fifo
|
||||
// the summation of number of bits must be 128 or less
|
||||
union alignas(16) mscclppTrigger {
|
||||
uint64_t value[2];
|
||||
@@ -38,6 +39,19 @@ union alignas(16) mscclppTrigger {
|
||||
|
||||
typedef mscclppTrigger* mscclppTrigger_t;
|
||||
|
||||
/* This is a concurrent fifo where multiple device threads can push mscclppTrigger work elements to
|
||||
* and a single host proxy thread consumes these work elements. There is a head pointer allocated on device
|
||||
* which starts with 0 and goes to 2^64-1 which is almost infinity. There are two copies of tail, one
|
||||
* that is on the deivce (triggerFifoTail) and another that is on host (proxyState->fifoTailHost).
|
||||
* The host always has the "true" tail and occasionally, pushes it to the tail version.
|
||||
* Therefore, most of the time, the device has a stale version. The invariants are:
|
||||
* triggerFifoTail <= proxyState->fifoTailHost <= triggerFifoHead.
|
||||
* push function increments triggerFifoHead, proxyState->fifoTailHost is updated in proxy.cc:mscclppProxyService
|
||||
* and it occasionally flushes it to triggerFifoTail via a cudaMemcpyAsync.
|
||||
*
|
||||
* Why douplicating the tail is a good idea? The fifo is large engouh and we do not need frequent updates
|
||||
* for the tail as there is usually enough space for device threads to push their work into.
|
||||
*/
|
||||
struct mscclppConcurrentFifo
|
||||
{
|
||||
#ifdef __CUDACC__
|
||||
@@ -57,10 +71,11 @@ struct mscclppConcurrentFifo
|
||||
return curFifoHead;
|
||||
}
|
||||
|
||||
#endif // __CUDACC__
|
||||
mscclppTrigger* triggerFifo; // allocate on host via cudaHostAlloc. produced by device and consumed by host
|
||||
uint64_t* triggerFifoTail; // allocated on device. updated only by host
|
||||
uint64_t* triggerFifoHead; // allocated on device. update only by device
|
||||
#endif // __CUDACC__
|
||||
mscclppTrigger* triggerFifo; // Allocate on host via cudaHostAlloc. This space is used for pushing the workelements
|
||||
uint64_t* triggerFifoTail; // Allocated on device. proxyState->fifoTailHost is the true tail on host and pused
|
||||
// occasionally to device
|
||||
uint64_t* triggerFifoHead; // Allocated on device. Only accessed by device
|
||||
int connId;
|
||||
};
|
||||
|
||||
|
||||
@@ -35,10 +35,12 @@ struct mscclppProxyState
|
||||
// allocated on the device. Read-only by device, write-only by host
|
||||
uint64_t* fifoTailDev;
|
||||
// allocated on the host. Only accessed by the host. This is a copy of the
|
||||
// value pointed to by fifoTailDev and the invariance is that
|
||||
// value pointed to by fifoTailDev and the invariant is that
|
||||
// *fifoTailDev <= fifoTailHost. Meaning that host's copy of tail is
|
||||
// always ahead of the device's copy and host updates the device's copy
|
||||
// only when it is needed.
|
||||
// only when it is needed. Therefore, fifoTailHost is the "true" tail
|
||||
// and fifoTailDev is a "stale" tail. See proxy.cc to undertand how
|
||||
// these updates are pushed to the device.
|
||||
uint64_t fifoTailHost;
|
||||
|
||||
struct mscclppIbContext* ibContext; // For IB connection only
|
||||
|
||||
@@ -68,7 +68,7 @@ void* mscclppProxyService(void* _args)
|
||||
uint64_t fifoTailCached = *fifoTail;
|
||||
mscclppTrigger trigger;
|
||||
mscclppIbContext* ibCtx = args->proxyState->ibContext;
|
||||
cudaStream_t p2pStream;
|
||||
cudaStream_t p2pStream = NULL;
|
||||
cudaStream_t stream;
|
||||
|
||||
PROXYCUDACHECK(cudaStreamCreate(&stream));
|
||||
@@ -169,6 +169,9 @@ void* mscclppProxyService(void* _args)
|
||||
// Send completion: reset only the high 64 bits
|
||||
*(volatile uint64_t*)(&fifo[fifoTailCached % MSCCLPP_PROXY_FIFO_SIZE]) = 0;
|
||||
fifoTailCached++;
|
||||
// Flush the tail to device memory. This is either triggered every MSCCLPP_FLUSH_FIFO_COUNTER to make sure that
|
||||
// the fifo can make progress even if there is no request mscclppSync. However, mscclppSync type is
|
||||
// for flush request.
|
||||
if (((fifoTailCached % MSCCLPP_FLUSH_FIFO_COUNTER) == 0) || (trigger.fields.type & mscclppSync)) {
|
||||
PROXYCUDACHECK(
|
||||
cudaMemcpyAsync(fifoTailDevPtr, &fifoTailCached, sizeof(uint64_t), cudaMemcpyHostToDevice, stream));
|
||||
@@ -176,6 +179,9 @@ void* mscclppProxyService(void* _args)
|
||||
}
|
||||
*fifoTail = fifoTailCached;
|
||||
|
||||
// make sure the tail is flushed before we shut the proxy
|
||||
PROXYCUDACHECK(cudaMemcpyAsync(fifoTailDevPtr, &fifoTailCached, sizeof(uint64_t), cudaMemcpyHostToDevice, stream));
|
||||
PROXYCUDACHECK(cudaStreamSynchronize(stream));
|
||||
PROXYCUDACHECK(cudaStreamDestroy(stream));
|
||||
if (isP2pProxy) {
|
||||
PROXYCUDACHECK(cudaStreamSynchronize(p2pStream));
|
||||
|
||||
Reference in New Issue
Block a user