mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-19 12:30:16 +00:00
Fix bug of layernorm ckProfiler and refine code (#448)
* Fix bug of profiler for layernorm
* 1. Rename layernorm into normalization
2. Decouple softmax from normalization
* clang-format
[ROCm/composable_kernel commit: 1b62bfaa2a]
This commit is contained in:
@@ -7,7 +7,7 @@
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
|
||||
#include "ck/library/tensor_operation_instance/gpu/layernorm.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/normalization.hpp"
|
||||
|
||||
#include "ck/library/utility/check_err.hpp"
|
||||
#include "ck/library/utility/device_memory.hpp"
|
||||
@@ -75,14 +75,14 @@ bool profile_groupnorm_impl(int do_verification,
|
||||
beta_dev.ToDevice(beta.mData.data());
|
||||
|
||||
// add device normalization instances
|
||||
using DeviceOp = ck::tensor_operation::device::DeviceLayernorm<XDataType,
|
||||
GammaDataType,
|
||||
BetaDataType,
|
||||
AccDataType,
|
||||
YDataType,
|
||||
PassThrough,
|
||||
5,
|
||||
3>;
|
||||
using DeviceOp = ck::tensor_operation::device::DeviceNormalization<XDataType,
|
||||
GammaDataType,
|
||||
BetaDataType,
|
||||
AccDataType,
|
||||
YDataType,
|
||||
PassThrough,
|
||||
5,
|
||||
3>;
|
||||
|
||||
// get device op instances
|
||||
const auto instance_ptrs =
|
||||
|
||||
@@ -7,7 +7,7 @@
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
|
||||
#include "ck/library/tensor_operation_instance/gpu/layernorm.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/normalization.hpp"
|
||||
|
||||
#include "ck/library/utility/check_err.hpp"
|
||||
#include "ck/library/utility/device_memory.hpp"
|
||||
@@ -28,27 +28,29 @@ void profile_layernorm_impl(int do_verification,
|
||||
int init_method,
|
||||
bool do_log,
|
||||
bool time_kernel,
|
||||
std::vector<index_t> length,
|
||||
std::vector<index_t> strideXY,
|
||||
std::vector<index_t> strideGamma,
|
||||
std::vector<index_t> strideBeta)
|
||||
std::vector<index_t> length)
|
||||
{
|
||||
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
|
||||
|
||||
if(length.size() < 2)
|
||||
return;
|
||||
|
||||
// Assume normalize dimension except for first dimension
|
||||
// Assume normalize dimension except for batch (first) dimension
|
||||
std::vector<index_t> reduce_length{length.begin() + 1, length.end()};
|
||||
std::vector<index_t> reduce_dim;
|
||||
for(int i = 1; i < Rank; ++i)
|
||||
reduce_dim.push_back(i);
|
||||
|
||||
Tensor<XDataType> x(length);
|
||||
Tensor<GammaDataType> gamma(reduce_length, strideGamma);
|
||||
Tensor<BetaDataType> beta(reduce_length, strideBeta);
|
||||
Tensor<YDataType> y(length, strideXY);
|
||||
Tensor<YDataType> host_y(length, strideXY);
|
||||
Tensor<GammaDataType> gamma(reduce_length);
|
||||
Tensor<BetaDataType> beta(reduce_length);
|
||||
Tensor<YDataType> y(length);
|
||||
Tensor<YDataType> host_y(length);
|
||||
|
||||
std::vector<index_t> strideXY =
|
||||
std::vector<ck::index_t>{x.mDesc.GetStrides().begin(), x.mDesc.GetStrides().end()};
|
||||
std::vector<index_t> strideGammaBeta = strideXY;
|
||||
strideGammaBeta[0] = 0;
|
||||
|
||||
switch(init_method)
|
||||
{
|
||||
@@ -84,14 +86,14 @@ void profile_layernorm_impl(int do_verification,
|
||||
constexpr int NumReduceDim = Rank - 1;
|
||||
|
||||
// add device normalization instances
|
||||
using DeviceOp = ck::tensor_operation::device::DeviceLayernorm<XDataType,
|
||||
GammaDataType,
|
||||
BetaDataType,
|
||||
AccDataType,
|
||||
YDataType,
|
||||
PassThrough,
|
||||
Rank,
|
||||
NumReduceDim>;
|
||||
using DeviceOp = ck::tensor_operation::device::DeviceNormalization<XDataType,
|
||||
GammaDataType,
|
||||
BetaDataType,
|
||||
AccDataType,
|
||||
YDataType,
|
||||
PassThrough,
|
||||
Rank,
|
||||
NumReduceDim>;
|
||||
|
||||
// get device op instances
|
||||
const auto instance_ptrs =
|
||||
@@ -126,8 +128,8 @@ void profile_layernorm_impl(int do_verification,
|
||||
{
|
||||
auto argument_ptr = inst_ptr->MakeArgumentPointer(length,
|
||||
strideXY,
|
||||
strideGamma,
|
||||
strideBeta,
|
||||
strideGammaBeta,
|
||||
strideGammaBeta,
|
||||
strideXY,
|
||||
reduce_dim,
|
||||
1e-4,
|
||||
|
||||
@@ -69,16 +69,16 @@ template <> std::string type_to_string<int32_t>() { return "int32"; }
|
||||
// clang-format on
|
||||
|
||||
template <typename InDataType, typename AccDataType, typename OutDataType, index_t Rank>
|
||||
void profile_normalization_impl(int do_verification,
|
||||
int init_method,
|
||||
bool do_log,
|
||||
bool time_kernel,
|
||||
std::vector<index_t> in_length,
|
||||
std::vector<index_t> in_strides,
|
||||
std::vector<index_t> reduce_dims,
|
||||
AccDataType alpha,
|
||||
AccDataType beta,
|
||||
NormType norm_type)
|
||||
void profile_softmax_impl(int do_verification,
|
||||
int init_method,
|
||||
bool do_log,
|
||||
bool time_kernel,
|
||||
std::vector<index_t> in_length,
|
||||
std::vector<index_t> in_strides,
|
||||
std::vector<index_t> reduce_dims,
|
||||
AccDataType alpha,
|
||||
AccDataType beta,
|
||||
NormType norm_type)
|
||||
{
|
||||
if(Rank != in_length.size())
|
||||
{
|
||||
Reference in New Issue
Block a user