[CK_TILE] layernorm have more accurate residual (#1623)

* more accurate residual

* modify comment

* Fix literal case in README.md

---------

Co-authored-by: Po Yen Chen <PoYen.Chen@amd.com>
This commit is contained in:
carlushuang
2024-11-02 05:30:16 +00:00
committed by GitHub
parent 03c6448ba3
commit cb6c5d39dc
6 changed files with 97 additions and 63 deletions

View File

@@ -45,7 +45,7 @@ struct Layernorm2dFwdPipelineDefaultPolicy
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetBlockWelford()
{
using P_ = BlockWelfordProblem<typename Problem::XDataType,
using P_ = BlockWelfordProblem<typename Problem::ComputeDataType,
typename Problem::ComputeDataType,
typename Problem::BlockShape>;
@@ -55,7 +55,7 @@ struct Layernorm2dFwdPipelineDefaultPolicy
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetBlockWelfordSync()
{
using P_ = BlockWelfordProblem<typename Problem::XDataType,
using P_ = BlockWelfordProblem<typename Problem::ComputeDataType,
typename Problem::ComputeDataType,
typename Problem::BlockShape>;
@@ -65,7 +65,7 @@ struct Layernorm2dFwdPipelineDefaultPolicy
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetBlockWelfordCrossWarpSync()
{
using P_ = BlockWelfordProblem<typename Problem::XDataType,
using P_ = BlockWelfordProblem<typename Problem::ComputeDataType,
typename Problem::ComputeDataType,
typename Problem::BlockShape>;
@@ -77,13 +77,13 @@ struct Layernorm2dFwdPipelineDefaultPolicy
{
if constexpr(Problem::kNeedCrossWarpSync)
{
using P_ = BlockWelfordProblem<typename Problem::XDataType,
using P_ = BlockWelfordProblem<typename Problem::ComputeDataType,
typename Problem::ComputeDataType,
typename Problem::BlockShape>;
using block_welford = BlockWelford<P_>;
using x_block_tile =
decltype(make_static_distributed_tensor<typename Problem::XDataType>(
decltype(make_static_distributed_tensor<typename Problem::ComputeDataType>(
MakeXBlockTileDistribution<Problem>()));
using mean_var_block_tile =
decltype(block_welford::template MakeMeanVarBlockTile<x_block_tile>());

View File

@@ -87,12 +87,9 @@ struct Layernorm2dFwdPipelineOnePass
x_residual_window_, Policy::template MakeXBlockTileDistribution<Problem>());
auto y_residual_window = make_tile_window(
y_residual_window_, Policy::template MakeXBlockTileDistribution<Problem>());
const auto x_scale_window = make_tile_window(
x_scale_window_, Policy::template MakeGammaBetaBlockTileDistribution<Problem>());
auto x = load_tile(x_window);
auto x_resi = load_tile(x_residual_window);
auto x_scale = load_tile(x_scale_window);
auto x = load_tile(x_window);
auto x_resi = load_tile(x_residual_window);
int cur_count = 0;
int max_count =
@@ -106,21 +103,21 @@ struct Layernorm2dFwdPipelineOnePass
const auto gamma = load_tile(gamma_window);
const auto beta = load_tile(beta_window);
auto acc = cast_tile<ComputeDataType>(x);
if constexpr(kFusedAdd == Layernorm2dFusedAddEnum::PRE_ADD_STORE ||
kFusedAdd == Layernorm2dFusedAddEnum::PRE_ADD)
{
sweep_tile(x_resi, [&](auto idx) {
// compute x = x_resi + x
auto re_ = type_convert<ComputeDataType>(x_resi(idx)) +
type_convert<ComputeDataType>(x(idx));
x(idx) = type_convert<XDataType>(re_);
acc(idx) = type_convert<ComputeDataType>(x_resi(idx)) + acc(idx);
});
if constexpr(kFusedAdd == Layernorm2dFusedAddEnum::PRE_ADD_STORE)
store_tile(y_residual_window, x);
store_tile(y_residual_window, cast_tile<YResidualDataType>(acc));
}
// compute welford each-thread->cross-lane->cross-warp
auto [mean, var] = block_welford(x, cur_count, max_count);
auto [mean, var] = block_welford(acc, cur_count, max_count);
block_welford_sync(mean, var, cur_count);
block_welford_cross_warp_sync(mean, var, cur_count, smem);
block_tile_welford_post_scale_var(var, cur_count);
@@ -138,7 +135,7 @@ struct Layernorm2dFwdPipelineOnePass
store_tile(inv_std_window, cast_tile<InvStdDataType>(inv_std));
// layernorm computation
auto ln = make_static_distributed_tensor<ComputeDataType>(x.get_tile_distribution());
auto ln = make_static_distributed_tensor<ComputeDataType>(acc.get_tile_distribution());
sweep_tile(ln, [&, mean_ = mean](auto idx) {
constexpr auto i_idx = make_tuple(idx[number<0>{}]);
constexpr auto j_idx = make_tuple(idx[number<1>{}]);
@@ -146,26 +143,15 @@ struct Layernorm2dFwdPipelineOnePass
const auto gamma_ = type_convert<ComputeDataType>(gamma[j_idx]);
const auto beta_ = type_convert<ComputeDataType>(beta[j_idx]);
const auto x_ = type_convert<ComputeDataType>(x[idx]);
auto ln_ = (x_ - mean_[i_idx]) * inv_std[i_idx] * gamma_ + beta_;
auto ln_ = (acc[idx] - mean_[i_idx]) * inv_std[i_idx] * gamma_ + beta_;
ln(idx) = ln_;
});
if constexpr(kFusedQuant == Layernorm2dFusedQuantEnum::SMOOTH_DYNAMIC_QUANT)
{
// smooth-quant pre-scale, then run rowwise-quant
sweep_tile(ln, [&](auto idx) {
constexpr auto j_idx = make_tuple(idx[number<1>{}]);
const auto xs_ = type_convert<ComputeDataType>(x_scale[j_idx]);
ln(idx) = ln(idx) * xs_;
});
}
if constexpr(kFusedQuant == Layernorm2dFusedQuantEnum::DYNAMIC_QUANT ||
kFusedQuant == Layernorm2dFusedQuantEnum::SMOOTH_DYNAMIC_QUANT)
{
Epilogue{}(y_window_, y_scale_window, ln, smem);
Epilogue{}(y_window_, x_scale_window_, y_scale_window, ln, smem);
}
else
Epilogue{}(y_window_, ln);

View File

@@ -106,7 +106,7 @@ struct Layernorm2dFwdPipelineTwoPass
auto block_welford_cross_warp_sync =
Policy::template GetBlockWelfordCrossWarpSync<Problem>();
using XTensorType = decltype(load_tile(x_window));
using XTensorType = decltype(cast_tile<ComputeDataType>(load_tile(x_window)));
auto mean = block_welford.template MakeMeanVarBlockTile<XTensorType>();
auto var = block_welford.template MakeMeanVarBlockTile<XTensorType>();
@@ -117,22 +117,22 @@ struct Layernorm2dFwdPipelineTwoPass
move_tile_window(x_window, {0, Block_N});
move_tile_window(x_residual_window, {0, Block_N});
auto acc = cast_tile<ComputeDataType>(x);
if constexpr(kFusedAdd == Layernorm2dFusedAddEnum::PRE_ADD_STORE ||
kFusedAdd == Layernorm2dFusedAddEnum::PRE_ADD)
{
sweep_tile(x_resi, [&](auto idx) {
// compute x = x_resi + x
auto re_ = type_convert<ComputeDataType>(x_resi(idx)) +
type_convert<ComputeDataType>(x(idx));
x(idx) = type_convert<XDataType>(re_);
acc(idx) = type_convert<ComputeDataType>(x_resi(idx)) + acc(idx);
});
if constexpr(kFusedAdd == Layernorm2dFusedAddEnum::PRE_ADD_STORE)
{
store_tile(y_residual_window, x);
store_tile(y_residual_window, cast_tile<YResidualDataType>(acc));
move_tile_window(y_residual_window, {0, Block_N});
}
}
block_welford(x, mean, var, cur_count, max_count);
block_welford(acc, mean, var, cur_count, max_count);
}
block_welford_sync(mean, var, cur_count);
@@ -166,21 +166,21 @@ struct Layernorm2dFwdPipelineTwoPass
{
auto x = load_tile(x_window);
auto x_resi = load_tile(x_residual_window);
auto acc = cast_tile<ComputeDataType>(x);
if constexpr(kFusedAdd == Layernorm2dFusedAddEnum::PRE_ADD_STORE ||
kFusedAdd == Layernorm2dFusedAddEnum::PRE_ADD)
{
sweep_tile(x_resi, [&](auto idx) {
// compute x = x_resi + x
auto re_ = type_convert<ComputeDataType>(x_resi(idx)) +
type_convert<ComputeDataType>(x(idx));
x(idx) = type_convert<XDataType>(re_);
acc(idx) = type_convert<ComputeDataType>(x_resi(idx)) + acc(idx);
});
}
// load gamma/beta (TODO: support no gamma/beta?)
const auto gamma = load_tile(gamma_window);
const auto beta = load_tile(beta_window);
auto ln = make_static_distributed_tensor<ComputeDataType>(x.get_tile_distribution());
auto ln = make_static_distributed_tensor<ComputeDataType>(acc.get_tile_distribution());
sweep_tile(ln, [&, mean_ = mean](auto idx) {
constexpr auto i_idx = make_tuple(idx[number<0>{}]);
@@ -189,8 +189,7 @@ struct Layernorm2dFwdPipelineTwoPass
const auto gamma_ = type_convert<ComputeDataType>(gamma[j_idx]);
const auto beta_ = type_convert<ComputeDataType>(beta[j_idx]);
const auto x_ = type_convert<ComputeDataType>(x[idx]);
auto ln_ = (x_ - mean_[i_idx]) * inv_std[i_idx] * gamma_ + beta_;
auto ln_ = (acc(idx) - mean_[i_idx]) * inv_std[i_idx] * gamma_ + beta_;
ln(idx) = ln_;
});