mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-03 13:11:25 +00:00
Support Wave32 in CK_TILE - Part 1 (#2594)
* Support wave32/wave64 in CK_TILE - Part 1 * remove blocksize in kernel launch * fix build error * fix clang format * fix clang format 2 * fix clang format 3 * fix fmha build error * fix fmha build 2 * fix fmha build 3 * fix build error 4 * address review comment * update change log * replace KernelBlockSize with kBlockSize * fix CI fail * fix clang format * address review comment and rebase code. * fix universal test fail --------- Co-authored-by: Lin, Qun <Quentin.Lin+amdeng@amd.com> Co-authored-by: Thomas Ning <Thomas.Ning@amd.com>
This commit is contained in:
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
@@ -213,7 +213,7 @@ struct MoeSortingKernel
|
||||
|
||||
using Hargs = MoeSortingHostArgs;
|
||||
|
||||
static constexpr index_t BLOCK_SIZE = 256;
|
||||
static constexpr index_t kBlockSize = 256;
|
||||
static constexpr index_t OCCUPANCY = 2; // hard coded
|
||||
|
||||
struct Kargs
|
||||
@@ -487,8 +487,8 @@ struct MoeSortingKernel
|
||||
vector_type* p_buf = reinterpret_cast<vector_type*>(buf);
|
||||
auto zero_ = vector_type{0};
|
||||
|
||||
for(long_index_t i = (blockIdx.x - 1) * BLOCK_SIZE + threadIdx.x; i < total_elems;
|
||||
i += (gridDim.x - 1) * BLOCK_SIZE)
|
||||
for(long_index_t i = (blockIdx.x - 1) * kBlockSize + threadIdx.x; i < total_elems;
|
||||
i += (gridDim.x - 1) * kBlockSize)
|
||||
{
|
||||
p_buf[i] = zero_;
|
||||
}
|
||||
@@ -1419,7 +1419,7 @@ template <typename Problem_>
|
||||
struct MoeSortingClearWorkspaceKernel
|
||||
{
|
||||
using Problem = remove_cvref_t<Problem_>;
|
||||
static constexpr index_t BLOCK_SIZE = Problem::BlockSize;
|
||||
static constexpr index_t kBlockSize = Problem::BlockSize;
|
||||
static constexpr index_t OCCUPANCY = Problem::Occu;
|
||||
|
||||
using Hargs = MoeSortingHostArgs;
|
||||
@@ -1461,7 +1461,7 @@ struct MoeSortingClearWorkspaceKernel
|
||||
|
||||
CK_TILE_HOST static constexpr auto GridSize(const Hargs&) { return get_num_cu() * OCCUPANCY; }
|
||||
|
||||
CK_TILE_HOST static constexpr auto BlockSize(const Hargs&) { return dim3(BLOCK_SIZE); }
|
||||
CK_TILE_HOST static constexpr auto BlockSize(const Hargs&) { return dim3(kBlockSize); }
|
||||
|
||||
// in byte
|
||||
CK_TILE_HOST static constexpr auto GetSmemSize() { return 0; }
|
||||
@@ -1499,8 +1499,8 @@ struct MoeSortingClearWorkspaceKernel
|
||||
vector_type* p_expert_mesh = reinterpret_cast<vector_type*>(kargs.p_expert_mesh);
|
||||
auto zero_ = vector_type{0};
|
||||
|
||||
for(index_t i = blockIdx.x * BLOCK_SIZE + threadIdx.x; i < total_elems;
|
||||
i += gridDim.x * BLOCK_SIZE)
|
||||
for(index_t i = blockIdx.x * kBlockSize + threadIdx.x; i < total_elems;
|
||||
i += gridDim.x * kBlockSize)
|
||||
{
|
||||
p_expert_mesh[i] = zero_;
|
||||
}
|
||||
@@ -1560,7 +1560,7 @@ struct MoeSortingMultiPhaseKernel_P0
|
||||
using WeightType = typename Problem::WeightType;
|
||||
using MeshType = typename Problem::MeshType;
|
||||
|
||||
static constexpr index_t BLOCK_SIZE = 256;
|
||||
static constexpr index_t kBlockSize = 256;
|
||||
static constexpr index_t OCCUPANCY = 2; // hard coded
|
||||
|
||||
typedef MoeSortingHostArgs MoeSortingKargs;
|
||||
@@ -1604,7 +1604,7 @@ struct MoeSortingMultiPhaseKernel_P0
|
||||
|
||||
CK_TILE_HOST static constexpr auto GridSize(const Hargs&) { return get_num_cu() * OCCUPANCY; }
|
||||
|
||||
CK_TILE_HOST static constexpr auto BlockSize(const Hargs&) { return dim3(BLOCK_SIZE); }
|
||||
CK_TILE_HOST static constexpr auto BlockSize(const Hargs&) { return dim3(kBlockSize); }
|
||||
|
||||
// in byte
|
||||
CK_TILE_HOST static constexpr auto GetSmemSize() { return 0; }
|
||||
@@ -1647,8 +1647,8 @@ struct MoeSortingMultiPhaseKernel_P0
|
||||
index_t total_elem = rounded_tokens * kargs.topk_mdiv.divisor / Problem::SubTokenTile;
|
||||
|
||||
#pragma unroll Problem::SubTokenTile
|
||||
for(index_t i = blockIdx.x * BLOCK_SIZE + threadIdx.x; i < total_elem;
|
||||
i += gridDim.x * BLOCK_SIZE)
|
||||
for(index_t i = blockIdx.x * kBlockSize + threadIdx.x; i < total_elem;
|
||||
i += gridDim.x * kBlockSize)
|
||||
{
|
||||
auto x = p_topk_ids[i];
|
||||
static_for<0, Problem::SubTokenTile, 1>{}([&](auto j) {
|
||||
@@ -1678,7 +1678,7 @@ struct MoeSortingMultiPhaseKernel_P1
|
||||
using WeightType = typename Problem::WeightType;
|
||||
using MeshType = typename Problem::MeshType;
|
||||
|
||||
static constexpr index_t BLOCK_SIZE = 256;
|
||||
static constexpr index_t kBlockSize = 256;
|
||||
static constexpr index_t OCCUPANCY = 2; // hard coded
|
||||
|
||||
typedef MoeSortingHostArgs MoeSortingKargs;
|
||||
@@ -1709,12 +1709,12 @@ struct MoeSortingMultiPhaseKernel_P1
|
||||
|
||||
CK_TILE_HOST static constexpr auto GridSize(const Hargs& h) { return dim3(h.num_experts); }
|
||||
|
||||
CK_TILE_HOST static constexpr auto BlockSize(const Hargs&) { return dim3(BLOCK_SIZE); }
|
||||
CK_TILE_HOST static constexpr auto BlockSize(const Hargs&) { return dim3(kBlockSize); }
|
||||
|
||||
// in byte
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetSmemSize()
|
||||
{
|
||||
return BLOCK_SIZE / get_warp_size() * sizeof(IndexType);
|
||||
return kBlockSize / get_warp_size() * sizeof(IndexType);
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE void operator()(Kargs kargs) const
|
||||
@@ -1756,7 +1756,7 @@ struct MoeSortingMultiPhaseKernel_P1
|
||||
r_t* p_expert_mesh = reinterpret_cast<r_t*>(
|
||||
reinterpret_cast<MeshType*>(kargs.p_expert_mesh) + eid * mesh_stride);
|
||||
|
||||
int loops = (mesh_stride / index_pack + BLOCK_SIZE - 1) / BLOCK_SIZE;
|
||||
int loops = (mesh_stride / index_pack + kBlockSize - 1) / kBlockSize;
|
||||
|
||||
if constexpr(Problem::LocalExpertMasking)
|
||||
{
|
||||
@@ -1768,7 +1768,7 @@ struct MoeSortingMultiPhaseKernel_P1
|
||||
index_t cnt = 0; // per-wave cnt
|
||||
for(int i = 0; i < loops; i++)
|
||||
{
|
||||
int position = i * BLOCK_SIZE + threadIdx.x;
|
||||
int position = i * kBlockSize + threadIdx.x;
|
||||
r_t v{0};
|
||||
if(position < (mesh_stride / index_pack))
|
||||
v = p_expert_mesh[position];
|
||||
@@ -1792,7 +1792,7 @@ struct MoeSortingMultiPhaseKernel_P1
|
||||
if(threadIdx.x == 0)
|
||||
{
|
||||
index_t c = 0;
|
||||
for(auto i = 0; i < (BLOCK_SIZE / get_warp_size()); i++)
|
||||
for(auto i = 0; i < (kBlockSize / get_warp_size()); i++)
|
||||
{
|
||||
c += s[i];
|
||||
}
|
||||
@@ -1811,7 +1811,7 @@ struct MoeSortingMultiPhaseKernel_P01
|
||||
using WeightType = typename Problem::WeightType;
|
||||
using MeshType = typename Problem::MeshType;
|
||||
|
||||
static constexpr index_t BLOCK_SIZE = 256;
|
||||
static constexpr index_t kBlockSize = 256;
|
||||
static constexpr index_t OCCUPANCY = 2; // hard coded
|
||||
|
||||
typedef MoeSortingHostArgs MoeSortingKargs;
|
||||
@@ -1878,12 +1878,12 @@ struct MoeSortingMultiPhaseKernel_P01
|
||||
|
||||
CK_TILE_HOST static constexpr auto GridSize(const Hargs&) { return get_num_cu() * OCCUPANCY; }
|
||||
|
||||
CK_TILE_HOST static constexpr auto BlockSize(const Hargs&) { return dim3(BLOCK_SIZE); }
|
||||
CK_TILE_HOST static constexpr auto BlockSize(const Hargs&) { return dim3(kBlockSize); }
|
||||
|
||||
CK_TILE_HOST static constexpr auto WGCounts(const Hargs& h)
|
||||
{
|
||||
index_t total_elem = h.tokens * h.topk / Problem::SubTokenTile;
|
||||
index_t elem_cnt = (total_elem + BLOCK_SIZE - 1) / BLOCK_SIZE;
|
||||
index_t elem_cnt = (total_elem + kBlockSize - 1) / kBlockSize;
|
||||
|
||||
// no more than grid_size
|
||||
return min(elem_cnt, GridSize(h));
|
||||
@@ -1892,7 +1892,7 @@ struct MoeSortingMultiPhaseKernel_P01
|
||||
// in byte
|
||||
CK_TILE_HOST static constexpr auto GetSmemSize()
|
||||
{
|
||||
return BLOCK_SIZE / get_warp_size() * sizeof(IndexType);
|
||||
return kBlockSize / get_warp_size() * sizeof(IndexType);
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE void operator()(Kargs kargs) const
|
||||
@@ -1921,7 +1921,7 @@ struct MoeSortingMultiPhaseKernel_P01
|
||||
if constexpr(Problem::LocalToken)
|
||||
{
|
||||
index_t total_elem = rounded_tokens * kargs.topk / Problem::SubTokenTile;
|
||||
index_t elem_cnt = (total_elem + BLOCK_SIZE - 1) / BLOCK_SIZE;
|
||||
index_t elem_cnt = (total_elem + kBlockSize - 1) / kBlockSize;
|
||||
|
||||
// no more than grid_size
|
||||
return min(elem_cnt, kargs.wg_count);
|
||||
@@ -1940,8 +1940,8 @@ struct MoeSortingMultiPhaseKernel_P01
|
||||
index_t total_elem = rounded_tokens * kargs.topk_mdiv.divisor / Problem::SubTokenTile;
|
||||
|
||||
#pragma unroll Problem::SubTokenTile
|
||||
for(index_t i = blockIdx.x * BLOCK_SIZE + threadIdx.x; i < total_elem;
|
||||
i += BLOCK_SIZE * gridDim.x)
|
||||
for(index_t i = blockIdx.x * kBlockSize + threadIdx.x; i < total_elem;
|
||||
i += kBlockSize * gridDim.x)
|
||||
{
|
||||
auto x = p_topk_ids[i];
|
||||
static_for<0, Problem::SubTokenTile, 1>{}([&](auto j) {
|
||||
@@ -1996,7 +1996,7 @@ struct MoeSortingMultiPhaseKernel_P01
|
||||
|
||||
auto f_sum = [](auto x_, auto y_) { return x_ + y_; };
|
||||
|
||||
int loops = (kargs.mesh_stride / index_pack + BLOCK_SIZE - 1) / BLOCK_SIZE;
|
||||
int loops = (kargs.mesh_stride / index_pack + kBlockSize - 1) / kBlockSize;
|
||||
|
||||
if constexpr(Problem::LocalExpertMasking)
|
||||
{
|
||||
@@ -2008,7 +2008,7 @@ struct MoeSortingMultiPhaseKernel_P01
|
||||
index_t cnt = 0; // per-wave cnt
|
||||
for(int i = 0; i < loops; i++)
|
||||
{
|
||||
int position = i * BLOCK_SIZE + threadIdx.x;
|
||||
int position = i * kBlockSize + threadIdx.x;
|
||||
r_t v{0};
|
||||
if(position < (kargs.mesh_stride / index_pack))
|
||||
v = p_expert_mesh[position];
|
||||
@@ -2033,7 +2033,7 @@ struct MoeSortingMultiPhaseKernel_P01
|
||||
if(threadIdx.x == 0)
|
||||
{
|
||||
index_t c = 0;
|
||||
for(auto i = 0; i < (BLOCK_SIZE / get_warp_size()); i++)
|
||||
for(auto i = 0; i < (kBlockSize / get_warp_size()); i++)
|
||||
{
|
||||
c += s[i];
|
||||
}
|
||||
@@ -2055,7 +2055,7 @@ struct MoeSortingMultiPhaseKernel_P2
|
||||
using WeightType = typename Problem::WeightType;
|
||||
using MeshType = typename Problem::MeshType;
|
||||
|
||||
static constexpr index_t BLOCK_SIZE = 256;
|
||||
static constexpr index_t kBlockSize = 256;
|
||||
static constexpr index_t OCCUPANCY = 2; // hard coded
|
||||
|
||||
typedef MoeSortingHostArgs MoeSortingKargs;
|
||||
@@ -2123,17 +2123,17 @@ struct MoeSortingMultiPhaseKernel_P2
|
||||
return dim3(h.num_experts + get_num_cu() * OCCUPANCY);
|
||||
#else
|
||||
// use 1 block to cumsum
|
||||
return dim3(1 + ck_tile::integer_divide_ceil(h.moe_buf_bytes, BLOCK_SIZE * 16));
|
||||
return dim3(1 + ck_tile::integer_divide_ceil(h.moe_buf_bytes, kBlockSize * 16));
|
||||
#endif
|
||||
}
|
||||
|
||||
CK_TILE_HOST static constexpr auto BlockSize(const Hargs&) { return dim3(BLOCK_SIZE); }
|
||||
CK_TILE_HOST static constexpr auto BlockSize(const Hargs&) { return dim3(kBlockSize); }
|
||||
|
||||
// in byte
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetSmemSize()
|
||||
{
|
||||
// return 2 * BLOCK_SIZE * sizeof(IndexType);
|
||||
return (4 + 2 * BLOCK_SIZE / get_warp_size()) * sizeof(IndexType);
|
||||
// return 2 * kBlockSize * sizeof(IndexType);
|
||||
return (4 + 2 * kBlockSize / get_warp_size()) * sizeof(IndexType);
|
||||
}
|
||||
|
||||
// reduce single pixel within a wave
|
||||
@@ -2142,7 +2142,7 @@ struct MoeSortingMultiPhaseKernel_P2
|
||||
if(blockIdx.x > 0)
|
||||
{
|
||||
#if MOE_SORTING_FMOE_2D_BUF
|
||||
impl::moe_buf_set_zero_kernel_2d<BLOCK_SIZE>(kargs.p_moe_buf,
|
||||
impl::moe_buf_set_zero_kernel_2d<kBlockSize>(kargs.p_moe_buf,
|
||||
kargs.tokens,
|
||||
kargs.moe_buf_interm_dim,
|
||||
kargs.moe_buf_elem_bytes,
|
||||
@@ -2150,7 +2150,7 @@ struct MoeSortingMultiPhaseKernel_P2
|
||||
gridDim.x - 1);
|
||||
return;
|
||||
#else
|
||||
impl::moe_buf_set_zero_kernel<BLOCK_SIZE>(
|
||||
impl::moe_buf_set_zero_kernel<kBlockSize>(
|
||||
reinterpret_cast<uint8x16_t*>(kargs.p_moe_buf),
|
||||
kargs.moe_buf_bytes,
|
||||
blockIdx.x - 1);
|
||||
@@ -2167,7 +2167,7 @@ struct MoeSortingMultiPhaseKernel_P2
|
||||
reinterpret_cast<IndexType*>(kargs.p_total_tokens_post_pad);
|
||||
IndexType* p_sorted_expert_ids = reinterpret_cast<IndexType*>(kargs.p_sorted_expert_ids);
|
||||
|
||||
const index_t loops = (kargs.num_experts + BLOCK_SIZE - 1) / BLOCK_SIZE;
|
||||
const index_t loops = (kargs.num_experts + kBlockSize - 1) / kBlockSize;
|
||||
index_t wave_id = threadIdx.x / get_warp_size();
|
||||
index_t lane_id = threadIdx.x % get_warp_size();
|
||||
|
||||
@@ -2176,7 +2176,7 @@ struct MoeSortingMultiPhaseKernel_P2
|
||||
|
||||
for(index_t i = 0; i < loops; i++)
|
||||
{
|
||||
index_t position = i * BLOCK_SIZE + threadIdx.x;
|
||||
index_t position = i * kBlockSize + threadIdx.x;
|
||||
IndexType a_ = 0; // token count for a expert
|
||||
IndexType b_ = 0; // mask for a expert
|
||||
if(position < kargs.num_experts)
|
||||
@@ -2221,15 +2221,15 @@ struct MoeSortingMultiPhaseKernel_P2
|
||||
if(lane_id == get_warp_size() - 1)
|
||||
{
|
||||
s[4 + wave_id] = cumsum_a;
|
||||
s[4 + wave_id + BLOCK_SIZE / get_warp_size()] = cumsum_b;
|
||||
s[4 + wave_id + kBlockSize / get_warp_size()] = cumsum_b;
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
|
||||
// reduce cross wave
|
||||
static_for<0, BLOCK_SIZE / get_warp_size() - 1, 1>{}([&](auto i_w) {
|
||||
static_for<0, kBlockSize / get_warp_size() - 1, 1>{}([&](auto i_w) {
|
||||
IndexType prev_a = s[4 + i_w];
|
||||
IndexType prev_b = s[4 + i_w + BLOCK_SIZE / get_warp_size()];
|
||||
IndexType prev_b = s[4 + i_w + kBlockSize / get_warp_size()];
|
||||
prev_a = wave_id > i_w ? prev_a : 0; // mask out
|
||||
prev_b = wave_id > i_w ? prev_b : 0; // mask out
|
||||
cumsum_a += prev_a;
|
||||
@@ -2240,7 +2240,7 @@ struct MoeSortingMultiPhaseKernel_P2
|
||||
cumsum_a += prev_cumsum_a;
|
||||
cumsum_b += prev_cumsum_b;
|
||||
|
||||
if(threadIdx.x == BLOCK_SIZE - 1)
|
||||
if(threadIdx.x == kBlockSize - 1)
|
||||
{
|
||||
s[2] = cumsum_a; // store the last cumsum
|
||||
s[3] = cumsum_b;
|
||||
@@ -2297,7 +2297,7 @@ struct MoeSortingMultiPhaseKernel_P3
|
||||
using WeightType = typename Problem::WeightType;
|
||||
using MeshType = typename Problem::MeshType;
|
||||
|
||||
static constexpr index_t BLOCK_SIZE = 256;
|
||||
static constexpr index_t kBlockSize = 256;
|
||||
static constexpr index_t OCCUPANCY = 2; // hard coded
|
||||
|
||||
typedef MoeSortingHostArgs MoeSortingKargs;
|
||||
@@ -2341,12 +2341,12 @@ struct MoeSortingMultiPhaseKernel_P3
|
||||
|
||||
CK_TILE_HOST static constexpr auto GridSize(const Hargs& h) { return dim3(h.num_experts); }
|
||||
|
||||
CK_TILE_HOST static constexpr auto BlockSize(const Hargs&) { return dim3(BLOCK_SIZE); }
|
||||
CK_TILE_HOST static constexpr auto BlockSize(const Hargs&) { return dim3(kBlockSize); }
|
||||
|
||||
// in byte
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetSmemSize()
|
||||
{
|
||||
return (4 + BLOCK_SIZE / get_warp_size()) * sizeof(IndexType);
|
||||
return (4 + kBlockSize / get_warp_size()) * sizeof(IndexType);
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE void operator()(Kargs kargs) const
|
||||
@@ -2391,11 +2391,11 @@ struct MoeSortingMultiPhaseKernel_P3
|
||||
}
|
||||
|
||||
// cumsum one by one
|
||||
int loops = (kargs.mesh_stride + BLOCK_SIZE - 1) / BLOCK_SIZE;
|
||||
int loops = (kargs.mesh_stride + kBlockSize - 1) / kBlockSize;
|
||||
int prev_cumsum = 0;
|
||||
for(int i = 0; i < loops; i++)
|
||||
{
|
||||
int i_token = i * BLOCK_SIZE + threadIdx.x;
|
||||
int i_token = i * kBlockSize + threadIdx.x;
|
||||
IndexType x = 0;
|
||||
if(i_token < tokens)
|
||||
{
|
||||
@@ -2414,13 +2414,13 @@ struct MoeSortingMultiPhaseKernel_P3
|
||||
__syncthreads();
|
||||
|
||||
// reduce cross wave
|
||||
static_for<0, BLOCK_SIZE / get_warp_size() - 1, 1>{}([&](auto i_w) {
|
||||
static_for<0, kBlockSize / get_warp_size() - 1, 1>{}([&](auto i_w) {
|
||||
IndexType prev = s[4 + i_w];
|
||||
prev = wave_id > i_w ? prev : 0; // mask out
|
||||
cumsum += prev;
|
||||
});
|
||||
cumsum += prev_cumsum; // add previous round cumsum
|
||||
if(threadIdx.x == BLOCK_SIZE - 1)
|
||||
if(threadIdx.x == kBlockSize - 1)
|
||||
{
|
||||
s[0] = cumsum;
|
||||
}
|
||||
@@ -2441,7 +2441,7 @@ struct MoeSortingMultiPhaseKernel_P3
|
||||
}
|
||||
}
|
||||
|
||||
for(index_t i = e_start + prev_cumsum + threadIdx.x; i < e_end; i += BLOCK_SIZE)
|
||||
for(index_t i = e_start + prev_cumsum + threadIdx.x; i < e_end; i += kBlockSize)
|
||||
{
|
||||
#if CK_TILE_REFERENCE_MOE_SORTING_MOCK_ID
|
||||
p_sorted_token_ids[i] = MOE_SORTING_MOCK_ID(tokens, kargs.topk_mdiv.divisor);
|
||||
@@ -2457,9 +2457,9 @@ namespace impl {
|
||||
// we use dynamic LDS size here
|
||||
CK_TILE_HOST constexpr auto moe_sorting_get_smem_size_p23(int num_experts_)
|
||||
{
|
||||
constexpr index_t BLOCK_SIZE = 256; // hardcoded 256
|
||||
constexpr index_t kBlockSize = 256; // hardcoded 256
|
||||
const index_t expert_cumsum_elem = num_experts_ + 1;
|
||||
return (4 + 2 * BLOCK_SIZE / get_warp_size() + expert_cumsum_elem) * sizeof(int);
|
||||
return (4 + 2 * kBlockSize / get_warp_size() + expert_cumsum_elem) * sizeof(int);
|
||||
}
|
||||
} // namespace impl
|
||||
|
||||
@@ -2473,7 +2473,7 @@ struct MoeSortingMultiPhaseKernel_P23
|
||||
using WeightType = typename Problem::WeightType;
|
||||
using MeshType = typename Problem::MeshType;
|
||||
|
||||
static constexpr index_t BLOCK_SIZE = 256;
|
||||
static constexpr index_t kBlockSize = 256;
|
||||
static constexpr index_t OCCUPANCY = 2; // hard coded
|
||||
|
||||
typedef MoeSortingHostArgs MoeSortingKargs;
|
||||
@@ -2563,18 +2563,18 @@ struct MoeSortingMultiPhaseKernel_P23
|
||||
return dim3(h.num_experts + get_num_cu() * OCCUPANCY);
|
||||
#else
|
||||
// use 1 block to cumsum
|
||||
// return dim3(1 + ck_tile::integer_divide_ceil(h.moe_buf_bytes, BLOCK_SIZE * 16));
|
||||
return dim3(h.num_experts + ck_tile::integer_divide_ceil(h.moe_buf_bytes, BLOCK_SIZE * 16));
|
||||
// return dim3(1 + ck_tile::integer_divide_ceil(h.moe_buf_bytes, kBlockSize * 16));
|
||||
return dim3(h.num_experts + ck_tile::integer_divide_ceil(h.moe_buf_bytes, kBlockSize * 16));
|
||||
#endif
|
||||
}
|
||||
|
||||
CK_TILE_HOST static constexpr auto BlockSize(const Hargs&) { return dim3(BLOCK_SIZE); }
|
||||
CK_TILE_HOST static constexpr auto BlockSize(const Hargs&) { return dim3(kBlockSize); }
|
||||
|
||||
// only use this at host !
|
||||
CK_TILE_HOST static constexpr auto GetSmemSize(const Hargs& h)
|
||||
{
|
||||
const auto smem_23 = impl::moe_sorting_get_smem_size_p23(h.num_experts);
|
||||
const auto smem_sf = BLOCK_SIZE * 4 * sizeof(IndexType);
|
||||
const auto smem_sf = kBlockSize * 4 * sizeof(IndexType);
|
||||
return max(smem_23, smem_sf);
|
||||
}
|
||||
|
||||
@@ -2595,7 +2595,7 @@ struct MoeSortingMultiPhaseKernel_P23
|
||||
if(static_cast<index_t>(blockIdx.x) >= kargs.num_experts)
|
||||
{
|
||||
#if MOE_SORTING_FMOE_2D_BUF
|
||||
impl::moe_buf_set_zero_kernel_2d<BLOCK_SIZE>(kargs.p_moe_buf,
|
||||
impl::moe_buf_set_zero_kernel_2d<kBlockSize>(kargs.p_moe_buf,
|
||||
tokens,
|
||||
kargs.moe_buf_interm_dim,
|
||||
kargs.moe_buf_elem_bytes,
|
||||
@@ -2603,7 +2603,7 @@ struct MoeSortingMultiPhaseKernel_P23
|
||||
gridDim.x - kargs.num_experts);
|
||||
return;
|
||||
#else
|
||||
impl::moe_buf_set_zero_kernel<BLOCK_SIZE>(
|
||||
impl::moe_buf_set_zero_kernel<kBlockSize>(
|
||||
reinterpret_cast<uint8x16_t*>(kargs.p_moe_buf),
|
||||
kargs.moe_buf_bytes,
|
||||
blockIdx.x - kargs.num_experts);
|
||||
@@ -2618,13 +2618,13 @@ struct MoeSortingMultiPhaseKernel_P23
|
||||
const IndexType* p_local_expert_mask =
|
||||
static_cast<const IndexType*>(kargs.p_local_expert_mask);
|
||||
IndexType* p_expert_cumsum = reinterpret_cast<IndexType*>(kargs.p_expert_cumsum);
|
||||
IndexType* p_expert_cumsum_smem = s + 4 + 2 * BLOCK_SIZE / get_warp_size();
|
||||
IndexType* p_expert_cumsum_smem = s + 4 + 2 * kBlockSize / get_warp_size();
|
||||
IndexType* p_total_tokens_post_pad =
|
||||
reinterpret_cast<IndexType*>(kargs.p_total_tokens_post_pad);
|
||||
IndexType* p_sorted_expert_ids =
|
||||
reinterpret_cast<IndexType*>(kargs.p_sorted_expert_ids);
|
||||
|
||||
const index_t loops = (kargs.num_experts + BLOCK_SIZE - 1) / BLOCK_SIZE;
|
||||
const index_t loops = (kargs.num_experts + kBlockSize - 1) / kBlockSize;
|
||||
index_t wave_id = threadIdx.x / get_warp_size();
|
||||
index_t lane_id = threadIdx.x % get_warp_size();
|
||||
|
||||
@@ -2633,7 +2633,7 @@ struct MoeSortingMultiPhaseKernel_P23
|
||||
|
||||
for(index_t i = 0; i < loops; i++)
|
||||
{
|
||||
index_t position = i * BLOCK_SIZE + threadIdx.x;
|
||||
index_t position = i * kBlockSize + threadIdx.x;
|
||||
IndexType a_ = 0; // token count for a expert
|
||||
IndexType b_ = 0; // mask for a expert
|
||||
if(position < kargs.num_experts)
|
||||
@@ -2678,15 +2678,15 @@ struct MoeSortingMultiPhaseKernel_P23
|
||||
if(lane_id == get_warp_size() - 1)
|
||||
{
|
||||
s[4 + wave_id] = cumsum_a;
|
||||
s[4 + wave_id + BLOCK_SIZE / get_warp_size()] = cumsum_b;
|
||||
s[4 + wave_id + kBlockSize / get_warp_size()] = cumsum_b;
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
|
||||
// reduce cross wave
|
||||
static_for<0, BLOCK_SIZE / get_warp_size() - 1, 1>{}([&](auto i_w) {
|
||||
static_for<0, kBlockSize / get_warp_size() - 1, 1>{}([&](auto i_w) {
|
||||
IndexType prev_a = s[4 + i_w];
|
||||
IndexType prev_b = s[4 + i_w + BLOCK_SIZE / get_warp_size()];
|
||||
IndexType prev_b = s[4 + i_w + kBlockSize / get_warp_size()];
|
||||
prev_a = wave_id > i_w ? prev_a : 0; // mask out
|
||||
prev_b = wave_id > i_w ? prev_b : 0; // mask out
|
||||
cumsum_a += prev_a;
|
||||
@@ -2697,7 +2697,7 @@ struct MoeSortingMultiPhaseKernel_P23
|
||||
cumsum_a += prev_cumsum_a;
|
||||
cumsum_b += prev_cumsum_b;
|
||||
|
||||
if(threadIdx.x == BLOCK_SIZE - 1)
|
||||
if(threadIdx.x == kBlockSize - 1)
|
||||
{
|
||||
s[2] = cumsum_a; // store the last cumsum
|
||||
s[3] = cumsum_b;
|
||||
@@ -2758,7 +2758,7 @@ struct MoeSortingMultiPhaseKernel_P23
|
||||
IndexType* s = reinterpret_cast<IndexType*>(smem);
|
||||
MeshType* p_expert_mesh = reinterpret_cast<MeshType*>(kargs.p_expert_mesh);
|
||||
IndexType* p_sorted_token_ids = reinterpret_cast<IndexType*>(kargs.p_sorted_token_ids);
|
||||
IndexType* p_expert_cumsum_smem = s + 4 + 2 * BLOCK_SIZE / get_warp_size();
|
||||
IndexType* p_expert_cumsum_smem = s + 4 + 2 * kBlockSize / get_warp_size();
|
||||
const WeightType* p_weights = static_cast<const WeightType*>(kargs.p_weights);
|
||||
WeightType* p_sorted_weights = reinterpret_cast<WeightType*>(kargs.p_sorted_weights);
|
||||
|
||||
@@ -2795,13 +2795,13 @@ struct MoeSortingMultiPhaseKernel_P23
|
||||
constexpr index_t index_pack = Problem::SubTokenTile; // always packed
|
||||
using r_t = ext_vector_t<MeshType, index_pack>; // always use int32x4
|
||||
using d_t = ext_vector_t<index_t, index_pack>;
|
||||
int loops = (mesh_stride / index_pack + BLOCK_SIZE - 1) / BLOCK_SIZE;
|
||||
int loops = (mesh_stride / index_pack + kBlockSize - 1) / kBlockSize;
|
||||
|
||||
int prev_cumsum = 0;
|
||||
|
||||
for(int i = 0; i < loops; i++)
|
||||
{
|
||||
int i_token_pack = i * BLOCK_SIZE + threadIdx.x;
|
||||
int i_token_pack = i * kBlockSize + threadIdx.x;
|
||||
r_t x_v = 0;
|
||||
if(i_token_pack < (tokens + index_pack - 1) / index_pack)
|
||||
{
|
||||
@@ -2819,7 +2819,7 @@ struct MoeSortingMultiPhaseKernel_P23
|
||||
|
||||
static_for<0, index_pack, 1>{}([&](auto j_) {
|
||||
constexpr auto j = j_.value;
|
||||
x_r[j] = reinterpret_cast<MeshType*>(s)[threadIdx.x + j * BLOCK_SIZE];
|
||||
x_r[j] = reinterpret_cast<MeshType*>(s)[threadIdx.x + j * kBlockSize];
|
||||
});
|
||||
}
|
||||
#else
|
||||
@@ -2830,7 +2830,7 @@ struct MoeSortingMultiPhaseKernel_P23
|
||||
#pragma unroll
|
||||
for(int j = 0; j < index_pack / 2; j++)
|
||||
{
|
||||
int i_token = i * BLOCK_SIZE * index_pack + threadIdx.x + j * BLOCK_SIZE;
|
||||
int i_token = i * kBlockSize * index_pack + threadIdx.x + j * kBlockSize;
|
||||
index_t x = x_d[j];
|
||||
int i_topk = x - 1; // topk of this token
|
||||
int i_show = x != 0 ? 1 : 0; // has this token or not
|
||||
@@ -2845,13 +2845,13 @@ struct MoeSortingMultiPhaseKernel_P23
|
||||
__syncthreads();
|
||||
|
||||
// reduce cross wave
|
||||
static_for<0, BLOCK_SIZE / get_warp_size() - 1, 1>{}([&](auto i_w) {
|
||||
static_for<0, kBlockSize / get_warp_size() - 1, 1>{}([&](auto i_w) {
|
||||
IndexType prev = s[4 + i_w];
|
||||
prev = wave_id > i_w ? prev : 0; // mask out
|
||||
cumsum += prev;
|
||||
});
|
||||
cumsum += prev_cumsum; // add previous round cumsum
|
||||
if(threadIdx.x == BLOCK_SIZE - 1)
|
||||
if(threadIdx.x == kBlockSize - 1)
|
||||
{
|
||||
s[0] = cumsum;
|
||||
}
|
||||
@@ -2896,13 +2896,13 @@ struct MoeSortingMultiPhaseKernel_P23
|
||||
__syncthreads();
|
||||
|
||||
// reduce cross wave
|
||||
static_for<0, BLOCK_SIZE / get_warp_size() - 1, 1>{}([&](auto i_w) {
|
||||
static_for<0, kBlockSize / get_warp_size() - 1, 1>{}([&](auto i_w) {
|
||||
IndexType prev = s[4 + i_w];
|
||||
prev = wave_id > i_w ? prev : 0; // mask out
|
||||
cumsum += prev;
|
||||
});
|
||||
cumsum += prev_cumsum; // add previous round cumsum
|
||||
if(threadIdx.x == BLOCK_SIZE - 1)
|
||||
if(threadIdx.x == kBlockSize - 1)
|
||||
{
|
||||
s[0] = cumsum;
|
||||
}
|
||||
@@ -2912,10 +2912,10 @@ struct MoeSortingMultiPhaseKernel_P23
|
||||
int position = cumsum - cumsum_store;
|
||||
static_for<0, index_pack, 1>{}([&](auto j_) {
|
||||
constexpr auto j = j_.value;
|
||||
// int i_token = i * BLOCK_SIZE * index_pack + threadIdx.x + j *
|
||||
// BLOCK_SIZE;
|
||||
// int i_token = i * kBlockSize * index_pack + threadIdx.x + j *
|
||||
// kBlockSize;
|
||||
int i_token =
|
||||
i * BLOCK_SIZE * index_pack + threadIdx.x * index_pack + j;
|
||||
i * kBlockSize * index_pack + threadIdx.x * index_pack + j;
|
||||
|
||||
if(i_show[j])
|
||||
{
|
||||
@@ -2932,7 +2932,7 @@ struct MoeSortingMultiPhaseKernel_P23
|
||||
});
|
||||
|
||||
#if 0
|
||||
int i_token = i * BLOCK_SIZE * index_pack + threadIdx.x * 2 + j * BLOCK_SIZE * 2;
|
||||
int i_token = i * kBlockSize * index_pack + threadIdx.x * 2 + j * kBlockSize * 2;
|
||||
index_t x = x_d[j];
|
||||
index_t x0 = static_cast<index_t>(x & 0xffff);
|
||||
index_t x1 = static_cast<index_t>(x >> 16);
|
||||
@@ -2951,13 +2951,13 @@ struct MoeSortingMultiPhaseKernel_P23
|
||||
__syncthreads();
|
||||
|
||||
// reduce cross wave
|
||||
static_for<0, BLOCK_SIZE / get_warp_size() - 1, 1>{}([&](auto i_w) {
|
||||
static_for<0, kBlockSize / get_warp_size() - 1, 1>{}([&](auto i_w) {
|
||||
IndexType prev = s[4 + i_w];
|
||||
prev = wave_id > i_w ? prev : 0; // mask out
|
||||
cumsum += prev;
|
||||
});
|
||||
cumsum += prev_cumsum; // add previous round cumsum
|
||||
if(threadIdx.x == BLOCK_SIZE - 1)
|
||||
if(threadIdx.x == kBlockSize - 1)
|
||||
{
|
||||
s[0] = cumsum;
|
||||
}
|
||||
@@ -2996,7 +2996,7 @@ struct MoeSortingMultiPhaseKernel_P23
|
||||
}
|
||||
}
|
||||
|
||||
for(index_t i = e_start + prev_cumsum + threadIdx.x; i < e_end; i += BLOCK_SIZE)
|
||||
for(index_t i = e_start + prev_cumsum + threadIdx.x; i < e_end; i += kBlockSize)
|
||||
{
|
||||
#if CK_TILE_REFERENCE_MOE_SORTING_MOCK_ID
|
||||
p_sorted_token_ids[i] = MOE_SORTING_MOCK_ID(tokens, kargs.topk_mdiv.divisor);
|
||||
|
||||
Reference in New Issue
Block a user