[CK_TILE] moe_sorting support "local_tokens" feature for EP case (#2335)

* support local_token for hipgraph

* update README

* fix comment

* fix fmoe example
This commit is contained in:
carlushuang
2025-06-18 10:49:43 +08:00
committed by GitHub
parent c7c6a0ccb3
commit a4e1248dba
11 changed files with 495 additions and 162 deletions

View File

@@ -165,7 +165,8 @@ struct MoeSortingHostArgs
const void* p_topk_ids; // [token, topk]
const void* p_weights; // [token, topk]
const void* p_local_expert_mask;
const void* p_local_expert_mask; // [experts]
const void* p_local_tokens; // [1] if not nullptr, tokens read from here
void* p_sorted_token_ids;
void* p_sorted_weights;
@@ -177,7 +178,7 @@ struct MoeSortingHostArgs
void* p_ws; // size is moe_sorting_get_workspace_size()
// if return zero, then could be nullptr
// must be cleard before use
index_t tokens;
index_t tokens; // if p_local_tokens is not nullptr, this indicate the max possible tokens used for ws/LDS calculation
index_t unit_size; // this is the M_a of fused-moe kernel
index_t num_experts;
index_t topk;
@@ -201,6 +202,7 @@ struct MoeSortingKernel
const void* p_topk_ids;
const void* p_weights;
const void* p_local_expert_mask;
const void* p_local_tokens; // [1] if not nullptr, tokens read from here
void* p_sorted_token_ids;
void* p_sorted_weights;
void* p_sorted_expert_ids;
@@ -253,6 +255,7 @@ struct MoeSortingKernel
k.p_topk_ids = h.p_topk_ids;
k.p_weights = h.p_weights;
k.p_local_expert_mask = h.p_local_expert_mask;
k.p_local_tokens = h.p_local_tokens;
k.p_sorted_token_ids = h.p_sorted_token_ids;
k.p_sorted_weights = h.p_sorted_weights;
k.p_sorted_expert_ids = h.p_sorted_expert_ids;
@@ -263,9 +266,13 @@ struct MoeSortingKernel
k.moe_buf_bytes = h.moe_buf_bytes;
const auto blocks = BlockSize(h);
// NOTE: tokens could from p_local_tokens, so here this variable is useless
// hence moe_align_block_size_kernel() will not behavior properly if we have dynamic tokens
// (indeed we can deprecate moe_align_block_size_kernel)
k.tokens_per_thread = integer_divide_ceil(h.tokens * h.topk, blocks.x);
k.unit_size_mdiv = mdiv{static_cast<uint32_t>(h.unit_size)};
k.topk_mdiv = mdiv{static_cast<uint32_t>(h.topk)};
// NOTE: tokens could from p_local_tokens, so here the LDS will be bigger than expected (but works)
k.smem_rows = [&](){
auto [r_, c_] = moe_sorting_get_smem_row_col(h.tokens, h.num_experts);
(void) c_;
@@ -1009,8 +1016,19 @@ struct MoeSortingKernel
}
const size_t numel = kargs.tokens * kargs.topk_mdiv.divisor;
extern __shared__ char smem[];
#if MOE_SORTING_USE_EX_KERNEL
(void)numel;
index_t tokens_ = [&]() {
if constexpr(Problem::LocalToken)
{
return reinterpret_cast<const index_t*>(kargs.p_local_tokens)[0];
}
else
{
return kargs.tokens;
}
}();
return moe_align_block_size_kernel_ex(
static_cast<const IndexType*>(kargs.p_topk_ids),
static_cast<const WeightType*>(kargs.p_weights),
@@ -1020,7 +1038,7 @@ struct MoeSortingKernel
static_cast<IndexType*>(kargs.p_sorted_expert_ids),
static_cast<IndexType*>(kargs.p_total_tokens_post_pad),
kargs.num_experts,
kargs.tokens,
tokens_,
kargs.unit_size_mdiv,
kargs.topk_mdiv,
kargs.expert_mdiv,
@@ -1245,6 +1263,7 @@ CK_TILE_DEVICE void moe_buf_set_zero_kernel(uint8x16_t* buf, long_index_t buf_by
} // namespace impl
// TODO: tokens could be from
// prefer to run mp kernel if is not oneshot
CK_TILE_HOST bool moe_sorting_is_oneshot(int tokens_, int num_experts_)
{
@@ -1351,9 +1370,11 @@ struct MoeSortingMultiPhaseKernel_P0
struct Kargs
{
const void* p_topk_ids; // [tokens, topk]
void* p_expert_mesh; // [expert, tokens]
index_t tokens;
const void* p_topk_ids; // [tokens, topk]
const void* p_local_tokens; // [1], if not nullptr, use this as actual tokens
void* p_expert_mesh; // [expert, tokens]
index_t tokens; // if p_local_tokens is not nullptr, this indicate the max possible tokens
// used for ws/LDS calculation
index_t mesh_stride; // mesh_stride for p_expert_mesh
mdiv topk_mdiv;
};
@@ -1373,11 +1394,12 @@ struct MoeSortingMultiPhaseKernel_P0
CK_TILE_HOST static constexpr auto MakeKargs(const Hargs& h)
{
Kargs k;
k.p_topk_ids = h.p_topk_ids;
k.p_expert_mesh = h.p_ws;
k.tokens = h.tokens;
k.mesh_stride = impl::moe_sorting_mp_mesh_stride(h.tokens);
k.topk_mdiv = mdiv{static_cast<uint32_t>(h.topk)};
k.p_topk_ids = h.p_topk_ids;
k.p_local_tokens = h.p_local_tokens;
k.p_expert_mesh = h.p_ws;
k.tokens = h.tokens;
k.mesh_stride = impl::moe_sorting_mp_mesh_stride(h.tokens);
k.topk_mdiv = mdiv{static_cast<uint32_t>(h.topk)};
return k;
}
@@ -1394,7 +1416,26 @@ struct MoeSortingMultiPhaseKernel_P0
const topk_id_t* p_topk_ids = reinterpret_cast<const topk_id_t*>(kargs.p_topk_ids);
MeshType* p_expert_mesh = reinterpret_cast<MeshType*>(kargs.p_expert_mesh);
index_t total_elem = kargs.tokens * kargs.topk_mdiv.divisor / Problem::SubTokenTile;
index_t tokens = [&]() {
if constexpr(Problem::LocalToken)
{
return reinterpret_cast<const index_t*>(kargs.p_local_tokens)[0];
}
else
{
return kargs.tokens;
}
}();
index_t rounded_tokens = [&]() {
if constexpr(Problem::LocalToken)
{
return (tokens + Problem::SubTokenTile - 1) / Problem::SubTokenTile *
Problem::SubTokenTile;
}
else
return tokens;
}();
index_t total_elem = rounded_tokens * kargs.topk_mdiv.divisor / Problem::SubTokenTile;
#pragma unroll Problem::SubTokenTile
for(index_t i = blockIdx.x * BLOCK_SIZE + threadIdx.x; i < total_elem;
@@ -1405,8 +1446,15 @@ struct MoeSortingMultiPhaseKernel_P0
IndexType eid = x[j.value]; // ext_vector_type must use int to []
uint32_t curr_token_id, curr_topk_id;
kargs.topk_mdiv.divmod(i * Problem::SubTokenTile + j, curr_token_id, curr_topk_id);
p_expert_mesh[eid * kargs.mesh_stride + curr_token_id] =
(curr_topk_id + 1) & 0xffff;
if constexpr(Problem::LocalToken)
{
if(static_cast<index_t>(curr_token_id) < tokens)
p_expert_mesh[eid * kargs.mesh_stride + curr_token_id] =
(curr_topk_id + 1) & 0xffff;
}
else
p_expert_mesh[eid * kargs.mesh_stride + curr_token_id] =
(curr_topk_id + 1) & 0xffff;
});
}
}
@@ -1542,6 +1590,7 @@ struct MoeSortingMultiPhaseKernel_P01
{
const void* p_topk_ids; // [tokens, topk]
const void* p_local_expert_mask; // [expert]
const void* p_local_tokens; // [1]
void* p_expert_mesh; // [expert, tokens]
void* p_expert_cumsum; // [expert + 1]
void* p_expert_sem; // [1]
@@ -1569,6 +1618,7 @@ struct MoeSortingMultiPhaseKernel_P01
Kargs k;
k.p_topk_ids = h.p_topk_ids;
k.p_local_expert_mask = h.p_local_expert_mask;
k.p_local_tokens = h.p_local_tokens;
k.p_expert_mesh = h.p_ws;
k.p_expert_cumsum = reinterpret_cast<void*>(
reinterpret_cast<char*>(h.p_ws) +
@@ -1580,8 +1630,17 @@ struct MoeSortingMultiPhaseKernel_P01
k.tokens = h.tokens;
k.num_experts = h.num_experts;
k.mesh_stride = impl::moe_sorting_mp_mesh_stride(h.tokens);
k.wg_count = WGCounts(h);
k.topk_mdiv = mdiv{static_cast<uint32_t>(h.topk)};
k.wg_count = [&]() {
if constexpr(Problem::LocalToken)
{
return GridSize(h);
}
else
{
return WGCounts(h);
}
}();
k.topk_mdiv = mdiv{static_cast<uint32_t>(h.topk)};
return k;
}
@@ -1607,13 +1666,46 @@ struct MoeSortingMultiPhaseKernel_P01
CK_TILE_DEVICE void operator()(Kargs kargs) const
{
workgroup_barrier wb{reinterpret_cast<uint32_t*>(kargs.p_expert_sem)};
index_t tokens = [&]() {
if constexpr(Problem::LocalToken)
{
return reinterpret_cast<const index_t*>(kargs.p_local_tokens)[0];
}
else
{
return kargs.tokens;
}
}();
index_t rounded_tokens = [&]() {
if constexpr(Problem::LocalToken)
{
return (tokens + Problem::SubTokenTile - 1) / Problem::SubTokenTile *
Problem::SubTokenTile;
}
else
return tokens;
}();
index_t wg_count = [&]() {
if constexpr(Problem::LocalToken)
{
index_t total_elem = rounded_tokens * kargs.topk / Problem::SubTokenTile;
index_t elem_cnt = (total_elem + BLOCK_SIZE - 1) / BLOCK_SIZE;
// no more than grid_size
return min(elem_cnt, kargs.wg_count);
}
else
{
return kargs.wg_count;
}
}();
{
using topk_id_t = ext_vector_t<IndexType, Problem::SubTokenTile>;
const topk_id_t* p_topk_ids = reinterpret_cast<const topk_id_t*>(kargs.p_topk_ids);
IndexType* p_expert_mesh = reinterpret_cast<IndexType*>(kargs.p_expert_mesh);
index_t total_elem = kargs.tokens * kargs.topk_mdiv.divisor / Problem::SubTokenTile;
index_t total_elem = rounded_tokens * kargs.topk_mdiv.divisor / Problem::SubTokenTile;
#pragma unroll Problem::SubTokenTile
for(index_t i = blockIdx.x * BLOCK_SIZE + threadIdx.x; i < total_elem;
@@ -1625,10 +1717,19 @@ struct MoeSortingMultiPhaseKernel_P01
uint32_t curr_token_id, curr_topk_id;
kargs.topk_mdiv.divmod(
i * Problem::SubTokenTile + j, curr_token_id, curr_topk_id);
p_expert_mesh[eid * kargs.mesh_stride + curr_token_id] = curr_topk_id + 1;
// p_expert_mesh[eid * kargs.mesh_stride + curr_token_id] = curr_topk_id + 1;
if constexpr(Problem::LocalToken)
{
if(static_cast<index_t>(curr_token_id) < tokens)
p_expert_mesh[eid * kargs.mesh_stride + curr_token_id] =
(curr_topk_id + 1) & 0xffff;
}
else
p_expert_mesh[eid * kargs.mesh_stride + curr_token_id] =
(curr_topk_id + 1) & 0xffff;
});
}
if(static_cast<index_t>(blockIdx.x) < kargs.wg_count)
if(static_cast<index_t>(blockIdx.x) < wg_count)
{
wb.inc();
}
@@ -1642,7 +1743,7 @@ struct MoeSortingMultiPhaseKernel_P01
if(eid >= kargs.num_experts)
return;
wb.wait_lt(kargs.wg_count);
wb.wait_lt(wg_count);
for(; eid < kargs.num_experts; eid += gridDim.x)
{
@@ -1731,6 +1832,7 @@ struct MoeSortingMultiPhaseKernel_P2
struct Kargs
{
const void* p_local_expert_mask; // [expert]
const void* p_local_tokens; // [1]
void* p_expert_mesh; // [expert, tokens]
void* p_expert_cumsum; // [expert + 1]
void* p_total_tokens_post_pad; // [1]
@@ -1747,6 +1849,7 @@ struct MoeSortingMultiPhaseKernel_P2
{
Kargs k;
k.p_local_expert_mask = h.p_local_expert_mask;
k.p_local_tokens = h.p_local_tokens;
k.p_expert_cumsum = reinterpret_cast<void*>(
reinterpret_cast<char*>(h.p_ws) +
impl::moe_sorting_mp_mesh_smem_size(h.tokens, h.num_experts, h.topk));
@@ -1942,6 +2045,7 @@ struct MoeSortingMultiPhaseKernel_P3
{
const void* p_weights;
const void* p_local_expert_mask;
const void* p_local_tokens;
void* p_sorted_token_ids;
void* p_sorted_weights;
void* p_expert_mesh; // [token, expert]
@@ -1958,6 +2062,7 @@ struct MoeSortingMultiPhaseKernel_P3
Kargs k;
k.p_weights = h.p_weights;
k.p_local_expert_mask = h.p_local_expert_mask;
k.p_local_tokens = h.p_local_tokens;
k.p_sorted_token_ids = h.p_sorted_token_ids;
k.p_sorted_weights = h.p_sorted_weights;
k.p_expert_mesh = h.p_ws;
@@ -1994,6 +2099,16 @@ struct MoeSortingMultiPhaseKernel_P3
const WeightType* p_weights = static_cast<const WeightType*>(kargs.p_weights);
WeightType* p_sorted_weights = reinterpret_cast<WeightType*>(kargs.p_sorted_weights);
index_t tokens = [&]() {
if constexpr(Problem::LocalToken)
{
return reinterpret_cast<const index_t*>(kargs.p_local_tokens)[0];
}
else
{
return kargs.tokens;
}
}();
int eid = blockIdx.x;
int wave_id = threadIdx.x / WarpSize;
int lane_id = threadIdx.x % WarpSize;
@@ -2019,7 +2134,7 @@ struct MoeSortingMultiPhaseKernel_P3
{
int i_token = i * BLOCK_SIZE + threadIdx.x;
IndexType x = 0;
if(i_token < kargs.tokens)
if(i_token < tokens)
{
x = p_expert_mesh[eid * kargs.mesh_stride + i_token];
}
@@ -2066,7 +2181,7 @@ struct MoeSortingMultiPhaseKernel_P3
for(index_t i = e_start + prev_cumsum + threadIdx.x; i < e_end; i += BLOCK_SIZE)
{
#if CK_TILE_REFERENCE_MOE_SORTING_MOCK_ID
p_sorted_token_ids[i] = MOE_SORTING_MOCK_ID(kargs.tokens, kargs.topk_mdiv.divisor);
p_sorted_token_ids[i] = MOE_SORTING_MOCK_ID(tokens, kargs.topk_mdiv.divisor);
#else
p_sorted_token_ids[i] = tokens;
#endif
@@ -2105,6 +2220,7 @@ struct MoeSortingMultiPhaseKernel_P23
{
const void* p_weights;
const void* p_local_expert_mask; // [expert]
const void* p_local_tokens; // [1]
void* p_expert_mesh; // [expert, tokens]
void* p_expert_cumsum; // [expert + 1]
void* p_total_tokens_post_pad; // [1]
@@ -2127,6 +2243,7 @@ struct MoeSortingMultiPhaseKernel_P23
Kargs k;
k.p_weights = h.p_weights;
k.p_local_expert_mask = h.p_local_expert_mask;
k.p_local_tokens = h.p_local_tokens;
k.p_expert_mesh = h.p_ws;
k.p_expert_cumsum = reinterpret_cast<void*>(
reinterpret_cast<char*>(h.p_ws) +
@@ -2346,6 +2463,17 @@ struct MoeSortingMultiPhaseKernel_P23
return; // skip empty expert
}
index_t tokens = [&]() {
if constexpr(Problem::LocalToken)
{
return reinterpret_cast<const index_t*>(kargs.p_local_tokens)[0];
}
else
{
return kargs.tokens;
}
}();
// cumsum one by one
constexpr index_t index_pack = Problem::SubTokenTile; // always packed
using r_t = ext_vector_t<MeshType, index_pack>; // always use int32x4
@@ -2357,7 +2485,7 @@ struct MoeSortingMultiPhaseKernel_P23
{
int i_token_pack = i * BLOCK_SIZE + threadIdx.x;
r_t x_v = 0;
if(i_token_pack < (kargs.tokens + index_pack - 1) / index_pack)
if(i_token_pack < (tokens + index_pack - 1) / index_pack)
{
x_v = reinterpret_cast<r_t*>(p_expert_mesh +
eid * kargs.mesh_stride)[i_token_pack];
@@ -2554,7 +2682,7 @@ struct MoeSortingMultiPhaseKernel_P23
for(index_t i = e_start + prev_cumsum + threadIdx.x; i < e_end; i += BLOCK_SIZE)
{
#if CK_TILE_REFERENCE_MOE_SORTING_MOCK_ID
p_sorted_token_ids[i] = MOE_SORTING_MOCK_ID(kargs.tokens, kargs.topk_mdiv.divisor);
p_sorted_token_ids[i] = MOE_SORTING_MOCK_ID(tokens, kargs.topk_mdiv.divisor);
#else
p_sorted_token_ids[i] = tokens;
#endif

View File

@@ -31,6 +31,7 @@ template <typename IndexType_,
index_t SubTokenTile_, // 1,2,4,8, or 0 in the future
bool SubTokenOneShot_, // if we only loop over once or not
bool LocalExpertMasking_, // used in EP case
bool LocalToken_, // used in EP case
bool SkipExpertsWithZeroTokens_ = true,
index_t ExpertTile_ = 0>
struct MoeSortingProblemEx
@@ -44,6 +45,7 @@ struct MoeSortingProblemEx
static constexpr index_t SubTokenTile = SubTokenTile_;
static constexpr bool SubTokenOneShot = SubTokenOneShot_;
static constexpr bool LocalExpertMasking = LocalExpertMasking_;
static constexpr bool LocalToken = LocalToken_;
static constexpr bool SkipExpertsWithZeroTokens = SkipExpertsWithZeroTokens_;
static_assert(SubTokenTile == 1 || SubTokenTile == 2 || SubTokenTile == 4 || SubTokenTile == 8);
static constexpr index_t ExpertTile = ExpertTile_; // TODO: only used in store out
@@ -54,6 +56,7 @@ template <typename IndexType_,
typename MeshType_,
index_t SubTokenTile_, // 1,2,4,8
bool LocalExpertMasking_, // used in EP case
bool LocalToken_, // used in EP case
bool SkipExpertsWithZeroTokens_ = true>
struct MoeSortingProblemMp
{
@@ -64,6 +67,7 @@ struct MoeSortingProblemMp
static constexpr index_t SubTokenTile = SubTokenTile_;
static constexpr bool LocalExpertMasking = LocalExpertMasking_;
static constexpr bool LocalToken = LocalToken_;
static constexpr bool SkipExpertsWithZeroTokens = SkipExpertsWithZeroTokens_;
static_assert(SubTokenTile == 1 || SubTokenTile == 2 || SubTokenTile == 4 ||
SubTokenTile == 8 || SubTokenTile == 16);