From 44e612cfbd0836856a9d4b547492d4fccb52097d Mon Sep 17 00:00:00 2001 From: Olli Saarikivi Date: Tue, 12 Sep 2023 21:30:45 +0000 Subject: [PATCH] Move bootstrap example to notebook Also change bootstrap and context into properties in Communicator --- docs/setup_example.ipynb | 164 +++++++++++++++++++++++++++++++++++ python/examples/bootstrap.py | 108 ----------------------- python/examples/send_recv.py | 80 ----------------- python/examples/utils.py | 17 ---- python/mscclpp/core_py.cpp | 4 +- 5 files changed, 166 insertions(+), 207 deletions(-) create mode 100644 docs/setup_example.ipynb delete mode 100644 python/examples/bootstrap.py delete mode 100644 python/examples/send_recv.py delete mode 100644 python/examples/utils.py diff --git a/docs/setup_example.ipynb b/docs/setup_example.ipynb new file mode 100644 index 00000000..743883b4 --- /dev/null +++ b/docs/setup_example.ipynb @@ -0,0 +1,164 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Copyright (c) Microsoft Corporation.\n", + "Licensed under the MIT license.\n", + "\n", + "The following example demonstrates how to initialize the MSCCL++ library and perform necessary setup for communicating from GPU kernels. First we define a function for registering memory, making connections and creating channels." + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "import mscclpp\n", + "import torch\n", + "\n", + "def setup_channels(comm, memory, proxy_service):\n", + " # Register the memory with the communicator\n", + " ptr = memory.data_ptr()\n", + " size = memory.numel() * memory.element_size()\n", + " reg_mem = comm.register_memory(ptr, size, mscclpp.Transport.CudaIpc)\n", + "\n", + " # Create connections to all other ranks and exchange registered memories\n", + " connections = []\n", + " remote_memories = []\n", + " for r in range(comm.bootstrap.size):\n", + " if r == comm.bootstrap.rank: # Don't connect to self\n", + " continue\n", + " connections.append(comm.connect(r, 0, mscclpp.Transport.CudaIpc))\n", + " comm.send_memory(reg_mem, r, 0)\n", + " remote_mem = comm.recv_memory(r, 0)\n", + " remote_memories.append(remote_mem)\n", + "\n", + " # Both connections and received remote memories are returned as futures,\n", + " # so we wait for them to complete and unwrap them.\n", + " connections = [conn.get() for conn in connections]\n", + " remote_memories = [mem.get() for mem in remote_memories]\n", + "\n", + " # Finally, create proxy channels for each connection\n", + " proxy_channels = [mscclpp.SimpleProxyChannel(\n", + " proxy_service.proxy_channel(proxy_service.build_and_add_semaphore(comm, conn)),\n", + " proxy_service.add_memory(remote_memories[i]),\n", + " proxy_service.add_memory(reg_mem),\n", + " ) for i, conn in enumerate(connections)]\n", + "\n", + " return proxy_channels" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Now we are ready to write the top-level code for each rank." + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "def run(rank, world_size, if_ip_port_trio):\n", + " # Use the right GPU for this rank\n", + " torch.cuda.set_device(rank)\n", + " \n", + " # Allocate memory on the GPU\n", + " memory = torch.zeros(1024, dtype=torch.int32)\n", + " memory = memory.to(\"cuda\")\n", + "\n", + " # Initialize a bootstrapper using a known interface/IP/port trio for the root rank\n", + " boot = mscclpp.TcpBootstrap.create(rank, world_size)\n", + " boot.initialize(if_ip_port_trio)\n", + "\n", + " # Create a communicator for the processes in the bootstrapper\n", + " comm = mscclpp.Communicator(boot)\n", + "\n", + " # Create a proxy service, which enables GPU kernels to use connections\n", + " proxy_service = mscclpp.ProxyService()\n", + "\n", + " if rank == 0:\n", + " print(\"Setting up channels\")\n", + " proxy_channels = setup_channels(comm, memory, proxy_service)\n", + "\n", + " if rank == 0:\n", + " print(\"Starting proxy service\")\n", + " proxy_service.start_proxy()\n", + "\n", + " # This is where we could launch a GPU kernel that uses proxy_channels[i].device_handle\n", + " # to initiate communication. See include/mscclpp/proxy_channel_device.hpp for details.\n", + " if rank == 0:\n", + " print(\"GPU kernels that use the proxy go here.\")\n", + "\n", + " if rank == 0:\n", + " print(f\"Stopping proxy service\")\n", + " proxy_service.stop_proxy()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Finally, to test the code we can run each process using the `multiprocessing` package." + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Setting up channels\n", + "Starting proxy service\n", + "GPU kernels that use the proxy go here.\n", + "Stopping proxy service\n", + "\n", + "Starting proxy service\n", + "GPU kernels that use the proxy go here.\n", + "Stopping proxy service\n" + ] + } + ], + "source": [ + "import multiprocessing as mp\n", + "\n", + "world_size = 2\n", + "processes = [mp.Process(target=run, args=(rank, world_size, \"eth0:localhost:50051\")) for rank in range(world_size)]\n", + "for p in processes:\n", + " p.start()\n", + "for p in processes:\n", + " p.join()" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": ".venv", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.10" + }, + "orig_nbformat": 4 + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/python/examples/bootstrap.py b/python/examples/bootstrap.py deleted file mode 100644 index 28446df2..00000000 --- a/python/examples/bootstrap.py +++ /dev/null @@ -1,108 +0,0 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT license. - -import argparse -import logging -import multiprocessing as mp -import sys - -import mscclpp -import torch - -IB_TRANSPORTS = [ - mscclpp.Transport.IB0, - mscclpp.Transport.IB1, - mscclpp.Transport.IB2, - mscclpp.Transport.IB3, - mscclpp.Transport.IB4, - mscclpp.Transport.IB5, - mscclpp.Transport.IB6, - mscclpp.Transport.IB7, -] - -# Use to hold the sm channels so they don't get garbage collected -sm_channels = [] - - -def setup_connections(comm, rank, world_size, element_size, proxy_service): - simple_proxy_channels = [] - sm_semaphores = [] - connections = [] - remote_memories = [] - memory = torch.zeros(element_size, dtype=torch.int32) - memory = memory.to("cuda") - - transport_flag = mscclpp.TransportFlags(IB_TRANSPORTS[rank]) | mscclpp.Transport.CudaIpc - ptr = memory.data_ptr() - size = memory.numel() * memory.element_size() - reg_mem = comm.register_memory(ptr, size, transport_flag) - - for r in range(world_size): - if r == rank: - continue - conn = comm.connect(r, 0, mscclpp.Transport.CudaIpc) - connections.append(conn) - comm.send_memory(reg_mem, r, 0) - remote_mem = comm.recv_memory(r, 0) - remote_memories.append(remote_mem) - - connections = [conn.get() for conn in connections] - - # Create simple proxy channels - for i, conn in enumerate(connections): - proxy_channel = mscclpp.SimpleProxyChannel( - proxy_service.proxy_channel(proxy_service.build_and_add_semaphore(comm, conn)), - proxy_service.add_memory(remote_memories[i].get()), - proxy_service.add_memory(reg_mem), - ) - simple_proxy_channels.append(proxy_channel.device_handle()) - comm.setup() - - # Create sm channels - for i, conn in enumerate(connections): - sm_chan = mscclpp.SmDevice2DeviceSemaphore(comm, conn) - sm_semaphores.append(sm_chan) - comm.setup() - - for i, conn in enumerate(sm_semaphores): - sm_chan = mscclpp.SmChannel(sm_semaphores[i], remote_memories[i].get(), ptr) - sm_channels.append(sm_chan) - return simple_proxy_channels, [sm_chan.device_handle() for sm_chan in sm_channels] - - -def run(rank, args): - world_size = args.gpu_number - torch.cuda.set_device(rank) - - boot = mscclpp.TcpBootstrap.create(rank, world_size) - boot.initialize(args.if_ip_port_trio) - comm = mscclpp.Communicator(boot) - proxy_service = mscclpp.ProxyService() - - logging.info("Rank: %d, setting up connections", rank) - setup_connections(comm, rank, world_size, args.num_elements, proxy_service) - - logging.info("Rank: %d, starting proxy service", rank) - proxy_service.start_proxy() - - -def main(): - logging.basicConfig(stream=sys.stderr, level=logging.DEBUG) - parser = argparse.ArgumentParser() - parser.add_argument("if_ip_port_trio", type=str) - parser.add_argument("-n", "--num-elements", type=int, default=10) - parser.add_argument("-g", "--gpu_number", type=int, default=2) - args = parser.parse_args() - processes = [] - - for rank in range(args.gpu_number): - p = mp.Process(target=run, args=(rank, args)) - p.start() - processes.append(p) - - for p in processes: - p.join() - - -if __name__ == "__main__": - main() diff --git a/python/examples/send_recv.py b/python/examples/send_recv.py deleted file mode 100644 index e3880072..00000000 --- a/python/examples/send_recv.py +++ /dev/null @@ -1,80 +0,0 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT license. - -import argparse -import time - -import mscclpp - - -def main(args): - if args.root: - rank = 0 - else: - rank = 1 - - boot = mscclpp.TcpBootstrap.create(rank, 2) - boot.initialize(args.if_ip_port_trio) - - comm = mscclpp.Communicator(boot) - - if args.gpu: - import torch - - print("Allocating GPU memory") - memory = torch.zeros(args.num_elements, dtype=torch.int32) - memory = memory.to("cuda") - ptr = memory.data_ptr() - size = memory.numel() * memory.element_size() - else: - from array import array - - print("Allocating host memory") - memory = array("i", [0] * args.num_elements) - ptr, elements = memory.buffer_info() - size = elements * memory.itemsize - my_reg_mem = comm.register_memory(ptr, size, mscclpp.Transport.IB0) - - conn = comm.connect((rank + 1) % 2, 0, mscclpp.Transport.IB0) - - other_reg_mem = None - if rank == 0: - other_reg_mem = comm.recv_memory((rank + 1) % 2, 0) - else: - comm.send_memory(my_reg_mem, (rank + 1) % 2, 0) - - if rank == 0: - other_reg_mem = other_reg_mem.get() - - if rank == 0: - for i in range(args.num_elements): - memory[i] = i + 1 - conn.write(other_reg_mem, 0, my_reg_mem, 0, size) - print("Done sending") - else: - print("Checking for correctness") - # polling - for _ in range(args.polling_num): - all_correct = True - for i in range(args.num_elements): - if memory[i] != i + 1: - all_correct = False - print(f"Error: Mismatch at index {i}: expected {i + 1}, got {memory[i]}") - break - if all_correct: - print("All data matched expected values") - break - else: - time.sleep(0.1) - - -if __name__ == "__main__": - parser = argparse.ArgumentParser() - parser.add_argument("if_ip_port_trio", type=str) - parser.add_argument("-r", "--root", action="store_true") - parser.add_argument("-n", "--num-elements", type=int, default=10) - parser.add_argument("--gpu", action="store_true") - parser.add_argument("--polling_num", type=int, default=100) - args = parser.parse_args() - - main(args) diff --git a/python/examples/utils.py b/python/examples/utils.py deleted file mode 100644 index 7f2b4c98..00000000 --- a/python/examples/utils.py +++ /dev/null @@ -1,17 +0,0 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT license. - -import time - -import mscclpp - - -def main(): - timer = mscclpp.Timer() - timer.reset() - time.sleep(2) - assert timer.elapsed() >= 2000000 - - -if __name__ == "__main__": - main() diff --git a/python/mscclpp/core_py.cpp b/python/mscclpp/core_py.cpp index 2bc1e604..086c241e 100644 --- a/python/mscclpp/core_py.cpp +++ b/python/mscclpp/core_py.cpp @@ -153,8 +153,8 @@ void register_core(nb::module_& m) { nb::class_(m, "Communicator") .def(nb::init, std::shared_ptr>(), nb::arg("bootstrap"), nb::arg("context") = nullptr) - .def("bootstrap", &Communicator::bootstrap) - .def("context", &Communicator::context) + .def_prop_ro("bootstrap", &Communicator::bootstrap) + .def_prop_ro("context", &Communicator::context) .def( "register_memory", [](Communicator* self, uintptr_t ptr, size_t size, TransportFlags transports) {