From 6caec3d429c8400c16afd431a15a6355c41d81f9 Mon Sep 17 00:00:00 2001 From: guangzlu <87220526+guangzlu@users.noreply.github.com> Date: Fri, 10 Feb 2023 01:37:29 +0800 Subject: [PATCH] Add instance for elementwise normlization (#573) * added instances for large N * add instance for elementwise normlization * added supported restrict in device_elementwise_normalization_impl.hpp [ROCm/composable_kernel commit: 76d144fa7c396e52631719f79008e7099b6cd30d] --- .../device/impl/device_elementwise_normalization_impl.hpp | 5 +++++ .../device_elementwise_normalization_f16_instance.cpp | 5 +++++ .../test_elementwise_layernorm_fp16.cpp | 2 +- 3 files changed, 11 insertions(+), 1 deletion(-) diff --git a/include/ck/tensor_operation/gpu/device/impl/device_elementwise_normalization_impl.hpp b/include/ck/tensor_operation/gpu/device/impl/device_elementwise_normalization_impl.hpp index 1085bdf922..1fa69288a4 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_elementwise_normalization_impl.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_elementwise_normalization_impl.hpp @@ -533,6 +533,11 @@ struct DeviceElementwiseNormalizationImpl return (false); } + if(p_arg_->x_lds_size_ >= 65536) + { + return (false); + } + return true; }; diff --git a/library/src/tensor_operation_instance/gpu/elementwise_normalization/device_elementwise_normalization_f16_instance.cpp b/library/src/tensor_operation_instance/gpu/elementwise_normalization/device_elementwise_normalization_f16_instance.cpp index 7f15372ed9..b160d4fe1a 100644 --- a/library/src/tensor_operation_instance/gpu/elementwise_normalization/device_elementwise_normalization_f16_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/elementwise_normalization/device_elementwise_normalization_f16_instance.cpp @@ -23,6 +23,11 @@ template + DeviceElementwiseNormalizationImpl, F16, F16, F32, F16, XElementwise ,YElementwise, Rank, Reduce, 256, 1, 256, 1, 8, 1, 1, 1, 1, 1, 1, 1>, // fallback kernel for large N + DeviceElementwiseNormalizationImpl, F16, F16, F32, F16, XElementwise ,YElementwise, Rank, Reduce, 256, 1, 256, 1, 8, 1, 2, 1, 2, 1, 2, 2>, // fallback kernel for large N + DeviceElementwiseNormalizationImpl, F16, F16, F32, F16, XElementwise ,YElementwise, Rank, Reduce, 256, 1, 256, 1, 8, 1, 8, 1, 8, 1, 8, 8>, // fallback kernel for large N + DeviceElementwiseNormalizationImpl, F16, F16, F32, F16, XElementwise ,YElementwise, Rank, Reduce, 256, 2, 128, 1, 8, 1, 8, 1, 8, 1, 8, 8>, // fallback kernel for large N + DeviceElementwiseNormalizationImpl, F16, F16, F32, F16, XElementwise ,YElementwise, Rank, Reduce, 256, 4, 64, 1, 8, 1, 1, 1, 1, 1, 1, 1>, // fallback kernel for large N DeviceElementwiseNormalizationImpl, F16, F16, F32, F16, XElementwise ,YElementwise, Rank, Reduce, 256, 8, 32, 1, 8, 1, 1, 1, 1, 1, 1, 1>, // fallback kernel DeviceElementwiseNormalizationImpl, F16, F16, F32, F16, XElementwise ,YElementwise, Rank, Reduce, 256, 8, 32, 1, 8, 1, 2, 1, 2, 1, 2, 2>, // fallback kernel DeviceElementwiseNormalizationImpl, F16, F16, F32, F16, XElementwise ,YElementwise, Rank, Reduce, 256, 8, 32, 1, 8, 1, 4, 1, 4, 1, 4, 4>, // fallback kernel diff --git a/test/elementwise_normalization/test_elementwise_layernorm_fp16.cpp b/test/elementwise_normalization/test_elementwise_layernorm_fp16.cpp index 403881b3cc..e80995c4f0 100644 --- a/test/elementwise_normalization/test_elementwise_layernorm_fp16.cpp +++ b/test/elementwise_normalization/test_elementwise_layernorm_fp16.cpp @@ -23,7 +23,7 @@ class TestElementwiseLayernorm : public ::testing::Test { // M, N std::vector> lengths = { - {1, 1}, {25, 16}, {39, 777}, {100, 200}, {1024, 1024}, {48 * 256, 2048}}; + {1, 1}, {25, 16}, {39, 777}, {100, 200}, {1024, 1024}, {48 * 256, 2048}, {4096, 8192}}; for(auto length : lengths) {