From fffd6799e68a85a88545c6a3cdcec7bce5089d50 Mon Sep 17 00:00:00 2001 From: "PoYen, Chen" Date: Sat, 20 Jul 2024 02:28:21 +0000 Subject: [PATCH] Instantiate multiple kernels for RoPE approaches --- .../ck_tile/01_fmha/codegen/cpp_symbol_map.py | 13 ++++ .../01_fmha/codegen/ops/fmha_fwd_appendkv.py | 16 ++--- example/ck_tile/01_fmha/fmha_fwd.cpp | 12 +++- example/ck_tile/01_fmha/fmha_fwd.hpp | 27 ++++---- example/ck_tile/01_fmha/rotary.hpp | 17 +++-- ...ence_batched_rotary_position_embedding.hpp | 1 + include/ck_tile/ops/fmha.hpp | 1 + .../block/block_rotary_embedding_enum.hpp | 37 +++++++++++ .../fmha/kernel/fmha_fwd_appendkv_kernel.hpp | 24 +++---- .../block_fmha_fwd_appendkv_pipeline.hpp | 65 ++++++++++++++++--- ...ock_fmha_fwd_appendkv_pipeline_problem.hpp | 2 +- .../ops/fmha/pipeline/tile_fmha_traits.hpp | 5 +- 12 files changed, 163 insertions(+), 57 deletions(-) create mode 100644 include/ck_tile/ops/fmha/block/block_rotary_embedding_enum.hpp diff --git a/example/ck_tile/01_fmha/codegen/cpp_symbol_map.py b/example/ck_tile/01_fmha/codegen/cpp_symbol_map.py index d3d215f7f5..5e8d1fecbc 100644 --- a/example/ck_tile/01_fmha/codegen/cpp_symbol_map.py +++ b/example/ck_tile/01_fmha/codegen/cpp_symbol_map.py @@ -66,6 +66,19 @@ BIAS_CHECK_MAP = { "alibi" : "bias_enum::alibi" } +ROPE_MAP = { + "no" : "ck_tile::BlockRotaryEmbeddingEnum::NONE", + "inter" : "ck_tile::BlockRotaryEmbeddingEnum::INTERLEAVED", + "half" : "ck_tile::BlockRotaryEmbeddingEnum::HALF_ROTATED" +} + +# TODO: avoid duplication +ROPE_CHECK_MAP = { + "no" : "rope_enum::none", + "inter" : "rope_enum::interleaved", + "half" : "rope_enum::half_rotated" +} + MODE_MAP = { "batch" : "false", "group" : "true" diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_appendkv.py b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_appendkv.py index d3f128cc9a..ab572debbd 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_appendkv.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_appendkv.py @@ -78,7 +78,7 @@ float fmha_fwd_appendkv(fmha_fwd_appendkv_traits t, fmha_fwd_appendkv_args a, co """ FMHA_FWD_APPENDKV_API_INNER_DISPATCH=""" {F_if}((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) && - ({F_scheck}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck}) && (t.apply_rope == {F_rope})) {{ + ({F_scheck}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck}) && (t.rope_type == {F_rope_check})) {{ using trait_ = fmha_fwd_appendkv_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bs}, {F_bsk}, {F_bd}, {F_bdv}, {F_vlayout}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, {F_rope}>; return fmha_fwd_appendkv_(s, a); }} @@ -99,7 +99,7 @@ class FmhaFwdAppendKVApiTrait: skpad : str dpad : str dvpad : str - rope : str # t/f, apply RoPE to Q/K or not + rope : str # key from ROPE_MAP @property def name(self) -> str: @@ -135,7 +135,7 @@ class FmhaFwdAppendKVPipeline: F_skpad : str # F_dpad : str # F_dvpad : str # - F_rope : str # t/f, apply RoPE to Q/K or not + F_rope : str # key from ROPE_MAP @property def name(self) -> str: @@ -150,7 +150,7 @@ class FmhaFwdAppendKVPipeline: pn = pad_name() n = f'v{self.F_vlayout[0]}' if pn != '' : n += f'_{pn}' - if self.F_rope == 't': n += '_rope' + if self.F_rope != 'no': n += f'_{self.F_rope}' return n class FmhaFwdAppendKVApiPool: @@ -178,9 +178,9 @@ class FmhaFwdAppendKVApiPool: for k, trait in enumerate(traits): if_k = 'if' if k == 0 else 'else if' inners = inners + FMHA_FWD_APPENDKV_API_INNER_DISPATCH.format(F_if=if_k, F_mode=MODE_MAP[trait.mode], F_vlayout=LAYOUT_MAP[trait.vlayout], - F_scheck=trait.scheck, F_skcheck=trait.skcheck, F_dcheck=trait.dcheck, F_dvcheck=trait.dvcheck, + F_scheck=trait.scheck, F_skcheck=trait.skcheck, F_dcheck=trait.dcheck, F_dvcheck=trait.dvcheck, F_rope_check=ROPE_CHECK_MAP[trait.rope], F_spad=BOOL_MAP[trait.spad], F_skpad=BOOL_MAP[trait.skpad], F_dpad=BOOL_MAP[trait.dpad], F_dvpad=BOOL_MAP[trait.dvpad], - F_rope=BOOL_MAP[trait.rope], F_bs=trait.bs, F_bsk=trait.bsk, F_bd=trait.bd, F_bdv=trait.bdv, F_hdim=hdim, F_dtype=DTYPE_MAP[dtype]) + F_rope=ROPE_MAP[trait.rope], F_bs=trait.bs, F_bsk=trait.bsk, F_bd=trait.bd, F_bdv=trait.bdv, F_hdim=hdim, F_dtype=DTYPE_MAP[dtype]) if_j = 'if' if j == 0 else 'else if' per_hdim_case = per_hdim_case + FMHA_FWD_API_PER_HDIM_CASE.format(F_if=if_j, F_hdim=hdim, F_inner_dispatch=inners) if_i = 'if' if i == 0 else 'else if' @@ -226,7 +226,7 @@ class FmhaFwdAppendKVKernel: F_skpad = BOOL_MAP[self.F_pipeline.F_skpad], F_dpad = BOOL_MAP[self.F_pipeline.F_dpad], F_dvpad = BOOL_MAP[self.F_pipeline.F_dvpad], - F_rope = BOOL_MAP[self.F_pipeline.F_rope], + F_rope = ROPE_MAP[self.F_pipeline.F_rope], F_occupancy = self.F_tile.F_occupancy, F_mode = MODE_MAP[self.F_mode]) @@ -286,7 +286,7 @@ def get_fwd_appendkv_blobs(kernel_filter : Optional[str], receipt, mask_impl) -> squant = 't' if dtype == 'fp8' else 'f' pipelines = [] if dtype in ['fp16', 'bf16']: - for rope in ["t", "f"]: + for rope in ROPE_MAP.keys(): # pipelines.append(FmhaFwdAppendKVPipeline('row', 'f', 'f', 'f', 'f', rope)) # pipelines.append(FmhaFwdAppendKVPipeline('col', 'f', 'f', 'f', 'f', rope)) diff --git a/example/ck_tile/01_fmha/fmha_fwd.cpp b/example/ck_tile/01_fmha/fmha_fwd.cpp index 98d2a2ec62..1aa8eadf8a 100644 --- a/example/ck_tile/01_fmha/fmha_fwd.cpp +++ b/example/ck_tile/01_fmha/fmha_fwd.cpp @@ -558,7 +558,7 @@ bool run(const ck_tile::ArgParser& arg_parser) printf("\n"); } #endif -#if 1 +#if 0 printf("rotary_sin's shape: (%2zu, %2zu)\n", rotary_sin_host.get_length(0), rotary_sin_host.get_length(1)); @@ -727,7 +727,14 @@ bool run(const ck_tile::ArgParser& arg_parser) if(0 < seqlen_knew) { auto appendkv_traits = fmha_fwd_appendkv_traits{ - hdim_q, hdim_v, data_type, mode == mode_enum::group, is_v_rowmajor, 0 < rotary_dim}; + hdim_q, + hdim_v, + data_type, + mode == mode_enum::group, + is_v_rowmajor, + (0 < rotary_dim + ? (is_rotary_interleaved ? rope_enum::interleaved : rope_enum::half_rotated) + : rope_enum::none)}; auto appendkv_args = [&, k_paddings_ = seqlen_kpads]() { // setup stride_* arguments @@ -790,7 +797,6 @@ bool run(const ck_tile::ArgParser& arg_parser) rotary_cos_buf.GetDeviceBuffer(), rotary_sin_buf.GetDeviceBuffer(), rotary_dim, - is_rotary_interleaved, stride_q, stride_k, stride_knew, diff --git a/example/ck_tile/01_fmha/fmha_fwd.hpp b/example/ck_tile/01_fmha/fmha_fwd.hpp index ec2aa85da2..8b8f08fc20 100644 --- a/example/ck_tile/01_fmha/fmha_fwd.hpp +++ b/example/ck_tile/01_fmha/fmha_fwd.hpp @@ -5,10 +5,13 @@ #include "ck_tile/core.hpp" #include "ck_tile/host/kernel_launch.hpp" -#include "ck_tile/ops/fmha.hpp" #include "ck_tile/ops/epilogue.hpp" -#include "mask.hpp" +#include "ck_tile/ops/fmha.hpp" + #include "bias.hpp" +#include "mask.hpp" +#include "rotary.hpp" + #include template @@ -176,7 +179,6 @@ struct fmha_fwd_appendkv_args const void* rotary_cos_ptr; const void* rotary_sin_ptr; ck_tile::index_t rotary_dim; - bool is_rotary_interleaved; ck_tile::index_t stride_q; ck_tile::index_t stride_k; @@ -486,7 +488,6 @@ auto fmha_fwd_appendkv_create_kargs_and_grids(fmha_fwd_appendkv_args args) args.rotary_cos_ptr, args.rotary_sin_ptr, args.rotary_dim, - args.is_rotary_interleaved, args.stride_q, args.stride_k, args.stride_knew, @@ -517,7 +518,6 @@ auto fmha_fwd_appendkv_create_kargs_and_grids(fmha_fwd_appendkv_args args) args.rotary_cos_ptr, args.rotary_sin_ptr, args.rotary_dim, - args.is_rotary_interleaved, args.stride_q, args.stride_k, args.stride_knew, @@ -537,10 +537,13 @@ auto fmha_fwd_appendkv_create_kargs_and_grids(fmha_fwd_appendkv_args args) }(); dim3 grids = Kernel::GridSize(args.batch, args.nhead_q, args.seqlen_knew, args.hdim_v); - printf("[POYENC][HOST] grid size: %2d,%2d,%2d\n", - static_cast(grids.x), - static_cast(grids.y), - static_cast(grids.z)); + HOST_DEBUG_STMTS + { + printf("[HOST] grid size: %2d,%2d,%2d\n", + static_cast(grids.x), + static_cast(grids.y), + static_cast(grids.z)); + } return ck_tile::make_tuple(kargs, grids); } @@ -639,7 +642,7 @@ template + ck_tile::BlockRotaryEmbeddingEnum RotaryEnum_> struct fmha_fwd_appendkv_traits_ { static constexpr ck_tile::index_t HDim = HDim_; @@ -654,7 +657,7 @@ struct fmha_fwd_appendkv_traits_ static constexpr bool kPadSk = kPadSk_; static constexpr bool kPadD = kPadD_; static constexpr bool kPadDv = kPadDv_; - static constexpr bool kApplyRoPE = kApplyRoPE_; + static constexpr auto RotaryEnum = RotaryEnum_; }; template @@ -685,7 +688,7 @@ struct fmha_fwd_appendkv_traits std::string data_type; bool is_group_mode; bool is_v_rowmajor; - bool apply_rope; + rope_enum rope_type; }; float fmha_fwd_appendkv(fmha_fwd_appendkv_traits, fmha_fwd_appendkv_args, diff --git a/example/ck_tile/01_fmha/rotary.hpp b/example/ck_tile/01_fmha/rotary.hpp index a4eacb157f..423c313a48 100644 --- a/example/ck_tile/01_fmha/rotary.hpp +++ b/example/ck_tile/01_fmha/rotary.hpp @@ -1,6 +1,8 @@ // SPDX-License-Identifier: MIT // Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +#pragma once + #include "ck_tile/core.hpp" #include "ck_tile/host/host_tensor.hpp" @@ -12,6 +14,14 @@ #include #include +// keep sync with BlockRotaryEmbeddingEnum +enum class rope_enum +{ + none = 0, + interleaved = 1, + half_rotated = 2, +}; + template std::tuple, ck_tile::HostTensor> generate_rotary_cos_sin(ck_tile::index_t seqlen_k, @@ -49,13 +59,6 @@ generate_rotary_cos_sin(ck_tile::index_t seqlen_k, return std::make_tuple(cos, sin); } -ck_tile::index_t generate_seqlen_offset(ck_tile::index_t seqlen, - std::optional seed = std::nullopt) -{ - std::mt19937 random_engine(seed.has_value() ? *seed : std::random_device{}()); - return std::uniform_int_distribution{0, seqlen}(random_engine); -} - template std::tuple, ck_tile::HostTensor> index_cos_sin(const ck_tile::HostTensor& cos, diff --git a/include/ck_tile/host/reference/reference_batched_rotary_position_embedding.hpp b/include/ck_tile/host/reference/reference_batched_rotary_position_embedding.hpp index 6475a0fb8a..3f5218b19b 100644 --- a/include/ck_tile/host/reference/reference_batched_rotary_position_embedding.hpp +++ b/include/ck_tile/host/reference/reference_batched_rotary_position_embedding.hpp @@ -40,6 +40,7 @@ CK_TILE_HOST void reference_batched_rotary_position_embedding(const HostTensor( interleaved ? sin_sd(i_s, i_d / 2) : sin_sd(i_s, i_d % sin_sd.get_length(1))); + const ComputeDataType half_rotated_input = [&] { const index_t i_b = i[0]; diff --git a/include/ck_tile/ops/fmha.hpp b/include/ck_tile/ops/fmha.hpp index 81f9a0aa6f..1846ebcece 100644 --- a/include/ck_tile/ops/fmha.hpp +++ b/include/ck_tile/ops/fmha.hpp @@ -7,6 +7,7 @@ #include "ck_tile/ops/fmha/block/block_dropout.hpp" #include "ck_tile/ops/fmha/block/block_masking.hpp" #include "ck_tile/ops/fmha/block/block_position_encoding.hpp" +#include "ck_tile/ops/fmha/block/block_rotary_embedding_enum.hpp" #include "ck_tile/ops/fmha/kernel/fmha_bwd_kernel.hpp" #include "ck_tile/ops/fmha/kernel/fmha_bwd_tile_partitioner.hpp" #include "ck_tile/ops/fmha/kernel/fmha_fwd_appendkv_kernel.hpp" diff --git a/include/ck_tile/ops/fmha/block/block_rotary_embedding_enum.hpp b/include/ck_tile/ops/fmha/block/block_rotary_embedding_enum.hpp new file mode 100644 index 0000000000..32e7b66976 --- /dev/null +++ b/include/ck_tile/ops/fmha/block/block_rotary_embedding_enum.hpp @@ -0,0 +1,37 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include + +namespace ck_tile { + +// This class is used for codegen pattern matching +enum class BlockRotaryEmbeddingEnum +{ + NONE = 0, + INTERLEAVED = 1, // combine dimensions 0 & 1, 2 & 3, etc + HALF_ROTATED = 2, // combine dimensions 0 & rotary_dim / 2, 1 & rotary_dim / 2 + 1, etc +}; + +template +struct BlockRotaryEmbeddingEnumToStr; + +template <> +struct BlockRotaryEmbeddingEnumToStr +{ + static constexpr const char* name = ""; +}; +template <> +struct BlockRotaryEmbeddingEnumToStr +{ + static constexpr const char* name = "inter"; +}; +template <> +struct BlockRotaryEmbeddingEnumToStr +{ + static constexpr const char* name = "half"; +}; + +} // namespace ck_tile diff --git a/include/ck_tile/ops/fmha/kernel/fmha_fwd_appendkv_kernel.hpp b/include/ck_tile/ops/fmha/kernel/fmha_fwd_appendkv_kernel.hpp index 3558cb3513..b64aeb48d7 100644 --- a/include/ck_tile/ops/fmha/kernel/fmha_fwd_appendkv_kernel.hpp +++ b/include/ck_tile/ops/fmha/kernel/fmha_fwd_appendkv_kernel.hpp @@ -31,7 +31,7 @@ struct FmhaFwdAppendKVKernel static constexpr bool kPadSeqLenK = FmhaPipeline::kPadSeqLenK; static constexpr bool kPadHeadDimQ = FmhaPipeline::kPadHeadDimQ; static constexpr bool kPadHeadDimV = FmhaPipeline::kPadHeadDimV; - static constexpr bool kApplyRoPE = FmhaPipeline::kApplyRoPE; + static constexpr bool kApplyRoPE = FmhaPipeline::RotaryEnum != BlockRotaryEmbeddingEnum::NONE; // clang-format off template struct t2s; @@ -62,7 +62,7 @@ struct FmhaFwdAppendKVKernel "b" + _TS_(FmhaPipeline::kTileSizeS) + "x" + _TS_(FmhaPipeline::kTileSizeSk) + "x" + _TS_(FmhaPipeline::kTileSizeD) + "x" + _TS_(FmhaPipeline::kTileSizeDv) + "_" + (kBlockPerCuInput == -1 ? "" : ("o" + _TS_(kBlockPerCu) + "_")) + "v" + (std::is_same_v ? "r" : "c") + (pn.empty() ? "" : "_" + pn) - + (kApplyRoPE ? "_rope" : ""); + + (!kApplyRoPE ? _SS_("") : (_SS_("_") + BlockRotaryEmbeddingEnumToStr::name)); #undef _SS_ #undef _TS_ // clang-format on @@ -117,7 +117,6 @@ struct FmhaFwdAppendKVKernel const void* rotary_cos_ptr; const void* rotary_sin_ptr; ck_tile::index_t rotary_dim; - bool is_rotary_interleaved; }; struct BatchModeKargs : CommonKargs, @@ -155,7 +154,6 @@ struct FmhaFwdAppendKVKernel const void* rotary_cos_ptr, const void* rotary_sin_ptr, ck_tile::index_t rotary_dim, - bool is_rotary_interleaved, ck_tile::index_t stride_q, ck_tile::index_t stride_k, ck_tile::index_t stride_knew, @@ -203,10 +201,9 @@ struct FmhaFwdAppendKVKernel if constexpr(kApplyRoPE) { - kargs.rotary_cos_ptr = rotary_cos_ptr; - kargs.rotary_sin_ptr = rotary_sin_ptr; - kargs.rotary_dim = rotary_dim; - kargs.is_rotary_interleaved = is_rotary_interleaved; + kargs.rotary_cos_ptr = rotary_cos_ptr; + kargs.rotary_sin_ptr = rotary_sin_ptr; + kargs.rotary_dim = rotary_dim; } return kargs; @@ -230,7 +227,6 @@ struct FmhaFwdAppendKVKernel const void* rotary_cos_ptr, const void* rotary_sin_ptr, ck_tile::index_t rotary_dim, - bool is_rotary_interleaved, ck_tile::index_t stride_q, ck_tile::index_t stride_k, ck_tile::index_t stride_knew, @@ -275,10 +271,9 @@ struct FmhaFwdAppendKVKernel if constexpr(kApplyRoPE) { - kargs.rotary_cos_ptr = rotary_cos_ptr; - kargs.rotary_sin_ptr = rotary_sin_ptr; - kargs.rotary_dim = rotary_dim; - kargs.is_rotary_interleaved = is_rotary_interleaved; + kargs.rotary_cos_ptr = rotary_cos_ptr; + kargs.rotary_sin_ptr = rotary_sin_ptr; + kargs.rotary_dim = rotary_dim; } return kargs; @@ -626,8 +621,7 @@ struct FmhaFwdAppendKVKernel rotary_cos_dram_window, rotary_sin_dram_window, smem_ptr, - kargs.rotary_dim, - kargs.is_rotary_interleaved); + kargs.rotary_dim); } else { diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_appendkv_pipeline.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_appendkv_pipeline.hpp index ab3fea23ed..11794e67be 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_appendkv_pipeline.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_appendkv_pipeline.hpp @@ -4,6 +4,7 @@ #pragma once #include "ck_tile/core.hpp" +#include "ck_tile/ops/fmha/block/block_rotary_embedding_enum.hpp" #include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_appendkv_pipeline_default_policy.hpp" namespace ck_tile { @@ -31,7 +32,7 @@ struct BlockFmhaFwdAppendKVPipeline static constexpr bool kPadSeqLenK = Problem::kPadSeqLenK; static constexpr bool kPadHeadDimQ = Problem::kPadHeadDimQ; static constexpr bool kPadHeadDimV = Problem::kPadHeadDimV; - static constexpr bool kApplyRoPE = Problem::kApplyRoPE; + static constexpr auto RotaryEnum = Problem::RotaryEnum; // last dimension vector length used to create tensor view(and decide buffer_load vector length) // ... together with tensor distribution. tensor dist should able to overwrite this @@ -101,8 +102,7 @@ struct BlockFmhaFwdAppendKVPipeline const RotaryCosBlockWindowTemp rotary_cos_block_window_tmp, const RotarySinBlockWindowTemp rotary_sin_block_window_tmp, void* smem_ptr, - index_t rotary_dim = 0, - bool is_rotary_interleaved = false) const + index_t rotary_dim = 0) const { auto* const ksmem = reinterpret_cast(smem_ptr); @@ -125,7 +125,6 @@ struct BlockFmhaFwdAppendKVPipeline (void)rotary_sin_block_window_tmp; (void)smem_ptr; (void)rotary_dim; - (void)is_rotary_interleaved; auto knew_dram_block_window = make_tile_window(knew_dram_block_window_tmp.get_bottom_tensor_view(), @@ -140,7 +139,7 @@ struct BlockFmhaFwdAppendKVPipeline auto knew_tile = load_tile(knew_dram_window); - if constexpr(kApplyRoPE) + if constexpr(RotaryEnum != BlockRotaryEmbeddingEnum::NONE) { auto rotary_cos_window = make_tile_window( rotary_cos_block_window_tmp.get_bottom_tensor_view(), @@ -188,12 +187,54 @@ struct BlockFmhaFwdAppendKVPipeline } } #endif +#define DUMP_KNEW 0 constexpr index_t KPerThread = 16 / sizeof(KDataType); static_assert(kTileSizeD % KPerThread == 0); constexpr index_t KThreadPerBlock = kTileSizeD / KPerThread; index_t start_x = (threadIdx.x % KThreadPerBlock) * KPerThread; + if((start_x + KPerThread) <= rotary_dim) { + bool is_left = (start_x + KPerThread) <= (rotary_dim / 2); + + auto knew_other_dram_window = knew_dram_window; + DEVICE_DEBUG_STMTS + { + auto origin = knew_other_dram_window.get_window_origin(); + printf("after move window, origin = (%3d, %3d)\n", + origin.at(number<0>{}), + origin.at(number<1>{})); + } + move_tile_window(knew_other_dram_window, + {0, is_left ? rotary_dim / 2 : -(rotary_dim / 2)}); + DEVICE_DEBUG_STMTS + { + auto origin = knew_other_dram_window.get_window_origin(); + printf("after move window, origin = (%3d, %3d)\n", + origin.at(number<0>{}), + origin.at(number<1>{})); + } + auto knew_other_tile = load_tile(knew_other_dram_window); + +#if !DUMP_KNEW + { + constexpr auto spans = decltype(knew_other_tile)::get_distributed_spans(); + sweep_tile_span(spans[number<0>{}], [&](auto idx0) { + sweep_tile_span(spans[number<1>{}], [&](auto idx1) { + const auto tile_idx = get_x_indices_from_distributed_indices( + knew_other_tile.get_tile_distribution(), make_tuple(idx0, idx1)); + + const auto row = tile_idx.at(number<0>{}); + const auto col = tile_idx.at(number<1>{}); + constexpr auto i_j_idx = make_tuple(idx0, idx1); + + ksmem[row * kTileSizeD + col] = knew_other_tile(i_j_idx); + }); + }); + } +#endif + +#if !defined(DUMP_KNEW) constexpr index_t thread_buffer_size = decltype(knew_tile.thread_buf_)::size(); static_assert(thread_buffer_size % KPerThread == 0); static_for<0, thread_buffer_size, 2>{}([&](auto idx) { @@ -206,10 +247,12 @@ struct BlockFmhaFwdAppendKVPipeline knew_tile.thread_buf_[idx] = left * cos - right * sin; knew_tile.thread_buf_[idx + 1] = right * cos + left * sin; }); +#endif } #if defined(ENABLE_DEVICE_DEBUG_STMTS) DEVICE_DEBUG_STMTS { printf("[DEVICE] kTileSizeD: %3d\n", kTileSizeD); } +#if DUMP_KNEW { constexpr auto spans = decltype(knew_tile)::get_distributed_spans(); sweep_tile_span(spans[number<0>{}], [&](auto idx0) { @@ -225,6 +268,7 @@ struct BlockFmhaFwdAppendKVPipeline }); }); } +#endif block_sync_lds(); @@ -232,7 +276,12 @@ struct BlockFmhaFwdAppendKVPipeline { for(int row = 0; row < 7; ++row) { +#if DUMP_KNEW printf("[DEVICE] knew_tile[%3d] = ", row); +#else + printf("[DEVICE] knew_other_tile[%3d] = ", row); +#endif + for(int col = 0; col < kTileSizeD; ++col) { printf("%11.7f", type_convert(ksmem[row * kTileSizeD + col])); @@ -297,8 +346,7 @@ struct BlockFmhaFwdAppendKVPipeline const RotaryCosBlockWindowTemp& rotary_cos_block_window_tmp, const RotarySinBlockWindowTemp& rotary_sin_block_window_tmp, void* smem_ptr, - index_t rotary_dim = 0, - bool is_rotary_interleaved = false) const + index_t rotary_dim = 0) const { return operator()(q_dram_block_window_tmp, identity{}, @@ -313,8 +361,7 @@ struct BlockFmhaFwdAppendKVPipeline rotary_cos_block_window_tmp, rotary_sin_block_window_tmp, smem_ptr, - rotary_dim, - is_rotary_interleaved); + rotary_dim); } }; diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_appendkv_pipeline_problem.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_appendkv_pipeline_problem.hpp index 34a9f70125..5b4ab00eed 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_appendkv_pipeline_problem.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_appendkv_pipeline_problem.hpp @@ -41,7 +41,7 @@ struct BlockFmhaFwdAppendKVPipelineProblem static constexpr bool kPadSeqLenK = Traits::kPadSeqLenK; static constexpr bool kPadHeadDimQ = Traits::kPadHeadDimQ; static constexpr bool kPadHeadDimV = Traits::kPadHeadDimV; - static constexpr bool kApplyRoPE = Traits::kApplyRoPE; + static constexpr auto RotaryEnum = Traits::RotaryEnum; static constexpr index_t kBlockPerCu = Traits::kBlockPerCu; }; diff --git a/include/ck_tile/ops/fmha/pipeline/tile_fmha_traits.hpp b/include/ck_tile/ops/fmha/pipeline/tile_fmha_traits.hpp index 8638bf2408..c3ae9772d5 100644 --- a/include/ck_tile/ops/fmha/pipeline/tile_fmha_traits.hpp +++ b/include/ck_tile/ops/fmha/pipeline/tile_fmha_traits.hpp @@ -5,6 +5,7 @@ #include "ck_tile/core.hpp" #include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp" +#include "ck_tile/ops/fmha/block/block_rotary_embedding_enum.hpp" namespace ck_tile { @@ -80,7 +81,7 @@ template struct TileFmhaFwdAppendKVTraits { @@ -88,7 +89,7 @@ struct TileFmhaFwdAppendKVTraits static constexpr bool kPadSeqLenK = kPadSeqLenK_; static constexpr bool kPadHeadDimQ = kPadHeadDimQ_; static constexpr bool kPadHeadDimV = kPadHeadDimV_; - static constexpr bool kApplyRoPE = kApplyRoPE_; + static constexpr auto RotaryEnum = RotaryEnum_; static constexpr index_t kBlockPerCu = kBlockPerCu_; };