From 391210ed9e57f655daf110f80afb1317e21ca48e Mon Sep 17 00:00:00 2001 From: "PoYen, Chen" Date: Sun, 14 Jul 2024 23:18:32 +0000 Subject: [PATCH] Pass RoPE kernel args --- .../01_fmha/codegen/ops/fmha_fwd_appendkv.py | 30 +++++++----- example/ck_tile/01_fmha/fmha_fwd.cpp | 2 +- example/ck_tile/01_fmha/fmha_fwd.hpp | 23 ++++++--- .../fmha/kernel/fmha_fwd_appendkv_kernel.hpp | 49 +++++++++++++++++-- .../block_fmha_fwd_appendkv_pipeline.hpp | 1 + ...ock_fmha_fwd_appendkv_pipeline_problem.hpp | 1 + .../ops/fmha/pipeline/tile_fmha_traits.hpp | 2 + 7 files changed, 83 insertions(+), 25 deletions(-) diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_appendkv.py b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_appendkv.py index f0d0fd720a..d3f128cc9a 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_appendkv.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_appendkv.py @@ -28,6 +28,7 @@ using fmha_trait_{F_idx} = ck_tile::TileFmhaFwdAppendKVTraits<{F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, + {F_rope}, {F_occupancy}>; using fmha_pipeline_problem_{F_idx} = ck_tile::BlockFmhaFwdAppendKVPipelineProblem< @@ -50,7 +51,7 @@ using fmha_kernel_{F_idx} = fmha_pipeline_{F_idx}>; using trait_{F_idx} = fmha_fwd_appendkv_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bs}, {F_bsk}, {F_bd}, {F_bdv}, {F_vlayout}, - {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}>; + {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, {F_rope}>; #include @@ -77,8 +78,8 @@ float fmha_fwd_appendkv(fmha_fwd_appendkv_traits t, fmha_fwd_appendkv_args a, co """ FMHA_FWD_APPENDKV_API_INNER_DISPATCH=""" {F_if}((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) && - ({F_scheck}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck})) {{ - using trait_ = fmha_fwd_appendkv_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bs}, {F_bsk}, {F_bd}, {F_bdv}, {F_vlayout}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}>; + ({F_scheck}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck}) && (t.apply_rope == {F_rope})) {{ + using trait_ = fmha_fwd_appendkv_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bs}, {F_bsk}, {F_bd}, {F_bdv}, {F_vlayout}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, {F_rope}>; return fmha_fwd_appendkv_(s, a); }} """ @@ -98,11 +99,12 @@ class FmhaFwdAppendKVApiTrait: skpad : str dpad : str dvpad : str + rope : str # t/f, apply RoPE to Q/K or not @property def name(self) -> str: return f'{self.hdim}-{self.dtype}-{self.mode}-{self.bs}-{self.bsk}-{self.bd}-{self.bdv}-'+\ - f'{self.vlayout}-{self.spad}-{self.skpad}-{self.dpad}-{self.dvpad}' + f'{self.vlayout}-{self.spad}-{self.skpad}-{self.dpad}-{self.dvpad}-{self.rope}' @property def scheck(self) -> str: @@ -133,6 +135,7 @@ class FmhaFwdAppendKVPipeline: F_skpad : str # F_dpad : str # F_dvpad : str # + F_rope : str # t/f, apply RoPE to Q/K or not @property def name(self) -> str: @@ -147,6 +150,7 @@ class FmhaFwdAppendKVPipeline: pn = pad_name() n = f'v{self.F_vlayout[0]}' if pn != '' : n += f'_{pn}' + if self.F_rope == 't': n += '_rope' return n class FmhaFwdAppendKVApiPool: @@ -176,7 +180,7 @@ class FmhaFwdAppendKVApiPool: inners = inners + FMHA_FWD_APPENDKV_API_INNER_DISPATCH.format(F_if=if_k, F_mode=MODE_MAP[trait.mode], F_vlayout=LAYOUT_MAP[trait.vlayout], F_scheck=trait.scheck, F_skcheck=trait.skcheck, F_dcheck=trait.dcheck, F_dvcheck=trait.dvcheck, F_spad=BOOL_MAP[trait.spad], F_skpad=BOOL_MAP[trait.skpad], F_dpad=BOOL_MAP[trait.dpad], F_dvpad=BOOL_MAP[trait.dvpad], - F_bs=trait.bs, F_bsk=trait.bsk, F_bd=trait.bd, F_bdv=trait.bdv, F_hdim=hdim, F_dtype=DTYPE_MAP[dtype]) + F_rope=BOOL_MAP[trait.rope], F_bs=trait.bs, F_bsk=trait.bsk, F_bd=trait.bd, F_bdv=trait.bdv, F_hdim=hdim, F_dtype=DTYPE_MAP[dtype]) if_j = 'if' if j == 0 else 'else if' per_hdim_case = per_hdim_case + FMHA_FWD_API_PER_HDIM_CASE.format(F_if=if_j, F_hdim=hdim, F_inner_dispatch=inners) if_i = 'if' if i == 0 else 'else if' @@ -222,6 +226,7 @@ class FmhaFwdAppendKVKernel: F_skpad = BOOL_MAP[self.F_pipeline.F_skpad], F_dpad = BOOL_MAP[self.F_pipeline.F_dpad], F_dvpad = BOOL_MAP[self.F_pipeline.F_dvpad], + F_rope = BOOL_MAP[self.F_pipeline.F_rope], F_occupancy = self.F_tile.F_occupancy, F_mode = MODE_MAP[self.F_mode]) @@ -248,7 +253,8 @@ class FmhaFwdAppendKVKernel: spad=self.F_pipeline.F_spad, skpad=self.F_pipeline.F_skpad, dpad=self.F_pipeline.F_dpad, - dvpad=self.F_pipeline.F_dvpad) + dvpad=self.F_pipeline.F_dvpad, + rope=self.F_pipeline.F_rope) # TODO: design a more practical way to do it # this is current supported tile size per hdim @@ -280,15 +286,15 @@ def get_fwd_appendkv_blobs(kernel_filter : Optional[str], receipt, mask_impl) -> squant = 't' if dtype == 'fp8' else 'f' pipelines = [] if dtype in ['fp16', 'bf16']: - # pipelines.append(FmhaFwdAppendKVPipeline('row', 'f', 'f', 'f', 'f')) - # pipelines.append(FmhaFwdAppendKVPipeline('col', 'f', 'f', 'f', 'f')) - - pipelines.append(FmhaFwdAppendKVPipeline('row', 't', 't', 't', 't')) - pipelines.append(FmhaFwdAppendKVPipeline('col', 't', 't', 't', 't')) + for rope in ["t", "f"]: + # pipelines.append(FmhaFwdAppendKVPipeline('row', 'f', 'f', 'f', 'f', rope)) + # pipelines.append(FmhaFwdAppendKVPipeline('col', 'f', 'f', 'f', 'f', rope)) + pipelines.append(FmhaFwdAppendKVPipeline('row', 't', 't', 't', 't', rope)) + pipelines.append(FmhaFwdAppendKVPipeline('col', 't', 't', 't', 't', rope)) elif dtype in ['fp8', 'bf8']: # no need lse/dropout kernels - pipelines.append(FmhaFwdAppendKVPipeline('col', 'f', 'f', 'f', 'f')) + pipelines.append(FmhaFwdAppendKVPipeline('col', 'f', 'f', 'f', 'f', 'f')) else: assert False return pipelines diff --git a/example/ck_tile/01_fmha/fmha_fwd.cpp b/example/ck_tile/01_fmha/fmha_fwd.cpp index 77f236422a..ddae8d162f 100644 --- a/example/ck_tile/01_fmha/fmha_fwd.cpp +++ b/example/ck_tile/01_fmha/fmha_fwd.cpp @@ -676,7 +676,7 @@ bool run(const ck_tile::ArgParser& arg_parser) if(0 < seqlen_knew) { auto appendkv_traits = fmha_fwd_appendkv_traits{ - hdim_q, hdim_v, data_type, mode == mode_enum::group, is_v_rowmajor}; + hdim_q, hdim_v, data_type, mode == mode_enum::group, is_v_rowmajor, 0 < rotary_dim}; auto appendkv_args = [&, k_paddings_ = seqlen_kpads]() { // setup stride_* arguments diff --git a/example/ck_tile/01_fmha/fmha_fwd.hpp b/example/ck_tile/01_fmha/fmha_fwd.hpp index 23474f428a..ba9cd06833 100644 --- a/example/ck_tile/01_fmha/fmha_fwd.hpp +++ b/example/ck_tile/01_fmha/fmha_fwd.hpp @@ -479,6 +479,10 @@ auto fmha_fwd_appendkv_create_kargs_and_grids(fmha_fwd_appendkv_args args) args.hdim_v, args.nhead_q, args.nhead_q / args.nhead_k, + args.rotary_cos_ptr, + args.rotary_sin_ptr, + args.rotary_dim, + args.is_rotary_interleaved, args.stride_q, args.stride_k, args.stride_knew, @@ -506,6 +510,10 @@ auto fmha_fwd_appendkv_create_kargs_and_grids(fmha_fwd_appendkv_args args) args.hdim_v, args.nhead_q, args.nhead_q / args.nhead_k, + args.rotary_cos_ptr, + args.rotary_sin_ptr, + args.rotary_dim, + args.is_rotary_interleaved, args.stride_q, args.stride_k, args.stride_knew, @@ -623,11 +631,11 @@ template + bool kPadDv_, + bool kApplyRoPE_> struct fmha_fwd_appendkv_traits_ { static constexpr ck_tile::index_t HDim = HDim_; @@ -638,11 +646,11 @@ struct fmha_fwd_appendkv_traits_ static constexpr ck_tile::index_t kTileSizeD = kTileSizeD_; static constexpr ck_tile::index_t kTileSizeDv = kTileSizeDv_; static constexpr bool kIsVLayoutRowMajor = kIsVLayoutRowMajor_; - // static constexpr bool kApplyRotray = kApplyRotray_; - static constexpr bool kPadS = kPadS_; - static constexpr bool kPadSk = kPadSk_; - static constexpr bool kPadD = kPadD_; - static constexpr bool kPadDv = kPadDv_; + static constexpr bool kPadS = kPadS_; + static constexpr bool kPadSk = kPadSk_; + static constexpr bool kPadD = kPadD_; + static constexpr bool kPadDv = kPadDv_; + static constexpr bool kApplyRoPE = kApplyRoPE_; }; template @@ -673,6 +681,7 @@ struct fmha_fwd_appendkv_traits std::string data_type; bool is_group_mode; bool is_v_rowmajor; + bool apply_rope; }; float fmha_fwd_appendkv(fmha_fwd_appendkv_traits, fmha_fwd_appendkv_args, diff --git a/include/ck_tile/ops/fmha/kernel/fmha_fwd_appendkv_kernel.hpp b/include/ck_tile/ops/fmha/kernel/fmha_fwd_appendkv_kernel.hpp index 1efe6e115d..a66cd86007 100644 --- a/include/ck_tile/ops/fmha/kernel/fmha_fwd_appendkv_kernel.hpp +++ b/include/ck_tile/ops/fmha/kernel/fmha_fwd_appendkv_kernel.hpp @@ -31,6 +31,7 @@ struct FmhaFwdAppendKVKernel static constexpr bool kPadSeqLenK = FmhaPipeline::kPadSeqLenK; static constexpr bool kPadHeadDimQ = FmhaPipeline::kPadHeadDimQ; static constexpr bool kPadHeadDimV = FmhaPipeline::kPadHeadDimV; + static constexpr bool kApplyRoPE = FmhaPipeline::kApplyRoPE; // clang-format off template struct t2s; @@ -43,8 +44,9 @@ struct FmhaFwdAppendKVKernel __host__ static std::string GetName() { -// sync with generate.py -// clang-format off + // sync with generate.py + // clang-format off + #define _SS_ std::string #define _TS_ std::to_string auto pn = [&] () { @@ -59,7 +61,8 @@ struct FmhaFwdAppendKVKernel "_" + (kIsGroupMode ? "group" : "batch") + "_" "b" + _TS_(FmhaPipeline::kTileSizeS) + "x" + _TS_(FmhaPipeline::kTileSizeSk) + "x" + _TS_(FmhaPipeline::kTileSizeD) + "x" + _TS_(FmhaPipeline::kTileSizeDv) + "_" + (kBlockPerCuInput == -1 ? "" : ("o" + _TS_(kBlockPerCu) + "_")) + - "v" + (std::is_same_v ? "r" : "c") + (pn.empty() ? "" : "_" + pn); + "v" + (std::is_same_v ? "r" : "c") + (pn.empty() ? "" : "_" + pn) + + (kApplyRoPE ? "_rope" : ""); #undef _SS_ #undef _TS_ // clang-format on @@ -109,14 +112,24 @@ struct FmhaFwdAppendKVKernel ck_tile::index_t batch_stride_vnew; }; - struct BatchModeKargs : CommonKargs + struct CommonRoPEKargs + { + const void* rotary_cos_ptr; + const void* rotary_sin_ptr; + ck_tile::index_t rotary_dim; + bool is_rotary_interleaved; + }; + + struct BatchModeKargs : CommonKargs, + std::conditional_t> { ck_tile::index_t batch_stride_q; ck_tile::index_t batch_stride_k; ck_tile::index_t batch_stride_v; }; - struct GroupModeKargs : CommonKargs + struct GroupModeKargs : CommonKargs, + std::conditional_t> { const int32_t* seqstart_q_ptr; const int32_t* seqstart_k_ptr; @@ -139,6 +152,10 @@ struct FmhaFwdAppendKVKernel ck_tile::index_t hdim_v, ck_tile::index_t num_head_q, ck_tile::index_t nhead_ratio_qk, + const void* rotary_cos_ptr, + const void* rotary_sin_ptr, + ck_tile::index_t rotary_dim, + bool is_rotary_interleaved, ck_tile::index_t stride_q, ck_tile::index_t stride_k, ck_tile::index_t stride_knew, @@ -179,10 +196,19 @@ struct FmhaFwdAppendKVKernel nhead_stride_vnew, batch_stride_knew, batch_stride_vnew}, // args for common karg + {}, // placeholder for rope batch_stride_q, batch_stride_k, batch_stride_v}; + if constexpr(kApplyRoPE) + { + kargs.rotary_cos_ptr = rotary_cos_ptr; + kargs.rotary_sin_ptr = rotary_sin_ptr; + kargs.rotary_dim = rotary_dim; + kargs.is_rotary_interleaved = is_rotary_interleaved; + } + return kargs; } @@ -201,6 +227,10 @@ struct FmhaFwdAppendKVKernel ck_tile::index_t hdim_v, ck_tile::index_t num_head_q, ck_tile::index_t nhead_ratio_qk, + const void* rotary_cos_ptr, + const void* rotary_sin_ptr, + ck_tile::index_t rotary_dim, + bool is_rotary_interleaved, ck_tile::index_t stride_q, ck_tile::index_t stride_k, ck_tile::index_t stride_knew, @@ -238,10 +268,19 @@ struct FmhaFwdAppendKVKernel nhead_stride_vnew, batch_stride_knew, batch_stride_vnew}, // args for common karg + {}, // placeholder for rope reinterpret_cast(seqstart_q_ptr), reinterpret_cast(seqstart_k_ptr), reinterpret_cast(seqlen_k_ptr)}; + if constexpr(kApplyRoPE) + { + kargs.rotary_cos_ptr = rotary_cos_ptr; + kargs.rotary_sin_ptr = rotary_sin_ptr; + kargs.rotary_dim = rotary_dim; + kargs.is_rotary_interleaved = is_rotary_interleaved; + } + return kargs; } diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_appendkv_pipeline.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_appendkv_pipeline.hpp index 5892e0abfe..3e90e97564 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_appendkv_pipeline.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_appendkv_pipeline.hpp @@ -31,6 +31,7 @@ struct BlockFmhaFwdAppendKVPipeline static constexpr bool kPadSeqLenK = Problem::kPadSeqLenK; static constexpr bool kPadHeadDimQ = Problem::kPadHeadDimQ; static constexpr bool kPadHeadDimV = Problem::kPadHeadDimV; + static constexpr bool kApplyRoPE = Problem::kApplyRoPE; // last dimension vector length used to create tensor view(and decide buffer_load vector length) // ... together with tensor distribution. tensor dist should able to overwrite this diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_appendkv_pipeline_problem.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_appendkv_pipeline_problem.hpp index 9e831e5f1e..34a9f70125 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_appendkv_pipeline_problem.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_appendkv_pipeline_problem.hpp @@ -41,6 +41,7 @@ struct BlockFmhaFwdAppendKVPipelineProblem static constexpr bool kPadSeqLenK = Traits::kPadSeqLenK; static constexpr bool kPadHeadDimQ = Traits::kPadHeadDimQ; static constexpr bool kPadHeadDimV = Traits::kPadHeadDimV; + static constexpr bool kApplyRoPE = Traits::kApplyRoPE; static constexpr index_t kBlockPerCu = Traits::kBlockPerCu; }; diff --git a/include/ck_tile/ops/fmha/pipeline/tile_fmha_traits.hpp b/include/ck_tile/ops/fmha/pipeline/tile_fmha_traits.hpp index d79c95dfb6..8638bf2408 100644 --- a/include/ck_tile/ops/fmha/pipeline/tile_fmha_traits.hpp +++ b/include/ck_tile/ops/fmha/pipeline/tile_fmha_traits.hpp @@ -80,6 +80,7 @@ template struct TileFmhaFwdAppendKVTraits { @@ -87,6 +88,7 @@ struct TileFmhaFwdAppendKVTraits static constexpr bool kPadSeqLenK = kPadSeqLenK_; static constexpr bool kPadHeadDimQ = kPadHeadDimQ_; static constexpr bool kPadHeadDimV = kPadHeadDimV_; + static constexpr bool kApplyRoPE = kApplyRoPE_; static constexpr index_t kBlockPerCu = kBlockPerCu_; };