mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-12 09:16:52 +00:00
Layernorm and groupnorm support to save mean and inverse std in forward (#929)
* save mean and inverse std in normalization * Save mean and inverse std in splitK * Vector save mean and inv std * Modify instance for save mean and std * simplify the layernorm example * Save mean and std in groupnorm example * Save mean and inv std in ckProfiler and test * Remove compute data type from base class * Save mean and inv std in client example * Add changelog * clang format * Fix compile error * Refine naming * Avoid error in bf16 * revert changelog
This commit is contained in:
@@ -12,11 +12,12 @@ template <typename Tuple>
|
||||
class TestGroupnorm : public ::testing::Test
|
||||
{
|
||||
protected:
|
||||
using XDataType = std::tuple_element_t<0, Tuple>;
|
||||
using GammaDataType = std::tuple_element_t<1, Tuple>;
|
||||
using BetaDataType = std::tuple_element_t<2, Tuple>;
|
||||
using ComputeDataType = std::tuple_element_t<3, Tuple>;
|
||||
using YDataType = std::tuple_element_t<4, Tuple>;
|
||||
using XDataType = std::tuple_element_t<0, Tuple>;
|
||||
using GammaDataType = std::tuple_element_t<1, Tuple>;
|
||||
using BetaDataType = std::tuple_element_t<2, Tuple>;
|
||||
using ComputeDataType = std::tuple_element_t<3, Tuple>;
|
||||
using YDataType = std::tuple_element_t<4, Tuple>;
|
||||
using SaveMeanInvStdDataType = std::tuple_element_t<5, Tuple>;
|
||||
|
||||
void Run()
|
||||
{
|
||||
@@ -35,7 +36,9 @@ class TestGroupnorm : public ::testing::Test
|
||||
GammaDataType,
|
||||
BetaDataType,
|
||||
ComputeDataType,
|
||||
YDataType>(true, 2, false, false, length);
|
||||
YDataType,
|
||||
SaveMeanInvStdDataType,
|
||||
true>(true, 2, false, false, length);
|
||||
EXPECT_TRUE(success);
|
||||
}
|
||||
}
|
||||
@@ -43,7 +46,7 @@ class TestGroupnorm : public ::testing::Test
|
||||
|
||||
using KernelTypes = ::testing::Types<
|
||||
// XDataType, GammaDataType, BetaDataType, ComputeDataType, YDataType>
|
||||
std::tuple<F32, F32, F32, F32, F32>>;
|
||||
std::tuple<F32, F32, F32, F32, F32, F32>>;
|
||||
|
||||
TYPED_TEST_SUITE(TestGroupnorm, KernelTypes);
|
||||
TYPED_TEST(TestGroupnorm, Test_FP32) { this->Run(); }
|
||||
|
||||
Reference in New Issue
Block a user