mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-08 15:30:23 +00:00
Add template argument of dim . Prepare to support multiple dimension
This commit is contained in:
@@ -26,7 +26,7 @@ using EltwiseComputeDataType = F32;
|
||||
using Add = ck::tensor_operation::binary_element_wise::Add;
|
||||
|
||||
using DeviceElementwiseAddInstance = ck::tensor_operation::device::
|
||||
DeviceBinaryElementwise<F16, F16, CDataType, EltwiseComputeDataType, Add, 8>;
|
||||
DeviceBinaryElementwise<F16, F16, CDataType, EltwiseComputeDataType, Add, 2, 8>;
|
||||
|
||||
template <typename HostTensorA,
|
||||
typename HostTensorB,
|
||||
|
||||
@@ -15,15 +15,16 @@ template <typename ADataType,
|
||||
typename CDataType,
|
||||
typename ComputeDataType,
|
||||
typename ElementwiseFunctor,
|
||||
index_t Dim,
|
||||
index_t ScalarPerVector>
|
||||
struct DeviceBinaryElementwise : public BaseOperator
|
||||
{
|
||||
static constexpr auto I0 = Number<0>{};
|
||||
|
||||
static auto MakeDescriptor_M0(const std::vector<int>& shape,
|
||||
const std::vector<int>& stride,
|
||||
index_t gridSize,
|
||||
index_t threadPerBlock)
|
||||
static auto MakeDescriptor_M0_2d(const std::vector<int>& shape,
|
||||
const std::vector<int>& stride,
|
||||
index_t gridSize,
|
||||
index_t threadPerBlock)
|
||||
{
|
||||
const int m = shape[0];
|
||||
const int n = shape[1];
|
||||
@@ -51,6 +52,17 @@ struct DeviceBinaryElementwise : public BaseOperator
|
||||
return desc_m0_pad;
|
||||
}
|
||||
|
||||
static auto MakeDescriptor_M0(const std::vector<int>& shape,
|
||||
const std::vector<int>& stride,
|
||||
index_t gridSize,
|
||||
index_t threadPerBlock)
|
||||
{
|
||||
if constexpr(Dim == 2)
|
||||
return MakeDescriptor_M0_2d(shape, stride, gridSize, threadPerBlock);
|
||||
else
|
||||
return make_naive_tensor_descriptor(make_tuple(0), make_tuple(0));
|
||||
}
|
||||
|
||||
using GridDesc_M0 = decltype(MakeDescriptor_M0({1, 1}, {1, 1}, 1, 1));
|
||||
using GridwiseBinEltwise = GridwiseBinaryElementwise_1D<ADataType,
|
||||
BDataType,
|
||||
|
||||
Reference in New Issue
Block a user