mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-20 06:49:15 +00:00
[CK_TILE] moe_sorting support "local_tokens" feature for EP case (#2335)
* support local_token for hipgraph * update README * fix comment * fix fmoe example
This commit is contained in:
@@ -165,7 +165,8 @@ struct MoeSortingHostArgs
|
||||
const void* p_topk_ids; // [token, topk]
|
||||
const void* p_weights; // [token, topk]
|
||||
|
||||
const void* p_local_expert_mask;
|
||||
const void* p_local_expert_mask; // [experts]
|
||||
const void* p_local_tokens; // [1] if not nullptr, tokens read from here
|
||||
|
||||
void* p_sorted_token_ids;
|
||||
void* p_sorted_weights;
|
||||
@@ -177,7 +178,7 @@ struct MoeSortingHostArgs
|
||||
void* p_ws; // size is moe_sorting_get_workspace_size()
|
||||
// if return zero, then could be nullptr
|
||||
// must be cleard before use
|
||||
index_t tokens;
|
||||
index_t tokens; // if p_local_tokens is not nullptr, this indicate the max possible tokens used for ws/LDS calculation
|
||||
index_t unit_size; // this is the M_a of fused-moe kernel
|
||||
index_t num_experts;
|
||||
index_t topk;
|
||||
@@ -201,6 +202,7 @@ struct MoeSortingKernel
|
||||
const void* p_topk_ids;
|
||||
const void* p_weights;
|
||||
const void* p_local_expert_mask;
|
||||
const void* p_local_tokens; // [1] if not nullptr, tokens read from here
|
||||
void* p_sorted_token_ids;
|
||||
void* p_sorted_weights;
|
||||
void* p_sorted_expert_ids;
|
||||
@@ -253,6 +255,7 @@ struct MoeSortingKernel
|
||||
k.p_topk_ids = h.p_topk_ids;
|
||||
k.p_weights = h.p_weights;
|
||||
k.p_local_expert_mask = h.p_local_expert_mask;
|
||||
k.p_local_tokens = h.p_local_tokens;
|
||||
k.p_sorted_token_ids = h.p_sorted_token_ids;
|
||||
k.p_sorted_weights = h.p_sorted_weights;
|
||||
k.p_sorted_expert_ids = h.p_sorted_expert_ids;
|
||||
@@ -263,9 +266,13 @@ struct MoeSortingKernel
|
||||
k.moe_buf_bytes = h.moe_buf_bytes;
|
||||
|
||||
const auto blocks = BlockSize(h);
|
||||
// NOTE: tokens could from p_local_tokens, so here this variable is useless
|
||||
// hence moe_align_block_size_kernel() will not behavior properly if we have dynamic tokens
|
||||
// (indeed we can deprecate moe_align_block_size_kernel)
|
||||
k.tokens_per_thread = integer_divide_ceil(h.tokens * h.topk, blocks.x);
|
||||
k.unit_size_mdiv = mdiv{static_cast<uint32_t>(h.unit_size)};
|
||||
k.topk_mdiv = mdiv{static_cast<uint32_t>(h.topk)};
|
||||
// NOTE: tokens could from p_local_tokens, so here the LDS will be bigger than expected (but works)
|
||||
k.smem_rows = [&](){
|
||||
auto [r_, c_] = moe_sorting_get_smem_row_col(h.tokens, h.num_experts);
|
||||
(void) c_;
|
||||
@@ -1009,8 +1016,19 @@ struct MoeSortingKernel
|
||||
}
|
||||
const size_t numel = kargs.tokens * kargs.topk_mdiv.divisor;
|
||||
extern __shared__ char smem[];
|
||||
|
||||
#if MOE_SORTING_USE_EX_KERNEL
|
||||
(void)numel;
|
||||
index_t tokens_ = [&]() {
|
||||
if constexpr(Problem::LocalToken)
|
||||
{
|
||||
return reinterpret_cast<const index_t*>(kargs.p_local_tokens)[0];
|
||||
}
|
||||
else
|
||||
{
|
||||
return kargs.tokens;
|
||||
}
|
||||
}();
|
||||
return moe_align_block_size_kernel_ex(
|
||||
static_cast<const IndexType*>(kargs.p_topk_ids),
|
||||
static_cast<const WeightType*>(kargs.p_weights),
|
||||
@@ -1020,7 +1038,7 @@ struct MoeSortingKernel
|
||||
static_cast<IndexType*>(kargs.p_sorted_expert_ids),
|
||||
static_cast<IndexType*>(kargs.p_total_tokens_post_pad),
|
||||
kargs.num_experts,
|
||||
kargs.tokens,
|
||||
tokens_,
|
||||
kargs.unit_size_mdiv,
|
||||
kargs.topk_mdiv,
|
||||
kargs.expert_mdiv,
|
||||
@@ -1245,6 +1263,7 @@ CK_TILE_DEVICE void moe_buf_set_zero_kernel(uint8x16_t* buf, long_index_t buf_by
|
||||
|
||||
} // namespace impl
|
||||
|
||||
// TODO: tokens could be from
|
||||
// prefer to run mp kernel if is not oneshot
|
||||
CK_TILE_HOST bool moe_sorting_is_oneshot(int tokens_, int num_experts_)
|
||||
{
|
||||
@@ -1351,9 +1370,11 @@ struct MoeSortingMultiPhaseKernel_P0
|
||||
|
||||
struct Kargs
|
||||
{
|
||||
const void* p_topk_ids; // [tokens, topk]
|
||||
void* p_expert_mesh; // [expert, tokens]
|
||||
index_t tokens;
|
||||
const void* p_topk_ids; // [tokens, topk]
|
||||
const void* p_local_tokens; // [1], if not nullptr, use this as actual tokens
|
||||
void* p_expert_mesh; // [expert, tokens]
|
||||
index_t tokens; // if p_local_tokens is not nullptr, this indicate the max possible tokens
|
||||
// used for ws/LDS calculation
|
||||
index_t mesh_stride; // mesh_stride for p_expert_mesh
|
||||
mdiv topk_mdiv;
|
||||
};
|
||||
@@ -1373,11 +1394,12 @@ struct MoeSortingMultiPhaseKernel_P0
|
||||
CK_TILE_HOST static constexpr auto MakeKargs(const Hargs& h)
|
||||
{
|
||||
Kargs k;
|
||||
k.p_topk_ids = h.p_topk_ids;
|
||||
k.p_expert_mesh = h.p_ws;
|
||||
k.tokens = h.tokens;
|
||||
k.mesh_stride = impl::moe_sorting_mp_mesh_stride(h.tokens);
|
||||
k.topk_mdiv = mdiv{static_cast<uint32_t>(h.topk)};
|
||||
k.p_topk_ids = h.p_topk_ids;
|
||||
k.p_local_tokens = h.p_local_tokens;
|
||||
k.p_expert_mesh = h.p_ws;
|
||||
k.tokens = h.tokens;
|
||||
k.mesh_stride = impl::moe_sorting_mp_mesh_stride(h.tokens);
|
||||
k.topk_mdiv = mdiv{static_cast<uint32_t>(h.topk)};
|
||||
return k;
|
||||
}
|
||||
|
||||
@@ -1394,7 +1416,26 @@ struct MoeSortingMultiPhaseKernel_P0
|
||||
|
||||
const topk_id_t* p_topk_ids = reinterpret_cast<const topk_id_t*>(kargs.p_topk_ids);
|
||||
MeshType* p_expert_mesh = reinterpret_cast<MeshType*>(kargs.p_expert_mesh);
|
||||
index_t total_elem = kargs.tokens * kargs.topk_mdiv.divisor / Problem::SubTokenTile;
|
||||
index_t tokens = [&]() {
|
||||
if constexpr(Problem::LocalToken)
|
||||
{
|
||||
return reinterpret_cast<const index_t*>(kargs.p_local_tokens)[0];
|
||||
}
|
||||
else
|
||||
{
|
||||
return kargs.tokens;
|
||||
}
|
||||
}();
|
||||
index_t rounded_tokens = [&]() {
|
||||
if constexpr(Problem::LocalToken)
|
||||
{
|
||||
return (tokens + Problem::SubTokenTile - 1) / Problem::SubTokenTile *
|
||||
Problem::SubTokenTile;
|
||||
}
|
||||
else
|
||||
return tokens;
|
||||
}();
|
||||
index_t total_elem = rounded_tokens * kargs.topk_mdiv.divisor / Problem::SubTokenTile;
|
||||
|
||||
#pragma unroll Problem::SubTokenTile
|
||||
for(index_t i = blockIdx.x * BLOCK_SIZE + threadIdx.x; i < total_elem;
|
||||
@@ -1405,8 +1446,15 @@ struct MoeSortingMultiPhaseKernel_P0
|
||||
IndexType eid = x[j.value]; // ext_vector_type must use int to []
|
||||
uint32_t curr_token_id, curr_topk_id;
|
||||
kargs.topk_mdiv.divmod(i * Problem::SubTokenTile + j, curr_token_id, curr_topk_id);
|
||||
p_expert_mesh[eid * kargs.mesh_stride + curr_token_id] =
|
||||
(curr_topk_id + 1) & 0xffff;
|
||||
if constexpr(Problem::LocalToken)
|
||||
{
|
||||
if(static_cast<index_t>(curr_token_id) < tokens)
|
||||
p_expert_mesh[eid * kargs.mesh_stride + curr_token_id] =
|
||||
(curr_topk_id + 1) & 0xffff;
|
||||
}
|
||||
else
|
||||
p_expert_mesh[eid * kargs.mesh_stride + curr_token_id] =
|
||||
(curr_topk_id + 1) & 0xffff;
|
||||
});
|
||||
}
|
||||
}
|
||||
@@ -1542,6 +1590,7 @@ struct MoeSortingMultiPhaseKernel_P01
|
||||
{
|
||||
const void* p_topk_ids; // [tokens, topk]
|
||||
const void* p_local_expert_mask; // [expert]
|
||||
const void* p_local_tokens; // [1]
|
||||
void* p_expert_mesh; // [expert, tokens]
|
||||
void* p_expert_cumsum; // [expert + 1]
|
||||
void* p_expert_sem; // [1]
|
||||
@@ -1569,6 +1618,7 @@ struct MoeSortingMultiPhaseKernel_P01
|
||||
Kargs k;
|
||||
k.p_topk_ids = h.p_topk_ids;
|
||||
k.p_local_expert_mask = h.p_local_expert_mask;
|
||||
k.p_local_tokens = h.p_local_tokens;
|
||||
k.p_expert_mesh = h.p_ws;
|
||||
k.p_expert_cumsum = reinterpret_cast<void*>(
|
||||
reinterpret_cast<char*>(h.p_ws) +
|
||||
@@ -1580,8 +1630,17 @@ struct MoeSortingMultiPhaseKernel_P01
|
||||
k.tokens = h.tokens;
|
||||
k.num_experts = h.num_experts;
|
||||
k.mesh_stride = impl::moe_sorting_mp_mesh_stride(h.tokens);
|
||||
k.wg_count = WGCounts(h);
|
||||
k.topk_mdiv = mdiv{static_cast<uint32_t>(h.topk)};
|
||||
k.wg_count = [&]() {
|
||||
if constexpr(Problem::LocalToken)
|
||||
{
|
||||
return GridSize(h);
|
||||
}
|
||||
else
|
||||
{
|
||||
return WGCounts(h);
|
||||
}
|
||||
}();
|
||||
k.topk_mdiv = mdiv{static_cast<uint32_t>(h.topk)};
|
||||
return k;
|
||||
}
|
||||
|
||||
@@ -1607,13 +1666,46 @@ struct MoeSortingMultiPhaseKernel_P01
|
||||
CK_TILE_DEVICE void operator()(Kargs kargs) const
|
||||
{
|
||||
workgroup_barrier wb{reinterpret_cast<uint32_t*>(kargs.p_expert_sem)};
|
||||
index_t tokens = [&]() {
|
||||
if constexpr(Problem::LocalToken)
|
||||
{
|
||||
return reinterpret_cast<const index_t*>(kargs.p_local_tokens)[0];
|
||||
}
|
||||
else
|
||||
{
|
||||
return kargs.tokens;
|
||||
}
|
||||
}();
|
||||
index_t rounded_tokens = [&]() {
|
||||
if constexpr(Problem::LocalToken)
|
||||
{
|
||||
return (tokens + Problem::SubTokenTile - 1) / Problem::SubTokenTile *
|
||||
Problem::SubTokenTile;
|
||||
}
|
||||
else
|
||||
return tokens;
|
||||
}();
|
||||
index_t wg_count = [&]() {
|
||||
if constexpr(Problem::LocalToken)
|
||||
{
|
||||
index_t total_elem = rounded_tokens * kargs.topk / Problem::SubTokenTile;
|
||||
index_t elem_cnt = (total_elem + BLOCK_SIZE - 1) / BLOCK_SIZE;
|
||||
|
||||
// no more than grid_size
|
||||
return min(elem_cnt, kargs.wg_count);
|
||||
}
|
||||
else
|
||||
{
|
||||
return kargs.wg_count;
|
||||
}
|
||||
}();
|
||||
|
||||
{
|
||||
using topk_id_t = ext_vector_t<IndexType, Problem::SubTokenTile>;
|
||||
|
||||
const topk_id_t* p_topk_ids = reinterpret_cast<const topk_id_t*>(kargs.p_topk_ids);
|
||||
IndexType* p_expert_mesh = reinterpret_cast<IndexType*>(kargs.p_expert_mesh);
|
||||
index_t total_elem = kargs.tokens * kargs.topk_mdiv.divisor / Problem::SubTokenTile;
|
||||
index_t total_elem = rounded_tokens * kargs.topk_mdiv.divisor / Problem::SubTokenTile;
|
||||
|
||||
#pragma unroll Problem::SubTokenTile
|
||||
for(index_t i = blockIdx.x * BLOCK_SIZE + threadIdx.x; i < total_elem;
|
||||
@@ -1625,10 +1717,19 @@ struct MoeSortingMultiPhaseKernel_P01
|
||||
uint32_t curr_token_id, curr_topk_id;
|
||||
kargs.topk_mdiv.divmod(
|
||||
i * Problem::SubTokenTile + j, curr_token_id, curr_topk_id);
|
||||
p_expert_mesh[eid * kargs.mesh_stride + curr_token_id] = curr_topk_id + 1;
|
||||
// p_expert_mesh[eid * kargs.mesh_stride + curr_token_id] = curr_topk_id + 1;
|
||||
if constexpr(Problem::LocalToken)
|
||||
{
|
||||
if(static_cast<index_t>(curr_token_id) < tokens)
|
||||
p_expert_mesh[eid * kargs.mesh_stride + curr_token_id] =
|
||||
(curr_topk_id + 1) & 0xffff;
|
||||
}
|
||||
else
|
||||
p_expert_mesh[eid * kargs.mesh_stride + curr_token_id] =
|
||||
(curr_topk_id + 1) & 0xffff;
|
||||
});
|
||||
}
|
||||
if(static_cast<index_t>(blockIdx.x) < kargs.wg_count)
|
||||
if(static_cast<index_t>(blockIdx.x) < wg_count)
|
||||
{
|
||||
wb.inc();
|
||||
}
|
||||
@@ -1642,7 +1743,7 @@ struct MoeSortingMultiPhaseKernel_P01
|
||||
if(eid >= kargs.num_experts)
|
||||
return;
|
||||
|
||||
wb.wait_lt(kargs.wg_count);
|
||||
wb.wait_lt(wg_count);
|
||||
|
||||
for(; eid < kargs.num_experts; eid += gridDim.x)
|
||||
{
|
||||
@@ -1731,6 +1832,7 @@ struct MoeSortingMultiPhaseKernel_P2
|
||||
struct Kargs
|
||||
{
|
||||
const void* p_local_expert_mask; // [expert]
|
||||
const void* p_local_tokens; // [1]
|
||||
void* p_expert_mesh; // [expert, tokens]
|
||||
void* p_expert_cumsum; // [expert + 1]
|
||||
void* p_total_tokens_post_pad; // [1]
|
||||
@@ -1747,6 +1849,7 @@ struct MoeSortingMultiPhaseKernel_P2
|
||||
{
|
||||
Kargs k;
|
||||
k.p_local_expert_mask = h.p_local_expert_mask;
|
||||
k.p_local_tokens = h.p_local_tokens;
|
||||
k.p_expert_cumsum = reinterpret_cast<void*>(
|
||||
reinterpret_cast<char*>(h.p_ws) +
|
||||
impl::moe_sorting_mp_mesh_smem_size(h.tokens, h.num_experts, h.topk));
|
||||
@@ -1942,6 +2045,7 @@ struct MoeSortingMultiPhaseKernel_P3
|
||||
{
|
||||
const void* p_weights;
|
||||
const void* p_local_expert_mask;
|
||||
const void* p_local_tokens;
|
||||
void* p_sorted_token_ids;
|
||||
void* p_sorted_weights;
|
||||
void* p_expert_mesh; // [token, expert]
|
||||
@@ -1958,6 +2062,7 @@ struct MoeSortingMultiPhaseKernel_P3
|
||||
Kargs k;
|
||||
k.p_weights = h.p_weights;
|
||||
k.p_local_expert_mask = h.p_local_expert_mask;
|
||||
k.p_local_tokens = h.p_local_tokens;
|
||||
k.p_sorted_token_ids = h.p_sorted_token_ids;
|
||||
k.p_sorted_weights = h.p_sorted_weights;
|
||||
k.p_expert_mesh = h.p_ws;
|
||||
@@ -1994,6 +2099,16 @@ struct MoeSortingMultiPhaseKernel_P3
|
||||
const WeightType* p_weights = static_cast<const WeightType*>(kargs.p_weights);
|
||||
WeightType* p_sorted_weights = reinterpret_cast<WeightType*>(kargs.p_sorted_weights);
|
||||
|
||||
index_t tokens = [&]() {
|
||||
if constexpr(Problem::LocalToken)
|
||||
{
|
||||
return reinterpret_cast<const index_t*>(kargs.p_local_tokens)[0];
|
||||
}
|
||||
else
|
||||
{
|
||||
return kargs.tokens;
|
||||
}
|
||||
}();
|
||||
int eid = blockIdx.x;
|
||||
int wave_id = threadIdx.x / WarpSize;
|
||||
int lane_id = threadIdx.x % WarpSize;
|
||||
@@ -2019,7 +2134,7 @@ struct MoeSortingMultiPhaseKernel_P3
|
||||
{
|
||||
int i_token = i * BLOCK_SIZE + threadIdx.x;
|
||||
IndexType x = 0;
|
||||
if(i_token < kargs.tokens)
|
||||
if(i_token < tokens)
|
||||
{
|
||||
x = p_expert_mesh[eid * kargs.mesh_stride + i_token];
|
||||
}
|
||||
@@ -2066,7 +2181,7 @@ struct MoeSortingMultiPhaseKernel_P3
|
||||
for(index_t i = e_start + prev_cumsum + threadIdx.x; i < e_end; i += BLOCK_SIZE)
|
||||
{
|
||||
#if CK_TILE_REFERENCE_MOE_SORTING_MOCK_ID
|
||||
p_sorted_token_ids[i] = MOE_SORTING_MOCK_ID(kargs.tokens, kargs.topk_mdiv.divisor);
|
||||
p_sorted_token_ids[i] = MOE_SORTING_MOCK_ID(tokens, kargs.topk_mdiv.divisor);
|
||||
#else
|
||||
p_sorted_token_ids[i] = tokens;
|
||||
#endif
|
||||
@@ -2105,6 +2220,7 @@ struct MoeSortingMultiPhaseKernel_P23
|
||||
{
|
||||
const void* p_weights;
|
||||
const void* p_local_expert_mask; // [expert]
|
||||
const void* p_local_tokens; // [1]
|
||||
void* p_expert_mesh; // [expert, tokens]
|
||||
void* p_expert_cumsum; // [expert + 1]
|
||||
void* p_total_tokens_post_pad; // [1]
|
||||
@@ -2127,6 +2243,7 @@ struct MoeSortingMultiPhaseKernel_P23
|
||||
Kargs k;
|
||||
k.p_weights = h.p_weights;
|
||||
k.p_local_expert_mask = h.p_local_expert_mask;
|
||||
k.p_local_tokens = h.p_local_tokens;
|
||||
k.p_expert_mesh = h.p_ws;
|
||||
k.p_expert_cumsum = reinterpret_cast<void*>(
|
||||
reinterpret_cast<char*>(h.p_ws) +
|
||||
@@ -2346,6 +2463,17 @@ struct MoeSortingMultiPhaseKernel_P23
|
||||
return; // skip empty expert
|
||||
}
|
||||
|
||||
index_t tokens = [&]() {
|
||||
if constexpr(Problem::LocalToken)
|
||||
{
|
||||
return reinterpret_cast<const index_t*>(kargs.p_local_tokens)[0];
|
||||
}
|
||||
else
|
||||
{
|
||||
return kargs.tokens;
|
||||
}
|
||||
}();
|
||||
|
||||
// cumsum one by one
|
||||
constexpr index_t index_pack = Problem::SubTokenTile; // always packed
|
||||
using r_t = ext_vector_t<MeshType, index_pack>; // always use int32x4
|
||||
@@ -2357,7 +2485,7 @@ struct MoeSortingMultiPhaseKernel_P23
|
||||
{
|
||||
int i_token_pack = i * BLOCK_SIZE + threadIdx.x;
|
||||
r_t x_v = 0;
|
||||
if(i_token_pack < (kargs.tokens + index_pack - 1) / index_pack)
|
||||
if(i_token_pack < (tokens + index_pack - 1) / index_pack)
|
||||
{
|
||||
x_v = reinterpret_cast<r_t*>(p_expert_mesh +
|
||||
eid * kargs.mesh_stride)[i_token_pack];
|
||||
@@ -2554,7 +2682,7 @@ struct MoeSortingMultiPhaseKernel_P23
|
||||
for(index_t i = e_start + prev_cumsum + threadIdx.x; i < e_end; i += BLOCK_SIZE)
|
||||
{
|
||||
#if CK_TILE_REFERENCE_MOE_SORTING_MOCK_ID
|
||||
p_sorted_token_ids[i] = MOE_SORTING_MOCK_ID(kargs.tokens, kargs.topk_mdiv.divisor);
|
||||
p_sorted_token_ids[i] = MOE_SORTING_MOCK_ID(tokens, kargs.topk_mdiv.divisor);
|
||||
#else
|
||||
p_sorted_token_ids[i] = tokens;
|
||||
#endif
|
||||
|
||||
@@ -31,6 +31,7 @@ template <typename IndexType_,
|
||||
index_t SubTokenTile_, // 1,2,4,8, or 0 in the future
|
||||
bool SubTokenOneShot_, // if we only loop over once or not
|
||||
bool LocalExpertMasking_, // used in EP case
|
||||
bool LocalToken_, // used in EP case
|
||||
bool SkipExpertsWithZeroTokens_ = true,
|
||||
index_t ExpertTile_ = 0>
|
||||
struct MoeSortingProblemEx
|
||||
@@ -44,6 +45,7 @@ struct MoeSortingProblemEx
|
||||
static constexpr index_t SubTokenTile = SubTokenTile_;
|
||||
static constexpr bool SubTokenOneShot = SubTokenOneShot_;
|
||||
static constexpr bool LocalExpertMasking = LocalExpertMasking_;
|
||||
static constexpr bool LocalToken = LocalToken_;
|
||||
static constexpr bool SkipExpertsWithZeroTokens = SkipExpertsWithZeroTokens_;
|
||||
static_assert(SubTokenTile == 1 || SubTokenTile == 2 || SubTokenTile == 4 || SubTokenTile == 8);
|
||||
static constexpr index_t ExpertTile = ExpertTile_; // TODO: only used in store out
|
||||
@@ -54,6 +56,7 @@ template <typename IndexType_,
|
||||
typename MeshType_,
|
||||
index_t SubTokenTile_, // 1,2,4,8
|
||||
bool LocalExpertMasking_, // used in EP case
|
||||
bool LocalToken_, // used in EP case
|
||||
bool SkipExpertsWithZeroTokens_ = true>
|
||||
struct MoeSortingProblemMp
|
||||
{
|
||||
@@ -64,6 +67,7 @@ struct MoeSortingProblemMp
|
||||
|
||||
static constexpr index_t SubTokenTile = SubTokenTile_;
|
||||
static constexpr bool LocalExpertMasking = LocalExpertMasking_;
|
||||
static constexpr bool LocalToken = LocalToken_;
|
||||
static constexpr bool SkipExpertsWithZeroTokens = SkipExpertsWithZeroTokens_;
|
||||
static_assert(SubTokenTile == 1 || SubTokenTile == 2 || SubTokenTile == 4 ||
|
||||
SubTokenTile == 8 || SubTokenTile == 16);
|
||||
|
||||
Reference in New Issue
Block a user