mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-02 20:51:23 +00:00
porting fmoe_sorting from moe_sorting (#1884)
* porting fmoe_sorting from moe_sorting * pass default example test * remod
This commit is contained in:
@@ -17,10 +17,11 @@ float fused_moe(fused_moe_traits t, fused_moe_args a, const ck_tile::stream_conf
|
||||
return 1;
|
||||
}();
|
||||
|
||||
auto t0 = fused_moesorting_trait{"int32", "fp32"};
|
||||
auto t0 = fused_moesorting_trait{"int32", "fp32", t.local_expert_masking};
|
||||
auto a0 = fused_moesorting_args{
|
||||
a.topk_ids_ptr, // const void* p_topk_ids;
|
||||
a.topk_weight_ptr, // const void* p_weights;
|
||||
a.local_expert_mask_ptr, // const void* p_local_expert_mask;
|
||||
a.sorted_token_ids_ptr, // void* p_sorted_token_ids;
|
||||
a.sorted_weight_ptr, // void* p_sorted_weights;
|
||||
a.sorted_expert_ids_ptr, // void* p_sorted_expert_ids;
|
||||
|
||||
@@ -24,20 +24,63 @@
|
||||
return ave_time;
|
||||
|
||||
#else
|
||||
#define MOE_SORTING_DISPATCH_(sub_token_tile_, sub_token_onshot_) \
|
||||
constexpr ck_tile::index_t sub_token_tile = sub_token_tile_; \
|
||||
constexpr bool sub_token_onshot = sub_token_onshot_; \
|
||||
using ms_problem = \
|
||||
ck_tile::MoeSortingProblemEx<index_t, ms_weight_type, sub_token_tile, sub_token_onshot>; \
|
||||
using kernel = ck_tile::MoeSortingKernel<ms_problem>; \
|
||||
auto kargs = kernel::MakeKargs(a); \
|
||||
const dim3 grids = kernel::GridSize(a); \
|
||||
const dim3 blocks = kernel::BlockSize(a); \
|
||||
const auto lds_bytes = kernel::GetSmemSize(a); \
|
||||
float ave_time = ck_tile::launch_kernel( \
|
||||
s, ck_tile::make_kernel(kernel{}, grids, blocks, lds_bytes, kargs)); \
|
||||
|
||||
#define MOE_SORTING_DISPATCH_(sub_token_tile_, sub_token_onshot_, local_expert_masking_) \
|
||||
constexpr ck_tile::index_t sub_token_tile = sub_token_tile_; \
|
||||
constexpr bool sub_token_onshot = sub_token_onshot_; \
|
||||
constexpr bool local_expert_masking = local_expert_masking_; \
|
||||
using ms_problem = ck_tile::MoeSortingProblemEx<index_t, \
|
||||
ms_weight_type, \
|
||||
sub_token_tile, \
|
||||
sub_token_onshot, \
|
||||
local_expert_masking>; \
|
||||
using kernel = ck_tile::MoeSortingKernel<ms_problem>; \
|
||||
auto kargs = kernel::MakeKargs(a); \
|
||||
const dim3 grids = kernel::GridSize(a); \
|
||||
const dim3 blocks = kernel::BlockSize(a); \
|
||||
const auto lds_bytes = kernel::GetSmemSize(a); \
|
||||
float ave_time = ck_tile::launch_kernel( \
|
||||
s, ck_tile::make_kernel(kernel{}, grids, blocks, lds_bytes, kargs)); \
|
||||
return ave_time;
|
||||
|
||||
#define MOE_SORTING_DISPATCH_SUB_TOKEN_(row_, sub_token_onshot_, local_expert_masking_) \
|
||||
if(row_ % 8 == 0) \
|
||||
{ \
|
||||
MOE_SORTING_DISPATCH_(8, sub_token_onshot_, local_expert_masking_); \
|
||||
} \
|
||||
else if(row_ % 4 == 0) \
|
||||
{ \
|
||||
MOE_SORTING_DISPATCH_(4, sub_token_onshot_, local_expert_masking_); \
|
||||
} \
|
||||
else if(row_ % 2 == 0) \
|
||||
{ \
|
||||
MOE_SORTING_DISPATCH_(2, sub_token_onshot_, local_expert_masking_); \
|
||||
} \
|
||||
else \
|
||||
{ \
|
||||
MOE_SORTING_DISPATCH_(1, sub_token_onshot_, local_expert_masking_); \
|
||||
}
|
||||
|
||||
#define MOE_SORTING_DISPATCH_SUBTO_(row_, local_expert_masking_) \
|
||||
if(is_sub_token_onshot) \
|
||||
{ \
|
||||
MOE_SORTING_DISPATCH_SUB_TOKEN_(row_, true, local_expert_masking_) \
|
||||
} \
|
||||
else \
|
||||
{ \
|
||||
MOE_SORTING_DISPATCH_SUB_TOKEN_(row_, false, local_expert_masking_) \
|
||||
}
|
||||
|
||||
#define MOE_SORTING_DISPATCH_EMASK_(row_) \
|
||||
if(is_local_expert_masking) \
|
||||
{ \
|
||||
MOE_SORTING_DISPATCH_SUBTO_(row_, true) \
|
||||
} \
|
||||
else \
|
||||
{ \
|
||||
MOE_SORTING_DISPATCH_SUBTO_(row_, false) \
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
#if !MOE_SORTING_USE_EX_KERNEL
|
||||
@@ -116,45 +159,10 @@ float fused_moesorting(fused_moesorting_trait t, fused_moesorting_args a, ck_til
|
||||
auto sub_token_ = r_ - 2;
|
||||
r_ = (r_ - 2) / 8;
|
||||
bool is_sub_token_onshot = a.tokens <= sub_token_;
|
||||
bool is_local_expert_masking = t.local_expert_masking;
|
||||
(void)c_;
|
||||
if(is_sub_token_onshot)
|
||||
{
|
||||
if(r_ % 8 == 0)
|
||||
{
|
||||
MOE_SORTING_DISPATCH_(8, true);
|
||||
}
|
||||
else if(r_ % 4 == 0)
|
||||
{
|
||||
MOE_SORTING_DISPATCH_(4, true);
|
||||
}
|
||||
else if(r_ % 2 == 0)
|
||||
{
|
||||
MOE_SORTING_DISPATCH_(2, true);
|
||||
}
|
||||
else
|
||||
{
|
||||
MOE_SORTING_DISPATCH_(1, true);
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
if(r_ % 8 == 0)
|
||||
{
|
||||
MOE_SORTING_DISPATCH_(8, false);
|
||||
}
|
||||
else if(r_ % 4 == 0)
|
||||
{
|
||||
MOE_SORTING_DISPATCH_(4, false);
|
||||
}
|
||||
else if(r_ % 2 == 0)
|
||||
{
|
||||
MOE_SORTING_DISPATCH_(2, false);
|
||||
}
|
||||
else
|
||||
{
|
||||
MOE_SORTING_DISPATCH_(1, false);
|
||||
}
|
||||
}
|
||||
|
||||
MOE_SORTING_DISPATCH_EMASK_(r_);
|
||||
// MOE_SORTING_DISPATCH_ETILE(0, 0);
|
||||
#endif
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user