Files
mscclpp/python/test/port_channel_test.cu
Changho Hwang 710f6686dc Revised MemoryChannel interfaces (#508)
* Moved the `MemoryChannel::copy()` method out of the `MemoryChannel` as
a standalone function.
* Renamed `mscclpp::putPackets()` and `mscclpp::getPackets()` to
`mscclpp::copyToPackets()` and `mscclpp::copyFromPackets()` respectively
for consistency.
* Renamed `MemoryChannel::getPackets()` to
`MemoryChannel::unpackPackets()` for clarity. Renamed `getPacketBuffer`
to `packetBuffer`.
* Added the `MemoryChannel::unpackPacket()` method that unpacks one
packet in the buffer.
* Added the `BaseMemoryChannel` class that only contains a semaphore
without memory addresses.
* Removed the `MemoryDevice2DeviceSemaphoreDeviceHandle::signalPacket()`
method that is lacking use cases.
2025-04-25 00:02:56 +00:00

36 lines
1.5 KiB
Plaintext

// Copyright (c) Microsoft Corporation.
// Licensed under the MIT license.
#include <mscclpp/copy_device.hpp>
#include <mscclpp/port_channel_device.hpp>
// be careful about using channels[my_rank] as it is inavlie and it is there just for simplicity of indexing
extern "C" __global__ void __launch_bounds__(1024, 1)
port_channel(mscclpp::PortChannelDeviceHandle* channels, int my_rank, int nranks, int* data, int* scratch,
int num_elements, int use_packet) {
int tid = threadIdx.x;
int nthreads = blockDim.x;
uint64_t size_per_rank = (num_elements * sizeof(int)) / nranks;
uint64_t my_offset = size_per_rank * my_rank;
int nthreads_per_rank = nthreads / nranks;
int my_nghr = tid / nthreads_per_rank;
uint64_t my_nghr_offset = size_per_rank * my_nghr;
__syncthreads();
int flag = 123;
if (use_packet) {
mscclpp::copyToPackets((char*)scratch + 2 * my_offset, (char*)data + my_offset, size_per_rank, tid, nthreads, flag);
__syncthreads();
if (tid < nranks && tid != my_rank) {
channels[tid].put(2 * my_offset, 2 * my_offset, 2 * size_per_rank);
}
if (my_nghr != my_rank && my_nghr < nranks)
mscclpp::copyFromPackets((char*)data + my_nghr_offset, (char*)scratch + 2 * my_nghr_offset, size_per_rank,
tid % nthreads_per_rank, nthreads_per_rank, flag);
} else {
if (tid < nranks && tid != my_rank) {
channels[tid].putWithSignalAndFlush(my_offset, my_offset, size_per_rank);
channels[tid].wait();
}
}
}