From d7de39fab350f3bbfb4662eccaabd070da2fac6f Mon Sep 17 00:00:00 2001 From: Qianfeng Date: Sun, 29 Dec 2024 14:29:56 +0800 Subject: [PATCH] Remove using partitioner for all fmha kernels (#1778) * Remove using tile partitioner for fmha_fwd_kernel * Remove using tile partitioner for fmha_fwd_splitkv and splitkv-combine kernels * Remove using tile partitioner for fmha_fwd_appendkv kernel * Unify the format of GetTileIndex [ROCm/composable_kernel commit: 4e076909b6c1e1404d9ff5dc0e71e3be1c06569e] --- example/ck_tile/01_fmha/README.md | 3 +- .../ck_tile/01_fmha/codegen/ops/fmha_fwd.py | 20 +--- .../01_fmha/codegen/ops/fmha_fwd_appendkv.py | 6 +- .../01_fmha/codegen/ops/fmha_fwd_splitkv.py | 10 +- example/ck_tile/01_fmha/fmha_fwd.hpp | 14 ++- include/ck_tile/ops/fmha.hpp | 3 - .../fmha/kernel/fmha_fwd_appendkv_kernel.hpp | 28 +++-- .../ops/fmha/kernel/fmha_fwd_kernel.hpp | 78 +++++++++++-- .../fmha_fwd_splitkv_combine_kernel.hpp | 39 +++++-- ...a_fwd_splitkv_combine_tile_partitioner.hpp | 48 -------- .../fmha/kernel/fmha_fwd_splitkv_kernel.hpp | 40 +++++-- .../fmha_fwd_splitkv_tile_partitioner.hpp | 54 --------- .../fmha/kernel/fmha_fwd_tile_partitioner.hpp | 105 ------------------ 13 files changed, 171 insertions(+), 277 deletions(-) delete mode 100644 include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_combine_tile_partitioner.hpp delete mode 100644 include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_tile_partitioner.hpp delete mode 100644 include/ck_tile/ops/fmha/kernel/fmha_fwd_tile_partitioner.hpp diff --git a/example/ck_tile/01_fmha/README.md b/example/ck_tile/01_fmha/README.md index c7ab296c3b..e9806e7a67 100644 --- a/example/ck_tile/01_fmha/README.md +++ b/example/ck_tile/01_fmha/README.md @@ -15,8 +15,7 @@ This will result in an executable `build/bin/tile_example_fmha_fwd` ## kernel The kernel template is `fmha_fwd_kernel.hpp`, this is the grid-wise op in old ck_tile's terminology. We put it here purposely, to demonstrate one can construct a kernel by using various internal component from ck_tile. We may still have an implementation under ck_tile's include path (in the future) for the kernel template. -There are 3 template parameters for this kernel template. -* `TilePartitioner` is used to map the workgroup to corresponding tile, `fmha_fwd_tile_partitioner.hpp` in this folder served as this purpose. +There are 2 template parameters for this kernel template. * `FmhaPipeline` is one of the block_tile_pipeline(under `include/ck_tile/tile_program/block_tile_pipeline`) which is a performance critical component. Indeed, we did a lot of optimization and trials to optimize the pipeline and may still workout more performance pipeline and update into that folder. People only need to replace this pipeline type and would be able to enjoy the benefit of different performant implementations (stay tuned for updated pipeline(s)). * `EpiloguePipeline` will modify and store out the result in the last phase. People usually will do lot of post-fusion at this stage, so we also abstract this concept. Currently we didn't do much thing at the epilogue stage but leave the room for future possible support. diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py index 66814f5a16..1c9d743f3d 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py @@ -29,11 +29,6 @@ K0_MAX_SUBMAX_MAP = { 256: 256 } -TILE_PARTITIONER_MAP = { - "shb" : "ck_tile::FmhaFwdTilePartitioner_SHB", - "hbs" : "ck_tile::FmhaFwdTilePartitioner_HBS", -} - FMHA_FWD_KERNEL_HEADER = """// SPDX-License-Identifier: MIT // Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.\n // auto generated by generate.py @@ -90,9 +85,7 @@ using fmha_epilogue_{F_idx} = {F_spad}, {F_dvpad}>>; using fmha_kernel_{F_idx} = - ck_tile::FmhaFwdKernel<{F_tile_partitioner}, - fmha_pipeline_{F_idx}, - fmha_epilogue_{F_idx}>; + ck_tile::FmhaFwdKernel; using trait_{F_idx} = fmha_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode},{F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, fmha_mask_{F_idx}, {F_bias}, {F_lse}, {F_dropout}, {F_squant}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}>; @@ -329,12 +322,6 @@ class FmhaFwdKernel: F_pipeline : FmhaFwdPipeline mask_impl : str - def get_tp(self) -> str: - if self.F_mode == 'group': - return 'hbs' - else: - return 'shb' - @property def template(self) -> str: kernel_body = str() @@ -374,13 +361,12 @@ class FmhaFwdKernel: F_pipeline_enum = PIPELINE_ENUM_MAP[self.F_pipeline.tag], F_mask = get_mask_map(self.mask_impl)[self.F_pipeline.F_mask], F_mode = MODE_MAP[self.F_mode], - F_pipeline = PIPELINE_MAP[self.F_pipeline.tag], - F_tile_partitioner = TILE_PARTITIONER_MAP[self.get_tp()]) + F_pipeline = PIPELINE_MAP[self.F_pipeline.tag]) @property def name(self) -> str: # TODO: we don't encode idx here - return f"fmha_fwd_d{self.F_hdim}_{self.F_dtype}_{self.F_mode}_{self.get_tp()}_" + \ + return f"fmha_fwd_d{self.F_hdim}_{self.F_dtype}_{self.F_mode}_" + \ self.F_tile.name + '_' + self.F_pipeline.name @property diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_appendkv.py b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_appendkv.py index fb998a33d7..2f20819302 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_appendkv.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_appendkv.py @@ -46,9 +46,7 @@ using fmha_pipeline_problem_{F_idx} = ck_tile::BlockFmhaFwdAppendKVPipelineProbl using fmha_pipeline_{F_idx} = ck_tile::BlockFmhaFwdAppendKVPipeline< fmha_pipeline_problem_{F_idx}>; -using fmha_kernel_{F_idx} = - ck_tile::FmhaFwdAppendKVKernel, - fmha_pipeline_{F_idx}>; +using fmha_kernel_{F_idx} = ck_tile::FmhaFwdAppendKVKernel; using trait_{F_idx} = fmha_fwd_appendkv_traits_<{F_hdim}, {F_dtype}, {F_bs}, {F_bsk}, {F_bd}, {F_bdv}, {F_vlayout}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, {F_rope}, {F_pagedkv}>; @@ -355,4 +353,4 @@ def list_blobs(file_path : Path, kernel_filter : Optional[str], receipt, mask_im _, kernels = get_fwd_appendkv_blobs(kernel_filter, receipt, mask_impl) for kernel in kernels: f.write(str(file_path.parent / GEN_DIR / kernel.filename) + "\n") - f.write(str(file_path.parent / GEN_DIR / FMHA_FWD_APPENDKV_API_FILENAME) + "\n") \ No newline at end of file + f.write(str(file_path.parent / GEN_DIR / FMHA_FWD_APPENDKV_API_FILENAME) + "\n") diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py index 2f7edd5477..fb8a4389f3 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py @@ -96,9 +96,7 @@ using fmha_epilogue = {F_spad}, {F_dvpad}>>; using fmha_kernel = - ck_tile::FmhaFwdSplitKVKernel, - fmha_pipeline, - fmha_epilogue>; + ck_tile::FmhaFwdSplitKVKernel; static void run(const ck_tile::stream_config& s, fmha_fwd_splitkv_args a) {{ @@ -176,11 +174,7 @@ using fmha_epilogue = false, false>>; using fmha_kernel = - ck_tile::FmhaFwdSplitKVCombineKernel< - ck_tile::FmhaFwdSplitKVCombineTilePartitioner< - fmha_pipeline_problem::kM0, fmha_pipeline_problem::kN1>, - fmha_pipeline, - fmha_epilogue>; + ck_tile::FmhaFwdSplitKVCombineKernel; static void run(const ck_tile::stream_config& s, fmha_fwd_splitkv_args a) {{ diff --git a/example/ck_tile/01_fmha/fmha_fwd.hpp b/example/ck_tile/01_fmha/fmha_fwd.hpp index 0e821ed5d9..0368de352f 100644 --- a/example/ck_tile/01_fmha/fmha_fwd.hpp +++ b/example/ck_tile/01_fmha/fmha_fwd.hpp @@ -400,8 +400,18 @@ auto fmha_fwd_create_kargs_and_grids(fmha_fwd_args args) } }(); - dim3 grids = FmhaKernel::GridSize(args.batch, args.nhead_q, args.max_seqlen_q, args.hdim_v); - return ck_tile::make_tuple(kargs, grids); + if constexpr(FmhaKernel::kIsGroupMode) + { + dim3 grids = FmhaKernel::GridSize( + args.batch, args.nhead_q, args.max_seqlen_q, args.hdim_v, args.seqlen_k_ptr != nullptr); + return ck_tile::make_tuple(kargs, grids); + } + else + { + dim3 grids = + FmhaKernel::GridSize(args.batch, args.nhead_q, args.max_seqlen_q, args.hdim_v, false); + return ck_tile::make_tuple(kargs, grids); + } } template diff --git a/include/ck_tile/ops/fmha.hpp b/include/ck_tile/ops/fmha.hpp index 7a09e4622d..d5920f4837 100644 --- a/include/ck_tile/ops/fmha.hpp +++ b/include/ck_tile/ops/fmha.hpp @@ -14,10 +14,7 @@ #include "ck_tile/ops/fmha/kernel/fmha_fwd_appendkv_tile_partitioner.hpp" #include "ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp" #include "ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_combine_kernel.hpp" -#include "ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_combine_tile_partitioner.hpp" #include "ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_kernel.hpp" -#include "ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_tile_partitioner.hpp" -#include "ck_tile/ops/fmha/kernel/fmha_fwd_tile_partitioner.hpp" #include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_convert_dq.hpp" #include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dot_do_o.hpp" #include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr.hpp" diff --git a/include/ck_tile/ops/fmha/kernel/fmha_fwd_appendkv_kernel.hpp b/include/ck_tile/ops/fmha/kernel/fmha_fwd_appendkv_kernel.hpp index d598f97433..9fec9a320c 100644 --- a/include/ck_tile/ops/fmha/kernel/fmha_fwd_appendkv_kernel.hpp +++ b/include/ck_tile/ops/fmha/kernel/fmha_fwd_appendkv_kernel.hpp @@ -10,10 +10,9 @@ namespace ck_tile { -template +template struct FmhaFwdAppendKVKernel { - using TilePartitioner = ck_tile::remove_cvref_t; using FmhaPipeline = ck_tile::remove_cvref_t; static constexpr ck_tile::index_t kBlockSize = FmhaPipeline::kBlockSize; static constexpr ck_tile::index_t kBlockPerCu = FmhaPipeline::kBlockPerCu; @@ -234,12 +233,25 @@ struct FmhaFwdAppendKVKernel return kargs; } - __host__ static constexpr auto GridSize(ck_tile::index_t batch_size, - ck_tile::index_t nhead, - ck_tile::index_t seqlen_q, - ck_tile::index_t seqlen_knew) + CK_TILE_HOST static constexpr auto GridSize(ck_tile::index_t batch_size, + ck_tile::index_t nhead, + ck_tile::index_t seqlen_q, + ck_tile::index_t seqlen_knew) { - return TilePartitioner::GridSize(batch_size, nhead, seqlen_q, seqlen_knew); + // TODO: this may need tuning + return dim3(std::max(ck_tile::integer_divide_ceil(seqlen_q, FmhaPipeline::kM0), + ck_tile::integer_divide_ceil(seqlen_knew, FmhaPipeline::kN0)), + nhead, + batch_size); + } + + CK_TILE_DEVICE static constexpr auto GetTileIndex(const Kargs& /* kargs */) + { + const index_t i_tile = blockIdx.x; + const index_t i_nhead = blockIdx.y; + const index_t i_batch = blockIdx.z; + + return ck_tile::make_tuple(i_tile, i_nhead, i_batch); } __host__ static constexpr auto BlockSize() { return dim3(kBlockSize); } @@ -247,7 +259,7 @@ struct FmhaFwdAppendKVKernel CK_TILE_DEVICE void operator()(Kargs kargs) const { // divide problem - const auto [i_tile, i_nhead, i_batch] = TilePartitioner{}(); + const auto [i_tile, i_nhead, i_batch] = GetTileIndex(kargs); const index_t i_m0 = __builtin_amdgcn_readfirstlane(i_tile * FmhaPipeline::kM0); const index_t i_n0 = __builtin_amdgcn_readfirstlane(i_tile * FmhaPipeline::kN0); diff --git a/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp b/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp index 90102a6c6f..f107b10dff 100644 --- a/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp +++ b/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp @@ -20,10 +20,9 @@ namespace ck_tile { -template +template struct FmhaFwdKernel { - using TilePartitioner = ck_tile::remove_cvref_t; using FmhaPipeline = ck_tile::remove_cvref_t; using EpiloguePipeline = ck_tile::remove_cvref_t; static constexpr ck_tile::index_t kBlockSize = FmhaPipeline::kBlockSize; @@ -84,7 +83,7 @@ struct FmhaFwdKernel return n.empty() ? n : std::string("p") + n; }(); return _SS_("fmha_fwd_d") + _TS_(bfs::kQKHeaddim) + "_" + _SS_(t2s::name) + - "_" + (kIsGroupMode ? "group" : "batch") + "_" + _SS_(TilePartitioner::name) + "_" + "_" + (kIsGroupMode ? "group" : "batch") + "_" "b" + _TS_(bfs::kM0) + "x" + _TS_(bfs::kN0) + "x" + _TS_(bfs::kK0) + "x" + _TS_(bfs::kN1) + "x" + _TS_(bfs::kK1) + "x" + _TS_(bfs::kQKHeaddim) + "_" + "r" + _TS_(g0br::at(ck_tile::number<0>{})) + "x" + _TS_(g0br::at(ck_tile::number<1>{})) + "x" + _TS_(g0br::at(ck_tile::number<2>{})) + "_" + @@ -867,9 +866,75 @@ struct FmhaFwdKernel CK_TILE_HOST static constexpr auto GridSize(ck_tile::index_t batch_size_, ck_tile::index_t nhead_, ck_tile::index_t seqlen_q_, - ck_tile::index_t hdim_v_) + ck_tile::index_t hdim_v_, + bool has_padded_seqlen_k = false) { - return TilePartitioner::GridSize(batch_size_, nhead_, seqlen_q_, hdim_v_); + // has_padded_seqlen_k is determined by checking (seqlen_k_ptr != nullptr) + if(has_padded_seqlen_k) + { + // TODO: this may need tuning + return dim3(nhead_, + batch_size_, + ck_tile::integer_divide_ceil(seqlen_q_, FmhaPipeline::kM0) * + ck_tile::integer_divide_ceil(hdim_v_, FmhaPipeline::kN1)); + } + else + { + // TODO: this may need tuning + return dim3(ck_tile::integer_divide_ceil(seqlen_q_, FmhaPipeline::kM0) * + ck_tile::integer_divide_ceil(hdim_v_, FmhaPipeline::kN1), + nhead_, + batch_size_); + } + } + + CK_TILE_DEVICE static constexpr auto GetTileIndex(const Kargs& kargs) + { + bool has_padded_seqlen_k = false; + + if constexpr(kIsGroupMode) + has_padded_seqlen_k = (kargs.seqlen_k_ptr != nullptr); + + if(has_padded_seqlen_k) + { + // const index_t num_tile_m0 = seqlen_q / kM0; + const index_t num_tile_n1 = + ck_tile::integer_divide_ceil(kargs.hdim_v, FmhaPipeline::kN1); + + const index_t i_block = blockIdx.z; + const index_t i_nhead = blockIdx.x; + const index_t i_batch = blockIdx.y; + + const auto f = [](index_t dividend, index_t divisor) { + index_t quotient = dividend / divisor; + index_t modulus = dividend - quotient * divisor; + return ck_tile::make_tuple(quotient, modulus); + }; + + const auto [i_tile_m, i_tile_n] = f(i_block, num_tile_n1); + + return ck_tile::make_tuple(i_tile_m, i_tile_n, i_nhead, i_batch); + } + else + { + // const index_t num_tile_m0 = seqlen_q / kM0; + const index_t num_tile_n1 = + ck_tile::integer_divide_ceil(kargs.hdim_v, FmhaPipeline::kN1); + + const index_t i_block = blockIdx.x; + const index_t i_nhead = blockIdx.y; + const index_t i_batch = blockIdx.z; + + const auto f = [](index_t dividend, index_t divisor) { + index_t quotient = dividend / divisor; + index_t modulus = dividend - quotient * divisor; + return ck_tile::make_tuple(quotient, modulus); + }; + + const auto [i_tile_m, i_tile_n] = f(i_block, num_tile_n1); + + return ck_tile::make_tuple(i_tile_m, i_tile_n, i_nhead, i_batch); + } } CK_TILE_HOST static constexpr auto BlockSize() { return dim3(kBlockSize); } @@ -885,8 +950,7 @@ struct FmhaFwdKernel __shared__ char smem_ptr[GetSmemSize()]; // divide problem - const auto [i_tile_m, i_tile_n, i_nhead, i_batch] = - TilePartitioner{}(kargs.seqlen_q, kargs.hdim_v); + const auto [i_tile_m, i_tile_n, i_nhead, i_batch] = GetTileIndex(kargs); const index_t i_m0 = __builtin_amdgcn_readfirstlane(i_tile_m * FmhaPipeline::kM0); const index_t i_n1 = __builtin_amdgcn_readfirstlane(i_tile_n * FmhaPipeline::kN1); diff --git a/include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_combine_kernel.hpp b/include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_combine_kernel.hpp index a0adfdc127..a342a91f10 100644 --- a/include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_combine_kernel.hpp +++ b/include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_combine_kernel.hpp @@ -5,10 +5,9 @@ namespace ck_tile { -template +template struct FmhaFwdSplitKVCombineKernel { - using TilePartitioner = remove_cvref_t; using FmhaPipeline = remove_cvref_t; using EpiloguePipeline = remove_cvref_t; @@ -235,12 +234,35 @@ struct FmhaFwdSplitKVCombineKernel return kargs; } - __host__ static constexpr auto GridSize(ck_tile::index_t batch_size, - ck_tile::index_t nhead, - ck_tile::index_t max_seqlen_q, - ck_tile::index_t hdim_v) + CK_TILE_HOST static constexpr auto GridSize(ck_tile::index_t batch_size, + ck_tile::index_t nhead, + ck_tile::index_t max_seqlen_q, + ck_tile::index_t hdim_v) { - return TilePartitioner::GridSize(batch_size, nhead, max_seqlen_q, hdim_v); + // TODO: this may need tuning + return dim3(ck_tile::integer_divide_ceil(max_seqlen_q, FmhaPipeline::kM0) * + ck_tile::integer_divide_ceil(hdim_v, FmhaPipeline::kN1), + nhead, + batch_size); + } + + CK_TILE_DEVICE static constexpr auto GetTileIndex(const Kargs& kargs) + { + const index_t num_tile_n1 = ck_tile::integer_divide_ceil(kargs.hdim_v, FmhaPipeline::kN1); + + const index_t i_block = blockIdx.x; + const index_t i_nhead = blockIdx.y; + const index_t i_batch = blockIdx.z; + + const auto f = [](index_t dividend, index_t divisor) { + index_t quotient = dividend / divisor; + index_t modulus = dividend - quotient * divisor; + return ck_tile::make_tuple(quotient, modulus); + }; + + const auto [i_tile_m, i_tile_n] = f(i_block, num_tile_n1); + + return ck_tile::make_tuple(i_tile_m, i_tile_n, i_nhead, i_batch); } __host__ static constexpr auto BlockSize() { return dim3(kBlockSize); } @@ -256,8 +278,7 @@ struct FmhaFwdSplitKVCombineKernel __shared__ char smem_ptr[GetSmemSize()]; // divide problem - const auto [i_tile_m, i_tile_n, i_nhead, i_batch] = - TilePartitioner{}(kargs.seqlen_q, kargs.hdim_v); + const auto [i_tile_m, i_tile_n, i_nhead, i_batch] = GetTileIndex(kargs); const index_t i_m0 = __builtin_amdgcn_readfirstlane(i_tile_m * FmhaPipeline::kM0); const index_t i_n1 = __builtin_amdgcn_readfirstlane(i_tile_n * FmhaPipeline::kN1); diff --git a/include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_combine_tile_partitioner.hpp b/include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_combine_tile_partitioner.hpp deleted file mode 100644 index 3b73909712..0000000000 --- a/include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_combine_tile_partitioner.hpp +++ /dev/null @@ -1,48 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. - -#pragma once - -#include "ck_tile/core.hpp" - -namespace ck_tile { - -template -struct FmhaFwdSplitKVCombineTilePartitioner -{ - static constexpr ck_tile::index_t kM0 = kM0_; - static constexpr ck_tile::index_t kN1 = kN1_; - - CK_TILE_HOST static constexpr auto GridSize(ck_tile::index_t batch_size, - ck_tile::index_t nhead, - ck_tile::index_t max_seqlen_q, - ck_tile::index_t hdim_v) - { - // TODO: this may need tuning - return dim3(ck_tile::integer_divide_ceil(max_seqlen_q, kM0) * - ck_tile::integer_divide_ceil(hdim_v, kN1), - nhead, - batch_size); - } - - CK_TILE_DEVICE auto operator()(ck_tile::index_t /*seqlen_q*/, ck_tile::index_t hdim_v) - { - const index_t num_tile_n1 = ck_tile::integer_divide_ceil(hdim_v, kN1); - - const index_t i_block = blockIdx.x; - const index_t i_nhead = blockIdx.y; - const index_t i_batch = blockIdx.z; - - const auto f = [](index_t dividend, index_t divisor) { - index_t quotient = dividend / divisor; - index_t modulus = dividend - quotient * divisor; - return ck_tile::make_tuple(quotient, modulus); - }; - - const auto [i_tile_m, i_tile_n] = f(i_block, num_tile_n1); - - return ck_tile::make_tuple(i_tile_m, i_tile_n, i_nhead, i_batch); - } -}; - -} // namespace ck_tile diff --git a/include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_kernel.hpp b/include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_kernel.hpp index dc17487262..10ab25119b 100644 --- a/include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_kernel.hpp +++ b/include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_kernel.hpp @@ -17,10 +17,9 @@ namespace ck_tile { -template +template struct FmhaFwdSplitKVKernel { - using TilePartitioner = ck_tile::remove_cvref_t; using FmhaPipeline = ck_tile::remove_cvref_t; using EpiloguePipeline = ck_tile::remove_cvref_t; static constexpr ck_tile::index_t kBlockSize = FmhaPipeline::kBlockSize; @@ -476,13 +475,35 @@ struct FmhaFwdSplitKVKernel return kargs; } - __host__ static constexpr auto GridSize(ck_tile::index_t batch_size, - ck_tile::index_t nhead, - ck_tile::index_t max_seqlen_q, - ck_tile::index_t hdim_v, - ck_tile::index_t num_splits) + CK_TILE_HOST static constexpr auto GridSize(ck_tile::index_t batch_size, + ck_tile::index_t nhead, + ck_tile::index_t max_seqlen_q, + ck_tile::index_t hdim_v, + ck_tile::index_t num_splits) { - return TilePartitioner::GridSize(batch_size, nhead, max_seqlen_q, hdim_v, num_splits); + // TODO: this may need tuning + return dim3(ck_tile::integer_divide_ceil(max_seqlen_q, FmhaPipeline::kM0) * + ck_tile::integer_divide_ceil(hdim_v, FmhaPipeline::kN1) * num_splits, + nhead, + batch_size); + } + + CK_TILE_DEVICE static constexpr auto GetTileIndex(const Kargs& kargs) + { + const index_t num_tile_n1 = ck_tile::integer_divide_ceil(kargs.hdim_v, FmhaPipeline::kN1); + + const auto f = [](index_t dividend, index_t divisor) { + index_t quotient = dividend / divisor; + index_t modulus = dividend - quotient * divisor; + return ck_tile::make_tuple(quotient, modulus); + }; + + const auto [mn, i_split] = f(blockIdx.x, kargs.num_splits); + const auto [i_tile_m, i_tile_n] = f(mn, num_tile_n1); + const index_t i_nhead = blockIdx.y; + const index_t i_batch = blockIdx.z; + + return ck_tile::make_tuple(i_tile_m, i_tile_n, i_split, i_nhead, i_batch); } __host__ static constexpr auto BlockSize() { return dim3(kBlockSize); } @@ -498,8 +519,7 @@ struct FmhaFwdSplitKVKernel __shared__ char smem_ptr[GetSmemSize()]; // divide problem - const auto [i_tile_m, i_tile_n, i_split, i_nhead, i_batch] = - TilePartitioner{}(kargs.seqlen_q, kargs.hdim_v, kargs.num_splits); + const auto [i_tile_m, i_tile_n, i_split, i_nhead, i_batch] = GetTileIndex(kargs); const index_t i_m0 = __builtin_amdgcn_readfirstlane(i_tile_m * FmhaPipeline::kM0); const index_t i_n1 = __builtin_amdgcn_readfirstlane(i_tile_n * FmhaPipeline::kN1); diff --git a/include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_tile_partitioner.hpp b/include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_tile_partitioner.hpp deleted file mode 100644 index 5a52fa0f67..0000000000 --- a/include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_tile_partitioner.hpp +++ /dev/null @@ -1,54 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. - -#pragma once - -#include "ck_tile/core.hpp" - -namespace ck_tile { - -template -struct FmhaFwdSplitKVTilePartitioner -{ - using BlockFmhaShape = ck_tile::remove_cvref_t; - - static constexpr ck_tile::index_t kM0 = BlockFmhaShape::kM0; - static constexpr ck_tile::index_t kN0 = BlockFmhaShape::kN0; - static constexpr ck_tile::index_t kK0 = BlockFmhaShape::kK0; - static constexpr ck_tile::index_t kN1 = BlockFmhaShape::kN1; - static constexpr ck_tile::index_t kK1 = BlockFmhaShape::kK1; - - CK_TILE_HOST static constexpr auto GridSize(ck_tile::index_t batch_size, - ck_tile::index_t nhead, - ck_tile::index_t max_seqlen_q, - ck_tile::index_t hdim_v, - ck_tile::index_t num_splits) - { - // TODO: this may need tuning - return dim3(ck_tile::integer_divide_ceil(max_seqlen_q, kM0) * - ck_tile::integer_divide_ceil(hdim_v, kN1) * num_splits, - nhead, - batch_size); - } - - CK_TILE_DEVICE auto - operator()(ck_tile::index_t /*seqlen_q*/, ck_tile::index_t hdim_v, ck_tile::index_t num_splits) - { - const index_t num_tile_n1 = ck_tile::integer_divide_ceil(hdim_v, kN1); - - const auto f = [](index_t dividend, index_t divisor) { - index_t quotient = dividend / divisor; - index_t modulus = dividend - quotient * divisor; - return ck_tile::make_tuple(quotient, modulus); - }; - - const auto [mn, i_split] = f(blockIdx.x, num_splits); - const auto [i_tile_m, i_tile_n] = f(mn, num_tile_n1); - const index_t i_nhead = blockIdx.y; - const index_t i_batch = blockIdx.z; - - return ck_tile::make_tuple(i_tile_m, i_tile_n, i_split, i_nhead, i_batch); - } -}; - -} // namespace ck_tile diff --git a/include/ck_tile/ops/fmha/kernel/fmha_fwd_tile_partitioner.hpp b/include/ck_tile/ops/fmha/kernel/fmha_fwd_tile_partitioner.hpp deleted file mode 100644 index 2dca84b786..0000000000 --- a/include/ck_tile/ops/fmha/kernel/fmha_fwd_tile_partitioner.hpp +++ /dev/null @@ -1,105 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. - -#pragma once - -#include "ck_tile/core.hpp" - -namespace ck_tile { - -template -struct FmhaFwdTilePartitioner -{ - using BlockFmhaShape = ck_tile::remove_cvref_t; - - static constexpr ck_tile::index_t kM0 = BlockFmhaShape::kM0; - static constexpr ck_tile::index_t kN0 = BlockFmhaShape::kN0; - static constexpr ck_tile::index_t kK0 = BlockFmhaShape::kK0; - static constexpr ck_tile::index_t kN1 = BlockFmhaShape::kN1; - static constexpr ck_tile::index_t kK1 = BlockFmhaShape::kK1; - - static constexpr const char* name = "shb"; - - CK_TILE_HOST static constexpr auto GridSize(ck_tile::index_t batch_size_, - ck_tile::index_t nhead_, - ck_tile::index_t seqlen_q_, - ck_tile::index_t hdim_v_) - { - // TODO: this may need tuning - return dim3(ck_tile::integer_divide_ceil(seqlen_q_, kM0) * - ck_tile::integer_divide_ceil(hdim_v_, kN1), - nhead_, - batch_size_); - } - - CK_TILE_DEVICE auto operator()(ck_tile::index_t /*seqlen_q*/, ck_tile::index_t hdim_v) - { - // const index_t num_tile_m0 = seqlen_q / kM0; - const index_t num_tile_n1 = ck_tile::integer_divide_ceil(hdim_v, kN1); - - const index_t i_block = blockIdx.x; - const index_t i_nhead = blockIdx.y; - const index_t i_batch = blockIdx.z; - - const auto f = [](index_t dividend, index_t divisor) { - index_t quotient = dividend / divisor; - index_t modulus = dividend - quotient * divisor; - return ck_tile::make_tuple(quotient, modulus); - }; - - const auto [i_tile_m, i_tile_n] = f(i_block, num_tile_n1); - - return ck_tile::make_tuple(i_tile_m, i_tile_n, i_nhead, i_batch); - } -}; - -template -using FmhaFwdTilePartitioner_SHB = FmhaFwdTilePartitioner; - -template -struct FmhaFwdTilePartitioner_HBS -{ - using BlockFmhaShape = ck_tile::remove_cvref_t; - - static constexpr ck_tile::index_t kM0 = BlockFmhaShape::kM0; - static constexpr ck_tile::index_t kN0 = BlockFmhaShape::kN0; - static constexpr ck_tile::index_t kK0 = BlockFmhaShape::kK0; - static constexpr ck_tile::index_t kN1 = BlockFmhaShape::kN1; - static constexpr ck_tile::index_t kK1 = BlockFmhaShape::kK1; - - static constexpr const char* name = "hbs"; - - CK_TILE_HOST static constexpr auto GridSize(ck_tile::index_t batch_size_, - ck_tile::index_t nhead_, - ck_tile::index_t seqlen_q_, - ck_tile::index_t hdim_v_) - { - // TODO: this may need tuning - return dim3(nhead_, - batch_size_, - ck_tile::integer_divide_ceil(seqlen_q_, kM0) * - ck_tile::integer_divide_ceil(hdim_v_, kN1)); - } - - CK_TILE_DEVICE auto operator()(ck_tile::index_t /*seqlen_q*/, ck_tile::index_t hdim_v) - { - // const index_t num_tile_m0 = seqlen_q / kM0; - const index_t num_tile_n1 = ck_tile::integer_divide_ceil(hdim_v, kN1); - - const index_t i_block = blockIdx.z; - const index_t i_nhead = blockIdx.x; - const index_t i_batch = blockIdx.y; - - const auto f = [](index_t dividend, index_t divisor) { - index_t quotient = dividend / divisor; - index_t modulus = dividend - quotient * divisor; - return ck_tile::make_tuple(quotient, modulus); - }; - - const auto [i_tile_m, i_tile_n] = f(i_block, num_tile_n1); - - return ck_tile::make_tuple(i_tile_m, i_tile_n, i_nhead, i_batch); - } -}; - -} // namespace ck_tile