mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-04 05:31:24 +00:00
[CK_TILE] optimize moe-sorting kernel (#1771)
* opt moe sorting * remove commented code
This commit is contained in:
@@ -3,18 +3,42 @@
|
||||
|
||||
#include "fused_moesorting.hpp"
|
||||
|
||||
#define MOE_SORTING_DISPATCH(unroll_num_) \
|
||||
constexpr ck_tile::index_t unroll_num = unroll_num_; \
|
||||
using ms_problem = ck_tile::MoeSortingProblem<index_t, ms_weight_type, unroll_num>; \
|
||||
using kernel = ck_tile::MoeSortingKernel<ms_problem>; \
|
||||
auto kargs = kernel::MakeKargs(a); \
|
||||
const dim3 grids = kernel::GridSize(a); \
|
||||
const dim3 blocks = kernel::BlockSize(a); \
|
||||
const auto lds_bytes = kernel::GetSmemSize(a); \
|
||||
float ave_time = ck_tile::launch_kernel( \
|
||||
s, ck_tile::make_kernel(kernel{}, grids, blocks, lds_bytes, kargs)); \
|
||||
#define MOE_SORTING_DISPATCH_ETILE(unroll_num_, expert_tile_) \
|
||||
constexpr ck_tile::index_t unroll_num = unroll_num_; \
|
||||
constexpr ck_tile::index_t expert_tile = expert_tile_; \
|
||||
using ms_problem = \
|
||||
ck_tile::MoeSortingProblem<index_t, ms_weight_type, unroll_num, expert_tile>; \
|
||||
using kernel = ck_tile::MoeSortingKernel<ms_problem>; \
|
||||
auto kargs = kernel::MakeKargs(a); \
|
||||
const dim3 grids = kernel::GridSize(a); \
|
||||
const dim3 blocks = kernel::BlockSize(a); \
|
||||
const auto lds_bytes = kernel::GetSmemSize(a); \
|
||||
float ave_time = ck_tile::launch_kernel( \
|
||||
s, ck_tile::make_kernel(kernel{}, grids, blocks, lds_bytes, kargs)); \
|
||||
return ave_time;
|
||||
|
||||
#define MOE_SORTING_DISPATCH(unroll_num_) \
|
||||
if(a.num_experts <= 8) \
|
||||
{ \
|
||||
MOE_SORTING_DISPATCH_ETILE(unroll_num_, 8) \
|
||||
} \
|
||||
else if(a.num_experts <= 16) \
|
||||
{ \
|
||||
MOE_SORTING_DISPATCH_ETILE(unroll_num_, 16) \
|
||||
} \
|
||||
else if(a.num_experts <= 32) \
|
||||
{ \
|
||||
MOE_SORTING_DISPATCH_ETILE(unroll_num_, 32) \
|
||||
} \
|
||||
else if(a.num_experts <= 64) \
|
||||
{ \
|
||||
MOE_SORTING_DISPATCH_ETILE(unroll_num_, 64) \
|
||||
} \
|
||||
else \
|
||||
{ \
|
||||
MOE_SORTING_DISPATCH_ETILE(unroll_num_, 0) \
|
||||
}
|
||||
|
||||
float fused_moesorting(fused_moesorting_trait t, fused_moesorting_args a, ck_tile::stream_config s)
|
||||
{
|
||||
if(t.weight_type == "fp32" && t.index_type == "int32")
|
||||
@@ -49,21 +73,12 @@ float fused_moesorting(fused_moesorting_trait t, fused_moesorting_args a, ck_til
|
||||
case(6): {
|
||||
MOE_SORTING_DISPATCH(6);
|
||||
}
|
||||
case(7): {
|
||||
MOE_SORTING_DISPATCH(7);
|
||||
}
|
||||
case(8): {
|
||||
MOE_SORTING_DISPATCH(8);
|
||||
}
|
||||
case(9): {
|
||||
MOE_SORTING_DISPATCH(9);
|
||||
}
|
||||
case(10): {
|
||||
MOE_SORTING_DISPATCH(10);
|
||||
}
|
||||
case(11): {
|
||||
MOE_SORTING_DISPATCH(11);
|
||||
}
|
||||
default: {
|
||||
MOE_SORTING_DISPATCH(4);
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user