diff --git a/example/44_elementwise_permute/CMakeLists.txt b/example/44_elementwise_permute/CMakeLists.txt index a963399dc7..3cf4812509 100644 --- a/example/44_elementwise_permute/CMakeLists.txt +++ b/example/44_elementwise_permute/CMakeLists.txt @@ -4,6 +4,8 @@ add_example_executable(example_elementwise_permute_4D_fp32_row elementwise_permu add_example_executable(example_elementwise_permute_4D_fp16_row elementwise_permute_4D_fp16_row.cpp) add_example_executable(example_elementwise_permute_4D_fp32_col elementwise_permute_4D_fp32_col.cpp) add_example_executable(example_elementwise_permute_4D_fp16_col elementwise_permute_4D_fp16_col.cpp) +add_example_executable(example_elementwise_binary_4D_fp16 elementwise_binary_4D_fp16.cpp) +add_example_executable(example_elementwise_trinary_4D_fp16 elementwise_trinary_4D_fp16.cpp) add_example_executable(example_elementwise_permute elementwise_permute.cpp) if((NOT GPU_TARGETS MATCHES "gfx940") AND (NOT GPU_TARGETS MATCHES "gfx941") AND (NOT GPU_TARGETS MATCHES "gfx942")) add_example_executable(example_elementwise_permute_3d elementwise_permute_3d.cpp) diff --git a/example/44_elementwise_permute/elementwise_binary_4D_fp16.cpp b/example/44_elementwise_permute/elementwise_binary_4D_fp16.cpp new file mode 100644 index 0000000000..8819bb65e6 --- /dev/null +++ b/example/44_elementwise_permute/elementwise_binary_4D_fp16.cpp @@ -0,0 +1,140 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/element/combined_element_wise_operation.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_elementwise_dynamic_vector_dims_impl.hpp" + +#include "ck/library/reference_tensor_operation/cpu/reference_elementwise.hpp" + +#include "ck/library/utility/algorithm.hpp" +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" + +using F16 = ck::half_t; +using F32 = float; + +using ADataType = F16; +using BDataType = F16; + +using UnaryScale = ck::tensor_operation::element_wise::Scale; +using UnarySquare = ck::tensor_operation::element_wise::UnarySquare; +using UnaryScaleSquare = + ck::tensor_operation::element_wise::UnaryCombinedOp; +using BinaryAdd = ck::tensor_operation::element_wise::Add; +// B = alpha * A0 * A0 + beta * A1 * A1 +using BinaryAddUnaryScaleSquare = ck::tensor_operation::element_wise:: + BinaryWithUnaryCombinedOp; +using DeviceElementwisePermuteInstance = ck::tensor_operation::device::DeviceElementwiseImpl< + ck::Tuple, // InDataTypeTuple + ck::Tuple, // OutDataTypeTuple + BinaryAddUnaryScaleSquare, // ElementwiseOp + 4, // NumDim + 256, // BlockSize + 128, // M0PerBlock + 128, // M1PerBlock + 8, // M0PerThread + 8, // M1PerThread + ck::Sequence<1, 0>, // ThreadClusterArrangeOrder + ck::Sequence<8, 8>, // InScalarPerVectorSeq + ck::Sequence<8>>; // OutScalarPerVectorSeq + +int main() +{ + bool do_verification = true; + bool time_kernel = true; + + std::vector nchw = {16, 128, 32, 64}; + std::array ab_lengths; + std::array ab_strides = {static_cast(nchw[1] * nchw[2] * nchw[3]), + static_cast(nchw[2] * nchw[3]), + static_cast(nchw[3]), + 1}; + ck::ranges::copy(nchw, ab_lengths.begin()); + + std::array, 2> as = {Tensor(ab_lengths, ab_strides), + Tensor(ab_lengths, ab_strides)}; + Tensor& a0 = as[0]; + Tensor& a1 = as[1]; + Tensor b(ab_lengths, ab_strides); + float alpha = 3.f; + float beta = 2.f; + a0.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + a1.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + + DeviceMem a0_device_buf(sizeof(ADataType) * a0.mDesc.GetElementSpaceSize()); + DeviceMem a1_device_buf(sizeof(ADataType) * a1.mDesc.GetElementSpaceSize()); + DeviceMem b_device_buf(sizeof(BDataType) * b.mDesc.GetElementSpaceSize()); + + a0_device_buf.ToDevice(a0.mData.data()); + a1_device_buf.ToDevice(a1.mData.data()); + + std::array inputs = {a0_device_buf.GetDeviceBuffer(), + a1_device_buf.GetDeviceBuffer()}; + std::array output = {b_device_buf.GetDeviceBuffer()}; + + auto broadcastPermute = DeviceElementwisePermuteInstance{}; + auto unary_scale_op_a0 = UnaryScaleSquare{UnarySquare{}, UnaryScale{alpha}}; + auto unary_scale_op_a1 = UnaryScaleSquare{UnarySquare{}, UnaryScale{beta}}; + auto argument = broadcastPermute.MakeArgumentPointer( + ab_lengths, + {ab_strides, ab_strides}, + {ab_strides}, + inputs, + output, + BinaryAddUnaryScaleSquare{BinaryAdd{}, unary_scale_op_a0, unary_scale_op_a1}); + + if(!broadcastPermute.IsSupportedArgument(argument.get())) + { + throw std::runtime_error( + "The runtime parameters seems not supported by the device instance, exiting!"); + }; + + std::cout << "A0 (nchw): " << a0.mDesc << std::endl; + std::cout << "A1 (nchw): " << a1.mDesc << std::endl; + std::cout << "B (nchw): " << b.mDesc << std::endl; + + auto broadcastPermute_invoker_ptr = broadcastPermute.MakeInvokerPointer(); + float ave_time = + broadcastPermute_invoker_ptr->Run(argument.get(), StreamConfig{nullptr, time_kernel}); + std::size_t flop = std::size_t(5) * nchw[0] * nchw[1] * nchw[2] * nchw[3]; + + std::size_t num_btype = sizeof(ADataType) * (nchw[0] * nchw[1] * nchw[2] * nchw[3]) + + sizeof(BDataType) * (nchw[0] * nchw[1] * nchw[2] * nchw[3]); + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s" + << std::endl; + + bool pass = true; + + if(do_verification) + { + Tensor host_b(ab_lengths, ab_strides); + + using ReferenceElementwiseInstance = ck::tensor_operation::host:: + ReferenceElementwise<2, ADataType, BDataType, BinaryAddUnaryScaleSquare>; + auto ref_elementwise = ReferenceElementwiseInstance{}; + auto ref_invoker = ref_elementwise.MakeInvoker(); + + auto ref_argument = ref_elementwise.MakeArgument( + as, + host_b, + BinaryAddUnaryScaleSquare{BinaryAdd{}, unary_scale_op_a0, unary_scale_op_a1}); + ref_invoker.Run(ref_argument); + + b_device_buf.FromDevice(b.mData.data()); + pass &= + ck::utils::check_err(b.mData, host_b.mData, "Error: Incorrect results b", 1e-3, 1e-3); + } + + return pass ? 0 : 1; +} diff --git a/example/44_elementwise_permute/elementwise_permute.cpp b/example/44_elementwise_permute/elementwise_permute.cpp index 24e161c6d3..d3c3085eb8 100644 --- a/example/44_elementwise_permute/elementwise_permute.cpp +++ b/example/44_elementwise_permute/elementwise_permute.cpp @@ -8,6 +8,8 @@ #include "ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp" #include "ck/tensor_operation/gpu/device/impl/device_elementwise_impl.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_elementwise.hpp" + #include "ck/library/utility/algorithm.hpp" #include "ck/library/utility/check_err.hpp" #include "ck/library/utility/device_memory.hpp" @@ -30,20 +32,6 @@ using DeviceElementwisePermuteInstance = ck::Sequence<1>, // InScalarPerVectorSeq ck::Sequence<1>>; // OutScalarPerVectorSeq -template -void host_elementwise4D(HostTensorB& B_ndhwc, const HostTensorA& A_ncdhw, Functor functor) -{ - for(std::size_t n = 0; n < A_ncdhw.mDesc.GetLengths()[0]; ++n) - for(std::size_t c = 0; c < A_ncdhw.mDesc.GetLengths()[1]; ++c) - for(std::size_t d = 0; d < A_ncdhw.mDesc.GetLengths()[2]; ++d) - for(std::size_t h = 0; h < A_ncdhw.mDesc.GetLengths()[3]; ++h) - for(std::size_t w = 0; w < A_ncdhw.mDesc.GetLengths()[4]; ++w) - { - auto a_val = A_ncdhw(n, c, d, h, w); - functor(B_ndhwc(n, d, h, w, c), a_val); - } -} - int main() { bool do_verification = true; @@ -51,32 +39,7 @@ int main() std::vector ncdhw = {16, 8, 8, 8, 8}; std::vector ndhwc = {16, 8, 8, 8, 8}; - Tensor a(ncdhw); - Tensor b(ndhwc); - - a.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); - - DeviceMem a_device_buf(sizeof(ADataType) * a.mDesc.GetElementSpaceSize()); - DeviceMem b_device_buf(sizeof(BDataType) * b.mDesc.GetElementSpaceSize()); - - a_device_buf.ToDevice(a.mData.data()); - - std::array input = {a_device_buf.GetDeviceBuffer()}; - std::array output = {b_device_buf.GetDeviceBuffer()}; - std::array ab_lengths; - /**std::array a_strides = { - static_cast(ncdhw[1] * ncdhw[2] * ncdhw[3] * ncdhw[4]), - static_cast(ncdhw[2] * ncdhw[3] * ncdhw[4]), - static_cast(ncdhw[3] * ncdhw[4]), - static_cast(ncdhw[4]), - 1}; - std::array b_strides = { - static_cast(ndhwc[1] * ndhwc[2] * ndhwc[3] * ndhwc[4]), - static_cast(ndhwc[2] * ndhwc[3] * ndhwc[4]), - 1, - static_cast(ndhwc[3] * ndhwc[4]), - static_cast(ndhwc[4])};**/ std::array a_strides = { static_cast(ncdhw[1] * ncdhw[2] * ncdhw[3] * ncdhw[4]), @@ -93,6 +56,20 @@ int main() 1}; ck::ranges::copy(ncdhw, ab_lengths.begin()); + std::array, 1> as = {Tensor(ab_lengths, a_strides)}; + Tensor& a = as[0]; + Tensor b(ab_lengths, b_strides); + + a.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + + DeviceMem a_device_buf(sizeof(ADataType) * a.mDesc.GetElementSpaceSize()); + DeviceMem b_device_buf(sizeof(BDataType) * b.mDesc.GetElementSpaceSize()); + + a_device_buf.ToDevice(a.mData.data()); + + std::array input = {a_device_buf.GetDeviceBuffer()}; + std::array output = {b_device_buf.GetDeviceBuffer()}; + auto broadcastPermute = DeviceElementwisePermuteInstance{}; auto argument = broadcastPermute.MakeArgumentPointer( ab_lengths, {a_strides}, {b_strides}, input, output, PassThrough{}); @@ -126,10 +103,16 @@ int main() if(do_verification) { - b_device_buf.FromDevice(b.mData.data()); - Tensor host_b(ndhwc); - host_elementwise4D(host_b, a, PassThrough{}); + Tensor host_b(ab_lengths, b_strides); + using ReferenceElementwiseInstance = + ck::tensor_operation::host::ReferenceElementwise<1, ADataType, BDataType, PassThrough>; + auto ref_elementwise = ReferenceElementwiseInstance{}; + auto ref_invoker = ref_elementwise.MakeInvoker(); + auto ref_argument = ref_elementwise.MakeArgument(as, host_b, PassThrough{}); + ref_invoker.Run(ref_argument); + + b_device_buf.FromDevice(b.mData.data()); pass &= ck::utils::check_err(b.mData, host_b.mData, "Error: Incorrect results b", 1e-3, 1e-3); } diff --git a/example/44_elementwise_permute/elementwise_permute_3d.cpp b/example/44_elementwise_permute/elementwise_permute_3d.cpp index f3aca57c35..47d8c4de65 100644 --- a/example/44_elementwise_permute/elementwise_permute_3d.cpp +++ b/example/44_elementwise_permute/elementwise_permute_3d.cpp @@ -8,6 +8,8 @@ #include "ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp" #include "ck/tensor_operation/gpu/device/impl/device_elementwise_3d_impl.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_elementwise.hpp" + #include "ck/library/utility/algorithm.hpp" #include "ck/library/utility/check_err.hpp" #include "ck/library/utility/device_memory.hpp" @@ -34,20 +36,6 @@ using DeviceElementwisePermuteInstance = ck::Sequence<4>, // InScalarPerVectorSeq ck::Sequence<4>>; // OutScalarPerVectorSeq -template -void host_elementwise4D(HostTensorB& B_ndhwc, const HostTensorA& A_ncdhw, Functor functor) -{ - for(std::size_t n = 0; n < A_ncdhw.mDesc.GetLengths()[0]; ++n) - for(std::size_t c = 0; c < A_ncdhw.mDesc.GetLengths()[1]; ++c) - for(std::size_t d = 0; d < A_ncdhw.mDesc.GetLengths()[2]; ++d) - for(std::size_t h = 0; h < A_ncdhw.mDesc.GetLengths()[3]; ++h) - for(std::size_t w = 0; w < A_ncdhw.mDesc.GetLengths()[4]; ++w) - { - auto a_val = A_ncdhw(n, c, d, h, w); - functor(B_ndhwc(n, d, h, w, c), a_val); - } -} - int main() { bool do_verification = true; @@ -59,10 +47,13 @@ int main() const int W = 5; const int D = 16; - std::vector ncdhw = {N, C, D, H, W}; - std::vector ndhwc = {N, D, H, W, C}; - Tensor a(ncdhw); - Tensor b(ndhwc); + std::array ab_lengths{N, C, H, W, D}; + std::array a_strides = {C * D * H * W, H * W, W, 1, D * H * W}; // N, C, D, H, W + std::array b_strides = {C * H * W * D, H * W * D, W * D, D, 1}; // N, D, H, W, C + + std::array, 1> as = {Tensor(ab_lengths, a_strides)}; + Tensor& a = as[0]; + Tensor b(ab_lengths, b_strides); a.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); @@ -74,10 +65,6 @@ int main() std::array input = {a_device_buf.GetDeviceBuffer()}; std::array output = {b_device_buf.GetDeviceBuffer()}; - std::array ab_lengths{N, C, H, W, D}; - std::array a_strides = {C * D * H * W, H * W, W, 1, D * H * W}; // N, C, D, H, W - std::array b_strides = {C * H * W * D, H * W * D, W * D, D, 1}; // N, D, H, W, C - auto broadcastPermute = DeviceElementwisePermuteInstance{}; auto argument = broadcastPermute.MakeArgumentPointer( ab_lengths, {a_strides}, {b_strides}, input, output, PassThrough{}); @@ -94,11 +81,12 @@ int main() auto broadcastPermute_invoker_ptr = broadcastPermute.MakeInvokerPointer(); float ave_time = broadcastPermute_invoker_ptr->Run(argument.get(), StreamConfig{nullptr, time_kernel}); - std::size_t flop = std::size_t(2) * ncdhw[0] * ncdhw[1] * ncdhw[2] * ncdhw[3] * ncdhw[4]; + std::size_t flop = std::size_t(2) * ab_lengths[0] * ab_lengths[1] * ab_lengths[2] * + ab_lengths[3] * ab_lengths[4]; std::size_t num_btype = - sizeof(ADataType) * (ncdhw[0] * ncdhw[1] * ncdhw[2] * ncdhw[3] * ncdhw[4]) + - sizeof(BDataType) * (ncdhw[0] * ncdhw[1] * ncdhw[2] * ncdhw[3] * ncdhw[4]); + (sizeof(ADataType) + sizeof(BDataType)) * + (ab_lengths[0] * ab_lengths[1] * ab_lengths[2] * ab_lengths[3] * ab_lengths[4]); float tflops = static_cast(flop) / 1.E9 / ave_time; @@ -111,10 +99,17 @@ int main() if(do_verification) { - b_device_buf.FromDevice(b.mData.data()); - Tensor host_b(ndhwc); - host_elementwise4D(host_b, a, PassThrough{}); + Tensor host_b(ab_lengths, b_strides); + using ReferenceElementwiseInstance = + ck::tensor_operation::host::ReferenceElementwise<1, ADataType, BDataType, PassThrough>; + auto ref_elementwise = ReferenceElementwiseInstance{}; + auto ref_invoker = ref_elementwise.MakeInvoker(); + + auto ref_argument = ref_elementwise.MakeArgument(as, host_b, PassThrough{}); + ref_invoker.Run(ref_argument); + + b_device_buf.FromDevice(b.mData.data()); pass &= ck::utils::check_err(b.mData, host_b.mData, "Error: Incorrect results b", 1e-3, 1e-3); } diff --git a/example/44_elementwise_permute/elementwise_permute_4D_fp16.cpp b/example/44_elementwise_permute/elementwise_permute_4D_fp16.cpp index 1b28a901cb..3ea1aa4bf8 100644 --- a/example/44_elementwise_permute/elementwise_permute_4D_fp16.cpp +++ b/example/44_elementwise_permute/elementwise_permute_4D_fp16.cpp @@ -8,6 +8,8 @@ #include "ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp" #include "ck/tensor_operation/gpu/device/impl/device_elementwise_dynamic_vector_dims_impl.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_elementwise.hpp" + #include "ck/library/utility/algorithm.hpp" #include "ck/library/utility/check_err.hpp" #include "ck/library/utility/device_memory.hpp" @@ -35,19 +37,6 @@ using DeviceElementwisePermuteInstance = ck::tensor_operation::device::DeviceEle ck::Sequence<8>, // InScalarPerVectorSeq ck::Sequence<8>>; // OutScalarPerVectorSeq -template -void host_elementwise4D(HostTensorB& B_nhwc, const HostTensorA& A_nchw, Functor functor) -{ - for(std::size_t n = 0; n < A_nchw.mDesc.GetLengths()[0]; ++n) - for(std::size_t c = 0; c < A_nchw.mDesc.GetLengths()[1]; ++c) - for(std::size_t h = 0; h < A_nchw.mDesc.GetLengths()[2]; ++h) - for(std::size_t w = 0; w < A_nchw.mDesc.GetLengths()[3]; ++w) - { - auto a_val = A_nchw(n, c, h, w); - functor(B_nhwc(n, h, w, c), a_val); - } -} - int main() { bool do_verification = true; @@ -55,18 +44,6 @@ int main() std::vector nchw = {16, 128, 32, 64}; std::vector nhwc = {16, 32, 64, 128}; - Tensor a(nchw); - Tensor b(nhwc); - - a.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); - - DeviceMem a_device_buf(sizeof(ADataType) * a.mDesc.GetElementSpaceSize()); - DeviceMem b_device_buf(sizeof(BDataType) * b.mDesc.GetElementSpaceSize()); - - a_device_buf.ToDevice(a.mData.data()); - - std::array input = {a_device_buf.GetDeviceBuffer()}; - std::array output = {b_device_buf.GetDeviceBuffer()}; std::array ab_lengths; std::array a_strides = {static_cast(nchw[1] * nchw[2] * nchw[3]), @@ -77,9 +54,22 @@ int main() 1, static_cast(nhwc[2] * nhwc[3]), static_cast(nhwc[3])}; - ck::ranges::copy(nchw, ab_lengths.begin()); + std::array, 1> as = {Tensor(ab_lengths, a_strides)}; + Tensor& a = as[0]; + Tensor b(ab_lengths, b_strides); + + a.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + + DeviceMem a_device_buf(sizeof(ADataType) * a.mDesc.GetElementSpaceSize()); + DeviceMem b_device_buf(sizeof(BDataType) * b.mDesc.GetElementSpaceSize()); + + a_device_buf.ToDevice(a.mData.data()); + + std::array input = {a_device_buf.GetDeviceBuffer()}; + std::array output = {b_device_buf.GetDeviceBuffer()}; + auto broadcastPermute = DeviceElementwisePermuteInstance{}; auto argument = broadcastPermute.MakeArgumentPointer( ab_lengths, {a_strides}, {b_strides}, input, output, PassThrough{}); @@ -111,10 +101,16 @@ int main() if(do_verification) { - b_device_buf.FromDevice(b.mData.data()); - Tensor host_b(nhwc); - host_elementwise4D(host_b, a, PassThrough{}); + Tensor host_b(ab_lengths, b_strides); + using ReferenceElementwiseInstance = + ck::tensor_operation::host::ReferenceElementwise<1, ADataType, BDataType, PassThrough>; + auto ref_elementwise = ReferenceElementwiseInstance{}; + auto ref_invoker = ref_elementwise.MakeInvoker(); + auto ref_argument = ref_elementwise.MakeArgument(as, host_b, PassThrough{}); + ref_invoker.Run(ref_argument); + + b_device_buf.FromDevice(b.mData.data()); pass &= ck::utils::check_err(b.mData, host_b.mData, "Error: Incorrect results b", 1e-3, 1e-3); } diff --git a/example/44_elementwise_permute/elementwise_permute_4D_fp16_2d.cpp b/example/44_elementwise_permute/elementwise_permute_4D_fp16_2d.cpp index 30231a3758..1747e6dd8b 100644 --- a/example/44_elementwise_permute/elementwise_permute_4D_fp16_2d.cpp +++ b/example/44_elementwise_permute/elementwise_permute_4D_fp16_2d.cpp @@ -8,6 +8,8 @@ #include "ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp" #include "ck/tensor_operation/gpu/device/impl/device_elementwise_2d_impl.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_elementwise.hpp" + #include "ck/library/utility/check_err.hpp" #include "ck/library/utility/device_memory.hpp" #include "ck/library/utility/host_tensor.hpp" @@ -30,22 +32,6 @@ using DeviceElementwisePermuteInstance = ck::Sequence<1>, // InScalarPerVectorSeq ck::Sequence<1>>; // OutScalarPerVectorSeq -template -void host_elementwise4D(HostTensorB& B_nhwc, - const HostTensorA& A_nchw, - const std::vector& shape_nchw, - Functor functor) -{ - for(std::size_t n = 0; n < shape_nchw[0]; ++n) - for(std::size_t c = 0; c < shape_nchw[1]; ++c) - for(std::size_t h = 0; h < shape_nchw[2]; ++h) - for(std::size_t w = 0; w < shape_nchw[3]; ++w) - { - auto a_val = A_nchw(n, c, h, w); - functor(B_nhwc(n, h, w, c), a_val); - } -} - int main() { bool do_verification = true; @@ -54,13 +40,16 @@ int main() const int N = 120; const int C = 128; const int H = 32; - const int W = 1024; + const int W = 32; - std::vector nchw = {N, C, H, W}; - std::vector nhwc = {N, H, W, C}; + std::array ab_lengths{N, H, W, C}; - Tensor a(nchw); - Tensor b(nhwc); + std::array a_strides = {C * H * W, W, 1, H * W}; + std::array b_strides = {H * W * C, W * C, C, 1}; + + std::array, 1> as = {Tensor(ab_lengths, a_strides)}; + Tensor& a = as[0]; + Tensor b(ab_lengths, b_strides); a.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); @@ -72,11 +61,6 @@ int main() std::array input = {a_device_buf.GetDeviceBuffer()}; std::array output = {b_device_buf.GetDeviceBuffer()}; - std::array ab_lengths{N, H, W, C}; - - std::array a_strides = {C * H * W, W, 1, H * W}; - std::array b_strides = {H * W * C, W * C, C, 1}; - auto broadcastPermute = DeviceElementwisePermuteInstance{}; auto argument = broadcastPermute.MakeArgumentPointer( ab_lengths, {a_strides}, {b_strides}, input, output, PassThrough{}); @@ -94,10 +78,11 @@ int main() float ave_time = broadcastPermute_invoker_ptr->Run(argument.get(), StreamConfig{nullptr, time_kernel}); - std::size_t flop = std::size_t(2) * nchw[0] * nchw[1] * nchw[2] * nchw[3]; + std::size_t flop = + std::size_t(2) * ab_lengths[0] * ab_lengths[1] * ab_lengths[2] * ab_lengths[3]; - std::size_t num_btype = sizeof(ADataType) * (nchw[0] * nchw[1] * nchw[2] * nchw[3]) + - sizeof(BDataType) * (nchw[0] * nchw[1] * nchw[2] * nchw[3]); + std::size_t num_btype = (sizeof(ADataType) + sizeof(BDataType)) * + (ab_lengths[0] * ab_lengths[1] * ab_lengths[2] * ab_lengths[3]); float tflops = static_cast(flop) / 1.E9 / ave_time; @@ -110,11 +95,16 @@ int main() if(do_verification) { - b_device_buf.FromDevice(b.mData.data()); + Tensor host_b(ab_lengths, b_strides); + using ReferenceElementwiseInstance = + ck::tensor_operation::host::ReferenceElementwise<1, ADataType, BDataType, PassThrough>; + auto ref_elementwise = ReferenceElementwiseInstance{}; + auto ref_invoker = ref_elementwise.MakeInvoker(); - Tensor host_b(nhwc); - host_elementwise4D, Tensor, PassThrough>( - host_b, a, nchw, PassThrough{}); + auto ref_argument = ref_elementwise.MakeArgument(as, host_b, PassThrough{}); + ref_invoker.Run(ref_argument); + + b_device_buf.FromDevice(b.mData.data()); pass &= ck::utils::check_err(b.mData, host_b.mData, "Error: Incorrect results b", 1e-3, 1e-3); } diff --git a/example/44_elementwise_permute/elementwise_permute_4D_fp16_col.cpp b/example/44_elementwise_permute/elementwise_permute_4D_fp16_col.cpp index f832601f07..13c67fce05 100644 --- a/example/44_elementwise_permute/elementwise_permute_4D_fp16_col.cpp +++ b/example/44_elementwise_permute/elementwise_permute_4D_fp16_col.cpp @@ -6,9 +6,11 @@ #include #include "ck/ck.hpp" -#include "ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp" +#include "ck/tensor_operation/gpu/element/combined_element_wise_operation.hpp" #include "ck/tensor_operation/gpu/device/impl/device_elementwise_dynamic_vector_dims_impl.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_elementwise.hpp" + #include "ck/library/utility/algorithm.hpp" #include "ck/library/utility/check_err.hpp" #include "ck/library/utility/device_memory.hpp" @@ -21,11 +23,14 @@ using F32 = float; using ADataType = F16; using BDataType = F16; -using UnaryOp = ck::tensor_operation::element_wise::Scale; +using UnaryScale = ck::tensor_operation::element_wise::Scale; +using UnarySquare = ck::tensor_operation::element_wise::UnarySquare; +using UnaryScaleSquare = + ck::tensor_operation::element_wise::UnaryCombinedOp; using DeviceElementwisePermuteInstance = ck::tensor_operation::device::DeviceElementwiseImpl< ck::Tuple, // InDataTypeTuple ck::Tuple, // OutDataTypeTuple - UnaryOp, // UnaryOp + UnaryScaleSquare, // UnaryScaleSquare 4, // NumDim 256, // BlockSize 128, // M0PerBlock @@ -36,23 +41,6 @@ using DeviceElementwisePermuteInstance = ck::tensor_operation::device::DeviceEle ck::Sequence<8>, // InScalarPerVectorSeq ck::Sequence<8>>; // OutScalarPerVectorSeq -template -void host_elementwise4D(HostTensorB& B_nhwc, const HostTensorA& A_nchw, Functor functor) -{ - std::size_t N = A_nchw.mDesc.GetLengths()[0]; - std::size_t C = A_nchw.mDesc.GetLengths()[1]; - std::size_t H = A_nchw.mDesc.GetLengths()[2]; - std::size_t W = A_nchw.mDesc.GetLengths()[3]; - for(std::size_t w = 0; w < W; ++w) - for(std::size_t h = 0; h < H; ++h) - for(std::size_t c = 0; c < C; ++c) - for(std::size_t n = 0; n < N; ++n) - { - auto a_val = A_nchw.mData[(n) + (c * N) + (h * C * N) + (w * H * C * N)]; - functor(B_nhwc.mData[(n) + (c * W * H * N) + (h * N) + (w * H * N)], a_val); - } -} - int main() { bool do_verification = true; @@ -60,8 +48,21 @@ int main() std::vector nchw = {16, 8, 32, 64}; std::vector nhwc = {16, 32, 64, 8}; - Tensor a(nchw); - Tensor b(nhwc); + std::array ab_lengths; + std::array a_strides = {1, + static_cast(nchw[0]), + static_cast(nchw[0] * nchw[1]), + static_cast(nchw[0] * nchw[1] * nchw[2])}; + + std::array b_strides = {1, + static_cast(nhwc[0] * nhwc[1] * nhwc[2]), + static_cast(nhwc[0]), + static_cast(nhwc[0] * nhwc[1])}; + ck::ranges::copy(nchw, ab_lengths.begin()); + + std::array, 1> as = {Tensor(ab_lengths, a_strides)}; + Tensor& a = as[0]; + Tensor b(ab_lengths, b_strides); float scale = 1.f; auto i = 0; std::mt19937 gen(11939); @@ -84,22 +85,14 @@ int main() std::array input = {a_device_buf.GetDeviceBuffer()}; std::array output = {b_device_buf.GetDeviceBuffer()}; - std::array ab_lengths; - - std::array a_strides = {1, - static_cast(nchw[0]), - static_cast(nchw[0] * nchw[1]), - static_cast(nchw[0] * nchw[1] * nchw[2])}; - - std::array b_strides = {1, - static_cast(nhwc[0] * nhwc[1] * nhwc[2]), - static_cast(nhwc[0]), - static_cast(nhwc[0] * nhwc[1])}; - ck::ranges::copy(nchw, ab_lengths.begin()); - auto broadcastPermute = DeviceElementwisePermuteInstance{}; - auto argument = broadcastPermute.MakeArgumentPointer( - ab_lengths, {a_strides}, {b_strides}, input, output, UnaryOp{scale}); + auto argument = + broadcastPermute.MakeArgumentPointer(ab_lengths, + {a_strides}, + {b_strides}, + input, + output, + UnaryScaleSquare{UnarySquare{}, UnaryScale{scale}}); if(!broadcastPermute.IsSupportedArgument(argument.get())) { @@ -113,11 +106,10 @@ int main() auto broadcastPermute_invoker_ptr = broadcastPermute.MakeInvokerPointer(); float ave_time = broadcastPermute_invoker_ptr->Run(argument.get(), StreamConfig{nullptr, time_kernel}); - std::size_t flop = std::size_t(2) * nchw[0] * nchw[1] * nchw[2] * nchw[3]; - - std::size_t num_btype = sizeof(ADataType) * (nchw[0] * nchw[1] * nchw[2] * nchw[3]) + - sizeof(BDataType) * (nchw[0] * nchw[1] * nchw[2] * nchw[3]); + std::size_t flop = std::size_t(5) * nchw[0] * nchw[1] * nchw[2] * nchw[3]; + std::size_t num_btype = + (2 * sizeof(ADataType) + sizeof(BDataType)) * (nchw[0] * nchw[1] * nchw[2] * nchw[3]); float tflops = static_cast(flop) / 1.E9 / ave_time; float gb_per_sec = num_btype / 1.E6 / ave_time; @@ -129,10 +121,17 @@ int main() if(do_verification) { - b_device_buf.FromDevice(b.mData.data()); - Tensor host_b(nhwc); - host_elementwise4D(host_b, a, UnaryOp{scale}); + Tensor host_b(ab_lengths, b_strides); + using ReferenceElementwiseInstance = ck::tensor_operation::host:: + ReferenceElementwise<1, ADataType, BDataType, UnaryScaleSquare>; + auto ref_elementwise = ReferenceElementwiseInstance{}; + auto ref_invoker = ref_elementwise.MakeInvoker(); + auto ref_argument = ref_elementwise.MakeArgument( + as, host_b, UnaryScaleSquare{UnarySquare{}, UnaryScale{scale}}); + ref_invoker.Run(ref_argument); + + b_device_buf.FromDevice(b.mData.data()); pass &= ck::utils::check_err(b.mData, host_b.mData, "Error: Incorrect results b", 1e-3, 1e-3); } diff --git a/example/44_elementwise_permute/elementwise_permute_4D_fp16_row.cpp b/example/44_elementwise_permute/elementwise_permute_4D_fp16_row.cpp index bae85f53c1..0a0f6fec10 100644 --- a/example/44_elementwise_permute/elementwise_permute_4D_fp16_row.cpp +++ b/example/44_elementwise_permute/elementwise_permute_4D_fp16_row.cpp @@ -5,9 +5,11 @@ #include #include "ck/ck.hpp" -#include "ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp" +#include "ck/tensor_operation/gpu/element/combined_element_wise_operation.hpp" #include "ck/tensor_operation/gpu/device/impl/device_elementwise_dynamic_vector_dims_impl.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_elementwise.hpp" + #include "ck/library/utility/algorithm.hpp" #include "ck/library/utility/check_err.hpp" #include "ck/library/utility/device_memory.hpp" @@ -20,11 +22,14 @@ using F32 = float; using ADataType = F16; using BDataType = F16; -using UnaryOp = ck::tensor_operation::element_wise::Scale; +using UnaryScale = ck::tensor_operation::element_wise::Scale; +using UnarySquare = ck::tensor_operation::element_wise::UnarySquare; +using UnaryScaleSquare = + ck::tensor_operation::element_wise::UnaryCombinedOp; using DeviceElementwisePermuteInstance = ck::tensor_operation::device::DeviceElementwiseImpl< ck::Tuple, // InDataTypeTuple ck::Tuple, // OutDataTypeTuple - UnaryOp, // UnaryOp + UnaryScaleSquare, // UnaryScaleSquare 4, // NumDim 256, // BlockSize 128, // M0PerBlock @@ -35,19 +40,6 @@ using DeviceElementwisePermuteInstance = ck::tensor_operation::device::DeviceEle ck::Sequence<8>, // InScalarPerVectorSeq ck::Sequence<8>>; // OutScalarPerVectorSeq -template -void host_elementwise4D(HostTensorB& B_nhwc, const HostTensorA& A_nchw, Functor functor) -{ - for(std::size_t n = 0; n < A_nchw.mDesc.GetLengths()[0]; ++n) - for(std::size_t c = 0; c < A_nchw.mDesc.GetLengths()[1]; ++c) - for(std::size_t h = 0; h < A_nchw.mDesc.GetLengths()[2]; ++h) - for(std::size_t w = 0; w < A_nchw.mDesc.GetLengths()[3]; ++w) - { - auto a_val = A_nchw(n, c, h, w); - functor(B_nhwc(n, h, w, c), a_val); - } -} - int main() { bool do_verification = true; @@ -55,18 +47,6 @@ int main() std::vector nchw = {16, 128, 32, 64}; std::vector nhwc = {16, 32, 64, 128}; - Tensor a(nchw); - Tensor b(nhwc); - float scale = 2.f; - a.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); - - DeviceMem a_device_buf(sizeof(ADataType) * a.mDesc.GetElementSpaceSize()); - DeviceMem b_device_buf(sizeof(BDataType) * b.mDesc.GetElementSpaceSize()); - - a_device_buf.ToDevice(a.mData.data()); - - std::array input = {a_device_buf.GetDeviceBuffer()}; - std::array output = {b_device_buf.GetDeviceBuffer()}; std::array ab_lengths; std::array a_strides = {static_cast(nchw[1] * nchw[2] * nchw[3]), @@ -80,9 +60,29 @@ int main() ck::ranges::copy(nchw, ab_lengths.begin()); + std::array, 1> as = {Tensor(ab_lengths, a_strides)}; + Tensor& a = as[0]; + Tensor b(ab_lengths, b_strides); + + float scale = 2.f; + a.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + + DeviceMem a_device_buf(sizeof(ADataType) * a.mDesc.GetElementSpaceSize()); + DeviceMem b_device_buf(sizeof(BDataType) * b.mDesc.GetElementSpaceSize()); + + a_device_buf.ToDevice(a.mData.data()); + + std::array input = {a_device_buf.GetDeviceBuffer()}; + std::array output = {b_device_buf.GetDeviceBuffer()}; + auto broadcastPermute = DeviceElementwisePermuteInstance{}; - auto argument = broadcastPermute.MakeArgumentPointer( - ab_lengths, {a_strides}, {b_strides}, input, output, UnaryOp{scale}); + auto argument = + broadcastPermute.MakeArgumentPointer(ab_lengths, + {a_strides}, + {b_strides}, + input, + output, + UnaryScaleSquare{UnarySquare{}, UnaryScale{scale}}); if(!broadcastPermute.IsSupportedArgument(argument.get())) { @@ -112,10 +112,17 @@ int main() if(do_verification) { - b_device_buf.FromDevice(b.mData.data()); - Tensor host_b(nhwc); - host_elementwise4D(host_b, a, UnaryOp{scale}); + Tensor host_b(ab_lengths, b_strides); + using ReferenceElementwiseInstance = ck::tensor_operation::host:: + ReferenceElementwise<1, ADataType, BDataType, UnaryScaleSquare>; + auto ref_elementwise = ReferenceElementwiseInstance{}; + auto ref_invoker = ref_elementwise.MakeInvoker(); + auto ref_argument = ref_elementwise.MakeArgument( + as, host_b, UnaryScaleSquare{UnarySquare{}, UnaryScale{scale}}); + ref_invoker.Run(ref_argument); + + b_device_buf.FromDevice(b.mData.data()); pass &= ck::utils::check_err(b.mData, host_b.mData, "Error: Incorrect results b", 1e-3, 1e-3); } diff --git a/example/44_elementwise_permute/elementwise_permute_4D_fp32_col.cpp b/example/44_elementwise_permute/elementwise_permute_4D_fp32_col.cpp index fe7acd3010..fc664186be 100644 --- a/example/44_elementwise_permute/elementwise_permute_4D_fp32_col.cpp +++ b/example/44_elementwise_permute/elementwise_permute_4D_fp32_col.cpp @@ -5,9 +5,11 @@ #include #include "ck/ck.hpp" -#include "ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp" +#include "ck/tensor_operation/gpu/element/combined_element_wise_operation.hpp" #include "ck/tensor_operation/gpu/device/impl/device_elementwise_dynamic_vector_dims_impl.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_elementwise.hpp" + #include "ck/library/utility/algorithm.hpp" #include "ck/library/utility/check_err.hpp" #include "ck/library/utility/device_memory.hpp" @@ -20,11 +22,14 @@ using F32 = float; using ADataType = F32; using BDataType = F32; -using UnaryOp = ck::tensor_operation::element_wise::Scale; +using UnaryScale = ck::tensor_operation::element_wise::Scale; +using UnarySquare = ck::tensor_operation::element_wise::UnarySquare; +using UnaryScaleSquare = + ck::tensor_operation::element_wise::UnaryCombinedOp; using DeviceElementwisePermuteInstance = ck::tensor_operation::device::DeviceElementwiseImpl< ck::Tuple, // InDataTypeTuple ck::Tuple, // OutDataTypeTuple - UnaryOp, // UnaryOp + UnaryScaleSquare, // UnaryScaleSquare 4, // NumDim 256, // BlockSize 128, // M0PerBlock @@ -35,32 +40,29 @@ using DeviceElementwisePermuteInstance = ck::tensor_operation::device::DeviceEle ck::Sequence<1>, // InScalarPerVectorSeq ck::Sequence<1>>; // OutScalarPerVectorSeq -template -void host_elementwise4D(HostTensorB& B_nhwc, const HostTensorA& A_nchw, Functor functor) -{ - std::size_t N = A_nchw.mDesc.GetLengths()[0]; - std::size_t C = A_nchw.mDesc.GetLengths()[1]; - std::size_t H = A_nchw.mDesc.GetLengths()[2]; - std::size_t W = A_nchw.mDesc.GetLengths()[3]; - for(std::size_t w = 0; w < W; ++w) - for(std::size_t h = 0; h < H; ++h) - for(std::size_t c = 0; c < C; ++c) - for(std::size_t n = 0; n < N; ++n) - { - auto a_val = A_nchw.mData[(n) + (c * N) + (h * C * N) + (w * H * C * N)]; - functor(B_nhwc.mData[(n) + (c * W * H * N) + (h * N) + (w * H * N)], a_val); - } -} - int main() { bool do_verification = true; bool time_kernel = true; - std::vector nchw = {5, 4, 2, 3}; - std::vector nhwc = {5, 2, 3, 4}; - Tensor a(nchw); - Tensor b(nhwc); + std::vector nchw = {16, 8, 32, 64}; + std::vector nhwc = {16, 32, 64, 8}; + std::array ab_lengths; + + std::array a_strides = {1, + static_cast(nchw[0]), + static_cast(nchw[0] * nchw[1]), + static_cast(nchw[0] * nchw[1] * nchw[2])}; + + std::array b_strides = {1, + static_cast(nhwc[0] * nhwc[1] * nhwc[2]), + static_cast(nhwc[0]), + static_cast(nhwc[0] * nhwc[1])}; + ck::ranges::copy(nchw, ab_lengths.begin()); + + std::array, 1> as = {Tensor(ab_lengths, a_strides)}; + Tensor& a = as[0]; + Tensor b(ab_lengths, b_strides); float scale = 1.f; auto i = 0; @@ -84,22 +86,14 @@ int main() std::array input = {a_device_buf.GetDeviceBuffer()}; std::array output = {b_device_buf.GetDeviceBuffer()}; - std::array ab_lengths; - - std::array a_strides = {1, - static_cast(nchw[0]), - static_cast(nchw[0] * nchw[1]), - static_cast(nchw[0] * nchw[1] * nchw[2])}; - - std::array b_strides = {1, - static_cast(nhwc[0] * nhwc[1] * nhwc[2]), - static_cast(nhwc[0]), - static_cast(nhwc[0] * nhwc[1])}; - ck::ranges::copy(nchw, ab_lengths.begin()); - auto broadcastPermute = DeviceElementwisePermuteInstance{}; - auto argument = broadcastPermute.MakeArgumentPointer( - ab_lengths, {a_strides}, {b_strides}, input, output, UnaryOp{scale}); + auto argument = + broadcastPermute.MakeArgumentPointer(ab_lengths, + {a_strides}, + {b_strides}, + input, + output, + UnaryScaleSquare{UnarySquare{}, UnaryScale{scale}}); if(!broadcastPermute.IsSupportedArgument(argument.get())) { @@ -129,10 +123,17 @@ int main() if(do_verification) { - b_device_buf.FromDevice(b.mData.data()); - Tensor host_b(nhwc); - host_elementwise4D(host_b, a, UnaryOp{scale}); + Tensor host_b(ab_lengths, b_strides); + using ReferenceElementwiseInstance = ck::tensor_operation::host:: + ReferenceElementwise<1, ADataType, BDataType, UnaryScaleSquare>; + auto ref_elementwise = ReferenceElementwiseInstance{}; + auto ref_invoker = ref_elementwise.MakeInvoker(); + auto ref_argument = ref_elementwise.MakeArgument( + as, host_b, UnaryScaleSquare{UnarySquare{}, UnaryScale{scale}}); + ref_invoker.Run(ref_argument); + + b_device_buf.FromDevice(b.mData.data()); pass &= ck::utils::check_err(b.mData, host_b.mData, "Error: Incorrect results b", 1e-3, 1e-3); } diff --git a/example/44_elementwise_permute/elementwise_permute_4D_fp32_row.cpp b/example/44_elementwise_permute/elementwise_permute_4D_fp32_row.cpp index aebdb37d9b..a0c416318a 100644 --- a/example/44_elementwise_permute/elementwise_permute_4D_fp32_row.cpp +++ b/example/44_elementwise_permute/elementwise_permute_4D_fp32_row.cpp @@ -5,9 +5,11 @@ #include #include "ck/ck.hpp" -#include "ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp" +#include "ck/tensor_operation/gpu/element/combined_element_wise_operation.hpp" #include "ck/tensor_operation/gpu/device/impl/device_elementwise_dynamic_vector_dims_impl.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_elementwise.hpp" + #include "ck/library/utility/algorithm.hpp" #include "ck/library/utility/check_err.hpp" #include "ck/library/utility/device_memory.hpp" @@ -20,11 +22,14 @@ using F32 = float; using ADataType = F32; using BDataType = F32; -using UnaryOp = ck::tensor_operation::element_wise::Scale; +using UnaryScale = ck::tensor_operation::element_wise::Scale; +using UnarySquare = ck::tensor_operation::element_wise::UnarySquare; +using UnaryScaleSquare = + ck::tensor_operation::element_wise::UnaryCombinedOp; using DeviceElementwisePermuteInstance = ck::tensor_operation::device::DeviceElementwiseImpl< ck::Tuple, // InDataTypeTuple ck::Tuple, // OutDataTypeTuple - UnaryOp, // UnaryOp + UnaryScaleSquare, // UnaryScaleSquare 4, // NumDim 256, // BlockSize 128, // M0PerBlock @@ -35,19 +40,6 @@ using DeviceElementwisePermuteInstance = ck::tensor_operation::device::DeviceEle ck::Sequence<8>, // InScalarPerVectorSeq ck::Sequence<8>>; // OutScalarPerVectorSeq -template -void host_elementwise4D(HostTensorB& B_nhwc, const HostTensorA& A_nchw, Functor functor) -{ - for(std::size_t n = 0; n < A_nchw.mDesc.GetLengths()[0]; ++n) - for(std::size_t c = 0; c < A_nchw.mDesc.GetLengths()[1]; ++c) - for(std::size_t h = 0; h < A_nchw.mDesc.GetLengths()[2]; ++h) - for(std::size_t w = 0; w < A_nchw.mDesc.GetLengths()[3]; ++w) - { - auto a_val = A_nchw(n, c, h, w); - functor(B_nhwc(n, h, w, c), a_val); - } -} - int main() { bool do_verification = true; @@ -55,18 +47,6 @@ int main() std::vector nchw = {16, 128, 32, 64}; std::vector nhwc = {16, 32, 64, 128}; - Tensor a(nchw); - Tensor b(nhwc); - float scale = 2.f; - a.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); - - DeviceMem a_device_buf(sizeof(ADataType) * a.mDesc.GetElementSpaceSize()); - DeviceMem b_device_buf(sizeof(BDataType) * b.mDesc.GetElementSpaceSize()); - - a_device_buf.ToDevice(a.mData.data()); - - std::array input = {a_device_buf.GetDeviceBuffer()}; - std::array output = {b_device_buf.GetDeviceBuffer()}; std::array ab_lengths; std::array a_strides = {static_cast(nchw[1] * nchw[2] * nchw[3]), @@ -80,9 +60,28 @@ int main() ck::ranges::copy(nchw, ab_lengths.begin()); + std::array, 1> as = {Tensor(ab_lengths, a_strides)}; + Tensor& a = as[0]; + Tensor b(ab_lengths, b_strides); + float scale = 2.f; + a.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + + DeviceMem a_device_buf(sizeof(ADataType) * a.mDesc.GetElementSpaceSize()); + DeviceMem b_device_buf(sizeof(BDataType) * b.mDesc.GetElementSpaceSize()); + + a_device_buf.ToDevice(a.mData.data()); + + std::array input = {a_device_buf.GetDeviceBuffer()}; + std::array output = {b_device_buf.GetDeviceBuffer()}; + auto broadcastPermute = DeviceElementwisePermuteInstance{}; - auto argument = broadcastPermute.MakeArgumentPointer( - ab_lengths, {a_strides}, {b_strides}, input, output, UnaryOp{scale}); + auto argument = + broadcastPermute.MakeArgumentPointer(ab_lengths, + {a_strides}, + {b_strides}, + input, + output, + UnaryScaleSquare{UnarySquare{}, UnaryScale{scale}}); if(!broadcastPermute.IsSupportedArgument(argument.get())) { @@ -112,10 +111,17 @@ int main() if(do_verification) { - b_device_buf.FromDevice(b.mData.data()); - Tensor host_b(nhwc); - host_elementwise4D(host_b, a, UnaryOp{scale}); + Tensor host_b(ab_lengths, b_strides); + using ReferenceElementwiseInstance = ck::tensor_operation::host:: + ReferenceElementwise<1, ADataType, BDataType, UnaryScaleSquare>; + auto ref_elementwise = ReferenceElementwiseInstance{}; + auto ref_invoker = ref_elementwise.MakeInvoker(); + auto ref_argument = ref_elementwise.MakeArgument( + as, host_b, UnaryScaleSquare{UnarySquare{}, UnaryScale{scale}}); + ref_invoker.Run(ref_argument); + + b_device_buf.FromDevice(b.mData.data()); pass &= ck::utils::check_err(b.mData, host_b.mData, "Error: Incorrect results b", 1e-3, 1e-3); } diff --git a/example/44_elementwise_permute/elementwise_trinary_4D_fp16.cpp b/example/44_elementwise_permute/elementwise_trinary_4D_fp16.cpp new file mode 100644 index 0000000000..050300eed2 --- /dev/null +++ b/example/44_elementwise_permute/elementwise_trinary_4D_fp16.cpp @@ -0,0 +1,156 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/element/combined_element_wise_operation.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_elementwise_dynamic_vector_dims_impl.hpp" + +#include "ck/library/reference_tensor_operation/cpu/reference_elementwise.hpp" + +#include "ck/library/utility/algorithm.hpp" +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" + +using F16 = ck::half_t; +using F32 = float; + +using ADataType = F16; +using BDataType = F16; + +using UnaryScale = ck::tensor_operation::element_wise::Scale; +using UnarySquare = ck::tensor_operation::element_wise::UnarySquare; +using UnaryScaleSquare = + ck::tensor_operation::element_wise::UnaryCombinedOp; +using BinaryAdd = ck::tensor_operation::element_wise::Add; +// B = alpha * A0 * A0 + beta * A1 * A1 + gamma * A2 * A2 +using TrinaryAddUnaryScaleSquare = + ck::tensor_operation::element_wise::TrinaryWithUnaryCombinedOp; +using DeviceElementwisePermuteInstance = ck::tensor_operation::device::DeviceElementwiseImpl< + ck::Tuple, // InDataTypeTuple + ck::Tuple, // OutDataTypeTuple + TrinaryAddUnaryScaleSquare, // ElementwiseOp + 4, // NumDim + 256, // BlockSize + 128, // M0PerBlock + 128, // M1PerBlock + 8, // M0PerThread + 8, // M1PerThread + ck::Sequence<1, 0>, // ThreadClusterArrangeOrder + ck::Sequence<8, 8, 8>, // InScalarPerVectorSeq + ck::Sequence<8>>; // OutScalarPerVectorSeq + +int main() +{ + bool do_verification = true; + bool time_kernel = true; + + std::vector nchw = {16, 128, 32, 64}; + std::array ab_lengths; + std::array ab_strides = {static_cast(nchw[1] * nchw[2] * nchw[3]), + static_cast(nchw[2] * nchw[3]), + static_cast(nchw[3]), + 1}; + + ck::ranges::copy(nchw, ab_lengths.begin()); + + std::array, 3> as = {Tensor(ab_lengths, ab_strides), + Tensor(ab_lengths, ab_strides), + Tensor(ab_lengths, ab_strides)}; + Tensor& a0 = as[0]; + Tensor& a1 = as[1]; + Tensor& a2 = as[2]; + Tensor b(ab_lengths, ab_strides); + float alpha = 3.f; + float beta = 2.f; + float gamma = 4.f; + a0.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + a1.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + a2.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + + DeviceMem a0_device_buf(sizeof(ADataType) * a0.mDesc.GetElementSpaceSize()); + DeviceMem a1_device_buf(sizeof(ADataType) * a1.mDesc.GetElementSpaceSize()); + DeviceMem a2_device_buf(sizeof(ADataType) * a2.mDesc.GetElementSpaceSize()); + DeviceMem b_device_buf(sizeof(BDataType) * b.mDesc.GetElementSpaceSize()); + + a0_device_buf.ToDevice(a0.mData.data()); + a1_device_buf.ToDevice(a1.mData.data()); + a2_device_buf.ToDevice(a2.mData.data()); + + std::array inputs = {a0_device_buf.GetDeviceBuffer(), + a1_device_buf.GetDeviceBuffer(), + a2_device_buf.GetDeviceBuffer()}; + std::array output = {b_device_buf.GetDeviceBuffer()}; + + auto broadcastPermute = DeviceElementwisePermuteInstance{}; + auto unary_scale_op_a0 = UnaryScaleSquare{UnarySquare{}, UnaryScale{alpha}}; + auto unary_scale_op_a1 = UnaryScaleSquare{UnarySquare{}, UnaryScale{beta}}; + auto unary_scale_op_a2 = UnaryScaleSquare{UnarySquare{}, UnaryScale{gamma}}; + auto argument = broadcastPermute.MakeArgumentPointer( + ab_lengths, + {ab_strides, ab_strides, ab_strides}, + {ab_strides}, + inputs, + output, + TrinaryAddUnaryScaleSquare{ + BinaryAdd{}, BinaryAdd{}, unary_scale_op_a0, unary_scale_op_a1, unary_scale_op_a2}); + + if(!broadcastPermute.IsSupportedArgument(argument.get())) + { + throw std::runtime_error( + "The runtime parameters seems not supported by the device instance, exiting!"); + }; + + std::cout << "A0 (nchw): " << a0.mDesc << std::endl; + std::cout << "A1 (nchw): " << a1.mDesc << std::endl; + std::cout << "A2 (nchw): " << a2.mDesc << std::endl; + std::cout << "B (nchw): " << b.mDesc << std::endl; + + auto broadcastPermute_invoker_ptr = broadcastPermute.MakeInvokerPointer(); + float ave_time = + broadcastPermute_invoker_ptr->Run(argument.get(), StreamConfig{nullptr, time_kernel}); + std::size_t flop = std::size_t(5) * nchw[0] * nchw[1] * nchw[2] * nchw[3]; + + std::size_t num_btype = sizeof(ADataType) * (nchw[0] * nchw[1] * nchw[2] * nchw[3]) + + sizeof(BDataType) * (nchw[0] * nchw[1] * nchw[2] * nchw[3]); + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s" + << std::endl; + + bool pass = true; + + if(do_verification) + { + Tensor host_b(ab_lengths, ab_strides); + using ReferenceElementwiseInstance = ck::tensor_operation::host:: + ReferenceElementwise<3, ADataType, BDataType, TrinaryAddUnaryScaleSquare>; + auto ref_elementwise = ReferenceElementwiseInstance{}; + auto ref_invoker = ref_elementwise.MakeInvoker(); + + auto ref_argument = ref_elementwise.MakeArgument( + as, + host_b, + TrinaryAddUnaryScaleSquare{ + BinaryAdd{}, BinaryAdd{}, unary_scale_op_a0, unary_scale_op_a1, unary_scale_op_a2}); + ref_invoker.Run(ref_argument); + + const double threshold = std::pow(2, -10) * 2; + b_device_buf.FromDevice(b.mData.data()); + pass &= ck::utils::check_err( + b.mData, host_b.mData, "Error: Incorrect results b", threshold, threshold); + } + + return pass ? 0 : 1; +} diff --git a/include/ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp b/include/ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp index ba2e0057d9..f6e57aad09 100644 --- a/include/ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp +++ b/include/ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp @@ -92,6 +92,110 @@ struct Add }; }; +struct Max +{ + template + __host__ __device__ void operator()(Y& y, const X0& x0, const X1& x1) const + { + const Y x0_converted = type_convert(x0); + const Y x1_converted = type_convert(x1); + y = ck::math::max(x0_converted, x1_converted); + } +}; + +struct Min +{ + template + __host__ __device__ void operator()(Y& y, const X0& x0, const X1& x1) const + { + const Y x0_converted = type_convert(x0); + const Y x1_converted = type_convert(x1); + y = ck::math::min(x0_converted, x1_converted); + } +}; + +struct Multiply +{ + template + __host__ __device__ constexpr void operator()(Y& y, const X0& x0, const X1& x1) const; + + template <> + __host__ __device__ constexpr void + operator()(float& y, const float& x0, const float& x1) const + { + y = x0 * x1; + }; + + template <> + __host__ __device__ constexpr void + operator()(double& y, const double& x0, const double& x1) const + { + y = x0 * x1; + }; + + template <> + __host__ __device__ constexpr void + operator()(float& y, const float& x0, const half_t& x1) const + { + y = x0 * type_convert(x1); + }; + + template <> + __host__ __device__ constexpr void + operator()(half_t& y, const float& x0, const float& x1) const + { + y = type_convert(x0 * x1); + }; + + template <> + __host__ __device__ constexpr void + operator()(half_t& y, const float& x0, const half_t& x1) const + { + y = type_convert(x0) * x1; + }; + + template <> + __host__ __device__ constexpr void + operator()(half_t& y, const half_t& x0, const half_t& x1) const + { + y = x0 * x1; + }; + + template <> + __host__ __device__ constexpr void + operator()(float& y, const float& x0, const bhalf_t& x1) const + { + const float x1_tmp = ck::type_convert(x1); + y = x0 * x1_tmp; + } + + template <> + __host__ __device__ constexpr void + operator()(bhalf_t& y, const bhalf_t& x0, const bhalf_t& x1) const + { + const float x1_tmp = ck::type_convert(x0); + const float x2_tmp = ck::type_convert(x1); + const float y_tmp = x1_tmp * x2_tmp; + y = ck::type_convert(y_tmp); + } + + template <> + __host__ __device__ constexpr void + operator()(bhalf_t& y, const float& x0, const bhalf_t& x1) const + { + const float x2_tmp = ck::type_convert(x1); + const float y_tmp = x0 * x2_tmp; + y = ck::type_convert(y_tmp); + } + + template <> + __host__ __device__ constexpr void + operator()(int8_t& y, const int8_t& x0, const int8_t& x1) const + { + y = x0 * x1; + }; +}; + struct ScaleAdd { __host__ __device__ ScaleAdd(float scale = 1.f) : scale_(scale) {} diff --git a/include/ck/tensor_operation/gpu/element/combined_element_wise_operation.hpp b/include/ck/tensor_operation/gpu/element/combined_element_wise_operation.hpp new file mode 100644 index 0000000000..6d1d6b57c5 --- /dev/null +++ b/include/ck/tensor_operation/gpu/element/combined_element_wise_operation.hpp @@ -0,0 +1,103 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/data_type.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +namespace ck { +namespace tensor_operation { +namespace element_wise { + +// y = UnaryOp0(UnaryOp1(...(x))) +template +struct UnaryCombinedOp +{ + __host__ __device__ UnaryCombinedOp(UnaryOpsSet... unary_ops) : unary_ops_(unary_ops...) {} + + template + __host__ __device__ void operator()(Y& y, const X& x) const + { + // Execute first unary op to copy data to y + unary_ops_.At(Number<0>{})(y, x); + + static_for<1, Tuple::Size(), 1>{}([&](auto i) { unary_ops_.At(i)(y, y); }); + }; + + Tuple unary_ops_; +}; + +// y = BinaryOp(UnaryOp0(x0), UnaryOp1(x1)) +template +struct BinaryWithUnaryCombinedOp +{ + __host__ __device__ BinaryWithUnaryCombinedOp(BinaryOp binary_op, + UnaryOp0 unary_op0, + UnaryOp1 unary_op1) + : binary_op_(binary_op), unary_op0_(unary_op0), unary_op1_(unary_op1) + { + } + + template + __host__ __device__ void operator()(Y& y, const X0& x0, const X1& x1) const + { + Y unary_x0_tmp_result; + Y unary_x1_tmp_result; + unary_op0_(unary_x0_tmp_result, x0); + unary_op1_(unary_x1_tmp_result, x1); + binary_op_(y, unary_x0_tmp_result, unary_x1_tmp_result); + }; + + private: + BinaryOp binary_op_; + UnaryOp0 unary_op0_; + UnaryOp1 unary_op1_; +}; + +// y = BinaryOp0(BinaryOp1(UnaryOp0(x0), UnaryOp1(x1)), UnaryOp2(x2)) +template +struct TrinaryWithUnaryCombinedOp +{ + __host__ __device__ TrinaryWithUnaryCombinedOp(BinaryOp0 binary_op0, + BinaryOp0 binary_op1, + UnaryOp0 unary_op0, + UnaryOp1 unary_op1, + UnaryOp2 unary_op2) + : binary_op0_(binary_op0), + binary_op1_(binary_op1), + unary_op0_(unary_op0), + unary_op1_(unary_op1), + unary_op2_(unary_op2) + { + } + + template + __host__ __device__ void operator()(Y& y, const X0& x0, const X1& x1, const X2& x2) const + { + + Y unary_x0_tmp_result; + Y unary_x1_tmp_result; + Y unary_x2_tmp_result; + unary_op0_(unary_x0_tmp_result, x0); + unary_op1_(unary_x1_tmp_result, x1); + unary_op2_(unary_x2_tmp_result, x2); + binary_op0_(unary_x0_tmp_result, unary_x0_tmp_result, unary_x1_tmp_result); + binary_op1_(y, unary_x0_tmp_result, unary_x2_tmp_result); + }; + + private: + BinaryOp0 binary_op0_{}; + BinaryOp1 binary_op1_{}; + UnaryOp0 unary_op0_{}; + UnaryOp1 unary_op1_{}; + UnaryOp2 unary_op2_{}; +}; + +} // namespace element_wise +} // namespace tensor_operation +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp b/include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp index 9c64ad4dfa..1add81e69e 100644 --- a/include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp +++ b/include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp @@ -12,10 +12,6 @@ namespace ck { namespace tensor_operation { namespace element_wise { -#if CK_WORKAROUND_SWDEV_383542 -extern "C" __device__ float __ocml_native_recip_f32(float); -#endif - struct PassThroughPack2 { template @@ -449,11 +445,7 @@ struct FastGelu const float u = x * (c1 * x * x + c2); const float emu = __expf(u); -#if !CK_WORKAROUND_SWDEV_383542 - y = x * __frcp_rn(1.f + emu); -#else - y = x * __ocml_native_recip_f32(1.f + emu); -#endif + y = x * ck::math::rcp(1.f + emu); } template <> @@ -559,6 +551,244 @@ struct TanH }; }; +struct ACos +{ + template + __host__ __device__ void operator()(T& y, const T& x) const + { + static_assert(is_same::value || is_same::value || + is_same::value || is_same::value || + is_same::value, + "Data type is not supported by this operation!"); + + y = ck::math::acos(x); + }; +}; + +struct Neg +{ + template + __host__ __device__ void operator()(T& y, const T& x) const + { + static_assert(is_same::value || is_same::value || + is_same::value || is_same::value || + is_same::value, + "Data type is not supported by this operation!"); + + y = ck::math::neg(x); + }; +}; + +struct ATan +{ + template + __host__ __device__ void operator()(T& y, const T& x) const + { + static_assert(is_same::value || is_same::value || + is_same::value || is_same::value || + is_same::value, + "Data type is not supported by this operation!"); + + y = ck::math::atan(x); + }; +}; + +struct Sin +{ + template + __host__ __device__ void operator()(T& y, const T& x) const + { + static_assert(is_same::value || is_same::value || + is_same::value || is_same::value || + is_same::value, + "Data type is not supported by this operation!"); + + y = ck::math::sin(x); + }; +}; + +struct ASinH +{ + template + __host__ __device__ void operator()(T& y, const T& x) const + { + static_assert(is_same::value || is_same::value || + is_same::value || is_same::value || + is_same::value, + "Data type is not supported by this operation!"); + + y = ck::math::asinh(x); + }; +}; + +struct Cos +{ + template + __host__ __device__ void operator()(T& y, const T& x) const + { + static_assert(is_same::value || is_same::value || + is_same::value || is_same::value || + is_same::value, + "Data type is not supported by this operation!"); + + y = ck::math::cos(x); + }; +}; + +struct ACosH +{ + template + __host__ __device__ void operator()(T& y, const T& x) const + { + static_assert(is_same::value || is_same::value || + is_same::value || is_same::value || + is_same::value, + "Data type is not supported by this operation!"); + + y = ck::math::acosh(x); + }; +}; + +struct Tan +{ + template + __host__ __device__ void operator()(T& y, const T& x) const + { + static_assert(is_same::value || is_same::value || + is_same::value || is_same::value || + is_same::value, + "Data type is not supported by this operation!"); + + y = ck::math::tan(x); + }; +}; + +struct ATanH +{ + template + __host__ __device__ void operator()(T& y, const T& x) const + { + static_assert(is_same::value || is_same::value || + is_same::value || is_same::value || + is_same::value, + "Data type is not supported by this operation!"); + + y = ck::math::atanh(x); + }; +}; + +struct SinH +{ + template + __host__ __device__ void operator()(T& y, const T& x) const + { + static_assert(is_same::value || is_same::value || + is_same::value || is_same::value || + is_same::value, + "Data type is not supported by this operation!"); + + y = ck::math::sinh(x); + }; +}; + +struct Ceil +{ + template + __host__ __device__ void operator()(T& y, const T& x) const + { + static_assert(is_same::value || is_same::value || + is_same::value || is_same::value || + is_same::value, + "Data type is not supported by this operation!"); + + y = ck::math::ceil(x); + }; +}; + +struct Exp +{ + template + __host__ __device__ void operator()(T& y, const T& x) const + { + static_assert(is_same::value || is_same::value || + is_same::value || is_same::value || + is_same::value, + "Data type is not supported by this operation!"); + + y = ck::math::exp(x); + }; +}; + +struct CosH +{ + template + __host__ __device__ void operator()(T& y, const T& x) const + { + static_assert(is_same::value || is_same::value || + is_same::value || is_same::value || + is_same::value, + "Data type is not supported by this operation!"); + + y = ck::math::cosh(x); + }; +}; + +struct Floor +{ + template + __host__ __device__ void operator()(T& y, const T& x) const + { + static_assert(is_same::value || is_same::value || + is_same::value || is_same::value || + is_same::value, + "Data type is not supported by this operation!"); + + y = ck::math::floor(x); + }; +}; + +struct Log +{ + template + __host__ __device__ void operator()(T& y, const T& x) const + { + static_assert(is_same::value || is_same::value || + is_same::value || is_same::value || + is_same::value, + "Data type is not supported by this operation!"); + + y = ck::math::log(x); + }; +}; + +struct ASin +{ + template + __host__ __device__ void operator()(T& y, const T& x) const + { + static_assert(is_same::value || is_same::value || + is_same::value || is_same::value || + is_same::value, + "Data type is not supported by this operation!"); + + y = ck::math::asin(x); + }; +}; + +struct Rcp +{ + template + __host__ __device__ void operator()(T& y, const T& x) const + { + static_assert(is_same::value || is_same::value || + is_same::value || is_same::value || + is_same::value, + "Data type is not supported by this operation!"); + + y = ck::math::rcp(x); + }; +}; + struct Swish { Swish(float beta = 1.0f) : beta_(beta) {} diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_elementwise_dynamic_vector_dims.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_elementwise_dynamic_vector_dims.hpp index 2a906a1432..4d1a09b445 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_elementwise_dynamic_vector_dims.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_elementwise_dynamic_vector_dims.hpp @@ -118,8 +118,16 @@ struct GridwiseElementwise __builtin_amdgcn_readfirstlane(block_work_idx[I0] * M0PerBlock); const index_t m1_block_data_idx_on_grid = __builtin_amdgcn_readfirstlane(block_work_idx[I1] * M1PerBlock); - const auto thread_grid_offset = - make_multi_index(m0_block_data_idx_on_grid, m1_block_data_idx_on_grid); + const auto input_thread_grid_offset = generate_tuple( + [&](auto) { + return make_multi_index(m0_block_data_idx_on_grid, m1_block_data_idx_on_grid); + }, + Number{}); + const auto output_thread_grid_offset = generate_tuple( + [&](auto) { + return make_multi_index(m0_block_data_idx_on_grid, m1_block_data_idx_on_grid); + }, + Number{}); using ThisThreadBlock = ThisThreadBlock; // If src and dst have same vector dim, then: @@ -157,9 +165,9 @@ struct GridwiseElementwise uniform_sequence_gen_t, uniform_sequence_gen_t, uniform_sequence_gen_t>{in_grid_desc_tuple, - thread_grid_offset, + input_thread_grid_offset, out_grid_desc_tuple, - thread_grid_offset, + output_thread_grid_offset, elementwise_op}; global_to_global_transfer.Run( in_grid_desc_tuple, in_global_buf_tuple, out_grid_desc_tuple, out_global_buf_tuple, I0); diff --git a/include/ck/utility/math_v2.hpp b/include/ck/utility/math_v2.hpp index a07fde3da3..2b921cdc7c 100644 --- a/include/ck/utility/math_v2.hpp +++ b/include/ck/utility/math_v2.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -14,6 +14,10 @@ namespace ck { namespace math { +#if CK_WORKAROUND_SWDEV_383542 +extern "C" __device__ float __ocml_native_recip_f32(float); +#endif + // math functions for the host, some are implemented by calling C++ std functions static inline __host__ float abs(float x) { return std::abs(x); }; @@ -111,6 +115,276 @@ inline __host__ double tanh(double x) return std::tanh(x); }; +template +inline __host__ T acos(T x) +{ + return ck::type_convert(std::acosf(ck::type_convert(x))); +}; + +template <> +inline __host__ float acos(float x) +{ + return std::acosf(x); +}; + +template <> +inline __host__ double acos(double x) +{ + return std::acos(x); +}; + +template +inline __host__ T neg(T x) +{ + return ck::type_convert(-(ck::type_convert(x))); +}; + +template <> +inline __host__ float neg(float x) +{ + return -x; +}; + +template <> +inline __host__ double neg(double x) +{ + return -x; +}; + +template <> +inline __host__ int32_t neg(int32_t x) +{ + return -x; +}; + +template <> +inline __host__ int8_t neg(int8_t x) +{ + return -x; +}; + +template +inline __host__ T atan(T x) +{ + return ck::type_convert(std::atanf(ck::type_convert(x))); +}; + +template <> +inline __host__ float atan(float x) +{ + return std::atanf(x); +}; + +template <> +inline __host__ double atan(double x) +{ + return std::atan(x); +}; + +template +inline __host__ T sin(T x) +{ + return ck::type_convert(std::sinf(ck::type_convert(x))); +}; + +template <> +inline __host__ float sin(float x) +{ + return std::sinf(x); +}; + +template <> +inline __host__ double sin(double x) +{ + return std::sin(x); +}; + +template +inline __host__ T asin(T x) +{ + return ck::type_convert(std::asinf(ck::type_convert(x))); +}; + +template <> +inline __host__ float asin(float x) +{ + return std::asinf(x); +}; + +template <> +inline __host__ double asin(double x) +{ + return std::asin(x); +}; + +template +inline __host__ T asinh(T x) +{ + return ck::type_convert(std::asinhf(ck::type_convert(x))); +}; + +template <> +inline __host__ float asinh(float x) +{ + return std::asinhf(x); +}; + +template <> +inline __host__ double asinh(double x) +{ + return std::asinh(x); +}; + +template +inline __host__ T cos(T x) +{ + return ck::type_convert(std::cosf(ck::type_convert(x))); +}; + +template <> +inline __host__ float cos(float x) +{ + return std::cosf(x); +}; + +template <> +inline __host__ double cos(double x) +{ + return std::cos(x); +}; + +template +inline __host__ T acosh(T x) +{ + return ck::type_convert(std::acoshf(ck::type_convert(x))); +}; + +template <> +inline __host__ float acosh(float x) +{ + return std::acoshf(x); +}; + +template <> +inline __host__ double acosh(double x) +{ + return std::acosh(x); +}; + +template +inline __host__ T tan(T x) +{ + return ck::type_convert(std::tanf(ck::type_convert(x))); +}; + +template <> +inline __host__ float tan(float x) +{ + return std::tanf(x); +}; + +template <> +inline __host__ double tan(double x) +{ + return std::tan(x); +}; + +template +inline __host__ T atanh(T x) +{ + return ck::type_convert(std::atanhf(ck::type_convert(x))); +}; + +template <> +inline __host__ float atanh(float x) +{ + return std::atanhf(x); +}; + +template <> +inline __host__ double atanh(double x) +{ + return std::atanh(x); +}; + +template +inline __host__ T sinh(T x) +{ + return ck::type_convert(std::sinhf(ck::type_convert(x))); +}; + +template <> +inline __host__ float sinh(float x) +{ + return std::sinhf(x); +}; + +template <> +inline __host__ double sinh(double x) +{ + return std::sinh(x); +}; + +template +inline __host__ T ceil(T x) +{ + return ck::type_convert(std::ceilf(ck::type_convert(x))); +}; + +template <> +inline __host__ float ceil(float x) +{ + return std::ceilf(x); +}; + +template <> +inline __host__ double ceil(double x) +{ + return std::ceil(x); +}; + +template +inline __host__ T cosh(T x) +{ + return ck::type_convert(std::coshf(ck::type_convert(x))); +}; + +template <> +inline __host__ float cosh(float x) +{ + return std::coshf(x); +}; + +template <> +inline __host__ double cosh(double x) +{ + return std::cosh(x); +}; + +template +inline __host__ T floor(T x) +{ + return ck::type_convert(std::floorf(ck::type_convert(x))); +}; + +template <> +inline __host__ float floor(float x) +{ + return std::floorf(x); +}; + +template <> +inline __host__ double floor(double x) +{ + return std::floor(x); +}; + +template +inline __host__ T rcp(T x) +{ + return ck::type_convert(1.f / ck::type_convert(x)); +}; + template inline __host__ T exp(T x) { @@ -282,6 +556,286 @@ inline __device__ double tanh(double x) return ::tanh(x); }; +template +inline __device__ T acos(T x) +{ + return ck::type_convert(::acosf(ck::type_convert(x))); +}; + +template <> +inline __device__ float acos(float x) +{ + return ::acosf(x); +}; + +template <> +inline __device__ double acos(double x) +{ + return ::acos(x); +}; + +template +inline __device__ T neg(T x) +{ + return ck::type_convert(-(ck::type_convert(x))); +}; + +template <> +inline __device__ float neg(float x) +{ + return -x; +}; + +template <> +inline __device__ double neg(double x) +{ + return -x; +}; + +template <> +inline __device__ int32_t neg(int32_t x) +{ + return -x; +}; + +template <> +inline __device__ int8_t neg(int8_t x) +{ + return -x; +}; + +template <> +inline __device__ half_t neg(half_t x) +{ + return __hneg(x); +}; + +template +inline __device__ T atan(T x) +{ + return ck::type_convert(::atanf(ck::type_convert(x))); +}; + +template <> +inline __device__ float atan(float x) +{ + return ::atanf(x); +}; + +template <> +inline __device__ double atan(double x) +{ + return ::atan(x); +}; + +template +inline __device__ T sin(T x) +{ + return ck::type_convert(::sinf(ck::type_convert(x))); +}; + +template <> +inline __device__ float sin(float x) +{ + return ::sinf(x); +}; + +template <> +inline __device__ double sin(double x) +{ + return ::sin(x); +}; + +template <> +inline __device__ half_t sin(half_t x) +{ + return ::hsin(x); +}; + +template +inline __device__ T asin(T x) +{ + return ck::type_convert(::asinf(ck::type_convert(x))); +}; + +template <> +inline __device__ float asin(float x) +{ + return ::asinf(x); +}; + +template <> +inline __device__ double asin(double x) +{ + return ::asin(x); +}; + +template +inline __device__ T asinh(T x) +{ + return ck::type_convert(::asinhf(ck::type_convert(x))); +}; + +template <> +inline __device__ float asinh(float x) +{ + return ::asinhf(x); +}; + +template <> +inline __device__ double asinh(double x) +{ + return ::asinh(x); +}; + +template +inline __device__ T acosh(T x) +{ + return ck::type_convert(::acoshf(ck::type_convert(x))); +}; + +template <> +inline __device__ float acosh(float x) +{ + return ::acoshf(x); +}; + +template <> +inline __device__ double acosh(double x) +{ + return ::acosh(x); +}; + +template +inline __device__ T tan(T x) +{ + return ck::type_convert(::tanf(ck::type_convert(x))); +}; + +template <> +inline __device__ float tan(float x) +{ + return ::tanf(x); +}; + +template <> +inline __device__ double tan(double x) +{ + return ::tan(x); +}; + +template +inline __device__ T atanh(T x) +{ + return ck::type_convert(::atanhf(ck::type_convert(x))); +}; + +template <> +inline __device__ float atanh(float x) +{ + return ::atanhf(x); +}; + +template <> +inline __device__ double atanh(double x) +{ + return ::atanh(x); +}; + +template +inline __device__ T sinh(T x) +{ + return ck::type_convert(::sinhf(ck::type_convert(x))); +}; + +template <> +inline __device__ float sinh(float x) +{ + return ::sinhf(x); +}; + +template <> +inline __device__ double sinh(double x) +{ + return ::sinh(x); +}; + +template +inline __device__ T ceil(T x) +{ + return ck::type_convert(::ceilf(ck::type_convert(x))); +}; + +template <> +inline __device__ float ceil(float x) +{ + return ::ceilf(x); +}; + +template <> +inline __device__ double ceil(double x) +{ + return ::ceil(x); +}; + +template <> +inline __device__ half_t ceil(half_t x) +{ + return ::hceil(x); +}; + +template +inline __device__ T cosh(T x) +{ + return ck::type_convert(::coshf(ck::type_convert(x))); +}; + +template <> +inline __device__ float cosh(float x) +{ + return ::coshf(x); +}; + +template <> +inline __device__ double cosh(double x) +{ + return ::cosh(x); +}; + +template +inline __device__ T floor(T x) +{ + return ck::type_convert(::floorf(ck::type_convert(x))); +}; + +template <> +inline __device__ float floor(float x) +{ + return ::floorf(x); +}; + +template <> +inline __device__ double floor(double x) +{ + return ::floor(x); +}; + +template <> +inline __device__ half_t floor(half_t x) +{ + return ::hfloor(x); +}; + +template +inline __device__ T rcp(T x) +{ +#if !CK_WORKAROUND_SWDEV_383542 + return __frcp_rn(x); +#else + return __ocml_native_recip_f32(x); +#endif +}; + template inline __device__ T exp(T x) { diff --git a/library/include/ck/library/reference_tensor_operation/cpu/reference_elementwise.hpp b/library/include/ck/library/reference_tensor_operation/cpu/reference_elementwise.hpp new file mode 100644 index 0000000000..470641fff7 --- /dev/null +++ b/library/include/ck/library/reference_tensor_operation/cpu/reference_elementwise.hpp @@ -0,0 +1,110 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "ck/tensor_operation/gpu/element/combined_element_wise_operation.hpp" +#include "ck/tensor_operation/gpu/device/device_base.hpp" +#include "ck/library/utility/host_tensor.hpp" + +namespace ck { +namespace tensor_operation { +namespace host { + +template +struct ReferenceElementwise : public device::BaseOperator +{ + // Argument + struct Argument : public device::BaseArgument + { + Argument(const std::array, NumATensors>& a_tensors, + Tensor& b_tensor, + ElementOp element_op) + : a_tensors_{a_tensors}, b_tensor_{b_tensor}, element_op_{element_op} + { + } + + const std::array, NumATensors>& a_tensors_; + Tensor& b_tensor_; + ElementOp element_op_; + }; + + // Invoker + struct Invoker : public device::BaseInvoker + { + using Argument = ReferenceElementwise::Argument; + + float Run(const Argument& arg) + { + if constexpr(NumATensors == 1) + { + arg.b_tensor_.ForEach([&](auto& self, auto idx) { + arg.element_op_(self(idx), arg.a_tensors_[0](idx)); + }); + } + else if constexpr(NumATensors == 2) + { + arg.b_tensor_.ForEach([&](auto& self, auto idx) { + arg.element_op_(self(idx), arg.a_tensors_[0](idx), arg.a_tensors_[1](idx)); + }); + } + else if constexpr(NumATensors == 3) + { + arg.b_tensor_.ForEach([&](auto& self, auto idx) { + arg.element_op_(self(idx), + arg.a_tensors_[0](idx), + arg.a_tensors_[1](idx), + arg.a_tensors_[2](idx)); + }); + } + return 0; + } + + float Run(const device::BaseArgument* p_arg, + const StreamConfig& /* stream_config */ = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg)); + } + }; + + static constexpr bool IsValidCompilationParameter() + { + // TODO: properly implement this check + return true; + } + + bool IsSupportedArgument(const device::BaseArgument*) override { return true; } + + static auto MakeArgument(const std::array, NumATensors>& a_tensors, + Tensor& b_tensor, + ElementOp element_op) + { + return Argument{a_tensors, b_tensor, element_op}; + } + + static auto MakeInvoker() { return Invoker{}; } + + virtual std::unique_ptr MakeInvokerPointer() + { + return std::make_unique(Invoker{}); + } + + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + // clang-format off + str << "ReferenceElementwise" + << std::endl; + // clang-format on + + return str.str(); + } +}; + +} // namespace host +} // namespace tensor_operation +} // namespace ck diff --git a/profiler/include/profiler/profile_permute_scale_impl.hpp b/profiler/include/profiler/profile_permute_scale_impl.hpp index c69e36142d..186a24501e 100644 --- a/profiler/include/profiler/profile_permute_scale_impl.hpp +++ b/profiler/include/profiler/profile_permute_scale_impl.hpp @@ -14,6 +14,8 @@ #include "ck/library/tensor_operation_instance/gpu/permute_scale.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_elementwise.hpp" + #include "ck/library/utility/check_err.hpp" #include "ck/library/utility/device_memory.hpp" #include "ck/library/utility/host_tensor.hpp" @@ -21,14 +23,6 @@ #include "ck/library/utility/literals.hpp" namespace ck { -template -void reference_permute_scale(HostTensorB& b_tensor, - const HostTensorA& a_tensor, - ElementOp tensor_op) -{ - b_tensor.ForEach([&](auto& self, auto idx) { tensor_op(self(idx), a_tensor(idx)); }); -} - namespace profiler { template @@ -46,7 +40,8 @@ bool profile_permute_scale_impl(int do_verification, using ElementOp = ck::tensor_operation::element_wise::Scale; float scale = 2.f; - Tensor a(lengths_vector, input_strides_vector); + std::array, 1> as = {Tensor(lengths_vector, input_strides_vector)}; + Tensor& a = as[0]; Tensor b(lengths_vector, output_strides_vector); Tensor host_b(lengths_vector, output_strides_vector); @@ -83,7 +78,14 @@ bool profile_permute_scale_impl(int do_verification, if(do_verification) { - reference_permute_scale(host_b, a, ElementOp{scale}); + using ReferenceElementwiseInstance = + ck::tensor_operation::host::ReferenceElementwise<1, ADataType, BDataType, ElementOp>; + auto ref_elementwise = ReferenceElementwiseInstance{}; + auto ref_invoker = ref_elementwise.MakeInvoker(); + + auto ref_argument = ref_elementwise.MakeArgument(as, host_b, ElementOp{scale}); + + ref_invoker.Run(ref_argument); } auto copy = [](const auto& x, auto& y) { std::copy(x.begin(), x.end(), y.begin()); };