mirror of
https://github.com/microsoft/mscclpp.git
synced 2026-05-13 01:36:10 +00:00
trigger wrappers
This commit is contained in:
@@ -69,6 +69,26 @@ union alignas(16) mscclppTrigger {
|
||||
***************************************/
|
||||
|
||||
struct mscclppDevConn {
|
||||
#ifdef __CUDACC__
|
||||
__forceinline__ __device__ mscclppTrigger *getTrigger() {
|
||||
unsigned int curFifoHead = atomicInc(this->triggerFifoHead, MSCCLPP_PROXY_FIFO_SIZE - 1);
|
||||
return &this->trigger[curFifoHead];
|
||||
}
|
||||
|
||||
__forceinline__ __device__ void setTrigger(mscclppTrigger *trig, uint64_t type, uint64_t dataOffset, uint64_t dataSize) {
|
||||
asm volatile(
|
||||
"st.volatile.global.v2.u64 [%0], {%1,%2};" ::"l"(&trig->value),
|
||||
"l"((dataOffset << (MSCCLPP_BITS_SIZE)) +
|
||||
(dataSize)),
|
||||
"l"((type << MSCCLPP_BITS_CONNID) + this->connId));
|
||||
}
|
||||
|
||||
__forceinline__ __device__ void waitTrigger(mscclppTrigger *trig) {
|
||||
// Check only the first 64 bits
|
||||
while (*(volatile uint64_t *)trig->value != 0) {}
|
||||
}
|
||||
#endif // __CUDACC__
|
||||
|
||||
int tag;
|
||||
|
||||
void* localBuff;
|
||||
|
||||
@@ -42,16 +42,6 @@ static double getTime(void)
|
||||
|
||||
__constant__ mscclppDevConn_t constDevConns[16];
|
||||
|
||||
__forceinline__ __device__ void setTrigger(mscclppTrigger *trig, uint64_t connId, uint64_t type,
|
||||
uint64_t dataOffset, uint64_t dataSize)
|
||||
{
|
||||
asm volatile(
|
||||
"st.volatile.global.v2.u64 [%0], {%1,%2};" ::"l"(&trig->value),
|
||||
"l"((dataOffset << (MSCCLPP_BITS_SIZE)) +
|
||||
(dataSize)),
|
||||
"l"((type << MSCCLPP_BITS_CONNID) + connId));
|
||||
}
|
||||
|
||||
__global__ void kernel(int rank, int world_size)
|
||||
{
|
||||
if (threadIdx.x % 32 != 0) return;
|
||||
@@ -65,8 +55,7 @@ __global__ void kernel(int rank, int world_size)
|
||||
volatile uint64_t *remoteFlag = devConn.remoteFlag;
|
||||
#endif
|
||||
volatile uint64_t *proxyFlag = devConn.proxyFlag;
|
||||
unsigned int curFifoHead = atomicInc(devConn.triggerFifoHead, MSCCLPP_PROXY_FIFO_SIZE - 1);
|
||||
mscclppTrigger *trig = &devConn.trigger[curFifoHead];
|
||||
mscclppTrigger *trig = devConn.getTrigger();
|
||||
|
||||
uint64_t baseFlag = *localFlag;
|
||||
|
||||
@@ -86,11 +75,10 @@ __global__ void kernel(int rank, int world_size)
|
||||
#if (USE_DMA_FOR_P2P == 1)
|
||||
|
||||
// Wait until the proxy have sent my data and flag
|
||||
// Check only the high 64 bits
|
||||
while (*(volatile uint64_t *)trig->value != 0) {}
|
||||
devConn.waitTrigger(trig);
|
||||
|
||||
// Trigger sending data and flag
|
||||
setTrigger(trig, devConn.connId, mscclppFlag | mscclppData, rank * sizeof(int), sizeof(int));
|
||||
devConn.setTrigger(trig, mscclppFlag | mscclppData, rank * sizeof(int), sizeof(int));
|
||||
|
||||
// Wait for receiving data from remote rank
|
||||
while (*proxyFlag == baseFlag) {}
|
||||
@@ -99,11 +87,10 @@ __global__ void kernel(int rank, int world_size)
|
||||
|
||||
if (devConn.remoteBuff == NULL) { // IB
|
||||
// Wait until the proxy have sent my data and flag
|
||||
// Check only the high 64 bits
|
||||
while (*(volatile uint64_t *)trig->value != 0) {}
|
||||
devConn.waitTrigger(trig);
|
||||
|
||||
// Trigger sending data and flag
|
||||
setTrigger(trig, devConn.connId, mscclppFlag | mscclppData, rank * sizeof(int), sizeof(int));
|
||||
devConn.setTrigger(trig, mscclppFlag | mscclppData, rank * sizeof(int), sizeof(int));
|
||||
|
||||
// Wait for receiving data from remote rank
|
||||
while (*proxyFlag == baseFlag) {}
|
||||
|
||||
Reference in New Issue
Block a user