Fix in GetTileRangeAlongX() and IsFullTileInsideMask() of HstuBlockMaskWithLocal

This commit is contained in:
Qianfeng Zhang
2025-07-25 11:16:54 +00:00
parent cf012c23fc
commit 01c123dedd

View File

@@ -54,16 +54,13 @@ struct HstuBlockMaskWithLocal
CK_TILE_HOST_DEVICE constexpr auto
GetTileRangeAlongX(index_t i_y, number<YTile>, number<XTile>) const
{
// handle two special cases first
if(!is_tile_in_first_split)
if constexpr(kUseCausal)
{
// the tile is completely inside [max_uih_len - min_full_attn_seqlen, max_uih_len)
if(i_y + YTile <= max_uih_len)
return ck_tile::make_tuple(0, max_uih_len);
// the tils is partially inside [max_uih_len - min_full_attn_seqlen, max_uih_len) and
// partially inside [max_uih_len, seqlen)
if(i_y < max_uih_len)
return ck_tile::make_tuple(0, seqlen);
if(!is_tile_in_first_split)
{
index_t x_end = min(i_y + YTile, seqlen);
return ck_tile::make_tuple(0, x_end);
};
};
if constexpr(!kUseCausal)
@@ -241,10 +238,22 @@ struct HstuBlockMaskWithLocal
number<TileWidth>,
number<TileHeight>) const
{
std::ignore = i_tile_left;
if constexpr(kUseCausal)
{
index_t i_tile_right = i_tile_left + TileWidth;
if(!is_tile_in_first_split && (i_tile_top + TileHeight <= max_uih_len))
return true;
if(!is_tile_in_first_split && i_tile_right <= min(i_tile_top + 1, max_uih_len))
return true;
}
else
{
index_t i_tile_right = i_tile_left + TileWidth;
index_t i_tile_bottom = i_tile_top + TileHeight;
if(!is_tile_in_first_split && i_tile_bottom <= max_uih_len &&
i_tile_right <= i_tile_top + 1)
return true;
};
return false;
}