rms norm 9.7us

This commit is contained in:
zanzhang
2025-08-26 19:46:58 +08:00
parent 5f9c2dbb8a
commit 64613296dc
5 changed files with 90 additions and 52 deletions

View File

@@ -100,6 +100,7 @@ struct Default2DAndDynamicQuantEpilogue
const bool isArray,
void* smem)
{
// Default2D{}(o_direct_dram_window_tmp, o_acc_tiles, Problem::BlockShape::Repeat_N, smem);
// DynamicQuant{}(o_quant_dram_window_tmp, sm_scale_window_, y_scale_window, o_acc_tiles, true, smem);
}
};

View File

@@ -109,6 +109,28 @@ struct Default2DEpilogue
{
return operator()<ODramWindowTmp, OAccTile>(o_dram_window_tmp, o_acc_tile);
}
template <typename ODramWindowTmp, typename OAccTiles>
CK_TILE_DEVICE auto
operator()(ODramWindowTmp& o_dram_window_tmp, const OAccTiles& o_acc_tiles, int Repeat_N, void* = nullptr) const
{
// TODO: this is ugly
for (int repeat_n = 0; repeat_n < Repeat_N; ++repeat_n)
{
auto o_acc_tmp = o_acc_tiles[repeat_n];
if constexpr(UseRawStore && (kPadM || kPadN))
{
store_tile_raw(o_dram_window_tmp, cast_tile<ODataType>(o_acc_tmp));
buffer_store_fence();
}
else
{
store_tile(o_dram_window_tmp, cast_tile<ODataType>(o_acc_tmp));
}
o_dram_window_tmp.move({0, 5120 / Repeat_N});
}
}
};
template <typename Problem_, typename Policy_ = void>

View File

@@ -70,6 +70,7 @@ struct Rmsnorm2dFwd
static constexpr index_t ThreadPerWarp_N = Problem::BlockShape::ThreadPerWarp_N;
static constexpr index_t Vector_N = Problem::BlockShape::Vector_N;
static constexpr index_t Repeat_N = Problem::BlockShape::Repeat_N;
static constexpr index_t Stride_N = Block_N / Repeat_N;
static constexpr auto I0 = number<0>{};
static constexpr auto I1 = number<1>{};
@@ -194,7 +195,7 @@ struct Rmsnorm2dFwd
const auto tmp2_ = pad_tensor_view(
tmp_, make_tuple(number<Block_M>{}, number<Block_N>{}), sequence<kPadM, kPadN>{});
return make_tile_window(
tmp2_, make_tuple(number<Block_M>{}, number<Block_N / Repeat_N>{}), {iM, 0});
tmp2_, make_tuple(number<Block_M>{}, number<Stride_N>{}), {iM, 0});
}();
const auto x_residual_window = [&]() {
@@ -212,7 +213,7 @@ struct Rmsnorm2dFwd
make_tuple(number<Block_M>{}, number<Block_N>{}),
sequence<kPadM, kPadN>{});
return make_tile_window(
tmp2_, make_tuple(number<Block_M>{}, number<Block_N / Repeat_N>{}), {iM, 0});
tmp2_, make_tuple(number<Block_M>{}, number<Stride_N>{}), {iM, 0});
}
else
{
@@ -231,7 +232,7 @@ struct Rmsnorm2dFwd
const auto tmp2_ =
pad_tensor_view(tmp_, make_tuple(number<Block_N>{}), sequence<kPadN>{});
return make_tile_window(tmp2_, make_tuple(number<Block_N>{}), {0});
return make_tile_window(tmp2_, make_tuple(number<Stride_N>{}), {0});
}();
auto y_window = [&]() {
@@ -262,7 +263,7 @@ struct Rmsnorm2dFwd
make_tuple(number<Block_M>{}, number<Block_N>{}),
sequence<kPadM, kPadN>{});
return make_tile_window(
tmp2_, make_tuple(number<Block_M>{}, number<Block_N / Repeat_N>{}), {iM, 0});
tmp2_, make_tuple(number<Block_M>{}, number<Stride_N>{}), {iM, 0});
}
else
{
@@ -302,11 +303,11 @@ struct Rmsnorm2dFwd
make_tuple(number<Block_N>{}),
sequence<false>{}); // sm_scale no need pad
}();
return make_tile_window(win_, make_tuple(number<Block_N>{}), {0});
return make_tile_window(win_, make_tuple(number<Stride_N>{}), {0});
}
else
{
return make_null_tile_window(make_tuple(number<Block_N>{}));
return make_null_tile_window(make_tuple(number<Stride_N>{}));
}
}();
@@ -347,11 +348,11 @@ struct Rmsnorm2dFwd
make_tuple(number<Block_M>{}, number<Block_N>{}),
sequence<kPadM, kPadN>{});
return make_tile_window(
tmp2_, make_tuple(number<Block_M>{}, number<Block_N>{}), {iM, 0});
tmp2_, make_tuple(number<Block_M>{}, number<Stride_N>{}), {iM, 0});
}
else
{
return make_null_tile_window(make_tuple(number<Block_M>{}, number<Block_N>{}));
return make_null_tile_window(make_tuple(number<Block_M>{}, number<Stride_N>{}));
}
}();

