Code re-arrangement in pipeline

This commit is contained in:
Qianfeng Zhang
2025-04-25 14:16:29 +00:00
parent 4a49119d98
commit 80677eb6e0

View File

@@ -182,11 +182,11 @@ struct HstuAttentionFwdPipelineQRKSVS
{seqlen_k_start, 0},
Policy::template MakeKDramTileDistribution<Problem>());
auto q_tile = load_tile(q_dram_window);
auto k_tile = load_tile(k_dram_window);
move_tile_window(k_dram_window, {kK1, 0});
auto q_tile = load_tile(q_dram_window);
__builtin_amdgcn_sched_barrier(0);
// K tile in LDS
@@ -310,8 +310,6 @@ struct HstuAttentionFwdPipelineQRKSVS
store_tile(k_lds_windows[number<i_k1 % NumKLdsBuffers>{}],
tile_elementwise_in(k_element_func, k_tile));
clear_tile(sacc_tiles[i_k1]);
if constexpr(i_k1 < k1_loops - 1)
{
k_tile = load_tile(k_dram_window);
@@ -319,12 +317,12 @@ struct HstuAttentionFwdPipelineQRKSVS
}
else
{
static_for<0, NumPrefetchV, 1>{}([&](auto i_buf) {
v_tiles[i_buf] = load_tile(v_dram_window);
move_tile_window(v_dram_window, {0, kK1});
});
v_tiles[I0] = load_tile(v_dram_window);
move_tile_window(v_dram_window, {0, kK1});
};
clear_tile(sacc_tiles[i_k1]);
block_sync_lds();
// execute current unroll of gemm_0
gemm_0(sacc_tiles[i_k1], q_tile, k_lds_windows[number<i_k1 % NumKLdsBuffers>{}]);
@@ -379,9 +377,10 @@ struct HstuAttentionFwdPipelineQRKSVS
seqlen_k_curr += kK1;
});
// load one k_tile for next iteration
k_tile = load_tile(k_dram_window);
move_tile_window(k_dram_window, {kK1, 0});
static_for<1, NumPrefetchV, 1>{}([&](auto i_buf) {
v_tiles[i_buf] = load_tile(v_dram_window);
move_tile_window(v_dram_window, {0, kK1});
});
if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
{
@@ -433,6 +432,12 @@ struct HstuAttentionFwdPipelineQRKSVS
{
v_tiles[number<i_k1 % NumPrefetchV>{}] = load_tile(v_dram_window);
move_tile_window(v_dram_window, {0, kK1});
}
else if constexpr(i_k1 == k1_loops - NumPrefetchV)
{
// load one k_tile for next iteration
k_tile = load_tile(k_dram_window);
move_tile_window(k_dram_window, {kK1, 0});
};
block_sync_lds();