mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-02 12:41:26 +00:00
Felix/opt sorting (#2902)
* merge felix/sorting * opt moe sorting (#2822) * opt moe storing for 2k --------- Co-authored-by: lalala-sh <Jiaxing.Wen@amd.com> Co-authored-by: coderfeli <coderfeli@163.com>
This commit is contained in:
@@ -20,7 +20,7 @@ namespace ck_tile {
|
||||
#endif
|
||||
|
||||
#ifndef MOE_SORTING_FUSE_MP_01
|
||||
#define MOE_SORTING_FUSE_MP_01 0
|
||||
#define MOE_SORTING_FUSE_MP_01 1
|
||||
#endif
|
||||
|
||||
// weather use 2d buffer indexing for fmoe ws or 1d
|
||||
@@ -527,7 +527,7 @@ struct MoeSortingKernel
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
#if 1
|
||||
#if MOE_SORTING_FUSE_MP_01
|
||||
if(tid < num_experts)
|
||||
{
|
||||
tokens_cnts[calc_index(num_experts + 1, 0, tid)] = 0;
|
||||
@@ -1322,18 +1322,18 @@ CK_TILE_DEVICE void moe_sorting_wave_cumsum(data_t& thread_data)
|
||||
}
|
||||
}
|
||||
|
||||
template <index_t BLOCK_SIZE = 256>
|
||||
template <index_t kBlockSize = 256>
|
||||
CK_TILE_DEVICE void moe_buf_set_zero_kernel(uint8x16_t* buf, long_index_t buf_bytes, index_t gid)
|
||||
{
|
||||
// const index_t offset = (blockIdx.x - 1) * BLOCK_SIZE + threadIdx.x;
|
||||
long_index_t offset = static_cast<long_index_t>(gid) * BLOCK_SIZE + threadIdx.x;
|
||||
// const index_t offset = (blockIdx.x - 1) * kBlockSize + threadIdx.x;
|
||||
long_index_t offset = static_cast<long_index_t>(gid) * kBlockSize + threadIdx.x;
|
||||
if(offset < buf_bytes / 16)
|
||||
{
|
||||
buf[offset] = uint8x16_t{0};
|
||||
}
|
||||
}
|
||||
|
||||
template <index_t BLOCK_SIZE = 256>
|
||||
template <index_t kBlockSize = 256>
|
||||
CK_TILE_DEVICE void moe_buf_set_zero_kernel_2d(
|
||||
void* buf, index_t row, index_t col, index_t elem_bytes, index_t gid, index_t blocks)
|
||||
{
|
||||
@@ -1345,7 +1345,7 @@ CK_TILE_DEVICE void moe_buf_set_zero_kernel_2d(
|
||||
vector_type* p_buf = reinterpret_cast<vector_type*>(buf);
|
||||
auto zero_ = vector_type{0};
|
||||
|
||||
for(long_index_t i = gid * BLOCK_SIZE + threadIdx.x; i < total_elems; i += blocks * BLOCK_SIZE)
|
||||
for(long_index_t i = gid * kBlockSize + threadIdx.x; i < total_elems; i += blocks * kBlockSize)
|
||||
{
|
||||
p_buf[i] = zero_;
|
||||
}
|
||||
@@ -1552,7 +1552,7 @@ p_m_cumsum
|
||||
|
||||
// count topk_id into mesh
|
||||
template <typename Problem_>
|
||||
struct MoeSortingMultiPhaseKernel_P0
|
||||
struct MoeSortingMultiPhaseKernel_P0_v1
|
||||
{
|
||||
using Problem = remove_cvref_t<Problem_>;
|
||||
|
||||
@@ -1673,6 +1673,197 @@ struct MoeSortingMultiPhaseKernel_P0
|
||||
}
|
||||
}
|
||||
};
|
||||
template <typename Problem_>
|
||||
struct MoeSortingMultiPhaseKernel_P0_v2
|
||||
{
|
||||
using Problem = remove_cvref_t<Problem_>;
|
||||
|
||||
using IndexType = typename Problem::IndexType;
|
||||
using WeightType = typename Problem::WeightType;
|
||||
using MeshType = typename Problem::MeshType;
|
||||
|
||||
static constexpr index_t kBlockSize = 512;
|
||||
|
||||
typedef MoeSortingHostArgs MoeSortingKargs;
|
||||
|
||||
using Hargs = MoeSortingHostArgs;
|
||||
|
||||
struct Kargs
|
||||
{
|
||||
const void* p_topk_ids; // [tokens, topk]
|
||||
const void* p_local_tokens; // [1], if not nullptr, use this as actual tokens
|
||||
void* p_expert_mesh; // [expert, tokens]
|
||||
index_t tokens; // if p_local_tokens is not nullptr, this indicate the max possible tokens
|
||||
// used for ws/LDS calculation
|
||||
index_t mesh_stride; // mesh_stride for p_expert_mesh
|
||||
mdiv topk_mdiv;
|
||||
|
||||
const void* p_local_expert_mask; // [expert]
|
||||
void* p_expert_cumsum; // [expert]
|
||||
index_t num_experts;
|
||||
};
|
||||
|
||||
CK_TILE_HOST static constexpr auto get_num_cu()
|
||||
{
|
||||
index_t num_cu = [&]() {
|
||||
hipDeviceProp_t dev_prop;
|
||||
hipDevice_t dev;
|
||||
HIP_CHECK_ERROR(hipGetDevice(&dev));
|
||||
HIP_CHECK_ERROR(hipGetDeviceProperties(&dev_prop, dev));
|
||||
return dev_prop.multiProcessorCount;
|
||||
}();
|
||||
return num_cu;
|
||||
}
|
||||
|
||||
CK_TILE_HOST static constexpr auto MakeKargs(const Hargs& h)
|
||||
{
|
||||
Kargs k;
|
||||
k.p_topk_ids = h.p_topk_ids;
|
||||
k.p_local_tokens = h.p_local_tokens;
|
||||
k.p_expert_mesh = h.p_ws;
|
||||
k.p_expert_cumsum = reinterpret_cast<void*>(
|
||||
reinterpret_cast<char*>(h.p_ws) +
|
||||
impl::moe_sorting_mp_mesh_smem_size(h.tokens, h.num_experts, h.topk));
|
||||
k.tokens = h.tokens;
|
||||
k.mesh_stride = impl::moe_sorting_mp_mesh_stride(h.tokens);
|
||||
k.topk_mdiv = mdiv{static_cast<uint32_t>(h.topk)};
|
||||
k.p_local_expert_mask = h.p_local_expert_mask;
|
||||
k.num_experts = h.num_experts;
|
||||
return k;
|
||||
}
|
||||
|
||||
CK_TILE_HOST static constexpr auto GridSize(const Hargs& h) { return h.num_experts; }
|
||||
|
||||
CK_TILE_HOST static constexpr auto BlockSize(const Hargs&) { return dim3(kBlockSize); }
|
||||
|
||||
// in byte
|
||||
// CK_TILE_HOST static constexpr auto GetSmemSize() { return 0; }
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetSmemSize()
|
||||
{
|
||||
return kBlockSize / get_warp_size() * sizeof(IndexType);
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE void operator()(Kargs kargs) const
|
||||
{
|
||||
constexpr index_t index_pack = Problem::SubTokenTile; // always packed
|
||||
__shared__ char smem[GetSmemSize()];
|
||||
using topk_id_t = ext_vector_t<IndexType, index_pack>;
|
||||
const int eid = blockIdx.x;
|
||||
const topk_id_t* p_topk_ids = reinterpret_cast<const topk_id_t*>(kargs.p_topk_ids);
|
||||
const IndexType* p_local_expert_mask =
|
||||
static_cast<const IndexType*>(kargs.p_local_expert_mask);
|
||||
IndexType* p_expert_cumsum = reinterpret_cast<IndexType*>(kargs.p_expert_cumsum);
|
||||
index_t lane_id = threadIdx.x % get_warp_size();
|
||||
index_t wave_id = threadIdx.x / get_warp_size();
|
||||
const index_t tokens = [&]() {
|
||||
if constexpr(Problem::LocalToken)
|
||||
{
|
||||
return reinterpret_cast<const index_t*>(kargs.p_local_tokens)[0];
|
||||
}
|
||||
else
|
||||
{
|
||||
return kargs.tokens;
|
||||
}
|
||||
}();
|
||||
index_t rounded_tokens = [&]() {
|
||||
if constexpr(Problem::LocalToken)
|
||||
{
|
||||
return (tokens + index_pack - 1) / index_pack * index_pack;
|
||||
}
|
||||
else
|
||||
return tokens;
|
||||
}();
|
||||
index_t mesh_stride = [&]() {
|
||||
if constexpr(Problem::LocalToken)
|
||||
{
|
||||
return impl::moe_sorting_mp_mesh_stride(tokens);
|
||||
}
|
||||
else
|
||||
{
|
||||
return kargs.mesh_stride;
|
||||
}
|
||||
}();
|
||||
|
||||
IndexType mask = 1;
|
||||
if constexpr(Problem::LocalExpertMasking)
|
||||
{
|
||||
mask = p_local_expert_mask[eid];
|
||||
}
|
||||
MeshType* p_expert_mesh =
|
||||
reinterpret_cast<MeshType*>(kargs.p_expert_mesh) + eid * mesh_stride;
|
||||
for(index_t i = threadIdx.x; i < mesh_stride; i += kBlockSize)
|
||||
{
|
||||
p_expert_mesh[i] = 0;
|
||||
}
|
||||
ck_tile::block_sync_load_raw(0);
|
||||
|
||||
index_t total_elem = rounded_tokens * kargs.topk_mdiv.divisor / index_pack;
|
||||
|
||||
#pragma unroll index_pack
|
||||
for(index_t i = threadIdx.x; i < total_elem; i += kBlockSize)
|
||||
{
|
||||
auto x = p_topk_ids[i];
|
||||
static_for<0, index_pack, 1>{}([&](auto j) {
|
||||
IndexType eid_x = x[j.value]; // ext_vector_type must use int to []
|
||||
if(eid_x == eid)
|
||||
{
|
||||
uint32_t curr_token_id, curr_topk_id;
|
||||
kargs.topk_mdiv.divmod(i * index_pack + j, curr_token_id, curr_topk_id);
|
||||
if constexpr(Problem::LocalToken)
|
||||
{
|
||||
if(static_cast<index_t>(curr_token_id) < tokens)
|
||||
p_expert_mesh[curr_token_id] = (curr_topk_id + 1) & 0xffff;
|
||||
}
|
||||
else
|
||||
p_expert_mesh[curr_token_id] = (curr_topk_id + 1) & 0xffff;
|
||||
}
|
||||
});
|
||||
}
|
||||
ck_tile::block_sync_load_raw(0);
|
||||
|
||||
{
|
||||
|
||||
using r_t = ext_vector_t<MeshType, index_pack>; // always use int32x4
|
||||
auto f_sum = [](auto x_, auto y_) { return x_ + y_; };
|
||||
const r_t* p_expert_mesh_r = reinterpret_cast<r_t*>(p_expert_mesh);
|
||||
|
||||
int loops = (mesh_stride / index_pack + kBlockSize - 1) / kBlockSize;
|
||||
|
||||
if(Problem::LocalToken && mask == 0)
|
||||
return; // skip
|
||||
index_t cnt = 0; // per-wave cnt
|
||||
for(int i = 0; i < loops; i++)
|
||||
{
|
||||
int position = i * kBlockSize + threadIdx.x;
|
||||
r_t v{0};
|
||||
if(position < (mesh_stride / index_pack))
|
||||
v = p_expert_mesh_r[position];
|
||||
index_t local_sum = 0;
|
||||
static_for<0, index_pack, 1>{}(
|
||||
[&](auto i_vec) { local_sum += v[i_vec.value] != 0 ? 1 : 0; });
|
||||
cnt += impl::moe_sorting_wave_reduce(local_sum, f_sum);
|
||||
}
|
||||
|
||||
// reduce cross wave
|
||||
IndexType* s = reinterpret_cast<IndexType*>(smem);
|
||||
if(lane_id == 0)
|
||||
{
|
||||
s[wave_id] = cnt;
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
if(threadIdx.x == 0)
|
||||
{
|
||||
index_t c = 0;
|
||||
for(auto i = 0; i < (kBlockSize / get_warp_size()); i++)
|
||||
{
|
||||
c += s[i];
|
||||
}
|
||||
p_expert_cumsum[eid] = c;
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
// cnt total tokens for a expert
|
||||
template <typename Problem_>
|
||||
|
||||
Reference in New Issue
Block a user