Fix bug of layernorm ckProfiler and refine code (#448)

* Fix bug of profiler for layernorm

* 1. Rename layernorm into normalization
2. Decouple softmax from normalization

* clang-format

[ROCm/composable_kernel commit: 1b62bfaa2a]
This commit is contained in:
rocking5566
2022-10-13 10:06:39 +08:00
committed by GitHub
parent 096571bea9
commit 1dcaa3991f
29 changed files with 423 additions and 461 deletions

View File

@@ -7,7 +7,7 @@
#include "ck/ck.hpp"
#include "ck/library/tensor_operation_instance/gpu/layernorm.hpp"
#include "ck/library/tensor_operation_instance/gpu/normalization.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/device_memory.hpp"
@@ -75,14 +75,14 @@ bool profile_groupnorm_impl(int do_verification,
beta_dev.ToDevice(beta.mData.data());
// add device normalization instances
using DeviceOp = ck::tensor_operation::device::DeviceLayernorm<XDataType,
GammaDataType,
BetaDataType,
AccDataType,
YDataType,
PassThrough,
5,
3>;
using DeviceOp = ck::tensor_operation::device::DeviceNormalization<XDataType,
GammaDataType,
BetaDataType,
AccDataType,
YDataType,
PassThrough,
5,
3>;
// get device op instances
const auto instance_ptrs =

View File

@@ -7,7 +7,7 @@
#include "ck/ck.hpp"
#include "ck/library/tensor_operation_instance/gpu/layernorm.hpp"
#include "ck/library/tensor_operation_instance/gpu/normalization.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/device_memory.hpp"
@@ -28,27 +28,29 @@ void profile_layernorm_impl(int do_verification,
int init_method,
bool do_log,
bool time_kernel,
std::vector<index_t> length,
std::vector<index_t> strideXY,
std::vector<index_t> strideGamma,
std::vector<index_t> strideBeta)
std::vector<index_t> length)
{
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
if(length.size() < 2)
return;
// Assume normalize dimension except for first dimension
// Assume normalize dimension except for batch (first) dimension
std::vector<index_t> reduce_length{length.begin() + 1, length.end()};
std::vector<index_t> reduce_dim;
for(int i = 1; i < Rank; ++i)
reduce_dim.push_back(i);
Tensor<XDataType> x(length);
Tensor<GammaDataType> gamma(reduce_length, strideGamma);
Tensor<BetaDataType> beta(reduce_length, strideBeta);
Tensor<YDataType> y(length, strideXY);
Tensor<YDataType> host_y(length, strideXY);
Tensor<GammaDataType> gamma(reduce_length);
Tensor<BetaDataType> beta(reduce_length);
Tensor<YDataType> y(length);
Tensor<YDataType> host_y(length);
std::vector<index_t> strideXY =
std::vector<ck::index_t>{x.mDesc.GetStrides().begin(), x.mDesc.GetStrides().end()};
std::vector<index_t> strideGammaBeta = strideXY;
strideGammaBeta[0] = 0;
switch(init_method)
{
@@ -84,14 +86,14 @@ void profile_layernorm_impl(int do_verification,
constexpr int NumReduceDim = Rank - 1;
// add device normalization instances
using DeviceOp = ck::tensor_operation::device::DeviceLayernorm<XDataType,
GammaDataType,
BetaDataType,
AccDataType,
YDataType,
PassThrough,
Rank,
NumReduceDim>;
using DeviceOp = ck::tensor_operation::device::DeviceNormalization<XDataType,
GammaDataType,
BetaDataType,
AccDataType,
YDataType,
PassThrough,
Rank,
NumReduceDim>;
// get device op instances
const auto instance_ptrs =
@@ -126,8 +128,8 @@ void profile_layernorm_impl(int do_verification,
{
auto argument_ptr = inst_ptr->MakeArgumentPointer(length,
strideXY,
strideGamma,
strideBeta,
strideGammaBeta,
strideGammaBeta,
strideXY,
reduce_dim,
1e-4,

View File

@@ -69,16 +69,16 @@ template <> std::string type_to_string<int32_t>() { return "int32"; }
// clang-format on
template <typename InDataType, typename AccDataType, typename OutDataType, index_t Rank>
void profile_normalization_impl(int do_verification,
int init_method,
bool do_log,
bool time_kernel,
std::vector<index_t> in_length,
std::vector<index_t> in_strides,
std::vector<index_t> reduce_dims,
AccDataType alpha,
AccDataType beta,
NormType norm_type)
void profile_softmax_impl(int do_verification,
int init_method,
bool do_log,
bool time_kernel,
std::vector<index_t> in_length,
std::vector<index_t> in_strides,
std::vector<index_t> reduce_dims,
AccDataType alpha,
AccDataType beta,
NormType norm_type)
{
if(Rank != in_length.size())
{