From acb6cd89d9369f5561fb6e2d9c61d12978492783 Mon Sep 17 00:00:00 2001 From: Qianfeng Zhang Date: Mon, 21 Jul 2025 04:17:10 +0000 Subject: [PATCH] Move store_tile() caled before the current iteration --- .../hstu_attention_fwd_pipeline.hpp | 20 +++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/example/ck_tile/18_hstu_attention/hstu_attention_fwd_pipeline.hpp b/example/ck_tile/18_hstu_attention/hstu_attention_fwd_pipeline.hpp index 795b9349da..01ef3cc496 100644 --- a/example/ck_tile/18_hstu_attention/hstu_attention_fwd_pipeline.hpp +++ b/example/ck_tile/18_hstu_attention/hstu_attention_fwd_pipeline.hpp @@ -353,10 +353,14 @@ struct HstuAttentionFwdPipelineQRKSVS // ensure all q_reg_tiles[] have been loaded from LDS, so the LDS can be reused by k_tile __builtin_amdgcn_s_barrier(); + __builtin_amdgcn_sched_barrier(0); + using v_tile_type = decltype(load_tile(v_dram_window)); v_tile_type v_tile; + store_tile(k_lds_windows[number<0>{}], tile_elementwise_in(k_element_func, k_tile)); + do { static_for<0, k1_loops, 1>{}([&](auto i_k1) { @@ -365,17 +369,12 @@ struct HstuAttentionFwdPipelineQRKSVS constexpr index_t V_VMEM_LOAD_INST = (kN1 * kK1) / kBlockSize / kAlignmentV; constexpr index_t K_VMEM_LOAD_INST = (kK1 * kSubQKHeaddim) / kBlockSize / kAlignmentK; - constexpr index_t K_LDS_WRITE_INST = (kK1 * kSubQKHeaddim) / kBlockSize / - Policy::template GetSmemKPackK(); constexpr index_t MFMA_INST = (kM0 * kSubQKHeaddim) / kBlockSize / 4; constexpr index_t K_LDS_READ_INST = MFMA_INST / kGemmNumRepM; // load v_tile for current unroll v_tile = load_tile(v_dram_window); - store_tile(k_lds_windows[number{}], - tile_elementwise_in(k_element_func, k_tile)); - move_tile_window(v_dram_window, {0, kK1}); // for i_k1 = k1_loop-1, the loading is for next iteration @@ -386,8 +385,6 @@ struct HstuAttentionFwdPipelineQRKSVS // execute current unroll of gemm_0 gemm_0(sacc_tile, q_tile, k_lds_windows[number{}]); - __builtin_amdgcn_sched_group_barrier(0x00000200, K_LDS_WRITE_INST, 0); - __builtin_amdgcn_sched_group_barrier(0x00000020, V_VMEM_LOAD_INST, 0); __builtin_amdgcn_sched_group_barrier(0x00000100, K_LDS_READ_INST, 0); @@ -410,9 +407,6 @@ struct HstuAttentionFwdPipelineQRKSVS // load v_tile for current unroll v_tile = load_tile(v_dram_window); - store_tile(k_lds_windows[number{}], - tile_elementwise_in(k_element_func, k_tile)); - move_tile_window(v_dram_window, {0, kK1}); // for i_k1 = k1_loop-1, the loading is for next iteration @@ -516,6 +510,9 @@ struct HstuAttentionFwdPipelineQRKSVS { __builtin_amdgcn_s_barrier(); }; + + store_tile(k_lds_windows[number<(i_k1 + 1) % NumKVLdsBuffers>{}], + tile_elementwise_in(k_element_func, k_tile)); } else { @@ -525,6 +522,9 @@ struct HstuAttentionFwdPipelineQRKSVS { __builtin_amdgcn_s_barrier(); }; + + store_tile(k_lds_windows[number<0>{}], + tile_elementwise_in(k_element_func, k_tile)); } }); } while(seqlen_k_curr < seqlen_k_end);