Remove using IGLP method for instruction scheduling for kUseLocal true path

This commit is contained in:
Qianfeng Zhang
2025-10-11 06:38:32 +00:00
parent 6b40ce4074
commit d308b09fae

View File

@@ -375,95 +375,20 @@ struct HstuAttentionFwdPipelineQRKSVS
do
{
static_for<0, k1_loops, 1>{}([&](auto i_k1) {
if constexpr(HstuMask::kUseLocal)
{
constexpr index_t V_VMEM_LOAD_INST = (kN1 * kK1) / kBlockSize / kAlignmentV;
constexpr index_t K_VMEM_LOAD_INST =
(kK1 * kSubQKHeaddim) / kBlockSize / kAlignmentK;
constexpr index_t MFMA_INST = (kM0 * kSubQKHeaddim) / kBlockSize / 4;
constexpr index_t K_LDS_READ_INST = MFMA_INST / kGemmNumRepM;
// load v_tile for current unroll
v_tile = load_tile(v_dram_window);
// load v_tile for current unroll
v_tile = load_tile(v_dram_window);
move_tile_window(v_dram_window, {0, kK1});
move_tile_window(v_dram_window, {0, kK1});
// for i_k1 = k1_loop-1, the loading is for next iteration
k_tile = load_tile(k_dram_window);
move_tile_window(k_dram_window, {kK1, 0});
// for i_k1 = k1_loop-1, the loading is for next iteration
k_tile = load_tile(k_dram_window);
move_tile_window(k_dram_window, {kK1, 0});
__builtin_amdgcn_sched_barrier(0x00000001);
block_sync_lds();
// execute current unroll of gemm_0
gemm_0(sacc_tile, q_tile, k_lds_read_windows[number<i_k1 % NumKVLdsBuffers>{}]);
if constexpr(V_VMEM_LOAD_INST + K_VMEM_LOAD_INST < K_LDS_READ_INST)
{
static_for<0, V_VMEM_LOAD_INST + K_VMEM_LOAD_INST, 1>{}([&](auto i) {
ignore = i;
// ds_load for K
__builtin_amdgcn_sched_group_barrier(0x00000100, 1, 0);
// buffer_load for V & K
__builtin_amdgcn_sched_group_barrier(0x00000020, 1, 0);
});
static_for<0, K_LDS_READ_INST - (V_VMEM_LOAD_INST + K_VMEM_LOAD_INST), 1>{}(
[&](auto i) {
ignore = i;
// ds_load for K
__builtin_amdgcn_sched_group_barrier(0x00000100, 1, 0);
// MFMA
__builtin_amdgcn_sched_group_barrier(0x00000008, kGemmNumRepM, 0);
});
static_for<0, V_VMEM_LOAD_INST + K_VMEM_LOAD_INST, 1>{}([&](auto i) {
ignore = i;
// MFMA
__builtin_amdgcn_sched_group_barrier(0x00000008, kGemmNumRepM, 0);
});
}
else
{
// buffer_load for V
__builtin_amdgcn_sched_group_barrier(0x00000020, V_VMEM_LOAD_INST, 0);
// ds_load for K
__builtin_amdgcn_sched_group_barrier(0x00000100, 1, 0);
// buffer_load for K
__builtin_amdgcn_sched_group_barrier(0x00000020, K_VMEM_LOAD_INST, 0);
static_for<0, K_LDS_READ_INST - 1, 1>{}([&](auto i) {
ignore = i;
// ds_load for K
__builtin_amdgcn_sched_group_barrier(0x00000100, 1, 0);
// MFMA
__builtin_amdgcn_sched_group_barrier(0x00000008, kGemmNumRepM, 0);
});
// MFMA
__builtin_amdgcn_sched_group_barrier(0x00000008, kGemmNumRepM, 0);
};
__builtin_amdgcn_sched_barrier(0x00000001);
}
else
{
// load v_tile for current unroll
v_tile = load_tile(v_dram_window);
move_tile_window(v_dram_window, {0, kK1});
// for i_k1 = k1_loop-1, the loading is for next iteration
k_tile = load_tile(k_dram_window);
move_tile_window(k_dram_window, {kK1, 0});
__builtin_amdgcn_sched_barrier(0x00000001);
block_sync_lds();
// execute current unroll of gemm_0
gemm_0(sacc_tile, q_tile, k_lds_read_windows[number<i_k1 % NumKVLdsBuffers>{}]);
};
block_sync_lds();
// execute current unroll of gemm_0
gemm_0(sacc_tile, q_tile, k_lds_read_windows[number<i_k1 % NumKVLdsBuffers>{}]);
sacc_tile = tile_elementwise_in(s_acc_element_func, sacc_tile);