From 17e404be3b83a64e54d1d7c6c76032a1d2eeb6fe Mon Sep 17 00:00:00 2001 From: Qianfeng Zhang Date: Fri, 31 Oct 2025 14:04:32 +0000 Subject: [PATCH] Support separate sequence lengths for q and kv --- .../example_hstu_attention.cpp | 160 +++++++++------ ...stu_attention_batched_forward_dispatch.hpp | 7 +- .../hstu_attention_fwd_kernel.hpp | 59 +++--- ...hstu_attention_jagged_forward_dispatch.hpp | 3 +- .../hstu_attention_params.hpp | 8 +- .../18_hstu_attention/hstu_block_masking.hpp | 185 ++++++++++-------- .../reference_hstu_attention.hpp | 43 ++-- 7 files changed, 277 insertions(+), 188 deletions(-) diff --git a/example/ck_tile/18_hstu_attention/example_hstu_attention.cpp b/example/ck_tile/18_hstu_attention/example_hstu_attention.cpp index 6f9e1b9391..ea4c6b1165 100644 --- a/example/ck_tile/18_hstu_attention/example_hstu_attention.cpp +++ b/example/ck_tile/18_hstu_attention/example_hstu_attention.cpp @@ -100,7 +100,8 @@ auto create_args(int argc, char* argv[]) .insert("nhead", "4", "number of heads") .insert("hdim_qk", "64", "headdim size of Q/K") .insert("hdim_v", "64", "headdim size of V/O") - .insert("seqlens", "400", "uih seqlen of single or all batches for query and key/value tensor, actually allocated seqlen will include the target of each batch and context_len") + .insert("seqlens", "400", "uih seqlen of single or all batches for query tensor, actually allocated seqlen will include the target of each batch and context_len") + .insert("seqlens_kv", "", "uih seqlen of single or all batches for key/value tensor, actually allocated seqlen will include the target of each batch and context_len") .insert("max_seqlen", "0", "max uih_seqlen, can be ignored, or else must be equal or bigger than the maximum of all uih seqlens") .insert("targets", "", "sequence length at the end of query/key token sequence that should be excluded from attention") .insert("max_target", "0", "max target, can be ignored, or else must be equal of bigger than the maximum of all targets") @@ -238,13 +239,15 @@ bool run(const ck_tile::ArgParser& arg_parser) std::string str_of_targets = arg_parser.get_str("targets"); std::vector num_targets = get_integers_from_string(str_of_targets); - std::string str_of_lengths = arg_parser.get_str("seqlens"); - std::vector seq_lengths = get_integers_from_string(str_of_lengths); + std::string str_of_lengths_q = arg_parser.get_str("seqlens"); + std::vector seq_lengths_q = get_integers_from_string(str_of_lengths_q); + + std::string str_of_lengths_kv = arg_parser.get_str("seqlens_kv"); + std::vector seq_lengths_kv = get_integers_from_string(str_of_lengths_kv); int input_max_uih_seqlen = arg_parser.get_int("max_seqlen"); int input_max_target = arg_parser.get_int("max_target"); - int uih_seqlen = 0; // means total seq lengths for jagged int max_uih_seqlen = 0; int max_target = 0; @@ -264,31 +267,43 @@ bool run(const ck_tile::ArgParser& arg_parser) max_target = max(max_target, num_targets[i]); }; - HSTU_CHECK(!seq_lengths.empty(), "sequence lengths shoud be defined!"); + HSTU_CHECK(!seq_lengths_q.empty(), "sequence lengths of q shoud be defined!"); + + // assume seq_lengths_kv is same as seq_lengths_q if not defined + if(seq_lengths_kv.empty()) + seq_lengths_kv = seq_lengths_q; if(is_jagged) { // supplement seq_lengths using the last input value if user-provided lengths not enough - if(static_cast(seq_lengths.size()) < num_batch) + if(static_cast(seq_lengths_q.size()) < num_batch) { - auto last_len = seq_lengths.back(); + auto last_len = seq_lengths_q.back(); - for(int i = seq_lengths.size(); i < num_batch; i++) - seq_lengths.push_back(last_len); + for(int i = seq_lengths_q.size(); i < num_batch; i++) + seq_lengths_q.push_back(last_len); + }; + + // supplement seq_lengths_kv using the last input value if user-provided lengths not enough + if(static_cast(seq_lengths_kv.size()) < num_batch) + { + auto last_len = seq_lengths_kv.back(); + + for(int i = seq_lengths_kv.size(); i < num_batch; i++) + seq_lengths_kv.push_back(last_len); }; // only consider num_batch values even if more values are provided by the user for(int i = 0; i < num_batch; i++) { - max_uih_seqlen = max(max_uih_seqlen, seq_lengths[i]); + max_uih_seqlen = max(max_uih_seqlen, seq_lengths_q[i]); }; } else { - HSTU_CHECK(1 == seq_lengths.size(), + HSTU_CHECK(1 == seq_lengths_q.size() && 1 == seq_lengths_kv.size(), "sequence lengths for batched mode shoud have single element!"); - uih_seqlen = seq_lengths[0]; - max_uih_seqlen = uih_seqlen; + max_uih_seqlen = max(seq_lengths_q[0], seq_lengths_kv[0]); }; // the user input of max_uih_seqlen can either be ignored or be bigger than all uih_seqlens @@ -304,28 +319,43 @@ bool run(const ck_tile::ArgParser& arg_parser) max_uih_seqlen = (input_max_uih_seqlen > 0) ? input_max_uih_seqlen : max_uih_seqlen; max_target = (input_max_target > 0) ? input_max_target : max_target; - int phy_seqlen = 0; - int max_seqlen = max_uih_seqlen + max_target + contextual_seqlen; + int phy_seqlen_q = 0; + int phy_seqlen_kv = 0; + int max_seqlen = max_uih_seqlen + max_target + contextual_seqlen; - std::vector seq_offsets; + std::vector seq_offsets_q; + std::vector seq_offsets_kv; if(is_jagged) { - seq_offsets.push_back(0); + seq_offsets_q.push_back(0); for(int i = 0; i < num_batch; i++) { int batch_seqlen = num_targets.empty() - ? seq_lengths[i] + contextual_seqlen - : seq_lengths[i] + num_targets[i] + contextual_seqlen; + ? seq_lengths_q[i] + contextual_seqlen + : seq_lengths_q[i] + num_targets[i] + contextual_seqlen; - phy_seqlen += batch_seqlen; - seq_offsets.push_back(phy_seqlen); + phy_seqlen_q += batch_seqlen; + seq_offsets_q.push_back(phy_seqlen_q); + }; + + seq_offsets_kv.push_back(0); + + for(int i = 0; i < num_batch; i++) + { + int batch_seqlen = num_targets.empty() + ? seq_lengths_kv[i] + contextual_seqlen + : seq_lengths_kv[i] + num_targets[i] + contextual_seqlen; + + phy_seqlen_kv += batch_seqlen; + seq_offsets_kv.push_back(phy_seqlen_kv); }; } else { - phy_seqlen = max_seqlen; + phy_seqlen_q = max_seqlen; + phy_seqlen_kv = max_seqlen; }; long total_flops = 0; @@ -335,10 +365,11 @@ bool run(const ck_tile::ArgParser& arg_parser) { for(int i = 0; i < num_batch; i++) { - int len = seq_offsets[i + 1] - seq_offsets[i]; - total_flops += - (static_cast(len) * len * hdim_qk + static_cast(len) * hdim_v * len) * - 2; + int len_q = seq_offsets_q[i + 1] - seq_offsets_q[i]; + int len_kv = seq_offsets_kv[i + 1] - seq_offsets_kv[i]; + total_flops += (static_cast(len_q) * len_kv * hdim_qk + + static_cast(len_q) * hdim_v * len_kv) * + 2; }; total_flops *= num_head; @@ -346,21 +377,21 @@ bool run(const ck_tile::ArgParser& arg_parser) else { total_flops = static_cast(num_batch) * num_head * - (static_cast(phy_seqlen) * phy_seqlen * hdim_qk + - static_cast(phy_seqlen) * hdim_v * phy_seqlen) * + (static_cast(phy_seqlen_q) * phy_seqlen_kv * hdim_qk + + static_cast(phy_seqlen_q) * hdim_v * phy_seqlen_kv) * 2; }; int batches_for_alloc = is_jagged ? 1 : num_batch; ck_tile::HostTensor q_host( - std::array{batches_for_alloc, phy_seqlen, num_head, hdim_qk}); + std::array{batches_for_alloc, phy_seqlen_q, num_head, hdim_qk}); ck_tile::HostTensor k_host( - std::array{batches_for_alloc, phy_seqlen, num_head, hdim_qk}); + std::array{batches_for_alloc, phy_seqlen_kv, num_head, hdim_qk}); ck_tile::HostTensor v_host( - std::array{batches_for_alloc, phy_seqlen, num_head, hdim_v}); + std::array{batches_for_alloc, phy_seqlen_kv, num_head, hdim_v}); ck_tile::HostTensor o_host_ref( - std::array{batches_for_alloc, phy_seqlen, num_head, hdim_v}); + std::array{batches_for_alloc, phy_seqlen_q, num_head, hdim_v}); ck_tile::HostTensor mask_host( save_mask ? std::array{num_batch, num_head, max_seqlen, max_seqlen} @@ -393,7 +424,8 @@ bool run(const ck_tile::ArgParser& arg_parser) ck_tile::DeviceMem v_dev(v_host.get_element_space_size_in_bytes()); ck_tile::DeviceMem o_dev(o_host_ref.get_element_space_size_in_bytes()); - ck_tile::DeviceMem seq_offsets_dev(seq_offsets.size() * sizeof(int)); + ck_tile::DeviceMem seq_offsets_q_dev(seq_offsets_q.size() * sizeof(int)); + ck_tile::DeviceMem seq_offsets_kv_dev(seq_offsets_kv.size() * sizeof(int)); ck_tile::DeviceMem num_targets_dev(num_targets.size() * sizeof(int)); q_dev.ToDevice(q_host.data()); @@ -401,7 +433,10 @@ bool run(const ck_tile::ArgParser& arg_parser) v_dev.ToDevice(v_host.data()); if(is_jagged) - seq_offsets_dev.ToDevice(seq_offsets.data()); + { + seq_offsets_q_dev.ToDevice(seq_offsets_q.data()); + seq_offsets_kv_dev.ToDevice(seq_offsets_kv.data()); + }; if(!num_targets.empty()) num_targets_dev.ToDevice(num_targets.data()); @@ -411,30 +446,31 @@ bool run(const ck_tile::ArgParser& arg_parser) if(is_jagged) { - params.is_jagged = true; - params.num_batch = num_batch; - params.seq_offsets_ptr = seq_offsets_dev.GetDeviceBuffer(); - params.max_seqlen = max_seqlen; - params.q_ptr = q_dev.GetDeviceBuffer(); - params.k_ptr = k_dev.GetDeviceBuffer(); - params.v_ptr = v_dev.GetDeviceBuffer(); - params.bias_ptr = nullptr; // bias is not supported at present - params.o_ptr = o_dev.GetDeviceBuffer(); - params.hdim_qk = hdim_qk; - params.hdim_v = hdim_v; - params.num_head = num_head; - params.scale_s = scale_s; - params.attn_scale = attn_scale; - params.seq_stride_q = q_host.get_strides()[1]; - params.seq_stride_k = k_host.get_strides()[1]; - params.seq_stride_v = v_host.get_strides()[1]; - params.seq_stride_bias = 0; - params.seq_stride_o = o_host_ref.get_strides()[1]; - params.nhead_stride_q = q_host.get_strides()[2]; - params.nhead_stride_k = k_host.get_strides()[2]; - params.nhead_stride_v = v_host.get_strides()[2]; - params.nhead_stride_bias = 0; - params.nhead_stride_o = o_host_ref.get_strides()[2]; + params.is_jagged = true; + params.num_batch = num_batch; + params.seq_q_offsets_ptr = seq_offsets_q_dev.GetDeviceBuffer(); + params.seq_kv_offsets_ptr = seq_offsets_kv_dev.GetDeviceBuffer(); + params.max_seqlen = max_seqlen; + params.q_ptr = q_dev.GetDeviceBuffer(); + params.k_ptr = k_dev.GetDeviceBuffer(); + params.v_ptr = v_dev.GetDeviceBuffer(); + params.bias_ptr = nullptr; // bias is not supported at present + params.o_ptr = o_dev.GetDeviceBuffer(); + params.hdim_qk = hdim_qk; + params.hdim_v = hdim_v; + params.num_head = num_head; + params.scale_s = scale_s; + params.attn_scale = attn_scale; + params.seq_stride_q = q_host.get_strides()[1]; + params.seq_stride_k = k_host.get_strides()[1]; + params.seq_stride_v = v_host.get_strides()[1]; + params.seq_stride_bias = 0; + params.seq_stride_o = o_host_ref.get_strides()[1]; + params.nhead_stride_q = q_host.get_strides()[2]; + params.nhead_stride_k = k_host.get_strides()[2]; + params.nhead_stride_v = v_host.get_strides()[2]; + params.nhead_stride_bias = 0; + params.nhead_stride_o = o_host_ref.get_strides()[2]; params.num_targets_ptr = num_targets.empty() ? nullptr : num_targets_dev.GetDeviceBuffer(); params.use_softmax = use_softmax; params.use_causal = use_causal; @@ -449,7 +485,8 @@ bool run(const ck_tile::ArgParser& arg_parser) { params.is_jagged = false; params.num_batch = num_batch; - params.seqlen = max_seqlen; + params.seqlen_q = phy_seqlen_q; + params.seqlen_kv = phy_seqlen_kv; params.q_ptr = q_dev.GetDeviceBuffer(); params.k_ptr = k_dev.GetDeviceBuffer(); params.v_ptr = v_dev.GetDeviceBuffer(); @@ -532,7 +569,8 @@ bool run(const ck_tile::ArgParser& arg_parser) scale_s, attn_scale, max_seqlen, - seq_offsets, + seq_offsets_q, + seq_offsets_kv, num_targets, contextual_seqlen, window_size, @@ -540,7 +578,7 @@ bool run(const ck_tile::ArgParser& arg_parser) }); ck_tile::HostTensor o_host( - std::array{batches_for_alloc, phy_seqlen, num_head, hdim_v}); + std::array{batches_for_alloc, phy_seqlen_q, num_head, hdim_v}); o_dev.FromDevice(o_host.data()); diff --git a/example/ck_tile/18_hstu_attention/hstu_attention_batched_forward_dispatch.hpp b/example/ck_tile/18_hstu_attention/hstu_attention_batched_forward_dispatch.hpp index f17aa3a31e..87b53a133e 100644 --- a/example/ck_tile/18_hstu_attention/hstu_attention_batched_forward_dispatch.hpp +++ b/example/ck_tile/18_hstu_attention/hstu_attention_batched_forward_dispatch.hpp @@ -60,7 +60,7 @@ struct batched_forward_causal_softmax_bias_dropout_dispatch { constexpr ck_tile::index_t occupancy = -1; - const bool pad_seqlen_k = !(param.seqlen % HstuAttentionTileSetting::kN0 == 0); + const bool pad_seqlen_k = !(param.seqlen_kv % HstuAttentionTileSetting::kN0 == 0); const bool pad_headdim_qk = !(param.hdim_qk % HstuAttentionTileSetting::kSubQKHeaddim == 0); const bool pad_headdim_v = !(param.hdim_v % HstuAttentionTileSetting::kN1 == 0); @@ -125,7 +125,8 @@ struct batched_forward_causal_softmax_bias_dropout_dispatch param.v_ptr, param.bias_ptr, param.o_ptr, - param.seqlen, + param.seqlen_q, + param.seqlen_kv, param.hdim_qk, param.hdim_v, param.num_head, @@ -157,7 +158,7 @@ struct batched_forward_causal_softmax_bias_dropout_dispatch bool has_minfull_attn_seqlen = (param.min_full_attn_seqlen > 0); dim3 kGridSize = HstuKernel::GridSize( - param.num_batch, param.num_head, param.seqlen, param.hdim_v, has_minfull_attn_seqlen); + param.num_batch, param.num_head, param.seqlen_q, param.hdim_v, has_minfull_attn_seqlen); constexpr dim3 kBlockSize = HstuKernel::BlockSize(); constexpr ck_tile::index_t kBlockPerCu = HstuKernel::kBlockPerCu; diff --git a/example/ck_tile/18_hstu_attention/hstu_attention_fwd_kernel.hpp b/example/ck_tile/18_hstu_attention/hstu_attention_fwd_kernel.hpp index 0d8282fc5a..fcd8f6e229 100644 --- a/example/ck_tile/18_hstu_attention/hstu_attention_fwd_kernel.hpp +++ b/example/ck_tile/18_hstu_attention/hstu_attention_fwd_kernel.hpp @@ -78,7 +78,8 @@ struct HstuAttentionFwdKernel ck_tile::index_t nhead_stride_v; ck_tile::index_t nhead_stride_o; - ck_tile::index_t seqlen; + ck_tile::index_t seqlen_q; + ck_tile::index_t seqlen_kv; ck_tile::index_t hdim_qk; ck_tile::index_t hdim_v; @@ -98,7 +99,8 @@ struct HstuAttentionFwdKernel struct HstuAttentionFwdJaggModeBaseKargs { - const int32_t* seq_offsets_ptr; + const int32_t* seq_q_offsets_ptr; + const int32_t* seq_kv_offsets_ptr; ck_tile::index_t seq_stride_q; ck_tile::index_t seq_stride_k; @@ -120,7 +122,8 @@ struct HstuAttentionFwdKernel ck_tile::index_t hdim_qk; ck_tile::index_t hdim_v; - ck_tile::index_t seqlen; + ck_tile::index_t seqlen_q; + ck_tile::index_t seqlen_kv; ck_tile::index_t num_head; float scale_s; // scaling value exerted on the immediate Q@K result @@ -196,7 +199,8 @@ struct HstuAttentionFwdKernel const void* v_ptr, const void* bias_ptr, void* o_ptr, - ck_tile::index_t seqlen, + ck_tile::index_t seqlen_q, + ck_tile::index_t seqlen_kv, ck_tile::index_t hdim_qk, ck_tile::index_t hdim_v, ck_tile::index_t num_head, @@ -239,7 +243,8 @@ struct HstuAttentionFwdKernel nhead_stride_k, nhead_stride_v, nhead_stride_o, - seqlen, + seqlen_q, + seqlen_kv, hdim_qk, hdim_v, seq_stride_q, @@ -248,7 +253,8 @@ struct HstuAttentionFwdKernel seq_stride_o, num_head, scale_s, - attn_scale ? attn_scale : 1.0f / static_cast(seqlen), // max_seqlen + attn_scale ? attn_scale + : 1.0f / static_cast(max(seqlen_q, seqlen_kv)), // max_seqlen contextual_seqlen, window_size, min_full_attn_seqlen}, // args for common karg @@ -278,7 +284,8 @@ struct HstuAttentionFwdKernel const void* v_ptr, const void* bias_ptr, void* o_ptr, - const void* seq_offsets_ptr, + const void* seq_q_offsets_ptr, + const void* seq_kv_offsets_ptr, ck_tile::index_t max_seqlen, ck_tile::index_t hdim_qk, ck_tile::index_t hdim_v, @@ -304,7 +311,8 @@ struct HstuAttentionFwdKernel uint64_t philox_offset) { Kargs kargs{ - {reinterpret_cast(seq_offsets_ptr), + {reinterpret_cast(seq_q_offsets_ptr), + reinterpret_cast(seq_kv_offsets_ptr), seq_stride_q, seq_stride_k, seq_stride_v, @@ -320,7 +328,8 @@ struct HstuAttentionFwdKernel nhead_stride_o, hdim_qk, hdim_v, - -1, // seqlen will be updated by another pointer + -1, // seqlen_q will be updated by another pointer + -1, // seqlen_kv will be updated by another pointer num_head, scale_s, attn_scale ? attn_scale : 1.0f / static_cast(max_seqlen), @@ -465,8 +474,8 @@ struct HstuAttentionFwdKernel if constexpr(kIsJagged) { // get starting offset for each batch - const long_index_t query_start = kargs.seq_offsets_ptr[i_batch]; - const long_index_t key_start = query_start; + const long_index_t query_start = kargs.seq_q_offsets_ptr[i_batch]; + const long_index_t key_start = kargs.seq_kv_offsets_ptr[i_batch]; batch_offset_q = query_start * kargs.seq_stride_q; batch_offset_k = key_start * kargs.seq_stride_k; @@ -478,7 +487,10 @@ struct HstuAttentionFwdKernel } batch_offset_o = query_start * kargs.seq_stride_o; - kargs.seqlen = kargs.seq_offsets_ptr[i_batch + 1] - kargs.seq_offsets_ptr[i_batch]; + kargs.seqlen_q = + kargs.seq_q_offsets_ptr[i_batch + 1] - kargs.seq_q_offsets_ptr[i_batch]; + kargs.seqlen_kv = + kargs.seq_kv_offsets_ptr[i_batch + 1] - kargs.seq_kv_offsets_ptr[i_batch]; } else { @@ -494,16 +506,16 @@ struct HstuAttentionFwdKernel int num_target = (kargs.num_targets_ptr == nullptr) ? 0 : kargs.num_targets_ptr[i_batch]; - index_t seqlen_in_first_split = kargs.seqlen; + index_t seqlen_in_first_split = kargs.seqlen_q; bool is_tile_in_first_split = true; index_t i_m0; if(kargs.min_full_attn_seqlen > 0) { // need consider for cases where min_full_attn_seqlen be bigger than max_uih_len - if(kargs.seqlen - num_target > kargs.min_full_attn_seqlen) + if(kargs.seqlen_q - num_target > kargs.min_full_attn_seqlen) { - seqlen_in_first_split = kargs.seqlen - num_target - kargs.min_full_attn_seqlen; + seqlen_in_first_split = kargs.seqlen_q - num_target - kargs.min_full_attn_seqlen; index_t num_tile_in_first_split = ck_tile::integer_divide_ceil(seqlen_in_first_split, HstuAttentionPipeline::kM0); @@ -522,7 +534,7 @@ struct HstuAttentionFwdKernel is_tile_in_first_split = false; // adjust the min_full_attn_seqlen to be passed to HstuBlockMask constructor - kargs.min_full_attn_seqlen = kargs.seqlen - num_target; + kargs.min_full_attn_seqlen = kargs.seqlen_q - num_target; i_m0 = __builtin_amdgcn_readfirstlane(i_tile_m * HstuAttentionPipeline::kM0); }; @@ -532,7 +544,7 @@ struct HstuAttentionFwdKernel const index_t i_n1 = __builtin_amdgcn_readfirstlane(i_tile_n * HstuAttentionPipeline::kN1); - index_t seqlen_q_in_ctrl = is_tile_in_first_split ? seqlen_in_first_split : kargs.seqlen; + index_t seqlen_q_in_ctrl = is_tile_in_first_split ? seqlen_in_first_split : kargs.seqlen_q; if(seqlen_q_in_ctrl <= i_m0) return; @@ -567,7 +579,7 @@ struct HstuAttentionFwdKernel const auto k_dram = [&]() { const auto k_dram_naive = make_naive_tensor_view( k_ptr, - make_tuple(kargs.seqlen, kargs.hdim_qk), + make_tuple(kargs.seqlen_kv, kargs.hdim_qk), make_tuple(kargs.seq_stride_k, 1), number{}, number<1>{}); @@ -580,7 +592,7 @@ struct HstuAttentionFwdKernel const auto v_dram = [&]() { const auto v_dram_naive = make_naive_tensor_view( v_ptr, - make_tuple(kargs.seqlen, kargs.hdim_v), + make_tuple(kargs.seqlen_kv, kargs.hdim_v), make_tuple(kargs.seq_stride_v, 1), number{}, number<1>{}); @@ -590,7 +602,7 @@ struct HstuAttentionFwdKernel const auto v_dram_transposed = transform_tensor_view(v_dram_naive, make_tuple(make_pass_through_transform(kargs.hdim_v), - make_pass_through_transform(kargs.seqlen)), + make_pass_through_transform(kargs.seqlen_kv)), make_tuple(sequence<1>{}, sequence<0>{}), make_tuple(sequence<0>{}, sequence<1>{})); @@ -641,7 +653,7 @@ struct HstuAttentionFwdKernel const auto bias_dram = [&]() { const auto bias_dram_naive = make_naive_tensor_view( bias_ptr, - make_tuple(seqlen_q_in_ctrl, kargs.seqlen), + make_tuple(seqlen_q_in_ctrl, kargs.seqlen_kv), make_tuple(kargs.seq_stride_bias, 1), number{}, number<1>{}); @@ -682,7 +694,8 @@ struct HstuAttentionFwdKernel using HstuMaskType = typename ck_tile::HstuBlockMasking::Type; const auto mask = make_hstu_block_mask_with_local(is_tile_in_first_split, - kargs.seqlen, + kargs.seqlen_q, + kargs.seqlen_kv, kargs.contextual_seqlen, num_target, kargs.window_size, @@ -703,7 +716,7 @@ struct HstuAttentionFwdKernel using HstuMaskType = typename ck_tile::HstuBlockMasking::Type; const auto mask = make_hstu_block_mask_without_local( - kargs.seqlen, kargs.contextual_seqlen, num_target); + kargs.seqlen_q, kargs.seqlen_kv, kargs.contextual_seqlen, num_target); return HstuAttentionPipeline{}(q_dram_window, k_dram_window, diff --git a/example/ck_tile/18_hstu_attention/hstu_attention_jagged_forward_dispatch.hpp b/example/ck_tile/18_hstu_attention/hstu_attention_jagged_forward_dispatch.hpp index d72526d1b3..c72eb99b5d 100644 --- a/example/ck_tile/18_hstu_attention/hstu_attention_jagged_forward_dispatch.hpp +++ b/example/ck_tile/18_hstu_attention/hstu_attention_jagged_forward_dispatch.hpp @@ -118,7 +118,8 @@ struct jagged_forward_causal_softmax_bias_dropout_dispatch param.v_ptr, param.bias_ptr, param.o_ptr, - param.seq_offsets_ptr, + param.seq_q_offsets_ptr, + param.seq_kv_offsets_ptr, param.max_seqlen, param.hdim_qk, param.hdim_v, diff --git a/example/ck_tile/18_hstu_attention/hstu_attention_params.hpp b/example/ck_tile/18_hstu_attention/hstu_attention_params.hpp index 2e85e11971..68404422f5 100644 --- a/example/ck_tile/18_hstu_attention/hstu_attention_params.hpp +++ b/example/ck_tile/18_hstu_attention/hstu_attention_params.hpp @@ -10,9 +10,11 @@ struct HstuAttentionFwdParams bool is_jagged; ck_tile::index_t num_batch; - ck_tile::index_t seqlen; // batched mode only - const void* seq_offsets_ptr; // jagged mode only - ck_tile::index_t max_seqlen; // jagged mode only + ck_tile::index_t seqlen_q; // batched mode only + ck_tile::index_t seqlen_kv; // batched mode only + const void* seq_q_offsets_ptr; // jagged mode only + const void* seq_kv_offsets_ptr; // jagged mode only + ck_tile::index_t max_seqlen; // jagged mode only const void* q_ptr; const void* k_ptr; diff --git a/example/ck_tile/18_hstu_attention/hstu_block_masking.hpp b/example/ck_tile/18_hstu_attention/hstu_block_masking.hpp index 24c8de1d50..22b0fdbfe6 100644 --- a/example/ck_tile/18_hstu_attention/hstu_block_masking.hpp +++ b/example/ck_tile/18_hstu_attention/hstu_block_masking.hpp @@ -15,42 +15,54 @@ struct HstuBlockMaskWithLocal static constexpr bool IsMasking = true; // is_tile_in_first_split is false only when min_full_attn_seqlen > 0 and the current - // tile is inside scope [max_uih_len - min_full_attn_seqlen, seqlen); for other cases + // tile is inside scope [max_uih_len - min_full_attn_seqlen, seqlen_q); for other cases // and tiles, is_tile_in_first_split is true bool is_tile_in_first_split; - int seqlen; + int seqlen_q; + int seqlen_k; int contextual_seqlen; int min_full_attn_seqlen; int max_attn_len; - int max_uih_len; - int max_id; + int max_q_uih_len; + int max_k_uih_len; + int max_row_id; + int max_col_id; CK_TILE_HOST_DEVICE HstuBlockMaskWithLocal(bool is_tile_in_first_split_, - int seqlen_, + int seqlen_q_, + int seqlen_k_, int contextual_seqlen_, int max_attn_len_, int min_full_attn_seqlen_, int num_target_) : is_tile_in_first_split(is_tile_in_first_split_), - seqlen(seqlen_), + seqlen_q(seqlen_q_), + seqlen_k(seqlen_k_), contextual_seqlen(contextual_seqlen_), min_full_attn_seqlen(min_full_attn_seqlen_) { - max_uih_len = seqlen - num_target_; + max_q_uih_len = seqlen_q - num_target_; + max_k_uih_len = seqlen_k - num_target_; // in case user provided max_attn_len_ could be bigger than max_uih_len - max_attn_len = min(max_uih_len, max_attn_len_); + max_attn_len = min(max_k_uih_len, min(max_q_uih_len, max_attn_len_)); // assuming min_full_attn_seqlen has higher priority, ensure contextual scope not collide // with min_full_attn_seqlen scope - contextual_seqlen = min(contextual_seqlen, max_uih_len - min_full_attn_seqlen); + contextual_seqlen = min(contextual_seqlen, max_q_uih_len - min_full_attn_seqlen); if(contextual_seqlen > 0) - max_id = max_uih_len - (contextual_seqlen - 1); + { + max_row_id = max_q_uih_len - (contextual_seqlen - 1); + max_col_id = max_k_uih_len - (contextual_seqlen - 1); + } else - max_id = max_uih_len; + { + max_row_id = max_q_uih_len; + max_col_id = max_k_uih_len; + } }; // to get the loop length along X axis, return index:[start, end), end-start=length @@ -65,20 +77,20 @@ struct HstuBlockMaskWithLocal { if constexpr(kUseCausal) { - index_t x_end = min(i_y + YTile, seqlen); + index_t x_end = min(i_y + YTile, seqlen_k); return ck_tile::make_tuple(0, x_end); } else { // tile is partitially or completely in [max_uih_len-min_full_attn_seqlen, - // max_uih_len) - if(i_y < max_uih_len) + // max_q_uih_len) + if(i_y < max_q_uih_len) { - return ck_tile::make_tuple(0, seqlen); + return ck_tile::make_tuple(0, seqlen_k); } - else // tile is completely inside [max_uih_len, seqlen) + else // tile is completely inside [max_q_uih_len, seqlen_q) { - index_t x_end = min(i_y + YTile, seqlen); + index_t x_end = min(i_y + YTile, seqlen_k); return ck_tile::make_tuple(0, x_end); }; }; @@ -90,18 +102,18 @@ struct HstuBlockMaskWithLocal { if(i_y >= min(contextual_seqlen, 1) + max_attn_len) { - // some row of the tile in [contextual_seqlen+max_attn_len, max_uih_len) - if(i_y < max_uih_len) + // some row of the tile in [contextual_seqlen+max_attn_len, max_q_uih_len) + if(i_y < max_q_uih_len) { index_t x_start = i_y - max_attn_len; index_t x_start_aligned = x_start - x_start % XTile; - // some rows of the tile in [max_uih_len -max_attn_len, max_uih_len) - if(i_y + YTile > max_uih_len - max_attn_len) + // some rows of the tile in [max_q_uih_len - max_attn_len, max_q_uih_len) + if(i_y + YTile > max_q_uih_len - max_attn_len) { - return ck_tile::make_tuple(x_start_aligned, seqlen); + return ck_tile::make_tuple(x_start_aligned, seqlen_k); } - else // whole tile in [contextual_seqlen+max_attn_len, max_uih_len + else // whole tile in [contextual_seqlen+max_attn_len, max_q_uih_len // -max_attn_len) { index_t x_end = i_y + YTile + max_attn_len; @@ -110,8 +122,8 @@ struct HstuBlockMaskWithLocal } else // whole tile in [max_uih_len, seqlen) { - index_t x_start = max_uih_len - max_attn_len; - index_t x_end = min(i_y + YTile, seqlen); + index_t x_start = max_k_uih_len - max_attn_len; + index_t x_end = min(i_y + YTile, seqlen_k); return ck_tile::make_tuple(x_start - x_start % XTile, x_end); } @@ -120,12 +132,12 @@ struct HstuBlockMaskWithLocal { if(i_y < contextual_seqlen) // some row of the tile in [0, contextual_seqlen) { - index_t x_end = min(max(i_y + YTile + max_attn_len, max_uih_len), seqlen); + index_t x_end = min(max(i_y + YTile + max_attn_len, max_k_uih_len), seqlen_k); return ck_tile::make_tuple(0, x_end); } else // whole tile in [contextual_seqlen, seqlen) { - index_t x_end = min(i_y + YTile + max_attn_len, seqlen); + index_t x_end = min(i_y + YTile + max_attn_len, seqlen_k); return ck_tile::make_tuple(0, x_end); } } @@ -134,17 +146,17 @@ struct HstuBlockMaskWithLocal { if(i_y >= min(contextual_seqlen, 1) + max_attn_len) { - index_t x_end = min(i_y + YTile, seqlen); + index_t x_end = min(i_y + YTile, seqlen_k); - // some row of the tile in [contextual_seqlen+max_attn_len, max_uih_len) - if(i_y < max_uih_len) + // some row of the tile in [contextual_seqlen+max_attn_len, max_q_uih_len) + if(i_y < max_q_uih_len) { index_t x_start = i_y - max_attn_len; return ck_tile::make_tuple(x_start - x_start % XTile, x_end); } else // whole tile in [max_uih_len, seqlen) { - index_t x_start = max_uih_len - max_attn_len; + index_t x_start = max_k_uih_len - max_attn_len; return ck_tile::make_tuple(x_start - x_start % XTile, x_end); } } @@ -152,12 +164,12 @@ struct HstuBlockMaskWithLocal { if(i_y < contextual_seqlen) // some row of the tile in [0, contextual_seqlen) { - index_t x_end = min(max(i_y + YTile, max_uih_len), seqlen); + index_t x_end = min(max(i_y + YTile, max_k_uih_len), seqlen_k); return ck_tile::make_tuple(0, x_end); } else // whole tile in [contextual_seqlen, seqlen) { - index_t x_end = min(i_y + YTile, seqlen); + index_t x_end = min(i_y + YTile, seqlen_k); return ck_tile::make_tuple(0, x_end); } } @@ -176,18 +188,18 @@ struct HstuBlockMaskWithLocal row_id = max(row - contextual_seqlen + 1, 0); col_id = max(col - contextual_seqlen + 1, 0); - row_id = min(row_id, max_id); - col_id = min(col_id, max_id); + row_id = min(row_id, max_row_id); + col_id = min(col_id, max_col_id); - if(row_id == 0 && col_id < max_id) + if(row_id == 0 && col_id < max_col_id) return true; } else { // row_id/col_id is clamped from physical row/col according to contextual_seqlen and // max_uih_len - row_id = min(row, max_id); - col_id = min(col, max_id); + row_id = min(row, max_row_id); + col_id = min(col, max_col_id); }; // use row_id/col_id to check the dist between two q/k token pair, token pairs on the @@ -195,7 +207,7 @@ struct HstuBlockMaskWithLocal if constexpr(kUseCausal) { bool in_min_full_scope = - (min_full_attn_seqlen > 0) ? (row_id >= max_id - min_full_attn_seqlen) : false; + (min_full_attn_seqlen > 0) ? (row_id >= max_row_id - min_full_attn_seqlen) : false; return (((row_id > col_id) || (row == col)) && ((row_id - col_id <= max_attn_len) || in_min_full_scope)); @@ -203,7 +215,7 @@ struct HstuBlockMaskWithLocal else { bool in_min_full_scope = - (min_full_attn_seqlen > 0) ? (row_id >= max_id - min_full_attn_seqlen) : false; + (min_full_attn_seqlen > 0) ? (row_id >= max_row_id - min_full_attn_seqlen) : false; return (((row_id != col_id) || (row == col)) && ((abs(row_id - col_id) <= max_attn_len) || in_min_full_scope)); @@ -222,18 +234,18 @@ struct HstuBlockMaskWithLocal row_id = max(row - contextual_seqlen + 1, 0); col_id = max(col - contextual_seqlen + 1, 0); - row_id = min(row_id, max_id); - col_id = min(col_id, max_id); + row_id = min(row_id, max_row_id); + col_id = min(col_id, max_col_id); - if(row_id == 0 && col_id < max_id) + if(row_id == 0 && col_id < max_col_id) return true; } else { // row_id/col_id is clamped from physical row/col according to contextual_seqlen and // max_uih_len - row_id = min(row, max_id); - col_id = min(col, max_id); + row_id = min(row, max_row_id); + col_id = min(col, max_col_id); }; // use row_id/col_id to check the dist between two q/k token pair, token pairs on the @@ -269,7 +281,7 @@ struct HstuBlockMaskWithLocal { index_t i_tile_right = i_tile_left + TileWidth; - if(!is_tile_in_first_split && i_tile_right <= min(i_tile_top + 1, max_uih_len)) + if(!is_tile_in_first_split && i_tile_right <= min(i_tile_top + 1, max_k_uih_len)) return true; } else @@ -277,11 +289,11 @@ struct HstuBlockMaskWithLocal index_t i_tile_right = i_tile_left + TileWidth; index_t i_tile_bottom = i_tile_top + TileHeight; - // 1) tile is completely in [max_uih_len-min_full_attn_seqlen, max_uih_len] - // 2) some row of tile is in [max_uih_len, seqlen], requires i_tile_right <= max_uih_len - // to return true + // 1) tile is completely in [max_q_uih_len-min_full_attn_seqlen, max_q_uih_len] + // 2) some row of tile is in [max_q_uih_len, seqlen_q], requires i_tile_right <= + // max_k_uih_len to return true if(!is_tile_in_first_split && - (i_tile_bottom <= max_uih_len || i_tile_right <= max_uih_len)) + (i_tile_bottom <= max_q_uih_len || i_tile_right <= max_k_uih_len)) return true; }; @@ -295,21 +307,32 @@ struct HstuBlockMaskNoLocal static constexpr bool kUseLocal = false; static constexpr bool IsMasking = kUseCausal; - int seqlen; + int seqlen_q; + int seqlen_k; int contextual_seqlen; - int max_uih_len; - int max_id; + int max_q_uih_len; + int max_k_uih_len; + int max_row_id; + int max_col_id; - CK_TILE_HOST_DEVICE HstuBlockMaskNoLocal(int seqlen_, int contextual_seqlen_, int num_target_) - : seqlen(seqlen_), contextual_seqlen(contextual_seqlen_) + CK_TILE_HOST_DEVICE + HstuBlockMaskNoLocal(int seqlen_q_, int seqlen_k_, int contextual_seqlen_, int num_target_) + : seqlen_q(seqlen_q_), seqlen_k(seqlen_k_), contextual_seqlen(contextual_seqlen_) { - max_uih_len = seqlen - num_target_; + max_q_uih_len = seqlen_q - num_target_; + max_k_uih_len = seqlen_k - num_target_; if(contextual_seqlen > 0) - max_id = max_uih_len - (contextual_seqlen - 1); + { + max_row_id = max_q_uih_len - (contextual_seqlen - 1); + max_col_id = max_k_uih_len - (contextual_seqlen - 1); + } else - max_id = max_uih_len; + { + max_row_id = max_q_uih_len; + max_col_id = max_k_uih_len; + } }; // to get the loop length along X axis, return index:[start, end), end-start=length @@ -321,21 +344,21 @@ struct HstuBlockMaskNoLocal { if constexpr(!IsMasking) { - return ck_tile::make_tuple(0, seqlen); + return ck_tile::make_tuple(0, seqlen_k); } else { - index_t x_end = min(i_y + YTile, seqlen); + index_t x_end = min(i_y + YTile, seqlen_k); if(i_y < contextual_seqlen) { - if(i_y + YTile > max_uih_len) + if(i_y + YTile > max_k_uih_len) { return ck_tile::make_tuple(0, x_end); } else { - return ck_tile::make_tuple(0, max_uih_len); + return ck_tile::make_tuple(0, max_k_uih_len); }; } else @@ -357,18 +380,18 @@ struct HstuBlockMaskNoLocal row_id = max(row - contextual_seqlen + 1, 0); col_id = max(col - contextual_seqlen + 1, 0); - row_id = min(row_id, max_id); - col_id = min(col_id, max_id); + row_id = min(row_id, max_row_id); + col_id = min(col_id, max_col_id); - if(row_id == 0 && col_id < max_id) + if(row_id == 0 && col_id < max_col_id) return true; } else { // row_id/col_id is clamped from physical row/col according to contextual_seqlen and // max_uih_len - row_id = min(row, max_id); - col_id = min(col, max_id); + row_id = min(row, max_row_id); + col_id = min(col, max_col_id); }; // use row_id/col_id to check the dist between two q/k token pair, token pairs on the @@ -395,18 +418,18 @@ struct HstuBlockMaskNoLocal row_id = max(row - contextual_seqlen + 1, 0); col_id = max(col - contextual_seqlen + 1, 0); - row_id = min(row_id, max_id); - col_id = min(col_id, max_id); + row_id = min(row_id, max_row_id); + col_id = min(col_id, max_col_id); - if(row_id == 0 && col_id < max_id) + if(row_id == 0 && col_id < max_col_id) return true; } else { // row_id/col_id is clamped from physical row/col according to contextual_seqlen and // max_uih_len - row_id = min(row, max_id); - col_id = min(col, max_id); + row_id = min(row, max_row_id); + col_id = min(col, max_col_id); }; // use row_id/col_id to check the dist between two q/k token pair, token pairs on the @@ -439,7 +462,7 @@ struct HstuBlockMaskNoLocal // assume num_target > 0 with high probability, don't check whether num_target is 0; // so if num_target is 0, IsTokenPairInsideMask() will be called for the bottom tile - if(i_tile_bottom >= max_uih_len || i_tile_right > i_tile_top) + if(i_tile_bottom >= max_q_uih_len || i_tile_right > i_tile_top) return false; return true; @@ -451,7 +474,7 @@ struct HstuBlockMaskNoLocal // assume num_target > 0 with high probability, don't check whether num_target is 0; // so if num_target is 0, IsTokenPairInsideMask() will be called for the bottom tile - if(i_tile_bottom >= max_uih_len || i_tile_right >= max_uih_len) + if(i_tile_bottom >= max_q_uih_len || i_tile_right >= max_k_uih_len) return false; return true; @@ -469,14 +492,16 @@ struct HstuBlockMasking template CK_TILE_HOST_DEVICE constexpr auto make_hstu_block_mask_with_local(bool is_tile_in_first_split_, - int seqlen_, + int seqlen_q_, + int seqlen_k_, int contextual_seqlen_, int num_target, int max_attn_len_, int min_full_attn_seqlen_) { return HstuBlockMaskType{is_tile_in_first_split_, - seqlen_, + seqlen_q_, + seqlen_k_, contextual_seqlen_, max_attn_len_, min_full_attn_seqlen_, @@ -484,10 +509,12 @@ CK_TILE_HOST_DEVICE constexpr auto make_hstu_block_mask_with_local(bool is_tile_ }; template -CK_TILE_HOST_DEVICE constexpr auto -make_hstu_block_mask_without_local(int seqlen_, int contextual_seqlen_, int num_target) +CK_TILE_HOST_DEVICE constexpr auto make_hstu_block_mask_without_local(int seqlen_q_, + int seqlen_k_, + int contextual_seqlen_, + int num_target) { - return HstuBlockMaskType{seqlen_, contextual_seqlen_, num_target}; + return HstuBlockMaskType{seqlen_q_, seqlen_k_, contextual_seqlen_, num_target}; }; } // namespace ck_tile diff --git a/example/ck_tile/18_hstu_attention/reference_hstu_attention.hpp b/example/ck_tile/18_hstu_attention/reference_hstu_attention.hpp index 798f73c4c5..68ec7e514e 100644 --- a/example/ck_tile/18_hstu_attention/reference_hstu_attention.hpp +++ b/example/ck_tile/18_hstu_attention/reference_hstu_attention.hpp @@ -42,7 +42,8 @@ struct reference_hstu_attention float alpha, float attn_scale, int max_seqlen, - std::vector seq_offsets, + std::vector seq_q_offsets, + std::vector seq_kv_offsets, std::vector num_targets, // define masking length at the end of token // sequence to be excluded for attention int contextual_seqlen, // define masking length at the begin of query token @@ -54,7 +55,8 @@ struct reference_hstu_attention if constexpr(kIsJagged) { // check the number of batches - assert(!seq_offsets.empty() && seq_offsets.size() == num_batch + 1); + assert(!seq_q_offsets.empty() && seq_q_offsets.size() == num_batch + 1); + assert(!seq_kv_offsets.empty() && seq_kv_offsets.size() == num_batch + 1); assert(q_batch_seq_nhead_hdim.get_lengths()[0] == 1); assert(k_batch_seq_nhead_hdim.get_lengths()[0] == 1); assert(v_batch_seq_nhead_hdim.get_lengths()[0] == 1); @@ -62,7 +64,8 @@ struct reference_hstu_attention } else { - assert(seq_offsets.empty()); + assert(seq_q_offsets.empty()); + assert(seq_kv_offsets.empty()); assert(q_batch_seq_nhead_hdim.get_lengths()[0] == num_batch); assert(k_batch_seq_nhead_hdim.get_lengths()[0] == num_batch); assert(v_batch_seq_nhead_hdim.get_lengths()[0] == num_batch); @@ -104,8 +107,10 @@ struct reference_hstu_attention }; auto f = [&](auto i_batch, auto i_head) { - int seqlen = kIsJagged ? (seq_offsets[i_batch + 1] - seq_offsets[i_batch]) - : q_batch_seq_nhead_hdim.get_lengths()[1]; + int seqlen_q = kIsJagged ? (seq_q_offsets[i_batch + 1] - seq_q_offsets[i_batch]) + : q_batch_seq_nhead_hdim.get_lengths()[1]; + int seqlen_kv = kIsJagged ? (seq_kv_offsets[i_batch + 1] - seq_kv_offsets[i_batch]) + : k_batch_seq_nhead_hdim.get_lengths()[1]; int num_target = num_targets.empty() ? 0 : num_targets[i_batch]; @@ -118,10 +123,11 @@ struct reference_hstu_attention if constexpr(kHasLocal) // need adjust the min_full_attn_seqlen passed to the HstuBlockMask() if the // user passed min_full_attn_seqlen is bigger than max_uih_len - if(seqlen - num_target > min_full_attn_seqlen) + if(seqlen_q - num_target > min_full_attn_seqlen) return ck_tile::make_hstu_block_mask_with_local( true, - seqlen, + seqlen_q, + seqlen_kv, contextual_seqlen, num_target, window_size, @@ -129,14 +135,15 @@ struct reference_hstu_attention else return ck_tile::make_hstu_block_mask_with_local( true, - seqlen, + seqlen_q, + seqlen_kv, contextual_seqlen, num_target, window_size, - seqlen - num_target); + seqlen_q - num_target); else return ck_tile::make_hstu_block_mask_without_local( - seqlen, contextual_seqlen, num_target); + seqlen_q, seqlen_kv, contextual_seqlen, num_target); }(); if(save_mask) @@ -149,7 +156,7 @@ struct reference_hstu_attention } // for all rows in the batch - for(int sq = 0; sq < seqlen; sq++) + for(int sq = 0; sq < seqlen_q; sq++) { CompDataType m = -ck_tile::numeric::infinity(); // max value of the row @@ -159,7 +166,7 @@ struct reference_hstu_attention std::vector locals; // for all cols in the batch - for(int sk = 0; sk < seqlen; sk++) + for(int sk = 0; sk < seqlen_kv; sk++) { if(mask.IsTokenPairInsideMask(sq, sk)) { @@ -169,9 +176,9 @@ struct reference_hstu_attention if constexpr(kIsJagged) { InOutDataType qreg = q_batch_seq_nhead_hdim( - 0, seq_offsets[i_batch] + sq, i_head, k); + 0, seq_q_offsets[i_batch] + sq, i_head, k); InOutDataType kreg = k_batch_seq_nhead_hdim( - 0, seq_offsets[i_batch] + sk, i_head, k); + 0, seq_kv_offsets[i_batch] + sk, i_head, k); dot_prod += ck_tile::type_convert(qreg) * ck_tile::type_convert(kreg); @@ -233,14 +240,14 @@ struct reference_hstu_attention { GemmAccDataType dot_prod = 0.f; - for(int sk = 0; sk < seqlen; sk++) + for(int sk = 0; sk < seqlen_kv; sk++) { if constexpr(kIsJagged) { InOutDataType preg = ck_tile::type_convert(locals[sk]); - InOutDataType vreg = - v_batch_seq_nhead_hdim(0, seq_offsets[i_batch] + sk, i_head, k); + InOutDataType vreg = v_batch_seq_nhead_hdim( + 0, seq_kv_offsets[i_batch] + sk, i_head, k); dot_prod += ck_tile::type_convert(preg) * ck_tile::type_convert(vreg); @@ -257,7 +264,7 @@ struct reference_hstu_attention }; if constexpr(kIsJagged) - o_batch_seq_nhead_hdim(0, seq_offsets[i_batch] + sq, i_head, k) = + o_batch_seq_nhead_hdim(0, seq_q_offsets[i_batch] + sq, i_head, k) = ck_tile::type_convert(dot_prod); else o_batch_seq_nhead_hdim(i_batch, sq, i_head, k) =