mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-05 22:22:27 +00:00
[CK_TILE] optimize moe sorting kernel, boost large context case up to 20x (#2153)
* combine 2-3 as single stage * support zeroing * improve long tokens * update specialization * b16 ws * 8bit topk optimize * update 15 example
This commit is contained in:
@@ -153,9 +153,8 @@ bool test_moe_sorting(ck_tile::ArgParser args)
|
||||
local_expert_masking_dev.ToDevice(local_expert_masking_host.data());
|
||||
|
||||
// if return zero, means no need workspace, can set moe_sorting_args.p_ws to nullptr
|
||||
ck_tile::index_t workspace_size = moe_sorting_get_workspace_size(tokens, num_experts);
|
||||
ck_tile::index_t workspace_size = moe_sorting_get_workspace_size(tokens, num_experts, topk);
|
||||
ck_tile::DeviceMem moe_sorting_ws(workspace_size != 0 ? workspace_size : 0);
|
||||
|
||||
if(workspace_size != 0)
|
||||
moe_sorting_ws.SetZero(); // note, clear here!!!!
|
||||
|
||||
|
||||
@@ -7,6 +7,14 @@
|
||||
#define MOE_SORTING_USE_EX_KERNEL 1
|
||||
#endif
|
||||
|
||||
#ifndef MOE_SORTING_SUPPORT_LARGE_EXPERT
|
||||
#define MOE_SORTING_SUPPORT_LARGE_EXPERT 0
|
||||
#endif
|
||||
|
||||
#ifndef MOE_SORTING_SUPPORT_LARGE_TOPK
|
||||
#define MOE_SORTING_SUPPORT_LARGE_TOPK 0
|
||||
#endif
|
||||
|
||||
#if !MOE_SORTING_USE_EX_KERNEL
|
||||
|
||||
#define MOE_SORTING_DISPATCH_ETILE(unroll_num_, expert_tile_) \
|
||||
@@ -153,7 +161,7 @@ float moe_sorting(moe_sorting_trait t, moe_sorting_args a, ck_tile::stream_confi
|
||||
}
|
||||
}
|
||||
#else
|
||||
if(moe_sorting_get_workspace_size(a.tokens, a.num_experts) != 0)
|
||||
if(moe_sorting_get_workspace_size(a.tokens, a.num_experts, a.topk) != 0)
|
||||
{
|
||||
return moe_sorting_mp(t, a, s);
|
||||
}
|
||||
@@ -171,57 +179,107 @@ float moe_sorting(moe_sorting_trait t, moe_sorting_args a, ck_tile::stream_confi
|
||||
return -1;
|
||||
}
|
||||
|
||||
#define MOE_SORTING_MP_0(unroll_num_, expert_masking_) \
|
||||
[&]() { \
|
||||
constexpr ck_tile::index_t unroll_num = unroll_num_; \
|
||||
constexpr bool expert_masking = expert_masking_; \
|
||||
using ms_problem = \
|
||||
ck_tile::MoeSortingProblemMp<ms_index_t, ms_weight_type, unroll_num, expert_masking>; \
|
||||
using kernel = ck_tile::MoeSortingMultiPhaseKernel_P0<ms_problem>; \
|
||||
auto kargs = kernel::MakeKargs(a); \
|
||||
const dim3 grids = kernel::GridSize(a); \
|
||||
const dim3 blocks = kernel::BlockSize(a); \
|
||||
return ck_tile::make_kernel(kernel{}, grids, blocks, 0, kargs); \
|
||||
#define MOE_SORTING_MP_0(mesh_type_, unroll_num_, expert_masking_) \
|
||||
[&]() { \
|
||||
constexpr ck_tile::index_t unroll_num = unroll_num_; \
|
||||
constexpr bool expert_masking = expert_masking_; \
|
||||
using ms_problem = ck_tile::MoeSortingProblemMp<ms_index_t, \
|
||||
ms_weight_type, \
|
||||
mesh_type_, \
|
||||
unroll_num, \
|
||||
expert_masking>; \
|
||||
using kernel = ck_tile::MoeSortingMultiPhaseKernel_P0<ms_problem>; \
|
||||
auto kargs = kernel::MakeKargs(a); \
|
||||
const dim3 grids = kernel::GridSize(a); \
|
||||
const dim3 blocks = kernel::BlockSize(a); \
|
||||
return ck_tile::make_kernel<kernel::BLOCK_SIZE>(kernel{}, grids, blocks, 0, kargs); \
|
||||
}()
|
||||
|
||||
#define MOE_SORTING_MP_1(unroll_num_, expert_masking_) \
|
||||
[&]() { \
|
||||
constexpr ck_tile::index_t unroll_num = unroll_num_; \
|
||||
constexpr bool expert_masking = expert_masking_; \
|
||||
using ms_problem = \
|
||||
ck_tile::MoeSortingProblemMp<ms_index_t, ms_weight_type, unroll_num, expert_masking>; \
|
||||
using kernel = ck_tile::MoeSortingMultiPhaseKernel_P1<ms_problem>; \
|
||||
auto kargs = kernel::MakeKargs(a); \
|
||||
const dim3 grids = kernel::GridSize(a); \
|
||||
const dim3 blocks = kernel::BlockSize(a); \
|
||||
return ck_tile::make_kernel(kernel{}, grids, blocks, 0, kargs); \
|
||||
#define MOE_SORTING_MP_1(mesh_type_, unroll_num_, expert_masking_) \
|
||||
[&]() { \
|
||||
constexpr ck_tile::index_t unroll_num = unroll_num_; \
|
||||
constexpr bool expert_masking = expert_masking_; \
|
||||
using ms_problem = ck_tile::MoeSortingProblemMp<ms_index_t, \
|
||||
ms_weight_type, \
|
||||
mesh_type_, \
|
||||
unroll_num, \
|
||||
expert_masking>; \
|
||||
using kernel = ck_tile::MoeSortingMultiPhaseKernel_P1<ms_problem>; \
|
||||
auto kargs = kernel::MakeKargs(a); \
|
||||
const dim3 grids = kernel::GridSize(a); \
|
||||
const dim3 blocks = kernel::BlockSize(a); \
|
||||
return ck_tile::make_kernel<kernel::BLOCK_SIZE>(kernel{}, grids, blocks, 0, kargs); \
|
||||
}()
|
||||
#if MOE_SORTING_SUPPORT_LARGE_EXPERT
|
||||
#define MOE_SORTING_MP_2(mesh_type_, unroll_num_, expert_masking_) \
|
||||
[&]() { \
|
||||
constexpr ck_tile::index_t unroll_num = unroll_num_; \
|
||||
constexpr bool expert_masking = expert_masking_; \
|
||||
using ms_problem = ck_tile::MoeSortingProblemMp<ms_index_t, \
|
||||
ms_weight_type, \
|
||||
mesh_type_, \
|
||||
unroll_num, \
|
||||
expert_masking>; \
|
||||
using kernel = ck_tile::MoeSortingMultiPhaseKernel_P2<ms_problem>; \
|
||||
auto kargs = kernel::MakeKargs(a); \
|
||||
const dim3 grids = kernel::GridSize(a); \
|
||||
const dim3 blocks = kernel::BlockSize(a); \
|
||||
return ck_tile::make_kernel(kernel{}, grids, blocks, 0, kargs); \
|
||||
}()
|
||||
|
||||
#define MOE_SORTING_MP_2(unroll_num_, expert_masking_) \
|
||||
[&]() { \
|
||||
constexpr ck_tile::index_t unroll_num = unroll_num_; \
|
||||
constexpr bool expert_masking = expert_masking_; \
|
||||
using ms_problem = \
|
||||
ck_tile::MoeSortingProblemMp<ms_index_t, ms_weight_type, unroll_num, expert_masking>; \
|
||||
using kernel = ck_tile::MoeSortingMultiPhaseKernel_P2<ms_problem>; \
|
||||
auto kargs = kernel::MakeKargs(a); \
|
||||
const dim3 grids = kernel::GridSize(a); \
|
||||
const dim3 blocks = kernel::BlockSize(a); \
|
||||
return ck_tile::make_kernel(kernel{}, grids, blocks, 0, kargs); \
|
||||
#define MOE_SORTING_MP_3(mesh_type_, unroll_num_, expert_masking_) \
|
||||
[&]() { \
|
||||
constexpr ck_tile::index_t unroll_num = unroll_num_; \
|
||||
constexpr bool expert_masking = expert_masking_; \
|
||||
using ms_problem = ck_tile::MoeSortingProblemMp<ms_index_t, \
|
||||
ms_weight_type, \
|
||||
mesh_type_, \
|
||||
unroll_num, \
|
||||
expert_masking>; \
|
||||
using kernel = ck_tile::MoeSortingMultiPhaseKernel_P3<ms_problem>; \
|
||||
auto kargs = kernel::MakeKargs(a); \
|
||||
const dim3 grids = kernel::GridSize(a); \
|
||||
const dim3 blocks = kernel::BlockSize(a); \
|
||||
return ck_tile::make_kernel(kernel{}, grids, blocks, 0, kargs); \
|
||||
}()
|
||||
#endif
|
||||
|
||||
#define MOE_SORTING_MP_23(mesh_type_, unroll_num_, expert_masking_) \
|
||||
[&]() { \
|
||||
constexpr ck_tile::index_t unroll_num = unroll_num_; \
|
||||
constexpr bool expert_masking = expert_masking_; \
|
||||
using ms_problem = ck_tile::MoeSortingProblemMp<ms_index_t, \
|
||||
ms_weight_type, \
|
||||
mesh_type_, \
|
||||
unroll_num, \
|
||||
expert_masking>; \
|
||||
using kernel = ck_tile::MoeSortingMultiPhaseKernel_P23<ms_problem>; \
|
||||
auto kargs = kernel::MakeKargs(a); \
|
||||
const dim3 grids = kernel::GridSize(a); \
|
||||
const dim3 blocks = kernel::BlockSize(a); \
|
||||
const auto lds_size = kernel::GetSmemSize(a); \
|
||||
return ck_tile::make_kernel<kernel::BLOCK_SIZE>(kernel{}, grids, blocks, lds_size, kargs); \
|
||||
}()
|
||||
|
||||
#define MOE_SORTING_MP_3(unroll_num_, expert_masking_) \
|
||||
[&]() { \
|
||||
constexpr ck_tile::index_t unroll_num = unroll_num_; \
|
||||
constexpr bool expert_masking = expert_masking_; \
|
||||
using ms_problem = \
|
||||
ck_tile::MoeSortingProblemMp<ms_index_t, ms_weight_type, unroll_num, expert_masking>; \
|
||||
using kernel = ck_tile::MoeSortingMultiPhaseKernel_P3<ms_problem>; \
|
||||
auto kargs = kernel::MakeKargs(a); \
|
||||
const dim3 grids = kernel::GridSize(a); \
|
||||
const dim3 blocks = kernel::BlockSize(a); \
|
||||
return ck_tile::make_kernel(kernel{}, grids, blocks, 0, kargs); \
|
||||
}()
|
||||
#define MOR_SORTING_MP_DISPATCH_(mesh_type_, token_vec_0_, token_vec_1_, token_vec_23_) \
|
||||
if(t.local_expert_masking) \
|
||||
{ \
|
||||
float ave_time = \
|
||||
ck_tile::launch_kernel(s, \
|
||||
MOE_SORTING_MP_0(mesh_type_, token_vec_0_, true), \
|
||||
MOE_SORTING_MP_1(mesh_type_, token_vec_1_, true), \
|
||||
MOE_SORTING_MP_23(mesh_type_, token_vec_23_, true)); \
|
||||
return ave_time; \
|
||||
} \
|
||||
else \
|
||||
{ \
|
||||
float ave_time = \
|
||||
ck_tile::launch_kernel(s, \
|
||||
MOE_SORTING_MP_0(mesh_type_, token_vec_0_, false), \
|
||||
MOE_SORTING_MP_1(mesh_type_, token_vec_1_, false), \
|
||||
MOE_SORTING_MP_23(mesh_type_, token_vec_23_, false)); \
|
||||
return ave_time; \
|
||||
}
|
||||
|
||||
float moe_sorting_mp(moe_sorting_trait t, moe_sorting_args a, ck_tile::stream_config s)
|
||||
{
|
||||
@@ -230,29 +288,74 @@ float moe_sorting_mp(moe_sorting_trait t, moe_sorting_args a, ck_tile::stream_co
|
||||
using ms_index_t = ck_tile::index_t;
|
||||
using ms_weight_type = float;
|
||||
|
||||
if(t.local_expert_masking)
|
||||
if(ck_tile::impl::moe_sorting_get_smem_size_p23(a.num_experts) >
|
||||
ck_tile::get_smem_capacity())
|
||||
{
|
||||
float ave_time = ck_tile::launch_kernel(s,
|
||||
MOE_SORTING_MP_0(1, true),
|
||||
MOE_SORTING_MP_1(1, true),
|
||||
MOE_SORTING_MP_2(1, true),
|
||||
MOE_SORTING_MP_3(1, true));
|
||||
return ave_time;
|
||||
#if MOE_SORTING_SUPPORT_LARGE_EXPERT
|
||||
if(t.local_expert_masking)
|
||||
{
|
||||
float ave_time = ck_tile::launch_kernel(s,
|
||||
MOE_SORTING_MP_0(ms_index_t, 1, true),
|
||||
MOE_SORTING_MP_1(ms_index_t, 1, true),
|
||||
MOE_SORTING_MP_2(ms_index_t, 1, true),
|
||||
MOE_SORTING_MP_3(ms_index_t, 1, true));
|
||||
return ave_time;
|
||||
}
|
||||
else
|
||||
{
|
||||
float ave_time = ck_tile::launch_kernel(s,
|
||||
MOE_SORTING_MP_0(ms_index_t, 1, false),
|
||||
MOE_SORTING_MP_1(ms_index_t, 1, false),
|
||||
MOE_SORTING_MP_2(ms_index_t, 1, false),
|
||||
MOE_SORTING_MP_3(ms_index_t, 1, false));
|
||||
return ave_time;
|
||||
}
|
||||
#else
|
||||
printf("do not support large expert %d\n", a.num_experts);
|
||||
return -1;
|
||||
#endif
|
||||
}
|
||||
else
|
||||
{
|
||||
float ave_time = ck_tile::launch_kernel(s,
|
||||
MOE_SORTING_MP_0(1, false),
|
||||
MOE_SORTING_MP_1(1, false),
|
||||
MOE_SORTING_MP_2(1, false),
|
||||
MOE_SORTING_MP_3(1, false));
|
||||
return ave_time;
|
||||
ck_tile::index_t mesh_byte_size =
|
||||
ck_tile::impl::moe_sorting_mesh_byte_size(a.tokens, a.num_experts, a.topk);
|
||||
if(mesh_byte_size == 1)
|
||||
{
|
||||
if(a.tokens * a.topk % 4 == 0)
|
||||
{
|
||||
MOR_SORTING_MP_DISPATCH_(uint8_t, 4, 16, 16)
|
||||
}
|
||||
else
|
||||
{
|
||||
MOR_SORTING_MP_DISPATCH_(uint8_t, 1, 16, 16)
|
||||
}
|
||||
}
|
||||
else if(mesh_byte_size == 2)
|
||||
{
|
||||
#if MOE_SORTING_SUPPORT_LARGE_TOPK
|
||||
if(a.tokens * a.topk % 4 == 0)
|
||||
{
|
||||
MOR_SORTING_MP_DISPATCH_(uint16_t, 4, 8, 8)
|
||||
}
|
||||
else
|
||||
{
|
||||
MOR_SORTING_MP_DISPATCH_(uint16_t, 1, 8, 8)
|
||||
}
|
||||
#else
|
||||
printf("do not support large topk %d\n", a.topk);
|
||||
return -1;
|
||||
#endif
|
||||
}
|
||||
else
|
||||
{
|
||||
MOR_SORTING_MP_DISPATCH_(ck_tile::index_t, 1, 1, 1)
|
||||
}
|
||||
}
|
||||
}
|
||||
return -1;
|
||||
}
|
||||
|
||||
int moe_sorting_get_workspace_size(int tokens, int num_experts)
|
||||
int moe_sorting_get_workspace_size(int tokens, int num_experts, int topk)
|
||||
{
|
||||
return ck_tile::moe_sorting_get_workspace_size(tokens, num_experts);
|
||||
return ck_tile::moe_sorting_get_workspace_size(tokens, num_experts, topk);
|
||||
}
|
||||
|
||||
@@ -22,6 +22,6 @@ struct moe_sorting_args : public ck_tile::MoeSortingHostArgs
|
||||
// if return non zero, means need workspace, you need to allocate a GPU buffer
|
||||
// and set to moe_sorting_args.p_ws
|
||||
// NOTE: workspace size are required to clear zero before use the API
|
||||
int moe_sorting_get_workspace_size(int tokens, int num_experts);
|
||||
int moe_sorting_get_workspace_size(int tokens, int num_experts, int topk);
|
||||
float moe_sorting(moe_sorting_trait t, moe_sorting_args a, ck_tile::stream_config s);
|
||||
float moe_sorting_mp(moe_sorting_trait t, moe_sorting_args a, ck_tile::stream_config s);
|
||||
|
||||
@@ -26,3 +26,9 @@ $EXE -t=13 -e=64 -k=3 -local_eid=4,5,6,7,8,9,10,11
|
||||
$EXE -t=99 -e=33 -k=9 -local_eid=6,10,11,15,19
|
||||
$EXE -t=80 -e=99 -k=10 -local_eid=0,8,12,33
|
||||
$EXE -t=11 -e=256 -k=5 -local_eid=99,110,129
|
||||
$EXE -t=128 -e=128 -k=6 -moe_buf_size=163840
|
||||
$EXE -t=8192 -e=32 -k=5 -moe_buf_size=163840
|
||||
$EXE -t=8192 -e=32 -k=8 -moe_buf_size=163840
|
||||
$EXE -t=8192 -e=256 -k=5 -moe_buf_size=163840
|
||||
$EXE -t=8192 -e=256 -k=8 -moe_buf_size=163840
|
||||
$EXE -t=163840 -e=256 -k=8 -moe_buf_size=163840
|
||||
@@ -56,4 +56,6 @@ struct fused_moe_traits
|
||||
bool local_expert_masking; // if mask experts as local expert
|
||||
};
|
||||
|
||||
// if return zero, no ws needed
|
||||
int fused_moe_get_workspace_size(int tokens, int num_experts, int topk);
|
||||
float fused_moe(fused_moe_traits, fused_moe_args, const ck_tile::stream_config&);
|
||||
|
||||
@@ -18,4 +18,5 @@ struct fused_moesorting_args : public ck_tile::MoeSortingHostArgs
|
||||
{
|
||||
};
|
||||
|
||||
int fused_moe_get_workspace_size(int tokens, int num_experts, int topk);
|
||||
float fused_moesorting(fused_moesorting_trait t, fused_moesorting_args a, ck_tile::stream_config s);
|
||||
|
||||
@@ -2,6 +2,12 @@
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "fused_moe.hpp"
|
||||
#include "ck_tile/ops/fused_moe.hpp"
|
||||
|
||||
int fused_moe_get_workspace_size(int tokens, int num_experts, int topk)
|
||||
{
|
||||
return ck_tile::moe_sorting_get_workspace_size(tokens, num_experts, topk);
|
||||
}
|
||||
|
||||
float fused_moe(fused_moe_traits t, fused_moe_args a, const ck_tile::stream_config& s)
|
||||
{
|
||||
|
||||
@@ -7,6 +7,14 @@
|
||||
#define MOE_SORTING_USE_EX_KERNEL 1
|
||||
#endif
|
||||
|
||||
#ifndef MOE_SORTING_SUPPORT_LARGE_EXPERT
|
||||
#define MOE_SORTING_SUPPORT_LARGE_EXPERT 0
|
||||
#endif
|
||||
|
||||
#ifndef MOE_SORTING_SUPPORT_LARGE_TOPK
|
||||
#define MOE_SORTING_SUPPORT_LARGE_TOPK 0
|
||||
#endif
|
||||
|
||||
#if !MOE_SORTING_USE_EX_KERNEL
|
||||
|
||||
#define MOE_SORTING_DISPATCH_ETILE(unroll_num_, expert_tile_) \
|
||||
@@ -107,6 +115,10 @@
|
||||
}
|
||||
#endif
|
||||
|
||||
float fused_moesorting_mp(fused_moesorting_trait t,
|
||||
fused_moesorting_args a,
|
||||
ck_tile::stream_config s);
|
||||
|
||||
float fused_moesorting(fused_moesorting_trait t, fused_moesorting_args a, ck_tile::stream_config s)
|
||||
{
|
||||
if(t.weight_type == "fp32" && t.index_type == "int32")
|
||||
@@ -153,18 +165,198 @@ float fused_moesorting(fused_moesorting_trait t, fused_moesorting_args a, ck_til
|
||||
}
|
||||
}
|
||||
#else
|
||||
using index_t = ck_tile::index_t;
|
||||
using ms_weight_type = float;
|
||||
auto [r_, c_] = ck_tile::moe_sorting_get_smem_row_col(a.tokens, a.num_experts);
|
||||
auto sub_token_ = r_ - 2;
|
||||
r_ = (r_ - 2) / 8;
|
||||
bool is_sub_token_onshot = a.tokens <= sub_token_;
|
||||
if(fused_moe_get_workspace_size(a.tokens, a.num_experts, a.topk) != 0)
|
||||
{
|
||||
return fused_moesorting_mp(t, a, s);
|
||||
}
|
||||
using index_t = ck_tile::index_t;
|
||||
using ms_weight_type = float;
|
||||
auto sub_token_ = ck_tile::moe_sorting_get_sub_token(a.tokens, a.num_experts);
|
||||
auto row_ = sub_token_ / 8;
|
||||
bool is_sub_token_onshot = a.tokens <= sub_token_;
|
||||
bool is_local_expert_masking = t.local_expert_masking;
|
||||
(void)c_;
|
||||
|
||||
MOE_SORTING_DISPATCH_EMASK_(r_);
|
||||
MOE_SORTING_DISPATCH_EMASK_(row_);
|
||||
// MOE_SORTING_DISPATCH_ETILE(0, 0);
|
||||
#endif
|
||||
}
|
||||
return -1;
|
||||
}
|
||||
|
||||
#define MOE_SORTING_MP_0(mesh_type_, unroll_num_, expert_masking_) \
|
||||
[&]() { \
|
||||
constexpr ck_tile::index_t unroll_num = unroll_num_; \
|
||||
constexpr bool expert_masking = expert_masking_; \
|
||||
using ms_problem = ck_tile::MoeSortingProblemMp<ms_index_t, \
|
||||
ms_weight_type, \
|
||||
mesh_type_, \
|
||||
unroll_num, \
|
||||
expert_masking>; \
|
||||
using kernel = ck_tile::MoeSortingMultiPhaseKernel_P0<ms_problem>; \
|
||||
auto kargs = kernel::MakeKargs(a); \
|
||||
const dim3 grids = kernel::GridSize(a); \
|
||||
const dim3 blocks = kernel::BlockSize(a); \
|
||||
return ck_tile::make_kernel<kernel::BLOCK_SIZE>(kernel{}, grids, blocks, 0, kargs); \
|
||||
}()
|
||||
|
||||
#define MOE_SORTING_MP_1(mesh_type_, unroll_num_, expert_masking_) \
|
||||
[&]() { \
|
||||
constexpr ck_tile::index_t unroll_num = unroll_num_; \
|
||||
constexpr bool expert_masking = expert_masking_; \
|
||||
using ms_problem = ck_tile::MoeSortingProblemMp<ms_index_t, \
|
||||
ms_weight_type, \
|
||||
mesh_type_, \
|
||||
unroll_num, \
|
||||
expert_masking>; \
|
||||
using kernel = ck_tile::MoeSortingMultiPhaseKernel_P1<ms_problem>; \
|
||||
auto kargs = kernel::MakeKargs(a); \
|
||||
const dim3 grids = kernel::GridSize(a); \
|
||||
const dim3 blocks = kernel::BlockSize(a); \
|
||||
return ck_tile::make_kernel<kernel::BLOCK_SIZE>(kernel{}, grids, blocks, 0, kargs); \
|
||||
}()
|
||||
#if MOE_SORTING_SUPPORT_LARGE_EXPERT
|
||||
#define MOE_SORTING_MP_2(mesh_type_, unroll_num_, expert_masking_) \
|
||||
[&]() { \
|
||||
constexpr ck_tile::index_t unroll_num = unroll_num_; \
|
||||
constexpr bool expert_masking = expert_masking_; \
|
||||
using ms_problem = ck_tile::MoeSortingProblemMp<ms_index_t, \
|
||||
ms_weight_type, \
|
||||
mesh_type_, \
|
||||
unroll_num, \
|
||||
expert_masking>; \
|
||||
using kernel = ck_tile::MoeSortingMultiPhaseKernel_P2<ms_problem>; \
|
||||
auto kargs = kernel::MakeKargs(a); \
|
||||
const dim3 grids = kernel::GridSize(a); \
|
||||
const dim3 blocks = kernel::BlockSize(a); \
|
||||
return ck_tile::make_kernel(kernel{}, grids, blocks, 0, kargs); \
|
||||
}()
|
||||
|
||||
#define MOE_SORTING_MP_3(mesh_type_, unroll_num_, expert_masking_) \
|
||||
[&]() { \
|
||||
constexpr ck_tile::index_t unroll_num = unroll_num_; \
|
||||
constexpr bool expert_masking = expert_masking_; \
|
||||
using ms_problem = ck_tile::MoeSortingProblemMp<ms_index_t, \
|
||||
ms_weight_type, \
|
||||
mesh_type_, \
|
||||
unroll_num, \
|
||||
expert_masking>; \
|
||||
using kernel = ck_tile::MoeSortingMultiPhaseKernel_P3<ms_problem>; \
|
||||
auto kargs = kernel::MakeKargs(a); \
|
||||
const dim3 grids = kernel::GridSize(a); \
|
||||
const dim3 blocks = kernel::BlockSize(a); \
|
||||
return ck_tile::make_kernel(kernel{}, grids, blocks, 0, kargs); \
|
||||
}()
|
||||
#endif
|
||||
|
||||
#define MOE_SORTING_MP_23(mesh_type_, unroll_num_, expert_masking_) \
|
||||
[&]() { \
|
||||
constexpr ck_tile::index_t unroll_num = unroll_num_; \
|
||||
constexpr bool expert_masking = expert_masking_; \
|
||||
using ms_problem = ck_tile::MoeSortingProblemMp<ms_index_t, \
|
||||
ms_weight_type, \
|
||||
mesh_type_, \
|
||||
unroll_num, \
|
||||
expert_masking>; \
|
||||
using kernel = ck_tile::MoeSortingMultiPhaseKernel_P23<ms_problem>; \
|
||||
auto kargs = kernel::MakeKargs(a); \
|
||||
const dim3 grids = kernel::GridSize(a); \
|
||||
const dim3 blocks = kernel::BlockSize(a); \
|
||||
const auto lds_size = kernel::GetSmemSize(a); \
|
||||
return ck_tile::make_kernel<kernel::BLOCK_SIZE>(kernel{}, grids, blocks, lds_size, kargs); \
|
||||
}()
|
||||
|
||||
#define MOR_SORTING_MP_DISPATCH_(mesh_type_, token_vec_0_, token_vec_1_, token_vec_23_) \
|
||||
if(t.local_expert_masking) \
|
||||
{ \
|
||||
float ave_time = \
|
||||
ck_tile::launch_kernel(s, \
|
||||
MOE_SORTING_MP_0(mesh_type_, token_vec_0_, true), \
|
||||
MOE_SORTING_MP_1(mesh_type_, token_vec_1_, true), \
|
||||
MOE_SORTING_MP_23(mesh_type_, token_vec_23_, true)); \
|
||||
return ave_time; \
|
||||
} \
|
||||
else \
|
||||
{ \
|
||||
float ave_time = \
|
||||
ck_tile::launch_kernel(s, \
|
||||
MOE_SORTING_MP_0(mesh_type_, token_vec_0_, false), \
|
||||
MOE_SORTING_MP_1(mesh_type_, token_vec_1_, false), \
|
||||
MOE_SORTING_MP_23(mesh_type_, token_vec_23_, false)); \
|
||||
return ave_time; \
|
||||
}
|
||||
|
||||
float fused_moesorting_mp(fused_moesorting_trait t,
|
||||
fused_moesorting_args a,
|
||||
ck_tile::stream_config s)
|
||||
{
|
||||
if(t.weight_type == "fp32" && t.index_type == "int32")
|
||||
{
|
||||
using ms_index_t = ck_tile::index_t;
|
||||
using ms_weight_type = float;
|
||||
|
||||
if(ck_tile::impl::moe_sorting_get_smem_size_p23(a.num_experts) >
|
||||
ck_tile::get_smem_capacity())
|
||||
{
|
||||
#if MOE_SORTING_SUPPORT_LARGE_EXPERT
|
||||
if(t.local_expert_masking)
|
||||
{
|
||||
float ave_time = ck_tile::launch_kernel(s,
|
||||
MOE_SORTING_MP_0(ms_index_t, 1, true),
|
||||
MOE_SORTING_MP_1(ms_index_t, 1, true),
|
||||
MOE_SORTING_MP_2(ms_index_t, 1, true),
|
||||
MOE_SORTING_MP_3(ms_index_t, 1, true));
|
||||
return ave_time;
|
||||
}
|
||||
else
|
||||
{
|
||||
float ave_time = ck_tile::launch_kernel(s,
|
||||
MOE_SORTING_MP_0(ms_index_t, 1, false),
|
||||
MOE_SORTING_MP_1(ms_index_t, 1, false),
|
||||
MOE_SORTING_MP_2(ms_index_t, 1, false),
|
||||
MOE_SORTING_MP_3(ms_index_t, 1, false));
|
||||
return ave_time;
|
||||
}
|
||||
#else
|
||||
printf("do not support large expert %d\n", a.num_experts);
|
||||
return -1;
|
||||
#endif
|
||||
}
|
||||
else
|
||||
{
|
||||
ck_tile::index_t mesh_byte_size =
|
||||
ck_tile::impl::moe_sorting_mesh_byte_size(a.tokens, a.num_experts, a.topk);
|
||||
if(mesh_byte_size == 1)
|
||||
{
|
||||
if(a.tokens * a.topk % 4 == 0)
|
||||
{
|
||||
MOR_SORTING_MP_DISPATCH_(uint8_t, 4, 16, 16)
|
||||
}
|
||||
else
|
||||
{
|
||||
MOR_SORTING_MP_DISPATCH_(uint8_t, 1, 16, 16)
|
||||
}
|
||||
}
|
||||
else if(mesh_byte_size == 2)
|
||||
{
|
||||
#if MOE_SORTING_SUPPORT_LARGE_TOPK
|
||||
if(a.tokens * a.topk % 4 == 0)
|
||||
{
|
||||
MOR_SORTING_MP_DISPATCH_(uint16_t, 4, 8, 8)
|
||||
}
|
||||
else
|
||||
{
|
||||
MOR_SORTING_MP_DISPATCH_(uint16_t, 1, 8, 8)
|
||||
}
|
||||
#else
|
||||
printf("do not support large topk %d\n", a.topk);
|
||||
return -1;
|
||||
#endif
|
||||
}
|
||||
else
|
||||
{
|
||||
MOR_SORTING_MP_DISPATCH_(ck_tile::index_t, 1, 1, 1)
|
||||
}
|
||||
}
|
||||
}
|
||||
return -1;
|
||||
}
|
||||
|
||||
@@ -372,7 +372,8 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
num_sorted_tiles_host.get_element_space_size_in_bytes());
|
||||
|
||||
// if return zero, means no need workspace, can set moe_sorting_args.p_ws to nullptr
|
||||
ck_tile::index_t workspace_size = ck_tile::moe_sorting_get_workspace_size(tokens, experts);
|
||||
ck_tile::index_t workspace_size =
|
||||
ck_tile::moe_sorting_get_workspace_size(tokens, experts, topk);
|
||||
ck_tile::DeviceMem moe_sorting_ws(workspace_size != 0 ? workspace_size : 0);
|
||||
if(workspace_size != 0)
|
||||
moe_sorting_ws.SetZero(); // note, clear here!!!!
|
||||
|
||||
Reference in New Issue
Block a user