merge clear workspace with P0_v1

- Merged the workspace clearing kernel with the P0_v1 kernel to save
time from separate kernel launch.

- Uses the workgroup_barrier to sync the wqrkgroups before they start
P0_v1 code
This commit is contained in:
Yashvardhan Agarwal
2025-10-21 10:26:08 +00:00
parent e135dd518d
commit 5efccadbd9
3 changed files with 54 additions and 40 deletions

View File

@@ -248,10 +248,11 @@ bool test_moe_sorting(ck_tile::ArgParser args)
topk,
#if MOE_SORTING_FMOE_2D_BUF
moe_buf_interm_dim,
moe_buf_elem_bytes
moe_buf_elem_bytes,
#else
static_cast<ck_tile::long_index_t>(moe_buf_size * sizeof(float))
static_cast<ck_tile::long_index_t>(moe_buf_size * sizeof(float)),
#endif
false // clear_workspace_in_p0 - will be set by API if needed
};
ck_tile::stream_config sc{nullptr,

View File

@@ -351,7 +351,6 @@ float moe_sorting(moe_sorting_trait t, moe_sorting_args a, ck_tile::stream_confi
{ \
float ave_time = \
ck_tile::launch_kernel(s, \
maybe_clear_workspace, \
MOE_SORTING_MP_0_V1(mesh_type_, token_vec_0_, true, true), \
MOE_SORTING_MP_1(mesh_type_, token_vec_1_, true, true), \
MOE_SORTING_MP_23(mesh_type_, token_vec_23_, true, true)); \
@@ -361,7 +360,6 @@ float moe_sorting(moe_sorting_trait t, moe_sorting_args a, ck_tile::stream_confi
{ \
float ave_time = \
ck_tile::launch_kernel(s, \
maybe_clear_workspace, \
MOE_SORTING_MP_0_V1(mesh_type_, token_vec_0_, true, false), \
MOE_SORTING_MP_1(mesh_type_, token_vec_1_, true, false), \
MOE_SORTING_MP_23(mesh_type_, token_vec_23_, true, false)); \
@@ -374,7 +372,6 @@ float moe_sorting(moe_sorting_trait t, moe_sorting_args a, ck_tile::stream_confi
{ \
float ave_time = \
ck_tile::launch_kernel(s, \
maybe_clear_workspace, \
MOE_SORTING_MP_0_V1(mesh_type_, token_vec_0_, false, true), \
MOE_SORTING_MP_1(mesh_type_, token_vec_1_, false, true), \
MOE_SORTING_MP_23(mesh_type_, token_vec_23_, false, true)); \
@@ -384,7 +381,6 @@ float moe_sorting(moe_sorting_trait t, moe_sorting_args a, ck_tile::stream_confi
{ \
float ave_time = ck_tile::launch_kernel( \
s, \
maybe_clear_workspace, \
MOE_SORTING_MP_0_V1(mesh_type_, token_vec_0_, false, false), \
MOE_SORTING_MP_1(mesh_type_, token_vec_1_, false, false), \
MOE_SORTING_MP_23(mesh_type_, token_vec_23_, false, false)); \
@@ -392,16 +388,7 @@ float moe_sorting(moe_sorting_trait t, moe_sorting_args a, ck_tile::stream_confi
} \
}
#define MOR_SORTING_CLEAR_WS_DISPATCH_(is_local_token_, block_size_, occu_) \
[&]() { \
using problem_ = \
ck_tile::MoeSortingClearWorkspaceProblem<is_local_token_, block_size_, occu_>; \
using kernel = ck_tile::MoeSortingClearWorkspaceKernel<problem_>; \
auto kargs = kernel::MakeKargs(a); \
const dim3 grids = kernel::GridSize(a); \
const dim3 blocks = kernel::BlockSize(a); \
return ck_tile::make_kernel(kernel{}, grids, blocks, 0, kargs); \
}()
float moe_sorting_mp(moe_sorting_trait t, moe_sorting_args a, ck_tile::stream_config s)
{
@@ -411,21 +398,8 @@ float moe_sorting_mp(moe_sorting_trait t, moe_sorting_args a, ck_tile::stream_co
using ms_index_t = ck_tile::index_t;
using ms_weight_type = float;
auto maybe_clear_workspace = [=](const ck_tile::stream_config& s_) {
if(t.clear_workspace_inside_api)
{
if(is_local_token)
{
auto k = MOR_SORTING_CLEAR_WS_DISPATCH_(true, 1024, 1);
k(s_);
}
else
{
auto k = MOR_SORTING_CLEAR_WS_DISPATCH_(false, 1024, 1);
k(s_);
}
}
};
// Set the workspace clearing flag in the arguments
a.clear_workspace_in_p0 = t.clear_workspace_inside_api;
if(a.tokens < 2048)
{
@@ -503,7 +477,6 @@ float moe_sorting_mp(moe_sorting_trait t, moe_sorting_args a, ck_tile::stream_co
{
float ave_time =
ck_tile::launch_kernel(s,
maybe_clear_workspace,
MOE_SORTING_MP_0_V1(ms_index_t, 1, true),
MOE_SORTING_MP_1(ms_index_t, 1, true),
MOE_SORTING_MP_2(ms_index_t, 1, true),
@@ -514,7 +487,6 @@ float moe_sorting_mp(moe_sorting_trait t, moe_sorting_args a, ck_tile::stream_co
{
float ave_time =
ck_tile::launch_kernel(s,
maybe_clear_workspace,
MOE_SORTING_MP_0_V1(ms_index_t, 1, false),
MOE_SORTING_MP_1(ms_index_t, 1, false),
MOE_SORTING_MP_2(ms_index_t, 1, false),

View File

@@ -198,6 +198,7 @@ struct MoeSortingHostArgs
#else
long_index_t moe_buf_bytes; // byte size of p_moe_buf
#endif
bool clear_workspace_in_p0; // flag to enable workspace clearing in P0_v1 kernel
};
@@ -1572,10 +1573,14 @@ struct MoeSortingMultiPhaseKernel_P0_v1
const void* p_topk_ids; // [tokens, topk]
const void* p_local_tokens; // [1], if not nullptr, use this as actual tokens
void* p_expert_mesh; // [expert, tokens]
void* p_sync_counter; //Synchronization counter
index_t tokens; // if p_local_tokens is not nullptr, this indicate the max possible tokens
// used for ws/LDS calculation
index_t num_experts;
index_t mesh_stride; // mesh_stride for p_expert_mesh
index_t mesh_byte_size; // byte size per mesh element (for clearing)
bool clear_workspace; // flag to enable workspace clearing
mdiv topk_mdiv;
};
@@ -1594,13 +1599,16 @@ struct MoeSortingMultiPhaseKernel_P0_v1
CK_TILE_HOST static constexpr auto MakeKargs(const Hargs& h)
{
Kargs k;
k.p_topk_ids = h.p_topk_ids;
k.p_local_tokens = h.p_local_tokens;
k.p_expert_mesh = h.p_ws;
k.tokens = h.tokens;
k.num_experts = h.num_experts;
k.mesh_stride = impl::moe_sorting_mp_mesh_stride(h.tokens);
k.topk_mdiv = mdiv{static_cast<uint32_t>(h.topk)};
k.p_topk_ids = h.p_topk_ids;
k.p_local_tokens = h.p_local_tokens;
k.p_expert_mesh = h.p_ws;
k.p_sync_counter = reinterpret_cast<uint32_t*>(reinterpret_cast<char*>(h.p_ws) + impl::moe_sorting_mp_mesh_smem_size(h.tokens, h.num_experts, h.topk));
k.tokens = h.tokens;
k.num_experts = h.num_experts;
k.mesh_stride = impl::moe_sorting_mp_mesh_stride(h.tokens);
k.mesh_byte_size = impl::moe_sorting_mesh_byte_size(h.tokens, h.num_experts, h.topk);
k.clear_workspace = h.clear_workspace_in_p0;
k.topk_mdiv = mdiv{static_cast<uint32_t>(h.topk)};
return k;
}
@@ -1617,6 +1625,13 @@ struct MoeSortingMultiPhaseKernel_P0_v1
const topk_id_t* p_topk_ids = reinterpret_cast<const topk_id_t*>(kargs.p_topk_ids);
MeshType* p_expert_mesh = reinterpret_cast<MeshType*>(kargs.p_expert_mesh);
if(blockIdx.x == 0 && threadIdx.x == 0)
{
*reinterpret_cast<uint32_t*>(kargs.p_sync_counter) = 0;
}
__syncthreads();
index_t tokens = [&]() {
if constexpr(Problem::LocalToken)
{
@@ -1646,6 +1661,32 @@ struct MoeSortingMultiPhaseKernel_P0_v1
return kargs.mesh_stride;
}
}();
// WORKSPACE CLEARING PHASE (if enabled)
if(kargs.clear_workspace)
{
ck_tile::workgroup_barrier wb{reinterpret_cast<uint32_t*>(kargs.p_sync_counter)};
index_t row_size = mesh_stride;
index_t pixels = kargs.num_experts * row_size;
index_t total_bytes = pixels * kargs.mesh_byte_size;
index_t clear_total_elems = total_bytes / 16; // always use dwordx4
using vector_type = ext_vector_t<index_t, 4>;
vector_type* p_expert_mesh_clear = reinterpret_cast<vector_type*>(kargs.p_expert_mesh);
auto zero_ = vector_type{0};
for(index_t i = blockIdx.x * kBlockSize + threadIdx.x; i < clear_total_elems;
i += gridDim.x * kBlockSize)
{
p_expert_mesh_clear[i] = zero_;
}
wb.inc();
wb.wait_eq(gridDim.x);
}
index_t total_elem = rounded_tokens * kargs.topk_mdiv.divisor / Problem::SubTokenTile;
#pragma unroll Problem::SubTokenTile