From 57d3525983943dd4c67082ff1b39c04439cf3ef3 Mon Sep 17 00:00:00 2001 From: Illia Silin <98187287+illsilin@users.noreply.github.com> Date: Tue, 17 Dec 2024 10:17:29 -0800 Subject: [PATCH] Pass build flags to config.h (#1760) * pass the build flags to config.h * fix clang format [ROCm/composable_kernel commit: 689a5ae45be802f51fc947a9f92208dcfb143f77] --- CMakeLists.txt | 4 ++++ include/ck/config.h.in | 16 ++++++++++++++++ include/ck/utility/amd_ck_fp8.hpp | 20 +++++++++++++------- 3 files changed, 33 insertions(+), 7 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 2c86987561..be4efd3dfd 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -183,14 +183,17 @@ message("Building CK for the following targets: ${SUPPORTED_GPU_TARGETS}") if (SUPPORTED_GPU_TARGETS MATCHES "gfx9") message("Enabling XDL instances") add_definitions(-DCK_USE_XDL) + set(CK_USE_XDL "ON") endif() if (SUPPORTED_GPU_TARGETS MATCHES "gfx94") message("Enabling FP8 gemms on native architectures") add_definitions(-DCK_USE_GFX94) + set(CK_USE_GFX94 "ON") endif() if (SUPPORTED_GPU_TARGETS MATCHES "gfx11" OR SUPPORTED_GPU_TARGETS MATCHES "gfx12") message("Enabling WMMA instances") add_definitions(-DCK_USE_WMMA) + set(CK_USE_WMMA "ON") endif() if (SUPPORTED_GPU_TARGETS MATCHES "gfx12") add_definitions(-DCK_USE_OCP_FP8) @@ -204,6 +207,7 @@ endif() option(CK_USE_FP8_ON_UNSUPPORTED_ARCH "Enable FP8 GEMM instances on older architectures" OFF) if(CK_USE_FP8_ON_UNSUPPORTED_ARCH AND (SUPPORTED_GPU_TARGETS MATCHES "gfx90a" OR SUPPORTED_GPU_TARGETS MATCHES "gfx908")) add_definitions(-DCK_USE_FP8_ON_UNSUPPORTED_ARCH) + set(CK_USE_FP8_ON_UNSUPPORTED_ARCH "ON") endif() # CK config file to record supported datatypes, etc. diff --git a/include/ck/config.h.in b/include/ck/config.h.in index 0f0b7bd607..55a498073f 100644 --- a/include/ck/config.h.in +++ b/include/ck/config.h.in @@ -111,6 +111,22 @@ #cmakedefine CK_USE_WMMA @CK_USE_WMMA@ #endif +#ifndef CK_USE_GFX94 +#cmakedefine CK_USE_GFX94 @CK_USE_GFX94@ +#endif + +#ifndef DCK_USE_OCP_FP8 +#cmakedefine DCK_USE_OCP_FP8 @DCK_USE_OCP_FP8@ +#endif + +#ifndef CK_USE_FNUZ_FP8 +#cmakedefine CK_USE_FNUZ_FP8 @CK_USE_FNUZ_FP8@ +#endif + +#ifndef CK_USE_FP8_ON_UNSUPPORTED_ARCH +#cmakedefine CK_USE_FP8_ON_UNSUPPORTED_ARCH @CK_USE_FP8_ON_UNSUPPORTED_ARCH@ +#endif + // clang-format on #endif // CK_CONFIG_H_IN diff --git a/include/ck/utility/amd_ck_fp8.hpp b/include/ck/utility/amd_ck_fp8.hpp index 1bdb1d078e..e9174904c9 100644 --- a/include/ck/utility/amd_ck_fp8.hpp +++ b/include/ck/utility/amd_ck_fp8.hpp @@ -20,9 +20,17 @@ namespace { // https://en.cppreference.com/w/cpp/types/conditional -template struct conditional { using type = T; }; -template struct conditional { using type = F; }; -} +template +struct conditional +{ + using type = T; +}; +template +struct conditional +{ + using type = F; +}; +} // namespace namespace ck { @@ -200,8 +208,7 @@ __host__ __device__ static inline T cast_from_f8(fp8_storage_t x) typename conditional< sizeof(T) == 2, unsigned short int, - typename conditional:: - type>::type retval; + typename conditional::type>::type retval; if constexpr(we == 5 && is_half && !is_fnuz) { @@ -547,8 +554,7 @@ __host__ __device__ static inline fp8_storage_t cast_to_f8(T _x, unsigned int rn using T_bitwise = typename conditional< sizeof(T) == 2, unsigned short int, - typename conditional:: - type>::type; + typename conditional::type>::type; T_bitwise x_bitwise = bit_cast(_x); unsigned long long x{x_bitwise};