mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-07-01 04:07:56 +00:00
rms norm 9.7us
This commit is contained in:
@@ -100,6 +100,7 @@ struct Default2DAndDynamicQuantEpilogue
|
||||
const bool isArray,
|
||||
void* smem)
|
||||
{
|
||||
// Default2D{}(o_direct_dram_window_tmp, o_acc_tiles, Problem::BlockShape::Repeat_N, smem);
|
||||
// DynamicQuant{}(o_quant_dram_window_tmp, sm_scale_window_, y_scale_window, o_acc_tiles, true, smem);
|
||||
}
|
||||
};
|
||||
|
||||
@@ -109,6 +109,28 @@ struct Default2DEpilogue
|
||||
{
|
||||
return operator()<ODramWindowTmp, OAccTile>(o_dram_window_tmp, o_acc_tile);
|
||||
}
|
||||
|
||||
template <typename ODramWindowTmp, typename OAccTiles>
|
||||
CK_TILE_DEVICE auto
|
||||
operator()(ODramWindowTmp& o_dram_window_tmp, const OAccTiles& o_acc_tiles, int Repeat_N, void* = nullptr) const
|
||||
{
|
||||
// TODO: this is ugly
|
||||
for (int repeat_n = 0; repeat_n < Repeat_N; ++repeat_n)
|
||||
{
|
||||
auto o_acc_tmp = o_acc_tiles[repeat_n];
|
||||
|
||||
if constexpr(UseRawStore && (kPadM || kPadN))
|
||||
{
|
||||
store_tile_raw(o_dram_window_tmp, cast_tile<ODataType>(o_acc_tmp));
|
||||
buffer_store_fence();
|
||||
}
|
||||
else
|
||||
{
|
||||
store_tile(o_dram_window_tmp, cast_tile<ODataType>(o_acc_tmp));
|
||||
}
|
||||
o_dram_window_tmp.move({0, 5120 / Repeat_N});
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <typename Problem_, typename Policy_ = void>
|
||||
|
||||
@@ -70,6 +70,7 @@ struct Rmsnorm2dFwd
|
||||
static constexpr index_t ThreadPerWarp_N = Problem::BlockShape::ThreadPerWarp_N;
|
||||
static constexpr index_t Vector_N = Problem::BlockShape::Vector_N;
|
||||
static constexpr index_t Repeat_N = Problem::BlockShape::Repeat_N;
|
||||
static constexpr index_t Stride_N = Block_N / Repeat_N;
|
||||
|
||||
static constexpr auto I0 = number<0>{};
|
||||
static constexpr auto I1 = number<1>{};
|
||||
@@ -194,7 +195,7 @@ struct Rmsnorm2dFwd
|
||||
const auto tmp2_ = pad_tensor_view(
|
||||
tmp_, make_tuple(number<Block_M>{}, number<Block_N>{}), sequence<kPadM, kPadN>{});
|
||||
return make_tile_window(
|
||||
tmp2_, make_tuple(number<Block_M>{}, number<Block_N / Repeat_N>{}), {iM, 0});
|
||||
tmp2_, make_tuple(number<Block_M>{}, number<Stride_N>{}), {iM, 0});
|
||||
}();
|
||||
|
||||
const auto x_residual_window = [&]() {
|
||||
@@ -212,7 +213,7 @@ struct Rmsnorm2dFwd
|
||||
make_tuple(number<Block_M>{}, number<Block_N>{}),
|
||||
sequence<kPadM, kPadN>{});
|
||||
return make_tile_window(
|
||||
tmp2_, make_tuple(number<Block_M>{}, number<Block_N / Repeat_N>{}), {iM, 0});
|
||||
tmp2_, make_tuple(number<Block_M>{}, number<Stride_N>{}), {iM, 0});
|
||||
}
|
||||
else
|
||||
{
|
||||
@@ -231,7 +232,7 @@ struct Rmsnorm2dFwd
|
||||
const auto tmp2_ =
|
||||
pad_tensor_view(tmp_, make_tuple(number<Block_N>{}), sequence<kPadN>{});
|
||||
|
||||
return make_tile_window(tmp2_, make_tuple(number<Block_N>{}), {0});
|
||||
return make_tile_window(tmp2_, make_tuple(number<Stride_N>{}), {0});
|
||||
}();
|
||||
|
||||
auto y_window = [&]() {
|
||||
@@ -262,7 +263,7 @@ struct Rmsnorm2dFwd
|
||||
make_tuple(number<Block_M>{}, number<Block_N>{}),
|
||||
sequence<kPadM, kPadN>{});
|
||||
return make_tile_window(
|
||||
tmp2_, make_tuple(number<Block_M>{}, number<Block_N / Repeat_N>{}), {iM, 0});
|
||||
tmp2_, make_tuple(number<Block_M>{}, number<Stride_N>{}), {iM, 0});
|
||||
}
|
||||
else
|
||||
{
|
||||
@@ -302,11 +303,11 @@ struct Rmsnorm2dFwd
|
||||
make_tuple(number<Block_N>{}),
|
||||
sequence<false>{}); // sm_scale no need pad
|
||||
}();
|
||||
return make_tile_window(win_, make_tuple(number<Block_N>{}), {0});
|
||||
return make_tile_window(win_, make_tuple(number<Stride_N>{}), {0});
|
||||
}
|
||||
else
|
||||
{
|
||||
return make_null_tile_window(make_tuple(number<Block_N>{}));
|
||||
return make_null_tile_window(make_tuple(number<Stride_N>{}));
|
||||
}
|
||||
}();
|
||||
|
||||
@@ -347,11 +348,11 @@ struct Rmsnorm2dFwd
|
||||
make_tuple(number<Block_M>{}, number<Block_N>{}),
|
||||
sequence<kPadM, kPadN>{});
|
||||
return make_tile_window(
|
||||
tmp2_, make_tuple(number<Block_M>{}, number<Block_N>{}), {iM, 0});
|
||||
tmp2_, make_tuple(number<Block_M>{}, number<Stride_N>{}), {iM, 0});
|
||||
}
|
||||
else
|
||||
{
|
||||
return make_null_tile_window(make_tuple(number<Block_M>{}, number<Block_N>{}));
|
||||
return make_null_tile_window(make_tuple(number<Block_M>{}, number<Stride_N>{}));
|
||||
}
|
||||
}();
|
||||
|
||||
|
||||
@@ -51,11 +51,11 @@ struct Rmsnorm2dFwdPipelineDefaultPolicy
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<
|
||||
sequence<S::WarpPerBlock_M, S::ThreadPerWarp_M>,
|
||||
tuple<sequence<S::Repeat_N, S::WarpPerBlock_N, S::ThreadPerWarp_N, S::Vector_N>>,
|
||||
tuple<sequence<S::WarpPerBlock_N, S::ThreadPerWarp_N, S::Vector_N>>,
|
||||
tuple<sequence<0, 1>, sequence<0, 1>>,
|
||||
tuple<sequence<0, 1>, sequence<1, 2>>,
|
||||
sequence<1, 1>,
|
||||
sequence<0, 3>>{});
|
||||
tuple<sequence<0, 0>, sequence<1, 1>>,
|
||||
sequence<1>,
|
||||
sequence<2>>{});
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
|
||||
@@ -21,6 +21,7 @@ struct Rmsnorm2dFwdPipelineOnePass
|
||||
using ComputeDataType = ck_tile::remove_cvref_t<typename Problem::ComputeDataType>;
|
||||
using YDataType = ck_tile::remove_cvref_t<typename Problem::YDataType>;
|
||||
using InvRmsDataType = ck_tile::remove_cvref_t<typename Problem::InvRmsDataType>;
|
||||
using UnquantYDataType= ck_tile::remove_cvref_t<typename Problem::UnquantYDataType>;
|
||||
|
||||
using XResidualDataType = XDataType;
|
||||
using YResidualDataType = XDataType;
|
||||
@@ -57,11 +58,11 @@ struct Rmsnorm2dFwdPipelineOnePass
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<
|
||||
sequence<S::WarpPerBlock_M, S::ThreadPerWarp_M>,
|
||||
tuple<sequence<S::Repeat_N, S::WarpPerBlock_N, S::ThreadPerWarp_N, S::Vector_N>>,
|
||||
tuple<sequence<S::WarpPerBlock_N, S::ThreadPerWarp_N, S::Vector_N>>,
|
||||
tuple<sequence<0, 1>, sequence<0, 1>>,
|
||||
tuple<sequence<0, 1>, sequence<1, 2>>,
|
||||
sequence<1, 1>,
|
||||
sequence<0, 3>>{});
|
||||
sequence<1>,
|
||||
sequence<3>>{});
|
||||
}
|
||||
|
||||
template <typename XWindow,
|
||||
@@ -93,8 +94,11 @@ struct Rmsnorm2dFwdPipelineOnePass
|
||||
x_window_.get_window_lengths(),
|
||||
x_window_.get_window_origin(),
|
||||
Policy::template MakeXInnerBlockTileDistribution<Problem>());
|
||||
const auto gamma_window = make_tile_window(
|
||||
gamma_window_, Policy::template MakeGammaBlockTileDistribution<Problem>());
|
||||
auto gamma_window =
|
||||
make_tile_window(gamma_window_.get_bottom_tensor_view(),
|
||||
gamma_window_.get_window_lengths(),
|
||||
gamma_window_.get_window_origin(),
|
||||
Policy::template MakeGammaBlockTileDistribution<Problem>());
|
||||
auto x_residual_window =
|
||||
make_tile_window(x_residual_window_.get_bottom_tensor_view(),
|
||||
x_residual_window_.get_window_lengths(),
|
||||
@@ -112,11 +116,18 @@ struct Rmsnorm2dFwdPipelineOnePass
|
||||
y_window_.get_window_origin(),
|
||||
Policy::template MakeXInnerBlockTileDistribution<Problem>());
|
||||
|
||||
auto sm_scale_window =
|
||||
make_tile_window(sm_scale_window_.get_bottom_tensor_view(),
|
||||
sm_scale_window_.get_window_lengths(),
|
||||
sm_scale_window_.get_window_origin(),
|
||||
Policy::template MakeGammaBlockTileDistribution<Problem>());
|
||||
|
||||
auto o_all_window =
|
||||
make_tile_window(y_window_.get_bottom_tensor_view(),
|
||||
y_window_.get_window_lengths(),
|
||||
y_window_.get_window_origin(),
|
||||
Policy::template MakeXBlockTileDistribution<Problem>());
|
||||
Policy::template MakeXInnerBlockTileDistribution<Problem>());
|
||||
|
||||
auto reduce_square_sum_func = ReduceOp::SquareAdd{};
|
||||
auto reduce_sum_func = ReduceOp::Add{};
|
||||
auto block_reduce2d = Policy::template GetBlockReduce2d<Problem>();
|
||||
@@ -127,17 +138,20 @@ struct Rmsnorm2dFwdPipelineOnePass
|
||||
using AccTensorType = decltype(cast_tile<ComputeDataType>(load_tile(x_window)));
|
||||
using AccResTensorType = decltype(load_tile(x_residual_window));
|
||||
|
||||
using GammaTensorType = decltype(load_tile(gamma_window));
|
||||
using SmScaleTensorType = decltype(load_tile(sm_scale_window));
|
||||
|
||||
AccTensorType x_warp_tensors[Repeat_N];
|
||||
AccTensorType o_warp_tensors[Repeat_N];
|
||||
|
||||
GammaTensorType gamma_warp_tensors[Repeat_N];
|
||||
SmScaleTensorType sm_scale_warp_tensors[Repeat_N];
|
||||
|
||||
auto square_sum = decltype(block_reduce2d(AccTensorType{},
|
||||
reduce_square_sum_func.GetIdentityValue<ComputeDataType>(),
|
||||
reduce_square_sum_func)){};
|
||||
clear_tile(square_sum);
|
||||
|
||||
const auto sm_scale_window =
|
||||
make_tile_window(sm_scale_window_, MakeSmoothInputScaleTileDistribution());
|
||||
|
||||
for (int repeat_n = 0; repeat_n < Repeat_N; ++repeat_n)
|
||||
{
|
||||
auto x = load_tile(x_window);
|
||||
@@ -161,23 +175,26 @@ struct Rmsnorm2dFwdPipelineOnePass
|
||||
});
|
||||
}
|
||||
|
||||
gamma_warp_tensors[repeat_n] = load_tile(gamma_window);
|
||||
move_tile_window(gamma_window, {0, Stride_N});
|
||||
|
||||
sm_scale_warp_tensors[repeat_n] = load_tile(sm_scale_window);
|
||||
if constexpr(SmScaleTensorType::is_valid())
|
||||
move_tile_window(sm_scale_window, {0, Stride_N});
|
||||
|
||||
// compute mean square each-thread->cross-lane->cross-warp
|
||||
auto square_sum_local = block_reduce2d(x_warp_tensors[repeat_n],
|
||||
reduce_square_sum_func.GetIdentityValue<ComputeDataType>(),
|
||||
reduce_square_sum_func);
|
||||
reduce_square_sum_func.GetIdentityValue<ComputeDataType>(),
|
||||
reduce_square_sum_func);
|
||||
|
||||
ck_tile::sweep_tile(square_sum, [&](auto idx) {
|
||||
square_sum(idx) += square_sum_local[idx];
|
||||
});
|
||||
}
|
||||
|
||||
const auto gamma = load_tile(gamma_window);
|
||||
|
||||
block_reduce2d_sync(square_sum, reduce_sum_func);
|
||||
block_reduce2d_cross_warp_sync(square_sum, smem, reduce_sum_func);
|
||||
|
||||
auto sm_scale = load_tile(sm_scale_window);
|
||||
|
||||
// compute inv-rms
|
||||
auto inv_rms = tile_elementwise_in(
|
||||
[&](const auto& v_) { return rsqrtf(v_ / row_size + epsilon); }, square_sum);
|
||||
@@ -188,26 +205,23 @@ struct Rmsnorm2dFwdPipelineOnePass
|
||||
// rmsnorm computation
|
||||
auto rmsn = make_static_distributed_tensor<ComputeDataType>(Policy::template MakeXBlockTileDistribution<Problem>());
|
||||
|
||||
static_for<0, Repeat_N, 1>{}([&](auto repeat_n)
|
||||
for (int repeat_n = 0; repeat_n < Repeat_N; ++repeat_n)
|
||||
{
|
||||
sweep_tile(o_warp_tensors[0], [&, inv_rms_ = inv_rms](auto idx) {
|
||||
constexpr auto i_idx = make_tuple(idx[number<0>{}]);
|
||||
constexpr auto j_idx = make_tuple(idx[number<1>{}]);
|
||||
|
||||
const auto gamma_ = type_convert<ComputeDataType>(gamma[j_idx]);
|
||||
const auto gamma_ = type_convert<ComputeDataType>(gamma_warp_tensors[repeat_n][j_idx]);
|
||||
|
||||
auto rmsn_ = o_warp_tensors[repeat_n][idx] * inv_rms_[i_idx] * gamma_;
|
||||
auto rmsn_ = x_warp_tensors[repeat_n][idx] * inv_rms_[i_idx] * gamma_;
|
||||
|
||||
if constexpr(sm_scale.is_valid())
|
||||
if constexpr(SmScaleTensorType::is_valid())
|
||||
{
|
||||
const auto xs_ = type_convert<ComputeDataType>(sm_scale[j_idx]);
|
||||
const auto xs_ = type_convert<ComputeDataType>(sm_scale_warp_tensors[repeat_n][j_idx]);
|
||||
o_warp_tensors[repeat_n](idx) = rmsn_ * xs_;
|
||||
}
|
||||
});
|
||||
});
|
||||
|
||||
for (int repeat_n = 0; repeat_n < Repeat_N; ++repeat_n)
|
||||
{
|
||||
if constexpr(kFusedAdd == Rmsnorm2dFusedAddEnum::PRE_ADD_STORE)
|
||||
{
|
||||
store_tile(y_residual_window, cast_tile<YResidualDataType>(x_warp_tensors[repeat_n]));
|
||||
@@ -220,29 +234,29 @@ struct Rmsnorm2dFwdPipelineOnePass
|
||||
{
|
||||
if constexpr(kSaveUnquant)
|
||||
{
|
||||
Epilogue{}(
|
||||
unquant_y_window, o_all_window, sm_scale_window_, y_scale_window_, o_warp_tensors, true, smem);
|
||||
// Epilogue{}(
|
||||
// unquant_y_window, o_all_window, sm_scale_window_, y_scale_window_, o_warp_tensors, true, smem);
|
||||
}
|
||||
else
|
||||
{
|
||||
Epilogue{}(o_window, sm_scale_window_, y_scale_window_, o_warp_tensors, true, smem);
|
||||
}
|
||||
}
|
||||
else if constexpr(kFusedQuant == Rmsnorm2dFusedQuantEnum::DYNAMIC_QUANT)
|
||||
{
|
||||
if constexpr(kSaveUnquant)
|
||||
{
|
||||
Epilogue{}(unquant_y_window, o_all_window, y_scale_window_, rmsn, smem);
|
||||
}
|
||||
else
|
||||
{
|
||||
Epilogue{}(o_all_window, y_scale_window_, rmsn, smem);
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
Epilogue{}(o_all_window, rmsn);
|
||||
}
|
||||
// else if constexpr(kFusedQuant == Rmsnorm2dFusedQuantEnum::DYNAMIC_QUANT)
|
||||
// {
|
||||
// if constexpr(kSaveUnquant)
|
||||
// {
|
||||
// Epilogue{}(unquant_y_window, o_all_window, y_scale_window_, rmsn, smem);
|
||||
// }
|
||||
// else
|
||||
// {
|
||||
// Epilogue{}(o_all_window, y_scale_window_, rmsn, smem);
|
||||
// }
|
||||
// }
|
||||
// else
|
||||
// {
|
||||
// Epilogue{}(o_all_window, rmsn);
|
||||
// }
|
||||
}
|
||||
};
|
||||
} // namespace ck_tile
|
||||
|
||||
Reference in New Issue
Block a user