mirror of
https://github.com/microsoft/mscclpp.git
synced 2026-05-12 09:17:06 +00:00
kernel 2 is also performant
This commit is contained in:
@@ -84,7 +84,7 @@ __device__ void localAllGather(mscclpp::channel::SimpleDeviceChannel devChan, in
|
||||
if ((remoteRank % nranksPerNode) == ((rank + i) % nranksPerNode)) {
|
||||
// put your data to GPU (rank+i) % nranksPerNode and signal in one call
|
||||
if ((threadIdx.x % 32) == 0)
|
||||
devChan.putWithSignalAndFlush(offset, size);
|
||||
devChan.putWithSignal(offset, size);
|
||||
}
|
||||
// wait for the data from GPU (rank-i) % nranksPerNode to arrive
|
||||
if ((remoteRank % nranksPerNode) == ((rank - i + nranksPerNode) % nranksPerNode)) {
|
||||
@@ -100,6 +100,9 @@ __device__ void allgather1(mscclpp::channel::SimpleDeviceChannel devChan, int ra
|
||||
{
|
||||
localAllGather(devChan, rank, world_size, nranksPerNode, remoteRank, rank * nelemsPerGPU * sizeof(int),
|
||||
nelemsPerGPU * sizeof(int));
|
||||
if (remoteRank / nranksPerNode == rank / nranksPerNode)
|
||||
if ((threadIdx.x % 32) == 0)
|
||||
devChan.flush();
|
||||
}
|
||||
|
||||
__device__ void allgather2(mscclpp::channel::SimpleDeviceChannel devChan, int rank, int world_size, int nranksPerNode,
|
||||
@@ -127,7 +130,7 @@ __device__ void allgather2(mscclpp::channel::SimpleDeviceChannel devChan, int ra
|
||||
if (remoteRank % nranksPerNode == rank % nranksPerNode) {
|
||||
// opposite side
|
||||
if ((threadIdx.x % 32) == 0)
|
||||
devChan.putWithSignalAndFlush(rank * nelemsPerGPU * sizeof(int),
|
||||
devChan.putWithSignal(rank * nelemsPerGPU * sizeof(int),
|
||||
(nelemsPerGPU * (pipelineSize - 1)) / pipelineSize * sizeof(int));
|
||||
if ((threadIdx.x % 32) == 0)
|
||||
devChan.wait();
|
||||
@@ -147,7 +150,7 @@ __device__ void allgather2(mscclpp::channel::SimpleDeviceChannel devChan, int ra
|
||||
if (remoteRank % nranksPerNode == rank % nranksPerNode) {
|
||||
// opposite side
|
||||
if ((threadIdx.x % 32) == 0)
|
||||
devChan.putWithSignalAndFlush((rank * nelemsPerGPU + (nelemsPerGPU * (pipelineSize - 1)) / pipelineSize) *
|
||||
devChan.putWithSignal((rank * nelemsPerGPU + (nelemsPerGPU * (pipelineSize - 1)) / pipelineSize) *
|
||||
sizeof(int),
|
||||
nelemsPerGPU / pipelineSize * sizeof(int));
|
||||
if ((threadIdx.x % 32) == 0)
|
||||
@@ -163,6 +166,11 @@ __device__ void allgather2(mscclpp::channel::SimpleDeviceChannel devChan, int ra
|
||||
(otherNghr * nelemsPerGPU + (nelemsPerGPU * (pipelineSize - 1)) / pipelineSize) * sizeof(int),
|
||||
nelemsPerGPU / pipelineSize * sizeof(int));
|
||||
}
|
||||
|
||||
if (remoteRank / nranksPerNode == rank / nranksPerNode || remoteRank % nranksPerNode == rank % nranksPerNode) {
|
||||
if ((threadIdx.x % 32) == 0)
|
||||
devChan.flush();
|
||||
}
|
||||
}
|
||||
|
||||
__global__ void kernel(int rank, int world_size, int nranksPerNode, size_t nelemsPerGPU, int kernel)
|
||||
|
||||
@@ -39,23 +39,6 @@ void register_all_memories(mscclpp::Communicator& communicator, int rank, int wo
|
||||
remoteMemory[i] = futureRemoteMemory[i].get();
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
// auto serialized = localMemory.serialize();
|
||||
// int serializedSize = serialized.size();
|
||||
// for (int i = 0; i < worldSize; i++) {
|
||||
// if (i != rank){
|
||||
// communicator.bootstrapper()->send(serialized.data(), serializedSize, i, 0);
|
||||
// }
|
||||
// }
|
||||
// for (int i = 0; i < worldSize; i++) {
|
||||
// if (i != rank){
|
||||
// std::vector<char> deserialized(serializedSize);
|
||||
// communicator.bootstrapper()->recv(deserialized.data(), serializedSize, i, 0);
|
||||
// auto remote = mscclpp::RegisteredMemory::deserialize(deserialized);
|
||||
// remoteMemory[i] = remote;
|
||||
// }
|
||||
// }
|
||||
}
|
||||
|
||||
void make_connections(mscclpp::Communicator& communicator, int rank, int worldSize, int nRanksPerNode, mscclpp::Transport myIbDevice, std::unordered_map<int, std::shared_ptr<mscclpp::Connection>>& connections){
|
||||
|
||||
Reference in New Issue
Block a user