mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-30 11:47:48 +00:00
Merge branch 'wip_355' into wip_355_xcd_remap
This commit is contained in:
@@ -3,6 +3,6 @@ target_include_directories(tile_example_moe_sorting PRIVATE ${CMAKE_CURRENT_SOUR
|
||||
|
||||
set(EXAMPLE_MOE_SORTING_COMPILE_OPTIONS)
|
||||
# NOTE: we turn off undefined-func-template to let source compile without explicit declare function specializations
|
||||
list(APPEND EXAMPLE_MOE_SORTING_COMPILE_OPTIONS -Wno-undefined-func-template -Wno-float-equal)
|
||||
list(APPEND EXAMPLE_MOE_SORTING_COMPILE_OPTIONS -Wno-undefined-func-template -Wno-float-equal -Wno-error)
|
||||
# list(APPEND EXAMPLE_MOE_SORTING_COMPILE_OPTIONS -v --save-temps -Wno-gnu-line-marker)
|
||||
target_compile_options(tile_example_moe_sorting PRIVATE ${EXAMPLE_MOE_SORTING_COMPILE_OPTIONS})
|
||||
|
||||
@@ -194,7 +194,7 @@ float moe_sorting(moe_sorting_trait t, moe_sorting_args a, ck_tile::stream_confi
|
||||
return -1;
|
||||
}
|
||||
|
||||
#define MOE_SORTING_MP_0(mesh_type_, unroll_num_, expert_masking_, local_token_) \
|
||||
#define MOE_SORTING_MP_0_V1(mesh_type_, unroll_num_, expert_masking_, local_token_) \
|
||||
[&]() { \
|
||||
constexpr ck_tile::index_t unroll_num = unroll_num_; \
|
||||
constexpr bool expert_masking = expert_masking_; \
|
||||
@@ -205,7 +205,25 @@ float moe_sorting(moe_sorting_trait t, moe_sorting_args a, ck_tile::stream_confi
|
||||
unroll_num, \
|
||||
expert_masking, \
|
||||
local_token>; \
|
||||
using kernel = ck_tile::MoeSortingMultiPhaseKernel_P0<ms_problem>; \
|
||||
using kernel = ck_tile::MoeSortingMultiPhaseKernel_P0_v1<ms_problem>; \
|
||||
auto kargs = kernel::MakeKargs(a); \
|
||||
const dim3 grids = kernel::GridSize(a); \
|
||||
const dim3 blocks = kernel::BlockSize(a); \
|
||||
return ck_tile::make_kernel<kernel::BLOCK_SIZE>(kernel{}, grids, blocks, 0, kargs); \
|
||||
}()
|
||||
|
||||
#define MOE_SORTING_MP_0_V2(mesh_type_, unroll_num_, expert_masking_, local_token_) \
|
||||
[&]() { \
|
||||
constexpr ck_tile::index_t unroll_num = unroll_num_; \
|
||||
constexpr bool expert_masking = expert_masking_; \
|
||||
constexpr bool local_token = local_token_; \
|
||||
using ms_problem = ck_tile::MoeSortingProblemMp<ms_index_t, \
|
||||
ms_weight_type, \
|
||||
mesh_type_, \
|
||||
unroll_num, \
|
||||
expert_masking, \
|
||||
local_token>; \
|
||||
using kernel = ck_tile::MoeSortingMultiPhaseKernel_P0_v2<ms_problem>; \
|
||||
auto kargs = kernel::MakeKargs(a); \
|
||||
const dim3 grids = kernel::GridSize(a); \
|
||||
const dim3 blocks = kernel::BlockSize(a); \
|
||||
@@ -286,6 +304,46 @@ float moe_sorting(moe_sorting_trait t, moe_sorting_args a, ck_tile::stream_confi
|
||||
return ck_tile::make_kernel<kernel::BLOCK_SIZE>(kernel{}, grids, blocks, lds_size, kargs); \
|
||||
}()
|
||||
|
||||
#define MOR_SORTING_MP_DISPATCH_SMALL_(mesh_type_, token_vec_0_, token_vec_1_, token_vec_23_) \
|
||||
if(t.local_expert_masking) \
|
||||
{ \
|
||||
if(is_local_token) \
|
||||
{ \
|
||||
float ave_time = \
|
||||
ck_tile::launch_kernel(s, \
|
||||
MOE_SORTING_MP_0_V2(mesh_type_, token_vec_0_, true, true), \
|
||||
MOE_SORTING_MP_23(mesh_type_, token_vec_23_, true, true)); \
|
||||
return ave_time; \
|
||||
} \
|
||||
else \
|
||||
{ \
|
||||
float ave_time = \
|
||||
ck_tile::launch_kernel(s, \
|
||||
MOE_SORTING_MP_0_V2(mesh_type_, token_vec_0_, true, false), \
|
||||
MOE_SORTING_MP_23(mesh_type_, token_vec_23_, true, false)); \
|
||||
return ave_time; \
|
||||
} \
|
||||
} \
|
||||
else \
|
||||
{ \
|
||||
if(is_local_token) \
|
||||
{ \
|
||||
float ave_time = \
|
||||
ck_tile::launch_kernel(s, \
|
||||
MOE_SORTING_MP_0_V2(mesh_type_, token_vec_0_, false, true), \
|
||||
MOE_SORTING_MP_23(mesh_type_, token_vec_23_, false, true)); \
|
||||
return ave_time; \
|
||||
} \
|
||||
else \
|
||||
{ \
|
||||
float ave_time = ck_tile::launch_kernel( \
|
||||
s, \
|
||||
MOE_SORTING_MP_0_V2(mesh_type_, token_vec_0_, false, false), \
|
||||
MOE_SORTING_MP_23(mesh_type_, token_vec_23_, false, false)); \
|
||||
return ave_time; \
|
||||
} \
|
||||
}
|
||||
|
||||
#define MOR_SORTING_MP_DISPATCH_(mesh_type_, token_vec_0_, token_vec_1_, token_vec_23_) \
|
||||
if(t.local_expert_masking) \
|
||||
{ \
|
||||
@@ -294,7 +352,8 @@ float moe_sorting(moe_sorting_trait t, moe_sorting_args a, ck_tile::stream_confi
|
||||
float ave_time = \
|
||||
ck_tile::launch_kernel(s, \
|
||||
maybe_clear_workspace, \
|
||||
MOE_SORTING_MP_0(mesh_type_, token_vec_0_, true, true), \
|
||||
MOE_SORTING_MP_0_V1(mesh_type_, token_vec_0_, true, true), \
|
||||
MOE_SORTING_MP_1(mesh_type_, token_vec_1_, true, true), \
|
||||
MOE_SORTING_MP_23(mesh_type_, token_vec_23_, true, true)); \
|
||||
return ave_time; \
|
||||
} \
|
||||
@@ -303,7 +362,8 @@ float moe_sorting(moe_sorting_trait t, moe_sorting_args a, ck_tile::stream_confi
|
||||
float ave_time = \
|
||||
ck_tile::launch_kernel(s, \
|
||||
maybe_clear_workspace, \
|
||||
MOE_SORTING_MP_0(mesh_type_, token_vec_0_, true, false), \
|
||||
MOE_SORTING_MP_0_V1(mesh_type_, token_vec_0_, true, false), \
|
||||
MOE_SORTING_MP_1(mesh_type_, token_vec_1_, true, false), \
|
||||
MOE_SORTING_MP_23(mesh_type_, token_vec_23_, true, false)); \
|
||||
return ave_time; \
|
||||
} \
|
||||
@@ -315,7 +375,8 @@ float moe_sorting(moe_sorting_trait t, moe_sorting_args a, ck_tile::stream_confi
|
||||
float ave_time = \
|
||||
ck_tile::launch_kernel(s, \
|
||||
maybe_clear_workspace, \
|
||||
MOE_SORTING_MP_0(mesh_type_, token_vec_0_, false, true), \
|
||||
MOE_SORTING_MP_0_V1(mesh_type_, token_vec_0_, false, true), \
|
||||
MOE_SORTING_MP_1(mesh_type_, token_vec_1_, false, true), \
|
||||
MOE_SORTING_MP_23(mesh_type_, token_vec_23_, false, true)); \
|
||||
return ave_time; \
|
||||
} \
|
||||
@@ -324,7 +385,8 @@ float moe_sorting(moe_sorting_trait t, moe_sorting_args a, ck_tile::stream_confi
|
||||
float ave_time = ck_tile::launch_kernel( \
|
||||
s, \
|
||||
maybe_clear_workspace, \
|
||||
MOE_SORTING_MP_0(mesh_type_, token_vec_0_, false, false), \
|
||||
MOE_SORTING_MP_0_V1(mesh_type_, token_vec_0_, false, false), \
|
||||
MOE_SORTING_MP_1(mesh_type_, token_vec_1_, false, false), \
|
||||
MOE_SORTING_MP_23(mesh_type_, token_vec_23_, false, false)); \
|
||||
return ave_time; \
|
||||
} \
|
||||
@@ -365,67 +427,136 @@ float moe_sorting_mp(moe_sorting_trait t, moe_sorting_args a, ck_tile::stream_co
|
||||
}
|
||||
};
|
||||
|
||||
if(ck_tile::impl::moe_sorting_get_smem_size_p23(a.num_experts) >
|
||||
ck_tile::get_smem_capacity())
|
||||
if (a.tokens < 2048)
|
||||
{
|
||||
if(ck_tile::impl::moe_sorting_get_smem_size_p23(a.num_experts) >
|
||||
ck_tile::get_smem_capacity())
|
||||
{
|
||||
#if MOE_SORTING_SUPPORT_LARGE_EXPERT
|
||||
if(t.local_expert_masking)
|
||||
{
|
||||
float ave_time = ck_tile::launch_kernel(s,
|
||||
maybe_clear_workspace,
|
||||
MOE_SORTING_MP_0(ms_index_t, 1, true),
|
||||
MOE_SORTING_MP_2(ms_index_t, 1, true),
|
||||
MOE_SORTING_MP_3(ms_index_t, 1, true));
|
||||
return ave_time;
|
||||
}
|
||||
else
|
||||
{
|
||||
float ave_time = ck_tile::launch_kernel(s,
|
||||
maybe_clear_workspace,
|
||||
MOE_SORTING_MP_0(ms_index_t, 1, false),
|
||||
MOE_SORTING_MP_2(ms_index_t, 1, false),
|
||||
MOE_SORTING_MP_3(ms_index_t, 1, false));
|
||||
return ave_time;
|
||||
}
|
||||
#else
|
||||
printf("do not support large expert %d\n", a.num_experts);
|
||||
return -1;
|
||||
#endif
|
||||
}
|
||||
else
|
||||
{
|
||||
ck_tile::index_t mesh_byte_size =
|
||||
ck_tile::impl::moe_sorting_mesh_byte_size(a.tokens, a.num_experts, a.topk);
|
||||
if(mesh_byte_size == 1)
|
||||
{
|
||||
if(a.tokens * a.topk % 4 == 0)
|
||||
if(t.local_expert_masking)
|
||||
{
|
||||
MOR_SORTING_MP_DISPATCH_(uint8_t, 4, 16, 16)
|
||||
float ave_time = ck_tile::launch_kernel(s,
|
||||
MOE_SORTING_MP_0_V2(ms_index_t, 1, true),
|
||||
MOE_SORTING_MP_2(ms_index_t, 1, true),
|
||||
MOE_SORTING_MP_3(ms_index_t, 1, true));
|
||||
return ave_time;
|
||||
}
|
||||
else
|
||||
{
|
||||
MOR_SORTING_MP_DISPATCH_(uint8_t, 1, 16, 16)
|
||||
}
|
||||
}
|
||||
else if(mesh_byte_size == 2)
|
||||
{
|
||||
#if MOE_SORTING_SUPPORT_LARGE_TOPK
|
||||
if(a.tokens * a.topk % 4 == 0)
|
||||
{
|
||||
MOR_SORTING_MP_DISPATCH_(uint16_t, 4, 8, 8)
|
||||
}
|
||||
else
|
||||
{
|
||||
MOR_SORTING_MP_DISPATCH_(uint16_t, 1, 8, 8)
|
||||
float ave_time = ck_tile::launch_kernel(s,
|
||||
MOE_SORTING_MP_0_V2(ms_index_t, 1, false),
|
||||
MOE_SORTING_MP_2(ms_index_t, 1, false),
|
||||
MOE_SORTING_MP_3(ms_index_t, 1, false));
|
||||
return ave_time;
|
||||
}
|
||||
#else
|
||||
printf("do not support large topk %d\n", a.topk);
|
||||
printf("do not support large expert %d\n", a.num_experts);
|
||||
return -1;
|
||||
#endif
|
||||
}
|
||||
else
|
||||
{
|
||||
MOR_SORTING_MP_DISPATCH_(ck_tile::index_t, 1, 1, 1)
|
||||
ck_tile::index_t mesh_byte_size =
|
||||
ck_tile::impl::moe_sorting_mesh_byte_size(a.tokens, a.num_experts, a.topk);
|
||||
if(mesh_byte_size == 1)
|
||||
{
|
||||
if(a.tokens * a.topk % 4 == 0)
|
||||
{
|
||||
MOR_SORTING_MP_DISPATCH_SMALL_(uint8_t, 4, 16, 16)
|
||||
}
|
||||
else
|
||||
{
|
||||
MOR_SORTING_MP_DISPATCH_SMALL_(uint8_t, 1, 16, 16)
|
||||
}
|
||||
}
|
||||
else if(mesh_byte_size == 2)
|
||||
{
|
||||
#if MOE_SORTING_SUPPORT_LARGE_TOPK
|
||||
if(a.tokens * a.topk % 4 == 0)
|
||||
{
|
||||
MOR_SORTING_MP_DISPATCH_SMALL_(uint16_t, 4, 8, 8)
|
||||
}
|
||||
else
|
||||
{
|
||||
MOR_SORTING_MP_DISPATCH_SMALL_(uint16_t, 1, 8, 8)
|
||||
}
|
||||
#else
|
||||
printf("do not support large topk %d\n", a.topk);
|
||||
return -1;
|
||||
#endif
|
||||
}
|
||||
else
|
||||
{
|
||||
MOR_SORTING_MP_DISPATCH_SMALL_(ck_tile::index_t, 1, 1, 1)
|
||||
}
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
if(ck_tile::impl::moe_sorting_get_smem_size_p23(a.num_experts) >
|
||||
ck_tile::get_smem_capacity())
|
||||
{
|
||||
#if MOE_SORTING_SUPPORT_LARGE_EXPERT
|
||||
if(t.local_expert_masking)
|
||||
{
|
||||
float ave_time = ck_tile::launch_kernel(s,
|
||||
maybe_clear_workspace,
|
||||
MOE_SORTING_MP_0_V1(ms_index_t, 1, true),
|
||||
MOE_SORTING_MP_1(ms_index_t, 1, true),
|
||||
MOE_SORTING_MP_2(ms_index_t, 1, true),
|
||||
MOE_SORTING_MP_3(ms_index_t, 1, true));
|
||||
return ave_time;
|
||||
}
|
||||
else
|
||||
{
|
||||
float ave_time = ck_tile::launch_kernel(s,
|
||||
maybe_clear_workspace,
|
||||
MOE_SORTING_MP_0_V1(ms_index_t, 1, false),
|
||||
MOE_SORTING_MP_1(ms_index_t, 1, false),
|
||||
MOE_SORTING_MP_2(ms_index_t, 1, false),
|
||||
MOE_SORTING_MP_3(ms_index_t, 1, false));
|
||||
return ave_time;
|
||||
}
|
||||
#else
|
||||
printf("do not support large expert %d\n", a.num_experts);
|
||||
return -1;
|
||||
#endif
|
||||
}
|
||||
else
|
||||
{
|
||||
ck_tile::index_t mesh_byte_size =
|
||||
ck_tile::impl::moe_sorting_mesh_byte_size(a.tokens, a.num_experts, a.topk);
|
||||
if(mesh_byte_size == 1)
|
||||
{
|
||||
if(a.tokens * a.topk % 4 == 0)
|
||||
{
|
||||
MOR_SORTING_MP_DISPATCH_(uint8_t, 4, 16, 16)
|
||||
}
|
||||
else
|
||||
{
|
||||
MOR_SORTING_MP_DISPATCH_(uint8_t, 1, 16, 16)
|
||||
}
|
||||
}
|
||||
else if(mesh_byte_size == 2)
|
||||
{
|
||||
#if MOE_SORTING_SUPPORT_LARGE_TOPK
|
||||
if(a.tokens * a.topk % 4 == 0)
|
||||
{
|
||||
MOR_SORTING_MP_DISPATCH_(uint16_t, 4, 8, 8)
|
||||
}
|
||||
else
|
||||
{
|
||||
MOR_SORTING_MP_DISPATCH_(uint16_t, 1, 8, 8)
|
||||
}
|
||||
#else
|
||||
printf("do not support large topk %d\n", a.topk);
|
||||
return -1;
|
||||
#endif
|
||||
}
|
||||
else
|
||||
{
|
||||
MOR_SORTING_MP_DISPATCH_(ck_tile::index_t, 1, 1, 1)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -198,7 +198,7 @@ float fused_moesorting(fused_moesorting_trait t, fused_moesorting_args a, ck_til
|
||||
return -1;
|
||||
}
|
||||
|
||||
#define MOE_SORTING_MP_0(mesh_type_, unroll_num_, expert_masking_, local_token_) \
|
||||
#define MOE_SORTING_MP_0_V1(mesh_type_, unroll_num_, expert_masking_, local_token_) \
|
||||
[&]() { \
|
||||
constexpr ck_tile::index_t unroll_num = unroll_num_; \
|
||||
constexpr bool expert_masking = expert_masking_; \
|
||||
@@ -209,7 +209,25 @@ float fused_moesorting(fused_moesorting_trait t, fused_moesorting_args a, ck_til
|
||||
unroll_num, \
|
||||
expert_masking, \
|
||||
local_token>; \
|
||||
using kernel = ck_tile::MoeSortingMultiPhaseKernel_P0<ms_problem>; \
|
||||
using kernel = ck_tile::MoeSortingMultiPhaseKernel_P0_v1<ms_problem>; \
|
||||
auto kargs = kernel::MakeKargs(a); \
|
||||
const dim3 grids = kernel::GridSize(a); \
|
||||
const dim3 blocks = kernel::BlockSize(a); \
|
||||
return ck_tile::make_kernel<kernel::BLOCK_SIZE>(kernel{}, grids, blocks, 0, kargs); \
|
||||
}()
|
||||
|
||||
#define MOE_SORTING_MP_0_V2(mesh_type_, unroll_num_, expert_masking_, local_token_) \
|
||||
[&]() { \
|
||||
constexpr ck_tile::index_t unroll_num = unroll_num_; \
|
||||
constexpr bool expert_masking = expert_masking_; \
|
||||
constexpr bool local_token = local_token_; \
|
||||
using ms_problem = ck_tile::MoeSortingProblemMp<ms_index_t, \
|
||||
ms_weight_type, \
|
||||
mesh_type_, \
|
||||
unroll_num, \
|
||||
expert_masking, \
|
||||
local_token>; \
|
||||
using kernel = ck_tile::MoeSortingMultiPhaseKernel_P0_v2<ms_problem>; \
|
||||
auto kargs = kernel::MakeKargs(a); \
|
||||
const dim3 grids = kernel::GridSize(a); \
|
||||
const dim3 blocks = kernel::BlockSize(a); \
|
||||
@@ -290,14 +308,14 @@ float fused_moesorting(fused_moesorting_trait t, fused_moesorting_args a, ck_til
|
||||
return ck_tile::make_kernel<kernel::BLOCK_SIZE>(kernel{}, grids, blocks, lds_size, kargs); \
|
||||
}()
|
||||
|
||||
#define MOR_SORTING_MP_DISPATCH_(mesh_type_, token_vec_0_, token_vec_1_, token_vec_23_) \
|
||||
#define MOR_SORTING_MP_DISPATCH_SMALL_(mesh_type_, token_vec_0_, token_vec_1_, token_vec_23_) \
|
||||
if(t.local_expert_masking) \
|
||||
{ \
|
||||
if(is_local_token) \
|
||||
{ \
|
||||
float ave_time = \
|
||||
ck_tile::launch_kernel(s, \
|
||||
MOE_SORTING_MP_0(mesh_type_, token_vec_0_, true, true), \
|
||||
MOE_SORTING_MP_0_V2(mesh_type_, token_vec_0_, true, true), \
|
||||
MOE_SORTING_MP_23(mesh_type_, token_vec_23_, true, true)); \
|
||||
return ave_time; \
|
||||
} \
|
||||
@@ -305,7 +323,7 @@ float fused_moesorting(fused_moesorting_trait t, fused_moesorting_args a, ck_til
|
||||
{ \
|
||||
float ave_time = \
|
||||
ck_tile::launch_kernel(s, \
|
||||
MOE_SORTING_MP_0(mesh_type_, token_vec_0_, true, false), \
|
||||
MOE_SORTING_MP_0_V2(mesh_type_, token_vec_0_, true, false), \
|
||||
MOE_SORTING_MP_23(mesh_type_, token_vec_23_, true, false)); \
|
||||
return ave_time; \
|
||||
} \
|
||||
@@ -316,7 +334,7 @@ float fused_moesorting(fused_moesorting_trait t, fused_moesorting_args a, ck_til
|
||||
{ \
|
||||
float ave_time = \
|
||||
ck_tile::launch_kernel(s, \
|
||||
MOE_SORTING_MP_0(mesh_type_, token_vec_0_, false, true), \
|
||||
MOE_SORTING_MP_0_V2(mesh_type_, token_vec_0_, false, true), \
|
||||
MOE_SORTING_MP_23(mesh_type_, token_vec_23_, false, true)); \
|
||||
return ave_time; \
|
||||
} \
|
||||
@@ -324,7 +342,51 @@ float fused_moesorting(fused_moesorting_trait t, fused_moesorting_args a, ck_til
|
||||
{ \
|
||||
float ave_time = ck_tile::launch_kernel( \
|
||||
s, \
|
||||
MOE_SORTING_MP_0(mesh_type_, token_vec_0_, false, false), \
|
||||
MOE_SORTING_MP_0_V2(mesh_type_, token_vec_0_, false, false), \
|
||||
MOE_SORTING_MP_23(mesh_type_, token_vec_23_, false, false)); \
|
||||
return ave_time; \
|
||||
} \
|
||||
}
|
||||
|
||||
#define MOR_SORTING_MP_DISPATCH_(mesh_type_, token_vec_0_, token_vec_1_, token_vec_23_) \
|
||||
if(t.local_expert_masking) \
|
||||
{ \
|
||||
if(is_local_token) \
|
||||
{ \
|
||||
float ave_time = \
|
||||
ck_tile::launch_kernel(s, \
|
||||
MOE_SORTING_MP_0_V1(mesh_type_, token_vec_0_, true, true), \
|
||||
MOE_SORTING_MP_1(mesh_type_, token_vec_1_, true, true), \
|
||||
MOE_SORTING_MP_23(mesh_type_, token_vec_23_, true, true)); \
|
||||
return ave_time; \
|
||||
} \
|
||||
else \
|
||||
{ \
|
||||
float ave_time = \
|
||||
ck_tile::launch_kernel(s, \
|
||||
MOE_SORTING_MP_0_V1(mesh_type_, token_vec_0_, true, false), \
|
||||
MOE_SORTING_MP_1(mesh_type_, token_vec_1_, true, false), \
|
||||
MOE_SORTING_MP_23(mesh_type_, token_vec_23_, true, false)); \
|
||||
return ave_time; \
|
||||
} \
|
||||
} \
|
||||
else \
|
||||
{ \
|
||||
if(is_local_token) \
|
||||
{ \
|
||||
float ave_time = \
|
||||
ck_tile::launch_kernel(s, \
|
||||
MOE_SORTING_MP_0_V1(mesh_type_, token_vec_0_, false, true), \
|
||||
MOE_SORTING_MP_1(mesh_type_, token_vec_1_, false, true), \
|
||||
MOE_SORTING_MP_23(mesh_type_, token_vec_23_, false, true)); \
|
||||
return ave_time; \
|
||||
} \
|
||||
else \
|
||||
{ \
|
||||
float ave_time = ck_tile::launch_kernel( \
|
||||
s, \
|
||||
MOE_SORTING_MP_0_V1(mesh_type_, token_vec_0_, false, false), \
|
||||
MOE_SORTING_MP_1(mesh_type_, token_vec_1_, false, false), \
|
||||
MOE_SORTING_MP_23(mesh_type_, token_vec_23_, false, false)); \
|
||||
return ave_time; \
|
||||
} \
|
||||
@@ -340,65 +402,154 @@ float fused_moesorting_mp(fused_moesorting_trait t,
|
||||
using ms_index_t = ck_tile::index_t;
|
||||
using ms_weight_type = float;
|
||||
|
||||
if(ck_tile::impl::moe_sorting_get_smem_size_p23(a.num_experts) >
|
||||
ck_tile::get_smem_capacity())
|
||||
auto maybe_clear_workspace = [=](const ck_tile::stream_config& s_) {
|
||||
if(t.clear_workspace_inside_api)
|
||||
{
|
||||
if(is_local_token)
|
||||
{
|
||||
auto k = MOR_SORTING_CLEAR_WS_DISPATCH_(true, 1024, 1);
|
||||
k(s_);
|
||||
}
|
||||
else
|
||||
{
|
||||
auto k = MOR_SORTING_CLEAR_WS_DISPATCH_(false, 1024, 1);
|
||||
k(s_);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
if (a.tokens < 2048)
|
||||
{
|
||||
if(ck_tile::impl::moe_sorting_get_smem_size_p23(a.num_experts) >
|
||||
ck_tile::get_smem_capacity())
|
||||
{
|
||||
#if MOE_SORTING_SUPPORT_LARGE_EXPERT
|
||||
if(t.local_expert_masking)
|
||||
{
|
||||
float ave_time = ck_tile::launch_kernel(s,
|
||||
MOE_SORTING_MP_0(ms_index_t, 1, true),
|
||||
MOE_SORTING_MP_2(ms_index_t, 1, true),
|
||||
MOE_SORTING_MP_3(ms_index_t, 1, true));
|
||||
return ave_time;
|
||||
}
|
||||
else
|
||||
{
|
||||
float ave_time = ck_tile::launch_kernel(s,
|
||||
MOE_SORTING_MP_0(ms_index_t, 1, false),
|
||||
MOE_SORTING_MP_2(ms_index_t, 1, false),
|
||||
MOE_SORTING_MP_3(ms_index_t, 1, false));
|
||||
return ave_time;
|
||||
}
|
||||
#else
|
||||
printf("do not support large expert %d\n", a.num_experts);
|
||||
return -1;
|
||||
#endif
|
||||
}
|
||||
else
|
||||
{
|
||||
ck_tile::index_t mesh_byte_size =
|
||||
ck_tile::impl::moe_sorting_mesh_byte_size(a.tokens, a.num_experts, a.topk);
|
||||
if(mesh_byte_size == 1)
|
||||
{
|
||||
if(a.tokens * a.topk % 4 == 0)
|
||||
if(t.local_expert_masking)
|
||||
{
|
||||
MOR_SORTING_MP_DISPATCH_(uint8_t, 4, 16, 16)
|
||||
float ave_time = ck_tile::launch_kernel(s,
|
||||
maybe_clear_workspace,
|
||||
MOE_SORTING_MP_0_V2(ms_index_t, 1, true),
|
||||
MOE_SORTING_MP_2(ms_index_t, 1, true),
|
||||
MOE_SORTING_MP_3(ms_index_t, 1, true));
|
||||
return ave_time;
|
||||
}
|
||||
else
|
||||
{
|
||||
MOR_SORTING_MP_DISPATCH_(uint8_t, 1, 16, 16)
|
||||
}
|
||||
}
|
||||
else if(mesh_byte_size == 2)
|
||||
{
|
||||
#if MOE_SORTING_SUPPORT_LARGE_TOPK
|
||||
if(a.tokens * a.topk % 4 == 0)
|
||||
{
|
||||
MOR_SORTING_MP_DISPATCH_(uint16_t, 4, 8, 8)
|
||||
}
|
||||
else
|
||||
{
|
||||
MOR_SORTING_MP_DISPATCH_(uint16_t, 1, 8, 8)
|
||||
float ave_time = ck_tile::launch_kernel(s,
|
||||
maybe_clear_workspace,
|
||||
MOE_SORTING_MP_0_V2(ms_index_t, 1, false),
|
||||
MOE_SORTING_MP_2(ms_index_t, 1, false),
|
||||
MOE_SORTING_MP_3(ms_index_t, 1, false));
|
||||
return ave_time;
|
||||
}
|
||||
#else
|
||||
printf("do not support large topk %d\n", a.topk);
|
||||
printf("do not support large expert %d\n", a.num_experts);
|
||||
return -1;
|
||||
#endif
|
||||
}
|
||||
else
|
||||
{
|
||||
MOR_SORTING_MP_DISPATCH_(ck_tile::index_t, 1, 1, 1)
|
||||
ck_tile::index_t mesh_byte_size =
|
||||
ck_tile::impl::moe_sorting_mesh_byte_size(a.tokens, a.num_experts, a.topk);
|
||||
if(mesh_byte_size == 1)
|
||||
{
|
||||
if(a.tokens * a.topk % 4 == 0)
|
||||
{
|
||||
MOR_SORTING_MP_DISPATCH_SMALL_(uint8_t, 4, 16, 16)
|
||||
}
|
||||
else
|
||||
{
|
||||
MOR_SORTING_MP_DISPATCH_SMALL_(uint8_t, 1, 16, 16)
|
||||
}
|
||||
}
|
||||
else if(mesh_byte_size == 2)
|
||||
{
|
||||
#if MOE_SORTING_SUPPORT_LARGE_TOPK
|
||||
if(a.tokens * a.topk % 4 == 0)
|
||||
{
|
||||
MOR_SORTING_MP_DISPATCH_SMALL_(uint16_t, 4, 8, 8)
|
||||
}
|
||||
else
|
||||
{
|
||||
MOR_SORTING_MP_DISPATCH_SMALL_(uint16_t, 1, 8, 8)
|
||||
}
|
||||
#else
|
||||
printf("do not support large topk %d\n", a.topk);
|
||||
return -1;
|
||||
#endif
|
||||
}
|
||||
else
|
||||
{
|
||||
MOR_SORTING_MP_DISPATCH_SMALL_(ck_tile::index_t, 1, 1, 1)
|
||||
}
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
if(ck_tile::impl::moe_sorting_get_smem_size_p23(a.num_experts) >
|
||||
ck_tile::get_smem_capacity())
|
||||
{
|
||||
#if MOE_SORTING_SUPPORT_LARGE_EXPERT
|
||||
if(t.local_expert_masking)
|
||||
{
|
||||
float ave_time = ck_tile::launch_kernel(s,
|
||||
maybe_clear_workspace,
|
||||
MOE_SORTING_MP_0_V1(ms_index_t, 1, true),
|
||||
MOE_SORTING_MP_1(ms_index_t, 1, true),
|
||||
MOE_SORTING_MP_2(ms_index_t, 1, true),
|
||||
MOE_SORTING_MP_3(ms_index_t, 1, true));
|
||||
return ave_time;
|
||||
}
|
||||
else
|
||||
{
|
||||
float ave_time = ck_tile::launch_kernel(s,
|
||||
maybe_clear_workspace,
|
||||
MOE_SORTING_MP_0_V1(ms_index_t, 1, false),
|
||||
MOE_SORTING_MP_1(ms_index_t, 1, false),
|
||||
MOE_SORTING_MP_2(ms_index_t, 1, false),
|
||||
MOE_SORTING_MP_3(ms_index_t, 1, false));
|
||||
return ave_time;
|
||||
}
|
||||
#else
|
||||
printf("do not support large expert %d\n", a.num_experts);
|
||||
return -1;
|
||||
#endif
|
||||
}
|
||||
else
|
||||
{
|
||||
ck_tile::index_t mesh_byte_size =
|
||||
ck_tile::impl::moe_sorting_mesh_byte_size(a.tokens, a.num_experts, a.topk);
|
||||
if(mesh_byte_size == 1)
|
||||
{
|
||||
if(a.tokens * a.topk % 4 == 0)
|
||||
{
|
||||
MOR_SORTING_MP_DISPATCH_(uint8_t, 4, 16, 16)
|
||||
}
|
||||
else
|
||||
{
|
||||
MOR_SORTING_MP_DISPATCH_(uint8_t, 1, 16, 16)
|
||||
}
|
||||
}
|
||||
else if(mesh_byte_size == 2)
|
||||
{
|
||||
#if MOE_SORTING_SUPPORT_LARGE_TOPK
|
||||
if(a.tokens * a.topk % 4 == 0)
|
||||
{
|
||||
MOR_SORTING_MP_DISPATCH_(uint16_t, 4, 8, 8)
|
||||
}
|
||||
else
|
||||
{
|
||||
MOR_SORTING_MP_DISPATCH_(uint16_t, 1, 8, 8)
|
||||
}
|
||||
#else
|
||||
printf("do not support large topk %d\n", a.topk);
|
||||
return -1;
|
||||
#endif
|
||||
}
|
||||
else
|
||||
{
|
||||
MOR_SORTING_MP_DISPATCH_(ck_tile::index_t, 1, 1, 1)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1373,7 +1373,7 @@ CK_TILE_HOST index_t moe_sorting_mp_get_workspace_size(int tokens_, int num_expe
|
||||
{
|
||||
index_t s_ = impl::moe_sorting_mp_mesh_smem_size(tokens_, num_experts_, topk_) +
|
||||
impl::moe_sorting_mp_cumsum_smem_size(num_experts_)
|
||||
#if 1
|
||||
#if MOE_SORTING_FUSE_MP_01
|
||||
+ impl::moe_sorting_mp_sem_smem_size();
|
||||
#else
|
||||
;
|
||||
@@ -1552,7 +1552,123 @@ p_m_cumsum
|
||||
|
||||
// count topk_id into mesh
|
||||
template <typename Problem_>
|
||||
struct MoeSortingMultiPhaseKernel_P0
|
||||
struct MoeSortingMultiPhaseKernel_P0_v1
|
||||
{
|
||||
using Problem = remove_cvref_t<Problem_>;
|
||||
|
||||
using IndexType = typename Problem::IndexType;
|
||||
using WeightType = typename Problem::WeightType;
|
||||
using MeshType = typename Problem::MeshType;
|
||||
|
||||
static constexpr index_t BLOCK_SIZE = 256;
|
||||
static constexpr index_t OCCUPANCY = 2; // hard coded
|
||||
|
||||
typedef MoeSortingHostArgs MoeSortingKargs;
|
||||
|
||||
using Hargs = MoeSortingHostArgs;
|
||||
|
||||
struct Kargs
|
||||
{
|
||||
const void* p_topk_ids; // [tokens, topk]
|
||||
const void* p_local_tokens; // [1], if not nullptr, use this as actual tokens
|
||||
void* p_expert_mesh; // [expert, tokens]
|
||||
index_t tokens; // if p_local_tokens is not nullptr, this indicate the max possible tokens
|
||||
// used for ws/LDS calculation
|
||||
index_t mesh_stride; // mesh_stride for p_expert_mesh
|
||||
mdiv topk_mdiv;
|
||||
};
|
||||
|
||||
CK_TILE_HOST static constexpr auto get_num_cu()
|
||||
{
|
||||
index_t num_cu = [&]() {
|
||||
hipDeviceProp_t dev_prop;
|
||||
hipDevice_t dev;
|
||||
HIP_CHECK_ERROR(hipGetDevice(&dev));
|
||||
HIP_CHECK_ERROR(hipGetDeviceProperties(&dev_prop, dev));
|
||||
return dev_prop.multiProcessorCount;
|
||||
}();
|
||||
return num_cu;
|
||||
}
|
||||
|
||||
CK_TILE_HOST static constexpr auto MakeKargs(const Hargs& h)
|
||||
{
|
||||
Kargs k;
|
||||
k.p_topk_ids = h.p_topk_ids;
|
||||
k.p_local_tokens = h.p_local_tokens;
|
||||
k.p_expert_mesh = h.p_ws;
|
||||
k.tokens = h.tokens;
|
||||
k.mesh_stride = impl::moe_sorting_mp_mesh_stride(h.tokens);
|
||||
k.topk_mdiv = mdiv{static_cast<uint32_t>(h.topk)};
|
||||
return k;
|
||||
}
|
||||
|
||||
CK_TILE_HOST static constexpr auto GridSize(const Hargs&) { return get_num_cu() * OCCUPANCY; }
|
||||
|
||||
CK_TILE_HOST static constexpr auto BlockSize(const Hargs&) { return dim3(BLOCK_SIZE); }
|
||||
|
||||
// in byte
|
||||
CK_TILE_HOST static constexpr auto GetSmemSize() { return 0; }
|
||||
|
||||
CK_TILE_DEVICE void operator()(Kargs kargs) const
|
||||
{
|
||||
using topk_id_t = ext_vector_t<IndexType, Problem::SubTokenTile>;
|
||||
|
||||
const topk_id_t* p_topk_ids = reinterpret_cast<const topk_id_t*>(kargs.p_topk_ids);
|
||||
MeshType* p_expert_mesh = reinterpret_cast<MeshType*>(kargs.p_expert_mesh);
|
||||
index_t tokens = [&]() {
|
||||
if constexpr(Problem::LocalToken)
|
||||
{
|
||||
return reinterpret_cast<const index_t*>(kargs.p_local_tokens)[0];
|
||||
}
|
||||
else
|
||||
{
|
||||
return kargs.tokens;
|
||||
}
|
||||
}();
|
||||
index_t rounded_tokens = [&]() {
|
||||
if constexpr(Problem::LocalToken)
|
||||
{
|
||||
return (tokens + Problem::SubTokenTile - 1) / Problem::SubTokenTile *
|
||||
Problem::SubTokenTile;
|
||||
}
|
||||
else
|
||||
return tokens;
|
||||
}();
|
||||
index_t mesh_stride = [&]() {
|
||||
if constexpr(Problem::LocalToken)
|
||||
{
|
||||
return impl::moe_sorting_mp_mesh_stride(tokens);
|
||||
}
|
||||
else
|
||||
{
|
||||
return kargs.mesh_stride;
|
||||
}
|
||||
}();
|
||||
index_t total_elem = rounded_tokens * kargs.topk_mdiv.divisor / Problem::SubTokenTile;
|
||||
|
||||
#pragma unroll Problem::SubTokenTile
|
||||
for(index_t i = blockIdx.x * BLOCK_SIZE + threadIdx.x; i < total_elem;
|
||||
i += gridDim.x * BLOCK_SIZE)
|
||||
{
|
||||
auto x = p_topk_ids[i];
|
||||
static_for<0, Problem::SubTokenTile, 1>{}([&](auto j) {
|
||||
IndexType eid = x[j.value]; // ext_vector_type must use int to []
|
||||
uint32_t curr_token_id, curr_topk_id;
|
||||
kargs.topk_mdiv.divmod(i * Problem::SubTokenTile + j, curr_token_id, curr_topk_id);
|
||||
if constexpr(Problem::LocalToken)
|
||||
{
|
||||
if(static_cast<index_t>(curr_token_id) < tokens)
|
||||
p_expert_mesh[eid * mesh_stride + curr_token_id] =
|
||||
(curr_topk_id + 1) & 0xffff;
|
||||
}
|
||||
else
|
||||
p_expert_mesh[eid * mesh_stride + curr_token_id] = (curr_topk_id + 1) & 0xffff;
|
||||
});
|
||||
}
|
||||
}
|
||||
};
|
||||
template <typename Problem_>
|
||||
struct MoeSortingMultiPhaseKernel_P0_v2
|
||||
{
|
||||
using Problem = remove_cvref_t<Problem_>;
|
||||
|
||||
@@ -1577,7 +1693,7 @@ struct MoeSortingMultiPhaseKernel_P0
|
||||
mdiv topk_mdiv;
|
||||
|
||||
const void* p_local_expert_mask; // [expert]
|
||||
void* p_expert_cumsum;
|
||||
void* p_expert_cumsum; // [expert]
|
||||
index_t num_experts;
|
||||
};
|
||||
|
||||
@@ -1669,12 +1785,6 @@ struct MoeSortingMultiPhaseKernel_P0
|
||||
mask = p_local_expert_mask[eid];
|
||||
}
|
||||
MeshType* p_expert_mesh = reinterpret_cast<MeshType*>(kargs.p_expert_mesh) + eid * mesh_stride;
|
||||
// const index_t total_bytes = mesh_stride * 4;
|
||||
|
||||
// using vector_type = ext_vector_t<index_t, 4>;
|
||||
// auto zero_ = vector_type{0};
|
||||
|
||||
// vector_type* p_expert_mesh_clear = reinterpret_cast<vector_type*>(kargs.p_expert_mesh);
|
||||
for(index_t i = threadIdx.x; i < mesh_stride; i += BLOCK_SIZE)
|
||||
{
|
||||
p_expert_mesh[i] = 0;
|
||||
@@ -1750,6 +1860,8 @@ struct MoeSortingMultiPhaseKernel_P0
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
|
||||
// cnt total tokens for a expert
|
||||
template <typename Problem_>
|
||||
struct MoeSortingMultiPhaseKernel_P1
|
||||
|
||||
Reference in New Issue
Block a user