[CK_TILE] moe_sorting support "local_tokens" feature for EP case (#2335)

* support local_token for hipgraph

* update README

* fix comment

* fix fmoe example
This commit is contained in:
carlushuang
2025-06-18 10:49:43 +08:00
committed by GitHub
parent c7c6a0ccb3
commit a4e1248dba
11 changed files with 495 additions and 162 deletions

View File

@@ -28,6 +28,7 @@ float fused_moe(fused_moe_traits t, fused_moe_args a, const ck_tile::stream_conf
a.topk_ids_ptr, // const void* p_topk_ids;
a.topk_weight_ptr, // const void* p_weights;
a.local_expert_mask_ptr, // const void* p_local_expert_mask;
a.local_tokens,
a.sorted_token_ids_ptr, // void* p_sorted_token_ids;
a.sorted_weight_ptr, // void* p_sorted_weights;
a.sorted_expert_ids_ptr, // void* p_sorted_expert_ids;

View File

@@ -33,15 +33,18 @@
#else
#define MOE_SORTING_DISPATCH_(sub_token_tile_, sub_token_onshot_, local_expert_masking_) \
#define MOE_SORTING_DISPATCH_( \
sub_token_tile_, sub_token_onshot_, local_expert_masking_, local_token_) \
constexpr ck_tile::index_t sub_token_tile = sub_token_tile_; \
constexpr bool sub_token_onshot = sub_token_onshot_; \
constexpr bool local_expert_masking = local_expert_masking_; \
constexpr bool local_token = local_token_; \
using ms_problem = ck_tile::MoeSortingProblemEx<index_t, \
ms_weight_type, \
sub_token_tile, \
sub_token_onshot, \
local_expert_masking>; \
local_expert_masking, \
local_token>; \
using kernel = ck_tile::MoeSortingKernel<ms_problem>; \
auto kargs = kernel::MakeKargs(a); \
const dim3 grids = kernel::GridSize(a); \
@@ -51,32 +54,43 @@
s, ck_tile::make_kernel(kernel{}, grids, blocks, lds_bytes, kargs)); \
return ave_time;
#define MOE_SORTING_DISPATCH_SUB_TOKEN_(row_, sub_token_onshot_, local_expert_masking_) \
if(row_ % 8 == 0) \
{ \
MOE_SORTING_DISPATCH_(8, sub_token_onshot_, local_expert_masking_); \
} \
else if(row_ % 4 == 0) \
{ \
MOE_SORTING_DISPATCH_(4, sub_token_onshot_, local_expert_masking_); \
} \
else if(row_ % 2 == 0) \
{ \
MOE_SORTING_DISPATCH_(2, sub_token_onshot_, local_expert_masking_); \
} \
else \
{ \
MOE_SORTING_DISPATCH_(1, sub_token_onshot_, local_expert_masking_); \
#define MOE_SORTING_DISPATCH_SUB_TOKEN_( \
row_, sub_token_onshot_, local_expert_masking_, local_token_) \
if(row_ % 8 == 0) \
{ \
MOE_SORTING_DISPATCH_(8, sub_token_onshot_, local_expert_masking_, local_token_); \
} \
else if(row_ % 4 == 0) \
{ \
MOE_SORTING_DISPATCH_(4, sub_token_onshot_, local_expert_masking_, local_token_); \
} \
else if(row_ % 2 == 0) \
{ \
MOE_SORTING_DISPATCH_(2, sub_token_onshot_, local_expert_masking_, local_token_); \
} \
else \
{ \
MOE_SORTING_DISPATCH_(1, sub_token_onshot_, local_expert_masking_, local_token_); \
}
#define MOE_SORTING_DISPATCH_SUBTO_(row_, local_expert_masking_) \
if(is_sub_token_onshot) \
{ \
MOE_SORTING_DISPATCH_SUB_TOKEN_(row_, true, local_expert_masking_) \
} \
else \
{ \
MOE_SORTING_DISPATCH_SUB_TOKEN_(row_, false, local_expert_masking_) \
#define MOE_SORTING_DISPATCH_DYNAMIC_TOKEN_(row_, sub_token_onshot_, local_expert_masking_) \
if(is_local_token) \
{ \
MOE_SORTING_DISPATCH_SUB_TOKEN_(row_, sub_token_onshot_, local_expert_masking_, true) \
} \
else \
{ \
MOE_SORTING_DISPATCH_SUB_TOKEN_(row_, sub_token_onshot_, local_expert_masking_, false) \
}
#define MOE_SORTING_DISPATCH_SUBTO_(row_, local_expert_masking_) \
if(is_sub_token_onshot) \
{ \
MOE_SORTING_DISPATCH_DYNAMIC_TOKEN_(row_, true, local_expert_masking_) \
} \
else \
{ \
MOE_SORTING_DISPATCH_DYNAMIC_TOKEN_(row_, false, local_expert_masking_) \
}
#define MOE_SORTING_DISPATCH_EMASK_(row_) \
@@ -175,6 +189,7 @@ float fused_moesorting(fused_moesorting_trait t, fused_moesorting_args a, ck_til
auto row_ = sub_token_ / 8;
bool is_sub_token_onshot = a.tokens <= sub_token_;
bool is_local_expert_masking = t.local_expert_masking;
bool is_local_token = a.p_local_tokens != nullptr;
MOE_SORTING_DISPATCH_EMASK_(row_);
// MOE_SORTING_DISPATCH_ETILE(0, 0);
@@ -183,15 +198,17 @@ float fused_moesorting(fused_moesorting_trait t, fused_moesorting_args a, ck_til
return -1;
}
#define MOE_SORTING_MP_0(mesh_type_, unroll_num_, expert_masking_) \
#define MOE_SORTING_MP_0(mesh_type_, unroll_num_, expert_masking_, local_token_) \
[&]() { \
constexpr ck_tile::index_t unroll_num = unroll_num_; \
constexpr bool expert_masking = expert_masking_; \
constexpr bool local_token = local_token_; \
using ms_problem = ck_tile::MoeSortingProblemMp<ms_index_t, \
ms_weight_type, \
mesh_type_, \
unroll_num, \
expert_masking>; \
expert_masking, \
local_token>; \
using kernel = ck_tile::MoeSortingMultiPhaseKernel_P0<ms_problem>; \
auto kargs = kernel::MakeKargs(a); \
const dim3 grids = kernel::GridSize(a); \
@@ -199,15 +216,17 @@ float fused_moesorting(fused_moesorting_trait t, fused_moesorting_args a, ck_til
return ck_tile::make_kernel<kernel::BLOCK_SIZE>(kernel{}, grids, blocks, 0, kargs); \
}()
#define MOE_SORTING_MP_1(mesh_type_, unroll_num_, expert_masking_) \
#define MOE_SORTING_MP_1(mesh_type_, unroll_num_, expert_masking_, local_token_) \
[&]() { \
constexpr ck_tile::index_t unroll_num = unroll_num_; \
constexpr bool expert_masking = expert_masking_; \
constexpr bool local_token = local_token_; \
using ms_problem = ck_tile::MoeSortingProblemMp<ms_index_t, \
ms_weight_type, \
mesh_type_, \
unroll_num, \
expert_masking>; \
expert_masking, \
local_token>; \
using kernel = ck_tile::MoeSortingMultiPhaseKernel_P1<ms_problem>; \
auto kargs = kernel::MakeKargs(a); \
const dim3 grids = kernel::GridSize(a); \
@@ -215,15 +234,17 @@ float fused_moesorting(fused_moesorting_trait t, fused_moesorting_args a, ck_til
return ck_tile::make_kernel<kernel::BLOCK_SIZE>(kernel{}, grids, blocks, 0, kargs); \
}()
#if MOE_SORTING_SUPPORT_LARGE_EXPERT
#define MOE_SORTING_MP_2(mesh_type_, unroll_num_, expert_masking_) \
#define MOE_SORTING_MP_2(mesh_type_, unroll_num_, expert_masking_, local_token_) \
[&]() { \
constexpr ck_tile::index_t unroll_num = unroll_num_; \
constexpr bool expert_masking = expert_masking_; \
constexpr bool local_token = local_token_; \
using ms_problem = ck_tile::MoeSortingProblemMp<ms_index_t, \
ms_weight_type, \
mesh_type_, \
unroll_num, \
expert_masking>; \
expert_masking, \
local_token>; \
using kernel = ck_tile::MoeSortingMultiPhaseKernel_P2<ms_problem>; \
auto kargs = kernel::MakeKargs(a); \
const dim3 grids = kernel::GridSize(a); \
@@ -231,15 +252,17 @@ float fused_moesorting(fused_moesorting_trait t, fused_moesorting_args a, ck_til
return ck_tile::make_kernel(kernel{}, grids, blocks, 0, kargs); \
}()
#define MOE_SORTING_MP_3(mesh_type_, unroll_num_, expert_masking_) \
#define MOE_SORTING_MP_3(mesh_type_, unroll_num_, expert_masking_, local_token_) \
[&]() { \
constexpr ck_tile::index_t unroll_num = unroll_num_; \
constexpr bool expert_masking = expert_masking_; \
constexpr bool local_token = local_token_; \
using ms_problem = ck_tile::MoeSortingProblemMp<ms_index_t, \
ms_weight_type, \
mesh_type_, \
unroll_num, \
expert_masking>; \
expert_masking, \
local_token>; \
using kernel = ck_tile::MoeSortingMultiPhaseKernel_P3<ms_problem>; \
auto kargs = kernel::MakeKargs(a); \
const dim3 grids = kernel::GridSize(a); \
@@ -248,15 +271,17 @@ float fused_moesorting(fused_moesorting_trait t, fused_moesorting_args a, ck_til
}()
#endif
#define MOE_SORTING_MP_23(mesh_type_, unroll_num_, expert_masking_) \
#define MOE_SORTING_MP_23(mesh_type_, unroll_num_, expert_masking_, local_token_) \
[&]() { \
constexpr ck_tile::index_t unroll_num = unroll_num_; \
constexpr bool expert_masking = expert_masking_; \
constexpr bool local_token = local_token_; \
using ms_problem = ck_tile::MoeSortingProblemMp<ms_index_t, \
ms_weight_type, \
mesh_type_, \
unroll_num, \
expert_masking>; \
expert_masking, \
local_token>; \
using kernel = ck_tile::MoeSortingMultiPhaseKernel_P23<ms_problem>; \
auto kargs = kernel::MakeKargs(a); \
const dim3 grids = kernel::GridSize(a); \
@@ -265,30 +290,55 @@ float fused_moesorting(fused_moesorting_trait t, fused_moesorting_args a, ck_til
return ck_tile::make_kernel<kernel::BLOCK_SIZE>(kernel{}, grids, blocks, lds_size, kargs); \
}()
#define MOR_SORTING_MP_DISPATCH_(mesh_type_, token_vec_0_, token_vec_1_, token_vec_23_) \
if(t.local_expert_masking) \
{ \
float ave_time = \
ck_tile::launch_kernel(s, \
MOE_SORTING_MP_0(mesh_type_, token_vec_0_, true), \
MOE_SORTING_MP_1(mesh_type_, token_vec_1_, true), \
MOE_SORTING_MP_23(mesh_type_, token_vec_23_, true)); \
return ave_time; \
} \
else \
{ \
float ave_time = \
ck_tile::launch_kernel(s, \
MOE_SORTING_MP_0(mesh_type_, token_vec_0_, false), \
MOE_SORTING_MP_1(mesh_type_, token_vec_1_, false), \
MOE_SORTING_MP_23(mesh_type_, token_vec_23_, false)); \
return ave_time; \
#define MOR_SORTING_MP_DISPATCH_(mesh_type_, token_vec_0_, token_vec_1_, token_vec_23_) \
if(t.local_expert_masking) \
{ \
if(is_local_token) \
{ \
float ave_time = \
ck_tile::launch_kernel(s, \
MOE_SORTING_MP_0(mesh_type_, token_vec_0_, true, true), \
MOE_SORTING_MP_1(mesh_type_, token_vec_1_, true, true), \
MOE_SORTING_MP_23(mesh_type_, token_vec_23_, true, true)); \
return ave_time; \
} \
else \
{ \
float ave_time = \
ck_tile::launch_kernel(s, \
MOE_SORTING_MP_0(mesh_type_, token_vec_0_, true, false), \
MOE_SORTING_MP_1(mesh_type_, token_vec_1_, true, false), \
MOE_SORTING_MP_23(mesh_type_, token_vec_23_, true, false)); \
return ave_time; \
} \
} \
else \
{ \
if(is_local_token) \
{ \
float ave_time = \
ck_tile::launch_kernel(s, \
MOE_SORTING_MP_0(mesh_type_, token_vec_0_, false, true), \
MOE_SORTING_MP_1(mesh_type_, token_vec_1_, false, true), \
MOE_SORTING_MP_23(mesh_type_, token_vec_23_, false, true)); \
return ave_time; \
} \
else \
{ \
float ave_time = ck_tile::launch_kernel( \
s, \
MOE_SORTING_MP_0(mesh_type_, token_vec_0_, false, false), \
MOE_SORTING_MP_1(mesh_type_, token_vec_1_, false, false), \
MOE_SORTING_MP_23(mesh_type_, token_vec_23_, false, false)); \
return ave_time; \
} \
}
float fused_moesorting_mp(fused_moesorting_trait t,
fused_moesorting_args a,
ck_tile::stream_config s)
{
bool is_local_token = a.p_local_tokens != nullptr;
if(t.weight_type == "fp32" && t.index_type == "int32")
{
using ms_index_t = ck_tile::index_t;
@@ -360,3 +410,8 @@ float fused_moesorting_mp(fused_moesorting_trait t,
}
return -1;
}
int fused_moesorting_get_workspace_size(int tokens, int num_experts, int topk)
{
return ck_tile::moe_sorting_get_workspace_size(tokens, num_experts, topk);
}