Move store_tile() caled before the current iteration

This commit is contained in:
Qianfeng Zhang
2025-07-21 04:17:10 +00:00
parent fed1474e4f
commit acb6cd89d9

View File

@@ -353,10 +353,14 @@ struct HstuAttentionFwdPipelineQRKSVS
// ensure all q_reg_tiles[] have been loaded from LDS, so the LDS can be reused by k_tile
__builtin_amdgcn_s_barrier();
__builtin_amdgcn_sched_barrier(0);
using v_tile_type = decltype(load_tile(v_dram_window));
v_tile_type v_tile;
store_tile(k_lds_windows[number<0>{}], tile_elementwise_in(k_element_func, k_tile));
do
{
static_for<0, k1_loops, 1>{}([&](auto i_k1) {
@@ -365,17 +369,12 @@ struct HstuAttentionFwdPipelineQRKSVS
constexpr index_t V_VMEM_LOAD_INST = (kN1 * kK1) / kBlockSize / kAlignmentV;
constexpr index_t K_VMEM_LOAD_INST =
(kK1 * kSubQKHeaddim) / kBlockSize / kAlignmentK;
constexpr index_t K_LDS_WRITE_INST = (kK1 * kSubQKHeaddim) / kBlockSize /
Policy::template GetSmemKPackK<Problem>();
constexpr index_t MFMA_INST = (kM0 * kSubQKHeaddim) / kBlockSize / 4;
constexpr index_t K_LDS_READ_INST = MFMA_INST / kGemmNumRepM;
// load v_tile for current unroll
v_tile = load_tile(v_dram_window);
store_tile(k_lds_windows[number<i_k1 % NumKVLdsBuffers>{}],
tile_elementwise_in(k_element_func, k_tile));
move_tile_window(v_dram_window, {0, kK1});
// for i_k1 = k1_loop-1, the loading is for next iteration
@@ -386,8 +385,6 @@ struct HstuAttentionFwdPipelineQRKSVS
// execute current unroll of gemm_0
gemm_0(sacc_tile, q_tile, k_lds_windows[number<i_k1 % NumKVLdsBuffers>{}]);
__builtin_amdgcn_sched_group_barrier(0x00000200, K_LDS_WRITE_INST, 0);
__builtin_amdgcn_sched_group_barrier(0x00000020, V_VMEM_LOAD_INST, 0);
__builtin_amdgcn_sched_group_barrier(0x00000100, K_LDS_READ_INST, 0);
@@ -410,9 +407,6 @@ struct HstuAttentionFwdPipelineQRKSVS
// load v_tile for current unroll
v_tile = load_tile(v_dram_window);
store_tile(k_lds_windows[number<i_k1 % NumKVLdsBuffers>{}],
tile_elementwise_in(k_element_func, k_tile));
move_tile_window(v_dram_window, {0, kK1});
// for i_k1 = k1_loop-1, the loading is for next iteration
@@ -516,6 +510,9 @@ struct HstuAttentionFwdPipelineQRKSVS
{
__builtin_amdgcn_s_barrier();
};
store_tile(k_lds_windows[number<(i_k1 + 1) % NumKVLdsBuffers>{}],
tile_elementwise_in(k_element_func, k_tile));
}
else
{
@@ -525,6 +522,9 @@ struct HstuAttentionFwdPipelineQRKSVS
{
__builtin_amdgcn_s_barrier();
};
store_tile(k_lds_windows[number<0>{}],
tile_elementwise_in(k_element_func, k_tile));
}
});
} while(seqlen_k_curr < seqlen_k_end);