diff --git a/include/ck/ck.hpp b/include/ck/ck.hpp index 09801203ba..b8a1afec4e 100644 --- a/include/ck/ck.hpp +++ b/include/ck/ck.hpp @@ -50,10 +50,11 @@ #endif // define general macros for various architectures -#if defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx942__) || defined(__gfx950__) +#if defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx942__) || \ + defined(__gfx950__) || defined(__gfx9_4_generic__) #define __gfx9__ #endif -#if defined(__gfx942__) || defined(__gfx950__) +#if defined(__gfx942__) || defined(__gfx950__) || defined(__gfx9_4_generic__) #define __gfx94__ #endif #if defined(__gfx1010__) || defined(__gfx1011__) || defined(__gfx1012__) diff --git a/include/ck_tile/core/config.hpp b/include/ck_tile/core/config.hpp index 7b5b862cb1..0d4aa58026 100644 --- a/include/ck_tile/core/config.hpp +++ b/include/ck_tile/core/config.hpp @@ -3,10 +3,11 @@ #pragma once -#if defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx942__) || defined(__gfx950__) +#if defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx942__) || \\ + defined(__gfx950__) || defined(__gfx9_4_generic__) #define __gfx9__ #endif -#if defined(__gfx942__) || defined(__gfx950__) +#if defined(__gfx942__) || defined(__gfx950__) || defined(__gfx9_4_generic__) #define __gfx94__ #endif #if defined(__gfx1030__) || defined(__gfx1031__) || defined(__gfx1032__) || \ diff --git a/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma_impl.hpp b/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma_impl.hpp index 0831cf85c4..11a8416fb2 100644 --- a/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma_impl.hpp +++ b/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma_impl.hpp @@ -660,8 +660,20 @@ struct WarpGemmAttributeMfmaImplBf16Bf16F32M4N64K4 DISPATCH_MFMA_CTRL_("v_mfma_f32_4x4x4bf16_1k", Ctrl) else { -#if defined(__gfx9__) +#if defined(__gfx90a__) || defined(__gfx94__) c_vec = __builtin_amdgcn_mfma_f32_4x4x4bf16_1k(a_vec, b_vec, c_vec, 0, 0, 0); +#elif defined(__gfx908__) + static_for<0, 2, 1>{}([&](auto k) { + c_vec = __builtin_amdgcn_mfma_f32_4x4x2bf16( + reinterpret_cast&>(a_vec) + .template get_as>()[number{}], + reinterpret_cast&>(b_vec) + .template get_as>()[number{}], + c_vec, + 0, + 0, + 0); + }); #else ignore = c_vec; ignore = a_vec; @@ -673,9 +685,23 @@ struct WarpGemmAttributeMfmaImplBf16Bf16F32M4N64K4 // c_vec = a_vec * b_vec CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const { -#if defined(__gfx9__) +#if defined(__gfx90a__) || defined(__gfx94__) return bit_cast( __builtin_amdgcn_mfma_f32_4x4x4bf16_1k(a_vec, b_vec, fp32x4_t{0.f}, 0, 0, 0)); +#elif defined(__gfx908__) + CVecType c_vec{0.f}; + static_for<0, 2, 1>{}([&](auto k) { + c_vec = __builtin_amdgcn_mfma_f32_4x4x2bf16( + reinterpret_cast&>(a_vec) + .template get_as>()[number{}], + reinterpret_cast&>(b_vec) + .template get_as>()[number{}], + c_vec, + 0, + 0, + 0); + }); + return c_vec; #else ignore = a_vec; ignore = b_vec; @@ -724,8 +750,20 @@ struct WarpGemmAttributeMfmaImplBf16Bf16F32M64N4K4 DISPATCH_MFMA_CTRL_("v_mfma_f32_4x4x4bf16_1k", Ctrl) else { -#if defined(__gfx9__) +#if defined(__gfx90a__) || defined(__gfx94__) c_vec = __builtin_amdgcn_mfma_f32_4x4x4bf16_1k(a_vec, b_vec, c_vec, 0, 0, 0); +#elif defined(__gfx908__) + static_for<0, 2, 1>{}([&](auto k) { + c_vec = __builtin_amdgcn_mfma_f32_4x4x2bf16( + reinterpret_cast&>(a_vec) + .template get_as>()[number{}], + reinterpret_cast&>(b_vec) + .template get_as>()[number{}], + c_vec, + 0, + 0, + 0); + }); #else ignore = c_vec; ignore = a_vec; @@ -737,9 +775,23 @@ struct WarpGemmAttributeMfmaImplBf16Bf16F32M64N4K4 // c_vec = a_vec * b_vec CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const { -#if defined(__gfx9__) +#if defined(__gfx90a__) || defined(__gfx94__) return bit_cast( __builtin_amdgcn_mfma_f32_4x4x4bf16_1k(a_vec, b_vec, fp32x4_t{0.f}, 0, 0, 0)); +#elif defined(__gfx908__) + CVecType c_vec{0.f}; + static_for<0, 2, 1>{}([&](auto k) { + c_vec = __builtin_amdgcn_mfma_f32_4x4x2bf16( + reinterpret_cast&>(a_vec) + .template get_as>()[number{}], + reinterpret_cast&>(b_vec) + .template get_as>()[number{}], + c_vec, + 0, + 0, + 0); + }); + return c_vec; #else ignore = a_vec; ignore = b_vec;