mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-03 13:11:25 +00:00
[CK_TILE] Fix incorrect computation of group mode PagedAttention (#1688)
* Allow getting batch size from splitkv tile partitioner * Fix wrong paged-kvcache impl for group mode * Fix wrong example code for page-kvcache * Undo changes in fmha_fwd.cpp * Always use 2D block table * Add is_gappy kernel argument for paged-kvcache The is_gappy argument is used for differentiating seqstart_k_ptr usage in flash-attention & xformers * Remove out-of-date comments * Remove no-longer used method * Fix wrong # page-block calculation * Fix wrong comment --------- Co-authored-by: Qianfeng <qianfeng.zhang@amd.com>
This commit is contained in:
@@ -1046,6 +1046,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
(0 < page_block_size ? block_table_buf.GetDeviceBuffer() : nullptr);
|
||||
args.batch_stride_block_table = batch_stride_block_table;
|
||||
args.page_block_size = page_block_size;
|
||||
args.is_gappy = false; // use 'false' for flash-attention integration
|
||||
|
||||
args.cache_batch_idx =
|
||||
(use_cache_batch_idx ? cache_batch_idx_buf.GetDeviceBuffer() : nullptr);
|
||||
|
||||
@@ -165,6 +165,8 @@ struct fmha_fwd_splitkv_args
|
||||
void* block_table_ptr;
|
||||
ck_tile::index_t batch_stride_block_table; // only used if 'block_table_ptr' is not nullptr
|
||||
ck_tile::index_t page_block_size; // only used if 'block_table_ptr' is not nullptr
|
||||
bool is_gappy; // differentiate seqstart_k_ptr usage. only used if 'block_table_ptr' is not
|
||||
// nullptr.
|
||||
|
||||
const void* cache_batch_idx;
|
||||
|
||||
@@ -173,12 +175,21 @@ struct fmha_fwd_splitkv_args
|
||||
// seqlen_k = kargs.seqlen_k
|
||||
// group mode: seqlen_q = kargs.seqstart_q_ptr[b + 1] - kargs.seqstart_q_ptr[b]
|
||||
// seqlen_k = kargs.seqstart_k_ptr[b + 1] - kargs.seqstart_k_ptr[b]
|
||||
// or kargs.seqlen_k_ptr[b]
|
||||
//
|
||||
// batch mode (kvcache):
|
||||
// seqlen_q = kargs.seqlen_q
|
||||
// seqlen_k = kargs.seqlen_k_ptr[b]
|
||||
// group mode (kvcache):
|
||||
// seqlen_q = kargs.seqstart_q_ptr[b + 1] - kargs.seqstart_q_ptr[b]
|
||||
//
|
||||
// when is_gappy=true:
|
||||
// seqlen_k = kargs.seqlen_k_ptr[b]
|
||||
// seqstart_k_ptr[b] now store local offset of each batch
|
||||
//
|
||||
// when is_gappy=false:
|
||||
// seqlen_k = kargs.seqstart_k_ptr[b + 1] - kargs.seqstart_k_ptr[b]
|
||||
// or kargs.seqlen_k_ptr[b]
|
||||
const void* seqstart_q_ptr;
|
||||
const void* seqstart_k_ptr;
|
||||
const void* seqlen_k_ptr;
|
||||
@@ -395,6 +406,7 @@ auto fmha_fwd_splitkv_create_kargs_and_grids(fmha_fwd_splitkv_args args)
|
||||
args.block_table_ptr,
|
||||
args.batch_stride_block_table,
|
||||
args.page_block_size,
|
||||
args.is_gappy,
|
||||
args.scale_s,
|
||||
args.scale_p,
|
||||
args.stride_q,
|
||||
|
||||
Reference in New Issue
Block a user