mirror of
https://github.com/microsoft/mscclpp.git
synced 2026-05-11 17:00:22 +00:00
Reorganize current native algorithm implementation and DSL algorithm implementation. Provide unified API for DSL algo and native algo and provide interface to tune the algo Provide interface for pytorch integration with native API and DSL --------- Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> Co-authored-by: Copilot <198982749+Copilot@users.noreply.github.com> Co-authored-by: chhwang <8018170+chhwang@users.noreply.github.com>
313 lines
10 KiB
Python
313 lines
10 KiB
Python
# Copyright (c) Microsoft Corporation.
|
|
# Licensed under the MIT license.
|
|
|
|
import cupy as cp
|
|
from mscclpp_op import (
|
|
MscclppAllReduce1,
|
|
MscclppAllReduce2,
|
|
MscclppAllReduce3,
|
|
MscclppAllReduce4,
|
|
MscclppAllReduce5,
|
|
MscclppAllReduce6,
|
|
)
|
|
from nccl_op import NcclAllReduce
|
|
from mpi4py import MPI
|
|
import cupy.cuda.nccl as nccl
|
|
from mscclpp import ProxyService, is_nvls_supported, CommGroup, GpuBuffer
|
|
from prettytable import PrettyTable
|
|
import netifaces as ni
|
|
import ipaddress
|
|
|
|
data_type = cp.float32
|
|
|
|
if data_type == cp.float16:
|
|
dtype_str = "fp16"
|
|
elif data_type == cp.float32:
|
|
dtype_str = "fp32"
|
|
elif data_type == cp.int32:
|
|
dtype_str = "int32"
|
|
else:
|
|
raise RuntimeError("Unknown data type")
|
|
|
|
|
|
def plot_graph(sizes, mscclpp_algbw, nccl_algbw, speed_ups):
|
|
import matplotlib.pyplot as plt
|
|
|
|
human_readable_sizes = [human_readable_size(size) for size in sizes]
|
|
|
|
fig, ax1 = plt.subplots(figsize=(10, 6))
|
|
|
|
# Plotting AlgBW for MSCCLPP and NCCL on the primary y-axis
|
|
(line1,) = ax1.plot(sizes, mscclpp_algbw, marker="o", color="blue", label="MSCCLPP AlgBW")
|
|
(line2,) = ax1.plot(sizes, nccl_algbw, marker="x", color="red", label="NCCL AlgBW")
|
|
ax1.set_ylabel("AlgBW (GB/s)")
|
|
ax1.set_xlabel("Data Size")
|
|
|
|
# Logarithmic x-axis
|
|
ax1.set_xscale("log", base=2)
|
|
ax1.set_xticks(sizes)
|
|
ax1.set_xticklabels(human_readable_sizes, rotation=45)
|
|
|
|
# Adding secondary y-axis for Speed Up
|
|
ax2 = ax1.twinx()
|
|
(line3,) = ax2.plot(sizes, speed_ups, marker="^", color="green", label="Speed Up")
|
|
ax2.set_ylabel("Speed Up (NCCL Time / MSCCLPP Time)", color="green")
|
|
ax2.tick_params(axis="y", labelcolor="green")
|
|
|
|
# Set the lower bound of the secondary y-axis to 0
|
|
ax2.set_ylim(bottom=0)
|
|
|
|
# Creating legends
|
|
lines = [line1, line2, line3]
|
|
labels = [line.get_label() for line in lines]
|
|
ax1.legend(lines, labels, loc="upper left")
|
|
|
|
# Setting title and grid
|
|
num_nodes = MPI.COMM_WORLD.size // N_GPUS_PER_NODE
|
|
ax1.set_title(f"MSCCLPP vs NCCL -- {num_nodes} Nodes")
|
|
ax2.grid(True, which="both", ls="--")
|
|
|
|
# Saving the plot
|
|
plt.savefig(f"mscclpp_vs_nccl_comparison_num_nodes_{num_nodes}.jpeg", format="jpeg")
|
|
|
|
|
|
def human_readable_size(size, decimal_places=1):
|
|
for unit in ["B", "KiB", "MiB", "GiB", "TiB", "PiB"]:
|
|
if size < 1024.0 or unit == "PiB":
|
|
break
|
|
size /= 1024.0
|
|
return f"{size:.{decimal_places}f} {unit}"
|
|
|
|
|
|
def check_correctness(memory, func, niter=100):
|
|
ac = True
|
|
for p in range(niter):
|
|
memory[:] = cp.ones(memory.shape).astype(data_type) * (p * MPI.COMM_WORLD.size + MPI.COMM_WORLD.rank)
|
|
cp.cuda.runtime.deviceSynchronize()
|
|
output_memory = func(None)
|
|
cp.cuda.runtime.deviceSynchronize()
|
|
expected = cp.zeros_like(memory)
|
|
for i in range(MPI.COMM_WORLD.size):
|
|
expected += cp.ones(memory.shape).astype(data_type) * (p * MPI.COMM_WORLD.size + i)
|
|
|
|
is_close = cp.isclose(output_memory, expected, rtol=1.0e-2, atol=2)
|
|
icf = is_close == 0
|
|
all_close = cp.all(is_close)
|
|
ac = ac and all_close
|
|
if not all_close:
|
|
print(
|
|
f"not close: p={p}, rank={MPI.COMM_WORLD.rank}, output={output_memory[icf][0]}, expected={expected[icf][0]}",
|
|
flush=True,
|
|
)
|
|
|
|
ac = MPI.COMM_WORLD.allreduce(ac, op=MPI.SUM)
|
|
return ac
|
|
|
|
|
|
def bench_time(niter: int, func):
|
|
# capture cuda graph for nites of the kernel launch
|
|
stream = cp.cuda.Stream(non_blocking=True)
|
|
with stream:
|
|
stream.begin_capture()
|
|
for i in range(niter):
|
|
func(stream)
|
|
graph = stream.end_capture()
|
|
|
|
# now run a warm up round
|
|
graph.launch(stream)
|
|
|
|
# now run the benchmark and measure time
|
|
start = cp.cuda.Event()
|
|
end = cp.cuda.Event()
|
|
|
|
start.record(stream)
|
|
graph.launch(stream)
|
|
end.record(stream)
|
|
end.synchronize()
|
|
|
|
return cp.cuda.get_elapsed_time(start, end) / niter * 1000.0
|
|
|
|
|
|
def find_best_algo(mscclpp_algos, niter):
|
|
assert len(mscclpp_algos) > 0
|
|
best_time = 10000000.0
|
|
best_algo = None
|
|
for algo in mscclpp_algos:
|
|
config, cur_time = find_best_config(algo, niter)
|
|
if cur_time < best_time:
|
|
best_time = cur_time
|
|
best_algo = algo
|
|
algo.set_params(*config)
|
|
if MPI.COMM_WORLD.rank == 0:
|
|
print(best_algo, end="", flush=True)
|
|
return best_algo
|
|
|
|
|
|
def find_best_config(mscclpp_call, niter):
|
|
best_time = 10000000.0
|
|
for config in mscclpp_call.auto_tune():
|
|
cur_time = bench_time(niter, mscclpp_call)
|
|
if cur_time < best_time:
|
|
best_time = cur_time
|
|
best_config = config
|
|
if MPI.COMM_WORLD.rank == 0:
|
|
print("t", end="", flush=True)
|
|
best_config = MPI.COMM_WORLD.bcast(best_config, root=0)
|
|
if MPI.COMM_WORLD.rank == 0:
|
|
print(best_config, end="", flush=True)
|
|
return best_config, best_time
|
|
|
|
|
|
def run_benchmark(mscclpp_group: CommGroup, nccl_op: nccl.NcclCommunicator, table: PrettyTable, niter: int, nelem: int):
|
|
memory = GpuBuffer(nelem, dtype=data_type)
|
|
memory_out = GpuBuffer(nelem, dtype=data_type)
|
|
cp.cuda.runtime.deviceSynchronize()
|
|
|
|
proxy_service = ProxyService()
|
|
if MPI.COMM_WORLD.size // N_GPUS_PER_NODE == 1:
|
|
if memory.nbytes < 2**20:
|
|
mscclpp_algos = [MscclppAllReduce2(mscclpp_group, memory, memory_out)]
|
|
else:
|
|
mscclpp_algos = [
|
|
MscclppAllReduce1(mscclpp_group, memory),
|
|
MscclppAllReduce3(mscclpp_group, memory, proxy_service),
|
|
]
|
|
if is_nvls_supported() and (data_type == cp.float32 or data_type == cp.float16):
|
|
mscclpp_algos.append(MscclppAllReduce6(mscclpp_group, nelem, data_type))
|
|
else:
|
|
if memory.nbytes < 2**22:
|
|
mscclpp_algos = [MscclppAllReduce5(mscclpp_group, memory, memory_out, N_GPUS_PER_NODE, proxy_service)]
|
|
else:
|
|
mscclpp_algos = [MscclppAllReduce4(mscclpp_group, memory, N_GPUS_PER_NODE, proxy_service)]
|
|
|
|
proxy_service.start_proxy()
|
|
MPI.COMM_WORLD.barrier()
|
|
mscclpp_call = find_best_algo(mscclpp_algos, 20)
|
|
if isinstance(mscclpp_call, MscclppAllReduce6):
|
|
memory = mscclpp_call.get_memory()
|
|
|
|
nccl_call = NcclAllReduce(nccl_op, memory)
|
|
|
|
memory_nbytes = memory.nbytes
|
|
mscclpp_time = bench_time(niter, mscclpp_call)
|
|
mscclpp_algBw = memory_nbytes / mscclpp_time / 1e3
|
|
mscclpp_check = "PASS" if check_correctness(memory, mscclpp_call) else "FAIL"
|
|
|
|
nccl_time = bench_time(niter, nccl_call)
|
|
nccl_algBw = memory_nbytes / nccl_time / 1e3
|
|
nccl_check = "PASS" if check_correctness(memory, nccl_call) else "FAIL"
|
|
|
|
MPI.COMM_WORLD.barrier()
|
|
proxy_service.stop_proxy()
|
|
|
|
speed_up = nccl_time / mscclpp_time
|
|
if MPI.COMM_WORLD.rank == 0:
|
|
table.add_row(
|
|
[
|
|
human_readable_size(memory_nbytes),
|
|
"{:.2f}".format(mscclpp_time),
|
|
"{:.2f}".format(mscclpp_algBw),
|
|
mscclpp_check,
|
|
"{:.2f}".format(nccl_time),
|
|
"{:.2f}".format(nccl_algBw),
|
|
nccl_check,
|
|
"{:.2f}".format(speed_up),
|
|
]
|
|
)
|
|
if MPI.COMM_WORLD.rank == 0:
|
|
print(".", end="", flush=True)
|
|
|
|
return memory.nbytes, mscclpp_algBw, nccl_algBw, speed_up
|
|
|
|
|
|
def is_valid(ip):
|
|
"""
|
|
Check if the IP address is valid for connecting to other devices.
|
|
This excludes loopback (127.0.0.1) and link-local (169.254.x.x) addresses.
|
|
"""
|
|
ip_obj = ipaddress.ip_address(ip)
|
|
return not (ip_obj.is_loopback or ip_obj.is_link_local or ip_obj.is_multicast)
|
|
|
|
|
|
def get_netinterface_info():
|
|
"""
|
|
Returns the name of the first network interface with a valid IP address that it finds.
|
|
"""
|
|
interfaces = ni.interfaces()
|
|
for interface in interfaces:
|
|
addresses = ni.ifaddresses(interface)
|
|
if ni.AF_INET in addresses:
|
|
for addr in addresses[ni.AF_INET]:
|
|
ip_address = addr["addr"]
|
|
if is_valid(ip_address):
|
|
print(f"Selected Interface: {interface}, IP Address: {ip_address}")
|
|
return interface, ip_address
|
|
return None, None
|
|
|
|
|
|
if __name__ == "__main__":
|
|
shm_comm = MPI.COMM_WORLD.Split_type(MPI.COMM_TYPE_SHARED, 0, MPI.INFO_NULL)
|
|
N_GPUS_PER_NODE = shm_comm.size
|
|
shm_comm.Free()
|
|
cp.cuda.Device(MPI.COMM_WORLD.rank % N_GPUS_PER_NODE).use()
|
|
|
|
# create a MscclppGroup
|
|
network_interface, my_ip = get_netinterface_info()
|
|
root_ip = MPI.COMM_WORLD.bcast(my_ip, root=0)
|
|
ifIpPortTrio = network_interface + ":" + root_ip + ":50000" # some random port
|
|
mscclpp_group = CommGroup(interfaceIpPortTrio=ifIpPortTrio, rank=MPI.COMM_WORLD.rank, size=MPI.COMM_WORLD.size)
|
|
|
|
# create a NcclComm
|
|
if MPI.COMM_WORLD.rank == 0:
|
|
uid = nccl.get_unique_id()
|
|
else:
|
|
uid = None
|
|
uid = MPI.COMM_WORLD.bcast(uid, root=0)
|
|
nccl_comm = nccl.NcclCommunicator(MPI.COMM_WORLD.size, uid, MPI.COMM_WORLD.rank)
|
|
|
|
table = None
|
|
if MPI.COMM_WORLD.rank == 0:
|
|
# Set table headers
|
|
table = PrettyTable()
|
|
table.field_names = [
|
|
f"Size ({dtype_str})",
|
|
"Time (us)",
|
|
"AlgBW (GB/s)",
|
|
"Correctness",
|
|
"NCCL Time (us)",
|
|
"NCCL AlgBW (GB/s)",
|
|
"NCCL Correctness",
|
|
"Speed Up",
|
|
]
|
|
|
|
sizes = []
|
|
mscclpp_algbw = []
|
|
nccl_algbw = []
|
|
speed_ups = []
|
|
end_range = 29
|
|
for i in range(10, end_range):
|
|
if MPI.COMM_WORLD.size // N_GPUS_PER_NODE == 1:
|
|
nelems = 2**i
|
|
elif MPI.COMM_WORLD.size // N_GPUS_PER_NODE == 2:
|
|
nelems = 3 * 2**i
|
|
else:
|
|
raise RuntimeError("Only support one node/two nodes communication")
|
|
|
|
if nelems * data_type().itemsize > 2**32:
|
|
break # due to trigger bit width limitation, we can only support up to 2**32
|
|
|
|
size, mscclpp_algBw, nccl_algBw, speed_up = run_benchmark(mscclpp_group, nccl_comm, table, 100, nelems)
|
|
sizes.append(size)
|
|
mscclpp_algbw.append(mscclpp_algBw)
|
|
nccl_algbw.append(nccl_algBw)
|
|
speed_ups.append(speed_up)
|
|
|
|
if MPI.COMM_WORLD.rank == 0:
|
|
print()
|
|
print(table)
|
|
|
|
plot_graph(sizes, mscclpp_algbw, nccl_algbw, speed_ups)
|
|
|
|
mscclpp_group = None
|
|
nccl_comm = None
|