diff --git a/example/ck_tile/02_layernorm2d/generate.py b/example/ck_tile/02_layernorm2d/generate.py index 0238a125dc..2dc9ccbd77 100644 --- a/example/ck_tile/02_layernorm2d/generate.py +++ b/example/ck_tile/02_layernorm2d/generate.py @@ -75,22 +75,22 @@ struct layernorm2d_fwd_traits_ using SmoothScaleDataType = ck_tile::remove_cvref_t; using YScaleDataType = ck_tile::remove_cvref_t; - static constexpr bool is_warp_per_row = ThreadPerBlock_N_ <= warpSize; - static_assert((ThreadPerBlock_M_ * ThreadPerBlock_N_) % warpSize == 0); + static constexpr bool is_warp_per_row = ThreadPerBlock_N_ <= WarpSize; + static_assert((ThreadPerBlock_M_ * ThreadPerBlock_N_) % WarpSize == 0); static constexpr ck_tile::index_t total_warps = - (ThreadPerBlock_M_ * ThreadPerBlock_N_) / warpSize; + (ThreadPerBlock_M_ * ThreadPerBlock_N_) / WarpSize; // num of warps along m static constexpr ck_tile::index_t BlockWarps_M = []() { if constexpr(is_warp_per_row) { - static_assert(warpSize % ThreadPerBlock_N_ == 0); - return total_warps * (warpSize / ThreadPerBlock_N_); + static_assert(WarpSize % ThreadPerBlock_N_ == 0); + return total_warps * (WarpSize / ThreadPerBlock_N_); } else { - // static_assert(warpSize % ThreadPerBlock_M_ == 0); - return total_warps / (ThreadPerBlock_N_ / warpSize); + // static_assert(WarpSize % ThreadPerBlock_M_ == 0); + return total_warps / (ThreadPerBlock_N_ / WarpSize); } }(); @@ -98,13 +98,13 @@ struct layernorm2d_fwd_traits_ static constexpr ck_tile::index_t BlockWarps_N = []() { if constexpr(is_warp_per_row) { - static_assert(warpSize % ThreadPerBlock_N_ == 0); + static_assert(WarpSize % ThreadPerBlock_N_ == 0); return 1; } else { - static_assert(ThreadPerBlock_N_ % warpSize == 0); - return ThreadPerBlock_N_ / warpSize; + static_assert(ThreadPerBlock_N_ % WarpSize == 0); + return ThreadPerBlock_N_ / WarpSize; } }(); diff --git a/example/ck_tile/05_reduce/reduce.hpp b/example/ck_tile/05_reduce/reduce.hpp index 55e479591c..50ffb9c1c7 100644 --- a/example/ck_tile/05_reduce/reduce.hpp +++ b/example/ck_tile/05_reduce/reduce.hpp @@ -35,7 +35,7 @@ struct Reduce2dShape static constexpr index_t Repeat_N = Block_N / (WarpPerBlock_N * Warp_N); static constexpr index_t BlockSize = - warpSize * reduce_on_sequence(BlockWarps{}, multiplies{}, number<1>{}); + WarpSize * reduce_on_sequence(BlockWarps{}, multiplies{}, number<1>{}); }; template ; using UnquantYDataType = ck_tile::remove_cvref_t; - static constexpr bool is_warp_per_row = ThreadPerBlock_N_ <= warpSize; - static_assert((ThreadPerBlock_M_ * ThreadPerBlock_N_) % warpSize == 0); + static constexpr bool is_warp_per_row = ThreadPerBlock_N_ <= WarpSize; + static_assert((ThreadPerBlock_M_ * ThreadPerBlock_N_) % WarpSize == 0); static constexpr ck_tile::index_t total_warps = - (ThreadPerBlock_M_ * ThreadPerBlock_N_) / warpSize; + (ThreadPerBlock_M_ * ThreadPerBlock_N_) / WarpSize; // num of warps along m static constexpr ck_tile::index_t BlockWarps_M = []() { if constexpr(is_warp_per_row) { - static_assert(warpSize % ThreadPerBlock_N_ == 0); - return total_warps * (warpSize / ThreadPerBlock_N_); + static_assert(WarpSize % ThreadPerBlock_N_ == 0); + return total_warps * (WarpSize / ThreadPerBlock_N_); } else { - // static_assert(warpSize % ThreadPerBlock_M_ == 0); - return total_warps / (ThreadPerBlock_N_ / warpSize); + // static_assert(WarpSize % ThreadPerBlock_M_ == 0); + return total_warps / (ThreadPerBlock_N_ / WarpSize); } }(); @@ -97,13 +97,13 @@ struct rmsnorm2d_fwd_traits_ static constexpr ck_tile::index_t BlockWarps_N = []() { if constexpr(is_warp_per_row) { - static_assert(warpSize % ThreadPerBlock_N_ == 0); + static_assert(WarpSize % ThreadPerBlock_N_ == 0); return 1; } else { - static_assert(ThreadPerBlock_N_ % warpSize == 0); - return ThreadPerBlock_N_ / warpSize; + static_assert(ThreadPerBlock_N_ % WarpSize == 0); + return ThreadPerBlock_N_ / WarpSize; } }(); diff --git a/example/ck_tile/11_add_rmsnorm2d_rdquant/add_rmsnorm2d_rdquant_fwd.hpp b/example/ck_tile/11_add_rmsnorm2d_rdquant/add_rmsnorm2d_rdquant_fwd.hpp index c91b387d62..1d843b5594 100644 --- a/example/ck_tile/11_add_rmsnorm2d_rdquant/add_rmsnorm2d_rdquant_fwd.hpp +++ b/example/ck_tile/11_add_rmsnorm2d_rdquant/add_rmsnorm2d_rdquant_fwd.hpp @@ -80,22 +80,22 @@ struct add_rmsnorm2d_rdquant_fwd_traits_ using InputDataType = ck_tile::remove_cvref_t; using QuantizedDataType = ck_tile::remove_cvref_t; - static constexpr bool is_warp_per_row = ThreadPerBlock_N_ <= warpSize; - static_assert((ThreadPerBlock_M_ * ThreadPerBlock_N_) % warpSize == 0); + static constexpr bool is_warp_per_row = ThreadPerBlock_N_ <= WarpSize; + static_assert((ThreadPerBlock_M_ * ThreadPerBlock_N_) % WarpSize == 0); static constexpr ck_tile::index_t total_warps = - (ThreadPerBlock_M_ * ThreadPerBlock_N_) / warpSize; + (ThreadPerBlock_M_ * ThreadPerBlock_N_) / WarpSize; // num of warps along m static constexpr ck_tile::index_t BlockWarps_M = []() { if constexpr(is_warp_per_row) { - static_assert(warpSize % ThreadPerBlock_N_ == 0); - return total_warps * (warpSize / ThreadPerBlock_N_); + static_assert(WarpSize % ThreadPerBlock_N_ == 0); + return total_warps * (WarpSize / ThreadPerBlock_N_); } else { - // static_assert(warpSize % ThreadPerBlock_M_ == 0); - return total_warps / (ThreadPerBlock_N_ / warpSize); + // static_assert(WarpSize % ThreadPerBlock_M_ == 0); + return total_warps / (ThreadPerBlock_N_ / WarpSize); } }(); @@ -103,13 +103,13 @@ struct add_rmsnorm2d_rdquant_fwd_traits_ static constexpr ck_tile::index_t BlockWarps_N = []() { if constexpr(is_warp_per_row) { - static_assert(warpSize % ThreadPerBlock_N_ == 0); + static_assert(WarpSize % ThreadPerBlock_N_ == 0); return 1; } else { - static_assert(ThreadPerBlock_N_ % warpSize == 0); - return ThreadPerBlock_N_ / warpSize; + static_assert(ThreadPerBlock_N_ % WarpSize == 0); + return ThreadPerBlock_N_ / WarpSize; } }(); diff --git a/example/ck_tile/12_smoothquant/smoothquant.hpp b/example/ck_tile/12_smoothquant/smoothquant.hpp index 83ad7b012c..265399c276 100644 --- a/example/ck_tile/12_smoothquant/smoothquant.hpp +++ b/example/ck_tile/12_smoothquant/smoothquant.hpp @@ -49,22 +49,22 @@ struct smoothquant_traits_ { using DataType = ck_tile::remove_cvref_t; - static constexpr bool is_warp_per_row = ThreadPerBlock_N_ <= warpSize; - static_assert((ThreadPerBlock_M_ * ThreadPerBlock_N_) % warpSize == 0); + static constexpr bool is_warp_per_row = ThreadPerBlock_N_ <= WarpSize; + static_assert((ThreadPerBlock_M_ * ThreadPerBlock_N_) % WarpSize == 0); static constexpr ck_tile::index_t total_warps = - (ThreadPerBlock_M_ * ThreadPerBlock_N_) / warpSize; + (ThreadPerBlock_M_ * ThreadPerBlock_N_) / WarpSize; // num of warps along m static constexpr ck_tile::index_t BlockWarps_M = []() { if constexpr(is_warp_per_row) { - static_assert(warpSize % ThreadPerBlock_N_ == 0); - return total_warps * (warpSize / ThreadPerBlock_N_); + static_assert(WarpSize % ThreadPerBlock_N_ == 0); + return total_warps * (WarpSize / ThreadPerBlock_N_); } else { - // static_assert(warpSize % ThreadPerBlock_M_ == 0); - return total_warps / (ThreadPerBlock_N_ / warpSize); + // static_assert(WarpSize % ThreadPerBlock_M_ == 0); + return total_warps / (ThreadPerBlock_N_ / WarpSize); } }(); @@ -72,13 +72,13 @@ struct smoothquant_traits_ static constexpr ck_tile::index_t BlockWarps_N = []() { if constexpr(is_warp_per_row) { - static_assert(warpSize % ThreadPerBlock_N_ == 0); + static_assert(WarpSize % ThreadPerBlock_N_ == 0); return 1; } else { - static_assert(ThreadPerBlock_N_ % warpSize == 0); - return ThreadPerBlock_N_ / warpSize; + static_assert(ThreadPerBlock_N_ % WarpSize == 0); + return ThreadPerBlock_N_ / WarpSize; } }(); diff --git a/example/ck_tile/14_moe_smoothquant/moe_smoothquant.hpp b/example/ck_tile/14_moe_smoothquant/moe_smoothquant.hpp index c1b90b14b2..b29295f175 100644 --- a/example/ck_tile/14_moe_smoothquant/moe_smoothquant.hpp +++ b/example/ck_tile/14_moe_smoothquant/moe_smoothquant.hpp @@ -38,22 +38,22 @@ struct moe_smoothquant_traits_ using InputType = ck_tile::remove_cvref_t; using OutputType = ck_tile::remove_cvref_t; - static constexpr bool is_warp_per_row = ThreadPerBlock_N_ <= warpSize; - static_assert((ThreadPerBlock_M_ * ThreadPerBlock_N_) % warpSize == 0); + static constexpr bool is_warp_per_row = ThreadPerBlock_N_ <= WarpSize; + static_assert((ThreadPerBlock_M_ * ThreadPerBlock_N_) % WarpSize == 0); static constexpr ck_tile::index_t total_warps = - (ThreadPerBlock_M_ * ThreadPerBlock_N_) / warpSize; + (ThreadPerBlock_M_ * ThreadPerBlock_N_) / WarpSize; // num of warps along m static constexpr ck_tile::index_t BlockWarps_M = []() { if constexpr(is_warp_per_row) { - static_assert(warpSize % ThreadPerBlock_N_ == 0); - return total_warps * (warpSize / ThreadPerBlock_N_); + static_assert(WarpSize % ThreadPerBlock_N_ == 0); + return total_warps * (WarpSize / ThreadPerBlock_N_); } else { - // static_assert(warpSize % ThreadPerBlock_M_ == 0); - return total_warps / (ThreadPerBlock_N_ / warpSize); + // static_assert(WarpSize % ThreadPerBlock_M_ == 0); + return total_warps / (ThreadPerBlock_N_ / WarpSize); } }(); @@ -61,13 +61,13 @@ struct moe_smoothquant_traits_ static constexpr ck_tile::index_t BlockWarps_N = []() { if constexpr(is_warp_per_row) { - static_assert(warpSize % ThreadPerBlock_N_ == 0); + static_assert(WarpSize % ThreadPerBlock_N_ == 0); return 1; } else { - static_assert(ThreadPerBlock_N_ % warpSize == 0); - return ThreadPerBlock_N_ / warpSize; + static_assert(ThreadPerBlock_N_ % WarpSize == 0); + return ThreadPerBlock_N_ / WarpSize; } }(); diff --git a/include/ck/ck.hpp b/include/ck/ck.hpp index 26e4787949..3c1373a387 100644 --- a/include/ck/ck.hpp +++ b/include/ck/ck.hpp @@ -274,6 +274,12 @@ namespace ck { +#if defined(__GFX9__) || !defined(__HIP_DEVICE_COMPILE__) +__device__ static constexpr int WarpSize = 64; +#else +__device__ static constexpr int WarpSize = 32; +#endif + enum struct InMemoryDataOperationEnum { Set, diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_mx_pipeline_xdlops_base.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_mx_pipeline_xdlops_base.hpp index f366f309ff..5370cfa975 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_mx_pipeline_xdlops_base.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_mx_pipeline_xdlops_base.hpp @@ -45,7 +45,7 @@ struct BlockwiseGemmXdlops_mx_pipeline_base using ThisThreadBlock = ThisThreadBlock; - // Hardcode to 64, as HIP-provided "warpSize" would return 32 on RDNA GPUs. + // Hardcode to 64, as HIP-provided "WarpSize" would return 32 on RDNA GPUs. static constexpr index_t WaveSize = 64; static constexpr index_t A_K0 = ATileDesc{}.GetLength(I0); diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_base.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_base.hpp index 94772361d3..9296b8136f 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_base.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_base.hpp @@ -40,7 +40,7 @@ struct BlockwiseGemmXdlops_pipeline_base using ThisThreadBlock = ThisThreadBlock; - // Hardcode to 64, as HIP-provided "warpSize" would return 32 on RDNA GPUs. + // Hardcode to 64, as HIP-provided "WarpSize" would return 32 on RDNA GPUs. static constexpr index_t WaveSize = 64; static constexpr index_t A_K0 = ATileDesc{}.GetLength(I0); diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v2.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v2.hpp index 54edf0c353..a6b5e272ff 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v2.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v2.hpp @@ -141,7 +141,7 @@ struct BlockwiseGemmXdlops_pipeline_v2= 1 ? 4 * warpSize / BlockSize : 1; + (4 * WarpSize / BlockSize) >= 1 ? 4 * WarpSize / BlockSize : 1; static constexpr index_t FullMemBandPrefetchStages = math::integer_divide_ceil( 32768 / WgpPerCU, (MPerBlock * sizeof(ADataType) + NPerBlock * sizeof(BDataType)) * KPerBlock); @@ -631,7 +631,7 @@ struct BlockwiseGemmXdlops_pipeline_v2= 1 ? 4 * warpSize / BlockSize : 1; + (4 * WarpSize / BlockSize) >= 1 ? 4 * WarpSize / BlockSize : 1; static constexpr index_t FullMemBandPrefetchStages = math::integer_divide_ceil( 32768 / WgpPerCU, (MPerBlock * sizeof(ADataType) + NPerBlock * sizeof(BDataType)) * KPerBlock); diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v2_ab_scale.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v2_ab_scale.hpp index c8ad9c5b02..0c030030fe 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v2_ab_scale.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v2_ab_scale.hpp @@ -143,7 +143,7 @@ struct BlockwiseGemmXdlops_pipeline_v2_ab_scale= 1 ? 4 * warpSize / BlockSize : 1; + (4 * WarpSize / BlockSize) >= 1 ? 4 * WarpSize / BlockSize : 1; static constexpr index_t FullMemBandPrefetchStages = math::integer_divide_ceil( 32768 / WgpPerCU, (MPerBlock * sizeof(ADataType) + NPerBlock * sizeof(BDataType)) * KPerBlock); diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v2_b_scale.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v2_b_scale.hpp index 776f66dbbb..69002d7962 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v2_b_scale.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v2_b_scale.hpp @@ -141,7 +141,7 @@ struct BlockwiseGemmXdlops_pipeline_v2_b_scale= 1 ? 4 * warpSize / BlockSize : 1; + (4 * WarpSize / BlockSize) >= 1 ? 4 * WarpSize / BlockSize : 1; static constexpr index_t FullMemBandPrefetchStages = math::integer_divide_ceil( 32768 / WgpPerCU, (MPerBlock * sizeof(ADataType) + NPerBlock * sizeof(BDataType)) * KPerBlock); @@ -632,7 +632,7 @@ struct BlockwiseGemmXdlops_pipeline_v2_b_scale= 1 ? 4 * warpSize / BlockSize : 1; + (4 * WarpSize / BlockSize) >= 1 ? 4 * WarpSize / BlockSize : 1; static constexpr index_t FullMemBandPrefetchStages = math::integer_divide_ceil( 32768 / WgpPerCU, (MPerBlock * sizeof(ADataType) + NPerBlock * sizeof(BDataType)) * KPerBlock); diff --git a/include/ck/tensor_operation/gpu/grid/batchnorm_multiblock/gridwise_multiblock_batchnorm_forward.hpp b/include/ck/tensor_operation/gpu/grid/batchnorm_multiblock/gridwise_multiblock_batchnorm_forward.hpp index 47573107cf..7c9febf4de 100644 --- a/include/ck/tensor_operation/gpu/grid/batchnorm_multiblock/gridwise_multiblock_batchnorm_forward.hpp +++ b/include/ck/tensor_operation/gpu/grid/batchnorm_multiblock/gridwise_multiblock_batchnorm_forward.hpp @@ -202,7 +202,7 @@ struct GridwiseMultiblockBatchNormForward const index_t block_local_id = block_global_id % blkgroup_size; if(block_local_id == 0) - gms_init(BlockSize / warpSize * blkgroup_size, &p_control[blkgroup_id * 2]); + gms_init(BlockSize / WarpSize * blkgroup_size, &p_control[blkgroup_id * 2]); const auto thread_cluster_idx = thread_cluster_desc.CalculateBottomIndex(make_multi_index(thread_local_id)); diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_b_preshuffle.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_b_preshuffle.hpp index cfa8bfeb2a..8d5c844103 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_b_preshuffle.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_b_preshuffle.hpp @@ -347,7 +347,7 @@ struct GridwiseGemm_xdl_cshuffle_v3_b_preshuffle __host__ __device__ static auto MakeBGridDescriptor_Preshuffled(index_t N0, index_t K0) { - constexpr index_t NkSwizzleNumber = Number{}; + constexpr index_t NkSwizzleNumber = Number{}; return make_naive_tensor_descriptor( make_tuple(N0 / NWave, NWave, K0, NkSwizzleNumber), make_tuple(NWave * K0 * NkSwizzleNumber, K0 * NkSwizzleNumber, NkSwizzleNumber, I1)); @@ -1229,7 +1229,7 @@ struct GridwiseGemm_xdl_cshuffle_v3_b_preshuffle make_multi_index(n_block_data_idx_on_grid, get_warp_local_1d_id() % NWave, 0, - KPack * (get_thread_local_1d_id() % warpSize))); + KPack * (get_thread_local_1d_id() % WarpSize))); // LDS allocation for A and B: be careful of alignment @@ -1607,7 +1607,7 @@ struct GridwiseGemm_xdl_cshuffle_v3_b_preshuffle make_multi_index(n_block_data_idx_on_grid, get_warp_local_1d_id() % NWave, 0, - KPack * (get_thread_local_1d_id() % warpSize))); + KPack * (get_thread_local_1d_id() % WarpSize))); // LDS allocation for A and B: be careful of alignment auto a_block_buf_ping = make_dynamic_buffer( diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle.hpp index 3eb0f986b3..d31ed19787 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle.hpp @@ -374,7 +374,7 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle __host__ __device__ static auto MakeBGridDescriptor_Preshuffled(index_t N0, index_t K0) { - constexpr index_t NkSwizzleNumber = Number{}; + constexpr index_t NkSwizzleNumber = Number{}; return make_naive_tensor_descriptor( make_tuple(N0 / NWave, NWave, K0, NkSwizzleNumber), make_tuple(NWave * K0 * NkSwizzleNumber, K0 * NkSwizzleNumber, NkSwizzleNumber, I1)); @@ -1249,7 +1249,7 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle make_multi_index(n_block_data_idx_on_grid, get_warp_local_1d_id() % NWave, 0, - KPackPerGroup * (get_thread_local_1d_id() % warpSize))); + KPackPerGroup * (get_thread_local_1d_id() % WarpSize))); // LDS allocation for A and B: be careful of alignment // Cast after lds @@ -1687,7 +1687,7 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle make_multi_index(n_block_data_idx_on_grid, get_warp_local_1d_id() % NWave, 0, - KPackPerGroup * (get_thread_local_1d_id() % warpSize))); + KPackPerGroup * (get_thread_local_1d_id() % WarpSize))); // LDS allocation for A and B: be careful of alignment // Cast after lds diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d_blockscale_b_preshuffle.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d_blockscale_b_preshuffle.hpp index 322cd3d162..909376e5f7 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d_blockscale_b_preshuffle.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d_blockscale_b_preshuffle.hpp @@ -370,7 +370,7 @@ struct GridwiseGemmMultiD_blockscale_xdl_cshuffle_v3_b_preshuffle __host__ __device__ static auto MakeBGridDescriptor_Preshuffled(index_t N0, index_t K0) { - constexpr index_t NkSwizzleNumber = Number{}; + constexpr index_t NkSwizzleNumber = Number{}; return make_naive_tensor_descriptor( make_tuple(N0 / NWave, NWave, K0, NkSwizzleNumber), make_tuple(NWave * K0 * NkSwizzleNumber, K0 * NkSwizzleNumber, NkSwizzleNumber, I1)); @@ -1208,7 +1208,7 @@ struct GridwiseGemmMultiD_blockscale_xdl_cshuffle_v3_b_preshuffle make_multi_index(n_block_data_idx_on_grid, get_warp_local_1d_id() % NWave, 0, - KPack / KGroup * (get_thread_local_1d_id() % warpSize))); + KPack / KGroup * (get_thread_local_1d_id() % WarpSize))); // LDS allocation for A and B: be careful of alignment // Cast after lds @@ -1707,7 +1707,7 @@ struct GridwiseGemmMultiD_blockscale_xdl_cshuffle_v3_b_preshuffle make_multi_index(n_block_data_idx_on_grid, get_warp_local_1d_id() % NWave, 0, - KPack / KGroup * (get_thread_local_1d_id() % warpSize))); + KPack / KGroup * (get_thread_local_1d_id() % WarpSize))); // LDS allocation for A and B: be careful of alignment // Cast after lds diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_mx_bpreshuffle.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_mx_bpreshuffle.hpp index 223670e3bc..6691c63484 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_mx_bpreshuffle.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_mx_bpreshuffle.hpp @@ -422,7 +422,7 @@ struct GridwiseGemmMX_xdl_cshuffle_v3_bpreshuffle __host__ __device__ static auto MakeBGridDescriptor_Preshuffled(index_t N0, index_t K0) { - constexpr index_t NkSwizzleNumber = Number{}; + constexpr index_t NkSwizzleNumber = Number{}; return make_naive_tensor_descriptor_packed( make_tuple(N0 / NWave / NXdlPack, NWave, NXdlPack, K0, NkSwizzleNumber)); } @@ -1886,7 +1886,7 @@ struct GridwiseGemmMX_xdl_cshuffle_v3_bpreshuffle get_warp_local_1d_id() % NWave, 0, 0, - KPack * (get_thread_local_1d_id() % warpSize))); + KPack * (get_thread_local_1d_id() % WarpSize))); // LDS allocation for A and B: be careful of alignment auto a_block_buf_ping = make_dynamic_buffer( diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_moe_gemm.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_moe_gemm.hpp index 62d94c0bf8..92aab5af52 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_moe_gemm.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_moe_gemm.hpp @@ -405,7 +405,7 @@ struct GridwiseMoeGemm __host__ __device__ static auto MakeBGridDescriptor_Preshuffled(index_t N0, index_t K0) { - constexpr index_t NkSwizzleNumber = Number{}; + constexpr index_t NkSwizzleNumber = Number{}; return make_naive_tensor_descriptor( make_tuple(N0 / NWave, NWave, K0, NkSwizzleNumber), make_tuple(NWave * K0 * NkSwizzleNumber, K0 * NkSwizzleNumber, NkSwizzleNumber, I1)); @@ -1315,7 +1315,7 @@ struct GridwiseMoeGemm make_multi_index(n_block_data_idx_on_grid, get_warp_local_1d_id() % NWave, 0, - KPack / KGroup * (get_thread_local_1d_id() % warpSize))); + KPack / KGroup * (get_thread_local_1d_id() % WarpSize))); // LDS allocation for A and B: be careful of alignment // Cast after lds @@ -1361,7 +1361,8 @@ struct GridwiseMoeGemm make_multi_index(n_block_data_idx_on_grid, get_warp_local_1d_id() % NWave, 0, - KPack / KGroup * (get_thread_local_1d_id() % warpSize))); + KPack / KGroup * (get_thread_local_1d_id() % WarpSize))); + blockwise_gemm_pipeline.template Run( a_grid_desc_ak0_m_ak1, a_block_desc_ak0_m_ak1, @@ -2027,7 +2028,7 @@ struct GridwiseMoeGemm make_multi_index(n_block_data_idx_on_grid, get_warp_local_1d_id() % NWave, 0, - KPack / KGroup * (get_thread_local_1d_id() % warpSize))); + KPack / KGroup * (get_thread_local_1d_id() % WarpSize))); // LDS allocation for A and B: be careful of alignment // Cast after lds @@ -2077,7 +2078,7 @@ struct GridwiseMoeGemm make_multi_index(n_block_data_idx_on_grid, get_warp_local_1d_id() % NWave, 0, - KPack / KGroup * (get_thread_local_1d_id() % warpSize))); + KPack / KGroup * (get_thread_local_1d_id() % WarpSize))); blockwise_gemm_pipeline.template Run( a_grid_desc_ak0_m_ak1, a_block_desc_ak0_m_ak1, diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_moe_gemm_blockscale.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_moe_gemm_blockscale.hpp index fbfe2509ff..f092c9c1eb 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_moe_gemm_blockscale.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_moe_gemm_blockscale.hpp @@ -410,7 +410,7 @@ struct GridwiseMoeGemmBlockScale __host__ __device__ static auto MakeBGridDescriptor_Preshuffled(index_t N0, index_t K0) { - constexpr index_t NkSwizzleNumber = Number{}; + constexpr index_t NkSwizzleNumber = Number{}; return make_naive_tensor_descriptor( make_tuple(N0 / NWave, NWave, K0, NkSwizzleNumber), make_tuple(NWave * K0 * NkSwizzleNumber, K0 * NkSwizzleNumber, NkSwizzleNumber, I1)); @@ -1355,7 +1355,7 @@ struct GridwiseMoeGemmBlockScale make_multi_index(n_block_data_idx_on_grid, get_warp_local_1d_id() % NWave, 0, - KPack / KGroup * (get_thread_local_1d_id() % warpSize))); + KPack / KGroup * (get_thread_local_1d_id() % WarpSize))); // LDS allocation for A and B: be careful of alignment // Cast after lds @@ -1467,7 +1467,7 @@ struct GridwiseMoeGemmBlockScale make_multi_index(n_block_data_idx_on_grid, get_warp_local_1d_id() % NWave, 0, - KPack / KGroup * (get_thread_local_1d_id() % warpSize))); + KPack / KGroup * (get_thread_local_1d_id() % WarpSize))); const BScaleType* p_b_scale_grid_up = p_b_scale_grid + expert_scale_stride / 2 / BPackedSize; const auto b_scale_grid_buf_up = make_dynamic_buffer( @@ -2105,7 +2105,7 @@ struct GridwiseMoeGemmBlockScale make_multi_index(n_block_data_idx_on_grid, get_warp_local_1d_id() % NWave, 0, - KPack / KGroup * (get_thread_local_1d_id() % warpSize))); + KPack / KGroup * (get_thread_local_1d_id() % WarpSize))); // LDS allocation for A and B: be careful of alignment // Cast after lds @@ -2221,7 +2221,7 @@ struct GridwiseMoeGemmBlockScale make_multi_index(n_block_data_idx_on_grid, get_warp_local_1d_id() % NWave, 0, - KPack / KGroup * (get_thread_local_1d_id() % warpSize))); + KPack / KGroup * (get_thread_local_1d_id() % WarpSize))); const BScaleType* p_b_scale_grid_up = p_b_scale_grid + expert_scale_stride / 2 / BPackedSize; const auto b_scale_grid_buf_up = make_dynamic_buffer( diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_moe_mx_gemm.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_moe_mx_gemm.hpp index fc156a878f..59693a5861 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_moe_mx_gemm.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_moe_mx_gemm.hpp @@ -409,7 +409,7 @@ struct GridwiseMoeGemmMX __host__ __device__ static auto MakeBGridDescriptor_Preshuffled(index_t N0, index_t K0) { - constexpr index_t NkSwizzleNumber = Number{}; + constexpr index_t NkSwizzleNumber = Number{}; return make_naive_tensor_descriptor( make_tuple(N0 / NWave / NXdlPack, NWave, NXdlPack, K0, NkSwizzleNumber), make_tuple(NWave * NXdlPack * K0 * NkSwizzleNumber, @@ -1415,7 +1415,7 @@ struct GridwiseMoeGemmMX make_multi_index(n_block_data_idx_on_grid, get_warp_local_1d_id() % NWave, 0, - KPack / KGroup * (get_thread_local_1d_id() % warpSize))); + KPack / KGroup * (get_thread_local_1d_id() % WarpSize))); // LDS allocation for A and B: be careful of alignment // Cast after lds @@ -1508,7 +1508,7 @@ struct GridwiseMoeGemmMX make_multi_index(n_block_data_idx_on_grid, get_warp_local_1d_id() % NWave, 0, - KPack / KGroup * (get_thread_local_1d_id() % warpSize))); + KPack / KGroup * (get_thread_local_1d_id() % WarpSize))); const BScaleDataType* p_b_scale_grid_up = p_b_scale_grid + expert_scale_stride / 2; const auto b_scale_grid_buf_up = make_dynamic_buffer( p_b_scale_grid_up + expert_id * expert_scale_stride, @@ -2123,7 +2123,7 @@ struct GridwiseMoeGemmMX get_warp_local_1d_id() % NWave, 0, 0, - KPack / KGroup * (get_thread_local_1d_id() % warpSize))); + KPack / KGroup * (get_thread_local_1d_id() % WarpSize))); // LDS allocation for A and B: be careful of alignment // Cast after lds @@ -2221,7 +2221,7 @@ struct GridwiseMoeGemmMX make_multi_index(n_block_data_idx_on_grid, get_warp_local_1d_id() % NWave, 0, - KPack / KGroup * (get_thread_local_1d_id() % warpSize))); + KPack / KGroup * (get_thread_local_1d_id() % WarpSize))); const BScaleDataType* p_b_scale_grid_up = p_b_scale_grid + expert_scale_stride / 2; const auto b_scale_grid_buf_up = make_dynamic_buffer( p_b_scale_grid_up + expert_id * expert_scale_stride, diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_moe_mx_gemm_bns.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_moe_mx_gemm_bns.hpp index 7238917920..9ccd334262 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_moe_mx_gemm_bns.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_moe_mx_gemm_bns.hpp @@ -2319,7 +2319,7 @@ struct GridwiseMoeGemmMXBNS get_warp_local_1d_id() % NWave, 0, 0, - KPack / KGroup * (get_thread_local_1d_id() % warpSize))); + KPack / KGroup * (get_thread_local_1d_id() % WarpSize))); // LDS allocation for A and B: be careful of alignment // Cast after lds @@ -2417,7 +2417,7 @@ struct GridwiseMoeGemmMXBNS make_multi_index(n_block_data_idx_on_grid, get_warp_local_1d_id() % NWave, 0, - KPack / KGroup * (get_thread_local_1d_id() % warpSize))); + KPack / KGroup * (get_thread_local_1d_id() % WarpSize))); const BScaleDataType* p_b_scale_grid_up = p_b_scale_grid + expert_scale_stride / 2; const auto b_scale_grid_buf_up = make_dynamic_buffer( p_b_scale_grid_up + expert_id * expert_scale_stride, diff --git a/include/ck/utility/workgroup_synchronization.hpp b/include/ck/utility/workgroup_synchronization.hpp index 24858fdbdc..af5b0808fb 100644 --- a/include/ck/utility/workgroup_synchronization.hpp +++ b/include/ck/utility/workgroup_synchronization.hpp @@ -32,7 +32,7 @@ static __device__ void gms_init(int NumWarps, int* p_control_bits) // all the workgroups in the synchronization group is supposed to call this function static __device__ void gms_barrier(int* p_control_bits) { - constexpr int mask = warpSize - 1; + constexpr int mask = WarpSize - 1; if((threadIdx.x & mask) == 0) { diff --git a/include/ck_tile/core/arch/utility.hpp b/include/ck_tile/core/arch/utility.hpp index df0f54c5ed..7184f99521 100644 --- a/include/ck_tile/core/arch/utility.hpp +++ b/include/ck_tile/core/arch/utility.hpp @@ -35,7 +35,7 @@ CK_TILE_DEVICE T warp_shuffle_up(const T& v_local, uint32_t lane_delta) #elif 1 static_assert(sizeof(T) == sizeof(int32_t), "wrong!"); - const uint32_t wrap_around_lane_delta = warpSize - lane_delta; + const uint32_t wrap_around_lane_delta = get_warp_size() - lane_delta; const int32_t v_remote_tmp = __builtin_amdgcn_ds_bpermute( (__lane_id() << 2) + (wrap_around_lane_delta << 2), bit_cast(v_local)); diff --git a/include/ck_tile/ops/flatmm/block/flatmm_32x512x128_1x4x1_16x16x32.hpp b/include/ck_tile/ops/flatmm/block/flatmm_32x512x128_1x4x1_16x16x32.hpp index 869ab32c2e..1dcd62011a 100644 --- a/include/ck_tile/ops/flatmm/block/flatmm_32x512x128_1x4x1_16x16x32.hpp +++ b/include/ck_tile/ops/flatmm/block/flatmm_32x512x128_1x4x1_16x16x32.hpp @@ -95,7 +95,7 @@ struct Flatmm_32x512x128_1x4x1_16x16x32_Base // for f16/bf16 // constexpr index_t Block_M = Problem::BlockShape::Block_M0; // constexpr index_t Block_K = Problem::BlockShape::Block_K0; // constexpr index_t BlockSize = Problem::BlockShape::BlockSize; - constexpr index_t warpSize = ck_tile::get_warp_size(); + constexpr index_t WarpSize = ck_tile::get_warp_size(); // constexpr index_t NumWarps = Problem::BlockShape::NumWarps; constexpr index_t KPack_ = 8; // GetSmemKPack_A(); // LDS @@ -104,11 +104,11 @@ struct Flatmm_32x512x128_1x4x1_16x16x32_Base // for f16/bf16 static_assert(Block_K % KVector == 0); constexpr index_t LanesPerK = Block_K / KVector; // how many thread loading K - if constexpr(LanesPerK >= warpSize) + if constexpr(LanesPerK >= WarpSize) { // need multiple waves to load K - static_assert(LanesPerK % warpSize == 0); - constexpr index_t wavesPerK = LanesPerK / warpSize; + static_assert(LanesPerK % WarpSize == 0); + constexpr index_t wavesPerK = LanesPerK / WarpSize; if constexpr(wavesPerK > NumWarps) { // TODO: need multiple issues along K to load all data @@ -121,11 +121,11 @@ struct Flatmm_32x512x128_1x4x1_16x16x32_Base // for f16/bf16 make_tuple(number{}, // m0 number{}, // m1 number{}, // k0 - number{}, // k1 + number{}, // k1 number{}), // k2 - make_tuple(number{}, // m0 - number{}, // m1 - number{}, // k0 + make_tuple(number{}, // m0 + number{}, // m1 + number{}, // k0 number{}, // k1 number<1>{}), // k2 number{}, // lds store vector(actually no explicit store) @@ -136,7 +136,7 @@ struct Flatmm_32x512x128_1x4x1_16x16x32_Base // for f16/bf16 make_tuple( make_pass_through_transform(number{}), make_merge_transform(make_tuple(number{}, number{})), - make_merge_transform(make_tuple(number{}, number{}))), + make_merge_transform(make_tuple(number{}, number{}))), make_tuple(sequence<0>{}, sequence<1, 2>{}, sequence<3, 4>{}), make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{})); @@ -146,8 +146,8 @@ struct Flatmm_32x512x128_1x4x1_16x16x32_Base // for f16/bf16 else { // lanes within a wave load different M but same K - static_assert(warpSize % LanesPerK == 0); - constexpr index_t LaneGroups = warpSize / LanesPerK; // along m + static_assert(WarpSize % LanesPerK == 0); + constexpr index_t LaneGroups = WarpSize / LanesPerK; // along m constexpr index_t NumIssues = Block_M / (LaneGroups * NumWarps); constexpr auto lds_block_desc_0 = make_naive_tensor_descriptor( @@ -156,9 +156,9 @@ struct Flatmm_32x512x128_1x4x1_16x16x32_Base // for f16/bf16 number{}, // m2 number{}, // k0 number{}), // k1 - make_tuple(number{}, // m0 + make_tuple(number{}, // m0 number{}, // m1 - number{}, // m2 + number{}, // m2 number{}, // k0 number<1>{}), // k1 number{}, // lds store vector(actually no explicit store) diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qx_ks_vs_custom_policy.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qx_ks_vs_custom_policy.hpp index 26f7e46f9f..30d07a4754 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qx_ks_vs_custom_policy.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qx_ks_vs_custom_policy.hpp @@ -448,19 +448,19 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy(); // this is for lds constexpr index_t KVector = GetAlignmentK(); // this is for global load constexpr index_t kPad = KPack; - static_assert(warpSize * KVector >= kKPerBlock && - warpSize * KVector % kKPerBlock == 0); + static_assert(WarpSize * KVector >= kKPerBlock && + WarpSize * KVector % kKPerBlock == 0); constexpr index_t LanesPerK = kKPerBlock / KVector; - constexpr index_t LaneGroups = warpSize / LanesPerK; + constexpr index_t LaneGroups = WarpSize / LanesPerK; constexpr index_t NumIssues = kNPerBlock / (LaneGroups * NumWarps); - return NumIssues * NumWarps * (warpSize * KVector + kPad); + return NumIssues * NumWarps * (WarpSize * KVector + kPad); } }(); @@ -516,18 +516,18 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy(); // this is for lds constexpr index_t KVector = GetAlignmentK(); // this is for global load constexpr index_t kPad = KPack; // for async-copy, this pad is between warps. Optimize this for lds_read speed - static_assert(warpSize * KVector >= kKPerBlock && warpSize * KVector % kKPerBlock == 0); + static_assert(WarpSize * KVector >= kKPerBlock && WarpSize * KVector % kKPerBlock == 0); constexpr index_t LanesPerK = kKPerBlock / KVector; // how many lane (within a wave) to load K constexpr index_t LaneGroups = - warpSize / + WarpSize / LanesPerK; // how many groups (within a wave), they may load different N, but same K constexpr index_t NumIssues = kNPerBlock / (LaneGroups * NumWarps); static_assert(NumIssues == kNPerBlock * kKPerBlock / (kBlockSize * KVector)); @@ -538,9 +538,9 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy{}, // n2 number{}, // k0 number{}), // k1 - make_tuple(number{}, + make_tuple(number{}, number{}, - number{}, + number{}, number{}, number<1>{}), number()>{}, @@ -569,18 +569,18 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy(); // this is for lds constexpr index_t KVector = GetAlignmentK(); // this is for global load constexpr index_t kPad = KPack; // for async-copy, this pad is between warps - static_assert(warpSize * KVector >= kKPerBlock && warpSize * KVector % kKPerBlock == 0); + static_assert(WarpSize * KVector >= kKPerBlock && WarpSize * KVector % kKPerBlock == 0); constexpr index_t LanesPerK = kKPerBlock / KVector; // within a wave - constexpr index_t LaneGroups = warpSize / LanesPerK; // within a wave + constexpr index_t LaneGroups = WarpSize / LanesPerK; // within a wave constexpr index_t NumIssues = kNPerBlock / (LaneGroups * NumWarps); static_assert(NumIssues == kNPerBlock * kKPerBlock / (kBlockSize * KVector)); - // constexpr index_t SingleKSize = NumIssues * NumWarps * (warpSize * KVector + kPad); + // constexpr index_t SingleKSize = NumIssues * NumWarps * (WarpSize * KVector + kPad); // constexpr index_t SingleVSize = // MakeVLdsBlockDescriptor().get_element_space_size(); constexpr index_t BufferSize = @@ -594,8 +594,8 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy{}, // k0 number{}), // k1 make_tuple(number{}, - number{}, - number{}, + number{}, + number{}, number{}, number{}, number<1>{}), @@ -746,13 +746,13 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy(); // this is for global load - static_assert(warpSize * KVector >= kKPerBlock && warpSize * KVector % kKPerBlock == 0); + static_assert(WarpSize * KVector >= kKPerBlock && WarpSize * KVector % kKPerBlock == 0); constexpr index_t LanesPerK = kKPerBlock / KVector; // within a wave - constexpr index_t LaneGroups = warpSize / LanesPerK; // within a wave + constexpr index_t LaneGroups = WarpSize / LanesPerK; // within a wave constexpr index_t NumIssues = kNPerBlock / (LaneGroups * NumWarps); static_assert(NumIssues == kNPerBlock * kKPerBlock / (kBlockSize * KVector)); diff --git a/include/ck_tile/ops/fused_moe/kernel/fused_moegemm_shape.hpp b/include/ck_tile/ops/fused_moe/kernel/fused_moegemm_shape.hpp index 4f3f8bb7d3..336bdc806f 100644 --- a/include/ck_tile/ops/fused_moe/kernel/fused_moegemm_shape.hpp +++ b/include/ck_tile/ops/fused_moe/kernel/fused_moegemm_shape.hpp @@ -101,7 +101,7 @@ struct FusedMoeGemmShape static constexpr index_t Repeat_N1 = Block_N1 / ThreadPerBlock_N1; static constexpr index_t Repeat_K1 = Block_K1 / ThreadPerBlock_K1; - static constexpr index_t BlockSize = warpSize * NumWarps; + static constexpr index_t BlockSize = WarpSize * NumWarps; // some assert static_assert(Block_M0 == Block_M1); diff --git a/include/ck_tile/ops/fused_moe/kernel/moe_sorting_kernel.hpp b/include/ck_tile/ops/fused_moe/kernel/moe_sorting_kernel.hpp index 4166c1c602..d3c98d7bca 100644 --- a/include/ck_tile/ops/fused_moe/kernel/moe_sorting_kernel.hpp +++ b/include/ck_tile/ops/fused_moe/kernel/moe_sorting_kernel.hpp @@ -381,7 +381,7 @@ struct MoeSortingKernel } // reduce single pixel within a wave - template + template __device__ static constexpr T wave_reduce(T local, F reduce_f, number = {}) { // constexpr int wave_size = 64; @@ -618,7 +618,7 @@ struct MoeSortingKernel { const index_t prefill_token = topk_mdiv.div(numel); // TODO: only support expert-tile like 8, 16, 32 - static constexpr index_t experts_per_wave = warpSize / Problem::ExpertTile; + static constexpr index_t experts_per_wave = WarpSize / Problem::ExpertTile; { index_t eid = tid / experts_per_wave; index_t expert_offset = cumsum[eid] + @@ -686,7 +686,7 @@ struct MoeSortingKernel void* smem) const { const index_t tid = static_cast(threadIdx.x); - const index_t wid = __builtin_amdgcn_readfirstlane(tid / warpSize); + const index_t wid = __builtin_amdgcn_readfirstlane(tid / WarpSize); const index_t lid = __lane_id(); constexpr index_t block_size = 256; // blockDim.x; const index_t sub_tokens = smem_rows - 2; // sub_tokens_mdiv.divisor; @@ -791,7 +791,7 @@ struct MoeSortingKernel // NOTE: under this block can never use __syncthreads! int i_e_ = 0; int local_cumsum_ = 0; - for(; i_e_ < num_experts; i_e_ += warpSize) + for(; i_e_ < num_experts; i_e_ += WarpSize) { int pre_cumsum_ = smem_cumsum(lid == 0 ? i_e_ : 0); int local_cnt = smem_cumsum(i_e_ + lid + 1); @@ -836,7 +836,7 @@ struct MoeSortingKernel // cumsum padded in case local cumsum is zero, but // pre_sumsum has value, which will result int // zero local cumsum(but we want at least padded) - wave_cumsum(local_cumsum_); + wave_cumsum(local_cumsum_); if((i_e_ + lid) < num_experts) smem_cumsum(i_e_ + lid + 1) = local_cumsum_; @@ -844,7 +844,7 @@ struct MoeSortingKernel if constexpr(Problem::LocalExpertMasking) { local_masking += pre_cumsum_masking; - wave_cumsum(local_masking); + wave_cumsum(local_masking); if((i_e_ + lid) < num_experts) smem_cumdup(i_e_ + lid + 1) = local_masking; } @@ -854,7 +854,7 @@ struct MoeSortingKernel // than 0(which is not we want) __builtin_amdgcn_s_waitcnt(0xc07f); } - if((lid + i_e_ - warpSize) == (num_experts - 1)) + if((lid + i_e_ - WarpSize) == (num_experts - 1)) { *p_total_tokens_post_pad = local_cumsum_; } @@ -1091,7 +1091,7 @@ CK_TILE_HOST_DEVICE index_t moe_sorting_mp_sem_smem_size() return chunk * sizeof(index_t); }; -template +template CK_TILE_DEVICE constexpr T moe_sorting_wave_reduce(T local, F reduce_f, number = {}) { // constexpr int wave_size = 64; @@ -1456,7 +1456,7 @@ struct MoeSortingMultiPhaseKernel_P1 // in byte CK_TILE_HOST_DEVICE static constexpr auto GetSmemSize() { - return BLOCK_SIZE / warpSize * sizeof(IndexType); + return BLOCK_SIZE / WarpSize * sizeof(IndexType); } CK_TILE_DEVICE void operator()(Kargs kargs) const @@ -1498,8 +1498,8 @@ struct MoeSortingMultiPhaseKernel_P1 cnt += impl::moe_sorting_wave_reduce(local_sum, f_sum); } - index_t lane_id = threadIdx.x % warpSize; - index_t wave_id = threadIdx.x / warpSize; + index_t lane_id = threadIdx.x % WarpSize; + index_t wave_id = threadIdx.x / WarpSize; // reduce cross wave IndexType* s = reinterpret_cast(smem); @@ -1512,7 +1512,7 @@ struct MoeSortingMultiPhaseKernel_P1 if(threadIdx.x == 0) { index_t c = 0; - for(auto i = 0; i < (BLOCK_SIZE / warpSize); i++) + for(auto i = 0; i < (BLOCK_SIZE / WarpSize); i++) { c += s[i]; } @@ -1601,7 +1601,7 @@ struct MoeSortingMultiPhaseKernel_P01 // in byte CK_TILE_HOST static constexpr auto GetSmemSize() { - return BLOCK_SIZE / warpSize * sizeof(IndexType); + return BLOCK_SIZE / WarpSize * sizeof(IndexType); } CK_TILE_DEVICE void operator()(Kargs kargs) const @@ -1685,8 +1685,8 @@ struct MoeSortingMultiPhaseKernel_P01 cnt += impl::moe_sorting_wave_reduce(local_sum, f_sum); } - index_t lane_id = threadIdx.x % warpSize; - index_t wave_id = threadIdx.x / warpSize; + index_t lane_id = threadIdx.x % WarpSize; + index_t wave_id = threadIdx.x / WarpSize; // reduce cross wave IndexType* s = reinterpret_cast(smem); @@ -1700,7 +1700,7 @@ struct MoeSortingMultiPhaseKernel_P01 if(threadIdx.x == 0) { index_t c = 0; - for(auto i = 0; i < (BLOCK_SIZE / warpSize); i++) + for(auto i = 0; i < (BLOCK_SIZE / WarpSize); i++) { c += s[i]; } @@ -1777,7 +1777,7 @@ struct MoeSortingMultiPhaseKernel_P2 CK_TILE_HOST_DEVICE static constexpr auto GetSmemSize() { // return 2 * BLOCK_SIZE * sizeof(IndexType); - return (4 + 2 * BLOCK_SIZE / warpSize) * sizeof(IndexType); + return (4 + 2 * BLOCK_SIZE / WarpSize) * sizeof(IndexType); } // reduce single pixel within a wave @@ -1802,8 +1802,8 @@ struct MoeSortingMultiPhaseKernel_P2 IndexType* p_sorted_expert_ids = reinterpret_cast(kargs.p_sorted_expert_ids); const index_t loops = (kargs.num_experts + BLOCK_SIZE - 1) / BLOCK_SIZE; - index_t wave_id = threadIdx.x / warpSize; - index_t lane_id = threadIdx.x % warpSize; + index_t wave_id = threadIdx.x / WarpSize; + index_t lane_id = threadIdx.x % WarpSize; IndexType prev_cumsum_a = 0; IndexType prev_cumsum_b = 0; @@ -1848,22 +1848,22 @@ struct MoeSortingMultiPhaseKernel_P2 IndexType cumsum_b = b_; // Note: we first cumsum local round, then add previous cumsum - impl::moe_sorting_wave_cumsum(cumsum_a); - impl::moe_sorting_wave_cumsum(cumsum_b); + impl::moe_sorting_wave_cumsum(cumsum_a); + impl::moe_sorting_wave_cumsum(cumsum_b); __syncthreads(); - if(lane_id == warpSize - 1) + if(lane_id == WarpSize - 1) { s[4 + wave_id] = cumsum_a; - s[4 + wave_id + BLOCK_SIZE / warpSize] = cumsum_b; + s[4 + wave_id + BLOCK_SIZE / WarpSize] = cumsum_b; } __syncthreads(); // reduce cross wave - static_for<0, BLOCK_SIZE / warpSize - 1, 1>{}([&](auto i_w) { + static_for<0, BLOCK_SIZE / WarpSize - 1, 1>{}([&](auto i_w) { IndexType prev_a = s[4 + i_w]; - IndexType prev_b = s[4 + i_w + BLOCK_SIZE / warpSize]; + IndexType prev_b = s[4 + i_w + BLOCK_SIZE / WarpSize]; prev_a = wave_id > i_w ? prev_a : 0; // mask out prev_b = wave_id > i_w ? prev_b : 0; // mask out cumsum_a += prev_a; @@ -1978,7 +1978,7 @@ struct MoeSortingMultiPhaseKernel_P3 // in byte CK_TILE_HOST_DEVICE static constexpr auto GetSmemSize() { - return (4 + BLOCK_SIZE / warpSize) * sizeof(IndexType); + return (4 + BLOCK_SIZE / WarpSize) * sizeof(IndexType); } CK_TILE_DEVICE void operator()(Kargs kargs) const @@ -1995,8 +1995,8 @@ struct MoeSortingMultiPhaseKernel_P3 WeightType* p_sorted_weights = reinterpret_cast(kargs.p_sorted_weights); int eid = blockIdx.x; - int wave_id = threadIdx.x / warpSize; - int lane_id = threadIdx.x % warpSize; + int wave_id = threadIdx.x / WarpSize; + int lane_id = threadIdx.x % WarpSize; int e_start = p_expert_cumsum[eid]; int e_end = p_expert_cumsum[eid + 1]; if constexpr(Problem::SkipExpertsWithZeroTokens) @@ -2026,17 +2026,17 @@ struct MoeSortingMultiPhaseKernel_P3 int i_topk = x - 1; // topk of this token int i_show = x != 0 ? 1 : 0; // has this token or not int cumsum = i_show; - impl::moe_sorting_wave_cumsum(cumsum); + impl::moe_sorting_wave_cumsum(cumsum); __syncthreads(); - if(lane_id == warpSize - 1) + if(lane_id == WarpSize - 1) { s[4 + wave_id] = cumsum; } __syncthreads(); // reduce cross wave - static_for<0, BLOCK_SIZE / warpSize - 1, 1>{}([&](auto i_w) { + static_for<0, BLOCK_SIZE / WarpSize - 1, 1>{}([&](auto i_w) { IndexType prev = s[4 + i_w]; prev = wave_id > i_w ? prev : 0; // mask out cumsum += prev; @@ -2081,7 +2081,7 @@ CK_TILE_HOST constexpr auto moe_sorting_get_smem_size_p23(int num_experts_) { constexpr index_t BLOCK_SIZE = 256; // hardcoded 256 const index_t expert_cumsum_elem = num_experts_ + 1; - return (4 + 2 * BLOCK_SIZE / warpSize + expert_cumsum_elem) * sizeof(int); + return (4 + 2 * BLOCK_SIZE / WarpSize + expert_cumsum_elem) * sizeof(int); } } // namespace impl @@ -2186,15 +2186,15 @@ struct MoeSortingMultiPhaseKernel_P23 const IndexType* p_local_expert_mask = static_cast(kargs.p_local_expert_mask); IndexType* p_expert_cumsum = reinterpret_cast(kargs.p_expert_cumsum); - IndexType* p_expert_cumsum_smem = s + 4 + 2 * BLOCK_SIZE / warpSize; + IndexType* p_expert_cumsum_smem = s + 4 + 2 * BLOCK_SIZE / WarpSize; IndexType* p_total_tokens_post_pad = reinterpret_cast(kargs.p_total_tokens_post_pad); IndexType* p_sorted_expert_ids = reinterpret_cast(kargs.p_sorted_expert_ids); const index_t loops = (kargs.num_experts + BLOCK_SIZE - 1) / BLOCK_SIZE; - index_t wave_id = threadIdx.x / warpSize; - index_t lane_id = threadIdx.x % warpSize; + index_t wave_id = threadIdx.x / WarpSize; + index_t lane_id = threadIdx.x % WarpSize; IndexType prev_cumsum_a = 0; IndexType prev_cumsum_b = 0; @@ -2239,22 +2239,22 @@ struct MoeSortingMultiPhaseKernel_P23 IndexType cumsum_b = b_; // Note: we first cumsum local round, then add previous cumsum - impl::moe_sorting_wave_cumsum(cumsum_a); - impl::moe_sorting_wave_cumsum(cumsum_b); + impl::moe_sorting_wave_cumsum(cumsum_a); + impl::moe_sorting_wave_cumsum(cumsum_b); __syncthreads(); - if(lane_id == warpSize - 1) + if(lane_id == WarpSize - 1) { s[4 + wave_id] = cumsum_a; - s[4 + wave_id + BLOCK_SIZE / warpSize] = cumsum_b; + s[4 + wave_id + BLOCK_SIZE / WarpSize] = cumsum_b; } __syncthreads(); // reduce cross wave - static_for<0, BLOCK_SIZE / warpSize - 1, 1>{}([&](auto i_w) { + static_for<0, BLOCK_SIZE / WarpSize - 1, 1>{}([&](auto i_w) { IndexType prev_a = s[4 + i_w]; - IndexType prev_b = s[4 + i_w + BLOCK_SIZE / warpSize]; + IndexType prev_b = s[4 + i_w + BLOCK_SIZE / WarpSize]; prev_a = wave_id > i_w ? prev_a : 0; // mask out prev_b = wave_id > i_w ? prev_b : 0; // mask out cumsum_a += prev_a; @@ -2324,13 +2324,13 @@ struct MoeSortingMultiPhaseKernel_P23 IndexType* s = reinterpret_cast(smem); MeshType* p_expert_mesh = reinterpret_cast(kargs.p_expert_mesh); IndexType* p_sorted_token_ids = reinterpret_cast(kargs.p_sorted_token_ids); - IndexType* p_expert_cumsum_smem = s + 4 + 2 * BLOCK_SIZE / warpSize; + IndexType* p_expert_cumsum_smem = s + 4 + 2 * BLOCK_SIZE / WarpSize; const WeightType* p_weights = static_cast(kargs.p_weights); WeightType* p_sorted_weights = reinterpret_cast(kargs.p_sorted_weights); int eid = blockIdx.x; - int wave_id = threadIdx.x / warpSize; - int lane_id = threadIdx.x % warpSize; + int wave_id = threadIdx.x / WarpSize; + int lane_id = threadIdx.x % WarpSize; int e_start = p_expert_cumsum_smem[eid]; int e_end = p_expert_cumsum_smem[eid + 1]; if constexpr(Problem::SkipExpertsWithZeroTokens) @@ -2390,17 +2390,17 @@ struct MoeSortingMultiPhaseKernel_P23 int i_topk = x - 1; // topk of this token int i_show = x != 0 ? 1 : 0; // has this token or not int cumsum = i_show; - impl::moe_sorting_wave_cumsum(cumsum); + impl::moe_sorting_wave_cumsum(cumsum); __syncthreads(); - if(lane_id == warpSize - 1) + if(lane_id == WarpSize - 1) { s[4 + wave_id] = cumsum; } __syncthreads(); // reduce cross wave - static_for<0, BLOCK_SIZE / warpSize - 1, 1>{}([&](auto i_w) { + static_for<0, BLOCK_SIZE / WarpSize - 1, 1>{}([&](auto i_w) { IndexType prev = s[4 + i_w]; prev = wave_id > i_w ? prev : 0; // mask out cumsum += prev; @@ -2441,17 +2441,17 @@ struct MoeSortingMultiPhaseKernel_P23 cumsum_store += i_show[j]; }); int cumsum = cumsum_store; - impl::moe_sorting_wave_cumsum(cumsum); + impl::moe_sorting_wave_cumsum(cumsum); __syncthreads(); - if(lane_id == warpSize - 1) + if(lane_id == WarpSize - 1) { s[4 + wave_id] = cumsum; } __syncthreads(); // reduce cross wave - static_for<0, BLOCK_SIZE / warpSize - 1, 1>{}([&](auto i_w) { + static_for<0, BLOCK_SIZE / WarpSize - 1, 1>{}([&](auto i_w) { IndexType prev = s[4 + i_w]; prev = wave_id > i_w ? prev : 0; // mask out cumsum += prev; @@ -2496,17 +2496,17 @@ struct MoeSortingMultiPhaseKernel_P23 int i_topk_1 = x1 - 1; // topk of this token int i_show_1 = x1 != 0 ? 1 : 0; // has this token or not int cumsum = i_show_0 + i_show_1; - impl::moe_sorting_wave_cumsum(cumsum); + impl::moe_sorting_wave_cumsum(cumsum); __syncthreads(); - if(lane_id == warpSize - 1) + if(lane_id == WarpSize - 1) { s[4 + wave_id] = cumsum; } __syncthreads(); // reduce cross wave - static_for<0, BLOCK_SIZE / warpSize - 1, 1>{}([&](auto i_w) { + static_for<0, BLOCK_SIZE / WarpSize - 1, 1>{}([&](auto i_w) { IndexType prev = s[4 + i_w]; prev = wave_id > i_w ? prev : 0; // mask out cumsum += prev; diff --git a/include/ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_flatmm_policy.hpp b/include/ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_flatmm_policy.hpp index 629f0ee8f1..0c8baaf191 100644 --- a/include/ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_flatmm_policy.hpp +++ b/include/ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_flatmm_policy.hpp @@ -303,7 +303,7 @@ struct FusedMoeGemmPipelineFlatmmPolicy constexpr index_t Block_M = Problem::BlockShape::Block_M0; constexpr index_t Block_K = Problem::BlockShape::Block_K0; // constexpr index_t BlockSize = Problem::BlockShape::BlockSize; - constexpr index_t warpSize = ck_tile::get_warp_size(); + constexpr index_t WarpSize = ck_tile::get_warp_size(); constexpr index_t NumWarps = Problem::BlockShape::NumWarps; constexpr index_t KPack = GetSmemKPack_A(); // LDS @@ -312,11 +312,11 @@ struct FusedMoeGemmPipelineFlatmmPolicy static_assert(Block_K % KVector == 0); constexpr index_t LanesPerK = Block_K / KVector; // how many thread loading K - if constexpr(LanesPerK >= warpSize) + if constexpr(LanesPerK >= WarpSize) { // need multiple waves to load K - static_assert(LanesPerK % warpSize == 0); - constexpr index_t wavesPerK = LanesPerK / warpSize; + static_assert(LanesPerK % WarpSize == 0); + constexpr index_t wavesPerK = LanesPerK / WarpSize; if constexpr(wavesPerK > NumWarps) { // TODO: need multiple issues along K to load all data @@ -329,11 +329,11 @@ struct FusedMoeGemmPipelineFlatmmPolicy make_tuple(number{}, // m0 number{}, // m1 number{}, // k0 - number{}, // k1 + number{}, // k1 number{}), // k2 - make_tuple(number{}, // m0 - number{}, // m1 - number{}, // k0 + make_tuple(number{}, // m0 + number{}, // m1 + number{}, // k0 number{}, // k1 number<1>{}), // k2 number{}, // lds store vector(actually no explicit store) @@ -344,7 +344,7 @@ struct FusedMoeGemmPipelineFlatmmPolicy make_tuple( make_pass_through_transform(number{}), make_merge_transform(make_tuple(number{}, number{})), - make_merge_transform(make_tuple(number{}, number{}))), + make_merge_transform(make_tuple(number{}, number{}))), make_tuple(sequence<0>{}, sequence<1, 2>{}, sequence<3, 4>{}), make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{})); @@ -354,8 +354,8 @@ struct FusedMoeGemmPipelineFlatmmPolicy else { // lanes within a wave load different M but same K - static_assert(warpSize % LanesPerK == 0); - constexpr index_t LaneGroups = warpSize / LanesPerK; // along m + static_assert(WarpSize % LanesPerK == 0); + constexpr index_t LaneGroups = WarpSize / LanesPerK; // along m constexpr index_t NumIssues = Block_M / (LaneGroups * NumWarps); constexpr auto lds_block_desc_0 = make_naive_tensor_descriptor( @@ -364,9 +364,9 @@ struct FusedMoeGemmPipelineFlatmmPolicy number{}, // m2 number{}, // k0 number{}), // k1 - make_tuple(number{}, // m0 + make_tuple(number{}, // m0 number{}, // m1 - number{}, // m2 + number{}, // m2 number{}, // k0 number<1>{}), // k1 number{}, // lds store vector(actually no explicit store) @@ -398,7 +398,7 @@ struct FusedMoeGemmPipelineFlatmmPolicy constexpr index_t Block_M = Problem::BlockShape::Block_M0; constexpr index_t Block_K = Problem::BlockShape::Block_K0; // constexpr index_t BlockSize = Problem::BlockShape::BlockSize; - constexpr index_t warpSize = ck_tile::get_warp_size(); + constexpr index_t WarpSize = ck_tile::get_warp_size(); constexpr index_t NumWarps = Problem::BlockShape::NumWarps; constexpr index_t KPack = GetSmemKPack_A(); // LDS @@ -407,11 +407,11 @@ struct FusedMoeGemmPipelineFlatmmPolicy static_assert(Block_K % KVector == 0); constexpr index_t LanesPerK = Block_K / KVector; // how many thread loading K - if constexpr(LanesPerK >= warpSize) + if constexpr(LanesPerK >= WarpSize) { // need multiple waves to load K - static_assert(LanesPerK % warpSize == 0); - constexpr index_t wavesPerK = LanesPerK / warpSize; + static_assert(LanesPerK % WarpSize == 0); + constexpr index_t wavesPerK = LanesPerK / WarpSize; if constexpr(wavesPerK >= NumWarps) { // TODO: need multiple issues along K to load all data @@ -424,11 +424,11 @@ struct FusedMoeGemmPipelineFlatmmPolicy make_tuple(number{}, // m0 number{}, // m1 number{}, // k0 - number{}, // k1 + number{}, // k1 number{}), // k2 - make_tuple(number{}, // m0 - number{}, // m1 - number{}, // k0 + make_tuple(number{}, // m0 + number{}, // m1 + number{}, // k0 number{}, // k1 number<1>{}), // k2 number{}, // lds load vector @@ -439,7 +439,7 @@ struct FusedMoeGemmPipelineFlatmmPolicy make_tuple( make_merge_transform(make_tuple(number{}, number{})), make_merge_transform(make_tuple( - number{}, number{}, number{}))), + number{}, number{}, number{}))), make_tuple(sequence<0, 1>{}, sequence<2, 3, 4>{}), make_tuple(sequence<0>{}, sequence<1>{})); @@ -449,8 +449,8 @@ struct FusedMoeGemmPipelineFlatmmPolicy else { // lanes within a wave load different M but same K - static_assert(warpSize % LanesPerK == 0); - constexpr index_t LaneGroups = warpSize / LanesPerK; // along m + static_assert(WarpSize % LanesPerK == 0); + constexpr index_t LaneGroups = WarpSize / LanesPerK; // along m constexpr index_t NumIssues = Block_M / (LaneGroups * NumWarps); constexpr auto lds_block_desc_0 = make_naive_tensor_descriptor( @@ -459,9 +459,9 @@ struct FusedMoeGemmPipelineFlatmmPolicy number{}, // m2 number{}, // k0 number{}), // k1 - make_tuple(number{}, // m0 + make_tuple(number{}, // m0 number{}, // m1 - number{}, // m2 + number{}, // m2 number{}, // k0 number<1>{}), // k1 number{}, // lds load vector diff --git a/include/ck_tile/ops/image_to_column/pipeline/tile_image_to_column_shape.hpp b/include/ck_tile/ops/image_to_column/pipeline/tile_image_to_column_shape.hpp index b038472fcf..ad513dbd11 100644 --- a/include/ck_tile/ops/image_to_column/pipeline/tile_image_to_column_shape.hpp +++ b/include/ck_tile/ops/image_to_column/pipeline/tile_image_to_column_shape.hpp @@ -26,7 +26,7 @@ struct TileImageToColumnShape static constexpr index_t kMWarpPerBlock = kMPerBlock / kMPerWarp; static constexpr index_t kKWarpPerBlock = kKPerBlock / kKPerWarp; - static constexpr index_t kBlockSize = warpSize * kMWarpPerBlock * kKWarpPerBlock; + static constexpr index_t kBlockSize = get_warp_size() * kMWarpPerBlock * kKWarpPerBlock; }; } // namespace ck_tile diff --git a/include/ck_tile/ops/norm_reduce/block/block_norm_reduce.hpp b/include/ck_tile/ops/norm_reduce/block/block_norm_reduce.hpp index 15ac021631..26437c7126 100644 --- a/include/ck_tile/ops/norm_reduce/block/block_norm_reduce.hpp +++ b/include/ck_tile/ops/norm_reduce/block/block_norm_reduce.hpp @@ -250,7 +250,7 @@ struct BlockNormReduceCrossWarpSync // | w0 | w1 | w2 | w3 | -----> | w0123 | // // -> also store data from every wave into LDS - constexpr index_t num_warps = BlockShape::BlockSize / warpSize; + constexpr index_t num_warps = BlockShape::BlockSize / WarpSize; return num_warps * 4 * thread_buf_size * sizeof(float); } @@ -276,7 +276,7 @@ struct BlockNormReduceCrossWarpSync const index_t lane_id = get_lane_id(); const index_t warp_id = get_warp_id(); constexpr auto num_reduce_warps = GetReduceWarps(); - constexpr index_t num_warps = BlockShape::BlockSize / warpSize; + constexpr index_t num_warps = BlockShape::BlockSize / WarpSize; const index_t smem_offset = warp_id; // skip if nonthing to do diff --git a/include/ck_tile/ops/reduce/block/block_reduce2d.hpp b/include/ck_tile/ops/reduce/block/block_reduce2d.hpp index d6ca98e7b4..6a1f926a9a 100644 --- a/include/ck_tile/ops/reduce/block/block_reduce2d.hpp +++ b/include/ck_tile/ops/reduce/block/block_reduce2d.hpp @@ -210,7 +210,7 @@ struct BlockReduce2dCrossWarpSync // | w0 | w1 | w2 | w3 | -----> | w0123 | // // -> also store data from every wave into LDS - constexpr index_t num_warps = BlockShape::BlockSize / warpSize; + constexpr index_t num_warps = BlockShape::BlockSize / get_warp_size(); return num_warps * thread_buf_size * sizeof(DataType); } @@ -226,7 +226,7 @@ struct BlockReduce2dCrossWarpSync const index_t lane_id = get_lane_id(); const index_t warp_id = get_warp_id(); constexpr auto num_reduce_warps = GetReduceWarps(); - constexpr index_t num_warps = BlockShape::BlockSize / warpSize; + constexpr index_t num_warps = BlockShape::BlockSize / get_warp_size(); const index_t smem_offset = warp_id; // skip if nonthing to do