mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-16 10:59:55 +00:00
Update in GetTileRangeAlongX to consider for non-causal+local_size>0 situation and add test case to test_hstu_attention.sh
This commit is contained in:
@@ -54,13 +54,10 @@ struct HstuBlockMaskWithLocal
|
||||
CK_TILE_HOST_DEVICE constexpr auto
|
||||
GetTileRangeAlongX(index_t i_y, number<YTile>, number<XTile>) const
|
||||
{
|
||||
if constexpr(kUseCausal)
|
||||
if(!is_tile_in_first_split)
|
||||
{
|
||||
if(!is_tile_in_first_split)
|
||||
{
|
||||
index_t x_end = min(i_y + YTile, seqlen);
|
||||
return ck_tile::make_tuple(0, x_end);
|
||||
};
|
||||
index_t x_end = min(i_y + YTile, seqlen);
|
||||
return ck_tile::make_tuple(0, x_end);
|
||||
};
|
||||
|
||||
if constexpr(!kUseCausal)
|
||||
|
||||
@@ -36,5 +36,8 @@ for dtype in "fp16" "bf16"; do
|
||||
## jagged causal+local+context+target
|
||||
$EXE -v=1 -prec=$dtype -b=10 -jagged=1 -nhead=4 -hdim_qk=128 -hdim_v=128 -seqlens=300,300,290,280,310 -causal=1 -local_len=5 -context_len=8 -minfull_len=7 -targets=8
|
||||
|
||||
## jagged no-causal+local+context+target
|
||||
$EXE -v=1 -prec=$dtype -b=10 -jagged=1 -nhead=4 -hdim_qk=128 -hdim_v=128 -seqlens=300,300,290,280,310 -causal=1 -local_len=5 -context_len=8 -minfull_len=7 -targets=8
|
||||
|
||||
set +x
|
||||
done
|
||||
|
||||
Reference in New Issue
Block a user