From f49fe28ca229b4b4812e7bbb883c14f26710dc42 Mon Sep 17 00:00:00 2001 From: Qianfeng Zhang Date: Wed, 23 Jul 2025 09:39:50 +0000 Subject: [PATCH] [Performance] Use separate workgroups to handle seqlen scope [max_uih_len - minfull_attn_seqlen, seqlen] --- ...stu_attention_batched_forward_dispatch.hpp | 5 +- .../hstu_attention_fwd_kernel.hpp | 71 +++++++++++++------ ...hstu_attention_jagged_forward_dispatch.hpp | 8 ++- .../18_hstu_attention/hstu_block_masking.hpp | 45 +++++++----- .../reference_hstu_attention.hpp | 8 ++- 5 files changed, 92 insertions(+), 45 deletions(-) diff --git a/example/ck_tile/18_hstu_attention/hstu_attention_batched_forward_dispatch.hpp b/example/ck_tile/18_hstu_attention/hstu_attention_batched_forward_dispatch.hpp index 52bd5222e8..36bb6b261d 100644 --- a/example/ck_tile/18_hstu_attention/hstu_attention_batched_forward_dispatch.hpp +++ b/example/ck_tile/18_hstu_attention/hstu_attention_batched_forward_dispatch.hpp @@ -128,8 +128,9 @@ struct batched_forward_causal_local_bias_dropout_dispatch param.philox_offset); }(); - dim3 kGridSize = - HstuKernel::GridSize(param.num_batch, param.num_head, param.seqlen, param.hdim_v); + bool has_minfull_attn_seqlen = (param.min_full_attn_seqlen > 0); + dim3 kGridSize = HstuKernel::GridSize( + param.num_batch, param.num_head, param.seqlen, param.hdim_v, has_minfull_attn_seqlen); constexpr dim3 kBlockSize = HstuKernel::BlockSize(); constexpr ck_tile::index_t kBlockPerCu = HstuKernel::kBlockPerCu; diff --git a/example/ck_tile/18_hstu_attention/hstu_attention_fwd_kernel.hpp b/example/ck_tile/18_hstu_attention/hstu_attention_fwd_kernel.hpp index b294a4a3e0..34e3c98ac6 100644 --- a/example/ck_tile/18_hstu_attention/hstu_attention_fwd_kernel.hpp +++ b/example/ck_tile/18_hstu_attention/hstu_attention_fwd_kernel.hpp @@ -495,17 +495,30 @@ struct HstuAttentionFwdKernel CK_TILE_HOST static constexpr auto GridSize(ck_tile::index_t batch_size_, ck_tile::index_t nhead_, ck_tile::index_t seqlen_, - ck_tile::index_t hdim_v_) + ck_tile::index_t hdim_v_, + bool has_minfull_attn_seqlen) { + // The Q sequence [0, seqlen) will be split to two parts for allocating workgroups: + // 1) [0, seqlen - target - min_full_attn_seqlen) + // 2) [seqlen - target - min_full_attn_seqlen, seqlen) + ck_tile::index_t num_tile_in_seqlen = + ck_tile::integer_divide_ceil(seqlen_, HstuAttentionPipeline::kM0); + + if constexpr(kHasLocalMask) + { + if(has_minfull_attn_seqlen) + num_tile_in_seqlen += 1; + }; + if constexpr(HstuAttentionPipeline::kN1 < HstuAttentionPipeline::kSubQKHeaddim) { #if HSTU_SCHED_BATCH_AS_FIRST_GRID_DIM return dim3(batch_size_, nhead_, - ck_tile::integer_divide_ceil(seqlen_, HstuAttentionPipeline::kM0) * + num_tile_in_seqlen * ck_tile::integer_divide_ceil(hdim_v_, HstuAttentionPipeline::kN1)); #else - return dim3(ck_tile::integer_divide_ceil(seqlen_, HstuAttentionPipeline::kM0) * + return dim3(num_tile_in_seqlen * ck_tile::integer_divide_ceil(hdim_v_, HstuAttentionPipeline::kN1), nhead_, batch_size_); @@ -514,11 +527,9 @@ struct HstuAttentionFwdKernel else { #if HSTU_SCHED_BATCH_AS_FIRST_GRID_DIM - return dim3(batch_size_, - nhead_, - ck_tile::integer_divide_ceil(seqlen_, HstuAttentionPipeline::kM0)); + return dim3(batch_size_, nhead_, num_tile_in_seqlen); #else - return dim3(ck_tile::integer_divide_ceil(seqlen_, HstuAttentionPipeline::kM0), + return dim3(num_tile_in_seqlen), nhead_, batch_size_); #endif @@ -593,12 +604,8 @@ struct HstuAttentionFwdKernel // allocate LDS __shared__ char smem_ptr[GetSmemSize()]; - // divide problem const auto [i_tile_m, i_tile_n, i_nhead, i_batch] = GetTileIndex(kargs); - const index_t i_m0 = __builtin_amdgcn_readfirstlane(i_tile_m * HstuAttentionPipeline::kM0); - const index_t i_n1 = __builtin_amdgcn_readfirstlane(i_tile_n * HstuAttentionPipeline::kN1); - long_index_t batch_offset_q = 0; long_index_t batch_offset_k = 0; long_index_t batch_offset_v = 0; @@ -628,13 +635,6 @@ struct HstuAttentionFwdKernel batch_offset_o = query_start * kargs.seq_stride_o; kargs.seqlen = kargs.seq_offsets_ptr[i_batch + 1] - kargs.seq_offsets_ptr[i_batch]; - - // # of required blocks is different in each groups, terminate unnecessary blocks - // earlier - if(kargs.seqlen <= i_m0) - { - return; - } } else { @@ -650,9 +650,36 @@ struct HstuAttentionFwdKernel int num_target = (kargs.num_targets_ptr == nullptr) ? 0 : kargs.num_targets_ptr[i_batch]; + index_t seqlen_in_first_split = kargs.seqlen; + + if constexpr(kHasLocalMask) + { + if(kargs.min_full_attn_seqlen > 0) + seqlen_in_first_split = kargs.seqlen - kargs.min_full_attn_seqlen - num_target; + }; + + index_t num_tile_in_first_split = + ck_tile::integer_divide_ceil(seqlen_in_first_split, HstuAttentionPipeline::kM0); + + bool is_tile_in_first_split = (i_tile_m < num_tile_in_first_split); + + index_t i_m0 = is_tile_in_first_split + ? __builtin_amdgcn_readfirstlane(i_tile_m * HstuAttentionPipeline::kM0) + : __builtin_amdgcn_readfirstlane((i_tile_m - num_tile_in_first_split) * + HstuAttentionPipeline::kM0) + + seqlen_in_first_split; + + const index_t i_n1 = __builtin_amdgcn_readfirstlane(i_tile_n * HstuAttentionPipeline::kN1); + + index_t seqlen_q_in_ctrl = is_tile_in_first_split ? seqlen_in_first_split : kargs.seqlen; + + if(seqlen_q_in_ctrl <= i_m0) + return; + HstuMask mask = [&]() { if constexpr(kHasLocalMask) - return make_hstu_block_mask_with_local(kargs.seqlen, + return make_hstu_block_mask_with_local(is_tile_in_first_split, + kargs.seqlen, kargs.contextual_seqlen, num_target, kargs.window_size, @@ -680,7 +707,7 @@ struct HstuAttentionFwdKernel const auto q_dram = [&]() { const auto q_dram_naive = make_naive_tensor_view( q_ptr, - make_tuple(kargs.seqlen, kargs.hdim_qk), + make_tuple(seqlen_q_in_ctrl, kargs.hdim_qk), make_tuple(kargs.seq_stride_q, 1), number{}, number<1>{}); @@ -773,7 +800,7 @@ struct HstuAttentionFwdKernel const auto bias_dram = [&]() { const auto bias_dram_naive = make_naive_tensor_view( bias_ptr, - make_tuple(kargs.seqlen, kargs.seqlen), + make_tuple(seqlen_q_in_ctrl, kargs.seqlen), make_tuple(kargs.seq_stride_bias, 1), number{}, number<1>{}); @@ -824,7 +851,7 @@ struct HstuAttentionFwdKernel auto o_dram = [&]() { const auto o_dram_naive = make_naive_tensor_view( o_ptr, - make_tuple(kargs.seqlen, kargs.hdim_v), + make_tuple(seqlen_q_in_ctrl, kargs.hdim_v), make_tuple(kargs.seq_stride_o, 1), number{}, number<1>{}); diff --git a/example/ck_tile/18_hstu_attention/hstu_attention_jagged_forward_dispatch.hpp b/example/ck_tile/18_hstu_attention/hstu_attention_jagged_forward_dispatch.hpp index a4d27b7eff..676ecc3e50 100644 --- a/example/ck_tile/18_hstu_attention/hstu_attention_jagged_forward_dispatch.hpp +++ b/example/ck_tile/18_hstu_attention/hstu_attention_jagged_forward_dispatch.hpp @@ -116,8 +116,12 @@ struct jagged_forward_causal_local_bias_dropout_dispatch param.philox_offset); }(); - dim3 kGridSize = - HstuKernel::GridSize(param.num_batch, param.num_head, param.max_seqlen, param.hdim_v); + bool has_minfull_attn_seqlen = (param.min_full_attn_seqlen > 0); + dim3 kGridSize = HstuKernel::GridSize(param.num_batch, + param.num_head, + param.max_seqlen, + param.hdim_v, + has_minfull_attn_seqlen); constexpr dim3 kBlockSize = HstuKernel::BlockSize(); constexpr ck_tile::index_t kBlockPerCu = HstuKernel::kBlockPerCu; diff --git a/example/ck_tile/18_hstu_attention/hstu_block_masking.hpp b/example/ck_tile/18_hstu_attention/hstu_block_masking.hpp index 54d4a32476..d5b0724566 100644 --- a/example/ck_tile/18_hstu_attention/hstu_block_masking.hpp +++ b/example/ck_tile/18_hstu_attention/hstu_block_masking.hpp @@ -14,6 +14,10 @@ struct HstuBlockMaskWithLocal static constexpr bool kUseLocal = true; static constexpr bool IsMasking = true; + // is_tile_in_first_split is false only when min_full_attn_seqlen > 0 and the current + // tile is inside scope [max_uih_len - min_full_attn_seqlen, seqlen); for other cases + // and tiles, is_tile_in_first_split is true + bool is_tile_in_first_split; int seqlen; int contextual_seqlen; @@ -23,12 +27,14 @@ struct HstuBlockMaskWithLocal int max_uih_len; int max_id; - CK_TILE_HOST_DEVICE HstuBlockMaskWithLocal(int seqlen_, + CK_TILE_HOST_DEVICE HstuBlockMaskWithLocal(bool is_tile_in_first_split_, + int seqlen_, int contextual_seqlen_, int max_attn_len_, int min_full_attn_seqlen_, int num_target_) - : seqlen(seqlen_), + : is_tile_in_first_split(is_tile_in_first_split_), + seqlen(seqlen_), contextual_seqlen(contextual_seqlen_), max_attn_len(max_attn_len_), min_full_attn_seqlen(min_full_attn_seqlen_) @@ -48,10 +54,16 @@ struct HstuBlockMaskWithLocal CK_TILE_HOST_DEVICE constexpr auto GetTileRangeAlongX(index_t i_y, number, number) const { - if(min_full_attn_seqlen > 0 && i_y + YTile > max_uih_len - min_full_attn_seqlen) + // handle two special cases first + if(!is_tile_in_first_split) { - index_t x_end = min(i_y + YTile, seqlen); - return ck_tile::make_tuple(0, x_end); + // the tile is completely inside [max_uih_len - min_full_attn_seqlen, max_uih_len) + if(i_y + YTile <= max_uih_len) + return ck_tile::make_tuple(0, max_uih_len); + // the tils is partially inside [max_uih_len - min_full_attn_seqlen, max_uih_len) and + // partially inside [max_uih_len, seqlen) + if(i_y < max_uih_len) + return ck_tile::make_tuple(0, seqlen); }; if constexpr(!kUseCausal) @@ -204,8 +216,7 @@ struct HstuBlockMaskWithLocal // diagonal line are always considerred if constexpr(kUseCausal) { - bool in_min_full_scope = - (min_full_attn_seqlen > 0) ? (row_id >= max_id - min_full_attn_seqlen) : false; + bool in_min_full_scope = !is_tile_in_first_split; bool res = (((row_id > col_id) || (row == col)) && ((row_id - col_id <= max_attn_len) || in_min_full_scope)); @@ -214,8 +225,7 @@ struct HstuBlockMaskWithLocal } else { - bool in_min_full_scope = - (min_full_attn_seqlen > 0) ? (row_id >= max_id - min_full_attn_seqlen) : false; + bool in_min_full_scope = !is_tile_in_first_split; bool res = (((row_id != col_id) || (row == col)) && ((abs(row_id - col_id) <= max_attn_len) || in_min_full_scope)); @@ -233,11 +243,7 @@ struct HstuBlockMaskWithLocal { std::ignore = i_tile_left; - index_t i_tile_bottom = i_tile_top + (TileHeight - 1); - - // assume num_target > 0 with high probability, don't check whether num_target is 0; - // so if num_target is 0, IsTokenPairInsideMask() will be called for the bottom tile - if(i_tile_top >= max_uih_len - min_full_attn_seqlen && i_tile_bottom < max_uih_len) + if(!is_tile_in_first_split && (i_tile_top + TileHeight <= max_uih_len)) return true; return false; @@ -423,14 +429,19 @@ struct HstuBlockMasking }; template -CK_TILE_HOST_DEVICE constexpr auto make_hstu_block_mask_with_local(int seqlen_, +CK_TILE_HOST_DEVICE constexpr auto make_hstu_block_mask_with_local(bool is_tile_in_first_split_, + int seqlen_, int contextual_seqlen_, int num_target, int max_attn_len_, int min_full_attn_seqlen_) { - return HstuBlockMaskType{ - seqlen_, contextual_seqlen_, max_attn_len_, min_full_attn_seqlen_, num_target}; + return HstuBlockMaskType{is_tile_in_first_split_, + seqlen_, + contextual_seqlen_, + max_attn_len_, + min_full_attn_seqlen_, + num_target}; }; template diff --git a/example/ck_tile/18_hstu_attention/reference_hstu_attention.hpp b/example/ck_tile/18_hstu_attention/reference_hstu_attention.hpp index 9eeaaa35be..645f53aebd 100644 --- a/example/ck_tile/18_hstu_attention/reference_hstu_attention.hpp +++ b/example/ck_tile/18_hstu_attention/reference_hstu_attention.hpp @@ -113,8 +113,12 @@ struct reference_hstu_attention HstuMask mask = [&]() { if constexpr(kHasLocalMask) - return ck_tile::make_hstu_block_mask_with_local( - seqlen, contextual_seqlen, num_target, max_attn_len, min_full_attn_seqlen); + return ck_tile::make_hstu_block_mask_with_local(true, + seqlen, + contextual_seqlen, + num_target, + max_attn_len, + min_full_attn_seqlen); else return ck_tile::make_hstu_block_mask_without_local( seqlen, contextual_seqlen, num_target);