[CK_TILE] moe_sorting support "local_tokens" feature for EP case (#2335)

* support local_token for hipgraph

* update README

* fix comment

* fix fmoe example

[ROCm/composable_kernel commit: a4e1248dba]
This commit is contained in:
carlushuang
2025-06-18 10:49:43 +08:00
committed by GitHub
parent 5ff790eb05
commit 8660f6ef22
11 changed files with 495 additions and 162 deletions

View File

@@ -14,14 +14,24 @@ This will result in an executable `build/bin/tile_example_moe_sorting`
## example
```
args:
-v weather do CPU validation or not (default:1)
-pr_i index data type. (currently only fp32 supported now) (default:int32)
-pr_w output weight data type(currently only fp32 supported now) (default:fp32)
-t number of input tokens (default:32)
-e number of experts (default:8)
-k topk (default:2)
-st_i row stride of input, -1 means same as experts (default:-1)
-seed seed to be used, -1 means random every time (default:-1)
-kname when set to 1 it will print kernel name (default:0)
-v turn CPU validation on (1) or off (0). (default:1)
-pr_i index data type. Only int32 is currently supported. (default:int32)
-pr_w output weight data type. Only fp32 is currently supported. (default:fp32)
-t number of input tokens. (default:128)
If "local_t" presents, this value indicates global concurrency of all ranks.
-local_t Number of local input tokens for curent rank. (default:-1)
This value must be within range "[0, t)", or "-1"(no such feature)
This feature is to simulate EP case where where each rank has different tokens.
Besides, this value will be stored in a GPU buffer, which is friendly for CUDA graph.
-e number of num_experts (default:8)
-k topk (default:4)
-unit unit_size (default:32)
-moe_buf_size moe_buf_size (default:0)
-local_eid a list of experts enabled as local expert. e.g. "0,1,4,5" (default:-1)
please make sure eid is in ascending order!
-seed seed to be used. When set to -1, a random seed will be generated each time invoking this example (default:-1)
-kname prints the kernel name when set to 1 (default:0)
-warmup number of iterations before benchmark the kernel (default:5)
-repeat number of iterations to benchmark the kernel (default:20)
```

View File

@@ -18,10 +18,20 @@
auto create_args(int argc, char* argv[])
{
ck_tile::ArgParser arg_parser;
arg_parser.insert("v", "1", "weather do CPU validation or not")
.insert("pr_i", "int32", "index data type. (currently only int32 supported now)")
.insert("pr_w", "fp32", "output weight data type(currently only fp32 supported now)")
.insert("t", "128", "number of input tokens")
arg_parser.insert("v", "1", "turn CPU validation on (1) or off (0).")
.insert("pr_i", "int32", "index data type. Only int32 is currently supported.")
.insert("pr_w", "fp32", "output weight data type. Only fp32 is currently supported.")
.insert("t",
"128",
"number of input tokens.\n"
"If \"local_t\" presents, this value indicates global concurrency of all ranks.")
.insert(
"local_t",
"-1",
"Number of local input tokens for curent rank.\n"
"This value must be within range \"[0, t)\", or \"-1\"(no such feature)\n"
"This feature is to simulate EP case where where each rank has different tokens.\n"
"Besides, this value will be stored in a GPU buffer, which is friendly for CUDA graph.")
.insert("e", "8", "number of num_experts")
.insert("k", "4", "topk")
.insert("unit", "32", "unit_size")
@@ -30,8 +40,11 @@ auto create_args(int argc, char* argv[])
"-1",
"a list of experts enabled as local expert. e.g. \"0,1,4,5\"\n"
"please make sure eid is in ascending order!")
.insert("seed", "-1", "seed to be used, -1 means random every time")
.insert("kname", "0", "when set to 1 it will print kernel name")
.insert("seed",
"-1",
"seed to be used. When set to -1, a random seed will be generated each time "
"invoking this example")
.insert("kname", "0", "prints the kernel name when set to 1")
.insert("warmup", "5", "number of iterations before benchmark the kernel")
.insert("repeat", "20", "number of iterations to benchmark the kernel");
@@ -70,6 +83,7 @@ bool test_moe_sorting(ck_tile::ArgParser args)
std::string index_prec = args.get_str("pr_i");
std::string weight_prec = args.get_str("pr_w");
int tokens = args.get_int("t");
int local_tokens = args.get_int("local_t");
int num_experts = args.get_int("e");
int topk = args.get_int("k");
int seed = args.get_int("seed");
@@ -95,6 +109,16 @@ bool test_moe_sorting(ck_tile::ArgParser args)
return false;
}
// if local_tokens == tokens, not local_token, but better avoid this since no meaning for such
// case
bool is_local_token = local_tokens >= 0 && local_tokens < tokens;
if(local_tokens > tokens)
{
printf("local_tokens:%d larger than tokens:%d, invalid\n", local_tokens, tokens);
return false;
}
bool local_expert_masking = args.get_str("local_eid") != "-1";
auto local_expert_masking_host = [&]() {
if(local_expert_masking)
@@ -143,6 +167,13 @@ bool test_moe_sorting(ck_tile::ArgParser args)
ck_tile::DeviceMem local_expert_masking_dev(
local_expert_masking_host.get_element_space_size_in_bytes());
// used for simulating dynamic_tokens for EP case
ck_tile::DeviceMem local_tokens_dev(sizeof(ck_tile::index_t));
if(is_local_token)
{
local_tokens_dev.ToDevice(&local_tokens);
}
topk_ids_dev.ToDevice(topk_ids_host.data());
weights_dev.ToDevice(weights_host.data());
if(moe_buf_size > 0)
@@ -164,6 +195,7 @@ bool test_moe_sorting(ck_tile::ArgParser args)
weights_dev.GetDeviceBuffer(),
local_expert_masking ? local_expert_masking_dev.GetDeviceBuffer()
: nullptr,
is_local_token ? local_tokens_dev.GetDeviceBuffer() : nullptr,
sorted_ids_dev.GetDeviceBuffer(),
sorted_weights_dev.GetDeviceBuffer(),
sorted_expert_ids_dev.GetDeviceBuffer(),
@@ -236,13 +268,12 @@ bool test_moe_sorting(ck_tile::ArgParser args)
}
#endif
printf("[%s|%s]tokens:%d, num_experts:%d, topk:%d, mp:%d, ",
index_prec.c_str(),
weight_prec.c_str(),
tokens,
num_experts,
topk,
workspace_size != 0 ? 1 : 0);
printf("[%s|%s]tokens:%d", index_prec.c_str(), weight_prec.c_str(), tokens);
if(is_local_token)
{
printf("(%d)", local_tokens);
}
printf(", num_experts:%d, topk:%d, mp:%d, ", num_experts, topk, workspace_size != 0 ? 1 : 0);
if(local_expert_masking)
{
@@ -285,6 +316,8 @@ bool test_moe_sorting(ck_tile::ArgParser args)
ref_total_tokens_post_pad,
num_experts,
unit_size,
is_local_token ? local_tokens
: tokens,
local_expert_masking);
printf("total_tokens_post_pad:%d(%d), ",
ref_total_tokens_post_pad,

View File

@@ -33,15 +33,18 @@
#else
#define MOE_SORTING_DISPATCH_(sub_token_tile_, sub_token_onshot_, local_expert_masking_) \
#define MOE_SORTING_DISPATCH_( \
sub_token_tile_, sub_token_onshot_, local_expert_masking_, local_token_) \
constexpr ck_tile::index_t sub_token_tile = sub_token_tile_; \
constexpr bool sub_token_onshot = sub_token_onshot_; \
constexpr bool local_expert_masking = local_expert_masking_; \
constexpr bool local_token = local_token_; \
using ms_problem = ck_tile::MoeSortingProblemEx<index_t, \
ms_weight_type, \
sub_token_tile, \
sub_token_onshot, \
local_expert_masking>; \
local_expert_masking, \
local_token>; \
using kernel = ck_tile::MoeSortingKernel<ms_problem>; \
auto kargs = kernel::MakeKargs(a); \
const dim3 grids = kernel::GridSize(a); \
@@ -51,32 +54,43 @@
s, ck_tile::make_kernel(kernel{}, grids, blocks, lds_bytes, kargs)); \
return ave_time;
#define MOE_SORTING_DISPATCH_SUB_TOKEN_(row_, sub_token_onshot_, local_expert_masking_) \
if(row_ % 8 == 0) \
{ \
MOE_SORTING_DISPATCH_(8, sub_token_onshot_, local_expert_masking_); \
} \
else if(row_ % 4 == 0) \
{ \
MOE_SORTING_DISPATCH_(4, sub_token_onshot_, local_expert_masking_); \
} \
else if(row_ % 2 == 0) \
{ \
MOE_SORTING_DISPATCH_(2, sub_token_onshot_, local_expert_masking_); \
} \
else \
{ \
MOE_SORTING_DISPATCH_(1, sub_token_onshot_, local_expert_masking_); \
#define MOE_SORTING_DISPATCH_SUB_TOKEN_( \
row_, sub_token_onshot_, local_expert_masking_, local_token_) \
if(row_ % 8 == 0) \
{ \
MOE_SORTING_DISPATCH_(8, sub_token_onshot_, local_expert_masking_, local_token_); \
} \
else if(row_ % 4 == 0) \
{ \
MOE_SORTING_DISPATCH_(4, sub_token_onshot_, local_expert_masking_, local_token_); \
} \
else if(row_ % 2 == 0) \
{ \
MOE_SORTING_DISPATCH_(2, sub_token_onshot_, local_expert_masking_, local_token_); \
} \
else \
{ \
MOE_SORTING_DISPATCH_(1, sub_token_onshot_, local_expert_masking_, local_token_); \
}
#define MOE_SORTING_DISPATCH_SUBTO_(row_, local_expert_masking_) \
if(is_sub_token_onshot) \
{ \
MOE_SORTING_DISPATCH_SUB_TOKEN_(row_, true, local_expert_masking_) \
} \
else \
{ \
MOE_SORTING_DISPATCH_SUB_TOKEN_(row_, false, local_expert_masking_) \
#define MOE_SORTING_DISPATCH_DYNAMIC_TOKEN_(row_, sub_token_onshot_, local_expert_masking_) \
if(is_local_token) \
{ \
MOE_SORTING_DISPATCH_SUB_TOKEN_(row_, sub_token_onshot_, local_expert_masking_, true) \
} \
else \
{ \
MOE_SORTING_DISPATCH_SUB_TOKEN_(row_, sub_token_onshot_, local_expert_masking_, false) \
}
#define MOE_SORTING_DISPATCH_SUBTO_(row_, local_expert_masking_) \
if(is_sub_token_onshot) \
{ \
MOE_SORTING_DISPATCH_DYNAMIC_TOKEN_(row_, true, local_expert_masking_) \
} \
else \
{ \
MOE_SORTING_DISPATCH_DYNAMIC_TOKEN_(row_, false, local_expert_masking_) \
}
#define MOE_SORTING_DISPATCH_EMASK_(row_) \
@@ -171,6 +185,7 @@ float moe_sorting(moe_sorting_trait t, moe_sorting_args a, ck_tile::stream_confi
auto row_ = sub_token_ / 8;
bool is_sub_token_onshot = a.tokens <= sub_token_;
bool is_local_expert_masking = t.local_expert_masking;
bool is_local_token = a.p_local_tokens != nullptr;
MOE_SORTING_DISPATCH_EMASK_(row_);
// MOE_SORTING_DISPATCH_ETILE(0, 0);
@@ -179,15 +194,17 @@ float moe_sorting(moe_sorting_trait t, moe_sorting_args a, ck_tile::stream_confi
return -1;
}
#define MOE_SORTING_MP_0(mesh_type_, unroll_num_, expert_masking_) \
#define MOE_SORTING_MP_0(mesh_type_, unroll_num_, expert_masking_, local_token_) \
[&]() { \
constexpr ck_tile::index_t unroll_num = unroll_num_; \
constexpr bool expert_masking = expert_masking_; \
constexpr bool local_token = local_token_; \
using ms_problem = ck_tile::MoeSortingProblemMp<ms_index_t, \
ms_weight_type, \
mesh_type_, \
unroll_num, \
expert_masking>; \
expert_masking, \
local_token>; \
using kernel = ck_tile::MoeSortingMultiPhaseKernel_P0<ms_problem>; \
auto kargs = kernel::MakeKargs(a); \
const dim3 grids = kernel::GridSize(a); \
@@ -195,15 +212,17 @@ float moe_sorting(moe_sorting_trait t, moe_sorting_args a, ck_tile::stream_confi
return ck_tile::make_kernel<kernel::BLOCK_SIZE>(kernel{}, grids, blocks, 0, kargs); \
}()
#define MOE_SORTING_MP_1(mesh_type_, unroll_num_, expert_masking_) \
#define MOE_SORTING_MP_1(mesh_type_, unroll_num_, expert_masking_, local_token_) \
[&]() { \
constexpr ck_tile::index_t unroll_num = unroll_num_; \
constexpr bool expert_masking = expert_masking_; \
constexpr bool local_token = local_token_; \
using ms_problem = ck_tile::MoeSortingProblemMp<ms_index_t, \
ms_weight_type, \
mesh_type_, \
unroll_num, \
expert_masking>; \
expert_masking, \
local_token>; \
using kernel = ck_tile::MoeSortingMultiPhaseKernel_P1<ms_problem>; \
auto kargs = kernel::MakeKargs(a); \
const dim3 grids = kernel::GridSize(a); \
@@ -211,15 +230,17 @@ float moe_sorting(moe_sorting_trait t, moe_sorting_args a, ck_tile::stream_confi
return ck_tile::make_kernel<kernel::BLOCK_SIZE>(kernel{}, grids, blocks, 0, kargs); \
}()
#if MOE_SORTING_SUPPORT_LARGE_EXPERT
#define MOE_SORTING_MP_2(mesh_type_, unroll_num_, expert_masking_) \
#define MOE_SORTING_MP_2(mesh_type_, unroll_num_, expert_masking_, local_token_) \
[&]() { \
constexpr ck_tile::index_t unroll_num = unroll_num_; \
constexpr bool expert_masking = expert_masking_; \
constexpr bool local_token = local_token_; \
using ms_problem = ck_tile::MoeSortingProblemMp<ms_index_t, \
ms_weight_type, \
mesh_type_, \
unroll_num, \
expert_masking>; \
expert_masking, \
local_token>; \
using kernel = ck_tile::MoeSortingMultiPhaseKernel_P2<ms_problem>; \
auto kargs = kernel::MakeKargs(a); \
const dim3 grids = kernel::GridSize(a); \
@@ -227,15 +248,17 @@ float moe_sorting(moe_sorting_trait t, moe_sorting_args a, ck_tile::stream_confi
return ck_tile::make_kernel(kernel{}, grids, blocks, 0, kargs); \
}()
#define MOE_SORTING_MP_3(mesh_type_, unroll_num_, expert_masking_) \
#define MOE_SORTING_MP_3(mesh_type_, unroll_num_, expert_masking_, local_token_) \
[&]() { \
constexpr ck_tile::index_t unroll_num = unroll_num_; \
constexpr bool expert_masking = expert_masking_; \
constexpr bool local_token = local_token_; \
using ms_problem = ck_tile::MoeSortingProblemMp<ms_index_t, \
ms_weight_type, \
mesh_type_, \
unroll_num, \
expert_masking>; \
expert_masking, \
local_token>; \
using kernel = ck_tile::MoeSortingMultiPhaseKernel_P3<ms_problem>; \
auto kargs = kernel::MakeKargs(a); \
const dim3 grids = kernel::GridSize(a); \
@@ -244,15 +267,17 @@ float moe_sorting(moe_sorting_trait t, moe_sorting_args a, ck_tile::stream_confi
}()
#endif
#define MOE_SORTING_MP_23(mesh_type_, unroll_num_, expert_masking_) \
#define MOE_SORTING_MP_23(mesh_type_, unroll_num_, expert_masking_, local_token_) \
[&]() { \
constexpr ck_tile::index_t unroll_num = unroll_num_; \
constexpr bool expert_masking = expert_masking_; \
constexpr bool local_token = local_token_; \
using ms_problem = ck_tile::MoeSortingProblemMp<ms_index_t, \
ms_weight_type, \
mesh_type_, \
unroll_num, \
expert_masking>; \
expert_masking, \
local_token>; \
using kernel = ck_tile::MoeSortingMultiPhaseKernel_P23<ms_problem>; \
auto kargs = kernel::MakeKargs(a); \
const dim3 grids = kernel::GridSize(a); \
@@ -261,28 +286,53 @@ float moe_sorting(moe_sorting_trait t, moe_sorting_args a, ck_tile::stream_confi
return ck_tile::make_kernel<kernel::BLOCK_SIZE>(kernel{}, grids, blocks, lds_size, kargs); \
}()
#define MOR_SORTING_MP_DISPATCH_(mesh_type_, token_vec_0_, token_vec_1_, token_vec_23_) \
if(t.local_expert_masking) \
{ \
float ave_time = \
ck_tile::launch_kernel(s, \
MOE_SORTING_MP_0(mesh_type_, token_vec_0_, true), \
MOE_SORTING_MP_1(mesh_type_, token_vec_1_, true), \
MOE_SORTING_MP_23(mesh_type_, token_vec_23_, true)); \
return ave_time; \
} \
else \
{ \
float ave_time = \
ck_tile::launch_kernel(s, \
MOE_SORTING_MP_0(mesh_type_, token_vec_0_, false), \
MOE_SORTING_MP_1(mesh_type_, token_vec_1_, false), \
MOE_SORTING_MP_23(mesh_type_, token_vec_23_, false)); \
return ave_time; \
#define MOR_SORTING_MP_DISPATCH_(mesh_type_, token_vec_0_, token_vec_1_, token_vec_23_) \
if(t.local_expert_masking) \
{ \
if(is_local_token) \
{ \
float ave_time = \
ck_tile::launch_kernel(s, \
MOE_SORTING_MP_0(mesh_type_, token_vec_0_, true, true), \
MOE_SORTING_MP_1(mesh_type_, token_vec_1_, true, true), \
MOE_SORTING_MP_23(mesh_type_, token_vec_23_, true, true)); \
return ave_time; \
} \
else \
{ \
float ave_time = \
ck_tile::launch_kernel(s, \
MOE_SORTING_MP_0(mesh_type_, token_vec_0_, true, false), \
MOE_SORTING_MP_1(mesh_type_, token_vec_1_, true, false), \
MOE_SORTING_MP_23(mesh_type_, token_vec_23_, true, false)); \
return ave_time; \
} \
} \
else \
{ \
if(is_local_token) \
{ \
float ave_time = \
ck_tile::launch_kernel(s, \
MOE_SORTING_MP_0(mesh_type_, token_vec_0_, false, true), \
MOE_SORTING_MP_1(mesh_type_, token_vec_1_, false, true), \
MOE_SORTING_MP_23(mesh_type_, token_vec_23_, false, true)); \
return ave_time; \
} \
else \
{ \
float ave_time = ck_tile::launch_kernel( \
s, \
MOE_SORTING_MP_0(mesh_type_, token_vec_0_, false, false), \
MOE_SORTING_MP_1(mesh_type_, token_vec_1_, false, false), \
MOE_SORTING_MP_23(mesh_type_, token_vec_23_, false, false)); \
return ave_time; \
} \
}
float moe_sorting_mp(moe_sorting_trait t, moe_sorting_args a, ck_tile::stream_config s)
{
bool is_local_token = a.p_local_tokens != nullptr;
if(t.weight_type == "fp32" && t.index_type == "int32")
{
using ms_index_t = ck_tile::index_t;

View File

@@ -31,4 +31,14 @@ $EXE -t=8192 -e=32 -k=5 -moe_buf_size=163840
$EXE -t=8192 -e=32 -k=8 -moe_buf_size=163840
$EXE -t=8192 -e=256 -k=5 -moe_buf_size=163840
$EXE -t=8192 -e=256 -k=8 -moe_buf_size=163840
$EXE -t=163840 -e=256 -k=8 -moe_buf_size=163840
$EXE -t=163840 -e=256 -k=8 -moe_buf_size=163840
$EXE -t=12 -local_t=3 -e=256 -k=5 -local_eid=9,10,199,145
$EXE -t=67 -local_t=9 -e=555 -k=5 -local_eid=19,23,24,25,26,99
$EXE -t=99 -local_t=93 -e=121 -moe_buf_size=10244
$EXE -t=536 -local_t=345 -e=802 -k=99
$EXE -t=331 -local_t=39 -e=83 -k=33
$EXE -t=765 -local_t=654 -e=783 -k=8
$EXE -t=23 -local_t=9 -e=1 -k=1
$EXE -t=7 -local_t=0 -e=89 -k=1 -local_eid=0,8,12,33
$EXE -t=61 -local_t=0 -e=333 -k=99 -local_eid=0,8,12,33
$EXE -t=133940 -local_t=111921 -e=256 -k=17 -moe_buf_size=133940

View File

@@ -16,6 +16,7 @@ struct fused_moe_args
const void* d_scale_ptr; // [e, 1, k], down scale
const void* y_smooth_scale_ptr; // [e, 1, n], smooth-quant-scale for 2nd gemm input
const void* local_expert_mask_ptr; // [e], local_expert_mask_ptr for EP
const void* local_tokens; // [1] if not nullptr, tokens read from here
void* o_ptr; // [m, k], output token (no need to do zeroing)
void* ws_ptr; // size is moe_sorting_get_workspace_size()
// if return zero, then could be nullptr

View File

@@ -28,6 +28,7 @@ float fused_moe(fused_moe_traits t, fused_moe_args a, const ck_tile::stream_conf
a.topk_ids_ptr, // const void* p_topk_ids;
a.topk_weight_ptr, // const void* p_weights;
a.local_expert_mask_ptr, // const void* p_local_expert_mask;
a.local_tokens,
a.sorted_token_ids_ptr, // void* p_sorted_token_ids;
a.sorted_weight_ptr, // void* p_sorted_weights;
a.sorted_expert_ids_ptr, // void* p_sorted_expert_ids;

View File

@@ -33,15 +33,18 @@
#else
#define MOE_SORTING_DISPATCH_(sub_token_tile_, sub_token_onshot_, local_expert_masking_) \
#define MOE_SORTING_DISPATCH_( \
sub_token_tile_, sub_token_onshot_, local_expert_masking_, local_token_) \
constexpr ck_tile::index_t sub_token_tile = sub_token_tile_; \
constexpr bool sub_token_onshot = sub_token_onshot_; \
constexpr bool local_expert_masking = local_expert_masking_; \
constexpr bool local_token = local_token_; \
using ms_problem = ck_tile::MoeSortingProblemEx<index_t, \
ms_weight_type, \
sub_token_tile, \
sub_token_onshot, \
local_expert_masking>; \
local_expert_masking, \
local_token>; \
using kernel = ck_tile::MoeSortingKernel<ms_problem>; \
auto kargs = kernel::MakeKargs(a); \
const dim3 grids = kernel::GridSize(a); \
@@ -51,32 +54,43 @@
s, ck_tile::make_kernel(kernel{}, grids, blocks, lds_bytes, kargs)); \
return ave_time;
#define MOE_SORTING_DISPATCH_SUB_TOKEN_(row_, sub_token_onshot_, local_expert_masking_) \
if(row_ % 8 == 0) \
{ \
MOE_SORTING_DISPATCH_(8, sub_token_onshot_, local_expert_masking_); \
} \
else if(row_ % 4 == 0) \
{ \
MOE_SORTING_DISPATCH_(4, sub_token_onshot_, local_expert_masking_); \
} \
else if(row_ % 2 == 0) \
{ \
MOE_SORTING_DISPATCH_(2, sub_token_onshot_, local_expert_masking_); \
} \
else \
{ \
MOE_SORTING_DISPATCH_(1, sub_token_onshot_, local_expert_masking_); \
#define MOE_SORTING_DISPATCH_SUB_TOKEN_( \
row_, sub_token_onshot_, local_expert_masking_, local_token_) \
if(row_ % 8 == 0) \
{ \
MOE_SORTING_DISPATCH_(8, sub_token_onshot_, local_expert_masking_, local_token_); \
} \
else if(row_ % 4 == 0) \
{ \
MOE_SORTING_DISPATCH_(4, sub_token_onshot_, local_expert_masking_, local_token_); \
} \
else if(row_ % 2 == 0) \
{ \
MOE_SORTING_DISPATCH_(2, sub_token_onshot_, local_expert_masking_, local_token_); \
} \
else \
{ \
MOE_SORTING_DISPATCH_(1, sub_token_onshot_, local_expert_masking_, local_token_); \
}
#define MOE_SORTING_DISPATCH_SUBTO_(row_, local_expert_masking_) \
if(is_sub_token_onshot) \
{ \
MOE_SORTING_DISPATCH_SUB_TOKEN_(row_, true, local_expert_masking_) \
} \
else \
{ \
MOE_SORTING_DISPATCH_SUB_TOKEN_(row_, false, local_expert_masking_) \
#define MOE_SORTING_DISPATCH_DYNAMIC_TOKEN_(row_, sub_token_onshot_, local_expert_masking_) \
if(is_local_token) \
{ \
MOE_SORTING_DISPATCH_SUB_TOKEN_(row_, sub_token_onshot_, local_expert_masking_, true) \
} \
else \
{ \
MOE_SORTING_DISPATCH_SUB_TOKEN_(row_, sub_token_onshot_, local_expert_masking_, false) \
}
#define MOE_SORTING_DISPATCH_SUBTO_(row_, local_expert_masking_) \
if(is_sub_token_onshot) \
{ \
MOE_SORTING_DISPATCH_DYNAMIC_TOKEN_(row_, true, local_expert_masking_) \
} \
else \
{ \
MOE_SORTING_DISPATCH_DYNAMIC_TOKEN_(row_, false, local_expert_masking_) \
}
#define MOE_SORTING_DISPATCH_EMASK_(row_) \
@@ -175,6 +189,7 @@ float fused_moesorting(fused_moesorting_trait t, fused_moesorting_args a, ck_til
auto row_ = sub_token_ / 8;
bool is_sub_token_onshot = a.tokens <= sub_token_;
bool is_local_expert_masking = t.local_expert_masking;
bool is_local_token = a.p_local_tokens != nullptr;
MOE_SORTING_DISPATCH_EMASK_(row_);
// MOE_SORTING_DISPATCH_ETILE(0, 0);
@@ -183,15 +198,17 @@ float fused_moesorting(fused_moesorting_trait t, fused_moesorting_args a, ck_til
return -1;
}
#define MOE_SORTING_MP_0(mesh_type_, unroll_num_, expert_masking_) \
#define MOE_SORTING_MP_0(mesh_type_, unroll_num_, expert_masking_, local_token_) \
[&]() { \
constexpr ck_tile::index_t unroll_num = unroll_num_; \
constexpr bool expert_masking = expert_masking_; \
constexpr bool local_token = local_token_; \
using ms_problem = ck_tile::MoeSortingProblemMp<ms_index_t, \
ms_weight_type, \
mesh_type_, \
unroll_num, \
expert_masking>; \
expert_masking, \
local_token>; \
using kernel = ck_tile::MoeSortingMultiPhaseKernel_P0<ms_problem>; \
auto kargs = kernel::MakeKargs(a); \
const dim3 grids = kernel::GridSize(a); \
@@ -199,15 +216,17 @@ float fused_moesorting(fused_moesorting_trait t, fused_moesorting_args a, ck_til
return ck_tile::make_kernel<kernel::BLOCK_SIZE>(kernel{}, grids, blocks, 0, kargs); \
}()
#define MOE_SORTING_MP_1(mesh_type_, unroll_num_, expert_masking_) \
#define MOE_SORTING_MP_1(mesh_type_, unroll_num_, expert_masking_, local_token_) \
[&]() { \
constexpr ck_tile::index_t unroll_num = unroll_num_; \
constexpr bool expert_masking = expert_masking_; \
constexpr bool local_token = local_token_; \
using ms_problem = ck_tile::MoeSortingProblemMp<ms_index_t, \
ms_weight_type, \
mesh_type_, \
unroll_num, \
expert_masking>; \
expert_masking, \
local_token>; \
using kernel = ck_tile::MoeSortingMultiPhaseKernel_P1<ms_problem>; \
auto kargs = kernel::MakeKargs(a); \
const dim3 grids = kernel::GridSize(a); \
@@ -215,15 +234,17 @@ float fused_moesorting(fused_moesorting_trait t, fused_moesorting_args a, ck_til
return ck_tile::make_kernel<kernel::BLOCK_SIZE>(kernel{}, grids, blocks, 0, kargs); \
}()
#if MOE_SORTING_SUPPORT_LARGE_EXPERT
#define MOE_SORTING_MP_2(mesh_type_, unroll_num_, expert_masking_) \
#define MOE_SORTING_MP_2(mesh_type_, unroll_num_, expert_masking_, local_token_) \
[&]() { \
constexpr ck_tile::index_t unroll_num = unroll_num_; \
constexpr bool expert_masking = expert_masking_; \
constexpr bool local_token = local_token_; \
using ms_problem = ck_tile::MoeSortingProblemMp<ms_index_t, \
ms_weight_type, \
mesh_type_, \
unroll_num, \
expert_masking>; \
expert_masking, \
local_token>; \
using kernel = ck_tile::MoeSortingMultiPhaseKernel_P2<ms_problem>; \
auto kargs = kernel::MakeKargs(a); \
const dim3 grids = kernel::GridSize(a); \
@@ -231,15 +252,17 @@ float fused_moesorting(fused_moesorting_trait t, fused_moesorting_args a, ck_til
return ck_tile::make_kernel(kernel{}, grids, blocks, 0, kargs); \
}()
#define MOE_SORTING_MP_3(mesh_type_, unroll_num_, expert_masking_) \
#define MOE_SORTING_MP_3(mesh_type_, unroll_num_, expert_masking_, local_token_) \
[&]() { \
constexpr ck_tile::index_t unroll_num = unroll_num_; \
constexpr bool expert_masking = expert_masking_; \
constexpr bool local_token = local_token_; \
using ms_problem = ck_tile::MoeSortingProblemMp<ms_index_t, \
ms_weight_type, \
mesh_type_, \
unroll_num, \
expert_masking>; \
expert_masking, \
local_token>; \
using kernel = ck_tile::MoeSortingMultiPhaseKernel_P3<ms_problem>; \
auto kargs = kernel::MakeKargs(a); \
const dim3 grids = kernel::GridSize(a); \
@@ -248,15 +271,17 @@ float fused_moesorting(fused_moesorting_trait t, fused_moesorting_args a, ck_til
}()
#endif
#define MOE_SORTING_MP_23(mesh_type_, unroll_num_, expert_masking_) \
#define MOE_SORTING_MP_23(mesh_type_, unroll_num_, expert_masking_, local_token_) \
[&]() { \
constexpr ck_tile::index_t unroll_num = unroll_num_; \
constexpr bool expert_masking = expert_masking_; \
constexpr bool local_token = local_token_; \
using ms_problem = ck_tile::MoeSortingProblemMp<ms_index_t, \
ms_weight_type, \
mesh_type_, \
unroll_num, \
expert_masking>; \
expert_masking, \
local_token>; \
using kernel = ck_tile::MoeSortingMultiPhaseKernel_P23<ms_problem>; \
auto kargs = kernel::MakeKargs(a); \
const dim3 grids = kernel::GridSize(a); \
@@ -265,30 +290,55 @@ float fused_moesorting(fused_moesorting_trait t, fused_moesorting_args a, ck_til
return ck_tile::make_kernel<kernel::BLOCK_SIZE>(kernel{}, grids, blocks, lds_size, kargs); \
}()
#define MOR_SORTING_MP_DISPATCH_(mesh_type_, token_vec_0_, token_vec_1_, token_vec_23_) \
if(t.local_expert_masking) \
{ \
float ave_time = \
ck_tile::launch_kernel(s, \
MOE_SORTING_MP_0(mesh_type_, token_vec_0_, true), \
MOE_SORTING_MP_1(mesh_type_, token_vec_1_, true), \
MOE_SORTING_MP_23(mesh_type_, token_vec_23_, true)); \
return ave_time; \
} \
else \
{ \
float ave_time = \
ck_tile::launch_kernel(s, \
MOE_SORTING_MP_0(mesh_type_, token_vec_0_, false), \
MOE_SORTING_MP_1(mesh_type_, token_vec_1_, false), \
MOE_SORTING_MP_23(mesh_type_, token_vec_23_, false)); \
return ave_time; \
#define MOR_SORTING_MP_DISPATCH_(mesh_type_, token_vec_0_, token_vec_1_, token_vec_23_) \
if(t.local_expert_masking) \
{ \
if(is_local_token) \
{ \
float ave_time = \
ck_tile::launch_kernel(s, \
MOE_SORTING_MP_0(mesh_type_, token_vec_0_, true, true), \
MOE_SORTING_MP_1(mesh_type_, token_vec_1_, true, true), \
MOE_SORTING_MP_23(mesh_type_, token_vec_23_, true, true)); \
return ave_time; \
} \
else \
{ \
float ave_time = \
ck_tile::launch_kernel(s, \
MOE_SORTING_MP_0(mesh_type_, token_vec_0_, true, false), \
MOE_SORTING_MP_1(mesh_type_, token_vec_1_, true, false), \
MOE_SORTING_MP_23(mesh_type_, token_vec_23_, true, false)); \
return ave_time; \
} \
} \
else \
{ \
if(is_local_token) \
{ \
float ave_time = \
ck_tile::launch_kernel(s, \
MOE_SORTING_MP_0(mesh_type_, token_vec_0_, false, true), \
MOE_SORTING_MP_1(mesh_type_, token_vec_1_, false, true), \
MOE_SORTING_MP_23(mesh_type_, token_vec_23_, false, true)); \
return ave_time; \
} \
else \
{ \
float ave_time = ck_tile::launch_kernel( \
s, \
MOE_SORTING_MP_0(mesh_type_, token_vec_0_, false, false), \
MOE_SORTING_MP_1(mesh_type_, token_vec_1_, false, false), \
MOE_SORTING_MP_23(mesh_type_, token_vec_23_, false, false)); \
return ave_time; \
} \
}
float fused_moesorting_mp(fused_moesorting_trait t,
fused_moesorting_args a,
ck_tile::stream_config s)
{
bool is_local_token = a.p_local_tokens != nullptr;
if(t.weight_type == "fp32" && t.index_type == "int32")
{
using ms_index_t = ck_tile::index_t;
@@ -360,3 +410,8 @@ float fused_moesorting_mp(fused_moesorting_trait t,
}
return -1;
}
int fused_moesorting_get_workspace_size(int tokens, int num_experts, int topk)
{
return ck_tile::moe_sorting_get_workspace_size(tokens, num_experts, topk);
}

View File

@@ -87,7 +87,18 @@ void topid_unique_gen(
auto create_args(int argc, char* argv[])
{
ck_tile::ArgParser arg_parser;
arg_parser.insert("t", "128", "num input tokens")
arg_parser
.insert("t",
"128",
"number of input tokens.\n"
"If \"local_t\" presents, this value indicates global concurrency of all ranks.")
.insert(
"local_t",
"-1",
"Number of local input tokens for curent rank.\n"
"This value must be within range \"[0, t)\", or \"-1\"(no such feature)\n"
"This feature is to simulate EP case where where each rank has different tokens.\n"
"Besides, this value will be stored in a GPU buffer, which is friendly for CUDA graph.")
.insert("e", "32", "num of experts")
.insert("k", "5", "topk")
.insert("h", "8192", "hidden_size of this model")
@@ -131,6 +142,7 @@ template <typename I, typename W, typename O, typename ST, typename SW, typename
bool run(const ck_tile::ArgParser& arg_parser)
{
ck_tile::index_t tokens = arg_parser.get_int("t");
ck_tile::index_t local_tokens = arg_parser.get_int("local_t");
ck_tile::index_t experts = arg_parser.get_int("e");
ck_tile::index_t topk = arg_parser.get_int("k");
ck_tile::index_t hidden_size = arg_parser.get_int("h");
@@ -169,6 +181,14 @@ bool run(const ck_tile::ArgParser& arg_parser)
// w1 (Down, N size)
ck_tile::index_t shared_intermediate_size_1 = intermediate_size / tp;
bool is_local_token = local_tokens >= 0 && local_tokens < tokens;
if(local_tokens > tokens)
{
printf("local_tokens:%d larger than tokens:%d, invalid\n", local_tokens, tokens);
return false;
}
auto prec_str = [&]() {
auto base_str = prec_i;
if(prec_i != prec_w)
@@ -198,11 +218,17 @@ bool run(const ck_tile::ArgParser& arg_parser)
return std::string(", st:") + std::to_string(stride);
}();
std::cout << "[" << api_str << "|" << prec_str << "]"
<< " t:" << tokens;
if(is_local_token)
{
std::cout << "(" << local_tokens << ")";
}
std::cout
<< "[" << api_str << "|" << prec_str << "]"
<< " t:" << tokens << ", e:" << experts << ", k:" << topk << stride_str
<< ", hidden:" << hidden_size << ", interm:" << intermediate_size << ", tp:" << tp
<< ", act:"
<< ", e:" << experts << ", k:" << topk << stride_str << ", hidden:" << hidden_size
<< ", interm:" << intermediate_size << ", tp:" << tp << ", act:"
<< activation
// << ", shrd_interm:" << shared_intermediate_size_0 << "|" << shared_intermediate_size_1
<< (gate_only ? ", g1u0" : ", g1u1") << ", q:" << fused_quant << std::flush;
@@ -377,6 +403,11 @@ bool run(const ck_tile::ArgParser& arg_parser)
ck_tile::DeviceMem moe_sorting_ws(workspace_size != 0 ? workspace_size : 0);
if(workspace_size != 0)
moe_sorting_ws.SetZero(); // note, clear here!!!!
ck_tile::DeviceMem local_tokens_dev(sizeof(ck_tile::index_t));
if(is_local_token)
{
local_tokens_dev.ToDevice(&local_tokens);
}
fused_moe_traits traits{prec_i,
prec_w,
@@ -400,6 +431,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
fused_quant == 1 ? sy_buf.GetDeviceBuffer() : nullptr,
local_expert_masking ? local_expert_mask_buf.GetDeviceBuffer()
: nullptr,
is_local_token ? local_tokens_dev.GetDeviceBuffer() : nullptr,
o_buf.GetDeviceBuffer(),
workspace_size != 0 ? moe_sorting_ws.GetDeviceBuffer() : nullptr,
topk_ids_buf.GetDeviceBuffer(),
@@ -463,6 +495,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
num_sorted_tiles_host.mData[0],
experts,
block_m,
is_local_token ? local_tokens : tokens,
local_expert_masking);
if(activation == 0)
{
@@ -495,6 +528,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
num_sorted_tiles_host.mData[0],
experts,
block_m,
is_local_token ? local_tokens : tokens,
local_expert_masking);
// done, preparing GPU buffer
@@ -506,6 +540,11 @@ bool run(const ck_tile::ArgParser& arg_parser)
ck_tile::DeviceMem sd_buf(sd_host);
ck_tile::DeviceMem sy_buf(sy_host);
ck_tile::DeviceMem o_buf(o_host);
ck_tile::DeviceMem local_tokens_dev(sizeof(ck_tile::index_t));
if(is_local_token)
{
local_tokens_dev.ToDevice(&local_tokens);
}
// manually clear output buffer for atomic
o_buf.SetZero();
@@ -542,7 +581,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
num_sorted_tiles_buf.GetDeviceBuffer(),
hidden_size,
intermediate_size / tp,
tokens,
is_local_token ? local_tokens : tokens,
experts,
topk,
stride};

View File

@@ -21,10 +21,12 @@ CK_TILE_HOST void reference_moe_sorting(const HostTensor<IndexType>& topk_ids,
index_t& unit_cnt,
const index_t experts,
const index_t unit_size,
const index_t tokens,
bool local_expert_masking,
bool skip_experts_with_zero_token = true)
{
const index_t num_token = topk_ids.mDesc.get_lengths()[0];
// note: if tokens is smaller than topk_ids.mDesc.get_lengths()[0], indicating local_token case
const index_t num_token = tokens; // topk_ids.mDesc.get_lengths()[0];
const index_t topk = topk_ids.mDesc.get_lengths()[1];
// allocate a temp buffer, and fill the value with [number_token|topk]
std::vector<std::vector<IndexType>> expert_tokens(

View File

@@ -165,7 +165,8 @@ struct MoeSortingHostArgs
const void* p_topk_ids; // [token, topk]
const void* p_weights; // [token, topk]
const void* p_local_expert_mask;
const void* p_local_expert_mask; // [experts]
const void* p_local_tokens; // [1] if not nullptr, tokens read from here
void* p_sorted_token_ids;
void* p_sorted_weights;
@@ -177,7 +178,7 @@ struct MoeSortingHostArgs
void* p_ws; // size is moe_sorting_get_workspace_size()
// if return zero, then could be nullptr
// must be cleard before use
index_t tokens;
index_t tokens; // if p_local_tokens is not nullptr, this indicate the max possible tokens used for ws/LDS calculation
index_t unit_size; // this is the M_a of fused-moe kernel
index_t num_experts;
index_t topk;
@@ -201,6 +202,7 @@ struct MoeSortingKernel
const void* p_topk_ids;
const void* p_weights;
const void* p_local_expert_mask;
const void* p_local_tokens; // [1] if not nullptr, tokens read from here
void* p_sorted_token_ids;
void* p_sorted_weights;
void* p_sorted_expert_ids;
@@ -253,6 +255,7 @@ struct MoeSortingKernel
k.p_topk_ids = h.p_topk_ids;
k.p_weights = h.p_weights;
k.p_local_expert_mask = h.p_local_expert_mask;
k.p_local_tokens = h.p_local_tokens;
k.p_sorted_token_ids = h.p_sorted_token_ids;
k.p_sorted_weights = h.p_sorted_weights;
k.p_sorted_expert_ids = h.p_sorted_expert_ids;
@@ -263,9 +266,13 @@ struct MoeSortingKernel
k.moe_buf_bytes = h.moe_buf_bytes;
const auto blocks = BlockSize(h);
// NOTE: tokens could from p_local_tokens, so here this variable is useless
// hence moe_align_block_size_kernel() will not behavior properly if we have dynamic tokens
// (indeed we can deprecate moe_align_block_size_kernel)
k.tokens_per_thread = integer_divide_ceil(h.tokens * h.topk, blocks.x);
k.unit_size_mdiv = mdiv{static_cast<uint32_t>(h.unit_size)};
k.topk_mdiv = mdiv{static_cast<uint32_t>(h.topk)};
// NOTE: tokens could from p_local_tokens, so here the LDS will be bigger than expected (but works)
k.smem_rows = [&](){
auto [r_, c_] = moe_sorting_get_smem_row_col(h.tokens, h.num_experts);
(void) c_;
@@ -1009,8 +1016,19 @@ struct MoeSortingKernel
}
const size_t numel = kargs.tokens * kargs.topk_mdiv.divisor;
extern __shared__ char smem[];
#if MOE_SORTING_USE_EX_KERNEL
(void)numel;
index_t tokens_ = [&]() {
if constexpr(Problem::LocalToken)
{
return reinterpret_cast<const index_t*>(kargs.p_local_tokens)[0];
}
else
{
return kargs.tokens;
}
}();
return moe_align_block_size_kernel_ex(
static_cast<const IndexType*>(kargs.p_topk_ids),
static_cast<const WeightType*>(kargs.p_weights),
@@ -1020,7 +1038,7 @@ struct MoeSortingKernel
static_cast<IndexType*>(kargs.p_sorted_expert_ids),
static_cast<IndexType*>(kargs.p_total_tokens_post_pad),
kargs.num_experts,
kargs.tokens,
tokens_,
kargs.unit_size_mdiv,
kargs.topk_mdiv,
kargs.expert_mdiv,
@@ -1245,6 +1263,7 @@ CK_TILE_DEVICE void moe_buf_set_zero_kernel(uint8x16_t* buf, long_index_t buf_by
} // namespace impl
// TODO: tokens could be from
// prefer to run mp kernel if is not oneshot
CK_TILE_HOST bool moe_sorting_is_oneshot(int tokens_, int num_experts_)
{
@@ -1351,9 +1370,11 @@ struct MoeSortingMultiPhaseKernel_P0
struct Kargs
{
const void* p_topk_ids; // [tokens, topk]
void* p_expert_mesh; // [expert, tokens]
index_t tokens;
const void* p_topk_ids; // [tokens, topk]
const void* p_local_tokens; // [1], if not nullptr, use this as actual tokens
void* p_expert_mesh; // [expert, tokens]
index_t tokens; // if p_local_tokens is not nullptr, this indicate the max possible tokens
// used for ws/LDS calculation
index_t mesh_stride; // mesh_stride for p_expert_mesh
mdiv topk_mdiv;
};
@@ -1373,11 +1394,12 @@ struct MoeSortingMultiPhaseKernel_P0
CK_TILE_HOST static constexpr auto MakeKargs(const Hargs& h)
{
Kargs k;
k.p_topk_ids = h.p_topk_ids;
k.p_expert_mesh = h.p_ws;
k.tokens = h.tokens;
k.mesh_stride = impl::moe_sorting_mp_mesh_stride(h.tokens);
k.topk_mdiv = mdiv{static_cast<uint32_t>(h.topk)};
k.p_topk_ids = h.p_topk_ids;
k.p_local_tokens = h.p_local_tokens;
k.p_expert_mesh = h.p_ws;
k.tokens = h.tokens;
k.mesh_stride = impl::moe_sorting_mp_mesh_stride(h.tokens);
k.topk_mdiv = mdiv{static_cast<uint32_t>(h.topk)};
return k;
}
@@ -1394,7 +1416,26 @@ struct MoeSortingMultiPhaseKernel_P0
const topk_id_t* p_topk_ids = reinterpret_cast<const topk_id_t*>(kargs.p_topk_ids);
MeshType* p_expert_mesh = reinterpret_cast<MeshType*>(kargs.p_expert_mesh);
index_t total_elem = kargs.tokens * kargs.topk_mdiv.divisor / Problem::SubTokenTile;
index_t tokens = [&]() {
if constexpr(Problem::LocalToken)
{
return reinterpret_cast<const index_t*>(kargs.p_local_tokens)[0];
}
else
{
return kargs.tokens;
}
}();
index_t rounded_tokens = [&]() {
if constexpr(Problem::LocalToken)
{
return (tokens + Problem::SubTokenTile - 1) / Problem::SubTokenTile *
Problem::SubTokenTile;
}
else
return tokens;
}();
index_t total_elem = rounded_tokens * kargs.topk_mdiv.divisor / Problem::SubTokenTile;
#pragma unroll Problem::SubTokenTile
for(index_t i = blockIdx.x * BLOCK_SIZE + threadIdx.x; i < total_elem;
@@ -1405,8 +1446,15 @@ struct MoeSortingMultiPhaseKernel_P0
IndexType eid = x[j.value]; // ext_vector_type must use int to []
uint32_t curr_token_id, curr_topk_id;
kargs.topk_mdiv.divmod(i * Problem::SubTokenTile + j, curr_token_id, curr_topk_id);
p_expert_mesh[eid * kargs.mesh_stride + curr_token_id] =
(curr_topk_id + 1) & 0xffff;
if constexpr(Problem::LocalToken)
{
if(static_cast<index_t>(curr_token_id) < tokens)
p_expert_mesh[eid * kargs.mesh_stride + curr_token_id] =
(curr_topk_id + 1) & 0xffff;
}
else
p_expert_mesh[eid * kargs.mesh_stride + curr_token_id] =
(curr_topk_id + 1) & 0xffff;
});
}
}
@@ -1542,6 +1590,7 @@ struct MoeSortingMultiPhaseKernel_P01
{
const void* p_topk_ids; // [tokens, topk]
const void* p_local_expert_mask; // [expert]
const void* p_local_tokens; // [1]
void* p_expert_mesh; // [expert, tokens]
void* p_expert_cumsum; // [expert + 1]
void* p_expert_sem; // [1]
@@ -1569,6 +1618,7 @@ struct MoeSortingMultiPhaseKernel_P01
Kargs k;
k.p_topk_ids = h.p_topk_ids;
k.p_local_expert_mask = h.p_local_expert_mask;
k.p_local_tokens = h.p_local_tokens;
k.p_expert_mesh = h.p_ws;
k.p_expert_cumsum = reinterpret_cast<void*>(
reinterpret_cast<char*>(h.p_ws) +
@@ -1580,8 +1630,17 @@ struct MoeSortingMultiPhaseKernel_P01
k.tokens = h.tokens;
k.num_experts = h.num_experts;
k.mesh_stride = impl::moe_sorting_mp_mesh_stride(h.tokens);
k.wg_count = WGCounts(h);
k.topk_mdiv = mdiv{static_cast<uint32_t>(h.topk)};
k.wg_count = [&]() {
if constexpr(Problem::LocalToken)
{
return GridSize(h);
}
else
{
return WGCounts(h);
}
}();
k.topk_mdiv = mdiv{static_cast<uint32_t>(h.topk)};
return k;
}
@@ -1607,13 +1666,46 @@ struct MoeSortingMultiPhaseKernel_P01
CK_TILE_DEVICE void operator()(Kargs kargs) const
{
workgroup_barrier wb{reinterpret_cast<uint32_t*>(kargs.p_expert_sem)};
index_t tokens = [&]() {
if constexpr(Problem::LocalToken)
{
return reinterpret_cast<const index_t*>(kargs.p_local_tokens)[0];
}
else
{
return kargs.tokens;
}
}();
index_t rounded_tokens = [&]() {
if constexpr(Problem::LocalToken)
{
return (tokens + Problem::SubTokenTile - 1) / Problem::SubTokenTile *
Problem::SubTokenTile;
}
else
return tokens;
}();
index_t wg_count = [&]() {
if constexpr(Problem::LocalToken)
{
index_t total_elem = rounded_tokens * kargs.topk / Problem::SubTokenTile;
index_t elem_cnt = (total_elem + BLOCK_SIZE - 1) / BLOCK_SIZE;
// no more than grid_size
return min(elem_cnt, kargs.wg_count);
}
else
{
return kargs.wg_count;
}
}();
{
using topk_id_t = ext_vector_t<IndexType, Problem::SubTokenTile>;
const topk_id_t* p_topk_ids = reinterpret_cast<const topk_id_t*>(kargs.p_topk_ids);
IndexType* p_expert_mesh = reinterpret_cast<IndexType*>(kargs.p_expert_mesh);
index_t total_elem = kargs.tokens * kargs.topk_mdiv.divisor / Problem::SubTokenTile;
index_t total_elem = rounded_tokens * kargs.topk_mdiv.divisor / Problem::SubTokenTile;
#pragma unroll Problem::SubTokenTile
for(index_t i = blockIdx.x * BLOCK_SIZE + threadIdx.x; i < total_elem;
@@ -1625,10 +1717,19 @@ struct MoeSortingMultiPhaseKernel_P01
uint32_t curr_token_id, curr_topk_id;
kargs.topk_mdiv.divmod(
i * Problem::SubTokenTile + j, curr_token_id, curr_topk_id);
p_expert_mesh[eid * kargs.mesh_stride + curr_token_id] = curr_topk_id + 1;
// p_expert_mesh[eid * kargs.mesh_stride + curr_token_id] = curr_topk_id + 1;
if constexpr(Problem::LocalToken)
{
if(static_cast<index_t>(curr_token_id) < tokens)
p_expert_mesh[eid * kargs.mesh_stride + curr_token_id] =
(curr_topk_id + 1) & 0xffff;
}
else
p_expert_mesh[eid * kargs.mesh_stride + curr_token_id] =
(curr_topk_id + 1) & 0xffff;
});
}
if(static_cast<index_t>(blockIdx.x) < kargs.wg_count)
if(static_cast<index_t>(blockIdx.x) < wg_count)
{
wb.inc();
}
@@ -1642,7 +1743,7 @@ struct MoeSortingMultiPhaseKernel_P01
if(eid >= kargs.num_experts)
return;
wb.wait_lt(kargs.wg_count);
wb.wait_lt(wg_count);
for(; eid < kargs.num_experts; eid += gridDim.x)
{
@@ -1731,6 +1832,7 @@ struct MoeSortingMultiPhaseKernel_P2
struct Kargs
{
const void* p_local_expert_mask; // [expert]
const void* p_local_tokens; // [1]
void* p_expert_mesh; // [expert, tokens]
void* p_expert_cumsum; // [expert + 1]
void* p_total_tokens_post_pad; // [1]
@@ -1747,6 +1849,7 @@ struct MoeSortingMultiPhaseKernel_P2
{
Kargs k;
k.p_local_expert_mask = h.p_local_expert_mask;
k.p_local_tokens = h.p_local_tokens;
k.p_expert_cumsum = reinterpret_cast<void*>(
reinterpret_cast<char*>(h.p_ws) +
impl::moe_sorting_mp_mesh_smem_size(h.tokens, h.num_experts, h.topk));
@@ -1942,6 +2045,7 @@ struct MoeSortingMultiPhaseKernel_P3
{
const void* p_weights;
const void* p_local_expert_mask;
const void* p_local_tokens;
void* p_sorted_token_ids;
void* p_sorted_weights;
void* p_expert_mesh; // [token, expert]
@@ -1958,6 +2062,7 @@ struct MoeSortingMultiPhaseKernel_P3
Kargs k;
k.p_weights = h.p_weights;
k.p_local_expert_mask = h.p_local_expert_mask;
k.p_local_tokens = h.p_local_tokens;
k.p_sorted_token_ids = h.p_sorted_token_ids;
k.p_sorted_weights = h.p_sorted_weights;
k.p_expert_mesh = h.p_ws;
@@ -1994,6 +2099,16 @@ struct MoeSortingMultiPhaseKernel_P3
const WeightType* p_weights = static_cast<const WeightType*>(kargs.p_weights);
WeightType* p_sorted_weights = reinterpret_cast<WeightType*>(kargs.p_sorted_weights);
index_t tokens = [&]() {
if constexpr(Problem::LocalToken)
{
return reinterpret_cast<const index_t*>(kargs.p_local_tokens)[0];
}
else
{
return kargs.tokens;
}
}();
int eid = blockIdx.x;
int wave_id = threadIdx.x / WarpSize;
int lane_id = threadIdx.x % WarpSize;
@@ -2019,7 +2134,7 @@ struct MoeSortingMultiPhaseKernel_P3
{
int i_token = i * BLOCK_SIZE + threadIdx.x;
IndexType x = 0;
if(i_token < kargs.tokens)
if(i_token < tokens)
{
x = p_expert_mesh[eid * kargs.mesh_stride + i_token];
}
@@ -2066,7 +2181,7 @@ struct MoeSortingMultiPhaseKernel_P3
for(index_t i = e_start + prev_cumsum + threadIdx.x; i < e_end; i += BLOCK_SIZE)
{
#if CK_TILE_REFERENCE_MOE_SORTING_MOCK_ID
p_sorted_token_ids[i] = MOE_SORTING_MOCK_ID(kargs.tokens, kargs.topk_mdiv.divisor);
p_sorted_token_ids[i] = MOE_SORTING_MOCK_ID(tokens, kargs.topk_mdiv.divisor);
#else
p_sorted_token_ids[i] = tokens;
#endif
@@ -2105,6 +2220,7 @@ struct MoeSortingMultiPhaseKernel_P23
{
const void* p_weights;
const void* p_local_expert_mask; // [expert]
const void* p_local_tokens; // [1]
void* p_expert_mesh; // [expert, tokens]
void* p_expert_cumsum; // [expert + 1]
void* p_total_tokens_post_pad; // [1]
@@ -2127,6 +2243,7 @@ struct MoeSortingMultiPhaseKernel_P23
Kargs k;
k.p_weights = h.p_weights;
k.p_local_expert_mask = h.p_local_expert_mask;
k.p_local_tokens = h.p_local_tokens;
k.p_expert_mesh = h.p_ws;
k.p_expert_cumsum = reinterpret_cast<void*>(
reinterpret_cast<char*>(h.p_ws) +
@@ -2346,6 +2463,17 @@ struct MoeSortingMultiPhaseKernel_P23
return; // skip empty expert
}
index_t tokens = [&]() {
if constexpr(Problem::LocalToken)
{
return reinterpret_cast<const index_t*>(kargs.p_local_tokens)[0];
}
else
{
return kargs.tokens;
}
}();
// cumsum one by one
constexpr index_t index_pack = Problem::SubTokenTile; // always packed
using r_t = ext_vector_t<MeshType, index_pack>; // always use int32x4
@@ -2357,7 +2485,7 @@ struct MoeSortingMultiPhaseKernel_P23
{
int i_token_pack = i * BLOCK_SIZE + threadIdx.x;
r_t x_v = 0;
if(i_token_pack < (kargs.tokens + index_pack - 1) / index_pack)
if(i_token_pack < (tokens + index_pack - 1) / index_pack)
{
x_v = reinterpret_cast<r_t*>(p_expert_mesh +
eid * kargs.mesh_stride)[i_token_pack];
@@ -2554,7 +2682,7 @@ struct MoeSortingMultiPhaseKernel_P23
for(index_t i = e_start + prev_cumsum + threadIdx.x; i < e_end; i += BLOCK_SIZE)
{
#if CK_TILE_REFERENCE_MOE_SORTING_MOCK_ID
p_sorted_token_ids[i] = MOE_SORTING_MOCK_ID(kargs.tokens, kargs.topk_mdiv.divisor);
p_sorted_token_ids[i] = MOE_SORTING_MOCK_ID(tokens, kargs.topk_mdiv.divisor);
#else
p_sorted_token_ids[i] = tokens;
#endif

View File

@@ -31,6 +31,7 @@ template <typename IndexType_,
index_t SubTokenTile_, // 1,2,4,8, or 0 in the future
bool SubTokenOneShot_, // if we only loop over once or not
bool LocalExpertMasking_, // used in EP case
bool LocalToken_, // used in EP case
bool SkipExpertsWithZeroTokens_ = true,
index_t ExpertTile_ = 0>
struct MoeSortingProblemEx
@@ -44,6 +45,7 @@ struct MoeSortingProblemEx
static constexpr index_t SubTokenTile = SubTokenTile_;
static constexpr bool SubTokenOneShot = SubTokenOneShot_;
static constexpr bool LocalExpertMasking = LocalExpertMasking_;
static constexpr bool LocalToken = LocalToken_;
static constexpr bool SkipExpertsWithZeroTokens = SkipExpertsWithZeroTokens_;
static_assert(SubTokenTile == 1 || SubTokenTile == 2 || SubTokenTile == 4 || SubTokenTile == 8);
static constexpr index_t ExpertTile = ExpertTile_; // TODO: only used in store out
@@ -54,6 +56,7 @@ template <typename IndexType_,
typename MeshType_,
index_t SubTokenTile_, // 1,2,4,8
bool LocalExpertMasking_, // used in EP case
bool LocalToken_, // used in EP case
bool SkipExpertsWithZeroTokens_ = true>
struct MoeSortingProblemMp
{
@@ -64,6 +67,7 @@ struct MoeSortingProblemMp
static constexpr index_t SubTokenTile = SubTokenTile_;
static constexpr bool LocalExpertMasking = LocalExpertMasking_;
static constexpr bool LocalToken = LocalToken_;
static constexpr bool SkipExpertsWithZeroTokens = SkipExpertsWithZeroTokens_;
static_assert(SubTokenTile == 1 || SubTokenTile == 2 || SubTokenTile == 4 ||
SubTokenTile == 8 || SubTokenTile == 16);