Revise NVLS interface (#458)

* Rename `NvlsConnection::DeviceMulticastPointer` to `SwitchChannel`
* Minor interface improvements
This commit is contained in:
Changho Hwang
2025-07-11 17:33:03 -07:00
committed by GitHub
parent ae56698d67
commit 199468bc47
17 changed files with 462 additions and 377 deletions

View File

@@ -2,18 +2,18 @@
// Licensed under the MIT license.
#include <mscclpp/concurrency_device.hpp>
#include <mscclpp/nvls_device.hpp>
#include <mscclpp/poll_device.hpp>
#include <mscclpp/semaphore_device.hpp>
#include <mscclpp/switch_channel_device.hpp>
__device__ mscclpp::DeviceSyncer deviceSyncer;
extern "C" __global__ void __launch_bounds__(1024, 1)
nvls_test(mscclpp::DeviceMulticastPointerDeviceHandle nvlsPtrs,
nvls_test(mscclpp::SwitchChannelDeviceHandle nvlsPtrs,
mscclpp::MemoryDevice2DeviceSemaphoreDeviceHandle* semaphores, int my_rank, int nranks, int nbytes) {
int nelem = nbytes / sizeof(float);
float* dev_ptr = (float*)nvlsPtrs.devicePtr;
float* mc_ptr = (float*)nvlsPtrs.mcPtr;
mscclpp::f32x4* mc_ptr = (mscclpp::f32x4*)nvlsPtrs.mcPtr;
int tid = threadIdx.x;
int bid = blockIdx.x;
@@ -33,16 +33,15 @@ extern "C" __global__ void __launch_bounds__(1024, 1)
}
deviceSyncer.sync(gridDim.x);
int my_st = ((int64_t)nelem * (int64_t)my_rank) / (int64_t)nranks;
int my_en = ((int64_t)nelem * (int64_t)(my_rank + 1)) / (int64_t)nranks;
int my_st = ((int64_t)nelem / 4 * (int64_t)my_rank) / (int64_t)nranks;
int my_en = ((int64_t)nelem / 4 * (int64_t)(my_rank + 1)) / (int64_t)nranks;
int my_offset = (tid + bid * blockDim.x) * 4;
int my_step = blockDim.x * gridDim.x * 4;
int my_offset = (tid + bid * blockDim.x);
int my_step = blockDim.x * gridDim.x;
for (int idx = my_st + my_offset; idx < my_en; idx += my_step) {
uint4 val;
mscclpp::DeviceMulticastPointerDeviceHandle::multimemLoadReduce(val, mc_ptr + idx);
mscclpp::DeviceMulticastPointerDeviceHandle::multimemStore(val, mc_ptr + idx);
mscclpp::f32x4 val = mscclpp::SwitchChannelDeviceHandle::multimemLoadReduce(mc_ptr + idx);
mscclpp::SwitchChannelDeviceHandle::multimemStore(val, mc_ptr + idx);
}
deviceSyncer.sync(gridDim.x);