Fix wrong seqlen for kvcache

This commit is contained in:
PoYen, Chen
2024-08-08 20:39:36 +00:00
parent 6a399ea47e
commit 822d5dcd8e

View File

@@ -867,7 +867,6 @@ bool run(const ck_tile::ArgParser& arg_parser)
args.seqstart_k_ptr = seqstart_k.GetDeviceBuffer();
args.seqlen_q = shape_seqlen_q;
args.seqlen_k = shape_seqlen_k;
args.batch = batch;
args.max_seqlen_q = max_seqlen_q;
args.hdim_q = hdim_q;
@@ -892,6 +891,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
args.seqlen_knew = seqlen_knew;
args.seqlen_k_ptr = cache_seqlen_k_buf.GetDeviceBuffer();
args.seqlen_k = shape_seqlen_k - seqlen_knew; // kvcache seqlen for batch mode
args.rotary_cos_ptr = rotary_cos_buf.GetDeviceBuffer();
args.rotary_sin_ptr = rotary_sin_buf.GetDeviceBuffer();
@@ -909,7 +909,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
args.batch_stride_knew = batch_stride_knew;
args.batch_stride_vnew = batch_stride_vnew;
}
else
else // fmha_fwd_args or fmha_fwd_splitkv_args
{
args.bias_ptr = bias.type == bias_enum::alibi ? alibi_slope_buf.GetDeviceBuffer()
: bias_buf.GetDeviceBuffer();
@@ -917,6 +917,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
args.o_ptr = o_buf.GetDeviceBuffer();
args.seqlen_k_ptr = k_paddings_[0] < 0 ? nullptr : seqlen_k_buf.GetDeviceBuffer();
args.seqlen_k = shape_seqlen_k;
args.scale_s = scale_s;
args.scale_p = scale_p;