[CK_TILE] add moe-sorting MP kernel (#1910)

* moe sorting ex

* fix bug for race condition

* fix bug and optimze large expert

* fix

* optimize with sub_token_oneshot

* support skip empty tokens for expert sorting

* update moe_sorting

* tidy code

* support mp kernel

* hint mp

* remove use less code

* porting to example 15

---------

Co-authored-by: valarLip <340077269@qq.com>

[ROCm/composable_kernel commit: 353a612b44]
This commit is contained in:
carlushuang
2025-02-25 17:56:55 +08:00
committed by GitHub
parent 1d09b0928c
commit 1d32e34075
8 changed files with 1043 additions and 31 deletions

View File

@@ -152,6 +152,13 @@ bool test_moe_sorting(ck_tile::ArgParser args)
if(local_expert_masking)
local_expert_masking_dev.ToDevice(local_expert_masking_host.data());
// if return zero, means no need workspace, can set moe_sorting_args.p_ws to nullptr
ck_tile::index_t workspace_size = moe_sorting_get_workspace_size(tokens, num_experts);
ck_tile::DeviceMem moe_sorting_ws(workspace_size != 0 ? workspace_size : 0);
if(workspace_size != 0)
moe_sorting_ws.SetZero(); // note, clear here!!!!
moe_sorting_trait trait{index_prec, weight_prec, local_expert_masking};
moe_sorting_args karg{topk_ids_dev.GetDeviceBuffer(),
@@ -163,6 +170,7 @@ bool test_moe_sorting(ck_tile::ArgParser args)
sorted_expert_ids_dev.GetDeviceBuffer(),
sorted_id_cnt_dev.GetDeviceBuffer(),
moe_buf_size > 0 ? moe_buf_dev.GetDeviceBuffer() : nullptr,
workspace_size != 0 ? moe_sorting_ws.GetDeviceBuffer() : nullptr,
tokens,
unit_size,
num_experts,
@@ -174,13 +182,68 @@ bool test_moe_sorting(ck_tile::ArgParser args)
/* log_level = */ (kname ? 1 : 0),
warmup,
repeat};
auto ms = moe_sorting(trait, karg, sc);
printf("[%s|%s]tokens:%d, num_experts:%d, topk:%d, ",
// auto ms = moe_sorting_mp(trait, karg, sc);
#if 0
{
ck_tile::HostTensor<char> ws_host({workspace_size}, {1});
moe_sorting_ws.FromDevice(ws_host.data());
int * p_mesh = reinterpret_cast<int*>(ws_host.data());
ck_tile::index_t row_size = ck_tile::impl::moe_sorting_mp_mesh_stride(tokens);
std::cout << "topk_ids:" << std::endl;
int * p_topk_ids = reinterpret_cast<int*>(topk_ids_host.data());
for(int i_token = 0; i_token < tokens; i_token++) {
printf("[t:%2d]", i_token);
for(int i_topk = 0; i_topk < topk; i_topk++) {
printf("%d, ",p_topk_ids[i_token * topk + i_topk] );
}
printf("\n");
}
printf("----------------\n");
std::vector<int> l_cumsum (num_experts + 1, 0);
for(int i_expert = 0; i_expert < num_experts; i_expert++ ) {
printf("[e:%2d]", i_expert);
int e_cnt = 0;
for(int i_token = 0; i_token < tokens; i_token++) {
auto v_mesh = p_mesh[i_expert * row_size + i_token];
e_cnt += v_mesh != 0 ? 1 : 0;
printf("%d, ", v_mesh);
}
int e_cnt_unit = (e_cnt + unit_size - 1) / unit_size;
printf("[%d/%d]", e_cnt, e_cnt_unit);
printf("\n");
l_cumsum[i_expert + 1] = l_cumsum[i_expert] + e_cnt_unit;
}
printf("----------------\n");
printf("cumsum:\n");
for(int i_cc= 0; i_cc < num_experts + 1; i_cc++) {
printf("%2d, ", l_cumsum[i_cc]);
}
printf("\n");
printf("----------------\n");
int * p_cumsum = p_mesh + ck_tile::impl::moe_sorting_mp_mesh_elem(tokens, num_experts);
for(int i_expert = 0; i_expert < num_experts + 1; i_expert++ ) {
printf("%2d(%d), ",p_cumsum[i_expert], p_cumsum[i_expert] / unit_size);
}
printf("\n");
}
#endif
printf("[%s|%s]tokens:%d, num_experts:%d, topk:%d, mp:%d, ",
index_prec.c_str(),
weight_prec.c_str(),
tokens,
num_experts,
topk);
topk,
workspace_size != 0 ? 1 : 0);
if(local_expert_masking)
{
@@ -224,28 +287,41 @@ bool test_moe_sorting(ck_tile::ArgParser args)
num_experts,
unit_size,
local_expert_masking);
rtn &= ck_tile::check_err(
sorted_ids_host, sorted_ids_ref, std::string("OUT Error: Incorrect ids!"), 1e-6, 1e-6);
rtn &= ck_tile::check_err(sorted_weights_host,
sorted_weights_ref,
std::string("OUT Error: Incorrect w!"),
1e-6,
1e-6);
rtn &= ck_tile::check_err(sorted_expert_ids_host,
sorted_expert_ids_ref,
std::string("OUT Error: Incorrect eid!"),
1e-6,
1e-6);
printf("total_tokens_post_pad:%d(%d), ",
ref_total_tokens_post_pad,
sorted_id_cnt_host.mData[0]);
if(ref_total_tokens_post_pad == sorted_id_cnt_host.mData[0])
{
size_t slen = ref_total_tokens_post_pad;
rtn &= ck_tile::check_err(sorted_ids_host.slice({0}, {slen}),
sorted_ids_ref.slice({0}, {slen}),
std::string("OUT Error: Incorrect ids!"),
1e-6,
1e-6);
rtn &= ck_tile::check_err(sorted_weights_host.slice({0}, {slen}),
sorted_weights_ref.slice({0}, {slen}),
std::string("OUT Error: Incorrect w!"),
1e-6,
1e-6);
rtn &= ck_tile::check_err(sorted_expert_ids_host.slice({0}, {slen / unit_size}),
sorted_expert_ids_ref.slice({0}, {slen / unit_size}),
std::string("OUT Error: Incorrect eid!"),
1e-6,
1e-6);
}
else
{
printf("(token size not equal!!)");
rtn = false;
}
if(moe_buf_size)
{
ck_tile::HostTensor<WeightType> moe_buf_ref({moe_buf_size});
rtn &= ck_tile::check_err(
moe_buf_host, moe_buf_ref, std::string("OUT Error: Incorrect zero buf!"), 0, 0);
}
rtn &= ref_total_tokens_post_pad == sorted_id_cnt_host.mData[0];
printf("total_tokens_post_pad:%d(%d), ",
ref_total_tokens_post_pad,
sorted_id_cnt_host.mData[0]);
// rtn &= ref_total_tokens_post_pad == sorted_id_cnt_host.mData[0];
}
printf("valid:%s", rtn ? "y" : "n");

