Move bootstrap example to notebook

Also change bootstrap and context into properties in Communicator
This commit is contained in:
Olli Saarikivi
2023-09-12 21:30:45 +00:00
committed by Saeed Maleki
parent 6b39fd9a54
commit 44e612cfbd
5 changed files with 166 additions and 207 deletions

164
docs/setup_example.ipynb Normal file
View File

@@ -0,0 +1,164 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Copyright (c) Microsoft Corporation.\n",
"Licensed under the MIT license.\n",
"\n",
"The following example demonstrates how to initialize the MSCCL++ library and perform necessary setup for communicating from GPU kernels. First we define a function for registering memory, making connections and creating channels."
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"import mscclpp\n",
"import torch\n",
"\n",
"def setup_channels(comm, memory, proxy_service):\n",
" # Register the memory with the communicator\n",
" ptr = memory.data_ptr()\n",
" size = memory.numel() * memory.element_size()\n",
" reg_mem = comm.register_memory(ptr, size, mscclpp.Transport.CudaIpc)\n",
"\n",
" # Create connections to all other ranks and exchange registered memories\n",
" connections = []\n",
" remote_memories = []\n",
" for r in range(comm.bootstrap.size):\n",
" if r == comm.bootstrap.rank: # Don't connect to self\n",
" continue\n",
" connections.append(comm.connect(r, 0, mscclpp.Transport.CudaIpc))\n",
" comm.send_memory(reg_mem, r, 0)\n",
" remote_mem = comm.recv_memory(r, 0)\n",
" remote_memories.append(remote_mem)\n",
"\n",
" # Both connections and received remote memories are returned as futures,\n",
" # so we wait for them to complete and unwrap them.\n",
" connections = [conn.get() for conn in connections]\n",
" remote_memories = [mem.get() for mem in remote_memories]\n",
"\n",
" # Finally, create proxy channels for each connection\n",
" proxy_channels = [mscclpp.SimpleProxyChannel(\n",
" proxy_service.proxy_channel(proxy_service.build_and_add_semaphore(comm, conn)),\n",
" proxy_service.add_memory(remote_memories[i]),\n",
" proxy_service.add_memory(reg_mem),\n",
" ) for i, conn in enumerate(connections)]\n",
"\n",
" return proxy_channels"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Now we are ready to write the top-level code for each rank."
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"def run(rank, world_size, if_ip_port_trio):\n",
" # Use the right GPU for this rank\n",
" torch.cuda.set_device(rank)\n",
" \n",
" # Allocate memory on the GPU\n",
" memory = torch.zeros(1024, dtype=torch.int32)\n",
" memory = memory.to(\"cuda\")\n",
"\n",
" # Initialize a bootstrapper using a known interface/IP/port trio for the root rank\n",
" boot = mscclpp.TcpBootstrap.create(rank, world_size)\n",
" boot.initialize(if_ip_port_trio)\n",
"\n",
" # Create a communicator for the processes in the bootstrapper\n",
" comm = mscclpp.Communicator(boot)\n",
"\n",
" # Create a proxy service, which enables GPU kernels to use connections\n",
" proxy_service = mscclpp.ProxyService()\n",
"\n",
" if rank == 0:\n",
" print(\"Setting up channels\")\n",
" proxy_channels = setup_channels(comm, memory, proxy_service)\n",
"\n",
" if rank == 0:\n",
" print(\"Starting proxy service\")\n",
" proxy_service.start_proxy()\n",
"\n",
" # This is where we could launch a GPU kernel that uses proxy_channels[i].device_handle\n",
" # to initiate communication. See include/mscclpp/proxy_channel_device.hpp for details.\n",
" if rank == 0:\n",
" print(\"GPU kernels that use the proxy go here.\")\n",
"\n",
" if rank == 0:\n",
" print(f\"Stopping proxy service\")\n",
" proxy_service.stop_proxy()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Finally, to test the code we can run each process using the `multiprocessing` package."
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Setting up channels\n",
"Starting proxy service\n",
"GPU kernels that use the proxy go here.\n",
"Stopping proxy service\n",
"\n",
"Starting proxy service\n",
"GPU kernels that use the proxy go here.\n",
"Stopping proxy service\n"
]
}
],
"source": [
"import multiprocessing as mp\n",
"\n",
"world_size = 2\n",
"processes = [mp.Process(target=run, args=(rank, world_size, \"eth0:localhost:50051\")) for rank in range(world_size)]\n",
"for p in processes:\n",
" p.start()\n",
"for p in processes:\n",
" p.join()"
]
}
],
"metadata": {
"kernelspec": {
"display_name": ".venv",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.10"
},
"orig_nbformat": 4
},
"nbformat": 4,
"nbformat_minor": 2
}

View File

