remove innerloop, move restrict parameters to mainloop and add noinline attribute.

This commit is contained in:
Kevin Choi
2025-08-14 12:11:17 +00:00
parent 3340408537
commit 598e3fec41

View File

@@ -638,11 +638,11 @@ struct BlockFmhaPipelineQRKSVSAsyncTrload
typename LSEaccDramBlockWindowTmp,
typename PositionEncoding>
CK_TILE_HOST_DEVICE auto
operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile
const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*K0 tile
const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile
const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile
LSEaccDramBlockWindowTmp& lse_acc_dram_window_tmp, // M0*1 tile
operator()(const QDramBlockWindowTmp& __restrict__ q_dram_block_window_tmp, // M0*K0 tile
const KDramBlockWindowTmp& __restrict__ k_dram_block_window_tmp, // N0*K0 tile
const VDramBlockWindowTmp& __restrict__ v_dram_block_window_tmp, // N1*K1 tile
const BiasDramBlockWindowTmp& __restrict__ bias_dram_block_window_tmp, // M0*N0 tile
LSEaccDramBlockWindowTmp& __restrict__ lse_acc_dram_window_tmp, // M0*1 tile
FmhaMask mask,
PositionEncoding position_encoding,
float scale_s,
@@ -854,12 +854,11 @@ struct BlockFmhaPipelineQRKSVSAsyncTrload
__builtin_amdgcn_sched_barrier(0);
auto mainloop = [&](index_t cur_loop) {
const bool is_even_loop = (cur_loop % 2 == 0);
auto innerloop = [&](KDataType* __restrict__ k_lds_write_ptr,
KDataType* __restrict__ k_lds_read_ptr,
KDataType* __restrict__ v_lds_write_ptr,
KDataType* __restrict__ v_lds_read_ptr) {
auto mainloop = [&] __attribute__((noinline)) (index_t cur_loop,
KDataType* __restrict__ k_lds_write_ptr,
KDataType* __restrict__ k_lds_read_ptr,
KDataType* __restrict__ v_lds_write_ptr,
KDataType* __restrict__ v_lds_read_ptr) {
// move V tile windows
block_sync_lds<k_lds_insts>();
@@ -1105,8 +1104,11 @@ struct BlockFmhaPipelineQRKSVSAsyncTrload
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS_READ
});
}; // innerloop
}; // mainloop
do
{
bool is_even_loop = i_total_loops % 2 == 0;
auto k_lds_write_ptr = is_even_loop ? static_cast<KDataType* __restrict__>(smem_ptrk0)
: static_cast<KDataType* __restrict__>(smem_ptrk1);
auto k_lds_read_ptr = is_even_loop ? static_cast<KDataType* __restrict__>(smem_ptrk1)
@@ -1115,15 +1117,7 @@ struct BlockFmhaPipelineQRKSVSAsyncTrload
: static_cast<VDataType* __restrict__>(smem_ptrv0);
auto v_lds_read_ptr = is_even_loop ? static_cast<VDataType* __restrict__>(smem_ptrv0)
: static_cast<VDataType* __restrict__>(smem_ptrv1);
innerloop(k_lds_write_ptr,
k_lds_read_ptr,
v_lds_write_ptr,
v_lds_read_ptr);
}; // mainloop
do
{
mainloop(i_total_loops);
mainloop(i_total_loops, k_lds_write_ptr, k_lds_read_ptr, v_lds_write_ptr, v_lds_read_ptr);
i_total_loops++;
} while(i_total_loops < num_total_loop);