rms prefetch

This commit is contained in:
zanzhang
2025-08-26 20:34:24 +08:00
parent 64613296dc
commit 06cc880e79

View File

@@ -135,6 +135,8 @@ struct Rmsnorm2dFwdPipelineOnePass
auto block_reduce2d_cross_warp_sync =
Policy::template GetBlockReduce2dCrossWarpSync<Problem>();
using XTensorType = decltype(load_tile(x_window));
using XResTensorType = decltype(load_tile(x_residual_window));
using AccTensorType = decltype(cast_tile<ComputeDataType>(load_tile(x_window)));
using AccResTensorType = decltype(load_tile(x_residual_window));
@@ -151,37 +153,40 @@ struct Rmsnorm2dFwdPipelineOnePass
reduce_square_sum_func.GetIdentityValue<ComputeDataType>(),
reduce_square_sum_func)){};
clear_tile(square_sum);
XTensorType x[2];
XResTensorType x_resi[2];
x[0] = load_tile(x_window);
x_window.move({0, Stride_N});
for (int repeat_n = 0; repeat_n < Repeat_N; ++repeat_n)
x_resi[0] = load_tile(x_residual_window);
if constexpr(x_resi[0].is_valid())
move_tile_window(x_residual_window, {0, Stride_N});
static_for<0, Repeat_N, 1>{}([&](auto repeat_n)
{
auto x = load_tile(x_window);
auto ld_stage = repeat_n % 2 + 1;
auto st_stage = repeat_n % 2;
x[ld_stage] = load_tile(x_window);
x_window.move({0, Stride_N});
auto x_resi = load_tile(x_residual_window);
if constexpr(x_resi.is_valid())
x_resi[ld_stage] = load_tile(x_residual_window);
if constexpr(x_resi[0].is_valid())
move_tile_window(x_residual_window, {0, Stride_N});
// load gamma (TODO: support no gamma?)
x_warp_tensors[repeat_n] = cast_tile<ComputeDataType>(x);
x_warp_tensors[repeat_n] = cast_tile<ComputeDataType>(x[st_stage]);
if constexpr(kFusedAdd == Rmsnorm2dFusedAddEnum::PRE_ADD ||
kFusedAdd == Rmsnorm2dFusedAddEnum::PRE_ADD_STORE)
{
sweep_tile(x_resi, [&](auto idx) {
sweep_tile(x_resi[0], [&](auto idx) {
// compute x = x_resi + x
x_warp_tensors[repeat_n](idx) = type_convert<ComputeDataType>(x_resi(idx)) + x_warp_tensors[repeat_n](idx);
x_warp_tensors[repeat_n](idx) = type_convert<ComputeDataType>(x_resi[st_stage](idx)) + x_warp_tensors[repeat_n](idx);
});
}
gamma_warp_tensors[repeat_n] = load_tile(gamma_window);
move_tile_window(gamma_window, {0, Stride_N});
sm_scale_warp_tensors[repeat_n] = load_tile(sm_scale_window);
if constexpr(SmScaleTensorType::is_valid())
move_tile_window(sm_scale_window, {0, Stride_N});
// compute mean square each-thread->cross-lane->cross-warp
auto square_sum_local = block_reduce2d(x_warp_tensors[repeat_n],
reduce_square_sum_func.GetIdentityValue<ComputeDataType>(),
@@ -190,6 +195,42 @@ struct Rmsnorm2dFwdPipelineOnePass
ck_tile::sweep_tile(square_sum, [&](auto idx) {
square_sum(idx) += square_sum_local[idx];
});
});
{
constexpr auto tail_stage = Repeat_N - 1;
constexpr auto st_stage = tail_stage % 2;
x_warp_tensors[tail_stage] = cast_tile<ComputeDataType>(x[st_stage]);
if constexpr(kFusedAdd == Rmsnorm2dFusedAddEnum::PRE_ADD ||
kFusedAdd == Rmsnorm2dFusedAddEnum::PRE_ADD_STORE)
{
sweep_tile(x_resi[0], [&](auto idx) {
// compute x = x_resi + x
x_warp_tensors[tail_stage](idx) = type_convert<ComputeDataType>(x_resi[st_stage](idx)) + x_warp_tensors[tail_stage](idx);
});
}
// compute mean square each-thread->cross-lane->cross-warp
auto square_sum_local = block_reduce2d(x_warp_tensors[tail_stage],
reduce_square_sum_func.GetIdentityValue<ComputeDataType>(),
reduce_square_sum_func);
ck_tile::sweep_tile(square_sum, [&](auto idx) {
square_sum(idx) += square_sum_local[idx];
});
}
#pragma unroll
for (int repeat_n = 0; repeat_n < Repeat_N; ++repeat_n)
{
gamma_warp_tensors[repeat_n] = load_tile(gamma_window);
move_tile_window(gamma_window, {0, Stride_N});
sm_scale_warp_tensors[repeat_n] = load_tile(sm_scale_window);
if constexpr(SmScaleTensorType::is_valid())
move_tile_window(sm_scale_window, {0, Stride_N});
}
block_reduce2d_sync(square_sum, reduce_sum_func);
@@ -205,6 +246,7 @@ struct Rmsnorm2dFwdPipelineOnePass
// rmsnorm computation
auto rmsn = make_static_distributed_tensor<ComputeDataType>(Policy::template MakeXBlockTileDistribution<Problem>());
#pragma unroll
for (int repeat_n = 0; repeat_n < Repeat_N; ++repeat_n)
{
sweep_tile(o_warp_tensors[0], [&, inv_rms_ = inv_rms](auto idx) {