From 06cc880e79235dc92143c0bbeaf02bf06d63152c Mon Sep 17 00:00:00 2001 From: zanzhang Date: Tue, 26 Aug 2025 20:34:24 +0800 Subject: [PATCH] rms prefetch --- .../rmsnorm2d_fwd_pipeline_one_pass.hpp | 70 +++++++++++++++---- 1 file changed, 56 insertions(+), 14 deletions(-) diff --git a/include/ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_pipeline_one_pass.hpp b/include/ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_pipeline_one_pass.hpp index 47b653022d..3531a72c8a 100644 --- a/include/ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_pipeline_one_pass.hpp +++ b/include/ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_pipeline_one_pass.hpp @@ -135,6 +135,8 @@ struct Rmsnorm2dFwdPipelineOnePass auto block_reduce2d_cross_warp_sync = Policy::template GetBlockReduce2dCrossWarpSync(); + using XTensorType = decltype(load_tile(x_window)); + using XResTensorType = decltype(load_tile(x_residual_window)); using AccTensorType = decltype(cast_tile(load_tile(x_window))); using AccResTensorType = decltype(load_tile(x_residual_window)); @@ -151,37 +153,40 @@ struct Rmsnorm2dFwdPipelineOnePass reduce_square_sum_func.GetIdentityValue(), reduce_square_sum_func)){}; clear_tile(square_sum); + XTensorType x[2]; + XResTensorType x_resi[2]; + x[0] = load_tile(x_window); + x_window.move({0, Stride_N}); - for (int repeat_n = 0; repeat_n < Repeat_N; ++repeat_n) + x_resi[0] = load_tile(x_residual_window); + if constexpr(x_resi[0].is_valid()) + move_tile_window(x_residual_window, {0, Stride_N}); + + static_for<0, Repeat_N, 1>{}([&](auto repeat_n) { - auto x = load_tile(x_window); + auto ld_stage = repeat_n % 2 + 1; + auto st_stage = repeat_n % 2; + x[ld_stage] = load_tile(x_window); x_window.move({0, Stride_N}); - auto x_resi = load_tile(x_residual_window); - if constexpr(x_resi.is_valid()) + x_resi[ld_stage] = load_tile(x_residual_window); + if constexpr(x_resi[0].is_valid()) move_tile_window(x_residual_window, {0, Stride_N}); // load gamma (TODO: support no gamma?) - x_warp_tensors[repeat_n] = cast_tile(x); + x_warp_tensors[repeat_n] = cast_tile(x[st_stage]); if constexpr(kFusedAdd == Rmsnorm2dFusedAddEnum::PRE_ADD || kFusedAdd == Rmsnorm2dFusedAddEnum::PRE_ADD_STORE) { - sweep_tile(x_resi, [&](auto idx) { + sweep_tile(x_resi[0], [&](auto idx) { // compute x = x_resi + x - x_warp_tensors[repeat_n](idx) = type_convert(x_resi(idx)) + x_warp_tensors[repeat_n](idx); + x_warp_tensors[repeat_n](idx) = type_convert(x_resi[st_stage](idx)) + x_warp_tensors[repeat_n](idx); }); } - gamma_warp_tensors[repeat_n] = load_tile(gamma_window); - move_tile_window(gamma_window, {0, Stride_N}); - - sm_scale_warp_tensors[repeat_n] = load_tile(sm_scale_window); - if constexpr(SmScaleTensorType::is_valid()) - move_tile_window(sm_scale_window, {0, Stride_N}); - // compute mean square each-thread->cross-lane->cross-warp auto square_sum_local = block_reduce2d(x_warp_tensors[repeat_n], reduce_square_sum_func.GetIdentityValue(), @@ -190,6 +195,42 @@ struct Rmsnorm2dFwdPipelineOnePass ck_tile::sweep_tile(square_sum, [&](auto idx) { square_sum(idx) += square_sum_local[idx]; }); + }); + + + { + constexpr auto tail_stage = Repeat_N - 1; + constexpr auto st_stage = tail_stage % 2; + x_warp_tensors[tail_stage] = cast_tile(x[st_stage]); + + if constexpr(kFusedAdd == Rmsnorm2dFusedAddEnum::PRE_ADD || + kFusedAdd == Rmsnorm2dFusedAddEnum::PRE_ADD_STORE) + { + sweep_tile(x_resi[0], [&](auto idx) { + // compute x = x_resi + x + x_warp_tensors[tail_stage](idx) = type_convert(x_resi[st_stage](idx)) + x_warp_tensors[tail_stage](idx); + }); + } + + // compute mean square each-thread->cross-lane->cross-warp + auto square_sum_local = block_reduce2d(x_warp_tensors[tail_stage], + reduce_square_sum_func.GetIdentityValue(), + reduce_square_sum_func); + + ck_tile::sweep_tile(square_sum, [&](auto idx) { + square_sum(idx) += square_sum_local[idx]; + }); + } + +#pragma unroll + for (int repeat_n = 0; repeat_n < Repeat_N; ++repeat_n) + { + gamma_warp_tensors[repeat_n] = load_tile(gamma_window); + move_tile_window(gamma_window, {0, Stride_N}); + + sm_scale_warp_tensors[repeat_n] = load_tile(sm_scale_window); + if constexpr(SmScaleTensorType::is_valid()) + move_tile_window(sm_scale_window, {0, Stride_N}); } block_reduce2d_sync(square_sum, reduce_sum_func); @@ -205,6 +246,7 @@ struct Rmsnorm2dFwdPipelineOnePass // rmsnorm computation auto rmsn = make_static_distributed_tensor(Policy::template MakeXBlockTileDistribution()); +#pragma unroll for (int repeat_n = 0; repeat_n < Repeat_N; ++repeat_n) { sweep_tile(o_warp_tensors[0], [&, inv_rms_ = inv_rms](auto idx) {