perf fix for multi-node allgather

This commit is contained in:
Saeed Maleki
2023-03-21 06:26:12 +00:00
parent 9df5077015
commit e2ee8d80b9

View File

@@ -59,18 +59,23 @@ __global__ void kernel(int rank, int world_size, int nelemsPerGPU)
}
// Each warp receives data from different ranks
#if 0
#if 1
// get a thread-local trigger and a request for waiting on it
mscclppTrigger_t trig;
mscclppRequest_t req = devConn.fifo.getTrigger(&trig);
// Trigger sending data, flag and synchronize after
devConn.fifo.setTrigger(trig, mscclppFlag | mscclppData | mscclppSync, rank * nelemsPerGPU * sizeof(int), nelemsPerGPU*sizeof(int));
devConn.fifo.setTrigger(trig, mscclppData, rank * nelemsPerGPU * sizeof(int), nelemsPerGPU*sizeof(int));
// we cannot reuse buffer and flag until the request is completed
req = devConn.fifo.getTrigger(&trig);
// Trigger sending data, flag and synchronize after
devConn.fifo.setTrigger(trig, mscclppFlag | mscclppSync, rank * nelemsPerGPU * sizeof(int), nelemsPerGPU*sizeof(int));
// we cannot reuse buffer and flag until the request is completed
// Wait on the request to make sure it is safe to reuse buffer and flag
devConn.fifo.waitTrigger(req);
// Wait for receiving data from remote rank
while (*proxyFlag == baseFlag);
#else