Add ROCm FP8 E4M3B15 support (#774)

## Summary

Add ROCm (gfx942) support for the FP8 E4M3B15 data type, including
optimized conversion routines between FP8 E4M3B15 and FP16/FP32 using
inline assembly.

Extends the allpair packet and fullmesh allreduce kernels to support
higher-precision accumulation (e.g., FP16/FP32) when reducing FP8 data,
improving numerical accuracy.

Adds Python tests to verify that higher-precision accumulation is at
least as accurate as native FP8 accumulation across all algorithm
variants.

---------

Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
This commit is contained in:
Binyang Li
2026-04-08 09:53:45 -07:00
committed by GitHub
parent e66ce39647
commit 8896cd909a
7 changed files with 538 additions and 268 deletions

View File

@@ -24,4 +24,7 @@ set_target_properties(mscclpp_py PROPERTIES OUTPUT_NAME _mscclpp)
set_target_properties(mscclpp_py PROPERTIES INSTALL_RPATH "\$ORIGIN/lib")
target_link_libraries(mscclpp_py PRIVATE dlpack mscclpp mscclpp_collectives ${GPU_LIBRARIES})
target_include_directories(mscclpp_py SYSTEM PRIVATE ${GPU_INCLUDE_DIRS})
if(MSCCLPP_USE_ROCM)
target_compile_definitions(mscclpp_py PRIVATE MSCCLPP_USE_ROCM)
endif()
install(TARGETS mscclpp_py LIBRARY DESTINATION .)

View File

@@ -1,5 +1,5 @@
mpi4py==4.1.1
cupy==13.6.0
mpi4py
cupy
prettytable
netifaces
pytest

View File

@@ -21,9 +21,8 @@ from .mscclpp_mpi import MpiGroup, parametrize_mpi_groups, mpi_group
# FP8 E4M3 (hardware) requires SM >= 89 (Ada / Hopper) on NVIDIA GPUs.
# On AMD/ROCm (e.g. MI300X), FP8 is supported natively — no skip needed.
_is_hip = hasattr(cp.cuda.runtime, "is_hip") and cp.cuda.runtime.is_hip
# TODO(binyli): Skip hip for now, will fix it in the next PR
_skip_fp8 = _is_hip or int(cp.cuda.Device().compute_capability) < 89
pytestmark = pytest.mark.skipif(_skip_fp8, reason="FP8 accum tests require SM >= 89 on CUDA (HIP not yet supported)")
_skip_fp8 = not _is_hip and int(cp.cuda.Device().compute_capability) < 89
pytestmark = pytest.mark.skipif(_skip_fp8, reason="FP8 accum tests require SM >= 89 on CUDA")
# ---------------------------------------------------------------------------
# FP8 E4M3FN helpers (bias=7, no infinity, NaN = exp=15 & mant=7)
@@ -208,6 +207,7 @@ def run_allreduce(algo, comm_group, buffer, dtype, accum_dtype=None, nblocks=0,
"default_allreduce_nvls_packet",
"default_allreduce_fullmesh",
"default_allreduce_rsag_zero_copy",
"default_allreduce_allpair_packet",
],
)
@pytest.mark.parametrize("size", [1024, 4096, 16384, 65536, 262144, 1048576])
@@ -220,6 +220,8 @@ def test_fp8_e4m3_accum(mpi_group: MpiGroup, algo_name: str, size: int):
comm_group, algo_map, scratch = setup_algorithms(mpi_group)
if algo_name not in algo_map:
pytest.skip(f"{algo_name} not available")
if "nvls" in algo_name and not is_nvls_supported():
pytest.skip(f"{algo_name} requires NVLS which is not supported on this platform")
algo = algo_map[algo_name]
buf = GpuBuffer(size, dtype=cp.uint8)
@@ -243,9 +245,9 @@ def test_fp8_e4m3_accum(mpi_group: MpiGroup, algo_name: str, size: int):
errors = {}
for accum_label, accum_dtype in accum_configs:
# Generate deterministic per-rank data
cp.random.seed(42 + rank)
src_f32 = cp.random.randn(size).astype(cp.float32)
# Generate deterministic per-rank data (use numpy to avoid hipRAND issues on ROCm)
rng = np.random.RandomState(42 + rank)
src_f32 = cp.asarray(rng.randn(size).astype(np.float32))
src_f32 = cp.clip(src_f32, -240.0, 240.0)
src_fp8 = float_to_e4m3fn(src_f32)
@@ -268,8 +270,8 @@ def test_fp8_e4m3_accum(mpi_group: MpiGroup, algo_name: str, size: int):
# Compute float32 reference: sum all ranks' quantized FP8 inputs in float32
ref_f32 = cp.zeros(size, dtype=cp.float32)
for r in range(world_size):
cp.random.seed(42 + r)
rank_data = cp.random.randn(size).astype(cp.float32)
rng_r = np.random.RandomState(42 + r)
rank_data = cp.asarray(rng_r.randn(size).astype(np.float32))
rank_data = cp.clip(rank_data, -240.0, 240.0)
rank_data_fp8 = float_to_e4m3fn(rank_data)
ref_f32 += e4m3fn_to_float(rank_data_fp8)
@@ -303,6 +305,8 @@ def test_fp8_e4m3_accum(mpi_group: MpiGroup, algo_name: str, size: int):
"default_allreduce_packet",
"default_allreduce_nvls_packet",
"default_allreduce_rsag_zero_copy",
"default_allreduce_fullmesh",
"default_allreduce_allpair_packet",
],
)
@pytest.mark.parametrize("size", [1024, 4096, 65536])
@@ -315,6 +319,8 @@ def test_fp8_e4m3b15_accum(mpi_group: MpiGroup, algo_name: str, size: int):
comm_group, algo_map, scratch = setup_algorithms(mpi_group)
if algo_name not in algo_map:
pytest.skip(f"{algo_name} not available")
if "nvls" in algo_name and not is_nvls_supported():
pytest.skip(f"{algo_name} requires NVLS which is not supported on this platform")
algo = algo_map[algo_name]
buf = GpuBuffer(size, dtype=cp.uint8)
@@ -336,9 +342,9 @@ def test_fp8_e4m3b15_accum(mpi_group: MpiGroup, algo_name: str, size: int):
errors = {}
for accum_label, accum_dtype in accum_configs:
# Generate deterministic per-rank random uint8 values in valid e4m3b15 range
cp.random.seed(42 + rank)
raw = cp.random.randint(0, 0x78, (size,), dtype=cp.uint8)
signs = cp.random.randint(0, 2, (size,), dtype=cp.uint8).astype(cp.uint8) << 7
rng = np.random.RandomState(42 + rank)
raw = cp.asarray(rng.randint(0, 0x78, (size,)).astype(np.uint8))
signs = cp.asarray(rng.randint(0, 2, (size,)).astype(np.uint8)) << 7
src_uint8 = raw | signs
# Fix negative zero -> positive zero
src_uint8 = cp.where(src_uint8 == 0x80, cp.uint8(0), src_uint8)
@@ -364,9 +370,9 @@ def test_fp8_e4m3b15_accum(mpi_group: MpiGroup, algo_name: str, size: int):
# Compute float32 reference
ref_f32 = cp.zeros(size, dtype=cp.float32)
for r in range(world_size):
cp.random.seed(42 + r)
raw_r = cp.random.randint(0, 0x78, (size,), dtype=cp.uint8)
signs_r = cp.random.randint(0, 2, (size,), dtype=cp.uint8).astype(cp.uint8) << 7
rng_r = np.random.RandomState(42 + r)
raw_r = cp.asarray(rng_r.randint(0, 0x78, (size,)).astype(np.uint8))
signs_r = cp.asarray(rng_r.randint(0, 2, (size,)).astype(np.uint8)) << 7
bits_r = raw_r | signs_r
bits_r = cp.where(bits_r == 0x80, cp.uint8(0), bits_r)
ref_f32 += e4m3b15_to_float(bits_r)