Move bootstrap example to notebook

Also change bootstrap and context into properties in Communicator
This commit is contained in:
Olli Saarikivi
2023-09-12 21:30:45 +00:00
committed by Saeed Maleki
parent 6b39fd9a54
commit 44e612cfbd
5 changed files with 166 additions and 207 deletions

View File

@@ -1,108 +0,0 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import argparse
import logging
import multiprocessing as mp
import sys
import mscclpp
import torch
IB_TRANSPORTS = [
mscclpp.Transport.IB0,
mscclpp.Transport.IB1,
mscclpp.Transport.IB2,
mscclpp.Transport.IB3,
mscclpp.Transport.IB4,
mscclpp.Transport.IB5,
mscclpp.Transport.IB6,
mscclpp.Transport.IB7,
]
# Use to hold the sm channels so they don't get garbage collected
sm_channels = []
def setup_connections(comm, rank, world_size, element_size, proxy_service):
simple_proxy_channels = []
sm_semaphores = []
connections = []
remote_memories = []
memory = torch.zeros(element_size, dtype=torch.int32)
memory = memory.to("cuda")
transport_flag = mscclpp.TransportFlags(IB_TRANSPORTS[rank]) | mscclpp.Transport.CudaIpc
ptr = memory.data_ptr()
size = memory.numel() * memory.element_size()
reg_mem = comm.register_memory(ptr, size, transport_flag)
for r in range(world_size):
if r == rank:
continue
conn = comm.connect(r, 0, mscclpp.Transport.CudaIpc)
connections.append(conn)
comm.send_memory(reg_mem, r, 0)
remote_mem = comm.recv_memory(r, 0)
remote_memories.append(remote_mem)
connections = [conn.get() for conn in connections]
# Create simple proxy channels
for i, conn in enumerate(connections):
proxy_channel = mscclpp.SimpleProxyChannel(
proxy_service.proxy_channel(proxy_service.build_and_add_semaphore(comm, conn)),
proxy_service.add_memory(remote_memories[i].get()),
proxy_service.add_memory(reg_mem),
)
simple_proxy_channels.append(proxy_channel.device_handle())
comm.setup()
# Create sm channels
for i, conn in enumerate(connections):
sm_chan = mscclpp.SmDevice2DeviceSemaphore(comm, conn)
sm_semaphores.append(sm_chan)
comm.setup()
for i, conn in enumerate(sm_semaphores):
sm_chan = mscclpp.SmChannel(sm_semaphores[i], remote_memories[i].get(), ptr)
sm_channels.append(sm_chan)
return simple_proxy_channels, [sm_chan.device_handle() for sm_chan in sm_channels]
def run(rank, args):
world_size = args.gpu_number
torch.cuda.set_device(rank)
boot = mscclpp.TcpBootstrap.create(rank, world_size)
boot.initialize(args.if_ip_port_trio)
comm = mscclpp.Communicator(boot)
proxy_service = mscclpp.ProxyService()
logging.info("Rank: %d, setting up connections", rank)
setup_connections(comm, rank, world_size, args.num_elements, proxy_service)
logging.info("Rank: %d, starting proxy service", rank)
proxy_service.start_proxy()
def main():
logging.basicConfig(stream=sys.stderr, level=logging.DEBUG)
parser = argparse.ArgumentParser()
parser.add_argument("if_ip_port_trio", type=str)
parser.add_argument("-n", "--num-elements", type=int, default=10)
parser.add_argument("-g", "--gpu_number", type=int, default=2)
args = parser.parse_args()
processes = []
for rank in range(args.gpu_number):
p = mp.Process(target=run, args=(rank, args))
p.start()
processes.append(p)
for p in processes:
p.join()
if __name__ == "__main__":
main()

View File

@@ -1,80 +0,0 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import argparse
import time
import mscclpp
def main(args):
if args.root:
rank = 0
else:
rank = 1
boot = mscclpp.TcpBootstrap.create(rank, 2)
boot.initialize(args.if_ip_port_trio)
comm = mscclpp.Communicator(boot)
if args.gpu:
import torch
print("Allocating GPU memory")
memory = torch.zeros(args.num_elements, dtype=torch.int32)
memory = memory.to("cuda")
ptr = memory.data_ptr()
size = memory.numel() * memory.element_size()
else:
from array import array
print("Allocating host memory")
memory = array("i", [0] * args.num_elements)
ptr, elements = memory.buffer_info()
size = elements * memory.itemsize
my_reg_mem = comm.register_memory(ptr, size, mscclpp.Transport.IB0)
conn = comm.connect((rank + 1) % 2, 0, mscclpp.Transport.IB0)
other_reg_mem = None
if rank == 0:
other_reg_mem = comm.recv_memory((rank + 1) % 2, 0)
else:
comm.send_memory(my_reg_mem, (rank + 1) % 2, 0)
if rank == 0:
other_reg_mem = other_reg_mem.get()
if rank == 0:
for i in range(args.num_elements):
memory[i] = i + 1
conn.write(other_reg_mem, 0, my_reg_mem, 0, size)
print("Done sending")
else:
print("Checking for correctness")
# polling
for _ in range(args.polling_num):
all_correct = True
for i in range(args.num_elements):
if memory[i] != i + 1:
all_correct = False
print(f"Error: Mismatch at index {i}: expected {i + 1}, got {memory[i]}")
break
if all_correct:
print("All data matched expected values")
break
else:
time.sleep(0.1)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("if_ip_port_trio", type=str)
parser.add_argument("-r", "--root", action="store_true")
parser.add_argument("-n", "--num-elements", type=int, default=10)
parser.add_argument("--gpu", action="store_true")
parser.add_argument("--polling_num", type=int, default=100)
args = parser.parse_args()
main(args)

View File

@@ -1,17 +0,0 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import time
import mscclpp
def main():
timer = mscclpp.Timer()
timer.reset()
time.sleep(2)
assert timer.elapsed() >= 2000000
if __name__ == "__main__":
main()

View File

@@ -153,8 +153,8 @@ void register_core(nb::module_& m) {
nb::class_<Communicator>(m, "Communicator")
.def(nb::init<std::shared_ptr<Bootstrap>, std::shared_ptr<Context>>(), nb::arg("bootstrap"),
nb::arg("context") = nullptr)
.def("bootstrap", &Communicator::bootstrap)
.def("context", &Communicator::context)
.def_prop_ro("bootstrap", &Communicator::bootstrap)
.def_prop_ro("context", &Communicator::context)
.def(
"register_memory",
[](Communicator* self, uintptr_t ptr, size_t size, TransportFlags transports) {