diff --git a/example/62_conv_fwd_activ/CMakeLists.txt b/example/62_conv_fwd_activ/CMakeLists.txt new file mode 100644 index 0000000000..ea38216fa9 --- /dev/null +++ b/example/62_conv_fwd_activ/CMakeLists.txt @@ -0,0 +1,35 @@ +list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942) +set(target 0) +foreach(gpu IN LISTS GPU_TARGETS) + if(gpu IN_LIST gpu_list AND target EQUAL 0) + add_custom_target(example_convnd_fwd_activ_xdl) + # Sigmoid + add_example_executable(example_convnd_fwd_xdl_sigmoid_fp16 convnd_fwd_xdl_sigmoid_fp16.cpp) + add_example_dependencies(example_convnd_fwd_activ_xdl example_convnd_fwd_xdl_sigmoid_fp16) + # Tanh + add_example_executable(example_convnd_fwd_xdl_tanh_fp16 convnd_fwd_xdl_tanh_fp16.cpp) + add_example_dependencies(example_convnd_fwd_activ_xdl example_convnd_fwd_xdl_tanh_fp16) + # Relu + add_example_executable(example_convnd_fwd_xdl_relu_fp16 convnd_fwd_xdl_relu_fp16.cpp) + add_example_dependencies(example_convnd_fwd_activ_xdl example_convnd_fwd_xdl_relu_fp16) + # SoftRelu + add_example_executable(example_convnd_fwd_xdl_softrelu_fp16 convnd_fwd_xdl_softrelu_fp16.cpp) + add_example_dependencies(example_convnd_fwd_activ_xdl example_convnd_fwd_xdl_softrelu_fp16) + # Abs + add_example_executable(example_convnd_fwd_xdl_abs_fp16 convnd_fwd_xdl_abs_fp16.cpp) + add_example_dependencies(example_convnd_fwd_activ_xdl example_convnd_fwd_xdl_abs_fp16) + # Pow + add_example_executable(example_convnd_fwd_xdl_pow_fp16 convnd_fwd_xdl_pow_fp16.cpp) + add_example_dependencies(example_convnd_fwd_activ_xdl example_convnd_fwd_xdl_pow_fp16) + # Clipped Relu + add_example_executable(example_convnd_fwd_xdl_clippedrelu_fp16 convnd_fwd_xdl_clippedrelu_fp16.cpp) + add_example_dependencies(example_convnd_fwd_activ_xdl example_convnd_fwd_xdl_clippedrelu_fp16) + # Leaky Relu + add_example_executable(example_convnd_fwd_xdl_leakyrelu_fp16 convnd_fwd_xdl_leakyrelu_fp16.cpp) + add_example_dependencies(example_convnd_fwd_activ_xdl example_convnd_fwd_xdl_leakyrelu_fp16) + # Elu + add_example_executable(example_convnd_fwd_xdl_elu_fp16 convnd_fwd_xdl_elu_fp16.cpp) + add_example_dependencies(example_convnd_fwd_activ_xdl example_convnd_fwd_xdl_elu_fp16) + set(target 1) + endif() +endforeach() diff --git a/example/62_conv_fwd_activ/convnd_fwd_activ_common.hpp b/example/62_conv_fwd_activ/convnd_fwd_activ_common.hpp new file mode 100644 index 0000000000..185026b1e3 --- /dev/null +++ b/example/62_conv_fwd_activ/convnd_fwd_activ_common.hpp @@ -0,0 +1,238 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_xdl_cshuffle.hpp" + +#include "ck/library/utility/algorithm.hpp" +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/utility/convolution_parameter.hpp" +#include "ck/library/utility/convolution_host_tensor_descriptor_helper.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_conv_fwd.hpp" +#include "ck/library/utility/convolution_host_tensor_descriptor_helper.hpp" + +constexpr ck::index_t NDimSpatial = 3; +using InDataType = ck::half_t; +using WeiDataType = ck::half_t; +using AccDataType = float; +using CShuffleDataType = ck::half_t; +using OutDataType = ck::half_t; + +template +using S = ck::Sequence; + +using InLayout = ck::tensor_layout::convolution::GNDHWC; +using WeiLayout = ck::tensor_layout::convolution::GKZYXC; +using OutLayout = ck::tensor_layout::convolution::GNDHWK; + +using InElementOp = ck::tensor_operation::element_wise::PassThrough; +using WeiElementOp = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto ConvSpec = + ck::tensor_operation::device::ConvolutionForwardSpecialization::Default; + +static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + +template +using DeviceGroupedConvNDFwdInstance = + ck::tensor_operation::device::DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< + NDimSpatial, + InLayout, + WeiLayout, + ck::Tuple<>, + OutLayout, + InDataType, + WeiDataType, + AccDataType, + CShuffleDataType, + ck::Tuple<>, + OutDataType, + InElementOp, + WeiElementOp, + OutElementOp, + ConvSpec, // ConvForwardSpecialization + GemmSpec, // GemmSpecialization + 1, // + 256, // BlockSize + 128, // MPerBlock + 256, // NPerBlock + 32, // KPerBlock + 8, // AK1 + 8, // BK1 + 32, // MPerXdl + 32, // NPerXdl + 2, // MXdlPerWave + 4, // NXdlPerWave + S<4, 64, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1 + S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder + S<1, 0, 2>, // ABlockTransferSrcAccessOrder + 2, // ABlockTransferSrcVectorDim + 8, // ABlockTransferSrcScalarPerVector + 8, // ABlockTransferDstScalarPerVector_AK1 + 1, // ABlockLdsExtraM + S<4, 64, 1>, // BBlockTransferThreadClusterLengths_BK0_N_BK1 + S<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder + S<1, 0, 2>, // BBlockTransferSrcAccessOrder + 2, // BBlockTransferSrcVectorDim + 8, // BBlockTransferSrcScalarPerVector + 8, // BBlockTransferDstScalarPerVector_BK1 + 1, // BBlockLdsExtraN + 1, + 1, + S<1, 32, 1, 8>, + 8>; + +template +bool run_grouped_conv_fwd(bool do_verification, + int init_method, + bool time_kernel, + const ck::utils::conv::ConvParam& conv_param, + const HostTensorDescriptor& in_g_n_c_wis_desc, + const HostTensorDescriptor& wei_g_k_c_xs_desc, + const HostTensorDescriptor& out_g_n_k_wos_desc, + const InElementOp& in_element_op, + const WeiElementOp& wei_element_op, + const OutElementOp& out_element_op) +{ + Tensor in(in_g_n_c_wis_desc); + Tensor wei(wei_g_k_c_xs_desc); + Tensor out_host(out_g_n_k_wos_desc); + Tensor out_device(out_g_n_k_wos_desc); + + std::cout << "in: " << in.mDesc << std::endl; + std::cout << "wei: " << wei.mDesc << std::endl; + std::cout << "out: " << out_host.mDesc << std::endl; + + switch(init_method) + { + case 0: break; + case 1: + in.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + wei.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + break; + default: + in.GenerateTensorValue(GeneratorTensor_3{-1.0, 1.0}); + wei.GenerateTensorValue(GeneratorTensor_3{-0.05, 0.05}); + } + + DeviceMem in_device_buf(sizeof(InDataType) * in.mDesc.GetElementSpaceSize()); + DeviceMem wei_device_buf(sizeof(WeiDataType) * wei.mDesc.GetElementSpaceSize()); + DeviceMem out_device_buf(sizeof(OutDataType) * out_device.mDesc.GetElementSpaceSize()); + + in_device_buf.ToDevice(in.mData.data()); + wei_device_buf.ToDevice(wei.mData.data()); + + std::array a_g_n_c_wis_lengths{}; + std::array a_g_n_c_wis_strides{}; + std::array b_g_k_c_xs_lengths{}; + std::array b_g_k_c_xs_strides{}; + std::array e_g_n_k_wos_lengths{}; + std::array e_g_n_k_wos_strides{}; + std::array conv_filter_strides{}; + std::array conv_filter_dilations{}; + std::array input_left_pads{}; + std::array input_right_pads{}; + + auto copy = [](const auto& x, auto& y) { ck::ranges::copy(x, y.begin()); }; + + copy(in_g_n_c_wis_desc.GetLengths(), a_g_n_c_wis_lengths); + copy(in_g_n_c_wis_desc.GetStrides(), a_g_n_c_wis_strides); + copy(wei_g_k_c_xs_desc.GetLengths(), b_g_k_c_xs_lengths); + copy(wei_g_k_c_xs_desc.GetStrides(), b_g_k_c_xs_strides); + copy(out_g_n_k_wos_desc.GetLengths(), e_g_n_k_wos_lengths); + copy(out_g_n_k_wos_desc.GetStrides(), e_g_n_k_wos_strides); + copy(conv_param.conv_filter_strides_, conv_filter_strides); + copy(conv_param.conv_filter_dilations_, conv_filter_dilations); + copy(conv_param.input_left_pads_, input_left_pads); + copy(conv_param.input_right_pads_, input_right_pads); + + // do Conv + auto conv = DeviceConvNDFwdInstance{}; + auto invoker = conv.MakeInvoker(); + auto argument = conv.MakeArgument(in_device_buf.GetDeviceBuffer(), + wei_device_buf.GetDeviceBuffer(), + std::array{}, + out_device_buf.GetDeviceBuffer(), + a_g_n_c_wis_lengths, + a_g_n_c_wis_strides, + b_g_k_c_xs_lengths, + b_g_k_c_xs_strides, + std::array, 0>{{}}, + std::array, 0>{{}}, + e_g_n_k_wos_lengths, + e_g_n_k_wos_strides, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + in_element_op, + wei_element_op, + out_element_op); + + if(!conv.IsSupportedArgument(argument)) + { + throw std::runtime_error( + "wrong! device_conv with the specified compilation parameters does " + "not support this Conv problem"); + } + + float avg_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel}); + + std::size_t flop = conv_param.GetFlops(); + std::size_t num_btype = conv_param.GetByte(); + + float tflops = static_cast(flop) / 1.E9 / avg_time; + float gb_per_sec = num_btype / 1.E6 / avg_time; + std::cout << "Perf: " << avg_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, " + << conv.GetTypeString() << std::endl; + + if(do_verification) + { + auto ref_conv = ck::tensor_operation::host::ReferenceConvFwd(); + + auto ref_invoker = ref_conv.MakeInvoker(); + auto ref_argument = ref_conv.MakeArgument(in, + wei, + out_host, + conv_param.conv_filter_strides_, + conv_param.conv_filter_dilations_, + conv_param.input_left_pads_, + conv_param.input_right_pads_, + in_element_op, + wei_element_op, + out_element_op); + + ref_invoker.Run(ref_argument); + + out_device_buf.FromDevice(out_device.mData.data()); + + return ck::utils::check_err(out_device, out_host, "Error: incorrect results!"); + } + + return true; +} diff --git a/example/62_conv_fwd_activ/convnd_fwd_xdl_abs_fp16.cpp b/example/62_conv_fwd_activ/convnd_fwd_xdl_abs_fp16.cpp new file mode 100644 index 0000000000..4fe0c857fa --- /dev/null +++ b/example/62_conv_fwd_activ/convnd_fwd_xdl_abs_fp16.cpp @@ -0,0 +1,11 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "convnd_fwd_activ_common.hpp" + +using OutElementOp = ck::tensor_operation::element_wise::UnaryAbs; + +using DeviceGroupedConvNDFwdActivInstance = DeviceGroupedConvNDFwdInstance; +#include "run_convnd_fwd_activ_example.inc" + +int main(int argc, char* argv[]) { return !run_convnd_fwd_example(argc, argv); } diff --git a/example/62_conv_fwd_activ/convnd_fwd_xdl_clippedrelu_fp16.cpp b/example/62_conv_fwd_activ/convnd_fwd_xdl_clippedrelu_fp16.cpp new file mode 100644 index 0000000000..feabacc5c9 --- /dev/null +++ b/example/62_conv_fwd_activ/convnd_fwd_xdl_clippedrelu_fp16.cpp @@ -0,0 +1,11 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "convnd_fwd_activ_common.hpp" + +using OutElementOp = ck::tensor_operation::element_wise::ClippedRelu; + +using DeviceGroupedConvNDFwdActivInstance = DeviceGroupedConvNDFwdInstance; +#include "run_convnd_fwd_activ_example.inc" + +int main(int argc, char* argv[]) { return !run_convnd_fwd_example(argc, argv); } diff --git a/example/62_conv_fwd_activ/convnd_fwd_xdl_elu_fp16.cpp b/example/62_conv_fwd_activ/convnd_fwd_xdl_elu_fp16.cpp new file mode 100644 index 0000000000..793102dbc6 --- /dev/null +++ b/example/62_conv_fwd_activ/convnd_fwd_xdl_elu_fp16.cpp @@ -0,0 +1,11 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "convnd_fwd_activ_common.hpp" + +using OutElementOp = ck::tensor_operation::element_wise::Elu; + +using DeviceGroupedConvNDFwdActivInstance = DeviceGroupedConvNDFwdInstance; +#include "run_convnd_fwd_activ_example.inc" + +int main(int argc, char* argv[]) { return !run_convnd_fwd_example(argc, argv); } diff --git a/example/62_conv_fwd_activ/convnd_fwd_xdl_leakyrelu_fp16.cpp b/example/62_conv_fwd_activ/convnd_fwd_xdl_leakyrelu_fp16.cpp new file mode 100644 index 0000000000..a77408db7e --- /dev/null +++ b/example/62_conv_fwd_activ/convnd_fwd_xdl_leakyrelu_fp16.cpp @@ -0,0 +1,11 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "convnd_fwd_activ_common.hpp" + +using OutElementOp = ck::tensor_operation::element_wise::LeakyRelu; + +using DeviceGroupedConvNDFwdActivInstance = DeviceGroupedConvNDFwdInstance; +#include "run_convnd_fwd_activ_example.inc" + +int main(int argc, char* argv[]) { return !run_convnd_fwd_example(argc, argv); } diff --git a/example/62_conv_fwd_activ/convnd_fwd_xdl_pow_fp16.cpp b/example/62_conv_fwd_activ/convnd_fwd_xdl_pow_fp16.cpp new file mode 100644 index 0000000000..2b695cf8c3 --- /dev/null +++ b/example/62_conv_fwd_activ/convnd_fwd_xdl_pow_fp16.cpp @@ -0,0 +1,11 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "convnd_fwd_activ_common.hpp" + +using OutElementOp = ck::tensor_operation::element_wise::Power; + +using DeviceGroupedConvNDFwdActivInstance = DeviceGroupedConvNDFwdInstance; +#include "run_convnd_fwd_activ_example.inc" + +int main(int argc, char* argv[]) { return !run_convnd_fwd_example(argc, argv); } diff --git a/example/62_conv_fwd_activ/convnd_fwd_xdl_relu_fp16.cpp b/example/62_conv_fwd_activ/convnd_fwd_xdl_relu_fp16.cpp new file mode 100644 index 0000000000..e1b6e3f0cc --- /dev/null +++ b/example/62_conv_fwd_activ/convnd_fwd_xdl_relu_fp16.cpp @@ -0,0 +1,11 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "convnd_fwd_activ_common.hpp" + +using OutElementOp = ck::tensor_operation::element_wise::Relu; + +using DeviceGroupedConvNDFwdActivInstance = DeviceGroupedConvNDFwdInstance; +#include "run_convnd_fwd_activ_example.inc" + +int main(int argc, char* argv[]) { return !run_convnd_fwd_example(argc, argv); } diff --git a/example/62_conv_fwd_activ/convnd_fwd_xdl_sigmoid_fp16.cpp b/example/62_conv_fwd_activ/convnd_fwd_xdl_sigmoid_fp16.cpp new file mode 100644 index 0000000000..350c15a787 --- /dev/null +++ b/example/62_conv_fwd_activ/convnd_fwd_xdl_sigmoid_fp16.cpp @@ -0,0 +1,11 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "convnd_fwd_activ_common.hpp" + +using OutElementOp = ck::tensor_operation::element_wise::Sigmoid; + +using DeviceGroupedConvNDFwdActivInstance = DeviceGroupedConvNDFwdInstance; +#include "run_convnd_fwd_activ_example.inc" + +int main(int argc, char* argv[]) { return !run_convnd_fwd_example(argc, argv); } diff --git a/example/62_conv_fwd_activ/convnd_fwd_xdl_softrelu_fp16.cpp b/example/62_conv_fwd_activ/convnd_fwd_xdl_softrelu_fp16.cpp new file mode 100644 index 0000000000..ec52e1a3c4 --- /dev/null +++ b/example/62_conv_fwd_activ/convnd_fwd_xdl_softrelu_fp16.cpp @@ -0,0 +1,11 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "convnd_fwd_activ_common.hpp" + +using OutElementOp = ck::tensor_operation::element_wise::SoftRelu; + +using DeviceGroupedConvNDFwdActivInstance = DeviceGroupedConvNDFwdInstance; +#include "run_convnd_fwd_activ_example.inc" + +int main(int argc, char* argv[]) { return !run_convnd_fwd_example(argc, argv); } diff --git a/example/62_conv_fwd_activ/convnd_fwd_xdl_tanh_fp16.cpp b/example/62_conv_fwd_activ/convnd_fwd_xdl_tanh_fp16.cpp new file mode 100644 index 0000000000..dca405669a --- /dev/null +++ b/example/62_conv_fwd_activ/convnd_fwd_xdl_tanh_fp16.cpp @@ -0,0 +1,11 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "convnd_fwd_activ_common.hpp" + +using OutElementOp = ck::tensor_operation::element_wise::TanH; + +using DeviceGroupedConvNDFwdActivInstance = DeviceGroupedConvNDFwdInstance; +#include "run_convnd_fwd_activ_example.inc" + +int main(int argc, char* argv[]) { return !run_convnd_fwd_example(argc, argv); } diff --git a/example/62_conv_fwd_activ/run_convnd_fwd_activ_example.inc b/example/62_conv_fwd_activ/run_convnd_fwd_activ_example.inc new file mode 100644 index 0000000000..7c20c01066 --- /dev/null +++ b/example/62_conv_fwd_activ/run_convnd_fwd_activ_example.inc @@ -0,0 +1,91 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +void print_helper_msg() +{ + std::cout << "arg1: verification (0=no, 1=yes)\n" + << "arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n" + << "arg3: time kernel (0=no, 1=yes)\n" + << ck::utils::conv::get_conv_param_parser_helper_msg() << std::endl; +} + +bool run_convnd_fwd_example(int argc, char* argv[]) +{ + print_helper_msg(); + + bool do_verification = true; + // Use floats for SoftRelu by default to avoid overflow after e^x. + int init_method = + std::is_same_v ? 2 : 1; + bool time_kernel = false; + + // Following shapes are selected to avoid overflow. Expect inf in case of + // size increase for some elementwise ops. + ck::utils::conv::ConvParam conv_param{ + 3, 1, 16, 128, 8, {3, 3, 3}, {17, 17, 17}, {2, 2, 2}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}}; + + if(argc == 1) + { + // use default + } + else if(argc == 4) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + } + else + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + const ck::index_t num_dim_spatial = std::stoi(argv[4]); + + conv_param = ck::utils::conv::parse_conv_param(num_dim_spatial, 5, argv); + } + + const auto in_element_op = InElementOp{}; + const auto wei_element_op = WeiElementOp{}; + const auto out_element_op = OutElementOp{}; + + const auto run = [&]() { + const auto in_g_n_c_wis_desc = + ck::utils::conv::make_input_host_tensor_descriptor_g_n_c_wis_packed( + conv_param); + + const auto wei_g_k_c_xs_desc = + ck::utils::conv::make_weight_host_tensor_descriptor_g_k_c_xs_packed( + conv_param); + + const auto out_g_n_k_wos_desc = + ck::utils::conv::make_output_host_tensor_descriptor_g_n_k_wos_packed( + conv_param); + + return run_grouped_conv_fwd(do_verification, + init_method, + time_kernel, + conv_param, + in_g_n_c_wis_desc, + wei_g_k_c_xs_desc, + out_g_n_k_wos_desc, + in_element_op, + wei_element_op, + out_element_op); + }; + + if(conv_param.num_dim_spatial_ == 3) + { + return run(); + } + + return false; +} diff --git a/include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp b/include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp index 23fdd23e10..dabdf649e4 100644 --- a/include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp +++ b/include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp @@ -442,10 +442,11 @@ struct Sigmoid __host__ __device__ void operator()(T& y, const T& x) const { static_assert(is_same::value || is_same::value || - is_same::value, + is_same::value || is_same::value || + is_same::value, "Data type is not supported by this operation!"); - - y = 1 / (ck::type_convert(1) + exp(-x)); + constexpr T one = type_convert(1); + y = one / (one + ck::math::exp(-x)); }; }; @@ -455,7 +456,8 @@ struct TanH __host__ __device__ void operator()(T& y, const T& x) const { static_assert(is_same::value || is_same::value || - is_same::value, + is_same::value || is_same::value || + is_same::value, "Data type is not supported by this operation!"); y = ck::math::tanh(x); @@ -481,7 +483,101 @@ struct Swish y = type_convert(x / (1.f + ck::math::exp(bx))); }; - float beta_ = 1.0f; + const float beta_; +}; + +struct SoftRelu +{ + SoftRelu(float alpha = 1.f) : alpha_(alpha){}; + + template + __host__ __device__ void operator()(T& y, const T& x) const + { + static_assert(is_same::value || is_same::value || + is_same::value || is_same::value || + is_same::value, + "Data type is not supported by this operation!"); + T casted_alpha = type_convert(alpha_); + constexpr T one = type_convert(1); + y = ck::math::log(one + ck::math::exp(x * casted_alpha)) / casted_alpha; + } + const float alpha_; +}; + +struct Power +{ + Power(float alpha = 0.f, float beta = 1.f, float gamma = 2.f) + : alpha_(alpha), beta_(beta), gamma_(gamma){}; + + template + __host__ __device__ void operator()(T& y, const T& x) const + { + static_assert(is_same::value || is_same::value || + is_same::value || is_same::value || + is_same::value, + "Data type is not supported by this operation!"); + T casted_alpha = type_convert(alpha_); + T casted_beta = type_convert(beta_); + T casted_gamma = type_convert(gamma_); + T shifted_scaled_x = casted_alpha + casted_beta * x; + y = ck::math::pow(shifted_scaled_x, casted_gamma); + } + const float alpha_; + const float beta_; + const float gamma_; +}; + +struct ClippedRelu +{ + ClippedRelu(float alpha = 0.f, float beta = 1.f) : alpha_(alpha), beta_(beta){}; + + template + __host__ __device__ void operator()(T& y, const T& x) const + { + static_assert(is_same::value || is_same::value || + is_same::value || is_same::value || + is_same::value, + "Data type is not supported by this operation!"); + T casted_alpha = type_convert(alpha_); + T casted_beta = type_convert(beta_); + y = ck::math::min(casted_beta, ck::math::max(casted_alpha, x)); + } + const float alpha_; + const float beta_; +}; + +struct LeakyRelu +{ + LeakyRelu(float alpha = 0.01f) : alpha_(alpha){}; + + template + __host__ __device__ void operator()(T& y, const T& x) const + { + static_assert(is_same::value || is_same::value || + is_same::value || is_same::value || + is_same::value, + "Data type is not supported by this operation!"); + T casted_alpha = type_convert(alpha_); + y = x >= 0 ? x : x * casted_alpha; + } + const float alpha_; +}; + +struct Elu +{ + Elu(float alpha = 1.f) : alpha_(alpha){}; + + template + __host__ __device__ void operator()(T& y, const T& x) const + { + static_assert(is_same::value || is_same::value || + is_same::value || is_same::value || + is_same::value, + "Data type is not supported by this operation!"); + T casted_alpha = type_convert(alpha_); + y = x > 0 ? x : casted_alpha * ck::math::expm1(x); + } + const float alpha_; }; } // namespace element_wise diff --git a/include/ck/utility/math.hpp b/include/ck/utility/math.hpp index c5e967c8f4..7efbb3e63a 100644 --- a/include/ck/utility/math.hpp +++ b/include/ck/utility/math.hpp @@ -150,28 +150,6 @@ __host__ __device__ constexpr T clamp(const T& x, const T& lowerbound, const T& return min(max(x, lowerbound), upperbound); } -// disallow implicit type casting -template -__device__ T exp(T x); - -// TODO: add f16 support using v_exp_f16 - -template <> -__device__ float exp(float x) -{ - return __expf(x); -} - -template <> -__device__ double exp(double x) -{ - return exp(x); -} - -static inline __host__ float exp(float x) { return std::expf(x); } - -static inline __host__ double exp(double x) { return std::exp(x); } - // greatest common divisor, aka highest common factor __host__ __device__ constexpr index_t gcd(index_t x, index_t y) { diff --git a/include/ck/utility/math_v2.hpp b/include/ck/utility/math_v2.hpp index 1cac2cc0c7..a07fde3da3 100644 --- a/include/ck/utility/math_v2.hpp +++ b/include/ck/utility/math_v2.hpp @@ -9,6 +9,7 @@ #include "ck/utility/data_type.hpp" #include "ck/utility/type.hpp" +#include "ck/utility/type_convert.hpp" namespace ck { namespace math { @@ -92,14 +93,96 @@ static inline __host__ float sqrt(float x) { return std::sqrt(x); }; static inline __host__ double sqrt(double x) { return std::sqrt(x); }; -static inline __host__ half_t tanh(half_t x) +template +inline __host__ T tanh(T x) { - return static_cast(std::tanh(static_cast(x))); + return ck::type_convert(std::tanhf(ck::type_convert(x))); }; -static inline __host__ float tanh(float x) { return std::tanh(x); }; +template <> +inline __host__ float tanh(float x) +{ + return std::tanhf(x); +}; -static inline __host__ double tanh(double x) { return std::tanh(x); }; +template <> +inline __host__ double tanh(double x) +{ + return std::tanh(x); +}; + +template +inline __host__ T exp(T x) +{ + return ck::type_convert(std::expf(ck::type_convert(x))); +} + +template <> +inline __host__ float exp(float x) +{ + return std::expf(x); +} + +template <> +inline __host__ double exp(double x) +{ + return std::exp(x); +} + +template +inline __host__ T log(T x) +{ + return ck::type_convert(std::logf(ck::type_convert(x))); +} + +template <> +inline __host__ float log(float x) +{ + return std::logf(x); +} + +template <> +inline __host__ double log(double x) +{ + return std::log(x); +} + +template +inline __host__ T pow(T x, T gamma) +{ + return ck::type_convert( + std::powf(ck::type_convert(x), ck::type_convert(gamma))); +} + +template <> +inline __host__ float pow(float x, float gamma) +{ + return std::powf(x, gamma); +} + +template <> +inline __host__ double pow(double x, double gamma) +{ + return std::pow(x, gamma); +} + +template +inline __host__ T expm1(T x) +{ + return ck::type_convert(std::expm1f(ck::type_convert(x))); +} + +template <> +inline __host__ float expm1(float x) +{ + return std::expm1f(x); +} + +template <> +inline __host__ double expm1(double x) +{ + return std::expm1(x); +} // math functions for the HIP kernel, some are implemented by calling hip builtin functions @@ -181,14 +264,107 @@ static inline __device__ float sqrt(float x) { return __builtin_amdgcn_sqrtf(x); static inline __device__ double sqrt(double x) { return __builtin_amdgcn_sqrt(x); }; -static inline __device__ half_t tanh(half_t x) +template +inline __device__ T tanh(T x) { - return static_cast(::tanhf(static_cast(x))); + return ck::type_convert(::tanhf(ck::type_convert(x))); }; -static inline __device__ float tanh(float x) { return ::tanhf(x); }; +template <> +inline __device__ float tanh(float x) +{ + return ::tanhf(x); +}; -static inline __device__ double tanh(double x) { return ::tanh(x); }; +template <> +inline __device__ double tanh(double x) +{ + return ::tanh(x); +}; + +template +inline __device__ T exp(T x) +{ + return ck::type_convert(__expf(ck::type_convert(x))); +}; + +template <> +inline __device__ half_t exp(half_t x) +{ + return hexp(x); +}; + +template <> +inline __device__ float exp(float x) +{ + return __expf(x); +}; + +template <> +inline __device__ double exp(double x) +{ + return exp(x); +}; + +template +inline __device__ T log(T x) +{ + return ck::type_convert(__logf(ck::type_convert(x))); +}; + +template <> +inline __device__ half_t log(half_t x) +{ + return hlog(x); +}; + +template <> +inline __device__ float log(float x) +{ + return __logf(x); +}; + +template <> +inline __device__ double log(double x) +{ + return log(x); +}; + +template +inline __device__ T pow(T x, T gamma) +{ + return ck::type_convert(powf(ck::type_convert(x), ck::type_convert(gamma))); +}; + +template <> +inline __device__ float pow(float x, float gamma) +{ + return powf(x, gamma); +}; + +template <> +inline __device__ double pow(double x, double gamma) +{ + return pow(x, gamma); +}; + +template +inline __device__ T expm1(T x) +{ + return ck::type_convert(expm1f(ck::type_convert(x))); +}; + +template <> +inline __device__ float expm1(float x) +{ + return expm1f(x); +}; + +template <> +inline __device__ double expm1(double x) +{ + return expm1(x); +}; } // namespace math } // namespace ck diff --git a/include/ck/utility/statically_indexed_array_multi_index.hpp b/include/ck/utility/statically_indexed_array_multi_index.hpp index 4a8b96ae8a..80a865ed87 100644 --- a/include/ck/utility/statically_indexed_array_multi_index.hpp +++ b/include/ck/utility/statically_indexed_array_multi_index.hpp @@ -5,6 +5,7 @@ #define CK_STATICALLY_INDEXED_ARRAY_MULTI_INDEX_HPP #include "common_header.hpp" +#include "ck/utility/math_v2.hpp" namespace ck { diff --git a/library/include/ck/library/reference_tensor_operation/cpu/reference_conv_fwd.hpp b/library/include/ck/library/reference_tensor_operation/cpu/reference_conv_fwd.hpp index 8f4182a231..0be9d83ad3 100644 --- a/library/include/ck/library/reference_tensor_operation/cpu/reference_conv_fwd.hpp +++ b/library/include/ck/library/reference_tensor_operation/cpu/reference_conv_fwd.hpp @@ -128,11 +128,9 @@ struct ReferenceConvFwd : public device::BaseOperator } } - float v_out; - - arg.out_element_op_(v_out, v_acc); - - arg.output_(g, n, k, wo) = ck::type_convert(v_out); + OutDataType v_out; + arg.out_element_op_(v_out, ck::type_convert(v_acc)); + arg.output_(g, n, k, wo) = v_out; }; make_ParallelTensorFunctor(func, @@ -184,11 +182,9 @@ struct ReferenceConvFwd : public device::BaseOperator } } - float v_out; - - arg.out_element_op_(v_out, v_acc); - - arg.output_(g, n, k, ho, wo) = ck::type_convert(v_out); + OutDataType v_out; + arg.out_element_op_(v_out, ck::type_convert(v_acc)); + arg.output_(g, n, k, ho, wo) = v_out; }; make_ParallelTensorFunctor(func, @@ -253,11 +249,9 @@ struct ReferenceConvFwd : public device::BaseOperator } } - float v_out; - - arg.out_element_op_(v_out, v_acc); - - arg.output_(g, n, k, d_o, ho, wo) = ck::type_convert(v_out); + OutDataType v_out; + arg.out_element_op_(v_out, ck::type_convert(v_acc)); + arg.output_(g, n, k, d_o, ho, wo) = v_out; }; make_ParallelTensorFunctor(func,