[CK_TILE] layernorm support fused-quant/fused-add (#1604)

* add prenorm/postnorm support, refactor using generate.py

* update README

* update README

* fix format

* update some description and fix format

* update format

* format

* use non-raw for loading

* format and update n4096

* dynamic-quant ready

* update readme

* support fused dynamic-quant

* update fused-quant, with smooth

* update README

* update args

* update some based on comment
This commit is contained in:
carlushuang
2024-10-31 14:54:53 +08:00
committed by GitHub
parent 9a8a52130d
commit c3a4800c5f
61 changed files with 1790 additions and 766 deletions

View File

@@ -5,19 +5,24 @@
#include "ck_tile/core.hpp"
#include "ck_tile/ops/common.hpp"
#include "ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_traits.hpp"
namespace ck_tile {
// host side args
struct Layernorm2dFwdHostArgs
{
const void* p_x;
const void* p_gamma;
const void* p_beta;
const void* p_x; // [m ,n], input, fp16/bf16
const void* p_x_residual; // [m ,n], shortcut input, prec same as input, nullptr if not used
const void* p_x_scale; // [1 ,n], smooth scale input, fp32, nullptr if not used
const void* p_gamma; // [1, n], gamma, prec same as input
const void* p_beta; // [1, n], beta, prec same as input
void* p_y;
void* p_mean;
void* p_invStd;
void* p_y; // [m, n], output, fp16/bf16
void* p_y_residual; // [m, n], shortcut output, prec same as input, nullptr if not used
void* p_y_scale; // [m, 1], output a dynamic quant per row, nullptr if not used
void* p_mean; // [m, 1], output mean, prec same as input, nullptr if not used
void* p_invStd; // [m, 1], output inv-stdvariance, prec same as input, nullptr if not used
float epsilon;
@@ -27,10 +32,11 @@ struct Layernorm2dFwdHostArgs
};
// TODO: Extract some type to wrapper class
template <typename Pipeline_>
template <typename Pipeline_, typename Epilogue_>
struct Layernorm2dFwd
{
using Pipeline = remove_cvref_t<Pipeline_>;
using Epilogue = remove_cvref_t<Epilogue_>;
using Problem = typename Pipeline::Problem;
using XDataType = remove_cvref_t<typename Problem::XDataType>;
@@ -40,18 +46,26 @@ struct Layernorm2dFwd
using YDataType = remove_cvref_t<typename Problem::YDataType>;
using MeanDataType = remove_cvref_t<typename Problem::MeanDataType>;
using InvStdDataType = remove_cvref_t<typename Problem::InvStdDataType>;
using XScaleDataType = remove_cvref_t<typename Problem::XScaleDataType>;
using YScaleDataType = remove_cvref_t<typename Problem::YScaleDataType>;
// for simplicity, shortcut input/output type is same as X
using XResidualDataType = XDataType;
using YResidualDataType = XDataType;
static constexpr bool kHasGamma = !std::is_same_v<GammaDataType, null_type>;
static constexpr bool kHasBeta = !std::is_same_v<BetaDataType, null_type>;
static constexpr bool kSaveMeanInvStd = Problem::kSaveMeanInvStd;
static constexpr bool kSaveMean = Problem::kSaveMeanInvStd;
static constexpr bool kSaveInvStd = Problem::kSaveMeanInvStd;
static constexpr bool kSaveMeanInvStd = Problem::Traits::kSaveMeanInvStd;
static constexpr bool kSaveMean = Problem::Traits::kSaveMeanInvStd;
static constexpr bool kSaveInvStd = Problem::Traits::kSaveMeanInvStd;
static constexpr index_t Block_M = Problem::BlockShape::Block_M;
static constexpr index_t Block_N = Problem::BlockShape::Block_N;
static constexpr bool kPadM = false; // always no need to pad along M
static constexpr bool kPadN = Problem::kPadN;
static constexpr bool kTwoPass = Problem::kTwoPass;
static constexpr index_t Block_M = Problem::BlockShape::Block_M;
static constexpr index_t Block_N = Problem::BlockShape::Block_N;
static constexpr bool kPadM = false; // always no need to pad along M
static constexpr bool kPadN = Problem::Traits::kPadN;
static constexpr bool kTwoPass = Problem::Traits::kTwoPass;
static constexpr auto kFusedAdd = Problem::Traits::kFusedAdd;
static constexpr auto kFusedQuant = Problem::Traits::kFusedQuant;
static constexpr index_t ThreadPerWarp_N = Problem::BlockShape::ThreadPerWarp_N;
static constexpr index_t Vector_N = Problem::BlockShape::Vector_N;
@@ -62,13 +76,18 @@ struct Layernorm2dFwd
struct Kargs
{
const void* p_x;
const void* p_gamma;
const void* p_beta;
const void* p_x; // [m ,n], input, fp16/bf16
const void* p_x_residual; // [m ,n], shortcut input, prec same as input, nullptr if not used
const void* p_x_scale; // [1 ,n], smooth scale input, fp32, nullptr if not used
const void* p_gamma; // [1, n], gamma, prec same as input
const void* p_beta; // [1, n], beta, prec same as input
void* p_y;
void* p_mean;
void* p_invStd;
void* p_y; // [m, n], output, fp16/bf16
void* p_y_residual; // [m, n], shortcut output, prec same as input, nullptr if not used
void* p_y_scale; // [m, 1], output a dynamic quant per row, nullptr if not used
void* p_mean; // [m, 1], output mean, prec same as input, nullptr if not used
void* p_invStd; // [m, 1], output inv-stdvariance, prec same as input, nullptr if not used
float epsilon;
@@ -81,9 +100,13 @@ struct Layernorm2dFwd
CK_TILE_HOST static constexpr Kargs MakeKargs(const Hargs& hargs)
{
return Kargs{hargs.p_x,
hargs.p_x_residual,
hargs.p_x_scale,
hargs.p_gamma,
hargs.p_beta,
hargs.p_y,
hargs.p_y_residual,
hargs.p_y_scale,
hargs.p_mean,
hargs.p_invStd,
hargs.epsilon,
@@ -106,6 +129,7 @@ struct Layernorm2dFwd
template <> struct t2s<ck_tile::bf16_t> { static constexpr const char * name = "bf16"; };
template <> struct t2s<ck_tile::fp8_t> { static constexpr const char * name = "fp8"; };
template <> struct t2s<ck_tile::bf8_t> { static constexpr const char * name = "bf8"; };
template <> struct t2s<ck_tile::int8_t> { static constexpr const char * name = "int8"; };
// clang-format on
// in byte
@@ -113,24 +137,41 @@ struct Layernorm2dFwd
CK_TILE_HOST static std::string GetName()
{
#define _SS_ std::string
#define _TS_ std::to_string
// clang-format off
using S_ = typename Problem::BlockShape;
auto surfix = [&] () {
std::string n;
if (kFusedAdd != Layernorm2dFusedAddEnum::NO_ADD) n += _SS_("_") + Layernorm2dFusedAddEnumName<kFusedAdd>::name;
if (kFusedQuant != Layernorm2dFusedQuantEnum::NO_SWEEP) n += _SS_("_") + Layernorm2dFusedQuantEnumName<kFusedQuant>::name;
if (kPadN) n += "_pn";
if (kSaveMeanInvStd) n += "_mv";
if (kTwoPass) n += "_2p";
// if (kTwoPass) n += "_2p";
return n; }();
#define _SS_ std::string
#define _TS_ std::to_string
return _SS_("layernorm2d_fwd_") + _SS_(t2s<XDataType>::name) + "_" +
auto prec_str = [&] () {
std::string base_str = _SS_(t2s<XDataType>::name);
if (!std::is_same_v<XDataType, YDataType>) {
base_str += _SS_("_") + _SS_(t2s<YDataType>::name);
}
if (kFusedQuant == Layernorm2dFusedQuantEnum::SMOOTH_DYNAMIC_QUANT) {
base_str += _SS_("_sx") + _SS_(t2s<XScaleDataType>::name);
base_str += _SS_("_sy") + _SS_(t2s<YScaleDataType>::name);
}
if (kFusedQuant == Layernorm2dFusedQuantEnum::DYNAMIC_QUANT) {
base_str += _SS_("_sy") + _SS_(t2s<YScaleDataType>::name);
}
return base_str;
}();
return _SS_("layernorm2d_fwd_") + _SS_(prec_str) + "_" +
_TS_(S_::Block_M) + "x" + _TS_(S_::Block_N) + "_" + _TS_(S_::WarpPerBlock_M) + "x" + _TS_(S_::WarpPerBlock_N) + "_" +
_TS_(S_::Warp_M) + "x" + _TS_(S_::Warp_N) + "_" + _TS_(S_::Vector_M) + "x" + _TS_(S_::Vector_N) + "_" +
_SS_(Pipeline::name) + surfix;
#undef _SS_
#undef _TS_
// clang-format on
#undef _SS_
#undef _TS_
}
CK_TILE_DEVICE void operator()(Kargs kargs) const
@@ -153,6 +194,31 @@ struct Layernorm2dFwd
tmp2_, make_tuple(number<Block_M>{}, number<Block_N>{}), {iM, 0});
}();
const auto x_residual_window = [&]() {
if constexpr(kFusedAdd == Layernorm2dFusedAddEnum::PRE_ADD_STORE ||
kFusedAdd == Layernorm2dFusedAddEnum::PRE_ADD)
{
const auto tmp_ = make_naive_tensor_view<address_space_enum::global>(
static_cast<const XResidualDataType*>(kargs.p_x_residual),
make_tuple(kargs.m, kargs.n),
make_tuple(kargs.stride, 1),
number<Vector_N>{},
number<1>{});
// NOTE: we don't do any pad in this kernel for loading, assume that inside kernel
// will check the max count dynamically
const auto tmp2_ = pad_tensor_view(tmp_,
make_tuple(number<Block_M>{}, number<Block_N>{}),
sequence<false, false>{});
return make_tile_window(
tmp2_, make_tuple(number<Block_M>{}, number<Block_N>{}), {iM, 0});
}
else
{
return make_null_tile_window(make_tuple(number<Block_M>{}, number<Block_N>{}));
}
}();
const auto gamma_window = [&]() {
const auto tmp_ = make_naive_tensor_view<address_space_enum::global>(
static_cast<const GammaDataType*>(kargs.p_gamma),
@@ -194,6 +260,28 @@ struct Layernorm2dFwd
tmp2_, make_tuple(number<Block_M>{}, number<Block_N>{}), {iM, 0});
}();
auto y_residual_window = [&]() {
if constexpr(kFusedAdd == Layernorm2dFusedAddEnum::PRE_ADD_STORE)
{
auto tmp_ = make_naive_tensor_view<address_space_enum::global>(
static_cast<YResidualDataType*>(kargs.p_y_residual),
make_tuple(kargs.m, kargs.n),
make_tuple(kargs.stride, 1),
number<Vector_N>{},
number<1>{});
auto tmp2_ = pad_tensor_view(tmp_,
make_tuple(number<Block_M>{}, number<Block_N>{}),
sequence<kPadM, kPadN>{});
return make_tile_window(
tmp2_, make_tuple(number<Block_M>{}, number<Block_N>{}), {iM, 0});
}
else
{
return make_null_tile_window(make_tuple(number<Block_M>{}, number<Block_N>{}));
}
}();
auto mean_window = [&]() {
if constexpr(kSaveMean)
{
@@ -232,17 +320,60 @@ struct Layernorm2dFwd
return make_null_tile_window(make_tuple(number<Block_M>{}));
}();
auto x_scale_window = [&]() {
if constexpr(kFusedQuant == Layernorm2dFusedQuantEnum::SMOOTH_DYNAMIC_QUANT)
{
const auto win_ = [&]() {
const auto tmp_0_ = make_naive_tensor_view_packed<address_space_enum::global>(
static_cast<const XScaleDataType*>(kargs.p_x_scale),
make_tuple(kargs.n),
number<Vector_N>{});
return pad_tensor_view(tmp_0_,
make_tuple(number<Block_N>{}),
sequence<false>{}); // x_scale no need pad
}();
return make_tile_window(win_, make_tuple(number<Block_N>{}), {0});
}
else
return make_null_tile_window(make_tuple(number<Block_N>{}));
}();
auto y_scale_window = [&]() {
if constexpr(kFusedQuant == Layernorm2dFusedQuantEnum::SMOOTH_DYNAMIC_QUANT ||
kFusedQuant == Layernorm2dFusedQuantEnum::DYNAMIC_QUANT)
{
const auto win_ = [&]() {
const auto tmp_0_ = make_naive_tensor_view_packed<address_space_enum::global>(
static_cast<YScaleDataType*>(kargs.p_y_scale),
make_tuple(kargs.m),
number<1>{});
return pad_tensor_view(
tmp_0_, make_tuple(number<Block_M>{}), sequence<kPadM>{});
}();
return make_tile_window(win_, make_tuple(number<Block_M>{}), {iM});
}
else
return make_null_tile_window(make_tuple(number<Block_M>{}));
}();
__shared__ char smem[GetSmemSize()];
Pipeline{}(x_window,
x_residual_window,
gamma_window,
beta_window,
y_window,
y_residual_window,
mean_window,
inv_std_window,
x_scale_window,
y_scale_window,
static_cast<const ComputeDataType>(kargs.epsilon),
kargs.n,
smem);
smem,
Epilogue{});
}
};

View File

@@ -1,78 +0,0 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
namespace ck_tile {
/*
// clang-format off
4-level descriptor: BlockTile-> WarpPerBlock-> WarpTile-> Vector
Block_N (Warp_N * WarpPerBlock_N * Repeat_N )
+<----------------------< Repeat_N(2)>--------------------->+
| |
+<-- <WarpPerBlock_N(2)> -->+
Warp_N
+--------------+--------------+--------------+--------------+----+----------------+
Warp_M | wrap_0 | wrap_1 | | ^ ^
+--------------+--------------+ | <WarpPerBlock_M(2)> |
| wrap_2 | wrap_3 | | v
+--------------+--------------+--------------+--------------+----+ Block_M
| | |
+ + |
| | | v
+--------------+--------------+--------------+--------------+ +
each Warp-tile (e.g 16 thrd per row)
Vector_N (contiguous pixels each thrd holds along N, or vector size)
+-----------+-----------+-----------+-----------+-----------+
| thrd_0 | thrd_1 | thrd_2 | thrd_3 | ... Vector_M
+-----------+-----------+-----------+-----------+-----------+
| thrd_16 | thrd_17 | thrd_18 | thrd_19 | ...
+-----------+-----------+-----------+-----------+-----------+
// clang-format on
*/
template <typename BlockTile_, // block size, seq<M, N>
typename WarpPerBlock_, // num warps along seq<M, N>
typename WarpTile_, // warp size, seq<M, N>
typename Vector_, // contiguous pixels(vector size) along seq<M, N>
index_t BlockSize_ =
warpSize* reduce_on_sequence(WarpPerBlock_{}, multiplies{}, number<1>{})>
struct Layernorm2dShape
{
// block size
static constexpr index_t Block_M = BlockTile_::at(number<0>{});
static constexpr index_t Block_N = BlockTile_::at(number<1>{});
// num warps along seq<M, N>, within each block
static constexpr index_t WarpPerBlock_M = WarpPerBlock_::at(number<0>{});
static constexpr index_t WarpPerBlock_N = WarpPerBlock_::at(number<1>{});
// warp size
static constexpr index_t Warp_M = WarpTile_::at(number<0>{});
static constexpr index_t Warp_N = WarpTile_::at(number<1>{});
static_assert(Block_M % (WarpPerBlock_M * Warp_M) == 0);
static_assert(Block_N % (WarpPerBlock_N * Warp_N) == 0);
// repeat of each thread along seq<M, N>
static constexpr index_t Repeat_M = Block_M / (WarpPerBlock_M * Warp_M);
static constexpr index_t Repeat_N = Block_N / (WarpPerBlock_N * Warp_N);
// vector size along seq<M, N>
static constexpr index_t Vector_M = Vector_::at(number<0>{});
static constexpr index_t Vector_N = Vector_::at(number<1>{});
static_assert(Warp_M % Vector_M == 0);
static_assert(Warp_N % Vector_N == 0);
// num of threads along seq<M, N>, within each warp
static constexpr index_t ThreadPerWarp_M = Warp_M / Vector_M;
static constexpr index_t ThreadPerWarp_N = Warp_N / Vector_N;
static constexpr index_t BlockSize = BlockSize_;
};
} // namespace ck_tile

View File

@@ -5,6 +5,7 @@
#include "ck_tile/core.hpp"
#include "ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_pipeline_default_policy.hpp"
#include "ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_traits.hpp"
#include <string>
#include <type_traits>
@@ -24,20 +25,25 @@ struct Layernorm2dFwdPipelineOnePass
using MeanDataType = ck_tile::remove_cvref_t<typename Problem::MeanDataType>;
using InvStdDataType = ck_tile::remove_cvref_t<typename Problem::InvStdDataType>;
using XResidualDataType = XDataType;
using YResidualDataType = XDataType;
static constexpr bool kHasGamma = !std::is_same_v<GammaDataType, ck_tile::null_type>;
static constexpr bool kHasBeta = !std::is_same_v<BetaDataType, ck_tile::null_type>;
static constexpr bool kSaveMean = Problem::kSaveMeanInvStd;
static constexpr bool kSaveInvStd = Problem::kSaveMeanInvStd;
static constexpr bool kSaveMean = Problem::Traits::kSaveMeanInvStd;
static constexpr bool kSaveInvStd = Problem::Traits::kSaveMeanInvStd;
static constexpr bool kNeedCrossWarpSync = Problem::kNeedCrossWarpSync;
static constexpr bool kPadM = false; // TODO - BlockLayernorm2dFwdProblem::kPadM
static constexpr bool kPadN = Problem::kPadN;
static constexpr bool kPadN = Problem::Traits::kPadN;
static constexpr auto kFusedAdd = Problem::Traits::kFusedAdd;
static constexpr auto kFusedQuant = Problem::Traits::kFusedQuant;
static constexpr const char* name = []() {
if constexpr(kNeedCrossWarpSync)
return "bpr_op"; // block per row
return "bpr"; // block per row
else
return "wpr_op"; // warp per row
return "wpr"; // warp per row
}();
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize()
@@ -46,20 +52,30 @@ struct Layernorm2dFwdPipelineOnePass
}
template <typename XWindow,
typename XResidualWindow,
typename GammaWindow,
typename BetaWindow,
typename YWindow,
typename YResidualWindow,
typename MeanWindow,
typename InvStdWindow>
typename InvStdWindow,
typename XScaleWindow,
typename YScaleWindow,
typename Epilogue>
CK_TILE_DEVICE auto operator()(const XWindow& x_window_,
const XResidualWindow& x_residual_window_,
const GammaWindow& gamma_window_,
const BetaWindow& beta_window_,
YWindow& y_window,
YWindow& y_window_,
const YResidualWindow& y_residual_window_,
MeanWindow& mean_window,
InvStdWindow& inv_std_window,
const XScaleWindow& x_scale_window_,
YScaleWindow& y_scale_window,
ComputeDataType epsilon,
ck_tile::index_t row_size,
void* smem) const
void* smem,
Epilogue) const
{
const auto x_window =
make_tile_window(x_window_, Policy::template MakeXBlockTileDistribution<Problem>());
@@ -67,8 +83,17 @@ struct Layernorm2dFwdPipelineOnePass
gamma_window_, Policy::template MakeGammaBetaBlockTileDistribution<Problem>());
const auto beta_window = make_tile_window(
beta_window_, Policy::template MakeGammaBetaBlockTileDistribution<Problem>());
const auto x_residual_window = make_tile_window(
x_residual_window_, Policy::template MakeXBlockTileDistribution<Problem>());
auto y_residual_window = make_tile_window(
y_residual_window_, Policy::template MakeXBlockTileDistribution<Problem>());
const auto x_scale_window = make_tile_window(
x_scale_window_, Policy::template MakeGammaBetaBlockTileDistribution<Problem>());
auto x = load_tile(x_window);
auto x_resi = load_tile(x_residual_window);
auto x_scale = load_tile(x_scale_window);
const auto x = load_tile(x_window);
int cur_count = 0;
int max_count =
block_tile_welford_calculate_max_count<typename Problem::BlockShape>(row_size);
@@ -81,6 +106,18 @@ struct Layernorm2dFwdPipelineOnePass
const auto gamma = load_tile(gamma_window);
const auto beta = load_tile(beta_window);
if constexpr(kFusedAdd == Layernorm2dFusedAddEnum::PRE_ADD_STORE ||
kFusedAdd == Layernorm2dFusedAddEnum::PRE_ADD)
{
sweep_tile(x_resi, [&](auto idx) {
// compute x = x_resi + x
x(idx) = type_convert<YResidualDataType>(x_resi(idx)) +
type_convert<YResidualDataType>(x(idx));
});
if constexpr(kFusedAdd == Layernorm2dFusedAddEnum::PRE_ADD_STORE)
store_tile(y_residual_window, x);
}
// compute welford each-thread->cross-lane->cross-warp
auto [mean, var] = block_welford(x, cur_count, max_count);
block_welford_sync(mean, var, cur_count);
@@ -100,8 +137,8 @@ struct Layernorm2dFwdPipelineOnePass
store_tile(inv_std_window, cast_tile<InvStdDataType>(inv_std));
// layernorm computation
auto y = make_static_distributed_tensor<YDataType>(x.get_tile_distribution());
sweep_tile(y, [&, mean_ = mean](auto idx) {
auto ln = make_static_distributed_tensor<ComputeDataType>(x.get_tile_distribution());
sweep_tile(ln, [&, mean_ = mean](auto idx) {
constexpr auto i_idx = make_tuple(idx[number<0>{}]);
constexpr auto j_idx = make_tuple(idx[number<1>{}]);
@@ -109,11 +146,28 @@ struct Layernorm2dFwdPipelineOnePass
const auto beta_ = type_convert<ComputeDataType>(beta[j_idx]);
const auto x_ = type_convert<ComputeDataType>(x[idx]);
auto y_ = (x_ - mean_[i_idx]) * inv_std[i_idx] * gamma_ + beta_;
auto ln_ = (x_ - mean_[i_idx]) * inv_std[i_idx] * gamma_ + beta_;
y(idx) = type_convert<YDataType>(y_);
ln(idx) = ln_;
});
store_tile(y_window, y);
if constexpr(kFusedQuant == Layernorm2dFusedQuantEnum::SMOOTH_DYNAMIC_QUANT)
{
// smooth-quant pre-scale, then run rowwise-quant
sweep_tile(ln, [&](auto idx) {
constexpr auto j_idx = make_tuple(idx[number<1>{}]);
const auto xs_ = type_convert<ComputeDataType>(x_scale[j_idx]);
ln(idx) = ln(idx) * xs_;
});
}
if constexpr(kFusedQuant == Layernorm2dFusedQuantEnum::DYNAMIC_QUANT ||
kFusedQuant == Layernorm2dFusedQuantEnum::SMOOTH_DYNAMIC_QUANT)
{
Epilogue{}(y_window_, y_scale_window, ln, smem);
}
else
Epilogue{}(y_window_, ln);
}
};
} // namespace ck_tile

