mirror of
https://github.com/microsoft/mscclpp.git
synced 2026-05-12 09:17:06 +00:00
Move bootstrap example to notebook
Also change bootstrap and context into properties in Communicator
This commit is contained in:
committed by
Saeed Maleki
parent
6b39fd9a54
commit
44e612cfbd
164
docs/setup_example.ipynb
Normal file
164
docs/setup_example.ipynb
Normal file
@@ -0,0 +1,164 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Copyright (c) Microsoft Corporation.\n",
|
||||
"Licensed under the MIT license.\n",
|
||||
"\n",
|
||||
"The following example demonstrates how to initialize the MSCCL++ library and perform necessary setup for communicating from GPU kernels. First we define a function for registering memory, making connections and creating channels."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import mscclpp\n",
|
||||
"import torch\n",
|
||||
"\n",
|
||||
"def setup_channels(comm, memory, proxy_service):\n",
|
||||
" # Register the memory with the communicator\n",
|
||||
" ptr = memory.data_ptr()\n",
|
||||
" size = memory.numel() * memory.element_size()\n",
|
||||
" reg_mem = comm.register_memory(ptr, size, mscclpp.Transport.CudaIpc)\n",
|
||||
"\n",
|
||||
" # Create connections to all other ranks and exchange registered memories\n",
|
||||
" connections = []\n",
|
||||
" remote_memories = []\n",
|
||||
" for r in range(comm.bootstrap.size):\n",
|
||||
" if r == comm.bootstrap.rank: # Don't connect to self\n",
|
||||
" continue\n",
|
||||
" connections.append(comm.connect(r, 0, mscclpp.Transport.CudaIpc))\n",
|
||||
" comm.send_memory(reg_mem, r, 0)\n",
|
||||
" remote_mem = comm.recv_memory(r, 0)\n",
|
||||
" remote_memories.append(remote_mem)\n",
|
||||
"\n",
|
||||
" # Both connections and received remote memories are returned as futures,\n",
|
||||
" # so we wait for them to complete and unwrap them.\n",
|
||||
" connections = [conn.get() for conn in connections]\n",
|
||||
" remote_memories = [mem.get() for mem in remote_memories]\n",
|
||||
"\n",
|
||||
" # Finally, create proxy channels for each connection\n",
|
||||
" proxy_channels = [mscclpp.SimpleProxyChannel(\n",
|
||||
" proxy_service.proxy_channel(proxy_service.build_and_add_semaphore(comm, conn)),\n",
|
||||
" proxy_service.add_memory(remote_memories[i]),\n",
|
||||
" proxy_service.add_memory(reg_mem),\n",
|
||||
" ) for i, conn in enumerate(connections)]\n",
|
||||
"\n",
|
||||
" return proxy_channels"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Now we are ready to write the top-level code for each rank."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def run(rank, world_size, if_ip_port_trio):\n",
|
||||
" # Use the right GPU for this rank\n",
|
||||
" torch.cuda.set_device(rank)\n",
|
||||
" \n",
|
||||
" # Allocate memory on the GPU\n",
|
||||
" memory = torch.zeros(1024, dtype=torch.int32)\n",
|
||||
" memory = memory.to(\"cuda\")\n",
|
||||
"\n",
|
||||
" # Initialize a bootstrapper using a known interface/IP/port trio for the root rank\n",
|
||||
" boot = mscclpp.TcpBootstrap.create(rank, world_size)\n",
|
||||
" boot.initialize(if_ip_port_trio)\n",
|
||||
"\n",
|
||||
" # Create a communicator for the processes in the bootstrapper\n",
|
||||
" comm = mscclpp.Communicator(boot)\n",
|
||||
"\n",
|
||||
" # Create a proxy service, which enables GPU kernels to use connections\n",
|
||||
" proxy_service = mscclpp.ProxyService()\n",
|
||||
"\n",
|
||||
" if rank == 0:\n",
|
||||
" print(\"Setting up channels\")\n",
|
||||
" proxy_channels = setup_channels(comm, memory, proxy_service)\n",
|
||||
"\n",
|
||||
" if rank == 0:\n",
|
||||
" print(\"Starting proxy service\")\n",
|
||||
" proxy_service.start_proxy()\n",
|
||||
"\n",
|
||||
" # This is where we could launch a GPU kernel that uses proxy_channels[i].device_handle\n",
|
||||
" # to initiate communication. See include/mscclpp/proxy_channel_device.hpp for details.\n",
|
||||
" if rank == 0:\n",
|
||||
" print(\"GPU kernels that use the proxy go here.\")\n",
|
||||
"\n",
|
||||
" if rank == 0:\n",
|
||||
" print(f\"Stopping proxy service\")\n",
|
||||
" proxy_service.stop_proxy()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Finally, to test the code we can run each process using the `multiprocessing` package."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Setting up channels\n",
|
||||
"Starting proxy service\n",
|
||||
"GPU kernels that use the proxy go here.\n",
|
||||
"Stopping proxy service\n",
|
||||
"\n",
|
||||
"Starting proxy service\n",
|
||||
"GPU kernels that use the proxy go here.\n",
|
||||
"Stopping proxy service\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"import multiprocessing as mp\n",
|
||||
"\n",
|
||||
"world_size = 2\n",
|
||||
"processes = [mp.Process(target=run, args=(rank, world_size, \"eth0:localhost:50051\")) for rank in range(world_size)]\n",
|
||||
"for p in processes:\n",
|
||||
" p.start()\n",
|
||||
"for p in processes:\n",
|
||||
" p.join()"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": ".venv",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.8.10"
|
||||
},
|
||||
"orig_nbformat": 4
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 2
|
||||
}
|
||||
@@ -1,108 +0,0 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import multiprocessing as mp
|
||||
import sys
|
||||
|
||||
import mscclpp
|
||||
import torch
|
||||
|
||||
IB_TRANSPORTS = [
|
||||
mscclpp.Transport.IB0,
|
||||
mscclpp.Transport.IB1,
|
||||
mscclpp.Transport.IB2,
|
||||
mscclpp.Transport.IB3,
|
||||
mscclpp.Transport.IB4,
|
||||
mscclpp.Transport.IB5,
|
||||
mscclpp.Transport.IB6,
|
||||
mscclpp.Transport.IB7,
|
||||
]
|
||||
|
||||
# Use to hold the sm channels so they don't get garbage collected
|
||||
sm_channels = []
|
||||
|
||||
|
||||
def setup_connections(comm, rank, world_size, element_size, proxy_service):
|
||||
simple_proxy_channels = []
|
||||
sm_semaphores = []
|
||||
connections = []
|
||||
remote_memories = []
|
||||
memory = torch.zeros(element_size, dtype=torch.int32)
|
||||
memory = memory.to("cuda")
|
||||
|
||||
transport_flag = mscclpp.TransportFlags(IB_TRANSPORTS[rank]) | mscclpp.Transport.CudaIpc
|
||||
ptr = memory.data_ptr()
|
||||
size = memory.numel() * memory.element_size()
|
||||
reg_mem = comm.register_memory(ptr, size, transport_flag)
|
||||
|
||||
for r in range(world_size):
|
||||
if r == rank:
|
||||
continue
|
||||
conn = comm.connect(r, 0, mscclpp.Transport.CudaIpc)
|
||||
connections.append(conn)
|
||||
comm.send_memory(reg_mem, r, 0)
|
||||
remote_mem = comm.recv_memory(r, 0)
|
||||
remote_memories.append(remote_mem)
|
||||
|
||||
connections = [conn.get() for conn in connections]
|
||||
|
||||
# Create simple proxy channels
|
||||
for i, conn in enumerate(connections):
|
||||
proxy_channel = mscclpp.SimpleProxyChannel(
|
||||
proxy_service.proxy_channel(proxy_service.build_and_add_semaphore(comm, conn)),
|
||||
proxy_service.add_memory(remote_memories[i].get()),
|
||||
proxy_service.add_memory(reg_mem),
|
||||
)
|
||||
simple_proxy_channels.append(proxy_channel.device_handle())
|
||||
comm.setup()
|
||||
|
||||
# Create sm channels
|
||||
for i, conn in enumerate(connections):
|
||||
sm_chan = mscclpp.SmDevice2DeviceSemaphore(comm, conn)
|
||||
sm_semaphores.append(sm_chan)
|
||||
comm.setup()
|
||||
|
||||
for i, conn in enumerate(sm_semaphores):
|
||||
sm_chan = mscclpp.SmChannel(sm_semaphores[i], remote_memories[i].get(), ptr)
|
||||
sm_channels.append(sm_chan)
|
||||
return simple_proxy_channels, [sm_chan.device_handle() for sm_chan in sm_channels]
|
||||
|
||||
|
||||
def run(rank, args):
|
||||
world_size = args.gpu_number
|
||||
torch.cuda.set_device(rank)
|
||||
|
||||
boot = mscclpp.TcpBootstrap.create(rank, world_size)
|
||||
boot.initialize(args.if_ip_port_trio)
|
||||
comm = mscclpp.Communicator(boot)
|
||||
proxy_service = mscclpp.ProxyService()
|
||||
|
||||
logging.info("Rank: %d, setting up connections", rank)
|
||||
setup_connections(comm, rank, world_size, args.num_elements, proxy_service)
|
||||
|
||||
logging.info("Rank: %d, starting proxy service", rank)
|
||||
proxy_service.start_proxy()
|
||||
|
||||
|
||||
def main():
|
||||
logging.basicConfig(stream=sys.stderr, level=logging.DEBUG)
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("if_ip_port_trio", type=str)
|
||||
parser.add_argument("-n", "--num-elements", type=int, default=10)
|
||||
parser.add_argument("-g", "--gpu_number", type=int, default=2)
|
||||
args = parser.parse_args()
|
||||
processes = []
|
||||
|
||||
for rank in range(args.gpu_number):
|
||||
p = mp.Process(target=run, args=(rank, args))
|
||||
p.start()
|
||||
processes.append(p)
|
||||
|
||||
for p in processes:
|
||||
p.join()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -1,80 +0,0 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
import argparse
|
||||
import time
|
||||
|
||||
import mscclpp
|
||||
|
||||
|
||||
def main(args):
|
||||
if args.root:
|
||||
rank = 0
|
||||
else:
|
||||
rank = 1
|
||||
|
||||
boot = mscclpp.TcpBootstrap.create(rank, 2)
|
||||
boot.initialize(args.if_ip_port_trio)
|
||||
|
||||
comm = mscclpp.Communicator(boot)
|
||||
|
||||
if args.gpu:
|
||||
import torch
|
||||
|
||||
print("Allocating GPU memory")
|
||||
memory = torch.zeros(args.num_elements, dtype=torch.int32)
|
||||
memory = memory.to("cuda")
|
||||
ptr = memory.data_ptr()
|
||||
size = memory.numel() * memory.element_size()
|
||||
else:
|
||||
from array import array
|
||||
|
||||
print("Allocating host memory")
|
||||
memory = array("i", [0] * args.num_elements)
|
||||
ptr, elements = memory.buffer_info()
|
||||
size = elements * memory.itemsize
|
||||
my_reg_mem = comm.register_memory(ptr, size, mscclpp.Transport.IB0)
|
||||
|
||||
conn = comm.connect((rank + 1) % 2, 0, mscclpp.Transport.IB0)
|
||||
|
||||
other_reg_mem = None
|
||||
if rank == 0:
|
||||
other_reg_mem = comm.recv_memory((rank + 1) % 2, 0)
|
||||
else:
|
||||
comm.send_memory(my_reg_mem, (rank + 1) % 2, 0)
|
||||
|
||||
if rank == 0:
|
||||
other_reg_mem = other_reg_mem.get()
|
||||
|
||||
if rank == 0:
|
||||
for i in range(args.num_elements):
|
||||
memory[i] = i + 1
|
||||
conn.write(other_reg_mem, 0, my_reg_mem, 0, size)
|
||||
print("Done sending")
|
||||
else:
|
||||
print("Checking for correctness")
|
||||
# polling
|
||||
for _ in range(args.polling_num):
|
||||
all_correct = True
|
||||
for i in range(args.num_elements):
|
||||
if memory[i] != i + 1:
|
||||
all_correct = False
|
||||
print(f"Error: Mismatch at index {i}: expected {i + 1}, got {memory[i]}")
|
||||
break
|
||||
if all_correct:
|
||||
print("All data matched expected values")
|
||||
break
|
||||
else:
|
||||
time.sleep(0.1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("if_ip_port_trio", type=str)
|
||||
parser.add_argument("-r", "--root", action="store_true")
|
||||
parser.add_argument("-n", "--num-elements", type=int, default=10)
|
||||
parser.add_argument("--gpu", action="store_true")
|
||||
parser.add_argument("--polling_num", type=int, default=100)
|
||||
args = parser.parse_args()
|
||||
|
||||
main(args)
|
||||
@@ -1,17 +0,0 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
import time
|
||||
|
||||
import mscclpp
|
||||
|
||||
|
||||
def main():
|
||||
timer = mscclpp.Timer()
|
||||
timer.reset()
|
||||
time.sleep(2)
|
||||
assert timer.elapsed() >= 2000000
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -153,8 +153,8 @@ void register_core(nb::module_& m) {
|
||||
nb::class_<Communicator>(m, "Communicator")
|
||||
.def(nb::init<std::shared_ptr<Bootstrap>, std::shared_ptr<Context>>(), nb::arg("bootstrap"),
|
||||
nb::arg("context") = nullptr)
|
||||
.def("bootstrap", &Communicator::bootstrap)
|
||||
.def("context", &Communicator::context)
|
||||
.def_prop_ro("bootstrap", &Communicator::bootstrap)
|
||||
.def_prop_ro("context", &Communicator::context)
|
||||
.def(
|
||||
"register_memory",
|
||||
[](Communicator* self, uintptr_t ptr, size_t size, TransportFlags transports) {
|
||||
|
||||
Reference in New Issue
Block a user