From ed8630416e6c0989e70f2b7ee9a8bfc11ba63876 Mon Sep 17 00:00:00 2001 From: Po Yen Chen Date: Thu, 21 Nov 2024 14:53:10 +0800 Subject: [PATCH] [CK_TILE] Add paged-kvcache support in group mode fmha fwd splitkv kernels (#1678) * Generate group mode paged-attn kernel * Enable paged-kvcache + group mode support * Add missing header: fused_moe.hpp * Add comment to explain kernel arg usage * Make error message more clear * Add comment for confusing data member names * Add more comment for confusing variable names * Fix typo in option description [ROCm/composable_kernel commit: fb1ccfa9df534c8c9f351dd959a0ff692d6f9210] --- .../01_fmha/codegen/ops/fmha_fwd_splitkv.py | 3 - example/ck_tile/01_fmha/fmha_fwd.cpp | 61 ++++++++++++------- example/ck_tile/01_fmha/fmha_fwd.hpp | 10 ++- example/ck_tile/01_fmha/utils.hpp | 4 +- .../fmha/kernel/fmha_fwd_splitkv_kernel.hpp | 49 ++++++++++----- include/ck_tile/ops/fused_moe.hpp | 11 ++++ 6 files changed, 95 insertions(+), 43 deletions(-) create mode 100644 include/ck_tile/ops/fused_moe.hpp diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py index b084e9d0fc..d1da951567 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py @@ -655,9 +655,6 @@ def get_fwd_splitkv_blobs(kernel_filter : Optional[str], receipt, mask_impl) -> if pipeline.F_spad != 't' or pipeline.F_skpad != 't': # in group mode, spad/skpad must be true, since we can't predict if seqlen of current batch need pad or not continue - if pipeline.F_pagedkv == 't': - # we only use batch mode kernels to handle (paged-) kvcache problems - continue k = Kernel(F_idx=0, F_hdim=hdim, F_dtype=dtype, diff --git a/example/ck_tile/01_fmha/fmha_fwd.cpp b/example/ck_tile/01_fmha/fmha_fwd.cpp index 14291715fb..00e0a16536 100644 --- a/example/ck_tile/01_fmha/fmha_fwd.cpp +++ b/example/ck_tile/01_fmha/fmha_fwd.cpp @@ -62,7 +62,7 @@ auto create_args(int argc, char* argv[]) "-1 to choose s_knew in [1, s] randomly.") .insert("s_kpad", "-1", - "seqlen_k stride between 2 tokens, currently used in group-mode only\n" + "seqlen_k stride between 2 batches, currently used in group-mode only\n" "for kv-cache case, each batch [1,s,h,d]/[1,h,s,d] can have a stride\n" "along seqlen, instead of packed. same as xformer kv_padding") .insert("d", "128", "head dim for q, k") @@ -294,7 +294,8 @@ bool run(const ck_tile::ArgParser& arg_parser) #if !CK_TILE_FMHA_FWD_APPENDKV_API if(seqlen_knew != 0) { - std::cerr << "kvcache is not supported. ignoring the 's_knew' option" << std::endl; + std::cerr << "fmha_fwd_appendkv() is not enabled. ignoring the 's_knew' option" + << std::endl; seqlen_knew = 0; } #endif @@ -321,6 +322,13 @@ bool run(const ck_tile::ArgParser& arg_parser) rotary_dim = 0; } #endif + // to use fmha_fwd_appendkv(), make sure it's in batch mode + const bool need_append_kvcache = (0 < seqlen_knew || 0 < rotary_dim); + if(need_append_kvcache && mode == mode_enum::group) + { + std::cerr << "fmha_fwd_appendkv() will be invoked. ignoring the 'mode' option" << std::endl; + mode = mode_enum::batch; + } if(!(rotary_dim <= hdim_q)) { std::cerr << "rotary_dim should be less than or equal to head dim for q" << std::endl; @@ -356,22 +364,26 @@ bool run(const ck_tile::ArgParser& arg_parser) << std::endl; use_cache_batch_idx = false; } +#else + if(use_cache_batch_idx) + { + if(0 < page_block_size) + { + std::cerr << "paged-kvcache does not support cache_batch_idx. ignoring the " + "'cache_batch_idx' option" + << std::endl; + use_cache_batch_idx = false; + } + else if(mode == mode_enum::group) + { + std::cerr << "group mode will not use cache_batch_idx. ignoring the " + "'cache_batch_idx' option" + << std::endl; + use_cache_batch_idx = false; + } + } #endif - if(0 < page_block_size && use_cache_batch_idx) - { - std::cerr << "paged-kvcache does not support cache_batch_idx. ignoring the " - "'cache_batch_idx' option" - << std::endl; - use_cache_batch_idx = false; - } - // the input tensor layout for kvcache is same as batch mode - const bool need_append_kvcache = (0 < seqlen_knew || 0 < rotary_dim); const bool use_kvcache = (need_append_kvcache || use_cache_batch_idx || 0 < page_block_size); - if(use_kvcache && mode != mode_enum::batch) - { - std::cerr << "kvcache enabled. ignoring the 'mode' option" << std::endl; - mode = mode_enum::batch; - } auto [seqlen_qs, seqlen_ks, seqlen_kpads] = decode_seqlen(mode, @@ -380,7 +392,7 @@ bool run(const ck_tile::ArgParser& arg_parser) arg_parser.get_str("s_k"), arg_parser.get_str("s_kpad"), /*seqlen_k_min=*/0 < seqlen_knew ? seqlen_knew : 0, - use_kvcache); + need_append_kvcache); // compute kvcache seqlen_k (before appending knew/vnew) auto cache_seqlen_ks = seqlen_ks; std::transform(cache_seqlen_ks.begin(), @@ -741,8 +753,10 @@ bool run(const ck_tile::ArgParser& arg_parser) ck_tile::DeviceMem o_buf(o_host.get_element_space_size_in_bytes()); ck_tile::DeviceMem seqstart_q(seqstart_q_host.size() * sizeof(int32_t)); ck_tile::DeviceMem seqstart_k(seqstart_k_host.size() * sizeof(int32_t)); - ck_tile::DeviceMem seqlen_k_buf( - use_kvcache || 0 <= seqlen_kpads[0] ? seqlen_ks.size() * sizeof(int32_t) : 0); + ck_tile::DeviceMem seqlen_k_buf((mode == mode_enum::batch && use_kvcache) || + 0 <= seqlen_kpads[0] + ? seqlen_ks.size() * sizeof(int32_t) + : 0); ck_tile::DeviceMem cache_seqlen_k_buf( need_append_kvcache ? cache_seqlen_ks.size() * sizeof(int32_t) : 0); ck_tile::DeviceMem rotary_cos_buf(rotary_cos_host.get_element_space_size_in_bytes()); @@ -763,7 +777,9 @@ bool run(const ck_tile::ArgParser& arg_parser) seqstart_q.ToDevice(seqstart_q_host.data()); seqstart_k.ToDevice(seqlen_kpads[0] < 0 ? seqstart_k_host.data() : seqstart_k_with_padding_host.data()); - seqlen_k_buf.ToDevice(use_kvcache || 0 <= seqlen_kpads[0] ? seqlen_ks.data() : nullptr); + seqlen_k_buf.ToDevice((mode == mode_enum::batch && use_kvcache) || 0 <= seqlen_kpads[0] + ? seqlen_ks.data() + : nullptr); cache_seqlen_k_buf.ToDevice(need_append_kvcache ? cache_seqlen_ks.data() : nullptr); rotary_cos_buf.ToDevice(rotary_cos_host.data()); rotary_sin_buf.ToDevice(rotary_sin_host.data()); @@ -976,8 +992,9 @@ bool run(const ck_tile::ArgParser& arg_parser) (mode == mode_enum::group ? seqstart_q.GetDeviceBuffer() : nullptr); args.seqstart_k_ptr = (mode == mode_enum::group ? seqstart_k.GetDeviceBuffer() : nullptr); - args.seqlen_k_ptr = - (use_kvcache || 0 <= k_paddings_[0] ? seqlen_k_buf.GetDeviceBuffer() : nullptr); + args.seqlen_k_ptr = ((mode == mode_enum::batch && use_kvcache) || 0 <= k_paddings_[0] + ? seqlen_k_buf.GetDeviceBuffer() + : nullptr); args.seqlen_k = shape_seqlen_k; // unused in group mode (or kvcache enabled) args.max_seqlen_q = max_seqlen_q; diff --git a/example/ck_tile/01_fmha/fmha_fwd.hpp b/example/ck_tile/01_fmha/fmha_fwd.hpp index 251e61bc76..41edac67ba 100644 --- a/example/ck_tile/01_fmha/fmha_fwd.hpp +++ b/example/ck_tile/01_fmha/fmha_fwd.hpp @@ -173,8 +173,11 @@ struct fmha_fwd_splitkv_args // seqlen_k = kargs.seqlen_k // group mode: seqlen_q = kargs.seqstart_q_ptr[b + 1] - kargs.seqstart_q_ptr[b] // seqlen_k = kargs.seqstart_k_ptr[b + 1] - kargs.seqstart_k_ptr[b] - // kvcache mode (use same kernel as batch mode): + // batch mode (kvcache): // seqlen_q = kargs.seqlen_q + // seqlen_k = kargs.seqlen_k_ptr[b] + // group mode (kvcache): + // seqlen_q = kargs.seqstart_q_ptr[b + 1] - kargs.seqstart_q_ptr[b] // seqlen_k = kargs.seqstart_k_ptr[b + 1] - kargs.seqstart_k_ptr[b] const void* seqstart_q_ptr; const void* seqstart_k_ptr; @@ -251,7 +254,7 @@ struct fmha_fwd_appendkv_args ck_tile::index_t batch_stride_block_table; // only used if 'block_table_ptr' is not nullptr ck_tile::index_t page_block_size; // only used if 'block_table_ptr' is not nullptr - const void* cache_batch_idx; + const void* cache_batch_idx; // only used if block_table_ptr is nullptr -> batch mode (kvcache) ck_tile::index_t stride_q; ck_tile::index_t stride_k; @@ -389,6 +392,9 @@ auto fmha_fwd_splitkv_create_kargs_and_grids(fmha_fwd_splitkv_args args) args.nhead_q, args.nhead_q / args.nhead_k, args.num_splits, + args.block_table_ptr, + args.batch_stride_block_table, + args.page_block_size, args.scale_s, args.scale_p, args.stride_q, diff --git a/example/ck_tile/01_fmha/utils.hpp b/example/ck_tile/01_fmha/utils.hpp index 996032a717..faf3f08437 100644 --- a/example/ck_tile/01_fmha/utils.hpp +++ b/example/ck_tile/01_fmha/utils.hpp @@ -145,7 +145,7 @@ decode_seqlen(mode_enum mode, std::string k_val, std::string k_pad_val, ck_tile::index_t seqlen_k_min = 0, - bool use_kvcache = false, + bool need_append_kvcache = false, std::optional seed = std::nullopt) { #define _S2I_(str_) static_cast(std::atoi((str_).c_str())) @@ -159,7 +159,7 @@ decode_seqlen(mode_enum mode, const ck_tile::index_t seqlen_k_max = (k < 0 ? q : k); std::vector seqlen_ks(batch, seqlen_k_max); - if(1 < batch && use_kvcache) + if(1 < batch && need_append_kvcache) { // to keep the original s_k value, we always use seqlen_k_max in first batch randints(std::next(seqlen_ks.begin()), diff --git a/include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_kernel.hpp b/include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_kernel.hpp index 4ffebc3c9c..98a4329d75 100644 --- a/include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_kernel.hpp +++ b/include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_kernel.hpp @@ -46,8 +46,7 @@ struct FmhaFwdSplitKVKernel static constexpr auto BiasEnum = FmhaPipeline::BiasEnum; static constexpr bool kDoFp8StaticQuant = FmhaPipeline::Problem::kDoFp8StaticQuant; static constexpr bool kIsPagedKV = FmhaPipeline::Problem::kIsPagedKV; - static_assert(!kIsGroupMode || (kIsGroupMode && !kIsPagedKV), - "paged-kvcache only supported by batch mode kernels"); + using FmhaMask = ck_tile::remove_cvref_t; static constexpr bool kHasMask = FmhaMask::IsMasking; @@ -198,8 +197,10 @@ struct FmhaFwdSplitKVKernel const int32_t* seqlen_k_ptr; ck_tile::index_t batch_stride_q; - ck_tile::index_t batch_stride_k; - ck_tile::index_t batch_stride_v; + ck_tile::index_t batch_stride_k; // when using paged-kvcache, this will be stride/size for + // single kcache page-block + ck_tile::index_t batch_stride_v; // when using paged-kvcache, this will be stride/size for + // single vcache page-block ck_tile::index_t batch_stride_lse_acc; ck_tile::index_t batch_stride_o_acc; }; @@ -212,14 +213,17 @@ struct FmhaFwdSplitKVKernel AlibiKargs, EmptyKargs<0>>>, std::conditional_t>, - std::conditional_t> + std::conditional_t>, + std::conditional_t> { const int32_t* seqstart_q_ptr; const int32_t* seqstart_k_ptr; const int32_t* seqlen_k_ptr; - ck_tile::index_t batch_stride_k; // only used for paged-kvcache - ck_tile::index_t batch_stride_v; // only used for paged-kvcache + ck_tile::index_t batch_stride_k; // only used for paged-kvcache, this will be stride/size + // for single kcache page-block + ck_tile::index_t batch_stride_v; // only used for paged-kvcache, this will be stride/size + // for single vcache page-block }; using Kargs = std::conditional_t; @@ -363,6 +367,9 @@ struct FmhaFwdSplitKVKernel ck_tile::index_t num_head_q, ck_tile::index_t nhead_ratio_qk, ck_tile::index_t num_splits, + const void* block_table_ptr, + ck_tile::index_t batch_stride_block_table, + ck_tile::index_t page_block_size, float scale_s, float scale_p, ck_tile::index_t stride_q, @@ -416,6 +423,7 @@ struct FmhaFwdSplitKVKernel {}, // placeholder for bias {}, // placeholder for mask {}, // placeholder for fp8_static_quant args + {}, // placeholder for paged-block table reinterpret_cast(seqstart_q_ptr), reinterpret_cast(seqstart_k_ptr), reinterpret_cast(seqlen_k_ptr), @@ -443,6 +451,12 @@ struct FmhaFwdSplitKVKernel { kargs.scale_p = scale_p; } + if constexpr(kIsPagedKV) + { + kargs.block_table_ptr = reinterpret_cast(block_table_ptr); + kargs.batch_stride_block_table = batch_stride_block_table; + kargs.page_block_size = page_block_size; + } return kargs; } @@ -489,15 +503,22 @@ struct FmhaFwdSplitKVKernel const long_index_t key_start = kargs.seqstart_k_ptr[i_batch]; batch_offset_q = query_start * kargs.stride_q; - batch_offset_k = key_start * kargs.stride_k; - - if constexpr(std::is_same_v) + if constexpr(kIsPagedKV) { - batch_offset_v = key_start * kargs.stride_v; + batch_offset_k = static_cast(i_batch) * kargs.batch_stride_k; + batch_offset_v = static_cast(i_batch) * kargs.batch_stride_v; } else { - batch_offset_v = key_start; + batch_offset_k = key_start * kargs.stride_k; + if constexpr(std::is_same_v) + { + batch_offset_v = key_start * kargs.stride_v; + } + else + { + batch_offset_v = key_start; + } } if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) { @@ -685,7 +706,7 @@ struct FmhaFwdSplitKVKernel return make_page_block_navigator( kargs.k_ptr, - kargs.batch_stride_k, + kargs.batch_stride_k, // kcache page-block stride/size fixed_offset, block_indices, num_blocks, @@ -715,7 +736,7 @@ struct FmhaFwdSplitKVKernel return make_page_block_navigator( kargs.v_ptr, - kargs.batch_stride_v, + kargs.batch_stride_v, // vcache page-block stride/size fixed_offset, block_indices, num_blocks, diff --git a/include/ck_tile/ops/fused_moe.hpp b/include/ck_tile/ops/fused_moe.hpp new file mode 100644 index 0000000000..b74607f061 --- /dev/null +++ b/include/ck_tile/ops/fused_moe.hpp @@ -0,0 +1,11 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/ops/fused_moe/kernel/moe_sorting_kernel.hpp" +#include "ck_tile/ops/fused_moe/pipeline/moe_sorting_pipeline.hpp" +#include "ck_tile/ops/fused_moe/pipeline/moe_sorting_policy.hpp" +#include "ck_tile/ops/fused_moe/pipeline/moe_sorting_problem.hpp" +#include "ck_tile/ops/common/generic_2d_block_shape.hpp" +#include "ck_tile/ops/common/tensor_layout.hpp"