Add elementwise with dynamic vector dim (#1198)

* Add elementwise with dynamic vector dim

* Reduce number of instaces

* Fixes

* Fixes
This commit is contained in:
Bartłomiej Kocot
2024-03-22 10:40:43 +01:00
committed by GitHub
parent fd0d093e78
commit 9c052804a7
28 changed files with 2157 additions and 359 deletions

View File

@@ -6,7 +6,7 @@
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_elementwise_impl.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_elementwise_dynamic_vector_dims_impl.hpp"
#include "ck/library/utility/algorithm.hpp"
#include "ck/library/utility/check_err.hpp"
@@ -20,15 +20,20 @@ using F32 = float;
using ADataType = F16;
using BDataType = F16;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using DeviceElementwisePermuteInstance =
ck::tensor_operation::device::DeviceElementwiseImpl<ck::Tuple<ADataType>, // InDataTypeTuple
ck::Tuple<BDataType>, // OutDataTypeTuple
PassThrough, // Elementwise op
4, // NumDim
8, // MPerThread
ck::Sequence<8>, // InScalarPerVectorSeq
ck::Sequence<1>>; // OutScalarPerVectorSeq
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using DeviceElementwisePermuteInstance = ck::tensor_operation::device::DeviceElementwiseImpl<
ck::Tuple<ADataType>, // InDataTypeTuple
ck::Tuple<BDataType>, // OutDataTypeTuple
PassThrough, // Elementwise
4, // NumDim
256, // BlockSize
128, // M0PerBlock
128, // M1PerBlock
8, // M0PerThread
8, // M1PerThread
ck::Sequence<1, 0>, // ThreadClusterArrangeOrder
ck::Sequence<8>, // InScalarPerVectorSeq
ck::Sequence<8>>; // OutScalarPerVectorSeq
template <typename HostTensorA, typename HostTensorB, typename Functor>
void host_elementwise4D(HostTensorB& B_nhwc, const HostTensorA& A_nchw, Functor functor)