mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-20 06:49:15 +00:00
[CK_TILE]Moe update index (#1672)
* update MOCK_ID for moe-sorting * add moe-smoothquant * update a comment * fix format * hot fix * update topk in overflow case * update comments * update bf16 cvt --------- Co-authored-by: valarLip <340077269@qq.com>
This commit is contained in:
@@ -12,20 +12,77 @@
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
#define MOE_SORTING_MOCK_ID(token_id_, topk_id_) \
|
||||
static_cast<uint32_t>(((token_id_)&0x00ffffff) | (((topk_id_)&0xff) << 24))
|
||||
|
||||
// clang-format off
|
||||
// [indexing implementation-1]
|
||||
// using M_a as constexpr block_size to partition all tokens into different slices
|
||||
// each slice map to one expert, and one expert can have multiple slices
|
||||
// e.g. num_experts = 6, topk=3, M_a = 4, input_tokens = 5
|
||||
// before sort, topk_ids is : [[0, 3, 5], [2, 3, 5], [1, 3, 5], [1, 2, 3], [1, 3, 5]]
|
||||
// tok-0 tok-1 tok-2 tok-3 tok-4
|
||||
// topk_weight is : [[a, b, c], [d, e, f], [g, h, i], [j, k, l], [m, n, o]] (some float number)
|
||||
//
|
||||
// token_id_per_expert is : [[0], [2, 3, 4], [1, 3], [0, 1, 2, 3, 4], [], [0, 1, 2, 5]]
|
||||
// (only for reference) exp-0 exp-1 exp-2 exp-3 exp-4 exp-5
|
||||
// weight_id_per_expert is: [[a], [g, j, m], [d, k], [b, e, h, l, n], [], [c, f, i, o]]
|
||||
//
|
||||
// max_num_tokens_padded : topk * input_tokens + num_experts * (M_a - 1)
|
||||
// * this could be larger than actual, since actual tokens are on GPU
|
||||
//
|
||||
// sorted_token_ids_ptr : [0, 6, 6, 6, 2, 3, 4, 6, 1, 3, 6, 6, 0, 1, 2, 3, 4, 6, 6, 6, 6, 6, 6, 6, 0, 1, 2, 5]
|
||||
// |- exp-0 -|- exp-1 -|- exp-2 -|- exp-3 -|- exp-4 -|- exp-5 -|
|
||||
// sorted_weight_ptr : [a, *, *, *, g, j, m, *, d, k, *, *, b, e, h, l, n, *, *, *, *, *, *, *, c, f, i, o]
|
||||
//
|
||||
// * length is max_num_tokens_padded, actual size is num_tokens_post_padded_ptr
|
||||
//
|
||||
// * Note on token_id_per_expert/sorted_token_ids_ptr data:
|
||||
// currently we do not have topk information from the data of token_id_per_expert/sorted_token_ids_ptr.
|
||||
// In some cases(like smooth-quant), we need topk information to indexing into tokens quant from
|
||||
// different expert smooth quant. So we modify the number stored inside token_id_per_expert/sorted_token_ids_ptr
|
||||
//
|
||||
// 32bit 0........23 24.....31 bit
|
||||
// (data) -> (token_id | topk_id)
|
||||
// low 24 bit is for token id, top 8 bit is for topk id
|
||||
//
|
||||
// the input after smooth-quant is [topk, token, hidden_dim], originally it is [token, hidden_dim]
|
||||
// the input scale for token is [topk, token, 1], the smooth-quant scale for first gemm is [expert, interm_dim]
|
||||
//
|
||||
// sorted_expert_ids_ptr : [0, 1, 2, 3, 3, 4, 5]
|
||||
// * length is (max_num_tokens_padded + block_size - 1) / block_size
|
||||
//
|
||||
// num_tokens_post_padded_ptr : [28]
|
||||
// num_sorted_tiles_ptr : [7]
|
||||
//
|
||||
// * different from vLLM
|
||||
// 1) token_id stored in sorted_token_ids_ptr is actual token_id, not token_id*top_K expanded id
|
||||
// 2)need sorted_weight_ptr
|
||||
// 3) use num_sorted_tiles_ptr, already divided by M_a
|
||||
//
|
||||
// * below used for indexing
|
||||
// 1) sorted_token_ids_ptr [max_num_tokens_padded]
|
||||
// 2) sorted_weight_ptr
|
||||
// 3) sorted_expert_ids_ptr
|
||||
// 4)num_tokens_post_padded_ptr/num_sorted_tiles_ptr (select one)
|
||||
//
|
||||
// max_num_tokens_padded: opk_ids.numel() + num_experts * (block_size - 1)
|
||||
struct MoeSortingHostArgs
|
||||
{
|
||||
const void* p_topk_ids;
|
||||
const void* p_weights;
|
||||
const void* p_topk_ids; // [token, topk]
|
||||
const void* p_weights; // [token, topk]
|
||||
void* p_sorted_token_ids;
|
||||
void* p_sorted_weights;
|
||||
void* p_sorted_expert_ids;
|
||||
void* p_total_tokens_post_pad;
|
||||
// we fused the setzero of output of fused-moe buffer
|
||||
// set this pointer to nullptr will skip this operation
|
||||
void* p_moe_buf;
|
||||
index_t tokens;
|
||||
index_t unit_size;
|
||||
index_t unit_size; // this is the M_a of fused-moe kernel
|
||||
index_t num_experts;
|
||||
index_t topk;
|
||||
index_t moe_buf_bytes;
|
||||
index_t moe_buf_bytes; // byte size of p_moe_buf
|
||||
};
|
||||
|
||||
template <typename Problem_>
|
||||
@@ -183,8 +240,14 @@ struct MoeSortingKernel
|
||||
index_t expert_id = topk_id[i];
|
||||
index_t rank_post_pad =
|
||||
tokens_cnts[calc_index(num_experts, tid, expert_id)] + cumsum[expert_id];
|
||||
#if CK_TILE_REFERENCE_MOE_SORTING_MOCK_ID
|
||||
uint32_t curr_token_id, curr_topk_id;
|
||||
topk_mdiv.divmod(i, curr_token_id, curr_topk_id);
|
||||
p_sorted_token_ids[rank_post_pad] = MOE_SORTING_MOCK_ID(curr_token_id, curr_topk_id);
|
||||
#else
|
||||
p_sorted_token_ids[rank_post_pad] = topk_mdiv.div(i);
|
||||
p_sorted_weights[rank_post_pad] = weights[i];
|
||||
#endif
|
||||
p_sorted_weights[rank_post_pad] = weights[i];
|
||||
++tokens_cnts[calc_index(num_experts, tid, expert_id)];
|
||||
}
|
||||
|
||||
@@ -195,8 +258,13 @@ struct MoeSortingKernel
|
||||
cumsum[tid] + tokens_cnts[calc_index(num_experts, blockDim.x, tid)];
|
||||
while(expert_offset < cumsum[tid + 1])
|
||||
{
|
||||
#if CK_TILE_REFERENCE_MOE_SORTING_MOCK_ID
|
||||
p_sorted_token_ids[expert_offset] =
|
||||
MOE_SORTING_MOCK_ID(prefill_token, topk_mdiv.divisor);
|
||||
#else
|
||||
p_sorted_token_ids[expert_offset] = prefill_token;
|
||||
p_sorted_weights[expert_offset] = static_cast<WeightType>(0.0);
|
||||
#endif
|
||||
p_sorted_weights[expert_offset] = static_cast<WeightType>(0.0);
|
||||
expert_offset++;
|
||||
}
|
||||
}
|
||||
@@ -229,4 +297,7 @@ struct MoeSortingKernel
|
||||
smem);
|
||||
}
|
||||
};
|
||||
|
||||
#undef MOE_SORTING_MOCK_ID
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
Reference in New Issue
Block a user