View File

@@ -51,11 +51,11 @@ struct Rmsnorm2dFwdPipelineDefaultPolicy
return make_static_tile_distribution(
tile_distribution_encoding<
sequence<S::WarpPerBlock_M, S::ThreadPerWarp_M>,
tuple<sequence<S::Repeat_N, S::WarpPerBlock_N, S::ThreadPerWarp_N, S::Vector_N>>,
tuple<sequence<S::WarpPerBlock_N, S::ThreadPerWarp_N, S::Vector_N>>,
tuple<sequence<0, 1>, sequence<0, 1>>,
tuple<sequence<0, 1>, sequence<1, 2>>,
sequence<1, 1>,
sequence<0, 3>>{});
tuple<sequence<0, 0>, sequence<1, 1>>,
sequence<1>,
sequence<2>>{});
}
template <typename Problem>

View File

@@ -21,6 +21,7 @@ struct Rmsnorm2dFwdPipelineOnePass
using ComputeDataType = ck_tile::remove_cvref_t<typename Problem::ComputeDataType>;
using YDataType = ck_tile::remove_cvref_t<typename Problem::YDataType>;
using InvRmsDataType = ck_tile::remove_cvref_t<typename Problem::InvRmsDataType>;
using UnquantYDataType= ck_tile::remove_cvref_t<typename Problem::UnquantYDataType>;
using XResidualDataType = XDataType;
using YResidualDataType = XDataType;
@@ -57,11 +58,11 @@ struct Rmsnorm2dFwdPipelineOnePass
return make_static_tile_distribution(
tile_distribution_encoding<
sequence<S::WarpPerBlock_M, S::ThreadPerWarp_M>,
tuple<sequence<S::Repeat_N, S::WarpPerBlock_N, S::ThreadPerWarp_N, S::Vector_N>>,
tuple<sequence<S::WarpPerBlock_N, S::ThreadPerWarp_N, S::Vector_N>>,
tuple<sequence<0, 1>, sequence<0, 1>>,
tuple<sequence<0, 1>, sequence<1, 2>>,
sequence<1, 1>,
sequence<0, 3>>{});
sequence<1>,
sequence<3>>{});
}
template <typename XWindow,
@@ -93,8 +94,11 @@ struct Rmsnorm2dFwdPipelineOnePass
x_window_.get_window_lengths(),
x_window_.get_window_origin(),
Policy::template MakeXInnerBlockTileDistribution<Problem>());
const auto gamma_window = make_tile_window(
gamma_window_, Policy::template MakeGammaBlockTileDistribution<Problem>());
auto gamma_window =
make_tile_window(gamma_window_.get_bottom_tensor_view(),
gamma_window_.get_window_lengths(),
gamma_window_.get_window_origin(),
Policy::template MakeGammaBlockTileDistribution<Problem>());
auto x_residual_window =
make_tile_window(x_residual_window_.get_bottom_tensor_view(),
x_residual_window_.get_window_lengths(),
@@ -112,11 +116,18 @@ struct Rmsnorm2dFwdPipelineOnePass
y_window_.get_window_origin(),
Policy::template MakeXInnerBlockTileDistribution<Problem>());
auto sm_scale_window =
make_tile_window(sm_scale_window_.get_bottom_tensor_view(),
sm_scale_window_.get_window_lengths(),
sm_scale_window_.get_window_origin(),
Policy::template MakeGammaBlockTileDistribution<Problem>());
auto o_all_window =
make_tile_window(y_window_.get_bottom_tensor_view(),
y_window_.get_window_lengths(),
y_window_.get_window_origin(),
Policy::template MakeXBlockTileDistribution<Problem>());
Policy::template MakeXInnerBlockTileDistribution<Problem>());
auto reduce_square_sum_func = ReduceOp::SquareAdd{};
auto reduce_sum_func = ReduceOp::Add{};
auto block_reduce2d = Policy::template GetBlockReduce2d<Problem>();
@@ -127,17 +138,20 @@ struct Rmsnorm2dFwdPipelineOnePass
using AccTensorType = decltype(cast_tile<ComputeDataType>(load_tile(x_window)));
using AccResTensorType = decltype(load_tile(x_residual_window));
using GammaTensorType = decltype(load_tile(gamma_window));
using SmScaleTensorType = decltype(load_tile(sm_scale_window));
AccTensorType x_warp_tensors[Repeat_N];
AccTensorType o_warp_tensors[Repeat_N];
GammaTensorType gamma_warp_tensors[Repeat_N];
SmScaleTensorType sm_scale_warp_tensors[Repeat_N];
auto square_sum = decltype(block_reduce2d(AccTensorType{},
reduce_square_sum_func.GetIdentityValue<ComputeDataType>(),
reduce_square_sum_func)){};
clear_tile(square_sum);
const auto sm_scale_window =
make_tile_window(sm_scale_window_, MakeSmoothInputScaleTileDistribution());
for (int repeat_n = 0; repeat_n < Repeat_N; ++repeat_n)
{
auto x = load_tile(x_window);
@@ -161,23 +175,26 @@ struct Rmsnorm2dFwdPipelineOnePass
});
}
gamma_warp_tensors[repeat_n] = load_tile(gamma_window);
move_tile_window(gamma_window, {0, Stride_N});
sm_scale_warp_tensors[repeat_n] = load_tile(sm_scale_window);
if constexpr(SmScaleTensorType::is_valid())
move_tile_window(sm_scale_window, {0, Stride_N});
// compute mean square each-thread->cross-lane->cross-warp
auto square_sum_local = block_reduce2d(x_warp_tensors[repeat_n],
reduce_square_sum_func.GetIdentityValue<ComputeDataType>(),
reduce_square_sum_func);
reduce_square_sum_func.GetIdentityValue<ComputeDataType>(),
reduce_square_sum_func);
ck_tile::sweep_tile(square_sum, [&](auto idx) {
square_sum(idx) += square_sum_local[idx];
});
}
const auto gamma = load_tile(gamma_window);
block_reduce2d_sync(square_sum, reduce_sum_func);
block_reduce2d_cross_warp_sync(square_sum, smem, reduce_sum_func);
auto sm_scale = load_tile(sm_scale_window);
// compute inv-rms
auto inv_rms = tile_elementwise_in(
[&](const auto& v_) { return rsqrtf(v_ / row_size + epsilon); }, square_sum);
@@ -188,26 +205,23 @@ struct Rmsnorm2dFwdPipelineOnePass
// rmsnorm computation
auto rmsn = make_static_distributed_tensor<ComputeDataType>(Policy::template MakeXBlockTileDistribution<Problem>());
static_for<0, Repeat_N, 1>{}([&](auto repeat_n)
for (int repeat_n = 0; repeat_n < Repeat_N; ++repeat_n)
{
sweep_tile(o_warp_tensors[0], [&, inv_rms_ = inv_rms](auto idx) {
constexpr auto i_idx = make_tuple(idx[number<0>{}]);
constexpr auto j_idx = make_tuple(idx[number<1>{}]);
const auto gamma_ = type_convert<ComputeDataType>(gamma[j_idx]);
const auto gamma_ = type_convert<ComputeDataType>(gamma_warp_tensors[repeat_n][j_idx]);
auto rmsn_ = o_warp_tensors[repeat_n][idx] * inv_rms_[i_idx] * gamma_;
auto rmsn_ = x_warp_tensors[repeat_n][idx] * inv_rms_[i_idx] * gamma_;
if constexpr(sm_scale.is_valid())
if constexpr(SmScaleTensorType::is_valid())
{
const auto xs_ = type_convert<ComputeDataType>(sm_scale[j_idx]);
const auto xs_ = type_convert<ComputeDataType>(sm_scale_warp_tensors[repeat_n][j_idx]);
o_warp_tensors[repeat_n](idx) = rmsn_ * xs_;
}
});
});
for (int repeat_n = 0; repeat_n < Repeat_N; ++repeat_n)
{
if constexpr(kFusedAdd == Rmsnorm2dFusedAddEnum::PRE_ADD_STORE)
{
store_tile(y_residual_window, cast_tile<YResidualDataType>(x_warp_tensors[repeat_n]));
@@ -220,29 +234,29 @@ struct Rmsnorm2dFwdPipelineOnePass
{
if constexpr(kSaveUnquant)
{
Epilogue{}(
unquant_y_window, o_all_window, sm_scale_window_, y_scale_window_, o_warp_tensors, true, smem);
// Epilogue{}(
// unquant_y_window, o_all_window, sm_scale_window_, y_scale_window_, o_warp_tensors, true, smem);
}
else
{
Epilogue{}(o_window, sm_scale_window_, y_scale_window_, o_warp_tensors, true, smem);
}
}
else if constexpr(kFusedQuant == Rmsnorm2dFusedQuantEnum::DYNAMIC_QUANT)
{
if constexpr(kSaveUnquant)
{
Epilogue{}(unquant_y_window, o_all_window, y_scale_window_, rmsn, smem);
}
else
{
Epilogue{}(o_all_window, y_scale_window_, rmsn, smem);
}
}
else
{
Epilogue{}(o_all_window, rmsn);
}
// else if constexpr(kFusedQuant == Rmsnorm2dFusedQuantEnum::DYNAMIC_QUANT)
// {
// if constexpr(kSaveUnquant)
// {
// Epilogue{}(unquant_y_window, o_all_window, y_scale_window_, rmsn, smem);
// }
// else
// {
// Epilogue{}(o_all_window, y_scale_window_, rmsn, smem);
// }
// }
// else
// {
// Epilogue{}(o_all_window, rmsn);
// }
}
};
} // namespace ck_tile