From 2f6345845e72c0b4b0ce23519746c3cb9b355826 Mon Sep 17 00:00:00 2001 From: Khushbu Agarwal Date: Mon, 12 May 2025 09:56:23 -0700 Subject: [PATCH] Disable SMFMA gfx90a (#2184) * sparsity fix for gfx90a * reverting tile_engine changes [ROCm/composable_kernel commit: f05e45ba59b76cb6ea83c471860ded65d5fc623f] --- include/ck_tile/ops/gemm/warp/warp_gemm.hpp | 9 --------- .../ops/gemm/warp/warp_gemm_attribute_smfmac_impl.hpp | 4 ++-- tile_engine/ops/gemm/gemm_instance_builder.py | 6 +----- 3 files changed, 3 insertions(+), 16 deletions(-) diff --git a/include/ck_tile/ops/gemm/warp/warp_gemm.hpp b/include/ck_tile/ops/gemm/warp/warp_gemm.hpp index 5ed97dc05c..f050a8e382 100644 --- a/include/ck_tile/ops/gemm/warp/warp_gemm.hpp +++ b/include/ck_tile/ops/gemm/warp/warp_gemm.hpp @@ -109,20 +109,11 @@ using WarpGemmMfmaF16F16F32M64N4K16 = WarpGemmImpl>; // fp16 2:4 structured sparsity -#if defined(__gfx94__) || defined(__gfx95__) using WarpGemmSmfmacF16F16F32M32N32K16 = WarpGemmSmfmacImpl>>; using WarpGemmSmfmacF16F16F32M16N16K32 = WarpGemmSmfmacImpl>>; -#else // gfx 90a does not support smfmac -using WarpGemmSmfmacF16F16F32M32N32K16 = WarpGemmImpl, - 2>>; -using WarpGemmSmfmacF16F16F32M16N16K32 = WarpGemmImpl, - 2>>; -#endif // bf16 using WarpGemmMfmaBf16Bf16F32M32N32K8 = WarpGemmImpl< diff --git a/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_smfmac_impl.hpp b/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_smfmac_impl.hpp index 97fd2a8742..cd6cd3a399 100644 --- a/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_smfmac_impl.hpp +++ b/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_smfmac_impl.hpp @@ -49,7 +49,7 @@ struct WarpGemmAttributeSmfmacImplF16F16F32M32N32K16 const int32_t& idx, bool_constant = {}) const { -#if defined(__gfx9__) +#if defined(__gfx94_) or defined(__gfx95_) c_vec = __builtin_amdgcn_smfmac_f32_32x32x16_f16(a_vec, b_vec, c_vec, idx, 0, 0); #else ck_tile::ignore = c_vec; @@ -100,7 +100,7 @@ struct WarpGemmAttributeSmfmacImplF16F16F32M16N16K32 const int32_t& idx, bool_constant = {}) const { -#if defined(__gfx9__) +#if defined(__gfx94_) or defined(__gfx95_) c_vec = __builtin_amdgcn_smfmac_f32_16x16x32_f16(a_vec, b_vec, c_vec, idx, 0, 0); #else ck_tile::ignore = c_vec; diff --git a/tile_engine/ops/gemm/gemm_instance_builder.py b/tile_engine/ops/gemm/gemm_instance_builder.py index c00554df8f..3839523e3d 100755 --- a/tile_engine/ops/gemm/gemm_instance_builder.py +++ b/tile_engine/ops/gemm/gemm_instance_builder.py @@ -535,11 +535,7 @@ struct GemmDispatcher { ((tile[6] == 32 and tile[7] == 32 and tile[8] == 16) or (tile[6] == 16 and tile[7] == 16 and tile[8] == 32)) content += f""" -#if defined(__gfx908__) - run_kernel<{group}::GemmKernel<{tile[0]}, {tile[1]}, {tile[2]}, {tile[3]}, {tile[4]}, {tile[5]}, {tile[6]}, {tile[7]}, {tile[8]}, {BOOL_MAP(False)}>>(c_m_n_dev_buf, c_m_n_host_result, c_m_n_dev_result, verify, args, stream); -#else - run_kernel<{group}::GemmKernel<{tile[0]}, {tile[1]}, {tile[2]}, {tile[3]}, {tile[4]}, {tile[5]}, {tile[6]}, {tile[7]}, {tile[8]}, {BOOL_MAP(sparse)}>>(c_m_n_dev_buf, c_m_n_host_result, c_m_n_dev_result, verify, args, stream); -#endif""" + run_kernel<{group}::GemmKernel<{tile[0]}, {tile[1]}, {tile[2]}, {tile[3]}, {tile[4]}, {tile[5]}, {tile[6]}, {tile[7]}, {tile[8]}, {BOOL_MAP(sparse)}>>(c_m_n_dev_buf, c_m_n_host_result, c_m_n_dev_result, verify, args, stream);""" content += f""" }} else {{""" for tile in tile_params: