Do not use warpSize as compile time constant as it is removed

This commit is contained in:
Satyanvesh Dittakavi
2025-06-10 09:30:41 +00:00
parent 5a0bd157db
commit 391e14a24d
11 changed files with 85 additions and 73 deletions

View File

@@ -274,6 +274,12 @@
namespace ck {
#if defined(__GFX9__)
__device__ static constexpr int WarpSize = 64;
#else
__device__ static constexpr int WarpSize = 32;
#endif
enum struct InMemoryDataOperationEnum
{
Set,

View File

@@ -141,7 +141,7 @@ struct BlockwiseGemmXdlops_pipeline_v2<BlockGemmPipelineScheduler::Intrawave,
using Base::BMmaKStride;
static constexpr index_t WgpPerCU =
(4 * warpSize / BlockSize) >= 1 ? 4 * warpSize / BlockSize : 1;
(4 * WarpSize / BlockSize) >= 1 ? 4 * WarpSize / BlockSize : 1;
static constexpr index_t FullMemBandPrefetchStages = math::integer_divide_ceil(
32768 / WgpPerCU,
(MPerBlock * sizeof(ADataType) + NPerBlock * sizeof(BDataType)) * KPerBlock);
@@ -631,7 +631,7 @@ struct BlockwiseGemmXdlops_pipeline_v2<BlockGemmPipelineScheduler::Interwave,
static constexpr index_t KRepeat = KPerThread / KPerInnerLoop;
static constexpr index_t WgpPerCU =
(4 * warpSize / BlockSize) >= 1 ? 4 * warpSize / BlockSize : 1;
(4 * WarpSize / BlockSize) >= 1 ? 4 * WarpSize / BlockSize : 1;
static constexpr index_t FullMemBandPrefetchStages = math::integer_divide_ceil(
32768 / WgpPerCU,
(MPerBlock * sizeof(ADataType) + NPerBlock * sizeof(BDataType)) * KPerBlock);

View File

@@ -143,7 +143,7 @@ struct BlockwiseGemmXdlops_pipeline_v2_ab_scale<BlockGemmPipelineScheduler::Intr
using Base::BMmaKStride;
static constexpr index_t WgpPerCU =
(4 * warpSize / BlockSize) >= 1 ? 4 * warpSize / BlockSize : 1;
(4 * WarpSize / BlockSize) >= 1 ? 4 * WarpSize / BlockSize : 1;
static constexpr index_t FullMemBandPrefetchStages = math::integer_divide_ceil(
32768 / WgpPerCU,
(MPerBlock * sizeof(ADataType) + NPerBlock * sizeof(BDataType)) * KPerBlock);

View File

@@ -141,7 +141,7 @@ struct BlockwiseGemmXdlops_pipeline_v2_b_scale<BlockGemmPipelineScheduler::Intra
using Base::BMmaKStride;
static constexpr index_t WgpPerCU =
(4 * warpSize / BlockSize) >= 1 ? 4 * warpSize / BlockSize : 1;
(4 * WarpSize / BlockSize) >= 1 ? 4 * WarpSize / BlockSize : 1;
static constexpr index_t FullMemBandPrefetchStages = math::integer_divide_ceil(
32768 / WgpPerCU,
(MPerBlock * sizeof(ADataType) + NPerBlock * sizeof(BDataType)) * KPerBlock);
@@ -632,7 +632,7 @@ struct BlockwiseGemmXdlops_pipeline_v2_b_scale<BlockGemmPipelineScheduler::Inter
static constexpr index_t KRepeat = KPerThread / KPerInnerLoop;
static constexpr index_t WgpPerCU =
(4 * warpSize / BlockSize) >= 1 ? 4 * warpSize / BlockSize : 1;
(4 * WarpSize / BlockSize) >= 1 ? 4 * WarpSize / BlockSize : 1;
static constexpr index_t FullMemBandPrefetchStages = math::integer_divide_ceil(
32768 / WgpPerCU,
(MPerBlock * sizeof(ADataType) + NPerBlock * sizeof(BDataType)) * KPerBlock);

View File

@@ -347,7 +347,7 @@ struct GridwiseGemm_xdl_cshuffle_v3_b_preshuffle
__host__ __device__ static auto MakeBGridDescriptor_Preshuffled(index_t N0, index_t K0)
{
constexpr index_t NkSwizzleNumber = Number<warpSize * KPack>{};
constexpr index_t NkSwizzleNumber = Number<WarpSize * KPack>{};
return make_naive_tensor_descriptor(
make_tuple(N0 / NWave, NWave, K0, NkSwizzleNumber),
make_tuple(NWave * K0 * NkSwizzleNumber, K0 * NkSwizzleNumber, NkSwizzleNumber, I1));
@@ -1229,7 +1229,7 @@ struct GridwiseGemm_xdl_cshuffle_v3_b_preshuffle
make_multi_index(n_block_data_idx_on_grid,
get_warp_local_1d_id() % NWave,
0,
KPack * (get_thread_local_1d_id() % warpSize)));
KPack * (get_thread_local_1d_id() % WarpSize)));
// LDS allocation for A and B: be careful of alignment
@@ -1607,7 +1607,7 @@ struct GridwiseGemm_xdl_cshuffle_v3_b_preshuffle
make_multi_index(n_block_data_idx_on_grid,
get_warp_local_1d_id() % NWave,
0,
KPack * (get_thread_local_1d_id() % warpSize)));
KPack * (get_thread_local_1d_id() % WarpSize)));
// LDS allocation for A and B: be careful of alignment
auto a_block_buf_ping = make_dynamic_buffer<AddressSpaceEnum::Lds>(

View File

@@ -376,7 +376,7 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle
__host__ __device__ static auto MakeBGridDescriptor_Preshuffled(index_t N0, index_t K0)
{
constexpr index_t NkSwizzleNumber = Number<warpSize * KPackPerGroup>{};
constexpr index_t NkSwizzleNumber = Number<WarpSize * KPackPerGroup>{};
return make_naive_tensor_descriptor(
make_tuple(N0 / NWave, NWave, K0, NkSwizzleNumber),
make_tuple(NWave * K0 * NkSwizzleNumber, K0 * NkSwizzleNumber, NkSwizzleNumber, I1));
@@ -1253,7 +1253,7 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle
make_multi_index(n_block_data_idx_on_grid,
get_warp_local_1d_id() % NWave,
0,
KPackPerGroup * (get_thread_local_1d_id() % warpSize)));
KPackPerGroup * (get_thread_local_1d_id() % WarpSize)));
// LDS allocation for A and B: be careful of alignment
// Cast after lds
@@ -1693,7 +1693,7 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle
make_multi_index(n_block_data_idx_on_grid,
get_warp_local_1d_id() % NWave,
0,
KPackPerGroup * (get_thread_local_1d_id() % warpSize)));
KPackPerGroup * (get_thread_local_1d_id() % WarpSize)));
// LDS allocation for A and B: be careful of alignment
// Cast after lds

