mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-02 12:41:26 +00:00
[CK_TILE] moe sorting optimization : refactor subtoken logic to let more kernel pickup mp kernel (#2327)
* refactor subtoken logic to let more kernel pickup mp kernel * typo
This commit is contained in:
@@ -127,37 +127,21 @@ CK_TILE_HOST constexpr auto moe_sorting_get_smem_row_col(int tokens_, int num_ex
|
||||
constexpr index_t cumsum_bufs = 2; // 1 for cumsum, 1 for cnt
|
||||
// at lease 2 lines, one for sub_token unroll, one for cumsum
|
||||
// should be enough
|
||||
if ((total_ / target_occupancy_) < ((cumsum_bufs+sub_unroll) * smem_cols)) {
|
||||
if ((total_ / 1) < ((cumsum_bufs+sub_unroll) * smem_cols))
|
||||
throw std::runtime_error("too many num_experts, can't allocate smem");
|
||||
target_occupancy_ = 1;
|
||||
}
|
||||
|
||||
int r = total_ / target_occupancy_ / smem_cols;
|
||||
|
||||
// Note: at lease allocate cumsum_bufs + sub_unroll as num-row. Otherwise, fallback to mp kernel
|
||||
if(r < (cumsum_bufs + sub_unroll))
|
||||
return cumsum_bufs;
|
||||
|
||||
// round to sub_unroll multipl
|
||||
int r_for_sub_token = r - cumsum_bufs;
|
||||
r_for_sub_token = min(r_for_sub_token, tokens_);
|
||||
r_for_sub_token = (r_for_sub_token + sub_unroll - 1) / sub_unroll * sub_unroll;
|
||||
r_for_sub_token = max(r_for_sub_token, 1);
|
||||
r_for_sub_token = r_for_sub_token / sub_unroll * sub_unroll;
|
||||
int r_token_min = (tokens_ + sub_unroll - 1) / sub_unroll * sub_unroll;
|
||||
r_for_sub_token = min(r_for_sub_token, r_token_min);
|
||||
|
||||
if(r_for_sub_token > 1)
|
||||
{
|
||||
int r_unroll_ = r_for_sub_token / sub_unroll;
|
||||
|
||||
|
||||
// round to 1x/2x/4x/8x number of sub_unroll
|
||||
int clz_ = __builtin_clz(r_unroll_); // 0b1:31 0b2:30, 0b3:30, 0b4:29
|
||||
int mask_ = (1 << (31 - clz_)) - 1;
|
||||
|
||||
|
||||
mask_ = mask_ > 0b111 ? 0b111 : mask_; //clamp to 8x at most
|
||||
mask_ = ~mask_;
|
||||
|
||||
r_for_sub_token = (r_unroll_ & mask_) * sub_unroll;
|
||||
}
|
||||
|
||||
// final check
|
||||
if( (r_for_sub_token + cumsum_bufs * smem_cols * target_occupancy_ ) >= total_ ) {
|
||||
// final check, but usually should not happen
|
||||
if( ((r_for_sub_token + cumsum_bufs) * smem_cols * target_occupancy_ ) > total_ ) {
|
||||
throw std::runtime_error("can't run this kernel, request LDS over size");
|
||||
}
|
||||
|
||||
@@ -167,6 +151,7 @@ CK_TILE_HOST constexpr auto moe_sorting_get_smem_row_col(int tokens_, int num_ex
|
||||
return ck_tile::make_tuple(smem_rows, smem_cols);
|
||||
}
|
||||
|
||||
// if return 0 or negative, means LDS is not enough
|
||||
CK_TILE_HOST index_t moe_sorting_get_sub_token(int tokens_, int num_experts_)
|
||||
{
|
||||
auto [r_, c_] = moe_sorting_get_smem_row_col(tokens_, num_experts_);
|
||||
|
||||
Reference in New Issue
Block a user