mirror of
https://github.com/microsoft/mscclpp.git
synced 2026-05-13 01:36:10 +00:00
Move bootstrap example to notebook
Also change bootstrap and context into properties in Communicator
This commit is contained in:
committed by
Saeed Maleki
parent
6b39fd9a54
commit
44e612cfbd
@@ -1,108 +0,0 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import multiprocessing as mp
|
||||
import sys
|
||||
|
||||
import mscclpp
|
||||
import torch
|
||||
|
||||
IB_TRANSPORTS = [
|
||||
mscclpp.Transport.IB0,
|
||||
mscclpp.Transport.IB1,
|
||||
mscclpp.Transport.IB2,
|
||||
mscclpp.Transport.IB3,
|
||||
mscclpp.Transport.IB4,
|
||||
mscclpp.Transport.IB5,
|
||||
mscclpp.Transport.IB6,
|
||||
mscclpp.Transport.IB7,
|
||||
]
|
||||
|
||||
# Use to hold the sm channels so they don't get garbage collected
|
||||
sm_channels = []
|
||||
|
||||
|
||||
def setup_connections(comm, rank, world_size, element_size, proxy_service):
|
||||
simple_proxy_channels = []
|
||||
sm_semaphores = []
|
||||
connections = []
|
||||
remote_memories = []
|
||||
memory = torch.zeros(element_size, dtype=torch.int32)
|
||||
memory = memory.to("cuda")
|
||||
|
||||
transport_flag = mscclpp.TransportFlags(IB_TRANSPORTS[rank]) | mscclpp.Transport.CudaIpc
|
||||
ptr = memory.data_ptr()
|
||||
size = memory.numel() * memory.element_size()
|
||||
reg_mem = comm.register_memory(ptr, size, transport_flag)
|
||||
|
||||
for r in range(world_size):
|
||||
if r == rank:
|
||||
continue
|
||||
conn = comm.connect(r, 0, mscclpp.Transport.CudaIpc)
|
||||
connections.append(conn)
|
||||
comm.send_memory(reg_mem, r, 0)
|
||||
remote_mem = comm.recv_memory(r, 0)
|
||||
remote_memories.append(remote_mem)
|
||||
|
||||
connections = [conn.get() for conn in connections]
|
||||
|
||||
# Create simple proxy channels
|
||||
for i, conn in enumerate(connections):
|
||||
proxy_channel = mscclpp.SimpleProxyChannel(
|
||||
proxy_service.proxy_channel(proxy_service.build_and_add_semaphore(comm, conn)),
|
||||
proxy_service.add_memory(remote_memories[i].get()),
|
||||
proxy_service.add_memory(reg_mem),
|
||||
)
|
||||
simple_proxy_channels.append(proxy_channel.device_handle())
|
||||
comm.setup()
|
||||
|
||||
# Create sm channels
|
||||
for i, conn in enumerate(connections):
|
||||
sm_chan = mscclpp.SmDevice2DeviceSemaphore(comm, conn)
|
||||
sm_semaphores.append(sm_chan)
|
||||
comm.setup()
|
||||
|
||||
for i, conn in enumerate(sm_semaphores):
|
||||
sm_chan = mscclpp.SmChannel(sm_semaphores[i], remote_memories[i].get(), ptr)
|
||||
sm_channels.append(sm_chan)
|
||||
return simple_proxy_channels, [sm_chan.device_handle() for sm_chan in sm_channels]
|
||||
|
||||
|
||||
def run(rank, args):
|
||||
world_size = args.gpu_number
|
||||
torch.cuda.set_device(rank)
|
||||
|
||||
boot = mscclpp.TcpBootstrap.create(rank, world_size)
|
||||
boot.initialize(args.if_ip_port_trio)
|
||||
comm = mscclpp.Communicator(boot)
|
||||
proxy_service = mscclpp.ProxyService()
|
||||
|
||||
logging.info("Rank: %d, setting up connections", rank)
|
||||
setup_connections(comm, rank, world_size, args.num_elements, proxy_service)
|
||||
|
||||
logging.info("Rank: %d, starting proxy service", rank)
|
||||
proxy_service.start_proxy()
|
||||
|
||||
|
||||
def main():
|
||||
logging.basicConfig(stream=sys.stderr, level=logging.DEBUG)
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("if_ip_port_trio", type=str)
|
||||
parser.add_argument("-n", "--num-elements", type=int, default=10)
|
||||
parser.add_argument("-g", "--gpu_number", type=int, default=2)
|
||||
args = parser.parse_args()
|
||||
processes = []
|
||||
|
||||
for rank in range(args.gpu_number):
|
||||
p = mp.Process(target=run, args=(rank, args))
|
||||
p.start()
|
||||
processes.append(p)
|
||||
|
||||
for p in processes:
|
||||
p.join()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -1,80 +0,0 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
import argparse
|
||||
import time
|
||||
|
||||
import mscclpp
|
||||
|
||||
|
||||
def main(args):
|
||||
if args.root:
|
||||
rank = 0
|
||||
else:
|
||||
rank = 1
|
||||
|
||||
boot = mscclpp.TcpBootstrap.create(rank, 2)
|
||||
boot.initialize(args.if_ip_port_trio)
|
||||
|
||||
comm = mscclpp.Communicator(boot)
|
||||
|
||||
if args.gpu:
|
||||
import torch
|
||||
|
||||
print("Allocating GPU memory")
|
||||
memory = torch.zeros(args.num_elements, dtype=torch.int32)
|
||||
memory = memory.to("cuda")
|
||||
ptr = memory.data_ptr()
|
||||
size = memory.numel() * memory.element_size()
|
||||
else:
|
||||
from array import array
|
||||
|
||||
print("Allocating host memory")
|
||||
memory = array("i", [0] * args.num_elements)
|
||||
ptr, elements = memory.buffer_info()
|
||||
size = elements * memory.itemsize
|
||||
my_reg_mem = comm.register_memory(ptr, size, mscclpp.Transport.IB0)
|
||||
|
||||
conn = comm.connect((rank + 1) % 2, 0, mscclpp.Transport.IB0)
|
||||
|
||||
other_reg_mem = None
|
||||
if rank == 0:
|
||||
other_reg_mem = comm.recv_memory((rank + 1) % 2, 0)
|
||||
else:
|
||||
comm.send_memory(my_reg_mem, (rank + 1) % 2, 0)
|
||||
|
||||
if rank == 0:
|
||||
other_reg_mem = other_reg_mem.get()
|
||||
|
||||
if rank == 0:
|
||||
for i in range(args.num_elements):
|
||||
memory[i] = i + 1
|
||||
conn.write(other_reg_mem, 0, my_reg_mem, 0, size)
|
||||
print("Done sending")
|
||||
else:
|
||||
print("Checking for correctness")
|
||||
# polling
|
||||
for _ in range(args.polling_num):
|
||||
all_correct = True
|
||||
for i in range(args.num_elements):
|
||||
if memory[i] != i + 1:
|
||||
all_correct = False
|
||||
print(f"Error: Mismatch at index {i}: expected {i + 1}, got {memory[i]}")
|
||||
break
|
||||
if all_correct:
|
||||
print("All data matched expected values")
|
||||
break
|
||||
else:
|
||||
time.sleep(0.1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("if_ip_port_trio", type=str)
|
||||
parser.add_argument("-r", "--root", action="store_true")
|
||||
parser.add_argument("-n", "--num-elements", type=int, default=10)
|
||||
parser.add_argument("--gpu", action="store_true")
|
||||
parser.add_argument("--polling_num", type=int, default=100)
|
||||
args = parser.parse_args()
|
||||
|
||||
main(args)
|
||||
@@ -1,17 +0,0 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
import time
|
||||
|
||||
import mscclpp
|
||||
|
||||
|
||||
def main():
|
||||
timer = mscclpp.Timer()
|
||||
timer.reset()
|
||||
time.sleep(2)
|
||||
assert timer.elapsed() >= 2000000
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -153,8 +153,8 @@ void register_core(nb::module_& m) {
|
||||
nb::class_<Communicator>(m, "Communicator")
|
||||
.def(nb::init<std::shared_ptr<Bootstrap>, std::shared_ptr<Context>>(), nb::arg("bootstrap"),
|
||||
nb::arg("context") = nullptr)
|
||||
.def("bootstrap", &Communicator::bootstrap)
|
||||
.def("context", &Communicator::context)
|
||||
.def_prop_ro("bootstrap", &Communicator::bootstrap)
|
||||
.def_prop_ro("context", &Communicator::context)
|
||||
.def(
|
||||
"register_memory",
|
||||
[](Communicator* self, uintptr_t ptr, size_t size, TransportFlags transports) {
|
||||
|
||||
Reference in New Issue
Block a user