From 2ff2610908035c52a8ea3815066e6a9ba1073b7e Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" Date: Fri, 9 May 2025 08:07:08 +0000 Subject: [PATCH] Merge commit 'ef72a4b9bc2e5ddc63d9138cae4e5eba23d35b16' into develop --- include/ck_tile/ops/gemm/warp/warp_gemm.hpp | 10 +++++++++- tile_engine/ops/gemm/gemm_instance_builder.py | 6 +++++- 2 files changed, 14 insertions(+), 2 deletions(-) diff --git a/include/ck_tile/ops/gemm/warp/warp_gemm.hpp b/include/ck_tile/ops/gemm/warp/warp_gemm.hpp index c98d46e3a0..5cc5ddc70e 100644 --- a/include/ck_tile/ops/gemm/warp/warp_gemm.hpp +++ b/include/ck_tile/ops/gemm/warp/warp_gemm.hpp @@ -97,12 +97,20 @@ using WarpGemmMfmaF16F16F32M64N4K16 = WarpGemmImpl>; // fp16 2:4 structured sparsity - +#if defined(__gfx94__) || defined(__gfx95__) using WarpGemmSmfmacF16F16F32M32N32K16 = WarpGemmSmfmacImpl>>; using WarpGemmSmfmacF16F16F32M16N16K32 = WarpGemmSmfmacImpl>>; +#else // gfx 90a does not support smfmac +using WarpGemmSmfmacF16F16F32M32N32K16 = WarpGemmImpl, + 2>>; +using WarpGemmSmfmacF16F16F32M16N16K32 = WarpGemmImpl, + 2>>; +#endif // bf16 using WarpGemmMfmaBf16Bf16F32M32N32K8 = WarpGemmImpl< diff --git a/tile_engine/ops/gemm/gemm_instance_builder.py b/tile_engine/ops/gemm/gemm_instance_builder.py index 3839523e3d..c00554df8f 100755 --- a/tile_engine/ops/gemm/gemm_instance_builder.py +++ b/tile_engine/ops/gemm/gemm_instance_builder.py @@ -535,7 +535,11 @@ struct GemmDispatcher { ((tile[6] == 32 and tile[7] == 32 and tile[8] == 16) or (tile[6] == 16 and tile[7] == 16 and tile[8] == 32)) content += f""" - run_kernel<{group}::GemmKernel<{tile[0]}, {tile[1]}, {tile[2]}, {tile[3]}, {tile[4]}, {tile[5]}, {tile[6]}, {tile[7]}, {tile[8]}, {BOOL_MAP(sparse)}>>(c_m_n_dev_buf, c_m_n_host_result, c_m_n_dev_result, verify, args, stream);""" +#if defined(__gfx908__) + run_kernel<{group}::GemmKernel<{tile[0]}, {tile[1]}, {tile[2]}, {tile[3]}, {tile[4]}, {tile[5]}, {tile[6]}, {tile[7]}, {tile[8]}, {BOOL_MAP(False)}>>(c_m_n_dev_buf, c_m_n_host_result, c_m_n_dev_result, verify, args, stream); +#else + run_kernel<{group}::GemmKernel<{tile[0]}, {tile[1]}, {tile[2]}, {tile[3]}, {tile[4]}, {tile[5]}, {tile[6]}, {tile[7]}, {tile[8]}, {BOOL_MAP(sparse)}>>(c_m_n_dev_buf, c_m_n_host_result, c_m_n_dev_result, verify, args, stream); +#endif""" content += f""" }} else {{""" for tile in tile_params: