From 9996270087c532d874ad9a8a70c37aca535a3c8a Mon Sep 17 00:00:00 2001 From: Qianfeng Zhang Date: Fri, 25 Apr 2025 15:36:58 +0000 Subject: [PATCH] Tiny update in IsTokenPairInsideMask() --- .../18_hstu_attention/hstu_block_masking.hpp | 20 +++++++++++++------ 1 file changed, 14 insertions(+), 6 deletions(-) diff --git a/example/ck_tile/18_hstu_attention/hstu_block_masking.hpp b/example/ck_tile/18_hstu_attention/hstu_block_masking.hpp index 161b0cb5cb..c616198cab 100644 --- a/example/ck_tile/18_hstu_attention/hstu_block_masking.hpp +++ b/example/ck_tile/18_hstu_attention/hstu_block_masking.hpp @@ -96,16 +96,24 @@ struct HstuBlockMaskWithLocal if(row < contextual_seqlen) return true; - bool result = false; if constexpr(kUseCausal) - result = (row >= col) && (row - col <= max_attn_len); + { + bool result = (row >= col) && (row - col <= max_attn_len); + + if(min_full_attn_seqlen > 0) + result = result || (row >= max_uih_len - min_full_attn_seqlen); + + return result; + } else - result = abs(row - col) <= max_attn_len; + { + bool result = abs(row - col) <= max_attn_len; - if(min_full_attn_seqlen > 0) - result = result || (row >= max_uih_len - min_full_attn_seqlen); + if(min_full_attn_seqlen > 0) + result = result || (row >= max_uih_len - min_full_attn_seqlen); - return result; + return result; + } }; };