Felix/opt sorting (#2902)

* merge felix/sorting
* opt moe sorting  (#2822)
* opt moe storing for 2k
---------
Co-authored-by: lalala-sh <Jiaxing.Wen@amd.com>
Co-authored-by: coderfeli <coderfeli@163.com>
This commit is contained in:
felix
2025-10-15 09:24:03 +08:00
committed by GitHub
parent ca1ab083a7
commit 4c826abfff
4 changed files with 812 additions and 217 deletions

View File

@@ -20,7 +20,7 @@ namespace ck_tile {
#endif
#ifndef MOE_SORTING_FUSE_MP_01
#define MOE_SORTING_FUSE_MP_01 0
#define MOE_SORTING_FUSE_MP_01 1
#endif
// weather use 2d buffer indexing for fmoe ws or 1d
@@ -527,7 +527,7 @@ struct MoeSortingKernel
}
__syncthreads();
#if 1
#if MOE_SORTING_FUSE_MP_01
if(tid < num_experts)
{
tokens_cnts[calc_index(num_experts + 1, 0, tid)] = 0;
@@ -1322,18 +1322,18 @@ CK_TILE_DEVICE void moe_sorting_wave_cumsum(data_t& thread_data)
}
}
template <index_t BLOCK_SIZE = 256>
template <index_t kBlockSize = 256>
CK_TILE_DEVICE void moe_buf_set_zero_kernel(uint8x16_t* buf, long_index_t buf_bytes, index_t gid)
{
// const index_t offset = (blockIdx.x - 1) * BLOCK_SIZE + threadIdx.x;
long_index_t offset = static_cast<long_index_t>(gid) * BLOCK_SIZE + threadIdx.x;
// const index_t offset = (blockIdx.x - 1) * kBlockSize + threadIdx.x;
long_index_t offset = static_cast<long_index_t>(gid) * kBlockSize + threadIdx.x;
if(offset < buf_bytes / 16)
{
buf[offset] = uint8x16_t{0};
}
}
template <index_t BLOCK_SIZE = 256>
template <index_t kBlockSize = 256>
CK_TILE_DEVICE void moe_buf_set_zero_kernel_2d(
void* buf, index_t row, index_t col, index_t elem_bytes, index_t gid, index_t blocks)
{
@@ -1345,7 +1345,7 @@ CK_TILE_DEVICE void moe_buf_set_zero_kernel_2d(
vector_type* p_buf = reinterpret_cast<vector_type*>(buf);
auto zero_ = vector_type{0};
for(long_index_t i = gid * BLOCK_SIZE + threadIdx.x; i < total_elems; i += blocks * BLOCK_SIZE)
for(long_index_t i = gid * kBlockSize + threadIdx.x; i < total_elems; i += blocks * kBlockSize)
{
p_buf[i] = zero_;
}
@@ -1552,7 +1552,7 @@ p_m_cumsum
// count topk_id into mesh
template <typename Problem_>
struct MoeSortingMultiPhaseKernel_P0
struct MoeSortingMultiPhaseKernel_P0_v1
{
using Problem = remove_cvref_t<Problem_>;
@@ -1673,6 +1673,197 @@ struct MoeSortingMultiPhaseKernel_P0
}
}
};
template <typename Problem_>
struct MoeSortingMultiPhaseKernel_P0_v2
{
using Problem = remove_cvref_t<Problem_>;
using IndexType = typename Problem::IndexType;
using WeightType = typename Problem::WeightType;
using MeshType = typename Problem::MeshType;
static constexpr index_t kBlockSize = 512;
typedef MoeSortingHostArgs MoeSortingKargs;
using Hargs = MoeSortingHostArgs;
struct Kargs
{
const void* p_topk_ids; // [tokens, topk]
const void* p_local_tokens; // [1], if not nullptr, use this as actual tokens
void* p_expert_mesh; // [expert, tokens]
index_t tokens; // if p_local_tokens is not nullptr, this indicate the max possible tokens
// used for ws/LDS calculation
index_t mesh_stride; // mesh_stride for p_expert_mesh
mdiv topk_mdiv;
const void* p_local_expert_mask; // [expert]
void* p_expert_cumsum; // [expert]
index_t num_experts;
};
CK_TILE_HOST static constexpr auto get_num_cu()
{
index_t num_cu = [&]() {
hipDeviceProp_t dev_prop;
hipDevice_t dev;
HIP_CHECK_ERROR(hipGetDevice(&dev));
HIP_CHECK_ERROR(hipGetDeviceProperties(&dev_prop, dev));
return dev_prop.multiProcessorCount;
}();
return num_cu;
}
CK_TILE_HOST static constexpr auto MakeKargs(const Hargs& h)
{
Kargs k;
k.p_topk_ids = h.p_topk_ids;
k.p_local_tokens = h.p_local_tokens;
k.p_expert_mesh = h.p_ws;
k.p_expert_cumsum = reinterpret_cast<void*>(
reinterpret_cast<char*>(h.p_ws) +
impl::moe_sorting_mp_mesh_smem_size(h.tokens, h.num_experts, h.topk));
k.tokens = h.tokens;
k.mesh_stride = impl::moe_sorting_mp_mesh_stride(h.tokens);
k.topk_mdiv = mdiv{static_cast<uint32_t>(h.topk)};
k.p_local_expert_mask = h.p_local_expert_mask;
k.num_experts = h.num_experts;
return k;
}
CK_TILE_HOST static constexpr auto GridSize(const Hargs& h) { return h.num_experts; }
CK_TILE_HOST static constexpr auto BlockSize(const Hargs&) { return dim3(kBlockSize); }
// in byte
// CK_TILE_HOST static constexpr auto GetSmemSize() { return 0; }
CK_TILE_HOST_DEVICE static constexpr auto GetSmemSize()
{
return kBlockSize / get_warp_size() * sizeof(IndexType);
}
CK_TILE_DEVICE void operator()(Kargs kargs) const
{
constexpr index_t index_pack = Problem::SubTokenTile; // always packed
__shared__ char smem[GetSmemSize()];
using topk_id_t = ext_vector_t<IndexType, index_pack>;
const int eid = blockIdx.x;
const topk_id_t* p_topk_ids = reinterpret_cast<const topk_id_t*>(kargs.p_topk_ids);
const IndexType* p_local_expert_mask =
static_cast<const IndexType*>(kargs.p_local_expert_mask);
IndexType* p_expert_cumsum = reinterpret_cast<IndexType*>(kargs.p_expert_cumsum);
index_t lane_id = threadIdx.x % get_warp_size();
index_t wave_id = threadIdx.x / get_warp_size();
const index_t tokens = [&]() {
if constexpr(Problem::LocalToken)
{
return reinterpret_cast<const index_t*>(kargs.p_local_tokens)[0];
}
else
{
return kargs.tokens;
}
}();
index_t rounded_tokens = [&]() {
if constexpr(Problem::LocalToken)
{
return (tokens + index_pack - 1) / index_pack * index_pack;
}
else
return tokens;
}();
index_t mesh_stride = [&]() {
if constexpr(Problem::LocalToken)
{
return impl::moe_sorting_mp_mesh_stride(tokens);
}
else
{
return kargs.mesh_stride;
}
}();
IndexType mask = 1;
if constexpr(Problem::LocalExpertMasking)
{
mask = p_local_expert_mask[eid];
}
MeshType* p_expert_mesh =
reinterpret_cast<MeshType*>(kargs.p_expert_mesh) + eid * mesh_stride;
for(index_t i = threadIdx.x; i < mesh_stride; i += kBlockSize)
{
p_expert_mesh[i] = 0;
}
ck_tile::block_sync_load_raw(0);
index_t total_elem = rounded_tokens * kargs.topk_mdiv.divisor / index_pack;
#pragma unroll index_pack
for(index_t i = threadIdx.x; i < total_elem; i += kBlockSize)
{
auto x = p_topk_ids[i];
static_for<0, index_pack, 1>{}([&](auto j) {
IndexType eid_x = x[j.value]; // ext_vector_type must use int to []
if(eid_x == eid)
{
uint32_t curr_token_id, curr_topk_id;
kargs.topk_mdiv.divmod(i * index_pack + j, curr_token_id, curr_topk_id);
if constexpr(Problem::LocalToken)
{
if(static_cast<index_t>(curr_token_id) < tokens)
p_expert_mesh[curr_token_id] = (curr_topk_id + 1) & 0xffff;
}
else
p_expert_mesh[curr_token_id] = (curr_topk_id + 1) & 0xffff;
}
});
}
ck_tile::block_sync_load_raw(0);
{
using r_t = ext_vector_t<MeshType, index_pack>; // always use int32x4
auto f_sum = [](auto x_, auto y_) { return x_ + y_; };
const r_t* p_expert_mesh_r = reinterpret_cast<r_t*>(p_expert_mesh);
int loops = (mesh_stride / index_pack + kBlockSize - 1) / kBlockSize;
if(Problem::LocalToken && mask == 0)
return; // skip
index_t cnt = 0; // per-wave cnt
for(int i = 0; i < loops; i++)
{
int position = i * kBlockSize + threadIdx.x;
r_t v{0};
if(position < (mesh_stride / index_pack))
v = p_expert_mesh_r[position];
index_t local_sum = 0;
static_for<0, index_pack, 1>{}(
[&](auto i_vec) { local_sum += v[i_vec.value] != 0 ? 1 : 0; });
cnt += impl::moe_sorting_wave_reduce(local_sum, f_sum);
}
// reduce cross wave
IndexType* s = reinterpret_cast<IndexType*>(smem);
if(lane_id == 0)
{
s[wave_id] = cnt;
}
__syncthreads();
if(threadIdx.x == 0)
{
index_t c = 0;
for(auto i = 0; i < (kBlockSize / get_warp_size()); i++)
{
c += s[i];
}
p_expert_cumsum[eid] = c;
}
}
}
};
// cnt total tokens for a expert
template <typename Problem_>