mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-24 14:54:47 +00:00
Improve normalization (#580)
* Sync the order of type string with template parameter
* Add more instances
* Check the vector size and remove redundant var
* Extract var to static, prepare to separate sweep once kernel
* Separate sweeponce flow and optimize the flow
* 1. Rename AccDatatype in normalization to computeData
2. Rename AccElementwiseOperation to YElementwiseOperation in normalization
* Remove useless code
* Update naive variance kernel
* Refine string
* Fix typo
* Support naive variance for device_normalization
* Check the blocksize
* Share the VGPR of x and y
* Share the VGPR of gamma and beta
* Add more instances
* Support fp16 sqrt for experiment
* Add CHANGELOG
* Fix typo
* clang-format
[ROCm/composable_kernel commit: 6a6163a3d1]
This commit is contained in:
@@ -20,12 +20,12 @@
|
||||
#include "ck/library/utility/literals.hpp"
|
||||
#include "ck/library/reference_tensor_operation/cpu/reference_layernorm.hpp"
|
||||
|
||||
using XDataType = ck::half_t;
|
||||
using GammaDataType = ck::half_t;
|
||||
using BetaDataType = ck::half_t;
|
||||
using YDataType = ck::half_t;
|
||||
using AccDataType = float;
|
||||
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
|
||||
using XDataType = ck::half_t;
|
||||
using GammaDataType = ck::half_t;
|
||||
using BetaDataType = ck::half_t;
|
||||
using YDataType = ck::half_t;
|
||||
using ComputeDataType = float;
|
||||
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
|
||||
|
||||
constexpr int Rank = 2;
|
||||
constexpr int NumReduceDim = 1;
|
||||
@@ -34,7 +34,7 @@ using DeviceInstance =
|
||||
ck::tensor_operation::device::DeviceNormalizationImpl<XDataType,
|
||||
GammaDataType,
|
||||
BetaDataType,
|
||||
AccDataType,
|
||||
ComputeDataType,
|
||||
YDataType,
|
||||
PassThrough,
|
||||
Rank,
|
||||
@@ -121,7 +121,7 @@ int main()
|
||||
GammaDataType,
|
||||
BetaDataType,
|
||||
YDataType,
|
||||
AccDataType,
|
||||
ComputeDataType,
|
||||
PassThrough,
|
||||
Rank,
|
||||
NumReduceDim>;
|
||||
|
||||
@@ -23,11 +23,11 @@
|
||||
constexpr int Rank = 5;
|
||||
constexpr int NumReduceDim = 3;
|
||||
|
||||
using XDataType = ck::half_t;
|
||||
using GammaDataType = ck::half_t;
|
||||
using BetaDataType = ck::half_t;
|
||||
using YDataType = ck::half_t;
|
||||
using AccDataType = float;
|
||||
using XDataType = ck::half_t;
|
||||
using GammaDataType = ck::half_t;
|
||||
using BetaDataType = ck::half_t;
|
||||
using YDataType = ck::half_t;
|
||||
using ComputeDataType = float;
|
||||
|
||||
struct YElementOp
|
||||
{
|
||||
@@ -50,7 +50,7 @@ using DeviceInstance =
|
||||
ck::tensor_operation::device::DeviceNormalizationImpl<XDataType,
|
||||
GammaDataType,
|
||||
BetaDataType,
|
||||
AccDataType,
|
||||
ComputeDataType,
|
||||
YDataType,
|
||||
YElementOp,
|
||||
Rank,
|
||||
@@ -157,7 +157,7 @@ int main(int argc, char* argv[])
|
||||
GammaDataType,
|
||||
BetaDataType,
|
||||
YDataType,
|
||||
AccDataType,
|
||||
ComputeDataType,
|
||||
YElementOp>;
|
||||
|
||||
ReferenceInstance ref;
|
||||
|
||||
Reference in New Issue
Block a user