diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_appendkv.py b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_appendkv.py index 310b793cfa..e975a10c04 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_appendkv.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_appendkv.py @@ -41,7 +41,6 @@ using fmha_pipeline_problem_{F_idx} = ck_tile::BlockFmhaFwdAppendKVPipelineProbl {F_bd}, {F_bdv}, {F_vlayout}, - {F_mode}, fmha_trait_{F_idx}>; using fmha_pipeline_{F_idx} = ck_tile::BlockFmhaFwdAppendKVPipeline< @@ -51,7 +50,7 @@ using fmha_kernel_{F_idx} = ck_tile::FmhaFwdAppendKVKernel, fmha_pipeline_{F_idx}>; -using trait_{F_idx} = fmha_fwd_appendkv_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bs}, {F_bsk}, {F_bd}, {F_bdv}, {F_vlayout}, +using trait_{F_idx} = fmha_fwd_appendkv_traits_<{F_hdim}, {F_dtype}, {F_bs}, {F_bsk}, {F_bd}, {F_bdv}, {F_vlayout}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, {F_rope}, {F_pagedkv}>; #include @@ -78,10 +77,10 @@ float fmha_fwd_appendkv(fmha_fwd_appendkv_traits t, fmha_fwd_appendkv_args a, co }} """ -FMHA_FWD_APPENDKV_API_INNER_DISPATCH=""" {F_if}((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) && +FMHA_FWD_APPENDKV_API_INNER_DISPATCH=""" {F_if}((t.is_v_rowmajor == {F_vlayout}) && ({F_scheck}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck}) && (t.rope_type == {F_rope_check}) && ((a.block_table_ptr != nullptr) == {F_pagedkv})) {{ - using trait_ = fmha_fwd_appendkv_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bs}, {F_bsk}, {F_bd}, {F_bdv}, {F_vlayout}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, {F_rope}, {F_pagedkv}>; + using trait_ = fmha_fwd_appendkv_traits_<{F_hdim}, {F_dtype}, {F_bs}, {F_bsk}, {F_bd}, {F_bdv}, {F_vlayout}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, {F_rope}, {F_pagedkv}>; return fmha_fwd_appendkv_(s, a); }} """ @@ -91,7 +90,6 @@ class FmhaFwdAppendKVApiTrait: # sync with fmha_fwd_traits<>, to generate fallback calls hdim : str dtype : str # data type - mode : str # value from MODE_MAP bs : int # tile size along q seqlen bsk : int # tile size along k seqlen bd : int # tile size along qk gemm unroll @@ -106,13 +104,11 @@ class FmhaFwdAppendKVApiTrait: @property def name(self) -> str: - return f'{self.hdim}-{self.dtype}-{self.mode}-{self.bs}-{self.bsk}-{self.bd}-{self.bdv}-'+\ - f'{self.vlayout}-{self.spad}-{self.skpad}-{self.dpad}-{self.dvpad}-{self.rope}-'+\ - f'{self.pagedkv}' + return f'{self.hdim}-{self.dtype}-{self.bs}-{self.bsk}-{self.bd}-{self.bdv}-{self.vlayout}-'+\ + f'{self.spad}-{self.skpad}-{self.dpad}-{self.dvpad}-{self.rope}-{self.pagedkv}' @property def scheck(self) -> str: - if self.mode == 'group': return 'true/*group mode spad always true*/' # group mode only generate spad/skpad == true if self.spad == 't' : return f'true /*a.seqlen_q % {self.bs} != 0*/' else : return f'a.seqlen_q % {self.bs} == 0' @@ -183,7 +179,7 @@ class FmhaFwdAppendKVApiPool: inners=str() for k, trait in enumerate(traits): if_k = 'if' if k == 0 else 'else if' - inners = inners + FMHA_FWD_APPENDKV_API_INNER_DISPATCH.format(F_if=if_k, F_mode=MODE_MAP[trait.mode], F_vlayout=LAYOUT_MAP[trait.vlayout], + inners = inners + FMHA_FWD_APPENDKV_API_INNER_DISPATCH.format(F_if=if_k, F_vlayout=LAYOUT_MAP[trait.vlayout], F_scheck=trait.scheck, F_skcheck=trait.skcheck, F_dcheck=trait.dcheck, F_dvcheck=trait.dvcheck, F_rope_check=ROPE_CHECK_MAP[trait.rope], F_pagedkv=BOOL_MAP[trait.pagedkv], F_spad=BOOL_MAP[trait.spad], F_skpad=BOOL_MAP[trait.skpad], F_dpad=BOOL_MAP[trait.dpad], F_dvpad=BOOL_MAP[trait.dvpad], F_rope=ROPE_MAP[trait.rope], F_bs=trait.bs, F_bsk=trait.bsk, F_bd=trait.bd, F_bdv=trait.bdv, F_hdim=hdim, F_dtype=DTYPE_MAP[dtype]) @@ -210,7 +206,6 @@ class FmhaFwdAppendKVKernel: F_idx : int # this is not a tunable, but a counter to differentiate symbol F_hdim : int # hdim F_dtype : str # data type - F_mode : str # value from MODE_MAP F_tile : FmhaFwdAppendKVTileSize F_pipeline : FmhaFwdAppendKVPipeline mask_impl : str @@ -234,13 +229,12 @@ class FmhaFwdAppendKVKernel: F_dvpad = BOOL_MAP[self.F_pipeline.F_dvpad], F_rope = ROPE_MAP[self.F_pipeline.F_rope], F_pagedkv = BOOL_MAP[self.F_pipeline.F_pagedkv], - F_occupancy = self.F_tile.F_occupancy, - F_mode = MODE_MAP[self.F_mode]) + F_occupancy = self.F_tile.F_occupancy) @property def name(self) -> str: # TODO: we don't encode idx here - return f"fmha_fwd_appendkv_d{self.F_hdim}_{self.F_dtype}_{self.F_mode}_" + \ + return f"fmha_fwd_appendkv_d{self.F_hdim}_{self.F_dtype}_" + \ self.F_tile.name + '_' + self.F_pipeline.name @property @@ -251,7 +245,6 @@ class FmhaFwdAppendKVKernel: return FmhaFwdAppendKVApiTrait( hdim=str(self.F_hdim), dtype=self.F_dtype, - mode=self.F_mode, bs=self.F_tile.F_bs, bsk=self.F_tile.F_bsk, bd=self.F_tile.F_bd, @@ -320,19 +313,13 @@ def get_fwd_appendkv_blobs(kernel_filter : Optional[str], receipt, mask_impl) -> d = get_fmha_fwd_appendkv_tile_dict_from_dtype(dtype) if d == None: continue - #for hdim_str, mode, mask, bias, lse in itertools.product(d.keys(), MODE_MAP.keys(), MASK_MAP.keys(), ["t", "f"], ["t", "f"]): - for hdim_str, mode in itertools.product(d.keys(), MODE_MAP.keys()): + for hdim_str in d.keys(): tile = d[hdim_str] hdim = int(hdim_str) for pipeline in get_pipelines(dtype, hdim): - if mode == "group": - if pipeline.F_spad != 't' or pipeline.F_skpad != 't': - # in group mode, spad/skpad must be true, since we can't predict if seqlen of current batch need pad or not - continue k = FmhaFwdAppendKVKernel(F_idx=0, F_hdim=hdim, F_dtype=dtype, - F_mode=mode, F_tile=tile, F_pipeline=pipeline, mask_impl=mask_impl) diff --git a/example/ck_tile/01_fmha/fmha_fwd.cpp b/example/ck_tile/01_fmha/fmha_fwd.cpp index d7beb7eb07..a5b4a2135e 100644 --- a/example/ck_tile/01_fmha/fmha_fwd.cpp +++ b/example/ck_tile/01_fmha/fmha_fwd.cpp @@ -258,7 +258,7 @@ bool run(const ck_tile::ArgParser& arg_parser) { std::string data_type = arg_parser.get_str("prec"); int do_validation = arg_parser.get_int("v"); - auto mode = static_cast(arg_parser.get_uint32("mode")); + ck_tile::index_t batch = arg_parser.get_int("b"); ck_tile::index_t nhead = arg_parser.get_int("h"); ck_tile::index_t nhead_k = arg_parser.get_int("h_k"); @@ -278,7 +278,7 @@ bool run(const ck_tile::ArgParser& arg_parser) } ck_tile::index_t seqlen_knew = arg_parser.get_int("s_knew"); -#if !CK_TILE_FMHA_FWD_APPENDKV_API +#if !CK_TILE_FMHA_FWD_APPENDKV_API || !CK_TILE_FMHA_FWD_SPLITKV_API if(seqlen_knew != 0) { std::cerr << "kvcache is not supported. ignoring the 's_knew' option" << std::endl; @@ -290,6 +290,29 @@ bool run(const ck_tile::ArgParser& arg_parser) seqlen_knew = randint(1, arg_parser.get_int("s"), seed); } + ck_tile::index_t page_block_size = arg_parser.get_int("page_block_size"); +#if !CK_TILE_FMHA_FWD_SPLITKV_API + if(0 < page_block_size) + { + std::cerr << "paged-kvcache is not supported. ignoring the 'page_block_size' option" + << std::endl; + page_block_size = 0; + } +#endif + if(!(page_block_size % 128 == 0)) + { + std::cerr << "only paged-kvcache block size divisible by 128 are currently supported" + << std::endl; + return false; + } + + auto mode = static_cast(arg_parser.get_uint32("mode")); + if ((0 < seqlen_knew || 0 < page_block_size) && mode != mode_enum::batch) { + std::cerr << "kvcache enabled. ignoring the 'mode' option" + << std::endl; + mode = mode_enum::batch; + } + auto [seqlen_qs, seqlen_ks, seqlen_kpads] = decode_seqlen(mode, batch, arg_parser.get_str("s"), @@ -420,21 +443,6 @@ bool run(const ck_tile::ArgParser& arg_parser) num_splits = 1; } #endif - ck_tile::index_t page_block_size = arg_parser.get_int("page_block_size"); -#if !CK_TILE_FMHA_FWD_SPLITKV_API - if(0 < page_block_size) - { - std::cerr << "paged-kvcache is not supported. ignoring the 'page_block_size' option" - << std::endl; - page_block_size = 0; - } -#endif - if(!(page_block_size % 128 == 0)) - { - std::cerr << "only paged-kvcache block size divisible by 128 are currently supported" - << std::endl; - return false; - } int stream_warmup = arg_parser.get_int("warmup"); int stream_repeat = arg_parser.get_int("repeat"); @@ -581,11 +589,11 @@ bool run(const ck_tile::ArgParser& arg_parser) generate_rotary_cos_sin(shape_seqlen_k, rotary_dim, seed); ck_tile::HostTensor lse_acc_host( - 1 < num_splits || 0 < page_block_size + 1 < num_splits || 0 < seqlen_knew || 0 < page_block_size ? std::array{num_splits, batch, nhead, max_seqlen_q} : std::array{1, 1, 1, 1}); ck_tile::HostTensor o_acc_host( - 1 < num_splits || 0 < page_block_size + 1 < num_splits || 0 < seqlen_knew || 0 < page_block_size ? std::array{num_splits, batch, nhead, max_seqlen_q, hdim_v} : std::array{1, 1, 1, 1, 1}); @@ -762,7 +770,6 @@ bool run(const ck_tile::ArgParser& arg_parser) traits.hdim_q = hdim_q; traits.hdim_v = hdim_v; traits.data_type = data_type; - traits.is_group_mode = (mode == mode_enum::group); traits.is_v_rowmajor = is_v_rowmajor; if constexpr(std::is_same_v>) @@ -773,6 +780,7 @@ bool run(const ck_tile::ArgParser& arg_parser) } else // fmha_fwd_traits or fmha_splitkv_traits { + traits.is_group_mode = (mode == mode_enum::group); traits.mask_type = mask.type; traits.bias_type = bias.type; traits.has_lse = lse; @@ -863,12 +871,8 @@ bool run(const ck_tile::ArgParser& arg_parser) args.k_ptr = k_buf.GetDeviceBuffer(); args.v_ptr = v_buf.GetDeviceBuffer(); - args.seqstart_q_ptr = seqstart_q.GetDeviceBuffer(); - args.seqstart_k_ptr = seqstart_k.GetDeviceBuffer(); - - args.seqlen_q = shape_seqlen_q; args.batch = batch; - args.max_seqlen_q = max_seqlen_q; + args.seqlen_q = shape_seqlen_q; args.hdim_q = hdim_q; args.hdim_v = hdim_v; args.nhead_q = nhead; @@ -891,7 +895,6 @@ bool run(const ck_tile::ArgParser& arg_parser) args.seqlen_knew = seqlen_knew; args.seqlen_k_ptr = cache_seqlen_k_buf.GetDeviceBuffer(); - args.seqlen_k = shape_seqlen_k - seqlen_knew; // kvcache seqlen for batch mode args.rotary_cos_ptr = rotary_cos_buf.GetDeviceBuffer(); args.rotary_sin_ptr = rotary_sin_buf.GetDeviceBuffer(); @@ -916,8 +919,12 @@ bool run(const ck_tile::ArgParser& arg_parser) args.lse_ptr = lse_buf.GetDeviceBuffer(); args.o_ptr = o_buf.GetDeviceBuffer(); - args.seqlen_k_ptr = k_paddings_[0] < 0 ? nullptr : seqlen_k_buf.GetDeviceBuffer(); - args.seqlen_k = shape_seqlen_k; + args.seqstart_q_ptr = (mode == mode_enum::group ? seqstart_q.GetDeviceBuffer() : nullptr); + args.seqstart_k_ptr = (mode == mode_enum::group ? seqstart_k.GetDeviceBuffer() : nullptr); + args.seqlen_k_ptr = (0 < seqlen_knew || 0 < page_block_size || 0 <= k_paddings_[0] ? seqlen_k_buf.GetDeviceBuffer() : nullptr); + + args.seqlen_k = (args.seqlen_k_ptr == nullptr ? shape_seqlen_k : -1); + args.max_seqlen_q = max_seqlen_q; args.scale_s = scale_s; args.scale_p = scale_p; @@ -990,7 +997,7 @@ bool run(const ck_tile::ArgParser& arg_parser) const float fwd_ave_time = [&] { #if CK_TILE_FMHA_FWD_SPLITKV_API - if(1 < num_splits || 0 < page_block_size) + if(1 < num_splits || 0 < seqlen_knew || 0 < page_block_size) { fmha_fwd_splitkv_traits fmha_splitkv_traits; init_traits(fmha_splitkv_traits); diff --git a/example/ck_tile/01_fmha/fmha_fwd.hpp b/example/ck_tile/01_fmha/fmha_fwd.hpp index ad40061355..9ea074f80b 100644 --- a/example/ck_tile/01_fmha/fmha_fwd.hpp +++ b/example/ck_tile/01_fmha/fmha_fwd.hpp @@ -165,10 +165,11 @@ struct fmha_fwd_splitkv_args const void* seqstart_q_ptr; const void* seqstart_k_ptr; const void* - seqlen_k_ptr; // only used if both 'seqstart_q_ptr' & 'seqstart_k_ptr' are not nullptr + seqlen_k_ptr; // only used if both 'seqstart_q_ptr' & 'seqstart_k_ptr' are not nullptr, or + // kvcache is used ck_tile::index_t seqlen_q; - ck_tile::index_t seqlen_k; + ck_tile::index_t seqlen_k; // only used if 'seqlen_k_ptr' is nullptr ck_tile::index_t batch; ck_tile::index_t max_seqlen_q; ck_tile::index_t hdim_q; @@ -219,16 +220,11 @@ struct fmha_fwd_appendkv_args void* v_ptr; const void* vnew_ptr; - const void* seqstart_q_ptr; - const void* seqstart_k_ptr; - const void* - seqlen_k_ptr; // only used if both 'seqstart_q_ptr' & 'seqstart_k_ptr' are not nullptr + const void* seqlen_k_ptr; ck_tile::index_t seqlen_q; - ck_tile::index_t seqlen_k; ck_tile::index_t seqlen_knew; ck_tile::index_t batch; - ck_tile::index_t max_seqlen_q; ck_tile::index_t hdim_q; ck_tile::index_t hdim_v; ck_tile::index_t nhead_q; @@ -371,7 +367,6 @@ auto fmha_fwd_splitkv_create_kargs_and_grids(fmha_fwd_splitkv_args args) args.lse_acc_ptr, args.o_acc_ptr, args.batch, - args.max_seqlen_q, args.seqstart_q_ptr, args.seqstart_k_ptr, args.seqlen_k_ptr, @@ -415,9 +410,9 @@ auto fmha_fwd_splitkv_create_kargs_and_grids(fmha_fwd_splitkv_args args) args.lse_acc_ptr, args.o_acc_ptr, args.batch, - args.max_seqlen_q, args.seqlen_q, args.seqlen_k, + args.seqlen_k_ptr, args.hdim_q, args.hdim_v, args.nhead_q, @@ -525,53 +520,13 @@ template auto fmha_fwd_appendkv_create_kargs_and_grids(fmha_fwd_appendkv_args args) { assert(args.nhead_q % args.nhead_k == 0); - auto kargs = [&] { - // create group mode kernel arguments - if constexpr(Kernel::kIsGroupMode) - { - return Kernel::MakeKargs(args.q_ptr, - args.k_ptr, - args.knew_ptr, - args.v_ptr, - args.vnew_ptr, - args.seqstart_q_ptr, - args.seqstart_k_ptr, - args.seqlen_k_ptr, - args.seqlen_knew, - args.hdim_q, - args.hdim_v, - args.nhead_q, - args.nhead_q / args.nhead_k, - args.rotary_cos_ptr, - args.rotary_sin_ptr, - args.rotary_dim, - args.block_table_ptr, - args.batch_stride_block_table, - args.page_block_size, - args.stride_q, - args.stride_k, - args.stride_knew, - args.stride_v, - args.stride_vnew, - args.nhead_stride_q, - args.nhead_stride_k, - args.nhead_stride_knew, - args.nhead_stride_v, - args.nhead_stride_vnew, - args.batch_stride_k, - args.batch_stride_knew, - args.batch_stride_v, - args.batch_stride_vnew); - } - else - { // create batch mode kernel arguments - return Kernel::MakeKargs(args.q_ptr, + auto kargs = Kernel::MakeKargs(args.q_ptr, args.k_ptr, args.knew_ptr, args.v_ptr, args.vnew_ptr, args.seqlen_q, - args.seqlen_k, + args.seqlen_k_ptr, args.seqlen_knew, args.hdim_q, args.hdim_v, @@ -598,10 +553,8 @@ auto fmha_fwd_appendkv_create_kargs_and_grids(fmha_fwd_appendkv_args args) args.batch_stride_knew, args.batch_stride_v, args.batch_stride_vnew); - } - }(); - dim3 grids = Kernel::GridSize(args.batch, args.nhead_q, args.max_seqlen_q, args.seqlen_knew); + dim3 grids = Kernel::GridSize(args.batch, args.nhead_q, args.seqlen_q, args.seqlen_knew); return ck_tile::make_tuple(kargs, grids); } @@ -735,7 +688,6 @@ std::string fmha_fwd_splitkv_combine_get_name_(); // this is used to pattern-match internl kernel implementation, not to instantiate kernel template ; - static constexpr bool kIsGroupMode = kIsGroupMode_; static constexpr ck_tile::index_t kTileSizeS = kTileSizeS_; static constexpr ck_tile::index_t kTileSizeSk = kTileSizeSk_; static constexpr ck_tile::index_t kTileSizeD = kTileSizeD_; @@ -807,7 +758,6 @@ struct fmha_fwd_appendkv_traits int hdim_q; int hdim_v; std::string data_type; - bool is_group_mode; bool is_v_rowmajor; rope_enum rope_type; }; diff --git a/include/ck_tile/ops/fmha/kernel/fmha_fwd_appendkv_kernel.hpp b/include/ck_tile/ops/fmha/kernel/fmha_fwd_appendkv_kernel.hpp index 4d46e283d0..23ee6d1b61 100644 --- a/include/ck_tile/ops/fmha/kernel/fmha_fwd_appendkv_kernel.hpp +++ b/include/ck_tile/ops/fmha/kernel/fmha_fwd_appendkv_kernel.hpp @@ -26,7 +26,6 @@ struct FmhaFwdAppendKVKernel using VLayout = ck_tile::remove_cvref_t; - static constexpr bool kIsGroupMode = FmhaPipeline::kIsGroupMode; static constexpr bool kPadSeqLenQ = FmhaPipeline::kPadSeqLenQ; static constexpr bool kPadSeqLenK = FmhaPipeline::kPadSeqLenK; static constexpr bool kPadHeadDimQ = FmhaPipeline::kPadHeadDimQ; @@ -58,8 +57,7 @@ struct FmhaFwdAppendKVKernel if (kPadHeadDimV) n += "dv"; return n.empty() ? n : std::string("p") + n; }(); return - _SS_("fmha_fwd_appendkv_d") + _TS_(FmhaPipeline::kK0) + "_" + _SS_(t2s::name) + - "_" + (kIsGroupMode ? "group" : "batch") + "_" + _SS_("fmha_fwd_appendkv_d") + _TS_(FmhaPipeline::kK0) + "_" + _SS_(t2s::name) + "_" "b" + _TS_(FmhaPipeline::kM0) + "x" + _TS_(FmhaPipeline::kN0) + "x" + _TS_(FmhaPipeline::kK0) + "x" + _TS_(FmhaPipeline::kN1) + "_" + (kBlockPerCuInput == -1 ? "" : ("o" + _TS_(kBlockPerCu) + "_")) + "v" + (std::is_same_v ? "r" : "c") + (pn.empty() ? "" : "_" + pn) @@ -79,7 +77,7 @@ struct FmhaFwdAppendKVKernel // kargs use aggregate initializer, so no constructor will provided // use inheritance to minimize karg size // user need to use MakeKargs() function to create kargs. - struct CommonKargs + struct BasicKargs { void* q_ptr; void* k_ptr; @@ -87,6 +85,8 @@ struct FmhaFwdAppendKVKernel void* v_ptr; const void* vnew_ptr; + const int32_t* seqlen_k_ptr; + ck_tile::index_t seqlen_q; ck_tile::index_t seqlen_k; ck_tile::index_t seqlen_knew; @@ -114,47 +114,32 @@ struct FmhaFwdAppendKVKernel ck_tile::index_t nhead_stride_v; ck_tile::index_t nhead_stride_vnew; + ck_tile::index_t batch_stride_q; + ck_tile::index_t batch_stride_k; ck_tile::index_t batch_stride_knew; + ck_tile::index_t batch_stride_v; ck_tile::index_t batch_stride_vnew; }; - struct CommonRoPEKargs + struct RoPEKargs { const void* rotary_cos_ptr; const void* rotary_sin_ptr; ck_tile::index_t rotary_dim; }; - struct BatchModeKargs : CommonKargs, - std::conditional_t> - { - ck_tile::index_t batch_stride_q; - ck_tile::index_t batch_stride_k; - ck_tile::index_t batch_stride_v; - }; + struct Kargs : BasicKargs, + std::conditional_t> + {}; - struct GroupModeKargs : CommonKargs, - std::conditional_t> - { - const int32_t* seqstart_q_ptr; - const int32_t* seqstart_k_ptr; - const int32_t* seqlen_k_ptr; - - ck_tile::index_t batch_stride_k; - ck_tile::index_t batch_stride_v; - }; - - using Kargs = std::conditional_t; - - template - __host__ static constexpr std::enable_if_t + __host__ static constexpr Kargs MakeKargs(void* q_ptr, void* k_ptr, const void* knew_ptr, void* v_ptr, const void* vnew_ptr, ck_tile::index_t seqlen_q, - ck_tile::index_t seqlen_k, + const void* seqlen_k_ptr, ck_tile::index_t seqlen_knew, ck_tile::index_t hdim_q, ck_tile::index_t hdim_v, @@ -187,8 +172,9 @@ struct FmhaFwdAppendKVKernel knew_ptr, v_ptr, vnew_ptr, + reinterpret_cast(seqlen_k_ptr), seqlen_q, - seqlen_k, + -1, // seqlen_k will be updated by content of seqlen_k_ptr seqlen_knew, hdim_q, hdim_v, @@ -207,92 +193,13 @@ struct FmhaFwdAppendKVKernel nhead_stride_knew, nhead_stride_v, nhead_stride_vnew, + batch_stride_q, + batch_stride_k, batch_stride_knew, + batch_stride_v, batch_stride_vnew}, // args for common karg - {}, // placeholder for rope - batch_stride_q, - batch_stride_k, - batch_stride_v}; - - if constexpr(kApplyRoPE) - { - kargs.rotary_cos_ptr = rotary_cos_ptr; - kargs.rotary_sin_ptr = rotary_sin_ptr; - kargs.rotary_dim = rotary_dim; - } - - return kargs; - } - - template - __host__ static constexpr std::enable_if_t - MakeKargs(void* q_ptr, - void* k_ptr, - const void* knew_ptr, - void* v_ptr, - const void* vnew_ptr, - const void* seqstart_q_ptr, - const void* seqstart_k_ptr, - const void* seqlen_k_ptr, - ck_tile::index_t seqlen_knew, - ck_tile::index_t hdim_q, - ck_tile::index_t hdim_v, - ck_tile::index_t num_head_q, - ck_tile::index_t nhead_ratio_qk, - const void* rotary_cos_ptr, - const void* rotary_sin_ptr, - ck_tile::index_t rotary_dim, - const void* block_table_ptr, - ck_tile::index_t batch_stride_block_table, - ck_tile::index_t page_block_size, - ck_tile::index_t stride_q, - ck_tile::index_t stride_k, - ck_tile::index_t stride_knew, - ck_tile::index_t stride_v, - ck_tile::index_t stride_vnew, - ck_tile::index_t nhead_stride_q, - ck_tile::index_t nhead_stride_k, - ck_tile::index_t nhead_stride_knew, - ck_tile::index_t nhead_stride_v, - ck_tile::index_t nhead_stride_vnew, - ck_tile::index_t batch_stride_k, - ck_tile::index_t batch_stride_knew, - ck_tile::index_t batch_stride_v, - ck_tile::index_t batch_stride_vnew) - { - Kargs kargs{{q_ptr, - k_ptr, - knew_ptr, - v_ptr, - vnew_ptr, - -1, // seqlen will be updated by another pointer - -1, // - seqlen_knew, - hdim_q, - hdim_v, - num_head_q, - nhead_ratio_qk, - block_table_ptr, - batch_stride_block_table, - page_block_size, - stride_q, - stride_k, - stride_knew, - stride_v, - stride_vnew, - nhead_stride_q, - nhead_stride_k, - nhead_stride_knew, - nhead_stride_v, - nhead_stride_vnew, - batch_stride_knew, - batch_stride_vnew}, // args for common karg - {}, // placeholder for rope - reinterpret_cast(seqstart_q_ptr), - reinterpret_cast(seqstart_k_ptr), - reinterpret_cast(seqlen_k_ptr), - batch_stride_k, - batch_stride_v}; + {} // placeholder for rope + }; if constexpr(kApplyRoPE) { @@ -322,51 +229,15 @@ struct FmhaFwdAppendKVKernel const index_t i_m0 = __builtin_amdgcn_readfirstlane(i_tile * FmhaPipeline::kM0); const index_t i_n0 = __builtin_amdgcn_readfirstlane(i_tile * FmhaPipeline::kN0); - long_index_t batch_offset_q = 0; - long_index_t batch_offset_k = 0; - long_index_t batch_offset_knew = + const long_index_t batch_offset_q = static_cast(i_batch) * kargs.batch_stride_q; + const long_index_t batch_offset_k = static_cast(i_batch) * kargs.batch_stride_k; + const long_index_t batch_offset_knew = static_cast(i_batch) * kargs.batch_stride_knew; - long_index_t batch_offset_v = 0; - long_index_t batch_offset_vnew = + const long_index_t batch_offset_v = static_cast(i_batch) * kargs.batch_stride_v; + const long_index_t batch_offset_vnew = static_cast(i_batch) * kargs.batch_stride_vnew; - if constexpr(kIsGroupMode) - { - // get starting offset for each batch - const long_index_t query_start = kargs.seqstart_q_ptr[i_batch]; - const long_index_t key_start = kargs.seqstart_k_ptr[i_batch]; - - batch_offset_q = query_start * kargs.stride_q; - batch_offset_k = key_start * kargs.stride_k; - if constexpr(std::is_same_v) - { - batch_offset_v = key_start * kargs.stride_v; - } - else - { - batch_offset_v = key_start; - } - - // get real # queries & # keys under group mode - const auto adjusted_seqstart_q_ptr = kargs.seqstart_q_ptr + i_batch; - kargs.seqlen_q = adjusted_seqstart_q_ptr[1] - adjusted_seqstart_q_ptr[0]; - - if(kargs.seqlen_k_ptr != nullptr) - { - kargs.seqlen_k = kargs.seqlen_k_ptr[i_batch]; - } - else - { - const auto adjusted_seqstart_k_ptr = kargs.seqstart_k_ptr + i_batch; - kargs.seqlen_k = adjusted_seqstart_k_ptr[1] - adjusted_seqstart_k_ptr[0]; - } - } - else - { - batch_offset_q = static_cast(i_batch) * kargs.batch_stride_q; - batch_offset_k = static_cast(i_batch) * kargs.batch_stride_k; - batch_offset_v = static_cast(i_batch) * kargs.batch_stride_v; - } + kargs.seqlen_k = kargs.seqlen_k_ptr[i_batch]; auto k_page_block_navigator = [&, i_batch_ = i_batch, i_nhead_ = i_nhead]() { if constexpr(kIsPagedKV) diff --git a/include/ck_tile/ops/fmha/kernel/fmha_fwd_appendkv_tile_partitioner.hpp b/include/ck_tile/ops/fmha/kernel/fmha_fwd_appendkv_tile_partitioner.hpp index 1190520995..97c9b960c2 100644 --- a/include/ck_tile/ops/fmha/kernel/fmha_fwd_appendkv_tile_partitioner.hpp +++ b/include/ck_tile/ops/fmha/kernel/fmha_fwd_appendkv_tile_partitioner.hpp @@ -19,11 +19,11 @@ struct FmhaFwdAppendKVTilePartitioner CK_TILE_HOST static constexpr auto GridSize(ck_tile::index_t batch_size, ck_tile::index_t nhead, - ck_tile::index_t max_seqlen_q, + ck_tile::index_t seqlen_q, ck_tile::index_t seqlen_knew) { // TODO: this may need tuning - return dim3(std::max(ck_tile::integer_divide_ceil(max_seqlen_q, kM0), + return dim3(std::max(ck_tile::integer_divide_ceil(seqlen_q, kM0), ck_tile::integer_divide_ceil(seqlen_knew, kN0)), nhead, batch_size); diff --git a/include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_kernel.hpp b/include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_kernel.hpp index fdc9eea8b8..c60221d602 100644 --- a/include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_kernel.hpp +++ b/include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_kernel.hpp @@ -108,7 +108,6 @@ struct FmhaFwdSplitKVKernel void* o_acc_ptr; ck_tile::index_t batch; - ck_tile::index_t max_seqlen_q; ck_tile::index_t seqlen_q; ck_tile::index_t seqlen_k; @@ -186,6 +185,8 @@ struct FmhaFwdSplitKVKernel std::conditional_t>, std::conditional_t> { + const int32_t* seqlen_k_ptr; + ck_tile::index_t batch_stride_q; ck_tile::index_t batch_stride_k; ck_tile::index_t batch_stride_v; @@ -220,9 +221,9 @@ struct FmhaFwdSplitKVKernel void* lse_acc_ptr, void* o_acc_ptr, ck_tile::index_t batch, - ck_tile::index_t max_seqlen_q, ck_tile::index_t seqlen_q, - ck_tile::index_t seqlen_k, + ck_tile::index_t seqlen_k, // only used if 'seqlen_k_ptr' is not specified + const void* seqlen_k_ptr, // only used for (paged-) kvcache ck_tile::index_t hdim_q, ck_tile::index_t hdim_v, ck_tile::index_t num_head_q, @@ -262,7 +263,6 @@ struct FmhaFwdSplitKVKernel lse_acc_ptr, o_acc_ptr, batch, - max_seqlen_q, seqlen_q, seqlen_k, hdim_q, @@ -294,6 +294,7 @@ struct FmhaFwdSplitKVKernel {}, // placeholder for bias {}, // placeholder for mask {}, // placeholder for fp8_static_quant args + reinterpret_cast(seqlen_k_ptr), batch_stride_q, batch_stride_k, batch_stride_v}; @@ -333,7 +334,6 @@ struct FmhaFwdSplitKVKernel void* lse_acc_ptr, void* o_acc_ptr, ck_tile::index_t batch, - ck_tile::index_t max_seqlen_q, const void* seqstart_q_ptr, const void* seqstart_k_ptr, const void* seqlen_k_ptr, @@ -374,9 +374,8 @@ struct FmhaFwdSplitKVKernel lse_acc_ptr, o_acc_ptr, batch, - max_seqlen_q, - -1, // seqlen will be updated by another pointer - -1, // + -1, // seqlen_q will be updated by another pointer + -1, // seqlen_k will be updated by another pointer hdim_q, hdim_v, num_head_q, @@ -496,8 +495,7 @@ struct FmhaFwdSplitKVKernel } // get real # queries & # keys under group mode - const auto adjusted_seqstart_q_ptr = kargs.seqstart_q_ptr + i_batch; - kargs.seqlen_q = adjusted_seqstart_q_ptr[1] - adjusted_seqstart_q_ptr[0]; + kargs.seqlen_q = kargs.seqstart_q_ptr[i_batch + 1] - kargs.seqstart_q_ptr[i_batch]; // # of required blocks is different in each groups, terminate unnecessary blocks // earlier @@ -512,8 +510,7 @@ struct FmhaFwdSplitKVKernel } else { - const auto adjusted_seqstart_k_ptr = kargs.seqstart_k_ptr + i_batch; - kargs.seqlen_k = adjusted_seqstart_k_ptr[1] - adjusted_seqstart_k_ptr[0]; + kargs.seqlen_k = kargs.seqstart_k_ptr[i_batch + 1] - kargs.seqstart_k_ptr[i_batch]; } } else @@ -526,6 +523,11 @@ struct FmhaFwdSplitKVKernel { batch_offset_bias = static_cast(i_batch) * kargs.batch_stride_bias; } + + if(kargs.seqlen_k_ptr != nullptr) + { + kargs.seqlen_k = kargs.seqlen_k_ptr[i_batch]; + } } auto k_page_block_navigator = [&, i_batch_ = i_batch, i_nhead_ = i_nhead]() { diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_appendkv_pipeline.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_appendkv_pipeline.hpp index 734abefe63..def3477055 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_appendkv_pipeline.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_appendkv_pipeline.hpp @@ -27,7 +27,6 @@ struct BlockFmhaFwdAppendKVPipeline static constexpr index_t kK0 = Problem::kK0; static constexpr index_t kN1 = Problem::kN1; - static constexpr bool kIsGroupMode = Problem::kIsGroupMode; static constexpr bool kPadSeqLenQ = Problem::kPadSeqLenQ; static constexpr bool kPadSeqLenK = Problem::kPadSeqLenK; static constexpr bool kPadHeadDimQ = Problem::kPadHeadDimQ; diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_problem.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_problem.hpp index a2c9faea88..e8ab04cb78 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_problem.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_problem.hpp @@ -139,7 +139,6 @@ template struct BlockFmhaFwdAppendKVPipelineProblem { @@ -149,7 +148,6 @@ struct BlockFmhaFwdAppendKVPipelineProblem using Traits = remove_cvref_t; static constexpr index_t kBlockSize = 256; - static constexpr bool kIsGroupMode = kIsGroupMode_; static constexpr index_t kM0 = kM0_; static constexpr index_t kN0 = kN0_;