mirror of
https://github.com/microsoft/mscclpp.git
synced 2026-05-04 13:41:27 +00:00
Remove cuda-python and use CuPy APIs instead --------- Co-authored-by: Changho Hwang <changhohwang@microsoft.com>
293 lines
9.4 KiB
Python
293 lines
9.4 KiB
Python
# Copyright (c) Microsoft Corporation.
|
|
# Licensed under the MIT license.
|
|
|
|
import cupy as cp
|
|
from mscclpp_op import (
|
|
MscclppAllReduce1,
|
|
MscclppAllReduce2,
|
|
MscclppAllReduce3,
|
|
MscclppAllReduce4,
|
|
MscclppAllReduce5,
|
|
MscclppAllReduce6,
|
|
)
|
|
from nccl_op import NcclAllReduce
|
|
from mpi4py import MPI
|
|
import cupy.cuda.nccl as nccl
|
|
import mscclpp.comm as mscclpp_comm
|
|
from mscclpp import ProxyService, is_nvls_supported
|
|
from prettytable import PrettyTable
|
|
import netifaces as ni
|
|
|
|
data_type = cp.float32
|
|
|
|
if data_type == cp.float16:
|
|
dtype_str = "fp16"
|
|
elif data_type == cp.float32:
|
|
dtype_str = "fp32"
|
|
elif data_type == cp.int32:
|
|
dtype_str = "int32"
|
|
else:
|
|
raise RuntimeError("Unknown data type")
|
|
|
|
|
|
def plot_graph(sizes, mscclpp_algbw, nccl_algbw, speed_ups):
|
|
import matplotlib.pyplot as plt
|
|
|
|
human_readable_sizes = [human_readable_size(size) for size in sizes]
|
|
|
|
fig, ax1 = plt.subplots(figsize=(10, 6))
|
|
|
|
# Plotting AlgBW for MSCCLPP and NCCL on the primary y-axis
|
|
(line1,) = ax1.plot(sizes, mscclpp_algbw, marker="o", color="blue", label="MSCCLPP AlgBW")
|
|
(line2,) = ax1.plot(sizes, nccl_algbw, marker="x", color="red", label="NCCL AlgBW")
|
|
ax1.set_ylabel("AlgBW (GB/s)")
|
|
ax1.set_xlabel("Data Size")
|
|
|
|
# Logarithmic x-axis
|
|
ax1.set_xscale("log", base=2)
|
|
ax1.set_xticks(sizes)
|
|
ax1.set_xticklabels(human_readable_sizes, rotation=45)
|
|
|
|
# Adding secondary y-axis for Speed Up
|
|
ax2 = ax1.twinx()
|
|
(line3,) = ax2.plot(sizes, speed_ups, marker="^", color="green", label="Speed Up")
|
|
ax2.set_ylabel("Speed Up (NCCL Time / MSCCLPP Time)", color="green")
|
|
ax2.tick_params(axis="y", labelcolor="green")
|
|
|
|
# Set the lower bound of the secondary y-axis to 0
|
|
ax2.set_ylim(bottom=0)
|
|
|
|
# Creating legends
|
|
lines = [line1, line2, line3]
|
|
labels = [line.get_label() for line in lines]
|
|
ax1.legend(lines, labels, loc="upper left")
|
|
|
|
# Setting title and grid
|
|
num_nodes = MPI.COMM_WORLD.size // N_GPUS_PER_NODE
|
|
ax1.set_title(f"MSCCLPP vs NCCL -- {num_nodes} Nodes")
|
|
ax2.grid(True, which="both", ls="--")
|
|
|
|
# Saving the plot
|
|
plt.savefig(f"mscclpp_vs_nccl_comparison_num_nodes_{num_nodes}.jpeg", format="jpeg")
|
|
|
|
|
|
def human_readable_size(size, decimal_places=1):
|
|
for unit in ["B", "KiB", "MiB", "GiB", "TiB", "PiB"]:
|
|
if size < 1024.0 or unit == "PiB":
|
|
break
|
|
size /= 1024.0
|
|
return f"{size:.{decimal_places}f} {unit}"
|
|
|
|
|
|
def check_correctness(memory, func, niter=100):
|
|
ac = True
|
|
for p in range(niter):
|
|
memory[:] = cp.ones(memory.shape).astype(data_type) * (p * MPI.COMM_WORLD.size + MPI.COMM_WORLD.rank)
|
|
cp.cuda.runtime.deviceSynchronize()
|
|
output_memory = func(None)
|
|
cp.cuda.runtime.deviceSynchronize()
|
|
expected = cp.zeros_like(memory)
|
|
for i in range(MPI.COMM_WORLD.size):
|
|
expected += cp.ones(memory.shape).astype(data_type) * (p * MPI.COMM_WORLD.size + i)
|
|
|
|
is_close = cp.isclose(output_memory, expected, rtol=1.0e-2, atol=2)
|
|
icf = is_close == 0
|
|
all_close = cp.all(is_close)
|
|
ac = ac and all_close
|
|
if not all_close:
|
|
print(
|
|
f"not close: p={p}, rank={MPI.COMM_WORLD.rank}, output={output_memory[icf][0]}, expected={expected[icf][0]}",
|
|
flush=True,
|
|
)
|
|
|
|
ac = MPI.COMM_WORLD.allreduce(ac, op=MPI.SUM)
|
|
return ac
|
|
|
|
|
|
def bench_time(niter: int, func):
|
|
# capture cuda graph for nites of the kernel launch
|
|
stream = cp.cuda.Stream(non_blocking=True)
|
|
with stream:
|
|
stream.begin_capture()
|
|
for i in range(niter):
|
|
func(stream)
|
|
graph = stream.end_capture()
|
|
|
|
# now run a warm up round
|
|
graph.launch(stream)
|
|
|
|
# now run the benchmark and measure time
|
|
start = cp.cuda.Event()
|
|
end = cp.cuda.Event()
|
|
|
|
start.record(stream)
|
|
graph.launch(stream)
|
|
end.record(stream)
|
|
end.synchronize()
|
|
|
|
return cp.cuda.get_elapsed_time(start, end) / niter * 1000.0
|
|
|
|
|
|
def find_best_algo(mscclpp_algos, niter):
|
|
assert len(mscclpp_algos) > 0
|
|
best_time = 10000000.0
|
|
best_algo = None
|
|
for algo in mscclpp_algos:
|
|
config, cur_time = find_best_config(algo, niter)
|
|
if cur_time < best_time:
|
|
best_time = cur_time
|
|
best_algo = algo
|
|
algo.set_params(*config)
|
|
if MPI.COMM_WORLD.rank == 0:
|
|
print(best_algo, end="", flush=True)
|
|
return best_algo
|
|
|
|
|
|
def find_best_config(mscclpp_call, niter):
|
|
best_time = 10000000.0
|
|
for config in mscclpp_call.auto_tune():
|
|
cur_time = bench_time(niter, mscclpp_call)
|
|
if cur_time < best_time:
|
|
best_time = cur_time
|
|
best_config = config
|
|
if MPI.COMM_WORLD.rank == 0:
|
|
print("t", end="", flush=True)
|
|
best_config = MPI.COMM_WORLD.bcast(best_config, root=0)
|
|
if MPI.COMM_WORLD.rank == 0:
|
|
print(best_config, end="", flush=True)
|
|
return best_config, best_time
|
|
|
|
|
|
def run_benchmark(
|
|
mscclpp_group: mscclpp_comm.CommGroup, nccl_op: nccl.NcclCommunicator, table: PrettyTable, niter: int, nelem: int
|
|
):
|
|
memory = cp.zeros(nelem, dtype=data_type)
|
|
memory_out = cp.zeros(nelem, dtype=data_type)
|
|
cp.cuda.runtime.deviceSynchronize()
|
|
|
|
proxy_service = ProxyService()
|
|
if MPI.COMM_WORLD.size // N_GPUS_PER_NODE == 1:
|
|
if memory.nbytes < 2**20:
|
|
mscclpp_algos = [MscclppAllReduce2(mscclpp_group, memory, memory_out)]
|
|
else:
|
|
mscclpp_algos = [
|
|
MscclppAllReduce1(mscclpp_group, memory),
|
|
MscclppAllReduce3(mscclpp_group, memory, proxy_service),
|
|
]
|
|
if is_nvls_supported():
|
|
mscclpp_algos.append(MscclppAllReduce6(mscclpp_group, nelem, data_type))
|
|
else:
|
|
if memory.nbytes < 2**22:
|
|
mscclpp_algos = [MscclppAllReduce5(mscclpp_group, memory, memory_out, N_GPUS_PER_NODE, proxy_service)]
|
|
else:
|
|
mscclpp_algos = [MscclppAllReduce4(mscclpp_group, memory, N_GPUS_PER_NODE, proxy_service)]
|
|
|
|
proxy_service.start_proxy()
|
|
MPI.COMM_WORLD.barrier()
|
|
mscclpp_call = find_best_algo(mscclpp_algos, 20)
|
|
if isinstance(mscclpp_call, MscclppAllReduce6):
|
|
memory = mscclpp_call.get_memory()
|
|
|
|
nccl_call = NcclAllReduce(nccl_op, memory)
|
|
|
|
memory_nbytes = memory.nbytes
|
|
mscclpp_time = bench_time(niter, mscclpp_call)
|
|
mscclpp_algBw = memory_nbytes / mscclpp_time / 1e3
|
|
mscclpp_check = "PASS" if check_correctness(memory, mscclpp_call) else "FAIL"
|
|
|
|
nccl_time = bench_time(niter, nccl_call)
|
|
nccl_algBw = memory_nbytes / nccl_time / 1e3
|
|
nccl_check = "PASS" if check_correctness(memory, nccl_call) else "FAIL"
|
|
|
|
MPI.COMM_WORLD.barrier()
|
|
proxy_service.stop_proxy()
|
|
|
|
speed_up = nccl_time / mscclpp_time
|
|
if MPI.COMM_WORLD.rank == 0:
|
|
table.add_row(
|
|
[
|
|
human_readable_size(memory_nbytes),
|
|
"{:.2f}".format(mscclpp_time),
|
|
"{:.2f}".format(mscclpp_algBw),
|
|
mscclpp_check,
|
|
"{:.2f}".format(nccl_time),
|
|
"{:.2f}".format(nccl_algBw),
|
|
nccl_check,
|
|
"{:.2f}".format(speed_up),
|
|
]
|
|
)
|
|
if MPI.COMM_WORLD.rank == 0:
|
|
print(".", end="", flush=True)
|
|
|
|
return memory.nbytes, mscclpp_algBw, nccl_algBw, speed_up
|
|
|
|
|
|
if __name__ == "__main__":
|
|
shm_comm = MPI.COMM_WORLD.Split_type(MPI.COMM_TYPE_SHARED, 0, MPI.INFO_NULL)
|
|
N_GPUS_PER_NODE = shm_comm.size
|
|
shm_comm.Free()
|
|
cp.cuda.Device(MPI.COMM_WORLD.rank % N_GPUS_PER_NODE).use()
|
|
|
|
# create a MscclppGroup
|
|
network_interface = "eth0"
|
|
my_ip = ni.ifaddresses(network_interface)[ni.AF_INET][0]["addr"]
|
|
root_ip = MPI.COMM_WORLD.bcast(my_ip, root=0)
|
|
ifIpPortTrio = network_interface + ":" + root_ip + ":50000" # some random port
|
|
mscclpp_group = mscclpp_comm.CommGroup(
|
|
interfaceIpPortTrio=ifIpPortTrio, rank=MPI.COMM_WORLD.rank, size=MPI.COMM_WORLD.size
|
|
)
|
|
|
|
# create a NcclComm
|
|
if MPI.COMM_WORLD.rank == 0:
|
|
uid = nccl.get_unique_id()
|
|
else:
|
|
uid = None
|
|
uid = MPI.COMM_WORLD.bcast(uid, root=0)
|
|
nccl_comm = nccl.NcclCommunicator(MPI.COMM_WORLD.size, uid, MPI.COMM_WORLD.rank)
|
|
|
|
table = None
|
|
if MPI.COMM_WORLD.rank == 0:
|
|
# Set table headers
|
|
table = PrettyTable()
|
|
table.field_names = [
|
|
f"Size ({dtype_str})",
|
|
"Time (us)",
|
|
"AlgBW (GB/s)",
|
|
"Correctness",
|
|
"NCCL Time (us)",
|
|
"NCCL AlgBW (GB/s)",
|
|
"NCCL Correctness",
|
|
"Speed Up",
|
|
]
|
|
|
|
sizes = []
|
|
mscclpp_algbw = []
|
|
nccl_algbw = []
|
|
speed_ups = []
|
|
end_range = 28 if is_nvls_supported() else 29
|
|
for i in range(10, end_range):
|
|
if MPI.COMM_WORLD.size // N_GPUS_PER_NODE == 1:
|
|
nelems = 2**i
|
|
elif MPI.COMM_WORLD.size // N_GPUS_PER_NODE == 2:
|
|
nelems = 3 * 2**i
|
|
else:
|
|
raise RuntimeError("Only support one node/two nodes communication")
|
|
|
|
if nelems * data_type().itemsize > 2**32:
|
|
break # due to trigger bit width limitation, we can only support up to 2**32
|
|
|
|
size, mscclpp_algBw, nccl_algBw, speed_up = run_benchmark(mscclpp_group, nccl_comm, table, 100, nelems)
|
|
sizes.append(size)
|
|
mscclpp_algbw.append(mscclpp_algBw)
|
|
nccl_algbw.append(nccl_algBw)
|
|
speed_ups.append(speed_up)
|
|
|
|
if MPI.COMM_WORLD.rank == 0:
|
|
print()
|
|
print(table)
|
|
|
|
plot_graph(sizes, mscclpp_algbw, nccl_algbw, speed_ups)
|
|
|
|
mscclpp_group = None
|
|
nccl_comm = None
|