mirror of
https://github.com/microsoft/mscclpp.git
synced 2026-04-20 06:49:29 +00:00
Update dtype name (#748)
- Change FP8_E4M3/FP8_E5M2 to FLOAT8_E4M3/FLOAT8_E5M2 - Add torch.uint8 to DataType.uint8 mapping
This commit is contained in:
@@ -68,14 +68,14 @@ namespace mscclpp {
|
||||
|
||||
/// Data types supported by mscclpp operations.
|
||||
enum class DataType {
|
||||
INT32, // 32-bit signed integer.
|
||||
UINT32, // 32-bit unsigned integer.
|
||||
FLOAT16, // IEEE 754 half precision.
|
||||
FLOAT32, // IEEE 754 single precision.
|
||||
BFLOAT16, // bfloat16 precision.
|
||||
FP8_E4M3, // FP8 with E4M3 layout.
|
||||
FP8_E5M2, // FP8 with E5M2 layout.
|
||||
UINT8, // 8-bit unsigned integer.
|
||||
INT32, // 32-bit signed integer.
|
||||
UINT32, // 32-bit unsigned integer.
|
||||
FLOAT16, // IEEE 754 half precision.
|
||||
FLOAT32, // IEEE 754 single precision.
|
||||
BFLOAT16, // bfloat16 precision.
|
||||
FLOAT8_E4M3, // float8 with E4M3 layout.
|
||||
FLOAT8_E5M2, // float8 with E5M2 layout.
|
||||
UINT8, // 8-bit unsigned integer.
|
||||
};
|
||||
|
||||
/// Word array.
|
||||
|
||||
@@ -45,8 +45,9 @@ void register_core(nb::module_& m) {
|
||||
.value("float16", DataType::FLOAT16)
|
||||
.value("float32", DataType::FLOAT32)
|
||||
.value("bfloat16", DataType::BFLOAT16)
|
||||
.value("float8_e4m3", DataType::FP8_E4M3)
|
||||
.value("float8_e5m2", DataType::FP8_E5M2);
|
||||
.value("float8_e4m3", DataType::FLOAT8_E4M3)
|
||||
.value("float8_e5m2", DataType::FLOAT8_E5M2)
|
||||
.value("uint8", DataType::UINT8);
|
||||
|
||||
nb::class_<Bootstrap>(m, "CppBootstrap")
|
||||
.def("get_rank", &Bootstrap::getRank)
|
||||
|
||||
@@ -198,5 +198,7 @@ def torch_dtype_to_mscclpp_dtype(dtype: "torch.dtype") -> DataType:
|
||||
return DataType.float8_e5m2
|
||||
elif dtype == torch.float8_e4m3fn or dtype == torch.float8_e4m3fnuz:
|
||||
return DataType.float8_e4m3
|
||||
elif dtype == torch.uint8:
|
||||
return DataType.uint8
|
||||
else:
|
||||
raise ValueError(f"Unknown data type: {dtype}")
|
||||
|
||||
@@ -174,13 +174,13 @@ CommResult DslAlgorithm::execute(std::shared_ptr<Communicator> comm, const void*
|
||||
stream);
|
||||
break;
|
||||
#if defined(__FP8_TYPES_EXIST__)
|
||||
case DataType::FP8_E4M3:
|
||||
executor->execute(rank, (__fp8_e4m3*)input, (__fp8_e4m3*)output, inputSize, outputSize, DataType::FP8_E4M3, plan_,
|
||||
stream);
|
||||
case DataType::FLOAT8_E4M3:
|
||||
executor->execute(rank, (__fp8_e4m3*)input, (__fp8_e4m3*)output, inputSize, outputSize, DataType::FLOAT8_E4M3,
|
||||
plan_, stream);
|
||||
break;
|
||||
case DataType::FP8_E5M2:
|
||||
executor->execute(rank, (__fp8_e5m2*)input, (__fp8_e5m2*)output, inputSize, outputSize, DataType::FP8_E5M2, plan_,
|
||||
stream);
|
||||
case DataType::FLOAT8_E5M2:
|
||||
executor->execute(rank, (__fp8_e5m2*)input, (__fp8_e5m2*)output, inputSize, outputSize, DataType::FLOAT8_E5M2,
|
||||
plan_, stream);
|
||||
break;
|
||||
#endif
|
||||
case DataType::INT32:
|
||||
|
||||
@@ -78,8 +78,8 @@ void ExecutionKernel::launchKernel(int rank, int nthreadblocks, int nthreads, vo
|
||||
);
|
||||
#endif
|
||||
break;
|
||||
case DataType::FP8_E4M3:
|
||||
case DataType::FP8_E5M2:
|
||||
case DataType::FLOAT8_E4M3:
|
||||
case DataType::FLOAT8_E5M2:
|
||||
// FP8 is not supported in CUDA execution kernel.
|
||||
break;
|
||||
}
|
||||
|
||||
@@ -876,7 +876,7 @@ class ExecutionKernel {
|
||||
#endif
|
||||
break;
|
||||
#if defined(__FP8_TYPES_EXIST__)
|
||||
case DataType::FP8_E4M3:
|
||||
case DataType::FLOAT8_E4M3:
|
||||
executionKernel<__fp8_e4m3, PacketType, ReuseScratch><<<nthreadblocks, nthreads, sharedMemSize, stream>>>(
|
||||
rank, (__fp8_e4m3*)src, (__fp8_e4m3*)dst, (__fp8_e4m3*)scratch, scratchOffset, scratchChunkSize, plan,
|
||||
semaphores, localMemoryIdBegin, flag
|
||||
@@ -887,7 +887,7 @@ class ExecutionKernel {
|
||||
);
|
||||
#endif
|
||||
break;
|
||||
case DataType::FP8_E5M2:
|
||||
case DataType::FLOAT8_E5M2:
|
||||
executionKernel<__fp8_e5m2, PacketType, ReuseScratch><<<nthreadblocks, nthreads, sharedMemSize, stream>>>(
|
||||
rank, (__fp8_e5m2*)src, (__fp8_e5m2*)dst, (__fp8_e5m2*)scratch, scratchOffset, scratchChunkSize, plan,
|
||||
semaphores, localMemoryIdBegin, flag
|
||||
|
||||
@@ -188,7 +188,7 @@ inline std::pair<int, int> getDefaultBlockNumAndThreadNum(size_t inputSize, int
|
||||
|
||||
#if defined(__FP8_TYPES_EXIST__)
|
||||
// FP8-specific tuning for 32KB-256KB range
|
||||
if (dtype == DataType::FP8_E4M3 || dtype == DataType::FP8_E5M2) {
|
||||
if (dtype == DataType::FLOAT8_E4M3 || dtype == DataType::FLOAT8_E5M2) {
|
||||
if (inputSize < (64 << 10)) {
|
||||
nThreadsPerBlock = 64;
|
||||
} else if (inputSize >= (64 << 10) && inputSize <= (128 << 10)) {
|
||||
|
||||
@@ -89,9 +89,9 @@ AllreduceFunc dispatch(ReduceOp op, mscclpp::DataType dtype) {
|
||||
return Adapter<SUM, __bfloat16>::call;
|
||||
#endif
|
||||
#if defined(__FP8_TYPES_EXIST__)
|
||||
} else if (dtype == mscclpp::DataType::FP8_E4M3) {
|
||||
} else if (dtype == mscclpp::DataType::FLOAT8_E4M3) {
|
||||
return Adapter<SUM, __fp8_e4m3>::call;
|
||||
} else if (dtype == mscclpp::DataType::FP8_E5M2) {
|
||||
} else if (dtype == mscclpp::DataType::FLOAT8_E5M2) {
|
||||
return Adapter<SUM, __fp8_e5m2>::call;
|
||||
#endif
|
||||
} else if (dtype == mscclpp::DataType::INT32 || dtype == mscclpp::DataType::UINT32) {
|
||||
@@ -111,9 +111,9 @@ AllreduceFunc dispatch(ReduceOp op, mscclpp::DataType dtype) {
|
||||
return Adapter<MIN, __bfloat16>::call;
|
||||
#endif
|
||||
#if defined(__FP8_TYPES_EXIST__)
|
||||
} else if (dtype == mscclpp::DataType::FP8_E4M3) {
|
||||
} else if (dtype == mscclpp::DataType::FLOAT8_E4M3) {
|
||||
return Adapter<MIN, __fp8_e4m3>::call;
|
||||
} else if (dtype == mscclpp::DataType::FP8_E5M2) {
|
||||
} else if (dtype == mscclpp::DataType::FLOAT8_E5M2) {
|
||||
return Adapter<MIN, __fp8_e5m2>::call;
|
||||
#endif
|
||||
} else if (dtype == mscclpp::DataType::INT32 || dtype == mscclpp::DataType::UINT32) {
|
||||
|
||||
@@ -19,7 +19,7 @@ static bool isNvlsSupportedForDataType(const AlgorithmSelectorConfig& config, Da
|
||||
return false;
|
||||
}
|
||||
|
||||
const bool isFp8 = dtype == DataType::FP8_E4M3 || dtype == DataType::FP8_E5M2;
|
||||
const bool isFp8 = dtype == DataType::FLOAT8_E4M3 || dtype == DataType::FLOAT8_E5M2;
|
||||
|
||||
if (!isFp8) {
|
||||
return nvlsSupported;
|
||||
|
||||
@@ -28,9 +28,9 @@ inline mscclpp::DataType ncclDataTypeToMscclpp(ncclDataType_t dtype) {
|
||||
return mscclpp::DataType::BFLOAT16;
|
||||
#ifdef __FP8_TYPES_EXIST__
|
||||
case ncclFloat8e4m3:
|
||||
return mscclpp::DataType::FP8_E4M3;
|
||||
return mscclpp::DataType::FLOAT8_E4M3;
|
||||
case ncclFloat8e5m2:
|
||||
return mscclpp::DataType::FP8_E5M2;
|
||||
return mscclpp::DataType::FLOAT8_E5M2;
|
||||
#endif
|
||||
default:
|
||||
throw mscclpp::Error("Unsupported ncclDataType_t: " + std::to_string(dtype), mscclpp::ErrorCode::InvalidUsage);
|
||||
@@ -41,8 +41,8 @@ inline mscclpp::DataType ncclDataTypeToMscclpp(ncclDataType_t dtype) {
|
||||
inline size_t getDataTypeSize(mscclpp::DataType dtype) {
|
||||
switch (dtype) {
|
||||
case mscclpp::DataType::UINT8:
|
||||
case mscclpp::DataType::FP8_E4M3:
|
||||
case mscclpp::DataType::FP8_E5M2:
|
||||
case mscclpp::DataType::FLOAT8_E4M3:
|
||||
case mscclpp::DataType::FLOAT8_E5M2:
|
||||
return 1;
|
||||
case mscclpp::DataType::FLOAT16:
|
||||
case mscclpp::DataType::BFLOAT16:
|
||||
@@ -71,9 +71,9 @@ static inline ncclDataType_t mscclppToNcclDataType(mscclpp::DataType dtype) {
|
||||
case mscclpp::DataType::BFLOAT16:
|
||||
return ncclBfloat16;
|
||||
#ifdef __FP8_TYPES_EXIST__
|
||||
case mscclpp::DataType::FP8_E4M3:
|
||||
case mscclpp::DataType::FLOAT8_E4M3:
|
||||
return ncclFloat8e4m3;
|
||||
case mscclpp::DataType::FP8_E5M2:
|
||||
case mscclpp::DataType::FLOAT8_E5M2:
|
||||
return ncclFloat8e5m2;
|
||||
#endif
|
||||
default:
|
||||
|
||||
Reference in New Issue
Block a user