diff --git a/example/ck_tile/01_fmha/fmha_fwd.cpp b/example/ck_tile/01_fmha/fmha_fwd.cpp index 0e59c3c017..f5fc884b65 100644 --- a/example/ck_tile/01_fmha/fmha_fwd.cpp +++ b/example/ck_tile/01_fmha/fmha_fwd.cpp @@ -256,8 +256,8 @@ int override_num_splits_if_necessary( template bool run(const ck_tile::ArgParser& arg_parser) { - std::string data_type = arg_parser.get_str("prec"); - int do_validation = arg_parser.get_int("v"); + std::string data_type = arg_parser.get_str("prec"); + int do_validation = arg_parser.get_int("v"); ck_tile::index_t batch = arg_parser.get_int("b"); ck_tile::index_t nhead = arg_parser.get_int("h"); @@ -307,9 +307,9 @@ bool run(const ck_tile::ArgParser& arg_parser) } auto mode = static_cast(arg_parser.get_uint32("mode")); - if((0 < seqlen_knew || 0 < page_block_size) && mode != mode_enum::batch) { - std::cerr << "kvcache enabled. ignoring the 'mode' option" - << std::endl; + if((0 < seqlen_knew || 0 < page_block_size) && mode != mode_enum::batch) + { + std::cerr << "kvcache enabled. ignoring the 'mode' option" << std::endl; mode = mode_enum::batch; } @@ -780,7 +780,7 @@ bool run(const ck_tile::ArgParser& arg_parser) } else // fmha_fwd_traits or fmha_splitkv_traits { - traits.is_group_mode = (mode == mode_enum::group); + traits.is_group_mode = (mode == mode_enum::group); traits.mask_type = mask.type; traits.bias_type = bias.type; traits.has_lse = lse; @@ -871,12 +871,12 @@ bool run(const ck_tile::ArgParser& arg_parser) args.k_ptr = k_buf.GetDeviceBuffer(); args.v_ptr = v_buf.GetDeviceBuffer(); - args.batch = batch; - args.seqlen_q = shape_seqlen_q; - args.hdim_q = hdim_q; - args.hdim_v = hdim_v; - args.nhead_q = nhead; - args.nhead_k = nhead_k; + args.batch = batch; + args.seqlen_q = shape_seqlen_q; + args.hdim_q = hdim_q; + args.hdim_v = hdim_v; + args.nhead_q = nhead; + args.nhead_k = nhead_k; args.stride_q = stride_q; args.stride_k = stride_k; @@ -919,9 +919,13 @@ bool run(const ck_tile::ArgParser& arg_parser) args.lse_ptr = lse_buf.GetDeviceBuffer(); args.o_ptr = o_buf.GetDeviceBuffer(); - args.seqstart_q_ptr = (mode == mode_enum::group ? seqstart_q.GetDeviceBuffer() : nullptr); - args.seqstart_k_ptr = (mode == mode_enum::group ? seqstart_k.GetDeviceBuffer() : nullptr); - args.seqlen_k_ptr = (0 < seqlen_knew || 0 < page_block_size || 0 <= k_paddings_[0] ? seqlen_k_buf.GetDeviceBuffer() : nullptr); + args.seqstart_q_ptr = + (mode == mode_enum::group ? seqstart_q.GetDeviceBuffer() : nullptr); + args.seqstart_k_ptr = + (mode == mode_enum::group ? seqstart_k.GetDeviceBuffer() : nullptr); + args.seqlen_k_ptr = (0 < seqlen_knew || 0 < page_block_size || 0 <= k_paddings_[0] + ? seqlen_k_buf.GetDeviceBuffer() + : nullptr); args.seqlen_k = (args.seqlen_k_ptr == nullptr ? shape_seqlen_k : -1); args.max_seqlen_q = max_seqlen_q; diff --git a/example/ck_tile/01_fmha/fmha_fwd.hpp b/example/ck_tile/01_fmha/fmha_fwd.hpp index 9ea074f80b..ad88fed9b2 100644 --- a/example/ck_tile/01_fmha/fmha_fwd.hpp +++ b/example/ck_tile/01_fmha/fmha_fwd.hpp @@ -164,9 +164,8 @@ struct fmha_fwd_splitkv_args const void* seqstart_q_ptr; const void* seqstart_k_ptr; - const void* - seqlen_k_ptr; // only used if both 'seqstart_q_ptr' & 'seqstart_k_ptr' are not nullptr, or - // kvcache is used + const void* seqlen_k_ptr; // only used if both 'seqstart_q_ptr' & 'seqstart_k_ptr' are not + // nullptr, or kvcache is used ck_tile::index_t seqlen_q; ck_tile::index_t seqlen_k; // only used if 'seqlen_k_ptr' is nullptr @@ -521,38 +520,38 @@ auto fmha_fwd_appendkv_create_kargs_and_grids(fmha_fwd_appendkv_args args) { assert(args.nhead_q % args.nhead_k == 0); auto kargs = Kernel::MakeKargs(args.q_ptr, - args.k_ptr, - args.knew_ptr, - args.v_ptr, - args.vnew_ptr, - args.seqlen_q, - args.seqlen_k_ptr, - args.seqlen_knew, - args.hdim_q, - args.hdim_v, - args.nhead_q, - args.nhead_q / args.nhead_k, - args.rotary_cos_ptr, - args.rotary_sin_ptr, - args.rotary_dim, - args.block_table_ptr, - args.batch_stride_block_table, - args.page_block_size, - args.stride_q, - args.stride_k, - args.stride_knew, - args.stride_v, - args.stride_vnew, - args.nhead_stride_q, - args.nhead_stride_k, - args.nhead_stride_knew, - args.nhead_stride_v, - args.nhead_stride_vnew, - args.batch_stride_q, - args.batch_stride_k, - args.batch_stride_knew, - args.batch_stride_v, - args.batch_stride_vnew); + args.k_ptr, + args.knew_ptr, + args.v_ptr, + args.vnew_ptr, + args.seqlen_q, + args.seqlen_k_ptr, + args.seqlen_knew, + args.hdim_q, + args.hdim_v, + args.nhead_q, + args.nhead_q / args.nhead_k, + args.rotary_cos_ptr, + args.rotary_sin_ptr, + args.rotary_dim, + args.block_table_ptr, + args.batch_stride_block_table, + args.page_block_size, + args.stride_q, + args.stride_k, + args.stride_knew, + args.stride_v, + args.stride_vnew, + args.nhead_stride_q, + args.nhead_stride_k, + args.nhead_stride_knew, + args.nhead_stride_v, + args.nhead_stride_vnew, + args.batch_stride_q, + args.batch_stride_k, + args.batch_stride_knew, + args.batch_stride_v, + args.batch_stride_vnew); dim3 grids = Kernel::GridSize(args.batch, args.nhead_q, args.seqlen_q, args.seqlen_knew); diff --git a/include/ck_tile/ops/fmha/kernel/fmha_fwd_appendkv_kernel.hpp b/include/ck_tile/ops/fmha/kernel/fmha_fwd_appendkv_kernel.hpp index 23ee6d1b61..404a84e3d6 100644 --- a/include/ck_tile/ops/fmha/kernel/fmha_fwd_appendkv_kernel.hpp +++ b/include/ck_tile/ops/fmha/kernel/fmha_fwd_appendkv_kernel.hpp @@ -128,78 +128,78 @@ struct FmhaFwdAppendKVKernel ck_tile::index_t rotary_dim; }; - struct Kargs : BasicKargs, - std::conditional_t> - {}; - - __host__ static constexpr Kargs - MakeKargs(void* q_ptr, - void* k_ptr, - const void* knew_ptr, - void* v_ptr, - const void* vnew_ptr, - ck_tile::index_t seqlen_q, - const void* seqlen_k_ptr, - ck_tile::index_t seqlen_knew, - ck_tile::index_t hdim_q, - ck_tile::index_t hdim_v, - ck_tile::index_t num_head_q, - ck_tile::index_t nhead_ratio_qk, - const void* rotary_cos_ptr, - const void* rotary_sin_ptr, - ck_tile::index_t rotary_dim, - const void* block_table_ptr, - ck_tile::index_t batch_stride_block_table, - ck_tile::index_t page_block_size, - ck_tile::index_t stride_q, - ck_tile::index_t stride_k, - ck_tile::index_t stride_knew, - ck_tile::index_t stride_v, - ck_tile::index_t stride_vnew, - ck_tile::index_t nhead_stride_q, - ck_tile::index_t nhead_stride_k, - ck_tile::index_t nhead_stride_knew, - ck_tile::index_t nhead_stride_v, - ck_tile::index_t nhead_stride_vnew, - ck_tile::index_t batch_stride_q, - ck_tile::index_t batch_stride_k, - ck_tile::index_t batch_stride_knew, - ck_tile::index_t batch_stride_v, - ck_tile::index_t batch_stride_vnew) + struct Kargs : BasicKargs, std::conditional_t> { - Kargs kargs{{q_ptr, - k_ptr, - knew_ptr, - v_ptr, - vnew_ptr, - reinterpret_cast(seqlen_k_ptr), - seqlen_q, - -1, // seqlen_k will be updated by content of seqlen_k_ptr - seqlen_knew, - hdim_q, - hdim_v, - num_head_q, - nhead_ratio_qk, - block_table_ptr, - batch_stride_block_table, - page_block_size, - stride_q, - stride_k, - stride_knew, - stride_v, - stride_vnew, - nhead_stride_q, - nhead_stride_k, - nhead_stride_knew, - nhead_stride_v, - nhead_stride_vnew, - batch_stride_q, - batch_stride_k, - batch_stride_knew, - batch_stride_v, - batch_stride_vnew}, // args for common karg - {} // placeholder for rope - }; + }; + + __host__ static constexpr Kargs MakeKargs(void* q_ptr, + void* k_ptr, + const void* knew_ptr, + void* v_ptr, + const void* vnew_ptr, + ck_tile::index_t seqlen_q, + const void* seqlen_k_ptr, + ck_tile::index_t seqlen_knew, + ck_tile::index_t hdim_q, + ck_tile::index_t hdim_v, + ck_tile::index_t num_head_q, + ck_tile::index_t nhead_ratio_qk, + const void* rotary_cos_ptr, + const void* rotary_sin_ptr, + ck_tile::index_t rotary_dim, + const void* block_table_ptr, + ck_tile::index_t batch_stride_block_table, + ck_tile::index_t page_block_size, + ck_tile::index_t stride_q, + ck_tile::index_t stride_k, + ck_tile::index_t stride_knew, + ck_tile::index_t stride_v, + ck_tile::index_t stride_vnew, + ck_tile::index_t nhead_stride_q, + ck_tile::index_t nhead_stride_k, + ck_tile::index_t nhead_stride_knew, + ck_tile::index_t nhead_stride_v, + ck_tile::index_t nhead_stride_vnew, + ck_tile::index_t batch_stride_q, + ck_tile::index_t batch_stride_k, + ck_tile::index_t batch_stride_knew, + ck_tile::index_t batch_stride_v, + ck_tile::index_t batch_stride_vnew) + { + Kargs kargs{ + {q_ptr, + k_ptr, + knew_ptr, + v_ptr, + vnew_ptr, + reinterpret_cast(seqlen_k_ptr), + seqlen_q, + -1, // seqlen_k will be updated by content of seqlen_k_ptr + seqlen_knew, + hdim_q, + hdim_v, + num_head_q, + nhead_ratio_qk, + block_table_ptr, + batch_stride_block_table, + page_block_size, + stride_q, + stride_k, + stride_knew, + stride_v, + stride_vnew, + nhead_stride_q, + nhead_stride_k, + nhead_stride_knew, + nhead_stride_v, + nhead_stride_vnew, + batch_stride_q, + batch_stride_k, + batch_stride_knew, + batch_stride_v, + batch_stride_vnew}, // args for common karg + {} // placeholder for rope + }; if constexpr(kApplyRoPE) { @@ -229,11 +229,14 @@ struct FmhaFwdAppendKVKernel const index_t i_m0 = __builtin_amdgcn_readfirstlane(i_tile * FmhaPipeline::kM0); const index_t i_n0 = __builtin_amdgcn_readfirstlane(i_tile * FmhaPipeline::kN0); - const long_index_t batch_offset_q = static_cast(i_batch) * kargs.batch_stride_q; - const long_index_t batch_offset_k = static_cast(i_batch) * kargs.batch_stride_k; + const long_index_t batch_offset_q = + static_cast(i_batch) * kargs.batch_stride_q; + const long_index_t batch_offset_k = + static_cast(i_batch) * kargs.batch_stride_k; const long_index_t batch_offset_knew = static_cast(i_batch) * kargs.batch_stride_knew; - const long_index_t batch_offset_v = static_cast(i_batch) * kargs.batch_stride_v; + const long_index_t batch_offset_v = + static_cast(i_batch) * kargs.batch_stride_v; const long_index_t batch_offset_vnew = static_cast(i_batch) * kargs.batch_stride_vnew; diff --git a/include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_kernel.hpp b/include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_kernel.hpp index c60221d602..6e77151f8a 100644 --- a/include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_kernel.hpp +++ b/include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_kernel.hpp @@ -223,7 +223,7 @@ struct FmhaFwdSplitKVKernel ck_tile::index_t batch, ck_tile::index_t seqlen_q, ck_tile::index_t seqlen_k, // only used if 'seqlen_k_ptr' is not specified - const void* seqlen_k_ptr, // only used for (paged-) kvcache + const void* seqlen_k_ptr, // only used for (paged-) kvcache ck_tile::index_t hdim_q, ck_tile::index_t hdim_v, ck_tile::index_t num_head_q, @@ -495,7 +495,7 @@ struct FmhaFwdSplitKVKernel } // get real # queries & # keys under group mode - kargs.seqlen_q = kargs.seqstart_q_ptr[i_batch + 1] - kargs.seqstart_q_ptr[i_batch]; + kargs.seqlen_q = kargs.seqstart_q_ptr[i_batch + 1] - kargs.seqstart_q_ptr[i_batch]; // # of required blocks is different in each groups, terminate unnecessary blocks // earlier