@@ -1,108 +0,0 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import argparse
import logging
import multiprocessing as mp
import sys
import mscclpp
import torch
IB_TRANSPORTS = [
mscclpp.Transport.IB0,
mscclpp.Transport.IB1,
mscclpp.Transport.IB2,
mscclpp.Transport.IB3,
mscclpp.Transport.IB4,
mscclpp.Transport.IB5,
mscclpp.Transport.IB6,
mscclpp.Transport.IB7,
]
# Use to hold the sm channels so they don't get garbage collected
sm_channels = []
def setup_connections(comm, rank, world_size, element_size, proxy_service):
simple_proxy_channels = []
sm_semaphores = []
connections = []
remote_memories = []
memory = torch.zeros(element_size, dtype=torch.int32)
memory = memory.to("cuda")
transport_flag = mscclpp.TransportFlags(IB_TRANSPORTS[rank]) | mscclpp.Transport.CudaIpc
ptr = memory.data_ptr()
size = memory.numel() * memory.element_size()
reg_mem = comm.register_memory(ptr, size, transport_flag)
for r in range(world_size):
if r == rank:
continue
conn = comm.connect(r, 0, mscclpp.Transport.CudaIpc)
connections.append(conn)
comm.send_memory(reg_mem, r, 0)
remote_mem = comm.recv_memory(r, 0)
remote_memories.append(remote_mem)
connections = [conn.get() for conn in connections]
# Create simple proxy channels
for i, conn in enumerate(connections):
proxy_channel = mscclpp.SimpleProxyChannel(
proxy_service.proxy_channel(proxy_service.build_and_add_semaphore(comm, conn)),
proxy_service.add_memory(remote_memories[i].get()),
proxy_service.add_memory(reg_mem),
)
simple_proxy_channels.append(proxy_channel.device_handle())
comm.setup()
# Create sm channels
for i, conn in enumerate(connections):
sm_chan = mscclpp.SmDevice2DeviceSemaphore(comm, conn)
sm_semaphores.append(sm_chan)
comm.setup()
for i, conn in enumerate(sm_semaphores):
sm_chan = mscclpp.SmChannel(sm_semaphores[i], remote_memories[i].get(), ptr)
sm_channels.append(sm_chan)
return simple_proxy_channels, [sm_chan.device_handle() for sm_chan in sm_channels]
def run(rank, args):
world_size = args.gpu_number
torch.cuda.set_device(rank)
boot = mscclpp.TcpBootstrap.create(rank, world_size)
boot.initialize(args.if_ip_port_trio)
comm = mscclpp.Communicator(boot)
proxy_service = mscclpp.ProxyService()
logging.info("Rank: %d, setting up connections", rank)
setup_connections(comm, rank, world_size, args.num_elements, proxy_service)
logging.info("Rank: %d, starting proxy service", rank)
proxy_service.start_proxy()
def main():
logging.basicConfig(stream=sys.stderr, level=logging.DEBUG)
parser = argparse.ArgumentParser()
parser.add_argument("if_ip_port_trio", type=str)
parser.add_argument("-n", "--num-elements", type=int, default=10)
parser.add_argument("-g", "--gpu_number", type=int, default=2)
args = parser.parse_args()
processes = []
for rank in range(args.gpu_number):
p = mp.Process(target=run, args=(rank, args))
p.start()
processes.append(p)
for p in processes:
p.join()
if __name__ == "__main__":
main()

View File

@@ -1,80 +0,0 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import argparse
import time
import mscclpp
def main(args):
if args.root:
rank = 0
else:
rank = 1
boot = mscclpp.TcpBootstrap.create(rank, 2)
boot.initialize(args.if_ip_port_trio)
comm = mscclpp.Communicator(boot)
if args.gpu:
import torch
print("Allocating GPU memory")
memory = torch.zeros(args.num_elements, dtype=torch.int32)
memory = memory.to("cuda")
ptr = memory.data_ptr()
size = memory.numel() * memory.element_size()
else:
from array import array
print("Allocating host memory")
memory = array("i", [0] * args.num_elements)
ptr, elements = memory.buffer_info()
size = elements * memory.itemsize
my_reg_mem = comm.register_memory(ptr, size, mscclpp.Transport.IB0)
conn = comm.connect((rank + 1) % 2, 0, mscclpp.Transport.IB0)
other_reg_mem = None
if rank == 0:
other_reg_mem = comm.recv_memory((rank + 1) % 2, 0)
else:
comm.send_memory(my_reg_mem, (rank + 1) % 2, 0)
if rank == 0:
other_reg_mem = other_reg_mem.get()
if rank == 0:
for i in range(args.num_elements):
memory[i] = i + 1
conn.write(other_reg_mem, 0, my_reg_mem, 0, size)
print("Done sending")
else:
print("Checking for correctness")
# polling
for _ in range(args.polling_num):
all_correct = True
for i in range(args.num_elements):
if memory[i] != i + 1:
all_correct = False
print(f"Error: Mismatch at index {i}: expected {i + 1}, got {memory[i]}")
break
if all_correct:
print("All data matched expected values")
break
else:
time.sleep(0.1)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("if_ip_port_trio", type=str)
parser.add_argument("-r", "--root", action="store_true")
parser.add_argument("-n", "--num-elements", type=int, default=10)
parser.add_argument("--gpu", action="store_true")
parser.add_argument("--polling_num", type=int, default=100)
args = parser.parse_args()
main(args)

View File

@@ -1,17 +0,0 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import time
import mscclpp
def main():
timer = mscclpp.Timer()
timer.reset()
time.sleep(2)
assert timer.elapsed() >= 2000000
if __name__ == "__main__":
main()

View File

@@ -153,8 +153,8 @@ void register_core(nb::module_& m) {
nb::class_<Communicator>(m, "Communicator")
.def(nb::init<std::shared_ptr<Bootstrap>, std::shared_ptr<Context>>(), nb::arg("bootstrap"),
nb::arg("context") = nullptr)
.def("bootstrap", &Communicator::bootstrap)
.def("context", &Communicator::context)
.def_prop_ro("bootstrap", &Communicator::bootstrap)
.def_prop_ro("context", &Communicator::context)
.def(
"register_memory",
[](Communicator* self, uintptr_t ptr, size_t size, TransportFlags transports) {