From e5db71cc596c182d30f51632e8f838ac9c2599b9 Mon Sep 17 00:00:00 2001 From: "PoYen, Chen" Date: Sun, 18 Aug 2024 17:42:32 +0000 Subject: [PATCH] Use randomized seqlen_k for kvcache --- example/ck_tile/01_fmha/fmha_fwd.cpp | 19 ++++++++++--------- 1 file changed, 10 insertions(+), 9 deletions(-) diff --git a/example/ck_tile/01_fmha/fmha_fwd.cpp b/example/ck_tile/01_fmha/fmha_fwd.cpp index 4b5635455c..6e9133806e 100644 --- a/example/ck_tile/01_fmha/fmha_fwd.cpp +++ b/example/ck_tile/01_fmha/fmha_fwd.cpp @@ -371,7 +371,8 @@ bool run(const ck_tile::ArgParser& arg_parser) arg_parser.get_str("s"), arg_parser.get_str("s_k"), arg_parser.get_str("s_kpad"), - 0 < seqlen_knew ? seqlen_knew : 0); + 0 < seqlen_knew ? seqlen_knew : 0, + use_kvcache); // compute kvcache seqlen_k (before appending knew/vnew) auto cache_seqlen_ks = seqlen_ks; std::transform(cache_seqlen_ks.begin(), @@ -562,9 +563,9 @@ bool run(const ck_tile::ArgParser& arg_parser) // host memory for storing all the tensor elements const ck_tile::index_t shape_batch = (mode == mode_enum::batch ? batch : 1); const ck_tile::index_t shape_seqlen_q = - (mode == mode_enum::batch ? seqlen_qs[0] : seqstart_q_host.back()); + (mode == mode_enum::batch ? max_seqlen_q : seqstart_q_host.back()); const ck_tile::index_t shape_seqlen_k = - (mode == mode_enum::batch ? seqlen_ks[0] + (mode == mode_enum::batch ? max_seqlen_k : (seqlen_kpads[0] < 0 ? seqstart_k_host.back() : seqstart_k_with_padding_host.back())); @@ -726,7 +727,8 @@ bool run(const ck_tile::ArgParser& arg_parser) ck_tile::DeviceMem o_buf(o_host.get_element_space_size_in_bytes()); ck_tile::DeviceMem seqstart_q(seqstart_q_host.size() * sizeof(int32_t)); ck_tile::DeviceMem seqstart_k(seqstart_k_host.size() * sizeof(int32_t)); - ck_tile::DeviceMem seqlen_k_buf(seqlen_kpads[0] < 0 ? 0 : seqlen_ks.size() * sizeof(int32_t)); + ck_tile::DeviceMem seqlen_k_buf( + use_kvcache || 0 <= seqlen_kpads[0] ? seqlen_ks.size() * sizeof(int32_t) : 0); ck_tile::DeviceMem cache_seqlen_k_buf(cache_seqlen_ks.size() * sizeof(int32_t)); ck_tile::DeviceMem rotary_cos_buf(rotary_cos_host.get_element_space_size_in_bytes()); ck_tile::DeviceMem rotary_sin_buf(rotary_sin_host.get_element_space_size_in_bytes()); @@ -744,7 +746,7 @@ bool run(const ck_tile::ArgParser& arg_parser) seqstart_q.ToDevice(seqstart_q_host.data()); seqstart_k.ToDevice(seqlen_kpads[0] < 0 ? seqstart_k_host.data() : seqstart_k_with_padding_host.data()); - seqlen_k_buf.ToDevice(seqlen_kpads[0] < 0 ? nullptr : seqlen_ks.data()); + seqlen_k_buf.ToDevice(use_kvcache || 0 <= seqlen_kpads[0] ? seqlen_ks.data() : nullptr); cache_seqlen_k_buf.ToDevice(cache_seqlen_ks.data()); rotary_cos_buf.ToDevice(rotary_cos_host.data()); rotary_sin_buf.ToDevice(rotary_sin_host.data()); @@ -954,11 +956,10 @@ bool run(const ck_tile::ArgParser& arg_parser) (mode == mode_enum::group ? seqstart_q.GetDeviceBuffer() : nullptr); args.seqstart_k_ptr = (mode == mode_enum::group ? seqstart_k.GetDeviceBuffer() : nullptr); - args.seqlen_k_ptr = (0 < seqlen_knew || 0 < page_block_size || 0 <= k_paddings_[0] - ? seqlen_k_buf.GetDeviceBuffer() - : nullptr); + args.seqlen_k_ptr = + (use_kvcache || 0 <= k_paddings_[0] ? seqlen_k_buf.GetDeviceBuffer() : nullptr); - args.seqlen_k = (args.seqlen_k_ptr == nullptr ? shape_seqlen_k : -1); + args.seqlen_k = shape_seqlen_k; args.max_seqlen_q = max_seqlen_q; args.scale_s = scale_s;