mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-15 10:37:44 +00:00
Gemm layernorm welford (#413)
* Add device op of gemm layernorm
* [What] Rename F to H
[Why] F and G prepare for welford tensor
* Add gridwise gemm + welford
* Extract template parameter
* Rename kernel. Prepare to add second half kernel
* Extract var
* Add second kernel for gemm+layernorm
* Move to the gemm_layernorm folder
* Rename F and G to mean and var
* Do not use snakeCurved, it makes determination of padding for welford difficult
* Rewrite the device interface and rename some var
* Add welford count
* Update interface
* Sync code, prepare to test on MI200
* Clean the code
* Implement layernorm
* Add comment to mension hipFree
* Wrtie out the e for debug.
This could be remove and use h for instead
* 1. Allocate mean, var and count into by SetWorkSpacePointer.
2. Add GetWorkSpaceSize to calculate the space size
* Add gemm layernorm host code
* use reference layernorm
* Fix bug of blockwise welford for first kernel
* Fix bug of mean var padding for layernorm
* Use sgpr for shuffleM_index
* padding for GemmMeanVarCountGridDescriptor_M_NBlock
* Add layout parameter
* Check argument for gemm
* calculate max count for tail block
* Share E and H memory in device op
* Hard code the vector dim
* Refine the MakeDescriptor
* 1. Remove E parameter, because E is inside of device op
2. Check vector size
* [What] Rename MakeMeanVarDescriptor_M_N
[Why] Prepare to add count version of make descriptor
* Use 1D global memory for count
* Prevent redundant IO
* Update parameter
* Add pipeline v1/v2 selector
* Rename the example name
* Add base class for gemm layernorm
* Refine naming to distinguish naive and welford
* Add comment to explan in detail
* We don't need to pad in N dimension in gemm for mean/var/count. Set NPerTile 1
* Rewrite the 2st kernel, use multiple block along N dimension in layernorm kernel
* Share the vector size
* Refine var name
* [What] Force LayernormThreadSliceSize_N = vector size.
[Why] Memory coalesce
* Add comment
* Extract divisor out of the loop in reference layernorm
* Pad different size for E and H in layernorm kernel according to different block tile
* Refine naming
* Refine naming
* Prevent implicit cast
* [What] use ck::math::sqrt instead of __builtin_amdgcn_sqrtf
[Why] __builtin_amdgcn_sqrtf is only support float, double will cause casting
* Cast only constant
* Change of post shuffle thread descriptor
* Add EMeanVarDataType parameter.
* Merge the mean and var threadwise copy
* Add missing index
* Fix Typo
* Sync the variable with previous if
* 1. Declare e inside the host_gemm_layernorm()
2. Prevent implicit cast in reference code
Co-authored-by: Po Yen Chen <PoYen.Chen@amd.com>
[ROCm/composable_kernel commit: 7829d729fb]
This commit is contained in:
@@ -1,3 +1,4 @@
|
||||
add_example_executable(example_gemm_bias_relu_add_layernorm_xdl_fp16 gemm_bias_relu_add_layernorm_xdl_fp16.cpp)
|
||||
add_example_executable(example_gemm_layernorm_xdl_fp16 gemm_layernorm_xdl_fp16.cpp)
|
||||
add_example_executable(example_gemm_xdl_layernorm_single_kernel_fp16 gemm_xdl_layernorm_single_kernel_fp16.cpp)
|
||||
add_example_executable(example_gemm_bias_relu_add_layernorm_xdl_welford_fp16 gemm_bias_relu_add_layernorm_xdl_welford_fp16.cpp)
|
||||
add_example_executable(example_gemm_bias_relu_add_layernorm_xdl_naive_fp16 gemm_bias_relu_add_layernorm_xdl_naive_fp16.cpp)
|
||||
add_example_executable(example_gemm_layernorm_xdl_naive_fp16 gemm_layernorm_xdl_naive_fp16.cpp)
|
||||
add_example_executable(example_gemm_xdl_layernorm_naive_single_kernel_fp16 gemm_xdl_layernorm_naive_single_kernel_fp16.cpp)
|
||||
|
||||
@@ -0,0 +1,262 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include <iostream>
|
||||
#include <numeric>
|
||||
#include <initializer_list>
|
||||
#include <cstdlib>
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_layernorm_xdl_cshuffle.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
|
||||
|
||||
#include "ck/library/utility/device_memory.hpp"
|
||||
#include "ck/library/utility/host_tensor.hpp"
|
||||
#include "ck/library/utility/host_tensor_generator.hpp"
|
||||
#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp"
|
||||
#include "ck/library/reference_tensor_operation/cpu/reference_layernorm.hpp"
|
||||
#include "ck/library/utility/check_err.hpp"
|
||||
|
||||
template <ck::index_t... Is>
|
||||
using S = ck::Sequence<Is...>;
|
||||
|
||||
using F16 = ck::half_t;
|
||||
using F32 = float;
|
||||
|
||||
using Row = ck::tensor_layout::gemm::RowMajor;
|
||||
using Col = ck::tensor_layout::gemm::ColumnMajor;
|
||||
|
||||
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
|
||||
using AddReluAdd = ck::tensor_operation::element_wise::AddReluAdd;
|
||||
|
||||
// DataType
|
||||
using ADataType = F16;
|
||||
using BDataType = F16;
|
||||
using AccDataType = F32;
|
||||
using CShuffleDataType = F32;
|
||||
using D0DataType = F16;
|
||||
using D1DataType = F16;
|
||||
using DsDataType = ck::Tuple<D0DataType, D1DataType>;
|
||||
using EMeanVarDataType = F16;
|
||||
using GammaDataType = F16;
|
||||
using BetaDataType = F16;
|
||||
using HDataType = F16;
|
||||
|
||||
// Layout
|
||||
using ALayout = Row;
|
||||
using BLayout = Col;
|
||||
using D0Layout = Row;
|
||||
using D1Layout = Row;
|
||||
using DsLayout = ck::Tuple<D0Layout, D1Layout>;
|
||||
using HLayout = Row;
|
||||
|
||||
using AElementOp = PassThrough;
|
||||
using BElementOp = PassThrough;
|
||||
using CDEElementOp = AddReluAdd;
|
||||
using HElementOp = PassThrough;
|
||||
|
||||
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::MNKPadding;
|
||||
|
||||
// clang-format off
|
||||
using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultipleDLayernorm_Xdl_CShuffle
|
||||
//######| ALayout| BLayout| DsLayout| HLayout| AData| BData| AccData| CShuffle| DsData| EMeanVarData| GammaData| BetaData| HData| A| B| CDE| H| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| PostShuffle| PostShuffle| Layernorm| Layernorm|
|
||||
//######| | | | | Type| Type| Type| DataType| Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| ThreadClusterLengths| ScalarPerVector| ThreadClusterLengths| ThreadSliceSize|
|
||||
//######| | | | | | | | | | | | | | Operation| Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _M_N| _M_N| _M_N| _M|
|
||||
//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
< ALayout, BLayout, DsLayout, HLayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EMeanVarDataType, GammaDataType, BetaDataType, HDataType, AElementOp, BElementOp, CDEElementOp, HElementOp, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<32, 8>, 8, S<8, 32>, 8>;
|
||||
// clang-format on
|
||||
|
||||
auto f_host_tensor_descriptor1d = [](std::size_t len, std::size_t stride) {
|
||||
return HostTensorDescriptor(std::vector<std::size_t>({len}),
|
||||
std::vector<std::size_t>({stride}));
|
||||
};
|
||||
|
||||
auto f_host_tensor_descriptor2d =
|
||||
[](std::size_t row, std::size_t col, std::size_t stride, auto layout) {
|
||||
if(std::is_same<decltype(layout), ck::tensor_layout::gemm::RowMajor>::value)
|
||||
{
|
||||
return HostTensorDescriptor(std::vector<std::size_t>({row, col}),
|
||||
std::vector<std::size_t>({stride, 1}));
|
||||
}
|
||||
else
|
||||
{
|
||||
return HostTensorDescriptor(std::vector<std::size_t>({row, col}),
|
||||
std::vector<std::size_t>({1, stride}));
|
||||
}
|
||||
};
|
||||
|
||||
void host_gemm_layernorm(Tensor<HDataType>& h_m_n,
|
||||
const Tensor<ADataType>& a_m_k,
|
||||
const Tensor<BDataType>& b_k_n,
|
||||
const Tensor<D0DataType>& bias_n,
|
||||
const Tensor<D1DataType>& d1_m_n,
|
||||
const Tensor<GammaDataType>& gamma_n,
|
||||
const Tensor<BetaDataType>& beta_n,
|
||||
AElementOp a_element_op,
|
||||
BElementOp b_element_op,
|
||||
CDEElementOp cde_element_op,
|
||||
int M,
|
||||
int N,
|
||||
AccDataType epsilon = 1e-5)
|
||||
{
|
||||
using ReferenceGemm = ck::tensor_operation::host::ReferenceGemm<ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
AccDataType,
|
||||
AElementOp,
|
||||
BElementOp,
|
||||
PassThrough>;
|
||||
|
||||
using ReferenceLayernorm = ck::tensor_operation::host::ReferenceLayernorm<EMeanVarDataType,
|
||||
GammaDataType,
|
||||
BetaDataType,
|
||||
HDataType,
|
||||
AccDataType,
|
||||
HElementOp,
|
||||
2,
|
||||
1>;
|
||||
|
||||
Tensor<EMeanVarDataType> e_m_n(HostTensorDescriptor{M, N});
|
||||
Tensor<AccDataType> c_m_n(HostTensorDescriptor{M, N});
|
||||
|
||||
auto ref_gemm = ReferenceGemm{};
|
||||
auto ref_gemm_invoker = ref_gemm.MakeInvoker();
|
||||
|
||||
auto ref_gemm_argument =
|
||||
ref_gemm.MakeArgument(a_m_k, b_k_n, c_m_n, a_element_op, b_element_op, PassThrough{});
|
||||
|
||||
ref_gemm_invoker.Run(ref_gemm_argument);
|
||||
|
||||
for(int n = 0; n < N; ++n)
|
||||
{
|
||||
AccDataType bias = static_cast<AccDataType>(bias_n(n));
|
||||
for(int m = 0; m < M; ++m)
|
||||
{
|
||||
AccDataType e = static_cast<AccDataType>(e_m_n(m, n));
|
||||
AccDataType d1 = static_cast<AccDataType>(d1_m_n(m, n));
|
||||
cde_element_op(e, c_m_n(m, n), bias, d1);
|
||||
e_m_n(m, n) = static_cast<EMeanVarDataType>(e);
|
||||
}
|
||||
}
|
||||
|
||||
ReferenceLayernorm ref_layernorm;
|
||||
auto ref_layernorm_invoker = ref_layernorm.MakeInvoker();
|
||||
|
||||
auto ref_layernorm_argument = ref_layernorm.MakeArgument(
|
||||
e_m_n, gamma_n, beta_n, h_m_n, HElementOp{}, {M, N}, {1}, epsilon);
|
||||
ref_layernorm_invoker.Run(ref_layernorm_argument);
|
||||
}
|
||||
|
||||
int main()
|
||||
{
|
||||
bool do_verification = true;
|
||||
|
||||
// GEMM shape
|
||||
ck::index_t M = 1024;
|
||||
ck::index_t N = 1024;
|
||||
ck::index_t K = 1024;
|
||||
|
||||
ck::index_t StrideA = K;
|
||||
ck::index_t StrideB = K;
|
||||
ck::index_t StrideD0 = 0;
|
||||
ck::index_t StrideD1 = N;
|
||||
ck::index_t StrideH = N;
|
||||
|
||||
float epsilon = 1e-5;
|
||||
|
||||
Tensor<ADataType> a_m_k(f_host_tensor_descriptor2d(M, K, StrideA, ALayout{}));
|
||||
Tensor<BDataType> b_k_n(f_host_tensor_descriptor2d(K, N, StrideB, BLayout{}));
|
||||
Tensor<D0DataType> d0_n(f_host_tensor_descriptor1d(N, 1));
|
||||
Tensor<D1DataType> d1_m_n(f_host_tensor_descriptor2d(M, N, StrideD1, D1Layout{}));
|
||||
Tensor<GammaDataType> gamma_n(f_host_tensor_descriptor1d(N, 1));
|
||||
Tensor<BetaDataType> beta_n(f_host_tensor_descriptor1d(N, 1));
|
||||
Tensor<HDataType> h_m_n(f_host_tensor_descriptor2d(M, N, StrideH, HLayout{}));
|
||||
|
||||
a_m_k.GenerateTensorValue(GeneratorTensor_3<ADataType>{-1, 1});
|
||||
b_k_n.GenerateTensorValue(GeneratorTensor_3<BDataType>{-1, 1});
|
||||
d0_n.GenerateTensorValue(GeneratorTensor_3<D0DataType>{-1, 1});
|
||||
d1_m_n.GenerateTensorValue(GeneratorTensor_3<D1DataType>{-1, 1});
|
||||
gamma_n.GenerateTensorValue(GeneratorTensor_3<GammaDataType>{-1, 1});
|
||||
beta_n.GenerateTensorValue(GeneratorTensor_3<BetaDataType>{-1, 1});
|
||||
|
||||
DeviceMem a_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpaceSize());
|
||||
DeviceMem b_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpaceSize());
|
||||
DeviceMem d0_device_buf(sizeof(D0DataType) * d0_n.mDesc.GetElementSpaceSize());
|
||||
DeviceMem d1_device_buf(sizeof(D1DataType) * d1_m_n.mDesc.GetElementSpaceSize());
|
||||
DeviceMem gamma_device_buf(sizeof(GammaDataType) * gamma_n.mDesc.GetElementSpaceSize());
|
||||
DeviceMem beta_device_buf(sizeof(BetaDataType) * beta_n.mDesc.GetElementSpaceSize());
|
||||
DeviceMem h_device_buf(sizeof(HDataType) * h_m_n.mDesc.GetElementSpaceSize());
|
||||
|
||||
a_device_buf.ToDevice(a_m_k.mData.data());
|
||||
b_device_buf.ToDevice(b_k_n.mData.data());
|
||||
d0_device_buf.ToDevice(d0_n.mData.data());
|
||||
d1_device_buf.ToDevice(d1_m_n.mData.data());
|
||||
gamma_device_buf.ToDevice(gamma_n.mData.data());
|
||||
beta_device_buf.ToDevice(beta_n.mData.data());
|
||||
|
||||
auto a_element_op = AElementOp{};
|
||||
auto b_element_op = BElementOp{};
|
||||
auto cde_element_op = CDEElementOp{};
|
||||
auto h_element_op = HElementOp{};
|
||||
|
||||
auto device_op = DeviceOpInstance{};
|
||||
auto invoker = device_op.MakeInvoker();
|
||||
auto argument =
|
||||
device_op.MakeArgument(a_device_buf.GetDeviceBuffer(),
|
||||
b_device_buf.GetDeviceBuffer(),
|
||||
{d0_device_buf.GetDeviceBuffer(), d1_device_buf.GetDeviceBuffer()},
|
||||
gamma_device_buf.GetDeviceBuffer(),
|
||||
beta_device_buf.GetDeviceBuffer(),
|
||||
h_device_buf.GetDeviceBuffer(),
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
StrideA,
|
||||
StrideB,
|
||||
{StrideD0, StrideD1},
|
||||
StrideH,
|
||||
epsilon,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
cde_element_op,
|
||||
h_element_op);
|
||||
|
||||
if(!device_op.IsSupportedArgument(argument))
|
||||
{
|
||||
throw std::runtime_error("wrong! this device_op instance does not support this problem");
|
||||
}
|
||||
|
||||
size_t workspace_sz = device_op.GetWorkSpaceSize(&argument);
|
||||
DeviceMem workspace_dev(workspace_sz);
|
||||
device_op.SetWorkSpacePointer(&argument, workspace_dev.GetDeviceBuffer());
|
||||
|
||||
invoker.Run(argument, StreamConfig{nullptr, false});
|
||||
|
||||
bool pass = true;
|
||||
|
||||
if(do_verification)
|
||||
{
|
||||
Tensor<HDataType> h_m_n_host(HostTensorDescriptor{M, N});
|
||||
host_gemm_layernorm(h_m_n_host,
|
||||
a_m_k,
|
||||
b_k_n,
|
||||
d0_n,
|
||||
d1_m_n,
|
||||
gamma_n,
|
||||
beta_n,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
cde_element_op,
|
||||
M,
|
||||
N,
|
||||
epsilon);
|
||||
|
||||
h_device_buf.FromDevice(h_m_n.mData.data());
|
||||
pass &=
|
||||
ck::utils::check_err(h_m_n, h_m_n_host, "Error: Incorrect results h_m_n", 1e-2, 1e-2);
|
||||
}
|
||||
|
||||
return pass ? 0 : 1;
|
||||
}
|
||||
@@ -0,0 +1,67 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <array>
|
||||
#include "device_base.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
|
||||
// GEMM:
|
||||
// input : A[M, K]
|
||||
// input : B[N, K]
|
||||
// input : D0[M, N], D1[M, N], ...
|
||||
// output : E[M, N]
|
||||
// output : H[M, N]
|
||||
// C = a_op(A) * b_op(B)
|
||||
// E = cde_op(C, D0, D1, ...)
|
||||
// H = layernorm(E)
|
||||
// Assume:
|
||||
// D0, D1, ... and E have the same layout
|
||||
// Calculate mean & variance along N dimension in layernorm(E)
|
||||
template <typename ALayout,
|
||||
typename BLayout,
|
||||
typename DsLayout,
|
||||
typename HLayout,
|
||||
typename ADataType,
|
||||
typename BDataType,
|
||||
typename DsDataType,
|
||||
typename GammaDataType,
|
||||
typename BetaDataType,
|
||||
typename HDataType,
|
||||
typename AElementwiseOperation,
|
||||
typename BElementwiseOperation,
|
||||
typename CDEElementwiseOperation,
|
||||
typename HElementwiseOperation>
|
||||
struct DeviceGemmMultipleDLayernorm : public BaseOperator
|
||||
{
|
||||
static constexpr index_t NumDTensor = DsDataType::Size();
|
||||
virtual std::unique_ptr<BaseArgument>
|
||||
MakeArgumentPointer(const void* p_a,
|
||||
const void* p_b,
|
||||
std::array<const void*, NumDTensor> p_ds,
|
||||
const void* p_gamma,
|
||||
const void* p_beta,
|
||||
void* p_h,
|
||||
index_t MRaw,
|
||||
index_t NRaw,
|
||||
index_t KRaw,
|
||||
index_t StrideA,
|
||||
index_t StrideB,
|
||||
std::array<index_t, NumDTensor> StrideDs,
|
||||
index_t StrideH,
|
||||
double epsilon,
|
||||
AElementwiseOperation a_element_op,
|
||||
BElementwiseOperation b_element_op,
|
||||
CDEElementwiseOperation cde_element_op,
|
||||
HElementwiseOperation h_element_op) = 0;
|
||||
|
||||
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
|
||||
}; // namespace device
|
||||
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,394 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck/utility/common_header.hpp"
|
||||
#include "ck/tensor_description/multi_index_transform_helper.hpp"
|
||||
#include "ck/tensor_description/tensor_descriptor.hpp"
|
||||
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
|
||||
#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp"
|
||||
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v1.hpp"
|
||||
#include "ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp"
|
||||
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp"
|
||||
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v7.hpp"
|
||||
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
|
||||
#include "ck/tensor_operation/gpu/block/blockwise_welford.hpp"
|
||||
#include "ck/tensor_operation/gpu/thread/threadwise_welford.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
template <typename EMeanVarDataType,
|
||||
typename HDataType,
|
||||
typename GammaDataType,
|
||||
typename BetaDataType,
|
||||
typename ComputeDataType,
|
||||
typename EHGridDesc_M_N,
|
||||
typename MeanVarGridDesc_M_NBlock,
|
||||
typename CountGridDesc_M_NBlock,
|
||||
typename GammaBetaGridDesc_N,
|
||||
typename HElementwiseOperation,
|
||||
index_t BlockSize,
|
||||
index_t MThreadClusterSize,
|
||||
index_t NThreadClusterSize,
|
||||
index_t MThreadSliceSize,
|
||||
index_t NThreadSliceSize,
|
||||
index_t ESrcVectorSize,
|
||||
index_t HDstVectorSize,
|
||||
index_t GammaSrcVectorSize,
|
||||
index_t BetaSrcVectorSize>
|
||||
struct GridwiseWelfordSecondHalfLayernorm2d
|
||||
{
|
||||
static_assert(NThreadSliceSize % ESrcVectorSize == 0 &&
|
||||
NThreadSliceSize % GammaSrcVectorSize == 0 &&
|
||||
NThreadSliceSize % BetaSrcVectorSize == 0,
|
||||
"Invalid thread slice sizes and/or vector sizes configuration, please check!");
|
||||
|
||||
static_assert(NThreadSliceSize % HDstVectorSize == 0,
|
||||
"Invalid thread slice sizes and/or vector sizes configuration, please check!");
|
||||
|
||||
using ThreadClusterLengths_M_N = Sequence<MThreadClusterSize, NThreadClusterSize>;
|
||||
using ThreadBufferDimAccessOrder = Sequence<0, 1>;
|
||||
using ThreadClusterArrangeOrder = Sequence<0, 1>;
|
||||
|
||||
static constexpr auto thread_cluster_desc_m_n =
|
||||
make_cluster_descriptor(ThreadClusterLengths_M_N{}, ThreadClusterArrangeOrder{});
|
||||
|
||||
using ThreadBufferLengths_M_N = Sequence<MThreadSliceSize, NThreadSliceSize>;
|
||||
static constexpr auto thread_buffer_desc_m_n = make_naive_tensor_descriptor_packed(
|
||||
make_tuple(Number<MThreadSliceSize>{}, Number<NThreadSliceSize>{}));
|
||||
|
||||
using ThreadBufferLengths_M_1 = Sequence<MThreadSliceSize, 1>;
|
||||
static constexpr auto thread_buffer_desc_m_1 =
|
||||
make_naive_tensor_descriptor_packed(make_tuple(Number<MThreadSliceSize>{}, Number<1>{}));
|
||||
|
||||
using ThreadBufferLengths_N = Sequence<NThreadSliceSize>;
|
||||
static constexpr auto thread_buffer_desc_n =
|
||||
make_naive_tensor_descriptor_packed(make_tuple(Number<NThreadSliceSize>{}));
|
||||
|
||||
using ThreadWelfordSrcDesc_M_1 = decltype(thread_buffer_desc_m_1);
|
||||
using ThreadWelfordDstDesc_M =
|
||||
decltype(make_naive_tensor_descriptor_packed(make_tuple(Number<MThreadSliceSize>{})));
|
||||
|
||||
using ThreadwiseWelford =
|
||||
ThreadwiseWelfordMerge<ComputeDataType, ThreadWelfordSrcDesc_M_1, ThreadWelfordDstDesc_M>;
|
||||
|
||||
using BlockwiseWelford = BlockwiseWelford<ComputeDataType,
|
||||
BlockSize,
|
||||
ThreadClusterLengths_M_N,
|
||||
ThreadClusterArrangeOrder>;
|
||||
|
||||
static constexpr auto I0 = Number<0>{};
|
||||
static constexpr auto I1 = Number<1>{};
|
||||
|
||||
static constexpr index_t M_BlockTileSize = MThreadClusterSize * MThreadSliceSize;
|
||||
static constexpr index_t N_BlockTileSize = NThreadClusterSize * NThreadSliceSize;
|
||||
|
||||
__device__ static void Run(const EMeanVarDataType* __restrict__ p_e_grid,
|
||||
const EMeanVarDataType* __restrict__ p_in_welford_mean_grid,
|
||||
const EMeanVarDataType* __restrict__ p_in_welford_var_grid,
|
||||
const int32_t* __restrict__ p_in_welford_count_grid,
|
||||
const GammaDataType* __restrict__ p_gamma_grid,
|
||||
const BetaDataType* __restrict__ p_beta_grid,
|
||||
HDataType* __restrict__ p_h_grid,
|
||||
const EHGridDesc_M_N& e_grid_desc_m_n,
|
||||
const EHGridDesc_M_N& h_grid_desc_m_n,
|
||||
const MeanVarGridDesc_M_NBlock& mean_var_grid_desc_m_nblock,
|
||||
const CountGridDesc_M_NBlock& count_grid_desc_m_nblock,
|
||||
const GammaBetaGridDesc_N& gamma_grid_desc_n,
|
||||
const GammaBetaGridDesc_N& beta_grid_desc_n,
|
||||
index_t numMeanVarCountBlockTileIteration_N,
|
||||
index_t NBlockClusterLength,
|
||||
ComputeDataType epsilon,
|
||||
HElementwiseOperation h_element_op)
|
||||
{
|
||||
// Thread/Block id
|
||||
const index_t thread_local_id = get_thread_local_1d_id();
|
||||
const index_t block_global_id = get_block_1d_id();
|
||||
const auto block_work_idx = make_tuple(block_global_id / NBlockClusterLength,
|
||||
block_global_id % NBlockClusterLength);
|
||||
|
||||
const auto thread_cluster_idx =
|
||||
thread_cluster_desc_m_n.CalculateBottomIndex(make_multi_index(thread_local_id));
|
||||
const auto thread_m_cluster_id = thread_cluster_idx[I0];
|
||||
const auto thread_n_cluster_id = thread_cluster_idx[I1];
|
||||
|
||||
// Global Memory
|
||||
const auto e_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_e_grid, e_grid_desc_m_n.GetElementSpaceSize());
|
||||
|
||||
const auto welford_mean_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_in_welford_mean_grid, mean_var_grid_desc_m_nblock.GetElementSpaceSize());
|
||||
|
||||
const auto welford_var_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_in_welford_var_grid, mean_var_grid_desc_m_nblock.GetElementSpaceSize());
|
||||
|
||||
const auto welford_count_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_in_welford_count_grid, count_grid_desc_m_nblock.GetElementSpaceSize());
|
||||
|
||||
const auto gamma_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_gamma_grid, gamma_grid_desc_n.GetElementSpaceSize());
|
||||
|
||||
const auto beta_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_beta_grid, beta_grid_desc_n.GetElementSpaceSize());
|
||||
|
||||
auto h_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_h_grid, h_grid_desc_m_n.GetElementSpaceSize());
|
||||
|
||||
// VGPR
|
||||
StaticBuffer<AddressSpaceEnum::Vgpr, ComputeDataType, MThreadSliceSize, true>
|
||||
in_welford_mean_thread_buf;
|
||||
StaticBuffer<AddressSpaceEnum::Vgpr, ComputeDataType, MThreadSliceSize, true>
|
||||
in_welford_var_thread_buf;
|
||||
StaticBuffer<AddressSpaceEnum::Vgpr, int32_t, MThreadSliceSize, true>
|
||||
in_welford_count_thread_buf;
|
||||
|
||||
StaticBuffer<AddressSpaceEnum::Vgpr, ComputeDataType, MThreadSliceSize, true>
|
||||
welford_mean_thread_buf;
|
||||
StaticBuffer<AddressSpaceEnum::Vgpr, ComputeDataType, MThreadSliceSize, true>
|
||||
welford_var_thread_buf;
|
||||
StaticBuffer<AddressSpaceEnum::Vgpr, int32_t, MThreadSliceSize, true>
|
||||
welford_count_thread_buf;
|
||||
|
||||
StaticBuffer<AddressSpaceEnum::Vgpr,
|
||||
ComputeDataType,
|
||||
MThreadSliceSize * NThreadSliceSize,
|
||||
true>
|
||||
e_thread_buf;
|
||||
StaticBuffer<AddressSpaceEnum::Vgpr,
|
||||
ComputeDataType,
|
||||
MThreadSliceSize * NThreadSliceSize,
|
||||
true>
|
||||
gamma_thread_buf;
|
||||
StaticBuffer<AddressSpaceEnum::Vgpr,
|
||||
ComputeDataType,
|
||||
MThreadSliceSize * NThreadSliceSize,
|
||||
true>
|
||||
beta_thread_buf;
|
||||
StaticBuffer<AddressSpaceEnum::Vgpr,
|
||||
ComputeDataType,
|
||||
MThreadSliceSize * NThreadSliceSize,
|
||||
true>
|
||||
h_thread_buf;
|
||||
|
||||
// IO
|
||||
auto threadwise_mean_load_m_nblock =
|
||||
ThreadwiseTensorSliceTransfer_v2<EMeanVarDataType,
|
||||
ComputeDataType,
|
||||
MeanVarGridDesc_M_NBlock,
|
||||
decltype(thread_buffer_desc_m_1),
|
||||
ThreadBufferLengths_M_1,
|
||||
ThreadBufferDimAccessOrder,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
true>(
|
||||
mean_var_grid_desc_m_nblock,
|
||||
make_multi_index(block_work_idx[I0] * M_BlockTileSize +
|
||||
thread_m_cluster_id * MThreadSliceSize,
|
||||
thread_n_cluster_id));
|
||||
|
||||
auto threadwise_var_load_m_nblock =
|
||||
ThreadwiseTensorSliceTransfer_v2<EMeanVarDataType,
|
||||
ComputeDataType,
|
||||
MeanVarGridDesc_M_NBlock,
|
||||
decltype(thread_buffer_desc_m_1),
|
||||
ThreadBufferLengths_M_1,
|
||||
ThreadBufferDimAccessOrder,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
true>(
|
||||
mean_var_grid_desc_m_nblock,
|
||||
make_multi_index(block_work_idx[I0] * M_BlockTileSize +
|
||||
thread_m_cluster_id * MThreadSliceSize,
|
||||
thread_n_cluster_id));
|
||||
|
||||
auto threadwise_count_load_m_nblock =
|
||||
ThreadwiseTensorSliceTransfer_v2<int32_t,
|
||||
int32_t,
|
||||
CountGridDesc_M_NBlock,
|
||||
decltype(thread_buffer_desc_m_1),
|
||||
ThreadBufferLengths_M_1,
|
||||
ThreadBufferDimAccessOrder,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
true>(
|
||||
count_grid_desc_m_nblock,
|
||||
make_multi_index(block_work_idx[I0] * M_BlockTileSize +
|
||||
thread_m_cluster_id * MThreadSliceSize,
|
||||
thread_n_cluster_id));
|
||||
|
||||
auto threadwise_e_load_m_n =
|
||||
ThreadwiseTensorSliceTransfer_v2<EMeanVarDataType,
|
||||
ComputeDataType,
|
||||
decltype(e_grid_desc_m_n),
|
||||
decltype(thread_buffer_desc_m_n),
|
||||
ThreadBufferLengths_M_N,
|
||||
ThreadBufferDimAccessOrder,
|
||||
1, // SrcVectorDim
|
||||
ESrcVectorSize,
|
||||
1,
|
||||
true>(
|
||||
e_grid_desc_m_n,
|
||||
make_multi_index(
|
||||
block_work_idx[I0] * M_BlockTileSize + thread_m_cluster_id * MThreadSliceSize,
|
||||
block_work_idx[I1] * N_BlockTileSize + thread_n_cluster_id * NThreadSliceSize));
|
||||
|
||||
auto threadwise_gamma_load_n =
|
||||
ThreadwiseTensorSliceTransfer_v2<GammaDataType,
|
||||
ComputeDataType,
|
||||
decltype(gamma_grid_desc_n),
|
||||
decltype(thread_buffer_desc_n),
|
||||
ThreadBufferLengths_N,
|
||||
Sequence<0>, // DimAccessOrder,
|
||||
0, // SrcVectorDim,
|
||||
GammaSrcVectorSize,
|
||||
1,
|
||||
true>(
|
||||
gamma_grid_desc_n,
|
||||
make_multi_index(block_work_idx[I1] * N_BlockTileSize +
|
||||
thread_n_cluster_id * NThreadSliceSize));
|
||||
|
||||
auto threadwise_beta_load_n =
|
||||
ThreadwiseTensorSliceTransfer_v2<BetaDataType,
|
||||
ComputeDataType,
|
||||
decltype(beta_grid_desc_n),
|
||||
decltype(thread_buffer_desc_n),
|
||||
ThreadBufferLengths_N,
|
||||
Sequence<0>, // DimAccessOrder,
|
||||
0, // SrcVectorDim,
|
||||
BetaSrcVectorSize,
|
||||
1,
|
||||
true>(
|
||||
beta_grid_desc_n,
|
||||
make_multi_index(block_work_idx[I1] * N_BlockTileSize +
|
||||
thread_n_cluster_id * NThreadSliceSize));
|
||||
|
||||
auto threadwise_h_store_m_n =
|
||||
ThreadwiseTensorSliceTransfer_v1r3<ComputeDataType,
|
||||
HDataType,
|
||||
decltype(thread_buffer_desc_m_n),
|
||||
decltype(h_grid_desc_m_n),
|
||||
HElementwiseOperation,
|
||||
ThreadBufferLengths_M_N,
|
||||
ThreadBufferDimAccessOrder,
|
||||
1, // DstVectorDim
|
||||
HDstVectorSize,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
1,
|
||||
true>(
|
||||
h_grid_desc_m_n,
|
||||
make_multi_index(
|
||||
block_work_idx[I0] * M_BlockTileSize + thread_m_cluster_id * MThreadSliceSize,
|
||||
block_work_idx[I1] * N_BlockTileSize + thread_n_cluster_id * NThreadSliceSize),
|
||||
h_element_op);
|
||||
|
||||
// step1: Merge mean and variance
|
||||
constexpr auto mean_var_count_thread_copy_step_I0_n =
|
||||
make_multi_index(I0, NThreadClusterSize);
|
||||
|
||||
static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
|
||||
welford_mean_thread_buf(I) = type_convert<ComputeDataType>(0.0f);
|
||||
welford_var_thread_buf(I) = type_convert<ComputeDataType>(0.0f);
|
||||
welford_count_thread_buf(I) = 0;
|
||||
});
|
||||
|
||||
for(index_t n = 0; n < numMeanVarCountBlockTileIteration_N; ++n)
|
||||
{
|
||||
threadwise_mean_load_m_nblock.Run(mean_var_grid_desc_m_nblock,
|
||||
welford_mean_global_val_buf,
|
||||
thread_buffer_desc_m_1,
|
||||
make_tuple(I0, I0),
|
||||
in_welford_mean_thread_buf);
|
||||
|
||||
threadwise_var_load_m_nblock.Run(mean_var_grid_desc_m_nblock,
|
||||
welford_var_global_val_buf,
|
||||
thread_buffer_desc_m_1,
|
||||
make_tuple(I0, I0),
|
||||
in_welford_var_thread_buf);
|
||||
|
||||
threadwise_count_load_m_nblock.Run(count_grid_desc_m_nblock,
|
||||
welford_count_global_val_buf,
|
||||
thread_buffer_desc_m_1,
|
||||
make_tuple(I0, I0),
|
||||
in_welford_count_thread_buf);
|
||||
|
||||
ThreadwiseWelford::Run(in_welford_mean_thread_buf,
|
||||
in_welford_var_thread_buf,
|
||||
in_welford_count_thread_buf,
|
||||
welford_mean_thread_buf,
|
||||
welford_var_thread_buf,
|
||||
welford_count_thread_buf);
|
||||
|
||||
threadwise_mean_load_m_nblock.MoveSrcSliceWindow(mean_var_grid_desc_m_nblock,
|
||||
mean_var_count_thread_copy_step_I0_n);
|
||||
threadwise_var_load_m_nblock.MoveSrcSliceWindow(mean_var_grid_desc_m_nblock,
|
||||
mean_var_count_thread_copy_step_I0_n);
|
||||
threadwise_count_load_m_nblock.MoveSrcSliceWindow(count_grid_desc_m_nblock,
|
||||
mean_var_count_thread_copy_step_I0_n);
|
||||
}
|
||||
|
||||
static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
|
||||
if constexpr(I > 0)
|
||||
block_sync_lds();
|
||||
|
||||
BlockwiseWelford::Run(
|
||||
welford_mean_thread_buf(I), welford_var_thread_buf(I), welford_count_thread_buf(I));
|
||||
});
|
||||
|
||||
// step2: normalization
|
||||
// h[m, n] = [(e[m, n] - mean[m]) / sqrt(var[m] + eps)] * gamma[n] + beta[n]
|
||||
threadwise_e_load_m_n.Run(e_grid_desc_m_n,
|
||||
e_global_val_buf,
|
||||
thread_buffer_desc_m_n,
|
||||
make_tuple(I0, I0),
|
||||
e_thread_buf);
|
||||
|
||||
static_for<0, MThreadSliceSize, 1>{}([&](auto m) {
|
||||
auto divisor = 1 / ck::math::sqrt(welford_var_thread_buf(m) + epsilon);
|
||||
static_for<0, NThreadSliceSize, 1>{}([&](auto n) {
|
||||
constexpr auto m_n = thread_buffer_desc_m_n.CalculateOffset(make_tuple(m, n));
|
||||
h_thread_buf(Number<m_n>{}) =
|
||||
(e_thread_buf(Number<m_n>{}) - welford_mean_thread_buf(m)) * divisor;
|
||||
});
|
||||
});
|
||||
|
||||
threadwise_gamma_load_n.Run(gamma_grid_desc_n,
|
||||
gamma_global_val_buf,
|
||||
thread_buffer_desc_n,
|
||||
make_tuple(I0),
|
||||
gamma_thread_buf);
|
||||
|
||||
static_for<0, MThreadSliceSize, 1>{}([&](auto m) {
|
||||
static_for<0, NThreadSliceSize, 1>{}([&](auto n) {
|
||||
constexpr auto m_n = thread_buffer_desc_m_n.CalculateOffset(make_tuple(m, n));
|
||||
h_thread_buf(Number<m_n>{}) = h_thread_buf(Number<m_n>{}) * gamma_thread_buf(n);
|
||||
});
|
||||
});
|
||||
|
||||
threadwise_beta_load_n.Run(beta_grid_desc_n,
|
||||
beta_global_val_buf,
|
||||
thread_buffer_desc_n,
|
||||
make_tuple(I0),
|
||||
beta_thread_buf);
|
||||
|
||||
static_for<0, MThreadSliceSize, 1>{}([&](auto m) {
|
||||
static_for<0, NThreadSliceSize, 1>{}([&](auto n) {
|
||||
constexpr auto m_n = thread_buffer_desc_m_n.CalculateOffset(make_tuple(m, n));
|
||||
h_thread_buf(Number<m_n>{}) = h_thread_buf(Number<m_n>{}) + beta_thread_buf(n);
|
||||
});
|
||||
});
|
||||
|
||||
threadwise_h_store_m_n.Run(thread_buffer_desc_m_n,
|
||||
make_tuple(I0, I0),
|
||||
h_thread_buf,
|
||||
h_grid_desc_m_n,
|
||||
h_global_val_buf);
|
||||
|
||||
} // run
|
||||
};
|
||||
|
||||
} // namespace ck
|
||||
@@ -434,7 +434,7 @@ struct GridwiseElementwiseLayernormWelfordVariance_mk_to_mk
|
||||
});
|
||||
|
||||
static_for<0, MThreadSliceSize, 1>{}([&](auto iM) {
|
||||
auto divisor = 1 / __builtin_amdgcn_sqrtf(var_thread_buf(iM) + epsilon);
|
||||
auto divisor = 1 / ck::math::sqrt(var_thread_buf(iM) + epsilon);
|
||||
static_for<0, XThreadBufferNumber, 1>{}([&](auto iK0) {
|
||||
static_for<0, XSrcVectorSize, 1>{}([&](auto iK1) {
|
||||
constexpr auto offset_m_k =
|
||||
|
||||
@@ -319,7 +319,7 @@ struct GridwiseNormalizationWelfordVariance_mk_to_mk
|
||||
});
|
||||
|
||||
static_for<0, MThreadSliceSize, 1>{}([&](auto iM) {
|
||||
auto divisor = 1 / __builtin_amdgcn_sqrtf(var_thread_buf(iM) + epsilon);
|
||||
auto divisor = 1 / ck::math::sqrt(var_thread_buf(iM) + epsilon);
|
||||
static_for<0, XThreadBufferNumber, 1>{}([&](auto iK0) {
|
||||
static_for<0, XSrcVectorSize, 1>{}([&](auto iK1) {
|
||||
constexpr auto offset_m_k =
|
||||
|
||||
@@ -90,10 +90,13 @@ struct ReferenceLayernorm : public device::BaseOperator
|
||||
|
||||
for(int m = 0; m < M; ++m)
|
||||
{
|
||||
AccDataType divisor =
|
||||
static_cast<AccDataType>(1) / ck::math::sqrt(var(m) + arg.epsilon_);
|
||||
|
||||
for(int n = 0; n < N; ++n)
|
||||
{
|
||||
auto x_val = ck::type_convert<AccDataType>(arg.x_m_n_(m, n));
|
||||
auto y_val = (x_val - mean(m)) / sqrt(var(m) + arg.epsilon_);
|
||||
auto y_val = (x_val - mean(m)) * divisor;
|
||||
y_val = (y_val * arg.gamma_n_(n)) + arg.beta_n_(n);
|
||||
arg.acc_elementwise_op_(y_val, y_val);
|
||||
arg.y_m_n_(m, n) = ck::type_convert<YDataType>(y_val);
|
||||
|
||||
Reference in New Issue
Block a user