diff --git a/client_example/18_groupnorm/groupnorm_swish.cpp b/client_example/18_groupnorm/groupnorm_swish.cpp index 8a873e6acd..a79630c237 100644 --- a/client_example/18_groupnorm/groupnorm_swish.cpp +++ b/client_example/18_groupnorm/groupnorm_swish.cpp @@ -13,8 +13,8 @@ #include "ck/library/tensor_operation_instance/gpu/normalization_swish.hpp" using XDataType = ck::half_t; -using GammaDataType = ck::half_t; -using BetaDataType = ck::half_t; +using GammaDataType = float; +using BetaDataType = float; using YDataType = ck::half_t; using ComputeDataType = float; using Swish = ck::tensor_operation::element_wise::Swish; diff --git a/library/include/ck/library/tensor_operation_instance/gpu/normalization_swish.hpp b/library/include/ck/library/tensor_operation_instance/gpu/normalization_swish.hpp index c04a54455d..367180dea4 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/normalization_swish.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/normalization_swish.hpp @@ -25,6 +25,10 @@ void add_device_normalization_rank_5_3_swish_f16_instances( void add_device_normalization_rank_5_3_swish_f32_instances( std::vector>>&); +// [x, gamma, beta, y] = [f16, f32, f32, f16] +void add_device_normalization_rank_5_3_swish_f16_f32_f32_f16_instances( + std::vector>>&); + template && is_same_v && + is_same_v && is_same_v) + { + if constexpr(Rank == 5 && NumReduceDim == 3) + { + add_device_normalization_rank_5_3_swish_f16_f32_f32_f16_instances(op_ptrs); + } + } return op_ptrs; } diff --git a/library/src/tensor_operation_instance/gpu/normalization/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/normalization/CMakeLists.txt index 6bed36e350..176fb2fbee 100644 --- a/library/src/tensor_operation_instance/gpu/normalization/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/normalization/CMakeLists.txt @@ -7,4 +7,5 @@ add_instance_library(device_normalization_instance device_groupnorm_f32_instance.cpp device_groupnorm_swish_f16_instance.cpp device_groupnorm_swish_f32_instance.cpp + device_groupnorm_swish_f16_f32_f32_f16_instance.cpp ) diff --git a/library/src/tensor_operation_instance/gpu/normalization/device_groupnorm_swish_f16_f32_f32_f16_instance.cpp b/library/src/tensor_operation_instance/gpu/normalization/device_groupnorm_swish_f16_f32_f32_f16_instance.cpp new file mode 100644 index 0000000000..9f6bf128fa --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/normalization/device_groupnorm_swish_f16_f32_f32_f16_instance.cpp @@ -0,0 +1,24 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include "normalization_instance_common.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using Swish = ck::tensor_operation::element_wise::Swish; + +void add_device_normalization_rank_5_3_swish_f16_f32_f32_f16_instances( + std::vector>>& + instances) +{ + add_device_operation_instances(instances, + device_normalization_f16_f32_f32_f16_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/normalization/normalization_instance_common.hpp b/library/src/tensor_operation_instance/gpu/normalization/normalization_instance_common.hpp index a58fb6ca35..9dea41e89d 100644 --- a/library/src/tensor_operation_instance/gpu/normalization/normalization_instance_common.hpp +++ b/library/src/tensor_operation_instance/gpu/normalization/normalization_instance_common.hpp @@ -69,6 +69,32 @@ using device_normalization_f32_instances = std::tuple< // clang-format on >; +template +using device_normalization_f16_f32_f32_f16_instances = std::tuple< + // clang-format off + // XDataType, GammaDataType, BetaDataType, ComputeDataType, YDataType, Rank, NumReduceDim, BlockSize, MThreadClusterSize, KThreadClusterSize, MThreadSliceSize, KThreadSliceSize, XYSrcVectorDim, XSrcVectorSize, GammaSrcVectorSize, BetaSrcVectorSize, YDstVectorSize> + DeviceNormalizationImpl, // irregular size + DeviceNormalizationImpl, // irregular size + DeviceNormalizationImpl, // irregular size + DeviceNormalizationImpl, // irregular size + DeviceNormalizationImpl, // irregular size + DeviceNormalizationImpl, + DeviceNormalizationImpl, + DeviceNormalizationImpl, + DeviceNormalizationImpl, + DeviceNormalizationImpl, + DeviceNormalizationImpl, + DeviceNormalizationImpl, + DeviceNormalizationImpl, + DeviceNormalizationImpl, + DeviceNormalizationImpl, + DeviceNormalizationImpl, + DeviceNormalizationImpl, + DeviceNormalizationImpl, + DeviceNormalizationImpl + // clang-format on + >; + } // namespace instance } // namespace device } // namespace tensor_operation