Files
mscclpp/python/examples/send_recv.py
Saeed Maleki 8d1b984bed Change device handle interfaces & others (#142)
* Changed device handle interfaces
* Changed proxy service interfaces
* Move device code into separate files
* Fixed FIFO polling issues
* Add configuration arguments in several interface functions

---------

Co-authored-by: Changho Hwang <changhohwang@microsoft.com>
Co-authored-by: Binyang Li <binyli@microsoft.com>
Co-authored-by: root <root@a100-saemal0.qxveptpukjsuthqvv514inp03c.gx.internal.cloudapp.net>
2023-08-16 20:00:56 +08:00

83 lines
2.3 KiB
Python

# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import argparse
import time
import mscclpp
def main(args):
if args.root:
rank = 0
else:
rank = 1
boot = mscclpp.TcpBootstrap.create(rank, 2)
boot.initialize(args.if_ip_port_trio)
comm = mscclpp.Communicator(boot)
if args.gpu:
import torch
print("Allocating GPU memory")
memory = torch.zeros(args.num_elements, dtype=torch.int32)
memory = memory.to("cuda")
ptr = memory.data_ptr()
size = memory.numel() * memory.element_size()
else:
from array import array
print("Allocating host memory")
memory = array("i", [0] * args.num_elements)
ptr, elements = memory.buffer_info()
size = elements * memory.itemsize
my_reg_mem = comm.register_memory(ptr, size, mscclpp.Transport.IB0)
conn = comm.connect_on_setup((rank + 1) % 2, 0, mscclpp.Transport.IB0)
other_reg_mem = None
if rank == 0:
other_reg_mem = comm.recv_memory_on_setup((rank + 1) % 2, 0)
else:
comm.send_memory_on_setup(my_reg_mem, (rank + 1) % 2, 0)
comm.setup()
if rank == 0:
other_reg_mem = other_reg_mem.get()
if rank == 0:
for i in range(args.num_elements):
memory[i] = i + 1
conn.write(other_reg_mem, 0, my_reg_mem, 0, size)
print("Done sending")
else:
print("Checking for correctness")
# polling
for _ in range(args.polling_num):
all_correct = True
for i in range(args.num_elements):
if memory[i] != i + 1:
all_correct = False
print(f"Error: Mismatch at index {i}: expected {i + 1}, got {memory[i]}")
break
if all_correct:
print("All data matched expected values")
break
else:
time.sleep(0.1)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("if_ip_port_trio", type=str)
parser.add_argument("-r", "--root", action="store_true")
parser.add_argument("-n", "--num-elements", type=int, default=10)
parser.add_argument("--gpu", action="store_true")
parser.add_argument("--polling_num", type=int, default=100)
args = parser.parse_args()
main(args)