diff --git a/example/ck_tile/02_layernorm2d/CMakeLists.txt b/example/ck_tile/02_layernorm2d/CMakeLists.txt new file mode 100644 index 0000000000..bac5f45cd3 --- /dev/null +++ b/example/ck_tile/02_layernorm2d/CMakeLists.txt @@ -0,0 +1,4 @@ +# not using add_example_executable() to add this target, since we don't want this to have +# to be included in "make all/install/check" +add_executable(tile_example_layernorm2d_fwd EXCLUDE_FROM_ALL layernorm2d_fwd.cpp) +target_compile_options(tile_example_layernorm2d_fwd PRIVATE -DSAVE_MEAN_INV_STD) \ No newline at end of file diff --git a/example/ck_tile/02_layernorm2d/README.md b/example/ck_tile/02_layernorm2d/README.md new file mode 100644 index 0000000000..433dad04e6 --- /dev/null +++ b/example/ck_tile/02_layernorm2d/README.md @@ -0,0 +1,22 @@ +# Layernorm2D forward + +This folder contains example for Layernorm2D forward using ck_tile tile-programming implementation. + +## build +``` +# in the root of ck_tile +mkdir build && cd build +sh ../script/cmake-ck-dev.sh ../ # you can replace this to gfx90a, gfx942... +make tile_example_layernorm2d_fwd -j +``` +This will result in an executable `build/bin/tile_example_layernorm2d_fwd` + +## example +``` +args: + -m m dimension (default:3328) + -n m dimension (default:4096) + -e epsilon (default:1e-5) + -v cpu validation or not (default:1) + -prec precision (default:fp16) +``` \ No newline at end of file diff --git a/example/ck_tile/02_layernorm2d/layernorm2d_fwd.cpp b/example/ck_tile/02_layernorm2d/layernorm2d_fwd.cpp new file mode 100644 index 0000000000..9cbd286104 --- /dev/null +++ b/example/ck_tile/02_layernorm2d/layernorm2d_fwd.cpp @@ -0,0 +1,191 @@ +#include "ck_tile/host.hpp" +#include "layernorm2d_fwd.hpp" +#include + +// Host API implementation +float layernorm2d_fwd(layernorm2d_fwd_traits t, + layernorm2d_fwd_args a, + const ck_tile::stream_config& s) +{ + if(t.data_type.compare("fp16") == 0) + { + using XDataType = ck_tile::half_t; + using YDataType = ck_tile::half_t; + using GammaDataType = ck_tile::half_t; + using BetaDataType = ck_tile::half_t; +#ifdef SAVE_MEAN_INV_STD + using MeanDataType = ck_tile::half_t; + using InvStdDataType = ck_tile::half_t; +#else + using MeanDataType = ck_tile::null_type; + using InvStdDataType = ck_tile::null_type; +#endif + using ComputeDataType = float; + + using thread_tile = ck_tile::sequence<4, 4>; + using warp_tile = ck_tile::sequence<8, 128>; + using block_tile = ck_tile::sequence<32, 128>; + + using Shape = ck_tile::TileLayernorm2dShape; + + using PipelineProblem = ck_tile::BlockLayernorm2dFwdProblem; + + using Kernel = ck_tile::Layernorm2dFwd; + + auto kargs = Kernel::MakeKargs( + a.p_x, a.p_gamma, a.p_beta, a.p_y, a.p_mean, a.p_invStd, a.epsilon, a.M, a.N); + + const dim3 grids = Kernel::GridSize(a.M); + constexpr dim3 blocks = Kernel::BlockSize(); + + constexpr ck_tile::index_t kBlockPerCu = Shape::kMWarpPerBlock * Shape::kNWarpPerBlock; + + float ave_time = ck_tile::launch_kernel( + s, ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); + + return ave_time; + } + + return 0; +} + +auto create_args(int argc, char* argv[]) +{ + ck_tile::ArgParser arg_parser; + arg_parser.insert("m", "3328", "m dimension") + .insert("n", "4096", "m dimension") + .insert("e", "1e-5", "epsilon") + .insert("v", "1", "cpu validation or not") + .insert("prec", "fp16", "precision"); + + bool result = arg_parser.parse(argc, argv); + return std::make_tuple(result, arg_parser); +} + +int main(int argc, char* argv[]) +{ + + auto [result, arg_parser] = create_args(argc, argv); + if(!result) + return -1; + + float epsilon = arg_parser.get_float("e"); + ck_tile::index_t M = arg_parser.get_int("m"); + ck_tile::index_t N = arg_parser.get_int("n"); + std::string data_type = arg_parser.get_str("prec"); + int do_validation = arg_parser.get_int("v"); + + using XDataType = ck_tile::half_t; + using YDataType = ck_tile::half_t; + using GammaDataType = ck_tile::half_t; + using BetaDataType = ck_tile::half_t; +#ifdef SAVE_MEAN_INV_STD + using MeanDataType = ck_tile::half_t; + using InvStdDataType = ck_tile::half_t; +#else + using MeanDataType = ck_tile::null_type; + using InvStdDataType = ck_tile::null_type; +#endif + using ComputeDataType = float; + + // host verify + ck_tile::HostTensor x_host({M, N}); + ck_tile::HostTensor gamma_host({N}); + ck_tile::HostTensor beta_host({N}); + + ck_tile::HostTensor y_host_ref({M, N}); + ck_tile::HostTensor y_host_dev({M, N}); + + ck_tile::HostTensor mean_host_ref({M}); + ck_tile::HostTensor invStd_host_ref({M}); + +#ifdef SAVE_MEAN_INV_STD + ck_tile::HostTensor mean_host_dev({M}); + ck_tile::HostTensor invStd_host_dev({M}); +#endif + + ck_tile::FillUniformDistribution{-5.f, 5.f}(x_host); + ck_tile::FillUniformDistribution{-5.f, 5.f}(gamma_host); + ck_tile::FillUniformDistribution{-5.f, 5.f}(beta_host); + + ck_tile::DeviceMem x_buf(x_host.get_element_space_size_in_bytes()); + ck_tile::DeviceMem gamma_buf(gamma_host.get_element_space_size_in_bytes()); + ck_tile::DeviceMem beta_buf(beta_host.get_element_space_size_in_bytes()); + ck_tile::DeviceMem y_buf(y_host_dev.get_element_space_size_in_bytes()); + +#ifdef SAVE_MEAN_INV_STD + ck_tile::DeviceMem mean_buf(mean_host_dev.get_element_space_size_in_bytes()); + ck_tile::DeviceMem invStd_buf(invStd_host_dev.get_element_space_size_in_bytes()); +#endif + + x_buf.ToDevice(x_host.data()); + gamma_buf.ToDevice(gamma_host.data()); + beta_buf.ToDevice(beta_host.data()); + + layernorm2d_fwd_traits traits{data_type}; + + layernorm2d_fwd_args args{x_buf.GetDeviceBuffer(), + gamma_buf.GetDeviceBuffer(), + beta_buf.GetDeviceBuffer(), + y_buf.GetDeviceBuffer(), +#ifdef SAVE_MEAN_INV_STD + mean_buf.GetDeviceBuffer(), + invStd_buf.GetDeviceBuffer(), +#else + nullptr, + nullptr, +#endif + epsilon, + M, + N}; + + float ave_time = layernorm2d_fwd(traits, args, ck_tile::stream_config{nullptr, true}); + + std::size_t num_byte = sizeof(XDataType) * M * N + sizeof(GammaDataType) * N + + sizeof(BetaDataType) * N + sizeof(YDataType) * M * N; + + float gb_per_sec = num_byte / 1.E6 / ave_time; + std::cout << "[" << data_type << "]" + << " m:" << M << ", n:" << N << ", " << ave_time << " ms, " << gb_per_sec << " GB/s" + << std::flush; + + bool pass = true; + + if(do_validation) + { + // reference + ck_tile::reference_layernorm2d_fwd( + x_host, gamma_host, beta_host, y_host_ref, mean_host_ref, invStd_host_ref, epsilon); + + y_buf.FromDevice(y_host_dev.data()); + + pass = ck_tile::check_err(y_host_dev, y_host_ref); + +#ifdef SAVE_MEAN_INV_STD + mean_buf.FromDevice(mean_host_dev.data()); + pass &= ck_tile::check_err(mean_host_dev, mean_host_ref); + + invStd_buf.FromDevice(invStd_host_dev.data()); + pass &= ck_tile::check_err(invStd_host_dev, invStd_host_ref); +#endif + + std::cout << ", valid:" << (pass ? "y" : "n") << std::flush; + } + + std::cout << std::endl << std::flush; + + return !pass; +} diff --git a/example/ck_tile/02_layernorm2d/layernorm2d_fwd.hpp b/example/ck_tile/02_layernorm2d/layernorm2d_fwd.hpp new file mode 100644 index 0000000000..4d1aac0994 --- /dev/null +++ b/example/ck_tile/02_layernorm2d/layernorm2d_fwd.hpp @@ -0,0 +1,30 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/host/kernel_launch.hpp" +#include "ck_tile/ops/layernorm2d.hpp" +#include + +struct layernorm2d_fwd_traits +{ + std::string data_type; +}; + +struct layernorm2d_fwd_args +{ + const void* p_x; + const void* p_gamma; + const void* p_beta; + void* p_y; + void* p_mean; + void* p_invStd; + float epsilon; + ck_tile::index_t M; + ck_tile::index_t N; +}; + +// host API +float layernorm2d_fwd(layernorm2d_fwd_traits, layernorm2d_fwd_args, const ck_tile::stream_config&); diff --git a/example/ck_tile/CMakeLists.txt b/example/ck_tile/CMakeLists.txt index d2b086e043..995d193f10 100644 --- a/example/ck_tile/CMakeLists.txt +++ b/example/ck_tile/CMakeLists.txt @@ -3,3 +3,4 @@ include_directories(AFTER ) add_subdirectory(01_fmha) +add_subdirectory(02_layernorm2d) diff --git a/include/ck_tile/core.hpp b/include/ck_tile/core.hpp index bb490cce4a..4cddf6faa9 100644 --- a/include/ck_tile/core.hpp +++ b/include/ck_tile/core.hpp @@ -27,6 +27,7 @@ #include "ck_tile/core/numeric/integer.hpp" #include "ck_tile/core/numeric/integral_constant.hpp" #include "ck_tile/core/numeric/math.hpp" +#include "ck_tile/core/numeric/null_type.hpp" #include "ck_tile/core/numeric/numeric.hpp" #include "ck_tile/core/numeric/type_convert.hpp" #include "ck_tile/core/numeric/vector_type.hpp" diff --git a/include/ck_tile/core/numeric/null_type.hpp b/include/ck_tile/core/numeric/null_type.hpp new file mode 100644 index 0000000000..8799c0560e --- /dev/null +++ b/include/ck_tile/core/numeric/null_type.hpp @@ -0,0 +1,13 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once +#include + +namespace ck_tile { + +struct null_type +{ +}; + +} // namespace ck_tile diff --git a/include/ck_tile/host.hpp b/include/ck_tile/host.hpp index 09030fa6df..0e69a925d5 100644 --- a/include/ck_tile/host.hpp +++ b/include/ck_tile/host.hpp @@ -18,6 +18,7 @@ #include "ck_tile/host/reference/reference_batched_softmax.hpp" #include "ck_tile/host/reference/reference_gemm.hpp" #include "ck_tile/host/reference/reference_im2col.hpp" +#include "ck_tile/host/reference/reference_layernorm2d.hpp" #include "ck_tile/host/reference/reference_reduce.hpp" #include "ck_tile/host/reference/reference_softmax.hpp" #include "ck_tile/host/stream_config.hpp" diff --git a/include/ck_tile/host/check_err.hpp b/include/ck_tile/host/check_err.hpp index 1ef9b24138..529bfdff25 100644 --- a/include/ck_tile/host/check_err.hpp +++ b/include/ck_tile/host/check_err.hpp @@ -56,8 +56,9 @@ check_err(const Range& out, } const auto is_infinity_error = [=](auto o, auto r) { - const bool either_not_finite = !std::isfinite(o) || !std::isfinite(r); - const bool both_infinite_and_same = std::isinf(o) && std::isinf(r) && (o == r); + const bool either_not_finite = !std::isfinite(o) || !std::isfinite(r); + const bool both_infinite_and_same = + std::isinf(o) && std::isinf(r) && (bit_cast(o) == bit_cast(r)); return either_not_finite && !(allow_infinity_ref && both_infinite_and_same); }; @@ -114,8 +115,9 @@ check_err(const Range& out, } const auto is_infinity_error = [=](auto o, auto r) { - const bool either_not_finite = !std::isfinite(o) || !std::isfinite(r); - const bool both_infinite_and_same = std::isinf(o) && std::isinf(r) && (o == r); + const bool either_not_finite = !std::isfinite(o) || !std::isfinite(r); + const bool both_infinite_and_same = + std::isinf(o) && std::isinf(r) && (bit_cast(o) == bit_cast(r)); return either_not_finite && !(allow_infinity_ref && both_infinite_and_same); }; @@ -173,8 +175,9 @@ check_err(const Range& out, } const auto is_infinity_error = [=](auto o, auto r) { - const bool either_not_finite = !std::isfinite(o) || !std::isfinite(r); - const bool both_infinite_and_same = std::isinf(o) && std::isinf(r) && (o == r); + const bool either_not_finite = !std::isfinite(o) || !std::isfinite(r); + const bool both_infinite_and_same = + std::isinf(o) && std::isinf(r) && (bit_cast(o) == bit_cast(r)); return either_not_finite && !(allow_infinity_ref && both_infinite_and_same); }; @@ -285,8 +288,9 @@ std::enable_if_t<(std::is_same_v, ranges::range_val } const auto is_infinity_error = [=](auto o, auto r) { - const bool either_not_finite = !std::isfinite(o) || !std::isfinite(r); - const bool both_infinite_and_same = std::isinf(o) && std::isinf(r) && (o == r); + const bool either_not_finite = !std::isfinite(o) || !std::isfinite(r); + const bool both_infinite_and_same = + std::isinf(o) && std::isinf(r) && (bit_cast(o) == bit_cast(r)); return either_not_finite && !(allow_infinity_ref && both_infinite_and_same); }; @@ -357,8 +361,9 @@ std::enable_if_t<(std::is_same_v, ranges::range_val } const auto is_infinity_error = [=](auto o, auto r) { - const bool either_not_finite = !std::isfinite(o) || !std::isfinite(r); - const bool both_infinite_and_same = std::isinf(o) && std::isinf(r) && (o == r); + const bool either_not_finite = !std::isfinite(o) || !std::isfinite(r); + const bool both_infinite_and_same = + std::isinf(o) && std::isinf(r) && (bit_cast(o) == bit_cast(r)); return either_not_finite && !(allow_infinity_ref && both_infinite_and_same); }; diff --git a/include/ck_tile/host/reference/reference_layernorm2d.hpp b/include/ck_tile/host/reference/reference_layernorm2d.hpp new file mode 100644 index 0000000000..837f52c399 --- /dev/null +++ b/include/ck_tile/host/reference/reference_layernorm2d.hpp @@ -0,0 +1,69 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/host/host_tensor.hpp" + +namespace ck_tile { + +template +void reference_layernorm2d_fwd(const HostTensor& x_m_n, + const HostTensor& gamma_n, + const HostTensor& beta_n, + HostTensor& y_m_n, + HostTensor& mean_m, + HostTensor& invStd_m, + ComputeDataType epsilon) +{ + auto layernorm2d_fwd_func = [&](auto m) { + const int N = x_m_n.mDesc.get_lengths()[1]; + + int count = 0; + ComputeDataType mean = 0; + ComputeDataType variance = 0; + ComputeDataType divisor = 0; + + for(int n = 0; n < N; ++n) + { + ++count; + ComputeDataType x = ck_tile::type_convert(x_m_n(m, n)); + ComputeDataType delta = x - mean; + mean += delta / count; + ComputeDataType delta2 = x - mean; + variance += delta * delta2; + } + + // actual variance + variance = variance / count; + divisor = ck_tile::type_convert(1) / ck_tile::sqrt(variance + epsilon); + + if constexpr(!std::is_same_v) + mean_m(m) = ck_tile::type_convert(mean); + + if constexpr(!std::is_same_v) + invStd_m(m) = ck_tile::type_convert(divisor); + + for(int n = 0; n < N; ++n) + { + ComputeDataType x = ck_tile::type_convert(x_m_n(m, n)); + ComputeDataType gamma = ck_tile::type_convert(gamma_n(n)); + ComputeDataType beta = ck_tile::type_convert(beta_n(n)); + auto y = (x - mean) * divisor; + y = y * gamma + beta; + + y_m_n(m, n) = ck_tile::type_convert(y); + } + }; + + make_ParallelTensorFunctor(layernorm2d_fwd_func, + mean_m.mDesc.get_lengths()[0])(std::thread::hardware_concurrency()); +} +} // namespace ck_tile diff --git a/include/ck_tile/ops/layernorm2d.hpp b/include/ck_tile/ops/layernorm2d.hpp new file mode 100644 index 0000000000..3b66645ed4 --- /dev/null +++ b/include/ck_tile/ops/layernorm2d.hpp @@ -0,0 +1,9 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/ops/layernorm2d/kernel/layernorm2d_fwd_kernel.hpp" +#include "ck_tile/ops/layernorm2d/pipeline/block_layernorm2d_fwd_problem.hpp" +#include "ck_tile/ops/layernorm2d/pipeline/tile_layernorm2d_fwd_shape.hpp" +#include "ck_tile/ops/common/tensor_layout.hpp" diff --git a/include/ck_tile/ops/layernorm2d/kernel/layernorm2d_fwd_kernel.hpp b/include/ck_tile/ops/layernorm2d/kernel/layernorm2d_fwd_kernel.hpp new file mode 100644 index 0000000000..4be3e56874 --- /dev/null +++ b/include/ck_tile/ops/layernorm2d/kernel/layernorm2d_fwd_kernel.hpp @@ -0,0 +1,291 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/ops/common.hpp" +#include "ck_tile/ops/welford/thread/thread_welford.hpp" +#include "ck_tile/ops/welford/warp/warp_welford.hpp" + +namespace ck_tile { + +// TODO: Extract some type to wrapper class +template +struct Layernorm2dFwd +{ + using Problem = ck_tile::remove_cvref_t; + + using XDataType = ck_tile::remove_cvref_t; + using GammaDataType = ck_tile::remove_cvref_t; + using BetaDataType = ck_tile::remove_cvref_t; + using ComputeDataType = ck_tile::remove_cvref_t; + using YDataType = ck_tile::remove_cvref_t; + using MeanDataType = ck_tile::remove_cvref_t; + using InvStdDataType = ck_tile::remove_cvref_t; + + static constexpr bool kHasGamma = !std::is_same_v; + static constexpr bool kHasBeta = !std::is_same_v; + static constexpr bool kSaveMean = !std::is_same_v; + static constexpr bool kSaveInvStd = !std::is_same_v; + + static constexpr ck_tile::index_t kMPerBlock = Problem::BlockShape::kMPerBlock; + static constexpr ck_tile::index_t kNPerBlock = Problem::BlockShape::kNPerBlock; + + static constexpr ck_tile::index_t kNThreadPerWarp = Problem::BlockShape::kNThreadPerWarp; + + struct Kargs + { + const void* p_x; + const void* p_gamma; + const void* p_beta; + + void* p_y; + void* p_mean; + void* p_invStd; + + float epsilon; + + ck_tile::index_t M; + ck_tile::index_t N; + }; + + CK_TILE_HOST static constexpr Kargs MakeKargs(const void* p_x, + const void* p_gamma, + const void* p_beta, + void* p_y, + void* p_mean, + void* p_invStd, + float epsilon, + ck_tile::index_t M, + ck_tile::index_t N) + { + return Kargs{p_x, p_gamma, p_beta, p_y, p_mean, p_invStd, epsilon, M, N}; + } + + CK_TILE_HOST static constexpr auto GridSize(ck_tile::index_t M) { return M / kMPerBlock; } + + CK_TILE_HOST static constexpr auto BlockSize() { return Problem::BlockShape::kBlockSize; } + + CK_TILE_DEVICE static constexpr auto MakeXBlockTileDistribution() + { + using S = typename Problem::BlockShape; + + return make_static_tile_distribution( + tile_distribution_encoding< + sequence<>, + tuple, + sequence>, + tuple, sequence<1, 2>>, + tuple, sequence<1, 1>>, + sequence<1, 2>, + sequence<2, 2>>{}); + } + + CK_TILE_DEVICE static constexpr auto MakeGammaBetaBlockTileDistribution() + { + using S = typename Problem::BlockShape; + + return make_static_tile_distribution( + tile_distribution_encoding< + sequence, + tuple>, + tuple, sequence<0, 1>>, + tuple, sequence<1, 1>>, + sequence<1>, + sequence<2>>{}); + } + + template + CK_TILE_DEVICE static constexpr auto GetNPerThread(Dstr) + { + constexpr auto nDstrSpan = Dstr::get_distributed_spans().template at<1>(); + + using Lengths = decltype(nDstrSpan.impl_); + + ck_tile::index_t ret = 1; + + ck_tile::static_for<0, Lengths::size(), 1>{}( + [&](auto idx) { ret *= Lengths::template at(idx); }); + + return ret; + } + + template + CK_TILE_DEVICE static auto InvSqrt(const DistributedTensor& in_dstr_tensor, + const ComputeDataType epsilon) + { + // TODO: Investigate fast inverse square root algorithm with epsilon + constexpr auto spans = DistributedTensor::get_distributed_spans(); + + DistributedTensor out_dstr_tensor; + + sweep_tile_span(spans[number<0>{}], [&](auto idx0) { + constexpr auto i_idx = make_tuple(idx0); + out_dstr_tensor(i_idx) = type_convert(1.0f) / + ck_tile::sqrt(in_dstr_tensor[i_idx] + epsilon); + }); + + return out_dstr_tensor; + } + + template + CK_TILE_DEVICE std::enable_if_t TwoPassLayernorm2dFwd(const XDataType* p_x, + const GammaDataType* p_gamma, + const BetaDataType* p_beta, + YDataType* p_y, + MeanDataType* p_mean, + InvStdDataType* p_invStd, + const ComputeDataType epsilon, + ck_tile::index_t M, + ck_tile::index_t N) const + { + constexpr auto I0 = number<0>{}; + constexpr auto I1 = number<1>{}; + + const auto x_m_n = make_naive_tensor_view( + p_x, make_tuple(M, N), make_tuple(N, 1), number<32>{}, number<1>{}); + + const auto gamma_n = make_naive_tensor_view( + p_gamma, make_tuple(N), make_tuple(1), number<32>{}, number<1>{}); + + const auto beta_n = make_naive_tensor_view( + p_beta, make_tuple(N), make_tuple(1), number<32>{}, number<1>{}); + + const auto iM = get_block_id() * kMPerBlock; + + constexpr auto xDstr = MakeXBlockTileDistribution(); + + auto x_block_window = make_tile_window( + x_m_n, make_tuple(number{}, number{}), {iM, 0}, xDstr); + + index_t num_n_tile_iteration = __builtin_amdgcn_readfirstlane(N / kNPerBlock); + + // TODO: padding - handle max_count if N % kNPerBlock != 0 + constexpr auto NPerThread = GetNPerThread(xDstr); + ThreadWelford thread_welford{ + type_convert(NPerThread * N / kNPerBlock)}; + + using XTensorType = decltype(load_tile(x_block_window)); + auto mean_compute_block_tensor = + thread_welford.template MakeInitialMeanVarDistributedTensor(); + auto var_compute_block_tensor = + thread_welford.template MakeInitialMeanVarDistributedTensor(); + + clear_tile(mean_compute_block_tensor); + clear_tile(var_compute_block_tensor); + + for(int iN = __builtin_amdgcn_readfirstlane(0); iN < num_n_tile_iteration; ++iN) + { + const auto x_block_tensor = load_tile(x_block_window); + + thread_welford(x_block_tensor, mean_compute_block_tensor, var_compute_block_tensor); + move_tile_window(x_block_window, {0, kNPerBlock}); + } + + // TODO: support cross warp Welford + WarpMergeWelford{}( + mean_compute_block_tensor, var_compute_block_tensor, thread_welford.cur_count_); + + auto inv_std_compute_block_tensor = InvSqrt(var_compute_block_tensor, epsilon); + + if constexpr(kSaveMean) + { + const auto mean_m = make_naive_tensor_view_packed( + p_mean, make_tuple(M), number<32>{}); + + auto mean_block_window = + make_tile_window(mean_m, make_tuple(number{}), {iM}); + + store_tile(mean_block_window, cast_tile(mean_compute_block_tensor)); + } + if constexpr(kSaveInvStd) + { + const auto inv_std_m = make_naive_tensor_view_packed( + p_invStd, make_tuple(M), number<32>{}); + + auto inv_std_block_window = + make_tile_window(inv_std_m, make_tuple(number{}), {iM}); + + store_tile(inv_std_block_window, cast_tile(inv_std_compute_block_tensor)); + } + + // TODO: Extract normalize pipeline + const auto y_m_n = make_naive_tensor_view( + p_y, make_tuple(M, N), make_tuple(N, 1), number<32>{}, number<1>{}); + + auto y_block_window = make_tile_window( + y_m_n, make_tuple(number{}, number{}), {iM, 0}); + + constexpr auto gammaDstr = MakeGammaBetaBlockTileDistribution(); + constexpr auto betaDstr = gammaDstr; + + auto gamma_block_window = + make_tile_window(gamma_n, make_tuple(number{}), {0}, gammaDstr); + + auto beta_block_window = make_tile_window( + beta_n, make_tuple(number{}, number{}), {0}, betaDstr); + + // reverse read x to reuse cache + ck_tile::index_t stride_to_right_most_window = N - kNPerBlock; + + move_tile_window(x_block_window, {0, -kNPerBlock}); + move_tile_window(gamma_block_window, {stride_to_right_most_window}); + move_tile_window(beta_block_window, {stride_to_right_most_window}); + move_tile_window(y_block_window, {0, stride_to_right_most_window}); + + // Normalization + for(int iN = __builtin_amdgcn_readfirstlane(0); iN < num_n_tile_iteration; ++iN) + { + const auto x_block_tensor = load_tile(x_block_window); + const auto gamma_block_tensor = load_tile(gamma_block_window); + const auto beta_block_tensor = load_tile(beta_block_window); + + constexpr auto x_spans = decltype(x_block_tensor)::get_distributed_spans(); + + auto y_block_tensor = + make_static_distributed_tensor(x_block_tensor.get_tile_distribution()); + + sweep_tile_span(x_spans[I1], [&](auto idx1) { + constexpr auto j_idx = make_tuple(idx1); + const auto gamma = type_convert(gamma_block_tensor[j_idx]); + const auto beta = type_convert(beta_block_tensor[j_idx]); + + sweep_tile_span(x_spans[I0], [&](auto idx0) { + constexpr auto i_idx = make_tuple(idx0); + constexpr auto i_j_idx = make_tuple(idx0, idx1); + + const auto mean = mean_compute_block_tensor[i_idx]; + const auto inv_std = inv_std_compute_block_tensor[i_idx]; + + const auto x = type_convert(x_block_tensor[i_j_idx]); + auto y = (x - mean) * inv_std * gamma + beta; + + y_block_tensor(i_j_idx) = type_convert(y); + }); + }); + + store_tile(y_block_window, y_block_tensor); + + move_tile_window(x_block_window, {0, -kNPerBlock}); + move_tile_window(gamma_block_window, {-kNPerBlock}); + move_tile_window(beta_block_window, {-kNPerBlock}); + move_tile_window(y_block_window, {0, -kNPerBlock}); + } + } + + CK_TILE_DEVICE void operator()(Kargs kargs) const + { + TwoPassLayernorm2dFwd(static_cast(kargs.p_x), + static_cast(kargs.p_gamma), + static_cast(kargs.p_beta), + static_cast(kargs.p_y), + static_cast(kargs.p_mean), + static_cast(kargs.p_invStd), + static_cast(kargs.epsilon), + kargs.M, + kargs.N); + } +}; + +} // namespace ck_tile diff --git a/include/ck_tile/ops/layernorm2d/pipeline/block_layernorm2d_fwd_problem.hpp b/include/ck_tile/ops/layernorm2d/pipeline/block_layernorm2d_fwd_problem.hpp new file mode 100644 index 0000000000..5206d36d7d --- /dev/null +++ b/include/ck_tile/ops/layernorm2d/pipeline/block_layernorm2d_fwd_problem.hpp @@ -0,0 +1,30 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core/utility/type_traits.hpp" + +namespace ck_tile { + +template +struct BlockLayernorm2dFwdProblem +{ + using XDataType = remove_cvref_t; + using GammaDataType = remove_cvref_t; + using BetaDataType = remove_cvref_t; + using ComputeDataType = remove_cvref_t; + using YDataType = remove_cvref_t; + using MeanDataType = remove_cvref_t; + using InvStdDataType = remove_cvref_t; + using BlockShape = remove_cvref_t; +}; + +} // namespace ck_tile diff --git a/include/ck_tile/ops/layernorm2d/pipeline/tile_layernorm2d_fwd_shape.hpp b/include/ck_tile/ops/layernorm2d/pipeline/tile_layernorm2d_fwd_shape.hpp new file mode 100644 index 0000000000..1ff541d844 --- /dev/null +++ b/include/ck_tile/ops/layernorm2d/pipeline/tile_layernorm2d_fwd_shape.hpp @@ -0,0 +1,35 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" + +namespace ck_tile { +template // Sequence<... +struct TileLayernorm2dShape +{ + static constexpr index_t kMPerThread = ThreadTile::at(number<0>{}); + static constexpr index_t kNPerThread = ThreadTile::at(number<1>{}); + + static constexpr index_t kMPerWarp = WarpTile::at(number<0>{}); + static constexpr index_t kNPerWarp = WarpTile::at(number<1>{}); + + static constexpr index_t kMThreadPerWarp = kMPerWarp / kMPerThread; + static constexpr index_t kNThreadPerWarp = kNPerWarp / kNPerThread; + + static constexpr index_t kMPerBlock = BlockTile::at(number<0>{}); + static constexpr index_t kNPerBlock = BlockTile::at(number<1>{}); + + static constexpr index_t kMWarpPerBlock = kMPerBlock / kMPerWarp; + static constexpr index_t kNWarpPerBlock = kNPerBlock / kNPerWarp; + + // TODO - kNNumWarps can only be 1 if we don't support cross warp welford + static_assert(kNWarpPerBlock == 1); + + static constexpr index_t kBlockSize = warpSize * kMWarpPerBlock * kNWarpPerBlock; +}; + +} // namespace ck_tile diff --git a/include/ck_tile/ops/welford.hpp b/include/ck_tile/ops/welford.hpp new file mode 100644 index 0000000000..dffaad7501 --- /dev/null +++ b/include/ck_tile/ops/welford.hpp @@ -0,0 +1,8 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/ops/welford/thread/thread_welford.hpp" +#include "ck_tile/ops/welford/warp/warp_welford.hpp" +#include "ck_tile/ops/common/tensor_layout.hpp" diff --git a/include/ck_tile/ops/welford/thread/thread_welford.hpp b/include/ck_tile/ops/welford/thread/thread_welford.hpp new file mode 100644 index 0000000000..2ca9a23657 --- /dev/null +++ b/include/ck_tile/ops/welford/thread/thread_welford.hpp @@ -0,0 +1,101 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" + +namespace ck_tile { + +template +struct ThreadWelford +{ + using XDataType = remove_cvref_t; + using ComputeDataType = remove_cvref_t; + + template + CK_TILE_DEVICE void Update(T& mean, T& var, T x) + { + if(ck_tile::isnan(x)) + { + mean = x; + var = x; + } + else + { + T delta = x - mean; + mean += delta / cur_count_; + T delta2 = x - mean; + var += delta * delta2; + } + } + + // [CAUSION] - max_count_ is to deal with the padding problem + // max_count_ is depend on caller, eg: naive and splitN welford will have different + // calculation of max_count_ + CK_TILE_DEVICE constexpr ThreadWelford(int max_count) : cur_count_(0), max_count_(max_count) {} + + template + CK_TILE_DEVICE void operator()(const XDistributedTensor_& x_tensor, + MeanDistributedTensor_& mean_tensor, + VarDistributedTensor_& var_tensor) + { + constexpr auto I0 = number<0>{}; + constexpr auto I1 = number<1>{}; + + constexpr auto spans = XDistributedTensor_::get_distributed_spans(); + + sweep_tile_span(spans[I1], [&](auto dstr_idx_i1) { + if(cur_count_ < max_count_) + { + ++cur_count_; + + sweep_tile_span(spans[I0], [&](auto dstr_idx_i0) { + constexpr auto in_dstr_idx = make_tuple(dstr_idx_i0, dstr_idx_i1); + constexpr auto out_dstr_idx = make_tuple(dstr_idx_i0); + + auto x = ck_tile::type_convert(x_tensor[in_dstr_idx]); + + Update(mean_tensor(out_dstr_idx), var_tensor(out_dstr_idx), x); + }); + } + }); + } + + template + CK_TILE_DEVICE static auto MakeInitialMeanVarDistributedTensor() + { + static_assert(std::is_same_v, "wrong!"); + + constexpr auto reduce_dims = sequence<1>{}; + + constexpr auto dstr = + make_static_tile_distribution(detail::make_reduce_tile_distribution_encoding( + XDistributedTensor_::get_tile_distribution() + .get_static_tile_distribution_encoding(), + reduce_dims)); + + auto tensor = make_static_distributed_tensor(dstr); + clear_tile(tensor); + + return tensor; + } + + template + CK_TILE_DEVICE auto operator()(const XDistributedTensor_& x_tensor) + { + auto mean_tensor = MakeInitialMeanVarDistributedTensor(); + auto var_tensor = MakeInitialMeanVarDistributedTensor(); + + (*this)(x_tensor, mean_tensor, var_tensor); + + return ck_tile::make_tuple(mean_tensor, var_tensor); + } + + int cur_count_; + int max_count_; +}; + +} // namespace ck_tile diff --git a/include/ck_tile/ops/welford/warp/warp_welford.hpp b/include/ck_tile/ops/welford/warp/warp_welford.hpp new file mode 100644 index 0000000000..687b61f430 --- /dev/null +++ b/include/ck_tile/ops/welford/warp/warp_welford.hpp @@ -0,0 +1,154 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" + +namespace ck_tile { + +template +struct WarpMergeWelford +{ + using ComputeDataType = remove_cvref_t; + + template + CK_TILE_DEVICE static void + Merge(T& mean_a, T& var_a, int& count_a, T mean_b, T var_b, int count_b) + { + int count = count_a + count_b; + T count_ = type_convert(count); + T count_a_ = type_convert(count_a); + T count_b_ = type_convert(count_b); + T count_b_over_count = count == 0 ? type_convert(0) : count_b_ / count_; + + T delta = mean_b - mean_a; + mean_a += delta * count_b_over_count; + var_a += var_b + delta * delta * count_a_ * count_b_over_count; + count_a = count; + } + + template + CK_TILE_DEVICE void + operator()(MeanDistributedTensor_& mean_tensor, VarDistributedTensor_& var_tensor, int& count) + { + using Dstr = typename MeanDistributedTensor_::StaticTileDistribution; + using DstrEncode = typename Dstr::DstrEncode; + using DstrEncodeDetail = typename DstrEncode::detail; + + static_assert(std::is_same_v, + "wrong!"); + + constexpr index_t NDimP = Dstr::get_num_of_dimension_p(); + constexpr index_t NDimR = Dstr::get_num_of_dimension_r(); + + constexpr index_t idim_p_lane = NDimP - 1; + + const auto ps_idx = make_array(get_warp_id(), get_lane_id()); + const auto rs_idx = + mean_tensor.get_tile_distribution().calculate_rs_index_from_ps_index(ps_idx); + + constexpr index_t thread_buf_size = MeanDistributedTensor_::get_thread_buffer_size(); + static_assert(thread_buf_size == VarDistributedTensor_::get_thread_buffer_size()); + + const int original_count = count; + + // loop over thread data + static_for<0, thread_buf_size, 1>{}([&](auto i) { + auto v_local_mean = mean_tensor.get_thread_buffer()[i]; + auto v_local_var = var_tensor.get_thread_buffer()[i]; + auto v_local_count = original_count; + + // cross-lane reduce for replication + // only reduce on R dimension correspond to lane + // (lane id maps to this R dimension) + static_for<0, NDimR, 1>{}([&](auto idim_r) { + // FIXME: nasty to use does_p_own_r_ + if constexpr(DstrEncodeDetail::does_p_own_r_[idim_p_lane][idim_r]) + { + constexpr index_t r_length = DstrEncode::rs_lengths_[idim_r]; + + constexpr index_t lid_over_rid_derivative = + DstrEncodeDetail::ps_over_rs_derivative_[idim_p_lane][idim_r]; + + static_assert(is_power_of_two_integer(r_length), + "wrong! only support power of 2 reduction"); + + constexpr index_t nstage = integer_log2_floor(r_length); + + // reduction sweep forward + static_for<0, nstage, 1>{}([&](auto istage) { + constexpr index_t lid_delta = + lid_over_rid_derivative * (1 << (nstage - istage - 1)); + + // pull data from remote lane + const auto v_remote_mean = warp_shuffle_down(v_local_mean, lid_delta); + const auto v_remote_var = warp_shuffle_down(v_local_var, lid_delta); + const auto v_remote_count = warp_shuffle_down(v_local_count, lid_delta); + + // welford merge + Merge(v_local_mean, + v_local_var, + v_local_count, + v_remote_mean, + v_remote_var, + v_remote_count); + }); + } + }); + + // cross-lane broadcast for replication + // only broadcast on R dimension correspond to lane + // (lane id maps to this R dimension) + if constexpr(BroadcastLane) + { + static_for<0, NDimR, 1>{}([&](auto idim_r) { + // FIXME: nasty to use does_p_own_r_ + if constexpr(DstrEncodeDetail::does_p_own_r_[idim_p_lane][idim_r]) + { + const index_t r_id = rs_idx[idim_r]; + + constexpr index_t r_length = DstrEncode::rs_lengths_[idim_r]; + + constexpr index_t lid_over_rid_derivative = + DstrEncodeDetail::ps_over_rs_derivative_[NDimP - 1][idim_r]; + + static_assert(is_power_of_two_integer(r_length), + "wrong! only support power of 2 reduction"); + + constexpr index_t nstage = integer_log2_floor(r_length); + + // broadcast sweep backward + static_for<0, nstage, 1>{}([&](auto istage) { + // do I hold reduced data? + const bool do_i_hold_reduced_data = r_id < (1 << istage); + + constexpr index_t lid_delta = lid_over_rid_derivative * (1 << istage); + + // pull data from remote lane + const auto v_remote_mean = warp_shuffle_up(v_local_mean, lid_delta); + const auto v_remote_var = warp_shuffle_up(v_local_var, lid_delta); + const auto v_remote_count = warp_shuffle_up(v_local_count, lid_delta); + + // decide whether to update local data with remote data + v_local_mean = do_i_hold_reduced_data ? v_local_mean : v_remote_mean; + v_local_var = do_i_hold_reduced_data ? v_local_var : v_remote_var; + v_local_count = do_i_hold_reduced_data ? v_local_count : v_remote_count; + }); + } + }); + } + + mean_tensor.get_thread_buffer()(i) = v_local_mean; + + if constexpr(GetActualVariance) + var_tensor.get_thread_buffer()(i) = v_local_var / v_local_count; + else + var_tensor.get_thread_buffer()(i) = v_local_var; + + count = v_local_count; + }); + } +}; + +} // namespace ck_tile