mirror of
https://github.com/microsoft/mscclpp.git
synced 2026-05-02 20:51:26 +00:00
Revised ProxyChannel interfaces (#400)
* Renamed `ProxyChannel` -> `BaseProxyChannel` and `SimpleProxyChannel` -> `ProxyChannel`. It makes the interface more consistent by defining channels to be associated with a certain src/dst memory region: `ProxyChannel` as "sema + src/dst + fifo" and `SmChannel` as "sema + src/dst". BaseProxyChannel is not associated with any memory regions, as "sema + fifo". * `ProxyChannelDeviceHandle` now inherits from `BaseProxyChannelDeviceHandle`, instead of having one as a member.
This commit is contained in:
@@ -21,7 +21,7 @@ We will setup a mesh topology with eight GPUs. Each GPU will be connected to its
|
||||
|
||||
template <class T>
|
||||
using DeviceHandle = mscclpp::DeviceHandle<T>;
|
||||
__constant__ DeviceHandle<mscclpp::SimpleProxyChannel> constProxyChans[8];
|
||||
__constant__ DeviceHandle<mscclpp::ProxyChannel> constProxyChans[8];
|
||||
|
||||
void setupMeshTopology(int rank, int worldsize, void* data, size_t dataSize) {
|
||||
std::string ip_port = "10.0.0.4:50000";
|
||||
@@ -55,17 +55,17 @@ void setupMeshTopology(int rank, int worldsize, void* data, size_t dataSize) {
|
||||
|
||||
comm.setup();
|
||||
|
||||
std::vector<DeviceHandle<mscclpp::SimpleProxyChannel>> proxyChannels;
|
||||
std::vector<DeviceHandle<mscclpp::ProxyChannel>> proxyChannels;
|
||||
for (size_t i = 0; i < semaphoreIds.size(); ++i) {
|
||||
proxyChannels.push_back(mscclpp::deviceHandle(mscclpp::SimpleProxyChannel(
|
||||
proxyChannels.push_back(mscclpp::deviceHandle(mscclpp::ProxyChannel(
|
||||
proxyService.proxyChannel(semaphoreIds[i]), proxyService.addMemory(remoteMemories[i].get()),
|
||||
proxyService.addMemory(localMemories[i]))));
|
||||
}
|
||||
|
||||
if (proxyChannels.size() > sizeof(constProxyChans) / sizeof(DeviceHandle<mscclpp::SimpleProxyChannel>)) {
|
||||
if (proxyChannels.size() > sizeof(constProxyChans) / sizeof(DeviceHandle<mscclpp::ProxyChannel>)) {
|
||||
std::runtime_error("unexpected error");
|
||||
}
|
||||
CUDACHECK(cudaMemcpyToSymbol(constProxyChans, proxyChannels.data(),
|
||||
sizeof(DeviceHandle<mscclpp::SimpleProxyChannel>) * proxyChannels.size()));
|
||||
sizeof(DeviceHandle<mscclpp::ProxyChannel>) * proxyChannels.size()));
|
||||
}
|
||||
```
|
||||
|
||||
@@ -47,7 +47,7 @@ We provide some Python utils to help you launch kernel via python. Here is a exa
|
||||
```python
|
||||
from mscclpp.utils import KernelBuilder, pack
|
||||
|
||||
def launch_kernel(my_rank: int, nranks: int, simple_channels: List[SimpleProxyChannel], memory: cp.ndarray):
|
||||
def launch_kernel(my_rank: int, nranks: int, simple_channels: List[ProxyChannel], memory: cp.ndarray):
|
||||
file_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
kernel = KernelBuilder(file="test.cu", kernel_name="test", file_dir=file_dir).get_compiled_kernel()
|
||||
params = b""
|
||||
@@ -77,7 +77,7 @@ The test kernel is defined in `test.cu` as follows:
|
||||
|
||||
// be careful about using channels[my_rank] as it is inavlie and it is there just for simplicity of indexing
|
||||
extern "C" __global__ void __launch_bounds__(1024, 1)
|
||||
simple_proxy_channel(mscclpp::SimpleProxyChannelDeviceHandle* channels, int my_rank, int nranks,
|
||||
proxy_channel(mscclpp::ProxyChannelDeviceHandle* channels, int my_rank, int nranks,
|
||||
int num_elements) {
|
||||
int tid = threadIdx.x;
|
||||
int nthreads = blockDim.x;
|
||||
|
||||
Reference in New Issue
Block a user