mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-19 14:29:05 +00:00
Adding fix for the gfx908 to the GEMM MFMA implementaitons of WarpGem… (#2751)
* Adding fix for the gfx908 to the GEMM MFMA implementaitons of WarpGemmMfmaBf16Bf16F32M4N64K16 WarpGemmMfmaBf16Bf16F32M64N4K16 * Adding support for offload target gfx9-4-generic * This duplication here isn't ideal
This commit is contained in:
@@ -50,10 +50,11 @@
|
||||
#endif
|
||||
|
||||
// define general macros for various architectures
|
||||
#if defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx942__) || defined(__gfx950__)
|
||||
#if defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx942__) || \
|
||||
defined(__gfx950__) || defined(__gfx9_4_generic__)
|
||||
#define __gfx9__
|
||||
#endif
|
||||
#if defined(__gfx942__) || defined(__gfx950__)
|
||||
#if defined(__gfx942__) || defined(__gfx950__) || defined(__gfx9_4_generic__)
|
||||
#define __gfx94__
|
||||
#endif
|
||||
#if defined(__gfx1010__) || defined(__gfx1011__) || defined(__gfx1012__)
|
||||
|
||||
@@ -3,10 +3,11 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#if defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx942__) || defined(__gfx950__)
|
||||
#if defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx942__) || \\
|
||||
defined(__gfx950__) || defined(__gfx9_4_generic__)
|
||||
#define __gfx9__
|
||||
#endif
|
||||
#if defined(__gfx942__) || defined(__gfx950__)
|
||||
#if defined(__gfx942__) || defined(__gfx950__) || defined(__gfx9_4_generic__)
|
||||
#define __gfx94__
|
||||
#endif
|
||||
#if defined(__gfx1030__) || defined(__gfx1031__) || defined(__gfx1032__) || \
|
||||
|
||||
@@ -660,8 +660,20 @@ struct WarpGemmAttributeMfmaImplBf16Bf16F32M4N64K4
|
||||
DISPATCH_MFMA_CTRL_("v_mfma_f32_4x4x4bf16_1k", Ctrl)
|
||||
else
|
||||
{
|
||||
#if defined(__gfx9__)
|
||||
#if defined(__gfx90a__) || defined(__gfx94__)
|
||||
c_vec = __builtin_amdgcn_mfma_f32_4x4x4bf16_1k(a_vec, b_vec, c_vec, 0, 0, 0);
|
||||
#elif defined(__gfx908__)
|
||||
static_for<0, 2, 1>{}([&](auto k) {
|
||||
c_vec = __builtin_amdgcn_mfma_f32_4x4x2bf16(
|
||||
reinterpret_cast<const thread_buffer<ADataType, 4>&>(a_vec)
|
||||
.template get_as<ext_vector_t<bf16_t, 2>>()[number<k>{}],
|
||||
reinterpret_cast<const thread_buffer<BDataType, 4>&>(b_vec)
|
||||
.template get_as<ext_vector_t<bf16_t, 2>>()[number<k>{}],
|
||||
c_vec,
|
||||
0,
|
||||
0,
|
||||
0);
|
||||
});
|
||||
#else
|
||||
ignore = c_vec;
|
||||
ignore = a_vec;
|
||||
@@ -673,9 +685,23 @@ struct WarpGemmAttributeMfmaImplBf16Bf16F32M4N64K4
|
||||
// c_vec = a_vec * b_vec
|
||||
CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const
|
||||
{
|
||||
#if defined(__gfx9__)
|
||||
#if defined(__gfx90a__) || defined(__gfx94__)
|
||||
return bit_cast<CVecType>(
|
||||
__builtin_amdgcn_mfma_f32_4x4x4bf16_1k(a_vec, b_vec, fp32x4_t{0.f}, 0, 0, 0));
|
||||
#elif defined(__gfx908__)
|
||||
CVecType c_vec{0.f};
|
||||
static_for<0, 2, 1>{}([&](auto k) {
|
||||
c_vec = __builtin_amdgcn_mfma_f32_4x4x2bf16(
|
||||
reinterpret_cast<const thread_buffer<ADataType, 4>&>(a_vec)
|
||||
.template get_as<ext_vector_t<bf16_t, 2>>()[number<k>{}],
|
||||
reinterpret_cast<const thread_buffer<BDataType, 4>&>(b_vec)
|
||||
.template get_as<ext_vector_t<bf16_t, 2>>()[number<k>{}],
|
||||
c_vec,
|
||||
0,
|
||||
0,
|
||||
0);
|
||||
});
|
||||
return c_vec;
|
||||
#else
|
||||
ignore = a_vec;
|
||||
ignore = b_vec;
|
||||
@@ -724,8 +750,20 @@ struct WarpGemmAttributeMfmaImplBf16Bf16F32M64N4K4
|
||||
DISPATCH_MFMA_CTRL_("v_mfma_f32_4x4x4bf16_1k", Ctrl)
|
||||
else
|
||||
{
|
||||
#if defined(__gfx9__)
|
||||
#if defined(__gfx90a__) || defined(__gfx94__)
|
||||
c_vec = __builtin_amdgcn_mfma_f32_4x4x4bf16_1k(a_vec, b_vec, c_vec, 0, 0, 0);
|
||||
#elif defined(__gfx908__)
|
||||
static_for<0, 2, 1>{}([&](auto k) {
|
||||
c_vec = __builtin_amdgcn_mfma_f32_4x4x2bf16(
|
||||
reinterpret_cast<const thread_buffer<ADataType, 4>&>(a_vec)
|
||||
.template get_as<ext_vector_t<bf16_t, 2>>()[number<k>{}],
|
||||
reinterpret_cast<const thread_buffer<BDataType, 4>&>(b_vec)
|
||||
.template get_as<ext_vector_t<bf16_t, 2>>()[number<k>{}],
|
||||
c_vec,
|
||||
0,
|
||||
0,
|
||||
0);
|
||||
});
|
||||
#else
|
||||
ignore = c_vec;
|
||||
ignore = a_vec;
|
||||
@@ -737,9 +775,23 @@ struct WarpGemmAttributeMfmaImplBf16Bf16F32M64N4K4
|
||||
// c_vec = a_vec * b_vec
|
||||
CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const
|
||||
{
|
||||
#if defined(__gfx9__)
|
||||
#if defined(__gfx90a__) || defined(__gfx94__)
|
||||
return bit_cast<CVecType>(
|
||||
__builtin_amdgcn_mfma_f32_4x4x4bf16_1k(a_vec, b_vec, fp32x4_t{0.f}, 0, 0, 0));
|
||||
#elif defined(__gfx908__)
|
||||
CVecType c_vec{0.f};
|
||||
static_for<0, 2, 1>{}([&](auto k) {
|
||||
c_vec = __builtin_amdgcn_mfma_f32_4x4x2bf16(
|
||||
reinterpret_cast<const thread_buffer<ADataType, 4>&>(a_vec)
|
||||
.template get_as<ext_vector_t<bf16_t, 2>>()[number<k>{}],
|
||||
reinterpret_cast<const thread_buffer<BDataType, 4>&>(b_vec)
|
||||
.template get_as<ext_vector_t<bf16_t, 2>>()[number<k>{}],
|
||||
c_vec,
|
||||
0,
|
||||
0,
|
||||
0);
|
||||
});
|
||||
return c_vec;
|
||||
#else
|
||||
ignore = a_vec;
|
||||
ignore = b_vec;
|
||||
|
||||
Reference in New Issue
Block a user