mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-17 19:40:04 +00:00
Move store_tile() caled before the current iteration
This commit is contained in:
@@ -353,10 +353,14 @@ struct HstuAttentionFwdPipelineQRKSVS
|
||||
// ensure all q_reg_tiles[] have been loaded from LDS, so the LDS can be reused by k_tile
|
||||
__builtin_amdgcn_s_barrier();
|
||||
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
|
||||
using v_tile_type = decltype(load_tile(v_dram_window));
|
||||
|
||||
v_tile_type v_tile;
|
||||
|
||||
store_tile(k_lds_windows[number<0>{}], tile_elementwise_in(k_element_func, k_tile));
|
||||
|
||||
do
|
||||
{
|
||||
static_for<0, k1_loops, 1>{}([&](auto i_k1) {
|
||||
@@ -365,17 +369,12 @@ struct HstuAttentionFwdPipelineQRKSVS
|
||||
constexpr index_t V_VMEM_LOAD_INST = (kN1 * kK1) / kBlockSize / kAlignmentV;
|
||||
constexpr index_t K_VMEM_LOAD_INST =
|
||||
(kK1 * kSubQKHeaddim) / kBlockSize / kAlignmentK;
|
||||
constexpr index_t K_LDS_WRITE_INST = (kK1 * kSubQKHeaddim) / kBlockSize /
|
||||
Policy::template GetSmemKPackK<Problem>();
|
||||
constexpr index_t MFMA_INST = (kM0 * kSubQKHeaddim) / kBlockSize / 4;
|
||||
constexpr index_t K_LDS_READ_INST = MFMA_INST / kGemmNumRepM;
|
||||
|
||||
// load v_tile for current unroll
|
||||
v_tile = load_tile(v_dram_window);
|
||||
|
||||
store_tile(k_lds_windows[number<i_k1 % NumKVLdsBuffers>{}],
|
||||
tile_elementwise_in(k_element_func, k_tile));
|
||||
|
||||
move_tile_window(v_dram_window, {0, kK1});
|
||||
|
||||
// for i_k1 = k1_loop-1, the loading is for next iteration
|
||||
@@ -386,8 +385,6 @@ struct HstuAttentionFwdPipelineQRKSVS
|
||||
// execute current unroll of gemm_0
|
||||
gemm_0(sacc_tile, q_tile, k_lds_windows[number<i_k1 % NumKVLdsBuffers>{}]);
|
||||
|
||||
__builtin_amdgcn_sched_group_barrier(0x00000200, K_LDS_WRITE_INST, 0);
|
||||
|
||||
__builtin_amdgcn_sched_group_barrier(0x00000020, V_VMEM_LOAD_INST, 0);
|
||||
|
||||
__builtin_amdgcn_sched_group_barrier(0x00000100, K_LDS_READ_INST, 0);
|
||||
@@ -410,9 +407,6 @@ struct HstuAttentionFwdPipelineQRKSVS
|
||||
// load v_tile for current unroll
|
||||
v_tile = load_tile(v_dram_window);
|
||||
|
||||
store_tile(k_lds_windows[number<i_k1 % NumKVLdsBuffers>{}],
|
||||
tile_elementwise_in(k_element_func, k_tile));
|
||||
|
||||
move_tile_window(v_dram_window, {0, kK1});
|
||||
|
||||
// for i_k1 = k1_loop-1, the loading is for next iteration
|
||||
@@ -516,6 +510,9 @@ struct HstuAttentionFwdPipelineQRKSVS
|
||||
{
|
||||
__builtin_amdgcn_s_barrier();
|
||||
};
|
||||
|
||||
store_tile(k_lds_windows[number<(i_k1 + 1) % NumKVLdsBuffers>{}],
|
||||
tile_elementwise_in(k_element_func, k_tile));
|
||||
}
|
||||
else
|
||||
{
|
||||
@@ -525,6 +522,9 @@ struct HstuAttentionFwdPipelineQRKSVS
|
||||
{
|
||||
__builtin_amdgcn_s_barrier();
|
||||
};
|
||||
|
||||
store_tile(k_lds_windows[number<0>{}],
|
||||
tile_elementwise_in(k_element_func, k_tile));
|
||||
}
|
||||
});
|
||||
} while(seqlen_k_curr < seqlen_k_end);
|
||||
|
||||
Reference in New Issue
Block a user