diff --git a/example/ck_tile/18_hstu_attention/hstu_block_masking.hpp b/example/ck_tile/18_hstu_attention/hstu_block_masking.hpp index ac8b6a0f58..6a29971ea8 100644 --- a/example/ck_tile/18_hstu_attention/hstu_block_masking.hpp +++ b/example/ck_tile/18_hstu_attention/hstu_block_masking.hpp @@ -54,13 +54,10 @@ struct HstuBlockMaskWithLocal CK_TILE_HOST_DEVICE constexpr auto GetTileRangeAlongX(index_t i_y, number, number) const { - if constexpr(kUseCausal) + if(!is_tile_in_first_split) { - if(!is_tile_in_first_split) - { - index_t x_end = min(i_y + YTile, seqlen); - return ck_tile::make_tuple(0, x_end); - }; + index_t x_end = min(i_y + YTile, seqlen); + return ck_tile::make_tuple(0, x_end); }; if constexpr(!kUseCausal) diff --git a/example/ck_tile/18_hstu_attention/scripts/test_hstu_attention.sh b/example/ck_tile/18_hstu_attention/scripts/test_hstu_attention.sh index 8144424f54..16b141f80b 100644 --- a/example/ck_tile/18_hstu_attention/scripts/test_hstu_attention.sh +++ b/example/ck_tile/18_hstu_attention/scripts/test_hstu_attention.sh @@ -36,5 +36,8 @@ for dtype in "fp16" "bf16"; do ## jagged causal+local+context+target $EXE -v=1 -prec=$dtype -b=10 -jagged=1 -nhead=4 -hdim_qk=128 -hdim_v=128 -seqlens=300,300,290,280,310 -causal=1 -local_len=5 -context_len=8 -minfull_len=7 -targets=8 + ## jagged no-causal+local+context+target + $EXE -v=1 -prec=$dtype -b=10 -jagged=1 -nhead=4 -hdim_qk=128 -hdim_v=128 -seqlens=300,300,290,280,310 -causal=1 -local_len=5 -context_len=8 -minfull_len=7 -targets=8 + set +x done