mirror of
https://github.com/microsoft/mscclpp.git
synced 2026-05-04 13:41:27 +00:00
Revised ProxyChannel interfaces (#400)
* Renamed `ProxyChannel` -> `BaseProxyChannel` and `SimpleProxyChannel` -> `ProxyChannel`. It makes the interface more consistent by defining channels to be associated with a certain src/dst memory region: `ProxyChannel` as "sema + src/dst + fifo" and `SmChannel` as "sema + src/dst". BaseProxyChannel is not associated with any memory regions, as "sema + fifo". * `ProxyChannelDeviceHandle` now inherits from `BaseProxyChannelDeviceHandle`, instead of having one as a member.
This commit is contained in:
35
python/test/proxy_channel_test.cu
Normal file
35
python/test/proxy_channel_test.cu
Normal file
@@ -0,0 +1,35 @@
|
||||
// Copyright (c) Microsoft Corporation.
|
||||
// Licensed under the MIT license.
|
||||
|
||||
#include <mscclpp/packet_device.hpp>
|
||||
#include <mscclpp/proxy_channel_device.hpp>
|
||||
|
||||
// be careful about using channels[my_rank] as it is inavlie and it is there just for simplicity of indexing
|
||||
extern "C" __global__ void __launch_bounds__(1024, 1)
|
||||
proxy_channel(mscclpp::ProxyChannelDeviceHandle* channels, int my_rank, int nranks, int* data, int* scratch,
|
||||
int num_elements, int use_packet) {
|
||||
int tid = threadIdx.x;
|
||||
int nthreads = blockDim.x;
|
||||
uint64_t size_per_rank = (num_elements * sizeof(int)) / nranks;
|
||||
uint64_t my_offset = size_per_rank * my_rank;
|
||||
int nthreads_per_rank = nthreads / nranks;
|
||||
int my_nghr = tid / nthreads_per_rank;
|
||||
uint64_t my_nghr_offset = size_per_rank * my_nghr;
|
||||
__syncthreads();
|
||||
int flag = 123;
|
||||
if (use_packet) {
|
||||
mscclpp::putPackets(scratch, 2 * my_offset, data, my_offset, size_per_rank, tid, nthreads, flag);
|
||||
__syncthreads();
|
||||
if (tid < nranks && tid != my_rank) {
|
||||
channels[tid].put(2 * my_offset, 2 * my_offset, 2 * size_per_rank);
|
||||
}
|
||||
if (my_nghr != my_rank && my_nghr < nranks)
|
||||
mscclpp::getPackets(scratch, 2 * my_nghr_offset, data, my_nghr_offset, size_per_rank, tid % nthreads_per_rank,
|
||||
nthreads_per_rank, flag);
|
||||
} else {
|
||||
if (tid < nranks && tid != my_rank) {
|
||||
channels[tid].putWithSignalAndFlush(my_offset, my_offset, size_per_rank);
|
||||
channels[tid].wait();
|
||||
}
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user