Fix FP8 ROCm build/test issues and dtype naming (#792)

## Summary
- Fix ROCm FP8 build failure by using the actual FP8 `DataType` enum
constants in allreduce packet tuning.
- Fix FP8 E4M3FNUZ test encoding so small negative values do not produce
the FNUZ NaN byte (`0x80`).
- Align FP8 `DataType` enum constants and Python bindings with
torch-style names (`FLOAT8_E4M3FN`, `FLOAT8_E4M3FNUZ`, `FLOAT8_E5M2FNUZ`
/ `float8_e4m3fn`, `float8_e4m3fnuz`, `float8_e5m2fnuz`).

## Validation
- `./tools/lint.sh`
- `make -j` from `build/`
- `mpirun --allow-run-as-root -np 8 python3 -m pytest
python/test/test_fp8_accum.py -q` (`36 passed, 9 skipped`)
- `DTYPE=float8_e4m3fnuz ACCUM_DTYPE=float32 torchrun --nnodes=1
--nproc_per_node=8
examples/torch-integration/customized_comm_with_tuning.py`

---------

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
This commit is contained in:
Binyang Li
2026-04-28 15:02:22 -07:00
committed by GitHub
parent c97be492d5
commit 2c52937b26
12 changed files with 271 additions and 138 deletions

View File

@@ -45,8 +45,10 @@ void register_core(nb::module_& m) {
.value("float16", DataType::FLOAT16)
.value("float32", DataType::FLOAT32)
.value("bfloat16", DataType::BFLOAT16)
.value("float8_e4m3", DataType::FLOAT8_E4M3)
.value("float8_e4m3fn", DataType::FLOAT8_E4M3FN)
.value("float8_e4m3fnuz", DataType::FLOAT8_E4M3FNUZ)
.value("float8_e5m2", DataType::FLOAT8_E5M2)
.value("float8_e5m2fnuz", DataType::FLOAT8_E5M2FNUZ)
.value("uint8", DataType::UINT8)
.value("float8_e4m3b15", DataType::FLOAT8_E4M3B15);
@@ -328,4 +330,4 @@ NB_MODULE(_mscclpp, m) {
// ext
register_algorithm_collection_builder(m);
}
}