From 9e62359b59787ca17ab4173a3662f9f54b79b229 Mon Sep 17 00:00:00 2001 From: root Date: Wed, 18 Jun 2025 15:12:05 +0000 Subject: [PATCH] Tiny fix in hstu attention IsFullTileInsideMask() --- .../18_hstu_attention/hstu_block_masking.hpp | 27 ++++++++++++------- 1 file changed, 17 insertions(+), 10 deletions(-) diff --git a/example/ck_tile/18_hstu_attention/hstu_block_masking.hpp b/example/ck_tile/18_hstu_attention/hstu_block_masking.hpp index 3ae5924f6a..6311a939d5 100644 --- a/example/ck_tile/18_hstu_attention/hstu_block_masking.hpp +++ b/example/ck_tile/18_hstu_attention/hstu_block_masking.hpp @@ -267,7 +267,11 @@ struct HstuBlockMaskWithLocal { std::ignore = i_tile_left; - if(min_full_attn_seqlen > 0 && i_tile_top >= max_uih_len - min_full_attn_seqlen) + index_t i_tile_bottom = i_tile_top + (TileHeight - 1); + + // assume num_target > 0 with high probability, don't check whether num_target is 0; + // so if num_target is 0, IsTokenPairInsideMask() will be called for the bottom tile + if(i_tile_top >= max_uih_len - min_full_attn_seqlen && i_tile_bottom < max_uih_len) return true; return false; @@ -440,26 +444,29 @@ struct HstuBlockMaskNoLocal { if constexpr(kUseCausal) { - index_t i_tile_right = i_tile_left + TileWidth; - index_t i_tile_bottom = i_tile_top + TileHeight; + index_t i_tile_right = i_tile_left + (TileWidth - 1); + index_t i_tile_bottom = i_tile_top + (TileHeight - 1); - if(i_tile_right > i_tile_top || - (i_tile_bottom > max_uih_len && i_tile_right > max_uih_len)) + // assume num_target > 0 with high probability, don't check whether num_target is 0; + // so if num_target is 0, IsTokenPairInsideMask() will be called for the bottom tile + if(i_tile_bottom >= max_uih_len || i_tile_right > i_tile_top) return false; return true; } else { - index_t i_tile_right = i_tile_left + TileWidth; - index_t i_tile_bottom = i_tile_top + TileHeight; + index_t i_tile_right = i_tile_left + (TileWidth - 1); + index_t i_tile_bottom = i_tile_top + (TileHeight - 1); - if(i_tile_bottom > max_uih_len && i_tile_right > max_uih_len) + // assume num_target > 0 with high probability, don't check whether num_target is 0; + // so if num_target is 0, IsTokenPairInsideMask() will be called for the bottom tile + if(i_tile_bottom >= max_uih_len || i_tile_right >= max_uih_len) return false; return true; - }; - } + } + }; }; template