diff --git a/include/ck_tile/ops/unified_attention/kernel/unified_attention_kernel.hpp b/include/ck_tile/ops/unified_attention/kernel/unified_attention_kernel.hpp index 6d4d2a2500..42e0dd25ff 100644 --- a/include/ck_tile/ops/unified_attention/kernel/unified_attention_kernel.hpp +++ b/include/ck_tile/ops/unified_attention/kernel/unified_attention_kernel.hpp @@ -449,10 +449,10 @@ struct UnifiedAttentionKernel FmhaMask mask = [&]() { if constexpr(kHasMask) return ck_tile::make_generic_attention_mask_from_lr_window( - cur_batch_query_len, // x (i.e. extend) seq_len - cur_batch_query_len, // y (i.e. context) - cur_batch_query_len, // x_total + cur_batch_query_len, // x (i.e. extend) seq_len, // y_total (x + y) + cur_batch_query_len, // x_total num_queries_per_kv, // the same sequence index is repeated num_queries_per_kv times along x dim of the tile kargs.mask_type == GenericAttentionMaskEnum::MASK_FROM_TOP_LEFT); else