From 0206b3442aceac5e28f6c082c3f4c283a18cb170 Mon Sep 17 00:00:00 2001 From: Qianfeng Zhang Date: Mon, 14 Jul 2025 16:01:30 +0000 Subject: [PATCH] [Performance] use iglp compiler instruction to tune the codes around gemm0 for window_size > 0 situation --- .../hstu_attention_fwd_pipeline.hpp | 73 ++++++++++++++++--- 1 file changed, 62 insertions(+), 11 deletions(-) diff --git a/example/ck_tile/18_hstu_attention/hstu_attention_fwd_pipeline.hpp b/example/ck_tile/18_hstu_attention/hstu_attention_fwd_pipeline.hpp index ff5da4aca0..ecca0e752d 100644 --- a/example/ck_tile/18_hstu_attention/hstu_attention_fwd_pipeline.hpp +++ b/example/ck_tile/18_hstu_attention/hstu_attention_fwd_pipeline.hpp @@ -353,24 +353,75 @@ struct HstuAttentionFwdPipelineQRKSVS // ensure all q_reg_tiles[] have been loaded from LDS, so the LDS can be reused by k_tile __builtin_amdgcn_s_barrier(); + using v_tile_type = decltype(load_tile(v_dram_window)); + + v_tile_type v_tile; + do { static_for<0, k1_loops, 1>{}([&](auto i_k1) { - // load v_tile for current unroll - auto v_tile = load_tile(v_dram_window); + if constexpr(HstuMask::kUseLocal) + { + constexpr index_t V_VMEM_LOAD_INST = (kN1 * kK1) / kBlockSize / kAlignmentV; + constexpr index_t K_VMEM_LOAD_INST = (kN0 * kK0) / kBlockSize / kAlignmentV; + constexpr index_t K_LDS_WRITE_INST = + (kN0 * kK0) / kBlockSize / Policy::template GetSmemKPackK(); + constexpr index_t MFMA_INST = (kM0 * kSubQKHeaddim) / kBlockSize / 4; + constexpr index_t K_LDS_READ_INST = MFMA_INST / kGemmNumRepM; - store_tile(k_lds_windows[number{}], - tile_elementwise_in(k_element_func, k_tile)); + // load v_tile for current unroll + v_tile = load_tile(v_dram_window); - move_tile_window(v_dram_window, {0, kK1}); + store_tile(k_lds_windows[number{}], + tile_elementwise_in(k_element_func, k_tile)); - // for i_k1 = k1_loop-1, the loading is for next iteration - k_tile = load_tile(k_dram_window); - move_tile_window(k_dram_window, {kK1, 0}); + move_tile_window(v_dram_window, {0, kK1}); - block_sync_lds(); - // execute current unroll of gemm_0 - gemm_0(sacc_tile, q_tile, k_lds_windows[number{}]); + // for i_k1 = k1_loop-1, the loading is for next iteration + k_tile = load_tile(k_dram_window); + move_tile_window(k_dram_window, {kK1, 0}); + + block_sync_lds(); + // execute current unroll of gemm_0 + gemm_0(sacc_tile, q_tile, k_lds_windows[number{}]); + + __builtin_amdgcn_sched_group_barrier(0x00000200, K_LDS_WRITE_INST, 0); + + __builtin_amdgcn_sched_group_barrier(0x00000020, V_VMEM_LOAD_INST, 0); + + __builtin_amdgcn_sched_group_barrier(0x00000100, K_LDS_READ_INST, 0); + + __builtin_amdgcn_sched_group_barrier(0x00000020, K_VMEM_LOAD_INST, 0); + + static_for<0, K_LDS_READ_INST - 1, 1>{}([&](auto i) { + ignore = i; + __builtin_amdgcn_sched_group_barrier(0x00000100, K_LDS_READ_INST, 0); + __builtin_amdgcn_sched_group_barrier(0x00000008, kGemmNumRepM, 0); + }); + + __builtin_amdgcn_sched_group_barrier(0x00000008, kGemmNumRepM, 0); + + __builtin_amdgcn_sched_barrier(0); + } + else + { + + // load v_tile for current unroll + v_tile = load_tile(v_dram_window); + + store_tile(k_lds_windows[number{}], + tile_elementwise_in(k_element_func, k_tile)); + + move_tile_window(v_dram_window, {0, kK1}); + + // for i_k1 = k1_loop-1, the loading is for next iteration + k_tile = load_tile(k_dram_window); + move_tile_window(k_dram_window, {kK1, 0}); + + block_sync_lds(); + // execute current unroll of gemm_0 + gemm_0(sacc_tile, q_tile, k_lds_windows[number{}]); + }; sacc_tile = tile_elementwise_in(s_acc_element_func, sacc_tile);