From 95c93ba92e2f2d1400f0a83a99e739dd4aebda6b Mon Sep 17 00:00:00 2001 From: Qianfeng Zhang Date: Sat, 26 Apr 2025 10:01:23 +0000 Subject: [PATCH] Update the GridSize() and GetTileIndex() in hstu kernel --- .../hstu_attention_fwd_kernel.hpp | 57 +++++++++++++------ 1 file changed, 39 insertions(+), 18 deletions(-) diff --git a/example/ck_tile/18_hstu_attention/hstu_attention_fwd_kernel.hpp b/example/ck_tile/18_hstu_attention/hstu_attention_fwd_kernel.hpp index 92351ecc69..f0780520ca 100644 --- a/example/ck_tile/18_hstu_attention/hstu_attention_fwd_kernel.hpp +++ b/example/ck_tile/18_hstu_attention/hstu_attention_fwd_kernel.hpp @@ -448,32 +448,53 @@ struct HstuAttentionFwdKernel ck_tile::index_t seqlen_, ck_tile::index_t hdim_v_) { - // TODO: this may need tuning - return dim3(ck_tile::integer_divide_ceil(seqlen_, HstuAttentionPipeline::kM0) * - ck_tile::integer_divide_ceil(hdim_v_, HstuAttentionPipeline::kN1), - nhead_, - batch_size_); + if constexpr(HstuAttentionPipeline::kN1 < HstuAttentionPipeline::kSubQKHeaddim) + { + return dim3(ck_tile::integer_divide_ceil(seqlen_, HstuAttentionPipeline::kM0) * + ck_tile::integer_divide_ceil(hdim_v_, HstuAttentionPipeline::kN1), + nhead_, + batch_size_); + } + else + { + return dim3(ck_tile::integer_divide_ceil(seqlen_, HstuAttentionPipeline::kM0), + nhead_, + batch_size_); + } } CK_TILE_DEVICE static constexpr auto GetTileIndex(const Kargs& kargs) { - // const index_t num_tile_m0 = seqlen_q / kM0; - const index_t num_tile_n1 = - ck_tile::integer_divide_ceil(kargs.hdim_v, HstuAttentionPipeline::kN1); + if constexpr(HstuAttentionPipeline::kN1 < HstuAttentionPipeline::kSubQKHeaddim) + { + const index_t num_tile_n1 = + ck_tile::integer_divide_ceil(kargs.hdim_v, HstuAttentionPipeline::kN1); - const index_t i_block = blockIdx.x; - const index_t i_nhead = blockIdx.y; - const index_t i_batch = blockIdx.z; + const index_t i_block = blockIdx.x; + const index_t i_nhead = blockIdx.y; + const index_t i_batch = blockIdx.z; - const auto f = [](index_t dividend, index_t divisor) { - index_t quotient = dividend / divisor; - index_t modulus = dividend - quotient * divisor; - return ck_tile::make_tuple(quotient, modulus); - }; + const auto f = [](index_t dividend, index_t divisor) { + index_t quotient = dividend / divisor; + index_t modulus = dividend - quotient * divisor; + return ck_tile::make_tuple(quotient, modulus); + }; - const auto [i_tile_m, i_tile_n] = f(i_block, num_tile_n1); + const auto [i_tile_m, i_tile_n] = f(i_block, num_tile_n1); - return ck_tile::make_tuple(i_tile_m, i_tile_n, i_nhead, i_batch); + return ck_tile::make_tuple(i_tile_m, i_tile_n, i_nhead, i_batch); + } + else + { + const index_t i_block = blockIdx.x; + const index_t i_nhead = blockIdx.y; + const index_t i_batch = blockIdx.z; + + const index_t i_tile_m = i_block; + const index_t i_tile_n = 0; + + return ck_tile::make_tuple(i_tile_m, i_tile_n, i_nhead, i_batch); + } } CK_TILE_HOST static constexpr auto BlockSize() { return dim3(kBlockSize); }