diff --git a/example/ck_tile/18_hstu_attention/hstu_attention_fwd_pipeline.hpp b/example/ck_tile/18_hstu_attention/hstu_attention_fwd_pipeline.hpp index 014929bd08..023de7e655 100644 --- a/example/ck_tile/18_hstu_attention/hstu_attention_fwd_pipeline.hpp +++ b/example/ck_tile/18_hstu_attention/hstu_attention_fwd_pipeline.hpp @@ -337,6 +337,9 @@ struct HstuAttentionFwdPipelineQRKSVS q_reg_tiles[i_rep] = load_tile(q_lds_read_window); + __builtin_amdgcn_s_waitcnt(0xc07f); + + // the following codes will not generate actual instructions by the compiler static_for<0, splitted_tile_thread_buf_size, 1>{}([&](auto i_buf) { q_tile.get_thread_buffer()[i_rep * splitted_tile_thread_buf_size + i_buf] = q_reg_tiles[i_rep].get_thread_buffer()[i_buf]; @@ -356,7 +359,7 @@ struct HstuAttentionFwdPipelineQRKSVS // ensure all q_reg_tiles[] have been loaded from LDS, so the LDS can be reused by k_tile __builtin_amdgcn_s_barrier(); - while(i_loop < num_loops) + do { static_for<0, k1_loops, 1>{}([&](auto i_k1) { // load v_tile for current unroll @@ -525,9 +528,7 @@ struct HstuAttentionFwdPipelineQRKSVS // this does not occur when k1_loops == 2 and NumKVLdsBuffers == 3 if constexpr(Policy::template IsFirstKLdsBufferOverlapLastVLdsBuffer()) __builtin_amdgcn_s_barrier(); - - i_loop++; - }; + } while(i_loop++ < num_loops); tile_elementwise_inout( [&](auto& x) {