[CK_TILE] moe sorting optimize local_token (#2469)

* fix bug in loops that need use local tokens to compute

* support extra chain local_token

* update

* update

* refine some main

* update

* support dispatch_policy

* fix 15 example
This commit is contained in:
carlushuang
2025-07-15 09:42:18 +08:00
committed by GitHub
parent 141bf2d54d
commit cfe211cc60
9 changed files with 579 additions and 94 deletions

View File

@@ -23,6 +23,11 @@ namespace ck_tile {
#define MOE_SORTING_FUSE_MP_01 0
#endif
// weather use 2d buffer indexing for fmoe ws or 1d
#ifndef MOE_SORTING_FMOE_2D_BUF
#define MOE_SORTING_FMOE_2D_BUF 1
#endif
// clang-format off
// [indexing implementation-1]
// using M_a as constexpr block_size to partition all tokens into different slices
@@ -171,7 +176,7 @@ struct MoeSortingHostArgs
void* p_sorted_token_ids;
void* p_sorted_weights;
void* p_sorted_expert_ids;
void* p_total_tokens_post_pad;
void* p_total_tokens_post_pad; // [2], [0]:outputed tokens_post_padded, [1]:actual tokens on current rank (local_tokens or tokens)
// we fused the setzero of output of fused-moe buffer
// set this pointer to nullptr will skip this operation
void* p_moe_buf;
@@ -182,7 +187,18 @@ struct MoeSortingHostArgs
index_t unit_size; // this is the M_a of fused-moe kernel
index_t num_experts;
index_t topk;
#if MOE_SORTING_FMOE_2D_BUF
// NOTE:
// moe_buf_* is a 2d ws buffer used for the following fmoe kernel
// arranged as row*col, where row=tokens(or local_token), col=interm_dim
// we fuse this clearing inside sorting kernel
// Besides, we require inter_dim to be multiple of 16 byte(make sure when alloc ws for fmoe)
index_t moe_buf_interm_dim; // p_moe_buf interm_dim
index_t moe_buf_elem_bytes; // p_moe_buf byte size(8bit, 16bit, 32bit, etc.)
#else
long_index_t moe_buf_bytes; // byte size of p_moe_buf
#endif
};
template <typename Problem_>
@@ -197,6 +213,9 @@ struct MoeSortingKernel
using Hargs = MoeSortingHostArgs;
static constexpr index_t BLOCK_SIZE = 256;
static constexpr index_t OCCUPANCY = 2; // hard coded
struct Kargs
{
const void* p_topk_ids;
@@ -210,8 +229,12 @@ struct MoeSortingKernel
void* p_moe_buf;
index_t tokens;
index_t num_experts;
#if MOE_SORTING_FMOE_2D_BUF
index_t moe_buf_interm_dim; // p_moe_buf interm_dim
index_t moe_buf_elem_bytes; // p_moe_buf byte size(8bit, 16bit, 32bit, etc.)
#else
long_index_t moe_buf_bytes;
#endif
index_t tokens_per_thread;
index_t smem_rows;
mdiv unit_size_mdiv;
@@ -220,10 +243,27 @@ struct MoeSortingKernel
// mdiv sub_tokens_mdiv;
};
CK_TILE_HOST static constexpr auto get_num_cu()
{
index_t num_cu = [&]() {
hipDeviceProp_t dev_prop;
hipDevice_t dev;
HIP_CHECK_ERROR(hipGetDevice(&dev));
HIP_CHECK_ERROR(hipGetDeviceProperties(&dev_prop, dev));
return dev_prop.multiProcessorCount;
}();
return num_cu;
}
CK_TILE_HOST static constexpr auto GridSize(const Hargs& h)
{
#if MOE_SORTING_FMOE_2D_BUF
(void)h;
return get_num_cu() * OCCUPANCY;
#else
// TODO: assume num-experts not too much
return dim3(1 + ck_tile::integer_divide_ceil(h.moe_buf_bytes, BlockSize(h).x * 16));
#endif
}
CK_TILE_HOST static constexpr auto BlockSize(const Hargs& h)
@@ -263,7 +303,12 @@ struct MoeSortingKernel
k.p_total_tokens_post_pad = h.p_total_tokens_post_pad;
k.tokens = h.tokens;
k.num_experts = h.num_experts;
#if MOE_SORTING_FMOE_2D_BUF
k.moe_buf_interm_dim = h.moe_buf_interm_dim;
k.moe_buf_elem_bytes = h.moe_buf_elem_bytes;
#else
k.moe_buf_bytes = h.moe_buf_bytes;
#endif
const auto blocks = BlockSize(h);
// NOTE: tokens could from p_local_tokens, so here this variable is useless
@@ -431,6 +476,24 @@ struct MoeSortingKernel
}
}
CK_TILE_DEVICE void
moe_buf_set_zero_kernel_2d(void* buf, index_t row, index_t col, index_t elem_bytes) const
{
const long_index_t total_pixels = static_cast<long_index_t>(row) * col;
const long_index_t total_bytes = total_pixels * elem_bytes;
const long_index_t total_elems = total_bytes / 16; // always use dwordx4
using vector_type = ext_vector_t<index_t, 4>;
vector_type* p_buf = reinterpret_cast<vector_type*>(buf);
auto zero_ = vector_type{0};
for(long_index_t i = (blockIdx.x - 1) * BLOCK_SIZE + threadIdx.x; i < total_elems;
i += (gridDim.x - 1) * BLOCK_SIZE)
{
p_buf[i] = zero_;
}
}
CK_TILE_DEVICE void moe_align_block_size_kernel(const IndexType* __restrict__ topk_id,
const WeightType* __restrict__ weights,
index_t* p_sorted_token_ids,
@@ -863,7 +926,8 @@ struct MoeSortingKernel
}
if((lid + i_e_ - get_warp_size()) == (num_experts - 1))
{
*p_total_tokens_post_pad = local_cumsum_;
*p_total_tokens_post_pad = local_cumsum_;
p_total_tokens_post_pad[1] = tokens;
}
}
__syncthreads();
@@ -1005,20 +1069,6 @@ struct MoeSortingKernel
CK_TILE_DEVICE void operator()(Kargs kargs) const
{
if(blockIdx.x > 0)
{
if(kargs.p_moe_buf)
{
moe_buf_set_zero_kernel(reinterpret_cast<uint8x16_t*>(kargs.p_moe_buf),
kargs.moe_buf_bytes);
}
return;
}
const size_t numel = kargs.tokens * kargs.topk_mdiv.divisor;
extern __shared__ char smem[];
#if MOE_SORTING_USE_EX_KERNEL
(void)numel;
index_t tokens_ = [&]() {
if constexpr(Problem::LocalToken)
{
@@ -1029,6 +1079,25 @@ struct MoeSortingKernel
return kargs.tokens;
}
}();
if(blockIdx.x > 0)
{
if(kargs.p_moe_buf)
{
#if MOE_SORTING_FMOE_2D_BUF
moe_buf_set_zero_kernel_2d(
kargs.p_moe_buf, tokens_, kargs.moe_buf_interm_dim, kargs.moe_buf_elem_bytes);
#else
moe_buf_set_zero_kernel(reinterpret_cast<uint8x16_t*>(kargs.p_moe_buf),
kargs.moe_buf_bytes);
#endif
}
return;
}
extern __shared__ char smem[];
#if MOE_SORTING_USE_EX_KERNEL
return moe_align_block_size_kernel_ex(
static_cast<const IndexType*>(kargs.p_topk_ids),
static_cast<const WeightType*>(kargs.p_weights),
@@ -1045,6 +1114,7 @@ struct MoeSortingKernel
kargs.smem_rows,
smem);
#else
const size_t numel = kargs.tokens * kargs.topk_mdiv.divisor;
return moe_align_block_size_kernel(static_cast<const IndexType*>(kargs.p_topk_ids),
static_cast<const WeightType*>(kargs.p_weights),
static_cast<IndexType*>(kargs.p_sorted_token_ids),
@@ -1066,6 +1136,8 @@ namespace impl {
// [expert, padded_tokens]
CK_TILE_HOST_DEVICE index_t moe_sorting_mp_mesh_stride(index_t tokens)
{
// Pad to multiply of 32. This can make sure even if the mesh is in 8bit,
// we can still use dwordx4 load/store
constexpr index_t chunk = 32;
return (tokens + chunk - 1) / chunk * chunk;
};
@@ -1261,6 +1333,24 @@ CK_TILE_DEVICE void moe_buf_set_zero_kernel(uint8x16_t* buf, long_index_t buf_by
}
}
template <index_t BLOCK_SIZE = 256>
CK_TILE_DEVICE void moe_buf_set_zero_kernel_2d(
void* buf, index_t row, index_t col, index_t elem_bytes, index_t gid, index_t blocks)
{
const long_index_t total_pixels = static_cast<long_index_t>(row) * col;
const long_index_t total_bytes = total_pixels * elem_bytes;
const long_index_t total_elems = total_bytes / 16; // always use dwordx4
using vector_type = ext_vector_t<index_t, 4>;
vector_type* p_buf = reinterpret_cast<vector_type*>(buf);
auto zero_ = vector_type{0};
for(long_index_t i = gid * BLOCK_SIZE + threadIdx.x; i < total_elems; i += blocks * BLOCK_SIZE)
{
p_buf[i] = zero_;
}
}
} // namespace impl
// TODO: tokens could be from
@@ -1292,12 +1382,29 @@ CK_TILE_HOST index_t moe_sorting_mp_get_workspace_size(int tokens_, int num_expe
}
// return size in byte
CK_TILE_HOST index_t moe_sorting_get_workspace_size(int tokens_, int num_experts_, int topk_)
// dispatch_policy: 0-automatically pick up kerel. 1-always use single kernel, 2-always use mp
// kernel
CK_TILE_HOST index_t moe_sorting_get_workspace_size(int tokens_,
int num_experts_,
int topk_,
int dispatch_policy_)
{
#if 1
if(moe_sorting_is_oneshot(tokens_, num_experts_))
// return 0;
if(dispatch_policy_ == 0)
{
return 0;
if(moe_sorting_is_oneshot(tokens_, num_experts_))
{
return 0;
}
else
{
return moe_sorting_mp_get_workspace_size(tokens_, num_experts_, topk_);
}
}
else if(dispatch_policy_ == 1)
{
return 0; // always use single kernel
}
else
{
@@ -1308,6 +1415,98 @@ CK_TILE_HOST index_t moe_sorting_get_workspace_size(int tokens_, int num_experts
#endif
}
template <typename Problem_>
struct MoeSortingClearWorkspaceKernel
{
using Problem = remove_cvref_t<Problem_>;
static constexpr index_t BLOCK_SIZE = Problem::BlockSize;
static constexpr index_t OCCUPANCY = Problem::Occu;
using Hargs = MoeSortingHostArgs;
struct Kargs
{
const void* p_local_tokens; // [1], if not nullptr, use this as actual tokens
void* p_expert_mesh; // [expert, tokens]
index_t tokens; // if p_local_tokens is not nullptr, this indicate the max possible tokens
// used for ws/LDS calculation
index_t num_experts;
index_t mesh_stride; // mesh_stride for p_expert_mesh
index_t mesh_byte_size;
};
CK_TILE_HOST static constexpr auto get_num_cu()
{
index_t num_cu = [&]() {
hipDeviceProp_t dev_prop;
hipDevice_t dev;
HIP_CHECK_ERROR(hipGetDevice(&dev));
HIP_CHECK_ERROR(hipGetDeviceProperties(&dev_prop, dev));
return dev_prop.multiProcessorCount;
}();
return num_cu;
}
CK_TILE_HOST static constexpr auto MakeKargs(const Hargs& h)
{
Kargs k;
k.p_local_tokens = h.p_local_tokens;
k.p_expert_mesh = h.p_ws;
k.tokens = h.tokens;
k.num_experts = h.num_experts;
k.mesh_stride = impl::moe_sorting_mp_mesh_stride(h.tokens);
k.mesh_byte_size = impl::moe_sorting_mesh_byte_size(h.tokens, h.num_experts, h.topk);
return k;
}
CK_TILE_HOST static constexpr auto GridSize(const Hargs&) { return get_num_cu() * OCCUPANCY; }
CK_TILE_HOST static constexpr auto BlockSize(const Hargs&) { return dim3(BLOCK_SIZE); }
// in byte
CK_TILE_HOST static constexpr auto GetSmemSize() { return 0; }
CK_TILE_DEVICE void operator()(Kargs kargs) const
{
index_t tokens = [&]() {
if constexpr(Problem::LocalToken)
{
return reinterpret_cast<const index_t*>(kargs.p_local_tokens)[0];
}
else
{
return kargs.tokens;
}
}();
index_t mesh_stride = [&]() {
if constexpr(Problem::LocalToken)
{
return impl::moe_sorting_mp_mesh_stride(tokens);
}
else
{
return kargs.mesh_stride;
}
}();
index_t row_size = mesh_stride; // impl::moe_sorting_mp_mesh_stride(tokens);
index_t pixels = kargs.num_experts * row_size;
index_t total_bytes = pixels * kargs.mesh_byte_size;
index_t total_elems = total_bytes / 16; // always use dwordx4
using vector_type = ext_vector_t<index_t, 4>;
vector_type* p_expert_mesh = reinterpret_cast<vector_type*>(kargs.p_expert_mesh);
auto zero_ = vector_type{0};
for(index_t i = blockIdx.x * BLOCK_SIZE + threadIdx.x; i < total_elems;
i += gridDim.x * BLOCK_SIZE)
{
p_expert_mesh[i] = zero_;
}
}
};
// below kernel is multi-phase implementation for large token and/or expert case
// write into a buffer to record the token cnt
@@ -1435,6 +1634,16 @@ struct MoeSortingMultiPhaseKernel_P0
else
return tokens;
}();
index_t mesh_stride = [&]() {
if constexpr(Problem::LocalToken)
{
return impl::moe_sorting_mp_mesh_stride(tokens);
}
else
{
return kargs.mesh_stride;
}
}();
index_t total_elem = rounded_tokens * kargs.topk_mdiv.divisor / Problem::SubTokenTile;
#pragma unroll Problem::SubTokenTile
@@ -1449,12 +1658,11 @@ struct MoeSortingMultiPhaseKernel_P0
if constexpr(Problem::LocalToken)
{
if(static_cast<index_t>(curr_token_id) < tokens)
p_expert_mesh[eid * kargs.mesh_stride + curr_token_id] =
p_expert_mesh[eid * mesh_stride + curr_token_id] =
(curr_topk_id + 1) & 0xffff;
}
else
p_expert_mesh[eid * kargs.mesh_stride + curr_token_id] =
(curr_topk_id + 1) & 0xffff;
p_expert_mesh[eid * mesh_stride + curr_token_id] = (curr_topk_id + 1) & 0xffff;
});
}
}
@@ -1479,6 +1687,7 @@ struct MoeSortingMultiPhaseKernel_P1
struct Kargs
{
const void* p_local_expert_mask; // [expert]
const void* p_local_tokens; // [1], if not nullptr, use this as actual tokens
void* p_expert_mesh; // [expert, tokens]
void* p_expert_cumsum;
index_t mesh_stride; // mesh_stride for p_expert_mesh
@@ -1488,6 +1697,7 @@ struct MoeSortingMultiPhaseKernel_P1
{
Kargs k;
k.p_local_expert_mask = h.p_local_expert_mask;
k.p_local_tokens = h.p_local_tokens;
k.p_expert_mesh = h.p_ws;
k.p_expert_cumsum = reinterpret_cast<void*>(
reinterpret_cast<char*>(h.p_ws) +
@@ -1511,12 +1721,9 @@ struct MoeSortingMultiPhaseKernel_P1
{
__shared__ char smem[GetSmemSize()];
int eid = blockIdx.x;
int eid = blockIdx.x;
constexpr index_t index_pack = Problem::SubTokenTile; // always packed
using r_t = ext_vector_t<MeshType, index_pack>; // always use int32x4
r_t* p_expert_mesh = reinterpret_cast<r_t*>(
reinterpret_cast<MeshType*>(kargs.p_expert_mesh) + eid * kargs.mesh_stride);
const IndexType* p_local_expert_mask =
static_cast<const IndexType*>(kargs.p_local_expert_mask);
@@ -1524,7 +1731,32 @@ struct MoeSortingMultiPhaseKernel_P1
auto f_sum = [](auto x_, auto y_) { return x_ + y_; };
int loops = (kargs.mesh_stride / index_pack + BLOCK_SIZE - 1) / BLOCK_SIZE;
index_t tokens = [&]() {
if constexpr(Problem::LocalToken)
{
return reinterpret_cast<const index_t*>(kargs.p_local_tokens)[0];
}
else
{
return 0; // will not use if not LocalToken
}
}();
index_t mesh_stride = [&]() {
if constexpr(Problem::LocalToken)
{
return impl::moe_sorting_mp_mesh_stride(tokens);
}
else
{
return kargs.mesh_stride;
}
}();
r_t* p_expert_mesh = reinterpret_cast<r_t*>(
reinterpret_cast<MeshType*>(kargs.p_expert_mesh) + eid * mesh_stride);
int loops = (mesh_stride / index_pack + BLOCK_SIZE - 1) / BLOCK_SIZE;
if constexpr(Problem::LocalExpertMasking)
{
@@ -1538,7 +1770,7 @@ struct MoeSortingMultiPhaseKernel_P1
{
int position = i * BLOCK_SIZE + threadIdx.x;
r_t v{0};
if(position < (kargs.mesh_stride / index_pack))
if(position < (mesh_stride / index_pack))
v = p_expert_mesh[position];
index_t local_sum = 0;
static_for<0, index_pack, 1>{}(
@@ -1835,7 +2067,7 @@ struct MoeSortingMultiPhaseKernel_P2
const void* p_local_tokens; // [1]
void* p_expert_mesh; // [expert, tokens]
void* p_expert_cumsum; // [expert + 1]
void* p_total_tokens_post_pad; // [1]
void* p_total_tokens_post_pad; // [2]
void* p_sorted_expert_ids;
void* p_moe_buf;
index_t tokens;
@@ -1863,15 +2095,36 @@ struct MoeSortingMultiPhaseKernel_P2
k.mesh_stride = impl::moe_sorting_mp_mesh_stride(h.tokens);
k.unit_size_mdiv = mdiv{static_cast<uint32_t>(h.unit_size)};
#if MOE_SORTING_FMOE_2D_BUF
k.moe_buf_interm_dim = h.moe_buf_interm_dim;
k.moe_buf_elem_bytes = h.moe_buf_elem_bytes;
#else
k.moe_buf_bytes = h.moe_buf_bytes;
#endif
return k;
}
CK_TILE_HOST static constexpr auto get_num_cu()
{
index_t num_cu = [&]() {
hipDeviceProp_t dev_prop;
hipDevice_t dev;
HIP_CHECK_ERROR(hipGetDevice(&dev));
HIP_CHECK_ERROR(hipGetDeviceProperties(&dev_prop, dev));
return dev_prop.multiProcessorCount;
}();
return num_cu;
}
CK_TILE_HOST static constexpr auto GridSize(const Hargs& h)
{
#if MOE_SORTING_FMOE_2D_BUF
return dim3(h.num_experts + get_num_cu() * OCCUPANCY);
#else
// use 1 block to cumsum
return dim3(1 + ck_tile::integer_divide_ceil(h.moe_buf_bytes, BLOCK_SIZE * 16));
#endif
}
CK_TILE_HOST static constexpr auto BlockSize(const Hargs&) { return dim3(BLOCK_SIZE); }
@@ -1888,11 +2141,21 @@ struct MoeSortingMultiPhaseKernel_P2
{
if(blockIdx.x > 0)
{
#if MOE_SORTING_FMOE_2D_BUF
impl::moe_buf_set_zero_kernel_2d<BLOCK_SIZE>(kargs.p_moe_buf,
kargs.tokens,
kargs.moe_buf_interm_dim,
kargs.moe_buf_elem_bytes,
blockIdx.x - 1,
gridDim.x - 1);
return;
#else
impl::moe_buf_set_zero_kernel<BLOCK_SIZE>(
reinterpret_cast<uint8x16_t*>(kargs.p_moe_buf),
kargs.moe_buf_bytes,
blockIdx.x - 1);
return;
#endif
}
__shared__ char smem[GetSmemSize()];
IndexType* s = reinterpret_cast<IndexType*>(smem);
@@ -2223,7 +2486,7 @@ struct MoeSortingMultiPhaseKernel_P23
const void* p_local_tokens; // [1]
void* p_expert_mesh; // [expert, tokens]
void* p_expert_cumsum; // [expert + 1]
void* p_total_tokens_post_pad; // [1]
void* p_total_tokens_post_pad; // [2]
void* p_sorted_expert_ids;
void* p_sorted_token_ids;
@@ -2235,7 +2498,17 @@ struct MoeSortingMultiPhaseKernel_P23
index_t mesh_stride; // mesh_stride for p_expert_mesh
mdiv unit_size_mdiv;
mdiv topk_mdiv;
long_index_t moe_buf_bytes;
#if MOE_SORTING_FMOE_2D_BUF
// NOTE:
// moe_buf_* is a 2d ws buffer used for the following fmoe kernel
// arranged as row*col, where row=tokens(or local_token), col=interm_dim
// we fuse this clearing inside sorting kernel
// Besides, we require inter_dim to be multiple of 16 byte(make sure when alloc ws for fmoe)
index_t moe_buf_interm_dim; // p_moe_buf interm_dim
index_t moe_buf_elem_bytes; // p_moe_buf byte size(8bit, 16bit, 32bit, etc.)
#else
long_index_t moe_buf_bytes; // byte size of p_moe_buf
#endif
};
CK_TILE_HOST static constexpr auto MakeKargs(const Hargs& h)
@@ -2262,16 +2535,37 @@ struct MoeSortingMultiPhaseKernel_P23
k.unit_size_mdiv = mdiv{static_cast<uint32_t>(h.unit_size)};
k.topk_mdiv = mdiv{static_cast<uint32_t>(h.topk)};
#if MOE_SORTING_FMOE_2D_BUF
k.moe_buf_interm_dim = h.moe_buf_interm_dim;
k.moe_buf_elem_bytes = h.moe_buf_elem_bytes;
#else
k.moe_buf_bytes = h.moe_buf_bytes;
#endif
return k;
}
CK_TILE_HOST static constexpr auto get_num_cu()
{
index_t num_cu = [&]() {
hipDeviceProp_t dev_prop;
hipDevice_t dev;
HIP_CHECK_ERROR(hipGetDevice(&dev));
HIP_CHECK_ERROR(hipGetDeviceProperties(&dev_prop, dev));
return dev_prop.multiProcessorCount;
}();
return num_cu;
}
CK_TILE_HOST static constexpr auto GridSize(const Hargs& h)
{
#if MOE_SORTING_FMOE_2D_BUF
return dim3(h.num_experts + get_num_cu() * OCCUPANCY);
#else
// use 1 block to cumsum
// return dim3(1 + ck_tile::integer_divide_ceil(h.moe_buf_bytes, BLOCK_SIZE * 16));
return dim3(h.num_experts + ck_tile::integer_divide_ceil(h.moe_buf_bytes, BLOCK_SIZE * 16));
#endif
}
CK_TILE_HOST static constexpr auto BlockSize(const Hargs&) { return dim3(BLOCK_SIZE); }
@@ -2287,13 +2581,34 @@ struct MoeSortingMultiPhaseKernel_P23
// reduce single pixel within a wave
CK_TILE_DEVICE void operator()(Kargs kargs) const
{
index_t tokens = [&]() {
if constexpr(Problem::LocalToken)
{
return reinterpret_cast<const index_t*>(kargs.p_local_tokens)[0];
}
else
{
return kargs.tokens;
}
}();
if(static_cast<index_t>(blockIdx.x) >= kargs.num_experts)
{
#if MOE_SORTING_FMOE_2D_BUF
impl::moe_buf_set_zero_kernel_2d<BLOCK_SIZE>(kargs.p_moe_buf,
tokens,
kargs.moe_buf_interm_dim,
kargs.moe_buf_elem_bytes,
blockIdx.x - kargs.num_experts,
gridDim.x - kargs.num_experts);
return;
#else
impl::moe_buf_set_zero_kernel<BLOCK_SIZE>(
reinterpret_cast<uint8x16_t*>(kargs.p_moe_buf),
kargs.moe_buf_bytes,
blockIdx.x - kargs.num_experts);
return;
#endif
}
extern __shared__ char smem[];
@@ -2428,13 +2743,15 @@ struct MoeSortingMultiPhaseKernel_P23
{
auto total_tokens_post_pad = prev_cumsum_a * kargs.unit_size_mdiv.divisor;
if(blockIdx.x == 0)
{
p_total_tokens_post_pad[0] = total_tokens_post_pad;
p_total_tokens_post_pad[1] = tokens;
}
p_expert_cumsum_smem[kargs.num_experts] = total_tokens_post_pad;
}
}
__syncthreads();
{
const IndexType* p_local_expert_mask =
static_cast<const IndexType*>(kargs.p_local_expert_mask);
@@ -2463,14 +2780,14 @@ struct MoeSortingMultiPhaseKernel_P23
return; // skip empty expert
}
index_t tokens = [&]() {
index_t mesh_stride = [&]() {
if constexpr(Problem::LocalToken)
{
return reinterpret_cast<const index_t*>(kargs.p_local_tokens)[0];
return impl::moe_sorting_mp_mesh_stride(tokens);
}
else
{
return kargs.tokens;
return kargs.mesh_stride;
}
}();
@@ -2478,7 +2795,8 @@ struct MoeSortingMultiPhaseKernel_P23
constexpr index_t index_pack = Problem::SubTokenTile; // always packed
using r_t = ext_vector_t<MeshType, index_pack>; // always use int32x4
using d_t = ext_vector_t<index_t, index_pack>;
int loops = (kargs.mesh_stride / index_pack + BLOCK_SIZE - 1) / BLOCK_SIZE;
int loops = (mesh_stride / index_pack + BLOCK_SIZE - 1) / BLOCK_SIZE;
int prev_cumsum = 0;
for(int i = 0; i < loops; i++)
@@ -2487,8 +2805,7 @@ struct MoeSortingMultiPhaseKernel_P23
r_t x_v = 0;
if(i_token_pack < (tokens + index_pack - 1) / index_pack)
{
x_v = reinterpret_cast<r_t*>(p_expert_mesh +
eid * kargs.mesh_stride)[i_token_pack];
x_v = reinterpret_cast<r_t*>(p_expert_mesh + eid * mesh_stride)[i_token_pack];
}
r_t x_r;

View File

@@ -73,4 +73,12 @@ struct MoeSortingProblemMp
SubTokenTile == 8 || SubTokenTile == 16);
};
template <bool LocalToken_, index_t BlockSize_ = 1024, index_t Occu_ = 1>
struct MoeSortingClearWorkspaceProblem
{
static constexpr bool LocalToken = LocalToken_;
static constexpr index_t BlockSize = BlockSize_;
static constexpr index_t Occu = Occu_;
};
} // namespace ck_tile