mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-03 05:01:25 +00:00
hotfix fix sorting int64 (#2025)
* fix sorting int64 * clang format * fix example issue * update WA issue # --------- Co-authored-by: coderfeli <coderfeli@163.com> Co-authored-by: carlushuang <carlus.huang@amd.com>
This commit is contained in:
@@ -192,7 +192,7 @@ struct MoeSortingHostArgs
|
||||
index_t unit_size; // this is the M_a of fused-moe kernel
|
||||
index_t num_experts;
|
||||
index_t topk;
|
||||
index_t moe_buf_bytes; // byte size of p_moe_buf
|
||||
long_index_t moe_buf_bytes; // byte size of p_moe_buf
|
||||
};
|
||||
|
||||
template <typename Problem_>
|
||||
@@ -219,7 +219,7 @@ struct MoeSortingKernel
|
||||
void* p_moe_buf;
|
||||
index_t tokens;
|
||||
index_t num_experts;
|
||||
index_t moe_buf_bytes;
|
||||
long_index_t moe_buf_bytes;
|
||||
|
||||
index_t tokens_per_thread;
|
||||
index_t smem_rows;
|
||||
@@ -426,7 +426,7 @@ struct MoeSortingKernel
|
||||
return row * total_col + col;
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE void moe_buf_set_zero_kernel(uint8x16_t* buf, index_t buf_bytes) const
|
||||
CK_TILE_DEVICE void moe_buf_set_zero_kernel(uint8x16_t* buf, long_index_t buf_bytes) const
|
||||
{
|
||||
const index_t offset = (blockIdx.x - 1) * blockDim.x + threadIdx.x;
|
||||
if(offset < buf_bytes / 16)
|
||||
@@ -1218,10 +1218,10 @@ CK_TILE_DEVICE void moe_sorting_wave_cumsum(data_t& thread_data)
|
||||
}
|
||||
|
||||
template <index_t BLOCK_SIZE = 256>
|
||||
CK_TILE_DEVICE void moe_buf_set_zero_kernel(uint8x16_t* buf, index_t buf_bytes, index_t gid)
|
||||
CK_TILE_DEVICE void moe_buf_set_zero_kernel(uint8x16_t* buf, long_index_t buf_bytes, index_t gid)
|
||||
{
|
||||
// const index_t offset = (blockIdx.x - 1) * BLOCK_SIZE + threadIdx.x;
|
||||
index_t offset = gid * BLOCK_SIZE + threadIdx.x;
|
||||
long_index_t offset = static_cast<long_index_t>(gid) * BLOCK_SIZE + threadIdx.x;
|
||||
if(offset < buf_bytes / 16)
|
||||
{
|
||||
buf[offset] = uint8x16_t{0};
|
||||
@@ -1233,6 +1233,12 @@ CK_TILE_DEVICE void moe_buf_set_zero_kernel(uint8x16_t* buf, index_t buf_bytes,
|
||||
// prefer to run mp kernel if is not oneshot
|
||||
CK_TILE_HOST bool moe_sorting_is_oneshot(int tokens_, int num_experts_)
|
||||
{
|
||||
#if CK_TILE_WA_ISSUE_2028
|
||||
if(tokens_ >= 65536 * 2)
|
||||
{
|
||||
return true;
|
||||
}
|
||||
#endif
|
||||
auto sub_token_ = moe_sorting_get_sub_token(tokens_, num_experts_);
|
||||
bool is_sub_token_onshot = tokens_ <= sub_token_;
|
||||
return is_sub_token_onshot;
|
||||
@@ -1523,7 +1529,7 @@ struct MoeSortingMultiPhaseKernel_P2
|
||||
index_t num_experts;
|
||||
index_t mesh_stride; // mesh_stride for p_expert_mesh
|
||||
mdiv unit_size_mdiv;
|
||||
index_t moe_buf_bytes;
|
||||
long_index_t moe_buf_bytes;
|
||||
};
|
||||
|
||||
CK_TILE_HOST static constexpr auto MakeKargs(const Hargs& h)
|
||||
|
||||
Reference in New Issue
Block a user