mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-16 10:59:55 +00:00
Tiny fix in hstu attention IsFullTileInsideMask()
This commit is contained in:
@@ -267,7 +267,11 @@ struct HstuBlockMaskWithLocal
|
||||
{
|
||||
std::ignore = i_tile_left;
|
||||
|
||||
if(min_full_attn_seqlen > 0 && i_tile_top >= max_uih_len - min_full_attn_seqlen)
|
||||
index_t i_tile_bottom = i_tile_top + (TileHeight - 1);
|
||||
|
||||
// assume num_target > 0 with high probability, don't check whether num_target is 0;
|
||||
// so if num_target is 0, IsTokenPairInsideMask() will be called for the bottom tile
|
||||
if(i_tile_top >= max_uih_len - min_full_attn_seqlen && i_tile_bottom < max_uih_len)
|
||||
return true;
|
||||
|
||||
return false;
|
||||
@@ -440,26 +444,29 @@ struct HstuBlockMaskNoLocal
|
||||
{
|
||||
if constexpr(kUseCausal)
|
||||
{
|
||||
index_t i_tile_right = i_tile_left + TileWidth;
|
||||
index_t i_tile_bottom = i_tile_top + TileHeight;
|
||||
index_t i_tile_right = i_tile_left + (TileWidth - 1);
|
||||
index_t i_tile_bottom = i_tile_top + (TileHeight - 1);
|
||||
|
||||
if(i_tile_right > i_tile_top ||
|
||||
(i_tile_bottom > max_uih_len && i_tile_right > max_uih_len))
|
||||
// assume num_target > 0 with high probability, don't check whether num_target is 0;
|
||||
// so if num_target is 0, IsTokenPairInsideMask() will be called for the bottom tile
|
||||
if(i_tile_bottom >= max_uih_len || i_tile_right > i_tile_top)
|
||||
return false;
|
||||
|
||||
return true;
|
||||
}
|
||||
else
|
||||
{
|
||||
index_t i_tile_right = i_tile_left + TileWidth;
|
||||
index_t i_tile_bottom = i_tile_top + TileHeight;
|
||||
index_t i_tile_right = i_tile_left + (TileWidth - 1);
|
||||
index_t i_tile_bottom = i_tile_top + (TileHeight - 1);
|
||||
|
||||
if(i_tile_bottom > max_uih_len && i_tile_right > max_uih_len)
|
||||
// assume num_target > 0 with high probability, don't check whether num_target is 0;
|
||||
// so if num_target is 0, IsTokenPairInsideMask() will be called for the bottom tile
|
||||
if(i_tile_bottom >= max_uih_len || i_tile_right >= max_uih_len)
|
||||
return false;
|
||||
|
||||
return true;
|
||||
};
|
||||
}
|
||||
}
|
||||
};
|
||||
};
|
||||
|
||||
template <bool kUseCausal, bool kUseLocal>
|
||||
|
||||
Reference in New Issue
Block a user