diff --git a/example/ck_tile/18_hstu_attention/hstu_attention_fwd_pipeline.hpp b/example/ck_tile/18_hstu_attention/hstu_attention_fwd_pipeline.hpp index ecbadfe1ac..4019294153 100644 --- a/example/ck_tile/18_hstu_attention/hstu_attention_fwd_pipeline.hpp +++ b/example/ck_tile/18_hstu_attention/hstu_attention_fwd_pipeline.hpp @@ -296,9 +296,6 @@ struct HstuAttentionFwdPipelineQRKSVS q_tile = tile_elementwise_in(q_element_func, q_tile); - auto v_tile = load_tile(v_dram_window); - move_tile_window(v_dram_window, {0, kK1}); - auto seqlen_k_curr = seqlen_k_start; index_t i_loop = 0; @@ -309,9 +306,11 @@ struct HstuAttentionFwdPipelineQRKSVS store_tile(k_lds_windows[number{}], tile_elementwise_in(k_element_func, k_tile)); - // for i_k1 = k1_loop-1, the loading is for next iteration - k_tile = load_tile(k_dram_window); - move_tile_window(k_dram_window, {kK1, 0}); + // load v_tile for current unroll + auto v_tile = load_tile(v_dram_window); + move_tile_window(v_dram_window, {0, kK1}); + + __builtin_amdgcn_sched_barrier(0); clear_tile(sacc_tiles[i_k1]); @@ -433,9 +432,11 @@ struct HstuAttentionFwdPipelineQRKSVS tile_elementwise_in(v_element_func, v_tile)); // store the prefetch }; - // for i_k1 = k1_loops-1, the loading is for next iteration - v_tile = load_tile(v_dram_window); - move_tile_window(v_dram_window, {0, kK1}); + // for i_k1 = k1_loop-1, the loading is for next iteration + k_tile = load_tile(k_dram_window); + move_tile_window(k_dram_window, {kK1, 0}); + + __builtin_amdgcn_sched_barrier(0); tile_elementwise_inout(f_silu, pcomp_tiles[i_k1]);