mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-24 06:44:36 +00:00
Layernorm4d (#1022)
* Rename folder
* Add layernorm 4d fwd example
* Rename original layernorm example
* Add layernorm 4d f16 test
* Add layernorm4d_fwd client example
* Support layernorm4D in ckProfiler
* Rename groupnorm to groupnorm fwd in example
* Rename layernorm and group fwd in test
* Rename normalization to normalization_fwd (instances)
* Add fwd to DeviceNormalization
* Rename external api header
* Rename folder, because we can also add bwd in this folder
* Add fwd in layernorm and groupnorm (profiler
* Fix compile error
---------
Co-authored-by: Po Yen Chen <PoYen.Chen@amd.com>
[ROCm/composable_kernel commit: a3d9a2cd42]
This commit is contained in:
@@ -28,7 +28,8 @@ template <typename XDataType,
|
||||
struct ReferenceLayernorm : public device::BaseOperator
|
||||
{
|
||||
// TODO - support generic layernorm
|
||||
static_assert((Rank == 2 && NumReduceDim == 1), "Only support 2D version so far");
|
||||
static_assert((Rank == 2 && NumReduceDim == 1) || (Rank == 4 && NumReduceDim == 3),
|
||||
"Only support 2D & 4D version so far");
|
||||
|
||||
// Argument
|
||||
struct Argument : public device::BaseArgument
|
||||
@@ -71,7 +72,7 @@ struct ReferenceLayernorm : public device::BaseOperator
|
||||
// Invoker
|
||||
struct Invoker : public device::BaseInvoker
|
||||
{
|
||||
float Run(const Argument& arg)
|
||||
float Run2D(const Argument& arg)
|
||||
{
|
||||
int M = arg.lengths_[0];
|
||||
int N = arg.lengths_[1];
|
||||
@@ -117,6 +118,71 @@ struct ReferenceLayernorm : public device::BaseOperator
|
||||
return 0;
|
||||
}
|
||||
|
||||
float Run4D(const Argument& arg)
|
||||
{
|
||||
int N = arg.lengths_[0];
|
||||
int H = arg.lengths_[1];
|
||||
int W = arg.lengths_[2];
|
||||
int C = arg.lengths_[3];
|
||||
|
||||
Tensor<ComputeDataType> mean({N});
|
||||
Tensor<ComputeDataType> var({N});
|
||||
|
||||
int reduce_length = H * W * C;
|
||||
|
||||
for(int n = 0; n < N; ++n)
|
||||
{
|
||||
mean(n) = 0;
|
||||
var(n) = 0;
|
||||
|
||||
for(int h = 0; h < H; ++h)
|
||||
for(int w = 0; w < W; ++w)
|
||||
for(int c = 0; c < C; ++c)
|
||||
{
|
||||
auto x_val = ck::type_convert<ComputeDataType>(arg.x_m_n_(n, h, w, c));
|
||||
mean(n) += x_val;
|
||||
var(n) += x_val * x_val;
|
||||
}
|
||||
|
||||
mean(n) = mean(n) / reduce_length;
|
||||
var(n) = (var(n) / reduce_length) - (mean(n) * mean(n));
|
||||
}
|
||||
|
||||
for(int n = 0; n < N; ++n)
|
||||
{
|
||||
ComputeDataType divisor =
|
||||
static_cast<ComputeDataType>(1) / ck::math::sqrt(var(n) + arg.epsilon_);
|
||||
|
||||
for(int h = 0; h < H; ++h)
|
||||
for(int w = 0; w < W; ++w)
|
||||
for(int c = 0; c < C; ++c)
|
||||
{
|
||||
auto x_val = ck::type_convert<ComputeDataType>(arg.x_m_n_(n, h, w, c));
|
||||
auto gamma_val =
|
||||
ck::type_convert<ComputeDataType>(arg.gamma_n_(h, w, c));
|
||||
auto beta_val = ck::type_convert<ComputeDataType>(arg.beta_n_(h, w, c));
|
||||
auto y_val = (x_val - mean(n)) * divisor;
|
||||
y_val = (y_val * gamma_val) + beta_val;
|
||||
arg.y_elementwise_op_(y_val, y_val);
|
||||
arg.y_m_n_(n, h, w, c) = ck::type_convert<YDataType>(y_val);
|
||||
}
|
||||
arg.save_mean_m_(n) = ck::type_convert<SaveMeanInvStdDataType>(mean(n));
|
||||
arg.save_inv_std_m_(n) = ck::type_convert<SaveMeanInvStdDataType>(divisor);
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
float Run(const Argument& arg)
|
||||
{
|
||||
if(arg.lengths_.size() == 2)
|
||||
return Run2D(arg);
|
||||
else if(arg.lengths_.size() == 4)
|
||||
return Run4D(arg);
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
float Run(const device::BaseArgument* p_arg,
|
||||
const StreamConfig& /* stream_config */ = StreamConfig{}) override
|
||||
{
|
||||
@@ -134,17 +200,16 @@ struct ReferenceLayernorm : public device::BaseOperator
|
||||
{
|
||||
const Argument* p_arg_ = dynamic_cast<const Argument*>(p_arg);
|
||||
|
||||
// TODO - support generic layernorm
|
||||
if(p_arg_->lengths_.size() != 2)
|
||||
return false;
|
||||
if(p_arg_->lengths_.size() == 2 && p_arg_->reduceDims_.size() == 1 &&
|
||||
p_arg_->reduceDims_[0] == 1)
|
||||
return true;
|
||||
|
||||
if(p_arg_->reduceDims_.size() != 1)
|
||||
return false;
|
||||
else if(p_arg_->lengths_.size() == 4 && p_arg_->reduceDims_.size() == 3 &&
|
||||
p_arg_->reduceDims_[0] == 1 && p_arg_->reduceDims_[1] == 2 &&
|
||||
p_arg_->reduceDims_[2] == 3)
|
||||
return true;
|
||||
|
||||
if(p_arg_->reduceDims_[0] != 1)
|
||||
return false;
|
||||
|
||||
return true;
|
||||
return false;
|
||||
}
|
||||
|
||||
static auto MakeArgument(const Tensor<XDataType>& x_m_n,
|
||||
|
||||
@@ -7,7 +7,7 @@
|
||||
#include <memory>
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/device_normalization.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/device_normalization_fwd.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
|
||||
|
||||
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
|
||||
@@ -18,25 +18,31 @@ namespace device {
|
||||
namespace instance {
|
||||
#ifdef CK_ENABLE_FP16
|
||||
// FP16
|
||||
void add_device_normalization_rank_2_1_f16_instances(
|
||||
std::vector<std::unique_ptr<DeviceNormalization<F16, F16, F16, F16, F32, PassThrough, 2, 1>>>&);
|
||||
void add_device_normalization_fwd_rank_2_1_f16_instances(
|
||||
std::vector<
|
||||
std::unique_ptr<DeviceNormalizationFwd<F16, F16, F16, F16, F32, PassThrough, 2, 1>>>&);
|
||||
|
||||
void add_device_normalization_rank_4_3_f16_instances(
|
||||
std::vector<std::unique_ptr<DeviceNormalization<F16, F16, F16, F16, F32, PassThrough, 4, 3>>>&);
|
||||
void add_device_normalization_fwd_rank_4_3_f16_instances(
|
||||
std::vector<
|
||||
std::unique_ptr<DeviceNormalizationFwd<F16, F16, F16, F16, F32, PassThrough, 4, 3>>>&);
|
||||
|
||||
void add_device_normalization_rank_5_3_f16_instances(
|
||||
std::vector<std::unique_ptr<DeviceNormalization<F16, F16, F16, F16, F32, PassThrough, 5, 3>>>&);
|
||||
void add_device_normalization_fwd_rank_5_3_f16_instances(
|
||||
std::vector<
|
||||
std::unique_ptr<DeviceNormalizationFwd<F16, F16, F16, F16, F32, PassThrough, 5, 3>>>&);
|
||||
#endif
|
||||
#ifdef CK_ENABLE_FP32
|
||||
// FP32
|
||||
void add_device_normalization_rank_2_1_f32_instances(
|
||||
std::vector<std::unique_ptr<DeviceNormalization<F32, F32, F32, F32, F32, PassThrough, 2, 1>>>&);
|
||||
void add_device_normalization_fwd_rank_2_1_f32_instances(
|
||||
std::vector<
|
||||
std::unique_ptr<DeviceNormalizationFwd<F32, F32, F32, F32, F32, PassThrough, 2, 1>>>&);
|
||||
|
||||
void add_device_normalization_rank_4_3_f32_instances(
|
||||
std::vector<std::unique_ptr<DeviceNormalization<F32, F32, F32, F32, F32, PassThrough, 4, 3>>>&);
|
||||
void add_device_normalization_fwd_rank_4_3_f32_instances(
|
||||
std::vector<
|
||||
std::unique_ptr<DeviceNormalizationFwd<F32, F32, F32, F32, F32, PassThrough, 4, 3>>>&);
|
||||
|
||||
void add_device_normalization_rank_5_3_f32_instances(
|
||||
std::vector<std::unique_ptr<DeviceNormalization<F32, F32, F32, F32, F32, PassThrough, 5, 3>>>&);
|
||||
void add_device_normalization_fwd_rank_5_3_f32_instances(
|
||||
std::vector<
|
||||
std::unique_ptr<DeviceNormalizationFwd<F32, F32, F32, F32, F32, PassThrough, 5, 3>>>&);
|
||||
#endif
|
||||
template <typename XDataType,
|
||||
typename GammaDataType,
|
||||
@@ -45,7 +51,7 @@ template <typename XDataType,
|
||||
typename SaveMeanInvStdDataType,
|
||||
index_t Rank,
|
||||
index_t NumReduceDim>
|
||||
struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceNormalization<
|
||||
struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceNormalizationFwd<
|
||||
XDataType,
|
||||
GammaDataType,
|
||||
BetaDataType,
|
||||
@@ -55,14 +61,14 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceNormal
|
||||
Rank,
|
||||
NumReduceDim>>
|
||||
{
|
||||
using DeviceOp = DeviceNormalization<XDataType,
|
||||
GammaDataType,
|
||||
BetaDataType,
|
||||
YDataType,
|
||||
SaveMeanInvStdDataType,
|
||||
ck::tensor_operation::element_wise::PassThrough,
|
||||
Rank,
|
||||
NumReduceDim>;
|
||||
using DeviceOp = DeviceNormalizationFwd<XDataType,
|
||||
GammaDataType,
|
||||
BetaDataType,
|
||||
YDataType,
|
||||
SaveMeanInvStdDataType,
|
||||
ck::tensor_operation::element_wise::PassThrough,
|
||||
Rank,
|
||||
NumReduceDim>;
|
||||
|
||||
static auto GetInstances()
|
||||
{
|
||||
@@ -74,15 +80,15 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceNormal
|
||||
{
|
||||
if constexpr(Rank == 2 && NumReduceDim == 1)
|
||||
{
|
||||
add_device_normalization_rank_2_1_f16_instances(op_ptrs);
|
||||
add_device_normalization_fwd_rank_2_1_f16_instances(op_ptrs);
|
||||
}
|
||||
else if constexpr(Rank == 4 && NumReduceDim == 3)
|
||||
{
|
||||
add_device_normalization_rank_4_3_f16_instances(op_ptrs);
|
||||
add_device_normalization_fwd_rank_4_3_f16_instances(op_ptrs);
|
||||
}
|
||||
else if constexpr(Rank == 5 && NumReduceDim == 3)
|
||||
{
|
||||
add_device_normalization_rank_5_3_f16_instances(op_ptrs);
|
||||
add_device_normalization_fwd_rank_5_3_f16_instances(op_ptrs);
|
||||
}
|
||||
}
|
||||
#endif
|
||||
@@ -93,15 +99,15 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceNormal
|
||||
{
|
||||
if constexpr(Rank == 2 && NumReduceDim == 1)
|
||||
{
|
||||
add_device_normalization_rank_2_1_f32_instances(op_ptrs);
|
||||
add_device_normalization_fwd_rank_2_1_f32_instances(op_ptrs);
|
||||
}
|
||||
else if constexpr(Rank == 4 && NumReduceDim == 3)
|
||||
{
|
||||
add_device_normalization_rank_4_3_f32_instances(op_ptrs);
|
||||
add_device_normalization_fwd_rank_4_3_f32_instances(op_ptrs);
|
||||
}
|
||||
else if constexpr(Rank == 5 && NumReduceDim == 3)
|
||||
{
|
||||
add_device_normalization_rank_5_3_f32_instances(op_ptrs);
|
||||
add_device_normalization_fwd_rank_5_3_f32_instances(op_ptrs);
|
||||
}
|
||||
}
|
||||
#endif
|
||||
@@ -7,7 +7,7 @@
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/device_normalization.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/device_normalization_fwd.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
|
||||
|
||||
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
|
||||
@@ -18,16 +18,16 @@ namespace device {
|
||||
namespace instance {
|
||||
|
||||
// FP16
|
||||
void add_device_normalization_rank_5_3_swish_f16_instances(
|
||||
std::vector<std::unique_ptr<DeviceNormalization<F16, F16, F16, F16, F32, Swish, 5, 3>>>&);
|
||||
void add_device_normalization_fwd_rank_5_3_swish_f16_instances(
|
||||
std::vector<std::unique_ptr<DeviceNormalizationFwd<F16, F16, F16, F16, F32, Swish, 5, 3>>>&);
|
||||
|
||||
// FP32
|
||||
void add_device_normalization_rank_5_3_swish_f32_instances(
|
||||
std::vector<std::unique_ptr<DeviceNormalization<F32, F32, F32, F32, F32, Swish, 5, 3>>>&);
|
||||
void add_device_normalization_fwd_rank_5_3_swish_f32_instances(
|
||||
std::vector<std::unique_ptr<DeviceNormalizationFwd<F32, F32, F32, F32, F32, Swish, 5, 3>>>&);
|
||||
|
||||
// [x, gamma, beta, y] = [f16, f32, f32, f16]
|
||||
void add_device_normalization_rank_5_3_swish_f16_f32_f32_f16_instances(
|
||||
std::vector<std::unique_ptr<DeviceNormalization<F16, F32, F32, F16, F32, Swish, 5, 3>>>&);
|
||||
void add_device_normalization_fwd_rank_5_3_swish_f16_f32_f32_f16_instances(
|
||||
std::vector<std::unique_ptr<DeviceNormalizationFwd<F16, F32, F32, F16, F32, Swish, 5, 3>>>&);
|
||||
|
||||
template <typename XDataType,
|
||||
typename GammaDataType,
|
||||
@@ -37,23 +37,23 @@ template <typename XDataType,
|
||||
index_t Rank,
|
||||
index_t NumReduceDim>
|
||||
struct DeviceOperationInstanceFactory<
|
||||
ck::tensor_operation::device::DeviceNormalization<XDataType,
|
||||
GammaDataType,
|
||||
BetaDataType,
|
||||
YDataType,
|
||||
SaveMeanInvStdDataType,
|
||||
ck::tensor_operation::element_wise::Swish,
|
||||
Rank,
|
||||
NumReduceDim>>
|
||||
ck::tensor_operation::device::DeviceNormalizationFwd<XDataType,
|
||||
GammaDataType,
|
||||
BetaDataType,
|
||||
YDataType,
|
||||
SaveMeanInvStdDataType,
|
||||
ck::tensor_operation::element_wise::Swish,
|
||||
Rank,
|
||||
NumReduceDim>>
|
||||
{
|
||||
using DeviceOp = DeviceNormalization<XDataType,
|
||||
GammaDataType,
|
||||
BetaDataType,
|
||||
YDataType,
|
||||
SaveMeanInvStdDataType,
|
||||
ck::tensor_operation::element_wise::Swish,
|
||||
Rank,
|
||||
NumReduceDim>;
|
||||
using DeviceOp = DeviceNormalizationFwd<XDataType,
|
||||
GammaDataType,
|
||||
BetaDataType,
|
||||
YDataType,
|
||||
SaveMeanInvStdDataType,
|
||||
ck::tensor_operation::element_wise::Swish,
|
||||
Rank,
|
||||
NumReduceDim>;
|
||||
|
||||
static auto GetInstances()
|
||||
{
|
||||
@@ -65,7 +65,7 @@ struct DeviceOperationInstanceFactory<
|
||||
{
|
||||
if constexpr(Rank == 5 && NumReduceDim == 3)
|
||||
{
|
||||
add_device_normalization_rank_5_3_swish_f16_instances(op_ptrs);
|
||||
add_device_normalization_fwd_rank_5_3_swish_f16_instances(op_ptrs);
|
||||
}
|
||||
}
|
||||
else if constexpr(is_same_v<XDataType, F32> && is_same_v<GammaDataType, F32> &&
|
||||
@@ -74,7 +74,7 @@ struct DeviceOperationInstanceFactory<
|
||||
{
|
||||
if constexpr(Rank == 5 && NumReduceDim == 3)
|
||||
{
|
||||
add_device_normalization_rank_5_3_swish_f32_instances(op_ptrs);
|
||||
add_device_normalization_fwd_rank_5_3_swish_f32_instances(op_ptrs);
|
||||
}
|
||||
}
|
||||
else if constexpr(is_same_v<XDataType, F16> && is_same_v<GammaDataType, F32> &&
|
||||
@@ -83,7 +83,7 @@ struct DeviceOperationInstanceFactory<
|
||||
{
|
||||
if constexpr(Rank == 5 && NumReduceDim == 3)
|
||||
{
|
||||
add_device_normalization_rank_5_3_swish_f16_f32_f32_f16_instances(op_ptrs);
|
||||
add_device_normalization_fwd_rank_5_3_swish_f16_f32_f32_f16_instances(op_ptrs);
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user