mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-07 08:15:04 +00:00
Use randomized seqlen_k for kvcache
This commit is contained in:
@@ -371,7 +371,8 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
arg_parser.get_str("s"),
|
||||
arg_parser.get_str("s_k"),
|
||||
arg_parser.get_str("s_kpad"),
|
||||
0 < seqlen_knew ? seqlen_knew : 0);
|
||||
0 < seqlen_knew ? seqlen_knew : 0,
|
||||
use_kvcache);
|
||||
// compute kvcache seqlen_k (before appending knew/vnew)
|
||||
auto cache_seqlen_ks = seqlen_ks;
|
||||
std::transform(cache_seqlen_ks.begin(),
|
||||
@@ -562,9 +563,9 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
// host memory for storing all the tensor elements
|
||||
const ck_tile::index_t shape_batch = (mode == mode_enum::batch ? batch : 1);
|
||||
const ck_tile::index_t shape_seqlen_q =
|
||||
(mode == mode_enum::batch ? seqlen_qs[0] : seqstart_q_host.back());
|
||||
(mode == mode_enum::batch ? max_seqlen_q : seqstart_q_host.back());
|
||||
const ck_tile::index_t shape_seqlen_k =
|
||||
(mode == mode_enum::batch ? seqlen_ks[0]
|
||||
(mode == mode_enum::batch ? max_seqlen_k
|
||||
: (seqlen_kpads[0] < 0 ? seqstart_k_host.back()
|
||||
: seqstart_k_with_padding_host.back()));
|
||||
|
||||
@@ -726,7 +727,8 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
ck_tile::DeviceMem o_buf(o_host.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem seqstart_q(seqstart_q_host.size() * sizeof(int32_t));
|
||||
ck_tile::DeviceMem seqstart_k(seqstart_k_host.size() * sizeof(int32_t));
|
||||
ck_tile::DeviceMem seqlen_k_buf(seqlen_kpads[0] < 0 ? 0 : seqlen_ks.size() * sizeof(int32_t));
|
||||
ck_tile::DeviceMem seqlen_k_buf(
|
||||
use_kvcache || 0 <= seqlen_kpads[0] ? seqlen_ks.size() * sizeof(int32_t) : 0);
|
||||
ck_tile::DeviceMem cache_seqlen_k_buf(cache_seqlen_ks.size() * sizeof(int32_t));
|
||||
ck_tile::DeviceMem rotary_cos_buf(rotary_cos_host.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem rotary_sin_buf(rotary_sin_host.get_element_space_size_in_bytes());
|
||||
@@ -744,7 +746,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
seqstart_q.ToDevice(seqstart_q_host.data());
|
||||
seqstart_k.ToDevice(seqlen_kpads[0] < 0 ? seqstart_k_host.data()
|
||||
: seqstart_k_with_padding_host.data());
|
||||
seqlen_k_buf.ToDevice(seqlen_kpads[0] < 0 ? nullptr : seqlen_ks.data());
|
||||
seqlen_k_buf.ToDevice(use_kvcache || 0 <= seqlen_kpads[0] ? seqlen_ks.data() : nullptr);
|
||||
cache_seqlen_k_buf.ToDevice(cache_seqlen_ks.data());
|
||||
rotary_cos_buf.ToDevice(rotary_cos_host.data());
|
||||
rotary_sin_buf.ToDevice(rotary_sin_host.data());
|
||||
@@ -954,11 +956,10 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
(mode == mode_enum::group ? seqstart_q.GetDeviceBuffer() : nullptr);
|
||||
args.seqstart_k_ptr =
|
||||
(mode == mode_enum::group ? seqstart_k.GetDeviceBuffer() : nullptr);
|
||||
args.seqlen_k_ptr = (0 < seqlen_knew || 0 < page_block_size || 0 <= k_paddings_[0]
|
||||
? seqlen_k_buf.GetDeviceBuffer()
|
||||
: nullptr);
|
||||
args.seqlen_k_ptr =
|
||||
(use_kvcache || 0 <= k_paddings_[0] ? seqlen_k_buf.GetDeviceBuffer() : nullptr);
|
||||
|
||||
args.seqlen_k = (args.seqlen_k_ptr == nullptr ? shape_seqlen_k : -1);
|
||||
args.seqlen_k = shape_seqlen_k;
|
||||
args.max_seqlen_q = max_seqlen_q;
|
||||
|
||||
args.scale_s = scale_s;
|
||||
|
||||
Reference in New Issue
Block a user