Fix FP8 ROCm build/test issues and dtype naming (#792)

## Summary
- Fix ROCm FP8 build failure by using the actual FP8 `DataType` enum
constants in allreduce packet tuning.
- Fix FP8 E4M3FNUZ test encoding so small negative values do not produce
the FNUZ NaN byte (`0x80`).
- Align FP8 `DataType` enum constants and Python bindings with
torch-style names (`FLOAT8_E4M3FN`, `FLOAT8_E4M3FNUZ`, `FLOAT8_E5M2FNUZ`
/ `float8_e4m3fn`, `float8_e4m3fnuz`, `float8_e5m2fnuz`).

## Validation
- `./tools/lint.sh`
- `make -j` from `build/`
- `mpirun --allow-run-as-root -np 8 python3 -m pytest
python/test/test_fp8_accum.py -q` (`36 passed, 9 skipped`)
- `DTYPE=float8_e4m3fnuz ACCUM_DTYPE=float32 torchrun --nnodes=1
--nproc_per_node=8
examples/torch-integration/customized_comm_with_tuning.py`

---------

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
This commit is contained in:
Binyang Li
2026-04-28 15:02:22 -07:00
committed by GitHub
parent c97be492d5
commit 2c52937b26
12 changed files with 271 additions and 138 deletions

View File

@@ -3,6 +3,7 @@
#include <filesystem>
#include <mscclpp/algorithm.hpp>
#include <mscclpp/errors.hpp>
#include <mscclpp/gpu_utils.hpp>
#include "logger.hpp"
@@ -182,13 +183,41 @@ CommResult DslAlgorithm::execute(std::shared_ptr<Communicator> comm, const void*
stream);
break;
#if defined(__FP8_TYPES_EXIST__)
case DataType::FLOAT8_E4M3:
executor->execute(rank, (__fp8_e4m3*)input, (__fp8_e4m3*)output, inputSize, outputSize, DataType::FLOAT8_E4M3,
case DataType::FLOAT8_E4M3FN:
#if defined(__FP8_E4M3_IS_FNUZ__)
THROW(EXEC, Error, ErrorCode::InvalidUsage,
"FLOAT8_E4M3FN is not natively supported on this platform; use FLOAT8_E4M3FNUZ");
#else
executor->execute(rank, (__fp8_e4m3*)input, (__fp8_e4m3*)output, inputSize, outputSize, DataType::FLOAT8_E4M3FN,
plan_, stream);
#endif
break;
case DataType::FLOAT8_E4M3FNUZ:
#if defined(__FP8_E4M3_IS_FNUZ__)
executor->execute(rank, (__fp8_e4m3*)input, (__fp8_e4m3*)output, inputSize, outputSize, DataType::FLOAT8_E4M3FNUZ,
plan_, stream);
#else
THROW(EXEC, Error, ErrorCode::InvalidUsage,
"FLOAT8_E4M3FNUZ is not natively supported on this platform; use FLOAT8_E4M3FN");
#endif
break;
case DataType::FLOAT8_E5M2:
#if defined(__FP8_E5M2_IS_FNUZ__)
THROW(EXEC, Error, ErrorCode::InvalidUsage,
"FLOAT8_E5M2 is not natively supported on this platform; use FLOAT8_E5M2FNUZ");
#else
executor->execute(rank, (__fp8_e5m2*)input, (__fp8_e5m2*)output, inputSize, outputSize, DataType::FLOAT8_E5M2,
plan_, stream);
#endif
break;
case DataType::FLOAT8_E5M2FNUZ:
#if defined(__FP8_E5M2_IS_FNUZ__)
executor->execute(rank, (__fp8_e5m2*)input, (__fp8_e5m2*)output, inputSize, outputSize, DataType::FLOAT8_E5M2FNUZ,
plan_, stream);
#else
THROW(EXEC, Error, ErrorCode::InvalidUsage,
"FLOAT8_E5M2FNUZ is not natively supported on this platform; use FLOAT8_E5M2");
#endif
break;
#endif
case DataType::FLOAT8_E4M3B15:
@@ -230,4 +259,4 @@ std::pair<std::shared_ptr<void>, size_t> getFlagBuffer() {
return {ptr, gDefaultFlagCount * sizeof(uint32_t)};
}
} // namespace mscclpp
} // namespace mscclpp