mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-01 20:21:23 +00:00
[CK_TILE] optimize moe-sorting kernel (#1771)
* opt moe sorting * remove commented code
This commit is contained in:
@@ -130,7 +130,8 @@ struct MoeSortingKernel
|
||||
CK_TILE_HOST static constexpr auto GetSmemSize(const Hargs& h)
|
||||
{
|
||||
const auto blocks = BlockSize(h);
|
||||
return ((blocks.x + 1) * h.num_experts + (h.num_experts + 1)) * sizeof(index_t);
|
||||
// usually num_experts is power of 2, we pad 1 dword here for the row-size
|
||||
return ((blocks.x + 1) * (h.num_experts + 1) + (h.num_experts + 1)) * sizeof(index_t);
|
||||
}
|
||||
|
||||
CK_TILE_HOST static constexpr auto MakeKargs(const Hargs& h)
|
||||
@@ -154,6 +155,75 @@ struct MoeSortingKernel
|
||||
return k;
|
||||
}
|
||||
|
||||
// [a, b, c, d....] -> [a, a+b, a+b+c, a+b+c+d, ....]
|
||||
template <typename data_t, int wave_size>
|
||||
__device__ inline void wave_cumsum(data_t& thread_data) const
|
||||
{
|
||||
// wave_size must be power of 2
|
||||
constexpr int row_mask = 0xf;
|
||||
constexpr int bank_mask = 0xf;
|
||||
constexpr bool bound_ctrl = true; // ! out-of-bound is zero !
|
||||
auto reduce_op = [&](auto x_, auto y_) { return x_ + y_; };
|
||||
|
||||
if constexpr(wave_size > 1)
|
||||
{
|
||||
thread_data = reduce_op(
|
||||
thread_data,
|
||||
__builtin_bit_cast(data_t, __builtin_amdgcn_mov_dpp(__builtin_bit_cast(int, thread_data),
|
||||
0x111,
|
||||
row_mask,
|
||||
bank_mask,
|
||||
bound_ctrl))); // row_shr:1
|
||||
}
|
||||
|
||||
if constexpr(wave_size > 2)
|
||||
{
|
||||
thread_data = reduce_op(
|
||||
thread_data,
|
||||
__builtin_bit_cast(data_t, __builtin_amdgcn_mov_dpp(__builtin_bit_cast(int, thread_data),
|
||||
0x112,
|
||||
row_mask,
|
||||
bank_mask,
|
||||
bound_ctrl))); // row_shr:2
|
||||
}
|
||||
if constexpr(wave_size > 4)
|
||||
{
|
||||
thread_data =
|
||||
reduce_op(thread_data,
|
||||
__builtin_bit_cast(data_t, __builtin_amdgcn_mov_dpp(__builtin_bit_cast(int, thread_data),
|
||||
0x114,
|
||||
row_mask,
|
||||
bank_mask,
|
||||
bound_ctrl))); // row_shr:4
|
||||
}
|
||||
if constexpr(wave_size > 8)
|
||||
{
|
||||
thread_data =
|
||||
reduce_op(thread_data,
|
||||
__builtin_bit_cast(data_t, __builtin_amdgcn_mov_dpp(__builtin_bit_cast(int, thread_data),
|
||||
0x118,
|
||||
row_mask,
|
||||
bank_mask,
|
||||
bound_ctrl))); // row_shr:8
|
||||
}
|
||||
|
||||
if constexpr(wave_size > 16)
|
||||
{
|
||||
// now row-0, row-0+row-1, row-1+row-2, row-2+row-3
|
||||
int v_remote_tmp = __builtin_amdgcn_ds_bpermute(((__lane_id() & 0x30) - 1) << 2, __builtin_bit_cast(int, thread_data));
|
||||
v_remote_tmp = __lane_id() >= 16 ? v_remote_tmp : 0;
|
||||
thread_data = reduce_op(thread_data, __builtin_bit_cast(data_t, v_remote_tmp));
|
||||
}
|
||||
|
||||
if constexpr(wave_size > 32)
|
||||
{
|
||||
// lane-id 48...63->31
|
||||
int v_remote_tmp = __builtin_amdgcn_ds_bpermute(((__lane_id() & 0x30) - 17) << 2, __builtin_bit_cast(int, thread_data));
|
||||
v_remote_tmp = __lane_id() >= 32 ? v_remote_tmp : 0;
|
||||
thread_data = reduce_op(thread_data, __builtin_bit_cast(data_t, v_remote_tmp));
|
||||
}
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE index_t calc_index(index_t total_col, index_t row, index_t col) const
|
||||
{
|
||||
return row * total_col + col;
|
||||
@@ -187,48 +257,124 @@ struct MoeSortingKernel
|
||||
index_t* shared_mem = reinterpret_cast<index_t*>(smem);
|
||||
|
||||
index_t* tokens_cnts = shared_mem; // 2d: (blockDim.x + 1, num_experts)
|
||||
index_t* cumsum = shared_mem + (blockDim.x + 1) * num_experts; // 1: (num_experts + 1)
|
||||
index_t* cumsum = shared_mem + (blockDim.x + 1) * (num_experts+1); // 1: (num_experts + 1)
|
||||
|
||||
for(int i = 0; i < num_experts; ++i)
|
||||
{
|
||||
tokens_cnts[calc_index(num_experts, tid + 1, i)] = 0;
|
||||
tokens_cnts[calc_index(num_experts+1, tid + 1, i)] = 0;
|
||||
}
|
||||
|
||||
#pragma unroll Problem_::InternalLoadUnroll
|
||||
for(int i = start_idx; i < numel && i < start_idx + tokens_per_thread; ++i)
|
||||
{
|
||||
++tokens_cnts[calc_index(num_experts, tid + 1, topk_id[i])];
|
||||
++tokens_cnts[calc_index(num_experts+1, tid + 1, topk_id[i])];
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
#if 1
|
||||
if(tid < num_experts)
|
||||
{
|
||||
tokens_cnts[calc_index(num_experts, 0, tid)] = 0;
|
||||
for(int i = 1; i <= static_cast<index_t>(blockDim.x); ++i)
|
||||
tokens_cnts[calc_index(num_experts+1, 0, tid)] = 0;
|
||||
index_t local_c[8];
|
||||
index_t prev_c = 0;
|
||||
// TODO: manually unroll. pragma unroll does not work well when we have dependency
|
||||
for(int i = 1; i <= static_cast<index_t>(blockDim.x); i+= 8)
|
||||
{
|
||||
tokens_cnts[calc_index(num_experts, i, tid)] +=
|
||||
tokens_cnts[calc_index(num_experts, i - 1, tid)];
|
||||
local_c[0] = tokens_cnts[calc_index(num_experts+1, i + 0, tid)];
|
||||
local_c[1] = tokens_cnts[calc_index(num_experts+1, i + 1, tid)];
|
||||
local_c[2] = tokens_cnts[calc_index(num_experts+1, i + 2, tid)];
|
||||
local_c[3] = tokens_cnts[calc_index(num_experts+1, i + 3, tid)];
|
||||
local_c[4] = tokens_cnts[calc_index(num_experts+1, i + 4, tid)];
|
||||
local_c[5] = tokens_cnts[calc_index(num_experts+1, i + 5, tid)];
|
||||
local_c[6] = tokens_cnts[calc_index(num_experts+1, i + 6, tid)];
|
||||
local_c[7] = tokens_cnts[calc_index(num_experts+1, i + 7, tid)];
|
||||
|
||||
local_c[0] += prev_c;
|
||||
local_c[1] += local_c[0];
|
||||
local_c[2] += local_c[1];
|
||||
local_c[3] += local_c[2];
|
||||
local_c[4] += local_c[3];
|
||||
local_c[5] += local_c[4];
|
||||
local_c[6] += local_c[5];
|
||||
local_c[7] += local_c[6];
|
||||
prev_c = local_c[7];
|
||||
|
||||
tokens_cnts[calc_index(num_experts+1, i + 0, tid)] = local_c[0];
|
||||
tokens_cnts[calc_index(num_experts+1, i + 1, tid)] = local_c[1];
|
||||
tokens_cnts[calc_index(num_experts+1, i + 2, tid)] = local_c[2];
|
||||
tokens_cnts[calc_index(num_experts+1, i + 3, tid)] = local_c[3];
|
||||
tokens_cnts[calc_index(num_experts+1, i + 4, tid)] = local_c[4];
|
||||
tokens_cnts[calc_index(num_experts+1, i + 5, tid)] = local_c[5];
|
||||
tokens_cnts[calc_index(num_experts+1, i + 6, tid)] = local_c[6];
|
||||
tokens_cnts[calc_index(num_experts+1, i + 7, tid)] = local_c[7];
|
||||
}
|
||||
}
|
||||
#else
|
||||
// TODO: below code still working, but slow in expert=32/topk=5 case. Put here for future heuristic
|
||||
{
|
||||
if(tid < num_experts)
|
||||
tokens_cnts[calc_index(num_experts+1, 0, tid)] = 0;
|
||||
for(int i = 0; i < num_experts; i+=8) {
|
||||
index_t local_c[8];
|
||||
#pragma unroll
|
||||
for(int j = 0; j < 8; j++) {
|
||||
local_c[j] = tokens_cnts[calc_index(num_experts+1, tid+1, i+j)];
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for(int j = 0; j < 8; j++) {
|
||||
wave_cumsum<int, 64>(local_c[j]);
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for(int j = 0; j < 8; j++) {
|
||||
tokens_cnts[calc_index(num_experts+1, tid+1, i+j)] = local_c[j];
|
||||
}
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
__syncthreads();
|
||||
if constexpr (Problem::ExpertTile == 0) {
|
||||
if(tid == 0)
|
||||
{
|
||||
cumsum[0] = 0;
|
||||
for(int i = 1; i <= num_experts; ++i)
|
||||
{
|
||||
auto current_units = [&]() {
|
||||
index_t x_ = tokens_cnts[calc_index(num_experts+1, blockDim.x, i - 1)] +
|
||||
unit_size_mdiv.divisor - 1;
|
||||
index_t y_ = unit_size_mdiv.div(x_);
|
||||
return max(y_, 1) * unit_size_mdiv.divisor;
|
||||
}();
|
||||
cumsum[i] = cumsum[i - 1] + current_units;
|
||||
}
|
||||
*p_total_tokens_post_pad = cumsum[num_experts];
|
||||
}
|
||||
} else {
|
||||
// TODO: we have out-of-bound read here. But result is still OK (will ignore tid >= expert)
|
||||
// for simplicity, not check experts here.
|
||||
int local_cnt = tokens_cnts[calc_index(num_experts+1, blockDim.x, tid)];
|
||||
int blocks_pers_expert = unit_size_mdiv.div(local_cnt + unit_size_mdiv.divisor - 1);
|
||||
int padded_tokens_per_expert = max(blocks_pers_expert, 1) * unit_size_mdiv.divisor;
|
||||
int local_cumsum = padded_tokens_per_expert;
|
||||
wave_cumsum<int, 64>(local_cumsum);
|
||||
|
||||
if(tid == (num_experts - 1)) {
|
||||
cumsum[0] = 0;
|
||||
*p_total_tokens_post_pad = local_cumsum;
|
||||
}
|
||||
if(tid < num_experts) {
|
||||
cumsum[tid + 1] = local_cumsum;
|
||||
}
|
||||
}
|
||||
|
||||
// __syncthreads();
|
||||
if(tid == 0)
|
||||
{
|
||||
cumsum[0] = 0;
|
||||
for(int i = 1; i <= num_experts; ++i)
|
||||
{
|
||||
auto current_units = [&]() {
|
||||
index_t x_ = tokens_cnts[calc_index(num_experts, blockDim.x, i - 1)] +
|
||||
unit_size_mdiv.divisor - 1;
|
||||
index_t y_ = unit_size_mdiv.div(x_);
|
||||
return max(y_, 1) * unit_size_mdiv.divisor;
|
||||
}();
|
||||
cumsum[i] = cumsum[i - 1] + current_units;
|
||||
}
|
||||
*p_total_tokens_post_pad = cumsum[num_experts];
|
||||
}
|
||||
__syncthreads();
|
||||
if(tid < num_experts)
|
||||
{
|
||||
for(int i = cumsum[tid]; i < cumsum[tid + 1]; i += unit_size_mdiv.divisor)
|
||||
int e_start = cumsum[tid];
|
||||
int e_end = cumsum[tid + 1];
|
||||
for(int i = e_start; i < e_end; i += unit_size_mdiv.divisor)
|
||||
{
|
||||
p_sorted_expert_ids[unit_size_mdiv.div(i)] = tid;
|
||||
}
|
||||
@@ -238,8 +384,8 @@ struct MoeSortingKernel
|
||||
for(int i = start_idx; i < numel && i < start_idx + tokens_per_thread; ++i)
|
||||
{
|
||||
index_t expert_id = topk_id[i];
|
||||
index_t rank_post_pad =
|
||||
tokens_cnts[calc_index(num_experts, tid, expert_id)] + cumsum[expert_id];
|
||||
index_t local_cnt = tokens_cnts[calc_index(num_experts+1, tid, expert_id)];
|
||||
index_t rank_post_pad = local_cnt + cumsum[expert_id];
|
||||
#if CK_TILE_REFERENCE_MOE_SORTING_MOCK_ID
|
||||
uint32_t curr_token_id, curr_topk_id;
|
||||
topk_mdiv.divmod(i, curr_token_id, curr_topk_id);
|
||||
@@ -247,27 +393,54 @@ struct MoeSortingKernel
|
||||
#else
|
||||
p_sorted_token_ids[rank_post_pad] = topk_mdiv.div(i);
|
||||
#endif
|
||||
p_sorted_weights[rank_post_pad] = weights[i];
|
||||
++tokens_cnts[calc_index(num_experts, tid, expert_id)];
|
||||
p_sorted_weights[rank_post_pad] = weights[i];
|
||||
tokens_cnts[calc_index(num_experts+1, tid, expert_id)] = local_cnt+1;
|
||||
}
|
||||
|
||||
const index_t prefill_token = topk_mdiv.div(numel);
|
||||
if(tid < num_experts)
|
||||
{
|
||||
index_t expert_offset =
|
||||
cumsum[tid] + tokens_cnts[calc_index(num_experts, blockDim.x, tid)];
|
||||
while(expert_offset < cumsum[tid + 1])
|
||||
if constexpr (Problem::ExpertTile == 0) {
|
||||
const index_t prefill_token = topk_mdiv.div(numel);
|
||||
if(tid < num_experts)
|
||||
{
|
||||
index_t expert_offset =
|
||||
cumsum[tid] + tokens_cnts[calc_index(num_experts+1, blockDim.x, tid)];
|
||||
index_t expert_end = cumsum[tid + 1];
|
||||
while(expert_offset < expert_end)
|
||||
{
|
||||
#if CK_TILE_REFERENCE_MOE_SORTING_MOCK_ID
|
||||
p_sorted_token_ids[expert_offset] =
|
||||
MOE_SORTING_MOCK_ID(prefill_token, topk_mdiv.divisor);
|
||||
p_sorted_token_ids[expert_offset] =
|
||||
MOE_SORTING_MOCK_ID(prefill_token, topk_mdiv.divisor);
|
||||
#else
|
||||
p_sorted_token_ids[expert_offset] = prefill_token;
|
||||
p_sorted_token_ids[expert_offset] = prefill_token;
|
||||
#endif
|
||||
p_sorted_weights[expert_offset] = static_cast<WeightType>(0.0);
|
||||
expert_offset++;
|
||||
p_sorted_weights[expert_offset] = static_cast<WeightType>(0.0);
|
||||
expert_offset++;
|
||||
}
|
||||
}
|
||||
}
|
||||
else {
|
||||
const index_t prefill_token = topk_mdiv.div(numel);
|
||||
// TODO: only support expert-tile like 8, 16, 32
|
||||
static constexpr index_t experts_per_wave = warpSize / Problem::ExpertTile;
|
||||
{
|
||||
index_t eid = tid / experts_per_wave;
|
||||
index_t expert_offset =
|
||||
cumsum[eid] + tokens_cnts[calc_index(num_experts+1, blockDim.x, eid)] + tid % experts_per_wave;
|
||||
index_t expert_end = cumsum[eid + 1];
|
||||
if(eid < num_experts) {
|
||||
while(expert_offset < expert_end)
|
||||
{
|
||||
#if CK_TILE_REFERENCE_MOE_SORTING_MOCK_ID
|
||||
p_sorted_token_ids[expert_offset] =
|
||||
MOE_SORTING_MOCK_ID(prefill_token, topk_mdiv.divisor);
|
||||
#else
|
||||
p_sorted_token_ids[expert_offset] = prefill_token;
|
||||
#endif
|
||||
p_sorted_weights[expert_offset] = static_cast<WeightType>(0.0);
|
||||
expert_offset+=experts_per_wave;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE void operator()(Kargs kargs) const
|
||||
|
||||
Reference in New Issue
Block a user