mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-04 13:41:24 +00:00
[CK_TILE] layernorm support fused-quant/fused-add (#1604)
* add prenorm/postnorm support, refactor using generate.py * update README * update README * fix format * update some description and fix format * update format * format * use non-raw for loading * format and update n4096 * dynamic-quant ready * update readme * support fused dynamic-quant * update fused-quant, with smooth * update README * update args * update some based on comment
This commit is contained in:
@@ -5,6 +5,7 @@
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_pipeline_default_policy.hpp"
|
||||
#include "ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_traits.hpp"
|
||||
#include <string>
|
||||
#include <type_traits>
|
||||
|
||||
@@ -24,20 +25,25 @@ struct Layernorm2dFwdPipelineOnePass
|
||||
using MeanDataType = ck_tile::remove_cvref_t<typename Problem::MeanDataType>;
|
||||
using InvStdDataType = ck_tile::remove_cvref_t<typename Problem::InvStdDataType>;
|
||||
|
||||
using XResidualDataType = XDataType;
|
||||
using YResidualDataType = XDataType;
|
||||
|
||||
static constexpr bool kHasGamma = !std::is_same_v<GammaDataType, ck_tile::null_type>;
|
||||
static constexpr bool kHasBeta = !std::is_same_v<BetaDataType, ck_tile::null_type>;
|
||||
static constexpr bool kSaveMean = Problem::kSaveMeanInvStd;
|
||||
static constexpr bool kSaveInvStd = Problem::kSaveMeanInvStd;
|
||||
static constexpr bool kSaveMean = Problem::Traits::kSaveMeanInvStd;
|
||||
static constexpr bool kSaveInvStd = Problem::Traits::kSaveMeanInvStd;
|
||||
|
||||
static constexpr bool kNeedCrossWarpSync = Problem::kNeedCrossWarpSync;
|
||||
static constexpr bool kPadM = false; // TODO - BlockLayernorm2dFwdProblem::kPadM
|
||||
static constexpr bool kPadN = Problem::kPadN;
|
||||
static constexpr bool kPadN = Problem::Traits::kPadN;
|
||||
static constexpr auto kFusedAdd = Problem::Traits::kFusedAdd;
|
||||
static constexpr auto kFusedQuant = Problem::Traits::kFusedQuant;
|
||||
|
||||
static constexpr const char* name = []() {
|
||||
if constexpr(kNeedCrossWarpSync)
|
||||
return "bpr_op"; // block per row
|
||||
return "bpr"; // block per row
|
||||
else
|
||||
return "wpr_op"; // warp per row
|
||||
return "wpr"; // warp per row
|
||||
}();
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize()
|
||||
@@ -46,20 +52,30 @@ struct Layernorm2dFwdPipelineOnePass
|
||||
}
|
||||
|
||||
template <typename XWindow,
|
||||
typename XResidualWindow,
|
||||
typename GammaWindow,
|
||||
typename BetaWindow,
|
||||
typename YWindow,
|
||||
typename YResidualWindow,
|
||||
typename MeanWindow,
|
||||
typename InvStdWindow>
|
||||
typename InvStdWindow,
|
||||
typename XScaleWindow,
|
||||
typename YScaleWindow,
|
||||
typename Epilogue>
|
||||
CK_TILE_DEVICE auto operator()(const XWindow& x_window_,
|
||||
const XResidualWindow& x_residual_window_,
|
||||
const GammaWindow& gamma_window_,
|
||||
const BetaWindow& beta_window_,
|
||||
YWindow& y_window,
|
||||
YWindow& y_window_,
|
||||
const YResidualWindow& y_residual_window_,
|
||||
MeanWindow& mean_window,
|
||||
InvStdWindow& inv_std_window,
|
||||
const XScaleWindow& x_scale_window_,
|
||||
YScaleWindow& y_scale_window,
|
||||
ComputeDataType epsilon,
|
||||
ck_tile::index_t row_size,
|
||||
void* smem) const
|
||||
void* smem,
|
||||
Epilogue) const
|
||||
{
|
||||
const auto x_window =
|
||||
make_tile_window(x_window_, Policy::template MakeXBlockTileDistribution<Problem>());
|
||||
@@ -67,8 +83,17 @@ struct Layernorm2dFwdPipelineOnePass
|
||||
gamma_window_, Policy::template MakeGammaBetaBlockTileDistribution<Problem>());
|
||||
const auto beta_window = make_tile_window(
|
||||
beta_window_, Policy::template MakeGammaBetaBlockTileDistribution<Problem>());
|
||||
const auto x_residual_window = make_tile_window(
|
||||
x_residual_window_, Policy::template MakeXBlockTileDistribution<Problem>());
|
||||
auto y_residual_window = make_tile_window(
|
||||
y_residual_window_, Policy::template MakeXBlockTileDistribution<Problem>());
|
||||
const auto x_scale_window = make_tile_window(
|
||||
x_scale_window_, Policy::template MakeGammaBetaBlockTileDistribution<Problem>());
|
||||
|
||||
auto x = load_tile(x_window);
|
||||
auto x_resi = load_tile(x_residual_window);
|
||||
auto x_scale = load_tile(x_scale_window);
|
||||
|
||||
const auto x = load_tile(x_window);
|
||||
int cur_count = 0;
|
||||
int max_count =
|
||||
block_tile_welford_calculate_max_count<typename Problem::BlockShape>(row_size);
|
||||
@@ -81,6 +106,18 @@ struct Layernorm2dFwdPipelineOnePass
|
||||
const auto gamma = load_tile(gamma_window);
|
||||
const auto beta = load_tile(beta_window);
|
||||
|
||||
if constexpr(kFusedAdd == Layernorm2dFusedAddEnum::PRE_ADD_STORE ||
|
||||
kFusedAdd == Layernorm2dFusedAddEnum::PRE_ADD)
|
||||
{
|
||||
sweep_tile(x_resi, [&](auto idx) {
|
||||
// compute x = x_resi + x
|
||||
x(idx) = type_convert<YResidualDataType>(x_resi(idx)) +
|
||||
type_convert<YResidualDataType>(x(idx));
|
||||
});
|
||||
if constexpr(kFusedAdd == Layernorm2dFusedAddEnum::PRE_ADD_STORE)
|
||||
store_tile(y_residual_window, x);
|
||||
}
|
||||
|
||||
// compute welford each-thread->cross-lane->cross-warp
|
||||
auto [mean, var] = block_welford(x, cur_count, max_count);
|
||||
block_welford_sync(mean, var, cur_count);
|
||||
@@ -100,8 +137,8 @@ struct Layernorm2dFwdPipelineOnePass
|
||||
store_tile(inv_std_window, cast_tile<InvStdDataType>(inv_std));
|
||||
|
||||
// layernorm computation
|
||||
auto y = make_static_distributed_tensor<YDataType>(x.get_tile_distribution());
|
||||
sweep_tile(y, [&, mean_ = mean](auto idx) {
|
||||
auto ln = make_static_distributed_tensor<ComputeDataType>(x.get_tile_distribution());
|
||||
sweep_tile(ln, [&, mean_ = mean](auto idx) {
|
||||
constexpr auto i_idx = make_tuple(idx[number<0>{}]);
|
||||
constexpr auto j_idx = make_tuple(idx[number<1>{}]);
|
||||
|
||||
@@ -109,11 +146,28 @@ struct Layernorm2dFwdPipelineOnePass
|
||||
const auto beta_ = type_convert<ComputeDataType>(beta[j_idx]);
|
||||
|
||||
const auto x_ = type_convert<ComputeDataType>(x[idx]);
|
||||
auto y_ = (x_ - mean_[i_idx]) * inv_std[i_idx] * gamma_ + beta_;
|
||||
auto ln_ = (x_ - mean_[i_idx]) * inv_std[i_idx] * gamma_ + beta_;
|
||||
|
||||
y(idx) = type_convert<YDataType>(y_);
|
||||
ln(idx) = ln_;
|
||||
});
|
||||
store_tile(y_window, y);
|
||||
|
||||
if constexpr(kFusedQuant == Layernorm2dFusedQuantEnum::SMOOTH_DYNAMIC_QUANT)
|
||||
{
|
||||
// smooth-quant pre-scale, then run rowwise-quant
|
||||
sweep_tile(ln, [&](auto idx) {
|
||||
constexpr auto j_idx = make_tuple(idx[number<1>{}]);
|
||||
const auto xs_ = type_convert<ComputeDataType>(x_scale[j_idx]);
|
||||
ln(idx) = ln(idx) * xs_;
|
||||
});
|
||||
}
|
||||
|
||||
if constexpr(kFusedQuant == Layernorm2dFusedQuantEnum::DYNAMIC_QUANT ||
|
||||
kFusedQuant == Layernorm2dFusedQuantEnum::SMOOTH_DYNAMIC_QUANT)
|
||||
{
|
||||
Epilogue{}(y_window_, y_scale_window, ln, smem);
|
||||
}
|
||||
else
|
||||
Epilogue{}(y_window_, ln);
|
||||
}
|
||||
};
|
||||
} // namespace ck_tile
|
||||
|
||||
@@ -14,10 +14,10 @@ template <typename XDataType_,
|
||||
typename YDataType_,
|
||||
typename MeanDataType_,
|
||||
typename InvStdDataType_,
|
||||
typename XScaleDataType_,
|
||||
typename YScaleDataType_,
|
||||
typename BlockShape_,
|
||||
bool kPadN_,
|
||||
bool kSaveMeanInvStd_,
|
||||
bool kTwoPass_>
|
||||
typename Traits_>
|
||||
struct Layernorm2dFwdPipelineProblem
|
||||
{
|
||||
using XDataType = remove_cvref_t<XDataType_>;
|
||||
@@ -27,14 +27,14 @@ struct Layernorm2dFwdPipelineProblem
|
||||
using YDataType = remove_cvref_t<YDataType_>;
|
||||
using MeanDataType = remove_cvref_t<MeanDataType_>;
|
||||
using InvStdDataType = remove_cvref_t<InvStdDataType_>;
|
||||
using XScaleDataType = remove_cvref_t<XScaleDataType_>;
|
||||
using YScaleDataType = remove_cvref_t<YScaleDataType_>;
|
||||
using BlockShape = remove_cvref_t<BlockShape_>;
|
||||
|
||||
static constexpr bool kNeedCrossLaneSync = BlockShape::ThreadPerWarp_N > 1;
|
||||
static constexpr bool kNeedCrossWarpSync = BlockShape::WarpPerBlock_N > 1;
|
||||
|
||||
static constexpr bool kPadN = kPadN_;
|
||||
static constexpr bool kSaveMeanInvStd = kSaveMeanInvStd_;
|
||||
static constexpr bool kTwoPass = kTwoPass_;
|
||||
using Traits = remove_cvref_t<Traits_>;
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
@@ -24,20 +24,25 @@ struct Layernorm2dFwdPipelineTwoPass
|
||||
using MeanDataType = ck_tile::remove_cvref_t<typename Problem::MeanDataType>;
|
||||
using InvStdDataType = ck_tile::remove_cvref_t<typename Problem::InvStdDataType>;
|
||||
|
||||
using XResidualDataType = XDataType;
|
||||
using YResidualDataType = XDataType;
|
||||
|
||||
static constexpr bool kHasGamma = !std::is_same_v<GammaDataType, ck_tile::null_type>;
|
||||
static constexpr bool kHasBeta = !std::is_same_v<BetaDataType, ck_tile::null_type>;
|
||||
static constexpr bool kSaveMean = Problem::kSaveMeanInvStd;
|
||||
static constexpr bool kSaveInvStd = Problem::kSaveMeanInvStd;
|
||||
static constexpr bool kSaveMean = Problem::Traits::kSaveMeanInvStd;
|
||||
static constexpr bool kSaveInvStd = Problem::Traits::kSaveMeanInvStd;
|
||||
|
||||
static constexpr bool kNeedCrossWarpSync = Problem::kNeedCrossWarpSync;
|
||||
static constexpr bool kPadM = false; // TODO - BlockLayernorm2dFwdProblem::kPadM
|
||||
static constexpr bool kPadN = Problem::kPadN;
|
||||
static constexpr bool kPadN = Problem::Traits::kPadN;
|
||||
static constexpr auto kFusedAdd = Problem::Traits::kFusedAdd;
|
||||
static constexpr auto kFusedQuant = Problem::Traits::kFusedQuant;
|
||||
|
||||
static constexpr const char* name = []() {
|
||||
if constexpr(kNeedCrossWarpSync)
|
||||
return "bpr_tp"; // block per row
|
||||
return "bpr_2p"; // block per row
|
||||
else
|
||||
return "wpr_tp"; // warp per row
|
||||
return "wpr_2p"; // warp per row
|
||||
}();
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize()
|
||||
@@ -46,20 +51,30 @@ struct Layernorm2dFwdPipelineTwoPass
|
||||
}
|
||||
|
||||
template <typename XWindow,
|
||||
typename XResidualWindow,
|
||||
typename GammaWindow,
|
||||
typename BetaWindow,
|
||||
typename YWindow,
|
||||
typename YResidualWindow,
|
||||
typename MeanWindow,
|
||||
typename InvStdWindow>
|
||||
typename InvStdWindow,
|
||||
typename XScaleWindow,
|
||||
typename YScaleWindow,
|
||||
typename Epilogue>
|
||||
CK_TILE_DEVICE auto operator()(const XWindow& x_window_,
|
||||
const XResidualWindow& x_residual_window_,
|
||||
const GammaWindow& gamma_window_,
|
||||
const BetaWindow& beta_window_,
|
||||
YWindow& y_window,
|
||||
const YResidualWindow& y_residual_window_,
|
||||
MeanWindow& mean_window,
|
||||
InvStdWindow& inv_std_window,
|
||||
const XScaleWindow& /*x_scale_window*/,
|
||||
YScaleWindow& /*y_scale_window*/,
|
||||
ComputeDataType epsilon,
|
||||
ck_tile::index_t row_size,
|
||||
void* smem) const
|
||||
void* smem,
|
||||
Epilogue) const
|
||||
{
|
||||
auto x_window =
|
||||
make_tile_window(x_window_, Policy::template MakeXBlockTileDistribution<Problem>());
|
||||
@@ -67,6 +82,10 @@ struct Layernorm2dFwdPipelineTwoPass
|
||||
gamma_window_, Policy::template MakeGammaBetaBlockTileDistribution<Problem>());
|
||||
auto beta_window = make_tile_window(
|
||||
beta_window_, Policy::template MakeGammaBetaBlockTileDistribution<Problem>());
|
||||
auto x_residual_window = make_tile_window(
|
||||
x_residual_window_, Policy::template MakeXBlockTileDistribution<Problem>());
|
||||
auto y_residual_window = make_tile_window(
|
||||
y_residual_window_, Policy::template MakeXBlockTileDistribution<Problem>());
|
||||
|
||||
// Problem::BlockShape
|
||||
static constexpr index_t Block_N = Problem::BlockShape::Block_N;
|
||||
@@ -93,9 +112,26 @@ struct Layernorm2dFwdPipelineTwoPass
|
||||
|
||||
for(int iN = __builtin_amdgcn_readfirstlane(0); iN < num_n_tile_iteration; ++iN)
|
||||
{
|
||||
const auto x = load_tile(x_window);
|
||||
block_welford(x, mean, var, cur_count, max_count);
|
||||
auto x = load_tile(x_window);
|
||||
auto x_resi = load_tile(x_residual_window);
|
||||
|
||||
move_tile_window(x_window, {0, Block_N});
|
||||
move_tile_window(x_residual_window, {0, Block_N});
|
||||
if constexpr(kFusedAdd == Layernorm2dFusedAddEnum::PRE_ADD_STORE ||
|
||||
kFusedAdd == Layernorm2dFusedAddEnum::PRE_ADD)
|
||||
{
|
||||
sweep_tile(x_resi, [&](auto idx) {
|
||||
// compute x = x_resi + x
|
||||
x(idx) = type_convert<YResidualDataType>(x_resi(idx)) +
|
||||
type_convert<YResidualDataType>(x(idx));
|
||||
});
|
||||
if constexpr(kFusedAdd == Layernorm2dFusedAddEnum::PRE_ADD_STORE)
|
||||
{
|
||||
store_tile(y_residual_window, x);
|
||||
move_tile_window(y_residual_window, {0, Block_N});
|
||||
}
|
||||
}
|
||||
block_welford(x, mean, var, cur_count, max_count);
|
||||
}
|
||||
|
||||
block_welford_sync(mean, var, cur_count);
|
||||
@@ -119,6 +155,7 @@ struct Layernorm2dFwdPipelineTwoPass
|
||||
row_size % Block_N == 0 ? row_size - Block_N : row_size - row_size % Block_N;
|
||||
|
||||
move_tile_window(x_window, {0, -Block_N});
|
||||
move_tile_window(x_residual_window, {0, -Block_N});
|
||||
move_tile_window(gamma_window, {stride_to_right_most_window});
|
||||
move_tile_window(beta_window, {stride_to_right_most_window});
|
||||
move_tile_window(y_window, {0, stride_to_right_most_window});
|
||||
@@ -126,14 +163,24 @@ struct Layernorm2dFwdPipelineTwoPass
|
||||
// layernorm computation
|
||||
for(int iN = __builtin_amdgcn_readfirstlane(0); iN < num_n_tile_iteration; ++iN)
|
||||
{
|
||||
const auto x = load_tile(x_window);
|
||||
auto x = load_tile(x_window);
|
||||
auto x_resi = load_tile(x_residual_window);
|
||||
if constexpr(kFusedAdd == Layernorm2dFusedAddEnum::PRE_ADD_STORE ||
|
||||
kFusedAdd == Layernorm2dFusedAddEnum::PRE_ADD)
|
||||
{
|
||||
sweep_tile(x_resi, [&](auto idx) {
|
||||
// compute x = x_resi + x
|
||||
x(idx) = type_convert<YResidualDataType>(x_resi(idx)) +
|
||||
type_convert<YResidualDataType>(x(idx));
|
||||
});
|
||||
}
|
||||
// load gamma/beta (TODO: support no gamma/beta?)
|
||||
const auto gamma = load_tile(gamma_window);
|
||||
const auto beta = load_tile(beta_window);
|
||||
|
||||
auto y = make_static_distributed_tensor<YDataType>(x.get_tile_distribution());
|
||||
auto ln = make_static_distributed_tensor<ComputeDataType>(x.get_tile_distribution());
|
||||
|
||||
sweep_tile(y, [&, mean_ = mean](auto idx) {
|
||||
sweep_tile(ln, [&, mean_ = mean](auto idx) {
|
||||
constexpr auto i_idx = make_tuple(idx[number<0>{}]);
|
||||
constexpr auto j_idx = make_tuple(idx[number<1>{}]);
|
||||
|
||||
@@ -141,14 +188,16 @@ struct Layernorm2dFwdPipelineTwoPass
|
||||
const auto beta_ = type_convert<ComputeDataType>(beta[j_idx]);
|
||||
|
||||
const auto x_ = type_convert<ComputeDataType>(x[idx]);
|
||||
auto y_ = (x_ - mean_[i_idx]) * inv_std[i_idx] * gamma_ + beta_;
|
||||
auto ln_ = (x_ - mean_[i_idx]) * inv_std[i_idx] * gamma_ + beta_;
|
||||
|
||||
y(idx) = type_convert<YDataType>(y_);
|
||||
ln(idx) = ln_;
|
||||
});
|
||||
|
||||
store_tile(y_window, y);
|
||||
static_assert(kFusedQuant != Layernorm2dFusedQuantEnum::DYNAMIC_QUANT);
|
||||
Epilogue{}(y_window, ln);
|
||||
|
||||
move_tile_window(x_window, {0, -Block_N});
|
||||
move_tile_window(x_residual_window, {0, -Block_N});
|
||||
move_tile_window(gamma_window, {-Block_N});
|
||||
move_tile_window(beta_window, {-Block_N});
|
||||
move_tile_window(y_window, {0, -Block_N});
|
||||
|
||||
@@ -0,0 +1,54 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core/utility/type_traits.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
enum class Layernorm2dFusedAddEnum
|
||||
{
|
||||
NO_ADD = 0,
|
||||
// fused add before layernorm and store result to global
|
||||
PRE_ADD_STORE = 1,
|
||||
// fused add before layernorm, but not store result
|
||||
PRE_ADD = 2,
|
||||
};
|
||||
|
||||
// clang-format off
|
||||
template<Layernorm2dFusedAddEnum> struct Layernorm2dFusedAddEnumName;
|
||||
template<> struct Layernorm2dFusedAddEnumName<Layernorm2dFusedAddEnum::NO_ADD> { static constexpr const char * name = "no"; };
|
||||
template<> struct Layernorm2dFusedAddEnumName<Layernorm2dFusedAddEnum::PRE_ADD_STORE> { static constexpr const char * name = "pras"; };
|
||||
template<> struct Layernorm2dFusedAddEnumName<Layernorm2dFusedAddEnum::PRE_ADD> { static constexpr const char * name = "pra"; };
|
||||
// clang-format on
|
||||
|
||||
enum class Layernorm2dFusedQuantEnum
|
||||
{
|
||||
NO_SWEEP = 0,
|
||||
SMOOTH_DYNAMIC_QUANT = 1, // smooth oulier + rowwise quant, need input x-scale and store y_scale
|
||||
DYNAMIC_QUANT = 2, // rowwise quant, store out a y-scale
|
||||
};
|
||||
|
||||
// clang-format off
|
||||
template<Layernorm2dFusedQuantEnum> struct Layernorm2dFusedQuantEnumName;
|
||||
template<> struct Layernorm2dFusedQuantEnumName<Layernorm2dFusedQuantEnum::NO_SWEEP> { static constexpr const char * name = "no"; };
|
||||
template<> struct Layernorm2dFusedQuantEnumName<Layernorm2dFusedQuantEnum::DYNAMIC_QUANT> { static constexpr const char * name = "dqt"; };
|
||||
template<> struct Layernorm2dFusedQuantEnumName<Layernorm2dFusedQuantEnum::SMOOTH_DYNAMIC_QUANT> { static constexpr const char * name = "smdqt"; };
|
||||
// clang-format on
|
||||
|
||||
template <bool kPadN_,
|
||||
bool kSaveMeanInvStd_,
|
||||
bool kTwoPass_,
|
||||
Layernorm2dFusedAddEnum kFusedAdd_,
|
||||
Layernorm2dFusedQuantEnum kFusedQuant_>
|
||||
struct Layernorm2dFwdTraits
|
||||
{
|
||||
static constexpr bool kPadN = kPadN_;
|
||||
static constexpr bool kSaveMeanInvStd = kSaveMeanInvStd_;
|
||||
static constexpr bool kTwoPass = kTwoPass_;
|
||||
static constexpr Layernorm2dFusedAddEnum kFusedAdd = kFusedAdd_;
|
||||
static constexpr Layernorm2dFusedQuantEnum kFusedQuant = kFusedQuant_;
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
Reference in New Issue
Block a user