mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-20 14:59:17 +00:00
[CK_TILE] optimize moe sorting kernel, boost large context case up to 20x (#2153)
* combine 2-3 as single stage * support zeroing * improve long tokens * update specialization * b16 ws * 8bit topk optimize * update 15 example
This commit is contained in:
@@ -19,6 +19,10 @@ namespace ck_tile {
|
||||
#define MOE_SORTING_USE_EX_KERNEL 1
|
||||
#endif
|
||||
|
||||
#ifndef MOE_SORTING_FUSE_MP_01
|
||||
#define MOE_SORTING_FUSE_MP_01 0
|
||||
#endif
|
||||
|
||||
// clang-format off
|
||||
// [indexing implementation-1]
|
||||
// using M_a as constexpr block_size to partition all tokens into different slices
|
||||
@@ -118,7 +122,7 @@ CK_TILE_HOST constexpr auto moe_sorting_get_smem_row_col(int tokens_, int num_ex
|
||||
int smem_cols = num_experts_ + 1; // usually experts is power of 2. padding here
|
||||
int smem_rows = [&](){
|
||||
index_t target_occupancy_ = 2;
|
||||
constexpr index_t total_ = 65536 / sizeof(int);
|
||||
constexpr index_t total_ = get_smem_capacity() / sizeof(index_t);
|
||||
constexpr index_t sub_unroll = 8;
|
||||
constexpr index_t cumsum_bufs = 2; // 1 for cumsum, 1 for cnt
|
||||
// at lease 2 lines, one for sub_token unroll, one for cumsum
|
||||
@@ -250,7 +254,7 @@ struct MoeSortingKernel
|
||||
{
|
||||
#if MOE_SORTING_USE_EX_KERNEL
|
||||
auto [smem_rows, smem_cols] = moe_sorting_get_smem_row_col(h.tokens, h.num_experts);
|
||||
return smem_rows * smem_cols * sizeof(int);
|
||||
return smem_rows * smem_cols * sizeof(index_t);
|
||||
#else
|
||||
const auto blocks = BlockSize(h);
|
||||
// usually num_experts is power of 2, we pad 1 dword here for the row-size
|
||||
@@ -1063,17 +1067,43 @@ CK_TILE_HOST_DEVICE index_t moe_sorting_mp_mesh_stride(index_t tokens)
|
||||
return (tokens + chunk - 1) / chunk * chunk;
|
||||
};
|
||||
|
||||
CK_TILE_HOST_DEVICE index_t moe_sorting_mp_mesh_elem(index_t tokens, index_t num_experts)
|
||||
// 4-i32 mesh, 2-i16 mseh, 1-i8 mesh
|
||||
CK_TILE_HOST index_t moe_sorting_mesh_byte_size(index_t tokens_,
|
||||
index_t /*num_experts_*/,
|
||||
index_t topk_)
|
||||
{
|
||||
// small token case, let's run mesh with dword score board
|
||||
if(tokens_ < 512)
|
||||
return 4;
|
||||
else
|
||||
{
|
||||
if(topk_ >= 255)
|
||||
return 2; // 16bit mesh
|
||||
else
|
||||
return 1; // 8bit mesh if small enough
|
||||
}
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE index_t moe_sorting_mp_mesh_smem_size(index_t tokens,
|
||||
index_t num_experts,
|
||||
index_t topk)
|
||||
{
|
||||
index_t row_size = moe_sorting_mp_mesh_stride(tokens);
|
||||
return num_experts * row_size;
|
||||
index_t elem = num_experts * row_size;
|
||||
return elem * moe_sorting_mesh_byte_size(tokens, num_experts, topk);
|
||||
};
|
||||
|
||||
CK_TILE_HOST_DEVICE index_t moe_sorting_mp_cumsum_elem(index_t num_experts)
|
||||
CK_TILE_HOST_DEVICE index_t moe_sorting_mp_cumsum_smem_size(index_t num_experts)
|
||||
{
|
||||
constexpr index_t chunk = 32;
|
||||
index_t row_size = num_experts + 1;
|
||||
return (row_size + chunk - 1) / chunk * chunk;
|
||||
return (row_size + chunk - 1) / chunk * chunk * sizeof(index_t);
|
||||
};
|
||||
|
||||
CK_TILE_HOST_DEVICE index_t moe_sorting_mp_sem_smem_size()
|
||||
{
|
||||
constexpr index_t chunk = 32;
|
||||
return chunk * sizeof(index_t);
|
||||
};
|
||||
|
||||
template <typename T, typename F, index_t wave_size_ = warpSize>
|
||||
@@ -1245,15 +1275,20 @@ CK_TILE_HOST bool moe_sorting_is_oneshot(int tokens_, int num_experts_)
|
||||
}
|
||||
|
||||
// return size in byte
|
||||
CK_TILE_HOST index_t moe_sorting_mp_get_workspace_size(int tokens_, int num_experts_)
|
||||
CK_TILE_HOST index_t moe_sorting_mp_get_workspace_size(int tokens_, int num_experts_, int topk_)
|
||||
{
|
||||
index_t elem = impl::moe_sorting_mp_mesh_elem(tokens_, num_experts_) +
|
||||
impl::moe_sorting_mp_cumsum_elem(num_experts_);
|
||||
return elem * sizeof(index_t);
|
||||
index_t s_ = impl::moe_sorting_mp_mesh_smem_size(tokens_, num_experts_, topk_) +
|
||||
impl::moe_sorting_mp_cumsum_smem_size(num_experts_)
|
||||
#if MOE_SORTING_FUSE_MP_01
|
||||
+ impl::moe_sorting_mp_sem_smem_size();
|
||||
#else
|
||||
;
|
||||
#endif
|
||||
return s_;
|
||||
}
|
||||
|
||||
// return size in byte
|
||||
CK_TILE_HOST index_t moe_sorting_get_workspace_size(int tokens_, int num_experts_)
|
||||
CK_TILE_HOST index_t moe_sorting_get_workspace_size(int tokens_, int num_experts_, int topk_)
|
||||
{
|
||||
#if 1
|
||||
if(moe_sorting_is_oneshot(tokens_, num_experts_))
|
||||
@@ -1262,10 +1297,10 @@ CK_TILE_HOST index_t moe_sorting_get_workspace_size(int tokens_, int num_experts
|
||||
}
|
||||
else
|
||||
{
|
||||
return moe_sorting_mp_get_workspace_size(tokens_, num_experts_);
|
||||
return moe_sorting_mp_get_workspace_size(tokens_, num_experts_, topk_);
|
||||
}
|
||||
#else
|
||||
return moe_sorting_mp_get_workspace_size(tokens_, num_experts_);
|
||||
return moe_sorting_mp_get_workspace_size(tokens_, num_experts_, topk_);
|
||||
#endif
|
||||
}
|
||||
|
||||
@@ -1320,6 +1355,7 @@ struct MoeSortingMultiPhaseKernel_P0
|
||||
|
||||
using IndexType = typename Problem::IndexType;
|
||||
using WeightType = typename Problem::WeightType;
|
||||
using MeshType = typename Problem::MeshType;
|
||||
|
||||
static constexpr index_t BLOCK_SIZE = 256;
|
||||
static constexpr index_t OCCUPANCY = 2; // hard coded
|
||||
@@ -1371,22 +1407,21 @@ struct MoeSortingMultiPhaseKernel_P0
|
||||
{
|
||||
using topk_id_t = ext_vector_t<IndexType, Problem::SubTokenTile>;
|
||||
|
||||
static_assert(Problem::SubTokenTile == 1 || Problem::SubTokenTile == 2 ||
|
||||
Problem::SubTokenTile == 4);
|
||||
|
||||
const topk_id_t* p_topk_ids = reinterpret_cast<const topk_id_t*>(kargs.p_topk_ids);
|
||||
IndexType* p_expert_mesh = reinterpret_cast<IndexType*>(kargs.p_expert_mesh);
|
||||
MeshType* p_expert_mesh = reinterpret_cast<MeshType*>(kargs.p_expert_mesh);
|
||||
index_t total_elem = kargs.tokens * kargs.topk_mdiv.divisor / Problem::SubTokenTile;
|
||||
|
||||
#pragma unroll Problem::SubTokenTile
|
||||
for(index_t i = blockIdx.x * BLOCK_SIZE + threadIdx.x; i < total_elem; i += blockDim.x)
|
||||
for(index_t i = blockIdx.x * BLOCK_SIZE + threadIdx.x; i < total_elem;
|
||||
i += gridDim.x * BLOCK_SIZE)
|
||||
{
|
||||
auto x = p_topk_ids[i];
|
||||
static_for<0, Problem::SubTokenTile, 1>{}([&](auto j) {
|
||||
IndexType eid = x[j.value]; // ext_vector_type must use int to []
|
||||
uint32_t curr_token_id, curr_topk_id;
|
||||
kargs.topk_mdiv.divmod(i * Problem::SubTokenTile + j, curr_token_id, curr_topk_id);
|
||||
p_expert_mesh[eid * kargs.mesh_stride + curr_token_id] = curr_topk_id + 1;
|
||||
p_expert_mesh[eid * kargs.mesh_stride + curr_token_id] =
|
||||
(curr_topk_id + 1) & 0xffff;
|
||||
});
|
||||
}
|
||||
}
|
||||
@@ -1400,6 +1435,7 @@ struct MoeSortingMultiPhaseKernel_P1
|
||||
|
||||
using IndexType = typename Problem::IndexType;
|
||||
using WeightType = typename Problem::WeightType;
|
||||
using MeshType = typename Problem::MeshType;
|
||||
|
||||
static constexpr index_t BLOCK_SIZE = 256;
|
||||
static constexpr index_t OCCUPANCY = 2; // hard coded
|
||||
@@ -1420,9 +1456,9 @@ struct MoeSortingMultiPhaseKernel_P1
|
||||
Kargs k;
|
||||
k.p_local_expert_mask = h.p_local_expert_mask;
|
||||
k.p_expert_mesh = h.p_ws;
|
||||
k.p_expert_cumsum =
|
||||
reinterpret_cast<void*>(reinterpret_cast<IndexType*>(h.p_ws) +
|
||||
impl::moe_sorting_mp_mesh_elem(h.tokens, h.num_experts));
|
||||
k.p_expert_cumsum = reinterpret_cast<void*>(
|
||||
reinterpret_cast<char*>(h.p_ws) +
|
||||
impl::moe_sorting_mp_mesh_smem_size(h.tokens, h.num_experts, h.topk));
|
||||
k.mesh_stride = impl::moe_sorting_mp_mesh_stride(h.tokens);
|
||||
|
||||
return k;
|
||||
@@ -1444,13 +1480,11 @@ struct MoeSortingMultiPhaseKernel_P1
|
||||
|
||||
int eid = blockIdx.x;
|
||||
|
||||
constexpr index_t index_pack = 4; // always packed
|
||||
using r_t = ext_vector_t<IndexType, index_pack>; // always use int32x4
|
||||
constexpr index_t index_pack = Problem::SubTokenTile; // always packed
|
||||
using r_t = ext_vector_t<MeshType, index_pack>; // always use int32x4
|
||||
r_t* p_expert_mesh = reinterpret_cast<r_t*>(
|
||||
reinterpret_cast<index_t*>(kargs.p_expert_mesh) + eid * kargs.mesh_stride);
|
||||
reinterpret_cast<MeshType*>(kargs.p_expert_mesh) + eid * kargs.mesh_stride);
|
||||
|
||||
static_assert(Problem::SubTokenTile == 1 || Problem::SubTokenTile == 2 ||
|
||||
Problem::SubTokenTile == 4);
|
||||
const IndexType* p_local_expert_mask =
|
||||
static_cast<const IndexType*>(kargs.p_local_expert_mask);
|
||||
IndexType* p_expert_cumsum = reinterpret_cast<IndexType*>(kargs.p_expert_cumsum);
|
||||
@@ -1502,6 +1536,197 @@ struct MoeSortingMultiPhaseKernel_P1
|
||||
}
|
||||
};
|
||||
|
||||
#if MOE_SORTING_FUSE_MP_01
|
||||
template <typename Problem_>
|
||||
struct MoeSortingMultiPhaseKernel_P01
|
||||
{
|
||||
using Problem = remove_cvref_t<Problem_>;
|
||||
|
||||
using IndexType = typename Problem::IndexType;
|
||||
using WeightType = typename Problem::WeightType;
|
||||
using MeshType = typename Problem::MeshType;
|
||||
|
||||
static constexpr index_t BLOCK_SIZE = 256;
|
||||
static constexpr index_t OCCUPANCY = 2; // hard coded
|
||||
|
||||
typedef MoeSortingHostArgs MoeSortingKargs;
|
||||
|
||||
using Hargs = MoeSortingHostArgs;
|
||||
|
||||
struct Kargs
|
||||
{
|
||||
const void* p_topk_ids; // [tokens, topk]
|
||||
const void* p_local_expert_mask; // [expert]
|
||||
void* p_expert_mesh; // [expert, tokens]
|
||||
void* p_expert_cumsum; // [expert + 1]
|
||||
void* p_expert_sem; // [1]
|
||||
index_t tokens;
|
||||
index_t num_experts;
|
||||
index_t mesh_stride; // mesh_stride for p_expert_mesh
|
||||
index_t wg_count; // used for semaphore
|
||||
mdiv topk_mdiv;
|
||||
};
|
||||
|
||||
CK_TILE_HOST static constexpr auto get_num_cu()
|
||||
{
|
||||
index_t num_cu = [&]() {
|
||||
hipDeviceProp_t dev_prop;
|
||||
hipDevice_t dev;
|
||||
HIP_CHECK_ERROR(hipGetDevice(&dev));
|
||||
HIP_CHECK_ERROR(hipGetDeviceProperties(&dev_prop, dev));
|
||||
return dev_prop.multiProcessorCount;
|
||||
}();
|
||||
return num_cu;
|
||||
}
|
||||
|
||||
CK_TILE_HOST static constexpr auto MakeKargs(const Hargs& h)
|
||||
{
|
||||
Kargs k;
|
||||
k.p_topk_ids = h.p_topk_ids;
|
||||
k.p_local_expert_mask = h.p_local_expert_mask;
|
||||
k.p_expert_mesh = h.p_ws;
|
||||
k.p_expert_cumsum = reinterpret_cast<void*>(
|
||||
reinterpret_cast<char*>(h.p_ws) +
|
||||
impl::moe_sorting_mp_mesh_smem_size(h.tokens, h.num_experts, h.topk));
|
||||
k.p_expert_sem = reinterpret_cast<void*>(
|
||||
reinterpret_cast<char*>(h.p_ws) +
|
||||
impl::moe_sorting_mp_mesh_smem_size(h.tokens, h.num_experts, h.topk) +
|
||||
impl::moe_sorting_mp_cumsum_smem_size(h.num_experts));
|
||||
k.tokens = h.tokens;
|
||||
k.num_experts = h.num_experts;
|
||||
k.mesh_stride = impl::moe_sorting_mp_mesh_stride(h.tokens);
|
||||
k.wg_count = WGCounts(h);
|
||||
k.topk_mdiv = mdiv{static_cast<uint32_t>(h.topk)};
|
||||
return k;
|
||||
}
|
||||
|
||||
CK_TILE_HOST static constexpr auto GridSize(const Hargs&) { return get_num_cu() * OCCUPANCY; }
|
||||
|
||||
CK_TILE_HOST static constexpr auto BlockSize(const Hargs&) { return dim3(BLOCK_SIZE); }
|
||||
|
||||
CK_TILE_HOST static constexpr auto WGCounts(const Hargs& h)
|
||||
{
|
||||
index_t total_elem = h.tokens * h.topk / Problem::SubTokenTile;
|
||||
index_t elem_cnt = (total_elem + BLOCK_SIZE - 1) / BLOCK_SIZE;
|
||||
|
||||
// no more than grid_size
|
||||
return min(elem_cnt, GridSize(h));
|
||||
}
|
||||
|
||||
// in byte
|
||||
CK_TILE_HOST static constexpr auto GetSmemSize()
|
||||
{
|
||||
return BLOCK_SIZE / warpSize * sizeof(IndexType);
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE void operator()(Kargs kargs) const
|
||||
{
|
||||
workgroup_barrier wb{reinterpret_cast<uint32_t*>(kargs.p_expert_sem)};
|
||||
|
||||
{
|
||||
using topk_id_t = ext_vector_t<IndexType, Problem::SubTokenTile>;
|
||||
|
||||
const topk_id_t* p_topk_ids = reinterpret_cast<const topk_id_t*>(kargs.p_topk_ids);
|
||||
IndexType* p_expert_mesh = reinterpret_cast<IndexType*>(kargs.p_expert_mesh);
|
||||
index_t total_elem = kargs.tokens * kargs.topk_mdiv.divisor / Problem::SubTokenTile;
|
||||
|
||||
#pragma unroll Problem::SubTokenTile
|
||||
for(index_t i = blockIdx.x * BLOCK_SIZE + threadIdx.x; i < total_elem;
|
||||
i += BLOCK_SIZE * gridDim.x)
|
||||
{
|
||||
auto x = p_topk_ids[i];
|
||||
static_for<0, Problem::SubTokenTile, 1>{}([&](auto j) {
|
||||
IndexType eid = x[j.value]; // ext_vector_type must use int to []
|
||||
uint32_t curr_token_id, curr_topk_id;
|
||||
kargs.topk_mdiv.divmod(
|
||||
i * Problem::SubTokenTile + j, curr_token_id, curr_topk_id);
|
||||
p_expert_mesh[eid * kargs.mesh_stride + curr_token_id] = curr_topk_id + 1;
|
||||
});
|
||||
}
|
||||
if(static_cast<index_t>(blockIdx.x) < kargs.wg_count)
|
||||
{
|
||||
wb.inc();
|
||||
}
|
||||
}
|
||||
|
||||
{
|
||||
__shared__ char smem[GetSmemSize()];
|
||||
int eid = blockIdx.x;
|
||||
|
||||
// early exist in case of extra atomic wait
|
||||
if(eid >= kargs.num_experts)
|
||||
return;
|
||||
|
||||
wb.wait_lt(kargs.wg_count);
|
||||
|
||||
for(; eid < kargs.num_experts; eid += gridDim.x)
|
||||
{
|
||||
// if(threadIdx.x == 0)
|
||||
// printf("!!! bid:%d, eid:%d (%d, %d)\n",
|
||||
// static_cast<int>(blockIdx.x),
|
||||
// eid,
|
||||
// kargs.num_experts,
|
||||
// static_cast<int>(blockDim.x));
|
||||
constexpr index_t index_pack = 4; // always packed
|
||||
using r_t = ext_vector_t<IndexType, index_pack>; // always use int32x4
|
||||
r_t* p_expert_mesh = reinterpret_cast<r_t*>(
|
||||
reinterpret_cast<index_t*>(kargs.p_expert_mesh) + eid * kargs.mesh_stride);
|
||||
|
||||
const IndexType* p_local_expert_mask =
|
||||
static_cast<const IndexType*>(kargs.p_local_expert_mask);
|
||||
IndexType* p_expert_cumsum = reinterpret_cast<IndexType*>(kargs.p_expert_cumsum);
|
||||
|
||||
auto f_sum = [](auto x_, auto y_) { return x_ + y_; };
|
||||
|
||||
int loops = (kargs.mesh_stride / index_pack + BLOCK_SIZE - 1) / BLOCK_SIZE;
|
||||
|
||||
if constexpr(Problem::LocalExpertMasking)
|
||||
{
|
||||
IndexType mask = p_local_expert_mask[eid];
|
||||
if(mask == 0)
|
||||
continue; // skip
|
||||
}
|
||||
|
||||
index_t cnt = 0; // per-wave cnt
|
||||
for(int i = 0; i < loops; i++)
|
||||
{
|
||||
int position = i * BLOCK_SIZE + threadIdx.x;
|
||||
r_t v{0};
|
||||
if(position < (kargs.mesh_stride / index_pack))
|
||||
v = p_expert_mesh[position];
|
||||
index_t local_sum = 0;
|
||||
static_for<0, index_pack, 1>{}(
|
||||
[&](auto i_vec) { local_sum += v[i_vec.value] != 0 ? 1 : 0; });
|
||||
cnt += impl::moe_sorting_wave_reduce(local_sum, f_sum);
|
||||
}
|
||||
|
||||
index_t lane_id = threadIdx.x % warpSize;
|
||||
index_t wave_id = threadIdx.x / warpSize;
|
||||
|
||||
// reduce cross wave
|
||||
IndexType* s = reinterpret_cast<IndexType*>(smem);
|
||||
__syncthreads();
|
||||
if(lane_id == 0)
|
||||
{
|
||||
s[wave_id] = cnt;
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
if(threadIdx.x == 0)
|
||||
{
|
||||
index_t c = 0;
|
||||
for(auto i = 0; i < (BLOCK_SIZE / warpSize); i++)
|
||||
{
|
||||
c += s[i];
|
||||
}
|
||||
p_expert_cumsum[eid] = c;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
#endif
|
||||
|
||||
// token count cumsum
|
||||
template <typename Problem_>
|
||||
struct MoeSortingMultiPhaseKernel_P2
|
||||
@@ -1510,6 +1735,7 @@ struct MoeSortingMultiPhaseKernel_P2
|
||||
|
||||
using IndexType = typename Problem::IndexType;
|
||||
using WeightType = typename Problem::WeightType;
|
||||
using MeshType = typename Problem::MeshType;
|
||||
|
||||
static constexpr index_t BLOCK_SIZE = 256;
|
||||
static constexpr index_t OCCUPANCY = 2; // hard coded
|
||||
@@ -1536,10 +1762,9 @@ struct MoeSortingMultiPhaseKernel_P2
|
||||
{
|
||||
Kargs k;
|
||||
k.p_local_expert_mask = h.p_local_expert_mask;
|
||||
// k.p_expert_mesh = h.p_ws;
|
||||
k.p_expert_cumsum =
|
||||
reinterpret_cast<void*>(reinterpret_cast<IndexType*>(h.p_ws) +
|
||||
impl::moe_sorting_mp_mesh_elem(h.tokens, h.num_experts));
|
||||
k.p_expert_cumsum = reinterpret_cast<void*>(
|
||||
reinterpret_cast<char*>(h.p_ws) +
|
||||
impl::moe_sorting_mp_mesh_smem_size(h.tokens, h.num_experts, h.topk));
|
||||
k.p_total_tokens_post_pad = h.p_total_tokens_post_pad;
|
||||
k.p_sorted_expert_ids = h.p_sorted_expert_ids;
|
||||
|
||||
@@ -1566,7 +1791,8 @@ struct MoeSortingMultiPhaseKernel_P2
|
||||
// in byte
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetSmemSize()
|
||||
{
|
||||
return 2 * BLOCK_SIZE * sizeof(IndexType);
|
||||
// return 2 * BLOCK_SIZE * sizeof(IndexType);
|
||||
return (4 + 2 * BLOCK_SIZE / warpSize) * sizeof(IndexType);
|
||||
}
|
||||
|
||||
// reduce single pixel within a wave
|
||||
@@ -1718,6 +1944,7 @@ struct MoeSortingMultiPhaseKernel_P3
|
||||
|
||||
using IndexType = typename Problem::IndexType;
|
||||
using WeightType = typename Problem::WeightType;
|
||||
using MeshType = typename Problem::MeshType;
|
||||
|
||||
static constexpr index_t BLOCK_SIZE = 256;
|
||||
static constexpr index_t OCCUPANCY = 2; // hard coded
|
||||
@@ -1749,9 +1976,9 @@ struct MoeSortingMultiPhaseKernel_P3
|
||||
k.p_sorted_token_ids = h.p_sorted_token_ids;
|
||||
k.p_sorted_weights = h.p_sorted_weights;
|
||||
k.p_expert_mesh = h.p_ws;
|
||||
k.p_expert_cumsum =
|
||||
reinterpret_cast<void*>(reinterpret_cast<IndexType*>(h.p_ws) +
|
||||
impl::moe_sorting_mp_mesh_elem(h.tokens, h.num_experts));
|
||||
k.p_expert_cumsum = reinterpret_cast<void*>(
|
||||
reinterpret_cast<char*>(h.p_ws) +
|
||||
impl::moe_sorting_mp_mesh_smem_size(h.tokens, h.num_experts, h.topk));
|
||||
k.tokens = h.tokens;
|
||||
k.num_experts = h.num_experts;
|
||||
k.topk_mdiv = mdiv{static_cast<uint32_t>(h.topk)};
|
||||
@@ -1782,9 +2009,6 @@ struct MoeSortingMultiPhaseKernel_P3
|
||||
const WeightType* p_weights = static_cast<const WeightType*>(kargs.p_weights);
|
||||
WeightType* p_sorted_weights = reinterpret_cast<WeightType*>(kargs.p_sorted_weights);
|
||||
|
||||
static_assert(Problem::SubTokenTile == 1 || Problem::SubTokenTile == 2 ||
|
||||
Problem::SubTokenTile == 4);
|
||||
|
||||
int eid = blockIdx.x;
|
||||
int wave_id = threadIdx.x / warpSize;
|
||||
int lane_id = threadIdx.x % warpSize;
|
||||
@@ -1866,6 +2090,495 @@ struct MoeSortingMultiPhaseKernel_P3
|
||||
}
|
||||
};
|
||||
|
||||
namespace impl {
|
||||
// we use dynamic LDS size here
|
||||
CK_TILE_HOST constexpr auto moe_sorting_get_smem_size_p23(int num_experts_)
|
||||
{
|
||||
constexpr index_t BLOCK_SIZE = 256; // hardcoded 256
|
||||
const index_t expert_cumsum_elem = num_experts_ + 1;
|
||||
return (4 + 2 * BLOCK_SIZE / warpSize + expert_cumsum_elem) * sizeof(int);
|
||||
}
|
||||
} // namespace impl
|
||||
|
||||
// token count cumsum
|
||||
template <typename Problem_>
|
||||
struct MoeSortingMultiPhaseKernel_P23
|
||||
{
|
||||
using Problem = remove_cvref_t<Problem_>;
|
||||
|
||||
using IndexType = typename Problem::IndexType;
|
||||
using WeightType = typename Problem::WeightType;
|
||||
using MeshType = typename Problem::MeshType;
|
||||
|
||||
static constexpr index_t BLOCK_SIZE = 256;
|
||||
static constexpr index_t OCCUPANCY = 2; // hard coded
|
||||
|
||||
typedef MoeSortingHostArgs MoeSortingKargs;
|
||||
|
||||
using Hargs = MoeSortingHostArgs;
|
||||
struct Kargs
|
||||
{
|
||||
const void* p_weights;
|
||||
const void* p_local_expert_mask; // [expert]
|
||||
void* p_expert_mesh; // [expert, tokens]
|
||||
void* p_expert_cumsum; // [expert + 1]
|
||||
void* p_total_tokens_post_pad; // [1]
|
||||
void* p_sorted_expert_ids;
|
||||
|
||||
void* p_sorted_token_ids;
|
||||
void* p_sorted_weights;
|
||||
void* p_moe_buf;
|
||||
|
||||
index_t tokens;
|
||||
index_t num_experts;
|
||||
index_t mesh_stride; // mesh_stride for p_expert_mesh
|
||||
mdiv unit_size_mdiv;
|
||||
mdiv topk_mdiv;
|
||||
long_index_t moe_buf_bytes;
|
||||
};
|
||||
|
||||
CK_TILE_HOST static constexpr auto MakeKargs(const Hargs& h)
|
||||
{
|
||||
Kargs k;
|
||||
k.p_weights = h.p_weights;
|
||||
k.p_local_expert_mask = h.p_local_expert_mask;
|
||||
k.p_expert_mesh = h.p_ws;
|
||||
k.p_expert_cumsum = reinterpret_cast<void*>(
|
||||
reinterpret_cast<char*>(h.p_ws) +
|
||||
impl::moe_sorting_mp_mesh_smem_size(h.tokens, h.num_experts, h.topk));
|
||||
k.p_total_tokens_post_pad = h.p_total_tokens_post_pad;
|
||||
k.p_sorted_expert_ids = h.p_sorted_expert_ids;
|
||||
|
||||
k.p_sorted_token_ids = h.p_sorted_token_ids;
|
||||
k.p_sorted_weights = h.p_sorted_weights;
|
||||
|
||||
k.p_moe_buf = h.p_moe_buf;
|
||||
|
||||
k.tokens = h.tokens;
|
||||
k.num_experts = h.num_experts;
|
||||
k.mesh_stride = impl::moe_sorting_mp_mesh_stride(h.tokens);
|
||||
k.unit_size_mdiv = mdiv{static_cast<uint32_t>(h.unit_size)};
|
||||
k.topk_mdiv = mdiv{static_cast<uint32_t>(h.topk)};
|
||||
|
||||
k.moe_buf_bytes = h.moe_buf_bytes;
|
||||
|
||||
return k;
|
||||
}
|
||||
|
||||
CK_TILE_HOST static constexpr auto GridSize(const Hargs& h)
|
||||
{
|
||||
// use 1 block to cumsum
|
||||
// return dim3(1 + ck_tile::integer_divide_ceil(h.moe_buf_bytes, BLOCK_SIZE * 16));
|
||||
return dim3(h.num_experts + ck_tile::integer_divide_ceil(h.moe_buf_bytes, BLOCK_SIZE * 16));
|
||||
}
|
||||
|
||||
CK_TILE_HOST static constexpr auto BlockSize(const Hargs&) { return dim3(BLOCK_SIZE); }
|
||||
|
||||
// only use this at host !
|
||||
CK_TILE_HOST static constexpr auto GetSmemSize(const Hargs& h)
|
||||
{
|
||||
const auto smem_23 = impl::moe_sorting_get_smem_size_p23(h.num_experts);
|
||||
const auto smem_sf = BLOCK_SIZE * 4 * sizeof(IndexType);
|
||||
return max(smem_23, smem_sf);
|
||||
}
|
||||
|
||||
// reduce single pixel within a wave
|
||||
CK_TILE_DEVICE void operator()(Kargs kargs) const
|
||||
{
|
||||
if(static_cast<index_t>(blockIdx.x) >= kargs.num_experts)
|
||||
{
|
||||
impl::moe_buf_set_zero_kernel<BLOCK_SIZE>(
|
||||
reinterpret_cast<uint8x16_t*>(kargs.p_moe_buf),
|
||||
kargs.moe_buf_bytes,
|
||||
blockIdx.x - kargs.num_experts);
|
||||
return;
|
||||
}
|
||||
|
||||
extern __shared__ char smem[];
|
||||
{
|
||||
IndexType* s = reinterpret_cast<IndexType*>(smem);
|
||||
|
||||
const IndexType* p_local_expert_mask =
|
||||
static_cast<const IndexType*>(kargs.p_local_expert_mask);
|
||||
IndexType* p_expert_cumsum = reinterpret_cast<IndexType*>(kargs.p_expert_cumsum);
|
||||
IndexType* p_expert_cumsum_smem = s + 4 + 2 * BLOCK_SIZE / warpSize;
|
||||
IndexType* p_total_tokens_post_pad =
|
||||
reinterpret_cast<IndexType*>(kargs.p_total_tokens_post_pad);
|
||||
IndexType* p_sorted_expert_ids =
|
||||
reinterpret_cast<IndexType*>(kargs.p_sorted_expert_ids);
|
||||
|
||||
const index_t loops = (kargs.num_experts + BLOCK_SIZE - 1) / BLOCK_SIZE;
|
||||
index_t wave_id = threadIdx.x / warpSize;
|
||||
index_t lane_id = threadIdx.x % warpSize;
|
||||
|
||||
IndexType prev_cumsum_a = 0;
|
||||
IndexType prev_cumsum_b = 0;
|
||||
|
||||
for(index_t i = 0; i < loops; i++)
|
||||
{
|
||||
index_t position = i * BLOCK_SIZE + threadIdx.x;
|
||||
IndexType a_ = 0; // token count for a expert
|
||||
IndexType b_ = 0; // mask for a expert
|
||||
if(position < kargs.num_experts)
|
||||
{
|
||||
a_ = p_expert_cumsum[position];
|
||||
if constexpr(Problem::LocalExpertMasking)
|
||||
b_ = p_local_expert_mask[position];
|
||||
}
|
||||
|
||||
int blocks_pers_expert =
|
||||
kargs.unit_size_mdiv.div(a_ + kargs.unit_size_mdiv.divisor - 1);
|
||||
// pad token
|
||||
int padded_blocks_per_expert = [&]() {
|
||||
int x_ = [&]() {
|
||||
if constexpr(Problem::SkipExpertsWithZeroTokens)
|
||||
{
|
||||
// if local_cnt is zero, blocks_pers_expert will be zero
|
||||
// this is what we want to achieve
|
||||
return blocks_pers_expert; // * kargs.unit_size_mdiv.divisor;
|
||||
}
|
||||
else
|
||||
{
|
||||
return max(blocks_pers_expert, 1);
|
||||
}
|
||||
}();
|
||||
if constexpr(Problem::LocalExpertMasking)
|
||||
{
|
||||
return b_ ? x_ : 0;
|
||||
}
|
||||
else
|
||||
return x_;
|
||||
}();
|
||||
|
||||
IndexType cumsum_a = padded_blocks_per_expert;
|
||||
IndexType cumsum_b = b_;
|
||||
|
||||
// Note: we first cumsum local round, then add previous cumsum
|
||||
impl::moe_sorting_wave_cumsum<IndexType, warpSize>(cumsum_a);
|
||||
impl::moe_sorting_wave_cumsum<IndexType, warpSize>(cumsum_b);
|
||||
|
||||
__syncthreads();
|
||||
if(lane_id == warpSize - 1)
|
||||
{
|
||||
s[4 + wave_id] = cumsum_a;
|
||||
s[4 + wave_id + BLOCK_SIZE / warpSize] = cumsum_b;
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
|
||||
// reduce cross wave
|
||||
static_for<0, BLOCK_SIZE / warpSize - 1, 1>{}([&](auto i_w) {
|
||||
IndexType prev_a = s[4 + i_w];
|
||||
IndexType prev_b = s[4 + i_w + BLOCK_SIZE / warpSize];
|
||||
prev_a = wave_id > i_w ? prev_a : 0; // mask out
|
||||
prev_b = wave_id > i_w ? prev_b : 0; // mask out
|
||||
cumsum_a += prev_a;
|
||||
cumsum_b += prev_b;
|
||||
});
|
||||
|
||||
// Now let's add previous cumsum
|
||||
cumsum_a += prev_cumsum_a;
|
||||
cumsum_b += prev_cumsum_b;
|
||||
|
||||
if(threadIdx.x == BLOCK_SIZE - 1)
|
||||
{
|
||||
s[2] = cumsum_a; // store the last cumsum
|
||||
s[3] = cumsum_b;
|
||||
}
|
||||
|
||||
IndexType out_0 = cumsum_a - padded_blocks_per_expert; // exclusive cumsum tok cnt
|
||||
IndexType out_1 = cumsum_b - b_; // exclusive cumsum mask cnt
|
||||
|
||||
__syncthreads();
|
||||
prev_cumsum_a = s[2];
|
||||
prev_cumsum_b = s[3];
|
||||
|
||||
if(position < kargs.num_experts)
|
||||
{
|
||||
p_expert_cumsum_smem[position] = out_0 * kargs.unit_size_mdiv.divisor;
|
||||
}
|
||||
|
||||
{
|
||||
if(blockIdx.x == 0)
|
||||
{
|
||||
if constexpr(Problem::LocalExpertMasking)
|
||||
{
|
||||
if(b_)
|
||||
{
|
||||
for(int j = 0; j < blocks_pers_expert; j++)
|
||||
{
|
||||
p_sorted_expert_ids[out_0 + j] = out_1;
|
||||
}
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
for(int j = 0; j < blocks_pers_expert; j++)
|
||||
{
|
||||
p_sorted_expert_ids[out_0 + j] = position;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if(threadIdx.x == 0)
|
||||
{
|
||||
auto total_tokens_post_pad = prev_cumsum_a * kargs.unit_size_mdiv.divisor;
|
||||
if(blockIdx.x == 0)
|
||||
p_total_tokens_post_pad[0] = total_tokens_post_pad;
|
||||
p_expert_cumsum_smem[kargs.num_experts] = total_tokens_post_pad;
|
||||
}
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
|
||||
{
|
||||
const IndexType* p_local_expert_mask =
|
||||
static_cast<const IndexType*>(kargs.p_local_expert_mask);
|
||||
IndexType* s = reinterpret_cast<IndexType*>(smem);
|
||||
MeshType* p_expert_mesh = reinterpret_cast<MeshType*>(kargs.p_expert_mesh);
|
||||
IndexType* p_sorted_token_ids = reinterpret_cast<IndexType*>(kargs.p_sorted_token_ids);
|
||||
IndexType* p_expert_cumsum_smem = s + 4 + 2 * BLOCK_SIZE / warpSize;
|
||||
const WeightType* p_weights = static_cast<const WeightType*>(kargs.p_weights);
|
||||
WeightType* p_sorted_weights = reinterpret_cast<WeightType*>(kargs.p_sorted_weights);
|
||||
|
||||
int eid = blockIdx.x;
|
||||
int wave_id = threadIdx.x / warpSize;
|
||||
int lane_id = threadIdx.x % warpSize;
|
||||
int e_start = p_expert_cumsum_smem[eid];
|
||||
int e_end = p_expert_cumsum_smem[eid + 1];
|
||||
if constexpr(Problem::SkipExpertsWithZeroTokens)
|
||||
{
|
||||
if(e_start == e_end)
|
||||
return;
|
||||
}
|
||||
|
||||
if constexpr(Problem::LocalExpertMasking)
|
||||
{
|
||||
int e_mask = p_local_expert_mask[eid];
|
||||
if(e_mask == 0)
|
||||
return; // skip empty expert
|
||||
}
|
||||
|
||||
// cumsum one by one
|
||||
constexpr index_t index_pack = Problem::SubTokenTile; // always packed
|
||||
using r_t = ext_vector_t<MeshType, index_pack>; // always use int32x4
|
||||
using d_t = ext_vector_t<index_t, index_pack>;
|
||||
int loops = (kargs.mesh_stride / index_pack + BLOCK_SIZE - 1) / BLOCK_SIZE;
|
||||
int prev_cumsum = 0;
|
||||
|
||||
for(int i = 0; i < loops; i++)
|
||||
{
|
||||
int i_token_pack = i * BLOCK_SIZE + threadIdx.x;
|
||||
r_t x_v = 0;
|
||||
if(i_token_pack < (kargs.tokens + index_pack - 1) / index_pack)
|
||||
{
|
||||
x_v = reinterpret_cast<r_t*>(p_expert_mesh +
|
||||
eid * kargs.mesh_stride)[i_token_pack];
|
||||
}
|
||||
|
||||
r_t x_r;
|
||||
#if 0
|
||||
if constexpr(index_pack != 1)
|
||||
{
|
||||
// shuffle, we must have contiguout thread holds contiguout token
|
||||
__syncthreads();
|
||||
reinterpret_cast<r_t*>(s)[threadIdx.x] = x_v;
|
||||
__syncthreads();
|
||||
|
||||
static_for<0, index_pack, 1>{}([&](auto j_) {
|
||||
constexpr auto j = j_.value;
|
||||
x_r[j] = reinterpret_cast<MeshType*>(s)[threadIdx.x + j * BLOCK_SIZE];
|
||||
});
|
||||
}
|
||||
#else
|
||||
x_r = x_v;
|
||||
#endif
|
||||
{
|
||||
#if 0
|
||||
#pragma unroll
|
||||
for(int j = 0; j < index_pack / 2; j++)
|
||||
{
|
||||
int i_token = i * BLOCK_SIZE * index_pack + threadIdx.x + j * BLOCK_SIZE;
|
||||
index_t x = x_d[j];
|
||||
int i_topk = x - 1; // topk of this token
|
||||
int i_show = x != 0 ? 1 : 0; // has this token or not
|
||||
int cumsum = i_show;
|
||||
impl::moe_sorting_wave_cumsum<int, warpSize>(cumsum);
|
||||
|
||||
__syncthreads();
|
||||
if(lane_id == warpSize - 1)
|
||||
{
|
||||
s[4 + wave_id] = cumsum;
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
// reduce cross wave
|
||||
static_for<0, BLOCK_SIZE / warpSize - 1, 1>{}([&](auto i_w) {
|
||||
IndexType prev = s[4 + i_w];
|
||||
prev = wave_id > i_w ? prev : 0; // mask out
|
||||
cumsum += prev;
|
||||
});
|
||||
cumsum += prev_cumsum; // add previous round cumsum
|
||||
if(threadIdx.x == BLOCK_SIZE - 1)
|
||||
{
|
||||
s[0] = cumsum;
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
int position = cumsum - i_show;
|
||||
prev_cumsum = s[0]; // update the last cumsum
|
||||
|
||||
if(i_show)
|
||||
{
|
||||
#if CK_TILE_REFERENCE_MOE_SORTING_MOCK_ID
|
||||
p_sorted_token_ids[e_start + position] =
|
||||
MOE_SORTING_MOCK_ID(i_token, i_topk);
|
||||
#else
|
||||
p_sorted_token_ids[e_start + position] = i_token;
|
||||
#endif
|
||||
p_sorted_weights[e_start + position] =
|
||||
p_weights[i_token * kargs.topk_mdiv.divisor + i_topk];
|
||||
}
|
||||
}
|
||||
#endif
|
||||
{
|
||||
d_t i_topk;
|
||||
d_t i_show;
|
||||
// = 0;
|
||||
int cumsum_store = 0;
|
||||
|
||||
static_for<0, index_pack, 1>{}([&](auto j_) {
|
||||
constexpr auto j = j_.value;
|
||||
i_topk[j] = static_cast<index_t>(x_r[j] - 1);
|
||||
i_show[j] = static_cast<index_t>(x_r[j] != 0 ? 1 : 0);
|
||||
cumsum_store += i_show[j];
|
||||
});
|
||||
int cumsum = cumsum_store;
|
||||
impl::moe_sorting_wave_cumsum<int, warpSize>(cumsum);
|
||||
|
||||
__syncthreads();
|
||||
if(lane_id == warpSize - 1)
|
||||
{
|
||||
s[4 + wave_id] = cumsum;
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
// reduce cross wave
|
||||
static_for<0, BLOCK_SIZE / warpSize - 1, 1>{}([&](auto i_w) {
|
||||
IndexType prev = s[4 + i_w];
|
||||
prev = wave_id > i_w ? prev : 0; // mask out
|
||||
cumsum += prev;
|
||||
});
|
||||
cumsum += prev_cumsum; // add previous round cumsum
|
||||
if(threadIdx.x == BLOCK_SIZE - 1)
|
||||
{
|
||||
s[0] = cumsum;
|
||||
}
|
||||
__syncthreads();
|
||||
prev_cumsum = s[0]; // update the last cumsum
|
||||
|
||||
int position = cumsum - cumsum_store;
|
||||
static_for<0, index_pack, 1>{}([&](auto j_) {
|
||||
constexpr auto j = j_.value;
|
||||
// int i_token = i * BLOCK_SIZE * index_pack + threadIdx.x + j *
|
||||
// BLOCK_SIZE;
|
||||
int i_token =
|
||||
i * BLOCK_SIZE * index_pack + threadIdx.x * index_pack + j;
|
||||
|
||||
if(i_show[j])
|
||||
{
|
||||
#if CK_TILE_REFERENCE_MOE_SORTING_MOCK_ID
|
||||
p_sorted_token_ids[e_start + position] =
|
||||
MOE_SORTING_MOCK_ID(i_token, i_topk[j]);
|
||||
#else
|
||||
p_sorted_token_ids[e_start + position] = i_token;
|
||||
#endif
|
||||
p_sorted_weights[e_start + position] =
|
||||
p_weights[i_token * kargs.topk_mdiv.divisor + i_topk[j]];
|
||||
}
|
||||
position += i_show[j];
|
||||
});
|
||||
|
||||
#if 0
|
||||
int i_token = i * BLOCK_SIZE * index_pack + threadIdx.x * 2 + j * BLOCK_SIZE * 2;
|
||||
index_t x = x_d[j];
|
||||
index_t x0 = static_cast<index_t>(x & 0xffff);
|
||||
index_t x1 = static_cast<index_t>(x >> 16);
|
||||
int i_topk_0 = x0 - 1; // topk of this token
|
||||
int i_show_0 = x0 != 0 ? 1 : 0; // has this token or not
|
||||
int i_topk_1 = x1 - 1; // topk of this token
|
||||
int i_show_1 = x1 != 0 ? 1 : 0; // has this token or not
|
||||
int cumsum = i_show_0 + i_show_1;
|
||||
impl::moe_sorting_wave_cumsum<int, warpSize>(cumsum);
|
||||
|
||||
__syncthreads();
|
||||
if(lane_id == warpSize - 1)
|
||||
{
|
||||
s[4 + wave_id] = cumsum;
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
// reduce cross wave
|
||||
static_for<0, BLOCK_SIZE / warpSize - 1, 1>{}([&](auto i_w) {
|
||||
IndexType prev = s[4 + i_w];
|
||||
prev = wave_id > i_w ? prev : 0; // mask out
|
||||
cumsum += prev;
|
||||
});
|
||||
cumsum += prev_cumsum; // add previous round cumsum
|
||||
if(threadIdx.x == BLOCK_SIZE - 1)
|
||||
{
|
||||
s[0] = cumsum;
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
int position_0 = cumsum - i_show_0 - i_show_1;
|
||||
prev_cumsum = s[0]; // update the last cumsum
|
||||
|
||||
if(i_show_0)
|
||||
{
|
||||
#if CK_TILE_REFERENCE_MOE_SORTING_MOCK_ID
|
||||
p_sorted_token_ids[e_start + position_0] =
|
||||
MOE_SORTING_MOCK_ID(i_token, i_topk_0);
|
||||
#else
|
||||
p_sorted_token_ids[e_start + position_0] = i_token;
|
||||
#endif
|
||||
p_sorted_weights[e_start + position_0] =
|
||||
p_weights[i_token * kargs.topk_mdiv.divisor + i_topk_0];
|
||||
}
|
||||
|
||||
int position_1 = cumsum - i_show_1;
|
||||
|
||||
if(i_show_1)
|
||||
{
|
||||
#if CK_TILE_REFERENCE_MOE_SORTING_MOCK_ID
|
||||
p_sorted_token_ids[e_start + position_1] =
|
||||
MOE_SORTING_MOCK_ID(i_token + 1, i_topk_1);
|
||||
#else
|
||||
p_sorted_token_ids[e_start + position_1] = i_token + 1;
|
||||
#endif
|
||||
p_sorted_weights[e_start + position_1] =
|
||||
p_weights[(i_token + 1) * kargs.topk_mdiv.divisor + i_topk_1];
|
||||
}
|
||||
#endif
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for(index_t i = e_start + prev_cumsum + threadIdx.x; i < e_end; i += BLOCK_SIZE)
|
||||
{
|
||||
#if CK_TILE_REFERENCE_MOE_SORTING_MOCK_ID
|
||||
p_sorted_token_ids[i] = MOE_SORTING_MOCK_ID(kargs.tokens, kargs.topk_mdiv.divisor);
|
||||
#else
|
||||
p_sorted_token_ids[i] = tokens;
|
||||
#endif
|
||||
p_sorted_weights[i] = static_cast<WeightType>(0.0);
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
#undef MOE_SORTING_MOCK_ID
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
@@ -50,20 +50,23 @@ struct MoeSortingProblemEx
|
||||
};
|
||||
|
||||
template <typename IndexType_,
|
||||
typename WeightType_,
|
||||
index_t SubTokenTile_, // 1,2,4
|
||||
typename WeightType_, // used for expert mesh in ws
|
||||
typename MeshType_,
|
||||
index_t SubTokenTile_, // 1,2,4,8
|
||||
bool LocalExpertMasking_, // used in EP case
|
||||
bool SkipExpertsWithZeroTokens_ = true>
|
||||
struct MoeSortingProblemMp
|
||||
{
|
||||
// TODO: this kernel only support warp per row
|
||||
using WeightType = remove_cvref_t<WeightType_>;
|
||||
using MeshType = remove_cvref_t<MeshType_>;
|
||||
using IndexType = remove_cvref_t<IndexType_>;
|
||||
|
||||
static constexpr index_t SubTokenTile = SubTokenTile_;
|
||||
static constexpr bool LocalExpertMasking = LocalExpertMasking_;
|
||||
static constexpr bool SkipExpertsWithZeroTokens = SkipExpertsWithZeroTokens_;
|
||||
static_assert(SubTokenTile == 1 || SubTokenTile == 2 || SubTokenTile == 4);
|
||||
static_assert(SubTokenTile == 1 || SubTokenTile == 2 || SubTokenTile == 4 ||
|
||||
SubTokenTile == 8 || SubTokenTile == 16);
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
Reference in New Issue
Block a user