mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-14 10:09:41 +00:00
[CK_TILE] Add paged-kvcache support in group mode fmha fwd splitkv kernels (#1678)
* Generate group mode paged-attn kernel
* Enable paged-kvcache + group mode support
* Add missing header: fused_moe.hpp
* Add comment to explain kernel arg usage
* Make error message more clear
* Add comment for confusing data member names
* Add more comment for confusing variable names
* Fix typo in option description
[ROCm/composable_kernel commit: fb1ccfa9df]
This commit is contained in:
@@ -655,9 +655,6 @@ def get_fwd_splitkv_blobs(kernel_filter : Optional[str], receipt, mask_impl) ->
|
||||
if pipeline.F_spad != 't' or pipeline.F_skpad != 't':
|
||||
# in group mode, spad/skpad must be true, since we can't predict if seqlen of current batch need pad or not
|
||||
continue
|
||||
if pipeline.F_pagedkv == 't':
|
||||
# we only use batch mode kernels to handle (paged-) kvcache problems
|
||||
continue
|
||||
k = Kernel(F_idx=0,
|
||||
F_hdim=hdim,
|
||||
F_dtype=dtype,
|
||||
|
||||
@@ -62,7 +62,7 @@ auto create_args(int argc, char* argv[])
|
||||
"-1 to choose s_knew in [1, s] randomly.")
|
||||
.insert("s_kpad",
|
||||
"-1",
|
||||
"seqlen_k stride between 2 tokens, currently used in group-mode only\n"
|
||||
"seqlen_k stride between 2 batches, currently used in group-mode only\n"
|
||||
"for kv-cache case, each batch [1,s,h,d]/[1,h,s,d] can have a stride\n"
|
||||
"along seqlen, instead of packed. same as xformer kv_padding")
|
||||
.insert("d", "128", "head dim for q, k")
|
||||
@@ -294,7 +294,8 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
#if !CK_TILE_FMHA_FWD_APPENDKV_API
|
||||
if(seqlen_knew != 0)
|
||||
{
|
||||
std::cerr << "kvcache is not supported. ignoring the 's_knew' option" << std::endl;
|
||||
std::cerr << "fmha_fwd_appendkv() is not enabled. ignoring the 's_knew' option"
|
||||
<< std::endl;
|
||||
seqlen_knew = 0;
|
||||
}
|
||||
#endif
|
||||
@@ -321,6 +322,13 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
rotary_dim = 0;
|
||||
}
|
||||
#endif
|
||||
// to use fmha_fwd_appendkv(), make sure it's in batch mode
|
||||
const bool need_append_kvcache = (0 < seqlen_knew || 0 < rotary_dim);
|
||||
if(need_append_kvcache && mode == mode_enum::group)
|
||||
{
|
||||
std::cerr << "fmha_fwd_appendkv() will be invoked. ignoring the 'mode' option" << std::endl;
|
||||
mode = mode_enum::batch;
|
||||
}
|
||||
if(!(rotary_dim <= hdim_q))
|
||||
{
|
||||
std::cerr << "rotary_dim should be less than or equal to head dim for q" << std::endl;
|
||||
@@ -356,22 +364,26 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
<< std::endl;
|
||||
use_cache_batch_idx = false;
|
||||
}
|
||||
#else
|
||||
if(use_cache_batch_idx)
|
||||
{
|
||||
if(0 < page_block_size)
|
||||
{
|
||||
std::cerr << "paged-kvcache does not support cache_batch_idx. ignoring the "
|
||||
"'cache_batch_idx' option"
|
||||
<< std::endl;
|
||||
use_cache_batch_idx = false;
|
||||
}
|
||||
else if(mode == mode_enum::group)
|
||||
{
|
||||
std::cerr << "group mode will not use cache_batch_idx. ignoring the "
|
||||
"'cache_batch_idx' option"
|
||||
<< std::endl;
|
||||
use_cache_batch_idx = false;
|
||||
}
|
||||
}
|
||||
#endif
|
||||
if(0 < page_block_size && use_cache_batch_idx)
|
||||
{
|
||||
std::cerr << "paged-kvcache does not support cache_batch_idx. ignoring the "
|
||||
"'cache_batch_idx' option"
|
||||
<< std::endl;
|
||||
use_cache_batch_idx = false;
|
||||
}
|
||||
// the input tensor layout for kvcache is same as batch mode
|
||||
const bool need_append_kvcache = (0 < seqlen_knew || 0 < rotary_dim);
|
||||
const bool use_kvcache = (need_append_kvcache || use_cache_batch_idx || 0 < page_block_size);
|
||||
if(use_kvcache && mode != mode_enum::batch)
|
||||
{
|
||||
std::cerr << "kvcache enabled. ignoring the 'mode' option" << std::endl;
|
||||
mode = mode_enum::batch;
|
||||
}
|
||||
|
||||
auto [seqlen_qs, seqlen_ks, seqlen_kpads] =
|
||||
decode_seqlen(mode,
|
||||
@@ -380,7 +392,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
arg_parser.get_str("s_k"),
|
||||
arg_parser.get_str("s_kpad"),
|
||||
/*seqlen_k_min=*/0 < seqlen_knew ? seqlen_knew : 0,
|
||||
use_kvcache);
|
||||
need_append_kvcache);
|
||||
// compute kvcache seqlen_k (before appending knew/vnew)
|
||||
auto cache_seqlen_ks = seqlen_ks;
|
||||
std::transform(cache_seqlen_ks.begin(),
|
||||
@@ -741,8 +753,10 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
ck_tile::DeviceMem o_buf(o_host.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem seqstart_q(seqstart_q_host.size() * sizeof(int32_t));
|
||||
ck_tile::DeviceMem seqstart_k(seqstart_k_host.size() * sizeof(int32_t));
|
||||
ck_tile::DeviceMem seqlen_k_buf(
|
||||
use_kvcache || 0 <= seqlen_kpads[0] ? seqlen_ks.size() * sizeof(int32_t) : 0);
|
||||
ck_tile::DeviceMem seqlen_k_buf((mode == mode_enum::batch && use_kvcache) ||
|
||||
0 <= seqlen_kpads[0]
|
||||
? seqlen_ks.size() * sizeof(int32_t)
|
||||
: 0);
|
||||
ck_tile::DeviceMem cache_seqlen_k_buf(
|
||||
need_append_kvcache ? cache_seqlen_ks.size() * sizeof(int32_t) : 0);
|
||||
ck_tile::DeviceMem rotary_cos_buf(rotary_cos_host.get_element_space_size_in_bytes());
|
||||
@@ -763,7 +777,9 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
seqstart_q.ToDevice(seqstart_q_host.data());
|
||||
seqstart_k.ToDevice(seqlen_kpads[0] < 0 ? seqstart_k_host.data()
|
||||
: seqstart_k_with_padding_host.data());
|
||||
seqlen_k_buf.ToDevice(use_kvcache || 0 <= seqlen_kpads[0] ? seqlen_ks.data() : nullptr);
|
||||
seqlen_k_buf.ToDevice((mode == mode_enum::batch && use_kvcache) || 0 <= seqlen_kpads[0]
|
||||
? seqlen_ks.data()
|
||||
: nullptr);
|
||||
cache_seqlen_k_buf.ToDevice(need_append_kvcache ? cache_seqlen_ks.data() : nullptr);
|
||||
rotary_cos_buf.ToDevice(rotary_cos_host.data());
|
||||
rotary_sin_buf.ToDevice(rotary_sin_host.data());
|
||||
@@ -976,8 +992,9 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
(mode == mode_enum::group ? seqstart_q.GetDeviceBuffer() : nullptr);
|
||||
args.seqstart_k_ptr =
|
||||
(mode == mode_enum::group ? seqstart_k.GetDeviceBuffer() : nullptr);
|
||||
args.seqlen_k_ptr =
|
||||
(use_kvcache || 0 <= k_paddings_[0] ? seqlen_k_buf.GetDeviceBuffer() : nullptr);
|
||||
args.seqlen_k_ptr = ((mode == mode_enum::batch && use_kvcache) || 0 <= k_paddings_[0]
|
||||
? seqlen_k_buf.GetDeviceBuffer()
|
||||
: nullptr);
|
||||
|
||||
args.seqlen_k = shape_seqlen_k; // unused in group mode (or kvcache enabled)
|
||||
args.max_seqlen_q = max_seqlen_q;
|
||||
|
||||
@@ -173,8 +173,11 @@ struct fmha_fwd_splitkv_args
|
||||
// seqlen_k = kargs.seqlen_k
|
||||
// group mode: seqlen_q = kargs.seqstart_q_ptr[b + 1] - kargs.seqstart_q_ptr[b]
|
||||
// seqlen_k = kargs.seqstart_k_ptr[b + 1] - kargs.seqstart_k_ptr[b]
|
||||
// kvcache mode (use same kernel as batch mode):
|
||||
// batch mode (kvcache):
|
||||
// seqlen_q = kargs.seqlen_q
|
||||
// seqlen_k = kargs.seqlen_k_ptr[b]
|
||||
// group mode (kvcache):
|
||||
// seqlen_q = kargs.seqstart_q_ptr[b + 1] - kargs.seqstart_q_ptr[b]
|
||||
// seqlen_k = kargs.seqstart_k_ptr[b + 1] - kargs.seqstart_k_ptr[b]
|
||||
const void* seqstart_q_ptr;
|
||||
const void* seqstart_k_ptr;
|
||||
@@ -251,7 +254,7 @@ struct fmha_fwd_appendkv_args
|
||||
ck_tile::index_t batch_stride_block_table; // only used if 'block_table_ptr' is not nullptr
|
||||
ck_tile::index_t page_block_size; // only used if 'block_table_ptr' is not nullptr
|
||||
|
||||
const void* cache_batch_idx;
|
||||
const void* cache_batch_idx; // only used if block_table_ptr is nullptr -> batch mode (kvcache)
|
||||
|
||||
ck_tile::index_t stride_q;
|
||||
ck_tile::index_t stride_k;
|
||||
@@ -389,6 +392,9 @@ auto fmha_fwd_splitkv_create_kargs_and_grids(fmha_fwd_splitkv_args args)
|
||||
args.nhead_q,
|
||||
args.nhead_q / args.nhead_k,
|
||||
args.num_splits,
|
||||
args.block_table_ptr,
|
||||
args.batch_stride_block_table,
|
||||
args.page_block_size,
|
||||
args.scale_s,
|
||||
args.scale_p,
|
||||
args.stride_q,
|
||||
|
||||
@@ -145,7 +145,7 @@ decode_seqlen(mode_enum mode,
|
||||
std::string k_val,
|
||||
std::string k_pad_val,
|
||||
ck_tile::index_t seqlen_k_min = 0,
|
||||
bool use_kvcache = false,
|
||||
bool need_append_kvcache = false,
|
||||
std::optional<unsigned> seed = std::nullopt)
|
||||
{
|
||||
#define _S2I_(str_) static_cast<ck_tile::index_t>(std::atoi((str_).c_str()))
|
||||
@@ -159,7 +159,7 @@ decode_seqlen(mode_enum mode,
|
||||
const ck_tile::index_t seqlen_k_max = (k < 0 ? q : k);
|
||||
std::vector<ck_tile::index_t> seqlen_ks(batch, seqlen_k_max);
|
||||
|
||||
if(1 < batch && use_kvcache)
|
||||
if(1 < batch && need_append_kvcache)
|
||||
{
|
||||
// to keep the original s_k value, we always use seqlen_k_max in first batch
|
||||
randints(std::next(seqlen_ks.begin()),
|
||||
|
||||
@@ -46,8 +46,7 @@ struct FmhaFwdSplitKVKernel
|
||||
static constexpr auto BiasEnum = FmhaPipeline::BiasEnum;
|
||||
static constexpr bool kDoFp8StaticQuant = FmhaPipeline::Problem::kDoFp8StaticQuant;
|
||||
static constexpr bool kIsPagedKV = FmhaPipeline::Problem::kIsPagedKV;
|
||||
static_assert(!kIsGroupMode || (kIsGroupMode && !kIsPagedKV),
|
||||
"paged-kvcache only supported by batch mode kernels");
|
||||
|
||||
using FmhaMask = ck_tile::remove_cvref_t<typename FmhaPipeline::FmhaMask>;
|
||||
static constexpr bool kHasMask = FmhaMask::IsMasking;
|
||||
|
||||
@@ -198,8 +197,10 @@ struct FmhaFwdSplitKVKernel
|
||||
const int32_t* seqlen_k_ptr;
|
||||
|
||||
ck_tile::index_t batch_stride_q;
|
||||
ck_tile::index_t batch_stride_k;
|
||||
ck_tile::index_t batch_stride_v;
|
||||
ck_tile::index_t batch_stride_k; // when using paged-kvcache, this will be stride/size for
|
||||
// single kcache page-block
|
||||
ck_tile::index_t batch_stride_v; // when using paged-kvcache, this will be stride/size for
|
||||
// single vcache page-block
|
||||
ck_tile::index_t batch_stride_lse_acc;
|
||||
ck_tile::index_t batch_stride_o_acc;
|
||||
};
|
||||
@@ -212,14 +213,17 @@ struct FmhaFwdSplitKVKernel
|
||||
AlibiKargs,
|
||||
EmptyKargs<0>>>,
|
||||
std::conditional_t<kHasMask, MaskKargs, EmptyKargs<1>>,
|
||||
std::conditional_t<kDoFp8StaticQuant, Fp8StaticQuantKargs, EmptyKargs<2>>
|
||||
std::conditional_t<kDoFp8StaticQuant, Fp8StaticQuantKargs, EmptyKargs<2>>,
|
||||
std::conditional_t<kIsPagedKV, PageBlockTableKargs, EmptyKargs<3>>
|
||||
{
|
||||
const int32_t* seqstart_q_ptr;
|
||||
const int32_t* seqstart_k_ptr;
|
||||
const int32_t* seqlen_k_ptr;
|
||||
|
||||
ck_tile::index_t batch_stride_k; // only used for paged-kvcache
|
||||
ck_tile::index_t batch_stride_v; // only used for paged-kvcache
|
||||
ck_tile::index_t batch_stride_k; // only used for paged-kvcache, this will be stride/size
|
||||
// for single kcache page-block
|
||||
ck_tile::index_t batch_stride_v; // only used for paged-kvcache, this will be stride/size
|
||||
// for single vcache page-block
|
||||
};
|
||||
|
||||
using Kargs = std::conditional_t<kIsGroupMode, GroupModeKargs, BatchModeKargs>;
|
||||
@@ -363,6 +367,9 @@ struct FmhaFwdSplitKVKernel
|
||||
ck_tile::index_t num_head_q,
|
||||
ck_tile::index_t nhead_ratio_qk,
|
||||
ck_tile::index_t num_splits,
|
||||
const void* block_table_ptr,
|
||||
ck_tile::index_t batch_stride_block_table,
|
||||
ck_tile::index_t page_block_size,
|
||||
float scale_s,
|
||||
float scale_p,
|
||||
ck_tile::index_t stride_q,
|
||||
@@ -416,6 +423,7 @@ struct FmhaFwdSplitKVKernel
|
||||
{}, // placeholder for bias
|
||||
{}, // placeholder for mask
|
||||
{}, // placeholder for fp8_static_quant args
|
||||
{}, // placeholder for paged-block table
|
||||
reinterpret_cast<const int32_t*>(seqstart_q_ptr),
|
||||
reinterpret_cast<const int32_t*>(seqstart_k_ptr),
|
||||
reinterpret_cast<const int32_t*>(seqlen_k_ptr),
|
||||
@@ -443,6 +451,12 @@ struct FmhaFwdSplitKVKernel
|
||||
{
|
||||
kargs.scale_p = scale_p;
|
||||
}
|
||||
if constexpr(kIsPagedKV)
|
||||
{
|
||||
kargs.block_table_ptr = reinterpret_cast<const int32_t*>(block_table_ptr);
|
||||
kargs.batch_stride_block_table = batch_stride_block_table;
|
||||
kargs.page_block_size = page_block_size;
|
||||
}
|
||||
|
||||
return kargs;
|
||||
}
|
||||
@@ -489,15 +503,22 @@ struct FmhaFwdSplitKVKernel
|
||||
const long_index_t key_start = kargs.seqstart_k_ptr[i_batch];
|
||||
|
||||
batch_offset_q = query_start * kargs.stride_q;
|
||||
batch_offset_k = key_start * kargs.stride_k;
|
||||
|
||||
if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
|
||||
if constexpr(kIsPagedKV)
|
||||
{
|
||||
batch_offset_v = key_start * kargs.stride_v;
|
||||
batch_offset_k = static_cast<long_index_t>(i_batch) * kargs.batch_stride_k;
|
||||
batch_offset_v = static_cast<long_index_t>(i_batch) * kargs.batch_stride_v;
|
||||
}
|
||||
else
|
||||
{
|
||||
batch_offset_v = key_start;
|
||||
batch_offset_k = key_start * kargs.stride_k;
|
||||
if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
batch_offset_v = key_start * kargs.stride_v;
|
||||
}
|
||||
else
|
||||
{
|
||||
batch_offset_v = key_start;
|
||||
}
|
||||
}
|
||||
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
|
||||
{
|
||||
@@ -685,7 +706,7 @@ struct FmhaFwdSplitKVKernel
|
||||
|
||||
return make_page_block_navigator<const KDataType, 0>(
|
||||
kargs.k_ptr,
|
||||
kargs.batch_stride_k,
|
||||
kargs.batch_stride_k, // kcache page-block stride/size
|
||||
fixed_offset,
|
||||
block_indices,
|
||||
num_blocks,
|
||||
@@ -715,7 +736,7 @@ struct FmhaFwdSplitKVKernel
|
||||
|
||||
return make_page_block_navigator<const VDataType, 1>(
|
||||
kargs.v_ptr,
|
||||
kargs.batch_stride_v,
|
||||
kargs.batch_stride_v, // vcache page-block stride/size
|
||||
fixed_offset,
|
||||
block_indices,
|
||||
num_blocks,
|
||||
|
||||
11
include/ck_tile/ops/fused_moe.hpp
Normal file
11
include/ck_tile/ops/fused_moe.hpp
Normal file
@@ -0,0 +1,11 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/ops/fused_moe/kernel/moe_sorting_kernel.hpp"
|
||||
#include "ck_tile/ops/fused_moe/pipeline/moe_sorting_pipeline.hpp"
|
||||
#include "ck_tile/ops/fused_moe/pipeline/moe_sorting_policy.hpp"
|
||||
#include "ck_tile/ops/fused_moe/pipeline/moe_sorting_problem.hpp"
|
||||
#include "ck_tile/ops/common/generic_2d_block_shape.hpp"
|
||||
#include "ck_tile/ops/common/tensor_layout.hpp"
|
||||
Reference in New Issue
Block a user