mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-14 10:09:41 +00:00
Do not use warpSize as compile time constant as it is removed (#2320)
* Do not use warpSize as compile time constant as it is removed
* Update tile_image_to_column_shape.hpp
update warpSize usage.
* clean-up all use of warpSize, make sure code builds
* fix
---------
Co-authored-by: Illia Silin <98187287+illsilin@users.noreply.github.com>
Co-authored-by: illsilin <Illia.Silin@amd.com>
Co-authored-by: Bartlomiej Kocot <barkocot@amd.com>
[ROCm/composable_kernel commit: 4c57157d50]
This commit is contained in:
committed by
GitHub
parent
66afddf431
commit
bde406245a
@@ -274,6 +274,12 @@
|
||||
|
||||
namespace ck {
|
||||
|
||||
#if defined(__GFX9__) || !defined(__HIP_DEVICE_COMPILE__)
|
||||
__device__ static constexpr int WarpSize = 64;
|
||||
#else
|
||||
__device__ static constexpr int WarpSize = 32;
|
||||
#endif
|
||||
|
||||
enum struct InMemoryDataOperationEnum
|
||||
{
|
||||
Set,
|
||||
|
||||
@@ -45,7 +45,7 @@ struct BlockwiseGemmXdlops_mx_pipeline_base
|
||||
|
||||
using ThisThreadBlock = ThisThreadBlock<BlockSize>;
|
||||
|
||||
// Hardcode to 64, as HIP-provided "warpSize" would return 32 on RDNA GPUs.
|
||||
// Hardcode to 64, as HIP-provided "WarpSize" would return 32 on RDNA GPUs.
|
||||
static constexpr index_t WaveSize = 64;
|
||||
|
||||
static constexpr index_t A_K0 = ATileDesc{}.GetLength(I0);
|
||||
|
||||
@@ -40,7 +40,7 @@ struct BlockwiseGemmXdlops_pipeline_base
|
||||
|
||||
using ThisThreadBlock = ThisThreadBlock<BlockSize>;
|
||||
|
||||
// Hardcode to 64, as HIP-provided "warpSize" would return 32 on RDNA GPUs.
|
||||
// Hardcode to 64, as HIP-provided "WarpSize" would return 32 on RDNA GPUs.
|
||||
static constexpr index_t WaveSize = 64;
|
||||
|
||||
static constexpr index_t A_K0 = ATileDesc{}.GetLength(I0);
|
||||
|
||||
@@ -141,7 +141,7 @@ struct BlockwiseGemmXdlops_pipeline_v2<BlockGemmPipelineScheduler::Intrawave,
|
||||
using Base::BMmaKStride;
|
||||
|
||||
static constexpr index_t WgpPerCU =
|
||||
(4 * warpSize / BlockSize) >= 1 ? 4 * warpSize / BlockSize : 1;
|
||||
(4 * WarpSize / BlockSize) >= 1 ? 4 * WarpSize / BlockSize : 1;
|
||||
static constexpr index_t FullMemBandPrefetchStages = math::integer_divide_ceil(
|
||||
32768 / WgpPerCU,
|
||||
(MPerBlock * sizeof(ADataType) + NPerBlock * sizeof(BDataType)) * KPerBlock);
|
||||
@@ -631,7 +631,7 @@ struct BlockwiseGemmXdlops_pipeline_v2<BlockGemmPipelineScheduler::Interwave,
|
||||
static constexpr index_t KRepeat = KPerThread / KPerInnerLoop;
|
||||
|
||||
static constexpr index_t WgpPerCU =
|
||||
(4 * warpSize / BlockSize) >= 1 ? 4 * warpSize / BlockSize : 1;
|
||||
(4 * WarpSize / BlockSize) >= 1 ? 4 * WarpSize / BlockSize : 1;
|
||||
static constexpr index_t FullMemBandPrefetchStages = math::integer_divide_ceil(
|
||||
32768 / WgpPerCU,
|
||||
(MPerBlock * sizeof(ADataType) + NPerBlock * sizeof(BDataType)) * KPerBlock);
|
||||
|
||||
@@ -143,7 +143,7 @@ struct BlockwiseGemmXdlops_pipeline_v2_ab_scale<BlockGemmPipelineScheduler::Intr
|
||||
using Base::BMmaKStride;
|
||||
|
||||
static constexpr index_t WgpPerCU =
|
||||
(4 * warpSize / BlockSize) >= 1 ? 4 * warpSize / BlockSize : 1;
|
||||
(4 * WarpSize / BlockSize) >= 1 ? 4 * WarpSize / BlockSize : 1;
|
||||
static constexpr index_t FullMemBandPrefetchStages = math::integer_divide_ceil(
|
||||
32768 / WgpPerCU,
|
||||
(MPerBlock * sizeof(ADataType) + NPerBlock * sizeof(BDataType)) * KPerBlock);
|
||||
|
||||
@@ -141,7 +141,7 @@ struct BlockwiseGemmXdlops_pipeline_v2_b_scale<BlockGemmPipelineScheduler::Intra
|
||||
using Base::BMmaKStride;
|
||||
|
||||
static constexpr index_t WgpPerCU =
|
||||
(4 * warpSize / BlockSize) >= 1 ? 4 * warpSize / BlockSize : 1;
|
||||
(4 * WarpSize / BlockSize) >= 1 ? 4 * WarpSize / BlockSize : 1;
|
||||
static constexpr index_t FullMemBandPrefetchStages = math::integer_divide_ceil(
|
||||
32768 / WgpPerCU,
|
||||
(MPerBlock * sizeof(ADataType) + NPerBlock * sizeof(BDataType)) * KPerBlock);
|
||||
@@ -632,7 +632,7 @@ struct BlockwiseGemmXdlops_pipeline_v2_b_scale<BlockGemmPipelineScheduler::Inter
|
||||
static constexpr index_t KRepeat = KPerThread / KPerInnerLoop;
|
||||
|
||||
static constexpr index_t WgpPerCU =
|
||||
(4 * warpSize / BlockSize) >= 1 ? 4 * warpSize / BlockSize : 1;
|
||||
(4 * WarpSize / BlockSize) >= 1 ? 4 * WarpSize / BlockSize : 1;
|
||||
static constexpr index_t FullMemBandPrefetchStages = math::integer_divide_ceil(
|
||||
32768 / WgpPerCU,
|
||||
(MPerBlock * sizeof(ADataType) + NPerBlock * sizeof(BDataType)) * KPerBlock);
|
||||
|
||||
@@ -202,7 +202,7 @@ struct GridwiseMultiblockBatchNormForward
|
||||
const index_t block_local_id = block_global_id % blkgroup_size;
|
||||
|
||||
if(block_local_id == 0)
|
||||
gms_init(BlockSize / warpSize * blkgroup_size, &p_control[blkgroup_id * 2]);
|
||||
gms_init(BlockSize / WarpSize * blkgroup_size, &p_control[blkgroup_id * 2]);
|
||||
|
||||
const auto thread_cluster_idx =
|
||||
thread_cluster_desc.CalculateBottomIndex(make_multi_index(thread_local_id));
|
||||
|
||||
@@ -347,7 +347,7 @@ struct GridwiseGemm_xdl_cshuffle_v3_b_preshuffle
|
||||
|
||||
__host__ __device__ static auto MakeBGridDescriptor_Preshuffled(index_t N0, index_t K0)
|
||||
{
|
||||
constexpr index_t NkSwizzleNumber = Number<warpSize * KPack>{};
|
||||
constexpr index_t NkSwizzleNumber = Number<WarpSize * KPack>{};
|
||||
return make_naive_tensor_descriptor(
|
||||
make_tuple(N0 / NWave, NWave, K0, NkSwizzleNumber),
|
||||
make_tuple(NWave * K0 * NkSwizzleNumber, K0 * NkSwizzleNumber, NkSwizzleNumber, I1));
|
||||
@@ -1229,7 +1229,7 @@ struct GridwiseGemm_xdl_cshuffle_v3_b_preshuffle
|
||||
make_multi_index(n_block_data_idx_on_grid,
|
||||
get_warp_local_1d_id() % NWave,
|
||||
0,
|
||||
KPack * (get_thread_local_1d_id() % warpSize)));
|
||||
KPack * (get_thread_local_1d_id() % WarpSize)));
|
||||
|
||||
// LDS allocation for A and B: be careful of alignment
|
||||
|
||||
@@ -1607,7 +1607,7 @@ struct GridwiseGemm_xdl_cshuffle_v3_b_preshuffle
|
||||
make_multi_index(n_block_data_idx_on_grid,
|
||||
get_warp_local_1d_id() % NWave,
|
||||
0,
|
||||
KPack * (get_thread_local_1d_id() % warpSize)));
|
||||
KPack * (get_thread_local_1d_id() % WarpSize)));
|
||||
|
||||
// LDS allocation for A and B: be careful of alignment
|
||||
auto a_block_buf_ping = make_dynamic_buffer<AddressSpaceEnum::Lds>(
|
||||
|
||||
@@ -374,7 +374,7 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle
|
||||
|
||||
__host__ __device__ static auto MakeBGridDescriptor_Preshuffled(index_t N0, index_t K0)
|
||||
{
|
||||
constexpr index_t NkSwizzleNumber = Number<warpSize * KPackPerGroup>{};
|
||||
constexpr index_t NkSwizzleNumber = Number<WarpSize * KPackPerGroup>{};
|
||||
return make_naive_tensor_descriptor(
|
||||
make_tuple(N0 / NWave, NWave, K0, NkSwizzleNumber),
|
||||
make_tuple(NWave * K0 * NkSwizzleNumber, K0 * NkSwizzleNumber, NkSwizzleNumber, I1));
|
||||
@@ -1249,7 +1249,7 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle
|
||||
make_multi_index(n_block_data_idx_on_grid,
|
||||
get_warp_local_1d_id() % NWave,
|
||||
0,
|
||||
KPackPerGroup * (get_thread_local_1d_id() % warpSize)));
|
||||
KPackPerGroup * (get_thread_local_1d_id() % WarpSize)));
|
||||
|
||||
// LDS allocation for A and B: be careful of alignment
|
||||
// Cast after lds
|
||||
@@ -1687,7 +1687,7 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle
|
||||
make_multi_index(n_block_data_idx_on_grid,
|
||||
get_warp_local_1d_id() % NWave,
|
||||
0,
|
||||
KPackPerGroup * (get_thread_local_1d_id() % warpSize)));
|
||||
KPackPerGroup * (get_thread_local_1d_id() % WarpSize)));
|
||||
|
||||
// LDS allocation for A and B: be careful of alignment
|
||||
// Cast after lds
|
||||
|
||||
@@ -370,7 +370,7 @@ struct GridwiseGemmMultiD_blockscale_xdl_cshuffle_v3_b_preshuffle
|
||||
|
||||
__host__ __device__ static auto MakeBGridDescriptor_Preshuffled(index_t N0, index_t K0)
|
||||
{
|
||||
constexpr index_t NkSwizzleNumber = Number<warpSize * KPack / KGroup>{};
|
||||
constexpr index_t NkSwizzleNumber = Number<WarpSize * KPack / KGroup>{};
|
||||
return make_naive_tensor_descriptor(
|
||||
make_tuple(N0 / NWave, NWave, K0, NkSwizzleNumber),
|
||||
make_tuple(NWave * K0 * NkSwizzleNumber, K0 * NkSwizzleNumber, NkSwizzleNumber, I1));
|
||||
@@ -1208,7 +1208,7 @@ struct GridwiseGemmMultiD_blockscale_xdl_cshuffle_v3_b_preshuffle
|
||||
make_multi_index(n_block_data_idx_on_grid,
|
||||
get_warp_local_1d_id() % NWave,
|
||||
0,
|
||||
KPack / KGroup * (get_thread_local_1d_id() % warpSize)));
|
||||
KPack / KGroup * (get_thread_local_1d_id() % WarpSize)));
|
||||
|
||||
// LDS allocation for A and B: be careful of alignment
|
||||
// Cast after lds
|
||||
@@ -1707,7 +1707,7 @@ struct GridwiseGemmMultiD_blockscale_xdl_cshuffle_v3_b_preshuffle
|
||||
make_multi_index(n_block_data_idx_on_grid,
|
||||
get_warp_local_1d_id() % NWave,
|
||||
0,
|
||||
KPack / KGroup * (get_thread_local_1d_id() % warpSize)));
|
||||
KPack / KGroup * (get_thread_local_1d_id() % WarpSize)));
|
||||
|
||||
// LDS allocation for A and B: be careful of alignment
|
||||
// Cast after lds
|
||||
|
||||
@@ -422,7 +422,7 @@ struct GridwiseGemmMX_xdl_cshuffle_v3_bpreshuffle
|
||||
|
||||
__host__ __device__ static auto MakeBGridDescriptor_Preshuffled(index_t N0, index_t K0)
|
||||
{
|
||||
constexpr index_t NkSwizzleNumber = Number<warpSize * KPack>{};
|
||||
constexpr index_t NkSwizzleNumber = Number<WarpSize * KPack>{};
|
||||
return make_naive_tensor_descriptor_packed(
|
||||
make_tuple(N0 / NWave / NXdlPack, NWave, NXdlPack, K0, NkSwizzleNumber));
|
||||
}
|
||||
@@ -1886,7 +1886,7 @@ struct GridwiseGemmMX_xdl_cshuffle_v3_bpreshuffle
|
||||
get_warp_local_1d_id() % NWave,
|
||||
0,
|
||||
0,
|
||||
KPack * (get_thread_local_1d_id() % warpSize)));
|
||||
KPack * (get_thread_local_1d_id() % WarpSize)));
|
||||
|
||||
// LDS allocation for A and B: be careful of alignment
|
||||
auto a_block_buf_ping = make_dynamic_buffer<AddressSpaceEnum::Lds>(
|
||||
|
||||
@@ -405,7 +405,7 @@ struct GridwiseMoeGemm
|
||||
|
||||
__host__ __device__ static auto MakeBGridDescriptor_Preshuffled(index_t N0, index_t K0)
|
||||
{
|
||||
constexpr index_t NkSwizzleNumber = Number<warpSize * KPack / KGroup>{};
|
||||
constexpr index_t NkSwizzleNumber = Number<WarpSize * KPack / KGroup>{};
|
||||
return make_naive_tensor_descriptor(
|
||||
make_tuple(N0 / NWave, NWave, K0, NkSwizzleNumber),
|
||||
make_tuple(NWave * K0 * NkSwizzleNumber, K0 * NkSwizzleNumber, NkSwizzleNumber, I1));
|
||||
@@ -1315,7 +1315,7 @@ struct GridwiseMoeGemm
|
||||
make_multi_index(n_block_data_idx_on_grid,
|
||||
get_warp_local_1d_id() % NWave,
|
||||
0,
|
||||
KPack / KGroup * (get_thread_local_1d_id() % warpSize)));
|
||||
KPack / KGroup * (get_thread_local_1d_id() % WarpSize)));
|
||||
|
||||
// LDS allocation for A and B: be careful of alignment
|
||||
// Cast after lds
|
||||
@@ -1361,7 +1361,8 @@ struct GridwiseMoeGemm
|
||||
make_multi_index(n_block_data_idx_on_grid,
|
||||
get_warp_local_1d_id() % NWave,
|
||||
0,
|
||||
KPack / KGroup * (get_thread_local_1d_id() % warpSize)));
|
||||
KPack / KGroup * (get_thread_local_1d_id() % WarpSize)));
|
||||
|
||||
blockwise_gemm_pipeline.template Run<HasMainKBlockLoop, TailNum>(
|
||||
a_grid_desc_ak0_m_ak1,
|
||||
a_block_desc_ak0_m_ak1,
|
||||
@@ -2027,7 +2028,7 @@ struct GridwiseMoeGemm
|
||||
make_multi_index(n_block_data_idx_on_grid,
|
||||
get_warp_local_1d_id() % NWave,
|
||||
0,
|
||||
KPack / KGroup * (get_thread_local_1d_id() % warpSize)));
|
||||
KPack / KGroup * (get_thread_local_1d_id() % WarpSize)));
|
||||
|
||||
// LDS allocation for A and B: be careful of alignment
|
||||
// Cast after lds
|
||||
@@ -2077,7 +2078,7 @@ struct GridwiseMoeGemm
|
||||
make_multi_index(n_block_data_idx_on_grid,
|
||||
get_warp_local_1d_id() % NWave,
|
||||
0,
|
||||
KPack / KGroup * (get_thread_local_1d_id() % warpSize)));
|
||||
KPack / KGroup * (get_thread_local_1d_id() % WarpSize)));
|
||||
blockwise_gemm_pipeline.template Run<HasMainKBlockLoop, TailNum>(
|
||||
a_grid_desc_ak0_m_ak1,
|
||||
a_block_desc_ak0_m_ak1,
|
||||
|
||||
@@ -410,7 +410,7 @@ struct GridwiseMoeGemmBlockScale
|
||||
|
||||
__host__ __device__ static auto MakeBGridDescriptor_Preshuffled(index_t N0, index_t K0)
|
||||
{
|
||||
constexpr index_t NkSwizzleNumber = Number<warpSize * KPack / KGroup>{};
|
||||
constexpr index_t NkSwizzleNumber = Number<WarpSize * KPack / KGroup>{};
|
||||
return make_naive_tensor_descriptor(
|
||||
make_tuple(N0 / NWave, NWave, K0, NkSwizzleNumber),
|
||||
make_tuple(NWave * K0 * NkSwizzleNumber, K0 * NkSwizzleNumber, NkSwizzleNumber, I1));
|
||||
@@ -1355,7 +1355,7 @@ struct GridwiseMoeGemmBlockScale
|
||||
make_multi_index(n_block_data_idx_on_grid,
|
||||
get_warp_local_1d_id() % NWave,
|
||||
0,
|
||||
KPack / KGroup * (get_thread_local_1d_id() % warpSize)));
|
||||
KPack / KGroup * (get_thread_local_1d_id() % WarpSize)));
|
||||
|
||||
// LDS allocation for A and B: be careful of alignment
|
||||
// Cast after lds
|
||||
@@ -1467,7 +1467,7 @@ struct GridwiseMoeGemmBlockScale
|
||||
make_multi_index(n_block_data_idx_on_grid,
|
||||
get_warp_local_1d_id() % NWave,
|
||||
0,
|
||||
KPack / KGroup * (get_thread_local_1d_id() % warpSize)));
|
||||
KPack / KGroup * (get_thread_local_1d_id() % WarpSize)));
|
||||
const BScaleType* p_b_scale_grid_up =
|
||||
p_b_scale_grid + expert_scale_stride / 2 / BPackedSize;
|
||||
const auto b_scale_grid_buf_up = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
@@ -2105,7 +2105,7 @@ struct GridwiseMoeGemmBlockScale
|
||||
make_multi_index(n_block_data_idx_on_grid,
|
||||
get_warp_local_1d_id() % NWave,
|
||||
0,
|
||||
KPack / KGroup * (get_thread_local_1d_id() % warpSize)));
|
||||
KPack / KGroup * (get_thread_local_1d_id() % WarpSize)));
|
||||
|
||||
// LDS allocation for A and B: be careful of alignment
|
||||
// Cast after lds
|
||||
@@ -2221,7 +2221,7 @@ struct GridwiseMoeGemmBlockScale
|
||||
make_multi_index(n_block_data_idx_on_grid,
|
||||
get_warp_local_1d_id() % NWave,
|
||||
0,
|
||||
KPack / KGroup * (get_thread_local_1d_id() % warpSize)));
|
||||
KPack / KGroup * (get_thread_local_1d_id() % WarpSize)));
|
||||
const BScaleType* p_b_scale_grid_up =
|
||||
p_b_scale_grid + expert_scale_stride / 2 / BPackedSize;
|
||||
const auto b_scale_grid_buf_up = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
|
||||
@@ -409,7 +409,7 @@ struct GridwiseMoeGemmMX
|
||||
|
||||
__host__ __device__ static auto MakeBGridDescriptor_Preshuffled(index_t N0, index_t K0)
|
||||
{
|
||||
constexpr index_t NkSwizzleNumber = Number<warpSize * KPack / KGroup>{};
|
||||
constexpr index_t NkSwizzleNumber = Number<WarpSize * KPack / KGroup>{};
|
||||
return make_naive_tensor_descriptor(
|
||||
make_tuple(N0 / NWave / NXdlPack, NWave, NXdlPack, K0, NkSwizzleNumber),
|
||||
make_tuple(NWave * NXdlPack * K0 * NkSwizzleNumber,
|
||||
@@ -1415,7 +1415,7 @@ struct GridwiseMoeGemmMX
|
||||
make_multi_index(n_block_data_idx_on_grid,
|
||||
get_warp_local_1d_id() % NWave,
|
||||
0,
|
||||
KPack / KGroup * (get_thread_local_1d_id() % warpSize)));
|
||||
KPack / KGroup * (get_thread_local_1d_id() % WarpSize)));
|
||||
|
||||
// LDS allocation for A and B: be careful of alignment
|
||||
// Cast after lds
|
||||
@@ -1508,7 +1508,7 @@ struct GridwiseMoeGemmMX
|
||||
make_multi_index(n_block_data_idx_on_grid,
|
||||
get_warp_local_1d_id() % NWave,
|
||||
0,
|
||||
KPack / KGroup * (get_thread_local_1d_id() % warpSize)));
|
||||
KPack / KGroup * (get_thread_local_1d_id() % WarpSize)));
|
||||
const BScaleDataType* p_b_scale_grid_up = p_b_scale_grid + expert_scale_stride / 2;
|
||||
const auto b_scale_grid_buf_up = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_b_scale_grid_up + expert_id * expert_scale_stride,
|
||||
@@ -2123,7 +2123,7 @@ struct GridwiseMoeGemmMX
|
||||
get_warp_local_1d_id() % NWave,
|
||||
0,
|
||||
0,
|
||||
KPack / KGroup * (get_thread_local_1d_id() % warpSize)));
|
||||
KPack / KGroup * (get_thread_local_1d_id() % WarpSize)));
|
||||
|
||||
// LDS allocation for A and B: be careful of alignment
|
||||
// Cast after lds
|
||||
@@ -2221,7 +2221,7 @@ struct GridwiseMoeGemmMX
|
||||
make_multi_index(n_block_data_idx_on_grid,
|
||||
get_warp_local_1d_id() % NWave,
|
||||
0,
|
||||
KPack / KGroup * (get_thread_local_1d_id() % warpSize)));
|
||||
KPack / KGroup * (get_thread_local_1d_id() % WarpSize)));
|
||||
const BScaleDataType* p_b_scale_grid_up = p_b_scale_grid + expert_scale_stride / 2;
|
||||
const auto b_scale_grid_buf_up = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_b_scale_grid_up + expert_id * expert_scale_stride,
|
||||
|
||||
@@ -2319,7 +2319,7 @@ struct GridwiseMoeGemmMXBNS
|
||||
get_warp_local_1d_id() % NWave,
|
||||
0,
|
||||
0,
|
||||
KPack / KGroup * (get_thread_local_1d_id() % warpSize)));
|
||||
KPack / KGroup * (get_thread_local_1d_id() % WarpSize)));
|
||||
|
||||
// LDS allocation for A and B: be careful of alignment
|
||||
// Cast after lds
|
||||
@@ -2417,7 +2417,7 @@ struct GridwiseMoeGemmMXBNS
|
||||
make_multi_index(n_block_data_idx_on_grid,
|
||||
get_warp_local_1d_id() % NWave,
|
||||
0,
|
||||
KPack / KGroup * (get_thread_local_1d_id() % warpSize)));
|
||||
KPack / KGroup * (get_thread_local_1d_id() % WarpSize)));
|
||||
const BScaleDataType* p_b_scale_grid_up = p_b_scale_grid + expert_scale_stride / 2;
|
||||
const auto b_scale_grid_buf_up = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_b_scale_grid_up + expert_id * expert_scale_stride,
|
||||
|
||||
@@ -32,7 +32,7 @@ static __device__ void gms_init(int NumWarps, int* p_control_bits)
|
||||
// all the workgroups in the synchronization group is supposed to call this function
|
||||
static __device__ void gms_barrier(int* p_control_bits)
|
||||
{
|
||||
constexpr int mask = warpSize - 1;
|
||||
constexpr int mask = WarpSize - 1;
|
||||
|
||||
if((threadIdx.x & mask) == 0)
|
||||
{
|
||||
|
||||
@@ -35,7 +35,7 @@ CK_TILE_DEVICE T warp_shuffle_up(const T& v_local, uint32_t lane_delta)
|
||||
#elif 1
|
||||
static_assert(sizeof(T) == sizeof(int32_t), "wrong!");
|
||||
|
||||
const uint32_t wrap_around_lane_delta = warpSize - lane_delta;
|
||||
const uint32_t wrap_around_lane_delta = get_warp_size() - lane_delta;
|
||||
|
||||
const int32_t v_remote_tmp = __builtin_amdgcn_ds_bpermute(
|
||||
(__lane_id() << 2) + (wrap_around_lane_delta << 2), bit_cast<int32_t>(v_local));
|
||||
|
||||
@@ -95,7 +95,7 @@ struct Flatmm_32x512x128_1x4x1_16x16x32_Base // for f16/bf16
|
||||
// constexpr index_t Block_M = Problem::BlockShape::Block_M0;
|
||||
// constexpr index_t Block_K = Problem::BlockShape::Block_K0;
|
||||
// constexpr index_t BlockSize = Problem::BlockShape::BlockSize;
|
||||
constexpr index_t warpSize = ck_tile::get_warp_size();
|
||||
constexpr index_t WarpSize = ck_tile::get_warp_size();
|
||||
// constexpr index_t NumWarps = Problem::BlockShape::NumWarps;
|
||||
|
||||
constexpr index_t KPack_ = 8; // GetSmemKPack_A<Problem>(); // LDS
|
||||
@@ -104,11 +104,11 @@ struct Flatmm_32x512x128_1x4x1_16x16x32_Base // for f16/bf16
|
||||
|
||||
static_assert(Block_K % KVector == 0);
|
||||
constexpr index_t LanesPerK = Block_K / KVector; // how many thread loading K
|
||||
if constexpr(LanesPerK >= warpSize)
|
||||
if constexpr(LanesPerK >= WarpSize)
|
||||
{
|
||||
// need multiple waves to load K
|
||||
static_assert(LanesPerK % warpSize == 0);
|
||||
constexpr index_t wavesPerK = LanesPerK / warpSize;
|
||||
static_assert(LanesPerK % WarpSize == 0);
|
||||
constexpr index_t wavesPerK = LanesPerK / WarpSize;
|
||||
if constexpr(wavesPerK > NumWarps)
|
||||
{
|
||||
// TODO: need multiple issues along K to load all data
|
||||
@@ -121,11 +121,11 @@ struct Flatmm_32x512x128_1x4x1_16x16x32_Base // for f16/bf16
|
||||
make_tuple(number<NumIssues>{}, // m0
|
||||
number<wavesPerM>{}, // m1
|
||||
number<wavesPerK>{}, // k0
|
||||
number<warpSize>{}, // k1
|
||||
number<WarpSize>{}, // k1
|
||||
number<KVector>{}), // k2
|
||||
make_tuple(number<NumWarps*(warpSize * KVector + KPad)>{}, // m0
|
||||
number<wavesPerK*(warpSize * KVector + KPad)>{}, // m1
|
||||
number<warpSize * KVector + KPad>{}, // k0
|
||||
make_tuple(number<NumWarps*(WarpSize * KVector + KPad)>{}, // m0
|
||||
number<wavesPerK*(WarpSize * KVector + KPad)>{}, // m1
|
||||
number<WarpSize * KVector + KPad>{}, // k0
|
||||
number<KVector>{}, // k1
|
||||
number<1>{}), // k2
|
||||
number<KVector>{}, // lds store vector(actually no explicit store)
|
||||
@@ -136,7 +136,7 @@ struct Flatmm_32x512x128_1x4x1_16x16x32_Base // for f16/bf16
|
||||
make_tuple(
|
||||
make_pass_through_transform(number<NumIssues>{}),
|
||||
make_merge_transform(make_tuple(number<wavesPerM>{}, number<wavesPerK>{})),
|
||||
make_merge_transform(make_tuple(number<warpSize>{}, number<KVector>{}))),
|
||||
make_merge_transform(make_tuple(number<WarpSize>{}, number<KVector>{}))),
|
||||
make_tuple(sequence<0>{}, sequence<1, 2>{}, sequence<3, 4>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}));
|
||||
|
||||
@@ -146,8 +146,8 @@ struct Flatmm_32x512x128_1x4x1_16x16x32_Base // for f16/bf16
|
||||
else
|
||||
{
|
||||
// lanes within a wave load different M but same K
|
||||
static_assert(warpSize % LanesPerK == 0);
|
||||
constexpr index_t LaneGroups = warpSize / LanesPerK; // along m
|
||||
static_assert(WarpSize % LanesPerK == 0);
|
||||
constexpr index_t LaneGroups = WarpSize / LanesPerK; // along m
|
||||
constexpr index_t NumIssues = Block_M / (LaneGroups * NumWarps);
|
||||
|
||||
constexpr auto lds_block_desc_0 = make_naive_tensor_descriptor(
|
||||
@@ -156,9 +156,9 @@ struct Flatmm_32x512x128_1x4x1_16x16x32_Base // for f16/bf16
|
||||
number<NumWarps>{}, // m2
|
||||
number<LanesPerK>{}, // k0
|
||||
number<KVector>{}), // k1
|
||||
make_tuple(number<NumWarps*(warpSize * KVector + KPad)>{}, // m0
|
||||
make_tuple(number<NumWarps*(WarpSize * KVector + KPad)>{}, // m0
|
||||
number<Block_K>{}, // m1
|
||||
number<warpSize * KVector + KPad>{}, // m2
|
||||
number<WarpSize * KVector + KPad>{}, // m2
|
||||
number<KVector>{}, // k0
|
||||
number<1>{}), // k1
|
||||
number<KVector>{}, // lds store vector(actually no explicit store)
|
||||
|
||||
@@ -448,19 +448,19 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy<QLo
|
||||
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0;
|
||||
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK1;
|
||||
constexpr index_t NumWarps = Problem::BlockFmhaShape::NumWarps;
|
||||
constexpr index_t warpSize = ck_tile::get_warp_size();
|
||||
constexpr index_t WarpSize = ck_tile::get_warp_size();
|
||||
|
||||
constexpr index_t KPack = GetSmemKPackK<Problem>(); // this is for lds
|
||||
constexpr index_t KVector = GetAlignmentK<Problem>(); // this is for global load
|
||||
constexpr index_t kPad = KPack;
|
||||
|
||||
static_assert(warpSize * KVector >= kKPerBlock &&
|
||||
warpSize * KVector % kKPerBlock == 0);
|
||||
static_assert(WarpSize * KVector >= kKPerBlock &&
|
||||
WarpSize * KVector % kKPerBlock == 0);
|
||||
constexpr index_t LanesPerK = kKPerBlock / KVector;
|
||||
constexpr index_t LaneGroups = warpSize / LanesPerK;
|
||||
constexpr index_t LaneGroups = WarpSize / LanesPerK;
|
||||
constexpr index_t NumIssues = kNPerBlock / (LaneGroups * NumWarps);
|
||||
|
||||
return NumIssues * NumWarps * (warpSize * KVector + kPad);
|
||||
return NumIssues * NumWarps * (WarpSize * KVector + kPad);
|
||||
}
|
||||
}();
|
||||
|
||||
@@ -516,18 +516,18 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy<QLo
|
||||
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK1;
|
||||
constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||
constexpr index_t NumWarps = Problem::BlockFmhaShape::NumWarps;
|
||||
constexpr index_t warpSize = ck_tile::get_warp_size();
|
||||
constexpr index_t WarpSize = ck_tile::get_warp_size();
|
||||
|
||||
constexpr index_t KPack = GetSmemKPackK<Problem>(); // this is for lds
|
||||
constexpr index_t KVector = GetAlignmentK<Problem>(); // this is for global load
|
||||
constexpr index_t kPad =
|
||||
KPack; // for async-copy, this pad is between warps. Optimize this for lds_read speed
|
||||
|
||||
static_assert(warpSize * KVector >= kKPerBlock && warpSize * KVector % kKPerBlock == 0);
|
||||
static_assert(WarpSize * KVector >= kKPerBlock && WarpSize * KVector % kKPerBlock == 0);
|
||||
constexpr index_t LanesPerK =
|
||||
kKPerBlock / KVector; // how many lane (within a wave) to load K
|
||||
constexpr index_t LaneGroups =
|
||||
warpSize /
|
||||
WarpSize /
|
||||
LanesPerK; // how many groups (within a wave), they may load different N, but same K
|
||||
constexpr index_t NumIssues = kNPerBlock / (LaneGroups * NumWarps);
|
||||
static_assert(NumIssues == kNPerBlock * kKPerBlock / (kBlockSize * KVector));
|
||||
@@ -538,9 +538,9 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy<QLo
|
||||
number<NumWarps>{}, // n2
|
||||
number<LanesPerK>{}, // k0
|
||||
number<KVector>{}), // k1
|
||||
make_tuple(number<NumWarps*(warpSize * KVector + kPad)>{},
|
||||
make_tuple(number<NumWarps*(WarpSize * KVector + kPad)>{},
|
||||
number<kKPerBlock>{},
|
||||
number<warpSize * KVector + kPad>{},
|
||||
number<WarpSize * KVector + kPad>{},
|
||||
number<KVector>{},
|
||||
number<1>{}),
|
||||
number<IBuf * GetSingleSmemElementSpaceSize<Problem>()>{},
|
||||
@@ -569,18 +569,18 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy<QLo
|
||||
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK1;
|
||||
constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||
constexpr index_t NumWarps = Problem::BlockFmhaShape::NumWarps;
|
||||
constexpr index_t warpSize = ck_tile::get_warp_size();
|
||||
constexpr index_t WarpSize = ck_tile::get_warp_size();
|
||||
|
||||
constexpr index_t KPack = GetSmemKPackK<Problem>(); // this is for lds
|
||||
constexpr index_t KVector = GetAlignmentK<Problem>(); // this is for global load
|
||||
constexpr index_t kPad = KPack; // for async-copy, this pad is between warps
|
||||
|
||||
static_assert(warpSize * KVector >= kKPerBlock && warpSize * KVector % kKPerBlock == 0);
|
||||
static_assert(WarpSize * KVector >= kKPerBlock && WarpSize * KVector % kKPerBlock == 0);
|
||||
constexpr index_t LanesPerK = kKPerBlock / KVector; // within a wave
|
||||
constexpr index_t LaneGroups = warpSize / LanesPerK; // within a wave
|
||||
constexpr index_t LaneGroups = WarpSize / LanesPerK; // within a wave
|
||||
constexpr index_t NumIssues = kNPerBlock / (LaneGroups * NumWarps);
|
||||
static_assert(NumIssues == kNPerBlock * kKPerBlock / (kBlockSize * KVector));
|
||||
// constexpr index_t SingleKSize = NumIssues * NumWarps * (warpSize * KVector + kPad);
|
||||
// constexpr index_t SingleKSize = NumIssues * NumWarps * (WarpSize * KVector + kPad);
|
||||
// constexpr index_t SingleVSize =
|
||||
// MakeVLdsBlockDescriptor<Problem>().get_element_space_size();
|
||||
constexpr index_t BufferSize =
|
||||
@@ -594,8 +594,8 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy<QLo
|
||||
number<kKPerBlock / KPack>{}, // k0
|
||||
number<KPack>{}), // k1
|
||||
make_tuple(number<BufferSize>{},
|
||||
number<NumWarps*(warpSize * KVector + kPad)>{},
|
||||
number<warpSize * KVector + kPad>{},
|
||||
number<NumWarps*(WarpSize * KVector + kPad)>{},
|
||||
number<WarpSize * KVector + kPad>{},
|
||||
number<kKPerBlock>{},
|
||||
number<KPack>{},
|
||||
number<1>{}),
|
||||
@@ -746,13 +746,13 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy<QLo
|
||||
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK1;
|
||||
constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||
constexpr index_t NumWarps = Problem::BlockFmhaShape::NumWarps;
|
||||
constexpr index_t warpSize = ck_tile::get_warp_size();
|
||||
constexpr index_t WarpSize = ck_tile::get_warp_size();
|
||||
|
||||
constexpr index_t KVector = GetAlignmentK<Problem>(); // this is for global load
|
||||
|
||||
static_assert(warpSize * KVector >= kKPerBlock && warpSize * KVector % kKPerBlock == 0);
|
||||
static_assert(WarpSize * KVector >= kKPerBlock && WarpSize * KVector % kKPerBlock == 0);
|
||||
constexpr index_t LanesPerK = kKPerBlock / KVector; // within a wave
|
||||
constexpr index_t LaneGroups = warpSize / LanesPerK; // within a wave
|
||||
constexpr index_t LaneGroups = WarpSize / LanesPerK; // within a wave
|
||||
constexpr index_t NumIssues = kNPerBlock / (LaneGroups * NumWarps);
|
||||
static_assert(NumIssues == kNPerBlock * kKPerBlock / (kBlockSize * KVector));
|
||||
|
||||
|
||||
@@ -101,7 +101,7 @@ struct FusedMoeGemmShape
|
||||
static constexpr index_t Repeat_N1 = Block_N1 / ThreadPerBlock_N1;
|
||||
static constexpr index_t Repeat_K1 = Block_K1 / ThreadPerBlock_K1;
|
||||
|
||||
static constexpr index_t BlockSize = warpSize * NumWarps;
|
||||
static constexpr index_t BlockSize = WarpSize * NumWarps;
|
||||
|
||||
// some assert
|
||||
static_assert(Block_M0 == Block_M1);
|
||||
|
||||
@@ -381,7 +381,7 @@ struct MoeSortingKernel
|
||||
}
|
||||
|
||||
// reduce single pixel within a wave
|
||||
template <typename T, typename F, index_t wave_size_ = warpSize>
|
||||
template <typename T, typename F, index_t wave_size_ = WarpSize>
|
||||
__device__ static constexpr T wave_reduce(T local, F reduce_f, number<wave_size_> = {})
|
||||
{
|
||||
// constexpr int wave_size = 64;
|
||||
@@ -618,7 +618,7 @@ struct MoeSortingKernel
|
||||
{
|
||||
const index_t prefill_token = topk_mdiv.div(numel);
|
||||
// TODO: only support expert-tile like 8, 16, 32
|
||||
static constexpr index_t experts_per_wave = warpSize / Problem::ExpertTile;
|
||||
static constexpr index_t experts_per_wave = WarpSize / Problem::ExpertTile;
|
||||
{
|
||||
index_t eid = tid / experts_per_wave;
|
||||
index_t expert_offset = cumsum[eid] +
|
||||
@@ -686,7 +686,7 @@ struct MoeSortingKernel
|
||||
void* smem) const
|
||||
{
|
||||
const index_t tid = static_cast<index_t>(threadIdx.x);
|
||||
const index_t wid = __builtin_amdgcn_readfirstlane(tid / warpSize);
|
||||
const index_t wid = __builtin_amdgcn_readfirstlane(tid / WarpSize);
|
||||
const index_t lid = __lane_id();
|
||||
constexpr index_t block_size = 256; // blockDim.x;
|
||||
const index_t sub_tokens = smem_rows - 2; // sub_tokens_mdiv.divisor;
|
||||
@@ -791,7 +791,7 @@ struct MoeSortingKernel
|
||||
// NOTE: under this block can never use __syncthreads!
|
||||
int i_e_ = 0;
|
||||
int local_cumsum_ = 0;
|
||||
for(; i_e_ < num_experts; i_e_ += warpSize)
|
||||
for(; i_e_ < num_experts; i_e_ += WarpSize)
|
||||
{
|
||||
int pre_cumsum_ = smem_cumsum(lid == 0 ? i_e_ : 0);
|
||||
int local_cnt = smem_cumsum(i_e_ + lid + 1);
|
||||
@@ -836,7 +836,7 @@ struct MoeSortingKernel
|
||||
// cumsum padded in case local cumsum is zero, but
|
||||
// pre_sumsum has value, which will result int
|
||||
// zero local cumsum(but we want at least padded)
|
||||
wave_cumsum<int, warpSize>(local_cumsum_);
|
||||
wave_cumsum<int, WarpSize>(local_cumsum_);
|
||||
|
||||
if((i_e_ + lid) < num_experts)
|
||||
smem_cumsum(i_e_ + lid + 1) = local_cumsum_;
|
||||
@@ -844,7 +844,7 @@ struct MoeSortingKernel
|
||||
if constexpr(Problem::LocalExpertMasking)
|
||||
{
|
||||
local_masking += pre_cumsum_masking;
|
||||
wave_cumsum<int, warpSize>(local_masking);
|
||||
wave_cumsum<int, WarpSize>(local_masking);
|
||||
if((i_e_ + lid) < num_experts)
|
||||
smem_cumdup(i_e_ + lid + 1) = local_masking;
|
||||
}
|
||||
@@ -854,7 +854,7 @@ struct MoeSortingKernel
|
||||
// than 0(which is not we want)
|
||||
__builtin_amdgcn_s_waitcnt(0xc07f);
|
||||
}
|
||||
if((lid + i_e_ - warpSize) == (num_experts - 1))
|
||||
if((lid + i_e_ - WarpSize) == (num_experts - 1))
|
||||
{
|
||||
*p_total_tokens_post_pad = local_cumsum_;
|
||||
}
|
||||
@@ -1091,7 +1091,7 @@ CK_TILE_HOST_DEVICE index_t moe_sorting_mp_sem_smem_size()
|
||||
return chunk * sizeof(index_t);
|
||||
};
|
||||
|
||||
template <typename T, typename F, index_t wave_size_ = warpSize>
|
||||
template <typename T, typename F, index_t wave_size_ = WarpSize>
|
||||
CK_TILE_DEVICE constexpr T moe_sorting_wave_reduce(T local, F reduce_f, number<wave_size_> = {})
|
||||
{
|
||||
// constexpr int wave_size = 64;
|
||||
@@ -1456,7 +1456,7 @@ struct MoeSortingMultiPhaseKernel_P1
|
||||
// in byte
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetSmemSize()
|
||||
{
|
||||
return BLOCK_SIZE / warpSize * sizeof(IndexType);
|
||||
return BLOCK_SIZE / WarpSize * sizeof(IndexType);
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE void operator()(Kargs kargs) const
|
||||
@@ -1498,8 +1498,8 @@ struct MoeSortingMultiPhaseKernel_P1
|
||||
cnt += impl::moe_sorting_wave_reduce(local_sum, f_sum);
|
||||
}
|
||||
|
||||
index_t lane_id = threadIdx.x % warpSize;
|
||||
index_t wave_id = threadIdx.x / warpSize;
|
||||
index_t lane_id = threadIdx.x % WarpSize;
|
||||
index_t wave_id = threadIdx.x / WarpSize;
|
||||
|
||||
// reduce cross wave
|
||||
IndexType* s = reinterpret_cast<IndexType*>(smem);
|
||||
@@ -1512,7 +1512,7 @@ struct MoeSortingMultiPhaseKernel_P1
|
||||
if(threadIdx.x == 0)
|
||||
{
|
||||
index_t c = 0;
|
||||
for(auto i = 0; i < (BLOCK_SIZE / warpSize); i++)
|
||||
for(auto i = 0; i < (BLOCK_SIZE / WarpSize); i++)
|
||||
{
|
||||
c += s[i];
|
||||
}
|
||||
@@ -1601,7 +1601,7 @@ struct MoeSortingMultiPhaseKernel_P01
|
||||
// in byte
|
||||
CK_TILE_HOST static constexpr auto GetSmemSize()
|
||||
{
|
||||
return BLOCK_SIZE / warpSize * sizeof(IndexType);
|
||||
return BLOCK_SIZE / WarpSize * sizeof(IndexType);
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE void operator()(Kargs kargs) const
|
||||
@@ -1685,8 +1685,8 @@ struct MoeSortingMultiPhaseKernel_P01
|
||||
cnt += impl::moe_sorting_wave_reduce(local_sum, f_sum);
|
||||
}
|
||||
|
||||
index_t lane_id = threadIdx.x % warpSize;
|
||||
index_t wave_id = threadIdx.x / warpSize;
|
||||
index_t lane_id = threadIdx.x % WarpSize;
|
||||
index_t wave_id = threadIdx.x / WarpSize;
|
||||
|
||||
// reduce cross wave
|
||||
IndexType* s = reinterpret_cast<IndexType*>(smem);
|
||||
@@ -1700,7 +1700,7 @@ struct MoeSortingMultiPhaseKernel_P01
|
||||
if(threadIdx.x == 0)
|
||||
{
|
||||
index_t c = 0;
|
||||
for(auto i = 0; i < (BLOCK_SIZE / warpSize); i++)
|
||||
for(auto i = 0; i < (BLOCK_SIZE / WarpSize); i++)
|
||||
{
|
||||
c += s[i];
|
||||
}
|
||||
@@ -1777,7 +1777,7 @@ struct MoeSortingMultiPhaseKernel_P2
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetSmemSize()
|
||||
{
|
||||
// return 2 * BLOCK_SIZE * sizeof(IndexType);
|
||||
return (4 + 2 * BLOCK_SIZE / warpSize) * sizeof(IndexType);
|
||||
return (4 + 2 * BLOCK_SIZE / WarpSize) * sizeof(IndexType);
|
||||
}
|
||||
|
||||
// reduce single pixel within a wave
|
||||
@@ -1802,8 +1802,8 @@ struct MoeSortingMultiPhaseKernel_P2
|
||||
IndexType* p_sorted_expert_ids = reinterpret_cast<IndexType*>(kargs.p_sorted_expert_ids);
|
||||
|
||||
const index_t loops = (kargs.num_experts + BLOCK_SIZE - 1) / BLOCK_SIZE;
|
||||
index_t wave_id = threadIdx.x / warpSize;
|
||||
index_t lane_id = threadIdx.x % warpSize;
|
||||
index_t wave_id = threadIdx.x / WarpSize;
|
||||
index_t lane_id = threadIdx.x % WarpSize;
|
||||
|
||||
IndexType prev_cumsum_a = 0;
|
||||
IndexType prev_cumsum_b = 0;
|
||||
@@ -1848,22 +1848,22 @@ struct MoeSortingMultiPhaseKernel_P2
|
||||
IndexType cumsum_b = b_;
|
||||
|
||||
// Note: we first cumsum local round, then add previous cumsum
|
||||
impl::moe_sorting_wave_cumsum<IndexType, warpSize>(cumsum_a);
|
||||
impl::moe_sorting_wave_cumsum<IndexType, warpSize>(cumsum_b);
|
||||
impl::moe_sorting_wave_cumsum<IndexType, WarpSize>(cumsum_a);
|
||||
impl::moe_sorting_wave_cumsum<IndexType, WarpSize>(cumsum_b);
|
||||
|
||||
__syncthreads();
|
||||
if(lane_id == warpSize - 1)
|
||||
if(lane_id == WarpSize - 1)
|
||||
{
|
||||
s[4 + wave_id] = cumsum_a;
|
||||
s[4 + wave_id + BLOCK_SIZE / warpSize] = cumsum_b;
|
||||
s[4 + wave_id + BLOCK_SIZE / WarpSize] = cumsum_b;
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
|
||||
// reduce cross wave
|
||||
static_for<0, BLOCK_SIZE / warpSize - 1, 1>{}([&](auto i_w) {
|
||||
static_for<0, BLOCK_SIZE / WarpSize - 1, 1>{}([&](auto i_w) {
|
||||
IndexType prev_a = s[4 + i_w];
|
||||
IndexType prev_b = s[4 + i_w + BLOCK_SIZE / warpSize];
|
||||
IndexType prev_b = s[4 + i_w + BLOCK_SIZE / WarpSize];
|
||||
prev_a = wave_id > i_w ? prev_a : 0; // mask out
|
||||
prev_b = wave_id > i_w ? prev_b : 0; // mask out
|
||||
cumsum_a += prev_a;
|
||||
@@ -1978,7 +1978,7 @@ struct MoeSortingMultiPhaseKernel_P3
|
||||
// in byte
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetSmemSize()
|
||||
{
|
||||
return (4 + BLOCK_SIZE / warpSize) * sizeof(IndexType);
|
||||
return (4 + BLOCK_SIZE / WarpSize) * sizeof(IndexType);
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE void operator()(Kargs kargs) const
|
||||
@@ -1995,8 +1995,8 @@ struct MoeSortingMultiPhaseKernel_P3
|
||||
WeightType* p_sorted_weights = reinterpret_cast<WeightType*>(kargs.p_sorted_weights);
|
||||
|
||||
int eid = blockIdx.x;
|
||||
int wave_id = threadIdx.x / warpSize;
|
||||
int lane_id = threadIdx.x % warpSize;
|
||||
int wave_id = threadIdx.x / WarpSize;
|
||||
int lane_id = threadIdx.x % WarpSize;
|
||||
int e_start = p_expert_cumsum[eid];
|
||||
int e_end = p_expert_cumsum[eid + 1];
|
||||
if constexpr(Problem::SkipExpertsWithZeroTokens)
|
||||
@@ -2026,17 +2026,17 @@ struct MoeSortingMultiPhaseKernel_P3
|
||||
int i_topk = x - 1; // topk of this token
|
||||
int i_show = x != 0 ? 1 : 0; // has this token or not
|
||||
int cumsum = i_show;
|
||||
impl::moe_sorting_wave_cumsum<int, warpSize>(cumsum);
|
||||
impl::moe_sorting_wave_cumsum<int, WarpSize>(cumsum);
|
||||
|
||||
__syncthreads();
|
||||
if(lane_id == warpSize - 1)
|
||||
if(lane_id == WarpSize - 1)
|
||||
{
|
||||
s[4 + wave_id] = cumsum;
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
// reduce cross wave
|
||||
static_for<0, BLOCK_SIZE / warpSize - 1, 1>{}([&](auto i_w) {
|
||||
static_for<0, BLOCK_SIZE / WarpSize - 1, 1>{}([&](auto i_w) {
|
||||
IndexType prev = s[4 + i_w];
|
||||
prev = wave_id > i_w ? prev : 0; // mask out
|
||||
cumsum += prev;
|
||||
@@ -2081,7 +2081,7 @@ CK_TILE_HOST constexpr auto moe_sorting_get_smem_size_p23(int num_experts_)
|
||||
{
|
||||
constexpr index_t BLOCK_SIZE = 256; // hardcoded 256
|
||||
const index_t expert_cumsum_elem = num_experts_ + 1;
|
||||
return (4 + 2 * BLOCK_SIZE / warpSize + expert_cumsum_elem) * sizeof(int);
|
||||
return (4 + 2 * BLOCK_SIZE / WarpSize + expert_cumsum_elem) * sizeof(int);
|
||||
}
|
||||
} // namespace impl
|
||||
|
||||
@@ -2186,15 +2186,15 @@ struct MoeSortingMultiPhaseKernel_P23
|
||||
const IndexType* p_local_expert_mask =
|
||||
static_cast<const IndexType*>(kargs.p_local_expert_mask);
|
||||
IndexType* p_expert_cumsum = reinterpret_cast<IndexType*>(kargs.p_expert_cumsum);
|
||||
IndexType* p_expert_cumsum_smem = s + 4 + 2 * BLOCK_SIZE / warpSize;
|
||||
IndexType* p_expert_cumsum_smem = s + 4 + 2 * BLOCK_SIZE / WarpSize;
|
||||
IndexType* p_total_tokens_post_pad =
|
||||
reinterpret_cast<IndexType*>(kargs.p_total_tokens_post_pad);
|
||||
IndexType* p_sorted_expert_ids =
|
||||
reinterpret_cast<IndexType*>(kargs.p_sorted_expert_ids);
|
||||
|
||||
const index_t loops = (kargs.num_experts + BLOCK_SIZE - 1) / BLOCK_SIZE;
|
||||
index_t wave_id = threadIdx.x / warpSize;
|
||||
index_t lane_id = threadIdx.x % warpSize;
|
||||
index_t wave_id = threadIdx.x / WarpSize;
|
||||
index_t lane_id = threadIdx.x % WarpSize;
|
||||
|
||||
IndexType prev_cumsum_a = 0;
|
||||
IndexType prev_cumsum_b = 0;
|
||||
@@ -2239,22 +2239,22 @@ struct MoeSortingMultiPhaseKernel_P23
|
||||
IndexType cumsum_b = b_;
|
||||
|
||||
// Note: we first cumsum local round, then add previous cumsum
|
||||
impl::moe_sorting_wave_cumsum<IndexType, warpSize>(cumsum_a);
|
||||
impl::moe_sorting_wave_cumsum<IndexType, warpSize>(cumsum_b);
|
||||
impl::moe_sorting_wave_cumsum<IndexType, WarpSize>(cumsum_a);
|
||||
impl::moe_sorting_wave_cumsum<IndexType, WarpSize>(cumsum_b);
|
||||
|
||||
__syncthreads();
|
||||
if(lane_id == warpSize - 1)
|
||||
if(lane_id == WarpSize - 1)
|
||||
{
|
||||
s[4 + wave_id] = cumsum_a;
|
||||
s[4 + wave_id + BLOCK_SIZE / warpSize] = cumsum_b;
|
||||
s[4 + wave_id + BLOCK_SIZE / WarpSize] = cumsum_b;
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
|
||||
// reduce cross wave
|
||||
static_for<0, BLOCK_SIZE / warpSize - 1, 1>{}([&](auto i_w) {
|
||||
static_for<0, BLOCK_SIZE / WarpSize - 1, 1>{}([&](auto i_w) {
|
||||
IndexType prev_a = s[4 + i_w];
|
||||
IndexType prev_b = s[4 + i_w + BLOCK_SIZE / warpSize];
|
||||
IndexType prev_b = s[4 + i_w + BLOCK_SIZE / WarpSize];
|
||||
prev_a = wave_id > i_w ? prev_a : 0; // mask out
|
||||
prev_b = wave_id > i_w ? prev_b : 0; // mask out
|
||||
cumsum_a += prev_a;
|
||||
@@ -2324,13 +2324,13 @@ struct MoeSortingMultiPhaseKernel_P23
|
||||
IndexType* s = reinterpret_cast<IndexType*>(smem);
|
||||
MeshType* p_expert_mesh = reinterpret_cast<MeshType*>(kargs.p_expert_mesh);
|
||||
IndexType* p_sorted_token_ids = reinterpret_cast<IndexType*>(kargs.p_sorted_token_ids);
|
||||
IndexType* p_expert_cumsum_smem = s + 4 + 2 * BLOCK_SIZE / warpSize;
|
||||
IndexType* p_expert_cumsum_smem = s + 4 + 2 * BLOCK_SIZE / WarpSize;
|
||||
const WeightType* p_weights = static_cast<const WeightType*>(kargs.p_weights);
|
||||
WeightType* p_sorted_weights = reinterpret_cast<WeightType*>(kargs.p_sorted_weights);
|
||||
|
||||
int eid = blockIdx.x;
|
||||
int wave_id = threadIdx.x / warpSize;
|
||||
int lane_id = threadIdx.x % warpSize;
|
||||
int wave_id = threadIdx.x / WarpSize;
|
||||
int lane_id = threadIdx.x % WarpSize;
|
||||
int e_start = p_expert_cumsum_smem[eid];
|
||||
int e_end = p_expert_cumsum_smem[eid + 1];
|
||||
if constexpr(Problem::SkipExpertsWithZeroTokens)
|
||||
@@ -2390,17 +2390,17 @@ struct MoeSortingMultiPhaseKernel_P23
|
||||
int i_topk = x - 1; // topk of this token
|
||||
int i_show = x != 0 ? 1 : 0; // has this token or not
|
||||
int cumsum = i_show;
|
||||
impl::moe_sorting_wave_cumsum<int, warpSize>(cumsum);
|
||||
impl::moe_sorting_wave_cumsum<int, WarpSize>(cumsum);
|
||||
|
||||
__syncthreads();
|
||||
if(lane_id == warpSize - 1)
|
||||
if(lane_id == WarpSize - 1)
|
||||
{
|
||||
s[4 + wave_id] = cumsum;
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
// reduce cross wave
|
||||
static_for<0, BLOCK_SIZE / warpSize - 1, 1>{}([&](auto i_w) {
|
||||
static_for<0, BLOCK_SIZE / WarpSize - 1, 1>{}([&](auto i_w) {
|
||||
IndexType prev = s[4 + i_w];
|
||||
prev = wave_id > i_w ? prev : 0; // mask out
|
||||
cumsum += prev;
|
||||
@@ -2441,17 +2441,17 @@ struct MoeSortingMultiPhaseKernel_P23
|
||||
cumsum_store += i_show[j];
|
||||
});
|
||||
int cumsum = cumsum_store;
|
||||
impl::moe_sorting_wave_cumsum<int, warpSize>(cumsum);
|
||||
impl::moe_sorting_wave_cumsum<int, WarpSize>(cumsum);
|
||||
|
||||
__syncthreads();
|
||||
if(lane_id == warpSize - 1)
|
||||
if(lane_id == WarpSize - 1)
|
||||
{
|
||||
s[4 + wave_id] = cumsum;
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
// reduce cross wave
|
||||
static_for<0, BLOCK_SIZE / warpSize - 1, 1>{}([&](auto i_w) {
|
||||
static_for<0, BLOCK_SIZE / WarpSize - 1, 1>{}([&](auto i_w) {
|
||||
IndexType prev = s[4 + i_w];
|
||||
prev = wave_id > i_w ? prev : 0; // mask out
|
||||
cumsum += prev;
|
||||
@@ -2496,17 +2496,17 @@ struct MoeSortingMultiPhaseKernel_P23
|
||||
int i_topk_1 = x1 - 1; // topk of this token
|
||||
int i_show_1 = x1 != 0 ? 1 : 0; // has this token or not
|
||||
int cumsum = i_show_0 + i_show_1;
|
||||
impl::moe_sorting_wave_cumsum<int, warpSize>(cumsum);
|
||||
impl::moe_sorting_wave_cumsum<int, WarpSize>(cumsum);
|
||||
|
||||
__syncthreads();
|
||||
if(lane_id == warpSize - 1)
|
||||
if(lane_id == WarpSize - 1)
|
||||
{
|
||||
s[4 + wave_id] = cumsum;
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
// reduce cross wave
|
||||
static_for<0, BLOCK_SIZE / warpSize - 1, 1>{}([&](auto i_w) {
|
||||
static_for<0, BLOCK_SIZE / WarpSize - 1, 1>{}([&](auto i_w) {
|
||||
IndexType prev = s[4 + i_w];
|
||||
prev = wave_id > i_w ? prev : 0; // mask out
|
||||
cumsum += prev;
|
||||
|
||||
@@ -303,7 +303,7 @@ struct FusedMoeGemmPipelineFlatmmPolicy
|
||||
constexpr index_t Block_M = Problem::BlockShape::Block_M0;
|
||||
constexpr index_t Block_K = Problem::BlockShape::Block_K0;
|
||||
// constexpr index_t BlockSize = Problem::BlockShape::BlockSize;
|
||||
constexpr index_t warpSize = ck_tile::get_warp_size();
|
||||
constexpr index_t WarpSize = ck_tile::get_warp_size();
|
||||
constexpr index_t NumWarps = Problem::BlockShape::NumWarps;
|
||||
|
||||
constexpr index_t KPack = GetSmemKPack_A<Problem>(); // LDS
|
||||
@@ -312,11 +312,11 @@ struct FusedMoeGemmPipelineFlatmmPolicy
|
||||
|
||||
static_assert(Block_K % KVector == 0);
|
||||
constexpr index_t LanesPerK = Block_K / KVector; // how many thread loading K
|
||||
if constexpr(LanesPerK >= warpSize)
|
||||
if constexpr(LanesPerK >= WarpSize)
|
||||
{
|
||||
// need multiple waves to load K
|
||||
static_assert(LanesPerK % warpSize == 0);
|
||||
constexpr index_t wavesPerK = LanesPerK / warpSize;
|
||||
static_assert(LanesPerK % WarpSize == 0);
|
||||
constexpr index_t wavesPerK = LanesPerK / WarpSize;
|
||||
if constexpr(wavesPerK > NumWarps)
|
||||
{
|
||||
// TODO: need multiple issues along K to load all data
|
||||
@@ -329,11 +329,11 @@ struct FusedMoeGemmPipelineFlatmmPolicy
|
||||
make_tuple(number<NumIssues>{}, // m0
|
||||
number<wavesPerM>{}, // m1
|
||||
number<wavesPerK>{}, // k0
|
||||
number<warpSize>{}, // k1
|
||||
number<WarpSize>{}, // k1
|
||||
number<KVector>{}), // k2
|
||||
make_tuple(number<NumWarps*(warpSize * KVector + KPad)>{}, // m0
|
||||
number<wavesPerK*(warpSize * KVector + KPad)>{}, // m1
|
||||
number<warpSize * KVector + KPad>{}, // k0
|
||||
make_tuple(number<NumWarps*(WarpSize * KVector + KPad)>{}, // m0
|
||||
number<wavesPerK*(WarpSize * KVector + KPad)>{}, // m1
|
||||
number<WarpSize * KVector + KPad>{}, // k0
|
||||
number<KVector>{}, // k1
|
||||
number<1>{}), // k2
|
||||
number<KVector>{}, // lds store vector(actually no explicit store)
|
||||
@@ -344,7 +344,7 @@ struct FusedMoeGemmPipelineFlatmmPolicy
|
||||
make_tuple(
|
||||
make_pass_through_transform(number<NumIssues>{}),
|
||||
make_merge_transform(make_tuple(number<wavesPerM>{}, number<wavesPerK>{})),
|
||||
make_merge_transform(make_tuple(number<warpSize>{}, number<KVector>{}))),
|
||||
make_merge_transform(make_tuple(number<WarpSize>{}, number<KVector>{}))),
|
||||
make_tuple(sequence<0>{}, sequence<1, 2>{}, sequence<3, 4>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}));
|
||||
|
||||
@@ -354,8 +354,8 @@ struct FusedMoeGemmPipelineFlatmmPolicy
|
||||
else
|
||||
{
|
||||
// lanes within a wave load different M but same K
|
||||
static_assert(warpSize % LanesPerK == 0);
|
||||
constexpr index_t LaneGroups = warpSize / LanesPerK; // along m
|
||||
static_assert(WarpSize % LanesPerK == 0);
|
||||
constexpr index_t LaneGroups = WarpSize / LanesPerK; // along m
|
||||
constexpr index_t NumIssues = Block_M / (LaneGroups * NumWarps);
|
||||
|
||||
constexpr auto lds_block_desc_0 = make_naive_tensor_descriptor(
|
||||
@@ -364,9 +364,9 @@ struct FusedMoeGemmPipelineFlatmmPolicy
|
||||
number<NumWarps>{}, // m2
|
||||
number<LanesPerK>{}, // k0
|
||||
number<KVector>{}), // k1
|
||||
make_tuple(number<NumWarps*(warpSize * KVector + KPad)>{}, // m0
|
||||
make_tuple(number<NumWarps*(WarpSize * KVector + KPad)>{}, // m0
|
||||
number<Block_K>{}, // m1
|
||||
number<warpSize * KVector + KPad>{}, // m2
|
||||
number<WarpSize * KVector + KPad>{}, // m2
|
||||
number<KVector>{}, // k0
|
||||
number<1>{}), // k1
|
||||
number<KVector>{}, // lds store vector(actually no explicit store)
|
||||
@@ -398,7 +398,7 @@ struct FusedMoeGemmPipelineFlatmmPolicy
|
||||
constexpr index_t Block_M = Problem::BlockShape::Block_M0;
|
||||
constexpr index_t Block_K = Problem::BlockShape::Block_K0;
|
||||
// constexpr index_t BlockSize = Problem::BlockShape::BlockSize;
|
||||
constexpr index_t warpSize = ck_tile::get_warp_size();
|
||||
constexpr index_t WarpSize = ck_tile::get_warp_size();
|
||||
constexpr index_t NumWarps = Problem::BlockShape::NumWarps;
|
||||
|
||||
constexpr index_t KPack = GetSmemKPack_A<Problem>(); // LDS
|
||||
@@ -407,11 +407,11 @@ struct FusedMoeGemmPipelineFlatmmPolicy
|
||||
|
||||
static_assert(Block_K % KVector == 0);
|
||||
constexpr index_t LanesPerK = Block_K / KVector; // how many thread loading K
|
||||
if constexpr(LanesPerK >= warpSize)
|
||||
if constexpr(LanesPerK >= WarpSize)
|
||||
{
|
||||
// need multiple waves to load K
|
||||
static_assert(LanesPerK % warpSize == 0);
|
||||
constexpr index_t wavesPerK = LanesPerK / warpSize;
|
||||
static_assert(LanesPerK % WarpSize == 0);
|
||||
constexpr index_t wavesPerK = LanesPerK / WarpSize;
|
||||
if constexpr(wavesPerK >= NumWarps)
|
||||
{
|
||||
// TODO: need multiple issues along K to load all data
|
||||
@@ -424,11 +424,11 @@ struct FusedMoeGemmPipelineFlatmmPolicy
|
||||
make_tuple(number<NumIssues>{}, // m0
|
||||
number<wavesPerM>{}, // m1
|
||||
number<wavesPerK>{}, // k0
|
||||
number<warpSize>{}, // k1
|
||||
number<WarpSize>{}, // k1
|
||||
number<KVector>{}), // k2
|
||||
make_tuple(number<NumWarps*(warpSize * KVector + KPad)>{}, // m0
|
||||
number<wavesPerK*(warpSize * KVector + KPad)>{}, // m1
|
||||
number<warpSize * KVector + KPad>{}, // k0
|
||||
make_tuple(number<NumWarps*(WarpSize * KVector + KPad)>{}, // m0
|
||||
number<wavesPerK*(WarpSize * KVector + KPad)>{}, // m1
|
||||
number<WarpSize * KVector + KPad>{}, // k0
|
||||
number<KVector>{}, // k1
|
||||
number<1>{}), // k2
|
||||
number<KPack>{}, // lds load vector
|
||||
@@ -439,7 +439,7 @@ struct FusedMoeGemmPipelineFlatmmPolicy
|
||||
make_tuple(
|
||||
make_merge_transform(make_tuple(number<NumIssues>{}, number<wavesPerM>{})),
|
||||
make_merge_transform(make_tuple(
|
||||
number<wavesPerK>{}, number<warpSize>{}, number<KVector>{}))),
|
||||
number<wavesPerK>{}, number<WarpSize>{}, number<KVector>{}))),
|
||||
make_tuple(sequence<0, 1>{}, sequence<2, 3, 4>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}));
|
||||
|
||||
@@ -449,8 +449,8 @@ struct FusedMoeGemmPipelineFlatmmPolicy
|
||||
else
|
||||
{
|
||||
// lanes within a wave load different M but same K
|
||||
static_assert(warpSize % LanesPerK == 0);
|
||||
constexpr index_t LaneGroups = warpSize / LanesPerK; // along m
|
||||
static_assert(WarpSize % LanesPerK == 0);
|
||||
constexpr index_t LaneGroups = WarpSize / LanesPerK; // along m
|
||||
constexpr index_t NumIssues = Block_M / (LaneGroups * NumWarps);
|
||||
|
||||
constexpr auto lds_block_desc_0 = make_naive_tensor_descriptor(
|
||||
@@ -459,9 +459,9 @@ struct FusedMoeGemmPipelineFlatmmPolicy
|
||||
number<NumWarps>{}, // m2
|
||||
number<LanesPerK>{}, // k0
|
||||
number<KVector>{}), // k1
|
||||
make_tuple(number<NumWarps*(warpSize * KVector + KPad)>{}, // m0
|
||||
make_tuple(number<NumWarps*(WarpSize * KVector + KPad)>{}, // m0
|
||||
number<Block_K>{}, // m1
|
||||
number<warpSize * KVector + KPad>{}, // m2
|
||||
number<WarpSize * KVector + KPad>{}, // m2
|
||||
number<KVector>{}, // k0
|
||||
number<1>{}), // k1
|
||||
number<KPack>{}, // lds load vector
|
||||
|
||||
@@ -26,7 +26,7 @@ struct TileImageToColumnShape
|
||||
static constexpr index_t kMWarpPerBlock = kMPerBlock / kMPerWarp;
|
||||
static constexpr index_t kKWarpPerBlock = kKPerBlock / kKPerWarp;
|
||||
|
||||
static constexpr index_t kBlockSize = warpSize * kMWarpPerBlock * kKWarpPerBlock;
|
||||
static constexpr index_t kBlockSize = get_warp_size() * kMWarpPerBlock * kKWarpPerBlock;
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
@@ -250,7 +250,7 @@ struct BlockNormReduceCrossWarpSync
|
||||
// | w0 | w1 | w2 | w3 | -----> | w0123 |
|
||||
//
|
||||
// -> also store data from every wave into LDS
|
||||
constexpr index_t num_warps = BlockShape::BlockSize / warpSize;
|
||||
constexpr index_t num_warps = BlockShape::BlockSize / WarpSize;
|
||||
return num_warps * 4 * thread_buf_size * sizeof(float);
|
||||
}
|
||||
|
||||
@@ -276,7 +276,7 @@ struct BlockNormReduceCrossWarpSync
|
||||
const index_t lane_id = get_lane_id();
|
||||
const index_t warp_id = get_warp_id();
|
||||
constexpr auto num_reduce_warps = GetReduceWarps<MeanDistributedTensor_>();
|
||||
constexpr index_t num_warps = BlockShape::BlockSize / warpSize;
|
||||
constexpr index_t num_warps = BlockShape::BlockSize / WarpSize;
|
||||
const index_t smem_offset = warp_id;
|
||||
|
||||
// skip if nonthing to do
|
||||
|
||||
@@ -210,7 +210,7 @@ struct BlockReduce2dCrossWarpSync
|
||||
// | w0 | w1 | w2 | w3 | -----> | w0123 |
|
||||
//
|
||||
// -> also store data from every wave into LDS
|
||||
constexpr index_t num_warps = BlockShape::BlockSize / warpSize;
|
||||
constexpr index_t num_warps = BlockShape::BlockSize / get_warp_size();
|
||||
return num_warps * thread_buf_size * sizeof(DataType);
|
||||
}
|
||||
|
||||
@@ -226,7 +226,7 @@ struct BlockReduce2dCrossWarpSync
|
||||
const index_t lane_id = get_lane_id();
|
||||
const index_t warp_id = get_warp_id();
|
||||
constexpr auto num_reduce_warps = GetReduceWarps<YDistributedTensor_>();
|
||||
constexpr index_t num_warps = BlockShape::BlockSize / warpSize;
|
||||
constexpr index_t num_warps = BlockShape::BlockSize / get_warp_size();
|
||||
const index_t smem_offset = warp_id;
|
||||
|
||||
// skip if nonthing to do
|
||||
|
||||
Reference in New Issue
Block a user