diff --git a/example/ck_tile/18_hstu_attention/example_hstu_attention.cpp b/example/ck_tile/18_hstu_attention/example_hstu_attention.cpp index 2eebcc866b..d39e2f8207 100644 --- a/example/ck_tile/18_hstu_attention/example_hstu_attention.cpp +++ b/example/ck_tile/18_hstu_attention/example_hstu_attention.cpp @@ -100,11 +100,11 @@ auto create_args(int argc, char* argv[]) .insert("hdim_v", "64", "headdim size of V/O") .insert("seqlens", "400", "uih seqlen of single or all batches for query tensor, actually allocated seqlen will include the target of each batch and context_len") .insert("seqlens_kv", "", "uih seqlen of single or all batches for key/value tensor, actually allocated seqlen will include the target of each batch and context_len") - .insert("max_seqlen", "0", "max uih_seqlen, can be ignored, or else must be equal or bigger than the maximum of all uih seqlens") - .insert("max_seqlen_kv", "0", "max uih_seqlen_kv, can be ignored, or else must be equal or bigger than the maximum of all uih seqlens") - .insert("g_max_seqlens", "0", "max uih_seqlen, can be ignored, or else must be equal or bigger than the maximum of all uih seqlens") + .insert("max_seqlen", "0", "max uih_seqlen, can be ignored, or else must be equal/bigger than the maximum of all uih seqlens") + .insert("max_seqlen_kv", "0", "max uih_seqlen_kv, can be ignored, or else must be equal/bigger than the maximum of all uih seqlens") + .insert("g_max_seqlens", "0", "max uih_seqlen of groups, can be ignored, or else each must be equal/bigger than maximum of all uih seqlens in its group") .insert("targets", "", "sequence length at the end of query/key token sequence that should be excluded from attention") - .insert("max_target", "0", "max target, can be ignored, or else must be equal of bigger than the maximum of all targets") + .insert("max_target", "0", "max target, can be ignored, or else must be equal/bigger than the maximum of all targets") .insert("softmax", "0", "use softmax or not") .insert("causal", "1", "enable causal mask or not") .insert("local_len", "5", "length of the diagonal window for enabling masking, value 0 to disable") @@ -697,10 +697,10 @@ bool run_group_hstu(const ck_tile::ArgParser& arg_parser, int num_group) else is_cross_attention = true; - str_of_integers = arg_parser.get_str("g_max_seqlens"); - std::vector group_max_seqlens = get_integers_from_string(str_of_integers); + str_of_integers = arg_parser.get_str("g_max_seqlens"); + std::vector group_input_max_uih_seqlens = get_integers_from_string(str_of_integers); - HSTU_CHECK(!group_max_seqlens.empty(), "group window sizes shoud be defined!"); + HSTU_CHECK(!group_input_max_uih_seqlens.empty(), "group window sizes shoud be defined!"); str_of_integers = arg_parser.get_str("g_context_lens"); std::vector group_contextual_seqlens = get_integers_from_string(str_of_integers); @@ -733,8 +733,9 @@ bool run_group_hstu(const ck_tile::ArgParser& arg_parser, int num_group) supplement_array_by_last_element(num_targets, num_batch); }; - // supplement group_max_seqlens using the last input value if user-provided lengths not enough - supplement_array_by_last_element(group_max_seqlens, num_group); + // supplement group_input_max_uih_seqlens using the last input value if user-provided lengths + // not enough + supplement_array_by_last_element(group_input_max_uih_seqlens, num_group); // supplement group_contextual_seqlens using the last input value if user-provided lengths not // enough @@ -754,10 +755,41 @@ bool run_group_hstu(const ck_tile::ArgParser& arg_parser, int num_group) int phy_seqlen_kv = 0; int max_max_seqlen = 0; - // only consider num_group values even if more values were provided by the user - for(int i = 0; i < num_group; i++) + std::vector group_max_uih_seqlens; + + group_max_uih_seqlens.resize(num_group); + + for(int i_grp = 0; i_grp < num_group; i_grp++) { - max_max_seqlen = max(max_max_seqlen, group_max_seqlens[i]); + group_max_uih_seqlens[i_grp] = 0; + + for(int i_batch = 0; i_batch < num_batch_per_group; i_batch++) + { + auto i_global_batch = i_grp * num_batch_per_group + i_batch; + + group_max_uih_seqlens[i_grp] = + max(group_max_uih_seqlens[i_grp], seq_lengths_q[i_global_batch]); + }; + + HSTU_CHECK(group_input_max_uih_seqlens[i_grp] <= 0 || + group_input_max_uih_seqlens[i_grp] >= group_max_uih_seqlens[i_grp], + "the user input of each group max_uih_seqlen can either be ignored or be bigger " + "than all uih_seqlens[] of the group"); + + group_max_uih_seqlens[i_grp] = group_input_max_uih_seqlens[i_grp] > 0 + ? group_input_max_uih_seqlens[i_grp] + : group_max_uih_seqlens[i_grp]; + }; + + std::vector group_max_seqlens; + + group_max_seqlens.resize(num_group); + + for(int i_grp = 0; i_grp < num_group; i_grp++) + { + group_max_seqlens[i_grp] = + group_max_uih_seqlens[i_grp] + group_contextual_seqlens[i_grp] + num_targets[i_grp]; + max_max_seqlen = max(max_max_seqlen, group_max_seqlens[i_grp]); }; std::vector seq_offsets_q;