Change device handle interfaces & others (#142)

* Changed device handle interfaces
* Changed proxy service interfaces
* Move device code into separate files
* Fixed FIFO polling issues
* Add configuration arguments in several interface functions

---------

Co-authored-by: Changho Hwang <changhohwang@microsoft.com>
Co-authored-by: Binyang Li <binyli@microsoft.com>
Co-authored-by: root <root@a100-saemal0.qxveptpukjsuthqvv514inp03c.gx.internal.cloudapp.net>
This commit is contained in:
Saeed Maleki
2023-08-16 05:00:56 -07:00
committed by GitHub
parent 4865b2017b
commit 8d1b984bed
59 changed files with 1271 additions and 1036 deletions

View File

@@ -1,13 +1,14 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import mscclpp
import argparse
import multiprocessing as mp
import logging
import torch
import multiprocessing as mp
import sys
import mscclpp
import torch
IB_TRANSPORTS = [
mscclpp.Transport.IB0,
mscclpp.Transport.IB1,
@@ -19,15 +20,19 @@ IB_TRANSPORTS = [
mscclpp.Transport.IB7,
]
# Use to hold the sm channels so they don't get garbage collected
sm_channels = []
def setup_connections(comm, rank, world_size, element_size, proxy_service):
simple_proxy_channels = []
sm_semaphores = []
connections = []
remote_memories = []
memory = torch.zeros(element_size, dtype=torch.int32)
memory = memory.to("cuda")
transport_flag = IB_TRANSPORTS[rank] or mscclpp.Transport.CudaIpc
transport_flag = mscclpp.TransportFlags(IB_TRANSPORTS[rank]) | mscclpp.Transport.CudaIpc
ptr = memory.data_ptr()
size = memory.numel() * memory.element_size()
reg_mem = comm.register_memory(ptr, size, transport_flag)
@@ -42,15 +47,26 @@ def setup_connections(comm, rank, world_size, element_size, proxy_service):
remote_memories.append(remote_mem)
comm.setup()
# Create simple proxy channels
for i, conn in enumerate(connections):
proxy_channel = mscclpp.SimpleProxyChannel(
proxy_service.device_channel(proxy_service.add_semaphore(conn)),
proxy_service.proxy_channel(proxy_service.build_and_add_semaphore(conn)),
proxy_service.add_memory(remote_memories[i].get()),
proxy_service.add_memory(reg_mem),
)
simple_proxy_channels.append(mscclpp.device_handle(proxy_channel))
comm.setup()
return simple_proxy_channels
# Create sm channels
for i, conn in enumerate(connections):
sm_chan = mscclpp.SmDevice2DeviceSemaphore(comm, conn)
sm_semaphores.append(sm_chan)
comm.setup()
for i, conn in enumerate(sm_semaphores):
sm_chan = mscclpp.SmChannel(sm_semaphores[i], remote_memories[i].get(), ptr)
sm_channels.append(sm_chan)
return simple_proxy_channels, [mscclpp.device_handle(sm_chan) for sm_chan in sm_channels]
def run(rank, args):
@@ -60,7 +76,7 @@ def run(rank, args):
boot = mscclpp.TcpBootstrap.create(rank, world_size)
boot.initialize(args.if_ip_port_trio)
comm = mscclpp.Communicator(boot)
proxy_service = mscclpp.ProxyService(comm)
proxy_service = mscclpp.ProxyService()
logging.info("Rank: %d, setting up connections", rank)
setup_connections(comm, rank, world_size, args.num_elements, proxy_service)