Files
composable_kernel/example/42_groupnorm/groupnorm_splitk_fp16.cpp
rocking 3696fe1c76 Layernorm and groupnorm support to save mean and inverse std in forward (#929)
* save mean and inverse std in normalization

* Save mean and inverse std in splitK

* Vector save mean and inv std

* Modify instance for save mean and std

* simplify the layernorm example

* Save mean and std in groupnorm example

* Save mean and inv std in ckProfiler and test

* Remove compute data type from base class

* Save mean and inv std in client example

* Add changelog

* clang format

* Fix compile error

* Refine naming

* Avoid error in bf16

* revert changelog
2023-10-19 07:36:29 +08:00

46 lines
2.5 KiB
C++

// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#include "common.hpp"
constexpr int Rank = 5;
constexpr int NumReduceDim = 3;
using XDataType = ck::half_t;
using GammaDataType = ck::half_t;
using BetaDataType = ck::half_t;
using YDataType = ck::half_t;
using SaveMeanInvStdDataType = float;
using ComputeDataType = float;
using YElementOp = ck::tensor_operation::element_wise::Swish;
#define SAVE_MEAN_INV_STD
using DeviceInstance =
ck::tensor_operation::device::DeviceNormalizationSplitKImpl<XDataType,
GammaDataType,
BetaDataType,
ComputeDataType,
YDataType,
SaveMeanInvStdDataType,
YElementOp,
Rank,
NumReduceDim,
256, // BlockSize
1, // ClusterM
256, // ClusterK
1, // SliceM
16, // SliceK
1, // SrcVecDim (0=M, 1=K)
2, // SrcScalarPerVector
1, // GammaVecDim (0=M, 1=K)
2, // GammaScalarPerVector
1, // BetaVecDim (0=M, 1=K)
2, // BetaScalarPerVector
2, // YScalarPerVector
1>; // SaveMeanInvStdScalarPerVector
#include "run_groupnorm_example.inc"
int main(int argc, char* argv[]) { run_groupnorm_example(argc, argv); }