mirror of
https://github.com/microsoft/mscclpp.git
synced 2026-06-08 15:30:41 +00:00
Change device handle interfaces & others (#142)
* Changed device handle interfaces * Changed proxy service interfaces * Move device code into separate files * Fixed FIFO polling issues * Add configuration arguments in several interface functions --------- Co-authored-by: Changho Hwang <changhohwang@microsoft.com> Co-authored-by: Binyang Li <binyli@microsoft.com> Co-authored-by: root <root@a100-saemal0.qxveptpukjsuthqvv514inp03c.gx.internal.cloudapp.net>
This commit is contained in:
@@ -1,13 +1,14 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
import mscclpp
|
||||
import argparse
|
||||
import multiprocessing as mp
|
||||
import logging
|
||||
import torch
|
||||
import multiprocessing as mp
|
||||
import sys
|
||||
|
||||
import mscclpp
|
||||
import torch
|
||||
|
||||
IB_TRANSPORTS = [
|
||||
mscclpp.Transport.IB0,
|
||||
mscclpp.Transport.IB1,
|
||||
@@ -19,15 +20,19 @@ IB_TRANSPORTS = [
|
||||
mscclpp.Transport.IB7,
|
||||
]
|
||||
|
||||
# Use to hold the sm channels so they don't get garbage collected
|
||||
sm_channels = []
|
||||
|
||||
|
||||
def setup_connections(comm, rank, world_size, element_size, proxy_service):
|
||||
simple_proxy_channels = []
|
||||
sm_semaphores = []
|
||||
connections = []
|
||||
remote_memories = []
|
||||
memory = torch.zeros(element_size, dtype=torch.int32)
|
||||
memory = memory.to("cuda")
|
||||
|
||||
transport_flag = IB_TRANSPORTS[rank] or mscclpp.Transport.CudaIpc
|
||||
transport_flag = mscclpp.TransportFlags(IB_TRANSPORTS[rank]) | mscclpp.Transport.CudaIpc
|
||||
ptr = memory.data_ptr()
|
||||
size = memory.numel() * memory.element_size()
|
||||
reg_mem = comm.register_memory(ptr, size, transport_flag)
|
||||
@@ -42,15 +47,26 @@ def setup_connections(comm, rank, world_size, element_size, proxy_service):
|
||||
remote_memories.append(remote_mem)
|
||||
comm.setup()
|
||||
|
||||
# Create simple proxy channels
|
||||
for i, conn in enumerate(connections):
|
||||
proxy_channel = mscclpp.SimpleProxyChannel(
|
||||
proxy_service.device_channel(proxy_service.add_semaphore(conn)),
|
||||
proxy_service.proxy_channel(proxy_service.build_and_add_semaphore(conn)),
|
||||
proxy_service.add_memory(remote_memories[i].get()),
|
||||
proxy_service.add_memory(reg_mem),
|
||||
)
|
||||
simple_proxy_channels.append(mscclpp.device_handle(proxy_channel))
|
||||
comm.setup()
|
||||
return simple_proxy_channels
|
||||
|
||||
# Create sm channels
|
||||
for i, conn in enumerate(connections):
|
||||
sm_chan = mscclpp.SmDevice2DeviceSemaphore(comm, conn)
|
||||
sm_semaphores.append(sm_chan)
|
||||
comm.setup()
|
||||
|
||||
for i, conn in enumerate(sm_semaphores):
|
||||
sm_chan = mscclpp.SmChannel(sm_semaphores[i], remote_memories[i].get(), ptr)
|
||||
sm_channels.append(sm_chan)
|
||||
return simple_proxy_channels, [mscclpp.device_handle(sm_chan) for sm_chan in sm_channels]
|
||||
|
||||
|
||||
def run(rank, args):
|
||||
@@ -60,7 +76,7 @@ def run(rank, args):
|
||||
boot = mscclpp.TcpBootstrap.create(rank, world_size)
|
||||
boot.initialize(args.if_ip_port_trio)
|
||||
comm = mscclpp.Communicator(boot)
|
||||
proxy_service = mscclpp.ProxyService(comm)
|
||||
proxy_service = mscclpp.ProxyService()
|
||||
|
||||
logging.info("Rank: %d, setting up connections", rank)
|
||||
setup_connections(comm, rank, world_size, args.num_elements, proxy_service)
|
||||
|
||||
Reference in New Issue
Block a user