Re-arrange the codes section for using sched_group_barrier

This commit is contained in:
Qianfeng Zhang
2025-07-21 08:16:22 +00:00
parent 906ab84d6c
commit fcd41a6f39

View File

@@ -385,19 +385,55 @@ struct HstuAttentionFwdPipelineQRKSVS
// execute current unroll of gemm_0
gemm_0(sacc_tile, q_tile, k_lds_windows[number<i_k1 % NumKVLdsBuffers>{}]);
__builtin_amdgcn_sched_group_barrier(0x00000020, V_VMEM_LOAD_INST, 0);
if constexpr(V_VMEM_LOAD_INST + K_VMEM_LOAD_INST < K_LDS_READ_INST)
{
static_for<0, V_VMEM_LOAD_INST + K_VMEM_LOAD_INST, 1>{}([&](auto i) {
ignore = i;
__builtin_amdgcn_sched_group_barrier(0x00000100, 1, 0);
// ds_load for K
__builtin_amdgcn_sched_group_barrier(0x00000100, 1, 0);
// buffer_load for V & K
__builtin_amdgcn_sched_group_barrier(
0x00000020, 1, 0); // buffer_load for K & V
});
__builtin_amdgcn_sched_group_barrier(0x00000020, K_VMEM_LOAD_INST, 0);
static_for<0, K_LDS_READ_INST - (V_VMEM_LOAD_INST + K_VMEM_LOAD_INST), 1>{}(
[&](auto i) {
ignore = i;
// ds_load for K
__builtin_amdgcn_sched_group_barrier(0x00000100, 1, 0);
// MFMA
__builtin_amdgcn_sched_group_barrier(0x00000008, kGemmNumRepM, 0);
});
static_for<0, K_LDS_READ_INST - 1, 1>{}([&](auto i) {
ignore = i;
static_for<0, V_VMEM_LOAD_INST + K_VMEM_LOAD_INST, 1>{}([&](auto i) {
ignore = i;
// MFMA
__builtin_amdgcn_sched_group_barrier(0x00000008, kGemmNumRepM, 0);
});
}
else
{
// buffer_load for V
__builtin_amdgcn_sched_group_barrier(0x00000020, V_VMEM_LOAD_INST, 0);
// ds_load for K
__builtin_amdgcn_sched_group_barrier(0x00000100, 1, 0);
__builtin_amdgcn_sched_group_barrier(0x00000008, kGemmNumRepM, 0);
});
__builtin_amdgcn_sched_group_barrier(0x00000008, kGemmNumRepM, 0);
// buffer_load for K
__builtin_amdgcn_sched_group_barrier(0x00000020, K_VMEM_LOAD_INST, 0);
static_for<0, K_LDS_READ_INST - 1, 1>{}([&](auto i) {
ignore = i;
// ds_load for K
__builtin_amdgcn_sched_group_barrier(0x00000100, 1, 0);
// MFMA
__builtin_amdgcn_sched_group_barrier(0x00000008, kGemmNumRepM, 0);
});
// MFMA
__builtin_amdgcn_sched_group_barrier(0x00000008, kGemmNumRepM, 0);
};
__builtin_amdgcn_sched_barrier(0x00000001);
}