diff --git a/client_example/05_layernorm/layernorm2d.cpp b/client_example/05_layernorm/layernorm2d.cpp index bdc6c2bd31..adb41171e1 100644 --- a/client_example/05_layernorm/layernorm2d.cpp +++ b/client_example/05_layernorm/layernorm2d.cpp @@ -90,6 +90,8 @@ int main(int argc, char* argv[]) gamma_device_buf.GetDeviceBuffer(), beta_device_buf.GetDeviceBuffer(), y_device_buf.GetDeviceBuffer(), + nullptr, + nullptr, PassThrough{}); auto invoker_ptr = op_ptr->MakeInvokerPointer(); @@ -143,6 +145,8 @@ int main(int argc, char* argv[]) gamma_device_buf.GetDeviceBuffer(), beta_device_buf.GetDeviceBuffer(), y_device_buf.GetDeviceBuffer(), + nullptr, + nullptr, PassThrough{}); auto invoker_ptr = op_ptr->MakeInvokerPointer(); diff --git a/example/27_layernorm/layernorm_blockwise.cpp b/example/27_layernorm/layernorm_blockwise.cpp index 54c4eaf74b..7f3033ee57 100644 --- a/example/27_layernorm/layernorm_blockwise.cpp +++ b/example/27_layernorm/layernorm_blockwise.cpp @@ -100,6 +100,8 @@ int main() gamma_dev.GetDeviceBuffer(), beta_dev.GetDeviceBuffer(), y_dev.GetDeviceBuffer(), + nullptr, + nullptr, PassThrough{}); if(!device_instance.IsSupportedArgument(argument_ptr.get())) diff --git a/example/42_groupnorm/groupnorm_sigmoid_fp16.cpp b/example/42_groupnorm/groupnorm_sigmoid_fp16.cpp index 8261b8d6ac..69cacfd143 100644 --- a/example/42_groupnorm/groupnorm_sigmoid_fp16.cpp +++ b/example/42_groupnorm/groupnorm_sigmoid_fp16.cpp @@ -128,6 +128,8 @@ int main(int argc, char* argv[]) gamma_dev.GetDeviceBuffer(), beta_dev.GetDeviceBuffer(), y_dev.GetDeviceBuffer(), + nullptr, + nullptr, y_element_op); if(!device_instance.IsSupportedArgument(argument_ptr.get())) diff --git a/include/ck/tensor_operation/gpu/device/device_normalization.hpp b/include/ck/tensor_operation/gpu/device/device_normalization.hpp index f1a3133c94..227c352cbd 100644 --- a/include/ck/tensor_operation/gpu/device/device_normalization.hpp +++ b/include/ck/tensor_operation/gpu/device/device_normalization.hpp @@ -33,6 +33,8 @@ struct DeviceNormalization : public BaseOperator const void* p_gamma, const void* p_beta, void* p_y, + void* p_savedMean, + void* p_savedInvVar, AccElementwiseOperation acc_elementwise_op) = 0; virtual std::unique_ptr MakeInvokerPointer() = 0; diff --git a/include/ck/tensor_operation/gpu/device/impl/device_normalization_impl.hpp b/include/ck/tensor_operation/gpu/device/impl/device_normalization_impl.hpp index 0fbeb7d714..47d9df8025 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_normalization_impl.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_normalization_impl.hpp @@ -10,7 +10,7 @@ #include "ck/tensor_operation/gpu/device/device_normalization.hpp" #include "ck/tensor_operation/gpu/device/device_reduce.hpp" #include "ck/tensor_operation/gpu/device/impl/device_reduce_common.hpp" -#include "ck/tensor_operation/gpu/grid/gridwise_layernorm_welford_variance.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_normalization_welford_variance.hpp" #include "ck/tensor_operation/gpu/grid/gridwise_set_buffer_value.hpp" #include "ck/host_utility/device_prop.hpp" #include "ck/host_utility/kernel_launch.hpp" @@ -24,17 +24,17 @@ template -__global__ void kernel_layernorm(const GridDesc_M_K x_grid_desc_m_k, - const GridDesc_M_K gamma_grid_desc_m_k, - const GridDesc_M_K beta_grid_desc_m_k, - const GridDesc_M_K y_grid_desc_m_k, - index_t num_k_block_tile_iteration, - AccDataType epsilon, - const XDataType* const __restrict__ p_x_global, - const GammaDataType* const __restrict__ p_gamma_global, - const BetaDataType* const __restrict__ p_beta_global, - YDataType* const __restrict__ p_y_global, - const AccElementwiseOperation acc_elementwise_op) +__global__ void kernel_normalization(const GridDesc_M_K x_grid_desc_m_k, + const GridDesc_M_K gamma_grid_desc_m_k, + const GridDesc_M_K beta_grid_desc_m_k, + const GridDesc_M_K y_grid_desc_m_k, + index_t num_k_block_tile_iteration, + AccDataType epsilon, + const XDataType* const __restrict__ p_x_global, + const GammaDataType* const __restrict__ p_gamma_global, + const BetaDataType* const __restrict__ p_beta_global, + YDataType* const __restrict__ p_y_global, + const AccElementwiseOperation acc_elementwise_op) { GridwiseReduction::Run(x_grid_desc_m_k, gamma_grid_desc_m_k, @@ -54,7 +54,7 @@ namespace ck { namespace tensor_operation { namespace device { -// Y = LayerNorm(X, Beta, Gamma) +// Y = Normalization(X, Beta, Gamma) template ; - using GridwiseReduceLayernormSweepOnce = - GridwiseLayernormWelfordVariance_mk_to_mk; + GridwiseNormalizationWelfordVariance_mk_to_mk; + using GridwiseNormalizationSweepOnce = + GridwiseNormalizationWelfordVariance_mk_to_mk; struct Argument : public BaseArgument { @@ -295,22 +295,22 @@ struct DeviceNormalizationImpl : public DeviceNormalization - : kernel_layernorm; + ? kernel_normalization + : kernel_normalization; float avg_time = 0; avg_time += launch_and_time_kernel(stream_config, @@ -426,8 +426,16 @@ struct DeviceNormalizationImpl : public DeviceNormalization(lengths, xStrides, gammaStrides, diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_layernorm_naive_variance.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_normalization_naive_variance.hpp similarity index 99% rename from include/ck/tensor_operation/gpu/grid/gridwise_layernorm_naive_variance.hpp rename to include/ck/tensor_operation/gpu/grid/gridwise_normalization_naive_variance.hpp index f90739eaec..89efea4d6c 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_layernorm_naive_variance.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_normalization_naive_variance.hpp @@ -14,7 +14,7 @@ namespace ck { -// Y = LayerNorm(X, Beta, Gamma) +// Y = Normalization(X, Beta, Gamma) template -struct GridwiseLayernormNaiveVariance_mk_to_mk +struct GridwiseNormalizationNaiveVariance_mk_to_mk { static_assert((XSrcVectorDim == 0 && MThreadSliceSize % XSrcVectorSize == 0) || (XSrcVectorDim == 1 && KThreadSliceSize % XSrcVectorSize == 0), diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_layernorm_welford_variance.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_normalization_welford_variance.hpp similarity index 99% rename from include/ck/tensor_operation/gpu/grid/gridwise_layernorm_welford_variance.hpp rename to include/ck/tensor_operation/gpu/grid/gridwise_normalization_welford_variance.hpp index 094c79c6f8..7aefd3c066 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_layernorm_welford_variance.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_normalization_welford_variance.hpp @@ -11,7 +11,7 @@ namespace ck { -// Y = LayerNorm(X, Beta, Gamma) +// Y = Normalization(X, Beta, Gamma) template -struct GridwiseLayernormWelfordVariance_mk_to_mk +struct GridwiseNormalizationWelfordVariance_mk_to_mk { static_assert((XSrcVectorDim == 0 && MThreadSliceSize % XSrcVectorSize == 0) || (XSrcVectorDim == 1 && KThreadSliceSize % XSrcVectorSize == 0), diff --git a/profiler/include/profile_groupnorm_impl.hpp b/profiler/include/profile_groupnorm_impl.hpp index 05966ed412..9b2a3e9f3f 100644 --- a/profiler/include/profile_groupnorm_impl.hpp +++ b/profiler/include/profile_groupnorm_impl.hpp @@ -126,6 +126,8 @@ bool profile_groupnorm_impl(int do_verification, gamma_dev.GetDeviceBuffer(), beta_dev.GetDeviceBuffer(), y_dev.GetDeviceBuffer(), + nullptr, + nullptr, PassThrough{}); if(inst_ptr->IsSupportedArgument(argument_ptr.get())) @@ -196,7 +198,7 @@ bool profile_groupnorm_impl(int do_verification, if(num_kernel == 0) { - std::cout << "Error: No kernel is tested" << std::endl; + std::cout << "Error: No kernel is applicable" << std::endl; return false; } diff --git a/profiler/include/profile_layernorm_impl.hpp b/profiler/include/profile_layernorm_impl.hpp index 54bf57b521..eb21d4a586 100644 --- a/profiler/include/profile_layernorm_impl.hpp +++ b/profiler/include/profile_layernorm_impl.hpp @@ -22,7 +22,7 @@ template -void profile_layernorm_impl(int do_verification, +bool profile_layernorm_impl(int do_verification, int init_method, bool do_log, bool time_kernel, @@ -31,7 +31,7 @@ void profile_layernorm_impl(int do_verification, using PassThrough = ck::tensor_operation::element_wise::PassThrough; if(length.size() < 2) - return; + return false; // Assume normalize dimension except for batch (first) dimension std::vector reduce_length{length.begin() + 1, length.end()}; @@ -52,7 +52,6 @@ void profile_layernorm_impl(int do_verification, switch(init_method) { - // case 0: break; case 0: x.GenerateTensorValue(GeneratorTensor_1{}); gamma.GenerateTensorValue(GeneratorTensor_1{}); @@ -122,6 +121,8 @@ void profile_layernorm_impl(int do_verification, ref_invoker.Run(ref_argument); } + int num_kernel = 0; + for(auto& inst_ptr : instance_ptrs) { auto argument_ptr = inst_ptr->MakeArgumentPointer(length, @@ -135,12 +136,21 @@ void profile_layernorm_impl(int do_verification, gamma_dev.GetDeviceBuffer(), beta_dev.GetDeviceBuffer(), y_dev.GetDeviceBuffer(), + nullptr, + nullptr, PassThrough{}); - if(!inst_ptr->IsSupportedArgument(argument_ptr.get())) + if(inst_ptr->IsSupportedArgument(argument_ptr.get())) { - std::cout << inst_ptr->GetTypeString() << " skipped due to unsupported argument: "; - LogRange(std::cout << "input lengths = ", length, ", ") << std::endl; + ++num_kernel; + } + else + { + if(time_kernel) + { + std::cout << inst_ptr->GetTypeString() << " skipped due to unsupported argument: "; + LogRange(std::cout << "input lengths = ", length, ", ") << std::endl; + } continue; } @@ -156,8 +166,9 @@ void profile_layernorm_impl(int do_verification, float gb_per_sec = num_bytes / 1.E6 / avg_time; - std::cout << "Perf: " << std::setw(10) << avg_time << " ms, " << gb_per_sec << " GB/s, " - << inst_ptr->GetTypeString() << std::endl; + if(time_kernel) + std::cout << "Perf: " << std::setw(10) << avg_time << " ms, " << gb_per_sec << " GB/s, " + << inst_ptr->GetTypeString() << std::endl; if(avg_time < best_avg_time) { @@ -184,20 +195,32 @@ void profile_layernorm_impl(int do_verification, { std::cout << inst_ptr->GetTypeString() << " failed verification: "; LogRange(std::cout << "lengths = [", length, ", ") << "]." << std::endl; - return; + return false; } else { - std::cout << "pass" << std::endl; + if(time_kernel) + std::cout << "pass" << std::endl; } } } - LogRange(std::cout << "length = ", length, ",") << ", "; - LogRange(std::cout << "stride = ", strideXY, ",") << ", "; - LogRange(std::cout << "reduce dims ", reduce_dim, ",") << std::endl; - std::cout << "best perf = " << best_avg_time << " ms, " << best_gb_per_sec << " GB/s, " - << best_instance_name << std::endl; + if(time_kernel) + { + LogRange(std::cout << "length = ", length, ",") << ", "; + LogRange(std::cout << "stride = ", strideXY, ",") << ", "; + LogRange(std::cout << "reduce dims ", reduce_dim, ",") << std::endl; + std::cout << "best perf = " << best_avg_time << " ms, " << best_gb_per_sec << " GB/s, " + << best_instance_name << std::endl; + } + + if(num_kernel == 0) + { + std::cout << "Error: No kernel is applicable" << std::endl; + return false; + } + + return true; } } // namespace profiler diff --git a/test/normalization/CMakeLists.txt b/test/normalization/CMakeLists.txt index ab6e2d1cd1..e740755bf5 100644 --- a/test/normalization/CMakeLists.txt +++ b/test/normalization/CMakeLists.txt @@ -5,8 +5,8 @@ add_gtest_executable(test_layernorm2d_fp16 test_layernorm2d_fp16.cpp) add_gtest_executable(test_groupnorm_fp16 test_groupnorm_fp16.cpp) add_gtest_executable(test_groupnorm_fp32 test_groupnorm_fp32.cpp) -target_link_libraries(test_layernorm2d_fp32 PRIVATE utility) -target_link_libraries(test_layernorm2d_fp16 PRIVATE utility) +target_link_libraries(test_layernorm2d_fp32 PRIVATE utility device_normalization_instance) +target_link_libraries(test_layernorm2d_fp16 PRIVATE utility device_normalization_instance) target_link_libraries(test_groupnorm_fp16 PRIVATE utility device_normalization_instance) target_link_libraries(test_groupnorm_fp32 PRIVATE utility device_normalization_instance) diff --git a/test/normalization/test_groupnorm_fp16.cpp b/test/normalization/test_groupnorm_fp16.cpp index ecdf61cade..8f7438247c 100644 --- a/test/normalization/test_groupnorm_fp16.cpp +++ b/test/normalization/test_groupnorm_fp16.cpp @@ -20,7 +20,7 @@ class TestGroupnorm : public ::testing::Test void Run() { - // N, H, W, G, C + // [N, H, W, G, C], reduce H, W, C std::vector> lengths = {{1, 1, 1, 1, 1}, {1, 2, 3, 4, 5}, {256, 9, 9, 9, 9}, diff --git a/test/normalization/test_groupnorm_fp32.cpp b/test/normalization/test_groupnorm_fp32.cpp index 6c5e2f20b7..8dadbb60f8 100644 --- a/test/normalization/test_groupnorm_fp32.cpp +++ b/test/normalization/test_groupnorm_fp32.cpp @@ -20,7 +20,7 @@ class TestGroupnorm : public ::testing::Test void Run() { - // N, H, W, G, C + // [N, H, W, G, C], reduce H, W, C std::vector> lengths = {{1, 1, 1, 1, 1}, {1, 2, 3, 4, 5}, {256, 9, 9, 9, 9}, diff --git a/test/normalization/test_layernorm2d_fp16.cpp b/test/normalization/test_layernorm2d_fp16.cpp index ccc6472660..7e3af7135e 100644 --- a/test/normalization/test_layernorm2d_fp16.cpp +++ b/test/normalization/test_layernorm2d_fp16.cpp @@ -2,28 +2,44 @@ // Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. #include "gtest/gtest.h" -#include "test_layernorm2d_util.hpp" +#include "profiler/include/profile_layernorm_impl.hpp" -template -using I = ck::Number; +using F16 = ck::half_t; +using F32 = float; +using ck::index_t; template -class TestLayernorm2dFP16 : public ck::TestLayernorm2d +class TestLayernorm2d : public ::testing::Test { + protected: + using XDataType = std::tuple_element_t<0, Tuple>; + using GammaDataType = std::tuple_element_t<1, Tuple>; + using BetaDataType = std::tuple_element_t<2, Tuple>; + using AccDataType = std::tuple_element_t<3, Tuple>; + using YDataType = std::tuple_element_t<4, Tuple>; + + void Run() + { + // [N, D], reduce D + std::vector> lengths = { + {4, 256}, {8, 511}, {9, 1032}, {4, 2048}, {1, 8192}, {4000, 2000}}; + + for(auto length : lengths) + { + bool success = ck::profiler::profile_layernorm_impl(true, 2, false, false, length); + EXPECT_TRUE(success); + } + } }; -// clang-format off using KernelTypes = ::testing::Types< -// XDataType, GammaDataType, BetaDataType, AccDataType, YDataType, Rank, NumReduceDim, BlockSize, MThreadClusterSize, KThreadClusterSize, MThreadSliceSize, KThreadSliceSize, XYSrcVectorDim, XSrcVectorSize, GammaSrcVectorDim , GammaSrcVectorSize, BetaSrcVectorDim, BetaSrcVectorSize, YDstVectorSize> - std::tuple, I<1>, I<256>, I<8>, I<32>, I<1>, I<8>, I<1>, I<8>, I<1>, I<8>, I<1>, I<8>, I<8>>, - std::tuple, I<1>, I<256>, I<8>, I<32>, I<2>, I<8>, I<1>, I<8>, I<1>, I<8>, I<1>, I<8>, I<8>>, - std::tuple, I<1>, I<256>, I<4>, I<64>, I<1>, I<8>, I<1>, I<8>, I<1>, I<8>, I<1>, I<8>, I<8>>, - std::tuple, I<1>, I<256>, I<4>, I<64>, I<2>, I<8>, I<1>, I<8>, I<1>, I<8>, I<1>, I<8>, I<8>>, - std::tuple, I<1>, I<256>, I<2>, I<128>, I<1>, I<8>, I<1>, I<8>, I<1>, I<8>, I<1>, I<8>, I<8>>, - std::tuple, I<1>, I<256>, I<2>, I<128>, I<2>, I<8>, I<1>, I<8>, I<1>, I<8>, I<1>, I<8>, I<8>>, - std::tuple, I<1>, I<256>, I<1>, I<256>, I<1>, I<8>, I<1>, I<8>, I<1>, I<8>, I<1>, I<8>, I<8>>, - std::tuple, I<1>, I<256>, I<1>, I<256>, I<2>, I<8>, I<1>, I<8>, I<1>, I<8>, I<1>, I<8>, I<8>> - >; -// clang-format on -TYPED_TEST_SUITE(TestLayernorm2dFP16, KernelTypes); -TYPED_TEST(TestLayernorm2dFP16, Test_FP16) { this->Run(); } + // XDataType, GammaDataType, BetaDataType, AccDataType, YDataType> + std::tuple>; + +TYPED_TEST_SUITE(TestLayernorm2d, KernelTypes); +TYPED_TEST(TestLayernorm2d, Test_FP16) { this->Run(); } diff --git a/test/normalization/test_layernorm2d_fp32.cpp b/test/normalization/test_layernorm2d_fp32.cpp index 47cf1641e3..a7c4380d59 100644 --- a/test/normalization/test_layernorm2d_fp32.cpp +++ b/test/normalization/test_layernorm2d_fp32.cpp @@ -2,28 +2,44 @@ // Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. #include "gtest/gtest.h" -#include "test_layernorm2d_util.hpp" +#include "profiler/include/profile_layernorm_impl.hpp" -template -using I = ck::Number; +using F16 = ck::half_t; +using F32 = float; +using ck::index_t; template -class TestLayernorm2dFP32 : public ck::TestLayernorm2d +class TestLayernorm2d : public ::testing::Test { + protected: + using XDataType = std::tuple_element_t<0, Tuple>; + using GammaDataType = std::tuple_element_t<1, Tuple>; + using BetaDataType = std::tuple_element_t<2, Tuple>; + using AccDataType = std::tuple_element_t<3, Tuple>; + using YDataType = std::tuple_element_t<4, Tuple>; + + void Run() + { + // [N, D], reduce D + std::vector> lengths = { + {4, 256}, {8, 511}, {9, 1032}, {4, 2048}, {1, 8192}, {4000, 2000}}; + + for(auto length : lengths) + { + bool success = ck::profiler::profile_layernorm_impl(true, 2, false, false, length); + EXPECT_TRUE(success); + } + } }; -// clang-format off using KernelTypes = ::testing::Types< -// XDataType, GammaDataType, BetaDataType, AccDataType, YDataType, Rank, NumReduceDim, BlockSize, MThreadClusterSize, KThreadClusterSize, MThreadSliceSize, KThreadSliceSize, XYSrcVectorDim, XSrcVectorSize, GammaSrcVectorDim, GammaSrcVectorSize, BetaSrcVectorDim, BetaSrcVectorSize, YDstVectorSize> - std::tuple, I<1>, I<256>, I<8>, I<32>, I<1>, I<8>, I<1>, I<4>, I<1>, I<4>, I<1>, I<4>, I<4>>, - std::tuple, I<1>, I<256>, I<8>, I<32>, I<2>, I<8>, I<1>, I<4>, I<1>, I<4>, I<1>, I<4>, I<4>>, - std::tuple, I<1>, I<256>, I<4>, I<64>, I<1>, I<8>, I<1>, I<4>, I<1>, I<4>, I<1>, I<4>, I<4>>, - std::tuple, I<1>, I<256>, I<4>, I<64>, I<2>, I<8>, I<1>, I<4>, I<1>, I<4>, I<1>, I<4>, I<4>>, - std::tuple, I<1>, I<256>, I<2>, I<128>, I<1>, I<8>, I<1>, I<4>, I<1>, I<4>, I<1>, I<4>, I<4>>, - std::tuple, I<1>, I<256>, I<2>, I<128>, I<2>, I<8>, I<1>, I<4>, I<1>, I<4>, I<1>, I<4>, I<4>>, - std::tuple, I<1>, I<256>, I<1>, I<256>, I<1>, I<8>, I<1>, I<4>, I<1>, I<4>, I<1>, I<4>, I<4>>, - std::tuple, I<1>, I<256>, I<1>, I<256>, I<2>, I<8>, I<1>, I<4>, I<1>, I<4>, I<1>, I<4>, I<4>> - >; -// clang-format on -TYPED_TEST_SUITE(TestLayernorm2dFP32, KernelTypes); -TYPED_TEST(TestLayernorm2dFP32, Test_FP32) { this->Run(); } + // XDataType, GammaDataType, BetaDataType, AccDataType, YDataType> + std::tuple>; + +TYPED_TEST_SUITE(TestLayernorm2d, KernelTypes); +TYPED_TEST(TestLayernorm2d, Test_FP32) { this->Run(); } diff --git a/test/normalization/test_layernorm2d_util.hpp b/test/normalization/test_layernorm2d_util.hpp deleted file mode 100644 index c1d4d0f542..0000000000 --- a/test/normalization/test_layernorm2d_util.hpp +++ /dev/null @@ -1,179 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. - -#pragma once - -#include -#include -#include - -#include "ck/ck.hpp" -#include "ck/utility/number.hpp" -#include "ck/tensor_operation/gpu/device/impl/device_normalization_impl.hpp" - -#include "ck/library/utility/check_err.hpp" -#include "ck/library/utility/host_tensor.hpp" -#include "ck/library/utility/device_memory.hpp" -#include "ck/library/reference_tensor_operation/cpu/reference_layernorm.hpp" - -namespace ck { - -template -std::string serialize_range(const Range& range) -{ - std::stringstream ss; - for(auto& r : range) - { - ss << r << ", "; - } - std::string str = ss.str(); - return std::string(str.begin(), str.end() - 2); -} - -template -class TestLayernorm2d : public ::testing::Test -{ - protected: - using XDataType = std::tuple_element_t<0, Tuple>; - using GammaDataType = std::tuple_element_t<1, Tuple>; - using BetaDataType = std::tuple_element_t<2, Tuple>; - using AccDataType = std::tuple_element_t<3, Tuple>; - using YDataType = std::tuple_element_t<4, Tuple>; - static constexpr index_t Rank = std::tuple_element_t<5, Tuple>{}.value; - static constexpr index_t NumReduceDim = std::tuple_element_t<6, Tuple>{}.value; - static constexpr index_t BlockSize = std::tuple_element_t<7, Tuple>{}.value; - static constexpr index_t MThreadClusterSize = std::tuple_element_t<8, Tuple>{}.value; - static constexpr index_t KThreadClusterSize = std::tuple_element_t<9, Tuple>{}.value; - static constexpr index_t MThreadSliceSize = std::tuple_element_t<10, Tuple>{}.value; - static constexpr index_t KThreadSliceSize = std::tuple_element_t<11, Tuple>{}.value; - static constexpr index_t XYSrcVectorDim = std::tuple_element_t<12, Tuple>{}.value; - static constexpr index_t XSrcVectorSize = std::tuple_element_t<13, Tuple>{}.value; - static constexpr index_t GammaSrcVectorDim = std::tuple_element_t<14, Tuple>{}.value; - static constexpr index_t GammaSrcVectorSize = std::tuple_element_t<15, Tuple>{}.value; - static constexpr index_t BetaSrcVectorDim = std::tuple_element_t<16, Tuple>{}.value; - static constexpr index_t BetaSrcVectorSize = std::tuple_element_t<17, Tuple>{}.value; - static constexpr index_t YDstVectorSize = std::tuple_element_t<18, Tuple>{}.value; - - using PassThrough = ck::tensor_operation::element_wise::PassThrough; - - using ReferenceInstance = tensor_operation::host::ReferenceLayernorm; - - using DeviceInstance = tensor_operation::device::DeviceNormalizationImpl; - - TestLayernorm2d() : ref_instance_invoker_(ReferenceInstance{}.MakeInvoker()) {} - - void RunSingle(const std::vector& lengths, - const std::vector& reduceDims, - const std::vector& GammaLength, - const std::vector& GammaStride, - const std::vector& BetaLength, - const std::vector& BetaStride) - { - Tensor x(lengths); - Tensor gamma(GammaLength); - Tensor beta(BetaLength); - Tensor y(lengths); - Tensor y_ref(lengths); - - x.GenerateTensorValue(GeneratorTensor_3{0, 1.0}); - gamma.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); - beta.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); - - DeviceMem x_dev(sizeof(XDataType) * x.mDesc.GetElementSpaceSize()); - DeviceMem gamma_dev(sizeof(GammaDataType) * gamma.mDesc.GetElementSpaceSize()); - DeviceMem beta_dev(sizeof(BetaDataType) * beta.mDesc.GetElementSpaceSize()); - DeviceMem y_dev(sizeof(YDataType) * y.mDesc.GetElementSpaceSize()); - - x_dev.ToDevice(x.mData.data()); - gamma_dev.ToDevice(gamma.mData.data()); - beta_dev.ToDevice(beta.mData.data()); - - auto device_instance = DeviceInstance{}; - auto argument_ptr = device_instance.MakeArgumentPointer( - lengths, - std::vector{x.mDesc.GetStrides().begin(), x.mDesc.GetStrides().end()}, - GammaStride, - BetaStride, - std::vector{y.mDesc.GetStrides().begin(), y.mDesc.GetStrides().end()}, - reduceDims, - 1e-4, - x_dev.GetDeviceBuffer(), - gamma_dev.GetDeviceBuffer(), - beta_dev.GetDeviceBuffer(), - y_dev.GetDeviceBuffer(), - PassThrough{}); - - if(!device_instance.IsSupportedArgument(argument_ptr.get())) - { - return; - } - - auto invoker_ptr = device_instance.MakeInvokerPointer(); - invoker_ptr->Run(argument_ptr.get()); - - ref_instance_invoker_.Run( - {x, gamma, beta, y_ref, PassThrough{}, lengths, reduceDims, 1e-4}); - - y_dev.FromDevice(y.mData.data()); - - bool pass; - - if(std::is_same::value) - { - EXPECT_TRUE(pass = ck::utils::check_err( - y.mData, y_ref.mData, "Error: Incorrect results!", 0, 1)); - } - else - { - EXPECT_TRUE(pass = ck::utils::check_err( - y.mData, y_ref.mData, "Error: Incorrect results d1", 1e-3, 1e-3)); - } - - if(!pass) - { - FAIL() << "Failure in input lengths = [" << serialize_range(lengths) << "], " - << "reduce dim = [" << serialize_range(reduceDims) << "]."; - } - } - - void Run() - { - std::vector> lengths = { - {4, 256}, {8, 511}, {9, 1032}, {4, 2048}, {1, 8192}, {4000, 2000}}; - - for(auto length : lengths) - { - this->RunSingle(length, {1}, {length[1]}, {0, 1}, {length[1]}, {0, 1}); - } - } - - typename ReferenceInstance::Invoker ref_instance_invoker_; -}; - -} // namespace ck