Adding fix for the gfx908 to the GEMM MFMA implementaitons of WarpGem… (#2751)

* Adding fix for the gfx908 to the GEMM MFMA implementaitons of WarpGemmMfmaBf16Bf16F32M4N64K16 WarpGemmMfmaBf16Bf16F32M64N4K16

* Adding support for offload target gfx9-4-generic

* This duplication here isn't ideal
This commit is contained in:
Michael Mcminn
2025-09-02 04:35:07 -04:00
committed by GitHub
parent 33418b201f
commit 022f369deb
3 changed files with 62 additions and 8 deletions

View File

@@ -50,10 +50,11 @@
#endif
// define general macros for various architectures
#if defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx942__) || defined(__gfx950__)
#if defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx942__) || \
defined(__gfx950__) || defined(__gfx9_4_generic__)
#define __gfx9__
#endif
#if defined(__gfx942__) || defined(__gfx950__)
#if defined(__gfx942__) || defined(__gfx950__) || defined(__gfx9_4_generic__)
#define __gfx94__
#endif
#if defined(__gfx1010__) || defined(__gfx1011__) || defined(__gfx1012__)

View File

@@ -3,10 +3,11 @@
#pragma once
#if defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx942__) || defined(__gfx950__)
#if defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx942__) || \\
defined(__gfx950__) || defined(__gfx9_4_generic__)
#define __gfx9__
#endif
#if defined(__gfx942__) || defined(__gfx950__)
#if defined(__gfx942__) || defined(__gfx950__) || defined(__gfx9_4_generic__)
#define __gfx94__
#endif
#if defined(__gfx1030__) || defined(__gfx1031__) || defined(__gfx1032__) || \

View File

@@ -660,8 +660,20 @@ struct WarpGemmAttributeMfmaImplBf16Bf16F32M4N64K4
DISPATCH_MFMA_CTRL_("v_mfma_f32_4x4x4bf16_1k", Ctrl)
else
{
#if defined(__gfx9__)
#if defined(__gfx90a__) || defined(__gfx94__)
c_vec = __builtin_amdgcn_mfma_f32_4x4x4bf16_1k(a_vec, b_vec, c_vec, 0, 0, 0);
#elif defined(__gfx908__)
static_for<0, 2, 1>{}([&](auto k) {
c_vec = __builtin_amdgcn_mfma_f32_4x4x2bf16(
reinterpret_cast<const thread_buffer<ADataType, 4>&>(a_vec)
.template get_as<ext_vector_t<bf16_t, 2>>()[number<k>{}],
reinterpret_cast<const thread_buffer<BDataType, 4>&>(b_vec)
.template get_as<ext_vector_t<bf16_t, 2>>()[number<k>{}],
c_vec,
0,
0,
0);
});
#else
ignore = c_vec;
ignore = a_vec;
@@ -673,9 +685,23 @@ struct WarpGemmAttributeMfmaImplBf16Bf16F32M4N64K4
// c_vec = a_vec * b_vec
CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const
{
#if defined(__gfx9__)
#if defined(__gfx90a__) || defined(__gfx94__)
return bit_cast<CVecType>(
__builtin_amdgcn_mfma_f32_4x4x4bf16_1k(a_vec, b_vec, fp32x4_t{0.f}, 0, 0, 0));
#elif defined(__gfx908__)
CVecType c_vec{0.f};
static_for<0, 2, 1>{}([&](auto k) {
c_vec = __builtin_amdgcn_mfma_f32_4x4x2bf16(
reinterpret_cast<const thread_buffer<ADataType, 4>&>(a_vec)
.template get_as<ext_vector_t<bf16_t, 2>>()[number<k>{}],
reinterpret_cast<const thread_buffer<BDataType, 4>&>(b_vec)
.template get_as<ext_vector_t<bf16_t, 2>>()[number<k>{}],
c_vec,
0,
0,
0);
});
return c_vec;
#else
ignore = a_vec;
ignore = b_vec;
@@ -724,8 +750,20 @@ struct WarpGemmAttributeMfmaImplBf16Bf16F32M64N4K4
DISPATCH_MFMA_CTRL_("v_mfma_f32_4x4x4bf16_1k", Ctrl)
else
{
#if defined(__gfx9__)
#if defined(__gfx90a__) || defined(__gfx94__)
c_vec = __builtin_amdgcn_mfma_f32_4x4x4bf16_1k(a_vec, b_vec, c_vec, 0, 0, 0);
#elif defined(__gfx908__)
static_for<0, 2, 1>{}([&](auto k) {
c_vec = __builtin_amdgcn_mfma_f32_4x4x2bf16(
reinterpret_cast<const thread_buffer<ADataType, 4>&>(a_vec)
.template get_as<ext_vector_t<bf16_t, 2>>()[number<k>{}],
reinterpret_cast<const thread_buffer<BDataType, 4>&>(b_vec)
.template get_as<ext_vector_t<bf16_t, 2>>()[number<k>{}],
c_vec,
0,
0,
0);
});
#else
ignore = c_vec;
ignore = a_vec;
@@ -737,9 +775,23 @@ struct WarpGemmAttributeMfmaImplBf16Bf16F32M64N4K4
// c_vec = a_vec * b_vec
CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const
{
#if defined(__gfx9__)
#if defined(__gfx90a__) || defined(__gfx94__)
return bit_cast<CVecType>(
__builtin_amdgcn_mfma_f32_4x4x4bf16_1k(a_vec, b_vec, fp32x4_t{0.f}, 0, 0, 0));
#elif defined(__gfx908__)
CVecType c_vec{0.f};
static_for<0, 2, 1>{}([&](auto k) {
c_vec = __builtin_amdgcn_mfma_f32_4x4x2bf16(
reinterpret_cast<const thread_buffer<ADataType, 4>&>(a_vec)
.template get_as<ext_vector_t<bf16_t, 2>>()[number<k>{}],
reinterpret_cast<const thread_buffer<BDataType, 4>&>(b_vec)
.template get_as<ext_vector_t<bf16_t, 2>>()[number<k>{}],
c_vec,
0,
0,
0);
});
return c_vec;
#else
ignore = a_vec;
ignore = b_vec;