From 9b67733bd19b6c8b23576bd1a9cbdb6d2f9b1ef1 Mon Sep 17 00:00:00 2001 From: Khushbu Agarwal Date: Thu, 24 Apr 2025 10:20:22 -0700 Subject: [PATCH] MFMA_32x32x16 for gfx950 (#2121) * Enable MFMA_32x32x16 for fp16/BF16 for gfx950 * clang formatted [ROCm/composable_kernel commit: a2ed34a112982664132db5283ee4d1b1aac746d5] --- include/ck_tile/ops/gemm/warp/warp_gemm.hpp | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/include/ck_tile/ops/gemm/warp/warp_gemm.hpp b/include/ck_tile/ops/gemm/warp/warp_gemm.hpp index e6350a8827..4732027e57 100644 --- a/include/ck_tile/ops/gemm/warp/warp_gemm.hpp +++ b/include/ck_tile/ops/gemm/warp/warp_gemm.hpp @@ -20,9 +20,15 @@ using WarpGemmMfmaF16F16F32M32N32K8 = WarpGemmImpl< using WarpGemmMfmaF16F16F32M16N16K16 = WarpGemmImpl< WarpGemmAtrributeMfma>>; +#if defined(__gfx950__) +using WarpGemmMfmaF16F16F32M32N32K16 = WarpGemmImpl< + WarpGemmAtrributeMfma>>; + +#else using WarpGemmMfmaF16F16F32M32N32K16 = WarpGemmImpl, 2>>; +#endif #if defined(__gfx950__) using WarpGemmMfmaF16F16F32M16N16K32 = WarpGemmImpl< @@ -105,9 +111,15 @@ using WarpGemmMfmaBf16Bf16F32M32N32K8 = WarpGemmImpl< using WarpGemmMfmaBf16Bf16F32M16N16K16 = WarpGemmImpl< WarpGemmAtrributeMfma>>; +#if defined(__gfx950__) +using WarpGemmMfmaBf16Bf16F32M32N32K16 = WarpGemmImpl< + WarpGemmAtrributeMfma>>; + +#else using WarpGemmMfmaBf16Bf16F32M32N32K16 = WarpGemmImpl, 2>>; +#endif #if defined(__gfx950__) using WarpGemmMfmaBf16Bf16F32M16N16K32 = WarpGemmImpl<