From 64613296dc4332ed42983c5affce719b3cc19dfb Mon Sep 17 00:00:00 2001 From: zanzhang Date: Tue, 26 Aug 2025 19:46:58 +0800 Subject: [PATCH] rms norm 9.7us --- .../default_2d_and_dynamic_quant_epilogue.hpp | 1 + .../ops/epilogue/default_2d_epilogue.hpp | 22 +++++ .../rmsnorm2d/kernel/rmsnorm2d_fwd_kernel.hpp | 17 ++-- .../rmsnorm2d_fwd_pipeline_default_policy.hpp | 8 +- .../rmsnorm2d_fwd_pipeline_one_pass.hpp | 94 +++++++++++-------- 5 files changed, 90 insertions(+), 52 deletions(-) diff --git a/include/ck_tile/ops/epilogue/default_2d_and_dynamic_quant_epilogue.hpp b/include/ck_tile/ops/epilogue/default_2d_and_dynamic_quant_epilogue.hpp index 7ea053443a..aa4ee3f67b 100644 --- a/include/ck_tile/ops/epilogue/default_2d_and_dynamic_quant_epilogue.hpp +++ b/include/ck_tile/ops/epilogue/default_2d_and_dynamic_quant_epilogue.hpp @@ -100,6 +100,7 @@ struct Default2DAndDynamicQuantEpilogue const bool isArray, void* smem) { + // Default2D{}(o_direct_dram_window_tmp, o_acc_tiles, Problem::BlockShape::Repeat_N, smem); // DynamicQuant{}(o_quant_dram_window_tmp, sm_scale_window_, y_scale_window, o_acc_tiles, true, smem); } }; diff --git a/include/ck_tile/ops/epilogue/default_2d_epilogue.hpp b/include/ck_tile/ops/epilogue/default_2d_epilogue.hpp index fdbe2e7a6d..eb20a57d1f 100644 --- a/include/ck_tile/ops/epilogue/default_2d_epilogue.hpp +++ b/include/ck_tile/ops/epilogue/default_2d_epilogue.hpp @@ -109,6 +109,28 @@ struct Default2DEpilogue { return operator()(o_dram_window_tmp, o_acc_tile); } + + template + CK_TILE_DEVICE auto + operator()(ODramWindowTmp& o_dram_window_tmp, const OAccTiles& o_acc_tiles, int Repeat_N, void* = nullptr) const + { + // TODO: this is ugly + for (int repeat_n = 0; repeat_n < Repeat_N; ++repeat_n) + { + auto o_acc_tmp = o_acc_tiles[repeat_n]; + + if constexpr(UseRawStore && (kPadM || kPadN)) + { + store_tile_raw(o_dram_window_tmp, cast_tile(o_acc_tmp)); + buffer_store_fence(); + } + else + { + store_tile(o_dram_window_tmp, cast_tile(o_acc_tmp)); + } + o_dram_window_tmp.move({0, 5120 / Repeat_N}); + } + } }; template diff --git a/include/ck_tile/ops/rmsnorm2d/kernel/rmsnorm2d_fwd_kernel.hpp b/include/ck_tile/ops/rmsnorm2d/kernel/rmsnorm2d_fwd_kernel.hpp index d56370acbc..8fcb1bbbbe 100644 --- a/include/ck_tile/ops/rmsnorm2d/kernel/rmsnorm2d_fwd_kernel.hpp +++ b/include/ck_tile/ops/rmsnorm2d/kernel/rmsnorm2d_fwd_kernel.hpp @@ -70,6 +70,7 @@ struct Rmsnorm2dFwd static constexpr index_t ThreadPerWarp_N = Problem::BlockShape::ThreadPerWarp_N; static constexpr index_t Vector_N = Problem::BlockShape::Vector_N; static constexpr index_t Repeat_N = Problem::BlockShape::Repeat_N; + static constexpr index_t Stride_N = Block_N / Repeat_N; static constexpr auto I0 = number<0>{}; static constexpr auto I1 = number<1>{}; @@ -194,7 +195,7 @@ struct Rmsnorm2dFwd const auto tmp2_ = pad_tensor_view( tmp_, make_tuple(number{}, number{}), sequence{}); return make_tile_window( - tmp2_, make_tuple(number{}, number{}), {iM, 0}); + tmp2_, make_tuple(number{}, number{}), {iM, 0}); }(); const auto x_residual_window = [&]() { @@ -212,7 +213,7 @@ struct Rmsnorm2dFwd make_tuple(number{}, number{}), sequence{}); return make_tile_window( - tmp2_, make_tuple(number{}, number{}), {iM, 0}); + tmp2_, make_tuple(number{}, number{}), {iM, 0}); } else { @@ -231,7 +232,7 @@ struct Rmsnorm2dFwd const auto tmp2_ = pad_tensor_view(tmp_, make_tuple(number{}), sequence{}); - return make_tile_window(tmp2_, make_tuple(number{}), {0}); + return make_tile_window(tmp2_, make_tuple(number{}), {0}); }(); auto y_window = [&]() { @@ -262,7 +263,7 @@ struct Rmsnorm2dFwd make_tuple(number{}, number{}), sequence{}); return make_tile_window( - tmp2_, make_tuple(number{}, number{}), {iM, 0}); + tmp2_, make_tuple(number{}, number{}), {iM, 0}); } else { @@ -302,11 +303,11 @@ struct Rmsnorm2dFwd make_tuple(number{}), sequence{}); // sm_scale no need pad }(); - return make_tile_window(win_, make_tuple(number{}), {0}); + return make_tile_window(win_, make_tuple(number{}), {0}); } else { - return make_null_tile_window(make_tuple(number{})); + return make_null_tile_window(make_tuple(number{})); } }(); @@ -347,11 +348,11 @@ struct Rmsnorm2dFwd make_tuple(number{}, number{}), sequence{}); return make_tile_window( - tmp2_, make_tuple(number{}, number{}), {iM, 0}); + tmp2_, make_tuple(number{}, number{}), {iM, 0}); } else { - return make_null_tile_window(make_tuple(number{}, number{})); + return make_null_tile_window(make_tuple(number{}, number{})); } }(); diff --git a/include/ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_pipeline_default_policy.hpp b/include/ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_pipeline_default_policy.hpp index 1ed76a2dc1..818e02c807 100644 --- a/include/ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_pipeline_default_policy.hpp +++ b/include/ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_pipeline_default_policy.hpp @@ -51,11 +51,11 @@ struct Rmsnorm2dFwdPipelineDefaultPolicy return make_static_tile_distribution( tile_distribution_encoding< sequence, - tuple>, + tuple>, tuple, sequence<0, 1>>, - tuple, sequence<1, 2>>, - sequence<1, 1>, - sequence<0, 3>>{}); + tuple, sequence<1, 1>>, + sequence<1>, + sequence<2>>{}); } template diff --git a/include/ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_pipeline_one_pass.hpp b/include/ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_pipeline_one_pass.hpp index b9c32c2537..47b653022d 100644 --- a/include/ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_pipeline_one_pass.hpp +++ b/include/ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_pipeline_one_pass.hpp @@ -21,6 +21,7 @@ struct Rmsnorm2dFwdPipelineOnePass using ComputeDataType = ck_tile::remove_cvref_t; using YDataType = ck_tile::remove_cvref_t; using InvRmsDataType = ck_tile::remove_cvref_t; + using UnquantYDataType= ck_tile::remove_cvref_t; using XResidualDataType = XDataType; using YResidualDataType = XDataType; @@ -57,11 +58,11 @@ struct Rmsnorm2dFwdPipelineOnePass return make_static_tile_distribution( tile_distribution_encoding< sequence, - tuple>, + tuple>, tuple, sequence<0, 1>>, tuple, sequence<1, 2>>, - sequence<1, 1>, - sequence<0, 3>>{}); + sequence<1>, + sequence<3>>{}); } template ()); - const auto gamma_window = make_tile_window( - gamma_window_, Policy::template MakeGammaBlockTileDistribution()); + auto gamma_window = + make_tile_window(gamma_window_.get_bottom_tensor_view(), + gamma_window_.get_window_lengths(), + gamma_window_.get_window_origin(), + Policy::template MakeGammaBlockTileDistribution()); auto x_residual_window = make_tile_window(x_residual_window_.get_bottom_tensor_view(), x_residual_window_.get_window_lengths(), @@ -112,11 +116,18 @@ struct Rmsnorm2dFwdPipelineOnePass y_window_.get_window_origin(), Policy::template MakeXInnerBlockTileDistribution()); + auto sm_scale_window = + make_tile_window(sm_scale_window_.get_bottom_tensor_view(), + sm_scale_window_.get_window_lengths(), + sm_scale_window_.get_window_origin(), + Policy::template MakeGammaBlockTileDistribution()); + auto o_all_window = make_tile_window(y_window_.get_bottom_tensor_view(), y_window_.get_window_lengths(), y_window_.get_window_origin(), - Policy::template MakeXBlockTileDistribution()); + Policy::template MakeXInnerBlockTileDistribution()); + auto reduce_square_sum_func = ReduceOp::SquareAdd{}; auto reduce_sum_func = ReduceOp::Add{}; auto block_reduce2d = Policy::template GetBlockReduce2d(); @@ -127,17 +138,20 @@ struct Rmsnorm2dFwdPipelineOnePass using AccTensorType = decltype(cast_tile(load_tile(x_window))); using AccResTensorType = decltype(load_tile(x_residual_window)); + using GammaTensorType = decltype(load_tile(gamma_window)); + using SmScaleTensorType = decltype(load_tile(sm_scale_window)); + AccTensorType x_warp_tensors[Repeat_N]; AccTensorType o_warp_tensors[Repeat_N]; + GammaTensorType gamma_warp_tensors[Repeat_N]; + SmScaleTensorType sm_scale_warp_tensors[Repeat_N]; + auto square_sum = decltype(block_reduce2d(AccTensorType{}, reduce_square_sum_func.GetIdentityValue(), reduce_square_sum_func)){}; clear_tile(square_sum); - const auto sm_scale_window = - make_tile_window(sm_scale_window_, MakeSmoothInputScaleTileDistribution()); - for (int repeat_n = 0; repeat_n < Repeat_N; ++repeat_n) { auto x = load_tile(x_window); @@ -161,23 +175,26 @@ struct Rmsnorm2dFwdPipelineOnePass }); } + gamma_warp_tensors[repeat_n] = load_tile(gamma_window); + move_tile_window(gamma_window, {0, Stride_N}); + + sm_scale_warp_tensors[repeat_n] = load_tile(sm_scale_window); + if constexpr(SmScaleTensorType::is_valid()) + move_tile_window(sm_scale_window, {0, Stride_N}); + // compute mean square each-thread->cross-lane->cross-warp auto square_sum_local = block_reduce2d(x_warp_tensors[repeat_n], - reduce_square_sum_func.GetIdentityValue(), - reduce_square_sum_func); + reduce_square_sum_func.GetIdentityValue(), + reduce_square_sum_func); ck_tile::sweep_tile(square_sum, [&](auto idx) { square_sum(idx) += square_sum_local[idx]; }); } - const auto gamma = load_tile(gamma_window); - block_reduce2d_sync(square_sum, reduce_sum_func); block_reduce2d_cross_warp_sync(square_sum, smem, reduce_sum_func); - auto sm_scale = load_tile(sm_scale_window); - // compute inv-rms auto inv_rms = tile_elementwise_in( [&](const auto& v_) { return rsqrtf(v_ / row_size + epsilon); }, square_sum); @@ -188,26 +205,23 @@ struct Rmsnorm2dFwdPipelineOnePass // rmsnorm computation auto rmsn = make_static_distributed_tensor(Policy::template MakeXBlockTileDistribution()); - static_for<0, Repeat_N, 1>{}([&](auto repeat_n) + for (int repeat_n = 0; repeat_n < Repeat_N; ++repeat_n) { sweep_tile(o_warp_tensors[0], [&, inv_rms_ = inv_rms](auto idx) { constexpr auto i_idx = make_tuple(idx[number<0>{}]); constexpr auto j_idx = make_tuple(idx[number<1>{}]); - const auto gamma_ = type_convert(gamma[j_idx]); + const auto gamma_ = type_convert(gamma_warp_tensors[repeat_n][j_idx]); - auto rmsn_ = o_warp_tensors[repeat_n][idx] * inv_rms_[i_idx] * gamma_; + auto rmsn_ = x_warp_tensors[repeat_n][idx] * inv_rms_[i_idx] * gamma_; - if constexpr(sm_scale.is_valid()) + if constexpr(SmScaleTensorType::is_valid()) { - const auto xs_ = type_convert(sm_scale[j_idx]); + const auto xs_ = type_convert(sm_scale_warp_tensors[repeat_n][j_idx]); o_warp_tensors[repeat_n](idx) = rmsn_ * xs_; } }); - }); - for (int repeat_n = 0; repeat_n < Repeat_N; ++repeat_n) - { if constexpr(kFusedAdd == Rmsnorm2dFusedAddEnum::PRE_ADD_STORE) { store_tile(y_residual_window, cast_tile(x_warp_tensors[repeat_n])); @@ -220,29 +234,29 @@ struct Rmsnorm2dFwdPipelineOnePass { if constexpr(kSaveUnquant) { - Epilogue{}( - unquant_y_window, o_all_window, sm_scale_window_, y_scale_window_, o_warp_tensors, true, smem); + // Epilogue{}( + // unquant_y_window, o_all_window, sm_scale_window_, y_scale_window_, o_warp_tensors, true, smem); } else { Epilogue{}(o_window, sm_scale_window_, y_scale_window_, o_warp_tensors, true, smem); } } - else if constexpr(kFusedQuant == Rmsnorm2dFusedQuantEnum::DYNAMIC_QUANT) - { - if constexpr(kSaveUnquant) - { - Epilogue{}(unquant_y_window, o_all_window, y_scale_window_, rmsn, smem); - } - else - { - Epilogue{}(o_all_window, y_scale_window_, rmsn, smem); - } - } - else - { - Epilogue{}(o_all_window, rmsn); - } + // else if constexpr(kFusedQuant == Rmsnorm2dFusedQuantEnum::DYNAMIC_QUANT) + // { + // if constexpr(kSaveUnquant) + // { + // Epilogue{}(unquant_y_window, o_all_window, y_scale_window_, rmsn, smem); + // } + // else + // { + // Epilogue{}(o_all_window, y_scale_window_, rmsn, smem); + // } + // } + // else + // { + // Epilogue{}(o_all_window, rmsn); + // } } }; } // namespace ck_tile