From 1ed315638eec74e5df80e29196a8aa0d08abb6a3 Mon Sep 17 00:00:00 2001 From: Cong Ma <142121551+CongMa13@users.noreply.github.com> Date: Tue, 2 Sep 2025 23:40:18 -0600 Subject: [PATCH] [CK TILE GEMM] Fix building issues (#2772) - Add `WarpGemmMfma_f32_16x16x128_[fp8|bf8]_[fp8|bf8]_CTransposed` - Replace `__gfx950__` with `CK_GFX950_SUPPORT` [ROCm/composable_kernel commit: e1ab460d2d2f58c3bfc18f1ff360a34aeb7f478f] --- .../38_block_scale_gemm/gemm_utils.hpp | 4 ++-- include/ck_tile/ops/gemm/warp/warp_gemm.hpp | 24 +++++++++++++++++++ .../ops/gemm/warp/warp_gemm_dispatcher.hpp | 4 ++++ .../test_gemm_aquant_utils.hpp | 2 +- 4 files changed, 31 insertions(+), 3 deletions(-) diff --git a/example/ck_tile/38_block_scale_gemm/gemm_utils.hpp b/example/ck_tile/38_block_scale_gemm/gemm_utils.hpp index d64297cb35..930cdefb7e 100644 --- a/example/ck_tile/38_block_scale_gemm/gemm_utils.hpp +++ b/example/ck_tile/38_block_scale_gemm/gemm_utils.hpp @@ -18,7 +18,7 @@ template constexpr ck_tile::index_t get_k_warp_tile() { -#if defined(__gfx950__) +#if defined(CK_GFX950_SUPPORT) constexpr bool is_8bit_float = std::is_same_v || std::is_same_v; if constexpr(M_Warp_Tile == 32) @@ -35,7 +35,7 @@ constexpr ck_tile::index_t get_k_warp_tile() template constexpr ck_tile::index_t get_k_from_preshuffled_warp_tile() { -#if defined(__gfx950__) +#if defined(CK_GFX950_SUPPORT) if constexpr(M_Warp_Tile == 32) return sizeof(PrecType) == 2 ? 16 : 64; else diff --git a/include/ck_tile/ops/gemm/warp/warp_gemm.hpp b/include/ck_tile/ops/gemm/warp/warp_gemm.hpp index 87772f78fc..f83bbc2a18 100644 --- a/include/ck_tile/ops/gemm/warp/warp_gemm.hpp +++ b/include/ck_tile/ops/gemm/warp/warp_gemm.hpp @@ -301,6 +301,30 @@ using WarpGemmMfma_f32_16x16x128_bf8_bf8 = WarpGemmImpl< WarpGemmAttributeMfma, AttrNumAccess>>; +template +using WarpGemmMfma_f32_16x16x128_fp8_fp8_CTransposed = + WarpGemmImpl, + AttrNumAccess>>; + +template +using WarpGemmMfma_f32_16x16x128_fp8_bf8_CTransposed = + WarpGemmImpl, + AttrNumAccess>>; + +template +using WarpGemmMfma_f32_16x16x128_bf8_fp8_CTransposed = + WarpGemmImpl, + AttrNumAccess>>; + +template +using WarpGemmMfma_f32_16x16x128_bf8_bf8_CTransposed = + WarpGemmImpl, + AttrNumAccess>>; + template using WarpGemmMfma_f32_32x32x64_fp8_fp8 = WarpGemmImpl< WarpGemmAttributeMfma, diff --git a/include/ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp b/include/ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp index 5021fb9907..1d3dd2ae6f 100644 --- a/include/ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp +++ b/include/ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp @@ -112,6 +112,10 @@ template<> struct WarpGemmDispatcher struct WarpGemmDispatcher { using Type = WarpGemmMfma_f32_16x16x128_fp8_bf8<>; }; template<> struct WarpGemmDispatcher { using Type = WarpGemmMfma_f32_16x16x128_bf8_fp8<>; }; template<> struct WarpGemmDispatcher { using Type = WarpGemmMfma_f32_16x16x128_bf8_bf8<>; }; +template<> struct WarpGemmDispatcher { using Type = WarpGemmMfma_f32_16x16x128_fp8_fp8_CTransposed<>; }; +template<> struct WarpGemmDispatcher { using Type = WarpGemmMfma_f32_16x16x128_fp8_bf8_CTransposed<>; }; +template<> struct WarpGemmDispatcher { using Type = WarpGemmMfma_f32_16x16x128_bf8_fp8_CTransposed<>; }; +template<> struct WarpGemmDispatcher { using Type = WarpGemmMfma_f32_16x16x128_bf8_bf8_CTransposed<>; }; template<> struct WarpGemmDispatcher { using Type = WarpGemmMfma_f32_32x32x64_fp8_fp8<>; }; template<> struct WarpGemmDispatcher { using Type = WarpGemmMfma_f32_32x32x64_fp8_bf8<>; }; diff --git a/test/ck_tile/gemm_block_scale/test_gemm_aquant_utils.hpp b/test/ck_tile/gemm_block_scale/test_gemm_aquant_utils.hpp index cf9bf18c5a..61bb1a8bdd 100644 --- a/test/ck_tile/gemm_block_scale/test_gemm_aquant_utils.hpp +++ b/test/ck_tile/gemm_block_scale/test_gemm_aquant_utils.hpp @@ -18,7 +18,7 @@ template constexpr ck_tile::index_t get_k_warp_tile() { -#if defined(__gfx950__) +#if defined(CK_GFX950_SUPPORT) constexpr bool is_8bit_float = std::is_same_v || std::is_same_v; if constexpr(M_Warp_Tile == 32)