Felix/opt sorting (#2902)

* merge felix/sorting
* opt moe sorting  (#2822)
* opt moe storing for 2k
---------
Co-authored-by: lalala-sh <Jiaxing.Wen@amd.com>
Co-authored-by: coderfeli <coderfeli@163.com>

[ROCm/composable_kernel commit: 4c826abfff]
This commit is contained in:
felix
2025-10-15 09:24:03 +08:00
committed by GitHub
parent f0b0b1e838
commit b6f6b7cd2a
4 changed files with 812 additions and 217 deletions

View File

@@ -194,22 +194,40 @@ float moe_sorting(moe_sorting_trait t, moe_sorting_args a, ck_tile::stream_confi
return -1;
}
#define MOE_SORTING_MP_0(mesh_type_, unroll_num_, expert_masking_, local_token_) \
[&]() { \
constexpr ck_tile::index_t unroll_num = unroll_num_; \
constexpr bool expert_masking = expert_masking_; \
constexpr bool local_token = local_token_; \
using ms_problem = ck_tile::MoeSortingProblemMp<ms_index_t, \
ms_weight_type, \
mesh_type_, \
unroll_num, \
expert_masking, \
local_token>; \
using kernel = ck_tile::MoeSortingMultiPhaseKernel_P0<ms_problem>; \
auto kargs = kernel::MakeKargs(a); \
const dim3 grids = kernel::GridSize(a); \
const dim3 blocks = kernel::BlockSize(a); \
return ck_tile::make_kernel(kernel{}, grids, blocks, 0, kargs); \
#define MOE_SORTING_MP_0_V1(mesh_type_, unroll_num_, expert_masking_, local_token_) \
[&]() { \
constexpr ck_tile::index_t unroll_num = unroll_num_; \
constexpr bool expert_masking = expert_masking_; \
constexpr bool local_token = local_token_; \
using ms_problem = ck_tile::MoeSortingProblemMp<ms_index_t, \
ms_weight_type, \
mesh_type_, \
unroll_num, \
expert_masking, \
local_token>; \
using kernel = ck_tile::MoeSortingMultiPhaseKernel_P0_v1<ms_problem>; \
auto kargs = kernel::MakeKargs(a); \
const dim3 grids = kernel::GridSize(a); \
const dim3 blocks = kernel::BlockSize(a); \
return ck_tile::make_kernel<kernel::kBlockSize>(kernel{}, grids, blocks, 0, kargs); \
}()
#define MOE_SORTING_MP_0_V2(mesh_type_, unroll_num_, expert_masking_, local_token_) \
[&]() { \
constexpr ck_tile::index_t unroll_num = unroll_num_; \
constexpr bool expert_masking = expert_masking_; \
constexpr bool local_token = local_token_; \
using ms_problem = ck_tile::MoeSortingProblemMp<ms_index_t, \
ms_weight_type, \
mesh_type_, \
unroll_num, \
expert_masking, \
local_token>; \
using kernel = ck_tile::MoeSortingMultiPhaseKernel_P0_v2<ms_problem>; \
auto kargs = kernel::MakeKargs(a); \
const dim3 grids = kernel::GridSize(a); \
const dim3 blocks = kernel::BlockSize(a); \
return ck_tile::make_kernel(kernel{}, grids, blocks, 0, kargs); \
}()
#define MOE_SORTING_MP_1(mesh_type_, unroll_num_, expert_masking_, local_token_) \
@@ -286,6 +304,46 @@ float moe_sorting(moe_sorting_trait t, moe_sorting_args a, ck_tile::stream_confi
return ck_tile::make_kernel(kernel{}, grids, blocks, lds_size, kargs); \
}()
#define MOR_SORTING_MP_DISPATCH_SMALL_(mesh_type_, token_vec_0_, token_vec_1_, token_vec_23_) \
if(t.local_expert_masking) \
{ \
if(is_local_token) \
{ \
float ave_time = \
ck_tile::launch_kernel(s, \
MOE_SORTING_MP_0_V2(mesh_type_, token_vec_0_, true, true), \
MOE_SORTING_MP_23(mesh_type_, token_vec_23_, true, true)); \
return ave_time; \
} \
else \
{ \
float ave_time = \
ck_tile::launch_kernel(s, \
MOE_SORTING_MP_0_V2(mesh_type_, token_vec_0_, true, false), \
MOE_SORTING_MP_23(mesh_type_, token_vec_23_, true, false)); \
return ave_time; \
} \
} \
else \
{ \
if(is_local_token) \
{ \
float ave_time = \
ck_tile::launch_kernel(s, \
MOE_SORTING_MP_0_V2(mesh_type_, token_vec_0_, false, true), \
MOE_SORTING_MP_23(mesh_type_, token_vec_23_, false, true)); \
return ave_time; \
} \
else \
{ \
float ave_time = ck_tile::launch_kernel( \
s, \
MOE_SORTING_MP_0_V2(mesh_type_, token_vec_0_, false, false), \
MOE_SORTING_MP_23(mesh_type_, token_vec_23_, false, false)); \
return ave_time; \
} \
}
#define MOR_SORTING_MP_DISPATCH_(mesh_type_, token_vec_0_, token_vec_1_, token_vec_23_) \
if(t.local_expert_masking) \
{ \
@@ -294,7 +352,7 @@ float moe_sorting(moe_sorting_trait t, moe_sorting_args a, ck_tile::stream_confi
float ave_time = \
ck_tile::launch_kernel(s, \
maybe_clear_workspace, \
MOE_SORTING_MP_0(mesh_type_, token_vec_0_, true, true), \
MOE_SORTING_MP_0_V1(mesh_type_, token_vec_0_, true, true), \
MOE_SORTING_MP_1(mesh_type_, token_vec_1_, true, true), \
MOE_SORTING_MP_23(mesh_type_, token_vec_23_, true, true)); \
return ave_time; \
@@ -304,7 +362,7 @@ float moe_sorting(moe_sorting_trait t, moe_sorting_args a, ck_tile::stream_confi
float ave_time = \
ck_tile::launch_kernel(s, \
maybe_clear_workspace, \
MOE_SORTING_MP_0(mesh_type_, token_vec_0_, true, false), \
MOE_SORTING_MP_0_V1(mesh_type_, token_vec_0_, true, false), \
MOE_SORTING_MP_1(mesh_type_, token_vec_1_, true, false), \
MOE_SORTING_MP_23(mesh_type_, token_vec_23_, true, false)); \
return ave_time; \
@@ -317,7 +375,7 @@ float moe_sorting(moe_sorting_trait t, moe_sorting_args a, ck_tile::stream_confi
float ave_time = \
ck_tile::launch_kernel(s, \
maybe_clear_workspace, \
MOE_SORTING_MP_0(mesh_type_, token_vec_0_, false, true), \
MOE_SORTING_MP_0_V1(mesh_type_, token_vec_0_, false, true), \
MOE_SORTING_MP_1(mesh_type_, token_vec_1_, false, true), \
MOE_SORTING_MP_23(mesh_type_, token_vec_23_, false, true)); \
return ave_time; \
@@ -327,7 +385,7 @@ float moe_sorting(moe_sorting_trait t, moe_sorting_args a, ck_tile::stream_confi
float ave_time = ck_tile::launch_kernel( \
s, \
maybe_clear_workspace, \
MOE_SORTING_MP_0(mesh_type_, token_vec_0_, false, false), \
MOE_SORTING_MP_0_V1(mesh_type_, token_vec_0_, false, false), \
MOE_SORTING_MP_1(mesh_type_, token_vec_1_, false, false), \
MOE_SORTING_MP_23(mesh_type_, token_vec_23_, false, false)); \
return ave_time; \
@@ -369,69 +427,140 @@ float moe_sorting_mp(moe_sorting_trait t, moe_sorting_args a, ck_tile::stream_co
}
};
if(ck_tile::impl::moe_sorting_get_smem_size_p23(a.num_experts) >
ck_tile::get_smem_capacity())
if(a.tokens < 2048)
{
if(ck_tile::impl::moe_sorting_get_smem_size_p23(a.num_experts) >
ck_tile::get_smem_capacity())
{
#if MOE_SORTING_SUPPORT_LARGE_EXPERT
if(t.local_expert_masking)
{
float ave_time = ck_tile::launch_kernel(s,
maybe_clear_workspace,
MOE_SORTING_MP_0(ms_index_t, 1, true),
MOE_SORTING_MP_1(ms_index_t, 1, true),
MOE_SORTING_MP_2(ms_index_t, 1, true),
MOE_SORTING_MP_3(ms_index_t, 1, true));
return ave_time;
}
else
{
float ave_time = ck_tile::launch_kernel(s,
maybe_clear_workspace,
MOE_SORTING_MP_0(ms_index_t, 1, false),
MOE_SORTING_MP_1(ms_index_t, 1, false),
MOE_SORTING_MP_2(ms_index_t, 1, false),
MOE_SORTING_MP_3(ms_index_t, 1, false));
return ave_time;
}
#else
printf("do not support large expert %d\n", a.num_experts);
return -1;
#endif
}
else
{
ck_tile::index_t mesh_byte_size =
ck_tile::impl::moe_sorting_mesh_byte_size(a.tokens, a.num_experts, a.topk);
if(mesh_byte_size == 1)
{
if(a.tokens * a.topk % 4 == 0)
if(t.local_expert_masking)
{
MOR_SORTING_MP_DISPATCH_(uint8_t, 4, 16, 16)
float ave_time =
ck_tile::launch_kernel(s,
MOE_SORTING_MP_0_V2(ms_index_t, 1, true),
MOE_SORTING_MP_2(ms_index_t, 1, true),
MOE_SORTING_MP_3(ms_index_t, 1, true));
return ave_time;
}
else
{
MOR_SORTING_MP_DISPATCH_(uint8_t, 1, 16, 16)
}
}
else if(mesh_byte_size == 2)
{
#if MOE_SORTING_SUPPORT_LARGE_TOPK
if(a.tokens * a.topk % 4 == 0)
{
MOR_SORTING_MP_DISPATCH_(uint16_t, 4, 8, 8)
}
else
{
MOR_SORTING_MP_DISPATCH_(uint16_t, 1, 8, 8)
float ave_time =
ck_tile::launch_kernel(s,
MOE_SORTING_MP_0_V2(ms_index_t, 1, false),
MOE_SORTING_MP_2(ms_index_t, 1, false),
MOE_SORTING_MP_3(ms_index_t, 1, false));
return ave_time;
}
#else
printf("do not support large topk %d\n", a.topk);
printf("do not support large expert %d\n", a.num_experts);
return -1;
#endif
}
else
{
MOR_SORTING_MP_DISPATCH_(ck_tile::index_t, 1, 1, 1)
ck_tile::index_t mesh_byte_size =
ck_tile::impl::moe_sorting_mesh_byte_size(a.tokens, a.num_experts, a.topk);
if(mesh_byte_size == 1)
{
if(a.tokens * a.topk % 4 == 0)
{
MOR_SORTING_MP_DISPATCH_SMALL_(uint8_t, 4, 16, 16)
}
else
{
MOR_SORTING_MP_DISPATCH_SMALL_(uint8_t, 1, 16, 16)
}
}
else if(mesh_byte_size == 2)
{
#if MOE_SORTING_SUPPORT_LARGE_TOPK
if(a.tokens * a.topk % 4 == 0)
{
MOR_SORTING_MP_DISPATCH_SMALL_(uint16_t, 4, 8, 8)
}
else
{
MOR_SORTING_MP_DISPATCH_SMALL_(uint16_t, 1, 8, 8)
}
#else
printf("do not support large topk %d\n", a.topk);
return -1;
#endif
}
else
{
MOR_SORTING_MP_DISPATCH_SMALL_(ck_tile::index_t, 1, 1, 1)
}
}
}
else
{
if(ck_tile::impl::moe_sorting_get_smem_size_p23(a.num_experts) >
ck_tile::get_smem_capacity())
{
#if MOE_SORTING_SUPPORT_LARGE_EXPERT
if(t.local_expert_masking)
{
float ave_time =
ck_tile::launch_kernel(s,
maybe_clear_workspace,
MOE_SORTING_MP_0_V1(ms_index_t, 1, true),
MOE_SORTING_MP_1(ms_index_t, 1, true),
MOE_SORTING_MP_2(ms_index_t, 1, true),
MOE_SORTING_MP_3(ms_index_t, 1, true));
return ave_time;
}
else
{
float ave_time =
ck_tile::launch_kernel(s,
maybe_clear_workspace,
MOE_SORTING_MP_0_V1(ms_index_t, 1, false),
MOE_SORTING_MP_1(ms_index_t, 1, false),
MOE_SORTING_MP_2(ms_index_t, 1, false),
MOE_SORTING_MP_3(ms_index_t, 1, false));
return ave_time;
}
#else
printf("do not support large expert %d\n", a.num_experts);
return -1;
#endif
}
else
{
ck_tile::index_t mesh_byte_size =
ck_tile::impl::moe_sorting_mesh_byte_size(a.tokens, a.num_experts, a.topk);
if(mesh_byte_size == 1)
{
if(a.tokens * a.topk % 4 == 0)
{
MOR_SORTING_MP_DISPATCH_(uint8_t, 4, 16, 16)
}
else
{
MOR_SORTING_MP_DISPATCH_(uint8_t, 1, 16, 16)
}
}
else if(mesh_byte_size == 2)
{
#if MOE_SORTING_SUPPORT_LARGE_TOPK
if(a.tokens * a.topk % 4 == 0)
{
MOR_SORTING_MP_DISPATCH_(uint16_t, 4, 8, 8)
}
else
{
MOR_SORTING_MP_DISPATCH_(uint16_t, 1, 8, 8)
}
#else
printf("do not support large topk %d\n", a.topk);
return -1;
#endif
}
else
{
MOR_SORTING_MP_DISPATCH_(ck_tile::index_t, 1, 1, 1)
}
}
}
}

View File

@@ -198,22 +198,40 @@ float fused_moesorting(fused_moesorting_trait t, fused_moesorting_args a, ck_til
return -1;
}
#define MOE_SORTING_MP_0(mesh_type_, unroll_num_, expert_masking_, local_token_) \
[&]() { \
constexpr ck_tile::index_t unroll_num = unroll_num_; \
constexpr bool expert_masking = expert_masking_; \
constexpr bool local_token = local_token_; \
using ms_problem = ck_tile::MoeSortingProblemMp<ms_index_t, \
ms_weight_type, \
mesh_type_, \
unroll_num, \
expert_masking, \
local_token>; \
using kernel = ck_tile::MoeSortingMultiPhaseKernel_P0<ms_problem>; \
auto kargs = kernel::MakeKargs(a); \
const dim3 grids = kernel::GridSize(a); \
const dim3 blocks = kernel::BlockSize(a); \
return ck_tile::make_kernel(kernel{}, grids, blocks, 0, kargs); \
#define MOE_SORTING_MP_0_V1(mesh_type_, unroll_num_, expert_masking_, local_token_) \
[&]() { \
constexpr ck_tile::index_t unroll_num = unroll_num_; \
constexpr bool expert_masking = expert_masking_; \
constexpr bool local_token = local_token_; \
using ms_problem = ck_tile::MoeSortingProblemMp<ms_index_t, \
ms_weight_type, \
mesh_type_, \
unroll_num, \
expert_masking, \
local_token>; \
using kernel = ck_tile::MoeSortingMultiPhaseKernel_P0_v1<ms_problem>; \
auto kargs = kernel::MakeKargs(a); \
const dim3 grids = kernel::GridSize(a); \
const dim3 blocks = kernel::BlockSize(a); \
return ck_tile::make_kernel<kernel::kBlockSize>(kernel{}, grids, blocks, 0, kargs); \
}()
#define MOE_SORTING_MP_0_V2(mesh_type_, unroll_num_, expert_masking_, local_token_) \
[&]() { \
constexpr ck_tile::index_t unroll_num = unroll_num_; \
constexpr bool expert_masking = expert_masking_; \
constexpr bool local_token = local_token_; \
using ms_problem = ck_tile::MoeSortingProblemMp<ms_index_t, \
ms_weight_type, \
mesh_type_, \
unroll_num, \
expert_masking, \
local_token>; \
using kernel = ck_tile::MoeSortingMultiPhaseKernel_P0_v2<ms_problem>; \
auto kargs = kernel::MakeKargs(a); \
const dim3 grids = kernel::GridSize(a); \
const dim3 blocks = kernel::BlockSize(a); \
return ck_tile::make_kernel(kernel{}, grids, blocks, 0, kargs); \
}()
#define MOE_SORTING_MP_1(mesh_type_, unroll_num_, expert_masking_, local_token_) \
@@ -290,6 +308,46 @@ float fused_moesorting(fused_moesorting_trait t, fused_moesorting_args a, ck_til
return ck_tile::make_kernel(kernel{}, grids, blocks, lds_size, kargs); \
}()
#define MOR_SORTING_MP_DISPATCH_SMALL_(mesh_type_, token_vec_0_, token_vec_1_, token_vec_23_) \
if(t.local_expert_masking) \
{ \
if(is_local_token) \
{ \
float ave_time = \
ck_tile::launch_kernel(s, \
MOE_SORTING_MP_0_V2(mesh_type_, token_vec_0_, true, true), \
MOE_SORTING_MP_23(mesh_type_, token_vec_23_, true, true)); \
return ave_time; \
} \
else \
{ \
float ave_time = \
ck_tile::launch_kernel(s, \
MOE_SORTING_MP_0_V2(mesh_type_, token_vec_0_, true, false), \
MOE_SORTING_MP_23(mesh_type_, token_vec_23_, true, false)); \
return ave_time; \
} \
} \
else \
{ \
if(is_local_token) \
{ \
float ave_time = \
ck_tile::launch_kernel(s, \
MOE_SORTING_MP_0_V2(mesh_type_, token_vec_0_, false, true), \
MOE_SORTING_MP_23(mesh_type_, token_vec_23_, false, true)); \
return ave_time; \
} \
else \
{ \
float ave_time = ck_tile::launch_kernel( \
s, \
MOE_SORTING_MP_0_V2(mesh_type_, token_vec_0_, false, false), \
MOE_SORTING_MP_23(mesh_type_, token_vec_23_, false, false)); \
return ave_time; \
} \
}
#define MOR_SORTING_MP_DISPATCH_(mesh_type_, token_vec_0_, token_vec_1_, token_vec_23_) \
if(t.local_expert_masking) \
{ \
@@ -297,7 +355,7 @@ float fused_moesorting(fused_moesorting_trait t, fused_moesorting_args a, ck_til
{ \
float ave_time = \
ck_tile::launch_kernel(s, \
MOE_SORTING_MP_0(mesh_type_, token_vec_0_, true, true), \
MOE_SORTING_MP_0_V1(mesh_type_, token_vec_0_, true, true), \
MOE_SORTING_MP_1(mesh_type_, token_vec_1_, true, true), \
MOE_SORTING_MP_23(mesh_type_, token_vec_23_, true, true)); \
return ave_time; \
@@ -306,7 +364,7 @@ float fused_moesorting(fused_moesorting_trait t, fused_moesorting_args a, ck_til
{ \
float ave_time = \
ck_tile::launch_kernel(s, \
MOE_SORTING_MP_0(mesh_type_, token_vec_0_, true, false), \
MOE_SORTING_MP_0_V1(mesh_type_, token_vec_0_, true, false), \
MOE_SORTING_MP_1(mesh_type_, token_vec_1_, true, false), \
MOE_SORTING_MP_23(mesh_type_, token_vec_23_, true, false)); \
return ave_time; \
@@ -318,7 +376,7 @@ float fused_moesorting(fused_moesorting_trait t, fused_moesorting_args a, ck_til
{ \
float ave_time = \
ck_tile::launch_kernel(s, \
MOE_SORTING_MP_0(mesh_type_, token_vec_0_, false, true), \
MOE_SORTING_MP_0_V1(mesh_type_, token_vec_0_, false, true), \
MOE_SORTING_MP_1(mesh_type_, token_vec_1_, false, true), \
MOE_SORTING_MP_23(mesh_type_, token_vec_23_, false, true)); \
return ave_time; \
@@ -327,7 +385,7 @@ float fused_moesorting(fused_moesorting_trait t, fused_moesorting_args a, ck_til
{ \
float ave_time = ck_tile::launch_kernel( \
s, \
MOE_SORTING_MP_0(mesh_type_, token_vec_0_, false, false), \
MOE_SORTING_MP_0_V1(mesh_type_, token_vec_0_, false, false), \
MOE_SORTING_MP_1(mesh_type_, token_vec_1_, false, false), \
MOE_SORTING_MP_23(mesh_type_, token_vec_23_, false, false)); \
return ave_time; \
@@ -344,67 +402,156 @@ float fused_moesorting_mp(fused_moesorting_trait t,
using ms_index_t = ck_tile::index_t;
using ms_weight_type = float;
if(ck_tile::impl::moe_sorting_get_smem_size_p23(a.num_experts) >
ck_tile::get_smem_capacity())
auto maybe_clear_workspace = [=](const ck_tile::stream_config& s_) {
if(t.clear_workspace_inside_api)
{
if(is_local_token)
{
auto k = MOR_SORTING_CLEAR_WS_DISPATCH_(true, 1024, 1);
k(s_);
}
else
{
auto k = MOR_SORTING_CLEAR_WS_DISPATCH_(false, 1024, 1);
k(s_);
}
}
};
if(a.tokens < 2048)
{
if(ck_tile::impl::moe_sorting_get_smem_size_p23(a.num_experts) >
ck_tile::get_smem_capacity())
{
#if MOE_SORTING_SUPPORT_LARGE_EXPERT
if(t.local_expert_masking)
{
float ave_time = ck_tile::launch_kernel(s,
MOE_SORTING_MP_0(ms_index_t, 1, true),
MOE_SORTING_MP_1(ms_index_t, 1, true),
MOE_SORTING_MP_2(ms_index_t, 1, true),
MOE_SORTING_MP_3(ms_index_t, 1, true));
return ave_time;
}
else
{
float ave_time = ck_tile::launch_kernel(s,
MOE_SORTING_MP_0(ms_index_t, 1, false),
MOE_SORTING_MP_1(ms_index_t, 1, false),
MOE_SORTING_MP_2(ms_index_t, 1, false),
MOE_SORTING_MP_3(ms_index_t, 1, false));
return ave_time;
}
#else
printf("do not support large expert %d\n", a.num_experts);
return -1;
#endif
}
else
{
ck_tile::index_t mesh_byte_size =
ck_tile::impl::moe_sorting_mesh_byte_size(a.tokens, a.num_experts, a.topk);
if(mesh_byte_size == 1)
{
if(a.tokens * a.topk % 4 == 0)
if(t.local_expert_masking)
{
MOR_SORTING_MP_DISPATCH_(uint8_t, 4, 16, 16)
float ave_time =
ck_tile::launch_kernel(s,
MOE_SORTING_MP_0_V2(ms_index_t, 1, true),
MOE_SORTING_MP_2(ms_index_t, 1, true),
MOE_SORTING_MP_3(ms_index_t, 1, true));
return ave_time;
}
else
{
MOR_SORTING_MP_DISPATCH_(uint8_t, 1, 16, 16)
}
}
else if(mesh_byte_size == 2)
{
#if MOE_SORTING_SUPPORT_LARGE_TOPK
if(a.tokens * a.topk % 4 == 0)
{
MOR_SORTING_MP_DISPATCH_(uint16_t, 4, 8, 8)
}
else
{
MOR_SORTING_MP_DISPATCH_(uint16_t, 1, 8, 8)
float ave_time =
ck_tile::launch_kernel(s,
MOE_SORTING_MP_0_V2(ms_index_t, 1, false),
MOE_SORTING_MP_2(ms_index_t, 1, false),
MOE_SORTING_MP_3(ms_index_t, 1, false));
return ave_time;
}
#else
printf("do not support large topk %d\n", a.topk);
printf("do not support large expert %d\n", a.num_experts);
return -1;
#endif
}
else
{
MOR_SORTING_MP_DISPATCH_(ck_tile::index_t, 1, 1, 1)
ck_tile::index_t mesh_byte_size =
ck_tile::impl::moe_sorting_mesh_byte_size(a.tokens, a.num_experts, a.topk);
if(mesh_byte_size == 1)
{
if(a.tokens * a.topk % 4 == 0)
{
MOR_SORTING_MP_DISPATCH_SMALL_(uint8_t, 4, 16, 16)
}
else
{
MOR_SORTING_MP_DISPATCH_SMALL_(uint8_t, 1, 16, 16)
}
}
else if(mesh_byte_size == 2)
{
#if MOE_SORTING_SUPPORT_LARGE_TOPK
if(a.tokens * a.topk % 4 == 0)
{
MOR_SORTING_MP_DISPATCH_SMALL_(uint16_t, 4, 8, 8)
}
else
{
MOR_SORTING_MP_DISPATCH_SMALL_(uint16_t, 1, 8, 8)
}
#else
printf("do not support large topk %d\n", a.topk);
return -1;
#endif
}
else
{
MOR_SORTING_MP_DISPATCH_SMALL_(ck_tile::index_t, 1, 1, 1)
}
}
}
else
{
if(ck_tile::impl::moe_sorting_get_smem_size_p23(a.num_experts) >
ck_tile::get_smem_capacity())
{
#if MOE_SORTING_SUPPORT_LARGE_EXPERT
if(t.local_expert_masking)
{
float ave_time =
ck_tile::launch_kernel(s,
maybe_clear_workspace,
MOE_SORTING_MP_0_V1(ms_index_t, 1, true),
MOE_SORTING_MP_1(ms_index_t, 1, true),
MOE_SORTING_MP_2(ms_index_t, 1, true),
MOE_SORTING_MP_3(ms_index_t, 1, true));
return ave_time;
}
else
{
float ave_time =
ck_tile::launch_kernel(s,
maybe_clear_workspace,
MOE_SORTING_MP_0_V1(ms_index_t, 1, false),
MOE_SORTING_MP_1(ms_index_t, 1, false),
MOE_SORTING_MP_2(ms_index_t, 1, false),
MOE_SORTING_MP_3(ms_index_t, 1, false));
return ave_time;
}
#else
printf("do not support large expert %d\n", a.num_experts);
return -1;
#endif
}
else
{
ck_tile::index_t mesh_byte_size =
ck_tile::impl::moe_sorting_mesh_byte_size(a.tokens, a.num_experts, a.topk);
if(mesh_byte_size == 1)
{
if(a.tokens * a.topk % 4 == 0)
{
MOR_SORTING_MP_DISPATCH_(uint8_t, 4, 16, 16)
}
else
{
MOR_SORTING_MP_DISPATCH_(uint8_t, 1, 16, 16)
}
}
else if(mesh_byte_size == 2)
{
#if MOE_SORTING_SUPPORT_LARGE_TOPK
if(a.tokens * a.topk % 4 == 0)
{
MOR_SORTING_MP_DISPATCH_(uint16_t, 4, 8, 8)
}
else
{
MOR_SORTING_MP_DISPATCH_(uint16_t, 1, 8, 8)
}
#else
printf("do not support large topk %d\n", a.topk);
return -1;
#endif
}
else
{
MOR_SORTING_MP_DISPATCH_(ck_tile::index_t, 1, 1, 1)
}
}
}
}

View File

@@ -20,7 +20,7 @@ namespace ck_tile {
#endif
#ifndef MOE_SORTING_FUSE_MP_01
#define MOE_SORTING_FUSE_MP_01 0
#define MOE_SORTING_FUSE_MP_01 1
#endif
// weather use 2d buffer indexing for fmoe ws or 1d
@@ -527,7 +527,7 @@ struct MoeSortingKernel
}
__syncthreads();
#if 1
#if MOE_SORTING_FUSE_MP_01
if(tid < num_experts)
{
tokens_cnts[calc_index(num_experts + 1, 0, tid)] = 0;
@@ -1322,18 +1322,18 @@ CK_TILE_DEVICE void moe_sorting_wave_cumsum(data_t& thread_data)
}
}
template <index_t BLOCK_SIZE = 256>
template <index_t kBlockSize = 256>
CK_TILE_DEVICE void moe_buf_set_zero_kernel(uint8x16_t* buf, long_index_t buf_bytes, index_t gid)
{
// const index_t offset = (blockIdx.x - 1) * BLOCK_SIZE + threadIdx.x;
long_index_t offset = static_cast<long_index_t>(gid) * BLOCK_SIZE + threadIdx.x;
// const index_t offset = (blockIdx.x - 1) * kBlockSize + threadIdx.x;
long_index_t offset = static_cast<long_index_t>(gid) * kBlockSize + threadIdx.x;
if(offset < buf_bytes / 16)
{
buf[offset] = uint8x16_t{0};
}
}
template <index_t BLOCK_SIZE = 256>
template <index_t kBlockSize = 256>
CK_TILE_DEVICE void moe_buf_set_zero_kernel_2d(
void* buf, index_t row, index_t col, index_t elem_bytes, index_t gid, index_t blocks)
{
@@ -1345,7 +1345,7 @@ CK_TILE_DEVICE void moe_buf_set_zero_kernel_2d(
vector_type* p_buf = reinterpret_cast<vector_type*>(buf);
auto zero_ = vector_type{0};
for(long_index_t i = gid * BLOCK_SIZE + threadIdx.x; i < total_elems; i += blocks * BLOCK_SIZE)
for(long_index_t i = gid * kBlockSize + threadIdx.x; i < total_elems; i += blocks * kBlockSize)
{
p_buf[i] = zero_;
}
@@ -1552,7 +1552,7 @@ p_m_cumsum
// count topk_id into mesh
template <typename Problem_>
struct MoeSortingMultiPhaseKernel_P0
struct MoeSortingMultiPhaseKernel_P0_v1
{
using Problem = remove_cvref_t<Problem_>;
@@ -1673,6 +1673,197 @@ struct MoeSortingMultiPhaseKernel_P0
}
}
};
template <typename Problem_>
struct MoeSortingMultiPhaseKernel_P0_v2
{
using Problem = remove_cvref_t<Problem_>;
using IndexType = typename Problem::IndexType;
using WeightType = typename Problem::WeightType;
using MeshType = typename Problem::MeshType;
static constexpr index_t kBlockSize = 512;
typedef MoeSortingHostArgs MoeSortingKargs;
using Hargs = MoeSortingHostArgs;
struct Kargs
{
const void* p_topk_ids; // [tokens, topk]
const void* p_local_tokens; // [1], if not nullptr, use this as actual tokens
void* p_expert_mesh; // [expert, tokens]
index_t tokens; // if p_local_tokens is not nullptr, this indicate the max possible tokens
// used for ws/LDS calculation
index_t mesh_stride; // mesh_stride for p_expert_mesh
mdiv topk_mdiv;
const void* p_local_expert_mask; // [expert]
void* p_expert_cumsum; // [expert]
index_t num_experts;
};
CK_TILE_HOST static constexpr auto get_num_cu()
{
index_t num_cu = [&]() {
hipDeviceProp_t dev_prop;
hipDevice_t dev;
HIP_CHECK_ERROR(hipGetDevice(&dev));
HIP_CHECK_ERROR(hipGetDeviceProperties(&dev_prop, dev));
return dev_prop.multiProcessorCount;
}();
return num_cu;
}
CK_TILE_HOST static constexpr auto MakeKargs(const Hargs& h)
{
Kargs k;
k.p_topk_ids = h.p_topk_ids;
k.p_local_tokens = h.p_local_tokens;
k.p_expert_mesh = h.p_ws;
k.p_expert_cumsum = reinterpret_cast<void*>(
reinterpret_cast<char*>(h.p_ws) +
impl::moe_sorting_mp_mesh_smem_size(h.tokens, h.num_experts, h.topk));
k.tokens = h.tokens;
k.mesh_stride = impl::moe_sorting_mp_mesh_stride(h.tokens);
k.topk_mdiv = mdiv{static_cast<uint32_t>(h.topk)};
k.p_local_expert_mask = h.p_local_expert_mask;
k.num_experts = h.num_experts;
return k;
}
CK_TILE_HOST static constexpr auto GridSize(const Hargs& h) { return h.num_experts; }
CK_TILE_HOST static constexpr auto BlockSize(const Hargs&) { return dim3(kBlockSize); }
// in byte
// CK_TILE_HOST static constexpr auto GetSmemSize() { return 0; }
CK_TILE_HOST_DEVICE static constexpr auto GetSmemSize()
{
return kBlockSize / get_warp_size() * sizeof(IndexType);
}
CK_TILE_DEVICE void operator()(Kargs kargs) const
{
constexpr index_t index_pack = Problem::SubTokenTile; // always packed
__shared__ char smem[GetSmemSize()];
using topk_id_t = ext_vector_t<IndexType, index_pack>;
const int eid = blockIdx.x;
const topk_id_t* p_topk_ids = reinterpret_cast<const topk_id_t*>(kargs.p_topk_ids);
const IndexType* p_local_expert_mask =
static_cast<const IndexType*>(kargs.p_local_expert_mask);
IndexType* p_expert_cumsum = reinterpret_cast<IndexType*>(kargs.p_expert_cumsum);
index_t lane_id = threadIdx.x % get_warp_size();
index_t wave_id = threadIdx.x / get_warp_size();
const index_t tokens = [&]() {
if constexpr(Problem::LocalToken)
{
return reinterpret_cast<const index_t*>(kargs.p_local_tokens)[0];
}
else
{
return kargs.tokens;
}
}();
index_t rounded_tokens = [&]() {
if constexpr(Problem::LocalToken)
{
return (tokens + index_pack - 1) / index_pack * index_pack;
}
else
return tokens;
}();
index_t mesh_stride = [&]() {
if constexpr(Problem::LocalToken)
{
return impl::moe_sorting_mp_mesh_stride(tokens);
}
else
{
return kargs.mesh_stride;
}
}();
IndexType mask = 1;
if constexpr(Problem::LocalExpertMasking)
{
mask = p_local_expert_mask[eid];
}
MeshType* p_expert_mesh =
reinterpret_cast<MeshType*>(kargs.p_expert_mesh) + eid * mesh_stride;
for(index_t i = threadIdx.x; i < mesh_stride; i += kBlockSize)
{
p_expert_mesh[i] = 0;
}
ck_tile::block_sync_load_raw(0);
index_t total_elem = rounded_tokens * kargs.topk_mdiv.divisor / index_pack;
#pragma unroll index_pack
for(index_t i = threadIdx.x; i < total_elem; i += kBlockSize)
{
auto x = p_topk_ids[i];
static_for<0, index_pack, 1>{}([&](auto j) {
IndexType eid_x = x[j.value]; // ext_vector_type must use int to []
if(eid_x == eid)
{
uint32_t curr_token_id, curr_topk_id;
kargs.topk_mdiv.divmod(i * index_pack + j, curr_token_id, curr_topk_id);
if constexpr(Problem::LocalToken)
{
if(static_cast<index_t>(curr_token_id) < tokens)
p_expert_mesh[curr_token_id] = (curr_topk_id + 1) & 0xffff;
}
else
p_expert_mesh[curr_token_id] = (curr_topk_id + 1) & 0xffff;
}
});
}
ck_tile::block_sync_load_raw(0);
{
using r_t = ext_vector_t<MeshType, index_pack>; // always use int32x4
auto f_sum = [](auto x_, auto y_) { return x_ + y_; };
const r_t* p_expert_mesh_r = reinterpret_cast<r_t*>(p_expert_mesh);
int loops = (mesh_stride / index_pack + kBlockSize - 1) / kBlockSize;
if(Problem::LocalToken && mask == 0)
return; // skip
index_t cnt = 0; // per-wave cnt
for(int i = 0; i < loops; i++)
{
int position = i * kBlockSize + threadIdx.x;
r_t v{0};
if(position < (mesh_stride / index_pack))
v = p_expert_mesh_r[position];
index_t local_sum = 0;
static_for<0, index_pack, 1>{}(
[&](auto i_vec) { local_sum += v[i_vec.value] != 0 ? 1 : 0; });
cnt += impl::moe_sorting_wave_reduce(local_sum, f_sum);
}
// reduce cross wave
IndexType* s = reinterpret_cast<IndexType*>(smem);
if(lane_id == 0)
{
s[wave_id] = cnt;
}
__syncthreads();
if(threadIdx.x == 0)
{
index_t c = 0;
for(auto i = 0; i < (kBlockSize / get_warp_size()); i++)
{
c += s[i];
}
p_expert_cumsum[eid] = c;
}
}
}
};
// cnt total tokens for a expert
template <typename Problem_>

View File

@@ -194,22 +194,40 @@ float moe_sorting(moe_sorting_trait t, moe_sorting_args a, ck_tile::stream_confi
return -1;
}
#define MOE_SORTING_MP_0(mesh_type_, unroll_num_, expert_masking_, local_token_) \
[&]() { \
constexpr ck_tile::index_t unroll_num = unroll_num_; \
constexpr bool expert_masking = expert_masking_; \
constexpr bool local_token = local_token_; \
using ms_problem = ck_tile::MoeSortingProblemMp<ms_index_t, \
ms_weight_type, \
mesh_type_, \
unroll_num, \
expert_masking, \
local_token>; \
using kernel = ck_tile::MoeSortingMultiPhaseKernel_P0<ms_problem>; \
auto kargs = kernel::MakeKargs(a); \
const dim3 grids = kernel::GridSize(a); \
const dim3 blocks = kernel::BlockSize(a); \
return ck_tile::make_kernel(kernel{}, grids, blocks, 0, kargs); \
#define MOE_SORTING_MP_0_V1(mesh_type_, unroll_num_, expert_masking_, local_token_) \
[&]() { \
constexpr ck_tile::index_t unroll_num = unroll_num_; \
constexpr bool expert_masking = expert_masking_; \
constexpr bool local_token = local_token_; \
using ms_problem = ck_tile::MoeSortingProblemMp<ms_index_t, \
ms_weight_type, \
mesh_type_, \
unroll_num, \
expert_masking, \
local_token>; \
using kernel = ck_tile::MoeSortingMultiPhaseKernel_P0_v1<ms_problem>; \
auto kargs = kernel::MakeKargs(a); \
const dim3 grids = kernel::GridSize(a); \
const dim3 blocks = kernel::BlockSize(a); \
return ck_tile::make_kernel<kernel::kBlockSize>(kernel{}, grids, blocks, 0, kargs); \
}()
#define MOE_SORTING_MP_0_V2(mesh_type_, unroll_num_, expert_masking_, local_token_) \
[&]() { \
constexpr ck_tile::index_t unroll_num = unroll_num_; \
constexpr bool expert_masking = expert_masking_; \
constexpr bool local_token = local_token_; \
using ms_problem = ck_tile::MoeSortingProblemMp<ms_index_t, \
ms_weight_type, \
mesh_type_, \
unroll_num, \
expert_masking, \
local_token>; \
using kernel = ck_tile::MoeSortingMultiPhaseKernel_P0_v2<ms_problem>; \
auto kargs = kernel::MakeKargs(a); \
const dim3 grids = kernel::GridSize(a); \
const dim3 blocks = kernel::BlockSize(a); \
return ck_tile::make_kernel(kernel{}, grids, blocks, 0, kargs); \
}()
#define MOE_SORTING_MP_1(mesh_type_, unroll_num_, expert_masking_, local_token_) \
@@ -286,6 +304,46 @@ float moe_sorting(moe_sorting_trait t, moe_sorting_args a, ck_tile::stream_confi
return ck_tile::make_kernel(kernel{}, grids, blocks, lds_size, kargs); \
}()
#define MOR_SORTING_MP_DISPATCH_SMALL_(mesh_type_, token_vec_0_, token_vec_1_, token_vec_23_) \
if(t.local_expert_masking) \
{ \
if(is_local_token) \
{ \
float ave_time = \
ck_tile::launch_kernel(s, \
MOE_SORTING_MP_0_V2(mesh_type_, token_vec_0_, true, true), \
MOE_SORTING_MP_23(mesh_type_, token_vec_23_, true, true)); \
return ave_time; \
} \
else \
{ \
float ave_time = \
ck_tile::launch_kernel(s, \
MOE_SORTING_MP_0_V2(mesh_type_, token_vec_0_, true, false), \
MOE_SORTING_MP_23(mesh_type_, token_vec_23_, true, false)); \
return ave_time; \
} \
} \
else \
{ \
if(is_local_token) \
{ \
float ave_time = \
ck_tile::launch_kernel(s, \
MOE_SORTING_MP_0_V2(mesh_type_, token_vec_0_, false, true), \
MOE_SORTING_MP_23(mesh_type_, token_vec_23_, false, true)); \
return ave_time; \
} \
else \
{ \
float ave_time = ck_tile::launch_kernel( \
s, \
MOE_SORTING_MP_0_V2(mesh_type_, token_vec_0_, false, false), \
MOE_SORTING_MP_23(mesh_type_, token_vec_23_, false, false)); \
return ave_time; \
} \
}
#define MOR_SORTING_MP_DISPATCH_(mesh_type_, token_vec_0_, token_vec_1_, token_vec_23_) \
if(t.local_expert_masking) \
{ \
@@ -294,7 +352,7 @@ float moe_sorting(moe_sorting_trait t, moe_sorting_args a, ck_tile::stream_confi
float ave_time = \
ck_tile::launch_kernel(s, \
maybe_clear_workspace, \
MOE_SORTING_MP_0(mesh_type_, token_vec_0_, true, true), \
MOE_SORTING_MP_0_V1(mesh_type_, token_vec_0_, true, true), \
MOE_SORTING_MP_1(mesh_type_, token_vec_1_, true, true), \
MOE_SORTING_MP_23(mesh_type_, token_vec_23_, true, true)); \
return ave_time; \
@@ -304,7 +362,7 @@ float moe_sorting(moe_sorting_trait t, moe_sorting_args a, ck_tile::stream_confi
float ave_time = \
ck_tile::launch_kernel(s, \
maybe_clear_workspace, \
MOE_SORTING_MP_0(mesh_type_, token_vec_0_, true, false), \
MOE_SORTING_MP_0_V1(mesh_type_, token_vec_0_, true, false), \
MOE_SORTING_MP_1(mesh_type_, token_vec_1_, true, false), \
MOE_SORTING_MP_23(mesh_type_, token_vec_23_, true, false)); \
return ave_time; \
@@ -317,7 +375,7 @@ float moe_sorting(moe_sorting_trait t, moe_sorting_args a, ck_tile::stream_confi
float ave_time = \
ck_tile::launch_kernel(s, \
maybe_clear_workspace, \
MOE_SORTING_MP_0(mesh_type_, token_vec_0_, false, true), \
MOE_SORTING_MP_0_V1(mesh_type_, token_vec_0_, false, true), \
MOE_SORTING_MP_1(mesh_type_, token_vec_1_, false, true), \
MOE_SORTING_MP_23(mesh_type_, token_vec_23_, false, true)); \
return ave_time; \
@@ -327,7 +385,7 @@ float moe_sorting(moe_sorting_trait t, moe_sorting_args a, ck_tile::stream_confi
float ave_time = ck_tile::launch_kernel( \
s, \
maybe_clear_workspace, \
MOE_SORTING_MP_0(mesh_type_, token_vec_0_, false, false), \
MOE_SORTING_MP_0_V1(mesh_type_, token_vec_0_, false, false), \
MOE_SORTING_MP_1(mesh_type_, token_vec_1_, false, false), \
MOE_SORTING_MP_23(mesh_type_, token_vec_23_, false, false)); \
return ave_time; \
@@ -368,70 +426,140 @@ float moe_sorting_mp(moe_sorting_trait t, moe_sorting_args a, ck_tile::stream_co
}
}
};
if(ck_tile::impl::moe_sorting_get_smem_size_p23(a.num_experts) >
ck_tile::get_smem_capacity())
if(!ck_tile::is_gfx12_supported() && a.tokens < 2048)
{
if(ck_tile::impl::moe_sorting_get_smem_size_p23(a.num_experts) >
ck_tile::get_smem_capacity())
{
#if MOE_SORTING_SUPPORT_LARGE_EXPERT
if(t.local_expert_masking)
{
float ave_time = ck_tile::launch_kernel(s,
maybe_clear_workspace,
MOE_SORTING_MP_0(ms_index_t, 1, true),
MOE_SORTING_MP_1(ms_index_t, 1, true),
MOE_SORTING_MP_2(ms_index_t, 1, true),
MOE_SORTING_MP_3(ms_index_t, 1, true));
return ave_time;
}
else
{
float ave_time = ck_tile::launch_kernel(s,
maybe_clear_workspace,
MOE_SORTING_MP_0(ms_index_t, 1, false),
MOE_SORTING_MP_1(ms_index_t, 1, false),
MOE_SORTING_MP_2(ms_index_t, 1, false),
MOE_SORTING_MP_3(ms_index_t, 1, false));
return ave_time;
}
#else
printf("do not support large expert %d\n", a.num_experts);
return -1;
#endif
}
else
{
ck_tile::index_t mesh_byte_size =
ck_tile::impl::moe_sorting_mesh_byte_size(a.tokens, a.num_experts, a.topk);
if(mesh_byte_size == 1)
{
if(a.tokens * a.topk % 4 == 0)
if(t.local_expert_masking)
{
MOR_SORTING_MP_DISPATCH_(uint8_t, 4, 16, 16)
float ave_time =
ck_tile::launch_kernel(s,
MOE_SORTING_MP_0_V2(ms_index_t, 1, true),
MOE_SORTING_MP_2(ms_index_t, 1, true),
MOE_SORTING_MP_3(ms_index_t, 1, true));
return ave_time;
}
else
{
MOR_SORTING_MP_DISPATCH_(uint8_t, 1, 16, 16)
}
}
else if(mesh_byte_size == 2)
{
#if MOE_SORTING_SUPPORT_LARGE_TOPK
if(a.tokens * a.topk % 4 == 0)
{
MOR_SORTING_MP_DISPATCH_(uint16_t, 4, 8, 8)
}
else
{
MOR_SORTING_MP_DISPATCH_(uint16_t, 1, 8, 8)
float ave_time =
ck_tile::launch_kernel(s,
MOE_SORTING_MP_0_V2(ms_index_t, 1, false),
MOE_SORTING_MP_2(ms_index_t, 1, false),
MOE_SORTING_MP_3(ms_index_t, 1, false));
return ave_time;
}
#else
printf("do not support large topk %d\n", a.topk);
printf("do not support large expert %d\n", a.num_experts);
return -1;
#endif
}
else
{
MOR_SORTING_MP_DISPATCH_(ck_tile::index_t, 1, 1, 1)
ck_tile::index_t mesh_byte_size =
ck_tile::impl::moe_sorting_mesh_byte_size(a.tokens, a.num_experts, a.topk);
if(mesh_byte_size == 1)
{
if(a.tokens * a.topk % 4 == 0)
{
MOR_SORTING_MP_DISPATCH_SMALL_(uint8_t, 4, 16, 16)
}
else
{
MOR_SORTING_MP_DISPATCH_SMALL_(uint8_t, 1, 16, 16)
}
}
else if(mesh_byte_size == 2)
{
#if MOE_SORTING_SUPPORT_LARGE_TOPK
if(a.tokens * a.topk % 4 == 0)
{
MOR_SORTING_MP_DISPATCH_SMALL_(uint16_t, 4, 8, 8)
}
else
{
MOR_SORTING_MP_DISPATCH_SMALL_(uint16_t, 1, 8, 8)
}
#else
printf("do not support large topk %d\n", a.topk);
return -1;
#endif
}
else
{
MOR_SORTING_MP_DISPATCH_SMALL_(ck_tile::index_t, 1, 1, 1)
}
}
}
else
{
if(ck_tile::impl::moe_sorting_get_smem_size_p23(a.num_experts) >
ck_tile::get_smem_capacity())
{
#if MOE_SORTING_SUPPORT_LARGE_EXPERT
if(t.local_expert_masking)
{
float ave_time =
ck_tile::launch_kernel(s,
maybe_clear_workspace,
MOE_SORTING_MP_0_V1(ms_index_t, 1, true),
MOE_SORTING_MP_1(ms_index_t, 1, true),
MOE_SORTING_MP_2(ms_index_t, 1, true),
MOE_SORTING_MP_3(ms_index_t, 1, true));
return ave_time;
}
else
{
float ave_time =
ck_tile::launch_kernel(s,
maybe_clear_workspace,
MOE_SORTING_MP_0_V1(ms_index_t, 1, false),
MOE_SORTING_MP_1(ms_index_t, 1, false),
MOE_SORTING_MP_2(ms_index_t, 1, false),
MOE_SORTING_MP_3(ms_index_t, 1, false));
return ave_time;
}
#else
printf("do not support large expert %d\n", a.num_experts);
return -1;
#endif
}
else
{
ck_tile::index_t mesh_byte_size =
ck_tile::impl::moe_sorting_mesh_byte_size(a.tokens, a.num_experts, a.topk);
if(mesh_byte_size == 1)
{
if(a.tokens * a.topk % 4 == 0)
{
MOR_SORTING_MP_DISPATCH_(uint8_t, 4, 16, 16)
}
else
{
MOR_SORTING_MP_DISPATCH_(uint8_t, 1, 16, 16)
}
}
else if(mesh_byte_size == 2)
{
#if MOE_SORTING_SUPPORT_LARGE_TOPK
if(a.tokens * a.topk % 4 == 0)
{
MOR_SORTING_MP_DISPATCH_(uint16_t, 4, 8, 8)
}
else
{
MOR_SORTING_MP_DISPATCH_(uint16_t, 1, 8, 8)
}
#else
printf("do not support large topk %d\n", a.topk);
return -1;
#endif
}
else
{
MOR_SORTING_MP_DISPATCH_(ck_tile::index_t, 1, 1, 1)
}
}
}
}