mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-15 18:42:06 +00:00
Fix in GetTileRangeAlongX() and IsFullTileInsideMask() of HstuBlockMaskWithLocal
This commit is contained in:
@@ -54,16 +54,13 @@ struct HstuBlockMaskWithLocal
|
||||
CK_TILE_HOST_DEVICE constexpr auto
|
||||
GetTileRangeAlongX(index_t i_y, number<YTile>, number<XTile>) const
|
||||
{
|
||||
// handle two special cases first
|
||||
if(!is_tile_in_first_split)
|
||||
if constexpr(kUseCausal)
|
||||
{
|
||||
// the tile is completely inside [max_uih_len - min_full_attn_seqlen, max_uih_len)
|
||||
if(i_y + YTile <= max_uih_len)
|
||||
return ck_tile::make_tuple(0, max_uih_len);
|
||||
// the tils is partially inside [max_uih_len - min_full_attn_seqlen, max_uih_len) and
|
||||
// partially inside [max_uih_len, seqlen)
|
||||
if(i_y < max_uih_len)
|
||||
return ck_tile::make_tuple(0, seqlen);
|
||||
if(!is_tile_in_first_split)
|
||||
{
|
||||
index_t x_end = min(i_y + YTile, seqlen);
|
||||
return ck_tile::make_tuple(0, x_end);
|
||||
};
|
||||
};
|
||||
|
||||
if constexpr(!kUseCausal)
|
||||
@@ -241,10 +238,22 @@ struct HstuBlockMaskWithLocal
|
||||
number<TileWidth>,
|
||||
number<TileHeight>) const
|
||||
{
|
||||
std::ignore = i_tile_left;
|
||||
if constexpr(kUseCausal)
|
||||
{
|
||||
index_t i_tile_right = i_tile_left + TileWidth;
|
||||
|
||||
if(!is_tile_in_first_split && (i_tile_top + TileHeight <= max_uih_len))
|
||||
return true;
|
||||
if(!is_tile_in_first_split && i_tile_right <= min(i_tile_top + 1, max_uih_len))
|
||||
return true;
|
||||
}
|
||||
else
|
||||
{
|
||||
index_t i_tile_right = i_tile_left + TileWidth;
|
||||
index_t i_tile_bottom = i_tile_top + TileHeight;
|
||||
|
||||
if(!is_tile_in_first_split && i_tile_bottom <= max_uih_len &&
|
||||
i_tile_right <= i_tile_top + 1)
|
||||
return true;
|
||||
};
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user