Refactor f8_t, add bf8_t (#792)

* Refactor f8_t to add bf8_t

* Add check_err impl for f8_t

* Update fp8 test

* Format

* Revert the fix

* Update vector_type implementation

* Add bf8 test

* Add bf8, use BitInt types

* Add bf8 conversion methods

* Update type_convert for fp8/bf8

* Add check_err fp8/bf8 support

* Add subnorm fp8 tests

* Add subnorm bf8 tests

* Fix conversion

* Add bf8 cmake bindings

* Add macros to enable build with disabled fp8/bf8

* Remove is_native method

* Update flag combination for mixed precision instances

* Add more flag checks

* Add another flag to a client example

* Add type traits, decouple f8/bf8 casting

* Clean up

* Decouple fp8 and bf8 flags

* Remove more redundant flags

* Remove leftover comments

[ROCm/composable_kernel commit: 62d4af7449]
This commit is contained in:
Rostyslav Geyyer
2023-09-12 17:04:27 -05:00
committed by GitHub
parent e885110f62
commit 2e227b8581
23 changed files with 739 additions and 172 deletions

View File

@@ -1,7 +1,13 @@
add_instance_library(device_gemm_multiply_add_instance
device_gemm_multiply_add_xdl_c_shuffle_f16_f16_f16_f16_f16_mk_kn_mn_mn_mn_instance.cpp
device_gemm_multiply_add_xdl_c_shuffle_f16_f16_f16_f16_f16_mk_nk_mn_mn_mn_instance.cpp
set(GEMM_MULTIPLY_ADD_INSTANCES)
device_gemm_multiply_add_xdl_c_shuffle_f16_f8_f32_f32_f16_mk_kn_mn_mn_mn_instance.cpp
device_gemm_multiply_add_xdl_c_shuffle_f16_f8_f32_f32_f16_mk_nk_mn_mn_mn_instance.cpp
)
if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES)
list(APPEND GEMM_MULTIPLY_ADD_INSTANCES device_gemm_multiply_add_xdl_c_shuffle_f16_f16_f16_f16_f16_mk_kn_mn_mn_mn_instance.cpp)
list(APPEND GEMM_MULTIPLY_ADD_INSTANCES device_gemm_multiply_add_xdl_c_shuffle_f16_f16_f16_f16_f16_mk_nk_mn_mn_mn_instance.cpp)
endif()
if((DTYPES MATCHES "fp16" AND DTYPES MATCHES "fp8") OR NOT DEFINED DTYPES)
list(APPEND GEMM_MULTIPLY_ADD_INSTANCES device_gemm_multiply_add_xdl_c_shuffle_f16_f8_f32_f32_f16_mk_kn_mn_mn_mn_instance.cpp)
list(APPEND GEMM_MULTIPLY_ADD_INSTANCES device_gemm_multiply_add_xdl_c_shuffle_f16_f8_f32_f32_f16_mk_nk_mn_mn_mn_instance.cpp)
endif()
add_instance_library(device_gemm_multiply_add_instance ${GEMM_MULTIPLY_ADD_INSTANCES})

View File

@@ -14,7 +14,7 @@ if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES)
list(APPEND GEMM_SPLITK_INSTANCES device_gemm_xdl_splitk_f16_f16_f16_km_nk_mn_instance.cpp)
endif()
if(DTYPES MATCHES "fp16" OR DTYPES MATCHES "fp8" OR NOT DEFINED DTYPES)
if((DTYPES MATCHES "fp16" AND DTYPES MATCHES "fp8") OR NOT DEFINED DTYPES)
list(APPEND GEMM_SPLITK_INSTANCES device_gemm_xdl_splitk_f8_f16_f16_mk_kn_mn_instance.cpp)
list(APPEND GEMM_SPLITK_INSTANCES device_gemm_xdl_splitk_f8_f16_f16_mk_nk_mn_instance.cpp)
list(APPEND GEMM_SPLITK_INSTANCES device_gemm_xdl_splitk_f8_f16_f16_km_kn_mn_instance.cpp)