Do not use warpSize as compile time constant as it is removed (#2320)

* Do not use warpSize as compile time constant as it is removed

* Update tile_image_to_column_shape.hpp

update warpSize usage.

* clean-up all use of warpSize, make sure code builds

* fix

---------

Co-authored-by: Illia Silin <98187287+illsilin@users.noreply.github.com>
Co-authored-by: illsilin <Illia.Silin@amd.com>
Co-authored-by: Bartlomiej Kocot <barkocot@amd.com>

[ROCm/composable_kernel commit: 4c57157d50]
This commit is contained in:
Satyanvesh Dittakavi
2025-06-18 00:24:30 +05:30
committed by GitHub
parent 66afddf431
commit bde406245a
31 changed files with 213 additions and 206 deletions

View File

@@ -274,6 +274,12 @@
namespace ck {
#if defined(__GFX9__) || !defined(__HIP_DEVICE_COMPILE__)
__device__ static constexpr int WarpSize = 64;
#else
__device__ static constexpr int WarpSize = 32;
#endif
enum struct InMemoryDataOperationEnum
{
Set,

View File

@@ -45,7 +45,7 @@ struct BlockwiseGemmXdlops_mx_pipeline_base
using ThisThreadBlock = ThisThreadBlock<BlockSize>;
// Hardcode to 64, as HIP-provided "warpSize" would return 32 on RDNA GPUs.
// Hardcode to 64, as HIP-provided "WarpSize" would return 32 on RDNA GPUs.
static constexpr index_t WaveSize = 64;
static constexpr index_t A_K0 = ATileDesc{}.GetLength(I0);

View File

@@ -40,7 +40,7 @@ struct BlockwiseGemmXdlops_pipeline_base
using ThisThreadBlock = ThisThreadBlock<BlockSize>;
// Hardcode to 64, as HIP-provided "warpSize" would return 32 on RDNA GPUs.
// Hardcode to 64, as HIP-provided "WarpSize" would return 32 on RDNA GPUs.
static constexpr index_t WaveSize = 64;
static constexpr index_t A_K0 = ATileDesc{}.GetLength(I0);

View File

@@ -141,7 +141,7 @@ struct BlockwiseGemmXdlops_pipeline_v2<BlockGemmPipelineScheduler::Intrawave,
using Base::BMmaKStride;
static constexpr index_t WgpPerCU =
(4 * warpSize / BlockSize) >= 1 ? 4 * warpSize / BlockSize : 1;
(4 * WarpSize / BlockSize) >= 1 ? 4 * WarpSize / BlockSize : 1;
static constexpr index_t FullMemBandPrefetchStages = math::integer_divide_ceil(
32768 / WgpPerCU,
(MPerBlock * sizeof(ADataType) + NPerBlock * sizeof(BDataType)) * KPerBlock);
@@ -631,7 +631,7 @@ struct BlockwiseGemmXdlops_pipeline_v2<BlockGemmPipelineScheduler::Interwave,
static constexpr index_t KRepeat = KPerThread / KPerInnerLoop;
static constexpr index_t WgpPerCU =
(4 * warpSize / BlockSize) >= 1 ? 4 * warpSize / BlockSize : 1;
(4 * WarpSize / BlockSize) >= 1 ? 4 * WarpSize / BlockSize : 1;
static constexpr index_t FullMemBandPrefetchStages = math::integer_divide_ceil(
32768 / WgpPerCU,
(MPerBlock * sizeof(ADataType) + NPerBlock * sizeof(BDataType)) * KPerBlock);

View File

@@ -143,7 +143,7 @@ struct BlockwiseGemmXdlops_pipeline_v2_ab_scale<BlockGemmPipelineScheduler::Intr
using Base::BMmaKStride;
static constexpr index_t WgpPerCU =
(4 * warpSize / BlockSize) >= 1 ? 4 * warpSize / BlockSize : 1;
(4 * WarpSize / BlockSize) >= 1 ? 4 * WarpSize / BlockSize : 1;
static constexpr index_t FullMemBandPrefetchStages = math::integer_divide_ceil(
32768 / WgpPerCU,
(MPerBlock * sizeof(ADataType) + NPerBlock * sizeof(BDataType)) * KPerBlock);

View File

@@ -141,7 +141,7 @@ struct BlockwiseGemmXdlops_pipeline_v2_b_scale<BlockGemmPipelineScheduler::Intra
using Base::BMmaKStride;
static constexpr index_t WgpPerCU =
(4 * warpSize / BlockSize) >= 1 ? 4 * warpSize / BlockSize : 1;
(4 * WarpSize / BlockSize) >= 1 ? 4 * WarpSize / BlockSize : 1;
static constexpr index_t FullMemBandPrefetchStages = math::integer_divide_ceil(
32768 / WgpPerCU,
(MPerBlock * sizeof(ADataType) + NPerBlock * sizeof(BDataType)) * KPerBlock);
@@ -632,7 +632,7 @@ struct BlockwiseGemmXdlops_pipeline_v2_b_scale<BlockGemmPipelineScheduler::Inter
static constexpr index_t KRepeat = KPerThread / KPerInnerLoop;
static constexpr index_t WgpPerCU =
(4 * warpSize / BlockSize) >= 1 ? 4 * warpSize / BlockSize : 1;
(4 * WarpSize / BlockSize) >= 1 ? 4 * WarpSize / BlockSize : 1;
static constexpr index_t FullMemBandPrefetchStages = math::integer_divide_ceil(
32768 / WgpPerCU,
(MPerBlock * sizeof(ADataType) + NPerBlock * sizeof(BDataType)) * KPerBlock);

View File

@@ -202,7 +202,7 @@ struct GridwiseMultiblockBatchNormForward
const index_t block_local_id = block_global_id % blkgroup_size;
if(block_local_id == 0)
gms_init(BlockSize / warpSize * blkgroup_size, &p_control[blkgroup_id * 2]);
gms_init(BlockSize / WarpSize * blkgroup_size, &p_control[blkgroup_id * 2]);
const auto thread_cluster_idx =
thread_cluster_desc.CalculateBottomIndex(make_multi_index(thread_local_id));

View File

@@ -347,7 +347,7 @@ struct GridwiseGemm_xdl_cshuffle_v3_b_preshuffle
__host__ __device__ static auto MakeBGridDescriptor_Preshuffled(index_t N0, index_t K0)
{
constexpr index_t NkSwizzleNumber = Number<warpSize * KPack>{};
constexpr index_t NkSwizzleNumber = Number<WarpSize * KPack>{};
return make_naive_tensor_descriptor(
make_tuple(N0 / NWave, NWave, K0, NkSwizzleNumber),
make_tuple(NWave * K0 * NkSwizzleNumber, K0 * NkSwizzleNumber, NkSwizzleNumber, I1));
@@ -1229,7 +1229,7 @@ struct GridwiseGemm_xdl_cshuffle_v3_b_preshuffle
make_multi_index(n_block_data_idx_on_grid,
get_warp_local_1d_id() % NWave,
0,
KPack * (get_thread_local_1d_id() % warpSize)));
KPack * (get_thread_local_1d_id() % WarpSize)));
// LDS allocation for A and B: be careful of alignment
@@ -1607,7 +1607,7 @@ struct GridwiseGemm_xdl_cshuffle_v3_b_preshuffle
make_multi_index(n_block_data_idx_on_grid,
get_warp_local_1d_id() % NWave,
0,
KPack * (get_thread_local_1d_id() % warpSize)));
KPack * (get_thread_local_1d_id() % WarpSize)));
// LDS allocation for A and B: be careful of alignment
auto a_block_buf_ping = make_dynamic_buffer<AddressSpaceEnum::Lds>(

View File

@@ -374,7 +374,7 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle
__host__ __device__ static auto MakeBGridDescriptor_Preshuffled(index_t N0, index_t K0)
{
constexpr index_t NkSwizzleNumber = Number<warpSize * KPackPerGroup>{};
constexpr index_t NkSwizzleNumber = Number<WarpSize * KPackPerGroup>{};
return make_naive_tensor_descriptor(
make_tuple(N0 / NWave, NWave, K0, NkSwizzleNumber),
make_tuple(NWave * K0 * NkSwizzleNumber, K0 * NkSwizzleNumber, NkSwizzleNumber, I1));
@@ -1249,7 +1249,7 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle
make_multi_index(n_block_data_idx_on_grid,
get_warp_local_1d_id() % NWave,
0,
KPackPerGroup * (get_thread_local_1d_id() % warpSize)));
KPackPerGroup * (get_thread_local_1d_id() % WarpSize)));
// LDS allocation for A and B: be careful of alignment
// Cast after lds
@@ -1687,7 +1687,7 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle
make_multi_index(n_block_data_idx_on_grid,
get_warp_local_1d_id() % NWave,
0,
KPackPerGroup * (get_thread_local_1d_id() % warpSize)));
KPackPerGroup * (get_thread_local_1d_id() % WarpSize)));
// LDS allocation for A and B: be careful of alignment
// Cast after lds

