From 5efccadbd991ca17c0583da7da3849b491b6984a Mon Sep 17 00:00:00 2001 From: Yashvardhan Agarwal Date: Tue, 21 Oct 2025 10:26:08 +0000 Subject: [PATCH] merge clear workspace with P0_v1 - Merged the workspace clearing kernel with the P0_v1 kernel to save time from separate kernel launch. - Uses the workgroup_barrier to sync the wqrkgroups before they start P0_v1 code --- .../ck_tile/13_moe_sorting/moe_sorting.cpp | 5 +- .../13_moe_sorting/moe_sorting_api.cpp | 34 +----------- .../fused_moe/kernel/moe_sorting_kernel.hpp | 55 ++++++++++++++++--- 3 files changed, 54 insertions(+), 40 deletions(-) diff --git a/example/ck_tile/13_moe_sorting/moe_sorting.cpp b/example/ck_tile/13_moe_sorting/moe_sorting.cpp index b2ad4eb98c..82f370922e 100644 --- a/example/ck_tile/13_moe_sorting/moe_sorting.cpp +++ b/example/ck_tile/13_moe_sorting/moe_sorting.cpp @@ -248,10 +248,11 @@ bool test_moe_sorting(ck_tile::ArgParser args) topk, #if MOE_SORTING_FMOE_2D_BUF moe_buf_interm_dim, - moe_buf_elem_bytes + moe_buf_elem_bytes, #else - static_cast(moe_buf_size * sizeof(float)) + static_cast(moe_buf_size * sizeof(float)), #endif + false // clear_workspace_in_p0 - will be set by API if needed }; ck_tile::stream_config sc{nullptr, diff --git a/example/ck_tile/13_moe_sorting/moe_sorting_api.cpp b/example/ck_tile/13_moe_sorting/moe_sorting_api.cpp index 00c6be8f10..b866ceb74f 100644 --- a/example/ck_tile/13_moe_sorting/moe_sorting_api.cpp +++ b/example/ck_tile/13_moe_sorting/moe_sorting_api.cpp @@ -351,7 +351,6 @@ float moe_sorting(moe_sorting_trait t, moe_sorting_args a, ck_tile::stream_confi { \ float ave_time = \ ck_tile::launch_kernel(s, \ - maybe_clear_workspace, \ MOE_SORTING_MP_0_V1(mesh_type_, token_vec_0_, true, true), \ MOE_SORTING_MP_1(mesh_type_, token_vec_1_, true, true), \ MOE_SORTING_MP_23(mesh_type_, token_vec_23_, true, true)); \ @@ -361,7 +360,6 @@ float moe_sorting(moe_sorting_trait t, moe_sorting_args a, ck_tile::stream_confi { \ float ave_time = \ ck_tile::launch_kernel(s, \ - maybe_clear_workspace, \ MOE_SORTING_MP_0_V1(mesh_type_, token_vec_0_, true, false), \ MOE_SORTING_MP_1(mesh_type_, token_vec_1_, true, false), \ MOE_SORTING_MP_23(mesh_type_, token_vec_23_, true, false)); \ @@ -374,7 +372,6 @@ float moe_sorting(moe_sorting_trait t, moe_sorting_args a, ck_tile::stream_confi { \ float ave_time = \ ck_tile::launch_kernel(s, \ - maybe_clear_workspace, \ MOE_SORTING_MP_0_V1(mesh_type_, token_vec_0_, false, true), \ MOE_SORTING_MP_1(mesh_type_, token_vec_1_, false, true), \ MOE_SORTING_MP_23(mesh_type_, token_vec_23_, false, true)); \ @@ -384,7 +381,6 @@ float moe_sorting(moe_sorting_trait t, moe_sorting_args a, ck_tile::stream_confi { \ float ave_time = ck_tile::launch_kernel( \ s, \ - maybe_clear_workspace, \ MOE_SORTING_MP_0_V1(mesh_type_, token_vec_0_, false, false), \ MOE_SORTING_MP_1(mesh_type_, token_vec_1_, false, false), \ MOE_SORTING_MP_23(mesh_type_, token_vec_23_, false, false)); \ @@ -392,16 +388,7 @@ float moe_sorting(moe_sorting_trait t, moe_sorting_args a, ck_tile::stream_confi } \ } -#define MOR_SORTING_CLEAR_WS_DISPATCH_(is_local_token_, block_size_, occu_) \ - [&]() { \ - using problem_ = \ - ck_tile::MoeSortingClearWorkspaceProblem; \ - using kernel = ck_tile::MoeSortingClearWorkspaceKernel; \ - auto kargs = kernel::MakeKargs(a); \ - const dim3 grids = kernel::GridSize(a); \ - const dim3 blocks = kernel::BlockSize(a); \ - return ck_tile::make_kernel(kernel{}, grids, blocks, 0, kargs); \ - }() + float moe_sorting_mp(moe_sorting_trait t, moe_sorting_args a, ck_tile::stream_config s) { @@ -411,21 +398,8 @@ float moe_sorting_mp(moe_sorting_trait t, moe_sorting_args a, ck_tile::stream_co using ms_index_t = ck_tile::index_t; using ms_weight_type = float; - auto maybe_clear_workspace = [=](const ck_tile::stream_config& s_) { - if(t.clear_workspace_inside_api) - { - if(is_local_token) - { - auto k = MOR_SORTING_CLEAR_WS_DISPATCH_(true, 1024, 1); - k(s_); - } - else - { - auto k = MOR_SORTING_CLEAR_WS_DISPATCH_(false, 1024, 1); - k(s_); - } - } - }; + // Set the workspace clearing flag in the arguments + a.clear_workspace_in_p0 = t.clear_workspace_inside_api; if(a.tokens < 2048) { @@ -503,7 +477,6 @@ float moe_sorting_mp(moe_sorting_trait t, moe_sorting_args a, ck_tile::stream_co { float ave_time = ck_tile::launch_kernel(s, - maybe_clear_workspace, MOE_SORTING_MP_0_V1(ms_index_t, 1, true), MOE_SORTING_MP_1(ms_index_t, 1, true), MOE_SORTING_MP_2(ms_index_t, 1, true), @@ -514,7 +487,6 @@ float moe_sorting_mp(moe_sorting_trait t, moe_sorting_args a, ck_tile::stream_co { float ave_time = ck_tile::launch_kernel(s, - maybe_clear_workspace, MOE_SORTING_MP_0_V1(ms_index_t, 1, false), MOE_SORTING_MP_1(ms_index_t, 1, false), MOE_SORTING_MP_2(ms_index_t, 1, false), diff --git a/include/ck_tile/ops/fused_moe/kernel/moe_sorting_kernel.hpp b/include/ck_tile/ops/fused_moe/kernel/moe_sorting_kernel.hpp index 2918cd33bc..e668962778 100644 --- a/include/ck_tile/ops/fused_moe/kernel/moe_sorting_kernel.hpp +++ b/include/ck_tile/ops/fused_moe/kernel/moe_sorting_kernel.hpp @@ -198,6 +198,7 @@ struct MoeSortingHostArgs #else long_index_t moe_buf_bytes; // byte size of p_moe_buf #endif + bool clear_workspace_in_p0; // flag to enable workspace clearing in P0_v1 kernel }; @@ -1572,10 +1573,14 @@ struct MoeSortingMultiPhaseKernel_P0_v1 const void* p_topk_ids; // [tokens, topk] const void* p_local_tokens; // [1], if not nullptr, use this as actual tokens void* p_expert_mesh; // [expert, tokens] + void* p_sync_counter; //Synchronization counter + index_t tokens; // if p_local_tokens is not nullptr, this indicate the max possible tokens // used for ws/LDS calculation index_t num_experts; index_t mesh_stride; // mesh_stride for p_expert_mesh + index_t mesh_byte_size; // byte size per mesh element (for clearing) + bool clear_workspace; // flag to enable workspace clearing mdiv topk_mdiv; }; @@ -1594,13 +1599,16 @@ struct MoeSortingMultiPhaseKernel_P0_v1 CK_TILE_HOST static constexpr auto MakeKargs(const Hargs& h) { Kargs k; - k.p_topk_ids = h.p_topk_ids; - k.p_local_tokens = h.p_local_tokens; - k.p_expert_mesh = h.p_ws; - k.tokens = h.tokens; - k.num_experts = h.num_experts; - k.mesh_stride = impl::moe_sorting_mp_mesh_stride(h.tokens); - k.topk_mdiv = mdiv{static_cast(h.topk)}; + k.p_topk_ids = h.p_topk_ids; + k.p_local_tokens = h.p_local_tokens; + k.p_expert_mesh = h.p_ws; + k.p_sync_counter = reinterpret_cast(reinterpret_cast(h.p_ws) + impl::moe_sorting_mp_mesh_smem_size(h.tokens, h.num_experts, h.topk)); + k.tokens = h.tokens; + k.num_experts = h.num_experts; + k.mesh_stride = impl::moe_sorting_mp_mesh_stride(h.tokens); + k.mesh_byte_size = impl::moe_sorting_mesh_byte_size(h.tokens, h.num_experts, h.topk); + k.clear_workspace = h.clear_workspace_in_p0; + k.topk_mdiv = mdiv{static_cast(h.topk)}; return k; } @@ -1617,6 +1625,13 @@ struct MoeSortingMultiPhaseKernel_P0_v1 const topk_id_t* p_topk_ids = reinterpret_cast(kargs.p_topk_ids); MeshType* p_expert_mesh = reinterpret_cast(kargs.p_expert_mesh); + + if(blockIdx.x == 0 && threadIdx.x == 0) + { + *reinterpret_cast(kargs.p_sync_counter) = 0; + } + __syncthreads(); + index_t tokens = [&]() { if constexpr(Problem::LocalToken) { @@ -1646,6 +1661,32 @@ struct MoeSortingMultiPhaseKernel_P0_v1 return kargs.mesh_stride; } }(); + + // WORKSPACE CLEARING PHASE (if enabled) + if(kargs.clear_workspace) + { + ck_tile::workgroup_barrier wb{reinterpret_cast(kargs.p_sync_counter)}; + + index_t row_size = mesh_stride; + index_t pixels = kargs.num_experts * row_size; + index_t total_bytes = pixels * kargs.mesh_byte_size; + index_t clear_total_elems = total_bytes / 16; // always use dwordx4 + + using vector_type = ext_vector_t; + vector_type* p_expert_mesh_clear = reinterpret_cast(kargs.p_expert_mesh); + auto zero_ = vector_type{0}; + + for(index_t i = blockIdx.x * kBlockSize + threadIdx.x; i < clear_total_elems; + i += gridDim.x * kBlockSize) + { + p_expert_mesh_clear[i] = zero_; + } + + wb.inc(); + wb.wait_eq(gridDim.x); + + } + index_t total_elem = rounded_tokens * kargs.topk_mdiv.divisor / Problem::SubTokenTile; #pragma unroll Problem::SubTokenTile