From 6d802e7ba4a2764235f8fd220e35f4a7998136cb Mon Sep 17 00:00:00 2001 From: MHYangAMD Date: Wed, 22 Oct 2025 14:41:35 +0800 Subject: [PATCH] Introduce tree reduction for BlockReduce2dCrossWarpSync (#2588) * Introduce tree reduction for BlockReduce2dCrossWarpSync * Rename original impl to BlockReduce2dLinearCrossWarpSync * Replace warp_size with get_warp_size() --------- Co-authored-by: Illia Silin <98187287+illsilin@users.noreply.github.com> [ROCm/composable_kernel commit: 5a27a97391d08652c3da0a5347209c19d3ebb03d] --- .../ops/reduce/block/block_reduce2d.hpp | 255 ++++++++---------- .../rmsnorm2d_fwd_pipeline_default_policy.hpp | 9 - ...rm2d_fwd_pipeline_model_sensitive_pass.hpp | 6 +- 3 files changed, 120 insertions(+), 150 deletions(-) diff --git a/include/ck_tile/ops/reduce/block/block_reduce2d.hpp b/include/ck_tile/ops/reduce/block/block_reduce2d.hpp index b97a66a3ec..9cddb0abf2 100644 --- a/include/ck_tile/ops/reduce/block/block_reduce2d.hpp +++ b/include/ck_tile/ops/reduce/block/block_reduce2d.hpp @@ -230,9 +230,121 @@ struct BlockReduce2dCrossWarpSync template CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize() { - using DataType = typename YDistributedTensor_::DataType; - // constexpr auto num_reduce_warps = GetReduceWarps(); + using DataType = typename YDistributedTensor_::DataType; + constexpr index_t thread_buf_size = YDistributedTensor_::get_thread_buffer_size(); + // we need to store all data from every wave into smem + // e.g. 2x2 reduce along N + // -------------> reduce N + // | w0 | w1 | ___> | w01 | + // | w2 | w3 | | w23 | + // + // -> store data from every wave into LDS + // + // + // -------------> reduce N + // | w0 | w1 | w2 | w3 | -----> | w0123 | + // + // -> also store data from every wave into LDS + constexpr index_t num_warps = BlockShape::BlockSize / get_warp_size(); + return num_warps * thread_buf_size * sizeof(DataType); + } + + template + CK_TILE_DEVICE void + operator()(YDistributedTensor_& y_tensor, void* smem, const ReduceFunc& reduce_func) + { + using DataType = typename YDistributedTensor_::DataType; + + constexpr index_t thread_buf_size = YDistributedTensor_::get_thread_buffer_size(); + + DataType* smem_ptr = reinterpret_cast(smem); + const index_t lane_id = get_lane_id(); + const index_t warp_id = get_warp_id(); + + constexpr index_t num_warps = BlockShape::BlockSize / get_warp_size(); + constexpr index_t num_reduce_warps = GetReduceWarps(); + + if constexpr(num_reduce_warps == 1) + return; + + // Each warp's lane 0 writes its partial results to shared memory + const index_t smem_offset = warp_id; + if(lane_id == 0) + { + static_for<0, thread_buf_size, 1>{}([&](auto i) { + // Store the i-th element of this warp's thread_buffer into SMEM + smem_ptr[smem_offset + i * num_warps] = y_tensor.get_thread_buffer()[i]; + }); + } + block_sync_lds(); + + // We let each warp holds a duplication to do reduction. + const index_t local_warp_id = warp_id / num_reduce_warps; + const index_t local_smem_os = local_warp_id * num_reduce_warps; + static_for<0, thread_buf_size, 1>{}([&](auto i) { + DataType v[num_reduce_warps]; + static_for<0, num_reduce_warps, 1>{}( + [&](auto idx) { v[idx] = smem_ptr[i * num_warps + local_smem_os + idx]; }); + + static_assert(is_power_of_two_integer(num_reduce_warps), + "wrong! only support power of 2 reduction"); + + constexpr index_t nstage = integer_log2_floor(num_reduce_warps); + + static_for<0, nstage, 1>{}([&](auto istage) { + constexpr index_t stride = 1 << istage.value; + static_for<0, num_reduce_warps, stride * 2>{}([&](auto idx_) { + constexpr index_t i0 = idx_(); + constexpr index_t i1 = idx_ + stride; + if constexpr(i1 < num_reduce_warps) + { + v[i0] = reduce_func(v[i0], v[i1]); + } + }); + }); + + y_tensor.get_thread_buffer()(i) = v[0]; + }); + } +}; + +template +struct BlockReduce2dLinearCrossWarpSync +{ + using Problem = remove_cvref_t; + using BlockShape = typename Problem::BlockShape; + + template + CK_TILE_DEVICE static constexpr index_t GetReduceWarps() + { + constexpr index_t num_reduce_warps = [&]() { + using Dstr = typename YDistributedTensor_::StaticTileDistribution; + using DstrEncode = typename Dstr::DstrEncode; + using DstrEncodeDetail = typename DstrEncode::detail; + + constexpr index_t NDimR = Dstr::get_num_of_dimension_r(); + + constexpr index_t idim_p_warp = 0; + + index_t len_ = 1; + static_for<0, NDimR, 1>{}([&](auto idim_r) { + if constexpr(DstrEncodeDetail::does_p_own_r_[idim_p_warp][idim_r]) + { + constexpr index_t r_length = DstrEncode::rs_lengths_[idim_r]; + len_ *= r_length; + } + }); + return len_; + }(); + return num_reduce_warps; + } + + // return in byte + template + CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize() + { + using DataType = typename YDistributedTensor_::DataType; constexpr index_t thread_buf_size = YDistributedTensor_::get_thread_buffer_size(); // we need to store all data from every wave into smem @@ -300,7 +412,9 @@ struct BlockReduce2dCrossWarpSync static_for<0, num_reduce_warps - 1, 1>{}([&](auto i_1_n1) { constexpr auto i_1 = number{}; const DataType v_remote = all_scratch[i_0 * num_reduce_warps + i_1]; - v_local = reduce_func(v_local, v_remote); + + // reduce + v_local = reduce_func(v_local, v_remote); }); y_tensor.get_thread_buffer()(i_0) = v_local; @@ -308,139 +422,4 @@ struct BlockReduce2dCrossWarpSync } }; -template -struct BlockReduce2dTreeCrossWarpSync -{ - using Problem = remove_cvref_t; - using BlockShape = typename Problem::BlockShape; - - template - CK_TILE_DEVICE static constexpr index_t GetReduceWarps() - { - constexpr index_t num_reduce_warps = [&]() { - using Dstr = typename YDistributedTensor_::StaticTileDistribution; - using DstrEncode = typename Dstr::DstrEncode; - using DstrEncodeDetail = typename DstrEncode::detail; - - constexpr index_t NDimR = Dstr::get_num_of_dimension_r(); - - constexpr index_t idim_p_warp = 0; - - index_t len_ = 1; - static_for<0, NDimR, 1>{}([&](auto idim_r) { - if constexpr(DstrEncodeDetail::does_p_own_r_[idim_p_warp][idim_r]) - { - constexpr index_t r_length = DstrEncode::rs_lengths_[idim_r]; - len_ *= r_length; - } - }); - return len_; - }(); - return num_reduce_warps; - } - - // return in byte - template - CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize() - { - using DataType = typename YDistributedTensor_::DataType; - constexpr index_t thread_buf_size = YDistributedTensor_::get_thread_buffer_size(); - - // we need to store all data from every wave into smem - // e.g. 2x2 reduce along N - // -------------> reduce N - // | w0 | w1 | ___> | w01 | - // | w2 | w3 | | w23 | - // - // -> store data from every wave into LDS - // - // - // -------------> reduce N - // | w0 | w1 | w2 | w3 | -----> | w0123 | - // - // -> also store data from every wave into LDS - constexpr index_t num_warps = BlockShape::BlockSize / warpSize; - return num_warps * thread_buf_size * sizeof(DataType); - } - - template - CK_TILE_DEVICE void - operator()(YDistributedTensor_& y_tensor, void* smem, const ReduceFunc& reduce_func) - { - using Dstr = typename YDistributedTensor_::StaticTileDistribution; - using DstrEncode = typename Dstr::DstrEncode; - using DstrEncodeDetail = typename DstrEncode::detail; - using DataType = typename YDistributedTensor_::DataType; - - constexpr index_t NDimP = Dstr::get_num_of_dimension_p(); - constexpr index_t NDimR = Dstr::get_num_of_dimension_r(); - - constexpr index_t idim_p_lane = NDimP - 1; - constexpr index_t thread_buf_size = YDistributedTensor_::get_thread_buffer_size(); - - DataType* smem_ptr = reinterpret_cast(smem); - const index_t lane_id = get_lane_id(); - const index_t warp_id = get_warp_id(); - - constexpr index_t num_warps = BlockShape::BlockSize / get_warp_size(); - constexpr index_t num_reduce_warps = GetReduceWarps(); - - if constexpr(num_reduce_warps == 1) - return; - - // Each warp's lane 0 writes its partial results to shared memory - const index_t smem_offset = warp_id; - if(lane_id == 0) - { - static_for<0, thread_buf_size, 1>{}([&](auto i) { - // Store the i-th element of this warp's thread_buffer into SMEM - smem_ptr[smem_offset + i * num_warps] = y_tensor.get_thread_buffer()[i]; - }); - } - block_sync_lds(); - - // We let each warp holds a duplication to do reduction. - const index_t local_warp_id = warp_id / num_reduce_warps; - const index_t local_smem_os = local_warp_id * num_reduce_warps; - static_for<0, thread_buf_size, 1>{}([&](auto i) { - DataType v = 0; - if(lane_id < num_reduce_warps) - { - v = smem_ptr[i * num_warps + local_smem_os + lane_id]; - } - - // cross-lane reduce for replication - // only reduce on R dimension correspond to lane - // (lane id maps to this R dimension) - static_for<0, NDimR, 1>{}([&](auto idim_r) { - // FIXME: nasty to use does_p_own_r_ - if constexpr(DstrEncodeDetail::does_p_own_r_[idim_p_lane][idim_r]) - { - constexpr index_t r_length = DstrEncode::rs_lengths_[idim_r]; - - constexpr index_t lid_over_rid_derivative = - DstrEncodeDetail::ps_over_rs_derivative_[idim_p_lane][idim_r]; - - static_assert(is_power_of_two_integer(r_length), - "wrong! only support power of 2 reduction"); - - constexpr index_t nstage = integer_log2_floor(r_length); - - // reduction sweep forward - static_for<0, nstage, 1>{}([&](auto istage) { - // pull data from remote lane - const auto o = - __shfl_xor(v, number{}.value); - - // reduce - v = reduce_func(v, o); - }); - } - }); - - y_tensor.get_thread_buffer()(i) = v; - }); - } -}; - } // namespace ck_tile diff --git a/include/ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_pipeline_default_policy.hpp b/include/ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_pipeline_default_policy.hpp index df689c6b46..356a2e12ca 100644 --- a/include/ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_pipeline_default_policy.hpp +++ b/include/ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_pipeline_default_policy.hpp @@ -69,15 +69,6 @@ struct Rmsnorm2dFwdPipelineDefaultPolicy return BlockReduce2dCrossWarpSync{}; } - template - CK_TILE_HOST_DEVICE static constexpr auto GetBlockReduce2dTreeCrossWarpSync() - { - using P_ = BlockReduce2dProblem; - return BlockReduce2dTreeCrossWarpSync{}; - } - template CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize() { diff --git a/include/ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_pipeline_model_sensitive_pass.hpp b/include/ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_pipeline_model_sensitive_pass.hpp index 1d5467b459..b05197b653 100644 --- a/include/ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_pipeline_model_sensitive_pass.hpp +++ b/include/ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_pipeline_model_sensitive_pass.hpp @@ -102,8 +102,8 @@ struct Rmsnorm2dFwdPipelineModelSensitiveT5Pass auto reduce_sum_func = ReduceOp::Add{}; auto block_reduce2d = Policy::template GetBlockReduce2d(); auto block_reduce2d_sync = Policy::template GetBlockReduce2dSync(); - auto block_reduce2d_tree_cross_warp_sync = - Policy::template GetBlockReduce2dTreeCrossWarpSync(); + auto block_reduce2d_cross_warp_sync = + Policy::template GetBlockReduce2dCrossWarpSync(); auto x = load_tile(x_window); auto x_resi = load_tile(x_residual_window); @@ -162,7 +162,7 @@ struct Rmsnorm2dFwdPipelineModelSensitiveT5Pass reduce_square_sum_func); } block_reduce2d_sync(square_sum, reduce_sum_func); - block_reduce2d_tree_cross_warp_sync(square_sum, smem, reduce_sum_func); + block_reduce2d_cross_warp_sync(square_sum, smem, reduce_sum_func); // compute inv-rms auto inv_rms = tile_elementwise_in(