mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-18 20:09:25 +00:00
[CK_TILE] Add paged-kvcache support in group mode fmha fwd splitkv kernels (#1678)
* Generate group mode paged-attn kernel
* Enable paged-kvcache + group mode support
* Add missing header: fused_moe.hpp
* Add comment to explain kernel arg usage
* Make error message more clear
* Add comment for confusing data member names
* Add more comment for confusing variable names
* Fix typo in option description
[ROCm/composable_kernel commit: fb1ccfa9df]
This commit is contained in:
@@ -655,9 +655,6 @@ def get_fwd_splitkv_blobs(kernel_filter : Optional[str], receipt, mask_impl) ->
|
||||
if pipeline.F_spad != 't' or pipeline.F_skpad != 't':
|
||||
# in group mode, spad/skpad must be true, since we can't predict if seqlen of current batch need pad or not
|
||||
continue
|
||||
if pipeline.F_pagedkv == 't':
|
||||
# we only use batch mode kernels to handle (paged-) kvcache problems
|
||||
continue
|
||||
k = Kernel(F_idx=0,
|
||||
F_hdim=hdim,
|
||||
F_dtype=dtype,
|
||||
|
||||
@@ -62,7 +62,7 @@ auto create_args(int argc, char* argv[])
|
||||
"-1 to choose s_knew in [1, s] randomly.")
|
||||
.insert("s_kpad",
|
||||
"-1",
|
||||
"seqlen_k stride between 2 tokens, currently used in group-mode only\n"
|
||||
"seqlen_k stride between 2 batches, currently used in group-mode only\n"
|
||||
"for kv-cache case, each batch [1,s,h,d]/[1,h,s,d] can have a stride\n"
|
||||
"along seqlen, instead of packed. same as xformer kv_padding")
|
||||
.insert("d", "128", "head dim for q, k")
|
||||
@@ -294,7 +294,8 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
#if !CK_TILE_FMHA_FWD_APPENDKV_API
|
||||
if(seqlen_knew != 0)
|
||||
{
|
||||
std::cerr << "kvcache is not supported. ignoring the 's_knew' option" << std::endl;
|
||||
std::cerr << "fmha_fwd_appendkv() is not enabled. ignoring the 's_knew' option"
|
||||
<< std::endl;
|
||||
seqlen_knew = 0;
|
||||
}
|
||||
#endif
|
||||
@@ -321,6 +322,13 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
rotary_dim = 0;
|
||||
}
|
||||
#endif
|
||||
// to use fmha_fwd_appendkv(), make sure it's in batch mode
|
||||
const bool need_append_kvcache = (0 < seqlen_knew || 0 < rotary_dim);
|
||||
if(need_append_kvcache && mode == mode_enum::group)
|
||||
{
|
||||
std::cerr << "fmha_fwd_appendkv() will be invoked. ignoring the 'mode' option" << std::endl;
|
||||
mode = mode_enum::batch;
|
||||
}
|
||||
if(!(rotary_dim <= hdim_q))
|
||||
{
|
||||
std::cerr << "rotary_dim should be less than or equal to head dim for q" << std::endl;
|
||||
@@ -356,22 +364,26 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
<< std::endl;
|
||||
use_cache_batch_idx = false;
|
||||
}
|
||||
#else
|
||||
if(use_cache_batch_idx)
|
||||
{
|
||||
if(0 < page_block_size)
|
||||
{
|
||||
std::cerr << "paged-kvcache does not support cache_batch_idx. ignoring the "
|
||||
"'cache_batch_idx' option"
|
||||
<< std::endl;
|
||||
use_cache_batch_idx = false;
|
||||
}
|
||||
else if(mode == mode_enum::group)
|
||||
{
|
||||
std::cerr << "group mode will not use cache_batch_idx. ignoring the "
|
||||
"'cache_batch_idx' option"
|
||||
<< std::endl;
|
||||
use_cache_batch_idx = false;
|
||||
}
|
||||
}
|
||||
#endif
|
||||
if(0 < page_block_size && use_cache_batch_idx)
|
||||
{
|
||||
std::cerr << "paged-kvcache does not support cache_batch_idx. ignoring the "
|
||||
"'cache_batch_idx' option"
|
||||
<< std::endl;
|
||||
use_cache_batch_idx = false;
|
||||
}
|
||||
// the input tensor layout for kvcache is same as batch mode
|
||||
const bool need_append_kvcache = (0 < seqlen_knew || 0 < rotary_dim);
|
||||
const bool use_kvcache = (need_append_kvcache || use_cache_batch_idx || 0 < page_block_size);
|
||||
if(use_kvcache && mode != mode_enum::batch)
|
||||
{
|
||||
std::cerr << "kvcache enabled. ignoring the 'mode' option" << std::endl;
|
||||
mode = mode_enum::batch;
|
||||
}
|
||||
|
||||
auto [seqlen_qs, seqlen_ks, seqlen_kpads] =
|
||||
decode_seqlen(mode,
|
||||
@@ -380,7 +392,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
arg_parser.get_str("s_k"),
|
||||
arg_parser.get_str("s_kpad"),
|
||||
/*seqlen_k_min=*/0 < seqlen_knew ? seqlen_knew : 0,
|
||||
use_kvcache);
|
||||
need_append_kvcache);
|
||||
// compute kvcache seqlen_k (before appending knew/vnew)
|
||||
auto cache_seqlen_ks = seqlen_ks;
|
||||
std::transform(cache_seqlen_ks.begin(),
|
||||
@@ -741,8 +753,10 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
ck_tile::DeviceMem o_buf(o_host.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem seqstart_q(seqstart_q_host.size() * sizeof(int32_t));
|
||||
ck_tile::DeviceMem seqstart_k(seqstart_k_host.size() * sizeof(int32_t));
|
||||
ck_tile::DeviceMem seqlen_k_buf(
|
||||
use_kvcache || 0 <= seqlen_kpads[0] ? seqlen_ks.size() * sizeof(int32_t) : 0);
|
||||
ck_tile::DeviceMem seqlen_k_buf((mode == mode_enum::batch && use_kvcache) ||
|
||||
0 <= seqlen_kpads[0]
|
||||
? seqlen_ks.size() * sizeof(int32_t)
|
||||
: 0);
|
||||
ck_tile::DeviceMem cache_seqlen_k_buf(
|
||||
need_append_kvcache ? cache_seqlen_ks.size() * sizeof(int32_t) : 0);
|
||||
ck_tile::DeviceMem rotary_cos_buf(rotary_cos_host.get_element_space_size_in_bytes());
|
||||
@@ -763,7 +777,9 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
seqstart_q.ToDevice(seqstart_q_host.data());
|
||||
seqstart_k.ToDevice(seqlen_kpads[0] < 0 ? seqstart_k_host.data()
|
||||
: seqstart_k_with_padding_host.data());
|
||||
seqlen_k_buf.ToDevice(use_kvcache || 0 <= seqlen_kpads[0] ? seqlen_ks.data() : nullptr);
|
||||
seqlen_k_buf.ToDevice((mode == mode_enum::batch && use_kvcache) || 0 <= seqlen_kpads[0]
|
||||
? seqlen_ks.data()
|
||||
: nullptr);
|
||||
cache_seqlen_k_buf.ToDevice(need_append_kvcache ? cache_seqlen_ks.data() : nullptr);
|
||||
rotary_cos_buf.ToDevice(rotary_cos_host.data());
|
||||
rotary_sin_buf.ToDevice(rotary_sin_host.data());
|
||||
@@ -976,8 +992,9 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
(mode == mode_enum::group ? seqstart_q.GetDeviceBuffer() : nullptr);
|
||||
args.seqstart_k_ptr =
|
||||
(mode == mode_enum::group ? seqstart_k.GetDeviceBuffer() : nullptr);
|
||||
args.seqlen_k_ptr =
|
||||
(use_kvcache || 0 <= k_paddings_[0] ? seqlen_k_buf.GetDeviceBuffer() : nullptr);
|
||||
args.seqlen_k_ptr = ((mode == mode_enum::batch && use_kvcache) || 0 <= k_paddings_[0]
|
||||
? seqlen_k_buf.GetDeviceBuffer()
|
||||
: nullptr);
|
||||
|
||||
args.seqlen_k = shape_seqlen_k; // unused in group mode (or kvcache enabled)
|
||||
args.max_seqlen_q = max_seqlen_q;
|
||||
|
||||
@@ -173,8 +173,11 @@ struct fmha_fwd_splitkv_args
|
||||
// seqlen_k = kargs.seqlen_k
|
||||
// group mode: seqlen_q = kargs.seqstart_q_ptr[b + 1] - kargs.seqstart_q_ptr[b]
|
||||
// seqlen_k = kargs.seqstart_k_ptr[b + 1] - kargs.seqstart_k_ptr[b]
|
||||
// kvcache mode (use same kernel as batch mode):
|
||||
// batch mode (kvcache):
|
||||
// seqlen_q = kargs.seqlen_q
|
||||
// seqlen_k = kargs.seqlen_k_ptr[b]
|
||||
// group mode (kvcache):
|
||||
// seqlen_q = kargs.seqstart_q_ptr[b + 1] - kargs.seqstart_q_ptr[b]
|
||||
// seqlen_k = kargs.seqstart_k_ptr[b + 1] - kargs.seqstart_k_ptr[b]
|
||||
const void* seqstart_q_ptr;
|
||||
const void* seqstart_k_ptr;
|
||||
@@ -251,7 +254,7 @@ struct fmha_fwd_appendkv_args
|
||||
ck_tile::index_t batch_stride_block_table; // only used if 'block_table_ptr' is not nullptr
|
||||
ck_tile::index_t page_block_size; // only used if 'block_table_ptr' is not nullptr
|
||||
|
||||
const void* cache_batch_idx;
|
||||
const void* cache_batch_idx; // only used if block_table_ptr is nullptr -> batch mode (kvcache)
|
||||
|
||||
ck_tile::index_t stride_q;
|
||||
ck_tile::index_t stride_k;
|
||||
@@ -389,6 +392,9 @@ auto fmha_fwd_splitkv_create_kargs_and_grids(fmha_fwd_splitkv_args args)
|
||||
args.nhead_q,
|
||||
args.nhead_q / args.nhead_k,
|
||||
args.num_splits,
|
||||
args.block_table_ptr,
|
||||
args.batch_stride_block_table,
|
||||
args.page_block_size,
|
||||
args.scale_s,
|
||||
args.scale_p,
|
||||
args.stride_q,
|
||||
|
||||
@@ -145,7 +145,7 @@ decode_seqlen(mode_enum mode,
|
||||
std::string k_val,
|
||||
std::string k_pad_val,
|
||||
ck_tile::index_t seqlen_k_min = 0,
|
||||
bool use_kvcache = false,
|
||||
bool need_append_kvcache = false,
|
||||
std::optional<unsigned> seed = std::nullopt)
|
||||
{
|
||||
#define _S2I_(str_) static_cast<ck_tile::index_t>(std::atoi((str_).c_str()))
|
||||
@@ -159,7 +159,7 @@ decode_seqlen(mode_enum mode,
|
||||
const ck_tile::index_t seqlen_k_max = (k < 0 ? q : k);
|
||||
std::vector<ck_tile::index_t> seqlen_ks(batch, seqlen_k_max);
|
||||
|
||||
if(1 < batch && use_kvcache)
|
||||
if(1 < batch && need_append_kvcache)
|
||||
{
|
||||
// to keep the original s_k value, we always use seqlen_k_max in first batch
|
||||
randints(std::next(seqlen_ks.begin()),
|
||||
|
||||
Reference in New Issue
Block a user