mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-19 22:39:03 +00:00
Standalone layernorm (#315)
* Implement layernorm kernel and deviceOp * verify gpu kernel with host code * 1. Separate gamma aand beta from affine 2. Check if argument is valid * clean * Sync the naming * Support sweep once mode if we can put k dimension data inside one block * [What] Get length from upper length. [Why] if we get length directly, we may get length after padding. * We only use one block in K dimension. Hence, we can simplify the indexing of global R/W. * Use 1d descriptor for gamma and beta * Add accElementwiseOp * Extract layernorm host code * Support different YVectorDim in GridwiseLayernorm * Rename XSrcVectorDim to XYSrcVectorDim. Because we use same parameter in deviceOp * Gamma and beta can share the VGPR. * Add test for fp32 and fp16 * Fix bug of concurrency and add test case which may fail orignally * Propagate NaN for layernorm Co-authored-by: Chao Liu <chao.liu2@amd.com>
This commit is contained in:
@@ -129,7 +129,7 @@ void host_gemm_layernorm(Tensor<LayerNormOutDataType>& out_m_n,
|
||||
const Tensor<ADataType>& a_m_k,
|
||||
const Tensor<ADataType>& b_k_n,
|
||||
const Tensor<GammaDataType>& gamma_n,
|
||||
const Tensor<GammaDataType>& beta_n,
|
||||
const Tensor<BetaDataType>& beta_n,
|
||||
A_functor a_element_op,
|
||||
B_functor b_element_op,
|
||||
C_functor c_element_op,
|
||||
|
||||
@@ -212,6 +212,8 @@ int main(int argc, char* argv[])
|
||||
|
||||
auto device_instance = DeviceInstance{};
|
||||
|
||||
std::cout << i_inLengths.size() << ", " << i_inStrides.size() << std::endl;
|
||||
|
||||
auto argument_ptr = device_instance.MakeArgumentPointer(i_inLengths,
|
||||
i_inStrides,
|
||||
reduceDims,
|
||||
|
||||
1
example/27_layernorm/CMakeLists.txt
Normal file
1
example/27_layernorm/CMakeLists.txt
Normal file
@@ -0,0 +1 @@
|
||||
add_example_executable(example_layernorm_blockwise layernorm_blockwise.cpp)
|
||||
133
example/27_layernorm/layernorm_blockwise.cpp
Normal file
133
example/27_layernorm/layernorm_blockwise.cpp
Normal file
@@ -0,0 +1,133 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include <iostream>
|
||||
#include <numeric>
|
||||
#include <initializer_list>
|
||||
#include <cstdlib>
|
||||
#include <getopt.h>
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/utility/reduction_enums.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/device_layernorm.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/reduction_operator_mapping.hpp"
|
||||
|
||||
#include "ck/library/utility/check_err.hpp"
|
||||
#include "ck/library/host_tensor/device_memory.hpp"
|
||||
#include "ck/library/host_tensor/host_common_util.hpp"
|
||||
#include "ck/library/host_tensor/host_tensor.hpp"
|
||||
#include "ck/library/host_tensor/host_tensor_generator.hpp"
|
||||
#include "ck/library/reference_tensor_operation/cpu/reference_layernorm.hpp"
|
||||
|
||||
using XDataType = ck::half_t;
|
||||
using GammaDataType = ck::half_t;
|
||||
using BetaDataType = ck::half_t;
|
||||
using YDataType = ck::half_t;
|
||||
using AccDataType = float;
|
||||
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
|
||||
|
||||
constexpr int Rank = 2;
|
||||
constexpr int NumReduceDim = 1;
|
||||
|
||||
using DeviceInstance = ck::tensor_operation::device::DeviceLayernorm<XDataType,
|
||||
GammaDataType,
|
||||
BetaDataType,
|
||||
AccDataType,
|
||||
YDataType,
|
||||
PassThrough,
|
||||
Rank,
|
||||
NumReduceDim,
|
||||
256, // BlockSize
|
||||
8, // ClusterM
|
||||
32, // ClusterK
|
||||
1, // SliceM
|
||||
8, // SliceK
|
||||
1, // SrcVecDim (0=M, 1=K)
|
||||
8, // SrcScalarPerVector
|
||||
8, // GammaScalarPerVector
|
||||
8, // BetaScalarPerVector
|
||||
1>; // OutScalarPerVector
|
||||
|
||||
int main()
|
||||
{
|
||||
bool time_kernel = false;
|
||||
|
||||
ck::index_t M = 1024;
|
||||
ck::index_t N = 1024;
|
||||
ck::index_t Stride = N;
|
||||
|
||||
auto f_host_tensor_descriptor1d = [](std::size_t len, std::size_t stride) {
|
||||
return HostTensorDescriptor(std::vector<std::size_t>({len}),
|
||||
std::vector<std::size_t>({stride}));
|
||||
};
|
||||
|
||||
auto f_host_tensor_descriptor2d = [](std::size_t row, std::size_t col, std::size_t stride) {
|
||||
return HostTensorDescriptor(std::vector<std::size_t>({row, col}),
|
||||
std::vector<std::size_t>({stride, 1}));
|
||||
};
|
||||
|
||||
Tensor<XDataType> x(f_host_tensor_descriptor2d(M, N, Stride));
|
||||
Tensor<GammaDataType> gamma(f_host_tensor_descriptor1d(N, 1));
|
||||
Tensor<BetaDataType> beta(f_host_tensor_descriptor1d(N, 1));
|
||||
Tensor<YDataType> y(f_host_tensor_descriptor2d(M, N, Stride));
|
||||
|
||||
x.GenerateTensorValue(GeneratorTensor_3<XDataType>{0.0, 1.0});
|
||||
gamma.GenerateTensorValue(GeneratorTensor_3<GammaDataType>{0.0, 1.0});
|
||||
beta.GenerateTensorValue(GeneratorTensor_3<BetaDataType>{0.0, 1.0});
|
||||
|
||||
DeviceMem x_dev(sizeof(XDataType) * x.mDesc.GetElementSpace());
|
||||
DeviceMem gamma_dev(sizeof(GammaDataType) * gamma.mDesc.GetElementSpace());
|
||||
DeviceMem beta_dev(sizeof(BetaDataType) * beta.mDesc.GetElementSpace());
|
||||
DeviceMem y_dev(sizeof(YDataType) * y.mDesc.GetElementSpace());
|
||||
|
||||
x_dev.ToDevice(x.mData.data());
|
||||
gamma_dev.ToDevice(gamma.mData.data());
|
||||
beta_dev.ToDevice(beta.mData.data());
|
||||
|
||||
auto device_instance = DeviceInstance{};
|
||||
auto argument_ptr = device_instance.MakeArgumentPointer(
|
||||
{M, N},
|
||||
std::vector<ck::index_t>{x.mDesc.GetStrides().begin(), x.mDesc.GetStrides().end()},
|
||||
std::vector<ck::index_t>{gamma.mDesc.GetStrides().begin(), gamma.mDesc.GetStrides().end()},
|
||||
std::vector<ck::index_t>{beta.mDesc.GetStrides().begin(), beta.mDesc.GetStrides().end()},
|
||||
{1},
|
||||
1e-4,
|
||||
x_dev.GetDeviceBuffer(),
|
||||
gamma_dev.GetDeviceBuffer(),
|
||||
beta_dev.GetDeviceBuffer(),
|
||||
y_dev.GetDeviceBuffer(),
|
||||
PassThrough{});
|
||||
|
||||
if(!device_instance.IsSupportedArgument(argument_ptr.get()))
|
||||
{
|
||||
std::cout << "The runtime parameters are not supported" << std::endl;
|
||||
return 1;
|
||||
};
|
||||
|
||||
auto invoker_ptr = device_instance.MakeInvokerPointer();
|
||||
invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, time_kernel});
|
||||
|
||||
bool pass = true;
|
||||
{
|
||||
Tensor<YDataType> host_y(f_host_tensor_descriptor2d(M, N, Stride));
|
||||
using ReferenceInstance = ck::tensor_operation::host::ReferenceLayernorm<XDataType,
|
||||
GammaDataType,
|
||||
BetaDataType,
|
||||
YDataType,
|
||||
AccDataType,
|
||||
PassThrough,
|
||||
Rank,
|
||||
NumReduceDim>;
|
||||
|
||||
ReferenceInstance ref;
|
||||
auto ref_argument =
|
||||
ref.MakeArgument(x, gamma, beta, host_y, PassThrough{}, {M, N}, {1}, 1e-4);
|
||||
auto ref_invoker = ref.MakeInvoker();
|
||||
ref_invoker.Run(ref_argument);
|
||||
|
||||
y_dev.FromDevice(y.mData.data());
|
||||
pass &=
|
||||
ck::utils::check_err(y.mData, host_y.mData, "Error: Incorrect results d1", 1e-3, 1e-3);
|
||||
}
|
||||
return (pass ? 0 : 1);
|
||||
}
|
||||
@@ -45,3 +45,4 @@ add_subdirectory(23_softmax)
|
||||
add_subdirectory(24_batched_gemm_c_permute)
|
||||
add_subdirectory(25_gemm_bias_c_permute)
|
||||
add_subdirectory(26_contraction)
|
||||
add_subdirectory(27_layernorm)
|
||||
|
||||
346
include/ck/tensor_operation/gpu/device/device_layernorm.hpp
Normal file
346
include/ck/tensor_operation/gpu/device/device_layernorm.hpp
Normal file
@@ -0,0 +1,346 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <iostream>
|
||||
#include <sstream>
|
||||
|
||||
#include "ck/utility/reduction_operator.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/device_base.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/device_reduce.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/device_reduce_multiblock.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/device_reduce_common.hpp"
|
||||
#include "ck/tensor_operation/gpu/grid/gridwise_layernorm.hpp"
|
||||
#include "ck/tensor_operation/gpu/grid/gridwise_set_buffer_value.hpp"
|
||||
#include "ck/device_utility/device_prop.hpp"
|
||||
#include "ck/device_utility/kernel_launch.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
|
||||
// Y = LayerNorm(X, Beta, Gamma)
|
||||
template <typename XDataType,
|
||||
typename GammaDataType,
|
||||
typename BetaDataType,
|
||||
typename AccDataType,
|
||||
typename YDataType,
|
||||
typename AccElementwiseOperation,
|
||||
index_t Rank,
|
||||
index_t NumReduceDim,
|
||||
index_t BlockSize,
|
||||
index_t MThreadClusterSize,
|
||||
index_t KThreadClusterSize,
|
||||
index_t MThreadSliceSize,
|
||||
index_t KThreadSliceSize,
|
||||
index_t XYSrcVectorDim,
|
||||
index_t XSrcVectorSize,
|
||||
index_t GammaSrcVectorSize,
|
||||
index_t BetaSrcVectorSize,
|
||||
index_t YDstVectorSize>
|
||||
struct DeviceLayernorm : public BaseOperator
|
||||
{
|
||||
static_assert(
|
||||
(KThreadSliceSize % GammaSrcVectorSize == 0),
|
||||
"Invalid thread slice sizes and/or gamma vector sizes configuration, please check!");
|
||||
|
||||
static_assert(
|
||||
(KThreadSliceSize % BetaSrcVectorSize == 0),
|
||||
"Invalid thread slice sizes and/or beta vector sizes configuration, please check!");
|
||||
|
||||
using PassThrough = tensor_operation::element_wise::PassThrough;
|
||||
|
||||
// Used for freeloading of some handy functions from DeviceReduceMultiBlock
|
||||
using Reduction = DeviceReduceMultiBlock<XDataType,
|
||||
AccDataType,
|
||||
YDataType,
|
||||
Rank,
|
||||
NumReduceDim,
|
||||
reduce::Add,
|
||||
PassThrough, // InElementwiseOperation
|
||||
AccElementwiseOperation, // AccElementwiseOperation
|
||||
InMemoryDataOperationEnum::Set,
|
||||
false, // PropagateNan
|
||||
false, // OutputIndex
|
||||
false, // HaveIndexInputIfOutputIndex
|
||||
BlockSize,
|
||||
MThreadClusterSize,
|
||||
KThreadClusterSize,
|
||||
MThreadSliceSize,
|
||||
KThreadSliceSize,
|
||||
XYSrcVectorDim,
|
||||
XSrcVectorSize,
|
||||
1>; // YDstVectorSize
|
||||
|
||||
static auto MakeAffine1dDescriptor(const std::vector<index_t>& Lengths,
|
||||
const std::vector<index_t>& Strides,
|
||||
int blkGroupSize,
|
||||
int numBlockTileIteration)
|
||||
{
|
||||
const auto tupleLengths = make_tuple_from_array(Lengths, Number<NumReduceDim>{});
|
||||
const auto tupleStrides = make_tuple_from_array(Strides, Number<NumReduceDim>{});
|
||||
|
||||
auto desc = make_naive_tensor_descriptor(tupleLengths, tupleStrides);
|
||||
|
||||
auto grid_desc_k = transform_tensor_descriptor(
|
||||
desc,
|
||||
make_tuple(make_merge_transform(tupleLengths)),
|
||||
make_tuple(typename arithmetic_sequence_gen<0, NumReduceDim, 1>::type{}),
|
||||
make_tuple(Sequence<0>{}));
|
||||
|
||||
const auto reduceTotalLength = grid_desc_k.GetLength(Number<0>{});
|
||||
const int reduceSizePerBlock = Reduction::K_BlockTileSize * numBlockTileIteration;
|
||||
|
||||
const auto Pad_K = reduceSizePerBlock * blkGroupSize - reduceTotalLength;
|
||||
|
||||
auto grid_desc_k_padded = transform_tensor_descriptor(
|
||||
grid_desc_k,
|
||||
make_tuple(make_right_pad_transform(reduceTotalLength, Pad_K)),
|
||||
make_tuple(Sequence<0>{}),
|
||||
make_tuple(Sequence<0>{}));
|
||||
|
||||
return (grid_desc_k_padded);
|
||||
};
|
||||
|
||||
using GridDesc_M_K = decltype(Reduction::MakeSrc2dDescriptor({1}, {1}, 1, 1));
|
||||
using GridDesc_K = decltype(MakeAffine1dDescriptor({1}, {1}, 1, 1));
|
||||
|
||||
using GridwiseReduceLayernormGeneric = GridwiseLayernorm_mk_to_mk<XDataType,
|
||||
GammaDataType,
|
||||
BetaDataType,
|
||||
YDataType,
|
||||
AccDataType,
|
||||
AccElementwiseOperation,
|
||||
GridDesc_M_K,
|
||||
GridDesc_K,
|
||||
BlockSize,
|
||||
MThreadClusterSize,
|
||||
KThreadClusterSize,
|
||||
MThreadSliceSize,
|
||||
KThreadSliceSize,
|
||||
XYSrcVectorDim,
|
||||
XSrcVectorSize,
|
||||
GammaSrcVectorSize,
|
||||
BetaSrcVectorSize,
|
||||
XYSrcVectorDim,
|
||||
YDstVectorSize,
|
||||
false>;
|
||||
|
||||
using GridwiseReduceLayernormSweepOnce = GridwiseLayernorm_mk_to_mk<XDataType,
|
||||
GammaDataType,
|
||||
BetaDataType,
|
||||
YDataType,
|
||||
AccDataType,
|
||||
AccElementwiseOperation,
|
||||
GridDesc_M_K,
|
||||
GridDesc_K,
|
||||
BlockSize,
|
||||
MThreadClusterSize,
|
||||
KThreadClusterSize,
|
||||
MThreadSliceSize,
|
||||
KThreadSliceSize,
|
||||
XYSrcVectorDim,
|
||||
XSrcVectorSize,
|
||||
GammaSrcVectorSize,
|
||||
BetaSrcVectorSize,
|
||||
XYSrcVectorDim,
|
||||
YDstVectorSize,
|
||||
true>;
|
||||
|
||||
struct Argument : public Reduction::Argument
|
||||
{
|
||||
Argument(const std::vector<index_t> lengths,
|
||||
const std::vector<index_t> xStrides,
|
||||
const std::vector<index_t> gammaStrides,
|
||||
const std::vector<index_t> betaStrides,
|
||||
const std::vector<index_t> reduceDims,
|
||||
AccElementwiseOperation acc_elementwise_op,
|
||||
AccDataType epsilon,
|
||||
const XDataType* p_x,
|
||||
const GammaDataType* p_gamma,
|
||||
const BetaDataType* p_beta,
|
||||
YDataType* p_y)
|
||||
: Reduction::Argument(lengths,
|
||||
xStrides,
|
||||
{},
|
||||
{},
|
||||
reduceDims,
|
||||
0.0f, // alpha
|
||||
0.0f, // beta
|
||||
p_x,
|
||||
nullptr,
|
||||
p_y,
|
||||
nullptr,
|
||||
acc_elementwise_op,
|
||||
PassThrough{}),
|
||||
epsilon_(epsilon),
|
||||
p_gamma_(p_gamma),
|
||||
p_beta_(p_beta),
|
||||
gammaStrides_(gammaStrides),
|
||||
betaStrides_(betaStrides)
|
||||
{
|
||||
reduceLength_.resize(NumReduceDim);
|
||||
|
||||
for(int i = 0; i < NumReduceDim; ++i)
|
||||
{
|
||||
reduceLength_[i] = lengths[reduceDims[i]];
|
||||
}
|
||||
}
|
||||
|
||||
AccDataType epsilon_;
|
||||
const GammaDataType* p_gamma_;
|
||||
const BetaDataType* p_beta_;
|
||||
std::vector<index_t> reduceLength_;
|
||||
std::vector<index_t> gammaStrides_;
|
||||
std::vector<index_t> betaStrides_;
|
||||
};
|
||||
|
||||
struct Invoker : public BaseInvoker
|
||||
{
|
||||
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
|
||||
{
|
||||
const auto x_grid_desc_m_k = Reduction::MakeSrc2dDescriptor(
|
||||
arg.inLengths_, arg.inStrides_, arg.blkGroupSize, arg.numBlockTileIteration);
|
||||
const auto gamma_grid_desc_k = MakeAffine1dDescriptor(
|
||||
arg.reduceLength_, arg.gammaStrides_, arg.blkGroupSize, arg.numBlockTileIteration);
|
||||
const auto beta_grid_desc_k = MakeAffine1dDescriptor(
|
||||
arg.reduceLength_, arg.betaStrides_, arg.blkGroupSize, arg.numBlockTileIteration);
|
||||
const auto y_grid_desc_m_k = Reduction::MakeSrc2dDescriptor(
|
||||
arg.inLengths_, arg.inStrides_, arg.blkGroupSize, arg.numBlockTileIteration);
|
||||
|
||||
bool sweep_once =
|
||||
x_grid_desc_m_k.GetLength(Number<1>{}) <= KThreadClusterSize * KThreadSliceSize;
|
||||
|
||||
const auto kernel_main = sweep_once ? kernel_layernorm<GridwiseReduceLayernormSweepOnce,
|
||||
XDataType,
|
||||
GammaDataType,
|
||||
BetaDataType,
|
||||
YDataType,
|
||||
AccDataType,
|
||||
AccElementwiseOperation,
|
||||
GridDesc_M_K,
|
||||
GridDesc_K>
|
||||
: kernel_layernorm<GridwiseReduceLayernormGeneric,
|
||||
XDataType,
|
||||
GammaDataType,
|
||||
BetaDataType,
|
||||
YDataType,
|
||||
AccDataType,
|
||||
AccElementwiseOperation,
|
||||
GridDesc_M_K,
|
||||
GridDesc_K>;
|
||||
|
||||
float avg_time = 0;
|
||||
avg_time += launch_and_time_kernel(stream_config,
|
||||
kernel_main,
|
||||
dim3(arg.gridSize),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
x_grid_desc_m_k,
|
||||
gamma_grid_desc_k,
|
||||
beta_grid_desc_k,
|
||||
y_grid_desc_m_k,
|
||||
arg.numBlockTileIteration,
|
||||
arg.epsilon_,
|
||||
arg.in_dev_,
|
||||
arg.p_gamma_,
|
||||
arg.p_beta_,
|
||||
arg.out_dev_,
|
||||
arg.acc_elementwise_op_);
|
||||
|
||||
return (avg_time);
|
||||
};
|
||||
|
||||
float Run(const BaseArgument* p_arg,
|
||||
const StreamConfig& stream_config = StreamConfig{}) override
|
||||
{
|
||||
return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
|
||||
};
|
||||
};
|
||||
|
||||
bool IsSupportedArgument(const BaseArgument* p_arg) override
|
||||
{
|
||||
const Argument* p_arg_ = dynamic_cast<const Argument*>(p_arg);
|
||||
|
||||
if(!Reduction::IsSupportedArgument(p_arg_))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
if(p_arg_->inLengths_[Rank - 1] % YDstVectorSize != 0)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
if(p_arg_->gammaStrides_.size() != NumReduceDim ||
|
||||
p_arg_->betaStrides_.size() != NumReduceDim)
|
||||
return false;
|
||||
|
||||
auto IsScalarPerVectorValid = [](bool isLastDimensionCoalesced, int scalarPerVector) {
|
||||
bool ret = true;
|
||||
|
||||
if(!isLastDimensionCoalesced)
|
||||
ret = scalarPerVector == 1;
|
||||
else
|
||||
ret = KThreadSliceSize % scalarPerVector == 0;
|
||||
|
||||
return ret;
|
||||
};
|
||||
|
||||
if(!IsScalarPerVectorValid(p_arg_->gammaStrides_.back() == 1, GammaSrcVectorSize))
|
||||
return false;
|
||||
|
||||
if(!IsScalarPerVectorValid(p_arg_->betaStrides_.back() == 1, BetaSrcVectorSize))
|
||||
return false;
|
||||
|
||||
return true;
|
||||
};
|
||||
|
||||
std::unique_ptr<BaseArgument> MakeArgumentPointer(const std::vector<index_t> lengths,
|
||||
const std::vector<index_t> xStrides,
|
||||
const std::vector<index_t> gammaStrides,
|
||||
const std::vector<index_t> betaStrides,
|
||||
const std::vector<index_t> reduceDims,
|
||||
AccDataType epsilon,
|
||||
const void* p_x,
|
||||
const void* p_gamma,
|
||||
const void* p_beta,
|
||||
void* p_y,
|
||||
AccElementwiseOperation acc_elementwise_op)
|
||||
{
|
||||
return std::make_unique<Argument>(lengths,
|
||||
xStrides,
|
||||
gammaStrides,
|
||||
betaStrides,
|
||||
reduceDims,
|
||||
acc_elementwise_op,
|
||||
epsilon,
|
||||
static_cast<const XDataType*>(p_x),
|
||||
static_cast<const GammaDataType*>(p_gamma),
|
||||
static_cast<const BetaDataType*>(p_beta),
|
||||
static_cast<YDataType*>(p_y));
|
||||
};
|
||||
|
||||
std::unique_ptr<BaseInvoker> MakeInvokerPointer() { return std::make_unique<Invoker>(); };
|
||||
|
||||
std::string GetTypeString() const override
|
||||
{
|
||||
auto str = std::stringstream();
|
||||
|
||||
// clang-format off
|
||||
str << "DeviceLayernorm<" << BlockSize << ",";
|
||||
str << "M_C" << MThreadClusterSize << "_S" << MThreadSliceSize << ",";
|
||||
str << "K_C" << KThreadClusterSize << "_S" << KThreadSliceSize << ",";
|
||||
str << "K_C" << KThreadClusterSize << "_S" << KThreadSliceSize << ",";
|
||||
str << "XYSrcVectorDim_" << XYSrcVectorDim << ",";
|
||||
str << "VectorSize_X" << XSrcVectorSize << "_Gamma" << GammaSrcVectorSize << "_Beta" << BetaSrcVectorSize << "_Y" << YDstVectorSize << ">";
|
||||
// clang-format on
|
||||
|
||||
return str.str();
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
392
include/ck/tensor_operation/gpu/grid/gridwise_layernorm.hpp
Normal file
392
include/ck/tensor_operation/gpu/grid/gridwise_layernorm.hpp
Normal file
@@ -0,0 +1,392 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck/utility/data_type.hpp"
|
||||
#include "ck/utility/reduction_common.hpp"
|
||||
#include "ck/utility/reduction_operator.hpp"
|
||||
#include "ck/utility/reduction_functions_accumulate.hpp"
|
||||
#include "ck/tensor_operation/gpu/block/reduction_functions_blockwise.hpp"
|
||||
#include "ck/tensor_operation/gpu/thread/reduction_functions_threadwise.hpp"
|
||||
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
template <typename GridwiseReduction,
|
||||
typename XDataType,
|
||||
typename GammaDataType,
|
||||
typename BetaDataType,
|
||||
typename YDataType,
|
||||
typename AccDataType,
|
||||
typename AccElementwiseOperation,
|
||||
typename GridDesc_M_K,
|
||||
typename GridDesc_K>
|
||||
__global__ void kernel_layernorm(const GridDesc_M_K x_grid_desc_m_k,
|
||||
const GridDesc_K gamma_grid_desc_k,
|
||||
const GridDesc_K beta_grid_desc_k,
|
||||
const GridDesc_M_K y_grid_desc_m_k,
|
||||
index_t num_k_block_tile_iteration,
|
||||
AccDataType epsilon,
|
||||
const XDataType* const __restrict__ p_x_global,
|
||||
const GammaDataType* const __restrict__ p_gamma_global,
|
||||
const BetaDataType* const __restrict__ p_beta_global,
|
||||
YDataType* const __restrict__ p_y_global,
|
||||
const AccElementwiseOperation acc_elementwise_op)
|
||||
{
|
||||
GridwiseReduction::Run(x_grid_desc_m_k,
|
||||
gamma_grid_desc_k,
|
||||
beta_grid_desc_k,
|
||||
y_grid_desc_m_k,
|
||||
num_k_block_tile_iteration,
|
||||
epsilon,
|
||||
p_x_global,
|
||||
p_gamma_global,
|
||||
p_beta_global,
|
||||
p_y_global,
|
||||
acc_elementwise_op);
|
||||
};
|
||||
|
||||
// Y = LayerNorm(X, Beta, Gamma)
|
||||
template <typename XDataType,
|
||||
typename GammaDataType,
|
||||
typename BetaDataType,
|
||||
typename YDataType,
|
||||
typename AccDataType,
|
||||
typename AccElementwiseOperation,
|
||||
typename GridDesc_M_K,
|
||||
typename GridDesc_K,
|
||||
index_t BlockSize,
|
||||
index_t MThreadClusterSize,
|
||||
index_t KThreadClusterSize,
|
||||
index_t MThreadSliceSize,
|
||||
index_t KThreadSliceSize,
|
||||
index_t XSrcVectorDim,
|
||||
index_t XSrcVectorSize,
|
||||
index_t GammaSrcVectorSize,
|
||||
index_t BetaSrcVectorSize,
|
||||
index_t YDstVectorDim,
|
||||
index_t YDstVectorSize,
|
||||
bool SweepOnce>
|
||||
struct GridwiseLayernorm_mk_to_mk
|
||||
{
|
||||
static_assert((XSrcVectorDim == 0 && MThreadSliceSize % XSrcVectorSize == 0) ||
|
||||
(XSrcVectorDim == 1 && KThreadSliceSize % XSrcVectorSize == 0),
|
||||
"Invalid thread slice sizes and/or vector sizes configuration, please check!");
|
||||
|
||||
static_assert((YDstVectorDim == 0 && MThreadSliceSize % YDstVectorSize == 0) ||
|
||||
(YDstVectorDim == 1 && KThreadSliceSize % YDstVectorSize == 0),
|
||||
"Invalid thread slice sizes and/or vector sizes configuration, please check!");
|
||||
|
||||
static constexpr bool reorder_thread_cluster = (XSrcVectorDim == 0);
|
||||
|
||||
using ThreadClusterLengths_M_K = Sequence<MThreadClusterSize, KThreadClusterSize>;
|
||||
|
||||
using ThreadBufferDimAccessOrder =
|
||||
typename conditional<reorder_thread_cluster, Sequence<1, 0>, Sequence<0, 1>>::type;
|
||||
|
||||
using ThreadClusterArrangeOrder =
|
||||
typename conditional<reorder_thread_cluster, Sequence<1, 0>, Sequence<0, 1>>::type;
|
||||
|
||||
static constexpr auto thread_cluster_desc =
|
||||
make_cluster_descriptor(ThreadClusterLengths_M_K{}, ThreadClusterArrangeOrder{});
|
||||
|
||||
using ThreadReduceSrcDesc_M_K = decltype(make_naive_tensor_descriptor_packed(
|
||||
make_tuple(Number<MThreadSliceSize>{}, Number<KThreadSliceSize>{})));
|
||||
using ThreadReduceDstDesc_M =
|
||||
decltype(make_naive_tensor_descriptor_packed(make_tuple(Number<MThreadSliceSize>{})));
|
||||
|
||||
using BlockwiseSumReduce = PartitionedBlockwiseReduction<AccDataType,
|
||||
BlockSize,
|
||||
ThreadClusterLengths_M_K,
|
||||
ThreadClusterArrangeOrder,
|
||||
reduce::Add,
|
||||
true>;
|
||||
|
||||
using ThreadwiseSumReduce = ThreadwiseReduction<AccDataType,
|
||||
ThreadReduceSrcDesc_M_K,
|
||||
ThreadReduceDstDesc_M,
|
||||
reduce::Add,
|
||||
true>;
|
||||
|
||||
static constexpr auto I0 = Number<0>{};
|
||||
static constexpr auto I1 = Number<1>{};
|
||||
|
||||
static constexpr index_t M_BlockTileSize = MThreadClusterSize * MThreadSliceSize;
|
||||
static constexpr index_t K_BlockTileSize = KThreadClusterSize * KThreadSliceSize;
|
||||
|
||||
__device__ static void Run(const GridDesc_M_K& x_grid_desc_m_k,
|
||||
const GridDesc_K& gamma_grid_desc_k,
|
||||
const GridDesc_K& beta_grid_desc_k,
|
||||
const GridDesc_M_K& y_grid_desc_m_k,
|
||||
index_t num_k_block_tile_iteration,
|
||||
AccDataType epsilon,
|
||||
const XDataType* const __restrict__ p_x_global,
|
||||
const GammaDataType* const __restrict__ p_gamma_global,
|
||||
const BetaDataType* const __restrict__ p_beta_global,
|
||||
YDataType* const __restrict__ p_y_global,
|
||||
const AccElementwiseOperation acc_elementwise_op)
|
||||
{
|
||||
if constexpr(SweepOnce)
|
||||
{
|
||||
num_k_block_tile_iteration = 1;
|
||||
}
|
||||
|
||||
// LDS
|
||||
__shared__ AccDataType p_reduce_work_buffer[BlockSize];
|
||||
|
||||
auto y_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_y_global, y_grid_desc_m_k.GetElementSpaceSize());
|
||||
|
||||
auto reduce_work_buf =
|
||||
make_dynamic_buffer<AddressSpaceEnum::Lds>(p_reduce_work_buffer, BlockSize);
|
||||
|
||||
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize * KThreadSliceSize, true>
|
||||
x_thread_buf;
|
||||
|
||||
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, KThreadSliceSize, true> gamma_thread_buf;
|
||||
|
||||
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, KThreadSliceSize, true>& beta_thread_buf =
|
||||
gamma_thread_buf;
|
||||
|
||||
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize * KThreadSliceSize, true>
|
||||
y_thread_buf;
|
||||
|
||||
StaticBuffer<AddressSpaceEnum::Vgpr,
|
||||
AccDataType,
|
||||
MThreadSliceSize * KThreadSliceSize,
|
||||
true>& x_square_thread_buf = y_thread_buf;
|
||||
|
||||
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize, true> mean_thread_buf;
|
||||
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize, true>
|
||||
mean_square_thread_buf;
|
||||
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize, true>& var_value_buf =
|
||||
mean_square_thread_buf;
|
||||
|
||||
static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
|
||||
mean_thread_buf(I) = reduce::Add::template GetIdentityValue<AccDataType>();
|
||||
mean_square_thread_buf(I) = reduce::Add::template GetIdentityValue<AccDataType>();
|
||||
});
|
||||
|
||||
const index_t thread_local_id = get_thread_local_1d_id();
|
||||
const index_t block_global_id = get_block_1d_id();
|
||||
|
||||
const auto thread_cluster_idx =
|
||||
thread_cluster_desc.CalculateBottomIndex(make_multi_index(thread_local_id));
|
||||
|
||||
const auto thread_m_cluster_id = thread_cluster_idx[I0];
|
||||
const auto thread_k_cluster_id = thread_cluster_idx[I1];
|
||||
|
||||
using ThreadBufferLengths_M_K = Sequence<MThreadSliceSize, KThreadSliceSize>;
|
||||
using ThreadBufferLengths_K = Sequence<KThreadSliceSize>;
|
||||
constexpr auto thread_buffer_desc_m_k = make_naive_tensor_descriptor_packed(
|
||||
make_tuple(Number<MThreadSliceSize>{}, Number<KThreadSliceSize>{}));
|
||||
constexpr auto thread_buffer_desc_k =
|
||||
make_naive_tensor_descriptor_packed(make_tuple(Number<KThreadSliceSize>{}));
|
||||
|
||||
auto threadwise_x_load = ThreadwiseTensorSliceTransfer_v2<XDataType,
|
||||
AccDataType,
|
||||
GridDesc_M_K,
|
||||
decltype(thread_buffer_desc_m_k),
|
||||
ThreadBufferLengths_M_K,
|
||||
ThreadBufferDimAccessOrder,
|
||||
XSrcVectorDim,
|
||||
XSrcVectorSize,
|
||||
1,
|
||||
true>(
|
||||
x_grid_desc_m_k,
|
||||
make_multi_index(block_global_id * M_BlockTileSize +
|
||||
thread_m_cluster_id * MThreadSliceSize,
|
||||
thread_k_cluster_id * KThreadSliceSize));
|
||||
|
||||
auto threadwise_gamma_load =
|
||||
ThreadwiseTensorSliceTransfer_v2<GammaDataType,
|
||||
AccDataType,
|
||||
GridDesc_K,
|
||||
decltype(thread_buffer_desc_k),
|
||||
ThreadBufferLengths_K,
|
||||
Sequence<0>,
|
||||
0,
|
||||
GammaSrcVectorSize,
|
||||
1,
|
||||
true>(
|
||||
gamma_grid_desc_k, make_multi_index(thread_k_cluster_id * KThreadSliceSize));
|
||||
|
||||
auto threadwise_beta_load = ThreadwiseTensorSliceTransfer_v2<BetaDataType,
|
||||
AccDataType,
|
||||
GridDesc_K,
|
||||
decltype(thread_buffer_desc_k),
|
||||
ThreadBufferLengths_K,
|
||||
Sequence<0>,
|
||||
0,
|
||||
BetaSrcVectorSize,
|
||||
1,
|
||||
true>(
|
||||
beta_grid_desc_k, make_multi_index(thread_k_cluster_id * KThreadSliceSize));
|
||||
|
||||
auto threadwise_y_store =
|
||||
ThreadwiseTensorSliceTransfer_v1r3<AccDataType,
|
||||
YDataType,
|
||||
decltype(thread_buffer_desc_m_k),
|
||||
GridDesc_M_K,
|
||||
AccElementwiseOperation,
|
||||
ThreadBufferLengths_M_K,
|
||||
ThreadBufferDimAccessOrder,
|
||||
YDstVectorDim,
|
||||
YDstVectorSize,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
1,
|
||||
true>(
|
||||
y_grid_desc_m_k,
|
||||
make_multi_index(block_global_id * M_BlockTileSize +
|
||||
thread_m_cluster_id * MThreadSliceSize,
|
||||
thread_k_cluster_id * KThreadSliceSize),
|
||||
acc_elementwise_op);
|
||||
|
||||
// Copy x from Cache
|
||||
// one pass: fwd, second pass: bwd
|
||||
constexpr auto thread_copy_fwd_step_k = make_multi_index(SweepOnce ? 0 : K_BlockTileSize);
|
||||
constexpr auto thread_copy_bwd_step_k = make_multi_index(SweepOnce ? 0 : -K_BlockTileSize);
|
||||
|
||||
constexpr auto thread_copy_fwd_step_m_k =
|
||||
make_multi_index(0, SweepOnce ? 0 : K_BlockTileSize);
|
||||
constexpr auto thread_copy_bwd_step_m_k =
|
||||
make_multi_index(0, SweepOnce ? 0 : -K_BlockTileSize);
|
||||
|
||||
const auto x_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_x_global, x_grid_desc_m_k.GetElementSpaceSize());
|
||||
|
||||
const auto gamma_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_gamma_global, gamma_grid_desc_k.GetElementSpaceSize());
|
||||
|
||||
const auto beta_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_beta_global, beta_grid_desc_k.GetElementSpaceSize());
|
||||
|
||||
// E(x), E[x^2], var(x)
|
||||
int reduce_length = x_grid_desc_m_k.GetTransforms()[I0].GetUpperLengths()[I1];
|
||||
|
||||
index_t reducedTiles = 0;
|
||||
do
|
||||
{
|
||||
threadwise_x_load.Run(x_grid_desc_m_k,
|
||||
x_global_val_buf,
|
||||
thread_buffer_desc_m_k,
|
||||
make_tuple(I0, I0),
|
||||
x_thread_buf);
|
||||
|
||||
static_for<0, MThreadSliceSize, 1>{}([&](auto iM) {
|
||||
static_for<0, KThreadSliceSize, 1>{}([&](auto iK) {
|
||||
constexpr auto offset_m_k =
|
||||
thread_buffer_desc_m_k.CalculateOffset(make_tuple(iM, iK));
|
||||
x_square_thread_buf(Number<offset_m_k>{}) =
|
||||
x_thread_buf(Number<offset_m_k>{}) * x_thread_buf(Number<offset_m_k>{});
|
||||
});
|
||||
});
|
||||
|
||||
ThreadwiseSumReduce::Reduce(x_thread_buf, mean_thread_buf);
|
||||
ThreadwiseSumReduce::Reduce(x_square_thread_buf, mean_square_thread_buf);
|
||||
|
||||
threadwise_x_load.MoveSrcSliceWindow(x_grid_desc_m_k, thread_copy_fwd_step_m_k);
|
||||
|
||||
++reducedTiles;
|
||||
} while(reducedTiles < num_k_block_tile_iteration);
|
||||
|
||||
static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
|
||||
if constexpr(I > 0)
|
||||
block_sync_lds();
|
||||
|
||||
BlockwiseSumReduce::Reduce(reduce_work_buf, mean_thread_buf(I));
|
||||
mean_thread_buf(I) = mean_thread_buf(I) / reduce_length;
|
||||
|
||||
block_sync_lds();
|
||||
|
||||
BlockwiseSumReduce::Reduce(reduce_work_buf, mean_square_thread_buf(I));
|
||||
mean_square_thread_buf(I) = mean_square_thread_buf(I) / reduce_length;
|
||||
|
||||
// var(x) = E[x^2] - E[x]^2
|
||||
var_value_buf(I) =
|
||||
mean_square_thread_buf(I) - (mean_thread_buf(I) * mean_thread_buf(I));
|
||||
});
|
||||
|
||||
// y = (x - E[x]) / sqrt(var[x] + epsilon)
|
||||
auto thread_copy_tail_m_k = (num_k_block_tile_iteration - 1) * thread_copy_fwd_step_m_k;
|
||||
auto thread_copy_tail_k = (num_k_block_tile_iteration - 1) * thread_copy_fwd_step_k;
|
||||
|
||||
threadwise_x_load.MoveSrcSliceWindow(x_grid_desc_m_k, thread_copy_bwd_step_m_k);
|
||||
threadwise_gamma_load.MoveSrcSliceWindow(gamma_grid_desc_k, thread_copy_tail_k);
|
||||
threadwise_beta_load.MoveSrcSliceWindow(beta_grid_desc_k, thread_copy_tail_k);
|
||||
threadwise_y_store.MoveDstSliceWindow(y_grid_desc_m_k, thread_copy_tail_m_k);
|
||||
|
||||
reducedTiles = 0;
|
||||
do
|
||||
{
|
||||
if constexpr(!SweepOnce)
|
||||
{
|
||||
threadwise_x_load.Run(x_grid_desc_m_k,
|
||||
x_global_val_buf,
|
||||
thread_buffer_desc_m_k,
|
||||
make_tuple(I0, I0),
|
||||
x_thread_buf);
|
||||
}
|
||||
|
||||
threadwise_gamma_load.Run(gamma_grid_desc_k,
|
||||
gamma_global_val_buf,
|
||||
thread_buffer_desc_k,
|
||||
make_tuple(I0),
|
||||
gamma_thread_buf);
|
||||
|
||||
static_for<0, MThreadSliceSize, 1>{}([&](auto iM) {
|
||||
static_for<0, KThreadSliceSize, 1>{}([&](auto iK) {
|
||||
constexpr auto offset_m_k =
|
||||
thread_buffer_desc_m_k.CalculateOffset(make_tuple(iM, iK));
|
||||
|
||||
constexpr auto offset_k = thread_buffer_desc_k.CalculateOffset(make_tuple(iK));
|
||||
|
||||
// normalize
|
||||
y_thread_buf(Number<offset_m_k>{}) =
|
||||
(x_thread_buf(Number<offset_m_k>{}) - mean_thread_buf(iM)) /
|
||||
sqrt(var_value_buf(iM) + epsilon);
|
||||
|
||||
// gamma
|
||||
y_thread_buf(Number<offset_m_k>{}) =
|
||||
y_thread_buf(Number<offset_m_k>{}) * gamma_thread_buf(Number<offset_k>{});
|
||||
});
|
||||
});
|
||||
|
||||
threadwise_beta_load.Run(beta_grid_desc_k,
|
||||
beta_global_val_buf,
|
||||
thread_buffer_desc_k,
|
||||
make_tuple(I0),
|
||||
beta_thread_buf);
|
||||
|
||||
static_for<0, MThreadSliceSize, 1>{}([&](auto iM) {
|
||||
static_for<0, KThreadSliceSize, 1>{}([&](auto iK) {
|
||||
constexpr auto offset_m_k =
|
||||
thread_buffer_desc_m_k.CalculateOffset(make_tuple(iM, iK));
|
||||
|
||||
constexpr auto offset_k = thread_buffer_desc_k.CalculateOffset(make_tuple(iK));
|
||||
|
||||
// beta
|
||||
y_thread_buf(Number<offset_m_k>{}) =
|
||||
y_thread_buf(Number<offset_m_k>{}) + beta_thread_buf(Number<offset_k>{});
|
||||
});
|
||||
});
|
||||
|
||||
threadwise_y_store.Run(thread_buffer_desc_m_k,
|
||||
make_tuple(I0, I0),
|
||||
y_thread_buf,
|
||||
y_grid_desc_m_k,
|
||||
y_global_val_buf);
|
||||
|
||||
threadwise_x_load.MoveSrcSliceWindow(x_grid_desc_m_k, thread_copy_bwd_step_m_k);
|
||||
threadwise_gamma_load.MoveSrcSliceWindow(gamma_grid_desc_k, thread_copy_bwd_step_k);
|
||||
threadwise_beta_load.MoveSrcSliceWindow(beta_grid_desc_k, thread_copy_bwd_step_k);
|
||||
threadwise_y_store.MoveDstSliceWindow(y_grid_desc_m_k, thread_copy_bwd_step_m_k);
|
||||
|
||||
++reducedTiles;
|
||||
} while(reducedTiles < num_k_block_tile_iteration);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,170 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <iostream>
|
||||
#include <sstream>
|
||||
#include <vector>
|
||||
#include <algorithm>
|
||||
|
||||
#include "ck/tensor_operation/gpu/device/device_base.hpp"
|
||||
#include "ck/library/host_tensor/host_tensor.hpp"
|
||||
#include "ck/library/host_tensor/host_tensor_generator.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace host {
|
||||
|
||||
template <typename XDataType,
|
||||
typename GammaDataType,
|
||||
typename BetaDataType,
|
||||
typename YDataType,
|
||||
typename AccDataType,
|
||||
typename AccElementwiseOperation,
|
||||
index_t Rank,
|
||||
index_t NumReduceDim>
|
||||
struct ReferenceLayernorm : public device::BaseOperator
|
||||
{
|
||||
// TODO - support generic layernorm
|
||||
static_assert((Rank == 2 && NumReduceDim == 1), "Only support 2D version so far");
|
||||
|
||||
// Argument
|
||||
struct Argument : public device::BaseArgument
|
||||
{
|
||||
Argument(const Tensor<XDataType>& x_m_n,
|
||||
const Tensor<GammaDataType>& gamma_n,
|
||||
const Tensor<BetaDataType>& beta_n,
|
||||
Tensor<YDataType>& y_m_n,
|
||||
AccElementwiseOperation acc_elementwise_op,
|
||||
const std::vector<index_t> lengths,
|
||||
const std::vector<index_t> reduceDims,
|
||||
AccDataType epsilon)
|
||||
: x_m_n_(x_m_n),
|
||||
gamma_n_(gamma_n),
|
||||
beta_n_(beta_n),
|
||||
y_m_n_(y_m_n),
|
||||
acc_elementwise_op_(acc_elementwise_op),
|
||||
lengths_(lengths),
|
||||
reduceDims_(reduceDims),
|
||||
epsilon_(epsilon)
|
||||
{
|
||||
}
|
||||
|
||||
const Tensor<XDataType> x_m_n_;
|
||||
const Tensor<XDataType> gamma_n_;
|
||||
const Tensor<XDataType> beta_n_;
|
||||
Tensor<YDataType>& y_m_n_;
|
||||
AccElementwiseOperation acc_elementwise_op_;
|
||||
std::vector<index_t> lengths_;
|
||||
std::vector<index_t> reduceDims_;
|
||||
AccDataType epsilon_;
|
||||
};
|
||||
|
||||
// Invoker
|
||||
struct Invoker : public device::BaseInvoker
|
||||
{
|
||||
float Run(const Argument& arg)
|
||||
{
|
||||
int M = arg.lengths_[0];
|
||||
int N = arg.lengths_[1];
|
||||
|
||||
Tensor<AccDataType> mean({M});
|
||||
Tensor<AccDataType> var({M});
|
||||
|
||||
for(int m = 0; m < M; ++m)
|
||||
{
|
||||
mean(m) = 0;
|
||||
var(m) = 0;
|
||||
|
||||
for(int n = 0; n < N; ++n)
|
||||
{
|
||||
auto x_val = ck::type_convert<AccDataType>(arg.x_m_n_(m, n));
|
||||
mean(m) += x_val;
|
||||
var(m) += x_val * x_val;
|
||||
}
|
||||
|
||||
mean(m) = mean(m) / N;
|
||||
var(m) = (var(m) / N) - (mean(m) * mean(m));
|
||||
}
|
||||
|
||||
for(int m = 0; m < M; ++m)
|
||||
{
|
||||
for(int n = 0; n < N; ++n)
|
||||
{
|
||||
auto x_val = ck::type_convert<AccDataType>(arg.x_m_n_(m, n));
|
||||
auto y_val = (x_val - mean(m)) / sqrt(var(m) + arg.epsilon_);
|
||||
y_val = (y_val * arg.gamma_n_(n)) + arg.beta_n_(n);
|
||||
arg.y_m_n_(m, n) = ck::type_convert<YDataType>(y_val);
|
||||
}
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
float Run(const device::BaseArgument* p_arg,
|
||||
const StreamConfig& /* stream_config */ = StreamConfig{}) override
|
||||
{
|
||||
return Run(*dynamic_cast<const Argument*>(p_arg));
|
||||
}
|
||||
};
|
||||
|
||||
static constexpr bool IsValidCompilationParameter()
|
||||
{
|
||||
// TODO: properly implement this check
|
||||
return true;
|
||||
}
|
||||
|
||||
bool IsSupportedArgument(const device::BaseArgument* p_arg) override
|
||||
{
|
||||
const Argument* p_arg_ = dynamic_cast<const Argument*>(p_arg);
|
||||
|
||||
// TODO - support generic layernorm
|
||||
if(p_arg_->lengths_.size() != 2)
|
||||
return false;
|
||||
|
||||
if(p_arg_->reduceDims_.size() != 1)
|
||||
return false;
|
||||
|
||||
if(p_arg_->reduceDims_[0] != 1)
|
||||
return false;
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
static auto MakeArgument(const Tensor<XDataType>& x_m_n,
|
||||
const Tensor<GammaDataType>& gamma_n,
|
||||
const Tensor<BetaDataType>& beta_n,
|
||||
Tensor<YDataType>& y_m_n,
|
||||
AccElementwiseOperation acc_elementwise_op,
|
||||
const std::vector<index_t> lengths,
|
||||
const std::vector<index_t> reduceDims,
|
||||
AccDataType epsilon)
|
||||
{
|
||||
return Argument{
|
||||
x_m_n, gamma_n, beta_n, y_m_n, acc_elementwise_op, lengths, reduceDims, epsilon};
|
||||
}
|
||||
|
||||
static auto MakeInvoker() { return Invoker{}; }
|
||||
|
||||
virtual std::unique_ptr<device::BaseInvoker> MakeInvokerPointer()
|
||||
{
|
||||
return std::make_unique<Invoker>(Invoker{});
|
||||
}
|
||||
|
||||
std::string GetTypeString() const override
|
||||
{
|
||||
auto str = std::stringstream();
|
||||
|
||||
// clang-format off
|
||||
str << "ReferenceLayernorm"
|
||||
<< std::endl;
|
||||
// clang-format on
|
||||
|
||||
return str.str();
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace host
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -48,3 +48,4 @@ add_subdirectory(convnd_bwd_weight)
|
||||
add_subdirectory(convnd_bwd_data)
|
||||
add_subdirectory(block_to_ctile_map)
|
||||
add_subdirectory(softmax)
|
||||
add_subdirectory(layernorm)
|
||||
|
||||
8
test/layernorm/CMakeLists.txt
Normal file
8
test/layernorm/CMakeLists.txt
Normal file
@@ -0,0 +1,8 @@
|
||||
add_custom_target(test_layernorm)
|
||||
|
||||
add_gtest_executable(test_layernorm_fp32 test_layernorm_fp32.cpp)
|
||||
add_gtest_executable(test_layernorm_fp16 test_layernorm_fp16.cpp)
|
||||
target_link_libraries(test_layernorm_fp32 PRIVATE host_tensor)
|
||||
target_link_libraries(test_layernorm_fp16 PRIVATE host_tensor)
|
||||
add_dependencies(test_layernorm test_layernorm_fp32)
|
||||
add_dependencies(test_layernorm test_layernorm_fp16)
|
||||
29
test/layernorm/test_layernorm_fp16.cpp
Normal file
29
test/layernorm/test_layernorm_fp16.cpp
Normal file
@@ -0,0 +1,29 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "gtest/gtest.h"
|
||||
#include "test_layernorm_util.hpp"
|
||||
|
||||
template <ck::index_t N>
|
||||
using I = ck::Number<N>;
|
||||
|
||||
template <typename Tuple>
|
||||
class TestLayernormFP16 : public ck::TestLayernorm<Tuple>
|
||||
{
|
||||
};
|
||||
|
||||
// clang-format off
|
||||
using KernelTypes = ::testing::Types<
|
||||
// XDataType, GammaDataType, BetaDataType, AccDataType, YDataType, Rank, NumReduceDim, BlockSize, MThreadClusterSize, KThreadClusterSize, MThreadSliceSize, KThreadSliceSize, XYSrcVectorDim, XSrcVectorSize, , GammaSrcVectorSize, BetaSrcVectorSize, YDstVectorSize>
|
||||
std::tuple<ck::half_t, ck::half_t, ck::half_t, float, ck::half_t, I<2>, I<1>, I<256>, I<8>, I<32>, I<1>, I<8>, I<1>, I<8>, I<8>, I<8>, I<8>>,
|
||||
std::tuple<ck::half_t, ck::half_t, ck::half_t, float, ck::half_t, I<2>, I<1>, I<256>, I<8>, I<32>, I<2>, I<8>, I<1>, I<8>, I<8>, I<8>, I<8>>,
|
||||
std::tuple<ck::half_t, ck::half_t, ck::half_t, float, ck::half_t, I<2>, I<1>, I<256>, I<4>, I<64>, I<1>, I<8>, I<1>, I<8>, I<8>, I<8>, I<8>>,
|
||||
std::tuple<ck::half_t, ck::half_t, ck::half_t, float, ck::half_t, I<2>, I<1>, I<256>, I<4>, I<64>, I<2>, I<8>, I<1>, I<8>, I<8>, I<8>, I<8>>,
|
||||
std::tuple<ck::half_t, ck::half_t, ck::half_t, float, ck::half_t, I<2>, I<1>, I<256>, I<2>, I<128>, I<1>, I<8>, I<1>, I<8>, I<8>, I<8>, I<8>>,
|
||||
std::tuple<ck::half_t, ck::half_t, ck::half_t, float, ck::half_t, I<2>, I<1>, I<256>, I<2>, I<128>, I<2>, I<8>, I<1>, I<8>, I<8>, I<8>, I<8>>,
|
||||
std::tuple<ck::half_t, ck::half_t, ck::half_t, float, ck::half_t, I<2>, I<1>, I<256>, I<1>, I<256>, I<1>, I<8>, I<1>, I<8>, I<8>, I<8>, I<8>>,
|
||||
std::tuple<ck::half_t, ck::half_t, ck::half_t, float, ck::half_t, I<2>, I<1>, I<256>, I<1>, I<256>, I<2>, I<8>, I<1>, I<8>, I<8>, I<8>, I<8>>
|
||||
>;
|
||||
// clang-format on
|
||||
TYPED_TEST_SUITE(TestLayernormFP16, KernelTypes);
|
||||
TYPED_TEST(TestLayernormFP16, Test_FP16) { this->Run(); }
|
||||
29
test/layernorm/test_layernorm_fp32.cpp
Normal file
29
test/layernorm/test_layernorm_fp32.cpp
Normal file
@@ -0,0 +1,29 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "gtest/gtest.h"
|
||||
#include "test_layernorm_util.hpp"
|
||||
|
||||
template <ck::index_t N>
|
||||
using I = ck::Number<N>;
|
||||
|
||||
template <typename Tuple>
|
||||
class TestLayernormFP32 : public ck::TestLayernorm<Tuple>
|
||||
{
|
||||
};
|
||||
|
||||
// clang-format off
|
||||
using KernelTypes = ::testing::Types<
|
||||
// XDataType, GammaDataType, BetaDataType, AccDataType, YDataType, Rank, NumReduceDim, BlockSize, MThreadClusterSize, KThreadClusterSize, MThreadSliceSize, KThreadSliceSize, XYSrcVectorDim, XSrcVectorSize, , GammaSrcVectorSize, BetaSrcVectorSize, YDstVectorSize>
|
||||
std::tuple<float, float, float, float, float, I<2>, I<1>, I<256>, I<8>, I<32>, I<1>, I<8>, I<1>, I<4>, I<4>, I<4>, I<4>>,
|
||||
std::tuple<float, float, float, float, float, I<2>, I<1>, I<256>, I<8>, I<32>, I<2>, I<8>, I<1>, I<4>, I<4>, I<4>, I<4>>,
|
||||
std::tuple<float, float, float, float, float, I<2>, I<1>, I<256>, I<4>, I<64>, I<1>, I<8>, I<1>, I<4>, I<4>, I<4>, I<4>>,
|
||||
std::tuple<float, float, float, float, float, I<2>, I<1>, I<256>, I<4>, I<64>, I<2>, I<8>, I<1>, I<4>, I<4>, I<4>, I<4>>,
|
||||
std::tuple<float, float, float, float, float, I<2>, I<1>, I<256>, I<2>, I<128>, I<1>, I<8>, I<1>, I<4>, I<4>, I<4>, I<4>>,
|
||||
std::tuple<float, float, float, float, float, I<2>, I<1>, I<256>, I<2>, I<128>, I<2>, I<8>, I<1>, I<4>, I<4>, I<4>, I<4>>,
|
||||
std::tuple<float, float, float, float, float, I<2>, I<1>, I<256>, I<1>, I<256>, I<1>, I<8>, I<1>, I<4>, I<4>, I<4>, I<4>>,
|
||||
std::tuple<float, float, float, float, float, I<2>, I<1>, I<256>, I<1>, I<256>, I<2>, I<8>, I<1>, I<4>, I<4>, I<4>, I<4>>
|
||||
>;
|
||||
// clang-format on
|
||||
TYPED_TEST_SUITE(TestLayernormFP32, KernelTypes);
|
||||
TYPED_TEST(TestLayernormFP32, Test_FP32) { this->Run(); }
|
||||
178
test/layernorm/test_layernorm_util.hpp
Normal file
178
test/layernorm/test_layernorm_util.hpp
Normal file
@@ -0,0 +1,178 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <vector>
|
||||
#include <iostream>
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/utility/number.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/device_layernorm.hpp"
|
||||
|
||||
#include "ck/library/utility/check_err.hpp"
|
||||
#include "ck/library/host_tensor/host_tensor.hpp"
|
||||
#include "ck/library/host_tensor/device_memory.hpp"
|
||||
#include "ck/library/reference_tensor_operation/cpu/reference_layernorm.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
template <typename Range>
|
||||
std::string serialize_range(const Range& range)
|
||||
{
|
||||
std::stringstream ss;
|
||||
for(auto& r : range)
|
||||
{
|
||||
ss << r << ", ";
|
||||
}
|
||||
std::string str = ss.str();
|
||||
return std::string(str.begin(), str.end() - 2);
|
||||
}
|
||||
|
||||
template <typename Tuple>
|
||||
class TestLayernorm : public ::testing::Test
|
||||
{
|
||||
protected:
|
||||
using XDataType = std::tuple_element_t<0, Tuple>;
|
||||
using GammaDataType = std::tuple_element_t<1, Tuple>;
|
||||
using BetaDataType = std::tuple_element_t<2, Tuple>;
|
||||
using AccDataType = std::tuple_element_t<3, Tuple>;
|
||||
using YDataType = std::tuple_element_t<4, Tuple>;
|
||||
static constexpr index_t Rank = std::tuple_element_t<5, Tuple>{}.value;
|
||||
static constexpr index_t NumReduceDim = std::tuple_element_t<6, Tuple>{}.value;
|
||||
static constexpr index_t BlockSize = std::tuple_element_t<7, Tuple>{}.value;
|
||||
static constexpr index_t MThreadClusterSize = std::tuple_element_t<8, Tuple>{}.value;
|
||||
static constexpr index_t KThreadClusterSize = std::tuple_element_t<9, Tuple>{}.value;
|
||||
static constexpr index_t MThreadSliceSize = std::tuple_element_t<10, Tuple>{}.value;
|
||||
static constexpr index_t KThreadSliceSize = std::tuple_element_t<11, Tuple>{}.value;
|
||||
static constexpr index_t XYSrcVectorDim = std::tuple_element_t<12, Tuple>{}.value;
|
||||
static constexpr index_t XSrcVectorSize = std::tuple_element_t<13, Tuple>{}.value;
|
||||
static constexpr index_t GammaSrcVectorSize = std::tuple_element_t<14, Tuple>{}.value;
|
||||
static constexpr index_t BetaSrcVectorSize = std::tuple_element_t<15, Tuple>{}.value;
|
||||
static constexpr index_t YDstVectorSize = std::tuple_element_t<16, Tuple>{}.value;
|
||||
|
||||
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
|
||||
|
||||
using ReferenceInstance = tensor_operation::host::ReferenceLayernorm<XDataType,
|
||||
GammaDataType,
|
||||
BetaDataType,
|
||||
YDataType,
|
||||
AccDataType,
|
||||
PassThrough,
|
||||
Rank,
|
||||
NumReduceDim>;
|
||||
|
||||
using DeviceInstance = tensor_operation::device::DeviceLayernorm<XDataType,
|
||||
GammaDataType,
|
||||
BetaDataType,
|
||||
AccDataType,
|
||||
YDataType,
|
||||
PassThrough,
|
||||
Rank,
|
||||
NumReduceDim,
|
||||
BlockSize,
|
||||
MThreadClusterSize,
|
||||
KThreadClusterSize,
|
||||
MThreadSliceSize,
|
||||
KThreadSliceSize,
|
||||
XYSrcVectorDim,
|
||||
XSrcVectorSize,
|
||||
GammaSrcVectorSize,
|
||||
BetaSrcVectorSize,
|
||||
YDstVectorSize>;
|
||||
|
||||
TestLayernorm() : ref_instance_invoker_(ReferenceInstance{}.MakeInvoker()) {}
|
||||
|
||||
void RunSingle(std::vector<index_t> lengths, std::vector<index_t> reduceDims)
|
||||
{
|
||||
std::vector<index_t> reduceLength(reduceDims.size());
|
||||
for(int i = 0; i < NumReduceDim; ++i)
|
||||
{
|
||||
reduceLength[i] = lengths[reduceDims[i]];
|
||||
}
|
||||
|
||||
Tensor<XDataType> x(lengths);
|
||||
Tensor<GammaDataType> gamma(reduceLength);
|
||||
Tensor<BetaDataType> beta(reduceLength);
|
||||
Tensor<YDataType> y(lengths);
|
||||
Tensor<YDataType> y_ref(lengths);
|
||||
|
||||
x.GenerateTensorValue(GeneratorTensor_3<XDataType>{0, 1.0});
|
||||
gamma.GenerateTensorValue(GeneratorTensor_3<GammaDataType>{0.0, 1.0});
|
||||
beta.GenerateTensorValue(GeneratorTensor_3<BetaDataType>{0.0, 1.0});
|
||||
|
||||
DeviceMem x_dev(sizeof(XDataType) * x.mDesc.GetElementSpace());
|
||||
DeviceMem gamma_dev(sizeof(GammaDataType) * gamma.mDesc.GetElementSpace());
|
||||
DeviceMem beta_dev(sizeof(BetaDataType) * beta.mDesc.GetElementSpace());
|
||||
DeviceMem y_dev(sizeof(YDataType) * y.mDesc.GetElementSpace());
|
||||
|
||||
x_dev.ToDevice(x.mData.data());
|
||||
gamma_dev.ToDevice(gamma.mData.data());
|
||||
beta_dev.ToDevice(beta.mData.data());
|
||||
|
||||
auto device_instance = DeviceInstance{};
|
||||
auto argument_ptr = device_instance.MakeArgumentPointer(
|
||||
lengths,
|
||||
std::vector<ck::index_t>{x.mDesc.GetStrides().begin(), x.mDesc.GetStrides().end()},
|
||||
std::vector<ck::index_t>{gamma.mDesc.GetStrides().begin(),
|
||||
gamma.mDesc.GetStrides().end()},
|
||||
std::vector<ck::index_t>{beta.mDesc.GetStrides().begin(),
|
||||
beta.mDesc.GetStrides().end()},
|
||||
reduceDims,
|
||||
1e-4,
|
||||
x_dev.GetDeviceBuffer(),
|
||||
gamma_dev.GetDeviceBuffer(),
|
||||
beta_dev.GetDeviceBuffer(),
|
||||
y_dev.GetDeviceBuffer(),
|
||||
PassThrough{});
|
||||
|
||||
if(!device_instance.IsSupportedArgument(argument_ptr.get()))
|
||||
{
|
||||
return;
|
||||
}
|
||||
|
||||
auto invoker_ptr = device_instance.MakeInvokerPointer();
|
||||
invoker_ptr->Run(argument_ptr.get());
|
||||
|
||||
ref_instance_invoker_.Run(
|
||||
{x, gamma, beta, y_ref, PassThrough{}, lengths, reduceDims, 1e-4});
|
||||
|
||||
y_dev.FromDevice(y.mData.data());
|
||||
|
||||
bool pass;
|
||||
|
||||
if(std::is_same<XDataType, int8_t>::value)
|
||||
{
|
||||
EXPECT_TRUE(pass = ck::utils::check_err(
|
||||
y.mData, y_ref.mData, "Error: Incorrect results!", 0, 1));
|
||||
}
|
||||
else
|
||||
{
|
||||
EXPECT_TRUE(pass = ck::utils::check_err(
|
||||
y.mData, y_ref.mData, "Error: Incorrect results d1", 1e-3, 1e-3));
|
||||
}
|
||||
|
||||
if(!pass)
|
||||
{
|
||||
FAIL() << "Failure in input lengths = [" << serialize_range(lengths) << "], "
|
||||
<< "reduce dim = [" << serialize_range(reduceDims) << "].";
|
||||
}
|
||||
}
|
||||
|
||||
void Run()
|
||||
{
|
||||
for(auto length : this->lengths_)
|
||||
{
|
||||
this->RunSingle(length, reduceDims_[0]);
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<std::vector<index_t>> lengths_ = {
|
||||
{4, 256}, {8, 511}, {9, 1032}, {4, 2048}, {1, 8192}, {4000, 2000}};
|
||||
|
||||
std::vector<std::vector<index_t>> reduceDims_ = {{1}};
|
||||
|
||||
typename ReferenceInstance::Invoker ref_instance_invoker_;
|
||||
};
|
||||
} // namespace ck
|
||||
Reference in New Issue
Block a user