mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-15 10:37:44 +00:00
[Performance] Use separate workgroups to handle seqlen scope [max_uih_len - minfull_attn_seqlen, seqlen]
This commit is contained in:
@@ -128,8 +128,9 @@ struct batched_forward_causal_local_bias_dropout_dispatch
|
||||
param.philox_offset);
|
||||
}();
|
||||
|
||||
dim3 kGridSize =
|
||||
HstuKernel::GridSize(param.num_batch, param.num_head, param.seqlen, param.hdim_v);
|
||||
bool has_minfull_attn_seqlen = (param.min_full_attn_seqlen > 0);
|
||||
dim3 kGridSize = HstuKernel::GridSize(
|
||||
param.num_batch, param.num_head, param.seqlen, param.hdim_v, has_minfull_attn_seqlen);
|
||||
constexpr dim3 kBlockSize = HstuKernel::BlockSize();
|
||||
constexpr ck_tile::index_t kBlockPerCu = HstuKernel::kBlockPerCu;
|
||||
|
||||
|
||||
@@ -495,17 +495,30 @@ struct HstuAttentionFwdKernel
|
||||
CK_TILE_HOST static constexpr auto GridSize(ck_tile::index_t batch_size_,
|
||||
ck_tile::index_t nhead_,
|
||||
ck_tile::index_t seqlen_,
|
||||
ck_tile::index_t hdim_v_)
|
||||
ck_tile::index_t hdim_v_,
|
||||
bool has_minfull_attn_seqlen)
|
||||
{
|
||||
// The Q sequence [0, seqlen) will be split to two parts for allocating workgroups:
|
||||
// 1) [0, seqlen - target - min_full_attn_seqlen)
|
||||
// 2) [seqlen - target - min_full_attn_seqlen, seqlen)
|
||||
ck_tile::index_t num_tile_in_seqlen =
|
||||
ck_tile::integer_divide_ceil(seqlen_, HstuAttentionPipeline::kM0);
|
||||
|
||||
if constexpr(kHasLocalMask)
|
||||
{
|
||||
if(has_minfull_attn_seqlen)
|
||||
num_tile_in_seqlen += 1;
|
||||
};
|
||||
|
||||
if constexpr(HstuAttentionPipeline::kN1 < HstuAttentionPipeline::kSubQKHeaddim)
|
||||
{
|
||||
#if HSTU_SCHED_BATCH_AS_FIRST_GRID_DIM
|
||||
return dim3(batch_size_,
|
||||
nhead_,
|
||||
ck_tile::integer_divide_ceil(seqlen_, HstuAttentionPipeline::kM0) *
|
||||
num_tile_in_seqlen *
|
||||
ck_tile::integer_divide_ceil(hdim_v_, HstuAttentionPipeline::kN1));
|
||||
#else
|
||||
return dim3(ck_tile::integer_divide_ceil(seqlen_, HstuAttentionPipeline::kM0) *
|
||||
return dim3(num_tile_in_seqlen *
|
||||
ck_tile::integer_divide_ceil(hdim_v_, HstuAttentionPipeline::kN1),
|
||||
nhead_,
|
||||
batch_size_);
|
||||
@@ -514,11 +527,9 @@ struct HstuAttentionFwdKernel
|
||||
else
|
||||
{
|
||||
#if HSTU_SCHED_BATCH_AS_FIRST_GRID_DIM
|
||||
return dim3(batch_size_,
|
||||
nhead_,
|
||||
ck_tile::integer_divide_ceil(seqlen_, HstuAttentionPipeline::kM0));
|
||||
return dim3(batch_size_, nhead_, num_tile_in_seqlen);
|
||||
#else
|
||||
return dim3(ck_tile::integer_divide_ceil(seqlen_, HstuAttentionPipeline::kM0),
|
||||
return dim3(num_tile_in_seqlen),
|
||||
nhead_,
|
||||
batch_size_);
|
||||
#endif
|
||||
@@ -593,12 +604,8 @@ struct HstuAttentionFwdKernel
|
||||
// allocate LDS
|
||||
__shared__ char smem_ptr[GetSmemSize()];
|
||||
|
||||
// divide problem
|
||||
const auto [i_tile_m, i_tile_n, i_nhead, i_batch] = GetTileIndex(kargs);
|
||||
|
||||
const index_t i_m0 = __builtin_amdgcn_readfirstlane(i_tile_m * HstuAttentionPipeline::kM0);
|
||||
const index_t i_n1 = __builtin_amdgcn_readfirstlane(i_tile_n * HstuAttentionPipeline::kN1);
|
||||
|
||||
long_index_t batch_offset_q = 0;
|
||||
long_index_t batch_offset_k = 0;
|
||||
long_index_t batch_offset_v = 0;
|
||||
@@ -628,13 +635,6 @@ struct HstuAttentionFwdKernel
|
||||
batch_offset_o = query_start * kargs.seq_stride_o;
|
||||
|
||||
kargs.seqlen = kargs.seq_offsets_ptr[i_batch + 1] - kargs.seq_offsets_ptr[i_batch];
|
||||
|
||||
// # of required blocks is different in each groups, terminate unnecessary blocks
|
||||
// earlier
|
||||
if(kargs.seqlen <= i_m0)
|
||||
{
|
||||
return;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
@@ -650,9 +650,36 @@ struct HstuAttentionFwdKernel
|
||||
|
||||
int num_target = (kargs.num_targets_ptr == nullptr) ? 0 : kargs.num_targets_ptr[i_batch];
|
||||
|
||||
index_t seqlen_in_first_split = kargs.seqlen;
|
||||
|
||||
if constexpr(kHasLocalMask)
|
||||
{
|
||||
if(kargs.min_full_attn_seqlen > 0)
|
||||
seqlen_in_first_split = kargs.seqlen - kargs.min_full_attn_seqlen - num_target;
|
||||
};
|
||||
|
||||
index_t num_tile_in_first_split =
|
||||
ck_tile::integer_divide_ceil(seqlen_in_first_split, HstuAttentionPipeline::kM0);
|
||||
|
||||
bool is_tile_in_first_split = (i_tile_m < num_tile_in_first_split);
|
||||
|
||||
index_t i_m0 = is_tile_in_first_split
|
||||
? __builtin_amdgcn_readfirstlane(i_tile_m * HstuAttentionPipeline::kM0)
|
||||
: __builtin_amdgcn_readfirstlane((i_tile_m - num_tile_in_first_split) *
|
||||
HstuAttentionPipeline::kM0) +
|
||||
seqlen_in_first_split;
|
||||
|
||||
const index_t i_n1 = __builtin_amdgcn_readfirstlane(i_tile_n * HstuAttentionPipeline::kN1);
|
||||
|
||||
index_t seqlen_q_in_ctrl = is_tile_in_first_split ? seqlen_in_first_split : kargs.seqlen;
|
||||
|
||||
if(seqlen_q_in_ctrl <= i_m0)
|
||||
return;
|
||||
|
||||
HstuMask mask = [&]() {
|
||||
if constexpr(kHasLocalMask)
|
||||
return make_hstu_block_mask_with_local<HstuMask>(kargs.seqlen,
|
||||
return make_hstu_block_mask_with_local<HstuMask>(is_tile_in_first_split,
|
||||
kargs.seqlen,
|
||||
kargs.contextual_seqlen,
|
||||
num_target,
|
||||
kargs.window_size,
|
||||
@@ -680,7 +707,7 @@ struct HstuAttentionFwdKernel
|
||||
const auto q_dram = [&]() {
|
||||
const auto q_dram_naive = make_naive_tensor_view<address_space_enum::global>(
|
||||
q_ptr,
|
||||
make_tuple(kargs.seqlen, kargs.hdim_qk),
|
||||
make_tuple(seqlen_q_in_ctrl, kargs.hdim_qk),
|
||||
make_tuple(kargs.seq_stride_q, 1),
|
||||
number<HstuAttentionPipeline::kAlignmentQ>{},
|
||||
number<1>{});
|
||||
@@ -773,7 +800,7 @@ struct HstuAttentionFwdKernel
|
||||
const auto bias_dram = [&]() {
|
||||
const auto bias_dram_naive = make_naive_tensor_view<address_space_enum::global>(
|
||||
bias_ptr,
|
||||
make_tuple(kargs.seqlen, kargs.seqlen),
|
||||
make_tuple(seqlen_q_in_ctrl, kargs.seqlen),
|
||||
make_tuple(kargs.seq_stride_bias, 1),
|
||||
number<HstuAttentionPipeline::kAlignmentBias>{},
|
||||
number<1>{});
|
||||
@@ -824,7 +851,7 @@ struct HstuAttentionFwdKernel
|
||||
auto o_dram = [&]() {
|
||||
const auto o_dram_naive = make_naive_tensor_view<address_space_enum::global>(
|
||||
o_ptr,
|
||||
make_tuple(kargs.seqlen, kargs.hdim_v),
|
||||
make_tuple(seqlen_q_in_ctrl, kargs.hdim_v),
|
||||
make_tuple(kargs.seq_stride_o, 1),
|
||||
number<HstuAttentionPipeline::kAlignmentO>{},
|
||||
number<1>{});
|
||||
|
||||
@@ -116,8 +116,12 @@ struct jagged_forward_causal_local_bias_dropout_dispatch
|
||||
param.philox_offset);
|
||||
}();
|
||||
|
||||
dim3 kGridSize =
|
||||
HstuKernel::GridSize(param.num_batch, param.num_head, param.max_seqlen, param.hdim_v);
|
||||
bool has_minfull_attn_seqlen = (param.min_full_attn_seqlen > 0);
|
||||
dim3 kGridSize = HstuKernel::GridSize(param.num_batch,
|
||||
param.num_head,
|
||||
param.max_seqlen,
|
||||
param.hdim_v,
|
||||
has_minfull_attn_seqlen);
|
||||
constexpr dim3 kBlockSize = HstuKernel::BlockSize();
|
||||
constexpr ck_tile::index_t kBlockPerCu = HstuKernel::kBlockPerCu;
|
||||
|
||||
|
||||
@@ -14,6 +14,10 @@ struct HstuBlockMaskWithLocal
|
||||
static constexpr bool kUseLocal = true;
|
||||
static constexpr bool IsMasking = true;
|
||||
|
||||
// is_tile_in_first_split is false only when min_full_attn_seqlen > 0 and the current
|
||||
// tile is inside scope [max_uih_len - min_full_attn_seqlen, seqlen); for other cases
|
||||
// and tiles, is_tile_in_first_split is true
|
||||
bool is_tile_in_first_split;
|
||||
int seqlen;
|
||||
int contextual_seqlen;
|
||||
|
||||
@@ -23,12 +27,14 @@ struct HstuBlockMaskWithLocal
|
||||
int max_uih_len;
|
||||
int max_id;
|
||||
|
||||
CK_TILE_HOST_DEVICE HstuBlockMaskWithLocal(int seqlen_,
|
||||
CK_TILE_HOST_DEVICE HstuBlockMaskWithLocal(bool is_tile_in_first_split_,
|
||||
int seqlen_,
|
||||
int contextual_seqlen_,
|
||||
int max_attn_len_,
|
||||
int min_full_attn_seqlen_,
|
||||
int num_target_)
|
||||
: seqlen(seqlen_),
|
||||
: is_tile_in_first_split(is_tile_in_first_split_),
|
||||
seqlen(seqlen_),
|
||||
contextual_seqlen(contextual_seqlen_),
|
||||
max_attn_len(max_attn_len_),
|
||||
min_full_attn_seqlen(min_full_attn_seqlen_)
|
||||
@@ -48,10 +54,16 @@ struct HstuBlockMaskWithLocal
|
||||
CK_TILE_HOST_DEVICE constexpr auto
|
||||
GetTileRangeAlongX(index_t i_y, number<YTile>, number<XTile>) const
|
||||
{
|
||||
if(min_full_attn_seqlen > 0 && i_y + YTile > max_uih_len - min_full_attn_seqlen)
|
||||
// handle two special cases first
|
||||
if(!is_tile_in_first_split)
|
||||
{
|
||||
index_t x_end = min(i_y + YTile, seqlen);
|
||||
return ck_tile::make_tuple(0, x_end);
|
||||
// the tile is completely inside [max_uih_len - min_full_attn_seqlen, max_uih_len)
|
||||
if(i_y + YTile <= max_uih_len)
|
||||
return ck_tile::make_tuple(0, max_uih_len);
|
||||
// the tils is partially inside [max_uih_len - min_full_attn_seqlen, max_uih_len) and
|
||||
// partially inside [max_uih_len, seqlen)
|
||||
if(i_y < max_uih_len)
|
||||
return ck_tile::make_tuple(0, seqlen);
|
||||
};
|
||||
|
||||
if constexpr(!kUseCausal)
|
||||
@@ -204,8 +216,7 @@ struct HstuBlockMaskWithLocal
|
||||
// diagonal line are always considerred
|
||||
if constexpr(kUseCausal)
|
||||
{
|
||||
bool in_min_full_scope =
|
||||
(min_full_attn_seqlen > 0) ? (row_id >= max_id - min_full_attn_seqlen) : false;
|
||||
bool in_min_full_scope = !is_tile_in_first_split;
|
||||
|
||||
bool res = (((row_id > col_id) || (row == col)) &&
|
||||
((row_id - col_id <= max_attn_len) || in_min_full_scope));
|
||||
@@ -214,8 +225,7 @@ struct HstuBlockMaskWithLocal
|
||||
}
|
||||
else
|
||||
{
|
||||
bool in_min_full_scope =
|
||||
(min_full_attn_seqlen > 0) ? (row_id >= max_id - min_full_attn_seqlen) : false;
|
||||
bool in_min_full_scope = !is_tile_in_first_split;
|
||||
|
||||
bool res = (((row_id != col_id) || (row == col)) &&
|
||||
((abs(row_id - col_id) <= max_attn_len) || in_min_full_scope));
|
||||
@@ -233,11 +243,7 @@ struct HstuBlockMaskWithLocal
|
||||
{
|
||||
std::ignore = i_tile_left;
|
||||
|
||||
index_t i_tile_bottom = i_tile_top + (TileHeight - 1);
|
||||
|
||||
// assume num_target > 0 with high probability, don't check whether num_target is 0;
|
||||
// so if num_target is 0, IsTokenPairInsideMask() will be called for the bottom tile
|
||||
if(i_tile_top >= max_uih_len - min_full_attn_seqlen && i_tile_bottom < max_uih_len)
|
||||
if(!is_tile_in_first_split && (i_tile_top + TileHeight <= max_uih_len))
|
||||
return true;
|
||||
|
||||
return false;
|
||||
@@ -423,14 +429,19 @@ struct HstuBlockMasking
|
||||
};
|
||||
|
||||
template <typename HstuBlockMaskType>
|
||||
CK_TILE_HOST_DEVICE constexpr auto make_hstu_block_mask_with_local(int seqlen_,
|
||||
CK_TILE_HOST_DEVICE constexpr auto make_hstu_block_mask_with_local(bool is_tile_in_first_split_,
|
||||
int seqlen_,
|
||||
int contextual_seqlen_,
|
||||
int num_target,
|
||||
int max_attn_len_,
|
||||
int min_full_attn_seqlen_)
|
||||
{
|
||||
return HstuBlockMaskType{
|
||||
seqlen_, contextual_seqlen_, max_attn_len_, min_full_attn_seqlen_, num_target};
|
||||
return HstuBlockMaskType{is_tile_in_first_split_,
|
||||
seqlen_,
|
||||
contextual_seqlen_,
|
||||
max_attn_len_,
|
||||
min_full_attn_seqlen_,
|
||||
num_target};
|
||||
};
|
||||
|
||||
template <typename HstuBlockMaskType>
|
||||
|
||||
@@ -113,8 +113,12 @@ struct reference_hstu_attention
|
||||
|
||||
HstuMask mask = [&]() {
|
||||
if constexpr(kHasLocalMask)
|
||||
return ck_tile::make_hstu_block_mask_with_local<HstuMask>(
|
||||
seqlen, contextual_seqlen, num_target, max_attn_len, min_full_attn_seqlen);
|
||||
return ck_tile::make_hstu_block_mask_with_local<HstuMask>(true,
|
||||
seqlen,
|
||||
contextual_seqlen,
|
||||
num_target,
|
||||
max_attn_len,
|
||||
min_full_attn_seqlen);
|
||||
else
|
||||
return ck_tile::make_hstu_block_mask_without_local<HstuMask>(
|
||||
seqlen, contextual_seqlen, num_target);
|
||||
|
||||
Reference in New Issue
Block a user