mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-14 02:02:46 +00:00
[CK_TILE] moe_sorting support "local_tokens" feature for EP case (#2335)
* support local_token for hipgraph
* update README
* fix comment
* fix fmoe example
[ROCm/composable_kernel commit: a4e1248dba]
This commit is contained in:
@@ -14,14 +14,24 @@ This will result in an executable `build/bin/tile_example_moe_sorting`
|
||||
## example
|
||||
```
|
||||
args:
|
||||
-v weather do CPU validation or not (default:1)
|
||||
-pr_i index data type. (currently only fp32 supported now) (default:int32)
|
||||
-pr_w output weight data type(currently only fp32 supported now) (default:fp32)
|
||||
-t number of input tokens (default:32)
|
||||
-e number of experts (default:8)
|
||||
-k topk (default:2)
|
||||
-st_i row stride of input, -1 means same as experts (default:-1)
|
||||
-seed seed to be used, -1 means random every time (default:-1)
|
||||
-kname when set to 1 it will print kernel name (default:0)
|
||||
-v turn CPU validation on (1) or off (0). (default:1)
|
||||
-pr_i index data type. Only int32 is currently supported. (default:int32)
|
||||
-pr_w output weight data type. Only fp32 is currently supported. (default:fp32)
|
||||
-t number of input tokens. (default:128)
|
||||
If "local_t" presents, this value indicates global concurrency of all ranks.
|
||||
-local_t Number of local input tokens for curent rank. (default:-1)
|
||||
This value must be within range "[0, t)", or "-1"(no such feature)
|
||||
This feature is to simulate EP case where where each rank has different tokens.
|
||||
Besides, this value will be stored in a GPU buffer, which is friendly for CUDA graph.
|
||||
-e number of num_experts (default:8)
|
||||
-k topk (default:4)
|
||||
-unit unit_size (default:32)
|
||||
-moe_buf_size moe_buf_size (default:0)
|
||||
-local_eid a list of experts enabled as local expert. e.g. "0,1,4,5" (default:-1)
|
||||
please make sure eid is in ascending order!
|
||||
-seed seed to be used. When set to -1, a random seed will be generated each time invoking this example (default:-1)
|
||||
-kname prints the kernel name when set to 1 (default:0)
|
||||
-warmup number of iterations before benchmark the kernel (default:5)
|
||||
-repeat number of iterations to benchmark the kernel (default:20)
|
||||
|
||||
```
|
||||
|
||||
@@ -18,10 +18,20 @@
|
||||
auto create_args(int argc, char* argv[])
|
||||
{
|
||||
ck_tile::ArgParser arg_parser;
|
||||
arg_parser.insert("v", "1", "weather do CPU validation or not")
|
||||
.insert("pr_i", "int32", "index data type. (currently only int32 supported now)")
|
||||
.insert("pr_w", "fp32", "output weight data type(currently only fp32 supported now)")
|
||||
.insert("t", "128", "number of input tokens")
|
||||
arg_parser.insert("v", "1", "turn CPU validation on (1) or off (0).")
|
||||
.insert("pr_i", "int32", "index data type. Only int32 is currently supported.")
|
||||
.insert("pr_w", "fp32", "output weight data type. Only fp32 is currently supported.")
|
||||
.insert("t",
|
||||
"128",
|
||||
"number of input tokens.\n"
|
||||
"If \"local_t\" presents, this value indicates global concurrency of all ranks.")
|
||||
.insert(
|
||||
"local_t",
|
||||
"-1",
|
||||
"Number of local input tokens for curent rank.\n"
|
||||
"This value must be within range \"[0, t)\", or \"-1\"(no such feature)\n"
|
||||
"This feature is to simulate EP case where where each rank has different tokens.\n"
|
||||
"Besides, this value will be stored in a GPU buffer, which is friendly for CUDA graph.")
|
||||
.insert("e", "8", "number of num_experts")
|
||||
.insert("k", "4", "topk")
|
||||
.insert("unit", "32", "unit_size")
|
||||
@@ -30,8 +40,11 @@ auto create_args(int argc, char* argv[])
|
||||
"-1",
|
||||
"a list of experts enabled as local expert. e.g. \"0,1,4,5\"\n"
|
||||
"please make sure eid is in ascending order!")
|
||||
.insert("seed", "-1", "seed to be used, -1 means random every time")
|
||||
.insert("kname", "0", "when set to 1 it will print kernel name")
|
||||
.insert("seed",
|
||||
"-1",
|
||||
"seed to be used. When set to -1, a random seed will be generated each time "
|
||||
"invoking this example")
|
||||
.insert("kname", "0", "prints the kernel name when set to 1")
|
||||
.insert("warmup", "5", "number of iterations before benchmark the kernel")
|
||||
.insert("repeat", "20", "number of iterations to benchmark the kernel");
|
||||
|
||||
@@ -70,6 +83,7 @@ bool test_moe_sorting(ck_tile::ArgParser args)
|
||||
std::string index_prec = args.get_str("pr_i");
|
||||
std::string weight_prec = args.get_str("pr_w");
|
||||
int tokens = args.get_int("t");
|
||||
int local_tokens = args.get_int("local_t");
|
||||
int num_experts = args.get_int("e");
|
||||
int topk = args.get_int("k");
|
||||
int seed = args.get_int("seed");
|
||||
@@ -95,6 +109,16 @@ bool test_moe_sorting(ck_tile::ArgParser args)
|
||||
return false;
|
||||
}
|
||||
|
||||
// if local_tokens == tokens, not local_token, but better avoid this since no meaning for such
|
||||
// case
|
||||
bool is_local_token = local_tokens >= 0 && local_tokens < tokens;
|
||||
|
||||
if(local_tokens > tokens)
|
||||
{
|
||||
printf("local_tokens:%d larger than tokens:%d, invalid\n", local_tokens, tokens);
|
||||
return false;
|
||||
}
|
||||
|
||||
bool local_expert_masking = args.get_str("local_eid") != "-1";
|
||||
auto local_expert_masking_host = [&]() {
|
||||
if(local_expert_masking)
|
||||
@@ -143,6 +167,13 @@ bool test_moe_sorting(ck_tile::ArgParser args)
|
||||
ck_tile::DeviceMem local_expert_masking_dev(
|
||||
local_expert_masking_host.get_element_space_size_in_bytes());
|
||||
|
||||
// used for simulating dynamic_tokens for EP case
|
||||
ck_tile::DeviceMem local_tokens_dev(sizeof(ck_tile::index_t));
|
||||
if(is_local_token)
|
||||
{
|
||||
local_tokens_dev.ToDevice(&local_tokens);
|
||||
}
|
||||
|
||||
topk_ids_dev.ToDevice(topk_ids_host.data());
|
||||
weights_dev.ToDevice(weights_host.data());
|
||||
if(moe_buf_size > 0)
|
||||
@@ -164,6 +195,7 @@ bool test_moe_sorting(ck_tile::ArgParser args)
|
||||
weights_dev.GetDeviceBuffer(),
|
||||
local_expert_masking ? local_expert_masking_dev.GetDeviceBuffer()
|
||||
: nullptr,
|
||||
is_local_token ? local_tokens_dev.GetDeviceBuffer() : nullptr,
|
||||
sorted_ids_dev.GetDeviceBuffer(),
|
||||
sorted_weights_dev.GetDeviceBuffer(),
|
||||
sorted_expert_ids_dev.GetDeviceBuffer(),
|
||||
@@ -236,13 +268,12 @@ bool test_moe_sorting(ck_tile::ArgParser args)
|
||||
}
|
||||
#endif
|
||||
|
||||
printf("[%s|%s]tokens:%d, num_experts:%d, topk:%d, mp:%d, ",
|
||||
index_prec.c_str(),
|
||||
weight_prec.c_str(),
|
||||
tokens,
|
||||
num_experts,
|
||||
topk,
|
||||
workspace_size != 0 ? 1 : 0);
|
||||
printf("[%s|%s]tokens:%d", index_prec.c_str(), weight_prec.c_str(), tokens);
|
||||
if(is_local_token)
|
||||
{
|
||||
printf("(%d)", local_tokens);
|
||||
}
|
||||
printf(", num_experts:%d, topk:%d, mp:%d, ", num_experts, topk, workspace_size != 0 ? 1 : 0);
|
||||
|
||||
if(local_expert_masking)
|
||||
{
|
||||
@@ -285,6 +316,8 @@ bool test_moe_sorting(ck_tile::ArgParser args)
|
||||
ref_total_tokens_post_pad,
|
||||
num_experts,
|
||||
unit_size,
|
||||
is_local_token ? local_tokens
|
||||
: tokens,
|
||||
local_expert_masking);
|
||||
printf("total_tokens_post_pad:%d(%d), ",
|
||||
ref_total_tokens_post_pad,
|
||||
|
||||
@@ -33,15 +33,18 @@
|
||||
|
||||
#else
|
||||
|
||||
#define MOE_SORTING_DISPATCH_(sub_token_tile_, sub_token_onshot_, local_expert_masking_) \
|
||||
#define MOE_SORTING_DISPATCH_( \
|
||||
sub_token_tile_, sub_token_onshot_, local_expert_masking_, local_token_) \
|
||||
constexpr ck_tile::index_t sub_token_tile = sub_token_tile_; \
|
||||
constexpr bool sub_token_onshot = sub_token_onshot_; \
|
||||
constexpr bool local_expert_masking = local_expert_masking_; \
|
||||
constexpr bool local_token = local_token_; \
|
||||
using ms_problem = ck_tile::MoeSortingProblemEx<index_t, \
|
||||
ms_weight_type, \
|
||||
sub_token_tile, \
|
||||
sub_token_onshot, \
|
||||
local_expert_masking>; \
|
||||
local_expert_masking, \
|
||||
local_token>; \
|
||||
using kernel = ck_tile::MoeSortingKernel<ms_problem>; \
|
||||
auto kargs = kernel::MakeKargs(a); \
|
||||
const dim3 grids = kernel::GridSize(a); \
|
||||
@@ -51,32 +54,43 @@
|
||||
s, ck_tile::make_kernel(kernel{}, grids, blocks, lds_bytes, kargs)); \
|
||||
return ave_time;
|
||||
|
||||
#define MOE_SORTING_DISPATCH_SUB_TOKEN_(row_, sub_token_onshot_, local_expert_masking_) \
|
||||
if(row_ % 8 == 0) \
|
||||
{ \
|
||||
MOE_SORTING_DISPATCH_(8, sub_token_onshot_, local_expert_masking_); \
|
||||
} \
|
||||
else if(row_ % 4 == 0) \
|
||||
{ \
|
||||
MOE_SORTING_DISPATCH_(4, sub_token_onshot_, local_expert_masking_); \
|
||||
} \
|
||||
else if(row_ % 2 == 0) \
|
||||
{ \
|
||||
MOE_SORTING_DISPATCH_(2, sub_token_onshot_, local_expert_masking_); \
|
||||
} \
|
||||
else \
|
||||
{ \
|
||||
MOE_SORTING_DISPATCH_(1, sub_token_onshot_, local_expert_masking_); \
|
||||
#define MOE_SORTING_DISPATCH_SUB_TOKEN_( \
|
||||
row_, sub_token_onshot_, local_expert_masking_, local_token_) \
|
||||
if(row_ % 8 == 0) \
|
||||
{ \
|
||||
MOE_SORTING_DISPATCH_(8, sub_token_onshot_, local_expert_masking_, local_token_); \
|
||||
} \
|
||||
else if(row_ % 4 == 0) \
|
||||
{ \
|
||||
MOE_SORTING_DISPATCH_(4, sub_token_onshot_, local_expert_masking_, local_token_); \
|
||||
} \
|
||||
else if(row_ % 2 == 0) \
|
||||
{ \
|
||||
MOE_SORTING_DISPATCH_(2, sub_token_onshot_, local_expert_masking_, local_token_); \
|
||||
} \
|
||||
else \
|
||||
{ \
|
||||
MOE_SORTING_DISPATCH_(1, sub_token_onshot_, local_expert_masking_, local_token_); \
|
||||
}
|
||||
|
||||
#define MOE_SORTING_DISPATCH_SUBTO_(row_, local_expert_masking_) \
|
||||
if(is_sub_token_onshot) \
|
||||
{ \
|
||||
MOE_SORTING_DISPATCH_SUB_TOKEN_(row_, true, local_expert_masking_) \
|
||||
} \
|
||||
else \
|
||||
{ \
|
||||
MOE_SORTING_DISPATCH_SUB_TOKEN_(row_, false, local_expert_masking_) \
|
||||
#define MOE_SORTING_DISPATCH_DYNAMIC_TOKEN_(row_, sub_token_onshot_, local_expert_masking_) \
|
||||
if(is_local_token) \
|
||||
{ \
|
||||
MOE_SORTING_DISPATCH_SUB_TOKEN_(row_, sub_token_onshot_, local_expert_masking_, true) \
|
||||
} \
|
||||
else \
|
||||
{ \
|
||||
MOE_SORTING_DISPATCH_SUB_TOKEN_(row_, sub_token_onshot_, local_expert_masking_, false) \
|
||||
}
|
||||
|
||||
#define MOE_SORTING_DISPATCH_SUBTO_(row_, local_expert_masking_) \
|
||||
if(is_sub_token_onshot) \
|
||||
{ \
|
||||
MOE_SORTING_DISPATCH_DYNAMIC_TOKEN_(row_, true, local_expert_masking_) \
|
||||
} \
|
||||
else \
|
||||
{ \
|
||||
MOE_SORTING_DISPATCH_DYNAMIC_TOKEN_(row_, false, local_expert_masking_) \
|
||||
}
|
||||
|
||||
#define MOE_SORTING_DISPATCH_EMASK_(row_) \
|
||||
@@ -171,6 +185,7 @@ float moe_sorting(moe_sorting_trait t, moe_sorting_args a, ck_tile::stream_confi
|
||||
auto row_ = sub_token_ / 8;
|
||||
bool is_sub_token_onshot = a.tokens <= sub_token_;
|
||||
bool is_local_expert_masking = t.local_expert_masking;
|
||||
bool is_local_token = a.p_local_tokens != nullptr;
|
||||
|
||||
MOE_SORTING_DISPATCH_EMASK_(row_);
|
||||
// MOE_SORTING_DISPATCH_ETILE(0, 0);
|
||||
@@ -179,15 +194,17 @@ float moe_sorting(moe_sorting_trait t, moe_sorting_args a, ck_tile::stream_confi
|
||||
return -1;
|
||||
}
|
||||
|
||||
#define MOE_SORTING_MP_0(mesh_type_, unroll_num_, expert_masking_) \
|
||||
#define MOE_SORTING_MP_0(mesh_type_, unroll_num_, expert_masking_, local_token_) \
|
||||
[&]() { \
|
||||
constexpr ck_tile::index_t unroll_num = unroll_num_; \
|
||||
constexpr bool expert_masking = expert_masking_; \
|
||||
constexpr bool local_token = local_token_; \
|
||||
using ms_problem = ck_tile::MoeSortingProblemMp<ms_index_t, \
|
||||
ms_weight_type, \
|
||||
mesh_type_, \
|
||||
unroll_num, \
|
||||
expert_masking>; \
|
||||
expert_masking, \
|
||||
local_token>; \
|
||||
using kernel = ck_tile::MoeSortingMultiPhaseKernel_P0<ms_problem>; \
|
||||
auto kargs = kernel::MakeKargs(a); \
|
||||
const dim3 grids = kernel::GridSize(a); \
|
||||
@@ -195,15 +212,17 @@ float moe_sorting(moe_sorting_trait t, moe_sorting_args a, ck_tile::stream_confi
|
||||
return ck_tile::make_kernel<kernel::BLOCK_SIZE>(kernel{}, grids, blocks, 0, kargs); \
|
||||
}()
|
||||
|
||||
#define MOE_SORTING_MP_1(mesh_type_, unroll_num_, expert_masking_) \
|
||||
#define MOE_SORTING_MP_1(mesh_type_, unroll_num_, expert_masking_, local_token_) \
|
||||
[&]() { \
|
||||
constexpr ck_tile::index_t unroll_num = unroll_num_; \
|
||||
constexpr bool expert_masking = expert_masking_; \
|
||||
constexpr bool local_token = local_token_; \
|
||||
using ms_problem = ck_tile::MoeSortingProblemMp<ms_index_t, \
|
||||
ms_weight_type, \
|
||||
mesh_type_, \
|
||||
unroll_num, \
|
||||
expert_masking>; \
|
||||
expert_masking, \
|
||||
local_token>; \
|
||||
using kernel = ck_tile::MoeSortingMultiPhaseKernel_P1<ms_problem>; \
|
||||
auto kargs = kernel::MakeKargs(a); \
|
||||
const dim3 grids = kernel::GridSize(a); \
|
||||
@@ -211,15 +230,17 @@ float moe_sorting(moe_sorting_trait t, moe_sorting_args a, ck_tile::stream_confi
|
||||
return ck_tile::make_kernel<kernel::BLOCK_SIZE>(kernel{}, grids, blocks, 0, kargs); \
|
||||
}()
|
||||
#if MOE_SORTING_SUPPORT_LARGE_EXPERT
|
||||
#define MOE_SORTING_MP_2(mesh_type_, unroll_num_, expert_masking_) \
|
||||
#define MOE_SORTING_MP_2(mesh_type_, unroll_num_, expert_masking_, local_token_) \
|
||||
[&]() { \
|
||||
constexpr ck_tile::index_t unroll_num = unroll_num_; \
|
||||
constexpr bool expert_masking = expert_masking_; \
|
||||
constexpr bool local_token = local_token_; \
|
||||
using ms_problem = ck_tile::MoeSortingProblemMp<ms_index_t, \
|
||||
ms_weight_type, \
|
||||
mesh_type_, \
|
||||
unroll_num, \
|
||||
expert_masking>; \
|
||||
expert_masking, \
|
||||
local_token>; \
|
||||
using kernel = ck_tile::MoeSortingMultiPhaseKernel_P2<ms_problem>; \
|
||||
auto kargs = kernel::MakeKargs(a); \
|
||||
const dim3 grids = kernel::GridSize(a); \
|
||||
@@ -227,15 +248,17 @@ float moe_sorting(moe_sorting_trait t, moe_sorting_args a, ck_tile::stream_confi
|
||||
return ck_tile::make_kernel(kernel{}, grids, blocks, 0, kargs); \
|
||||
}()
|
||||
|
||||
#define MOE_SORTING_MP_3(mesh_type_, unroll_num_, expert_masking_) \
|
||||
#define MOE_SORTING_MP_3(mesh_type_, unroll_num_, expert_masking_, local_token_) \
|
||||
[&]() { \
|
||||
constexpr ck_tile::index_t unroll_num = unroll_num_; \
|
||||
constexpr bool expert_masking = expert_masking_; \
|
||||
constexpr bool local_token = local_token_; \
|
||||
using ms_problem = ck_tile::MoeSortingProblemMp<ms_index_t, \
|
||||
ms_weight_type, \
|
||||
mesh_type_, \
|
||||
unroll_num, \
|
||||
expert_masking>; \
|
||||
expert_masking, \
|
||||
local_token>; \
|
||||
using kernel = ck_tile::MoeSortingMultiPhaseKernel_P3<ms_problem>; \
|
||||
auto kargs = kernel::MakeKargs(a); \
|
||||
const dim3 grids = kernel::GridSize(a); \
|
||||
@@ -244,15 +267,17 @@ float moe_sorting(moe_sorting_trait t, moe_sorting_args a, ck_tile::stream_confi
|
||||
}()
|
||||
#endif
|
||||
|
||||
#define MOE_SORTING_MP_23(mesh_type_, unroll_num_, expert_masking_) \
|
||||
#define MOE_SORTING_MP_23(mesh_type_, unroll_num_, expert_masking_, local_token_) \
|
||||
[&]() { \
|
||||
constexpr ck_tile::index_t unroll_num = unroll_num_; \
|
||||
constexpr bool expert_masking = expert_masking_; \
|
||||
constexpr bool local_token = local_token_; \
|
||||
using ms_problem = ck_tile::MoeSortingProblemMp<ms_index_t, \
|
||||
ms_weight_type, \
|
||||
mesh_type_, \
|
||||
unroll_num, \
|
||||
expert_masking>; \
|
||||
expert_masking, \
|
||||
local_token>; \
|
||||
using kernel = ck_tile::MoeSortingMultiPhaseKernel_P23<ms_problem>; \
|
||||
auto kargs = kernel::MakeKargs(a); \
|
||||
const dim3 grids = kernel::GridSize(a); \
|
||||
@@ -261,28 +286,53 @@ float moe_sorting(moe_sorting_trait t, moe_sorting_args a, ck_tile::stream_confi
|
||||
return ck_tile::make_kernel<kernel::BLOCK_SIZE>(kernel{}, grids, blocks, lds_size, kargs); \
|
||||
}()
|
||||
|
||||
#define MOR_SORTING_MP_DISPATCH_(mesh_type_, token_vec_0_, token_vec_1_, token_vec_23_) \
|
||||
if(t.local_expert_masking) \
|
||||
{ \
|
||||
float ave_time = \
|
||||
ck_tile::launch_kernel(s, \
|
||||
MOE_SORTING_MP_0(mesh_type_, token_vec_0_, true), \
|
||||
MOE_SORTING_MP_1(mesh_type_, token_vec_1_, true), \
|
||||
MOE_SORTING_MP_23(mesh_type_, token_vec_23_, true)); \
|
||||
return ave_time; \
|
||||
} \
|
||||
else \
|
||||
{ \
|
||||
float ave_time = \
|
||||
ck_tile::launch_kernel(s, \
|
||||
MOE_SORTING_MP_0(mesh_type_, token_vec_0_, false), \
|
||||
MOE_SORTING_MP_1(mesh_type_, token_vec_1_, false), \
|
||||
MOE_SORTING_MP_23(mesh_type_, token_vec_23_, false)); \
|
||||
return ave_time; \
|
||||
#define MOR_SORTING_MP_DISPATCH_(mesh_type_, token_vec_0_, token_vec_1_, token_vec_23_) \
|
||||
if(t.local_expert_masking) \
|
||||
{ \
|
||||
if(is_local_token) \
|
||||
{ \
|
||||
float ave_time = \
|
||||
ck_tile::launch_kernel(s, \
|
||||
MOE_SORTING_MP_0(mesh_type_, token_vec_0_, true, true), \
|
||||
MOE_SORTING_MP_1(mesh_type_, token_vec_1_, true, true), \
|
||||
MOE_SORTING_MP_23(mesh_type_, token_vec_23_, true, true)); \
|
||||
return ave_time; \
|
||||
} \
|
||||
else \
|
||||
{ \
|
||||
float ave_time = \
|
||||
ck_tile::launch_kernel(s, \
|
||||
MOE_SORTING_MP_0(mesh_type_, token_vec_0_, true, false), \
|
||||
MOE_SORTING_MP_1(mesh_type_, token_vec_1_, true, false), \
|
||||
MOE_SORTING_MP_23(mesh_type_, token_vec_23_, true, false)); \
|
||||
return ave_time; \
|
||||
} \
|
||||
} \
|
||||
else \
|
||||
{ \
|
||||
if(is_local_token) \
|
||||
{ \
|
||||
float ave_time = \
|
||||
ck_tile::launch_kernel(s, \
|
||||
MOE_SORTING_MP_0(mesh_type_, token_vec_0_, false, true), \
|
||||
MOE_SORTING_MP_1(mesh_type_, token_vec_1_, false, true), \
|
||||
MOE_SORTING_MP_23(mesh_type_, token_vec_23_, false, true)); \
|
||||
return ave_time; \
|
||||
} \
|
||||
else \
|
||||
{ \
|
||||
float ave_time = ck_tile::launch_kernel( \
|
||||
s, \
|
||||
MOE_SORTING_MP_0(mesh_type_, token_vec_0_, false, false), \
|
||||
MOE_SORTING_MP_1(mesh_type_, token_vec_1_, false, false), \
|
||||
MOE_SORTING_MP_23(mesh_type_, token_vec_23_, false, false)); \
|
||||
return ave_time; \
|
||||
} \
|
||||
}
|
||||
|
||||
float moe_sorting_mp(moe_sorting_trait t, moe_sorting_args a, ck_tile::stream_config s)
|
||||
{
|
||||
bool is_local_token = a.p_local_tokens != nullptr;
|
||||
if(t.weight_type == "fp32" && t.index_type == "int32")
|
||||
{
|
||||
using ms_index_t = ck_tile::index_t;
|
||||
|
||||
@@ -31,4 +31,14 @@ $EXE -t=8192 -e=32 -k=5 -moe_buf_size=163840
|
||||
$EXE -t=8192 -e=32 -k=8 -moe_buf_size=163840
|
||||
$EXE -t=8192 -e=256 -k=5 -moe_buf_size=163840
|
||||
$EXE -t=8192 -e=256 -k=8 -moe_buf_size=163840
|
||||
$EXE -t=163840 -e=256 -k=8 -moe_buf_size=163840
|
||||
$EXE -t=163840 -e=256 -k=8 -moe_buf_size=163840
|
||||
$EXE -t=12 -local_t=3 -e=256 -k=5 -local_eid=9,10,199,145
|
||||
$EXE -t=67 -local_t=9 -e=555 -k=5 -local_eid=19,23,24,25,26,99
|
||||
$EXE -t=99 -local_t=93 -e=121 -moe_buf_size=10244
|
||||
$EXE -t=536 -local_t=345 -e=802 -k=99
|
||||
$EXE -t=331 -local_t=39 -e=83 -k=33
|
||||
$EXE -t=765 -local_t=654 -e=783 -k=8
|
||||
$EXE -t=23 -local_t=9 -e=1 -k=1
|
||||
$EXE -t=7 -local_t=0 -e=89 -k=1 -local_eid=0,8,12,33
|
||||
$EXE -t=61 -local_t=0 -e=333 -k=99 -local_eid=0,8,12,33
|
||||
$EXE -t=133940 -local_t=111921 -e=256 -k=17 -moe_buf_size=133940
|
||||
|
||||
@@ -16,6 +16,7 @@ struct fused_moe_args
|
||||
const void* d_scale_ptr; // [e, 1, k], down scale
|
||||
const void* y_smooth_scale_ptr; // [e, 1, n], smooth-quant-scale for 2nd gemm input
|
||||
const void* local_expert_mask_ptr; // [e], local_expert_mask_ptr for EP
|
||||
const void* local_tokens; // [1] if not nullptr, tokens read from here
|
||||
void* o_ptr; // [m, k], output token (no need to do zeroing)
|
||||
void* ws_ptr; // size is moe_sorting_get_workspace_size()
|
||||
// if return zero, then could be nullptr
|
||||
|
||||
@@ -28,6 +28,7 @@ float fused_moe(fused_moe_traits t, fused_moe_args a, const ck_tile::stream_conf
|
||||
a.topk_ids_ptr, // const void* p_topk_ids;
|
||||
a.topk_weight_ptr, // const void* p_weights;
|
||||
a.local_expert_mask_ptr, // const void* p_local_expert_mask;
|
||||
a.local_tokens,
|
||||
a.sorted_token_ids_ptr, // void* p_sorted_token_ids;
|
||||
a.sorted_weight_ptr, // void* p_sorted_weights;
|
||||
a.sorted_expert_ids_ptr, // void* p_sorted_expert_ids;
|
||||
|
||||
@@ -33,15 +33,18 @@
|
||||
|
||||
#else
|
||||
|
||||
#define MOE_SORTING_DISPATCH_(sub_token_tile_, sub_token_onshot_, local_expert_masking_) \
|
||||
#define MOE_SORTING_DISPATCH_( \
|
||||
sub_token_tile_, sub_token_onshot_, local_expert_masking_, local_token_) \
|
||||
constexpr ck_tile::index_t sub_token_tile = sub_token_tile_; \
|
||||
constexpr bool sub_token_onshot = sub_token_onshot_; \
|
||||
constexpr bool local_expert_masking = local_expert_masking_; \
|
||||
constexpr bool local_token = local_token_; \
|
||||
using ms_problem = ck_tile::MoeSortingProblemEx<index_t, \
|
||||
ms_weight_type, \
|
||||
sub_token_tile, \
|
||||
sub_token_onshot, \
|
||||
local_expert_masking>; \
|
||||
local_expert_masking, \
|
||||
local_token>; \
|
||||
using kernel = ck_tile::MoeSortingKernel<ms_problem>; \
|
||||
auto kargs = kernel::MakeKargs(a); \
|
||||
const dim3 grids = kernel::GridSize(a); \
|
||||
@@ -51,32 +54,43 @@
|
||||
s, ck_tile::make_kernel(kernel{}, grids, blocks, lds_bytes, kargs)); \
|
||||
return ave_time;
|
||||
|
||||
#define MOE_SORTING_DISPATCH_SUB_TOKEN_(row_, sub_token_onshot_, local_expert_masking_) \
|
||||
if(row_ % 8 == 0) \
|
||||
{ \
|
||||
MOE_SORTING_DISPATCH_(8, sub_token_onshot_, local_expert_masking_); \
|
||||
} \
|
||||
else if(row_ % 4 == 0) \
|
||||
{ \
|
||||
MOE_SORTING_DISPATCH_(4, sub_token_onshot_, local_expert_masking_); \
|
||||
} \
|
||||
else if(row_ % 2 == 0) \
|
||||
{ \
|
||||
MOE_SORTING_DISPATCH_(2, sub_token_onshot_, local_expert_masking_); \
|
||||
} \
|
||||
else \
|
||||
{ \
|
||||
MOE_SORTING_DISPATCH_(1, sub_token_onshot_, local_expert_masking_); \
|
||||
#define MOE_SORTING_DISPATCH_SUB_TOKEN_( \
|
||||
row_, sub_token_onshot_, local_expert_masking_, local_token_) \
|
||||
if(row_ % 8 == 0) \
|
||||
{ \
|
||||
MOE_SORTING_DISPATCH_(8, sub_token_onshot_, local_expert_masking_, local_token_); \
|
||||
} \
|
||||
else if(row_ % 4 == 0) \
|
||||
{ \
|
||||
MOE_SORTING_DISPATCH_(4, sub_token_onshot_, local_expert_masking_, local_token_); \
|
||||
} \
|
||||
else if(row_ % 2 == 0) \
|
||||
{ \
|
||||
MOE_SORTING_DISPATCH_(2, sub_token_onshot_, local_expert_masking_, local_token_); \
|
||||
} \
|
||||
else \
|
||||
{ \
|
||||
MOE_SORTING_DISPATCH_(1, sub_token_onshot_, local_expert_masking_, local_token_); \
|
||||
}
|
||||
|
||||
#define MOE_SORTING_DISPATCH_SUBTO_(row_, local_expert_masking_) \
|
||||
if(is_sub_token_onshot) \
|
||||
{ \
|
||||
MOE_SORTING_DISPATCH_SUB_TOKEN_(row_, true, local_expert_masking_) \
|
||||
} \
|
||||
else \
|
||||
{ \
|
||||
MOE_SORTING_DISPATCH_SUB_TOKEN_(row_, false, local_expert_masking_) \
|
||||
#define MOE_SORTING_DISPATCH_DYNAMIC_TOKEN_(row_, sub_token_onshot_, local_expert_masking_) \
|
||||
if(is_local_token) \
|
||||
{ \
|
||||
MOE_SORTING_DISPATCH_SUB_TOKEN_(row_, sub_token_onshot_, local_expert_masking_, true) \
|
||||
} \
|
||||
else \
|
||||
{ \
|
||||
MOE_SORTING_DISPATCH_SUB_TOKEN_(row_, sub_token_onshot_, local_expert_masking_, false) \
|
||||
}
|
||||
|
||||
#define MOE_SORTING_DISPATCH_SUBTO_(row_, local_expert_masking_) \
|
||||
if(is_sub_token_onshot) \
|
||||
{ \
|
||||
MOE_SORTING_DISPATCH_DYNAMIC_TOKEN_(row_, true, local_expert_masking_) \
|
||||
} \
|
||||
else \
|
||||
{ \
|
||||
MOE_SORTING_DISPATCH_DYNAMIC_TOKEN_(row_, false, local_expert_masking_) \
|
||||
}
|
||||
|
||||
#define MOE_SORTING_DISPATCH_EMASK_(row_) \
|
||||
@@ -175,6 +189,7 @@ float fused_moesorting(fused_moesorting_trait t, fused_moesorting_args a, ck_til
|
||||
auto row_ = sub_token_ / 8;
|
||||
bool is_sub_token_onshot = a.tokens <= sub_token_;
|
||||
bool is_local_expert_masking = t.local_expert_masking;
|
||||
bool is_local_token = a.p_local_tokens != nullptr;
|
||||
|
||||
MOE_SORTING_DISPATCH_EMASK_(row_);
|
||||
// MOE_SORTING_DISPATCH_ETILE(0, 0);
|
||||
@@ -183,15 +198,17 @@ float fused_moesorting(fused_moesorting_trait t, fused_moesorting_args a, ck_til
|
||||
return -1;
|
||||
}
|
||||
|
||||
#define MOE_SORTING_MP_0(mesh_type_, unroll_num_, expert_masking_) \
|
||||
#define MOE_SORTING_MP_0(mesh_type_, unroll_num_, expert_masking_, local_token_) \
|
||||
[&]() { \
|
||||
constexpr ck_tile::index_t unroll_num = unroll_num_; \
|
||||
constexpr bool expert_masking = expert_masking_; \
|
||||
constexpr bool local_token = local_token_; \
|
||||
using ms_problem = ck_tile::MoeSortingProblemMp<ms_index_t, \
|
||||
ms_weight_type, \
|
||||
mesh_type_, \
|
||||
unroll_num, \
|
||||
expert_masking>; \
|
||||
expert_masking, \
|
||||
local_token>; \
|
||||
using kernel = ck_tile::MoeSortingMultiPhaseKernel_P0<ms_problem>; \
|
||||
auto kargs = kernel::MakeKargs(a); \
|
||||
const dim3 grids = kernel::GridSize(a); \
|
||||
@@ -199,15 +216,17 @@ float fused_moesorting(fused_moesorting_trait t, fused_moesorting_args a, ck_til
|
||||
return ck_tile::make_kernel<kernel::BLOCK_SIZE>(kernel{}, grids, blocks, 0, kargs); \
|
||||
}()
|
||||
|
||||
#define MOE_SORTING_MP_1(mesh_type_, unroll_num_, expert_masking_) \
|
||||
#define MOE_SORTING_MP_1(mesh_type_, unroll_num_, expert_masking_, local_token_) \
|
||||
[&]() { \
|
||||
constexpr ck_tile::index_t unroll_num = unroll_num_; \
|
||||
constexpr bool expert_masking = expert_masking_; \
|
||||
constexpr bool local_token = local_token_; \
|
||||
using ms_problem = ck_tile::MoeSortingProblemMp<ms_index_t, \
|
||||
ms_weight_type, \
|
||||
mesh_type_, \
|
||||
unroll_num, \
|
||||
expert_masking>; \
|
||||
expert_masking, \
|
||||
local_token>; \
|
||||
using kernel = ck_tile::MoeSortingMultiPhaseKernel_P1<ms_problem>; \
|
||||
auto kargs = kernel::MakeKargs(a); \
|
||||
const dim3 grids = kernel::GridSize(a); \
|
||||
@@ -215,15 +234,17 @@ float fused_moesorting(fused_moesorting_trait t, fused_moesorting_args a, ck_til
|
||||
return ck_tile::make_kernel<kernel::BLOCK_SIZE>(kernel{}, grids, blocks, 0, kargs); \
|
||||
}()
|
||||
#if MOE_SORTING_SUPPORT_LARGE_EXPERT
|
||||
#define MOE_SORTING_MP_2(mesh_type_, unroll_num_, expert_masking_) \
|
||||
#define MOE_SORTING_MP_2(mesh_type_, unroll_num_, expert_masking_, local_token_) \
|
||||
[&]() { \
|
||||
constexpr ck_tile::index_t unroll_num = unroll_num_; \
|
||||
constexpr bool expert_masking = expert_masking_; \
|
||||
constexpr bool local_token = local_token_; \
|
||||
using ms_problem = ck_tile::MoeSortingProblemMp<ms_index_t, \
|
||||
ms_weight_type, \
|
||||
mesh_type_, \
|
||||
unroll_num, \
|
||||
expert_masking>; \
|
||||
expert_masking, \
|
||||
local_token>; \
|
||||
using kernel = ck_tile::MoeSortingMultiPhaseKernel_P2<ms_problem>; \
|
||||
auto kargs = kernel::MakeKargs(a); \
|
||||
const dim3 grids = kernel::GridSize(a); \
|
||||
@@ -231,15 +252,17 @@ float fused_moesorting(fused_moesorting_trait t, fused_moesorting_args a, ck_til
|
||||
return ck_tile::make_kernel(kernel{}, grids, blocks, 0, kargs); \
|
||||
}()
|
||||
|
||||
#define MOE_SORTING_MP_3(mesh_type_, unroll_num_, expert_masking_) \
|
||||
#define MOE_SORTING_MP_3(mesh_type_, unroll_num_, expert_masking_, local_token_) \
|
||||
[&]() { \
|
||||
constexpr ck_tile::index_t unroll_num = unroll_num_; \
|
||||
constexpr bool expert_masking = expert_masking_; \
|
||||
constexpr bool local_token = local_token_; \
|
||||
using ms_problem = ck_tile::MoeSortingProblemMp<ms_index_t, \
|
||||
ms_weight_type, \
|
||||
mesh_type_, \
|
||||
unroll_num, \
|
||||
expert_masking>; \
|
||||
expert_masking, \
|
||||
local_token>; \
|
||||
using kernel = ck_tile::MoeSortingMultiPhaseKernel_P3<ms_problem>; \
|
||||
auto kargs = kernel::MakeKargs(a); \
|
||||
const dim3 grids = kernel::GridSize(a); \
|
||||
@@ -248,15 +271,17 @@ float fused_moesorting(fused_moesorting_trait t, fused_moesorting_args a, ck_til
|
||||
}()
|
||||
#endif
|
||||
|
||||
#define MOE_SORTING_MP_23(mesh_type_, unroll_num_, expert_masking_) \
|
||||
#define MOE_SORTING_MP_23(mesh_type_, unroll_num_, expert_masking_, local_token_) \
|
||||
[&]() { \
|
||||
constexpr ck_tile::index_t unroll_num = unroll_num_; \
|
||||
constexpr bool expert_masking = expert_masking_; \
|
||||
constexpr bool local_token = local_token_; \
|
||||
using ms_problem = ck_tile::MoeSortingProblemMp<ms_index_t, \
|
||||
ms_weight_type, \
|
||||
mesh_type_, \
|
||||
unroll_num, \
|
||||
expert_masking>; \
|
||||
expert_masking, \
|
||||
local_token>; \
|
||||
using kernel = ck_tile::MoeSortingMultiPhaseKernel_P23<ms_problem>; \
|
||||
auto kargs = kernel::MakeKargs(a); \
|
||||
const dim3 grids = kernel::GridSize(a); \
|
||||
@@ -265,30 +290,55 @@ float fused_moesorting(fused_moesorting_trait t, fused_moesorting_args a, ck_til
|
||||
return ck_tile::make_kernel<kernel::BLOCK_SIZE>(kernel{}, grids, blocks, lds_size, kargs); \
|
||||
}()
|
||||
|
||||
#define MOR_SORTING_MP_DISPATCH_(mesh_type_, token_vec_0_, token_vec_1_, token_vec_23_) \
|
||||
if(t.local_expert_masking) \
|
||||
{ \
|
||||
float ave_time = \
|
||||
ck_tile::launch_kernel(s, \
|
||||
MOE_SORTING_MP_0(mesh_type_, token_vec_0_, true), \
|
||||
MOE_SORTING_MP_1(mesh_type_, token_vec_1_, true), \
|
||||
MOE_SORTING_MP_23(mesh_type_, token_vec_23_, true)); \
|
||||
return ave_time; \
|
||||
} \
|
||||
else \
|
||||
{ \
|
||||
float ave_time = \
|
||||
ck_tile::launch_kernel(s, \
|
||||
MOE_SORTING_MP_0(mesh_type_, token_vec_0_, false), \
|
||||
MOE_SORTING_MP_1(mesh_type_, token_vec_1_, false), \
|
||||
MOE_SORTING_MP_23(mesh_type_, token_vec_23_, false)); \
|
||||
return ave_time; \
|
||||
#define MOR_SORTING_MP_DISPATCH_(mesh_type_, token_vec_0_, token_vec_1_, token_vec_23_) \
|
||||
if(t.local_expert_masking) \
|
||||
{ \
|
||||
if(is_local_token) \
|
||||
{ \
|
||||
float ave_time = \
|
||||
ck_tile::launch_kernel(s, \
|
||||
MOE_SORTING_MP_0(mesh_type_, token_vec_0_, true, true), \
|
||||
MOE_SORTING_MP_1(mesh_type_, token_vec_1_, true, true), \
|
||||
MOE_SORTING_MP_23(mesh_type_, token_vec_23_, true, true)); \
|
||||
return ave_time; \
|
||||
} \
|
||||
else \
|
||||
{ \
|
||||
float ave_time = \
|
||||
ck_tile::launch_kernel(s, \
|
||||
MOE_SORTING_MP_0(mesh_type_, token_vec_0_, true, false), \
|
||||
MOE_SORTING_MP_1(mesh_type_, token_vec_1_, true, false), \
|
||||
MOE_SORTING_MP_23(mesh_type_, token_vec_23_, true, false)); \
|
||||
return ave_time; \
|
||||
} \
|
||||
} \
|
||||
else \
|
||||
{ \
|
||||
if(is_local_token) \
|
||||
{ \
|
||||
float ave_time = \
|
||||
ck_tile::launch_kernel(s, \
|
||||
MOE_SORTING_MP_0(mesh_type_, token_vec_0_, false, true), \
|
||||
MOE_SORTING_MP_1(mesh_type_, token_vec_1_, false, true), \
|
||||
MOE_SORTING_MP_23(mesh_type_, token_vec_23_, false, true)); \
|
||||
return ave_time; \
|
||||
} \
|
||||
else \
|
||||
{ \
|
||||
float ave_time = ck_tile::launch_kernel( \
|
||||
s, \
|
||||
MOE_SORTING_MP_0(mesh_type_, token_vec_0_, false, false), \
|
||||
MOE_SORTING_MP_1(mesh_type_, token_vec_1_, false, false), \
|
||||
MOE_SORTING_MP_23(mesh_type_, token_vec_23_, false, false)); \
|
||||
return ave_time; \
|
||||
} \
|
||||
}
|
||||
|
||||
float fused_moesorting_mp(fused_moesorting_trait t,
|
||||
fused_moesorting_args a,
|
||||
ck_tile::stream_config s)
|
||||
{
|
||||
bool is_local_token = a.p_local_tokens != nullptr;
|
||||
if(t.weight_type == "fp32" && t.index_type == "int32")
|
||||
{
|
||||
using ms_index_t = ck_tile::index_t;
|
||||
@@ -360,3 +410,8 @@ float fused_moesorting_mp(fused_moesorting_trait t,
|
||||
}
|
||||
return -1;
|
||||
}
|
||||
|
||||
int fused_moesorting_get_workspace_size(int tokens, int num_experts, int topk)
|
||||
{
|
||||
return ck_tile::moe_sorting_get_workspace_size(tokens, num_experts, topk);
|
||||
}
|
||||
|
||||
@@ -87,7 +87,18 @@ void topid_unique_gen(
|
||||
auto create_args(int argc, char* argv[])
|
||||
{
|
||||
ck_tile::ArgParser arg_parser;
|
||||
arg_parser.insert("t", "128", "num input tokens")
|
||||
arg_parser
|
||||
.insert("t",
|
||||
"128",
|
||||
"number of input tokens.\n"
|
||||
"If \"local_t\" presents, this value indicates global concurrency of all ranks.")
|
||||
.insert(
|
||||
"local_t",
|
||||
"-1",
|
||||
"Number of local input tokens for curent rank.\n"
|
||||
"This value must be within range \"[0, t)\", or \"-1\"(no such feature)\n"
|
||||
"This feature is to simulate EP case where where each rank has different tokens.\n"
|
||||
"Besides, this value will be stored in a GPU buffer, which is friendly for CUDA graph.")
|
||||
.insert("e", "32", "num of experts")
|
||||
.insert("k", "5", "topk")
|
||||
.insert("h", "8192", "hidden_size of this model")
|
||||
@@ -131,6 +142,7 @@ template <typename I, typename W, typename O, typename ST, typename SW, typename
|
||||
bool run(const ck_tile::ArgParser& arg_parser)
|
||||
{
|
||||
ck_tile::index_t tokens = arg_parser.get_int("t");
|
||||
ck_tile::index_t local_tokens = arg_parser.get_int("local_t");
|
||||
ck_tile::index_t experts = arg_parser.get_int("e");
|
||||
ck_tile::index_t topk = arg_parser.get_int("k");
|
||||
ck_tile::index_t hidden_size = arg_parser.get_int("h");
|
||||
@@ -169,6 +181,14 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
// w1 (Down, N size)
|
||||
ck_tile::index_t shared_intermediate_size_1 = intermediate_size / tp;
|
||||
|
||||
bool is_local_token = local_tokens >= 0 && local_tokens < tokens;
|
||||
|
||||
if(local_tokens > tokens)
|
||||
{
|
||||
printf("local_tokens:%d larger than tokens:%d, invalid\n", local_tokens, tokens);
|
||||
return false;
|
||||
}
|
||||
|
||||
auto prec_str = [&]() {
|
||||
auto base_str = prec_i;
|
||||
if(prec_i != prec_w)
|
||||
@@ -198,11 +218,17 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
return std::string(", st:") + std::to_string(stride);
|
||||
}();
|
||||
|
||||
std::cout << "[" << api_str << "|" << prec_str << "]"
|
||||
<< " t:" << tokens;
|
||||
|
||||
if(is_local_token)
|
||||
{
|
||||
std::cout << "(" << local_tokens << ")";
|
||||
}
|
||||
|
||||
std::cout
|
||||
<< "[" << api_str << "|" << prec_str << "]"
|
||||
<< " t:" << tokens << ", e:" << experts << ", k:" << topk << stride_str
|
||||
<< ", hidden:" << hidden_size << ", interm:" << intermediate_size << ", tp:" << tp
|
||||
<< ", act:"
|
||||
<< ", e:" << experts << ", k:" << topk << stride_str << ", hidden:" << hidden_size
|
||||
<< ", interm:" << intermediate_size << ", tp:" << tp << ", act:"
|
||||
<< activation
|
||||
// << ", shrd_interm:" << shared_intermediate_size_0 << "|" << shared_intermediate_size_1
|
||||
<< (gate_only ? ", g1u0" : ", g1u1") << ", q:" << fused_quant << std::flush;
|
||||
@@ -377,6 +403,11 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
ck_tile::DeviceMem moe_sorting_ws(workspace_size != 0 ? workspace_size : 0);
|
||||
if(workspace_size != 0)
|
||||
moe_sorting_ws.SetZero(); // note, clear here!!!!
|
||||
ck_tile::DeviceMem local_tokens_dev(sizeof(ck_tile::index_t));
|
||||
if(is_local_token)
|
||||
{
|
||||
local_tokens_dev.ToDevice(&local_tokens);
|
||||
}
|
||||
|
||||
fused_moe_traits traits{prec_i,
|
||||
prec_w,
|
||||
@@ -400,6 +431,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
fused_quant == 1 ? sy_buf.GetDeviceBuffer() : nullptr,
|
||||
local_expert_masking ? local_expert_mask_buf.GetDeviceBuffer()
|
||||
: nullptr,
|
||||
is_local_token ? local_tokens_dev.GetDeviceBuffer() : nullptr,
|
||||
o_buf.GetDeviceBuffer(),
|
||||
workspace_size != 0 ? moe_sorting_ws.GetDeviceBuffer() : nullptr,
|
||||
topk_ids_buf.GetDeviceBuffer(),
|
||||
@@ -463,6 +495,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
num_sorted_tiles_host.mData[0],
|
||||
experts,
|
||||
block_m,
|
||||
is_local_token ? local_tokens : tokens,
|
||||
local_expert_masking);
|
||||
if(activation == 0)
|
||||
{
|
||||
@@ -495,6 +528,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
num_sorted_tiles_host.mData[0],
|
||||
experts,
|
||||
block_m,
|
||||
is_local_token ? local_tokens : tokens,
|
||||
local_expert_masking);
|
||||
|
||||
// done, preparing GPU buffer
|
||||
@@ -506,6 +540,11 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
ck_tile::DeviceMem sd_buf(sd_host);
|
||||
ck_tile::DeviceMem sy_buf(sy_host);
|
||||
ck_tile::DeviceMem o_buf(o_host);
|
||||
ck_tile::DeviceMem local_tokens_dev(sizeof(ck_tile::index_t));
|
||||
if(is_local_token)
|
||||
{
|
||||
local_tokens_dev.ToDevice(&local_tokens);
|
||||
}
|
||||
|
||||
// manually clear output buffer for atomic
|
||||
o_buf.SetZero();
|
||||
@@ -542,7 +581,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
num_sorted_tiles_buf.GetDeviceBuffer(),
|
||||
hidden_size,
|
||||
intermediate_size / tp,
|
||||
tokens,
|
||||
is_local_token ? local_tokens : tokens,
|
||||
experts,
|
||||
topk,
|
||||
stride};
|
||||
|
||||
@@ -21,10 +21,12 @@ CK_TILE_HOST void reference_moe_sorting(const HostTensor<IndexType>& topk_ids,
|
||||
index_t& unit_cnt,
|
||||
const index_t experts,
|
||||
const index_t unit_size,
|
||||
const index_t tokens,
|
||||
bool local_expert_masking,
|
||||
bool skip_experts_with_zero_token = true)
|
||||
{
|
||||
const index_t num_token = topk_ids.mDesc.get_lengths()[0];
|
||||
// note: if tokens is smaller than topk_ids.mDesc.get_lengths()[0], indicating local_token case
|
||||
const index_t num_token = tokens; // topk_ids.mDesc.get_lengths()[0];
|
||||
const index_t topk = topk_ids.mDesc.get_lengths()[1];
|
||||
// allocate a temp buffer, and fill the value with [number_token|topk]
|
||||
std::vector<std::vector<IndexType>> expert_tokens(
|
||||
|
||||
@@ -165,7 +165,8 @@ struct MoeSortingHostArgs
|
||||
const void* p_topk_ids; // [token, topk]
|
||||
const void* p_weights; // [token, topk]
|
||||
|
||||
const void* p_local_expert_mask;
|
||||
const void* p_local_expert_mask; // [experts]
|
||||
const void* p_local_tokens; // [1] if not nullptr, tokens read from here
|
||||
|
||||
void* p_sorted_token_ids;
|
||||
void* p_sorted_weights;
|
||||
@@ -177,7 +178,7 @@ struct MoeSortingHostArgs
|
||||
void* p_ws; // size is moe_sorting_get_workspace_size()
|
||||
// if return zero, then could be nullptr
|
||||
// must be cleard before use
|
||||
index_t tokens;
|
||||
index_t tokens; // if p_local_tokens is not nullptr, this indicate the max possible tokens used for ws/LDS calculation
|
||||
index_t unit_size; // this is the M_a of fused-moe kernel
|
||||
index_t num_experts;
|
||||
index_t topk;
|
||||
@@ -201,6 +202,7 @@ struct MoeSortingKernel
|
||||
const void* p_topk_ids;
|
||||
const void* p_weights;
|
||||
const void* p_local_expert_mask;
|
||||
const void* p_local_tokens; // [1] if not nullptr, tokens read from here
|
||||
void* p_sorted_token_ids;
|
||||
void* p_sorted_weights;
|
||||
void* p_sorted_expert_ids;
|
||||
@@ -253,6 +255,7 @@ struct MoeSortingKernel
|
||||
k.p_topk_ids = h.p_topk_ids;
|
||||
k.p_weights = h.p_weights;
|
||||
k.p_local_expert_mask = h.p_local_expert_mask;
|
||||
k.p_local_tokens = h.p_local_tokens;
|
||||
k.p_sorted_token_ids = h.p_sorted_token_ids;
|
||||
k.p_sorted_weights = h.p_sorted_weights;
|
||||
k.p_sorted_expert_ids = h.p_sorted_expert_ids;
|
||||
@@ -263,9 +266,13 @@ struct MoeSortingKernel
|
||||
k.moe_buf_bytes = h.moe_buf_bytes;
|
||||
|
||||
const auto blocks = BlockSize(h);
|
||||
// NOTE: tokens could from p_local_tokens, so here this variable is useless
|
||||
// hence moe_align_block_size_kernel() will not behavior properly if we have dynamic tokens
|
||||
// (indeed we can deprecate moe_align_block_size_kernel)
|
||||
k.tokens_per_thread = integer_divide_ceil(h.tokens * h.topk, blocks.x);
|
||||
k.unit_size_mdiv = mdiv{static_cast<uint32_t>(h.unit_size)};
|
||||
k.topk_mdiv = mdiv{static_cast<uint32_t>(h.topk)};
|
||||
// NOTE: tokens could from p_local_tokens, so here the LDS will be bigger than expected (but works)
|
||||
k.smem_rows = [&](){
|
||||
auto [r_, c_] = moe_sorting_get_smem_row_col(h.tokens, h.num_experts);
|
||||
(void) c_;
|
||||
@@ -1009,8 +1016,19 @@ struct MoeSortingKernel
|
||||
}
|
||||
const size_t numel = kargs.tokens * kargs.topk_mdiv.divisor;
|
||||
extern __shared__ char smem[];
|
||||
|
||||
#if MOE_SORTING_USE_EX_KERNEL
|
||||
(void)numel;
|
||||
index_t tokens_ = [&]() {
|
||||
if constexpr(Problem::LocalToken)
|
||||
{
|
||||
return reinterpret_cast<const index_t*>(kargs.p_local_tokens)[0];
|
||||
}
|
||||
else
|
||||
{
|
||||
return kargs.tokens;
|
||||
}
|
||||
}();
|
||||
return moe_align_block_size_kernel_ex(
|
||||
static_cast<const IndexType*>(kargs.p_topk_ids),
|
||||
static_cast<const WeightType*>(kargs.p_weights),
|
||||
@@ -1020,7 +1038,7 @@ struct MoeSortingKernel
|
||||
static_cast<IndexType*>(kargs.p_sorted_expert_ids),
|
||||
static_cast<IndexType*>(kargs.p_total_tokens_post_pad),
|
||||
kargs.num_experts,
|
||||
kargs.tokens,
|
||||
tokens_,
|
||||
kargs.unit_size_mdiv,
|
||||
kargs.topk_mdiv,
|
||||
kargs.expert_mdiv,
|
||||
@@ -1245,6 +1263,7 @@ CK_TILE_DEVICE void moe_buf_set_zero_kernel(uint8x16_t* buf, long_index_t buf_by
|
||||
|
||||
} // namespace impl
|
||||
|
||||
// TODO: tokens could be from
|
||||
// prefer to run mp kernel if is not oneshot
|
||||
CK_TILE_HOST bool moe_sorting_is_oneshot(int tokens_, int num_experts_)
|
||||
{
|
||||
@@ -1351,9 +1370,11 @@ struct MoeSortingMultiPhaseKernel_P0
|
||||
|
||||
struct Kargs
|
||||
{
|
||||
const void* p_topk_ids; // [tokens, topk]
|
||||
void* p_expert_mesh; // [expert, tokens]
|
||||
index_t tokens;
|
||||
const void* p_topk_ids; // [tokens, topk]
|
||||
const void* p_local_tokens; // [1], if not nullptr, use this as actual tokens
|
||||
void* p_expert_mesh; // [expert, tokens]
|
||||
index_t tokens; // if p_local_tokens is not nullptr, this indicate the max possible tokens
|
||||
// used for ws/LDS calculation
|
||||
index_t mesh_stride; // mesh_stride for p_expert_mesh
|
||||
mdiv topk_mdiv;
|
||||
};
|
||||
@@ -1373,11 +1394,12 @@ struct MoeSortingMultiPhaseKernel_P0
|
||||
CK_TILE_HOST static constexpr auto MakeKargs(const Hargs& h)
|
||||
{
|
||||
Kargs k;
|
||||
k.p_topk_ids = h.p_topk_ids;
|
||||
k.p_expert_mesh = h.p_ws;
|
||||
k.tokens = h.tokens;
|
||||
k.mesh_stride = impl::moe_sorting_mp_mesh_stride(h.tokens);
|
||||
k.topk_mdiv = mdiv{static_cast<uint32_t>(h.topk)};
|
||||
k.p_topk_ids = h.p_topk_ids;
|
||||
k.p_local_tokens = h.p_local_tokens;
|
||||
k.p_expert_mesh = h.p_ws;
|
||||
k.tokens = h.tokens;
|
||||
k.mesh_stride = impl::moe_sorting_mp_mesh_stride(h.tokens);
|
||||
k.topk_mdiv = mdiv{static_cast<uint32_t>(h.topk)};
|
||||
return k;
|
||||
}
|
||||
|
||||
@@ -1394,7 +1416,26 @@ struct MoeSortingMultiPhaseKernel_P0
|
||||
|
||||
const topk_id_t* p_topk_ids = reinterpret_cast<const topk_id_t*>(kargs.p_topk_ids);
|
||||
MeshType* p_expert_mesh = reinterpret_cast<MeshType*>(kargs.p_expert_mesh);
|
||||
index_t total_elem = kargs.tokens * kargs.topk_mdiv.divisor / Problem::SubTokenTile;
|
||||
index_t tokens = [&]() {
|
||||
if constexpr(Problem::LocalToken)
|
||||
{
|
||||
return reinterpret_cast<const index_t*>(kargs.p_local_tokens)[0];
|
||||
}
|
||||
else
|
||||
{
|
||||
return kargs.tokens;
|
||||
}
|
||||
}();
|
||||
index_t rounded_tokens = [&]() {
|
||||
if constexpr(Problem::LocalToken)
|
||||
{
|
||||
return (tokens + Problem::SubTokenTile - 1) / Problem::SubTokenTile *
|
||||
Problem::SubTokenTile;
|
||||
}
|
||||
else
|
||||
return tokens;
|
||||
}();
|
||||
index_t total_elem = rounded_tokens * kargs.topk_mdiv.divisor / Problem::SubTokenTile;
|
||||
|
||||
#pragma unroll Problem::SubTokenTile
|
||||
for(index_t i = blockIdx.x * BLOCK_SIZE + threadIdx.x; i < total_elem;
|
||||
@@ -1405,8 +1446,15 @@ struct MoeSortingMultiPhaseKernel_P0
|
||||
IndexType eid = x[j.value]; // ext_vector_type must use int to []
|
||||
uint32_t curr_token_id, curr_topk_id;
|
||||
kargs.topk_mdiv.divmod(i * Problem::SubTokenTile + j, curr_token_id, curr_topk_id);
|
||||
p_expert_mesh[eid * kargs.mesh_stride + curr_token_id] =
|
||||
(curr_topk_id + 1) & 0xffff;
|
||||
if constexpr(Problem::LocalToken)
|
||||
{
|
||||
if(static_cast<index_t>(curr_token_id) < tokens)
|
||||
p_expert_mesh[eid * kargs.mesh_stride + curr_token_id] =
|
||||
(curr_topk_id + 1) & 0xffff;
|
||||
}
|
||||
else
|
||||
p_expert_mesh[eid * kargs.mesh_stride + curr_token_id] =
|
||||
(curr_topk_id + 1) & 0xffff;
|
||||
});
|
||||
}
|
||||
}
|
||||
@@ -1542,6 +1590,7 @@ struct MoeSortingMultiPhaseKernel_P01
|
||||
{
|
||||
const void* p_topk_ids; // [tokens, topk]
|
||||
const void* p_local_expert_mask; // [expert]
|
||||
const void* p_local_tokens; // [1]
|
||||
void* p_expert_mesh; // [expert, tokens]
|
||||
void* p_expert_cumsum; // [expert + 1]
|
||||
void* p_expert_sem; // [1]
|
||||
@@ -1569,6 +1618,7 @@ struct MoeSortingMultiPhaseKernel_P01
|
||||
Kargs k;
|
||||
k.p_topk_ids = h.p_topk_ids;
|
||||
k.p_local_expert_mask = h.p_local_expert_mask;
|
||||
k.p_local_tokens = h.p_local_tokens;
|
||||
k.p_expert_mesh = h.p_ws;
|
||||
k.p_expert_cumsum = reinterpret_cast<void*>(
|
||||
reinterpret_cast<char*>(h.p_ws) +
|
||||
@@ -1580,8 +1630,17 @@ struct MoeSortingMultiPhaseKernel_P01
|
||||
k.tokens = h.tokens;
|
||||
k.num_experts = h.num_experts;
|
||||
k.mesh_stride = impl::moe_sorting_mp_mesh_stride(h.tokens);
|
||||
k.wg_count = WGCounts(h);
|
||||
k.topk_mdiv = mdiv{static_cast<uint32_t>(h.topk)};
|
||||
k.wg_count = [&]() {
|
||||
if constexpr(Problem::LocalToken)
|
||||
{
|
||||
return GridSize(h);
|
||||
}
|
||||
else
|
||||
{
|
||||
return WGCounts(h);
|
||||
}
|
||||
}();
|
||||
k.topk_mdiv = mdiv{static_cast<uint32_t>(h.topk)};
|
||||
return k;
|
||||
}
|
||||
|
||||
@@ -1607,13 +1666,46 @@ struct MoeSortingMultiPhaseKernel_P01
|
||||
CK_TILE_DEVICE void operator()(Kargs kargs) const
|
||||
{
|
||||
workgroup_barrier wb{reinterpret_cast<uint32_t*>(kargs.p_expert_sem)};
|
||||
index_t tokens = [&]() {
|
||||
if constexpr(Problem::LocalToken)
|
||||
{
|
||||
return reinterpret_cast<const index_t*>(kargs.p_local_tokens)[0];
|
||||
}
|
||||
else
|
||||
{
|
||||
return kargs.tokens;
|
||||
}
|
||||
}();
|
||||
index_t rounded_tokens = [&]() {
|
||||
if constexpr(Problem::LocalToken)
|
||||
{
|
||||
return (tokens + Problem::SubTokenTile - 1) / Problem::SubTokenTile *
|
||||
Problem::SubTokenTile;
|
||||
}
|
||||
else
|
||||
return tokens;
|
||||
}();
|
||||
index_t wg_count = [&]() {
|
||||
if constexpr(Problem::LocalToken)
|
||||
{
|
||||
index_t total_elem = rounded_tokens * kargs.topk / Problem::SubTokenTile;
|
||||
index_t elem_cnt = (total_elem + BLOCK_SIZE - 1) / BLOCK_SIZE;
|
||||
|
||||
// no more than grid_size
|
||||
return min(elem_cnt, kargs.wg_count);
|
||||
}
|
||||
else
|
||||
{
|
||||
return kargs.wg_count;
|
||||
}
|
||||
}();
|
||||
|
||||
{
|
||||
using topk_id_t = ext_vector_t<IndexType, Problem::SubTokenTile>;
|
||||
|
||||
const topk_id_t* p_topk_ids = reinterpret_cast<const topk_id_t*>(kargs.p_topk_ids);
|
||||
IndexType* p_expert_mesh = reinterpret_cast<IndexType*>(kargs.p_expert_mesh);
|
||||
index_t total_elem = kargs.tokens * kargs.topk_mdiv.divisor / Problem::SubTokenTile;
|
||||
index_t total_elem = rounded_tokens * kargs.topk_mdiv.divisor / Problem::SubTokenTile;
|
||||
|
||||
#pragma unroll Problem::SubTokenTile
|
||||
for(index_t i = blockIdx.x * BLOCK_SIZE + threadIdx.x; i < total_elem;
|
||||
@@ -1625,10 +1717,19 @@ struct MoeSortingMultiPhaseKernel_P01
|
||||
uint32_t curr_token_id, curr_topk_id;
|
||||
kargs.topk_mdiv.divmod(
|
||||
i * Problem::SubTokenTile + j, curr_token_id, curr_topk_id);
|
||||
p_expert_mesh[eid * kargs.mesh_stride + curr_token_id] = curr_topk_id + 1;
|
||||
// p_expert_mesh[eid * kargs.mesh_stride + curr_token_id] = curr_topk_id + 1;
|
||||
if constexpr(Problem::LocalToken)
|
||||
{
|
||||
if(static_cast<index_t>(curr_token_id) < tokens)
|
||||
p_expert_mesh[eid * kargs.mesh_stride + curr_token_id] =
|
||||
(curr_topk_id + 1) & 0xffff;
|
||||
}
|
||||
else
|
||||
p_expert_mesh[eid * kargs.mesh_stride + curr_token_id] =
|
||||
(curr_topk_id + 1) & 0xffff;
|
||||
});
|
||||
}
|
||||
if(static_cast<index_t>(blockIdx.x) < kargs.wg_count)
|
||||
if(static_cast<index_t>(blockIdx.x) < wg_count)
|
||||
{
|
||||
wb.inc();
|
||||
}
|
||||
@@ -1642,7 +1743,7 @@ struct MoeSortingMultiPhaseKernel_P01
|
||||
if(eid >= kargs.num_experts)
|
||||
return;
|
||||
|
||||
wb.wait_lt(kargs.wg_count);
|
||||
wb.wait_lt(wg_count);
|
||||
|
||||
for(; eid < kargs.num_experts; eid += gridDim.x)
|
||||
{
|
||||
@@ -1731,6 +1832,7 @@ struct MoeSortingMultiPhaseKernel_P2
|
||||
struct Kargs
|
||||
{
|
||||
const void* p_local_expert_mask; // [expert]
|
||||
const void* p_local_tokens; // [1]
|
||||
void* p_expert_mesh; // [expert, tokens]
|
||||
void* p_expert_cumsum; // [expert + 1]
|
||||
void* p_total_tokens_post_pad; // [1]
|
||||
@@ -1747,6 +1849,7 @@ struct MoeSortingMultiPhaseKernel_P2
|
||||
{
|
||||
Kargs k;
|
||||
k.p_local_expert_mask = h.p_local_expert_mask;
|
||||
k.p_local_tokens = h.p_local_tokens;
|
||||
k.p_expert_cumsum = reinterpret_cast<void*>(
|
||||
reinterpret_cast<char*>(h.p_ws) +
|
||||
impl::moe_sorting_mp_mesh_smem_size(h.tokens, h.num_experts, h.topk));
|
||||
@@ -1942,6 +2045,7 @@ struct MoeSortingMultiPhaseKernel_P3
|
||||
{
|
||||
const void* p_weights;
|
||||
const void* p_local_expert_mask;
|
||||
const void* p_local_tokens;
|
||||
void* p_sorted_token_ids;
|
||||
void* p_sorted_weights;
|
||||
void* p_expert_mesh; // [token, expert]
|
||||
@@ -1958,6 +2062,7 @@ struct MoeSortingMultiPhaseKernel_P3
|
||||
Kargs k;
|
||||
k.p_weights = h.p_weights;
|
||||
k.p_local_expert_mask = h.p_local_expert_mask;
|
||||
k.p_local_tokens = h.p_local_tokens;
|
||||
k.p_sorted_token_ids = h.p_sorted_token_ids;
|
||||
k.p_sorted_weights = h.p_sorted_weights;
|
||||
k.p_expert_mesh = h.p_ws;
|
||||
@@ -1994,6 +2099,16 @@ struct MoeSortingMultiPhaseKernel_P3
|
||||
const WeightType* p_weights = static_cast<const WeightType*>(kargs.p_weights);
|
||||
WeightType* p_sorted_weights = reinterpret_cast<WeightType*>(kargs.p_sorted_weights);
|
||||
|
||||
index_t tokens = [&]() {
|
||||
if constexpr(Problem::LocalToken)
|
||||
{
|
||||
return reinterpret_cast<const index_t*>(kargs.p_local_tokens)[0];
|
||||
}
|
||||
else
|
||||
{
|
||||
return kargs.tokens;
|
||||
}
|
||||
}();
|
||||
int eid = blockIdx.x;
|
||||
int wave_id = threadIdx.x / WarpSize;
|
||||
int lane_id = threadIdx.x % WarpSize;
|
||||
@@ -2019,7 +2134,7 @@ struct MoeSortingMultiPhaseKernel_P3
|
||||
{
|
||||
int i_token = i * BLOCK_SIZE + threadIdx.x;
|
||||
IndexType x = 0;
|
||||
if(i_token < kargs.tokens)
|
||||
if(i_token < tokens)
|
||||
{
|
||||
x = p_expert_mesh[eid * kargs.mesh_stride + i_token];
|
||||
}
|
||||
@@ -2066,7 +2181,7 @@ struct MoeSortingMultiPhaseKernel_P3
|
||||
for(index_t i = e_start + prev_cumsum + threadIdx.x; i < e_end; i += BLOCK_SIZE)
|
||||
{
|
||||
#if CK_TILE_REFERENCE_MOE_SORTING_MOCK_ID
|
||||
p_sorted_token_ids[i] = MOE_SORTING_MOCK_ID(kargs.tokens, kargs.topk_mdiv.divisor);
|
||||
p_sorted_token_ids[i] = MOE_SORTING_MOCK_ID(tokens, kargs.topk_mdiv.divisor);
|
||||
#else
|
||||
p_sorted_token_ids[i] = tokens;
|
||||
#endif
|
||||
@@ -2105,6 +2220,7 @@ struct MoeSortingMultiPhaseKernel_P23
|
||||
{
|
||||
const void* p_weights;
|
||||
const void* p_local_expert_mask; // [expert]
|
||||
const void* p_local_tokens; // [1]
|
||||
void* p_expert_mesh; // [expert, tokens]
|
||||
void* p_expert_cumsum; // [expert + 1]
|
||||
void* p_total_tokens_post_pad; // [1]
|
||||
@@ -2127,6 +2243,7 @@ struct MoeSortingMultiPhaseKernel_P23
|
||||
Kargs k;
|
||||
k.p_weights = h.p_weights;
|
||||
k.p_local_expert_mask = h.p_local_expert_mask;
|
||||
k.p_local_tokens = h.p_local_tokens;
|
||||
k.p_expert_mesh = h.p_ws;
|
||||
k.p_expert_cumsum = reinterpret_cast<void*>(
|
||||
reinterpret_cast<char*>(h.p_ws) +
|
||||
@@ -2346,6 +2463,17 @@ struct MoeSortingMultiPhaseKernel_P23
|
||||
return; // skip empty expert
|
||||
}
|
||||
|
||||
index_t tokens = [&]() {
|
||||
if constexpr(Problem::LocalToken)
|
||||
{
|
||||
return reinterpret_cast<const index_t*>(kargs.p_local_tokens)[0];
|
||||
}
|
||||
else
|
||||
{
|
||||
return kargs.tokens;
|
||||
}
|
||||
}();
|
||||
|
||||
// cumsum one by one
|
||||
constexpr index_t index_pack = Problem::SubTokenTile; // always packed
|
||||
using r_t = ext_vector_t<MeshType, index_pack>; // always use int32x4
|
||||
@@ -2357,7 +2485,7 @@ struct MoeSortingMultiPhaseKernel_P23
|
||||
{
|
||||
int i_token_pack = i * BLOCK_SIZE + threadIdx.x;
|
||||
r_t x_v = 0;
|
||||
if(i_token_pack < (kargs.tokens + index_pack - 1) / index_pack)
|
||||
if(i_token_pack < (tokens + index_pack - 1) / index_pack)
|
||||
{
|
||||
x_v = reinterpret_cast<r_t*>(p_expert_mesh +
|
||||
eid * kargs.mesh_stride)[i_token_pack];
|
||||
@@ -2554,7 +2682,7 @@ struct MoeSortingMultiPhaseKernel_P23
|
||||
for(index_t i = e_start + prev_cumsum + threadIdx.x; i < e_end; i += BLOCK_SIZE)
|
||||
{
|
||||
#if CK_TILE_REFERENCE_MOE_SORTING_MOCK_ID
|
||||
p_sorted_token_ids[i] = MOE_SORTING_MOCK_ID(kargs.tokens, kargs.topk_mdiv.divisor);
|
||||
p_sorted_token_ids[i] = MOE_SORTING_MOCK_ID(tokens, kargs.topk_mdiv.divisor);
|
||||
#else
|
||||
p_sorted_token_ids[i] = tokens;
|
||||
#endif
|
||||
|
||||
@@ -31,6 +31,7 @@ template <typename IndexType_,
|
||||
index_t SubTokenTile_, // 1,2,4,8, or 0 in the future
|
||||
bool SubTokenOneShot_, // if we only loop over once or not
|
||||
bool LocalExpertMasking_, // used in EP case
|
||||
bool LocalToken_, // used in EP case
|
||||
bool SkipExpertsWithZeroTokens_ = true,
|
||||
index_t ExpertTile_ = 0>
|
||||
struct MoeSortingProblemEx
|
||||
@@ -44,6 +45,7 @@ struct MoeSortingProblemEx
|
||||
static constexpr index_t SubTokenTile = SubTokenTile_;
|
||||
static constexpr bool SubTokenOneShot = SubTokenOneShot_;
|
||||
static constexpr bool LocalExpertMasking = LocalExpertMasking_;
|
||||
static constexpr bool LocalToken = LocalToken_;
|
||||
static constexpr bool SkipExpertsWithZeroTokens = SkipExpertsWithZeroTokens_;
|
||||
static_assert(SubTokenTile == 1 || SubTokenTile == 2 || SubTokenTile == 4 || SubTokenTile == 8);
|
||||
static constexpr index_t ExpertTile = ExpertTile_; // TODO: only used in store out
|
||||
@@ -54,6 +56,7 @@ template <typename IndexType_,
|
||||
typename MeshType_,
|
||||
index_t SubTokenTile_, // 1,2,4,8
|
||||
bool LocalExpertMasking_, // used in EP case
|
||||
bool LocalToken_, // used in EP case
|
||||
bool SkipExpertsWithZeroTokens_ = true>
|
||||
struct MoeSortingProblemMp
|
||||
{
|
||||
@@ -64,6 +67,7 @@ struct MoeSortingProblemMp
|
||||
|
||||
static constexpr index_t SubTokenTile = SubTokenTile_;
|
||||
static constexpr bool LocalExpertMasking = LocalExpertMasking_;
|
||||
static constexpr bool LocalToken = LocalToken_;
|
||||
static constexpr bool SkipExpertsWithZeroTokens = SkipExpertsWithZeroTokens_;
|
||||
static_assert(SubTokenTile == 1 || SubTokenTile == 2 || SubTokenTile == 4 ||
|
||||
SubTokenTile == 8 || SubTokenTile == 16);
|
||||
|
||||
Reference in New Issue
Block a user