Use randomized seqlen_k for kvcache

This commit is contained in:
PoYen, Chen
2024-08-18 17:42:32 +00:00
parent 996f46b0d1
commit e5db71cc59

View File

@@ -371,7 +371,8 @@ bool run(const ck_tile::ArgParser& arg_parser)
arg_parser.get_str("s"),
arg_parser.get_str("s_k"),
arg_parser.get_str("s_kpad"),
0 < seqlen_knew ? seqlen_knew : 0);
0 < seqlen_knew ? seqlen_knew : 0,
use_kvcache);
// compute kvcache seqlen_k (before appending knew/vnew)
auto cache_seqlen_ks = seqlen_ks;
std::transform(cache_seqlen_ks.begin(),
@@ -562,9 +563,9 @@ bool run(const ck_tile::ArgParser& arg_parser)
// host memory for storing all the tensor elements
const ck_tile::index_t shape_batch = (mode == mode_enum::batch ? batch : 1);
const ck_tile::index_t shape_seqlen_q =
(mode == mode_enum::batch ? seqlen_qs[0] : seqstart_q_host.back());
(mode == mode_enum::batch ? max_seqlen_q : seqstart_q_host.back());
const ck_tile::index_t shape_seqlen_k =
(mode == mode_enum::batch ? seqlen_ks[0]
(mode == mode_enum::batch ? max_seqlen_k
: (seqlen_kpads[0] < 0 ? seqstart_k_host.back()
: seqstart_k_with_padding_host.back()));
@@ -726,7 +727,8 @@ bool run(const ck_tile::ArgParser& arg_parser)
ck_tile::DeviceMem o_buf(o_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem seqstart_q(seqstart_q_host.size() * sizeof(int32_t));
ck_tile::DeviceMem seqstart_k(seqstart_k_host.size() * sizeof(int32_t));
ck_tile::DeviceMem seqlen_k_buf(seqlen_kpads[0] < 0 ? 0 : seqlen_ks.size() * sizeof(int32_t));
ck_tile::DeviceMem seqlen_k_buf(
use_kvcache || 0 <= seqlen_kpads[0] ? seqlen_ks.size() * sizeof(int32_t) : 0);
ck_tile::DeviceMem cache_seqlen_k_buf(cache_seqlen_ks.size() * sizeof(int32_t));
ck_tile::DeviceMem rotary_cos_buf(rotary_cos_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem rotary_sin_buf(rotary_sin_host.get_element_space_size_in_bytes());
@@ -744,7 +746,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
seqstart_q.ToDevice(seqstart_q_host.data());
seqstart_k.ToDevice(seqlen_kpads[0] < 0 ? seqstart_k_host.data()
: seqstart_k_with_padding_host.data());
seqlen_k_buf.ToDevice(seqlen_kpads[0] < 0 ? nullptr : seqlen_ks.data());
seqlen_k_buf.ToDevice(use_kvcache || 0 <= seqlen_kpads[0] ? seqlen_ks.data() : nullptr);
cache_seqlen_k_buf.ToDevice(cache_seqlen_ks.data());
rotary_cos_buf.ToDevice(rotary_cos_host.data());
rotary_sin_buf.ToDevice(rotary_sin_host.data());
@@ -954,11 +956,10 @@ bool run(const ck_tile::ArgParser& arg_parser)
(mode == mode_enum::group ? seqstart_q.GetDeviceBuffer() : nullptr);
args.seqstart_k_ptr =
(mode == mode_enum::group ? seqstart_k.GetDeviceBuffer() : nullptr);
args.seqlen_k_ptr = (0 < seqlen_knew || 0 < page_block_size || 0 <= k_paddings_[0]
? seqlen_k_buf.GetDeviceBuffer()
: nullptr);
args.seqlen_k_ptr =
(use_kvcache || 0 <= k_paddings_[0] ? seqlen_k_buf.GetDeviceBuffer() : nullptr);
args.seqlen_k = (args.seqlen_k_ptr == nullptr ? shape_seqlen_k : -1);
args.seqlen_k = shape_seqlen_k;
args.max_seqlen_q = max_seqlen_q;
args.scale_s = scale_s;