Force both Gemm0 and Gemm1 to use mfma-16x16x32 on gfx950

This commit is contained in:
Qianfeng Zhang
2025-11-28 13:45:20 +00:00
parent a0e4315d4e
commit f952d3571c
3 changed files with 21 additions and 25 deletions

View File

@@ -736,21 +736,10 @@ struct HstuAttentionFwdPipelineQRKSVSDefaultPolicy
#ifdef __gfx950__
static_assert(WarpGemmM == 16 && WarpGemmK == 32, "Not supported WarpGemm sizes!");
return WarpGemmDispatcher<
typename Problem::QKVDataType,
typename Problem::QKVDataType,
typename Problem::GemmAccDataType,
Problem::HstuAttentionTileSetting::Gemm0WarpTile::at(number<0>{}),
Problem::HstuAttentionTileSetting::Gemm0WarpTile::at(number<1>{}),
Problem::HstuAttentionTileSetting::Gemm0WarpTile::at(number<2>{}),
true,
false,
false,
WGAttrNumAccessEnum::Single>{};
#else
static_assert((WarpGemmM == 16 && (WarpGemmK == 16 || WarpGemmK == 32)),
static_assert(WarpGemmM == 16 && (WarpGemmK == 16 || WarpGemmK == 32),
"Not supported WarpGemm sizes!");
#endif
return WarpGemmDispatcher<
typename Problem::QKVDataType,
@@ -763,7 +752,6 @@ struct HstuAttentionFwdPipelineQRKSVSDefaultPolicy
false,
false,
WGAttrNumAccessEnum::Single>{};
#endif
}
else
{
@@ -815,8 +803,12 @@ struct HstuAttentionFwdPipelineQRKSVSDefaultPolicy
constexpr index_t WarpGemmK =
Problem::HstuAttentionTileSetting::Gemm1WarpTile::at(number<2>{});
static_assert((WarpGemmM == 16 && (WarpGemmK == 16 || WarpGemmK == 32)),
#ifdef __gfx950__
static_assert(WarpGemmM == 16 && WarpGemmK == 32, "Not supported WarpGemm sizes!");
#else
static_assert(WarpGemmM == 16 && (WarpGemmK == 16 || WarpGemmK == 32),
"Not supported WarpGemm sizes!");
#endif
if constexpr(WarpGemmK == 32)
return WarpGemmDispatcher<

View File

@@ -190,7 +190,7 @@ struct HstuAttentionNoSoftmaxFwdBlockTile;
template <>
struct HstuAttentionNoSoftmaxFwdBlockTile<32>
{
using type = ck_tile::sequence<64, 64, 32, 16, 32>;
using type = ck_tile::sequence<64, 64, 32, 32, 32>;
using gemm0_warps = ck_tile::sequence<2, 1, 1>;
using gemm1_warps = ck_tile::sequence<2, 1, 1>;
};
@@ -214,7 +214,7 @@ struct HstuAttentionNoSoftmaxFwdBlockTile<128>
template <>
struct HstuAttentionNoSoftmaxFwdBlockTile<256>
{
using type = ck_tile::sequence<128, 32, 256, 16, 256>;
using type = ck_tile::sequence<128, 32, 256, 32, 256>;
using gemm0_warps = ck_tile::sequence<4, 1, 1>;
using gemm1_warps = ck_tile::sequence<4, 1, 1>;
};
@@ -227,7 +227,7 @@ struct HstuAttentionWithSoftmaxFwdBlockTile;
template <>
struct HstuAttentionWithSoftmaxFwdBlockTile<32>
{
using type = ck_tile::sequence<64, 64, 32, 16, 32>;
using type = ck_tile::sequence<64, 64, 32, 32, 32>;
using gemm0_warps = ck_tile::sequence<2, 1, 1>;
using gemm1_warps = ck_tile::sequence<2, 1, 1>;
};
@@ -251,7 +251,7 @@ struct HstuAttentionWithSoftmaxFwdBlockTile<128>
template <>
struct HstuAttentionWithSoftmaxFwdBlockTile<256>
{
using type = ck_tile::sequence<128, 32, 256, 16, 256>;
using type = ck_tile::sequence<128, 64, 256, 32, 256>;
using gemm0_warps = ck_tile::sequence<4, 1, 1>;
using gemm1_warps = ck_tile::sequence<4, 1, 1>;
};
@@ -267,7 +267,7 @@ struct HstuAttentionNoSoftmaxFwdTileSetting<32>
typename HstuAttentionNoSoftmaxFwdBlockTile<32>::gemm0_warps,
HstuAttentionFwdWarpTile2,
typename HstuAttentionNoSoftmaxFwdBlockTile<32>::gemm1_warps,
HstuAttentionFwdWarpTile1>;
HstuAttentionFwdWarpTile2>;
};
template <>
@@ -278,7 +278,7 @@ struct HstuAttentionNoSoftmaxFwdTileSetting<64>
typename HstuAttentionNoSoftmaxFwdBlockTile<64>::gemm0_warps,
HstuAttentionFwdWarpTile2,
typename HstuAttentionNoSoftmaxFwdBlockTile<64>::gemm1_warps,
HstuAttentionFwdWarpTile1>;
HstuAttentionFwdWarpTile2>;
};
template <>
@@ -300,7 +300,7 @@ struct HstuAttentionNoSoftmaxFwdTileSetting<256>
typename HstuAttentionNoSoftmaxFwdBlockTile<256>::gemm0_warps,
HstuAttentionFwdWarpTile2,
typename HstuAttentionNoSoftmaxFwdBlockTile<256>::gemm1_warps,
HstuAttentionFwdWarpTile1>;
HstuAttentionFwdWarpTile2>;
};
template <ck_tile::index_t MaxK>
@@ -314,7 +314,7 @@ struct HstuAttentionWithSoftmaxFwdTileSetting<32>
typename HstuAttentionWithSoftmaxFwdBlockTile<32>::gemm0_warps,
HstuAttentionFwdWarpTile2,
typename HstuAttentionWithSoftmaxFwdBlockTile<32>::gemm1_warps,
HstuAttentionFwdWarpTile1>;
HstuAttentionFwdWarpTile2>;
};
template <>
@@ -325,7 +325,7 @@ struct HstuAttentionWithSoftmaxFwdTileSetting<64>
typename HstuAttentionWithSoftmaxFwdBlockTile<64>::gemm0_warps,
HstuAttentionFwdWarpTile2,
typename HstuAttentionWithSoftmaxFwdBlockTile<64>::gemm1_warps,
HstuAttentionFwdWarpTile1>;
HstuAttentionFwdWarpTile2>;
};
template <>
@@ -347,6 +347,6 @@ struct HstuAttentionWithSoftmaxFwdTileSetting<256>
typename HstuAttentionWithSoftmaxFwdBlockTile<256>::gemm0_warps,
HstuAttentionFwdWarpTile2,
typename HstuAttentionWithSoftmaxFwdBlockTile<256>::gemm1_warps,
HstuAttentionFwdWarpTile1>;
HstuAttentionFwdWarpTile2>;
};
#endif

View File

@@ -212,6 +212,9 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVSTrLoad
static_assert(k1_loops >= NumPrefetchK, "Check failed!");
static_assert(k1_loops >= 2,
"k1_loops >= 2 required due to pre-storing two v_tiles to Lds");
// only prefetch two k tiles to save vgprs consumption
statically_indexed_array<k_tile_type, NumPrefetchK> k_tiles;
@@ -556,6 +559,7 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVSTrLoad
__builtin_amdgcn_s_barrier();
};
// k1_loops >= 2 required
store_tile(v_lds_windows[number<3 % NumKVLdsBuffers>{}],
v_tiles[number<1>{}],
partition_index);