mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-03 05:01:25 +00:00
[CK_TILE] layernorm support fused-quant/fused-add (#1604)
* add prenorm/postnorm support, refactor using generate.py * update README * update README * fix format * update some description and fix format * update format * format * use non-raw for loading * format and update n4096 * dynamic-quant ready * update readme * support fused dynamic-quant * update fused-quant, with smooth * update README * update args * update some based on comment
This commit is contained in:
@@ -5,19 +5,24 @@
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/common.hpp"
|
||||
#include "ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_traits.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
// host side args
|
||||
struct Layernorm2dFwdHostArgs
|
||||
{
|
||||
const void* p_x;
|
||||
const void* p_gamma;
|
||||
const void* p_beta;
|
||||
const void* p_x; // [m ,n], input, fp16/bf16
|
||||
const void* p_x_residual; // [m ,n], shortcut input, prec same as input, nullptr if not used
|
||||
const void* p_x_scale; // [1 ,n], smooth scale input, fp32, nullptr if not used
|
||||
const void* p_gamma; // [1, n], gamma, prec same as input
|
||||
const void* p_beta; // [1, n], beta, prec same as input
|
||||
|
||||
void* p_y;
|
||||
void* p_mean;
|
||||
void* p_invStd;
|
||||
void* p_y; // [m, n], output, fp16/bf16
|
||||
void* p_y_residual; // [m, n], shortcut output, prec same as input, nullptr if not used
|
||||
void* p_y_scale; // [m, 1], output a dynamic quant per row, nullptr if not used
|
||||
void* p_mean; // [m, 1], output mean, prec same as input, nullptr if not used
|
||||
void* p_invStd; // [m, 1], output inv-stdvariance, prec same as input, nullptr if not used
|
||||
|
||||
float epsilon;
|
||||
|
||||
@@ -27,10 +32,11 @@ struct Layernorm2dFwdHostArgs
|
||||
};
|
||||
|
||||
// TODO: Extract some type to wrapper class
|
||||
template <typename Pipeline_>
|
||||
template <typename Pipeline_, typename Epilogue_>
|
||||
struct Layernorm2dFwd
|
||||
{
|
||||
using Pipeline = remove_cvref_t<Pipeline_>;
|
||||
using Epilogue = remove_cvref_t<Epilogue_>;
|
||||
using Problem = typename Pipeline::Problem;
|
||||
|
||||
using XDataType = remove_cvref_t<typename Problem::XDataType>;
|
||||
@@ -40,18 +46,26 @@ struct Layernorm2dFwd
|
||||
using YDataType = remove_cvref_t<typename Problem::YDataType>;
|
||||
using MeanDataType = remove_cvref_t<typename Problem::MeanDataType>;
|
||||
using InvStdDataType = remove_cvref_t<typename Problem::InvStdDataType>;
|
||||
using XScaleDataType = remove_cvref_t<typename Problem::XScaleDataType>;
|
||||
using YScaleDataType = remove_cvref_t<typename Problem::YScaleDataType>;
|
||||
|
||||
// for simplicity, shortcut input/output type is same as X
|
||||
using XResidualDataType = XDataType;
|
||||
using YResidualDataType = XDataType;
|
||||
|
||||
static constexpr bool kHasGamma = !std::is_same_v<GammaDataType, null_type>;
|
||||
static constexpr bool kHasBeta = !std::is_same_v<BetaDataType, null_type>;
|
||||
static constexpr bool kSaveMeanInvStd = Problem::kSaveMeanInvStd;
|
||||
static constexpr bool kSaveMean = Problem::kSaveMeanInvStd;
|
||||
static constexpr bool kSaveInvStd = Problem::kSaveMeanInvStd;
|
||||
static constexpr bool kSaveMeanInvStd = Problem::Traits::kSaveMeanInvStd;
|
||||
static constexpr bool kSaveMean = Problem::Traits::kSaveMeanInvStd;
|
||||
static constexpr bool kSaveInvStd = Problem::Traits::kSaveMeanInvStd;
|
||||
|
||||
static constexpr index_t Block_M = Problem::BlockShape::Block_M;
|
||||
static constexpr index_t Block_N = Problem::BlockShape::Block_N;
|
||||
static constexpr bool kPadM = false; // always no need to pad along M
|
||||
static constexpr bool kPadN = Problem::kPadN;
|
||||
static constexpr bool kTwoPass = Problem::kTwoPass;
|
||||
static constexpr index_t Block_M = Problem::BlockShape::Block_M;
|
||||
static constexpr index_t Block_N = Problem::BlockShape::Block_N;
|
||||
static constexpr bool kPadM = false; // always no need to pad along M
|
||||
static constexpr bool kPadN = Problem::Traits::kPadN;
|
||||
static constexpr bool kTwoPass = Problem::Traits::kTwoPass;
|
||||
static constexpr auto kFusedAdd = Problem::Traits::kFusedAdd;
|
||||
static constexpr auto kFusedQuant = Problem::Traits::kFusedQuant;
|
||||
|
||||
static constexpr index_t ThreadPerWarp_N = Problem::BlockShape::ThreadPerWarp_N;
|
||||
static constexpr index_t Vector_N = Problem::BlockShape::Vector_N;
|
||||
@@ -62,13 +76,18 @@ struct Layernorm2dFwd
|
||||
|
||||
struct Kargs
|
||||
{
|
||||
const void* p_x;
|
||||
const void* p_gamma;
|
||||
const void* p_beta;
|
||||
const void* p_x; // [m ,n], input, fp16/bf16
|
||||
const void* p_x_residual; // [m ,n], shortcut input, prec same as input, nullptr if not used
|
||||
const void* p_x_scale; // [1 ,n], smooth scale input, fp32, nullptr if not used
|
||||
const void* p_gamma; // [1, n], gamma, prec same as input
|
||||
const void* p_beta; // [1, n], beta, prec same as input
|
||||
|
||||
void* p_y;
|
||||
void* p_mean;
|
||||
void* p_invStd;
|
||||
void* p_y; // [m, n], output, fp16/bf16
|
||||
void* p_y_residual; // [m, n], shortcut output, prec same as input, nullptr if not used
|
||||
void* p_y_scale; // [m, 1], output a dynamic quant per row, nullptr if not used
|
||||
|
||||
void* p_mean; // [m, 1], output mean, prec same as input, nullptr if not used
|
||||
void* p_invStd; // [m, 1], output inv-stdvariance, prec same as input, nullptr if not used
|
||||
|
||||
float epsilon;
|
||||
|
||||
@@ -81,9 +100,13 @@ struct Layernorm2dFwd
|
||||
CK_TILE_HOST static constexpr Kargs MakeKargs(const Hargs& hargs)
|
||||
{
|
||||
return Kargs{hargs.p_x,
|
||||
hargs.p_x_residual,
|
||||
hargs.p_x_scale,
|
||||
hargs.p_gamma,
|
||||
hargs.p_beta,
|
||||
hargs.p_y,
|
||||
hargs.p_y_residual,
|
||||
hargs.p_y_scale,
|
||||
hargs.p_mean,
|
||||
hargs.p_invStd,
|
||||
hargs.epsilon,
|
||||
@@ -106,6 +129,7 @@ struct Layernorm2dFwd
|
||||
template <> struct t2s<ck_tile::bf16_t> { static constexpr const char * name = "bf16"; };
|
||||
template <> struct t2s<ck_tile::fp8_t> { static constexpr const char * name = "fp8"; };
|
||||
template <> struct t2s<ck_tile::bf8_t> { static constexpr const char * name = "bf8"; };
|
||||
template <> struct t2s<ck_tile::int8_t> { static constexpr const char * name = "int8"; };
|
||||
// clang-format on
|
||||
|
||||
// in byte
|
||||
@@ -113,24 +137,41 @@ struct Layernorm2dFwd
|
||||
|
||||
CK_TILE_HOST static std::string GetName()
|
||||
{
|
||||
#define _SS_ std::string
|
||||
#define _TS_ std::to_string
|
||||
// clang-format off
|
||||
using S_ = typename Problem::BlockShape;
|
||||
auto surfix = [&] () {
|
||||
std::string n;
|
||||
if (kFusedAdd != Layernorm2dFusedAddEnum::NO_ADD) n += _SS_("_") + Layernorm2dFusedAddEnumName<kFusedAdd>::name;
|
||||
if (kFusedQuant != Layernorm2dFusedQuantEnum::NO_SWEEP) n += _SS_("_") + Layernorm2dFusedQuantEnumName<kFusedQuant>::name;
|
||||
if (kPadN) n += "_pn";
|
||||
if (kSaveMeanInvStd) n += "_mv";
|
||||
if (kTwoPass) n += "_2p";
|
||||
// if (kTwoPass) n += "_2p";
|
||||
return n; }();
|
||||
|
||||
#define _SS_ std::string
|
||||
#define _TS_ std::to_string
|
||||
return _SS_("layernorm2d_fwd_") + _SS_(t2s<XDataType>::name) + "_" +
|
||||
auto prec_str = [&] () {
|
||||
std::string base_str = _SS_(t2s<XDataType>::name);
|
||||
if (!std::is_same_v<XDataType, YDataType>) {
|
||||
base_str += _SS_("_") + _SS_(t2s<YDataType>::name);
|
||||
}
|
||||
if (kFusedQuant == Layernorm2dFusedQuantEnum::SMOOTH_DYNAMIC_QUANT) {
|
||||
base_str += _SS_("_sx") + _SS_(t2s<XScaleDataType>::name);
|
||||
base_str += _SS_("_sy") + _SS_(t2s<YScaleDataType>::name);
|
||||
}
|
||||
if (kFusedQuant == Layernorm2dFusedQuantEnum::DYNAMIC_QUANT) {
|
||||
base_str += _SS_("_sy") + _SS_(t2s<YScaleDataType>::name);
|
||||
}
|
||||
return base_str;
|
||||
}();
|
||||
|
||||
return _SS_("layernorm2d_fwd_") + _SS_(prec_str) + "_" +
|
||||
_TS_(S_::Block_M) + "x" + _TS_(S_::Block_N) + "_" + _TS_(S_::WarpPerBlock_M) + "x" + _TS_(S_::WarpPerBlock_N) + "_" +
|
||||
_TS_(S_::Warp_M) + "x" + _TS_(S_::Warp_N) + "_" + _TS_(S_::Vector_M) + "x" + _TS_(S_::Vector_N) + "_" +
|
||||
_SS_(Pipeline::name) + surfix;
|
||||
#undef _SS_
|
||||
#undef _TS_
|
||||
// clang-format on
|
||||
#undef _SS_
|
||||
#undef _TS_
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE void operator()(Kargs kargs) const
|
||||
@@ -153,6 +194,31 @@ struct Layernorm2dFwd
|
||||
tmp2_, make_tuple(number<Block_M>{}, number<Block_N>{}), {iM, 0});
|
||||
}();
|
||||
|
||||
const auto x_residual_window = [&]() {
|
||||
if constexpr(kFusedAdd == Layernorm2dFusedAddEnum::PRE_ADD_STORE ||
|
||||
kFusedAdd == Layernorm2dFusedAddEnum::PRE_ADD)
|
||||
{
|
||||
const auto tmp_ = make_naive_tensor_view<address_space_enum::global>(
|
||||
static_cast<const XResidualDataType*>(kargs.p_x_residual),
|
||||
make_tuple(kargs.m, kargs.n),
|
||||
make_tuple(kargs.stride, 1),
|
||||
number<Vector_N>{},
|
||||
number<1>{});
|
||||
|
||||
// NOTE: we don't do any pad in this kernel for loading, assume that inside kernel
|
||||
// will check the max count dynamically
|
||||
const auto tmp2_ = pad_tensor_view(tmp_,
|
||||
make_tuple(number<Block_M>{}, number<Block_N>{}),
|
||||
sequence<false, false>{});
|
||||
return make_tile_window(
|
||||
tmp2_, make_tuple(number<Block_M>{}, number<Block_N>{}), {iM, 0});
|
||||
}
|
||||
else
|
||||
{
|
||||
return make_null_tile_window(make_tuple(number<Block_M>{}, number<Block_N>{}));
|
||||
}
|
||||
}();
|
||||
|
||||
const auto gamma_window = [&]() {
|
||||
const auto tmp_ = make_naive_tensor_view<address_space_enum::global>(
|
||||
static_cast<const GammaDataType*>(kargs.p_gamma),
|
||||
@@ -194,6 +260,28 @@ struct Layernorm2dFwd
|
||||
tmp2_, make_tuple(number<Block_M>{}, number<Block_N>{}), {iM, 0});
|
||||
}();
|
||||
|
||||
auto y_residual_window = [&]() {
|
||||
if constexpr(kFusedAdd == Layernorm2dFusedAddEnum::PRE_ADD_STORE)
|
||||
{
|
||||
auto tmp_ = make_naive_tensor_view<address_space_enum::global>(
|
||||
static_cast<YResidualDataType*>(kargs.p_y_residual),
|
||||
make_tuple(kargs.m, kargs.n),
|
||||
make_tuple(kargs.stride, 1),
|
||||
number<Vector_N>{},
|
||||
number<1>{});
|
||||
|
||||
auto tmp2_ = pad_tensor_view(tmp_,
|
||||
make_tuple(number<Block_M>{}, number<Block_N>{}),
|
||||
sequence<kPadM, kPadN>{});
|
||||
return make_tile_window(
|
||||
tmp2_, make_tuple(number<Block_M>{}, number<Block_N>{}), {iM, 0});
|
||||
}
|
||||
else
|
||||
{
|
||||
return make_null_tile_window(make_tuple(number<Block_M>{}, number<Block_N>{}));
|
||||
}
|
||||
}();
|
||||
|
||||
auto mean_window = [&]() {
|
||||
if constexpr(kSaveMean)
|
||||
{
|
||||
@@ -232,17 +320,60 @@ struct Layernorm2dFwd
|
||||
return make_null_tile_window(make_tuple(number<Block_M>{}));
|
||||
}();
|
||||
|
||||
auto x_scale_window = [&]() {
|
||||
if constexpr(kFusedQuant == Layernorm2dFusedQuantEnum::SMOOTH_DYNAMIC_QUANT)
|
||||
{
|
||||
const auto win_ = [&]() {
|
||||
const auto tmp_0_ = make_naive_tensor_view_packed<address_space_enum::global>(
|
||||
static_cast<const XScaleDataType*>(kargs.p_x_scale),
|
||||
make_tuple(kargs.n),
|
||||
number<Vector_N>{});
|
||||
|
||||
return pad_tensor_view(tmp_0_,
|
||||
make_tuple(number<Block_N>{}),
|
||||
sequence<false>{}); // x_scale no need pad
|
||||
}();
|
||||
return make_tile_window(win_, make_tuple(number<Block_N>{}), {0});
|
||||
}
|
||||
else
|
||||
return make_null_tile_window(make_tuple(number<Block_N>{}));
|
||||
}();
|
||||
|
||||
auto y_scale_window = [&]() {
|
||||
if constexpr(kFusedQuant == Layernorm2dFusedQuantEnum::SMOOTH_DYNAMIC_QUANT ||
|
||||
kFusedQuant == Layernorm2dFusedQuantEnum::DYNAMIC_QUANT)
|
||||
{
|
||||
const auto win_ = [&]() {
|
||||
const auto tmp_0_ = make_naive_tensor_view_packed<address_space_enum::global>(
|
||||
static_cast<YScaleDataType*>(kargs.p_y_scale),
|
||||
make_tuple(kargs.m),
|
||||
number<1>{});
|
||||
|
||||
return pad_tensor_view(
|
||||
tmp_0_, make_tuple(number<Block_M>{}), sequence<kPadM>{});
|
||||
}();
|
||||
return make_tile_window(win_, make_tuple(number<Block_M>{}), {iM});
|
||||
}
|
||||
else
|
||||
return make_null_tile_window(make_tuple(number<Block_M>{}));
|
||||
}();
|
||||
|
||||
__shared__ char smem[GetSmemSize()];
|
||||
|
||||
Pipeline{}(x_window,
|
||||
x_residual_window,
|
||||
gamma_window,
|
||||
beta_window,
|
||||
y_window,
|
||||
y_residual_window,
|
||||
mean_window,
|
||||
inv_std_window,
|
||||
x_scale_window,
|
||||
y_scale_window,
|
||||
static_cast<const ComputeDataType>(kargs.epsilon),
|
||||
kargs.n,
|
||||
smem);
|
||||
smem,
|
||||
Epilogue{});
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
@@ -1,78 +0,0 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
/*
|
||||
// clang-format off
|
||||
|
||||
4-level descriptor: BlockTile-> WarpPerBlock-> WarpTile-> Vector
|
||||
|
||||
Block_N (Warp_N * WarpPerBlock_N * Repeat_N )
|
||||
+<----------------------< Repeat_N(2)>--------------------->+
|
||||
| |
|
||||
+<-- <WarpPerBlock_N(2)> -->+
|
||||
Warp_N
|
||||
+--------------+--------------+--------------+--------------+----+----------------+
|
||||
Warp_M | wrap_0 | wrap_1 | | ^ ^
|
||||
+--------------+--------------+ | <WarpPerBlock_M(2)> |
|
||||
| wrap_2 | wrap_3 | | v
|
||||
+--------------+--------------+--------------+--------------+----+ Block_M
|
||||
| | |
|
||||
+ + |
|
||||
| | | v
|
||||
+--------------+--------------+--------------+--------------+ +
|
||||
|
||||
each Warp-tile (e.g 16 thrd per row)
|
||||
|
||||
Vector_N (contiguous pixels each thrd holds along N, or vector size)
|
||||
+-----------+-----------+-----------+-----------+-----------+
|
||||
| thrd_0 | thrd_1 | thrd_2 | thrd_3 | ... Vector_M
|
||||
+-----------+-----------+-----------+-----------+-----------+
|
||||
| thrd_16 | thrd_17 | thrd_18 | thrd_19 | ...
|
||||
+-----------+-----------+-----------+-----------+-----------+
|
||||
// clang-format on
|
||||
*/
|
||||
template <typename BlockTile_, // block size, seq<M, N>
|
||||
typename WarpPerBlock_, // num warps along seq<M, N>
|
||||
typename WarpTile_, // warp size, seq<M, N>
|
||||
typename Vector_, // contiguous pixels(vector size) along seq<M, N>
|
||||
index_t BlockSize_ =
|
||||
warpSize* reduce_on_sequence(WarpPerBlock_{}, multiplies{}, number<1>{})>
|
||||
struct Layernorm2dShape
|
||||
{
|
||||
// block size
|
||||
static constexpr index_t Block_M = BlockTile_::at(number<0>{});
|
||||
static constexpr index_t Block_N = BlockTile_::at(number<1>{});
|
||||
|
||||
// num warps along seq<M, N>, within each block
|
||||
static constexpr index_t WarpPerBlock_M = WarpPerBlock_::at(number<0>{});
|
||||
static constexpr index_t WarpPerBlock_N = WarpPerBlock_::at(number<1>{});
|
||||
|
||||
// warp size
|
||||
static constexpr index_t Warp_M = WarpTile_::at(number<0>{});
|
||||
static constexpr index_t Warp_N = WarpTile_::at(number<1>{});
|
||||
|
||||
static_assert(Block_M % (WarpPerBlock_M * Warp_M) == 0);
|
||||
static_assert(Block_N % (WarpPerBlock_N * Warp_N) == 0);
|
||||
// repeat of each thread along seq<M, N>
|
||||
static constexpr index_t Repeat_M = Block_M / (WarpPerBlock_M * Warp_M);
|
||||
static constexpr index_t Repeat_N = Block_N / (WarpPerBlock_N * Warp_N);
|
||||
|
||||
// vector size along seq<M, N>
|
||||
static constexpr index_t Vector_M = Vector_::at(number<0>{});
|
||||
static constexpr index_t Vector_N = Vector_::at(number<1>{});
|
||||
|
||||
static_assert(Warp_M % Vector_M == 0);
|
||||
static_assert(Warp_N % Vector_N == 0);
|
||||
// num of threads along seq<M, N>, within each warp
|
||||
static constexpr index_t ThreadPerWarp_M = Warp_M / Vector_M;
|
||||
static constexpr index_t ThreadPerWarp_N = Warp_N / Vector_N;
|
||||
|
||||
static constexpr index_t BlockSize = BlockSize_;
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
Reference in New Issue
Block a user