diff --git a/example/ck_tile/01_fmha/CMakeLists.txt b/example/ck_tile/01_fmha/CMakeLists.txt index 68db468a7c..3d79f2f6d3 100644 --- a/example/ck_tile/01_fmha/CMakeLists.txt +++ b/example/ck_tile/01_fmha/CMakeLists.txt @@ -213,8 +213,20 @@ list(APPEND EXAMPLE_FMHA_FWD_V3_COMPILE_OPTIONS -Wno-undefined-func-template --save-temps ) -target_compile_options(${EXAMPLE_FMHA_FWD_V3} PRIVATE ${EXAMPLE_FMHA_FWD_V3_COMPILE_OPTIONS}) +set(EXAMPLE_FMHA_FWD_V3_COMPILE_DEFINITIONS) +check_cxx_compiler_flag("-mllvm --amdgpu-disable-packed-fp32=1" HAS_DISABLE_PACKED_FP32) +if(HAS_DISABLE_PACKED_FP32) + list(APPEND EXAMPLE_FMHA_FWD_V3_COMPILE_OPTIONS + -mllvm --amdgpu-disable-packed-fp32=1 + ) + list(APPEND EXAMPLE_FMHA_FWD_V3_COMPILE_DEFINITIONS + -DCK_TILE_DISABLE_PACKED_FP32=1 + ) +endif() + +target_compile_options(${EXAMPLE_FMHA_FWD_V3} PRIVATE ${EXAMPLE_FMHA_FWD_V3_COMPILE_OPTIONS}) +target_compile_definitions(${EXAMPLE_FMHA_FWD_V3} PRIVATE ${EXAMPLE_FMHA_FWD_V3_COMPILE_DEFINITIONS}) # TODO: we have to turn off this global prop, otherwise the progress bar generated # by cmake will print too many files, execvp: /bin/sh: Argument list too long # however, this property may affect global diff --git a/example/ck_tile/01_fmha/example_fmha_fwd_v3.cpp b/example/ck_tile/01_fmha/example_fmha_fwd_v3.cpp index d2428e5152..569c98a458 100644 --- a/example/ck_tile/01_fmha/example_fmha_fwd_v3.cpp +++ b/example/ck_tile/01_fmha/example_fmha_fwd_v3.cpp @@ -45,18 +45,7 @@ auto parse_cmd_args(int argc, char* argv[]) -> std::pair(hdim)); - mask = mask_info::decode(args.get_str("mask"), seqlen_q, seqlen_k); + + const auto is_causal = args.get_bool("causal"); + if(is_causal) + { + mask = mask_info::decode("b:-1,0", seqlen_q, seqlen_k); + } + else + { + mask = mask_info::decode("0", seqlen_q, seqlen_k); + } input_layout = args.get_int("iperm") == 1 ? TensorLayout::bhsd : TensorLayout::bshd; output_layout = args.get_int("operm") == 1 ? TensorLayout::bhsd : TensorLayout::bshd; diff --git a/example/ck_tile/01_fmha/fmha_fwd_v3.hpp b/example/ck_tile/01_fmha/fmha_fwd_v3.hpp index 5361d27f0f..10cb5149a4 100644 --- a/example/ck_tile/01_fmha/fmha_fwd_v3.hpp +++ b/example/ck_tile/01_fmha/fmha_fwd_v3.hpp @@ -34,7 +34,8 @@ struct fmha_fwd_v3_args index_t window_size_left; index_t window_size_right; - index_t mask_type; + index_t mask_type; // should be 0 for no mask; or 2 for causal mask (window_size_left < 0 and + // window_size_right == 0). const void* q_ptr; index_t stride_q; diff --git a/example/ck_tile/01_fmha/fmha_fwd_v3_impl.hpp b/example/ck_tile/01_fmha/fmha_fwd_v3_impl.hpp index d6e4ac4c60..e0fbad39a5 100644 --- a/example/ck_tile/01_fmha/fmha_fwd_v3_impl.hpp +++ b/example/ck_tile/01_fmha/fmha_fwd_v3_impl.hpp @@ -18,6 +18,7 @@ #include "ck_tile/ops/fmha/pipeline/tile_fmha_traits.hpp" #include "fmha_fwd_v3.hpp" +#include "mask.hpp" #define INST_FMHA_FWD_V3_DISPATCH(kernel_traits) \ template <> \ @@ -79,7 +80,7 @@ struct fmha_fwd_v3_kernel_traits -1 // kBlockPerCu >; - using fmha_mask = SimplifiedGenericAttentionMask; + using fmha_mask = GenericAttentionMask; using fmha_pipeline_problem = BlockFmhaFwdV3PipelineProblem::qkvp_dtype, @@ -112,6 +113,22 @@ struct fmha_fwd_v3_kernel_traits template float fmha_fwd_v3_kernel_launch(const fmha_fwd_v3_args& args, const stream_config& config) { + /// NOTICE: This was borrowed from Aiter. Make sure the selected remap_opt setting truly + /// maximizes the kernel's performance. + int remap_opt = 2; + if(args.mask_type != static_cast(mask_enum::no_mask) && + ((args.nhead_q % 8 != 0) || (16384 < args.seqlen_q))) + { + if(65536 <= args.seqlen_q) + { + remap_opt = 0; + } + else + { + remap_opt = 1; + } + } + auto kargs = Kernel::MakeKargs(args.q_ptr, args.k_ptr, args.v_ptr, @@ -140,7 +157,8 @@ float fmha_fwd_v3_kernel_launch(const fmha_fwd_v3_args& args, const stream_confi args.batch_stride_o, args.window_size_left, args.window_size_right, - args.mask_type); + args.mask_type, + remap_opt); dim3 grids = Kernel::GridSize(args.batch, args.nhead_q, args.seqlen_q, args.hdim_v); constexpr dim3 blocks = Kernel::BlockSize(); diff --git a/example/ck_tile/01_fmha/script/benchmark_fwd_v3.sh b/example/ck_tile/01_fmha/script/benchmark_fwd_v3.sh index 9c500edf9d..b847e85398 100755 --- a/example/ck_tile/01_fmha/script/benchmark_fwd_v3.sh +++ b/example/ck_tile/01_fmha/script/benchmark_fwd_v3.sh @@ -8,22 +8,16 @@ for prec in "fp16" "bf16" ; do for hdim in 128 ; do for perm in 0 ; do -if [ $causal -eq 0 ]; then - mask=0 -else - mask=b:-1,0 -fi - -$EXE -prec=$prec -b=32 -h=16 -s=512 -d=$hdim -mask=$mask -iperm=$perm -operm=$perm -v=$VALID -$EXE -prec=$prec -b=16 -h=16 -s=1024 -d=$hdim -mask=$mask -iperm=$perm -operm=$perm -v=$VALID -$EXE -prec=$prec -b=8 -h=16 -s=2048 -d=$hdim -mask=$mask -iperm=$perm -operm=$perm -v=$VALID -$EXE -prec=$prec -b=4 -h=16 -s=4096 -d=$hdim -mask=$mask -iperm=$perm -operm=$perm -v=$VALID -$EXE -prec=$prec -b=2 -h=16 -s=8192 -d=$hdim -mask=$mask -iperm=$perm -operm=$perm -v=$VALID -$EXE -prec=$prec -b=1 -h=16 -s=16384 -d=$hdim -mask=$mask -iperm=$perm -operm=$perm -v=$VALID +$EXE -prec=$prec -b=32 -h=16 -s=512 -d=$hdim -causal=$causal -iperm=$perm -operm=$perm -v=$VALID +$EXE -prec=$prec -b=16 -h=16 -s=1024 -d=$hdim -causal=$causal -iperm=$perm -operm=$perm -v=$VALID +$EXE -prec=$prec -b=8 -h=16 -s=2048 -d=$hdim -causal=$causal -iperm=$perm -operm=$perm -v=$VALID +$EXE -prec=$prec -b=4 -h=16 -s=4096 -d=$hdim -causal=$causal -iperm=$perm -operm=$perm -v=$VALID +$EXE -prec=$prec -b=2 -h=16 -s=8192 -d=$hdim -causal=$causal -iperm=$perm -operm=$perm -v=$VALID +$EXE -prec=$prec -b=1 -h=16 -s=16384 -d=$hdim -causal=$causal -iperm=$perm -operm=$perm -v=$VALID -$EXE -prec=$prec -b=1 -h=64 -s=16384 -d=$hdim -mask=$mask -iperm=$perm -operm=$perm -v=$VALID -$EXE -prec=$prec -b=1 -h=16 -h_k=1 -s=65536 -d=$hdim -mask=$mask -iperm=$perm -operm=$perm -v=$VALID -$EXE -prec=$prec -b=1 -h=40 -s=37200 -d=$hdim -mask=$mask -iperm=$perm -operm=$perm -v=$VALID +$EXE -prec=$prec -b=1 -h=64 -s=16384 -d=$hdim -causal=$causal -iperm=$perm -operm=$perm -v=$VALID +$EXE -prec=$prec -b=1 -h=16 -h_k=1 -s=65536 -d=$hdim -causal=$causal -iperm=$perm -operm=$perm -v=$VALID +$EXE -prec=$prec -b=1 -h=40 -s=37200 -d=$hdim -causal=$causal -iperm=$perm -operm=$perm -v=$VALID done done diff --git a/include/ck_tile/ops/fmha/block/block_masking.hpp b/include/ck_tile/ops/fmha/block/block_masking.hpp index f5c12e11d2..2c45945fac 100644 --- a/include/ck_tile/ops/fmha/block/block_masking.hpp +++ b/include/ck_tile/ops/fmha/block/block_masking.hpp @@ -203,27 +203,36 @@ struct GenericAttentionMask CK_TILE_HOST_DEVICE constexpr auto IsEdgeTile(index_t i_tile_top, index_t i_tile_left, number, number) const { - if constexpr(IsLocal) + if constexpr(!IsMasking) { - // check top-right corner > x or left-borrom corner < x - index_t i_tile_right = i_tile_left + TileWidth; - index_t i_tile_bottom = i_tile_top + TileHeight; - index_t x_end = min(i_tile_top + x, x_total); - - bool top_right_edge = i_tile_right > (i_tile_top + x); - bool bottom_left_edge = i_tile_bottom > (i_tile_left + y); - bool is_partial_out_of_bound = i_tile_right > x_end; // only consider right-pad for now - - return top_right_edge || bottom_left_edge || is_partial_out_of_bound; + // TODO: no need to check begin + return (i_tile_left + TileWidth) > x_total; } else { - // only need to check top-right corner > x - index_t i_tile_right = i_tile_left + TileWidth; - index_t x_end = min(i_tile_top + x, x_total); + if constexpr(IsLocal) + { + // check top-right corner > x or left-borrom corner < x + index_t i_tile_right = i_tile_left + TileWidth; + index_t i_tile_bottom = i_tile_top + TileHeight; + index_t x_end = min(i_tile_top + x, x_total); - bool top_right_edge = i_tile_right > x_end; - return top_right_edge; + bool top_right_edge = i_tile_right > (i_tile_top + x); + bool bottom_left_edge = i_tile_bottom > (i_tile_left + y); + bool is_partial_out_of_bound = + i_tile_right > x_end; // only consider right-pad for now + + return top_right_edge || bottom_left_edge || is_partial_out_of_bound; + } + else + { + // only need to check top-right corner > x + index_t i_tile_right = i_tile_left + TileWidth; + index_t x_end = min(i_tile_top + x, x_total); + + bool top_right_edge = i_tile_right > x_end; + return top_right_edge; + } } } diff --git a/include/ck_tile/ops/fmha/kernel/fmha_fwd_v3_kernel.hpp b/include/ck_tile/ops/fmha/kernel/fmha_fwd_v3_kernel.hpp index 87021354aa..c5e5745817 100644 --- a/include/ck_tile/ops/fmha/kernel/fmha_fwd_v3_kernel.hpp +++ b/include/ck_tile/ops/fmha/kernel/fmha_fwd_v3_kernel.hpp @@ -81,6 +81,7 @@ struct FmhaFwdV3Kernel // ck_tile::index_t window_size_left, window_size_right; ck_tile::index_t window_size_left, window_size_right; ck_tile::GenericAttentionMaskEnum mask_type; + ck_tile::index_t remap_opt; }; struct FmhaFwdCommonLSEKargs @@ -143,7 +144,8 @@ struct FmhaFwdV3Kernel ck_tile::index_t batch_stride_o, ck_tile::index_t window_size_left, ck_tile::index_t window_size_right, - ck_tile::index_t mask_type) + ck_tile::index_t mask_type, + ck_tile::index_t remap_opt) { Kargs kargs{{q_ptr, k_ptr, @@ -176,6 +178,7 @@ struct FmhaFwdV3Kernel kargs.window_size_left = window_size_left; kargs.window_size_right = window_size_right; kargs.mask_type = static_cast(mask_type); + kargs.remap_opt = remap_opt; } if constexpr(kStoreLSE) { @@ -213,7 +216,8 @@ struct FmhaFwdV3Kernel ck_tile::index_t nhead_stride_o, ck_tile::index_t window_size_left, ck_tile::index_t window_size_right, - ck_tile::index_t mask_type) + ck_tile::index_t mask_type, + ck_tile::index_t remap_opt) { Kargs kargs{{q_ptr, k_ptr, @@ -245,6 +249,7 @@ struct FmhaFwdV3Kernel kargs.window_size_left = window_size_left; kargs.window_size_right = window_size_right; kargs.mask_type = static_cast(mask_type); + kargs.remap_opt = remap_opt; } if constexpr(kStoreLSE) { @@ -261,39 +266,81 @@ struct FmhaFwdV3Kernel ck_tile::index_t hdim_v_) { // TODO: this may need tuning - return dim3(ck_tile::integer_divide_ceil(seqlen_q_, FmhaPipeline::kM0) * - ck_tile::integer_divide_ceil(hdim_v_, FmhaPipeline::kN1), - nhead_, - batch_size_); - } - - CK_TILE_DEVICE static constexpr auto GetTileIndex(const Kargs& kargs) - { - using namespace ck_tile; - - // const index_t num_tile_m0 = seqlen_q / kM0; - const index_t num_tile_n1 = ck_tile::integer_divide_ceil(kargs.hdim_v, FmhaPipeline::kN1); - - const index_t i_block = blockIdx.x; - const index_t i_nhead = blockIdx.y; - const index_t i_batch = blockIdx.z; - - const auto f = [](index_t dividend, index_t divisor) { - index_t quotient = dividend / divisor; - index_t modulus = dividend - quotient * divisor; - return ck_tile::make_tuple(quotient, modulus); - }; - - const auto [i_tile_m, i_tile_n] = f(i_block, num_tile_n1); - if constexpr(kHasMask) { - // assume that num_tile_n1 is always 1 - return ck_tile::make_tuple(gridDim.x - 1 - i_tile_m, i_tile_n, i_nhead, i_batch); + return dim3(nhead_, + ck_tile::integer_divide_ceil(seqlen_q_, FmhaPipeline::kM0) * + ck_tile::integer_divide_ceil(hdim_v_, FmhaPipeline::kN1), + batch_size_); } else { - return ck_tile::make_tuple(i_tile_m, i_tile_n, i_nhead, i_batch); + return dim3(nhead_, + ck_tile::integer_divide_ceil(seqlen_q_, FmhaPipeline::kM0) * + ck_tile::integer_divide_ceil(hdim_v_, FmhaPipeline::kN1), + batch_size_); + } + } + + CK_TILE_DEVICE static constexpr auto + RemapTileIndices(int32_t tg_idx, int32_t tg_idy, int32_t remap_option) + { + if(remap_option < 1) + { + return make_tuple(static_cast(gridDim.x - tg_idx - 1), tg_idy); + } + + int32_t remapped_tg_idx = tg_idx; + int32_t remapped_tg_idy = tg_idy; + + if(remap_option == 2) + { // special remapping + int32_t tmp0 = (remapped_tg_idy & 0x7) * gridDim.x + remapped_tg_idx; + int32_t tmp1 = tmp0 & 0x7; + + remapped_tg_idx = tmp0 >> 3; + remapped_tg_idy = (remapped_tg_idy & 0xfffffff8) + tmp1; + } + else + { // normal remapping + int32_t cus_per_xdim_per_xcc = gridDim.x >> 3; + int32_t tgs_cu_id = remapped_tg_idx >> 3; + + if(tgs_cu_id < cus_per_xdim_per_xcc) + { + int32_t tgs_xcc_id = remapped_tg_idx & 0x7; + int32_t new_tg_idx = tgs_xcc_id * cus_per_xdim_per_xcc + tgs_cu_id; + + remapped_tg_idx = new_tg_idx; + } + } + + return make_tuple(remapped_tg_idx, remapped_tg_idy); + } + + CK_TILE_DEVICE static constexpr auto GetTileIndex(const Kargs&) + { + using namespace ck_tile; + + // const index_t num_tile_n1 = ck_tile::integer_divide_ceil(kargs.hdim_v, + // FmhaPipeline::kN1); + + // assume that num_tile_n1 is always 1 + if constexpr(kHasMask) + { + const index_t i_nhead = blockIdx.x; + const index_t i_block = blockIdx.y; + const index_t i_batch = blockIdx.z; + + return ck_tile::make_tuple(gridDim.y - 1 - i_block, 0, i_nhead, i_batch); + } + else + { + const index_t i_nhead = blockIdx.x; + const index_t i_block = blockIdx.y; + const index_t i_batch = blockIdx.z; + + return ck_tile::make_tuple(i_block, 0, i_nhead, i_batch); } } diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_v3_pipeline.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_v3_pipeline.hpp index 20d84116d4..5e2a4e898b 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_v3_pipeline.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_v3_pipeline.hpp @@ -57,7 +57,11 @@ struct CoreLoopScheduler __builtin_amdgcn_sched_group_barrier(0x002, 2, 0); // VALU }); } - else if constexpr(Phase == 1) {} + else if constexpr(Phase == 1) + { + __builtin_amdgcn_sched_group_barrier(0x002, 2, 0); // VALU + __builtin_amdgcn_sched_group_barrier(0x004, 4, 0); // SALU + } else if constexpr(Phase == 2) { #if !CK_TILE_DISABLE_PACKED_FP32 @@ -68,11 +72,19 @@ struct CoreLoopScheduler __builtin_amdgcn_sched_group_barrier(0x002, 4, 0); // VALU }); } - else if constexpr(Phase == 3) {} + else if constexpr(Phase == 3) + { + __builtin_amdgcn_sched_group_barrier(0x002, 2, 0); // VALU + __builtin_amdgcn_sched_group_barrier(0x004, 4, 0); // SALU + } } else { - if constexpr(Phase == 0) {} + if constexpr(Phase == 0) + { + __builtin_amdgcn_sched_group_barrier(0x002, 2, 0); // VALU + __builtin_amdgcn_sched_group_barrier(0x004, 4, 0); // SALU + } else if constexpr(Phase == 1) { static_for<0, 8, 1>{}([&](auto) { @@ -81,7 +93,11 @@ struct CoreLoopScheduler __builtin_amdgcn_sched_group_barrier(0x002, 2, 0); // VALU }); } - else if constexpr(Phase == 2) {} + else if constexpr(Phase == 2) + { + __builtin_amdgcn_sched_group_barrier(0x002, 2, 0); // VALU + __builtin_amdgcn_sched_group_barrier(0x004, 4, 0); // SALU + } else if constexpr(Phase == 3) { #if !CK_TILE_DISABLE_PACKED_FP32 @@ -115,7 +131,11 @@ struct CoreLoopScheduler __builtin_amdgcn_sched_group_barrier(0x002, 2, 0); // VALU }); } - else if constexpr(Phase == 1) {} + else if constexpr(Phase == 1) + { + __builtin_amdgcn_sched_group_barrier(0x002, 2, 0); // VALU + __builtin_amdgcn_sched_group_barrier(0x004, 4, 0); // SALU + } else if constexpr(Phase == 2) { #if !CK_TILE_DISABLE_PACKED_FP32 @@ -126,11 +146,19 @@ struct CoreLoopScheduler __builtin_amdgcn_sched_group_barrier(0x002, 4, 0); // VALU }); } - else if constexpr(Phase == 3) {} + else if constexpr(Phase == 3) + { + __builtin_amdgcn_sched_group_barrier(0x002, 2, 0); // VALU + __builtin_amdgcn_sched_group_barrier(0x004, 4, 0); // SALU + } } else { - if constexpr(Phase == 0) {} + if constexpr(Phase == 0) + { + __builtin_amdgcn_sched_group_barrier(0x002, 2, 0); // VALU + __builtin_amdgcn_sched_group_barrier(0x004, 4, 0); // SALU + } else if constexpr(Phase == 1) { static_for<0, 8, 1>{}([&](auto) { @@ -139,7 +167,11 @@ struct CoreLoopScheduler __builtin_amdgcn_sched_group_barrier(0x002, 2, 0); // VALU }); } - else if constexpr(Phase == 2) {} + else if constexpr(Phase == 2) + { + __builtin_amdgcn_sched_group_barrier(0x002, 2, 0); // VALU + __builtin_amdgcn_sched_group_barrier(0x004, 4, 0); // SALU + } else if constexpr(Phase == 3) { #if !CK_TILE_DISABLE_PACKED_FP32 @@ -177,6 +209,15 @@ CK_TILE_DEVICE float add_impl_vv(float lhs, float rhs) return result; } +CK_TILE_DEVICE float mul_impl_vv(float lhs, float rhs) +{ + float result; + asm volatile("v_mul_f32_e32 %[result], %[lhs], %[rhs]" + : [result] "=v"(result) + : [lhs] "v"(lhs), [rhs] "v"(rhs)); + return result; +} + CK_TILE_DEVICE fp16x2_t cvt_pk_fp16_f32(float a, float b) { fp16x2_t result; @@ -466,7 +507,7 @@ struct BlockFmhaFwdV3Pipeline statically_indexed_array sp; decltype(gemm_1.MakeCBlockTile()) o_acc; - constexpr index_t fmha_alu_D_reg_cnt = 0; // threshold to decide how many fmha_alu_D_upd() + constexpr index_t fmha_alu_D_reg_cnt = 6; // threshold to decide how many fmha_alu_D_upd() // instructions should we move to fmha_alu1() static_assert(fmha_alu_D_reg_cnt <= o_acc.thread_buf_.size()); @@ -631,8 +672,8 @@ struct BlockFmhaFwdV3Pipeline // K_mem_su_ld_insts = 1 for 32 x 128 // V_mem_su_ld_insts = 1 for 128 x 32 - static constexpr int K_mem_su_ld_insts = 1; - static constexpr int V_mem_su_ld_insts = 1; + constexpr int K_mem_su_ld_insts = k_dram_window.get_num_of_access(); + constexpr int V_mem_su_ld_insts = v_dram_window.get_num_of_access(); auto K_mem_load = [&](auto k_lds_write_idx) { async_load_tile_raw(k_lds_window_store(k_lds_write_idx), k_dram_window); @@ -648,7 +689,6 @@ struct BlockFmhaFwdV3Pipeline auto V_mem_load = [&](auto v_lds_write_idx) { async_load_tile_raw(v_lds_window_store(v_lds_write_idx), v_dram_window); - __builtin_amdgcn_sched_barrier(0); /// FIXME: use the future-predicting method to move the window move_tile_window(v_dram_window, {kK1, 0}); @@ -726,11 +766,12 @@ struct BlockFmhaFwdV3Pipeline #else block_tile_reduce_sync(rowsum_p, f_sum, bool_constant{}); #endif - // update partial o_acc [0, 2) - static_for<0, ck_tile::min(2, fmha_alu_D_reg_cnt), 1>{}( - [&](auto idx) { o_acc.thread_buf_[idx] *= o_acc_scale; }); // l{j} + /// Note: The compiler keeps moving the following instructions elsewhere because 'l' + /// is first consumed later. To anchor them here, we rewrite the final addition in + /// inline assembly to create a dependency, forcing the dependent instructions to + /// be emitted at this point. constexpr auto o_spans = decltype(o_acc)::get_distributed_spans(); sweep_tile_span(o_spans[number<0>{}], [&](auto idx0) { constexpr auto i_idx = make_tuple(idx0); @@ -739,13 +780,15 @@ struct BlockFmhaFwdV3Pipeline l(i_idx) = detail::add_impl_vv(tmp * l[i_idx], rowsum_p[i_idx]); }); - // update partial o_acc [2, fmha_alu_D_reg_cnt) - static_for<2, ck_tile::max(2, fmha_alu_D_reg_cnt), 1>{}( - [&](auto idx) { o_acc.thread_buf_[idx] *= o_acc_scale; }); + // update partial o_acc [0, fmha_alu_D_reg_cnt) + static_for<0, fmha_alu_D_reg_cnt, 1>{}([&](auto idx) { + o_acc.thread_buf_[idx] = detail::mul_impl_vv(o_acc.thread_buf_[idx], o_acc_scale); + }); - /// NOTICE: Compiler keep moving the conversion instructions to other places. We rewite - /// the cast_tile() call into inline asm to force the conversion instructions to be - /// generated here. The fmha_alu1() call should be placed at the end of a phase. + /// Note: The compiler keeps sinking the conversion instructions because the + /// result 'p' is only consumed later. To anchor them here, we rewrite + /// the cast_tile() call as inline assembly, forcing the conversions to be + /// emitted at this point. static_assert(sp(sp_reg_idx).p.thread_buf_.size() % 2 == 0); static_for<0, sp(sp_reg_idx).p.thread_buf_.size(), 2>{}([&](auto idx) { float x = p_compute_element_func(sp(sp_reg_idx).sp_compute.thread_buf_[idx]); @@ -763,6 +806,10 @@ struct BlockFmhaFwdV3Pipeline sp(sp_reg_idx).p.thread_buf_[idx + 1] = casted.y; } }); + + /// Note: Place fmha_alu1() at the end of the phase. The surrounding inline assembly + /// can interfere with the behavior of sched_group_barrier(), so ending the phase here + /// avoids unintended reordering. }; auto gemm = [&](auto sp_reg_idx, auto gemm_idx) { @@ -937,9 +984,9 @@ struct BlockFmhaFwdV3Pipeline __builtin_amdgcn_s_barrier(); __builtin_amdgcn_sched_barrier(0); cl_load(memK, K_w0_lds_wr_idx, V_w0_lds_rd_idx); + Scheduler::schedule(cl_p, number<1>{}); fmha_mask(xdl_SP_p01_reg_idx); - Scheduler::schedule(cl_p, number<1>{}); __builtin_amdgcn_sched_barrier(0); // phase2 ASM_MARKER("phase2 Wave0-3"); @@ -947,6 +994,8 @@ struct BlockFmhaFwdV3Pipeline __builtin_amdgcn_sched_barrier(0); __builtin_amdgcn_s_barrier(); __builtin_amdgcn_sched_barrier(0); + asm volatile("s_nop 0"); + __builtin_amdgcn_sched_barrier(0); cl_calc(xdl_SP_p23_reg_idx, gemm1); Scheduler::schedule(cl_p, number<2>{}); @@ -995,6 +1044,8 @@ struct BlockFmhaFwdV3Pipeline __builtin_amdgcn_sched_barrier(0); __builtin_amdgcn_s_barrier(); __builtin_amdgcn_sched_barrier(0); + asm volatile("s_nop 1"); + __builtin_amdgcn_sched_barrier(0); cl_calc(xdl_SP_p01_reg_idx, gemm0); fmha_alu1(xdl_SP_p23_reg_idx); @@ -1005,9 +1056,9 @@ struct BlockFmhaFwdV3Pipeline __builtin_amdgcn_s_barrier(); __builtin_amdgcn_sched_barrier(0); cl_load(memK, K_w4_lds_wr_idx, V_w4_lds_rd_idx); + Scheduler::schedule(cl_p, number<2>{}); fmha_mask(xdl_SP_p01_reg_idx); - Scheduler::schedule(cl_p, number<2>{}); kv_token_start += kN0; if(num_total_loop <= ++i_total_loops) { @@ -1021,6 +1072,8 @@ struct BlockFmhaFwdV3Pipeline __builtin_amdgcn_sched_barrier(0); __builtin_amdgcn_s_barrier(); __builtin_amdgcn_sched_barrier(0); + asm volatile("s_nop 1"); + __builtin_amdgcn_sched_barrier(0); cl_calc(xdl_SP_p23_reg_idx, gemm1); Scheduler::schedule(cl_p, number<3>{}); @@ -1036,7 +1089,14 @@ struct BlockFmhaFwdV3Pipeline auto ps_pi = number<1>{} - d; auto V_lds_rd_idx = ps_pi; - s_waitcnt_vmcnt(); + if(1 < num_total_loop) + { + s_waitcnt_vmcnt(); + } + else + { + s_waitcnt_vmcnt<0>(); + } __builtin_amdgcn_s_barrier(); V_lds_load(V_lds_rd_idx); @@ -1102,14 +1162,14 @@ struct BlockFmhaFwdV3Pipeline V_mem_load(number<1>{}); // V1 K_lds_load(number<1>{}); // K1 - asm volatile("s_setprio 0"); + __builtin_amdgcn_s_setprio(0); __builtin_amdgcn_s_barrier(); while(core_loop(number<0>{})) ; } if(warp_group_id != 0) { - asm volatile("s_setprio 1"); + __builtin_amdgcn_s_setprio(1); __builtin_amdgcn_s_barrier(); while(core_loop(number<1>{})) ; @@ -1167,14 +1227,13 @@ struct BlockFmhaFwdV3Pipeline typename KDramBlockWindowTmp, typename VDramBlockWindowTmp, typename LSEDramBlockWindowTmp> - CK_TILE_HOST_DEVICE auto - operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile - const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*K0 tile - const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile - LSEDramBlockWindowTmp& lse_dram_block_window_tmp, // M0*1 tile - FmhaMask mask, - float scale_s, - void* smem_ptr) const + CK_TILE_DEVICE auto operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile + const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*K0 tile + const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile + LSEDramBlockWindowTmp& lse_dram_block_window_tmp, // M0*1 tile + FmhaMask mask, + float scale_s, + void* smem_ptr) const { using namespace ck_tile;