mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-19 22:39:03 +00:00
[CK_TILE] add moe-sorting MP kernel (#1910)
* moe sorting ex * fix bug for race condition * fix bug and optimze large expert * fix * optimize with sub_token_oneshot * support skip empty tokens for expert sorting * update moe_sorting * tidy code * support mp kernel * hint mp * remove use less code * porting to example 15 --------- Co-authored-by: valarLip <340077269@qq.com>
This commit is contained in:
@@ -152,6 +152,13 @@ bool test_moe_sorting(ck_tile::ArgParser args)
|
||||
if(local_expert_masking)
|
||||
local_expert_masking_dev.ToDevice(local_expert_masking_host.data());
|
||||
|
||||
// if return zero, means no need workspace, can set moe_sorting_args.p_ws to nullptr
|
||||
ck_tile::index_t workspace_size = moe_sorting_get_workspace_size(tokens, num_experts);
|
||||
ck_tile::DeviceMem moe_sorting_ws(workspace_size != 0 ? workspace_size : 0);
|
||||
|
||||
if(workspace_size != 0)
|
||||
moe_sorting_ws.SetZero(); // note, clear here!!!!
|
||||
|
||||
moe_sorting_trait trait{index_prec, weight_prec, local_expert_masking};
|
||||
|
||||
moe_sorting_args karg{topk_ids_dev.GetDeviceBuffer(),
|
||||
@@ -163,6 +170,7 @@ bool test_moe_sorting(ck_tile::ArgParser args)
|
||||
sorted_expert_ids_dev.GetDeviceBuffer(),
|
||||
sorted_id_cnt_dev.GetDeviceBuffer(),
|
||||
moe_buf_size > 0 ? moe_buf_dev.GetDeviceBuffer() : nullptr,
|
||||
workspace_size != 0 ? moe_sorting_ws.GetDeviceBuffer() : nullptr,
|
||||
tokens,
|
||||
unit_size,
|
||||
num_experts,
|
||||
@@ -174,13 +182,68 @@ bool test_moe_sorting(ck_tile::ArgParser args)
|
||||
/* log_level = */ (kname ? 1 : 0),
|
||||
warmup,
|
||||
repeat};
|
||||
|
||||
auto ms = moe_sorting(trait, karg, sc);
|
||||
printf("[%s|%s]tokens:%d, num_experts:%d, topk:%d, ",
|
||||
// auto ms = moe_sorting_mp(trait, karg, sc);
|
||||
|
||||
#if 0
|
||||
{
|
||||
ck_tile::HostTensor<char> ws_host({workspace_size}, {1});
|
||||
moe_sorting_ws.FromDevice(ws_host.data());
|
||||
|
||||
int * p_mesh = reinterpret_cast<int*>(ws_host.data());
|
||||
ck_tile::index_t row_size = ck_tile::impl::moe_sorting_mp_mesh_stride(tokens);
|
||||
|
||||
std::cout << "topk_ids:" << std::endl;
|
||||
|
||||
int * p_topk_ids = reinterpret_cast<int*>(topk_ids_host.data());
|
||||
for(int i_token = 0; i_token < tokens; i_token++) {
|
||||
printf("[t:%2d]", i_token);
|
||||
for(int i_topk = 0; i_topk < topk; i_topk++) {
|
||||
printf("%d, ",p_topk_ids[i_token * topk + i_topk] );
|
||||
}
|
||||
printf("\n");
|
||||
}
|
||||
printf("----------------\n");
|
||||
|
||||
std::vector<int> l_cumsum (num_experts + 1, 0);
|
||||
for(int i_expert = 0; i_expert < num_experts; i_expert++ ) {
|
||||
printf("[e:%2d]", i_expert);
|
||||
int e_cnt = 0;
|
||||
for(int i_token = 0; i_token < tokens; i_token++) {
|
||||
auto v_mesh = p_mesh[i_expert * row_size + i_token];
|
||||
e_cnt += v_mesh != 0 ? 1 : 0;
|
||||
printf("%d, ", v_mesh);
|
||||
}
|
||||
int e_cnt_unit = (e_cnt + unit_size - 1) / unit_size;
|
||||
printf("[%d/%d]", e_cnt, e_cnt_unit);
|
||||
printf("\n");
|
||||
l_cumsum[i_expert + 1] = l_cumsum[i_expert] + e_cnt_unit;
|
||||
}
|
||||
|
||||
printf("----------------\n");
|
||||
printf("cumsum:\n");
|
||||
for(int i_cc= 0; i_cc < num_experts + 1; i_cc++) {
|
||||
printf("%2d, ", l_cumsum[i_cc]);
|
||||
}
|
||||
printf("\n");
|
||||
printf("----------------\n");
|
||||
|
||||
int * p_cumsum = p_mesh + ck_tile::impl::moe_sorting_mp_mesh_elem(tokens, num_experts);
|
||||
for(int i_expert = 0; i_expert < num_experts + 1; i_expert++ ) {
|
||||
printf("%2d(%d), ",p_cumsum[i_expert], p_cumsum[i_expert] / unit_size);
|
||||
}
|
||||
printf("\n");
|
||||
}
|
||||
#endif
|
||||
|
||||
printf("[%s|%s]tokens:%d, num_experts:%d, topk:%d, mp:%d, ",
|
||||
index_prec.c_str(),
|
||||
weight_prec.c_str(),
|
||||
tokens,
|
||||
num_experts,
|
||||
topk);
|
||||
topk,
|
||||
workspace_size != 0 ? 1 : 0);
|
||||
|
||||
if(local_expert_masking)
|
||||
{
|
||||
@@ -224,28 +287,41 @@ bool test_moe_sorting(ck_tile::ArgParser args)
|
||||
num_experts,
|
||||
unit_size,
|
||||
local_expert_masking);
|
||||
rtn &= ck_tile::check_err(
|
||||
sorted_ids_host, sorted_ids_ref, std::string("OUT Error: Incorrect ids!"), 1e-6, 1e-6);
|
||||
rtn &= ck_tile::check_err(sorted_weights_host,
|
||||
sorted_weights_ref,
|
||||
std::string("OUT Error: Incorrect w!"),
|
||||
1e-6,
|
||||
1e-6);
|
||||
rtn &= ck_tile::check_err(sorted_expert_ids_host,
|
||||
sorted_expert_ids_ref,
|
||||
std::string("OUT Error: Incorrect eid!"),
|
||||
1e-6,
|
||||
1e-6);
|
||||
printf("total_tokens_post_pad:%d(%d), ",
|
||||
ref_total_tokens_post_pad,
|
||||
sorted_id_cnt_host.mData[0]);
|
||||
if(ref_total_tokens_post_pad == sorted_id_cnt_host.mData[0])
|
||||
{
|
||||
size_t slen = ref_total_tokens_post_pad;
|
||||
rtn &= ck_tile::check_err(sorted_ids_host.slice({0}, {slen}),
|
||||
sorted_ids_ref.slice({0}, {slen}),
|
||||
std::string("OUT Error: Incorrect ids!"),
|
||||
1e-6,
|
||||
1e-6);
|
||||
rtn &= ck_tile::check_err(sorted_weights_host.slice({0}, {slen}),
|
||||
sorted_weights_ref.slice({0}, {slen}),
|
||||
std::string("OUT Error: Incorrect w!"),
|
||||
1e-6,
|
||||
1e-6);
|
||||
rtn &= ck_tile::check_err(sorted_expert_ids_host.slice({0}, {slen / unit_size}),
|
||||
sorted_expert_ids_ref.slice({0}, {slen / unit_size}),
|
||||
std::string("OUT Error: Incorrect eid!"),
|
||||
1e-6,
|
||||
1e-6);
|
||||
}
|
||||
else
|
||||
{
|
||||
printf("(token size not equal!!)");
|
||||
rtn = false;
|
||||
}
|
||||
|
||||
if(moe_buf_size)
|
||||
{
|
||||
ck_tile::HostTensor<WeightType> moe_buf_ref({moe_buf_size});
|
||||
rtn &= ck_tile::check_err(
|
||||
moe_buf_host, moe_buf_ref, std::string("OUT Error: Incorrect zero buf!"), 0, 0);
|
||||
}
|
||||
rtn &= ref_total_tokens_post_pad == sorted_id_cnt_host.mData[0];
|
||||
printf("total_tokens_post_pad:%d(%d), ",
|
||||
ref_total_tokens_post_pad,
|
||||
sorted_id_cnt_host.mData[0]);
|
||||
// rtn &= ref_total_tokens_post_pad == sorted_id_cnt_host.mData[0];
|
||||
}
|
||||
|
||||
printf("valid:%s", rtn ? "y" : "n");
|
||||
|
||||
@@ -153,18 +153,106 @@ float moe_sorting(moe_sorting_trait t, moe_sorting_args a, ck_tile::stream_confi
|
||||
}
|
||||
}
|
||||
#else
|
||||
using index_t = ck_tile::index_t;
|
||||
using ms_weight_type = float;
|
||||
auto [r_, c_] = ck_tile::moe_sorting_get_smem_row_col(a.tokens, a.num_experts);
|
||||
auto sub_token_ = r_ - 2;
|
||||
r_ = (r_ - 2) / 8;
|
||||
bool is_sub_token_onshot = a.tokens <= sub_token_;
|
||||
if(moe_sorting_get_workspace_size(a.tokens, a.num_experts) != 0)
|
||||
{
|
||||
return moe_sorting_mp(t, a, s);
|
||||
}
|
||||
using index_t = ck_tile::index_t;
|
||||
using ms_weight_type = float;
|
||||
auto sub_token_ = ck_tile::moe_sorting_get_sub_token(a.tokens, a.num_experts);
|
||||
auto row_ = sub_token_ / 8;
|
||||
bool is_sub_token_onshot = a.tokens <= sub_token_;
|
||||
bool is_local_expert_masking = t.local_expert_masking;
|
||||
(void)c_;
|
||||
|
||||
MOE_SORTING_DISPATCH_EMASK_(r_);
|
||||
MOE_SORTING_DISPATCH_EMASK_(row_);
|
||||
// MOE_SORTING_DISPATCH_ETILE(0, 0);
|
||||
#endif
|
||||
}
|
||||
return -1;
|
||||
}
|
||||
|
||||
#define MOE_SORTING_MP_0(unroll_num_, expert_masking_) \
|
||||
[&]() { \
|
||||
constexpr ck_tile::index_t unroll_num = unroll_num_; \
|
||||
constexpr bool expert_masking = expert_masking_; \
|
||||
using ms_problem = \
|
||||
ck_tile::MoeSortingProblemMp<ms_index_t, ms_weight_type, unroll_num, expert_masking>; \
|
||||
using kernel = ck_tile::MoeSortingMultiPhaseKernel_P0<ms_problem>; \
|
||||
auto kargs = kernel::MakeKargs(a); \
|
||||
const dim3 grids = kernel::GridSize(a); \
|
||||
const dim3 blocks = kernel::BlockSize(a); \
|
||||
return ck_tile::make_kernel(kernel{}, grids, blocks, 0, kargs); \
|
||||
}()
|
||||
|
||||
#define MOE_SORTING_MP_1(unroll_num_, expert_masking_) \
|
||||
[&]() { \
|
||||
constexpr ck_tile::index_t unroll_num = unroll_num_; \
|
||||
constexpr bool expert_masking = expert_masking_; \
|
||||
using ms_problem = \
|
||||
ck_tile::MoeSortingProblemMp<ms_index_t, ms_weight_type, unroll_num, expert_masking>; \
|
||||
using kernel = ck_tile::MoeSortingMultiPhaseKernel_P1<ms_problem>; \
|
||||
auto kargs = kernel::MakeKargs(a); \
|
||||
const dim3 grids = kernel::GridSize(a); \
|
||||
const dim3 blocks = kernel::BlockSize(a); \
|
||||
return ck_tile::make_kernel(kernel{}, grids, blocks, 0, kargs); \
|
||||
}()
|
||||
|
||||
#define MOE_SORTING_MP_2(unroll_num_, expert_masking_) \
|
||||
[&]() { \
|
||||
constexpr ck_tile::index_t unroll_num = unroll_num_; \
|
||||
constexpr bool expert_masking = expert_masking_; \
|
||||
using ms_problem = \
|
||||
ck_tile::MoeSortingProblemMp<ms_index_t, ms_weight_type, unroll_num, expert_masking>; \
|
||||
using kernel = ck_tile::MoeSortingMultiPhaseKernel_P2<ms_problem>; \
|
||||
auto kargs = kernel::MakeKargs(a); \
|
||||
const dim3 grids = kernel::GridSize(a); \
|
||||
const dim3 blocks = kernel::BlockSize(a); \
|
||||
return ck_tile::make_kernel(kernel{}, grids, blocks, 0, kargs); \
|
||||
}()
|
||||
|
||||
#define MOE_SORTING_MP_3(unroll_num_, expert_masking_) \
|
||||
[&]() { \
|
||||
constexpr ck_tile::index_t unroll_num = unroll_num_; \
|
||||
constexpr bool expert_masking = expert_masking_; \
|
||||
using ms_problem = \
|
||||
ck_tile::MoeSortingProblemMp<ms_index_t, ms_weight_type, unroll_num, expert_masking>; \
|
||||
using kernel = ck_tile::MoeSortingMultiPhaseKernel_P3<ms_problem>; \
|
||||
auto kargs = kernel::MakeKargs(a); \
|
||||
const dim3 grids = kernel::GridSize(a); \
|
||||
const dim3 blocks = kernel::BlockSize(a); \
|
||||
return ck_tile::make_kernel(kernel{}, grids, blocks, 0, kargs); \
|
||||
}()
|
||||
|
||||
float moe_sorting_mp(moe_sorting_trait t, moe_sorting_args a, ck_tile::stream_config s)
|
||||
{
|
||||
if(t.weight_type == "fp32" && t.index_type == "int32")
|
||||
{
|
||||
using ms_index_t = ck_tile::index_t;
|
||||
using ms_weight_type = float;
|
||||
|
||||
if(t.local_expert_masking)
|
||||
{
|
||||
float ave_time = ck_tile::launch_kernel(s,
|
||||
MOE_SORTING_MP_0(1, true),
|
||||
MOE_SORTING_MP_1(1, true),
|
||||
MOE_SORTING_MP_2(1, true),
|
||||
MOE_SORTING_MP_3(1, true));
|
||||
return ave_time;
|
||||
}
|
||||
else
|
||||
{
|
||||
float ave_time = ck_tile::launch_kernel(s,
|
||||
MOE_SORTING_MP_0(1, false),
|
||||
MOE_SORTING_MP_1(1, false),
|
||||
MOE_SORTING_MP_2(1, false),
|
||||
MOE_SORTING_MP_3(1, false));
|
||||
return ave_time;
|
||||
}
|
||||
}
|
||||
return -1;
|
||||
}
|
||||
|
||||
int moe_sorting_get_workspace_size(int tokens, int num_experts)
|
||||
{
|
||||
return ck_tile::moe_sorting_get_workspace_size(tokens, num_experts);
|
||||
}
|
||||
|
||||
@@ -18,4 +18,10 @@ struct moe_sorting_args : public ck_tile::MoeSortingHostArgs
|
||||
{
|
||||
};
|
||||
|
||||
// use below API before call moe_sorting() to indicate if need workspace or not
|
||||
// if return non zero, means need workspace, you need to allocate a GPU buffer
|
||||
// and set to moe_sorting_args.p_ws
|
||||
// NOTE: workspace size are required to clear zero before use the API
|
||||
int moe_sorting_get_workspace_size(int tokens, int num_experts);
|
||||
float moe_sorting(moe_sorting_trait t, moe_sorting_args a, ck_tile::stream_config s);
|
||||
float moe_sorting_mp(moe_sorting_trait t, moe_sorting_args a, ck_tile::stream_config s);
|
||||
|
||||
@@ -17,6 +17,9 @@ struct fused_moe_args
|
||||
const void* y_smooth_scale_ptr; // [e, 1, n], smooth-quant-scale for 2nd gemm input
|
||||
const void* local_expert_mask_ptr; // [e], local_expert_mask_ptr for EP
|
||||
void* o_ptr; // [m, k], output token (no need to do zeroing)
|
||||
void* ws_ptr; // size is moe_sorting_get_workspace_size()
|
||||
// if return zero, then could be nullptr
|
||||
// must be cleard before use
|
||||
|
||||
const void* topk_ids_ptr; // [tokens, topk]
|
||||
const void* topk_weight_ptr; // [tokens, topk]
|
||||
|
||||
@@ -27,6 +27,7 @@ float fused_moe(fused_moe_traits t, fused_moe_args a, const ck_tile::stream_conf
|
||||
a.sorted_expert_ids_ptr, // void* p_sorted_expert_ids;
|
||||
a.num_sorted_tiles_ptr, // void* p_total_tokens_post_pad;
|
||||
a.o_ptr, // void* p_moe_buf;
|
||||
a.ws_ptr, // void* p_ws;
|
||||
a.num_tokens, // index_t tokens;
|
||||
a.block_m, // index_t unit_size;
|
||||
a.num_experts, // index_t num_experts;
|
||||
|
||||
@@ -371,6 +371,12 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
ck_tile::DeviceMem num_sorted_tiles_buf(
|
||||
num_sorted_tiles_host.get_element_space_size_in_bytes());
|
||||
|
||||
// if return zero, means no need workspace, can set moe_sorting_args.p_ws to nullptr
|
||||
ck_tile::index_t workspace_size = ck_tile::moe_sorting_get_workspace_size(tokens, experts);
|
||||
ck_tile::DeviceMem moe_sorting_ws(workspace_size != 0 ? workspace_size : 0);
|
||||
if(workspace_size != 0)
|
||||
moe_sorting_ws.SetZero(); // note, clear here!!!!
|
||||
|
||||
fused_moe_traits traits{prec_i,
|
||||
prec_w,
|
||||
prec_o,
|
||||
@@ -394,6 +400,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
local_expert_masking ? local_expert_mask_buf.GetDeviceBuffer()
|
||||
: nullptr,
|
||||
o_buf.GetDeviceBuffer(),
|
||||
workspace_size != 0 ? moe_sorting_ws.GetDeviceBuffer() : nullptr,
|
||||
topk_ids_buf.GetDeviceBuffer(),
|
||||
topk_weight_buf.GetDeviceBuffer(),
|
||||
sorted_token_ids_buf.GetDeviceBuffer(),
|
||||
|
||||
@@ -101,7 +101,7 @@ namespace ck_tile {
|
||||
// max_num_tokens_padded: opk_ids.numel() + num_experts * (block_size - 1)
|
||||
|
||||
|
||||
CK_TILE_HOST constexpr auto moe_sorting_get_smem_row_col(int num_tokens_, int num_experts_)
|
||||
CK_TILE_HOST constexpr auto moe_sorting_get_smem_row_col(int tokens_, int num_experts_)
|
||||
{
|
||||
/* num_experts + 1
|
||||
* +--------------------------------------+
|
||||
@@ -132,7 +132,7 @@ CK_TILE_HOST constexpr auto moe_sorting_get_smem_row_col(int num_tokens_, int nu
|
||||
|
||||
// round to sub_unroll multipl
|
||||
int r_for_sub_token = r - cumsum_bufs;
|
||||
r_for_sub_token = min(r_for_sub_token, num_tokens_);
|
||||
r_for_sub_token = min(r_for_sub_token, tokens_);
|
||||
r_for_sub_token = (r_for_sub_token + sub_unroll - 1) / sub_unroll * sub_unroll;
|
||||
r_for_sub_token = max(r_for_sub_token, 1);
|
||||
|
||||
@@ -148,7 +148,6 @@ CK_TILE_HOST constexpr auto moe_sorting_get_smem_row_col(int num_tokens_, int nu
|
||||
|
||||
mask_ = mask_ > 0b111 ? 0b111 : mask_; //clamp to 8x at most
|
||||
mask_ = ~mask_;
|
||||
//printf("r_unroll_:%d, clz:%d, mask:%x\n", r_unroll_, clz_, mask_); fflush(stdout);
|
||||
|
||||
r_for_sub_token = (r_unroll_ & mask_) * sub_unroll;
|
||||
}
|
||||
@@ -161,11 +160,17 @@ CK_TILE_HOST constexpr auto moe_sorting_get_smem_row_col(int num_tokens_, int nu
|
||||
return r_for_sub_token + cumsum_bufs;
|
||||
}();
|
||||
|
||||
// printf("r:%d, c:%d\n", smem_rows, smem_cols);
|
||||
|
||||
return ck_tile::make_tuple(smem_rows, smem_cols);
|
||||
}
|
||||
|
||||
CK_TILE_HOST index_t moe_sorting_get_sub_token(int tokens_, int num_experts_)
|
||||
{
|
||||
auto [r_, c_] = moe_sorting_get_smem_row_col(tokens_, num_experts_);
|
||||
auto sub_token_ = r_ - 2;
|
||||
(void) c_;
|
||||
return sub_token_;
|
||||
}
|
||||
|
||||
struct MoeSortingHostArgs
|
||||
{
|
||||
const void* p_topk_ids; // [token, topk]
|
||||
@@ -180,6 +185,9 @@ struct MoeSortingHostArgs
|
||||
// we fused the setzero of output of fused-moe buffer
|
||||
// set this pointer to nullptr will skip this operation
|
||||
void* p_moe_buf;
|
||||
void* p_ws; // size is moe_sorting_get_workspace_size()
|
||||
// if return zero, then could be nullptr
|
||||
// must be cleard before use
|
||||
index_t tokens;
|
||||
index_t unit_size; // this is the M_a of fused-moe kernel
|
||||
index_t num_experts;
|
||||
@@ -1046,6 +1054,812 @@ struct MoeSortingKernel
|
||||
}
|
||||
};
|
||||
|
||||
namespace impl {
|
||||
|
||||
// [expert, padded_tokens]
|
||||
CK_TILE_HOST_DEVICE index_t moe_sorting_mp_mesh_stride(index_t tokens)
|
||||
{
|
||||
constexpr index_t chunk = 32;
|
||||
return (tokens + chunk - 1) / chunk * chunk;
|
||||
};
|
||||
|
||||
CK_TILE_HOST_DEVICE index_t moe_sorting_mp_mesh_elem(index_t tokens, index_t num_experts)
|
||||
{
|
||||
index_t row_size = moe_sorting_mp_mesh_stride(tokens);
|
||||
return num_experts * row_size;
|
||||
};
|
||||
|
||||
CK_TILE_HOST_DEVICE index_t moe_sorting_mp_cumsum_elem(index_t num_experts)
|
||||
{
|
||||
constexpr index_t chunk = 32;
|
||||
index_t row_size = num_experts + 1;
|
||||
return (row_size + chunk - 1) / chunk * chunk;
|
||||
};
|
||||
|
||||
template <typename T, typename F, index_t wave_size_ = warpSize>
|
||||
CK_TILE_DEVICE constexpr T moe_sorting_wave_reduce(T local, F reduce_f, number<wave_size_> = {})
|
||||
{
|
||||
// constexpr int wave_size = 64;
|
||||
// constexpr int reduce_stage = 6; // 1<<6=64
|
||||
// clang-format off
|
||||
constexpr int reduce_stage = [](){
|
||||
if constexpr(wave_size_ == 2) return 1;
|
||||
else if constexpr(wave_size_ == 4) return 2;
|
||||
else if constexpr(wave_size_ == 8) return 3;
|
||||
else if constexpr(wave_size_ == 16) return 4;
|
||||
else if constexpr(wave_size_ == 32) return 5;
|
||||
else if constexpr(wave_size_ == 64) return 6;
|
||||
else return 0;
|
||||
}();
|
||||
// clang-format on
|
||||
T v_local = local;
|
||||
#pragma unroll reduce_stage
|
||||
for(int i_stage = 0; i_stage < reduce_stage; i_stage++)
|
||||
{
|
||||
int src_lane = __lane_id() ^ (1 << i_stage);
|
||||
int32_t v_remote_tmp =
|
||||
__builtin_amdgcn_ds_bpermute(src_lane << 2, bit_cast<int32_t>(v_local));
|
||||
T v_remote = bit_cast<T>(v_remote_tmp);
|
||||
v_local = reduce_f(v_local, v_remote);
|
||||
}
|
||||
return v_local;
|
||||
}
|
||||
|
||||
// [a, b, c, d....] -> [a, a+b, a+b+c, a+b+c+d, ....]
|
||||
// NOTE: wave_size need at least be 16!! dpp 16 is one row
|
||||
template <typename data_t, int wave_size>
|
||||
CK_TILE_DEVICE void moe_sorting_wave_cumsum(data_t& thread_data)
|
||||
{
|
||||
// wave_size must be power of 2
|
||||
constexpr int row_mask = 0xf;
|
||||
constexpr int bank_mask = 0xf;
|
||||
constexpr bool bound_ctrl = true; // ! out-of-bound is zero !
|
||||
auto reduce_op = [&](auto x_, auto y_) { return x_ + y_; };
|
||||
|
||||
if constexpr(wave_size > 1)
|
||||
{
|
||||
thread_data = reduce_op(
|
||||
thread_data,
|
||||
__builtin_bit_cast(data_t,
|
||||
__builtin_amdgcn_mov_dpp(__builtin_bit_cast(int, thread_data),
|
||||
0x111,
|
||||
row_mask,
|
||||
bank_mask,
|
||||
bound_ctrl))); // row_shr:1
|
||||
}
|
||||
|
||||
if constexpr(wave_size > 2)
|
||||
{
|
||||
thread_data = reduce_op(
|
||||
thread_data,
|
||||
__builtin_bit_cast(data_t,
|
||||
__builtin_amdgcn_mov_dpp(__builtin_bit_cast(int, thread_data),
|
||||
0x112,
|
||||
row_mask,
|
||||
bank_mask,
|
||||
bound_ctrl))); // row_shr:2
|
||||
}
|
||||
if constexpr(wave_size > 4)
|
||||
{
|
||||
thread_data = reduce_op(
|
||||
thread_data,
|
||||
__builtin_bit_cast(data_t,
|
||||
__builtin_amdgcn_mov_dpp(__builtin_bit_cast(int, thread_data),
|
||||
0x114,
|
||||
row_mask,
|
||||
bank_mask,
|
||||
bound_ctrl))); // row_shr:4
|
||||
}
|
||||
if constexpr(wave_size == 8)
|
||||
{
|
||||
|
||||
// wave-size=8 need one extra shift
|
||||
thread_data = reduce_op(
|
||||
thread_data,
|
||||
__builtin_bit_cast(data_t,
|
||||
__builtin_amdgcn_mov_dpp(__builtin_bit_cast(int, thread_data),
|
||||
0x118,
|
||||
row_mask,
|
||||
bank_mask,
|
||||
bound_ctrl))); // row_shr:8
|
||||
#if 0
|
||||
constexpr int bank_mask_0_7 = 0b1100;
|
||||
auto reduce_op_r = [&](auto x_, auto y_) { return x_ - y_; };
|
||||
thread_data = reduce_op_r(thread_data, __builtin_bit_cast(data_t,
|
||||
__builtin_amdgcn_update_dpp(0, /* old value */
|
||||
__builtin_bit_cast(int, thread_data),
|
||||
0x157,
|
||||
row_mask,
|
||||
bank_mask_0_7,
|
||||
bound_ctrl))// row_newbcast:7
|
||||
);
|
||||
#else
|
||||
data_t xxx =
|
||||
__builtin_bit_cast(data_t,
|
||||
__builtin_amdgcn_mov_dpp(__builtin_bit_cast(int, thread_data),
|
||||
0x157,
|
||||
row_mask,
|
||||
bank_mask,
|
||||
bound_ctrl)); // row_newbcast:7
|
||||
|
||||
data_t yyy = (__lane_id() / 8) % 2 == 0 ? 0 : xxx;
|
||||
thread_data = thread_data - yyy;
|
||||
#endif
|
||||
}
|
||||
if constexpr(wave_size > 8)
|
||||
{
|
||||
thread_data = reduce_op(
|
||||
thread_data,
|
||||
__builtin_bit_cast(data_t,
|
||||
__builtin_amdgcn_mov_dpp(__builtin_bit_cast(int, thread_data),
|
||||
0x118,
|
||||
row_mask,
|
||||
bank_mask,
|
||||
bound_ctrl))); // row_shr:8
|
||||
}
|
||||
|
||||
if constexpr(wave_size > 16)
|
||||
{
|
||||
// now row-0, row-0+row-1, row-1+row-2, row-2+row-3
|
||||
int v_remote_tmp = __builtin_amdgcn_ds_bpermute(((__lane_id() & 0x30) - 1) << 2,
|
||||
__builtin_bit_cast(int, thread_data));
|
||||
v_remote_tmp = __lane_id() >= 16 ? v_remote_tmp : 0;
|
||||
thread_data = reduce_op(thread_data, __builtin_bit_cast(data_t, v_remote_tmp));
|
||||
}
|
||||
|
||||
if constexpr(wave_size > 32)
|
||||
{
|
||||
// lane-id 48...63->31
|
||||
int v_remote_tmp = __builtin_amdgcn_ds_bpermute(((__lane_id() & 0x30) - 17) << 2,
|
||||
__builtin_bit_cast(int, thread_data));
|
||||
v_remote_tmp = __lane_id() >= 32 ? v_remote_tmp : 0;
|
||||
thread_data = reduce_op(thread_data, __builtin_bit_cast(data_t, v_remote_tmp));
|
||||
}
|
||||
}
|
||||
|
||||
template <index_t BLOCK_SIZE = 256>
|
||||
CK_TILE_DEVICE void moe_buf_set_zero_kernel(uint8x16_t* buf, index_t buf_bytes, index_t gid)
|
||||
{
|
||||
// const index_t offset = (blockIdx.x - 1) * BLOCK_SIZE + threadIdx.x;
|
||||
index_t offset = gid * BLOCK_SIZE + threadIdx.x;
|
||||
if(offset < buf_bytes / 16)
|
||||
{
|
||||
buf[offset] = uint8x16_t{0};
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace impl
|
||||
|
||||
// prefer to run mp kernel if is not oneshot
|
||||
CK_TILE_HOST bool moe_sorting_is_oneshot(int tokens_, int num_experts_)
|
||||
{
|
||||
auto sub_token_ = moe_sorting_get_sub_token(tokens_, num_experts_);
|
||||
bool is_sub_token_onshot = tokens_ <= sub_token_;
|
||||
return is_sub_token_onshot;
|
||||
}
|
||||
|
||||
// return size in byte
|
||||
CK_TILE_HOST index_t moe_sorting_mp_get_workspace_size(int tokens_, int num_experts_)
|
||||
{
|
||||
index_t elem = impl::moe_sorting_mp_mesh_elem(tokens_, num_experts_) +
|
||||
impl::moe_sorting_mp_cumsum_elem(num_experts_);
|
||||
return elem * sizeof(index_t);
|
||||
}
|
||||
|
||||
// return size in byte
|
||||
CK_TILE_HOST index_t moe_sorting_get_workspace_size(int tokens_, int num_experts_)
|
||||
{
|
||||
#if 1
|
||||
if(moe_sorting_is_oneshot(tokens_, num_experts_))
|
||||
{
|
||||
return 0;
|
||||
}
|
||||
else
|
||||
{
|
||||
return moe_sorting_mp_get_workspace_size(tokens_, num_experts_);
|
||||
}
|
||||
#else
|
||||
return moe_sorting_mp_get_workspace_size(tokens_, num_experts_);
|
||||
#endif
|
||||
}
|
||||
|
||||
// below kernel is multi-phase implementation for large token and/or expert case
|
||||
|
||||
// write into a buffer to record the token cnt
|
||||
// e.g. num_experts = 6, topk=3, M_a = 4, input_tokens = 5
|
||||
// before sort, topk_ids is : [[0, 3, 5], [2, 3, 5], [1, 3, 5], [1, 2, 3], [1, 3, 5]]
|
||||
// tok-0 tok-1 tok-2 tok-3 tok-4
|
||||
// topk_weight is : [[a, b, c], [d, e, f], [g, h, i], [j, k, l], [m, n, o]] (some float
|
||||
// number)
|
||||
//
|
||||
// token_id_per_expert is : [[0], [2, 3, 4], [1, 3], [0, 1, 2, 3, 4], [], [0, 1, 2, 5]]
|
||||
// (only for reference) exp-0 exp-1 exp-2 exp-3 exp-4 exp-5
|
||||
// weight_id_per_expert is: [[a], [g, j, m], [d, k], [b, e, h, l, n], [], [c, f, i, o]]
|
||||
/*
|
||||
|
||||
p_expert_mesh:
|
||||
t0 t1 t2 t3 t4 r5
|
||||
+--+--+--+--+--+--+
|
||||
e0 | 1| | | | | |
|
||||
e1 | | | 1| 1| 1| |
|
||||
e2 | | 1| | 1| | |
|
||||
e3 | 1| 1| 1| 1| 1| |
|
||||
e4 | | | | | | |
|
||||
e5 | 1| 1| 1| | | 1|
|
||||
|
||||
|
||||
p_expert_cumsum:
|
||||
| 1| 3| 2| 5| 0| 4|
|
||||
e0 e1 e2 e3 e4 e5
|
||||
|
||||
p_expert_cumsum(with M_a pad, and skip zero tokens):
|
||||
| 4| 4| 4| 8| 0| 4|
|
||||
e0 e1 e2 e3 e4 e5
|
||||
|
||||
p_expert_cumsum
|
||||
| 0| 4| 8|12|20|20|24|
|
||||
|
||||
local_expert_mask : [1, 0, 1, 1, 0, 1] (mask out expert-id=1, 4)
|
||||
|
||||
p_m_cumsum
|
||||
| 0| 1| 1| 2| 3| 3| 4|
|
||||
|
||||
*/
|
||||
|
||||
// count topk_id into mesh
|
||||
template <typename Problem_>
|
||||
struct MoeSortingMultiPhaseKernel_P0
|
||||
{
|
||||
using Problem = remove_cvref_t<Problem_>;
|
||||
|
||||
using IndexType = typename Problem::IndexType;
|
||||
using WeightType = typename Problem::WeightType;
|
||||
|
||||
static constexpr index_t BLOCK_SIZE = 256;
|
||||
static constexpr index_t OCCUPANCY = 2; // hard coded
|
||||
|
||||
typedef MoeSortingHostArgs MoeSortingKargs;
|
||||
|
||||
using Hargs = MoeSortingHostArgs;
|
||||
|
||||
struct Kargs
|
||||
{
|
||||
const void* p_topk_ids; // [tokens, topk]
|
||||
void* p_expert_mesh; // [expert, tokens]
|
||||
index_t tokens;
|
||||
index_t mesh_stride; // mesh_stride for p_expert_mesh
|
||||
mdiv topk_mdiv;
|
||||
};
|
||||
|
||||
CK_TILE_HOST static constexpr auto get_num_cu()
|
||||
{
|
||||
index_t num_cu = [&]() {
|
||||
hipDeviceProp_t dev_prop;
|
||||
hipDevice_t dev;
|
||||
HIP_CHECK_ERROR(hipGetDevice(&dev));
|
||||
HIP_CHECK_ERROR(hipGetDeviceProperties(&dev_prop, dev));
|
||||
return dev_prop.multiProcessorCount;
|
||||
}();
|
||||
return num_cu;
|
||||
}
|
||||
|
||||
CK_TILE_HOST static constexpr auto MakeKargs(const Hargs& h)
|
||||
{
|
||||
Kargs k;
|
||||
k.p_topk_ids = h.p_topk_ids;
|
||||
k.p_expert_mesh = h.p_ws;
|
||||
k.tokens = h.tokens;
|
||||
k.mesh_stride = impl::moe_sorting_mp_mesh_stride(h.tokens);
|
||||
k.topk_mdiv = mdiv{static_cast<uint32_t>(h.topk)};
|
||||
return k;
|
||||
}
|
||||
|
||||
CK_TILE_HOST static constexpr auto GridSize(const Hargs&) { return get_num_cu() * OCCUPANCY; }
|
||||
|
||||
CK_TILE_HOST static constexpr auto BlockSize(const Hargs&) { return dim3(BLOCK_SIZE); }
|
||||
|
||||
// in byte
|
||||
CK_TILE_HOST static constexpr auto GetSmemSize() { return 0; }
|
||||
|
||||
CK_TILE_DEVICE void operator()(Kargs kargs) const
|
||||
{
|
||||
using topk_id_t = ext_vector_t<IndexType, Problem::SubTokenTile>;
|
||||
|
||||
static_assert(Problem::SubTokenTile == 1 || Problem::SubTokenTile == 2 ||
|
||||
Problem::SubTokenTile == 4);
|
||||
|
||||
const topk_id_t* p_topk_ids = reinterpret_cast<const topk_id_t*>(kargs.p_topk_ids);
|
||||
IndexType* p_expert_mesh = reinterpret_cast<IndexType*>(kargs.p_expert_mesh);
|
||||
index_t total_elem = kargs.tokens * kargs.topk_mdiv.divisor / Problem::SubTokenTile;
|
||||
|
||||
#pragma unroll Problem::SubTokenTile
|
||||
for(index_t i = blockIdx.x * BLOCK_SIZE + threadIdx.x; i < total_elem; i += blockDim.x)
|
||||
{
|
||||
auto x = p_topk_ids[i];
|
||||
static_for<0, Problem::SubTokenTile, 1>{}([&](auto j) {
|
||||
IndexType eid = x[j.value]; // ext_vector_type must use int to []
|
||||
uint32_t curr_token_id, curr_topk_id;
|
||||
kargs.topk_mdiv.divmod(i * Problem::SubTokenTile + j, curr_token_id, curr_topk_id);
|
||||
p_expert_mesh[eid * kargs.mesh_stride + curr_token_id] = curr_topk_id + 1;
|
||||
});
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
// cnt total tokens for a expert
|
||||
template <typename Problem_>
|
||||
struct MoeSortingMultiPhaseKernel_P1
|
||||
{
|
||||
using Problem = remove_cvref_t<Problem_>;
|
||||
|
||||
using IndexType = typename Problem::IndexType;
|
||||
using WeightType = typename Problem::WeightType;
|
||||
|
||||
static constexpr index_t BLOCK_SIZE = 256;
|
||||
static constexpr index_t OCCUPANCY = 2; // hard coded
|
||||
|
||||
typedef MoeSortingHostArgs MoeSortingKargs;
|
||||
|
||||
using Hargs = MoeSortingHostArgs;
|
||||
struct Kargs
|
||||
{
|
||||
const void* p_local_expert_mask; // [expert]
|
||||
void* p_expert_mesh; // [expert, tokens]
|
||||
void* p_expert_cumsum;
|
||||
index_t mesh_stride; // mesh_stride for p_expert_mesh
|
||||
};
|
||||
|
||||
CK_TILE_HOST static constexpr auto MakeKargs(const Hargs& h)
|
||||
{
|
||||
Kargs k;
|
||||
k.p_local_expert_mask = h.p_local_expert_mask;
|
||||
k.p_expert_mesh = h.p_ws;
|
||||
k.p_expert_cumsum =
|
||||
reinterpret_cast<void*>(reinterpret_cast<IndexType*>(h.p_ws) +
|
||||
impl::moe_sorting_mp_mesh_elem(h.tokens, h.num_experts));
|
||||
k.mesh_stride = impl::moe_sorting_mp_mesh_stride(h.tokens);
|
||||
|
||||
return k;
|
||||
}
|
||||
|
||||
CK_TILE_HOST static constexpr auto GridSize(const Hargs& h) { return dim3(h.num_experts); }
|
||||
|
||||
CK_TILE_HOST static constexpr auto BlockSize(const Hargs&) { return dim3(BLOCK_SIZE); }
|
||||
|
||||
// in byte
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetSmemSize()
|
||||
{
|
||||
return BLOCK_SIZE / warpSize * sizeof(IndexType);
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE void operator()(Kargs kargs) const
|
||||
{
|
||||
__shared__ char smem[GetSmemSize()];
|
||||
|
||||
int eid = blockIdx.x;
|
||||
|
||||
constexpr index_t index_pack = 4; // always packed
|
||||
using r_t = ext_vector_t<IndexType, index_pack>; // always use int32x4
|
||||
r_t* p_expert_mesh = reinterpret_cast<r_t*>(
|
||||
reinterpret_cast<index_t*>(kargs.p_expert_mesh) + eid * kargs.mesh_stride);
|
||||
|
||||
static_assert(Problem::SubTokenTile == 1 || Problem::SubTokenTile == 2 ||
|
||||
Problem::SubTokenTile == 4);
|
||||
const IndexType* p_local_expert_mask =
|
||||
static_cast<const IndexType*>(kargs.p_local_expert_mask);
|
||||
IndexType* p_expert_cumsum = reinterpret_cast<IndexType*>(kargs.p_expert_cumsum);
|
||||
|
||||
auto f_sum = [](auto x_, auto y_) { return x_ + y_; };
|
||||
|
||||
int loops = (kargs.mesh_stride / index_pack + BLOCK_SIZE - 1) / BLOCK_SIZE;
|
||||
|
||||
if constexpr(Problem::LocalExpertMasking)
|
||||
{
|
||||
IndexType mask = p_local_expert_mask[eid];
|
||||
if(mask == 0)
|
||||
return; // skip
|
||||
}
|
||||
|
||||
index_t cnt = 0; // per-wave cnt
|
||||
for(int i = 0; i < loops; i++)
|
||||
{
|
||||
int position = i * BLOCK_SIZE + threadIdx.x;
|
||||
r_t v{0};
|
||||
if(position < (kargs.mesh_stride / index_pack))
|
||||
v = p_expert_mesh[position];
|
||||
index_t local_sum = 0;
|
||||
static_for<0, index_pack, 1>{}(
|
||||
[&](auto i_vec) { local_sum += v[i_vec.value] != 0 ? 1 : 0; });
|
||||
cnt += impl::moe_sorting_wave_reduce(local_sum, f_sum);
|
||||
}
|
||||
|
||||
index_t lane_id = threadIdx.x % warpSize;
|
||||
index_t wave_id = threadIdx.x / warpSize;
|
||||
|
||||
// reduce cross wave
|
||||
IndexType* s = reinterpret_cast<IndexType*>(smem);
|
||||
if(lane_id == 0)
|
||||
{
|
||||
s[wave_id] = cnt;
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
if(threadIdx.x == 0)
|
||||
{
|
||||
index_t c = 0;
|
||||
for(auto i = 0; i < (BLOCK_SIZE / warpSize); i++)
|
||||
{
|
||||
c += s[i];
|
||||
}
|
||||
p_expert_cumsum[eid] = c;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
// token count cumsum
|
||||
template <typename Problem_>
|
||||
struct MoeSortingMultiPhaseKernel_P2
|
||||
{
|
||||
using Problem = remove_cvref_t<Problem_>;
|
||||
|
||||
using IndexType = typename Problem::IndexType;
|
||||
using WeightType = typename Problem::WeightType;
|
||||
|
||||
static constexpr index_t BLOCK_SIZE = 256;
|
||||
static constexpr index_t OCCUPANCY = 2; // hard coded
|
||||
|
||||
typedef MoeSortingHostArgs MoeSortingKargs;
|
||||
|
||||
using Hargs = MoeSortingHostArgs;
|
||||
struct Kargs
|
||||
{
|
||||
const void* p_local_expert_mask; // [expert]
|
||||
void* p_expert_mesh; // [expert, tokens]
|
||||
void* p_expert_cumsum; // [expert + 1]
|
||||
void* p_total_tokens_post_pad; // [1]
|
||||
void* p_sorted_expert_ids;
|
||||
void* p_moe_buf;
|
||||
index_t tokens;
|
||||
index_t num_experts;
|
||||
index_t mesh_stride; // mesh_stride for p_expert_mesh
|
||||
mdiv unit_size_mdiv;
|
||||
index_t moe_buf_bytes;
|
||||
};
|
||||
|
||||
CK_TILE_HOST static constexpr auto MakeKargs(const Hargs& h)
|
||||
{
|
||||
Kargs k;
|
||||
k.p_local_expert_mask = h.p_local_expert_mask;
|
||||
// k.p_expert_mesh = h.p_ws;
|
||||
k.p_expert_cumsum =
|
||||
reinterpret_cast<void*>(reinterpret_cast<IndexType*>(h.p_ws) +
|
||||
impl::moe_sorting_mp_mesh_elem(h.tokens, h.num_experts));
|
||||
k.p_total_tokens_post_pad = h.p_total_tokens_post_pad;
|
||||
k.p_sorted_expert_ids = h.p_sorted_expert_ids;
|
||||
|
||||
k.p_moe_buf = h.p_moe_buf;
|
||||
|
||||
k.tokens = h.tokens;
|
||||
k.num_experts = h.num_experts;
|
||||
k.mesh_stride = impl::moe_sorting_mp_mesh_stride(h.tokens);
|
||||
k.unit_size_mdiv = mdiv{static_cast<uint32_t>(h.unit_size)};
|
||||
|
||||
k.moe_buf_bytes = h.moe_buf_bytes;
|
||||
|
||||
return k;
|
||||
}
|
||||
|
||||
CK_TILE_HOST static constexpr auto GridSize(const Hargs& h)
|
||||
{
|
||||
// use 1 block to cumsum
|
||||
return dim3(1 + ck_tile::integer_divide_ceil(h.moe_buf_bytes, BLOCK_SIZE * 16));
|
||||
}
|
||||
|
||||
CK_TILE_HOST static constexpr auto BlockSize(const Hargs&) { return dim3(BLOCK_SIZE); }
|
||||
|
||||
// in byte
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetSmemSize()
|
||||
{
|
||||
return 2 * BLOCK_SIZE * sizeof(IndexType);
|
||||
}
|
||||
|
||||
// reduce single pixel within a wave
|
||||
CK_TILE_DEVICE void operator()(Kargs kargs) const
|
||||
{
|
||||
if(blockIdx.x > 0)
|
||||
{
|
||||
impl::moe_buf_set_zero_kernel<BLOCK_SIZE>(
|
||||
reinterpret_cast<uint8x16_t*>(kargs.p_moe_buf),
|
||||
kargs.moe_buf_bytes,
|
||||
blockIdx.x - 1);
|
||||
return;
|
||||
}
|
||||
__shared__ char smem[GetSmemSize()];
|
||||
IndexType* s = reinterpret_cast<IndexType*>(smem);
|
||||
|
||||
const IndexType* p_local_expert_mask =
|
||||
static_cast<const IndexType*>(kargs.p_local_expert_mask);
|
||||
IndexType* p_expert_cumsum = reinterpret_cast<IndexType*>(kargs.p_expert_cumsum);
|
||||
IndexType* p_total_tokens_post_pad =
|
||||
reinterpret_cast<IndexType*>(kargs.p_total_tokens_post_pad);
|
||||
IndexType* p_sorted_expert_ids = reinterpret_cast<IndexType*>(kargs.p_sorted_expert_ids);
|
||||
|
||||
const index_t loops = (kargs.num_experts + BLOCK_SIZE - 1) / BLOCK_SIZE;
|
||||
index_t wave_id = threadIdx.x / warpSize;
|
||||
index_t lane_id = threadIdx.x % warpSize;
|
||||
|
||||
IndexType prev_cumsum_a = 0;
|
||||
IndexType prev_cumsum_b = 0;
|
||||
|
||||
for(index_t i = 0; i < loops; i++)
|
||||
{
|
||||
index_t position = i * BLOCK_SIZE + threadIdx.x;
|
||||
IndexType a_ = 0; // token count for a expert
|
||||
IndexType b_ = 0; // mask for a expert
|
||||
if(position < kargs.num_experts)
|
||||
{
|
||||
a_ = p_expert_cumsum[position];
|
||||
if constexpr(Problem::LocalExpertMasking)
|
||||
b_ = p_local_expert_mask[position];
|
||||
}
|
||||
|
||||
int blocks_pers_expert =
|
||||
kargs.unit_size_mdiv.div(a_ + kargs.unit_size_mdiv.divisor - 1);
|
||||
// pad token
|
||||
int padded_blocks_per_expert = [&]() {
|
||||
int x_ = [&]() {
|
||||
if constexpr(Problem::SkipExpertsWithZeroTokens)
|
||||
{
|
||||
// if local_cnt is zero, blocks_pers_expert will be zero
|
||||
// this is what we want to achieve
|
||||
return blocks_pers_expert; // * kargs.unit_size_mdiv.divisor;
|
||||
}
|
||||
else
|
||||
{
|
||||
return max(blocks_pers_expert, 1);
|
||||
}
|
||||
}();
|
||||
if constexpr(Problem::LocalExpertMasking)
|
||||
{
|
||||
return b_ ? x_ : 0;
|
||||
}
|
||||
else
|
||||
return x_;
|
||||
}();
|
||||
|
||||
IndexType cumsum_a = padded_blocks_per_expert;
|
||||
IndexType cumsum_b = b_;
|
||||
|
||||
// Note: we first cumsum local round, then add previous cumsum
|
||||
impl::moe_sorting_wave_cumsum<IndexType, warpSize>(cumsum_a);
|
||||
impl::moe_sorting_wave_cumsum<IndexType, warpSize>(cumsum_b);
|
||||
|
||||
__syncthreads();
|
||||
if(lane_id == warpSize - 1)
|
||||
{
|
||||
s[4 + wave_id] = cumsum_a;
|
||||
s[4 + wave_id + BLOCK_SIZE / warpSize] = cumsum_b;
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
|
||||
// reduce cross wave
|
||||
static_for<0, BLOCK_SIZE / warpSize - 1, 1>{}([&](auto i_w) {
|
||||
IndexType prev_a = s[4 + i_w];
|
||||
IndexType prev_b = s[4 + i_w + BLOCK_SIZE / warpSize];
|
||||
prev_a = wave_id > i_w ? prev_a : 0; // mask out
|
||||
prev_b = wave_id > i_w ? prev_b : 0; // mask out
|
||||
cumsum_a += prev_a;
|
||||
cumsum_b += prev_b;
|
||||
});
|
||||
|
||||
// Now let's add previous cumsum
|
||||
cumsum_a += prev_cumsum_a;
|
||||
cumsum_b += prev_cumsum_b;
|
||||
|
||||
if(threadIdx.x == BLOCK_SIZE - 1)
|
||||
{
|
||||
s[2] = cumsum_a; // store the last cumsum
|
||||
s[3] = cumsum_b;
|
||||
}
|
||||
|
||||
IndexType out_0 = cumsum_a - padded_blocks_per_expert; // exclusive cumsum tok cnt
|
||||
IndexType out_1 = cumsum_b - b_; // exclusive cumsum mask cnt
|
||||
|
||||
__syncthreads();
|
||||
prev_cumsum_a = s[2];
|
||||
prev_cumsum_b = s[3];
|
||||
|
||||
if(position < kargs.num_experts)
|
||||
{
|
||||
p_expert_cumsum[position] = out_0 * kargs.unit_size_mdiv.divisor;
|
||||
}
|
||||
|
||||
{
|
||||
if constexpr(Problem::LocalExpertMasking)
|
||||
{
|
||||
if(b_)
|
||||
{
|
||||
for(int j = 0; j < blocks_pers_expert; j++)
|
||||
{
|
||||
p_sorted_expert_ids[out_0 + j] = out_1;
|
||||
}
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
for(int j = 0; j < blocks_pers_expert; j++)
|
||||
{
|
||||
p_sorted_expert_ids[out_0 + j] = position;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if(threadIdx.x == 0)
|
||||
{
|
||||
auto total_tokens_post_pad = prev_cumsum_a * kargs.unit_size_mdiv.divisor;
|
||||
p_total_tokens_post_pad[0] = total_tokens_post_pad;
|
||||
p_expert_cumsum[kargs.num_experts] = total_tokens_post_pad;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <typename Problem_>
|
||||
struct MoeSortingMultiPhaseKernel_P3
|
||||
{
|
||||
using Problem = remove_cvref_t<Problem_>;
|
||||
|
||||
using IndexType = typename Problem::IndexType;
|
||||
using WeightType = typename Problem::WeightType;
|
||||
|
||||
static constexpr index_t BLOCK_SIZE = 256;
|
||||
static constexpr index_t OCCUPANCY = 2; // hard coded
|
||||
|
||||
typedef MoeSortingHostArgs MoeSortingKargs;
|
||||
|
||||
using Hargs = MoeSortingHostArgs;
|
||||
|
||||
struct Kargs
|
||||
{
|
||||
const void* p_weights;
|
||||
const void* p_local_expert_mask;
|
||||
void* p_sorted_token_ids;
|
||||
void* p_sorted_weights;
|
||||
void* p_expert_mesh; // [token, expert]
|
||||
void* p_expert_cumsum;
|
||||
|
||||
index_t tokens;
|
||||
index_t num_experts;
|
||||
index_t mesh_stride; // mesh_stride for p_expert_mesh
|
||||
mdiv topk_mdiv;
|
||||
};
|
||||
|
||||
CK_TILE_HOST static constexpr auto MakeKargs(const Hargs& h)
|
||||
{
|
||||
Kargs k;
|
||||
k.p_weights = h.p_weights;
|
||||
k.p_local_expert_mask = h.p_local_expert_mask;
|
||||
k.p_sorted_token_ids = h.p_sorted_token_ids;
|
||||
k.p_sorted_weights = h.p_sorted_weights;
|
||||
k.p_expert_mesh = h.p_ws;
|
||||
k.p_expert_cumsum =
|
||||
reinterpret_cast<void*>(reinterpret_cast<IndexType*>(h.p_ws) +
|
||||
impl::moe_sorting_mp_mesh_elem(h.tokens, h.num_experts));
|
||||
k.tokens = h.tokens;
|
||||
k.num_experts = h.num_experts;
|
||||
k.topk_mdiv = mdiv{static_cast<uint32_t>(h.topk)};
|
||||
k.mesh_stride = impl::moe_sorting_mp_mesh_stride(h.tokens);
|
||||
return k;
|
||||
}
|
||||
|
||||
CK_TILE_HOST static constexpr auto GridSize(const Hargs& h) { return dim3(h.num_experts); }
|
||||
|
||||
CK_TILE_HOST static constexpr auto BlockSize(const Hargs&) { return dim3(BLOCK_SIZE); }
|
||||
|
||||
// in byte
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetSmemSize()
|
||||
{
|
||||
return (4 + BLOCK_SIZE / warpSize) * sizeof(IndexType);
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE void operator()(Kargs kargs) const
|
||||
{
|
||||
__shared__ char smem[GetSmemSize()];
|
||||
|
||||
const IndexType* p_local_expert_mask =
|
||||
static_cast<const IndexType*>(kargs.p_local_expert_mask);
|
||||
IndexType* s = reinterpret_cast<IndexType*>(smem);
|
||||
IndexType* p_expert_mesh = reinterpret_cast<IndexType*>(kargs.p_expert_mesh);
|
||||
IndexType* p_sorted_token_ids = reinterpret_cast<IndexType*>(kargs.p_sorted_token_ids);
|
||||
IndexType* p_expert_cumsum = reinterpret_cast<IndexType*>(kargs.p_expert_cumsum);
|
||||
const WeightType* p_weights = static_cast<const WeightType*>(kargs.p_weights);
|
||||
WeightType* p_sorted_weights = reinterpret_cast<WeightType*>(kargs.p_sorted_weights);
|
||||
|
||||
static_assert(Problem::SubTokenTile == 1 || Problem::SubTokenTile == 2 ||
|
||||
Problem::SubTokenTile == 4);
|
||||
|
||||
int eid = blockIdx.x;
|
||||
int wave_id = threadIdx.x / warpSize;
|
||||
int lane_id = threadIdx.x % warpSize;
|
||||
int e_start = p_expert_cumsum[eid];
|
||||
int e_end = p_expert_cumsum[eid + 1];
|
||||
if constexpr(Problem::SkipExpertsWithZeroTokens)
|
||||
{
|
||||
if(e_start == e_end)
|
||||
return;
|
||||
}
|
||||
|
||||
if constexpr(Problem::LocalExpertMasking)
|
||||
{
|
||||
int e_mask = p_local_expert_mask[eid];
|
||||
if(e_mask == 0)
|
||||
return; // skip empty expert
|
||||
}
|
||||
|
||||
// cumsum one by one
|
||||
int loops = (kargs.mesh_stride + BLOCK_SIZE - 1) / BLOCK_SIZE;
|
||||
int prev_cumsum = 0;
|
||||
for(int i = 0; i < loops; i++)
|
||||
{
|
||||
int i_token = i * BLOCK_SIZE + threadIdx.x;
|
||||
IndexType x = 0;
|
||||
if(i_token < kargs.tokens)
|
||||
{
|
||||
x = p_expert_mesh[eid * kargs.mesh_stride + i_token];
|
||||
}
|
||||
int i_topk = x - 1; // topk of this token
|
||||
int i_show = x != 0 ? 1 : 0; // has this token or not
|
||||
int cumsum = i_show;
|
||||
impl::moe_sorting_wave_cumsum<int, warpSize>(cumsum);
|
||||
|
||||
__syncthreads();
|
||||
if(lane_id == warpSize - 1)
|
||||
{
|
||||
s[4 + wave_id] = cumsum;
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
// reduce cross wave
|
||||
static_for<0, BLOCK_SIZE / warpSize - 1, 1>{}([&](auto i_w) {
|
||||
IndexType prev = s[4 + i_w];
|
||||
prev = wave_id > i_w ? prev : 0; // mask out
|
||||
cumsum += prev;
|
||||
});
|
||||
cumsum += prev_cumsum; // add previous round cumsum
|
||||
if(threadIdx.x == BLOCK_SIZE - 1)
|
||||
{
|
||||
s[0] = cumsum;
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
int position = cumsum - i_show;
|
||||
prev_cumsum = s[0]; // update the last cumsum
|
||||
|
||||
if(i_show)
|
||||
{
|
||||
#if CK_TILE_REFERENCE_MOE_SORTING_MOCK_ID
|
||||
p_sorted_token_ids[e_start + position] = MOE_SORTING_MOCK_ID(i_token, i_topk);
|
||||
#else
|
||||
p_sorted_token_ids[e_start + position] = i_token;
|
||||
#endif
|
||||
p_sorted_weights[e_start + position] =
|
||||
p_weights[i_token * kargs.topk_mdiv.divisor + i_topk];
|
||||
}
|
||||
}
|
||||
|
||||
for(index_t i = e_start + prev_cumsum + threadIdx.x; i < e_end; i += BLOCK_SIZE)
|
||||
{
|
||||
#if CK_TILE_REFERENCE_MOE_SORTING_MOCK_ID
|
||||
p_sorted_token_ids[i] = MOE_SORTING_MOCK_ID(kargs.tokens, kargs.topk_mdiv.divisor);
|
||||
#else
|
||||
p_sorted_token_ids[i] = tokens;
|
||||
#endif
|
||||
p_sorted_weights[i] = static_cast<WeightType>(0.0);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
#undef MOE_SORTING_MOCK_ID
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
@@ -49,4 +49,21 @@ struct MoeSortingProblemEx
|
||||
static constexpr index_t ExpertTile = ExpertTile_; // TODO: only used in store out
|
||||
};
|
||||
|
||||
template <typename IndexType_,
|
||||
typename WeightType_,
|
||||
index_t SubTokenTile_, // 1,2,4
|
||||
bool LocalExpertMasking_, // used in EP case
|
||||
bool SkipExpertsWithZeroTokens_ = true>
|
||||
struct MoeSortingProblemMp
|
||||
{
|
||||
// TODO: this kernel only support warp per row
|
||||
using WeightType = remove_cvref_t<WeightType_>;
|
||||
using IndexType = remove_cvref_t<IndexType_>;
|
||||
|
||||
static constexpr index_t SubTokenTile = SubTokenTile_;
|
||||
static constexpr bool LocalExpertMasking = LocalExpertMasking_;
|
||||
static constexpr bool SkipExpertsWithZeroTokens = SkipExpertsWithZeroTokens_;
|
||||
static_assert(SubTokenTile == 1 || SubTokenTile == 2 || SubTokenTile == 4);
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
Reference in New Issue
Block a user