Felix/opt sorting (#2902)

* merge felix/sorting
* opt moe sorting  (#2822)
* opt moe storing for 2k
---------
Co-authored-by: lalala-sh <Jiaxing.Wen@amd.com>
Co-authored-by: coderfeli <coderfeli@163.com>
This commit is contained in:
felix
2025-10-15 09:24:03 +08:00
committed by GitHub
parent ca1ab083a7
commit 4c826abfff
4 changed files with 812 additions and 217 deletions

View File

@@ -194,22 +194,40 @@ float moe_sorting(moe_sorting_trait t, moe_sorting_args a, ck_tile::stream_confi
return -1;
}
#define MOE_SORTING_MP_0(mesh_type_, unroll_num_, expert_masking_, local_token_) \
[&]() { \
constexpr ck_tile::index_t unroll_num = unroll_num_; \
constexpr bool expert_masking = expert_masking_; \
constexpr bool local_token = local_token_; \
using ms_problem = ck_tile::MoeSortingProblemMp<ms_index_t, \
ms_weight_type, \
mesh_type_, \
unroll_num, \
expert_masking, \
local_token>; \
using kernel = ck_tile::MoeSortingMultiPhaseKernel_P0<ms_problem>; \
auto kargs = kernel::MakeKargs(a); \
const dim3 grids = kernel::GridSize(a); \
const dim3 blocks = kernel::BlockSize(a); \
return ck_tile::make_kernel(kernel{}, grids, blocks, 0, kargs); \
#define MOE_SORTING_MP_0_V1(mesh_type_, unroll_num_, expert_masking_, local_token_) \
[&]() { \
constexpr ck_tile::index_t unroll_num = unroll_num_; \
constexpr bool expert_masking = expert_masking_; \
constexpr bool local_token = local_token_; \
using ms_problem = ck_tile::MoeSortingProblemMp<ms_index_t, \
ms_weight_type, \
mesh_type_, \
unroll_num, \
expert_masking, \
local_token>; \
using kernel = ck_tile::MoeSortingMultiPhaseKernel_P0_v1<ms_problem>; \
auto kargs = kernel::MakeKargs(a); \
const dim3 grids = kernel::GridSize(a); \
const dim3 blocks = kernel::BlockSize(a); \
return ck_tile::make_kernel<kernel::kBlockSize>(kernel{}, grids, blocks, 0, kargs); \
}()
#define MOE_SORTING_MP_0_V2(mesh_type_, unroll_num_, expert_masking_, local_token_) \
[&]() { \
constexpr ck_tile::index_t unroll_num = unroll_num_; \
constexpr bool expert_masking = expert_masking_; \
constexpr bool local_token = local_token_; \
using ms_problem = ck_tile::MoeSortingProblemMp<ms_index_t, \
ms_weight_type, \
mesh_type_, \
unroll_num, \
expert_masking, \
local_token>; \
using kernel = ck_tile::MoeSortingMultiPhaseKernel_P0_v2<ms_problem>; \
auto kargs = kernel::MakeKargs(a); \
const dim3 grids = kernel::GridSize(a); \
const dim3 blocks = kernel::BlockSize(a); \
return ck_tile::make_kernel(kernel{}, grids, blocks, 0, kargs); \
}()
#define MOE_SORTING_MP_1(mesh_type_, unroll_num_, expert_masking_, local_token_) \
@@ -286,6 +304,46 @@ float moe_sorting(moe_sorting_trait t, moe_sorting_args a, ck_tile::stream_confi
return ck_tile::make_kernel(kernel{}, grids, blocks, lds_size, kargs); \
}()
#define MOR_SORTING_MP_DISPATCH_SMALL_(mesh_type_, token_vec_0_, token_vec_1_, token_vec_23_) \
if(t.local_expert_masking) \
{ \
if(is_local_token) \
{ \
float ave_time = \
ck_tile::launch_kernel(s, \
MOE_SORTING_MP_0_V2(mesh_type_, token_vec_0_, true, true), \
MOE_SORTING_MP_23(mesh_type_, token_vec_23_, true, true)); \
return ave_time; \
} \
else \
{ \
float ave_time = \
ck_tile::launch_kernel(s, \
MOE_SORTING_MP_0_V2(mesh_type_, token_vec_0_, true, false), \
MOE_SORTING_MP_23(mesh_type_, token_vec_23_, true, false)); \
return ave_time; \
} \
} \
else \
{ \
if(is_local_token) \
{ \
float ave_time = \
ck_tile::launch_kernel(s, \
MOE_SORTING_MP_0_V2(mesh_type_, token_vec_0_, false, true), \
MOE_SORTING_MP_23(mesh_type_, token_vec_23_, false, true)); \
return ave_time; \
} \
else \
{ \
float ave_time = ck_tile::launch_kernel( \
s, \
MOE_SORTING_MP_0_V2(mesh_type_, token_vec_0_, false, false), \
MOE_SORTING_MP_23(mesh_type_, token_vec_23_, false, false)); \
return ave_time; \
} \
}
#define MOR_SORTING_MP_DISPATCH_(mesh_type_, token_vec_0_, token_vec_1_, token_vec_23_) \
if(t.local_expert_masking) \
{ \
@@ -294,7 +352,7 @@ float moe_sorting(moe_sorting_trait t, moe_sorting_args a, ck_tile::stream_confi
float ave_time = \
ck_tile::launch_kernel(s, \
maybe_clear_workspace, \
MOE_SORTING_MP_0(mesh_type_, token_vec_0_, true, true), \
MOE_SORTING_MP_0_V1(mesh_type_, token_vec_0_, true, true), \
MOE_SORTING_MP_1(mesh_type_, token_vec_1_, true, true), \
MOE_SORTING_MP_23(mesh_type_, token_vec_23_, true, true)); \
return ave_time; \
@@ -304,7 +362,7 @@ float moe_sorting(moe_sorting_trait t, moe_sorting_args a, ck_tile::stream_confi
float ave_time = \
ck_tile::launch_kernel(s, \
maybe_clear_workspace, \
MOE_SORTING_MP_0(mesh_type_, token_vec_0_, true, false), \
MOE_SORTING_MP_0_V1(mesh_type_, token_vec_0_, true, false), \
MOE_SORTING_MP_1(mesh_type_, token_vec_1_, true, false), \
MOE_SORTING_MP_23(mesh_type_, token_vec_23_, true, false)); \
return ave_time; \
@@ -317,7 +375,7 @@ float moe_sorting(moe_sorting_trait t, moe_sorting_args a, ck_tile::stream_confi
float ave_time = \
ck_tile::launch_kernel(s, \
maybe_clear_workspace, \
MOE_SORTING_MP_0(mesh_type_, token_vec_0_, false, true), \
MOE_SORTING_MP_0_V1(mesh_type_, token_vec_0_, false, true), \
MOE_SORTING_MP_1(mesh_type_, token_vec_1_, false, true), \
MOE_SORTING_MP_23(mesh_type_, token_vec_23_, false, true)); \
return ave_time; \
@@ -327,7 +385,7 @@ float moe_sorting(moe_sorting_trait t, moe_sorting_args a, ck_tile::stream_confi
float ave_time = ck_tile::launch_kernel( \
s, \
maybe_clear_workspace, \
MOE_SORTING_MP_0(mesh_type_, token_vec_0_, false, false), \
MOE_SORTING_MP_0_V1(mesh_type_, token_vec_0_, false, false), \
MOE_SORTING_MP_1(mesh_type_, token_vec_1_, false, false), \
MOE_SORTING_MP_23(mesh_type_, token_vec_23_, false, false)); \
return ave_time; \
@@ -369,69 +427,140 @@ float moe_sorting_mp(moe_sorting_trait t, moe_sorting_args a, ck_tile::stream_co
}
};
if(ck_tile::impl::moe_sorting_get_smem_size_p23(a.num_experts) >
ck_tile::get_smem_capacity())
if(a.tokens < 2048)
{
if(ck_tile::impl::moe_sorting_get_smem_size_p23(a.num_experts) >
ck_tile::get_smem_capacity())
{
#if MOE_SORTING_SUPPORT_LARGE_EXPERT
if(t.local_expert_masking)
{
float ave_time = ck_tile::launch_kernel(s,
maybe_clear_workspace,
MOE_SORTING_MP_0(ms_index_t, 1, true),
MOE_SORTING_MP_1(ms_index_t, 1, true),
MOE_SORTING_MP_2(ms_index_t, 1, true),
MOE_SORTING_MP_3(ms_index_t, 1, true));
return ave_time;
}
else
{
float ave_time = ck_tile::launch_kernel(s,
maybe_clear_workspace,
MOE_SORTING_MP_0(ms_index_t, 1, false),
MOE_SORTING_MP_1(ms_index_t, 1, false),
MOE_SORTING_MP_2(ms_index_t, 1, false),
MOE_SORTING_MP_3(ms_index_t, 1, false));
return ave_time;
}
#else
printf("do not support large expert %d\n", a.num_experts);
return -1;
#endif
}
else
{
ck_tile::index_t mesh_byte_size =
ck_tile::impl::moe_sorting_mesh_byte_size(a.tokens, a.num_experts, a.topk);
if(mesh_byte_size == 1)
{
if(a.tokens * a.topk % 4 == 0)
if(t.local_expert_masking)
{
MOR_SORTING_MP_DISPATCH_(uint8_t, 4, 16, 16)
float ave_time =
ck_tile::launch_kernel(s,
MOE_SORTING_MP_0_V2(ms_index_t, 1, true),
MOE_SORTING_MP_2(ms_index_t, 1, true),
MOE_SORTING_MP_3(ms_index_t, 1, true));
return ave_time;
}
else
{
MOR_SORTING_MP_DISPATCH_(uint8_t, 1, 16, 16)
}
}
else if(mesh_byte_size == 2)
{
#if MOE_SORTING_SUPPORT_LARGE_TOPK
if(a.tokens * a.topk % 4 == 0)
{
MOR_SORTING_MP_DISPATCH_(uint16_t, 4, 8, 8)
}
else
{
MOR_SORTING_MP_DISPATCH_(uint16_t, 1, 8, 8)
float ave_time =
ck_tile::launch_kernel(s,
MOE_SORTING_MP_0_V2(ms_index_t, 1, false),
MOE_SORTING_MP_2(ms_index_t, 1, false),
MOE_SORTING_MP_3(ms_index_t, 1, false));
return ave_time;
}
#else
printf("do not support large topk %d\n", a.topk);
printf("do not support large expert %d\n", a.num_experts);
return -1;
#endif
}
else
{
MOR_SORTING_MP_DISPATCH_(ck_tile::index_t, 1, 1, 1)
ck_tile::index_t mesh_byte_size =
ck_tile::impl::moe_sorting_mesh_byte_size(a.tokens, a.num_experts, a.topk);
if(mesh_byte_size == 1)
{
if(a.tokens * a.topk % 4 == 0)
{
MOR_SORTING_MP_DISPATCH_SMALL_(uint8_t, 4, 16, 16)
}
else
{
MOR_SORTING_MP_DISPATCH_SMALL_(uint8_t, 1, 16, 16)
}
}
else if(mesh_byte_size == 2)
{
#if MOE_SORTING_SUPPORT_LARGE_TOPK
if(a.tokens * a.topk % 4 == 0)
{
MOR_SORTING_MP_DISPATCH_SMALL_(uint16_t, 4, 8, 8)
}
else
{
MOR_SORTING_MP_DISPATCH_SMALL_(uint16_t, 1, 8, 8)
}
#else
printf("do not support large topk %d\n", a.topk);
return -1;
#endif
}
else
{
MOR_SORTING_MP_DISPATCH_SMALL_(ck_tile::index_t, 1, 1, 1)
}
}
}
else
{
if(ck_tile::impl::moe_sorting_get_smem_size_p23(a.num_experts) >
ck_tile::get_smem_capacity())
{
#if MOE_SORTING_SUPPORT_LARGE_EXPERT
if(t.local_expert_masking)
{
float ave_time =
ck_tile::launch_kernel(s,
maybe_clear_workspace,
MOE_SORTING_MP_0_V1(ms_index_t, 1, true),
MOE_SORTING_MP_1(ms_index_t, 1, true),
MOE_SORTING_MP_2(ms_index_t, 1, true),
MOE_SORTING_MP_3(ms_index_t, 1, true));
return ave_time;
}
else
{
float ave_time =
ck_tile::launch_kernel(s,
maybe_clear_workspace,
MOE_SORTING_MP_0_V1(ms_index_t, 1, false),
MOE_SORTING_MP_1(ms_index_t, 1, false),
MOE_SORTING_MP_2(ms_index_t, 1, false),
MOE_SORTING_MP_3(ms_index_t, 1, false));
return ave_time;
}
#else
printf("do not support large expert %d\n", a.num_experts);
return -1;
#endif
}
else
{
ck_tile::index_t mesh_byte_size =
ck_tile::impl::moe_sorting_mesh_byte_size(a.tokens, a.num_experts, a.topk);
if(mesh_byte_size == 1)
{
if(a.tokens * a.topk % 4 == 0)
{
MOR_SORTING_MP_DISPATCH_(uint8_t, 4, 16, 16)
}
else
{
MOR_SORTING_MP_DISPATCH_(uint8_t, 1, 16, 16)
}
}
else if(mesh_byte_size == 2)
{
#if MOE_SORTING_SUPPORT_LARGE_TOPK
if(a.tokens * a.topk % 4 == 0)
{
MOR_SORTING_MP_DISPATCH_(uint16_t, 4, 8, 8)
}
else
{
MOR_SORTING_MP_DISPATCH_(uint16_t, 1, 8, 8)
}
#else
printf("do not support large topk %d\n", a.topk);
return -1;
#endif
}
else
{
MOR_SORTING_MP_DISPATCH_(ck_tile::index_t, 1, 1, 1)
}
}
}
}

View File

@@ -198,22 +198,40 @@ float fused_moesorting(fused_moesorting_trait t, fused_moesorting_args a, ck_til
return -1;
}
#define MOE_SORTING_MP_0(mesh_type_, unroll_num_, expert_masking_, local_token_) \
[&]() { \
constexpr ck_tile::index_t unroll_num = unroll_num_; \
constexpr bool expert_masking = expert_masking_; \
constexpr bool local_token = local_token_; \
using ms_problem = ck_tile::MoeSortingProblemMp<ms_index_t, \
ms_weight_type, \
mesh_type_, \
unroll_num, \
expert_masking, \
local_token>; \
using kernel = ck_tile::MoeSortingMultiPhaseKernel_P0<ms_problem>; \
auto kargs = kernel::MakeKargs(a); \
const dim3 grids = kernel::GridSize(a); \
const dim3 blocks = kernel::BlockSize(a); \
return ck_tile::make_kernel(kernel{}, grids, blocks, 0, kargs); \
#define MOE_SORTING_MP_0_V1(mesh_type_, unroll_num_, expert_masking_, local_token_) \
[&]() { \
constexpr ck_tile::index_t unroll_num = unroll_num_; \
constexpr bool expert_masking = expert_masking_; \
constexpr bool local_token = local_token_; \
using ms_problem = ck_tile::MoeSortingProblemMp<ms_index_t, \
ms_weight_type, \
mesh_type_, \
unroll_num, \
expert_masking, \
local_token>; \
using kernel = ck_tile::MoeSortingMultiPhaseKernel_P0_v1<ms_problem>; \
auto kargs = kernel::MakeKargs(a); \
const dim3 grids = kernel::GridSize(a); \
const dim3 blocks = kernel::BlockSize(a); \
return ck_tile::make_kernel<kernel::kBlockSize>(kernel{}, grids, blocks, 0, kargs); \
}()
#define MOE_SORTING_MP_0_V2(mesh_type_, unroll_num_, expert_masking_, local_token_) \
[&]() { \
constexpr ck_tile::index_t unroll_num = unroll_num_; \
constexpr bool expert_masking = expert_masking_; \
constexpr bool local_token = local_token_; \
using ms_problem = ck_tile::MoeSortingProblemMp<ms_index_t, \
ms_weight_type, \
mesh_type_, \
unroll_num, \
expert_masking, \
local_token>; \
using kernel = ck_tile::MoeSortingMultiPhaseKernel_P0_v2<ms_problem>; \
auto kargs = kernel::MakeKargs(a); \
const dim3 grids = kernel::GridSize(a); \
const dim3 blocks = kernel::BlockSize(a); \
return ck_tile::make_kernel(kernel{}, grids, blocks, 0, kargs); \
}()
#define MOE_SORTING_MP_1(mesh_type_, unroll_num_, expert_masking_, local_token_) \
@@ -290,6 +308,46 @@ float fused_moesorting(fused_moesorting_trait t, fused_moesorting_args a, ck_til
return ck_tile::make_kernel(kernel{}, grids, blocks, lds_size, kargs); \
}()
#define MOR_SORTING_MP_DISPATCH_SMALL_(mesh_type_, token_vec_0_, token_vec_1_, token_vec_23_) \
if(t.local_expert_masking) \
{ \
if(is_local_token) \
{ \
float ave_time = \
ck_tile::launch_kernel(s, \
MOE_SORTING_MP_0_V2(mesh_type_, token_vec_0_, true, true), \
MOE_SORTING_MP_23(mesh_type_, token_vec_23_, true, true)); \
return ave_time; \
} \
else \
{ \
float ave_time = \
ck_tile::launch_kernel(s, \
MOE_SORTING_MP_0_V2(mesh_type_, token_vec_0_, true, false), \
MOE_SORTING_MP_23(mesh_type_, token_vec_23_, true, false)); \
return ave_time; \
} \
} \
else \
{ \
if(is_local_token) \
{ \
float ave_time = \
ck_tile::launch_kernel(s, \
MOE_SORTING_MP_0_V2(mesh_type_, token_vec_0_, false, true), \
MOE_SORTING_MP_23(mesh_type_, token_vec_23_, false, true)); \
return ave_time; \
} \
else \
{ \
float ave_time = ck_tile::launch_kernel( \
s, \
MOE_SORTING_MP_0_V2(mesh_type_, token_vec_0_, false, false), \
MOE_SORTING_MP_23(mesh_type_, token_vec_23_, false, false)); \
return ave_time; \
} \
}
#define MOR_SORTING_MP_DISPATCH_(mesh_type_, token_vec_0_, token_vec_1_, token_vec_23_) \
if(t.local_expert_masking) \
{ \
@@ -297,7 +355,7 @@ float fused_moesorting(fused_moesorting_trait t, fused_moesorting_args a, ck_til
{ \
float ave_time = \
ck_tile::launch_kernel(s, \
MOE_SORTING_MP_0(mesh_type_, token_vec_0_, true, true), \
MOE_SORTING_MP_0_V1(mesh_type_, token_vec_0_, true, true), \
MOE_SORTING_MP_1(mesh_type_, token_vec_1_, true, true), \
MOE_SORTING_MP_23(mesh_type_, token_vec_23_, true, true)); \
return ave_time; \
@@ -306,7 +364,7 @@ float fused_moesorting(fused_moesorting_trait t, fused_moesorting_args a, ck_til
{ \
float ave_time = \
ck_tile::launch_kernel(s, \
MOE_SORTING_MP_0(mesh_type_, token_vec_0_, true, false), \
MOE_SORTING_MP_0_V1(mesh_type_, token_vec_0_, true, false), \
MOE_SORTING_MP_1(mesh_type_, token_vec_1_, true, false), \
MOE_SORTING_MP_23(mesh_type_, token_vec_23_, true, false)); \
return ave_time; \
@@ -318,7 +376,7 @@ float fused_moesorting(fused_moesorting_trait t, fused_moesorting_args a, ck_til
{ \
float ave_time = \
ck_tile::launch_kernel(s, \
MOE_SORTING_MP_0(mesh_type_, token_vec_0_, false, true), \
MOE_SORTING_MP_0_V1(mesh_type_, token_vec_0_, false, true), \
MOE_SORTING_MP_1(mesh_type_, token_vec_1_, false, true), \
MOE_SORTING_MP_23(mesh_type_, token_vec_23_, false, true)); \
return ave_time; \
@@ -327,7 +385,7 @@ float fused_moesorting(fused_moesorting_trait t, fused_moesorting_args a, ck_til
{ \
float ave_time = ck_tile::launch_kernel( \
s, \
MOE_SORTING_MP_0(mesh_type_, token_vec_0_, false, false), \
MOE_SORTING_MP_0_V1(mesh_type_, token_vec_0_, false, false), \
MOE_SORTING_MP_1(mesh_type_, token_vec_1_, false, false), \
MOE_SORTING_MP_23(mesh_type_, token_vec_23_, false, false)); \
return ave_time; \
@@ -344,67 +402,156 @@ float fused_moesorting_mp(fused_moesorting_trait t,
using ms_index_t = ck_tile::index_t;
using ms_weight_type = float;
if(ck_tile::impl::moe_sorting_get_smem_size_p23(a.num_experts) >
ck_tile::get_smem_capacity())
auto maybe_clear_workspace = [=](const ck_tile::stream_config& s_) {
if(t.clear_workspace_inside_api)
{
if(is_local_token)
{
auto k = MOR_SORTING_CLEAR_WS_DISPATCH_(true, 1024, 1);
k(s_);
}
else
{
auto k = MOR_SORTING_CLEAR_WS_DISPATCH_(false, 1024, 1);
k(s_);
}
}
};
if(a.tokens < 2048)
{
if(ck_tile::impl::moe_sorting_get_smem_size_p23(a.num_experts) >
ck_tile::get_smem_capacity())
{
#if MOE_SORTING_SUPPORT_LARGE_EXPERT
if(t.local_expert_masking)
{
float ave_time = ck_tile::launch_kernel(s,
MOE_SORTING_MP_0(ms_index_t, 1, true),
MOE_SORTING_MP_1(ms_index_t, 1, true),
MOE_SORTING_MP_2(ms_index_t, 1, true),
MOE_SORTING_MP_3(ms_index_t, 1, true));
return ave_time;
}
else
{
float ave_time = ck_tile::launch_kernel(s,
MOE_SORTING_MP_0(ms_index_t, 1, false),
MOE_SORTING_MP_1(ms_index_t, 1, false),
MOE_SORTING_MP_2(ms_index_t, 1, false),
MOE_SORTING_MP_3(ms_index_t, 1, false));
return ave_time;
}
#else
printf("do not support large expert %d\n", a.num_experts);
return -1;
#endif
}
else
{
ck_tile::index_t mesh_byte_size =
ck_tile::impl::moe_sorting_mesh_byte_size(a.tokens, a.num_experts, a.topk);
if(mesh_byte_size == 1)
{
if(a.tokens * a.topk % 4 == 0)
if(t.local_expert_masking)
{
MOR_SORTING_MP_DISPATCH_(uint8_t, 4, 16, 16)
float ave_time =
ck_tile::launch_kernel(s,
MOE_SORTING_MP_0_V2(ms_index_t, 1, true),
MOE_SORTING_MP_2(ms_index_t, 1, true),
MOE_SORTING_MP_3(ms_index_t, 1, true));
return ave_time;
}
else
{
MOR_SORTING_MP_DISPATCH_(uint8_t, 1, 16, 16)
}
}
else if(mesh_byte_size == 2)
{
#if MOE_SORTING_SUPPORT_LARGE_TOPK
if(a.tokens * a.topk % 4 == 0)
{
MOR_SORTING_MP_DISPATCH_(uint16_t, 4, 8, 8)
}
else
{
MOR_SORTING_MP_DISPATCH_(uint16_t, 1, 8, 8)
float ave_time =
ck_tile::launch_kernel(s,
MOE_SORTING_MP_0_V2(ms_index_t, 1, false),
MOE_SORTING_MP_2(ms_index_t, 1, false),
MOE_SORTING_MP_3(ms_index_t, 1, false));
return ave_time;
}
#else
printf("do not support large topk %d\n", a.topk);
printf("do not support large expert %d\n", a.num_experts);
return -1;
#endif
}
else
{
MOR_SORTING_MP_DISPATCH_(ck_tile::index_t, 1, 1, 1)
ck_tile::index_t mesh_byte_size =
ck_tile::impl::moe_sorting_mesh_byte_size(a.tokens, a.num_experts, a.topk);
if(mesh_byte_size == 1)
{
if(a.tokens * a.topk % 4 == 0)
{
MOR_SORTING_MP_DISPATCH_SMALL_(uint8_t, 4, 16, 16)
}
else
{
MOR_SORTING_MP_DISPATCH_SMALL_(uint8_t, 1, 16, 16)
}
}
else if(mesh_byte_size == 2)
{
#if MOE_SORTING_SUPPORT_LARGE_TOPK
if(a.tokens * a.topk % 4 == 0)
{
MOR_SORTING_MP_DISPATCH_SMALL_(uint16_t, 4, 8, 8)
}
else
{
MOR_SORTING_MP_DISPATCH_SMALL_(uint16_t, 1, 8, 8)
}
#else
printf("do not support large topk %d\n", a.topk);
return -1;
#endif
}
else
{
MOR_SORTING_MP_DISPATCH_SMALL_(ck_tile::index_t, 1, 1, 1)
}
}
}
else
{
if(ck_tile::impl::moe_sorting_get_smem_size_p23(a.num_experts) >
ck_tile::get_smem_capacity())
{
#if MOE_SORTING_SUPPORT_LARGE_EXPERT
if(t.local_expert_masking)
{
float ave_time =
ck_tile::launch_kernel(s,
maybe_clear_workspace,
MOE_SORTING_MP_0_V1(ms_index_t, 1, true),
MOE_SORTING_MP_1(ms_index_t, 1, true),
MOE_SORTING_MP_2(ms_index_t, 1, true),
MOE_SORTING_MP_3(ms_index_t, 1, true));
return ave_time;
}
else
{
float ave_time =
ck_tile::launch_kernel(s,
maybe_clear_workspace,
MOE_SORTING_MP_0_V1(ms_index_t, 1, false),
MOE_SORTING_MP_1(ms_index_t, 1, false),
MOE_SORTING_MP_2(ms_index_t, 1, false),
MOE_SORTING_MP_3(ms_index_t, 1, false));
return ave_time;
}
#else
printf("do not support large expert %d\n", a.num_experts);
return -1;
#endif
}
else
{
ck_tile::index_t mesh_byte_size =
ck_tile::impl::moe_sorting_mesh_byte_size(a.tokens, a.num_experts, a.topk);
if(mesh_byte_size == 1)
{
if(a.tokens * a.topk % 4 == 0)
{
MOR_SORTING_MP_DISPATCH_(uint8_t, 4, 16, 16)
}
else
{
MOR_SORTING_MP_DISPATCH_(uint8_t, 1, 16, 16)
}
}
else if(mesh_byte_size == 2)
{
#if MOE_SORTING_SUPPORT_LARGE_TOPK
if(a.tokens * a.topk % 4 == 0)
{
MOR_SORTING_MP_DISPATCH_(uint16_t, 4, 8, 8)
}
else
{
MOR_SORTING_MP_DISPATCH_(uint16_t, 1, 8, 8)
}
#else
printf("do not support large topk %d\n", a.topk);
return -1;
#endif
}
else
{
MOR_SORTING_MP_DISPATCH_(ck_tile::index_t, 1, 1, 1)
}
}
}
}