Files
mscclpp/python/examples/bootstrap.py
Olli Saarikivi 828be48b21 Add Context and Endpoint classes to enable non-Communicator use-cases (#166)
This PR implements and closes #137. The new `Endpoint` and `Context`
classes expose the connection establishing functionality from
`Communicator`, which now is only responsible for tying together the
bootstrapper with a context.

The largest breaking change here is that
`Communicator.connectOnSetup(...)` now returns the `Connection` wrapped
inside a `NonblockingFuture`. This is because with the way `Context` is
implemented a `Connection` is now fully initialized on construction.

Some smaller breaking API changes from this change are that
`RegisteredMemory` no longer has a `rank()` function (as there maybe no
concept of rank), and similarly `Connection` has no `remoteRank()` and
`tag()` functions. The latter are replaced by `remoteRankOf` and `tagOf`
functions in `Communicator`.

A new `EndpointConfig` class is introduced to avoid duplication of the
IB configuration parameters in the APIs of `Context` and `Communicator`.
The usual usage pattern of just passing in a `Transport` still works due
to an implicit conversion into `EndpointConfig`.

Miscellaneous changes:
-Cleans up how the PIMPL pattern is applied by making both the `Impl`
struct and the `pimpl_` pointers private for all relevant classes in the
core API.
-Enables ctest to be run from the build root directory.
2023-09-06 13:10:04 +08:00

110 lines
3.3 KiB
Python

# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import argparse
import logging
import multiprocessing as mp
import sys
import mscclpp
import torch
IB_TRANSPORTS = [
mscclpp.Transport.IB0,
mscclpp.Transport.IB1,
mscclpp.Transport.IB2,
mscclpp.Transport.IB3,
mscclpp.Transport.IB4,
mscclpp.Transport.IB5,
mscclpp.Transport.IB6,
mscclpp.Transport.IB7,
]
# Use to hold the sm channels so they don't get garbage collected
sm_channels = []
def setup_connections(comm, rank, world_size, element_size, proxy_service):
simple_proxy_channels = []
sm_semaphores = []
connections = []
remote_memories = []
memory = torch.zeros(element_size, dtype=torch.int32)
memory = memory.to("cuda")
transport_flag = mscclpp.TransportFlags(IB_TRANSPORTS[rank]) | mscclpp.Transport.CudaIpc
ptr = memory.data_ptr()
size = memory.numel() * memory.element_size()
reg_mem = comm.register_memory(ptr, size, transport_flag)
for r in range(world_size):
if r == rank:
continue
conn = comm.connect_on_setup(r, 0, mscclpp.Transport.CudaIpc)
connections.append(conn)
comm.send_memory_on_setup(reg_mem, r, 0)
remote_mem = comm.recv_memory_on_setup(r, 0)
remote_memories.append(remote_mem)
comm.setup()
connections = [conn.get() for conn in connections]
# Create simple proxy channels
for i, conn in enumerate(connections):
proxy_channel = mscclpp.SimpleProxyChannel(
proxy_service.proxy_channel(proxy_service.build_and_add_semaphore(comm, conn)),
proxy_service.add_memory(remote_memories[i].get()),
proxy_service.add_memory(reg_mem),
)
simple_proxy_channels.append(proxy_channel.device_handle())
comm.setup()
# Create sm channels
for i, conn in enumerate(connections):
sm_chan = mscclpp.SmDevice2DeviceSemaphore(comm, conn)
sm_semaphores.append(sm_chan)
comm.setup()
for i, conn in enumerate(sm_semaphores):
sm_chan = mscclpp.SmChannel(sm_semaphores[i], remote_memories[i].get(), ptr)
sm_channels.append(sm_chan)
return simple_proxy_channels, [sm_chan.device_handle() for sm_chan in sm_channels]
def run(rank, args):
world_size = args.gpu_number
torch.cuda.set_device(rank)
boot = mscclpp.TcpBootstrap.create(rank, world_size)
boot.initialize(args.if_ip_port_trio)
comm = mscclpp.Communicator(boot)
proxy_service = mscclpp.ProxyService()
logging.info("Rank: %d, setting up connections", rank)
setup_connections(comm, rank, world_size, args.num_elements, proxy_service)
logging.info("Rank: %d, starting proxy service", rank)
proxy_service.start_proxy()
def main():
logging.basicConfig(stream=sys.stderr, level=logging.DEBUG)
parser = argparse.ArgumentParser()
parser.add_argument("if_ip_port_trio", type=str)
parser.add_argument("-n", "--num-elements", type=int, default=10)
parser.add_argument("-g", "--gpu_number", type=int, default=2)
args = parser.parse_args()
processes = []
for rank in range(args.gpu_number):
p = mp.Process(target=run, args=(rank, args))
p.start()
processes.append(p)
for p in processes:
p.join()
if __name__ == "__main__":
main()