mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-01 12:11:19 +00:00
[CK_TILE] moe sorting optimize local_token (#2469)
* fix bug in loops that need use local tokens to compute * support extra chain local_token * update * update * refine some main * update * support dispatch_policy * fix 15 example
This commit is contained in:
@@ -23,6 +23,11 @@ namespace ck_tile {
|
||||
#define MOE_SORTING_FUSE_MP_01 0
|
||||
#endif
|
||||
|
||||
// weather use 2d buffer indexing for fmoe ws or 1d
|
||||
#ifndef MOE_SORTING_FMOE_2D_BUF
|
||||
#define MOE_SORTING_FMOE_2D_BUF 1
|
||||
#endif
|
||||
|
||||
// clang-format off
|
||||
// [indexing implementation-1]
|
||||
// using M_a as constexpr block_size to partition all tokens into different slices
|
||||
@@ -171,7 +176,7 @@ struct MoeSortingHostArgs
|
||||
void* p_sorted_token_ids;
|
||||
void* p_sorted_weights;
|
||||
void* p_sorted_expert_ids;
|
||||
void* p_total_tokens_post_pad;
|
||||
void* p_total_tokens_post_pad; // [2], [0]:outputed tokens_post_padded, [1]:actual tokens on current rank (local_tokens or tokens)
|
||||
// we fused the setzero of output of fused-moe buffer
|
||||
// set this pointer to nullptr will skip this operation
|
||||
void* p_moe_buf;
|
||||
@@ -182,7 +187,18 @@ struct MoeSortingHostArgs
|
||||
index_t unit_size; // this is the M_a of fused-moe kernel
|
||||
index_t num_experts;
|
||||
index_t topk;
|
||||
#if MOE_SORTING_FMOE_2D_BUF
|
||||
// NOTE:
|
||||
// moe_buf_* is a 2d ws buffer used for the following fmoe kernel
|
||||
// arranged as row*col, where row=tokens(or local_token), col=interm_dim
|
||||
// we fuse this clearing inside sorting kernel
|
||||
// Besides, we require inter_dim to be multiple of 16 byte(make sure when alloc ws for fmoe)
|
||||
index_t moe_buf_interm_dim; // p_moe_buf interm_dim
|
||||
index_t moe_buf_elem_bytes; // p_moe_buf byte size(8bit, 16bit, 32bit, etc.)
|
||||
#else
|
||||
long_index_t moe_buf_bytes; // byte size of p_moe_buf
|
||||
#endif
|
||||
|
||||
};
|
||||
|
||||
template <typename Problem_>
|
||||
@@ -197,6 +213,9 @@ struct MoeSortingKernel
|
||||
|
||||
using Hargs = MoeSortingHostArgs;
|
||||
|
||||
static constexpr index_t BLOCK_SIZE = 256;
|
||||
static constexpr index_t OCCUPANCY = 2; // hard coded
|
||||
|
||||
struct Kargs
|
||||
{
|
||||
const void* p_topk_ids;
|
||||
@@ -210,8 +229,12 @@ struct MoeSortingKernel
|
||||
void* p_moe_buf;
|
||||
index_t tokens;
|
||||
index_t num_experts;
|
||||
#if MOE_SORTING_FMOE_2D_BUF
|
||||
index_t moe_buf_interm_dim; // p_moe_buf interm_dim
|
||||
index_t moe_buf_elem_bytes; // p_moe_buf byte size(8bit, 16bit, 32bit, etc.)
|
||||
#else
|
||||
long_index_t moe_buf_bytes;
|
||||
|
||||
#endif
|
||||
index_t tokens_per_thread;
|
||||
index_t smem_rows;
|
||||
mdiv unit_size_mdiv;
|
||||
@@ -220,10 +243,27 @@ struct MoeSortingKernel
|
||||
// mdiv sub_tokens_mdiv;
|
||||
};
|
||||
|
||||
CK_TILE_HOST static constexpr auto get_num_cu()
|
||||
{
|
||||
index_t num_cu = [&]() {
|
||||
hipDeviceProp_t dev_prop;
|
||||
hipDevice_t dev;
|
||||
HIP_CHECK_ERROR(hipGetDevice(&dev));
|
||||
HIP_CHECK_ERROR(hipGetDeviceProperties(&dev_prop, dev));
|
||||
return dev_prop.multiProcessorCount;
|
||||
}();
|
||||
return num_cu;
|
||||
}
|
||||
|
||||
CK_TILE_HOST static constexpr auto GridSize(const Hargs& h)
|
||||
{
|
||||
#if MOE_SORTING_FMOE_2D_BUF
|
||||
(void)h;
|
||||
return get_num_cu() * OCCUPANCY;
|
||||
#else
|
||||
// TODO: assume num-experts not too much
|
||||
return dim3(1 + ck_tile::integer_divide_ceil(h.moe_buf_bytes, BlockSize(h).x * 16));
|
||||
#endif
|
||||
}
|
||||
|
||||
CK_TILE_HOST static constexpr auto BlockSize(const Hargs& h)
|
||||
@@ -263,7 +303,12 @@ struct MoeSortingKernel
|
||||
k.p_total_tokens_post_pad = h.p_total_tokens_post_pad;
|
||||
k.tokens = h.tokens;
|
||||
k.num_experts = h.num_experts;
|
||||
#if MOE_SORTING_FMOE_2D_BUF
|
||||
k.moe_buf_interm_dim = h.moe_buf_interm_dim;
|
||||
k.moe_buf_elem_bytes = h.moe_buf_elem_bytes;
|
||||
#else
|
||||
k.moe_buf_bytes = h.moe_buf_bytes;
|
||||
#endif
|
||||
|
||||
const auto blocks = BlockSize(h);
|
||||
// NOTE: tokens could from p_local_tokens, so here this variable is useless
|
||||
@@ -431,6 +476,24 @@ struct MoeSortingKernel
|
||||
}
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE void
|
||||
moe_buf_set_zero_kernel_2d(void* buf, index_t row, index_t col, index_t elem_bytes) const
|
||||
{
|
||||
const long_index_t total_pixels = static_cast<long_index_t>(row) * col;
|
||||
const long_index_t total_bytes = total_pixels * elem_bytes;
|
||||
const long_index_t total_elems = total_bytes / 16; // always use dwordx4
|
||||
|
||||
using vector_type = ext_vector_t<index_t, 4>;
|
||||
vector_type* p_buf = reinterpret_cast<vector_type*>(buf);
|
||||
auto zero_ = vector_type{0};
|
||||
|
||||
for(long_index_t i = (blockIdx.x - 1) * BLOCK_SIZE + threadIdx.x; i < total_elems;
|
||||
i += (gridDim.x - 1) * BLOCK_SIZE)
|
||||
{
|
||||
p_buf[i] = zero_;
|
||||
}
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE void moe_align_block_size_kernel(const IndexType* __restrict__ topk_id,
|
||||
const WeightType* __restrict__ weights,
|
||||
index_t* p_sorted_token_ids,
|
||||
@@ -863,7 +926,8 @@ struct MoeSortingKernel
|
||||
}
|
||||
if((lid + i_e_ - get_warp_size()) == (num_experts - 1))
|
||||
{
|
||||
*p_total_tokens_post_pad = local_cumsum_;
|
||||
*p_total_tokens_post_pad = local_cumsum_;
|
||||
p_total_tokens_post_pad[1] = tokens;
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
@@ -1005,20 +1069,6 @@ struct MoeSortingKernel
|
||||
|
||||
CK_TILE_DEVICE void operator()(Kargs kargs) const
|
||||
{
|
||||
if(blockIdx.x > 0)
|
||||
{
|
||||
if(kargs.p_moe_buf)
|
||||
{
|
||||
moe_buf_set_zero_kernel(reinterpret_cast<uint8x16_t*>(kargs.p_moe_buf),
|
||||
kargs.moe_buf_bytes);
|
||||
}
|
||||
return;
|
||||
}
|
||||
const size_t numel = kargs.tokens * kargs.topk_mdiv.divisor;
|
||||
extern __shared__ char smem[];
|
||||
|
||||
#if MOE_SORTING_USE_EX_KERNEL
|
||||
(void)numel;
|
||||
index_t tokens_ = [&]() {
|
||||
if constexpr(Problem::LocalToken)
|
||||
{
|
||||
@@ -1029,6 +1079,25 @@ struct MoeSortingKernel
|
||||
return kargs.tokens;
|
||||
}
|
||||
}();
|
||||
|
||||
if(blockIdx.x > 0)
|
||||
{
|
||||
if(kargs.p_moe_buf)
|
||||
{
|
||||
#if MOE_SORTING_FMOE_2D_BUF
|
||||
moe_buf_set_zero_kernel_2d(
|
||||
kargs.p_moe_buf, tokens_, kargs.moe_buf_interm_dim, kargs.moe_buf_elem_bytes);
|
||||
#else
|
||||
moe_buf_set_zero_kernel(reinterpret_cast<uint8x16_t*>(kargs.p_moe_buf),
|
||||
kargs.moe_buf_bytes);
|
||||
#endif
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
extern __shared__ char smem[];
|
||||
|
||||
#if MOE_SORTING_USE_EX_KERNEL
|
||||
return moe_align_block_size_kernel_ex(
|
||||
static_cast<const IndexType*>(kargs.p_topk_ids),
|
||||
static_cast<const WeightType*>(kargs.p_weights),
|
||||
@@ -1045,6 +1114,7 @@ struct MoeSortingKernel
|
||||
kargs.smem_rows,
|
||||
smem);
|
||||
#else
|
||||
const size_t numel = kargs.tokens * kargs.topk_mdiv.divisor;
|
||||
return moe_align_block_size_kernel(static_cast<const IndexType*>(kargs.p_topk_ids),
|
||||
static_cast<const WeightType*>(kargs.p_weights),
|
||||
static_cast<IndexType*>(kargs.p_sorted_token_ids),
|
||||
@@ -1066,6 +1136,8 @@ namespace impl {
|
||||
// [expert, padded_tokens]
|
||||
CK_TILE_HOST_DEVICE index_t moe_sorting_mp_mesh_stride(index_t tokens)
|
||||
{
|
||||
// Pad to multiply of 32. This can make sure even if the mesh is in 8bit,
|
||||
// we can still use dwordx4 load/store
|
||||
constexpr index_t chunk = 32;
|
||||
return (tokens + chunk - 1) / chunk * chunk;
|
||||
};
|
||||
@@ -1261,6 +1333,24 @@ CK_TILE_DEVICE void moe_buf_set_zero_kernel(uint8x16_t* buf, long_index_t buf_by
|
||||
}
|
||||
}
|
||||
|
||||
template <index_t BLOCK_SIZE = 256>
|
||||
CK_TILE_DEVICE void moe_buf_set_zero_kernel_2d(
|
||||
void* buf, index_t row, index_t col, index_t elem_bytes, index_t gid, index_t blocks)
|
||||
{
|
||||
const long_index_t total_pixels = static_cast<long_index_t>(row) * col;
|
||||
const long_index_t total_bytes = total_pixels * elem_bytes;
|
||||
const long_index_t total_elems = total_bytes / 16; // always use dwordx4
|
||||
|
||||
using vector_type = ext_vector_t<index_t, 4>;
|
||||
vector_type* p_buf = reinterpret_cast<vector_type*>(buf);
|
||||
auto zero_ = vector_type{0};
|
||||
|
||||
for(long_index_t i = gid * BLOCK_SIZE + threadIdx.x; i < total_elems; i += blocks * BLOCK_SIZE)
|
||||
{
|
||||
p_buf[i] = zero_;
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace impl
|
||||
|
||||
// TODO: tokens could be from
|
||||
@@ -1292,12 +1382,29 @@ CK_TILE_HOST index_t moe_sorting_mp_get_workspace_size(int tokens_, int num_expe
|
||||
}
|
||||
|
||||
// return size in byte
|
||||
CK_TILE_HOST index_t moe_sorting_get_workspace_size(int tokens_, int num_experts_, int topk_)
|
||||
// dispatch_policy: 0-automatically pick up kerel. 1-always use single kernel, 2-always use mp
|
||||
// kernel
|
||||
CK_TILE_HOST index_t moe_sorting_get_workspace_size(int tokens_,
|
||||
int num_experts_,
|
||||
int topk_,
|
||||
int dispatch_policy_)
|
||||
{
|
||||
#if 1
|
||||
if(moe_sorting_is_oneshot(tokens_, num_experts_))
|
||||
// return 0;
|
||||
if(dispatch_policy_ == 0)
|
||||
{
|
||||
return 0;
|
||||
if(moe_sorting_is_oneshot(tokens_, num_experts_))
|
||||
{
|
||||
return 0;
|
||||
}
|
||||
else
|
||||
{
|
||||
return moe_sorting_mp_get_workspace_size(tokens_, num_experts_, topk_);
|
||||
}
|
||||
}
|
||||
else if(dispatch_policy_ == 1)
|
||||
{
|
||||
return 0; // always use single kernel
|
||||
}
|
||||
else
|
||||
{
|
||||
@@ -1308,6 +1415,98 @@ CK_TILE_HOST index_t moe_sorting_get_workspace_size(int tokens_, int num_experts
|
||||
#endif
|
||||
}
|
||||
|
||||
template <typename Problem_>
|
||||
struct MoeSortingClearWorkspaceKernel
|
||||
{
|
||||
using Problem = remove_cvref_t<Problem_>;
|
||||
static constexpr index_t BLOCK_SIZE = Problem::BlockSize;
|
||||
static constexpr index_t OCCUPANCY = Problem::Occu;
|
||||
|
||||
using Hargs = MoeSortingHostArgs;
|
||||
|
||||
struct Kargs
|
||||
{
|
||||
const void* p_local_tokens; // [1], if not nullptr, use this as actual tokens
|
||||
void* p_expert_mesh; // [expert, tokens]
|
||||
index_t tokens; // if p_local_tokens is not nullptr, this indicate the max possible tokens
|
||||
// used for ws/LDS calculation
|
||||
index_t num_experts;
|
||||
index_t mesh_stride; // mesh_stride for p_expert_mesh
|
||||
index_t mesh_byte_size;
|
||||
};
|
||||
|
||||
CK_TILE_HOST static constexpr auto get_num_cu()
|
||||
{
|
||||
index_t num_cu = [&]() {
|
||||
hipDeviceProp_t dev_prop;
|
||||
hipDevice_t dev;
|
||||
HIP_CHECK_ERROR(hipGetDevice(&dev));
|
||||
HIP_CHECK_ERROR(hipGetDeviceProperties(&dev_prop, dev));
|
||||
return dev_prop.multiProcessorCount;
|
||||
}();
|
||||
return num_cu;
|
||||
}
|
||||
|
||||
CK_TILE_HOST static constexpr auto MakeKargs(const Hargs& h)
|
||||
{
|
||||
Kargs k;
|
||||
k.p_local_tokens = h.p_local_tokens;
|
||||
k.p_expert_mesh = h.p_ws;
|
||||
k.tokens = h.tokens;
|
||||
k.num_experts = h.num_experts;
|
||||
k.mesh_stride = impl::moe_sorting_mp_mesh_stride(h.tokens);
|
||||
k.mesh_byte_size = impl::moe_sorting_mesh_byte_size(h.tokens, h.num_experts, h.topk);
|
||||
return k;
|
||||
}
|
||||
|
||||
CK_TILE_HOST static constexpr auto GridSize(const Hargs&) { return get_num_cu() * OCCUPANCY; }
|
||||
|
||||
CK_TILE_HOST static constexpr auto BlockSize(const Hargs&) { return dim3(BLOCK_SIZE); }
|
||||
|
||||
// in byte
|
||||
CK_TILE_HOST static constexpr auto GetSmemSize() { return 0; }
|
||||
|
||||
CK_TILE_DEVICE void operator()(Kargs kargs) const
|
||||
{
|
||||
index_t tokens = [&]() {
|
||||
if constexpr(Problem::LocalToken)
|
||||
{
|
||||
return reinterpret_cast<const index_t*>(kargs.p_local_tokens)[0];
|
||||
}
|
||||
else
|
||||
{
|
||||
return kargs.tokens;
|
||||
}
|
||||
}();
|
||||
|
||||
index_t mesh_stride = [&]() {
|
||||
if constexpr(Problem::LocalToken)
|
||||
{
|
||||
return impl::moe_sorting_mp_mesh_stride(tokens);
|
||||
}
|
||||
else
|
||||
{
|
||||
return kargs.mesh_stride;
|
||||
}
|
||||
}();
|
||||
|
||||
index_t row_size = mesh_stride; // impl::moe_sorting_mp_mesh_stride(tokens);
|
||||
index_t pixels = kargs.num_experts * row_size;
|
||||
index_t total_bytes = pixels * kargs.mesh_byte_size;
|
||||
index_t total_elems = total_bytes / 16; // always use dwordx4
|
||||
|
||||
using vector_type = ext_vector_t<index_t, 4>;
|
||||
vector_type* p_expert_mesh = reinterpret_cast<vector_type*>(kargs.p_expert_mesh);
|
||||
auto zero_ = vector_type{0};
|
||||
|
||||
for(index_t i = blockIdx.x * BLOCK_SIZE + threadIdx.x; i < total_elems;
|
||||
i += gridDim.x * BLOCK_SIZE)
|
||||
{
|
||||
p_expert_mesh[i] = zero_;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
// below kernel is multi-phase implementation for large token and/or expert case
|
||||
|
||||
// write into a buffer to record the token cnt
|
||||
@@ -1435,6 +1634,16 @@ struct MoeSortingMultiPhaseKernel_P0
|
||||
else
|
||||
return tokens;
|
||||
}();
|
||||
index_t mesh_stride = [&]() {
|
||||
if constexpr(Problem::LocalToken)
|
||||
{
|
||||
return impl::moe_sorting_mp_mesh_stride(tokens);
|
||||
}
|
||||
else
|
||||
{
|
||||
return kargs.mesh_stride;
|
||||
}
|
||||
}();
|
||||
index_t total_elem = rounded_tokens * kargs.topk_mdiv.divisor / Problem::SubTokenTile;
|
||||
|
||||
#pragma unroll Problem::SubTokenTile
|
||||
@@ -1449,12 +1658,11 @@ struct MoeSortingMultiPhaseKernel_P0
|
||||
if constexpr(Problem::LocalToken)
|
||||
{
|
||||
if(static_cast<index_t>(curr_token_id) < tokens)
|
||||
p_expert_mesh[eid * kargs.mesh_stride + curr_token_id] =
|
||||
p_expert_mesh[eid * mesh_stride + curr_token_id] =
|
||||
(curr_topk_id + 1) & 0xffff;
|
||||
}
|
||||
else
|
||||
p_expert_mesh[eid * kargs.mesh_stride + curr_token_id] =
|
||||
(curr_topk_id + 1) & 0xffff;
|
||||
p_expert_mesh[eid * mesh_stride + curr_token_id] = (curr_topk_id + 1) & 0xffff;
|
||||
});
|
||||
}
|
||||
}
|
||||
@@ -1479,6 +1687,7 @@ struct MoeSortingMultiPhaseKernel_P1
|
||||
struct Kargs
|
||||
{
|
||||
const void* p_local_expert_mask; // [expert]
|
||||
const void* p_local_tokens; // [1], if not nullptr, use this as actual tokens
|
||||
void* p_expert_mesh; // [expert, tokens]
|
||||
void* p_expert_cumsum;
|
||||
index_t mesh_stride; // mesh_stride for p_expert_mesh
|
||||
@@ -1488,6 +1697,7 @@ struct MoeSortingMultiPhaseKernel_P1
|
||||
{
|
||||
Kargs k;
|
||||
k.p_local_expert_mask = h.p_local_expert_mask;
|
||||
k.p_local_tokens = h.p_local_tokens;
|
||||
k.p_expert_mesh = h.p_ws;
|
||||
k.p_expert_cumsum = reinterpret_cast<void*>(
|
||||
reinterpret_cast<char*>(h.p_ws) +
|
||||
@@ -1511,12 +1721,9 @@ struct MoeSortingMultiPhaseKernel_P1
|
||||
{
|
||||
__shared__ char smem[GetSmemSize()];
|
||||
|
||||
int eid = blockIdx.x;
|
||||
|
||||
int eid = blockIdx.x;
|
||||
constexpr index_t index_pack = Problem::SubTokenTile; // always packed
|
||||
using r_t = ext_vector_t<MeshType, index_pack>; // always use int32x4
|
||||
r_t* p_expert_mesh = reinterpret_cast<r_t*>(
|
||||
reinterpret_cast<MeshType*>(kargs.p_expert_mesh) + eid * kargs.mesh_stride);
|
||||
|
||||
const IndexType* p_local_expert_mask =
|
||||
static_cast<const IndexType*>(kargs.p_local_expert_mask);
|
||||
@@ -1524,7 +1731,32 @@ struct MoeSortingMultiPhaseKernel_P1
|
||||
|
||||
auto f_sum = [](auto x_, auto y_) { return x_ + y_; };
|
||||
|
||||
int loops = (kargs.mesh_stride / index_pack + BLOCK_SIZE - 1) / BLOCK_SIZE;
|
||||
index_t tokens = [&]() {
|
||||
if constexpr(Problem::LocalToken)
|
||||
{
|
||||
return reinterpret_cast<const index_t*>(kargs.p_local_tokens)[0];
|
||||
}
|
||||
else
|
||||
{
|
||||
return 0; // will not use if not LocalToken
|
||||
}
|
||||
}();
|
||||
|
||||
index_t mesh_stride = [&]() {
|
||||
if constexpr(Problem::LocalToken)
|
||||
{
|
||||
return impl::moe_sorting_mp_mesh_stride(tokens);
|
||||
}
|
||||
else
|
||||
{
|
||||
return kargs.mesh_stride;
|
||||
}
|
||||
}();
|
||||
|
||||
r_t* p_expert_mesh = reinterpret_cast<r_t*>(
|
||||
reinterpret_cast<MeshType*>(kargs.p_expert_mesh) + eid * mesh_stride);
|
||||
|
||||
int loops = (mesh_stride / index_pack + BLOCK_SIZE - 1) / BLOCK_SIZE;
|
||||
|
||||
if constexpr(Problem::LocalExpertMasking)
|
||||
{
|
||||
@@ -1538,7 +1770,7 @@ struct MoeSortingMultiPhaseKernel_P1
|
||||
{
|
||||
int position = i * BLOCK_SIZE + threadIdx.x;
|
||||
r_t v{0};
|
||||
if(position < (kargs.mesh_stride / index_pack))
|
||||
if(position < (mesh_stride / index_pack))
|
||||
v = p_expert_mesh[position];
|
||||
index_t local_sum = 0;
|
||||
static_for<0, index_pack, 1>{}(
|
||||
@@ -1835,7 +2067,7 @@ struct MoeSortingMultiPhaseKernel_P2
|
||||
const void* p_local_tokens; // [1]
|
||||
void* p_expert_mesh; // [expert, tokens]
|
||||
void* p_expert_cumsum; // [expert + 1]
|
||||
void* p_total_tokens_post_pad; // [1]
|
||||
void* p_total_tokens_post_pad; // [2]
|
||||
void* p_sorted_expert_ids;
|
||||
void* p_moe_buf;
|
||||
index_t tokens;
|
||||
@@ -1863,15 +2095,36 @@ struct MoeSortingMultiPhaseKernel_P2
|
||||
k.mesh_stride = impl::moe_sorting_mp_mesh_stride(h.tokens);
|
||||
k.unit_size_mdiv = mdiv{static_cast<uint32_t>(h.unit_size)};
|
||||
|
||||
#if MOE_SORTING_FMOE_2D_BUF
|
||||
k.moe_buf_interm_dim = h.moe_buf_interm_dim;
|
||||
k.moe_buf_elem_bytes = h.moe_buf_elem_bytes;
|
||||
#else
|
||||
k.moe_buf_bytes = h.moe_buf_bytes;
|
||||
#endif
|
||||
|
||||
return k;
|
||||
}
|
||||
|
||||
CK_TILE_HOST static constexpr auto get_num_cu()
|
||||
{
|
||||
index_t num_cu = [&]() {
|
||||
hipDeviceProp_t dev_prop;
|
||||
hipDevice_t dev;
|
||||
HIP_CHECK_ERROR(hipGetDevice(&dev));
|
||||
HIP_CHECK_ERROR(hipGetDeviceProperties(&dev_prop, dev));
|
||||
return dev_prop.multiProcessorCount;
|
||||
}();
|
||||
return num_cu;
|
||||
}
|
||||
|
||||
CK_TILE_HOST static constexpr auto GridSize(const Hargs& h)
|
||||
{
|
||||
#if MOE_SORTING_FMOE_2D_BUF
|
||||
return dim3(h.num_experts + get_num_cu() * OCCUPANCY);
|
||||
#else
|
||||
// use 1 block to cumsum
|
||||
return dim3(1 + ck_tile::integer_divide_ceil(h.moe_buf_bytes, BLOCK_SIZE * 16));
|
||||
#endif
|
||||
}
|
||||
|
||||
CK_TILE_HOST static constexpr auto BlockSize(const Hargs&) { return dim3(BLOCK_SIZE); }
|
||||
@@ -1888,11 +2141,21 @@ struct MoeSortingMultiPhaseKernel_P2
|
||||
{
|
||||
if(blockIdx.x > 0)
|
||||
{
|
||||
#if MOE_SORTING_FMOE_2D_BUF
|
||||
impl::moe_buf_set_zero_kernel_2d<BLOCK_SIZE>(kargs.p_moe_buf,
|
||||
kargs.tokens,
|
||||
kargs.moe_buf_interm_dim,
|
||||
kargs.moe_buf_elem_bytes,
|
||||
blockIdx.x - 1,
|
||||
gridDim.x - 1);
|
||||
return;
|
||||
#else
|
||||
impl::moe_buf_set_zero_kernel<BLOCK_SIZE>(
|
||||
reinterpret_cast<uint8x16_t*>(kargs.p_moe_buf),
|
||||
kargs.moe_buf_bytes,
|
||||
blockIdx.x - 1);
|
||||
return;
|
||||
#endif
|
||||
}
|
||||
__shared__ char smem[GetSmemSize()];
|
||||
IndexType* s = reinterpret_cast<IndexType*>(smem);
|
||||
@@ -2223,7 +2486,7 @@ struct MoeSortingMultiPhaseKernel_P23
|
||||
const void* p_local_tokens; // [1]
|
||||
void* p_expert_mesh; // [expert, tokens]
|
||||
void* p_expert_cumsum; // [expert + 1]
|
||||
void* p_total_tokens_post_pad; // [1]
|
||||
void* p_total_tokens_post_pad; // [2]
|
||||
void* p_sorted_expert_ids;
|
||||
|
||||
void* p_sorted_token_ids;
|
||||
@@ -2235,7 +2498,17 @@ struct MoeSortingMultiPhaseKernel_P23
|
||||
index_t mesh_stride; // mesh_stride for p_expert_mesh
|
||||
mdiv unit_size_mdiv;
|
||||
mdiv topk_mdiv;
|
||||
long_index_t moe_buf_bytes;
|
||||
#if MOE_SORTING_FMOE_2D_BUF
|
||||
// NOTE:
|
||||
// moe_buf_* is a 2d ws buffer used for the following fmoe kernel
|
||||
// arranged as row*col, where row=tokens(or local_token), col=interm_dim
|
||||
// we fuse this clearing inside sorting kernel
|
||||
// Besides, we require inter_dim to be multiple of 16 byte(make sure when alloc ws for fmoe)
|
||||
index_t moe_buf_interm_dim; // p_moe_buf interm_dim
|
||||
index_t moe_buf_elem_bytes; // p_moe_buf byte size(8bit, 16bit, 32bit, etc.)
|
||||
#else
|
||||
long_index_t moe_buf_bytes; // byte size of p_moe_buf
|
||||
#endif
|
||||
};
|
||||
|
||||
CK_TILE_HOST static constexpr auto MakeKargs(const Hargs& h)
|
||||
@@ -2262,16 +2535,37 @@ struct MoeSortingMultiPhaseKernel_P23
|
||||
k.unit_size_mdiv = mdiv{static_cast<uint32_t>(h.unit_size)};
|
||||
k.topk_mdiv = mdiv{static_cast<uint32_t>(h.topk)};
|
||||
|
||||
#if MOE_SORTING_FMOE_2D_BUF
|
||||
k.moe_buf_interm_dim = h.moe_buf_interm_dim;
|
||||
k.moe_buf_elem_bytes = h.moe_buf_elem_bytes;
|
||||
#else
|
||||
k.moe_buf_bytes = h.moe_buf_bytes;
|
||||
#endif
|
||||
|
||||
return k;
|
||||
}
|
||||
|
||||
CK_TILE_HOST static constexpr auto get_num_cu()
|
||||
{
|
||||
index_t num_cu = [&]() {
|
||||
hipDeviceProp_t dev_prop;
|
||||
hipDevice_t dev;
|
||||
HIP_CHECK_ERROR(hipGetDevice(&dev));
|
||||
HIP_CHECK_ERROR(hipGetDeviceProperties(&dev_prop, dev));
|
||||
return dev_prop.multiProcessorCount;
|
||||
}();
|
||||
return num_cu;
|
||||
}
|
||||
|
||||
CK_TILE_HOST static constexpr auto GridSize(const Hargs& h)
|
||||
{
|
||||
#if MOE_SORTING_FMOE_2D_BUF
|
||||
return dim3(h.num_experts + get_num_cu() * OCCUPANCY);
|
||||
#else
|
||||
// use 1 block to cumsum
|
||||
// return dim3(1 + ck_tile::integer_divide_ceil(h.moe_buf_bytes, BLOCK_SIZE * 16));
|
||||
return dim3(h.num_experts + ck_tile::integer_divide_ceil(h.moe_buf_bytes, BLOCK_SIZE * 16));
|
||||
#endif
|
||||
}
|
||||
|
||||
CK_TILE_HOST static constexpr auto BlockSize(const Hargs&) { return dim3(BLOCK_SIZE); }
|
||||
@@ -2287,13 +2581,34 @@ struct MoeSortingMultiPhaseKernel_P23
|
||||
// reduce single pixel within a wave
|
||||
CK_TILE_DEVICE void operator()(Kargs kargs) const
|
||||
{
|
||||
index_t tokens = [&]() {
|
||||
if constexpr(Problem::LocalToken)
|
||||
{
|
||||
return reinterpret_cast<const index_t*>(kargs.p_local_tokens)[0];
|
||||
}
|
||||
else
|
||||
{
|
||||
return kargs.tokens;
|
||||
}
|
||||
}();
|
||||
|
||||
if(static_cast<index_t>(blockIdx.x) >= kargs.num_experts)
|
||||
{
|
||||
#if MOE_SORTING_FMOE_2D_BUF
|
||||
impl::moe_buf_set_zero_kernel_2d<BLOCK_SIZE>(kargs.p_moe_buf,
|
||||
tokens,
|
||||
kargs.moe_buf_interm_dim,
|
||||
kargs.moe_buf_elem_bytes,
|
||||
blockIdx.x - kargs.num_experts,
|
||||
gridDim.x - kargs.num_experts);
|
||||
return;
|
||||
#else
|
||||
impl::moe_buf_set_zero_kernel<BLOCK_SIZE>(
|
||||
reinterpret_cast<uint8x16_t*>(kargs.p_moe_buf),
|
||||
kargs.moe_buf_bytes,
|
||||
blockIdx.x - kargs.num_experts);
|
||||
return;
|
||||
#endif
|
||||
}
|
||||
|
||||
extern __shared__ char smem[];
|
||||
@@ -2428,13 +2743,15 @@ struct MoeSortingMultiPhaseKernel_P23
|
||||
{
|
||||
auto total_tokens_post_pad = prev_cumsum_a * kargs.unit_size_mdiv.divisor;
|
||||
if(blockIdx.x == 0)
|
||||
{
|
||||
p_total_tokens_post_pad[0] = total_tokens_post_pad;
|
||||
p_total_tokens_post_pad[1] = tokens;
|
||||
}
|
||||
p_expert_cumsum_smem[kargs.num_experts] = total_tokens_post_pad;
|
||||
}
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
|
||||
{
|
||||
const IndexType* p_local_expert_mask =
|
||||
static_cast<const IndexType*>(kargs.p_local_expert_mask);
|
||||
@@ -2463,14 +2780,14 @@ struct MoeSortingMultiPhaseKernel_P23
|
||||
return; // skip empty expert
|
||||
}
|
||||
|
||||
index_t tokens = [&]() {
|
||||
index_t mesh_stride = [&]() {
|
||||
if constexpr(Problem::LocalToken)
|
||||
{
|
||||
return reinterpret_cast<const index_t*>(kargs.p_local_tokens)[0];
|
||||
return impl::moe_sorting_mp_mesh_stride(tokens);
|
||||
}
|
||||
else
|
||||
{
|
||||
return kargs.tokens;
|
||||
return kargs.mesh_stride;
|
||||
}
|
||||
}();
|
||||
|
||||
@@ -2478,7 +2795,8 @@ struct MoeSortingMultiPhaseKernel_P23
|
||||
constexpr index_t index_pack = Problem::SubTokenTile; // always packed
|
||||
using r_t = ext_vector_t<MeshType, index_pack>; // always use int32x4
|
||||
using d_t = ext_vector_t<index_t, index_pack>;
|
||||
int loops = (kargs.mesh_stride / index_pack + BLOCK_SIZE - 1) / BLOCK_SIZE;
|
||||
int loops = (mesh_stride / index_pack + BLOCK_SIZE - 1) / BLOCK_SIZE;
|
||||
|
||||
int prev_cumsum = 0;
|
||||
|
||||
for(int i = 0; i < loops; i++)
|
||||
@@ -2487,8 +2805,7 @@ struct MoeSortingMultiPhaseKernel_P23
|
||||
r_t x_v = 0;
|
||||
if(i_token_pack < (tokens + index_pack - 1) / index_pack)
|
||||
{
|
||||
x_v = reinterpret_cast<r_t*>(p_expert_mesh +
|
||||
eid * kargs.mesh_stride)[i_token_pack];
|
||||
x_v = reinterpret_cast<r_t*>(p_expert_mesh + eid * mesh_stride)[i_token_pack];
|
||||
}
|
||||
|
||||
r_t x_r;
|
||||
|
||||
@@ -73,4 +73,12 @@ struct MoeSortingProblemMp
|
||||
SubTokenTile == 8 || SubTokenTile == 16);
|
||||
};
|
||||
|
||||
template <bool LocalToken_, index_t BlockSize_ = 1024, index_t Occu_ = 1>
|
||||
struct MoeSortingClearWorkspaceProblem
|
||||
{
|
||||
static constexpr bool LocalToken = LocalToken_;
|
||||
static constexpr index_t BlockSize = BlockSize_;
|
||||
static constexpr index_t Occu = Occu_;
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
Reference in New Issue
Block a user