hotfix fix sorting int64 (#2025)

* fix sorting int64

* clang format

* fix example issue

* update WA issue #

---------

Co-authored-by: coderfeli <coderfeli@163.com>
Co-authored-by: carlushuang <carlus.huang@amd.com>

[ROCm/composable_kernel commit: a82f338fb9]
This commit is contained in:
felix
2025-03-28 11:31:52 +08:00
committed by GitHub
parent 895ba2b497
commit 20ffa0f474
4 changed files with 33 additions and 22 deletions

View File

@@ -74,7 +74,7 @@ bool test_moe_sorting(ck_tile::ArgParser args)
int topk = args.get_int("k");
int seed = args.get_int("seed");
int unit_size = args.get_int("unit");
int moe_buf_size = args.get_int("moe_buf_size");
int64_t moe_buf_size = static_cast<int64_t>(args.get_uint64("moe_buf_size"));
int kname = args.get_int("kname");
int warmup = args.get_int("warmup");
int repeat = args.get_int("repeat");
@@ -175,7 +175,7 @@ bool test_moe_sorting(ck_tile::ArgParser args)
unit_size,
num_experts,
topk,
static_cast<ck_tile::index_t>(moe_buf_size * sizeof(float))};
static_cast<ck_tile::long_index_t>(moe_buf_size * sizeof(float))};
ck_tile::stream_config sc{nullptr,
true,

View File

@@ -19,20 +19,21 @@ float fused_moe(fused_moe_traits t, fused_moe_args a, const ck_tile::stream_conf
auto t0 = fused_moesorting_trait{"int32", "fp32", t.local_expert_masking};
auto a0 = fused_moesorting_args{
a.topk_ids_ptr, // const void* p_topk_ids;
a.topk_weight_ptr, // const void* p_weights;
a.local_expert_mask_ptr, // const void* p_local_expert_mask;
a.sorted_token_ids_ptr, // void* p_sorted_token_ids;
a.sorted_weight_ptr, // void* p_sorted_weights;
a.sorted_expert_ids_ptr, // void* p_sorted_expert_ids;
a.num_sorted_tiles_ptr, // void* p_total_tokens_post_pad;
a.o_ptr, // void* p_moe_buf;
a.ws_ptr, // void* p_ws;
a.num_tokens, // index_t tokens;
a.block_m, // index_t unit_size;
a.num_experts, // index_t num_experts;
a.topk, // index_t topk;
a.num_tokens * a.stride_token * o_data_bytes // index_t moe_buf_bytes;
a.topk_ids_ptr, // const void* p_topk_ids;
a.topk_weight_ptr, // const void* p_weights;
a.local_expert_mask_ptr, // const void* p_local_expert_mask;
a.sorted_token_ids_ptr, // void* p_sorted_token_ids;
a.sorted_weight_ptr, // void* p_sorted_weights;
a.sorted_expert_ids_ptr, // void* p_sorted_expert_ids;
a.num_sorted_tiles_ptr, // void* p_total_tokens_post_pad;
a.o_ptr, // void* p_moe_buf;
a.ws_ptr, // void* p_ws;
a.num_tokens, // index_t tokens;
a.block_m, // index_t unit_size;
a.num_experts, // index_t num_experts;
a.topk, // index_t topk;
static_cast<ck_tile::long_index_t>(a.num_tokens) * a.stride_token *
o_data_bytes // index_t moe_buf_bytes;
};
auto t1 = fused_moegemm_traits{t.prec_i,

View File

@@ -260,3 +260,7 @@ CK_TILE_DECLARE_ENV_VAR_BOOL(CK_TILE_LOGGING)
#define CK_TILE_USE_BUFFER_ADDRESSING_BUILTIN 0
#endif
#endif
#ifndef CK_TILE_WA_ISSUE_2028
#define CK_TILE_WA_ISSUE_2028 1
#endif

View File

@@ -192,7 +192,7 @@ struct MoeSortingHostArgs
index_t unit_size; // this is the M_a of fused-moe kernel
index_t num_experts;
index_t topk;
index_t moe_buf_bytes; // byte size of p_moe_buf
long_index_t moe_buf_bytes; // byte size of p_moe_buf
};
template <typename Problem_>
@@ -219,7 +219,7 @@ struct MoeSortingKernel
void* p_moe_buf;
index_t tokens;
index_t num_experts;
index_t moe_buf_bytes;
long_index_t moe_buf_bytes;
index_t tokens_per_thread;
index_t smem_rows;
@@ -426,7 +426,7 @@ struct MoeSortingKernel
return row * total_col + col;
}
CK_TILE_DEVICE void moe_buf_set_zero_kernel(uint8x16_t* buf, index_t buf_bytes) const
CK_TILE_DEVICE void moe_buf_set_zero_kernel(uint8x16_t* buf, long_index_t buf_bytes) const
{
const index_t offset = (blockIdx.x - 1) * blockDim.x + threadIdx.x;
if(offset < buf_bytes / 16)
@@ -1218,10 +1218,10 @@ CK_TILE_DEVICE void moe_sorting_wave_cumsum(data_t& thread_data)
}
template <index_t BLOCK_SIZE = 256>
CK_TILE_DEVICE void moe_buf_set_zero_kernel(uint8x16_t* buf, index_t buf_bytes, index_t gid)
CK_TILE_DEVICE void moe_buf_set_zero_kernel(uint8x16_t* buf, long_index_t buf_bytes, index_t gid)
{
// const index_t offset = (blockIdx.x - 1) * BLOCK_SIZE + threadIdx.x;
index_t offset = gid * BLOCK_SIZE + threadIdx.x;
long_index_t offset = static_cast<long_index_t>(gid) * BLOCK_SIZE + threadIdx.x;
if(offset < buf_bytes / 16)
{
buf[offset] = uint8x16_t{0};
@@ -1233,6 +1233,12 @@ CK_TILE_DEVICE void moe_buf_set_zero_kernel(uint8x16_t* buf, index_t buf_bytes,
// prefer to run mp kernel if is not oneshot
CK_TILE_HOST bool moe_sorting_is_oneshot(int tokens_, int num_experts_)
{
#if CK_TILE_WA_ISSUE_2028
if(tokens_ >= 65536 * 2)
{
return true;
}
#endif
auto sub_token_ = moe_sorting_get_sub_token(tokens_, num_experts_);
bool is_sub_token_onshot = tokens_ <= sub_token_;
return is_sub_token_onshot;
@@ -1523,7 +1529,7 @@ struct MoeSortingMultiPhaseKernel_P2
index_t num_experts;
index_t mesh_stride; // mesh_stride for p_expert_mesh
mdiv unit_size_mdiv;
index_t moe_buf_bytes;
long_index_t moe_buf_bytes;
};
CK_TILE_HOST static constexpr auto MakeKargs(const Hargs& h)