View File

@@ -14,10 +14,10 @@ template <typename XDataType_,
typename YDataType_,
typename MeanDataType_,
typename InvStdDataType_,
typename XScaleDataType_,
typename YScaleDataType_,
typename BlockShape_,
bool kPadN_,
bool kSaveMeanInvStd_,
bool kTwoPass_>
typename Traits_>
struct Layernorm2dFwdPipelineProblem
{
using XDataType = remove_cvref_t<XDataType_>;
@@ -27,14 +27,14 @@ struct Layernorm2dFwdPipelineProblem
using YDataType = remove_cvref_t<YDataType_>;
using MeanDataType = remove_cvref_t<MeanDataType_>;
using InvStdDataType = remove_cvref_t<InvStdDataType_>;
using XScaleDataType = remove_cvref_t<XScaleDataType_>;
using YScaleDataType = remove_cvref_t<YScaleDataType_>;
using BlockShape = remove_cvref_t<BlockShape_>;
static constexpr bool kNeedCrossLaneSync = BlockShape::ThreadPerWarp_N > 1;
static constexpr bool kNeedCrossWarpSync = BlockShape::WarpPerBlock_N > 1;
static constexpr bool kPadN = kPadN_;
static constexpr bool kSaveMeanInvStd = kSaveMeanInvStd_;
static constexpr bool kTwoPass = kTwoPass_;
using Traits = remove_cvref_t<Traits_>;
};
} // namespace ck_tile

