From db3263469c9b96d59cb40919c777bb83284f3191 Mon Sep 17 00:00:00 2001 From: Qianfeng Zhang Date: Fri, 17 Apr 2026 09:13:45 +0000 Subject: [PATCH] Clarify the using the max_seqlen and max_seqlen_q --- .../example_hstu_attention.cpp | 155 ++++++++++-------- .../hstu_attention_fwd_kernel.hpp | 19 +-- .../hstu_attention_fwd_splitkv_kernel.hpp | 19 +-- .../hstu_attention_group_forward_dispatch.hpp | 8 +- ...tention_group_forward_splitkv_dispatch.hpp | 10 +- ...hstu_attention_jagged_forward_dispatch.hpp | 8 +- ...ention_jagged_forward_splitkv_dispatch.hpp | 10 +- .../hstu_attention_params.hpp | 6 +- .../reference_hstu_attention.hpp | 29 ++-- 9 files changed, 143 insertions(+), 121 deletions(-) diff --git a/example/ck_tile/18_hstu_attention/example_hstu_attention.cpp b/example/ck_tile/18_hstu_attention/example_hstu_attention.cpp index d39e2f8207..51ddce2e87 100644 --- a/example/ck_tile/18_hstu_attention/example_hstu_attention.cpp +++ b/example/ck_tile/18_hstu_attention/example_hstu_attention.cpp @@ -103,6 +103,7 @@ auto create_args(int argc, char* argv[]) .insert("max_seqlen", "0", "max uih_seqlen, can be ignored, or else must be equal/bigger than the maximum of all uih seqlens") .insert("max_seqlen_kv", "0", "max uih_seqlen_kv, can be ignored, or else must be equal/bigger than the maximum of all uih seqlens") .insert("g_max_seqlens", "0", "max uih_seqlen of groups, can be ignored, or else each must be equal/bigger than maximum of all uih seqlens in its group") + .insert("g_max_seqlens_kv", "0", "max uih_seqlen of groups, can be ignored, or else each must be equal/bigger than maximum of all uih seqlens in its group") .insert("targets", "", "sequence length at the end of query/key token sequence that should be excluded from attention") .insert("max_target", "0", "max target, can be ignored, or else must be equal/bigger than the maximum of all targets") .insert("softmax", "0", "use softmax or not") @@ -478,7 +479,7 @@ bool run_no_group_hstu(const ck_tile::ArgParser& arg_parser, bool is_jagged) params.num_batch = num_batch; params.seq_q_offsets_ptr = seq_offsets_q_dev.GetDeviceBuffer(); params.seq_kv_offsets_ptr = seq_offsets_kv_dev.GetDeviceBuffer(); - params.max_seqlen = max(max_seqlen_q, max_seqlen_kv); + params.max_seqlen_q = max_seqlen_q; params.q_ptr = q_dev.GetDeviceBuffer(); params.k_ptr = k_dev.GetDeviceBuffer(); params.v_ptr = v_dev.GetDeviceBuffer(); @@ -697,10 +698,11 @@ bool run_group_hstu(const ck_tile::ArgParser& arg_parser, int num_group) else is_cross_attention = true; - str_of_integers = arg_parser.get_str("g_max_seqlens"); - std::vector group_input_max_uih_seqlens = get_integers_from_string(str_of_integers); + str_of_integers = arg_parser.get_str("g_max_seqlens"); + std::vector group_input_max_uih_seqlens_q = get_integers_from_string(str_of_integers); - HSTU_CHECK(!group_input_max_uih_seqlens.empty(), "group window sizes shoud be defined!"); + str_of_integers = arg_parser.get_str("g_max_seqlens_kv"); + std::vector group_input_max_uih_seqlens_kv = get_integers_from_string(str_of_integers); str_of_integers = arg_parser.get_str("g_context_lens"); std::vector group_contextual_seqlens = get_integers_from_string(str_of_integers); @@ -721,10 +723,8 @@ bool run_group_hstu(const ck_tile::ArgParser& arg_parser, int num_group) std::vector group_attn_scales = get_floats_from_string(str_of_floats); HSTU_CHECK(!group_attn_scales.empty(), "group attn_scales shoud be defined!"); - // supplement seq_lengths_q using the last input value if user-provided lengths not enough + // supplement seq_lengths using the last input value if user-provided lengths not enough supplement_array_by_last_element(seq_lengths_q, num_batch); - - // supplement seq_lengths_kv using the last input value if user-provided lengths not enough supplement_array_by_last_element(seq_lengths_kv, num_batch); if(!num_targets.empty()) @@ -735,7 +735,8 @@ bool run_group_hstu(const ck_tile::ArgParser& arg_parser, int num_group) // supplement group_input_max_uih_seqlens using the last input value if user-provided lengths // not enough - supplement_array_by_last_element(group_input_max_uih_seqlens, num_group); + supplement_array_by_last_element(group_input_max_uih_seqlens_q, num_group); + supplement_array_by_last_element(group_input_max_uih_seqlens_kv, num_group); // supplement group_contextual_seqlens using the last input value if user-provided lengths not // enough @@ -751,45 +752,64 @@ bool run_group_hstu(const ck_tile::ArgParser& arg_parser, int num_group) // supplement group_attn_scales using the last input value if user-provided values not enough supplement_array_by_last_element(group_attn_scales, num_group); - int phy_seqlen_q = 0; - int phy_seqlen_kv = 0; - int max_max_seqlen = 0; + int phy_seqlen_q = 0; + int phy_seqlen_kv = 0; + int max_max_seqlen_q = 0; + int max_max_seqlen_kv = 0; - std::vector group_max_uih_seqlens; + std::vector group_max_uih_seqlens_q; + std::vector group_max_uih_seqlens_kv; - group_max_uih_seqlens.resize(num_group); + group_max_uih_seqlens_q.resize(num_group); + group_max_uih_seqlens_kv.resize(num_group); for(int i_grp = 0; i_grp < num_group; i_grp++) { - group_max_uih_seqlens[i_grp] = 0; + group_max_uih_seqlens_q[i_grp] = 0; + group_max_uih_seqlens_kv[i_grp] = 0; for(int i_batch = 0; i_batch < num_batch_per_group; i_batch++) { auto i_global_batch = i_grp * num_batch_per_group + i_batch; - group_max_uih_seqlens[i_grp] = - max(group_max_uih_seqlens[i_grp], seq_lengths_q[i_global_batch]); + group_max_uih_seqlens_q[i_grp] = + max(group_max_uih_seqlens_q[i_grp], seq_lengths_q[i_global_batch]); + group_max_uih_seqlens_kv[i_grp] = + max(group_max_uih_seqlens_kv[i_grp], seq_lengths_kv[i_global_batch]); }; - HSTU_CHECK(group_input_max_uih_seqlens[i_grp] <= 0 || - group_input_max_uih_seqlens[i_grp] >= group_max_uih_seqlens[i_grp], + HSTU_CHECK(group_input_max_uih_seqlens_q[i_grp] <= 0 || + group_input_max_uih_seqlens_q[i_grp] >= group_max_uih_seqlens_q[i_grp], "the user input of each group max_uih_seqlen can either be ignored or be bigger " "than all uih_seqlens[] of the group"); - group_max_uih_seqlens[i_grp] = group_input_max_uih_seqlens[i_grp] > 0 - ? group_input_max_uih_seqlens[i_grp] - : group_max_uih_seqlens[i_grp]; + HSTU_CHECK(group_input_max_uih_seqlens_kv[i_grp] <= 0 || + group_input_max_uih_seqlens_kv[i_grp] >= group_max_uih_seqlens_kv[i_grp], + "the user input of each group max_uih_seqlen can either be ignored or be bigger " + "than all uih_seqlens[] of the group"); + + group_max_uih_seqlens_q[i_grp] = group_input_max_uih_seqlens_q[i_grp] > 0 + ? group_input_max_uih_seqlens_q[i_grp] + : group_max_uih_seqlens_q[i_grp]; + group_max_uih_seqlens_kv[i_grp] = group_input_max_uih_seqlens_kv[i_grp] > 0 + ? group_input_max_uih_seqlens_kv[i_grp] + : group_max_uih_seqlens_kv[i_grp]; }; - std::vector group_max_seqlens; + std::vector group_max_seqlens_q; + std::vector group_max_seqlens_kv; - group_max_seqlens.resize(num_group); + group_max_seqlens_q.resize(num_group); + group_max_seqlens_kv.resize(num_group); for(int i_grp = 0; i_grp < num_group; i_grp++) { - group_max_seqlens[i_grp] = - group_max_uih_seqlens[i_grp] + group_contextual_seqlens[i_grp] + num_targets[i_grp]; - max_max_seqlen = max(max_max_seqlen, group_max_seqlens[i_grp]); + group_max_seqlens_q[i_grp] = + group_max_uih_seqlens_q[i_grp] + group_contextual_seqlens[i_grp] + num_targets[i_grp]; + max_max_seqlen_q = max(max_max_seqlen_q, group_max_seqlens_q[i_grp]); + group_max_seqlens_kv[i_grp] = + group_max_uih_seqlens_kv[i_grp] + group_contextual_seqlens[i_grp] + num_targets[i_grp]; + max_max_seqlen_kv = max(max_max_seqlen_kv, group_max_seqlens_kv[i_grp]); }; std::vector seq_offsets_q; @@ -859,10 +879,12 @@ bool run_group_hstu(const ck_tile::ArgParser& arg_parser, int num_group) ck_tile::HostTensor o_host_ref( std::array{batches_for_alloc, phy_seqlen_q, num_head, hdim_v}); - ck_tile::HostTensor mask_host( - save_mask - ? std::array{num_batch, num_head, max_max_seqlen, max_max_seqlen} - : std::array{1, 1, 1, 1}); + ck_tile::HostTensor mask_host(save_mask + ? std::array{num_batch, + num_head, + max_max_seqlen_q, + max_max_seqlen_q} + : std::array{1, 1, 1, 1}); if(!initialize_qkv) { @@ -904,14 +926,14 @@ bool run_group_hstu(const ck_tile::ArgParser& arg_parser, int num_group) if(!num_targets.empty()) num_targets_dev.ToDevice(num_targets.data()); - ck_tile::DeviceMem group_max_seqlens_dev(group_max_seqlens.size() * sizeof(int)); + ck_tile::DeviceMem group_max_seqlens_q_dev(group_max_seqlens_q.size() * sizeof(int)); ck_tile::DeviceMem group_contextual_seqlens_dev(group_contextual_seqlens.size() * sizeof(int)); ck_tile::DeviceMem group_window_sizes_dev(group_window_sizes.size() * sizeof(int)); ck_tile::DeviceMem group_min_full_attn_seqlens_dev(group_min_full_attn_seqlens.size() * sizeof(int)); ck_tile::DeviceMem group_attn_scales_dev(group_attn_scales.size() * sizeof(float)); - group_max_seqlens_dev.ToDevice(group_max_seqlens.data()); + group_max_seqlens_q_dev.ToDevice(group_max_seqlens_q.data()); group_contextual_seqlens_dev.ToDevice(group_contextual_seqlens.data()); group_window_sizes_dev.ToDevice(group_window_sizes.data()); group_min_full_attn_seqlens_dev.ToDevice(group_min_full_attn_seqlens.data()); @@ -921,38 +943,38 @@ bool run_group_hstu(const ck_tile::ArgParser& arg_parser, int num_group) float scale_s = (alpha != 0.f) ? alpha : 1.0f / std::sqrt(hdim_qk); - params.is_cross_attention = is_cross_attention; - params.num_batch = num_batch; - params.num_group = num_group; - params.seq_q_offsets_ptr = seq_offsets_q_dev.GetDeviceBuffer(); - params.seq_kv_offsets_ptr = seq_offsets_kv_dev.GetDeviceBuffer(); - params.max_seqlen = max_max_seqlen; - params.q_ptr = q_dev.GetDeviceBuffer(); - params.k_ptr = k_dev.GetDeviceBuffer(); - params.v_ptr = v_dev.GetDeviceBuffer(); - params.bias_ptr = nullptr; // bias is not supported at present - params.o_ptr = o_dev.GetDeviceBuffer(); - params.hdim_qk = hdim_qk; - params.hdim_v = hdim_v; - params.num_head = num_head; - params.scale_s = scale_s; - params.seq_stride_q = q_host.get_strides()[1]; - params.seq_stride_k = k_host.get_strides()[1]; - params.seq_stride_v = v_host.get_strides()[1]; - params.seq_stride_bias = 0; - params.seq_stride_o = o_host_ref.get_strides()[1]; - params.nhead_stride_q = q_host.get_strides()[2]; - params.nhead_stride_k = k_host.get_strides()[2]; - params.nhead_stride_v = v_host.get_strides()[2]; - params.nhead_stride_bias = 0; - params.nhead_stride_o = o_host_ref.get_strides()[2]; - params.num_targets_ptr = num_targets.empty() ? nullptr : num_targets_dev.GetDeviceBuffer(); - params.use_softmax = use_softmax; - params.use_causal = use_causal; - params.p_drop = 0.0f; // dropout is not supported at present - params.philox_seed = 0UL; - params.philox_offset = 0UL; - params.group_max_seqlen_ptr = group_max_seqlens_dev.GetDeviceBuffer(); + params.is_cross_attention = is_cross_attention; + params.num_batch = num_batch; + params.num_group = num_group; + params.seq_q_offsets_ptr = seq_offsets_q_dev.GetDeviceBuffer(); + params.seq_kv_offsets_ptr = seq_offsets_kv_dev.GetDeviceBuffer(); + params.max_seqlen_q = max_max_seqlen_q; + params.q_ptr = q_dev.GetDeviceBuffer(); + params.k_ptr = k_dev.GetDeviceBuffer(); + params.v_ptr = v_dev.GetDeviceBuffer(); + params.bias_ptr = nullptr; // bias is not supported at present + params.o_ptr = o_dev.GetDeviceBuffer(); + params.hdim_qk = hdim_qk; + params.hdim_v = hdim_v; + params.num_head = num_head; + params.scale_s = scale_s; + params.seq_stride_q = q_host.get_strides()[1]; + params.seq_stride_k = k_host.get_strides()[1]; + params.seq_stride_v = v_host.get_strides()[1]; + params.seq_stride_bias = 0; + params.seq_stride_o = o_host_ref.get_strides()[1]; + params.nhead_stride_q = q_host.get_strides()[2]; + params.nhead_stride_k = k_host.get_strides()[2]; + params.nhead_stride_v = v_host.get_strides()[2]; + params.nhead_stride_bias = 0; + params.nhead_stride_o = o_host_ref.get_strides()[2]; + params.num_targets_ptr = num_targets.empty() ? nullptr : num_targets_dev.GetDeviceBuffer(); + params.use_softmax = use_softmax; + params.use_causal = use_causal; + params.p_drop = 0.0f; // dropout is not supported at present + params.philox_seed = 0UL; + params.philox_offset = 0UL; + params.group_max_seqlen_q_ptr = group_max_seqlens_q_dev.GetDeviceBuffer(); params.group_contextual_seqlen_ptr = group_contextual_seqlens_dev.GetDeviceBuffer(); params.group_window_size_ptr = group_window_sizes_dev.GetDeviceBuffer(); params.group_min_full_attn_seqlen_ptr = group_min_full_attn_seqlens_dev.GetDeviceBuffer(); @@ -994,11 +1016,12 @@ bool run_group_hstu(const ck_tile::ArgParser& arg_parser, int num_group) num_batch, num_batch / num_group, scale_s, - max_max_seqlen, + max_max_seqlen_q, + max_max_seqlen_kv, seq_offsets_q, seq_offsets_kv, num_targets, - group_max_seqlens, + group_max_seqlens_q, group_contextual_seqlens, group_window_sizes, group_min_full_attn_seqlens, diff --git a/example/ck_tile/18_hstu_attention/hstu_attention_fwd_kernel.hpp b/example/ck_tile/18_hstu_attention/hstu_attention_fwd_kernel.hpp index aa2bee442e..1ab7178c0a 100644 --- a/example/ck_tile/18_hstu_attention/hstu_attention_fwd_kernel.hpp +++ b/example/ck_tile/18_hstu_attention/hstu_attention_fwd_kernel.hpp @@ -174,7 +174,7 @@ struct HstuAttentionFwdKernel int32_t window_size; // to be set by the per-group window_size int32_t min_full_attn_seqlen; // to be set by the per-group min_full_attn_seqlen - const int32_t* group_max_seqlen_ptr; + const int32_t* group_max_seqlen_q_ptr; const int32_t* group_contextual_seqlen_ptr; const int32_t* group_window_size_ptr; const int32_t* group_min_full_attn_seqlen_ptr; @@ -318,8 +318,7 @@ struct HstuAttentionFwdKernel seq_stride_o, num_head, scale_s, - attn_scale ? attn_scale - : 1.0f / static_cast(max(seqlen_q, seqlen_kv)), // max_seqlen + attn_scale ? attn_scale : 1.0f / static_cast(seqlen_q), contextual_seqlen, window_size, min_full_attn_seqlen}, // args for common karg @@ -351,7 +350,7 @@ struct HstuAttentionFwdKernel void* o_ptr, const void* seq_q_offsets_ptr, const void* seq_kv_offsets_ptr, - ck_tile::index_t max_seqlen, + ck_tile::index_t max_seqlen_q, ck_tile::index_t hdim_qk, ck_tile::index_t hdim_v, ck_tile::index_t num_head, @@ -397,7 +396,7 @@ struct HstuAttentionFwdKernel -1, // seqlen_kv will be updated by another pointer num_head, scale_s, - attn_scale ? attn_scale : 1.0f / static_cast(max_seqlen), + attn_scale ? attn_scale : 1.0f / static_cast(max_seqlen_q), contextual_seqlen, window_size, min_full_attn_seqlen}, // args for common karg @@ -429,7 +428,7 @@ struct HstuAttentionFwdKernel ck_tile::index_t num_batch_per_group, const void* seq_q_offsets_ptr, const void* seq_kv_offsets_ptr, - const void* group_max_seqlen_ptr, + const void* group_max_seqlen_q_ptr, const void* group_contextual_seqlen_ptr, const void* group_window_size_ptr, const void* group_min_full_attn_seqlen_ptr, @@ -480,7 +479,7 @@ struct HstuAttentionFwdKernel 0, // to be set by the per-group contextual_seqlen 0, // to be set by the per-group window_size 0, // to be set by the per-group min_full_attn_seqlen - reinterpret_cast(group_max_seqlen_ptr), + reinterpret_cast(group_max_seqlen_q_ptr), reinterpret_cast(group_contextual_seqlen_ptr), reinterpret_cast(group_window_size_ptr), reinterpret_cast(group_min_full_attn_seqlen_ptr), @@ -654,9 +653,9 @@ struct HstuAttentionFwdKernel index_t i_group = __builtin_amdgcn_readfirstlane(i_batch / kargs.num_batch_per_group); - float attn_scale = kargs.group_attn_scale_ptr[i_group]; - index_t max_seqlen = kargs.group_max_seqlen_ptr[i_group]; - kargs.scale_p = (attn_scale ? attn_scale : 1.0f / static_cast(max_seqlen)); + float attn_scale = kargs.group_attn_scale_ptr[i_group]; + index_t max_seqlen_q = kargs.group_max_seqlen_q_ptr[i_group]; + kargs.scale_p = (attn_scale ? attn_scale : 1.0f / static_cast(max_seqlen_q)); kargs.contextual_seqlen = kargs.group_contextual_seqlen_ptr[i_group]; kargs.window_size = kargs.group_window_size_ptr[i_group]; kargs.min_full_attn_seqlen = kargs.group_min_full_attn_seqlen_ptr[i_group]; diff --git a/example/ck_tile/18_hstu_attention/hstu_attention_fwd_splitkv_kernel.hpp b/example/ck_tile/18_hstu_attention/hstu_attention_fwd_splitkv_kernel.hpp index 36760dcd78..96c14ebb04 100644 --- a/example/ck_tile/18_hstu_attention/hstu_attention_fwd_splitkv_kernel.hpp +++ b/example/ck_tile/18_hstu_attention/hstu_attention_fwd_splitkv_kernel.hpp @@ -173,7 +173,7 @@ struct HstuAttentionFwdSplitKVKernel int32_t window_size; // to be set by the per-group window_size int32_t min_full_attn_seqlen; // to be set by the per-group min_full_attn_seqlen - const int32_t* group_max_seqlen_ptr; + const int32_t* group_max_seqlen_q_ptr; const int32_t* group_contextual_seqlen_ptr; const int32_t* group_window_size_ptr; const int32_t* group_min_full_attn_seqlen_ptr; @@ -313,8 +313,7 @@ struct HstuAttentionFwdSplitKVKernel seq_stride_v, num_head, scale_s, - attn_scale ? attn_scale - : 1.0f / static_cast(max(seqlen_q, seqlen_kv)), // max_seqlen + attn_scale ? attn_scale : 1.0f / static_cast(seqlen_q), contextual_seqlen, window_size, min_full_attn_seqlen}, // args for common karg @@ -347,7 +346,7 @@ struct HstuAttentionFwdSplitKVKernel ck_tile::index_t num_splits, // number of splitted seqlen_kv const void* seq_q_offsets_ptr, const void* seq_kv_offsets_ptr, - ck_tile::index_t max_seqlen, + ck_tile::index_t max_seqlen_q, ck_tile::index_t hdim_qk, ck_tile::index_t hdim_v, ck_tile::index_t num_head, @@ -390,7 +389,7 @@ struct HstuAttentionFwdSplitKVKernel -1, // seqlen_kv will be updated by another pointer num_head, scale_s, - attn_scale ? attn_scale : 1.0f / static_cast(max_seqlen), + attn_scale ? attn_scale : 1.0f / static_cast(max_seqlen_q), contextual_seqlen, window_size, min_full_attn_seqlen}, // args for common karg @@ -423,7 +422,7 @@ struct HstuAttentionFwdSplitKVKernel ck_tile::index_t num_batch_per_group, const void* seq_q_offsets_ptr, const void* seq_kv_offsets_ptr, - const void* group_max_seqlen_ptr, + const void* group_max_seqlen_q_ptr, const void* group_contextual_seqlen_ptr, const void* group_window_size_ptr, const void* group_min_full_attn_seqlen_ptr, @@ -471,7 +470,7 @@ struct HstuAttentionFwdSplitKVKernel 0, // to be set by the per-group contextual_seqlen 0, // to be set by the per-group window_size 0, // to be set by the per-group min_full_attn_seqlen - reinterpret_cast(group_max_seqlen_ptr), + reinterpret_cast(group_max_seqlen_q_ptr), reinterpret_cast(group_contextual_seqlen_ptr), reinterpret_cast(group_window_size_ptr), reinterpret_cast(group_min_full_attn_seqlen_ptr), @@ -674,9 +673,9 @@ struct HstuAttentionFwdSplitKVKernel index_t i_group = __builtin_amdgcn_readfirstlane(i_batch / kargs.num_batch_per_group); - float attn_scale = kargs.group_attn_scale_ptr[i_group]; - index_t max_seqlen = kargs.group_max_seqlen_ptr[i_group]; - kargs.scale_p = (attn_scale ? attn_scale : 1.0f / static_cast(max_seqlen)); + float attn_scale = kargs.group_attn_scale_ptr[i_group]; + index_t max_seqlen_q = kargs.group_max_seqlen_q_ptr[i_group]; + kargs.scale_p = (attn_scale ? attn_scale : 1.0f / static_cast(max_seqlen_q)); kargs.contextual_seqlen = kargs.group_contextual_seqlen_ptr[i_group]; kargs.window_size = kargs.group_window_size_ptr[i_group]; kargs.min_full_attn_seqlen = kargs.group_min_full_attn_seqlen_ptr[i_group]; diff --git a/example/ck_tile/18_hstu_attention/hstu_attention_group_forward_dispatch.hpp b/example/ck_tile/18_hstu_attention/hstu_attention_group_forward_dispatch.hpp index 987e3cf198..fa69ab7cb8 100644 --- a/example/ck_tile/18_hstu_attention/hstu_attention_group_forward_dispatch.hpp +++ b/example/ck_tile/18_hstu_attention/hstu_attention_group_forward_dispatch.hpp @@ -133,7 +133,7 @@ struct group_forward_causal_softmax_bias_dropout_dispatch param.seq_q_offsets_ptr, param.is_cross_attention ? param.seq_kv_offsets_ptr : param.seq_q_offsets_ptr, - param.group_max_seqlen_ptr, + param.group_max_seqlen_q_ptr, param.group_contextual_seqlen_ptr, param.group_window_size_ptr, param.group_min_full_attn_seqlen_ptr, @@ -159,7 +159,7 @@ struct group_forward_causal_softmax_bias_dropout_dispatch }(); dim3 kGridSize = - HstuKernel::GridSize(param.num_batch, param.num_head, param.max_seqlen, param.hdim_v); + HstuKernel::GridSize(param.num_batch, param.num_head, param.max_seqlen_q, param.hdim_v); constexpr dim3 kBlockSize = HstuKernel::BlockSize(); constexpr ck_tile::index_t kBlockPerCu = HstuKernel::kBlockPerCu; @@ -178,7 +178,7 @@ template (param.num_batch) * param.max_seqlen * + size_t workspace_bytes = static_cast(param.num_batch) * param.max_seqlen_q * param.num_head * param.num_splits * param.hdim_v * sizeof(OaccDataType); @@ -197,7 +197,7 @@ struct group_forward_splitkv_causal_softmax_bias_dropout_dispatch param.seq_q_offsets_ptr, param.is_cross_attention ? param.seq_kv_offsets_ptr : param.seq_q_offsets_ptr, - param.group_max_seqlen_ptr, + param.group_max_seqlen_q_ptr, param.group_contextual_seqlen_ptr, param.group_window_size_ptr, param.group_min_full_attn_seqlen_ptr, @@ -221,7 +221,7 @@ struct group_forward_splitkv_causal_softmax_bias_dropout_dispatch }(); dim3 kGridSize = HstuKernel::GridSize( - param.num_batch, param.num_head, param.max_seqlen, param.hdim_v, param.num_splits); + param.num_batch, param.num_head, param.max_seqlen_q, param.hdim_v, param.num_splits); constexpr dim3 kBlockSize = HstuKernel::BlockSize(); constexpr ck_tile::index_t kBlockPerCu = HstuKernel::kBlockPerCu; @@ -245,7 +245,7 @@ struct group_forward_splitkv_causal_softmax_bias_dropout_dispatch param.hdim_v); }(); - dim3 kGridSize = HstuKernel::GridSize(param.num_batch, param.num_head, param.max_seqlen); + dim3 kGridSize = HstuKernel::GridSize(param.num_batch, param.num_head, param.max_seqlen_q); constexpr dim3 kBlockSize = HstuKernel::BlockSize(); constexpr ck_tile::index_t kBlockPerCu = HstuKernel::kBlockPerCu; diff --git a/example/ck_tile/18_hstu_attention/hstu_attention_jagged_forward_dispatch.hpp b/example/ck_tile/18_hstu_attention/hstu_attention_jagged_forward_dispatch.hpp index 717b414f82..0e8dadee2f 100644 --- a/example/ck_tile/18_hstu_attention/hstu_attention_jagged_forward_dispatch.hpp +++ b/example/ck_tile/18_hstu_attention/hstu_attention_jagged_forward_dispatch.hpp @@ -132,7 +132,7 @@ struct jagged_forward_causal_softmax_bias_dropout_dispatch param.seq_q_offsets_ptr, param.is_cross_attention ? param.seq_kv_offsets_ptr : param.seq_q_offsets_ptr, - param.max_seqlen, + param.max_seqlen_q, param.hdim_qk, param.hdim_v, param.num_head, @@ -160,7 +160,7 @@ struct jagged_forward_causal_softmax_bias_dropout_dispatch bool has_minfull_attn_seqlen = (param.min_full_attn_seqlen > 0); dim3 kGridSize = HstuKernel::GridSize(param.num_batch, param.num_head, - param.max_seqlen, + param.max_seqlen_q, param.hdim_v, has_minfull_attn_seqlen); constexpr dim3 kBlockSize = HstuKernel::BlockSize(); @@ -181,7 +181,7 @@ template (param.num_batch) * param.max_seqlen * + size_t workspace_bytes = static_cast(param.num_batch) * param.max_seqlen_q * param.num_head * param.num_splits * param.hdim_v * sizeof(OaccDataType); @@ -195,7 +195,7 @@ struct jagged_forward_splitkv_causal_softmax_bias_dropout_dispatch param.seq_q_offsets_ptr, param.is_cross_attention ? param.seq_kv_offsets_ptr : param.seq_q_offsets_ptr, - param.max_seqlen, + param.max_seqlen_q, param.hdim_qk, param.hdim_v, param.num_head, @@ -221,7 +221,7 @@ struct jagged_forward_splitkv_causal_softmax_bias_dropout_dispatch bool has_minfull_attn_seqlen = (param.min_full_attn_seqlen > 0); dim3 kGridSize = HstuKernel::GridSize(param.num_batch, param.num_head, - param.max_seqlen, + param.max_seqlen_q, param.hdim_v, param.num_splits, has_minfull_attn_seqlen); @@ -248,7 +248,7 @@ struct jagged_forward_splitkv_causal_softmax_bias_dropout_dispatch param.hdim_v); }(); - dim3 kGridSize = HstuKernel::GridSize(param.num_batch, param.num_head, param.max_seqlen); + dim3 kGridSize = HstuKernel::GridSize(param.num_batch, param.num_head, param.max_seqlen_q); constexpr dim3 kBlockSize = HstuKernel::BlockSize(); constexpr ck_tile::index_t kBlockPerCu = HstuKernel::kBlockPerCu; diff --git a/example/ck_tile/18_hstu_attention/hstu_attention_params.hpp b/example/ck_tile/18_hstu_attention/hstu_attention_params.hpp index 03b8fc27e3..8acf9dbf1d 100644 --- a/example/ck_tile/18_hstu_attention/hstu_attention_params.hpp +++ b/example/ck_tile/18_hstu_attention/hstu_attention_params.hpp @@ -19,7 +19,7 @@ struct HstuAttentionNoGroupFwdParams ck_tile::index_t seqlen_kv; // batched mode only const void* seq_q_offsets_ptr; // jagged mode only const void* seq_kv_offsets_ptr; // jagged mode only - ck_tile::index_t max_seqlen; // jagged mode only + ck_tile::index_t max_seqlen_q; // jagged mode only const void* q_ptr; const void* k_ptr; @@ -84,7 +84,7 @@ struct HstuAttentionGroupFwdParams ck_tile::index_t num_batch; const void* seq_q_offsets_ptr; const void* seq_kv_offsets_ptr; - ck_tile::index_t max_seqlen; // the maximum of all the groups' max_seqlen + ck_tile::index_t max_seqlen_q; // the maximum of all the groups' max_seqlen_q const void* q_ptr; const void* k_ptr; @@ -122,7 +122,7 @@ struct HstuAttentionGroupFwdParams // parameters used by Group HSTU const void* group_attn_scale_ptr; - const void* group_max_seqlen_ptr; + const void* group_max_seqlen_q_ptr; // use for setting attn_scales const void* group_window_size_ptr; const void* group_contextual_seqlen_ptr; const void* group_min_full_attn_seqlen_ptr; diff --git a/example/ck_tile/18_hstu_attention/reference_hstu_attention.hpp b/example/ck_tile/18_hstu_attention/reference_hstu_attention.hpp index dacb0e3059..f5d6e80bd0 100644 --- a/example/ck_tile/18_hstu_attention/reference_hstu_attention.hpp +++ b/example/ck_tile/18_hstu_attention/reference_hstu_attention.hpp @@ -43,7 +43,7 @@ struct reference_no_group_hstu_attention float alpha, float attn_scale, int max_seqlen_q, - int max_seqlen_kv, + int max_seqlen_kv, // only used as last dim of the tensor for saving the mask std::vector seq_q_offsets, std::vector seq_kv_offsets, std::vector num_targets, // define masking length at the end of token @@ -116,9 +116,7 @@ struct reference_no_group_hstu_attention int num_target = num_targets.empty() ? 0 : num_targets[i_batch]; - float scale_p = attn_scale - ? attn_scale - : 1.0f / static_cast(max(max_seqlen_q, max_seqlen_kv)); + float scale_p = attn_scale ? attn_scale : 1.0f / static_cast(max_seqlen_q); BOOL_SWITCH_2(window_size > 0, kHasLocal, is_cross_attention, kIsCrossAttention, [&] { using HstuMaskType = @@ -335,12 +333,15 @@ struct reference_group_hstu_attention int num_batch, int num_batch_per_group, float alpha, - int max_max_seqlen, // the maximum of all groups's max_seqlen + int max_max_seqlen_q, // the maximum of all groups's max_seqlen_q, only used as second last + // dim of the tensor for saving the mask + int max_max_seqlen_kv, // the maximum of all groups's max_seqlen_k, only used as last dim of + // the tensor for saving the mask const std::vector& seq_q_offsets, const std::vector& seq_kv_offsets, - const std::vector& num_targets, // define masking length at the end of token - // sequence to be excluded for attention - const std::vector& group_max_seqlens, // max seqlen list by groups + const std::vector& num_targets, // define masking length at the end of token + // sequence to be excluded for attention + const std::vector& group_max_seqlens_q, // max seqlen_q list by groups const std::vector& group_contextual_seqlens, // contextual seqlen list by groups const std::vector& group_window_sizes, // window_size list by groups const std::vector& group_min_full_attn_seqlens, // min_full_attn_seqlen list by groups @@ -374,8 +375,8 @@ struct reference_group_hstu_attention if(static_cast(mask_batch_nhead_seq_seq.get_lengths()[0]) == num_batch && static_cast(mask_batch_nhead_seq_seq.get_lengths()[1]) == num_head && - static_cast(mask_batch_nhead_seq_seq.get_lengths()[2]) == max_max_seqlen && - static_cast(mask_batch_nhead_seq_seq.get_lengths()[3]) == max_max_seqlen) + static_cast(mask_batch_nhead_seq_seq.get_lengths()[2]) == max_max_seqlen_q && + static_cast(mask_batch_nhead_seq_seq.get_lengths()[3]) == max_max_seqlen_kv) save_mask = true; // check num_tagets @@ -394,10 +395,10 @@ struct reference_group_hstu_attention int num_target = num_targets.empty() ? 0 : num_targets[i_batch]; - int max_seqlen = group_max_seqlens[i_group]; + int max_seqlen_q = group_max_seqlens_q[i_group]; float attn_scale = group_attn_scales[i_group]; - float scale_p = (attn_scale ? attn_scale : 1.0f / static_cast(max_seqlen)); + float scale_p = (attn_scale ? attn_scale : 1.0f / static_cast(max_seqlen_q)); int contextual_seqlen = group_contextual_seqlens[i_group]; int window_size = group_window_sizes[i_group]; @@ -468,8 +469,8 @@ struct reference_group_hstu_attention if(save_mask) { - for(int sq = 0; sq < max_seqlen; sq++) - for(int sk = 0; sk < max_seqlen; sk++) + for(int sq = 0; sq < max_max_seqlen_q; sq++) + for(int sk = 0; sk < max_max_seqlen_kv; sk++) mask_batch_nhead_seq_seq(i_batch, i_head, sq, sk) = 0; for(int sq = 0; sq < seqlen_q; sq++)