From 140433620032c404e045e2757432a96b5f3ab77f Mon Sep 17 00:00:00 2001 From: Qianfeng Zhang Date: Sun, 10 Aug 2025 04:22:21 +0000 Subject: [PATCH] Update HstuBlockMaskWithLocal::GetTileRangeAlongX, add comments and test cases for causal == false --- .../18_hstu_attention/hstu_block_masking.hpp | 82 ++++++++++++------- .../scripts/test_hstu_attention.sh | 2 + 2 files changed, 55 insertions(+), 29 deletions(-) diff --git a/example/ck_tile/18_hstu_attention/hstu_block_masking.hpp b/example/ck_tile/18_hstu_attention/hstu_block_masking.hpp index 5314169781..37d2ea8751 100644 --- a/example/ck_tile/18_hstu_attention/hstu_block_masking.hpp +++ b/example/ck_tile/18_hstu_attention/hstu_block_masking.hpp @@ -58,82 +58,106 @@ struct HstuBlockMaskWithLocal CK_TILE_HOST_DEVICE constexpr auto GetTileRangeAlongX(index_t i_y, number, number) const { + // handle two special cases first if(!is_tile_in_first_split) { - index_t x_end = min(i_y + YTile, seqlen); - return ck_tile::make_tuple(0, x_end); - }; - - if constexpr(!kUseCausal) - { - if(i_y >= contextual_seqlen) + if constexpr(kUseCausal) { + index_t x_end = min(i_y + YTile, seqlen); + return ck_tile::make_tuple(0, x_end); + } + else + { + // tile is partitially or completely in [max_uih_len-min_full_attn_seqlen, + // max_uih_len) if(i_y < max_uih_len) { - index_t x_start = max(0, i_y - max_attn_len); + return ck_tile::make_tuple(0, seqlen); + } + else // tile is completely inside [max_uih_len, seqlen) + { + index_t x_end = min(i_y + YTile, seqlen); + return ck_tile::make_tuple(0, x_end); + }; + }; + }; + + // is_tile_in_first_split is true, either min_full_attn_seqlen is 0 or tile is + // in [0, max_uih_len-min_full_attn_seqlen) + if constexpr(!kUseCausal) + { + if(i_y >= contextual_seqlen + max_attn_len) + { + // some row of the tile in [contextual_seqlen+max_attn_len, max_uih_len) + if(i_y < max_uih_len) + { + index_t x_start = i_y - max_attn_len; index_t x_start_aligned = x_start - x_start % XTile; + // some rows of the tile in [max_uih_len -max_attn_len, max_uih_len) if(i_y + YTile > max_uih_len - max_attn_len) { return ck_tile::make_tuple(x_start_aligned, seqlen); } - else + else // whole tile in [contextual_seqlen+max_attn_len, max_uih_len + // -max_attn_len) { - index_t x_end = min(i_y + YTile + max_attn_len, seqlen); + index_t x_end = i_y + YTile + max_attn_len; return ck_tile::make_tuple(x_start_aligned, x_end); }; } - else + else // whole tile in [max_uih_len, seqlen) { - index_t x_start = i_y - max_attn_len; - index_t x_end = seqlen; + index_t x_start = max_uih_len - max_attn_len; + index_t x_end = min(i_y + YTile, seqlen); return ck_tile::make_tuple(x_start - x_start % XTile, x_end); } } - else + else // for i_y < contextual_seqlen + max_attn_len { - if(i_y + YTile > max_uih_len) + if(i_y < contextual_seqlen) // some row of the tile in [0, contextual_seqlen) { - index_t x_end = min(i_y + YTile, seqlen); + index_t x_end = min(max(i_y + YTile + max_attn_len, max_uih_len), seqlen); return ck_tile::make_tuple(0, x_end); } - else + else // whole tile in [contextual_seqlen, seqlen) { - index_t x_end = max(i_y + YTile + max_attn_len, max_uih_len); + index_t x_end = min(i_y + YTile + max_attn_len, seqlen); return ck_tile::make_tuple(0, x_end); - }; + } } } else // kUseCausal && kUseLocal { - if(i_y >= contextual_seqlen) + if(i_y >= contextual_seqlen + max_attn_len) { index_t x_end = min(i_y + YTile, seqlen); + // some row of the tile in [contextual_seqlen+max_attn_len, max_uih_len) if(i_y < max_uih_len) { - index_t x_start = max(0, i_y - max_attn_len); + index_t x_start = i_y - max_attn_len; return ck_tile::make_tuple(x_start - x_start % XTile, x_end); } - else + else // whole tile in [max_uih_len, seqlen) { index_t x_start = max_uih_len - max_attn_len; return ck_tile::make_tuple(x_start - x_start % XTile, x_end); } } - else + else // for i_y < contextual_seqlen + max_attn_len { - index_t x_end = min(i_y + YTile, seqlen); - - if(i_y + YTile > max_uih_len) + if(i_y < contextual_seqlen) // some row of the tile in [0, contextual_seqlen) { + index_t x_end = min(max(i_y + YTile, max_uih_len), seqlen); return ck_tile::make_tuple(0, x_end); } - else + else // whole tile in [contextual_seqlen, seqlen) { - return ck_tile::make_tuple(0, max_uih_len); - }; + index_t x_end = min(i_y + YTile, seqlen); + return ck_tile::make_tuple(0, x_end); + } } }; } diff --git a/example/ck_tile/18_hstu_attention/scripts/test_hstu_attention.sh b/example/ck_tile/18_hstu_attention/scripts/test_hstu_attention.sh index 508151821a..311e9d29d3 100644 --- a/example/ck_tile/18_hstu_attention/scripts/test_hstu_attention.sh +++ b/example/ck_tile/18_hstu_attention/scripts/test_hstu_attention.sh @@ -50,5 +50,7 @@ for dtype in "fp16" "bf16"; do ## jagged causal+local+context+target (minfull_len > max_uih_len) $EXE -v=1 -prec=$dtype -b=10 -jagged=1 -nhead=4 -hdim_qk=128 -hdim_v=128 -seqlens=300,300,290,280,310 -causal=1 -local_len=5 -context_len=8 -minfull_len=290 -targets=8 -attn_scale=$attn_scale + ## jagged no-causal+local+context+target (minfull_len > max_uih_len) + $EXE -v=1 -prec=$dtype -b=10 -jagged=1 -nhead=4 -hdim_qk=128 -hdim_v=128 -seqlens=300,300,290,280,310 -causal=0 -local_len=5 -context_len=3 -minfull_len=290 -targets=8 -attn_scale=$attn_scale set +x done