mirror of
https://github.com/microsoft/mscclpp.git
synced 2026-05-12 17:26:04 +00:00
both allgather algorithms
This commit is contained in:
@@ -7,7 +7,8 @@
|
||||
#include <unistd.h>
|
||||
#include <string>
|
||||
|
||||
#define RANKS_PER_NODE 2
|
||||
#define RANKS_PER_NODE 8
|
||||
#define KERNEL 1
|
||||
|
||||
#define MSCCLPPCHECK(call) do { \
|
||||
mscclppResult_t res = call; \
|
||||
@@ -61,7 +62,7 @@ __global__ void kernel(int rank, int world_size, int nelemsPerGPU)
|
||||
}
|
||||
|
||||
// Each warp receives data from different ranks
|
||||
|
||||
#if 0
|
||||
// get a thread-local trigger and a request for waiting on it
|
||||
mscclppTrigger_t trig;
|
||||
mscclppRequest_t req = devConn.fifo.getTrigger(&trig);
|
||||
@@ -73,7 +74,24 @@ __global__ void kernel(int rank, int world_size, int nelemsPerGPU)
|
||||
devConn.fifo.waitTrigger(req);
|
||||
|
||||
// Wait for receiving data from remote rank
|
||||
while (*proxyFlag == baseFlag) {}
|
||||
while (*proxyFlag == baseFlag);
|
||||
#else
|
||||
for (int i = 1; i < world_size; i++){
|
||||
__syncthreads();
|
||||
if (remoteRank != ((rank+i) % world_size)) continue;
|
||||
// get a thread-local trigger and a request for waiting on it
|
||||
mscclppTrigger_t trig;
|
||||
mscclppRequest_t req = devConn.fifo.getTrigger(&trig);
|
||||
|
||||
// Trigger sending data, flag and synchronize after
|
||||
devConn.fifo.setTrigger(trig, mscclppFlag | mscclppData | mscclppSync, rank * nelemsPerGPU * sizeof(int), nelemsPerGPU*sizeof(int));
|
||||
|
||||
// Wait on the request to make sure it is safe to reuse buffer and flag
|
||||
devConn.fifo.waitTrigger(req);
|
||||
}
|
||||
// Wait for receiving data from remote rank
|
||||
while (*proxyFlag == baseFlag);
|
||||
#endif
|
||||
|
||||
}
|
||||
|
||||
@@ -162,7 +180,7 @@ int main(int argc, const char *argv[])
|
||||
|
||||
int *data_d;
|
||||
uint64_t *flag_d;
|
||||
size_t data_size = 1024*1;
|
||||
size_t data_size = 1024*1024*1024;
|
||||
int nelemsPerGPU = data_size / sizeof(int) / world_size;
|
||||
CUDACHECK(cudaMalloc(&data_d, data_size));
|
||||
CUDACHECK(cudaMalloc(&flag_d, sizeof(uint64_t)));
|
||||
|
||||
Reference in New Issue
Block a user