Auto-tune vector sizes for NVLS allreduce6 (#338)

Also fixes bugs in MscclppAllReduce6

Below is the performance when the algorithm is fixed to
MscclppAllReduce6 on 8 H100 GPUs connected with NVLink using CUDA 12.2.

Float16:

+-------------+-----------+--------------+-------------+----------------+-------------------+------------------+----------+
| Size (fp16) | Time (us) | AlgBW (GB/s) | Correctness | NCCL Time (us)
| NCCL AlgBW (GB/s) | NCCL Correctness | Speed Up |

+-------------+-----------+--------------+-------------+----------------+-------------------+------------------+----------+
| 2.0 KiB | 11.15 | 0.18 | PASS | 13.82 | 0.15 | PASS | 1.24 |
| 4.0 KiB | 11.15 | 0.37 | PASS | 14.74 | 0.28 | PASS | 1.32 |
| 8.0 KiB | 11.14 | 0.74 | PASS | 15.17 | 0.54 | PASS | 1.36 |
| 16.0 KiB | 11.16 | 1.47 | PASS | 15.77 | 1.04 | PASS | 1.41 |
| 32.0 KiB | 11.15 | 2.94 | PASS | 17.50 | 1.87 | PASS | 1.57 |
| 64.0 KiB | 11.18 | 5.86 | PASS | 17.64 | 3.71 | PASS | 1.58 |
| 128.0 KiB | 11.16 | 11.74 | PASS | 17.83 | 7.35 | PASS | 1.60 |
| 256.0 KiB | 11.21 | 23.38 | PASS | 18.00 | 14.57 | PASS | 1.60 |
| 512.0 KiB | 11.70 | 44.81 | PASS | 18.42 | 28.46 | PASS | 1.57 |
| 1.0 MiB | 13.64 | 76.87 | PASS | 20.23 | 51.83 | PASS | 1.48 |
| 2.0 MiB | 17.29 | 121.27 | PASS | 31.60 | 66.36 | PASS | 1.83 |
| 4.0 MiB | 25.26 | 166.02 | PASS | 38.74 | 108.26 | PASS | 1.53 |
| 8.0 MiB | 40.17 | 208.83 | PASS | 62.86 | 133.45 | PASS | 1.56 |
| 16.0 MiB | 70.92 | 236.56 | PASS | 113.36 | 147.99 | PASS | 1.60 |
| 32.0 MiB | 131.38 | 255.41 | PASS | 203.21 | 165.13 | PASS | 1.55 |
| 64.0 MiB | 253.39 | 264.84 | PASS | 342.12 | 196.15 | PASS | 1.35 |
| 128.0 MiB | 496.74 | 270.20 | PASS | 670.62 | 200.14 | PASS | 1.35 |
| 256.0 MiB | 982.42 | 273.24 | PASS | 1318.36 | 203.61 | PASS | 1.34 |

+-------------+-----------+--------------+-------------+----------------+-------------------+------------------+----------+
 
 Float32:

+-------------+-----------+--------------+-------------+----------------+-------------------+------------------+----------+
| Size (fp32) | Time (us) | AlgBW (GB/s) | Correctness | NCCL Time (us)
| NCCL AlgBW (GB/s) | NCCL Correctness | Speed Up |

+-------------+-----------+--------------+-------------+----------------+-------------------+------------------+----------+
| 4.0 KiB | 11.04 | 0.37 | PASS | 14.79 | 0.28 | PASS | 1.34 |
| 8.0 KiB | 11.15 | 0.73 | PASS | 15.25 | 0.54 | PASS | 1.37 |
| 16.0 KiB | 11.12 | 1.47 | PASS | 15.87 | 1.03 | PASS | 1.43 |
| 32.0 KiB | 11.13 | 2.95 | PASS | 17.21 | 1.90 | PASS | 1.55 |
| 64.0 KiB | 11.11 | 5.90 | PASS | 17.37 | 3.77 | PASS | 1.56 |
| 128.0 KiB | 11.08 | 11.83 | PASS | 17.54 | 7.47 | PASS | 1.58 |
| 256.0 KiB | 11.15 | 23.50 | PASS | 17.71 | 14.80 | PASS | 1.59 |
| 512.0 KiB | 11.56 | 45.34 | PASS | 18.21 | 28.79 | PASS | 1.57 |
| 1.0 MiB | 13.64 | 76.90 | PASS | 19.87 | 52.77 | PASS | 1.46 |
| 2.0 MiB | 17.24 | 121.67 | PASS | 31.63 | 66.30 | PASS | 1.84 |
| 4.0 MiB | 25.19 | 166.47 | PASS | 38.63 | 108.57 | PASS | 1.53 |
| 8.0 MiB | 40.38 | 207.72 | PASS | 62.65 | 133.89 | PASS | 1.55 |
| 16.0 MiB | 70.72 | 237.23 | PASS | 114.57 | 146.44 | PASS | 1.62 |
| 32.0 MiB | 131.49 | 255.18 | PASS | 200.79 | 167.11 | PASS | 1.53 |
| 64.0 MiB | 253.98 | 264.23 | PASS | 342.58 | 195.89 | PASS | 1.35 |
| 128.0 MiB | 496.96 | 270.08 | PASS | 670.64 | 200.13 | PASS | 1.35 |
| 256.0 MiB | 982.83 | 273.12 | PASS | 1318.90 | 203.53 | PASS | 1.34 |
| 512.0 MiB | 1954.07 | 274.75 | PASS | 2609.04 | 205.77 | PASS | 1.34 |

+-------------+-----------+--------------+-------------+----------------+-------------------+------------------+----------+
This commit is contained in:
Roshan Dathathri
2024-08-15 20:11:54 -07:00
committed by GitHub
parent ead4efc315
commit 7ed13ec4b5
6 changed files with 114 additions and 42 deletions

View File

@@ -175,7 +175,7 @@ def run_benchmark(
MscclppAllReduce1(mscclpp_group, memory),
MscclppAllReduce3(mscclpp_group, memory, proxy_service),
]
if is_nvls_supported():
if is_nvls_supported() and (data_type == cp.float32 or data_type == cp.float16):
mscclpp_algos.append(MscclppAllReduce6(mscclpp_group, nelem, data_type))
else:
if memory.nbytes < 2**22: