mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-18 20:09:25 +00:00
[CK_TILE] add moe-sorting MP kernel (#1910)
* moe sorting ex
* fix bug for race condition
* fix bug and optimze large expert
* fix
* optimize with sub_token_oneshot
* support skip empty tokens for expert sorting
* update moe_sorting
* tidy code
* support mp kernel
* hint mp
* remove use less code
* porting to example 15
---------
Co-authored-by: valarLip <340077269@qq.com>
[ROCm/composable_kernel commit: 353a612b44]
This commit is contained in:
@@ -152,6 +152,13 @@ bool test_moe_sorting(ck_tile::ArgParser args)
|
||||
if(local_expert_masking)
|
||||
local_expert_masking_dev.ToDevice(local_expert_masking_host.data());
|
||||
|
||||
// if return zero, means no need workspace, can set moe_sorting_args.p_ws to nullptr
|
||||
ck_tile::index_t workspace_size = moe_sorting_get_workspace_size(tokens, num_experts);
|
||||
ck_tile::DeviceMem moe_sorting_ws(workspace_size != 0 ? workspace_size : 0);
|
||||
|
||||
if(workspace_size != 0)
|
||||
moe_sorting_ws.SetZero(); // note, clear here!!!!
|
||||
|
||||
moe_sorting_trait trait{index_prec, weight_prec, local_expert_masking};
|
||||
|
||||
moe_sorting_args karg{topk_ids_dev.GetDeviceBuffer(),
|
||||
@@ -163,6 +170,7 @@ bool test_moe_sorting(ck_tile::ArgParser args)
|
||||
sorted_expert_ids_dev.GetDeviceBuffer(),
|
||||
sorted_id_cnt_dev.GetDeviceBuffer(),
|
||||
moe_buf_size > 0 ? moe_buf_dev.GetDeviceBuffer() : nullptr,
|
||||
workspace_size != 0 ? moe_sorting_ws.GetDeviceBuffer() : nullptr,
|
||||
tokens,
|
||||
unit_size,
|
||||
num_experts,
|
||||
@@ -174,13 +182,68 @@ bool test_moe_sorting(ck_tile::ArgParser args)
|
||||
/* log_level = */ (kname ? 1 : 0),
|
||||
warmup,
|
||||
repeat};
|
||||
|
||||
auto ms = moe_sorting(trait, karg, sc);
|
||||
printf("[%s|%s]tokens:%d, num_experts:%d, topk:%d, ",
|
||||
// auto ms = moe_sorting_mp(trait, karg, sc);
|
||||
|
||||
#if 0
|
||||
{
|
||||
ck_tile::HostTensor<char> ws_host({workspace_size}, {1});
|
||||
moe_sorting_ws.FromDevice(ws_host.data());
|
||||
|
||||
int * p_mesh = reinterpret_cast<int*>(ws_host.data());
|
||||
ck_tile::index_t row_size = ck_tile::impl::moe_sorting_mp_mesh_stride(tokens);
|
||||
|
||||
std::cout << "topk_ids:" << std::endl;
|
||||
|
||||
int * p_topk_ids = reinterpret_cast<int*>(topk_ids_host.data());
|
||||
for(int i_token = 0; i_token < tokens; i_token++) {
|
||||
printf("[t:%2d]", i_token);
|
||||
for(int i_topk = 0; i_topk < topk; i_topk++) {
|
||||
printf("%d, ",p_topk_ids[i_token * topk + i_topk] );
|
||||
}
|
||||
printf("\n");
|
||||
}
|
||||
printf("----------------\n");
|
||||
|
||||
std::vector<int> l_cumsum (num_experts + 1, 0);
|
||||
for(int i_expert = 0; i_expert < num_experts; i_expert++ ) {
|
||||
printf("[e:%2d]", i_expert);
|
||||
int e_cnt = 0;
|
||||
for(int i_token = 0; i_token < tokens; i_token++) {
|
||||
auto v_mesh = p_mesh[i_expert * row_size + i_token];
|
||||
e_cnt += v_mesh != 0 ? 1 : 0;
|
||||
printf("%d, ", v_mesh);
|
||||
}
|
||||
int e_cnt_unit = (e_cnt + unit_size - 1) / unit_size;
|
||||
printf("[%d/%d]", e_cnt, e_cnt_unit);
|
||||
printf("\n");
|
||||
l_cumsum[i_expert + 1] = l_cumsum[i_expert] + e_cnt_unit;
|
||||
}
|
||||
|
||||
printf("----------------\n");
|
||||
printf("cumsum:\n");
|
||||
for(int i_cc= 0; i_cc < num_experts + 1; i_cc++) {
|
||||
printf("%2d, ", l_cumsum[i_cc]);
|
||||
}
|
||||
printf("\n");
|
||||
printf("----------------\n");
|
||||
|
||||
int * p_cumsum = p_mesh + ck_tile::impl::moe_sorting_mp_mesh_elem(tokens, num_experts);
|
||||
for(int i_expert = 0; i_expert < num_experts + 1; i_expert++ ) {
|
||||
printf("%2d(%d), ",p_cumsum[i_expert], p_cumsum[i_expert] / unit_size);
|
||||
}
|
||||
printf("\n");
|
||||
}
|
||||
#endif
|
||||
|
||||
printf("[%s|%s]tokens:%d, num_experts:%d, topk:%d, mp:%d, ",
|
||||
index_prec.c_str(),
|
||||
weight_prec.c_str(),
|
||||
tokens,
|
||||
num_experts,
|
||||
topk);
|
||||
topk,
|
||||
workspace_size != 0 ? 1 : 0);
|
||||
|
||||
if(local_expert_masking)
|
||||
{
|
||||
@@ -224,28 +287,41 @@ bool test_moe_sorting(ck_tile::ArgParser args)
|
||||
num_experts,
|
||||
unit_size,
|
||||
local_expert_masking);
|
||||
rtn &= ck_tile::check_err(
|
||||
sorted_ids_host, sorted_ids_ref, std::string("OUT Error: Incorrect ids!"), 1e-6, 1e-6);
|
||||
rtn &= ck_tile::check_err(sorted_weights_host,
|
||||
sorted_weights_ref,
|
||||
std::string("OUT Error: Incorrect w!"),
|
||||
1e-6,
|
||||
1e-6);
|
||||
rtn &= ck_tile::check_err(sorted_expert_ids_host,
|
||||
sorted_expert_ids_ref,
|
||||
std::string("OUT Error: Incorrect eid!"),
|
||||
1e-6,
|
||||
1e-6);
|
||||
printf("total_tokens_post_pad:%d(%d), ",
|
||||
ref_total_tokens_post_pad,
|
||||
sorted_id_cnt_host.mData[0]);
|
||||
if(ref_total_tokens_post_pad == sorted_id_cnt_host.mData[0])
|
||||
{
|
||||
size_t slen = ref_total_tokens_post_pad;
|
||||
rtn &= ck_tile::check_err(sorted_ids_host.slice({0}, {slen}),
|
||||
sorted_ids_ref.slice({0}, {slen}),
|
||||
std::string("OUT Error: Incorrect ids!"),
|
||||
1e-6,
|
||||
1e-6);
|
||||
rtn &= ck_tile::check_err(sorted_weights_host.slice({0}, {slen}),
|
||||
sorted_weights_ref.slice({0}, {slen}),
|
||||
std::string("OUT Error: Incorrect w!"),
|
||||
1e-6,
|
||||
1e-6);
|
||||
rtn &= ck_tile::check_err(sorted_expert_ids_host.slice({0}, {slen / unit_size}),
|
||||
sorted_expert_ids_ref.slice({0}, {slen / unit_size}),
|
||||
std::string("OUT Error: Incorrect eid!"),
|
||||
1e-6,
|
||||
1e-6);
|
||||
}
|
||||
else
|
||||
{
|
||||
printf("(token size not equal!!)");
|
||||
rtn = false;
|
||||
}
|
||||
|
||||
if(moe_buf_size)
|
||||
{
|
||||
ck_tile::HostTensor<WeightType> moe_buf_ref({moe_buf_size});
|
||||
rtn &= ck_tile::check_err(
|
||||
moe_buf_host, moe_buf_ref, std::string("OUT Error: Incorrect zero buf!"), 0, 0);
|
||||
}
|
||||
rtn &= ref_total_tokens_post_pad == sorted_id_cnt_host.mData[0];
|
||||
printf("total_tokens_post_pad:%d(%d), ",
|
||||
ref_total_tokens_post_pad,
|
||||
sorted_id_cnt_host.mData[0]);
|
||||
// rtn &= ref_total_tokens_post_pad == sorted_id_cnt_host.mData[0];
|
||||
}
|
||||
|
||||
printf("valid:%s", rtn ? "y" : "n");
|
||||
|
||||
@@ -153,18 +153,106 @@ float moe_sorting(moe_sorting_trait t, moe_sorting_args a, ck_tile::stream_confi
|
||||
}
|
||||
}
|
||||
#else
|
||||
using index_t = ck_tile::index_t;
|
||||
using ms_weight_type = float;
|
||||
auto [r_, c_] = ck_tile::moe_sorting_get_smem_row_col(a.tokens, a.num_experts);
|
||||
auto sub_token_ = r_ - 2;
|
||||
r_ = (r_ - 2) / 8;
|
||||
bool is_sub_token_onshot = a.tokens <= sub_token_;
|
||||
if(moe_sorting_get_workspace_size(a.tokens, a.num_experts) != 0)
|
||||
{
|
||||
return moe_sorting_mp(t, a, s);
|
||||
}
|
||||
using index_t = ck_tile::index_t;
|
||||
using ms_weight_type = float;
|
||||
auto sub_token_ = ck_tile::moe_sorting_get_sub_token(a.tokens, a.num_experts);
|
||||
auto row_ = sub_token_ / 8;
|
||||
bool is_sub_token_onshot = a.tokens <= sub_token_;
|
||||
bool is_local_expert_masking = t.local_expert_masking;
|
||||
(void)c_;
|
||||
|
||||
MOE_SORTING_DISPATCH_EMASK_(r_);
|
||||
MOE_SORTING_DISPATCH_EMASK_(row_);
|
||||
// MOE_SORTING_DISPATCH_ETILE(0, 0);
|
||||
#endif
|
||||
}
|
||||
return -1;
|
||||
}
|
||||
|
||||
#define MOE_SORTING_MP_0(unroll_num_, expert_masking_) \
|
||||
[&]() { \
|
||||
constexpr ck_tile::index_t unroll_num = unroll_num_; \
|
||||
constexpr bool expert_masking = expert_masking_; \
|
||||
using ms_problem = \
|
||||
ck_tile::MoeSortingProblemMp<ms_index_t, ms_weight_type, unroll_num, expert_masking>; \
|
||||
using kernel = ck_tile::MoeSortingMultiPhaseKernel_P0<ms_problem>; \
|
||||
auto kargs = kernel::MakeKargs(a); \
|
||||
const dim3 grids = kernel::GridSize(a); \
|
||||
const dim3 blocks = kernel::BlockSize(a); \
|
||||
return ck_tile::make_kernel(kernel{}, grids, blocks, 0, kargs); \
|
||||
}()
|
||||
|
||||
#define MOE_SORTING_MP_1(unroll_num_, expert_masking_) \
|
||||
[&]() { \
|
||||
constexpr ck_tile::index_t unroll_num = unroll_num_; \
|
||||
constexpr bool expert_masking = expert_masking_; \
|
||||
using ms_problem = \
|
||||
ck_tile::MoeSortingProblemMp<ms_index_t, ms_weight_type, unroll_num, expert_masking>; \
|
||||
using kernel = ck_tile::MoeSortingMultiPhaseKernel_P1<ms_problem>; \
|
||||
auto kargs = kernel::MakeKargs(a); \
|
||||
const dim3 grids = kernel::GridSize(a); \
|
||||
const dim3 blocks = kernel::BlockSize(a); \
|
||||
return ck_tile::make_kernel(kernel{}, grids, blocks, 0, kargs); \
|
||||
}()
|
||||
|
||||
#define MOE_SORTING_MP_2(unroll_num_, expert_masking_) \
|
||||
[&]() { \
|
||||
constexpr ck_tile::index_t unroll_num = unroll_num_; \
|
||||
constexpr bool expert_masking = expert_masking_; \
|
||||
using ms_problem = \
|
||||
ck_tile::MoeSortingProblemMp<ms_index_t, ms_weight_type, unroll_num, expert_masking>; \
|
||||
using kernel = ck_tile::MoeSortingMultiPhaseKernel_P2<ms_problem>; \
|
||||
auto kargs = kernel::MakeKargs(a); \
|
||||
const dim3 grids = kernel::GridSize(a); \
|
||||
const dim3 blocks = kernel::BlockSize(a); \
|
||||
return ck_tile::make_kernel(kernel{}, grids, blocks, 0, kargs); \
|
||||
}()
|
||||
|
||||
#define MOE_SORTING_MP_3(unroll_num_, expert_masking_) \
|
||||
[&]() { \
|
||||
constexpr ck_tile::index_t unroll_num = unroll_num_; \
|
||||
constexpr bool expert_masking = expert_masking_; \
|
||||
using ms_problem = \
|
||||
ck_tile::MoeSortingProblemMp<ms_index_t, ms_weight_type, unroll_num, expert_masking>; \
|
||||
using kernel = ck_tile::MoeSortingMultiPhaseKernel_P3<ms_problem>; \
|
||||
auto kargs = kernel::MakeKargs(a); \
|
||||
const dim3 grids = kernel::GridSize(a); \
|
||||
const dim3 blocks = kernel::BlockSize(a); \
|
||||
return ck_tile::make_kernel(kernel{}, grids, blocks, 0, kargs); \
|
||||
}()
|
||||
|
||||
float moe_sorting_mp(moe_sorting_trait t, moe_sorting_args a, ck_tile::stream_config s)
|
||||
{
|
||||
if(t.weight_type == "fp32" && t.index_type == "int32")
|
||||
{
|
||||
using ms_index_t = ck_tile::index_t;
|
||||
using ms_weight_type = float;
|
||||
|
||||
if(t.local_expert_masking)
|
||||
{
|
||||
float ave_time = ck_tile::launch_kernel(s,
|
||||
MOE_SORTING_MP_0(1, true),
|
||||
MOE_SORTING_MP_1(1, true),
|
||||
MOE_SORTING_MP_2(1, true),
|
||||
MOE_SORTING_MP_3(1, true));
|
||||
return ave_time;
|
||||
}
|
||||
else
|
||||
{
|
||||
float ave_time = ck_tile::launch_kernel(s,
|
||||
MOE_SORTING_MP_0(1, false),
|
||||
MOE_SORTING_MP_1(1, false),
|
||||
MOE_SORTING_MP_2(1, false),
|
||||
MOE_SORTING_MP_3(1, false));
|
||||
return ave_time;
|
||||
}
|
||||
}
|
||||
return -1;
|
||||
}
|
||||
|
||||
int moe_sorting_get_workspace_size(int tokens, int num_experts)
|
||||
{
|
||||
return ck_tile::moe_sorting_get_workspace_size(tokens, num_experts);
|
||||
}
|
||||
|
||||
@@ -18,4 +18,10 @@ struct moe_sorting_args : public ck_tile::MoeSortingHostArgs
|
||||
{
|
||||
};
|
||||
|
||||
// use below API before call moe_sorting() to indicate if need workspace or not
|
||||
// if return non zero, means need workspace, you need to allocate a GPU buffer
|
||||
// and set to moe_sorting_args.p_ws
|
||||
// NOTE: workspace size are required to clear zero before use the API
|
||||
int moe_sorting_get_workspace_size(int tokens, int num_experts);
|
||||
float moe_sorting(moe_sorting_trait t, moe_sorting_args a, ck_tile::stream_config s);
|
||||
float moe_sorting_mp(moe_sorting_trait t, moe_sorting_args a, ck_tile::stream_config s);
|
||||
|
||||
@@ -17,6 +17,9 @@ struct fused_moe_args
|
||||
const void* y_smooth_scale_ptr; // [e, 1, n], smooth-quant-scale for 2nd gemm input
|
||||
const void* local_expert_mask_ptr; // [e], local_expert_mask_ptr for EP
|
||||
void* o_ptr; // [m, k], output token (no need to do zeroing)
|
||||
void* ws_ptr; // size is moe_sorting_get_workspace_size()
|
||||
// if return zero, then could be nullptr
|
||||
// must be cleard before use
|
||||
|
||||
const void* topk_ids_ptr; // [tokens, topk]
|
||||
const void* topk_weight_ptr; // [tokens, topk]
|
||||
|
||||
@@ -27,6 +27,7 @@ float fused_moe(fused_moe_traits t, fused_moe_args a, const ck_tile::stream_conf
|
||||
a.sorted_expert_ids_ptr, // void* p_sorted_expert_ids;
|
||||
a.num_sorted_tiles_ptr, // void* p_total_tokens_post_pad;
|
||||
a.o_ptr, // void* p_moe_buf;
|
||||
a.ws_ptr, // void* p_ws;
|
||||
a.num_tokens, // index_t tokens;
|
||||
a.block_m, // index_t unit_size;
|
||||
a.num_experts, // index_t num_experts;
|
||||
|
||||
@@ -371,6 +371,12 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
ck_tile::DeviceMem num_sorted_tiles_buf(
|
||||
num_sorted_tiles_host.get_element_space_size_in_bytes());
|
||||
|
||||
// if return zero, means no need workspace, can set moe_sorting_args.p_ws to nullptr
|
||||
ck_tile::index_t workspace_size = ck_tile::moe_sorting_get_workspace_size(tokens, experts);
|
||||
ck_tile::DeviceMem moe_sorting_ws(workspace_size != 0 ? workspace_size : 0);
|
||||
if(workspace_size != 0)
|
||||
moe_sorting_ws.SetZero(); // note, clear here!!!!
|
||||
|
||||
fused_moe_traits traits{prec_i,
|
||||
prec_w,
|
||||
prec_o,
|
||||
@@ -394,6 +400,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
local_expert_masking ? local_expert_mask_buf.GetDeviceBuffer()
|
||||
: nullptr,
|
||||
o_buf.GetDeviceBuffer(),
|
||||
workspace_size != 0 ? moe_sorting_ws.GetDeviceBuffer() : nullptr,
|
||||
topk_ids_buf.GetDeviceBuffer(),
|
||||
topk_weight_buf.GetDeviceBuffer(),
|
||||
sorted_token_ids_buf.GetDeviceBuffer(),
|
||||
|
||||
Reference in New Issue
Block a user