diff --git a/example/ck_tile/18_hstu_attention/hstu_attention_fwd_pipeline.hpp b/example/ck_tile/18_hstu_attention/hstu_attention_fwd_pipeline.hpp index 1d9691da65..31049ca955 100644 --- a/example/ck_tile/18_hstu_attention/hstu_attention_fwd_pipeline.hpp +++ b/example/ck_tile/18_hstu_attention/hstu_attention_fwd_pipeline.hpp @@ -314,8 +314,6 @@ struct HstuAttentionFwdPipelineQRKSVS q_tile_type q_tile; { - clear_tile(o_acc); - constexpr index_t complete_tile_thread_buf_size = q_tile_type::get_thread_buffer_size(); constexpr index_t splitted_tile_thread_buf_size = q_reg_tile_type::get_thread_buffer_size(); @@ -344,16 +342,20 @@ struct HstuAttentionFwdPipelineQRKSVS // no need to call __builtin_amdgcn_s_barrier() since the tile-slice read // by each wavefront is over-written by itself }); + + clear_tile(o_acc); }; q_tile = tile_elementwise_in(q_element_func, q_tile); auto seqlen_k_curr = seqlen_k_start; + __builtin_amdgcn_sched_barrier(0x00000001); + // ensure all q_reg_tiles[] have been loaded from LDS, so the LDS can be reused by k_tile __builtin_amdgcn_s_barrier(); - __builtin_amdgcn_sched_barrier(0); + __builtin_amdgcn_sched_barrier(0x00000001); using v_tile_type = decltype(load_tile(v_dram_window)); @@ -438,7 +440,6 @@ struct HstuAttentionFwdPipelineQRKSVS } else { - // load v_tile for current unroll v_tile = load_tile(v_dram_window); @@ -448,6 +449,8 @@ struct HstuAttentionFwdPipelineQRKSVS k_tile = load_tile(k_dram_window); move_tile_window(k_dram_window, {kK1, 0}); + __builtin_amdgcn_sched_barrier(0x00000001); + block_sync_lds(); // execute current unroll of gemm_0 gemm_0(sacc_tile, q_tile, k_lds_windows[number{}]); @@ -548,6 +551,8 @@ struct HstuAttentionFwdPipelineQRKSVS store_tile(k_lds_windows[number<(i_k1 + 1) % NumKVLdsBuffers>{}], tile_elementwise_in(k_element_func, k_tile)); + + __builtin_amdgcn_sched_barrier(0x00000001); } else {