Improve normalization (#580)

* Sync the order of type string with template parameter

* Add more instances

* Check the vector size and remove redundant var

* Extract var to static, prepare to separate sweep once kernel

* Separate sweeponce flow and optimize the flow

* 1. Rename AccDatatype in normalization to computeData
2. Rename AccElementwiseOperation to YElementwiseOperation in normalization

* Remove useless code

* Update naive variance kernel

* Refine string

* Fix typo

* Support naive variance for device_normalization

* Check the blocksize

* Share the VGPR of x and y

* Share the VGPR of gamma and beta

* Add more instances

* Support fp16 sqrt for experiment

* Add CHANGELOG

* Fix typo

* clang-format

[ROCm/composable_kernel commit: 6a6163a3d1]
This commit is contained in:
rocking5566
2023-02-16 01:59:35 +08:00
committed by GitHub
parent d6de9bdcbe
commit d5062679f1
17 changed files with 826 additions and 459 deletions

View File

@@ -12,11 +12,11 @@ template <typename Tuple>
class TestGroupnorm : public ::testing::Test
{
protected:
using XDataType = std::tuple_element_t<0, Tuple>;
using GammaDataType = std::tuple_element_t<1, Tuple>;
using BetaDataType = std::tuple_element_t<2, Tuple>;
using AccDataType = std::tuple_element_t<3, Tuple>;
using YDataType = std::tuple_element_t<4, Tuple>;
using XDataType = std::tuple_element_t<0, Tuple>;
using GammaDataType = std::tuple_element_t<1, Tuple>;
using BetaDataType = std::tuple_element_t<2, Tuple>;
using ComputeDataType = std::tuple_element_t<3, Tuple>;
using YDataType = std::tuple_element_t<4, Tuple>;
void Run()
{
@@ -34,7 +34,7 @@ class TestGroupnorm : public ::testing::Test
ck::profiler::profile_groupnorm_impl<XDataType,
GammaDataType,
BetaDataType,
AccDataType,
ComputeDataType,
YDataType>(true, 2, false, false, length);
EXPECT_TRUE(success);
}
@@ -42,7 +42,7 @@ class TestGroupnorm : public ::testing::Test
};
using KernelTypes = ::testing::Types<
// XDataType, GammaDataType, BetaDataType, AccDataType, YDataType>
// XDataType, GammaDataType, BetaDataType, ComputeDataType, YDataType>
std::tuple<F32, F32, F32, F32, F32>>;
TYPED_TEST_SUITE(TestGroupnorm, KernelTypes);