View File

@@ -153,18 +153,106 @@ float moe_sorting(moe_sorting_trait t, moe_sorting_args a, ck_tile::stream_confi
}
}
#else
using index_t = ck_tile::index_t;
using ms_weight_type = float;
auto [r_, c_] = ck_tile::moe_sorting_get_smem_row_col(a.tokens, a.num_experts);
auto sub_token_ = r_ - 2;
r_ = (r_ - 2) / 8;
bool is_sub_token_onshot = a.tokens <= sub_token_;
if(moe_sorting_get_workspace_size(a.tokens, a.num_experts) != 0)
{
return moe_sorting_mp(t, a, s);
}
using index_t = ck_tile::index_t;
using ms_weight_type = float;
auto sub_token_ = ck_tile::moe_sorting_get_sub_token(a.tokens, a.num_experts);
auto row_ = sub_token_ / 8;
bool is_sub_token_onshot = a.tokens <= sub_token_;
bool is_local_expert_masking = t.local_expert_masking;
(void)c_;
MOE_SORTING_DISPATCH_EMASK_(r_);
MOE_SORTING_DISPATCH_EMASK_(row_);
// MOE_SORTING_DISPATCH_ETILE(0, 0);
#endif
}
return -1;
}
#define MOE_SORTING_MP_0(unroll_num_, expert_masking_) \
[&]() { \
constexpr ck_tile::index_t unroll_num = unroll_num_; \
constexpr bool expert_masking = expert_masking_; \
using ms_problem = \
ck_tile::MoeSortingProblemMp<ms_index_t, ms_weight_type, unroll_num, expert_masking>; \
using kernel = ck_tile::MoeSortingMultiPhaseKernel_P0<ms_problem>; \
auto kargs = kernel::MakeKargs(a); \
const dim3 grids = kernel::GridSize(a); \
const dim3 blocks = kernel::BlockSize(a); \
return ck_tile::make_kernel(kernel{}, grids, blocks, 0, kargs); \
}()
#define MOE_SORTING_MP_1(unroll_num_, expert_masking_) \
[&]() { \
constexpr ck_tile::index_t unroll_num = unroll_num_; \
constexpr bool expert_masking = expert_masking_; \
using ms_problem = \
ck_tile::MoeSortingProblemMp<ms_index_t, ms_weight_type, unroll_num, expert_masking>; \
using kernel = ck_tile::MoeSortingMultiPhaseKernel_P1<ms_problem>; \
auto kargs = kernel::MakeKargs(a); \
const dim3 grids = kernel::GridSize(a); \
const dim3 blocks = kernel::BlockSize(a); \
return ck_tile::make_kernel(kernel{}, grids, blocks, 0, kargs); \
}()
#define MOE_SORTING_MP_2(unroll_num_, expert_masking_) \
[&]() { \
constexpr ck_tile::index_t unroll_num = unroll_num_; \
constexpr bool expert_masking = expert_masking_; \
using ms_problem = \
ck_tile::MoeSortingProblemMp<ms_index_t, ms_weight_type, unroll_num, expert_masking>; \
using kernel = ck_tile::MoeSortingMultiPhaseKernel_P2<ms_problem>; \
auto kargs = kernel::MakeKargs(a); \
const dim3 grids = kernel::GridSize(a); \
const dim3 blocks = kernel::BlockSize(a); \
return ck_tile::make_kernel(kernel{}, grids, blocks, 0, kargs); \
}()
#define MOE_SORTING_MP_3(unroll_num_, expert_masking_) \
[&]() { \
constexpr ck_tile::index_t unroll_num = unroll_num_; \
constexpr bool expert_masking = expert_masking_; \
using ms_problem = \
ck_tile::MoeSortingProblemMp<ms_index_t, ms_weight_type, unroll_num, expert_masking>; \
using kernel = ck_tile::MoeSortingMultiPhaseKernel_P3<ms_problem>; \
auto kargs = kernel::MakeKargs(a); \
const dim3 grids = kernel::GridSize(a); \
const dim3 blocks = kernel::BlockSize(a); \
return ck_tile::make_kernel(kernel{}, grids, blocks, 0, kargs); \
}()
float moe_sorting_mp(moe_sorting_trait t, moe_sorting_args a, ck_tile::stream_config s)
{
if(t.weight_type == "fp32" && t.index_type == "int32")
{
using ms_index_t = ck_tile::index_t;
using ms_weight_type = float;
if(t.local_expert_masking)
{
float ave_time = ck_tile::launch_kernel(s,
MOE_SORTING_MP_0(1, true),
MOE_SORTING_MP_1(1, true),
MOE_SORTING_MP_2(1, true),
MOE_SORTING_MP_3(1, true));
return ave_time;
}
else
{
float ave_time = ck_tile::launch_kernel(s,
MOE_SORTING_MP_0(1, false),
MOE_SORTING_MP_1(1, false),
MOE_SORTING_MP_2(1, false),
MOE_SORTING_MP_3(1, false));
return ave_time;
}
}
return -1;
}
int moe_sorting_get_workspace_size(int tokens, int num_experts)
{
return ck_tile::moe_sorting_get_workspace_size(tokens, num_experts);
}

