diff --git a/example/21_gemm_layernorm/CMakeLists.txt b/example/21_gemm_layernorm/CMakeLists.txt index 78d3a5d02a..2eb7052e1e 100644 --- a/example/21_gemm_layernorm/CMakeLists.txt +++ b/example/21_gemm_layernorm/CMakeLists.txt @@ -1,3 +1,4 @@ -add_example_executable(example_gemm_bias_relu_add_layernorm_xdl_fp16 gemm_bias_relu_add_layernorm_xdl_fp16.cpp) -add_example_executable(example_gemm_layernorm_xdl_fp16 gemm_layernorm_xdl_fp16.cpp) -add_example_executable(example_gemm_xdl_layernorm_single_kernel_fp16 gemm_xdl_layernorm_single_kernel_fp16.cpp) +add_example_executable(example_gemm_bias_relu_add_layernorm_xdl_welford_fp16 gemm_bias_relu_add_layernorm_xdl_welford_fp16.cpp) +add_example_executable(example_gemm_bias_relu_add_layernorm_xdl_naive_fp16 gemm_bias_relu_add_layernorm_xdl_naive_fp16.cpp) +add_example_executable(example_gemm_layernorm_xdl_naive_fp16 gemm_layernorm_xdl_naive_fp16.cpp) +add_example_executable(example_gemm_xdl_layernorm_naive_single_kernel_fp16 gemm_xdl_layernorm_naive_single_kernel_fp16.cpp) diff --git a/example/21_gemm_layernorm/gemm_bias_relu_add_layernorm_xdl_fp16.cpp b/example/21_gemm_layernorm/gemm_bias_relu_add_layernorm_xdl_naive_fp16.cpp similarity index 100% rename from example/21_gemm_layernorm/gemm_bias_relu_add_layernorm_xdl_fp16.cpp rename to example/21_gemm_layernorm/gemm_bias_relu_add_layernorm_xdl_naive_fp16.cpp diff --git a/example/21_gemm_layernorm/gemm_bias_relu_add_layernorm_xdl_welford_fp16.cpp b/example/21_gemm_layernorm/gemm_bias_relu_add_layernorm_xdl_welford_fp16.cpp new file mode 100644 index 0000000000..b927ae2828 --- /dev/null +++ b/example/21_gemm_layernorm/gemm_bias_relu_add_layernorm_xdl_welford_fp16.cpp @@ -0,0 +1,262 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_layernorm_xdl_cshuffle.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_layernorm.hpp" +#include "ck/library/utility/check_err.hpp" + +template +using S = ck::Sequence; + +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using AddReluAdd = ck::tensor_operation::element_wise::AddReluAdd; + +// DataType +using ADataType = F16; +using BDataType = F16; +using AccDataType = F32; +using CShuffleDataType = F32; +using D0DataType = F16; +using D1DataType = F16; +using DsDataType = ck::Tuple; +using EMeanVarDataType = F16; +using GammaDataType = F16; +using BetaDataType = F16; +using HDataType = F16; + +// Layout +using ALayout = Row; +using BLayout = Col; +using D0Layout = Row; +using D1Layout = Row; +using DsLayout = ck::Tuple; +using HLayout = Row; + +using AElementOp = PassThrough; +using BElementOp = PassThrough; +using CDEElementOp = AddReluAdd; +using HElementOp = PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + +// clang-format off +using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultipleDLayernorm_Xdl_CShuffle +//######| ALayout| BLayout| DsLayout| HLayout| AData| BData| AccData| CShuffle| DsData| EMeanVarData| GammaData| BetaData| HData| A| B| CDE| H| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| PostShuffle| PostShuffle| Layernorm| Layernorm| +//######| | | | | Type| Type| Type| DataType| Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| ThreadClusterLengths| ScalarPerVector| ThreadClusterLengths| ThreadSliceSize| +//######| | | | | | | | | | | | | | Operation| Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _M_N| _M_N| _M_N| _M| +//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + < ALayout, BLayout, DsLayout, HLayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EMeanVarDataType, GammaDataType, BetaDataType, HDataType, AElementOp, BElementOp, CDEElementOp, HElementOp, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<32, 8>, 8, S<8, 32>, 8>; +// clang-format on + +auto f_host_tensor_descriptor1d = [](std::size_t len, std::size_t stride) { + return HostTensorDescriptor(std::vector({len}), + std::vector({stride})); +}; + +auto f_host_tensor_descriptor2d = + [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { + if(std::is_same::value) + { + return HostTensorDescriptor(std::vector({row, col}), + std::vector({stride, 1})); + } + else + { + return HostTensorDescriptor(std::vector({row, col}), + std::vector({1, stride})); + } + }; + +void host_gemm_layernorm(Tensor& h_m_n, + const Tensor& a_m_k, + const Tensor& b_k_n, + const Tensor& bias_n, + const Tensor& d1_m_n, + const Tensor& gamma_n, + const Tensor& beta_n, + AElementOp a_element_op, + BElementOp b_element_op, + CDEElementOp cde_element_op, + int M, + int N, + AccDataType epsilon = 1e-5) +{ + using ReferenceGemm = ck::tensor_operation::host::ReferenceGemm; + + using ReferenceLayernorm = ck::tensor_operation::host::ReferenceLayernorm; + + Tensor e_m_n(HostTensorDescriptor{M, N}); + Tensor c_m_n(HostTensorDescriptor{M, N}); + + auto ref_gemm = ReferenceGemm{}; + auto ref_gemm_invoker = ref_gemm.MakeInvoker(); + + auto ref_gemm_argument = + ref_gemm.MakeArgument(a_m_k, b_k_n, c_m_n, a_element_op, b_element_op, PassThrough{}); + + ref_gemm_invoker.Run(ref_gemm_argument); + + for(int n = 0; n < N; ++n) + { + AccDataType bias = static_cast(bias_n(n)); + for(int m = 0; m < M; ++m) + { + AccDataType e = static_cast(e_m_n(m, n)); + AccDataType d1 = static_cast(d1_m_n(m, n)); + cde_element_op(e, c_m_n(m, n), bias, d1); + e_m_n(m, n) = static_cast(e); + } + } + + ReferenceLayernorm ref_layernorm; + auto ref_layernorm_invoker = ref_layernorm.MakeInvoker(); + + auto ref_layernorm_argument = ref_layernorm.MakeArgument( + e_m_n, gamma_n, beta_n, h_m_n, HElementOp{}, {M, N}, {1}, epsilon); + ref_layernorm_invoker.Run(ref_layernorm_argument); +} + +int main() +{ + bool do_verification = true; + + // GEMM shape + ck::index_t M = 1024; + ck::index_t N = 1024; + ck::index_t K = 1024; + + ck::index_t StrideA = K; + ck::index_t StrideB = K; + ck::index_t StrideD0 = 0; + ck::index_t StrideD1 = N; + ck::index_t StrideH = N; + + float epsilon = 1e-5; + + Tensor a_m_k(f_host_tensor_descriptor2d(M, K, StrideA, ALayout{})); + Tensor b_k_n(f_host_tensor_descriptor2d(K, N, StrideB, BLayout{})); + Tensor d0_n(f_host_tensor_descriptor1d(N, 1)); + Tensor d1_m_n(f_host_tensor_descriptor2d(M, N, StrideD1, D1Layout{})); + Tensor gamma_n(f_host_tensor_descriptor1d(N, 1)); + Tensor beta_n(f_host_tensor_descriptor1d(N, 1)); + Tensor h_m_n(f_host_tensor_descriptor2d(M, N, StrideH, HLayout{})); + + a_m_k.GenerateTensorValue(GeneratorTensor_3{-1, 1}); + b_k_n.GenerateTensorValue(GeneratorTensor_3{-1, 1}); + d0_n.GenerateTensorValue(GeneratorTensor_3{-1, 1}); + d1_m_n.GenerateTensorValue(GeneratorTensor_3{-1, 1}); + gamma_n.GenerateTensorValue(GeneratorTensor_3{-1, 1}); + beta_n.GenerateTensorValue(GeneratorTensor_3{-1, 1}); + + DeviceMem a_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpaceSize()); + DeviceMem b_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpaceSize()); + DeviceMem d0_device_buf(sizeof(D0DataType) * d0_n.mDesc.GetElementSpaceSize()); + DeviceMem d1_device_buf(sizeof(D1DataType) * d1_m_n.mDesc.GetElementSpaceSize()); + DeviceMem gamma_device_buf(sizeof(GammaDataType) * gamma_n.mDesc.GetElementSpaceSize()); + DeviceMem beta_device_buf(sizeof(BetaDataType) * beta_n.mDesc.GetElementSpaceSize()); + DeviceMem h_device_buf(sizeof(HDataType) * h_m_n.mDesc.GetElementSpaceSize()); + + a_device_buf.ToDevice(a_m_k.mData.data()); + b_device_buf.ToDevice(b_k_n.mData.data()); + d0_device_buf.ToDevice(d0_n.mData.data()); + d1_device_buf.ToDevice(d1_m_n.mData.data()); + gamma_device_buf.ToDevice(gamma_n.mData.data()); + beta_device_buf.ToDevice(beta_n.mData.data()); + + auto a_element_op = AElementOp{}; + auto b_element_op = BElementOp{}; + auto cde_element_op = CDEElementOp{}; + auto h_element_op = HElementOp{}; + + auto device_op = DeviceOpInstance{}; + auto invoker = device_op.MakeInvoker(); + auto argument = + device_op.MakeArgument(a_device_buf.GetDeviceBuffer(), + b_device_buf.GetDeviceBuffer(), + {d0_device_buf.GetDeviceBuffer(), d1_device_buf.GetDeviceBuffer()}, + gamma_device_buf.GetDeviceBuffer(), + beta_device_buf.GetDeviceBuffer(), + h_device_buf.GetDeviceBuffer(), + M, + N, + K, + StrideA, + StrideB, + {StrideD0, StrideD1}, + StrideH, + epsilon, + a_element_op, + b_element_op, + cde_element_op, + h_element_op); + + if(!device_op.IsSupportedArgument(argument)) + { + throw std::runtime_error("wrong! this device_op instance does not support this problem"); + } + + size_t workspace_sz = device_op.GetWorkSpaceSize(&argument); + DeviceMem workspace_dev(workspace_sz); + device_op.SetWorkSpacePointer(&argument, workspace_dev.GetDeviceBuffer()); + + invoker.Run(argument, StreamConfig{nullptr, false}); + + bool pass = true; + + if(do_verification) + { + Tensor h_m_n_host(HostTensorDescriptor{M, N}); + host_gemm_layernorm(h_m_n_host, + a_m_k, + b_k_n, + d0_n, + d1_m_n, + gamma_n, + beta_n, + a_element_op, + b_element_op, + cde_element_op, + M, + N, + epsilon); + + h_device_buf.FromDevice(h_m_n.mData.data()); + pass &= + ck::utils::check_err(h_m_n, h_m_n_host, "Error: Incorrect results h_m_n", 1e-2, 1e-2); + } + + return pass ? 0 : 1; +} diff --git a/example/21_gemm_layernorm/gemm_layernorm_xdl_fp16.cpp b/example/21_gemm_layernorm/gemm_layernorm_xdl_naive_fp16.cpp similarity index 100% rename from example/21_gemm_layernorm/gemm_layernorm_xdl_fp16.cpp rename to example/21_gemm_layernorm/gemm_layernorm_xdl_naive_fp16.cpp diff --git a/example/21_gemm_layernorm/gemm_xdl_layernorm_single_kernel_fp16.cpp b/example/21_gemm_layernorm/gemm_xdl_layernorm_naive_single_kernel_fp16.cpp similarity index 100% rename from example/21_gemm_layernorm/gemm_xdl_layernorm_single_kernel_fp16.cpp rename to example/21_gemm_layernorm/gemm_xdl_layernorm_naive_single_kernel_fp16.cpp diff --git a/include/ck/tensor_operation/gpu/device/device_gemm_multiple_d_layernorm.hpp b/include/ck/tensor_operation/gpu/device/device_gemm_multiple_d_layernorm.hpp new file mode 100644 index 0000000000..a67a09b874 --- /dev/null +++ b/include/ck/tensor_operation/gpu/device/device_gemm_multiple_d_layernorm.hpp @@ -0,0 +1,67 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include "device_base.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +// GEMM: +// input : A[M, K] +// input : B[N, K] +// input : D0[M, N], D1[M, N], ... +// output : E[M, N] +// output : H[M, N] +// C = a_op(A) * b_op(B) +// E = cde_op(C, D0, D1, ...) +// H = layernorm(E) +// Assume: +// D0, D1, ... and E have the same layout +// Calculate mean & variance along N dimension in layernorm(E) +template +struct DeviceGemmMultipleDLayernorm : public BaseOperator +{ + static constexpr index_t NumDTensor = DsDataType::Size(); + virtual std::unique_ptr + MakeArgumentPointer(const void* p_a, + const void* p_b, + std::array p_ds, + const void* p_gamma, + const void* p_beta, + void* p_h, + index_t MRaw, + index_t NRaw, + index_t KRaw, + index_t StrideA, + index_t StrideB, + std::array StrideDs, + index_t StrideH, + double epsilon, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CDEElementwiseOperation cde_element_op, + HElementwiseOperation h_element_op) = 0; + + virtual std::unique_ptr MakeInvokerPointer() = 0; +}; // namespace device + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_layernorm_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_layernorm_xdl_cshuffle.hpp new file mode 100644 index 0000000000..2f4bf3ee0e --- /dev/null +++ b/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_layernorm_xdl_cshuffle.hpp @@ -0,0 +1,1072 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm_multiple_d_layernorm.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/matrix_padder.hpp" +#include "ck/tensor_operation/gpu/grid/gemm_layernorm/gridwise_gemm_multiple_d_welford_first_half_xdl_cshuffle.hpp" +#include "ck/tensor_operation/gpu/grid/gemm_layernorm/gridwise_welford_second_half_layernorm2d.hpp" +#include "ck/host_utility/device_prop.hpp" +#include "ck/host_utility/kernel_launch.hpp" + +namespace ck { + +template +__global__ void +#if CK_USE_LAUNCH_BOUNDS + __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +#endif + kernel_gemm_multiple_d_welford_first_half_xdl_cshuffle( + const ABDataType* __restrict__ p_a_grid, + const ABDataType* __restrict__ p_b_grid, + DsPointer p_ds_grid, + EMeanVarDataType* __restrict__ p_e_grid, + EMeanVarDataType* __restrict__ p_welford_mean_grid, + EMeanVarDataType* __restrict__ p_welford_var_grid, + int32_t* __restrict__ p_welford_count_grid, + const AElementwiseOperation a_element_op, + const BElementwiseOperation b_element_op, + const CDEElementwiseOperation cde_element_op, + const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1, + const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1, + const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock + ds_grid_desc_mblock_mperblock_nblock_nperblock, + const EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock + e_grid_desc_mblock_mperblock_nblock_nperblock, + const MeanVarGridDescriptor_MBlock_MPerBlock_NBlock + mean_var_grid_desc_mblock_mperblock_nblock, + const CountGridDescriptor_MBlock_MPerBlock_NBlock + count_grid_desc_mblock_mperblock_nblock, + const Block2ETileMap block_2_etile_map, + index_t NRaw) +{ +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__)) + __shared__ char p_shared[GridwiseGemmWelford::GetSharedMemoryNumberOfByte()]; + + GridwiseGemmWelford::template Run( + p_a_grid, + p_b_grid, + p_ds_grid, + p_e_grid, + p_welford_mean_grid, + p_welford_var_grid, + p_welford_count_grid, + p_shared, + a_element_op, + b_element_op, + cde_element_op, + a_grid_desc_ak0_m_ak1, + b_grid_desc_bk0_n_bk1, + ds_grid_desc_mblock_mperblock_nblock_nperblock, + e_grid_desc_mblock_mperblock_nblock_nperblock, + mean_var_grid_desc_mblock_mperblock_nblock, + count_grid_desc_mblock_mperblock_nblock, + block_2_etile_map, + NRaw); +#else + ignore = p_a_grid; + ignore = p_b_grid; + ignore = p_ds_grid; + ignore = p_e_grid; + ignore = p_welford_mean_grid; + ignore = p_welford_var_grid; + ignore = p_welford_count_grid; + ignore = a_element_op; + ignore = b_element_op; + ignore = cde_element_op; + ignore = a_grid_desc_ak0_m_ak1; + ignore = b_grid_desc_bk0_n_bk1; + ignore = ds_grid_desc_mblock_mperblock_nblock_nperblock; + ignore = e_grid_desc_mblock_mperblock_nblock_nperblock; + ignore = mean_var_grid_desc_mblock_mperblock_nblock; + ignore = count_grid_desc_mblock_mperblock_nblock; + ignore = block_2_etile_map; + ignore = NRaw; +#endif +} + +template +__global__ void +#if CK_USE_LAUNCH_BOUNDS + __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +#endif + kernel_welford_layernorm2d_second_half( + const EMeanVarDataType* __restrict__ p_e_grid, + const EMeanVarDataType* __restrict__ p_in_welford_mean_grid, + const EMeanVarDataType* __restrict__ p_in_welford_var_grid, + const int32_t* __restrict__ p_in_welford_count_grid, + const GammaDataType* __restrict__ p_gamma_grid, + const BetaDataType* __restrict__ p_beta_grid, + HDataType* __restrict__ p_h_grid, + const EHGridDesc_M_N e_grid_desc_m_n, + const EHGridDesc_M_N h_grid_desc_m_n, + const LayernormMeanVarGridDesc_M_NBlock mean_var_grid_desc_m_nblock, + const LayernormCountGridDesc_M_NBlock count_grid_desc_m_nblock, + const GammaBetaGridDesc_N gamma_grid_desc_n, + const GammaBetaGridDesc_N beta_grid_desc_n, + index_t numMeanVarCountBlockTileIteration_N, + index_t NBlockClusterLength, + ComputeDataType epsilon, + HElementwiseOperation h_element_op) +{ + GridwiseWelfordLayernorm::Run(p_e_grid, + p_in_welford_mean_grid, + p_in_welford_var_grid, + p_in_welford_count_grid, + p_gamma_grid, + p_beta_grid, + p_h_grid, + e_grid_desc_m_n, + h_grid_desc_m_n, + mean_var_grid_desc_m_nblock, + count_grid_desc_m_nblock, + gamma_grid_desc_n, + beta_grid_desc_n, + numMeanVarCountBlockTileIteration_N, + NBlockClusterLength, + epsilon, + h_element_op); +} + +} // namespace ck + +namespace ck { +namespace tensor_operation { +namespace device { + +// GEMM: +// input : A[M, K] +// input : B[N, K] +// input : D0[M, N], D1[M, N], ... +// output : E[M, N] +// output : H[M, N] +// C = a_op(A) * b_op(B) +// E = cde_op(C, D0, D1, ...) +// H = layernorm(E) +// Assume: +// D0, D1, ... and E have the same layout +// Calculate mean & variance along N dimension in layernorm(E) +template +struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle + : public DeviceGemmMultipleDLayernorm +{ + // EDataType, MeanDataType and VarDataType must be the same. + // eg. M, N, K = [1, 1, 1], + // in case of layernorm, divisor = 1 / sqrt(var + 1e-5) = 316.227783 + // if (x - mean) != 0, (x - mean) * divisor * gamma might be too large + // However, (x - mean) * divisor * gamma should be 0 in this case + + using DeviceOp = DeviceGemmMultipleDLayernorm_Xdl_CShuffle; + using ELayout = HLayout; + + static constexpr index_t NumDTensor = DsDataType::Size(); + static constexpr index_t LayernormHDstVectorSize = PostShuffleScalarPerVector; + static constexpr index_t LayernormGammaSrcVectorSize = PostShuffleScalarPerVector; + static constexpr index_t LayernormBetaSrcVectorSize = PostShuffleScalarPerVector; + static constexpr index_t LayernormESrcVectorSize = PostShuffleScalarPerVector; + static constexpr index_t LayernormThreadSliceSize_N = PostShuffleScalarPerVector; + using LayernormBlockTileSize_M_N = + Sequence; + + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + + static constexpr auto matrix_padder = MatrixPadder{ + GemmMPerBlock, GemmNPerBlock, GemmKPerBlock}; + + static auto MakeAGridDescriptor_M_K(index_t MRaw, index_t KRaw, index_t StrideA) + { + const auto a_grid_desc_mraw_kraw = [&]() { + if constexpr(is_same_v) + { + return make_naive_tensor_descriptor(make_tuple(MRaw, KRaw), + make_tuple(StrideA, I1)); + } + else if constexpr(is_same_v) + { + return make_naive_tensor_descriptor(make_tuple(MRaw, KRaw), + make_tuple(I1, StrideA)); + } + }(); + + return matrix_padder.PadADescriptor_M_K(a_grid_desc_mraw_kraw); + } + + static auto MakeBGridDescriptor_N_K(index_t KRaw, index_t NRaw, index_t StrideB) + { + const auto b_grid_desc_nraw_kraw = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(NRaw, KRaw), + make_tuple(I1, StrideB)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(NRaw, KRaw), + make_tuple(StrideB, I1)); + } + }(); + + return matrix_padder.PadBDescriptor_N_K(b_grid_desc_nraw_kraw); + } + + template + static auto MakeEHGridDescriptor_M_N(index_t M, index_t N, index_t Stride) + { + // Only support row major for E and H + const auto grid_desc_m_n = + make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(Stride, I1)); + return PadTensorDescriptor(grid_desc_m_n, make_tuple(MPerTile, NPerTile), DoPads{}); + } + + static auto MakeDsGridDescriptor_M_N(const std::array& MRaws, + const std::array& NRaws, + const std::array& DsStride) + { + return generate_tuple( + [&](auto i) { + using DLayout = remove_cvref_t>; + static_assert(is_same::value); + + return DeviceOp:: + MakeEHGridDescriptor_M_N, GemmMPerBlock, GemmNPerBlock>( + MRaws[i], NRaws[i], DsStride[i]); + }, + Number{}); + } + + template + static auto MakeMeanVarDescriptor_M_N(index_t M, index_t N) + { + const auto grid_desc_m_n = + make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(N, I1)); + return PadTensorDescriptor(grid_desc_m_n, make_tuple(MPerTile, NPerTile), DoPads{}); + } + + template + static auto MakeCountDescriptor_M_N(index_t M, index_t N) + { + // We will broadcast [N] to [M, N] in this descriptor + // Hence, 1st stride is 0 + const auto grid_desc_m_n = + make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(I0, I1)); + return PadTensorDescriptor(grid_desc_m_n, make_tuple(MPerTile, NPerTile), DoPads{}); + } + + template + static auto MakeDescriptor_X(index_t X) + { + const auto grid_desc_x = make_naive_tensor_descriptor_packed(make_tuple(X)); + return PadTensorDescriptor(grid_desc_x, make_tuple(XPerTile), Sequence{}); + } + + using AGridDesc_M_K = decltype(MakeAGridDescriptor_M_K(1, 1, 1)); + using BGridDesc_N_K = decltype(MakeBGridDescriptor_N_K(1, 1, 1)); + using DsGridDesc_M_N = remove_cvref_t; + // We have to separate mean var descriptor for gemm and layernorm bacause of different grid + // layout(different padding) + using GemmMeanVarGridDesc_M_NBlock = decltype( + MakeMeanVarDescriptor_M_N, GemmMPerBlock, GemmNPerBlock>(1, 1)); + + using GemmCountGridDesc_M_NBlock = decltype( + MakeCountDescriptor_M_N, GemmMPerBlock, GemmNPerBlock>(1, 1)); + + using LayernormMeanVarGridDesc_M_NBlock = + decltype(MakeMeanVarDescriptor_M_N, + LayernormBlockTileSize_M_N::At(0), + LayernormBlockTileSize_M_N::At(1)>(1, 1)); + + using LayernormCountGridDesc_M_NBlock = + decltype(MakeCountDescriptor_M_N, + LayernormBlockTileSize_M_N::At(0), + LayernormBlockTileSize_M_N::At(1)>(1, 1)); + + using GammaBetaGridDesc_N = decltype(MakeDescriptor_X(1)); + using EHGridDesc_M_N = decltype(MakeEHGridDescriptor_M_N, 1, 1>(1, 1, 1)); + + using GridwiseGemmWelford = GridwiseGemmMultipleDWelfordFirstHalf_xdl_cshuffle< + ADataType, // TODO: distinguish A/B datatype + AccDataType, + CShuffleDataType, + DsDataType, + EMeanVarDataType, + AElementwiseOperation, + BElementwiseOperation, + CDEElementwiseOperation, + InMemoryDataOperationEnum::Set, + AGridDesc_M_K, + BGridDesc_N_K, + DsGridDesc_M_N, + EHGridDesc_M_N, + GemmMeanVarGridDesc_M_NBlock, + GemmCountGridDesc_M_NBlock, + NumGemmKPrefetchStage, + BlockSize, + GemmMPerBlock, + GemmNPerBlock, + GemmKPerBlock, + AK1, + BK1, + MPerXDL, + NPerXDL, + MXdlPerWave, + NXdlPerWave, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + false, + ABlockLdsExtraM, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_BK1, + false, + BBlockLdsExtraN, + CShuffleMXdlPerWavePerShuffle, + CShuffleNXdlPerWavePerShuffle, + PostShuffleThreadClusterSize_M_N, + PostShuffleScalarPerVector, + LoopSched, + PipelineVer>; + + using Block2ETileMap = typename GridwiseGemmWelford::DefaultBlock2ETileMap; + + using GridwiseWelfordLayernorm = + GridwiseWelfordSecondHalfLayernorm2d; + + // Argument + struct Argument : public BaseArgument + { + Argument(const void* p_a_grid, + const void* p_b_grid, + std::array p_ds_grid, + const void* p_gamma_grid, + const void* p_beta_grid, + void* p_h_grid, + index_t MRaw, + index_t NRaw, + index_t KRaw, + index_t StrideA, + index_t StrideB, + std::array StrideDs, + index_t StrideH, + double epsilon, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CDEElementwiseOperation cde_element_op, + HElementwiseOperation h_element_op) + : p_a_grid_{static_cast(p_a_grid)}, + p_b_grid_{static_cast(p_b_grid)}, + p_ds_grid_{}, + p_workspace_e_grid_{nullptr}, + p_workspace_mean_{nullptr}, + p_workspace_var_{nullptr}, + p_workspace_count_{nullptr}, + p_gamma_grid_{static_cast(p_gamma_grid)}, + p_beta_grid_{static_cast(p_beta_grid)}, + p_h_grid_{static_cast(p_h_grid)}, + a_grid_desc_m_k_{DeviceOp::MakeAGridDescriptor_M_K(MRaw, KRaw, StrideA)}, + b_grid_desc_n_k_{DeviceOp::MakeBGridDescriptor_N_K(KRaw, NRaw, StrideB)}, + ds_grid_desc_m_n_{}, + gemm_e_grid_desc_m_n_{ + DeviceOp::MakeEHGridDescriptor_M_N, + GemmMPerBlock, + GemmNPerBlock>(MRaw, NRaw, StrideH)}, + layernorm_e_grid_desc_m_n_{ + DeviceOp::MakeEHGridDescriptor_M_N, + LayernormBlockTileSize_M_N::At(0), + LayernormBlockTileSize_M_N::At(1)>( + MRaw, NRaw, StrideH)}, + gemm_mean_var_grid_desc_m_nblock_{}, + gemm_count_grid_desc_m_nblock_{}, + layernorm_mean_var_grid_desc_m_nblock_{}, + layernorm_count_grid_desc_m_nblock_{}, + gamma_grid_desc_n_{ + DeviceOp::MakeDescriptor_X(NRaw)}, + beta_grid_desc_n_{ + DeviceOp::MakeDescriptor_X(NRaw)}, + h_grid_desc_m_n_{ + DeviceOp::MakeEHGridDescriptor_M_N, + LayernormBlockTileSize_M_N::At(0), + LayernormBlockTileSize_M_N::At(1)>( + MRaw, NRaw, StrideH)}, + a_grid_desc_ak0_m_ak1_{ + GridwiseGemmWelford::MakeDefaultAGridDescriptor_AK0_M_AK1(a_grid_desc_m_k_)}, + b_grid_desc_bk0_n_bk1_{ + GridwiseGemmWelford::MakeDefaultBGridDescriptor_BK0_N_BK1(b_grid_desc_n_k_)}, + block_2_etile_map_{ + GridwiseGemmWelford::MakeDefaultBlock2ETileMap(gemm_e_grid_desc_m_n_)}, + a_element_op_{a_element_op}, + b_element_op_{b_element_op}, + cde_element_op_{cde_element_op}, + h_element_op_{h_element_op}, + MRaw_{MRaw}, + NRaw_{NRaw}, + KRaw_{KRaw}, + gemm_nblock_{math::integer_divide_ceil(NRaw, GemmNPerBlock)}, + epsilon_{static_cast(epsilon)} + { + // We don't need to pad in N dimension in gemm for mean/var/count. Set NPerTile 1. + gemm_mean_var_grid_desc_m_nblock_ = + DeviceOp::MakeMeanVarDescriptor_M_N, GemmMPerBlock, 1>( + MRaw, gemm_nblock_); + + gemm_count_grid_desc_m_nblock_ = + DeviceOp::MakeCountDescriptor_M_N, GemmMPerBlock, 1>( + MRaw, gemm_nblock_); + + layernorm_mean_var_grid_desc_m_nblock_ = + DeviceOp::MakeMeanVarDescriptor_M_N, + LayernormBlockTileSize_M_N::At(0), + LayernormBlockTileSize_M_N::At(1)>( + MRaw, gemm_nblock_); + + layernorm_count_grid_desc_m_nblock_ = + DeviceOp::MakeCountDescriptor_M_N, + LayernormBlockTileSize_M_N::At(0), + LayernormBlockTileSize_M_N::At(1)>(MRaw, + gemm_nblock_); + + // populate pointer, desc for Ds + static_for<0, NumDTensor, 1>{}([&](auto i) { + using DDataType = remove_cvref_t>; + + // D pointer + p_ds_grid_(i) = static_cast(p_ds_grid[i]); + + // D desc + ds_grid_desc_m_n_(i) = + DeviceOp::MakeEHGridDescriptor_M_N, + GemmMPerBlock, + GemmNPerBlock>(MRaw, NRaw, StrideDs[i]); + }); + + // populate desc for Ds/E/mean/var/count + if(GridwiseGemmWelford::CheckValidity(a_grid_desc_m_k_, + b_grid_desc_n_k_, + ds_grid_desc_m_n_, + gemm_e_grid_desc_m_n_, + block_2_etile_map_)) + { + ds_grid_desc_mblock_mperblock_nblock_nperblock_ = + GridwiseGemmWelford::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + ds_grid_desc_m_n_); + + e_grid_desc_mblock_mperblock_nblock_nperblock_ = + GridwiseGemmWelford::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + gemm_e_grid_desc_m_n_); + + gemm_mean_var_grid_desc_mblock_mperblock_nblock_ = + GridwiseGemmWelford::MakeMeanVarCountGridDescriptor_MBlock_MPerBlock_NBlock( + gemm_mean_var_grid_desc_m_nblock_); + + gemm_count_grid_desc_mblock_mperblock_nblock_ = + GridwiseGemmWelford::MakeMeanVarCountGridDescriptor_MBlock_MPerBlock_NBlock( + gemm_count_grid_desc_m_nblock_); + } + } + + void Print() const + { + std::cout << "A[M, K]: " << a_grid_desc_m_k_ << std::endl; + std::cout << "B[N, K]: " << b_grid_desc_n_k_ << std::endl; + static_for<0, NumDTensor, 1>{}( + [&](auto i) { std::cout << "Ds[M, N]: " << ds_grid_desc_m_n_[i] << std::endl; }); + std::cout << "E[M, N]: " << gemm_e_grid_desc_m_n_ << std::endl; + std::cout << "H[M, N]: " << h_grid_desc_m_n_ << std::endl; + } + + // private: + // pointers + const ADataType* p_a_grid_; + const BDataType* p_b_grid_; + typename GridwiseGemmWelford::DsGridPointer p_ds_grid_; + void* p_workspace_e_grid_; + void* p_workspace_mean_; + void* p_workspace_var_; + void* p_workspace_count_; + const GammaDataType* p_gamma_grid_; + const BetaDataType* p_beta_grid_; + HDataType* p_h_grid_; + + // tensor descriptors for problem definiton + AGridDesc_M_K a_grid_desc_m_k_; + BGridDesc_N_K b_grid_desc_n_k_; + DsGridDesc_M_N ds_grid_desc_m_n_; + EHGridDesc_M_N gemm_e_grid_desc_m_n_; + EHGridDesc_M_N layernorm_e_grid_desc_m_n_; + GemmMeanVarGridDesc_M_NBlock gemm_mean_var_grid_desc_m_nblock_; + GemmCountGridDesc_M_NBlock gemm_count_grid_desc_m_nblock_; + LayernormMeanVarGridDesc_M_NBlock layernorm_mean_var_grid_desc_m_nblock_; + LayernormCountGridDesc_M_NBlock layernorm_count_grid_desc_m_nblock_; + GammaBetaGridDesc_N gamma_grid_desc_n_; + GammaBetaGridDesc_N beta_grid_desc_n_; + EHGridDesc_M_N h_grid_desc_m_n_; + + // tensor descriptors for block/thread-wise copy + typename GridwiseGemmWelford::DefaultAGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_; + typename GridwiseGemmWelford::DefaultBGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_; + typename GridwiseGemmWelford::DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock + ds_grid_desc_mblock_mperblock_nblock_nperblock_; + typename GridwiseGemmWelford::EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock + e_grid_desc_mblock_mperblock_nblock_nperblock_; + typename GridwiseGemmWelford::MeanVarGridDescriptor_MBlock_MPerBlock_NBlock + gemm_mean_var_grid_desc_mblock_mperblock_nblock_; + typename GridwiseGemmWelford::CountGridDescriptor_MBlock_MPerBlock_NBlock + gemm_count_grid_desc_mblock_mperblock_nblock_; + + // block-to-e-tile map + Block2ETileMap block_2_etile_map_; + + // element-wise op + AElementwiseOperation a_element_op_; + BElementwiseOperation b_element_op_; + CDEElementwiseOperation cde_element_op_; + HElementwiseOperation h_element_op_; + + index_t MRaw_; + index_t NRaw_; + index_t KRaw_; + index_t gemm_nblock_; + AccDataType epsilon_; + }; + + // Invoker + struct Invoker : public BaseInvoker + { + using Argument = DeviceOp::Argument; + + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { + float avg_time = 0; + + if(!GridwiseGemmWelford::CheckValidity(arg.a_grid_desc_m_k_, + arg.b_grid_desc_n_k_, + arg.ds_grid_desc_m_n_, + arg.gemm_e_grid_desc_m_n_, + arg.block_2_etile_map_)) + { + throw std::runtime_error("wrong! GridwiseGemmWelford has invalid setting"); + } + + index_t grid_size = arg.block_2_etile_map_.CalculateGridSize(arg.gemm_e_grid_desc_m_n_); + + const auto M = arg.h_grid_desc_m_n_.GetLength(I0); + const auto N = arg.h_grid_desc_m_n_.GetLength(I1); + const auto K = + arg.a_grid_desc_ak0_m_ak1_.GetLength(I0) * arg.a_grid_desc_ak0_m_ak1_.GetLength(I2); + + auto launch_kernel = [&](auto has_main_k_block_loop) { + constexpr bool has_main_loop = has_main_k_block_loop.value; + + const auto kernel_gemm_welford = + kernel_gemm_multiple_d_welford_first_half_xdl_cshuffle< + GridwiseGemmWelford, + ADataType, // TODO: distiguish A/B datatype + typename GridwiseGemmWelford::DsGridPointer, + EMeanVarDataType, + AElementwiseOperation, + BElementwiseOperation, + CDEElementwiseOperation, + typename GridwiseGemmWelford::DefaultAGridDesc_AK0_M_AK1, + typename GridwiseGemmWelford::DefaultBGridDesc_BK0_N_BK1, + typename GridwiseGemmWelford:: + DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock, + typename GridwiseGemmWelford:: + EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock, + typename GridwiseGemmWelford::MeanVarGridDescriptor_MBlock_MPerBlock_NBlock, + typename GridwiseGemmWelford::CountGridDescriptor_MBlock_MPerBlock_NBlock, + typename GridwiseGemmWelford::DefaultBlock2ETileMap, + has_main_loop>; + + const auto kernel_welford_layernorm = + kernel_welford_layernorm2d_second_half; + + avg_time += + launch_and_time_kernel(stream_config, + kernel_gemm_welford, + dim3(grid_size), + dim3(BlockSize), + 0, + arg.p_a_grid_, + arg.p_b_grid_, + arg.p_ds_grid_, + static_cast(arg.p_workspace_e_grid_), + static_cast(arg.p_workspace_mean_), + static_cast(arg.p_workspace_var_), + static_cast(arg.p_workspace_count_), + arg.a_element_op_, + arg.b_element_op_, + arg.cde_element_op_, + arg.a_grid_desc_ak0_m_ak1_, + arg.b_grid_desc_bk0_n_bk1_, + arg.ds_grid_desc_mblock_mperblock_nblock_nperblock_, + arg.e_grid_desc_mblock_mperblock_nblock_nperblock_, + arg.gemm_mean_var_grid_desc_mblock_mperblock_nblock_, + arg.gemm_count_grid_desc_mblock_mperblock_nblock_, + arg.block_2_etile_map_, + arg.NRaw_); + + index_t MBlockClusterLength = + math::integer_divide_ceil(M, LayernormBlockTileSize_M_N::At(0)); + index_t NBlockClusterLength = + math::integer_divide_ceil(N, LayernormBlockTileSize_M_N::At(1)); + grid_size = MBlockClusterLength * NBlockClusterLength; + + index_t numMeanVarCountBlockTileIteration_N = math::integer_divide_ceil( + arg.gemm_nblock_, LayernormThreadClusterSize_M_N::At(I1)); + + avg_time += launch_and_time_kernel( + stream_config, + kernel_welford_layernorm, + dim3(grid_size), + dim3(BlockSize), + 0, + static_cast(arg.p_workspace_e_grid_), + static_cast(arg.p_workspace_mean_), + static_cast(arg.p_workspace_var_), + static_cast(arg.p_workspace_count_), + arg.p_gamma_grid_, + arg.p_beta_grid_, + arg.p_h_grid_, + arg.layernorm_e_grid_desc_m_n_, + arg.h_grid_desc_m_n_, + arg.layernorm_mean_var_grid_desc_m_nblock_, + arg.layernorm_count_grid_desc_m_nblock_, + arg.gamma_grid_desc_n_, + arg.beta_grid_desc_n_, + numMeanVarCountBlockTileIteration_N, + NBlockClusterLength, + arg.epsilon_, + arg.h_element_op_); + + return avg_time; + }; + + if(GridwiseGemmWelford::CalculateHasMainKBlockLoop(K)) + { + return launch_kernel(integral_constant{}); + } + else + { + return launch_kernel(integral_constant{}); + } + } + + // polymorphic + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + } + }; + + size_t GetWorkSpaceSize(const BaseArgument* pArg) const override + { + const Argument* pArg_ = dynamic_cast(pArg); + + size_t workspace_size = 0; + + int gemm_welford_size = pArg_->MRaw_ * pArg_->gemm_nblock_; + + // workspace for welford intermediate mean + workspace_size += gemm_welford_size * sizeof(EMeanVarDataType) + 64; + + // workspace for welford intermediate mean + workspace_size += gemm_welford_size * sizeof(EMeanVarDataType) + 64; + + // workspace for welford intermediate count + workspace_size += pArg_->gemm_nblock_ * sizeof(int32_t) + 64; + + if constexpr(!is_same_v) + workspace_size += pArg_->MRaw_ * pArg_->NRaw_ * sizeof(EMeanVarDataType); + + return (workspace_size); + }; + + void SetWorkSpacePointer(BaseArgument* pArg, void* p_workspace) const override + { + Argument* pArg_ = dynamic_cast(pArg); + + pArg_->p_workspace_ = p_workspace; + + int gemm_welford_size = pArg_->MRaw_ * pArg_->gemm_nblock_; + + // setup buffer used for intermediate welford mean + pArg_->p_workspace_mean_ = static_cast(pArg_->p_workspace_); + + index_t mean_space_sz = gemm_welford_size * sizeof(EMeanVarDataType); + mean_space_sz = math::integer_least_multiple(mean_space_sz, 64); + + // setup buffer used for intermediate welford varirance + pArg_->p_workspace_var_ = reinterpret_cast(pArg_->p_workspace_mean_) + mean_space_sz; + + index_t variance_space_sz = gemm_welford_size * sizeof(EMeanVarDataType); + variance_space_sz = math::integer_least_multiple(variance_space_sz, 64); + + // setup buffer used for intermediate welford count + pArg_->p_workspace_count_ = + reinterpret_cast(pArg_->p_workspace_var_) + variance_space_sz; + + index_t count_space_sz = gemm_welford_size * sizeof(int32_t); + count_space_sz = math::integer_least_multiple(count_space_sz, 64); + + if constexpr(!is_same_v) + pArg_->p_workspace_e_grid_ = + reinterpret_cast(pArg_->p_workspace_count_) + count_space_sz; + else + pArg_->p_workspace_e_grid_ = static_cast(pArg_->p_h_grid_); + }; + + static bool IsSupportedArgument(const Argument& arg) + { + if(!(ck::get_device_name() == "gfx908" || ck::get_device_name() == "gfx90a")) + { + return false; + } + + // check vector load/store + { + using Row = ck::tensor_layout::gemm::RowMajor; + using Col = ck::tensor_layout::gemm::ColumnMajor; + + // check vector load of A + if constexpr(is_same_v && ABlockTransferSrcVectorDim == 2) + { + if(arg.KRaw_ % ABlockTransferSrcScalarPerVector != 0) + { + return false; + } + } + else if constexpr(is_same_v && ABlockTransferSrcVectorDim == 1) + { + // FIXME: not rigorous + if(arg.MRaw_ % ABlockTransferSrcScalarPerVector != 0) + { + return false; + } + } + else + { + return false; + } + + // check vector laod of B + if constexpr(is_same_v && BBlockTransferSrcVectorDim == 2) + { + if(arg.KRaw_ % BBlockTransferSrcScalarPerVector != 0) + { + return false; + } + } + else if constexpr(is_same_v && BBlockTransferSrcVectorDim == 1) + { + // FIXME: not rigorous + if(arg.NRaw_ % BBlockTransferSrcScalarPerVector != 0) + { + return false; + } + } + else + { + return false; + } + + // check vector load of Ds + // only support RowMajor for now + bool all_valid = true; + + static_for<0, NumDTensor, 1>{}([&](auto i) { + using DLayout = remove_cvref_t>; + + if constexpr(!is_same_v) + { + all_valid = false; + } + }); + + if(!all_valid) + { + return false; + } + + // check vector store of E + // E and H only support RowMajor for now + if constexpr(is_same_v && is_same_v) + { + if(arg.NRaw_ % PostShuffleScalarPerVector != 0 || + arg.NRaw_ % LayernormGammaSrcVectorSize != 0 || + arg.NRaw_ % LayernormBetaSrcVectorSize != 0 || + arg.NRaw_ % LayernormHDstVectorSize != 0) + { + return false; + } + } + else + { + return false; + } + } + + return true; + } + + // polymorphic + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + return IsSupportedArgument(*dynamic_cast(p_arg)); + } + + static auto MakeArgument(const void* p_a, + const void* p_b, + std::array p_ds, + const void* p_gamma, + const void* p_beta, + void* p_h, + index_t MRaw, + index_t NRaw, + index_t KRaw, + index_t StrideA, + index_t StrideB, + std::array StrideDs, + index_t StrideH, + double epsilon, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CDEElementwiseOperation cde_element_op, + HElementwiseOperation h_element_op) + { + return Argument{p_a, + p_b, + p_ds, + p_gamma, + p_beta, + p_h, + MRaw, + NRaw, + KRaw, + StrideA, + StrideB, + StrideDs, + StrideH, + epsilon, + a_element_op, + b_element_op, + cde_element_op, + h_element_op}; + } + + static auto MakeInvoker() { return Invoker{}; } + + // polymorphic + std::unique_ptr MakeArgumentPointer(const void* p_a, + const void* p_b, + std::array p_ds, + const void* p_gamma, + const void* p_beta, + void* p_h, + index_t MRaw, + index_t NRaw, + index_t KRaw, + index_t StrideA, + index_t StrideB, + std::array StrideDs, + index_t StrideH, + double epsilon, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CDEElementwiseOperation cde_element_op, + HElementwiseOperation h_element_op) override + { + return std::make_unique(p_a, + p_b, + p_ds, + p_gamma, + p_beta, + p_h, + MRaw, + NRaw, + KRaw, + StrideA, + StrideB, + StrideDs, + StrideH, + epsilon, + a_element_op, + b_element_op, + cde_element_op, + h_element_op); + } + + // polymorphic + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(Invoker{}); + } + + // polymorphic + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + std::map LoopSchedToString{ + {LoopScheduler::Default, "Default"}, {LoopScheduler::Interwave, "Interwave"}}; + + std::map PipelineVersionToString{{PipelineVersion::v1, "v1"}, + {PipelineVersion::v2, "v2"}}; + + // clang-format off + str << "DeviceGemmMultipleDLayernorm_Xdl_CShuffle" + << "<" + << BlockSize << ", " + << GemmMPerBlock << ", " + << GemmNPerBlock << ", " + << GemmKPerBlock << ", " + << AK1 << ", " + << BK1 << ", " + << getGemmSpecializationString(GemmSpec) + << ">" + << " LoopScheduler: " + << LoopSchedToString[LoopSched] << ", " + << "PipelineVersion: " + << PipelineVersionToString[PipelineVer]; + // clang-format on + + return str.str(); + } +}; // namespace device + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/grid/gemm_layernorm/gridwise_gemm_multiple_d_welford_first_half_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/grid/gemm_layernorm/gridwise_gemm_multiple_d_welford_first_half_xdl_cshuffle.hpp new file mode 100644 index 0000000000..aa34cfbf84 --- /dev/null +++ b/include/ck/tensor_operation/gpu/grid/gemm_layernorm/gridwise_gemm_multiple_d_welford_first_half_xdl_cshuffle.hpp @@ -0,0 +1,1111 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/multi_index_transform_helper.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_selector.hpp" +#include "ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp" +#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp" +#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v7.hpp" +#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/tensor_operation/gpu/block/blockwise_welford.hpp" +#include "ck/tensor_operation/gpu/thread/threadwise_welford.hpp" + +namespace ck { + +// GEMM: +// input : A[M, K] +// input : B[N, K] +// input : D0[M, N], D1[M, N], ... +// output : E[M, N] +// output : F[M, N0], where N0 is number of blocks along N dimension +// output : G[M, N0], where N0 is number of blocks along N dimension +// C = a_op(A) * b_op(B) +// E = cde_op(C, D0, D1, ...) +// F, G = welford(E) +// Assume: +// D0, D1, ... and E have the same layout +// Calculate mean & variance along N dimension for E +template +struct GridwiseGemmMultipleDWelfordFirstHalf_xdl_cshuffle +{ + static constexpr index_t NumDTensor = DsDataType::Size(); + + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + static constexpr auto I4 = Number<4>{}; + static constexpr auto I5 = Number<5>{}; + static constexpr auto I6 = Number<6>{}; + static constexpr auto I7 = Number<7>{}; + + // K1 should be Number<...> + static constexpr auto AK1 = Number{}; + static constexpr auto BK1 = Number{}; + static constexpr auto AK0PerBlock = Number{}; + static constexpr auto BK0PerBlock = Number{}; + + using ThisThreadBlock = ThisThreadBlock; + + using GridwiseGemmPipe = remove_cvref_t())>; + + __host__ __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1() + { + // A matrix in LDS memory, dst of blockwise copy + return make_naive_tensor_descriptor( + make_tuple(AK0PerBlock, Number{}, AK1), + make_tuple(Number{} * AK1, AK1, I1)); + } + + __host__ __device__ static constexpr auto GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1() + { + // B matrix in LDS memory, dst of blockwise copy + return make_naive_tensor_descriptor( + make_tuple(BK0PerBlock, Number{}, BK1), + make_tuple(Number{} * BK1, BK1, I1)); + } + + __host__ __device__ static constexpr auto + GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock() + { + constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); + constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl); + + constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = + make_naive_tensor_descriptor_packed( + make_tuple(I1, + Number{}, + I1, + Number{})); + + return c_shuffle_block_desc_mblock_mperblock_nblock_nperblock; + } + + // ck::Tuple + static constexpr auto MakeDsGridPointer() + { + return generate_tuple( + [&](auto i) { + using DDataType = remove_cvref_t>; + + return static_cast(nullptr); + }, + Number{}); + } + + __host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte() + { + // LDS allocation for A and B: be careful of alignment + constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(); + constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(); + + // lds max alignment + constexpr auto max_lds_align = math::lcm(AK1, BK1); + + constexpr auto a_block_space_size_aligned = math::integer_least_multiple( + a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align); + + constexpr auto b_block_space_size_aligned = math::integer_least_multiple( + b_block_desc_bk0_n_bk1.GetElementSpaceSize(), max_lds_align); + + // LDS allocation for C shuffle in LDS + constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = + GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(); + + constexpr auto c_block_size = + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize(); + + return math::max((a_block_space_size_aligned + b_block_space_size_aligned) * + sizeof(ABDataType), + c_block_size * sizeof(CShuffleDataType)); + } + + // A desc for source in blockwise copy + __host__ __device__ static constexpr auto + MakeDefaultAGridDescriptor_AK0_M_AK1(const AGridDesc_M_K& a_grid_desc_m_k) + { + const auto M = a_grid_desc_m_k.GetLength(I0); + const auto K = a_grid_desc_m_k.GetLength(I1); + + const auto AK0 = K / AK1; + + return transform_tensor_descriptor(a_grid_desc_m_k, + make_tuple(make_unmerge_transform(make_tuple(AK0, AK1)), + make_pass_through_transform(M)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + } + + // B desc for source in blockwise copy + __host__ __device__ static constexpr auto + MakeDefaultBGridDescriptor_BK0_N_BK1(const BGridDesc_N_K& b_grid_desc_n_k) + { + const auto N = b_grid_desc_n_k.GetLength(I0); + const auto K = b_grid_desc_n_k.GetLength(I1); + + const auto BK0 = K / BK1; + + return transform_tensor_descriptor(b_grid_desc_n_k, + make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)), + make_pass_through_transform(N)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + } + + // E desc for destination in blockwise copy + template + __host__ __device__ static constexpr auto MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + const EGridDescriptor_M_N& e_grid_desc_m_n) + { + const auto M = e_grid_desc_m_n.GetLength(I0); + const auto N = e_grid_desc_m_n.GetLength(I1); + + const auto MBlock = M / MPerBlock; + const auto NBlock = N / NPerBlock; + + const auto e_grid_desc_mblock_mperblock_nblock_nperblock = transform_tensor_descriptor( + e_grid_desc_m_n, + make_tuple(make_unmerge_transform(make_tuple(MBlock, Number{})), + make_unmerge_transform(make_tuple(NBlock, Number{}))), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{})); + + return e_grid_desc_mblock_mperblock_nblock_nperblock; + } + + // Ds desc for source in blockwise copy + template + __host__ __device__ static constexpr auto + MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + const DsGridDescriptor_M_N& ds_grid_desc_m_n) + { + return generate_tuple( + [&](auto i) { + return MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(ds_grid_desc_m_n[i]); + }, + Number{}); + } + + template + __host__ __device__ static constexpr auto + MakeMeanVarCountGridDescriptor_MBlock_MPerBlock_NBlock(const GridDescriptor_M_N& grid_desc_m_n) + { + const auto M = grid_desc_m_n.GetLength(I0); + const auto NBlock = grid_desc_m_n.GetLength(I1); + const auto MBlock = M / MPerBlock; + + const auto grid_desc_mblock_mperblock_nblock = transform_tensor_descriptor( + grid_desc_m_n, + make_tuple(make_unmerge_transform(make_tuple(MBlock, Number{})), + make_pass_through_transform(NBlock)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 1>{}, Sequence<2>{})); + + return grid_desc_mblock_mperblock_nblock; + } + + // return block_id to E matrix tile idx (m0, n0) mapping + __host__ __device__ static constexpr auto + MakeDefaultBlock2ETileMap(const EGridDesc_M_N& e_grid_desc_m_n) + { + return BlockToCTileMap_M00_N0_M01Adapt( + e_grid_desc_m_n); + } + + // block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01} + template + __host__ __device__ static constexpr bool CheckValidity(const AGridDesc_M_K& a_grid_desc_m_k, + const BGridDesc_N_K& b_grid_desc_n_k, + const DsGridDesc_M_N& ds_grid_desc_m_n, + const EGridDesc_M_N& e_grid_desc_m_n, + const Block2ETileMap& block_2_etile_map) + { + static_assert((MPerBlock % (MPerXdl * MXdlPerWave) == 0) && + (NPerBlock % (NXdlPerWave * NPerXdl)) == 0, + "Invalid tuning param!"); + + const auto M = a_grid_desc_m_k.GetLength(I0); + const auto N = b_grid_desc_n_k.GetLength(I0); + const auto K = a_grid_desc_m_k.GetLength(I1); + + // check consistency of desc + if(!(M == e_grid_desc_m_n.GetLength(I0) && N == e_grid_desc_m_n.GetLength(I1))) + { + return false; + } + + bool valid = true; + + static_for<0, NumDTensor, 1>{}([&](auto i) { + valid = valid && (M == ds_grid_desc_m_n[i].GetLength(I0) && + N == ds_grid_desc_m_n[i].GetLength(I1)); + }); + + if(!valid) + { + return false; + } + + // check tile size + if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && K % KPerBlock == 0)) + { + return false; + } + + // check gridwise gemm pipeline + const auto num_k_loop = K / KPerBlock; + + if(!GridwiseGemmPipe::IsSupported(num_k_loop)) + { + return false; + } + + // check block-to-E-tile + if(!block_2_etile_map.CheckValidity(e_grid_desc_m_n)) + { + return false; + } + + // TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc) + // check tensor size: cannot be larger than 2GB each + constexpr long_index_t TwoGB = (long_index_t{1} << 31); + + if(!(a_grid_desc_m_k.GetElementSpaceSize() * sizeof(ABDataType) <= TwoGB && + b_grid_desc_n_k.GetElementSpaceSize() * sizeof(ABDataType) <= TwoGB && + e_grid_desc_m_n.GetElementSpaceSize() * sizeof(EMeanVarDataType) <= TwoGB)) + { + return false; + } + + return true; + } + + __host__ __device__ static constexpr bool CalculateHasMainKBlockLoop(index_t K) + { + const index_t num_loop = K / KPerBlock; + + return GridwiseGemmPipe::CalculateHasMainLoop(num_loop); + } + + using DefaultAGridDesc_AK0_M_AK1 = + remove_cvref_t; + using DefaultBGridDesc_BK0_N_BK1 = + remove_cvref_t; + using EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t; + using MeanVarGridDescriptor_MBlock_MPerBlock_NBlock = remove_cvref_t; + using CountGridDescriptor_MBlock_MPerBlock_NBlock = remove_cvref_t; + using DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t; + + using DefaultBlock2ETileMap = + remove_cvref_t; + + using DsGridPointer = decltype(MakeDsGridPointer()); + + template + __device__ static void + Run(const ABDataType* __restrict__ p_a_grid, + const ABDataType* __restrict__ p_b_grid, + DsGridPointer p_ds_grid, + EMeanVarDataType* __restrict__ p_e_grid, + EMeanVarDataType* __restrict__ p_welford_mean_grid, + EMeanVarDataType* __restrict__ p_welford_var_grid, + int32_t* __restrict__ p_welford_count, + void* __restrict__ p_shared, + const AElementwiseOperation& a_element_op, + const BElementwiseOperation& b_element_op, + const CDEElementwiseOperation& cde_element_op, + const AGridDesc_AK0_M_AK1& a_grid_desc_ak0_m_ak1, + const BGridDesc_BK0_N_BK1& b_grid_desc_bk0_n_bk1, + const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock& + ds_grid_desc_mblock_mperblock_nblock_nperblock, + const EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock& + e_grid_desc_mblock_mperblock_nblock_nperblock, + const MeanVarGridDescriptor_MBlock_MPerBlock_NBlock& + mean_var_grid_desc_mblock_mperblock_nblock, + const CountGridDescriptor_MBlock_MPerBlock_NBlock& count_grid_desc_mblock_mperblock_nblock, + const Block2ETileMap& block_2_etile_map, + index_t NRaw) + { + const auto a_grid_buf = make_dynamic_buffer( + p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize()); + + const auto b_grid_buf = make_dynamic_buffer( + p_b_grid, b_grid_desc_bk0_n_bk1.GetElementSpaceSize()); + + const auto ds_grid_buf = generate_tuple( + [&](auto i) { + return make_dynamic_buffer( + p_ds_grid[i], + ds_grid_desc_mblock_mperblock_nblock_nperblock[i].GetElementSpaceSize()); + }, + Number{}); + + auto e_grid_buf = make_dynamic_buffer( + p_e_grid, e_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); + + auto mean_grid_buf = make_dynamic_buffer( + p_welford_mean_grid, mean_var_grid_desc_mblock_mperblock_nblock.GetElementSpaceSize()); + + auto var_grid_buf = make_dynamic_buffer( + p_welford_var_grid, mean_var_grid_desc_mblock_mperblock_nblock.GetElementSpaceSize()); + + auto welford_count_grid_buf = make_dynamic_buffer( + p_welford_count, count_grid_desc_mblock_mperblock_nblock.GetElementSpaceSize()); + + // divide block work by [M, N] + const auto block_work_idx = + block_2_etile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id())); + + if(!block_2_etile_map.ValidCTileIndex( + block_work_idx, + make_tuple(e_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I0), + e_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I2)))) + { + return; + } + + // HACK: this force m/n_block_data_idx_on_grid into SGPR + const index_t m_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(block_work_idx[I0] * MPerBlock); + + const index_t n_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(block_work_idx[I1] * NPerBlock); + + // lds max alignment + constexpr auto max_lds_align = math::lcm(AK1, BK1); + + // A matrix in LDS memory, dst of blockwise copy + constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(); + + // B matrix in LDS memory, dst of blockwise copy + constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(); + + // A matrix blockwise copy + auto a_blockwise_copy = + ThreadGroupTensorSliceTransfer_v4r1, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + ABDataType, + ABDataType, + decltype(a_grid_desc_ak0_m_ak1), + decltype(a_block_desc_ak0_m_ak1), + ABlockTransferSrcAccessOrder, + Sequence<1, 0, 2>, + ABlockTransferSrcVectorDim, + 2, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + 1, + 1, + AThreadTransferSrcResetCoordinateAfterRun, + true, + NumGemmKPrefetchStage>( + a_grid_desc_ak0_m_ak1, + make_multi_index(0, m_block_data_idx_on_grid, 0), + a_element_op, + a_block_desc_ak0_m_ak1, + make_multi_index(0, 0, 0), + ck::tensor_operation::element_wise::PassThrough{}); + + // B matrix blockwise copy + auto b_blockwise_copy = + ThreadGroupTensorSliceTransfer_v4r1, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + ABDataType, + ABDataType, + decltype(b_grid_desc_bk0_n_bk1), + decltype(b_block_desc_bk0_n_bk1), + BBlockTransferSrcAccessOrder, + Sequence<1, 0, 2>, + BBlockTransferSrcVectorDim, + 2, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_BK1, + 1, + 1, + BThreadTransferSrcResetCoordinateAfterRun, + true, + NumGemmKPrefetchStage>( + b_grid_desc_bk0_n_bk1, + make_multi_index(0, n_block_data_idx_on_grid, 0), + b_element_op, + b_block_desc_bk0_n_bk1, + make_multi_index(0, 0, 0), + ck::tensor_operation::element_wise::PassThrough{}); + + // GEMM definition + // c_mtx += transpose(a_mtx) * b_mtx + // a_mtx[K0PerBlock, MPerBlock] is in LDS + // b_mtx[K0PerBlock, NPerBlock] is in LDS + // c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in + // register + // sanity check + constexpr index_t KPack = + math::max(math::lcm(AK1, BK1), + MfmaSelector::selected_mfma.k_per_blk); + + auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector< + BlockSize, + ABDataType, + AccDataType, + decltype(a_block_desc_ak0_m_ak1), + decltype(b_block_desc_bk0_n_bk1), + MPerXdl, + NPerXdl, + MXdlPerWave, + NXdlPerWave, + KPack, + LoopSched>(); + + auto c_thread_buf = blockwise_gemm.GetCThreadBuffer(); + + // LDS allocation for A and B: be careful of alignment + constexpr auto a_block_space_size_aligned = math::integer_least_multiple( + a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align); + + auto a_block_buf = make_dynamic_buffer( + static_cast(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize()); + + auto b_block_buf = make_dynamic_buffer( + static_cast(p_shared) + a_block_space_size_aligned, + b_block_desc_bk0_n_bk1.GetElementSpaceSize()); + + constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1, 0, 0); + constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock / BK1, 0, 0); + + // gridwise GEMM pipeline + const auto gridwise_gemm_pipeline = + GridwiseGemmPipeline_Selector(); + + const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane( + (a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2)) / + KPerBlock); + + gridwise_gemm_pipeline.template Run(a_grid_desc_ak0_m_ak1, + a_block_desc_ak0_m_ak1, + a_blockwise_copy, + a_grid_buf, + a_block_buf, + a_block_slice_copy_step, + b_grid_desc_bk0_n_bk1, + b_block_desc_bk0_n_bk1, + b_blockwise_copy, + b_grid_buf, + b_block_buf, + b_block_slice_copy_step, + blockwise_gemm, + c_thread_buf, + num_k_block_main_loop); + + // shuffle C, Welford and write out + { + static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 && + NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0, + "wrong!"); + + constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); + constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl); + + // TODO: hacky, fix it! + constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 = + blockwise_gemm.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); + + // TODO: hacky, fix it! + // c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp is only used to get lengths + constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp = + blockwise_gemm.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); + + constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I0); + constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I1); + constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I2); + constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I3); + constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I4); + constexpr auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I5); + constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I6); + constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I7); + + constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = + GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(); + + auto c_shuffle_block_buf = make_dynamic_buffer( + static_cast(p_shared), + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); + + constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = transform_tensor_descriptor( + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, + make_tuple( + make_freeze_transform(I0), + make_unmerge_transform(make_tuple( + Number{}, // M0 (MXdlPerWave) per shuffle + M1, // M1 = MWave + M2, // M2 * M3 * M4 = MPerXdl + M3, + M4)), + make_freeze_transform(I0), + make_unmerge_transform(make_tuple( + Number{}, // N0 (NXdlPerWave) per shuffle + N1, // N1 = NWave + N2))), // N2 = NPerXdl + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple( + Sequence<>{}, Sequence<0, 2, 4, 5, 6>{}, Sequence<>{}, Sequence<1, 3, 7>{})); + + // calculate origin of thread output tensor on global memory + // blockwise GEMM c matrix starting index + const auto c_thread_mtx_on_block = + blockwise_gemm.CalculateCThreadOriginDataIndex(I0, I0, I0, I0); + + const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0]; + const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1]; + + const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))), + make_tuple(Sequence<0, 1, 2, 3, 4>{}), + make_tuple(Sequence<0>{})); + + const auto m_thread_data_on_block_idx = + m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex( + make_multi_index(m_thread_data_on_block)); + + const auto n_thread_data_on_block_to_n0_n1_n2_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(N0, N1, N2))), + make_tuple(Sequence<0, 1, 2>{}), + make_tuple(Sequence<0>{})); + + const auto n_thread_data_on_block_idx = + n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex( + make_multi_index(n_thread_data_on_block)); + + // shuffle: threadwise copy C from VGPR to LDS + auto c_thread_copy_vgpr_to_lds = + ThreadwiseTensorSliceTransfer_v1r3, + Sequence<0, 1, 2, 3, 4, 5, 6, 7>, + 7, + 1, + InMemoryDataOperationEnum::Set, + 1, + true>{ + c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, + make_multi_index(0, + 0, + m_thread_data_on_block_idx[I1], + n_thread_data_on_block_idx[I1], + m_thread_data_on_block_idx[I2], + m_thread_data_on_block_idx[I3], + m_thread_data_on_block_idx[I4], + n_thread_data_on_block_idx[I2]), + ck::tensor_operation::element_wise::PassThrough{}}; + + // space filling curve for threadwise C in VGPR + constexpr auto sfc_c_vgpr = + SpaceFillingCurve, + Sequence<0, 1, 2, 3, 4, 5, 6, 7>, + Sequence, + false>{}; + + // space filling curve for shuffled blockwise C in global mem + constexpr auto sfc_der_global = + SpaceFillingCurve, + Sequence<0, 2, 1, 3>, + Sequence<1, + CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, + 1, + CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, + false>{}; + + // LDS c_shuffle_block_desc_mperblock_nperblock + constexpr auto c_shuffle_block_desc_mperblock_nperblock = transform_tensor_descriptor( + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, + make_tuple( + make_freeze_transform(I0), + make_pass_through_transform( + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetLength(I1)), + make_freeze_transform(I0), + make_pass_through_transform( + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetLength(I3))), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<>{}, Sequence<0>{}, Sequence<>{}, Sequence<1>{})); + + static_assert(PostShuffleThreadClusterSize_M_N::At(I0) * + PostShuffleThreadClusterSize_M_N::At(I1) == + BlockSize, + "wrong!"); + + static_assert((CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl) % + PostShuffleThreadClusterSize_M_N::At(I0) == + 0 && + (CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl) % + PostShuffleThreadClusterSize_M_N::At(I1) == + 0, + "wrong!"); + + constexpr index_t PostShuffleThreadSliceSize_M = + (CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl) / + PostShuffleThreadClusterSize_M_N::At(I0); + + constexpr index_t PostShuffleThreadSliceSize_N = + (CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl) / + PostShuffleThreadClusterSize_M_N::At(I1); + + constexpr auto PostShuffleThreadSliceSize_M_N = + Sequence{}; + + // VGPR post_shuffle_thread_desc_m_n + constexpr auto post_shuffle_thread_desc_m_n = make_naive_tensor_descriptor_packed( + make_tuple(Number{}, + Number{})); + + auto e_thread_buf = make_static_buffer( + post_shuffle_thread_desc_m_n.GetElementSpaceSize()); + + // To apply D0, D1, ... and Welford. + // threadwise copy from LDS to VGPR + constexpr auto post_shuffle_thread_cluster_desc = + make_cluster_descriptor(PostShuffleThreadClusterSize_M_N{}, Sequence<0, 1>{}); + + const auto post_shuffle_thread_cluster_idx = + post_shuffle_thread_cluster_desc.CalculateBottomIndex( + make_multi_index(get_thread_local_1d_id())); + + const auto post_shuffle_thread_data_idx_begin = + post_shuffle_thread_cluster_idx * PostShuffleThreadSliceSize_M_N; + + // To apply D0, D1, ... and Welford. + // Copy c shuffle from LDS back to VGPR + auto post_shuffle_thread_copy_lds_to_vgpr = + ThreadwiseTensorSliceTransfer_v2, + 1, + PostShuffleScalarPerVector, + 1, + true>{c_shuffle_block_desc_mperblock_nperblock, + post_shuffle_thread_data_idx_begin}; + + // D0, D1, ..., Dn + constexpr auto post_shuffle_thread_desc_I1_mperblock_I1_nperblock = + make_naive_tensor_descriptor_packed( + make_tuple(I1, + Number{}, + I1, + Number{})); + + // FIXME: Decrease usage of VGPR + // Apply pointwise lambda function from multi-source (Global and LDS) into VGPR + auto ds_thread_buf = generate_tuple( + [&](auto) { + return make_static_buffer( + post_shuffle_thread_desc_I1_mperblock_I1_nperblock.GetElementSpaceSize()); + }, + Number{}); + + // Copy D0, D1, ..., Dn from global to VGPR + auto ds_thread_copy_global_to_vgpr = generate_tuple( + [&](auto I) { + using DDataType = remove_cvref_t>; + return ThreadwiseTensorSliceTransfer_v2< + DDataType, + AccDataType, + decltype(ds_grid_desc_mblock_mperblock_nblock_nperblock[I]), + decltype(post_shuffle_thread_desc_I1_mperblock_I1_nperblock), + Sequence, + Sequence<0, 1, 2, 3>, + 3, + PostShuffleScalarPerVector, + 1, + true>( + ds_grid_desc_mblock_mperblock_nblock_nperblock[I], + make_multi_index( + I0, + m_block_data_idx_on_grid + post_shuffle_thread_data_idx_begin[I0], + I0, + n_block_data_idx_on_grid + post_shuffle_thread_data_idx_begin[I1])); + }, + Number{}); + + auto e_thread_copy_vgpr_to_global = ThreadwiseTensorSliceTransfer_v1r3< + AccDataType, + EMeanVarDataType, + decltype(post_shuffle_thread_desc_I1_mperblock_I1_nperblock), + decltype(e_grid_desc_mblock_mperblock_nblock_nperblock), + tensor_operation::element_wise::PassThrough, + Sequence, // SliceLengths + Sequence<0, 1, 2, 3>, // DimAccessOrder + 3, // DstVectorDim + PostShuffleScalarPerVector, + InMemoryDataOperationEnum::Set, + 1, + true>{ + e_grid_desc_mblock_mperblock_nblock_nperblock, + make_multi_index(I0, + m_block_data_idx_on_grid + post_shuffle_thread_data_idx_begin[I0], + I0, + n_block_data_idx_on_grid + post_shuffle_thread_data_idx_begin[I1]), + tensor_operation::element_wise::PassThrough{}}; + + // Welford + constexpr auto thread_welford_src_desc_m_k = make_naive_tensor_descriptor_packed( + make_tuple(Number{}, + Number{})); + + constexpr auto thread_welford_dst_desc_m = make_naive_tensor_descriptor_packed( + make_tuple(Number{})); + + using ThreadwiseWelford = ThreadwiseWelford; + + using BlockwiseWelford = BlockwiseWelford, + false>; + + constexpr int num_shuffleM = + MPerBlock / (CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl); + + constexpr int num_shuffleN = + NPerBlock / (CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl); + + using mean_var_vgpr_type = + decltype(make_static_buffer( + thread_welford_dst_desc_m.GetElementSpaceSize())); + + using welford_count_vgpr_type = + decltype(make_static_buffer( + thread_welford_dst_desc_m.GetElementSpaceSize())); + + Array threadwise_welfords; + Array mean_thread_bufs; + Array var_thread_bufs; + Array welford_count_thread_bufs; + + int max_count = PostShuffleThreadSliceSize_N * num_shuffleN; + const auto nblock = mean_var_grid_desc_mblock_mperblock_nblock.GetLength(I2); + + // tail block + if(block_work_idx[I1] % nblock == nblock - 1) + { + constexpr index_t NPerShuffleBlock = + CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl; + + int NPerBlockTail = NRaw - NPerBlock * (nblock - 1); + int thread_max_len = + PostShuffleThreadSliceSize_N * (post_shuffle_thread_cluster_idx[I1] + 1); + int shuffle_step = 0; + while(thread_max_len <= NPerBlockTail && shuffle_step < num_shuffleN) + { + ++shuffle_step; + thread_max_len += NPerShuffleBlock; + } + + int delta = 0; + if(thread_max_len - NPerBlockTail > PostShuffleThreadSliceSize_N) + delta = 0; + else if(NPerBlockTail > thread_max_len) + delta = PostShuffleThreadSliceSize_N; + else + delta = PostShuffleThreadSliceSize_N - thread_max_len + NPerBlockTail; + + max_count = shuffle_step * PostShuffleThreadSliceSize_N + delta; + } + + static_for<0, num_shuffleM, 1>{}([&](auto i) { + threadwise_welfords(i).max_count_ = max_count; + mean_thread_bufs(i) = make_static_buffer( + thread_welford_dst_desc_m.GetElementSpaceSize()); + + var_thread_bufs(i) = make_static_buffer( + thread_welford_dst_desc_m.GetElementSpaceSize()); + + welford_count_thread_bufs(i) = make_static_buffer( + thread_welford_dst_desc_m.GetElementSpaceSize()); + + static_for<0, PostShuffleThreadSliceSize_M, 1>{}([&](auto j) { + mean_thread_bufs(i)(j) = type_convert(0.0f); + var_thread_bufs(i)(j) = type_convert(0.0f); + welford_count_thread_bufs(i)(j) = 0; + }); + }); + + constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess(); + + static_assert(num_access == sfc_der_global.GetNumOfAccess(), "wrong!"); + + int shuffleM_index = __builtin_amdgcn_readfirstlane(0); + static_for<0, num_access, 1>{}([&](auto access_id) { + // make sure it's safe to read from LDS + block_sync_lds(); + + // each thread shuffle data from VGPR to LDS + c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2, + sfc_c_vgpr.GetIndexTupleOfNumber(access_id), + c_thread_buf, + c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, + c_shuffle_block_buf); + + // make sure it's safe to write to LDS + block_sync_lds(); + + // Get shuffle data from LDS to VGPR + post_shuffle_thread_copy_lds_to_vgpr.Run(c_shuffle_block_desc_mperblock_nperblock, + c_shuffle_block_buf, + post_shuffle_thread_desc_m_n, + make_tuple(I0, I0), + e_thread_buf); + + // Global read D0, D1, ... + static_for<0, NumDTensor, 1>{}([&](auto Id) { + auto& d_thread_copy_global_to_vgpr = ds_thread_copy_global_to_vgpr(Id); + d_thread_copy_global_to_vgpr.Run( + ds_grid_desc_mblock_mperblock_nblock_nperblock[Id], + ds_grid_buf[Id], + post_shuffle_thread_desc_I1_mperblock_I1_nperblock, + make_tuple(I0, I0, I0, I0), + ds_thread_buf(Id)); + + if constexpr(access_id < num_access - 1) + { + // move on D0, D1, ... + constexpr auto de_global_step = sfc_der_global.GetForwardStep(access_id); + d_thread_copy_global_to_vgpr.MoveSrcSliceWindow( + ds_grid_desc_mblock_mperblock_nblock_nperblock[Id], de_global_step); + } + }); + + // cde_element_op(e, c, d0, d1, ...); + static_for<0, post_shuffle_thread_desc_m_n.GetElementSize(), 1>{}([&](auto i) { + const auto c_ds_src_data_refs = concat_tuple_of_reference( + tie(e_thread_buf[i]), + generate_tie( + [&](auto Id) -> const auto& { return ds_thread_buf[Id][i]; }, + Number{})); + auto e_dst_data_refs = tie(e_thread_buf(i)); + unpack2(cde_element_op, e_dst_data_refs, c_ds_src_data_refs); + }); + + // Global write E + e_thread_copy_vgpr_to_global.Run(post_shuffle_thread_desc_I1_mperblock_I1_nperblock, + make_tuple(I0, I0, I0, I0), + e_thread_buf, + e_grid_desc_mblock_mperblock_nblock_nperblock, + e_grid_buf); + + if constexpr(access_id < num_access - 1) + { + // move on E + constexpr auto de_global_step = sfc_der_global.GetForwardStep(access_id); + e_thread_copy_vgpr_to_global.MoveDstSliceWindow( + e_grid_desc_mblock_mperblock_nblock_nperblock, de_global_step); + } + + // Threadwise welford + auto& threadwise_welford = threadwise_welfords(shuffleM_index); + auto& mean_thread_buf = mean_thread_bufs(shuffleM_index); + auto& var_thread_buf = var_thread_bufs(shuffleM_index); + + threadwise_welford.Run(e_thread_buf, mean_thread_buf, var_thread_buf); + + if constexpr(access_id < num_access - 1) + { + constexpr auto de_global_step = sfc_der_global.GetForwardStep(access_id); + constexpr int shuffleMInc = + de_global_step[I1] / + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetLength(I1); + shuffleM_index = __builtin_amdgcn_readfirstlane(shuffleM_index + shuffleMInc); + } + }); // copy c, d, e + welford + + // Blockwise welford and write out + static_for<0, num_shuffleM, 1>{}([&](auto i) { + auto& mean_thread_buf = mean_thread_bufs(i); + auto& var_thread_buf = var_thread_bufs(i); + auto& count_thread_buf = welford_count_thread_bufs(i); + + static_for<0, PostShuffleThreadSliceSize_M, 1>{}([&](auto j) { + block_sync_lds(); + count_thread_buf(j) = threadwise_welfords(i).cur_count_; + BlockwiseWelford::Run( + mean_thread_buf(j), var_thread_buf(j), count_thread_buf(j)); + }); + + if(post_shuffle_thread_cluster_idx[I1] == 0) + { + constexpr auto thread_welford_desc_I_m_I = make_naive_tensor_descriptor_packed( + make_tuple(I1, Number{}, I1)); + + constexpr int shuffleMPerBlock = + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetLength(I1); + + auto mean_var_count_thread_copy_index = make_multi_index( + block_work_idx[I0], // mblock + shuffleMPerBlock * i + post_shuffle_thread_data_idx_begin[I0], // mperblock + block_work_idx[I1]); // nblock + + auto mean_var_thread_copy_vgpr_to_global = ThreadwiseTensorSliceTransfer_v1r3< + AccDataType, + EMeanVarDataType, + decltype(thread_welford_desc_I_m_I), + decltype(mean_var_grid_desc_mblock_mperblock_nblock), + tensor_operation::element_wise::PassThrough, + Sequence<1, PostShuffleThreadSliceSize_M, 1>, + Sequence<0, 1, 2>, + 1, + 1, + InMemoryDataOperationEnum::Set, + 1, + true>{mean_var_grid_desc_mblock_mperblock_nblock, + mean_var_count_thread_copy_index, + tensor_operation::element_wise::PassThrough{}}; + + mean_var_thread_copy_vgpr_to_global.Run( + thread_welford_desc_I_m_I, + make_tuple(I0, I0, I0), + mean_thread_buf, + mean_var_grid_desc_mblock_mperblock_nblock, + mean_grid_buf); // write mean + + mean_var_thread_copy_vgpr_to_global.Run( + thread_welford_desc_I_m_I, + make_tuple(I0, I0, I0), + var_thread_buf, + mean_var_grid_desc_mblock_mperblock_nblock, + var_grid_buf); // write variance + + // Stride of count is [0, 1]. Only the first row in count[0, 0:nblock] need + // to be written. + if(i == 0 && block_work_idx[I0] == 0 && + post_shuffle_thread_cluster_idx[I0] == 0) + { + auto count_thread_copy_vgpr_to_global = ThreadwiseTensorSliceTransfer_v1r3< + int32_t, + int32_t, + decltype(thread_welford_desc_I_m_I), + decltype(count_grid_desc_mblock_mperblock_nblock), + tensor_operation::element_wise::PassThrough, + Sequence<1, PostShuffleThreadSliceSize_M, 1>, + Sequence<0, 1, 2>, + 1, + 1, + InMemoryDataOperationEnum::Set, + 1, + false>{count_grid_desc_mblock_mperblock_nblock, + mean_var_count_thread_copy_index, + tensor_operation::element_wise::PassThrough{}}; + + count_thread_copy_vgpr_to_global.Run( + thread_welford_desc_I_m_I, + make_tuple(I0, I0, I0), + count_thread_buf, + count_grid_desc_mblock_mperblock_nblock, + welford_count_grid_buf); // write count + } + } + }); + + } // shuffle C + Ds + welford + write out + } // run +}; + +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/grid/gemm_layernorm/gridwise_welford_second_half_layernorm2d.hpp b/include/ck/tensor_operation/gpu/grid/gemm_layernorm/gridwise_welford_second_half_layernorm2d.hpp new file mode 100644 index 0000000000..fbe89e7e5e --- /dev/null +++ b/include/ck/tensor_operation/gpu/grid/gemm_layernorm/gridwise_welford_second_half_layernorm2d.hpp @@ -0,0 +1,394 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/multi_index_transform_helper.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v1.hpp" +#include "ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp" +#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp" +#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v7.hpp" +#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/tensor_operation/gpu/block/blockwise_welford.hpp" +#include "ck/tensor_operation/gpu/thread/threadwise_welford.hpp" + +namespace ck { + +template +struct GridwiseWelfordSecondHalfLayernorm2d +{ + static_assert(NThreadSliceSize % ESrcVectorSize == 0 && + NThreadSliceSize % GammaSrcVectorSize == 0 && + NThreadSliceSize % BetaSrcVectorSize == 0, + "Invalid thread slice sizes and/or vector sizes configuration, please check!"); + + static_assert(NThreadSliceSize % HDstVectorSize == 0, + "Invalid thread slice sizes and/or vector sizes configuration, please check!"); + + using ThreadClusterLengths_M_N = Sequence; + using ThreadBufferDimAccessOrder = Sequence<0, 1>; + using ThreadClusterArrangeOrder = Sequence<0, 1>; + + static constexpr auto thread_cluster_desc_m_n = + make_cluster_descriptor(ThreadClusterLengths_M_N{}, ThreadClusterArrangeOrder{}); + + using ThreadBufferLengths_M_N = Sequence; + static constexpr auto thread_buffer_desc_m_n = make_naive_tensor_descriptor_packed( + make_tuple(Number{}, Number{})); + + using ThreadBufferLengths_M_1 = Sequence; + static constexpr auto thread_buffer_desc_m_1 = + make_naive_tensor_descriptor_packed(make_tuple(Number{}, Number<1>{})); + + using ThreadBufferLengths_N = Sequence; + static constexpr auto thread_buffer_desc_n = + make_naive_tensor_descriptor_packed(make_tuple(Number{})); + + using ThreadWelfordSrcDesc_M_1 = decltype(thread_buffer_desc_m_1); + using ThreadWelfordDstDesc_M = + decltype(make_naive_tensor_descriptor_packed(make_tuple(Number{}))); + + using ThreadwiseWelford = + ThreadwiseWelfordMerge; + + using BlockwiseWelford = BlockwiseWelford; + + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + + static constexpr index_t M_BlockTileSize = MThreadClusterSize * MThreadSliceSize; + static constexpr index_t N_BlockTileSize = NThreadClusterSize * NThreadSliceSize; + + __device__ static void Run(const EMeanVarDataType* __restrict__ p_e_grid, + const EMeanVarDataType* __restrict__ p_in_welford_mean_grid, + const EMeanVarDataType* __restrict__ p_in_welford_var_grid, + const int32_t* __restrict__ p_in_welford_count_grid, + const GammaDataType* __restrict__ p_gamma_grid, + const BetaDataType* __restrict__ p_beta_grid, + HDataType* __restrict__ p_h_grid, + const EHGridDesc_M_N& e_grid_desc_m_n, + const EHGridDesc_M_N& h_grid_desc_m_n, + const MeanVarGridDesc_M_NBlock& mean_var_grid_desc_m_nblock, + const CountGridDesc_M_NBlock& count_grid_desc_m_nblock, + const GammaBetaGridDesc_N& gamma_grid_desc_n, + const GammaBetaGridDesc_N& beta_grid_desc_n, + index_t numMeanVarCountBlockTileIteration_N, + index_t NBlockClusterLength, + ComputeDataType epsilon, + HElementwiseOperation h_element_op) + { + // Thread/Block id + const index_t thread_local_id = get_thread_local_1d_id(); + const index_t block_global_id = get_block_1d_id(); + const auto block_work_idx = make_tuple(block_global_id / NBlockClusterLength, + block_global_id % NBlockClusterLength); + + const auto thread_cluster_idx = + thread_cluster_desc_m_n.CalculateBottomIndex(make_multi_index(thread_local_id)); + const auto thread_m_cluster_id = thread_cluster_idx[I0]; + const auto thread_n_cluster_id = thread_cluster_idx[I1]; + + // Global Memory + const auto e_global_val_buf = make_dynamic_buffer( + p_e_grid, e_grid_desc_m_n.GetElementSpaceSize()); + + const auto welford_mean_global_val_buf = make_dynamic_buffer( + p_in_welford_mean_grid, mean_var_grid_desc_m_nblock.GetElementSpaceSize()); + + const auto welford_var_global_val_buf = make_dynamic_buffer( + p_in_welford_var_grid, mean_var_grid_desc_m_nblock.GetElementSpaceSize()); + + const auto welford_count_global_val_buf = make_dynamic_buffer( + p_in_welford_count_grid, count_grid_desc_m_nblock.GetElementSpaceSize()); + + const auto gamma_global_val_buf = make_dynamic_buffer( + p_gamma_grid, gamma_grid_desc_n.GetElementSpaceSize()); + + const auto beta_global_val_buf = make_dynamic_buffer( + p_beta_grid, beta_grid_desc_n.GetElementSpaceSize()); + + auto h_global_val_buf = make_dynamic_buffer( + p_h_grid, h_grid_desc_m_n.GetElementSpaceSize()); + + // VGPR + StaticBuffer + in_welford_mean_thread_buf; + StaticBuffer + in_welford_var_thread_buf; + StaticBuffer + in_welford_count_thread_buf; + + StaticBuffer + welford_mean_thread_buf; + StaticBuffer + welford_var_thread_buf; + StaticBuffer + welford_count_thread_buf; + + StaticBuffer + e_thread_buf; + StaticBuffer + gamma_thread_buf; + StaticBuffer + beta_thread_buf; + StaticBuffer + h_thread_buf; + + // IO + auto threadwise_mean_load_m_nblock = + ThreadwiseTensorSliceTransfer_v2( + mean_var_grid_desc_m_nblock, + make_multi_index(block_work_idx[I0] * M_BlockTileSize + + thread_m_cluster_id * MThreadSliceSize, + thread_n_cluster_id)); + + auto threadwise_var_load_m_nblock = + ThreadwiseTensorSliceTransfer_v2( + mean_var_grid_desc_m_nblock, + make_multi_index(block_work_idx[I0] * M_BlockTileSize + + thread_m_cluster_id * MThreadSliceSize, + thread_n_cluster_id)); + + auto threadwise_count_load_m_nblock = + ThreadwiseTensorSliceTransfer_v2( + count_grid_desc_m_nblock, + make_multi_index(block_work_idx[I0] * M_BlockTileSize + + thread_m_cluster_id * MThreadSliceSize, + thread_n_cluster_id)); + + auto threadwise_e_load_m_n = + ThreadwiseTensorSliceTransfer_v2( + e_grid_desc_m_n, + make_multi_index( + block_work_idx[I0] * M_BlockTileSize + thread_m_cluster_id * MThreadSliceSize, + block_work_idx[I1] * N_BlockTileSize + thread_n_cluster_id * NThreadSliceSize)); + + auto threadwise_gamma_load_n = + ThreadwiseTensorSliceTransfer_v2, // DimAccessOrder, + 0, // SrcVectorDim, + GammaSrcVectorSize, + 1, + true>( + gamma_grid_desc_n, + make_multi_index(block_work_idx[I1] * N_BlockTileSize + + thread_n_cluster_id * NThreadSliceSize)); + + auto threadwise_beta_load_n = + ThreadwiseTensorSliceTransfer_v2, // DimAccessOrder, + 0, // SrcVectorDim, + BetaSrcVectorSize, + 1, + true>( + beta_grid_desc_n, + make_multi_index(block_work_idx[I1] * N_BlockTileSize + + thread_n_cluster_id * NThreadSliceSize)); + + auto threadwise_h_store_m_n = + ThreadwiseTensorSliceTransfer_v1r3( + h_grid_desc_m_n, + make_multi_index( + block_work_idx[I0] * M_BlockTileSize + thread_m_cluster_id * MThreadSliceSize, + block_work_idx[I1] * N_BlockTileSize + thread_n_cluster_id * NThreadSliceSize), + h_element_op); + + // step1: Merge mean and variance + constexpr auto mean_var_count_thread_copy_step_I0_n = + make_multi_index(I0, NThreadClusterSize); + + static_for<0, MThreadSliceSize, 1>{}([&](auto I) { + welford_mean_thread_buf(I) = type_convert(0.0f); + welford_var_thread_buf(I) = type_convert(0.0f); + welford_count_thread_buf(I) = 0; + }); + + for(index_t n = 0; n < numMeanVarCountBlockTileIteration_N; ++n) + { + threadwise_mean_load_m_nblock.Run(mean_var_grid_desc_m_nblock, + welford_mean_global_val_buf, + thread_buffer_desc_m_1, + make_tuple(I0, I0), + in_welford_mean_thread_buf); + + threadwise_var_load_m_nblock.Run(mean_var_grid_desc_m_nblock, + welford_var_global_val_buf, + thread_buffer_desc_m_1, + make_tuple(I0, I0), + in_welford_var_thread_buf); + + threadwise_count_load_m_nblock.Run(count_grid_desc_m_nblock, + welford_count_global_val_buf, + thread_buffer_desc_m_1, + make_tuple(I0, I0), + in_welford_count_thread_buf); + + ThreadwiseWelford::Run(in_welford_mean_thread_buf, + in_welford_var_thread_buf, + in_welford_count_thread_buf, + welford_mean_thread_buf, + welford_var_thread_buf, + welford_count_thread_buf); + + threadwise_mean_load_m_nblock.MoveSrcSliceWindow(mean_var_grid_desc_m_nblock, + mean_var_count_thread_copy_step_I0_n); + threadwise_var_load_m_nblock.MoveSrcSliceWindow(mean_var_grid_desc_m_nblock, + mean_var_count_thread_copy_step_I0_n); + threadwise_count_load_m_nblock.MoveSrcSliceWindow(count_grid_desc_m_nblock, + mean_var_count_thread_copy_step_I0_n); + } + + static_for<0, MThreadSliceSize, 1>{}([&](auto I) { + if constexpr(I > 0) + block_sync_lds(); + + BlockwiseWelford::Run( + welford_mean_thread_buf(I), welford_var_thread_buf(I), welford_count_thread_buf(I)); + }); + + // step2: normalization + // h[m, n] = [(e[m, n] - mean[m]) / sqrt(var[m] + eps)] * gamma[n] + beta[n] + threadwise_e_load_m_n.Run(e_grid_desc_m_n, + e_global_val_buf, + thread_buffer_desc_m_n, + make_tuple(I0, I0), + e_thread_buf); + + static_for<0, MThreadSliceSize, 1>{}([&](auto m) { + auto divisor = 1 / ck::math::sqrt(welford_var_thread_buf(m) + epsilon); + static_for<0, NThreadSliceSize, 1>{}([&](auto n) { + constexpr auto m_n = thread_buffer_desc_m_n.CalculateOffset(make_tuple(m, n)); + h_thread_buf(Number{}) = + (e_thread_buf(Number{}) - welford_mean_thread_buf(m)) * divisor; + }); + }); + + threadwise_gamma_load_n.Run(gamma_grid_desc_n, + gamma_global_val_buf, + thread_buffer_desc_n, + make_tuple(I0), + gamma_thread_buf); + + static_for<0, MThreadSliceSize, 1>{}([&](auto m) { + static_for<0, NThreadSliceSize, 1>{}([&](auto n) { + constexpr auto m_n = thread_buffer_desc_m_n.CalculateOffset(make_tuple(m, n)); + h_thread_buf(Number{}) = h_thread_buf(Number{}) * gamma_thread_buf(n); + }); + }); + + threadwise_beta_load_n.Run(beta_grid_desc_n, + beta_global_val_buf, + thread_buffer_desc_n, + make_tuple(I0), + beta_thread_buf); + + static_for<0, MThreadSliceSize, 1>{}([&](auto m) { + static_for<0, NThreadSliceSize, 1>{}([&](auto n) { + constexpr auto m_n = thread_buffer_desc_m_n.CalculateOffset(make_tuple(m, n)); + h_thread_buf(Number{}) = h_thread_buf(Number{}) + beta_thread_buf(n); + }); + }); + + threadwise_h_store_m_n.Run(thread_buffer_desc_m_n, + make_tuple(I0, I0), + h_thread_buf, + h_grid_desc_m_n, + h_global_val_buf); + + } // run +}; + +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_elementwise_layernorm_welford_variance.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_elementwise_layernorm_welford_variance.hpp index 0d5cbca925..b09a735902 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_elementwise_layernorm_welford_variance.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_elementwise_layernorm_welford_variance.hpp @@ -434,7 +434,7 @@ struct GridwiseElementwiseLayernormWelfordVariance_mk_to_mk }); static_for<0, MThreadSliceSize, 1>{}([&](auto iM) { - auto divisor = 1 / __builtin_amdgcn_sqrtf(var_thread_buf(iM) + epsilon); + auto divisor = 1 / ck::math::sqrt(var_thread_buf(iM) + epsilon); static_for<0, XThreadBufferNumber, 1>{}([&](auto iK0) { static_for<0, XSrcVectorSize, 1>{}([&](auto iK1) { constexpr auto offset_m_k = diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_normalization_welford_variance.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_normalization_welford_variance.hpp index 7aefd3c066..70a8c020dd 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_normalization_welford_variance.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_normalization_welford_variance.hpp @@ -319,7 +319,7 @@ struct GridwiseNormalizationWelfordVariance_mk_to_mk }); static_for<0, MThreadSliceSize, 1>{}([&](auto iM) { - auto divisor = 1 / __builtin_amdgcn_sqrtf(var_thread_buf(iM) + epsilon); + auto divisor = 1 / ck::math::sqrt(var_thread_buf(iM) + epsilon); static_for<0, XThreadBufferNumber, 1>{}([&](auto iK0) { static_for<0, XSrcVectorSize, 1>{}([&](auto iK1) { constexpr auto offset_m_k = diff --git a/library/include/ck/library/reference_tensor_operation/cpu/reference_layernorm.hpp b/library/include/ck/library/reference_tensor_operation/cpu/reference_layernorm.hpp index 680d94f7d1..2bac5bc5c8 100644 --- a/library/include/ck/library/reference_tensor_operation/cpu/reference_layernorm.hpp +++ b/library/include/ck/library/reference_tensor_operation/cpu/reference_layernorm.hpp @@ -90,10 +90,13 @@ struct ReferenceLayernorm : public device::BaseOperator for(int m = 0; m < M; ++m) { + AccDataType divisor = + static_cast(1) / ck::math::sqrt(var(m) + arg.epsilon_); + for(int n = 0; n < N; ++n) { auto x_val = ck::type_convert(arg.x_m_n_(m, n)); - auto y_val = (x_val - mean(m)) / sqrt(var(m) + arg.epsilon_); + auto y_val = (x_val - mean(m)) * divisor; y_val = (y_val * arg.gamma_n_(n)) + arg.beta_n_(n); arg.acc_elementwise_op_(y_val, y_val); arg.y_m_n_(m, n) = ck::type_convert(y_val);