mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-05 06:01:23 +00:00
update layernorm (#1570)
* port layernorm
* change warp_welford.hpp
* Update warpshuffle
* 1. Add save mean and save std back
2. Move construction of tensor_view and tile_window to operator()
* refine welford max count calculation
* unify layernorm api
* Rename file
* Remove save mean and inv std
* Revert "refine welford max count calculation"
This reverts commit 022365802b.
* Fix order of parameter
* refine welford max count calculation again
* Remove fp32 instances
* Fix bug of padding
* refactor api
* Support bf16
* Extract common function
* Refine arg of operator()
* Add kMThreadPerBlock to template parameter
* clang format
* Refine variable name
* Refine file name
* remove redundant line
* refactor layernorm2d pipeline and add block-per-block utility
* fix name
* rename more
* add more block-per-tile instance
* remove duplicated define
* update instance for 2048, 1024 case
* support up to 2048 now
* opt loading
* add n1536
* Add two pass pipeline
* format
* Fix incorrect type
* parallel compilation
* Use smaller N
* fix 2p pass
* Support Repeat_M in distribution
* Refine nameing
* Add reduce example
---------
Co-authored-by: letaoqin <letaoqin@amd.com>
Co-authored-by: aska-0096 <haocwang@amd.com>
Co-authored-by: rocking <ChunYu.Lai@amd.com>
Co-authored-by: carlushuang <carlus.huang@amd.com>
This commit is contained in:
@@ -5,37 +5,57 @@
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/common.hpp"
|
||||
#include "ck_tile/ops/welford/thread/thread_welford.hpp"
|
||||
#include "ck_tile/ops/welford/warp/warp_welford.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
// host side args
|
||||
struct Layernorm2dFwdHostArgs
|
||||
{
|
||||
const void* p_x;
|
||||
const void* p_gamma;
|
||||
const void* p_beta;
|
||||
|
||||
void* p_y;
|
||||
void* p_mean;
|
||||
void* p_invStd;
|
||||
|
||||
float epsilon;
|
||||
|
||||
index_t m;
|
||||
index_t n;
|
||||
index_t stride; // row_stride
|
||||
};
|
||||
|
||||
// TODO: Extract some type to wrapper class
|
||||
template <typename Problem_>
|
||||
template <typename Pipeline_>
|
||||
struct Layernorm2dFwd
|
||||
{
|
||||
using Problem = ck_tile::remove_cvref_t<Problem_>;
|
||||
using Pipeline = remove_cvref_t<Pipeline_>;
|
||||
using Problem = typename Pipeline::Problem;
|
||||
|
||||
using XDataType = ck_tile::remove_cvref_t<typename Problem::XDataType>;
|
||||
using GammaDataType = ck_tile::remove_cvref_t<typename Problem::GammaDataType>;
|
||||
using BetaDataType = ck_tile::remove_cvref_t<typename Problem::BetaDataType>;
|
||||
using ComputeDataType = ck_tile::remove_cvref_t<typename Problem::ComputeDataType>;
|
||||
using YDataType = ck_tile::remove_cvref_t<typename Problem::YDataType>;
|
||||
using MeanDataType = ck_tile::remove_cvref_t<typename Problem::MeanDataType>;
|
||||
using InvStdDataType = ck_tile::remove_cvref_t<typename Problem::InvStdDataType>;
|
||||
using XDataType = remove_cvref_t<typename Problem::XDataType>;
|
||||
using GammaDataType = remove_cvref_t<typename Problem::GammaDataType>;
|
||||
using BetaDataType = remove_cvref_t<typename Problem::BetaDataType>;
|
||||
using ComputeDataType = remove_cvref_t<typename Problem::ComputeDataType>;
|
||||
using YDataType = remove_cvref_t<typename Problem::YDataType>;
|
||||
using MeanDataType = remove_cvref_t<typename Problem::MeanDataType>;
|
||||
using InvStdDataType = remove_cvref_t<typename Problem::InvStdDataType>;
|
||||
|
||||
static constexpr bool kHasGamma = !std::is_same_v<GammaDataType, ck_tile::null_type>;
|
||||
static constexpr bool kHasBeta = !std::is_same_v<BetaDataType, ck_tile::null_type>;
|
||||
static constexpr bool kSaveMean = !std::is_same_v<MeanDataType, ck_tile::null_type>;
|
||||
static constexpr bool kSaveInvStd = !std::is_same_v<InvStdDataType, ck_tile::null_type>;
|
||||
static constexpr bool kHasGamma = !std::is_same_v<GammaDataType, null_type>;
|
||||
static constexpr bool kHasBeta = !std::is_same_v<BetaDataType, null_type>;
|
||||
static constexpr bool kSaveMeanInvStd = Problem::kSaveMeanInvStd;
|
||||
static constexpr bool kSaveMean = Problem::kSaveMeanInvStd;
|
||||
static constexpr bool kSaveInvStd = Problem::kSaveMeanInvStd;
|
||||
|
||||
static constexpr ck_tile::index_t kMPerBlock = Problem::BlockShape::kMPerBlock;
|
||||
static constexpr ck_tile::index_t kNPerBlock = Problem::BlockShape::kNPerBlock;
|
||||
static constexpr bool kPadM = Problem::kPadM;
|
||||
static constexpr bool kPadN = Problem::kPadN;
|
||||
static constexpr index_t Block_M = Problem::BlockShape::Block_M;
|
||||
static constexpr index_t Block_N = Problem::BlockShape::Block_N;
|
||||
static constexpr bool kPadM = false; // always no need to pad along M
|
||||
static constexpr bool kPadN = Problem::kPadN;
|
||||
static constexpr bool kTwoPass = Problem::kTwoPass;
|
||||
|
||||
static constexpr ck_tile::index_t kNThreadPerWarp = Problem::BlockShape::kNThreadPerWarp;
|
||||
static constexpr ck_tile::index_t kNPerThread = Problem::BlockShape::kNPerThread;
|
||||
static constexpr index_t ThreadPerWarp_N = Problem::BlockShape::ThreadPerWarp_N;
|
||||
static constexpr index_t Vector_N = Problem::BlockShape::Vector_N;
|
||||
static constexpr index_t Repeat_N = Problem::BlockShape::Repeat_N;
|
||||
|
||||
static constexpr auto I0 = number<0>{};
|
||||
static constexpr auto I1 = number<1>{};
|
||||
@@ -52,400 +72,177 @@ struct Layernorm2dFwd
|
||||
|
||||
float epsilon;
|
||||
|
||||
ck_tile::index_t M;
|
||||
ck_tile::index_t N;
|
||||
index_t m;
|
||||
index_t n;
|
||||
index_t stride; // row_stride
|
||||
};
|
||||
using Hargs = Layernorm2dFwdHostArgs;
|
||||
|
||||
CK_TILE_HOST static constexpr Kargs MakeKargs(const void* p_x,
|
||||
const void* p_gamma,
|
||||
const void* p_beta,
|
||||
void* p_y,
|
||||
void* p_mean,
|
||||
void* p_invStd,
|
||||
float epsilon,
|
||||
ck_tile::index_t M,
|
||||
ck_tile::index_t N)
|
||||
CK_TILE_HOST static constexpr Kargs MakeKargs(const Hargs& hargs)
|
||||
{
|
||||
return Kargs{p_x, p_gamma, p_beta, p_y, p_mean, p_invStd, epsilon, M, N};
|
||||
return Kargs{hargs.p_x,
|
||||
hargs.p_gamma,
|
||||
hargs.p_beta,
|
||||
hargs.p_y,
|
||||
hargs.p_mean,
|
||||
hargs.p_invStd,
|
||||
hargs.epsilon,
|
||||
hargs.m,
|
||||
hargs.n,
|
||||
hargs.stride};
|
||||
}
|
||||
|
||||
CK_TILE_HOST static constexpr auto GridSize(ck_tile::index_t M) { return M / kMPerBlock; }
|
||||
|
||||
CK_TILE_HOST static constexpr auto BlockSize() { return Problem::BlockShape::kBlockSize; }
|
||||
|
||||
CK_TILE_DEVICE static constexpr auto MakeXBlockTileDistribution()
|
||||
CK_TILE_HOST static constexpr auto GridSize(const Hargs& hargs)
|
||||
{
|
||||
using S = typename Problem::BlockShape;
|
||||
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<
|
||||
sequence<>,
|
||||
tuple<sequence<S::kMWarpPerBlock, S::kMThreadPerWarp, S::kMPerThread>,
|
||||
sequence<S::kNWarpPerBlock, S::kNThreadPerWarp, S::kNPerThread>>,
|
||||
tuple<sequence<1, 2>, sequence<1, 2>>,
|
||||
tuple<sequence<0, 0>, sequence<1, 1>>,
|
||||
sequence<1, 2>,
|
||||
sequence<2, 2>>{});
|
||||
return (hargs.m + Block_M - 1) / Block_M;
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE static constexpr auto MakeGammaBetaBlockTileDistribution()
|
||||
CK_TILE_HOST static constexpr auto BlockSize() { return Problem::BlockShape::BlockSize; }
|
||||
|
||||
// clang-format off
|
||||
template <typename T> struct t2s;
|
||||
template <> struct t2s<float> { static constexpr const char * name = "fp32"; };
|
||||
template <> struct t2s<ck_tile::fp16_t> { static constexpr const char * name = "fp16"; };
|
||||
template <> struct t2s<ck_tile::bf16_t> { static constexpr const char * name = "bf16"; };
|
||||
template <> struct t2s<ck_tile::fp8_t> { static constexpr const char * name = "fp8"; };
|
||||
template <> struct t2s<ck_tile::bf8_t> { static constexpr const char * name = "bf8"; };
|
||||
// clang-format on
|
||||
|
||||
// in byte
|
||||
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize() { return Pipeline::GetSmemSize(); }
|
||||
|
||||
CK_TILE_HOST static std::string GetName()
|
||||
{
|
||||
using S = typename Problem::BlockShape;
|
||||
// clang-format off
|
||||
using S_ = typename Problem::BlockShape;
|
||||
auto surfix = [&] () {
|
||||
std::string n;
|
||||
if (kPadN) n += "_pn";
|
||||
if (kSaveMeanInvStd) n += "_mv";
|
||||
if (kTwoPass) n += "_2p";
|
||||
return n; }();
|
||||
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<
|
||||
sequence<S::kMWarpPerBlock, S::kMThreadPerWarp>,
|
||||
tuple<sequence<S::kNWarpPerBlock, S::kNThreadPerWarp, S::kNPerThread>>,
|
||||
tuple<sequence<0, 1>, sequence<0, 1>>,
|
||||
tuple<sequence<0, 0>, sequence<1, 1>>,
|
||||
sequence<1>,
|
||||
sequence<2>>{});
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE static int GetWelfordMaxCount(int N)
|
||||
{
|
||||
constexpr ck_tile::index_t kNThreadPerBlock = kNPerBlock / kNPerThread;
|
||||
|
||||
int thread_id_n = get_thread_id() % kNThreadPerBlock;
|
||||
int max_count =
|
||||
__builtin_amdgcn_readfirstlane(N < kNPerBlock ? 0 : kNPerThread * (N / kNPerBlock));
|
||||
int n_per_block_tail_loop =
|
||||
__builtin_amdgcn_readfirstlane(N - max_count * kNThreadPerBlock);
|
||||
|
||||
if(n_per_block_tail_loop > 0)
|
||||
{
|
||||
int thread_max_n = (thread_id_n + 1) * kNPerThread;
|
||||
int delta = thread_max_n - n_per_block_tail_loop;
|
||||
delta = clamp(thread_max_n - n_per_block_tail_loop, 0, kNPerThread);
|
||||
max_count += kNPerThread - delta;
|
||||
}
|
||||
|
||||
return max_count;
|
||||
}
|
||||
|
||||
template <typename DistributedTensor>
|
||||
CK_TILE_DEVICE static auto InvSqrt(const DistributedTensor& in_dstr_tensor,
|
||||
const ComputeDataType epsilon)
|
||||
{
|
||||
// TODO: Investigate fast inverse square root algorithm with epsilon
|
||||
constexpr auto spans = DistributedTensor::get_distributed_spans();
|
||||
|
||||
DistributedTensor out_dstr_tensor;
|
||||
|
||||
sweep_tile_span(spans[number<0>{}], [&](auto idx0) {
|
||||
constexpr auto i_idx = make_tuple(idx0);
|
||||
out_dstr_tensor(i_idx) = type_convert<ComputeDataType>(1.0f) /
|
||||
ck_tile::sqrt(in_dstr_tensor[i_idx] + epsilon);
|
||||
});
|
||||
|
||||
return out_dstr_tensor;
|
||||
}
|
||||
|
||||
template <typename XBlockWindow,
|
||||
typename GammaBlockWindow,
|
||||
typename BetaBlockWindow,
|
||||
typename YBlockWindow,
|
||||
typename MeanBlockWindow,
|
||||
typename InvStdBlockWindow,
|
||||
bool Cond = (kHasGamma && kHasBeta)>
|
||||
CK_TILE_DEVICE std::enable_if_t<Cond>
|
||||
TwoPassLayernorm2dFwd(XBlockWindow& x_block_window,
|
||||
GammaBlockWindow& gamma_block_window,
|
||||
BetaBlockWindow& beta_block_window,
|
||||
YBlockWindow& y_block_window,
|
||||
MeanBlockWindow& mean_block_window,
|
||||
InvStdBlockWindow& inv_std_block_window,
|
||||
ComputeDataType epsilon,
|
||||
ck_tile::index_t N) const
|
||||
{
|
||||
// TODO - Optimize tail loop to reduce move_tile_window()
|
||||
index_t num_n_tile_iteration =
|
||||
__builtin_amdgcn_readfirstlane(integer_divide_ceil(N, kNPerBlock));
|
||||
|
||||
int welford_max_count = GetWelfordMaxCount(N);
|
||||
ThreadWelford<ComputeDataType, XDataType> thread_welford{welford_max_count};
|
||||
|
||||
using XTensorType = decltype(load_tile(x_block_window));
|
||||
auto mean_compute_block_tensor =
|
||||
thread_welford.template MakeInitialMeanVarDistributedTensor<XTensorType>();
|
||||
auto var_compute_block_tensor =
|
||||
thread_welford.template MakeInitialMeanVarDistributedTensor<XTensorType>();
|
||||
|
||||
clear_tile(mean_compute_block_tensor);
|
||||
clear_tile(var_compute_block_tensor);
|
||||
|
||||
for(int iN = __builtin_amdgcn_readfirstlane(0); iN < num_n_tile_iteration; ++iN)
|
||||
{
|
||||
const auto x_block_tensor = load_tile(x_block_window);
|
||||
|
||||
thread_welford(x_block_tensor, mean_compute_block_tensor, var_compute_block_tensor);
|
||||
move_tile_window(x_block_window, {0, kNPerBlock});
|
||||
}
|
||||
|
||||
// TODO: support cross warp Welford
|
||||
WarpMergeWelford<ComputeDataType, true>{}(
|
||||
mean_compute_block_tensor, var_compute_block_tensor, thread_welford.cur_count_);
|
||||
|
||||
auto inv_std_compute_block_tensor = InvSqrt(var_compute_block_tensor, epsilon);
|
||||
|
||||
if constexpr(kSaveMean)
|
||||
store_tile(mean_block_window, cast_tile<MeanDataType>(mean_compute_block_tensor));
|
||||
if constexpr(kSaveInvStd)
|
||||
store_tile(inv_std_block_window,
|
||||
cast_tile<InvStdDataType>(inv_std_compute_block_tensor));
|
||||
|
||||
// reverse read x to reuse cache
|
||||
ck_tile::index_t stride_to_right_most_window =
|
||||
N % kNPerBlock == 0 ? N - kNPerBlock : N - N % kNPerBlock;
|
||||
|
||||
move_tile_window(x_block_window, {0, -kNPerBlock});
|
||||
move_tile_window(gamma_block_window, {stride_to_right_most_window});
|
||||
move_tile_window(beta_block_window, {stride_to_right_most_window});
|
||||
move_tile_window(y_block_window, {0, stride_to_right_most_window});
|
||||
|
||||
// Normalization
|
||||
for(int iN = __builtin_amdgcn_readfirstlane(0); iN < num_n_tile_iteration; ++iN)
|
||||
{
|
||||
const auto x_block_tensor = load_tile(x_block_window);
|
||||
const auto gamma_block_tensor = load_tile(gamma_block_window);
|
||||
const auto beta_block_tensor = load_tile(beta_block_window);
|
||||
|
||||
constexpr auto x_spans = decltype(x_block_tensor)::get_distributed_spans();
|
||||
|
||||
auto y_block_tensor =
|
||||
make_static_distributed_tensor<YDataType>(x_block_tensor.get_tile_distribution());
|
||||
|
||||
sweep_tile_span(x_spans[I1], [&](auto idx1) {
|
||||
constexpr auto j_idx = make_tuple(idx1);
|
||||
const auto gamma = type_convert<ComputeDataType>(gamma_block_tensor[j_idx]);
|
||||
const auto beta = type_convert<ComputeDataType>(beta_block_tensor[j_idx]);
|
||||
|
||||
sweep_tile_span(x_spans[I0], [&](auto idx0) {
|
||||
constexpr auto i_idx = make_tuple(idx0);
|
||||
constexpr auto i_j_idx = make_tuple(idx0, idx1);
|
||||
|
||||
const auto mean = mean_compute_block_tensor[i_idx];
|
||||
const auto inv_std = inv_std_compute_block_tensor[i_idx];
|
||||
|
||||
const auto x = type_convert<ComputeDataType>(x_block_tensor[i_j_idx]);
|
||||
auto y = (x - mean) * inv_std * gamma + beta;
|
||||
|
||||
y_block_tensor(i_j_idx) = type_convert<YDataType>(y);
|
||||
});
|
||||
});
|
||||
|
||||
store_tile(y_block_window, y_block_tensor);
|
||||
|
||||
move_tile_window(x_block_window, {0, -kNPerBlock});
|
||||
move_tile_window(gamma_block_window, {-kNPerBlock});
|
||||
move_tile_window(beta_block_window, {-kNPerBlock});
|
||||
move_tile_window(y_block_window, {0, -kNPerBlock});
|
||||
}
|
||||
}
|
||||
|
||||
template <typename XBlockWindow,
|
||||
typename GammaBlockWindow,
|
||||
typename BetaBlockWindow,
|
||||
typename YBlockWindow,
|
||||
typename MeanBlockWindow,
|
||||
typename InvStdBlockWindow,
|
||||
bool Cond = (kHasGamma && kHasBeta)>
|
||||
CK_TILE_DEVICE std::enable_if_t<Cond>
|
||||
OnePassLayernorm2dFwd(XBlockWindow& x_block_window,
|
||||
GammaBlockWindow& gamma_block_window,
|
||||
BetaBlockWindow& beta_block_window,
|
||||
YBlockWindow& y_block_window,
|
||||
MeanBlockWindow& mean_block_window,
|
||||
InvStdBlockWindow& inv_std_block_window,
|
||||
ComputeDataType epsilon,
|
||||
ck_tile::index_t N) const
|
||||
{
|
||||
int welford_max_count = GetWelfordMaxCount(N);
|
||||
ThreadWelford<ComputeDataType, XDataType> thread_welford{welford_max_count};
|
||||
|
||||
using XTensorType = decltype(load_tile(x_block_window));
|
||||
auto mean_compute_block_tensor =
|
||||
thread_welford.template MakeInitialMeanVarDistributedTensor<XTensorType>();
|
||||
auto var_compute_block_tensor =
|
||||
thread_welford.template MakeInitialMeanVarDistributedTensor<XTensorType>();
|
||||
|
||||
clear_tile(mean_compute_block_tensor);
|
||||
clear_tile(var_compute_block_tensor);
|
||||
|
||||
const auto x_block_tensor = load_tile(x_block_window);
|
||||
thread_welford(x_block_tensor, mean_compute_block_tensor, var_compute_block_tensor);
|
||||
// TODO: support cross warp Welford
|
||||
WarpMergeWelford<ComputeDataType, true>{}(
|
||||
mean_compute_block_tensor, var_compute_block_tensor, thread_welford.cur_count_);
|
||||
|
||||
auto inv_std_compute_block_tensor = InvSqrt(var_compute_block_tensor, epsilon);
|
||||
|
||||
if constexpr(kSaveMean)
|
||||
store_tile(mean_block_window, cast_tile<MeanDataType>(mean_compute_block_tensor));
|
||||
if constexpr(kSaveInvStd)
|
||||
store_tile(inv_std_block_window,
|
||||
cast_tile<InvStdDataType>(inv_std_compute_block_tensor));
|
||||
|
||||
// normalize
|
||||
const auto gamma_block_tensor = load_tile(gamma_block_window);
|
||||
const auto beta_block_tensor = load_tile(beta_block_window);
|
||||
|
||||
constexpr auto x_spans = decltype(x_block_tensor)::get_distributed_spans();
|
||||
|
||||
auto y_block_tensor =
|
||||
make_static_distributed_tensor<YDataType>(x_block_tensor.get_tile_distribution());
|
||||
|
||||
sweep_tile_span(x_spans[I1], [&](auto idx1) {
|
||||
constexpr auto j_idx = make_tuple(idx1);
|
||||
const auto gamma = type_convert<ComputeDataType>(gamma_block_tensor[j_idx]);
|
||||
const auto beta = type_convert<ComputeDataType>(beta_block_tensor[j_idx]);
|
||||
|
||||
sweep_tile_span(x_spans[I0], [&](auto idx0) {
|
||||
constexpr auto i_idx = make_tuple(idx0);
|
||||
constexpr auto i_j_idx = make_tuple(idx0, idx1);
|
||||
|
||||
const auto mean = mean_compute_block_tensor[i_idx];
|
||||
const auto inv_std = inv_std_compute_block_tensor[i_idx];
|
||||
|
||||
const auto x = type_convert<ComputeDataType>(x_block_tensor[i_j_idx]);
|
||||
auto y = (x - mean) * inv_std * gamma + beta;
|
||||
|
||||
y_block_tensor(i_j_idx) = type_convert<YDataType>(y);
|
||||
});
|
||||
});
|
||||
|
||||
store_tile(y_block_window, y_block_tensor);
|
||||
#define _SS_ std::string
|
||||
#define _TS_ std::to_string
|
||||
return _SS_("layernorm2d_fwd_") + _SS_(t2s<XDataType>::name) + "_" +
|
||||
_TS_(S_::Block_M) + "x" + _TS_(S_::Block_N) + "_" + _TS_(S_::WarpPerBlock_M) + "x" + _TS_(S_::WarpPerBlock_N) + "_" +
|
||||
_TS_(S_::Warp_M) + "x" + _TS_(S_::Warp_N) + "_" + _TS_(S_::Vector_M) + "x" + _TS_(S_::Vector_N) + "_" +
|
||||
_SS_(Pipeline::name) + surfix;
|
||||
#undef _SS_
|
||||
#undef _TS_
|
||||
// clang-format on
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE void operator()(Kargs kargs) const
|
||||
{
|
||||
const auto x_m_n = [&]() {
|
||||
const auto x_dram_naive = make_naive_tensor_view<address_space_enum::global>(
|
||||
const auto iM = get_block_id() * Block_M;
|
||||
|
||||
const auto x_window = [&]() {
|
||||
const auto tmp_ = make_naive_tensor_view<address_space_enum::global>(
|
||||
static_cast<const XDataType*>(kargs.p_x),
|
||||
make_tuple(kargs.M, kargs.N),
|
||||
make_tuple(kargs.N, 1),
|
||||
number<kNPerThread>{},
|
||||
make_tuple(kargs.m, kargs.n),
|
||||
make_tuple(kargs.stride, 1),
|
||||
number<Vector_N>{},
|
||||
number<1>{});
|
||||
|
||||
return pad_tensor_view(x_dram_naive,
|
||||
make_tuple(number<kMPerBlock>{}, number<kNPerBlock>{}),
|
||||
sequence<kPadM, kPadN>{});
|
||||
// NOTE: we don't do any pad in this kernel for loading, assume that inside kernel will
|
||||
// check the max count dynamically
|
||||
const auto tmp2_ = pad_tensor_view(
|
||||
tmp_, make_tuple(number<Block_M>{}, number<Block_N>{}), sequence<false, false>{});
|
||||
return make_tile_window(
|
||||
tmp2_, make_tuple(number<Block_M>{}, number<Block_N>{}), {iM, 0});
|
||||
}();
|
||||
|
||||
const auto gamma_n = [&]() {
|
||||
const auto gamma_dram_naive = make_naive_tensor_view<address_space_enum::global>(
|
||||
const auto gamma_window = [&]() {
|
||||
const auto tmp_ = make_naive_tensor_view<address_space_enum::global>(
|
||||
static_cast<const GammaDataType*>(kargs.p_gamma),
|
||||
make_tuple(kargs.N),
|
||||
make_tuple(kargs.n),
|
||||
make_tuple(1),
|
||||
number<kNPerThread>{},
|
||||
number<Vector_N>{},
|
||||
number<1>{});
|
||||
|
||||
return pad_tensor_view(
|
||||
gamma_dram_naive, make_tuple(number<kNPerBlock>{}), sequence<kPadN>{});
|
||||
const auto tmp2_ =
|
||||
pad_tensor_view(tmp_, make_tuple(number<Block_N>{}), sequence<false>{});
|
||||
|
||||
return make_tile_window(tmp2_, make_tuple(number<Block_N>{}), {0});
|
||||
}();
|
||||
|
||||
const auto beta_n = [&]() {
|
||||
const auto gamma_dram_naive = make_naive_tensor_view<address_space_enum::global>(
|
||||
const auto beta_window = [&]() {
|
||||
const auto tmp_ = make_naive_tensor_view<address_space_enum::global>(
|
||||
static_cast<const BetaDataType*>(kargs.p_beta),
|
||||
make_tuple(kargs.N),
|
||||
make_tuple(kargs.n),
|
||||
make_tuple(1),
|
||||
number<kNPerThread>{},
|
||||
number<Vector_N>{},
|
||||
number<1>{});
|
||||
|
||||
return pad_tensor_view(
|
||||
gamma_dram_naive, make_tuple(number<kNPerBlock>{}), sequence<kPadN>{});
|
||||
const auto tmp2_ =
|
||||
pad_tensor_view(tmp_, make_tuple(number<Block_N>{}), sequence<false>{});
|
||||
return make_tile_window(tmp2_, make_tuple(number<Block_M>{}, number<Block_N>{}), {0});
|
||||
}();
|
||||
|
||||
const auto iM = get_block_id() * kMPerBlock;
|
||||
|
||||
constexpr auto xDstr = MakeXBlockTileDistribution();
|
||||
|
||||
auto x_block_window = make_tile_window(
|
||||
x_m_n, make_tuple(number<kMPerBlock>{}, number<kNPerBlock>{}), {iM, 0}, xDstr);
|
||||
|
||||
const auto y_m_n = [&]() {
|
||||
const auto y_dram_naive = make_naive_tensor_view<address_space_enum::global>(
|
||||
auto y_window = [&]() {
|
||||
auto tmp_ = make_naive_tensor_view<address_space_enum::global>(
|
||||
static_cast<YDataType*>(kargs.p_y),
|
||||
make_tuple(kargs.M, kargs.N),
|
||||
make_tuple(kargs.N, 1),
|
||||
number<kNPerThread>{},
|
||||
make_tuple(kargs.m, kargs.n),
|
||||
make_tuple(kargs.stride, 1),
|
||||
number<Vector_N>{},
|
||||
number<1>{});
|
||||
|
||||
return pad_tensor_view(y_dram_naive,
|
||||
make_tuple(number<kMPerBlock>{}, number<kNPerBlock>{}),
|
||||
sequence<kPadM, kPadN>{});
|
||||
auto tmp2_ = pad_tensor_view(
|
||||
tmp_, make_tuple(number<Block_M>{}, number<Block_N>{}), sequence<kPadM, kPadN>{});
|
||||
return make_tile_window(
|
||||
tmp2_, make_tuple(number<Block_M>{}, number<Block_N>{}), {iM, 0});
|
||||
}();
|
||||
|
||||
auto y_block_window = make_tile_window(
|
||||
y_m_n, make_tuple(number<kMPerBlock>{}, number<kNPerBlock>{}), {iM, 0});
|
||||
|
||||
constexpr auto gammaDstr = MakeGammaBetaBlockTileDistribution();
|
||||
constexpr auto betaDstr = gammaDstr;
|
||||
|
||||
auto gamma_block_window =
|
||||
make_tile_window(gamma_n, make_tuple(number<kNPerBlock>{}), {0}, gammaDstr);
|
||||
|
||||
auto beta_block_window = make_tile_window(
|
||||
beta_n, make_tuple(number<kMPerBlock>{}, number<kNPerBlock>{}), {0}, betaDstr);
|
||||
|
||||
auto mean_block_window = [&]() {
|
||||
auto mean_window = [&]() {
|
||||
if constexpr(kSaveMean)
|
||||
{
|
||||
const auto mean_m = [&]() {
|
||||
const auto mean_dram_naive =
|
||||
make_naive_tensor_view_packed<address_space_enum::global>(
|
||||
static_cast<MeanDataType*>(kargs.p_mean),
|
||||
make_tuple(kargs.M),
|
||||
make_tuple(kargs.m),
|
||||
number<1>{});
|
||||
|
||||
return pad_tensor_view(
|
||||
mean_dram_naive, make_tuple(number<kMPerBlock>{}), sequence<kPadM>{});
|
||||
mean_dram_naive, make_tuple(number<Block_M>{}), sequence<kPadM>{});
|
||||
}();
|
||||
|
||||
return make_tile_window(mean_m, make_tuple(number<kMPerBlock>{}), {iM});
|
||||
return make_tile_window(mean_m, make_tuple(number<Block_M>{}), {iM});
|
||||
}
|
||||
else
|
||||
return make_null_tile_window(make_tuple(number<kMPerBlock>{}));
|
||||
return make_null_tile_window(make_tuple(number<Block_M>{}));
|
||||
}();
|
||||
|
||||
auto inv_std_block_window = [&]() {
|
||||
auto inv_std_window = [&]() {
|
||||
if constexpr(kSaveInvStd)
|
||||
{
|
||||
const auto inv_std_m = [&]() {
|
||||
const auto inv_std_dram_naive =
|
||||
make_naive_tensor_view_packed<address_space_enum::global>(
|
||||
static_cast<InvStdDataType*>(kargs.p_invStd),
|
||||
make_tuple(kargs.M),
|
||||
make_tuple(kargs.m),
|
||||
number<1>{});
|
||||
|
||||
return pad_tensor_view(
|
||||
inv_std_dram_naive, make_tuple(number<kMPerBlock>{}), sequence<kPadM>{});
|
||||
inv_std_dram_naive, make_tuple(number<Block_M>{}), sequence<kPadM>{});
|
||||
}();
|
||||
|
||||
return make_tile_window(inv_std_m, make_tuple(number<kMPerBlock>{}), {iM});
|
||||
return make_tile_window(inv_std_m, make_tuple(number<Block_M>{}), {iM});
|
||||
}
|
||||
else
|
||||
return make_null_tile_window(make_tuple(number<kMPerBlock>{}));
|
||||
return make_null_tile_window(make_tuple(number<Block_M>{}));
|
||||
}();
|
||||
|
||||
if(kargs.N <= kNPerBlock)
|
||||
OnePassLayernorm2dFwd(x_block_window,
|
||||
gamma_block_window,
|
||||
beta_block_window,
|
||||
y_block_window,
|
||||
mean_block_window,
|
||||
inv_std_block_window,
|
||||
static_cast<const ComputeDataType>(kargs.epsilon),
|
||||
kargs.N);
|
||||
else
|
||||
TwoPassLayernorm2dFwd(x_block_window,
|
||||
gamma_block_window,
|
||||
beta_block_window,
|
||||
y_block_window,
|
||||
mean_block_window,
|
||||
inv_std_block_window,
|
||||
static_cast<const ComputeDataType>(kargs.epsilon),
|
||||
kargs.N);
|
||||
__shared__ char smem[GetSmemSize()];
|
||||
|
||||
Pipeline{}(x_window,
|
||||
gamma_window,
|
||||
beta_window,
|
||||
y_window,
|
||||
mean_window,
|
||||
inv_std_window,
|
||||
static_cast<const ComputeDataType>(kargs.epsilon),
|
||||
kargs.n,
|
||||
smem);
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
@@ -0,0 +1,78 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
/*
|
||||
// clang-format off
|
||||
|
||||
4-level descriptor: BlockTile-> WarpPerBlock-> WarpTile-> Vector
|
||||
|
||||
Block_N (Warp_N * WarpPerBlock_N * Repeat_N )
|
||||
+<----------------------< Repeat_N(2)>--------------------->+
|
||||
| |
|
||||
+<-- <WarpPerBlock_N(2)> -->+
|
||||
Warp_N
|
||||
+--------------+--------------+--------------+--------------+----+----------------+
|
||||
Warp_M | wrap_0 | wrap_1 | | ^ ^
|
||||
+--------------+--------------+ | <WarpPerBlock_M(2)> |
|
||||
| wrap_2 | wrap_3 | | v
|
||||
+--------------+--------------+--------------+--------------+----+ Block_M
|
||||
| | |
|
||||
+ + |
|
||||
| | | v
|
||||
+--------------+--------------+--------------+--------------+ +
|
||||
|
||||
each Warp-tile (e.g 16 thrd per row)
|
||||
|
||||
Vector_N (contiguous pixels each thrd holds along N, or vector size)
|
||||
+-----------+-----------+-----------+-----------+-----------+
|
||||
| thrd_0 | thrd_1 | thrd_2 | thrd_3 | ... Vector_M
|
||||
+-----------+-----------+-----------+-----------+-----------+
|
||||
| thrd_16 | thrd_17 | thrd_18 | thrd_19 | ...
|
||||
+-----------+-----------+-----------+-----------+-----------+
|
||||
// clang-format on
|
||||
*/
|
||||
template <typename BlockTile_, // block size, seq<M, N>
|
||||
typename WarpPerBlock_, // num warps along seq<M, N>
|
||||
typename WarpTile_, // warp size, seq<M, N>
|
||||
typename Vector_, // contiguous pixels(vector size) along seq<M, N>
|
||||
index_t BlockSize_ =
|
||||
warpSize* reduce_on_sequence(WarpPerBlock_{}, multiplies{}, number<1>{})>
|
||||
struct Layernorm2dShape
|
||||
{
|
||||
// block size
|
||||
static constexpr index_t Block_M = BlockTile_::at(number<0>{});
|
||||
static constexpr index_t Block_N = BlockTile_::at(number<1>{});
|
||||
|
||||
// num warps along seq<M, N>, within each block
|
||||
static constexpr index_t WarpPerBlock_M = WarpPerBlock_::at(number<0>{});
|
||||
static constexpr index_t WarpPerBlock_N = WarpPerBlock_::at(number<1>{});
|
||||
|
||||
// warp size
|
||||
static constexpr index_t Warp_M = WarpTile_::at(number<0>{});
|
||||
static constexpr index_t Warp_N = WarpTile_::at(number<1>{});
|
||||
|
||||
static_assert(Block_M % (WarpPerBlock_M * Warp_M) == 0);
|
||||
static_assert(Block_N % (WarpPerBlock_N * Warp_N) == 0);
|
||||
// repeat of each thread along seq<M, N>
|
||||
static constexpr index_t Repeat_M = Block_M / (WarpPerBlock_M * Warp_M);
|
||||
static constexpr index_t Repeat_N = Block_N / (WarpPerBlock_N * Warp_N);
|
||||
|
||||
// vector size along seq<M, N>
|
||||
static constexpr index_t Vector_M = Vector_::at(number<0>{});
|
||||
static constexpr index_t Vector_N = Vector_::at(number<1>{});
|
||||
|
||||
static_assert(Warp_M % Vector_M == 0);
|
||||
static_assert(Warp_N % Vector_N == 0);
|
||||
// num of threads along seq<M, N>, within each warp
|
||||
static constexpr index_t ThreadPerWarp_M = Warp_M / Vector_M;
|
||||
static constexpr index_t ThreadPerWarp_N = Warp_N / Vector_N;
|
||||
|
||||
static constexpr index_t BlockSize = BlockSize_;
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -1,34 +0,0 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core/utility/type_traits.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
template <typename XDataType_,
|
||||
typename GammaDataType_,
|
||||
typename BetaDataType_,
|
||||
typename ComputeDataType_,
|
||||
typename YDataType_,
|
||||
typename MeanDataType_,
|
||||
typename InvStdDataType_,
|
||||
typename BlockShape_,
|
||||
bool kPadM_,
|
||||
bool kPadN_>
|
||||
struct BlockLayernorm2dFwdProblem
|
||||
{
|
||||
using XDataType = remove_cvref_t<XDataType_>;
|
||||
using GammaDataType = remove_cvref_t<GammaDataType_>;
|
||||
using BetaDataType = remove_cvref_t<BetaDataType_>;
|
||||
using ComputeDataType = remove_cvref_t<ComputeDataType_>;
|
||||
using YDataType = remove_cvref_t<YDataType_>;
|
||||
using MeanDataType = remove_cvref_t<MeanDataType_>;
|
||||
using InvStdDataType = remove_cvref_t<InvStdDataType_>;
|
||||
using BlockShape = remove_cvref_t<BlockShape_>;
|
||||
static constexpr bool kPadM = kPadM_;
|
||||
static constexpr bool kPadN = kPadN_;
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -0,0 +1,99 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/welford/block/block_welford_problem.hpp"
|
||||
#include "ck_tile/ops/welford/block/block_welford.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
struct Layernorm2dFwdPipelineDefaultPolicy
|
||||
{
|
||||
template <typename Problem>
|
||||
CK_TILE_DEVICE static constexpr auto MakeXBlockTileDistribution()
|
||||
{
|
||||
using S = typename Problem::BlockShape;
|
||||
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<
|
||||
sequence<>,
|
||||
tuple<sequence<S::Repeat_M, S::WarpPerBlock_M, S::ThreadPerWarp_M, S::Vector_M>,
|
||||
sequence<S::Repeat_N, S::WarpPerBlock_N, S::ThreadPerWarp_N, S::Vector_N>>,
|
||||
tuple<sequence<1, 2>, sequence<1, 2>>,
|
||||
tuple<sequence<1, 1>, sequence<2, 2>>,
|
||||
sequence<1, 1, 2, 2>,
|
||||
sequence<0, 3, 0, 3>>{});
|
||||
}
|
||||
template <typename Problem>
|
||||
CK_TILE_DEVICE static constexpr auto MakeGammaBetaBlockTileDistribution()
|
||||
{
|
||||
using S = typename Problem::BlockShape;
|
||||
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<
|
||||
sequence<S::WarpPerBlock_M, S::ThreadPerWarp_M>,
|
||||
tuple<sequence<S::Repeat_N, S::WarpPerBlock_N, S::ThreadPerWarp_N, S::Vector_N>>,
|
||||
tuple<sequence<0, 1>, sequence<0, 1>>,
|
||||
tuple<sequence<0, 1>, sequence<1, 2>>,
|
||||
sequence<1, 1>,
|
||||
sequence<0, 3>>{});
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetBlockWelford()
|
||||
{
|
||||
using P_ = BlockWelfordProblem<typename Problem::XDataType,
|
||||
typename Problem::ComputeDataType,
|
||||
typename Problem::BlockShape>;
|
||||
|
||||
return BlockWelford<P_>{};
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetBlockWelfordSync()
|
||||
{
|
||||
using P_ = BlockWelfordProblem<typename Problem::XDataType,
|
||||
typename Problem::ComputeDataType,
|
||||
typename Problem::BlockShape>;
|
||||
|
||||
return BlockWelfordSync<P_>{};
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetBlockWelfordCrossWarpSync()
|
||||
{
|
||||
using P_ = BlockWelfordProblem<typename Problem::XDataType,
|
||||
typename Problem::ComputeDataType,
|
||||
typename Problem::BlockShape>;
|
||||
|
||||
return BlockWelfordCrossWarpSync<P_>{};
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize()
|
||||
{
|
||||
if constexpr(Problem::kNeedCrossWarpSync)
|
||||
{
|
||||
using P_ = BlockWelfordProblem<typename Problem::XDataType,
|
||||
typename Problem::ComputeDataType,
|
||||
typename Problem::BlockShape>;
|
||||
|
||||
using block_welford = BlockWelford<P_>;
|
||||
using x_block_tile =
|
||||
decltype(make_static_distributed_tensor<typename Problem::XDataType>(
|
||||
MakeXBlockTileDistribution<Problem>()));
|
||||
using mean_var_block_tile =
|
||||
decltype(block_welford::template MakeMeanVarBlockTile<x_block_tile>());
|
||||
|
||||
return GetBlockWelfordCrossWarpSync<Problem>()
|
||||
.template GetSmemSize<mean_var_block_tile>();
|
||||
}
|
||||
else
|
||||
{
|
||||
return 1; // zero size arrays are an extension
|
||||
}
|
||||
}
|
||||
};
|
||||
} // namespace ck_tile
|
||||
@@ -0,0 +1,119 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_pipeline_default_policy.hpp"
|
||||
#include <string>
|
||||
#include <type_traits>
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
template <typename Problem_, typename Policy_ = Layernorm2dFwdPipelineDefaultPolicy>
|
||||
struct Layernorm2dFwdPipelineOnePass
|
||||
{
|
||||
using Problem = ck_tile::remove_cvref_t<Problem_>;
|
||||
using Policy = ck_tile::remove_cvref_t<Policy_>;
|
||||
|
||||
using XDataType = ck_tile::remove_cvref_t<typename Problem::XDataType>;
|
||||
using GammaDataType = ck_tile::remove_cvref_t<typename Problem::GammaDataType>;
|
||||
using BetaDataType = ck_tile::remove_cvref_t<typename Problem::BetaDataType>;
|
||||
using ComputeDataType = ck_tile::remove_cvref_t<typename Problem::ComputeDataType>;
|
||||
using YDataType = ck_tile::remove_cvref_t<typename Problem::YDataType>;
|
||||
using MeanDataType = ck_tile::remove_cvref_t<typename Problem::MeanDataType>;
|
||||
using InvStdDataType = ck_tile::remove_cvref_t<typename Problem::InvStdDataType>;
|
||||
|
||||
static constexpr bool kHasGamma = !std::is_same_v<GammaDataType, ck_tile::null_type>;
|
||||
static constexpr bool kHasBeta = !std::is_same_v<BetaDataType, ck_tile::null_type>;
|
||||
static constexpr bool kSaveMean = Problem::kSaveMeanInvStd;
|
||||
static constexpr bool kSaveInvStd = Problem::kSaveMeanInvStd;
|
||||
|
||||
static constexpr bool kNeedCrossWarpSync = Problem::kNeedCrossWarpSync;
|
||||
static constexpr bool kPadM = false; // TODO - BlockLayernorm2dFwdProblem::kPadM
|
||||
static constexpr bool kPadN = Problem::kPadN;
|
||||
|
||||
static constexpr const char* name = []() {
|
||||
if constexpr(kNeedCrossWarpSync)
|
||||
return "bpr"; // block per row
|
||||
else
|
||||
return "wpr"; // warp per row
|
||||
}();
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize()
|
||||
{
|
||||
return Policy::template GetSmemSize<Problem>();
|
||||
}
|
||||
|
||||
template <typename XWindow,
|
||||
typename GammaWindow,
|
||||
typename BetaWindow,
|
||||
typename YWindow,
|
||||
typename MeanWindow,
|
||||
typename InvStdWindow>
|
||||
CK_TILE_DEVICE auto operator()(const XWindow& x_window_,
|
||||
const GammaWindow& gamma_window_,
|
||||
const BetaWindow& beta_window_,
|
||||
YWindow& y_window,
|
||||
MeanWindow& mean_window,
|
||||
InvStdWindow& inv_std_window,
|
||||
ComputeDataType epsilon,
|
||||
ck_tile::index_t row_size,
|
||||
void* smem) const
|
||||
{
|
||||
const auto x_window =
|
||||
make_tile_window(x_window_, Policy::template MakeXBlockTileDistribution<Problem>());
|
||||
const auto gamma_window = make_tile_window(
|
||||
gamma_window_, Policy::template MakeGammaBetaBlockTileDistribution<Problem>());
|
||||
const auto beta_window = make_tile_window(
|
||||
beta_window_, Policy::template MakeGammaBetaBlockTileDistribution<Problem>());
|
||||
|
||||
const auto x = load_tile(x_window);
|
||||
int cur_count = 0;
|
||||
int max_count =
|
||||
block_tile_welford_calculate_max_count<typename Problem::BlockShape>(row_size);
|
||||
auto block_welford = Policy::template GetBlockWelford<Problem>();
|
||||
auto block_welford_sync = Policy::template GetBlockWelfordSync<Problem>();
|
||||
auto block_welford_cross_warp_sync =
|
||||
Policy::template GetBlockWelfordCrossWarpSync<Problem>();
|
||||
|
||||
// load gamma/beta (TODO: support no gamma/beta?)
|
||||
const auto gamma = load_tile(gamma_window);
|
||||
const auto beta = load_tile(beta_window);
|
||||
|
||||
// compute welford each-thread->cross-lane->cross-warp
|
||||
auto [mean, var] = block_welford(x, cur_count, max_count);
|
||||
block_welford_sync(mean, var, cur_count);
|
||||
block_welford_cross_warp_sync(mean, var, cur_count, smem);
|
||||
block_tile_welford_post_scale_var(var, cur_count);
|
||||
|
||||
// compute inv-std
|
||||
auto inv_std = tile_elementwise_in(
|
||||
[&](const auto& v_) {
|
||||
return type_convert<ComputeDataType>(1.0f) / (sqrt(v_) + epsilon);
|
||||
},
|
||||
var);
|
||||
|
||||
if constexpr(kSaveMean)
|
||||
store_tile(mean_window, cast_tile<MeanDataType>(mean));
|
||||
if constexpr(kSaveInvStd)
|
||||
store_tile(inv_std_window, cast_tile<InvStdDataType>(inv_std));
|
||||
|
||||
// layernorm computation
|
||||
auto y = make_static_distributed_tensor<YDataType>(x.get_tile_distribution());
|
||||
sweep_tile(y, [&, mean_ = mean](auto idx) {
|
||||
constexpr auto i_idx = make_tuple(idx[number<0>{}]);
|
||||
constexpr auto j_idx = make_tuple(idx[number<1>{}]);
|
||||
|
||||
const auto gamma_ = type_convert<ComputeDataType>(gamma[j_idx]);
|
||||
const auto beta_ = type_convert<ComputeDataType>(beta[j_idx]);
|
||||
|
||||
const auto x_ = type_convert<ComputeDataType>(x[idx]);
|
||||
auto y_ = (x_ - mean_[i_idx]) * inv_std[i_idx] * gamma_ + beta_;
|
||||
|
||||
y(idx) = type_convert<YDataType>(y_);
|
||||
});
|
||||
store_tile(y_window, y);
|
||||
}
|
||||
};
|
||||
} // namespace ck_tile
|
||||
@@ -0,0 +1,40 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core/utility/type_traits.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
template <typename XDataType_,
|
||||
typename GammaDataType_,
|
||||
typename BetaDataType_,
|
||||
typename ComputeDataType_,
|
||||
typename YDataType_,
|
||||
typename MeanDataType_,
|
||||
typename InvStdDataType_,
|
||||
typename BlockShape_,
|
||||
bool kPadN_,
|
||||
bool kSaveMeanInvStd_,
|
||||
bool kTwoPass_>
|
||||
struct Layernorm2dFwdPipelineProblem
|
||||
{
|
||||
using XDataType = remove_cvref_t<XDataType_>;
|
||||
using GammaDataType = remove_cvref_t<GammaDataType_>;
|
||||
using BetaDataType = remove_cvref_t<BetaDataType_>;
|
||||
using ComputeDataType = remove_cvref_t<ComputeDataType_>;
|
||||
using YDataType = remove_cvref_t<YDataType_>;
|
||||
using MeanDataType = remove_cvref_t<MeanDataType_>;
|
||||
using InvStdDataType = remove_cvref_t<InvStdDataType_>;
|
||||
using BlockShape = remove_cvref_t<BlockShape_>;
|
||||
|
||||
static constexpr bool kNeedCrossLaneSync = BlockShape::ThreadPerWarp_N > 1;
|
||||
static constexpr bool kNeedCrossWarpSync = BlockShape::WarpPerBlock_N > 1;
|
||||
|
||||
static constexpr bool kPadN = kPadN_;
|
||||
static constexpr bool kSaveMeanInvStd = kSaveMeanInvStd_;
|
||||
static constexpr bool kTwoPass = kTwoPass_;
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -0,0 +1,160 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_pipeline_default_policy.hpp"
|
||||
#include <string>
|
||||
#include <type_traits>
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
template <typename Problem_, typename Policy_ = Layernorm2dFwdPipelineDefaultPolicy>
|
||||
struct Layernorm2dFwdPipelineTwoPass
|
||||
{
|
||||
using Problem = ck_tile::remove_cvref_t<Problem_>;
|
||||
using Policy = ck_tile::remove_cvref_t<Policy_>;
|
||||
|
||||
using XDataType = ck_tile::remove_cvref_t<typename Problem::XDataType>;
|
||||
using GammaDataType = ck_tile::remove_cvref_t<typename Problem::GammaDataType>;
|
||||
using BetaDataType = ck_tile::remove_cvref_t<typename Problem::BetaDataType>;
|
||||
using ComputeDataType = ck_tile::remove_cvref_t<typename Problem::ComputeDataType>;
|
||||
using YDataType = ck_tile::remove_cvref_t<typename Problem::YDataType>;
|
||||
using MeanDataType = ck_tile::remove_cvref_t<typename Problem::MeanDataType>;
|
||||
using InvStdDataType = ck_tile::remove_cvref_t<typename Problem::InvStdDataType>;
|
||||
|
||||
static constexpr bool kHasGamma = !std::is_same_v<GammaDataType, ck_tile::null_type>;
|
||||
static constexpr bool kHasBeta = !std::is_same_v<BetaDataType, ck_tile::null_type>;
|
||||
static constexpr bool kSaveMean = Problem::kSaveMeanInvStd;
|
||||
static constexpr bool kSaveInvStd = Problem::kSaveMeanInvStd;
|
||||
|
||||
static constexpr bool kNeedCrossWarpSync = Problem::kNeedCrossWarpSync;
|
||||
static constexpr bool kPadM = false; // TODO - BlockLayernorm2dFwdProblem::kPadM
|
||||
static constexpr bool kPadN = Problem::kPadN;
|
||||
|
||||
static constexpr const char* name = []() {
|
||||
if constexpr(kNeedCrossWarpSync)
|
||||
return "bpr"; // block per row
|
||||
else
|
||||
return "wpr"; // warp per row
|
||||
}();
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize()
|
||||
{
|
||||
return Policy::template GetSmemSize<Problem>();
|
||||
}
|
||||
|
||||
template <typename XWindow,
|
||||
typename GammaWindow,
|
||||
typename BetaWindow,
|
||||
typename YWindow,
|
||||
typename MeanWindow,
|
||||
typename InvStdWindow>
|
||||
CK_TILE_DEVICE auto operator()(const XWindow& x_window_,
|
||||
const GammaWindow& gamma_window_,
|
||||
const BetaWindow& beta_window_,
|
||||
YWindow& y_window,
|
||||
MeanWindow& mean_window,
|
||||
InvStdWindow& inv_std_window,
|
||||
ComputeDataType epsilon,
|
||||
ck_tile::index_t row_size,
|
||||
void* smem) const
|
||||
{
|
||||
auto x_window =
|
||||
make_tile_window(x_window_, Policy::template MakeXBlockTileDistribution<Problem>());
|
||||
auto gamma_window = make_tile_window(
|
||||
gamma_window_, Policy::template MakeGammaBetaBlockTileDistribution<Problem>());
|
||||
auto beta_window = make_tile_window(
|
||||
beta_window_, Policy::template MakeGammaBetaBlockTileDistribution<Problem>());
|
||||
|
||||
// Problem::BlockShape
|
||||
static constexpr index_t Block_N = Problem::BlockShape::Block_N;
|
||||
index_t num_n_tile_iteration =
|
||||
__builtin_amdgcn_readfirstlane(integer_divide_ceil(row_size, Block_N));
|
||||
|
||||
// total number of count assume current iter have no pad(only last iter has pad)
|
||||
constexpr index_t count_per_iter =
|
||||
Problem::BlockShape::Repeat_N * Problem::BlockShape::Vector_N;
|
||||
const index_t last_iter_n = row_size - (num_n_tile_iteration - 1) * Block_N;
|
||||
|
||||
int cur_count = 0;
|
||||
int max_count =
|
||||
(num_n_tile_iteration - 1) * count_per_iter +
|
||||
block_tile_welford_calculate_max_count<typename Problem::BlockShape>(last_iter_n);
|
||||
auto block_welford = Policy::template GetBlockWelford<Problem>();
|
||||
auto block_welford_sync = Policy::template GetBlockWelfordSync<Problem>();
|
||||
auto block_welford_cross_warp_sync =
|
||||
Policy::template GetBlockWelfordCrossWarpSync<Problem>();
|
||||
|
||||
using XTensorType = decltype(load_tile(x_window));
|
||||
auto mean = block_welford.template MakeMeanVarBlockTile<XTensorType>();
|
||||
auto var = block_welford.template MakeMeanVarBlockTile<XTensorType>();
|
||||
|
||||
for(int iN = __builtin_amdgcn_readfirstlane(0); iN < num_n_tile_iteration; ++iN)
|
||||
{
|
||||
const auto x = load_tile(x_window);
|
||||
block_welford(x, mean, var, cur_count, max_count);
|
||||
move_tile_window(x_window, {0, Block_N});
|
||||
}
|
||||
|
||||
block_welford_sync(mean, var, cur_count);
|
||||
block_welford_cross_warp_sync(mean, var, cur_count, smem);
|
||||
block_tile_welford_post_scale_var(var, cur_count);
|
||||
|
||||
// compute inv-std
|
||||
auto inv_std = tile_elementwise_in(
|
||||
[&](const auto& v_) {
|
||||
return type_convert<ComputeDataType>(1.0f) / (sqrt(v_) + epsilon);
|
||||
},
|
||||
var);
|
||||
|
||||
if constexpr(kSaveMean)
|
||||
store_tile(mean_window, cast_tile<MeanDataType>(mean));
|
||||
if constexpr(kSaveInvStd)
|
||||
store_tile(inv_std_window, cast_tile<InvStdDataType>(inv_std));
|
||||
|
||||
// reverse read x to reuse cache
|
||||
ck_tile::index_t stride_to_right_most_window =
|
||||
row_size % Block_N == 0 ? row_size - Block_N : row_size - row_size % Block_N;
|
||||
|
||||
// x_window.foo();
|
||||
// gamma_window.foo();
|
||||
move_tile_window(x_window, {0, -Block_N});
|
||||
move_tile_window(gamma_window, {stride_to_right_most_window});
|
||||
move_tile_window(beta_window, {stride_to_right_most_window});
|
||||
move_tile_window(y_window, {0, stride_to_right_most_window});
|
||||
|
||||
// layernorm computation
|
||||
for(int iN = __builtin_amdgcn_readfirstlane(0); iN < num_n_tile_iteration; ++iN)
|
||||
{
|
||||
const auto x = load_tile(x_window);
|
||||
// load gamma/beta (TODO: support no gamma/beta?)
|
||||
const auto gamma = load_tile(gamma_window);
|
||||
const auto beta = load_tile(beta_window);
|
||||
|
||||
auto y = make_static_distributed_tensor<YDataType>(x.get_tile_distribution());
|
||||
|
||||
sweep_tile(y, [&, mean_ = mean](auto idx) {
|
||||
constexpr auto i_idx = make_tuple(idx[number<0>{}]);
|
||||
constexpr auto j_idx = make_tuple(idx[number<1>{}]);
|
||||
|
||||
const auto gamma_ = type_convert<ComputeDataType>(gamma[j_idx]);
|
||||
const auto beta_ = type_convert<ComputeDataType>(beta[j_idx]);
|
||||
|
||||
const auto x_ = type_convert<ComputeDataType>(x[idx]);
|
||||
auto y_ = (x_ - mean_[i_idx]) * inv_std[i_idx] * gamma_ + beta_;
|
||||
|
||||
y(idx) = type_convert<YDataType>(y_);
|
||||
});
|
||||
|
||||
store_tile(y_window, y);
|
||||
|
||||
move_tile_window(x_window, {0, -Block_N});
|
||||
move_tile_window(gamma_window, {-Block_N});
|
||||
move_tile_window(beta_window, {-Block_N});
|
||||
move_tile_window(y_window, {0, -Block_N});
|
||||
}
|
||||
}
|
||||
};
|
||||
} // namespace ck_tile
|
||||
@@ -1,35 +0,0 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
template <typename ThreadTile, // Sequence<...
|
||||
typename WarpTile, // Sequence<...
|
||||
typename BlockTile> // Sequence<...
|
||||
struct TileLayernorm2dShape
|
||||
{
|
||||
static constexpr index_t kMPerThread = ThreadTile::at(number<0>{});
|
||||
static constexpr index_t kNPerThread = ThreadTile::at(number<1>{});
|
||||
|
||||
static constexpr index_t kMPerWarp = WarpTile::at(number<0>{});
|
||||
static constexpr index_t kNPerWarp = WarpTile::at(number<1>{});
|
||||
|
||||
static constexpr index_t kMThreadPerWarp = kMPerWarp / kMPerThread;
|
||||
static constexpr index_t kNThreadPerWarp = kNPerWarp / kNPerThread;
|
||||
|
||||
static constexpr index_t kMPerBlock = BlockTile::at(number<0>{});
|
||||
static constexpr index_t kNPerBlock = BlockTile::at(number<1>{});
|
||||
|
||||
static constexpr index_t kMWarpPerBlock = kMPerBlock / kMPerWarp;
|
||||
static constexpr index_t kNWarpPerBlock = kNPerBlock / kNPerWarp;
|
||||
|
||||
// TODO - kNNumWarps can only be 1 if we don't support cross warp welford
|
||||
static_assert(kNWarpPerBlock == 1);
|
||||
|
||||
static constexpr index_t kBlockSize = warpSize * kMWarpPerBlock * kNWarpPerBlock;
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
Reference in New Issue
Block a user