diff --git a/example/ck_tile/18_hstu_attention/hstu_attention_fwd_pipeline.hpp b/example/ck_tile/18_hstu_attention/hstu_attention_fwd_pipeline.hpp index 7bfedb8ff0..4ff04a69af 100644 --- a/example/ck_tile/18_hstu_attention/hstu_attention_fwd_pipeline.hpp +++ b/example/ck_tile/18_hstu_attention/hstu_attention_fwd_pipeline.hpp @@ -375,95 +375,20 @@ struct HstuAttentionFwdPipelineQRKSVS do { static_for<0, k1_loops, 1>{}([&](auto i_k1) { - if constexpr(HstuMask::kUseLocal) - { - constexpr index_t V_VMEM_LOAD_INST = (kN1 * kK1) / kBlockSize / kAlignmentV; - constexpr index_t K_VMEM_LOAD_INST = - (kK1 * kSubQKHeaddim) / kBlockSize / kAlignmentK; - constexpr index_t MFMA_INST = (kM0 * kSubQKHeaddim) / kBlockSize / 4; - constexpr index_t K_LDS_READ_INST = MFMA_INST / kGemmNumRepM; + // load v_tile for current unroll + v_tile = load_tile(v_dram_window); - // load v_tile for current unroll - v_tile = load_tile(v_dram_window); + move_tile_window(v_dram_window, {0, kK1}); - move_tile_window(v_dram_window, {0, kK1}); + // for i_k1 = k1_loop-1, the loading is for next iteration + k_tile = load_tile(k_dram_window); + move_tile_window(k_dram_window, {kK1, 0}); - // for i_k1 = k1_loop-1, the loading is for next iteration - k_tile = load_tile(k_dram_window); - move_tile_window(k_dram_window, {kK1, 0}); + __builtin_amdgcn_sched_barrier(0x00000001); - block_sync_lds(); - // execute current unroll of gemm_0 - gemm_0(sacc_tile, q_tile, k_lds_read_windows[number{}]); - - if constexpr(V_VMEM_LOAD_INST + K_VMEM_LOAD_INST < K_LDS_READ_INST) - { - static_for<0, V_VMEM_LOAD_INST + K_VMEM_LOAD_INST, 1>{}([&](auto i) { - ignore = i; - - // ds_load for K - __builtin_amdgcn_sched_group_barrier(0x00000100, 1, 0); - // buffer_load for V & K - __builtin_amdgcn_sched_group_barrier(0x00000020, 1, 0); - }); - - static_for<0, K_LDS_READ_INST - (V_VMEM_LOAD_INST + K_VMEM_LOAD_INST), 1>{}( - [&](auto i) { - ignore = i; - // ds_load for K - __builtin_amdgcn_sched_group_barrier(0x00000100, 1, 0); - // MFMA - __builtin_amdgcn_sched_group_barrier(0x00000008, kGemmNumRepM, 0); - }); - - static_for<0, V_VMEM_LOAD_INST + K_VMEM_LOAD_INST, 1>{}([&](auto i) { - ignore = i; - // MFMA - __builtin_amdgcn_sched_group_barrier(0x00000008, kGemmNumRepM, 0); - }); - } - else - { - // buffer_load for V - __builtin_amdgcn_sched_group_barrier(0x00000020, V_VMEM_LOAD_INST, 0); - - // ds_load for K - __builtin_amdgcn_sched_group_barrier(0x00000100, 1, 0); - - // buffer_load for K - __builtin_amdgcn_sched_group_barrier(0x00000020, K_VMEM_LOAD_INST, 0); - - static_for<0, K_LDS_READ_INST - 1, 1>{}([&](auto i) { - ignore = i; - // ds_load for K - __builtin_amdgcn_sched_group_barrier(0x00000100, 1, 0); - // MFMA - __builtin_amdgcn_sched_group_barrier(0x00000008, kGemmNumRepM, 0); - }); - - // MFMA - __builtin_amdgcn_sched_group_barrier(0x00000008, kGemmNumRepM, 0); - }; - - __builtin_amdgcn_sched_barrier(0x00000001); - } - else - { - // load v_tile for current unroll - v_tile = load_tile(v_dram_window); - - move_tile_window(v_dram_window, {0, kK1}); - - // for i_k1 = k1_loop-1, the loading is for next iteration - k_tile = load_tile(k_dram_window); - move_tile_window(k_dram_window, {kK1, 0}); - - __builtin_amdgcn_sched_barrier(0x00000001); - - block_sync_lds(); - // execute current unroll of gemm_0 - gemm_0(sacc_tile, q_tile, k_lds_read_windows[number{}]); - }; + block_sync_lds(); + // execute current unroll of gemm_0 + gemm_0(sacc_tile, q_tile, k_lds_read_windows[number{}]); sacc_tile = tile_elementwise_in(s_acc_element_func, sacc_tile);