From 971d0d98d477b6ac3a7afb64206be75a45c9a90b Mon Sep 17 00:00:00 2001 From: Qianfeng Zhang Date: Fri, 8 Aug 2025 09:25:25 +0000 Subject: [PATCH] Update to support min_full_attn_seqlen be bigger than max_uih_len --- .../hstu_attention_fwd_kernel.hpp | 33 ++++++++++++++----- .../18_hstu_attention/hstu_block_masking.hpp | 4 +++ .../reference_hstu_attention.hpp | 24 ++++++++++---- .../scripts/test_hstu_attention.sh | 6 ++++ 4 files changed, 52 insertions(+), 15 deletions(-) diff --git a/example/ck_tile/18_hstu_attention/hstu_attention_fwd_kernel.hpp b/example/ck_tile/18_hstu_attention/hstu_attention_fwd_kernel.hpp index dae068448a..788cb719b1 100644 --- a/example/ck_tile/18_hstu_attention/hstu_attention_fwd_kernel.hpp +++ b/example/ck_tile/18_hstu_attention/hstu_attention_fwd_kernel.hpp @@ -664,18 +664,33 @@ struct HstuAttentionFwdKernel { if(kargs.min_full_attn_seqlen > 0) { - seqlen_in_first_split = kargs.seqlen - kargs.min_full_attn_seqlen - num_target; + // need consider for cases where min_full_attn_seqlen be bigger than max_uih_len + if(kargs.seqlen - num_target > kargs.min_full_attn_seqlen) + { + seqlen_in_first_split = kargs.seqlen - num_target - kargs.min_full_attn_seqlen; - index_t num_tile_in_first_split = - ck_tile::integer_divide_ceil(seqlen_in_first_split, HstuAttentionPipeline::kM0); + index_t num_tile_in_first_split = ck_tile::integer_divide_ceil( + seqlen_in_first_split, HstuAttentionPipeline::kM0); - is_tile_in_first_split = (i_tile_m < num_tile_in_first_split); + is_tile_in_first_split = (i_tile_m < num_tile_in_first_split); - i_m0 = is_tile_in_first_split - ? __builtin_amdgcn_readfirstlane(i_tile_m * HstuAttentionPipeline::kM0) - : __builtin_amdgcn_readfirstlane((i_tile_m - num_tile_in_first_split) * - HstuAttentionPipeline::kM0) + - seqlen_in_first_split; + i_m0 = + is_tile_in_first_split + ? __builtin_amdgcn_readfirstlane(i_tile_m * HstuAttentionPipeline::kM0) + : __builtin_amdgcn_readfirstlane((i_tile_m - num_tile_in_first_split) * + HstuAttentionPipeline::kM0) + + seqlen_in_first_split; + } + else + { + seqlen_in_first_split = 0; + is_tile_in_first_split = false; + + // adjust the min_full_attn_seqlen to be passed to HstuBlockMask constructor + kargs.min_full_attn_seqlen = kargs.seqlen - num_target; + + i_m0 = __builtin_amdgcn_readfirstlane(i_tile_m * HstuAttentionPipeline::kM0); + }; } else i_m0 = __builtin_amdgcn_readfirstlane(i_tile_m * HstuAttentionPipeline::kM0); diff --git a/example/ck_tile/18_hstu_attention/hstu_block_masking.hpp b/example/ck_tile/18_hstu_attention/hstu_block_masking.hpp index 6a29971ea8..5314169781 100644 --- a/example/ck_tile/18_hstu_attention/hstu_block_masking.hpp +++ b/example/ck_tile/18_hstu_attention/hstu_block_masking.hpp @@ -41,6 +41,10 @@ struct HstuBlockMaskWithLocal { max_uih_len = seqlen - num_target_; + // assuming min_full_attn_seqlen has higher priority, ensure contextual scope not collide + // with min_full_attn_seqlen scope + contextual_seqlen = min(contextual_seqlen, max_uih_len - min_full_attn_seqlen); + if(contextual_seqlen > 0) max_id = max_uih_len - (contextual_seqlen - 1); else diff --git a/example/ck_tile/18_hstu_attention/reference_hstu_attention.hpp b/example/ck_tile/18_hstu_attention/reference_hstu_attention.hpp index 96324946b2..b35f96461a 100644 --- a/example/ck_tile/18_hstu_attention/reference_hstu_attention.hpp +++ b/example/ck_tile/18_hstu_attention/reference_hstu_attention.hpp @@ -114,12 +114,24 @@ struct reference_hstu_attention HstuMask mask = [&]() { if constexpr(kHasLocalMask) - return ck_tile::make_hstu_block_mask_with_local(true, - seqlen, - contextual_seqlen, - num_target, - max_attn_len, - min_full_attn_seqlen); + // need adjust the min_full_attn_seqlen passed to the HstuBlockMask() if the + // user passed min_full_attn_seqlen is bigger than max_uih_len + if(seqlen - num_target > min_full_attn_seqlen) + return ck_tile::make_hstu_block_mask_with_local( + true, + seqlen, + contextual_seqlen, + num_target, + max_attn_len, + min_full_attn_seqlen); + else + return ck_tile::make_hstu_block_mask_with_local(true, + seqlen, + contextual_seqlen, + num_target, + max_attn_len, + seqlen - + num_target); else return ck_tile::make_hstu_block_mask_without_local( seqlen, contextual_seqlen, num_target); diff --git a/example/ck_tile/18_hstu_attention/scripts/test_hstu_attention.sh b/example/ck_tile/18_hstu_attention/scripts/test_hstu_attention.sh index 7f1fd3f4ca..508151821a 100644 --- a/example/ck_tile/18_hstu_attention/scripts/test_hstu_attention.sh +++ b/example/ck_tile/18_hstu_attention/scripts/test_hstu_attention.sh @@ -44,5 +44,11 @@ for dtype in "fp16" "bf16"; do ## jagged no-causal+local+context+target $EXE -v=1 -prec=$dtype -b=10 -jagged=1 -nhead=4 -hdim_qk=128 -hdim_v=128 -seqlens=300,300,290,280,310 -causal=0 -local_len=5 -context_len=8 -minfull_len=7 -targets=8 -attn_scale=$attn_scale + ## jagged causal+local+target (minfull_len > max_uih_len) + $EXE -v=1 -prec=$dtype -b=10 -jagged=1 -nhead=4 -hdim_qk=128 -hdim_v=128 -seqlens=300,300,290,280,310 -causal=1 -local_len=5 -context_len=0 -minfull_len=290 -targets=8 -attn_scale=$attn_scale + + ## jagged causal+local+context+target (minfull_len > max_uih_len) + $EXE -v=1 -prec=$dtype -b=10 -jagged=1 -nhead=4 -hdim_qk=128 -hdim_v=128 -seqlens=300,300,290,280,310 -causal=1 -local_len=5 -context_len=8 -minfull_len=290 -targets=8 -attn_scale=$attn_scale + set +x done