From 6fe28f408c5f1976dd5be2bbee694b1580516c8a Mon Sep 17 00:00:00 2001 From: Michael Mcminn <47832147+UD-mmcminn@users.noreply.github.com> Date: Tue, 2 Sep 2025 04:35:07 -0400 Subject: [PATCH] =?UTF-8?q?Adding=20fix=20for=20the=20gfx908=20to=20the=20?= =?UTF-8?q?GEMM=20MFMA=20implementaitons=20of=20WarpGem=E2=80=A6=20(#2751)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Adding fix for the gfx908 to the GEMM MFMA implementaitons of WarpGemmMfmaBf16Bf16F32M4N64K16 WarpGemmMfmaBf16Bf16F32M64N4K16 * Adding support for offload target gfx9-4-generic * This duplication here isn't ideal [ROCm/composable_kernel commit: 022f369deb06e202f6a0dd72b6759c9332e6d395] --- include/ck/ck.hpp | 5 +- include/ck_tile/core/config.hpp | 5 +- .../warp/warp_gemm_attribute_mfma_impl.hpp | 60 +++++++++++++++++-- 3 files changed, 62 insertions(+), 8 deletions(-) diff --git a/include/ck/ck.hpp b/include/ck/ck.hpp index 09801203ba..b8a1afec4e 100644 --- a/include/ck/ck.hpp +++ b/include/ck/ck.hpp @@ -50,10 +50,11 @@ #endif // define general macros for various architectures -#if defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx942__) || defined(__gfx950__) +#if defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx942__) || \ + defined(__gfx950__) || defined(__gfx9_4_generic__) #define __gfx9__ #endif -#if defined(__gfx942__) || defined(__gfx950__) +#if defined(__gfx942__) || defined(__gfx950__) || defined(__gfx9_4_generic__) #define __gfx94__ #endif #if defined(__gfx1010__) || defined(__gfx1011__) || defined(__gfx1012__) diff --git a/include/ck_tile/core/config.hpp b/include/ck_tile/core/config.hpp index 7b5b862cb1..0d4aa58026 100644 --- a/include/ck_tile/core/config.hpp +++ b/include/ck_tile/core/config.hpp @@ -3,10 +3,11 @@ #pragma once -#if defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx942__) || defined(__gfx950__) +#if defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx942__) || \\ + defined(__gfx950__) || defined(__gfx9_4_generic__) #define __gfx9__ #endif -#if defined(__gfx942__) || defined(__gfx950__) +#if defined(__gfx942__) || defined(__gfx950__) || defined(__gfx9_4_generic__) #define __gfx94__ #endif #if defined(__gfx1030__) || defined(__gfx1031__) || defined(__gfx1032__) || \ diff --git a/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma_impl.hpp b/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma_impl.hpp index 0831cf85c4..11a8416fb2 100644 --- a/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma_impl.hpp +++ b/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma_impl.hpp @@ -660,8 +660,20 @@ struct WarpGemmAttributeMfmaImplBf16Bf16F32M4N64K4 DISPATCH_MFMA_CTRL_("v_mfma_f32_4x4x4bf16_1k", Ctrl) else { -#if defined(__gfx9__) +#if defined(__gfx90a__) || defined(__gfx94__) c_vec = __builtin_amdgcn_mfma_f32_4x4x4bf16_1k(a_vec, b_vec, c_vec, 0, 0, 0); +#elif defined(__gfx908__) + static_for<0, 2, 1>{}([&](auto k) { + c_vec = __builtin_amdgcn_mfma_f32_4x4x2bf16( + reinterpret_cast&>(a_vec) + .template get_as>()[number{}], + reinterpret_cast&>(b_vec) + .template get_as>()[number{}], + c_vec, + 0, + 0, + 0); + }); #else ignore = c_vec; ignore = a_vec; @@ -673,9 +685,23 @@ struct WarpGemmAttributeMfmaImplBf16Bf16F32M4N64K4 // c_vec = a_vec * b_vec CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const { -#if defined(__gfx9__) +#if defined(__gfx90a__) || defined(__gfx94__) return bit_cast( __builtin_amdgcn_mfma_f32_4x4x4bf16_1k(a_vec, b_vec, fp32x4_t{0.f}, 0, 0, 0)); +#elif defined(__gfx908__) + CVecType c_vec{0.f}; + static_for<0, 2, 1>{}([&](auto k) { + c_vec = __builtin_amdgcn_mfma_f32_4x4x2bf16( + reinterpret_cast&>(a_vec) + .template get_as>()[number{}], + reinterpret_cast&>(b_vec) + .template get_as>()[number{}], + c_vec, + 0, + 0, + 0); + }); + return c_vec; #else ignore = a_vec; ignore = b_vec; @@ -724,8 +750,20 @@ struct WarpGemmAttributeMfmaImplBf16Bf16F32M64N4K4 DISPATCH_MFMA_CTRL_("v_mfma_f32_4x4x4bf16_1k", Ctrl) else { -#if defined(__gfx9__) +#if defined(__gfx90a__) || defined(__gfx94__) c_vec = __builtin_amdgcn_mfma_f32_4x4x4bf16_1k(a_vec, b_vec, c_vec, 0, 0, 0); +#elif defined(__gfx908__) + static_for<0, 2, 1>{}([&](auto k) { + c_vec = __builtin_amdgcn_mfma_f32_4x4x2bf16( + reinterpret_cast&>(a_vec) + .template get_as>()[number{}], + reinterpret_cast&>(b_vec) + .template get_as>()[number{}], + c_vec, + 0, + 0, + 0); + }); #else ignore = c_vec; ignore = a_vec; @@ -737,9 +775,23 @@ struct WarpGemmAttributeMfmaImplBf16Bf16F32M64N4K4 // c_vec = a_vec * b_vec CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const { -#if defined(__gfx9__) +#if defined(__gfx90a__) || defined(__gfx94__) return bit_cast( __builtin_amdgcn_mfma_f32_4x4x4bf16_1k(a_vec, b_vec, fp32x4_t{0.f}, 0, 0, 0)); +#elif defined(__gfx908__) + CVecType c_vec{0.f}; + static_for<0, 2, 1>{}([&](auto k) { + c_vec = __builtin_amdgcn_mfma_f32_4x4x2bf16( + reinterpret_cast&>(a_vec) + .template get_as>()[number{}], + reinterpret_cast&>(b_vec) + .template get_as>()[number{}], + c_vec, + 0, + 0, + 0); + }); + return c_vec; #else ignore = a_vec; ignore = b_vec;