diff --git a/example/ck_tile/18_hstu_attention/hstu_block_masking.hpp b/example/ck_tile/18_hstu_attention/hstu_block_masking.hpp index 37d2ea8751..d6a7290e1b 100644 --- a/example/ck_tile/18_hstu_attention/hstu_block_masking.hpp +++ b/example/ck_tile/18_hstu_attention/hstu_block_masking.hpp @@ -275,8 +275,11 @@ struct HstuBlockMaskWithLocal index_t i_tile_right = i_tile_left + TileWidth; index_t i_tile_bottom = i_tile_top + TileHeight; - if(!is_tile_in_first_split && i_tile_bottom <= max_uih_len && - i_tile_right <= i_tile_top + 1) + // 1) tile is completely in [max_uih_len-min_full_attn_seqlen, max_uih_len] + // 2) some row of tile is in [max_uih_len, seqlen], requires i_tile_right <= max_uih_len + // to return true + if(!is_tile_in_first_split && + (i_tile_bottom <= max_uih_len || i_tile_right <= max_uih_len)) return true; };