mirror of
https://github.com/microsoft/mscclpp.git
synced 2026-05-11 17:00:22 +00:00
Update allreduce_bench.py (#318)
Replacing hardcoded network interface name for generic discovery strategy. --------- Co-authored-by: Changho Hwang <changhohwang@microsoft.com>
This commit is contained in:
@@ -17,6 +17,7 @@ import mscclpp.comm as mscclpp_comm
|
||||
from mscclpp import ProxyService, is_nvls_supported
|
||||
from prettytable import PrettyTable
|
||||
import netifaces as ni
|
||||
import ipaddress
|
||||
|
||||
data_type = cp.float32
|
||||
|
||||
@@ -222,6 +223,31 @@ def run_benchmark(
|
||||
return memory.nbytes, mscclpp_algBw, nccl_algBw, speed_up
|
||||
|
||||
|
||||
def is_valid(ip):
|
||||
"""
|
||||
Check if the IP address is valid for connecting to other devices.
|
||||
This excludes loopback (127.0.0.1) and link-local (169.254.x.x) addresses.
|
||||
"""
|
||||
ip_obj = ipaddress.ip_address(ip)
|
||||
return not (ip_obj.is_loopback or ip_obj.is_link_local or ip_obj.is_multicast)
|
||||
|
||||
|
||||
def get_netinterface_info():
|
||||
"""
|
||||
Returns the name of the first network interface with a valid IP address that it finds.
|
||||
"""
|
||||
interfaces = ni.interfaces()
|
||||
for interface in interfaces:
|
||||
addresses = ni.ifaddresses(interface)
|
||||
if ni.AF_INET in addresses:
|
||||
for addr in addresses[ni.AF_INET]:
|
||||
ip_address = addr["addr"]
|
||||
if is_valid(ip_address):
|
||||
print(f"Selected Interface: {interface}, IP Address: {ip_address}")
|
||||
return interface, ip_address
|
||||
return None, None
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
shm_comm = MPI.COMM_WORLD.Split_type(MPI.COMM_TYPE_SHARED, 0, MPI.INFO_NULL)
|
||||
N_GPUS_PER_NODE = shm_comm.size
|
||||
@@ -229,8 +255,7 @@ if __name__ == "__main__":
|
||||
cp.cuda.Device(MPI.COMM_WORLD.rank % N_GPUS_PER_NODE).use()
|
||||
|
||||
# create a MscclppGroup
|
||||
network_interface = "eth0"
|
||||
my_ip = ni.ifaddresses(network_interface)[ni.AF_INET][0]["addr"]
|
||||
network_interface, my_ip = get_netinterface_info()
|
||||
root_ip = MPI.COMM_WORLD.bcast(my_ip, root=0)
|
||||
ifIpPortTrio = network_interface + ":" + root_ip + ":50000" # some random port
|
||||
mscclpp_group = mscclpp_comm.CommGroup(
|
||||
|
||||
Reference in New Issue
Block a user