From 9bc7574d39fc9b85f4f6fd73fdd9c0e212d58ee5 Mon Sep 17 00:00:00 2001 From: "assistant-librarian[bot]" Date: Mon, 23 Jun 2025 04:29:53 +0000 Subject: [PATCH] Merge commit '7d669440a6a7b25ac539648ce77fe5a7ae87a657' into develop --- example/ck_tile/02_layernorm2d/generate.py | 20 ++-- example/ck_tile/05_reduce/reduce.hpp | 2 +- example/ck_tile/10_rmsnorm2d/generate.py | 22 ++-- .../ck_tile/12_smoothquant/smoothquant.hpp | 20 ++-- .../14_moe_smoothquant/moe_smoothquant.hpp | 20 ++-- .../17_grouped_gemm/grouped_gemm_tileloop.cpp | 3 + .../core/tensor/tile_window_linear.hpp | 19 ++- .../fused_moe/kernel/fused_moegemm_shape.hpp | 2 +- .../fused_moe/kernel/moe_sorting_kernel.hpp | 110 +++++++++--------- .../norm_reduce/block/block_norm_reduce.hpp | 4 +- 10 files changed, 112 insertions(+), 110 deletions(-) diff --git a/example/ck_tile/02_layernorm2d/generate.py b/example/ck_tile/02_layernorm2d/generate.py index 2dc9ccbd77..d77582630a 100644 --- a/example/ck_tile/02_layernorm2d/generate.py +++ b/example/ck_tile/02_layernorm2d/generate.py @@ -75,22 +75,22 @@ struct layernorm2d_fwd_traits_ using SmoothScaleDataType = ck_tile::remove_cvref_t; using YScaleDataType = ck_tile::remove_cvref_t; - static constexpr bool is_warp_per_row = ThreadPerBlock_N_ <= WarpSize; - static_assert((ThreadPerBlock_M_ * ThreadPerBlock_N_) % WarpSize == 0); + static constexpr bool is_warp_per_row = ThreadPerBlock_N_ <= ck_tile::get_warp_size(); + static_assert((ThreadPerBlock_M_ * ThreadPerBlock_N_) % ck_tile::get_warp_size() == 0); static constexpr ck_tile::index_t total_warps = - (ThreadPerBlock_M_ * ThreadPerBlock_N_) / WarpSize; + (ThreadPerBlock_M_ * ThreadPerBlock_N_) / ck_tile::get_warp_size(); // num of warps along m static constexpr ck_tile::index_t BlockWarps_M = []() { if constexpr(is_warp_per_row) { - static_assert(WarpSize % ThreadPerBlock_N_ == 0); - return total_warps * (WarpSize / ThreadPerBlock_N_); + static_assert(ck_tile::get_warp_size() % ThreadPerBlock_N_ == 0); + return total_warps * (ck_tile::get_warp_size() / ThreadPerBlock_N_); } else { - // static_assert(WarpSize % ThreadPerBlock_M_ == 0); - return total_warps / (ThreadPerBlock_N_ / WarpSize); + // static_assert(ck_tile::get_warp_size() % ThreadPerBlock_M_ == 0); + return total_warps / (ThreadPerBlock_N_ / ck_tile::get_warp_size()); } }(); @@ -98,13 +98,13 @@ struct layernorm2d_fwd_traits_ static constexpr ck_tile::index_t BlockWarps_N = []() { if constexpr(is_warp_per_row) { - static_assert(WarpSize % ThreadPerBlock_N_ == 0); + static_assert(ck_tile::get_warp_size() % ThreadPerBlock_N_ == 0); return 1; } else { - static_assert(ThreadPerBlock_N_ % WarpSize == 0); - return ThreadPerBlock_N_ / WarpSize; + static_assert(ThreadPerBlock_N_ % ck_tile::get_warp_size() == 0); + return ThreadPerBlock_N_ / ck_tile::get_warp_size(); } }(); diff --git a/example/ck_tile/05_reduce/reduce.hpp b/example/ck_tile/05_reduce/reduce.hpp index 50ffb9c1c7..6fbb0b4274 100644 --- a/example/ck_tile/05_reduce/reduce.hpp +++ b/example/ck_tile/05_reduce/reduce.hpp @@ -35,7 +35,7 @@ struct Reduce2dShape static constexpr index_t Repeat_N = Block_N / (WarpPerBlock_N * Warp_N); static constexpr index_t BlockSize = - WarpSize * reduce_on_sequence(BlockWarps{}, multiplies{}, number<1>{}); + ck_tile::get_warp_size() * reduce_on_sequence(BlockWarps{}, multiplies{}, number<1>{}); }; template ; using UnquantYDataType = ck_tile::remove_cvref_t; - static constexpr bool is_warp_per_row = ThreadPerBlock_N_ <= WarpSize; - static_assert((ThreadPerBlock_M_ * ThreadPerBlock_N_) % WarpSize == 0); + static constexpr bool is_warp_per_row = ThreadPerBlock_N_ <= ck_tile::get_warp_size(); + static_assert((ThreadPerBlock_M_ * ThreadPerBlock_N_) % ck_tile::get_warp_size() == 0); static constexpr ck_tile::index_t total_warps = - (ThreadPerBlock_M_ * ThreadPerBlock_N_) / WarpSize; + (ThreadPerBlock_M_ * ThreadPerBlock_N_) / ck_tile::get_warp_size(); // num of warps along m static constexpr ck_tile::index_t BlockWarps_M = []() { if constexpr(is_warp_per_row) { - static_assert(WarpSize % ThreadPerBlock_N_ == 0); - return total_warps * (WarpSize / ThreadPerBlock_N_); + static_assert(ck_tile::get_warp_size() % ThreadPerBlock_N_ == 0); + return total_warps * (ck_tile::get_warp_size() / ThreadPerBlock_N_); } else { - // static_assert(WarpSize % ThreadPerBlock_M_ == 0); - return total_warps / (ThreadPerBlock_N_ / WarpSize); + // static_assert(ck_tile::get_warp_size() % ThreadPerBlock_M_ == 0); + return total_warps / (ThreadPerBlock_N_ / ck_tile::get_warp_size()); } }(); @@ -97,13 +97,13 @@ struct rmsnorm2d_fwd_traits_ static constexpr ck_tile::index_t BlockWarps_N = []() { if constexpr(is_warp_per_row) { - static_assert(WarpSize % ThreadPerBlock_N_ == 0); + static_assert(ck_tile::get_warp_size() % ThreadPerBlock_N_ == 0); return 1; } else { - static_assert(ThreadPerBlock_N_ % WarpSize == 0); - return ThreadPerBlock_N_ / WarpSize; + static_assert(ThreadPerBlock_N_ % ck_tile::get_warp_size() == 0); + return ThreadPerBlock_N_ / ck_tile::get_warp_size(); } }(); @@ -712,4 +712,4 @@ if __name__ == "__main__": if args.list_blobs: list_blobs(args) else: - gen_blobs(args) \ No newline at end of file + gen_blobs(args) diff --git a/example/ck_tile/12_smoothquant/smoothquant.hpp b/example/ck_tile/12_smoothquant/smoothquant.hpp index 265399c276..5f8254a664 100644 --- a/example/ck_tile/12_smoothquant/smoothquant.hpp +++ b/example/ck_tile/12_smoothquant/smoothquant.hpp @@ -49,22 +49,22 @@ struct smoothquant_traits_ { using DataType = ck_tile::remove_cvref_t; - static constexpr bool is_warp_per_row = ThreadPerBlock_N_ <= WarpSize; - static_assert((ThreadPerBlock_M_ * ThreadPerBlock_N_) % WarpSize == 0); + static constexpr bool is_warp_per_row = ThreadPerBlock_N_ <= ck_tile::get_warp_size(); + static_assert((ThreadPerBlock_M_ * ThreadPerBlock_N_) % ck_tile::get_warp_size() == 0); static constexpr ck_tile::index_t total_warps = - (ThreadPerBlock_M_ * ThreadPerBlock_N_) / WarpSize; + (ThreadPerBlock_M_ * ThreadPerBlock_N_) / ck_tile::get_warp_size(); // num of warps along m static constexpr ck_tile::index_t BlockWarps_M = []() { if constexpr(is_warp_per_row) { - static_assert(WarpSize % ThreadPerBlock_N_ == 0); - return total_warps * (WarpSize / ThreadPerBlock_N_); + static_assert(ck_tile::get_warp_size() % ThreadPerBlock_N_ == 0); + return total_warps * (ck_tile::get_warp_size() / ThreadPerBlock_N_); } else { - // static_assert(WarpSize % ThreadPerBlock_M_ == 0); - return total_warps / (ThreadPerBlock_N_ / WarpSize); + // static_assert(ck_tile::get_warp_size() % ThreadPerBlock_M_ == 0); + return total_warps / (ThreadPerBlock_N_ / ck_tile::get_warp_size()); } }(); @@ -72,13 +72,13 @@ struct smoothquant_traits_ static constexpr ck_tile::index_t BlockWarps_N = []() { if constexpr(is_warp_per_row) { - static_assert(WarpSize % ThreadPerBlock_N_ == 0); + static_assert(ck_tile::get_warp_size() % ThreadPerBlock_N_ == 0); return 1; } else { - static_assert(ThreadPerBlock_N_ % WarpSize == 0); - return ThreadPerBlock_N_ / WarpSize; + static_assert(ThreadPerBlock_N_ % ck_tile::get_warp_size() == 0); + return ThreadPerBlock_N_ / ck_tile::get_warp_size(); } }(); diff --git a/example/ck_tile/14_moe_smoothquant/moe_smoothquant.hpp b/example/ck_tile/14_moe_smoothquant/moe_smoothquant.hpp index b29295f175..36cf477a42 100644 --- a/example/ck_tile/14_moe_smoothquant/moe_smoothquant.hpp +++ b/example/ck_tile/14_moe_smoothquant/moe_smoothquant.hpp @@ -38,22 +38,22 @@ struct moe_smoothquant_traits_ using InputType = ck_tile::remove_cvref_t; using OutputType = ck_tile::remove_cvref_t; - static constexpr bool is_warp_per_row = ThreadPerBlock_N_ <= WarpSize; - static_assert((ThreadPerBlock_M_ * ThreadPerBlock_N_) % WarpSize == 0); + static constexpr bool is_warp_per_row = ThreadPerBlock_N_ <= ck_tile::get_warp_size(); + static_assert((ThreadPerBlock_M_ * ThreadPerBlock_N_) % ck_tile::get_warp_size() == 0); static constexpr ck_tile::index_t total_warps = - (ThreadPerBlock_M_ * ThreadPerBlock_N_) / WarpSize; + (ThreadPerBlock_M_ * ThreadPerBlock_N_) / ck_tile::get_warp_size(); // num of warps along m static constexpr ck_tile::index_t BlockWarps_M = []() { if constexpr(is_warp_per_row) { - static_assert(WarpSize % ThreadPerBlock_N_ == 0); - return total_warps * (WarpSize / ThreadPerBlock_N_); + static_assert(ck_tile::get_warp_size() % ThreadPerBlock_N_ == 0); + return total_warps * (ck_tile::get_warp_size() / ThreadPerBlock_N_); } else { - // static_assert(WarpSize % ThreadPerBlock_M_ == 0); - return total_warps / (ThreadPerBlock_N_ / WarpSize); + // static_assert(ck_tile::get_warp_size() % ThreadPerBlock_M_ == 0); + return total_warps / (ThreadPerBlock_N_ / ck_tile::get_warp_size()); } }(); @@ -61,13 +61,13 @@ struct moe_smoothquant_traits_ static constexpr ck_tile::index_t BlockWarps_N = []() { if constexpr(is_warp_per_row) { - static_assert(WarpSize % ThreadPerBlock_N_ == 0); + static_assert(ck_tile::get_warp_size() % ThreadPerBlock_N_ == 0); return 1; } else { - static_assert(ThreadPerBlock_N_ % WarpSize == 0); - return ThreadPerBlock_N_ / WarpSize; + static_assert(ThreadPerBlock_N_ % ck_tile::get_warp_size() == 0); + return ThreadPerBlock_N_ / ck_tile::get_warp_size(); } }(); diff --git a/example/ck_tile/17_grouped_gemm/grouped_gemm_tileloop.cpp b/example/ck_tile/17_grouped_gemm/grouped_gemm_tileloop.cpp index 5c0cb92683..4107181520 100644 --- a/example/ck_tile/17_grouped_gemm/grouped_gemm_tileloop.cpp +++ b/example/ck_tile/17_grouped_gemm/grouped_gemm_tileloop.cpp @@ -116,9 +116,12 @@ float grouped_gemm_tileloop(const ck_tile::stream_config& s, using GemmEpilogue = ck_tile::CShuffleEpilogue< ck_tile::CShuffleEpilogueProblem, AccDataType, CDataType, + ck_tile::tuple<>, CLayout, + ck_tile::element_wise::PassThrough, GemmPipelineProblem::kBlockSize, TilePartitioner::MPerBlock, TilePartitioner::NPerBlock, diff --git a/include/ck_tile/core/tensor/tile_window_linear.hpp b/include/ck_tile/core/tensor/tile_window_linear.hpp index 56c5066774..596584f3cc 100644 --- a/include/ck_tile/core/tensor/tile_window_linear.hpp +++ b/include/ck_tile/core/tensor/tile_window_linear.hpp @@ -314,8 +314,7 @@ struct tile_window_linear constexpr auto tile_dstr = typename Base::TileDstr{}; - auto dst_tensor = - make_static_distributed_tensor(tile_dstr); + auto dst_tensor = make_static_distributed_tensor(tile_dstr); auto issue = [&](auto i_access_) { constexpr auto IAccess = number{}; @@ -348,8 +347,9 @@ struct tile_window_linear constexpr index_t d = tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys) / Base::Traits::PackedSize; - dst_tensor.get_thread_buffer().template at() = vec_value.template get_as< - typename Base::DataTypeDataType>()[j / Base::Traits::PackedSize]; + dst_tensor.get_thread_buffer().template at() = + vec_value + .template get_as()[j / Base::Traits::PackedSize]; }); }; @@ -400,8 +400,9 @@ struct tile_window_linear constexpr index_t d = tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys) / Base::Traits::PackedSize; - dst_tensor.get_thread_buffer().template at() = vec_value.template get_as< - typename Base::DataTypeDataType>()[j / Base::Traits::PackedSize]; + dst_tensor.get_thread_buffer().template at() = + vec_value + .template get_as()[j / Base::Traits::PackedSize]; }); }; @@ -804,8 +805,7 @@ struct tile_window_linear constexpr index_t d = tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys) / Base::Traits::PackedSize; - vec_value.template get_as()( - j / Base::Traits::PackedSize) = + vec_value.template get_as()(j / Base::Traits::PackedSize) = dstr_tensor.get_thread_buffer().template at(); }); @@ -860,8 +860,7 @@ struct tile_window_linear constexpr index_t d = tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys) / Base::Traits::PackedSize; - vec_value.template get_as()( - j / Base::Traits::PackedSize) = + vec_value.template get_as()(j / Base::Traits::PackedSize) = dstr_tensor.get_thread_buffer().template at(); }); diff --git a/include/ck_tile/ops/fused_moe/kernel/fused_moegemm_shape.hpp b/include/ck_tile/ops/fused_moe/kernel/fused_moegemm_shape.hpp index 336bdc806f..92f6a48648 100644 --- a/include/ck_tile/ops/fused_moe/kernel/fused_moegemm_shape.hpp +++ b/include/ck_tile/ops/fused_moe/kernel/fused_moegemm_shape.hpp @@ -101,7 +101,7 @@ struct FusedMoeGemmShape static constexpr index_t Repeat_N1 = Block_N1 / ThreadPerBlock_N1; static constexpr index_t Repeat_K1 = Block_K1 / ThreadPerBlock_K1; - static constexpr index_t BlockSize = WarpSize * NumWarps; + static constexpr index_t BlockSize = get_warp_size() * NumWarps; // some assert static_assert(Block_M0 == Block_M1); diff --git a/include/ck_tile/ops/fused_moe/kernel/moe_sorting_kernel.hpp b/include/ck_tile/ops/fused_moe/kernel/moe_sorting_kernel.hpp index 3e2e100025..5da675ae42 100644 --- a/include/ck_tile/ops/fused_moe/kernel/moe_sorting_kernel.hpp +++ b/include/ck_tile/ops/fused_moe/kernel/moe_sorting_kernel.hpp @@ -388,7 +388,7 @@ struct MoeSortingKernel } // reduce single pixel within a wave - template + template __device__ static constexpr T wave_reduce(T local, F reduce_f, number = {}) { // constexpr int wave_size = 64; @@ -625,7 +625,7 @@ struct MoeSortingKernel { const index_t prefill_token = topk_mdiv.div(numel); // TODO: only support expert-tile like 8, 16, 32 - static constexpr index_t experts_per_wave = WarpSize / Problem::ExpertTile; + static constexpr index_t experts_per_wave = get_warp_size() / Problem::ExpertTile; { index_t eid = tid / experts_per_wave; index_t expert_offset = cumsum[eid] + @@ -693,7 +693,7 @@ struct MoeSortingKernel void* smem) const { const index_t tid = static_cast(threadIdx.x); - const index_t wid = __builtin_amdgcn_readfirstlane(tid / WarpSize); + const index_t wid = __builtin_amdgcn_readfirstlane(tid / get_warp_size()); const index_t lid = __lane_id(); constexpr index_t block_size = 256; // blockDim.x; const index_t sub_tokens = smem_rows - 2; // sub_tokens_mdiv.divisor; @@ -798,7 +798,7 @@ struct MoeSortingKernel // NOTE: under this block can never use __syncthreads! int i_e_ = 0; int local_cumsum_ = 0; - for(; i_e_ < num_experts; i_e_ += WarpSize) + for(; i_e_ < num_experts; i_e_ += get_warp_size()) { int pre_cumsum_ = smem_cumsum(lid == 0 ? i_e_ : 0); int local_cnt = smem_cumsum(i_e_ + lid + 1); @@ -843,7 +843,7 @@ struct MoeSortingKernel // cumsum padded in case local cumsum is zero, but // pre_sumsum has value, which will result int // zero local cumsum(but we want at least padded) - wave_cumsum(local_cumsum_); + wave_cumsum(local_cumsum_); if((i_e_ + lid) < num_experts) smem_cumsum(i_e_ + lid + 1) = local_cumsum_; @@ -851,7 +851,7 @@ struct MoeSortingKernel if constexpr(Problem::LocalExpertMasking) { local_masking += pre_cumsum_masking; - wave_cumsum(local_masking); + wave_cumsum(local_masking); if((i_e_ + lid) < num_experts) smem_cumdup(i_e_ + lid + 1) = local_masking; } @@ -861,7 +861,7 @@ struct MoeSortingKernel // than 0(which is not we want) __builtin_amdgcn_s_waitcnt(0xc07f); } - if((lid + i_e_ - WarpSize) == (num_experts - 1)) + if((lid + i_e_ - get_warp_size()) == (num_experts - 1)) { *p_total_tokens_post_pad = local_cumsum_; } @@ -1109,7 +1109,7 @@ CK_TILE_HOST_DEVICE index_t moe_sorting_mp_sem_smem_size() return chunk * sizeof(index_t); }; -template +template CK_TILE_DEVICE constexpr T moe_sorting_wave_reduce(T local, F reduce_f, number = {}) { // constexpr int wave_size = 64; @@ -1504,7 +1504,7 @@ struct MoeSortingMultiPhaseKernel_P1 // in byte CK_TILE_HOST_DEVICE static constexpr auto GetSmemSize() { - return BLOCK_SIZE / WarpSize * sizeof(IndexType); + return BLOCK_SIZE / get_warp_size() * sizeof(IndexType); } CK_TILE_DEVICE void operator()(Kargs kargs) const @@ -1546,8 +1546,8 @@ struct MoeSortingMultiPhaseKernel_P1 cnt += impl::moe_sorting_wave_reduce(local_sum, f_sum); } - index_t lane_id = threadIdx.x % WarpSize; - index_t wave_id = threadIdx.x / WarpSize; + index_t lane_id = threadIdx.x % get_warp_size(); + index_t wave_id = threadIdx.x / get_warp_size(); // reduce cross wave IndexType* s = reinterpret_cast(smem); @@ -1560,7 +1560,7 @@ struct MoeSortingMultiPhaseKernel_P1 if(threadIdx.x == 0) { index_t c = 0; - for(auto i = 0; i < (BLOCK_SIZE / WarpSize); i++) + for(auto i = 0; i < (BLOCK_SIZE / get_warp_size()); i++) { c += s[i]; } @@ -1660,7 +1660,7 @@ struct MoeSortingMultiPhaseKernel_P01 // in byte CK_TILE_HOST static constexpr auto GetSmemSize() { - return BLOCK_SIZE / WarpSize * sizeof(IndexType); + return BLOCK_SIZE / get_warp_size() * sizeof(IndexType); } CK_TILE_DEVICE void operator()(Kargs kargs) const @@ -1786,8 +1786,8 @@ struct MoeSortingMultiPhaseKernel_P01 cnt += impl::moe_sorting_wave_reduce(local_sum, f_sum); } - index_t lane_id = threadIdx.x % WarpSize; - index_t wave_id = threadIdx.x / WarpSize; + index_t lane_id = threadIdx.x % get_warp_size(); + index_t wave_id = threadIdx.x / get_warp_size(); // reduce cross wave IndexType* s = reinterpret_cast(smem); @@ -1801,7 +1801,7 @@ struct MoeSortingMultiPhaseKernel_P01 if(threadIdx.x == 0) { index_t c = 0; - for(auto i = 0; i < (BLOCK_SIZE / WarpSize); i++) + for(auto i = 0; i < (BLOCK_SIZE / get_warp_size()); i++) { c += s[i]; } @@ -1880,7 +1880,7 @@ struct MoeSortingMultiPhaseKernel_P2 CK_TILE_HOST_DEVICE static constexpr auto GetSmemSize() { // return 2 * BLOCK_SIZE * sizeof(IndexType); - return (4 + 2 * BLOCK_SIZE / WarpSize) * sizeof(IndexType); + return (4 + 2 * BLOCK_SIZE / get_warp_size()) * sizeof(IndexType); } // reduce single pixel within a wave @@ -1905,8 +1905,8 @@ struct MoeSortingMultiPhaseKernel_P2 IndexType* p_sorted_expert_ids = reinterpret_cast(kargs.p_sorted_expert_ids); const index_t loops = (kargs.num_experts + BLOCK_SIZE - 1) / BLOCK_SIZE; - index_t wave_id = threadIdx.x / WarpSize; - index_t lane_id = threadIdx.x % WarpSize; + index_t wave_id = threadIdx.x / get_warp_size(); + index_t lane_id = threadIdx.x % get_warp_size(); IndexType prev_cumsum_a = 0; IndexType prev_cumsum_b = 0; @@ -1951,22 +1951,22 @@ struct MoeSortingMultiPhaseKernel_P2 IndexType cumsum_b = b_; // Note: we first cumsum local round, then add previous cumsum - impl::moe_sorting_wave_cumsum(cumsum_a); - impl::moe_sorting_wave_cumsum(cumsum_b); + impl::moe_sorting_wave_cumsum(cumsum_a); + impl::moe_sorting_wave_cumsum(cumsum_b); __syncthreads(); - if(lane_id == WarpSize - 1) + if(lane_id == get_warp_size() - 1) { - s[4 + wave_id] = cumsum_a; - s[4 + wave_id + BLOCK_SIZE / WarpSize] = cumsum_b; + s[4 + wave_id] = cumsum_a; + s[4 + wave_id + BLOCK_SIZE / get_warp_size()] = cumsum_b; } __syncthreads(); // reduce cross wave - static_for<0, BLOCK_SIZE / WarpSize - 1, 1>{}([&](auto i_w) { + static_for<0, BLOCK_SIZE / get_warp_size() - 1, 1>{}([&](auto i_w) { IndexType prev_a = s[4 + i_w]; - IndexType prev_b = s[4 + i_w + BLOCK_SIZE / WarpSize]; + IndexType prev_b = s[4 + i_w + BLOCK_SIZE / get_warp_size()]; prev_a = wave_id > i_w ? prev_a : 0; // mask out prev_b = wave_id > i_w ? prev_b : 0; // mask out cumsum_a += prev_a; @@ -2083,7 +2083,7 @@ struct MoeSortingMultiPhaseKernel_P3 // in byte CK_TILE_HOST_DEVICE static constexpr auto GetSmemSize() { - return (4 + BLOCK_SIZE / WarpSize) * sizeof(IndexType); + return (4 + BLOCK_SIZE / get_warp_size()) * sizeof(IndexType); } CK_TILE_DEVICE void operator()(Kargs kargs) const @@ -2110,8 +2110,8 @@ struct MoeSortingMultiPhaseKernel_P3 } }(); int eid = blockIdx.x; - int wave_id = threadIdx.x / WarpSize; - int lane_id = threadIdx.x % WarpSize; + int wave_id = threadIdx.x / get_warp_size(); + int lane_id = threadIdx.x % get_warp_size(); int e_start = p_expert_cumsum[eid]; int e_end = p_expert_cumsum[eid + 1]; if constexpr(Problem::SkipExpertsWithZeroTokens) @@ -2141,17 +2141,17 @@ struct MoeSortingMultiPhaseKernel_P3 int i_topk = x - 1; // topk of this token int i_show = x != 0 ? 1 : 0; // has this token or not int cumsum = i_show; - impl::moe_sorting_wave_cumsum(cumsum); + impl::moe_sorting_wave_cumsum(cumsum); __syncthreads(); - if(lane_id == WarpSize - 1) + if(lane_id == get_warp_size() - 1) { s[4 + wave_id] = cumsum; } __syncthreads(); // reduce cross wave - static_for<0, BLOCK_SIZE / WarpSize - 1, 1>{}([&](auto i_w) { + static_for<0, BLOCK_SIZE / get_warp_size() - 1, 1>{}([&](auto i_w) { IndexType prev = s[4 + i_w]; prev = wave_id > i_w ? prev : 0; // mask out cumsum += prev; @@ -2196,7 +2196,7 @@ CK_TILE_HOST constexpr auto moe_sorting_get_smem_size_p23(int num_experts_) { constexpr index_t BLOCK_SIZE = 256; // hardcoded 256 const index_t expert_cumsum_elem = num_experts_ + 1; - return (4 + 2 * BLOCK_SIZE / WarpSize + expert_cumsum_elem) * sizeof(int); + return (4 + 2 * BLOCK_SIZE / get_warp_size() + expert_cumsum_elem) * sizeof(int); } } // namespace impl @@ -2303,15 +2303,15 @@ struct MoeSortingMultiPhaseKernel_P23 const IndexType* p_local_expert_mask = static_cast(kargs.p_local_expert_mask); IndexType* p_expert_cumsum = reinterpret_cast(kargs.p_expert_cumsum); - IndexType* p_expert_cumsum_smem = s + 4 + 2 * BLOCK_SIZE / WarpSize; + IndexType* p_expert_cumsum_smem = s + 4 + 2 * BLOCK_SIZE / get_warp_size(); IndexType* p_total_tokens_post_pad = reinterpret_cast(kargs.p_total_tokens_post_pad); IndexType* p_sorted_expert_ids = reinterpret_cast(kargs.p_sorted_expert_ids); const index_t loops = (kargs.num_experts + BLOCK_SIZE - 1) / BLOCK_SIZE; - index_t wave_id = threadIdx.x / WarpSize; - index_t lane_id = threadIdx.x % WarpSize; + index_t wave_id = threadIdx.x / get_warp_size(); + index_t lane_id = threadIdx.x % get_warp_size(); IndexType prev_cumsum_a = 0; IndexType prev_cumsum_b = 0; @@ -2356,22 +2356,22 @@ struct MoeSortingMultiPhaseKernel_P23 IndexType cumsum_b = b_; // Note: we first cumsum local round, then add previous cumsum - impl::moe_sorting_wave_cumsum(cumsum_a); - impl::moe_sorting_wave_cumsum(cumsum_b); + impl::moe_sorting_wave_cumsum(cumsum_a); + impl::moe_sorting_wave_cumsum(cumsum_b); __syncthreads(); - if(lane_id == WarpSize - 1) + if(lane_id == get_warp_size() - 1) { - s[4 + wave_id] = cumsum_a; - s[4 + wave_id + BLOCK_SIZE / WarpSize] = cumsum_b; + s[4 + wave_id] = cumsum_a; + s[4 + wave_id + BLOCK_SIZE / get_warp_size()] = cumsum_b; } __syncthreads(); // reduce cross wave - static_for<0, BLOCK_SIZE / WarpSize - 1, 1>{}([&](auto i_w) { + static_for<0, BLOCK_SIZE / get_warp_size() - 1, 1>{}([&](auto i_w) { IndexType prev_a = s[4 + i_w]; - IndexType prev_b = s[4 + i_w + BLOCK_SIZE / WarpSize]; + IndexType prev_b = s[4 + i_w + BLOCK_SIZE / get_warp_size()]; prev_a = wave_id > i_w ? prev_a : 0; // mask out prev_b = wave_id > i_w ? prev_b : 0; // mask out cumsum_a += prev_a; @@ -2441,13 +2441,13 @@ struct MoeSortingMultiPhaseKernel_P23 IndexType* s = reinterpret_cast(smem); MeshType* p_expert_mesh = reinterpret_cast(kargs.p_expert_mesh); IndexType* p_sorted_token_ids = reinterpret_cast(kargs.p_sorted_token_ids); - IndexType* p_expert_cumsum_smem = s + 4 + 2 * BLOCK_SIZE / WarpSize; + IndexType* p_expert_cumsum_smem = s + 4 + 2 * BLOCK_SIZE / get_warp_size(); const WeightType* p_weights = static_cast(kargs.p_weights); WeightType* p_sorted_weights = reinterpret_cast(kargs.p_sorted_weights); int eid = blockIdx.x; - int wave_id = threadIdx.x / WarpSize; - int lane_id = threadIdx.x % WarpSize; + int wave_id = threadIdx.x / get_warp_size(); + int lane_id = threadIdx.x % get_warp_size(); int e_start = p_expert_cumsum_smem[eid]; int e_end = p_expert_cumsum_smem[eid + 1]; if constexpr(Problem::SkipExpertsWithZeroTokens) @@ -2518,17 +2518,17 @@ struct MoeSortingMultiPhaseKernel_P23 int i_topk = x - 1; // topk of this token int i_show = x != 0 ? 1 : 0; // has this token or not int cumsum = i_show; - impl::moe_sorting_wave_cumsum(cumsum); + impl::moe_sorting_wave_cumsum(cumsum); __syncthreads(); - if(lane_id == WarpSize - 1) + if(lane_id == get_warp_size() - 1) { s[4 + wave_id] = cumsum; } __syncthreads(); // reduce cross wave - static_for<0, BLOCK_SIZE / WarpSize - 1, 1>{}([&](auto i_w) { + static_for<0, BLOCK_SIZE / get_warp_size() - 1, 1>{}([&](auto i_w) { IndexType prev = s[4 + i_w]; prev = wave_id > i_w ? prev : 0; // mask out cumsum += prev; @@ -2569,17 +2569,17 @@ struct MoeSortingMultiPhaseKernel_P23 cumsum_store += i_show[j]; }); int cumsum = cumsum_store; - impl::moe_sorting_wave_cumsum(cumsum); + impl::moe_sorting_wave_cumsum(cumsum); __syncthreads(); - if(lane_id == WarpSize - 1) + if(lane_id == get_warp_size() - 1) { s[4 + wave_id] = cumsum; } __syncthreads(); // reduce cross wave - static_for<0, BLOCK_SIZE / WarpSize - 1, 1>{}([&](auto i_w) { + static_for<0, BLOCK_SIZE / get_warp_size() - 1, 1>{}([&](auto i_w) { IndexType prev = s[4 + i_w]; prev = wave_id > i_w ? prev : 0; // mask out cumsum += prev; @@ -2624,17 +2624,17 @@ struct MoeSortingMultiPhaseKernel_P23 int i_topk_1 = x1 - 1; // topk of this token int i_show_1 = x1 != 0 ? 1 : 0; // has this token or not int cumsum = i_show_0 + i_show_1; - impl::moe_sorting_wave_cumsum(cumsum); + impl::moe_sorting_wave_cumsum(cumsum); __syncthreads(); - if(lane_id == WarpSize - 1) + if(lane_id == get_warp_size() - 1) { s[4 + wave_id] = cumsum; } __syncthreads(); // reduce cross wave - static_for<0, BLOCK_SIZE / WarpSize - 1, 1>{}([&](auto i_w) { + static_for<0, BLOCK_SIZE / get_warp_size() - 1, 1>{}([&](auto i_w) { IndexType prev = s[4 + i_w]; prev = wave_id > i_w ? prev : 0; // mask out cumsum += prev; diff --git a/include/ck_tile/ops/norm_reduce/block/block_norm_reduce.hpp b/include/ck_tile/ops/norm_reduce/block/block_norm_reduce.hpp index 26437c7126..88da6be86e 100644 --- a/include/ck_tile/ops/norm_reduce/block/block_norm_reduce.hpp +++ b/include/ck_tile/ops/norm_reduce/block/block_norm_reduce.hpp @@ -250,7 +250,7 @@ struct BlockNormReduceCrossWarpSync // | w0 | w1 | w2 | w3 | -----> | w0123 | // // -> also store data from every wave into LDS - constexpr index_t num_warps = BlockShape::BlockSize / WarpSize; + constexpr index_t num_warps = BlockShape::BlockSize / get_warp_size(); return num_warps * 4 * thread_buf_size * sizeof(float); } @@ -276,7 +276,7 @@ struct BlockNormReduceCrossWarpSync const index_t lane_id = get_lane_id(); const index_t warp_id = get_warp_id(); constexpr auto num_reduce_warps = GetReduceWarps(); - constexpr index_t num_warps = BlockShape::BlockSize / WarpSize; + constexpr index_t num_warps = BlockShape::BlockSize / get_warp_size(); const index_t smem_offset = warp_id; // skip if nonthing to do