View File

@@ -404,7 +404,7 @@ struct GridwiseMoeGemm
__host__ __device__ static auto MakeBGridDescriptor_Preshuffled(index_t N0, index_t K0)
{
constexpr index_t NkSwizzleNumber = Number<warpSize * KPack>{};
constexpr index_t NkSwizzleNumber = Number<WarpSize * KPack>{};
return make_naive_tensor_descriptor(
make_tuple(N0 / NWave, NWave, K0, NkSwizzleNumber),
make_tuple(NWave * K0 * NkSwizzleNumber, K0 * NkSwizzleNumber, NkSwizzleNumber, I1));
@@ -1314,7 +1314,7 @@ struct GridwiseMoeGemm
make_multi_index(n_block_data_idx_on_grid,
get_warp_local_1d_id() % NWave,
0,
KPack * (get_thread_local_1d_id() % warpSize)));
KPack * (get_thread_local_1d_id() % WarpSize)));
// LDS allocation for A and B: be careful of alignment
// Cast after lds
@@ -1360,7 +1360,7 @@ struct GridwiseMoeGemm
make_multi_index(n_block_data_idx_on_grid,
get_warp_local_1d_id() % NWave,
0,
KPack * (get_thread_local_1d_id() % warpSize)));
KPack * (get_thread_local_1d_id() % WarpSize)));
blockwise_gemm_pipeline.template Run<HasMainKBlockLoop, TailNum>(
a_grid_desc_ak0_m_ak1,
a_block_desc_ak0_m_ak1,
@@ -2025,7 +2025,7 @@ struct GridwiseMoeGemm
make_multi_index(n_block_data_idx_on_grid,
get_warp_local_1d_id() % NWave,
0,
KPack * (get_thread_local_1d_id() % warpSize)));
KPack * (get_thread_local_1d_id() % WarpSize)));
// LDS allocation for A and B: be careful of alignment
// Cast after lds

