Update to support min_full_attn_seqlen be bigger than max_uih_len

This commit is contained in:
Qianfeng Zhang
2025-08-08 09:25:25 +00:00
parent fd25f5df05
commit 971d0d98d4
4 changed files with 52 additions and 15 deletions

View File

@@ -664,18 +664,33 @@ struct HstuAttentionFwdKernel
{
if(kargs.min_full_attn_seqlen > 0)
{
seqlen_in_first_split = kargs.seqlen - kargs.min_full_attn_seqlen - num_target;
// need consider for cases where min_full_attn_seqlen be bigger than max_uih_len
if(kargs.seqlen - num_target > kargs.min_full_attn_seqlen)
{
seqlen_in_first_split = kargs.seqlen - num_target - kargs.min_full_attn_seqlen;
index_t num_tile_in_first_split =
ck_tile::integer_divide_ceil(seqlen_in_first_split, HstuAttentionPipeline::kM0);
index_t num_tile_in_first_split = ck_tile::integer_divide_ceil(
seqlen_in_first_split, HstuAttentionPipeline::kM0);
is_tile_in_first_split = (i_tile_m < num_tile_in_first_split);
is_tile_in_first_split = (i_tile_m < num_tile_in_first_split);
i_m0 = is_tile_in_first_split
? __builtin_amdgcn_readfirstlane(i_tile_m * HstuAttentionPipeline::kM0)
: __builtin_amdgcn_readfirstlane((i_tile_m - num_tile_in_first_split) *
HstuAttentionPipeline::kM0) +
seqlen_in_first_split;
i_m0 =
is_tile_in_first_split
? __builtin_amdgcn_readfirstlane(i_tile_m * HstuAttentionPipeline::kM0)
: __builtin_amdgcn_readfirstlane((i_tile_m - num_tile_in_first_split) *
HstuAttentionPipeline::kM0) +
seqlen_in_first_split;
}
else
{
seqlen_in_first_split = 0;
is_tile_in_first_split = false;
// adjust the min_full_attn_seqlen to be passed to HstuBlockMask constructor
kargs.min_full_attn_seqlen = kargs.seqlen - num_target;
i_m0 = __builtin_amdgcn_readfirstlane(i_tile_m * HstuAttentionPipeline::kM0);
};
}
else
i_m0 = __builtin_amdgcn_readfirstlane(i_tile_m * HstuAttentionPipeline::kM0);

View File

@@ -41,6 +41,10 @@ struct HstuBlockMaskWithLocal
{
max_uih_len = seqlen - num_target_;
// assuming min_full_attn_seqlen has higher priority, ensure contextual scope not collide
// with min_full_attn_seqlen scope
contextual_seqlen = min(contextual_seqlen, max_uih_len - min_full_attn_seqlen);
if(contextual_seqlen > 0)
max_id = max_uih_len - (contextual_seqlen - 1);
else

View File

@@ -114,12 +114,24 @@ struct reference_hstu_attention
HstuMask mask = [&]() {
if constexpr(kHasLocalMask)
return ck_tile::make_hstu_block_mask_with_local<HstuMask>(true,
seqlen,
contextual_seqlen,
num_target,
max_attn_len,
min_full_attn_seqlen);
// need adjust the min_full_attn_seqlen passed to the HstuBlockMask() if the
// user passed min_full_attn_seqlen is bigger than max_uih_len
if(seqlen - num_target > min_full_attn_seqlen)
return ck_tile::make_hstu_block_mask_with_local<HstuMask>(
true,
seqlen,
contextual_seqlen,
num_target,
max_attn_len,
min_full_attn_seqlen);
else
return ck_tile::make_hstu_block_mask_with_local<HstuMask>(true,
seqlen,
contextual_seqlen,
num_target,
max_attn_len,
seqlen -
num_target);
else
return ck_tile::make_hstu_block_mask_without_local<HstuMask>(
seqlen, contextual_seqlen, num_target);

View File

@@ -44,5 +44,11 @@ for dtype in "fp16" "bf16"; do
## jagged no-causal+local+context+target
$EXE -v=1 -prec=$dtype -b=10 -jagged=1 -nhead=4 -hdim_qk=128 -hdim_v=128 -seqlens=300,300,290,280,310 -causal=0 -local_len=5 -context_len=8 -minfull_len=7 -targets=8 -attn_scale=$attn_scale
## jagged causal+local+target (minfull_len > max_uih_len)
$EXE -v=1 -prec=$dtype -b=10 -jagged=1 -nhead=4 -hdim_qk=128 -hdim_v=128 -seqlens=300,300,290,280,310 -causal=1 -local_len=5 -context_len=0 -minfull_len=290 -targets=8 -attn_scale=$attn_scale
## jagged causal+local+context+target (minfull_len > max_uih_len)
$EXE -v=1 -prec=$dtype -b=10 -jagged=1 -nhead=4 -hdim_qk=128 -hdim_v=128 -seqlens=300,300,290,280,310 -causal=1 -local_len=5 -context_len=8 -minfull_len=290 -targets=8 -attn_scale=$attn_scale
set +x
done