diff --git a/example/ck_tile/02_layernorm2d/CMakeLists.txt b/example/ck_tile/02_layernorm2d/CMakeLists.txt index bac5f45cd3..eb63a88382 100644 --- a/example/ck_tile/02_layernorm2d/CMakeLists.txt +++ b/example/ck_tile/02_layernorm2d/CMakeLists.txt @@ -1,4 +1,20 @@ +set(EXAMPLE_LAYERNORM2D_FWD "tile_example_layernorm2d_fwd") # not using add_example_executable() to add this target, since we don't want this to have # to be included in "make all/install/check" -add_executable(tile_example_layernorm2d_fwd EXCLUDE_FROM_ALL layernorm2d_fwd.cpp) -target_compile_options(tile_example_layernorm2d_fwd PRIVATE -DSAVE_MEAN_INV_STD) \ No newline at end of file +message("adding example ${EXAMPLE_LAYERNORM2D_FWD}") +file(GLOB INSTANCE_SRCS instances/*.cpp) +add_executable(${EXAMPLE_LAYERNORM2D_FWD} EXCLUDE_FROM_ALL example_layernorm2d_fwd.cpp) +target_include_directories(${EXAMPLE_LAYERNORM2D_FWD} PRIVATE ${CMAKE_CURRENT_LIST_DIR}) +target_sources(${EXAMPLE_LAYERNORM2D_FWD} PRIVATE layernorm2d_fwd_fp16.cpp layernorm2d_fwd_fp32.cpp ${INSTANCE_SRCS}) + +set(EXAMPLE_LAYERNORM2D_FWD_COMPILE_OPTIONS) + +# list(APPEND EXAMPLE_LAYERNORM2D_FWD_COMPILE_OPTIONS -Wno-undefined-func-template -Wno-float-equal) + +target_compile_options(${EXAMPLE_LAYERNORM2D_FWD} PRIVATE ${EXAMPLE_LAYERNORM2D_FWD_COMPILE_OPTIONS}) + +# TODO: we have to turn off this global prop, otherwise the progress bar generated +# by cmake will print too many files, execvp: /bin/sh: Argument list too long +# however, this property may affect global +# TODO: consider codegen a makefile by us +set_property(GLOBAL PROPERTY RULE_MESSAGES OFF) diff --git a/example/ck_tile/02_layernorm2d/README.md b/example/ck_tile/02_layernorm2d/README.md index 66b16c1b7f..405325a2a1 100644 --- a/example/ck_tile/02_layernorm2d/README.md +++ b/example/ck_tile/02_layernorm2d/README.md @@ -6,8 +6,7 @@ This folder contains example for Layernorm2D forward using ck_tile tile-programm ``` # in the root of ck_tile mkdir build && cd build -# you can replace with the appropriate architecture (for example gfx90a or gfx942) or leave it blank -sh ../script/cmake-ck-dev.sh ../ +sh ../script/cmake-ck-dev.sh ../ # you can replace this to gfx90a, gfx942... make tile_example_layernorm2d_fwd -j ``` This will result in an executable `build/bin/tile_example_layernorm2d_fwd` @@ -20,4 +19,4 @@ args: -e epsilon (default:1e-5) -v cpu validation or not (default:1) -prec precision (default:fp16) -``` \ No newline at end of file +``` diff --git a/example/ck_tile/02_layernorm2d/example_layernorm2d_fwd.cpp b/example/ck_tile/02_layernorm2d/example_layernorm2d_fwd.cpp new file mode 100644 index 0000000000..47862f72b8 --- /dev/null +++ b/example/ck_tile/02_layernorm2d/example_layernorm2d_fwd.cpp @@ -0,0 +1,145 @@ +#include "ck_tile/host.hpp" +#include "layernorm2d_fwd.hpp" +#include + +extern float layernorm2d_fwd_fp16(layernorm2d_fwd_args& param, ck_tile::stream_config stream); +extern float layernorm2d_fwd_fp32(layernorm2d_fwd_args& param, ck_tile::stream_config stream); + +auto create_args(int argc, char* argv[]) +{ + ck_tile::ArgParser arg_parser; + arg_parser.insert("m", "3328", "m dimension") + .insert("n", "4096", "m dimension") + .insert("e", "1e-5", "epsilon") + .insert("v", "1", "cpu validation or not") + .insert("prec", "fp32", "precision") + .insert("warmup", "5", "cold iter") + .insert("repeat", "20", "hot iter"); + + bool result = arg_parser.parse(argc, argv); + return std::make_tuple(result, arg_parser); +} + +template +bool run(const ck_tile::ArgParser& arg_parser) +{ + + float epsilon = arg_parser.get_float("e"); + ck_tile::index_t M = arg_parser.get_int("m"); + ck_tile::index_t N = arg_parser.get_int("n"); + std::string data_type = arg_parser.get_str("prec"); + int do_validation = arg_parser.get_int("v"); + int warmup = arg_parser.get_int("warmup"); + int repeat = arg_parser.get_int("repeat"); + + using TypeConfig = LayerNormTypeConfig; + + using XDataType = typename TypeConfig::XDataType; + using YDataType = typename TypeConfig::YDataType; + using GammaDataType = typename TypeConfig::GammaDataType; + using BetaDataType = typename TypeConfig::BetaDataType; + + using MeanDataType = ck_tile::null_type; + using InvStdDataType = ck_tile::null_type; + + using ComputeDataType = typename TypeConfig::ComputeDataType; + + // host verify + ck_tile::HostTensor x_host({M, N}); + ck_tile::HostTensor gamma_host({N}); + ck_tile::HostTensor beta_host({N}); + + ck_tile::HostTensor y_host_ref({M, N}); + ck_tile::HostTensor y_host_dev({M, N}); + + ck_tile::HostTensor mean_host_ref({M}); + ck_tile::HostTensor invStd_host_ref({M}); + + ck_tile::FillUniformDistribution{-.5f, .5f}(x_host); + ck_tile::FillUniformDistribution{-.5f, .5f}(gamma_host); + ck_tile::FillUniformDistribution{-.5f, .5f}(beta_host); + + ck_tile::DeviceMem x_buf(x_host.get_element_space_size_in_bytes()); + ck_tile::DeviceMem gamma_buf(gamma_host.get_element_space_size_in_bytes()); + ck_tile::DeviceMem beta_buf(beta_host.get_element_space_size_in_bytes()); + ck_tile::DeviceMem y_buf(y_host_dev.get_element_space_size_in_bytes()); + + x_buf.ToDevice(x_host.data()); + gamma_buf.ToDevice(gamma_host.data()); + beta_buf.ToDevice(beta_host.data()); + + layernorm2d_fwd_traits traits{data_type}; + + layernorm2d_fwd_args args{x_buf.GetDeviceBuffer(), + gamma_buf.GetDeviceBuffer(), + beta_buf.GetDeviceBuffer(), + y_buf.GetDeviceBuffer(), + epsilon, + M, + N}; + + float ave_time = .0; + + if constexpr(std::is_same::value) + { + ave_time = + layernorm2d_fwd_fp16(args, ck_tile::stream_config{nullptr, true, 0, warmup, repeat}); + } + else if constexpr(std::is_same::value) + { + ave_time = + layernorm2d_fwd_fp32(args, ck_tile::stream_config{nullptr, true, 0, warmup, repeat}); + } + + std::size_t num_byte = sizeof(XDataType) * M * N + sizeof(GammaDataType) * N + + sizeof(BetaDataType) * N + sizeof(YDataType) * M * N; + + float gb_per_sec = num_byte / 1.E6 / ave_time; + std::cout << "[" << data_type << "]" + << " m:" << M << ", n:" << N << ", " << ave_time * 1.E6 << " ns, " << gb_per_sec + << " GB/s" << std::flush; + + bool pass = true; + + if(do_validation) + { + // reference + ck_tile::reference_layernorm2d_fwd( + x_host, gamma_host, beta_host, y_host_ref, mean_host_ref, invStd_host_ref, epsilon); + + y_buf.FromDevice(y_host_dev.data()); + + pass = ck_tile::check_err(y_host_dev, y_host_ref); + + std::cout << ", valid:" << (pass ? "y" : "n") << std::flush; + } + + std::cout << std::endl << std::flush; + + return pass; +} + +int main(int argc, char* argv[]) +{ + auto [result, arg_parser] = create_args(argc, argv); + if(!result) + return -1; + + const std::string data_type = arg_parser.get_str("prec"); + if(data_type == "fp16") + { + return run(arg_parser) ? 0 : -2; + } + if(data_type == "fp32") + { + return run(arg_parser) ? 0 : -2; + } + + return -3; +} diff --git a/example/ck_tile/02_layernorm2d/instances/layernorm2d_fwd_fp16_kernel.cpp b/example/ck_tile/02_layernorm2d/instances/layernorm2d_fwd_fp16_kernel.cpp new file mode 100644 index 0000000000..eab366dc42 --- /dev/null +++ b/example/ck_tile/02_layernorm2d/instances/layernorm2d_fwd_fp16_kernel.cpp @@ -0,0 +1,28 @@ + +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "layernorm_dispatch.hpp" + +// clang-format off +// template float run_layernorm(const layernorm2d_fwd_args& param, ck_tile::stream_config stream); +template float run_layernorm(const layernorm2d_fwd_args& param, ck_tile::stream_config stream); +template float run_layernorm(const layernorm2d_fwd_args& param, ck_tile::stream_config stream); +// template float run_layernorm(const layernorm2d_fwd_args& param, ck_tile::stream_config stream); +template float run_layernorm(const layernorm2d_fwd_args& param, ck_tile::stream_config stream); +template float run_layernorm(const layernorm2d_fwd_args& param, ck_tile::stream_config stream); +// template float run_layernorm(const layernorm2d_fwd_args& param, ck_tile::stream_config stream); +template float run_layernorm(const layernorm2d_fwd_args& param, ck_tile::stream_config stream); +template float run_layernorm(const layernorm2d_fwd_args& param, ck_tile::stream_config stream); +// template float run_layernorm(const layernorm2d_fwd_args& param, ck_tile::stream_config stream); +template float run_layernorm(const layernorm2d_fwd_args& param, ck_tile::stream_config stream); +template float run_layernorm(const layernorm2d_fwd_args& param, ck_tile::stream_config stream); +// template float run_layernorm(const layernorm2d_fwd_args& param, ck_tile::stream_config stream); +template float run_layernorm(const layernorm2d_fwd_args& param, ck_tile::stream_config stream); +template float run_layernorm(const layernorm2d_fwd_args& param, ck_tile::stream_config stream); +// template float run_layernorm(const layernorm2d_fwd_args& param, ck_tile::stream_config stream); +template float run_layernorm(const layernorm2d_fwd_args& param, ck_tile::stream_config stream); +template float run_layernorm(const layernorm2d_fwd_args& param, ck_tile::stream_config stream); +// clang-format on diff --git a/example/ck_tile/02_layernorm2d/instances/layernorm2d_fwd_fp16_kernel_pad.cpp b/example/ck_tile/02_layernorm2d/instances/layernorm2d_fwd_fp16_kernel_pad.cpp new file mode 100644 index 0000000000..708c0ca5aa --- /dev/null +++ b/example/ck_tile/02_layernorm2d/instances/layernorm2d_fwd_fp16_kernel_pad.cpp @@ -0,0 +1,28 @@ + +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "layernorm_dispatch.hpp" + +// clang-format off +// template float run_layernorm(const layernorm2d_fwd_args& param, ck_tile::stream_config stream); +template float run_layernorm(const layernorm2d_fwd_args& param, ck_tile::stream_config stream); +template float run_layernorm(const layernorm2d_fwd_args& param, ck_tile::stream_config stream); +// template float run_layernorm(const layernorm2d_fwd_args& param, ck_tile::stream_config stream); +template float run_layernorm(const layernorm2d_fwd_args& param, ck_tile::stream_config stream); +template float run_layernorm(const layernorm2d_fwd_args& param, ck_tile::stream_config stream); +// template float run_layernorm(const layernorm2d_fwd_args& param, ck_tile::stream_config stream); +template float run_layernorm(const layernorm2d_fwd_args& param, ck_tile::stream_config stream); +template float run_layernorm(const layernorm2d_fwd_args& param, ck_tile::stream_config stream); +// template float run_layernorm(const layernorm2d_fwd_args& param, ck_tile::stream_config stream); +template float run_layernorm(const layernorm2d_fwd_args& param, ck_tile::stream_config stream); +template float run_layernorm(const layernorm2d_fwd_args& param, ck_tile::stream_config stream); +// template float run_layernorm(const layernorm2d_fwd_args& param, ck_tile::stream_config stream); +template float run_layernorm(const layernorm2d_fwd_args& param, ck_tile::stream_config stream); +template float run_layernorm(const layernorm2d_fwd_args& param, ck_tile::stream_config stream); +// template float run_layernorm(const layernorm2d_fwd_args& param, ck_tile::stream_config stream); +template float run_layernorm(const layernorm2d_fwd_args& param, ck_tile::stream_config stream); +template float run_layernorm(const layernorm2d_fwd_args& param, ck_tile::stream_config stream); +// clang-format on diff --git a/example/ck_tile/02_layernorm2d/instances/layernorm2d_fwd_fp32_kernel.cpp b/example/ck_tile/02_layernorm2d/instances/layernorm2d_fwd_fp32_kernel.cpp new file mode 100644 index 0000000000..780431ea3a --- /dev/null +++ b/example/ck_tile/02_layernorm2d/instances/layernorm2d_fwd_fp32_kernel.cpp @@ -0,0 +1,34 @@ + +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "layernorm_dispatch.hpp" + +// clang-format off +template float run_layernorm(const layernorm2d_fwd_args& param, ck_tile::stream_config stream); +template float run_layernorm(const layernorm2d_fwd_args& param, ck_tile::stream_config stream); +template float run_layernorm(const layernorm2d_fwd_args& param, ck_tile::stream_config stream); +template float run_layernorm(const layernorm2d_fwd_args& param, ck_tile::stream_config stream); +template float run_layernorm(const layernorm2d_fwd_args& param, ck_tile::stream_config stream); +template float run_layernorm(const layernorm2d_fwd_args& param, ck_tile::stream_config stream); +template float run_layernorm(const layernorm2d_fwd_args& param, ck_tile::stream_config stream); +template float run_layernorm(const layernorm2d_fwd_args& param, ck_tile::stream_config stream); +template float run_layernorm(const layernorm2d_fwd_args& param, ck_tile::stream_config stream); +template float run_layernorm(const layernorm2d_fwd_args& param, ck_tile::stream_config stream); +template float run_layernorm(const layernorm2d_fwd_args& param, ck_tile::stream_config stream); +template float run_layernorm(const layernorm2d_fwd_args& param, ck_tile::stream_config stream); +template float run_layernorm(const layernorm2d_fwd_args& param, ck_tile::stream_config stream); +template float run_layernorm(const layernorm2d_fwd_args& param, ck_tile::stream_config stream); +template float run_layernorm(const layernorm2d_fwd_args& param, ck_tile::stream_config stream); +template float run_layernorm(const layernorm2d_fwd_args& param, ck_tile::stream_config stream); +template float run_layernorm(const layernorm2d_fwd_args& param, ck_tile::stream_config stream); +template float run_layernorm(const layernorm2d_fwd_args& param, ck_tile::stream_config stream); +template float run_layernorm(const layernorm2d_fwd_args& param, ck_tile::stream_config stream); +template float run_layernorm(const layernorm2d_fwd_args& param, ck_tile::stream_config stream); +template float run_layernorm(const layernorm2d_fwd_args& param, ck_tile::stream_config stream); +template float run_layernorm(const layernorm2d_fwd_args& param, ck_tile::stream_config stream); +template float run_layernorm(const layernorm2d_fwd_args& param, ck_tile::stream_config stream); +template float run_layernorm(const layernorm2d_fwd_args& param, ck_tile::stream_config stream); +// clang-format on diff --git a/example/ck_tile/02_layernorm2d/layernorm2d_fwd.cpp b/example/ck_tile/02_layernorm2d/layernorm2d_fwd.cpp deleted file mode 100644 index 35f291e060..0000000000 --- a/example/ck_tile/02_layernorm2d/layernorm2d_fwd.cpp +++ /dev/null @@ -1,193 +0,0 @@ -#include "ck_tile/host.hpp" -#include "layernorm2d_fwd.hpp" -#include - -// Host API implementation -float layernorm2d_fwd(layernorm2d_fwd_traits t, - layernorm2d_fwd_args a, - const ck_tile::stream_config& s) -{ - if(t.data_type.compare("fp16") == 0) - { - using XDataType = ck_tile::half_t; - using YDataType = ck_tile::half_t; - using GammaDataType = ck_tile::half_t; - using BetaDataType = ck_tile::half_t; -#ifdef SAVE_MEAN_INV_STD - using MeanDataType = ck_tile::half_t; - using InvStdDataType = ck_tile::half_t; -#else - using MeanDataType = ck_tile::null_type; - using InvStdDataType = ck_tile::null_type; -#endif - using ComputeDataType = float; - - using thread_tile = ck_tile::sequence<4, 4>; - using warp_tile = ck_tile::sequence<8, 128>; - using block_tile = ck_tile::sequence<32, 128>; - - using Shape = ck_tile::TileLayernorm2dShape; - - using PipelineProblem = ck_tile::BlockLayernorm2dFwdProblem; - - using Kernel = ck_tile::Layernorm2dFwd; - - auto kargs = Kernel::MakeKargs( - a.p_x, a.p_gamma, a.p_beta, a.p_y, a.p_mean, a.p_invStd, a.epsilon, a.M, a.N); - - const dim3 grids = Kernel::GridSize(a.M); - constexpr dim3 blocks = Kernel::BlockSize(); - - constexpr ck_tile::index_t kBlockPerCu = Shape::kMWarpPerBlock * Shape::kNWarpPerBlock; - - float ave_time = ck_tile::launch_kernel( - s, ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); - - return ave_time; - } - - return 0; -} - -auto create_args(int argc, char* argv[]) -{ - ck_tile::ArgParser arg_parser; - arg_parser.insert("m", "3328", "m dimension") - .insert("n", "4096", "m dimension") - .insert("e", "1e-5", "epsilon") - .insert("v", "1", "cpu validation or not") - .insert("prec", "fp16", "precision"); - - bool result = arg_parser.parse(argc, argv); - return std::make_tuple(result, arg_parser); -} - -int main(int argc, char* argv[]) -{ - - auto [result, arg_parser] = create_args(argc, argv); - if(!result) - return -1; - - float epsilon = arg_parser.get_float("e"); - ck_tile::index_t M = arg_parser.get_int("m"); - ck_tile::index_t N = arg_parser.get_int("n"); - std::string data_type = arg_parser.get_str("prec"); - int do_validation = arg_parser.get_int("v"); - - using XDataType = ck_tile::half_t; - using YDataType = ck_tile::half_t; - using GammaDataType = ck_tile::half_t; - using BetaDataType = ck_tile::half_t; -#ifdef SAVE_MEAN_INV_STD - using MeanDataType = ck_tile::half_t; - using InvStdDataType = ck_tile::half_t; -#else - using MeanDataType = ck_tile::null_type; - using InvStdDataType = ck_tile::null_type; -#endif - using ComputeDataType = float; - - // host verify - ck_tile::HostTensor x_host({M, N}); - ck_tile::HostTensor gamma_host({N}); - ck_tile::HostTensor beta_host({N}); - - ck_tile::HostTensor y_host_ref({M, N}); - ck_tile::HostTensor y_host_dev({M, N}); - - ck_tile::HostTensor mean_host_ref({M}); - ck_tile::HostTensor invStd_host_ref({M}); - -#ifdef SAVE_MEAN_INV_STD - ck_tile::HostTensor mean_host_dev({M}); - ck_tile::HostTensor invStd_host_dev({M}); -#endif - - ck_tile::FillUniformDistribution{-5.f, 5.f}(x_host); - ck_tile::FillUniformDistribution{-5.f, 5.f}(gamma_host); - ck_tile::FillUniformDistribution{-5.f, 5.f}(beta_host); - - ck_tile::DeviceMem x_buf(x_host.get_element_space_size_in_bytes()); - ck_tile::DeviceMem gamma_buf(gamma_host.get_element_space_size_in_bytes()); - ck_tile::DeviceMem beta_buf(beta_host.get_element_space_size_in_bytes()); - ck_tile::DeviceMem y_buf(y_host_dev.get_element_space_size_in_bytes()); - -#ifdef SAVE_MEAN_INV_STD - ck_tile::DeviceMem mean_buf(mean_host_dev.get_element_space_size_in_bytes()); - ck_tile::DeviceMem invStd_buf(invStd_host_dev.get_element_space_size_in_bytes()); -#endif - - x_buf.ToDevice(x_host.data()); - gamma_buf.ToDevice(gamma_host.data()); - beta_buf.ToDevice(beta_host.data()); - - layernorm2d_fwd_traits traits{data_type}; - - layernorm2d_fwd_args args{x_buf.GetDeviceBuffer(), - gamma_buf.GetDeviceBuffer(), - beta_buf.GetDeviceBuffer(), - y_buf.GetDeviceBuffer(), -#ifdef SAVE_MEAN_INV_STD - mean_buf.GetDeviceBuffer(), - invStd_buf.GetDeviceBuffer(), -#else - nullptr, - nullptr, -#endif - epsilon, - M, - N}; - - float ave_time = layernorm2d_fwd(traits, args, ck_tile::stream_config{nullptr, true}); - - std::size_t num_byte = sizeof(XDataType) * M * N + sizeof(GammaDataType) * N + - sizeof(BetaDataType) * N + sizeof(YDataType) * M * N; - - float gb_per_sec = num_byte / 1.E6 / ave_time; - std::cout << "[" << data_type << "]" - << " m:" << M << ", n:" << N << ", " << ave_time << " ms, " << gb_per_sec << " GB/s" - << std::flush; - - bool pass = true; - - if(do_validation) - { - // reference - ck_tile::reference_layernorm2d_fwd( - x_host, gamma_host, beta_host, y_host_ref, mean_host_ref, invStd_host_ref, epsilon); - - y_buf.FromDevice(y_host_dev.data()); - - pass = ck_tile::check_err(y_host_dev, y_host_ref); - -#ifdef SAVE_MEAN_INV_STD - mean_buf.FromDevice(mean_host_dev.data()); - pass &= ck_tile::check_err(mean_host_dev, mean_host_ref); - - invStd_buf.FromDevice(invStd_host_dev.data()); - pass &= ck_tile::check_err(invStd_host_dev, invStd_host_ref); -#endif - - std::cout << ", valid:" << (pass ? "y" : "n") << std::flush; - } - - std::cout << std::endl << std::flush; - - return !pass; -} diff --git a/example/ck_tile/02_layernorm2d/layernorm2d_fwd.hpp b/example/ck_tile/02_layernorm2d/layernorm2d_fwd.hpp index 4d1aac0994..0ada3fbac4 100644 --- a/example/ck_tile/02_layernorm2d/layernorm2d_fwd.hpp +++ b/example/ck_tile/02_layernorm2d/layernorm2d_fwd.hpp @@ -13,14 +13,51 @@ struct layernorm2d_fwd_traits std::string data_type; }; +template +struct LayerNormTypeConfig; + +template <> +struct LayerNormTypeConfig +{ + using XDataType = ck_tile::half_t; + using YDataType = ck_tile::half_t; + using GammaDataType = ck_tile::half_t; + using BetaDataType = ck_tile::half_t; +#ifdef SAVE_MEAN_INV_STD + using MeanDataType = ck_tile::half_t; + using InvStdDataType = ck_tile::half_t; +#else + using MeanDataType = ck_tile::null_type; + using InvStdDataType = ck_tile::null_type; +#endif + using ComputeDataType = float; +}; + +template <> +struct LayerNormTypeConfig +{ + using XDataType = float; + using YDataType = float; + using GammaDataType = float; + using BetaDataType = float; +#ifdef SAVE_MEAN_INV_STD + using MeanDataType = float; + using InvStdDataType = float; +#else + using MeanDataType = ck_tile::null_type; + using InvStdDataType = ck_tile::null_type; +#endif + using ComputeDataType = float; +}; + struct layernorm2d_fwd_args { const void* p_x; const void* p_gamma; const void* p_beta; void* p_y; - void* p_mean; - void* p_invStd; + // void* p_mean; + // void* p_invStd; float epsilon; ck_tile::index_t M; ck_tile::index_t N; diff --git a/example/ck_tile/02_layernorm2d/layernorm2d_fwd_fp16.cpp b/example/ck_tile/02_layernorm2d/layernorm2d_fwd_fp16.cpp new file mode 100644 index 0000000000..5118849043 --- /dev/null +++ b/example/ck_tile/02_layernorm2d/layernorm2d_fwd_fp16.cpp @@ -0,0 +1,151 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "layernorm_dispatch.hpp" + +// clang-format off +// extern template float run_layernorm(const layernorm2d_fwd_args& param, ck_tile::stream_config stream); +extern template float run_layernorm(const layernorm2d_fwd_args& param, ck_tile::stream_config stream); +extern template float run_layernorm(const layernorm2d_fwd_args& param, ck_tile::stream_config stream); +// extern template float run_layernorm(const layernorm2d_fwd_args& param, ck_tile::stream_config stream); +extern template float run_layernorm(const layernorm2d_fwd_args& param, ck_tile::stream_config stream); +extern template float run_layernorm(const layernorm2d_fwd_args& param, ck_tile::stream_config stream); +// extern template float run_layernorm(const layernorm2d_fwd_args& param, ck_tile::stream_config stream); +extern template float run_layernorm(const layernorm2d_fwd_args& param, ck_tile::stream_config stream); +extern template float run_layernorm(const layernorm2d_fwd_args& param, ck_tile::stream_config stream); +// extern template float run_layernorm(const layernorm2d_fwd_args& param, ck_tile::stream_config stream); +extern template float run_layernorm(const layernorm2d_fwd_args& param, ck_tile::stream_config stream); +extern template float run_layernorm(const layernorm2d_fwd_args& param, ck_tile::stream_config stream); +// extern template float run_layernorm(const layernorm2d_fwd_args& param, ck_tile::stream_config stream); +extern template float run_layernorm(const layernorm2d_fwd_args& param, ck_tile::stream_config stream); +extern template float run_layernorm(const layernorm2d_fwd_args& param, ck_tile::stream_config stream); +extern template float run_layernorm(const layernorm2d_fwd_args& param, ck_tile::stream_config stream); +extern template float run_layernorm(const layernorm2d_fwd_args& param, ck_tile::stream_config stream); +// extern template float run_layernorm(const layernorm2d_fwd_args& param, ck_tile::stream_config stream); +extern template float run_layernorm(const layernorm2d_fwd_args& param, ck_tile::stream_config stream); +extern template float run_layernorm(const layernorm2d_fwd_args& param, ck_tile::stream_config stream); +// extern template float run_layernorm(const layernorm2d_fwd_args& param, ck_tile::stream_config stream); +extern template float run_layernorm(const layernorm2d_fwd_args& param, ck_tile::stream_config stream); +extern template float run_layernorm(const layernorm2d_fwd_args& param, ck_tile::stream_config stream); +// extern template float run_layernorm(const layernorm2d_fwd_args& param, ck_tile::stream_config stream); +extern template float run_layernorm(const layernorm2d_fwd_args& param, ck_tile::stream_config stream); +extern template float run_layernorm(const layernorm2d_fwd_args& param, ck_tile::stream_config stream); +// extern template float run_layernorm(const layernorm2d_fwd_args& param, ck_tile::stream_config stream); +extern template float run_layernorm(const layernorm2d_fwd_args& param, ck_tile::stream_config stream); +extern template float run_layernorm(const layernorm2d_fwd_args& param, ck_tile::stream_config stream); +// extern template float run_layernorm(const layernorm2d_fwd_args& param, ck_tile::stream_config stream); +extern template float run_layernorm(const layernorm2d_fwd_args& param, ck_tile::stream_config stream); +extern template float run_layernorm(const layernorm2d_fwd_args& param, ck_tile::stream_config stream); +extern template float run_layernorm(const layernorm2d_fwd_args& param, ck_tile::stream_config stream); +extern template float run_layernorm(const layernorm2d_fwd_args& param, ck_tile::stream_config stream); +// clang-format on + +float layernorm2d_fwd_fp16(layernorm2d_fwd_args& param, ck_tile::stream_config stream) +{ + // Disable all vector 8fp16 read/write instances as it has performance issue regarding compiler +#if 0 + if(param.N % 8 == 0) + { + if(param.N <= 128) + { + return param.N == 128 ? run_layernorm(param, stream) + : run_layernorm(param, stream); + } + else if(param.N <= 256) + { + return param.N == 256 ? run_layernorm(param, stream) + : run_layernorm(param, stream); + } + else if(param.N <= 512) + { + return param.N == 512 ? run_layernorm(param, stream) + : run_layernorm(param, stream); + } + else if(param.N <= 1024) + { + return param.N == 1024 ? run_layernorm(param, stream) + : run_layernorm(param, stream); + } + else + { + return param.N == 2048 ? run_layernorm(param, stream) + : run_layernorm(param, stream); + } + } + else if(param.N % 4 == 0) +#endif + if(param.N % 4 == 0) + { + if(param.N <= 128) + { + return param.N == 128 ? run_layernorm(param, stream) + : run_layernorm(param, stream); + } + else if(param.N <= 256) + { + return param.N == 256 ? run_layernorm(param, stream) + : run_layernorm(param, stream); + } + else if(param.N <= 512) + { + return param.N == 512 ? run_layernorm(param, stream) + : run_layernorm(param, stream); + } + else if(param.N <= 1024) + { + return param.N == 1024 ? run_layernorm(param, stream) + : run_layernorm(param, stream); + } + else if(param.N <= 2048) + { + return param.N == 2048 ? run_layernorm(param, stream) + : run_layernorm(param, stream); + } + else + { + return param.N % 2048 == 0 + ? run_layernorm(param, stream) + : run_layernorm(param, stream); + } + } + else if(param.N % 2 == 0) + { + if(param.N <= 128) + { + return param.N == 128 ? run_layernorm(param, stream) + : run_layernorm(param, stream); + } + else if(param.N <= 256) + { + return param.N == 256 ? run_layernorm(param, stream) + : run_layernorm(param, stream); + } + else if(param.N <= 512) + { + return param.N == 512 ? run_layernorm(param, stream) + : run_layernorm(param, stream); + } + else if(param.N <= 1024) + { + return param.N == 1024 ? run_layernorm(param, stream) + : run_layernorm(param, stream); + } + else if(param.N <= 2048) + { + return param.N == 2048 ? run_layernorm(param, stream) + : run_layernorm(param, stream); + } + else + { + return param.N % 2048 == 0 + ? run_layernorm(param, stream) + : run_layernorm(param, stream); + } + } + else + { + throw std::runtime_error("Sequence length sizes not supported!"); + } +}; diff --git a/example/ck_tile/02_layernorm2d/layernorm2d_fwd_fp32.cpp b/example/ck_tile/02_layernorm2d/layernorm2d_fwd_fp32.cpp new file mode 100644 index 0000000000..c03ebc1b75 --- /dev/null +++ b/example/ck_tile/02_layernorm2d/layernorm2d_fwd_fp32.cpp @@ -0,0 +1,103 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "layernorm_dispatch.hpp" + +// clang-format off +extern template float run_layernorm(const layernorm2d_fwd_args& param, ck_tile::stream_config stream); +extern template float run_layernorm(const layernorm2d_fwd_args& param, ck_tile::stream_config stream); +extern template float run_layernorm(const layernorm2d_fwd_args& param, ck_tile::stream_config stream); +extern template float run_layernorm(const layernorm2d_fwd_args& param, ck_tile::stream_config stream); +extern template float run_layernorm(const layernorm2d_fwd_args& param, ck_tile::stream_config stream); +extern template float run_layernorm(const layernorm2d_fwd_args& param, ck_tile::stream_config stream); +extern template float run_layernorm(const layernorm2d_fwd_args& param, ck_tile::stream_config stream); +extern template float run_layernorm(const layernorm2d_fwd_args& param, ck_tile::stream_config stream); +extern template float run_layernorm(const layernorm2d_fwd_args& param, ck_tile::stream_config stream); +extern template float run_layernorm(const layernorm2d_fwd_args& param, ck_tile::stream_config stream); +extern template float run_layernorm(const layernorm2d_fwd_args& param, ck_tile::stream_config stream); +extern template float run_layernorm(const layernorm2d_fwd_args& param, ck_tile::stream_config stream); +extern template float run_layernorm(const layernorm2d_fwd_args& param, ck_tile::stream_config stream); +extern template float run_layernorm(const layernorm2d_fwd_args& param, ck_tile::stream_config stream); +extern template float run_layernorm(const layernorm2d_fwd_args& param, ck_tile::stream_config stream); +extern template float run_layernorm(const layernorm2d_fwd_args& param, ck_tile::stream_config stream); +extern template float run_layernorm(const layernorm2d_fwd_args& param, ck_tile::stream_config stream); +extern template float run_layernorm(const layernorm2d_fwd_args& param, ck_tile::stream_config stream); +extern template float run_layernorm(const layernorm2d_fwd_args& param, ck_tile::stream_config stream); +extern template float run_layernorm(const layernorm2d_fwd_args& param, ck_tile::stream_config stream); +// clang-format on + +float layernorm2d_fwd_fp32(layernorm2d_fwd_args& param, ck_tile::stream_config stream) +{ + if(param.N % 4 == 0) + { + if(param.N <= 128) + { + return param.N == 128 ? run_layernorm(param, stream) + : run_layernorm(param, stream); + } + else if(param.N <= 256) + { + return param.N == 256 ? run_layernorm(param, stream) + : run_layernorm(param, stream); + } + else if(param.N <= 512) + { + return param.N == 512 ? run_layernorm(param, stream) + : run_layernorm(param, stream); + } + else if(param.N <= 1024) + { + return param.N == 1024 ? run_layernorm(param, stream) + : run_layernorm(param, stream); + } + else if(param.N <= 2048) + { + return param.N == 2048 ? run_layernorm(param, stream) + : run_layernorm(param, stream); + } + else + { + return param.N % 2048 == 0 ? run_layernorm(param, stream) + : run_layernorm(param, stream); + } + } + else if(param.N % 2 == 0) + { + if(param.N <= 128) + { + return param.N == 128 ? run_layernorm(param, stream) + : run_layernorm(param, stream); + } + else if(param.N <= 256) + { + return param.N == 256 ? run_layernorm(param, stream) + : run_layernorm(param, stream); + } + else if(param.N <= 512) + { + return param.N == 512 ? run_layernorm(param, stream) + : run_layernorm(param, stream); + } + else if(param.N <= 1024) + { + return param.N == 1024 ? run_layernorm(param, stream) + : run_layernorm(param, stream); + } + else if(param.N <= 2048) + { + return param.N == 2048 ? run_layernorm(param, stream) + : run_layernorm(param, stream); + } + else + { + return param.N % 2048 == 0 ? run_layernorm(param, stream) + : run_layernorm(param, stream); + } + } + else + { + throw std::runtime_error("Sequence length sizes not supported!"); + } +}; diff --git a/example/ck_tile/02_layernorm2d/layernorm_dispatch.hpp b/example/ck_tile/02_layernorm2d/layernorm_dispatch.hpp new file mode 100644 index 0000000000..ff0edb216e --- /dev/null +++ b/example/ck_tile/02_layernorm2d/layernorm_dispatch.hpp @@ -0,0 +1,78 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include +#include + +#include "layernorm2d_fwd.hpp" + +template +struct layernorm_dispatch +{ + static constexpr ck_tile::index_t MRepeat = 1; + static_assert(NThread <= 64, "We only support intra-wave reduction"); + static constexpr ck_tile::index_t WaveNum = NThread / 16; + // clang-format off + using thread_tile = ck_tile::sequence; + using warp_tile = ck_tile::sequence; + using block_tile = ck_tile::sequence; + // clang-format on + + using Shape = ck_tile::TileLayernorm2dShape; + + using PipelineProblem = ck_tile::BlockLayernorm2dFwdProblem< + typename LayerNormTypeConfig::XDataType, + typename LayerNormTypeConfig::GammaDataType, + typename LayerNormTypeConfig::BetaDataType, + typename LayerNormTypeConfig::ComputeDataType, + typename LayerNormTypeConfig::YDataType, + typename LayerNormTypeConfig::MeanDataType, + typename LayerNormTypeConfig::InvStdDataType, + Shape, + kPadN, + kTwoPass>; + + using Kernel = ck_tile::Layernorm2dFwd; + + static float Run(const layernorm2d_fwd_args& param, ck_tile::stream_config stream) + { + using k_ = Kernel; + + const dim3 grids = k_::GridSize(param.M); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = 1; + + return ck_tile::launch_kernel(stream, + ck_tile::make_kernel(k_{}, + grids, + blocks, + 0, + param.p_x, + param.p_gamma, + param.p_beta, + param.p_y, + param.epsilon, + param.M, + param.N)); + }; +}; + +template +float run_layernorm(const layernorm2d_fwd_args& param, ck_tile::stream_config stream) +{ + return layernorm_dispatch:: + Run(param, stream); +}; diff --git a/example/ck_tile/02_layernorm2d/perf_test.sh b/example/ck_tile/02_layernorm2d/perf_test.sh new file mode 100644 index 0000000000..f53e958c38 --- /dev/null +++ b/example/ck_tile/02_layernorm2d/perf_test.sh @@ -0,0 +1,32 @@ +./bin/tile_example_layernorm2d_fwd -m=700 -n=80 -e=1e-12 -v=1 -prec=fp32 -repeat=1000 +./bin/tile_example_layernorm2d_fwd -m=700 -n=128 -e=1e-12 -v=1 -prec=fp32 -repeat=1000 +./bin/tile_example_layernorm2d_fwd -m=700 -n=144 -e=1e-12 -v=1 -prec=fp32 -repeat=1000 +./bin/tile_example_layernorm2d_fwd -m=700 -n=168 -e=1e-12 -v=1 -prec=fp32 -repeat=1000 +./bin/tile_example_layernorm2d_fwd -m=700 -n=184 -e=1e-12 -v=1 -prec=fp32 -repeat=1000 +./bin/tile_example_layernorm2d_fwd -m=700 -n=256 -e=1e-12 -v=1 -prec=fp32 -repeat=1000 +./bin/tile_example_layernorm2d_fwd -m=700 -n=288 -e=1e-12 -v=1 -prec=fp32 -repeat=1000 +./bin/tile_example_layernorm2d_fwd -m=700 -n=344 -e=1e-12 -v=1 -prec=fp32 -repeat=1000 +./bin/tile_example_layernorm2d_fwd -m=700 -n=376 -e=1e-12 -v=1 -prec=fp32 -repeat=1000 +./bin/tile_example_layernorm2d_fwd -m=700 -n=448 -e=1e-12 -v=1 -prec=fp32 -repeat=1000 +./bin/tile_example_layernorm2d_fwd -m=700 -n=512 -e=1e-12 -v=1 -prec=fp32 -repeat=1000 +./bin/tile_example_layernorm2d_fwd -m=700 -n=924 -e=1e-12 -v=1 -prec=fp32 -repeat=1000 +./bin/tile_example_layernorm2d_fwd -m=700 -n=1024 -e=1e-12 -v=1 -prec=fp32 -repeat=1000 +./bin/tile_example_layernorm2d_fwd -m=700 -n=1078 -e=1e-12 -v=1 -prec=fp32 -repeat=1000 +./bin/tile_example_layernorm2d_fwd -m=700 -n=1996 -e=1e-12 -v=1 -prec=fp32 -repeat=1000 +./bin/tile_example_layernorm2d_fwd -m=700 -n=4080 -e=1e-12 -v=1 -prec=fp32 -repeat=1000 +./bin/tile_example_layernorm2d_fwd -m=700 -n=80 -e=1e-12 -v=1 -prec=fp16 -repeat=1000 +./bin/tile_example_layernorm2d_fwd -m=700 -n=128 -e=1e-12 -v=1 -prec=fp16 -repeat=1000 +./bin/tile_example_layernorm2d_fwd -m=700 -n=144 -e=1e-12 -v=1 -prec=fp16 -repeat=1000 +./bin/tile_example_layernorm2d_fwd -m=700 -n=168 -e=1e-12 -v=1 -prec=fp16 -repeat=1000 +./bin/tile_example_layernorm2d_fwd -m=700 -n=184 -e=1e-12 -v=1 -prec=fp16 -repeat=1000 +./bin/tile_example_layernorm2d_fwd -m=700 -n=256 -e=1e-12 -v=1 -prec=fp16 -repeat=1000 +./bin/tile_example_layernorm2d_fwd -m=700 -n=288 -e=1e-12 -v=1 -prec=fp16 -repeat=1000 +./bin/tile_example_layernorm2d_fwd -m=700 -n=344 -e=1e-12 -v=1 -prec=fp16 -repeat=1000 +./bin/tile_example_layernorm2d_fwd -m=700 -n=376 -e=1e-12 -v=1 -prec=fp16 -repeat=1000 +./bin/tile_example_layernorm2d_fwd -m=700 -n=448 -e=1e-12 -v=1 -prec=fp16 -repeat=1000 +./bin/tile_example_layernorm2d_fwd -m=700 -n=512 -e=1e-12 -v=1 -prec=fp16 -repeat=1000 +./bin/tile_example_layernorm2d_fwd -m=700 -n=924 -e=1e-12 -v=1 -prec=fp16 -repeat=1000 +./bin/tile_example_layernorm2d_fwd -m=700 -n=1024 -e=1e-12 -v=1 -prec=fp16 -repeat=1000 +./bin/tile_example_layernorm2d_fwd -m=700 -n=1078 -e=1e-12 -v=1 -prec=fp16 -repeat=1000 +./bin/tile_example_layernorm2d_fwd -m=700 -n=1996 -e=1e-12 -v=1 -prec=fp16 -repeat=1000 +./bin/tile_example_layernorm2d_fwd -m=700 -n=4080 -e=1e-12 -v=1 -prec=fp16 -repeat=1000 \ No newline at end of file diff --git a/include/ck_tile/ops/layernorm2d/kernel/layernorm2d_fwd_kernel.hpp b/include/ck_tile/ops/layernorm2d/kernel/layernorm2d_fwd_kernel.hpp index 468df793da..a2d817aa87 100644 --- a/include/ck_tile/ops/layernorm2d/kernel/layernorm2d_fwd_kernel.hpp +++ b/include/ck_tile/ops/layernorm2d/kernel/layernorm2d_fwd_kernel.hpp @@ -31,14 +31,10 @@ struct Layernorm2dFwd static constexpr ck_tile::index_t kMPerBlock = Problem::BlockShape::kMPerBlock; static constexpr ck_tile::index_t kNPerBlock = Problem::BlockShape::kNPerBlock; - static constexpr bool kPadM = Problem::kPadM; static constexpr bool kPadN = Problem::kPadN; + static constexpr bool kTwoPass = Problem::kTwoPass; static constexpr ck_tile::index_t kNThreadPerWarp = Problem::BlockShape::kNThreadPerWarp; - static constexpr ck_tile::index_t kNPerThread = Problem::BlockShape::kNPerThread; - - static constexpr auto I0 = number<0>{}; - static constexpr auto I1 = number<1>{}; struct Kargs { @@ -47,8 +43,8 @@ struct Layernorm2dFwd const void* p_beta; void* p_y; - void* p_mean; - void* p_invStd; + // void* p_mean; + // void* p_invStd; float epsilon; @@ -69,7 +65,10 @@ struct Layernorm2dFwd return Kargs{p_x, p_gamma, p_beta, p_y, p_mean, p_invStd, epsilon, M, N}; } - CK_TILE_HOST static constexpr auto GridSize(ck_tile::index_t M) { return M / kMPerBlock; } + CK_TILE_HOST static constexpr auto GridSize(ck_tile::index_t M) + { + return (M + kMPerBlock - 1) / kMPerBlock; + } CK_TILE_HOST static constexpr auto BlockSize() { return Problem::BlockShape::kBlockSize; } @@ -81,11 +80,11 @@ struct Layernorm2dFwd tile_distribution_encoding< sequence<>, tuple, - sequence>, + sequence>, tuple, sequence<1, 2>>, - tuple, sequence<1, 1>>, - sequence<1, 2>, - sequence<2, 2>>{}); + tuple, sequence<1, 2>>, + sequence<1, 2, 2>, + sequence<2, 0, 3>>{}); } CK_TILE_DEVICE static constexpr auto MakeGammaBetaBlockTileDistribution() @@ -95,32 +94,26 @@ struct Layernorm2dFwd return make_static_tile_distribution( tile_distribution_encoding< sequence, - tuple>, + tuple>, tuple, sequence<0, 1>>, - tuple, sequence<1, 1>>, - sequence<1>, - sequence<2>>{}); + tuple, sequence<1, 2>>, + sequence<1, 1>, + sequence<0, 3>>{}); } - CK_TILE_DEVICE static int GetWelfordMaxCount(int N) + template + CK_TILE_DEVICE static constexpr auto GetNPerThread(Dstr) { - constexpr ck_tile::index_t kNThreadPerBlock = kNPerBlock / kNPerThread; + constexpr auto nDstrSpan = Dstr::get_distributed_spans().template at<1>(); - int thread_id_n = get_thread_id() % kNThreadPerBlock; - int max_count = - __builtin_amdgcn_readfirstlane(N < kNPerBlock ? 0 : kNPerThread * (N / kNPerBlock)); - int n_per_block_tail_loop = - __builtin_amdgcn_readfirstlane(N - max_count * kNThreadPerBlock); + using Lengths = decltype(nDstrSpan.impl_); - if(n_per_block_tail_loop > 0) - { - int thread_max_n = (thread_id_n + 1) * kNPerThread; - int delta = thread_max_n - n_per_block_tail_loop; - delta = clamp(thread_max_n - n_per_block_tail_loop, 0, kNPerThread); - max_count += kNPerThread - delta; - } + ck_tile::index_t ret = 1; - return max_count; + ck_tile::static_for<0, Lengths::size(), 1>{}( + [&](auto idx) { ret *= Lengths::template at(idx); }); + + return ret; } template @@ -141,29 +134,70 @@ struct Layernorm2dFwd return out_dstr_tensor; } - template - CK_TILE_DEVICE std::enable_if_t - TwoPassLayernorm2dFwd(XBlockWindow& x_block_window, - GammaBlockWindow& gamma_block_window, - BetaBlockWindow& beta_block_window, - YBlockWindow& y_block_window, - MeanBlockWindow& mean_block_window, - InvStdBlockWindow& inv_std_block_window, - ComputeDataType epsilon, - ck_tile::index_t N) const + CK_TILE_HOST_DEVICE static constexpr auto + GetLastloopLayerNormIntraLaneReduceCount(index_t NLength) { - // TODO - Optimize tail loop to reduce move_tile_window() - index_t num_n_tile_iteration = - __builtin_amdgcn_readfirstlane(integer_divide_ceil(N, kNPerBlock)); + using S = typename Problem::BlockShape; + // S::kNWarpPerBlock, S::kNThreadPerWarp, S::kNPerThread + auto LastloopN = NLength % kNPerBlock == 0 ? kNPerBlock : NLength % kNPerBlock; + constexpr auto NThread = S::kNWarpPerBlock * S::kNThreadPerWarp; + auto iNLane = get_thread_local_1d_id() % NThread; + auto iN0 = LastloopN / (S::kNPerThread * S::kNThreadPerWarp); + auto iN1 = (LastloopN % (S::kNPerThread * S::kNThreadPerWarp)) / S::kNPerThread; + auto N2 = (LastloopN % (S::kNPerThread * S::kNThreadPerWarp)) % S::kNPerThread; + auto iN3 = iNLane < iN1 ? S::kNPerThread : iNLane == iN1 ? N2 : 0; - int welford_max_count = GetWelfordMaxCount(N); - ThreadWelford thread_welford{welford_max_count}; + return iN0 * S::kNPerThread + iN3; + } + + template + CK_TILE_DEVICE std::enable_if_t OnePassLayernorm2dFwd(const XDataType* p_x, + const GammaDataType* p_gamma, + const BetaDataType* p_beta, + YDataType* p_y, + const ComputeDataType epsilon, + ck_tile::index_t M, + ck_tile::index_t N) const + { + using S = typename Problem::BlockShape; + constexpr auto I0 = number<0>{}; + constexpr auto I1 = number<1>{}; + + const auto x_m_n = [&]() { + const auto x_dram_naive = make_naive_tensor_view( + p_x, make_tuple(M, N), make_tuple(N, 1), number{}, number<1>{}); + + return pad_tensor_view(x_dram_naive, + make_tuple(number{}, number{}), + sequence{}); + }(); + + const auto gamma_n = [&]() { + const auto gamma_dram_naive = make_naive_tensor_view( + p_gamma, make_tuple(N), make_tuple(1), number{}, number<1>{}); + + return pad_tensor_view( + gamma_dram_naive, make_tuple(number{}), sequence{}); + }(); + + const auto beta_n = [&]() { + const auto gamma_dram_naive = make_naive_tensor_view( + p_beta, make_tuple(N), make_tuple(1), number{}, number<1>{}); + + return pad_tensor_view( + gamma_dram_naive, make_tuple(number{}), sequence{}); + }(); + + const auto iM = get_block_id() * kMPerBlock; + + constexpr auto xDstr = MakeXBlockTileDistribution(); + + auto x_block_window = make_tile_window( + x_m_n, make_tuple(number{}, number{}), {iM, 0}, xDstr); + + auto intra_thread_count_last = GetLastloopLayerNormIntraLaneReduceCount(N); + + ThreadWelford thread_welford{intra_thread_count_last}; using XTensorType = decltype(load_tile(x_block_window)); auto mean_compute_block_tensor = @@ -174,13 +208,20 @@ struct Layernorm2dFwd clear_tile(mean_compute_block_tensor); clear_tile(var_compute_block_tensor); - for(int iN = __builtin_amdgcn_readfirstlane(0); iN < num_n_tile_iteration; ++iN) - { - const auto x_block_tensor = load_tile(x_block_window); + const auto x_block_tensor = load_tile(x_block_window); + thread_welford(x_block_tensor, mean_compute_block_tensor, var_compute_block_tensor); - thread_welford(x_block_tensor, mean_compute_block_tensor, var_compute_block_tensor); - move_tile_window(x_block_window, {0, kNPerBlock}); - } + constexpr auto gammaDstr = MakeGammaBetaBlockTileDistribution(); + constexpr auto betaDstr = gammaDstr; + + auto gamma_block_window = + make_tile_window(gamma_n, make_tuple(number{}), {0}, gammaDstr); + + auto beta_block_window = + make_tile_window(beta_n, make_tuple(number{}), {0}, betaDstr); + + const auto gamma_block_tensor = load_tile(gamma_block_window); + const auto beta_block_tensor = load_tile(beta_block_window); // TODO: support cross warp Welford WarpMergeWelford{}( @@ -188,17 +229,155 @@ struct Layernorm2dFwd auto inv_std_compute_block_tensor = InvSqrt(var_compute_block_tensor, epsilon); - if constexpr(kSaveMean) - store_tile(mean_block_window, cast_tile(mean_compute_block_tensor)); - if constexpr(kSaveInvStd) - store_tile(inv_std_block_window, - cast_tile(inv_std_compute_block_tensor)); + // TODO: Extract normalize pipeline + const auto y_m_n = [&]() { + const auto y_dram_naive = make_naive_tensor_view( + p_y, make_tuple(M, N), make_tuple(N, 1), number{}, number<1>{}); + + return pad_tensor_view(y_dram_naive, + make_tuple(number{}, number{}), + sequence{}); + }(); + + auto y_block_window = make_tile_window( + y_m_n, make_tuple(number{}, number{}), {iM, 0}); + + constexpr auto x_spans = decltype(x_block_tensor)::get_distributed_spans(); + + auto y_block_tensor = + make_static_distributed_tensor(x_block_tensor.get_tile_distribution()); + + sweep_tile_span(x_spans[I1], [&](auto idx1) { + constexpr auto j_idx = make_tuple(idx1); + const auto gamma = type_convert(gamma_block_tensor[j_idx]); + const auto beta = type_convert(beta_block_tensor[j_idx]); + + sweep_tile_span(x_spans[I0], [&](auto idx0) { + constexpr auto i_idx = make_tuple(idx0); + constexpr auto i_j_idx = make_tuple(idx0, idx1); + + const auto mean = mean_compute_block_tensor[i_idx]; + const auto inv_std = inv_std_compute_block_tensor[i_idx]; + + const auto x = type_convert(x_block_tensor[i_j_idx]); + auto y = (x - mean) * inv_std * gamma + beta; + + y_block_tensor(i_j_idx) = type_convert(y); + }); + }); + + store_tile(y_block_window, y_block_tensor); + } + + template + CK_TILE_DEVICE std::enable_if_t TwoPassLayernorm2dFwd(const XDataType* p_x, + const GammaDataType* p_gamma, + const BetaDataType* p_beta, + YDataType* p_y, + const ComputeDataType epsilon, + ck_tile::index_t M, + ck_tile::index_t N) const + { + using S = typename Problem::BlockShape; + constexpr auto I0 = number<0>{}; + constexpr auto I1 = number<1>{}; + + const auto x_m_n = [&]() { + const auto x_dram_naive = make_naive_tensor_view( + p_x, make_tuple(M, N), make_tuple(N, 1), number{}, number<1>{}); + + return pad_tensor_view(x_dram_naive, + make_tuple(number{}, number{}), + sequence{}); + }(); + + const auto gamma_n = [&]() { + const auto gamma_dram_naive = make_naive_tensor_view( + p_gamma, make_tuple(N), make_tuple(1), number{}, number<1>{}); + + return pad_tensor_view( + gamma_dram_naive, make_tuple(number{}), sequence{}); + }(); + + const auto beta_n = [&]() { + const auto gamma_dram_naive = make_naive_tensor_view( + p_beta, make_tuple(N), make_tuple(1), number{}, number<1>{}); + + return pad_tensor_view( + gamma_dram_naive, make_tuple(number{}), sequence{}); + }(); + + const auto iM = get_block_id() * kMPerBlock; + + constexpr auto xDstr = MakeXBlockTileDistribution(); + + auto x_block_window = make_tile_window( + x_m_n, make_tuple(number{}, number{}), {iM, 0}, xDstr); + + index_t num_n_tile_iteration = + __builtin_amdgcn_readfirstlane((N + kNPerBlock - 1) / kNPerBlock); + + auto intra_thread_count = S::kNRepeat * S::kNPerThread * (num_n_tile_iteration - 1); + auto intra_thread_count_last = GetLastloopLayerNormIntraLaneReduceCount(N); + + ThreadWelford thread_welford{intra_thread_count}; + ThreadWelford thread_welford_last{intra_thread_count_last}; + + using XTensorType = decltype(load_tile(x_block_window)); + auto mean_compute_block_tensor = + thread_welford.template MakeInitialMeanVarDistributedTensor(); + auto var_compute_block_tensor = + thread_welford.template MakeInitialMeanVarDistributedTensor(); + + clear_tile(mean_compute_block_tensor); + clear_tile(var_compute_block_tensor); + + for(int iN = __builtin_amdgcn_readfirstlane(0); iN < num_n_tile_iteration - 1; ++iN) + { + const auto x_block_tensor = load_tile(x_block_window); + + thread_welford(x_block_tensor, mean_compute_block_tensor, var_compute_block_tensor); + move_tile_window(x_block_window, {0, kNPerBlock}); + } + const auto x_block_tensor_ = load_tile(x_block_window); + + thread_welford_last.cur_count_ += intra_thread_count; + thread_welford_last.max_count_ += intra_thread_count; + thread_welford_last(x_block_tensor_, mean_compute_block_tensor, var_compute_block_tensor); + thread_welford.cur_count_ += intra_thread_count_last; + + // TODO: support cross warp Welford + WarpMergeWelford{}( + mean_compute_block_tensor, var_compute_block_tensor, thread_welford.cur_count_); + + auto inv_std_compute_block_tensor = InvSqrt(var_compute_block_tensor, epsilon); + + // TODO: Extract normalize pipeline + const auto y_m_n = [&]() { + const auto y_dram_naive = make_naive_tensor_view( + p_y, make_tuple(M, N), make_tuple(N, 1), number{}, number<1>{}); + + return pad_tensor_view(y_dram_naive, + make_tuple(number{}, number{}), + sequence{}); + }(); + + auto y_block_window = make_tile_window( + y_m_n, make_tuple(number{}, number{}), {iM, 0}); + + constexpr auto gammaDstr = MakeGammaBetaBlockTileDistribution(); + constexpr auto betaDstr = gammaDstr; + + auto gamma_block_window = + make_tile_window(gamma_n, make_tuple(number{}), {0}, gammaDstr); + + auto beta_block_window = + make_tile_window(beta_n, make_tuple(number{}), {0}, betaDstr); // reverse read x to reuse cache ck_tile::index_t stride_to_right_most_window = N % kNPerBlock == 0 ? N - kNPerBlock : N - N % kNPerBlock; - move_tile_window(x_block_window, {0, -kNPerBlock}); move_tile_window(gamma_block_window, {stride_to_right_most_window}); move_tile_window(beta_block_window, {stride_to_right_most_window}); move_tile_window(y_block_window, {0, stride_to_right_most_window}); @@ -243,209 +422,35 @@ struct Layernorm2dFwd } } - template - CK_TILE_DEVICE std::enable_if_t - OnePassLayernorm2dFwd(XBlockWindow& x_block_window, - GammaBlockWindow& gamma_block_window, - BetaBlockWindow& beta_block_window, - YBlockWindow& y_block_window, - MeanBlockWindow& mean_block_window, - InvStdBlockWindow& inv_std_block_window, - ComputeDataType epsilon, - ck_tile::index_t N) const + CK_TILE_DEVICE void operator()(const void* p_x, + const void* p_gamma, + const void* p_beta, + void* p_y, + const ComputeDataType epsilon, + ck_tile::index_t M, + ck_tile::index_t N) const { - int welford_max_count = GetWelfordMaxCount(N); - ThreadWelford thread_welford{welford_max_count}; - - using XTensorType = decltype(load_tile(x_block_window)); - auto mean_compute_block_tensor = - thread_welford.template MakeInitialMeanVarDistributedTensor(); - auto var_compute_block_tensor = - thread_welford.template MakeInitialMeanVarDistributedTensor(); - - clear_tile(mean_compute_block_tensor); - clear_tile(var_compute_block_tensor); - - const auto x_block_tensor = load_tile(x_block_window); - thread_welford(x_block_tensor, mean_compute_block_tensor, var_compute_block_tensor); - // TODO: support cross warp Welford - WarpMergeWelford{}( - mean_compute_block_tensor, var_compute_block_tensor, thread_welford.cur_count_); - - auto inv_std_compute_block_tensor = InvSqrt(var_compute_block_tensor, epsilon); - - if constexpr(kSaveMean) - store_tile(mean_block_window, cast_tile(mean_compute_block_tensor)); - if constexpr(kSaveInvStd) - store_tile(inv_std_block_window, - cast_tile(inv_std_compute_block_tensor)); - - // normalize - const auto gamma_block_tensor = load_tile(gamma_block_window); - const auto beta_block_tensor = load_tile(beta_block_window); - - constexpr auto x_spans = decltype(x_block_tensor)::get_distributed_spans(); - - auto y_block_tensor = - make_static_distributed_tensor(x_block_tensor.get_tile_distribution()); - - sweep_tile_span(x_spans[I1], [&](auto idx1) { - constexpr auto j_idx = make_tuple(idx1); - const auto gamma = type_convert(gamma_block_tensor[j_idx]); - const auto beta = type_convert(beta_block_tensor[j_idx]); - - sweep_tile_span(x_spans[I0], [&](auto idx0) { - constexpr auto i_idx = make_tuple(idx0); - constexpr auto i_j_idx = make_tuple(idx0, idx1); - - const auto mean = mean_compute_block_tensor[i_idx]; - const auto inv_std = inv_std_compute_block_tensor[i_idx]; - - const auto x = type_convert(x_block_tensor[i_j_idx]); - auto y = (x - mean) * inv_std * gamma + beta; - - y_block_tensor(i_j_idx) = type_convert(y); - }); - }); - - store_tile(y_block_window, y_block_tensor); - } - - CK_TILE_DEVICE void operator()(Kargs kargs) const - { - const auto x_m_n = [&]() { - const auto x_dram_naive = make_naive_tensor_view( - static_cast(kargs.p_x), - make_tuple(kargs.M, kargs.N), - make_tuple(kargs.N, 1), - number{}, - number<1>{}); - - return pad_tensor_view(x_dram_naive, - make_tuple(number{}, number{}), - sequence{}); - }(); - - const auto gamma_n = [&]() { - const auto gamma_dram_naive = make_naive_tensor_view( - static_cast(kargs.p_gamma), - make_tuple(kargs.N), - make_tuple(1), - number{}, - number<1>{}); - - return pad_tensor_view( - gamma_dram_naive, make_tuple(number{}), sequence{}); - }(); - - const auto beta_n = [&]() { - const auto gamma_dram_naive = make_naive_tensor_view( - static_cast(kargs.p_beta), - make_tuple(kargs.N), - make_tuple(1), - number{}, - number<1>{}); - - return pad_tensor_view( - gamma_dram_naive, make_tuple(number{}), sequence{}); - }(); - - const auto iM = get_block_id() * kMPerBlock; - - constexpr auto xDstr = MakeXBlockTileDistribution(); - - auto x_block_window = make_tile_window( - x_m_n, make_tuple(number{}, number{}), {iM, 0}, xDstr); - - const auto y_m_n = [&]() { - const auto y_dram_naive = make_naive_tensor_view( - static_cast(kargs.p_y), - make_tuple(kargs.M, kargs.N), - make_tuple(kargs.N, 1), - number{}, - number<1>{}); - - return pad_tensor_view(y_dram_naive, - make_tuple(number{}, number{}), - sequence{}); - }(); - - auto y_block_window = make_tile_window( - y_m_n, make_tuple(number{}, number{}), {iM, 0}); - - constexpr auto gammaDstr = MakeGammaBetaBlockTileDistribution(); - constexpr auto betaDstr = gammaDstr; - - auto gamma_block_window = - make_tile_window(gamma_n, make_tuple(number{}), {0}, gammaDstr); - - auto beta_block_window = make_tile_window( - beta_n, make_tuple(number{}, number{}), {0}, betaDstr); - - auto mean_block_window = [&]() { - if constexpr(kSaveMean) - { - const auto mean_m = [&]() { - const auto mean_dram_naive = - make_naive_tensor_view_packed( - static_cast(kargs.p_mean), - make_tuple(kargs.M), - number<1>{}); - - return pad_tensor_view( - mean_dram_naive, make_tuple(number{}), sequence{}); - }(); - - return make_tile_window(mean_m, make_tuple(number{}), {iM}); - } - else - return make_null_tile_window(make_tuple(number{})); - }(); - - auto inv_std_block_window = [&]() { - if constexpr(kSaveInvStd) - { - const auto inv_std_m = [&]() { - const auto inv_std_dram_naive = - make_naive_tensor_view_packed( - static_cast(kargs.p_invStd), - make_tuple(kargs.M), - number<1>{}); - - return pad_tensor_view( - inv_std_dram_naive, make_tuple(number{}), sequence{}); - }(); - - return make_tile_window(inv_std_m, make_tuple(number{}), {iM}); - } - else - return make_null_tile_window(make_tuple(number{})); - }(); - - if(kargs.N <= kNPerBlock) - OnePassLayernorm2dFwd(x_block_window, - gamma_block_window, - beta_block_window, - y_block_window, - mean_block_window, - inv_std_block_window, - static_cast(kargs.epsilon), - kargs.N); + if constexpr(kTwoPass) + { + TwoPassLayernorm2dFwd(static_cast(p_x), + static_cast(p_gamma), + static_cast(p_beta), + static_cast(p_y), + static_cast(epsilon), + M, + N); + } else - TwoPassLayernorm2dFwd(x_block_window, - gamma_block_window, - beta_block_window, - y_block_window, - mean_block_window, - inv_std_block_window, - static_cast(kargs.epsilon), - kargs.N); + { + + OnePassLayernorm2dFwd(static_cast(p_x), + static_cast(p_gamma), + static_cast(p_beta), + static_cast(p_y), + static_cast(epsilon), + M, + N); + } } }; diff --git a/include/ck_tile/ops/layernorm2d/pipeline/block_layernorm2d_fwd_problem.hpp b/include/ck_tile/ops/layernorm2d/pipeline/block_layernorm2d_fwd_problem.hpp index 707a38f621..915d843c1b 100644 --- a/include/ck_tile/ops/layernorm2d/pipeline/block_layernorm2d_fwd_problem.hpp +++ b/include/ck_tile/ops/layernorm2d/pipeline/block_layernorm2d_fwd_problem.hpp @@ -15,20 +15,20 @@ template + bool kPadN_, + bool kTwoPass_> struct BlockLayernorm2dFwdProblem { - using XDataType = remove_cvref_t; - using GammaDataType = remove_cvref_t; - using BetaDataType = remove_cvref_t; - using ComputeDataType = remove_cvref_t; - using YDataType = remove_cvref_t; - using MeanDataType = remove_cvref_t; - using InvStdDataType = remove_cvref_t; - using BlockShape = remove_cvref_t; - static constexpr bool kPadM = kPadM_; - static constexpr bool kPadN = kPadN_; + using XDataType = remove_cvref_t; + using GammaDataType = remove_cvref_t; + using BetaDataType = remove_cvref_t; + using ComputeDataType = remove_cvref_t; + using YDataType = remove_cvref_t; + using MeanDataType = remove_cvref_t; + using InvStdDataType = remove_cvref_t; + using BlockShape = remove_cvref_t; + static constexpr bool kPadN = kPadN_; + static constexpr bool kTwoPass = kTwoPass_; }; } // namespace ck_tile diff --git a/include/ck_tile/ops/layernorm2d/pipeline/tile_layernorm2d_fwd_shape.hpp b/include/ck_tile/ops/layernorm2d/pipeline/tile_layernorm2d_fwd_shape.hpp index 1ff541d844..4b5d3793b0 100644 --- a/include/ck_tile/ops/layernorm2d/pipeline/tile_layernorm2d_fwd_shape.hpp +++ b/include/ck_tile/ops/layernorm2d/pipeline/tile_layernorm2d_fwd_shape.hpp @@ -12,13 +12,14 @@ template {}); - static constexpr index_t kNPerThread = ThreadTile::at(number<1>{}); + static constexpr index_t kNRepeat = ThreadTile::at(number<1>{}); + static constexpr index_t kNPerThread = ThreadTile::at(number<2>{}); static constexpr index_t kMPerWarp = WarpTile::at(number<0>{}); static constexpr index_t kNPerWarp = WarpTile::at(number<1>{}); static constexpr index_t kMThreadPerWarp = kMPerWarp / kMPerThread; - static constexpr index_t kNThreadPerWarp = kNPerWarp / kNPerThread; + static constexpr index_t kNThreadPerWarp = kNPerWarp / kNPerThread / kNRepeat; static constexpr index_t kMPerBlock = BlockTile::at(number<0>{}); static constexpr index_t kNPerBlock = BlockTile::at(number<1>{});