Update dtype name (#748)

- Change FP8_E4M3/FP8_E5M2 to FLOAT8_E4M3/FLOAT8_E5M2
- Add torch.uint8 to DataType.uint8 mapping
This commit is contained in:
Binyang Li
2026-02-18 10:35:44 -08:00
committed by GitHub
parent d0d5a8c034
commit 4701ae3a95
10 changed files with 35 additions and 32 deletions

View File

@@ -68,14 +68,14 @@ namespace mscclpp {
/// Data types supported by mscclpp operations.
enum class DataType {
INT32, // 32-bit signed integer.
UINT32, // 32-bit unsigned integer.
FLOAT16, // IEEE 754 half precision.
FLOAT32, // IEEE 754 single precision.
BFLOAT16, // bfloat16 precision.
FP8_E4M3, // FP8 with E4M3 layout.
FP8_E5M2, // FP8 with E5M2 layout.
UINT8, // 8-bit unsigned integer.
INT32, // 32-bit signed integer.
UINT32, // 32-bit unsigned integer.
FLOAT16, // IEEE 754 half precision.
FLOAT32, // IEEE 754 single precision.
BFLOAT16, // bfloat16 precision.
FLOAT8_E4M3, // float8 with E4M3 layout.
FLOAT8_E5M2, // float8 with E5M2 layout.
UINT8, // 8-bit unsigned integer.
};
/// Word array.

View File

@@ -45,8 +45,9 @@ void register_core(nb::module_& m) {
.value("float16", DataType::FLOAT16)
.value("float32", DataType::FLOAT32)
.value("bfloat16", DataType::BFLOAT16)
.value("float8_e4m3", DataType::FP8_E4M3)
.value("float8_e5m2", DataType::FP8_E5M2);
.value("float8_e4m3", DataType::FLOAT8_E4M3)
.value("float8_e5m2", DataType::FLOAT8_E5M2)
.value("uint8", DataType::UINT8);
nb::class_<Bootstrap>(m, "CppBootstrap")
.def("get_rank", &Bootstrap::getRank)

View File

@@ -198,5 +198,7 @@ def torch_dtype_to_mscclpp_dtype(dtype: "torch.dtype") -> DataType:
return DataType.float8_e5m2
elif dtype == torch.float8_e4m3fn or dtype == torch.float8_e4m3fnuz:
return DataType.float8_e4m3
elif dtype == torch.uint8:
return DataType.uint8
else:
raise ValueError(f"Unknown data type: {dtype}")

View File

@@ -174,13 +174,13 @@ CommResult DslAlgorithm::execute(std::shared_ptr<Communicator> comm, const void*
stream);
break;
#if defined(__FP8_TYPES_EXIST__)
case DataType::FP8_E4M3:
executor->execute(rank, (__fp8_e4m3*)input, (__fp8_e4m3*)output, inputSize, outputSize, DataType::FP8_E4M3, plan_,
stream);
case DataType::FLOAT8_E4M3:
executor->execute(rank, (__fp8_e4m3*)input, (__fp8_e4m3*)output, inputSize, outputSize, DataType::FLOAT8_E4M3,
plan_, stream);
break;
case DataType::FP8_E5M2:
executor->execute(rank, (__fp8_e5m2*)input, (__fp8_e5m2*)output, inputSize, outputSize, DataType::FP8_E5M2, plan_,
stream);
case DataType::FLOAT8_E5M2:
executor->execute(rank, (__fp8_e5m2*)input, (__fp8_e5m2*)output, inputSize, outputSize, DataType::FLOAT8_E5M2,
plan_, stream);
break;
#endif
case DataType::INT32:

View File

@@ -78,8 +78,8 @@ void ExecutionKernel::launchKernel(int rank, int nthreadblocks, int nthreads, vo
);
#endif
break;
case DataType::FP8_E4M3:
case DataType::FP8_E5M2:
case DataType::FLOAT8_E4M3:
case DataType::FLOAT8_E5M2:
// FP8 is not supported in CUDA execution kernel.
break;
}

View File

@@ -876,7 +876,7 @@ class ExecutionKernel {
#endif
break;
#if defined(__FP8_TYPES_EXIST__)
case DataType::FP8_E4M3:
case DataType::FLOAT8_E4M3:
executionKernel<__fp8_e4m3, PacketType, ReuseScratch><<<nthreadblocks, nthreads, sharedMemSize, stream>>>(
rank, (__fp8_e4m3*)src, (__fp8_e4m3*)dst, (__fp8_e4m3*)scratch, scratchOffset, scratchChunkSize, plan,
semaphores, localMemoryIdBegin, flag
@@ -887,7 +887,7 @@ class ExecutionKernel {
);
#endif
break;
case DataType::FP8_E5M2:
case DataType::FLOAT8_E5M2:
executionKernel<__fp8_e5m2, PacketType, ReuseScratch><<<nthreadblocks, nthreads, sharedMemSize, stream>>>(
rank, (__fp8_e5m2*)src, (__fp8_e5m2*)dst, (__fp8_e5m2*)scratch, scratchOffset, scratchChunkSize, plan,
semaphores, localMemoryIdBegin, flag

View File

@@ -188,7 +188,7 @@ inline std::pair<int, int> getDefaultBlockNumAndThreadNum(size_t inputSize, int
#if defined(__FP8_TYPES_EXIST__)
// FP8-specific tuning for 32KB-256KB range
if (dtype == DataType::FP8_E4M3 || dtype == DataType::FP8_E5M2) {
if (dtype == DataType::FLOAT8_E4M3 || dtype == DataType::FLOAT8_E5M2) {
if (inputSize < (64 << 10)) {
nThreadsPerBlock = 64;
} else if (inputSize >= (64 << 10) && inputSize <= (128 << 10)) {

View File

@@ -89,9 +89,9 @@ AllreduceFunc dispatch(ReduceOp op, mscclpp::DataType dtype) {
return Adapter<SUM, __bfloat16>::call;
#endif
#if defined(__FP8_TYPES_EXIST__)
} else if (dtype == mscclpp::DataType::FP8_E4M3) {
} else if (dtype == mscclpp::DataType::FLOAT8_E4M3) {
return Adapter<SUM, __fp8_e4m3>::call;
} else if (dtype == mscclpp::DataType::FP8_E5M2) {
} else if (dtype == mscclpp::DataType::FLOAT8_E5M2) {
return Adapter<SUM, __fp8_e5m2>::call;
#endif
} else if (dtype == mscclpp::DataType::INT32 || dtype == mscclpp::DataType::UINT32) {
@@ -111,9 +111,9 @@ AllreduceFunc dispatch(ReduceOp op, mscclpp::DataType dtype) {
return Adapter<MIN, __bfloat16>::call;
#endif
#if defined(__FP8_TYPES_EXIST__)
} else if (dtype == mscclpp::DataType::FP8_E4M3) {
} else if (dtype == mscclpp::DataType::FLOAT8_E4M3) {
return Adapter<MIN, __fp8_e4m3>::call;
} else if (dtype == mscclpp::DataType::FP8_E5M2) {
} else if (dtype == mscclpp::DataType::FLOAT8_E5M2) {
return Adapter<MIN, __fp8_e5m2>::call;
#endif
} else if (dtype == mscclpp::DataType::INT32 || dtype == mscclpp::DataType::UINT32) {

View File

@@ -19,7 +19,7 @@ static bool isNvlsSupportedForDataType(const AlgorithmSelectorConfig& config, Da
return false;
}
const bool isFp8 = dtype == DataType::FP8_E4M3 || dtype == DataType::FP8_E5M2;
const bool isFp8 = dtype == DataType::FLOAT8_E4M3 || dtype == DataType::FLOAT8_E5M2;
if (!isFp8) {
return nvlsSupported;

View File

@@ -28,9 +28,9 @@ inline mscclpp::DataType ncclDataTypeToMscclpp(ncclDataType_t dtype) {
return mscclpp::DataType::BFLOAT16;
#ifdef __FP8_TYPES_EXIST__
case ncclFloat8e4m3:
return mscclpp::DataType::FP8_E4M3;
return mscclpp::DataType::FLOAT8_E4M3;
case ncclFloat8e5m2:
return mscclpp::DataType::FP8_E5M2;
return mscclpp::DataType::FLOAT8_E5M2;
#endif
default:
throw mscclpp::Error("Unsupported ncclDataType_t: " + std::to_string(dtype), mscclpp::ErrorCode::InvalidUsage);
@@ -41,8 +41,8 @@ inline mscclpp::DataType ncclDataTypeToMscclpp(ncclDataType_t dtype) {
inline size_t getDataTypeSize(mscclpp::DataType dtype) {
switch (dtype) {
case mscclpp::DataType::UINT8:
case mscclpp::DataType::FP8_E4M3:
case mscclpp::DataType::FP8_E5M2:
case mscclpp::DataType::FLOAT8_E4M3:
case mscclpp::DataType::FLOAT8_E5M2:
return 1;
case mscclpp::DataType::FLOAT16:
case mscclpp::DataType::BFLOAT16:
@@ -71,9 +71,9 @@ static inline ncclDataType_t mscclppToNcclDataType(mscclpp::DataType dtype) {
case mscclpp::DataType::BFLOAT16:
return ncclBfloat16;
#ifdef __FP8_TYPES_EXIST__
case mscclpp::DataType::FP8_E4M3:
case mscclpp::DataType::FLOAT8_E4M3:
return ncclFloat8e4m3;
case mscclpp::DataType::FP8_E5M2:
case mscclpp::DataType::FLOAT8_E5M2:
return ncclFloat8e5m2;
#endif
default: