mirror of
https://github.com/microsoft/mscclpp.git
synced 2026-05-12 09:17:06 +00:00
Also fixes bugs in MscclppAllReduce6 Below is the performance when the algorithm is fixed to MscclppAllReduce6 on 8 H100 GPUs connected with NVLink using CUDA 12.2. Float16: +-------------+-----------+--------------+-------------+----------------+-------------------+------------------+----------+ | Size (fp16) | Time (us) | AlgBW (GB/s) | Correctness | NCCL Time (us) | NCCL AlgBW (GB/s) | NCCL Correctness | Speed Up | +-------------+-----------+--------------+-------------+----------------+-------------------+------------------+----------+ | 2.0 KiB | 11.15 | 0.18 | PASS | 13.82 | 0.15 | PASS | 1.24 | | 4.0 KiB | 11.15 | 0.37 | PASS | 14.74 | 0.28 | PASS | 1.32 | | 8.0 KiB | 11.14 | 0.74 | PASS | 15.17 | 0.54 | PASS | 1.36 | | 16.0 KiB | 11.16 | 1.47 | PASS | 15.77 | 1.04 | PASS | 1.41 | | 32.0 KiB | 11.15 | 2.94 | PASS | 17.50 | 1.87 | PASS | 1.57 | | 64.0 KiB | 11.18 | 5.86 | PASS | 17.64 | 3.71 | PASS | 1.58 | | 128.0 KiB | 11.16 | 11.74 | PASS | 17.83 | 7.35 | PASS | 1.60 | | 256.0 KiB | 11.21 | 23.38 | PASS | 18.00 | 14.57 | PASS | 1.60 | | 512.0 KiB | 11.70 | 44.81 | PASS | 18.42 | 28.46 | PASS | 1.57 | | 1.0 MiB | 13.64 | 76.87 | PASS | 20.23 | 51.83 | PASS | 1.48 | | 2.0 MiB | 17.29 | 121.27 | PASS | 31.60 | 66.36 | PASS | 1.83 | | 4.0 MiB | 25.26 | 166.02 | PASS | 38.74 | 108.26 | PASS | 1.53 | | 8.0 MiB | 40.17 | 208.83 | PASS | 62.86 | 133.45 | PASS | 1.56 | | 16.0 MiB | 70.92 | 236.56 | PASS | 113.36 | 147.99 | PASS | 1.60 | | 32.0 MiB | 131.38 | 255.41 | PASS | 203.21 | 165.13 | PASS | 1.55 | | 64.0 MiB | 253.39 | 264.84 | PASS | 342.12 | 196.15 | PASS | 1.35 | | 128.0 MiB | 496.74 | 270.20 | PASS | 670.62 | 200.14 | PASS | 1.35 | | 256.0 MiB | 982.42 | 273.24 | PASS | 1318.36 | 203.61 | PASS | 1.34 | +-------------+-----------+--------------+-------------+----------------+-------------------+------------------+----------+ Float32: +-------------+-----------+--------------+-------------+----------------+-------------------+------------------+----------+ | Size (fp32) | Time (us) | AlgBW (GB/s) | Correctness | NCCL Time (us) | NCCL AlgBW (GB/s) | NCCL Correctness | Speed Up | +-------------+-----------+--------------+-------------+----------------+-------------------+------------------+----------+ | 4.0 KiB | 11.04 | 0.37 | PASS | 14.79 | 0.28 | PASS | 1.34 | | 8.0 KiB | 11.15 | 0.73 | PASS | 15.25 | 0.54 | PASS | 1.37 | | 16.0 KiB | 11.12 | 1.47 | PASS | 15.87 | 1.03 | PASS | 1.43 | | 32.0 KiB | 11.13 | 2.95 | PASS | 17.21 | 1.90 | PASS | 1.55 | | 64.0 KiB | 11.11 | 5.90 | PASS | 17.37 | 3.77 | PASS | 1.56 | | 128.0 KiB | 11.08 | 11.83 | PASS | 17.54 | 7.47 | PASS | 1.58 | | 256.0 KiB | 11.15 | 23.50 | PASS | 17.71 | 14.80 | PASS | 1.59 | | 512.0 KiB | 11.56 | 45.34 | PASS | 18.21 | 28.79 | PASS | 1.57 | | 1.0 MiB | 13.64 | 76.90 | PASS | 19.87 | 52.77 | PASS | 1.46 | | 2.0 MiB | 17.24 | 121.67 | PASS | 31.63 | 66.30 | PASS | 1.84 | | 4.0 MiB | 25.19 | 166.47 | PASS | 38.63 | 108.57 | PASS | 1.53 | | 8.0 MiB | 40.38 | 207.72 | PASS | 62.65 | 133.89 | PASS | 1.55 | | 16.0 MiB | 70.72 | 237.23 | PASS | 114.57 | 146.44 | PASS | 1.62 | | 32.0 MiB | 131.49 | 255.18 | PASS | 200.79 | 167.11 | PASS | 1.53 | | 64.0 MiB | 253.98 | 264.23 | PASS | 342.58 | 195.89 | PASS | 1.35 | | 128.0 MiB | 496.96 | 270.08 | PASS | 670.64 | 200.13 | PASS | 1.35 | | 256.0 MiB | 982.83 | 273.12 | PASS | 1318.90 | 203.53 | PASS | 1.34 | | 512.0 MiB | 1954.07 | 274.75 | PASS | 2609.04 | 205.77 | PASS | 1.34 | +-------------+-----------+--------------+-------------+----------------+-------------------+------------------+----------+
318 lines
10 KiB
Python
318 lines
10 KiB
Python
# Copyright (c) Microsoft Corporation.
|
|
# Licensed under the MIT license.
|
|
|
|
import cupy as cp
|
|
from mscclpp_op import (
|
|
MscclppAllReduce1,
|
|
MscclppAllReduce2,
|
|
MscclppAllReduce3,
|
|
MscclppAllReduce4,
|
|
MscclppAllReduce5,
|
|
MscclppAllReduce6,
|
|
)
|
|
from nccl_op import NcclAllReduce
|
|
from mpi4py import MPI
|
|
import cupy.cuda.nccl as nccl
|
|
import mscclpp.comm as mscclpp_comm
|
|
from mscclpp import ProxyService, is_nvls_supported
|
|
from prettytable import PrettyTable
|
|
import netifaces as ni
|
|
import ipaddress
|
|
|
|
data_type = cp.float32
|
|
|
|
if data_type == cp.float16:
|
|
dtype_str = "fp16"
|
|
elif data_type == cp.float32:
|
|
dtype_str = "fp32"
|
|
elif data_type == cp.int32:
|
|
dtype_str = "int32"
|
|
else:
|
|
raise RuntimeError("Unknown data type")
|
|
|
|
|
|
def plot_graph(sizes, mscclpp_algbw, nccl_algbw, speed_ups):
|
|
import matplotlib.pyplot as plt
|
|
|
|
human_readable_sizes = [human_readable_size(size) for size in sizes]
|
|
|
|
fig, ax1 = plt.subplots(figsize=(10, 6))
|
|
|
|
# Plotting AlgBW for MSCCLPP and NCCL on the primary y-axis
|
|
(line1,) = ax1.plot(sizes, mscclpp_algbw, marker="o", color="blue", label="MSCCLPP AlgBW")
|
|
(line2,) = ax1.plot(sizes, nccl_algbw, marker="x", color="red", label="NCCL AlgBW")
|
|
ax1.set_ylabel("AlgBW (GB/s)")
|
|
ax1.set_xlabel("Data Size")
|
|
|
|
# Logarithmic x-axis
|
|
ax1.set_xscale("log", base=2)
|
|
ax1.set_xticks(sizes)
|
|
ax1.set_xticklabels(human_readable_sizes, rotation=45)
|
|
|
|
# Adding secondary y-axis for Speed Up
|
|
ax2 = ax1.twinx()
|
|
(line3,) = ax2.plot(sizes, speed_ups, marker="^", color="green", label="Speed Up")
|
|
ax2.set_ylabel("Speed Up (NCCL Time / MSCCLPP Time)", color="green")
|
|
ax2.tick_params(axis="y", labelcolor="green")
|
|
|
|
# Set the lower bound of the secondary y-axis to 0
|
|
ax2.set_ylim(bottom=0)
|
|
|
|
# Creating legends
|
|
lines = [line1, line2, line3]
|
|
labels = [line.get_label() for line in lines]
|
|
ax1.legend(lines, labels, loc="upper left")
|
|
|
|
# Setting title and grid
|
|
num_nodes = MPI.COMM_WORLD.size // N_GPUS_PER_NODE
|
|
ax1.set_title(f"MSCCLPP vs NCCL -- {num_nodes} Nodes")
|
|
ax2.grid(True, which="both", ls="--")
|
|
|
|
# Saving the plot
|
|
plt.savefig(f"mscclpp_vs_nccl_comparison_num_nodes_{num_nodes}.jpeg", format="jpeg")
|
|
|
|
|
|
def human_readable_size(size, decimal_places=1):
|
|
for unit in ["B", "KiB", "MiB", "GiB", "TiB", "PiB"]:
|
|
if size < 1024.0 or unit == "PiB":
|
|
break
|
|
size /= 1024.0
|
|
return f"{size:.{decimal_places}f} {unit}"
|
|
|
|
|
|
def check_correctness(memory, func, niter=100):
|
|
ac = True
|
|
for p in range(niter):
|
|
memory[:] = cp.ones(memory.shape).astype(data_type) * (p * MPI.COMM_WORLD.size + MPI.COMM_WORLD.rank)
|
|
cp.cuda.runtime.deviceSynchronize()
|
|
output_memory = func(None)
|
|
cp.cuda.runtime.deviceSynchronize()
|
|
expected = cp.zeros_like(memory)
|
|
for i in range(MPI.COMM_WORLD.size):
|
|
expected += cp.ones(memory.shape).astype(data_type) * (p * MPI.COMM_WORLD.size + i)
|
|
|
|
is_close = cp.isclose(output_memory, expected, rtol=1.0e-2, atol=2)
|
|
icf = is_close == 0
|
|
all_close = cp.all(is_close)
|
|
ac = ac and all_close
|
|
if not all_close:
|
|
print(
|
|
f"not close: p={p}, rank={MPI.COMM_WORLD.rank}, output={output_memory[icf][0]}, expected={expected[icf][0]}",
|
|
flush=True,
|
|
)
|
|
|
|
ac = MPI.COMM_WORLD.allreduce(ac, op=MPI.SUM)
|
|
return ac
|
|
|
|
|
|
def bench_time(niter: int, func):
|
|
# capture cuda graph for nites of the kernel launch
|
|
stream = cp.cuda.Stream(non_blocking=True)
|
|
with stream:
|
|
stream.begin_capture()
|
|
for i in range(niter):
|
|
func(stream)
|
|
graph = stream.end_capture()
|
|
|
|
# now run a warm up round
|
|
graph.launch(stream)
|
|
|
|
# now run the benchmark and measure time
|
|
start = cp.cuda.Event()
|
|
end = cp.cuda.Event()
|
|
|
|
start.record(stream)
|
|
graph.launch(stream)
|
|
end.record(stream)
|
|
end.synchronize()
|
|
|
|
return cp.cuda.get_elapsed_time(start, end) / niter * 1000.0
|
|
|
|
|
|
def find_best_algo(mscclpp_algos, niter):
|
|
assert len(mscclpp_algos) > 0
|
|
best_time = 10000000.0
|
|
best_algo = None
|
|
for algo in mscclpp_algos:
|
|
config, cur_time = find_best_config(algo, niter)
|
|
if cur_time < best_time:
|
|
best_time = cur_time
|
|
best_algo = algo
|
|
algo.set_params(*config)
|
|
if MPI.COMM_WORLD.rank == 0:
|
|
print(best_algo, end="", flush=True)
|
|
return best_algo
|
|
|
|
|
|
def find_best_config(mscclpp_call, niter):
|
|
best_time = 10000000.0
|
|
for config in mscclpp_call.auto_tune():
|
|
cur_time = bench_time(niter, mscclpp_call)
|
|
if cur_time < best_time:
|
|
best_time = cur_time
|
|
best_config = config
|
|
if MPI.COMM_WORLD.rank == 0:
|
|
print("t", end="", flush=True)
|
|
best_config = MPI.COMM_WORLD.bcast(best_config, root=0)
|
|
if MPI.COMM_WORLD.rank == 0:
|
|
print(best_config, end="", flush=True)
|
|
return best_config, best_time
|
|
|
|
|
|
def run_benchmark(
|
|
mscclpp_group: mscclpp_comm.CommGroup, nccl_op: nccl.NcclCommunicator, table: PrettyTable, niter: int, nelem: int
|
|
):
|
|
memory = cp.zeros(nelem, dtype=data_type)
|
|
memory_out = cp.zeros(nelem, dtype=data_type)
|
|
cp.cuda.runtime.deviceSynchronize()
|
|
|
|
proxy_service = ProxyService()
|
|
if MPI.COMM_WORLD.size // N_GPUS_PER_NODE == 1:
|
|
if memory.nbytes < 2**20:
|
|
mscclpp_algos = [MscclppAllReduce2(mscclpp_group, memory, memory_out)]
|
|
else:
|
|
mscclpp_algos = [
|
|
MscclppAllReduce1(mscclpp_group, memory),
|
|
MscclppAllReduce3(mscclpp_group, memory, proxy_service),
|
|
]
|
|
if is_nvls_supported() and (data_type == cp.float32 or data_type == cp.float16):
|
|
mscclpp_algos.append(MscclppAllReduce6(mscclpp_group, nelem, data_type))
|
|
else:
|
|
if memory.nbytes < 2**22:
|
|
mscclpp_algos = [MscclppAllReduce5(mscclpp_group, memory, memory_out, N_GPUS_PER_NODE, proxy_service)]
|
|
else:
|
|
mscclpp_algos = [MscclppAllReduce4(mscclpp_group, memory, N_GPUS_PER_NODE, proxy_service)]
|
|
|
|
proxy_service.start_proxy()
|
|
MPI.COMM_WORLD.barrier()
|
|
mscclpp_call = find_best_algo(mscclpp_algos, 20)
|
|
if isinstance(mscclpp_call, MscclppAllReduce6):
|
|
memory = mscclpp_call.get_memory()
|
|
|
|
nccl_call = NcclAllReduce(nccl_op, memory)
|
|
|
|
memory_nbytes = memory.nbytes
|
|
mscclpp_time = bench_time(niter, mscclpp_call)
|
|
mscclpp_algBw = memory_nbytes / mscclpp_time / 1e3
|
|
mscclpp_check = "PASS" if check_correctness(memory, mscclpp_call) else "FAIL"
|
|
|
|
nccl_time = bench_time(niter, nccl_call)
|
|
nccl_algBw = memory_nbytes / nccl_time / 1e3
|
|
nccl_check = "PASS" if check_correctness(memory, nccl_call) else "FAIL"
|
|
|
|
MPI.COMM_WORLD.barrier()
|
|
proxy_service.stop_proxy()
|
|
|
|
speed_up = nccl_time / mscclpp_time
|
|
if MPI.COMM_WORLD.rank == 0:
|
|
table.add_row(
|
|
[
|
|
human_readable_size(memory_nbytes),
|
|
"{:.2f}".format(mscclpp_time),
|
|
"{:.2f}".format(mscclpp_algBw),
|
|
mscclpp_check,
|
|
"{:.2f}".format(nccl_time),
|
|
"{:.2f}".format(nccl_algBw),
|
|
nccl_check,
|
|
"{:.2f}".format(speed_up),
|
|
]
|
|
)
|
|
if MPI.COMM_WORLD.rank == 0:
|
|
print(".", end="", flush=True)
|
|
|
|
return memory.nbytes, mscclpp_algBw, nccl_algBw, speed_up
|
|
|
|
|
|
def is_valid(ip):
|
|
"""
|
|
Check if the IP address is valid for connecting to other devices.
|
|
This excludes loopback (127.0.0.1) and link-local (169.254.x.x) addresses.
|
|
"""
|
|
ip_obj = ipaddress.ip_address(ip)
|
|
return not (ip_obj.is_loopback or ip_obj.is_link_local or ip_obj.is_multicast)
|
|
|
|
|
|
def get_netinterface_info():
|
|
"""
|
|
Returns the name of the first network interface with a valid IP address that it finds.
|
|
"""
|
|
interfaces = ni.interfaces()
|
|
for interface in interfaces:
|
|
addresses = ni.ifaddresses(interface)
|
|
if ni.AF_INET in addresses:
|
|
for addr in addresses[ni.AF_INET]:
|
|
ip_address = addr["addr"]
|
|
if is_valid(ip_address):
|
|
print(f"Selected Interface: {interface}, IP Address: {ip_address}")
|
|
return interface, ip_address
|
|
return None, None
|
|
|
|
|
|
if __name__ == "__main__":
|
|
shm_comm = MPI.COMM_WORLD.Split_type(MPI.COMM_TYPE_SHARED, 0, MPI.INFO_NULL)
|
|
N_GPUS_PER_NODE = shm_comm.size
|
|
shm_comm.Free()
|
|
cp.cuda.Device(MPI.COMM_WORLD.rank % N_GPUS_PER_NODE).use()
|
|
|
|
# create a MscclppGroup
|
|
network_interface, my_ip = get_netinterface_info()
|
|
root_ip = MPI.COMM_WORLD.bcast(my_ip, root=0)
|
|
ifIpPortTrio = network_interface + ":" + root_ip + ":50000" # some random port
|
|
mscclpp_group = mscclpp_comm.CommGroup(
|
|
interfaceIpPortTrio=ifIpPortTrio, rank=MPI.COMM_WORLD.rank, size=MPI.COMM_WORLD.size
|
|
)
|
|
|
|
# create a NcclComm
|
|
if MPI.COMM_WORLD.rank == 0:
|
|
uid = nccl.get_unique_id()
|
|
else:
|
|
uid = None
|
|
uid = MPI.COMM_WORLD.bcast(uid, root=0)
|
|
nccl_comm = nccl.NcclCommunicator(MPI.COMM_WORLD.size, uid, MPI.COMM_WORLD.rank)
|
|
|
|
table = None
|
|
if MPI.COMM_WORLD.rank == 0:
|
|
# Set table headers
|
|
table = PrettyTable()
|
|
table.field_names = [
|
|
f"Size ({dtype_str})",
|
|
"Time (us)",
|
|
"AlgBW (GB/s)",
|
|
"Correctness",
|
|
"NCCL Time (us)",
|
|
"NCCL AlgBW (GB/s)",
|
|
"NCCL Correctness",
|
|
"Speed Up",
|
|
]
|
|
|
|
sizes = []
|
|
mscclpp_algbw = []
|
|
nccl_algbw = []
|
|
speed_ups = []
|
|
end_range = 28 if is_nvls_supported() else 29
|
|
for i in range(10, end_range):
|
|
if MPI.COMM_WORLD.size // N_GPUS_PER_NODE == 1:
|
|
nelems = 2**i
|
|
elif MPI.COMM_WORLD.size // N_GPUS_PER_NODE == 2:
|
|
nelems = 3 * 2**i
|
|
else:
|
|
raise RuntimeError("Only support one node/two nodes communication")
|
|
|
|
if nelems * data_type().itemsize > 2**32:
|
|
break # due to trigger bit width limitation, we can only support up to 2**32
|
|
|
|
size, mscclpp_algBw, nccl_algBw, speed_up = run_benchmark(mscclpp_group, nccl_comm, table, 100, nelems)
|
|
sizes.append(size)
|
|
mscclpp_algbw.append(mscclpp_algBw)
|
|
nccl_algbw.append(nccl_algBw)
|
|
speed_ups.append(speed_up)
|
|
|
|
if MPI.COMM_WORLD.rank == 0:
|
|
print()
|
|
print(table)
|
|
|
|
plot_graph(sizes, mscclpp_algbw, nccl_algbw, speed_ups)
|
|
|
|
mscclpp_group = None
|
|
nccl_comm = None
|