View File

@@ -370,7 +370,7 @@ struct GridwiseGemmMultiD_blockscale_xdl_cshuffle_v3_b_preshuffle
__host__ __device__ static auto MakeBGridDescriptor_Preshuffled(index_t N0, index_t K0)
{
constexpr index_t NkSwizzleNumber = Number<warpSize * KPack / KGroup>{};
constexpr index_t NkSwizzleNumber = Number<WarpSize * KPack / KGroup>{};
return make_naive_tensor_descriptor(
make_tuple(N0 / NWave, NWave, K0, NkSwizzleNumber),
make_tuple(NWave * K0 * NkSwizzleNumber, K0 * NkSwizzleNumber, NkSwizzleNumber, I1));
@@ -1208,7 +1208,7 @@ struct GridwiseGemmMultiD_blockscale_xdl_cshuffle_v3_b_preshuffle
make_multi_index(n_block_data_idx_on_grid,
get_warp_local_1d_id() % NWave,
0,
KPack / KGroup * (get_thread_local_1d_id() % warpSize)));
KPack / KGroup * (get_thread_local_1d_id() % WarpSize)));
// LDS allocation for A and B: be careful of alignment
// Cast after lds
@@ -1707,7 +1707,7 @@ struct GridwiseGemmMultiD_blockscale_xdl_cshuffle_v3_b_preshuffle
make_multi_index(n_block_data_idx_on_grid,
get_warp_local_1d_id() % NWave,
0,
KPack / KGroup * (get_thread_local_1d_id() % warpSize)));
KPack / KGroup * (get_thread_local_1d_id() % WarpSize)));
// LDS allocation for A and B: be careful of alignment
// Cast after lds

View File

@@ -422,7 +422,7 @@ struct GridwiseGemmMX_xdl_cshuffle_v3_bpreshuffle
__host__ __device__ static auto MakeBGridDescriptor_Preshuffled(index_t N0, index_t K0)
{
constexpr index_t NkSwizzleNumber = Number<warpSize * KPack>{};
constexpr index_t NkSwizzleNumber = Number<WarpSize * KPack>{};
return make_naive_tensor_descriptor_packed(
make_tuple(N0 / NWave / NXdlPack, NWave, NXdlPack, K0, NkSwizzleNumber));
}
@@ -1886,7 +1886,7 @@ struct GridwiseGemmMX_xdl_cshuffle_v3_bpreshuffle
get_warp_local_1d_id() % NWave,
0,
0,
KPack * (get_thread_local_1d_id() % warpSize)));
KPack * (get_thread_local_1d_id() % WarpSize)));
// LDS allocation for A and B: be careful of alignment
auto a_block_buf_ping = make_dynamic_buffer<AddressSpaceEnum::Lds>(

View File

@@ -405,7 +405,7 @@ struct GridwiseMoeGemm
__host__ __device__ static auto MakeBGridDescriptor_Preshuffled(index_t N0, index_t K0)
{
constexpr index_t NkSwizzleNumber = Number<warpSize * KPack / KGroup>{};
constexpr index_t NkSwizzleNumber = Number<WarpSize * KPack / KGroup>{};
return make_naive_tensor_descriptor(
make_tuple(N0 / NWave, NWave, K0, NkSwizzleNumber),
make_tuple(NWave * K0 * NkSwizzleNumber, K0 * NkSwizzleNumber, NkSwizzleNumber, I1));
@@ -1315,7 +1315,7 @@ struct GridwiseMoeGemm
make_multi_index(n_block_data_idx_on_grid,
get_warp_local_1d_id() % NWave,
0,
KPack / KGroup * (get_thread_local_1d_id() % warpSize)));
KPack / KGroup * (get_thread_local_1d_id() % WarpSize)));
// LDS allocation for A and B: be careful of alignment
// Cast after lds
@@ -1361,7 +1361,8 @@ struct GridwiseMoeGemm
make_multi_index(n_block_data_idx_on_grid,
get_warp_local_1d_id() % NWave,
0,
KPack / KGroup * (get_thread_local_1d_id() % warpSize)));
KPack / KGroup * (get_thread_local_1d_id() % WarpSize)));
blockwise_gemm_pipeline.template Run<HasMainKBlockLoop, TailNum>(
a_grid_desc_ak0_m_ak1,
a_block_desc_ak0_m_ak1,
@@ -2027,7 +2028,7 @@ struct GridwiseMoeGemm
make_multi_index(n_block_data_idx_on_grid,
get_warp_local_1d_id() % NWave,
0,
KPack / KGroup * (get_thread_local_1d_id() % warpSize)));
KPack / KGroup * (get_thread_local_1d_id() % WarpSize)));
// LDS allocation for A and B: be careful of alignment
// Cast after lds
@@ -2077,7 +2078,7 @@ struct GridwiseMoeGemm
make_multi_index(n_block_data_idx_on_grid,
get_warp_local_1d_id() % NWave,
0,
KPack / KGroup * (get_thread_local_1d_id() % warpSize)));
KPack / KGroup * (get_thread_local_1d_id() % WarpSize)));
blockwise_gemm_pipeline.template Run<HasMainKBlockLoop, TailNum>(
a_grid_desc_ak0_m_ak1,
a_block_desc_ak0_m_ak1,

View File

@@ -410,7 +410,7 @@ struct GridwiseMoeGemmBlockScale
__host__ __device__ static auto MakeBGridDescriptor_Preshuffled(index_t N0, index_t K0)
{
constexpr index_t NkSwizzleNumber = Number<warpSize * KPack / KGroup>{};
constexpr index_t NkSwizzleNumber = Number<WarpSize * KPack / KGroup>{};
return make_naive_tensor_descriptor(
make_tuple(N0 / NWave, NWave, K0, NkSwizzleNumber),
make_tuple(NWave * K0 * NkSwizzleNumber, K0 * NkSwizzleNumber, NkSwizzleNumber, I1));
@@ -1355,7 +1355,7 @@ struct GridwiseMoeGemmBlockScale
make_multi_index(n_block_data_idx_on_grid,
get_warp_local_1d_id() % NWave,
0,
KPack / KGroup * (get_thread_local_1d_id() % warpSize)));
KPack / KGroup * (get_thread_local_1d_id() % WarpSize)));
// LDS allocation for A and B: be careful of alignment
// Cast after lds
@@ -1467,7 +1467,7 @@ struct GridwiseMoeGemmBlockScale
make_multi_index(n_block_data_idx_on_grid,
get_warp_local_1d_id() % NWave,
0,
KPack / KGroup * (get_thread_local_1d_id() % warpSize)));
KPack / KGroup * (get_thread_local_1d_id() % WarpSize)));
const BScaleType* p_b_scale_grid_up =
p_b_scale_grid + expert_scale_stride / 2 / BPackedSize;
const auto b_scale_grid_buf_up = make_dynamic_buffer<AddressSpaceEnum::Global>(
@@ -2105,7 +2105,7 @@ struct GridwiseMoeGemmBlockScale
make_multi_index(n_block_data_idx_on_grid,
get_warp_local_1d_id() % NWave,
0,
KPack / KGroup * (get_thread_local_1d_id() % warpSize)));
KPack / KGroup * (get_thread_local_1d_id() % WarpSize)));
// LDS allocation for A and B: be careful of alignment
// Cast after lds
@@ -2221,7 +2221,7 @@ struct GridwiseMoeGemmBlockScale
make_multi_index(n_block_data_idx_on_grid,
get_warp_local_1d_id() % NWave,
0,
KPack / KGroup * (get_thread_local_1d_id() % warpSize)));
KPack / KGroup * (get_thread_local_1d_id() % WarpSize)));
const BScaleType* p_b_scale_grid_up =
p_b_scale_grid + expert_scale_stride / 2 / BPackedSize;
const auto b_scale_grid_buf_up = make_dynamic_buffer<AddressSpaceEnum::Global>(

View File

@@ -409,7 +409,7 @@ struct GridwiseMoeGemmMX
__host__ __device__ static auto MakeBGridDescriptor_Preshuffled(index_t N0, index_t K0)
{
constexpr index_t NkSwizzleNumber = Number<warpSize * KPack / KGroup>{};
constexpr index_t NkSwizzleNumber = Number<WarpSize * KPack / KGroup>{};
return make_naive_tensor_descriptor(
make_tuple(N0 / NWave / NXdlPack, NWave, NXdlPack, K0, NkSwizzleNumber),
make_tuple(NWave * NXdlPack * K0 * NkSwizzleNumber,
@@ -1415,7 +1415,7 @@ struct GridwiseMoeGemmMX
make_multi_index(n_block_data_idx_on_grid,
get_warp_local_1d_id() % NWave,
0,
KPack / KGroup * (get_thread_local_1d_id() % warpSize)));
KPack / KGroup * (get_thread_local_1d_id() % WarpSize)));
// LDS allocation for A and B: be careful of alignment
// Cast after lds
@@ -1508,7 +1508,7 @@ struct GridwiseMoeGemmMX
make_multi_index(n_block_data_idx_on_grid,
get_warp_local_1d_id() % NWave,
0,
KPack / KGroup * (get_thread_local_1d_id() % warpSize)));
KPack / KGroup * (get_thread_local_1d_id() % WarpSize)));
const BScaleDataType* p_b_scale_grid_up = p_b_scale_grid + expert_scale_stride / 2;
const auto b_scale_grid_buf_up = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_b_scale_grid_up + expert_id * expert_scale_stride,
@@ -2123,7 +2123,7 @@ struct GridwiseMoeGemmMX
get_warp_local_1d_id() % NWave,
0,
0,
KPack / KGroup * (get_thread_local_1d_id() % warpSize)));
KPack / KGroup * (get_thread_local_1d_id() % WarpSize)));
// LDS allocation for A and B: be careful of alignment
// Cast after lds
@@ -2221,7 +2221,7 @@ struct GridwiseMoeGemmMX
make_multi_index(n_block_data_idx_on_grid,
get_warp_local_1d_id() % NWave,
0,
KPack / KGroup * (get_thread_local_1d_id() % warpSize)));
KPack / KGroup * (get_thread_local_1d_id() % WarpSize)));
const BScaleDataType* p_b_scale_grid_up = p_b_scale_grid + expert_scale_stride / 2;
const auto b_scale_grid_buf_up = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_b_scale_grid_up + expert_id * expert_scale_stride,

