mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-24 06:44:36 +00:00
Use 16x16x32 for Gemm1 on MI350 and adjust the NumPrefetchK for with_softmax trload pipeline
This commit is contained in:
@@ -243,7 +243,7 @@ struct HstuAttentionWithSoftmaxFwdBlockTile<64>
|
||||
template <>
|
||||
struct HstuAttentionWithSoftmaxFwdBlockTile<128>
|
||||
{
|
||||
using type = ck_tile::sequence<128, 64, 128, 16, 128>;
|
||||
using type = ck_tile::sequence<128, 64, 128, 32, 128>;
|
||||
using gemm0_warps = ck_tile::sequence<4, 1, 1>;
|
||||
using gemm1_warps = ck_tile::sequence<4, 1, 1>;
|
||||
};
|
||||
@@ -336,7 +336,7 @@ struct HstuAttentionWithSoftmaxFwdTileSetting<128>
|
||||
typename HstuAttentionWithSoftmaxFwdBlockTile<128>::gemm0_warps,
|
||||
HstuAttentionFwdWarpTile2,
|
||||
typename HstuAttentionWithSoftmaxFwdBlockTile<128>::gemm1_warps,
|
||||
HstuAttentionFwdWarpTile1>;
|
||||
HstuAttentionFwdWarpTile2>;
|
||||
};
|
||||
|
||||
template <>
|
||||
|
||||
@@ -208,7 +208,7 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVSTrLoad
|
||||
|
||||
using k_tile_type = decltype(load_tile(k_dram_window));
|
||||
|
||||
constexpr index_t NumPrefetchK = 2;
|
||||
constexpr index_t NumPrefetchK = (k1_loops <= 3) ? 1 : 2;
|
||||
|
||||
static_assert(k1_loops >= NumPrefetchK, "Check failed!");
|
||||
|
||||
|
||||
Reference in New Issue
Block a user