diff --git a/example/19_binary_elementwise/broadcast_add.cpp b/example/19_binary_elementwise/broadcast_add.cpp index 9dbb38da43..55d1e130bf 100644 --- a/example/19_binary_elementwise/broadcast_add.cpp +++ b/example/19_binary_elementwise/broadcast_add.cpp @@ -26,7 +26,7 @@ using EltwiseComputeDataType = F32; using Add = ck::tensor_operation::binary_element_wise::Add; using DeviceElementwiseAddInstance = ck::tensor_operation::device:: - DeviceBinaryElementwise; + DeviceBinaryElementwise; template struct DeviceBinaryElementwise : public BaseOperator { static constexpr auto I0 = Number<0>{}; - static auto MakeDescriptor_M0(const std::vector& shape, - const std::vector& stride, - index_t gridSize, - index_t threadPerBlock) + static auto MakeDescriptor_M0_2d(const std::vector& shape, + const std::vector& stride, + index_t gridSize, + index_t threadPerBlock) { const int m = shape[0]; const int n = shape[1]; @@ -51,6 +52,17 @@ struct DeviceBinaryElementwise : public BaseOperator return desc_m0_pad; } + static auto MakeDescriptor_M0(const std::vector& shape, + const std::vector& stride, + index_t gridSize, + index_t threadPerBlock) + { + if constexpr(Dim == 2) + return MakeDescriptor_M0_2d(shape, stride, gridSize, threadPerBlock); + else + return make_naive_tensor_descriptor(make_tuple(0), make_tuple(0)); + } + using GridDesc_M0 = decltype(MakeDescriptor_M0({1, 1}, {1, 1}, 1, 1)); using GridwiseBinEltwise = GridwiseBinaryElementwise_1D