From a61f34f70f484c3cc871392ec633e57832d92925 Mon Sep 17 00:00:00 2001 From: rocking Date: Tue, 17 May 2022 04:36:47 +0800 Subject: [PATCH 1/7] Add elementwise operation kernel and example --- example/19_binary_elementwise/CMakeLists.txt | 1 + .../19_binary_elementwise/broadcast_add.cpp | 133 ++++++++++++ example/CMakeLists.txt | 1 + .../gpu/device/device_binary_elementwise.hpp | 190 ++++++++++++++++++ .../element/binary_element_wise_operation.hpp | 19 ++ .../grid/gridwise_binary_elementwise_1d.hpp | 150 ++++++++++++++ include/ck/utility/get_id.hpp | 4 + 7 files changed, 498 insertions(+) create mode 100644 example/19_binary_elementwise/CMakeLists.txt create mode 100644 example/19_binary_elementwise/broadcast_add.cpp create mode 100644 include/ck/tensor_operation/gpu/device/device_binary_elementwise.hpp create mode 100644 include/ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp create mode 100644 include/ck/tensor_operation/gpu/grid/gridwise_binary_elementwise_1d.hpp diff --git a/example/19_binary_elementwise/CMakeLists.txt b/example/19_binary_elementwise/CMakeLists.txt new file mode 100644 index 0000000000..143e31c196 --- /dev/null +++ b/example/19_binary_elementwise/CMakeLists.txt @@ -0,0 +1 @@ +add_example_executable(example_broadcast_add broadcast_add.cpp) \ No newline at end of file diff --git a/example/19_binary_elementwise/broadcast_add.cpp b/example/19_binary_elementwise/broadcast_add.cpp new file mode 100644 index 0000000000..7b06ec8b28 --- /dev/null +++ b/example/19_binary_elementwise/broadcast_add.cpp @@ -0,0 +1,133 @@ +#include +#include +#include +#include +#include +#include +#include +#include "check_err.hpp" +#include "config.hpp" +#include "device.hpp" +#include "host_tensor.hpp" +#include "host_tensor_generator.hpp" + +#include "device_tensor.hpp" +#include "binary_element_wise_operation.hpp" + +#include "device_binary_elementwise.hpp" + +using F16 = ck::half_t; +using F32 = float; + +using ABDataType = F16; +using CDataType = F16; +using EltwiseComputeDataType = F32; + +using Add = ck::tensor_operation::binary_element_wise::Add; + +using DeviceElementwiseAddInstance = ck::tensor_operation::device:: + DeviceBinaryElementwise; + +template +void host_broadcast2D( + HostTensorC& C, const HostTensorA& A, const HostTensorB& B, int M, int N, Functor functor) +{ + for(int m = 0; m < M; ++m) + { + for(int n = 0; n < N; ++n) + { + ComputeDataType Amn = static_cast(A(m, n)); + ComputeDataType Cmn = 0; + if constexpr(broadcastDim == 0) + { + ComputeDataType Bn = static_cast(B(n)); + functor(Cmn, Amn, Bn); + } + else + { + ComputeDataType Bm = static_cast(B(m)); + functor(Cmn, Amn, Bm); + } + C(m, n) = static_cast(Cmn); + } + } +} + +int main() +{ + bool do_verification = true; + bool time_kernel = false; + + ck::index_t M = 1024; + ck::index_t N = 1024; + ck::index_t Stride = 1024; + + auto f_host_tensor_descriptor2d = [](std::size_t row, std::size_t col, std::size_t stride) { + return HostTensorDescriptor(std::vector({row, col}), + std::vector({stride, 1})); + }; + + Tensor a_m_n(f_host_tensor_descriptor2d(M, N, Stride)); + + Tensor b_n(std::vector({static_cast(N)}), + std::vector({1})); + + Tensor c_m_n(f_host_tensor_descriptor2d(M, N, Stride)); + + a_m_n.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b_n.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + + DeviceMem a_m_n_device_buf(sizeof(ABDataType) * a_m_n.mDesc.GetElementSpace()); + DeviceMem b_n_device_buf(sizeof(ABDataType) * b_n.mDesc.GetElementSpace()); + DeviceMem c_m_n_device_buf(sizeof(CDataType) * c_m_n.mDesc.GetElementSpace()); + + a_m_n_device_buf.ToDevice(a_m_n.mData.data()); + b_n_device_buf.ToDevice(b_n.mData.data()); + + auto broadcastAdd = DeviceElementwiseAddInstance{}; + auto argument = broadcastAdd.MakeArgumentPointer(a_m_n_device_buf.GetDeviceBuffer(), + b_n_device_buf.GetDeviceBuffer(), + c_m_n_device_buf.GetDeviceBuffer(), + {M, N}, + {Stride, 1}, + {0, 1}, + {Stride, 1}, + Add{}, + 256); + + if(!broadcastAdd.IsSupportedArgument(argument.get())) + { + throw std::runtime_error("The runtime parameters seems not supported by the " + "DeviceBinaryElementwise_2D instance, exiting!"); + }; + + auto broadcastAdd_invoker_ptr = broadcastAdd.MakeInvokerPointer(); + float ave_time = + broadcastAdd_invoker_ptr->Run(argument.get(), StreamConfig{nullptr, time_kernel}); + + std::cout << "Perf: " << ave_time << " ms" << std::endl; + + bool pass = true; + if(do_verification) + { + c_m_n_device_buf.FromDevice(c_m_n.mData.data()); + Tensor host_c_m_n(f_host_tensor_descriptor2d(M, N, Stride)); + + host_broadcast2D, + Tensor, + Tensor, + EltwiseComputeDataType, + Add, + 0>(host_c_m_n, a_m_n, b_n, M, N, Add{}); + + pass &= ck::utils::check_err( + c_m_n.mData, host_c_m_n.mData, "Error: Incorrect results d1", 1e-3, 1e-3); + } + + return pass ? 0 : 1; +} diff --git a/example/CMakeLists.txt b/example/CMakeLists.txt index 4d81e84c01..1f4ed01de1 100644 --- a/example/CMakeLists.txt +++ b/example/CMakeLists.txt @@ -51,3 +51,4 @@ add_subdirectory(17_convnd_bwd_data_xdl) add_subdirectory(15_grouped_gemm) add_subdirectory(16_gemm_reduce) add_subdirectory(18_batched_gemm_reduce) +add_subdirectory(19_binary_elementwise) diff --git a/include/ck/tensor_operation/gpu/device/device_binary_elementwise.hpp b/include/ck/tensor_operation/gpu/device/device_binary_elementwise.hpp new file mode 100644 index 0000000000..6a789aa356 --- /dev/null +++ b/include/ck/tensor_operation/gpu/device/device_binary_elementwise.hpp @@ -0,0 +1,190 @@ +#pragma once +#include +#include + +#include "device.hpp" +#include "device_base.hpp" +#include "gridwise_binary_elementwise_1d.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +template +struct DeviceBinaryElementwise : public BaseOperator +{ + static constexpr auto I0 = Number<0>{}; + + static auto MakeDescriptor_M0(const std::vector& shape, + const std::vector& stride, + index_t gridSize, + index_t threadPerBlock) + { + const int m = shape[0]; + const int n = shape[1]; + + // 2d desc - [m, n] + const auto desc_m_n = + make_naive_tensor_descriptor(make_tuple(m, n), make_tuple(stride[0], stride[1])); + + // 1d desc - [m * n] + const auto desc_m0 = + transform_tensor_descriptor(desc_m_n, + make_tuple(make_merge_transform(make_tuple(m, n))), + make_tuple(Sequence<0, 1>{}), + make_tuple(Sequence<0>{})); + + // pad + const auto m0 = desc_m0.GetLength(I0); + const index_t loop_step = gridSize * threadPerBlock * ScalarPerVector; + const auto pad = math::integer_least_multiple(m0, loop_step) - m0; + const auto desc_m0_pad = + transform_tensor_descriptor(desc_m0, + make_tuple(make_right_pad_transform(m0, pad)), + make_tuple(Sequence<0>{}), + make_tuple(Sequence<0>{})); + return desc_m0_pad; + } + + using GridDesc_M0 = decltype(MakeDescriptor_M0({1, 1}, {1, 1}, 1, 1)); + using GridwiseBinEltwise = GridwiseBinaryElementwise_1D; + + struct Argument : public BaseArgument + { + Argument(const ADataType* p_a, + const BDataType* p_b, + CDataType* p_c, + const std::vector& shape, + const std::vector& stride_a, + const std::vector& stride_b, + const std::vector& stride_c, + ElementwiseFunctor functor, + index_t threadPerBlock) + : p_a_(p_a), + p_b_(p_b), + p_c_(p_c), + functor_(functor), + threadPerBlock_(threadPerBlock), + gridSize_(128) // FIXME - Calculate the grid size by number of CU in the future + { + a_grid_desc_m0_ = MakeDescriptor_M0(shape, stride_a, gridSize_, threadPerBlock_); + b_grid_desc_m0_ = MakeDescriptor_M0(shape, stride_b, gridSize_, threadPerBlock_); + c_grid_desc_m0_ = MakeDescriptor_M0(shape, stride_c, gridSize_, threadPerBlock_); + } + + const ADataType* p_a_; + const BDataType* p_b_; + CDataType* p_c_; + GridDesc_M0 a_grid_desc_m0_; + GridDesc_M0 b_grid_desc_m0_; + GridDesc_M0 c_grid_desc_m0_; + ElementwiseFunctor functor_; + index_t threadPerBlock_; + index_t gridSize_; + }; + + struct Invoker : public BaseInvoker + { + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { + const auto kernel = kernel_elementwise_1d; + + float elapsed_time = launch_and_time_kernel(stream_config, + kernel, + dim3(arg.gridSize_), + dim3(arg.threadPerBlock_), + 0, + arg.p_a_, + arg.p_b_, + arg.p_c_, + arg.a_grid_desc_m0_, + arg.b_grid_desc_m0_, + arg.c_grid_desc_m0_, + arg.functor_); + return elapsed_time; + } + + // polymorphic + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + } + }; + + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + const Argument* pArg = dynamic_cast(p_arg); + + if(pArg == nullptr) + return false; + + // m * n + const auto m0 = pArg->c_grid_desc_m0_.GetLength(I0); + + if(m0 % ScalarPerVector != 0) + return false; + + return true; + }; + + std::unique_ptr MakeArgumentPointer(const void* p_a, + const void* p_b, + void* p_c, + std::vector shape, + std::vector stride_a, + std::vector stride_b, + std::vector stride_c, + ElementwiseFunctor functor, + index_t threadPerBlock) + { + return std::make_unique(static_cast(p_a), + static_cast(p_b), + static_cast(p_c), + shape, + stride_a, + stride_b, + stride_c, + functor, + threadPerBlock); + } + + std::unique_ptr MakeInvokerPointer() + { + return std::make_unique(Invoker{}); + } + + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + // clang-format off + str << "DeviceBinaryElementwise" + << "<" + << "ScalarPerVector = " << ScalarPerVector + << ">"; + // clang-format on + + return str.str(); + } +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp b/include/ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp new file mode 100644 index 0000000000..ebec6b5b50 --- /dev/null +++ b/include/ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp @@ -0,0 +1,19 @@ +#pragma once +#include "data_type.hpp" + +namespace ck { +namespace tensor_operation { +namespace binary_element_wise { + +struct Add +{ + __host__ __device__ constexpr void + operator()(float& dst, const float& src1, const float& src2) const + { + dst = src1 + src2; + } +}; + +} // namespace binary_element_wise +} // namespace tensor_operation +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_binary_elementwise_1d.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_binary_elementwise_1d.hpp new file mode 100644 index 0000000000..aea54ff53c --- /dev/null +++ b/include/ck/tensor_operation/gpu/grid/gridwise_binary_elementwise_1d.hpp @@ -0,0 +1,150 @@ +#pragma once + +#include "cluster_descriptor.hpp" +#include "data_type.hpp" +#include "element_wise_operation.hpp" +#include "threadwise_tensor_slice_transfer.hpp" + +namespace ck { + +template +__global__ void kernel_elementwise_1d(const ADataType* __restrict__ p_a_global, + const BDataType* __restrict__ p_b_global, + CDataType* __restrict__ p_c_global, + const GridDesc_M0 a_grid_desc_m0, + const GridDesc_M0 b_grid_desc_m0, + const GridDesc_M0 c_grid_desc_m0, + const ElementwiseFunctor functor) +{ + GridwiseBinEltwise::Run(p_a_global, + p_b_global, + p_c_global, + a_grid_desc_m0, + b_grid_desc_m0, + c_grid_desc_m0, + functor); +} + +template +struct GridwiseBinaryElementwise_1D +{ + static constexpr auto I0 = Number<0>{}; + static constexpr auto thread_desc_m0 = + make_naive_tensor_descriptor_packed(make_tuple(Number{})); + + using PassThrough = tensor_operation::element_wise::PassThrough; + + static __device__ __host__ auto CalculateElementwiseIndex() + { + const index_t global_thread_id = get_thread_global_1d_id(); + return make_multi_index(global_thread_id * ScalarPerVector); + } + + __device__ static void Run(const ADataType* __restrict__ p_a_global, + const BDataType* __restrict__ p_b_global, + CDataType* __restrict__ p_c_global, + const GridDesc_M0 a_grid_desc_m0, + const GridDesc_M0 b_grid_desc_m0, + const GridDesc_M0 c_grid_desc_m0, + const ElementwiseFunctor functor) + { + const auto a_global_buf = make_dynamic_buffer( + p_a_global, a_grid_desc_m0.GetElementSpaceSize()); + const auto b_global_buf = make_dynamic_buffer( + p_b_global, b_grid_desc_m0.GetElementSpaceSize()); + auto c_global_buf = make_dynamic_buffer( + p_c_global, c_grid_desc_m0.GetElementSpaceSize()); + + StaticBuffer a_thread_buf; + StaticBuffer b_thread_buf; + StaticBuffer c_thread_buf; + + const auto thread_to_global_offset = CalculateElementwiseIndex(); + + auto a_global_load = + ThreadwiseTensorSliceTransfer_v2, // SliceLengths + Sequence<0>, // DimAccessOrder + 0, // SrcVectorDim + ScalarPerVector, + 1, // SrcScalarStrideInVector + false>{a_grid_desc_m0, thread_to_global_offset}; + + auto b_global_load = + ThreadwiseTensorSliceTransfer_v2, // SliceLengths + Sequence<0>, // DimAccessOrder + 0, // SrcVectorDim + ScalarPerVector, + 1, // SrcScalarStrideInVector + false>{b_grid_desc_m0, thread_to_global_offset}; + + auto c_global_write = + ThreadwiseTensorSliceTransfer_v1r3, // SliceLengths + Sequence<0>, // DimAccessOrder + 0, // DstVectorDim + ScalarPerVector, + InMemoryDataOperationEnum::Set, + 1, // DstScalarStrideInVector + false>{ + c_grid_desc_m0, thread_to_global_offset, PassThrough{}}; + + const index_t threadPerBlock = get_block_size(); + const index_t blockPerGrid = get_grid_size(); + const auto m0 = c_grid_desc_m0.GetLength(I0); + const index_t loop_step = blockPerGrid * threadPerBlock * ScalarPerVector; + const auto loop_step_index = make_multi_index(loop_step); + + index_t num_iter = m0 / (loop_step); + do + { + // read and process ScalarPerVector elements + a_global_load.Run( + a_grid_desc_m0, a_global_buf, thread_desc_m0, make_tuple(I0), a_thread_buf); + + b_global_load.Run( + b_grid_desc_m0, b_global_buf, thread_desc_m0, make_tuple(I0), b_thread_buf); + + static_for<0, ScalarPerVector, 1>{}([&](auto m) { + constexpr auto offset = thread_desc_m0.CalculateOffset(make_tuple(m)); + functor(c_thread_buf(Number{}), + a_thread_buf(Number{}), + b_thread_buf(Number{})); + }); + + c_global_write.Run(thread_desc_m0, + make_tuple(I0), // SrcSliceOriginIdx + c_thread_buf, + c_grid_desc_m0, + c_global_buf); + + a_global_load.MoveSrcSliceWindow(a_grid_desc_m0, loop_step_index); + b_global_load.MoveSrcSliceWindow(b_grid_desc_m0, loop_step_index); + c_global_write.MoveDstSliceWindow(c_grid_desc_m0, loop_step_index); + } while(--num_iter); + } +}; + +} // namespace ck diff --git a/include/ck/utility/get_id.hpp b/include/ck/utility/get_id.hpp index d1288a2274..7c62b890c7 100644 --- a/include/ck/utility/get_id.hpp +++ b/include/ck/utility/get_id.hpp @@ -11,10 +11,14 @@ __host__ __device__ constexpr index_t get_warp_size() __device__ index_t get_thread_local_1d_id() { return threadIdx.x; } +__device__ index_t get_thread_global_1d_id() { return blockIdx.x * blockDim.x + threadIdx.x; } + __device__ index_t get_warp_local_1d_id() { return threadIdx.x / get_warp_size(); } __device__ index_t get_block_1d_id() { return blockIdx.x; } __device__ index_t get_grid_size() { return gridDim.x; } +__device__ index_t get_block_size() { return blockDim.x; } + } // namespace ck From c2626122afea6757b8940e89a65250befa9350b8 Mon Sep 17 00:00:00 2001 From: rocking Date: Tue, 17 May 2022 08:23:43 +0800 Subject: [PATCH 2/7] Add comment --- example/19_binary_elementwise/broadcast_add.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/example/19_binary_elementwise/broadcast_add.cpp b/example/19_binary_elementwise/broadcast_add.cpp index 7b06ec8b28..9dbb38da43 100644 --- a/example/19_binary_elementwise/broadcast_add.cpp +++ b/example/19_binary_elementwise/broadcast_add.cpp @@ -95,7 +95,7 @@ int main() c_m_n_device_buf.GetDeviceBuffer(), {M, N}, {Stride, 1}, - {0, 1}, + {0, 1}, // broadcast in first dimension {Stride, 1}, Add{}, 256); From b456d5e53ec59c9fe90ee68a6b1c575934fb508d Mon Sep 17 00:00:00 2001 From: rocking Date: Tue, 17 May 2022 20:34:21 +0800 Subject: [PATCH 3/7] Add template argument of dim . Prepare to support multiple dimension --- .../19_binary_elementwise/broadcast_add.cpp | 2 +- .../gpu/device/device_binary_elementwise.hpp | 20 +++++++++++++++---- 2 files changed, 17 insertions(+), 5 deletions(-) diff --git a/example/19_binary_elementwise/broadcast_add.cpp b/example/19_binary_elementwise/broadcast_add.cpp index 9dbb38da43..55d1e130bf 100644 --- a/example/19_binary_elementwise/broadcast_add.cpp +++ b/example/19_binary_elementwise/broadcast_add.cpp @@ -26,7 +26,7 @@ using EltwiseComputeDataType = F32; using Add = ck::tensor_operation::binary_element_wise::Add; using DeviceElementwiseAddInstance = ck::tensor_operation::device:: - DeviceBinaryElementwise; + DeviceBinaryElementwise; template struct DeviceBinaryElementwise : public BaseOperator { static constexpr auto I0 = Number<0>{}; - static auto MakeDescriptor_M0(const std::vector& shape, - const std::vector& stride, - index_t gridSize, - index_t threadPerBlock) + static auto MakeDescriptor_M0_2d(const std::vector& shape, + const std::vector& stride, + index_t gridSize, + index_t threadPerBlock) { const int m = shape[0]; const int n = shape[1]; @@ -51,6 +52,17 @@ struct DeviceBinaryElementwise : public BaseOperator return desc_m0_pad; } + static auto MakeDescriptor_M0(const std::vector& shape, + const std::vector& stride, + index_t gridSize, + index_t threadPerBlock) + { + if constexpr(Dim == 2) + return MakeDescriptor_M0_2d(shape, stride, gridSize, threadPerBlock); + else + return make_naive_tensor_descriptor(make_tuple(0), make_tuple(0)); + } + using GridDesc_M0 = decltype(MakeDescriptor_M0({1, 1}, {1, 1}, 1, 1)); using GridwiseBinEltwise = GridwiseBinaryElementwise_1D Date: Tue, 17 May 2022 20:36:05 +0800 Subject: [PATCH 4/7] Rename example --- .../{broadcast_add.cpp => broadcast_add_2d.cpp} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename example/19_binary_elementwise/{broadcast_add.cpp => broadcast_add_2d.cpp} (100%) diff --git a/example/19_binary_elementwise/broadcast_add.cpp b/example/19_binary_elementwise/broadcast_add_2d.cpp similarity index 100% rename from example/19_binary_elementwise/broadcast_add.cpp rename to example/19_binary_elementwise/broadcast_add_2d.cpp From 4af77e1f0e1c7103d8e8535735e8e2df7fe16732 Mon Sep 17 00:00:00 2001 From: rocking Date: Tue, 17 May 2022 20:50:03 +0800 Subject: [PATCH 5/7] Support 1 dimension --- example/19_binary_elementwise/CMakeLists.txt | 3 +- .../broadcast_add_2d.cpp | 8 +- .../elementwise_add_1d.cpp | 119 ++++++++++++++++++ .../gpu/device/device_binary_elementwise.hpp | 25 +++- 4 files changed, 151 insertions(+), 4 deletions(-) create mode 100644 example/19_binary_elementwise/elementwise_add_1d.cpp diff --git a/example/19_binary_elementwise/CMakeLists.txt b/example/19_binary_elementwise/CMakeLists.txt index 143e31c196..202a9e1fcb 100644 --- a/example/19_binary_elementwise/CMakeLists.txt +++ b/example/19_binary_elementwise/CMakeLists.txt @@ -1 +1,2 @@ -add_example_executable(example_broadcast_add broadcast_add.cpp) \ No newline at end of file +add_example_executable(example_broadcast_add_2d broadcast_add_2d.cpp) +add_example_executable(example_elementwise_add_1d elementwise_add_1d.cpp) \ No newline at end of file diff --git a/example/19_binary_elementwise/broadcast_add_2d.cpp b/example/19_binary_elementwise/broadcast_add_2d.cpp index 55d1e130bf..4a6d2038e3 100644 --- a/example/19_binary_elementwise/broadcast_add_2d.cpp +++ b/example/19_binary_elementwise/broadcast_add_2d.cpp @@ -67,6 +67,11 @@ int main() ck::index_t N = 1024; ck::index_t Stride = 1024; + auto f_host_tensor_descriptor1d = [](std::size_t len, std::size_t stride) { + return HostTensorDescriptor(std::vector({len}), + std::vector({stride})); + }; + auto f_host_tensor_descriptor2d = [](std::size_t row, std::size_t col, std::size_t stride) { return HostTensorDescriptor(std::vector({row, col}), std::vector({stride, 1})); @@ -74,8 +79,7 @@ int main() Tensor a_m_n(f_host_tensor_descriptor2d(M, N, Stride)); - Tensor b_n(std::vector({static_cast(N)}), - std::vector({1})); + Tensor b_n(f_host_tensor_descriptor1d(N, 1)); Tensor c_m_n(f_host_tensor_descriptor2d(M, N, Stride)); diff --git a/example/19_binary_elementwise/elementwise_add_1d.cpp b/example/19_binary_elementwise/elementwise_add_1d.cpp new file mode 100644 index 0000000000..c9d0f77724 --- /dev/null +++ b/example/19_binary_elementwise/elementwise_add_1d.cpp @@ -0,0 +1,119 @@ +#include +#include +#include +#include +#include +#include +#include +#include "check_err.hpp" +#include "config.hpp" +#include "device.hpp" +#include "host_tensor.hpp" +#include "host_tensor_generator.hpp" + +#include "device_tensor.hpp" +#include "binary_element_wise_operation.hpp" + +#include "device_binary_elementwise.hpp" + +using F16 = ck::half_t; +using F32 = float; + +using ABDataType = F16; +using CDataType = F16; +using EltwiseComputeDataType = F32; + +using Add = ck::tensor_operation::binary_element_wise::Add; + +using DeviceElementwiseAddInstance = ck::tensor_operation::device:: + DeviceBinaryElementwise; + +template +void host_elementwise1D( + HostTensorC& C, const HostTensorA& A, const HostTensorB& B, int M, Functor functor) +{ + for(int m = 0; m < M; ++m) + { + ComputeDataType Am = static_cast(A(m)); + ComputeDataType Bm = static_cast(B(m)); + ComputeDataType Cm = 0; + functor(Cm, Am, Bm); + C(m) = static_cast(Cm); + } +} + +int main() +{ + bool do_verification = true; + bool time_kernel = false; + + ck::index_t M = 1024; + + auto f_host_tensor_descriptor1d = [](std::size_t len, std::size_t stride) { + return HostTensorDescriptor(std::vector({len}), + std::vector({stride})); + }; + + Tensor a_m(f_host_tensor_descriptor1d(M, 1)); + + Tensor b_m(f_host_tensor_descriptor1d(M, 1)); + + Tensor c_m(f_host_tensor_descriptor1d(M, 1)); + + a_m.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b_m.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + + DeviceMem a_m_device_buf(sizeof(ABDataType) * a_m.mDesc.GetElementSpace()); + DeviceMem b_m_device_buf(sizeof(ABDataType) * b_m.mDesc.GetElementSpace()); + DeviceMem c_m_device_buf(sizeof(CDataType) * c_m.mDesc.GetElementSpace()); + + a_m_device_buf.ToDevice(a_m.mData.data()); + b_m_device_buf.ToDevice(b_m.mData.data()); + + auto broadcastAdd = DeviceElementwiseAddInstance{}; + auto argument = broadcastAdd.MakeArgumentPointer(a_m_device_buf.GetDeviceBuffer(), + b_m_device_buf.GetDeviceBuffer(), + c_m_device_buf.GetDeviceBuffer(), + {M}, + {1}, + {1}, + {1}, + Add{}, + 256); + + if(!broadcastAdd.IsSupportedArgument(argument.get())) + { + throw std::runtime_error("The runtime parameters seems not supported by the " + "DeviceBinaryElementwise_2D instance, exiting!"); + }; + + auto broadcastAdd_invoker_ptr = broadcastAdd.MakeInvokerPointer(); + float ave_time = + broadcastAdd_invoker_ptr->Run(argument.get(), StreamConfig{nullptr, time_kernel}); + + std::cout << "Perf: " << ave_time << " ms" << std::endl; + + bool pass = true; + if(do_verification) + { + c_m_device_buf.FromDevice(c_m.mData.data()); + Tensor host_c_m(f_host_tensor_descriptor1d(M, 1)); + + host_elementwise1D, + Tensor, + Tensor, + EltwiseComputeDataType, + Add, + 0>(host_c_m, a_m, b_m, M, Add{}); + + pass &= ck::utils::check_err( + c_m.mData, host_c_m.mData, "Error: Incorrect results d1", 1e-3, 1e-3); + } + + return pass ? 0 : 1; +} diff --git a/include/ck/tensor_operation/gpu/device/device_binary_elementwise.hpp b/include/ck/tensor_operation/gpu/device/device_binary_elementwise.hpp index 60ca2895b5..bc3fe61dc4 100644 --- a/include/ck/tensor_operation/gpu/device/device_binary_elementwise.hpp +++ b/include/ck/tensor_operation/gpu/device/device_binary_elementwise.hpp @@ -21,6 +21,27 @@ struct DeviceBinaryElementwise : public BaseOperator { static constexpr auto I0 = Number<0>{}; + static auto MakeDescriptor_M0_1d(const std::vector& shape, + const std::vector& stride, + index_t gridSize, + index_t threadPerBlock) + { + // 1d desc - [m] + const auto desc_m0 = + make_naive_tensor_descriptor(make_tuple(shape[0]), make_tuple(stride[0])); + + // pad + const auto m0 = desc_m0.GetLength(I0); + const index_t loop_step = gridSize * threadPerBlock * ScalarPerVector; + const auto pad = math::integer_least_multiple(m0, loop_step) - m0; + const auto desc_m0_pad = + transform_tensor_descriptor(desc_m0, + make_tuple(make_right_pad_transform(m0, pad)), + make_tuple(Sequence<0>{}), + make_tuple(Sequence<0>{})); + return desc_m0_pad; + } + static auto MakeDescriptor_M0_2d(const std::vector& shape, const std::vector& stride, index_t gridSize, @@ -57,7 +78,9 @@ struct DeviceBinaryElementwise : public BaseOperator index_t gridSize, index_t threadPerBlock) { - if constexpr(Dim == 2) + if constexpr(Dim == 1) + return MakeDescriptor_M0_1d(shape, stride, gridSize, threadPerBlock); + else if constexpr(Dim == 2) return MakeDescriptor_M0_2d(shape, stride, gridSize, threadPerBlock); else return make_naive_tensor_descriptor(make_tuple(0), make_tuple(0)); From 492da45969690306d5a4003339c23659ca54bf6d Mon Sep 17 00:00:00 2001 From: rocking Date: Tue, 17 May 2022 20:53:45 +0800 Subject: [PATCH 6/7] Add static assert --- .../tensor_operation/gpu/device/device_binary_elementwise.hpp | 3 +++ 1 file changed, 3 insertions(+) diff --git a/include/ck/tensor_operation/gpu/device/device_binary_elementwise.hpp b/include/ck/tensor_operation/gpu/device/device_binary_elementwise.hpp index bc3fe61dc4..2cbb3f2f24 100644 --- a/include/ck/tensor_operation/gpu/device/device_binary_elementwise.hpp +++ b/include/ck/tensor_operation/gpu/device/device_binary_elementwise.hpp @@ -78,6 +78,9 @@ struct DeviceBinaryElementwise : public BaseOperator index_t gridSize, index_t threadPerBlock) { + static_assert(Dim == 1 || Dim == 2, + "wrong! DeviceBinaryElementwise not support this dimension"); + if constexpr(Dim == 1) return MakeDescriptor_M0_1d(shape, stride, gridSize, threadPerBlock); else if constexpr(Dim == 2) From ecdfe960921032c1aae6dc2c4a3e0ad1b8bba559 Mon Sep 17 00:00:00 2001 From: rocking Date: Tue, 17 May 2022 20:56:16 +0800 Subject: [PATCH 7/7] Add comment --- .../ck/tensor_operation/gpu/device/device_binary_elementwise.hpp | 1 + 1 file changed, 1 insertion(+) diff --git a/include/ck/tensor_operation/gpu/device/device_binary_elementwise.hpp b/include/ck/tensor_operation/gpu/device/device_binary_elementwise.hpp index 2cbb3f2f24..0b08b818a3 100644 --- a/include/ck/tensor_operation/gpu/device/device_binary_elementwise.hpp +++ b/include/ck/tensor_operation/gpu/device/device_binary_elementwise.hpp @@ -81,6 +81,7 @@ struct DeviceBinaryElementwise : public BaseOperator static_assert(Dim == 1 || Dim == 2, "wrong! DeviceBinaryElementwise not support this dimension"); + // TODO - 3D, 4D, 5D if constexpr(Dim == 1) return MakeDescriptor_M0_1d(shape, stride, gridSize, threadPerBlock); else if constexpr(Dim == 2)