From 697558d856c5bbd8b193350c10b856965658b92a Mon Sep 17 00:00:00 2001 From: rocking Date: Sat, 26 Oct 2024 20:02:38 +0000 Subject: [PATCH] Add two pass pipeline --- ...msnorm2d_rdquant_fwd_pipeline_one_pass.hpp | 7 +- ...msnorm2d_rdquant_fwd_pipeline_two_pass.hpp | 151 +++++++++++++++--- 2 files changed, 136 insertions(+), 22 deletions(-) diff --git a/include/ck_tile/ops/add_rmsnorm2d_rdquant/pipeline/add_rmsnorm2d_rdquant_fwd_pipeline_one_pass.hpp b/include/ck_tile/ops/add_rmsnorm2d_rdquant/pipeline/add_rmsnorm2d_rdquant_fwd_pipeline_one_pass.hpp index 6122129524..12a15938ae 100644 --- a/include/ck_tile/ops/add_rmsnorm2d_rdquant/pipeline/add_rmsnorm2d_rdquant_fwd_pipeline_one_pass.hpp +++ b/include/ck_tile/ops/add_rmsnorm2d_rdquant/pipeline/add_rmsnorm2d_rdquant_fwd_pipeline_one_pass.hpp @@ -54,7 +54,7 @@ struct AddRmsnorm2dRdquantFwdPipelineOnePass const GammaWindow& gamma_window_, XWindow& x_window, YScaleWindow& yscale_window, - QYWindow& y_window, + QYWindow& qy_window, ComputeDataType epsilon, ck_tile::index_t row_size, void* smem) const @@ -121,6 +121,7 @@ struct AddRmsnorm2dRdquantFwdPipelineOnePass block_reduce2d_sync(absmax, reduce_max_func); block_reduce2d_cross_warp_sync(absmax, smem, reduce_max_func); + // ex: yscale = absmax / 127 if int8 auto yscale = tile_elementwise_in( [&](const auto& v_) { return v_ / type_convert(numeric::max()); @@ -128,14 +129,14 @@ struct AddRmsnorm2dRdquantFwdPipelineOnePass absmax); store_tile(yscale_window, cast_tile(yscale)); - // quantize to + // quantize y to qy auto qy = make_static_distributed_tensor(y.get_tile_distribution()); sweep_tile(qy, [&, yscale_ = yscale](auto idx) { constexpr auto i_idx = make_tuple(idx[number<0>{}]); auto qy_ = y[idx] / yscale_[i_idx]; qy(idx) = saturates{}(qy_); }); - store_tile(y_window, qy); + store_tile(qy_window, qy); } }; } // namespace ck_tile diff --git a/include/ck_tile/ops/add_rmsnorm2d_rdquant/pipeline/add_rmsnorm2d_rdquant_fwd_pipeline_two_pass.hpp b/include/ck_tile/ops/add_rmsnorm2d_rdquant/pipeline/add_rmsnorm2d_rdquant_fwd_pipeline_two_pass.hpp index 6660defb47..e1d8bfe1e1 100644 --- a/include/ck_tile/ops/add_rmsnorm2d_rdquant/pipeline/add_rmsnorm2d_rdquant_fwd_pipeline_two_pass.hpp +++ b/include/ck_tile/ops/add_rmsnorm2d_rdquant/pipeline/add_rmsnorm2d_rdquant_fwd_pipeline_two_pass.hpp @@ -52,9 +52,9 @@ struct AddRmsnorm2dRdquantFwdPipelineTwoPass CK_TILE_DEVICE auto operator()(const AWindow& a_window_, const BWindow& b_window_, const GammaWindow& gamma_window_, - XWindow& x_window, + XWindow& x_window_, YScaleWindow& yscale_window, - QYWindow& y_window, + QYWindow& qy_window, ComputeDataType epsilon, ck_tile::index_t row_size, void* smem) const @@ -63,15 +63,17 @@ struct AddRmsnorm2dRdquantFwdPipelineTwoPass make_tile_window(a_window_, Policy::template MakeABXBlockTileDistribution()); auto b_window = make_tile_window(b_window_, Policy::template MakeABXBlockTileDistribution()); + auto x_window = + make_tile_window(x_window_, Policy::template MakeABXBlockTileDistribution()); auto gamma_window = make_tile_window( gamma_window_, Policy::template MakeGammaBlockTileDistribution()); - auto reduce_square_sum_func = [](const auto& v0, const auto& v1) { return v0 + v1 * v1; }; - auto reduce_sum_func = [](const auto& v0, const auto& v1) { return v0 + v1; }; - auto reduce_absmax_func = [](const auto& v0, const auto& v1) { return max(v0, abs(v1)); }; - auto reduce_max_func = [](const auto& v0, const auto& v1) { return max(v0, v1); }; - auto block_reduce2d = Policy::template GetBlockReduce2d(); - auto block_reduce2d_sync = Policy::template GetBlockReduce2dSync(); + auto reduce_square_sum_func = ReduceOp::SquareAdd{}; + auto reduce_sum_func = ReduceOp::Add{}; + auto reduce_absmax_func = ReduceOp::AbsMax{}; + auto reduce_max_func = ReduceOp::Max{}; + auto block_reduce2d = Policy::template GetBlockReduce2d(); + auto block_reduce2d_sync = Policy::template GetBlockReduce2dSync(); auto block_reduce2d_cross_warp_sync = Policy::template GetBlockReduce2dCrossWarpSync(); @@ -81,7 +83,7 @@ struct AddRmsnorm2dRdquantFwdPipelineTwoPass using XTensorType = decltype(cast_tile(load_tile(a_window))); auto square_sum = block_reduce2d.template MakeYBlockTile(); - set_tile(square_sum, 0); + set_tile(square_sum, reduce_square_sum_func.GetIdentityValue()); for(int iN = __builtin_amdgcn_readfirstlane(0); iN < num_n_tile_iteration; ++iN) { @@ -100,6 +102,8 @@ struct AddRmsnorm2dRdquantFwdPipelineTwoPass block_reduce2d(x, square_sum, reduce_square_sum_func); move_tile_window(x_window, {0, Block_N}); + move_tile_window(a_window, {0, Block_N}); + move_tile_window(b_window, {0, Block_N}); } block_reduce2d_sync(square_sum, reduce_sum_func); @@ -115,33 +119,142 @@ struct AddRmsnorm2dRdquantFwdPipelineTwoPass ck_tile::index_t stride_to_right_most_window = row_size % Block_N == 0 ? row_size - Block_N : row_size - row_size % Block_N; - move_tile_window(x_window, {0, -Block_N}); + if constexpr(kSaveX) + move_tile_window(x_window, {0, -Block_N}); + else + { + move_tile_window(a_window, {0, -Block_N}); + move_tile_window(b_window, {0, -Block_N}); + } move_tile_window(gamma_window, {stride_to_right_most_window}); - move_tile_window(y_window, {0, stride_to_right_most_window}); - // rmsnorm computation + absmax + quantization + using YTensorType = XTensorType; + auto absmax = block_reduce2d.template MakeYBlockTile(); + set_tile(absmax, reduce_absmax_func.GetIdentityValue()); + + // rmsnorm computation + absmax(threadwise reduce) + if constexpr(kSaveX) + __syncthreads(); + for(int iN = __builtin_amdgcn_readfirstlane(0); iN < num_n_tile_iteration; ++iN) { - const auto x = load_tile(x_window); - const auto gamma = load_tile(gamma_window); + auto x = [&]() { + if constexpr(kSaveX) + { + return load_tile(x_window); + } + else + { + const auto a = load_tile(a_window); + const auto b = load_tile(b_window); + return tile_elementwise_in( + [&](const auto& a_, const auto& b_) { + return type_convert(a_) + + type_convert(b_); + }, + a, + b); + } + }(); - auto y = make_static_distributed_tensor(x.get_tile_distribution()); + auto gamma = load_tile(gamma_window); + auto y = make_static_distributed_tensor(x.get_tile_distribution()); - sweep_tile(y, [&, inv_rms_ = inv_rms](auto idx) { + sweep_tile(y, [&](auto idx) { constexpr auto i_idx = make_tuple(idx[number<0>{}]); constexpr auto j_idx = make_tuple(idx[number<1>{}]); const auto gamma_ = type_convert(gamma[j_idx]); const auto x_ = type_convert(x[idx]); - auto y_ = x_ * inv_rms_[i_idx] * gamma_; + auto y_ = x_ * inv_rms[i_idx] * gamma_; y(idx) = type_convert(y_); }); - move_tile_window(x_window, {0, -Block_N}); + block_reduce2d(y, absmax, reduce_absmax_func); + + if constexpr(kSaveX) + move_tile_window(x_window, {0, -Block_N}); + else + { + move_tile_window(a_window, {0, -Block_N}); + move_tile_window(b_window, {0, -Block_N}); + } move_tile_window(gamma_window, {-Block_N}); - move_tile_window(y_window, {0, -Block_N}); + } + + // compute absmax, cross-lane->cross-warp + block_reduce2d_sync(absmax, reduce_max_func); + block_reduce2d_cross_warp_sync(absmax, smem, reduce_max_func); + + // ex: yscale = absmax / 127 if int8 + auto yscale = tile_elementwise_in( + [&](const auto& v_) { + return v_ / type_convert(numeric::max()); + }, + absmax); + store_tile(yscale_window, cast_tile(yscale)); + + // quantize y to qy + // recompute rmsnorm, try to save y in the future + if constexpr(kSaveX) + move_tile_window(x_window, {0, Block_N}); + else + { + move_tile_window(a_window, {0, Block_N}); + move_tile_window(b_window, {0, Block_N}); + } + move_tile_window(gamma_window, {Block_N}); + + for(int iN = __builtin_amdgcn_readfirstlane(0); iN < num_n_tile_iteration; ++iN) + { + auto x = [&]() { + if constexpr(kSaveX) + { + return load_tile(x_window); + } + else + { + const auto a = load_tile(a_window); + const auto b = load_tile(b_window); + return tile_elementwise_in( + [&](const auto& a_, const auto& b_) { + return type_convert(a_) + + type_convert(b_); + }, + a, + b); + } + }(); + + auto gamma = load_tile(gamma_window); + auto y = make_static_distributed_tensor(x.get_tile_distribution()); + auto qy = make_static_distributed_tensor(y.get_tile_distribution()); + + sweep_tile(y, [&](auto idx) { + constexpr auto i_idx = make_tuple(idx[number<0>{}]); + constexpr auto j_idx = make_tuple(idx[number<1>{}]); + + const auto gamma_ = type_convert(gamma[j_idx]); + + const auto x_ = type_convert(x[idx]); + auto y_ = x_ * inv_rms[i_idx] * gamma_; + auto qy_ = y_ / yscale[i_idx]; + qy(idx) = saturates{}(qy_); + }); + + store_tile(qy_window, qy); + + if constexpr(kSaveX) + move_tile_window(x_window, {0, Block_N}); + else + { + move_tile_window(a_window, {0, Block_N}); + move_tile_window(b_window, {0, Block_N}); + } + move_tile_window(gamma_window, {Block_N}); + move_tile_window(qy_window, {0, Block_N}); } } };