From 1c4c07c6695df587570ead3cb3c4903073832e2e Mon Sep 17 00:00:00 2001 From: Po Yen Chen Date: Tue, 1 Oct 2024 22:13:52 +0800 Subject: [PATCH] [CK_TILE] Change output accum tensor layout of fmha fwd split-kv & combine kernels (#1527) * Use same layout for o_acc and o tensor * Use better param names in partitioner * Remove redundant kargs 'max_seqlen_q' * Use better param names in splitkv kernel * Add comment for additional kernel arguments * Sync empty loop early return logics between pipelines * Pass more arguments to cmake in scripts * Align backslashes * Fix wrong o_acc tensor view strides * Change o_acc layout if o_perm=0 * Handle whole row masked via attn_bias * Use use vector width = 1 for o_acc * Use more even split sizes [ROCm/composable_kernel commit: a1c07e8d913cd03011f4ea3d45033ab4e765e9f1] --- example/ck_tile/01_fmha/fmha_fwd.cpp | 47 +++++++++++------ example/ck_tile/01_fmha/fmha_fwd.hpp | 9 +--- .../ck_tile/ops/fmha/block/block_masking.hpp | 4 +- .../fmha_fwd_splitkv_combine_kernel.hpp | 51 ++++++++----------- ...a_fwd_splitkv_combine_tile_partitioner.hpp | 17 +++---- .../fmha/kernel/fmha_fwd_splitkv_kernel.hpp | 44 ++++++++-------- .../fmha_fwd_splitkv_tile_partitioner.hpp | 4 +- ...lock_fmha_fwd_splitkv_combine_pipeline.hpp | 10 ++-- ...ock_fmha_fwd_splitkv_pipeline_qr_ks_vs.hpp | 9 ++-- script/cmake-ck-dev.sh | 3 ++ script/cmake-ck-release.sh | 3 ++ 11 files changed, 104 insertions(+), 97 deletions(-) diff --git a/example/ck_tile/01_fmha/fmha_fwd.cpp b/example/ck_tile/01_fmha/fmha_fwd.cpp index 723546a452..b9cb9a1ec2 100644 --- a/example/ck_tile/01_fmha/fmha_fwd.cpp +++ b/example/ck_tile/01_fmha/fmha_fwd.cpp @@ -552,16 +552,33 @@ bool run(const ck_tile::ArgParser& arg_parser) } #endif - auto get_lengths = [&](bool permute, - ck_tile::index_t b /*batch*/, - ck_tile::index_t h /*nhead*/, - ck_tile::index_t s /*seqlen*/, - ck_tile::index_t d /*hdim*/) { - if(permute) - return std::array{b, h, s, d}; - else - return std::array{b, s, h, d}; - }; + struct + { + auto operator()(bool permute, + ck_tile::index_t b /*batch*/, + ck_tile::index_t h /*nhead*/, + ck_tile::index_t s /*seqlen*/, + ck_tile::index_t d /*hdim*/) + { + if(permute) + return std::array{b, h, s, d}; + else + return std::array{b, s, h, d}; + } + + auto operator()(bool permute, + ck_tile::index_t ns /*num_splits*/, + ck_tile::index_t b /*batch*/, + ck_tile::index_t h /*nhead*/, + ck_tile::index_t s /*seqlen*/, + ck_tile::index_t d /*hdim*/) + { + if(permute) + return std::array{ns, b, h, s, d}; + else + return std::array{ns, b, s, h, d}; + } + } get_lengths; bool is_v_rowmajor = vlayout == std::string("r"); @@ -617,7 +634,7 @@ bool run(const ck_tile::ArgParser& arg_parser) : std::array{1, 1, 1, 1}); ck_tile::HostTensor o_acc_host( 1 < num_splits || use_kvcache - ? std::array{num_splits, batch, nhead, max_seqlen_q, hdim_v} + ? get_lengths(o_perm, num_splits, shape_batch, nhead, shape_seqlen_q, hdim_v) : std::array{1, 1, 1, 1, 1}); // batch mode of lse data layout is [batch, nhead, seqlen_q] @@ -854,7 +871,7 @@ bool run(const ck_tile::ArgParser& arg_parser) }(); const ck_tile::index_t stride_bias = (i_perm ? shape_seqlen_k : 1 * shape_seqlen_k); const ck_tile::index_t stride_randval = (max_seqlen_k); - const ck_tile::index_t stride_o_acc = hdim_v; + const ck_tile::index_t stride_o_acc = (o_perm ? hdim_v : nhead * hdim_v); const ck_tile::index_t stride_o = (o_perm ? hdim_v : nhead * hdim_v); // setup nhead_stride_* arguments const ck_tile::index_t nhead_stride_q = (i_perm ? shape_seqlen_q * hdim_q : hdim_q); @@ -881,7 +898,7 @@ bool run(const ck_tile::ArgParser& arg_parser) const ck_tile::index_t nhead_stride_randval = (shape_seqlen_q * max_seqlen_k); const ck_tile::index_t nhead_stride_lse = shape_seqlen_q; const ck_tile::index_t nhead_stride_lse_acc = shape_seqlen_q; - const ck_tile::index_t nhead_stride_o_acc = (max_seqlen_q * hdim_v); + const ck_tile::index_t nhead_stride_o_acc = (o_perm ? shape_seqlen_q * hdim_v : hdim_v); const ck_tile::index_t nhead_stride_o = (o_perm ? shape_seqlen_q * hdim_v : hdim_v); // setup batch_stride_* arguments const ck_tile::index_t batch_stride_q = (nhead * shape_seqlen_q * hdim_q); @@ -897,12 +914,12 @@ bool run(const ck_tile::ArgParser& arg_parser) const ck_tile::index_t batch_stride_randval = (nhead * shape_seqlen_q * max_seqlen_k); const ck_tile::index_t batch_stride_lse = (nhead * shape_seqlen_q); const ck_tile::index_t batch_stride_lse_acc = (nhead * shape_seqlen_q); - const ck_tile::index_t batch_stride_o_acc = (nhead * max_seqlen_q * hdim_v); + const ck_tile::index_t batch_stride_o_acc = (nhead * shape_seqlen_q * hdim_v); const ck_tile::index_t batch_stride_o = (nhead * shape_seqlen_q * hdim_v); const ck_tile::index_t batch_stride_block_table = (max_num_page_blocks / batch); // setup split_stride_* arguments (only used in split-kv kernel) const ck_tile::index_t split_stride_lse_acc = (shape_batch * nhead * shape_seqlen_q); - const ck_tile::index_t split_stride_o_acc = (batch * nhead * max_seqlen_q * hdim_v); + const ck_tile::index_t split_stride_o_acc = (shape_batch * nhead * shape_seqlen_q * hdim_v); args.q_ptr = q_buf.GetDeviceBuffer(); args.k_ptr = k_buf.GetDeviceBuffer(); diff --git a/example/ck_tile/01_fmha/fmha_fwd.hpp b/example/ck_tile/01_fmha/fmha_fwd.hpp index 183475064a..5dcad7907f 100644 --- a/example/ck_tile/01_fmha/fmha_fwd.hpp +++ b/example/ck_tile/01_fmha/fmha_fwd.hpp @@ -398,10 +398,8 @@ auto fmha_fwd_splitkv_create_kargs_and_grids(fmha_fwd_splitkv_args args) args.nhead_stride_bias, args.nhead_stride_lse_acc, args.nhead_stride_o_acc, - args.batch_stride_k, - args.batch_stride_v, - args.batch_stride_lse_acc, - args.batch_stride_o_acc, + args.batch_stride_k, // only used for paged-kvcache + args.batch_stride_v, // only used for paged-kvcache args.split_stride_lse_acc, args.split_stride_o_acc, args.window_size_left, @@ -475,7 +473,6 @@ auto fmha_fwd_splitkv_combine_create_kargs_and_grids(fmha_fwd_splitkv_args args) args.lse_ptr, args.o_ptr, args.batch, - args.max_seqlen_q, args.seqstart_q_ptr, args.hdim_v, args.num_splits, @@ -486,7 +483,6 @@ auto fmha_fwd_splitkv_combine_create_kargs_and_grids(fmha_fwd_splitkv_args args) args.nhead_stride_o_acc, args.nhead_stride_lse, args.nhead_stride_o, - args.batch_stride_o_acc, args.split_stride_lse_acc, args.split_stride_o_acc); } @@ -497,7 +493,6 @@ auto fmha_fwd_splitkv_combine_create_kargs_and_grids(fmha_fwd_splitkv_args args) args.lse_ptr, args.o_ptr, args.batch, - args.max_seqlen_q, args.seqlen_q, args.hdim_v, args.num_splits, diff --git a/include/ck_tile/ops/fmha/block/block_masking.hpp b/include/ck_tile/ops/fmha/block/block_masking.hpp index c022edf723..1569c93565 100644 --- a/include/ck_tile/ops/fmha/block/block_masking.hpp +++ b/include/ck_tile/ops/fmha/block/block_masking.hpp @@ -308,9 +308,9 @@ struct SimplifiedGenericAttentionMask { auto [origin_start, origin_end] = GetTileRangeAlongX(i_y, height, width); - const index_t x_per_split = ck_tile::max(1, x_total / num_splits); + const index_t x_per_split = ck_tile::max(1, integer_divide_ceil(x_total, num_splits)); const index_t split_start = x_per_split * i_split; - const index_t split_end = (i_split == num_splits - 1 ? x_total : split_start + x_per_split); + const index_t split_end = split_start + x_per_split; return ck_tile::make_tuple(ck_tile::max(origin_start, split_start), ck_tile::min(origin_end, split_end)); diff --git a/include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_combine_kernel.hpp b/include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_combine_kernel.hpp index e2c7db3e1b..ca9da91a5d 100644 --- a/include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_combine_kernel.hpp +++ b/include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_combine_kernel.hpp @@ -78,8 +78,6 @@ struct FmhaFwdSplitKVCombineKernel void* o_ptr; ck_tile::index_t batch; - ck_tile::index_t max_seqlen_q; - ck_tile::index_t seqlen_q; ck_tile::index_t hdim_v; ck_tile::index_t num_splits; @@ -91,8 +89,6 @@ struct FmhaFwdSplitKVCombineKernel ck_tile::index_t nhead_stride_o_acc; ck_tile::index_t nhead_stride_o; - ck_tile::index_t batch_stride_o_acc; - ck_tile::index_t split_stride_lse_acc; ck_tile::index_t split_stride_o_acc; }; @@ -114,8 +110,9 @@ struct FmhaFwdSplitKVCombineKernel std::conditional_t>, std::conditional_t> { - ck_tile::index_t batch_stride_o; ck_tile::index_t batch_stride_lse_acc; + ck_tile::index_t batch_stride_o_acc; + ck_tile::index_t batch_stride_o; }; struct GroupModeKargs @@ -135,7 +132,6 @@ struct FmhaFwdSplitKVCombineKernel void* lse_ptr, void* o_ptr, ck_tile::index_t batch, - ck_tile::index_t max_seqlen_q, ck_tile::index_t seqlen_q, ck_tile::index_t hdim_v, ck_tile::index_t num_splits, @@ -157,7 +153,6 @@ struct FmhaFwdSplitKVCombineKernel o_acc_ptr, o_ptr, batch, - max_seqlen_q, seqlen_q, hdim_v, num_splits, @@ -166,13 +161,13 @@ struct FmhaFwdSplitKVCombineKernel nhead_stride_lse_acc, nhead_stride_o_acc, nhead_stride_o, - batch_stride_o_acc, split_stride_lse_acc, split_stride_o_acc}, // args for common karg {}, // placeholder for lse {}, // placeholder for fp8_static_quant args - batch_stride_o, - batch_stride_lse_acc}; + batch_stride_lse_acc, + batch_stride_o_acc, + batch_stride_o}; if constexpr(kStoreLSE) { @@ -195,7 +190,6 @@ struct FmhaFwdSplitKVCombineKernel void* lse_ptr, void* o_ptr, ck_tile::index_t batch, - ck_tile::index_t max_seqlen_q, const void* seqstart_q_ptr, ck_tile::index_t hdim_v, ck_tile::index_t num_splits, @@ -206,7 +200,6 @@ struct FmhaFwdSplitKVCombineKernel ck_tile::index_t nhead_stride_o_acc, ck_tile::index_t nhead_stride_lse, ck_tile::index_t nhead_stride_o, - ck_tile::index_t batch_stride_o_acc, ck_tile::index_t split_stride_lse_acc, ck_tile::index_t split_stride_o_acc) { @@ -214,7 +207,6 @@ struct FmhaFwdSplitKVCombineKernel o_acc_ptr, o_ptr, batch, - max_seqlen_q, -1, // seqlen will be updated by another pointer hdim_v, num_splits, @@ -223,7 +215,6 @@ struct FmhaFwdSplitKVCombineKernel nhead_stride_lse_acc, nhead_stride_o_acc, nhead_stride_o, - batch_stride_o_acc, split_stride_lse_acc, split_stride_o_acc}, // args for common karg {}, // placeholder for lse @@ -243,12 +234,12 @@ struct FmhaFwdSplitKVCombineKernel return kargs; } - __host__ static constexpr auto GridSize(ck_tile::index_t batch_size_, - ck_tile::index_t nhead_, - ck_tile::index_t seqlen_q_, - ck_tile::index_t hdim_v_) + __host__ static constexpr auto GridSize(ck_tile::index_t batch_size, + ck_tile::index_t nhead, + ck_tile::index_t max_seqlen_q, + ck_tile::index_t hdim_v) { - return TilePartitioner::GridSize(batch_size_, nhead_, seqlen_q_, hdim_v_); + return TilePartitioner::GridSize(batch_size, nhead, max_seqlen_q, hdim_v); } __host__ static constexpr auto BlockSize() { return dim3(kBlockSize); } @@ -270,10 +261,8 @@ struct FmhaFwdSplitKVCombineKernel const index_t i_m0 = __builtin_amdgcn_readfirstlane(i_tile_m * FmhaPipeline::kM0); const index_t i_n1 = __builtin_amdgcn_readfirstlane(i_tile_n * FmhaPipeline::kN1); - const long_index_t batch_offset_o_acc = - static_cast(i_batch) * kargs.batch_stride_o_acc; - long_index_t batch_offset_lse_acc = 0; + long_index_t batch_offset_o_acc = 0; long_index_t batch_offset_lse = 0; long_index_t batch_offset_o = 0; @@ -282,14 +271,16 @@ struct FmhaFwdSplitKVCombineKernel // get starting offset for each batch const long_index_t query_start = kargs.seqstart_q_ptr[i_batch]; - batch_offset_o = query_start * kargs.row_stride_o; batch_offset_lse_acc = query_start; + batch_offset_o_acc = query_start * kargs.row_stride_o_acc; if constexpr(kStoreLSE) { batch_offset_lse = query_start; } + batch_offset_o = query_start * kargs.row_stride_o; + // get real # queries & # keys under group mode const auto adjusted_seqstart_q_ptr = kargs.seqstart_q_ptr + i_batch; kargs.seqlen_q = adjusted_seqstart_q_ptr[1] - adjusted_seqstart_q_ptr[0]; @@ -303,13 +294,15 @@ struct FmhaFwdSplitKVCombineKernel } else { - batch_offset_o = static_cast(i_batch) * kargs.batch_stride_o; batch_offset_lse_acc = static_cast(i_batch) * kargs.batch_stride_lse_acc; + batch_offset_o_acc = static_cast(i_batch) * kargs.batch_stride_o_acc; if constexpr(kStoreLSE) { batch_offset_lse = static_cast(i_batch) * kargs.batch_stride_lse; } + + batch_offset_o = static_cast(i_batch) * kargs.batch_stride_o; } // for simplicity, batch stride we just modify the pointer @@ -341,7 +334,7 @@ struct FmhaFwdSplitKVCombineKernel auto o_acc_dram = [&]() { const auto o_acc_dram_naive = make_naive_tensor_view( o_acc_ptr, - make_tuple(kargs.num_splits, kargs.max_seqlen_q, kargs.hdim_v), + make_tuple(kargs.num_splits, kargs.seqlen_q, kargs.hdim_v), make_tuple(kargs.split_stride_o_acc, kargs.row_stride_o_acc, 1), number{}, number<1>{}); @@ -351,14 +344,14 @@ struct FmhaFwdSplitKVCombineKernel make_tuple(number<1>{}, number{}, number{}), sequence{}); - const index_t padded_max_seqlen_q = + const index_t padded_seqlen_q = o_acc_dram_view.get_tensor_descriptor().get_lengths()[number<1>{}]; const index_t padded_hdim_v = o_acc_dram_view.get_tensor_descriptor().get_lengths()[number<2>{}]; return transform_tensor_view( o_acc_dram_view, - make_tuple(make_merge_transform(make_tuple(kargs.num_splits, padded_max_seqlen_q)), + make_tuple(make_merge_transform(make_tuple(kargs.num_splits, padded_seqlen_q)), make_pass_through_transform(padded_hdim_v)), make_tuple(sequence<0, 1>{}, sequence<2>{}), make_tuple(sequence<0>{}, sequence<1>{})); @@ -417,7 +410,7 @@ struct FmhaFwdSplitKVCombineKernel identity{}, // lse_element_func composes(saturates{}, scales{kargs.scale_o}), // o_acc_element_func kargs.num_splits, - kargs.max_seqlen_q, + kargs.seqlen_q, smem_ptr); } else @@ -426,7 +419,7 @@ struct FmhaFwdSplitKVCombineKernel o_acc_dram_window, lse_dram_window, kargs.num_splits, - kargs.max_seqlen_q, + kargs.seqlen_q, smem_ptr); } }(); diff --git a/include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_combine_tile_partitioner.hpp b/include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_combine_tile_partitioner.hpp index 9f04843a39..3b73909712 100644 --- a/include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_combine_tile_partitioner.hpp +++ b/include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_combine_tile_partitioner.hpp @@ -13,21 +13,20 @@ struct FmhaFwdSplitKVCombineTilePartitioner static constexpr ck_tile::index_t kM0 = kM0_; static constexpr ck_tile::index_t kN1 = kN1_; - CK_TILE_HOST static constexpr auto GridSize(ck_tile::index_t batch_size_, - ck_tile::index_t nhead_, - ck_tile::index_t seqlen_q_, - ck_tile::index_t hdim_v_) + CK_TILE_HOST static constexpr auto GridSize(ck_tile::index_t batch_size, + ck_tile::index_t nhead, + ck_tile::index_t max_seqlen_q, + ck_tile::index_t hdim_v) { // TODO: this may need tuning - return dim3(ck_tile::integer_divide_ceil(seqlen_q_, kM0) * - ck_tile::integer_divide_ceil(hdim_v_, kN1), - nhead_, - batch_size_); + return dim3(ck_tile::integer_divide_ceil(max_seqlen_q, kM0) * + ck_tile::integer_divide_ceil(hdim_v, kN1), + nhead, + batch_size); } CK_TILE_DEVICE auto operator()(ck_tile::index_t /*seqlen_q*/, ck_tile::index_t hdim_v) { - // const index_t num_tile_m0 = seqlen_q / kM0; const index_t num_tile_n1 = ck_tile::integer_divide_ceil(hdim_v, kN1); const index_t i_block = blockIdx.x; diff --git a/include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_kernel.hpp b/include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_kernel.hpp index 22978f1a3c..34f75990c6 100644 --- a/include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_kernel.hpp +++ b/include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_kernel.hpp @@ -135,9 +135,6 @@ struct FmhaFwdSplitKVKernel ck_tile::index_t nhead_stride_lse_acc; ck_tile::index_t nhead_stride_o_acc; - ck_tile::index_t batch_stride_lse_acc; - ck_tile::index_t batch_stride_o_acc; - ck_tile::index_t split_stride_lse_acc; ck_tile::index_t split_stride_o_acc; }; @@ -201,6 +198,8 @@ struct FmhaFwdSplitKVKernel ck_tile::index_t batch_stride_q; ck_tile::index_t batch_stride_k; ck_tile::index_t batch_stride_v; + ck_tile::index_t batch_stride_lse_acc; + ck_tile::index_t batch_stride_o_acc; }; struct GroupModeKargs @@ -217,8 +216,8 @@ struct FmhaFwdSplitKVKernel const int32_t* seqstart_k_ptr; const int32_t* seqlen_k_ptr; - ck_tile::index_t batch_stride_k; - ck_tile::index_t batch_stride_v; + ck_tile::index_t batch_stride_k; // only used for paged-kvcache + ck_tile::index_t batch_stride_v; // only used for paged-kvcache }; using Kargs = std::conditional_t; @@ -296,8 +295,6 @@ struct FmhaFwdSplitKVKernel nhead_stride_v, nhead_stride_lse_acc, nhead_stride_o_acc, - batch_stride_lse_acc, - batch_stride_o_acc, split_stride_lse_acc, split_stride_o_acc}, // args for common karg {}, // placeholder for bias @@ -307,7 +304,9 @@ struct FmhaFwdSplitKVKernel reinterpret_cast(seqlen_k_ptr), batch_stride_q, batch_stride_k, - batch_stride_v}; + batch_stride_v, + batch_stride_lse_acc, + batch_stride_o_acc}; if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) { @@ -375,10 +374,8 @@ struct FmhaFwdSplitKVKernel ck_tile::index_t nhead_stride_bias, ck_tile::index_t nhead_stride_lse_acc, ck_tile::index_t nhead_stride_o_acc, - ck_tile::index_t batch_stride_k, - ck_tile::index_t batch_stride_v, - ck_tile::index_t batch_stride_lse_acc, - ck_tile::index_t batch_stride_o_acc, + ck_tile::index_t batch_stride_k, // only used for paged-kvcache + ck_tile::index_t batch_stride_v, // only used for paged-kvcache ck_tile::index_t split_stride_lse_acc, ck_tile::index_t split_stride_o_acc, ck_tile::index_t window_size_left, @@ -412,8 +409,6 @@ struct FmhaFwdSplitKVKernel nhead_stride_v, nhead_stride_lse_acc, nhead_stride_o_acc, - batch_stride_lse_acc, - batch_stride_o_acc, split_stride_lse_acc, split_stride_o_acc}, // args for common karg {}, // placeholder for bias @@ -452,11 +447,11 @@ struct FmhaFwdSplitKVKernel __host__ static constexpr auto GridSize(ck_tile::index_t batch_size, ck_tile::index_t nhead, - ck_tile::index_t seqlen_q, + ck_tile::index_t max_seqlen_q, ck_tile::index_t hdim_v, ck_tile::index_t num_splits) { - return TilePartitioner::GridSize(batch_size, nhead, seqlen_q, hdim_v, num_splits); + return TilePartitioner::GridSize(batch_size, nhead, max_seqlen_q, hdim_v, num_splits); } __host__ static constexpr auto BlockSize() { return dim3(kBlockSize); } @@ -483,8 +478,7 @@ struct FmhaFwdSplitKVKernel long_index_t batch_offset_v = 0; long_index_t batch_offset_bias = 0; long_index_t batch_offset_lse_acc = 0; - const long_index_t batch_offset_o_acc = - static_cast(i_batch) * kargs.batch_stride_o_acc; + long_index_t batch_offset_o_acc = 0; if constexpr(kIsGroupMode) { @@ -492,9 +486,9 @@ struct FmhaFwdSplitKVKernel const long_index_t query_start = kargs.seqstart_q_ptr[i_batch]; const long_index_t key_start = kargs.seqstart_k_ptr[i_batch]; - batch_offset_q = query_start * kargs.stride_q; - batch_offset_k = key_start * kargs.stride_k; - batch_offset_lse_acc = query_start; + batch_offset_q = query_start * kargs.stride_q; + batch_offset_k = key_start * kargs.stride_k; + if constexpr(std::is_same_v) { batch_offset_v = key_start * kargs.stride_v; @@ -508,6 +502,9 @@ struct FmhaFwdSplitKVKernel batch_offset_bias = query_start * kargs.stride_bias + key_start; } + batch_offset_lse_acc = query_start; + batch_offset_o_acc = query_start * kargs.stride_o_acc; + // get real # queries & # keys under group mode kargs.seqlen_q = kargs.seqstart_q_ptr[i_batch + 1] - kargs.seqstart_q_ptr[i_batch]; @@ -545,6 +542,7 @@ struct FmhaFwdSplitKVKernel batch_offset_k = static_cast(i_cache_batch) * kargs.batch_stride_k; batch_offset_v = static_cast(i_cache_batch) * kargs.batch_stride_v; batch_offset_lse_acc = static_cast(i_batch) * kargs.batch_stride_lse_acc; + batch_offset_o_acc = static_cast(i_batch) * kargs.batch_stride_o_acc; if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) { @@ -895,8 +893,8 @@ struct FmhaFwdSplitKVKernel const auto o_acc_dram_naive = make_naive_tensor_view( o_acc_ptr, make_tuple(kargs.seqlen_q, kargs.hdim_v), - make_tuple(kargs.hdim_v, 1), - number{}, + make_tuple(kargs.stride_o_acc, 1), + number<1>{}, number<1>{}); return pad_tensor_view( diff --git a/include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_tile_partitioner.hpp b/include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_tile_partitioner.hpp index aec37cb36f..2d06ba1762 100644 --- a/include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_tile_partitioner.hpp +++ b/include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_tile_partitioner.hpp @@ -20,12 +20,12 @@ struct FmhaFwdSplitKVTilePartitioner __host__ static constexpr auto GridSize(ck_tile::index_t batch_size, ck_tile::index_t nhead, - ck_tile::index_t seqlen_q, + ck_tile::index_t max_seqlen_q, ck_tile::index_t hdim_v, ck_tile::index_t num_splits) { // TODO: this may need tuning - return dim3(ck_tile::integer_divide_ceil(seqlen_q, kM0) * + return dim3(ck_tile::integer_divide_ceil(max_seqlen_q, kM0) * ck_tile::integer_divide_ceil(hdim_v, kN1), nhead * num_splits, batch_size); diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_combine_pipeline.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_combine_pipeline.hpp index 7efdb798cb..842090afbe 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_combine_pipeline.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_combine_pipeline.hpp @@ -107,7 +107,7 @@ struct BlockFmhaFwdSplitKVCombinePipeline const LSEElementFunction& lse_element_func, const OaccElementFunction& o_acc_element_func, index_t num_splits, - index_t max_seqlen_q, + index_t seqlen_q, void* smem_ptr) const { // lse_acc tile in LDS @@ -261,7 +261,7 @@ struct BlockFmhaFwdSplitKVCombinePipeline auto o_acc = make_static_distributed_tensor(o_acc_dist); clear_tile(o_acc); - const index_t padded_max_seqlen_q = integer_divide_ceil(max_seqlen_q, kM0) * kM0; + const index_t padded_seqlen_q = integer_divide_ceil(seqlen_q, kM0) * kM0; for(index_t i_split = 0; i_split < num_splits; ++i_split) { @@ -282,7 +282,7 @@ struct BlockFmhaFwdSplitKVCombinePipeline }); } - move_tile_window(o_acc_dram_window, {padded_max_seqlen_q, 0}); + move_tile_window(o_acc_dram_window, {padded_seqlen_q, 0}); } o_acc = tile_elementwise_in(o_acc_element_func, o_acc); @@ -297,7 +297,7 @@ struct BlockFmhaFwdSplitKVCombinePipeline const OaccDramBlockWindow& o_acc_dram_block_window, LSEDramBlockWindow& lse_dram_block_window, index_t num_splits, - index_t max_seqlen_q, + index_t seqlen_q, void* smem_ptr) const { return operator()(lse_acc_dram_block_window, @@ -306,7 +306,7 @@ struct BlockFmhaFwdSplitKVCombinePipeline identity{}, identity{}, num_splits, - max_seqlen_q, + seqlen_q, smem_ptr); } }; diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs.hpp index b257b9e93d..75af7be82f 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs.hpp @@ -64,8 +64,6 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS return kPadSeqLenK ? 1 : Policy::template GetAlignmentV(); }(); - static constexpr index_t kAlignmentO = - kPadHeadDimV ? 1 : Policy::template GetAlignmentO(); static constexpr index_t kAlignmentBias = kPadSeqLenK ? 1 : Policy::template GetAlignmentBias(); @@ -212,8 +210,8 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS const auto [seqlen_k_start, seqlen_k_end] = mask.GetTileRangeAlongX( q_origin.at(number<0>{}), number{}, number{}, num_splits, i_split); - // check early exit if masked and no work to do. - if constexpr(FmhaMask::IsMasking || kHasUnevenSplits) + // check early exit if no work to do + if constexpr(FmhaMask::IsMasking || kPadSeqLenK || kHasUnevenSplits) { const index_t original_num_total_loop = integer_divide_ceil(seqlen_k_end - seqlen_k_start, kN0); @@ -616,7 +614,8 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS sweep_tile_span(o_spans[number<0>{}], [&](auto idx0) { constexpr auto i_idx = make_tuple(idx0); const auto tmp = [&]() { - if constexpr(FmhaMask::IsMasking) + if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS || + FmhaMask::IsMasking) { return l[i_idx] == 0.f ? 0.f : 1 / l[i_idx]; } diff --git a/script/cmake-ck-dev.sh b/script/cmake-ck-dev.sh index 26326523f4..5dae86089a 100755 --- a/script/cmake-ck-dev.sh +++ b/script/cmake-ck-dev.sh @@ -7,8 +7,10 @@ MY_PROJECT_SOURCE=$1 if [ $# -ge 2 ] ; then GPU_TARGETS=$2 + REST_ARGS=${@:3} else GPU_TARGETS="gfx908;gfx90a;gfx940" + REST_ARGS= fi cmake \ @@ -20,4 +22,5 @@ cmake -D GPU_TARGETS=$GPU_TARGETS \ -D CMAKE_VERBOSE_MAKEFILE:BOOL=ON \ -D USE_BITINT_EXTENSION_INT4=OFF \ +$REST_ARGS \ ${MY_PROJECT_SOURCE} diff --git a/script/cmake-ck-release.sh b/script/cmake-ck-release.sh index 25ccb5c799..f65ec610dd 100755 --- a/script/cmake-ck-release.sh +++ b/script/cmake-ck-release.sh @@ -7,8 +7,10 @@ MY_PROJECT_SOURCE=$1 if [ $# -ge 2 ] ; then GPU_TARGETS=$2 + REST_ARGS=${@:3} else GPU_TARGETS="gfx908;gfx90a;gfx940" + REST_ARGS= fi cmake \ @@ -20,5 +22,6 @@ cmake -D GPU_TARGETS=$GPU_TARGETS \ -D CMAKE_VERBOSE_MAKEFILE:BOOL=ON \ -D USE_BITINT_EXTENSION_INT4=OFF \ +$REST_ARGS \ ${MY_PROJECT_SOURCE}