mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-26 08:00:13 +00:00
Update the NumPrefetchK and NumPrefetchV in the softmax pipeline on mi350 to achieve better interleaving
This commit is contained in:
@@ -160,9 +160,8 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVSTrLoad
|
||||
constexpr index_t n0_loops = kN0 / kN0Sub;
|
||||
constexpr index_t k1_loops = kN0 / kK1;
|
||||
|
||||
static_assert(n0_loops >= k1_loops, "n0_loops >= k1_loops required by this pipeline");
|
||||
static_assert(k1_loops >= 2,
|
||||
"k1_loops >= 2 required due to pre-storing two v_tiles to Lds");
|
||||
static_assert(n0_loops >= 2, "n0_loops >= 2 required by this pipeline");
|
||||
static_assert(k1_loops >= 2, "k1_loops >= 2 required by this pipeline");
|
||||
|
||||
constexpr auto NumKVLdsBuffers = Policy::template GetNumKVLdsBuffers<Problem>();
|
||||
|
||||
@@ -211,13 +210,10 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVSTrLoad
|
||||
|
||||
using k_tile_type = decltype(load_tile(k_dram_window));
|
||||
|
||||
constexpr index_t NumPrefetchK = (n0_loops <= 3) ? 1 : 2;
|
||||
constexpr index_t NumPrefetchK = 1;
|
||||
|
||||
static_assert(n0_loops >= NumPrefetchK, "Check failed!");
|
||||
|
||||
static_assert(k1_loops >= 2,
|
||||
"k1_loops >= 2 required due to pre-storing two v_tiles to Lds");
|
||||
|
||||
// only prefetch two k tiles to save vgprs consumption
|
||||
statically_indexed_array<k_tile_type, NumPrefetchK> k_tiles;
|
||||
|
||||
@@ -321,9 +317,13 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVSTrLoad
|
||||
|
||||
auto seqlen_k_curr = seqlen_k_start;
|
||||
|
||||
constexpr index_t NumPrefetchV = 2;
|
||||
|
||||
static_assert(NumPrefetchV >= NumPrefetchK);
|
||||
|
||||
using v_tile_type = decltype(load_tile(v_dram_window));
|
||||
|
||||
statically_indexed_array<v_tile_type, k1_loops> v_tiles;
|
||||
statically_indexed_array<v_tile_type, NumPrefetchV> v_tiles;
|
||||
|
||||
do
|
||||
{
|
||||
@@ -342,7 +342,8 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVSTrLoad
|
||||
}
|
||||
else
|
||||
{
|
||||
if constexpr(i_n0 - (n0_loops - NumPrefetchK) < k1_loops)
|
||||
// We assume NumPrefetchV >= NumPrefetchK
|
||||
if constexpr(i_n0 - (n0_loops - NumPrefetchK) < NumPrefetchK)
|
||||
{
|
||||
// load v_tiles used in current iteration
|
||||
v_tiles[number<i_n0 - (n0_loops - NumPrefetchK)>{}] =
|
||||
@@ -443,7 +444,7 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVSTrLoad
|
||||
|
||||
__builtin_amdgcn_sched_barrier(0x00000001);
|
||||
|
||||
static_for<NumPrefetchK, k1_loops, 1>{}([&](auto i_k1) {
|
||||
static_for<NumPrefetchK, NumPrefetchV, 1>{}([&](auto i_k1) {
|
||||
// load v_tiles used in current iteration
|
||||
v_tiles[i_k1] = load_tile(v_dram_window);
|
||||
move_tile_window(v_dram_window, {kK1, 0});
|
||||
@@ -460,6 +461,8 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVSTrLoad
|
||||
tile_elementwise_inout(
|
||||
[](auto& e0, auto e1, auto e2) { e0 = max(e1, e2); }, m, m_old, m_local);
|
||||
|
||||
__builtin_amdgcn_sched_barrier(0x00000001);
|
||||
|
||||
constexpr auto p_spans = decltype(pcomp_tile)::get_distributed_spans();
|
||||
sweep_tile_span(p_spans[number<0>{}], [&](auto idx0) {
|
||||
constexpr auto i_idx = make_tuple(idx0);
|
||||
@@ -518,26 +521,20 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVSTrLoad
|
||||
|
||||
auto p = cast_tile<PDataType>(tile_elementwise_in(p_compute_element_func, pcomp_tile));
|
||||
|
||||
// check whether second V-LdsBufer overlap with last K-LdsBuffer,
|
||||
// this does not occur when k1_loops == 2 and NumKVLdsBuffers == 4
|
||||
if constexpr((k1_loops - 1) % NumKVLdsBuffers == 3 % NumKVLdsBuffers)
|
||||
{
|
||||
__builtin_amdgcn_s_barrier();
|
||||
};
|
||||
|
||||
// k1_loops >= 2 required
|
||||
store_tile(v_lds_windows[number<3 % NumKVLdsBuffers>{}],
|
||||
v_tiles[number<1>{}],
|
||||
partition_index);
|
||||
|
||||
__builtin_amdgcn_sched_barrier(0x00000001);
|
||||
|
||||
// STAGE 3, Gemm_1 ( O = P@V )
|
||||
static_for<0, k1_loops, 1>{}([&](auto i_k1) {
|
||||
if constexpr(i_k1 < NumPrefetchK)
|
||||
if constexpr(i_k1 < k1_loops - NumPrefetchV)
|
||||
{
|
||||
// load k_tiles used by next iteration
|
||||
k_tiles[i_k1] = load_tile(k_dram_window);
|
||||
v_tiles[number<i_k1 % NumPrefetchV>{}] = load_tile(v_dram_window);
|
||||
move_tile_window(v_dram_window, {kK1, 0});
|
||||
};
|
||||
|
||||
if constexpr((i_k1 >= k1_loops - NumPrefetchV) &&
|
||||
(i_k1 - (k1_loops - NumPrefetchV) < NumPrefetchK))
|
||||
{
|
||||
k_tiles[number<i_k1 - (k1_loops - NumPrefetchV)>{}] = load_tile(k_dram_window);
|
||||
move_tile_window(k_dram_window, {kN0Sub, 0});
|
||||
};
|
||||
|
||||
@@ -550,12 +547,12 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVSTrLoad
|
||||
get_slice_tile(p, sequence<0, i_k1 * kK1>{}, sequence<kM0, (i_k1 + 1) * kK1>{}),
|
||||
v_lds_windows[number<(i_k1 + 2) % NumKVLdsBuffers>{}]);
|
||||
|
||||
if constexpr(i_k1 < k1_loops - 2)
|
||||
if constexpr(i_k1 < k1_loops - 1)
|
||||
{
|
||||
__builtin_amdgcn_sched_barrier(0x00000001);
|
||||
|
||||
store_tile(v_lds_windows[number<(i_k1 + 4) % NumKVLdsBuffers>{}],
|
||||
v_tiles[number<i_k1 + 2>{}],
|
||||
store_tile(v_lds_windows[number<(i_k1 + 3) % NumKVLdsBuffers>{}],
|
||||
v_tiles[number<i_k1 + 1>{}],
|
||||
partition_index);
|
||||
|
||||
__builtin_amdgcn_sched_barrier(0x00000001);
|
||||
|
||||
Reference in New Issue
Block a user