[Performance] Use separate workgroups to handle seqlen scope [max_uih_len - minfull_attn_seqlen, seqlen]

This commit is contained in:
Qianfeng Zhang
2025-07-23 09:39:50 +00:00
parent ce6a044440
commit f49fe28ca2
5 changed files with 92 additions and 45 deletions

View File

@@ -128,8 +128,9 @@ struct batched_forward_causal_local_bias_dropout_dispatch
param.philox_offset);
}();
dim3 kGridSize =
HstuKernel::GridSize(param.num_batch, param.num_head, param.seqlen, param.hdim_v);
bool has_minfull_attn_seqlen = (param.min_full_attn_seqlen > 0);
dim3 kGridSize = HstuKernel::GridSize(
param.num_batch, param.num_head, param.seqlen, param.hdim_v, has_minfull_attn_seqlen);
constexpr dim3 kBlockSize = HstuKernel::BlockSize();
constexpr ck_tile::index_t kBlockPerCu = HstuKernel::kBlockPerCu;

View File

@@ -495,17 +495,30 @@ struct HstuAttentionFwdKernel
CK_TILE_HOST static constexpr auto GridSize(ck_tile::index_t batch_size_,
ck_tile::index_t nhead_,
ck_tile::index_t seqlen_,
ck_tile::index_t hdim_v_)
ck_tile::index_t hdim_v_,
bool has_minfull_attn_seqlen)
{
// The Q sequence [0, seqlen) will be split to two parts for allocating workgroups:
// 1) [0, seqlen - target - min_full_attn_seqlen)
// 2) [seqlen - target - min_full_attn_seqlen, seqlen)
ck_tile::index_t num_tile_in_seqlen =
ck_tile::integer_divide_ceil(seqlen_, HstuAttentionPipeline::kM0);
if constexpr(kHasLocalMask)
{
if(has_minfull_attn_seqlen)
num_tile_in_seqlen += 1;
};
if constexpr(HstuAttentionPipeline::kN1 < HstuAttentionPipeline::kSubQKHeaddim)
{
#if HSTU_SCHED_BATCH_AS_FIRST_GRID_DIM
return dim3(batch_size_,
nhead_,
ck_tile::integer_divide_ceil(seqlen_, HstuAttentionPipeline::kM0) *
num_tile_in_seqlen *
ck_tile::integer_divide_ceil(hdim_v_, HstuAttentionPipeline::kN1));
#else
return dim3(ck_tile::integer_divide_ceil(seqlen_, HstuAttentionPipeline::kM0) *
return dim3(num_tile_in_seqlen *
ck_tile::integer_divide_ceil(hdim_v_, HstuAttentionPipeline::kN1),
nhead_,
batch_size_);
@@ -514,11 +527,9 @@ struct HstuAttentionFwdKernel
else
{
#if HSTU_SCHED_BATCH_AS_FIRST_GRID_DIM
return dim3(batch_size_,
nhead_,
ck_tile::integer_divide_ceil(seqlen_, HstuAttentionPipeline::kM0));
return dim3(batch_size_, nhead_, num_tile_in_seqlen);
#else
return dim3(ck_tile::integer_divide_ceil(seqlen_, HstuAttentionPipeline::kM0),
return dim3(num_tile_in_seqlen),
nhead_,
batch_size_);
#endif
@@ -593,12 +604,8 @@ struct HstuAttentionFwdKernel
// allocate LDS
__shared__ char smem_ptr[GetSmemSize()];
// divide problem
const auto [i_tile_m, i_tile_n, i_nhead, i_batch] = GetTileIndex(kargs);
const index_t i_m0 = __builtin_amdgcn_readfirstlane(i_tile_m * HstuAttentionPipeline::kM0);
const index_t i_n1 = __builtin_amdgcn_readfirstlane(i_tile_n * HstuAttentionPipeline::kN1);
long_index_t batch_offset_q = 0;
long_index_t batch_offset_k = 0;
long_index_t batch_offset_v = 0;
@@ -628,13 +635,6 @@ struct HstuAttentionFwdKernel
batch_offset_o = query_start * kargs.seq_stride_o;
kargs.seqlen = kargs.seq_offsets_ptr[i_batch + 1] - kargs.seq_offsets_ptr[i_batch];
// # of required blocks is different in each groups, terminate unnecessary blocks
// earlier
if(kargs.seqlen <= i_m0)
{
return;
}
}
else
{
@@ -650,9 +650,36 @@ struct HstuAttentionFwdKernel
int num_target = (kargs.num_targets_ptr == nullptr) ? 0 : kargs.num_targets_ptr[i_batch];
index_t seqlen_in_first_split = kargs.seqlen;
if constexpr(kHasLocalMask)
{
if(kargs.min_full_attn_seqlen > 0)
seqlen_in_first_split = kargs.seqlen - kargs.min_full_attn_seqlen - num_target;
};
index_t num_tile_in_first_split =
ck_tile::integer_divide_ceil(seqlen_in_first_split, HstuAttentionPipeline::kM0);
bool is_tile_in_first_split = (i_tile_m < num_tile_in_first_split);
index_t i_m0 = is_tile_in_first_split
? __builtin_amdgcn_readfirstlane(i_tile_m * HstuAttentionPipeline::kM0)
: __builtin_amdgcn_readfirstlane((i_tile_m - num_tile_in_first_split) *
HstuAttentionPipeline::kM0) +
seqlen_in_first_split;
const index_t i_n1 = __builtin_amdgcn_readfirstlane(i_tile_n * HstuAttentionPipeline::kN1);
index_t seqlen_q_in_ctrl = is_tile_in_first_split ? seqlen_in_first_split : kargs.seqlen;
if(seqlen_q_in_ctrl <= i_m0)
return;
HstuMask mask = [&]() {
if constexpr(kHasLocalMask)
return make_hstu_block_mask_with_local<HstuMask>(kargs.seqlen,
return make_hstu_block_mask_with_local<HstuMask>(is_tile_in_first_split,
kargs.seqlen,
kargs.contextual_seqlen,
num_target,
kargs.window_size,
@@ -680,7 +707,7 @@ struct HstuAttentionFwdKernel
const auto q_dram = [&]() {
const auto q_dram_naive = make_naive_tensor_view<address_space_enum::global>(
q_ptr,
make_tuple(kargs.seqlen, kargs.hdim_qk),
make_tuple(seqlen_q_in_ctrl, kargs.hdim_qk),
make_tuple(kargs.seq_stride_q, 1),
number<HstuAttentionPipeline::kAlignmentQ>{},
number<1>{});
@@ -773,7 +800,7 @@ struct HstuAttentionFwdKernel
const auto bias_dram = [&]() {
const auto bias_dram_naive = make_naive_tensor_view<address_space_enum::global>(
bias_ptr,
make_tuple(kargs.seqlen, kargs.seqlen),
make_tuple(seqlen_q_in_ctrl, kargs.seqlen),
make_tuple(kargs.seq_stride_bias, 1),
number<HstuAttentionPipeline::kAlignmentBias>{},
number<1>{});
@@ -824,7 +851,7 @@ struct HstuAttentionFwdKernel
auto o_dram = [&]() {
const auto o_dram_naive = make_naive_tensor_view<address_space_enum::global>(
o_ptr,
make_tuple(kargs.seqlen, kargs.hdim_v),
make_tuple(seqlen_q_in_ctrl, kargs.hdim_v),
make_tuple(kargs.seq_stride_o, 1),
number<HstuAttentionPipeline::kAlignmentO>{},
number<1>{});

View File

@@ -116,8 +116,12 @@ struct jagged_forward_causal_local_bias_dropout_dispatch
param.philox_offset);
}();
dim3 kGridSize =
HstuKernel::GridSize(param.num_batch, param.num_head, param.max_seqlen, param.hdim_v);
bool has_minfull_attn_seqlen = (param.min_full_attn_seqlen > 0);
dim3 kGridSize = HstuKernel::GridSize(param.num_batch,
param.num_head,
param.max_seqlen,
param.hdim_v,
has_minfull_attn_seqlen);
constexpr dim3 kBlockSize = HstuKernel::BlockSize();
constexpr ck_tile::index_t kBlockPerCu = HstuKernel::kBlockPerCu;

View File

@@ -14,6 +14,10 @@ struct HstuBlockMaskWithLocal
static constexpr bool kUseLocal = true;
static constexpr bool IsMasking = true;
// is_tile_in_first_split is false only when min_full_attn_seqlen > 0 and the current
// tile is inside scope [max_uih_len - min_full_attn_seqlen, seqlen); for other cases
// and tiles, is_tile_in_first_split is true
bool is_tile_in_first_split;
int seqlen;
int contextual_seqlen;
@@ -23,12 +27,14 @@ struct HstuBlockMaskWithLocal
int max_uih_len;
int max_id;
CK_TILE_HOST_DEVICE HstuBlockMaskWithLocal(int seqlen_,
CK_TILE_HOST_DEVICE HstuBlockMaskWithLocal(bool is_tile_in_first_split_,
int seqlen_,
int contextual_seqlen_,
int max_attn_len_,
int min_full_attn_seqlen_,
int num_target_)
: seqlen(seqlen_),
: is_tile_in_first_split(is_tile_in_first_split_),
seqlen(seqlen_),
contextual_seqlen(contextual_seqlen_),
max_attn_len(max_attn_len_),
min_full_attn_seqlen(min_full_attn_seqlen_)
@@ -48,10 +54,16 @@ struct HstuBlockMaskWithLocal
CK_TILE_HOST_DEVICE constexpr auto
GetTileRangeAlongX(index_t i_y, number<YTile>, number<XTile>) const
{
if(min_full_attn_seqlen > 0 && i_y + YTile > max_uih_len - min_full_attn_seqlen)
// handle two special cases first
if(!is_tile_in_first_split)
{
index_t x_end = min(i_y + YTile, seqlen);
return ck_tile::make_tuple(0, x_end);
// the tile is completely inside [max_uih_len - min_full_attn_seqlen, max_uih_len)
if(i_y + YTile <= max_uih_len)
return ck_tile::make_tuple(0, max_uih_len);
// the tils is partially inside [max_uih_len - min_full_attn_seqlen, max_uih_len) and
// partially inside [max_uih_len, seqlen)
if(i_y < max_uih_len)
return ck_tile::make_tuple(0, seqlen);
};
if constexpr(!kUseCausal)
@@ -204,8 +216,7 @@ struct HstuBlockMaskWithLocal
// diagonal line are always considerred
if constexpr(kUseCausal)
{
bool in_min_full_scope =
(min_full_attn_seqlen > 0) ? (row_id >= max_id - min_full_attn_seqlen) : false;
bool in_min_full_scope = !is_tile_in_first_split;
bool res = (((row_id > col_id) || (row == col)) &&
((row_id - col_id <= max_attn_len) || in_min_full_scope));
@@ -214,8 +225,7 @@ struct HstuBlockMaskWithLocal
}
else
{
bool in_min_full_scope =
(min_full_attn_seqlen > 0) ? (row_id >= max_id - min_full_attn_seqlen) : false;
bool in_min_full_scope = !is_tile_in_first_split;
bool res = (((row_id != col_id) || (row == col)) &&
((abs(row_id - col_id) <= max_attn_len) || in_min_full_scope));
@@ -233,11 +243,7 @@ struct HstuBlockMaskWithLocal
{
std::ignore = i_tile_left;
index_t i_tile_bottom = i_tile_top + (TileHeight - 1);
// assume num_target > 0 with high probability, don't check whether num_target is 0;
// so if num_target is 0, IsTokenPairInsideMask() will be called for the bottom tile
if(i_tile_top >= max_uih_len - min_full_attn_seqlen && i_tile_bottom < max_uih_len)
if(!is_tile_in_first_split && (i_tile_top + TileHeight <= max_uih_len))
return true;
return false;
@@ -423,14 +429,19 @@ struct HstuBlockMasking
};
template <typename HstuBlockMaskType>
CK_TILE_HOST_DEVICE constexpr auto make_hstu_block_mask_with_local(int seqlen_,
CK_TILE_HOST_DEVICE constexpr auto make_hstu_block_mask_with_local(bool is_tile_in_first_split_,
int seqlen_,
int contextual_seqlen_,
int num_target,
int max_attn_len_,
int min_full_attn_seqlen_)
{
return HstuBlockMaskType{
seqlen_, contextual_seqlen_, max_attn_len_, min_full_attn_seqlen_, num_target};
return HstuBlockMaskType{is_tile_in_first_split_,
seqlen_,
contextual_seqlen_,
max_attn_len_,
min_full_attn_seqlen_,
num_target};
};
template <typename HstuBlockMaskType>

View File

@@ -113,8 +113,12 @@ struct reference_hstu_attention
HstuMask mask = [&]() {
if constexpr(kHasLocalMask)
return ck_tile::make_hstu_block_mask_with_local<HstuMask>(
seqlen, contextual_seqlen, num_target, max_attn_len, min_full_attn_seqlen);
return ck_tile::make_hstu_block_mask_with_local<HstuMask>(true,
seqlen,
contextual_seqlen,
num_target,
max_attn_len,
min_full_attn_seqlen);
else
return ck_tile::make_hstu_block_mask_without_local<HstuMask>(
seqlen, contextual_seqlen, num_target);