[CK_TILE] Fix compilation errors introduced in #2320, #2219 and #2214 (#2388)

* Fix compilation errors

* Fix more ck_tile example compilation errors
This commit is contained in:
Po Yen Chen
2025-06-23 12:29:15 +08:00
committed by GitHub
parent 0366fb2abc
commit 7d669440a6
10 changed files with 112 additions and 110 deletions

View File

@@ -101,7 +101,7 @@ struct FusedMoeGemmShape
static constexpr index_t Repeat_N1 = Block_N1 / ThreadPerBlock_N1;
static constexpr index_t Repeat_K1 = Block_K1 / ThreadPerBlock_K1;
static constexpr index_t BlockSize = WarpSize * NumWarps;
static constexpr index_t BlockSize = get_warp_size() * NumWarps;
// some assert
static_assert(Block_M0 == Block_M1);

View File

@@ -388,7 +388,7 @@ struct MoeSortingKernel
}
// reduce single pixel within a wave
template <typename T, typename F, index_t wave_size_ = WarpSize>
template <typename T, typename F, index_t wave_size_ = get_warp_size()>
__device__ static constexpr T wave_reduce(T local, F reduce_f, number<wave_size_> = {})
{
// constexpr int wave_size = 64;
@@ -625,7 +625,7 @@ struct MoeSortingKernel
{
const index_t prefill_token = topk_mdiv.div(numel);
// TODO: only support expert-tile like 8, 16, 32
static constexpr index_t experts_per_wave = WarpSize / Problem::ExpertTile;
static constexpr index_t experts_per_wave = get_warp_size() / Problem::ExpertTile;
{
index_t eid = tid / experts_per_wave;
index_t expert_offset = cumsum[eid] +
@@ -693,7 +693,7 @@ struct MoeSortingKernel
void* smem) const
{
const index_t tid = static_cast<index_t>(threadIdx.x);
const index_t wid = __builtin_amdgcn_readfirstlane(tid / WarpSize);
const index_t wid = __builtin_amdgcn_readfirstlane(tid / get_warp_size());
const index_t lid = __lane_id();
constexpr index_t block_size = 256; // blockDim.x;
const index_t sub_tokens = smem_rows - 2; // sub_tokens_mdiv.divisor;
@@ -798,7 +798,7 @@ struct MoeSortingKernel
// NOTE: under this block can never use __syncthreads!
int i_e_ = 0;
int local_cumsum_ = 0;
for(; i_e_ < num_experts; i_e_ += WarpSize)
for(; i_e_ < num_experts; i_e_ += get_warp_size())
{
int pre_cumsum_ = smem_cumsum(lid == 0 ? i_e_ : 0);
int local_cnt = smem_cumsum(i_e_ + lid + 1);
@@ -843,7 +843,7 @@ struct MoeSortingKernel
// cumsum padded in case local cumsum is zero, but
// pre_sumsum has value, which will result int
// zero local cumsum(but we want at least padded)
wave_cumsum<int, WarpSize>(local_cumsum_);
wave_cumsum<int, get_warp_size()>(local_cumsum_);
if((i_e_ + lid) < num_experts)
smem_cumsum(i_e_ + lid + 1) = local_cumsum_;
@@ -851,7 +851,7 @@ struct MoeSortingKernel
if constexpr(Problem::LocalExpertMasking)
{
local_masking += pre_cumsum_masking;
wave_cumsum<int, WarpSize>(local_masking);
wave_cumsum<int, get_warp_size()>(local_masking);
if((i_e_ + lid) < num_experts)
smem_cumdup(i_e_ + lid + 1) = local_masking;
}
@@ -861,7 +861,7 @@ struct MoeSortingKernel
// than 0(which is not we want)
__builtin_amdgcn_s_waitcnt(0xc07f);
}
if((lid + i_e_ - WarpSize) == (num_experts - 1))
if((lid + i_e_ - get_warp_size()) == (num_experts - 1))
{
*p_total_tokens_post_pad = local_cumsum_;
}
@@ -1109,7 +1109,7 @@ CK_TILE_HOST_DEVICE index_t moe_sorting_mp_sem_smem_size()
return chunk * sizeof(index_t);
};
template <typename T, typename F, index_t wave_size_ = WarpSize>
template <typename T, typename F, index_t wave_size_ = get_warp_size()>
CK_TILE_DEVICE constexpr T moe_sorting_wave_reduce(T local, F reduce_f, number<wave_size_> = {})
{
// constexpr int wave_size = 64;
@@ -1504,7 +1504,7 @@ struct MoeSortingMultiPhaseKernel_P1
// in byte
CK_TILE_HOST_DEVICE static constexpr auto GetSmemSize()
{
return BLOCK_SIZE / WarpSize * sizeof(IndexType);
return BLOCK_SIZE / get_warp_size() * sizeof(IndexType);
}
CK_TILE_DEVICE void operator()(Kargs kargs) const
@@ -1546,8 +1546,8 @@ struct MoeSortingMultiPhaseKernel_P1
cnt += impl::moe_sorting_wave_reduce(local_sum, f_sum);
}
index_t lane_id = threadIdx.x % WarpSize;
index_t wave_id = threadIdx.x / WarpSize;
index_t lane_id = threadIdx.x % get_warp_size();
index_t wave_id = threadIdx.x / get_warp_size();
// reduce cross wave
IndexType* s = reinterpret_cast<IndexType*>(smem);
@@ -1560,7 +1560,7 @@ struct MoeSortingMultiPhaseKernel_P1
if(threadIdx.x == 0)
{
index_t c = 0;
for(auto i = 0; i < (BLOCK_SIZE / WarpSize); i++)
for(auto i = 0; i < (BLOCK_SIZE / get_warp_size()); i++)
{
c += s[i];
}
@@ -1660,7 +1660,7 @@ struct MoeSortingMultiPhaseKernel_P01
// in byte
CK_TILE_HOST static constexpr auto GetSmemSize()
{
return BLOCK_SIZE / WarpSize * sizeof(IndexType);
return BLOCK_SIZE / get_warp_size() * sizeof(IndexType);
}
CK_TILE_DEVICE void operator()(Kargs kargs) const
@@ -1786,8 +1786,8 @@ struct MoeSortingMultiPhaseKernel_P01
cnt += impl::moe_sorting_wave_reduce(local_sum, f_sum);
}
index_t lane_id = threadIdx.x % WarpSize;
index_t wave_id = threadIdx.x / WarpSize;
index_t lane_id = threadIdx.x % get_warp_size();
index_t wave_id = threadIdx.x / get_warp_size();
// reduce cross wave
IndexType* s = reinterpret_cast<IndexType*>(smem);
@@ -1801,7 +1801,7 @@ struct MoeSortingMultiPhaseKernel_P01
if(threadIdx.x == 0)
{
index_t c = 0;
for(auto i = 0; i < (BLOCK_SIZE / WarpSize); i++)
for(auto i = 0; i < (BLOCK_SIZE / get_warp_size()); i++)
{
c += s[i];
}
@@ -1880,7 +1880,7 @@ struct MoeSortingMultiPhaseKernel_P2
CK_TILE_HOST_DEVICE static constexpr auto GetSmemSize()
{
// return 2 * BLOCK_SIZE * sizeof(IndexType);
return (4 + 2 * BLOCK_SIZE / WarpSize) * sizeof(IndexType);
return (4 + 2 * BLOCK_SIZE / get_warp_size()) * sizeof(IndexType);
}
// reduce single pixel within a wave
@@ -1905,8 +1905,8 @@ struct MoeSortingMultiPhaseKernel_P2
IndexType* p_sorted_expert_ids = reinterpret_cast<IndexType*>(kargs.p_sorted_expert_ids);
const index_t loops = (kargs.num_experts + BLOCK_SIZE - 1) / BLOCK_SIZE;
index_t wave_id = threadIdx.x / WarpSize;
index_t lane_id = threadIdx.x % WarpSize;
index_t wave_id = threadIdx.x / get_warp_size();
index_t lane_id = threadIdx.x % get_warp_size();
IndexType prev_cumsum_a = 0;
IndexType prev_cumsum_b = 0;
@@ -1951,22 +1951,22 @@ struct MoeSortingMultiPhaseKernel_P2
IndexType cumsum_b = b_;
// Note: we first cumsum local round, then add previous cumsum
impl::moe_sorting_wave_cumsum<IndexType, WarpSize>(cumsum_a);
impl::moe_sorting_wave_cumsum<IndexType, WarpSize>(cumsum_b);
impl::moe_sorting_wave_cumsum<IndexType, get_warp_size()>(cumsum_a);
impl::moe_sorting_wave_cumsum<IndexType, get_warp_size()>(cumsum_b);
__syncthreads();
if(lane_id == WarpSize - 1)
if(lane_id == get_warp_size() - 1)
{
s[4 + wave_id] = cumsum_a;
s[4 + wave_id + BLOCK_SIZE / WarpSize] = cumsum_b;
s[4 + wave_id] = cumsum_a;
s[4 + wave_id + BLOCK_SIZE / get_warp_size()] = cumsum_b;
}
__syncthreads();
// reduce cross wave
static_for<0, BLOCK_SIZE / WarpSize - 1, 1>{}([&](auto i_w) {
static_for<0, BLOCK_SIZE / get_warp_size() - 1, 1>{}([&](auto i_w) {
IndexType prev_a = s[4 + i_w];
IndexType prev_b = s[4 + i_w + BLOCK_SIZE / WarpSize];
IndexType prev_b = s[4 + i_w + BLOCK_SIZE / get_warp_size()];
prev_a = wave_id > i_w ? prev_a : 0; // mask out
prev_b = wave_id > i_w ? prev_b : 0; // mask out
cumsum_a += prev_a;
@@ -2083,7 +2083,7 @@ struct MoeSortingMultiPhaseKernel_P3
// in byte
CK_TILE_HOST_DEVICE static constexpr auto GetSmemSize()
{
return (4 + BLOCK_SIZE / WarpSize) * sizeof(IndexType);
return (4 + BLOCK_SIZE / get_warp_size()) * sizeof(IndexType);
}
CK_TILE_DEVICE void operator()(Kargs kargs) const
@@ -2110,8 +2110,8 @@ struct MoeSortingMultiPhaseKernel_P3
}
}();
int eid = blockIdx.x;
int wave_id = threadIdx.x / WarpSize;
int lane_id = threadIdx.x % WarpSize;
int wave_id = threadIdx.x / get_warp_size();
int lane_id = threadIdx.x % get_warp_size();
int e_start = p_expert_cumsum[eid];
int e_end = p_expert_cumsum[eid + 1];
if constexpr(Problem::SkipExpertsWithZeroTokens)
@@ -2141,17 +2141,17 @@ struct MoeSortingMultiPhaseKernel_P3
int i_topk = x - 1; // topk of this token
int i_show = x != 0 ? 1 : 0; // has this token or not
int cumsum = i_show;
impl::moe_sorting_wave_cumsum<int, WarpSize>(cumsum);
impl::moe_sorting_wave_cumsum<int, get_warp_size()>(cumsum);
__syncthreads();
if(lane_id == WarpSize - 1)
if(lane_id == get_warp_size() - 1)
{
s[4 + wave_id] = cumsum;
}
__syncthreads();
// reduce cross wave
static_for<0, BLOCK_SIZE / WarpSize - 1, 1>{}([&](auto i_w) {
static_for<0, BLOCK_SIZE / get_warp_size() - 1, 1>{}([&](auto i_w) {
IndexType prev = s[4 + i_w];
prev = wave_id > i_w ? prev : 0; // mask out
cumsum += prev;
@@ -2196,7 +2196,7 @@ CK_TILE_HOST constexpr auto moe_sorting_get_smem_size_p23(int num_experts_)
{
constexpr index_t BLOCK_SIZE = 256; // hardcoded 256
const index_t expert_cumsum_elem = num_experts_ + 1;
return (4 + 2 * BLOCK_SIZE / WarpSize + expert_cumsum_elem) * sizeof(int);
return (4 + 2 * BLOCK_SIZE / get_warp_size() + expert_cumsum_elem) * sizeof(int);
}
} // namespace impl
@@ -2303,15 +2303,15 @@ struct MoeSortingMultiPhaseKernel_P23
const IndexType* p_local_expert_mask =
static_cast<const IndexType*>(kargs.p_local_expert_mask);
IndexType* p_expert_cumsum = reinterpret_cast<IndexType*>(kargs.p_expert_cumsum);
IndexType* p_expert_cumsum_smem = s + 4 + 2 * BLOCK_SIZE / WarpSize;
IndexType* p_expert_cumsum_smem = s + 4 + 2 * BLOCK_SIZE / get_warp_size();
IndexType* p_total_tokens_post_pad =
reinterpret_cast<IndexType*>(kargs.p_total_tokens_post_pad);
IndexType* p_sorted_expert_ids =
reinterpret_cast<IndexType*>(kargs.p_sorted_expert_ids);
const index_t loops = (kargs.num_experts + BLOCK_SIZE - 1) / BLOCK_SIZE;
index_t wave_id = threadIdx.x / WarpSize;
index_t lane_id = threadIdx.x % WarpSize;
index_t wave_id = threadIdx.x / get_warp_size();
index_t lane_id = threadIdx.x % get_warp_size();
IndexType prev_cumsum_a = 0;
IndexType prev_cumsum_b = 0;
@@ -2356,22 +2356,22 @@ struct MoeSortingMultiPhaseKernel_P23
IndexType cumsum_b = b_;
// Note: we first cumsum local round, then add previous cumsum
impl::moe_sorting_wave_cumsum<IndexType, WarpSize>(cumsum_a);
impl::moe_sorting_wave_cumsum<IndexType, WarpSize>(cumsum_b);
impl::moe_sorting_wave_cumsum<IndexType, get_warp_size()>(cumsum_a);
impl::moe_sorting_wave_cumsum<IndexType, get_warp_size()>(cumsum_b);
__syncthreads();
if(lane_id == WarpSize - 1)
if(lane_id == get_warp_size() - 1)
{
s[4 + wave_id] = cumsum_a;
s[4 + wave_id + BLOCK_SIZE / WarpSize] = cumsum_b;
s[4 + wave_id] = cumsum_a;
s[4 + wave_id + BLOCK_SIZE / get_warp_size()] = cumsum_b;
}
__syncthreads();
// reduce cross wave
static_for<0, BLOCK_SIZE / WarpSize - 1, 1>{}([&](auto i_w) {
static_for<0, BLOCK_SIZE / get_warp_size() - 1, 1>{}([&](auto i_w) {
IndexType prev_a = s[4 + i_w];
IndexType prev_b = s[4 + i_w + BLOCK_SIZE / WarpSize];
IndexType prev_b = s[4 + i_w + BLOCK_SIZE / get_warp_size()];
prev_a = wave_id > i_w ? prev_a : 0; // mask out
prev_b = wave_id > i_w ? prev_b : 0; // mask out
cumsum_a += prev_a;
@@ -2441,13 +2441,13 @@ struct MoeSortingMultiPhaseKernel_P23
IndexType* s = reinterpret_cast<IndexType*>(smem);
MeshType* p_expert_mesh = reinterpret_cast<MeshType*>(kargs.p_expert_mesh);
IndexType* p_sorted_token_ids = reinterpret_cast<IndexType*>(kargs.p_sorted_token_ids);
IndexType* p_expert_cumsum_smem = s + 4 + 2 * BLOCK_SIZE / WarpSize;
IndexType* p_expert_cumsum_smem = s + 4 + 2 * BLOCK_SIZE / get_warp_size();
const WeightType* p_weights = static_cast<const WeightType*>(kargs.p_weights);
WeightType* p_sorted_weights = reinterpret_cast<WeightType*>(kargs.p_sorted_weights);
int eid = blockIdx.x;
int wave_id = threadIdx.x / WarpSize;
int lane_id = threadIdx.x % WarpSize;
int wave_id = threadIdx.x / get_warp_size();
int lane_id = threadIdx.x % get_warp_size();
int e_start = p_expert_cumsum_smem[eid];
int e_end = p_expert_cumsum_smem[eid + 1];
if constexpr(Problem::SkipExpertsWithZeroTokens)
@@ -2518,17 +2518,17 @@ struct MoeSortingMultiPhaseKernel_P23
int i_topk = x - 1; // topk of this token
int i_show = x != 0 ? 1 : 0; // has this token or not
int cumsum = i_show;
impl::moe_sorting_wave_cumsum<int, WarpSize>(cumsum);
impl::moe_sorting_wave_cumsum<int, get_warp_size()>(cumsum);
__syncthreads();
if(lane_id == WarpSize - 1)
if(lane_id == get_warp_size() - 1)
{
s[4 + wave_id] = cumsum;
}
__syncthreads();
// reduce cross wave
static_for<0, BLOCK_SIZE / WarpSize - 1, 1>{}([&](auto i_w) {
static_for<0, BLOCK_SIZE / get_warp_size() - 1, 1>{}([&](auto i_w) {
IndexType prev = s[4 + i_w];
prev = wave_id > i_w ? prev : 0; // mask out
cumsum += prev;
@@ -2569,17 +2569,17 @@ struct MoeSortingMultiPhaseKernel_P23
cumsum_store += i_show[j];
});
int cumsum = cumsum_store;
impl::moe_sorting_wave_cumsum<int, WarpSize>(cumsum);
impl::moe_sorting_wave_cumsum<int, get_warp_size()>(cumsum);
__syncthreads();
if(lane_id == WarpSize - 1)
if(lane_id == get_warp_size() - 1)
{
s[4 + wave_id] = cumsum;
}
__syncthreads();
// reduce cross wave
static_for<0, BLOCK_SIZE / WarpSize - 1, 1>{}([&](auto i_w) {
static_for<0, BLOCK_SIZE / get_warp_size() - 1, 1>{}([&](auto i_w) {
IndexType prev = s[4 + i_w];
prev = wave_id > i_w ? prev : 0; // mask out
cumsum += prev;
@@ -2624,17 +2624,17 @@ struct MoeSortingMultiPhaseKernel_P23
int i_topk_1 = x1 - 1; // topk of this token
int i_show_1 = x1 != 0 ? 1 : 0; // has this token or not
int cumsum = i_show_0 + i_show_1;
impl::moe_sorting_wave_cumsum<int, WarpSize>(cumsum);
impl::moe_sorting_wave_cumsum<int, get_warp_size()>(cumsum);
__syncthreads();
if(lane_id == WarpSize - 1)
if(lane_id == get_warp_size() - 1)
{
s[4 + wave_id] = cumsum;
}
__syncthreads();
// reduce cross wave
static_for<0, BLOCK_SIZE / WarpSize - 1, 1>{}([&](auto i_w) {
static_for<0, BLOCK_SIZE / get_warp_size() - 1, 1>{}([&](auto i_w) {
IndexType prev = s[4 + i_w];
prev = wave_id > i_w ? prev : 0; // mask out
cumsum += prev;