From 5d7302c240df39c5df3b3e622dfe82cd42924c7d Mon Sep 17 00:00:00 2001 From: carlushuang Date: Thu, 12 Jun 2025 11:44:22 +0800 Subject: [PATCH] [CK_TILE] moe sorting optimization : refactor subtoken logic to let more kernel pickup mp kernel (#2327) * refactor subtoken logic to let more kernel pickup mp kernel * typo [ROCm/composable_kernel commit: 8aff45a8af0c868d8c3513dab3335e3b1d3e111f] --- .../fused_moe/kernel/moe_sorting_kernel.hpp | 37 ++++++------------- 1 file changed, 11 insertions(+), 26 deletions(-) diff --git a/include/ck_tile/ops/fused_moe/kernel/moe_sorting_kernel.hpp b/include/ck_tile/ops/fused_moe/kernel/moe_sorting_kernel.hpp index 664294fe18..4166c1c602 100644 --- a/include/ck_tile/ops/fused_moe/kernel/moe_sorting_kernel.hpp +++ b/include/ck_tile/ops/fused_moe/kernel/moe_sorting_kernel.hpp @@ -127,37 +127,21 @@ CK_TILE_HOST constexpr auto moe_sorting_get_smem_row_col(int tokens_, int num_ex constexpr index_t cumsum_bufs = 2; // 1 for cumsum, 1 for cnt // at lease 2 lines, one for sub_token unroll, one for cumsum // should be enough - if ((total_ / target_occupancy_) < ((cumsum_bufs+sub_unroll) * smem_cols)) { - if ((total_ / 1) < ((cumsum_bufs+sub_unroll) * smem_cols)) - throw std::runtime_error("too many num_experts, can't allocate smem"); - target_occupancy_ = 1; - } + int r = total_ / target_occupancy_ / smem_cols; + // Note: at lease allocate cumsum_bufs + sub_unroll as num-row. Otherwise, fallback to mp kernel + if(r < (cumsum_bufs + sub_unroll)) + return cumsum_bufs; + // round to sub_unroll multipl int r_for_sub_token = r - cumsum_bufs; - r_for_sub_token = min(r_for_sub_token, tokens_); - r_for_sub_token = (r_for_sub_token + sub_unroll - 1) / sub_unroll * sub_unroll; - r_for_sub_token = max(r_for_sub_token, 1); + r_for_sub_token = r_for_sub_token / sub_unroll * sub_unroll; + int r_token_min = (tokens_ + sub_unroll - 1) / sub_unroll * sub_unroll; + r_for_sub_token = min(r_for_sub_token, r_token_min); - if(r_for_sub_token > 1) - { - int r_unroll_ = r_for_sub_token / sub_unroll; - - - // round to 1x/2x/4x/8x number of sub_unroll - int clz_ = __builtin_clz(r_unroll_); // 0b1:31 0b2:30, 0b3:30, 0b4:29 - int mask_ = (1 << (31 - clz_)) - 1; - - - mask_ = mask_ > 0b111 ? 0b111 : mask_; //clamp to 8x at most - mask_ = ~mask_; - - r_for_sub_token = (r_unroll_ & mask_) * sub_unroll; - } - - // final check - if( (r_for_sub_token + cumsum_bufs * smem_cols * target_occupancy_ ) >= total_ ) { + // final check, but usually should not happen + if( ((r_for_sub_token + cumsum_bufs) * smem_cols * target_occupancy_ ) > total_ ) { throw std::runtime_error("can't run this kernel, request LDS over size"); } @@ -167,6 +151,7 @@ CK_TILE_HOST constexpr auto moe_sorting_get_smem_row_col(int tokens_, int num_ex return ck_tile::make_tuple(smem_rows, smem_cols); } +// if return 0 or negative, means LDS is not enough CK_TILE_HOST index_t moe_sorting_get_sub_token(int tokens_, int num_experts_) { auto [r_, c_] = moe_sorting_get_smem_row_col(tokens_, num_experts_);