mirror of
https://github.com/microsoft/mscclpp.git
synced 2026-03-29 19:47:48 +00:00
This PR implements and closes #137. The new `Endpoint` and `Context` classes expose the connection establishing functionality from `Communicator`, which now is only responsible for tying together the bootstrapper with a context. The largest breaking change here is that `Communicator.connectOnSetup(...)` now returns the `Connection` wrapped inside a `NonblockingFuture`. This is because with the way `Context` is implemented a `Connection` is now fully initialized on construction. Some smaller breaking API changes from this change are that `RegisteredMemory` no longer has a `rank()` function (as there maybe no concept of rank), and similarly `Connection` has no `remoteRank()` and `tag()` functions. The latter are replaced by `remoteRankOf` and `tagOf` functions in `Communicator`. A new `EndpointConfig` class is introduced to avoid duplication of the IB configuration parameters in the APIs of `Context` and `Communicator`. The usual usage pattern of just passing in a `Transport` still works due to an implicit conversion into `EndpointConfig`. Miscellaneous changes: -Cleans up how the PIMPL pattern is applied by making both the `Impl` struct and the `pimpl_` pointers private for all relevant classes in the core API. -Enables ctest to be run from the build root directory.
110 lines
3.3 KiB
Python
110 lines
3.3 KiB
Python
# Copyright (c) Microsoft Corporation.
|
|
# Licensed under the MIT license.
|
|
|
|
import argparse
|
|
import logging
|
|
import multiprocessing as mp
|
|
import sys
|
|
|
|
import mscclpp
|
|
import torch
|
|
|
|
IB_TRANSPORTS = [
|
|
mscclpp.Transport.IB0,
|
|
mscclpp.Transport.IB1,
|
|
mscclpp.Transport.IB2,
|
|
mscclpp.Transport.IB3,
|
|
mscclpp.Transport.IB4,
|
|
mscclpp.Transport.IB5,
|
|
mscclpp.Transport.IB6,
|
|
mscclpp.Transport.IB7,
|
|
]
|
|
|
|
# Use to hold the sm channels so they don't get garbage collected
|
|
sm_channels = []
|
|
|
|
|
|
def setup_connections(comm, rank, world_size, element_size, proxy_service):
|
|
simple_proxy_channels = []
|
|
sm_semaphores = []
|
|
connections = []
|
|
remote_memories = []
|
|
memory = torch.zeros(element_size, dtype=torch.int32)
|
|
memory = memory.to("cuda")
|
|
|
|
transport_flag = mscclpp.TransportFlags(IB_TRANSPORTS[rank]) | mscclpp.Transport.CudaIpc
|
|
ptr = memory.data_ptr()
|
|
size = memory.numel() * memory.element_size()
|
|
reg_mem = comm.register_memory(ptr, size, transport_flag)
|
|
|
|
for r in range(world_size):
|
|
if r == rank:
|
|
continue
|
|
conn = comm.connect_on_setup(r, 0, mscclpp.Transport.CudaIpc)
|
|
connections.append(conn)
|
|
comm.send_memory_on_setup(reg_mem, r, 0)
|
|
remote_mem = comm.recv_memory_on_setup(r, 0)
|
|
remote_memories.append(remote_mem)
|
|
comm.setup()
|
|
|
|
connections = [conn.get() for conn in connections]
|
|
|
|
# Create simple proxy channels
|
|
for i, conn in enumerate(connections):
|
|
proxy_channel = mscclpp.SimpleProxyChannel(
|
|
proxy_service.proxy_channel(proxy_service.build_and_add_semaphore(comm, conn)),
|
|
proxy_service.add_memory(remote_memories[i].get()),
|
|
proxy_service.add_memory(reg_mem),
|
|
)
|
|
simple_proxy_channels.append(proxy_channel.device_handle())
|
|
comm.setup()
|
|
|
|
# Create sm channels
|
|
for i, conn in enumerate(connections):
|
|
sm_chan = mscclpp.SmDevice2DeviceSemaphore(comm, conn)
|
|
sm_semaphores.append(sm_chan)
|
|
comm.setup()
|
|
|
|
for i, conn in enumerate(sm_semaphores):
|
|
sm_chan = mscclpp.SmChannel(sm_semaphores[i], remote_memories[i].get(), ptr)
|
|
sm_channels.append(sm_chan)
|
|
return simple_proxy_channels, [sm_chan.device_handle() for sm_chan in sm_channels]
|
|
|
|
|
|
def run(rank, args):
|
|
world_size = args.gpu_number
|
|
torch.cuda.set_device(rank)
|
|
|
|
boot = mscclpp.TcpBootstrap.create(rank, world_size)
|
|
boot.initialize(args.if_ip_port_trio)
|
|
comm = mscclpp.Communicator(boot)
|
|
proxy_service = mscclpp.ProxyService()
|
|
|
|
logging.info("Rank: %d, setting up connections", rank)
|
|
setup_connections(comm, rank, world_size, args.num_elements, proxy_service)
|
|
|
|
logging.info("Rank: %d, starting proxy service", rank)
|
|
proxy_service.start_proxy()
|
|
|
|
|
|
def main():
|
|
logging.basicConfig(stream=sys.stderr, level=logging.DEBUG)
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument("if_ip_port_trio", type=str)
|
|
parser.add_argument("-n", "--num-elements", type=int, default=10)
|
|
parser.add_argument("-g", "--gpu_number", type=int, default=2)
|
|
args = parser.parse_args()
|
|
processes = []
|
|
|
|
for rank in range(args.gpu_number):
|
|
p = mp.Process(target=run, args=(rank, args))
|
|
p.start()
|
|
processes.append(p)
|
|
|
|
for p in processes:
|
|
p.join()
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|