mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-30 11:47:48 +00:00
merge clear workspace with P0_v1
- Merged the workspace clearing kernel with the P0_v1 kernel to save time from separate kernel launch. - Uses the workgroup_barrier to sync the wqrkgroups before they start P0_v1 code
This commit is contained in:
@@ -248,10 +248,11 @@ bool test_moe_sorting(ck_tile::ArgParser args)
|
||||
topk,
|
||||
#if MOE_SORTING_FMOE_2D_BUF
|
||||
moe_buf_interm_dim,
|
||||
moe_buf_elem_bytes
|
||||
moe_buf_elem_bytes,
|
||||
#else
|
||||
static_cast<ck_tile::long_index_t>(moe_buf_size * sizeof(float))
|
||||
static_cast<ck_tile::long_index_t>(moe_buf_size * sizeof(float)),
|
||||
#endif
|
||||
false // clear_workspace_in_p0 - will be set by API if needed
|
||||
};
|
||||
|
||||
ck_tile::stream_config sc{nullptr,
|
||||
|
||||
@@ -351,7 +351,6 @@ float moe_sorting(moe_sorting_trait t, moe_sorting_args a, ck_tile::stream_confi
|
||||
{ \
|
||||
float ave_time = \
|
||||
ck_tile::launch_kernel(s, \
|
||||
maybe_clear_workspace, \
|
||||
MOE_SORTING_MP_0_V1(mesh_type_, token_vec_0_, true, true), \
|
||||
MOE_SORTING_MP_1(mesh_type_, token_vec_1_, true, true), \
|
||||
MOE_SORTING_MP_23(mesh_type_, token_vec_23_, true, true)); \
|
||||
@@ -361,7 +360,6 @@ float moe_sorting(moe_sorting_trait t, moe_sorting_args a, ck_tile::stream_confi
|
||||
{ \
|
||||
float ave_time = \
|
||||
ck_tile::launch_kernel(s, \
|
||||
maybe_clear_workspace, \
|
||||
MOE_SORTING_MP_0_V1(mesh_type_, token_vec_0_, true, false), \
|
||||
MOE_SORTING_MP_1(mesh_type_, token_vec_1_, true, false), \
|
||||
MOE_SORTING_MP_23(mesh_type_, token_vec_23_, true, false)); \
|
||||
@@ -374,7 +372,6 @@ float moe_sorting(moe_sorting_trait t, moe_sorting_args a, ck_tile::stream_confi
|
||||
{ \
|
||||
float ave_time = \
|
||||
ck_tile::launch_kernel(s, \
|
||||
maybe_clear_workspace, \
|
||||
MOE_SORTING_MP_0_V1(mesh_type_, token_vec_0_, false, true), \
|
||||
MOE_SORTING_MP_1(mesh_type_, token_vec_1_, false, true), \
|
||||
MOE_SORTING_MP_23(mesh_type_, token_vec_23_, false, true)); \
|
||||
@@ -384,7 +381,6 @@ float moe_sorting(moe_sorting_trait t, moe_sorting_args a, ck_tile::stream_confi
|
||||
{ \
|
||||
float ave_time = ck_tile::launch_kernel( \
|
||||
s, \
|
||||
maybe_clear_workspace, \
|
||||
MOE_SORTING_MP_0_V1(mesh_type_, token_vec_0_, false, false), \
|
||||
MOE_SORTING_MP_1(mesh_type_, token_vec_1_, false, false), \
|
||||
MOE_SORTING_MP_23(mesh_type_, token_vec_23_, false, false)); \
|
||||
@@ -392,16 +388,7 @@ float moe_sorting(moe_sorting_trait t, moe_sorting_args a, ck_tile::stream_confi
|
||||
} \
|
||||
}
|
||||
|
||||
#define MOR_SORTING_CLEAR_WS_DISPATCH_(is_local_token_, block_size_, occu_) \
|
||||
[&]() { \
|
||||
using problem_ = \
|
||||
ck_tile::MoeSortingClearWorkspaceProblem<is_local_token_, block_size_, occu_>; \
|
||||
using kernel = ck_tile::MoeSortingClearWorkspaceKernel<problem_>; \
|
||||
auto kargs = kernel::MakeKargs(a); \
|
||||
const dim3 grids = kernel::GridSize(a); \
|
||||
const dim3 blocks = kernel::BlockSize(a); \
|
||||
return ck_tile::make_kernel(kernel{}, grids, blocks, 0, kargs); \
|
||||
}()
|
||||
|
||||
|
||||
float moe_sorting_mp(moe_sorting_trait t, moe_sorting_args a, ck_tile::stream_config s)
|
||||
{
|
||||
@@ -411,21 +398,8 @@ float moe_sorting_mp(moe_sorting_trait t, moe_sorting_args a, ck_tile::stream_co
|
||||
using ms_index_t = ck_tile::index_t;
|
||||
using ms_weight_type = float;
|
||||
|
||||
auto maybe_clear_workspace = [=](const ck_tile::stream_config& s_) {
|
||||
if(t.clear_workspace_inside_api)
|
||||
{
|
||||
if(is_local_token)
|
||||
{
|
||||
auto k = MOR_SORTING_CLEAR_WS_DISPATCH_(true, 1024, 1);
|
||||
k(s_);
|
||||
}
|
||||
else
|
||||
{
|
||||
auto k = MOR_SORTING_CLEAR_WS_DISPATCH_(false, 1024, 1);
|
||||
k(s_);
|
||||
}
|
||||
}
|
||||
};
|
||||
// Set the workspace clearing flag in the arguments
|
||||
a.clear_workspace_in_p0 = t.clear_workspace_inside_api;
|
||||
|
||||
if(a.tokens < 2048)
|
||||
{
|
||||
@@ -503,7 +477,6 @@ float moe_sorting_mp(moe_sorting_trait t, moe_sorting_args a, ck_tile::stream_co
|
||||
{
|
||||
float ave_time =
|
||||
ck_tile::launch_kernel(s,
|
||||
maybe_clear_workspace,
|
||||
MOE_SORTING_MP_0_V1(ms_index_t, 1, true),
|
||||
MOE_SORTING_MP_1(ms_index_t, 1, true),
|
||||
MOE_SORTING_MP_2(ms_index_t, 1, true),
|
||||
@@ -514,7 +487,6 @@ float moe_sorting_mp(moe_sorting_trait t, moe_sorting_args a, ck_tile::stream_co
|
||||
{
|
||||
float ave_time =
|
||||
ck_tile::launch_kernel(s,
|
||||
maybe_clear_workspace,
|
||||
MOE_SORTING_MP_0_V1(ms_index_t, 1, false),
|
||||
MOE_SORTING_MP_1(ms_index_t, 1, false),
|
||||
MOE_SORTING_MP_2(ms_index_t, 1, false),
|
||||
|
||||
@@ -198,6 +198,7 @@ struct MoeSortingHostArgs
|
||||
#else
|
||||
long_index_t moe_buf_bytes; // byte size of p_moe_buf
|
||||
#endif
|
||||
bool clear_workspace_in_p0; // flag to enable workspace clearing in P0_v1 kernel
|
||||
|
||||
};
|
||||
|
||||
@@ -1572,10 +1573,14 @@ struct MoeSortingMultiPhaseKernel_P0_v1
|
||||
const void* p_topk_ids; // [tokens, topk]
|
||||
const void* p_local_tokens; // [1], if not nullptr, use this as actual tokens
|
||||
void* p_expert_mesh; // [expert, tokens]
|
||||
void* p_sync_counter; //Synchronization counter
|
||||
|
||||
index_t tokens; // if p_local_tokens is not nullptr, this indicate the max possible tokens
|
||||
// used for ws/LDS calculation
|
||||
index_t num_experts;
|
||||
index_t mesh_stride; // mesh_stride for p_expert_mesh
|
||||
index_t mesh_byte_size; // byte size per mesh element (for clearing)
|
||||
bool clear_workspace; // flag to enable workspace clearing
|
||||
mdiv topk_mdiv;
|
||||
};
|
||||
|
||||
@@ -1594,13 +1599,16 @@ struct MoeSortingMultiPhaseKernel_P0_v1
|
||||
CK_TILE_HOST static constexpr auto MakeKargs(const Hargs& h)
|
||||
{
|
||||
Kargs k;
|
||||
k.p_topk_ids = h.p_topk_ids;
|
||||
k.p_local_tokens = h.p_local_tokens;
|
||||
k.p_expert_mesh = h.p_ws;
|
||||
k.tokens = h.tokens;
|
||||
k.num_experts = h.num_experts;
|
||||
k.mesh_stride = impl::moe_sorting_mp_mesh_stride(h.tokens);
|
||||
k.topk_mdiv = mdiv{static_cast<uint32_t>(h.topk)};
|
||||
k.p_topk_ids = h.p_topk_ids;
|
||||
k.p_local_tokens = h.p_local_tokens;
|
||||
k.p_expert_mesh = h.p_ws;
|
||||
k.p_sync_counter = reinterpret_cast<uint32_t*>(reinterpret_cast<char*>(h.p_ws) + impl::moe_sorting_mp_mesh_smem_size(h.tokens, h.num_experts, h.topk));
|
||||
k.tokens = h.tokens;
|
||||
k.num_experts = h.num_experts;
|
||||
k.mesh_stride = impl::moe_sorting_mp_mesh_stride(h.tokens);
|
||||
k.mesh_byte_size = impl::moe_sorting_mesh_byte_size(h.tokens, h.num_experts, h.topk);
|
||||
k.clear_workspace = h.clear_workspace_in_p0;
|
||||
k.topk_mdiv = mdiv{static_cast<uint32_t>(h.topk)};
|
||||
return k;
|
||||
}
|
||||
|
||||
@@ -1617,6 +1625,13 @@ struct MoeSortingMultiPhaseKernel_P0_v1
|
||||
|
||||
const topk_id_t* p_topk_ids = reinterpret_cast<const topk_id_t*>(kargs.p_topk_ids);
|
||||
MeshType* p_expert_mesh = reinterpret_cast<MeshType*>(kargs.p_expert_mesh);
|
||||
|
||||
if(blockIdx.x == 0 && threadIdx.x == 0)
|
||||
{
|
||||
*reinterpret_cast<uint32_t*>(kargs.p_sync_counter) = 0;
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
index_t tokens = [&]() {
|
||||
if constexpr(Problem::LocalToken)
|
||||
{
|
||||
@@ -1646,6 +1661,32 @@ struct MoeSortingMultiPhaseKernel_P0_v1
|
||||
return kargs.mesh_stride;
|
||||
}
|
||||
}();
|
||||
|
||||
// WORKSPACE CLEARING PHASE (if enabled)
|
||||
if(kargs.clear_workspace)
|
||||
{
|
||||
ck_tile::workgroup_barrier wb{reinterpret_cast<uint32_t*>(kargs.p_sync_counter)};
|
||||
|
||||
index_t row_size = mesh_stride;
|
||||
index_t pixels = kargs.num_experts * row_size;
|
||||
index_t total_bytes = pixels * kargs.mesh_byte_size;
|
||||
index_t clear_total_elems = total_bytes / 16; // always use dwordx4
|
||||
|
||||
using vector_type = ext_vector_t<index_t, 4>;
|
||||
vector_type* p_expert_mesh_clear = reinterpret_cast<vector_type*>(kargs.p_expert_mesh);
|
||||
auto zero_ = vector_type{0};
|
||||
|
||||
for(index_t i = blockIdx.x * kBlockSize + threadIdx.x; i < clear_total_elems;
|
||||
i += gridDim.x * kBlockSize)
|
||||
{
|
||||
p_expert_mesh_clear[i] = zero_;
|
||||
}
|
||||
|
||||
wb.inc();
|
||||
wb.wait_eq(gridDim.x);
|
||||
|
||||
}
|
||||
|
||||
index_t total_elem = rounded_tokens * kargs.topk_mdiv.divisor / Problem::SubTokenTile;
|
||||
|
||||
#pragma unroll Problem::SubTokenTile
|
||||
|
||||
Reference in New Issue
Block a user