mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-15 10:37:44 +00:00
Code re-arrangement in pipeline
This commit is contained in:
@@ -182,11 +182,11 @@ struct HstuAttentionFwdPipelineQRKSVS
|
||||
{seqlen_k_start, 0},
|
||||
Policy::template MakeKDramTileDistribution<Problem>());
|
||||
|
||||
auto q_tile = load_tile(q_dram_window);
|
||||
|
||||
auto k_tile = load_tile(k_dram_window);
|
||||
move_tile_window(k_dram_window, {kK1, 0});
|
||||
|
||||
auto q_tile = load_tile(q_dram_window);
|
||||
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
|
||||
// K tile in LDS
|
||||
@@ -310,8 +310,6 @@ struct HstuAttentionFwdPipelineQRKSVS
|
||||
store_tile(k_lds_windows[number<i_k1 % NumKLdsBuffers>{}],
|
||||
tile_elementwise_in(k_element_func, k_tile));
|
||||
|
||||
clear_tile(sacc_tiles[i_k1]);
|
||||
|
||||
if constexpr(i_k1 < k1_loops - 1)
|
||||
{
|
||||
k_tile = load_tile(k_dram_window);
|
||||
@@ -319,12 +317,12 @@ struct HstuAttentionFwdPipelineQRKSVS
|
||||
}
|
||||
else
|
||||
{
|
||||
static_for<0, NumPrefetchV, 1>{}([&](auto i_buf) {
|
||||
v_tiles[i_buf] = load_tile(v_dram_window);
|
||||
move_tile_window(v_dram_window, {0, kK1});
|
||||
});
|
||||
v_tiles[I0] = load_tile(v_dram_window);
|
||||
move_tile_window(v_dram_window, {0, kK1});
|
||||
};
|
||||
|
||||
clear_tile(sacc_tiles[i_k1]);
|
||||
|
||||
block_sync_lds();
|
||||
// execute current unroll of gemm_0
|
||||
gemm_0(sacc_tiles[i_k1], q_tile, k_lds_windows[number<i_k1 % NumKLdsBuffers>{}]);
|
||||
@@ -379,9 +377,10 @@ struct HstuAttentionFwdPipelineQRKSVS
|
||||
seqlen_k_curr += kK1;
|
||||
});
|
||||
|
||||
// load one k_tile for next iteration
|
||||
k_tile = load_tile(k_dram_window);
|
||||
move_tile_window(k_dram_window, {kK1, 0});
|
||||
static_for<1, NumPrefetchV, 1>{}([&](auto i_buf) {
|
||||
v_tiles[i_buf] = load_tile(v_dram_window);
|
||||
move_tile_window(v_dram_window, {0, kK1});
|
||||
});
|
||||
|
||||
if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
@@ -433,6 +432,12 @@ struct HstuAttentionFwdPipelineQRKSVS
|
||||
{
|
||||
v_tiles[number<i_k1 % NumPrefetchV>{}] = load_tile(v_dram_window);
|
||||
move_tile_window(v_dram_window, {0, kK1});
|
||||
}
|
||||
else if constexpr(i_k1 == k1_loops - NumPrefetchV)
|
||||
{
|
||||
// load one k_tile for next iteration
|
||||
k_tile = load_tile(k_dram_window);
|
||||
move_tile_window(k_dram_window, {kK1, 0});
|
||||
};
|
||||
|
||||
block_sync_lds();
|
||||
|
||||
Reference in New Issue
Block a user