View File

@@ -2319,7 +2319,7 @@ struct GridwiseMoeGemmMXBNS
get_warp_local_1d_id() % NWave,
0,
0,
KPack / KGroup * (get_thread_local_1d_id() % warpSize)));
KPack / KGroup * (get_thread_local_1d_id() % WarpSize)));
// LDS allocation for A and B: be careful of alignment
// Cast after lds
@@ -2417,7 +2417,7 @@ struct GridwiseMoeGemmMXBNS
make_multi_index(n_block_data_idx_on_grid,
get_warp_local_1d_id() % NWave,
0,
KPack / KGroup * (get_thread_local_1d_id() % warpSize)));
KPack / KGroup * (get_thread_local_1d_id() % WarpSize)));
const BScaleDataType* p_b_scale_grid_up = p_b_scale_grid + expert_scale_stride / 2;
const auto b_scale_grid_buf_up = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_b_scale_grid_up + expert_id * expert_scale_stride,

View File

@@ -32,7 +32,7 @@ static __device__ void gms_init(int NumWarps, int* p_control_bits)
// all the workgroups in the synchronization group is supposed to call this function
static __device__ void gms_barrier(int* p_control_bits)
{
constexpr int mask = warpSize - 1;
constexpr int mask = WarpSize - 1;
if((threadIdx.x & mask) == 0)
{

View File

@@ -35,7 +35,7 @@ CK_TILE_DEVICE T warp_shuffle_up(const T& v_local, uint32_t lane_delta)
#elif 1
static_assert(sizeof(T) == sizeof(int32_t), "wrong!");
const uint32_t wrap_around_lane_delta = warpSize - lane_delta;
const uint32_t wrap_around_lane_delta = get_warp_size() - lane_delta;
const int32_t v_remote_tmp = __builtin_amdgcn_ds_bpermute(
(__lane_id() << 2) + (wrap_around_lane_delta << 2), bit_cast<int32_t>(v_local));

View File

@@ -95,7 +95,7 @@ struct Flatmm_32x512x128_1x4x1_16x16x32_Base // for f16/bf16
// constexpr index_t Block_M = Problem::BlockShape::Block_M0;
// constexpr index_t Block_K = Problem::BlockShape::Block_K0;
// constexpr index_t BlockSize = Problem::BlockShape::BlockSize;
constexpr index_t warpSize = ck_tile::get_warp_size();
constexpr index_t WarpSize = ck_tile::get_warp_size();
// constexpr index_t NumWarps = Problem::BlockShape::NumWarps;
constexpr index_t KPack_ = 8; // GetSmemKPack_A<Problem>(); // LDS
@@ -104,11 +104,11 @@ struct Flatmm_32x512x128_1x4x1_16x16x32_Base // for f16/bf16
static_assert(Block_K % KVector == 0);
constexpr index_t LanesPerK = Block_K / KVector; // how many thread loading K
if constexpr(LanesPerK >= warpSize)
if constexpr(LanesPerK >= WarpSize)
{
// need multiple waves to load K
static_assert(LanesPerK % warpSize == 0);
constexpr index_t wavesPerK = LanesPerK / warpSize;
static_assert(LanesPerK % WarpSize == 0);
constexpr index_t wavesPerK = LanesPerK / WarpSize;
if constexpr(wavesPerK > NumWarps)
{
// TODO: need multiple issues along K to load all data
@@ -121,11 +121,11 @@ struct Flatmm_32x512x128_1x4x1_16x16x32_Base // for f16/bf16
make_tuple(number<NumIssues>{}, // m0
number<wavesPerM>{}, // m1
number<wavesPerK>{}, // k0
number<warpSize>{}, // k1
number<WarpSize>{}, // k1
number<KVector>{}), // k2
make_tuple(number<NumWarps*(warpSize * KVector + KPad)>{}, // m0
number<wavesPerK*(warpSize * KVector + KPad)>{}, // m1
number<warpSize * KVector + KPad>{}, // k0
make_tuple(number<NumWarps*(WarpSize * KVector + KPad)>{}, // m0
number<wavesPerK*(WarpSize * KVector + KPad)>{}, // m1
number<WarpSize * KVector + KPad>{}, // k0
number<KVector>{}, // k1
number<1>{}), // k2
number<KVector>{}, // lds store vector(actually no explicit store)
@@ -136,7 +136,7 @@ struct Flatmm_32x512x128_1x4x1_16x16x32_Base // for f16/bf16
make_tuple(
make_pass_through_transform(number<NumIssues>{}),
make_merge_transform(make_tuple(number<wavesPerM>{}, number<wavesPerK>{})),
make_merge_transform(make_tuple(number<warpSize>{}, number<KVector>{}))),
make_merge_transform(make_tuple(number<WarpSize>{}, number<KVector>{}))),
make_tuple(sequence<0>{}, sequence<1, 2>{}, sequence<3, 4>{}),
make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}));
@@ -146,8 +146,8 @@ struct Flatmm_32x512x128_1x4x1_16x16x32_Base // for f16/bf16
else
{
// lanes within a wave load different M but same K
static_assert(warpSize % LanesPerK == 0);
constexpr index_t LaneGroups = warpSize / LanesPerK; // along m
static_assert(WarpSize % LanesPerK == 0);
constexpr index_t LaneGroups = WarpSize / LanesPerK; // along m
constexpr index_t NumIssues = Block_M / (LaneGroups * NumWarps);
constexpr auto lds_block_desc_0 = make_naive_tensor_descriptor(
@@ -156,9 +156,9 @@ struct Flatmm_32x512x128_1x4x1_16x16x32_Base // for f16/bf16
number<NumWarps>{}, // m2
number<LanesPerK>{}, // k0
number<KVector>{}), // k1
make_tuple(number<NumWarps*(warpSize * KVector + KPad)>{}, // m0
make_tuple(number<NumWarps*(WarpSize * KVector + KPad)>{}, // m0
number<Block_K>{}, // m1
number<warpSize * KVector + KPad>{}, // m2
number<WarpSize * KVector + KPad>{}, // m2
number<KVector>{}, // k0
number<1>{}), // k1
number<KVector>{}, // lds store vector(actually no explicit store)

View File

@@ -448,19 +448,19 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy<QLo
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK1;
constexpr index_t NumWarps = Problem::BlockFmhaShape::NumWarps;
constexpr index_t warpSize = ck_tile::get_warp_size();
constexpr index_t WarpSize = ck_tile::get_warp_size();
constexpr index_t KPack = GetSmemKPackK<Problem>(); // this is for lds
constexpr index_t KVector = GetAlignmentK<Problem>(); // this is for global load
constexpr index_t kPad = KPack;
static_assert(warpSize * KVector >= kKPerBlock &&
warpSize * KVector % kKPerBlock == 0);
static_assert(WarpSize * KVector >= kKPerBlock &&
WarpSize * KVector % kKPerBlock == 0);
constexpr index_t LanesPerK = kKPerBlock / KVector;
constexpr index_t LaneGroups = warpSize / LanesPerK;
constexpr index_t LaneGroups = WarpSize / LanesPerK;
constexpr index_t NumIssues = kNPerBlock / (LaneGroups * NumWarps);
return NumIssues * NumWarps * (warpSize * KVector + kPad);
return NumIssues * NumWarps * (WarpSize * KVector + kPad);
}
}();
@@ -516,18 +516,18 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy<QLo
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK1;
constexpr index_t kBlockSize = Problem::kBlockSize;
constexpr index_t NumWarps = Problem::BlockFmhaShape::NumWarps;
constexpr index_t warpSize = ck_tile::get_warp_size();
constexpr index_t WarpSize = ck_tile::get_warp_size();
constexpr index_t KPack = GetSmemKPackK<Problem>(); // this is for lds
constexpr index_t KVector = GetAlignmentK<Problem>(); // this is for global load
constexpr index_t kPad =
KPack; // for async-copy, this pad is between warps. Optimize this for lds_read speed
static_assert(warpSize * KVector >= kKPerBlock && warpSize * KVector % kKPerBlock == 0);
static_assert(WarpSize * KVector >= kKPerBlock && WarpSize * KVector % kKPerBlock == 0);
constexpr index_t LanesPerK =
kKPerBlock / KVector; // how many lane (within a wave) to load K
constexpr index_t LaneGroups =
warpSize /
WarpSize /
LanesPerK; // how many groups (within a wave), they may load different N, but same K
constexpr index_t NumIssues = kNPerBlock / (LaneGroups * NumWarps);
static_assert(NumIssues == kNPerBlock * kKPerBlock / (kBlockSize * KVector));
@@ -538,9 +538,9 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy<QLo
number<NumWarps>{}, // n2
number<LanesPerK>{}, // k0
number<KVector>{}), // k1
make_tuple(number<NumWarps*(warpSize * KVector + kPad)>{},
make_tuple(number<NumWarps*(WarpSize * KVector + kPad)>{},
number<kKPerBlock>{},
number<warpSize * KVector + kPad>{},
number<WarpSize * KVector + kPad>{},
number<KVector>{},
number<1>{}),
number<IBuf * GetSingleSmemElementSpaceSize<Problem>()>{},
@@ -569,18 +569,18 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy<QLo
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK1;
constexpr index_t kBlockSize = Problem::kBlockSize;
constexpr index_t NumWarps = Problem::BlockFmhaShape::NumWarps;
constexpr index_t warpSize = ck_tile::get_warp_size();
constexpr index_t WarpSize = ck_tile::get_warp_size();
constexpr index_t KPack = GetSmemKPackK<Problem>(); // this is for lds
constexpr index_t KVector = GetAlignmentK<Problem>(); // this is for global load
constexpr index_t kPad = KPack; // for async-copy, this pad is between warps
static_assert(warpSize * KVector >= kKPerBlock && warpSize * KVector % kKPerBlock == 0);
static_assert(WarpSize * KVector >= kKPerBlock && WarpSize * KVector % kKPerBlock == 0);
constexpr index_t LanesPerK = kKPerBlock / KVector; // within a wave
constexpr index_t LaneGroups = warpSize / LanesPerK; // within a wave
constexpr index_t LaneGroups = WarpSize / LanesPerK; // within a wave
constexpr index_t NumIssues = kNPerBlock / (LaneGroups * NumWarps);
static_assert(NumIssues == kNPerBlock * kKPerBlock / (kBlockSize * KVector));
// constexpr index_t SingleKSize = NumIssues * NumWarps * (warpSize * KVector + kPad);
// constexpr index_t SingleKSize = NumIssues * NumWarps * (WarpSize * KVector + kPad);
// constexpr index_t SingleVSize =
// MakeVLdsBlockDescriptor<Problem>().get_element_space_size();
constexpr index_t BufferSize =
@@ -594,8 +594,8 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy<QLo
number<kKPerBlock / KPack>{}, // k0
number<KPack>{}), // k1
make_tuple(number<BufferSize>{},
number<NumWarps*(warpSize * KVector + kPad)>{},
number<warpSize * KVector + kPad>{},
number<NumWarps*(WarpSize * KVector + kPad)>{},
number<WarpSize * KVector + kPad>{},
number<kKPerBlock>{},
number<KPack>{},
number<1>{}),
@@ -746,13 +746,13 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy<QLo
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK1;
constexpr index_t kBlockSize = Problem::kBlockSize;
constexpr index_t NumWarps = Problem::BlockFmhaShape::NumWarps;
constexpr index_t warpSize = ck_tile::get_warp_size();
constexpr index_t WarpSize = ck_tile::get_warp_size();
constexpr index_t KVector = GetAlignmentK<Problem>(); // this is for global load
static_assert(warpSize * KVector >= kKPerBlock && warpSize * KVector % kKPerBlock == 0);
static_assert(WarpSize * KVector >= kKPerBlock && WarpSize * KVector % kKPerBlock == 0);
constexpr index_t LanesPerK = kKPerBlock / KVector; // within a wave
constexpr index_t LaneGroups = warpSize / LanesPerK; // within a wave
constexpr index_t LaneGroups = WarpSize / LanesPerK; // within a wave
constexpr index_t NumIssues = kNPerBlock / (LaneGroups * NumWarps);
static_assert(NumIssues == kNPerBlock * kKPerBlock / (kBlockSize * KVector));

