diff --git a/example/19_binary_elementwise/broadcast_add_2d.cpp b/example/19_binary_elementwise/broadcast_add_2d.cpp index 4a6d2038e3..0f67499984 100644 --- a/example/19_binary_elementwise/broadcast_add_2d.cpp +++ b/example/19_binary_elementwise/broadcast_add_2d.cpp @@ -101,8 +101,7 @@ int main() {Stride, 1}, {0, 1}, // broadcast in first dimension {Stride, 1}, - Add{}, - 256); + Add{}); if(!broadcastAdd.IsSupportedArgument(argument.get())) { diff --git a/example/19_binary_elementwise/elementwise_add_1d.cpp b/example/19_binary_elementwise/elementwise_add_1d.cpp index 5a04d479b6..602e055290 100644 --- a/example/19_binary_elementwise/elementwise_add_1d.cpp +++ b/example/19_binary_elementwise/elementwise_add_1d.cpp @@ -80,8 +80,7 @@ int main() {1}, {1}, {1}, - Add{}, - 256); + Add{}); if(!broadcastAdd.IsSupportedArgument(argument.get())) { diff --git a/example/19_binary_elementwise/elementwise_add_4d.cpp b/example/19_binary_elementwise/elementwise_add_4d.cpp index 92d2878c5a..9d468771f2 100644 --- a/example/19_binary_elementwise/elementwise_add_4d.cpp +++ b/example/19_binary_elementwise/elementwise_add_4d.cpp @@ -82,8 +82,7 @@ int main() ck::to_int_vector(a_m.mDesc.GetStrides()), ck::to_int_vector(b_m.mDesc.GetStrides()), ck::to_int_vector(c_m.mDesc.GetStrides()), - Add{}, - 256); + Add{}); if(!broadcastAdd.IsSupportedArgument(argument.get())) { diff --git a/include/ck/tensor_operation/gpu/device/device_binary_elementwise.hpp b/include/ck/tensor_operation/gpu/device/device_binary_elementwise.hpp index 198bf42ce7..a3a2c89eb7 100644 --- a/include/ck/tensor_operation/gpu/device/device_binary_elementwise.hpp +++ b/include/ck/tensor_operation/gpu/device/device_binary_elementwise.hpp @@ -19,6 +19,11 @@ template struct DeviceBinaryElementwise : public BaseOperator { + DeviceBinaryElementwise(index_t threadPerBlock = 256) + : BaseOperator(), threadPerBlock_(threadPerBlock) + { + } + static constexpr auto I0 = Number<0>{}; template @@ -85,12 +90,11 @@ struct DeviceBinaryElementwise : public BaseOperator p_b_(p_b), p_c_(p_c), functor_(functor), - threadPerBlock_(threadPerBlock), gridSize_(120) // FIXME - Calculate the grid size by number of CU in the future { - a_grid_desc_m0_ = MakeDescriptor_M0(shape, stride_a, gridSize_, threadPerBlock_); - b_grid_desc_m0_ = MakeDescriptor_M0(shape, stride_b, gridSize_, threadPerBlock_); - c_grid_desc_m0_ = MakeDescriptor_M0(shape, stride_c, gridSize_, threadPerBlock_); + a_grid_desc_m0_ = MakeDescriptor_M0(shape, stride_a, gridSize_, threadPerBlock); + b_grid_desc_m0_ = MakeDescriptor_M0(shape, stride_b, gridSize_, threadPerBlock); + c_grid_desc_m0_ = MakeDescriptor_M0(shape, stride_c, gridSize_, threadPerBlock); } const ADataType* p_a_; @@ -100,12 +104,13 @@ struct DeviceBinaryElementwise : public BaseOperator GridDesc_M0 b_grid_desc_m0_; GridDesc_M0 c_grid_desc_m0_; ElementwiseFunctor functor_; - index_t threadPerBlock_; index_t gridSize_; }; struct Invoker : public BaseInvoker { + Invoker(index_t threadPerBlock) : BaseInvoker(), threadPerBlock_(threadPerBlock) {} + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) { const auto kernel = kernel_elementwise_1d(p_arg), stream_config); } + + index_t threadPerBlock_; }; bool IsSupportedArgument(const BaseArgument* p_arg) override @@ -161,8 +168,7 @@ struct DeviceBinaryElementwise : public BaseOperator std::vector stride_a, std::vector stride_b, std::vector stride_c, - ElementwiseFunctor functor, - index_t threadPerBlock) + ElementwiseFunctor functor) { return std::make_unique(static_cast(p_a), static_cast(p_b), @@ -172,12 +178,12 @@ struct DeviceBinaryElementwise : public BaseOperator stride_b, stride_c, functor, - threadPerBlock); + threadPerBlock_); } std::unique_ptr MakeInvokerPointer() { - return std::make_unique(Invoker{}); + return std::make_unique(Invoker{threadPerBlock_}); } std::string GetTypeString() const override @@ -193,6 +199,8 @@ struct DeviceBinaryElementwise : public BaseOperator return str.str(); } + + index_t threadPerBlock_; }; } // namespace device