mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-21 21:39:15 +00:00
[CK_TILE] moe sorting ex kernel to support expert > 128 (#1840)
* moe sorting ex
* fix bug for race condition
* fix bug and optimze large expert
* fix
* optimize with sub_token_oneshot
* support skip empty tokens for expert sorting
* update moe_sorting
* tidy code
[ROCm/composable_kernel commit: c0adab4850]
This commit is contained in:
@@ -26,6 +26,10 @@ auto create_args(int argc, char* argv[])
|
||||
.insert("k", "4", "topk")
|
||||
.insert("unit", "32", "unit_size")
|
||||
.insert("moe_buf_size", "0", "moe_buf_size")
|
||||
.insert("local_eid",
|
||||
"-1",
|
||||
"a list of experts enabled as local expert. e.g. \"0,1,4,5\"\n"
|
||||
"please make sure eid is in ascending order!")
|
||||
.insert("seed", "-1", "seed to be used, -1 means random every time")
|
||||
.insert("kname", "0", "when set to 1 it will print kernel name")
|
||||
.insert("warmup", "5", "number of iterations before benchmark the kernel")
|
||||
@@ -74,6 +78,7 @@ bool test_moe_sorting(ck_tile::ArgParser args)
|
||||
int kname = args.get_int("kname");
|
||||
int warmup = args.get_int("warmup");
|
||||
int repeat = args.get_int("repeat");
|
||||
|
||||
int max_output_ids =
|
||||
ck_tile::integer_least_multiple(topk * tokens + num_experts * unit_size - topk, unit_size);
|
||||
|
||||
@@ -90,6 +95,30 @@ bool test_moe_sorting(ck_tile::ArgParser args)
|
||||
return false;
|
||||
}
|
||||
|
||||
bool local_expert_masking = args.get_str("local_eid") != "-1";
|
||||
auto local_expert_masking_host = [&]() {
|
||||
if(local_expert_masking)
|
||||
{
|
||||
auto local_eid = args.get_int_vec("local_eid");
|
||||
// std::vector<int> v_ {num_experts, 0};
|
||||
ck_tile::HostTensor<IndexType> v_{{num_experts}};
|
||||
v_.SetZero();
|
||||
for(auto eid : local_eid)
|
||||
{
|
||||
if(eid >= num_experts)
|
||||
{
|
||||
throw std::runtime_error(
|
||||
"local_eid larger than number of expert, please check");
|
||||
}
|
||||
v_.mData[eid] = 1;
|
||||
}
|
||||
return v_;
|
||||
}
|
||||
else
|
||||
// return std::vector<int>{};
|
||||
return ck_tile::HostTensor<IndexType>{{1}};
|
||||
}();
|
||||
|
||||
// tokens already considered batch size
|
||||
ck_tile::HostTensor<IndexType> topk_ids_host({tokens, topk}, {topk, 1});
|
||||
ck_tile::HostTensor<WeightType> weights_host({tokens, topk}, {topk, 1});
|
||||
@@ -111,6 +140,8 @@ bool test_moe_sorting(ck_tile::ArgParser args)
|
||||
sorted_expert_ids_host.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem sorted_id_cnt_dev(sorted_id_cnt_host.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem moe_buf_dev(moe_buf_host.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem local_expert_masking_dev(
|
||||
local_expert_masking_host.get_element_space_size_in_bytes());
|
||||
|
||||
topk_ids_dev.ToDevice(topk_ids_host.data());
|
||||
weights_dev.ToDevice(weights_host.data());
|
||||
@@ -118,11 +149,15 @@ bool test_moe_sorting(ck_tile::ArgParser args)
|
||||
{
|
||||
moe_buf_dev.ToDevice(moe_buf_host.data());
|
||||
}
|
||||
if(local_expert_masking)
|
||||
local_expert_masking_dev.ToDevice(local_expert_masking_host.data());
|
||||
|
||||
moe_sorting_trait trait{index_prec, weight_prec};
|
||||
moe_sorting_trait trait{index_prec, weight_prec, local_expert_masking};
|
||||
|
||||
moe_sorting_args karg{topk_ids_dev.GetDeviceBuffer(),
|
||||
weights_dev.GetDeviceBuffer(),
|
||||
local_expert_masking ? local_expert_masking_dev.GetDeviceBuffer()
|
||||
: nullptr,
|
||||
sorted_ids_dev.GetDeviceBuffer(),
|
||||
sorted_weights_dev.GetDeviceBuffer(),
|
||||
sorted_expert_ids_dev.GetDeviceBuffer(),
|
||||
@@ -140,15 +175,22 @@ bool test_moe_sorting(ck_tile::ArgParser args)
|
||||
warmup,
|
||||
repeat};
|
||||
auto ms = moe_sorting(trait, karg, sc);
|
||||
printf("[%s|%s]tokens:%d, num_experts:%d, topk:%d, ms:%f , ",
|
||||
printf("[%s|%s]tokens:%d, num_experts:%d, topk:%d, ",
|
||||
index_prec.c_str(),
|
||||
weight_prec.c_str(),
|
||||
tokens,
|
||||
num_experts,
|
||||
topk,
|
||||
ms);
|
||||
topk);
|
||||
|
||||
if(local_expert_masking)
|
||||
{
|
||||
printf("local_eid:%s, ", args.get_str("local_eid").c_str());
|
||||
}
|
||||
|
||||
if(ms < 0)
|
||||
printf("not supported\n");
|
||||
else
|
||||
printf("ms:%f, ", ms);
|
||||
fflush(stdout);
|
||||
if(ms < 0)
|
||||
{
|
||||
@@ -174,12 +216,14 @@ bool test_moe_sorting(ck_tile::ArgParser args)
|
||||
int32_t ref_total_tokens_post_pad = 0;
|
||||
ck_tile::reference_moe_sorting<WeightType, IndexType>(topk_ids_host,
|
||||
weights_host,
|
||||
local_expert_masking_host,
|
||||
sorted_ids_ref,
|
||||
sorted_weights_ref,
|
||||
sorted_expert_ids_ref,
|
||||
ref_total_tokens_post_pad,
|
||||
num_experts,
|
||||
unit_size);
|
||||
unit_size,
|
||||
local_expert_masking);
|
||||
rtn &= ck_tile::check_err(
|
||||
sorted_ids_host, sorted_ids_ref, std::string("OUT Error: Incorrect ids!"), 1e-6, 1e-6);
|
||||
rtn &= ck_tile::check_err(sorted_weights_host,
|
||||
@@ -199,9 +243,16 @@ bool test_moe_sorting(ck_tile::ArgParser args)
|
||||
moe_buf_host, moe_buf_ref, std::string("OUT Error: Incorrect zero buf!"), 0, 0);
|
||||
}
|
||||
rtn &= ref_total_tokens_post_pad == sorted_id_cnt_host.mData[0];
|
||||
printf("total_tokens_post_pad:%d(%d), ",
|
||||
ref_total_tokens_post_pad,
|
||||
sorted_id_cnt_host.mData[0]);
|
||||
}
|
||||
|
||||
printf("valid:%s\n", rtn ? "y" : "n");
|
||||
printf("valid:%s", rtn ? "y" : "n");
|
||||
fflush(stdout);
|
||||
if(!rtn)
|
||||
printf(", (%d)", seed);
|
||||
printf("\n");
|
||||
fflush(stdout);
|
||||
return rtn;
|
||||
}
|
||||
|
||||
@@ -3,6 +3,12 @@
|
||||
|
||||
#include "moe_sorting_api.hpp"
|
||||
|
||||
#ifndef MOE_SORTING_USE_EX_KERNEL
|
||||
#define MOE_SORTING_USE_EX_KERNEL 1
|
||||
#endif
|
||||
|
||||
#if !MOE_SORTING_USE_EX_KERNEL
|
||||
|
||||
#define MOE_SORTING_DISPATCH_ETILE(unroll_num_, expert_tile_) \
|
||||
constexpr ck_tile::index_t unroll_num = unroll_num_; \
|
||||
constexpr ck_tile::index_t expert_tile = expert_tile_; \
|
||||
@@ -17,6 +23,67 @@
|
||||
s, ck_tile::make_kernel(kernel{}, grids, blocks, lds_bytes, kargs)); \
|
||||
return ave_time;
|
||||
|
||||
#else
|
||||
|
||||
#define MOE_SORTING_DISPATCH_(sub_token_tile_, sub_token_onshot_, local_expert_masking_) \
|
||||
constexpr ck_tile::index_t sub_token_tile = sub_token_tile_; \
|
||||
constexpr bool sub_token_onshot = sub_token_onshot_; \
|
||||
constexpr bool local_expert_masking = local_expert_masking_; \
|
||||
using ms_problem = ck_tile::MoeSortingProblemEx<index_t, \
|
||||
ms_weight_type, \
|
||||
sub_token_tile, \
|
||||
sub_token_onshot, \
|
||||
local_expert_masking>; \
|
||||
using kernel = ck_tile::MoeSortingKernel<ms_problem>; \
|
||||
auto kargs = kernel::MakeKargs(a); \
|
||||
const dim3 grids = kernel::GridSize(a); \
|
||||
const dim3 blocks = kernel::BlockSize(a); \
|
||||
const auto lds_bytes = kernel::GetSmemSize(a); \
|
||||
float ave_time = ck_tile::launch_kernel( \
|
||||
s, ck_tile::make_kernel(kernel{}, grids, blocks, lds_bytes, kargs)); \
|
||||
return ave_time;
|
||||
|
||||
#define MOE_SORTING_DISPATCH_SUB_TOKEN_(row_, sub_token_onshot_, local_expert_masking_) \
|
||||
if(row_ % 8 == 0) \
|
||||
{ \
|
||||
MOE_SORTING_DISPATCH_(8, sub_token_onshot_, local_expert_masking_); \
|
||||
} \
|
||||
else if(row_ % 4 == 0) \
|
||||
{ \
|
||||
MOE_SORTING_DISPATCH_(4, sub_token_onshot_, local_expert_masking_); \
|
||||
} \
|
||||
else if(row_ % 2 == 0) \
|
||||
{ \
|
||||
MOE_SORTING_DISPATCH_(2, sub_token_onshot_, local_expert_masking_); \
|
||||
} \
|
||||
else \
|
||||
{ \
|
||||
MOE_SORTING_DISPATCH_(1, sub_token_onshot_, local_expert_masking_); \
|
||||
}
|
||||
|
||||
#define MOE_SORTING_DISPATCH_SUBTO_(row_, local_expert_masking_) \
|
||||
if(is_sub_token_onshot) \
|
||||
{ \
|
||||
MOE_SORTING_DISPATCH_SUB_TOKEN_(row_, true, local_expert_masking_) \
|
||||
} \
|
||||
else \
|
||||
{ \
|
||||
MOE_SORTING_DISPATCH_SUB_TOKEN_(row_, false, local_expert_masking_) \
|
||||
}
|
||||
|
||||
#define MOE_SORTING_DISPATCH_EMASK_(row_) \
|
||||
if(is_local_expert_masking) \
|
||||
{ \
|
||||
MOE_SORTING_DISPATCH_SUBTO_(row_, true) \
|
||||
} \
|
||||
else \
|
||||
{ \
|
||||
MOE_SORTING_DISPATCH_SUBTO_(row_, false) \
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
#if !MOE_SORTING_USE_EX_KERNEL
|
||||
#define MOE_SORTING_DISPATCH(unroll_num_) \
|
||||
if(a.num_experts <= 8) \
|
||||
{ \
|
||||
@@ -38,11 +105,13 @@
|
||||
{ \
|
||||
MOE_SORTING_DISPATCH_ETILE(unroll_num_, 0) \
|
||||
}
|
||||
#endif
|
||||
|
||||
float moe_sorting(moe_sorting_trait t, moe_sorting_args a, ck_tile::stream_config s)
|
||||
{
|
||||
if(t.weight_type == "fp32" && t.index_type == "int32")
|
||||
{
|
||||
#if !MOE_SORTING_USE_EX_KERNEL
|
||||
if(a.num_experts > 127)
|
||||
{
|
||||
printf("lds size exceed, only support experts <127 \n");
|
||||
@@ -83,6 +152,19 @@ float moe_sorting(moe_sorting_trait t, moe_sorting_args a, ck_tile::stream_confi
|
||||
MOE_SORTING_DISPATCH(4);
|
||||
}
|
||||
}
|
||||
#else
|
||||
using index_t = ck_tile::index_t;
|
||||
using ms_weight_type = float;
|
||||
auto [r_, c_] = ck_tile::moe_sorting_get_smem_row_col(a.tokens, a.num_experts);
|
||||
auto sub_token_ = r_ - 2;
|
||||
r_ = (r_ - 2) / 8;
|
||||
bool is_sub_token_onshot = a.tokens <= sub_token_;
|
||||
bool is_local_expert_masking = t.local_expert_masking;
|
||||
(void)c_;
|
||||
|
||||
MOE_SORTING_DISPATCH_EMASK_(r_);
|
||||
// MOE_SORTING_DISPATCH_ETILE(0, 0);
|
||||
#endif
|
||||
}
|
||||
return -1;
|
||||
}
|
||||
|
||||
@@ -10,7 +10,8 @@
|
||||
struct moe_sorting_trait
|
||||
{
|
||||
std::string index_type;
|
||||
std::string weight_type; // currently always float
|
||||
std::string weight_type; // currently always float
|
||||
bool local_expert_masking; // if mask experts as local expert
|
||||
};
|
||||
|
||||
struct moe_sorting_args : public ck_tile::MoeSortingHostArgs
|
||||
|
||||
@@ -17,4 +17,12 @@ $EXE -t=71 -e=11 -k=11
|
||||
$EXE -t=1 -e=1 -k=1
|
||||
$EXE -t=99 -e=2 -k=1
|
||||
$EXE -t=333 -e=99 -k=13
|
||||
$EXE -t=11 -e=256 -k=5
|
||||
$EXE -t=64 -e=455 -k=8
|
||||
$EXE -t=777 -e=802 -k=99
|
||||
$EXE -t=4097 -e=906 -k=51
|
||||
$EXE -t=128 -e=32 -k=5 -moe_buf_size=262144
|
||||
$EXE -t=13 -e=64 -k=3 -local_eid=4,5,6,7,8,9,10,11
|
||||
$EXE -t=99 -e=33 -k=9 -local_eid=6,10,11,15,19
|
||||
$EXE -t=80 -e=99 -k=10 -local_eid=0,8,12,33
|
||||
$EXE -t=11 -e=256 -k=5 -local_eid=99,110,129
|
||||
|
||||
@@ -42,7 +42,7 @@ summary of the key design of this fused-moe operator:
|
||||
// (only for reference) exp-0 exp-1 exp-2 exp-3 exp-4 exp-5
|
||||
// weight_id_per_expert is: [[a], [g, j, m], [d, k], [b, e, h, l, n], [], [c, f, i, o]]
|
||||
//
|
||||
// max_num_tokens_padded : topk * input_tokens + num_experts * (M_a - 1)
|
||||
// max_num_tokens_padded : topk * input_tokens + num_experts * M_a - topk (updated)
|
||||
// * this could be larger than actual, since actual tokens are on GPU
|
||||
//
|
||||
// sorted_token_ids_ptr : [0, 6, 6, 6, 2, 3, 4, 6, 1, 3, 6, 6, 0, 1, 2, 3, 4, 6, 6, 6, 6, 6, 6, 6, 0, 1, 2, 5]
|
||||
|
||||
@@ -3,6 +3,12 @@
|
||||
|
||||
#include "fused_moesorting.hpp"
|
||||
|
||||
#ifndef MOE_SORTING_USE_EX_KERNEL
|
||||
#define MOE_SORTING_USE_EX_KERNEL 1
|
||||
#endif
|
||||
|
||||
#if !MOE_SORTING_USE_EX_KERNEL
|
||||
|
||||
#define MOE_SORTING_DISPATCH_ETILE(unroll_num_, expert_tile_) \
|
||||
constexpr ck_tile::index_t unroll_num = unroll_num_; \
|
||||
constexpr ck_tile::index_t expert_tile = expert_tile_; \
|
||||
@@ -17,6 +23,24 @@
|
||||
s, ck_tile::make_kernel(kernel{}, grids, blocks, lds_bytes, kargs)); \
|
||||
return ave_time;
|
||||
|
||||
#else
|
||||
#define MOE_SORTING_DISPATCH_(sub_token_tile_, sub_token_onshot_) \
|
||||
constexpr ck_tile::index_t sub_token_tile = sub_token_tile_; \
|
||||
constexpr bool sub_token_onshot = sub_token_onshot_; \
|
||||
using ms_problem = \
|
||||
ck_tile::MoeSortingProblemEx<index_t, ms_weight_type, sub_token_tile, sub_token_onshot>; \
|
||||
using kernel = ck_tile::MoeSortingKernel<ms_problem>; \
|
||||
auto kargs = kernel::MakeKargs(a); \
|
||||
const dim3 grids = kernel::GridSize(a); \
|
||||
const dim3 blocks = kernel::BlockSize(a); \
|
||||
const auto lds_bytes = kernel::GetSmemSize(a); \
|
||||
float ave_time = ck_tile::launch_kernel( \
|
||||
s, ck_tile::make_kernel(kernel{}, grids, blocks, lds_bytes, kargs)); \
|
||||
return ave_time;
|
||||
|
||||
#endif
|
||||
|
||||
#if !MOE_SORTING_USE_EX_KERNEL
|
||||
#define MOE_SORTING_DISPATCH(unroll_num_) \
|
||||
if(a.num_experts <= 8) \
|
||||
{ \
|
||||
@@ -38,11 +62,13 @@
|
||||
{ \
|
||||
MOE_SORTING_DISPATCH_ETILE(unroll_num_, 0) \
|
||||
}
|
||||
#endif
|
||||
|
||||
float fused_moesorting(fused_moesorting_trait t, fused_moesorting_args a, ck_tile::stream_config s)
|
||||
{
|
||||
if(t.weight_type == "fp32" && t.index_type == "int32")
|
||||
{
|
||||
#if !MOE_SORTING_USE_EX_KERNEL
|
||||
if(a.num_experts > 127)
|
||||
{
|
||||
printf("lds size exceed, only support experts <127 \n");
|
||||
@@ -83,6 +109,54 @@ float fused_moesorting(fused_moesorting_trait t, fused_moesorting_args a, ck_til
|
||||
MOE_SORTING_DISPATCH(4);
|
||||
}
|
||||
}
|
||||
#else
|
||||
using index_t = ck_tile::index_t;
|
||||
using ms_weight_type = float;
|
||||
auto [r_, c_] = ck_tile::moe_sorting_get_smem_row_col(a.tokens, a.num_experts);
|
||||
auto sub_token_ = r_ - 2;
|
||||
r_ = (r_ - 2) / 8;
|
||||
bool is_sub_token_onshot = a.tokens <= sub_token_;
|
||||
(void)c_;
|
||||
if(is_sub_token_onshot)
|
||||
{
|
||||
if(r_ % 8 == 0)
|
||||
{
|
||||
MOE_SORTING_DISPATCH_(8, true);
|
||||
}
|
||||
else if(r_ % 4 == 0)
|
||||
{
|
||||
MOE_SORTING_DISPATCH_(4, true);
|
||||
}
|
||||
else if(r_ % 2 == 0)
|
||||
{
|
||||
MOE_SORTING_DISPATCH_(2, true);
|
||||
}
|
||||
else
|
||||
{
|
||||
MOE_SORTING_DISPATCH_(1, true);
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
if(r_ % 8 == 0)
|
||||
{
|
||||
MOE_SORTING_DISPATCH_(8, false);
|
||||
}
|
||||
else if(r_ % 4 == 0)
|
||||
{
|
||||
MOE_SORTING_DISPATCH_(4, false);
|
||||
}
|
||||
else if(r_ % 2 == 0)
|
||||
{
|
||||
MOE_SORTING_DISPATCH_(2, false);
|
||||
}
|
||||
else
|
||||
{
|
||||
MOE_SORTING_DISPATCH_(1, false);
|
||||
}
|
||||
}
|
||||
// MOE_SORTING_DISPATCH_ETILE(0, 0);
|
||||
#endif
|
||||
}
|
||||
return -1;
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user