diff --git a/include/ck_tile/ops/gemm/warp/warp_gemm.hpp b/include/ck_tile/ops/gemm/warp/warp_gemm.hpp index 5ed97dc05c..f050a8e382 100644 --- a/include/ck_tile/ops/gemm/warp/warp_gemm.hpp +++ b/include/ck_tile/ops/gemm/warp/warp_gemm.hpp @@ -109,20 +109,11 @@ using WarpGemmMfmaF16F16F32M64N4K16 = WarpGemmImpl>; // fp16 2:4 structured sparsity -#if defined(__gfx94__) || defined(__gfx95__) using WarpGemmSmfmacF16F16F32M32N32K16 = WarpGemmSmfmacImpl>>; using WarpGemmSmfmacF16F16F32M16N16K32 = WarpGemmSmfmacImpl>>; -#else // gfx 90a does not support smfmac -using WarpGemmSmfmacF16F16F32M32N32K16 = WarpGemmImpl, - 2>>; -using WarpGemmSmfmacF16F16F32M16N16K32 = WarpGemmImpl, - 2>>; -#endif // bf16 using WarpGemmMfmaBf16Bf16F32M32N32K8 = WarpGemmImpl< diff --git a/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_smfmac_impl.hpp b/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_smfmac_impl.hpp index 97fd2a8742..cd6cd3a399 100644 --- a/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_smfmac_impl.hpp +++ b/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_smfmac_impl.hpp @@ -49,7 +49,7 @@ struct WarpGemmAttributeSmfmacImplF16F16F32M32N32K16 const int32_t& idx, bool_constant = {}) const { -#if defined(__gfx9__) +#if defined(__gfx94_) or defined(__gfx95_) c_vec = __builtin_amdgcn_smfmac_f32_32x32x16_f16(a_vec, b_vec, c_vec, idx, 0, 0); #else ck_tile::ignore = c_vec; @@ -100,7 +100,7 @@ struct WarpGemmAttributeSmfmacImplF16F16F32M16N16K32 const int32_t& idx, bool_constant = {}) const { -#if defined(__gfx9__) +#if defined(__gfx94_) or defined(__gfx95_) c_vec = __builtin_amdgcn_smfmac_f32_16x16x32_f16(a_vec, b_vec, c_vec, idx, 0, 0); #else ck_tile::ignore = c_vec; diff --git a/tile_engine/ops/gemm/gemm_instance_builder.py b/tile_engine/ops/gemm/gemm_instance_builder.py index c00554df8f..3839523e3d 100755 --- a/tile_engine/ops/gemm/gemm_instance_builder.py +++ b/tile_engine/ops/gemm/gemm_instance_builder.py @@ -535,11 +535,7 @@ struct GemmDispatcher { ((tile[6] == 32 and tile[7] == 32 and tile[8] == 16) or (tile[6] == 16 and tile[7] == 16 and tile[8] == 32)) content += f""" -#if defined(__gfx908__) - run_kernel<{group}::GemmKernel<{tile[0]}, {tile[1]}, {tile[2]}, {tile[3]}, {tile[4]}, {tile[5]}, {tile[6]}, {tile[7]}, {tile[8]}, {BOOL_MAP(False)}>>(c_m_n_dev_buf, c_m_n_host_result, c_m_n_dev_result, verify, args, stream); -#else - run_kernel<{group}::GemmKernel<{tile[0]}, {tile[1]}, {tile[2]}, {tile[3]}, {tile[4]}, {tile[5]}, {tile[6]}, {tile[7]}, {tile[8]}, {BOOL_MAP(sparse)}>>(c_m_n_dev_buf, c_m_n_host_result, c_m_n_dev_result, verify, args, stream); -#endif""" + run_kernel<{group}::GemmKernel<{tile[0]}, {tile[1]}, {tile[2]}, {tile[3]}, {tile[4]}, {tile[5]}, {tile[6]}, {tile[7]}, {tile[8]}, {BOOL_MAP(sparse)}>>(c_m_n_dev_buf, c_m_n_host_result, c_m_n_dev_result, verify, args, stream);""" content += f""" }} else {{""" for tile in tile_params: