[CK_TILE] moe sorting optimization : refactor subtoken logic to let more kernel pickup mp kernel (#2327)

* refactor subtoken logic to let more kernel pickup mp kernel

* typo
This commit is contained in:
carlushuang
2025-06-12 11:44:22 +08:00
committed by GitHub
parent 37554c31e8
commit 8aff45a8af

View File

@@ -127,37 +127,21 @@ CK_TILE_HOST constexpr auto moe_sorting_get_smem_row_col(int tokens_, int num_ex
constexpr index_t cumsum_bufs = 2; // 1 for cumsum, 1 for cnt
// at lease 2 lines, one for sub_token unroll, one for cumsum
// should be enough
if ((total_ / target_occupancy_) < ((cumsum_bufs+sub_unroll) * smem_cols)) {
if ((total_ / 1) < ((cumsum_bufs+sub_unroll) * smem_cols))
throw std::runtime_error("too many num_experts, can't allocate smem");
target_occupancy_ = 1;
}
int r = total_ / target_occupancy_ / smem_cols;
// Note: at lease allocate cumsum_bufs + sub_unroll as num-row. Otherwise, fallback to mp kernel
if(r < (cumsum_bufs + sub_unroll))
return cumsum_bufs;
// round to sub_unroll multipl
int r_for_sub_token = r - cumsum_bufs;
r_for_sub_token = min(r_for_sub_token, tokens_);
r_for_sub_token = (r_for_sub_token + sub_unroll - 1) / sub_unroll * sub_unroll;
r_for_sub_token = max(r_for_sub_token, 1);
r_for_sub_token = r_for_sub_token / sub_unroll * sub_unroll;
int r_token_min = (tokens_ + sub_unroll - 1) / sub_unroll * sub_unroll;
r_for_sub_token = min(r_for_sub_token, r_token_min);
if(r_for_sub_token > 1)
{
int r_unroll_ = r_for_sub_token / sub_unroll;
// round to 1x/2x/4x/8x number of sub_unroll
int clz_ = __builtin_clz(r_unroll_); // 0b1:31 0b2:30, 0b3:30, 0b4:29
int mask_ = (1 << (31 - clz_)) - 1;
mask_ = mask_ > 0b111 ? 0b111 : mask_; //clamp to 8x at most
mask_ = ~mask_;
r_for_sub_token = (r_unroll_ & mask_) * sub_unroll;
}
// final check
if( (r_for_sub_token + cumsum_bufs * smem_cols * target_occupancy_ ) >= total_ ) {
// final check, but usually should not happen
if( ((r_for_sub_token + cumsum_bufs) * smem_cols * target_occupancy_ ) > total_ ) {
throw std::runtime_error("can't run this kernel, request LDS over size");
}
@@ -167,6 +151,7 @@ CK_TILE_HOST constexpr auto moe_sorting_get_smem_row_col(int tokens_, int num_ex
return ck_tile::make_tuple(smem_rows, smem_cols);
}
// if return 0 or negative, means LDS is not enough
CK_TILE_HOST index_t moe_sorting_get_sub_token(int tokens_, int num_experts_)
{
auto [r_, c_] = moe_sorting_get_smem_row_col(tokens_, num_experts_);