Update dtype name (#748)

- Change FP8_E4M3/FP8_E5M2 to FLOAT8_E4M3/FLOAT8_E5M2
- Add torch.uint8 to DataType.uint8 mapping
This commit is contained in:
Binyang Li
2026-02-18 10:35:44 -08:00
committed by GitHub
parent d0d5a8c034
commit 4701ae3a95
10 changed files with 35 additions and 32 deletions

View File

@@ -174,13 +174,13 @@ CommResult DslAlgorithm::execute(std::shared_ptr<Communicator> comm, const void*
stream);
break;
#if defined(__FP8_TYPES_EXIST__)
case DataType::FP8_E4M3:
executor->execute(rank, (__fp8_e4m3*)input, (__fp8_e4m3*)output, inputSize, outputSize, DataType::FP8_E4M3, plan_,
stream);
case DataType::FLOAT8_E4M3:
executor->execute(rank, (__fp8_e4m3*)input, (__fp8_e4m3*)output, inputSize, outputSize, DataType::FLOAT8_E4M3,
plan_, stream);
break;
case DataType::FP8_E5M2:
executor->execute(rank, (__fp8_e5m2*)input, (__fp8_e5m2*)output, inputSize, outputSize, DataType::FP8_E5M2, plan_,
stream);
case DataType::FLOAT8_E5M2:
executor->execute(rank, (__fp8_e5m2*)input, (__fp8_e5m2*)output, inputSize, outputSize, DataType::FLOAT8_E5M2,
plan_, stream);
break;
#endif
case DataType::INT32: