Files
composable_kernel/example/27_layernorm/layernorm_splitk_fp16.cpp
rocking 3696fe1c76 Layernorm and groupnorm support to save mean and inverse std in forward (#929)
* save mean and inverse std in normalization

* Save mean and inverse std in splitK

* Vector save mean and inv std

* Modify instance for save mean and std

* simplify the layernorm example

* Save mean and std in groupnorm example

* Save mean and inv std in ckProfiler and test

* Remove compute data type from base class

* Save mean and inv std in client example

* Add changelog

* clang format

* Fix compile error

* Refine naming

* Avoid error in bf16

* revert changelog
2023-10-19 07:36:29 +08:00

46 lines
2.5 KiB
C++

// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#include "common.hpp"
using XDataType = ck::half_t;
using GammaDataType = ck::half_t;
using BetaDataType = ck::half_t;
using YDataType = ck::half_t;
using SaveMeanInvStdDataType = float;
using ComputeDataType = float;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
#define SAVE_MEAN_INV_STD
constexpr int Rank = 2;
constexpr int NumReduceDim = 1;
using DeviceInstance =
ck::tensor_operation::device::DeviceNormalizationSplitKImpl<XDataType,
GammaDataType,
BetaDataType,
ComputeDataType,
YDataType,
SaveMeanInvStdDataType,
PassThrough,
Rank,
NumReduceDim,
256, // BlockSize
8, // ClusterM
32, // ClusterK
1, // SliceM
8, // SliceK
1, // XYVectorDim (0=M, 1=K)
8, // XScalarPerVector
1, // GammaVecDim (0=M, 1=K)
8, // GammaScalarPerVector
1, // BetaVecDim (0=M, 1=K)
8, // BetaScalarPerVector
8, // YScalarPerVector
1>; // SaveMeanInvStdScalarPerVector
#include "run_layernorm_example.inc"
int main() { return run_groupnorm_example<DeviceInstance>(); }