From 598e3fec417eb0ff8089c260e758aa2c305ccd1d Mon Sep 17 00:00:00 2001 From: Kevin Choi Date: Thu, 14 Aug 2025 12:11:17 +0000 Subject: [PATCH] remove innerloop, move restrict parameters to mainloop and add noinline attribute. --- ...ck_fmha_pipeline_qr_ks_vs_async_trload.hpp | 38 ++++++++----------- 1 file changed, 16 insertions(+), 22 deletions(-) diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async_trload.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async_trload.hpp index d71f448415..5587346de3 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async_trload.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async_trload.hpp @@ -638,11 +638,11 @@ struct BlockFmhaPipelineQRKSVSAsyncTrload typename LSEaccDramBlockWindowTmp, typename PositionEncoding> CK_TILE_HOST_DEVICE auto - operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile - const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*K0 tile - const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile - const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile - LSEaccDramBlockWindowTmp& lse_acc_dram_window_tmp, // M0*1 tile + operator()(const QDramBlockWindowTmp& __restrict__ q_dram_block_window_tmp, // M0*K0 tile + const KDramBlockWindowTmp& __restrict__ k_dram_block_window_tmp, // N0*K0 tile + const VDramBlockWindowTmp& __restrict__ v_dram_block_window_tmp, // N1*K1 tile + const BiasDramBlockWindowTmp& __restrict__ bias_dram_block_window_tmp, // M0*N0 tile + LSEaccDramBlockWindowTmp& __restrict__ lse_acc_dram_window_tmp, // M0*1 tile FmhaMask mask, PositionEncoding position_encoding, float scale_s, @@ -854,12 +854,11 @@ struct BlockFmhaPipelineQRKSVSAsyncTrload __builtin_amdgcn_sched_barrier(0); - auto mainloop = [&](index_t cur_loop) { - const bool is_even_loop = (cur_loop % 2 == 0); - auto innerloop = [&](KDataType* __restrict__ k_lds_write_ptr, - KDataType* __restrict__ k_lds_read_ptr, - KDataType* __restrict__ v_lds_write_ptr, - KDataType* __restrict__ v_lds_read_ptr) { + auto mainloop = [&] __attribute__((noinline)) (index_t cur_loop, + KDataType* __restrict__ k_lds_write_ptr, + KDataType* __restrict__ k_lds_read_ptr, + KDataType* __restrict__ v_lds_write_ptr, + KDataType* __restrict__ v_lds_read_ptr) { // move V tile windows block_sync_lds(); @@ -1105,8 +1104,11 @@ struct BlockFmhaPipelineQRKSVSAsyncTrload __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS_READ }); - - }; // innerloop + }; // mainloop + + do + { + bool is_even_loop = i_total_loops % 2 == 0; auto k_lds_write_ptr = is_even_loop ? static_cast(smem_ptrk0) : static_cast(smem_ptrk1); auto k_lds_read_ptr = is_even_loop ? static_cast(smem_ptrk1) @@ -1115,15 +1117,7 @@ struct BlockFmhaPipelineQRKSVSAsyncTrload : static_cast(smem_ptrv0); auto v_lds_read_ptr = is_even_loop ? static_cast(smem_ptrv0) : static_cast(smem_ptrv1); - innerloop(k_lds_write_ptr, - k_lds_read_ptr, - v_lds_write_ptr, - v_lds_read_ptr); - - }; // mainloop - do - { - mainloop(i_total_loops); + mainloop(i_total_loops, k_lds_write_ptr, k_lds_read_ptr, v_lds_write_ptr, v_lds_read_ptr); i_total_loops++; } while(i_total_loops < num_total_loop);