From 1e0c9fde519453835daa6c161f0ec3aee5c4b113 Mon Sep 17 00:00:00 2001 From: rocking Date: Thu, 24 Oct 2024 22:22:11 +0000 Subject: [PATCH] Add add_rmsnorm2d_rdquant kernel --- .../07_add_rmsnorm2d_rdquant/CMakeLists.txt | 25 ++ .../example_add_rmsnorm2d_rdquant_fwd.cpp | 181 +++++++++++++ example/ck_tile/CMakeLists.txt | 1 + include/ck_tile/ops/add_rmsnorm2d_rdquant.hpp | 12 + .../add_rmsnorm2d_rdquant_fwd_kernel.hpp | 239 ++++++++++++++++++ .../add_rmsnorm2d_rdquant_fwd_shape.hpp | 78 ++++++ ...2d_rdquant_fwd_pipeline_default_policy.hpp | 94 +++++++ ...msnorm2d_rdquant_fwd_pipeline_one_pass.hpp | 140 ++++++++++ ...rmsnorm2d_rdquant_fwd_pipeline_problem.hpp | 41 +++ ...msnorm2d_rdquant_fwd_pipeline_two_pass.hpp | 70 +++++ 10 files changed, 881 insertions(+) create mode 100644 example/ck_tile/07_add_rmsnorm2d_rdquant/CMakeLists.txt create mode 100644 example/ck_tile/07_add_rmsnorm2d_rdquant/example_add_rmsnorm2d_rdquant_fwd.cpp create mode 100644 include/ck_tile/ops/add_rmsnorm2d_rdquant.hpp create mode 100644 include/ck_tile/ops/add_rmsnorm2d_rdquant/kernel/add_rmsnorm2d_rdquant_fwd_kernel.hpp create mode 100644 include/ck_tile/ops/add_rmsnorm2d_rdquant/kernel/add_rmsnorm2d_rdquant_fwd_shape.hpp create mode 100644 include/ck_tile/ops/add_rmsnorm2d_rdquant/pipeline/add_rmsnorm2d_rdquant_fwd_pipeline_default_policy.hpp create mode 100644 include/ck_tile/ops/add_rmsnorm2d_rdquant/pipeline/add_rmsnorm2d_rdquant_fwd_pipeline_one_pass.hpp create mode 100644 include/ck_tile/ops/add_rmsnorm2d_rdquant/pipeline/add_rmsnorm2d_rdquant_fwd_pipeline_problem.hpp create mode 100644 include/ck_tile/ops/add_rmsnorm2d_rdquant/pipeline/add_rmsnorm2d_rdquant_fwd_pipeline_two_pass.hpp diff --git a/example/ck_tile/07_add_rmsnorm2d_rdquant/CMakeLists.txt b/example/ck_tile/07_add_rmsnorm2d_rdquant/CMakeLists.txt new file mode 100644 index 0000000000..ef475a7330 --- /dev/null +++ b/example/ck_tile/07_add_rmsnorm2d_rdquant/CMakeLists.txt @@ -0,0 +1,25 @@ +# set(TILE_ADD_RMSNORM2D_RDQUANT_FWD "tile_rmsnorm2d_fwd") +# # not using add_example_executable() to add this target, since we don't want this to have +# # to be included in "make all/install/check" +# message("adding ${TILE_ADD_RMSNORM2D_RDQUANT_FWD}") +# file(GLOB INSTANCE_SRCS instances/*.cpp) +# add_executable(${TILE_ADD_RMSNORM2D_RDQUANT_FWD} EXCLUDE_FROM_ALL add_rmsnorm2d_rdquant_fwd.cpp) +# target_include_directories(${TILE_ADD_RMSNORM2D_RDQUANT_FWD} PRIVATE ${CMAKE_CURRENT_LIST_DIR}) +# target_sources(${TILE_ADD_RMSNORM2D_RDQUANT_FWD} PRIVATE ${INSTANCE_SRCS}) + +set(TILE_ADD_RMSNORM2D_RDQUANT_FWD_COMPILE_OPTIONS) + +# NOTE: we turn off undefined-func-template to let source compile without explicit declare function specializations +list(APPEND TILE_ADD_RMSNORM2D_RDQUANT_FWD_COMPILE_OPTIONS -Wno-undefined-func-template -Wno-float-equal) + +# target_compile_options(${TILE_ADD_RMSNORM2D_RDQUANT_FWD} PRIVATE ${TILE_ADD_RMSNORM2D_RDQUANT_FWD_COMPILE_OPTIONS}) + +set(EXAMPLE_ADD_RMSNORM2D_RDQUANT_FWD "tile_example_add_rmsnorm2d_rdquant_fwd") +add_executable(${EXAMPLE_ADD_RMSNORM2D_RDQUANT_FWD} EXCLUDE_FROM_ALL example_add_rmsnorm2d_rdquant_fwd.cpp) +target_compile_options(${EXAMPLE_ADD_RMSNORM2D_RDQUANT_FWD} PRIVATE ${TILE_ADD_RMSNORM2D_RDQUANT_FWD_COMPILE_OPTIONS}) + +# TODO: we have to turn off this global prop, otherwise the progress bar generated +# by cmake will print too many files, execvp: /bin/sh: Argument list too long +# however, this property may affect global +# TODO: consider codegen a makefile by us +set_property(GLOBAL PROPERTY RULE_MESSAGES OFF) diff --git a/example/ck_tile/07_add_rmsnorm2d_rdquant/example_add_rmsnorm2d_rdquant_fwd.cpp b/example/ck_tile/07_add_rmsnorm2d_rdquant/example_add_rmsnorm2d_rdquant_fwd.cpp new file mode 100644 index 0000000000..f006a6c527 --- /dev/null +++ b/example/ck_tile/07_add_rmsnorm2d_rdquant/example_add_rmsnorm2d_rdquant_fwd.cpp @@ -0,0 +1,181 @@ +#include "ck_tile/host.hpp" +#include "ck_tile/core.hpp" +#include "ck_tile/host/kernel_launch.hpp" +#include "ck_tile/ops/add_rmsnorm2d_rdquant.hpp" +#include + +auto create_args(int argc, char* argv[]) +{ + ck_tile::ArgParser arg_parser; + arg_parser.insert("m", "3328", "m dimension") + .insert("n", "4096", "n dimension") + .insert("stride", "-1", "stride per row, if -1 then equal to n") + .insert("e", "1e-5", "epsilon") + .insert("v", "1", "cpu validation or not") + .insert("prec", "fp16", "precision") + .insert("warmup", "0", "cold iter") + .insert("repeat", "1", "hot iter"); + + bool result = arg_parser.parse(argc, argv); + return std::make_tuple(result, arg_parser); +} + +template +bool run(const ck_tile::ArgParser& arg_parser) +{ + ck_tile::index_t m = arg_parser.get_int("m"); + ck_tile::index_t n = arg_parser.get_int("n"); + ck_tile::index_t stride = arg_parser.get_int("stride"); + if(stride < 0) + stride = n; + float epsilon = arg_parser.get_float("e"); + std::string data_type = arg_parser.get_str("prec"); + int do_validation = arg_parser.get_int("v"); + int warmup = arg_parser.get_int("warmup"); + int repeat = arg_parser.get_int("repeat"); + + assert(stride >= n); + + using ADataType = DataType; + using BDataType = DataType; + using GammaDataType = DataType; + using XDataType = DataType; + using YScaleDataType = DataType; + using QYDataType = int8_t; + using ComputeDataType = float; + + // host verify + ck_tile::HostTensor a_host({m, n}, {stride, 1}); + ck_tile::HostTensor b_host({m, n}, {stride, 1}); + ck_tile::HostTensor gamma_host({n}); + + ck_tile::HostTensor x_host_ref({m, n}, {stride, 1}); + ck_tile::HostTensor x_host_dev({m, n}, {stride, 1}); + ck_tile::HostTensor yscale_host_ref({m}, {1}); + ck_tile::HostTensor yscale_host_dev({m}, {1}); + ck_tile::HostTensor qy_host_ref({m, n}, {stride, 1}); + ck_tile::HostTensor qy_host_dev({m, n}, {stride, 1}); + + ck_tile::FillUniformDistribution{-.5f, .5f}(a_host); + ck_tile::FillUniformDistribution{-.5f, .5f}(b_host); + ck_tile::FillUniformDistribution{-.5f, .5f}(gamma_host); + + ck_tile::DeviceMem a_buf(a_host.get_element_space_size_in_bytes()); + ck_tile::DeviceMem b_buf(b_host.get_element_space_size_in_bytes()); + ck_tile::DeviceMem gamma_buf(gamma_host.get_element_space_size_in_bytes()); + ck_tile::DeviceMem x_buf(x_host_dev.get_element_space_size_in_bytes()); + ck_tile::DeviceMem yscale_buf(yscale_host_dev.get_element_space_size_in_bytes()); + ck_tile::DeviceMem qy_buf(qy_host_dev.get_element_space_size_in_bytes()); + + a_buf.ToDevice(a_host.data()); + b_buf.ToDevice(b_host.data()); + gamma_buf.ToDevice(gamma_host.data()); + + constexpr bool kTwoPass = false; + + using BlockWarps = ck_tile::sequence<2, 2>; + using BlockTile = ck_tile::sequence<2, 128>; + using WarpTile = ck_tile::sequence<1, 64>; + using Vector = ck_tile::sequence<1, 1>; + + using Shape = ck_tile::AddRmsnorm2dRdquantShape; + using Problem = ck_tile::AddRmsnorm2dRdquantFwdPipelineProblem; + + using OnePassPipeline = ck_tile::AddRmsnorm2dRdquantFwdPipelineOnePass; + using TwoPassPipeline = ck_tile::AddRmsnorm2dRdquantFwdPipelineTwoPass; + using Pipeline = std::conditional_t; + using Kernel = ck_tile::AddRmsnorm2dRdquantFwd; + + ck_tile::AddRmsnorm2dRdquantFwdHostArgs args{a_buf.GetDeviceBuffer(), + b_buf.GetDeviceBuffer(), + gamma_buf.GetDeviceBuffer(), + x_buf.GetDeviceBuffer(), + yscale_buf.GetDeviceBuffer(), + qy_buf.GetDeviceBuffer(), + epsilon, + m, + n, + stride}; + + auto kargs = Kernel::MakeKargs(args); + + const dim3 grids = Kernel::GridSize(args); + constexpr dim3 blocks = Kernel::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = 1; + auto s = ck_tile::stream_config{nullptr, true, 0, warmup, repeat}; + + ck_tile::launch_kernel( + s, ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); + + bool pass = true; + + if(do_validation) + { + using InvRmsDataType = DataType; + ck_tile::HostTensor invRms_host_ref({m}); + + // ck_tile::reference_rmsnorm2d_fwd( + // x_host, gamma_host, qy_host_ref, invRms_host_ref, epsilon); + + // qy_buf.FromDevice(qy_host_dev.data()); + + // auto [rtol, atol] = ck_tile::make_tuple(1e-3, 1e-3); + // if(stride == n) + // { + // pass = ck_tile::check_err( + // qy_host_dev, qy_host_ref, std::string("OUT Error: Incorrect results!"), rtol, + // atol); + // } + // else + // { + // for(int i_r = 0; i_r < m; i_r++) + // { + // std::vector qy_host_dev_row(qy_host_dev.begin() + i_r * stride, + // qy_host_dev.begin() + i_r * stride + n); + // std::vector qy_host_ref_row(qy_host_ref.begin() + i_r * stride, + // qy_host_ref.begin() + i_r * stride + n); + // pass &= ck_tile::check_err(qy_host_dev_row, + // qy_host_ref_row, + // std::string("OUT[") + std::to_string(i_r) + + // std::string("] Error: Incorrect results!"), + // rtol, + // atol); + // } + // } + + std::cout << "[" << data_type << "]" + << " m:" << m << ", n:" << n << ", stride:" << stride + << ", valid:" << (pass ? "y" : "n") << std::flush << std::endl; + } + + return pass; +} + +int main(int argc, char* argv[]) +{ + auto [result, arg_parser] = create_args(argc, argv); + if(!result) + return -1; + + const std::string data_type = arg_parser.get_str("prec"); + if(data_type == "fp16") + { + return run(arg_parser) ? 0 : -2; + } + + return -3; +} diff --git a/example/ck_tile/CMakeLists.txt b/example/ck_tile/CMakeLists.txt index 22d51d1e10..a375b0e83f 100644 --- a/example/ck_tile/CMakeLists.txt +++ b/example/ck_tile/CMakeLists.txt @@ -8,3 +8,4 @@ add_subdirectory(03_gemm) add_subdirectory(04_img2col) add_subdirectory(05_reduce) add_subdirectory(06_rmsnorm2d) +add_subdirectory(07_add_rmsnorm2d_rdquant) diff --git a/include/ck_tile/ops/add_rmsnorm2d_rdquant.hpp b/include/ck_tile/ops/add_rmsnorm2d_rdquant.hpp new file mode 100644 index 0000000000..21c77f2c9c --- /dev/null +++ b/include/ck_tile/ops/add_rmsnorm2d_rdquant.hpp @@ -0,0 +1,12 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/ops/add_rmsnorm2d_rdquant/kernel/add_rmsnorm2d_rdquant_fwd_kernel.hpp" +#include "ck_tile/ops/add_rmsnorm2d_rdquant/kernel/add_rmsnorm2d_rdquant_fwd_shape.hpp" +#include "ck_tile/ops/add_rmsnorm2d_rdquant/pipeline/add_rmsnorm2d_rdquant_fwd_pipeline_default_policy.hpp" +#include "ck_tile/ops/add_rmsnorm2d_rdquant/pipeline/add_rmsnorm2d_rdquant_fwd_pipeline_one_pass.hpp" +#include "ck_tile/ops/add_rmsnorm2d_rdquant/pipeline/add_rmsnorm2d_rdquant_fwd_pipeline_problem.hpp" +#include "ck_tile/ops/add_rmsnorm2d_rdquant/pipeline/add_rmsnorm2d_rdquant_fwd_pipeline_two_pass.hpp" +#include "ck_tile/ops/common/tensor_layout.hpp" diff --git a/include/ck_tile/ops/add_rmsnorm2d_rdquant/kernel/add_rmsnorm2d_rdquant_fwd_kernel.hpp b/include/ck_tile/ops/add_rmsnorm2d_rdquant/kernel/add_rmsnorm2d_rdquant_fwd_kernel.hpp new file mode 100644 index 0000000000..29ba6080da --- /dev/null +++ b/include/ck_tile/ops/add_rmsnorm2d_rdquant/kernel/add_rmsnorm2d_rdquant_fwd_kernel.hpp @@ -0,0 +1,239 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/ops/common.hpp" + +namespace ck_tile { + +// host side args +struct AddRmsnorm2dRdquantFwdHostArgs +{ + const void* p_a; + const void* p_b; + const void* p_gamma; + + void* p_x; + void* p_yscale; + void* p_qy; + + float epsilon; + + index_t m; + index_t n; + index_t stride; // row_stride +}; + +// TODO: Extract some type to wrapper class +template +struct AddRmsnorm2dRdquantFwd +{ + using Pipeline = remove_cvref_t; + using Problem = typename Pipeline::Problem; + + using ADataType = remove_cvref_t; + using BDataType = remove_cvref_t; + using GammaDataType = remove_cvref_t; + using ComputeDataType = remove_cvref_t; + using XDataType = remove_cvref_t; + using YScaleDataType = remove_cvref_t; + using QYDataType = remove_cvref_t; + + static constexpr bool kSaveX = Problem::kSaveX; + + static constexpr index_t Block_M = Problem::BlockShape::Block_M; + static constexpr index_t Block_N = Problem::BlockShape::Block_N; + static constexpr bool kPadM = false; // always no need to pad along M + static constexpr bool kPadN = Problem::kPadN; + static constexpr bool kTwoPass = Problem::kTwoPass; + + static constexpr index_t ThreadPerWarp_N = Problem::BlockShape::ThreadPerWarp_N; + static constexpr index_t Vector_N = Problem::BlockShape::Vector_N; + static constexpr index_t Repeat_N = Problem::BlockShape::Repeat_N; + + static constexpr auto I0 = number<0>{}; + static constexpr auto I1 = number<1>{}; + + struct Kargs + { + const void* p_a; + const void* p_b; + const void* p_gamma; + + void* p_x; + void* p_yscale; + void* p_qy; + + float epsilon; + + index_t m; + index_t n; + index_t stride; // row_stride + }; + using Hargs = AddRmsnorm2dRdquantFwdHostArgs; + + CK_TILE_HOST static constexpr Kargs MakeKargs(const Hargs& hargs) + { + return Kargs{hargs.p_a, + hargs.p_b, + hargs.p_gamma, + hargs.p_x, + hargs.p_yscale, + hargs.p_qy, + hargs.epsilon, + hargs.m, + hargs.n, + hargs.stride}; + } + + CK_TILE_HOST static constexpr auto GridSize(const Hargs& hargs) + { + return integer_divide_ceil(hargs.m, Block_M); + } + + CK_TILE_HOST static constexpr auto BlockSize() { return Problem::BlockShape::BlockSize; } + + // clang-format off + template struct t2s; + template <> struct t2s { static constexpr const char * name = "fp32"; }; + template <> struct t2s { static constexpr const char * name = "fp16"; }; + template <> struct t2s { static constexpr const char * name = "bf16"; }; + template <> struct t2s { static constexpr const char * name = "fp8"; }; + template <> struct t2s { static constexpr const char * name = "bf8"; }; + // clang-format on + + // in byte + CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize() { return Pipeline::GetSmemSize(); } + + CK_TILE_HOST static std::string GetName() + { + // clang-format off + using S_ = typename Problem::BlockShape; + auto surfix = [&] () { + std::string n; + if (kPadN) n += "_pn"; + if (kSaveX) n += "_x"; + if (kTwoPass) n += "_2p"; + return n; }(); + + #define _SS_ std::string + #define _TS_ std::to_string + return _SS_("add_rmsnorm2d_rdquant_fwd_") + _SS_(t2s::name) + "_" + + _TS_(S_::Block_M) + "x" + _TS_(S_::Block_N) + "_" + _TS_(S_::WarpPerBlock_M) + "x" + _TS_(S_::WarpPerBlock_N) + "_" + + _TS_(S_::Warp_M) + "x" + _TS_(S_::Warp_N) + "_" + _TS_(S_::Vector_M) + "x" + _TS_(S_::Vector_N) + "_" + + _SS_(Pipeline::name) + surfix; + #undef _SS_ + #undef _TS_ + // clang-format on + } + + CK_TILE_DEVICE void operator()(Kargs kargs) const + { + const auto iM = get_block_id() * Block_M; + + const auto a_window = [&]() { + const auto tmp_ = make_naive_tensor_view( + static_cast(kargs.p_a), + make_tuple(kargs.m, kargs.n), + make_tuple(kargs.stride, 1), + number{}, + number<1>{}); + + const auto tmp2_ = pad_tensor_view( + tmp_, make_tuple(number{}, number{}), sequence{}); + return make_tile_window( + tmp2_, make_tuple(number{}, number{}), {iM, 0}); + }(); + + const auto b_window = [&]() { + const auto tmp_ = make_naive_tensor_view( + static_cast(kargs.p_b), + make_tuple(kargs.m, kargs.n), + make_tuple(kargs.stride, 1), + number{}, + number<1>{}); + + const auto tmp2_ = pad_tensor_view( + tmp_, make_tuple(number{}, number{}), sequence{}); + return make_tile_window( + tmp2_, make_tuple(number{}, number{}), {iM, 0}); + }(); + + const auto gamma_window = [&]() { + const auto tmp_ = make_naive_tensor_view( + static_cast(kargs.p_gamma), + make_tuple(kargs.n), + make_tuple(1), + number{}, + number<1>{}); + + const auto tmp2_ = + pad_tensor_view(tmp_, make_tuple(number{}), sequence{}); + + return make_tile_window(tmp2_, make_tuple(number{}), {0}); + }(); + + auto x_window = [&]() { + if constexpr(kSaveX) + { + const auto tmp2_ = [&]() { + const auto tmp_ = make_naive_tensor_view( + static_cast(kargs.p_x), + make_tuple(kargs.m, kargs.n), + make_tuple(kargs.stride, 1), + number{}, + number<1>{}); + + return pad_tensor_view(tmp_, + make_tuple(number{}, number{}), + sequence{}); + }(); + return make_tile_window( + tmp2_, make_tuple(number{}, number{}), {iM, 0}); + } + else + return make_null_tile_window(make_tuple(number{}, number{})); + }(); + + auto yscale_window = [&]() { + auto tmp_ = make_naive_tensor_view( + static_cast(kargs.p_yscale), + make_tuple(kargs.m), + make_tuple(1), + number<1>{}); + + auto tmp2_ = pad_tensor_view(tmp_, make_tuple(number{}), sequence{}); + return make_tile_window(tmp2_, make_tuple(number{}), {iM}); + }(); + + auto qy_window = [&]() { + auto tmp_ = make_naive_tensor_view( + static_cast(kargs.p_qy), + make_tuple(kargs.m, kargs.n), + make_tuple(kargs.stride, 1), + number{}, + number<1>{}); + + auto tmp2_ = pad_tensor_view( + tmp_, make_tuple(number{}, number{}), sequence{}); + return make_tile_window( + tmp2_, make_tuple(number{}, number{}), {iM, 0}); + }(); + + __shared__ char smem[GetSmemSize()]; + + Pipeline{}(a_window, + b_window, + gamma_window, + x_window, + yscale_window, + qy_window, + static_cast(kargs.epsilon), + kargs.n, + smem); + } +}; + +} // namespace ck_tile diff --git a/include/ck_tile/ops/add_rmsnorm2d_rdquant/kernel/add_rmsnorm2d_rdquant_fwd_shape.hpp b/include/ck_tile/ops/add_rmsnorm2d_rdquant/kernel/add_rmsnorm2d_rdquant_fwd_shape.hpp new file mode 100644 index 0000000000..a17c53c73f --- /dev/null +++ b/include/ck_tile/ops/add_rmsnorm2d_rdquant/kernel/add_rmsnorm2d_rdquant_fwd_shape.hpp @@ -0,0 +1,78 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" + +namespace ck_tile { +/* +// clang-format off + +4-level descriptor: BlockTile-> WarpPerBlock-> WarpTile-> Vector + + Block_N (Warp_N * WarpPerBlock_N * Repeat_N ) + +<----------------------< Repeat_N(2)>--------------------->+ + | | + +<-- -->+ + Warp_N + +--------------+--------------+--------------+--------------+----+----------------+ + Warp_M | wrap_0 | wrap_1 | | ^ ^ + +--------------+--------------+ | | + | wrap_2 | wrap_3 | | v + +--------------+--------------+--------------+--------------+----+ Block_M + | | | + + + | + | | | v + +--------------+--------------+--------------+--------------+ + + + each Warp-tile (e.g 16 thrd per row) + + Vector_N (contiguous pixels each thrd holds along N, or vector size) + +-----------+-----------+-----------+-----------+-----------+ + | thrd_0 | thrd_1 | thrd_2 | thrd_3 | ... Vector_M + +-----------+-----------+-----------+-----------+-----------+ + | thrd_16 | thrd_17 | thrd_18 | thrd_19 | ... + +-----------+-----------+-----------+-----------+-----------+ +// clang-format on +*/ +template + typename WarpPerBlock_, // num warps along seq + typename WarpTile_, // warp size, seq + typename Vector_, // contiguous pixels(vector size) along seq + index_t BlockSize_ = + warpSize* reduce_on_sequence(WarpPerBlock_{}, multiplies{}, number<1>{})> +struct AddRmsnorm2dRdquantShape +{ + // block size + static constexpr index_t Block_M = BlockTile_::at(number<0>{}); + static constexpr index_t Block_N = BlockTile_::at(number<1>{}); + + // num warps along seq, within each block + static constexpr index_t WarpPerBlock_M = WarpPerBlock_::at(number<0>{}); + static constexpr index_t WarpPerBlock_N = WarpPerBlock_::at(number<1>{}); + + // warp size + static constexpr index_t Warp_M = WarpTile_::at(number<0>{}); + static constexpr index_t Warp_N = WarpTile_::at(number<1>{}); + + static_assert(Block_M % (WarpPerBlock_M * Warp_M) == 0); + static_assert(Block_N % (WarpPerBlock_N * Warp_N) == 0); + // repeat of each thread along seq + static constexpr index_t Repeat_M = Block_M / (WarpPerBlock_M * Warp_M); + static constexpr index_t Repeat_N = Block_N / (WarpPerBlock_N * Warp_N); + + // vector size along seq + static constexpr index_t Vector_M = Vector_::at(number<0>{}); + static constexpr index_t Vector_N = Vector_::at(number<1>{}); + + static_assert(Warp_M % Vector_M == 0); + static_assert(Warp_N % Vector_N == 0); + // num of threads along seq, within each warp + static constexpr index_t ThreadPerWarp_M = Warp_M / Vector_M; + static constexpr index_t ThreadPerWarp_N = Warp_N / Vector_N; + + static constexpr index_t BlockSize = BlockSize_; +}; + +} // namespace ck_tile diff --git a/include/ck_tile/ops/add_rmsnorm2d_rdquant/pipeline/add_rmsnorm2d_rdquant_fwd_pipeline_default_policy.hpp b/include/ck_tile/ops/add_rmsnorm2d_rdquant/pipeline/add_rmsnorm2d_rdquant_fwd_pipeline_default_policy.hpp new file mode 100644 index 0000000000..ba20eb43ef --- /dev/null +++ b/include/ck_tile/ops/add_rmsnorm2d_rdquant/pipeline/add_rmsnorm2d_rdquant_fwd_pipeline_default_policy.hpp @@ -0,0 +1,94 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/ops/reduce2d/block/block_reduce2d_problem.hpp" +#include "ck_tile/ops/reduce2d/block/block_reduce2d.hpp" + +namespace ck_tile { + +struct AddRmsnorm2dRdquantFwdPipelineDefaultPolicy +{ + template + CK_TILE_DEVICE static constexpr auto MakeABXBlockTileDistribution() + { + using S = typename Problem::BlockShape; + + return make_static_tile_distribution( + tile_distribution_encoding< + sequence<>, + tuple, + sequence>, + tuple, sequence<1, 2>>, + tuple, sequence<2, 2>>, + sequence<1, 1, 2, 2>, + sequence<0, 3, 0, 3>>{}); + } + template + CK_TILE_DEVICE static constexpr auto MakeGammaBlockTileDistribution() + { + using S = typename Problem::BlockShape; + + return make_static_tile_distribution( + tile_distribution_encoding< + sequence, + tuple>, + tuple, sequence<0, 1>>, + tuple, sequence<1, 2>>, + sequence<1, 1>, + sequence<0, 3>>{}); + } + + template + CK_TILE_HOST_DEVICE static constexpr auto GetBlockReduce2d() + { + using P_ = BlockReduce2dProblem; + return BlockReduce2d{}; + } + + template + CK_TILE_HOST_DEVICE static constexpr auto GetBlockReduce2dSync() + { + using P_ = BlockReduce2dProblem; + return BlockReduce2dSync{}; + } + + template + CK_TILE_HOST_DEVICE static constexpr auto GetBlockReduce2dCrossWarpSync() + { + using P_ = BlockReduce2dProblem; + return BlockReduce2dCrossWarpSync{}; + } + + template + CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize() + { + if constexpr(Problem::kNeedCrossWarpSync) + { + using P_ = BlockReduce2dProblem; + + using block_reduce2d = BlockReduce2d; + using x_block_tile = + decltype(make_static_distributed_tensor( + MakeABXBlockTileDistribution())); + using y_block_tile = decltype(block_reduce2d::template MakeYBlockTile()); + + return GetBlockReduce2dCrossWarpSync().template GetSmemSize(); + } + else + { + return 1; // zero size arrays are an extension + } + } +}; +} // namespace ck_tile diff --git a/include/ck_tile/ops/add_rmsnorm2d_rdquant/pipeline/add_rmsnorm2d_rdquant_fwd_pipeline_one_pass.hpp b/include/ck_tile/ops/add_rmsnorm2d_rdquant/pipeline/add_rmsnorm2d_rdquant_fwd_pipeline_one_pass.hpp new file mode 100644 index 0000000000..f37f5e6475 --- /dev/null +++ b/include/ck_tile/ops/add_rmsnorm2d_rdquant/pipeline/add_rmsnorm2d_rdquant_fwd_pipeline_one_pass.hpp @@ -0,0 +1,140 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_pipeline_default_policy.hpp" +#include +#include + +namespace ck_tile { + +template +struct AddRmsnorm2dRdquantFwdPipelineOnePass +{ + using Problem = ck_tile::remove_cvref_t; + using Policy = ck_tile::remove_cvref_t; + + using ADataType = ck_tile::remove_cvref_t; + using BDataType = ck_tile::remove_cvref_t; + using GammaDataType = ck_tile::remove_cvref_t; + using ComputeDataType = ck_tile::remove_cvref_t; + using XDataType = ck_tile::remove_cvref_t; + using YScaleDataType = ck_tile::remove_cvref_t; + using QYDataType = ck_tile::remove_cvref_t; + + static constexpr bool kHasGamma = !std::is_same_v; + static constexpr bool kSaveX = Problem::kSaveX; + + static constexpr bool kNeedCrossWarpSync = Problem::kNeedCrossWarpSync; + static constexpr bool kPadM = false; // TODO - BlockAddRmsnorm2dRdquantFwdProblem::kPadM + static constexpr bool kPadN = Problem::kPadN; + + static constexpr const char* name = []() { + if constexpr(kNeedCrossWarpSync) + return "bpr_op"; // block per row + else + return "wpr_op"; // warp per row + }(); + + CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize() + { + return Policy::template GetSmemSize(); + } + + template + CK_TILE_DEVICE auto operator()(const AWindow& a_window_, + const BWindow& b_window_, + const GammaWindow& gamma_window_, + XWindow& x_window, + YScaleWindow& yscale_window, + QYWindow& y_window, + ComputeDataType epsilon, + ck_tile::index_t row_size, + void* smem) const + { + const auto a_window = + make_tile_window(a_window_, Policy::template MakeABXBlockTileDistribution()); + const auto b_window = + make_tile_window(b_window_, Policy::template MakeABXBlockTileDistribution()); + const auto gamma_window = make_tile_window( + gamma_window_, Policy::template MakeGammaBlockTileDistribution()); + + auto reduce_square_sum_func = [](const auto& v0, const auto& v1) { return v0 + v1 * v1; }; + auto reduce_sum_func = [](const auto& v0, const auto& v1) { return v0 + v1; }; + auto reduce_absmax_func = [](const auto& v0, const auto& v1) { return max(v0, abs(v1)); }; + auto reduce_max_func = [](const auto& v0, const auto& v1) { return max(v0, v1); }; + auto block_reduce2d = Policy::template GetBlockReduce2d(); + auto block_reduce2d_sync = Policy::template GetBlockReduce2dSync(); + auto block_reduce2d_cross_warp_sync = + Policy::template GetBlockReduce2dCrossWarpSync(); + + const auto a = load_tile(a_window); + const auto b = load_tile(b_window); + const auto gamma = load_tile(gamma_window); + + auto x = tile_elementwise_in( + [&](const auto& a_, const auto& b_) { + return type_convert(a_) + type_convert(b_); + }, + a, + b); + + if constexpr(kSaveX) + store_tile(x_window, cast_tile(x)); + + // compute mean square, each-thread->cross-lane->cross-warp + auto square_sum = block_reduce2d(x, 0, reduce_square_sum_func); + block_reduce2d_sync(square_sum, reduce_sum_func); + block_reduce2d_cross_warp_sync(square_sum, smem, reduce_sum_func); + + auto inv_rms = tile_elementwise_in( + [&](const auto& v_) { + return type_convert(1.0f) / (sqrt(v_ / row_size + epsilon)); + }, + square_sum); + + // rmsnorm computation + auto y = make_static_distributed_tensor(x.get_tile_distribution()); + sweep_tile(y, [&, inv_rms_ = inv_rms](auto idx) { + constexpr auto i_idx = make_tuple(idx[number<0>{}]); + constexpr auto j_idx = make_tuple(idx[number<1>{}]); + + const auto gamma_ = type_convert(gamma[j_idx]); + + const auto x_ = type_convert(x[idx]); + auto y_ = x_ * inv_rms_[i_idx] * gamma_; + + y(idx) = type_convert(y_); + }); + + // compute absmax, each-thread->cross-lane->cross-warp + auto absmax = block_reduce2d(x, numeric::min(), reduce_absmax_func); + block_reduce2d_sync(absmax, reduce_max_func); + block_reduce2d_cross_warp_sync(absmax, smem, reduce_max_func); + + auto yscale = tile_elementwise_in( + [&](const auto& v_) { + return v_ / type_convert(numeric::max()); + }, + absmax); + store_tile(yscale_window, cast_tile(yscale)); + + // quantize to + auto qy = make_static_distributed_tensor(y.get_tile_distribution()); + sweep_tile(qy, [&, yscale_ = yscale](auto idx) { + constexpr auto i_idx = make_tuple(idx[number<0>{}]); + constexpr auto j_idx = make_tuple(idx[number<1>{}]); + auto qy_ = y[idx] / yscale_[i_idx]; + qy(idx) = saturates{}(qy_); + }); + store_tile(y_window, qy); + } +}; +} // namespace ck_tile diff --git a/include/ck_tile/ops/add_rmsnorm2d_rdquant/pipeline/add_rmsnorm2d_rdquant_fwd_pipeline_problem.hpp b/include/ck_tile/ops/add_rmsnorm2d_rdquant/pipeline/add_rmsnorm2d_rdquant_fwd_pipeline_problem.hpp new file mode 100644 index 0000000000..51cba63a08 --- /dev/null +++ b/include/ck_tile/ops/add_rmsnorm2d_rdquant/pipeline/add_rmsnorm2d_rdquant_fwd_pipeline_problem.hpp @@ -0,0 +1,41 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core/utility/type_traits.hpp" + +namespace ck_tile { + +// X = A + B, Y = RmsNorm2d(X), QY = RowwiseDynamicQuant(Y) = SaturateCast(Y / YScale) +template +struct AddRmsnorm2dRdquantFwdPipelineProblem +{ + using ADataType = remove_cvref_t; + using BDataType = remove_cvref_t; + using GammaDataType = remove_cvref_t; + using ComputeDataType = remove_cvref_t; + using XDataType = remove_cvref_t; + using YScaleDataType = remove_cvref_t; + using QYDataType = remove_cvref_t; + using BlockShape = remove_cvref_t; + + static constexpr bool kNeedCrossLaneSync = BlockShape::ThreadPerWarp_N > 1; + static constexpr bool kNeedCrossWarpSync = BlockShape::WarpPerBlock_N > 1; + + static constexpr bool kPadN = kPadN_; + static constexpr bool kSaveX = kSaveX_; + static constexpr bool kTwoPass = kTwoPass_; +}; + +} // namespace ck_tile diff --git a/include/ck_tile/ops/add_rmsnorm2d_rdquant/pipeline/add_rmsnorm2d_rdquant_fwd_pipeline_two_pass.hpp b/include/ck_tile/ops/add_rmsnorm2d_rdquant/pipeline/add_rmsnorm2d_rdquant_fwd_pipeline_two_pass.hpp new file mode 100644 index 0000000000..0c8c402d84 --- /dev/null +++ b/include/ck_tile/ops/add_rmsnorm2d_rdquant/pipeline/add_rmsnorm2d_rdquant_fwd_pipeline_two_pass.hpp @@ -0,0 +1,70 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_pipeline_default_policy.hpp" +#include +#include + +namespace ck_tile { + +template +struct AddRmsnorm2dRdquantFwdPipelineTwoPass +{ + using Problem = ck_tile::remove_cvref_t; + using Policy = ck_tile::remove_cvref_t; + + using ADataType = ck_tile::remove_cvref_t; + using BDataType = ck_tile::remove_cvref_t; + using GammaDataType = ck_tile::remove_cvref_t; + using ComputeDataType = ck_tile::remove_cvref_t; + using XDataType = ck_tile::remove_cvref_t; + using YScaleDataType = ck_tile::remove_cvref_t; + using QYDataType = ck_tile::remove_cvref_t; + + static constexpr bool kHasGamma = !std::is_same_v; + static constexpr bool kSaveX = Problem::kSaveX; + + static constexpr bool kNeedCrossWarpSync = Problem::kNeedCrossWarpSync; + static constexpr bool kPadM = false; // TODO - BlockAddRmsnorm2dRdquantFwdProblem::kPadM + static constexpr bool kPadN = Problem::kPadN; + + static constexpr const char* name = []() { + if constexpr(kNeedCrossWarpSync) + return "bpr_tp"; // block per row + else + return "wpr_tp"; // warp per row + }(); + + CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize() + { + return Policy::template GetSmemSize(); + } + + template + CK_TILE_DEVICE auto operator()(const AWindow& a_window_, + const BWindow& b_window_, + const GammaWindow& gamma_window_, + XWindow& x_window, + YScaleWindow& yscale_window, + QYWindow& y_window, + ComputeDataType epsilon, + ck_tile::index_t row_size, + void* smem) const + { + auto a_window = + make_tile_window(a_window_, Policy::template MakeABXBlockTileDistribution()); + auto b_window = + make_tile_window(b_window_, Policy::template MakeABXBlockTileDistribution()); + const auto gamma_window = make_tile_window( + gamma_window_, Policy::template MakeGammaBlockTileDistribution()); + } +}; +} // namespace ck_tile