mirror of
https://github.com/microsoft/mscclpp.git
synced 2026-05-05 22:22:49 +00:00
92 lines
2.6 KiB
Python
92 lines
2.6 KiB
Python
# Copyright (c) Microsoft Corporation.
|
|
# Licensed under the MIT license.
|
|
|
|
import mscclpp
|
|
import argparse
|
|
import multiprocessing as mp
|
|
import logging
|
|
import torch
|
|
import sys
|
|
|
|
IB_TRANSPORTS = [
|
|
mscclpp.Transport.IB0,
|
|
mscclpp.Transport.IB1,
|
|
mscclpp.Transport.IB2,
|
|
mscclpp.Transport.IB3,
|
|
mscclpp.Transport.IB4,
|
|
mscclpp.Transport.IB5,
|
|
mscclpp.Transport.IB6,
|
|
mscclpp.Transport.IB7,
|
|
]
|
|
|
|
|
|
def setup_connections(comm, rank, world_size, element_size, proxy_service):
|
|
simple_proxy_channels = []
|
|
connections = []
|
|
remote_memories = []
|
|
memory = torch.zeros(element_size, dtype=torch.int32)
|
|
memory = memory.to("cuda")
|
|
|
|
transport_flag = IB_TRANSPORTS[rank] or mscclpp.Transport.CudaIpc
|
|
ptr = memory.data_ptr()
|
|
size = memory.numel() * memory.element_size()
|
|
reg_mem = comm.register_memory(ptr, size, transport_flag)
|
|
|
|
for r in range(world_size):
|
|
if r == rank:
|
|
continue
|
|
conn = comm.connect_on_setup(r, 0, mscclpp.Transport.CudaIpc)
|
|
connections.append(conn)
|
|
comm.send_memory_on_setup(reg_mem, r, 0)
|
|
remote_mem = comm.recv_memory_on_setup(r, 0)
|
|
remote_memories.append(remote_mem)
|
|
comm.setup()
|
|
|
|
for i, conn in enumerate(connections):
|
|
proxy_channel = mscclpp.SimpleProxyChannel(
|
|
proxy_service.device_channel(proxy_service.add_semaphore(conn)),
|
|
proxy_service.add_memory(remote_memories[i].get()),
|
|
proxy_service.add_memory(reg_mem),
|
|
)
|
|
simple_proxy_channels.append(mscclpp.device_handle(proxy_channel))
|
|
comm.setup()
|
|
return simple_proxy_channels
|
|
|
|
|
|
def run(rank, args):
|
|
world_size = args.gpu_number
|
|
torch.cuda.set_device(rank)
|
|
|
|
boot = mscclpp.TcpBootstrap.create(rank, world_size)
|
|
boot.initialize(args.if_ip_port_trio)
|
|
comm = mscclpp.Communicator(boot)
|
|
proxy_service = mscclpp.ProxyService(comm)
|
|
|
|
logging.info("Rank: %d, setting up connections", rank)
|
|
setup_connections(comm, rank, world_size, args.num_elements, proxy_service)
|
|
|
|
logging.info("Rank: %d, starting proxy service", rank)
|
|
proxy_service.start_proxy()
|
|
|
|
|
|
def main():
|
|
logging.basicConfig(stream=sys.stderr, level=logging.DEBUG)
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument("if_ip_port_trio", type=str)
|
|
parser.add_argument("-n", "--num-elements", type=int, default=10)
|
|
parser.add_argument("-g", "--gpu_number", type=int, default=2)
|
|
args = parser.parse_args()
|
|
processes = []
|
|
|
|
for rank in range(args.gpu_number):
|
|
p = mp.Process(target=run, args=(rank, args))
|
|
p.start()
|
|
processes.append(p)
|
|
|
|
for p in processes:
|
|
p.join()
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|