View File

@@ -24,20 +24,25 @@ struct Layernorm2dFwdPipelineTwoPass
using MeanDataType = ck_tile::remove_cvref_t<typename Problem::MeanDataType>;
using InvStdDataType = ck_tile::remove_cvref_t<typename Problem::InvStdDataType>;
using XResidualDataType = XDataType;
using YResidualDataType = XDataType;
static constexpr bool kHasGamma = !std::is_same_v<GammaDataType, ck_tile::null_type>;
static constexpr bool kHasBeta = !std::is_same_v<BetaDataType, ck_tile::null_type>;
static constexpr bool kSaveMean = Problem::kSaveMeanInvStd;
static constexpr bool kSaveInvStd = Problem::kSaveMeanInvStd;
static constexpr bool kSaveMean = Problem::Traits::kSaveMeanInvStd;
static constexpr bool kSaveInvStd = Problem::Traits::kSaveMeanInvStd;
static constexpr bool kNeedCrossWarpSync = Problem::kNeedCrossWarpSync;
static constexpr bool kPadM = false; // TODO - BlockLayernorm2dFwdProblem::kPadM
static constexpr bool kPadN = Problem::kPadN;
static constexpr bool kPadN = Problem::Traits::kPadN;
static constexpr auto kFusedAdd = Problem::Traits::kFusedAdd;
static constexpr auto kFusedQuant = Problem::Traits::kFusedQuant;
static constexpr const char* name = []() {
if constexpr(kNeedCrossWarpSync)
return "bpr_tp"; // block per row
return "bpr_2p"; // block per row
else
return "wpr_tp"; // warp per row
return "wpr_2p"; // warp per row
}();
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize()
@@ -46,20 +51,30 @@ struct Layernorm2dFwdPipelineTwoPass
}
template <typename XWindow,
typename XResidualWindow,
typename GammaWindow,
typename BetaWindow,
typename YWindow,
typename YResidualWindow,
typename MeanWindow,
typename InvStdWindow>
typename InvStdWindow,
typename XScaleWindow,
typename YScaleWindow,
typename Epilogue>
CK_TILE_DEVICE auto operator()(const XWindow& x_window_,
const XResidualWindow& x_residual_window_,
const GammaWindow& gamma_window_,
const BetaWindow& beta_window_,
YWindow& y_window,
const YResidualWindow& y_residual_window_,
MeanWindow& mean_window,
InvStdWindow& inv_std_window,
const XScaleWindow& /*x_scale_window*/,
YScaleWindow& /*y_scale_window*/,
ComputeDataType epsilon,
ck_tile::index_t row_size,
void* smem) const
void* smem,
Epilogue) const
{
auto x_window =
make_tile_window(x_window_, Policy::template MakeXBlockTileDistribution<Problem>());
@@ -67,6 +82,10 @@ struct Layernorm2dFwdPipelineTwoPass
gamma_window_, Policy::template MakeGammaBetaBlockTileDistribution<Problem>());
auto beta_window = make_tile_window(
beta_window_, Policy::template MakeGammaBetaBlockTileDistribution<Problem>());
auto x_residual_window = make_tile_window(
x_residual_window_, Policy::template MakeXBlockTileDistribution<Problem>());
auto y_residual_window = make_tile_window(
y_residual_window_, Policy::template MakeXBlockTileDistribution<Problem>());
// Problem::BlockShape
static constexpr index_t Block_N = Problem::BlockShape::Block_N;
@@ -93,9 +112,26 @@ struct Layernorm2dFwdPipelineTwoPass
for(int iN = __builtin_amdgcn_readfirstlane(0); iN < num_n_tile_iteration; ++iN)
{
const auto x = load_tile(x_window);
block_welford(x, mean, var, cur_count, max_count);
auto x = load_tile(x_window);
auto x_resi = load_tile(x_residual_window);
move_tile_window(x_window, {0, Block_N});
move_tile_window(x_residual_window, {0, Block_N});
if constexpr(kFusedAdd == Layernorm2dFusedAddEnum::PRE_ADD_STORE ||
kFusedAdd == Layernorm2dFusedAddEnum::PRE_ADD)
{
sweep_tile(x_resi, [&](auto idx) {
// compute x = x_resi + x
x(idx) = type_convert<YResidualDataType>(x_resi(idx)) +
type_convert<YResidualDataType>(x(idx));
});
if constexpr(kFusedAdd == Layernorm2dFusedAddEnum::PRE_ADD_STORE)
{
store_tile(y_residual_window, x);
move_tile_window(y_residual_window, {0, Block_N});
}
}
block_welford(x, mean, var, cur_count, max_count);
}
block_welford_sync(mean, var, cur_count);
@@ -119,6 +155,7 @@ struct Layernorm2dFwdPipelineTwoPass
row_size % Block_N == 0 ? row_size - Block_N : row_size - row_size % Block_N;
move_tile_window(x_window, {0, -Block_N});
move_tile_window(x_residual_window, {0, -Block_N});
move_tile_window(gamma_window, {stride_to_right_most_window});
move_tile_window(beta_window, {stride_to_right_most_window});
move_tile_window(y_window, {0, stride_to_right_most_window});
@@ -126,14 +163,24 @@ struct Layernorm2dFwdPipelineTwoPass
// layernorm computation
for(int iN = __builtin_amdgcn_readfirstlane(0); iN < num_n_tile_iteration; ++iN)
{
const auto x = load_tile(x_window);
auto x = load_tile(x_window);
auto x_resi = load_tile(x_residual_window);
if constexpr(kFusedAdd == Layernorm2dFusedAddEnum::PRE_ADD_STORE ||
kFusedAdd == Layernorm2dFusedAddEnum::PRE_ADD)
{
sweep_tile(x_resi, [&](auto idx) {
// compute x = x_resi + x
x(idx) = type_convert<YResidualDataType>(x_resi(idx)) +
type_convert<YResidualDataType>(x(idx));
});
}
// load gamma/beta (TODO: support no gamma/beta?)
const auto gamma = load_tile(gamma_window);
const auto beta = load_tile(beta_window);
auto y = make_static_distributed_tensor<YDataType>(x.get_tile_distribution());
auto ln = make_static_distributed_tensor<ComputeDataType>(x.get_tile_distribution());
sweep_tile(y, [&, mean_ = mean](auto idx) {
sweep_tile(ln, [&, mean_ = mean](auto idx) {
constexpr auto i_idx = make_tuple(idx[number<0>{}]);
constexpr auto j_idx = make_tuple(idx[number<1>{}]);
@@ -141,14 +188,16 @@ struct Layernorm2dFwdPipelineTwoPass
const auto beta_ = type_convert<ComputeDataType>(beta[j_idx]);
const auto x_ = type_convert<ComputeDataType>(x[idx]);
auto y_ = (x_ - mean_[i_idx]) * inv_std[i_idx] * gamma_ + beta_;
auto ln_ = (x_ - mean_[i_idx]) * inv_std[i_idx] * gamma_ + beta_;
y(idx) = type_convert<YDataType>(y_);
ln(idx) = ln_;
});
store_tile(y_window, y);
static_assert(kFusedQuant != Layernorm2dFusedQuantEnum::DYNAMIC_QUANT);
Epilogue{}(y_window, ln);
move_tile_window(x_window, {0, -Block_N});
move_tile_window(x_residual_window, {0, -Block_N});
move_tile_window(gamma_window, {-Block_N});
move_tile_window(beta_window, {-Block_N});
move_tile_window(y_window, {0, -Block_N});

View File

@@ -0,0 +1,54 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core/utility/type_traits.hpp"
namespace ck_tile {
enum class Layernorm2dFusedAddEnum
{
NO_ADD = 0,
// fused add before layernorm and store result to global
PRE_ADD_STORE = 1,
// fused add before layernorm, but not store result
PRE_ADD = 2,
};
// clang-format off
template<Layernorm2dFusedAddEnum> struct Layernorm2dFusedAddEnumName;
template<> struct Layernorm2dFusedAddEnumName<Layernorm2dFusedAddEnum::NO_ADD> { static constexpr const char * name = "no"; };
template<> struct Layernorm2dFusedAddEnumName<Layernorm2dFusedAddEnum::PRE_ADD_STORE> { static constexpr const char * name = "pras"; };
template<> struct Layernorm2dFusedAddEnumName<Layernorm2dFusedAddEnum::PRE_ADD> { static constexpr const char * name = "pra"; };
// clang-format on
enum class Layernorm2dFusedQuantEnum
{
NO_SWEEP = 0,
SMOOTH_DYNAMIC_QUANT = 1, // smooth oulier + rowwise quant, need input x-scale and store y_scale
DYNAMIC_QUANT = 2, // rowwise quant, store out a y-scale
};
// clang-format off
template<Layernorm2dFusedQuantEnum> struct Layernorm2dFusedQuantEnumName;
template<> struct Layernorm2dFusedQuantEnumName<Layernorm2dFusedQuantEnum::NO_SWEEP> { static constexpr const char * name = "no"; };
template<> struct Layernorm2dFusedQuantEnumName<Layernorm2dFusedQuantEnum::DYNAMIC_QUANT> { static constexpr const char * name = "dqt"; };
template<> struct Layernorm2dFusedQuantEnumName<Layernorm2dFusedQuantEnum::SMOOTH_DYNAMIC_QUANT> { static constexpr const char * name = "smdqt"; };
// clang-format on
template <bool kPadN_,
bool kSaveMeanInvStd_,
bool kTwoPass_,
Layernorm2dFusedAddEnum kFusedAdd_,
Layernorm2dFusedQuantEnum kFusedQuant_>
struct Layernorm2dFwdTraits
{
static constexpr bool kPadN = kPadN_;
static constexpr bool kSaveMeanInvStd = kSaveMeanInvStd_;
static constexpr bool kTwoPass = kTwoPass_;
static constexpr Layernorm2dFusedAddEnum kFusedAdd = kFusedAdd_;
static constexpr Layernorm2dFusedQuantEnum kFusedQuant = kFusedQuant_;
};
} // namespace ck_tile