From 01c123dedd669a64ff44f8cc9525c619c713edff Mon Sep 17 00:00:00 2001 From: Qianfeng Zhang Date: Fri, 25 Jul 2025 11:16:54 +0000 Subject: [PATCH] Fix in GetTileRangeAlongX() and IsFullTileInsideMask() of HstuBlockMaskWithLocal --- .../18_hstu_attention/hstu_block_masking.hpp | 33 ++++++++++++------- 1 file changed, 21 insertions(+), 12 deletions(-) diff --git a/example/ck_tile/18_hstu_attention/hstu_block_masking.hpp b/example/ck_tile/18_hstu_attention/hstu_block_masking.hpp index d5b0724566..ac8b6a0f58 100644 --- a/example/ck_tile/18_hstu_attention/hstu_block_masking.hpp +++ b/example/ck_tile/18_hstu_attention/hstu_block_masking.hpp @@ -54,16 +54,13 @@ struct HstuBlockMaskWithLocal CK_TILE_HOST_DEVICE constexpr auto GetTileRangeAlongX(index_t i_y, number, number) const { - // handle two special cases first - if(!is_tile_in_first_split) + if constexpr(kUseCausal) { - // the tile is completely inside [max_uih_len - min_full_attn_seqlen, max_uih_len) - if(i_y + YTile <= max_uih_len) - return ck_tile::make_tuple(0, max_uih_len); - // the tils is partially inside [max_uih_len - min_full_attn_seqlen, max_uih_len) and - // partially inside [max_uih_len, seqlen) - if(i_y < max_uih_len) - return ck_tile::make_tuple(0, seqlen); + if(!is_tile_in_first_split) + { + index_t x_end = min(i_y + YTile, seqlen); + return ck_tile::make_tuple(0, x_end); + }; }; if constexpr(!kUseCausal) @@ -241,10 +238,22 @@ struct HstuBlockMaskWithLocal number, number) const { - std::ignore = i_tile_left; + if constexpr(kUseCausal) + { + index_t i_tile_right = i_tile_left + TileWidth; - if(!is_tile_in_first_split && (i_tile_top + TileHeight <= max_uih_len)) - return true; + if(!is_tile_in_first_split && i_tile_right <= min(i_tile_top + 1, max_uih_len)) + return true; + } + else + { + index_t i_tile_right = i_tile_left + TileWidth; + index_t i_tile_bottom = i_tile_top + TileHeight; + + if(!is_tile_in_first_split && i_tile_bottom <= max_uih_len && + i_tile_right <= i_tile_top + 1) + return true; + }; return false; }