mirror of
https://github.com/microsoft/mscclpp.git
synced 2026-04-20 06:49:29 +00:00
Revise NVLS interface (#458)
* Rename `NvlsConnection::DeviceMulticastPointer` to `SwitchChannel` * Minor interface improvements
This commit is contained in:
@@ -2,18 +2,18 @@
|
||||
// Licensed under the MIT license.
|
||||
|
||||
#include <mscclpp/concurrency_device.hpp>
|
||||
#include <mscclpp/nvls_device.hpp>
|
||||
#include <mscclpp/poll_device.hpp>
|
||||
#include <mscclpp/semaphore_device.hpp>
|
||||
#include <mscclpp/switch_channel_device.hpp>
|
||||
|
||||
__device__ mscclpp::DeviceSyncer deviceSyncer;
|
||||
|
||||
extern "C" __global__ void __launch_bounds__(1024, 1)
|
||||
nvls_test(mscclpp::DeviceMulticastPointerDeviceHandle nvlsPtrs,
|
||||
nvls_test(mscclpp::SwitchChannelDeviceHandle nvlsPtrs,
|
||||
mscclpp::MemoryDevice2DeviceSemaphoreDeviceHandle* semaphores, int my_rank, int nranks, int nbytes) {
|
||||
int nelem = nbytes / sizeof(float);
|
||||
float* dev_ptr = (float*)nvlsPtrs.devicePtr;
|
||||
float* mc_ptr = (float*)nvlsPtrs.mcPtr;
|
||||
mscclpp::f32x4* mc_ptr = (mscclpp::f32x4*)nvlsPtrs.mcPtr;
|
||||
int tid = threadIdx.x;
|
||||
int bid = blockIdx.x;
|
||||
|
||||
@@ -33,16 +33,15 @@ extern "C" __global__ void __launch_bounds__(1024, 1)
|
||||
}
|
||||
deviceSyncer.sync(gridDim.x);
|
||||
|
||||
int my_st = ((int64_t)nelem * (int64_t)my_rank) / (int64_t)nranks;
|
||||
int my_en = ((int64_t)nelem * (int64_t)(my_rank + 1)) / (int64_t)nranks;
|
||||
int my_st = ((int64_t)nelem / 4 * (int64_t)my_rank) / (int64_t)nranks;
|
||||
int my_en = ((int64_t)nelem / 4 * (int64_t)(my_rank + 1)) / (int64_t)nranks;
|
||||
|
||||
int my_offset = (tid + bid * blockDim.x) * 4;
|
||||
int my_step = blockDim.x * gridDim.x * 4;
|
||||
int my_offset = (tid + bid * blockDim.x);
|
||||
int my_step = blockDim.x * gridDim.x;
|
||||
|
||||
for (int idx = my_st + my_offset; idx < my_en; idx += my_step) {
|
||||
uint4 val;
|
||||
mscclpp::DeviceMulticastPointerDeviceHandle::multimemLoadReduce(val, mc_ptr + idx);
|
||||
mscclpp::DeviceMulticastPointerDeviceHandle::multimemStore(val, mc_ptr + idx);
|
||||
mscclpp::f32x4 val = mscclpp::SwitchChannelDeviceHandle::multimemLoadReduce(mc_ptr + idx);
|
||||
mscclpp::SwitchChannelDeviceHandle::multimemStore(val, mc_ptr + idx);
|
||||
}
|
||||
|
||||
deviceSyncer.sync(gridDim.x);
|
||||
|
||||
Reference in New Issue
Block a user