Clarify the using of group_max_seqlens[] and group_input_max_uih_seqlens[] parameters for group attention example

This commit is contained in:
Qianfeng Zhang
2026-04-15 16:18:43 +00:00
parent 9279af33f1
commit 7889844d6b

View File

@@ -100,11 +100,11 @@ auto create_args(int argc, char* argv[])
.insert("hdim_v", "64", "headdim size of V/O")
.insert("seqlens", "400", "uih seqlen of single or all batches for query tensor, actually allocated seqlen will include the target of each batch and context_len")
.insert("seqlens_kv", "", "uih seqlen of single or all batches for key/value tensor, actually allocated seqlen will include the target of each batch and context_len")
.insert("max_seqlen", "0", "max uih_seqlen, can be ignored, or else must be equal or bigger than the maximum of all uih seqlens")
.insert("max_seqlen_kv", "0", "max uih_seqlen_kv, can be ignored, or else must be equal or bigger than the maximum of all uih seqlens")
.insert("g_max_seqlens", "0", "max uih_seqlen, can be ignored, or else must be equal or bigger than the maximum of all uih seqlens")
.insert("max_seqlen", "0", "max uih_seqlen, can be ignored, or else must be equal/bigger than the maximum of all uih seqlens")
.insert("max_seqlen_kv", "0", "max uih_seqlen_kv, can be ignored, or else must be equal/bigger than the maximum of all uih seqlens")
.insert("g_max_seqlens", "0", "max uih_seqlen of groups, can be ignored, or else each must be equal/bigger than maximum of all uih seqlens in its group")
.insert("targets", "", "sequence length at the end of query/key token sequence that should be excluded from attention")
.insert("max_target", "0", "max target, can be ignored, or else must be equal of bigger than the maximum of all targets")
.insert("max_target", "0", "max target, can be ignored, or else must be equal/bigger than the maximum of all targets")
.insert("softmax", "0", "use softmax or not")
.insert("causal", "1", "enable causal mask or not")
.insert("local_len", "5", "length of the diagonal window for enabling masking, value 0 to disable")
@@ -697,10 +697,10 @@ bool run_group_hstu(const ck_tile::ArgParser& arg_parser, int num_group)
else
is_cross_attention = true;
str_of_integers = arg_parser.get_str("g_max_seqlens");
std::vector<int> group_max_seqlens = get_integers_from_string(str_of_integers);
str_of_integers = arg_parser.get_str("g_max_seqlens");
std::vector<int> group_input_max_uih_seqlens = get_integers_from_string(str_of_integers);
HSTU_CHECK(!group_max_seqlens.empty(), "group window sizes shoud be defined!");
HSTU_CHECK(!group_input_max_uih_seqlens.empty(), "group window sizes shoud be defined!");
str_of_integers = arg_parser.get_str("g_context_lens");
std::vector<int> group_contextual_seqlens = get_integers_from_string(str_of_integers);
@@ -733,8 +733,9 @@ bool run_group_hstu(const ck_tile::ArgParser& arg_parser, int num_group)
supplement_array_by_last_element(num_targets, num_batch);
};
// supplement group_max_seqlens using the last input value if user-provided lengths not enough
supplement_array_by_last_element(group_max_seqlens, num_group);
// supplement group_input_max_uih_seqlens using the last input value if user-provided lengths
// not enough
supplement_array_by_last_element(group_input_max_uih_seqlens, num_group);
// supplement group_contextual_seqlens using the last input value if user-provided lengths not
// enough
@@ -754,10 +755,41 @@ bool run_group_hstu(const ck_tile::ArgParser& arg_parser, int num_group)
int phy_seqlen_kv = 0;
int max_max_seqlen = 0;
// only consider num_group values even if more values were provided by the user
for(int i = 0; i < num_group; i++)
std::vector<int> group_max_uih_seqlens;
group_max_uih_seqlens.resize(num_group);
for(int i_grp = 0; i_grp < num_group; i_grp++)
{
max_max_seqlen = max(max_max_seqlen, group_max_seqlens[i]);
group_max_uih_seqlens[i_grp] = 0;
for(int i_batch = 0; i_batch < num_batch_per_group; i_batch++)
{
auto i_global_batch = i_grp * num_batch_per_group + i_batch;
group_max_uih_seqlens[i_grp] =
max(group_max_uih_seqlens[i_grp], seq_lengths_q[i_global_batch]);
};
HSTU_CHECK(group_input_max_uih_seqlens[i_grp] <= 0 ||
group_input_max_uih_seqlens[i_grp] >= group_max_uih_seqlens[i_grp],
"the user input of each group max_uih_seqlen can either be ignored or be bigger "
"than all uih_seqlens[] of the group");
group_max_uih_seqlens[i_grp] = group_input_max_uih_seqlens[i_grp] > 0
? group_input_max_uih_seqlens[i_grp]
: group_max_uih_seqlens[i_grp];
};
std::vector<int> group_max_seqlens;
group_max_seqlens.resize(num_group);
for(int i_grp = 0; i_grp < num_group; i_grp++)
{
group_max_seqlens[i_grp] =
group_max_uih_seqlens[i_grp] + group_contextual_seqlens[i_grp] + num_targets[i_grp];
max_max_seqlen = max(max_max_seqlen, group_max_seqlens[i_grp]);
};
std::vector<int> seq_offsets_q;