mirror of
https://github.com/microsoft/mscclpp.git
synced 2026-06-07 00:05:19 +00:00
Revised MemoryChannel interfaces (#508)
* Moved the `MemoryChannel::copy()` method out of the `MemoryChannel` as a standalone function. * Renamed `mscclpp::putPackets()` and `mscclpp::getPackets()` to `mscclpp::copyToPackets()` and `mscclpp::copyFromPackets()` respectively for consistency. * Renamed `MemoryChannel::getPackets()` to `MemoryChannel::unpackPackets()` for clarity. Renamed `getPacketBuffer` to `packetBuffer`. * Added the `MemoryChannel::unpackPacket()` method that unpacks one packet in the buffer. * Added the `BaseMemoryChannel` class that only contains a semaphore without memory addresses. * Removed the `MemoryDevice2DeviceSemaphoreDeviceHandle::signalPacket()` method that is lacking use cases.
This commit is contained in:
@@ -26,9 +26,9 @@ void register_memory_channel(nb::module_& m) {
|
||||
nb::class_<MemoryChannel::DeviceHandle>(m, "MemoryChannelDeviceHandle")
|
||||
.def(nb::init<>())
|
||||
.def_rw("semaphore_", &MemoryChannel::DeviceHandle::semaphore_)
|
||||
.def_rw("src_", &MemoryChannel::DeviceHandle::src_)
|
||||
.def_rw("dst_", &MemoryChannel::DeviceHandle::dst_)
|
||||
.def_rw("getPacketBuffer_", &MemoryChannel::DeviceHandle::getPacketBuffer_)
|
||||
.def_rw("src_", &MemoryChannel::DeviceHandle::src_)
|
||||
.def_rw("packetBuffer_", &MemoryChannel::DeviceHandle::packetBuffer_)
|
||||
.def_prop_ro("raw", [](const MemoryChannel::DeviceHandle& self) -> nb::bytes {
|
||||
return nb::bytes(reinterpret_cast<const char*>(&self), sizeof(self));
|
||||
});
|
||||
|
||||
@@ -16,7 +16,7 @@ extern "C" __global__ void __launch_bounds__(1024, 1)
|
||||
if (bid < nranks && bid != my_rank) {
|
||||
if (use_packet) {
|
||||
channels[bid].putPackets(2 * my_offset, my_offset, size_per_rank, tid, blockDim.x, flag);
|
||||
channels[bid].getPackets(2 * my_nghr_offset, my_nghr_offset, size_per_rank, tid, blockDim.x, flag);
|
||||
channels[bid].unpackPackets(2 * my_nghr_offset, my_nghr_offset, size_per_rank, tid, blockDim.x, flag);
|
||||
} else {
|
||||
channels[bid].put(my_offset, my_offset, size_per_rank, tid, blockDim.x);
|
||||
__syncthreads();
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
// Copyright (c) Microsoft Corporation.
|
||||
// Licensed under the MIT license.
|
||||
|
||||
#include <mscclpp/packet_device.hpp>
|
||||
#include <mscclpp/copy_device.hpp>
|
||||
#include <mscclpp/port_channel_device.hpp>
|
||||
|
||||
// be careful about using channels[my_rank] as it is inavlie and it is there just for simplicity of indexing
|
||||
@@ -18,14 +18,14 @@ extern "C" __global__ void __launch_bounds__(1024, 1)
|
||||
__syncthreads();
|
||||
int flag = 123;
|
||||
if (use_packet) {
|
||||
mscclpp::putPackets(scratch, 2 * my_offset, data, my_offset, size_per_rank, tid, nthreads, flag);
|
||||
mscclpp::copyToPackets((char*)scratch + 2 * my_offset, (char*)data + my_offset, size_per_rank, tid, nthreads, flag);
|
||||
__syncthreads();
|
||||
if (tid < nranks && tid != my_rank) {
|
||||
channels[tid].put(2 * my_offset, 2 * my_offset, 2 * size_per_rank);
|
||||
}
|
||||
if (my_nghr != my_rank && my_nghr < nranks)
|
||||
mscclpp::getPackets(scratch, 2 * my_nghr_offset, data, my_nghr_offset, size_per_rank, tid % nthreads_per_rank,
|
||||
nthreads_per_rank, flag);
|
||||
mscclpp::copyFromPackets((char*)data + my_nghr_offset, (char*)scratch + 2 * my_nghr_offset, size_per_rank,
|
||||
tid % nthreads_per_rank, nthreads_per_rank, flag);
|
||||
} else {
|
||||
if (tid < nranks && tid != my_rank) {
|
||||
channels[tid].putWithSignalAndFlush(my_offset, my_offset, size_per_rank);
|
||||
|
||||
Reference in New Issue
Block a user