diff --git a/include/ck_tile/ops/gemm/warp/warp_gemm.hpp b/include/ck_tile/ops/gemm/warp/warp_gemm.hpp index e6350a8827..4732027e57 100644 --- a/include/ck_tile/ops/gemm/warp/warp_gemm.hpp +++ b/include/ck_tile/ops/gemm/warp/warp_gemm.hpp @@ -20,9 +20,15 @@ using WarpGemmMfmaF16F16F32M32N32K8 = WarpGemmImpl< using WarpGemmMfmaF16F16F32M16N16K16 = WarpGemmImpl< WarpGemmAtrributeMfma>>; +#if defined(__gfx950__) +using WarpGemmMfmaF16F16F32M32N32K16 = WarpGemmImpl< + WarpGemmAtrributeMfma>>; + +#else using WarpGemmMfmaF16F16F32M32N32K16 = WarpGemmImpl, 2>>; +#endif #if defined(__gfx950__) using WarpGemmMfmaF16F16F32M16N16K32 = WarpGemmImpl< @@ -105,9 +111,15 @@ using WarpGemmMfmaBf16Bf16F32M32N32K8 = WarpGemmImpl< using WarpGemmMfmaBf16Bf16F32M16N16K16 = WarpGemmImpl< WarpGemmAtrributeMfma>>; +#if defined(__gfx950__) +using WarpGemmMfmaBf16Bf16F32M32N32K16 = WarpGemmImpl< + WarpGemmAtrributeMfma>>; + +#else using WarpGemmMfmaBf16Bf16F32M32N32K16 = WarpGemmImpl, 2>>; +#endif #if defined(__gfx950__) using WarpGemmMfmaBf16Bf16F32M16N16K32 = WarpGemmImpl<