mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-17 11:30:02 +00:00
Update to support min_full_attn_seqlen be bigger than max_uih_len
This commit is contained in:
@@ -664,18 +664,33 @@ struct HstuAttentionFwdKernel
|
||||
{
|
||||
if(kargs.min_full_attn_seqlen > 0)
|
||||
{
|
||||
seqlen_in_first_split = kargs.seqlen - kargs.min_full_attn_seqlen - num_target;
|
||||
// need consider for cases where min_full_attn_seqlen be bigger than max_uih_len
|
||||
if(kargs.seqlen - num_target > kargs.min_full_attn_seqlen)
|
||||
{
|
||||
seqlen_in_first_split = kargs.seqlen - num_target - kargs.min_full_attn_seqlen;
|
||||
|
||||
index_t num_tile_in_first_split =
|
||||
ck_tile::integer_divide_ceil(seqlen_in_first_split, HstuAttentionPipeline::kM0);
|
||||
index_t num_tile_in_first_split = ck_tile::integer_divide_ceil(
|
||||
seqlen_in_first_split, HstuAttentionPipeline::kM0);
|
||||
|
||||
is_tile_in_first_split = (i_tile_m < num_tile_in_first_split);
|
||||
is_tile_in_first_split = (i_tile_m < num_tile_in_first_split);
|
||||
|
||||
i_m0 = is_tile_in_first_split
|
||||
? __builtin_amdgcn_readfirstlane(i_tile_m * HstuAttentionPipeline::kM0)
|
||||
: __builtin_amdgcn_readfirstlane((i_tile_m - num_tile_in_first_split) *
|
||||
HstuAttentionPipeline::kM0) +
|
||||
seqlen_in_first_split;
|
||||
i_m0 =
|
||||
is_tile_in_first_split
|
||||
? __builtin_amdgcn_readfirstlane(i_tile_m * HstuAttentionPipeline::kM0)
|
||||
: __builtin_amdgcn_readfirstlane((i_tile_m - num_tile_in_first_split) *
|
||||
HstuAttentionPipeline::kM0) +
|
||||
seqlen_in_first_split;
|
||||
}
|
||||
else
|
||||
{
|
||||
seqlen_in_first_split = 0;
|
||||
is_tile_in_first_split = false;
|
||||
|
||||
// adjust the min_full_attn_seqlen to be passed to HstuBlockMask constructor
|
||||
kargs.min_full_attn_seqlen = kargs.seqlen - num_target;
|
||||
|
||||
i_m0 = __builtin_amdgcn_readfirstlane(i_tile_m * HstuAttentionPipeline::kM0);
|
||||
};
|
||||
}
|
||||
else
|
||||
i_m0 = __builtin_amdgcn_readfirstlane(i_tile_m * HstuAttentionPipeline::kM0);
|
||||
|
||||
@@ -41,6 +41,10 @@ struct HstuBlockMaskWithLocal
|
||||
{
|
||||
max_uih_len = seqlen - num_target_;
|
||||
|
||||
// assuming min_full_attn_seqlen has higher priority, ensure contextual scope not collide
|
||||
// with min_full_attn_seqlen scope
|
||||
contextual_seqlen = min(contextual_seqlen, max_uih_len - min_full_attn_seqlen);
|
||||
|
||||
if(contextual_seqlen > 0)
|
||||
max_id = max_uih_len - (contextual_seqlen - 1);
|
||||
else
|
||||
|
||||
@@ -114,12 +114,24 @@ struct reference_hstu_attention
|
||||
|
||||
HstuMask mask = [&]() {
|
||||
if constexpr(kHasLocalMask)
|
||||
return ck_tile::make_hstu_block_mask_with_local<HstuMask>(true,
|
||||
seqlen,
|
||||
contextual_seqlen,
|
||||
num_target,
|
||||
max_attn_len,
|
||||
min_full_attn_seqlen);
|
||||
// need adjust the min_full_attn_seqlen passed to the HstuBlockMask() if the
|
||||
// user passed min_full_attn_seqlen is bigger than max_uih_len
|
||||
if(seqlen - num_target > min_full_attn_seqlen)
|
||||
return ck_tile::make_hstu_block_mask_with_local<HstuMask>(
|
||||
true,
|
||||
seqlen,
|
||||
contextual_seqlen,
|
||||
num_target,
|
||||
max_attn_len,
|
||||
min_full_attn_seqlen);
|
||||
else
|
||||
return ck_tile::make_hstu_block_mask_with_local<HstuMask>(true,
|
||||
seqlen,
|
||||
contextual_seqlen,
|
||||
num_target,
|
||||
max_attn_len,
|
||||
seqlen -
|
||||
num_target);
|
||||
else
|
||||
return ck_tile::make_hstu_block_mask_without_local<HstuMask>(
|
||||
seqlen, contextual_seqlen, num_target);
|
||||
|
||||
@@ -44,5 +44,11 @@ for dtype in "fp16" "bf16"; do
|
||||
## jagged no-causal+local+context+target
|
||||
$EXE -v=1 -prec=$dtype -b=10 -jagged=1 -nhead=4 -hdim_qk=128 -hdim_v=128 -seqlens=300,300,290,280,310 -causal=0 -local_len=5 -context_len=8 -minfull_len=7 -targets=8 -attn_scale=$attn_scale
|
||||
|
||||
## jagged causal+local+target (minfull_len > max_uih_len)
|
||||
$EXE -v=1 -prec=$dtype -b=10 -jagged=1 -nhead=4 -hdim_qk=128 -hdim_v=128 -seqlens=300,300,290,280,310 -causal=1 -local_len=5 -context_len=0 -minfull_len=290 -targets=8 -attn_scale=$attn_scale
|
||||
|
||||
## jagged causal+local+context+target (minfull_len > max_uih_len)
|
||||
$EXE -v=1 -prec=$dtype -b=10 -jagged=1 -nhead=4 -hdim_qk=128 -hdim_v=128 -seqlens=300,300,290,280,310 -causal=1 -local_len=5 -context_len=8 -minfull_len=290 -targets=8 -attn_scale=$attn_scale
|
||||
|
||||
set +x
|
||||
done
|
||||
|
||||
Reference in New Issue
Block a user