mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-30 03:37:38 +00:00
rms prefetch
This commit is contained in:
@@ -135,6 +135,8 @@ struct Rmsnorm2dFwdPipelineOnePass
|
||||
auto block_reduce2d_cross_warp_sync =
|
||||
Policy::template GetBlockReduce2dCrossWarpSync<Problem>();
|
||||
|
||||
using XTensorType = decltype(load_tile(x_window));
|
||||
using XResTensorType = decltype(load_tile(x_residual_window));
|
||||
using AccTensorType = decltype(cast_tile<ComputeDataType>(load_tile(x_window)));
|
||||
using AccResTensorType = decltype(load_tile(x_residual_window));
|
||||
|
||||
@@ -151,37 +153,40 @@ struct Rmsnorm2dFwdPipelineOnePass
|
||||
reduce_square_sum_func.GetIdentityValue<ComputeDataType>(),
|
||||
reduce_square_sum_func)){};
|
||||
clear_tile(square_sum);
|
||||
XTensorType x[2];
|
||||
XResTensorType x_resi[2];
|
||||
x[0] = load_tile(x_window);
|
||||
x_window.move({0, Stride_N});
|
||||
|
||||
for (int repeat_n = 0; repeat_n < Repeat_N; ++repeat_n)
|
||||
x_resi[0] = load_tile(x_residual_window);
|
||||
if constexpr(x_resi[0].is_valid())
|
||||
move_tile_window(x_residual_window, {0, Stride_N});
|
||||
|
||||
static_for<0, Repeat_N, 1>{}([&](auto repeat_n)
|
||||
{
|
||||
auto x = load_tile(x_window);
|
||||
auto ld_stage = repeat_n % 2 + 1;
|
||||
auto st_stage = repeat_n % 2;
|
||||
x[ld_stage] = load_tile(x_window);
|
||||
x_window.move({0, Stride_N});
|
||||
|
||||
auto x_resi = load_tile(x_residual_window);
|
||||
if constexpr(x_resi.is_valid())
|
||||
x_resi[ld_stage] = load_tile(x_residual_window);
|
||||
if constexpr(x_resi[0].is_valid())
|
||||
move_tile_window(x_residual_window, {0, Stride_N});
|
||||
|
||||
// load gamma (TODO: support no gamma?)
|
||||
|
||||
|
||||
x_warp_tensors[repeat_n] = cast_tile<ComputeDataType>(x);
|
||||
x_warp_tensors[repeat_n] = cast_tile<ComputeDataType>(x[st_stage]);
|
||||
|
||||
if constexpr(kFusedAdd == Rmsnorm2dFusedAddEnum::PRE_ADD ||
|
||||
kFusedAdd == Rmsnorm2dFusedAddEnum::PRE_ADD_STORE)
|
||||
{
|
||||
sweep_tile(x_resi, [&](auto idx) {
|
||||
sweep_tile(x_resi[0], [&](auto idx) {
|
||||
// compute x = x_resi + x
|
||||
x_warp_tensors[repeat_n](idx) = type_convert<ComputeDataType>(x_resi(idx)) + x_warp_tensors[repeat_n](idx);
|
||||
x_warp_tensors[repeat_n](idx) = type_convert<ComputeDataType>(x_resi[st_stage](idx)) + x_warp_tensors[repeat_n](idx);
|
||||
});
|
||||
}
|
||||
|
||||
gamma_warp_tensors[repeat_n] = load_tile(gamma_window);
|
||||
move_tile_window(gamma_window, {0, Stride_N});
|
||||
|
||||
sm_scale_warp_tensors[repeat_n] = load_tile(sm_scale_window);
|
||||
if constexpr(SmScaleTensorType::is_valid())
|
||||
move_tile_window(sm_scale_window, {0, Stride_N});
|
||||
|
||||
// compute mean square each-thread->cross-lane->cross-warp
|
||||
auto square_sum_local = block_reduce2d(x_warp_tensors[repeat_n],
|
||||
reduce_square_sum_func.GetIdentityValue<ComputeDataType>(),
|
||||
@@ -190,6 +195,42 @@ struct Rmsnorm2dFwdPipelineOnePass
|
||||
ck_tile::sweep_tile(square_sum, [&](auto idx) {
|
||||
square_sum(idx) += square_sum_local[idx];
|
||||
});
|
||||
});
|
||||
|
||||
|
||||
{
|
||||
constexpr auto tail_stage = Repeat_N - 1;
|
||||
constexpr auto st_stage = tail_stage % 2;
|
||||
x_warp_tensors[tail_stage] = cast_tile<ComputeDataType>(x[st_stage]);
|
||||
|
||||
if constexpr(kFusedAdd == Rmsnorm2dFusedAddEnum::PRE_ADD ||
|
||||
kFusedAdd == Rmsnorm2dFusedAddEnum::PRE_ADD_STORE)
|
||||
{
|
||||
sweep_tile(x_resi[0], [&](auto idx) {
|
||||
// compute x = x_resi + x
|
||||
x_warp_tensors[tail_stage](idx) = type_convert<ComputeDataType>(x_resi[st_stage](idx)) + x_warp_tensors[tail_stage](idx);
|
||||
});
|
||||
}
|
||||
|
||||
// compute mean square each-thread->cross-lane->cross-warp
|
||||
auto square_sum_local = block_reduce2d(x_warp_tensors[tail_stage],
|
||||
reduce_square_sum_func.GetIdentityValue<ComputeDataType>(),
|
||||
reduce_square_sum_func);
|
||||
|
||||
ck_tile::sweep_tile(square_sum, [&](auto idx) {
|
||||
square_sum(idx) += square_sum_local[idx];
|
||||
});
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (int repeat_n = 0; repeat_n < Repeat_N; ++repeat_n)
|
||||
{
|
||||
gamma_warp_tensors[repeat_n] = load_tile(gamma_window);
|
||||
move_tile_window(gamma_window, {0, Stride_N});
|
||||
|
||||
sm_scale_warp_tensors[repeat_n] = load_tile(sm_scale_window);
|
||||
if constexpr(SmScaleTensorType::is_valid())
|
||||
move_tile_window(sm_scale_window, {0, Stride_N});
|
||||
}
|
||||
|
||||
block_reduce2d_sync(square_sum, reduce_sum_func);
|
||||
@@ -205,6 +246,7 @@ struct Rmsnorm2dFwdPipelineOnePass
|
||||
// rmsnorm computation
|
||||
auto rmsn = make_static_distributed_tensor<ComputeDataType>(Policy::template MakeXBlockTileDistribution<Problem>());
|
||||
|
||||
#pragma unroll
|
||||
for (int repeat_n = 0; repeat_n < Repeat_N; ++repeat_n)
|
||||
{
|
||||
sweep_tile(o_warp_tensors[0], [&, inv_rms_ = inv_rms](auto idx) {
|
||||
|
||||
Reference in New Issue
Block a user