diff --git a/example/ck_tile/18_hstu_attention/example_hstu_attention_fwd.cpp b/example/ck_tile/18_hstu_attention/example_hstu_attention_fwd.cpp index 7e1cd9e643..acff2ecfd6 100644 --- a/example/ck_tile/18_hstu_attention/example_hstu_attention_fwd.cpp +++ b/example/ck_tile/18_hstu_attention/example_hstu_attention_fwd.cpp @@ -847,11 +847,23 @@ bool run_group_hstu(const ck_tile::ArgParser& arg_parser, int num_group) for(int i_grp = 0; i_grp < num_group; i_grp++) { + int max_num_target = 0; + + if(!num_targets.empty()) + { + for(int i_batch = 0; i_batch < num_batch_per_group; i_batch++) + { + int i_global_batch = i_grp * num_batch_per_group * i_batch; + + max_num_target = max(max_num_target, num_targets[i_global_batch]); + }; + }; + group_max_seqlens_q[i_grp] = - group_max_uih_seqlens_q[i_grp] + group_contextual_seqlens[i_grp] + num_targets[i_grp]; + group_max_uih_seqlens_q[i_grp] + group_contextual_seqlens[i_grp] + max_num_target; max_max_seqlen_q = max(max_max_seqlen_q, group_max_seqlens_q[i_grp]); group_max_seqlens_kv[i_grp] = - group_max_uih_seqlens_kv[i_grp] + group_contextual_seqlens[i_grp] + num_targets[i_grp]; + group_max_uih_seqlens_kv[i_grp] + group_contextual_seqlens[i_grp] + max_num_target; max_max_seqlen_kv = max(max_max_seqlen_kv, group_max_seqlens_kv[i_grp]); };