View File

@@ -18,4 +18,10 @@ struct moe_sorting_args : public ck_tile::MoeSortingHostArgs
{
};
// use below API before call moe_sorting() to indicate if need workspace or not
// if return non zero, means need workspace, you need to allocate a GPU buffer
// and set to moe_sorting_args.p_ws
// NOTE: workspace size are required to clear zero before use the API
int moe_sorting_get_workspace_size(int tokens, int num_experts);
float moe_sorting(moe_sorting_trait t, moe_sorting_args a, ck_tile::stream_config s);
float moe_sorting_mp(moe_sorting_trait t, moe_sorting_args a, ck_tile::stream_config s);

View File

@@ -17,6 +17,9 @@ struct fused_moe_args
const void* y_smooth_scale_ptr; // [e, 1, n], smooth-quant-scale for 2nd gemm input
const void* local_expert_mask_ptr; // [e], local_expert_mask_ptr for EP
void* o_ptr; // [m, k], output token (no need to do zeroing)
void* ws_ptr; // size is moe_sorting_get_workspace_size()
// if return zero, then could be nullptr
// must be cleard before use
const void* topk_ids_ptr; // [tokens, topk]
const void* topk_weight_ptr; // [tokens, topk]

View File

@@ -27,6 +27,7 @@ float fused_moe(fused_moe_traits t, fused_moe_args a, const ck_tile::stream_conf
a.sorted_expert_ids_ptr, // void* p_sorted_expert_ids;
a.num_sorted_tiles_ptr, // void* p_total_tokens_post_pad;
a.o_ptr, // void* p_moe_buf;
a.ws_ptr, // void* p_ws;
a.num_tokens, // index_t tokens;
a.block_m, // index_t unit_size;
a.num_experts, // index_t num_experts;

View File

@@ -371,6 +371,12 @@ bool run(const ck_tile::ArgParser& arg_parser)
ck_tile::DeviceMem num_sorted_tiles_buf(
num_sorted_tiles_host.get_element_space_size_in_bytes());
// if return zero, means no need workspace, can set moe_sorting_args.p_ws to nullptr
ck_tile::index_t workspace_size = ck_tile::moe_sorting_get_workspace_size(tokens, experts);
ck_tile::DeviceMem moe_sorting_ws(workspace_size != 0 ? workspace_size : 0);
if(workspace_size != 0)
moe_sorting_ws.SetZero(); // note, clear here!!!!
fused_moe_traits traits{prec_i,
prec_w,
prec_o,
@@ -394,6 +400,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
local_expert_masking ? local_expert_mask_buf.GetDeviceBuffer()
: nullptr,
o_buf.GetDeviceBuffer(),
workspace_size != 0 ? moe_sorting_ws.GetDeviceBuffer() : nullptr,
topk_ids_buf.GetDeviceBuffer(),
topk_weight_buf.GetDeviceBuffer(),
sorted_token_ids_buf.GetDeviceBuffer(),