From 20ffa0f47476ea1dbf4b2b6b9fd02c681c822ed5 Mon Sep 17 00:00:00 2001 From: felix Date: Fri, 28 Mar 2025 11:31:52 +0800 Subject: [PATCH] hotfix fix sorting int64 (#2025) * fix sorting int64 * clang format * fix example issue * update WA issue # --------- Co-authored-by: coderfeli Co-authored-by: carlushuang [ROCm/composable_kernel commit: a82f338fb9fb5743f071c5e6831c3dd92fcd7982] --- .../ck_tile/13_moe_sorting/moe_sorting.cpp | 4 +-- .../15_fused_moe/instances/fused_moe_api.cpp | 29 ++++++++++--------- include/ck_tile/core/config.hpp | 4 +++ .../fused_moe/kernel/moe_sorting_kernel.hpp | 18 ++++++++---- 4 files changed, 33 insertions(+), 22 deletions(-) diff --git a/example/ck_tile/13_moe_sorting/moe_sorting.cpp b/example/ck_tile/13_moe_sorting/moe_sorting.cpp index f00d948f25..e59fcaedad 100644 --- a/example/ck_tile/13_moe_sorting/moe_sorting.cpp +++ b/example/ck_tile/13_moe_sorting/moe_sorting.cpp @@ -74,7 +74,7 @@ bool test_moe_sorting(ck_tile::ArgParser args) int topk = args.get_int("k"); int seed = args.get_int("seed"); int unit_size = args.get_int("unit"); - int moe_buf_size = args.get_int("moe_buf_size"); + int64_t moe_buf_size = static_cast(args.get_uint64("moe_buf_size")); int kname = args.get_int("kname"); int warmup = args.get_int("warmup"); int repeat = args.get_int("repeat"); @@ -175,7 +175,7 @@ bool test_moe_sorting(ck_tile::ArgParser args) unit_size, num_experts, topk, - static_cast(moe_buf_size * sizeof(float))}; + static_cast(moe_buf_size * sizeof(float))}; ck_tile::stream_config sc{nullptr, true, diff --git a/example/ck_tile/15_fused_moe/instances/fused_moe_api.cpp b/example/ck_tile/15_fused_moe/instances/fused_moe_api.cpp index 466420f066..f887d57aa9 100644 --- a/example/ck_tile/15_fused_moe/instances/fused_moe_api.cpp +++ b/example/ck_tile/15_fused_moe/instances/fused_moe_api.cpp @@ -19,20 +19,21 @@ float fused_moe(fused_moe_traits t, fused_moe_args a, const ck_tile::stream_conf auto t0 = fused_moesorting_trait{"int32", "fp32", t.local_expert_masking}; auto a0 = fused_moesorting_args{ - a.topk_ids_ptr, // const void* p_topk_ids; - a.topk_weight_ptr, // const void* p_weights; - a.local_expert_mask_ptr, // const void* p_local_expert_mask; - a.sorted_token_ids_ptr, // void* p_sorted_token_ids; - a.sorted_weight_ptr, // void* p_sorted_weights; - a.sorted_expert_ids_ptr, // void* p_sorted_expert_ids; - a.num_sorted_tiles_ptr, // void* p_total_tokens_post_pad; - a.o_ptr, // void* p_moe_buf; - a.ws_ptr, // void* p_ws; - a.num_tokens, // index_t tokens; - a.block_m, // index_t unit_size; - a.num_experts, // index_t num_experts; - a.topk, // index_t topk; - a.num_tokens * a.stride_token * o_data_bytes // index_t moe_buf_bytes; + a.topk_ids_ptr, // const void* p_topk_ids; + a.topk_weight_ptr, // const void* p_weights; + a.local_expert_mask_ptr, // const void* p_local_expert_mask; + a.sorted_token_ids_ptr, // void* p_sorted_token_ids; + a.sorted_weight_ptr, // void* p_sorted_weights; + a.sorted_expert_ids_ptr, // void* p_sorted_expert_ids; + a.num_sorted_tiles_ptr, // void* p_total_tokens_post_pad; + a.o_ptr, // void* p_moe_buf; + a.ws_ptr, // void* p_ws; + a.num_tokens, // index_t tokens; + a.block_m, // index_t unit_size; + a.num_experts, // index_t num_experts; + a.topk, // index_t topk; + static_cast(a.num_tokens) * a.stride_token * + o_data_bytes // index_t moe_buf_bytes; }; auto t1 = fused_moegemm_traits{t.prec_i, diff --git a/include/ck_tile/core/config.hpp b/include/ck_tile/core/config.hpp index eeaf0dca6f..b1d201e30e 100644 --- a/include/ck_tile/core/config.hpp +++ b/include/ck_tile/core/config.hpp @@ -260,3 +260,7 @@ CK_TILE_DECLARE_ENV_VAR_BOOL(CK_TILE_LOGGING) #define CK_TILE_USE_BUFFER_ADDRESSING_BUILTIN 0 #endif #endif + +#ifndef CK_TILE_WA_ISSUE_2028 +#define CK_TILE_WA_ISSUE_2028 1 +#endif diff --git a/include/ck_tile/ops/fused_moe/kernel/moe_sorting_kernel.hpp b/include/ck_tile/ops/fused_moe/kernel/moe_sorting_kernel.hpp index a1410d1f4f..6a7ccd2472 100644 --- a/include/ck_tile/ops/fused_moe/kernel/moe_sorting_kernel.hpp +++ b/include/ck_tile/ops/fused_moe/kernel/moe_sorting_kernel.hpp @@ -192,7 +192,7 @@ struct MoeSortingHostArgs index_t unit_size; // this is the M_a of fused-moe kernel index_t num_experts; index_t topk; - index_t moe_buf_bytes; // byte size of p_moe_buf + long_index_t moe_buf_bytes; // byte size of p_moe_buf }; template @@ -219,7 +219,7 @@ struct MoeSortingKernel void* p_moe_buf; index_t tokens; index_t num_experts; - index_t moe_buf_bytes; + long_index_t moe_buf_bytes; index_t tokens_per_thread; index_t smem_rows; @@ -426,7 +426,7 @@ struct MoeSortingKernel return row * total_col + col; } - CK_TILE_DEVICE void moe_buf_set_zero_kernel(uint8x16_t* buf, index_t buf_bytes) const + CK_TILE_DEVICE void moe_buf_set_zero_kernel(uint8x16_t* buf, long_index_t buf_bytes) const { const index_t offset = (blockIdx.x - 1) * blockDim.x + threadIdx.x; if(offset < buf_bytes / 16) @@ -1218,10 +1218,10 @@ CK_TILE_DEVICE void moe_sorting_wave_cumsum(data_t& thread_data) } template -CK_TILE_DEVICE void moe_buf_set_zero_kernel(uint8x16_t* buf, index_t buf_bytes, index_t gid) +CK_TILE_DEVICE void moe_buf_set_zero_kernel(uint8x16_t* buf, long_index_t buf_bytes, index_t gid) { // const index_t offset = (blockIdx.x - 1) * BLOCK_SIZE + threadIdx.x; - index_t offset = gid * BLOCK_SIZE + threadIdx.x; + long_index_t offset = static_cast(gid) * BLOCK_SIZE + threadIdx.x; if(offset < buf_bytes / 16) { buf[offset] = uint8x16_t{0}; @@ -1233,6 +1233,12 @@ CK_TILE_DEVICE void moe_buf_set_zero_kernel(uint8x16_t* buf, index_t buf_bytes, // prefer to run mp kernel if is not oneshot CK_TILE_HOST bool moe_sorting_is_oneshot(int tokens_, int num_experts_) { +#if CK_TILE_WA_ISSUE_2028 + if(tokens_ >= 65536 * 2) + { + return true; + } +#endif auto sub_token_ = moe_sorting_get_sub_token(tokens_, num_experts_); bool is_sub_token_onshot = tokens_ <= sub_token_; return is_sub_token_onshot; @@ -1523,7 +1529,7 @@ struct MoeSortingMultiPhaseKernel_P2 index_t num_experts; index_t mesh_stride; // mesh_stride for p_expert_mesh mdiv unit_size_mdiv; - index_t moe_buf_bytes; + long_index_t moe_buf_bytes; }; CK_TILE_HOST static constexpr auto MakeKargs(const Hargs& h)