Add template argument of dim . Prepare to support multiple dimension

This commit is contained in:
rocking
2022-05-17 20:34:21 +08:00
parent c2626122af
commit b456d5e53e
2 changed files with 17 additions and 5 deletions

View File

@@ -26,7 +26,7 @@ using EltwiseComputeDataType = F32;
using Add = ck::tensor_operation::binary_element_wise::Add;
using DeviceElementwiseAddInstance = ck::tensor_operation::device::
DeviceBinaryElementwise<F16, F16, CDataType, EltwiseComputeDataType, Add, 8>;
DeviceBinaryElementwise<F16, F16, CDataType, EltwiseComputeDataType, Add, 2, 8>;
template <typename HostTensorA,
typename HostTensorB,

View File

@@ -15,15 +15,16 @@ template <typename ADataType,
typename CDataType,
typename ComputeDataType,
typename ElementwiseFunctor,
index_t Dim,
index_t ScalarPerVector>
struct DeviceBinaryElementwise : public BaseOperator
{
static constexpr auto I0 = Number<0>{};
static auto MakeDescriptor_M0(const std::vector<int>& shape,
const std::vector<int>& stride,
index_t gridSize,
index_t threadPerBlock)
static auto MakeDescriptor_M0_2d(const std::vector<int>& shape,
const std::vector<int>& stride,
index_t gridSize,
index_t threadPerBlock)
{
const int m = shape[0];
const int n = shape[1];
@@ -51,6 +52,17 @@ struct DeviceBinaryElementwise : public BaseOperator
return desc_m0_pad;
}
static auto MakeDescriptor_M0(const std::vector<int>& shape,
const std::vector<int>& stride,
index_t gridSize,
index_t threadPerBlock)
{
if constexpr(Dim == 2)
return MakeDescriptor_M0_2d(shape, stride, gridSize, threadPerBlock);
else
return make_naive_tensor_descriptor(make_tuple(0), make_tuple(0));
}
using GridDesc_M0 = decltype(MakeDescriptor_M0({1, 1}, {1, 1}, 1, 1));
using GridwiseBinEltwise = GridwiseBinaryElementwise_1D<ADataType,
BDataType,