Tiny fix in HstuBlockMaskWithLocal::GetTileRangeAlongX()

This commit is contained in:
Qianfeng Zhang
2025-08-12 01:58:10 +00:00
parent 30dd274e7e
commit d2b0f7503e

View File

@@ -88,7 +88,7 @@ struct HstuBlockMaskWithLocal
// in [0, max_uih_len-min_full_attn_seqlen)
if constexpr(!kUseCausal)
{
if(i_y >= contextual_seqlen + max_attn_len)
if(i_y >= min(contextual_seqlen, 1) + max_attn_len)
{
// some row of the tile in [contextual_seqlen+max_attn_len, max_uih_len)
if(i_y < max_uih_len)
@@ -132,7 +132,7 @@ struct HstuBlockMaskWithLocal
}
else // kUseCausal && kUseLocal
{
if(i_y >= contextual_seqlen + max_attn_len)
if(i_y >= min(contextual_seqlen, 1) + max_attn_len)
{
index_t x_end = min(i_y + YTile, seqlen);