mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-14 10:09:41 +00:00
Force both Gemm0 and Gemm1 to use mfma-16x16x32 on gfx950
This commit is contained in:
@@ -736,21 +736,10 @@ struct HstuAttentionFwdPipelineQRKSVSDefaultPolicy
|
||||
|
||||
#ifdef __gfx950__
|
||||
static_assert(WarpGemmM == 16 && WarpGemmK == 32, "Not supported WarpGemm sizes!");
|
||||
|
||||
return WarpGemmDispatcher<
|
||||
typename Problem::QKVDataType,
|
||||
typename Problem::QKVDataType,
|
||||
typename Problem::GemmAccDataType,
|
||||
Problem::HstuAttentionTileSetting::Gemm0WarpTile::at(number<0>{}),
|
||||
Problem::HstuAttentionTileSetting::Gemm0WarpTile::at(number<1>{}),
|
||||
Problem::HstuAttentionTileSetting::Gemm0WarpTile::at(number<2>{}),
|
||||
true,
|
||||
false,
|
||||
false,
|
||||
WGAttrNumAccessEnum::Single>{};
|
||||
#else
|
||||
static_assert((WarpGemmM == 16 && (WarpGemmK == 16 || WarpGemmK == 32)),
|
||||
static_assert(WarpGemmM == 16 && (WarpGemmK == 16 || WarpGemmK == 32),
|
||||
"Not supported WarpGemm sizes!");
|
||||
#endif
|
||||
|
||||
return WarpGemmDispatcher<
|
||||
typename Problem::QKVDataType,
|
||||
@@ -763,7 +752,6 @@ struct HstuAttentionFwdPipelineQRKSVSDefaultPolicy
|
||||
false,
|
||||
false,
|
||||
WGAttrNumAccessEnum::Single>{};
|
||||
#endif
|
||||
}
|
||||
else
|
||||
{
|
||||
@@ -815,8 +803,12 @@ struct HstuAttentionFwdPipelineQRKSVSDefaultPolicy
|
||||
constexpr index_t WarpGemmK =
|
||||
Problem::HstuAttentionTileSetting::Gemm1WarpTile::at(number<2>{});
|
||||
|
||||
static_assert((WarpGemmM == 16 && (WarpGemmK == 16 || WarpGemmK == 32)),
|
||||
#ifdef __gfx950__
|
||||
static_assert(WarpGemmM == 16 && WarpGemmK == 32, "Not supported WarpGemm sizes!");
|
||||
#else
|
||||
static_assert(WarpGemmM == 16 && (WarpGemmK == 16 || WarpGemmK == 32),
|
||||
"Not supported WarpGemm sizes!");
|
||||
#endif
|
||||
|
||||
if constexpr(WarpGemmK == 32)
|
||||
return WarpGemmDispatcher<
|
||||
|
||||
@@ -190,7 +190,7 @@ struct HstuAttentionNoSoftmaxFwdBlockTile;
|
||||
template <>
|
||||
struct HstuAttentionNoSoftmaxFwdBlockTile<32>
|
||||
{
|
||||
using type = ck_tile::sequence<64, 64, 32, 16, 32>;
|
||||
using type = ck_tile::sequence<64, 64, 32, 32, 32>;
|
||||
using gemm0_warps = ck_tile::sequence<2, 1, 1>;
|
||||
using gemm1_warps = ck_tile::sequence<2, 1, 1>;
|
||||
};
|
||||
@@ -214,7 +214,7 @@ struct HstuAttentionNoSoftmaxFwdBlockTile<128>
|
||||
template <>
|
||||
struct HstuAttentionNoSoftmaxFwdBlockTile<256>
|
||||
{
|
||||
using type = ck_tile::sequence<128, 32, 256, 16, 256>;
|
||||
using type = ck_tile::sequence<128, 32, 256, 32, 256>;
|
||||
using gemm0_warps = ck_tile::sequence<4, 1, 1>;
|
||||
using gemm1_warps = ck_tile::sequence<4, 1, 1>;
|
||||
};
|
||||
@@ -227,7 +227,7 @@ struct HstuAttentionWithSoftmaxFwdBlockTile;
|
||||
template <>
|
||||
struct HstuAttentionWithSoftmaxFwdBlockTile<32>
|
||||
{
|
||||
using type = ck_tile::sequence<64, 64, 32, 16, 32>;
|
||||
using type = ck_tile::sequence<64, 64, 32, 32, 32>;
|
||||
using gemm0_warps = ck_tile::sequence<2, 1, 1>;
|
||||
using gemm1_warps = ck_tile::sequence<2, 1, 1>;
|
||||
};
|
||||
@@ -251,7 +251,7 @@ struct HstuAttentionWithSoftmaxFwdBlockTile<128>
|
||||
template <>
|
||||
struct HstuAttentionWithSoftmaxFwdBlockTile<256>
|
||||
{
|
||||
using type = ck_tile::sequence<128, 32, 256, 16, 256>;
|
||||
using type = ck_tile::sequence<128, 64, 256, 32, 256>;
|
||||
using gemm0_warps = ck_tile::sequence<4, 1, 1>;
|
||||
using gemm1_warps = ck_tile::sequence<4, 1, 1>;
|
||||
};
|
||||
@@ -267,7 +267,7 @@ struct HstuAttentionNoSoftmaxFwdTileSetting<32>
|
||||
typename HstuAttentionNoSoftmaxFwdBlockTile<32>::gemm0_warps,
|
||||
HstuAttentionFwdWarpTile2,
|
||||
typename HstuAttentionNoSoftmaxFwdBlockTile<32>::gemm1_warps,
|
||||
HstuAttentionFwdWarpTile1>;
|
||||
HstuAttentionFwdWarpTile2>;
|
||||
};
|
||||
|
||||
template <>
|
||||
@@ -278,7 +278,7 @@ struct HstuAttentionNoSoftmaxFwdTileSetting<64>
|
||||
typename HstuAttentionNoSoftmaxFwdBlockTile<64>::gemm0_warps,
|
||||
HstuAttentionFwdWarpTile2,
|
||||
typename HstuAttentionNoSoftmaxFwdBlockTile<64>::gemm1_warps,
|
||||
HstuAttentionFwdWarpTile1>;
|
||||
HstuAttentionFwdWarpTile2>;
|
||||
};
|
||||
|
||||
template <>
|
||||
@@ -300,7 +300,7 @@ struct HstuAttentionNoSoftmaxFwdTileSetting<256>
|
||||
typename HstuAttentionNoSoftmaxFwdBlockTile<256>::gemm0_warps,
|
||||
HstuAttentionFwdWarpTile2,
|
||||
typename HstuAttentionNoSoftmaxFwdBlockTile<256>::gemm1_warps,
|
||||
HstuAttentionFwdWarpTile1>;
|
||||
HstuAttentionFwdWarpTile2>;
|
||||
};
|
||||
|
||||
template <ck_tile::index_t MaxK>
|
||||
@@ -314,7 +314,7 @@ struct HstuAttentionWithSoftmaxFwdTileSetting<32>
|
||||
typename HstuAttentionWithSoftmaxFwdBlockTile<32>::gemm0_warps,
|
||||
HstuAttentionFwdWarpTile2,
|
||||
typename HstuAttentionWithSoftmaxFwdBlockTile<32>::gemm1_warps,
|
||||
HstuAttentionFwdWarpTile1>;
|
||||
HstuAttentionFwdWarpTile2>;
|
||||
};
|
||||
|
||||
template <>
|
||||
@@ -325,7 +325,7 @@ struct HstuAttentionWithSoftmaxFwdTileSetting<64>
|
||||
typename HstuAttentionWithSoftmaxFwdBlockTile<64>::gemm0_warps,
|
||||
HstuAttentionFwdWarpTile2,
|
||||
typename HstuAttentionWithSoftmaxFwdBlockTile<64>::gemm1_warps,
|
||||
HstuAttentionFwdWarpTile1>;
|
||||
HstuAttentionFwdWarpTile2>;
|
||||
};
|
||||
|
||||
template <>
|
||||
@@ -347,6 +347,6 @@ struct HstuAttentionWithSoftmaxFwdTileSetting<256>
|
||||
typename HstuAttentionWithSoftmaxFwdBlockTile<256>::gemm0_warps,
|
||||
HstuAttentionFwdWarpTile2,
|
||||
typename HstuAttentionWithSoftmaxFwdBlockTile<256>::gemm1_warps,
|
||||
HstuAttentionFwdWarpTile1>;
|
||||
HstuAttentionFwdWarpTile2>;
|
||||
};
|
||||
#endif
|
||||
|
||||
@@ -212,6 +212,9 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVSTrLoad
|
||||
|
||||
static_assert(k1_loops >= NumPrefetchK, "Check failed!");
|
||||
|
||||
static_assert(k1_loops >= 2,
|
||||
"k1_loops >= 2 required due to pre-storing two v_tiles to Lds");
|
||||
|
||||
// only prefetch two k tiles to save vgprs consumption
|
||||
statically_indexed_array<k_tile_type, NumPrefetchK> k_tiles;
|
||||
|
||||
@@ -556,6 +559,7 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVSTrLoad
|
||||
__builtin_amdgcn_s_barrier();
|
||||
};
|
||||
|
||||
// k1_loops >= 2 required
|
||||
store_tile(v_lds_windows[number<3 % NumKVLdsBuffers>{}],
|
||||
v_tiles[number<1>{}],
|
||||
partition_index);
|
||||
|
||||
Reference in New Issue
Block a user