mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-05 22:22:27 +00:00
This commit is contained in:
@@ -0,0 +1,421 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/common.hpp"
|
||||
#include "ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_traits.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
// host side args
|
||||
struct Layernorm2dFwdHostArgs
|
||||
{
|
||||
const void* p_x; // [m ,n], input, fp16/bf16
|
||||
const void* p_x_residual; // [m ,n], shortcut input, prec same as input, nullptr if not used
|
||||
const void* p_sm_scale; // [1 ,n], smooth scale input, fp32, nullptr if not used
|
||||
const void* p_x_bias; // [1, n], bias, prec same as input
|
||||
const void* p_gamma; // [1, n], gamma, prec same as input
|
||||
const void* p_beta; // [1, n], beta, prec same as input
|
||||
|
||||
void* p_y; // [m, n], output, fp16/bf16
|
||||
void* p_y_residual; // [m, n], shortcut output, prec same as input, nullptr if not used
|
||||
void* p_y_scale; // [m, 1], output a dynamic quant per row, nullptr if not used
|
||||
void* p_mean; // [m, 1], output mean, prec same as input, nullptr if not used
|
||||
void* p_invStd; // [m, 1], output inv-stdvariance, prec same as input, nullptr if not used
|
||||
|
||||
float epsilon;
|
||||
|
||||
index_t m;
|
||||
index_t n;
|
||||
index_t x_stride; // x row_stride
|
||||
index_t xr_stride; // x residule row stride
|
||||
index_t y_stride; // y row stride
|
||||
index_t yr_stride; // y residule row stride
|
||||
};
|
||||
|
||||
// TODO: Extract some type to wrapper class
|
||||
template <typename Pipeline_, typename Epilogue_>
|
||||
struct Layernorm2dFwd
|
||||
{
|
||||
using Pipeline = remove_cvref_t<Pipeline_>;
|
||||
using Epilogue = remove_cvref_t<Epilogue_>;
|
||||
using Problem = typename Pipeline::Problem;
|
||||
|
||||
using XDataType = remove_cvref_t<typename Problem::XDataType>;
|
||||
using XBiasDataType = remove_cvref_t<typename Problem::XBiasDataType>;
|
||||
using GammaDataType = remove_cvref_t<typename Problem::GammaDataType>;
|
||||
using BetaDataType = remove_cvref_t<typename Problem::BetaDataType>;
|
||||
using ComputeDataType = remove_cvref_t<typename Problem::ComputeDataType>;
|
||||
using YDataType = remove_cvref_t<typename Problem::YDataType>;
|
||||
using MeanDataType = remove_cvref_t<typename Problem::MeanDataType>;
|
||||
using InvStdDataType = remove_cvref_t<typename Problem::InvStdDataType>;
|
||||
using SmoothScaleDataType = remove_cvref_t<typename Problem::SmoothScaleDataType>;
|
||||
using YScaleDataType = remove_cvref_t<typename Problem::YScaleDataType>;
|
||||
|
||||
// for simplicity, shortcut input/output type is same as X
|
||||
using XResidualDataType = XDataType;
|
||||
using YResidualDataType = XDataType;
|
||||
|
||||
static constexpr bool kHasGamma = !std::is_same_v<GammaDataType, null_type>;
|
||||
static constexpr bool kHasBeta = !std::is_same_v<BetaDataType, null_type>;
|
||||
static constexpr bool kSaveMeanInvStd = Problem::Traits::kSaveMeanInvStd;
|
||||
static constexpr bool kSaveMean = Problem::Traits::kSaveMeanInvStd;
|
||||
static constexpr bool kSaveInvStd = Problem::Traits::kSaveMeanInvStd;
|
||||
|
||||
static constexpr index_t Block_M = Problem::BlockShape::Block_M;
|
||||
static constexpr index_t Block_N = Problem::BlockShape::Block_N;
|
||||
static constexpr bool kPadM = false; // always no need to pad along M
|
||||
static constexpr bool kPadN = Problem::Traits::kPadN;
|
||||
static constexpr bool kTwoPass = Problem::Traits::kTwoPass;
|
||||
static constexpr auto kXbias = Problem::Traits::kXbias;
|
||||
static constexpr auto kFusedAdd = Problem::Traits::kFusedAdd;
|
||||
static constexpr auto kFusedQuant = Problem::Traits::kFusedQuant;
|
||||
|
||||
static constexpr index_t ThreadPerWarp_N = Problem::BlockShape::ThreadPerWarp_N;
|
||||
static constexpr index_t Vector_N = Problem::BlockShape::Vector_N;
|
||||
static constexpr index_t Repeat_N = Problem::BlockShape::Repeat_N;
|
||||
static constexpr index_t kBlockSize = Problem::BlockShape::BlockSize;
|
||||
static constexpr auto I0 = number<0>{};
|
||||
static constexpr auto I1 = number<1>{};
|
||||
|
||||
struct Kargs
|
||||
{
|
||||
const void* p_x; // [m ,n], input, fp16/bf16
|
||||
const void* p_x_residual; // [m ,n], shortcut input, prec same as input, nullptr if not used
|
||||
const void* p_sm_scale; // [1 ,n], smooth scale input, fp32, nullptr if not used
|
||||
const void* p_x_bias; // [1, n], bias, prec same as input
|
||||
const void* p_gamma; // [1, n], gamma, prec same as input
|
||||
const void* p_beta; // [1, n], beta, prec same as input
|
||||
|
||||
void* p_y; // [m, n], output, fp16/bf16
|
||||
void* p_y_residual; // [m, n], shortcut output, prec same as input, nullptr if not used
|
||||
void* p_y_scale; // [m, 1], output a dynamic quant per row, nullptr if not used
|
||||
|
||||
void* p_mean; // [m, 1], output mean, prec same as input, nullptr if not used
|
||||
void* p_invStd; // [m, 1], output inv-stdvariance, prec same as input, nullptr if not used
|
||||
|
||||
float epsilon;
|
||||
|
||||
index_t m;
|
||||
index_t n;
|
||||
index_t x_stride; // x row_stride
|
||||
index_t xr_stride; // x residule row stride
|
||||
index_t y_stride; // y row stride
|
||||
index_t yr_stride; // y residule row stride
|
||||
};
|
||||
using Hargs = Layernorm2dFwdHostArgs;
|
||||
|
||||
CK_TILE_HOST static constexpr Kargs MakeKargs(const Hargs& hargs)
|
||||
{
|
||||
return Kargs{hargs.p_x,
|
||||
hargs.p_x_residual,
|
||||
hargs.p_sm_scale,
|
||||
hargs.p_x_bias,
|
||||
hargs.p_gamma,
|
||||
hargs.p_beta,
|
||||
hargs.p_y,
|
||||
hargs.p_y_residual,
|
||||
hargs.p_y_scale,
|
||||
hargs.p_mean,
|
||||
hargs.p_invStd,
|
||||
hargs.epsilon,
|
||||
hargs.m,
|
||||
hargs.n,
|
||||
hargs.x_stride,
|
||||
hargs.xr_stride,
|
||||
hargs.y_stride,
|
||||
hargs.yr_stride};
|
||||
}
|
||||
|
||||
CK_TILE_HOST static constexpr auto GridSize(const Hargs& hargs)
|
||||
{
|
||||
return dim3(integer_divide_ceil(hargs.m, Block_M));
|
||||
}
|
||||
|
||||
CK_TILE_HOST static constexpr auto BlockSize()
|
||||
{
|
||||
return is_wave32() ? Problem::BlockShape::template GetBlockSize<true>()
|
||||
: Problem::BlockShape::template GetBlockSize<false>();
|
||||
}
|
||||
|
||||
// clang-format off
|
||||
template <typename T> struct t2s;
|
||||
template <> struct t2s<float> { static constexpr const char * name = "fp32"; };
|
||||
template <> struct t2s<ck_tile::fp16_t> { static constexpr const char * name = "fp16"; };
|
||||
template <> struct t2s<ck_tile::bf16_t> { static constexpr const char * name = "bf16"; };
|
||||
template <> struct t2s<ck_tile::fp8_t> { static constexpr const char * name = "fp8"; };
|
||||
template <> struct t2s<ck_tile::bf8_t> { static constexpr const char * name = "bf8"; };
|
||||
template <> struct t2s<ck_tile::int8_t> { static constexpr const char * name = "int8"; };
|
||||
// clang-format on
|
||||
|
||||
// in byte
|
||||
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize() { return Pipeline::GetSmemSize(); }
|
||||
|
||||
CK_TILE_HOST static std::string GetName()
|
||||
{
|
||||
#define _SS_ std::string
|
||||
#define _TS_ std::to_string
|
||||
// clang-format off
|
||||
using S_ = typename Problem::BlockShape;
|
||||
auto surfix = [&] () {
|
||||
std::string n;
|
||||
if (kXbias != Layernorm2dXBiasEnum::NO_BIAS) n += _SS_("_") + Layernorm2dXBiasEnumName<kXbias>::name;
|
||||
if (kFusedAdd != Layernorm2dFusedAddEnum::NO_ADD) n += _SS_("_") + Layernorm2dFusedAddEnumName<kFusedAdd>::name;
|
||||
if (kFusedQuant != Layernorm2dFusedQuantEnum::NO_SWEEP) n += _SS_("_") + Layernorm2dFusedQuantEnumName<kFusedQuant>::name;
|
||||
if (kPadN) n += "_pn";
|
||||
if (kSaveMeanInvStd) n += "_mv";
|
||||
// if (kTwoPass) n += "_2p";
|
||||
return n; }();
|
||||
|
||||
auto prec_str = [&] () {
|
||||
std::string base_str = _SS_(t2s<XDataType>::name);
|
||||
if (!std::is_same_v<XDataType, YDataType>) {
|
||||
base_str += _SS_("_") + _SS_(t2s<YDataType>::name);
|
||||
}
|
||||
if (kFusedQuant == Layernorm2dFusedQuantEnum::SMOOTH_DYNAMIC_QUANT) {
|
||||
base_str += _SS_("_sx") + _SS_(t2s<SmoothScaleDataType>::name);
|
||||
base_str += _SS_("_sy") + _SS_(t2s<YScaleDataType>::name);
|
||||
}
|
||||
if (kFusedQuant == Layernorm2dFusedQuantEnum::DYNAMIC_QUANT) {
|
||||
base_str += _SS_("_sy") + _SS_(t2s<YScaleDataType>::name);
|
||||
}
|
||||
return base_str;
|
||||
}();
|
||||
|
||||
return _SS_("layernorm2d_fwd_") + _SS_(prec_str) + "_" +
|
||||
_TS_(S_::Block_M) + "x" + _TS_(S_::Block_N) + "_" + _TS_(S_::WarpPerBlock_M) + "x" + _TS_(S_::WarpPerBlock_N) + "_" +
|
||||
_TS_(S_::Warp_M) + "x" + _TS_(S_::Warp_N) + "_" + _TS_(S_::Vector_M) + "x" + _TS_(S_::Vector_N) + "_" +
|
||||
_SS_(Pipeline::name) + surfix;
|
||||
// clang-format on
|
||||
#undef _SS_
|
||||
#undef _TS_
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE void operator()(Kargs kargs) const
|
||||
{
|
||||
const auto iM = get_block_id() * Block_M;
|
||||
|
||||
const auto x_window = [&]() {
|
||||
const auto tmp_ = make_naive_tensor_view<address_space_enum::global>(
|
||||
static_cast<const XDataType*>(kargs.p_x),
|
||||
make_tuple(kargs.m, kargs.n),
|
||||
make_tuple(kargs.x_stride, 1),
|
||||
number<Vector_N>{},
|
||||
number<1>{});
|
||||
|
||||
// NOTE: we don't do any pad in this kernel for loading, assume that inside kernel will
|
||||
// check the max count dynamically
|
||||
const auto tmp2_ = pad_tensor_view(
|
||||
tmp_, make_tuple(number<Block_M>{}, number<Block_N>{}), sequence<false, false>{});
|
||||
return make_tile_window(
|
||||
tmp2_, make_tuple(number<Block_M>{}, number<Block_N>{}), {iM, 0});
|
||||
}();
|
||||
|
||||
const auto x_residual_window = [&]() {
|
||||
if constexpr(kFusedAdd == Layernorm2dFusedAddEnum::PRE_ADD_STORE ||
|
||||
kFusedAdd == Layernorm2dFusedAddEnum::PRE_ADD)
|
||||
{
|
||||
const auto tmp_ = make_naive_tensor_view<address_space_enum::global>(
|
||||
static_cast<const XResidualDataType*>(kargs.p_x_residual),
|
||||
make_tuple(kargs.m, kargs.n),
|
||||
make_tuple(kargs.xr_stride, 1),
|
||||
number<Vector_N>{},
|
||||
number<1>{});
|
||||
|
||||
// NOTE: we don't do any pad in this kernel for loading, assume that inside kernel
|
||||
// will check the max count dynamically
|
||||
const auto tmp2_ = pad_tensor_view(tmp_,
|
||||
make_tuple(number<Block_M>{}, number<Block_N>{}),
|
||||
sequence<false, false>{});
|
||||
return make_tile_window(
|
||||
tmp2_, make_tuple(number<Block_M>{}, number<Block_N>{}), {iM, 0});
|
||||
}
|
||||
else
|
||||
{
|
||||
return make_null_tile_window(make_tuple(number<Block_M>{}, number<Block_N>{}));
|
||||
}
|
||||
}();
|
||||
|
||||
const auto x_bias_window = [&]() {
|
||||
if constexpr(kXbias == Layernorm2dXBiasEnum::ADD_BIAS)
|
||||
{
|
||||
const auto tmp_ = make_naive_tensor_view<address_space_enum::global>(
|
||||
static_cast<const XBiasDataType*>(kargs.p_x_bias),
|
||||
make_tuple(kargs.n),
|
||||
make_tuple(1),
|
||||
number<Vector_N>{},
|
||||
number<1>{});
|
||||
|
||||
const auto tmp2_ =
|
||||
pad_tensor_view(tmp_, make_tuple(number<Block_N>{}), sequence<false>{});
|
||||
|
||||
return make_tile_window(tmp2_, make_tuple(number<Block_N>{}), {0});
|
||||
}
|
||||
else
|
||||
{
|
||||
return make_null_tile_window(make_tuple(number<Block_N>{}));
|
||||
}
|
||||
}();
|
||||
|
||||
const auto gamma_window = [&]() {
|
||||
const auto tmp_ = make_naive_tensor_view<address_space_enum::global>(
|
||||
static_cast<const GammaDataType*>(kargs.p_gamma),
|
||||
make_tuple(kargs.n),
|
||||
make_tuple(1),
|
||||
number<Vector_N>{},
|
||||
number<1>{});
|
||||
|
||||
const auto tmp2_ =
|
||||
pad_tensor_view(tmp_, make_tuple(number<Block_N>{}), sequence<false>{});
|
||||
|
||||
return make_tile_window(tmp2_, make_tuple(number<Block_N>{}), {0});
|
||||
}();
|
||||
|
||||
const auto beta_window = [&]() {
|
||||
const auto tmp_ = make_naive_tensor_view<address_space_enum::global>(
|
||||
static_cast<const BetaDataType*>(kargs.p_beta),
|
||||
make_tuple(kargs.n),
|
||||
make_tuple(1),
|
||||
number<Vector_N>{},
|
||||
number<1>{});
|
||||
|
||||
const auto tmp2_ =
|
||||
pad_tensor_view(tmp_, make_tuple(number<Block_N>{}), sequence<false>{});
|
||||
return make_tile_window(tmp2_, make_tuple(number<Block_M>{}, number<Block_N>{}), {0});
|
||||
}();
|
||||
|
||||
auto y_window = [&]() {
|
||||
auto tmp_ = make_naive_tensor_view<address_space_enum::global>(
|
||||
static_cast<YDataType*>(kargs.p_y),
|
||||
make_tuple(kargs.m, kargs.n),
|
||||
make_tuple(kargs.y_stride, 1),
|
||||
number<Vector_N>{},
|
||||
number<1>{});
|
||||
|
||||
auto tmp2_ = pad_tensor_view(
|
||||
tmp_, make_tuple(number<Block_M>{}, number<Block_N>{}), sequence<kPadM, kPadN>{});
|
||||
return make_tile_window(
|
||||
tmp2_, make_tuple(number<Block_M>{}, number<Block_N>{}), {iM, 0});
|
||||
}();
|
||||
|
||||
auto y_residual_window = [&]() {
|
||||
if constexpr(kFusedAdd == Layernorm2dFusedAddEnum::PRE_ADD_STORE)
|
||||
{
|
||||
auto tmp_ = make_naive_tensor_view<address_space_enum::global>(
|
||||
static_cast<YResidualDataType*>(kargs.p_y_residual),
|
||||
make_tuple(kargs.m, kargs.n),
|
||||
make_tuple(kargs.yr_stride, 1),
|
||||
number<Vector_N>{},
|
||||
number<1>{});
|
||||
|
||||
auto tmp2_ = pad_tensor_view(tmp_,
|
||||
make_tuple(number<Block_M>{}, number<Block_N>{}),
|
||||
sequence<kPadM, kPadN>{});
|
||||
return make_tile_window(
|
||||
tmp2_, make_tuple(number<Block_M>{}, number<Block_N>{}), {iM, 0});
|
||||
}
|
||||
else
|
||||
{
|
||||
return make_null_tile_window(make_tuple(number<Block_M>{}, number<Block_N>{}));
|
||||
}
|
||||
}();
|
||||
|
||||
auto mean_window = [&]() {
|
||||
if constexpr(kSaveMean)
|
||||
{
|
||||
const auto mean_m = [&]() {
|
||||
const auto mean_dram_naive =
|
||||
make_naive_tensor_view_packed<address_space_enum::global>(
|
||||
static_cast<MeanDataType*>(kargs.p_mean),
|
||||
make_tuple(kargs.m),
|
||||
number<1>{});
|
||||
|
||||
return pad_tensor_view(
|
||||
mean_dram_naive, make_tuple(number<Block_M>{}), sequence<kPadM>{});
|
||||
}();
|
||||
return make_tile_window(mean_m, make_tuple(number<Block_M>{}), {iM});
|
||||
}
|
||||
else
|
||||
return make_null_tile_window(make_tuple(number<Block_M>{}));
|
||||
}();
|
||||
|
||||
auto inv_std_window = [&]() {
|
||||
if constexpr(kSaveInvStd)
|
||||
{
|
||||
const auto inv_std_m = [&]() {
|
||||
const auto inv_std_dram_naive =
|
||||
make_naive_tensor_view_packed<address_space_enum::global>(
|
||||
static_cast<InvStdDataType*>(kargs.p_invStd),
|
||||
make_tuple(kargs.m),
|
||||
number<1>{});
|
||||
|
||||
return pad_tensor_view(
|
||||
inv_std_dram_naive, make_tuple(number<Block_M>{}), sequence<kPadM>{});
|
||||
}();
|
||||
return make_tile_window(inv_std_m, make_tuple(number<Block_M>{}), {iM});
|
||||
}
|
||||
else
|
||||
return make_null_tile_window(make_tuple(number<Block_M>{}));
|
||||
}();
|
||||
|
||||
auto sm_scale_window = [&]() {
|
||||
if constexpr(kFusedQuant == Layernorm2dFusedQuantEnum::SMOOTH_DYNAMIC_QUANT)
|
||||
{
|
||||
const auto win_ = [&]() {
|
||||
const auto tmp_0_ = make_naive_tensor_view_packed<address_space_enum::global>(
|
||||
static_cast<const SmoothScaleDataType*>(kargs.p_sm_scale),
|
||||
make_tuple(kargs.n),
|
||||
number<Vector_N>{});
|
||||
|
||||
return pad_tensor_view(tmp_0_,
|
||||
make_tuple(number<Block_N>{}),
|
||||
sequence<false>{}); // sm_scale no need pad
|
||||
}();
|
||||
return make_tile_window(win_, make_tuple(number<Block_N>{}), {0});
|
||||
}
|
||||
else
|
||||
return make_null_tile_window(make_tuple(number<Block_N>{}));
|
||||
}();
|
||||
|
||||
auto y_scale_window = [&]() {
|
||||
if constexpr(kFusedQuant == Layernorm2dFusedQuantEnum::SMOOTH_DYNAMIC_QUANT ||
|
||||
kFusedQuant == Layernorm2dFusedQuantEnum::DYNAMIC_QUANT)
|
||||
{
|
||||
const auto win_ = [&]() {
|
||||
const auto tmp_0_ = make_naive_tensor_view_packed<address_space_enum::global>(
|
||||
static_cast<YScaleDataType*>(kargs.p_y_scale),
|
||||
make_tuple(kargs.m),
|
||||
number<1>{});
|
||||
|
||||
return pad_tensor_view(
|
||||
tmp_0_, make_tuple(number<Block_M>{}), sequence<kPadM>{});
|
||||
}();
|
||||
return make_tile_window(win_, make_tuple(number<Block_M>{}), {iM});
|
||||
}
|
||||
else
|
||||
return make_null_tile_window(make_tuple(number<Block_M>{}));
|
||||
}();
|
||||
|
||||
__shared__ char smem[GetSmemSize()];
|
||||
|
||||
Pipeline{}(x_window,
|
||||
x_residual_window,
|
||||
x_bias_window,
|
||||
gamma_window,
|
||||
beta_window,
|
||||
y_window,
|
||||
y_residual_window,
|
||||
mean_window,
|
||||
inv_std_window,
|
||||
sm_scale_window,
|
||||
y_scale_window,
|
||||
static_cast<const ComputeDataType>(kargs.epsilon),
|
||||
kargs.n,
|
||||
smem,
|
||||
Epilogue{});
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -0,0 +1,107 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/norm_reduce/block/block_norm_reduce_problem.hpp"
|
||||
#include "ck_tile/ops/norm_reduce/block/block_norm_reduce.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
struct Layernorm2dFwdPipelineDefaultPolicy
|
||||
{
|
||||
template <typename Problem>
|
||||
CK_TILE_DEVICE static constexpr auto MakeXBlockTileDistribution()
|
||||
{
|
||||
using S = typename Problem::BlockShape;
|
||||
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<
|
||||
sequence<>,
|
||||
tuple<sequence<S::Repeat_M, S::WarpPerBlock_M, S::ThreadPerWarp_M, S::Vector_M>,
|
||||
sequence<S::Repeat_N, S::WarpPerBlock_N, S::ThreadPerWarp_N, S::Vector_N>>,
|
||||
tuple<sequence<1, 2>, sequence<1, 2>>,
|
||||
tuple<sequence<1, 1>, sequence<2, 2>>,
|
||||
sequence<1, 1, 2, 2>,
|
||||
sequence<0, 3, 0, 3>>{});
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_DEVICE static constexpr auto MakeGammaBetaBlockTileDistribution()
|
||||
{
|
||||
using S = typename Problem::BlockShape;
|
||||
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<
|
||||
sequence<S::WarpPerBlock_M, S::ThreadPerWarp_M>,
|
||||
tuple<sequence<S::Repeat_N, S::WarpPerBlock_N, S::ThreadPerWarp_N, S::Vector_N>>,
|
||||
tuple<sequence<0, 1>, sequence<0, 1>>,
|
||||
tuple<sequence<0, 1>, sequence<1, 2>>,
|
||||
sequence<1, 1>,
|
||||
sequence<0, 3>>{});
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetBlockNormReduce()
|
||||
{
|
||||
using P_ = BlockNormReduceProblem<typename Problem::ComputeDataType,
|
||||
typename Problem::ComputeDataType,
|
||||
typename Problem::BlockShape,
|
||||
Problem::Traits::kFastFDiv,
|
||||
Problem::Traits::kWelford>;
|
||||
return BlockNormReduce<P_>{};
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetBlockNormReduceSync()
|
||||
{
|
||||
using P_ = BlockNormReduceProblem<typename Problem::ComputeDataType,
|
||||
typename Problem::ComputeDataType,
|
||||
typename Problem::BlockShape,
|
||||
Problem::Traits::kFastFDiv,
|
||||
Problem::Traits::kWelford>;
|
||||
|
||||
return BlockNormReduceSync<P_>{};
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetBlockNormReduceCrossWarpSync()
|
||||
{
|
||||
using P_ = BlockNormReduceProblem<typename Problem::ComputeDataType,
|
||||
typename Problem::ComputeDataType,
|
||||
typename Problem::BlockShape,
|
||||
Problem::Traits::kFastFDiv,
|
||||
Problem::Traits::kWelford>;
|
||||
|
||||
return BlockNormReduceCrossWarpSync<P_>{};
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize()
|
||||
{
|
||||
if constexpr(Problem::kNeedCrossWarpSync)
|
||||
{
|
||||
using P_ = BlockNormReduceProblem<typename Problem::ComputeDataType,
|
||||
typename Problem::ComputeDataType,
|
||||
typename Problem::BlockShape,
|
||||
Problem::Traits::kFastFDiv,
|
||||
Problem::Traits::kWelford>;
|
||||
|
||||
using block_welford = BlockNormReduce<P_>;
|
||||
using x_block_tile =
|
||||
decltype(make_static_distributed_tensor<typename Problem::ComputeDataType>(
|
||||
MakeXBlockTileDistribution<Problem>()));
|
||||
using mean_var_block_tile =
|
||||
decltype(block_welford::template MakeMeanVarBlockTile<x_block_tile>());
|
||||
|
||||
return GetBlockNormReduceCrossWarpSync<Problem>()
|
||||
.template GetSmemSize<mean_var_block_tile>();
|
||||
}
|
||||
else
|
||||
{
|
||||
return 1; // zero size arrays are an extension
|
||||
}
|
||||
}
|
||||
};
|
||||
} // namespace ck_tile
|
||||
@@ -0,0 +1,199 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_pipeline_default_policy.hpp"
|
||||
#include "ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_traits.hpp"
|
||||
#include <string>
|
||||
#include <type_traits>
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
template <typename Problem_, typename Policy_ = Layernorm2dFwdPipelineDefaultPolicy>
|
||||
struct Layernorm2dFwdPipelineOnePass
|
||||
{
|
||||
using Problem = ck_tile::remove_cvref_t<Problem_>;
|
||||
using Policy = ck_tile::remove_cvref_t<Policy_>;
|
||||
|
||||
using XDataType = ck_tile::remove_cvref_t<typename Problem::XDataType>;
|
||||
using XBiasDataType = ck_tile::remove_cvref_t<typename Problem::XBiasDataType>;
|
||||
using GammaDataType = ck_tile::remove_cvref_t<typename Problem::GammaDataType>;
|
||||
using BetaDataType = ck_tile::remove_cvref_t<typename Problem::BetaDataType>;
|
||||
using ComputeDataType = ck_tile::remove_cvref_t<typename Problem::ComputeDataType>;
|
||||
using YDataType = ck_tile::remove_cvref_t<typename Problem::YDataType>;
|
||||
using MeanDataType = ck_tile::remove_cvref_t<typename Problem::MeanDataType>;
|
||||
using InvStdDataType = ck_tile::remove_cvref_t<typename Problem::InvStdDataType>;
|
||||
|
||||
using XResidualDataType = XDataType;
|
||||
using YResidualDataType = XDataType;
|
||||
|
||||
static constexpr bool kHasGamma = !std::is_same_v<GammaDataType, ck_tile::null_type>;
|
||||
static constexpr bool kHasBeta = !std::is_same_v<BetaDataType, ck_tile::null_type>;
|
||||
static constexpr bool kSaveMean = Problem::Traits::kSaveMeanInvStd;
|
||||
static constexpr bool kSaveInvStd = Problem::Traits::kSaveMeanInvStd;
|
||||
|
||||
static constexpr bool kNeedCrossWarpSync = Problem::kNeedCrossWarpSync;
|
||||
static constexpr bool kPadM = false; // TODO - BlockLayernorm2dFwdProblem::kPadM
|
||||
static constexpr bool kPadN = Problem::Traits::kPadN;
|
||||
static constexpr bool kFastFDiv = Problem::Traits::kFastFDiv;
|
||||
static constexpr bool kWelford = Problem::Traits::kWelford;
|
||||
static constexpr auto kXbias = Problem::Traits::kXbias;
|
||||
static constexpr auto kFusedAdd = Problem::Traits::kFusedAdd;
|
||||
static constexpr auto kFusedQuant = Problem::Traits::kFusedQuant;
|
||||
|
||||
static constexpr const char* name = []() {
|
||||
if constexpr(kNeedCrossWarpSync)
|
||||
return "bpr"; // block per row
|
||||
else
|
||||
return "wpr"; // warp per row
|
||||
}();
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize()
|
||||
{
|
||||
return Policy::template GetSmemSize<Problem>();
|
||||
}
|
||||
|
||||
template <typename XWindow,
|
||||
typename XResidualWindow,
|
||||
typename XBiasWindow,
|
||||
typename GammaWindow,
|
||||
typename BetaWindow,
|
||||
typename YWindow,
|
||||
typename YResidualWindow,
|
||||
typename MeanWindow,
|
||||
typename InvStdWindow,
|
||||
typename SmoothScaleWindow,
|
||||
typename YScaleWindow,
|
||||
typename Epilogue>
|
||||
CK_TILE_DEVICE auto operator()(const XWindow& x_window_,
|
||||
const XResidualWindow& x_residual_window_,
|
||||
const XBiasWindow& x_bias_window_,
|
||||
const GammaWindow& gamma_window_,
|
||||
const BetaWindow& beta_window_,
|
||||
YWindow& y_window_,
|
||||
const YResidualWindow& y_residual_window_,
|
||||
MeanWindow& mean_window,
|
||||
InvStdWindow& inv_std_window,
|
||||
const SmoothScaleWindow& sm_scale_window_,
|
||||
YScaleWindow& y_scale_window,
|
||||
ComputeDataType epsilon,
|
||||
ck_tile::index_t row_size,
|
||||
void* smem,
|
||||
Epilogue) const
|
||||
{
|
||||
const auto x_window =
|
||||
make_tile_window(x_window_, Policy::template MakeXBlockTileDistribution<Problem>());
|
||||
const auto x_bias_window = make_tile_window(
|
||||
x_bias_window_, Policy::template MakeGammaBetaBlockTileDistribution<Problem>());
|
||||
const auto gamma_window = make_tile_window(
|
||||
gamma_window_, Policy::template MakeGammaBetaBlockTileDistribution<Problem>());
|
||||
const auto beta_window = make_tile_window(
|
||||
beta_window_, Policy::template MakeGammaBetaBlockTileDistribution<Problem>());
|
||||
const auto x_residual_window = make_tile_window(
|
||||
x_residual_window_, Policy::template MakeXBlockTileDistribution<Problem>());
|
||||
auto y_residual_window = make_tile_window(
|
||||
y_residual_window_, Policy::template MakeXBlockTileDistribution<Problem>());
|
||||
|
||||
auto x = load_tile(x_window);
|
||||
auto x_resi = load_tile(x_residual_window);
|
||||
const auto x_bias = load_tile(x_bias_window);
|
||||
|
||||
int cur_count = 0;
|
||||
int max_count =
|
||||
block_tile_welford_calculate_max_count<typename Problem::BlockShape>(row_size);
|
||||
auto block_norm_reduce = Policy::template GetBlockNormReduce<Problem>();
|
||||
auto block_norm_reduce_sync = Policy::template GetBlockNormReduceSync<Problem>();
|
||||
auto block_norm_reduce_cross_warp_sync =
|
||||
Policy::template GetBlockNormReduceCrossWarpSync<Problem>();
|
||||
|
||||
using XTensorType = decltype(cast_tile<ComputeDataType>(x));
|
||||
auto mean = block_norm_reduce.template MakeMeanVarBlockTile<XTensorType>();
|
||||
auto var = block_norm_reduce.template MakeMeanVarBlockTile<XTensorType>();
|
||||
clear_tile(mean);
|
||||
clear_tile(var);
|
||||
// load gamma/beta (TODO: support no gamma/beta?)
|
||||
const auto gamma = load_tile(gamma_window);
|
||||
const auto beta = load_tile(beta_window);
|
||||
|
||||
auto acc = cast_tile<ComputeDataType>(x);
|
||||
|
||||
if constexpr(kXbias == Layernorm2dXBiasEnum::ADD_BIAS)
|
||||
{
|
||||
sweep_tile(x, [&](auto idx) {
|
||||
// compute x = bias + x
|
||||
constexpr auto j_idx = make_tuple(idx[number<1>{}]);
|
||||
acc(idx) = type_convert<ComputeDataType>(x_bias[j_idx]) + acc(idx);
|
||||
});
|
||||
}
|
||||
|
||||
if constexpr(kFusedAdd == Layernorm2dFusedAddEnum::PRE_ADD_STORE ||
|
||||
kFusedAdd == Layernorm2dFusedAddEnum::PRE_ADD)
|
||||
{
|
||||
sweep_tile(x_resi, [&](auto idx) {
|
||||
// compute x = x_resi + x
|
||||
acc(idx) = type_convert<ComputeDataType>(x_resi(idx)) + acc(idx);
|
||||
});
|
||||
if constexpr(kFusedAdd == Layernorm2dFusedAddEnum::PRE_ADD_STORE)
|
||||
store_tile(y_residual_window, cast_tile<YResidualDataType>(acc));
|
||||
}
|
||||
|
||||
// compute reduce each-thread->cross-lane->cross-warp
|
||||
block_norm_reduce(acc, mean, var, cur_count, max_count);
|
||||
block_norm_reduce_sync(mean, var, cur_count);
|
||||
block_norm_reduce_cross_warp_sync(mean, var, cur_count, smem);
|
||||
if(kWelford)
|
||||
{
|
||||
block_tile_welford_post_scale_var(var, cur_count, constant<kFastFDiv>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
sweep_tile(mean, [&](auto idx) {
|
||||
mean(idx) = mean(idx) / type_convert<MeanDataType>(row_size);
|
||||
var(idx) = var(idx) / type_convert<MeanDataType>(row_size) - mean(idx) * mean(idx);
|
||||
});
|
||||
}
|
||||
// compute inv-std
|
||||
auto inv_std = tile_elementwise_in(
|
||||
[&](const auto& v_) {
|
||||
if(kFastFDiv && std::is_same_v<ComputeDataType, float>)
|
||||
{
|
||||
return type_convert<ComputeDataType>(1.0f) *
|
||||
__builtin_amdgcn_rcpf(sqrt(v_ + epsilon));
|
||||
}
|
||||
else
|
||||
{
|
||||
return type_convert<ComputeDataType>(1.0f) / sqrt(v_ + epsilon);
|
||||
}
|
||||
},
|
||||
var);
|
||||
|
||||
if constexpr(kSaveMean)
|
||||
store_tile(mean_window, cast_tile<MeanDataType>(mean));
|
||||
if constexpr(kSaveInvStd)
|
||||
store_tile(inv_std_window, cast_tile<InvStdDataType>(inv_std));
|
||||
|
||||
// layernorm computation
|
||||
auto ln = make_static_distributed_tensor<ComputeDataType>(acc.get_tile_distribution());
|
||||
sweep_tile(ln, [&, mean_ = mean](auto idx) {
|
||||
constexpr auto i_idx = make_tuple(idx[number<0>{}]);
|
||||
constexpr auto j_idx = make_tuple(idx[number<1>{}]);
|
||||
|
||||
const auto gamma_ = type_convert<ComputeDataType>(gamma[j_idx]);
|
||||
const auto beta_ = type_convert<ComputeDataType>(beta[j_idx]);
|
||||
|
||||
auto ln_ = (acc[idx] - mean_[i_idx]) * inv_std[i_idx] * gamma_ + beta_;
|
||||
ln(idx) = ln_;
|
||||
});
|
||||
|
||||
if constexpr(kFusedQuant == Layernorm2dFusedQuantEnum::DYNAMIC_QUANT ||
|
||||
kFusedQuant == Layernorm2dFusedQuantEnum::SMOOTH_DYNAMIC_QUANT)
|
||||
{
|
||||
Epilogue{}(y_window_, sm_scale_window_, y_scale_window, ln, smem);
|
||||
}
|
||||
else
|
||||
Epilogue{}(y_window_, ln, nullptr);
|
||||
}
|
||||
};
|
||||
} // namespace ck_tile
|
||||
@@ -0,0 +1,42 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core/utility/type_traits.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
template <typename XDataType_,
|
||||
typename XBiasDataType_,
|
||||
typename GammaDataType_,
|
||||
typename BetaDataType_,
|
||||
typename ComputeDataType_,
|
||||
typename YDataType_,
|
||||
typename MeanDataType_,
|
||||
typename InvStdDataType_,
|
||||
typename SmoothScaleDataType_,
|
||||
typename YScaleDataType_,
|
||||
typename BlockShape_,
|
||||
typename Traits_>
|
||||
struct Layernorm2dFwdPipelineProblem
|
||||
{
|
||||
using XDataType = remove_cvref_t<XDataType_>;
|
||||
using XBiasDataType = remove_cvref_t<XBiasDataType_>;
|
||||
using GammaDataType = remove_cvref_t<GammaDataType_>;
|
||||
using BetaDataType = remove_cvref_t<BetaDataType_>;
|
||||
using ComputeDataType = remove_cvref_t<ComputeDataType_>;
|
||||
using YDataType = remove_cvref_t<YDataType_>;
|
||||
using MeanDataType = remove_cvref_t<MeanDataType_>;
|
||||
using InvStdDataType = remove_cvref_t<InvStdDataType_>;
|
||||
using SmoothScaleDataType = remove_cvref_t<SmoothScaleDataType_>;
|
||||
using YScaleDataType = remove_cvref_t<YScaleDataType_>;
|
||||
using BlockShape = remove_cvref_t<BlockShape_>;
|
||||
|
||||
static constexpr bool kNeedCrossLaneSync = BlockShape::ThreadPerWarp_N > 1;
|
||||
static constexpr bool kNeedCrossWarpSync = BlockShape::WarpPerBlock_N > 1;
|
||||
|
||||
using Traits = remove_cvref_t<Traits_>;
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -0,0 +1,266 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_pipeline_default_policy.hpp"
|
||||
#include <string>
|
||||
#include <type_traits>
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
template <typename Problem_, typename Policy_ = Layernorm2dFwdPipelineDefaultPolicy>
|
||||
struct Layernorm2dFwdPipelineTwoPass
|
||||
{
|
||||
using Problem = ck_tile::remove_cvref_t<Problem_>;
|
||||
using Policy = ck_tile::remove_cvref_t<Policy_>;
|
||||
|
||||
using XDataType = ck_tile::remove_cvref_t<typename Problem::XDataType>;
|
||||
using XBiasDataType = ck_tile::remove_cvref_t<typename Problem::XBiasDataType>;
|
||||
using GammaDataType = ck_tile::remove_cvref_t<typename Problem::GammaDataType>;
|
||||
using BetaDataType = ck_tile::remove_cvref_t<typename Problem::BetaDataType>;
|
||||
using ComputeDataType = ck_tile::remove_cvref_t<typename Problem::ComputeDataType>;
|
||||
using YDataType = ck_tile::remove_cvref_t<typename Problem::YDataType>;
|
||||
using MeanDataType = ck_tile::remove_cvref_t<typename Problem::MeanDataType>;
|
||||
using InvStdDataType = ck_tile::remove_cvref_t<typename Problem::InvStdDataType>;
|
||||
|
||||
using XResidualDataType = XDataType;
|
||||
using YResidualDataType = XDataType;
|
||||
|
||||
static constexpr bool kHasGamma = !std::is_same_v<GammaDataType, ck_tile::null_type>;
|
||||
static constexpr bool kHasBeta = !std::is_same_v<BetaDataType, ck_tile::null_type>;
|
||||
static constexpr bool kSaveMean = Problem::Traits::kSaveMeanInvStd;
|
||||
static constexpr bool kSaveInvStd = Problem::Traits::kSaveMeanInvStd;
|
||||
|
||||
static constexpr bool kNeedCrossWarpSync = Problem::kNeedCrossWarpSync;
|
||||
static constexpr bool kPadM = false; // TODO - BlockLayernorm2dFwdProblem::kPadM
|
||||
static constexpr bool kPadN = Problem::Traits::kPadN;
|
||||
static constexpr bool kFastFDiv = Problem::Traits::kFastFDiv;
|
||||
static constexpr bool kWelford = Problem::Traits::kWelford;
|
||||
static constexpr auto kXbias = Problem::Traits::kXbias;
|
||||
static constexpr auto kFusedAdd = Problem::Traits::kFusedAdd;
|
||||
static constexpr auto kFusedQuant = Problem::Traits::kFusedQuant;
|
||||
|
||||
static constexpr const char* name = []() {
|
||||
if constexpr(kNeedCrossWarpSync)
|
||||
return "bpr_2p"; // block per row
|
||||
else
|
||||
return "wpr_2p"; // warp per row
|
||||
}();
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize()
|
||||
{
|
||||
return Policy::template GetSmemSize<Problem>();
|
||||
}
|
||||
|
||||
template <typename XWindow,
|
||||
typename XResidualWindow,
|
||||
typename XBiasWindow,
|
||||
typename GammaWindow,
|
||||
typename BetaWindow,
|
||||
typename YWindow,
|
||||
typename YResidualWindow,
|
||||
typename MeanWindow,
|
||||
typename InvStdWindow,
|
||||
typename SmoothScaleWindow,
|
||||
typename YScaleWindow,
|
||||
typename Epilogue>
|
||||
CK_TILE_DEVICE auto operator()(const XWindow& x_window_,
|
||||
const XResidualWindow& x_residual_window_,
|
||||
const XBiasWindow& x_bias_window_,
|
||||
const GammaWindow& gamma_window_,
|
||||
const BetaWindow& beta_window_,
|
||||
YWindow& y_window,
|
||||
const YResidualWindow& y_residual_window_,
|
||||
MeanWindow& mean_window,
|
||||
InvStdWindow& inv_std_window,
|
||||
const SmoothScaleWindow& /*sm_scale_window*/,
|
||||
YScaleWindow& /*y_scale_window*/,
|
||||
ComputeDataType epsilon,
|
||||
ck_tile::index_t row_size,
|
||||
void* smem,
|
||||
Epilogue) const
|
||||
{
|
||||
static_assert(kWelford == true, "2 pass only supports welford merge");
|
||||
auto x_window =
|
||||
make_tile_window(x_window_, Policy::template MakeXBlockTileDistribution<Problem>());
|
||||
auto x_bias_window = make_tile_window(
|
||||
x_bias_window_, Policy::template MakeGammaBetaBlockTileDistribution<Problem>());
|
||||
auto gamma_window = make_tile_window(
|
||||
gamma_window_, Policy::template MakeGammaBetaBlockTileDistribution<Problem>());
|
||||
auto beta_window = make_tile_window(
|
||||
beta_window_, Policy::template MakeGammaBetaBlockTileDistribution<Problem>());
|
||||
auto x_residual_window = make_tile_window(
|
||||
x_residual_window_, Policy::template MakeXBlockTileDistribution<Problem>());
|
||||
auto y_residual_window = make_tile_window(
|
||||
y_residual_window_, Policy::template MakeXBlockTileDistribution<Problem>());
|
||||
|
||||
// Problem::BlockShape
|
||||
static constexpr index_t Block_N = Problem::BlockShape::Block_N;
|
||||
index_t num_n_tile_iteration =
|
||||
amd_wave_read_first_lane(integer_divide_ceil(row_size, Block_N));
|
||||
|
||||
// total number of count assume current iter have no pad(only last iter has pad)
|
||||
constexpr index_t count_per_iter =
|
||||
Problem::BlockShape::Repeat_N * Problem::BlockShape::Vector_N;
|
||||
const index_t last_iter_n = row_size - (num_n_tile_iteration - 1) * Block_N;
|
||||
|
||||
int cur_count = 0;
|
||||
int max_count =
|
||||
(num_n_tile_iteration - 1) * count_per_iter +
|
||||
block_tile_welford_calculate_max_count<typename Problem::BlockShape>(last_iter_n);
|
||||
auto block_norm_reduce = Policy::template GetBlockNormReduce<Problem>();
|
||||
auto block_norm_reduce_sync = Policy::template GetBlockNormReduceSync<Problem>();
|
||||
auto block_norm_reduce_cross_warp_sync =
|
||||
Policy::template GetBlockNormReduceCrossWarpSync<Problem>();
|
||||
|
||||
using XTensorType = decltype(cast_tile<ComputeDataType>(load_tile(x_window)));
|
||||
auto mean = block_norm_reduce.template MakeMeanVarBlockTile<XTensorType>();
|
||||
auto var = block_norm_reduce.template MakeMeanVarBlockTile<XTensorType>();
|
||||
|
||||
for(int iN = amd_wave_read_first_lane(0); iN < num_n_tile_iteration; ++iN)
|
||||
{
|
||||
auto x = load_tile(x_window);
|
||||
auto x_resi = load_tile(x_residual_window);
|
||||
const auto x_bias = load_tile(x_bias_window);
|
||||
|
||||
move_tile_window(x_window, {0, Block_N});
|
||||
move_tile_window(x_residual_window, {0, Block_N});
|
||||
move_tile_window(x_bias_window, {Block_N});
|
||||
auto acc = cast_tile<ComputeDataType>(x);
|
||||
|
||||
if constexpr(kXbias == Layernorm2dXBiasEnum::ADD_BIAS)
|
||||
{
|
||||
sweep_tile(x, [&](auto idx) {
|
||||
// compute x = bias + x
|
||||
constexpr auto j_idx = make_tuple(idx[number<1>{}]);
|
||||
acc(idx) = type_convert<ComputeDataType>(x_bias[j_idx]) + acc(idx);
|
||||
});
|
||||
}
|
||||
|
||||
if constexpr(kFusedAdd == Layernorm2dFusedAddEnum::PRE_ADD_STORE ||
|
||||
kFusedAdd == Layernorm2dFusedAddEnum::PRE_ADD)
|
||||
{
|
||||
sweep_tile(x_resi, [&](auto idx) {
|
||||
// compute x = x_resi + x
|
||||
acc(idx) = type_convert<ComputeDataType>(x_resi(idx)) + acc(idx);
|
||||
});
|
||||
if constexpr(kFusedAdd == Layernorm2dFusedAddEnum::PRE_ADD_STORE)
|
||||
{
|
||||
store_tile(y_residual_window, cast_tile<YResidualDataType>(acc));
|
||||
move_tile_window(y_residual_window, {0, Block_N});
|
||||
}
|
||||
}
|
||||
block_norm_reduce(acc, mean, var, cur_count, max_count);
|
||||
}
|
||||
|
||||
block_norm_reduce_sync(mean, var, cur_count);
|
||||
block_norm_reduce_cross_warp_sync(mean, var, cur_count, smem);
|
||||
block_tile_welford_post_scale_var(var, cur_count, constant<kFastFDiv>{});
|
||||
|
||||
// compute inv-std
|
||||
auto inv_std = tile_elementwise_in(
|
||||
[&](const auto& v_) {
|
||||
if(kFastFDiv && std::is_same_v<ComputeDataType, float>)
|
||||
{
|
||||
return type_convert<ComputeDataType>(1.0f) *
|
||||
__builtin_amdgcn_rcpf(sqrt(v_ + epsilon));
|
||||
}
|
||||
else
|
||||
{
|
||||
return type_convert<ComputeDataType>(1.0f) / sqrt(v_ + epsilon);
|
||||
}
|
||||
},
|
||||
var);
|
||||
if constexpr(kSaveMean)
|
||||
store_tile(mean_window, cast_tile<MeanDataType>(mean));
|
||||
if constexpr(kSaveInvStd)
|
||||
store_tile(inv_std_window, cast_tile<InvStdDataType>(inv_std));
|
||||
|
||||
// reverse read x to reuse cache
|
||||
ck_tile::index_t stride_to_right_most_window =
|
||||
row_size % Block_N == 0 ? row_size - Block_N : row_size - row_size % Block_N;
|
||||
|
||||
if constexpr(kFusedAdd == Layernorm2dFusedAddEnum::PRE_ADD_STORE)
|
||||
{
|
||||
move_tile_window(y_residual_window, {0, -Block_N});
|
||||
}
|
||||
else
|
||||
{
|
||||
move_tile_window(x_window, {0, -Block_N});
|
||||
move_tile_window(x_residual_window, {0, -Block_N});
|
||||
move_tile_window(x_bias_window, {-Block_N});
|
||||
}
|
||||
move_tile_window(gamma_window, {stride_to_right_most_window});
|
||||
move_tile_window(beta_window, {stride_to_right_most_window});
|
||||
move_tile_window(y_window, {0, stride_to_right_most_window});
|
||||
|
||||
// layernorm computation
|
||||
for(int iN = amd_wave_read_first_lane(0); iN < num_n_tile_iteration; ++iN)
|
||||
{
|
||||
auto acc = make_static_distributed_tensor<ComputeDataType>(
|
||||
decltype(load_tile(x_window))::get_tile_distribution());
|
||||
|
||||
if constexpr(kFusedAdd == Layernorm2dFusedAddEnum::PRE_ADD_STORE)
|
||||
{
|
||||
acc = cast_tile<ComputeDataType>(load_tile(y_residual_window));
|
||||
move_tile_window(y_residual_window, {0, -Block_N});
|
||||
}
|
||||
else
|
||||
{
|
||||
acc = cast_tile<ComputeDataType>(load_tile(x_window));
|
||||
move_tile_window(x_window, {0, -Block_N});
|
||||
|
||||
if constexpr(kXbias == Layernorm2dXBiasEnum::ADD_BIAS)
|
||||
{
|
||||
const auto x_bias = load_tile(x_bias_window);
|
||||
move_tile_window(x_bias_window, {-Block_N});
|
||||
|
||||
sweep_tile(acc, [&](auto idx) {
|
||||
// compute x = bias + x
|
||||
constexpr auto j_idx = make_tuple(idx[number<1>{}]);
|
||||
acc(idx) = type_convert<ComputeDataType>(x_bias[j_idx]) + acc(idx);
|
||||
});
|
||||
}
|
||||
|
||||
if constexpr(kFusedAdd == Layernorm2dFusedAddEnum::PRE_ADD)
|
||||
{
|
||||
auto x_resi = load_tile(x_residual_window);
|
||||
move_tile_window(x_residual_window, {0, -Block_N});
|
||||
|
||||
sweep_tile(x_resi, [&](auto idx) {
|
||||
// compute x = x_resi + x
|
||||
acc(idx) = type_convert<ComputeDataType>(x_resi(idx)) + acc(idx);
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
// load gamma/beta (TODO: support no gamma/beta?)
|
||||
const auto gamma = load_tile(gamma_window);
|
||||
const auto beta = load_tile(beta_window);
|
||||
|
||||
auto ln = make_static_distributed_tensor<ComputeDataType>(acc.get_tile_distribution());
|
||||
|
||||
sweep_tile(ln, [&, mean_ = mean](auto idx) {
|
||||
constexpr auto i_idx = make_tuple(idx[number<0>{}]);
|
||||
constexpr auto j_idx = make_tuple(idx[number<1>{}]);
|
||||
|
||||
const auto gamma_ = type_convert<ComputeDataType>(gamma[j_idx]);
|
||||
const auto beta_ = type_convert<ComputeDataType>(beta[j_idx]);
|
||||
|
||||
auto ln_ = (acc(idx) - mean_[i_idx]) * inv_std[i_idx] * gamma_ + beta_;
|
||||
|
||||
ln(idx) = ln_;
|
||||
});
|
||||
|
||||
static_assert(kFusedQuant != Layernorm2dFusedQuantEnum::DYNAMIC_QUANT);
|
||||
Epilogue{}(y_window, ln, nullptr);
|
||||
|
||||
move_tile_window(gamma_window, {-Block_N});
|
||||
move_tile_window(beta_window, {-Block_N});
|
||||
move_tile_window(y_window, {0, -Block_N});
|
||||
}
|
||||
}
|
||||
};
|
||||
} // namespace ck_tile
|
||||
@@ -0,0 +1,73 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core/utility/type_traits.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
enum class Layernorm2dXBiasEnum
|
||||
{
|
||||
NO_BIAS = 0,
|
||||
// add bias before fused add
|
||||
ADD_BIAS = 1,
|
||||
};
|
||||
|
||||
// clang-format off
|
||||
template<Layernorm2dXBiasEnum> struct Layernorm2dXBiasEnumName;
|
||||
template<> struct Layernorm2dXBiasEnumName<Layernorm2dXBiasEnum::NO_BIAS> { static constexpr const char * name = "no"; };
|
||||
template<> struct Layernorm2dXBiasEnumName<Layernorm2dXBiasEnum::ADD_BIAS> { static constexpr const char * name = "xbias"; };
|
||||
// clang-format on
|
||||
|
||||
enum class Layernorm2dFusedAddEnum
|
||||
{
|
||||
NO_ADD = 0,
|
||||
// fused add before layernorm and store result to global
|
||||
PRE_ADD_STORE = 1,
|
||||
// fused add before layernorm, but not store result
|
||||
PRE_ADD = 2,
|
||||
};
|
||||
|
||||
// clang-format off
|
||||
template<Layernorm2dFusedAddEnum> struct Layernorm2dFusedAddEnumName;
|
||||
template<> struct Layernorm2dFusedAddEnumName<Layernorm2dFusedAddEnum::NO_ADD> { static constexpr const char * name = "no"; };
|
||||
template<> struct Layernorm2dFusedAddEnumName<Layernorm2dFusedAddEnum::PRE_ADD_STORE> { static constexpr const char * name = "pras"; };
|
||||
template<> struct Layernorm2dFusedAddEnumName<Layernorm2dFusedAddEnum::PRE_ADD> { static constexpr const char * name = "pra"; };
|
||||
// clang-format on
|
||||
|
||||
enum class Layernorm2dFusedQuantEnum
|
||||
{
|
||||
NO_SWEEP = 0,
|
||||
SMOOTH_DYNAMIC_QUANT = 1, // smooth oulier + rowwise quant, need input x-scale and store y_scale
|
||||
DYNAMIC_QUANT = 2, // rowwise quant, store out a y-scale
|
||||
};
|
||||
|
||||
// clang-format off
|
||||
template<Layernorm2dFusedQuantEnum> struct Layernorm2dFusedQuantEnumName;
|
||||
template<> struct Layernorm2dFusedQuantEnumName<Layernorm2dFusedQuantEnum::NO_SWEEP> { static constexpr const char * name = "no"; };
|
||||
template<> struct Layernorm2dFusedQuantEnumName<Layernorm2dFusedQuantEnum::DYNAMIC_QUANT> { static constexpr const char * name = "dqt"; };
|
||||
template<> struct Layernorm2dFusedQuantEnumName<Layernorm2dFusedQuantEnum::SMOOTH_DYNAMIC_QUANT> { static constexpr const char * name = "smdqt"; };
|
||||
// clang-format on
|
||||
|
||||
template <bool kPadN_,
|
||||
bool kSaveMeanInvStd_,
|
||||
bool kFastFDiv_,
|
||||
bool kWelford_,
|
||||
bool kTwoPass_,
|
||||
Layernorm2dXBiasEnum kXbias_,
|
||||
Layernorm2dFusedAddEnum kFusedAdd_,
|
||||
Layernorm2dFusedQuantEnum kFusedQuant_>
|
||||
struct Layernorm2dFwdTraits
|
||||
{
|
||||
static constexpr bool kPadN = kPadN_;
|
||||
static constexpr bool kSaveMeanInvStd = kSaveMeanInvStd_;
|
||||
static constexpr bool kFastFDiv = kFastFDiv_;
|
||||
static constexpr bool kWelford = kWelford_;
|
||||
static constexpr bool kTwoPass = kTwoPass_;
|
||||
static constexpr Layernorm2dXBiasEnum kXbias = kXbias_;
|
||||
static constexpr Layernorm2dFusedAddEnum kFusedAdd = kFusedAdd_;
|
||||
static constexpr Layernorm2dFusedQuantEnum kFusedQuant = kFusedQuant_;
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
Reference in New Issue
Block a user