mirror of
https://github.com/microsoft/mscclpp.git
synced 2026-05-14 02:03:03 +00:00
Fix FP8 ROCm build/test issues and dtype naming (#792)
## Summary - Fix ROCm FP8 build failure by using the actual FP8 `DataType` enum constants in allreduce packet tuning. - Fix FP8 E4M3FNUZ test encoding so small negative values do not produce the FNUZ NaN byte (`0x80`). - Align FP8 `DataType` enum constants and Python bindings with torch-style names (`FLOAT8_E4M3FN`, `FLOAT8_E4M3FNUZ`, `FLOAT8_E5M2FNUZ` / `float8_e4m3fn`, `float8_e4m3fnuz`, `float8_e5m2fnuz`). ## Validation - `./tools/lint.sh` - `make -j` from `build/` - `mpirun --allow-run-as-root -np 8 python3 -m pytest python/test/test_fp8_accum.py -q` (`36 passed, 9 skipped`) - `DTYPE=float8_e4m3fnuz ACCUM_DTYPE=float32 torchrun --nnodes=1 --nproc_per_node=8 examples/torch-integration/customized_comm_with_tuning.py` --------- Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
This commit is contained in:
@@ -3,6 +3,7 @@
|
||||
|
||||
#include <filesystem>
|
||||
#include <mscclpp/algorithm.hpp>
|
||||
#include <mscclpp/errors.hpp>
|
||||
#include <mscclpp/gpu_utils.hpp>
|
||||
|
||||
#include "logger.hpp"
|
||||
@@ -182,13 +183,41 @@ CommResult DslAlgorithm::execute(std::shared_ptr<Communicator> comm, const void*
|
||||
stream);
|
||||
break;
|
||||
#if defined(__FP8_TYPES_EXIST__)
|
||||
case DataType::FLOAT8_E4M3:
|
||||
executor->execute(rank, (__fp8_e4m3*)input, (__fp8_e4m3*)output, inputSize, outputSize, DataType::FLOAT8_E4M3,
|
||||
case DataType::FLOAT8_E4M3FN:
|
||||
#if defined(__FP8_E4M3_IS_FNUZ__)
|
||||
THROW(EXEC, Error, ErrorCode::InvalidUsage,
|
||||
"FLOAT8_E4M3FN is not natively supported on this platform; use FLOAT8_E4M3FNUZ");
|
||||
#else
|
||||
executor->execute(rank, (__fp8_e4m3*)input, (__fp8_e4m3*)output, inputSize, outputSize, DataType::FLOAT8_E4M3FN,
|
||||
plan_, stream);
|
||||
#endif
|
||||
break;
|
||||
case DataType::FLOAT8_E4M3FNUZ:
|
||||
#if defined(__FP8_E4M3_IS_FNUZ__)
|
||||
executor->execute(rank, (__fp8_e4m3*)input, (__fp8_e4m3*)output, inputSize, outputSize, DataType::FLOAT8_E4M3FNUZ,
|
||||
plan_, stream);
|
||||
#else
|
||||
THROW(EXEC, Error, ErrorCode::InvalidUsage,
|
||||
"FLOAT8_E4M3FNUZ is not natively supported on this platform; use FLOAT8_E4M3FN");
|
||||
#endif
|
||||
break;
|
||||
case DataType::FLOAT8_E5M2:
|
||||
#if defined(__FP8_E5M2_IS_FNUZ__)
|
||||
THROW(EXEC, Error, ErrorCode::InvalidUsage,
|
||||
"FLOAT8_E5M2 is not natively supported on this platform; use FLOAT8_E5M2FNUZ");
|
||||
#else
|
||||
executor->execute(rank, (__fp8_e5m2*)input, (__fp8_e5m2*)output, inputSize, outputSize, DataType::FLOAT8_E5M2,
|
||||
plan_, stream);
|
||||
#endif
|
||||
break;
|
||||
case DataType::FLOAT8_E5M2FNUZ:
|
||||
#if defined(__FP8_E5M2_IS_FNUZ__)
|
||||
executor->execute(rank, (__fp8_e5m2*)input, (__fp8_e5m2*)output, inputSize, outputSize, DataType::FLOAT8_E5M2FNUZ,
|
||||
plan_, stream);
|
||||
#else
|
||||
THROW(EXEC, Error, ErrorCode::InvalidUsage,
|
||||
"FLOAT8_E5M2FNUZ is not natively supported on this platform; use FLOAT8_E5M2");
|
||||
#endif
|
||||
break;
|
||||
#endif
|
||||
case DataType::FLOAT8_E4M3B15:
|
||||
@@ -230,4 +259,4 @@ std::pair<std::shared_ptr<void>, size_t> getFlagBuffer() {
|
||||
return {ptr, gDefaultFlagCount * sizeof(uint32_t)};
|
||||
}
|
||||
|
||||
} // namespace mscclpp
|
||||
} // namespace mscclpp
|
||||
|
||||
Reference in New Issue
Block a user