diff --git a/example/ck_tile/18_hstu_attention/hstu_attention_fwd_pipeline.hpp b/example/ck_tile/18_hstu_attention/hstu_attention_fwd_pipeline.hpp index 790f6e2b90..ea82f9c43e 100644 --- a/example/ck_tile/18_hstu_attention/hstu_attention_fwd_pipeline.hpp +++ b/example/ck_tile/18_hstu_attention/hstu_attention_fwd_pipeline.hpp @@ -182,11 +182,11 @@ struct HstuAttentionFwdPipelineQRKSVS {seqlen_k_start, 0}, Policy::template MakeKDramTileDistribution()); + auto q_tile = load_tile(q_dram_window); + auto k_tile = load_tile(k_dram_window); move_tile_window(k_dram_window, {kK1, 0}); - auto q_tile = load_tile(q_dram_window); - __builtin_amdgcn_sched_barrier(0); // K tile in LDS @@ -310,8 +310,6 @@ struct HstuAttentionFwdPipelineQRKSVS store_tile(k_lds_windows[number{}], tile_elementwise_in(k_element_func, k_tile)); - clear_tile(sacc_tiles[i_k1]); - if constexpr(i_k1 < k1_loops - 1) { k_tile = load_tile(k_dram_window); @@ -319,12 +317,12 @@ struct HstuAttentionFwdPipelineQRKSVS } else { - static_for<0, NumPrefetchV, 1>{}([&](auto i_buf) { - v_tiles[i_buf] = load_tile(v_dram_window); - move_tile_window(v_dram_window, {0, kK1}); - }); + v_tiles[I0] = load_tile(v_dram_window); + move_tile_window(v_dram_window, {0, kK1}); }; + clear_tile(sacc_tiles[i_k1]); + block_sync_lds(); // execute current unroll of gemm_0 gemm_0(sacc_tiles[i_k1], q_tile, k_lds_windows[number{}]); @@ -379,9 +377,10 @@ struct HstuAttentionFwdPipelineQRKSVS seqlen_k_curr += kK1; }); - // load one k_tile for next iteration - k_tile = load_tile(k_dram_window); - move_tile_window(k_dram_window, {kK1, 0}); + static_for<1, NumPrefetchV, 1>{}([&](auto i_buf) { + v_tiles[i_buf] = load_tile(v_dram_window); + move_tile_window(v_dram_window, {0, kK1}); + }); if constexpr(std::is_same_v) { @@ -433,6 +432,12 @@ struct HstuAttentionFwdPipelineQRKSVS { v_tiles[number{}] = load_tile(v_dram_window); move_tile_window(v_dram_window, {0, kK1}); + } + else if constexpr(i_k1 == k1_loops - NumPrefetchV) + { + // load one k_tile for next iteration + k_tile = load_tile(k_dram_window); + move_tile_window(k_dram_window, {kK1, 0}); }; block_sync_lds();