From b456d5e53ec59c9fe90ee68a6b1c575934fb508d Mon Sep 17 00:00:00 2001 From: rocking Date: Tue, 17 May 2022 20:34:21 +0800 Subject: [PATCH] Add template argument of dim . Prepare to support multiple dimension --- .../19_binary_elementwise/broadcast_add.cpp | 2 +- .../gpu/device/device_binary_elementwise.hpp | 20 +++++++++++++++---- 2 files changed, 17 insertions(+), 5 deletions(-) diff --git a/example/19_binary_elementwise/broadcast_add.cpp b/example/19_binary_elementwise/broadcast_add.cpp index 9dbb38da43..55d1e130bf 100644 --- a/example/19_binary_elementwise/broadcast_add.cpp +++ b/example/19_binary_elementwise/broadcast_add.cpp @@ -26,7 +26,7 @@ using EltwiseComputeDataType = F32; using Add = ck::tensor_operation::binary_element_wise::Add; using DeviceElementwiseAddInstance = ck::tensor_operation::device:: - DeviceBinaryElementwise; + DeviceBinaryElementwise; template struct DeviceBinaryElementwise : public BaseOperator { static constexpr auto I0 = Number<0>{}; - static auto MakeDescriptor_M0(const std::vector& shape, - const std::vector& stride, - index_t gridSize, - index_t threadPerBlock) + static auto MakeDescriptor_M0_2d(const std::vector& shape, + const std::vector& stride, + index_t gridSize, + index_t threadPerBlock) { const int m = shape[0]; const int n = shape[1]; @@ -51,6 +52,17 @@ struct DeviceBinaryElementwise : public BaseOperator return desc_m0_pad; } + static auto MakeDescriptor_M0(const std::vector& shape, + const std::vector& stride, + index_t gridSize, + index_t threadPerBlock) + { + if constexpr(Dim == 2) + return MakeDescriptor_M0_2d(shape, stride, gridSize, threadPerBlock); + else + return make_naive_tensor_descriptor(make_tuple(0), make_tuple(0)); + } + using GridDesc_M0 = decltype(MakeDescriptor_M0({1, 1}, {1, 1}, 1, 1)); using GridwiseBinEltwise = GridwiseBinaryElementwise_1D