trigger wrappers

This commit is contained in:
Changho Hwang
2023-03-14 09:11:51 +00:00
parent 75ec82d257
commit aacee9727b
2 changed files with 25 additions and 18 deletions

View File

@@ -69,6 +69,26 @@ union alignas(16) mscclppTrigger {
***************************************/
struct mscclppDevConn {
#ifdef __CUDACC__
__forceinline__ __device__ mscclppTrigger *getTrigger() {
unsigned int curFifoHead = atomicInc(this->triggerFifoHead, MSCCLPP_PROXY_FIFO_SIZE - 1);
return &this->trigger[curFifoHead];
}
__forceinline__ __device__ void setTrigger(mscclppTrigger *trig, uint64_t type, uint64_t dataOffset, uint64_t dataSize) {
asm volatile(
"st.volatile.global.v2.u64 [%0], {%1,%2};" ::"l"(&trig->value),
"l"((dataOffset << (MSCCLPP_BITS_SIZE)) +
(dataSize)),
"l"((type << MSCCLPP_BITS_CONNID) + this->connId));
}
__forceinline__ __device__ void waitTrigger(mscclppTrigger *trig) {
// Check only the first 64 bits
while (*(volatile uint64_t *)trig->value != 0) {}
}
#endif // __CUDACC__
int tag;
void* localBuff;

View File

@@ -42,16 +42,6 @@ static double getTime(void)
__constant__ mscclppDevConn_t constDevConns[16];
__forceinline__ __device__ void setTrigger(mscclppTrigger *trig, uint64_t connId, uint64_t type,
uint64_t dataOffset, uint64_t dataSize)
{
asm volatile(
"st.volatile.global.v2.u64 [%0], {%1,%2};" ::"l"(&trig->value),
"l"((dataOffset << (MSCCLPP_BITS_SIZE)) +
(dataSize)),
"l"((type << MSCCLPP_BITS_CONNID) + connId));
}
__global__ void kernel(int rank, int world_size)
{
if (threadIdx.x % 32 != 0) return;
@@ -65,8 +55,7 @@ __global__ void kernel(int rank, int world_size)
volatile uint64_t *remoteFlag = devConn.remoteFlag;
#endif
volatile uint64_t *proxyFlag = devConn.proxyFlag;
unsigned int curFifoHead = atomicInc(devConn.triggerFifoHead, MSCCLPP_PROXY_FIFO_SIZE - 1);
mscclppTrigger *trig = &devConn.trigger[curFifoHead];
mscclppTrigger *trig = devConn.getTrigger();
uint64_t baseFlag = *localFlag;
@@ -86,11 +75,10 @@ __global__ void kernel(int rank, int world_size)
#if (USE_DMA_FOR_P2P == 1)
// Wait until the proxy have sent my data and flag
// Check only the high 64 bits
while (*(volatile uint64_t *)trig->value != 0) {}
devConn.waitTrigger(trig);
// Trigger sending data and flag
setTrigger(trig, devConn.connId, mscclppFlag | mscclppData, rank * sizeof(int), sizeof(int));
devConn.setTrigger(trig, mscclppFlag | mscclppData, rank * sizeof(int), sizeof(int));
// Wait for receiving data from remote rank
while (*proxyFlag == baseFlag) {}
@@ -99,11 +87,10 @@ __global__ void kernel(int rank, int world_size)
if (devConn.remoteBuff == NULL) { // IB
// Wait until the proxy have sent my data and flag
// Check only the high 64 bits
while (*(volatile uint64_t *)trig->value != 0) {}
devConn.waitTrigger(trig);
// Trigger sending data and flag
setTrigger(trig, devConn.connId, mscclppFlag | mscclppData, rank * sizeof(int), sizeof(int));
devConn.setTrigger(trig, mscclppFlag | mscclppData, rank * sizeof(int), sizeof(int));
// Wait for receiving data from remote rank
while (*proxyFlag == baseFlag) {}