bfloat16 support (#336)

* Add bfloat16 support for executor and NCCL interface
* Changed `gpu_data_types.hpp` into an internal header file
This commit is contained in:
Changho Hwang
2024-08-12 15:41:58 -07:00
committed by GitHub
parent faadc75649
commit 8c6fb429e9
9 changed files with 88 additions and 5 deletions

View File

@@ -16,7 +16,8 @@ void register_executor(nb::module_& m) {
.value("int32", DataType::INT32)
.value("uint32", DataType::UINT32)
.value("float16", DataType::FLOAT16)
.value("float32", DataType::FLOAT32);
.value("float32", DataType::FLOAT32)
.value("bfloat16", DataType::BFLOAT16);
nb::enum_<PacketType>(m, "PacketType").value("LL8", PacketType::LL8).value("LL16", PacketType::LL16);