From 91965f3411cf0055b1ddd02c7f31026cd79ff530 Mon Sep 17 00:00:00 2001 From: Mateusz Ozga <110818320+mozga-amd@users.noreply.github.com> Date: Tue, 21 Jan 2025 23:23:19 +0100 Subject: [PATCH] Simplify static_cast if-lands (#1828) [ROCm/composable_kernel commit: 3db77bc4f26453a5ba5aad3d49adb03d1accf8de] --- include/ck_tile/core/utility/type_traits.hpp | 18 ++++ include/ck_tile/host/check_err.hpp | 54 ++++-------- ...volution_host_tensor_descriptor_helper.hpp | 84 ++++++------------- 3 files changed, 63 insertions(+), 93 deletions(-) diff --git a/include/ck_tile/core/utility/type_traits.hpp b/include/ck_tile/core/utility/type_traits.hpp index f6e133c759..b432cfcef7 100644 --- a/include/ck_tile/core/utility/type_traits.hpp +++ b/include/ck_tile/core/utility/type_traits.hpp @@ -109,4 +109,22 @@ CK_TILE_HOST_DEVICE PY c_style_pointer_cast(PX p_x) #pragma clang diagnostic pop } +template +struct is_any_of : std::false_type +{ +}; + +template +struct is_any_of : std::is_same +{ +}; + +template +struct is_any_of + : std::integral_constant::value || + is_any_of::value> +{ +}; + } // namespace ck_tile diff --git a/include/ck_tile/host/check_err.hpp b/include/ck_tile/host/check_err.hpp index c4ad345d8e..5238b361a2 100644 --- a/include/ck_tile/host/check_err.hpp +++ b/include/ck_tile/host/check_err.hpp @@ -28,14 +28,11 @@ double get_relative_threshold(const int number_of_accumulations = 1) using I8 = int8_t; using I32 = int32_t; - static_assert(std::is_same_v || std::is_same_v || - std::is_same_v || - std::is_same_v || std::is_same_v || - std::is_same_v || std::is_same_v, + static_assert(is_any_of::value, "Warning: Unhandled ComputeDataType for setting up the relative threshold!"); + double compute_error = 0; - if constexpr(std::is_same_v || std::is_same_v || - std::is_same_v) + if constexpr(is_any_of::value) { return 0; } @@ -44,14 +41,11 @@ double get_relative_threshold(const int number_of_accumulations = 1) compute_error = std::pow(2, -numeric_traits::mant) * 0.5; } - static_assert(std::is_same_v || std::is_same_v || - std::is_same_v || std::is_same_v || - std::is_same_v || std::is_same_v || - std::is_same_v, + static_assert(is_any_of::value, "Warning: Unhandled OutDataType for setting up the relative threshold!"); + double output_error = 0; - if constexpr(std::is_same_v || std::is_same_v || - std::is_same_v) + if constexpr(is_any_of::value) { return 0; } @@ -61,14 +55,11 @@ double get_relative_threshold(const int number_of_accumulations = 1) } double midway_error = std::max(compute_error, output_error); - static_assert(std::is_same_v || std::is_same_v || - std::is_same_v || std::is_same_v || - std::is_same_v || std::is_same_v || - std::is_same_v, + static_assert(is_any_of::value, "Warning: Unhandled AccDataType for setting up the relative threshold!"); + double acc_error = 0; - if constexpr(std::is_same_v || std::is_same_v || - std::is_same_v) + if constexpr(is_any_of::value) { return 0; } @@ -89,15 +80,12 @@ double get_absolute_threshold(const double max_possible_num, const int number_of using I8 = int8_t; using I32 = int32_t; - static_assert(std::is_same_v || std::is_same_v || - std::is_same_v || - std::is_same_v || std::is_same_v || - std::is_same_v || std::is_same_v, + static_assert(is_any_of::value, "Warning: Unhandled ComputeDataType for setting up the absolute threshold!"); + auto expo = std::log2(std::abs(max_possible_num)); double compute_error = 0; - if constexpr(std::is_same_v || std::is_same_v || - std::is_same_v) + if constexpr(is_any_of::value) { return 0; } @@ -106,14 +94,11 @@ double get_absolute_threshold(const double max_possible_num, const int number_of compute_error = std::pow(2, expo - numeric_traits::mant) * 0.5; } - static_assert(std::is_same_v || std::is_same_v || - std::is_same_v || std::is_same_v || - std::is_same_v || std::is_same_v || - std::is_same_v, + static_assert(is_any_of::value, "Warning: Unhandled OutDataType for setting up the absolute threshold!"); + double output_error = 0; - if constexpr(std::is_same_v || std::is_same_v || - std::is_same_v) + if constexpr(is_any_of::value) { return 0; } @@ -123,14 +108,11 @@ double get_absolute_threshold(const double max_possible_num, const int number_of } double midway_error = std::max(compute_error, output_error); - static_assert(std::is_same_v || std::is_same_v || - std::is_same_v || std::is_same_v || - std::is_same_v || std::is_same_v || - std::is_same_v, + static_assert(is_any_of::value, "Warning: Unhandled AccDataType for setting up the absolute threshold!"); + double acc_error = 0; - if constexpr(std::is_same_v || std::is_same_v || - std::is_same_v) + if constexpr(is_any_of::value) { return 0; } diff --git a/include/ck_tile/host/convolution_host_tensor_descriptor_helper.hpp b/include/ck_tile/host/convolution_host_tensor_descriptor_helper.hpp index b7317fc04b..33a85b0d4b 100644 --- a/include/ck_tile/host/convolution_host_tensor_descriptor_helper.hpp +++ b/include/ck_tile/host/convolution_host_tensor_descriptor_helper.hpp @@ -14,57 +14,41 @@ namespace detail { template CK_TILE_HOST std::vector get_layout_transpose_gnchw_to_old() { - if constexpr(std::is_same_v || - std::is_same_v || - std::is_same_v) + using namespace ck_tile::tensor_layout::convolution; + + if constexpr(is_any_of::value) { return {0, 1, 2, 3}; } - else if constexpr(std::is_same_v || - std::is_same_v || - std::is_same_v) + else if constexpr(is_any_of::value) { return {0, 1, 2, 3, 4}; } - else if constexpr(std::is_same_v || - std::is_same_v || - std::is_same_v) + else if constexpr(is_any_of::value) { return {0, 1, 2, 3, 4, 5}; } - if constexpr(std::is_same_v || - std::is_same_v || - std::is_same_v) + if constexpr(is_any_of::value) { return {0, 1, 3, 2}; } - else if constexpr(std::is_same_v || - std::is_same_v || - std::is_same_v) + else if constexpr(is_any_of::value) { return {0, 1, 4, 2, 3}; } - else if constexpr(std::is_same_v || - std::is_same_v || - std::is_same_v) + else if constexpr(is_any_of::value) { return {0, 1, 5, 2, 3, 4}; } - else if constexpr(std::is_same_v || - std::is_same_v || - std::is_same_v) + else if constexpr(is_any_of::value) { return {2, 0, 3, 1}; } - else if constexpr(std::is_same_v || - std::is_same_v || - std::is_same_v) + else if constexpr(is_any_of::value) { return {3, 0, 4, 1, 2}; } - else if constexpr(std::is_same_v || - std::is_same_v || - std::is_same_v) + else if constexpr(is_any_of::value) { return {4, 0, 5, 1, 2, 3}; } @@ -83,11 +67,11 @@ template CK_TILE_HOST HostTensorDescriptor make_input_host_tensor_descriptor_g_n_c_wis_packed(const ck_tile::conv::ConvParam& param) { + using namespace ck_tile::tensor_layout::convolution; + std::vector physical_lengths; - if constexpr(std::is_same_v || - std::is_same_v || - std::is_same_v) + if constexpr(is_any_of::value) { physical_lengths = std::vector{static_cast(param.G_), static_cast(param.N_), @@ -97,9 +81,7 @@ make_input_host_tensor_descriptor_g_n_c_wis_packed(const ck_tile::conv::ConvPara param.input_spatial_lengths_.begin(), param.input_spatial_lengths_.begin() + param.num_dim_spatial_); } - else if constexpr(std::is_same_v || - std::is_same_v || - std::is_same_v) + else if constexpr(is_any_of::value) { physical_lengths = std::vector{static_cast(param.G_), static_cast(param.N_), @@ -109,9 +91,7 @@ make_input_host_tensor_descriptor_g_n_c_wis_packed(const ck_tile::conv::ConvPara param.input_spatial_lengths_.begin(), param.input_spatial_lengths_.begin() + param.num_dim_spatial_); } - else if constexpr(std::is_same_v || - std::is_same_v || - std::is_same_v) + else if constexpr(is_any_of::value) { physical_lengths = std::vector{static_cast(param.N_), static_cast(param.G_), @@ -139,11 +119,11 @@ template CK_TILE_HOST HostTensorDescriptor make_weight_host_tensor_descriptor_g_k_c_xs_packed(const ck_tile::conv::ConvParam& param) { + using namespace ck_tile::tensor_layout::convolution; + std::vector physical_lengths; - if constexpr(std::is_same_v || - std::is_same_v || - std::is_same_v) + if constexpr(is_any_of::value) { if(param.G_ != 1) { @@ -157,9 +137,7 @@ make_weight_host_tensor_descriptor_g_k_c_xs_packed(const ck_tile::conv::ConvPara param.filter_spatial_lengths_.begin(), param.filter_spatial_lengths_.begin() + param.num_dim_spatial_); } - else if constexpr(std::is_same_v || - std::is_same_v || - std::is_same_v) + else if constexpr(is_any_of::value) { physical_lengths = std::vector{static_cast(param.G_), static_cast(param.K_), @@ -169,9 +147,7 @@ make_weight_host_tensor_descriptor_g_k_c_xs_packed(const ck_tile::conv::ConvPara param.filter_spatial_lengths_.begin(), param.filter_spatial_lengths_.begin() + param.num_dim_spatial_); } - else if constexpr(std::is_same_v || - std::is_same_v || - std::is_same_v) + else if constexpr(is_any_of::value) { physical_lengths = std::vector{static_cast(param.G_), static_cast(param.K_), @@ -181,9 +157,7 @@ make_weight_host_tensor_descriptor_g_k_c_xs_packed(const ck_tile::conv::ConvPara param.filter_spatial_lengths_.begin(), param.filter_spatial_lengths_.begin() + param.num_dim_spatial_); } - else if constexpr(std::is_same_v || - std::is_same_v || - std::is_same_v) + else if constexpr(is_any_of::value) { physical_lengths = std::vector{static_cast(param.K_), static_cast(param.G_), @@ -211,11 +185,11 @@ template CK_TILE_HOST HostTensorDescriptor make_output_host_tensor_descriptor_g_n_k_wos_packed(const ck_tile::conv::ConvParam& param) { + using namespace ck_tile::tensor_layout::convolution; + std::vector physical_lengths; - if constexpr(std::is_same_v || - std::is_same_v || - std::is_same_v) + if constexpr(is_any_of::value) { physical_lengths = std::vector{static_cast(param.G_), static_cast(param.N_), @@ -226,9 +200,7 @@ make_output_host_tensor_descriptor_g_n_k_wos_packed(const ck_tile::conv::ConvPar param.output_spatial_lengths_.begin() + param.num_dim_spatial_); } // separate from legacy code above - else if constexpr(std::is_same_v || - std::is_same_v || - std::is_same_v) + else if constexpr(is_any_of::value) { physical_lengths = std::vector{static_cast(param.G_), static_cast(param.N_), @@ -238,9 +210,7 @@ make_output_host_tensor_descriptor_g_n_k_wos_packed(const ck_tile::conv::ConvPar param.output_spatial_lengths_.begin(), param.output_spatial_lengths_.begin() + param.num_dim_spatial_); } - else if constexpr(std::is_same_v || - std::is_same_v || - std::is_same_v) + else if constexpr(is_any_of::value) { physical_lengths = std::vector{static_cast(param.N_), static_cast(param.G_),