mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-07 16:26:10 +00:00
Fix wrong seqlen for kvcache
This commit is contained in:
@@ -867,7 +867,6 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
args.seqstart_k_ptr = seqstart_k.GetDeviceBuffer();
|
||||
|
||||
args.seqlen_q = shape_seqlen_q;
|
||||
args.seqlen_k = shape_seqlen_k;
|
||||
args.batch = batch;
|
||||
args.max_seqlen_q = max_seqlen_q;
|
||||
args.hdim_q = hdim_q;
|
||||
@@ -892,6 +891,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
args.seqlen_knew = seqlen_knew;
|
||||
|
||||
args.seqlen_k_ptr = cache_seqlen_k_buf.GetDeviceBuffer();
|
||||
args.seqlen_k = shape_seqlen_k - seqlen_knew; // kvcache seqlen for batch mode
|
||||
|
||||
args.rotary_cos_ptr = rotary_cos_buf.GetDeviceBuffer();
|
||||
args.rotary_sin_ptr = rotary_sin_buf.GetDeviceBuffer();
|
||||
@@ -909,7 +909,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
args.batch_stride_knew = batch_stride_knew;
|
||||
args.batch_stride_vnew = batch_stride_vnew;
|
||||
}
|
||||
else
|
||||
else // fmha_fwd_args or fmha_fwd_splitkv_args
|
||||
{
|
||||
args.bias_ptr = bias.type == bias_enum::alibi ? alibi_slope_buf.GetDeviceBuffer()
|
||||
: bias_buf.GetDeviceBuffer();
|
||||
@@ -917,6 +917,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
args.o_ptr = o_buf.GetDeviceBuffer();
|
||||
|
||||
args.seqlen_k_ptr = k_paddings_[0] < 0 ? nullptr : seqlen_k_buf.GetDeviceBuffer();
|
||||
args.seqlen_k = shape_seqlen_k;
|
||||
|
||||
args.scale_s = scale_s;
|
||||
args.scale_p = scale_p;
|
||||
|
||||
Reference in New Issue
Block a user