mirror of
https://github.com/microsoft/mscclpp.git
synced 2026-05-11 17:00:22 +00:00
Add ROCm FP8 E4M3B15 support (#774)
## Summary Add ROCm (gfx942) support for the FP8 E4M3B15 data type, including optimized conversion routines between FP8 E4M3B15 and FP16/FP32 using inline assembly. Extends the allpair packet and fullmesh allreduce kernels to support higher-precision accumulation (e.g., FP16/FP32) when reducing FP8 data, improving numerical accuracy. Adds Python tests to verify that higher-precision accumulation is at least as accurate as native FP8 accumulation across all algorithm variants. --------- Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com> Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
This commit is contained in:
@@ -24,4 +24,7 @@ set_target_properties(mscclpp_py PROPERTIES OUTPUT_NAME _mscclpp)
|
||||
set_target_properties(mscclpp_py PROPERTIES INSTALL_RPATH "\$ORIGIN/lib")
|
||||
target_link_libraries(mscclpp_py PRIVATE dlpack mscclpp mscclpp_collectives ${GPU_LIBRARIES})
|
||||
target_include_directories(mscclpp_py SYSTEM PRIVATE ${GPU_INCLUDE_DIRS})
|
||||
if(MSCCLPP_USE_ROCM)
|
||||
target_compile_definitions(mscclpp_py PRIVATE MSCCLPP_USE_ROCM)
|
||||
endif()
|
||||
install(TARGETS mscclpp_py LIBRARY DESTINATION .)
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
mpi4py==4.1.1
|
||||
cupy==13.6.0
|
||||
mpi4py
|
||||
cupy
|
||||
prettytable
|
||||
netifaces
|
||||
pytest
|
||||
|
||||
@@ -21,9 +21,8 @@ from .mscclpp_mpi import MpiGroup, parametrize_mpi_groups, mpi_group
|
||||
# FP8 E4M3 (hardware) requires SM >= 89 (Ada / Hopper) on NVIDIA GPUs.
|
||||
# On AMD/ROCm (e.g. MI300X), FP8 is supported natively — no skip needed.
|
||||
_is_hip = hasattr(cp.cuda.runtime, "is_hip") and cp.cuda.runtime.is_hip
|
||||
# TODO(binyli): Skip hip for now, will fix it in the next PR
|
||||
_skip_fp8 = _is_hip or int(cp.cuda.Device().compute_capability) < 89
|
||||
pytestmark = pytest.mark.skipif(_skip_fp8, reason="FP8 accum tests require SM >= 89 on CUDA (HIP not yet supported)")
|
||||
_skip_fp8 = not _is_hip and int(cp.cuda.Device().compute_capability) < 89
|
||||
pytestmark = pytest.mark.skipif(_skip_fp8, reason="FP8 accum tests require SM >= 89 on CUDA")
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# FP8 E4M3FN helpers (bias=7, no infinity, NaN = exp=15 & mant=7)
|
||||
@@ -208,6 +207,7 @@ def run_allreduce(algo, comm_group, buffer, dtype, accum_dtype=None, nblocks=0,
|
||||
"default_allreduce_nvls_packet",
|
||||
"default_allreduce_fullmesh",
|
||||
"default_allreduce_rsag_zero_copy",
|
||||
"default_allreduce_allpair_packet",
|
||||
],
|
||||
)
|
||||
@pytest.mark.parametrize("size", [1024, 4096, 16384, 65536, 262144, 1048576])
|
||||
@@ -220,6 +220,8 @@ def test_fp8_e4m3_accum(mpi_group: MpiGroup, algo_name: str, size: int):
|
||||
comm_group, algo_map, scratch = setup_algorithms(mpi_group)
|
||||
if algo_name not in algo_map:
|
||||
pytest.skip(f"{algo_name} not available")
|
||||
if "nvls" in algo_name and not is_nvls_supported():
|
||||
pytest.skip(f"{algo_name} requires NVLS which is not supported on this platform")
|
||||
algo = algo_map[algo_name]
|
||||
|
||||
buf = GpuBuffer(size, dtype=cp.uint8)
|
||||
@@ -243,9 +245,9 @@ def test_fp8_e4m3_accum(mpi_group: MpiGroup, algo_name: str, size: int):
|
||||
|
||||
errors = {}
|
||||
for accum_label, accum_dtype in accum_configs:
|
||||
# Generate deterministic per-rank data
|
||||
cp.random.seed(42 + rank)
|
||||
src_f32 = cp.random.randn(size).astype(cp.float32)
|
||||
# Generate deterministic per-rank data (use numpy to avoid hipRAND issues on ROCm)
|
||||
rng = np.random.RandomState(42 + rank)
|
||||
src_f32 = cp.asarray(rng.randn(size).astype(np.float32))
|
||||
src_f32 = cp.clip(src_f32, -240.0, 240.0)
|
||||
src_fp8 = float_to_e4m3fn(src_f32)
|
||||
|
||||
@@ -268,8 +270,8 @@ def test_fp8_e4m3_accum(mpi_group: MpiGroup, algo_name: str, size: int):
|
||||
# Compute float32 reference: sum all ranks' quantized FP8 inputs in float32
|
||||
ref_f32 = cp.zeros(size, dtype=cp.float32)
|
||||
for r in range(world_size):
|
||||
cp.random.seed(42 + r)
|
||||
rank_data = cp.random.randn(size).astype(cp.float32)
|
||||
rng_r = np.random.RandomState(42 + r)
|
||||
rank_data = cp.asarray(rng_r.randn(size).astype(np.float32))
|
||||
rank_data = cp.clip(rank_data, -240.0, 240.0)
|
||||
rank_data_fp8 = float_to_e4m3fn(rank_data)
|
||||
ref_f32 += e4m3fn_to_float(rank_data_fp8)
|
||||
@@ -303,6 +305,8 @@ def test_fp8_e4m3_accum(mpi_group: MpiGroup, algo_name: str, size: int):
|
||||
"default_allreduce_packet",
|
||||
"default_allreduce_nvls_packet",
|
||||
"default_allreduce_rsag_zero_copy",
|
||||
"default_allreduce_fullmesh",
|
||||
"default_allreduce_allpair_packet",
|
||||
],
|
||||
)
|
||||
@pytest.mark.parametrize("size", [1024, 4096, 65536])
|
||||
@@ -315,6 +319,8 @@ def test_fp8_e4m3b15_accum(mpi_group: MpiGroup, algo_name: str, size: int):
|
||||
comm_group, algo_map, scratch = setup_algorithms(mpi_group)
|
||||
if algo_name not in algo_map:
|
||||
pytest.skip(f"{algo_name} not available")
|
||||
if "nvls" in algo_name and not is_nvls_supported():
|
||||
pytest.skip(f"{algo_name} requires NVLS which is not supported on this platform")
|
||||
|
||||
algo = algo_map[algo_name]
|
||||
buf = GpuBuffer(size, dtype=cp.uint8)
|
||||
@@ -336,9 +342,9 @@ def test_fp8_e4m3b15_accum(mpi_group: MpiGroup, algo_name: str, size: int):
|
||||
errors = {}
|
||||
for accum_label, accum_dtype in accum_configs:
|
||||
# Generate deterministic per-rank random uint8 values in valid e4m3b15 range
|
||||
cp.random.seed(42 + rank)
|
||||
raw = cp.random.randint(0, 0x78, (size,), dtype=cp.uint8)
|
||||
signs = cp.random.randint(0, 2, (size,), dtype=cp.uint8).astype(cp.uint8) << 7
|
||||
rng = np.random.RandomState(42 + rank)
|
||||
raw = cp.asarray(rng.randint(0, 0x78, (size,)).astype(np.uint8))
|
||||
signs = cp.asarray(rng.randint(0, 2, (size,)).astype(np.uint8)) << 7
|
||||
src_uint8 = raw | signs
|
||||
# Fix negative zero -> positive zero
|
||||
src_uint8 = cp.where(src_uint8 == 0x80, cp.uint8(0), src_uint8)
|
||||
@@ -364,9 +370,9 @@ def test_fp8_e4m3b15_accum(mpi_group: MpiGroup, algo_name: str, size: int):
|
||||
# Compute float32 reference
|
||||
ref_f32 = cp.zeros(size, dtype=cp.float32)
|
||||
for r in range(world_size):
|
||||
cp.random.seed(42 + r)
|
||||
raw_r = cp.random.randint(0, 0x78, (size,), dtype=cp.uint8)
|
||||
signs_r = cp.random.randint(0, 2, (size,), dtype=cp.uint8).astype(cp.uint8) << 7
|
||||
rng_r = np.random.RandomState(42 + r)
|
||||
raw_r = cp.asarray(rng_r.randint(0, 0x78, (size,)).astype(np.uint8))
|
||||
signs_r = cp.asarray(rng_r.randint(0, 2, (size,)).astype(np.uint8)) << 7
|
||||
bits_r = raw_r | signs_r
|
||||
bits_r = cp.where(bits_r == 0x80, cp.uint8(0), bits_r)
|
||||
ref_f32 += e4m3b15_to_float(bits_r)
|
||||
|
||||
Reference in New Issue
Block a user