diff --git a/example/ck_tile/18_hstu_attention/hstu_block_masking.hpp b/example/ck_tile/18_hstu_attention/hstu_block_masking.hpp index 161b0cb5cb..c616198cab 100644 --- a/example/ck_tile/18_hstu_attention/hstu_block_masking.hpp +++ b/example/ck_tile/18_hstu_attention/hstu_block_masking.hpp @@ -96,16 +96,24 @@ struct HstuBlockMaskWithLocal if(row < contextual_seqlen) return true; - bool result = false; if constexpr(kUseCausal) - result = (row >= col) && (row - col <= max_attn_len); + { + bool result = (row >= col) && (row - col <= max_attn_len); + + if(min_full_attn_seqlen > 0) + result = result || (row >= max_uih_len - min_full_attn_seqlen); + + return result; + } else - result = abs(row - col) <= max_attn_len; + { + bool result = abs(row - col) <= max_attn_len; - if(min_full_attn_seqlen > 0) - result = result || (row >= max_uih_len - min_full_attn_seqlen); + if(min_full_attn_seqlen > 0) + result = result || (row >= max_uih_len - min_full_attn_seqlen); - return result; + return result; + } }; };