Files
mscclpp/docs/setup_example.ipynb
Olli Saarikivi 44e612cfbd Move bootstrap example to notebook
Also change bootstrap and context into properties in Communicator
2023-11-08 18:44:45 +00:00

165 lines
5.2 KiB
Plaintext

{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Copyright (c) Microsoft Corporation.\n",
"Licensed under the MIT license.\n",
"\n",
"The following example demonstrates how to initialize the MSCCL++ library and perform necessary setup for communicating from GPU kernels. First we define a function for registering memory, making connections and creating channels."
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"import mscclpp\n",
"import torch\n",
"\n",
"def setup_channels(comm, memory, proxy_service):\n",
" # Register the memory with the communicator\n",
" ptr = memory.data_ptr()\n",
" size = memory.numel() * memory.element_size()\n",
" reg_mem = comm.register_memory(ptr, size, mscclpp.Transport.CudaIpc)\n",
"\n",
" # Create connections to all other ranks and exchange registered memories\n",
" connections = []\n",
" remote_memories = []\n",
" for r in range(comm.bootstrap.size):\n",
" if r == comm.bootstrap.rank: # Don't connect to self\n",
" continue\n",
" connections.append(comm.connect(r, 0, mscclpp.Transport.CudaIpc))\n",
" comm.send_memory(reg_mem, r, 0)\n",
" remote_mem = comm.recv_memory(r, 0)\n",
" remote_memories.append(remote_mem)\n",
"\n",
" # Both connections and received remote memories are returned as futures,\n",
" # so we wait for them to complete and unwrap them.\n",
" connections = [conn.get() for conn in connections]\n",
" remote_memories = [mem.get() for mem in remote_memories]\n",
"\n",
" # Finally, create proxy channels for each connection\n",
" proxy_channels = [mscclpp.SimpleProxyChannel(\n",
" proxy_service.proxy_channel(proxy_service.build_and_add_semaphore(comm, conn)),\n",
" proxy_service.add_memory(remote_memories[i]),\n",
" proxy_service.add_memory(reg_mem),\n",
" ) for i, conn in enumerate(connections)]\n",
"\n",
" return proxy_channels"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Now we are ready to write the top-level code for each rank."
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"def run(rank, world_size, if_ip_port_trio):\n",
" # Use the right GPU for this rank\n",
" torch.cuda.set_device(rank)\n",
" \n",
" # Allocate memory on the GPU\n",
" memory = torch.zeros(1024, dtype=torch.int32)\n",
" memory = memory.to(\"cuda\")\n",
"\n",
" # Initialize a bootstrapper using a known interface/IP/port trio for the root rank\n",
" boot = mscclpp.TcpBootstrap.create(rank, world_size)\n",
" boot.initialize(if_ip_port_trio)\n",
"\n",
" # Create a communicator for the processes in the bootstrapper\n",
" comm = mscclpp.Communicator(boot)\n",
"\n",
" # Create a proxy service, which enables GPU kernels to use connections\n",
" proxy_service = mscclpp.ProxyService()\n",
"\n",
" if rank == 0:\n",
" print(\"Setting up channels\")\n",
" proxy_channels = setup_channels(comm, memory, proxy_service)\n",
"\n",
" if rank == 0:\n",
" print(\"Starting proxy service\")\n",
" proxy_service.start_proxy()\n",
"\n",
" # This is where we could launch a GPU kernel that uses proxy_channels[i].device_handle\n",
" # to initiate communication. See include/mscclpp/proxy_channel_device.hpp for details.\n",
" if rank == 0:\n",
" print(\"GPU kernels that use the proxy go here.\")\n",
"\n",
" if rank == 0:\n",
" print(f\"Stopping proxy service\")\n",
" proxy_service.stop_proxy()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Finally, to test the code we can run each process using the `multiprocessing` package."
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Setting up channels\n",
"Starting proxy service\n",
"GPU kernels that use the proxy go here.\n",
"Stopping proxy service\n",
"\n",
"Starting proxy service\n",
"GPU kernels that use the proxy go here.\n",
"Stopping proxy service\n"
]
}
],
"source": [
"import multiprocessing as mp\n",
"\n",
"world_size = 2\n",
"processes = [mp.Process(target=run, args=(rank, world_size, \"eth0:localhost:50051\")) for rank in range(world_size)]\n",
"for p in processes:\n",
" p.start()\n",
"for p in processes:\n",
" p.join()"
]
}
],
"metadata": {
"kernelspec": {
"display_name": ".venv",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.10"
},
"orig_nbformat": 4
},
"nbformat": 4,
"nbformat_minor": 2
}