View File

@@ -9,8 +9,11 @@ namespace ck {
__host__ __device__ constexpr index_t get_warp_size()
{
// warpSize is defined by HIP
return warpSize;
#if defined(__GFX9__) || !defined(__HIP_DEVICE_COMPILE__)
return 64;
#else
return 32;
#endif
}
__device__ index_t get_thread_local_1d_id() { return threadIdx.x; }

View File

@@ -32,7 +32,7 @@ static __device__ void gms_init(int NumWarps, int* p_control_bits)
// all the workgroups in the synchronization group is supposed to call this function
static __device__ void gms_barrier(int* p_control_bits)
{
constexpr int mask = warpSize - 1;
constexpr int mask = WarpSize - 1;
if((threadIdx.x & mask) == 0)
{

View File

@@ -50,8 +50,11 @@ enum struct memory_operation_enum : std::uint16_t
CK_TILE_HOST_DEVICE constexpr index_t get_warp_size()
{
// warpSize is defined by HIP
return warpSize;
#if defined(__GFX9__) || !defined(__HIP_DEVICE_COMPILE__)
return 64;
#else
return 32;
#endif
}
CK_TILE_DEVICE index_t get_grid_size() { return gridDim.x; }

View File

@@ -396,7 +396,7 @@ struct MoeSortingKernel
}
// reduce single pixel within a wave
template <typename T, typename F, index_t wave_size_ = warpSize>
template <typename T, typename F, index_t wave_size_ = WarpSize>
__device__ static constexpr T wave_reduce(T local, F reduce_f, number<wave_size_> = {})
{
// constexpr int wave_size = 64;
@@ -633,7 +633,7 @@ struct MoeSortingKernel
{
const index_t prefill_token = topk_mdiv.div(numel);
// TODO: only support expert-tile like 8, 16, 32
static constexpr index_t experts_per_wave = warpSize / Problem::ExpertTile;
static constexpr index_t experts_per_wave = WarpSize / Problem::ExpertTile;
{
index_t eid = tid / experts_per_wave;
index_t expert_offset = cumsum[eid] +
@@ -701,7 +701,7 @@ struct MoeSortingKernel
void* smem) const
{
const index_t tid = static_cast<index_t>(threadIdx.x);
const index_t wid = __builtin_amdgcn_readfirstlane(tid / warpSize);
const index_t wid = __builtin_amdgcn_readfirstlane(tid / WarpSize);
const index_t lid = __lane_id();
constexpr index_t block_size = 256; // blockDim.x;
const index_t sub_tokens = smem_rows - 2; // sub_tokens_mdiv.divisor;
@@ -806,7 +806,7 @@ struct MoeSortingKernel
// NOTE: under this block can never use __syncthreads!
int i_e_ = 0;
int local_cumsum_ = 0;
for(; i_e_ < num_experts; i_e_ += warpSize)
for(; i_e_ < num_experts; i_e_ += WarpSize)
{
int pre_cumsum_ = smem_cumsum(lid == 0 ? i_e_ : 0);
int local_cnt = smem_cumsum(i_e_ + lid + 1);
@@ -851,7 +851,7 @@ struct MoeSortingKernel
// cumsum padded in case local cumsum is zero, but
// pre_sumsum has value, which will result int
// zero local cumsum(but we want at least padded)
wave_cumsum<int, warpSize>(local_cumsum_);
wave_cumsum<int, WarpSize>(local_cumsum_);
if((i_e_ + lid) < num_experts)
smem_cumsum(i_e_ + lid + 1) = local_cumsum_;
@@ -859,7 +859,7 @@ struct MoeSortingKernel
if constexpr(Problem::LocalExpertMasking)
{
local_masking += pre_cumsum_masking;
wave_cumsum<int, warpSize>(local_masking);
wave_cumsum<int, WarpSize>(local_masking);
if((i_e_ + lid) < num_experts)
smem_cumdup(i_e_ + lid + 1) = local_masking;
}
@@ -869,7 +869,7 @@ struct MoeSortingKernel
// than 0(which is not we want)
__builtin_amdgcn_s_waitcnt(0xc07f);
}
if((lid + i_e_ - warpSize) == (num_experts - 1))
if((lid + i_e_ - WarpSize) == (num_experts - 1))
{
*p_total_tokens_post_pad = local_cumsum_;
}
@@ -1106,7 +1106,7 @@ CK_TILE_HOST_DEVICE index_t moe_sorting_mp_sem_smem_size()
return chunk * sizeof(index_t);
};
template <typename T, typename F, index_t wave_size_ = warpSize>
template <typename T, typename F, index_t wave_size_ = WarpSize>
CK_TILE_DEVICE constexpr T moe_sorting_wave_reduce(T local, F reduce_f, number<wave_size_> = {})
{
// constexpr int wave_size = 64;
@@ -1471,7 +1471,7 @@ struct MoeSortingMultiPhaseKernel_P1
// in byte
CK_TILE_HOST_DEVICE static constexpr auto GetSmemSize()
{
return BLOCK_SIZE / warpSize * sizeof(IndexType);
return BLOCK_SIZE / WarpSize * sizeof(IndexType);
}
CK_TILE_DEVICE void operator()(Kargs kargs) const
@@ -1513,8 +1513,8 @@ struct MoeSortingMultiPhaseKernel_P1
cnt += impl::moe_sorting_wave_reduce(local_sum, f_sum);
}
index_t lane_id = threadIdx.x % warpSize;
index_t wave_id = threadIdx.x / warpSize;
index_t lane_id = threadIdx.x % WarpSize;
index_t wave_id = threadIdx.x / WarpSize;
// reduce cross wave
IndexType* s = reinterpret_cast<IndexType*>(smem);
@@ -1527,7 +1527,7 @@ struct MoeSortingMultiPhaseKernel_P1
if(threadIdx.x == 0)
{
index_t c = 0;
for(auto i = 0; i < (BLOCK_SIZE / warpSize); i++)
for(auto i = 0; i < (BLOCK_SIZE / WarpSize); i++)
{
c += s[i];
}
@@ -1616,7 +1616,7 @@ struct MoeSortingMultiPhaseKernel_P01
// in byte
CK_TILE_HOST static constexpr auto GetSmemSize()
{
return BLOCK_SIZE / warpSize * sizeof(IndexType);
return BLOCK_SIZE / WarpSize * sizeof(IndexType);
}
CK_TILE_DEVICE void operator()(Kargs kargs) const
@@ -1700,8 +1700,8 @@ struct MoeSortingMultiPhaseKernel_P01
cnt += impl::moe_sorting_wave_reduce(local_sum, f_sum);
}
index_t lane_id = threadIdx.x % warpSize;
index_t wave_id = threadIdx.x / warpSize;
index_t lane_id = threadIdx.x % WarpSize;
index_t wave_id = threadIdx.x / WarpSize;
// reduce cross wave
IndexType* s = reinterpret_cast<IndexType*>(smem);
@@ -1715,7 +1715,7 @@ struct MoeSortingMultiPhaseKernel_P01
if(threadIdx.x == 0)
{
index_t c = 0;
for(auto i = 0; i < (BLOCK_SIZE / warpSize); i++)
for(auto i = 0; i < (BLOCK_SIZE / WarpSize); i++)
{
c += s[i];
}
@@ -1792,7 +1792,7 @@ struct MoeSortingMultiPhaseKernel_P2
CK_TILE_HOST_DEVICE static constexpr auto GetSmemSize()
{
// return 2 * BLOCK_SIZE * sizeof(IndexType);
return (4 + 2 * BLOCK_SIZE / warpSize) * sizeof(IndexType);
return (4 + 2 * BLOCK_SIZE / WarpSize) * sizeof(IndexType);
}
// reduce single pixel within a wave
@@ -1817,8 +1817,8 @@ struct MoeSortingMultiPhaseKernel_P2
IndexType* p_sorted_expert_ids = reinterpret_cast<IndexType*>(kargs.p_sorted_expert_ids);
const index_t loops = (kargs.num_experts + BLOCK_SIZE - 1) / BLOCK_SIZE;
index_t wave_id = threadIdx.x / warpSize;
index_t lane_id = threadIdx.x % warpSize;
index_t wave_id = threadIdx.x / WarpSize;
index_t lane_id = threadIdx.x % WarpSize;
IndexType prev_cumsum_a = 0;
IndexType prev_cumsum_b = 0;
@@ -1863,22 +1863,22 @@ struct MoeSortingMultiPhaseKernel_P2
IndexType cumsum_b = b_;
// Note: we first cumsum local round, then add previous cumsum
impl::moe_sorting_wave_cumsum<IndexType, warpSize>(cumsum_a);
impl::moe_sorting_wave_cumsum<IndexType, warpSize>(cumsum_b);
impl::moe_sorting_wave_cumsum<IndexType, WarpSize>(cumsum_a);
impl::moe_sorting_wave_cumsum<IndexType, WarpSize>(cumsum_b);
__syncthreads();
if(lane_id == warpSize - 1)
if(lane_id == WarpSize - 1)
{
s[4 + wave_id] = cumsum_a;
s[4 + wave_id + BLOCK_SIZE / warpSize] = cumsum_b;
s[4 + wave_id + BLOCK_SIZE / WarpSize] = cumsum_b;
}
__syncthreads();
// reduce cross wave
static_for<0, BLOCK_SIZE / warpSize - 1, 1>{}([&](auto i_w) {
static_for<0, BLOCK_SIZE / WarpSize - 1, 1>{}([&](auto i_w) {
IndexType prev_a = s[4 + i_w];
IndexType prev_b = s[4 + i_w + BLOCK_SIZE / warpSize];
IndexType prev_b = s[4 + i_w + BLOCK_SIZE / WarpSize];
prev_a = wave_id > i_w ? prev_a : 0; // mask out
prev_b = wave_id > i_w ? prev_b : 0; // mask out
cumsum_a += prev_a;
@@ -1993,7 +1993,7 @@ struct MoeSortingMultiPhaseKernel_P3
// in byte
CK_TILE_HOST_DEVICE static constexpr auto GetSmemSize()
{
return (4 + BLOCK_SIZE / warpSize) * sizeof(IndexType);
return (4 + BLOCK_SIZE / WarpSize) * sizeof(IndexType);
}
CK_TILE_DEVICE void operator()(Kargs kargs) const
@@ -2010,8 +2010,8 @@ struct MoeSortingMultiPhaseKernel_P3
WeightType* p_sorted_weights = reinterpret_cast<WeightType*>(kargs.p_sorted_weights);
int eid = blockIdx.x;
int wave_id = threadIdx.x / warpSize;
int lane_id = threadIdx.x % warpSize;
int wave_id = threadIdx.x / WarpSize;
int lane_id = threadIdx.x % WarpSize;
int e_start = p_expert_cumsum[eid];
int e_end = p_expert_cumsum[eid + 1];
if constexpr(Problem::SkipExpertsWithZeroTokens)
@@ -2041,17 +2041,17 @@ struct MoeSortingMultiPhaseKernel_P3
int i_topk = x - 1; // topk of this token
int i_show = x != 0 ? 1 : 0; // has this token or not
int cumsum = i_show;
impl::moe_sorting_wave_cumsum<int, warpSize>(cumsum);
impl::moe_sorting_wave_cumsum<int, WarpSize>(cumsum);
__syncthreads();
if(lane_id == warpSize - 1)
if(lane_id == WarpSize - 1)
{
s[4 + wave_id] = cumsum;
}
__syncthreads();
// reduce cross wave
static_for<0, BLOCK_SIZE / warpSize - 1, 1>{}([&](auto i_w) {
static_for<0, BLOCK_SIZE / WarpSize - 1, 1>{}([&](auto i_w) {
IndexType prev = s[4 + i_w];
prev = wave_id > i_w ? prev : 0; // mask out
cumsum += prev;
@@ -2096,7 +2096,7 @@ CK_TILE_HOST constexpr auto moe_sorting_get_smem_size_p23(int num_experts_)
{
constexpr index_t BLOCK_SIZE = 256; // hardcoded 256
const index_t expert_cumsum_elem = num_experts_ + 1;
return (4 + 2 * BLOCK_SIZE / warpSize + expert_cumsum_elem) * sizeof(int);
return (4 + 2 * BLOCK_SIZE / WarpSize + expert_cumsum_elem) * sizeof(int);
}
} // namespace impl
@@ -2201,15 +2201,15 @@ struct MoeSortingMultiPhaseKernel_P23
const IndexType* p_local_expert_mask =
static_cast<const IndexType*>(kargs.p_local_expert_mask);
IndexType* p_expert_cumsum = reinterpret_cast<IndexType*>(kargs.p_expert_cumsum);
IndexType* p_expert_cumsum_smem = s + 4 + 2 * BLOCK_SIZE / warpSize;
IndexType* p_expert_cumsum_smem = s + 4 + 2 * BLOCK_SIZE / WarpSize;
IndexType* p_total_tokens_post_pad =
reinterpret_cast<IndexType*>(kargs.p_total_tokens_post_pad);
IndexType* p_sorted_expert_ids =
reinterpret_cast<IndexType*>(kargs.p_sorted_expert_ids);
const index_t loops = (kargs.num_experts + BLOCK_SIZE - 1) / BLOCK_SIZE;
index_t wave_id = threadIdx.x / warpSize;
index_t lane_id = threadIdx.x % warpSize;
index_t wave_id = threadIdx.x / WarpSize;
index_t lane_id = threadIdx.x % WarpSize;
IndexType prev_cumsum_a = 0;
IndexType prev_cumsum_b = 0;
@@ -2254,22 +2254,22 @@ struct MoeSortingMultiPhaseKernel_P23
IndexType cumsum_b = b_;
// Note: we first cumsum local round, then add previous cumsum
impl::moe_sorting_wave_cumsum<IndexType, warpSize>(cumsum_a);
impl::moe_sorting_wave_cumsum<IndexType, warpSize>(cumsum_b);
impl::moe_sorting_wave_cumsum<IndexType, WarpSize>(cumsum_a);
impl::moe_sorting_wave_cumsum<IndexType, WarpSize>(cumsum_b);
__syncthreads();
if(lane_id == warpSize - 1)
if(lane_id == WarpSize - 1)
{
s[4 + wave_id] = cumsum_a;
s[4 + wave_id + BLOCK_SIZE / warpSize] = cumsum_b;
s[4 + wave_id + BLOCK_SIZE / WarpSize] = cumsum_b;
}
__syncthreads();
// reduce cross wave
static_for<0, BLOCK_SIZE / warpSize - 1, 1>{}([&](auto i_w) {
static_for<0, BLOCK_SIZE / WarpSize - 1, 1>{}([&](auto i_w) {
IndexType prev_a = s[4 + i_w];
IndexType prev_b = s[4 + i_w + BLOCK_SIZE / warpSize];
IndexType prev_b = s[4 + i_w + BLOCK_SIZE / WarpSize];
prev_a = wave_id > i_w ? prev_a : 0; // mask out
prev_b = wave_id > i_w ? prev_b : 0; // mask out
cumsum_a += prev_a;
@@ -2339,13 +2339,13 @@ struct MoeSortingMultiPhaseKernel_P23
IndexType* s = reinterpret_cast<IndexType*>(smem);
MeshType* p_expert_mesh = reinterpret_cast<MeshType*>(kargs.p_expert_mesh);
IndexType* p_sorted_token_ids = reinterpret_cast<IndexType*>(kargs.p_sorted_token_ids);
IndexType* p_expert_cumsum_smem = s + 4 + 2 * BLOCK_SIZE / warpSize;
IndexType* p_expert_cumsum_smem = s + 4 + 2 * BLOCK_SIZE / WarpSize;
const WeightType* p_weights = static_cast<const WeightType*>(kargs.p_weights);
WeightType* p_sorted_weights = reinterpret_cast<WeightType*>(kargs.p_sorted_weights);
int eid = blockIdx.x;
int wave_id = threadIdx.x / warpSize;
int lane_id = threadIdx.x % warpSize;
int wave_id = threadIdx.x / WarpSize;
int lane_id = threadIdx.x % WarpSize;
int e_start = p_expert_cumsum_smem[eid];
int e_end = p_expert_cumsum_smem[eid + 1];
if constexpr(Problem::SkipExpertsWithZeroTokens)
@@ -2405,17 +2405,17 @@ struct MoeSortingMultiPhaseKernel_P23
int i_topk = x - 1; // topk of this token
int i_show = x != 0 ? 1 : 0; // has this token or not
int cumsum = i_show;
impl::moe_sorting_wave_cumsum<int, warpSize>(cumsum);
impl::moe_sorting_wave_cumsum<int, WarpSize>(cumsum);
__syncthreads();
if(lane_id == warpSize - 1)
if(lane_id == WarpSize - 1)
{
s[4 + wave_id] = cumsum;
}
__syncthreads();
// reduce cross wave
static_for<0, BLOCK_SIZE / warpSize - 1, 1>{}([&](auto i_w) {
static_for<0, BLOCK_SIZE / WarpSize - 1, 1>{}([&](auto i_w) {
IndexType prev = s[4 + i_w];
prev = wave_id > i_w ? prev : 0; // mask out
cumsum += prev;
@@ -2456,17 +2456,17 @@ struct MoeSortingMultiPhaseKernel_P23
cumsum_store += i_show[j];
});
int cumsum = cumsum_store;
impl::moe_sorting_wave_cumsum<int, warpSize>(cumsum);
impl::moe_sorting_wave_cumsum<int, WarpSize>(cumsum);
__syncthreads();
if(lane_id == warpSize - 1)
if(lane_id == WarpSize - 1)
{
s[4 + wave_id] = cumsum;
}
__syncthreads();
// reduce cross wave
static_for<0, BLOCK_SIZE / warpSize - 1, 1>{}([&](auto i_w) {
static_for<0, BLOCK_SIZE / WarpSize - 1, 1>{}([&](auto i_w) {
IndexType prev = s[4 + i_w];
prev = wave_id > i_w ? prev : 0; // mask out
cumsum += prev;
@@ -2511,17 +2511,17 @@ struct MoeSortingMultiPhaseKernel_P23
int i_topk_1 = x1 - 1; // topk of this token
int i_show_1 = x1 != 0 ? 1 : 0; // has this token or not
int cumsum = i_show_0 + i_show_1;
impl::moe_sorting_wave_cumsum<int, warpSize>(cumsum);
impl::moe_sorting_wave_cumsum<int, WarpSize>(cumsum);
__syncthreads();
if(lane_id == warpSize - 1)
if(lane_id == WarpSize - 1)
{
s[4 + wave_id] = cumsum;
}
__syncthreads();
// reduce cross wave
static_for<0, BLOCK_SIZE / warpSize - 1, 1>{}([&](auto i_w) {
static_for<0, BLOCK_SIZE / WarpSize - 1, 1>{}([&](auto i_w) {
IndexType prev = s[4 + i_w];
prev = wave_id > i_w ? prev : 0; // mask out
cumsum += prev;