Use 16x16x32 for Gemm1 on MI350 and adjust the NumPrefetchK for with_softmax trload pipeline

This commit is contained in:
Qianfeng Zhang
2025-11-27 15:30:53 +00:00
parent 69c97c06d7
commit a0e4315d4e
2 changed files with 3 additions and 3 deletions

View File

@@ -243,7 +243,7 @@ struct HstuAttentionWithSoftmaxFwdBlockTile<64>
template <>
struct HstuAttentionWithSoftmaxFwdBlockTile<128>
{
using type = ck_tile::sequence<128, 64, 128, 16, 128>;
using type = ck_tile::sequence<128, 64, 128, 32, 128>;
using gemm0_warps = ck_tile::sequence<4, 1, 1>;
using gemm1_warps = ck_tile::sequence<4, 1, 1>;
};
@@ -336,7 +336,7 @@ struct HstuAttentionWithSoftmaxFwdTileSetting<128>
typename HstuAttentionWithSoftmaxFwdBlockTile<128>::gemm0_warps,
HstuAttentionFwdWarpTile2,
typename HstuAttentionWithSoftmaxFwdBlockTile<128>::gemm1_warps,
HstuAttentionFwdWarpTile1>;
HstuAttentionFwdWarpTile2>;
};
template <>

View File

@@ -208,7 +208,7 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVSTrLoad
using k_tile_type = decltype(load_tile(k_dram_window));
constexpr index_t NumPrefetchK = 2;
constexpr index_t NumPrefetchK = (k1_loops <= 3) ? 1 : 2;
static_assert(k1_loops >= NumPrefetchK, "Check failed!");