From aacee9727bc3d3736ffde485bbb83c0972bd4908 Mon Sep 17 00:00:00 2001 From: Changho Hwang Date: Tue, 14 Mar 2023 09:11:51 +0000 Subject: [PATCH] trigger wrappers --- src/include/mscclpp.h | 20 ++++++++++++++++++++ tests/p2p_test.cu | 23 +++++------------------ 2 files changed, 25 insertions(+), 18 deletions(-) diff --git a/src/include/mscclpp.h b/src/include/mscclpp.h index 591dc5b0..85a6df24 100644 --- a/src/include/mscclpp.h +++ b/src/include/mscclpp.h @@ -69,6 +69,26 @@ union alignas(16) mscclppTrigger { ***************************************/ struct mscclppDevConn { +#ifdef __CUDACC__ + __forceinline__ __device__ mscclppTrigger *getTrigger() { + unsigned int curFifoHead = atomicInc(this->triggerFifoHead, MSCCLPP_PROXY_FIFO_SIZE - 1); + return &this->trigger[curFifoHead]; + } + + __forceinline__ __device__ void setTrigger(mscclppTrigger *trig, uint64_t type, uint64_t dataOffset, uint64_t dataSize) { + asm volatile( + "st.volatile.global.v2.u64 [%0], {%1,%2};" ::"l"(&trig->value), + "l"((dataOffset << (MSCCLPP_BITS_SIZE)) + + (dataSize)), + "l"((type << MSCCLPP_BITS_CONNID) + this->connId)); + } + + __forceinline__ __device__ void waitTrigger(mscclppTrigger *trig) { + // Check only the first 64 bits + while (*(volatile uint64_t *)trig->value != 0) {} + } +#endif // __CUDACC__ + int tag; void* localBuff; diff --git a/tests/p2p_test.cu b/tests/p2p_test.cu index 9a5591d8..04a412d6 100644 --- a/tests/p2p_test.cu +++ b/tests/p2p_test.cu @@ -42,16 +42,6 @@ static double getTime(void) __constant__ mscclppDevConn_t constDevConns[16]; -__forceinline__ __device__ void setTrigger(mscclppTrigger *trig, uint64_t connId, uint64_t type, - uint64_t dataOffset, uint64_t dataSize) -{ - asm volatile( - "st.volatile.global.v2.u64 [%0], {%1,%2};" ::"l"(&trig->value), - "l"((dataOffset << (MSCCLPP_BITS_SIZE)) + - (dataSize)), - "l"((type << MSCCLPP_BITS_CONNID) + connId)); -} - __global__ void kernel(int rank, int world_size) { if (threadIdx.x % 32 != 0) return; @@ -65,8 +55,7 @@ __global__ void kernel(int rank, int world_size) volatile uint64_t *remoteFlag = devConn.remoteFlag; #endif volatile uint64_t *proxyFlag = devConn.proxyFlag; - unsigned int curFifoHead = atomicInc(devConn.triggerFifoHead, MSCCLPP_PROXY_FIFO_SIZE - 1); - mscclppTrigger *trig = &devConn.trigger[curFifoHead]; + mscclppTrigger *trig = devConn.getTrigger(); uint64_t baseFlag = *localFlag; @@ -86,11 +75,10 @@ __global__ void kernel(int rank, int world_size) #if (USE_DMA_FOR_P2P == 1) // Wait until the proxy have sent my data and flag - // Check only the high 64 bits - while (*(volatile uint64_t *)trig->value != 0) {} + devConn.waitTrigger(trig); // Trigger sending data and flag - setTrigger(trig, devConn.connId, mscclppFlag | mscclppData, rank * sizeof(int), sizeof(int)); + devConn.setTrigger(trig, mscclppFlag | mscclppData, rank * sizeof(int), sizeof(int)); // Wait for receiving data from remote rank while (*proxyFlag == baseFlag) {} @@ -99,11 +87,10 @@ __global__ void kernel(int rank, int world_size) if (devConn.remoteBuff == NULL) { // IB // Wait until the proxy have sent my data and flag - // Check only the high 64 bits - while (*(volatile uint64_t *)trig->value != 0) {} + devConn.waitTrigger(trig); // Trigger sending data and flag - setTrigger(trig, devConn.connId, mscclppFlag | mscclppData, rank * sizeof(int), sizeof(int)); + devConn.setTrigger(trig, mscclppFlag | mscclppData, rank * sizeof(int), sizeof(int)); // Wait for receiving data from remote rank while (*proxyFlag == baseFlag) {}