View File

@@ -101,7 +101,7 @@ struct FusedMoeGemmShape
static constexpr index_t Repeat_N1 = Block_N1 / ThreadPerBlock_N1;
static constexpr index_t Repeat_K1 = Block_K1 / ThreadPerBlock_K1;
static constexpr index_t BlockSize = warpSize * NumWarps;
static constexpr index_t BlockSize = WarpSize * NumWarps;
// some assert
static_assert(Block_M0 == Block_M1);

View File

@@ -381,7 +381,7 @@ struct MoeSortingKernel
}
// reduce single pixel within a wave
template <typename T, typename F, index_t wave_size_ = warpSize>
template <typename T, typename F, index_t wave_size_ = WarpSize>
__device__ static constexpr T wave_reduce(T local, F reduce_f, number<wave_size_> = {})
{
// constexpr int wave_size = 64;
@@ -618,7 +618,7 @@ struct MoeSortingKernel
{
const index_t prefill_token = topk_mdiv.div(numel);
// TODO: only support expert-tile like 8, 16, 32
static constexpr index_t experts_per_wave = warpSize / Problem::ExpertTile;
static constexpr index_t experts_per_wave = WarpSize / Problem::ExpertTile;
{
index_t eid = tid / experts_per_wave;
index_t expert_offset = cumsum[eid] +
@@ -686,7 +686,7 @@ struct MoeSortingKernel
void* smem) const
{
const index_t tid = static_cast<index_t>(threadIdx.x);
const index_t wid = __builtin_amdgcn_readfirstlane(tid / warpSize);
const index_t wid = __builtin_amdgcn_readfirstlane(tid / WarpSize);
const index_t lid = __lane_id();
constexpr index_t block_size = 256; // blockDim.x;
const index_t sub_tokens = smem_rows - 2; // sub_tokens_mdiv.divisor;
@@ -791,7 +791,7 @@ struct MoeSortingKernel
// NOTE: under this block can never use __syncthreads!
int i_e_ = 0;
int local_cumsum_ = 0;
for(; i_e_ < num_experts; i_e_ += warpSize)
for(; i_e_ < num_experts; i_e_ += WarpSize)
{
int pre_cumsum_ = smem_cumsum(lid == 0 ? i_e_ : 0);
int local_cnt = smem_cumsum(i_e_ + lid + 1);
@@ -836,7 +836,7 @@ struct MoeSortingKernel
// cumsum padded in case local cumsum is zero, but
// pre_sumsum has value, which will result int
// zero local cumsum(but we want at least padded)
wave_cumsum<int, warpSize>(local_cumsum_);
wave_cumsum<int, WarpSize>(local_cumsum_);
if((i_e_ + lid) < num_experts)
smem_cumsum(i_e_ + lid + 1) = local_cumsum_;
@@ -844,7 +844,7 @@ struct MoeSortingKernel
if constexpr(Problem::LocalExpertMasking)
{
local_masking += pre_cumsum_masking;
wave_cumsum<int, warpSize>(local_masking);
wave_cumsum<int, WarpSize>(local_masking);
if((i_e_ + lid) < num_experts)
smem_cumdup(i_e_ + lid + 1) = local_masking;
}
@@ -854,7 +854,7 @@ struct MoeSortingKernel
// than 0(which is not we want)
__builtin_amdgcn_s_waitcnt(0xc07f);
}
if((lid + i_e_ - warpSize) == (num_experts - 1))
if((lid + i_e_ - WarpSize) == (num_experts - 1))
{
*p_total_tokens_post_pad = local_cumsum_;
}
@@ -1091,7 +1091,7 @@ CK_TILE_HOST_DEVICE index_t moe_sorting_mp_sem_smem_size()
return chunk * sizeof(index_t);
};
template <typename T, typename F, index_t wave_size_ = warpSize>
template <typename T, typename F, index_t wave_size_ = WarpSize>
CK_TILE_DEVICE constexpr T moe_sorting_wave_reduce(T local, F reduce_f, number<wave_size_> = {})
{
// constexpr int wave_size = 64;
@@ -1456,7 +1456,7 @@ struct MoeSortingMultiPhaseKernel_P1
// in byte
CK_TILE_HOST_DEVICE static constexpr auto GetSmemSize()
{
return BLOCK_SIZE / warpSize * sizeof(IndexType);
return BLOCK_SIZE / WarpSize * sizeof(IndexType);
}
CK_TILE_DEVICE void operator()(Kargs kargs) const
@@ -1498,8 +1498,8 @@ struct MoeSortingMultiPhaseKernel_P1
cnt += impl::moe_sorting_wave_reduce(local_sum, f_sum);
}
index_t lane_id = threadIdx.x % warpSize;
index_t wave_id = threadIdx.x / warpSize;
index_t lane_id = threadIdx.x % WarpSize;
index_t wave_id = threadIdx.x / WarpSize;
// reduce cross wave
IndexType* s = reinterpret_cast<IndexType*>(smem);
@@ -1512,7 +1512,7 @@ struct MoeSortingMultiPhaseKernel_P1
if(threadIdx.x == 0)
{
index_t c = 0;
for(auto i = 0; i < (BLOCK_SIZE / warpSize); i++)
for(auto i = 0; i < (BLOCK_SIZE / WarpSize); i++)
{
c += s[i];
}
@@ -1601,7 +1601,7 @@ struct MoeSortingMultiPhaseKernel_P01
// in byte
CK_TILE_HOST static constexpr auto GetSmemSize()
{
return BLOCK_SIZE / warpSize * sizeof(IndexType);
return BLOCK_SIZE / WarpSize * sizeof(IndexType);
}
CK_TILE_DEVICE void operator()(Kargs kargs) const
@@ -1685,8 +1685,8 @@ struct MoeSortingMultiPhaseKernel_P01
cnt += impl::moe_sorting_wave_reduce(local_sum, f_sum);
}
index_t lane_id = threadIdx.x % warpSize;
index_t wave_id = threadIdx.x / warpSize;
index_t lane_id = threadIdx.x % WarpSize;
index_t wave_id = threadIdx.x / WarpSize;
// reduce cross wave
IndexType* s = reinterpret_cast<IndexType*>(smem);
@@ -1700,7 +1700,7 @@ struct MoeSortingMultiPhaseKernel_P01
if(threadIdx.x == 0)
{
index_t c = 0;
for(auto i = 0; i < (BLOCK_SIZE / warpSize); i++)
for(auto i = 0; i < (BLOCK_SIZE / WarpSize); i++)
{
c += s[i];
}
@@ -1777,7 +1777,7 @@ struct MoeSortingMultiPhaseKernel_P2
CK_TILE_HOST_DEVICE static constexpr auto GetSmemSize()
{
// return 2 * BLOCK_SIZE * sizeof(IndexType);
return (4 + 2 * BLOCK_SIZE / warpSize) * sizeof(IndexType);
return (4 + 2 * BLOCK_SIZE / WarpSize) * sizeof(IndexType);
}
// reduce single pixel within a wave
@@ -1802,8 +1802,8 @@ struct MoeSortingMultiPhaseKernel_P2
IndexType* p_sorted_expert_ids = reinterpret_cast<IndexType*>(kargs.p_sorted_expert_ids);
const index_t loops = (kargs.num_experts + BLOCK_SIZE - 1) / BLOCK_SIZE;
index_t wave_id = threadIdx.x / warpSize;
index_t lane_id = threadIdx.x % warpSize;
index_t wave_id = threadIdx.x / WarpSize;
index_t lane_id = threadIdx.x % WarpSize;
IndexType prev_cumsum_a = 0;
IndexType prev_cumsum_b = 0;
@@ -1848,22 +1848,22 @@ struct MoeSortingMultiPhaseKernel_P2
IndexType cumsum_b = b_;
// Note: we first cumsum local round, then add previous cumsum
impl::moe_sorting_wave_cumsum<IndexType, warpSize>(cumsum_a);
impl::moe_sorting_wave_cumsum<IndexType, warpSize>(cumsum_b);
impl::moe_sorting_wave_cumsum<IndexType, WarpSize>(cumsum_a);
impl::moe_sorting_wave_cumsum<IndexType, WarpSize>(cumsum_b);
__syncthreads();
if(lane_id == warpSize - 1)
if(lane_id == WarpSize - 1)
{
s[4 + wave_id] = cumsum_a;
s[4 + wave_id + BLOCK_SIZE / warpSize] = cumsum_b;
s[4 + wave_id + BLOCK_SIZE / WarpSize] = cumsum_b;
}
__syncthreads();
// reduce cross wave
static_for<0, BLOCK_SIZE / warpSize - 1, 1>{}([&](auto i_w) {
static_for<0, BLOCK_SIZE / WarpSize - 1, 1>{}([&](auto i_w) {
IndexType prev_a = s[4 + i_w];
IndexType prev_b = s[4 + i_w + BLOCK_SIZE / warpSize];
IndexType prev_b = s[4 + i_w + BLOCK_SIZE / WarpSize];
prev_a = wave_id > i_w ? prev_a : 0; // mask out
prev_b = wave_id > i_w ? prev_b : 0; // mask out
cumsum_a += prev_a;
@@ -1978,7 +1978,7 @@ struct MoeSortingMultiPhaseKernel_P3
// in byte
CK_TILE_HOST_DEVICE static constexpr auto GetSmemSize()
{
return (4 + BLOCK_SIZE / warpSize) * sizeof(IndexType);
return (4 + BLOCK_SIZE / WarpSize) * sizeof(IndexType);
}
CK_TILE_DEVICE void operator()(Kargs kargs) const
@@ -1995,8 +1995,8 @@ struct MoeSortingMultiPhaseKernel_P3
WeightType* p_sorted_weights = reinterpret_cast<WeightType*>(kargs.p_sorted_weights);
int eid = blockIdx.x;
int wave_id = threadIdx.x / warpSize;
int lane_id = threadIdx.x % warpSize;
int wave_id = threadIdx.x / WarpSize;
int lane_id = threadIdx.x % WarpSize;
int e_start = p_expert_cumsum[eid];
int e_end = p_expert_cumsum[eid + 1];
if constexpr(Problem::SkipExpertsWithZeroTokens)
@@ -2026,17 +2026,17 @@ struct MoeSortingMultiPhaseKernel_P3
int i_topk = x - 1; // topk of this token
int i_show = x != 0 ? 1 : 0; // has this token or not
int cumsum = i_show;
impl::moe_sorting_wave_cumsum<int, warpSize>(cumsum);
impl::moe_sorting_wave_cumsum<int, WarpSize>(cumsum);
__syncthreads();
if(lane_id == warpSize - 1)
if(lane_id == WarpSize - 1)
{
s[4 + wave_id] = cumsum;
}
__syncthreads();
// reduce cross wave
static_for<0, BLOCK_SIZE / warpSize - 1, 1>{}([&](auto i_w) {
static_for<0, BLOCK_SIZE / WarpSize - 1, 1>{}([&](auto i_w) {
IndexType prev = s[4 + i_w];
prev = wave_id > i_w ? prev : 0; // mask out
cumsum += prev;
@@ -2081,7 +2081,7 @@ CK_TILE_HOST constexpr auto moe_sorting_get_smem_size_p23(int num_experts_)
{
constexpr index_t BLOCK_SIZE = 256; // hardcoded 256
const index_t expert_cumsum_elem = num_experts_ + 1;
return (4 + 2 * BLOCK_SIZE / warpSize + expert_cumsum_elem) * sizeof(int);
return (4 + 2 * BLOCK_SIZE / WarpSize + expert_cumsum_elem) * sizeof(int);
}
} // namespace impl
@@ -2186,15 +2186,15 @@ struct MoeSortingMultiPhaseKernel_P23
const IndexType* p_local_expert_mask =
static_cast<const IndexType*>(kargs.p_local_expert_mask);
IndexType* p_expert_cumsum = reinterpret_cast<IndexType*>(kargs.p_expert_cumsum);
IndexType* p_expert_cumsum_smem = s + 4 + 2 * BLOCK_SIZE / warpSize;
IndexType* p_expert_cumsum_smem = s + 4 + 2 * BLOCK_SIZE / WarpSize;
IndexType* p_total_tokens_post_pad =
reinterpret_cast<IndexType*>(kargs.p_total_tokens_post_pad);
IndexType* p_sorted_expert_ids =
reinterpret_cast<IndexType*>(kargs.p_sorted_expert_ids);
const index_t loops = (kargs.num_experts + BLOCK_SIZE - 1) / BLOCK_SIZE;
index_t wave_id = threadIdx.x / warpSize;
index_t lane_id = threadIdx.x % warpSize;
index_t wave_id = threadIdx.x / WarpSize;
index_t lane_id = threadIdx.x % WarpSize;
IndexType prev_cumsum_a = 0;
IndexType prev_cumsum_b = 0;
@@ -2239,22 +2239,22 @@ struct MoeSortingMultiPhaseKernel_P23
IndexType cumsum_b = b_;
// Note: we first cumsum local round, then add previous cumsum
impl::moe_sorting_wave_cumsum<IndexType, warpSize>(cumsum_a);
impl::moe_sorting_wave_cumsum<IndexType, warpSize>(cumsum_b);
impl::moe_sorting_wave_cumsum<IndexType, WarpSize>(cumsum_a);
impl::moe_sorting_wave_cumsum<IndexType, WarpSize>(cumsum_b);
__syncthreads();
if(lane_id == warpSize - 1)
if(lane_id == WarpSize - 1)
{
s[4 + wave_id] = cumsum_a;
s[4 + wave_id + BLOCK_SIZE / warpSize] = cumsum_b;
s[4 + wave_id + BLOCK_SIZE / WarpSize] = cumsum_b;
}
__syncthreads();
// reduce cross wave
static_for<0, BLOCK_SIZE / warpSize - 1, 1>{}([&](auto i_w) {
static_for<0, BLOCK_SIZE / WarpSize - 1, 1>{}([&](auto i_w) {
IndexType prev_a = s[4 + i_w];
IndexType prev_b = s[4 + i_w + BLOCK_SIZE / warpSize];
IndexType prev_b = s[4 + i_w + BLOCK_SIZE / WarpSize];
prev_a = wave_id > i_w ? prev_a : 0; // mask out
prev_b = wave_id > i_w ? prev_b : 0; // mask out
cumsum_a += prev_a;
@@ -2324,13 +2324,13 @@ struct MoeSortingMultiPhaseKernel_P23
IndexType* s = reinterpret_cast<IndexType*>(smem);
MeshType* p_expert_mesh = reinterpret_cast<MeshType*>(kargs.p_expert_mesh);
IndexType* p_sorted_token_ids = reinterpret_cast<IndexType*>(kargs.p_sorted_token_ids);
IndexType* p_expert_cumsum_smem = s + 4 + 2 * BLOCK_SIZE / warpSize;
IndexType* p_expert_cumsum_smem = s + 4 + 2 * BLOCK_SIZE / WarpSize;
const WeightType* p_weights = static_cast<const WeightType*>(kargs.p_weights);
WeightType* p_sorted_weights = reinterpret_cast<WeightType*>(kargs.p_sorted_weights);
int eid = blockIdx.x;
int wave_id = threadIdx.x / warpSize;
int lane_id = threadIdx.x % warpSize;
int wave_id = threadIdx.x / WarpSize;
int lane_id = threadIdx.x % WarpSize;
int e_start = p_expert_cumsum_smem[eid];
int e_end = p_expert_cumsum_smem[eid + 1];
if constexpr(Problem::SkipExpertsWithZeroTokens)
@@ -2390,17 +2390,17 @@ struct MoeSortingMultiPhaseKernel_P23
int i_topk = x - 1; // topk of this token
int i_show = x != 0 ? 1 : 0; // has this token or not
int cumsum = i_show;
impl::moe_sorting_wave_cumsum<int, warpSize>(cumsum);
impl::moe_sorting_wave_cumsum<int, WarpSize>(cumsum);
__syncthreads();
if(lane_id == warpSize - 1)
if(lane_id == WarpSize - 1)
{
s[4 + wave_id] = cumsum;
}
__syncthreads();
// reduce cross wave
static_for<0, BLOCK_SIZE / warpSize - 1, 1>{}([&](auto i_w) {
static_for<0, BLOCK_SIZE / WarpSize - 1, 1>{}([&](auto i_w) {
IndexType prev = s[4 + i_w];
prev = wave_id > i_w ? prev : 0; // mask out
cumsum += prev;
@@ -2441,17 +2441,17 @@ struct MoeSortingMultiPhaseKernel_P23
cumsum_store += i_show[j];
});
int cumsum = cumsum_store;
impl::moe_sorting_wave_cumsum<int, warpSize>(cumsum);
impl::moe_sorting_wave_cumsum<int, WarpSize>(cumsum);
__syncthreads();
if(lane_id == warpSize - 1)
if(lane_id == WarpSize - 1)
{
s[4 + wave_id] = cumsum;
}
__syncthreads();
// reduce cross wave
static_for<0, BLOCK_SIZE / warpSize - 1, 1>{}([&](auto i_w) {
static_for<0, BLOCK_SIZE / WarpSize - 1, 1>{}([&](auto i_w) {
IndexType prev = s[4 + i_w];
prev = wave_id > i_w ? prev : 0; // mask out
cumsum += prev;
@@ -2496,17 +2496,17 @@ struct MoeSortingMultiPhaseKernel_P23
int i_topk_1 = x1 - 1; // topk of this token
int i_show_1 = x1 != 0 ? 1 : 0; // has this token or not
int cumsum = i_show_0 + i_show_1;
impl::moe_sorting_wave_cumsum<int, warpSize>(cumsum);
impl::moe_sorting_wave_cumsum<int, WarpSize>(cumsum);
__syncthreads();
if(lane_id == warpSize - 1)
if(lane_id == WarpSize - 1)
{
s[4 + wave_id] = cumsum;
}
__syncthreads();
// reduce cross wave
static_for<0, BLOCK_SIZE / warpSize - 1, 1>{}([&](auto i_w) {
static_for<0, BLOCK_SIZE / WarpSize - 1, 1>{}([&](auto i_w) {
IndexType prev = s[4 + i_w];
prev = wave_id > i_w ? prev : 0; // mask out
cumsum += prev;

View File

@@ -303,7 +303,7 @@ struct FusedMoeGemmPipelineFlatmmPolicy
constexpr index_t Block_M = Problem::BlockShape::Block_M0;
constexpr index_t Block_K = Problem::BlockShape::Block_K0;
// constexpr index_t BlockSize = Problem::BlockShape::BlockSize;
constexpr index_t warpSize = ck_tile::get_warp_size();
constexpr index_t WarpSize = ck_tile::get_warp_size();
constexpr index_t NumWarps = Problem::BlockShape::NumWarps;
constexpr index_t KPack = GetSmemKPack_A<Problem>(); // LDS
@@ -312,11 +312,11 @@ struct FusedMoeGemmPipelineFlatmmPolicy
static_assert(Block_K % KVector == 0);
constexpr index_t LanesPerK = Block_K / KVector; // how many thread loading K
if constexpr(LanesPerK >= warpSize)
if constexpr(LanesPerK >= WarpSize)
{
// need multiple waves to load K
static_assert(LanesPerK % warpSize == 0);
constexpr index_t wavesPerK = LanesPerK / warpSize;
static_assert(LanesPerK % WarpSize == 0);
constexpr index_t wavesPerK = LanesPerK / WarpSize;
if constexpr(wavesPerK > NumWarps)
{
// TODO: need multiple issues along K to load all data
@@ -329,11 +329,11 @@ struct FusedMoeGemmPipelineFlatmmPolicy
make_tuple(number<NumIssues>{}, // m0
number<wavesPerM>{}, // m1
number<wavesPerK>{}, // k0
number<warpSize>{}, // k1
number<WarpSize>{}, // k1
number<KVector>{}), // k2
make_tuple(number<NumWarps*(warpSize * KVector + KPad)>{}, // m0
number<wavesPerK*(warpSize * KVector + KPad)>{}, // m1
number<warpSize * KVector + KPad>{}, // k0
make_tuple(number<NumWarps*(WarpSize * KVector + KPad)>{}, // m0
number<wavesPerK*(WarpSize * KVector + KPad)>{}, // m1
number<WarpSize * KVector + KPad>{}, // k0
number<KVector>{}, // k1
number<1>{}), // k2
number<KVector>{}, // lds store vector(actually no explicit store)
@@ -344,7 +344,7 @@ struct FusedMoeGemmPipelineFlatmmPolicy
make_tuple(
make_pass_through_transform(number<NumIssues>{}),
make_merge_transform(make_tuple(number<wavesPerM>{}, number<wavesPerK>{})),
make_merge_transform(make_tuple(number<warpSize>{}, number<KVector>{}))),
make_merge_transform(make_tuple(number<WarpSize>{}, number<KVector>{}))),
make_tuple(sequence<0>{}, sequence<1, 2>{}, sequence<3, 4>{}),
make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}));
@@ -354,8 +354,8 @@ struct FusedMoeGemmPipelineFlatmmPolicy
else
{
// lanes within a wave load different M but same K
static_assert(warpSize % LanesPerK == 0);
constexpr index_t LaneGroups = warpSize / LanesPerK; // along m
static_assert(WarpSize % LanesPerK == 0);
constexpr index_t LaneGroups = WarpSize / LanesPerK; // along m
constexpr index_t NumIssues = Block_M / (LaneGroups * NumWarps);
constexpr auto lds_block_desc_0 = make_naive_tensor_descriptor(
@@ -364,9 +364,9 @@ struct FusedMoeGemmPipelineFlatmmPolicy
number<NumWarps>{}, // m2
number<LanesPerK>{}, // k0
number<KVector>{}), // k1
make_tuple(number<NumWarps*(warpSize * KVector + KPad)>{}, // m0
make_tuple(number<NumWarps*(WarpSize * KVector + KPad)>{}, // m0
number<Block_K>{}, // m1
number<warpSize * KVector + KPad>{}, // m2
number<WarpSize * KVector + KPad>{}, // m2
number<KVector>{}, // k0
number<1>{}), // k1
number<KVector>{}, // lds store vector(actually no explicit store)
@@ -398,7 +398,7 @@ struct FusedMoeGemmPipelineFlatmmPolicy
constexpr index_t Block_M = Problem::BlockShape::Block_M0;
constexpr index_t Block_K = Problem::BlockShape::Block_K0;
// constexpr index_t BlockSize = Problem::BlockShape::BlockSize;
constexpr index_t warpSize = ck_tile::get_warp_size();
constexpr index_t WarpSize = ck_tile::get_warp_size();
constexpr index_t NumWarps = Problem::BlockShape::NumWarps;
constexpr index_t KPack = GetSmemKPack_A<Problem>(); // LDS
@@ -407,11 +407,11 @@ struct FusedMoeGemmPipelineFlatmmPolicy
static_assert(Block_K % KVector == 0);
constexpr index_t LanesPerK = Block_K / KVector; // how many thread loading K
if constexpr(LanesPerK >= warpSize)
if constexpr(LanesPerK >= WarpSize)
{
// need multiple waves to load K
static_assert(LanesPerK % warpSize == 0);
constexpr index_t wavesPerK = LanesPerK / warpSize;
static_assert(LanesPerK % WarpSize == 0);
constexpr index_t wavesPerK = LanesPerK / WarpSize;
if constexpr(wavesPerK >= NumWarps)
{
// TODO: need multiple issues along K to load all data
@@ -424,11 +424,11 @@ struct FusedMoeGemmPipelineFlatmmPolicy
make_tuple(number<NumIssues>{}, // m0
number<wavesPerM>{}, // m1
number<wavesPerK>{}, // k0
number<warpSize>{}, // k1
number<WarpSize>{}, // k1
number<KVector>{}), // k2
make_tuple(number<NumWarps*(warpSize * KVector + KPad)>{}, // m0
number<wavesPerK*(warpSize * KVector + KPad)>{}, // m1
number<warpSize * KVector + KPad>{}, // k0
make_tuple(number<NumWarps*(WarpSize * KVector + KPad)>{}, // m0
number<wavesPerK*(WarpSize * KVector + KPad)>{}, // m1
number<WarpSize * KVector + KPad>{}, // k0
number<KVector>{}, // k1
number<1>{}), // k2
number<KPack>{}, // lds load vector
@@ -439,7 +439,7 @@ struct FusedMoeGemmPipelineFlatmmPolicy
make_tuple(
make_merge_transform(make_tuple(number<NumIssues>{}, number<wavesPerM>{})),
make_merge_transform(make_tuple(
number<wavesPerK>{}, number<warpSize>{}, number<KVector>{}))),
number<wavesPerK>{}, number<WarpSize>{}, number<KVector>{}))),
make_tuple(sequence<0, 1>{}, sequence<2, 3, 4>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
@@ -449,8 +449,8 @@ struct FusedMoeGemmPipelineFlatmmPolicy
else
{
// lanes within a wave load different M but same K
static_assert(warpSize % LanesPerK == 0);
constexpr index_t LaneGroups = warpSize / LanesPerK; // along m
static_assert(WarpSize % LanesPerK == 0);
constexpr index_t LaneGroups = WarpSize / LanesPerK; // along m
constexpr index_t NumIssues = Block_M / (LaneGroups * NumWarps);
constexpr auto lds_block_desc_0 = make_naive_tensor_descriptor(
@@ -459,9 +459,9 @@ struct FusedMoeGemmPipelineFlatmmPolicy
number<NumWarps>{}, // m2
number<LanesPerK>{}, // k0
number<KVector>{}), // k1
make_tuple(number<NumWarps*(warpSize * KVector + KPad)>{}, // m0
make_tuple(number<NumWarps*(WarpSize * KVector + KPad)>{}, // m0
number<Block_K>{}, // m1
number<warpSize * KVector + KPad>{}, // m2
number<WarpSize * KVector + KPad>{}, // m2
number<KVector>{}, // k0
number<1>{}), // k1
number<KPack>{}, // lds load vector

View File

@@ -26,7 +26,7 @@ struct TileImageToColumnShape
static constexpr index_t kMWarpPerBlock = kMPerBlock / kMPerWarp;
static constexpr index_t kKWarpPerBlock = kKPerBlock / kKPerWarp;
static constexpr index_t kBlockSize = warpSize * kMWarpPerBlock * kKWarpPerBlock;
static constexpr index_t kBlockSize = get_warp_size() * kMWarpPerBlock * kKWarpPerBlock;
};
} // namespace ck_tile

View File

@@ -250,7 +250,7 @@ struct BlockNormReduceCrossWarpSync
// | w0 | w1 | w2 | w3 | -----> | w0123 |
//
// -> also store data from every wave into LDS
constexpr index_t num_warps = BlockShape::BlockSize / warpSize;
constexpr index_t num_warps = BlockShape::BlockSize / WarpSize;
return num_warps * 4 * thread_buf_size * sizeof(float);
}
@@ -276,7 +276,7 @@ struct BlockNormReduceCrossWarpSync
const index_t lane_id = get_lane_id();
const index_t warp_id = get_warp_id();
constexpr auto num_reduce_warps = GetReduceWarps<MeanDistributedTensor_>();
constexpr index_t num_warps = BlockShape::BlockSize / warpSize;
constexpr index_t num_warps = BlockShape::BlockSize / WarpSize;
const index_t smem_offset = warp_id;
// skip if nonthing to do

View File

@@ -210,7 +210,7 @@ struct BlockReduce2dCrossWarpSync
// | w0 | w1 | w2 | w3 | -----> | w0123 |
//
// -> also store data from every wave into LDS
constexpr index_t num_warps = BlockShape::BlockSize / warpSize;
constexpr index_t num_warps = BlockShape::BlockSize / get_warp_size();
return num_warps * thread_buf_size * sizeof(DataType);
}
@@ -226,7 +226,7 @@ struct BlockReduce2dCrossWarpSync
const index_t lane_id = get_lane_id();
const index_t warp_id = get_warp_id();
constexpr auto num_reduce_warps = GetReduceWarps<YDistributedTensor_>();
constexpr index_t num_warps = BlockShape::BlockSize / warpSize;
constexpr index_t num_warps = BlockShape::BlockSize / get_warp_size();
const index_t smem_offset = warp_id;
// skip if nonthing to do