Change while() do to do while() for the main loop to let the compiler to generate more elegant codes

This commit is contained in:
Qianfeng Zhang
2025-06-21 12:58:27 +00:00
parent f9caae2d8b
commit a5f24d7470

View File

@@ -337,6 +337,9 @@ struct HstuAttentionFwdPipelineQRKSVS
q_reg_tiles[i_rep] = load_tile(q_lds_read_window);
__builtin_amdgcn_s_waitcnt(0xc07f);
// the following codes will not generate actual instructions by the compiler
static_for<0, splitted_tile_thread_buf_size, 1>{}([&](auto i_buf) {
q_tile.get_thread_buffer()[i_rep * splitted_tile_thread_buf_size + i_buf] =
q_reg_tiles[i_rep].get_thread_buffer()[i_buf];
@@ -356,7 +359,7 @@ struct HstuAttentionFwdPipelineQRKSVS
// ensure all q_reg_tiles[] have been loaded from LDS, so the LDS can be reused by k_tile
__builtin_amdgcn_s_barrier();
while(i_loop < num_loops)
do
{
static_for<0, k1_loops, 1>{}([&](auto i_k1) {
// load v_tile for current unroll
@@ -525,9 +528,7 @@ struct HstuAttentionFwdPipelineQRKSVS
// this does not occur when k1_loops == 2 and NumKVLdsBuffers == 3
if constexpr(Policy::template IsFirstKLdsBufferOverlapLastVLdsBuffer<Problem>())
__builtin_amdgcn_s_barrier();
i_loop++;
};
} while(i_loop++ < num_loops);
tile_elementwise_inout(
[&](auto& x) {