mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-15 10:37:44 +00:00
Add simple handling for max_atten_seqlen bigger than max_uih_len situations
This commit is contained in:
@@ -21,8 +21,8 @@ struct HstuBlockMaskWithLocal
|
||||
int seqlen;
|
||||
int contextual_seqlen;
|
||||
|
||||
int max_attn_len;
|
||||
int min_full_attn_seqlen;
|
||||
int max_attn_len;
|
||||
|
||||
int max_uih_len;
|
||||
int max_id;
|
||||
@@ -36,11 +36,13 @@ struct HstuBlockMaskWithLocal
|
||||
: is_tile_in_first_split(is_tile_in_first_split_),
|
||||
seqlen(seqlen_),
|
||||
contextual_seqlen(contextual_seqlen_),
|
||||
max_attn_len(max_attn_len_),
|
||||
min_full_attn_seqlen(min_full_attn_seqlen_)
|
||||
{
|
||||
max_uih_len = seqlen - num_target_;
|
||||
|
||||
// in case user provided max_attn_len_ could be bigger than max_uih_len
|
||||
max_attn_len = min(max_uih_len, max_attn_len_);
|
||||
|
||||
// assuming min_full_attn_seqlen has higher priority, ensure contextual scope not collide
|
||||
// with min_full_attn_seqlen scope
|
||||
contextual_seqlen = min(contextual_seqlen, max_uih_len - min_full_attn_seqlen);
|
||||
|
||||
Reference in New Issue
Block a user