mirror of
https://github.com/microsoft/mscclpp.git
synced 2026-05-12 09:17:06 +00:00
Fix FP8 ROCm build/test issues and dtype naming (#792)
## Summary - Fix ROCm FP8 build failure by using the actual FP8 `DataType` enum constants in allreduce packet tuning. - Fix FP8 E4M3FNUZ test encoding so small negative values do not produce the FNUZ NaN byte (`0x80`). - Align FP8 `DataType` enum constants and Python bindings with torch-style names (`FLOAT8_E4M3FN`, `FLOAT8_E4M3FNUZ`, `FLOAT8_E5M2FNUZ` / `float8_e4m3fn`, `float8_e4m3fnuz`, `float8_e5m2fnuz`). ## Validation - `./tools/lint.sh` - `make -j` from `build/` - `mpirun --allow-run-as-root -np 8 python3 -m pytest python/test/test_fp8_accum.py -q` (`36 passed, 9 skipped`) - `DTYPE=float8_e4m3fnuz ACCUM_DTYPE=float32 torchrun --nnodes=1 --nproc_per_node=8 examples/torch-integration/customized_comm_with_tuning.py` --------- Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
This commit is contained in:
@@ -45,8 +45,10 @@ void register_core(nb::module_& m) {
|
||||
.value("float16", DataType::FLOAT16)
|
||||
.value("float32", DataType::FLOAT32)
|
||||
.value("bfloat16", DataType::BFLOAT16)
|
||||
.value("float8_e4m3", DataType::FLOAT8_E4M3)
|
||||
.value("float8_e4m3fn", DataType::FLOAT8_E4M3FN)
|
||||
.value("float8_e4m3fnuz", DataType::FLOAT8_E4M3FNUZ)
|
||||
.value("float8_e5m2", DataType::FLOAT8_E5M2)
|
||||
.value("float8_e5m2fnuz", DataType::FLOAT8_E5M2FNUZ)
|
||||
.value("uint8", DataType::UINT8)
|
||||
.value("float8_e4m3b15", DataType::FLOAT8_E4M3B15);
|
||||
|
||||
@@ -328,4 +330,4 @@ NB_MODULE(_mscclpp, m) {
|
||||
|
||||
// ext
|
||||
register_algorithm_collection_builder(m);
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user