From fcd41a6f39aea1fa6afe00834acf12cea4c33450 Mon Sep 17 00:00:00 2001 From: Qianfeng Zhang Date: Mon, 21 Jul 2025 08:16:22 +0000 Subject: [PATCH] Re-arrange the codes section for using sched_group_barrier --- .../hstu_attention_fwd_pipeline.hpp | 52 ++++++++++++++++--- 1 file changed, 44 insertions(+), 8 deletions(-) diff --git a/example/ck_tile/18_hstu_attention/hstu_attention_fwd_pipeline.hpp b/example/ck_tile/18_hstu_attention/hstu_attention_fwd_pipeline.hpp index 6667019f78..97b0656f5c 100644 --- a/example/ck_tile/18_hstu_attention/hstu_attention_fwd_pipeline.hpp +++ b/example/ck_tile/18_hstu_attention/hstu_attention_fwd_pipeline.hpp @@ -385,19 +385,55 @@ struct HstuAttentionFwdPipelineQRKSVS // execute current unroll of gemm_0 gemm_0(sacc_tile, q_tile, k_lds_windows[number{}]); - __builtin_amdgcn_sched_group_barrier(0x00000020, V_VMEM_LOAD_INST, 0); + if constexpr(V_VMEM_LOAD_INST + K_VMEM_LOAD_INST < K_LDS_READ_INST) + { + static_for<0, V_VMEM_LOAD_INST + K_VMEM_LOAD_INST, 1>{}([&](auto i) { + ignore = i; - __builtin_amdgcn_sched_group_barrier(0x00000100, 1, 0); + // ds_load for K + __builtin_amdgcn_sched_group_barrier(0x00000100, 1, 0); + // buffer_load for V & K + __builtin_amdgcn_sched_group_barrier( + 0x00000020, 1, 0); // buffer_load for K & V + }); - __builtin_amdgcn_sched_group_barrier(0x00000020, K_VMEM_LOAD_INST, 0); + static_for<0, K_LDS_READ_INST - (V_VMEM_LOAD_INST + K_VMEM_LOAD_INST), 1>{}( + [&](auto i) { + ignore = i; + // ds_load for K + __builtin_amdgcn_sched_group_barrier(0x00000100, 1, 0); + // MFMA + __builtin_amdgcn_sched_group_barrier(0x00000008, kGemmNumRepM, 0); + }); - static_for<0, K_LDS_READ_INST - 1, 1>{}([&](auto i) { - ignore = i; + static_for<0, V_VMEM_LOAD_INST + K_VMEM_LOAD_INST, 1>{}([&](auto i) { + ignore = i; + // MFMA + __builtin_amdgcn_sched_group_barrier(0x00000008, kGemmNumRepM, 0); + }); + } + else + { + // buffer_load for V + __builtin_amdgcn_sched_group_barrier(0x00000020, V_VMEM_LOAD_INST, 0); + + // ds_load for K __builtin_amdgcn_sched_group_barrier(0x00000100, 1, 0); - __builtin_amdgcn_sched_group_barrier(0x00000008, kGemmNumRepM, 0); - }); - __builtin_amdgcn_sched_group_barrier(0x00000008, kGemmNumRepM, 0); + // buffer_load for K + __builtin_amdgcn_sched_group_barrier(0x00000020, K_VMEM_LOAD_INST, 0); + + static_for<0, K_LDS_READ_INST - 1, 1>{}([&](auto i) { + ignore = i; + // ds_load for K + __builtin_amdgcn_sched_group_barrier(0x00000100, 1, 0); + // MFMA + __builtin_amdgcn_sched_group_barrier(0x00000008, kGemmNumRepM, 0); + }); + + // MFMA + __builtin_amdgcn_sched_group_barrier(0x00000008, kGemmNumRepM, 0); + }; __builtin_amdgcn_sched_barrier(0x00000001); }