mirror of
https://github.com/microsoft/mscclpp.git
synced 2026-05-11 08:50:21 +00:00
Auto-tune vector sizes for NVLS allreduce6 (#338)
Also fixes bugs in MscclppAllReduce6 Below is the performance when the algorithm is fixed to MscclppAllReduce6 on 8 H100 GPUs connected with NVLink using CUDA 12.2. Float16: +-------------+-----------+--------------+-------------+----------------+-------------------+------------------+----------+ | Size (fp16) | Time (us) | AlgBW (GB/s) | Correctness | NCCL Time (us) | NCCL AlgBW (GB/s) | NCCL Correctness | Speed Up | +-------------+-----------+--------------+-------------+----------------+-------------------+------------------+----------+ | 2.0 KiB | 11.15 | 0.18 | PASS | 13.82 | 0.15 | PASS | 1.24 | | 4.0 KiB | 11.15 | 0.37 | PASS | 14.74 | 0.28 | PASS | 1.32 | | 8.0 KiB | 11.14 | 0.74 | PASS | 15.17 | 0.54 | PASS | 1.36 | | 16.0 KiB | 11.16 | 1.47 | PASS | 15.77 | 1.04 | PASS | 1.41 | | 32.0 KiB | 11.15 | 2.94 | PASS | 17.50 | 1.87 | PASS | 1.57 | | 64.0 KiB | 11.18 | 5.86 | PASS | 17.64 | 3.71 | PASS | 1.58 | | 128.0 KiB | 11.16 | 11.74 | PASS | 17.83 | 7.35 | PASS | 1.60 | | 256.0 KiB | 11.21 | 23.38 | PASS | 18.00 | 14.57 | PASS | 1.60 | | 512.0 KiB | 11.70 | 44.81 | PASS | 18.42 | 28.46 | PASS | 1.57 | | 1.0 MiB | 13.64 | 76.87 | PASS | 20.23 | 51.83 | PASS | 1.48 | | 2.0 MiB | 17.29 | 121.27 | PASS | 31.60 | 66.36 | PASS | 1.83 | | 4.0 MiB | 25.26 | 166.02 | PASS | 38.74 | 108.26 | PASS | 1.53 | | 8.0 MiB | 40.17 | 208.83 | PASS | 62.86 | 133.45 | PASS | 1.56 | | 16.0 MiB | 70.92 | 236.56 | PASS | 113.36 | 147.99 | PASS | 1.60 | | 32.0 MiB | 131.38 | 255.41 | PASS | 203.21 | 165.13 | PASS | 1.55 | | 64.0 MiB | 253.39 | 264.84 | PASS | 342.12 | 196.15 | PASS | 1.35 | | 128.0 MiB | 496.74 | 270.20 | PASS | 670.62 | 200.14 | PASS | 1.35 | | 256.0 MiB | 982.42 | 273.24 | PASS | 1318.36 | 203.61 | PASS | 1.34 | +-------------+-----------+--------------+-------------+----------------+-------------------+------------------+----------+ Float32: +-------------+-----------+--------------+-------------+----------------+-------------------+------------------+----------+ | Size (fp32) | Time (us) | AlgBW (GB/s) | Correctness | NCCL Time (us) | NCCL AlgBW (GB/s) | NCCL Correctness | Speed Up | +-------------+-----------+--------------+-------------+----------------+-------------------+------------------+----------+ | 4.0 KiB | 11.04 | 0.37 | PASS | 14.79 | 0.28 | PASS | 1.34 | | 8.0 KiB | 11.15 | 0.73 | PASS | 15.25 | 0.54 | PASS | 1.37 | | 16.0 KiB | 11.12 | 1.47 | PASS | 15.87 | 1.03 | PASS | 1.43 | | 32.0 KiB | 11.13 | 2.95 | PASS | 17.21 | 1.90 | PASS | 1.55 | | 64.0 KiB | 11.11 | 5.90 | PASS | 17.37 | 3.77 | PASS | 1.56 | | 128.0 KiB | 11.08 | 11.83 | PASS | 17.54 | 7.47 | PASS | 1.58 | | 256.0 KiB | 11.15 | 23.50 | PASS | 17.71 | 14.80 | PASS | 1.59 | | 512.0 KiB | 11.56 | 45.34 | PASS | 18.21 | 28.79 | PASS | 1.57 | | 1.0 MiB | 13.64 | 76.90 | PASS | 19.87 | 52.77 | PASS | 1.46 | | 2.0 MiB | 17.24 | 121.67 | PASS | 31.63 | 66.30 | PASS | 1.84 | | 4.0 MiB | 25.19 | 166.47 | PASS | 38.63 | 108.57 | PASS | 1.53 | | 8.0 MiB | 40.38 | 207.72 | PASS | 62.65 | 133.89 | PASS | 1.55 | | 16.0 MiB | 70.72 | 237.23 | PASS | 114.57 | 146.44 | PASS | 1.62 | | 32.0 MiB | 131.49 | 255.18 | PASS | 200.79 | 167.11 | PASS | 1.53 | | 64.0 MiB | 253.98 | 264.23 | PASS | 342.58 | 195.89 | PASS | 1.35 | | 128.0 MiB | 496.96 | 270.08 | PASS | 670.64 | 200.13 | PASS | 1.35 | | 256.0 MiB | 982.83 | 273.12 | PASS | 1318.90 | 203.53 | PASS | 1.34 | | 512.0 MiB | 1954.07 | 274.75 | PASS | 2609.04 | 205.77 | PASS | 1.34 | +-------------+-----------+--------------+-------------+----------------+-------------------+------------------+----------+
This commit is contained in:
@@ -175,7 +175,7 @@ def run_benchmark(
|
||||
MscclppAllReduce1(mscclpp_group, memory),
|
||||
MscclppAllReduce3(mscclpp_group, memory, proxy_service),
|
||||
]
|
||||
if is_nvls_supported():
|
||||
if is_nvls_supported() and (data_type == cp.float32 or data_type == cp.float16):
|
||||
mscclpp_algos.append(MscclppAllReduce6(mscclpp_group, nelem, data_type))
|
||||
else:
|
||||
if memory.nbytes < 2**22:
|
||||
|
||||
Reference in New Issue
Block a user