mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-20 14:59:17 +00:00
[CK_TILE] layernorm have more accurate residual (#1623)
* more accurate residual * modify comment * Fix literal case in README.md --------- Co-authored-by: Po Yen Chen <PoYen.Chen@amd.com>
This commit is contained in:
@@ -45,7 +45,7 @@ struct Layernorm2dFwdPipelineDefaultPolicy
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetBlockWelford()
|
||||
{
|
||||
using P_ = BlockWelfordProblem<typename Problem::XDataType,
|
||||
using P_ = BlockWelfordProblem<typename Problem::ComputeDataType,
|
||||
typename Problem::ComputeDataType,
|
||||
typename Problem::BlockShape>;
|
||||
|
||||
@@ -55,7 +55,7 @@ struct Layernorm2dFwdPipelineDefaultPolicy
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetBlockWelfordSync()
|
||||
{
|
||||
using P_ = BlockWelfordProblem<typename Problem::XDataType,
|
||||
using P_ = BlockWelfordProblem<typename Problem::ComputeDataType,
|
||||
typename Problem::ComputeDataType,
|
||||
typename Problem::BlockShape>;
|
||||
|
||||
@@ -65,7 +65,7 @@ struct Layernorm2dFwdPipelineDefaultPolicy
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetBlockWelfordCrossWarpSync()
|
||||
{
|
||||
using P_ = BlockWelfordProblem<typename Problem::XDataType,
|
||||
using P_ = BlockWelfordProblem<typename Problem::ComputeDataType,
|
||||
typename Problem::ComputeDataType,
|
||||
typename Problem::BlockShape>;
|
||||
|
||||
@@ -77,13 +77,13 @@ struct Layernorm2dFwdPipelineDefaultPolicy
|
||||
{
|
||||
if constexpr(Problem::kNeedCrossWarpSync)
|
||||
{
|
||||
using P_ = BlockWelfordProblem<typename Problem::XDataType,
|
||||
using P_ = BlockWelfordProblem<typename Problem::ComputeDataType,
|
||||
typename Problem::ComputeDataType,
|
||||
typename Problem::BlockShape>;
|
||||
|
||||
using block_welford = BlockWelford<P_>;
|
||||
using x_block_tile =
|
||||
decltype(make_static_distributed_tensor<typename Problem::XDataType>(
|
||||
decltype(make_static_distributed_tensor<typename Problem::ComputeDataType>(
|
||||
MakeXBlockTileDistribution<Problem>()));
|
||||
using mean_var_block_tile =
|
||||
decltype(block_welford::template MakeMeanVarBlockTile<x_block_tile>());
|
||||
|
||||
@@ -87,12 +87,9 @@ struct Layernorm2dFwdPipelineOnePass
|
||||
x_residual_window_, Policy::template MakeXBlockTileDistribution<Problem>());
|
||||
auto y_residual_window = make_tile_window(
|
||||
y_residual_window_, Policy::template MakeXBlockTileDistribution<Problem>());
|
||||
const auto x_scale_window = make_tile_window(
|
||||
x_scale_window_, Policy::template MakeGammaBetaBlockTileDistribution<Problem>());
|
||||
|
||||
auto x = load_tile(x_window);
|
||||
auto x_resi = load_tile(x_residual_window);
|
||||
auto x_scale = load_tile(x_scale_window);
|
||||
auto x = load_tile(x_window);
|
||||
auto x_resi = load_tile(x_residual_window);
|
||||
|
||||
int cur_count = 0;
|
||||
int max_count =
|
||||
@@ -106,21 +103,21 @@ struct Layernorm2dFwdPipelineOnePass
|
||||
const auto gamma = load_tile(gamma_window);
|
||||
const auto beta = load_tile(beta_window);
|
||||
|
||||
auto acc = cast_tile<ComputeDataType>(x);
|
||||
|
||||
if constexpr(kFusedAdd == Layernorm2dFusedAddEnum::PRE_ADD_STORE ||
|
||||
kFusedAdd == Layernorm2dFusedAddEnum::PRE_ADD)
|
||||
{
|
||||
sweep_tile(x_resi, [&](auto idx) {
|
||||
// compute x = x_resi + x
|
||||
auto re_ = type_convert<ComputeDataType>(x_resi(idx)) +
|
||||
type_convert<ComputeDataType>(x(idx));
|
||||
x(idx) = type_convert<XDataType>(re_);
|
||||
acc(idx) = type_convert<ComputeDataType>(x_resi(idx)) + acc(idx);
|
||||
});
|
||||
if constexpr(kFusedAdd == Layernorm2dFusedAddEnum::PRE_ADD_STORE)
|
||||
store_tile(y_residual_window, x);
|
||||
store_tile(y_residual_window, cast_tile<YResidualDataType>(acc));
|
||||
}
|
||||
|
||||
// compute welford each-thread->cross-lane->cross-warp
|
||||
auto [mean, var] = block_welford(x, cur_count, max_count);
|
||||
auto [mean, var] = block_welford(acc, cur_count, max_count);
|
||||
block_welford_sync(mean, var, cur_count);
|
||||
block_welford_cross_warp_sync(mean, var, cur_count, smem);
|
||||
block_tile_welford_post_scale_var(var, cur_count);
|
||||
@@ -138,7 +135,7 @@ struct Layernorm2dFwdPipelineOnePass
|
||||
store_tile(inv_std_window, cast_tile<InvStdDataType>(inv_std));
|
||||
|
||||
// layernorm computation
|
||||
auto ln = make_static_distributed_tensor<ComputeDataType>(x.get_tile_distribution());
|
||||
auto ln = make_static_distributed_tensor<ComputeDataType>(acc.get_tile_distribution());
|
||||
sweep_tile(ln, [&, mean_ = mean](auto idx) {
|
||||
constexpr auto i_idx = make_tuple(idx[number<0>{}]);
|
||||
constexpr auto j_idx = make_tuple(idx[number<1>{}]);
|
||||
@@ -146,26 +143,15 @@ struct Layernorm2dFwdPipelineOnePass
|
||||
const auto gamma_ = type_convert<ComputeDataType>(gamma[j_idx]);
|
||||
const auto beta_ = type_convert<ComputeDataType>(beta[j_idx]);
|
||||
|
||||
const auto x_ = type_convert<ComputeDataType>(x[idx]);
|
||||
auto ln_ = (x_ - mean_[i_idx]) * inv_std[i_idx] * gamma_ + beta_;
|
||||
auto ln_ = (acc[idx] - mean_[i_idx]) * inv_std[i_idx] * gamma_ + beta_;
|
||||
|
||||
ln(idx) = ln_;
|
||||
});
|
||||
|
||||
if constexpr(kFusedQuant == Layernorm2dFusedQuantEnum::SMOOTH_DYNAMIC_QUANT)
|
||||
{
|
||||
// smooth-quant pre-scale, then run rowwise-quant
|
||||
sweep_tile(ln, [&](auto idx) {
|
||||
constexpr auto j_idx = make_tuple(idx[number<1>{}]);
|
||||
const auto xs_ = type_convert<ComputeDataType>(x_scale[j_idx]);
|
||||
ln(idx) = ln(idx) * xs_;
|
||||
});
|
||||
}
|
||||
|
||||
if constexpr(kFusedQuant == Layernorm2dFusedQuantEnum::DYNAMIC_QUANT ||
|
||||
kFusedQuant == Layernorm2dFusedQuantEnum::SMOOTH_DYNAMIC_QUANT)
|
||||
{
|
||||
Epilogue{}(y_window_, y_scale_window, ln, smem);
|
||||
Epilogue{}(y_window_, x_scale_window_, y_scale_window, ln, smem);
|
||||
}
|
||||
else
|
||||
Epilogue{}(y_window_, ln);
|
||||
|
||||
@@ -106,7 +106,7 @@ struct Layernorm2dFwdPipelineTwoPass
|
||||
auto block_welford_cross_warp_sync =
|
||||
Policy::template GetBlockWelfordCrossWarpSync<Problem>();
|
||||
|
||||
using XTensorType = decltype(load_tile(x_window));
|
||||
using XTensorType = decltype(cast_tile<ComputeDataType>(load_tile(x_window)));
|
||||
auto mean = block_welford.template MakeMeanVarBlockTile<XTensorType>();
|
||||
auto var = block_welford.template MakeMeanVarBlockTile<XTensorType>();
|
||||
|
||||
@@ -117,22 +117,22 @@ struct Layernorm2dFwdPipelineTwoPass
|
||||
|
||||
move_tile_window(x_window, {0, Block_N});
|
||||
move_tile_window(x_residual_window, {0, Block_N});
|
||||
auto acc = cast_tile<ComputeDataType>(x);
|
||||
|
||||
if constexpr(kFusedAdd == Layernorm2dFusedAddEnum::PRE_ADD_STORE ||
|
||||
kFusedAdd == Layernorm2dFusedAddEnum::PRE_ADD)
|
||||
{
|
||||
sweep_tile(x_resi, [&](auto idx) {
|
||||
// compute x = x_resi + x
|
||||
auto re_ = type_convert<ComputeDataType>(x_resi(idx)) +
|
||||
type_convert<ComputeDataType>(x(idx));
|
||||
x(idx) = type_convert<XDataType>(re_);
|
||||
acc(idx) = type_convert<ComputeDataType>(x_resi(idx)) + acc(idx);
|
||||
});
|
||||
if constexpr(kFusedAdd == Layernorm2dFusedAddEnum::PRE_ADD_STORE)
|
||||
{
|
||||
store_tile(y_residual_window, x);
|
||||
store_tile(y_residual_window, cast_tile<YResidualDataType>(acc));
|
||||
move_tile_window(y_residual_window, {0, Block_N});
|
||||
}
|
||||
}
|
||||
block_welford(x, mean, var, cur_count, max_count);
|
||||
block_welford(acc, mean, var, cur_count, max_count);
|
||||
}
|
||||
|
||||
block_welford_sync(mean, var, cur_count);
|
||||
@@ -166,21 +166,21 @@ struct Layernorm2dFwdPipelineTwoPass
|
||||
{
|
||||
auto x = load_tile(x_window);
|
||||
auto x_resi = load_tile(x_residual_window);
|
||||
auto acc = cast_tile<ComputeDataType>(x);
|
||||
|
||||
if constexpr(kFusedAdd == Layernorm2dFusedAddEnum::PRE_ADD_STORE ||
|
||||
kFusedAdd == Layernorm2dFusedAddEnum::PRE_ADD)
|
||||
{
|
||||
sweep_tile(x_resi, [&](auto idx) {
|
||||
// compute x = x_resi + x
|
||||
auto re_ = type_convert<ComputeDataType>(x_resi(idx)) +
|
||||
type_convert<ComputeDataType>(x(idx));
|
||||
x(idx) = type_convert<XDataType>(re_);
|
||||
acc(idx) = type_convert<ComputeDataType>(x_resi(idx)) + acc(idx);
|
||||
});
|
||||
}
|
||||
// load gamma/beta (TODO: support no gamma/beta?)
|
||||
const auto gamma = load_tile(gamma_window);
|
||||
const auto beta = load_tile(beta_window);
|
||||
|
||||
auto ln = make_static_distributed_tensor<ComputeDataType>(x.get_tile_distribution());
|
||||
auto ln = make_static_distributed_tensor<ComputeDataType>(acc.get_tile_distribution());
|
||||
|
||||
sweep_tile(ln, [&, mean_ = mean](auto idx) {
|
||||
constexpr auto i_idx = make_tuple(idx[number<0>{}]);
|
||||
@@ -189,8 +189,7 @@ struct Layernorm2dFwdPipelineTwoPass
|
||||
const auto gamma_ = type_convert<ComputeDataType>(gamma[j_idx]);
|
||||
const auto beta_ = type_convert<ComputeDataType>(beta[j_idx]);
|
||||
|
||||
const auto x_ = type_convert<ComputeDataType>(x[idx]);
|
||||
auto ln_ = (x_ - mean_[i_idx]) * inv_std[i_idx] * gamma_ + beta_;
|
||||
auto ln_ = (acc(idx) - mean_[i_idx]) * inv_std[i_idx] * gamma_ + beta_;
|
||||
|
||||
ln(idx) = ln_;
|
||||
});
|
||||
|
||||
Reference in New Issue
Block a user