diff --git a/example/ck_tile/18_hstu_attention/hstu_attention_fwd_setting.hpp b/example/ck_tile/18_hstu_attention/hstu_attention_fwd_setting.hpp index c1e6c3c941..98c912fc5f 100644 --- a/example/ck_tile/18_hstu_attention/hstu_attention_fwd_setting.hpp +++ b/example/ck_tile/18_hstu_attention/hstu_attention_fwd_setting.hpp @@ -471,6 +471,9 @@ static int get_hstu_attention_fwd_mtile(int num_batches, int num_heads, int max_ int num_CUs = get_number_of_cu(); auto ceildiv = [](int a, int b) { return (a + b - 1) / b; }; + if(max_seqlen_q <= 64) + return 64; + int nbatch_nhead_mblocks = num_batches * num_heads * ceildiv(max_seqlen_q, 128); // assuming each CU is assigned two work-groups