diff --git a/composable_kernel/include/tensor_description/static_tensor.hpp b/composable_kernel/include/tensor_description/static_tensor.hpp index e71980b818..b1a816167a 100644 --- a/composable_kernel/include/tensor_description/static_tensor.hpp +++ b/composable_kernel/include/tensor_description/static_tensor.hpp @@ -1,8 +1,6 @@ #ifndef CK_STATIC_TENSOR_HPP #define CK_STATIC_TENSOR_HPP -#include "ignore.hpp" - namespace ck { // StaticTensor for Scalar @@ -17,10 +15,10 @@ struct StaticTensor static constexpr index_t ndim_ = TensorDesc::GetNumOfDimension(); static constexpr index_t element_space_size_ = desc_.GetElementSpaceSize(); - __host__ __device__ constexpr StaticTensor() : invalid_element_value_{0} {} + __host__ __device__ constexpr StaticTensor() : invalid_element_scalar_value_{0} {} __host__ __device__ constexpr StaticTensor(T invalid_element_value) - : invalid_element_value_{invalid_element_value} + : invalid_element_scalar_value_{invalid_element_value} { } @@ -44,11 +42,11 @@ struct StaticTensor { if constexpr(InvalidElementUseNumericalZeroValue) { - return T{0}; + return zero_scalar_value_; } else { - return invalid_element_value_; + return invalid_element_scalar_value_; } } } @@ -71,12 +69,14 @@ struct StaticTensor } else { - return ignore; + return ignored_element_scalar_; } } StaticBuffer data_; - T invalid_element_value_ = T{0}; + static constexpr T zero_scalar_value_ = T{0}; + const T invalid_element_scalar_value_; + T ignored_element_scalar_; }; // StaticTensor for vector @@ -97,10 +97,13 @@ struct StaticTensorTupleOfVectorBuffer using V = vector_type; - __host__ __device__ constexpr StaticTensorTupleOfVectorBuffer() : invalid_element_value_{0} {} + __host__ __device__ constexpr StaticTensorTupleOfVectorBuffer() + : invalid_element_scalar_value_{0} + { + } __host__ __device__ constexpr StaticTensorTupleOfVectorBuffer(S invalid_element_value) - : invalid_element_value_{invalid_element_value} + : invalid_element_scalar_value_{invalid_element_value} { } @@ -125,11 +128,11 @@ struct StaticTensorTupleOfVectorBuffer { if constexpr(InvalidElementUseNumericalZeroValue) { - return S{0}; + return zero_scalar_value_; } else { - return invalid_element_value_; + return invalid_element_scalar_value_; } } } @@ -153,7 +156,7 @@ struct StaticTensorTupleOfVectorBuffer } else { - return ignore; + return ignored_element_scalar_; } } @@ -186,7 +189,7 @@ struct StaticTensorTupleOfVectorBuffer else { // TODO: is this right way to initialize a vector? - return X{invalid_element_value_}; + return X{invalid_element_scalar_value_}; } } } @@ -237,7 +240,9 @@ struct StaticTensorTupleOfVectorBuffer } StaticBufferTupleOfVector data_; - S invalid_element_value_ = S{0}; + static constexpr S zero_scalar_value_ = S{0}; + const S invalid_element_scalar_value_ = S{0}; + S ignored_element_scalar_; }; template -struct BlockwiseTensorSliceTransfer_v4 +struct BlockwiseTensorSliceTransfer_v4r1 { static constexpr index_t nDim = remove_reference_t::GetNumOfDimension(); + static constexpr auto thread_slice_lengths = BlockSliceLengths{} / ThreadClusterLengths{}; + using Index = MultiIndex; - __device__ constexpr BlockwiseTensorSliceTransfer_v4( + __device__ constexpr BlockwiseTensorSliceTransfer_v4r1( const SrcDesc& src_desc, const Index& src_block_slice_origin, + const SrcElementwiseOperation& src_element_op, const DstDesc& dst_desc, const Index& dst_block_slice_origin, - const SrcElementwiseOperation& src_element_op) + const DstElementwiseOperation& dst_element_op) : threadwise_transfer_(src_desc, make_zero_multi_index(), + src_element_op, dst_desc, make_zero_multi_index(), - src_element_op) + dst_element_op) { static_assert(nDim == remove_reference_t>::GetNumOfDimension() && nDim == remove_reference_t>::GetNumOfDimension() && - nDim == BlockSliceLengths::Size() && nDim == ThreadSliceLengths::Size() && nDim == ThreadClusterLengths::Size() && nDim == ThreadClusterArrangeOrder::Size() && nDim == SrcDimAccessOrder::Size() && nDim == DstDimAccessOrder::Size(), "wrong! nDim not consistent"); static_assert( - is_same{}, + is_same{}, "wrong! threads should be mapped to cover entire slicing window"); static_assert(BlockSize >= thread_cluster_desc_.GetElementSize(), @@ -74,7 +77,7 @@ struct BlockwiseTensorSliceTransfer_v4 const auto thread_cluster_idx = thread_cluster_desc_.CalculateBottomIndex( make_multi_index(get_thread_local_1d_id())); - const auto thread_data_idx_begin = thread_cluster_idx * ThreadSliceLengths{}; + const auto thread_data_idx_begin = thread_cluster_idx * thread_slice_lengths; threadwise_transfer_.SetSrcSliceOrigin(src_desc, src_block_slice_origin + thread_data_idx_begin); @@ -114,6 +117,16 @@ struct BlockwiseTensorSliceTransfer_v4 } } + template + __device__ void Run(const SrcDesc& src_desc, + const SrcBuffer& src_buf, + const DstDesc& dst_desc, + DstBuffer& dst_buf) + { + RunRead(src_desc, src_buf); + RunWrite(dst_desc, dst_buf); + } + __device__ void MoveSrcSliceWindow(const SrcDesc& src_desc, const Index& step) { if(BlockSize == thread_cluster_desc_.GetElementSize() or @@ -152,8 +165,9 @@ struct BlockwiseTensorSliceTransfer_v4 make_cluster_descriptor(ThreadClusterLengths{}, ThreadClusterArrangeOrder{}); using ThreadwiseTransfer = - ThreadwiseTensorSliceTransfer_v3r2 -struct BlockwiseTensorSliceTransfer_v4r1 +struct BlockwiseTensorSliceTransfer_v5r1 { static constexpr index_t nDim = remove_reference_t::GetNumOfDimension(); using Index = MultiIndex; - __device__ constexpr BlockwiseTensorSliceTransfer_v4r1(const SrcDesc& src_desc, + __device__ constexpr BlockwiseTensorSliceTransfer_v5r1(const SrcDesc& src_desc, const Index& src_block_slice_origin, const DstDesc& dst_desc, const Index& dst_block_slice_origin) @@ -134,7 +134,7 @@ struct BlockwiseTensorSliceTransfer_v4r1 make_cluster_descriptor(ThreadClusterLengths{}, ThreadClusterArrangeOrder{}); using ThreadwiseTransfer = - ThreadwiseTensorSliceTransfer_v3r1 +struct BlockwiseTensorSliceTransfer_v6r1 +{ + static constexpr index_t nDim = remove_reference_t::GetNumOfDimension(); + + static constexpr auto thread_slice_lengths = BlockSliceLengths{} / ThreadClusterLengths{}; + + using Index = MultiIndex; + + __device__ constexpr BlockwiseTensorSliceTransfer_v6r1(const SrcDesc& src_desc, + const Index& src_block_slice_origin, + const DstDesc& dst_desc, + const Index& dst_block_slice_origin, + const ElementwiseOperation& element_op) + : threadwise_transfer_(src_desc, + make_zero_multi_index(), + dst_desc, + make_zero_multi_index(), + element_op) + + { + static_assert(nDim == remove_reference_t>::GetNumOfDimension() && + nDim == remove_reference_t>::GetNumOfDimension() && + nDim == ThreadClusterLengths::Size() && + nDim == ThreadClusterArrangeOrder::Size() && + nDim == DimAccessOrder::Size(), + "wrong! nDim not consistent"); + + static_assert( + is_same{}, + "wrong! threads should be mapped to cover entire slicing window"); + + static_assert(BlockSize >= thread_cluster_desc_.GetElementSize(), + "wrong! BlockSize too small"); + + if(BlockSize == thread_cluster_desc_.GetElementSize() or + get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize()) + { + const auto thread_cluster_idx = thread_cluster_desc_.CalculateBottomIndex( + make_multi_index(get_thread_local_1d_id())); + + const auto thread_data_idx_begin = thread_cluster_idx * thread_slice_lengths; + + threadwise_transfer_.SetSrcSliceOrigin(src_desc, + src_block_slice_origin + thread_data_idx_begin); + threadwise_transfer_.SetDstSliceOrigin(dst_desc, + dst_block_slice_origin + thread_data_idx_begin); + } + } + + template + __device__ void Run(const SrcDesc& src_desc, + const SrcBuffer& src_buf, + const DstDesc& dst_desc, + DstBuffer& dst_buf) + { + if(BlockSize == thread_cluster_desc_.GetElementSize() or + get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize()) + { + threadwise_transfer_.Run(src_desc, src_buf, dst_desc, dst_buf); + } + } + + __device__ void MoveSrcSliceWindow(const SrcDesc& src_desc, const Index& step) + { + if(BlockSize == thread_cluster_desc_.GetElementSize() or + get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize()) + { + threadwise_transfer_.MoveSrcSliceWindow(src_desc, step); + } + } + + __device__ void MoveDstSliceWindow(const DstDesc& dst_desc, const Index& step) + { + if(BlockSize == thread_cluster_desc_.GetElementSize() or + get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize()) + { + threadwise_transfer_.MoveDstSliceWindow(dst_desc, step); + } + } + + private: + static constexpr auto thread_cluster_desc_ = + make_cluster_descriptor(ThreadClusterLengths{}, ThreadClusterArrangeOrder{}); + + using ThreadwiseTransfer = + ThreadwiseTensorSliceTransfer_v6r1; + + ThreadwiseTransfer threadwise_transfer_; +}; + +} // namespace ck +#endif diff --git a/composable_kernel/include/tensor_operation/blockwise_tensor_slice_transfer_v6r2.hpp b/composable_kernel/include/tensor_operation/blockwise_tensor_slice_transfer_v6r2.hpp new file mode 100644 index 0000000000..c92681fe91 --- /dev/null +++ b/composable_kernel/include/tensor_operation/blockwise_tensor_slice_transfer_v6r2.hpp @@ -0,0 +1,157 @@ +#ifndef CK_BLOCKWISE_TENSOR_SLICE_TRANSFER_V6R2_HPP +#define CK_BLOCKWISE_TENSOR_SLICE_TRANSFER_V6R2_HPP + +#include "common_header.hpp" +#include "tensor_descriptor.hpp" +#include "tensor_descriptor_helper.hpp" +#include "cluster_descriptor.hpp" +#include "threadwise_tensor_slice_transfer_v6r2.hpp" + +namespace ck { + +// this version does following things to avoid scratch memory issue +// 1. Use StaticallyIndexedArray instead of C array for thread buffer +// 2. It does not keep reference to tensor descriptor +// 3. Run() does not construct new tensor coordinate +template +struct BlockwiseTensorSliceTransfer_v6r2 +{ + static constexpr index_t nDim = remove_reference_t::GetNumOfDimension(); + + static constexpr auto thread_slice_lengths = BlockSliceLengths{} / ThreadClusterLengths{}; + + using Index = MultiIndex; + + __device__ constexpr BlockwiseTensorSliceTransfer_v6r2(const Src0Desc& src0_desc, + const Index& src0_block_slice_origin, + const Src1Desc& src1_desc, + const Index& src1_block_slice_origin, + const DstDesc& dst_desc, + const Index& dst_block_slice_origin, + const ElementwiseOperation& element_op) + : threadwise_transfer_(src0_desc, + make_zero_multi_index(), + src1_desc, + make_zero_multi_index(), + dst_desc, + make_zero_multi_index(), + element_op) + + { + static_assert(nDim == remove_reference_t>::GetNumOfDimension() && + nDim == remove_reference_t>::GetNumOfDimension() && + nDim == remove_reference_t>::GetNumOfDimension() && + nDim == ThreadClusterLengths::Size() && + nDim == ThreadClusterArrangeOrder::Size() && + nDim == DimAccessOrder::Size(), + "wrong! nDim not consistent"); + + static_assert( + is_same{}, + "wrong! threads should be mapped to cover entire slicing window"); + + static_assert(BlockSize >= thread_cluster_desc_.GetElementSize(), + "wrong! BlockSize too small"); + + if(BlockSize == thread_cluster_desc_.GetElementSize() or + get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize()) + { + const auto thread_cluster_idx = thread_cluster_desc_.CalculateBottomIndex( + make_multi_index(get_thread_local_1d_id())); + + const auto thread_data_idx_begin = thread_cluster_idx * thread_slice_lengths; + + threadwise_transfer_.SetSrc0SliceOrigin( + src0_desc, src0_block_slice_origin + thread_data_idx_begin); + threadwise_transfer_.SetSrc1SliceOrigin( + src1_desc, src1_block_slice_origin + thread_data_idx_begin); + threadwise_transfer_.SetDstSliceOrigin(dst_desc, + dst_block_slice_origin + thread_data_idx_begin); + } + } + + template + __device__ void Run(const Src0Desc& src0_desc, + const Src0Buffer& src0_buf, + const Src1Desc& src1_desc, + const Src1Buffer& src1_buf, + const DstDesc& dst_desc, + DstBuffer& dst_buf) + { + if(BlockSize == thread_cluster_desc_.GetElementSize() or + get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize()) + { + threadwise_transfer_.Run(src0_desc, src0_buf, src1_desc, src1_buf, dst_desc, dst_buf); + } + } + + __device__ void MoveSrc0SliceWindow(const Src0Desc& src0_desc, const Index& step) + { + if(BlockSize == thread_cluster_desc_.GetElementSize() or + get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize()) + { + threadwise_transfer_.MoveSrc0SliceWindow(src0_desc, step); + } + } + + __device__ void MoveSrc1SliceWindow(const Src1Desc& src1_desc, const Index& step) + { + if(BlockSize == thread_cluster_desc_.GetElementSize() or + get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize()) + { + threadwise_transfer_.MoveSrc1SliceWindow(src1_desc, step); + } + } + + __device__ void MoveDstSliceWindow(const DstDesc& dst_desc, const Index& step) + { + if(BlockSize == thread_cluster_desc_.GetElementSize() or + get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize()) + { + threadwise_transfer_.MoveDstSliceWindow(dst_desc, step); + } + } + + private: + static constexpr auto thread_cluster_desc_ = + make_cluster_descriptor(ThreadClusterLengths{}, ThreadClusterArrangeOrder{}); + + using ThreadwiseTransfer = + ThreadwiseTensorSliceTransfer_v6r2; + + ThreadwiseTransfer threadwise_transfer_; +}; + +} // namespace ck +#endif diff --git a/composable_kernel/include/tensor_operation/blockwise_tensor_slice_transfer_v6r3.hpp b/composable_kernel/include/tensor_operation/blockwise_tensor_slice_transfer_v6r3.hpp new file mode 100644 index 0000000000..f9840b4a20 --- /dev/null +++ b/composable_kernel/include/tensor_operation/blockwise_tensor_slice_transfer_v6r3.hpp @@ -0,0 +1,182 @@ +#ifndef CK_BLOCKWISE_TENSOR_SLICE_TRANSFER_V6R3_HPP +#define CK_BLOCKWISE_TENSOR_SLICE_TRANSFER_V6R3_HPP + +#include "common_header.hpp" +#include "tensor_descriptor.hpp" +#include "tensor_descriptor_helper.hpp" +#include "cluster_descriptor.hpp" +#include "threadwise_tensor_slice_transfer_v6r3.hpp" + +namespace ck { + +// this version does following things to avoid scratch memory issue +// 1. Use StaticallyIndexedArray instead of C array for thread buffer +// 2. ThreadwiseTensorSliceTransfer_v3 does not keep reference to tensor descriptor +// 3. ThreadwiseTensorSliceTransfer_v3::Run() does not construct new tensor coordinate +template +struct BlockwiseTensorSliceTransfer_v6r3 +{ + static constexpr index_t nDim = remove_reference_t::GetNumOfDimension(); + + static constexpr auto thread_slice_lengths = BlockSliceLengths{} / ThreadClusterLengths{}; + + using Index = MultiIndex; + + __device__ constexpr BlockwiseTensorSliceTransfer_v6r3(const Src0Desc& src0_desc, + const Index& src0_block_slice_origin, + const Src1Desc& src1_desc, + const Index& src1_block_slice_origin, + const Src2Desc& src2_desc, + const Index& src2_block_slice_origin, + const DstDesc& dst_desc, + const Index& dst_block_slice_origin, + const ElementwiseOperation& element_op) + : threadwise_transfer_(src0_desc, + make_zero_multi_index(), + src1_desc, + make_zero_multi_index(), + src2_desc, + make_zero_multi_index(), + dst_desc, + make_zero_multi_index(), + element_op) + + { + static_assert(nDim == remove_reference_t>::GetNumOfDimension() && + nDim == remove_reference_t>::GetNumOfDimension() && + nDim == remove_reference_t>::GetNumOfDimension() && + nDim == remove_reference_t>::GetNumOfDimension() && + nDim == ThreadClusterLengths::Size() && + nDim == ThreadClusterArrangeOrder::Size() && + nDim == DimAccessOrder::Size(), + "wrong! nDim not consistent"); + + static_assert( + is_same{}, + "wrong! threads should be mapped to cover entire slicing window"); + + static_assert(BlockSize >= thread_cluster_desc_.GetElementSize(), + "wrong! BlockSize too small"); + + if(BlockSize == thread_cluster_desc_.GetElementSize() or + get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize()) + { + const auto thread_cluster_idx = thread_cluster_desc_.CalculateBottomIndex( + make_multi_index(get_thread_local_1d_id())); + + const auto thread_data_idx_begin = thread_cluster_idx * thread_slice_lengths; + + threadwise_transfer_.SetSrc0SliceOrigin( + src0_desc, src0_block_slice_origin + thread_data_idx_begin); + threadwise_transfer_.SetSrc1SliceOrigin( + src1_desc, src1_block_slice_origin + thread_data_idx_begin); + threadwise_transfer_.SetSrc2SliceOrigin( + src2_desc, src2_block_slice_origin + thread_data_idx_begin); + threadwise_transfer_.SetDstSliceOrigin(dst_desc, + dst_block_slice_origin + thread_data_idx_begin); + } + } + + template + __device__ void Run(const Src0Desc& src0_desc, + const Src0Buffer& src0_buf, + const Src1Desc& src1_desc, + const Src1Buffer& src1_buf, + const Src2Desc& src2_desc, + const Src2Buffer& src2_buf, + const DstDesc& dst_desc, + DstBuffer& dst_buf) + { + if(BlockSize == thread_cluster_desc_.GetElementSize() or + get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize()) + { + threadwise_transfer_.Run( + src0_desc, src0_buf, src1_desc, src1_buf, src2_desc, src2_buf, dst_desc, dst_buf); + } + } + + __device__ void MoveSrc0SliceWindow(const Src0Desc& src0_desc, const Index& step) + { + if(BlockSize == thread_cluster_desc_.GetElementSize() or + get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize()) + { + threadwise_transfer_.MoveSrc0SliceWindow(src0_desc, step); + } + } + + __device__ void MoveSrc1SliceWindow(const Src1Desc& src1_desc, const Index& step) + { + if(BlockSize == thread_cluster_desc_.GetElementSize() or + get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize()) + { + threadwise_transfer_.MoveSrc1SliceWindow(src1_desc, step); + } + } + + __device__ void MoveSrc2SliceWindow(const Src2Desc& src2_desc, const Index& step) + { + if(BlockSize == thread_cluster_desc_.GetElementSize() or + get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize()) + { + threadwise_transfer_.MoveSrc2SliceWindow(src2_desc, step); + } + } + + __device__ void MoveDstSliceWindow(const DstDesc& dst_desc, const Index& step) + { + if(BlockSize == thread_cluster_desc_.GetElementSize() or + get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize()) + { + threadwise_transfer_.MoveDstSliceWindow(dst_desc, step); + } + } + + private: + static constexpr auto thread_cluster_desc_ = + make_cluster_descriptor(ThreadClusterLengths{}, ThreadClusterArrangeOrder{}); + + using ThreadwiseTransfer = + ThreadwiseTensorSliceTransfer_v6r3; + + ThreadwiseTransfer threadwise_transfer_; +}; + +} // namespace ck +#endif diff --git a/composable_kernel/include/tensor_operation/element_wise_operation.hpp b/composable_kernel/include/tensor_operation/element_wise_operation.hpp new file mode 100644 index 0000000000..306102f4fb --- /dev/null +++ b/composable_kernel/include/tensor_operation/element_wise_operation.hpp @@ -0,0 +1,185 @@ +#ifndef CK_ELEMENT_WISE_OPERATION_HPP +#define CK_ELEMENT_WISE_OPERATION_HPP + +namespace ck { +namespace tensor_operation { +namespace element_wise { + +struct PassThrough +{ + template + __host__ __device__ void operator()(T& y, const T& x) const + { + y = x; + } + + // TODO remove this + template + __host__ __device__ constexpr T operator()(T v) const + { + return v; + } +}; + +struct AddRelu +{ + template + __host__ __device__ constexpr void operator()(T& y, const T& x0, const T& x1) const + { + T a = x0 + x1; + y = a > 0 ? a : 0; + } + + // TODO remove this + template + __host__ constexpr float operator()(float v0, T1 v1) const + { + float b = v0 + v1; + float c = b > 0 ? b : 0; + + return c; + } + + // TODO remove this + template + __device__ constexpr float operator()(float v0, T1 v1) const + { +#if 0 + float a = v1 + v0; + float b = max(a, float(0)); + + return b; +#else + float b = v1 + v0; + float c = b > 0 ? b : 0; + + return c; +#endif + } +}; + +struct AddReluAdd +{ + template + __host__ __device__ constexpr void operator()(T& y, const T& x0, const T& x1, const T& x2) const + { + T a = x0 + x1; + T b = a > 0 ? a : 0; + y = b + x2; + } + + // TODO remove this + template + __host__ constexpr float operator()(float v0, T1 v1, T2 v2) const + { + float b = v0 + v1; + float c = b > 0 ? b : 0; + float d = c + v2; + + return d; + } + + // TODO remove this + template + __device__ constexpr float operator()(float v0, T1 v1, T2 v2) const + { +#if 0 + float a = v1 + v0; + float b = max(a, float(0)); + float c = b + v2; + + return c; +#else + float b = v1 + v2; + float c = (v0 > -v1) ? b + v0 : v2; + + return c; +#endif + } +}; + +} // namespace element_wise +} // namespace tensor_operation +} // namespace ck + +namespace ck { +namespace tensor_operation { +namespace element_wise { + +struct AddLeakyReluAdd +{ + template + __host__ constexpr float operator()(float v0, T1 v1, T2 v2) const + { + float a = v0 + v1; + float b = 0.1 * a; + float c = b > 0 ? b : 0; + float d = c + v2; + + return d; + } + + template + __device__ constexpr float operator()(float v0, T1 v1, T2 v2) const + { +#if 0 + // this use not too many registers, but use fp64 mul + float a = v0 + v1; + float b = 0.1 * a; + float c = b > 0 ? b : 0; + float d = c + v2; + + return d; +#elif 0 + // this spill register + float a = v0 + v1; + float b = float(0.1) * a; + float c = b > 0 ? b : 0; + float d = c + v2; + + return d; +#elif 0 + // this use lots of registers (but no spill) + constexpr float alpha = 0.1; + constexpr float alpha_inv = 1.0 / alpha; + + float a = v2 * alpha_inv; + float b = v1 + v0; + float c = b > 0 ? b : 0; + float d = alpha * (a + c); + + return d; +#elif 1 + // this use lots of registers (but no spill), 89 Tflops + constexpr float alpha = 0.1; + constexpr float alpha_inv = 1.0 / alpha; + + float a = v2 * alpha_inv; + float b = v1 + v0; + float c = max(b, float(0)); + float d = alpha * (a + c); + + return d; +#elif 1 + // this spill registers, 89 Tflops + float a = v0 + v1; + float alpha = 0.1; + + float b; + asm volatile("\n \ + v_mul_f32_e32 %0, %1, %2 \n \ + " + : "=v"(b) + : "s"(alpha), "v"(a)); + + float c = b > 0 ? b : 0; + float d = c + v2; + + return d; +#endif + } +}; +} // namespace element_wise +} // namespace tensor_operation +} // namespace ck +#endif diff --git a/composable_kernel/include/tensor_operation/gridwise_contraction_dlops_v1r2.hpp b/composable_kernel/include/tensor_operation/gridwise_contraction_dlops_v1r2.hpp index fe56d0d813..50e8f52c59 100644 --- a/composable_kernel/include/tensor_operation/gridwise_contraction_dlops_v1r2.hpp +++ b/composable_kernel/include/tensor_operation/gridwise_contraction_dlops_v1r2.hpp @@ -381,7 +381,7 @@ struct GridwiseContractionDlops_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_GM0_GM1_GN "wrong!"); // A matrix blockwise copy - auto a_blockwise_copy = BlockwiseTensorSliceTransfer_v4r1< + auto a_blockwise_copy = BlockwiseTensorSliceTransfer_v5r1< BlockSize, InMemoryDataOperationEnum_t::Set, Sequence, @@ -405,7 +405,7 @@ struct GridwiseContractionDlops_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_GM0_GM1_GN make_multi_index(0, 0, 0, 0, 0)); // B matrix blockwise copy - auto b_blockwise_copy = BlockwiseTensorSliceTransfer_v4r1< + auto b_blockwise_copy = BlockwiseTensorSliceTransfer_v5r1< BlockSize, InMemoryDataOperationEnum_t::Set, Sequence, diff --git a/composable_kernel/include/tensor_operation/gridwise_gemm_dlops_v1r3.hpp b/composable_kernel/include/tensor_operation/gridwise_gemm_dlops_v1r3.hpp index 2653dd4340..32b6c31200 100644 --- a/composable_kernel/include/tensor_operation/gridwise_gemm_dlops_v1r3.hpp +++ b/composable_kernel/include/tensor_operation/gridwise_gemm_dlops_v1r3.hpp @@ -6,7 +6,7 @@ #include "tensor_descriptor.hpp" #include "tensor_descriptor_helper.hpp" #include "blockwise_gemm_dlops_v2r3.hpp" -#include "blockwise_tensor_slice_transfer_v2.hpp" +#include "blockwise_tensor_slice_transfer_v5r1.hpp" #include "threadwise_tensor_slice_transfer_v2.hpp" #include "threadwise_tensor_slice_set.hpp" @@ -380,7 +380,7 @@ struct GridwiseGemmDlops_km_kn_mn_v1r3 "wrong!"); // A matrix blockwise copy - auto a_blockwise_copy = BlockwiseTensorSliceTransfer_v4r1< + auto a_blockwise_copy = BlockwiseTensorSliceTransfer_v5r1< BlockSize, InMemoryDataOperationEnum_t::Set, Sequence, @@ -404,7 +404,7 @@ struct GridwiseGemmDlops_km_kn_mn_v1r3 make_multi_index(0, 0, 0, 0)); // B matrix blockwise copy - auto b_blockwise_copy = BlockwiseTensorSliceTransfer_v4r1< + auto b_blockwise_copy = BlockwiseTensorSliceTransfer_v5r1< BlockSize, InMemoryDataOperationEnum_t::Set, Sequence, diff --git a/composable_kernel/include/tensor_operation/gridwise_gemm_xdlops_v2r3.hpp b/composable_kernel/include/tensor_operation/gridwise_gemm_xdlops_v2r3.hpp index b312491bb0..0db11aedef 100644 --- a/composable_kernel/include/tensor_operation/gridwise_gemm_xdlops_v2r3.hpp +++ b/composable_kernel/include/tensor_operation/gridwise_gemm_xdlops_v2r3.hpp @@ -6,9 +6,8 @@ #include "tensor_descriptor.hpp" #include "tensor_descriptor_helper.hpp" #include "blockwise_gemm_xdlops.hpp" -#include "blockwise_tensor_slice_transfer.hpp" +#include "blockwise_tensor_slice_transfer_v4r1.hpp" #include "threadwise_tensor_slice_transfer.hpp" -#include "threadwise_tensor_slice_set.hpp" namespace ck { @@ -40,15 +39,12 @@ __global__ void const CElementwiseOperation c_element_op, const Block2CTileMap block_2_ctile_map) { - constexpr index_t shared_block_size = - GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatAB); - - __shared__ FloatAB p_shared_block[shared_block_size]; + __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; GridwiseGemm::template Run(p_a_grid, p_b_grid, p_c_grid, - p_shared_block, + p_shared, a_grid_desc_k0_m_k1, b_grid_desc_k0_n_k1, c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2, @@ -83,9 +79,6 @@ __global__ void const void CONSTANT* p_c_element_op, const void CONSTANT* p_block_2_ctile_map) { - constexpr index_t shared_block_size = - GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatAB); - const auto a_grid_desc_k0_m_k1 = *reinterpret_cast( cast_pointer_to_generic_address_space(p_a_grid_desc_k0_m_k1)); const auto b_grid_desc_k0_n_k1 = *reinterpret_cast( @@ -102,12 +95,12 @@ __global__ void const auto c_element_op = *reinterpret_cast( cast_pointer_to_generic_address_space(p_c_element_op)); - __shared__ FloatAB p_shared_block[shared_block_size]; + __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; GridwiseGemm::template Run(p_a_grid, p_b_grid, p_c_grid, - p_shared_block, + p_shared, a_grid_desc_k0_m_k1, b_grid_desc_k0_n_k1, c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2, @@ -135,9 +128,8 @@ template + index_t CThreadTransferDstScalarPerVector> struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 { static constexpr auto I0 = Number<0>{}; @@ -178,7 +163,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 // K1 should be Number<...> static constexpr auto K1 = Number{}; - __host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte() + __host__ __device__ static constexpr auto GetABlockDescriptor_K0PerBlock_MPerBlock_K1() { constexpr auto max_lds_align = K1; @@ -197,6 +182,13 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 } }(); + return a_block_desc_k0_m_k1; + } + + __host__ __device__ static constexpr auto GetBBlockDescriptor_K0PerBlock_NPerBlock_K1() + { + constexpr auto max_lds_align = K1; + // B matrix in LDS memory, dst of blockwise copy constexpr auto b_block_desc_k0_n_k1 = [&]() { if constexpr(BBlockLdsExtraN) @@ -212,14 +204,25 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 } }(); + return b_block_desc_k0_n_k1; + } + + __host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte() + { // LDS allocation for A and B: be careful of alignment - constexpr auto a_block_space_size = + constexpr auto a_block_desc_k0_m_k1 = GetABlockDescriptor_K0PerBlock_MPerBlock_K1(); + + constexpr auto b_block_desc_k0_n_k1 = GetBBlockDescriptor_K0PerBlock_NPerBlock_K1(); + + constexpr auto max_lds_align = K1; + + constexpr auto a_block_space_size_aligned = math::integer_least_multiple(a_block_desc_k0_m_k1.GetElementSpaceSize(), max_lds_align); - constexpr auto b_block_space_size = + constexpr auto b_block_space_size_aligned = math::integer_least_multiple(b_block_desc_k0_n_k1.GetElementSpaceSize(), max_lds_align); - return (a_block_space_size + b_block_space_size) * sizeof(FloatAB); + return (a_block_space_size_aligned + b_block_space_size_aligned) * sizeof(FloatAB); } // block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01} @@ -233,8 +236,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 static_assert(is_known_at_compile_time>::value, "wrong! K1 need to be known at compile-time"); - static_assert((MPerBlock % (MPerXDL * MRepeat) == 0) && - (NPerBlock % (NRepeat * NPerXDL)) == 0, + static_assert((MPerBlock % (MPerXDL * MXdlPerWave) == 0) && + (NPerBlock % (NXdlPerWave * NPerXDL)) == 0, "Invalid tuning param!"); const auto M = a_grid_desc_k0_m_k1.GetLength(I1); @@ -324,8 +327,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 decltype(b_block_desc_k0_n_k1), MPerXDL, NPerXDL, - MRepeat, - NRepeat, + MXdlPerWave, + NXdlPerWave, K1>; return BlockwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(c_grid_desc_m_n); @@ -376,7 +379,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 Run(const FloatAB* __restrict__ p_a_grid, const FloatAB* __restrict__ p_b_grid, FloatC* __restrict__ p_c_grid, - FloatAB* __restrict__ p_shared_block, + void* __restrict__ p_shared, const AGridDesc_K0_M_K1& a_grid_desc_k0_m_k1, const BGridDesc_K0_N_K1& b_grid_desc_k0_n_k1, const CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2& c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2, @@ -409,90 +412,70 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 constexpr auto max_lds_align = K1; // A matrix in LDS memory, dst of blockwise copy - constexpr auto a_block_desc_k0_m_k1 = [&]() { - if constexpr(ABlockLdsExtraM) - { - return make_naive_tensor_descriptor( - make_tuple(Number{}, Number{}, K1), - make_tuple(Number{} * K1, K1, I1)); - } - else - { - return make_naive_tensor_descriptor_aligned( - make_tuple(Number{}, Number{}, K1), max_lds_align); - } - }(); + constexpr auto a_block_desc_k0_m_k1 = GetABlockDescriptor_K0PerBlock_MPerBlock_K1(); // B matrix in LDS memory, dst of blockwise copy - constexpr auto b_block_desc_k0_n_k1 = [&]() { - if constexpr(BBlockLdsExtraN) - { - return make_naive_tensor_descriptor( - make_tuple(Number{}, Number{}, K1), - make_tuple(Number{} * K1, K1, I1)); - } - else - { - return make_naive_tensor_descriptor_aligned( - make_tuple(Number{}, Number{}, K1), max_lds_align); - } - }(); + constexpr auto b_block_desc_k0_n_k1 = GetBBlockDescriptor_K0PerBlock_NPerBlock_K1(); // A matrix blockwise copy auto a_blockwise_copy = - BlockwiseTensorSliceTransfer_v4, - ABlockTransferThreadSliceLengths_K0_M_K1, - ABlockTransferThreadClusterLengths_K0_M_K1, - ABlockTransferThreadClusterArrangeOrder, - FloatAB, - FloatAB, - decltype(a_grid_desc_k0_m_k1), - decltype(a_block_desc_k0_m_k1), - ABlockTransferSrcAccessOrder, - Sequence<1, 0, 2>, - ABlockTransferSrcVectorDim, - 2, - ABlockTransferSrcScalarPerVector, - ABlockTransferDstScalarPerVector_K1, - 1, - 1, - AThreadTransferSrcResetCoordinateAfterRun, - true>(a_grid_desc_k0_m_k1, - make_multi_index(0, m_block_data_idx_on_grid, 0), - a_block_desc_k0_m_k1, - make_multi_index(0, 0, 0), - a_element_op); + BlockwiseTensorSliceTransfer_v4r1, + ABlockTransferThreadClusterLengths_K0_M_K1, + ABlockTransferThreadClusterArrangeOrder, + FloatAB, + FloatAB, + decltype(a_grid_desc_k0_m_k1), + decltype(a_block_desc_k0_m_k1), + ABlockTransferSrcAccessOrder, + Sequence<1, 0, 2>, + ABlockTransferSrcVectorDim, + 2, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_K1, + 1, + 1, + AThreadTransferSrcResetCoordinateAfterRun, + true>( + a_grid_desc_k0_m_k1, + make_multi_index(0, m_block_data_idx_on_grid, 0), + a_element_op, + a_block_desc_k0_m_k1, + make_multi_index(0, 0, 0), + ck::tensor_operation::element_wise::PassThrough{}); // B matrix blockwise copy auto b_blockwise_copy = - BlockwiseTensorSliceTransfer_v4, - BBlockTransferThreadSliceLengths_K0_N_K1, - BBlockTransferThreadClusterLengths_K0_N_K1, - BBlockTransferThreadClusterArrangeOrder, - FloatAB, - FloatAB, - decltype(b_grid_desc_k0_n_k1), - decltype(b_block_desc_k0_n_k1), - BBlockTransferSrcAccessOrder, - Sequence<1, 0, 2>, - BBlockTransferSrcVectorDim, - 2, - BBlockTransferSrcScalarPerVector, - BBlockTransferDstScalarPerVector_K1, - 1, - 1, - BThreadTransferSrcResetCoordinateAfterRun, - true>(b_grid_desc_k0_n_k1, - make_multi_index(0, n_block_data_idx_on_grid, 0), - b_block_desc_k0_n_k1, - make_multi_index(0, 0, 0), - b_element_op); + BlockwiseTensorSliceTransfer_v4r1, + BBlockTransferThreadClusterLengths_K0_N_K1, + BBlockTransferThreadClusterArrangeOrder, + FloatAB, + FloatAB, + decltype(b_grid_desc_k0_n_k1), + decltype(b_block_desc_k0_n_k1), + BBlockTransferSrcAccessOrder, + Sequence<1, 0, 2>, + BBlockTransferSrcVectorDim, + 2, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_K1, + 1, + 1, + BThreadTransferSrcResetCoordinateAfterRun, + true>( + b_grid_desc_k0_n_k1, + make_multi_index(0, n_block_data_idx_on_grid, 0), + b_element_op, + b_block_desc_k0_n_k1, + make_multi_index(0, 0, 0), + ck::tensor_operation::element_wise::PassThrough{}); // GEMM definition // c_mtx += transpose(a_mtx) * b_mtx @@ -510,68 +493,53 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 decltype(b_block_desc_k0_n_k1), MPerXDL, NPerXDL, - MRepeat, - NRepeat, + MXdlPerWave, + NXdlPerWave, K1>{}; auto c_thread_buf = blockwise_gemm.GetCThreadBuffer(); // LDS allocation for A and B: be careful of alignment - constexpr auto a_block_space_size = + constexpr auto a_block_space_size_aligned = math::integer_least_multiple(a_block_desc_k0_m_k1.GetElementSpaceSize(), max_lds_align); - FloatAB* p_a_block = p_shared_block; - FloatAB* p_b_block = p_shared_block + a_block_space_size; + auto a_block_buf = make_dynamic_buffer( + static_cast(p_shared), a_block_desc_k0_m_k1.GetElementSpaceSize()); + + auto b_block_buf = make_dynamic_buffer( + static_cast(p_shared) + a_block_space_size_aligned, + b_block_desc_k0_n_k1.GetElementSpaceSize()); constexpr auto a_block_slice_copy_step = make_multi_index(K0PerBlock, 0, 0); constexpr auto b_block_slice_copy_step = make_multi_index(K0PerBlock, 0, 0); - // hack to control index calculation when iterating over A and B matrix for threadwise copy - constexpr auto a_k0_m_k1_grid_step_hacks = AGridStepHacks{}; - constexpr auto b_k0_n_k1_grid_step_hacks = BGridStepHacks{}; - - // hack to control index calculation when move slice window for A and B matrix for - // threadwise copy - constexpr auto a_k0_m_k1_grid_move_slice_window_step_hack = AGridMoveSliceWindowStepHacks{}; - constexpr auto b_k0_n_k1_grid_move_slice_window_step_hack = BGridMoveSliceWindowStepHacks{}; - - auto a_block_buf = make_dynamic_buffer( - p_a_block, a_block_desc_k0_m_k1.GetElementSpaceSize()); - auto b_block_buf = make_dynamic_buffer( - p_b_block, b_block_desc_k0_n_k1.GetElementSpaceSize()); - // preload data into LDS { - a_blockwise_copy.RunRead(a_grid_desc_k0_m_k1, a_grid_buf, a_k0_m_k1_grid_step_hacks); - b_blockwise_copy.RunRead(b_grid_desc_k0_n_k1, b_grid_buf, b_k0_n_k1_grid_step_hacks); + a_blockwise_copy.RunRead(a_grid_desc_k0_m_k1, a_grid_buf); + b_blockwise_copy.RunRead(b_grid_desc_k0_n_k1, b_grid_buf); a_blockwise_copy.RunWrite(a_block_desc_k0_m_k1, a_block_buf); b_blockwise_copy.RunWrite(b_block_desc_k0_n_k1, b_block_buf); } - // main body - index_t k0_block_data_begin = 0; - + // Initialize C c_thread_buf.Clear(); + // main body if constexpr(HasMainKBlockLoop) { + index_t k0_block_data_begin = 0; + do { - a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc_k0_m_k1, - a_block_slice_copy_step, - a_k0_m_k1_grid_move_slice_window_step_hack); - b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc_k0_n_k1, - b_block_slice_copy_step, - b_k0_n_k1_grid_move_slice_window_step_hack); + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc_k0_m_k1, a_block_slice_copy_step); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc_k0_n_k1, b_block_slice_copy_step); - a_blockwise_copy.RunRead( - a_grid_desc_k0_m_k1, a_grid_buf, a_k0_m_k1_grid_step_hacks); + a_blockwise_copy.RunRead(a_grid_desc_k0_m_k1, a_grid_buf); block_sync_lds(); - b_blockwise_copy.RunRead( - b_grid_desc_k0_n_k1, b_grid_buf, b_k0_n_k1_grid_step_hacks); + b_blockwise_copy.RunRead(b_grid_desc_k0_n_k1, b_grid_buf); blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf); @@ -619,8 +587,6 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 const index_t n_thread_data_on_grid = n_block_data_idx_on_grid + c_thread_mtx_on_block[I1]; - constexpr auto c_m0_n0_m1_n1_m2_m3_m4_n2_grid_tensor_step_hacks = CGridStepHacks{}; - const auto m_thread_data_on_grid_to_m0_m1_m2_m3_m4_adaptor = make_single_stage_tensor_adaptor( make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))), @@ -668,11 +634,10 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 make_tuple(I0, I0, I0, I0, I0, I0, I0, I0), c_thread_buf, c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2, - c_grid_buf, - c_m0_n0_m1_n1_m2_m3_m4_n2_grid_tensor_step_hacks); + c_grid_buf); } } -}; // namespace ck +}; } // namespace ck #endif diff --git a/composable_kernel/include/tensor_operation/gridwise_gemm_xdlops_v2r4.hpp b/composable_kernel/include/tensor_operation/gridwise_gemm_xdlops_v2r4.hpp index 9d524a55bc..39a910a6ff 100644 --- a/composable_kernel/include/tensor_operation/gridwise_gemm_xdlops_v2r4.hpp +++ b/composable_kernel/include/tensor_operation/gridwise_gemm_xdlops_v2r4.hpp @@ -6,9 +6,8 @@ #include "tensor_descriptor.hpp" #include "tensor_descriptor_helper.hpp" #include "blockwise_gemm_xdlops.hpp" -#include "blockwise_tensor_slice_transfer.hpp" +#include "blockwise_tensor_slice_transfer_v4r1.hpp" #include "threadwise_tensor_slice_transfer.hpp" -#include "threadwise_tensor_slice_set.hpp" namespace ck { @@ -19,6 +18,9 @@ template __global__ void @@ -31,6 +33,9 @@ __global__ void const ABK0MK1GridDesc a_b_k0_m_k1_grid_desc, const BBK0NK1GridDesc b_b_k0_n_k1_grid_desc, const CM0N0M1N1M2M3M4N2GridDesc c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc, + const AElementwiseOperation a_element_op, + const BElementwiseOperation b_element_op, + const CElementwiseOperation c_element_op, const CBlockClusterAdaptor c_block_cluster_adaptor) { constexpr index_t shared_block_size = @@ -45,6 +50,9 @@ __global__ void a_b_k0_m_k1_grid_desc, b_b_k0_n_k1_grid_desc, c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc, + a_element_op, + b_element_op, + c_element_op, c_block_cluster_adaptor); } #elif CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VOID_POINTER @@ -129,11 +137,6 @@ template @@ -371,6 +374,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4 c_block_cluster_adaptor.CalculateBottomIndex(make_multi_index(get_block_1d_id())); const index_t k_batch_id = block_work_idx[I0]; + // HACK: this force m/n_block_data_idx_on_grid into SGPR const index_t m_block_data_idx_on_grid = __builtin_amdgcn_readfirstlane(block_work_idx[I1] * MPerBlock); @@ -447,57 +451,65 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4 }(); // A matrix blockwise copy auto a_blockwise_copy = - BlockwiseTensorSliceTransfer_v4, - ABlockTransferThreadSliceLengths_K0_M_K1, - ABlockTransferThreadClusterLengths_K0_M_K1, - ABlockTransferThreadClusterArrangeOrder, - FloatAB, - FloatAB, - decltype(a_b_k0_m_k1_grid_desc), - decltype(a_b_k0_m_k1_block_desc), - ABlockTransferSrcAccessOrder, - Sequence<0, 2, 1, 3>, - ABlockTransferSrcVectorDim, - 3, - ABlockTransferSrcScalarPerVector, - ABlockTransferDstScalarPerVector_K1, - 1, - 1, - AThreadTransferSrcResetCoordinateAfterRun, - true>( + BlockwiseTensorSliceTransfer_v4r1, + ABlockTransferThreadSliceLengths_K0_M_K1, + ABlockTransferThreadClusterLengths_K0_M_K1, + ABlockTransferThreadClusterArrangeOrder, + FloatAB, + FloatAB, + decltype(a_b_k0_m_k1_grid_desc), + decltype(a_b_k0_m_k1_block_desc), + ABlockTransferSrcAccessOrder, + Sequence<0, 2, 1, 3>, + ABlockTransferSrcVectorDim, + 3, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_K1, + 1, + 1, + AThreadTransferSrcResetCoordinateAfterRun, + true>( a_b_k0_m_k1_grid_desc, make_multi_index(k_batch_id, 0, m_block_data_idx_on_grid, 0), + a_element_op, a_b_k0_m_k1_block_desc, - make_multi_index(0, 0, 0, 0)); + make_multi_index(0, 0, 0, 0), + ck::tensor_operation::element_wise::PassThrough{}); // B matrix blockwise copy auto b_blockwise_copy = - BlockwiseTensorSliceTransfer_v4, - BBlockTransferThreadSliceLengths_K0_N_K1, - BBlockTransferThreadClusterLengths_K0_N_K1, - BBlockTransferThreadClusterArrangeOrder, - FloatAB, - FloatAB, - decltype(b_b_k0_n_k1_grid_desc), - decltype(b_b_k0_n_k1_block_desc), - BBlockTransferSrcAccessOrder, - Sequence<0, 2, 1, 3>, - BBlockTransferSrcVectorDim, - 3, - BBlockTransferSrcScalarPerVector, - BBlockTransferDstScalarPerVector_K1, - 1, - 1, - BThreadTransferSrcResetCoordinateAfterRun, - true>( + BlockwiseTensorSliceTransfer_v4r1, + BBlockTransferThreadSliceLengths_K0_N_K1, + BBlockTransferThreadClusterLengths_K0_N_K1, + BBlockTransferThreadClusterArrangeOrder, + FloatAB, + FloatAB, + decltype(b_b_k0_n_k1_grid_desc), + decltype(b_b_k0_n_k1_block_desc), + BBlockTransferSrcAccessOrder, + Sequence<0, 2, 1, 3>, + BBlockTransferSrcVectorDim, + 3, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_K1, + 1, + 1, + BThreadTransferSrcResetCoordinateAfterRun, + true>( b_b_k0_n_k1_grid_desc, make_multi_index(k_batch_id, 0, n_block_data_idx_on_grid, 0), + b_element_op, b_b_k0_n_k1_block_desc, - make_multi_index(0, 0, 0, 0)); + make_multi_index(0, 0, 0, 0), + ck::tensor_operation::element_wise::PassThrough{}); // GEMM definition // c_mtx += transpose(a_mtx) * b_mtx @@ -531,15 +543,6 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4 constexpr auto a_block_slice_copy_step = make_multi_index(0, K0PerBlock, 0, 0); constexpr auto b_block_slice_copy_step = make_multi_index(0, K0PerBlock, 0, 0); - // hack to control index calculation when iterating over A and B matrix for threadwise copy - constexpr auto a_k0_m_k1_grid_step_hacks = AGridStepHacks{}; - constexpr auto b_k0_n_k1_grid_step_hacks = BGridStepHacks{}; - - // hack to control index calculation when move slice window for A and B matrix for - // threadwise copy - constexpr auto a_k0_m_k1_grid_move_slice_window_step_hack = AGridMoveSliceWindowStepHacks{}; - constexpr auto b_k0_n_k1_grid_move_slice_window_step_hack = BGridMoveSliceWindowStepHacks{}; - auto a_block_buf = make_dynamic_buffer( p_a_block, a_k0_m_k1_block_desc.GetElementSpaceSize()); auto b_block_buf = make_dynamic_buffer( @@ -547,33 +550,31 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4 // preload data into LDS { - a_blockwise_copy.RunRead(a_b_k0_m_k1_grid_desc, a_grid_buf, a_k0_m_k1_grid_step_hacks); - b_blockwise_copy.RunRead(b_b_k0_n_k1_grid_desc, b_grid_buf, b_k0_n_k1_grid_step_hacks); + a_blockwise_copy.RunRead(a_b_k0_m_k1_grid_desc, a_grid_buf); + b_blockwise_copy.RunRead(b_b_k0_n_k1_grid_desc, b_grid_buf); a_blockwise_copy.RunWrite(a_b_k0_m_k1_block_desc, a_block_buf); b_blockwise_copy.RunWrite(b_b_k0_n_k1_block_desc, b_block_buf); } + // Initialize C + c_thread_buf.Clear(); + // main body - index_t k_block_data_begin = 0; if constexpr(HasMainKBlockLoop) { + index_t k0_block_data_begin = 0; + do { - a_blockwise_copy.MoveSrcSliceWindow(a_b_k0_m_k1_grid_desc, - a_block_slice_copy_step, - a_k0_m_k1_grid_move_slice_window_step_hack); - b_blockwise_copy.MoveSrcSliceWindow(b_b_k0_n_k1_grid_desc, - b_block_slice_copy_step, - b_k0_n_k1_grid_move_slice_window_step_hack); + a_blockwise_copy.MoveSrcSliceWindow(a_b_k0_m_k1_grid_desc, a_block_slice_copy_step); + b_blockwise_copy.MoveSrcSliceWindow(b_b_k0_n_k1_grid_desc, b_block_slice_copy_step); - a_blockwise_copy.RunRead( - a_b_k0_m_k1_grid_desc, a_grid_buf, a_k0_m_k1_grid_step_hacks); + a_blockwise_copy.RunRead(a_b_k0_m_k1_grid_desc, a_grid_buf); block_sync_lds(); - b_blockwise_copy.RunRead( - b_b_k0_n_k1_grid_desc, b_grid_buf, b_k0_n_k1_grid_step_hacks); + b_blockwise_copy.RunRead(b_b_k0_n_k1_grid_desc, b_grid_buf); blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf); @@ -622,8 +623,6 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4 const index_t n_thread_data_on_grid = n_block_data_idx_on_grid + c_thread_mtx_on_block[I1]; - constexpr auto c_m0_n0_m1_n1_m2_m3_m4_n2_grid_tensor_step_hacks = CGridStepHacks{}; - const auto m_thread_data_on_grid_to_m0_m1_m2_m3_m4_adaptor = make_single_stage_tensor_adaptor( make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))), @@ -648,6 +647,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4 FloatC, decltype(c_m0_n0_m1_n1_m2_m3_m4_n2_thread_desc), decltype(c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc), + CElementwiseOperation, Sequence, CThreadTransferSrcDstAccessOrder, CThreadTransferSrcDstVectorDim, @@ -664,14 +664,14 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4 m_thread_data_on_grid_idx[I2], m_thread_data_on_grid_idx[I3], m_thread_data_on_grid_idx[I4], - n_thread_data_on_grid_idx[I2])}; + n_thread_data_on_grid_idx[I2]), + c_element_op}; c_thread_copy.Run(c_m0_n0_m1_n1_m2_m3_m4_n2_thread_desc, make_tuple(I0, I0, I0, I0, I0, I0, I0, I0), c_thread_buf, c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc, - c_grid_buf, - c_m0_n0_m1_n1_m2_m3_m4_n2_grid_tensor_step_hacks); + c_grid_buf); } } }; // namespace ck diff --git a/composable_kernel/include/tensor_operation/gridwise_gemm_xdlops_v2r5.hpp b/composable_kernel/include/tensor_operation/gridwise_gemm_xdlops_v2r5.hpp index a181f4b106..986809de9c 100644 --- a/composable_kernel/include/tensor_operation/gridwise_gemm_xdlops_v2r5.hpp +++ b/composable_kernel/include/tensor_operation/gridwise_gemm_xdlops_v2r5.hpp @@ -6,9 +6,8 @@ #include "tensor_descriptor.hpp" #include "tensor_descriptor_helper.hpp" #include "blockwise_gemm_xdlops.hpp" -#include "blockwise_tensor_slice_transfer.hpp" +#include "blockwise_tensor_slice_transfer_v4r1.hpp" #include "threadwise_tensor_slice_transfer_v1r4.hpp" -#include "threadwise_tensor_slice_set.hpp" namespace ck { @@ -88,7 +87,6 @@ template + index_t CThreadTransferDstScalarPerVector> struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r5 { static constexpr auto I0 = Number<0>{}; @@ -410,59 +401,63 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r5 // A matrix blockwise copy auto a_blockwise_copy = - BlockwiseTensorSliceTransfer_v4, - ABlockTransferThreadSliceLengths_K0_M_K1, - ABlockTransferThreadClusterLengths_K0_M_K1, - ABlockTransferThreadClusterArrangeOrder, - FloatAB, - FloatAB, - decltype(a_grid_desc_k0_m_k1), - decltype(a_block_desc_k0_m_k1), - ABlockTransferSrcAccessOrder, - Sequence<1, 0, 2>, - ABlockTransferSrcVectorDim, - 2, - ABlockTransferSrcScalarPerVector, - ABlockTransferDstScalarPerVector_K1, - 1, - 1, - AThreadTransferSrcResetCoordinateAfterRun, - true>(a_grid_desc_k0_m_k1, - make_multi_index(0, m_block_data_idx_on_grid, 0), - a_block_desc_k0_m_k1, - make_multi_index(0, 0, 0), - a_element_op); + BlockwiseTensorSliceTransfer_v4r1, + ABlockTransferThreadClusterLengths_K0_M_K1, + ABlockTransferThreadClusterArrangeOrder, + FloatAB, + FloatAB, + decltype(a_grid_desc_k0_m_k1), + decltype(a_block_desc_k0_m_k1), + ABlockTransferSrcAccessOrder, + Sequence<1, 0, 2>, + ABlockTransferSrcVectorDim, + 2, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_K1, + 1, + 1, + AThreadTransferSrcResetCoordinateAfterRun, + true>( + a_grid_desc_k0_m_k1, + make_multi_index(0, m_block_data_idx_on_grid, 0), + a_element_op, + a_block_desc_k0_m_k1, + make_multi_index(0, 0, 0), + ck::tensor_operation::element_wise::PassThrough{}); // B matrix blockwise copy auto b_blockwise_copy = - BlockwiseTensorSliceTransfer_v4, - BBlockTransferThreadSliceLengths_K0_N_K1, - BBlockTransferThreadClusterLengths_K0_N_K1, - BBlockTransferThreadClusterArrangeOrder, - FloatAB, - FloatAB, - decltype(b_grid_desc_k0_n_k1), - decltype(b_block_desc_k0_n_k1), - BBlockTransferSrcAccessOrder, - Sequence<1, 0, 2>, - BBlockTransferSrcVectorDim, - 2, - BBlockTransferSrcScalarPerVector, - BBlockTransferDstScalarPerVector_K1, - 1, - 1, - BThreadTransferSrcResetCoordinateAfterRun, - true>(b_grid_desc_k0_n_k1, - make_multi_index(0, n_block_data_idx_on_grid, 0), - b_block_desc_k0_n_k1, - make_multi_index(0, 0, 0), - b_element_op); + BlockwiseTensorSliceTransfer_v4r1, + BBlockTransferThreadClusterLengths_K0_N_K1, + BBlockTransferThreadClusterArrangeOrder, + FloatAB, + FloatAB, + decltype(b_grid_desc_k0_n_k1), + decltype(b_block_desc_k0_n_k1), + BBlockTransferSrcAccessOrder, + Sequence<1, 0, 2>, + BBlockTransferSrcVectorDim, + 2, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_K1, + 1, + 1, + BThreadTransferSrcResetCoordinateAfterRun, + true>( + b_grid_desc_k0_n_k1, + make_multi_index(0, n_block_data_idx_on_grid, 0), + b_element_op, + b_block_desc_k0_n_k1, + make_multi_index(0, 0, 0), + ck::tensor_operation::element_wise::PassThrough{}); // GEMM definition // c_mtx += transpose(a_mtx) * b_mtx @@ -496,15 +491,6 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r5 constexpr auto a_block_slice_copy_step = make_multi_index(K0PerBlock, 0, 0); constexpr auto b_block_slice_copy_step = make_multi_index(K0PerBlock, 0, 0); - // hack to control index calculation when iterating over A and B matrix for threadwise copy - constexpr auto a_k0_m_k1_grid_step_hacks = AGridStepHacks{}; - constexpr auto b_k0_n_k1_grid_step_hacks = BGridStepHacks{}; - - // hack to control index calculation when move slice window for A and B matrix for - // threadwise copy - constexpr auto a_k0_m_k1_grid_move_slice_window_step_hack = AGridMoveSliceWindowStepHacks{}; - constexpr auto b_k0_n_k1_grid_move_slice_window_step_hack = BGridMoveSliceWindowStepHacks{}; - auto a_block_buf = make_dynamic_buffer( p_a_block, a_block_desc_k0_m_k1.GetElementSpaceSize()); auto b_block_buf = make_dynamic_buffer( @@ -512,34 +498,31 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r5 // preload data into LDS { - a_blockwise_copy.RunRead(a_grid_desc_k0_m_k1, a_grid_buf, a_k0_m_k1_grid_step_hacks); - b_blockwise_copy.RunRead(b_grid_desc_k0_n_k1, b_grid_buf, b_k0_n_k1_grid_step_hacks); + a_blockwise_copy.RunRead(a_grid_desc_k0_m_k1, a_grid_buf); + b_blockwise_copy.RunRead(b_grid_desc_k0_n_k1, b_grid_buf); a_blockwise_copy.RunWrite(a_block_desc_k0_m_k1, a_block_buf); b_blockwise_copy.RunWrite(b_block_desc_k0_n_k1, b_block_buf); } - // main body - index_t k0_block_data_begin = 0; + // Initialize C + c_thread_buf.Clear(); + // main body if constexpr(HasMainKBlockLoop) { + index_t k0_block_data_begin = 0; + do { - a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc_k0_m_k1, - a_block_slice_copy_step, - a_k0_m_k1_grid_move_slice_window_step_hack); - b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc_k0_n_k1, - b_block_slice_copy_step, - b_k0_n_k1_grid_move_slice_window_step_hack); + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc_k0_m_k1, a_block_slice_copy_step); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc_k0_n_k1, b_block_slice_copy_step); - a_blockwise_copy.RunRead( - a_grid_desc_k0_m_k1, a_grid_buf, a_k0_m_k1_grid_step_hacks); + a_blockwise_copy.RunRead(a_grid_desc_k0_m_k1, a_grid_buf); block_sync_lds(); - b_blockwise_copy.RunRead( - b_grid_desc_k0_n_k1, b_grid_buf, b_k0_n_k1_grid_step_hacks); + b_blockwise_copy.RunRead(b_grid_desc_k0_n_k1, b_grid_buf); blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf); @@ -588,8 +571,6 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r5 const index_t n_thread_data_on_grid = n_block_data_idx_on_grid + c_thread_mtx_on_block[I1]; - constexpr auto c_m0_n0_m1_n1_m2_m3_m4_n2_grid_tensor_step_hacks = CGridStepHacks{}; - const auto m_thread_data_on_grid_to_m0_m1_m2_m3_m4_adaptor = make_single_stage_tensor_adaptor( make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))), @@ -642,14 +623,13 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r5 c_thread_buf, c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2, c_grid_buf, - c_m0_n0_m1_n1_m2_m3_m4_n2_grid_tensor_step_hacks, c0_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2, c0_grid_buf, c1_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2, c1_grid_buf); } } -}; // namespace ck +}; } // namespace ck #endif diff --git a/composable_kernel/include/tensor_operation/gridwise_gemm_xdlops_v2r6.hpp b/composable_kernel/include/tensor_operation/gridwise_gemm_xdlops_v2r6.hpp new file mode 100644 index 0000000000..a96cd6e74a --- /dev/null +++ b/composable_kernel/include/tensor_operation/gridwise_gemm_xdlops_v2r6.hpp @@ -0,0 +1,617 @@ +#ifndef CK_GRIDWISE_GEMM_XDLOPS_V2R6_HPP +#define CK_GRIDWISE_GEMM_XDLOPS_V2R6_HPP + +#include "common_header.hpp" +#include "multi_index_transform_helper.hpp" +#include "tensor_descriptor.hpp" +#include "tensor_descriptor_helper.hpp" +#include "blockwise_gemm_xdlops.hpp" +#include "blockwise_tensor_slice_transfer_v4r1.hpp" +#include "threadwise_tensor_slice_transfer_v1r5.hpp" + +namespace ck { + +template +__global__ void +#if CK_USE_LAUNCH_BOUNDS + __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +#endif + kernel_gemm_xdlops_v2r6( + const FloatAB* __restrict__ p_a_grid, + const FloatAB* __restrict__ p_b_grid, + FloatC* __restrict__ p_c_grid, + const FloatC* __restrict__ p_c0_grid, + const AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1, + const BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1, + const CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2 c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2, + const C0GridDesc_M0_N0_M1_N1_M2_M3_M4_N2 c0_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2, + const AElementwiseOperation a_element_op, + const BElementwiseOperation b_element_op, + const CElementwiseOperation c_element_op, + const Block2CTileMap block_2_ctile_map) +{ + constexpr index_t shared_block_size = + GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatAB); + + __shared__ FloatAB p_shared_block[shared_block_size]; + + GridwiseGemm::template Run(p_a_grid, + p_b_grid, + p_c_grid, + p_c0_grid, + p_shared_block, + a_grid_desc_k0_m_k1, + b_grid_desc_k0_n_k1, + c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2, + c0_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2, + a_element_op, + b_element_op, + c_element_op, + block_2_ctile_map); +} + +template +struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r6 +{ + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + static constexpr auto I4 = Number<4>{}; + static constexpr auto I5 = Number<5>{}; + static constexpr auto I6 = Number<6>{}; + static constexpr auto I7 = Number<7>{}; + + // K1 should be Number<...> + static constexpr auto K1 = Number{}; + + __host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte() + { + constexpr auto max_lds_align = K1; + + // A matrix in LDS memory, dst of blockwise copy + constexpr auto a_block_desc_k0_m_k1 = [&]() { + if constexpr(ABlockLdsExtraM) + { + return make_naive_tensor_descriptor( + make_tuple(Number{}, Number{}, K1), + make_tuple(Number{} * K1, K1, I1)); + } + else + { + return make_naive_tensor_descriptor_aligned( + make_tuple(Number{}, Number{}, K1), max_lds_align); + } + }(); + + // B matrix in LDS memory, dst of blockwise copy + constexpr auto b_block_desc_k0_n_k1 = [&]() { + if constexpr(BBlockLdsExtraN) + { + return make_naive_tensor_descriptor( + make_tuple(Number{}, Number{}, K1), + make_tuple(Number{} * K1, K1, I1)); + } + else + { + return make_naive_tensor_descriptor_aligned( + make_tuple(Number{}, Number{}, K1), max_lds_align); + } + }(); + + // LDS allocation for A and B: be careful of alignment + constexpr auto a_block_space_size = + math::integer_least_multiple(a_block_desc_k0_m_k1.GetElementSpaceSize(), max_lds_align); + + constexpr auto b_block_space_size = + math::integer_least_multiple(b_block_desc_k0_n_k1.GetElementSpaceSize(), max_lds_align); + + return (a_block_space_size + b_block_space_size) * sizeof(FloatAB); + } + + // block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01} + __host__ __device__ static constexpr bool + CheckValidity(const AGridDesc_K0_M_K1& a_grid_desc_k0_m_k1, + const BGridDesc_K0_N_K1& b_grid_desc_k0_n_k1, + const CGridDesc_M_N& c_grid_desc_m_n, + index_t M01, + index_t N01) + { + static_assert(is_known_at_compile_time>::value, + "wrong! K1 need to be known at compile-time"); + + static_assert((MPerBlock % (MPerXDL * MRepeat) == 0) && + (NPerBlock % (NRepeat * NPerXDL)) == 0, + "Invalid tuning param!"); + + const auto M = a_grid_desc_k0_m_k1.GetLength(I1); + const auto N = b_grid_desc_k0_n_k1.GetLength(I1); + const auto K0 = a_grid_desc_k0_m_k1.GetLength(I0); + + if(!(M == c_grid_desc_m_n.GetLength(I0) && N == c_grid_desc_m_n.GetLength(I1) && + K0 == b_grid_desc_k0_n_k1.GetLength(I0) && K1 == a_grid_desc_k0_m_k1.GetLength(I2) && + K1 == b_grid_desc_k0_n_k1.GetLength(I2))) + return false; + + if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && K0 % K0PerBlock == 0)) + return false; + + // check M01, N01 + constexpr auto M1 = Number{}; + constexpr auto N1 = Number{}; + + const auto M0 = M / M1; + const auto N0 = N / N1; + + if(!(M0 % M01 == 0 && N0 % N01 == 0)) + return false; + + // TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc) + return true; + } + + __host__ __device__ static constexpr index_t + CalculateGridSize(const CGridDesc_M_N& c_grid_desc_m_n) + { + const auto M = c_grid_desc_m_n.GetLength(I0); + const auto N = c_grid_desc_m_n.GetLength(I1); + + const index_t grid_size = (M / MPerBlock) * (N / NPerBlock); + + return grid_size; + } + + __host__ __device__ static constexpr bool CalculateHasMainK0BlockLoop(index_t K0) + { + const bool has_main_k0_block_loop = (K0 / K0PerBlock) > 1; + + return has_main_k0_block_loop; + } + + // TODO fix this + template + __host__ __device__ static constexpr auto + MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(const CGridDesc_M_N_any& c_grid_desc_m_n) + { + constexpr auto max_lds_align = K1; + + // A matrix in LDS memory, dst of blockwise copy + constexpr auto a_block_desc_k0_m_k1 = [&]() { + if constexpr(ABlockLdsExtraM) + { + return make_naive_tensor_descriptor( + make_tuple(Number{}, Number{}, K1), + make_tuple(Number{} * K1, K1, I1)); + } + else + { + return make_naive_tensor_descriptor_aligned( + make_tuple(Number{}, Number{}, K1), max_lds_align); + } + }(); + + // B matrix in LDS memory, dst of blockwise copy + constexpr auto b_block_desc_k0_n_k1 = [&]() { + if constexpr(BBlockLdsExtraN) + { + return make_naive_tensor_descriptor( + make_tuple(Number{}, Number{}, K1), + make_tuple(Number{} * K1, K1, I1)); + } + else + { + return make_naive_tensor_descriptor_aligned( + make_tuple(Number{}, Number{}, K1), max_lds_align); + } + }(); + + using BlockwiseGemm = + BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1; + + return BlockwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(c_grid_desc_m_n); + } + + // return block_id to C matrix tile idx (m0, n0) mapping + __host__ __device__ static constexpr auto + MakeBlock2CTileMap(const CGridDesc_M_N& c_grid_desc_m_n, index_t M01, index_t N01) + { + const auto M = c_grid_desc_m_n.GetLength(I0); + const auto N = c_grid_desc_m_n.GetLength(I1); + + constexpr auto M1 = Number{}; + constexpr auto N1 = Number{}; + + const auto M0 = M / M1; + const auto N0 = N / N1; + + const auto M00 = M0 / M01; + const auto N00 = N0 / N01; + + const auto m00_m01_n00_n01_to_m0_n0_block_cluster_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_unmerge_transform(make_tuple(M00, M01)), + make_unmerge_transform(make_tuple(N00, N01))), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1, 3>{})); + + const auto c_blockid_to_m00_m01_n00_n01_block_cluster_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(M00, N00, M01, N01))), + make_tuple(Sequence<0, 1, 2, 3>{}), + make_tuple(Sequence<0>{})); + + const auto c_blockid_to_m0_n0_block_cluster_adaptor = + chain_tensor_adaptors(m00_m01_n00_n01_to_m0_n0_block_cluster_adaptor, + c_blockid_to_m00_m01_n00_n01_block_cluster_adaptor); + + return c_blockid_to_m0_n0_block_cluster_adaptor; + } + + using CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2 = + decltype(MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(CGridDesc_M_N{})); + + using C0GridDesc_M0_N0_M1_N1_M2_M3_M4_N2 = + decltype(MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(C0GridDesc_M_N{})); + + using Block2CTileMap = decltype(MakeBlock2CTileMap(CGridDesc_M_N{}, 1, 1)); + + template + __device__ static void + Run(const FloatAB* __restrict__ p_a_grid, + const FloatAB* __restrict__ p_b_grid, + FloatC* __restrict__ p_c_grid, + const FloatC* __restrict__ p_c0_grid, + FloatAB* __restrict__ p_shared_block, + const AGridDesc_K0_M_K1& a_grid_desc_k0_m_k1, + const BGridDesc_K0_N_K1& b_grid_desc_k0_n_k1, + const CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2& c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2, + const C0GridDesc_M0_N0_M1_N1_M2_M3_M4_N2& c0_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2, + const AElementwiseOperation& a_element_op, + const BElementwiseOperation& b_element_op, + const CElementwiseOperation& c_element_op, + const Block2CTileMap& block_2_ctile_map) + { + const auto a_grid_buf = make_dynamic_buffer( + p_a_grid, a_grid_desc_k0_m_k1.GetElementSpaceSize()); + const auto b_grid_buf = make_dynamic_buffer( + p_b_grid, b_grid_desc_k0_n_k1.GetElementSpaceSize()); + auto c_grid_buf = make_dynamic_buffer( + p_c_grid, c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetElementSpaceSize()); + + auto c0_grid_buf = make_dynamic_buffer( + p_c0_grid, c0_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetElementSpaceSize()); + + const auto K0 = a_grid_desc_k0_m_k1.GetLength(I0); + + // divide block work by [M, N] + const auto block_work_idx = + block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id())); + + // HACK: this force m/n_block_data_idx_on_grid into SGPR + const index_t m_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(block_work_idx[I0] * MPerBlock); + + const index_t n_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(block_work_idx[I1] * NPerBlock); + + // lds max alignment + constexpr auto max_lds_align = K1; + + // A matrix in LDS memory, dst of blockwise copy + constexpr auto a_block_desc_k0_m_k1 = [&]() { + if constexpr(ABlockLdsExtraM) + { + return make_naive_tensor_descriptor( + make_tuple(Number{}, Number{}, K1), + make_tuple(Number{} * K1, K1, I1)); + } + else + { + return make_naive_tensor_descriptor_aligned( + make_tuple(Number{}, Number{}, K1), max_lds_align); + } + }(); + + // B matrix in LDS memory, dst of blockwise copy + constexpr auto b_block_desc_k0_n_k1 = [&]() { + if constexpr(BBlockLdsExtraN) + { + return make_naive_tensor_descriptor( + make_tuple(Number{}, Number{}, K1), + make_tuple(Number{} * K1, K1, I1)); + } + else + { + return make_naive_tensor_descriptor_aligned( + make_tuple(Number{}, Number{}, K1), max_lds_align); + } + }(); + + // A matrix blockwise copy + auto a_blockwise_copy = + BlockwiseTensorSliceTransfer_v4r1, + ABlockTransferThreadClusterLengths_K0_M_K1, + ABlockTransferThreadClusterArrangeOrder, + FloatAB, + FloatAB, + decltype(a_grid_desc_k0_m_k1), + decltype(a_block_desc_k0_m_k1), + ABlockTransferSrcAccessOrder, + Sequence<1, 0, 2>, + ABlockTransferSrcVectorDim, + 2, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_K1, + 1, + 1, + AThreadTransferSrcResetCoordinateAfterRun, + true>( + a_grid_desc_k0_m_k1, + make_multi_index(0, m_block_data_idx_on_grid, 0), + a_element_op, + a_block_desc_k0_m_k1, + make_multi_index(0, 0, 0), + ck::tensor_operation::element_wise::PassThrough{}); + + // B matrix blockwise copy + auto b_blockwise_copy = + BlockwiseTensorSliceTransfer_v4r1, + BBlockTransferThreadClusterLengths_K0_N_K1, + BBlockTransferThreadClusterArrangeOrder, + FloatAB, + FloatAB, + decltype(b_grid_desc_k0_n_k1), + decltype(b_block_desc_k0_n_k1), + BBlockTransferSrcAccessOrder, + Sequence<1, 0, 2>, + BBlockTransferSrcVectorDim, + 2, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_K1, + 1, + 1, + BThreadTransferSrcResetCoordinateAfterRun, + true>( + b_grid_desc_k0_n_k1, + make_multi_index(0, n_block_data_idx_on_grid, 0), + b_element_op, + b_block_desc_k0_n_k1, + make_multi_index(0, 0, 0), + ck::tensor_operation::element_wise::PassThrough{}); + + // GEMM definition + // c_mtx += transpose(a_mtx) * b_mtx + // a_mtx[K0PerBlock, MPerBlock] is in LDS + // b_mtx[K0PerBlock, NPerBlock] is in LDS + // c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in + // register + // sanity check + + auto blockwise_gemm = + BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1{}; + + auto c_thread_buf = blockwise_gemm.GetCThreadBuffer(); + + // LDS allocation for A and B: be careful of alignment + constexpr auto a_block_space_size = + math::integer_least_multiple(a_block_desc_k0_m_k1.GetElementSpaceSize(), max_lds_align); + + FloatAB* p_a_block = p_shared_block; + FloatAB* p_b_block = p_shared_block + a_block_space_size; + + constexpr auto a_block_slice_copy_step = make_multi_index(K0PerBlock, 0, 0); + constexpr auto b_block_slice_copy_step = make_multi_index(K0PerBlock, 0, 0); + + auto a_block_buf = make_dynamic_buffer( + p_a_block, a_block_desc_k0_m_k1.GetElementSpaceSize()); + auto b_block_buf = make_dynamic_buffer( + p_b_block, b_block_desc_k0_n_k1.GetElementSpaceSize()); + + // preload data into LDS + { + a_blockwise_copy.RunRead(a_grid_desc_k0_m_k1, a_grid_buf); + b_blockwise_copy.RunRead(b_grid_desc_k0_n_k1, b_grid_buf); + + a_blockwise_copy.RunWrite(a_block_desc_k0_m_k1, a_block_buf); + b_blockwise_copy.RunWrite(b_block_desc_k0_n_k1, b_block_buf); + } + + // Initialize C + c_thread_buf.Clear(); + + // main body + if constexpr(HasMainKBlockLoop) + { + index_t k0_block_data_begin = 0; + + do + { + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc_k0_m_k1, a_block_slice_copy_step); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc_k0_n_k1, b_block_slice_copy_step); + + a_blockwise_copy.RunRead(a_grid_desc_k0_m_k1, a_grid_buf); + + block_sync_lds(); + + b_blockwise_copy.RunRead(b_grid_desc_k0_n_k1, b_grid_buf); + + blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf); + + block_sync_lds(); + + a_blockwise_copy.RunWrite(a_block_desc_k0_m_k1, a_block_buf); + b_blockwise_copy.RunWrite(b_block_desc_k0_n_k1, b_block_buf); + + k0_block_data_begin += K0PerBlock; + } while(k0_block_data_begin < (K0 - K0PerBlock)); + } + + // tail + { + block_sync_lds(); + + blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf); + } + + // output: register to global memory + { + constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = + blockwise_gemm.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); + + constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I0); + constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I1); + constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I2); + constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I3); + constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I4); + constexpr auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I5); + constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I6); + constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I7); + + constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 = + make_naive_tensor_descriptor_packed(make_tuple( + Number{}, Number{}, I1, I1, Number{}, I1, Number{}, I1)); + + // calculate origin of thread output tensor on global memory + // blockwise GEMM c matrix starting index + const auto c_thread_mtx_on_block = + blockwise_gemm.CalculateCThreadOriginDataIndex(I0, I0, I0, I0); + + const index_t m_thread_data_on_grid = + m_block_data_idx_on_grid + c_thread_mtx_on_block[I0]; + + const index_t n_thread_data_on_grid = + n_block_data_idx_on_grid + c_thread_mtx_on_block[I1]; + + const auto m_thread_data_on_grid_to_m0_m1_m2_m3_m4_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))), + make_tuple(Sequence<0, 1, 2, 3, 4>{}), + make_tuple(Sequence<0>{})); + + const auto m_thread_data_on_grid_idx = + m_thread_data_on_grid_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex( + make_multi_index(m_thread_data_on_grid)); + + const auto n_thread_data_on_grid_to_n0_n1_n2_adaptor = make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(N0, N1, N2))), + make_tuple(Sequence<0, 1, 2>{}), + make_tuple(Sequence<0>{})); + + const auto n_thread_data_on_grid_idx = + n_thread_data_on_grid_to_n0_n1_n2_adaptor.CalculateBottomIndex( + make_multi_index(n_thread_data_on_grid)); + + auto c_thread_copy = + ThreadwiseTensorSliceTransfer_v1r5, + CThreadTransferSrcDstAccessOrder, + CThreadTransferSrcDstVectorDim, + CThreadTransferDstScalarPerVector, + CGlobalMemoryDataOperation, + 1, + true>{ + c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2, + c0_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2, + make_multi_index(m_thread_data_on_grid_idx[I0], + n_thread_data_on_grid_idx[I0], + m_thread_data_on_grid_idx[I1], + n_thread_data_on_grid_idx[I1], + m_thread_data_on_grid_idx[I2], + m_thread_data_on_grid_idx[I3], + m_thread_data_on_grid_idx[I4], + n_thread_data_on_grid_idx[I2]), + c_element_op}; + + c_thread_copy.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2, + make_tuple(I0, I0, I0, I0, I0, I0, I0, I0), + c_thread_buf, + c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2, + c_grid_buf, + c0_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2, + c0_grid_buf); + } + } +}; + +} // namespace ck +#endif diff --git a/composable_kernel/include/tensor_operation/gridwise_gemm_xdlops_v3r1.hpp b/composable_kernel/include/tensor_operation/gridwise_gemm_xdlops_v3r1.hpp new file mode 100644 index 0000000000..3022f3f0fc --- /dev/null +++ b/composable_kernel/include/tensor_operation/gridwise_gemm_xdlops_v3r1.hpp @@ -0,0 +1,744 @@ +#ifndef CK_GRIDWISE_GEMM_XDLOPS_V3R1_HPP +#define CK_GRIDWISE_GEMM_XDLOPS_V3R1_HPP + +#include "common_header.hpp" +#include "multi_index_transform_helper.hpp" +#include "tensor_descriptor.hpp" +#include "tensor_descriptor_helper.hpp" +#include "blockwise_gemm_xdlops.hpp" +#include "blockwise_tensor_slice_transfer_v4r1.hpp" +#include "blockwise_tensor_slice_transfer_v6r1.hpp" +#include "threadwise_tensor_slice_transfer.hpp" + +namespace ck { + +template +__global__ void +#if CK_USE_LAUNCH_BOUNDS + __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +#endif + kernel_gemm_xdlops_v3r1( + const FloatAB* __restrict__ p_a_grid, + const FloatAB* __restrict__ p_b_grid, + FloatC* __restrict__ p_c_grid, + const AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1, + const BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1, + const CGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl + c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl, + const AElementwiseOperation a_element_op, + const BElementwiseOperation b_element_op, + const CElementwiseOperation c_element_op, + const Block2CTileMap block_2_ctile_map) +{ + __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + + GridwiseGemm::template Run( + p_a_grid, + p_b_grid, + p_c_grid, + p_shared, + a_grid_desc_k0_m_k1, + b_grid_desc_k0_n_k1, + c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl, + a_element_op, + b_element_op, + c_element_op, + block_2_ctile_map); +} + +template < + index_t BlockSize, + typename FloatAB, + typename FloatAcc, + typename FloatC, + InMemoryDataOperationEnum_t CGlobalMemoryDataOperation, + typename AGridDesc_K0_M_K1, + typename BGridDesc_K0_N_K1, + typename CGridDesc_M_N, + typename AElementwiseOperation, + typename BElementwiseOperation, + typename CElementwiseOperation, + index_t MPerBlock, + index_t NPerBlock, + index_t K0PerBlock, + index_t MPerXdl, + index_t NPerXdl, + index_t K1Value, + index_t MXdlPerWave, + index_t NXdlPerWave, + typename ABlockTransferThreadClusterLengths_K0_M_K1, + typename ABlockTransferThreadClusterArrangeOrder, + typename ABlockTransferSrcAccessOrder, + index_t ABlockTransferSrcVectorDim, + index_t ABlockTransferSrcScalarPerVector, + index_t ABlockTransferDstScalarPerVector_K1, + bool AThreadTransferSrcResetCoordinateAfterRun, + bool ABlockLdsExtraM, + typename BBlockTransferThreadClusterLengths_K0_N_K1, + typename BBlockTransferThreadClusterArrangeOrder, + typename BBlockTransferSrcAccessOrder, + index_t BBlockTransferSrcVectorDim, + index_t BBlockTransferSrcScalarPerVector, + index_t BBlockTransferDstScalarPerVector_K1, + bool BThreadTransferSrcResetCoordinateAfterRun, + bool BBlockLdsExtraN, + index_t CShuffleMXdlPerWavePerShuffle, + index_t CShuffleNXdlPerWavePerShuffle, + typename CBlockTransferClusterLengths_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl, + index_t CBlockTransferScalarPerVector_NWaveNPerXdl> +struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1 +{ + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + static constexpr auto I4 = Number<4>{}; + static constexpr auto I5 = Number<5>{}; + static constexpr auto I6 = Number<6>{}; + static constexpr auto I7 = Number<7>{}; + + // K1 should be Number<...> + static constexpr auto K1 = Number{}; + + __host__ __device__ static constexpr auto GetABlockDescriptor_K0PerBlock_MPerBlock_K1() + { + constexpr auto max_lds_align = K1; + + // A matrix in LDS memory, dst of blockwise copy + constexpr auto a_block_desc_k0_m_k1 = [&]() { + if constexpr(ABlockLdsExtraM) + { + return make_naive_tensor_descriptor( + make_tuple(Number{}, Number{}, K1), + make_tuple(Number{} * K1, K1, I1)); + } + else + { + return make_naive_tensor_descriptor_aligned( + make_tuple(Number{}, Number{}, K1), max_lds_align); + } + }(); + + return a_block_desc_k0_m_k1; + } + + __host__ __device__ static constexpr auto GetBBlockDescriptor_K0PerBlock_NPerBlock_K1() + { + constexpr auto max_lds_align = K1; + + // B matrix in LDS memory, dst of blockwise copy + constexpr auto b_block_desc_k0_n_k1 = [&]() { + if constexpr(BBlockLdsExtraN) + { + return make_naive_tensor_descriptor( + make_tuple(Number{}, Number{}, K1), + make_tuple(Number{} * K1, K1, I1)); + } + else + { + return make_naive_tensor_descriptor_aligned( + make_tuple(Number{}, Number{}, K1), max_lds_align); + } + }(); + + return b_block_desc_k0_n_k1; + } + + __host__ __device__ static constexpr auto + GetCBlockDescriptor_MBlock_NXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl() + { + constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); + constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl); + + constexpr auto + c_block_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl = + make_naive_tensor_descriptor_packed( + make_tuple(I1, + Number{}, + Number{}, + I1, + Number{}, + Number{})); + + return c_block_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl; + } + + __host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte() + { + // LDS allocation for A and B: be careful of alignment + constexpr auto a_block_desc_k0_m_k1 = GetABlockDescriptor_K0PerBlock_MPerBlock_K1(); + + constexpr auto b_block_desc_k0_n_k1 = GetBBlockDescriptor_K0PerBlock_NPerBlock_K1(); + + constexpr auto max_lds_align = K1; + + constexpr auto a_block_space_size_aligned = + math::integer_least_multiple(a_block_desc_k0_m_k1.GetElementSpaceSize(), max_lds_align); + + constexpr auto b_block_space_size_aligned = + math::integer_least_multiple(b_block_desc_k0_n_k1.GetElementSpaceSize(), max_lds_align); + + // LDS allocation for C shuffle in LDS + constexpr auto c_block_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl = + GetCBlockDescriptor_MBlock_NXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl(); + + constexpr auto c_block_size = + c_block_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl + .GetElementSpaceSize(); + + return math::max((a_block_space_size_aligned + b_block_space_size_aligned) * + sizeof(FloatAB), + c_block_size * sizeof(FloatC)); + } + + // block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01} + __host__ __device__ static constexpr bool + CheckValidity(const AGridDesc_K0_M_K1& a_grid_desc_k0_m_k1, + const BGridDesc_K0_N_K1& b_grid_desc_k0_n_k1, + const CGridDesc_M_N& c_grid_desc_m_n, + index_t M01, + index_t N01) + { + static_assert(is_known_at_compile_time>::value, + "wrong! K1 need to be known at compile-time"); + + static_assert((MPerBlock % (MPerXdl * MXdlPerWave) == 0) && + (NPerBlock % (NXdlPerWave * NPerXdl)) == 0, + "Invalid tuning param!"); + + const auto M = a_grid_desc_k0_m_k1.GetLength(I1); + const auto N = b_grid_desc_k0_n_k1.GetLength(I1); + const auto K0 = a_grid_desc_k0_m_k1.GetLength(I0); + + if(!(M == c_grid_desc_m_n.GetLength(I0) && N == c_grid_desc_m_n.GetLength(I1) && + K0 == b_grid_desc_k0_n_k1.GetLength(I0) && K1 == a_grid_desc_k0_m_k1.GetLength(I2) && + K1 == b_grid_desc_k0_n_k1.GetLength(I2))) + return false; + + if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && K0 % K0PerBlock == 0)) + return false; + + // check M01, N01 + constexpr auto M1 = Number{}; + constexpr auto N1 = Number{}; + + const auto M0 = M / M1; + const auto N0 = N / N1; + + if(!(M0 % M01 == 0 && N0 % N01 == 0)) + return false; + + // TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc) + return true; + } + + __host__ __device__ static constexpr index_t + CalculateGridSize(const CGridDesc_M_N& c_grid_desc_m_n) + { + const auto M = c_grid_desc_m_n.GetLength(I0); + const auto N = c_grid_desc_m_n.GetLength(I1); + + const index_t grid_size = (M / MPerBlock) * (N / NPerBlock); + + return grid_size; + } + + __host__ __device__ static constexpr bool CalculateHasMainK0BlockLoop(index_t K0) + { + const bool has_main_k0_block_loop = (K0 / K0PerBlock) > 1; + + return has_main_k0_block_loop; + } + + __host__ __device__ static constexpr auto + MakeCGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl( + const CGridDesc_M_N& c_grid_desc_m_n) + { + const auto M = c_grid_desc_m_n.GetLength(I0); + const auto N = c_grid_desc_m_n.GetLength(I1); + + const auto MBlock = M / MPerBlock; + const auto NBlock = N / NPerBlock; + + constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); + constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl); + + const auto c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl = + transform_tensor_descriptor( + c_grid_desc_m_n, + make_tuple(make_unmerge_transform(make_tuple( + MBlock, Number{}, Number{})), + make_unmerge_transform(make_tuple( + NBlock, Number{}, Number{}))), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 1, 2>{}, Sequence<3, 4, 5>{})); + + return c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl; + } + + // return block_id to C matrix tile idx (m0, n0) mapping + __host__ __device__ static constexpr auto + MakeBlock2CTileMap(const CGridDesc_M_N& c_grid_desc_m_n, index_t M01, index_t N01) + { + const auto M = c_grid_desc_m_n.GetLength(I0); + const auto N = c_grid_desc_m_n.GetLength(I1); + + constexpr auto M1 = Number{}; + constexpr auto N1 = Number{}; + + const auto M0 = M / M1; + const auto N0 = N / N1; + + const auto M00 = M0 / M01; + const auto N00 = N0 / N01; + + const auto m00_m01_n00_n01_to_m0_n0_block_cluster_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_unmerge_transform(make_tuple(M00, M01)), + make_unmerge_transform(make_tuple(N00, N01))), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1, 3>{})); + + const auto c_blockid_to_m00_m01_n00_n01_block_cluster_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(M00, N00, M01, N01))), + make_tuple(Sequence<0, 1, 2, 3>{}), + make_tuple(Sequence<0>{})); + + const auto c_blockid_to_m0_n0_block_cluster_adaptor = + chain_tensor_adaptors(m00_m01_n00_n01_to_m0_n0_block_cluster_adaptor, + c_blockid_to_m00_m01_n00_n01_block_cluster_adaptor); + + return c_blockid_to_m0_n0_block_cluster_adaptor; + } + using CGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl = + remove_cvref_t; + + using Block2CTileMap = remove_cvref_t; + + template + __device__ static void + Run(const FloatAB* __restrict__ p_a_grid, + const FloatAB* __restrict__ p_b_grid, + FloatC* __restrict__ p_c_grid, + void* __restrict__ p_shared, + const AGridDesc_K0_M_K1& a_grid_desc_k0_m_k1, + const BGridDesc_K0_N_K1& b_grid_desc_k0_n_k1, + const CGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl& + c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl, + const AElementwiseOperation& a_element_op, + const BElementwiseOperation& b_element_op, + const CElementwiseOperation& c_element_op, + const Block2CTileMap& block_2_ctile_map) + { + const auto a_grid_buf = make_dynamic_buffer( + p_a_grid, a_grid_desc_k0_m_k1.GetElementSpaceSize()); + const auto b_grid_buf = make_dynamic_buffer( + p_b_grid, b_grid_desc_k0_n_k1.GetElementSpaceSize()); + auto c_grid_buf = make_dynamic_buffer( + p_c_grid, + c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl + .GetElementSpaceSize()); + + const auto K0 = a_grid_desc_k0_m_k1.GetLength(I0); + + // divide block work by [M, N] + const auto block_work_idx = + block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id())); + + // HACK: this force m/n_block_data_idx_on_grid into SGPR + const index_t m_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(block_work_idx[I0] * MPerBlock); + + const index_t n_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(block_work_idx[I1] * NPerBlock); + + // lds max alignment + constexpr auto max_lds_align = K1; + + // A matrix in LDS memory, dst of blockwise copy + constexpr auto a_block_desc_k0_m_k1 = GetABlockDescriptor_K0PerBlock_MPerBlock_K1(); + + // B matrix in LDS memory, dst of blockwise copy + constexpr auto b_block_desc_k0_n_k1 = GetBBlockDescriptor_K0PerBlock_NPerBlock_K1(); + + // A matrix blockwise copy + auto a_blockwise_copy = + BlockwiseTensorSliceTransfer_v4r1, + ABlockTransferThreadClusterLengths_K0_M_K1, + ABlockTransferThreadClusterArrangeOrder, + FloatAB, + FloatAB, + decltype(a_grid_desc_k0_m_k1), + decltype(a_block_desc_k0_m_k1), + ABlockTransferSrcAccessOrder, + Sequence<1, 0, 2>, + ABlockTransferSrcVectorDim, + 2, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_K1, + 1, + 1, + AThreadTransferSrcResetCoordinateAfterRun, + true>( + a_grid_desc_k0_m_k1, + make_multi_index(0, m_block_data_idx_on_grid, 0), + a_element_op, + a_block_desc_k0_m_k1, + make_multi_index(0, 0, 0), + ck::tensor_operation::element_wise::PassThrough{}); + + // B matrix blockwise copy + auto b_blockwise_copy = + BlockwiseTensorSliceTransfer_v4r1, + BBlockTransferThreadClusterLengths_K0_N_K1, + BBlockTransferThreadClusterArrangeOrder, + FloatAB, + FloatAB, + decltype(b_grid_desc_k0_n_k1), + decltype(b_block_desc_k0_n_k1), + BBlockTransferSrcAccessOrder, + Sequence<1, 0, 2>, + BBlockTransferSrcVectorDim, + 2, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_K1, + 1, + 1, + BThreadTransferSrcResetCoordinateAfterRun, + true>( + b_grid_desc_k0_n_k1, + make_multi_index(0, n_block_data_idx_on_grid, 0), + b_element_op, + b_block_desc_k0_n_k1, + make_multi_index(0, 0, 0), + ck::tensor_operation::element_wise::PassThrough{}); + + // GEMM definition + // c_mtx += transpose(a_mtx) * b_mtx + // a_mtx[K0PerBlock, MPerBlock] is in LDS + // b_mtx[K0PerBlock, NPerBlock] is in LDS + // c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in + // register + // sanity check + + auto blockwise_gemm = + BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1{}; + + auto c_thread_buf = blockwise_gemm.GetCThreadBuffer(); + + // LDS allocation for A and B: be careful of alignment + constexpr auto a_block_space_size_aligned = + math::integer_least_multiple(a_block_desc_k0_m_k1.GetElementSpaceSize(), max_lds_align); + + auto a_block_buf = make_dynamic_buffer( + static_cast(p_shared), a_block_desc_k0_m_k1.GetElementSpaceSize()); + + auto b_block_buf = make_dynamic_buffer( + static_cast(p_shared) + a_block_space_size_aligned, + b_block_desc_k0_n_k1.GetElementSpaceSize()); + + constexpr auto a_block_slice_copy_step = make_multi_index(K0PerBlock, 0, 0); + constexpr auto b_block_slice_copy_step = make_multi_index(K0PerBlock, 0, 0); + + // preload data into LDS + { + a_blockwise_copy.RunRead(a_grid_desc_k0_m_k1, a_grid_buf); + b_blockwise_copy.RunRead(b_grid_desc_k0_n_k1, b_grid_buf); + + a_blockwise_copy.RunWrite(a_block_desc_k0_m_k1, a_block_buf); + b_blockwise_copy.RunWrite(b_block_desc_k0_n_k1, b_block_buf); + } + + // Initialize C + c_thread_buf.Clear(); + + // main body + if constexpr(HasMainKBlockLoop) + { + index_t k0_block_data_begin = 0; + + do + { + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc_k0_m_k1, a_block_slice_copy_step); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc_k0_n_k1, b_block_slice_copy_step); + + a_blockwise_copy.RunRead(a_grid_desc_k0_m_k1, a_grid_buf); + + block_sync_lds(); + + b_blockwise_copy.RunRead(b_grid_desc_k0_n_k1, b_grid_buf); + + blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf); + + block_sync_lds(); + + a_blockwise_copy.RunWrite(a_block_desc_k0_m_k1, a_block_buf); + b_blockwise_copy.RunWrite(b_block_desc_k0_n_k1, b_block_buf); + + k0_block_data_begin += K0PerBlock; + } while(k0_block_data_begin < (K0 - K0PerBlock)); + } + + // tail + { + block_sync_lds(); + + blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf); + } + + // shuffle C and write out + { + static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 && + NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0, + "wrong!"); + + constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); + constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl); + + // TODO: hacky, fix it! + constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 = + blockwise_gemm.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); + + // TODO: hacky, fix it! + // c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp is only used to get lengths + constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp = + blockwise_gemm.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); + + constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I0); + constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I1); + constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I2); + constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I3); + constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I4); + constexpr auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I5); + constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I6); + constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I7); + + constexpr auto c_block_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl = + GetCBlockDescriptor_MBlock_NXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl(); + + auto c_block_buf = make_dynamic_buffer( + static_cast(p_shared), + c_block_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl + .GetElementSpaceSize()); + + constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = transform_tensor_descriptor( + c_block_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl, + make_tuple( + make_freeze_transform(I0), // freeze mblock + make_pass_through_transform( + Number{}), // M0 (MXdlPerWave) per shuffle + make_unmerge_transform( + make_tuple(M1, M2, M3, M4)), // M1 = MWave, M2 * M3 * M4 = MPerXdl + make_freeze_transform(I0), // freeze nblock + make_pass_through_transform( + Number{}), // N0 (NXdlPerWave) per shuffle + make_unmerge_transform( + make_tuple(N1, N2))), // M1 = MWave, M2 * M3 * M4 = MPerXdl + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<2>{}, + Sequence<3>{}, + Sequence<4>{}, + Sequence<5>{}), + make_tuple(Sequence<>{}, + Sequence<0>{}, + Sequence<2, 4, 5, 6>{}, + Sequence<>{}, + Sequence<1>{}, + Sequence<3, 7>{}) + + ); + + // calculate origin of thread output tensor on global memory + // blockwise GEMM c matrix starting index + const auto c_thread_mtx_on_block = + blockwise_gemm.CalculateCThreadOriginDataIndex(I0, I0, I0, I0); + + const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0]; + const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1]; + + const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))), + make_tuple(Sequence<0, 1, 2, 3, 4>{}), + make_tuple(Sequence<0>{})); + + const auto m_thread_data_on_block_idx = + m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex( + make_multi_index(m_thread_data_on_block)); + + const auto n_thread_data_on_block_to_n0_n1_n2_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(N0, N1, N2))), + make_tuple(Sequence<0, 1, 2>{}), + make_tuple(Sequence<0>{})); + + const auto n_thread_data_on_block_idx = + n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex( + make_multi_index(n_thread_data_on_block)); + + // VGPR to LDS + auto c_thread_copy_vgpr_to_lds = + ThreadwiseTensorSliceTransfer_v1r3, + Sequence<0, 1, 2, 3, 4, 5, 6, 7>, + 7, + 1, + InMemoryDataOperationEnum_t::Set, + 1, + true>{ + c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, + make_multi_index(0, + 0, + m_thread_data_on_block_idx[I1], + n_thread_data_on_block_idx[I1], + m_thread_data_on_block_idx[I2], + m_thread_data_on_block_idx[I3], + m_thread_data_on_block_idx[I4], + n_thread_data_on_block_idx[I2]), + ck::tensor_operation::element_wise::PassThrough{}}; + + auto c_block_copy_lds_to_global = BlockwiseTensorSliceTransfer_v6r1< + BlockSize, // index_t BlockSize, + CElementwiseOperation, // ElementwiseOperation, + CGlobalMemoryDataOperation, // DstInMemOp, + Sequence<1, + CShuffleMXdlPerWavePerShuffle, + MWave * MPerXdl, + 1, + CShuffleNXdlPerWavePerShuffle, + NWave * NPerXdl>, // BlockSliceLengths, + CBlockTransferClusterLengths_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl, + Sequence<0, 1, 2, 3, 4, 5>, // typename ThreadClusterArrangeOrder, + FloatC, // typename SrcData, + FloatC, // typename DstData, + decltype( + c_block_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl), + decltype( + c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl), + Sequence<0, 1, 2, 3, 4, 5>, // typename DimAccessOrder, + 5, // index_t VectorDim, + CBlockTransferScalarPerVector_NWaveNPerXdl, // index_t ScalarPerVector, + true, // bool ThreadTransferSrcResetCoordinateAfterRun, + false> // bool ThreadTransferDstResetCoordinateAfterRun> + {c_block_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl, + make_multi_index(0, 0, 0, 0, 0, 0), + c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl, + make_multi_index(block_work_idx[I0], 0, 0, block_work_idx[I1], 0, 0), + c_element_op}; + + constexpr auto mxdlperwave_forward_step = + make_multi_index(0, CShuffleMXdlPerWavePerShuffle, 0, 0, 0, 0); + constexpr auto nxdlperwave_forward_step = + make_multi_index(0, 0, 0, 0, CShuffleNXdlPerWavePerShuffle, 0); + constexpr auto nxdlperwave_backward_step = + make_multi_index(0, 0, 0, 0, -CShuffleNXdlPerWavePerShuffle, 0); + + static_for<0, MXdlPerWave, CShuffleMXdlPerWavePerShuffle>{}([&](auto mxdlperwave_iter) { + constexpr auto mxdlperwave = mxdlperwave_iter; + + static_for<0, + NXdlPerWave, + CShuffleNXdlPerWavePerShuffle>{}([&](auto nxdlperwave_iter) { + constexpr bool nxdlperwave_forward_sweep = + (mxdlperwave % (2 * CShuffleMXdlPerWavePerShuffle) == 0); + + constexpr index_t nxdlperwave_value = + nxdlperwave_forward_sweep + ? nxdlperwave_iter + : (NXdlPerWave - nxdlperwave_iter - CShuffleNXdlPerWavePerShuffle); + + constexpr auto nxdlperwave = Number{}; + + // make sure it's safe to do ds_write + block_sync_lds(); + + // VGPR to LDS + c_thread_copy_vgpr_to_lds.Run( + c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2, + make_tuple(mxdlperwave, nxdlperwave, I0, I0, I0, I0, I0, I0), + c_thread_buf, + c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, + c_block_buf); + + // make sure it's safe to do ds_read + block_sync_lds(); + + // LDS to global + c_block_copy_lds_to_global.Run( + c_block_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl, + c_block_buf, + c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl, + c_grid_buf); + + // move on nxdlperwave dimension + if constexpr(nxdlperwave_forward_sweep && + (nxdlperwave < NXdlPerWave - CShuffleNXdlPerWavePerShuffle)) + { + c_block_copy_lds_to_global.MoveDstSliceWindow( + c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl, + nxdlperwave_forward_step); + } + else if constexpr((!nxdlperwave_forward_sweep) && (nxdlperwave > 0)) + { + c_block_copy_lds_to_global.MoveDstSliceWindow( + c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl, + nxdlperwave_backward_step); + } + }); + + // move on mxdlperwave dimension + if constexpr(mxdlperwave < MXdlPerWave - CShuffleMXdlPerWavePerShuffle) + { + c_block_copy_lds_to_global.MoveDstSliceWindow( + c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl, + mxdlperwave_forward_step); + } + }); + } + } +}; + +} // namespace ck +#endif diff --git a/composable_kernel/include/tensor_operation/gridwise_gemm_xdlops_v3r2.hpp b/composable_kernel/include/tensor_operation/gridwise_gemm_xdlops_v3r2.hpp new file mode 100644 index 0000000000..30059525c7 --- /dev/null +++ b/composable_kernel/include/tensor_operation/gridwise_gemm_xdlops_v3r2.hpp @@ -0,0 +1,784 @@ +#ifndef CK_GRIDWISE_GEMM_XDLOPS_V3R2_HPP +#define CK_GRIDWISE_GEMM_XDLOPS_V3R2_HPP + +#include "common_header.hpp" +#include "multi_index_transform_helper.hpp" +#include "tensor_descriptor.hpp" +#include "tensor_descriptor_helper.hpp" +#include "blockwise_gemm_xdlops.hpp" +#include "blockwise_tensor_slice_transfer_v4r1.hpp" +#include "blockwise_tensor_slice_transfer_v6r2.hpp" +#include "threadwise_tensor_slice_transfer.hpp" + +namespace ck { + +template +__global__ void +#if CK_USE_LAUNCH_BOUNDS + __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +#endif + kernel_gemm_xdlops_v3r2( + const FloatAB* __restrict__ p_a_grid, + const FloatAB* __restrict__ p_b_grid, + FloatC* __restrict__ p_c_grid, + const FloatC* __restrict__ p_c0_grid, + const AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1, + const BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1, + const CGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl + c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl, + const C0GridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl + c0_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl, + const AElementwiseOperation a_element_op, + const BElementwiseOperation b_element_op, + const CElementwiseOperation c_element_op, + const Block2CTileMap block_2_ctile_map) +{ + __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + + GridwiseGemm::template Run( + p_a_grid, + p_b_grid, + p_c_grid, + p_c0_grid, + p_shared, + a_grid_desc_k0_m_k1, + b_grid_desc_k0_n_k1, + c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl, + c0_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl, + a_element_op, + b_element_op, + c_element_op, + block_2_ctile_map); +} + +template < + index_t BlockSize, + typename FloatAB, + typename FloatAcc, + typename FloatC, + InMemoryDataOperationEnum_t CGlobalMemoryDataOperation, + typename AGridDesc_K0_M_K1, + typename BGridDesc_K0_N_K1, + typename CGridDesc_M_N, + typename C0GridDesc_M_N, + typename AElementwiseOperation, + typename BElementwiseOperation, + typename CElementwiseOperation, + index_t MPerBlock, + index_t NPerBlock, + index_t K0PerBlock, + index_t MPerXdl, + index_t NPerXdl, + index_t K1Value, + index_t MXdlPerWave, + index_t NXdlPerWave, + typename ABlockTransferThreadClusterLengths_K0_M_K1, + typename ABlockTransferThreadClusterArrangeOrder, + typename ABlockTransferSrcAccessOrder, + index_t ABlockTransferSrcVectorDim, + index_t ABlockTransferSrcScalarPerVector, + index_t ABlockTransferDstScalarPerVector_K1, + bool AThreadTransferSrcResetCoordinateAfterRun, + bool ABlockLdsExtraM, + typename BBlockTransferThreadClusterLengths_K0_N_K1, + typename BBlockTransferThreadClusterArrangeOrder, + typename BBlockTransferSrcAccessOrder, + index_t BBlockTransferSrcVectorDim, + index_t BBlockTransferSrcScalarPerVector, + index_t BBlockTransferDstScalarPerVector_K1, + bool BThreadTransferSrcResetCoordinateAfterRun, + bool BBlockLdsExtraN, + index_t CShuffleMXdlPerWavePerShuffle, + index_t CShuffleNXdlPerWavePerShuffle, + typename CBlockTransferClusterLengths_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl, + index_t CBlockTransferScalarPerVector_NWaveNPerXdl> +struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r2 +{ + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + static constexpr auto I4 = Number<4>{}; + static constexpr auto I5 = Number<5>{}; + static constexpr auto I6 = Number<6>{}; + static constexpr auto I7 = Number<7>{}; + + // K1 should be Number<...> + static constexpr auto K1 = Number{}; + + __host__ __device__ static constexpr auto GetABlockDescriptor_K0PerBlock_MPerBlock_K1() + { + constexpr auto max_lds_align = K1; + + // A matrix in LDS memory, dst of blockwise copy + constexpr auto a_block_desc_k0_m_k1 = [&]() { + if constexpr(ABlockLdsExtraM) + { + return make_naive_tensor_descriptor( + make_tuple(Number{}, Number{}, K1), + make_tuple(Number{} * K1, K1, I1)); + } + else + { + return make_naive_tensor_descriptor_aligned( + make_tuple(Number{}, Number{}, K1), max_lds_align); + } + }(); + + return a_block_desc_k0_m_k1; + } + + __host__ __device__ static constexpr auto GetBBlockDescriptor_K0PerBlock_NPerBlock_K1() + { + constexpr auto max_lds_align = K1; + + // B matrix in LDS memory, dst of blockwise copy + constexpr auto b_block_desc_k0_n_k1 = [&]() { + if constexpr(BBlockLdsExtraN) + { + return make_naive_tensor_descriptor( + make_tuple(Number{}, Number{}, K1), + make_tuple(Number{} * K1, K1, I1)); + } + else + { + return make_naive_tensor_descriptor_aligned( + make_tuple(Number{}, Number{}, K1), max_lds_align); + } + }(); + + return b_block_desc_k0_n_k1; + } + + __host__ __device__ static constexpr auto + GetCBlockDescriptor_MBlock_NXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl() + { + constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); + constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl); + + constexpr auto + c_block_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl = + make_naive_tensor_descriptor_packed( + make_tuple(I1, + Number{}, + Number{}, + I1, + Number{}, + Number{})); + + return c_block_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl; + } + + __host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte() + { + // LDS allocation for A and B: be careful of alignment + constexpr auto a_block_desc_k0_m_k1 = GetABlockDescriptor_K0PerBlock_MPerBlock_K1(); + + constexpr auto b_block_desc_k0_n_k1 = GetBBlockDescriptor_K0PerBlock_NPerBlock_K1(); + + constexpr auto max_lds_align = K1; + + constexpr auto a_block_space_size_aligned = + math::integer_least_multiple(a_block_desc_k0_m_k1.GetElementSpaceSize(), max_lds_align); + + constexpr auto b_block_space_size_aligned = + math::integer_least_multiple(b_block_desc_k0_n_k1.GetElementSpaceSize(), max_lds_align); + + // LDS allocation for C shuffle in LDS + constexpr auto c_block_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl = + GetCBlockDescriptor_MBlock_NXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl(); + + constexpr auto c_block_size = + c_block_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl + .GetElementSpaceSize(); + + return math::max((a_block_space_size_aligned + b_block_space_size_aligned) * + sizeof(FloatAB), + c_block_size * sizeof(FloatC)); + } + + // block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01} + __host__ __device__ static constexpr bool + CheckValidity(const AGridDesc_K0_M_K1& a_grid_desc_k0_m_k1, + const BGridDesc_K0_N_K1& b_grid_desc_k0_n_k1, + const CGridDesc_M_N& c_grid_desc_m_n, + index_t M01, + index_t N01) + { + static_assert(is_known_at_compile_time>::value, + "wrong! K1 need to be known at compile-time"); + + static_assert((MPerBlock % (MPerXdl * MXdlPerWave) == 0) && + (NPerBlock % (NXdlPerWave * NPerXdl)) == 0, + "Invalid tuning param!"); + + const auto M = a_grid_desc_k0_m_k1.GetLength(I1); + const auto N = b_grid_desc_k0_n_k1.GetLength(I1); + const auto K0 = a_grid_desc_k0_m_k1.GetLength(I0); + + if(!(M == c_grid_desc_m_n.GetLength(I0) && N == c_grid_desc_m_n.GetLength(I1) && + K0 == b_grid_desc_k0_n_k1.GetLength(I0) && K1 == a_grid_desc_k0_m_k1.GetLength(I2) && + K1 == b_grid_desc_k0_n_k1.GetLength(I2))) + return false; + + if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && K0 % K0PerBlock == 0)) + return false; + + // check M01, N01 + constexpr auto M1 = Number{}; + constexpr auto N1 = Number{}; + + const auto M0 = M / M1; + const auto N0 = N / N1; + + if(!(M0 % M01 == 0 && N0 % N01 == 0)) + return false; + + // TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc) + return true; + } + + __host__ __device__ static constexpr index_t + CalculateGridSize(const CGridDesc_M_N& c_grid_desc_m_n) + { + const auto M = c_grid_desc_m_n.GetLength(I0); + const auto N = c_grid_desc_m_n.GetLength(I1); + + const index_t grid_size = (M / MPerBlock) * (N / NPerBlock); + + return grid_size; + } + + __host__ __device__ static constexpr bool CalculateHasMainK0BlockLoop(index_t K0) + { + const bool has_main_k0_block_loop = (K0 / K0PerBlock) > 1; + + return has_main_k0_block_loop; + } + + template + __host__ __device__ static constexpr auto + MakeCGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl( + const CGridDesc_M_N_& c_grid_desc_m_n) + { + const auto M = c_grid_desc_m_n.GetLength(I0); + const auto N = c_grid_desc_m_n.GetLength(I1); + + const auto MBlock = M / MPerBlock; + const auto NBlock = N / NPerBlock; + + constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); + constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl); + + const auto c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl = + transform_tensor_descriptor( + c_grid_desc_m_n, + make_tuple(make_unmerge_transform(make_tuple( + MBlock, Number{}, Number{})), + make_unmerge_transform(make_tuple( + NBlock, Number{}, Number{}))), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 1, 2>{}, Sequence<3, 4, 5>{})); + + return c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl; + } + + // return block_id to C matrix tile idx (m0, n0) mapping + __host__ __device__ static constexpr auto + MakeBlock2CTileMap(const CGridDesc_M_N& c_grid_desc_m_n, index_t M01, index_t N01) + { + const auto M = c_grid_desc_m_n.GetLength(I0); + const auto N = c_grid_desc_m_n.GetLength(I1); + + constexpr auto M1 = Number{}; + constexpr auto N1 = Number{}; + + const auto M0 = M / M1; + const auto N0 = N / N1; + + const auto M00 = M0 / M01; + const auto N00 = N0 / N01; + + const auto m00_m01_n00_n01_to_m0_n0_block_cluster_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_unmerge_transform(make_tuple(M00, M01)), + make_unmerge_transform(make_tuple(N00, N01))), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1, 3>{})); + + const auto c_blockid_to_m00_m01_n00_n01_block_cluster_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(M00, N00, M01, N01))), + make_tuple(Sequence<0, 1, 2, 3>{}), + make_tuple(Sequence<0>{})); + + const auto c_blockid_to_m0_n0_block_cluster_adaptor = + chain_tensor_adaptors(m00_m01_n00_n01_to_m0_n0_block_cluster_adaptor, + c_blockid_to_m00_m01_n00_n01_block_cluster_adaptor); + + return c_blockid_to_m0_n0_block_cluster_adaptor; + } + using CGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl = + remove_cvref_t; + + using C0GridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl = + remove_cvref_t; + + using Block2CTileMap = remove_cvref_t; + + template + __device__ static void + Run(const FloatAB* __restrict__ p_a_grid, + const FloatAB* __restrict__ p_b_grid, + FloatC* __restrict__ p_c_grid, + const FloatC* __restrict__ p_c0_grid, + void* __restrict__ p_shared, + const AGridDesc_K0_M_K1& a_grid_desc_k0_m_k1, + const BGridDesc_K0_N_K1& b_grid_desc_k0_n_k1, + const CGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl& + c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl, + const C0GridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl& + c0_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl, + const AElementwiseOperation& a_element_op, + const BElementwiseOperation& b_element_op, + const CElementwiseOperation& c_element_op, + const Block2CTileMap& block_2_ctile_map) + { + const auto a_grid_buf = make_dynamic_buffer( + p_a_grid, a_grid_desc_k0_m_k1.GetElementSpaceSize()); + const auto b_grid_buf = make_dynamic_buffer( + p_b_grid, b_grid_desc_k0_n_k1.GetElementSpaceSize()); + auto c_grid_buf = make_dynamic_buffer( + p_c_grid, + c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl + .GetElementSpaceSize()); + auto c0_grid_buf = make_dynamic_buffer( + p_c0_grid, + c0_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl + .GetElementSpaceSize()); + + const auto K0 = a_grid_desc_k0_m_k1.GetLength(I0); + + // divide block work by [M, N] + const auto block_work_idx = + block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id())); + + // HACK: this force m/n_block_data_idx_on_grid into SGPR + const index_t m_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(block_work_idx[I0] * MPerBlock); + + const index_t n_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(block_work_idx[I1] * NPerBlock); + + // lds max alignment + constexpr auto max_lds_align = K1; + + // A matrix in LDS memory, dst of blockwise copy + constexpr auto a_block_desc_k0_m_k1 = GetABlockDescriptor_K0PerBlock_MPerBlock_K1(); + + // B matrix in LDS memory, dst of blockwise copy + constexpr auto b_block_desc_k0_n_k1 = GetBBlockDescriptor_K0PerBlock_NPerBlock_K1(); + + // A matrix blockwise copy + auto a_blockwise_copy = + BlockwiseTensorSliceTransfer_v4r1, + ABlockTransferThreadClusterLengths_K0_M_K1, + ABlockTransferThreadClusterArrangeOrder, + FloatAB, + FloatAB, + decltype(a_grid_desc_k0_m_k1), + decltype(a_block_desc_k0_m_k1), + ABlockTransferSrcAccessOrder, + Sequence<1, 0, 2>, + ABlockTransferSrcVectorDim, + 2, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_K1, + 1, + 1, + AThreadTransferSrcResetCoordinateAfterRun, + true>( + a_grid_desc_k0_m_k1, + make_multi_index(0, m_block_data_idx_on_grid, 0), + a_element_op, + a_block_desc_k0_m_k1, + make_multi_index(0, 0, 0), + ck::tensor_operation::element_wise::PassThrough{}); + + // B matrix blockwise copy + auto b_blockwise_copy = + BlockwiseTensorSliceTransfer_v4r1, + BBlockTransferThreadClusterLengths_K0_N_K1, + BBlockTransferThreadClusterArrangeOrder, + FloatAB, + FloatAB, + decltype(b_grid_desc_k0_n_k1), + decltype(b_block_desc_k0_n_k1), + BBlockTransferSrcAccessOrder, + Sequence<1, 0, 2>, + BBlockTransferSrcVectorDim, + 2, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_K1, + 1, + 1, + BThreadTransferSrcResetCoordinateAfterRun, + true>( + b_grid_desc_k0_n_k1, + make_multi_index(0, n_block_data_idx_on_grid, 0), + b_element_op, + b_block_desc_k0_n_k1, + make_multi_index(0, 0, 0), + ck::tensor_operation::element_wise::PassThrough{}); + + // GEMM definition + // c_mtx += transpose(a_mtx) * b_mtx + // a_mtx[K0PerBlock, MPerBlock] is in LDS + // b_mtx[K0PerBlock, NPerBlock] is in LDS + // c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in + // register + // sanity check + + auto blockwise_gemm = + BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1{}; + + auto c_thread_buf = blockwise_gemm.GetCThreadBuffer(); + + // LDS allocation for A and B: be careful of alignment + constexpr auto a_block_space_size_aligned = + math::integer_least_multiple(a_block_desc_k0_m_k1.GetElementSpaceSize(), max_lds_align); + + auto a_block_buf = make_dynamic_buffer( + static_cast(p_shared), a_block_desc_k0_m_k1.GetElementSpaceSize()); + + auto b_block_buf = make_dynamic_buffer( + static_cast(p_shared) + a_block_space_size_aligned, + b_block_desc_k0_n_k1.GetElementSpaceSize()); + + constexpr auto a_block_slice_copy_step = make_multi_index(K0PerBlock, 0, 0); + constexpr auto b_block_slice_copy_step = make_multi_index(K0PerBlock, 0, 0); + + // preload data into LDS + { + a_blockwise_copy.RunRead(a_grid_desc_k0_m_k1, a_grid_buf); + b_blockwise_copy.RunRead(b_grid_desc_k0_n_k1, b_grid_buf); + + a_blockwise_copy.RunWrite(a_block_desc_k0_m_k1, a_block_buf); + b_blockwise_copy.RunWrite(b_block_desc_k0_n_k1, b_block_buf); + } + + // Initialize C + c_thread_buf.Clear(); + + // main body + if constexpr(HasMainKBlockLoop) + { + index_t k0_block_data_begin = 0; + + do + { + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc_k0_m_k1, a_block_slice_copy_step); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc_k0_n_k1, b_block_slice_copy_step); + + a_blockwise_copy.RunRead(a_grid_desc_k0_m_k1, a_grid_buf); + + block_sync_lds(); + + b_blockwise_copy.RunRead(b_grid_desc_k0_n_k1, b_grid_buf); + + blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf); + + block_sync_lds(); + + a_blockwise_copy.RunWrite(a_block_desc_k0_m_k1, a_block_buf); + b_blockwise_copy.RunWrite(b_block_desc_k0_n_k1, b_block_buf); + + k0_block_data_begin += K0PerBlock; + } while(k0_block_data_begin < (K0 - K0PerBlock)); + } + + // tail + { + block_sync_lds(); + + blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf); + } + + // shuffle C and write out + { + static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 && + NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0, + "wrong!"); + + constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); + constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl); + + // TODO: hacky, fix it! + constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 = + blockwise_gemm.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); + + // TODO: hacky, fix it! + // c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp is only used to get lengths + constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp = + blockwise_gemm.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); + + constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I0); + constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I1); + constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I2); + constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I3); + constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I4); + constexpr auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I5); + constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I6); + constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I7); + + constexpr auto c_block_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl = + GetCBlockDescriptor_MBlock_NXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl(); + + auto c_block_buf = make_dynamic_buffer( + static_cast(p_shared), + c_block_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl + .GetElementSpaceSize()); + + constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = transform_tensor_descriptor( + c_block_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl, + make_tuple( + make_freeze_transform(I0), // freeze mblock + make_pass_through_transform( + Number{}), // M0 (MXdlPerWave) per shuffle + make_unmerge_transform( + make_tuple(M1, M2, M3, M4)), // M1 = MWave, M2 * M3 * M4 = MPerXdl + make_freeze_transform(I0), // freeze nblock + make_pass_through_transform( + Number{}), // N0 (NXdlPerWave) per shuffle + make_unmerge_transform( + make_tuple(N1, N2))), // M1 = MWave, M2 * M3 * M4 = MPerXdl + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<2>{}, + Sequence<3>{}, + Sequence<4>{}, + Sequence<5>{}), + make_tuple(Sequence<>{}, + Sequence<0>{}, + Sequence<2, 4, 5, 6>{}, + Sequence<>{}, + Sequence<1>{}, + Sequence<3, 7>{}) + + ); + + // calculate origin of thread output tensor on global memory + // blockwise GEMM c matrix starting index + const auto c_thread_mtx_on_block = + blockwise_gemm.CalculateCThreadOriginDataIndex(I0, I0, I0, I0); + + const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0]; + const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1]; + + const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))), + make_tuple(Sequence<0, 1, 2, 3, 4>{}), + make_tuple(Sequence<0>{})); + + const auto m_thread_data_on_block_idx = + m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex( + make_multi_index(m_thread_data_on_block)); + + const auto n_thread_data_on_block_to_n0_n1_n2_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(N0, N1, N2))), + make_tuple(Sequence<0, 1, 2>{}), + make_tuple(Sequence<0>{})); + + const auto n_thread_data_on_block_idx = + n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex( + make_multi_index(n_thread_data_on_block)); + + // VGPR to LDS + auto c_thread_copy_vgpr_to_lds = + ThreadwiseTensorSliceTransfer_v1r3, + Sequence<0, 1, 2, 3, 4, 5, 6, 7>, + 7, + 1, + InMemoryDataOperationEnum_t::Set, + 1, + true>{ + c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, + make_multi_index(0, + 0, + m_thread_data_on_block_idx[I1], + n_thread_data_on_block_idx[I1], + m_thread_data_on_block_idx[I2], + m_thread_data_on_block_idx[I3], + m_thread_data_on_block_idx[I4], + n_thread_data_on_block_idx[I2]), + ck::tensor_operation::element_wise::PassThrough{}}; + + auto c_block_copy_lds_to_global = BlockwiseTensorSliceTransfer_v6r2< + BlockSize, // index_t BlockSize, + CElementwiseOperation, // ElementwiseOperation, + CGlobalMemoryDataOperation, // DstInMemOp, + Sequence<1, + CShuffleMXdlPerWavePerShuffle, + MWave * MPerXdl, + 1, + CShuffleNXdlPerWavePerShuffle, + NWave * NPerXdl>, // BlockSliceLengths, + CBlockTransferClusterLengths_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl, + Sequence<0, 1, 2, 3, 4, 5>, // typename ThreadClusterArrangeOrder, + FloatC, // typename Src0Data, + FloatC, // typename Src1Data, + FloatC, // typename DstData, + decltype( + c_block_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl), + decltype( + c0_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl), + decltype( + c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl), + Sequence<0, 1, 2, 3, 4, 5>, // typename DimAccessOrder, + 5, // index_t VectorDim, + CBlockTransferScalarPerVector_NWaveNPerXdl, // index_t ScalarPerVector, + true, // bool ThreadTransferSrc0ResetCoordinateAfterRun, + false, // bool ThreadTransferSrc1ResetCoordinateAfterRun, + false> // bool ThreadTransferDstResetCoordinateAfterRun> + {c_block_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl, + make_multi_index(0, 0, 0, 0, 0, 0), + c0_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl, + make_multi_index(block_work_idx[I0], 0, 0, block_work_idx[I1], 0, 0), + c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl, + make_multi_index(block_work_idx[I0], 0, 0, block_work_idx[I1], 0, 0), + c_element_op}; + + constexpr auto mxdlperwave_forward_step = + make_multi_index(0, CShuffleMXdlPerWavePerShuffle, 0, 0, 0, 0); + constexpr auto nxdlperwave_forward_step = + make_multi_index(0, 0, 0, 0, CShuffleNXdlPerWavePerShuffle, 0); + constexpr auto nxdlperwave_backward_step = + make_multi_index(0, 0, 0, 0, -CShuffleNXdlPerWavePerShuffle, 0); + + static_for<0, MXdlPerWave, CShuffleMXdlPerWavePerShuffle>{}([&](auto mxdlperwave_iter) { + constexpr auto mxdlperwave = mxdlperwave_iter; + + static_for<0, + NXdlPerWave, + CShuffleNXdlPerWavePerShuffle>{}([&](auto nxdlperwave_iter) { + constexpr bool nxdlperwave_forward_sweep = + (mxdlperwave % (2 * CShuffleMXdlPerWavePerShuffle) == 0); + + constexpr index_t nxdlperwave_value = + nxdlperwave_forward_sweep + ? nxdlperwave_iter + : (NXdlPerWave - nxdlperwave_iter - CShuffleNXdlPerWavePerShuffle); + + constexpr auto nxdlperwave = Number{}; + + // make sure it's safe to do ds_write + block_sync_lds(); + + // VGPR to LDS + c_thread_copy_vgpr_to_lds.Run( + c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2, + make_tuple(mxdlperwave, nxdlperwave, I0, I0, I0, I0, I0, I0), + c_thread_buf, + c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, + c_block_buf); + + // make sure it's safe to do ds_read + block_sync_lds(); + + // LDS to global + c_block_copy_lds_to_global.Run( + c_block_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl, + c_block_buf, + c0_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl, + c0_grid_buf, + c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl, + c_grid_buf); + + // move on nxdlperwave dimension + if constexpr(nxdlperwave_forward_sweep && + (nxdlperwave < NXdlPerWave - CShuffleNXdlPerWavePerShuffle)) + { + c_block_copy_lds_to_global.MoveSrc1SliceWindow( + c0_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl, + nxdlperwave_forward_step); + + c_block_copy_lds_to_global.MoveDstSliceWindow( + c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl, + nxdlperwave_forward_step); + } + else if constexpr((!nxdlperwave_forward_sweep) && (nxdlperwave > 0)) + { + c_block_copy_lds_to_global.MoveSrc1SliceWindow( + c0_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl, + nxdlperwave_backward_step); + + c_block_copy_lds_to_global.MoveDstSliceWindow( + c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl, + nxdlperwave_backward_step); + } + }); + + // move on mxdlperwave dimension + if constexpr(mxdlperwave < MXdlPerWave - CShuffleMXdlPerWavePerShuffle) + { + c_block_copy_lds_to_global.MoveSrc1SliceWindow( + c0_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl, + mxdlperwave_forward_step); + + c_block_copy_lds_to_global.MoveDstSliceWindow( + c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl, + mxdlperwave_forward_step); + } + }); + } + } +}; + +} // namespace ck +#endif diff --git a/composable_kernel/include/tensor_operation/gridwise_gemm_xdlops_v3r3.hpp b/composable_kernel/include/tensor_operation/gridwise_gemm_xdlops_v3r3.hpp new file mode 100644 index 0000000000..7601aa6a07 --- /dev/null +++ b/composable_kernel/include/tensor_operation/gridwise_gemm_xdlops_v3r3.hpp @@ -0,0 +1,823 @@ +#ifndef CK_GRIDWISE_GEMM_XDLOPS_V3R3_HPP +#define CK_GRIDWISE_GEMM_XDLOPS_V3R3_HPP + +#include "common_header.hpp" +#include "multi_index_transform_helper.hpp" +#include "tensor_descriptor.hpp" +#include "tensor_descriptor_helper.hpp" +#include "blockwise_gemm_xdlops.hpp" +#include "blockwise_tensor_slice_transfer_v4r1.hpp" +#include "blockwise_tensor_slice_transfer_v6r3.hpp" +#include "threadwise_tensor_slice_transfer.hpp" + +namespace ck { + +template +__global__ void +#if CK_USE_LAUNCH_BOUNDS + __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +#endif + kernel_gemm_xdlops_v3r3( + const FloatAB* __restrict__ p_a_grid, + const FloatAB* __restrict__ p_b_grid, + FloatC* __restrict__ p_c_grid, + const FloatC* __restrict__ p_c0_grid, + const FloatC* __restrict__ p_c1_grid, + const AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1, + const BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1, + const CGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl + c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl, + const C0GridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl + c0_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl, + const C1GridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl + c1_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl, + const AElementwiseOperation a_element_op, + const BElementwiseOperation b_element_op, + const CElementwiseOperation c_element_op, + const Block2CTileMap block_2_ctile_map) +{ + __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + + GridwiseGemm::template Run( + p_a_grid, + p_b_grid, + p_c_grid, + p_c0_grid, + p_c1_grid, + p_shared, + a_grid_desc_k0_m_k1, + b_grid_desc_k0_n_k1, + c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl, + c0_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl, + c1_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl, + a_element_op, + b_element_op, + c_element_op, + block_2_ctile_map); +} + +template < + index_t BlockSize, + typename FloatAB, + typename FloatAcc, + typename FloatC, + InMemoryDataOperationEnum_t CGlobalMemoryDataOperation, + typename AGridDesc_K0_M_K1, + typename BGridDesc_K0_N_K1, + typename CGridDesc_M_N, + typename C0GridDesc_M_N, + typename C1GridDesc_M_N, + typename AElementwiseOperation, + typename BElementwiseOperation, + typename CElementwiseOperation, + index_t MPerBlock, + index_t NPerBlock, + index_t K0PerBlock, + index_t MPerXdl, + index_t NPerXdl, + index_t K1Value, + index_t MXdlPerWave, + index_t NXdlPerWave, + typename ABlockTransferThreadClusterLengths_K0_M_K1, + typename ABlockTransferThreadClusterArrangeOrder, + typename ABlockTransferSrcAccessOrder, + index_t ABlockTransferSrcVectorDim, + index_t ABlockTransferSrcScalarPerVector, + index_t ABlockTransferDstScalarPerVector_K1, + bool AThreadTransferSrcResetCoordinateAfterRun, + bool ABlockLdsExtraM, + typename BBlockTransferThreadClusterLengths_K0_N_K1, + typename BBlockTransferThreadClusterArrangeOrder, + typename BBlockTransferSrcAccessOrder, + index_t BBlockTransferSrcVectorDim, + index_t BBlockTransferSrcScalarPerVector, + index_t BBlockTransferDstScalarPerVector_K1, + bool BThreadTransferSrcResetCoordinateAfterRun, + bool BBlockLdsExtraN, + index_t CShuffleMXdlPerWavePerShuffle, + index_t CShuffleNXdlPerWavePerShuffle, + typename CBlockTransferClusterLengths_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl, + index_t CBlockTransferScalarPerVector_NWaveNPerXdl> +struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r3 +{ + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + static constexpr auto I4 = Number<4>{}; + static constexpr auto I5 = Number<5>{}; + static constexpr auto I6 = Number<6>{}; + static constexpr auto I7 = Number<7>{}; + + // K1 should be Number<...> + static constexpr auto K1 = Number{}; + + __host__ __device__ static constexpr auto GetABlockDescriptor_K0PerBlock_MPerBlock_K1() + { + constexpr auto max_lds_align = K1; + + // A matrix in LDS memory, dst of blockwise copy + constexpr auto a_block_desc_k0_m_k1 = [&]() { + if constexpr(ABlockLdsExtraM) + { + return make_naive_tensor_descriptor( + make_tuple(Number{}, Number{}, K1), + make_tuple(Number{} * K1, K1, I1)); + } + else + { + return make_naive_tensor_descriptor_aligned( + make_tuple(Number{}, Number{}, K1), max_lds_align); + } + }(); + + return a_block_desc_k0_m_k1; + } + + __host__ __device__ static constexpr auto GetBBlockDescriptor_K0PerBlock_NPerBlock_K1() + { + constexpr auto max_lds_align = K1; + + // B matrix in LDS memory, dst of blockwise copy + constexpr auto b_block_desc_k0_n_k1 = [&]() { + if constexpr(BBlockLdsExtraN) + { + return make_naive_tensor_descriptor( + make_tuple(Number{}, Number{}, K1), + make_tuple(Number{} * K1, K1, I1)); + } + else + { + return make_naive_tensor_descriptor_aligned( + make_tuple(Number{}, Number{}, K1), max_lds_align); + } + }(); + + return b_block_desc_k0_n_k1; + } + + __host__ __device__ static constexpr auto + GetCBlockDescriptor_MBlock_NXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl() + { + constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); + constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl); + + constexpr auto + c_block_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl = + make_naive_tensor_descriptor_packed( + make_tuple(I1, + Number{}, + Number{}, + I1, + Number{}, + Number{})); + + return c_block_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl; + } + + __host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte() + { + // LDS allocation for A and B: be careful of alignment + constexpr auto a_block_desc_k0_m_k1 = GetABlockDescriptor_K0PerBlock_MPerBlock_K1(); + + constexpr auto b_block_desc_k0_n_k1 = GetBBlockDescriptor_K0PerBlock_NPerBlock_K1(); + + constexpr auto max_lds_align = K1; + + constexpr auto a_block_space_size_aligned = + math::integer_least_multiple(a_block_desc_k0_m_k1.GetElementSpaceSize(), max_lds_align); + + constexpr auto b_block_space_size_aligned = + math::integer_least_multiple(b_block_desc_k0_n_k1.GetElementSpaceSize(), max_lds_align); + + // LDS allocation for C shuffle in LDS + constexpr auto c_block_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl = + GetCBlockDescriptor_MBlock_NXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl(); + + constexpr auto c_block_size = + c_block_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl + .GetElementSpaceSize(); + + return math::max((a_block_space_size_aligned + b_block_space_size_aligned) * + sizeof(FloatAB), + c_block_size * sizeof(FloatC)); + } + + // block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01} + __host__ __device__ static constexpr bool + CheckValidity(const AGridDesc_K0_M_K1& a_grid_desc_k0_m_k1, + const BGridDesc_K0_N_K1& b_grid_desc_k0_n_k1, + const CGridDesc_M_N& c_grid_desc_m_n, + index_t M01, + index_t N01) + { + static_assert(is_known_at_compile_time>::value, + "wrong! K1 need to be known at compile-time"); + + static_assert((MPerBlock % (MPerXdl * MXdlPerWave) == 0) && + (NPerBlock % (NXdlPerWave * NPerXdl)) == 0, + "Invalid tuning param!"); + + const auto M = a_grid_desc_k0_m_k1.GetLength(I1); + const auto N = b_grid_desc_k0_n_k1.GetLength(I1); + const auto K0 = a_grid_desc_k0_m_k1.GetLength(I0); + + if(!(M == c_grid_desc_m_n.GetLength(I0) && N == c_grid_desc_m_n.GetLength(I1) && + K0 == b_grid_desc_k0_n_k1.GetLength(I0) && K1 == a_grid_desc_k0_m_k1.GetLength(I2) && + K1 == b_grid_desc_k0_n_k1.GetLength(I2))) + return false; + + if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && K0 % K0PerBlock == 0)) + return false; + + // check M01, N01 + constexpr auto M1 = Number{}; + constexpr auto N1 = Number{}; + + const auto M0 = M / M1; + const auto N0 = N / N1; + + if(!(M0 % M01 == 0 && N0 % N01 == 0)) + return false; + + // TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc) + return true; + } + + __host__ __device__ static constexpr index_t + CalculateGridSize(const CGridDesc_M_N& c_grid_desc_m_n) + { + const auto M = c_grid_desc_m_n.GetLength(I0); + const auto N = c_grid_desc_m_n.GetLength(I1); + + const index_t grid_size = (M / MPerBlock) * (N / NPerBlock); + + return grid_size; + } + + __host__ __device__ static constexpr bool CalculateHasMainK0BlockLoop(index_t K0) + { + const bool has_main_k0_block_loop = (K0 / K0PerBlock) > 1; + + return has_main_k0_block_loop; + } + + template + __host__ __device__ static constexpr auto + MakeCGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl( + const CGridDesc_M_N_& c_grid_desc_m_n) + { + const auto M = c_grid_desc_m_n.GetLength(I0); + const auto N = c_grid_desc_m_n.GetLength(I1); + + const auto MBlock = M / MPerBlock; + const auto NBlock = N / NPerBlock; + + constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); + constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl); + + const auto c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl = + transform_tensor_descriptor( + c_grid_desc_m_n, + make_tuple(make_unmerge_transform(make_tuple( + MBlock, Number{}, Number{})), + make_unmerge_transform(make_tuple( + NBlock, Number{}, Number{}))), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 1, 2>{}, Sequence<3, 4, 5>{})); + + return c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl; + } + + // return block_id to C matrix tile idx (m0, n0) mapping + __host__ __device__ static constexpr auto + MakeBlock2CTileMap(const CGridDesc_M_N& c_grid_desc_m_n, index_t M01, index_t N01) + { + const auto M = c_grid_desc_m_n.GetLength(I0); + const auto N = c_grid_desc_m_n.GetLength(I1); + + constexpr auto M1 = Number{}; + constexpr auto N1 = Number{}; + + const auto M0 = M / M1; + const auto N0 = N / N1; + + const auto M00 = M0 / M01; + const auto N00 = N0 / N01; + + const auto m00_m01_n00_n01_to_m0_n0_block_cluster_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_unmerge_transform(make_tuple(M00, M01)), + make_unmerge_transform(make_tuple(N00, N01))), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1, 3>{})); + + const auto c_blockid_to_m00_m01_n00_n01_block_cluster_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(M00, N00, M01, N01))), + make_tuple(Sequence<0, 1, 2, 3>{}), + make_tuple(Sequence<0>{})); + + const auto c_blockid_to_m0_n0_block_cluster_adaptor = + chain_tensor_adaptors(m00_m01_n00_n01_to_m0_n0_block_cluster_adaptor, + c_blockid_to_m00_m01_n00_n01_block_cluster_adaptor); + + return c_blockid_to_m0_n0_block_cluster_adaptor; + } + using CGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl = + remove_cvref_t; + + using C0GridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl = + remove_cvref_t; + + using C1GridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl = + remove_cvref_t; + + using Block2CTileMap = remove_cvref_t; + + template + __device__ static void + Run(const FloatAB* __restrict__ p_a_grid, + const FloatAB* __restrict__ p_b_grid, + FloatC* __restrict__ p_c_grid, + const FloatC* __restrict__ p_c0_grid, + const FloatC* __restrict__ p_c1_grid, + void* __restrict__ p_shared, + const AGridDesc_K0_M_K1& a_grid_desc_k0_m_k1, + const BGridDesc_K0_N_K1& b_grid_desc_k0_n_k1, + const CGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl& + c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl, + const C0GridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl& + c0_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl, + const C1GridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl& + c1_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl, + const AElementwiseOperation& a_element_op, + const BElementwiseOperation& b_element_op, + const CElementwiseOperation& c_element_op, + const Block2CTileMap& block_2_ctile_map) + { + const auto a_grid_buf = make_dynamic_buffer( + p_a_grid, a_grid_desc_k0_m_k1.GetElementSpaceSize()); + const auto b_grid_buf = make_dynamic_buffer( + p_b_grid, b_grid_desc_k0_n_k1.GetElementSpaceSize()); + auto c_grid_buf = make_dynamic_buffer( + p_c_grid, + c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl + .GetElementSpaceSize()); + auto c0_grid_buf = make_dynamic_buffer( + p_c0_grid, + c0_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl + .GetElementSpaceSize()); + auto c1_grid_buf = make_dynamic_buffer( + p_c1_grid, + c1_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl + .GetElementSpaceSize()); + + const auto K0 = a_grid_desc_k0_m_k1.GetLength(I0); + + // divide block work by [M, N] + const auto block_work_idx = + block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id())); + + // HACK: this force m/n_block_data_idx_on_grid into SGPR + const index_t m_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(block_work_idx[I0] * MPerBlock); + + const index_t n_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(block_work_idx[I1] * NPerBlock); + + // lds max alignment + constexpr auto max_lds_align = K1; + + // A matrix in LDS memory, dst of blockwise copy + constexpr auto a_block_desc_k0_m_k1 = GetABlockDescriptor_K0PerBlock_MPerBlock_K1(); + + // B matrix in LDS memory, dst of blockwise copy + constexpr auto b_block_desc_k0_n_k1 = GetBBlockDescriptor_K0PerBlock_NPerBlock_K1(); + + // A matrix blockwise copy + auto a_blockwise_copy = + BlockwiseTensorSliceTransfer_v4r1, + ABlockTransferThreadClusterLengths_K0_M_K1, + ABlockTransferThreadClusterArrangeOrder, + FloatAB, + FloatAB, + decltype(a_grid_desc_k0_m_k1), + decltype(a_block_desc_k0_m_k1), + ABlockTransferSrcAccessOrder, + Sequence<1, 0, 2>, + ABlockTransferSrcVectorDim, + 2, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_K1, + 1, + 1, + AThreadTransferSrcResetCoordinateAfterRun, + true>( + a_grid_desc_k0_m_k1, + make_multi_index(0, m_block_data_idx_on_grid, 0), + a_element_op, + a_block_desc_k0_m_k1, + make_multi_index(0, 0, 0), + ck::tensor_operation::element_wise::PassThrough{}); + + // B matrix blockwise copy + auto b_blockwise_copy = + BlockwiseTensorSliceTransfer_v4r1, + BBlockTransferThreadClusterLengths_K0_N_K1, + BBlockTransferThreadClusterArrangeOrder, + FloatAB, + FloatAB, + decltype(b_grid_desc_k0_n_k1), + decltype(b_block_desc_k0_n_k1), + BBlockTransferSrcAccessOrder, + Sequence<1, 0, 2>, + BBlockTransferSrcVectorDim, + 2, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_K1, + 1, + 1, + BThreadTransferSrcResetCoordinateAfterRun, + true>( + b_grid_desc_k0_n_k1, + make_multi_index(0, n_block_data_idx_on_grid, 0), + b_element_op, + b_block_desc_k0_n_k1, + make_multi_index(0, 0, 0), + ck::tensor_operation::element_wise::PassThrough{}); + + // GEMM definition + // c_mtx += transpose(a_mtx) * b_mtx + // a_mtx[K0PerBlock, MPerBlock] is in LDS + // b_mtx[K0PerBlock, NPerBlock] is in LDS + // c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in + // register + // sanity check + + auto blockwise_gemm = + BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1{}; + + auto c_thread_buf = blockwise_gemm.GetCThreadBuffer(); + + // LDS allocation for A and B: be careful of alignment + constexpr auto a_block_space_size_aligned = + math::integer_least_multiple(a_block_desc_k0_m_k1.GetElementSpaceSize(), max_lds_align); + + auto a_block_buf = make_dynamic_buffer( + static_cast(p_shared), a_block_desc_k0_m_k1.GetElementSpaceSize()); + + auto b_block_buf = make_dynamic_buffer( + static_cast(p_shared) + a_block_space_size_aligned, + b_block_desc_k0_n_k1.GetElementSpaceSize()); + + constexpr auto a_block_slice_copy_step = make_multi_index(K0PerBlock, 0, 0); + constexpr auto b_block_slice_copy_step = make_multi_index(K0PerBlock, 0, 0); + + // preload data into LDS + { + a_blockwise_copy.RunRead(a_grid_desc_k0_m_k1, a_grid_buf); + b_blockwise_copy.RunRead(b_grid_desc_k0_n_k1, b_grid_buf); + + a_blockwise_copy.RunWrite(a_block_desc_k0_m_k1, a_block_buf); + b_blockwise_copy.RunWrite(b_block_desc_k0_n_k1, b_block_buf); + } + + // Initialize C + c_thread_buf.Clear(); + + // main body + if constexpr(HasMainKBlockLoop) + { + index_t k0_block_data_begin = 0; + + do + { + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc_k0_m_k1, a_block_slice_copy_step); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc_k0_n_k1, b_block_slice_copy_step); + + a_blockwise_copy.RunRead(a_grid_desc_k0_m_k1, a_grid_buf); + + block_sync_lds(); + + b_blockwise_copy.RunRead(b_grid_desc_k0_n_k1, b_grid_buf); + + blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf); + + block_sync_lds(); + + a_blockwise_copy.RunWrite(a_block_desc_k0_m_k1, a_block_buf); + b_blockwise_copy.RunWrite(b_block_desc_k0_n_k1, b_block_buf); + + k0_block_data_begin += K0PerBlock; + } while(k0_block_data_begin < (K0 - K0PerBlock)); + } + + // tail + { + block_sync_lds(); + + blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf); + } + + // shuffle C and write out + { + static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 && + NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0, + "wrong!"); + + constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); + constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl); + + // TODO: hacky, fix it! + constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 = + blockwise_gemm.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); + + // TODO: hacky, fix it! + // c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp is only used to get lengths + constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp = + blockwise_gemm.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); + + constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I0); + constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I1); + constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I2); + constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I3); + constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I4); + constexpr auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I5); + constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I6); + constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I7); + + constexpr auto c_block_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl = + GetCBlockDescriptor_MBlock_NXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl(); + + auto c_block_buf = make_dynamic_buffer( + static_cast(p_shared), + c_block_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl + .GetElementSpaceSize()); + + constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = transform_tensor_descriptor( + c_block_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl, + make_tuple( + make_freeze_transform(I0), // freeze mblock + make_pass_through_transform( + Number{}), // M0 (MXdlPerWave) per shuffle + make_unmerge_transform( + make_tuple(M1, M2, M3, M4)), // M1 = MWave, M2 * M3 * M4 = MPerXdl + make_freeze_transform(I0), // freeze nblock + make_pass_through_transform( + Number{}), // N0 (NXdlPerWave) per shuffle + make_unmerge_transform( + make_tuple(N1, N2))), // M1 = MWave, M2 * M3 * M4 = MPerXdl + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<2>{}, + Sequence<3>{}, + Sequence<4>{}, + Sequence<5>{}), + make_tuple(Sequence<>{}, + Sequence<0>{}, + Sequence<2, 4, 5, 6>{}, + Sequence<>{}, + Sequence<1>{}, + Sequence<3, 7>{}) + + ); + + // calculate origin of thread output tensor on global memory + // blockwise GEMM c matrix starting index + const auto c_thread_mtx_on_block = + blockwise_gemm.CalculateCThreadOriginDataIndex(I0, I0, I0, I0); + + const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0]; + const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1]; + + const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))), + make_tuple(Sequence<0, 1, 2, 3, 4>{}), + make_tuple(Sequence<0>{})); + + const auto m_thread_data_on_block_idx = + m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex( + make_multi_index(m_thread_data_on_block)); + + const auto n_thread_data_on_block_to_n0_n1_n2_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(N0, N1, N2))), + make_tuple(Sequence<0, 1, 2>{}), + make_tuple(Sequence<0>{})); + + const auto n_thread_data_on_block_idx = + n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex( + make_multi_index(n_thread_data_on_block)); + + // VGPR to LDS + auto c_thread_copy_vgpr_to_lds = + ThreadwiseTensorSliceTransfer_v1r3, + Sequence<0, 1, 2, 3, 4, 5, 6, 7>, + 7, + 1, + InMemoryDataOperationEnum_t::Set, + 1, + true>{ + c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, + make_multi_index(0, + 0, + m_thread_data_on_block_idx[I1], + n_thread_data_on_block_idx[I1], + m_thread_data_on_block_idx[I2], + m_thread_data_on_block_idx[I3], + m_thread_data_on_block_idx[I4], + n_thread_data_on_block_idx[I2]), + ck::tensor_operation::element_wise::PassThrough{}}; + + auto c_block_copy_lds_to_global = BlockwiseTensorSliceTransfer_v6r3< + BlockSize, // index_t BlockSize, + CElementwiseOperation, // ElementwiseOperation, + CGlobalMemoryDataOperation, // DstInMemOp, + Sequence<1, + CShuffleMXdlPerWavePerShuffle, + MWave * MPerXdl, + 1, + CShuffleNXdlPerWavePerShuffle, + NWave * NPerXdl>, // BlockSliceLengths, + CBlockTransferClusterLengths_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl, + Sequence<0, 1, 2, 3, 4, 5>, // typename ThreadClusterArrangeOrder, + FloatC, // typename Src0Data, + FloatC, // typename Src1Data, + FloatC, // typename Src2Data, + FloatC, // typename DstData, + decltype( + c_block_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl), + decltype( + c0_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl), + decltype( + c1_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl), + decltype( + c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl), + Sequence<0, 1, 2, 3, 4, 5>, // typename DimAccessOrder, + 5, // index_t VectorDim, + CBlockTransferScalarPerVector_NWaveNPerXdl, // index_t ScalarPerVector, + true, // bool ThreadTransferSrc0ResetCoordinateAfterRun, + false, // bool ThreadTransferSrc1ResetCoordinateAfterRun, + false, // bool ThreadTransferSrc2ResetCoordinateAfterRun, + false> // bool ThreadTransferDstResetCoordinateAfterRun> + {c_block_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl, + make_multi_index(0, 0, 0, 0, 0, 0), + c0_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl, + make_multi_index(block_work_idx[I0], 0, 0, block_work_idx[I1], 0, 0), + c1_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl, + make_multi_index(block_work_idx[I0], 0, 0, block_work_idx[I1], 0, 0), + c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl, + make_multi_index(block_work_idx[I0], 0, 0, block_work_idx[I1], 0, 0), + c_element_op}; + + constexpr auto mxdlperwave_forward_step = + make_multi_index(0, CShuffleMXdlPerWavePerShuffle, 0, 0, 0, 0); + constexpr auto nxdlperwave_forward_step = + make_multi_index(0, 0, 0, 0, CShuffleNXdlPerWavePerShuffle, 0); + constexpr auto nxdlperwave_backward_step = + make_multi_index(0, 0, 0, 0, -CShuffleNXdlPerWavePerShuffle, 0); + + static_for<0, MXdlPerWave, CShuffleMXdlPerWavePerShuffle>{}([&](auto mxdlperwave_iter) { + constexpr auto mxdlperwave = mxdlperwave_iter; + + static_for<0, + NXdlPerWave, + CShuffleNXdlPerWavePerShuffle>{}([&](auto nxdlperwave_iter) { + constexpr bool nxdlperwave_forward_sweep = + (mxdlperwave % (2 * CShuffleMXdlPerWavePerShuffle) == 0); + + constexpr index_t nxdlperwave_value = + nxdlperwave_forward_sweep + ? nxdlperwave_iter + : (NXdlPerWave - nxdlperwave_iter - CShuffleNXdlPerWavePerShuffle); + + constexpr auto nxdlperwave = Number{}; + + // make sure it's safe to do ds_write + block_sync_lds(); + + // VGPR to LDS + c_thread_copy_vgpr_to_lds.Run( + c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2, + make_tuple(mxdlperwave, nxdlperwave, I0, I0, I0, I0, I0, I0), + c_thread_buf, + c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, + c_block_buf); + + // make sure it's safe to do ds_read + block_sync_lds(); + + // LDS to global + c_block_copy_lds_to_global.Run( + c_block_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl, + c_block_buf, + c0_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl, + c0_grid_buf, + c1_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl, + c1_grid_buf, + c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl, + c_grid_buf); + + // move on nxdlperwave dimension + if constexpr(nxdlperwave_forward_sweep && + (nxdlperwave < NXdlPerWave - CShuffleNXdlPerWavePerShuffle)) + { + c_block_copy_lds_to_global.MoveSrc1SliceWindow( + c0_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl, + nxdlperwave_forward_step); + + c_block_copy_lds_to_global.MoveSrc2SliceWindow( + c1_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl, + nxdlperwave_forward_step); + + c_block_copy_lds_to_global.MoveDstSliceWindow( + c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl, + nxdlperwave_forward_step); + } + else if constexpr((!nxdlperwave_forward_sweep) && (nxdlperwave > 0)) + { + c_block_copy_lds_to_global.MoveSrc1SliceWindow( + c0_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl, + nxdlperwave_backward_step); + + c_block_copy_lds_to_global.MoveSrc2SliceWindow( + c1_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl, + nxdlperwave_backward_step); + + c_block_copy_lds_to_global.MoveDstSliceWindow( + c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl, + nxdlperwave_backward_step); + } + }); + + // move on mxdlperwave dimension + if constexpr(mxdlperwave < MXdlPerWave - CShuffleMXdlPerWavePerShuffle) + { + c_block_copy_lds_to_global.MoveSrc1SliceWindow( + c0_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl, + mxdlperwave_forward_step); + + c_block_copy_lds_to_global.MoveSrc2SliceWindow( + c1_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl, + mxdlperwave_forward_step); + + c_block_copy_lds_to_global.MoveDstSliceWindow( + c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl, + mxdlperwave_forward_step); + } + }); + } + } +}; + +} // namespace ck +#endif diff --git a/composable_kernel/include/tensor_operation/threadwise_tensor_slice_transfer.hpp b/composable_kernel/include/tensor_operation/threadwise_tensor_slice_transfer.hpp index 3302ff6bef..a58855aa35 100644 --- a/composable_kernel/include/tensor_operation/threadwise_tensor_slice_transfer.hpp +++ b/composable_kernel/include/tensor_operation/threadwise_tensor_slice_transfer.hpp @@ -290,7 +290,7 @@ struct ThreadwiseTensorSliceTransfer_v1r3 const DstDesc& dst_desc, DstBuffer& dst_buf) { - constexpr index_t ntransform_dst = DstDesc::GetNumOfTransform(); + constexpr index_t ntransform_dst = remove_cvref_t::GetNumOfTransform(); constexpr auto zeros = typename uniform_sequence_gen::type{}; @@ -326,7 +326,7 @@ struct ThreadwiseTensorSliceTransfer_v1r3 static_for<1, nDim, 1>{}([&](auto i) { index_t tmp = ordered_access_lengths[I0] - 1; - static_for<0, i, 1>{}([&](auto j) { + static_for<1, i, 1>{}([&](auto j) { tmp = tmp * ordered_access_lengths[j] + ordered_access_lengths[j] - 1; }); @@ -506,7 +506,7 @@ struct ThreadwiseTensorSliceTransfer_v2 static_for<1, nDim, 1>{}([&](auto i) { index_t tmp = ordered_access_idx[I0]; - static_for<0, i, 1>{}([&](auto j) { + static_for<1, i, 1>{}([&](auto j) { tmp = tmp * ordered_access_lengths[j] + ordered_access_idx[j]; }); @@ -638,7 +638,7 @@ struct ThreadwiseTensorSliceTransfer_v2 static_for<1, nDim, 1>{}([&](auto i) { index_t tmp = ordered_access_lengths[I0] - 1; - static_for<0, i, 1>{}([&](auto j) { + static_for<1, i, 1>{}([&](auto j) { tmp = tmp * ordered_access_lengths[j] + ordered_access_lengths[j] - 1; }); @@ -835,7 +835,7 @@ struct ThreadwiseTensorSliceTransfer_v3 static_for<1, nDim, 1>{}([&](auto i) { index_t tmp = ordered_src_access_idx[I0]; - static_for<0, i, 1>{}([&](auto j) { + static_for<1, i, 1>{}([&](auto j) { tmp = tmp * ordered_src_access_lengths[j] + ordered_src_access_idx[j]; }); @@ -992,7 +992,7 @@ struct ThreadwiseTensorSliceTransfer_v3 static_for<1, nDim, 1>{}([&](auto i) { index_t tmp = ordered_dst_access_idx[I0]; - static_for<0, i, 1>{}([&](auto j) { + static_for<1, i, 1>{}([&](auto j) { tmp = tmp * ordered_dst_access_lengths[j] + ordered_dst_access_idx[j]; }); @@ -1136,7 +1136,7 @@ struct ThreadwiseTensorSliceTransfer_v3 static_for<1, nDim, 1>{}([&](auto i) { index_t tmp = ordered_src_access_lengths[I0] - 1; - static_for<0, i, 1>{}([&](auto j) { + static_for<1, i, 1>{}([&](auto j) { tmp = tmp * ordered_src_access_lengths[j] + ordered_src_access_lengths[j] - 1; }); @@ -1196,7 +1196,7 @@ struct ThreadwiseTensorSliceTransfer_v3 static_for<1, nDim, 1>{}([&](auto i) { index_t tmp = ordered_dst_access_lengths[I0] - 1; - static_for<0, i, 1>{}([&](auto j) { + static_for<1, i, 1>{}([&](auto j) { tmp = tmp * ordered_dst_access_lengths[j] + ordered_dst_access_lengths[j] - 1; }); diff --git a/composable_kernel/include/tensor_operation/threadwise_tensor_slice_transfer_v1r4.hpp b/composable_kernel/include/tensor_operation/threadwise_tensor_slice_transfer_v1r4.hpp index c52787dafc..c669427896 100644 --- a/composable_kernel/include/tensor_operation/threadwise_tensor_slice_transfer_v1r4.hpp +++ b/composable_kernel/include/tensor_operation/threadwise_tensor_slice_transfer_v1r4.hpp @@ -116,9 +116,6 @@ struct ThreadwiseTensorSliceTransfer_v1r4 constexpr auto dst_scalar_per_access = generate_sequence( detail::lambda_scalar_per_access{}, Number{}); - constexpr auto dst_scalar_step_in_vector = - generate_sequence(detail::lambda_scalar_step_in_vector{}, Number{}); - constexpr auto access_lengths = SliceLengths{} / dst_scalar_per_access; constexpr auto dim_access_order = DimAccessOrder{}; @@ -141,7 +138,8 @@ struct ThreadwiseTensorSliceTransfer_v1r4 Number{}); // make forward steps: dst0 - // WARNING!!!!!!: this logic is only correct if DstScalarPerVector=1 + // WARNING!!!!!!: this logic is only correct if dst/dst0/dst1 can use the same + // DstScalarPerVector // TODO: fix this const auto dst0_forward_steps = generate_tuple( [&](auto i) { @@ -157,7 +155,8 @@ struct ThreadwiseTensorSliceTransfer_v1r4 Number{}); // make forward steps: dst1 - // WARNING!!!!!!: this logic is only correct if DstScalarPerVector=1 + // WARNING!!!!!!: this logic is only correct if dst/dst0/dst1 can use the same + // DstScalarPerVector // TODO: fix this const auto dst1_forward_steps = generate_tuple( [&](auto i) { @@ -187,7 +186,8 @@ struct ThreadwiseTensorSliceTransfer_v1r4 Number{}); // make backward steps: dst0 - // WARNING!!!!!!: this logic is only correct if DstScalarPerVector=1 + // WARNING!!!!!!: this logic is only correct if dst/dst0/dst1 can use the same + // DstScalarPerVector // TODO: fix this const auto dst0_backward_steps = generate_tuple( [&](auto i) { @@ -203,7 +203,8 @@ struct ThreadwiseTensorSliceTransfer_v1r4 Number{}); // make backward steps: dst1 - // WARNING!!!!!!: this logic is only correct if DstScalarPerVector=1 + // WARNING!!!!!!: this logic is only correct if dst/dst0/dst1 can use the same + // DstScalarPerVector // TODO: fix this const auto dst1_backward_steps = generate_tuple( [&](auto i) { @@ -229,7 +230,7 @@ struct ThreadwiseTensorSliceTransfer_v1r4 static_for<1, nDim, 1>{}([&](auto i) { index_t tmp = ordered_access_idx[I0]; - static_for<0, i, 1>{}([&](auto j) { + static_for<1, i, 1>{}([&](auto j) { tmp = tmp * ordered_access_lengths[j] + ordered_access_idx[j]; }); @@ -397,14 +398,12 @@ struct ThreadwiseTensorSliceTransfer_v1r4 typename SrcBuffer, typename DstBuffer, typename Dst0Buffer, - typename Dst1Buffer, - typename DstStepHacks> + typename Dst1Buffer> __device__ void Run(const SrcDesc&, const SrcSliceOriginIdx&, const SrcBuffer& src_buf, const DstDesc& dst_desc, DstBuffer& dst_buf, - const DstStepHacks& dst_step_hacks, const Dst0Desc& dst0_desc, const Dst0Buffer& dst0_buf, const Dst1Desc& dst1_desc, @@ -427,7 +426,7 @@ struct ThreadwiseTensorSliceTransfer_v1r4 src_buf, dst_desc, dst_buf, - dst_step_hacks, + f_step_hacks(dst_desc), dst0_desc, dst0_buf, f_step_hacks(dst0_desc), @@ -461,7 +460,7 @@ struct ThreadwiseTensorSliceTransfer_v1r4 static_for<1, nDim, 1>{}([&](auto i) { index_t tmp = ordered_access_lengths[I0] - 1; - static_for<0, i, 1>{}([&](auto j) { + static_for<1, i, 1>{}([&](auto j) { tmp = tmp * ordered_access_lengths[j] + ordered_access_lengths[j] - 1; }); diff --git a/composable_kernel/include/tensor_operation/threadwise_tensor_slice_transfer_v1r5.hpp b/composable_kernel/include/tensor_operation/threadwise_tensor_slice_transfer_v1r5.hpp new file mode 100644 index 0000000000..6389680c5f --- /dev/null +++ b/composable_kernel/include/tensor_operation/threadwise_tensor_slice_transfer_v1r5.hpp @@ -0,0 +1,453 @@ +#ifndef CK_THREADWISE_TENSOR_SLICE_TRANSFER_V1R5_HPP +#define CK_THREADWISE_TENSOR_SLICE_TRANSFER_V1R5_HPP + +#include "common_header.hpp" +#include "tensor_descriptor.hpp" +#include "tensor_descriptor_helper.hpp" + +namespace ck { + +// Do following things to avoid "alloca" in LLVM-IR, which would cause scratch memory +// and sometimes useless instructions: +// 1. Don't save a reference to tensor descriptor in class, pass in tensor descriptor as argument +// instead +// 2. Don't construct a new tensor coordinate everytime when using it, update and reuse the same +// tensor coordinate instead +// 3. Don't use a pointer to VGPR buffer, use vector instead + +// WARNING!!!!!!: this logic is only correct if DstScalarPerVector=1 +// TODO: fix this +// Assume: +// 1. src: +// 1. SrcDesc is known at compile-time +// 2. SrcBuffer is StaticBuffer +// 3. SrcSliceOrginIdx is known at compile-time +// 2. dst: +// 1. DstDesc is not known at compile-time +// 2. DstBuffer is DynamicBuffer +// 3. DstSliceOrginIdx is not known at compile time +template ::type = false> +struct ThreadwiseTensorSliceTransfer_v1r5 +{ + static constexpr index_t nDim = SliceLengths::Size(); + + using Index = MultiIndex; + + using DstCoord = decltype(make_tensor_coordinate(DstDesc{}, Index{})); + using Dst0Coord = decltype(make_tensor_coordinate(Dst0Desc{}, Index{})); + + using DstCoordStep = decltype(make_tensor_coordinate_step(DstDesc{}, Index{})); + using Dst0CoordStep = decltype(make_tensor_coordinate_step(Dst0Desc{}, Index{})); + + __device__ constexpr ThreadwiseTensorSliceTransfer_v1r5( + const DstDesc& dst_desc, + const Dst0Desc& dst0_desc, + const Index& dst_slice_origin_idx, + const DstElementwiseOperation& dst_element_op) + : dst_coord_(make_tensor_coordinate(dst_desc, dst_slice_origin_idx)), + dst0_coord_(make_tensor_coordinate(dst0_desc, dst_slice_origin_idx)), + dst_element_op_{dst_element_op} + { + static_assert(SrcDesc::IsKnownAtCompileTime(), + "wrong! SrcDesc need to known at compile-time"); + } + + __device__ void SetDstSliceOrigin(const DstDesc& dst_desc, const Index& dst_slice_origin_idx) + { + dst_coord_ = make_tensor_coordinate(dst_desc, dst_slice_origin_idx); + } + + template + __device__ void Run(const SrcDesc&, + const SrcSliceOriginIdx&, + const SrcBuffer& src_buf, + const DstDesc& dst_desc, + DstBuffer& dst_buf, + const DstStepHacks& dst_step_hacks, + const Dst0Desc& dst0_desc, + const Dst0Buffer& dst0_buf, + const Dst0StepHacks& dst0_step_hacks) + { + static_assert(SrcDesc::IsKnownAtCompileTime(), + "wrong! SrcDesc need to known at compile-time"); + + static_assert(is_known_at_compile_time>::value, + "wrong! SrcSliceOrigin need to known at compile-time"); + + static_assert(SrcBuffer::IsStaticBuffer(), "wrong! SrcBuffer need to be StaticBuffer"); + + // SrcDesc and src_slice_origin_idx are known at compile-time + constexpr auto src_desc = remove_cvref_t{}; + constexpr auto src_slice_origin_idx = to_multi_index(SrcSliceOriginIdx{}); + + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + + // scalar per access on each dim + // TODO: don't use lambda_scalar_per_access + constexpr auto dst_scalar_per_access = generate_sequence( + detail::lambda_scalar_per_access{}, Number{}); + + constexpr auto access_lengths = SliceLengths{} / dst_scalar_per_access; + + constexpr auto dim_access_order = DimAccessOrder{}; + + constexpr auto ordered_access_lengths = + container_reorder_given_new2old(access_lengths, dim_access_order); + + // make forward steps: dst + const auto dst_forward_steps = generate_tuple( + [&](auto i) { + Index forward_step_idx; + + static_for<0, nDim, 1>{}([&](auto j) { + forward_step_idx(j) = (i.value == j.value) ? dst_scalar_per_access[i] : 0; + }); + + return make_tensor_coordinate_step( + dst_desc, forward_step_idx, dst_step_hacks[I0][i]); + }, + Number{}); + + // make forward steps: dst0 + // WARNING!!!!!!: this logic is only correct if DstScalarPerVector=1 + // TODO: fix this + const auto dst0_forward_steps = generate_tuple( + [&](auto i) { + Index forward_step_idx; + + static_for<0, nDim, 1>{}([&](auto j) { + forward_step_idx(j) = (i.value == j.value) ? dst_scalar_per_access[i] : 0; + }); + + return make_tensor_coordinate_step( + dst0_desc, forward_step_idx, dst0_step_hacks[I0][i]); + }, + Number{}); + + // make backward steps: dst + const auto dst_backward_steps = generate_tuple( + [&](auto i) { + Index backward_step_idx; + + static_for<0, nDim, 1>{}([&](auto j) { + backward_step_idx(j) = (i.value == j.value) ? -dst_scalar_per_access[i] : 0; + }); + + return make_tensor_coordinate_step( + dst_desc, backward_step_idx, dst_step_hacks[I1][i]); + }, + Number{}); + + // make backward steps: dst0 + // WARNING!!!!!!: this logic is only correct if DstScalarPerVector=1 + // TODO: fix this + const auto dst0_backward_steps = generate_tuple( + [&](auto i) { + Index backward_step_idx; + + static_for<0, nDim, 1>{}([&](auto j) { + backward_step_idx(j) = (i.value == j.value) ? -dst_scalar_per_access[i] : 0; + }); + + return make_tensor_coordinate_step( + dst0_desc, backward_step_idx, dst0_step_hacks[I1][i]); + }, + Number{}); + + // loop over tensor and copy + static_ford{}([&](auto ordered_access_idx) { + // judge move forward or move backward + constexpr auto forward_sweep = [&]() { + StaticallyIndexedArray forward_sweep_; + + forward_sweep_(I0) = true; + + static_for<1, nDim, 1>{}([&](auto i) { + index_t tmp = ordered_access_idx[I0]; + + static_for<1, i, 1>{}([&](auto j) { + tmp = tmp * ordered_access_lengths[j] + ordered_access_idx[j]; + }); + + forward_sweep_(i) = tmp % 2 == 0; + }); + + return forward_sweep_; + }(); + + // calculate dst data index + constexpr auto dst_data_idx = [&]() { + Index ordered_idx; + + static_for<0, nDim, 1>{}([&](auto i) { + ordered_idx(i) = forward_sweep[i] + ? ordered_access_idx[i] + : ordered_access_lengths[i] - 1 - ordered_access_idx[i]; + }); + + return container_reorder_given_old2new(ordered_idx, dim_access_order) * + dst_scalar_per_access; + }(); + + typename vector_type_maker::type dst_vector; + + using dst_vector_t = + typename vector_type_maker::type::type; + + // load dst0 and apply elementwise operation + { + // WARNING!!!!!!: this logic is only correct if DstScalarPerVector=1 + // TODO: fix this + static_assert(DstScalarPerVector == 1, "wrong!"); + + // copy data from src_buf into dst_vector_src_data + constexpr index_t src_offset = + src_desc.CalculateOffset(src_slice_origin_idx + dst_data_idx); + + const SrcData src_v = src_buf[Number{}]; + + // load dst0 + const bool is_dst0_valid = + coordinate_has_valid_offset_assuming_visible_index_is_valid(dst0_desc, + dst0_coord_); + const DstData dst0_v = + dst0_buf.template Get(dst0_coord_.GetOffset(), is_dst0_valid); + +#if !CK_WORKAROUND_SWDEV_XXXXXX_THREAD_WISE_COPY_V1R5_TYPE_CONVERT_ISSUE + // apply element-wise operation in SrcData type + const SrcData dst_v = dst_element_op_(src_v, type_convert(dst0_v)); + + // apply type convert + dst_vector.template AsType()(Number<0>{}) = type_convert(dst_v); +#else + // apply element-wise operation in DstData type + const DstData dst_v = dst_element_op_(src_v, dst0_v); + + dst_vector.template AsType()(Number<0>{}) = dst_v; +#endif + } + + const bool is_dst_valid = + coordinate_has_valid_offset_assuming_visible_index_is_valid(dst_desc, dst_coord_); + + // copy data from dst_vector into dst_buf + if constexpr(DstInMemOp == InMemoryDataOperationEnum_t::Set) + { + dst_buf.template Set( + dst_coord_.GetOffset(), + is_dst_valid, + dst_vector.template AsType()[Number<0>{}]); + } + else if constexpr(DstInMemOp == InMemoryDataOperationEnum_t::AtomicAdd) + { + dst_buf.template AtomicAdd( + dst_coord_.GetOffset(), + is_dst_valid, + dst_vector.template AsType()[Number<0>{}]); + } + else if constexpr(DstInMemOp == InMemoryDataOperationEnum_t::Add) + { + + typename vector_type_maker::type tmp; + tmp.template AsType()(Number<0>{}) = + dst_buf.template Get(dst_coord_.GetOffset(), is_dst_valid); + + static_for<0, DstScalarPerVector, 1>{}([&](auto t) { + dst_vector.template AsType()(t) += tmp.template AsType()[t]; + }); + + dst_buf.template Set( + dst_coord_.GetOffset(), + is_dst_valid, + dst_vector.template AsType()[Number<0>{}]); + } + + constexpr auto move_on_dim = [&]() constexpr + { + StaticallyIndexedArray move_on_dim_; + + static_for<0, nDim, 1>{}([&](auto i) { + move_on_dim_(i) = ordered_access_idx[i] < ordered_access_lengths[i] - 1; + + static_for{}([&](auto j) { + move_on_dim_(i) &= ordered_access_idx[j] == ordered_access_lengths[j] - 1; + }); + }); + + return move_on_dim_; + } + (); + + // move + static_for<0, nDim, 1>{}([&](auto i) { + if constexpr(move_on_dim[i]) + { + if constexpr(forward_sweep[i]) + { + move_tensor_coordinate( + dst_desc, dst_coord_, dst_forward_steps[dim_access_order[i]]); + + // dst0 + move_tensor_coordinate( + dst0_desc, dst0_coord_, dst0_forward_steps[dim_access_order[i]]); + } + else + { + move_tensor_coordinate( + dst_desc, dst_coord_, dst_backward_steps[dim_access_order[i]]); + + // dst0 + move_tensor_coordinate( + dst0_desc, dst0_coord_, dst0_backward_steps[dim_access_order[i]]); + } + } + }); + }); + + // move dst coordinate back to slice origin (or not) + if constexpr(DstResetCoordinateAfterRun) + { + const auto dst_reset_step = + make_tensor_coordinate_step(dst_desc, GetDstCoordinateResetStep()); + + move_tensor_coordinate(dst_desc, dst_coord_, dst_reset_step); + } + } + + template + __device__ void Run(const SrcDesc&, + const SrcSliceOriginIdx&, + const SrcBuffer& src_buf, + const DstDesc& dst_desc, + DstBuffer& dst_buf, + const Dst0Desc& dst0_desc, + const Dst0Buffer& dst0_buf) + { + auto f_step_hacks = [&](auto desc) { + constexpr index_t ntransform = decltype(desc)::GetNumOfTransform(); + + constexpr auto zeros = typename uniform_sequence_gen::type{}; + + constexpr auto step_hacks = + make_tuple(generate_tuple([&](auto) { return zeros; }, Number{}), + generate_tuple([&](auto) { return zeros; }, Number{})); + + return step_hacks; + }; + + Run(SrcDesc{}, + SrcSliceOriginIdx{}, + src_buf, + dst_desc, + dst_buf, + f_step_hacks(dst_desc), + dst0_desc, + dst0_buf, + f_step_hacks(dst0_desc)); + } + + __device__ static constexpr auto GetDstCoordinateResetStep() + { + constexpr auto I0 = Number<0>{}; + + // scalar per access on each dim + // TODO: don't use lambda_scalar_per_access + constexpr auto dst_scalar_per_access = generate_sequence( + detail::lambda_scalar_per_access{}, Number{}); + + constexpr auto access_lengths = SliceLengths{} / dst_scalar_per_access; + + constexpr auto dim_access_order = DimAccessOrder{}; + + constexpr auto ordered_access_lengths = + container_reorder_given_new2old(access_lengths, dim_access_order); + + // judge move forward or move backward during the last iteration + constexpr auto forward_sweep = [&]() { + StaticallyIndexedArray forward_sweep_; + + forward_sweep_(I0) = true; + + static_for<1, nDim, 1>{}([&](auto i) { + index_t tmp = ordered_access_lengths[I0] - 1; + + static_for<1, i, 1>{}([&](auto j) { + tmp = tmp * ordered_access_lengths[j] + ordered_access_lengths[j] - 1; + }); + + forward_sweep_(i) = tmp % 2 == 0; + }); + + return forward_sweep_; + }(); + + // calculate dst data index after last iteration in Run(), if it has not being reset by + // RunWrite() + constexpr auto dst_data_idx = [&]() { + Index ordered_idx; + + static_for<0, nDim, 1>{}([&](auto i) { + ordered_idx(i) = forward_sweep[i] ? ordered_access_lengths[i] - 1 : 0; + }); + + return container_reorder_given_old2new(ordered_idx, dim_access_order) * + dst_scalar_per_access; + }(); + + // + constexpr auto reset_dst_data_step = [&]() { + Index reset_dst_data_step_; + + static_for<0, nDim, 1>{}([&](auto i) { reset_dst_data_step_(i) = -dst_data_idx[i]; }); + + return reset_dst_data_step_; + }(); + + return reset_dst_data_step; + } + + // dst_slice_origin_step_idx need to be known at compile-time, for performance reason + __device__ void MoveDstSliceWindow(const DstDesc& dst_desc, + const Index& dst_slice_origin_step_idx) + { + // if dst coord was not reset by Run(), then need to adjust the step here + const auto adjusted_step_idx = + DstResetCoordinateAfterRun ? dst_slice_origin_step_idx + : dst_slice_origin_step_idx + GetDstCoordinateResetStep(); + + // is it OK to construct a new step every time? + const auto adjusted_step = make_tensor_coordinate_step(dst_desc, adjusted_step_idx); + + move_tensor_coordinate(dst_desc, dst_coord_, adjusted_step); + } + + private: + DstCoord dst_coord_; + Dst0Coord dst0_coord_; + const DstElementwiseOperation dst_element_op_; +}; // namespace ck + +} // namespace ck +#endif diff --git a/composable_kernel/include/tensor_operation/threadwise_tensor_slice_transfer_v3r2.hpp b/composable_kernel/include/tensor_operation/threadwise_tensor_slice_transfer_v3r1.hpp similarity index 94% rename from composable_kernel/include/tensor_operation/threadwise_tensor_slice_transfer_v3r2.hpp rename to composable_kernel/include/tensor_operation/threadwise_tensor_slice_transfer_v3r1.hpp index f9f4fff63b..5497bb2e3d 100644 --- a/composable_kernel/include/tensor_operation/threadwise_tensor_slice_transfer_v3r2.hpp +++ b/composable_kernel/include/tensor_operation/threadwise_tensor_slice_transfer_v3r1.hpp @@ -1,5 +1,5 @@ -#ifndef CK_THREADWISE_TENSOR_SLICE_TRANSFER_V3R2_HPP -#define CK_THREADWISE_TENSOR_SLICE_TRANSFER_V3R2_HPP +#ifndef CK_THREADWISE_TENSOR_SLICE_TRANSFER_V3R1_HPP +#define CK_THREADWISE_TENSOR_SLICE_TRANSFER_V3R1_HPP #include "common_header.hpp" #include "tensor_descriptor.hpp" @@ -47,6 +47,7 @@ struct lambda_scalar_per_access_for_src_and_dst // 4. Use thread buffer template // control whether to move back dst coordinate after each // RunWrite(), will be fused with MoveDstSliceWindow to // save addr computation -struct ThreadwiseTensorSliceTransfer_v3r2 +struct ThreadwiseTensorSliceTransfer_v3r1 { static constexpr index_t nDim = SliceLengths::Size(); using Index = MultiIndex; @@ -77,15 +78,17 @@ struct ThreadwiseTensorSliceTransfer_v3r2 using SrcCoordStep = decltype(make_tensor_coordinate_step(SrcDesc{}, Index{})); using DstCoordStep = decltype(make_tensor_coordinate_step(DstDesc{}, Index{})); - __device__ constexpr ThreadwiseTensorSliceTransfer_v3r2( + __device__ constexpr ThreadwiseTensorSliceTransfer_v3r1( const SrcDesc& src_desc, const Index& src_slice_origin, + const SrcElementwiseOperation& src_element_op, const DstDesc& dst_desc, const Index& dst_slice_origin, - const SrcElementwiseOperation& src_element_op) + const DstElementwiseOperation& dst_element_op) : src_coord_(make_tensor_coordinate(src_desc, src_slice_origin)), dst_coord_(make_tensor_coordinate(dst_desc, dst_slice_origin)), - src_element_op_(src_element_op) + src_element_op_(src_element_op), + dst_element_op_(dst_element_op) { } @@ -165,7 +168,7 @@ struct ThreadwiseTensorSliceTransfer_v3r2 static_for<1, nDim, 1>{}([&](auto i) { index_t tmp = ordered_src_access_idx[I0]; - static_for<0, i, 1>{}([&](auto j) { + static_for<1, i, 1>{}([&](auto j) { tmp = tmp * ordered_src_access_lengths[j] + ordered_src_access_idx[j]; }); @@ -412,7 +415,7 @@ struct ThreadwiseTensorSliceTransfer_v3r2 static_for<1, nDim, 1>{}([&](auto i) { index_t tmp = ordered_dst_access_idx[I0]; - static_for<0, i, 1>{}([&](auto j) { + static_for<1, i, 1>{}([&](auto j) { tmp = tmp * ordered_dst_access_lengths[j] + ordered_dst_access_idx[j]; }); @@ -442,13 +445,24 @@ struct ThreadwiseTensorSliceTransfer_v3r2 const bool is_dst_valid = coordinate_has_valid_offset_assuming_visible_index_is_valid(dst_desc, dst_coord_); - using dst_vector_t = typename vector_type_maker_t::type; + using dst_vector_type = vector_type_maker_t; + using dst_vector_t = typename dst_vector_type::type; - // copy data from dst_thread_scratch_ to dst_buf + // copy data from dst_thread_scratch_ into dst_vector_container + auto dst_vector_container = dst_vector_type{ + dst_thread_scratch_.template GetAsType(dst_data_idx_seq)}; + + // apply DstElementwiseOperation on dst_vector_container + static_for<0, DstScalarPerVector, 1>{}([&](auto i) { + dst_vector_container.template AsType()(i) = + dst_element_op_(dst_vector_container.template AsType()[i]); + }); + + // copy data from dst_vector_container to dst_buf dst_buf.template Set( dst_coord_.GetOffset(), is_dst_valid, - dst_thread_scratch_.template GetAsType(dst_data_idx_seq)); + dst_vector_container.template AsType()[I0]); constexpr auto move_on_dim = [&]() constexpr { @@ -498,7 +512,7 @@ struct ThreadwiseTensorSliceTransfer_v3r2 template __device__ void RunRead(const SrcDesc& src_desc, const SrcBuffer& src_buf) { - constexpr index_t ntransform_src = SrcDesc::GetNumOfTransform(); + constexpr index_t ntransform_src = remove_cvref_t::GetNumOfTransform(); constexpr auto zeros = typename uniform_sequence_gen::type{}; @@ -512,7 +526,8 @@ struct ThreadwiseTensorSliceTransfer_v3r2 template __device__ void RunWrite(const DstDesc& dst_desc, DstBuffer& dst_buf) { - constexpr index_t ntransform_dst = DstDesc::GetNumOfTransform(); + // TODO: why need remove_cvref_t ? + constexpr index_t ntransform_dst = remove_cvref_t::GetNumOfTransform(); constexpr auto zeros = typename uniform_sequence_gen::type{}; @@ -548,7 +563,7 @@ struct ThreadwiseTensorSliceTransfer_v3r2 static_for<1, nDim, 1>{}([&](auto i) { index_t tmp = ordered_src_access_lengths[I0] - 1; - static_for<0, i, 1>{}([&](auto j) { + static_for<1, i, 1>{}([&](auto j) { tmp = tmp * ordered_src_access_lengths[j] + ordered_src_access_lengths[j] - 1; }); @@ -608,7 +623,7 @@ struct ThreadwiseTensorSliceTransfer_v3r2 static_for<1, nDim, 1>{}([&](auto i) { index_t tmp = ordered_dst_access_lengths[I0] - 1; - static_for<0, i, 1>{}([&](auto j) { + static_for<1, i, 1>{}([&](auto j) { tmp = tmp * ordered_dst_access_lengths[j] + ordered_dst_access_lengths[j] - 1; }); @@ -811,6 +826,7 @@ struct ThreadwiseTensorSliceTransfer_v3r2 SrcCoord src_coord_; DstCoord dst_coord_; const SrcElementwiseOperation src_element_op_; + const DstElementwiseOperation dst_element_op_; }; } // namespace ck diff --git a/composable_kernel/include/tensor_operation/threadwise_tensor_slice_transfer_v3r3.hpp b/composable_kernel/include/tensor_operation/threadwise_tensor_slice_transfer_v3r3.hpp new file mode 100644 index 0000000000..8f9d4fe281 --- /dev/null +++ b/composable_kernel/include/tensor_operation/threadwise_tensor_slice_transfer_v3r3.hpp @@ -0,0 +1,883 @@ +#ifndef CK_THREADWISE_TENSOR_SLICE_TRANSFER_V3R3_HPP +#define CK_THREADWISE_TENSOR_SLICE_TRANSFER_V3R3_HPP + +#include "common_header.hpp" +#include "tensor_descriptor.hpp" +#include "tensor_descriptor_helper.hpp" +#include "static_tensor.hpp" + +namespace ck { + +namespace detail { +// TODO: How to fix this? It uses an struct instead of lambda because lambda +// doesn't have constructor +template +struct lambda_scalar_per_access_for_src_and_dst +{ + __host__ __device__ constexpr auto operator()(index_t i) const + { + if(i == SrcVectorDim && i == DstVectorDim) + { + return math::lcm(SrcScalarPerVector, DstScalarPerVector); + } + else if(i == SrcVectorDim) + { + return SrcScalarPerVector; + } + else if(i == DstVectorDim) + { + return DstScalarPerVector; + } + else + { + return 1; + } + } +}; + +} // namespace detail + +// Assume: +// 1. src_desc and dst_desc are not known at compile-time +// 2. SrcBuffer and DstBuffer are DynamicBuffer +// 3. src_slice_origin and dst_slice_origin are not known at compile-time, +// 4. Use thread buffer +template // control whether to move back dst coordinate after each + // RunWrite(), will be fused with MoveDstSliceWindow to + // save addr computation +struct ThreadwiseTensorSliceTransfer_v3r3 +{ + static constexpr index_t nDim = SliceLengths::Size(); + using Index = MultiIndex; + + using SrcCoord = decltype(make_tensor_coordinate(SrcDesc{}, Index{})); + using DstCoord = decltype(make_tensor_coordinate(DstDesc{}, Index{})); + using Dst0Coord = decltype(make_tensor_coordinate(Dst0Desc{}, Index{})); + using Dst1Coord = decltype(make_tensor_coordinate(Dst1Desc{}, Index{})); + + using SrcCoordStep = decltype(make_tensor_coordinate_step(SrcDesc{}, Index{})); + using DstCoordStep = decltype(make_tensor_coordinate_step(DstDesc{}, Index{})); + using Dst0CoordStep = decltype(make_tensor_coordinate_step(Dst0Desc{}, Index{})); + using Dst1CoordStep = decltype(make_tensor_coordinate_step(Dst1Desc{}, Index{})); + + __device__ constexpr ThreadwiseTensorSliceTransfer_v3r3( + const SrcDesc& src_desc, + const Index& src_slice_origin, + const SrcElementwiseOperation& src_element_op, + const DstDesc& dst_desc, + const Dst0Desc& dst0_desc, + const Dst1Desc& dst1_desc, + const Index& dst_slice_origin, + const DstElementwiseOperation& dst_element_op) + : src_coord_(make_tensor_coordinate(src_desc, src_slice_origin)), + dst_coord_(make_tensor_coordinate(dst_desc, dst_slice_origin)), + dst0_coord_(make_tensor_coordinate(dst0_desc, dst_slice_origin)), + dst1_coord_(make_tensor_coordinate(dst1_desc, dst_slice_origin)), + src_element_op_(src_element_op), + dst_element_op_(dst_element_op) + { + } + + __device__ void SetSrcSliceOrigin(const SrcDesc& src_desc, const Index& src_slice_origin_idx) + { + src_coord_ = make_tensor_coordinate(src_desc, src_slice_origin_idx); + } + + __device__ void SetDstSliceOrigin(const DstDesc& dst_desc, + const Dst0Desc& dst0_desc, + const Dst1Desc& dst1_desc, + const Index& dst_slice_origin_idx) + { + dst_coord_ = make_tensor_coordinate(dst_desc, dst_slice_origin_idx); + dst0_coord_ = make_tensor_coordinate(dst0_desc, dst_slice_origin_idx); + dst1_coord_ = make_tensor_coordinate(dst1_desc, dst_slice_origin_idx); + } + + template + __device__ void RunRead(const SrcDesc& src_desc, const SrcBuffer& src_buf) + { + static_assert(SrcBuffer::GetAddressSpace() == AddressSpaceEnum_t::Global or + SrcBuffer::GetAddressSpace() == AddressSpaceEnum_t::Lds, + "wrong!"); + + static_assert( + is_same, remove_cvref_t>::value, + "wrong! SrcBuffer and SrcData data type are inconsistent"); + + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + + // scalar per access on each dim + // TODO: don't use lambda_scalar_per_access + constexpr auto src_scalar_per_access = generate_sequence( + detail::lambda_scalar_per_access{}, Number{}); + + constexpr auto src_access_lengths = SliceLengths{} / src_scalar_per_access; + + constexpr auto src_dim_access_order = SrcDimAccessOrder{}; + + constexpr auto ordered_src_access_lengths = + container_reorder_given_new2old(src_access_lengths, src_dim_access_order); + + // make forward steps + const auto src_forward_steps = generate_tuple( + [&](auto i) { + Index forward_step_idx; + + static_for<0, nDim, 1>{}([&](auto j) { + forward_step_idx(j) = (i.value == j.value) ? src_scalar_per_access[i] : 0; + }); + + return make_tensor_coordinate_step(src_desc, forward_step_idx); + }, + Number{}); + + // make backward steps + const auto src_backward_steps = generate_tuple( + [&](auto i) { + Index backward_step_idx; + + static_for<0, nDim, 1>{}([&](auto j) { + backward_step_idx(j) = (i.value == j.value) ? -src_scalar_per_access[i] : 0; + }); + + return make_tensor_coordinate_step(src_desc, backward_step_idx); + }, + Number{}); + + // loop over tensor and copy + static_ford{}([&](auto ordered_src_access_idx) { + // judge move forward or move backward + constexpr auto forward_sweep = [&]() { + StaticallyIndexedArray forward_sweep_; + + forward_sweep_(I0) = true; + + static_for<1, nDim, 1>{}([&](auto i) { + index_t tmp = ordered_src_access_idx[I0]; + + static_for<1, i, 1>{}([&](auto j) { + tmp = tmp * ordered_src_access_lengths[j] + ordered_src_access_idx[j]; + }); + + forward_sweep_(i) = tmp % 2 == 0; + }); + + return forward_sweep_; + }(); + + // calculate src data index + constexpr auto src_data_idx = [&]() { + Index ordered_idx; + + static_for<0, nDim, 1>{}([&](auto i) { + ordered_idx(i) = forward_sweep[i] ? ordered_src_access_idx[i] + : ordered_src_access_lengths[i] - 1 - + ordered_src_access_idx[i]; + }); + + return container_reorder_given_old2new(ordered_idx, src_dim_access_order) * + src_scalar_per_access; + }(); + + constexpr auto src_data_idx_seq = generate_sequence_v2( + [&](auto i) { return Number{}; }, Number{}); + + const bool is_src_valid = + coordinate_has_valid_offset_assuming_visible_index_is_valid(src_desc, src_coord_); + + using src_vector_type = vector_type_maker_t; + using src_vector_t = typename src_vector_type::type; + + // copy data from src_buf into src_vector_container + auto src_vector_container = src_vector_type{ + src_buf.template Get(src_coord_.GetOffset(), is_src_valid)}; + + // apply SrcElementwiseOperation on src_vector_container + static_for<0, SrcScalarPerVector, 1>{}([&](auto i) { + src_vector_container.template AsType()(i) = + src_element_op_(src_vector_container.template AsType()[i]); + }); + + // copy data from src_vector_container into src_thread_scratch_ + src_thread_scratch_.template SetAsType( + src_data_idx_seq, src_vector_container.template AsType()[I0]); + + constexpr auto move_on_dim = [&]() constexpr + { + StaticallyIndexedArray move_on_dim_; + + static_for<0, nDim, 1>{}([&](auto i) { + move_on_dim_(i) = ordered_src_access_idx[i] < ordered_src_access_lengths[i] - 1; + + static_for{}([&](auto j) { + move_on_dim_(i) &= + ordered_src_access_idx[j] == ordered_src_access_lengths[j] - 1; + }); + }); + + return move_on_dim_; + } + (); + + // move src coord + static_for<0, nDim, 1>{}([&](auto i) { + if constexpr(move_on_dim[i]) + { + if constexpr(forward_sweep[i]) + { + move_tensor_coordinate( + src_desc, src_coord_, src_forward_steps[src_dim_access_order[i]]); + } + else + { + move_tensor_coordinate( + src_desc, src_coord_, src_backward_steps[src_dim_access_order[i]]); + } + } + }); + }); + + // move src coordinate back to slice origin (or not) + if constexpr(SrcResetCoordinateAfterRun) + { + const auto src_reset_step = + make_tensor_coordinate_step(src_desc, GetSrcCoordinateResetStep()); + + move_tensor_coordinate(src_desc, src_coord_, src_reset_step); + } + } + + __device__ void TransferDataFromSrcThreadScratchToDstThreadScratch() + { +#if !CK_EXPERIMENTAL_USE_IN_REGISTER_SUB_DWORD_TRANSPOSE + static_ford{}([&](auto idx) { + // convert from SrcData to DstData here + dst_thread_scratch_(idx) = type_convert(src_thread_scratch_[idx]); + }); +#else + // sub-dword transpose between src_thread_scratch_ and dst_thread_scratch_ + // TODO make this logic more generic for more sub-dword datatype + if constexpr(SrcVectorDim != DstVectorDim && + is_same>::value && + is_same>::value && + SrcScalarPerVector % 2 == 0 && DstScalarPerVector % 2 == 0) + { + // each transpose does + // DstScalarPerVector # of src vectors in src_thread_scratch_ + // SrcScalarPerVector # of dst vectors in dst_thread_scratch_ + constexpr index_t num_src_vector = Number{}; + constexpr index_t num_dst_vector = Number{}; + + // Assume SrcVectorDim is not the same as DstVectorDim, so we do transpose + // TODO: make this logic generic for all scenario + static_assert(SrcVectorDim != DstVectorDim, "wrong"); + + constexpr auto src_scalar_step_in_vector = generate_sequence( + detail::lambda_scalar_step_in_vector{}, Number{}); + + constexpr auto dst_scalar_step_in_vector = generate_sequence( + detail::lambda_scalar_step_in_vector{}, Number{}); + + constexpr auto scalar_per_access = generate_sequence( + detail::lambda_scalar_per_access_for_src_and_dst{}, + Number{}); + + constexpr auto access_lengths = SliceLengths{} / scalar_per_access; + + static_ford{}([&](auto access_idx) { + constexpr auto data_idx = access_idx * scalar_per_access; + + constexpr auto data_idx_seq = generate_sequence_v2( + [&](auto i) { return Number{}; }, Number{}); + + // TODO type_convert is not used yet!!!!! + using src_vector_t = vector_type_maker_t; + using dst_vector_t = vector_type_maker_t; + + // get DstScalarPerVector # of read-only references to src vectors from + // src_thread_scratch_ + const auto src_vector_refs = generate_tie( + [&](auto i) -> const src_vector_t& { + // i increment corresponds to movement in DstVectorDim + return src_thread_scratch_.GetVectorTypeReference( + data_idx_seq + i * dst_scalar_step_in_vector); + }, + Number{}); + + // get SrcScalarPerVector # of references to dst vectors from dst_thread_scratch_ + auto dst_vector_refs = generate_tie( + [&](auto i) -> dst_vector_t& { + // i increment corresponds to movement in SrcVectorDim + return dst_thread_scratch_.GetVectorTypeReference( + data_idx_seq + i * src_scalar_step_in_vector); + }, + Number{}); + + // do data transpose + // TODO type_convert is not used yet!!!!! + transpose_vectors{}( + src_vector_refs, dst_vector_refs); + }); + } + else + { + static_ford{}([&](auto idx) { + // convert from SrcData to DstData here + dst_thread_scratch_(idx) = type_convert(src_thread_scratch_[idx]); + }); + } +#endif + } + + template + __device__ void RunWrite(const DstDesc& dst_desc, + DstBuffer& dst_buf, + const Dst0Desc& dst0_desc, + const Dst0Buffer& dst0_buf, + const Dst1Desc& dst1_desc, + const Dst1Buffer& dst1_buf) + { + // if there is transpose, it's done here + // TODO move this elsewhere + TransferDataFromSrcThreadScratchToDstThreadScratch(); + + static_assert(DstBuffer::GetAddressSpace() == AddressSpaceEnum_t::Global or + DstBuffer::GetAddressSpace() == AddressSpaceEnum_t::Lds, + "wrong!"); + + static_assert( + is_same, remove_cvref_t>::value, + "wrong! SrcBuffer or DstBuffer data type is wrong"); + + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + + // src scalar per access on each dim + // TODO: don't use this + constexpr auto dst_scalar_per_access = generate_sequence( + detail::lambda_scalar_per_access{}, Number{}); + + constexpr auto dst_access_lengths = SliceLengths{} / dst_scalar_per_access; + + constexpr auto dst_dim_access_order = DstDimAccessOrder{}; + + constexpr auto ordered_dst_access_lengths = + container_reorder_given_new2old(dst_access_lengths, dst_dim_access_order); + + // make forward steps + const auto dst_forward_steps = generate_tuple( + [&](auto i) { + Index forward_step_idx; + + static_for<0, nDim, 1>{}([&](auto j) { + forward_step_idx(j) = (i.value == j.value) ? dst_scalar_per_access[i] : 0; + }); + + return make_tensor_coordinate_step(dst_desc, forward_step_idx); + }, + Number{}); + + // make forward steps: dst0 + // WARNING!!!!!!: this logic is only correct if dst/dst0/dst1 can use the same + // DstScalarPerVector + // TODO: fix this + const auto dst0_forward_steps = generate_tuple( + [&](auto i) { + Index forward_step_idx; + + static_for<0, nDim, 1>{}([&](auto j) { + forward_step_idx(j) = (i.value == j.value) ? dst_scalar_per_access[i] : 0; + }); + + return make_tensor_coordinate_step(dst0_desc, forward_step_idx); + }, + Number{}); + + // make forward steps: dst1 + // WARNING!!!!!!: this logic is only correct if dst/dst0/dst1 can use the same + // DstScalarPerVector + // TODO: fix this + const auto dst1_forward_steps = generate_tuple( + [&](auto i) { + Index forward_step_idx; + + static_for<0, nDim, 1>{}([&](auto j) { + forward_step_idx(j) = (i.value == j.value) ? dst_scalar_per_access[i] : 0; + }); + + return make_tensor_coordinate_step(dst1_desc, forward_step_idx); + }, + Number{}); + + // make backward steps + const auto dst_backward_steps = generate_tuple( + [&](auto i) { + Index backward_step_idx; + + static_for<0, nDim, 1>{}([&](auto j) { + backward_step_idx(j) = (i.value == j.value) ? -dst_scalar_per_access[i] : 0; + }); + + return make_tensor_coordinate_step(dst_desc, backward_step_idx); + }, + Number{}); + + // make backward steps: dst0 + // WARNING!!!!!!: this logic is only correct if dst/dst0/dst1 can use the same + // DstScalarPerVector + // TODO: fix this + const auto dst0_backward_steps = generate_tuple( + [&](auto i) { + Index backward_step_idx; + + static_for<0, nDim, 1>{}([&](auto j) { + backward_step_idx(j) = (i.value == j.value) ? -dst_scalar_per_access[i] : 0; + }); + + return make_tensor_coordinate_step(dst0_desc, backward_step_idx); + }, + Number{}); + + // make backward steps: dst1 + // WARNING!!!!!!: this logic is only correct if dst/dst0/dst1 can use the same + // DstScalarPerVector + // TODO: fix this + const auto dst1_backward_steps = generate_tuple( + [&](auto i) { + Index backward_step_idx; + + static_for<0, nDim, 1>{}([&](auto j) { + backward_step_idx(j) = (i.value == j.value) ? -dst_scalar_per_access[i] : 0; + }); + + return make_tensor_coordinate_step(dst1_desc, backward_step_idx); + }, + Number{}); + + // loop over tensor and copy + static_ford{}([&](auto ordered_dst_access_idx) { + // judge move forward or move backward + constexpr auto forward_sweep = [&]() { + StaticallyIndexedArray forward_sweep_; + + forward_sweep_(I0) = true; + + static_for<1, nDim, 1>{}([&](auto i) { + index_t tmp = ordered_dst_access_idx[I0]; + + static_for<1, i, 1>{}([&](auto j) { + tmp = tmp * ordered_dst_access_lengths[j] + ordered_dst_access_idx[j]; + }); + + forward_sweep_(i) = tmp % 2 == 0; + }); + + return forward_sweep_; + }(); + + // calculate dst data index + constexpr auto dst_data_idx = [&]() { + Index ordered_idx; + + static_for<0, nDim, 1>{}([&](auto i) { + ordered_idx(i) = forward_sweep[i] ? ordered_dst_access_idx[i] + : ordered_dst_access_lengths[i] - 1 - + ordered_dst_access_idx[i]; + }); + + return container_reorder_given_old2new(ordered_idx, dst_dim_access_order) * + dst_scalar_per_access; + }(); + + constexpr auto dst_data_idx_seq = generate_sequence_v2( + [&](auto i) { return Number{}; }, Number{}); + + const bool is_dst_valid = + coordinate_has_valid_offset_assuming_visible_index_is_valid(dst_desc, dst_coord_); + + using dst_vector_type = vector_type_maker_t; + using dst_vector_t = typename dst_vector_type::type; + + // copy data from dst_thread_scratch_ into dst_vector_container + auto dst_vector_container = dst_vector_type{ + dst_thread_scratch_.template GetAsType(dst_data_idx_seq)}; + + // apply DstElementwiseOperation on dst_vector_container + static_for<0, DstScalarPerVector, 1>{}([&](auto i) { + dst_vector_container.template AsType()(i) = + dst_element_op_(dst_vector_container.template AsType()[i]); + }); + + // copy data from dst_vector_container to dst_buf + dst_buf.template Set( + dst_coord_.GetOffset(), + is_dst_valid, + dst_vector_container.template AsType()[I0]); + + constexpr auto move_on_dim = [&]() constexpr + { + StaticallyIndexedArray move_on_dim_; + + static_for<0, nDim, 1>{}([&](auto i) { + move_on_dim_(i) = ordered_dst_access_idx[i] < ordered_dst_access_lengths[i] - 1; + + static_for{}([&](auto j) { + move_on_dim_(i) &= + ordered_dst_access_idx[j] == ordered_dst_access_lengths[j] - 1; + }); + }); + + return move_on_dim_; + } + (); + + // move dst coord + static_for<0, nDim, 1>{}([&](auto i) { + if constexpr(move_on_dim[i]) + { + if constexpr(forward_sweep[i]) + { + move_tensor_coordinate( + dst_desc, dst_coord_, dst_forward_steps[dst_dim_access_order[i]]); + } + else + { + move_tensor_coordinate( + dst_desc, dst_coord_, dst_backward_steps[dst_dim_access_order[i]]); + } + } + }); + }); + + // move dst coordinate back to slice origin (or not) + if constexpr(DstResetCoordinateAfterRun) + { + const auto dst_reset_step = + make_tensor_coordinate_step(dst_desc, GetDstCoordinateResetStep()); + + move_tensor_coordinate(dst_desc, dst_coord_, dst_reset_step); + } + } + + __device__ static constexpr auto GetSrcCoordinateResetStep() + { + constexpr auto I0 = Number<0>{}; + + // scalar per access on each dim + // TODO: don't use lambda_scalar_per_access + constexpr auto src_scalar_per_access = generate_sequence( + detail::lambda_scalar_per_access{}, Number{}); + + constexpr auto src_access_lengths = SliceLengths{} / src_scalar_per_access; + + constexpr auto src_dim_access_order = SrcDimAccessOrder{}; + + constexpr auto ordered_src_access_lengths = + container_reorder_given_new2old(src_access_lengths, src_dim_access_order); + + // judge move forward or move backward during the last iteration + constexpr auto forward_sweep = [&]() { + StaticallyIndexedArray forward_sweep_; + + forward_sweep_(I0) = true; + + // TODO: BUG: should start at 1 + static_for<1, nDim, 1>{}([&](auto i) { + index_t tmp = ordered_src_access_lengths[I0] - 1; + + static_for<1, i, 1>{}([&](auto j) { + tmp = tmp * ordered_src_access_lengths[j] + ordered_src_access_lengths[j] - 1; + }); + + forward_sweep_(i) = tmp % 2 == 0; + }); + + return forward_sweep_; + }(); + + // calculate src data index after last iteration in RunRead(), if it has not being reset by + // RunRead() + constexpr auto src_data_idx = [&]() { + Index ordered_idx; + + static_for<0, nDim, 1>{}([&](auto i) { + ordered_idx(i) = forward_sweep[i] ? ordered_src_access_lengths[i] - 1 : 0; + }); + + return container_reorder_given_old2new(ordered_idx, src_dim_access_order) * + src_scalar_per_access; + }(); + + // + constexpr auto reset_src_data_step = [&]() { + Index reset_src_data_step_; + + static_for<0, nDim, 1>{}([&](auto i) { reset_src_data_step_(i) = -src_data_idx[i]; }); + + return reset_src_data_step_; + }(); + + return reset_src_data_step; + } + + __device__ static constexpr auto GetDstCoordinateResetStep() + { + constexpr auto I0 = Number<0>{}; + + // scalar per access on each dim + // TODO: don't use lambda_scalar_per_access + constexpr auto dst_scalar_per_access = generate_sequence( + detail::lambda_scalar_per_access{}, Number{}); + + constexpr auto dst_access_lengths = SliceLengths{} / dst_scalar_per_access; + + constexpr auto dst_dim_access_order = DstDimAccessOrder{}; + + constexpr auto ordered_dst_access_lengths = + container_reorder_given_new2old(dst_access_lengths, dst_dim_access_order); + + // judge move forward or move backward during the last iteration + constexpr auto forward_sweep = [&]() { + StaticallyIndexedArray forward_sweep_; + + forward_sweep_(I0) = true; + + static_for<1, nDim, 1>{}([&](auto i) { + index_t tmp = ordered_dst_access_lengths[I0] - 1; + + static_for<1, i, 1>{}([&](auto j) { + tmp = tmp * ordered_dst_access_lengths[j] + ordered_dst_access_lengths[j] - 1; + }); + + forward_sweep_(i) = tmp % 2 == 0; + }); + + return forward_sweep_; + }(); + + // calculate dst data index after last iteration in RunWrite(), if it has not being reset by + // RunWrite() + constexpr auto dst_data_idx = [&]() { + Index ordered_idx; + + static_for<0, nDim, 1>{}([&](auto i) { + ordered_idx(i) = forward_sweep[i] ? ordered_dst_access_lengths[i] - 1 : 0; + }); + + return container_reorder_given_old2new(ordered_idx, dst_dim_access_order) * + dst_scalar_per_access; + }(); + + // + constexpr auto reset_dst_data_step = [&]() { + Index reset_dst_data_step_; + + static_for<0, nDim, 1>{}([&](auto i) { reset_dst_data_step_(i) = -dst_data_idx[i]; }); + + return reset_dst_data_step_; + }(); + + return reset_dst_data_step; + } + + // src_slice_origin_step_idx need to be known at compile-time, for performance reason + __device__ void MoveSrcSliceWindow(const SrcDesc& src_desc, + const Index& src_slice_origin_step_idx) + { + // if src coord was not reset by RunRead(), then need to adjust the step here + const auto adjusted_step_idx = + SrcResetCoordinateAfterRun ? src_slice_origin_step_idx + : src_slice_origin_step_idx + GetSrcCoordinateResetStep(); + + // is it OK to construct a new step every time? + const auto adjusted_step = make_tensor_coordinate_step(src_desc, adjusted_step_idx); + + move_tensor_coordinate(src_desc, src_coord_, adjusted_step); + } + + // src_slice_origin_step_idx need to be known at compile-time, for performance reason + __device__ void MoveSrcSliceWindow(const SrcDesc& src_desc, + const Index& src_slice_origin_step_idx) + { + // if src coord was not reset by RunRead(), then need to adjust the step here + const auto adjusted_step_idx = + SrcResetCoordinateAfterRun ? src_slice_origin_step_idx + : src_slice_origin_step_idx + GetSrcCoordinateResetStep(); + + // is it OK to construct a new step every time? + const auto adjusted_step = make_tensor_coordinate_step(src_desc, adjusted_step_idx); + + move_tensor_coordinate(src_desc, src_coord_, adjusted_step); + } + + // dst_slice_origin_step_idx need to be known at compile-time, for performance reason + __device__ void MoveDstSliceWindow(const DstDesc& dst_desc, + const Dst0Desc dst0_desc, + const Dst1Desc dst1_desc, + const Index& dst_slice_origin_step_idx) + { + // if dst coord was not reset by RunWrite(), then need to adjust the step here + const auto adjusted_step_idx = + DstResetCoordinateAfterRun ? dst_slice_origin_step_idx + : dst_slice_origin_step_idx + GetDstCoordinateResetStep(); + + // is it OK to construct a new step every time? + const auto adjusted_step = make_tensor_coordinate_step(dst_desc, adjusted_step_idx); + + move_tensor_coordinate(dst_desc, dst_coord_, adjusted_step); + move_tensor_coordinate(dst0_desc, dst0_coord_, adjusted_step); + move_tensor_coordinate(dst1_desc, dst1_coord_, adjusted_step); + } + + __device__ static constexpr auto GetSrcThreadScratchDescriptor() + { + constexpr auto src_scalar_per_access = generate_sequence( + detail::lambda_scalar_per_access{}, Number{}); + + constexpr auto src_access_lengths = SliceLengths{} / src_scalar_per_access; + + constexpr auto src_access_lengths_and_vector_length = container_push_back( + sequence_to_tuple_of_number(src_access_lengths), Number{}); + + // 1st stage of transforms + constexpr auto desc0 = + make_naive_tensor_descriptor_packed(src_access_lengths_and_vector_length); + + // 2nd stage of transforms + constexpr auto transforms = generate_tuple( + [&](auto i) { + if constexpr(i == SrcVectorDim) + { + return make_merge_transform_v3_division_mod( + make_tuple(src_access_lengths_and_vector_length[i], + src_access_lengths_and_vector_length[Number{}])); + } + else + { + return make_pass_through_transform(src_access_lengths_and_vector_length[i]); + } + }, + Number{}); + + constexpr auto low_dim_idss = generate_tuple( + [&](auto i) { + if constexpr(i == SrcVectorDim) + { + return Sequence{}; + } + else + { + return Sequence{}; + } + }, + Number{}); + + constexpr auto up_dim_idss = + generate_tuple([&](auto i) { return Sequence{}; }, Number{}); + + return transform_tensor_descriptor(desc0, transforms, low_dim_idss, up_dim_idss); + } + + __device__ static constexpr auto GetDstThreadScratchDescriptor() + { + // 1st stage of transforms + constexpr auto dst_scalar_per_access = generate_sequence( + detail::lambda_scalar_per_access{}, Number{}); + + constexpr auto dst_access_lengths = SliceLengths{} / dst_scalar_per_access; + + constexpr auto dst_access_lengths_and_vector_length = container_push_back( + sequence_to_tuple_of_number(dst_access_lengths), Number{}); + + constexpr auto desc0 = + make_naive_tensor_descriptor_packed(dst_access_lengths_and_vector_length); + + // 2nd stage of transforms + constexpr auto transforms = generate_tuple( + [&](auto i) { + if constexpr(i == DstVectorDim) + { + return make_merge_transform_v3_division_mod( + make_tuple(dst_access_lengths_and_vector_length[i], + dst_access_lengths_and_vector_length[Number{}])); + } + else + { + return make_pass_through_transform(dst_access_lengths_and_vector_length[i]); + } + }, + Number{}); + + constexpr auto low_dim_idss = generate_tuple( + [&](auto i) { + if constexpr(i == DstVectorDim) + { + return Sequence{}; + } + else + { + return Sequence{}; + } + }, + Number{}); + + constexpr auto up_dim_idss = + generate_tuple([&](auto i) { return Sequence{}; }, Number{}); + + return transform_tensor_descriptor(desc0, transforms, low_dim_idss, up_dim_idss); + } + + private: + static constexpr auto src_thread_scratch_desc_ = decltype(GetSrcThreadScratchDescriptor()){}; + static constexpr auto dst_thread_scratch_desc_ = decltype(GetDstThreadScratchDescriptor()){}; + + StaticTensorTupleOfVectorBuffer + src_thread_scratch_; + + StaticTensorTupleOfVectorBuffer + dst_thread_scratch_; + + SrcCoord src_coord_; + DstCoord dst_coord_; + const SrcElementwiseOperation src_element_op_; + const DstElementwiseOperation dst_element_op_; +}; + +} // namespace ck +#endif diff --git a/composable_kernel/include/tensor_operation/threadwise_tensor_slice_transfer_v4r1.hpp b/composable_kernel/include/tensor_operation/threadwise_tensor_slice_transfer_v4r1.hpp new file mode 100644 index 0000000000..2504c92856 --- /dev/null +++ b/composable_kernel/include/tensor_operation/threadwise_tensor_slice_transfer_v4r1.hpp @@ -0,0 +1,174 @@ +#ifndef CK_THREADWISE_TENSOR_SLICE_TRANSFER_V4R1_HPP +#define CK_THREADWISE_TENSOR_SLICE_TRANSFER_V4R1_HPP + +#include "common_header.hpp" +#include "tensor_descriptor.hpp" +#include "tensor_descriptor_helper.hpp" + +namespace ck { +// Assume: +// 1. src: +// 1. SrcDesc is known at compile-time +// 2. SrcBuffer is DynamicBuffer +// 3. src_ref_idx is known at run-time +// 4. SrcRefToOriginDisplacement is known at compile-time +// 5. use #-step +// 2. dst: +// 1. DstDesc is known at compile-time +// 2. DstBuffer is StaticBuffer +// 3. DstOriginIdx is known at compile-time +// 4. use direct address calculation +// 3. vector access on src +template ::type = false> +struct ThreadwiseTensorSliceTransfer_v4r1 +{ + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + + static constexpr index_t nDim = SliceLengths::Size(); + + using Index = MultiIndex; + + using SrcCoord = decltype(make_tensor_coordinate(SrcDesc{}, Index{})); + + using SrcCoordStep = decltype(make_tensor_coordinate_step(SrcDesc{}, Index{})); + + __device__ constexpr ThreadwiseTensorSliceTransfer_v4r1(const Index& src_ref_idx) + : src_ref_coord_(make_tensor_coordinate(SrcDesc{}, src_ref_idx)) + { + static_assert(SrcDesc::IsKnownAtCompileTime() && DstDesc::IsKnownAtCompileTime(), + "wrong! SrcDesc and DstDesc need to known at compile-time"); + + static_for<0, nDim, 1>{}([](auto i) { + static_assert(SliceLengths::At(i) % SrcVectorTensorLengths::At(i) == 0, "wrong!"); + }); + } + + template + __device__ void Run(const SrcDesc&, + const SrcRefToOriginDisplacement&, + const SrcBuffer& src_buf, + const DstDesc&, + const DstOriginIdx&, + DstBuffer& dst_buf) const + { + static_assert(SrcDesc::IsKnownAtCompileTime() && DstDesc::IsKnownAtCompileTime(), + "wrong! SrcDesc and DstDesc need to known at compile-time"); + + static_assert( + is_same, remove_cvref_t>::value && + is_same, remove_cvref_t>::value, + "wrong! SrcBuffer or DstBuffer data type is wrong"); + + static_assert(DstBuffer::IsStaticBuffer(), "wrong! DstBuffer need to be StaticBuffer"); + + static_assert(is_known_at_compile_time>::value && + is_known_at_compile_time>::value, + "wrong! SrcOriginToRefDistance and DstOriginToRefDistance need to be known " + "at compile-time"); + + // SrcDesc and DstDesc are known at compile-time + constexpr auto src_desc = remove_cvref_t{}; + constexpr auto dst_desc = remove_cvref_t{}; + + // SrcOriginToRefDisttance and DstOriginToRefDistance are known at compile-time + constexpr auto src_ref_to_origin_disp_idx = to_multi_index(SrcRefToOriginDisplacement{}); + constexpr auto dst_origin_idx = to_multi_index(DstOriginIdx{}); + + // tensor descriptor for src_vector + constexpr auto src_vector_tensor_lengths = SrcVectorTensorLengths{}; + + constexpr auto src_vector_tensor_strides = container_reorder_given_old2new( + container_reverse_exclusive_scan( + container_reorder_given_new2old(src_vector_tensor_lengths, + SrcVectorTensorContiguousDimOrder{}), + math::multiplies{}, + I1), + SrcVectorTensorContiguousDimOrder{}); + + constexpr auto src_vector_desc = + make_naive_tensor_descriptor(sequence_to_tuple_of_number(src_vector_tensor_lengths), + sequence_to_tuple_of_number(src_vector_tensor_strides)); + + // access order and lengths + constexpr auto access_lengths = SliceLengths{} / src_vector_tensor_lengths; + + constexpr auto dim_access_order = DimAccessOrder{}; + + constexpr auto ordered_access_lengths = + container_reorder_given_new2old(access_lengths, dim_access_order); + + static_ford{}([&](auto ordered_access_idx) { + // position in slice window + constexpr auto data_to_origin_disp_idx = + ordered_access_idx.ReorderGivenOld2New(dim_access_order) * + src_vector_tensor_lengths; + + // src coordinate at starting point of src_vector + constexpr auto src_ref_to_data_disp_idx = + src_ref_to_origin_disp_idx + data_to_origin_disp_idx; + + constexpr auto src_ref_to_data_disp_coord_step = + make_tensor_coordinate_step(src_desc, src_ref_to_data_disp_idx); + + auto src_data_coord = src_ref_coord_; + + move_tensor_coordinate(src_desc, src_data_coord, src_ref_to_data_disp_coord_step); + + vector_type_maker_t src_vector; + + using src_vector_t = typename decltype(src_vector)::type; + + const bool is_src_valid = coordinate_has_valid_offset_assuming_visible_index_is_valid( + src_desc, src_data_coord); + + // copy data from src_buf into src_vector + src_vector.template AsType()(I0) = + src_buf.template Get(src_data_coord.GetOffset(), is_src_valid); + + // copy data from src_vector into dst_buf (also cast from SrcData to DstData) + static_ford{}([&](auto src_vector_idx_) { + constexpr auto src_vector_idx = to_multi_index(src_vector_idx_); + + constexpr index_t src_vector_offset = + src_vector_desc.CalculateOffset(src_vector_idx); + + constexpr index_t dst_offset = dst_desc.CalculateOffset( + dst_origin_idx + data_to_origin_disp_idx + src_vector_idx); + + dst_buf(Number{}) = type_convert( + src_vector.template AsType()[Number{}]); + }); + }); + } + + template + __device__ void MoveSrcSliceWindow(const SrcDesc&, + const SrcSliceMoveStepIdx& src_slice_move_step_idx) + { + constexpr auto src_desc = SrcDesc{}; + + const auto src_slice_move_step_iter = + make_tensor_coordinate_step(src_desc, to_multi_index(src_slice_move_step_idx)); + + move_tensor_coordinate(SrcDesc{}, src_ref_coord_, src_slice_move_step_iter); + } + + private: + SrcCoord src_ref_coord_; +}; + +} // namespace ck +#endif diff --git a/composable_kernel/include/tensor_operation/threadwise_tensor_slice_transfer_v2.hpp b/composable_kernel/include/tensor_operation/threadwise_tensor_slice_transfer_v5r1.hpp similarity index 76% rename from composable_kernel/include/tensor_operation/threadwise_tensor_slice_transfer_v2.hpp rename to composable_kernel/include/tensor_operation/threadwise_tensor_slice_transfer_v5r1.hpp index 9d996afbb0..bedea25874 100644 --- a/composable_kernel/include/tensor_operation/threadwise_tensor_slice_transfer_v2.hpp +++ b/composable_kernel/include/tensor_operation/threadwise_tensor_slice_transfer_v5r1.hpp @@ -1,5 +1,5 @@ -#ifndef CK_THREADWISE_TENSOR_SLICE_TRANSFER_V2_HPP -#define CK_THREADWISE_TENSOR_SLICE_TRANSFER_V2_HPP +#ifndef CK_THREADWISE_TENSOR_SLICE_TRANSFER_V5R1_HPP +#define CK_THREADWISE_TENSOR_SLICE_TRANSFER_V5R1_HPP #include "common_header.hpp" #include "tensor_descriptor.hpp" @@ -30,7 +30,7 @@ template // control whether to move back dst coordinate after each // RunWrite(), will be fused with MoveDstSliceWindow to // save addr computation -struct ThreadwiseTensorSliceTransfer_v3r1 +struct ThreadwiseTensorSliceTransfer_v5r1 { static constexpr auto I0 = Number<0>{}; static constexpr auto I1 = Number<1>{}; @@ -44,7 +44,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1 using SrcCoordStep = decltype(make_tensor_coordinate_step(SrcDesc{}, Index{})); using DstCoordStep = decltype(make_tensor_coordinate_step(DstDesc{}, Index{})); - __device__ constexpr ThreadwiseTensorSliceTransfer_v3r1(const SrcDesc& src_desc, + __device__ constexpr ThreadwiseTensorSliceTransfer_v5r1(const SrcDesc& src_desc, const Index& src_slice_origin, const DstDesc& dst_desc, const Index& dst_slice_origin) @@ -608,169 +608,5 @@ struct ThreadwiseTensorSliceTransfer_v3r1 DstCoord dst_coord_; }; -// Assume: -// 1. src: -// 1. SrcDesc is known at compile-time -// 2. SrcBuffer is DynamicBuffer -// 3. src_ref_idx is known at run-time -// 4. SrcRefToOriginDisplacement is known at compile-time -// 5. use #-step -// 2. dst: -// 1. DstDesc is known at compile-time -// 2. DstBuffer is StaticBuffer -// 3. DstOriginIdx is known at compile-time -// 4. use direct address calculation -// 3. vector access on src -template ::type = false> -struct ThreadwiseTensorSliceTransfer_v4r1 -{ - static constexpr auto I0 = Number<0>{}; - static constexpr auto I1 = Number<1>{}; - - static constexpr index_t nDim = SliceLengths::Size(); - - using Index = MultiIndex; - - using SrcCoord = decltype(make_tensor_coordinate(SrcDesc{}, Index{})); - - using SrcCoordStep = decltype(make_tensor_coordinate_step(SrcDesc{}, Index{})); - - __device__ constexpr ThreadwiseTensorSliceTransfer_v4r1(const Index& src_ref_idx) - : src_ref_coord_(make_tensor_coordinate(SrcDesc{}, src_ref_idx)) - { - static_assert(SrcDesc::IsKnownAtCompileTime() && DstDesc::IsKnownAtCompileTime(), - "wrong! SrcDesc and DstDesc need to known at compile-time"); - - static_for<0, nDim, 1>{}([](auto i) { - static_assert(SliceLengths::At(i) % SrcVectorTensorLengths::At(i) == 0, "wrong!"); - }); - } - - template - __device__ void Run(const SrcDesc&, - const SrcRefToOriginDisplacement&, - const SrcBuffer& src_buf, - const DstDesc&, - const DstOriginIdx&, - DstBuffer& dst_buf) const - { - static_assert(SrcDesc::IsKnownAtCompileTime() && DstDesc::IsKnownAtCompileTime(), - "wrong! SrcDesc and DstDesc need to known at compile-time"); - - static_assert( - is_same, remove_cvref_t>::value && - is_same, remove_cvref_t>::value, - "wrong! SrcBuffer or DstBuffer data type is wrong"); - - static_assert(DstBuffer::IsStaticBuffer(), "wrong! DstBuffer need to be StaticBuffer"); - - static_assert(is_known_at_compile_time>::value && - is_known_at_compile_time>::value, - "wrong! SrcOriginToRefDistance and DstOriginToRefDistance need to be known " - "at compile-time"); - - // SrcDesc and DstDesc are known at compile-time - constexpr auto src_desc = remove_cvref_t{}; - constexpr auto dst_desc = remove_cvref_t{}; - - // SrcOriginToRefDisttance and DstOriginToRefDistance are known at compile-time - constexpr auto src_ref_to_origin_disp_idx = to_multi_index(SrcRefToOriginDisplacement{}); - constexpr auto dst_origin_idx = to_multi_index(DstOriginIdx{}); - - // tensor descriptor for src_vector - constexpr auto src_vector_tensor_lengths = SrcVectorTensorLengths{}; - - constexpr auto src_vector_tensor_strides = container_reorder_given_old2new( - container_reverse_exclusive_scan( - container_reorder_given_new2old(src_vector_tensor_lengths, - SrcVectorTensorContiguousDimOrder{}), - math::multiplies{}, - I1), - SrcVectorTensorContiguousDimOrder{}); - - constexpr auto src_vector_desc = - make_naive_tensor_descriptor(sequence_to_tuple_of_number(src_vector_tensor_lengths), - sequence_to_tuple_of_number(src_vector_tensor_strides)); - - // access order and lengths - constexpr auto access_lengths = SliceLengths{} / src_vector_tensor_lengths; - - constexpr auto dim_access_order = DimAccessOrder{}; - - constexpr auto ordered_access_lengths = - container_reorder_given_new2old(access_lengths, dim_access_order); - - static_ford{}([&](auto ordered_access_idx) { - // position in slice window - constexpr auto data_to_origin_disp_idx = - ordered_access_idx.ReorderGivenOld2New(dim_access_order) * - src_vector_tensor_lengths; - - // src coordinate at starting point of src_vector - constexpr auto src_ref_to_data_disp_idx = - src_ref_to_origin_disp_idx + data_to_origin_disp_idx; - - constexpr auto src_ref_to_data_disp_coord_step = - make_tensor_coordinate_step(src_desc, src_ref_to_data_disp_idx); - - auto src_data_coord = src_ref_coord_; - - move_tensor_coordinate(src_desc, src_data_coord, src_ref_to_data_disp_coord_step); - - vector_type_maker_t src_vector; - - using src_vector_t = typename decltype(src_vector)::type; - - const bool is_src_valid = coordinate_has_valid_offset_assuming_visible_index_is_valid( - src_desc, src_data_coord); - - // copy data from src_buf into src_vector - src_vector.template AsType()(I0) = - src_buf.template Get(src_data_coord.GetOffset(), is_src_valid); - - // copy data from src_vector into dst_buf (also cast from SrcData to DstData) - static_ford{}([&](auto src_vector_idx_) { - constexpr auto src_vector_idx = to_multi_index(src_vector_idx_); - - constexpr index_t src_vector_offset = - src_vector_desc.CalculateOffset(src_vector_idx); - - constexpr index_t dst_offset = dst_desc.CalculateOffset( - dst_origin_idx + data_to_origin_disp_idx + src_vector_idx); - - dst_buf(Number{}) = type_convert( - src_vector.template AsType()[Number{}]); - }); - }); - } - - template - __device__ void MoveSrcSliceWindow(const SrcDesc&, - const SrcSliceMoveStepIdx& src_slice_move_step_idx) - { - constexpr auto src_desc = SrcDesc{}; - - const auto src_slice_move_step_iter = - make_tensor_coordinate_step(src_desc, to_multi_index(src_slice_move_step_idx)); - - move_tensor_coordinate(SrcDesc{}, src_ref_coord_, src_slice_move_step_iter); - } - - private: - SrcCoord src_ref_coord_; -}; - } // namespace ck #endif diff --git a/composable_kernel/include/tensor_operation/threadwise_tensor_slice_transfer_v6r1.hpp b/composable_kernel/include/tensor_operation/threadwise_tensor_slice_transfer_v6r1.hpp new file mode 100644 index 0000000000..6cdb142e76 --- /dev/null +++ b/composable_kernel/include/tensor_operation/threadwise_tensor_slice_transfer_v6r1.hpp @@ -0,0 +1,338 @@ +#ifndef CK_THREADWISE_TENSOR_SLICE_TRANSFER_V6R1_HPP +#define CK_THREADWISE_TENSOR_SLICE_TRANSFER_V6R1_HPP + +#include "common_header.hpp" +#include "tensor_descriptor.hpp" +#include "tensor_descriptor_helper.hpp" + +namespace ck { + +// Do following things to avoid "alloca" in LLVM-IR, which would cause scratch memory +// and sometimes useless instructions: +// 1. Don't save a reference to tensor descriptor in class, pass in tensor descriptor as argument +// instead +// 2. Don't construct a new tensor coordinate everytime when using it, update and reuse the same +// tensor coordinate instead +// 3. Don't use a pointer to VGPR buffer, use vector instead + +// Assume: +// 1. src_desc and dst_desc are not known at compile-time +// 2. SrcBuffer and DstBuffer are DynamicBuffer +// 3. src_slice_origin and dst_slice_origin are not known at compile-time, +template +struct ThreadwiseTensorSliceTransfer_v6r1 +{ + static constexpr index_t nDim = SliceLengths::Size(); + + using Index = MultiIndex; + + using SrcCoord = decltype(make_tensor_coordinate(SrcDesc{}, Index{})); + using DstCoord = decltype(make_tensor_coordinate(DstDesc{}, Index{})); + + using SrcCoordStep = decltype(make_tensor_coordinate_step(SrcDesc{}, Index{})); + using DstCoordStep = decltype(make_tensor_coordinate_step(DstDesc{}, Index{})); + + static constexpr auto I0 = Number<0>{}; + + __device__ constexpr ThreadwiseTensorSliceTransfer_v6r1(const SrcDesc& src_desc, + const Index& src_slice_origin, + const DstDesc& dst_desc, + const Index& dst_slice_origin, + const ElementwiseOperation& element_op) + : src_coord_(make_tensor_coordinate(src_desc, src_slice_origin)), + dst_coord_(make_tensor_coordinate(dst_desc, dst_slice_origin)), + element_op_(element_op) + { + static_assert(SliceLengths::At(Number{}) % ScalarPerVector == 0, + "wrong! cannot evenly divide"); + } + + __device__ void SetSrcSliceOrigin(const SrcDesc& src_desc, const Index& src_slice_origin_idx) + { + src_coord_ = make_tensor_coordinate(src_desc, src_slice_origin_idx); + } + + __device__ void SetDstSliceOrigin(const DstDesc& dst_desc, const Index& dst_slice_origin_idx) + { + dst_coord_ = make_tensor_coordinate(dst_desc, dst_slice_origin_idx); + } + + template + __device__ void Run(const SrcDesc& src_desc, + const SrcBuffer& src_buf, + const DstDesc& dst_desc, + DstBuffer& dst_buf) + { + // scalar per access on each dim + // TODO: don't use lambda_scalar_per_access + constexpr auto scalar_per_access = generate_sequence( + detail::lambda_scalar_per_access{}, Number{}); + + constexpr auto access_lengths = SliceLengths{} / scalar_per_access; + + constexpr auto dim_access_order = DimAccessOrder{}; + + constexpr auto ordered_access_lengths = + container_reorder_given_new2old(access_lengths, dim_access_order); + + auto make_forward_steps = [&](auto desc) { + return generate_tuple( + [&](auto i) { + Index forward_step_idx; + + static_for<0, nDim, 1>{}([&](auto j) { + forward_step_idx(j) = (i.value == j.value) ? scalar_per_access[i] : 0; + }); + + return make_tensor_coordinate_step(desc, forward_step_idx); + }, + Number{}); + }; + + auto make_backward_steps = [&](auto desc) { + return generate_tuple( + [&](auto i) { + Index backward_step_idx; + + static_for<0, nDim, 1>{}([&](auto j) { + backward_step_idx(j) = (i.value == j.value) ? -scalar_per_access[i] : 0; + }); + + return make_tensor_coordinate_step(desc, backward_step_idx); + }, + Number{}); + }; + + // make forward steps + const auto src_forward_steps = make_forward_steps(src_desc); + const auto dst_forward_steps = make_forward_steps(dst_desc); + + // make backward steps + const auto src_backward_steps = make_backward_steps(src_desc); + const auto dst_backward_steps = make_backward_steps(dst_desc); + + // loop over slice window + static_ford{}([&](auto ordered_access_idx) { + // judge move forward or move backward + constexpr auto forward_sweep = [&]() { + StaticallyIndexedArray forward_sweep_; + + forward_sweep_(I0) = true; + + static_for<1, nDim, 1>{}([&](auto i) { + index_t tmp = ordered_access_idx[I0]; + + static_for<1, i, 1>{}([&](auto j) { + tmp = tmp * ordered_access_lengths[j] + ordered_access_idx[j]; + }); + + forward_sweep_(i) = tmp % 2 == 0; + }); + + return forward_sweep_; + }(); + + using src_vector_type = vector_type_maker_t; + using src_vector_t = typename src_vector_type::type; + + using dst_vector_type = vector_type_maker_t; + using dst_vector_t = typename dst_vector_type::type; + + const bool is_src_valid = + coordinate_has_valid_offset_assuming_visible_index_is_valid(src_desc, src_coord_); + + // copy data from src_buf into src_vector_container + auto src_vector_container = src_vector_type{ + src_buf.template Get(src_coord_.GetOffset(), is_src_valid)}; + + auto dst_vector_container = dst_vector_type{}; + + // apply pointwise operation + static_for<0, ScalarPerVector, 1>{}([&](auto i) { + element_op_(dst_vector_container.template AsType()(i), + src_vector_container.template AsType()[i]); + }); + + const bool is_dst_valid = + coordinate_has_valid_offset_assuming_visible_index_is_valid(dst_desc, dst_coord_); + + // copy data from dst_vector into dst_buf + if constexpr(DstInMemOp == InMemoryDataOperationEnum_t::Set) + { + dst_buf.template Set( + dst_coord_.GetOffset(), + is_dst_valid, + dst_vector_container.template AsType()[I0]); + } + else if constexpr(DstInMemOp == InMemoryDataOperationEnum_t::AtomicAdd) + { + dst_buf.template AtomicAdd( + dst_coord_.GetOffset(), + is_dst_valid, + dst_vector_container.template AsType()[I0]); + } + + constexpr auto move_on_dim = [&]() constexpr + { + StaticallyIndexedArray move_on_dim_; + + static_for<0, nDim, 1>{}([&](auto i) { + move_on_dim_(i) = ordered_access_idx[i] < ordered_access_lengths[i] - 1; + + static_for{}([&](auto j) { + move_on_dim_(i) &= ordered_access_idx[j] == ordered_access_lengths[j] - 1; + }); + }); + + return move_on_dim_; + } + (); + + // move coordinate + static_for<0, nDim, 1>{}([&](auto i) { + if constexpr(move_on_dim[i]) + { + if constexpr(forward_sweep[i]) + { + move_tensor_coordinate( + src_desc, src_coord_, src_forward_steps[dim_access_order[i]]); + + move_tensor_coordinate( + dst_desc, dst_coord_, dst_forward_steps[dim_access_order[i]]); + } + else + { + move_tensor_coordinate( + src_desc, src_coord_, src_backward_steps[dim_access_order[i]]); + + move_tensor_coordinate( + dst_desc, dst_coord_, dst_backward_steps[dim_access_order[i]]); + } + } + }); + }); + + // move coordinate back to slice origin (or not) + if constexpr(SrcResetCoordinateAfterRun) + { + const auto src_reset_step = + make_tensor_coordinate_step(src_desc, GetCoordinateResetStep()); + + move_tensor_coordinate(src_desc, src_coord_, src_reset_step); + } + + if constexpr(DstResetCoordinateAfterRun) + { + const auto dst_reset_step = + make_tensor_coordinate_step(dst_desc, GetCoordinateResetStep()); + + move_tensor_coordinate(dst_desc, dst_coord_, dst_reset_step); + } + } + + __device__ static constexpr auto GetCoordinateResetStep() + { + // scalar per access on each dim + // TODO: don't use lambda_scalar_per_access + constexpr auto scalar_per_access = generate_sequence( + detail::lambda_scalar_per_access{}, Number{}); + + constexpr auto access_lengths = SliceLengths{} / scalar_per_access; + + constexpr auto dim_access_order = DimAccessOrder{}; + + constexpr auto ordered_access_lengths = + container_reorder_given_new2old(access_lengths, dim_access_order); + + // judge move forward or move backward during the last iteration + constexpr auto forward_sweep = [&]() { + StaticallyIndexedArray forward_sweep_; + + forward_sweep_(I0) = true; + + static_for<1, nDim, 1>{}([&](auto i) { + index_t tmp = ordered_access_lengths[I0] - 1; + + static_for<1, i, 1>{}([&](auto j) { + tmp = tmp * ordered_access_lengths[j] + ordered_access_lengths[j] - 1; + }); + + forward_sweep_(i) = tmp % 2 == 0; + }); + + return forward_sweep_; + }(); + + // calculate data index after last iteration in Run(), if it has not being reset + constexpr auto data_idx = [&]() { + Index ordered_idx; + + static_for<0, nDim, 1>{}([&](auto i) { + ordered_idx(i) = forward_sweep[i] ? ordered_access_lengths[i] - 1 : 0; + }); + + return container_reorder_given_old2new(ordered_idx, dim_access_order) * + scalar_per_access; + }(); + + // + constexpr auto reset_data_step = [&]() { + Index reset_data_step_; + + static_for<0, nDim, 1>{}([&](auto i) { reset_data_step_(i) = -data_idx[i]; }); + + return reset_data_step_; + }(); + + return reset_data_step; + } + + // src_slice_origin_step_idx need to be known at compile-time, for performance reason + __device__ void MoveSrcSliceWindow(const SrcDesc& src_desc, + const Index& src_slice_origin_step_idx) + { + // if src coord was not reset by RunRead(), then need to adjust the step here + const auto adjusted_step_idx = SrcResetCoordinateAfterRun + ? src_slice_origin_step_idx + : src_slice_origin_step_idx + GetCoordinateResetStep(); + + // is it OK to construct a new step every time? + const auto adjusted_step = make_tensor_coordinate_step(src_desc, adjusted_step_idx); + + move_tensor_coordinate(src_desc, src_coord_, adjusted_step); + } + + // dst_slice_origin_step_idx need to be known at compile-time, for performance reason + __device__ void MoveDstSliceWindow(const DstDesc& dst_desc, + const Index& dst_slice_origin_step_idx) + { + // if dst coord was not reset by Run(), then need to adjust the step here + const auto adjusted_step_idx = DstResetCoordinateAfterRun + ? dst_slice_origin_step_idx + : dst_slice_origin_step_idx + GetCoordinateResetStep(); + + // is it OK to construct a new step every time? + const auto adjusted_step = make_tensor_coordinate_step(dst_desc, adjusted_step_idx); + + move_tensor_coordinate(dst_desc, dst_coord_, adjusted_step); + } + + private: + SrcCoord src_coord_; + DstCoord dst_coord_; + const ElementwiseOperation element_op_; +}; + +} // namespace ck +#endif diff --git a/composable_kernel/include/tensor_operation/threadwise_tensor_slice_transfer_v6r2.hpp b/composable_kernel/include/tensor_operation/threadwise_tensor_slice_transfer_v6r2.hpp new file mode 100644 index 0000000000..a65c275744 --- /dev/null +++ b/composable_kernel/include/tensor_operation/threadwise_tensor_slice_transfer_v6r2.hpp @@ -0,0 +1,397 @@ +#ifndef CK_THREADWISE_TENSOR_SLICE_TRANSFER_V6R2_HPP +#define CK_THREADWISE_TENSOR_SLICE_TRANSFER_V6R2_HPP + +#include "common_header.hpp" +#include "tensor_descriptor.hpp" +#include "tensor_descriptor_helper.hpp" + +namespace ck { + +// Do following things to avoid "alloca" in LLVM-IR, which would cause scratch memory +// and sometimes useless instructions: +// 1. Don't save a reference to tensor descriptor in class, pass in tensor descriptor as argument +// instead +// 2. Don't construct a new tensor coordinate everytime when using it, update and reuse the same +// tensor coordinate instead +// 3. Don't use a pointer to VGPR buffer, use vector instead + +// Assume: +// 1. src0_desc and dst_desc are not known at compile-time +// 2. SrcBuffer and DstBuffer are DynamicBuffer +// 3. src_slice_origin and dst_slice_origin are not known at compile-time, +template +struct ThreadwiseTensorSliceTransfer_v6r2 +{ + static constexpr index_t nDim = SliceLengths::Size(); + + using Index = MultiIndex; + + using Src0Coord = decltype(make_tensor_coordinate(Src0Desc{}, Index{})); + using Src1Coord = decltype(make_tensor_coordinate(Src1Desc{}, Index{})); + using DstCoord = decltype(make_tensor_coordinate(DstDesc{}, Index{})); + + using Src0CoordStep = decltype(make_tensor_coordinate_step(Src0Desc{}, Index{})); + using Src1CoordStep = decltype(make_tensor_coordinate_step(Src1Desc{}, Index{})); + using DstCoordStep = decltype(make_tensor_coordinate_step(DstDesc{}, Index{})); + + static constexpr auto I0 = Number<0>{}; + + __device__ constexpr ThreadwiseTensorSliceTransfer_v6r2(const Src0Desc& src0_desc, + const Index& src0_slice_origin, + const Src1Desc& src1_desc, + const Index& src1_slice_origin, + const DstDesc& dst_desc, + const Index& dst_slice_origin, + const ElementwiseOperation& element_op) + : src0_coord_(make_tensor_coordinate(src0_desc, src0_slice_origin)), + src1_coord_(make_tensor_coordinate(src1_desc, src1_slice_origin)), + dst_coord_(make_tensor_coordinate(dst_desc, dst_slice_origin)), + element_op_(element_op) + { + static_assert(SliceLengths::At(Number{}) % ScalarPerVector == 0, + "wrong! cannot evenly divide"); + } + + __device__ void SetSrc0SliceOrigin(const Src0Desc& src0_desc, + const Index& src0_slice_origin_idx) + { + src0_coord_ = make_tensor_coordinate(src0_desc, src0_slice_origin_idx); + } + + __device__ void SetSrc1SliceOrigin(const Src1Desc& src1_desc, + const Index& src1_slice_origin_idx) + { + src1_coord_ = make_tensor_coordinate(src1_desc, src1_slice_origin_idx); + } + + __device__ void SetDstSliceOrigin(const DstDesc& dst_desc, const Index& dst_slice_origin_idx) + { + dst_coord_ = make_tensor_coordinate(dst_desc, dst_slice_origin_idx); + } + + template + __device__ void Run(const Src0Desc& src0_desc, + const Src0Buffer& src0_buf, + const Src1Desc& src1_desc, + const Src1Buffer& src1_buf, + const DstDesc& dst_desc, + DstBuffer& dst_buf) + { + // scalar per access on each dim + // TODO: don't use lambda_scalar_per_access + constexpr auto scalar_per_access = generate_sequence( + detail::lambda_scalar_per_access{}, Number{}); + + constexpr auto access_lengths = SliceLengths{} / scalar_per_access; + + constexpr auto dim_access_order = DimAccessOrder{}; + + constexpr auto ordered_access_lengths = + container_reorder_given_new2old(access_lengths, dim_access_order); + + auto make_forward_steps = [&](auto desc) { + return generate_tuple( + [&](auto i) { + Index forward_step_idx; + + static_for<0, nDim, 1>{}([&](auto j) { + forward_step_idx(j) = (i.value == j.value) ? scalar_per_access[i] : 0; + }); + + return make_tensor_coordinate_step(desc, forward_step_idx); + }, + Number{}); + }; + + auto make_backward_steps = [&](auto desc) { + return generate_tuple( + [&](auto i) { + Index backward_step_idx; + + static_for<0, nDim, 1>{}([&](auto j) { + backward_step_idx(j) = (i.value == j.value) ? -scalar_per_access[i] : 0; + }); + + return make_tensor_coordinate_step(desc, backward_step_idx); + }, + Number{}); + }; + + // make forward steps + const auto src0_forward_steps = make_forward_steps(src0_desc); + const auto src1_forward_steps = make_forward_steps(src1_desc); + const auto dst_forward_steps = make_forward_steps(dst_desc); + + // make backward steps + const auto src0_backward_steps = make_backward_steps(src0_desc); + const auto src1_backward_steps = make_backward_steps(src1_desc); + const auto dst_backward_steps = make_backward_steps(dst_desc); + + // loop over slice window + static_ford{}([&](auto ordered_access_idx) { + // judge move forward or move backward + constexpr auto forward_sweep = [&]() { + StaticallyIndexedArray forward_sweep_; + + forward_sweep_(I0) = true; + + static_for<1, nDim, 1>{}([&](auto i) { + index_t tmp = ordered_access_idx[I0]; + + static_for<1, i, 1>{}([&](auto j) { + tmp = tmp * ordered_access_lengths[j] + ordered_access_idx[j]; + }); + + forward_sweep_(i) = tmp % 2 == 0; + }); + + return forward_sweep_; + }(); + + using src0_vector_type = vector_type_maker_t; + using src0_vector_t = typename src0_vector_type::type; + + using src1_vector_type = vector_type_maker_t; + using src1_vector_t = typename src1_vector_type::type; + + using dst_vector_type = vector_type_maker_t; + using dst_vector_t = typename dst_vector_type::type; + + const bool is_src0_valid = + coordinate_has_valid_offset_assuming_visible_index_is_valid(src0_desc, src0_coord_); + + const bool is_src1_valid = + coordinate_has_valid_offset_assuming_visible_index_is_valid(src1_desc, src1_coord_); + + // copy data from src0_buf into src0_vector_container + auto src0_vector_container = src0_vector_type{ + src0_buf.template Get(src0_coord_.GetOffset(), is_src0_valid)}; + + auto src1_vector_container = src1_vector_type{ + src1_buf.template Get(src1_coord_.GetOffset(), is_src1_valid)}; + + auto dst_vector_container = dst_vector_type{}; + + // apply pointwise operation + static_for<0, ScalarPerVector, 1>{}([&](auto i) { + element_op_(dst_vector_container.template AsType()(i), + src0_vector_container.template AsType()[i], + src1_vector_container.template AsType()[i]); + }); + + const bool is_dst_valid = + coordinate_has_valid_offset_assuming_visible_index_is_valid(dst_desc, dst_coord_); + + // copy data from dst_vector into dst_buf + if constexpr(DstInMemOp == InMemoryDataOperationEnum_t::Set) + { + dst_buf.template Set( + dst_coord_.GetOffset(), + is_dst_valid, + dst_vector_container.template AsType()[I0]); + } + else if constexpr(DstInMemOp == InMemoryDataOperationEnum_t::AtomicAdd) + { + dst_buf.template AtomicAdd( + dst_coord_.GetOffset(), + is_dst_valid, + dst_vector_container.template AsType()[I0]); + } + + constexpr auto move_on_dim = [&]() constexpr + { + StaticallyIndexedArray move_on_dim_; + + static_for<0, nDim, 1>{}([&](auto i) { + move_on_dim_(i) = ordered_access_idx[i] < ordered_access_lengths[i] - 1; + + static_for{}([&](auto j) { + move_on_dim_(i) &= ordered_access_idx[j] == ordered_access_lengths[j] - 1; + }); + }); + + return move_on_dim_; + } + (); + + // move coordinate + static_for<0, nDim, 1>{}([&](auto i) { + if constexpr(move_on_dim[i]) + { + if constexpr(forward_sweep[i]) + { + move_tensor_coordinate( + src0_desc, src0_coord_, src0_forward_steps[dim_access_order[i]]); + + move_tensor_coordinate( + src1_desc, src1_coord_, src1_forward_steps[dim_access_order[i]]); + + move_tensor_coordinate( + dst_desc, dst_coord_, dst_forward_steps[dim_access_order[i]]); + } + else + { + move_tensor_coordinate( + src0_desc, src0_coord_, src0_backward_steps[dim_access_order[i]]); + + move_tensor_coordinate( + src1_desc, src1_coord_, src1_backward_steps[dim_access_order[i]]); + + move_tensor_coordinate( + dst_desc, dst_coord_, dst_backward_steps[dim_access_order[i]]); + } + } + }); + }); + + // move coordinate back to slice origin (or not) + if constexpr(Src0ResetCoordinateAfterRun) + { + const auto src0_reset_step = + make_tensor_coordinate_step(src0_desc, GetCoordinateResetStep()); + + move_tensor_coordinate(src0_desc, src0_coord_, src0_reset_step); + } + + if constexpr(Src1ResetCoordinateAfterRun) + { + const auto src1_reset_step = + make_tensor_coordinate_step(src1_desc, GetCoordinateResetStep()); + + move_tensor_coordinate(src1_desc, src1_coord_, src1_reset_step); + } + + if constexpr(DstResetCoordinateAfterRun) + { + const auto dst_reset_step = + make_tensor_coordinate_step(dst_desc, GetCoordinateResetStep()); + + move_tensor_coordinate(dst_desc, dst_coord_, dst_reset_step); + } + } + + __device__ static constexpr auto GetCoordinateResetStep() + { + // scalar per access on each dim + // TODO: don't use lambda_scalar_per_access + constexpr auto scalar_per_access = generate_sequence( + detail::lambda_scalar_per_access{}, Number{}); + + constexpr auto access_lengths = SliceLengths{} / scalar_per_access; + + constexpr auto dim_access_order = DimAccessOrder{}; + + constexpr auto ordered_access_lengths = + container_reorder_given_new2old(access_lengths, dim_access_order); + + // judge move forward or move backward during the last iteration + constexpr auto forward_sweep = [&]() { + StaticallyIndexedArray forward_sweep_; + + forward_sweep_(I0) = true; + + static_for<1, nDim, 1>{}([&](auto i) { + index_t tmp = ordered_access_lengths[I0] - 1; + + static_for<1, i, 1>{}([&](auto j) { + tmp = tmp * ordered_access_lengths[j] + ordered_access_lengths[j] - 1; + }); + + forward_sweep_(i) = tmp % 2 == 0; + }); + + return forward_sweep_; + }(); + + // calculate data index after last iteration in Run(), if it has not being reset + constexpr auto data_idx = [&]() { + Index ordered_idx; + + static_for<0, nDim, 1>{}([&](auto i) { + ordered_idx(i) = forward_sweep[i] ? ordered_access_lengths[i] - 1 : 0; + }); + + return container_reorder_given_old2new(ordered_idx, dim_access_order) * + scalar_per_access; + }(); + + // + constexpr auto reset_data_step = [&]() { + Index reset_data_step_; + + static_for<0, nDim, 1>{}([&](auto i) { reset_data_step_(i) = -data_idx[i]; }); + + return reset_data_step_; + }(); + + return reset_data_step; + } + + // src_slice_origin_step_idx need to be known at compile-time, for performance reason + __device__ void MoveSrc0SliceWindow(const Src0Desc& src0_desc, + const Index& src0_slice_origin_step_idx) + { + // if src coord was not reset by RunRead(), then need to adjust the step here + const auto adjusted_step_idx = Src0ResetCoordinateAfterRun + ? src0_slice_origin_step_idx + : src0_slice_origin_step_idx + GetCoordinateResetStep(); + + // is it OK to construct a new step every time? + const auto adjusted_step = make_tensor_coordinate_step(src0_desc, adjusted_step_idx); + + move_tensor_coordinate(src0_desc, src0_coord_, adjusted_step); + } + + // src_slice_origin_step_idx need to be known at compile-time, for performance reason + __device__ void MoveSrc1SliceWindow(const Src1Desc& src1_desc, + const Index& src1_slice_origin_step_idx) + { + // if src coord was not reset by RunRead(), then need to adjust the step here + const auto adjusted_step_idx = Src1ResetCoordinateAfterRun + ? src1_slice_origin_step_idx + : src1_slice_origin_step_idx + GetCoordinateResetStep(); + + // is it OK to construct a new step every time? + const auto adjusted_step = make_tensor_coordinate_step(src1_desc, adjusted_step_idx); + + move_tensor_coordinate(src1_desc, src1_coord_, adjusted_step); + } + + // dst_slice_origin_step_idx need to be known at compile-time, for performance reason + __device__ void MoveDstSliceWindow(const DstDesc& dst_desc, + const Index& dst_slice_origin_step_idx) + { + // if dst coord was not reset by Run(), then need to adjust the step here + const auto adjusted_step_idx = DstResetCoordinateAfterRun + ? dst_slice_origin_step_idx + : dst_slice_origin_step_idx + GetCoordinateResetStep(); + + // is it OK to construct a new step every time? + const auto adjusted_step = make_tensor_coordinate_step(dst_desc, adjusted_step_idx); + + move_tensor_coordinate(dst_desc, dst_coord_, adjusted_step); + } + + private: + Src0Coord src0_coord_; + Src1Coord src1_coord_; + DstCoord dst_coord_; + const ElementwiseOperation element_op_; +}; + +} // namespace ck +#endif diff --git a/composable_kernel/include/tensor_operation/threadwise_tensor_slice_transfer_v6r3.hpp b/composable_kernel/include/tensor_operation/threadwise_tensor_slice_transfer_v6r3.hpp new file mode 100644 index 0000000000..c7590d904c --- /dev/null +++ b/composable_kernel/include/tensor_operation/threadwise_tensor_slice_transfer_v6r3.hpp @@ -0,0 +1,455 @@ +#ifndef CK_THREADWISE_TENSOR_SLICE_TRANSFER_V6R3_HPP +#define CK_THREADWISE_TENSOR_SLICE_TRANSFER_V6R3_HPP + +#include "common_header.hpp" +#include "tensor_descriptor.hpp" +#include "tensor_descriptor_helper.hpp" + +namespace ck { + +// Do following things to avoid "alloca" in LLVM-IR, which would cause scratch memory +// and sometimes useless instructions: +// 1. Don't save a reference to tensor descriptor in class, pass in tensor descriptor as argument +// instead +// 2. Don't construct a new tensor coordinate everytime when using it, update and reuse the same +// tensor coordinate instead +// 3. Don't use a pointer to VGPR buffer, use vector instead + +// Assume: +// 1. src0_desc and dst_desc are not known at compile-time +// 2. SrcBuffer and DstBuffer are DynamicBuffer +// 3. src_slice_origin and dst_slice_origin are not known at compile-time, +template +struct ThreadwiseTensorSliceTransfer_v6r3 +{ + static constexpr index_t nDim = SliceLengths::Size(); + + using Index = MultiIndex; + + using Src0Coord = decltype(make_tensor_coordinate(Src0Desc{}, Index{})); + using Src1Coord = decltype(make_tensor_coordinate(Src1Desc{}, Index{})); + using Src2Coord = decltype(make_tensor_coordinate(Src2Desc{}, Index{})); + using DstCoord = decltype(make_tensor_coordinate(DstDesc{}, Index{})); + + using Src0CoordStep = decltype(make_tensor_coordinate_step(Src0Desc{}, Index{})); + using Src1CoordStep = decltype(make_tensor_coordinate_step(Src1Desc{}, Index{})); + using Src2CoordStep = decltype(make_tensor_coordinate_step(Src2Desc{}, Index{})); + using DstCoordStep = decltype(make_tensor_coordinate_step(DstDesc{}, Index{})); + + static constexpr auto I0 = Number<0>{}; + + __device__ constexpr ThreadwiseTensorSliceTransfer_v6r3(const Src0Desc& src0_desc, + const Index& src0_slice_origin, + const Src1Desc& src1_desc, + const Index& src1_slice_origin, + const Src2Desc& src2_desc, + const Index& src2_slice_origin, + const DstDesc& dst_desc, + const Index& dst_slice_origin, + const ElementwiseOperation& element_op) + : src0_coord_(make_tensor_coordinate(src0_desc, src0_slice_origin)), + src1_coord_(make_tensor_coordinate(src1_desc, src1_slice_origin)), + src2_coord_(make_tensor_coordinate(src2_desc, src2_slice_origin)), + dst_coord_(make_tensor_coordinate(dst_desc, dst_slice_origin)), + element_op_(element_op) + { + static_assert(SliceLengths::At(Number{}) % ScalarPerVector == 0, + "wrong! cannot evenly divide"); + } + + __device__ void SetSrc0SliceOrigin(const Src0Desc& src0_desc, + const Index& src0_slice_origin_idx) + { + src0_coord_ = make_tensor_coordinate(src0_desc, src0_slice_origin_idx); + } + + __device__ void SetSrc1SliceOrigin(const Src1Desc& src1_desc, + const Index& src1_slice_origin_idx) + { + src1_coord_ = make_tensor_coordinate(src1_desc, src1_slice_origin_idx); + } + + __device__ void SetSrc2SliceOrigin(const Src2Desc& src2_desc, + const Index& src2_slice_origin_idx) + { + src2_coord_ = make_tensor_coordinate(src2_desc, src2_slice_origin_idx); + } + + __device__ void SetDstSliceOrigin(const DstDesc& dst_desc, const Index& dst_slice_origin_idx) + { + dst_coord_ = make_tensor_coordinate(dst_desc, dst_slice_origin_idx); + } + + template + __device__ void Run(const Src0Desc& src0_desc, + const Src0Buffer& src0_buf, + const Src1Desc& src1_desc, + const Src1Buffer& src1_buf, + const Src2Desc& src2_desc, + const Src2Buffer& src2_buf, + const DstDesc& dst_desc, + DstBuffer& dst_buf) + { + // scalar per access on each dim + // TODO: don't use lambda_scalar_per_access + constexpr auto scalar_per_access = generate_sequence( + detail::lambda_scalar_per_access{}, Number{}); + + constexpr auto access_lengths = SliceLengths{} / scalar_per_access; + + constexpr auto dim_access_order = DimAccessOrder{}; + + constexpr auto ordered_access_lengths = + container_reorder_given_new2old(access_lengths, dim_access_order); + + auto make_forward_steps = [&](auto desc) { + return generate_tuple( + [&](auto i) { + Index forward_step_idx; + + static_for<0, nDim, 1>{}([&](auto j) { + forward_step_idx(j) = (i.value == j.value) ? scalar_per_access[i] : 0; + }); + + return make_tensor_coordinate_step(desc, forward_step_idx); + }, + Number{}); + }; + + auto make_backward_steps = [&](auto desc) { + return generate_tuple( + [&](auto i) { + Index backward_step_idx; + + static_for<0, nDim, 1>{}([&](auto j) { + backward_step_idx(j) = (i.value == j.value) ? -scalar_per_access[i] : 0; + }); + + return make_tensor_coordinate_step(desc, backward_step_idx); + }, + Number{}); + }; + + // make forward steps + const auto src0_forward_steps = make_forward_steps(src0_desc); + const auto src1_forward_steps = make_forward_steps(src1_desc); + const auto src2_forward_steps = make_forward_steps(src2_desc); + const auto dst_forward_steps = make_forward_steps(dst_desc); + + // make backward steps + const auto src0_backward_steps = make_backward_steps(src0_desc); + const auto src1_backward_steps = make_backward_steps(src1_desc); + const auto src2_backward_steps = make_backward_steps(src2_desc); + const auto dst_backward_steps = make_backward_steps(dst_desc); + + // loop over slice window + static_ford{}([&](auto ordered_access_idx) { + // judge move forward or move backward + constexpr auto forward_sweep = [&]() { + StaticallyIndexedArray forward_sweep_; + + forward_sweep_(I0) = true; + + static_for<1, nDim, 1>{}([&](auto i) { + index_t tmp = ordered_access_idx[I0]; + + static_for<1, i, 1>{}([&](auto j) { + tmp = tmp * ordered_access_lengths[j] + ordered_access_idx[j]; + }); + + forward_sweep_(i) = tmp % 2 == 0; + }); + + return forward_sweep_; + }(); + + using src0_vector_type = vector_type_maker_t; + using src0_vector_t = typename src0_vector_type::type; + + using src1_vector_type = vector_type_maker_t; + using src1_vector_t = typename src1_vector_type::type; + + using src2_vector_type = vector_type_maker_t; + using src2_vector_t = typename src2_vector_type::type; + + using dst_vector_type = vector_type_maker_t; + using dst_vector_t = typename dst_vector_type::type; + + const bool is_src0_valid = + coordinate_has_valid_offset_assuming_visible_index_is_valid(src0_desc, src0_coord_); + + const bool is_src1_valid = + coordinate_has_valid_offset_assuming_visible_index_is_valid(src1_desc, src1_coord_); + + const bool is_src2_valid = + coordinate_has_valid_offset_assuming_visible_index_is_valid(src2_desc, src2_coord_); + + // copy data from src0_buf into src0_vector_container + auto src0_vector_container = src0_vector_type{ + src0_buf.template Get(src0_coord_.GetOffset(), is_src0_valid)}; + + auto src1_vector_container = src1_vector_type{ + src1_buf.template Get(src1_coord_.GetOffset(), is_src1_valid)}; + + auto src2_vector_container = src2_vector_type{ + src2_buf.template Get(src2_coord_.GetOffset(), is_src2_valid)}; + + auto dst_vector_container = dst_vector_type{}; + + // apply pointwise operation + static_for<0, ScalarPerVector, 1>{}([&](auto i) { + element_op_(dst_vector_container.template AsType()(i), + src0_vector_container.template AsType()[i], + src1_vector_container.template AsType()[i], + src2_vector_container.template AsType()[i]); + }); + + const bool is_dst_valid = + coordinate_has_valid_offset_assuming_visible_index_is_valid(dst_desc, dst_coord_); + + // copy data from dst_vector into dst_buf + if constexpr(DstInMemOp == InMemoryDataOperationEnum_t::Set) + { + dst_buf.template Set( + dst_coord_.GetOffset(), + is_dst_valid, + dst_vector_container.template AsType()[I0]); + } + else if constexpr(DstInMemOp == InMemoryDataOperationEnum_t::AtomicAdd) + { + dst_buf.template AtomicAdd( + dst_coord_.GetOffset(), + is_dst_valid, + dst_vector_container.template AsType()[I0]); + } + + constexpr auto move_on_dim = [&]() constexpr + { + StaticallyIndexedArray move_on_dim_; + + static_for<0, nDim, 1>{}([&](auto i) { + move_on_dim_(i) = ordered_access_idx[i] < ordered_access_lengths[i] - 1; + + static_for{}([&](auto j) { + move_on_dim_(i) &= ordered_access_idx[j] == ordered_access_lengths[j] - 1; + }); + }); + + return move_on_dim_; + } + (); + + // move coordinate + static_for<0, nDim, 1>{}([&](auto i) { + if constexpr(move_on_dim[i]) + { + if constexpr(forward_sweep[i]) + { + move_tensor_coordinate( + src0_desc, src0_coord_, src0_forward_steps[dim_access_order[i]]); + + move_tensor_coordinate( + src1_desc, src1_coord_, src1_forward_steps[dim_access_order[i]]); + + move_tensor_coordinate( + src2_desc, src2_coord_, src2_forward_steps[dim_access_order[i]]); + + move_tensor_coordinate( + dst_desc, dst_coord_, dst_forward_steps[dim_access_order[i]]); + } + else + { + move_tensor_coordinate( + src0_desc, src0_coord_, src0_backward_steps[dim_access_order[i]]); + + move_tensor_coordinate( + src1_desc, src1_coord_, src1_backward_steps[dim_access_order[i]]); + + move_tensor_coordinate( + src2_desc, src2_coord_, src2_backward_steps[dim_access_order[i]]); + + move_tensor_coordinate( + dst_desc, dst_coord_, dst_backward_steps[dim_access_order[i]]); + } + } + }); + }); + + // move coordinate back to slice origin (or not) + if constexpr(Src0ResetCoordinateAfterRun) + { + const auto src0_reset_step = + make_tensor_coordinate_step(src0_desc, GetCoordinateResetStep()); + + move_tensor_coordinate(src0_desc, src0_coord_, src0_reset_step); + } + + if constexpr(Src1ResetCoordinateAfterRun) + { + const auto src1_reset_step = + make_tensor_coordinate_step(src1_desc, GetCoordinateResetStep()); + + move_tensor_coordinate(src1_desc, src1_coord_, src1_reset_step); + } + + if constexpr(Src2ResetCoordinateAfterRun) + { + const auto src2_reset_step = + make_tensor_coordinate_step(src2_desc, GetCoordinateResetStep()); + + move_tensor_coordinate(src2_desc, src2_coord_, src2_reset_step); + } + + if constexpr(DstResetCoordinateAfterRun) + { + const auto dst_reset_step = + make_tensor_coordinate_step(dst_desc, GetCoordinateResetStep()); + + move_tensor_coordinate(dst_desc, dst_coord_, dst_reset_step); + } + } + + __device__ static constexpr auto GetCoordinateResetStep() + { + // scalar per access on each dim + // TODO: don't use lambda_scalar_per_access + constexpr auto scalar_per_access = generate_sequence( + detail::lambda_scalar_per_access{}, Number{}); + + constexpr auto access_lengths = SliceLengths{} / scalar_per_access; + + constexpr auto dim_access_order = DimAccessOrder{}; + + constexpr auto ordered_access_lengths = + container_reorder_given_new2old(access_lengths, dim_access_order); + + // judge move forward or move backward during the last iteration + constexpr auto forward_sweep = [&]() { + StaticallyIndexedArray forward_sweep_; + + forward_sweep_(I0) = true; + + static_for<1, nDim, 1>{}([&](auto i) { + index_t tmp = ordered_access_lengths[I0] - 1; + + static_for<1, i, 1>{}([&](auto j) { + tmp = tmp * ordered_access_lengths[j] + ordered_access_lengths[j] - 1; + }); + + forward_sweep_(i) = tmp % 2 == 0; + }); + + return forward_sweep_; + }(); + + // calculate data index after last iteration in Run(), if it has not being reset + constexpr auto data_idx = [&]() { + Index ordered_idx; + + static_for<0, nDim, 1>{}([&](auto i) { + ordered_idx(i) = forward_sweep[i] ? ordered_access_lengths[i] - 1 : 0; + }); + + return container_reorder_given_old2new(ordered_idx, dim_access_order) * + scalar_per_access; + }(); + + // + constexpr auto reset_data_step = [&]() { + Index reset_data_step_; + + static_for<0, nDim, 1>{}([&](auto i) { reset_data_step_(i) = -data_idx[i]; }); + + return reset_data_step_; + }(); + + return reset_data_step; + } + + // src_slice_origin_step_idx need to be known at compile-time, for performance reason + __device__ void MoveSrc0SliceWindow(const Src0Desc& src0_desc, + const Index& src0_slice_origin_step_idx) + { + // if src coord was not reset by RunRead(), then need to adjust the step here + const auto adjusted_step_idx = Src0ResetCoordinateAfterRun + ? src0_slice_origin_step_idx + : src0_slice_origin_step_idx + GetCoordinateResetStep(); + + // is it OK to construct a new step every time? + const auto adjusted_step = make_tensor_coordinate_step(src0_desc, adjusted_step_idx); + + move_tensor_coordinate(src0_desc, src0_coord_, adjusted_step); + } + + // src_slice_origin_step_idx need to be known at compile-time, for performance reason + __device__ void MoveSrc1SliceWindow(const Src1Desc& src1_desc, + const Index& src1_slice_origin_step_idx) + { + // if src coord was not reset by RunRead(), then need to adjust the step here + const auto adjusted_step_idx = Src1ResetCoordinateAfterRun + ? src1_slice_origin_step_idx + : src1_slice_origin_step_idx + GetCoordinateResetStep(); + + // is it OK to construct a new step every time? + const auto adjusted_step = make_tensor_coordinate_step(src1_desc, adjusted_step_idx); + + move_tensor_coordinate(src1_desc, src1_coord_, adjusted_step); + } + + // src_slice_origin_step_idx need to be known at compile-time, for performance reason + __device__ void MoveSrc2SliceWindow(const Src2Desc& src2_desc, + const Index& src2_slice_origin_step_idx) + { + // if src coord was not reset by RunRead(), then need to adjust the step here + const auto adjusted_step_idx = Src2ResetCoordinateAfterRun + ? src2_slice_origin_step_idx + : src2_slice_origin_step_idx + GetCoordinateResetStep(); + + // is it OK to construct a new step every time? + const auto adjusted_step = make_tensor_coordinate_step(src2_desc, adjusted_step_idx); + + move_tensor_coordinate(src2_desc, src2_coord_, adjusted_step); + } + + // dst_slice_origin_step_idx need to be known at compile-time, for performance reason + __device__ void MoveDstSliceWindow(const DstDesc& dst_desc, + const Index& dst_slice_origin_step_idx) + { + // if dst coord was not reset by Run(), then need to adjust the step here + const auto adjusted_step_idx = DstResetCoordinateAfterRun + ? dst_slice_origin_step_idx + : dst_slice_origin_step_idx + GetCoordinateResetStep(); + + // is it OK to construct a new step every time? + const auto adjusted_step = make_tensor_coordinate_step(dst_desc, adjusted_step_idx); + + move_tensor_coordinate(dst_desc, dst_coord_, adjusted_step); + } + + private: + Src0Coord src0_coord_; + Src1Coord src1_coord_; + Src2Coord src2_coord_; + DstCoord dst_coord_; + const ElementwiseOperation element_op_; +}; + +} // namespace ck +#endif diff --git a/composable_kernel/include/utility/amd_buffer_addressing.hpp b/composable_kernel/include/utility/amd_buffer_addressing.hpp index 5f0257af26..773f7cff2c 100644 --- a/composable_kernel/include/utility/amd_buffer_addressing.hpp +++ b/composable_kernel/include/utility/amd_buffer_addressing.hpp @@ -31,7 +31,7 @@ __device__ int32x4_t make_wave_buffer_resource(T* p_wave, index_t element_space_ return wave_buffer_resource.content; } -// load +// buffer load i8 __device__ int8_t llvm_amdgcn_raw_buffer_load_i8(int32x4_t srsrc, index_t voffset, @@ -50,6 +50,7 @@ llvm_amdgcn_raw_buffer_load_i8x4(int32x4_t srsrc, index_t soffset, index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.v4i8"); +// buffer load i16 __device__ ushort llvm_amdgcn_raw_buffer_load_i16(int32x4_t srsrc, index_t voffset, @@ -68,6 +69,7 @@ llvm_amdgcn_raw_buffer_load_i16x4(int32x4_t srsrc, index_t soffset, index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.v4i16"); +// buffer load i32 __device__ int32_t llvm_amdgcn_raw_buffer_load_i32(int32x4_t srsrc, index_t voffset, @@ -85,7 +87,7 @@ llvm_amdgcn_raw_buffer_load_i32x4(int32x4_t srsrc, index_t voffset, index_t soffset, index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.v4i32"); -// half +// buffer load fp16 __device__ half_t llvm_amdgcn_raw_buffer_load_fp16(int32x4_t srsrc, index_t voffset, @@ -104,7 +106,7 @@ llvm_amdgcn_raw_buffer_load_fp16x4(int32x4_t srsrc, index_t soffset, index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.v4f16"); -// float +// buffer load fp32 __device__ float llvm_amdgcn_raw_buffer_load_fp32(int32x4_t srsrc, index_t voffset, @@ -123,7 +125,7 @@ llvm_amdgcn_raw_buffer_load_fp32x4(int32x4_t srsrc, index_t soffset, index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.v4f32"); -// store +// buffer store i8 __device__ void llvm_amdgcn_raw_buffer_store_i8(int8_t vdata, int32x4_t rsrc, @@ -145,6 +147,7 @@ llvm_amdgcn_raw_buffer_store_i8x4(int8x4_t vdata, index_t soffset, index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.v4i8"); +// buffer store i16 __device__ void llvm_amdgcn_raw_buffer_store_i16(ushort vdata, int32x4_t rsrc, @@ -166,6 +169,7 @@ llvm_amdgcn_raw_buffer_store_i16x4(ushort4_t vdata, index_t soffset, index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.v4i16"); +// buffer store i32 __device__ void llvm_amdgcn_raw_buffer_store_i32(int32_t vdata, int32x4_t rsrc, @@ -187,7 +191,7 @@ llvm_amdgcn_raw_buffer_store_i32x4(int32x4_t vdata, index_t soffset, index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.v4i32"); -// half +// buffer store fp16 __device__ void llvm_amdgcn_raw_buffer_store_fp16(half_t vdata, int32x4_t rsrc, @@ -208,7 +212,7 @@ llvm_amdgcn_raw_buffer_store_fp16x4(half4_t vdata, index_t voffset, index_t soffset, index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.v4f16"); -// float +// buffer store fp32 __device__ void llvm_amdgcn_raw_buffer_store_fp32(float vdata, int32x4_t rsrc, @@ -229,8 +233,15 @@ llvm_amdgcn_raw_buffer_store_fp32x4(float4_t vdata, index_t voffset, index_t soffset, index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.v4f32"); -// atomic add -// int +// buffer atomic-add fp16 +__device__ half2_t llvm_amdgcn_raw_buffer_atomic_add_fp16x2( + half2_t vdata, + int32x4_t rsrc, + index_t voffset, + index_t soffset, + index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.atomic.fadd.v2f16"); + +// buffer atomic-add i32 __device__ int32_t llvm_amdgcn_raw_buffer_atomic_add_i32( int32_t vdata, int32x4_t rsrc, @@ -238,7 +249,7 @@ __device__ int32_t llvm_amdgcn_raw_buffer_atomic_add_i32( index_t soffset, index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.atomic.add.i32"); -// float +// buffer atomic-add fp32 __device__ float llvm_amdgcn_raw_buffer_atomic_add_fp32( float vdata, int32x4_t rsrc, @@ -752,6 +763,7 @@ __device__ void amd_buffer_atomic_add_impl(const typename vector_type::typ index_t dst_wave_addr_offset) { static_assert((is_same::value && (N == 1 || N == 2 || N == 4)) || + (is_same::value && (N == 2 || N == 4 || N == 8)) || (is_same::value && (N == 1 || N == 2 || N == 4)), "wrong! not implemented"); @@ -810,6 +822,41 @@ __device__ void amd_buffer_atomic_add_impl(const typename vector_type::typ 0); } } + else if constexpr(is_same::value) + { + if constexpr(N == 2) + { + llvm_amdgcn_raw_buffer_atomic_add_fp16x2(src_thread_data, + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset, + 0); + } + else if constexpr(N == 4) + { + vector_type tmp{src_thread_data}; + + static_for<0, 2, 1>{}([&](auto i) { + llvm_amdgcn_raw_buffer_atomic_add_fp16x2(tmp.AsType()[i], + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset + i * sizeof(half2_t), + 0); + }); + } + else if constexpr(N == 8) + { + vector_type tmp{src_thread_data}; + + static_for<0, 4, 1>{}([&](auto i) { + llvm_amdgcn_raw_buffer_atomic_add_fp16x2(tmp.AsType()[i], + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset + i * sizeof(half2_t), + 0); + }); + } + } else if constexpr(is_same::value) { if constexpr(N == 1) diff --git a/composable_kernel/include/utility/common_header.hpp b/composable_kernel/include/utility/common_header.hpp index 4afdc7d788..5915645be2 100644 --- a/composable_kernel/include/utility/common_header.hpp +++ b/composable_kernel/include/utility/common_header.hpp @@ -35,8 +35,8 @@ #include "dynamic_buffer.hpp" #include "is_known_at_compile_time.hpp" #include "transpose_vectors.hpp" - #include "inner_product.hpp" +#include "element_wise_operation.hpp" // TODO: remove this #if CK_USE_AMD_INLINE_ASM diff --git a/composable_kernel/include/utility/config.hpp b/composable_kernel/include/utility/config.hpp index 0566048fc9..f29ab54660 100644 --- a/composable_kernel/include/utility/config.hpp +++ b/composable_kernel/include/utility/config.hpp @@ -24,12 +24,16 @@ #define CK_MIN_BLOCK_PER_CU 2 #endif -// buffer resourse +// GPU-specific parameters #if defined(CK_AMD_GPU_GFX803) || defined(CK_AMD_GPU_GFX900) || defined(CK_AMD_GPU_GFX906) || \ defined(CK_AMD_GPU_GFX908) || defined(CK_AMD_GPU_GFX90A) +// buffer resourse #define CK_BUFFER_RESOURCE_3RD_DWORD 0x00020000 +// wave size +#define CK_GPU_WAVE_SIZE 64 #elif defined(CK_AMD_GPU_GFX1030) #define CK_BUFFER_RESOURCE_3RD_DWORD 0x31014000 +#define CK_GPU_WAVE_SIZE 32 #endif // FMA instruction @@ -141,6 +145,10 @@ #define CK_WORKAROUND_SWDEV_XXXXXX_THREAD_WISE_COPY_V1R4_TYPE_CONVERT_ISSUE 1 #endif +#ifndef CK_WORKAROUND_SWDEV_XXXXXX_THREAD_WISE_COPY_V1R5_TYPE_CONVERT_ISSUE +#define CK_WORKAROUND_SWDEV_XXXXXX_THREAD_WISE_COPY_V1R5_TYPE_CONVERT_ISSUE 1 +#endif + namespace ck { enum InMemoryDataOperationEnum_t @@ -152,7 +160,7 @@ enum InMemoryDataOperationEnum_t enum ActivTypeEnum_t { - None = 0, + None, LeakyRelu, Sigmoid }; diff --git a/composable_kernel/include/utility/utility.hpp b/composable_kernel/include/utility/utility.hpp index 9f34e044b7..c4cc717618 100644 --- a/composable_kernel/include/utility/utility.hpp +++ b/composable_kernel/include/utility/utility.hpp @@ -5,8 +5,12 @@ namespace ck { +__device__ constexpr index_t get_wave_size() { return CK_GPU_WAVE_SIZE; } + __device__ index_t get_thread_local_1d_id() { return threadIdx.x; } +__device__ index_t get_wave_local_1d_id() { return threadIdx.x / get_wave_size(); } + __device__ index_t get_block_1d_id() { return blockIdx.x; } } // namespace ck diff --git a/device_operation/device_conv2d_fwd_xdl_c_shuffle_bias_relu_add_nhwc_kyxc_nhwk_f16_instance.cpp b/device_operation/device_conv2d_fwd_xdl_c_shuffle_bias_relu_add_nhwc_kyxc_nhwk_f16_instance.cpp new file mode 100644 index 0000000000..dbfa6e2031 --- /dev/null +++ b/device_operation/device_conv2d_fwd_xdl_c_shuffle_bias_relu_add_nhwc_kyxc_nhwk_f16_instance.cpp @@ -0,0 +1,144 @@ +#include +#include "config.hpp" +#include "device_conv2d_fwd_xdl_c_shuffle_bias_activation_add_nhwc_kyxc_nhwk.hpp" +#include "element_wise_operation.hpp" +#include "device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace device_conv2d_fwd_bias_activation_add_instance { + +using F16 = ck::half_t; +using F32 = float; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using AddReluAdd = ck::tensor_operation::element_wise::AddReluAdd; + +static constexpr auto ConvFwdDefault = + ck::tensor_operation::device::ConvolutionForwardSpecialization_t::Default; + +static constexpr auto ConvFwd1x1P0 = + ck::tensor_operation::device::ConvolutionForwardSpecialization_t::Filter1x1Pad0; + +static constexpr auto ConvFwd1x1S1P0 = + ck::tensor_operation::device::ConvolutionForwardSpecialization_t::Filter1x1Stride1Pad0; + +static constexpr auto ConvFwdOddC = + ck::tensor_operation::device::ConvolutionForwardSpecialization_t::OddC; + +// arbitrary conv +using device_conv2d_fwd_xdl_c_shuffle_bias_relu_add_nhwc_kyxc_nhwk_f16_instances = std::tuple< + // clang-format off + //##############################################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvForward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //##############################################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| + //##############################################################################################| | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| + //##############################################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Add_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddReluAdd, ConvFwdDefault, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Add_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddReluAdd, ConvFwdDefault, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Add_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddReluAdd, ConvFwdDefault, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 8>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Add_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddReluAdd, ConvFwdDefault, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Add_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddReluAdd, ConvFwdDefault, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 4>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Add_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddReluAdd, ConvFwdDefault, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 8>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Add_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddReluAdd, ConvFwdDefault, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 4>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Add_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddReluAdd, ConvFwdDefault, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Add_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddReluAdd, ConvFwdDefault, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Add_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddReluAdd, ConvFwdDefault, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 4>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Add_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddReluAdd, ConvFwdDefault, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 8>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Add_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddReluAdd, ConvFwdDefault, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 4>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Add_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddReluAdd, ConvFwdDefault, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 4>, 8> + // clang-format on + >; + +// 1x1, pad 0 +using device_conv2d_fwd_xdl_c_shuffle_bias_relu_add_nhwc_kyxc_nhwk_1x1_p0_f16_instances = std::tuple< + // clang-format off + //##############################################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvForward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //##############################################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| + //##############################################################################################| | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| + //##############################################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Add_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddReluAdd, ConvFwd1x1P0, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Add_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddReluAdd, ConvFwd1x1P0, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Add_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddReluAdd, ConvFwd1x1P0, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 8>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Add_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddReluAdd, ConvFwd1x1P0, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Add_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddReluAdd, ConvFwd1x1P0, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 4>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Add_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddReluAdd, ConvFwd1x1P0, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 8>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Add_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddReluAdd, ConvFwd1x1P0, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 4>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Add_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddReluAdd, ConvFwd1x1P0, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Add_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddReluAdd, ConvFwd1x1P0, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Add_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddReluAdd, ConvFwd1x1P0, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 4>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Add_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddReluAdd, ConvFwd1x1P0, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 8>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Add_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddReluAdd, ConvFwd1x1P0, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 4>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Add_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddReluAdd, ConvFwd1x1P0, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 4>, 8> + // clang-format on + >; + +// 1x1, stride 1, pad 0 +using device_conv2d_fwd_xdl_c_shuffle_bias_relu_add_nhwc_kyxc_nhwk_1x1_s1_p0_f16_instances = std::tuple< + // clang-format off + //##############################################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvForward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //##############################################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| + //##############################################################################################| | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| + //##############################################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Add_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddReluAdd, ConvFwd1x1S1P0, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Add_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddReluAdd, ConvFwd1x1S1P0, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Add_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddReluAdd, ConvFwd1x1S1P0, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 8>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Add_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddReluAdd, ConvFwd1x1S1P0, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Add_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddReluAdd, ConvFwd1x1S1P0, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 4>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Add_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddReluAdd, ConvFwd1x1S1P0, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 8>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Add_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddReluAdd, ConvFwd1x1S1P0, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 4>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Add_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddReluAdd, ConvFwd1x1S1P0, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Add_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddReluAdd, ConvFwd1x1S1P0, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Add_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddReluAdd, ConvFwd1x1S1P0, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 4>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Add_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddReluAdd, ConvFwd1x1S1P0, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 8>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Add_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddReluAdd, ConvFwd1x1S1P0, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 4>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Add_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddReluAdd, ConvFwd1x1S1P0, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 4>, 8> + // clang-format on + >; + +// Odd C +using device_conv2d_fwd_xdl_c_shuffle_bias_relu_add_nhwc_kyxc_nhwk_odd_c_f16_instances = std::tuple< + // clang-format off + //##############################################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvForward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //##############################################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| + //##############################################################################################| | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| + //##############################################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Add_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddReluAdd, ConvFwdOddC, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 8, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, S<4, 8, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Add_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddReluAdd, ConvFwdOddC, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 8, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, S<4, 8, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Add_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddReluAdd, ConvFwdOddC, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 4, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, S<4, 4, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, 1, 1, S<1, 1, 16, 1, 1, 8>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Add_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddReluAdd, ConvFwdOddC, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 8, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, S<4, 8, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Add_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddReluAdd, ConvFwdOddC, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<4, 4, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, S<4, 4, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, 1, 1, S<1, 1, 32, 1, 1, 4>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Add_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddReluAdd, ConvFwdOddC, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<4, 4, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, S<4, 4, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, 1, 1, S<1, 1, 16, 1, 1, 8>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Add_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddReluAdd, ConvFwdOddC, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<4, 2, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, S<4, 2, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, 1, 1, S<1, 1, 16, 1, 1, 4>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Add_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddReluAdd, ConvFwdOddC, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 8, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, S<4, 8, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Add_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddReluAdd, ConvFwdOddC, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 8, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, S<4, 8, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Add_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddReluAdd, ConvFwdOddC, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<4, 4, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, S<4, 4, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, 1, 1, S<1, 1, 32, 1, 1, 4>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Add_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddReluAdd, ConvFwdOddC, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<4, 4, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, S<4, 4, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, 1, 1, S<1, 1, 16, 1, 1, 8>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Add_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddReluAdd, ConvFwdOddC, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<4, 2, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, S<4, 2, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, 1, 1, S<1, 1, 16, 1, 1, 4>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Add_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddReluAdd, ConvFwdOddC, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<4, 2, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, S<4, 2, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, 1, 1, S<1, 1, 16, 1, 1, 4>, 8> + // clang-format on + >; + +void add_device_conv2d_fwd_xdl_c_shuffle_bias_relu_add_nhwc_kyxc_nhwk_f16_instances( + std::vector>& instances) +{ + add_device_operation_instances( + instances, device_conv2d_fwd_xdl_c_shuffle_bias_relu_add_nhwc_kyxc_nhwk_f16_instances{}); + add_device_operation_instances( + instances, + device_conv2d_fwd_xdl_c_shuffle_bias_relu_add_nhwc_kyxc_nhwk_1x1_p0_f16_instances{}); + add_device_operation_instances( + instances, + device_conv2d_fwd_xdl_c_shuffle_bias_relu_add_nhwc_kyxc_nhwk_1x1_s1_p0_f16_instances{}); + add_device_operation_instances( + instances, + device_conv2d_fwd_xdl_c_shuffle_bias_relu_add_nhwc_kyxc_nhwk_odd_c_f16_instances{}); +} + +} // namespace device_conv2d_fwd_bias_activation_add_instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/device_operation/device_conv2d_fwd_xdl_c_shuffle_bias_relu_atomic_add_nhwc_kyxc_nhwk_f16_instance.cpp b/device_operation/device_conv2d_fwd_xdl_c_shuffle_bias_relu_atomic_add_nhwc_kyxc_nhwk_f16_instance.cpp new file mode 100644 index 0000000000..1c9a4b989c --- /dev/null +++ b/device_operation/device_conv2d_fwd_xdl_c_shuffle_bias_relu_atomic_add_nhwc_kyxc_nhwk_f16_instance.cpp @@ -0,0 +1,69 @@ +#include +#include "config.hpp" +#include "device_conv2d_fwd_xdl_c_shuffle_bias_activation_nhwc_kyxc_nhwk.hpp" +#include "element_wise_operation.hpp" +#include "device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace device_conv2d_fwd_bias_activation_atomic_add_instance { + +using F16 = ck::half_t; +using F32 = float; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using AddRelu = ck::tensor_operation::element_wise::AddRelu; + +static constexpr auto InMemoryAtomicAdd = ck::InMemoryDataOperationEnum_t::AtomicAdd; + +static constexpr auto ConvFwdDefault = + ck::tensor_operation::device::ConvolutionForwardSpecialization_t::Default; + +using device_conv2d_fwd_xdl_c_shuffle_bias_relu_atomic_add_nhwc_kyxc_nhwk_f16_instances = std::tuple< + // clang-format off + //##########################################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| Out| ConvForward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //##########################################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| GlobalMemory| Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| + //##########################################################################################| | | | | Operation| Operation| Operation| DataOperation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| + //##########################################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddRelu, InMemoryAtomicAdd, ConvFwdDefault, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 8, 1, 1, 32>, 2>, + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddRelu, InMemoryAtomicAdd, ConvFwdDefault, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 8, 1, 1, 32>, 2>, + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddRelu, InMemoryAtomicAdd, ConvFwdDefault, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 4, 1, 1, 32>, 2>, + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddRelu, InMemoryAtomicAdd, ConvFwdDefault, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 8, 1, 1, 32>, 2>, + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddRelu, InMemoryAtomicAdd, ConvFwdDefault, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 8, 1, 1, 16>, 2>, + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddRelu, InMemoryAtomicAdd, ConvFwdDefault, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 4, 1, 1, 32>, 2>, + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddRelu, InMemoryAtomicAdd, ConvFwdDefault, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 4, 1, 1, 16>, 2>, + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddRelu, InMemoryAtomicAdd, ConvFwdDefault, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 8, 1, 1, 32>, 2>, + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddRelu, InMemoryAtomicAdd, ConvFwdDefault, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 8, 1, 1, 32>, 2>, + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddRelu, InMemoryAtomicAdd, ConvFwdDefault, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 8, 1, 1, 16>, 2>, + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddRelu, InMemoryAtomicAdd, ConvFwdDefault, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 4, 1, 1, 32>, 2>, + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddRelu, InMemoryAtomicAdd, ConvFwdDefault, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 4, 1, 1, 16>, 2>, + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddRelu, InMemoryAtomicAdd, ConvFwdDefault, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 4, 1, 1, 16>, 2> + // clang-format on + >; + +void add_device_conv2d_fwd_xdl_c_shuffle_bias_relu_atomic_add_nhwc_kyxc_nhwk_f16_instances( + std::vector>& + instance_container) +{ + using Instances = + device_conv2d_fwd_xdl_c_shuffle_bias_relu_atomic_add_nhwc_kyxc_nhwk_f16_instances; + + const auto instances = Instances{}; + + ck::static_for<0, std::tuple_size_v, 1>{}([&](auto i) { + using Instance = remove_cvref_t(instances))>; + + auto instance = Instance{}; + + instance_container.push_back(std::make_unique(instance)); + }); +} + +} // namespace device_conv2d_fwd_bias_activation_atomic_add_instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/device_operation/device_conv2d_fwd_xdl_c_shuffle_bias_relu_nhwc_kyxc_nhwk_f16_instance.cpp b/device_operation/device_conv2d_fwd_xdl_c_shuffle_bias_relu_nhwc_kyxc_nhwk_f16_instance.cpp new file mode 100644 index 0000000000..075eddd117 --- /dev/null +++ b/device_operation/device_conv2d_fwd_xdl_c_shuffle_bias_relu_nhwc_kyxc_nhwk_f16_instance.cpp @@ -0,0 +1,144 @@ +#include +#include "config.hpp" +#include "device_conv2d_fwd_xdl_c_shuffle_bias_activation_nhwc_kyxc_nhwk.hpp" +#include "element_wise_operation.hpp" +#include "device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace device_conv2d_fwd_bias_activation_instance { + +using F16 = ck::half_t; +using F32 = float; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using AddRelu = ck::tensor_operation::element_wise::AddRelu; + +static constexpr auto MemorySet = ck::InMemoryDataOperationEnum_t::Set; + +static constexpr auto ConvFwdDefault = + ck::tensor_operation::device::ConvolutionForwardSpecialization_t::Default; + +static constexpr auto ConvFwd1x1P0 = + ck::tensor_operation::device::ConvolutionForwardSpecialization_t::Filter1x1Pad0; + +static constexpr auto ConvFwd1x1S1P0 = + ck::tensor_operation::device::ConvolutionForwardSpecialization_t::Filter1x1Stride1Pad0; + +static constexpr auto ConvFwdOddC = + ck::tensor_operation::device::ConvolutionForwardSpecialization_t::OddC; + +// arbitrary conv +using device_conv2d_fwd_xdl_c_shuffle_bias_relu_nhwc_kyxc_nhwk_f16_instances = std::tuple< + // clang-format off + //##########################################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| Out| ConvForward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //##########################################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| GlobalMemory| Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| + //##########################################################################################| | | | | Operation| Operation| Operation| DataOperation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| + //##########################################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddRelu, MemorySet, ConvFwdDefault, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddRelu, MemorySet, ConvFwdDefault, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddRelu, MemorySet, ConvFwdDefault, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 8>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddRelu, MemorySet, ConvFwdDefault, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddRelu, MemorySet, ConvFwdDefault, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 4>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddRelu, MemorySet, ConvFwdDefault, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 8>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddRelu, MemorySet, ConvFwdDefault, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 4>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddRelu, MemorySet, ConvFwdDefault, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddRelu, MemorySet, ConvFwdDefault, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddRelu, MemorySet, ConvFwdDefault, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 4>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddRelu, MemorySet, ConvFwdDefault, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 8>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddRelu, MemorySet, ConvFwdDefault, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 4>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddRelu, MemorySet, ConvFwdDefault, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 4>, 8> + // clang-format on + >; + +// 1x1, pad 0 +using device_conv2d_fwd_xdl_c_shuffle_bias_relu_nhwc_kyxc_nhwk_1x1_p0_f16_instances = std::tuple< + // clang-format off + //##########################################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| Out| ConvForward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //##########################################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| GlobalMemory| Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| + //##########################################################################################| | | | | Operation| Operation| Operation| DataOperation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| + //##########################################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddRelu, MemorySet, ConvFwd1x1P0, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddRelu, MemorySet, ConvFwd1x1P0, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddRelu, MemorySet, ConvFwd1x1P0, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 8>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddRelu, MemorySet, ConvFwd1x1P0, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddRelu, MemorySet, ConvFwd1x1P0, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 4>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddRelu, MemorySet, ConvFwd1x1P0, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 8>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddRelu, MemorySet, ConvFwd1x1P0, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 4>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddRelu, MemorySet, ConvFwd1x1P0, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddRelu, MemorySet, ConvFwd1x1P0, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddRelu, MemorySet, ConvFwd1x1P0, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 4>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddRelu, MemorySet, ConvFwd1x1P0, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 8>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddRelu, MemorySet, ConvFwd1x1P0, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 4>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddRelu, MemorySet, ConvFwd1x1P0, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 4>, 8> + // clang-format on + >; + +// 1x1, stride 1, pad 0 +using device_conv2d_fwd_xdl_c_shuffle_bias_relu_nhwc_kyxc_nhwk_1x1_s1_p0_f16_instances = std::tuple< + // clang-format off + //##########################################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| Out| ConvForward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //##########################################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| GlobalMemory| Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| + //##########################################################################################| | | | | Operation| Operation| Operation| DataOperation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| + //##########################################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddRelu, MemorySet, ConvFwd1x1S1P0, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddRelu, MemorySet, ConvFwd1x1S1P0, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddRelu, MemorySet, ConvFwd1x1S1P0, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 8>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddRelu, MemorySet, ConvFwd1x1S1P0, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddRelu, MemorySet, ConvFwd1x1S1P0, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 4>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddRelu, MemorySet, ConvFwd1x1S1P0, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 8>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddRelu, MemorySet, ConvFwd1x1S1P0, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 4>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddRelu, MemorySet, ConvFwd1x1S1P0, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddRelu, MemorySet, ConvFwd1x1S1P0, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddRelu, MemorySet, ConvFwd1x1S1P0, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 4>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddRelu, MemorySet, ConvFwd1x1S1P0, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 8>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddRelu, MemorySet, ConvFwd1x1S1P0, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 4>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddRelu, MemorySet, ConvFwd1x1S1P0, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 4>, 8> + // clang-format on + >; + +// Odd C +using device_conv2d_fwd_xdl_c_shuffle_bias_relu_nhwc_kyxc_nhwk_odd_c_f16_instances = std::tuple< + // clang-format off + //##########################################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| Out| ConvForward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //##########################################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| GlobalMemory| Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| + //##########################################################################################| | | | | Operation| Operation| Operation| DataOperation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| + //##########################################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddRelu, MemorySet, ConvFwdOddC, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 8, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, S<4, 8, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddRelu, MemorySet, ConvFwdOddC, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 8, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, S<4, 8, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddRelu, MemorySet, ConvFwdOddC, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 4, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, S<4, 4, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, 1, 1, S<1, 1, 16, 1, 1, 8>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddRelu, MemorySet, ConvFwdOddC, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 8, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, S<4, 8, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddRelu, MemorySet, ConvFwdOddC, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<4, 4, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, S<4, 4, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, 1, 1, S<1, 1, 32, 1, 1, 4>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddRelu, MemorySet, ConvFwdOddC, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<4, 4, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, S<4, 4, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, 1, 1, S<1, 1, 16, 1, 1, 8>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddRelu, MemorySet, ConvFwdOddC, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<4, 2, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, S<4, 2, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, 1, 1, S<1, 1, 16, 1, 1, 4>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddRelu, MemorySet, ConvFwdOddC, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 8, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, S<4, 8, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddRelu, MemorySet, ConvFwdOddC, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 8, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, S<4, 8, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddRelu, MemorySet, ConvFwdOddC, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<4, 4, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, S<4, 4, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, 1, 1, S<1, 1, 32, 1, 1, 4>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddRelu, MemorySet, ConvFwdOddC, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<4, 4, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, S<4, 4, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, 1, 1, S<1, 1, 16, 1, 1, 8>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddRelu, MemorySet, ConvFwdOddC, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<4, 2, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, S<4, 2, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, 1, 1, S<1, 1, 16, 1, 1, 4>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddRelu, MemorySet, ConvFwdOddC, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<4, 2, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, S<4, 2, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, 1, 1, S<1, 1, 16, 1, 1, 4>, 8> + // clang-format on + >; + +void add_device_conv2d_fwd_xdl_c_shuffle_bias_relu_nhwc_kyxc_nhwk_f16_instances( + std::vector>& instances) +{ + add_device_operation_instances( + instances, device_conv2d_fwd_xdl_c_shuffle_bias_relu_nhwc_kyxc_nhwk_f16_instances{}); + add_device_operation_instances( + instances, device_conv2d_fwd_xdl_c_shuffle_bias_relu_nhwc_kyxc_nhwk_1x1_p0_f16_instances{}); + add_device_operation_instances( + instances, + device_conv2d_fwd_xdl_c_shuffle_bias_relu_nhwc_kyxc_nhwk_1x1_s1_p0_f16_instances{}); + add_device_operation_instances( + instances, device_conv2d_fwd_xdl_c_shuffle_bias_relu_nhwc_kyxc_nhwk_odd_c_f16_instances{}); +} + +} // namespace device_conv2d_fwd_bias_activation_instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/device_operation/device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk_f16_instance.cpp b/device_operation/device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk_f16_instance.cpp new file mode 100644 index 0000000000..cd9ee30627 --- /dev/null +++ b/device_operation/device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk_f16_instance.cpp @@ -0,0 +1,139 @@ +#include +#include "config.hpp" +#include "device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp" +#include "element_wise_operation.hpp" +#include "device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace device_conv2d_fwd_instance { + +using F16 = ck::half_t; +using F32 = float; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto ConvFwdDefault = + ck::tensor_operation::device::ConvolutionForwardSpecialization_t::Default; + +static constexpr auto ConvFwd1x1P0 = + ck::tensor_operation::device::ConvolutionForwardSpecialization_t::Filter1x1Pad0; + +static constexpr auto ConvFwd1x1S1P0 = + ck::tensor_operation::device::ConvolutionForwardSpecialization_t::Filter1x1Stride1Pad0; + +static constexpr auto ConvFwdOddC = + ck::tensor_operation::device::ConvolutionForwardSpecialization_t::OddC; + +// arbitrary conv +using device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk_f16_instances = std::tuple< + // clang-format off + //##########################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvForward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //##########################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| + //##########################################################################| | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| + //##########################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 8>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 4>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 8>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 4>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 4>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 8>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 4>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 4>, 8> + // clang-format on + >; + +// 1x1, pad 0 +using device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk_1x1_p0_f16_instances = std::tuple< + // clang-format off + //##########################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvForward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //##########################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| + //##########################################################################| | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| + //##########################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 8>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 4>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 8>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 4>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 4>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 8>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 4>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 4>, 8> + // clang-format on + >; + +// 1x1, stride 1, pad 0 +using device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk_1x1_s1_p0_f16_instances = std::tuple< + // clang-format off + //##########################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvForward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //##########################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| + //##########################################################################| | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| + //##########################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 8>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 4>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 8>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 4>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 4>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 8>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 4>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 4>, 8> + // clang-format on + >; + +using device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk_odd_c_f16_instances = std::tuple< + // clang-format off + //##########################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvForward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //##########################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| + //##########################################################################| | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| + //##########################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwdOddC, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 8, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, S<4, 8, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwdOddC, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 8, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, S<4, 8, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwdOddC, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 4, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, S<4, 4, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, 1, 1, S<1, 1, 16, 1, 1, 8>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwdOddC, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 8, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, S<4, 8, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwdOddC, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<4, 4, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, S<4, 4, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, 1, 1, S<1, 1, 32, 1, 1, 4>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwdOddC, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<4, 4, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, S<4, 4, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, 1, 1, S<1, 1, 16, 1, 1, 8>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwdOddC, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<4, 2, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, S<4, 2, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, 1, 1, S<1, 1, 16, 1, 1, 4>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwdOddC, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 8, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, S<4, 8, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwdOddC, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 8, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, S<4, 8, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwdOddC, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<4, 4, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, S<4, 4, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, 1, 1, S<1, 1, 32, 1, 1, 4>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwdOddC, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<4, 4, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, S<4, 4, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, 1, 1, S<1, 1, 16, 1, 1, 8>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwdOddC, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<4, 2, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, S<4, 2, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, 1, 1, S<1, 1, 16, 1, 1, 4>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwdOddC, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<4, 2, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, S<4, 2, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, 1, 1, S<1, 1, 16, 1, 1, 4>, 8> + // clang-format on + >; + +void add_device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk_f16_instances( + std::vector>& instances) +{ + add_device_operation_instances(instances, + device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk_f16_instances{}); + add_device_operation_instances( + instances, device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk_1x1_p0_f16_instances{}); + add_device_operation_instances( + instances, device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk_1x1_s1_p0_f16_instances{}); + add_device_operation_instances( + instances, device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk_odd_c_f16_instances{}); +} + +} // namespace device_conv2d_fwd_instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/device_operation/device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f16_instance.cpp b/device_operation/device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f16_instance.cpp new file mode 100644 index 0000000000..beaad1d3b4 --- /dev/null +++ b/device_operation/device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f16_instance.cpp @@ -0,0 +1,109 @@ +#include +#include "config.hpp" +#include "device_conv2d_fwd_xdl_nhwc_kyxc_nhwk.hpp" +#include "element_wise_operation.hpp" +#include "device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace device_conv2d_fwd_instance { + +using F16 = ck::half_t; +using F32 = float; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto ConvFwdDefault = + ck::tensor_operation::device::ConvolutionForwardSpecialization_t::Default; + +static constexpr auto ConvFwd1x1P0 = + ck::tensor_operation::device::ConvolutionForwardSpecialization_t::Filter1x1Pad0; + +static constexpr auto ConvFwd1x1S1P0 = + ck::tensor_operation::device::ConvolutionForwardSpecialization_t::Filter1x1Stride1Pad0; + +// Compilation parameters for in[n, hi, wi, c] * wei[k, y, x, c] = out[n, ho, wo, k] +using device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f16_instances = std::tuple< + // clang-format off + //################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvForward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| + //################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + //################################################################| | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| + //################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1> + // clang-format on + >; + +using device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_1x1_p0_f16_instances = std::tuple< + // clang-format off + //################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvForward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| + //################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + //################################################################| | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| + //################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1> + // clang-format on + >; + +using device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_1x1_s1_p0_f16_instances = std::tuple< + // clang-format off + //################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvForward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| + //################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + //################################################################| | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| + //################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1> + // clang-format on + >; + +void add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f16_instances( + std::vector>& instances) +{ + add_device_operation_instances(instances, device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f16_instances{}); + add_device_operation_instances(instances, + device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_1x1_p0_f16_instances{}); + add_device_operation_instances(instances, + device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_1x1_s1_p0_f16_instances{}); +} + +} // namespace device_conv2d_fwd_instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/device_operation/device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f32_instance.cpp b/device_operation/device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f32_instance.cpp new file mode 100644 index 0000000000..402d65a6e0 --- /dev/null +++ b/device_operation/device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f32_instance.cpp @@ -0,0 +1,108 @@ +#include +#include "config.hpp" +#include "device_conv2d_fwd_xdl_nhwc_kyxc_nhwk.hpp" +#include "element_wise_operation.hpp" +#include "device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace device_conv2d_fwd_instance { + +using F32 = float; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto ConvFwdDefault = + ck::tensor_operation::device::ConvolutionForwardSpecialization_t::Default; + +static constexpr auto ConvFwd1x1P0 = + ck::tensor_operation::device::ConvolutionForwardSpecialization_t::Filter1x1Pad0; + +static constexpr auto ConvFwd1x1S1P0 = + ck::tensor_operation::device::ConvolutionForwardSpecialization_t::Filter1x1Stride1Pad0; + +// Compilation parameters for in[n, hi, wi, c] * wei[k, y, x, c] = out[n, ho, wo, k] +using device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f32_instances = std::tuple< + // clang-format off + //################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvForward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| + //################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + //################################################################| | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| + //################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 256, 256, 128, 4, 4, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 256, 128, 256, 4, 4, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 128, 128, 128, 4, 4, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 256, 128, 128, 4, 4, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 128, 128, 64, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 128, 64, 128, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 64, 64, 64, 4, 4, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 256, 128, 64, 4, 4, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 256, 64, 128, 4, 4, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 128, 128, 32, 4, 4, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 128, 32, 128, 4, 4, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 64, 64, 32, 4, 4, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 64, 32, 64, 4, 4, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1> + // clang-format on + >; + +using device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_1x1_p0_f32_instances = std::tuple< + // clang-format off + //################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvForward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| + //################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + //################################################################| | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| + //################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 256, 256, 128, 4, 4, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 256, 128, 256, 4, 4, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 128, 128, 128, 4, 4, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 256, 128, 128, 4, 4, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 128, 128, 64, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 128, 64, 128, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 64, 64, 64, 4, 4, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 256, 128, 64, 4, 4, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 256, 64, 128, 4, 4, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 128, 128, 32, 4, 4, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 128, 32, 128, 4, 4, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 64, 64, 32, 4, 4, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 64, 32, 64, 4, 4, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1> + // clang-format on + >; + +using device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_1x1_s1_p0_f32_instances = std::tuple< + // clang-format off + //################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvForward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| + //################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + //################################################################| | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| + //################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 256, 256, 128, 4, 4, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 256, 128, 256, 4, 4, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 128, 128, 128, 4, 4, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 256, 128, 128, 4, 4, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 128, 128, 64, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 128, 64, 128, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 64, 64, 64, 4, 4, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 256, 128, 64, 4, 4, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 256, 64, 128, 4, 4, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 128, 128, 32, 4, 4, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 128, 32, 128, 4, 4, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 64, 64, 32, 4, 4, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 64, 32, 64, 4, 4, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1> + // clang-format on + >; + +void add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f32_instances( + std::vector>& instances) +{ + add_device_operation_instances(instances, device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f32_instances{}); + add_device_operation_instances(instances, + device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_1x1_p0_f32_instances{}); + add_device_operation_instances(instances, + device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_1x1_s1_p0_f32_instances{}); +} + +} // namespace device_conv2d_fwd_instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/device_operation/device_conv_xdl_instance_f16_f16_f16_nhwc_kyxc_nhwk.cpp b/device_operation/device_conv_xdl_instance_f16_f16_f16_nhwc_kyxc_nhwk.cpp deleted file mode 100644 index 5f8ba7904f..0000000000 --- a/device_operation/device_conv_xdl_instance_f16_f16_f16_nhwc_kyxc_nhwk.cpp +++ /dev/null @@ -1,67 +0,0 @@ -#include -#include "config.hpp" -#include "device_conv_fwd_xdl_nhwc_kyxc_nhwk.hpp" -#include "device_conv_instance.hpp" -#include "element_wise_operation.hpp" - -namespace ck { -namespace tensor_operation { -namespace device { -namespace device_conv_instance { - -using F16 = ck::half_t; -using F32 = float; - -using NHWC = ck::tensor_layout::convolution::NHWC; -using KYXC = ck::tensor_layout::convolution::KYXC; -using NHWK = ck::tensor_layout::convolution::NHWK; - -template -using S = ck::Sequence; - -using PassThrough = ck::tensor_operation::element_wise::PassThrough; - -// Compilation parameters for in[n, hi, wi, c] * wei[k, y, x, c] = out[n, ho, wo, k] -using device_conv_fwd_xdl_instances_f16_f16_f16_nhwc_kyxc_nhwk = std::tuple< - // clang-format off - //##############| NDim| InData| WeiData| OutData| AccData| In| Wei| Out| A| B| C| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| ABlockLds| BBlockLds| - //##############| Spatial| Type| Type| Type| Type| Layout| Layout| Layout| Elementwise| Elementwise| Elementwise| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadSlice| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ThreadSlice| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| SrcDstVectorDim| DstScalar| AddExtraM| AddExtraN| - //##############| | | | | | | | | Operation| Operation| Operation| | | | | | | | Wave| Wave| Lengths_K0_N_K1| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| Lengths_K0_N_K1| Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerVector| | | - //##############| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceConvFwdXdl< 2, F16, F16, F16, F32, NHWC, KYXC, NHWK, PassThrough, PassThrough, PassThrough, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 8>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, S<1, 2, 8>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 7, 1, true, true>, - DeviceConvFwdXdl< 2, F16, F16, F16, F32, NHWC, KYXC, NHWK, PassThrough, PassThrough, PassThrough, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<1, 2, 8>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, S<1, 4, 8>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 7, 1, true, true>, - DeviceConvFwdXdl< 2, F16, F16, F16, F32, NHWC, KYXC, NHWK, PassThrough, PassThrough, PassThrough, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 8>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, S<1, 4, 8>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 7, 1, true, true>, - DeviceConvFwdXdl< 2, F16, F16, F16, F32, NHWC, KYXC, NHWK, PassThrough, PassThrough, PassThrough, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<1, 2, 8>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, S<1, 2, 8>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 7, 1, true, true>, - DeviceConvFwdXdl< 2, F16, F16, F16, F32, NHWC, KYXC, NHWK, PassThrough, PassThrough, PassThrough, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 8>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, S<1, 2, 8>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 7, 1, true, true>, - DeviceConvFwdXdl< 2, F16, F16, F16, F32, NHWC, KYXC, NHWK, PassThrough, PassThrough, PassThrough, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<1, 2, 8>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, S<1, 4, 8>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 7, 1, true, true>, - DeviceConvFwdXdl< 2, F16, F16, F16, F32, NHWC, KYXC, NHWK, PassThrough, PassThrough, PassThrough, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 8>, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, S<1, 4, 8>, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 7, 1, true, true>, - DeviceConvFwdXdl< 2, F16, F16, F16, F32, NHWC, KYXC, NHWK, PassThrough, PassThrough, PassThrough, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<1, 2, 8>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, S<1, 1, 8>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 7, 1, true, true>, - DeviceConvFwdXdl< 2, F16, F16, F16, F32, NHWC, KYXC, NHWK, PassThrough, PassThrough, PassThrough, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<1, 1, 8>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, S<1, 2, 8>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 7, 1, true, true>, - DeviceConvFwdXdl< 2, F16, F16, F16, F32, NHWC, KYXC, NHWK, PassThrough, PassThrough, PassThrough, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<1, 4, 8>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, S<1, 1, 8>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 7, 1, true, true>, - DeviceConvFwdXdl< 2, F16, F16, F16, F32, NHWC, KYXC, NHWK, PassThrough, PassThrough, PassThrough, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<1, 1, 8>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, S<1, 4, 8>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 7, 1, true, true>, - DeviceConvFwdXdl< 2, F16, F16, F16, F32, NHWC, KYXC, NHWK, PassThrough, PassThrough, PassThrough, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<1, 4, 8>, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, S<1, 2, 8>, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 7, 1, true, true>, - DeviceConvFwdXdl< 2, F16, F16, F16, F32, NHWC, KYXC, NHWK, PassThrough, PassThrough, PassThrough, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<1, 2, 8>, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, S<1, 4, 8>, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 7, 1, true, true> - // clang-format on - >; - -template <> -void add_device_conv_fwd_instance<2, F16, F16, F16, NHWC, KYXC, NHWK>( - std::vector>& device_conv_instances) -{ - using DeviceConvs = device_conv_fwd_xdl_instances_f16_f16_f16_nhwc_kyxc_nhwk; - - const auto device_convs = DeviceConvs{}; - - ck::static_for<0, std::tuple_size_v, 1>{}([&](auto i) { - using Conv = remove_cvref_t(device_convs))>; - - auto conv = Conv{}; - - device_conv_instances.push_back(std::make_unique(conv)); - }); -} - -} // namespace device_conv_instance -} // namespace device -} // namespace tensor_operation -} // namespace ck diff --git a/device_operation/device_conv_xdl_instance_f32_f32_f32_nhwc_kyxc_nhwk.cpp b/device_operation/device_conv_xdl_instance_f32_f32_f32_nhwc_kyxc_nhwk.cpp deleted file mode 100644 index 90a92b7469..0000000000 --- a/device_operation/device_conv_xdl_instance_f32_f32_f32_nhwc_kyxc_nhwk.cpp +++ /dev/null @@ -1,67 +0,0 @@ -#include -#include "config.hpp" -#include "device_conv_fwd_xdl_nhwc_kyxc_nhwk.hpp" -#include "device_conv_instance.hpp" -#include "element_wise_operation.hpp" - -namespace ck { -namespace tensor_operation { -namespace device { -namespace device_conv_instance { - -using F16 = ck::half_t; -using F32 = float; - -using NHWC = ck::tensor_layout::convolution::NHWC; -using KYXC = ck::tensor_layout::convolution::KYXC; -using NHWK = ck::tensor_layout::convolution::NHWK; - -template -using S = ck::Sequence; - -using PassThrough = ck::tensor_operation::element_wise::PassThrough; - -// Compilation parameters for in[n, hi, wi, c] * wei[k, y, x, c] = out[n, ho, wo, k] -using device_conv_fwd_xdl_instances_f32_f32_f32_nhwc_kyxc_nhwk = std::tuple< - // clang-format off - //##############| NDim| InData| WeiData| OutData| AccData| In| Wei| Out| A| B| C| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| ABlockLds| BBlockLds| - //##############| Spatial| Type| Type| Type| Type| Layout| Layout| Layout| Elementwise| Elementwise| Elementwise| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadSlice| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ThreadSlice| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| SrcDstVectorDim| DstScalar| AddExtraM| AddExtraN| - //##############| | | | | | | | | Operation| Operation| Operation| | | | | | | | Wave| Wave| Lengths_K0_N_K1| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| Lengths_K0_N_K1| Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerVector| | | - //##############| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceConvFwdXdl< 2, F32, F32, F32, F32, NHWC, KYXC, NHWK, PassThrough, PassThrough, PassThrough, 256, 256, 128, 4, 4, 32, 32, 4, 2, S<1, 4, 4>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, S<1, 2, 4>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 7, 1, true, true>, - DeviceConvFwdXdl< 2, F32, F32, F32, F32, NHWC, KYXC, NHWK, PassThrough, PassThrough, PassThrough, 256, 128, 256, 4, 4, 32, 32, 2, 4, S<1, 2, 4>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, S<1, 4, 4>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 7, 1, true, true>, - DeviceConvFwdXdl< 2, F32, F32, F32, F32, NHWC, KYXC, NHWK, PassThrough, PassThrough, PassThrough, 128, 128, 128, 4, 4, 32, 32, 4, 2, S<1, 4, 4>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, S<1, 4, 4>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 7, 1, true, true>, - DeviceConvFwdXdl< 2, F32, F32, F32, F32, NHWC, KYXC, NHWK, PassThrough, PassThrough, PassThrough, 256, 128, 128, 4, 4, 32, 32, 2, 2, S<1, 2, 4>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, S<1, 2, 4>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 7, 1, true, true>, - DeviceConvFwdXdl< 2, F32, F32, F32, F32, NHWC, KYXC, NHWK, PassThrough, PassThrough, PassThrough, 128, 128, 64, 4, 4, 32, 32, 2, 2, S<1, 4, 4>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, S<1, 2, 4>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 7, 1, true, true>, - DeviceConvFwdXdl< 2, F32, F32, F32, F32, NHWC, KYXC, NHWK, PassThrough, PassThrough, PassThrough, 128, 64, 128, 4, 4, 32, 32, 2, 2, S<1, 2, 4>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, S<1, 4, 4>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 7, 1, true, true>, - DeviceConvFwdXdl< 2, F32, F32, F32, F32, NHWC, KYXC, NHWK, PassThrough, PassThrough, PassThrough, 64, 64, 64, 4, 4, 32, 32, 2, 2, S<1, 4, 4>, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, S<1, 4, 4>, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 7, 1, true, true>, - DeviceConvFwdXdl< 2, F32, F32, F32, F32, NHWC, KYXC, NHWK, PassThrough, PassThrough, PassThrough, 256, 128, 64, 4, 4, 32, 32, 2, 1, S<1, 2, 4>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, S<1, 1, 4>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 7, 1, true, true>, - DeviceConvFwdXdl< 2, F32, F32, F32, F32, NHWC, KYXC, NHWK, PassThrough, PassThrough, PassThrough, 256, 64, 128, 4, 4, 32, 32, 1, 2, S<1, 1, 4>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, S<1, 2, 4>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 7, 1, true, true>, - DeviceConvFwdXdl< 2, F32, F32, F32, F32, NHWC, KYXC, NHWK, PassThrough, PassThrough, PassThrough, 128, 128, 32, 4, 4, 32, 32, 2, 1, S<1, 4, 4>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, S<1, 1, 4>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 7, 1, true, true>, - DeviceConvFwdXdl< 2, F32, F32, F32, F32, NHWC, KYXC, NHWK, PassThrough, PassThrough, PassThrough, 128, 32, 128, 4, 4, 32, 32, 1, 2, S<1, 1, 4>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, S<1, 4, 4>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 7, 1, true, true>, - DeviceConvFwdXdl< 2, F32, F32, F32, F32, NHWC, KYXC, NHWK, PassThrough, PassThrough, PassThrough, 64, 64, 32, 4, 4, 32, 32, 2, 1, S<1, 4, 4>, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, S<1, 2, 4>, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 7, 1, true, true>, - DeviceConvFwdXdl< 2, F32, F32, F32, F32, NHWC, KYXC, NHWK, PassThrough, PassThrough, PassThrough, 64, 32, 64, 4, 4, 32, 32, 1, 2, S<1, 2, 4>, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, S<1, 4, 4>, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 7, 1, true, true> - // clang-format on - >; - -template <> -void add_device_conv_fwd_instance<2, F32, F32, F32, NHWC, KYXC, NHWK>( - std::vector>& device_conv_instances) -{ - using DeviceConvs = device_conv_fwd_xdl_instances_f32_f32_f32_nhwc_kyxc_nhwk; - - const auto device_convs = DeviceConvs{}; - - ck::static_for<0, std::tuple_size_v, 1>{}([&](auto i) { - using Conv = remove_cvref_t(device_convs))>; - - auto conv = Conv{}; - - device_conv_instances.push_back(std::make_unique(conv)); - }); -} - -} // namespace device_conv_instance -} // namespace device -} // namespace tensor_operation -} // namespace ck diff --git a/device_operation/device_gemm_xdl_instance_f16_f16_f16_km_kn_mn.cpp b/device_operation/device_gemm_xdl_instance_f16_f16_f16_km_kn_mn.cpp index 26ebd2238c..78f5352f7e 100644 --- a/device_operation/device_gemm_xdl_instance_f16_f16_f16_km_kn_mn.cpp +++ b/device_operation/device_gemm_xdl_instance_f16_f16_f16_km_kn_mn.cpp @@ -21,22 +21,23 @@ using S = ck::Sequence; using PassThrough = ck::tensor_operation::element_wise::PassThrough; // Compilation parameters for a[k, m] * b[k, n] = c[m, n] -using device_gemm_xdl_instance_f16_f16_f16_km_kn_mn = std::tuple< - // clang-format off - //##########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| ABlockLds| BBlockLds| - //##########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadSlice| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ThreadSlice| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| SrcDstVectorDim| DstScalar| AddExtraM| AddExtraN| - //##########| | | | | | | | Operation| Operation| Operation| | | | | | | | Wave| Wave| Lengths_K0_N_K1| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| Lengths_K0_N_K1| Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerVector| | | - //##########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceGemmXdl< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 8>, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, S<1, 2, 8>, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 7, 1, true, true>, - DeviceGemmXdl< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<1, 2, 8>, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, S<1, 4, 8>, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 7, 1, true, true>, - DeviceGemmXdl< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 8>, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, S<1, 4, 8>, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 7, 1, true, true>, - DeviceGemmXdl< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<1, 2, 8>, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, S<1, 2, 8>, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 7, 1, true, true>, - DeviceGemmXdl< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 8>, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, S<1, 2, 8>, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 7, 1, true, true>, - DeviceGemmXdl< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<1, 2, 8>, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, S<1, 4, 8>, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 7, 1, true, true>, - DeviceGemmXdl< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<1, 2, 8>, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, S<1, 1, 8>, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 7, 1, true, true>, - DeviceGemmXdl< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<1, 1, 8>, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, S<1, 2, 8>, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 7, 1, true, true> - // clang-format on - >; +using device_gemm_xdl_instance_f16_f16_f16_km_kn_mn = + std::tuple< + // clang-format off + //##########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| + //##########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + //##########| | | | | | | | Operation| Operation| Operation| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| + //##########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmXdl< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>, + DeviceGemmXdl< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>, + DeviceGemmXdl< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>, + DeviceGemmXdl< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>, + DeviceGemmXdl< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>, + DeviceGemmXdl< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>, + DeviceGemmXdl< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 7, 1>, + DeviceGemmXdl< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1> + // clang-format on + >; template <> void add_device_gemm_instance( diff --git a/device_operation/device_gemm_xdl_instance_f16_f16_f16_km_nk_mn.cpp b/device_operation/device_gemm_xdl_instance_f16_f16_f16_km_nk_mn.cpp index bd916b8271..786c4ab1e1 100644 --- a/device_operation/device_gemm_xdl_instance_f16_f16_f16_km_nk_mn.cpp +++ b/device_operation/device_gemm_xdl_instance_f16_f16_f16_km_nk_mn.cpp @@ -21,22 +21,23 @@ using S = ck::Sequence; using PassThrough = ck::tensor_operation::element_wise::PassThrough; // Compilation parameters for a[k, m] * b[n, k] = c[m, n] -using device_gemm_xdl_instance_f16_f16_f16_km_nk_mn = std::tuple< - // clang-format off - //##########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| ABlockLds| BBlockLds| - //##########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadSlice| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ThreadSlice| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| SrcDstVectorDim| DstScalar| AddExtraM| AddExtraN| - //##########| | | | | | | | Operation| Operation| Operation| | | | | | | | Wave| Wave| Lengths_K0_N_K1| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| Lengths_K0_N_K1| Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerVector| | | - //##########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceGemmXdl< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 8>, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, S<1, 2, 8>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 7, 1, true, true>, - DeviceGemmXdl< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<1, 2, 8>, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, S<1, 4, 8>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 7, 1, true, true>, - DeviceGemmXdl< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 8>, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, S<1, 4, 8>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 7, 1, true, true>, - DeviceGemmXdl< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<1, 2, 8>, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, S<1, 2, 8>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 7, 1, true, true>, - DeviceGemmXdl< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 8>, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, S<1, 2, 8>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 7, 1, true, true>, - DeviceGemmXdl< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<1, 2, 8>, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, S<1, 4, 8>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 7, 1, true, true>, - DeviceGemmXdl< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<1, 2, 8>, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, S<1, 1, 8>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 7, 1, true, true>, - DeviceGemmXdl< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<1, 1, 8>, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, S<1, 2, 8>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 7, 1, true, true> - // clang-format on - >; +using device_gemm_xdl_instance_f16_f16_f16_km_nk_mn = + std::tuple< + // clang-format off + //##########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| + //##########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + //##########| | | | | | | | Operation| Operation| Operation| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| + //##########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmXdl< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceGemmXdl< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceGemmXdl< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceGemmXdl< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceGemmXdl< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceGemmXdl< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceGemmXdl< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceGemmXdl< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1> + // clang-format on + >; template <> void add_device_gemm_instance( diff --git a/device_operation/device_gemm_xdl_instance_f16_f16_f16_mk_kn_mn.cpp b/device_operation/device_gemm_xdl_instance_f16_f16_f16_mk_kn_mn.cpp index 09fdc7d059..44459ca4cb 100644 --- a/device_operation/device_gemm_xdl_instance_f16_f16_f16_mk_kn_mn.cpp +++ b/device_operation/device_gemm_xdl_instance_f16_f16_f16_mk_kn_mn.cpp @@ -21,22 +21,23 @@ using S = ck::Sequence; using PassThrough = ck::tensor_operation::element_wise::PassThrough; // Compilation parameters for a[m, k] * b[k, n] = c[m, n] -using device_gemm_xdl_instance_f16_f16_f16_mk_kn_mn = std::tuple< - // clang-format off - //##########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| ABlockLds| BBlockLds| - //##########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadSlice| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ThreadSlice| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| SrcDstVectorDim| DstScalar| AddExtraM| AddExtraN| - //##########| | | | | | | | Operation| Operation| Operation| | | | | | | | Wave| Wave| Lengths_K0_N_K1| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| Lengths_K0_N_K1| Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerVector| | | - //##########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 8>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, S<1, 2, 8>, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 7, 1, true, true>, - DeviceGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<1, 2, 8>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, S<1, 4, 8>, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 7, 1, true, true>, - DeviceGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 8>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, S<1, 4, 8>, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 7, 1, true, true>, - DeviceGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<1, 2, 8>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, S<1, 2, 8>, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 7, 1, true, true>, - DeviceGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 8>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, S<1, 2, 8>, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 7, 1, true, true>, - DeviceGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<1, 2, 8>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, S<1, 4, 8>, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 7, 1, true, true>, - DeviceGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<1, 2, 8>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, S<1, 1, 8>, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 7, 1, true, true>, - DeviceGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<1, 1, 8>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, S<1, 2, 8>, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 7, 1, true, true> - // clang-format on - >; +using device_gemm_xdl_instance_f16_f16_f16_mk_kn_mn = + std::tuple< + // clang-format off + //##########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| + //##########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + //##########| | | | | | | | Operation| Operation| Operation| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| + //##########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>, + DeviceGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>, + DeviceGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>, + DeviceGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>, + DeviceGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>, + DeviceGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>, + DeviceGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 7, 1>, + DeviceGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1> + // clang-format on + >; template <> void add_device_gemm_instance( diff --git a/device_operation/device_gemm_xdl_instance_f16_f16_f16_mk_nk_mn.cpp b/device_operation/device_gemm_xdl_instance_f16_f16_f16_mk_nk_mn.cpp index 06362bdea0..7286dfe598 100644 --- a/device_operation/device_gemm_xdl_instance_f16_f16_f16_mk_nk_mn.cpp +++ b/device_operation/device_gemm_xdl_instance_f16_f16_f16_mk_nk_mn.cpp @@ -21,27 +21,28 @@ using S = ck::Sequence; using PassThrough = ck::tensor_operation::element_wise::PassThrough; // Compilation parameters for a[m, k] * b[n, k] = c[m, n] -using device_gemm_xdl_instance_f16_f16_f16_mk_nk_mn = std::tuple< - // clang-format off - //##########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| ABlockLds| BBlockLds| - //##########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadSlice| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ThreadSlice| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| SrcDstVectorDim| DstScalar| AddExtraM| AddExtraN| - //##########| | | | | | | | Operation| Operation| Operation| | | | | | | | Wave| Wave| Lengths_K0_N_K1| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| Lengths_K0_N_K1| Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerVector| | | - //##########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceGemmXdl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 8>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, S<1, 2, 8>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 7, 1, true, true>, - DeviceGemmXdl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<1, 2, 8>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, S<1, 4, 8>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 7, 1, true, true>, - DeviceGemmXdl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 8>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, S<1, 4, 8>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 7, 1, true, true>, - DeviceGemmXdl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<1, 2, 8>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, S<1, 2, 8>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 7, 1, true, true>, - DeviceGemmXdl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 8>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, S<1, 2, 8>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 7, 1, true, true>, - DeviceGemmXdl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<1, 2, 8>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, S<1, 4, 8>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 7, 1, true, true>, - DeviceGemmXdl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 8>, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, S<1, 4, 8>, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 7, 1, true, true>, - DeviceGemmXdl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<1, 2, 8>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, S<1, 1, 8>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 7, 1, true, true>, - DeviceGemmXdl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<1, 1, 8>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, S<1, 2, 8>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 7, 1, true, true>, - DeviceGemmXdl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<1, 4, 8>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, S<1, 1, 8>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 7, 1, true, true>, - DeviceGemmXdl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<1, 1, 8>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, S<1, 4, 8>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 7, 1, true, true>, - DeviceGemmXdl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<1, 4, 8>, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, S<1, 2, 8>, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 7, 1, true, true>, - DeviceGemmXdl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<1, 2, 8>, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, S<1, 4, 8>, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 7, 1, true, true> - // clang-format on - >; +using device_gemm_xdl_instance_f16_f16_f16_mk_nk_mn = + std::tuple< + // clang-format off + //##########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| + //##########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + //##########| | | | | | | | Operation| Operation| Operation| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| + //##########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmXdl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceGemmXdl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceGemmXdl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceGemmXdl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceGemmXdl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceGemmXdl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceGemmXdl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceGemmXdl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceGemmXdl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceGemmXdl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceGemmXdl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceGemmXdl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceGemmXdl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1> + // clang-format on + >; template <> void add_device_gemm_instance( diff --git a/device_operation/device_gemm_xdl_instance_f32_f32_f32_km_kn_mn.cpp b/device_operation/device_gemm_xdl_instance_f32_f32_f32_km_kn_mn.cpp index da0b9fce52..344f182fa3 100644 --- a/device_operation/device_gemm_xdl_instance_f32_f32_f32_km_kn_mn.cpp +++ b/device_operation/device_gemm_xdl_instance_f32_f32_f32_km_kn_mn.cpp @@ -21,22 +21,23 @@ using S = ck::Sequence; using PassThrough = ck::tensor_operation::element_wise::PassThrough; // Compilation parameters for a[k, m] * b[k, n] = c[m, n] -using device_gemm_xdl_instance_f32_f32_f32_km_kn_mn = std::tuple< - // clang-format off - //##########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| ABlockLds| BBlockLds| - //##########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadSlice| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ThreadSlice| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| SrcDstVectorDim| DstScalar| AddExtraM| AddExtraN| - //##########| | | | | | | | Operation| Operation| Operation| | | | | | | | Wave| Wave| Lengths_K0_N_K1| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| Lengths_K0_N_K1| Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerVector| | | - //##########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceGemmXdl< F32, F32, F32, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, 256, 256, 128, 4, 4, 32, 32, 4, 2, S<1, 4, 4>, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, S<1, 2, 4>, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 7, 1, true, true>, - DeviceGemmXdl< F32, F32, F32, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, 256, 128, 256, 4, 4, 32, 32, 2, 4, S<1, 2, 4>, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, S<1, 4, 4>, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 7, 1, true, true>, - DeviceGemmXdl< F32, F32, F32, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, 128, 128, 128, 4, 4, 32, 32, 4, 2, S<1, 4, 4>, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, S<1, 4, 4>, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 7, 1, true, true>, - DeviceGemmXdl< F32, F32, F32, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, 256, 128, 128, 4, 4, 32, 32, 2, 2, S<1, 2, 4>, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, S<1, 2, 4>, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 7, 1, true, true>, - DeviceGemmXdl< F32, F32, F32, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, 128, 128, 64, 4, 4, 32, 32, 2, 2, S<1, 4, 4>, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, S<1, 2, 4>, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 7, 1, true, true>, - DeviceGemmXdl< F32, F32, F32, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, 128, 64, 128, 4, 4, 32, 32, 2, 2, S<1, 2, 4>, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, S<1, 4, 4>, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 7, 1, true, true>, - DeviceGemmXdl< F32, F32, F32, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, 256, 128, 64, 4, 4, 32, 32, 2, 1, S<1, 2, 4>, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, S<1, 1, 4>, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, 7, 1, true, true>, - DeviceGemmXdl< F32, F32, F32, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, 256, 64, 128, 4, 4, 32, 32, 1, 2, S<1, 1, 4>, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, S<1, 2, 4>, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 7, 1, true, true> - // clang-format on - >; +using device_gemm_xdl_instance_f32_f32_f32_km_kn_mn = + std::tuple< + // clang-format off + //##########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| + //##########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + //##########| | | | | | | | Operation| Operation| Operation| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| + //##########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmXdl< F32, F32, F32, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, 256, 256, 128, 4, 4, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, true, 7, 1>, + DeviceGemmXdl< F32, F32, F32, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, 256, 128, 256, 4, 4, 32, 32, 2, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, true, 7, 1>, + DeviceGemmXdl< F32, F32, F32, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, 128, 128, 128, 4, 4, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, true, 7, 1>, + DeviceGemmXdl< F32, F32, F32, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, 256, 128, 128, 4, 4, 32, 32, 2, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, true, 7, 1>, + DeviceGemmXdl< F32, F32, F32, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, 128, 128, 64, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, true, 7, 1>, + DeviceGemmXdl< F32, F32, F32, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, 128, 64, 128, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, true, 7, 1>, + DeviceGemmXdl< F32, F32, F32, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, 256, 128, 64, 4, 4, 32, 32, 2, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, true, 7, 1>, + DeviceGemmXdl< F32, F32, F32, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, 256, 64, 128, 4, 4, 32, 32, 1, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, true, 7, 1> + // clang-format on + >; template <> void add_device_gemm_instance( diff --git a/device_operation/device_gemm_xdl_instance_f32_f32_f32_km_nk_mn.cpp b/device_operation/device_gemm_xdl_instance_f32_f32_f32_km_nk_mn.cpp index 1557b1d114..fb17e0aaea 100644 --- a/device_operation/device_gemm_xdl_instance_f32_f32_f32_km_nk_mn.cpp +++ b/device_operation/device_gemm_xdl_instance_f32_f32_f32_km_nk_mn.cpp @@ -21,22 +21,23 @@ using S = ck::Sequence; using PassThrough = ck::tensor_operation::element_wise::PassThrough; // Compilation parameters for a[k, m] * b[n, k] = c[m, n] -using device_gemm_xdl_instance_f32_f32_f32_km_nk_mn = std::tuple< - // clang-format off - //##########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| ABlockLds| BBlockLds| - //##########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadSlice| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ThreadSlice| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| SrcDstVectorDim| DstScalar| AddExtraM| AddExtraN| - //##########| | | | | | | | Operation| Operation| Operation| | | | | | | | Wave| Wave| Lengths_K0_N_K1| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| Lengths_K0_N_K1| Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerVector| | | - //##########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceGemmXdl< F32, F32, F32, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, 256, 256, 128, 4, 4, 32, 32, 4, 2, S<1, 4, 4>, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, S<1, 2, 4>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 7, 1, true, true>, - DeviceGemmXdl< F32, F32, F32, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, 256, 128, 256, 4, 4, 32, 32, 2, 4, S<1, 2, 4>, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, S<1, 4, 4>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 7, 1, true, true>, - DeviceGemmXdl< F32, F32, F32, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, 128, 128, 128, 4, 4, 32, 32, 4, 2, S<1, 4, 4>, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, S<1, 4, 4>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 7, 1, true, true>, - DeviceGemmXdl< F32, F32, F32, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, 256, 128, 128, 4, 4, 32, 32, 2, 2, S<1, 2, 4>, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, S<1, 2, 4>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 7, 1, true, true>, - DeviceGemmXdl< F32, F32, F32, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, 128, 128, 64, 4, 4, 32, 32, 2, 2, S<1, 4, 4>, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, S<1, 2, 4>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 7, 1, true, true>, - DeviceGemmXdl< F32, F32, F32, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, 128, 64, 128, 4, 4, 32, 32, 2, 2, S<1, 2, 4>, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, S<1, 4, 4>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 7, 1, true, true>, - DeviceGemmXdl< F32, F32, F32, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, 256, 128, 64, 4, 4, 32, 32, 2, 1, S<1, 2, 4>, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, S<1, 1, 4>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 7, 1, true, true>, - DeviceGemmXdl< F32, F32, F32, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, 256, 64, 128, 4, 4, 32, 32, 1, 2, S<1, 1, 4>, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, S<1, 2, 4>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 7, 1, true, true> - // clang-format on - >; +using device_gemm_xdl_instance_f32_f32_f32_km_nk_mn = + std::tuple< + // clang-format off + //##########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| + //##########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + //##########| | | | | | | | Operation| Operation| Operation| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| + //##########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmXdl< F32, F32, F32, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, 256, 256, 128, 4, 4, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceGemmXdl< F32, F32, F32, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, 256, 128, 256, 4, 4, 32, 32, 2, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceGemmXdl< F32, F32, F32, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, 128, 128, 128, 4, 4, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceGemmXdl< F32, F32, F32, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, 256, 128, 128, 4, 4, 32, 32, 2, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceGemmXdl< F32, F32, F32, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, 128, 128, 64, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceGemmXdl< F32, F32, F32, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, 128, 64, 128, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceGemmXdl< F32, F32, F32, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, 256, 128, 64, 4, 4, 32, 32, 2, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceGemmXdl< F32, F32, F32, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, 256, 64, 128, 4, 4, 32, 32, 1, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1> + // clang-format on + >; template <> void add_device_gemm_instance( diff --git a/device_operation/device_gemm_xdl_instance_f32_f32_f32_mk_kn_mn.cpp b/device_operation/device_gemm_xdl_instance_f32_f32_f32_mk_kn_mn.cpp index c9ba29bfdc..7567a8c2ec 100644 --- a/device_operation/device_gemm_xdl_instance_f32_f32_f32_mk_kn_mn.cpp +++ b/device_operation/device_gemm_xdl_instance_f32_f32_f32_mk_kn_mn.cpp @@ -21,22 +21,23 @@ using S = ck::Sequence; using PassThrough = ck::tensor_operation::element_wise::PassThrough; // Compilation parameters for a[m, k] * b[k, n] = c[m, n] -using device_gemm_xdl_instance_f32_f32_f32_mk_kn_mn = std::tuple< - // clang-format off - //##########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| ABlockLds| BBlockLds| - //##########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadSlice| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ThreadSlice| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| SrcDstVectorDim| DstScalar| AddExtraM| AddExtraN| - //##########| | | | | | | | Operation| Operation| Operation| | | | | | | | Wave| Wave| Lengths_K0_N_K1| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| Lengths_K0_N_K1| Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerVector| | | - //##########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceGemmXdl< F32, F32, F32, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 256, 256, 128, 4, 4, 32, 32, 4, 2, S<1, 4, 4>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, S<1, 2, 4>, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 7, 1, true, true>, - DeviceGemmXdl< F32, F32, F32, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 256, 128, 256, 4, 4, 32, 32, 2, 4, S<1, 2, 4>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, S<1, 4, 4>, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 7, 1, true, true>, - DeviceGemmXdl< F32, F32, F32, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 128, 128, 128, 4, 4, 32, 32, 4, 2, S<1, 4, 4>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, S<1, 4, 4>, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 7, 1, true, true>, - DeviceGemmXdl< F32, F32, F32, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 256, 128, 128, 4, 4, 32, 32, 2, 2, S<1, 2, 4>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, S<1, 2, 4>, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 7, 1, true, true>, - DeviceGemmXdl< F32, F32, F32, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 128, 128, 64, 4, 4, 32, 32, 2, 2, S<1, 4, 4>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, S<1, 2, 4>, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 7, 1, true, true>, - DeviceGemmXdl< F32, F32, F32, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 128, 64, 128, 4, 4, 32, 32, 2, 2, S<1, 2, 4>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, S<1, 4, 4>, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 7, 1, true, true>, - DeviceGemmXdl< F32, F32, F32, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 256, 128, 64, 4, 4, 32, 32, 2, 1, S<1, 2, 4>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, S<1, 1, 4>, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, 7, 1, true, true>, - DeviceGemmXdl< F32, F32, F32, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 256, 64, 128, 4, 4, 32, 32, 1, 2, S<1, 1, 4>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, S<1, 2, 4>, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 7, 1, true, true> - // clang-format on - >; +using device_gemm_xdl_instance_f32_f32_f32_mk_kn_mn = + std::tuple< + // clang-format off + //##########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| + //##########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + //##########| | | | | | | | Operation| Operation| Operation| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| + //##########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmXdl< F32, F32, F32, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 256, 256, 128, 4, 4, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, true, 7, 1>, + DeviceGemmXdl< F32, F32, F32, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 256, 128, 256, 4, 4, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, true, 7, 1>, + DeviceGemmXdl< F32, F32, F32, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 128, 128, 128, 4, 4, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, true, 7, 1>, + DeviceGemmXdl< F32, F32, F32, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 256, 128, 128, 4, 4, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, true, 7, 1>, + DeviceGemmXdl< F32, F32, F32, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 128, 128, 64, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, true, 7, 1>, + DeviceGemmXdl< F32, F32, F32, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 128, 64, 128, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, true, 7, 1>, + DeviceGemmXdl< F32, F32, F32, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 256, 128, 64, 4, 4, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, true, 7, 1>, + DeviceGemmXdl< F32, F32, F32, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 256, 64, 128, 4, 4, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, true, 7, 1> + // clang-format on + >; template <> void add_device_gemm_instance( diff --git a/device_operation/device_gemm_xdl_instance_f32_f32_f32_mk_nk_mn.cpp b/device_operation/device_gemm_xdl_instance_f32_f32_f32_mk_nk_mn.cpp index e1d2296336..6c80f0d9f4 100644 --- a/device_operation/device_gemm_xdl_instance_f32_f32_f32_mk_nk_mn.cpp +++ b/device_operation/device_gemm_xdl_instance_f32_f32_f32_mk_nk_mn.cpp @@ -21,27 +21,28 @@ using S = ck::Sequence; using PassThrough = ck::tensor_operation::element_wise::PassThrough; // Compilation parameters for a[m, k] * b[n, k] = c[m, n] -using device_gemm_xdl_instance_f32_f32_f32_mk_nk_mn = std::tuple< - // clang-format off - //##########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| ABlockLds| BBlockLds| - //##########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadSlice| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ThreadSlice| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| SrcDstVectorDim| DstScalar| AddExtraM| AddExtraN| - //##########| | | | | | | | Operation| Operation| Operation| | | | | | | | Wave| Wave| Lengths_K0_N_K1| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| Lengths_K0_N_K1| Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerVector| | | - //##########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceGemmXdl< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 256, 256, 128, 4, 4, 32, 32, 4, 2, S<1, 4, 4>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, S<1, 2, 4>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 7, 1, true, true>, - DeviceGemmXdl< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 256, 128, 256, 4, 4, 32, 32, 2, 4, S<1, 2, 4>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, S<1, 4, 4>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 7, 1, true, true>, - DeviceGemmXdl< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 128, 128, 128, 4, 4, 32, 32, 4, 2, S<1, 4, 4>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, S<1, 4, 4>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 7, 1, true, true>, - DeviceGemmXdl< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 256, 128, 128, 4, 4, 32, 32, 2, 2, S<1, 2, 4>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, S<1, 2, 4>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 7, 1, true, true>, - DeviceGemmXdl< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 128, 128, 64, 4, 4, 32, 32, 2, 2, S<1, 4, 4>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, S<1, 2, 4>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 7, 1, true, true>, - DeviceGemmXdl< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 128, 64, 128, 4, 4, 32, 32, 2, 2, S<1, 2, 4>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, S<1, 4, 4>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 7, 1, true, true>, - DeviceGemmXdl< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 64, 64, 64, 4, 4, 32, 32, 2, 2, S<1, 4, 4>, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, S<1, 4, 4>, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 7, 1, true, true>, - DeviceGemmXdl< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 256, 128, 64, 4, 4, 32, 32, 2, 1, S<1, 2, 4>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, S<1, 1, 4>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 7, 1, true, true>, - DeviceGemmXdl< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 256, 64, 128, 4, 4, 32, 32, 1, 2, S<1, 1, 4>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, S<1, 2, 4>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 7, 1, true, true>, - DeviceGemmXdl< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 128, 128, 32, 4, 4, 32, 32, 2, 1, S<1, 4, 4>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, S<1, 1, 4>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 7, 1, true, true>, - DeviceGemmXdl< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 128, 32, 128, 4, 4, 32, 32, 1, 2, S<1, 1, 4>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, S<1, 4, 4>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 7, 1, true, true>, - DeviceGemmXdl< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 64, 64, 32, 4, 4, 32, 32, 2, 1, S<1, 4, 4>, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, S<1, 2, 4>, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 7, 1, true, true>, - DeviceGemmXdl< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 64, 32, 64, 4, 4, 32, 32, 1, 2, S<1, 2, 4>, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, S<1, 4, 4>, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 7, 1, true, true> - // clang-format on - >; +using device_gemm_xdl_instance_f32_f32_f32_mk_nk_mn = + std::tuple< + // clang-format off + //##########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| + //##########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + //##########| | | | | | | | Operation| Operation| Operation| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| + //##########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmXdl< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 256, 256, 128, 4, 4, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceGemmXdl< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 256, 128, 256, 4, 4, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceGemmXdl< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 128, 128, 128, 4, 4, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceGemmXdl< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 256, 128, 128, 4, 4, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceGemmXdl< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 128, 128, 64, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceGemmXdl< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 128, 64, 128, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceGemmXdl< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 64, 64, 64, 4, 4, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceGemmXdl< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 256, 128, 64, 4, 4, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceGemmXdl< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 256, 64, 128, 4, 4, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceGemmXdl< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 128, 128, 32, 4, 4, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceGemmXdl< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 128, 32, 128, 4, 4, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceGemmXdl< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 64, 64, 32, 4, 4, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceGemmXdl< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 64, 32, 64, 4, 4, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1> + // clang-format on + >; template <> void add_device_gemm_instance( diff --git a/device_operation/include/convolution_forward_specialization.hpp b/device_operation/include/convolution_forward_specialization.hpp new file mode 100644 index 0000000000..e047acee76 --- /dev/null +++ b/device_operation/include/convolution_forward_specialization.hpp @@ -0,0 +1,19 @@ +#ifndef CONVOLUTION_FORWARD_SPECIALIZATION +#define CONVOLUTION_FORWARD_SPECIALIZATION + +namespace ck { +namespace tensor_operation { +namespace device { + +enum ConvolutionForwardSpecialization_t +{ + Default, + Filter1x1Pad0, + Filter1x1Stride1Pad0, + OddC, +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck +#endif diff --git a/device_operation/include/device_base.hpp b/device_operation/include/device_base.hpp index de47889f2a..cf48695ad0 100644 --- a/device_operation/include/device_base.hpp +++ b/device_operation/include/device_base.hpp @@ -1,6 +1,8 @@ #ifndef DEVICE_BASE_HPP #define DEVICE_BASE_HPP +#include + namespace ck { namespace tensor_operation { namespace device { @@ -32,6 +34,7 @@ struct BaseOperator BaseOperator& operator=(const BaseOperator&) = default; virtual bool IsSupportedArgument(const BaseArgument*) = 0; + virtual std::string GetTypeString() const = 0; virtual ~BaseOperator() {} }; diff --git a/device_operation/include/device_conv.hpp b/device_operation/include/device_conv.hpp deleted file mode 100644 index f521eecb9a..0000000000 --- a/device_operation/include/device_conv.hpp +++ /dev/null @@ -1,110 +0,0 @@ -#ifndef DEVICE_CONV_HPP -#define DEVICE_CONV_HPP - -#include -#include "device_base.hpp" - -namespace ck { -namespace tensor_operation { -namespace device { - -template -struct DeviceConvFwd : public BaseOperator -{ - virtual std::unique_ptr - MakeArgumentPointer(const void* p_in, - const void* p_wei, - void* p_out, - ck::index_t N, - ck::index_t K, - ck::index_t C, - std::vector input_spatial_lengths, - std::vector filter_spatial_lengths, - std::vector output_spatial_lengths, - std::vector conv_filter_strides, - std::vector conv_filter_dilations, - std::vector input_left_pads, - std::vector input_right_pads, - InElementwiseOperation in_element_op, - WeiElementwiseOperation wei_element_op, - OutElementwiseOperation out_element_op) = 0; - - virtual std::unique_ptr MakeInvokerPointer() = 0; -}; - -template -struct DeviceConvBwd : public BaseOperator -{ - virtual std::unique_ptr - MakeArgumentPointer(void* p_in, - const void* p_wei, - const void* p_out, - ck::index_t N, - ck::index_t K, - ck::index_t C, - std::vector input_spatial_lengths, - std::vector filter_spatial_lengths, - std::vector output_spatial_lengths, - std::vector conv_filter_strides, - std::vector conv_filter_dilations, - std::vector input_left_pads, - std::vector input_right_pads, - InElementwiseOperation in_element_op, - WeiElementwiseOperation wei_element_op, - OutElementwiseOperation out_element_op) = 0; - - virtual std::unique_ptr MakeInvokerPointer() = 0; -}; - -template -struct DeviceConvWrw : public BaseOperator -{ - virtual std::unique_ptr - MakeArgumentPointer(const void* p_in, - void* p_wei, - const void* p_out, - ck::index_t N, - ck::index_t K, - ck::index_t C, - std::vector input_spatial_lengths, - std::vector filter_spatial_lengths, - std::vector output_spatial_lengths, - std::vector conv_filter_strides, - std::vector conv_filter_dilations, - std::vector input_left_pads, - std::vector input_right_pads, - InElementwiseOperation in_element_op, - WeiElementwiseOperation wei_element_op, - OutElementwiseOperation out_element_op) = 0; - - virtual std::unique_ptr MakeInvokerPointer() = 0; -}; - -template -using DeviceConvFwdPtr = std::unique_ptr< - DeviceConvFwd>; - -template -using DeviceConvBwdPtr = std::unique_ptr< - DeviceConvBwd>; - -template -using DeviceConvWrwPtr = std::unique_ptr< - DeviceConvWrw>; - -} // namespace device -} // namespace tensor_operation -} // namespace ck -#endif diff --git a/device_operation/include/device_conv2d_fwd_xdl_c_shuffle_bias_activation_add_nhwc_kyxc_nhwk.hpp b/device_operation/include/device_conv2d_fwd_xdl_c_shuffle_bias_activation_add_nhwc_kyxc_nhwk.hpp new file mode 100644 index 0000000000..e9aa4fa42c --- /dev/null +++ b/device_operation/include/device_conv2d_fwd_xdl_c_shuffle_bias_activation_add_nhwc_kyxc_nhwk.hpp @@ -0,0 +1,944 @@ +#ifndef DEVICE_CONV2D_FWD_XDL_C_SHUFFLE_BIAS_ACTIVATION_ADD_NHWC_KYXC_NHWK_HPP +#define DEVICE_CONV2D_FWD_XDL_C_SHUFFLE_BIAS_ACTIVATION_ADD_NHWC_KYXC_NHWK_HPP + +#include +#include +#include "device.hpp" +#include "device_base.hpp" +#include "device_conv_fwd_bias_activation_add.hpp" +#include "convolution_forward_specialization.hpp" +#include "common_header.hpp" +#include "tensor_layout.hpp" +#include "tensor_descriptor.hpp" +#include "tensor_descriptor_helper.hpp" +#include "gridwise_gemm_xdlops_v3r3.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +// out[N, Ho, Wo, K] = +// activate(in[N, Hi, Wi, C] * wei[K, Y, X, C] + bias[K]) + residual[N, Ho, Wo, K] +template < + typename InDataType, + typename WeiDataType, + typename OutDataType, + typename AccDataType, + typename InElementwiseOperation, + typename WeiElementwiseOperation, + typename OutElementwiseOperation, + ConvolutionForwardSpecialization_t ConvForwardSpecialization, + ck::index_t BlockSize, + ck::index_t MPerBlock, + ck::index_t NPerBlock, + ck::index_t K0PerBlock, + ck::index_t K1, + ck::index_t MPerXDL, + ck::index_t NPerXDL, + ck::index_t MXdlPerWave, + ck::index_t NXdlPerWave, + typename ABlockTransferThreadClusterLengths_K0_M_K1, + typename ABlockTransferThreadClusterArrangeOrder, + typename ABlockTransferSrcAccessOrder, + ck::index_t ABlockTransferSrcVectorDim, + ck::index_t ABlockTransferSrcScalarPerVector, + ck::index_t ABlockTransferDstScalarPerVector_K1, + bool ABlockLdsAddExtraM, + typename BBlockTransferThreadClusterLengths_K0_N_K1, + typename BBlockTransferThreadClusterArrangeOrder, + typename BBlockTransferSrcAccessOrder, + ck::index_t BBlockTransferSrcVectorDim, + ck::index_t BBlockTransferSrcScalarPerVector, + ck::index_t BBlockTransferDstScalarPerVector_K1, + bool BBlockLdsAddExtraN, + index_t CShuffleMXdlPerWavePerShuffle, + index_t CShuffleNXdlPerWavePerShuffle, + typename CBlockTransferClusterLengths_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl, + index_t CBlockTransferScalarPerVector_NWaveNPerXdl> +struct + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Add_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K + : public DeviceConvFwdBiasActivationAdd +{ + using DeviceOp = + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Add_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K; + + using ADataType = InDataType; + using BDataType = WeiDataType; + using CDataType = OutDataType; + + // TODO make A/B datatype different + using ABDataType = InDataType; + + // TODO make it support any # of spatial dimensions + static constexpr index_t NDimSpatial = 2; + + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + static constexpr auto I4 = Number<4>{}; + + static constexpr auto K1Number = Number{}; + static constexpr auto GemmK1Number = K1Number; + + static auto + MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N(ck::index_t N, + ck::index_t K, + ck::index_t C, + std::vector input_spatial_lengths, + std::vector filter_spatial_lengths, + std::vector output_spatial_lengths, + std::vector conv_filter_strides, + std::vector conv_filter_dilations, + std::vector input_left_pads, + std::vector input_right_pads) + { + using namespace ck; + + const index_t Hi = input_spatial_lengths[0]; + const index_t Wi = input_spatial_lengths[1]; + + const index_t Ho = output_spatial_lengths[0]; + const index_t Wo = output_spatial_lengths[1]; + + const index_t Y = filter_spatial_lengths[0]; + const index_t X = filter_spatial_lengths[1]; + + const index_t ConvStrideH = conv_filter_strides[0]; + const index_t ConvStrideW = conv_filter_strides[1]; + + const index_t ConvDilationH = conv_filter_dilations[0]; + const index_t ConvDilationW = conv_filter_dilations[1]; + + const index_t InLeftPadH = input_left_pads[0]; + const index_t InLeftPadW = input_left_pads[1]; + + const index_t InRightPadH = input_right_pads[0]; + const index_t InRightPadW = input_right_pads[1]; + + const index_t GemmMRaw = N * Ho * Wo; + const index_t GemmN = K; + + const auto GemmM = math::integer_least_multiple(GemmMRaw, MPerBlock); + const auto GemmMPad = GemmM - GemmMRaw; + + if constexpr(ConvForwardSpecialization == + ConvolutionForwardSpecialization_t::Filter1x1Stride1Pad0) + { // 1x1, stride=1, pad=0 + const index_t GemmK = Y * X * C; + assert(GemmK % GemmK1Number == 0); + + const index_t GemmK0 = GemmK / GemmK1Number; + + // A: input tensor + const auto in_gemmmraw_gemmk_grid_desc = + make_naive_tensor_descriptor_packed(make_tuple(N * Ho * Wo, C)); + + const auto in_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor( + in_gemmmraw_gemmk_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1Number)), + make_right_pad_transform(GemmMRaw, GemmMPad)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + // B: weight tensor + const auto wei_gemmn_gemmk_grid_desc = + make_naive_tensor_descriptor_packed(make_tuple(K, C)); + + const auto wei_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor( + wei_gemmn_gemmk_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1Number)), + make_pass_through_transform(GemmN)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + // C: output tensor + const auto out_gemmmraw_gemmn_grid_desc = + make_naive_tensor_descriptor_packed(make_tuple(N * Ho * Wo, K)); + + const auto out_gemmm_gemmn_grid_desc = + transform_tensor_descriptor(out_gemmmraw_gemmn_grid_desc, + make_tuple(make_right_pad_transform(GemmMRaw, GemmMPad), + make_pass_through_transform(GemmN)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + // C0: bias tensor: assume a contiguous vector + const auto bias_grid_desc_gemmm_gemmn = + make_naive_tensor_descriptor(make_tuple(GemmM, GemmN), make_tuple(I0, I1)); + + // C1: residual tensor: assume same layout as output tensor + const auto resi_grid_desc_gemmm_gemmn = out_gemmm_gemmn_grid_desc; + + return make_tuple(in_gemmk0_gemmm_gemmk1_grid_desc, + wei_gemmk0_gemmn_gemmk1_grid_desc, + out_gemmm_gemmn_grid_desc, + bias_grid_desc_gemmm_gemmn, + resi_grid_desc_gemmm_gemmn); + } + else if constexpr(ConvForwardSpecialization == + ConvolutionForwardSpecialization_t::Filter1x1Pad0) + { // 1x1, pad=0 + const index_t GemmK = Y * X * C; + assert(GemmK % GemmK1Number == 0); + + const index_t GemmK0 = GemmK / GemmK1Number; + + // A: input tensor + const auto in_n_hi_wi_c_grid_desc = + make_naive_tensor_descriptor_packed(make_tuple(N, Hi, Wi, C)); + + const auto in_n_ho_wo_c_grid_desc = transform_tensor_descriptor( + in_n_hi_wi_c_grid_desc, + make_tuple(make_pass_through_transform(N), + make_embed_transform(make_tuple(Ho), make_tuple(ConvStrideH)), + make_embed_transform(make_tuple(Wo), make_tuple(ConvStrideW)), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); + + const auto in_gemmk0_gemmmraw_gemmk1_grid_desc = transform_tensor_descriptor( + in_n_ho_wo_c_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1Number)), + make_merge_transform(make_tuple(N, Ho, Wo))), + make_tuple(Sequence<3>{}, Sequence<0, 1, 2>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + const auto in_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor( + in_gemmk0_gemmmraw_gemmk1_grid_desc, + make_tuple(make_pass_through_transform(GemmK0), + make_right_pad_transform(GemmMRaw, GemmMPad), + make_pass_through_transform(GemmK1Number)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + + // B: weight tensor + const auto wei_gemmn_gemmk_grid_desc = + make_naive_tensor_descriptor_packed(make_tuple(K, C)); + + const auto wei_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor( + wei_gemmn_gemmk_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1Number)), + make_pass_through_transform(GemmN)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + // C: output tensor + const auto out_gemmmraw_gemmn_grid_desc = + make_naive_tensor_descriptor_packed(make_tuple(N * Ho * Wo, K)); + + const auto out_gemmm_gemmn_grid_desc = + transform_tensor_descriptor(out_gemmmraw_gemmn_grid_desc, + make_tuple(make_right_pad_transform(GemmMRaw, GemmMPad), + make_pass_through_transform(GemmN)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + // C0: bias tensor: assume a contiguous vector + const auto bias_grid_desc_gemmm_gemmn = + make_naive_tensor_descriptor(make_tuple(GemmM, GemmN), make_tuple(I0, I1)); + + // C1: residual tensor: assume same layout as output tensor + const auto resi_grid_desc_gemmm_gemmn = out_gemmm_gemmn_grid_desc; + + return make_tuple(in_gemmk0_gemmm_gemmk1_grid_desc, + wei_gemmk0_gemmn_gemmk1_grid_desc, + out_gemmm_gemmn_grid_desc, + bias_grid_desc_gemmm_gemmn, + resi_grid_desc_gemmm_gemmn); + } + else if constexpr(ConvForwardSpecialization == ConvolutionForwardSpecialization_t::OddC) + { // C = odd value + const index_t GemmKRaw = Y * X * C; + const index_t GemmK = math::integer_least_multiple(GemmKRaw, K0PerBlock * GemmK1Number); + const index_t GemmKPad = GemmK - GemmKRaw; + const index_t GemmK0 = GemmK / GemmK1Number; + + // A: input tensor + const auto in_n_hi_wi_c_grid_desc = + make_naive_tensor_descriptor_packed(make_tuple(N, Hi, Wi, C)); + + const auto in_n_hip_wip_c_grid_desc = transform_tensor_descriptor( + in_n_hi_wi_c_grid_desc, + make_tuple(make_pass_through_transform(N), + make_pad_transform(Hi, InLeftPadH, InRightPadH), + make_pad_transform(Wi, InLeftPadW, InRightPadW), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); + + const auto in_n_y_ho_x_wo_c_grid_desc = transform_tensor_descriptor( + in_n_hip_wip_c_grid_desc, + make_tuple( + make_pass_through_transform(N), + make_embed_transform(make_tuple(Y, Ho), make_tuple(ConvDilationH, ConvStrideH)), + make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW)), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{})); + + const auto in_gemmkraw_gemmmraw_grid_desc = + transform_tensor_descriptor(in_n_y_ho_x_wo_c_grid_desc, + make_tuple(make_merge_transform(make_tuple(Y, X, C)), + make_merge_transform(make_tuple(N, Ho, Wo))), + make_tuple(Sequence<1, 3, 5>{}, Sequence<0, 2, 4>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto in_gemmk_gemmm_grid_desc = transform_tensor_descriptor( + in_gemmkraw_gemmmraw_grid_desc, + make_tuple(make_right_pad_transform(GemmKRaw, GemmKPad), + make_right_pad_transform(GemmMRaw, GemmMPad)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto in_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor( + in_gemmk_gemmm_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1Number)), + make_pass_through_transform(GemmM)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + // B: weight tensor + const auto wei_k_yxc_grid_desc = + make_naive_tensor_descriptor_packed(make_tuple(K, Y * X * C)); + + const auto wei_gemmk_gemmn_grid_desc = transform_tensor_descriptor( + wei_k_yxc_grid_desc, + make_tuple(make_pass_through_transform(K), + make_right_pad_transform(GemmKRaw, GemmKPad)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<1>{}, Sequence<0>{})); + + const auto wei_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor( + wei_gemmk_gemmn_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1Number)), + make_pass_through_transform(GemmN)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + // C: output tensor + const auto out_nhowo_k_grid_desc = + make_naive_tensor_descriptor_packed(make_tuple(N * Ho * Wo, K)); + + const auto out_gemmmraw_gemmn_grid_desc = + transform_tensor_descriptor(out_nhowo_k_grid_desc, + make_tuple(make_pass_through_transform(N * Ho * Wo), + make_pass_through_transform(K)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto out_gemmm_gemmn_grid_desc = + transform_tensor_descriptor(out_gemmmraw_gemmn_grid_desc, + make_tuple(make_right_pad_transform(GemmMRaw, GemmMPad), + make_pass_through_transform(GemmN)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + // C0: bias tensor: assume a contiguous vector + const auto bias_grid_desc_gemmm_gemmn = + make_naive_tensor_descriptor(make_tuple(GemmM, GemmN), make_tuple(I0, I1)); + + // C1: residual tensor: assume same layout as output tensor + const auto resi_grid_desc_gemmm_gemmn = out_gemmm_gemmn_grid_desc; + + return make_tuple(in_gemmk0_gemmm_gemmk1_grid_desc, + wei_gemmk0_gemmn_gemmk1_grid_desc, + out_gemmm_gemmn_grid_desc, + bias_grid_desc_gemmm_gemmn, + resi_grid_desc_gemmm_gemmn); + } + else + { + const index_t GemmK = Y * X * C; + assert(GemmK % GemmK1Number == 0); + + const index_t GemmK0 = GemmK / GemmK1Number; + + // A: input tensor + const auto in_n_hi_wi_c_grid_desc = + make_naive_tensor_descriptor_packed(make_tuple(N, Hi, Wi, C)); + + const auto in_n_hip_wip_c_grid_desc = transform_tensor_descriptor( + in_n_hi_wi_c_grid_desc, + make_tuple(make_pass_through_transform(N), + make_pad_transform(Hi, InLeftPadH, InRightPadH), + make_pad_transform(Wi, InLeftPadW, InRightPadW), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); + + const auto in_n_y_ho_x_wo_c_grid_desc = transform_tensor_descriptor( + in_n_hip_wip_c_grid_desc, + make_tuple( + make_pass_through_transform(N), + make_embed_transform(make_tuple(Y, Ho), make_tuple(ConvDilationH, ConvStrideH)), + make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW)), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{})); + + const auto in_gemmk_gemmmraw_grid_desc = + transform_tensor_descriptor(in_n_y_ho_x_wo_c_grid_desc, + make_tuple(make_merge_transform(make_tuple(Y, X, C)), + make_merge_transform(make_tuple(N, Ho, Wo))), + make_tuple(Sequence<1, 3, 5>{}, Sequence<0, 2, 4>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto in_gemmk0_gemmmraw_gemmk1_grid_desc = transform_tensor_descriptor( + in_gemmk_gemmmraw_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1Number)), + make_pass_through_transform(GemmMRaw)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + const auto in_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor( + in_gemmk0_gemmmraw_gemmk1_grid_desc, + make_tuple(make_pass_through_transform(GemmK0), + make_right_pad_transform(GemmMRaw, GemmMPad), + make_pass_through_transform(GemmK1Number)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + + // B: weight tensor + const auto wei_k_yxc_grid_desc = + make_naive_tensor_descriptor_packed(make_tuple(K, Y * X * C)); + + const auto wei_gemmk_gemmn_grid_desc = transform_tensor_descriptor( + wei_k_yxc_grid_desc, + make_tuple(make_pass_through_transform(K), make_pass_through_transform(Y * X * C)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<1>{}, Sequence<0>{})); + + const auto wei_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor( + wei_gemmk_gemmn_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1Number)), + make_pass_through_transform(GemmN)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + // C: output tensor + const auto out_nhowo_k_grid_desc = + make_naive_tensor_descriptor_packed(make_tuple(N * Ho * Wo, K)); + + const auto out_gemmmraw_gemmn_grid_desc = + transform_tensor_descriptor(out_nhowo_k_grid_desc, + make_tuple(make_pass_through_transform(N * Ho * Wo), + make_pass_through_transform(K)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto out_gemmm_gemmn_grid_desc = + transform_tensor_descriptor(out_gemmmraw_gemmn_grid_desc, + make_tuple(make_right_pad_transform(GemmMRaw, GemmMPad), + make_pass_through_transform(GemmN)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + // C0: bias tensor: assume a contiguous vector + const auto bias_grid_desc_gemmm_gemmn = + make_naive_tensor_descriptor(make_tuple(GemmM, GemmN), make_tuple(I0, I1)); + + // C1: residual tensor: assume same layout as output tensor + const auto resi_grid_desc_gemmm_gemmn = out_gemmm_gemmn_grid_desc; + + return make_tuple(in_gemmk0_gemmm_gemmk1_grid_desc, + wei_gemmk0_gemmn_gemmk1_grid_desc, + out_gemmm_gemmn_grid_desc, + bias_grid_desc_gemmm_gemmn, + resi_grid_desc_gemmm_gemmn); + } + } + + using ABCGridDescs = decltype(MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N( + 1, 1, 1, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1})); + + using AGridDesc_K0_M_K1 = remove_cvref_t; + using BGridDesc_K0_N_K1 = remove_cvref_t; + using CGridDesc_M_N = remove_cvref_t; + using C0GridDesc_M_N = remove_cvref_t; + using C1GridDesc_M_N = remove_cvref_t; + + // GridwiseGemm + using GridwiseGemm = GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r3< + BlockSize, + ABDataType, // TODO: distinguish A/B datatype + AccDataType, + CDataType, + InMemoryDataOperationEnum_t::Set, + AGridDesc_K0_M_K1, + BGridDesc_K0_N_K1, + CGridDesc_M_N, + C0GridDesc_M_N, + C1GridDesc_M_N, + InElementwiseOperation, + WeiElementwiseOperation, + OutElementwiseOperation, + MPerBlock, + NPerBlock, + K0PerBlock, + MPerXDL, + NPerXDL, + K1, + MXdlPerWave, + NXdlPerWave, + ABlockTransferThreadClusterLengths_K0_M_K1, + Sequence<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder, + Sequence<1, 0, 2>, // ABlockTransferSrcAccessOrder, + 2, // ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_K1, + false, // AThreadTransferSrcResetCoordinateAfterRun, + ABlockLdsAddExtraM, + BBlockTransferThreadClusterLengths_K0_N_K1, + Sequence<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder, + Sequence<1, 0, 2>, // BBlockTransferSrcAccessOrder, + 2, // BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_K1, + false, // BThreadTransferSrcResetCoordinateAfterRun, + BBlockLdsAddExtraN, + CShuffleMXdlPerWavePerShuffle, + CShuffleNXdlPerWavePerShuffle, + CBlockTransferClusterLengths_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl, + CBlockTransferScalarPerVector_NWaveNPerXdl>; + + // Argument + struct Argument : public BaseArgument + { + Argument(const InDataType* p_in_grid, + const WeiDataType* p_wei_grid, + OutDataType* p_out_grid, + const OutDataType* p_bias_grid, + const OutDataType* p_resi_grid, + ck::index_t N, + ck::index_t K, + ck::index_t C, + std::vector input_spatial_lengths, + std::vector filter_spatial_lengths, + std::vector output_spatial_lengths, + std::vector conv_filter_strides, + std::vector conv_filter_dilations, + std::vector input_left_pads, + std::vector input_right_pads, + ck::index_t M01, + ck::index_t N01, + InElementwiseOperation in_element_op, + WeiElementwiseOperation wei_element_op, + OutElementwiseOperation out_element_op) + : p_a_grid_{p_in_grid}, + p_b_grid_{p_wei_grid}, + p_c_grid_{p_out_grid}, + p_c0_grid_{p_bias_grid}, + p_c1_grid_{p_resi_grid}, + a_grid_desc_k0_m_k1_{}, + b_grid_desc_k0_n_k1_{}, + c_grid_desc_m_n_{}, + c0_grid_desc_m_n_{}, + c1_grid_desc_m_n_{}, + c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_{}, + c0_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_{}, + c1_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_{}, + block_2_ctile_map_{}, + M01_{M01}, + N01_{N01}, + in_element_op_{in_element_op}, + wei_element_op_{wei_element_op}, + out_element_op_{out_element_op}, + Conv_N_{N}, + Conv_K_{K}, + Conv_C_{C}, + filter_spatial_lengths_{filter_spatial_lengths}, + conv_filter_strides_{conv_filter_strides}, + input_left_pads_{input_left_pads}, + input_right_pads_{input_right_pads} + { + const auto descs = + DeviceOp::MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N(N, + K, + C, + input_spatial_lengths, + filter_spatial_lengths, + output_spatial_lengths, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads); + + a_grid_desc_k0_m_k1_ = descs[I0]; + b_grid_desc_k0_n_k1_ = descs[I1]; + c_grid_desc_m_n_ = descs[I2]; + c0_grid_desc_m_n_ = descs[I3]; + c1_grid_desc_m_n_ = descs[I4]; + + if(GridwiseGemm::CheckValidity( + a_grid_desc_k0_m_k1_, b_grid_desc_k0_n_k1_, c_grid_desc_m_n_, M01_, N01_)) + { + c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_ = + GridwiseGemm:: + MakeCGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl( + c_grid_desc_m_n_); + + c0_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_ = + GridwiseGemm:: + MakeCGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl( + c0_grid_desc_m_n_); + + c1_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_ = + GridwiseGemm:: + MakeCGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl( + c1_grid_desc_m_n_); + + block_2_ctile_map_ = GridwiseGemm::MakeBlock2CTileMap(c_grid_desc_m_n_, M01, N01); + } + } + + // private: + const ADataType* p_a_grid_; + const BDataType* p_b_grid_; + CDataType* p_c_grid_; + const CDataType* p_c0_grid_; + const CDataType* p_c1_grid_; + AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1_; + BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1_; + CGridDesc_M_N c_grid_desc_m_n_; + C0GridDesc_M_N c0_grid_desc_m_n_; + C1GridDesc_M_N c1_grid_desc_m_n_; + typename GridwiseGemm:: + CGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl + c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_; + typename GridwiseGemm:: + C0GridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl + c0_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_; + typename GridwiseGemm:: + C1GridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl + c1_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_; + typename GridwiseGemm::Block2CTileMap block_2_ctile_map_; + index_t M01_; + index_t N01_; + InElementwiseOperation in_element_op_; + WeiElementwiseOperation wei_element_op_; + OutElementwiseOperation out_element_op_; + // for checking IsSupportedArgument() + index_t Conv_N_; + index_t Conv_K_; + index_t Conv_C_; + std::vector filter_spatial_lengths_; + std::vector conv_filter_strides_; + std::vector input_left_pads_; + std::vector input_right_pads_; + }; + + // Invoker + struct Invoker : public BaseInvoker + { + using Argument = DeviceOp::Argument; + + float Run(const Argument& arg, int nrepeat = 1) + { + { + std::cout << "arg.a_grid_desc_k0_m_k1_{" << arg.a_grid_desc_k0_m_k1_.GetLength(I0) + << ", " << arg.a_grid_desc_k0_m_k1_.GetLength(I1) << ", " + << arg.a_grid_desc_k0_m_k1_.GetLength(I2) << "}" << std::endl; + + std::cout << "arg.b_grid_desc_k0_n_k1_{" << arg.b_grid_desc_k0_n_k1_.GetLength(I0) + << ", " << arg.b_grid_desc_k0_n_k1_.GetLength(I1) << ", " + << arg.b_grid_desc_k0_n_k1_.GetLength(I2) << "}" << std::endl; + + std::cout << "arg.c_grid_desc_m_n_{ " << arg.c_grid_desc_m_n_.GetLength(I0) << ", " + << arg.c_grid_desc_m_n_.GetLength(I1) << "}" << std::endl; + + std::cout << "arg.c0_grid_desc_m_n_{ " << arg.c0_grid_desc_m_n_.GetLength(I0) + << ", " << arg.c0_grid_desc_m_n_.GetLength(I1) << "}" << std::endl; + + std::cout << "arg.c1_grid_desc_m_n_{ " << arg.c1_grid_desc_m_n_.GetLength(I0) + << ", " << arg.c1_grid_desc_m_n_.GetLength(I1) << "}" << std::endl; + } + + if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_, + arg.b_grid_desc_k0_n_k1_, + arg.c_grid_desc_m_n_, + arg.M01_, + arg.N01_)) + { + throw std::runtime_error( + "wrong! GridwiseGemm_km_kn_m0m1n0n1_xdlops_v3r3 has invalid setting"); + } + + const index_t grid_size = GridwiseGemm::CalculateGridSize(arg.c_grid_desc_m_n_); + + const auto K0 = arg.a_grid_desc_k0_m_k1_.GetLength(I0); + + const bool has_main_k0_block_loop = GridwiseGemm::CalculateHasMainK0BlockLoop(K0); + + float ave_time = 0; + + if(has_main_k0_block_loop) + { + const auto kernel = kernel_gemm_xdlops_v3r3< + GridwiseGemm, + ADataType, // TODO: distiguish A/B datatype + CDataType, + remove_reference_t, + remove_reference_t, + remove_reference_t< + typename GridwiseGemm:: + CGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl>, + remove_reference_t< + typename GridwiseGemm:: + C0GridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl>, + remove_reference_t< + typename GridwiseGemm:: + C1GridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl>, + InElementwiseOperation, + WeiElementwiseOperation, + OutElementwiseOperation, + remove_reference_t, + true>; + + ave_time = launch_and_time_kernel( + kernel, + nrepeat, + dim3(grid_size), + dim3(BlockSize), + 0, + arg.p_a_grid_, + arg.p_b_grid_, + arg.p_c_grid_, + arg.p_c0_grid_, + arg.p_c1_grid_, + arg.a_grid_desc_k0_m_k1_, + arg.b_grid_desc_k0_n_k1_, + arg.c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_, + arg.c0_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_, + arg.c1_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_, + arg.in_element_op_, + arg.wei_element_op_, + arg.out_element_op_, + arg.block_2_ctile_map_); + } + else + { + const auto kernel = kernel_gemm_xdlops_v3r3< + GridwiseGemm, + ADataType, // TODO: distiguish A/B datatype + CDataType, + remove_reference_t, + remove_reference_t, + remove_reference_t< + typename GridwiseGemm:: + CGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl>, + remove_reference_t< + typename GridwiseGemm:: + C0GridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl>, + remove_reference_t< + typename GridwiseGemm:: + C1GridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl>, + InElementwiseOperation, + WeiElementwiseOperation, + OutElementwiseOperation, + remove_reference_t, + false>; + + ave_time = launch_and_time_kernel( + kernel, + nrepeat, + dim3(grid_size), + dim3(BlockSize), + 0, + arg.p_a_grid_, + arg.p_b_grid_, + arg.p_c_grid_, + arg.p_c0_grid_, + arg.p_c1_grid_, + arg.a_grid_desc_k0_m_k1_, + arg.b_grid_desc_k0_n_k1_, + arg.c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_, + arg.c0_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_, + arg.c1_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_, + arg.in_element_op_, + arg.wei_element_op_, + arg.out_element_op_, + arg.block_2_ctile_map_); + } + + return ave_time; + } + + float Run(const BaseArgument* p_arg, int nrepeat = 1) override + { + return Run(*dynamic_cast(p_arg), nrepeat); + } + }; + + static constexpr bool IsValidCompilationParameter() + { + // TODO: properly implement this check + return true; + } + + static bool IsSupportedArgument(const Argument& arg) + { + if constexpr(ConvForwardSpecialization == + ConvolutionForwardSpecialization_t::Filter1x1Stride1Pad0) + { + // check if it's 1x1, stride=1 conv + if(!(arg.filter_spatial_lengths_[0] == 1 && arg.filter_spatial_lengths_[1] == 1 && + arg.conv_filter_strides_[0] == 1 && arg.conv_filter_strides_[1] == 1 && + arg.input_left_pads_[0] == 0 && arg.input_left_pads_[1] == 0 && + arg.input_right_pads_[0] == 0 && arg.input_right_pads_[1] == 0)) + { + return false; + } + } + else if constexpr(ConvForwardSpecialization == + ConvolutionForwardSpecialization_t::Filter1x1Pad0) + { + // check if it's 1x1 conv + if(!(arg.filter_spatial_lengths_[0] == 1 && arg.filter_spatial_lengths_[1] == 1 && + arg.input_left_pads_[0] == 0 && arg.input_left_pads_[1] == 0 && + arg.input_right_pads_[0] == 0 && arg.input_right_pads_[1] == 0)) + { + return false; + } + } + + // vector load A/B matrix from global memory + if(!(ABlockTransferSrcVectorDim == 2 && BBlockTransferSrcVectorDim == 2 && + arg.Conv_C_ % ABlockTransferSrcScalarPerVector == 0 && + arg.Conv_C_ % BBlockTransferSrcScalarPerVector == 0)) + { + return false; + } + + // vector store C matrix into global memory + if(!(arg.Conv_K_ % CBlockTransferScalarPerVector_NWaveNPerXdl == 0)) + { + return false; + } + + // Gridwise GEMM size + return GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_, + arg.b_grid_desc_k0_n_k1_, + arg.c_grid_desc_m_n_, + arg.M01_, + arg.N01_); + } + + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + return IsSupportedArgument(*dynamic_cast(p_arg)); + } + + static auto MakeArgument(const InDataType* p_in_grid, + const WeiDataType* p_wei_grid, + OutDataType* p_out_grid, + const OutDataType* p_bias_grid, + const OutDataType* p_resi_grid, + ck::index_t N, + ck::index_t K, + ck::index_t C, + std::vector input_spatial_lengths, + std::vector filter_spatial_lengths, + std::vector output_spatial_lengths, + std::vector conv_filter_strides, + std::vector conv_filter_dilations, + std::vector input_left_pads, + std::vector input_right_pads, + InElementwiseOperation in_element_op, + WeiElementwiseOperation wei_element_op, + OutElementwiseOperation out_element_op) + { + return Argument{p_in_grid, + p_wei_grid, + p_out_grid, + p_bias_grid, + p_resi_grid, + N, + K, + C, + input_spatial_lengths, + filter_spatial_lengths, + output_spatial_lengths, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + 1, + 1, + in_element_op, + wei_element_op, + out_element_op}; + } + + static auto MakeInvoker() { return Invoker{}; } + + std::unique_ptr + MakeArgumentPointer(const void* p_in_grid, + const void* p_wei_grid, + void* p_out_grid, + const void* p_bias_grid, + const void* p_resi_grid, + ck::index_t N, + ck::index_t K, + ck::index_t C, + std::vector input_spatial_lengths, + std::vector filter_spatial_lengths, + std::vector output_spatial_lengths, + std::vector conv_filter_strides, + std::vector conv_filter_dilations, + std::vector input_left_pads, + std::vector input_right_pads, + InElementwiseOperation in_element_op, + WeiElementwiseOperation wei_element_op, + OutElementwiseOperation out_element_op) override + { + return std::make_unique(static_cast(p_in_grid), + static_cast(p_wei_grid), + static_cast(p_out_grid), + static_cast(p_bias_grid), + static_cast(p_resi_grid), + N, + K, + C, + input_spatial_lengths, + filter_spatial_lengths, + output_spatial_lengths, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + 1, + 1, + in_element_op, + wei_element_op, + out_element_op); + } + + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(Invoker{}); + } + + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + // clang-format off + str << "DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Add_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K" + << "<" + << BlockSize << ", " + << MPerBlock << ", " + << NPerBlock << ", " + << K0PerBlock + << ">"; + // clang-format on + + return str.str(); + } +}; +} // namespace device +} // namespace tensor_operation +} // namespace ck +#endif diff --git a/device_operation/include/device_conv2d_fwd_xdl_c_shuffle_bias_activation_nhwc_kyxc_nhwk.hpp b/device_operation/include/device_conv2d_fwd_xdl_c_shuffle_bias_activation_nhwc_kyxc_nhwk.hpp new file mode 100644 index 0000000000..d915feab75 --- /dev/null +++ b/device_operation/include/device_conv2d_fwd_xdl_c_shuffle_bias_activation_nhwc_kyxc_nhwk.hpp @@ -0,0 +1,892 @@ +#ifndef DEVICE_CONV2D_FWD_XDL_C_SHUFFLE_BIAS_ACTIVATION_NHWC_KYXC_NHWK_HPP +#define DEVICE_CONV2D_FWD_XDL_C_SHUFFLE_BIAS_ACTIVATION_NHWC_KYXC_NHWK_HPP + +#include +#include +#include "device.hpp" +#include "device_base.hpp" +#include "device_conv_fwd_bias_activation.hpp" +#include "convolution_forward_specialization.hpp" +#include "common_header.hpp" +#include "tensor_layout.hpp" +#include "tensor_descriptor.hpp" +#include "tensor_descriptor_helper.hpp" +#include "gridwise_gemm_xdlops_v3r2.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +// out[N, Ho, Wo, K] = +// activate(in[N, Hi, Wi, C] * wei[K, Y, X, C] + bias[K]) +template < + typename InDataType, + typename WeiDataType, + typename OutDataType, + typename AccDataType, + typename InElementwiseOperation, + typename WeiElementwiseOperation, + typename OutElementwiseOperation, + InMemoryDataOperationEnum_t OutGlobalMemoryDataOperation, + ConvolutionForwardSpecialization_t ConvForwardSpecialization, + ck::index_t BlockSize, + ck::index_t MPerBlock, + ck::index_t NPerBlock, + ck::index_t K0PerBlock, + ck::index_t K1, + ck::index_t MPerXDL, + ck::index_t NPerXDL, + ck::index_t MXdlPerWave, + ck::index_t NXdlPerWave, + typename ABlockTransferThreadClusterLengths_K0_M_K1, + typename ABlockTransferThreadClusterArrangeOrder, + typename ABlockTransferSrcAccessOrder, + ck::index_t ABlockTransferSrcVectorDim, + ck::index_t ABlockTransferSrcScalarPerVector, + ck::index_t ABlockTransferDstScalarPerVector_K1, + bool ABlockLdsAddExtraM, + typename BBlockTransferThreadClusterLengths_K0_N_K1, + typename BBlockTransferThreadClusterArrangeOrder, + typename BBlockTransferSrcAccessOrder, + ck::index_t BBlockTransferSrcVectorDim, + ck::index_t BBlockTransferSrcScalarPerVector, + ck::index_t BBlockTransferDstScalarPerVector_K1, + bool BBlockLdsAddExtraN, + index_t CShuffleMXdlPerWavePerShuffle, + index_t CShuffleNXdlPerWavePerShuffle, + typename CBlockTransferClusterLengths_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl, + index_t CBlockTransferScalarPerVector_NWaveNPerXdl> +struct DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K + : public DeviceConvFwdBiasActivation +{ + using DeviceOp = + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K; + + using ADataType = InDataType; + using BDataType = WeiDataType; + using CDataType = OutDataType; + + // TODO make A/B datatype different + using ABDataType = InDataType; + + // TODO make it support any # of spatial dimensions + static constexpr index_t NDimSpatial = 2; + + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + + static constexpr auto K1Number = Number{}; + static constexpr auto GemmK1Number = K1Number; + + static auto + MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N(ck::index_t N, + ck::index_t K, + ck::index_t C, + std::vector input_spatial_lengths, + std::vector filter_spatial_lengths, + std::vector output_spatial_lengths, + std::vector conv_filter_strides, + std::vector conv_filter_dilations, + std::vector input_left_pads, + std::vector input_right_pads) + { + using namespace ck; + + const index_t Hi = input_spatial_lengths[0]; + const index_t Wi = input_spatial_lengths[1]; + + const index_t Ho = output_spatial_lengths[0]; + const index_t Wo = output_spatial_lengths[1]; + + const index_t Y = filter_spatial_lengths[0]; + const index_t X = filter_spatial_lengths[1]; + + const index_t ConvStrideH = conv_filter_strides[0]; + const index_t ConvStrideW = conv_filter_strides[1]; + + const index_t ConvDilationH = conv_filter_dilations[0]; + const index_t ConvDilationW = conv_filter_dilations[1]; + + const index_t InLeftPadH = input_left_pads[0]; + const index_t InLeftPadW = input_left_pads[1]; + + const index_t InRightPadH = input_right_pads[0]; + const index_t InRightPadW = input_right_pads[1]; + + const index_t GemmMRaw = N * Ho * Wo; + const index_t GemmN = K; + + const auto GemmM = math::integer_least_multiple(GemmMRaw, MPerBlock); + const auto GemmMPad = GemmM - GemmMRaw; + + if constexpr(ConvForwardSpecialization == + ConvolutionForwardSpecialization_t::Filter1x1Stride1Pad0) + { // 1x1, stride=1, pad=0 + const index_t GemmK = Y * X * C; + assert(GemmK % GemmK1Number == 0); + + const index_t GemmK0 = GemmK / GemmK1Number; + + // A: input tensor + const auto in_gemmmraw_gemmk_grid_desc = + make_naive_tensor_descriptor_packed(make_tuple(N * Ho * Wo, C)); + + const auto in_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor( + in_gemmmraw_gemmk_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1Number)), + make_right_pad_transform(GemmMRaw, GemmMPad)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + // B: weight tensor + const auto wei_gemmn_gemmk_grid_desc = + make_naive_tensor_descriptor_packed(make_tuple(K, C)); + + const auto wei_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor( + wei_gemmn_gemmk_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1Number)), + make_pass_through_transform(GemmN)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + // C: output tensor + const auto out_gemmmraw_gemmn_grid_desc = + make_naive_tensor_descriptor_packed(make_tuple(N * Ho * Wo, K)); + + const auto out_gemmm_gemmn_grid_desc = + transform_tensor_descriptor(out_gemmmraw_gemmn_grid_desc, + make_tuple(make_right_pad_transform(GemmMRaw, GemmMPad), + make_pass_through_transform(GemmN)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + // C0: bias tensor: assume a contiguous vector + const auto bias_grid_desc_gemmm_gemmn = + make_naive_tensor_descriptor(make_tuple(GemmM, GemmN), make_tuple(I0, I1)); + + return make_tuple(in_gemmk0_gemmm_gemmk1_grid_desc, + wei_gemmk0_gemmn_gemmk1_grid_desc, + out_gemmm_gemmn_grid_desc, + bias_grid_desc_gemmm_gemmn); + } + else if constexpr(ConvForwardSpecialization == + ConvolutionForwardSpecialization_t::Filter1x1Pad0) + { // 1x1, pad=0 + const index_t GemmK = Y * X * C; + assert(GemmK % GemmK1Number == 0); + + const index_t GemmK0 = GemmK / GemmK1Number; + + // A: input tensor + const auto in_n_hi_wi_c_grid_desc = + make_naive_tensor_descriptor_packed(make_tuple(N, Hi, Wi, C)); + + const auto in_n_ho_wo_c_grid_desc = transform_tensor_descriptor( + in_n_hi_wi_c_grid_desc, + make_tuple(make_pass_through_transform(N), + make_embed_transform(make_tuple(Ho), make_tuple(ConvStrideH)), + make_embed_transform(make_tuple(Wo), make_tuple(ConvStrideW)), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); + + const auto in_gemmk0_gemmmraw_gemmk1_grid_desc = transform_tensor_descriptor( + in_n_ho_wo_c_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1Number)), + make_merge_transform(make_tuple(N, Ho, Wo))), + make_tuple(Sequence<3>{}, Sequence<0, 1, 2>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + const auto in_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor( + in_gemmk0_gemmmraw_gemmk1_grid_desc, + make_tuple(make_pass_through_transform(GemmK0), + make_right_pad_transform(GemmMRaw, GemmMPad), + make_pass_through_transform(GemmK1Number)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + + // B: weight tensor + const auto wei_gemmn_gemmk_grid_desc = + make_naive_tensor_descriptor_packed(make_tuple(K, C)); + + const auto wei_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor( + wei_gemmn_gemmk_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1Number)), + make_pass_through_transform(GemmN)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + // C: output tensor + const auto out_gemmmraw_gemmn_grid_desc = + make_naive_tensor_descriptor_packed(make_tuple(N * Ho * Wo, K)); + + const auto out_gemmm_gemmn_grid_desc = + transform_tensor_descriptor(out_gemmmraw_gemmn_grid_desc, + make_tuple(make_right_pad_transform(GemmMRaw, GemmMPad), + make_pass_through_transform(GemmN)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + // C0: bias tensor: assume a contiguous vector + const auto bias_grid_desc_gemmm_gemmn = + make_naive_tensor_descriptor(make_tuple(GemmM, GemmN), make_tuple(I0, I1)); + + return make_tuple(in_gemmk0_gemmm_gemmk1_grid_desc, + wei_gemmk0_gemmn_gemmk1_grid_desc, + out_gemmm_gemmn_grid_desc, + bias_grid_desc_gemmm_gemmn); + } + else if constexpr(ConvForwardSpecialization == ConvolutionForwardSpecialization_t::OddC) + { // C = odd value + const index_t GemmKRaw = Y * X * C; + const index_t GemmK = math::integer_least_multiple(GemmKRaw, K0PerBlock * GemmK1Number); + const index_t GemmKPad = GemmK - GemmKRaw; + const index_t GemmK0 = GemmK / GemmK1Number; + + // A: input tensor + const auto in_n_hi_wi_c_grid_desc = + make_naive_tensor_descriptor_packed(make_tuple(N, Hi, Wi, C)); + + const auto in_n_hip_wip_c_grid_desc = transform_tensor_descriptor( + in_n_hi_wi_c_grid_desc, + make_tuple(make_pass_through_transform(N), + make_pad_transform(Hi, InLeftPadH, InRightPadH), + make_pad_transform(Wi, InLeftPadW, InRightPadW), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); + + const auto in_n_y_ho_x_wo_c_grid_desc = transform_tensor_descriptor( + in_n_hip_wip_c_grid_desc, + make_tuple( + make_pass_through_transform(N), + make_embed_transform(make_tuple(Y, Ho), make_tuple(ConvDilationH, ConvStrideH)), + make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW)), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{})); + + const auto in_gemmkraw_gemmmraw_grid_desc = + transform_tensor_descriptor(in_n_y_ho_x_wo_c_grid_desc, + make_tuple(make_merge_transform(make_tuple(Y, X, C)), + make_merge_transform(make_tuple(N, Ho, Wo))), + make_tuple(Sequence<1, 3, 5>{}, Sequence<0, 2, 4>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto in_gemmk_gemmm_grid_desc = transform_tensor_descriptor( + in_gemmkraw_gemmmraw_grid_desc, + make_tuple(make_right_pad_transform(GemmKRaw, GemmKPad), + make_right_pad_transform(GemmMRaw, GemmMPad)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto in_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor( + in_gemmk_gemmm_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1Number)), + make_pass_through_transform(GemmM)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + // B: weight tensor + const auto wei_k_yxc_grid_desc = + make_naive_tensor_descriptor_packed(make_tuple(K, Y * X * C)); + + const auto wei_gemmk_gemmn_grid_desc = transform_tensor_descriptor( + wei_k_yxc_grid_desc, + make_tuple(make_pass_through_transform(K), + make_right_pad_transform(GemmKRaw, GemmKPad)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<1>{}, Sequence<0>{})); + + const auto wei_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor( + wei_gemmk_gemmn_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1Number)), + make_pass_through_transform(GemmN)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + // C: output tensor + const auto out_nhowo_k_grid_desc = + make_naive_tensor_descriptor_packed(make_tuple(N * Ho * Wo, K)); + + const auto out_gemmmraw_gemmn_grid_desc = + transform_tensor_descriptor(out_nhowo_k_grid_desc, + make_tuple(make_pass_through_transform(N * Ho * Wo), + make_pass_through_transform(K)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto out_gemmm_gemmn_grid_desc = + transform_tensor_descriptor(out_gemmmraw_gemmn_grid_desc, + make_tuple(make_right_pad_transform(GemmMRaw, GemmMPad), + make_pass_through_transform(GemmN)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + // C0: bias tensor: assume a contiguous vector + const auto bias_grid_desc_gemmm_gemmn = + make_naive_tensor_descriptor(make_tuple(GemmM, GemmN), make_tuple(I0, I1)); + + return make_tuple(in_gemmk0_gemmm_gemmk1_grid_desc, + wei_gemmk0_gemmn_gemmk1_grid_desc, + out_gemmm_gemmn_grid_desc, + bias_grid_desc_gemmm_gemmn); + } + else + { + const index_t GemmK = Y * X * C; + assert(GemmK % GemmK1Number == 0); + + const index_t GemmK0 = GemmK / GemmK1Number; + + // A: input tensor + const auto in_n_hi_wi_c_grid_desc = + make_naive_tensor_descriptor_packed(make_tuple(N, Hi, Wi, C)); + + const auto in_n_hip_wip_c_grid_desc = transform_tensor_descriptor( + in_n_hi_wi_c_grid_desc, + make_tuple(make_pass_through_transform(N), + make_pad_transform(Hi, InLeftPadH, InRightPadH), + make_pad_transform(Wi, InLeftPadW, InRightPadW), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); + + const auto in_n_y_ho_x_wo_c_grid_desc = transform_tensor_descriptor( + in_n_hip_wip_c_grid_desc, + make_tuple( + make_pass_through_transform(N), + make_embed_transform(make_tuple(Y, Ho), make_tuple(ConvDilationH, ConvStrideH)), + make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW)), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{})); + + const auto in_gemmk_gemmmraw_grid_desc = + transform_tensor_descriptor(in_n_y_ho_x_wo_c_grid_desc, + make_tuple(make_merge_transform(make_tuple(Y, X, C)), + make_merge_transform(make_tuple(N, Ho, Wo))), + make_tuple(Sequence<1, 3, 5>{}, Sequence<0, 2, 4>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto in_gemmk0_gemmmraw_gemmk1_grid_desc = transform_tensor_descriptor( + in_gemmk_gemmmraw_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1Number)), + make_pass_through_transform(GemmMRaw)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + const auto in_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor( + in_gemmk0_gemmmraw_gemmk1_grid_desc, + make_tuple(make_pass_through_transform(GemmK0), + make_right_pad_transform(GemmMRaw, GemmMPad), + make_pass_through_transform(GemmK1Number)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + + // B: weight tensor + const auto wei_k_yxc_grid_desc = + make_naive_tensor_descriptor_packed(make_tuple(K, Y * X * C)); + + const auto wei_gemmk_gemmn_grid_desc = transform_tensor_descriptor( + wei_k_yxc_grid_desc, + make_tuple(make_pass_through_transform(K), make_pass_through_transform(Y * X * C)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<1>{}, Sequence<0>{})); + + const auto wei_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor( + wei_gemmk_gemmn_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1Number)), + make_pass_through_transform(GemmN)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + // C: output tensor + const auto out_nhowo_k_grid_desc = + make_naive_tensor_descriptor_packed(make_tuple(N * Ho * Wo, K)); + + const auto out_gemmmraw_gemmn_grid_desc = + transform_tensor_descriptor(out_nhowo_k_grid_desc, + make_tuple(make_pass_through_transform(N * Ho * Wo), + make_pass_through_transform(K)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto out_gemmm_gemmn_grid_desc = + transform_tensor_descriptor(out_gemmmraw_gemmn_grid_desc, + make_tuple(make_right_pad_transform(GemmMRaw, GemmMPad), + make_pass_through_transform(GemmN)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + // C0: bias tensor: assume a contiguous vector + const auto bias_grid_desc_gemmm_gemmn = + make_naive_tensor_descriptor(make_tuple(GemmM, GemmN), make_tuple(I0, I1)); + + return make_tuple(in_gemmk0_gemmm_gemmk1_grid_desc, + wei_gemmk0_gemmn_gemmk1_grid_desc, + out_gemmm_gemmn_grid_desc, + bias_grid_desc_gemmm_gemmn); + } + } + + using ABCGridDescs = decltype(MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N( + 1, 1, 1, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1})); + + using AGridDesc_K0_M_K1 = remove_cvref_t; + using BGridDesc_K0_N_K1 = remove_cvref_t; + using CGridDesc_M_N = remove_cvref_t; + using C0GridDesc_M_N = remove_cvref_t; + + // GridwiseGemm + using GridwiseGemm = GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r2< + BlockSize, + ABDataType, // TODO: distinguish A/B datatype + AccDataType, + CDataType, + OutGlobalMemoryDataOperation, + AGridDesc_K0_M_K1, + BGridDesc_K0_N_K1, + CGridDesc_M_N, + C0GridDesc_M_N, + InElementwiseOperation, + WeiElementwiseOperation, + OutElementwiseOperation, + MPerBlock, + NPerBlock, + K0PerBlock, + MPerXDL, + NPerXDL, + K1, + MXdlPerWave, + NXdlPerWave, + ABlockTransferThreadClusterLengths_K0_M_K1, + Sequence<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder, + Sequence<1, 0, 2>, // ABlockTransferSrcAccessOrder, + 2, // ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_K1, + false, // AThreadTransferSrcResetCoordinateAfterRun, + ABlockLdsAddExtraM, + BBlockTransferThreadClusterLengths_K0_N_K1, + Sequence<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder, + Sequence<1, 0, 2>, // BBlockTransferSrcAccessOrder, + 2, // BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_K1, + false, // BThreadTransferSrcResetCoordinateAfterRun, + BBlockLdsAddExtraN, + CShuffleMXdlPerWavePerShuffle, + CShuffleNXdlPerWavePerShuffle, + CBlockTransferClusterLengths_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl, + CBlockTransferScalarPerVector_NWaveNPerXdl>; + + // Argument + struct Argument : public BaseArgument + { + Argument(const InDataType* p_in_grid, + const WeiDataType* p_wei_grid, + OutDataType* p_out_grid, + const OutDataType* p_bias_grid, + ck::index_t N, + ck::index_t K, + ck::index_t C, + std::vector input_spatial_lengths, + std::vector filter_spatial_lengths, + std::vector output_spatial_lengths, + std::vector conv_filter_strides, + std::vector conv_filter_dilations, + std::vector input_left_pads, + std::vector input_right_pads, + ck::index_t M01, + ck::index_t N01, + InElementwiseOperation in_element_op, + WeiElementwiseOperation wei_element_op, + OutElementwiseOperation out_element_op) + : p_a_grid_{p_in_grid}, + p_b_grid_{p_wei_grid}, + p_c_grid_{p_out_grid}, + p_c0_grid_{p_bias_grid}, + a_grid_desc_k0_m_k1_{}, + b_grid_desc_k0_n_k1_{}, + c_grid_desc_m_n_{}, + c0_grid_desc_m_n_{}, + c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_{}, + c0_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_{}, + block_2_ctile_map_{}, + M01_{M01}, + N01_{N01}, + in_element_op_{in_element_op}, + wei_element_op_{wei_element_op}, + out_element_op_{out_element_op}, + Conv_N_{N}, + Conv_K_{K}, + Conv_C_{C}, + filter_spatial_lengths_{filter_spatial_lengths}, + conv_filter_strides_{conv_filter_strides}, + input_left_pads_{input_left_pads}, + input_right_pads_{input_right_pads} + { + const auto descs = + DeviceOp::MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N(N, + K, + C, + input_spatial_lengths, + filter_spatial_lengths, + output_spatial_lengths, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads); + + a_grid_desc_k0_m_k1_ = descs[I0]; + b_grid_desc_k0_n_k1_ = descs[I1]; + c_grid_desc_m_n_ = descs[I2]; + c0_grid_desc_m_n_ = descs[I3]; + + if(GridwiseGemm::CheckValidity( + a_grid_desc_k0_m_k1_, b_grid_desc_k0_n_k1_, c_grid_desc_m_n_, M01_, N01_)) + { + c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_ = + GridwiseGemm:: + MakeCGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl( + c_grid_desc_m_n_); + + c0_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_ = + GridwiseGemm:: + MakeCGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl( + c0_grid_desc_m_n_); + + block_2_ctile_map_ = GridwiseGemm::MakeBlock2CTileMap(c_grid_desc_m_n_, M01, N01); + } + } + + const ADataType* p_a_grid_; + const BDataType* p_b_grid_; + CDataType* p_c_grid_; + const CDataType* p_c0_grid_; + AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1_; + BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1_; + CGridDesc_M_N c_grid_desc_m_n_; + C0GridDesc_M_N c0_grid_desc_m_n_; + typename GridwiseGemm:: + CGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl + c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_; + typename GridwiseGemm:: + C0GridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl + c0_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_; + typename GridwiseGemm::Block2CTileMap block_2_ctile_map_; + index_t M01_; + index_t N01_; + InElementwiseOperation in_element_op_; + WeiElementwiseOperation wei_element_op_; + OutElementwiseOperation out_element_op_; + // for checking IsSupportedArgument() + index_t Conv_N_; + index_t Conv_K_; + index_t Conv_C_; + std::vector filter_spatial_lengths_; + std::vector conv_filter_strides_; + std::vector input_left_pads_; + std::vector input_right_pads_; + }; + + // Invoker + struct Invoker : public BaseInvoker + { + using Argument = DeviceOp::Argument; + + float Run(const Argument& arg, int nrepeat = 1) + { + { + std::cout << "arg.a_grid_desc_k0_m_k1_{" << arg.a_grid_desc_k0_m_k1_.GetLength(I0) + << ", " << arg.a_grid_desc_k0_m_k1_.GetLength(I1) << ", " + << arg.a_grid_desc_k0_m_k1_.GetLength(I2) << "}" << std::endl; + + std::cout << "arg.b_grid_desc_k0_n_k1_{" << arg.b_grid_desc_k0_n_k1_.GetLength(I0) + << ", " << arg.b_grid_desc_k0_n_k1_.GetLength(I1) << ", " + << arg.b_grid_desc_k0_n_k1_.GetLength(I2) << "}" << std::endl; + + std::cout << "arg.c_grid_desc_m_n_{ " << arg.c_grid_desc_m_n_.GetLength(I0) << ", " + << arg.c_grid_desc_m_n_.GetLength(I1) << "}" << std::endl; + + std::cout << "arg.c0_grid_desc_m_n_{ " << arg.c0_grid_desc_m_n_.GetLength(I0) + << ", " << arg.c0_grid_desc_m_n_.GetLength(I1) << "}" << std::endl; + } + + if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_, + arg.b_grid_desc_k0_n_k1_, + arg.c_grid_desc_m_n_, + arg.M01_, + arg.N01_)) + { + throw std::runtime_error( + "wrong! GridwiseGemm_km_kn_m0m1n0n1_xdlops_v3r2 has invalid setting"); + } + + const index_t grid_size = GridwiseGemm::CalculateGridSize(arg.c_grid_desc_m_n_); + + const auto K0 = arg.a_grid_desc_k0_m_k1_.GetLength(I0); + + const bool has_main_k0_block_loop = GridwiseGemm::CalculateHasMainK0BlockLoop(K0); + + float ave_time = 0; + + if(has_main_k0_block_loop) + { + const auto kernel = kernel_gemm_xdlops_v3r2< + GridwiseGemm, + ADataType, // TODO: distiguish A/B datatype + CDataType, + remove_reference_t, + remove_reference_t, + remove_reference_t< + typename GridwiseGemm:: + CGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl>, + remove_reference_t< + typename GridwiseGemm:: + C0GridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl>, + InElementwiseOperation, + WeiElementwiseOperation, + OutElementwiseOperation, + remove_reference_t, + true>; + + ave_time = launch_and_time_kernel( + kernel, + nrepeat, + dim3(grid_size), + dim3(BlockSize), + 0, + arg.p_a_grid_, + arg.p_b_grid_, + arg.p_c_grid_, + arg.p_c0_grid_, + arg.a_grid_desc_k0_m_k1_, + arg.b_grid_desc_k0_n_k1_, + arg.c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_, + arg.c0_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_, + arg.in_element_op_, + arg.wei_element_op_, + arg.out_element_op_, + arg.block_2_ctile_map_); + } + else + { + const auto kernel = kernel_gemm_xdlops_v3r2< + GridwiseGemm, + ADataType, // TODO: distiguish A/B datatype + CDataType, + remove_reference_t, + remove_reference_t, + remove_reference_t< + typename GridwiseGemm:: + CGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl>, + remove_reference_t< + typename GridwiseGemm:: + C0GridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl>, + InElementwiseOperation, + WeiElementwiseOperation, + OutElementwiseOperation, + remove_reference_t, + false>; + + ave_time = launch_and_time_kernel( + kernel, + nrepeat, + dim3(grid_size), + dim3(BlockSize), + 0, + arg.p_a_grid_, + arg.p_b_grid_, + arg.p_c_grid_, + arg.p_c0_grid_, + arg.a_grid_desc_k0_m_k1_, + arg.b_grid_desc_k0_n_k1_, + arg.c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_, + arg.c0_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_, + arg.in_element_op_, + arg.wei_element_op_, + arg.out_element_op_, + arg.block_2_ctile_map_); + } + + return ave_time; + } + + float Run(const BaseArgument* p_arg, int nrepeat = 1) override + { + return Run(*dynamic_cast(p_arg), nrepeat); + } + }; + + static constexpr bool IsValidCompilationParameter() + { + // TODO: properly implement this check + return true; + } + + static bool IsSupportedArgument(const Argument& arg) + { + if constexpr(ConvForwardSpecialization == + ConvolutionForwardSpecialization_t::Filter1x1Stride1Pad0) + { + // check if it's 1x1, stride=1 conv + if(!(arg.filter_spatial_lengths_[0] == 1 && arg.filter_spatial_lengths_[1] == 1 && + arg.conv_filter_strides_[0] == 1 && arg.conv_filter_strides_[1] == 1 && + arg.input_left_pads_[0] == 0 && arg.input_left_pads_[1] == 0 && + arg.input_right_pads_[0] == 0 && arg.input_right_pads_[1] == 0)) + { + return false; + } + } + else if constexpr(ConvForwardSpecialization == + ConvolutionForwardSpecialization_t::Filter1x1Pad0) + { + // check if it's 1x1 conv + if(!(arg.filter_spatial_lengths_[0] == 1 && arg.filter_spatial_lengths_[1] == 1 && + arg.input_left_pads_[0] == 0 && arg.input_left_pads_[1] == 0 && + arg.input_right_pads_[0] == 0 && arg.input_right_pads_[1] == 0)) + { + return false; + } + } + + // vector load A/B matrix from global memory + if(!(ABlockTransferSrcVectorDim == 2 && BBlockTransferSrcVectorDim == 2 && + arg.Conv_C_ % ABlockTransferSrcScalarPerVector == 0 && + arg.Conv_C_ % BBlockTransferSrcScalarPerVector == 0)) + { + return false; + } + + // vector store C matrix into global memory + if(!(arg.Conv_K_ % CBlockTransferScalarPerVector_NWaveNPerXdl == 0)) + { + return false; + } + + // Gridwise GEMM size + return GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_, + arg.b_grid_desc_k0_n_k1_, + arg.c_grid_desc_m_n_, + arg.M01_, + arg.N01_); + } + + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + return IsSupportedArgument(*dynamic_cast(p_arg)); + } + + static auto MakeArgument(const InDataType* p_in_grid, + const WeiDataType* p_wei_grid, + OutDataType* p_out_grid, + const OutDataType* p_bias_grid, + ck::index_t N, + ck::index_t K, + ck::index_t C, + std::vector input_spatial_lengths, + std::vector filter_spatial_lengths, + std::vector output_spatial_lengths, + std::vector conv_filter_strides, + std::vector conv_filter_dilations, + std::vector input_left_pads, + std::vector input_right_pads, + InElementwiseOperation in_element_op, + WeiElementwiseOperation wei_element_op, + OutElementwiseOperation out_element_op) + { + return Argument{p_in_grid, + p_wei_grid, + p_out_grid, + p_bias_grid, + N, + K, + C, + input_spatial_lengths, + filter_spatial_lengths, + output_spatial_lengths, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + 1, + 1, + in_element_op, + wei_element_op, + out_element_op}; + } + + static auto MakeInvoker() { return Invoker{}; } + + std::unique_ptr + MakeArgumentPointer(const void* p_in_grid, + const void* p_wei_grid, + void* p_out_grid, + const void* p_bias_grid, + ck::index_t N, + ck::index_t K, + ck::index_t C, + std::vector input_spatial_lengths, + std::vector filter_spatial_lengths, + std::vector output_spatial_lengths, + std::vector conv_filter_strides, + std::vector conv_filter_dilations, + std::vector input_left_pads, + std::vector input_right_pads, + InElementwiseOperation in_element_op, + WeiElementwiseOperation wei_element_op, + OutElementwiseOperation out_element_op) override + { + return std::make_unique(static_cast(p_in_grid), + static_cast(p_wei_grid), + static_cast(p_out_grid), + static_cast(p_bias_grid), + N, + K, + C, + input_spatial_lengths, + filter_spatial_lengths, + output_spatial_lengths, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + 1, + 1, + in_element_op, + wei_element_op, + out_element_op); + } + + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(Invoker{}); + } + + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + // clang-format off + str << "DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K" + << "<" + << BlockSize << ", " + << MPerBlock << ", " + << NPerBlock << ", " + << K0PerBlock + << ">"; + // clang-format on + + return str.str(); + } +}; +} // namespace device +} // namespace tensor_operation +} // namespace ck +#endif diff --git a/device_operation/include/device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp b/device_operation/include/device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp new file mode 100644 index 0000000000..43a10b1627 --- /dev/null +++ b/device_operation/include/device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp @@ -0,0 +1,857 @@ +#ifndef DEVICE_CONV2D_FWD_XDL_C_SHUFFLE_NHWC_KYXC_NHWK_HPP +#define DEVICE_CONV2D_FWD_XDL_C_SHUFFLE_NHWC_KYXC_NHWK_HPP + +#include +#include +#include "device.hpp" +#include "device_base.hpp" +#include "device_conv_fwd.hpp" +#include "convolution_forward_specialization.hpp" +#include "common_header.hpp" +#include "tensor_layout.hpp" +#include "tensor_descriptor.hpp" +#include "tensor_descriptor_helper.hpp" +#include "gridwise_gemm_xdlops_v3r1.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +// out[N, Ho, Wo, K] = in[N, Hi, Wi, C] * wei[K, Y, X, C] +template < + typename InDataType, + typename WeiDataType, + typename OutDataType, + typename AccDataType, + typename InElementwiseOperation, + typename WeiElementwiseOperation, + typename OutElementwiseOperation, + ConvolutionForwardSpecialization_t ConvForwardSpecialization, + ck::index_t BlockSize, + ck::index_t MPerBlock, + ck::index_t NPerBlock, + ck::index_t K0PerBlock, + ck::index_t K1, + ck::index_t MPerXdl, + ck::index_t NPerXdl, + ck::index_t MXdlPerWave, + ck::index_t NXdlPerWave, + typename ABlockTransferThreadClusterLengths_K0_M_K1, + typename ABlockTransferThreadClusterArrangeOrder, + typename ABlockTransferSrcAccessOrder, + ck::index_t ABlockTransferSrcVectorDim, + ck::index_t ABlockTransferSrcScalarPerVector, + ck::index_t ABlockTransferDstScalarPerVector_K1, + bool ABlockLdsAddExtraM, + typename BBlockTransferThreadClusterLengths_K0_N_K1, + typename BBlockTransferThreadClusterArrangeOrder, + typename BBlockTransferSrcAccessOrder, + ck::index_t BBlockTransferSrcVectorDim, + ck::index_t BBlockTransferSrcScalarPerVector, + ck::index_t BBlockTransferDstScalarPerVector_K1, + bool BBlockLdsAddExtraN, + index_t CShuffleMXdlPerWavePerShuffle, + index_t CShuffleNXdlPerWavePerShuffle, + typename CBlockTransferClusterLengths_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl, + index_t CBlockTransferScalarPerVector_NWaveNPerXdl> +struct DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K + : public DeviceConvFwd +{ + using DeviceOp = DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K; + + using ADataType = InDataType; + using BDataType = WeiDataType; + using CDataType = OutDataType; + + // TODO make A/B datatype different + using ABDataType = InDataType; + + static constexpr index_t NDimSpatial = 2; + + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + static constexpr auto I4 = Number<4>{}; + static constexpr auto I5 = Number<5>{}; + + static constexpr auto K1Number = Number{}; + static constexpr auto GemmK1Number = K1Number; + + static auto + MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N(ck::index_t N, + ck::index_t K, + ck::index_t C, + std::vector input_spatial_lengths, + std::vector filter_spatial_lengths, + std::vector output_spatial_lengths, + std::vector conv_filter_strides, + std::vector conv_filter_dilations, + std::vector input_left_pads, + std::vector input_right_pads) + { + using namespace ck; + + const index_t Hi = input_spatial_lengths[0]; + const index_t Wi = input_spatial_lengths[1]; + + const index_t Ho = output_spatial_lengths[0]; + const index_t Wo = output_spatial_lengths[1]; + + const index_t Y = filter_spatial_lengths[0]; + const index_t X = filter_spatial_lengths[1]; + + const index_t ConvStrideH = conv_filter_strides[0]; + const index_t ConvStrideW = conv_filter_strides[1]; + + const index_t ConvDilationH = conv_filter_dilations[0]; + const index_t ConvDilationW = conv_filter_dilations[1]; + + const index_t InLeftPadH = input_left_pads[0]; + const index_t InLeftPadW = input_left_pads[1]; + + const index_t InRightPadH = input_right_pads[0]; + const index_t InRightPadW = input_right_pads[1]; + + const index_t GemmMRaw = N * Ho * Wo; + const index_t GemmN = K; + + const auto GemmM = math::integer_least_multiple(GemmMRaw, MPerBlock); + const auto GemmMPad = GemmM - GemmMRaw; + + if constexpr(ConvForwardSpecialization == + ConvolutionForwardSpecialization_t::Filter1x1Stride1Pad0) + { // 1x1, stride=1, pad=0 + const index_t GemmK = Y * X * C; + assert(GemmK % GemmK1Number == 0); + + const index_t GemmK0 = GemmK / GemmK1Number; + + // A: input tensor + const auto in_gemmmraw_gemmk_grid_desc = + make_naive_tensor_descriptor_packed(make_tuple(N * Ho * Wo, C)); + + const auto in_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor( + in_gemmmraw_gemmk_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1Number)), + make_right_pad_transform(GemmMRaw, GemmMPad)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + // B: weight tensor + const auto wei_gemmn_gemmk_grid_desc = + make_naive_tensor_descriptor_packed(make_tuple(K, C)); + + const auto wei_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor( + wei_gemmn_gemmk_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1Number)), + make_pass_through_transform(GemmN)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + // C: output tensor + const auto out_gemmmraw_gemmn_grid_desc = + make_naive_tensor_descriptor_packed(make_tuple(N * Ho * Wo, K)); + + const auto out_gemmm_gemmn_grid_desc = + transform_tensor_descriptor(out_gemmmraw_gemmn_grid_desc, + make_tuple(make_right_pad_transform(GemmMRaw, GemmMPad), + make_pass_through_transform(GemmN)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + return make_tuple(in_gemmk0_gemmm_gemmk1_grid_desc, + wei_gemmk0_gemmn_gemmk1_grid_desc, + out_gemmm_gemmn_grid_desc); + } + else if constexpr(ConvForwardSpecialization == + ConvolutionForwardSpecialization_t::Filter1x1Pad0) + { // 1x1, pad=0 + const index_t GemmK = Y * X * C; + assert(GemmK % GemmK1Number == 0); + + const index_t GemmK0 = GemmK / GemmK1Number; + + // A: input tensor + const auto in_n_hi_wi_c_grid_desc = + make_naive_tensor_descriptor_packed(make_tuple(N, Hi, Wi, C)); + + const auto in_n_ho_wo_c_grid_desc = transform_tensor_descriptor( + in_n_hi_wi_c_grid_desc, + make_tuple(make_pass_through_transform(N), + make_embed_transform(make_tuple(Ho), make_tuple(ConvStrideH)), + make_embed_transform(make_tuple(Wo), make_tuple(ConvStrideW)), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); + + const auto in_gemmk0_gemmmraw_gemmk1_grid_desc = transform_tensor_descriptor( + in_n_ho_wo_c_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1Number)), + make_merge_transform(make_tuple(N, Ho, Wo))), + make_tuple(Sequence<3>{}, Sequence<0, 1, 2>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + const auto in_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor( + in_gemmk0_gemmmraw_gemmk1_grid_desc, + make_tuple(make_pass_through_transform(GemmK0), + make_right_pad_transform(GemmMRaw, GemmMPad), + make_pass_through_transform(GemmK1Number)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + + // B: weight tensor + const auto wei_gemmn_gemmk_grid_desc = + make_naive_tensor_descriptor_packed(make_tuple(K, C)); + + const auto wei_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor( + wei_gemmn_gemmk_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1Number)), + make_pass_through_transform(GemmN)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + // C: output tensor + const auto out_gemmmraw_gemmn_grid_desc = + make_naive_tensor_descriptor_packed(make_tuple(N * Ho * Wo, K)); + + const auto out_gemmm_gemmn_grid_desc = + transform_tensor_descriptor(out_gemmmraw_gemmn_grid_desc, + make_tuple(make_right_pad_transform(GemmMRaw, GemmMPad), + make_pass_through_transform(GemmN)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + return make_tuple(in_gemmk0_gemmm_gemmk1_grid_desc, + wei_gemmk0_gemmn_gemmk1_grid_desc, + out_gemmm_gemmn_grid_desc); + } + else if constexpr(ConvForwardSpecialization == ConvolutionForwardSpecialization_t::OddC) + { // C = odd value + const index_t GemmKRaw = Y * X * C; + const index_t GemmK = math::integer_least_multiple(GemmKRaw, K0PerBlock * GemmK1Number); + const index_t GemmKPad = GemmK - GemmKRaw; + const index_t GemmK0 = GemmK / GemmK1Number; + + // A: input tensor + const auto in_n_hi_wi_c_grid_desc = + make_naive_tensor_descriptor_packed(make_tuple(N, Hi, Wi, C)); + + const auto in_n_hip_wip_c_grid_desc = transform_tensor_descriptor( + in_n_hi_wi_c_grid_desc, + make_tuple(make_pass_through_transform(N), + make_pad_transform(Hi, InLeftPadH, InRightPadH), + make_pad_transform(Wi, InLeftPadW, InRightPadW), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); + + const auto in_n_y_ho_x_wo_c_grid_desc = transform_tensor_descriptor( + in_n_hip_wip_c_grid_desc, + make_tuple( + make_pass_through_transform(N), + make_embed_transform(make_tuple(Y, Ho), make_tuple(ConvDilationH, ConvStrideH)), + make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW)), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{})); + + const auto in_gemmkraw_gemmmraw_grid_desc = + transform_tensor_descriptor(in_n_y_ho_x_wo_c_grid_desc, + make_tuple(make_merge_transform(make_tuple(Y, X, C)), + make_merge_transform(make_tuple(N, Ho, Wo))), + make_tuple(Sequence<1, 3, 5>{}, Sequence<0, 2, 4>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto in_gemmk_gemmm_grid_desc = transform_tensor_descriptor( + in_gemmkraw_gemmmraw_grid_desc, + make_tuple(make_right_pad_transform(GemmKRaw, GemmKPad), + make_right_pad_transform(GemmMRaw, GemmMPad)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto in_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor( + in_gemmk_gemmm_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1Number)), + make_pass_through_transform(GemmM)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + // B: weight tensor + const auto wei_k_yxc_grid_desc = + make_naive_tensor_descriptor_packed(make_tuple(K, Y * X * C)); + + const auto wei_gemmk_gemmn_grid_desc = transform_tensor_descriptor( + wei_k_yxc_grid_desc, + make_tuple(make_pass_through_transform(K), + make_right_pad_transform(GemmKRaw, GemmKPad)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<1>{}, Sequence<0>{})); + + const auto wei_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor( + wei_gemmk_gemmn_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1Number)), + make_pass_through_transform(GemmN)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + // C: output tensor + const auto out_nhowo_k_grid_desc = + make_naive_tensor_descriptor_packed(make_tuple(N * Ho * Wo, K)); + + const auto out_gemmmraw_gemmn_grid_desc = + transform_tensor_descriptor(out_nhowo_k_grid_desc, + make_tuple(make_pass_through_transform(N * Ho * Wo), + make_pass_through_transform(K)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto out_gemmm_gemmn_grid_desc = + transform_tensor_descriptor(out_gemmmraw_gemmn_grid_desc, + make_tuple(make_right_pad_transform(GemmMRaw, GemmMPad), + make_pass_through_transform(GemmN)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + return make_tuple(in_gemmk0_gemmm_gemmk1_grid_desc, + wei_gemmk0_gemmn_gemmk1_grid_desc, + out_gemmm_gemmn_grid_desc); + } + else + { + const index_t GemmK = Y * X * C; + assert(GemmK % GemmK1Number == 0); + + const index_t GemmK0 = GemmK / GemmK1Number; + + // A: input tensor + const auto in_n_hi_wi_c_grid_desc = + make_naive_tensor_descriptor_packed(make_tuple(N, Hi, Wi, C)); + + const auto in_n_hip_wip_c_grid_desc = transform_tensor_descriptor( + in_n_hi_wi_c_grid_desc, + make_tuple(make_pass_through_transform(N), + make_pad_transform(Hi, InLeftPadH, InRightPadH), + make_pad_transform(Wi, InLeftPadW, InRightPadW), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); + + const auto in_n_y_ho_x_wo_c_grid_desc = transform_tensor_descriptor( + in_n_hip_wip_c_grid_desc, + make_tuple( + make_pass_through_transform(N), + make_embed_transform(make_tuple(Y, Ho), make_tuple(ConvDilationH, ConvStrideH)), + make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW)), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{})); + + const auto in_gemmk_gemmmraw_grid_desc = + transform_tensor_descriptor(in_n_y_ho_x_wo_c_grid_desc, + make_tuple(make_merge_transform(make_tuple(Y, X, C)), + make_merge_transform(make_tuple(N, Ho, Wo))), + make_tuple(Sequence<1, 3, 5>{}, Sequence<0, 2, 4>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto in_gemmk0_gemmmraw_gemmk1_grid_desc = transform_tensor_descriptor( + in_gemmk_gemmmraw_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1Number)), + make_pass_through_transform(GemmMRaw)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + const auto in_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor( + in_gemmk0_gemmmraw_gemmk1_grid_desc, + make_tuple(make_pass_through_transform(GemmK0), + make_right_pad_transform(GemmMRaw, GemmMPad), + make_pass_through_transform(GemmK1Number)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + + // B: weight tensor + const auto wei_k_yxc_grid_desc = + make_naive_tensor_descriptor_packed(make_tuple(K, Y * X * C)); + + const auto wei_gemmk_gemmn_grid_desc = transform_tensor_descriptor( + wei_k_yxc_grid_desc, + make_tuple(make_pass_through_transform(K), make_pass_through_transform(Y * X * C)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<1>{}, Sequence<0>{})); + + const auto wei_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor( + wei_gemmk_gemmn_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1Number)), + make_pass_through_transform(GemmN)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + // C: output tensor + const auto out_nhowo_k_grid_desc = + make_naive_tensor_descriptor_packed(make_tuple(N * Ho * Wo, K)); + + const auto out_gemmmraw_gemmn_grid_desc = + transform_tensor_descriptor(out_nhowo_k_grid_desc, + make_tuple(make_pass_through_transform(N * Ho * Wo), + make_pass_through_transform(K)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto out_gemmm_gemmn_grid_desc = + transform_tensor_descriptor(out_gemmmraw_gemmn_grid_desc, + make_tuple(make_right_pad_transform(GemmMRaw, GemmMPad), + make_pass_through_transform(GemmN)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + return make_tuple(in_gemmk0_gemmm_gemmk1_grid_desc, + wei_gemmk0_gemmn_gemmk1_grid_desc, + out_gemmm_gemmn_grid_desc); + } + } + + using ABCGridDescs = decltype(MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N( + 1, 1, 1, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1})); + + using AGridDesc_K0_M_K1 = remove_cvref_t; + using BGridDesc_K0_N_K1 = remove_cvref_t; + using CGridDesc_M_N = remove_cvref_t; + + // GridwiseGemm + using GridwiseGemm = GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1< + BlockSize, + ABDataType, // TODO: distinguish A/B datatype + AccDataType, + CDataType, + InMemoryDataOperationEnum_t::Set, + AGridDesc_K0_M_K1, + BGridDesc_K0_N_K1, + CGridDesc_M_N, + InElementwiseOperation, + WeiElementwiseOperation, + OutElementwiseOperation, + MPerBlock, + NPerBlock, + K0PerBlock, + MPerXdl, + NPerXdl, + K1, + MXdlPerWave, + NXdlPerWave, + ABlockTransferThreadClusterLengths_K0_M_K1, + Sequence<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder, + Sequence<1, 0, 2>, // ABlockTransferSrcAccessOrder, + 2, // ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_K1, + false, // AThreadTransferSrcResetCoordinateAfterRun, + ABlockLdsAddExtraM, + BBlockTransferThreadClusterLengths_K0_N_K1, + Sequence<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder, + Sequence<1, 0, 2>, // BBlockTransferSrcAccessOrder, + 2, // BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_K1, + false, // BThreadTransferSrcResetCoordinateAfterRun, + BBlockLdsAddExtraN, + CShuffleMXdlPerWavePerShuffle, + CShuffleNXdlPerWavePerShuffle, + CBlockTransferClusterLengths_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl, + CBlockTransferScalarPerVector_NWaveNPerXdl>; + + // Argument + struct Argument : public BaseArgument + { + Argument(const InDataType* p_in_grid, + const WeiDataType* p_wei_grid, + OutDataType* p_out_grid, + ck::index_t N, + ck::index_t K, + ck::index_t C, + std::vector input_spatial_lengths, + std::vector filter_spatial_lengths, + std::vector output_spatial_lengths, + std::vector conv_filter_strides, + std::vector conv_filter_dilations, + std::vector input_left_pads, + std::vector input_right_pads, + ck::index_t M01, + ck::index_t N01, + InElementwiseOperation in_element_op, + WeiElementwiseOperation wei_element_op, + OutElementwiseOperation out_element_op) + : p_a_grid_{p_in_grid}, + p_b_grid_{p_wei_grid}, + p_c_grid_{p_out_grid}, + a_grid_desc_k0_m_k1_{}, + b_grid_desc_k0_n_k1_{}, + c_grid_desc_m_n_{}, + c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_{}, + block_2_ctile_map_{}, + M01_{M01}, + N01_{N01}, + in_element_op_{in_element_op}, + wei_element_op_{wei_element_op}, + out_element_op_{out_element_op}, + Conv_N_{N}, + Conv_K_{K}, + Conv_C_{C}, + filter_spatial_lengths_{filter_spatial_lengths}, + conv_filter_strides_{conv_filter_strides}, + input_left_pads_{input_left_pads}, + input_right_pads_{input_right_pads} + { + const auto descs = + DeviceOp::MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N(N, + K, + C, + input_spatial_lengths, + filter_spatial_lengths, + output_spatial_lengths, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads); + + a_grid_desc_k0_m_k1_ = descs[I0]; + b_grid_desc_k0_n_k1_ = descs[I1]; + c_grid_desc_m_n_ = descs[I2]; + + if(GridwiseGemm::CheckValidity( + a_grid_desc_k0_m_k1_, b_grid_desc_k0_n_k1_, c_grid_desc_m_n_, M01_, N01_)) + { + c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_ = + GridwiseGemm:: + MakeCGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl( + c_grid_desc_m_n_); + + block_2_ctile_map_ = GridwiseGemm::MakeBlock2CTileMap(c_grid_desc_m_n_, M01, N01); + } + } + + const ADataType* p_a_grid_; + const BDataType* p_b_grid_; + CDataType* p_c_grid_; + AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1_; + BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1_; + CGridDesc_M_N c_grid_desc_m_n_; + typename GridwiseGemm:: + CGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl + c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_; + typename GridwiseGemm::Block2CTileMap block_2_ctile_map_; + index_t M01_; + index_t N01_; + InElementwiseOperation in_element_op_; + WeiElementwiseOperation wei_element_op_; + OutElementwiseOperation out_element_op_; + // for checking IsSupportedArgument() + index_t Conv_N_; + index_t Conv_K_; + index_t Conv_C_; + std::vector filter_spatial_lengths_; + std::vector conv_filter_strides_; + std::vector input_left_pads_; + std::vector input_right_pads_; + }; + + // Invoker + struct Invoker : public BaseInvoker + { + using Argument = DeviceOp::Argument; + + float Run(const Argument& arg, int nrepeat = 1) + { + { + std::cout << "arg.a_grid_desc_k0_m_k1_{" << arg.a_grid_desc_k0_m_k1_.GetLength(I0) + << ", " << arg.a_grid_desc_k0_m_k1_.GetLength(I1) << ", " + << arg.a_grid_desc_k0_m_k1_.GetLength(I2) << "}" << std::endl; + + std::cout << "arg.b_grid_desc_k0_n_k1_{" << arg.b_grid_desc_k0_n_k1_.GetLength(I0) + << ", " << arg.b_grid_desc_k0_n_k1_.GetLength(I1) << ", " + << arg.b_grid_desc_k0_n_k1_.GetLength(I2) << "}" << std::endl; + + std::cout << "arg.c_grid_desc_m_n_{ " << arg.c_grid_desc_m_n_.GetLength(I0) << ", " + << arg.c_grid_desc_m_n_.GetLength(I1) << "}" << std::endl; + + std::cout + << "arg.c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_" + "nwavenperxdl_{ " + << arg.c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_ + .GetLength(I0) + << ", " + << arg.c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_ + .GetLength(I1) + << ", " + << arg.c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_ + .GetLength(I2) + << ", " + << arg.c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_ + .GetLength(I3) + << ", " + << arg.c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_ + .GetLength(I4) + << ", " + << arg.c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_ + .GetLength(I5) + << "}" << std::endl; + } + + if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_, + arg.b_grid_desc_k0_n_k1_, + arg.c_grid_desc_m_n_, + arg.M01_, + arg.N01_)) + { + throw std::runtime_error( + "wrong! GridwiseGemm_km_kn_m0m1n0n1_xdlops_v3r1 has invalid setting"); + } + + const index_t grid_size = GridwiseGemm::CalculateGridSize(arg.c_grid_desc_m_n_); + + const auto K0 = arg.a_grid_desc_k0_m_k1_.GetLength(I0); + + const bool has_main_k0_block_loop = GridwiseGemm::CalculateHasMainK0BlockLoop(K0); + + float ave_time = 0; + + if(has_main_k0_block_loop) + { + const auto kernel = kernel_gemm_xdlops_v3r1< + GridwiseGemm, + ADataType, // TODO: distiguish A/B datatype + CDataType, + remove_reference_t, + remove_reference_t, + remove_reference_t< + typename GridwiseGemm:: + CGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl>, + InElementwiseOperation, + WeiElementwiseOperation, + OutElementwiseOperation, + remove_reference_t, + true>; + + ave_time = launch_and_time_kernel( + kernel, + nrepeat, + dim3(grid_size), + dim3(BlockSize), + 0, + arg.p_a_grid_, + arg.p_b_grid_, + arg.p_c_grid_, + arg.a_grid_desc_k0_m_k1_, + arg.b_grid_desc_k0_n_k1_, + arg.c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_, + arg.in_element_op_, + arg.wei_element_op_, + arg.out_element_op_, + arg.block_2_ctile_map_); + } + else + { + const auto kernel = kernel_gemm_xdlops_v3r1< + GridwiseGemm, + ADataType, // TODO: distiguish A/B datatype + CDataType, + remove_reference_t, + remove_reference_t, + remove_reference_t< + typename GridwiseGemm:: + CGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl>, + InElementwiseOperation, + WeiElementwiseOperation, + OutElementwiseOperation, + remove_reference_t, + false>; + + ave_time = launch_and_time_kernel( + kernel, + nrepeat, + dim3(grid_size), + dim3(BlockSize), + 0, + arg.p_a_grid_, + arg.p_b_grid_, + arg.p_c_grid_, + arg.a_grid_desc_k0_m_k1_, + arg.b_grid_desc_k0_n_k1_, + arg.c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_, + arg.in_element_op_, + arg.wei_element_op_, + arg.out_element_op_, + arg.block_2_ctile_map_); + } + + return ave_time; + } + + float Run(const BaseArgument* p_arg, int nrepeat = 1) override + { + return Run(*dynamic_cast(p_arg), nrepeat); + } + }; + + static constexpr bool IsValidCompilationParameter() + { + // TODO: properly implement this check + return true; + } + + static bool IsSupportedArgument(const Argument& arg) + { + if constexpr(ConvForwardSpecialization == + ConvolutionForwardSpecialization_t::Filter1x1Stride1Pad0) + { + // check if it's 1x1, stride=1 conv + if(!(arg.filter_spatial_lengths_[0] == 1 && arg.filter_spatial_lengths_[1] == 1 && + arg.conv_filter_strides_[0] == 1 && arg.conv_filter_strides_[1] == 1 && + arg.input_left_pads_[0] == 0 && arg.input_left_pads_[1] == 0 && + arg.input_right_pads_[0] == 0 && arg.input_right_pads_[1] == 0)) + { + return false; + } + } + else if constexpr(ConvForwardSpecialization == + ConvolutionForwardSpecialization_t::Filter1x1Pad0) + { + // check if it's 1x1 conv + if(!(arg.filter_spatial_lengths_[0] == 1 && arg.filter_spatial_lengths_[1] == 1 && + arg.input_left_pads_[0] == 0 && arg.input_left_pads_[1] == 0 && + arg.input_right_pads_[0] == 0 && arg.input_right_pads_[1] == 0)) + { + return false; + } + } + + // vector load A/B matrix from global memory + if(!(ABlockTransferSrcVectorDim == 2 && BBlockTransferSrcVectorDim == 2 && + arg.Conv_C_ % ABlockTransferSrcScalarPerVector == 0 && + arg.Conv_C_ % BBlockTransferSrcScalarPerVector == 0)) + { + return false; + } + + // vector store C matrix into global memory + if(!(arg.Conv_K_ % CBlockTransferScalarPerVector_NWaveNPerXdl == 0)) + { + return false; + } + + // Gridwise GEMM size + return GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_, + arg.b_grid_desc_k0_n_k1_, + arg.c_grid_desc_m_n_, + arg.M01_, + arg.N01_); + } + + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + return IsSupportedArgument(*dynamic_cast(p_arg)); + } + + static auto MakeArgument(const InDataType* p_in_grid, + const WeiDataType* p_wei_grid, + OutDataType* p_out_grid, + ck::index_t N, + ck::index_t K, + ck::index_t C, + std::vector input_spatial_lengths, + std::vector filter_spatial_lengths, + std::vector output_spatial_lengths, + std::vector conv_filter_strides, + std::vector conv_filter_dilations, + std::vector input_left_pads, + std::vector input_right_pads, + InElementwiseOperation in_element_op, + WeiElementwiseOperation wei_element_op, + OutElementwiseOperation out_element_op) + { + return Argument{p_in_grid, + p_wei_grid, + p_out_grid, + N, + K, + C, + input_spatial_lengths, + filter_spatial_lengths, + output_spatial_lengths, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + 1, + 1, + in_element_op, + wei_element_op, + out_element_op}; + } + + static auto MakeInvoker() { return Invoker{}; } + + std::unique_ptr + MakeArgumentPointer(const void* p_in_grid, + const void* p_wei_grid, + void* p_out_grid, + ck::index_t N, + ck::index_t K, + ck::index_t C, + std::vector input_spatial_lengths, + std::vector filter_spatial_lengths, + std::vector output_spatial_lengths, + std::vector conv_filter_strides, + std::vector conv_filter_dilations, + std::vector input_left_pads, + std::vector input_right_pads, + InElementwiseOperation in_element_op, + WeiElementwiseOperation wei_element_op, + OutElementwiseOperation out_element_op) override + { + return std::make_unique(static_cast(p_in_grid), + static_cast(p_wei_grid), + static_cast(p_out_grid), + N, + K, + C, + input_spatial_lengths, + filter_spatial_lengths, + output_spatial_lengths, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + 1, + 1, + in_element_op, + wei_element_op, + out_element_op); + } + + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(Invoker{}); + } + + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + // clang-format off + str << "DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K" + << "<" + << BlockSize << ", " + << MPerBlock << ", " + << NPerBlock << ", " + << K0PerBlock + << ">"; + // clang-format on + + return str.str(); + } +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck +#endif diff --git a/device_operation/include/device_conv_fwd_xdl_nhwc_kyxc_nhwk.hpp b/device_operation/include/device_conv2d_fwd_xdl_nhwc_kyxc_nhwk.hpp similarity index 53% rename from device_operation/include/device_conv_fwd_xdl_nhwc_kyxc_nhwk.hpp rename to device_operation/include/device_conv2d_fwd_xdl_nhwc_kyxc_nhwk.hpp index 87ab16f6f6..6093f31e49 100644 --- a/device_operation/include/device_conv_fwd_xdl_nhwc_kyxc_nhwk.hpp +++ b/device_operation/include/device_conv2d_fwd_xdl_nhwc_kyxc_nhwk.hpp @@ -1,23 +1,23 @@ -#ifndef DEVICE_CONV_FWD_XDL_NHWC_KYXC_NHWK_HPP -#define DEVICE_CONV_FWD_XDL_NHWC_KYXC_NHWK_HPP +#ifndef DEVICE_CONV2D_FWD_XDL_NHWC_KYXC_NHWK_HPP +#define DEVICE_CONV2D_FWD_XDL_NHWC_KYXC_NHWK_HPP #include +#include #include "device.hpp" #include "device_base.hpp" -#include "device_conv.hpp" +#include "device_conv_fwd.hpp" +#include "convolution_forward_specialization.hpp" #include "common_header.hpp" #include "tensor_layout.hpp" #include "tensor_descriptor.hpp" #include "tensor_descriptor_helper.hpp" #include "gridwise_gemm_xdlops_v2r3.hpp" -#include "device_conv.hpp" -#include "device_conv_fwd_xdl.hpp" namespace ck { namespace tensor_operation { namespace device { -// specialization for 2D conv: in[n, hi, wi, c] * wei[k, y, x, c] = out[n, ho, wo, k] +// out[N, Ho, Wo, K] = in[N, Hi, Wi, C] * wei[K, Y, X, C] template -struct DeviceConvFwdXdl< - 2, // ck::index_t NDimSpatial, - InDataType, // typename InDataType, - WeiDataType, // typename WeiDataType, - OutDataType, // typename OutDataType, - AccDataType, // typename AccDataType, - ck::tensor_layout::convolution::NHWC, // typename InLayout, - ck::tensor_layout::convolution::KYXC, // typename WeiLayout, - ck::tensor_layout::convolution::NHWK, // typename OutLayout, - InElementwiseOperation, // typename InElementwiseOperation, - WeiElementwiseOperation, // typename WeiElementwiseOperation, - OutElementwiseOperation, // typename OutElementwiseOperation, - BlockSize, // ck::index_t BlockSize, - MPerBlock, // ck::index_t MPerBlock, - NPerBlock, // ck::index_t NPerBlock, - K0PerBlock, // ck::index_t K0PerBlock, - K1, // ck::index_t K1, - MPerXDL, // ck::index_t MPerXDL, - NPerXDL, // ck::index_t NPerXDL, - MXdlPerWave, // ck::index_t MXdlPerWave, - NXdlPerWave, // ck::index_t NXdlPerWave, - ABlockTransferThreadSliceLengths_K0_M_K1, // typename ABlockTransferThreadSliceLengths_K0_M_K1, - ABlockTransferThreadClusterLengths_K0_M_K1, // typename - // ABlockTransferThreadClusterLengths_K0_M_K1, - ABlockTransferThreadClusterArrangeOrder, // typename ABlockTransferThreadClusterArrangeOrder, - ABlockTransferSrcAccessOrder, // typename ABlockTransferSrcAccessOrder, - ABlockTransferSrcVectorDim, // ck::index_t ABlockTransferSrcVectorDim, - ABlockTransferSrcScalarPerVector, // ck::index_t ABlockTransferSrcScalarPerVector, - ABlockTransferDstScalarPerVector_K1, // ck::index_t ABlockTransferDstScalarPerVector_K1, - BBlockTransferThreadSliceLengths_K0_N_K1, // typename BBlockTransferThreadSliceLengths_K0_N_K1, - BBlockTransferThreadClusterLengths_K0_N_K1, // typename - // BBlockTransferThreadClusterLengths_K0_N_K1, - BBlockTransferThreadClusterArrangeOrder, // typename BBlockTransferThreadClusterArrangeOrder, - BBlockTransferSrcAccessOrder, // typename BBlockTransferSrcAccessOrder, - BBlockTransferSrcVectorDim, // ck::index_t BBlockTransferSrcVectorDim, - BBlockTransferSrcScalarPerVector, // ck::index_t BBlockTransferSrcScalarPerVector, - BBlockTransferDstScalarPerVector_K1, // ck::index_t BBlockTransferDstScalarPerVector_K1, - CThreadTransferSrcDstVectorDim, // ck::index_t CThreadTransferSrcDstVectorDim, - CThreadTransferDstScalarPerVector, // ck::index_t CThreadTransferDstScalarPerVector, - ABlockLdsAddExtraM, // bool ABlockLdsAddExtraM, - BBlockLdsAddExtraN // bool BBlockLdsAddExtraN> - > + ck::index_t CThreadTransferDstScalarPerVector> +struct DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K : public DeviceConvFwd { + using DeviceOp = DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K; + using ADataType = InDataType; using BDataType = WeiDataType; using CDataType = OutDataType; @@ -103,7 +63,6 @@ struct DeviceConvFwdXdl< // TODO make A/B datatype different using ABDataType = InDataType; - // TODO make it support any # of spatial dimensions static constexpr index_t NDimSpatial = 2; static constexpr auto I0 = Number<0>{}; @@ -159,88 +118,189 @@ struct DeviceConvFwdXdl< const index_t GemmK0 = GemmK / GemmK1Number; - // A: input tensor - const auto in_n_hi_wi_c_grid_desc = - make_naive_tensor_descriptor_packed(make_tuple(N, Hi, Wi, C)); + if constexpr(ConvForwardSpecialization == + ConvolutionForwardSpecialization_t::Filter1x1Stride1Pad0) + { + // A: input tensor + const auto in_gemmmraw_gemmk_grid_desc = + make_naive_tensor_descriptor_packed(make_tuple(N * Ho * Wo, C)); - const auto in_n_hip_wip_c_grid_desc = transform_tensor_descriptor( - in_n_hi_wi_c_grid_desc, - make_tuple(make_pass_through_transform(N), - make_pad_transform(Hi, InLeftPadH, InRightPadH), - make_pad_transform(Wi, InLeftPadW, InRightPadW), - make_pass_through_transform(C)), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); + const auto in_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor( + in_gemmmraw_gemmk_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1Number)), + make_right_pad_transform(GemmMRaw, GemmMPad)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); - const auto in_n_y_ho_x_wo_c_grid_desc = transform_tensor_descriptor( - in_n_hip_wip_c_grid_desc, - make_tuple( - make_pass_through_transform(N), - make_embed_transform(make_tuple(Y, Ho), make_tuple(ConvDilationH, ConvStrideH)), - make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW)), - make_pass_through_transform(C)), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), - make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{})); + // B: weight tensor + const auto wei_gemmn_gemmk_grid_desc = + make_naive_tensor_descriptor_packed(make_tuple(K, C)); - const auto in_gemmk_gemmmraw_grid_desc = - transform_tensor_descriptor(in_n_y_ho_x_wo_c_grid_desc, - make_tuple(make_merge_transform(make_tuple(Y, X, C)), - make_merge_transform(make_tuple(N, Ho, Wo))), - make_tuple(Sequence<1, 3, 5>{}, Sequence<0, 2, 4>{}), - make_tuple(Sequence<0>{}, Sequence<1>{})); + const auto wei_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor( + wei_gemmn_gemmk_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1Number)), + make_pass_through_transform(GemmN)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); - const auto in_gemmk0_gemmmraw_gemmk1_grid_desc = transform_tensor_descriptor( - in_gemmk_gemmmraw_grid_desc, - make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1Number)), - make_pass_through_transform(GemmMRaw)), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + // C: output tensor + const auto out_gemmmraw_gemmn_grid_desc = + make_naive_tensor_descriptor_packed(make_tuple(N * Ho * Wo, K)); - const auto in_gemmk0_gemmm_gemmk1_grid_desc = - transform_tensor_descriptor(in_gemmk0_gemmmraw_gemmk1_grid_desc, - make_tuple(make_pass_through_transform(GemmK0), - make_right_pad_transform(GemmMRaw, GemmMPad), - make_pass_through_transform(GemmK1Number)), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + const auto out_gemmm_gemmn_grid_desc = + transform_tensor_descriptor(out_gemmmraw_gemmn_grid_desc, + make_tuple(make_right_pad_transform(GemmMRaw, GemmMPad), + make_pass_through_transform(GemmN)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); - // B: weight tensor - const auto wei_k_yxc_grid_desc = - make_naive_tensor_descriptor_packed(make_tuple(K, Y * X * C)); + return make_tuple(in_gemmk0_gemmm_gemmk1_grid_desc, + wei_gemmk0_gemmn_gemmk1_grid_desc, + out_gemmm_gemmn_grid_desc); + } + else if constexpr(ConvForwardSpecialization == + ConvolutionForwardSpecialization_t::Filter1x1Pad0) + { + // A: input tensor + const auto in_n_hi_wi_c_grid_desc = + make_naive_tensor_descriptor_packed(make_tuple(N, Hi, Wi, C)); - const auto wei_gemmk_gemmn_grid_desc = transform_tensor_descriptor( - wei_k_yxc_grid_desc, - make_tuple(make_pass_through_transform(K), make_pass_through_transform(Y * X * C)), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<1>{}, Sequence<0>{})); + const auto in_n_ho_wo_c_grid_desc = transform_tensor_descriptor( + in_n_hi_wi_c_grid_desc, + make_tuple(make_pass_through_transform(N), + make_embed_transform(make_tuple(Ho), make_tuple(ConvStrideH)), + make_embed_transform(make_tuple(Wo), make_tuple(ConvStrideW)), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); - const auto wei_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor( - wei_gemmk_gemmn_grid_desc, - make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1Number)), - make_pass_through_transform(GemmN)), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + const auto in_gemmk0_gemmmraw_gemmk1_grid_desc = transform_tensor_descriptor( + in_n_ho_wo_c_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1Number)), + make_merge_transform(make_tuple(N, Ho, Wo))), + make_tuple(Sequence<3>{}, Sequence<0, 1, 2>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); - // C: output tensor - const auto out_nhowo_k_grid_desc = - make_naive_tensor_descriptor_packed(make_tuple(N * Ho * Wo, K)); + const auto in_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor( + in_gemmk0_gemmmraw_gemmk1_grid_desc, + make_tuple(make_pass_through_transform(GemmK0), + make_right_pad_transform(GemmMRaw, GemmMPad), + make_pass_through_transform(GemmK1Number)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); - const auto out_gemmmraw_gemmn_grid_desc = transform_tensor_descriptor( - out_nhowo_k_grid_desc, - make_tuple(make_pass_through_transform(N * Ho * Wo), make_pass_through_transform(K)), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<0>{}, Sequence<1>{})); + // B: weight tensor + const auto wei_gemmn_gemmk_grid_desc = + make_naive_tensor_descriptor_packed(make_tuple(K, C)); - const auto out_gemmm_gemmn_grid_desc = - transform_tensor_descriptor(out_gemmmraw_gemmn_grid_desc, - make_tuple(make_right_pad_transform(GemmMRaw, GemmMPad), - make_pass_through_transform(GemmN)), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<0>{}, Sequence<1>{})); + const auto wei_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor( + wei_gemmn_gemmk_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1Number)), + make_pass_through_transform(GemmN)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); - return make_tuple(in_gemmk0_gemmm_gemmk1_grid_desc, - wei_gemmk0_gemmn_gemmk1_grid_desc, - out_gemmm_gemmn_grid_desc); + // C: output tensor + const auto out_gemmmraw_gemmn_grid_desc = + make_naive_tensor_descriptor_packed(make_tuple(N * Ho * Wo, K)); + + const auto out_gemmm_gemmn_grid_desc = + transform_tensor_descriptor(out_gemmmraw_gemmn_grid_desc, + make_tuple(make_right_pad_transform(GemmMRaw, GemmMPad), + make_pass_through_transform(GemmN)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + return make_tuple(in_gemmk0_gemmm_gemmk1_grid_desc, + wei_gemmk0_gemmn_gemmk1_grid_desc, + out_gemmm_gemmn_grid_desc); + } + else + { + // A: input tensor + const auto in_n_hi_wi_c_grid_desc = + make_naive_tensor_descriptor_packed(make_tuple(N, Hi, Wi, C)); + + const auto in_n_hip_wip_c_grid_desc = transform_tensor_descriptor( + in_n_hi_wi_c_grid_desc, + make_tuple(make_pass_through_transform(N), + make_pad_transform(Hi, InLeftPadH, InRightPadH), + make_pad_transform(Wi, InLeftPadW, InRightPadW), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); + + const auto in_n_y_ho_x_wo_c_grid_desc = transform_tensor_descriptor( + in_n_hip_wip_c_grid_desc, + make_tuple( + make_pass_through_transform(N), + make_embed_transform(make_tuple(Y, Ho), make_tuple(ConvDilationH, ConvStrideH)), + make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW)), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{})); + + const auto in_gemmk_gemmmraw_grid_desc = + transform_tensor_descriptor(in_n_y_ho_x_wo_c_grid_desc, + make_tuple(make_merge_transform(make_tuple(Y, X, C)), + make_merge_transform(make_tuple(N, Ho, Wo))), + make_tuple(Sequence<1, 3, 5>{}, Sequence<0, 2, 4>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto in_gemmk0_gemmmraw_gemmk1_grid_desc = transform_tensor_descriptor( + in_gemmk_gemmmraw_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1Number)), + make_pass_through_transform(GemmMRaw)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + const auto in_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor( + in_gemmk0_gemmmraw_gemmk1_grid_desc, + make_tuple(make_pass_through_transform(GemmK0), + make_right_pad_transform(GemmMRaw, GemmMPad), + make_pass_through_transform(GemmK1Number)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + + // B: weight tensor + const auto wei_k_yxc_grid_desc = + make_naive_tensor_descriptor_packed(make_tuple(K, Y * X * C)); + + const auto wei_gemmk_gemmn_grid_desc = transform_tensor_descriptor( + wei_k_yxc_grid_desc, + make_tuple(make_pass_through_transform(K), make_pass_through_transform(Y * X * C)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<1>{}, Sequence<0>{})); + + const auto wei_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor( + wei_gemmk_gemmn_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1Number)), + make_pass_through_transform(GemmN)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + // C: output tensor + const auto out_nhowo_k_grid_desc = + make_naive_tensor_descriptor_packed(make_tuple(N * Ho * Wo, K)); + + const auto out_gemmmraw_gemmn_grid_desc = + transform_tensor_descriptor(out_nhowo_k_grid_desc, + make_tuple(make_pass_through_transform(N * Ho * Wo), + make_pass_through_transform(K)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto out_gemmm_gemmn_grid_desc = + transform_tensor_descriptor(out_gemmmraw_gemmn_grid_desc, + make_tuple(make_right_pad_transform(GemmMRaw, GemmMPad), + make_pass_through_transform(GemmN)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + return make_tuple(in_gemmk0_gemmm_gemmk1_grid_desc, + wei_gemmk0_gemmn_gemmk1_grid_desc, + out_gemmm_gemmn_grid_desc); + } } using ABCGridDescs = decltype(MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N( @@ -250,46 +310,6 @@ struct DeviceConvFwdXdl< using BGridDesc_K0_N_K1 = remove_cvref_t; using CGridDesc_M_N = remove_cvref_t; - // TODO remove these hacks - static constexpr auto a_k0_m_k1_grid_step_hacks = make_tuple( - make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0>{}, // 0+: K0 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0>{}, // 1+: M - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0>{}), // 2+: K1 - make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0>{}, // 0-: K0 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0>{}, // 1-: M - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0>{})); // 2-: K1 - - static constexpr auto b_k0_n_k1_grid_step_hacks = - make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{}, // 0+: K0 - Sequence<0, 0, 0, 0, 0>{}, // 1+: N - Sequence<0, 0, 0, 0, 0>{}), // 2+: K1 - make_tuple(Sequence<0, 0, 0, 0, 0>{}, // 0-: K0 - Sequence<0, 0, 0, 0, 0>{}, // 1-: N - Sequence<0, 0, 0, 0, 0>{})); // 2-: K1 - - static constexpr auto c_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks = - make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0+: M0 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1+: N0 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2+: M1 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3+: N1 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4+: M2 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5+: M3 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6+: M4 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}), // 7+: N2 - make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0-: M0 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1-: N0 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2-: M1 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3-: N1 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4-: M2 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5-: M3 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6-: M4 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{})); // 7-: N2 - - static constexpr auto a_k0_m_k1_grid_move_slice_window_step_hacks = - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 0, 0, 0, 0, 0>{}; - - static constexpr auto b_k0_n_k1_grid_move_slice_window_step_hacks = Sequence<0, 0, 0, 0, 0>{}; - // GridwiseGemm using GridwiseGemm = GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3< BlockSize, @@ -311,7 +331,6 @@ struct DeviceConvFwdXdl< K1, MXdlPerWave, NXdlPerWave, - ABlockTransferThreadSliceLengths_K0_M_K1, ABlockTransferThreadClusterLengths_K0_M_K1, Sequence<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder, Sequence<1, 0, 2>, // ABlockTransferSrcAccessOrder, @@ -319,30 +338,18 @@ struct DeviceConvFwdXdl< ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_K1, false, // AThreadTransferSrcResetCoordinateAfterRun, - BBlockTransferThreadSliceLengths_K0_N_K1, + ABlockLdsAddExtraM, BBlockTransferThreadClusterLengths_K0_N_K1, Sequence<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder, Sequence<1, 0, 2>, // BBlockTransferSrcAccessOrder, 2, // BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_K1, - false, // BThreadTransferSrcResetCoordinateAfterRun, + false, // BThreadTransferSrcResetCoordinateAfterRun, + BBlockLdsAddExtraN, Sequence<2, 3, 0, 1, 7, 5, 4, 6>, // CThreadTransferSrcDstAccessOrder, 7, // CThreadTransferSrcDstVectorDim, - CThreadTransferDstScalarPerVector, - decltype(a_k0_m_k1_grid_step_hacks), // AGridStepHacks, - decltype(b_k0_n_k1_grid_step_hacks), // BGridStepHacks, - decltype(c_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks), // CGridStepHacks, - decltype(a_k0_m_k1_grid_move_slice_window_step_hacks), // AGridMoveSliceWindowStepHacks, - decltype(b_k0_n_k1_grid_move_slice_window_step_hacks), // BGridMoveSliceWindowStepHacks, - false, // CAccessOrderMRepeatNRepeat, - ABlockLdsAddExtraM, - BBlockLdsAddExtraN>; - - using CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2 = - decltype(GridwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(CGridDesc_M_N{})); - - using Block2CTileMap = decltype(GridwiseGemm::MakeBlock2CTileMap(CGridDesc_M_N{}, 1, 1)); + CThreadTransferDstScalarPerVector>; // Argument struct Argument : public BaseArgument @@ -377,19 +384,26 @@ struct DeviceConvFwdXdl< N01_{N01}, in_element_op_{in_element_op}, wei_element_op_{wei_element_op}, - out_element_op_{out_element_op} + out_element_op_{out_element_op}, + Conv_N_{N}, + Conv_K_{K}, + Conv_C_{C}, + filter_spatial_lengths_{filter_spatial_lengths}, + conv_filter_strides_{conv_filter_strides}, + input_left_pads_{input_left_pads}, + input_right_pads_{input_right_pads} { - const auto descs = DeviceConvFwdXdl::MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N( - N, - K, - C, - input_spatial_lengths, - filter_spatial_lengths, - output_spatial_lengths, - conv_filter_strides, - conv_filter_dilations, - input_left_pads, - input_right_pads); + const auto descs = + DeviceOp::MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N(N, + K, + C, + input_spatial_lengths, + filter_spatial_lengths, + output_spatial_lengths, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads); a_grid_desc_k0_m_k1_ = descs[I0]; b_grid_desc_k0_n_k1_ = descs[I1]; @@ -412,19 +426,28 @@ struct DeviceConvFwdXdl< AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1_; BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1_; CGridDesc_M_N c_grid_desc_m_n_; - CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2 c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_; - Block2CTileMap block_2_ctile_map_; + typename GridwiseGemm::CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2 + c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_; + typename GridwiseGemm::Block2CTileMap block_2_ctile_map_; index_t M01_; index_t N01_; InElementwiseOperation in_element_op_; WeiElementwiseOperation wei_element_op_; OutElementwiseOperation out_element_op_; + // for checking IsSupportedArgument() + index_t Conv_N_; + index_t Conv_K_; + index_t Conv_C_; + std::vector filter_spatial_lengths_; + std::vector conv_filter_strides_; + std::vector input_left_pads_; + std::vector input_right_pads_; }; // Invoker struct Invoker : public BaseInvoker { - using Argument = DeviceConvFwdXdl::Argument; + using Argument = DeviceOp::Argument; float Run(const Argument& arg, int nrepeat = 1) { @@ -465,13 +488,13 @@ struct DeviceConvFwdXdl< GridwiseGemm, ADataType, // TODO: distiguish A/B datatype CDataType, - remove_reference_t, - remove_reference_t, - remove_reference_t, + remove_reference_t, + remove_reference_t, + remove_reference_t, InElementwiseOperation, WeiElementwiseOperation, OutElementwiseOperation, - remove_reference_t, + remove_reference_t, true>; ave_time = launch_and_time_kernel(kernel, @@ -496,13 +519,13 @@ struct DeviceConvFwdXdl< GridwiseGemm, ADataType, // TODO: distiguish A/B datatype CDataType, - remove_reference_t, - remove_reference_t, - remove_reference_t, + remove_reference_t, + remove_reference_t, + remove_reference_t, InElementwiseOperation, WeiElementwiseOperation, OutElementwiseOperation, - remove_reference_t, + remove_reference_t, false>; ave_time = launch_and_time_kernel(kernel, @@ -525,7 +548,6 @@ struct DeviceConvFwdXdl< return ave_time; } - // polymorphic float Run(const BaseArgument* p_arg, int nrepeat = 1) override { return Run(*dynamic_cast(p_arg), nrepeat); @@ -540,6 +562,45 @@ struct DeviceConvFwdXdl< static bool IsSupportedArgument(const Argument& arg) { + if constexpr(ConvForwardSpecialization == + ConvolutionForwardSpecialization_t::Filter1x1Stride1Pad0) + { + // check if it's 1x1, stride=1 conv + if(!(arg.filter_spatial_lengths_[0] == 1 && arg.filter_spatial_lengths_[1] == 1 && + arg.conv_filter_strides_[0] == 1 && arg.conv_filter_strides_[1] == 1 && + arg.input_left_pads_[0] == 0 && arg.input_left_pads_[1] == 0 && + arg.input_right_pads_[0] == 0 && arg.input_right_pads_[1] == 0)) + { + return false; + } + } + else if constexpr(ConvForwardSpecialization == + ConvolutionForwardSpecialization_t::Filter1x1Pad0) + { + // check if it's 1x1 conv + if(!(arg.filter_spatial_lengths_[0] == 1 && arg.filter_spatial_lengths_[1] == 1 && + arg.input_left_pads_[0] == 0 && arg.input_left_pads_[1] == 0 && + arg.input_right_pads_[0] == 0 && arg.input_right_pads_[1] == 0)) + { + return false; + } + } + + // vector load A/B matrix from global memory + if(!(ABlockTransferSrcVectorDim == 2 && BBlockTransferSrcVectorDim == 2 && + arg.Conv_C_ % ABlockTransferSrcScalarPerVector == 0 && + arg.Conv_C_ % BBlockTransferSrcScalarPerVector == 0)) + { + return false; + } + + // vector store C matrix into global memory + if(!(arg.Conv_K_ % CThreadTransferDstScalarPerVector == 0)) + { + return false; + } + + // Gridwise GEMM size return GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_, arg.b_grid_desc_k0_n_k1_, arg.c_grid_desc_m_n_, @@ -547,7 +608,6 @@ struct DeviceConvFwdXdl< arg.N01_); } - // polymorphic bool IsSupportedArgument(const BaseArgument* p_arg) override { return IsSupportedArgument(*dynamic_cast(p_arg)); @@ -592,7 +652,6 @@ struct DeviceConvFwdXdl< static auto MakeInvoker() { return Invoker{}; } - // polymorphic std::unique_ptr MakeArgumentPointer(const void* p_in_grid, const void* p_wei_grid, @@ -631,11 +690,27 @@ struct DeviceConvFwdXdl< out_element_op); } - // polymorphic std::unique_ptr MakeInvokerPointer() override { return std::make_unique(Invoker{}); } + + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + // clang-format off + str << "DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K" + << "<" + << BlockSize << ", " + << MPerBlock << ", " + << NPerBlock << ", " + << K0PerBlock + << ">"; + // clang-format on + + return str.str(); + } }; // namespace device } // namespace device diff --git a/device_operation/include/device_conv_fwd.hpp b/device_operation/include/device_conv_fwd.hpp new file mode 100644 index 0000000000..d53e56f18b --- /dev/null +++ b/device_operation/include/device_conv_fwd.hpp @@ -0,0 +1,46 @@ +#ifndef DEVICE_CONV_FWD_HPP +#define DEVICE_CONV_FWD_HPP + +#include +#include "device_base.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +template +struct DeviceConvFwd : public BaseOperator +{ + virtual std::unique_ptr + MakeArgumentPointer(const void* p_in, + const void* p_wei, + void* p_out, + ck::index_t N, + ck::index_t K, + ck::index_t C, + std::vector input_spatial_lengths, + std::vector filter_spatial_lengths, + std::vector output_spatial_lengths, + std::vector conv_filter_strides, + std::vector conv_filter_dilations, + std::vector input_left_pads, + std::vector input_right_pads, + InElementwiseOperation in_element_op, + WeiElementwiseOperation wei_element_op, + OutElementwiseOperation out_element_op) = 0; + + virtual std::unique_ptr MakeInvokerPointer() = 0; +}; + +template +using DeviceConvFwdPtr = std::unique_ptr< + DeviceConvFwd>; + +} // namespace device +} // namespace tensor_operation +} // namespace ck +#endif diff --git a/device_operation/include/device_conv_fwd_bias_activation.hpp b/device_operation/include/device_conv_fwd_bias_activation.hpp new file mode 100644 index 0000000000..77d4b7fb95 --- /dev/null +++ b/device_operation/include/device_conv_fwd_bias_activation.hpp @@ -0,0 +1,49 @@ +#ifndef DEVICE_CONV_FWD_BIAS_ACTIVATION_HPP +#define DEVICE_CONV_FWD_BIAS_ACTIVATION_HPP + +#include +#include "device_base.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +template +struct DeviceConvFwdBiasActivation : public BaseOperator +{ + virtual std::unique_ptr + MakeArgumentPointer(const void* p_in, + const void* p_wei, + void* p_out, + const void* p_bias, + ck::index_t N, + ck::index_t K, + ck::index_t C, + std::vector input_spatial_lengths, + std::vector filter_spatial_lengths, + std::vector output_spatial_lengths, + std::vector conv_filter_strides, + std::vector conv_filter_dilations, + std::vector input_left_pads, + std::vector input_right_pads, + InElementwiseOperation in_element_op, + WeiElementwiseOperation wei_element_op, + OutElementwiseOperation out_element_op) = 0; + + virtual std::unique_ptr MakeInvokerPointer() = 0; +}; + +template +using DeviceConvFwdBiasActivationPtr = + std::unique_ptr>; + +} // namespace device +} // namespace tensor_operation +} // namespace ck +#endif diff --git a/device_operation/include/device_conv_fwd_bias_activation_add.hpp b/device_operation/include/device_conv_fwd_bias_activation_add.hpp new file mode 100644 index 0000000000..2f8e780b78 --- /dev/null +++ b/device_operation/include/device_conv_fwd_bias_activation_add.hpp @@ -0,0 +1,50 @@ +#ifndef DEVICE_CONV_FWD_BIAS_ACTIVATION_ADD_HPP +#define DEVICE_CONV_FWD_BIAS_ACTIVATION_ADD_HPP + +#include +#include "device_base.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +template +struct DeviceConvFwdBiasActivationAdd : public BaseOperator +{ + virtual std::unique_ptr + MakeArgumentPointer(const void* p_in, + const void* p_wei, + void* p_out, + const void* p_bias, + const void* p_resi, + ck::index_t N, + ck::index_t K, + ck::index_t C, + std::vector input_spatial_lengths, + std::vector filter_spatial_lengths, + std::vector output_spatial_lengths, + std::vector conv_filter_strides, + std::vector conv_filter_dilations, + std::vector input_left_pads, + std::vector input_right_pads, + InElementwiseOperation in_element_op, + WeiElementwiseOperation wei_element_op, + OutElementwiseOperation out_element_op) = 0; + + virtual std::unique_ptr MakeInvokerPointer() = 0; +}; + +template +using DeviceConvFwdBiasActivationAddPtr = + std::unique_ptr>; + +} // namespace device +} // namespace tensor_operation +} // namespace ck +#endif diff --git a/device_operation/include/device_conv_fwd_xdl.hpp b/device_operation/include/device_conv_fwd_xdl.hpp deleted file mode 100644 index f663e49fab..0000000000 --- a/device_operation/include/device_conv_fwd_xdl.hpp +++ /dev/null @@ -1,61 +0,0 @@ -#ifndef DEVICE_CONV_FWD_XDL_HPP -#define DEVICE_CONV_FWD_XDL_HPP - -#include -#include "device.hpp" -#include "device_base.hpp" -#include "device_conv.hpp" -#include "common_header.hpp" -#include "tensor_layout.hpp" -#include "tensor_descriptor.hpp" -#include "tensor_descriptor_helper.hpp" -#include "gridwise_gemm_xdlops_v2r3.hpp" - -namespace ck { -namespace tensor_operation { -namespace device { - -template -struct DeviceConvFwdXdl; - -} // namespace device -} // namespace tensor_operation -} // namespace ck -#endif diff --git a/device_operation/include/device_conv_instance.hpp b/device_operation/include/device_conv_instance.hpp deleted file mode 100644 index 1ea8265849..0000000000 --- a/device_operation/include/device_conv_instance.hpp +++ /dev/null @@ -1,52 +0,0 @@ -#ifndef DEVICE_CONV_INSTANTCE_HPP -#define DEVICE_CONV_INSTANTCE_HPP - -#include "device_conv.hpp" -#include "element_wise_operation.hpp" - -namespace ck { -namespace tensor_operation { -namespace device { -namespace device_conv_instance { - -template -void add_device_conv_fwd_instance( - std::vector>&); - -template -void add_device_conv_bwd_instance( - std::vector>&); - -template -void add_device_conv_wrw_instance( - std::vector>&); - -} // namespace device_conv_instance -} // namespace device -} // namespace tensor_operation -} // namespace ck -#endif diff --git a/device_operation/include/device_gemm_xdl.hpp b/device_operation/include/device_gemm_xdl.hpp index f6c95c511d..9e5ee80381 100644 --- a/device_operation/include/device_gemm_xdl.hpp +++ b/device_operation/include/device_gemm_xdl.hpp @@ -2,6 +2,7 @@ #define DEVICE_GEMM_XDL_HPP #include +#include #include "device.hpp" #include "device_base.hpp" #include "device_gemm.hpp" @@ -34,24 +35,22 @@ template + ck::index_t CThreadTransferDstScalarPerVector> struct DeviceGemmXdl : public DeviceGemm { @@ -131,45 +130,6 @@ struct DeviceGemmXdl using BGridDesc_K0_N_K1 = decltype(MakeBGridDescriptor_K0_N_K1(1, 1, 1)); using CGridDesc_M_N = decltype(MakeCGridDescriptor_M_N(1, 1, 1)); - // TODO remove these hacks - static constexpr auto a_k0_m_k1_grid_step_hacks = - make_tuple(make_tuple(Sequence<0, 0, 0>{}, // 0+: K0 - Sequence<0, 0, 0>{}, // 1+: M - Sequence<0, 0, 0>{}), // 2+: K1 - make_tuple(Sequence<0, 0, 0>{}, // 0-: K0 - Sequence<0, 0, 0>{}, // 1-: M - Sequence<0, 0, 0>{})); // 2-: K1 - - static constexpr auto b_k0_n_k1_grid_step_hacks = - make_tuple(make_tuple(Sequence<0, 0, 0>{}, // 0+: K0 - Sequence<0, 0, 0>{}, // 1+: N - Sequence<0, 0, 0>{}), // 2+: K1 - make_tuple(Sequence<0, 0, 0>{}, // 0-: K0 - Sequence<0, 0, 0>{}, // 1-: N - Sequence<0, 0, 0>{})); // 2-: K1 - - static constexpr auto c_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks = - make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0+: M0 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1+: N0 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2+: M1 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3+: N1 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4+: M2 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5+: M3 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6+: M4 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}), // 7+: N2 - make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0-: M0 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1-: N0 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2-: M1 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3-: N1 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4-: M2 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5-: M3 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6-: M4 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{})); // 7-: N2 - - static constexpr auto a_k0_m_k1_grid_move_slice_window_step_hacks = Sequence<0, 0, 0>{}; - - static constexpr auto b_k0_n_k1_grid_move_slice_window_step_hacks = Sequence<0, 0, 0>{}; - // GridwiseGemm using GridwiseGemm = GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3< BlockSize, @@ -191,7 +151,6 @@ struct DeviceGemmXdl K1, MXdlPerWave, NXdlPerWave, - ABlockTransferThreadSliceLengths_K0_M_K1, ABlockTransferThreadClusterLengths_K0_M_K1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, @@ -199,30 +158,18 @@ struct DeviceGemmXdl ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_K1, false, // AThreadTransferSrcResetCoordinateAfterRun, - BBlockTransferThreadSliceLengths_K0_N_K1, + ABlockLdsAddExtraM, BBlockTransferThreadClusterLengths_K0_N_K1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_K1, - false, // BThreadTransferSrcResetCoordinateAfterRun, + false, // BThreadTransferSrcResetCoordinateAfterRun, + BBlockLdsAddExtraN, Sequence<0, 2, 4, 5, 6, 1, 3, 7>, // CThreadTransferSrcDstAccessOrder, CThreadTransferSrcDstVectorDim, - CThreadTransferDstScalarPerVector, - decltype(a_k0_m_k1_grid_step_hacks), // AGridStepHacks, - decltype(b_k0_n_k1_grid_step_hacks), // BGridStepHacks, - decltype(c_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks), // CGridStepHacks, - decltype(a_k0_m_k1_grid_move_slice_window_step_hacks), // AGridMoveSliceWindowStepHacks, - decltype(b_k0_n_k1_grid_move_slice_window_step_hacks), // BGridMoveSliceWindowStepHacks, - false, // CAccessOrderMRepeatNRepeat, - ABlockLdsAddExtraM, - BBlockLdsAddExtraN>; - - using CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2 = - decltype(GridwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(CGridDesc_M_N{})); - - using Block2CTileMap = decltype(GridwiseGemm::MakeBlock2CTileMap(CGridDesc_M_N{}, 1, 1)); + CThreadTransferDstScalarPerVector>; // Argument struct Argument : public BaseArgument @@ -276,8 +223,9 @@ struct DeviceGemmXdl AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1_; BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1_; CGridDesc_M_N c_grid_desc_m_n_; - CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2 c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_; - Block2CTileMap block_2_ctile_map_; + typename GridwiseGemm::CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2 + c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_; + typename GridwiseGemm::Block2CTileMap block_2_ctile_map_; index_t M01_; index_t N01_; AElementwiseOperation a_element_op_; @@ -331,11 +279,11 @@ struct DeviceGemmXdl CDataType, remove_reference_t, remove_reference_t, - remove_reference_t, + remove_reference_t, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, - remove_reference_t, + remove_reference_t, true>; ave_time = launch_and_time_kernel(kernel, @@ -362,11 +310,11 @@ struct DeviceGemmXdl CDataType, remove_reference_t, remove_reference_t, - remove_reference_t, + remove_reference_t, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, - remove_reference_t, + remove_reference_t, false>; ave_time = launch_and_time_kernel(kernel, @@ -483,6 +431,24 @@ struct DeviceGemmXdl { return std::make_unique(Invoker{}); } + + // polymorphic + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + // clang-format off + str << "DeviceGemmXdl" + << "<" + << BlockSize << ", " + << MPerBlock << ", " + << NPerBlock << ", " + << K0PerBlock + << ">"; + // clang-format on + + return str.str(); + } }; } // namespace device diff --git a/device_operation/include/device_operation_instance.hpp b/device_operation/include/device_operation_instance.hpp new file mode 100644 index 0000000000..40fd7274ef --- /dev/null +++ b/device_operation/include/device_operation_instance.hpp @@ -0,0 +1,26 @@ +#ifndef CK_DEVICE_OPERATION_INSTANCE_HPP +#define CK_DEVICE_OPERATION_INSTANCE_HPP + +#include + +namespace ck { +namespace tensor_operation { +namespace device { + +template +void add_device_operation_instances(std::vector>& op_instances, + const NewOpInstances& new_op_instances) +{ + ck::static_for<0, std::tuple_size_v, 1>{}([&](auto i) { + const auto new_op_instance = std::get(new_op_instances); + + using NewOpInstance = remove_cvref_t; + + op_instances.push_back(std::make_unique(new_op_instance)); + }); +} + +} // namespace device +} // namespace tensor_operation +} // namespace ck +#endif diff --git a/device_operation/include/element_wise_operation.hpp b/device_operation/include/element_wise_operation.hpp deleted file mode 100644 index b4ad0a4167..0000000000 --- a/device_operation/include/element_wise_operation.hpp +++ /dev/null @@ -1,20 +0,0 @@ -#ifndef ELEMENT_WISE_OPERATION_HPP -#define ELEMENT_WISE_OPERATION_HPP - -namespace ck { -namespace tensor_operation { -namespace element_wise { - -struct PassThrough -{ - template - __host__ __device__ constexpr T operator()(T v) const - { - return v; - } -}; - -} // namespace element_wise -} // namespace tensor_operation -} // namespace ck -#endif diff --git a/example/1_gemm_xdl/gemm_xdl.cpp b/example/1_gemm_xdl/gemm_xdl.cpp index ff84b66d15..81d58b509b 100644 --- a/example/1_gemm_xdl/gemm_xdl.cpp +++ b/example/1_gemm_xdl/gemm_xdl.cpp @@ -13,24 +13,7 @@ #include "device_tensor.hpp" #include "device_base.hpp" #include "device_gemm_xdl.hpp" - -struct PassThrough -{ - template - __host__ __device__ constexpr T operator()(T v) const - { - return v; - } -}; - -struct Relu -{ - template - __host__ __device__ constexpr T operator()(T v) const - { - return v > 0 ? v : 0; - } -}; +#include "element_wise_operation.hpp" template using S = ck::Sequence; @@ -44,18 +27,18 @@ using ALayout = ck::tensor_layout::gemm::RowMajor; using BLayout = ck::tensor_layout::gemm::ColumnMajor; using CLayout = ck::tensor_layout::gemm::RowMajor; -using AOp = PassThrough; -using BOp = PassThrough; -using COp = Relu; +using AElementOp = ck::tensor_operation::element_wise::PassThrough; +using BElementOp = ck::tensor_operation::element_wise::PassThrough; +using CElementOp = ck::tensor_operation::element_wise::PassThrough; // Compilation parameters for NT problem // clang-format off using DeviceGemmInstance = - //#########################################| AData| BData| CData| AccData| ALayout| BLayout| CLayout| AElementwise| BElementwise| CElementwise| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| ABlockLds| BBlockLds| - //#########################################| Type| Type| Type| Type| | | | Operation| Operation| Operation| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadSlice| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ThreadSlice| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| SrcDstVectorDim| DstScalar| AddExtraM| AddExtraN| - //#########################################| | | | | | | | | | | | | | | | | | Wave| Wave| Lengths_K0_N_K1| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| Lengths_K0_N_K1| Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerVector| | | - //#########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - ck::tensor_operation::device::DeviceGemmXdl< ADataType, BDataType, CDataType, AccDataType, ALayout, BLayout, CLayout, AOp, BOp, COp, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 8>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, S<1, 2, 8>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 7, 1, true, true>; + //#########################################| AData| BData| CData| AccData| ALayout| BLayout| CLayout| AElementwise| BElementwise| CElementwise| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| ABlockLds| BBlockLds| + //#########################################| Type| Type| Type| Type| | | | Operation| Operation| Operation| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| SrcDstVectorDim| DstScalar| AddExtraM| AddExtraN| + //#########################################| | | | | | | | | | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerVector| | | + //#########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + ck::tensor_operation::device::DeviceGemmXdl< ADataType, BDataType, CDataType, AccDataType, ALayout, BLayout, CLayout, AElementOp, BElementOp, CElementOp, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 7, 1, true, true>; // clang-format on template , S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, S<1, 2, 8>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 7, 1, true, true>; + //#################################################################| AData| BData| CData| AccData| ALayout| BLayout| CLayout| AElementwise| BElementwise| CElementwise| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| ABlockLds| BBlockLds| + //#################################################################| Type| Type| Type| Type| | | | Operation| Operation| Operation| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| SrcDstVectorDim| DstScalar| AddExtraM| AddExtraN| + //#################################################################| | | | | | | | | | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerVector| | | + //#################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + ck::tensor_operation::device::DeviceGemmXdl_two_extra_source_reduce< ADataType, BDataType, CDataType, AccDataType, ALayout, BLayout, CLayout, AOp, BOp, COp, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 7, 1, true, true>; // clang-format on template +#include #include "device.hpp" #include "device_base.hpp" #include "device_gemm.hpp" @@ -560,6 +561,23 @@ struct DeviceGemmXdl_two_extra_source_reduce : public BaseOperator { return std::make_unique(Invoker{}); } + + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + // clang-format off + str << "DeviceGemmXdl_two_extra_source_reduce" + << "<" + << BlockSize << ", " + << MPerBlock << ", " + << NPerBlock << ", " + << K0PerBlock + << ">"; + // clang-format on + + return str.str(); + } }; } // namespace device diff --git a/example/3_conv_xdl/README.md b/example/4_conv2d_fwd_xdl/README.md similarity index 92% rename from example/3_conv_xdl/README.md rename to example/4_conv2d_fwd_xdl/README.md index 2db7487235..4114571afe 100644 --- a/example/3_conv_xdl/README.md +++ b/example/4_conv2d_fwd_xdl/README.md @@ -1,4 +1,4 @@ -# Instructions for ```conv_xdl``` Example +# Instructions for ```conv2d_fwd_xdl``` Example ## Docker script ```bash @@ -13,7 +13,7 @@ rocm/tensorflow:rocm4.3.1-tf2.6-dev \ /bin/bash ``` -## Build ```conv_xdl``` +## Build ```conv2d_fwd_xdl``` ```bash mkdir build && cd build ``` @@ -30,16 +30,16 @@ cmake \ ``` ```bash - make -j conv_xdl + make -j conv2d_fwd_xdl ``` -## Run ```conv_xdl``` +## Run ```conv2d_fwd_xdl``` ```bash #arg1: verification (0=no, 1=yes) #arg2: initialization (0=no init, 1=integer value, 2=decimal value) #arg3: run kernel # of times (>1) #arg4 to 18: N, K, C, Y, X, Hi, Wi, Sy, Sx, Dy, Dx, LeftPy, LeftPx, RightPy, RightPx -./example/conv_xdl 0 1 5 +./example/conv2d_fwd_xdl 0 1 5 ``` Result (MI100 @ 1087Mhz, 133.5TFlops peak FP16) diff --git a/example/3_conv_xdl/conv_xdl.cpp b/example/4_conv2d_fwd_xdl/conv2d_fwd_xdl.cpp similarity index 77% rename from example/3_conv_xdl/conv_xdl.cpp rename to example/4_conv2d_fwd_xdl/conv2d_fwd_xdl.cpp index 880c0db9ba..ad428e2ef2 100644 --- a/example/3_conv_xdl/conv_xdl.cpp +++ b/example/4_conv2d_fwd_xdl/conv2d_fwd_xdl.cpp @@ -11,27 +11,8 @@ #include "host_tensor_generator.hpp" #include "device_tensor.hpp" #include "tensor_layout.hpp" -#include "device_conv_fwd_xdl.hpp" -#include "device_conv_fwd_xdl_nhwc_kyxc_nhwk.hpp" - -struct PassThrough -{ - template - __host__ __device__ constexpr T operator()(T v) const - { - return v; - } -}; - -struct Relu -{ - template - __host__ __device__ constexpr T operator()(T v) const - { - T tmp = 0.1 * v; - return tmp > 0 ? tmp : 0; - } -}; +#include "device_operation/include/device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp" +#include "element_wise_operation.hpp" using InDataType = ck::half_t; using WeiDataType = ck::half_t; @@ -45,17 +26,21 @@ using InLayout = ck::tensor_layout::convolution::NHWC; using WeiLayout = ck::tensor_layout::convolution::KYXC; using OutLayout = ck::tensor_layout::convolution::NHWK; -using InElementOp = PassThrough; -using WeiElementOp = PassThrough; -using OutElementOp = Relu; +using InElementOp = ck::tensor_operation::element_wise::PassThrough; +using WeiElementOp = ck::tensor_operation::element_wise::PassThrough; +using OutElementOp = ck::tensor_operation::element_wise::PassThrough; -using DeviceConvFwdInstance = +static constexpr auto ConvFwdDefault = + ck::tensor_operation::device::ConvolutionForwardSpecialization_t::Default; + +using DeviceConvFwdInstance = ck::tensor_operation::device:: + DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K // clang-format off -//############################################| NDim| InData| WeiData| OutData| AccData| In| Wei| Out| In| Wei| Out| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| ABlockLds| BBlockLds| -//############################################| Spatial| Type| Type| Type| Type| Layout| Layout| Layout| Elementwise| Elementwise| Elementwise| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadSlice| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ThreadSlice| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| SrcDstVectorDim| DstScalar| AddExtraM| AddExtraN| -//############################################| | | | | | | | | Operation| Operation| Operation| | | | | | | | Wave| Wave| Lengths_K0_N_K1| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| Lengths_K0_N_K1| Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerVector| | | -//############################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | -ck::tensor_operation::device::DeviceConvFwdXdl< 2, InDataType, WeiDataType, OutDataType, AccDataType, InLayout, WeiLayout, OutLayout, InElementOp, WeiElementOp, OutElementOp, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<1, 2, 8>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, S<1, 4, 8>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 7, 1, true, true>; +// | InData| WeiData| OutData| AccData| In| Wei| Out| ConvForward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| +// | Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| +// | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| +// | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + , S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>; // clang-format on template & in, } } } - out(n, k, ho, wo) = out_element_op(v); + double v2 = out(n, k, ho, wo); + + out_element_op(v2, v); + + out(n, k, ho, wo) = v2; }; make_ParallelTensorFunctor(f_nchw, diff --git a/example/4_conv_xdl_bias_relu_add/include/device_conv_fwd_xdl_bias_activation_add.hpp b/example/4_conv_xdl_bias_relu_add/include/device_conv_fwd_xdl_bias_activation_add.hpp deleted file mode 100644 index d7164d4d5e..0000000000 --- a/example/4_conv_xdl_bias_relu_add/include/device_conv_fwd_xdl_bias_activation_add.hpp +++ /dev/null @@ -1,61 +0,0 @@ -#ifndef DEVICE_CONV_FWD_XDL_BIAS_ACTIVATION_ADD_HPP -#define DEVICE_CONV_FWD_XDL_BIAS_ACTIVATION_ADD_HPP - -#include -#include "device.hpp" -#include "device_base.hpp" -#include "device_conv.hpp" -#include "common_header.hpp" -#include "tensor_layout.hpp" -#include "tensor_descriptor.hpp" -#include "tensor_descriptor_helper.hpp" -#include "gridwise_gemm_xdlops_v2r3.hpp" - -namespace ck { -namespace tensor_operation { -namespace device { - -template -struct DeviceConvFwdXdl_bias_activation_add; - -} // namespace device -} // namespace tensor_operation -} // namespace ck -#endif diff --git a/example/4_conv_xdl_bias_relu_add/include/device_conv_fwd_xdl_bias_activation_add_nhwc_kyxc_nhwk.hpp b/example/4_conv_xdl_bias_relu_add/include/device_conv_fwd_xdl_bias_activation_add_nhwc_kyxc_nhwk.hpp deleted file mode 100644 index 49588b419a..0000000000 --- a/example/4_conv_xdl_bias_relu_add/include/device_conv_fwd_xdl_bias_activation_add_nhwc_kyxc_nhwk.hpp +++ /dev/null @@ -1,669 +0,0 @@ -#ifndef DEVICE_CONV_FWD_XDL_BIAS_ACTIVATION_ADD_NHWC_KYXC_NHWK_HPP -#define DEVICE_CONV_FWD_XDL_BIAS_ACTIVATION_ADD_NHWC_KYXC_NHWK_HPP - -#include -#include "device.hpp" -#include "device_base.hpp" -#include "device_conv.hpp" -#include "common_header.hpp" -#include "tensor_layout.hpp" -#include "tensor_descriptor.hpp" -#include "tensor_descriptor_helper.hpp" -#include "gridwise_gemm_xdlops_v2r5.hpp" -#include "example/4_conv_xdl_bias_relu_add/include/device_conv_fwd_xdl_bias_activation_add.hpp" - -namespace ck { -namespace tensor_operation { -namespace device { - -// specialization for 2D conv: in[n, hi, wi, c] * wei[k, y, x, c] = out[n, ho, wo, k] -template -struct DeviceConvFwdXdl_bias_activation_add< - 2, // ck::index_t NDimSpatial, - InDataType, // typename InDataType, - WeiDataType, // typename WeiDataType, - OutDataType, // typename OutDataType, - AccDataType, // typename AccDataType, - ck::tensor_layout::convolution::NHWC, // typename InLayout, - ck::tensor_layout::convolution::KYXC, // typename WeiLayout, - ck::tensor_layout::convolution::NHWK, // typename OutLayout, - InElementwiseOperation, // typename InElementwiseOperation, - WeiElementwiseOperation, // typename WeiElementwiseOperation, - OutElementwiseOperation, // typename OutElementwiseOperation, - BlockSize, // ck::index_t BlockSize, - MPerBlock, // ck::index_t MPerBlock, - NPerBlock, // ck::index_t NPerBlock, - K0PerBlock, // ck::index_t K0PerBlock, - K1, // ck::index_t K1, - MPerXDL, // ck::index_t MPerXDL, - NPerXDL, // ck::index_t NPerXDL, - MXdlPerWave, // ck::index_t MXdlPerWave, - NXdlPerWave, // ck::index_t NXdlPerWave, - ABlockTransferThreadSliceLengths_K0_M_K1, // typename ABlockTransferThreadSliceLengths_K0_M_K1, - ABlockTransferThreadClusterLengths_K0_M_K1, // typename - // ABlockTransferThreadClusterLengths_K0_M_K1, - ABlockTransferThreadClusterArrangeOrder, // typename ABlockTransferThreadClusterArrangeOrder, - ABlockTransferSrcAccessOrder, // typename ABlockTransferSrcAccessOrder, - ABlockTransferSrcVectorDim, // ck::index_t ABlockTransferSrcVectorDim, - ABlockTransferSrcScalarPerVector, // ck::index_t ABlockTransferSrcScalarPerVector, - ABlockTransferDstScalarPerVector_K1, // ck::index_t ABlockTransferDstScalarPerVector_K1, - BBlockTransferThreadSliceLengths_K0_N_K1, // typename BBlockTransferThreadSliceLengths_K0_N_K1, - BBlockTransferThreadClusterLengths_K0_N_K1, // typename - // BBlockTransferThreadClusterLengths_K0_N_K1, - BBlockTransferThreadClusterArrangeOrder, // typename BBlockTransferThreadClusterArrangeOrder, - BBlockTransferSrcAccessOrder, // typename BBlockTransferSrcAccessOrder, - BBlockTransferSrcVectorDim, // ck::index_t BBlockTransferSrcVectorDim, - BBlockTransferSrcScalarPerVector, // ck::index_t BBlockTransferSrcScalarPerVector, - BBlockTransferDstScalarPerVector_K1, // ck::index_t BBlockTransferDstScalarPerVector_K1, - CThreadTransferSrcDstVectorDim, // ck::index_t CThreadTransferSrcDstVectorDim, - CThreadTransferDstScalarPerVector, // ck::index_t CThreadTransferDstScalarPerVector, - ABlockLdsAddExtraM, // bool ABlockLdsAddExtraM, - BBlockLdsAddExtraN // bool BBlockLdsAddExtraN> - > : public BaseOperator -{ - using ADataType = InDataType; - using BDataType = WeiDataType; - using CDataType = OutDataType; - - // TODO make A/B datatype different - using ABDataType = InDataType; - - // TODO make it support any # of spatial dimensions - static constexpr index_t NDimSpatial = 2; - - static constexpr auto I0 = Number<0>{}; - static constexpr auto I1 = Number<1>{}; - static constexpr auto I2 = Number<2>{}; - static constexpr auto I3 = Number<3>{}; - static constexpr auto I4 = Number<4>{}; - - static constexpr auto K1Number = Number{}; - static constexpr auto GemmK1Number = K1Number; - - static auto - MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N(ck::index_t N, - ck::index_t K, - ck::index_t C, - std::vector input_spatial_lengths, - std::vector filter_spatial_lengths, - std::vector output_spatial_lengths, - std::vector conv_filter_strides, - std::vector conv_filter_dilations, - std::vector input_left_pads, - std::vector input_right_pads) - { - using namespace ck; - - const index_t Hi = input_spatial_lengths[0]; - const index_t Wi = input_spatial_lengths[1]; - - const index_t Ho = output_spatial_lengths[0]; - const index_t Wo = output_spatial_lengths[1]; - - const index_t Y = filter_spatial_lengths[0]; - const index_t X = filter_spatial_lengths[1]; - - const index_t ConvStrideH = conv_filter_strides[0]; - const index_t ConvStrideW = conv_filter_strides[1]; - - const index_t ConvDilationH = conv_filter_dilations[0]; - const index_t ConvDilationW = conv_filter_dilations[1]; - - const index_t InLeftPadH = input_left_pads[0]; - const index_t InLeftPadW = input_left_pads[1]; - - const index_t InRightPadH = input_right_pads[0]; - const index_t InRightPadW = input_right_pads[1]; - - const index_t GemmMRaw = N * Ho * Wo; - const index_t GemmN = K; - const index_t GemmK = Y * X * C; - - const auto GemmMPad = math::integer_least_multiple(GemmMRaw, MPerBlock) - GemmMRaw; - - const auto GemmM = GemmMRaw + GemmMPad; - - assert(GemmK % GemmK1Number == 0); - - const index_t GemmK0 = GemmK / GemmK1Number; - - // A: input tensor - const auto in_n_hi_wi_c_grid_desc = - make_naive_tensor_descriptor_packed(make_tuple(N, Hi, Wi, C)); - - const auto in_n_hip_wip_c_grid_desc = transform_tensor_descriptor( - in_n_hi_wi_c_grid_desc, - make_tuple(make_pass_through_transform(N), - make_pad_transform(Hi, InLeftPadH, InRightPadH), - make_pad_transform(Wi, InLeftPadW, InRightPadW), - make_pass_through_transform(C)), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); - - const auto in_n_y_ho_x_wo_c_grid_desc = transform_tensor_descriptor( - in_n_hip_wip_c_grid_desc, - make_tuple( - make_pass_through_transform(N), - make_embed_transform(make_tuple(Y, Ho), make_tuple(ConvDilationH, ConvStrideH)), - make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW)), - make_pass_through_transform(C)), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), - make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{})); - - const auto in_gemmk_gemmmraw_grid_desc = - transform_tensor_descriptor(in_n_y_ho_x_wo_c_grid_desc, - make_tuple(make_merge_transform(make_tuple(Y, X, C)), - make_merge_transform(make_tuple(N, Ho, Wo))), - make_tuple(Sequence<1, 3, 5>{}, Sequence<0, 2, 4>{}), - make_tuple(Sequence<0>{}, Sequence<1>{})); - - const auto in_gemmk0_gemmmraw_gemmk1_grid_desc = transform_tensor_descriptor( - in_gemmk_gemmmraw_grid_desc, - make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1Number)), - make_pass_through_transform(GemmMRaw)), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<0, 2>{}, Sequence<1>{})); - - const auto in_gemmk0_gemmm_gemmk1_grid_desc = - transform_tensor_descriptor(in_gemmk0_gemmmraw_gemmk1_grid_desc, - make_tuple(make_pass_through_transform(GemmK0), - make_right_pad_transform(GemmMRaw, GemmMPad), - make_pass_through_transform(GemmK1Number)), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); - - // B: weight tensor - const auto wei_k_yxc_grid_desc = - make_naive_tensor_descriptor_packed(make_tuple(K, Y * X * C)); - - const auto wei_gemmk_gemmn_grid_desc = transform_tensor_descriptor( - wei_k_yxc_grid_desc, - make_tuple(make_pass_through_transform(K), make_pass_through_transform(Y * X * C)), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<1>{}, Sequence<0>{})); - - const auto wei_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor( - wei_gemmk_gemmn_grid_desc, - make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1Number)), - make_pass_through_transform(GemmN)), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<0, 2>{}, Sequence<1>{})); - - // C: output tensor - const auto out_nhowo_k_grid_desc = - make_naive_tensor_descriptor_packed(make_tuple(N * Ho * Wo, K)); - - const auto out_gemmmraw_gemmn_grid_desc = transform_tensor_descriptor( - out_nhowo_k_grid_desc, - make_tuple(make_pass_through_transform(N * Ho * Wo), make_pass_through_transform(K)), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<0>{}, Sequence<1>{})); - - const auto out_gemmm_gemmn_grid_desc = - transform_tensor_descriptor(out_gemmmraw_gemmn_grid_desc, - make_tuple(make_right_pad_transform(GemmMRaw, GemmMPad), - make_pass_through_transform(GemmN)), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<0>{}, Sequence<1>{})); - - // C0: bias tensor: assume a contiguous vector - const auto bias_grid_desc_gemmm_gemmn = - make_naive_tensor_descriptor(make_tuple(GemmM, GemmN), make_tuple(0, 1)); - - // C1: residual tensor: assume same layout as output tensor - const auto resi_grid_desc_gemmm_gemmn = out_gemmm_gemmn_grid_desc; - - return make_tuple(in_gemmk0_gemmm_gemmk1_grid_desc, - wei_gemmk0_gemmn_gemmk1_grid_desc, - out_gemmm_gemmn_grid_desc, - bias_grid_desc_gemmm_gemmn, - resi_grid_desc_gemmm_gemmn); - } - - using ABCGridDescs = decltype(MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N( - 1, 1, 1, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1})); - - using AGridDesc_K0_M_K1 = remove_cvref_t; - using BGridDesc_K0_N_K1 = remove_cvref_t; - using CGridDesc_M_N = remove_cvref_t; - using C0GridDesc_M_N = remove_cvref_t; - using C1GridDesc_M_N = remove_cvref_t; - - // TODO remove these hacks - static constexpr auto a_k0_m_k1_grid_step_hacks = make_tuple( - make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0>{}, // 0+: K0 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0>{}, // 1+: M - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0>{}), // 2+: K1 - make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0>{}, // 0-: K0 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0>{}, // 1-: M - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0>{})); // 2-: K1 - - static constexpr auto b_k0_n_k1_grid_step_hacks = - make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{}, // 0+: K0 - Sequence<0, 0, 0, 0, 0>{}, // 1+: N - Sequence<0, 0, 0, 0, 0>{}), // 2+: K1 - make_tuple(Sequence<0, 0, 0, 0, 0>{}, // 0-: K0 - Sequence<0, 0, 0, 0, 0>{}, // 1-: N - Sequence<0, 0, 0, 0, 0>{})); // 2-: K1 - - static constexpr auto c_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks = - make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0+: M0 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1+: N0 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2+: M1 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3+: N1 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4+: M2 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5+: M3 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6+: M4 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}), // 7+: N2 - make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0-: M0 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1-: N0 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2-: M1 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3-: N1 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4-: M2 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5-: M3 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6-: M4 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{})); // 7-: N2 - - static constexpr auto a_k0_m_k1_grid_move_slice_window_step_hacks = - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 0, 0, 0, 0, 0>{}; - - static constexpr auto b_k0_n_k1_grid_move_slice_window_step_hacks = Sequence<0, 0, 0, 0, 0>{}; - - // GridwiseGemm - using GridwiseGemm = GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r5< - BlockSize, - ABDataType, // TODO: distinguish A/B datatype - AccDataType, - CDataType, - InMemoryDataOperationEnum_t::Set, - AGridDesc_K0_M_K1, - BGridDesc_K0_N_K1, - CGridDesc_M_N, - C0GridDesc_M_N, - C1GridDesc_M_N, - InElementwiseOperation, - WeiElementwiseOperation, - OutElementwiseOperation, - MPerBlock, - NPerBlock, - K0PerBlock, - MPerXDL, - NPerXDL, - K1, - MXdlPerWave, - NXdlPerWave, - ABlockTransferThreadSliceLengths_K0_M_K1, - ABlockTransferThreadClusterLengths_K0_M_K1, - Sequence<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder, - Sequence<1, 0, 2>, // ABlockTransferSrcAccessOrder, - 2, // ABlockTransferSrcVectorDim, - ABlockTransferSrcScalarPerVector, - ABlockTransferDstScalarPerVector_K1, - false, // AThreadTransferSrcResetCoordinateAfterRun, - BBlockTransferThreadSliceLengths_K0_N_K1, - BBlockTransferThreadClusterLengths_K0_N_K1, - Sequence<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder, - Sequence<1, 0, 2>, // BBlockTransferSrcAccessOrder, - 2, // BBlockTransferSrcVectorDim, - BBlockTransferSrcScalarPerVector, - BBlockTransferDstScalarPerVector_K1, - false, // BThreadTransferSrcResetCoordinateAfterRun, - Sequence<2, 3, 0, 1, 7, 5, 4, 6>, // CThreadTransferSrcDstAccessOrder, - 7, // CThreadTransferSrcDstVectorDim, - CThreadTransferDstScalarPerVector, - decltype(a_k0_m_k1_grid_step_hacks), // AGridStepHacks, - decltype(b_k0_n_k1_grid_step_hacks), // BGridStepHacks, - decltype(c_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks), // CGridStepHacks, - decltype(a_k0_m_k1_grid_move_slice_window_step_hacks), // AGridMoveSliceWindowStepHacks, - decltype(b_k0_n_k1_grid_move_slice_window_step_hacks), // BGridMoveSliceWindowStepHacks, - false, // CAccessOrderMRepeatNRepeat, - ABlockLdsAddExtraM, - BBlockLdsAddExtraN>; - - using CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2 = - decltype(GridwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(CGridDesc_M_N{})); - - using C0GridDesc_M0_N0_M1_N1_M2_M3_M4_N2 = - decltype(GridwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(C0GridDesc_M_N{})); - - using C1GridDesc_M0_N0_M1_N1_M2_M3_M4_N2 = - decltype(GridwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(C1GridDesc_M_N{})); - - using Block2CTileMap = decltype(GridwiseGemm::MakeBlock2CTileMap(CGridDesc_M_N{}, 1, 1)); - - // Argument - struct Argument : public BaseArgument - { - Argument(const InDataType* p_in_grid, - const WeiDataType* p_wei_grid, - OutDataType* p_out_grid, - const OutDataType* p_bias_grid, - const OutDataType* p_resi_grid, - ck::index_t N, - ck::index_t K, - ck::index_t C, - std::vector input_spatial_lengths, - std::vector filter_spatial_lengths, - std::vector output_spatial_lengths, - std::vector conv_filter_strides, - std::vector conv_filter_dilations, - std::vector input_left_pads, - std::vector input_right_pads, - ck::index_t M01, - ck::index_t N01, - InElementwiseOperation in_element_op, - WeiElementwiseOperation wei_element_op, - OutElementwiseOperation out_element_op) - : p_a_grid_{p_in_grid}, - p_b_grid_{p_wei_grid}, - p_c_grid_{p_out_grid}, - p_c0_grid_{p_bias_grid}, - p_c1_grid_{p_resi_grid}, - a_grid_desc_k0_m_k1_{}, - b_grid_desc_k0_n_k1_{}, - c_grid_desc_m_n_{}, - c0_grid_desc_m_n_{}, - c1_grid_desc_m_n_{}, - c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_{}, - c0_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_{}, - c1_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_{}, - block_2_ctile_map_{}, - M01_{M01}, - N01_{N01}, - in_element_op_{in_element_op}, - wei_element_op_{wei_element_op}, - out_element_op_{out_element_op} - { - const auto descs = DeviceConvFwdXdl_bias_activation_add:: - MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N(N, - K, - C, - input_spatial_lengths, - filter_spatial_lengths, - output_spatial_lengths, - conv_filter_strides, - conv_filter_dilations, - input_left_pads, - input_right_pads); - - a_grid_desc_k0_m_k1_ = descs[I0]; - b_grid_desc_k0_n_k1_ = descs[I1]; - c_grid_desc_m_n_ = descs[I2]; - c0_grid_desc_m_n_ = descs[I3]; - c1_grid_desc_m_n_ = descs[I4]; - - if(GridwiseGemm::CheckValidity( - a_grid_desc_k0_m_k1_, b_grid_desc_k0_n_k1_, c_grid_desc_m_n_, M01_, N01_)) - { - c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_ = - GridwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(c_grid_desc_m_n_); - - c0_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_ = - GridwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(c0_grid_desc_m_n_); - - c1_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_ = - GridwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(c1_grid_desc_m_n_); - - block_2_ctile_map_ = GridwiseGemm::MakeBlock2CTileMap(c_grid_desc_m_n_, M01, N01); - } - } - - // private: - const ADataType* p_a_grid_; - const BDataType* p_b_grid_; - CDataType* p_c_grid_; - const CDataType* p_c0_grid_; - const CDataType* p_c1_grid_; - AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1_; - BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1_; - CGridDesc_M_N c_grid_desc_m_n_; - C0GridDesc_M_N c0_grid_desc_m_n_; - C1GridDesc_M_N c1_grid_desc_m_n_; - CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2 c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_; - C0GridDesc_M0_N0_M1_N1_M2_M3_M4_N2 c0_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_; - C1GridDesc_M0_N0_M1_N1_M2_M3_M4_N2 c1_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_; - Block2CTileMap block_2_ctile_map_; - index_t M01_; - index_t N01_; - InElementwiseOperation in_element_op_; - WeiElementwiseOperation wei_element_op_; - OutElementwiseOperation out_element_op_; - }; - - // Invoker - struct Invoker : public BaseInvoker - { - using Argument = DeviceConvFwdXdl_bias_activation_add::Argument; - - float Run(const Argument& arg, int nrepeat = 1) - { - { - std::cout << "arg.a_grid_desc_k0_m_k1_{" << arg.a_grid_desc_k0_m_k1_.GetLength(I0) - << ", " << arg.a_grid_desc_k0_m_k1_.GetLength(I1) << ", " - << arg.a_grid_desc_k0_m_k1_.GetLength(I2) << "}" << std::endl; - - std::cout << "arg.b_grid_desc_k0_n_k1_{" << arg.b_grid_desc_k0_n_k1_.GetLength(I0) - << ", " << arg.b_grid_desc_k0_n_k1_.GetLength(I1) << ", " - << arg.b_grid_desc_k0_n_k1_.GetLength(I2) << "}" << std::endl; - - std::cout << "arg.c_grid_desc_m_n_{ " << arg.c_grid_desc_m_n_.GetLength(I0) << ", " - << arg.c_grid_desc_m_n_.GetLength(I1) << "}" << std::endl; - - std::cout << "arg.c0_grid_desc_m_n_{ " << arg.c0_grid_desc_m_n_.GetLength(I0) - << ", " << arg.c0_grid_desc_m_n_.GetLength(I1) << "}" << std::endl; - - std::cout << "arg.c1_grid_desc_m_n_{ " << arg.c1_grid_desc_m_n_.GetLength(I0) - << ", " << arg.c1_grid_desc_m_n_.GetLength(I1) << "}" << std::endl; - } - - if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_, - arg.b_grid_desc_k0_n_k1_, - arg.c_grid_desc_m_n_, - arg.M01_, - arg.N01_)) - { - throw std::runtime_error( - "wrong! GridwiseGemm_km_kn_m0m1n0n1_xdlops_v2r5 has invalid setting"); - } - - const index_t grid_size = GridwiseGemm::CalculateGridSize(arg.c_grid_desc_m_n_); - - const auto K0 = arg.a_grid_desc_k0_m_k1_.GetLength(I0); - - const bool has_main_k0_block_loop = GridwiseGemm::CalculateHasMainK0BlockLoop(K0); - - float ave_time = 0; - - if(has_main_k0_block_loop) - { - const auto kernel = kernel_gemm_xdlops_v2r5< - GridwiseGemm, - ADataType, // TODO: distiguish A/B datatype - CDataType, - remove_reference_t, - remove_reference_t, - remove_reference_t< - DeviceConvFwdXdl_bias_activation_add::CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2>, - remove_reference_t< - DeviceConvFwdXdl_bias_activation_add::C0GridDesc_M0_N0_M1_N1_M2_M3_M4_N2>, - remove_reference_t< - DeviceConvFwdXdl_bias_activation_add::C1GridDesc_M0_N0_M1_N1_M2_M3_M4_N2>, - InElementwiseOperation, - WeiElementwiseOperation, - OutElementwiseOperation, - remove_reference_t, - true>; - - ave_time = launch_and_time_kernel(kernel, - nrepeat, - dim3(grid_size), - dim3(BlockSize), - 0, - arg.p_a_grid_, - arg.p_b_grid_, - arg.p_c_grid_, - arg.p_c0_grid_, - arg.p_c1_grid_, - arg.a_grid_desc_k0_m_k1_, - arg.b_grid_desc_k0_n_k1_, - arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_, - arg.c0_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_, - arg.c1_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_, - arg.in_element_op_, - arg.wei_element_op_, - arg.out_element_op_, - arg.block_2_ctile_map_); - } - else - { - const auto kernel = kernel_gemm_xdlops_v2r5< - GridwiseGemm, - ADataType, // TODO: distiguish A/B datatype - CDataType, - remove_reference_t, - remove_reference_t, - remove_reference_t< - DeviceConvFwdXdl_bias_activation_add::CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2>, - remove_reference_t< - DeviceConvFwdXdl_bias_activation_add::C0GridDesc_M0_N0_M1_N1_M2_M3_M4_N2>, - remove_reference_t< - DeviceConvFwdXdl_bias_activation_add::C1GridDesc_M0_N0_M1_N1_M2_M3_M4_N2>, - InElementwiseOperation, - WeiElementwiseOperation, - OutElementwiseOperation, - remove_reference_t, - false>; - - ave_time = launch_and_time_kernel(kernel, - nrepeat, - dim3(grid_size), - dim3(BlockSize), - 0, - arg.p_a_grid_, - arg.p_b_grid_, - arg.p_c_grid_, - arg.p_c0_grid_, - arg.p_c1_grid_, - arg.a_grid_desc_k0_m_k1_, - arg.b_grid_desc_k0_n_k1_, - arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_, - arg.c0_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_, - arg.c1_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_, - arg.in_element_op_, - arg.wei_element_op_, - arg.out_element_op_, - arg.block_2_ctile_map_); - } - - return ave_time; - } - - // polymorphic - float Run(const BaseArgument* p_arg, int nrepeat = 1) override - { - return Run(*dynamic_cast(p_arg), nrepeat); - } - }; - - static constexpr bool IsValidCompilationParameter() - { - // TODO: properly implement this check - return true; - } - - static bool IsSupportedArgument(const Argument& arg) - { - return GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_, - arg.b_grid_desc_k0_n_k1_, - arg.c_grid_desc_m_n_, - arg.M01_, - arg.N01_); - } - - // polymorphic - bool IsSupportedArgument(const BaseArgument* p_arg) override - { - return IsSupportedArgument(*dynamic_cast(p_arg)); - } - - static auto MakeArgument(const InDataType* p_in_grid, - const WeiDataType* p_wei_grid, - OutDataType* p_out_grid, - const OutDataType* p_bias_grid, - const OutDataType* p_resi_grid, - ck::index_t N, - ck::index_t K, - ck::index_t C, - std::vector input_spatial_lengths, - std::vector filter_spatial_lengths, - std::vector output_spatial_lengths, - std::vector conv_filter_strides, - std::vector conv_filter_dilations, - std::vector input_left_pads, - std::vector input_right_pads, - InElementwiseOperation in_element_op, - WeiElementwiseOperation wei_element_op, - OutElementwiseOperation out_element_op) - { - return Argument{p_in_grid, - p_wei_grid, - p_out_grid, - p_bias_grid, - p_resi_grid, - N, - K, - C, - input_spatial_lengths, - filter_spatial_lengths, - output_spatial_lengths, - conv_filter_strides, - conv_filter_dilations, - input_left_pads, - input_right_pads, - 1, - 1, - in_element_op, - wei_element_op, - out_element_op}; - } - - static auto MakeInvoker() { return Invoker{}; } -}; // namespace device - -} // namespace device -} // namespace tensor_operation -} // namespace ck -#endif diff --git a/example/4_conv_xdl_bias_relu_add/README.md b/example/5_conv2d_fwd_xdl_bias_relu/README.md similarity index 100% rename from example/4_conv_xdl_bias_relu_add/README.md rename to example/5_conv2d_fwd_xdl_bias_relu/README.md diff --git a/example/5_conv2d_fwd_xdl_bias_relu/conv2d_fwd_xdl_bias_relu.cpp b/example/5_conv2d_fwd_xdl_bias_relu/conv2d_fwd_xdl_bias_relu.cpp new file mode 100644 index 0000000000..aa2605bbdf --- /dev/null +++ b/example/5_conv2d_fwd_xdl_bias_relu/conv2d_fwd_xdl_bias_relu.cpp @@ -0,0 +1,296 @@ +#include +#include +#include +#include +#include +#include +#include "config.hpp" +#include "print.hpp" +#include "device.hpp" +#include "host_tensor.hpp" +#include "host_tensor_generator.hpp" +#include "device_tensor.hpp" +#include "tensor_layout.hpp" +#include "device_conv2d_fwd_xdl_c_shuffle_bias_activation_nhwc_kyxc_nhwk.hpp" +#include "element_wise_operation.hpp" + +using InDataType = ck::half_t; +using WeiDataType = ck::half_t; +using OutDataType = ck::half_t; +using AccDataType = float; + +template +using S = ck::Sequence; + +using InLayout = ck::tensor_layout::convolution::NHWC; +using WeiLayout = ck::tensor_layout::convolution::KYXC; +using OutLayout = ck::tensor_layout::convolution::NHWK; + +using InElementOp = ck::tensor_operation::element_wise::PassThrough; +using WeiElementOp = ck::tensor_operation::element_wise::PassThrough; +using OutElementOp = ck::tensor_operation::element_wise::AddRelu; + +static constexpr auto MemorySet = ck::InMemoryDataOperationEnum_t::Set; + +static constexpr auto ConvFwdDefault = + ck::tensor_operation::device::ConvolutionForwardSpecialization_t::Default; + +// clang-format off +using DeviceConvFwdInstance = ck::tensor_operation::device:: + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K + // clang-format off +// | InData| WeiData| OutData| AccData| In| Wei| Out| Out| ConvForward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| +// | Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| GlobalMemory| Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| +// | | | | | Operation| Operation| Operation| DataOperation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| +// | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + , S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>; +// clang-format on + +template +void host_reference_calculation(const Tensor& in_n_c_hi_wi, + const Tensor& wei_k_c_y_x, + Tensor& out_n_k_ho_wo, + const Tensor& bias_k, + const std::vector& conv_strides, + const std::vector& conv_dilations, + const std::vector& in_left_pads, + const std::vector& /* in_right_pads */, + const InElementOp& in_element_op, + const WeiElementOp& wei_element_op, + const OutElementOp& out_element_op) +{ + auto f_nchw = [&](auto n, auto k, auto ho, auto wo) { + double v = 0; + for(int c = 0; c < wei_k_c_y_x.mDesc.GetLengths()[1]; ++c) + { + for(int y = 0; y < wei_k_c_y_x.mDesc.GetLengths()[2]; ++y) + { + int hi = ho * conv_strides[0] + y * conv_dilations[0] - in_left_pads[0]; + for(int x = 0; x < wei_k_c_y_x.mDesc.GetLengths()[3]; ++x) + { + int wi = wo * conv_strides[1] + x * conv_dilations[1] - in_left_pads[1]; + if(hi >= 0 && hi < in_n_c_hi_wi.mDesc.GetLengths()[2] && wi >= 0 && + wi < in_n_c_hi_wi.mDesc.GetLengths()[3]) + { + v += in_element_op(static_cast(in_n_c_hi_wi(n, c, hi, wi))) * + wei_element_op(static_cast(wei_k_c_y_x(k, c, y, x))); + } + } + } + } + + out_n_k_ho_wo(n, k, ho, wo) = out_element_op(v, bias_k(k)); + }; + + make_ParallelTensorFunctor(f_nchw, + out_n_k_ho_wo.mDesc.GetLengths()[0], + out_n_k_ho_wo.mDesc.GetLengths()[1], + out_n_k_ho_wo.mDesc.GetLengths()[2], + out_n_k_ho_wo.mDesc.GetLengths()[3])( + std::thread::hardware_concurrency()); +} + +int main(int argc, char* argv[]) +{ + bool do_verification = 0; + int init_method = 0; + int nrepeat = 5; + + // Conv shape + ck::index_t N = 128; + ck::index_t K = 256; + ck::index_t C = 192; + ck::index_t Y = 3; + ck::index_t X = 3; + ck::index_t Hi = 71; + ck::index_t Wi = 71; + ck::index_t conv_stride_h = 2; + ck::index_t conv_stride_w = 2; + ck::index_t conv_dilation_h = 1; + ck::index_t conv_dilation_w = 1; + ck::index_t in_left_pad_h = 1; + ck::index_t in_left_pad_w = 1; + ck::index_t in_right_pad_h = 1; + ck::index_t in_right_pad_w = 1; + + if(argc == 4) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + nrepeat = std::stoi(argv[3]); + } + else if(argc == 19) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + nrepeat = std::stoi(argv[3]); + + N = std::stoi(argv[4]); + K = std::stoi(argv[5]); + C = std::stoi(argv[6]); + Y = std::stoi(argv[7]); + X = std::stoi(argv[8]); + Hi = std::stoi(argv[9]); + Wi = std::stoi(argv[10]); + conv_stride_h = std::stoi(argv[11]); + conv_stride_w = std::stoi(argv[12]); + conv_dilation_h = std::stoi(argv[13]); + conv_dilation_w = std::stoi(argv[14]); + in_left_pad_h = std::stoi(argv[15]); + in_left_pad_w = std::stoi(argv[16]); + in_right_pad_h = std::stoi(argv[17]); + in_right_pad_w = std::stoi(argv[18]); + } + else + { + printf("arg1: verification (0=no, 1=yes)\n"); + printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); + printf("arg3: run kernel # of times (>1)\n"); + printf("arg4 to 18: N, K, C, Y, X, Hi, Wi, Sy, Sx, Dy, Dx, LeftPy, LeftPx, RightPy, " + "RightPx\n"); + exit(0); + } + + const ck::index_t YEff = (Y - 1) * conv_dilation_h + 1; + const ck::index_t XEff = (X - 1) * conv_dilation_w + 1; + + const ck::index_t Ho = (Hi + in_left_pad_h + in_right_pad_h - YEff) / conv_stride_h + 1; + const ck::index_t Wo = (Wi + in_left_pad_w + in_right_pad_w - XEff) / conv_stride_w + 1; + + const std::vector conv_filter_strides{{conv_stride_h, conv_stride_w}}; + const std::vector conv_filter_dilations{{conv_dilation_h, conv_dilation_w}}; + const std::vector input_left_pads{{in_left_pad_h, in_left_pad_w}}; + const std::vector input_right_pads{{in_right_pad_h, in_right_pad_w}}; + + // tensor layout + auto f_host_tensor_descriptor = [](std::size_t N_, + std::size_t C_, + std::size_t H, + std::size_t W, + auto layout) { + if constexpr(ck::is_same::value || + ck::is_same::value || + ck::is_same::value) + { + return HostTensorDescriptor(std::vector({N_, C_, H, W}), + std::vector({C_ * H * W, H * W, W, 1})); + } + else if constexpr(ck::is_same::value || + ck::is_same::value || + ck::is_same::value) + { + return HostTensorDescriptor(std::vector({N_, C_, H, W}), + std::vector({C_ * H * W, 1, W * C_, C_})); + } + }; + + Tensor in_n_c_hi_wi(f_host_tensor_descriptor(N, C, Hi, Wi, InLayout{})); + Tensor wei_k_c_y_x(f_host_tensor_descriptor(K, C, Y, X, WeiLayout{})); + Tensor out_n_k_ho_wo_host_result( + f_host_tensor_descriptor(N, K, Ho, Wo, OutLayout{})); + Tensor out_n_k_ho_wo_device_result( + f_host_tensor_descriptor(N, K, Ho, Wo, OutLayout{})); + + // bias: assume contiguous 1d vector + Tensor bias_k( + HostTensorDescriptor(std::vector({static_cast(K)}))); + + std::cout << "in_n_c_hi_wi: " << in_n_c_hi_wi.mDesc << std::endl; + std::cout << "wei_k_c_y_x: " << wei_k_c_y_x.mDesc << std::endl; + std::cout << "out_n_k_ho_wo: " << out_n_k_ho_wo_host_result.mDesc << std::endl; + std::cout << "bias_k: " << bias_k.mDesc << std::endl; + + switch(init_method) + { + case 0: break; + case 1: + in_n_c_hi_wi.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + wei_k_c_y_x.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + bias_k.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + break; + default: + in_n_c_hi_wi.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + wei_k_c_y_x.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + bias_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + } + + DeviceMem in_device_buf(sizeof(InDataType) * in_n_c_hi_wi.mDesc.GetElementSpace()); + DeviceMem wei_device_buf(sizeof(WeiDataType) * wei_k_c_y_x.mDesc.GetElementSpace()); + DeviceMem out_device_buf(sizeof(OutDataType) * + out_n_k_ho_wo_device_result.mDesc.GetElementSpace()); + DeviceMem bias_device_buf(sizeof(OutDataType) * bias_k.mDesc.GetElementSpace()); + + in_device_buf.ToDevice(in_n_c_hi_wi.mData.data()); + wei_device_buf.ToDevice(wei_k_c_y_x.mData.data()); + bias_device_buf.ToDevice(bias_k.mData.data()); + + auto conv = DeviceConvFwdInstance{}; + auto invoker = conv.MakeInvoker(); + auto argument = + conv.MakeArgument(static_cast(in_device_buf.GetDeviceBuffer()), + static_cast(wei_device_buf.GetDeviceBuffer()), + static_cast(out_device_buf.GetDeviceBuffer()), + static_cast(bias_device_buf.GetDeviceBuffer()), + N, + K, + C, + std::vector{{Hi, Wi}}, + std::vector{{Y, X}}, + std::vector{{Ho, Wo}}, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + InElementOp{}, + WeiElementOp{}, + OutElementOp{}); + + if(!conv.IsSupportedArgument(argument)) + { + throw std::runtime_error( + "wrong! device operator with the specified compilation parameters does " + "not support this problem"); + } + + float ave_time = invoker.Run(argument, nrepeat); + + std::size_t flop = std::size_t(2) * N * K * Ho * Wo * C * Y * X; + + std::size_t num_btype = sizeof(InDataType) * (N * C * Hi * Wi) + + sizeof(WeiDataType) * (K * C * Y * X) + + sizeof(OutDataType) * (N * K * Ho * Wo) + sizeof(OutDataType) * (K); + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s" + << std::endl; + + if(do_verification) + { + host_reference_calculation(in_n_c_hi_wi, + wei_k_c_y_x, + out_n_k_ho_wo_host_result, + bias_k, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + InElementOp{}, + WeiElementOp{}, + OutElementOp{}); + + out_device_buf.FromDevice(out_n_k_ho_wo_device_result.mData.data()); + + check_error(out_n_k_ho_wo_host_result, out_n_k_ho_wo_device_result); + } +} diff --git a/example/6_conv2d_fwd_xdl_bias_relu_add/README.md b/example/6_conv2d_fwd_xdl_bias_relu_add/README.md new file mode 100644 index 0000000000..eed5605a9e --- /dev/null +++ b/example/6_conv2d_fwd_xdl_bias_relu_add/README.md @@ -0,0 +1,61 @@ +# Instructions for ```conv_xdl_bias_relu_add``` Example + +## Docker script +```bash +docker run \ +-it \ +--rm \ +--privileged \ +--group-add sudo \ +-w /root/workspace \ +-v ${PATH_TO_LOCAL_WORKSPACE}:/root/workspace \ +rocm/tensorflow:rocm4.3.1-tf2.6-dev \ +/bin/bash +``` + +## Build ```conv_xdl_bias_relu_add``` +```bash +mkdir build && cd build +``` + +```bash +# Need to specify target ID, example below is gfx908 +cmake \ +-D BUILD_DEV=OFF \ +-D CMAKE_BUILD_TYPE=Release \ +-D CMAKE_CXX_FLAGS="-DCK_AMD_GPU_GFX908 --amdgpu-target=gfx908 -O3 " \ +-D CMAKE_CXX_COMPILER=/opt/rocm/bin/hipcc \ +-D CMAKE_PREFIX_PATH=/opt/rocm \ +.. +``` + +```bash + make -j conv_xdl_bias_relu_add +``` + +## Run ```conv_xdl_bias_relu_add``` +```bash +#arg1: verification (0=no, 1=yes) +#arg2: initialization (0=no init, 1=integer value, 2=decimal value) +#arg3: run kernel # of times (>1) +#arg4 to 18: N, K, C, Y, X, Hi, Wi, Sy, Sx, Dy, Dx, LeftPy, LeftPx, RightPy, RightPx +./example/conv_xdl_bias_relu_add 0 1 5 +``` + +Result (MI100 @ 1087Mhz, 133.5TFlops peak FP16) +``` +in_n_c_hi_wi: dim 4, lengths {128, 192, 71, 71}, strides {967872, 1, 13632, 192} +wei_k_c_y_x: dim 4, lengths {256, 192, 3, 3}, strides {1728, 1, 576, 192} +out_n_k_ho_wo: dim 4, lengths {128, 256, 36, 36}, strides {331776, 1, 9216, 256} +bias_k: dim 1, lengths {256}, strides {1} +resi_n_k_ho_wo: dim 4, lengths {128, 256, 36, 36}, strides {331776, 1, 9216, 256} +arg.a_grid_desc_k0_m_k1_{216, 165888, 8} +arg.b_grid_desc_k0_n_k1_{216, 256, 8} +arg.c_grid_desc_m_n_{ 165888, 256} +arg.c0_grid_desc_m_n_{ 165888, 256} +arg.c1_grid_desc_m_n_{ 165888, 256} +launch_and_time_kernel: grid_dim {1296, 1, 1}, block_dim {256, 1, 1} +Warm up +Start running 5 times... +Perf: 1.71779 ms, 85.4396 TFlops, 194.2 GB/s +``` diff --git a/example/4_conv_xdl_bias_relu_add/conv_xdl_bias_relu_add.cpp b/example/6_conv2d_fwd_xdl_bias_relu_add/conv2d_fwd_xdl_bias_relu_add.cpp similarity index 65% rename from example/4_conv_xdl_bias_relu_add/conv_xdl_bias_relu_add.cpp rename to example/6_conv2d_fwd_xdl_bias_relu_add/conv2d_fwd_xdl_bias_relu_add.cpp index 71f73a280f..1353b65248 100644 --- a/example/4_conv_xdl_bias_relu_add/conv_xdl_bias_relu_add.cpp +++ b/example/6_conv2d_fwd_xdl_bias_relu_add/conv2d_fwd_xdl_bias_relu_add.cpp @@ -11,148 +11,8 @@ #include "host_tensor_generator.hpp" #include "device_tensor.hpp" #include "tensor_layout.hpp" -#include "example/4_conv_xdl_bias_relu_add/include/device_conv_fwd_xdl_bias_activation_add.hpp" -#include "example/4_conv_xdl_bias_relu_add/include/device_conv_fwd_xdl_bias_activation_add_nhwc_kyxc_nhwk.hpp" - -struct PassThrough -{ - template - __host__ __device__ constexpr T operator()(T v) const - { - return v; - } -}; - -struct BiasLeakyReluAdd -{ - template - __host__ constexpr float operator()(float v0, T1 v1, T2 v2) const - { - float a = v0 + v1; - float b = 0.1 * a; - float c = b > 0 ? b : 0; - float d = c + v2; - - return d; - } - - template - __device__ constexpr float operator()(float v0, T1 v1, T2 v2) const - { -#if 0 - // this use not too many registers, but use fp64 mul - float a = v0 + v1; - float b = 0.1 * a; - float c = b > 0 ? b : 0; - float d = c + v2; - - return d; -#elif 0 - // this spill register - float a = v0 + v1; - float b = float(0.1) * a; - float c = b > 0 ? b : 0; - float d = c + v2; - - return d; -#elif 0 - // this use lots of registers (but no spill) - constexpr float alpha = 0.1; - constexpr float alpha_inv = 1.0 / alpha; - - float a = v2 * alpha_inv; - float b = v1 + v0; - float c = b > 0 ? b : 0; - float d = alpha * (a + c); - - return d; -#elif 1 - // this use lots of registers (but no spill), 89 Tflops - constexpr float alpha = 0.1; - constexpr float alpha_inv = 1.0 / alpha; - - float a = v2 * alpha_inv; - float b = v1 + v0; - float c = max(b, float(0)); - float d = alpha * (a + c); - - return d; -#elif 1 - // this spill registers, 89 Tflops - float a = v0 + v1; - float alpha = 0.1; - - float b; - asm volatile("\n \ - v_mul_f32_e32 %0, %1, %2 \n \ - " - : "=v"(b) - : "s"(alpha), "v"(a)); - - float c = b > 0 ? b : 0; - float d = c + v2; - - return d; -#endif - } -}; - -struct BiasReluAdd -{ - template - __host__ constexpr float operator()(float v0, T1 v1, T2 v2) const - { - float b = v0 + v1; - float c = b > 0 ? b : 0; - float d = c + v2; - - return d; - } - - template - __device__ constexpr float operator()(float v0, T1 v1, T2 v2) const - { -#if 0 - float a = v1 + v0; - float b = max(a, float(0)); - float c = b + v2; - - return c; -#else - float a = v1 + v2; - float b = v2; - - float c = (v0 > -v1) ? a + v0 : v2; - - return c; -#endif - } -}; - -struct BiasLeakyRelu -{ - template - __host__ constexpr float operator()(float v0, T1 v1, T2) const - { - float a = v0 + v1; - float b = 0.1 * a; - float c = b > 0 ? b : 0; - - return c; - } - - template - __device__ constexpr float operator()(float v0, T1 v1, T2) const - { - constexpr float alpha = 0.1; - - float b = v1 + v0; - float c = max(b, float(0)); - float d = alpha * c; - - return d; - } -}; +#include "device_conv2d_fwd_xdl_c_shuffle_bias_activation_add_nhwc_kyxc_nhwk.hpp" +#include "element_wise_operation.hpp" using InDataType = ck::half_t; using WeiDataType = ck::half_t; @@ -166,17 +26,21 @@ using InLayout = ck::tensor_layout::convolution::NHWC; using WeiLayout = ck::tensor_layout::convolution::KYXC; using OutLayout = ck::tensor_layout::convolution::NHWK; -using InElementOp = PassThrough; -using WeiElementOp = PassThrough; -using OutElementOp = BiasReluAdd; +using InElementOp = ck::tensor_operation::element_wise::PassThrough; +using WeiElementOp = ck::tensor_operation::element_wise::PassThrough; +using OutElementOp = ck::tensor_operation::element_wise::AddReluAdd; + +static constexpr auto ConvFwdDefault = + ck::tensor_operation::device::ConvolutionForwardSpecialization_t::Default; // clang-format off -using DeviceConvFwdInstance = - //################################################################| NDim| InData| WeiData| OutData| AccData| In| Wei| Out| In| Wei| Out| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| ABlockLds| BBlockLds| - //################################################################| Spatial| Type| Type| Type| Type| Layout| Layout| Layout| Elementwise| Elementwise| Elementwise| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadSlice| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ThreadSlice| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| SrcDstVectorDim| DstScalar| AddExtraM| AddExtraN| - //################################################################| | | | | | | | | Operation| Operation| Operation| | | | | | | | Wave| Wave| Lengths_K0_N_K1| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| Lengths_K0_N_K1| Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerVector| | | - //################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - ck::tensor_operation::device::DeviceConvFwdXdl_bias_activation_add< 2, InDataType, WeiDataType, OutDataType, AccDataType, InLayout, WeiLayout, OutLayout, InElementOp, WeiElementOp, OutElementOp, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<1, 2, 8>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, S<1, 4, 8>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 7, 1, true, true>; +using DeviceConvFwdInstance = ck::tensor_operation::device:: + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Add_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K +// | InData| WeiData| OutData| AccData| In| Wei| Out| ConvForward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| +// | Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| +// | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| +// | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + , S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>; // clang-format on template & in_n_c_hi_wi, const std::vector& conv_strides, const std::vector& conv_dilations, const std::vector& in_left_pads, - const std::vector&, + const std::vector& /* in_right_pads */, const InElementOp& in_element_op, const WeiElementOp& wei_element_op, const OutElementOp& out_element_op) @@ -218,7 +82,14 @@ void host_reference_calculation(const Tensor& in_n_c_hi_wi, } } - out_n_k_ho_wo(n, k, ho, wo) = out_element_op(v, bias_k(k), resi_n_k_ho_wo(n, k, ho, wo)); + double v2 = out_n_k_ho_wo(n, k, ho, wo); + + out_element_op(v2, + v, + static_cast(bias_k(k)), + static_cast(resi_n_k_ho_wo(n, k, ho, wo))); + + out_n_k_ho_wo(n, k, ho, wo) = v2; }; make_ParallelTensorFunctor(f_nchw, @@ -358,8 +229,8 @@ int main(int argc, char* argv[]) default: in_n_c_hi_wi.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); wei_k_c_y_x.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); - bias_k.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); - resi_n_k_ho_wo.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + bias_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + resi_n_k_ho_wo.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); } DeviceMem in_device_buf(sizeof(InDataType) * in_n_c_hi_wi.mDesc.GetElementSpace()); @@ -399,8 +270,8 @@ int main(int argc, char* argv[]) if(!conv.IsSupportedArgument(argument)) { throw std::runtime_error( - "wrong! device_conv with the specified compilation parameters does " - "not support this Conv problem"); + "wrong! device operator with the specified compilation parameters does " + "not support this problem"); } float ave_time = invoker.Run(argument, nrepeat); diff --git a/example/7_conv2d_fwd_xdl_bias_relu_atomic_add/README.md b/example/7_conv2d_fwd_xdl_bias_relu_atomic_add/README.md new file mode 100644 index 0000000000..eed5605a9e --- /dev/null +++ b/example/7_conv2d_fwd_xdl_bias_relu_atomic_add/README.md @@ -0,0 +1,61 @@ +# Instructions for ```conv_xdl_bias_relu_add``` Example + +## Docker script +```bash +docker run \ +-it \ +--rm \ +--privileged \ +--group-add sudo \ +-w /root/workspace \ +-v ${PATH_TO_LOCAL_WORKSPACE}:/root/workspace \ +rocm/tensorflow:rocm4.3.1-tf2.6-dev \ +/bin/bash +``` + +## Build ```conv_xdl_bias_relu_add``` +```bash +mkdir build && cd build +``` + +```bash +# Need to specify target ID, example below is gfx908 +cmake \ +-D BUILD_DEV=OFF \ +-D CMAKE_BUILD_TYPE=Release \ +-D CMAKE_CXX_FLAGS="-DCK_AMD_GPU_GFX908 --amdgpu-target=gfx908 -O3 " \ +-D CMAKE_CXX_COMPILER=/opt/rocm/bin/hipcc \ +-D CMAKE_PREFIX_PATH=/opt/rocm \ +.. +``` + +```bash + make -j conv_xdl_bias_relu_add +``` + +## Run ```conv_xdl_bias_relu_add``` +```bash +#arg1: verification (0=no, 1=yes) +#arg2: initialization (0=no init, 1=integer value, 2=decimal value) +#arg3: run kernel # of times (>1) +#arg4 to 18: N, K, C, Y, X, Hi, Wi, Sy, Sx, Dy, Dx, LeftPy, LeftPx, RightPy, RightPx +./example/conv_xdl_bias_relu_add 0 1 5 +``` + +Result (MI100 @ 1087Mhz, 133.5TFlops peak FP16) +``` +in_n_c_hi_wi: dim 4, lengths {128, 192, 71, 71}, strides {967872, 1, 13632, 192} +wei_k_c_y_x: dim 4, lengths {256, 192, 3, 3}, strides {1728, 1, 576, 192} +out_n_k_ho_wo: dim 4, lengths {128, 256, 36, 36}, strides {331776, 1, 9216, 256} +bias_k: dim 1, lengths {256}, strides {1} +resi_n_k_ho_wo: dim 4, lengths {128, 256, 36, 36}, strides {331776, 1, 9216, 256} +arg.a_grid_desc_k0_m_k1_{216, 165888, 8} +arg.b_grid_desc_k0_n_k1_{216, 256, 8} +arg.c_grid_desc_m_n_{ 165888, 256} +arg.c0_grid_desc_m_n_{ 165888, 256} +arg.c1_grid_desc_m_n_{ 165888, 256} +launch_and_time_kernel: grid_dim {1296, 1, 1}, block_dim {256, 1, 1} +Warm up +Start running 5 times... +Perf: 1.71779 ms, 85.4396 TFlops, 194.2 GB/s +``` diff --git a/example/7_conv2d_fwd_xdl_bias_relu_atomic_add/conv2d_fwd_xdl_bias_relu_atomic_add.cpp b/example/7_conv2d_fwd_xdl_bias_relu_atomic_add/conv2d_fwd_xdl_bias_relu_atomic_add.cpp new file mode 100644 index 0000000000..c47c094385 --- /dev/null +++ b/example/7_conv2d_fwd_xdl_bias_relu_atomic_add/conv2d_fwd_xdl_bias_relu_atomic_add.cpp @@ -0,0 +1,299 @@ +#include +#include +#include +#include +#include +#include +#include "config.hpp" +#include "print.hpp" +#include "device.hpp" +#include "host_tensor.hpp" +#include "host_tensor_generator.hpp" +#include "device_tensor.hpp" +#include "tensor_layout.hpp" +#include "device_conv2d_fwd_xdl_c_shuffle_bias_activation_nhwc_kyxc_nhwk.hpp" +#include "element_wise_operation.hpp" + +using InDataType = ck::half_t; +using WeiDataType = ck::half_t; +using OutDataType = ck::half_t; +using AccDataType = float; + +template +using S = ck::Sequence; + +using InLayout = ck::tensor_layout::convolution::NHWC; +using WeiLayout = ck::tensor_layout::convolution::KYXC; +using OutLayout = ck::tensor_layout::convolution::NHWK; + +using InElementOp = ck::tensor_operation::element_wise::PassThrough; +using WeiElementOp = ck::tensor_operation::element_wise::PassThrough; +using OutElementOp = ck::tensor_operation::element_wise::AddRelu; + +static constexpr auto MemoryAtomicAdd = ck::InMemoryDataOperationEnum_t::AtomicAdd; + +static constexpr auto ConvFwdDefault = + ck::tensor_operation::device::ConvolutionForwardSpecialization_t::Default; + +// clang-format off +using DeviceConvFwdInstance = ck::tensor_operation::device:: + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K + // clang-format off +// | InData| WeiData| OutData| AccData| In| Wei| Out| Out| ConvForward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| +// | Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| GlobalMemory| Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| +// | | | | | Operation| Operation| Operation| DataOperation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| +// | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + , S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 8, 1, 1,32>, 2>; +// clang-format on + +template +void host_reference_calculation(const Tensor& in_n_c_hi_wi, + const Tensor& wei_k_c_y_x, + Tensor& out_n_k_ho_wo, + const Tensor& bias_k, + const std::vector& conv_strides, + const std::vector& conv_dilations, + const std::vector& in_left_pads, + const std::vector& /* in_right_pads */, + const InElementOp& in_element_op, + const WeiElementOp& wei_element_op, + const OutElementOp& out_element_op) +{ + auto f_nchw = [&](auto n, auto k, auto ho, auto wo) { + double v = 0; + for(int c = 0; c < wei_k_c_y_x.mDesc.GetLengths()[1]; ++c) + { + for(int y = 0; y < wei_k_c_y_x.mDesc.GetLengths()[2]; ++y) + { + int hi = ho * conv_strides[0] + y * conv_dilations[0] - in_left_pads[0]; + for(int x = 0; x < wei_k_c_y_x.mDesc.GetLengths()[3]; ++x) + { + int wi = wo * conv_strides[1] + x * conv_dilations[1] - in_left_pads[1]; + if(hi >= 0 && hi < in_n_c_hi_wi.mDesc.GetLengths()[2] && wi >= 0 && + wi < in_n_c_hi_wi.mDesc.GetLengths()[3]) + { + v += in_element_op(static_cast(in_n_c_hi_wi(n, c, hi, wi))) * + wei_element_op(static_cast(wei_k_c_y_x(k, c, y, x))); + } + } + } + } + + out_n_k_ho_wo(n, k, ho, wo) += out_element_op(v, bias_k(k)); + }; + + make_ParallelTensorFunctor(f_nchw, + out_n_k_ho_wo.mDesc.GetLengths()[0], + out_n_k_ho_wo.mDesc.GetLengths()[1], + out_n_k_ho_wo.mDesc.GetLengths()[2], + out_n_k_ho_wo.mDesc.GetLengths()[3])( + std::thread::hardware_concurrency()); +} + +int main(int argc, char* argv[]) +{ + bool do_verification = 0; + int init_method = 0; + int nrepeat = 5; + + // Conv shape + ck::index_t N = 128; + ck::index_t K = 256; + ck::index_t C = 192; + ck::index_t Y = 3; + ck::index_t X = 3; + ck::index_t Hi = 71; + ck::index_t Wi = 71; + ck::index_t conv_stride_h = 2; + ck::index_t conv_stride_w = 2; + ck::index_t conv_dilation_h = 1; + ck::index_t conv_dilation_w = 1; + ck::index_t in_left_pad_h = 1; + ck::index_t in_left_pad_w = 1; + ck::index_t in_right_pad_h = 1; + ck::index_t in_right_pad_w = 1; + + if(argc == 4) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + nrepeat = std::stoi(argv[3]); + } + else if(argc == 19) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + nrepeat = std::stoi(argv[3]); + + N = std::stoi(argv[4]); + K = std::stoi(argv[5]); + C = std::stoi(argv[6]); + Y = std::stoi(argv[7]); + X = std::stoi(argv[8]); + Hi = std::stoi(argv[9]); + Wi = std::stoi(argv[10]); + conv_stride_h = std::stoi(argv[11]); + conv_stride_w = std::stoi(argv[12]); + conv_dilation_h = std::stoi(argv[13]); + conv_dilation_w = std::stoi(argv[14]); + in_left_pad_h = std::stoi(argv[15]); + in_left_pad_w = std::stoi(argv[16]); + in_right_pad_h = std::stoi(argv[17]); + in_right_pad_w = std::stoi(argv[18]); + } + else + { + printf("arg1: verification (0=no, 1=yes)\n"); + printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); + printf("arg3: run kernel # of times (>1)\n"); + printf("arg4 to 18: N, K, C, Y, X, Hi, Wi, Sy, Sx, Dy, Dx, LeftPy, LeftPx, RightPy, " + "RightPx\n"); + exit(0); + } + + const ck::index_t YEff = (Y - 1) * conv_dilation_h + 1; + const ck::index_t XEff = (X - 1) * conv_dilation_w + 1; + + const ck::index_t Ho = (Hi + in_left_pad_h + in_right_pad_h - YEff) / conv_stride_h + 1; + const ck::index_t Wo = (Wi + in_left_pad_w + in_right_pad_w - XEff) / conv_stride_w + 1; + + const std::vector conv_filter_strides{{conv_stride_h, conv_stride_w}}; + const std::vector conv_filter_dilations{{conv_dilation_h, conv_dilation_w}}; + const std::vector input_left_pads{{in_left_pad_h, in_left_pad_w}}; + const std::vector input_right_pads{{in_right_pad_h, in_right_pad_w}}; + + // tensor layout + auto f_host_tensor_descriptor = [](std::size_t N_, + std::size_t C_, + std::size_t H, + std::size_t W, + auto layout) { + if constexpr(ck::is_same::value || + ck::is_same::value || + ck::is_same::value) + { + return HostTensorDescriptor(std::vector({N_, C_, H, W}), + std::vector({C_ * H * W, H * W, W, 1})); + } + else if constexpr(ck::is_same::value || + ck::is_same::value || + ck::is_same::value) + { + return HostTensorDescriptor(std::vector({N_, C_, H, W}), + std::vector({C_ * H * W, 1, W * C_, C_})); + } + }; + + Tensor in_n_c_hi_wi(f_host_tensor_descriptor(N, C, Hi, Wi, InLayout{})); + Tensor wei_k_c_y_x(f_host_tensor_descriptor(K, C, Y, X, WeiLayout{})); + Tensor out_n_k_ho_wo_host_result( + f_host_tensor_descriptor(N, K, Ho, Wo, OutLayout{})); + Tensor out_n_k_ho_wo_device_result( + f_host_tensor_descriptor(N, K, Ho, Wo, OutLayout{})); + + // bias: assume contiguous 1d vector + Tensor bias_k( + HostTensorDescriptor(std::vector({static_cast(K)}))); + + std::cout << "in_n_c_hi_wi: " << in_n_c_hi_wi.mDesc << std::endl; + std::cout << "wei_k_c_y_x: " << wei_k_c_y_x.mDesc << std::endl; + std::cout << "out_n_k_ho_wo: " << out_n_k_ho_wo_host_result.mDesc << std::endl; + std::cout << "bias_k: " << bias_k.mDesc << std::endl; + + switch(init_method) + { + case 0: break; + case 1: + in_n_c_hi_wi.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + wei_k_c_y_x.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + out_n_k_ho_wo_host_result.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + bias_k.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + break; + default: + in_n_c_hi_wi.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + wei_k_c_y_x.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + out_n_k_ho_wo_host_result.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + bias_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + } + + DeviceMem in_device_buf(sizeof(InDataType) * in_n_c_hi_wi.mDesc.GetElementSpace()); + DeviceMem wei_device_buf(sizeof(WeiDataType) * wei_k_c_y_x.mDesc.GetElementSpace()); + DeviceMem out_device_buf(sizeof(OutDataType) * + out_n_k_ho_wo_device_result.mDesc.GetElementSpace()); + DeviceMem bias_device_buf(sizeof(OutDataType) * bias_k.mDesc.GetElementSpace()); + + in_device_buf.ToDevice(in_n_c_hi_wi.mData.data()); + wei_device_buf.ToDevice(wei_k_c_y_x.mData.data()); + out_device_buf.ToDevice(out_n_k_ho_wo_host_result.mData.data()); + bias_device_buf.ToDevice(bias_k.mData.data()); + + auto conv = DeviceConvFwdInstance{}; + auto invoker = conv.MakeInvoker(); + auto argument = + conv.MakeArgument(static_cast(in_device_buf.GetDeviceBuffer()), + static_cast(wei_device_buf.GetDeviceBuffer()), + static_cast(out_device_buf.GetDeviceBuffer()), + static_cast(bias_device_buf.GetDeviceBuffer()), + N, + K, + C, + std::vector{{Hi, Wi}}, + std::vector{{Y, X}}, + std::vector{{Ho, Wo}}, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + InElementOp{}, + WeiElementOp{}, + OutElementOp{}); + + if(!conv.IsSupportedArgument(argument)) + { + throw std::runtime_error( + "wrong! device operator with the specified compilation parameters does " + "not support this problem"); + } + + float ave_time = invoker.Run(argument, nrepeat); + + std::size_t flop = std::size_t(2) * N * K * Ho * Wo * C * Y * X; + + std::size_t num_btype = sizeof(InDataType) * (N * C * Hi * Wi) + + sizeof(WeiDataType) * (K * C * Y * X) + + sizeof(OutDataType) * (N * K * Ho * Wo) + sizeof(OutDataType) * (K); + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s" + << std::endl; + + if(do_verification) + { + host_reference_calculation(in_n_c_hi_wi, + wei_k_c_y_x, + out_n_k_ho_wo_host_result, + bias_k, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + InElementOp{}, + WeiElementOp{}, + OutElementOp{}); + + out_device_buf.FromDevice(out_n_k_ho_wo_device_result.mData.data()); + + check_error(out_n_k_ho_wo_host_result, out_n_k_ho_wo_device_result); + } +} diff --git a/example/CMakeLists.txt b/example/CMakeLists.txt index e2fe23a063..6f231bcdf0 100644 --- a/example/CMakeLists.txt +++ b/example/CMakeLists.txt @@ -12,16 +12,22 @@ include_directories(BEFORE ) set(GEMM_XDL_SOURCE 1_gemm_xdl/gemm_xdl.cpp) -set(GEMM_XDL_BIAS_RELU_ADD_SOURCE 2_gemm_xdl_bias_relu_add/gemm_xdl_bias_relu_add.cpp) -set(CONV_XDL_SOURCE 3_conv_xdl/conv_xdl.cpp) -set(CONV_XDL_BIAS_RELU_ADD_SOURCE 4_conv_xdl_bias_relu_add/conv_xdl_bias_relu_add.cpp) +set(GEMM_XDL_BIAS_RELU_ADD_SOURCE 3_gemm_xdl_bias_relu_add/gemm_xdl_bias_relu_add.cpp) +set(CONV2D_FWD_XDL_SOURCE 4_conv2d_fwd_xdl/conv2d_fwd_xdl.cpp) +set(CONV2D_FWD_XDL_BIAS_RELU_SOURCE 5_conv2d_fwd_xdl_bias_relu/conv2d_fwd_xdl_bias_relu.cpp) +set(CONV2D_FWD_XDL_BIAS_RELU_ADD_SOURCE 6_conv2d_fwd_xdl_bias_relu_add/conv2d_fwd_xdl_bias_relu_add.cpp) +set(CONV2D_FWD_XDL_BIAS_RELU_ATOMIC_ADD_SOURCE 7_conv2d_fwd_xdl_bias_relu_atomic_add/conv2d_fwd_xdl_bias_relu_atomic_add.cpp) add_executable(gemm_xdl ${GEMM_XDL_SOURCE}) add_executable(gemm_xdl_bias_relu_add ${GEMM_XDL_BIAS_RELU_ADD_SOURCE}) -add_executable(conv_xdl ${CONV_XDL_SOURCE}) -add_executable(conv_xdl_bias_relu_add ${CONV_XDL_BIAS_RELU_ADD_SOURCE}) +add_executable(conv2d_fwd_xdl ${CONV2D_FWD_XDL_SOURCE}) +add_executable(conv2d_fwd_xdl_bias_relu ${CONV2D_FWD_XDL_BIAS_RELU_SOURCE}) +add_executable(conv2d_fwd_xdl_bias_relu_add ${CONV2D_FWD_XDL_BIAS_RELU_ADD_SOURCE}) +add_executable(conv2d_fwd_xdl_bias_relu_atomic_add ${CONV2D_FWD_XDL_BIAS_RELU_ATOMIC_ADD_SOURCE}) target_link_libraries(gemm_xdl PRIVATE host_tensor) target_link_libraries(gemm_xdl_bias_relu_add PRIVATE host_tensor) -target_link_libraries(conv_xdl PRIVATE host_tensor) -target_link_libraries(conv_xdl_bias_relu_add PRIVATE host_tensor) +target_link_libraries(conv2d_fwd_xdl PRIVATE host_tensor) +target_link_libraries(conv2d_fwd_xdl_bias_relu PRIVATE host_tensor) +target_link_libraries(conv2d_fwd_xdl_bias_relu_add PRIVATE host_tensor) +target_link_libraries(conv2d_fwd_xdl_bias_relu_atomic_add PRIVATE host_tensor) diff --git a/host/host_tensor/src/host_tensor.cpp b/host/host_tensor/src/host_tensor.cpp index 4e3cdbdccd..a0d4894339 100644 --- a/host/host_tensor/src/host_tensor.cpp +++ b/host/host_tensor/src/host_tensor.cpp @@ -1,4 +1,3 @@ -#include #include #include "host_tensor.hpp" @@ -26,8 +25,12 @@ std::size_t HostTensorDescriptor::GetElementSize() const std::size_t HostTensorDescriptor::GetElementSpace() const { - auto ls = mLens | boost::adaptors::transformed([](std::size_t v) { return v - 1; }); - return std::inner_product(ls.begin(), ls.end(), mStrides.begin(), std::size_t{0}) + 1; + std::size_t space = 1; + for(int i = 0; i < mLens.size(); ++i) + { + space += (mLens[i] - 1) * mStrides[i]; + } + return space; } const std::vector& HostTensorDescriptor::GetLengths() const { return mLens; } diff --git a/profiler/CMakeLists.txt b/profiler/CMakeLists.txt index 62d8d30afc..6ef9cd6014 100644 --- a/profiler/CMakeLists.txt +++ b/profiler/CMakeLists.txt @@ -30,21 +30,65 @@ target_compile_features(device_gemm_instance PUBLIC) set_target_properties(device_gemm_instance PROPERTIES POSITION_INDEPENDENT_CODE ON) install(TARGETS device_gemm_instance LIBRARY DESTINATION lib) -# device_conv_instance -set(DEVICE_CONV_INSTANCE_SOURCE - ${PROJECT_SOURCE_DIR}/device_operation/device_conv_xdl_instance_f32_f32_f32_nhwc_kyxc_nhwk.cpp; - ${PROJECT_SOURCE_DIR}/device_operation/device_conv_xdl_instance_f16_f16_f16_nhwc_kyxc_nhwk.cpp; +# device_conv2d_fwd_instance +set(DEVICE_CONV2D_FWD_INSTANCE_SOURCE + ${PROJECT_SOURCE_DIR}/device_operation/device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f32_instance.cpp; + ${PROJECT_SOURCE_DIR}/device_operation/device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f16_instance.cpp; + ${PROJECT_SOURCE_DIR}/device_operation/device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk_f16_instance.cpp; ) -add_library(device_conv_instance SHARED ${DEVICE_CONV_INSTANCE_SOURCE}) -target_include_directories(device_conv_instance SYSTEM PUBLIC $) -target_compile_features(device_conv_instance PUBLIC) -set_target_properties(device_conv_instance PROPERTIES POSITION_INDEPENDENT_CODE ON) -install(TARGETS device_conv_instance LIBRARY DESTINATION lib) +add_library(device_conv2d_fwd_instance SHARED ${DEVICE_CONV2D_FWD_INSTANCE_SOURCE}) +target_include_directories(device_conv2d_fwd_instance SYSTEM PUBLIC $) +target_compile_features(device_conv2d_fwd_instance PUBLIC) +set_target_properties(device_conv2d_fwd_instance PROPERTIES POSITION_INDEPENDENT_CODE ON) +install(TARGETS device_conv2d_fwd_instance LIBRARY DESTINATION lib) + +# device_conv2d_fwd_bias_relu_instance +set(DEVICE_CONV2D_FWD_BIAS_RELU_INSTANCE_SOURCE + ${PROJECT_SOURCE_DIR}/device_operation/device_conv2d_fwd_xdl_c_shuffle_bias_relu_nhwc_kyxc_nhwk_f16_instance.cpp; +) + +add_library(device_conv2d_fwd_bias_relu_instance SHARED ${DEVICE_CONV2D_FWD_BIAS_RELU_INSTANCE_SOURCE}) +target_include_directories(device_conv2d_fwd_bias_relu_instance SYSTEM PUBLIC $) +target_compile_features(device_conv2d_fwd_bias_relu_instance PUBLIC) +set_target_properties(device_conv2d_fwd_bias_relu_instance PROPERTIES POSITION_INDEPENDENT_CODE ON) +install(TARGETS device_conv2d_fwd_bias_relu_instance LIBRARY DESTINATION lib) + +# device_conv2d_fwd_bias_relu_add_instance +set(DEVICE_CONV2D_FWD_BIAS_RELU_ADD_INSTANCE_SOURCE + ${PROJECT_SOURCE_DIR}/device_operation/device_conv2d_fwd_xdl_c_shuffle_bias_relu_add_nhwc_kyxc_nhwk_f16_instance.cpp; +) + +add_library(device_conv2d_fwd_bias_relu_add_instance SHARED ${DEVICE_CONV2D_FWD_BIAS_RELU_ADD_INSTANCE_SOURCE}) +target_include_directories(device_conv2d_fwd_bias_relu_add_instance SYSTEM PUBLIC $) +target_compile_features(device_conv2d_fwd_bias_relu_add_instance PUBLIC) +set_target_properties(device_conv2d_fwd_bias_relu_add_instance PROPERTIES POSITION_INDEPENDENT_CODE ON) +install(TARGETS device_conv2d_fwd_bias_relu_add_instance LIBRARY DESTINATION lib) + +# device_conv2d_fwd_bias_relu_atomic_add_instance +set(DEVICE_CONV2D_FWD_BIAS_RELU_ATOMIC_ADD_INSTANCE_SOURCE + ${PROJECT_SOURCE_DIR}/device_operation/device_conv2d_fwd_xdl_c_shuffle_bias_relu_atomic_add_nhwc_kyxc_nhwk_f16_instance.cpp; +) + +add_library(device_conv2d_fwd_bias_relu_atomic_add_instance SHARED ${DEVICE_CONV2D_FWD_BIAS_RELU_ATOMIC_ADD_INSTANCE_SOURCE}) +target_include_directories(device_conv2d_fwd_bias_relu_atomic_add_instance SYSTEM PUBLIC $) +target_compile_features(device_conv2d_fwd_bias_relu_atomic_add_instance PUBLIC) +set_target_properties(device_conv2d_fwd_bias_relu_atomic_add_instance PROPERTIES POSITION_INDEPENDENT_CODE ON) +install(TARGETS device_conv2d_fwd_bias_relu_atomic_add_instance LIBRARY DESTINATION lib) # ck_profiler -set(PROFILER_SOURCE profiler.cpp gemm_profiler.cpp conv_profiler.cpp) +set(PROFILER_SOURCE + profiler.cpp + profile_gemm.cpp + profile_conv_fwd.cpp + profile_conv_fwd_bias_relu.cpp + profile_conv_fwd_bias_relu_add.cpp + profile_conv_fwd_bias_relu_atomic_add.cpp) add_executable(ckProfiler ${PROFILER_SOURCE}) target_link_libraries(ckProfiler PRIVATE host_tensor) -target_link_libraries(ckProfiler PRIVATE device_gemm_instance device_conv_instance) +target_link_libraries(ckProfiler PRIVATE device_gemm_instance) +target_link_libraries(ckProfiler PRIVATE device_conv2d_fwd_instance) +target_link_libraries(ckProfiler PRIVATE device_conv2d_fwd_bias_relu_instance) +target_link_libraries(ckProfiler PRIVATE device_conv2d_fwd_bias_relu_add_instance) +target_link_libraries(ckProfiler PRIVATE device_conv2d_fwd_bias_relu_atomic_add_instance) diff --git a/profiler/gemm_profiler.cpp b/profiler/gemm_profiler.cpp deleted file mode 100644 index 018fe872d0..0000000000 --- a/profiler/gemm_profiler.cpp +++ /dev/null @@ -1,219 +0,0 @@ -#include -#include -#include -#include -#include -#include -#include "config.hpp" -#include "print.hpp" -#include "device.hpp" -#include "host_tensor.hpp" -#include "host_tensor_generator.hpp" -#include "host_gemm.hpp" -#include "device_tensor.hpp" -#include "device_base.hpp" -#include "device_gemm_xdl.hpp" -#include "profile_gemm.hpp" - -enum GemmMatrixLayout -{ - MK_KN_MN, // 0 - MK_NK_MN, // 1 - KM_KN_MN, // 2 - KM_NK_MN, // 3 - MK_KN_NM, // 4 - MK_NK_NM, // 5 - KM_KN_NM, // 6 - KM_NK_NM, // 7 -}; - -enum GemmDataType -{ - F32_F32_F32, // 0 - F16_F16_F16, // 1 -}; - -int gemm_profiler(int argc, char* argv[]) -{ - if(argc != 14) - { - printf("arg1: tensor operation (gemm: GEMM)\n"); - printf("arg2: data type (0: fp32; 1: fp16)\n"); - printf("arg3: matrix layout (0: A[m, k] * B[k, n] = C[m, n];\n"); - printf(" 1: A[m, k] * B[n, k] = C[m, n];\n"); - printf(" 2: A[k, n] * B[k, n] = C[m, n];\n"); - printf(" 3: A[k, n] * B[n, k] = C[m, n])\n"); - printf("arg4: verification (0: no; 1: yes)\n"); - printf("arg5: initialization (0: no init; 1: integer value; 2: decimal value)\n"); - printf("arg8: print tensor value (0: no; 1: yes)\n"); - printf("arg7: run kernel # of times (>1)\n"); - printf("arg8 to 13: M, N, K, StrideA, StrideB, StrideC\n"); - exit(1); - } - - const int data_type = static_cast(std::stoi(argv[2])); - const int layout = static_cast(std::stoi(argv[3])); - const bool do_verification = std::stoi(argv[4]); - const int init_method = std::stoi(argv[5]); - const bool do_log = std::stoi(argv[6]); - const int nrepeat = std::stoi(argv[7]); - - const int M = std::stoi(argv[8]); - const int N = std::stoi(argv[9]); - const int K = std::stoi(argv[10]); - - const int StrideA = std::stoi(argv[11]); - const int StrideB = std::stoi(argv[12]); - const int StrideC = std::stoi(argv[13]); - - if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::MK_KN_MN) - { - ck::profiler::profile_gemm(do_verification, - init_method, - do_log, - nrepeat, - M, - N, - K, - (StrideA < 0) ? K : StrideA, - (StrideB < 0) ? N : StrideB, - (StrideC < 0) ? N : StrideC); - } - else if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::MK_NK_MN) - { - ck::profiler::profile_gemm(do_verification, - init_method, - do_log, - nrepeat, - M, - N, - K, - (StrideA < 0) ? K : StrideA, - (StrideB < 0) ? K : StrideB, - (StrideC < 0) ? N : StrideC); - } - else if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::KM_KN_MN) - { - ck::profiler::profile_gemm(do_verification, - init_method, - do_log, - nrepeat, - M, - N, - K, - (StrideA < 0) ? M : StrideA, - (StrideB < 0) ? N : StrideB, - (StrideC < 0) ? N : StrideC); - } - else if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::KM_NK_MN) - { - ck::profiler::profile_gemm(do_verification, - init_method, - do_log, - nrepeat, - M, - N, - K, - (StrideA < 0) ? M : StrideA, - (StrideB < 0) ? K : StrideB, - (StrideC < 0) ? N : StrideC); - } - else if(data_type == GemmDataType::F32_F32_F32 && layout == GemmMatrixLayout::MK_KN_MN) - { - ck::profiler::profile_gemm(do_verification, - init_method, - do_log, - nrepeat, - M, - N, - K, - (StrideA < 0) ? K : StrideA, - (StrideB < 0) ? N : StrideB, - (StrideC < 0) ? N : StrideC); - } - else if(data_type == GemmDataType::F32_F32_F32 && layout == GemmMatrixLayout::MK_NK_MN) - { - ck::profiler::profile_gemm(do_verification, - init_method, - do_log, - nrepeat, - M, - N, - K, - (StrideA < 0) ? K : StrideA, - (StrideB < 0) ? K : StrideB, - (StrideC < 0) ? N : StrideC); - } - else if(data_type == GemmDataType::F32_F32_F32 && layout == GemmMatrixLayout::KM_KN_MN) - { - ck::profiler::profile_gemm(do_verification, - init_method, - do_log, - nrepeat, - M, - N, - K, - (StrideA < 0) ? M : StrideA, - (StrideB < 0) ? N : StrideB, - (StrideC < 0) ? N : StrideC); - } - else if(data_type == GemmDataType::F32_F32_F32 && layout == GemmMatrixLayout::KM_NK_MN) - { - ck::profiler::profile_gemm(do_verification, - init_method, - do_log, - nrepeat, - M, - N, - K, - (StrideA < 0) ? M : StrideA, - (StrideB < 0) ? K : StrideB, - (StrideC < 0) ? N : StrideC); - } - else - { - throw std::runtime_error("wrong! this GEMM data_type & layout is not implemented"); - } - - return 1; -} diff --git a/profiler/include/profile_conv_fwd_bias_relu_add_impl.hpp b/profiler/include/profile_conv_fwd_bias_relu_add_impl.hpp new file mode 100644 index 0000000000..d665321879 --- /dev/null +++ b/profiler/include/profile_conv_fwd_bias_relu_add_impl.hpp @@ -0,0 +1,305 @@ +#pragma once +#include "config.hpp" +#include "device.hpp" +#include "host_tensor.hpp" +#include "host_tensor_generator.hpp" +#include "host_conv.hpp" +#include "tensor_layout.hpp" +#include "device_tensor.hpp" +#include "device_conv_fwd_bias_activation_add.hpp" +#include "element_wise_operation.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace device_conv2d_fwd_bias_activation_add_instance { + +using DeviceConvFwdBiasReluAddPtr = + DeviceConvFwdBiasActivationAddPtr; + +void add_device_conv2d_fwd_xdl_c_shuffle_bias_relu_add_nhwc_kyxc_nhwk_f16_instances( + std::vector&); + +} // namespace device_conv2d_fwd_bias_activation_add_instance +} // namespace device +} // namespace tensor_operation +} // namespace ck + +namespace ck { +namespace profiler { + +template +void host_reference_calculation(const Tensor& in_n_c_hi_wi, + const Tensor& wei_k_c_y_x, + Tensor& out_n_k_ho_wo, + const Tensor& bias_k, + const Tensor& resi_n_k_ho_wo, + const std::vector& conv_strides, + const std::vector& conv_dilations, + const std::vector& in_left_pads, + const std::vector& /* in_right_pads */, + const InElementOp& in_element_op, + const WeiElementOp& wei_element_op, + const OutElementOp& out_element_op) +{ + auto f_nchw = [&](auto n, auto k, auto ho, auto wo) { + double v = 0; + for(int c = 0; c < wei_k_c_y_x.mDesc.GetLengths()[1]; ++c) + { + for(int y = 0; y < wei_k_c_y_x.mDesc.GetLengths()[2]; ++y) + { + int hi = ho * conv_strides[0] + y * conv_dilations[0] - in_left_pads[0]; + for(int x = 0; x < wei_k_c_y_x.mDesc.GetLengths()[3]; ++x) + { + int wi = wo * conv_strides[1] + x * conv_dilations[1] - in_left_pads[1]; + if(hi >= 0 && hi < in_n_c_hi_wi.mDesc.GetLengths()[2] && wi >= 0 && + wi < in_n_c_hi_wi.mDesc.GetLengths()[3]) + { + v += in_element_op(static_cast(in_n_c_hi_wi(n, c, hi, wi))) * + wei_element_op(static_cast(wei_k_c_y_x(k, c, y, x))); + } + } + } + } + + out_n_k_ho_wo(n, k, ho, wo) = out_element_op(v, bias_k(k), resi_n_k_ho_wo(n, k, ho, wo)); + }; + + make_ParallelTensorFunctor(f_nchw, + out_n_k_ho_wo.mDesc.GetLengths()[0], + out_n_k_ho_wo.mDesc.GetLengths()[1], + out_n_k_ho_wo.mDesc.GetLengths()[2], + out_n_k_ho_wo.mDesc.GetLengths()[3])( + std::thread::hardware_concurrency()); +} + +template +void profile_conv_fwd_bias_relu_add_impl(int do_verification, + int init_method, + bool do_log, + int nrepeat, + ck::index_t N, + ck::index_t K, + ck::index_t C, + std::vector input_spatial_lengths, + std::vector filter_spatial_lengths, + std::vector output_spatial_lengths, + std::vector conv_filter_strides, + std::vector conv_filter_dilations, + std::vector input_left_pads, + std::vector input_right_pads) +{ + const ck::index_t Y = filter_spatial_lengths[0]; + const ck::index_t X = filter_spatial_lengths[1]; + + const ck::index_t Hi = input_spatial_lengths[0]; + const ck::index_t Wi = input_spatial_lengths[1]; + + const ck::index_t Ho = output_spatial_lengths[0]; + const ck::index_t Wo = output_spatial_lengths[1]; + + auto f_host_tensor_descriptor = + [](std::size_t N_, std::size_t C_, std::size_t H, std::size_t W, auto layout) { + if constexpr(is_same::value || + is_same::value || + is_same::value) + { + return HostTensorDescriptor(std::vector({N_, C_, H, W}), + std::vector({C_ * H * W, H * W, W, 1})); + } + else if constexpr(is_same::value || + is_same::value || + is_same::value) + { + return HostTensorDescriptor(std::vector({N_, C_, H, W}), + std::vector({C_ * H * W, 1, W * C_, C_})); + } + }; + + Tensor in_n_c_hi_wi(f_host_tensor_descriptor(N, C, Hi, Wi, InLayout{})); + Tensor wei_k_c_y_x(f_host_tensor_descriptor(K, C, Y, X, WeiLayout{})); + Tensor out_n_k_ho_wo_host_result( + f_host_tensor_descriptor(N, K, Ho, Wo, OutLayout{})); + Tensor out_n_k_ho_wo_device_result( + f_host_tensor_descriptor(N, K, Ho, Wo, OutLayout{})); + + // bias: assume contiguous 1d vector + Tensor bias_k( + HostTensorDescriptor(std::vector({static_cast(K)}))); + + // residual: assume same layout as output tensor + Tensor resi_n_k_ho_wo(f_host_tensor_descriptor(N, K, Ho, Wo, OutLayout{})); + + std::cout << "in_n_c_hi_wi: " << in_n_c_hi_wi.mDesc << std::endl; + std::cout << "wei_k_c_y_x: " << wei_k_c_y_x.mDesc << std::endl; + std::cout << "out_n_k_ho_wo: " << out_n_k_ho_wo_host_result.mDesc << std::endl; + std::cout << "bias_k: " << bias_k.mDesc << std::endl; + std::cout << "resi_n_k_ho_wo: " << resi_n_k_ho_wo.mDesc << std::endl; + + switch(init_method) + { + case 0: break; + case 1: + in_n_c_hi_wi.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + wei_k_c_y_x.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + bias_k.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + resi_n_k_ho_wo.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + break; + default: + in_n_c_hi_wi.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + wei_k_c_y_x.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + bias_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + resi_n_k_ho_wo.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + } + + using InElementOp = ck::tensor_operation::element_wise::PassThrough; + using WeiElementOp = ck::tensor_operation::element_wise::PassThrough; + using OutElementOp = ck::tensor_operation::element_wise::AddReluAdd; + + if(do_verification) + { + host_reference_calculation(in_n_c_hi_wi, + wei_k_c_y_x, + out_n_k_ho_wo_host_result, + bias_k, + resi_n_k_ho_wo, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + InElementOp{}, + WeiElementOp{}, + OutElementOp{}); + } + + DeviceMem in_device_buf(sizeof(InDataType) * in_n_c_hi_wi.mDesc.GetElementSpace()); + DeviceMem wei_device_buf(sizeof(WeiDataType) * wei_k_c_y_x.mDesc.GetElementSpace()); + DeviceMem out_device_buf(sizeof(OutDataType) * + out_n_k_ho_wo_device_result.mDesc.GetElementSpace()); + DeviceMem bias_device_buf(sizeof(OutDataType) * bias_k.mDesc.GetElementSpace()); + DeviceMem resi_device_buf(sizeof(OutDataType) * resi_n_k_ho_wo.mDesc.GetElementSpace()); + + in_device_buf.ToDevice(in_n_c_hi_wi.mData.data()); + wei_device_buf.ToDevice(wei_k_c_y_x.mData.data()); + bias_device_buf.ToDevice(bias_k.mData.data()); + resi_device_buf.ToDevice(resi_n_k_ho_wo.mData.data()); + + using DeviceConvFwdBiasReluAddPtr = ck::tensor_operation::device:: + DeviceConvFwdBiasActivationAddPtr; + + // add device operator instances + std::vector op_ptrs; + + if constexpr(ck::is_same_v, ck::half_t> && + ck::is_same_v, ck::half_t> && + ck::is_same_v, ck::half_t>) + { + ck::tensor_operation::device::device_conv2d_fwd_bias_activation_add_instance:: + add_device_conv2d_fwd_xdl_c_shuffle_bias_relu_add_nhwc_kyxc_nhwk_f16_instances(op_ptrs); + } + + if(op_ptrs.size() <= 0) + { + throw std::runtime_error("wrong! no device Conv instance found"); + } + + std::string best_conv_name; + float best_ave_time = 0; + float best_tflops = 0; + float best_gb_per_sec = 0; + + // profile device Conv instances + for(auto& op_ptr : op_ptrs) + { + auto argument_ptr = op_ptr->MakeArgumentPointer( + static_cast(in_device_buf.GetDeviceBuffer()), + static_cast(wei_device_buf.GetDeviceBuffer()), + static_cast(out_device_buf.GetDeviceBuffer()), + static_cast(bias_device_buf.GetDeviceBuffer()), + static_cast(resi_device_buf.GetDeviceBuffer()), + N, + K, + C, + input_spatial_lengths, + filter_spatial_lengths, + output_spatial_lengths, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + InElementOp{}, + WeiElementOp{}, + OutElementOp{}); + + auto invoker_ptr = op_ptr->MakeInvokerPointer(); + + if(op_ptr->IsSupportedArgument(argument_ptr.get())) + { + std::string conv_name = op_ptr->GetTypeString(); + + float ave_time = invoker_ptr->Run(argument_ptr.get(), nrepeat); + + std::size_t flop = std::size_t(2) * N * K * Ho * Wo * C * Y * X; + + std::size_t num_btype = + sizeof(InDataType) * (N * C * Hi * Wi) + sizeof(WeiDataType) * (K * C * Y * X) + + sizeof(OutDataType) * (N * K * Ho * Wo) + sizeof(OutDataType) * (K) + + sizeof(OutDataType) * (N * K * Ho * Wo); + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec + << " GB/s, " << conv_name << std::endl; + + if(tflops > best_tflops) + { + best_conv_name = conv_name; + best_tflops = tflops; + best_ave_time = ave_time; + best_gb_per_sec = gb_per_sec; + } + + if(do_verification) + { + out_device_buf.FromDevice(out_n_k_ho_wo_device_result.mData.data()); + + check_error(out_n_k_ho_wo_host_result, out_n_k_ho_wo_device_result); + + if(do_log) + { + LogRangeAsType(std::cout << "in : ", in_n_c_hi_wi.mData, ",") + << std::endl; + LogRangeAsType(std::cout << "wei: ", wei_k_c_y_x.mData, ",") + << std::endl; + LogRangeAsType( + std::cout << "out_host : ", out_n_k_ho_wo_host_result.mData, ",") + << std::endl; + LogRangeAsType( + std::cout << "out_device: ", out_n_k_ho_wo_device_result.mData, ",") + << std::endl; + } + } + } + } + + std::cout << "Best Perf: " << best_ave_time << " ms, " << best_tflops << " TFlops, " + << best_gb_per_sec << " GB/s, " << best_conv_name << std::endl; +} + +} // namespace profiler +} // namespace ck diff --git a/profiler/include/profile_conv_fwd_bias_relu_atomic_add_impl.hpp b/profiler/include/profile_conv_fwd_bias_relu_atomic_add_impl.hpp new file mode 100644 index 0000000000..c17d184e84 --- /dev/null +++ b/profiler/include/profile_conv_fwd_bias_relu_atomic_add_impl.hpp @@ -0,0 +1,328 @@ +#pragma once +#include "config.hpp" +#include "device.hpp" +#include "host_tensor.hpp" +#include "host_tensor_generator.hpp" +#include "host_conv.hpp" +#include "tensor_layout.hpp" +#include "device_tensor.hpp" +#include "device_conv_fwd_bias_activation.hpp" +#include "element_wise_operation.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace device_conv2d_fwd_bias_activation_atomic_add_instance { + +using DeviceConvFwdBiasReluPtr = + DeviceConvFwdBiasActivationPtr; + +void add_device_conv2d_fwd_xdl_c_shuffle_bias_relu_atomic_add_nhwc_kyxc_nhwk_f16_instances( + std::vector&); + +} // namespace device_conv2d_fwd_bias_activation_atomic_add_instance +} // namespace device +} // namespace tensor_operation +} // namespace ck + +namespace ck { +namespace profiler { + +void cpu_conv_bias_relu_atomic_add(ck::half_t* in_ptr, + ck::half_t* weight_ptr, + ck::half_t* output_ptr, + ck::half_t* bias_ptr, + const ck::index_t N, + const ck::index_t K, + const ck::index_t C, + const ck::index_t Y, + const ck::index_t X, + const ck::index_t Hi, + const ck::index_t Wi, + const ck::index_t Ho, + const ck::index_t Wo, + const ck::index_t Stride, + const ck::index_t Dilation, + const ck::index_t Pad) +{ + + const auto in_desc = + HostTensorDescriptor(std::vector{static_cast(N), + static_cast(Hi), + static_cast(Wi), + static_cast(C)}); + const auto wei_desc = + HostTensorDescriptor(std::vector{static_cast(K), + static_cast(Y), + static_cast(X), + static_cast(C)}); + const auto out_desc = + HostTensorDescriptor(std::vector{static_cast(N), + static_cast(Ho), + static_cast(Wo), + static_cast(K)}); + const auto bias_desc = + HostTensorDescriptor(std::vector{static_cast(K)}); + + auto f_k = [&](auto k) { + for(int n = 0; n < N; ++n) + { + for(int ho = 0; ho < Ho; ++ho) + { + for(int wo = 0; wo < Wo; ++wo) + { + double v = 0; + for(int c = 0; c < C; ++c) + { + for(int y = 0; y < Y; ++y) + { + int hi = ho * Stride + y * Dilation - Pad; + for(int x = 0; x < X; ++x) + { + int wi = wo * Stride + x * Dilation - Pad; + if(hi >= 0 && hi < Hi && wi >= 0 && wi < Wi) + { + double in = + in_ptr[in_desc.GetOffsetFromMultiIndex(n, hi, wi, c)]; + double wei = + weight_ptr[wei_desc.GetOffsetFromMultiIndex(k, y, x, c)]; + + v += in * wei; + } + } + } + } + + v += bias_ptr[bias_desc.GetOffsetFromMultiIndex(k)]; + + v = v > 0 ? v : 0; + + output_ptr[out_desc.GetOffsetFromMultiIndex(n, ho, wo, k)] = v; + } + } + } + }; + + make_ParallelTensorFunctor(f_k, K)(std::thread::hardware_concurrency()); +} + +template +void profile_conv_fwd_bias_relu_atomic_add_impl(int do_verification, + int init_method, + bool do_log, + int nrepeat, + ck::index_t N, + ck::index_t K, + ck::index_t C, + std::vector input_spatial_lengths, + std::vector filter_spatial_lengths, + std::vector output_spatial_lengths, + std::vector conv_filter_strides, + std::vector conv_filter_dilations, + std::vector input_left_pads, + std::vector input_right_pads) +{ + const ck::index_t Y = filter_spatial_lengths[0]; + const ck::index_t X = filter_spatial_lengths[1]; + + const ck::index_t Hi = input_spatial_lengths[0]; + const ck::index_t Wi = input_spatial_lengths[1]; + + const ck::index_t Ho = output_spatial_lengths[0]; + const ck::index_t Wo = output_spatial_lengths[1]; + + auto f_host_tensor_descriptor = + [](std::size_t N_, std::size_t C_, std::size_t H, std::size_t W, auto layout) { + if constexpr(is_same::value || + is_same::value || + is_same::value) + { + return HostTensorDescriptor(std::vector({N_, C_, H, W}), + std::vector({C_ * H * W, H * W, W, 1})); + } + else if constexpr(is_same::value || + is_same::value || + is_same::value) + { + return HostTensorDescriptor(std::vector({N_, C_, H, W}), + std::vector({C_ * H * W, 1, W * C_, C_})); + } + }; + + Tensor in_n_c_hi_wi(f_host_tensor_descriptor(N, C, Hi, Wi, InLayout{})); + Tensor wei_k_c_y_x(f_host_tensor_descriptor(K, C, Y, X, WeiLayout{})); + Tensor out_n_k_ho_wo_host_result( + f_host_tensor_descriptor(N, K, Ho, Wo, OutLayout{})); + Tensor out_n_k_ho_wo_device_result( + f_host_tensor_descriptor(N, K, Ho, Wo, OutLayout{})); + + // bias: assume contiguous 1d vector + Tensor bias_k( + HostTensorDescriptor(std::vector({static_cast(K)}))); + + std::cout << "in_n_c_hi_wi: " << in_n_c_hi_wi.mDesc << std::endl; + std::cout << "wei_k_c_y_x: " << wei_k_c_y_x.mDesc << std::endl; + std::cout << "out_n_k_ho_wo: " << out_n_k_ho_wo_host_result.mDesc << std::endl; + std::cout << "bias_k: " << bias_k.mDesc << std::endl; + + switch(init_method) + { + case 0: break; + case 1: + in_n_c_hi_wi.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + wei_k_c_y_x.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + bias_k.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + break; + default: + in_n_c_hi_wi.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + wei_k_c_y_x.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + bias_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + } + + using InElementOp = ck::tensor_operation::element_wise::PassThrough; + using WeiElementOp = ck::tensor_operation::element_wise::PassThrough; + using OutElementOp = ck::tensor_operation::element_wise::AddRelu; + + if(do_verification) + { + cpu_conv_bias_relu_atomic_add(in_n_c_hi_wi.mData.data(), + wei_k_c_y_x.mData.data(), + out_n_k_ho_wo_host_result.mData.data(), + bias_k.mData.data(), + N, + K, + C, + Y, + X, + Hi, + Wi, + Ho, + Wo, + conv_filter_strides[0], + conv_filter_dilations[0], + input_left_pads[0]); + } + + DeviceMem in_device_buf(sizeof(InDataType) * in_n_c_hi_wi.mDesc.GetElementSpace()); + DeviceMem wei_device_buf(sizeof(WeiDataType) * wei_k_c_y_x.mDesc.GetElementSpace()); + DeviceMem out_device_buf(sizeof(OutDataType) * + out_n_k_ho_wo_device_result.mDesc.GetElementSpace()); + DeviceMem bias_device_buf(sizeof(OutDataType) * bias_k.mDesc.GetElementSpace()); + + in_device_buf.ToDevice(in_n_c_hi_wi.mData.data()); + wei_device_buf.ToDevice(wei_k_c_y_x.mData.data()); + bias_device_buf.ToDevice(bias_k.mData.data()); + + using DeviceConvFwdBiasReluPtr = ck::tensor_operation::device:: + DeviceConvFwdBiasActivationPtr; + + // add device operator instances + std::vector op_ptrs; + + if constexpr(ck::is_same_v, ck::half_t> && + ck::is_same_v, ck::half_t> && + ck::is_same_v, ck::half_t>) + { + ck::tensor_operation::device::device_conv2d_fwd_bias_activation_atomic_add_instance:: + add_device_conv2d_fwd_xdl_c_shuffle_bias_relu_atomic_add_nhwc_kyxc_nhwk_f16_instances( + op_ptrs); + } + + if(op_ptrs.size() <= 0) + { + throw std::runtime_error("wrong! no device Conv instance found"); + } + + std::string best_conv_name; + float best_ave_time = 0; + float best_tflops = 0; + float best_gb_per_sec = 0; + + // profile device Conv instances + for(auto& op_ptr : op_ptrs) + { + auto argument_ptr = op_ptr->MakeArgumentPointer( + static_cast(in_device_buf.GetDeviceBuffer()), + static_cast(wei_device_buf.GetDeviceBuffer()), + static_cast(out_device_buf.GetDeviceBuffer()), + static_cast(bias_device_buf.GetDeviceBuffer()), + N, + K, + C, + input_spatial_lengths, + filter_spatial_lengths, + output_spatial_lengths, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + InElementOp{}, + WeiElementOp{}, + OutElementOp{}); + + auto invoker_ptr = op_ptr->MakeInvokerPointer(); + + if(op_ptr->IsSupportedArgument(argument_ptr.get())) + { + std::string conv_name = op_ptr->GetTypeString(); + + float ave_time = invoker_ptr->Run(argument_ptr.get(), nrepeat); + + std::size_t flop = std::size_t(2) * N * K * Ho * Wo * C * Y * X; + + std::size_t num_btype = + sizeof(InDataType) * (N * C * Hi * Wi) + sizeof(WeiDataType) * (K * C * Y * X) + + sizeof(OutDataType) * (N * K * Ho * Wo) + sizeof(OutDataType) * (K); + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec + << " GB/s, " << conv_name << std::endl; + + if(tflops > best_tflops) + { + best_conv_name = conv_name; + best_tflops = tflops; + best_ave_time = ave_time; + best_gb_per_sec = gb_per_sec; + } + + if(do_verification) + { + out_device_buf.FromDevice(out_n_k_ho_wo_device_result.mData.data()); + + check_error(out_n_k_ho_wo_host_result, out_n_k_ho_wo_device_result); + + if(do_log) + { + LogRangeAsType(std::cout << "in : ", in_n_c_hi_wi.mData, ",") + << std::endl; + LogRangeAsType(std::cout << "wei: ", wei_k_c_y_x.mData, ",") + << std::endl; + LogRangeAsType( + std::cout << "out_host : ", out_n_k_ho_wo_host_result.mData, ",") + << std::endl; + LogRangeAsType( + std::cout << "out_device: ", out_n_k_ho_wo_device_result.mData, ",") + << std::endl; + } + } + } + } + + std::cout << "Best Perf: " << best_ave_time << " ms, " << best_tflops << " TFlops, " + << best_gb_per_sec << " GB/s, " << best_conv_name << std::endl; +} + +} // namespace profiler +} // namespace ck diff --git a/profiler/include/profile_conv_fwd_bias_relu_impl.hpp b/profiler/include/profile_conv_fwd_bias_relu_impl.hpp new file mode 100644 index 0000000000..955861dcf8 --- /dev/null +++ b/profiler/include/profile_conv_fwd_bias_relu_impl.hpp @@ -0,0 +1,327 @@ +#pragma once +#include "config.hpp" +#include "device.hpp" +#include "host_tensor.hpp" +#include "host_tensor_generator.hpp" +#include "host_conv.hpp" +#include "tensor_layout.hpp" +#include "device_tensor.hpp" +#include "device_conv_fwd_bias_activation.hpp" +#include "element_wise_operation.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace device_conv2d_fwd_bias_activation_instance { + +using DeviceConvFwdBiasReluPtr = + DeviceConvFwdBiasActivationPtr; + +void add_device_conv2d_fwd_xdl_c_shuffle_bias_relu_nhwc_kyxc_nhwk_f16_instances( + std::vector&); + +} // namespace device_conv2d_fwd_bias_activation_instance +} // namespace device +} // namespace tensor_operation +} // namespace ck + +namespace ck { +namespace profiler { + +void cpu_conv_bias_relu(ck::half_t* in_ptr, + ck::half_t* weight_ptr, + ck::half_t* output_ptr, + ck::half_t* bias_ptr, + const ck::index_t N, + const ck::index_t K, + const ck::index_t C, + const ck::index_t Y, + const ck::index_t X, + const ck::index_t Hi, + const ck::index_t Wi, + const ck::index_t Ho, + const ck::index_t Wo, + const ck::index_t Stride, + const ck::index_t Dilation, + const ck::index_t Pad) +{ + + const auto in_desc = + HostTensorDescriptor(std::vector{static_cast(N), + static_cast(Hi), + static_cast(Wi), + static_cast(C)}); + const auto wei_desc = + HostTensorDescriptor(std::vector{static_cast(K), + static_cast(Y), + static_cast(X), + static_cast(C)}); + const auto out_desc = + HostTensorDescriptor(std::vector{static_cast(N), + static_cast(Ho), + static_cast(Wo), + static_cast(K)}); + const auto bias_desc = + HostTensorDescriptor(std::vector{static_cast(K)}); + + auto f_k = [&](auto k) { + for(int n = 0; n < N; ++n) + { + for(int ho = 0; ho < Ho; ++ho) + { + for(int wo = 0; wo < Wo; ++wo) + { + double v = 0; + for(int c = 0; c < C; ++c) + { + for(int y = 0; y < Y; ++y) + { + int hi = ho * Stride + y * Dilation - Pad; + for(int x = 0; x < X; ++x) + { + int wi = wo * Stride + x * Dilation - Pad; + if(hi >= 0 && hi < Hi && wi >= 0 && wi < Wi) + { + double in = + in_ptr[in_desc.GetOffsetFromMultiIndex(n, hi, wi, c)]; + double wei = + weight_ptr[wei_desc.GetOffsetFromMultiIndex(k, y, x, c)]; + + v += in * wei; + } + } + } + } + + v += bias_ptr[bias_desc.GetOffsetFromMultiIndex(k)]; + + v = v > 0 ? v : 0; + + output_ptr[out_desc.GetOffsetFromMultiIndex(n, ho, wo, k)] = v; + } + } + } + }; + + make_ParallelTensorFunctor(f_k, K)(std::thread::hardware_concurrency()); +} + +template +void profile_conv_fwd_bias_relu_impl(int do_verification, + int init_method, + bool do_log, + int nrepeat, + ck::index_t N, + ck::index_t K, + ck::index_t C, + std::vector input_spatial_lengths, + std::vector filter_spatial_lengths, + std::vector output_spatial_lengths, + std::vector conv_filter_strides, + std::vector conv_filter_dilations, + std::vector input_left_pads, + std::vector input_right_pads) +{ + const ck::index_t Y = filter_spatial_lengths[0]; + const ck::index_t X = filter_spatial_lengths[1]; + + const ck::index_t Hi = input_spatial_lengths[0]; + const ck::index_t Wi = input_spatial_lengths[1]; + + const ck::index_t Ho = output_spatial_lengths[0]; + const ck::index_t Wo = output_spatial_lengths[1]; + + auto f_host_tensor_descriptor = + [](std::size_t N_, std::size_t C_, std::size_t H, std::size_t W, auto layout) { + if constexpr(is_same::value || + is_same::value || + is_same::value) + { + return HostTensorDescriptor(std::vector({N_, C_, H, W}), + std::vector({C_ * H * W, H * W, W, 1})); + } + else if constexpr(is_same::value || + is_same::value || + is_same::value) + { + return HostTensorDescriptor(std::vector({N_, C_, H, W}), + std::vector({C_ * H * W, 1, W * C_, C_})); + } + }; + + Tensor in_n_c_hi_wi(f_host_tensor_descriptor(N, C, Hi, Wi, InLayout{})); + Tensor wei_k_c_y_x(f_host_tensor_descriptor(K, C, Y, X, WeiLayout{})); + Tensor out_n_k_ho_wo_host_result( + f_host_tensor_descriptor(N, K, Ho, Wo, OutLayout{})); + Tensor out_n_k_ho_wo_device_result( + f_host_tensor_descriptor(N, K, Ho, Wo, OutLayout{})); + + // bias: assume contiguous 1d vector + Tensor bias_k( + HostTensorDescriptor(std::vector({static_cast(K)}))); + + std::cout << "in_n_c_hi_wi: " << in_n_c_hi_wi.mDesc << std::endl; + std::cout << "wei_k_c_y_x: " << wei_k_c_y_x.mDesc << std::endl; + std::cout << "out_n_k_ho_wo: " << out_n_k_ho_wo_host_result.mDesc << std::endl; + std::cout << "bias_k: " << bias_k.mDesc << std::endl; + + switch(init_method) + { + case 0: break; + case 1: + in_n_c_hi_wi.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + wei_k_c_y_x.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + bias_k.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + break; + default: + in_n_c_hi_wi.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + wei_k_c_y_x.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + bias_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + } + + using InElementOp = ck::tensor_operation::element_wise::PassThrough; + using WeiElementOp = ck::tensor_operation::element_wise::PassThrough; + using OutElementOp = ck::tensor_operation::element_wise::AddRelu; + + if(do_verification) + { + cpu_conv_bias_relu(in_n_c_hi_wi.mData.data(), + wei_k_c_y_x.mData.data(), + out_n_k_ho_wo_host_result.mData.data(), + bias_k.mData.data(), + N, + K, + C, + Y, + X, + Hi, + Wi, + Ho, + Wo, + conv_filter_strides[0], + conv_filter_dilations[0], + input_left_pads[0]); + } + + DeviceMem in_device_buf(sizeof(InDataType) * in_n_c_hi_wi.mDesc.GetElementSpace()); + DeviceMem wei_device_buf(sizeof(WeiDataType) * wei_k_c_y_x.mDesc.GetElementSpace()); + DeviceMem out_device_buf(sizeof(OutDataType) * + out_n_k_ho_wo_device_result.mDesc.GetElementSpace()); + DeviceMem bias_device_buf(sizeof(OutDataType) * bias_k.mDesc.GetElementSpace()); + + in_device_buf.ToDevice(in_n_c_hi_wi.mData.data()); + wei_device_buf.ToDevice(wei_k_c_y_x.mData.data()); + bias_device_buf.ToDevice(bias_k.mData.data()); + + using DeviceConvFwdBiasReluPtr = ck::tensor_operation::device:: + DeviceConvFwdBiasActivationPtr; + + // add device operator instances + std::vector op_ptrs; + + if constexpr(ck::is_same_v, ck::half_t> && + ck::is_same_v, ck::half_t> && + ck::is_same_v, ck::half_t>) + { + ck::tensor_operation::device::device_conv2d_fwd_bias_activation_instance:: + add_device_conv2d_fwd_xdl_c_shuffle_bias_relu_nhwc_kyxc_nhwk_f16_instances(op_ptrs); + } + + if(op_ptrs.size() <= 0) + { + throw std::runtime_error("wrong! no device Conv instance found"); + } + + std::string best_conv_name; + float best_ave_time = 0; + float best_tflops = 0; + float best_gb_per_sec = 0; + + // profile device Conv instances + for(auto& op_ptr : op_ptrs) + { + auto argument_ptr = op_ptr->MakeArgumentPointer( + static_cast(in_device_buf.GetDeviceBuffer()), + static_cast(wei_device_buf.GetDeviceBuffer()), + static_cast(out_device_buf.GetDeviceBuffer()), + static_cast(bias_device_buf.GetDeviceBuffer()), + N, + K, + C, + input_spatial_lengths, + filter_spatial_lengths, + output_spatial_lengths, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + InElementOp{}, + WeiElementOp{}, + OutElementOp{}); + + auto invoker_ptr = op_ptr->MakeInvokerPointer(); + + if(op_ptr->IsSupportedArgument(argument_ptr.get())) + { + std::string conv_name = op_ptr->GetTypeString(); + + float ave_time = invoker_ptr->Run(argument_ptr.get(), nrepeat); + + std::size_t flop = std::size_t(2) * N * K * Ho * Wo * C * Y * X; + + std::size_t num_btype = + sizeof(InDataType) * (N * C * Hi * Wi) + sizeof(WeiDataType) * (K * C * Y * X) + + sizeof(OutDataType) * (N * K * Ho * Wo) + sizeof(OutDataType) * (K); + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec + << " GB/s, " << conv_name << std::endl; + + if(tflops > best_tflops) + { + best_conv_name = conv_name; + best_tflops = tflops; + best_ave_time = ave_time; + best_gb_per_sec = gb_per_sec; + } + + if(do_verification) + { + out_device_buf.FromDevice(out_n_k_ho_wo_device_result.mData.data()); + + check_error(out_n_k_ho_wo_host_result, out_n_k_ho_wo_device_result); + + if(do_log) + { + LogRangeAsType(std::cout << "in : ", in_n_c_hi_wi.mData, ",") + << std::endl; + LogRangeAsType(std::cout << "wei: ", wei_k_c_y_x.mData, ",") + << std::endl; + LogRangeAsType( + std::cout << "out_host : ", out_n_k_ho_wo_host_result.mData, ",") + << std::endl; + LogRangeAsType( + std::cout << "out_device: ", out_n_k_ho_wo_device_result.mData, ",") + << std::endl; + } + } + } + } + + std::cout << "Best Perf: " << best_ave_time << " ms, " << best_tflops << " TFlops, " + << best_gb_per_sec << " GB/s, " << best_conv_name << std::endl; +} + +} // namespace profiler +} // namespace ck diff --git a/profiler/include/profile_conv.hpp b/profiler/include/profile_conv_fwd_impl.hpp similarity index 75% rename from profiler/include/profile_conv.hpp rename to profiler/include/profile_conv_fwd_impl.hpp index e373d34c55..6e79bf4b4a 100644 --- a/profiler/include/profile_conv.hpp +++ b/profiler/include/profile_conv_fwd_impl.hpp @@ -6,40 +6,26 @@ #include "host_conv.hpp" #include "tensor_layout.hpp" #include "device_tensor.hpp" -#include "device_conv.hpp" -#include "device_conv_instance.hpp" +#include "device_conv_fwd.hpp" #include "element_wise_operation.hpp" namespace ck { namespace tensor_operation { namespace device { -namespace device_conv_instance { +namespace device_conv2d_fwd_instance { using DeviceConvFwdNoOpPtr = DeviceConvFwdPtr; -template <> -void add_device_conv_fwd_instance<2, - float, - float, - float, - ck::tensor_layout::convolution::NHWC, - ck::tensor_layout::convolution::KYXC, - ck::tensor_layout::convolution::NHWK>( +void add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f32_instances(std::vector&); + +void add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f16_instances(std::vector&); + +void add_device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk_f16_instances( std::vector&); -template <> -void add_device_conv_fwd_instance<2, - ck::half_t, - ck::half_t, - ck::half_t, - ck::tensor_layout::convolution::NHWC, - ck::tensor_layout::convolution::KYXC, - ck::tensor_layout::convolution::NHWK>( - std::vector&); - -} // namespace device_conv_instance +} // namespace device_conv2d_fwd_instance } // namespace device } // namespace tensor_operation } // namespace ck @@ -54,20 +40,20 @@ template -void profile_conv(int do_verification, - int init_method, - bool do_log, - int nrepeat, - ck::index_t N, - ck::index_t K, - ck::index_t C, - std::vector input_spatial_lengths, - std::vector filter_spatial_lengths, - std::vector output_spatial_lengths, - std::vector conv_filter_strides, - std::vector conv_filter_dilations, - std::vector input_left_pads, - std::vector input_right_pads) +void profile_conv_fwd_impl(int do_verification, + int init_method, + bool do_log, + int nrepeat, + ck::index_t N, + ck::index_t K, + ck::index_t C, + std::vector input_spatial_lengths, + std::vector filter_spatial_lengths, + std::vector output_spatial_lengths, + std::vector conv_filter_strides, + std::vector conv_filter_dilations, + std::vector input_left_pads, + std::vector input_right_pads) { const ck::index_t Y = filter_spatial_lengths[0]; const ck::index_t X = filter_spatial_lengths[1]; @@ -146,20 +132,30 @@ void profile_conv(int do_verification, // add device Conv instances std::vector conv_ptrs; - ck::tensor_operation::device::device_conv_instance::add_device_conv_fwd_instance<2, - InDataType, - WeiDataType, - OutDataType, - InLayout, - WeiLayout, - OutLayout>( - conv_ptrs); + if constexpr(ck::is_same_v, float> && + ck::is_same_v, float> && + ck::is_same_v, float>) + { + ck::tensor_operation::device::device_conv2d_fwd_instance:: + add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f32_instances(conv_ptrs); + } + else if constexpr(ck::is_same_v, ck::half_t> && + ck::is_same_v, ck::half_t> && + ck::is_same_v, ck::half_t>) + { + ck::tensor_operation::device::device_conv2d_fwd_instance:: + add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f16_instances(conv_ptrs); + + ck::tensor_operation::device::device_conv2d_fwd_instance:: + add_device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk_f16_instances(conv_ptrs); + } if(conv_ptrs.size() <= 0) { throw std::runtime_error("wrong! no device Conv instance found"); } + std::string best_conv_name; float best_ave_time = 0; float best_tflops = 0; float best_gb_per_sec = 0; @@ -189,6 +185,8 @@ void profile_conv(int do_verification, if(conv_ptr->IsSupportedArgument(argument_ptr.get())) { + std::string conv_name = conv_ptr->GetTypeString(); + float ave_time = invoker_ptr->Run(argument_ptr.get(), nrepeat); std::size_t flop = std::size_t(2) * N * K * Ho * Wo * C * Y * X; @@ -202,10 +200,11 @@ void profile_conv(int do_verification, float gb_per_sec = num_btype / 1.E6 / ave_time; std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec - << " GB/s" << std::endl; + << " GB/s, " << conv_name << std::endl; if(tflops > best_tflops) { + best_conv_name = conv_name; best_tflops = tflops; best_ave_time = ave_time; best_gb_per_sec = gb_per_sec; @@ -235,7 +234,7 @@ void profile_conv(int do_verification, } std::cout << "Best Perf: " << best_ave_time << " ms, " << best_tflops << " TFlops, " - << best_gb_per_sec << " GB/s" << std::endl; + << best_gb_per_sec << " GB/s, " << best_conv_name << std::endl; } } // namespace profiler diff --git a/profiler/include/profile_gemm.hpp b/profiler/include/profile_gemm_impl.hpp similarity index 93% rename from profiler/include/profile_gemm.hpp rename to profiler/include/profile_gemm_impl.hpp index 8f92c78a13..3e99928fa4 100644 --- a/profiler/include/profile_gemm.hpp +++ b/profiler/include/profile_gemm_impl.hpp @@ -88,16 +88,16 @@ template -void profile_gemm(int do_verification, - int init_method, - bool do_log, - int nrepeat, - int M, - int N, - int K, - int StrideA, - int StrideB, - int StrideC) +void profile_gemm_impl(int do_verification, + int init_method, + bool do_log, + int nrepeat, + int M, + int N, + int K, + int StrideA, + int StrideB, + int StrideC) { auto f_host_tensor_descriptor = [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { @@ -164,6 +164,7 @@ void profile_gemm(int do_verification, throw std::runtime_error("wrong! no device GEMM instance found"); } + std::string best_gemm_name; float best_ave_time = 0; float best_tflops = 0; float best_gb_per_sec = 0; @@ -189,9 +190,12 @@ void profile_gemm(int do_verification, if(gemm_ptr->IsSupportedArgument(argument_ptr.get())) { + std::string gemm_name = gemm_ptr->GetTypeString(); + float ave_time = invoker_ptr->Run(argument_ptr.get(), nrepeat); std::size_t flop = std::size_t(2) * M * N * K; + std::size_t num_btype = sizeof(ADataType) * M * K + sizeof(BDataType) * K * M + sizeof(CDataType) * M * N; @@ -200,10 +204,11 @@ void profile_gemm(int do_verification, float gb_per_sec = num_btype / 1.E6 / ave_time; std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec - << " GB/s" << std::endl; + << " GB/s, " << gemm_name << std::endl; if(tflops > best_tflops) { + best_gemm_name = gemm_name; best_tflops = tflops; best_ave_time = ave_time; best_gb_per_sec = gb_per_sec; @@ -234,7 +239,7 @@ void profile_gemm(int do_verification, } std::cout << "Best Perf: " << best_ave_time << " ms, " << best_tflops << " TFlops, " - << best_gb_per_sec << " GB/s" << std::endl; + << best_gb_per_sec << " GB/s, " << best_gemm_name << std::endl; } } // namespace profiler diff --git a/profiler/conv_profiler.cpp b/profiler/profile_conv_fwd.cpp similarity index 80% rename from profiler/conv_profiler.cpp rename to profiler/profile_conv_fwd.cpp index 1d39d59e75..d3ca54f83a 100644 --- a/profiler/conv_profiler.cpp +++ b/profiler/profile_conv_fwd.cpp @@ -4,7 +4,7 @@ #include #include #include -#include "profile_conv.hpp" +#include "profile_conv_fwd_impl.hpp" enum ConvDataType { @@ -30,11 +30,11 @@ enum ConvOutputLayout NHWK, // 1 }; -int conv_profiler(int argc, char* argv[]) +int profile_conv_fwd(int argc, char* argv[]) { if(argc != 25) { - printf("arg1: tensor operation (conv: Convolution)\n"); + printf("arg1: tensor operation (conv_fwd: ForwardConvolution)\n"); printf("arg2: data type (0: fp32; 1: fp16)\n"); printf("arg3: input tensor layout (0: NCHW; 1: NHWC)\n"); printf("arg4: weight tensor layout (0: KCYX; 1: KYXC)\n"); @@ -83,13 +83,13 @@ int conv_profiler(int argc, char* argv[]) if(data_type == ConvDataType::F32_F32_F32 && in_layout == ConvInputLayout::NHWC && wei_layout == ConvWeightLayout::KYXC && out_layout == ConvOutputLayout::NHWK) { - ck::profiler::profile_conv<2, - float, - float, - float, - ck::tensor_layout::convolution::NHWC, - ck::tensor_layout::convolution::KYXC, - ck::tensor_layout::convolution::NHWK>( + ck::profiler::profile_conv_fwd_impl<2, + float, + float, + float, + ck::tensor_layout::convolution::NHWC, + ck::tensor_layout::convolution::KYXC, + ck::tensor_layout::convolution::NHWK>( do_verification, init_method, do_log, @@ -108,13 +108,13 @@ int conv_profiler(int argc, char* argv[]) else if(data_type == ConvDataType::F16_F16_F16 && in_layout == ConvInputLayout::NHWC && wei_layout == ConvWeightLayout::KYXC && out_layout == ConvOutputLayout::NHWK) { - ck::profiler::profile_conv<2, - ck::half_t, - ck::half_t, - ck::half_t, - ck::tensor_layout::convolution::NHWC, - ck::tensor_layout::convolution::KYXC, - ck::tensor_layout::convolution::NHWK>( + ck::profiler::profile_conv_fwd_impl<2, + ck::half_t, + ck::half_t, + ck::half_t, + ck::tensor_layout::convolution::NHWC, + ck::tensor_layout::convolution::KYXC, + ck::tensor_layout::convolution::NHWK>( do_verification, init_method, do_log, diff --git a/profiler/profile_conv_fwd_bias_relu.cpp b/profiler/profile_conv_fwd_bias_relu.cpp new file mode 100644 index 0000000000..3390a9e472 --- /dev/null +++ b/profiler/profile_conv_fwd_bias_relu.cpp @@ -0,0 +1,114 @@ +#include +#include +#include +#include +#include +#include +#include "profile_conv_fwd_bias_relu_impl.hpp" + +enum ConvDataType +{ + F32_F32_F32, // 0 + F16_F16_F16, // 1 +}; + +enum ConvInputLayout +{ + NCHW, // 0 + NHWC, // 1 +}; + +enum ConvWeightLayout +{ + KCYX, // 0 + KYXC, // 1 +}; + +enum ConvOutputLayout +{ + NKHW, // 0 + NHWK, // 1 +}; + +int profile_conv_fwd_bias_relu(int argc, char* argv[]) +{ + if(argc != 25) + { + printf("arg1: tensor operation (conv_fwd_bias_relu: ForwardConvolution+Bias+ReLu)\n"); + printf("arg2: data type (0: fp32; 1: fp16)\n"); + printf("arg3: input tensor layout (0: NCHW; 1: NHWC)\n"); + printf("arg4: weight tensor layout (0: KCYX; 1: KYXC)\n"); + printf("arg5: output tensor layout (0: NKHW; 1: NHWK)\n"); + printf("arg6: verification (0: no; 1: yes)\n"); + printf("arg7: initialization (0: no init; 1: integer value; 2: decimal value)\n"); + printf("arg8: print tensor value (0: no; 1: yes)\n"); + printf("arg9: run kernel # of times (>1)\n"); + printf("arg10 to 24: N, K, C, Y, X, Hi, Wi, Sy, Sx, Dy, Dx, LeftPy, LeftPx, RightPy, " + "RightPx\n"); + exit(1); + } + + const int data_type = static_cast(std::stoi(argv[2])); + const int in_layout = static_cast(std::stoi(argv[3])); + const int wei_layout = static_cast(std::stoi(argv[4])); + const int out_layout = static_cast(std::stoi(argv[5])); + const bool do_verification = std::stoi(argv[6]); + const int init_method = std::stoi(argv[7]); + const bool do_log = std::stoi(argv[8]); + const int nrepeat = std::stoi(argv[9]); + + const ck::index_t N = std::stoi(argv[10]); + const ck::index_t K = std::stoi(argv[11]); + const ck::index_t C = std::stoi(argv[12]); + const ck::index_t Y = std::stoi(argv[13]); + const ck::index_t X = std::stoi(argv[14]); + const ck::index_t Hi = std::stoi(argv[15]); + const ck::index_t Wi = std::stoi(argv[16]); + + const ck::index_t conv_stride_h = std::stoi(argv[17]); + const ck::index_t conv_stride_w = std::stoi(argv[18]); + const ck::index_t conv_dilation_h = std::stoi(argv[19]); + const ck::index_t conv_dilation_w = std::stoi(argv[20]); + const ck::index_t in_left_pad_h = std::stoi(argv[21]); + const ck::index_t in_left_pad_w = std::stoi(argv[22]); + const ck::index_t in_right_pad_h = std::stoi(argv[23]); + const ck::index_t in_right_pad_w = std::stoi(argv[24]); + + const ck::index_t YEff = (Y - 1) * conv_dilation_h + 1; + const ck::index_t XEff = (X - 1) * conv_dilation_w + 1; + + const ck::index_t Ho = (Hi + in_left_pad_h + in_right_pad_h - YEff) / conv_stride_h + 1; + const ck::index_t Wo = (Wi + in_left_pad_w + in_right_pad_w - XEff) / conv_stride_w + 1; + + if(data_type == ConvDataType::F16_F16_F16 && in_layout == ConvInputLayout::NHWC && + wei_layout == ConvWeightLayout::KYXC && out_layout == ConvOutputLayout::NHWK) + { + ck::profiler::profile_conv_fwd_bias_relu_impl<2, + ck::half_t, + ck::half_t, + ck::half_t, + ck::tensor_layout::convolution::NHWC, + ck::tensor_layout::convolution::KYXC, + ck::tensor_layout::convolution::NHWK>( + do_verification, + init_method, + do_log, + nrepeat, + N, + K, + C, + std::vector{Hi, Wi}, + std::vector{Y, X}, + std::vector{Ho, Wo}, + std::vector{conv_stride_h, conv_stride_w}, + std::vector{conv_dilation_h, conv_dilation_w}, + std::vector{in_left_pad_h, in_left_pad_w}, + std::vector{in_right_pad_h, in_right_pad_w}); + } + else + { + throw std::runtime_error("wrong! data_type & layout for this operator is not implemented"); + } + + return 1; +} diff --git a/profiler/profile_conv_fwd_bias_relu_add.cpp b/profiler/profile_conv_fwd_bias_relu_add.cpp new file mode 100644 index 0000000000..b6b4822234 --- /dev/null +++ b/profiler/profile_conv_fwd_bias_relu_add.cpp @@ -0,0 +1,115 @@ +#include +#include +#include +#include +#include +#include +#include "profile_conv_fwd_bias_relu_add_impl.hpp" + +enum ConvDataType +{ + F32_F32_F32, // 0 + F16_F16_F16, // 1 +}; + +enum ConvInputLayout +{ + NCHW, // 0 + NHWC, // 1 +}; + +enum ConvWeightLayout +{ + KCYX, // 0 + KYXC, // 1 +}; + +enum ConvOutputLayout +{ + NKHW, // 0 + NHWK, // 1 +}; + +int profile_conv_fwd_bias_relu_add(int argc, char* argv[]) +{ + if(argc != 25) + { + printf( + "arg1: tensor operation (conv_fwd_bias_relu_add: ForwardConvolution+Bias+ReLu+Add)\n"); + printf("arg2: data type (0: fp32; 1: fp16)\n"); + printf("arg3: input tensor layout (0: NCHW; 1: NHWC)\n"); + printf("arg4: weight tensor layout (0: KCYX; 1: KYXC)\n"); + printf("arg5: output tensor layout (0: NKHW; 1: NHWK)\n"); + printf("arg6: verification (0: no; 1: yes)\n"); + printf("arg7: initialization (0: no init; 1: integer value; 2: decimal value)\n"); + printf("arg8: print tensor value (0: no; 1: yes)\n"); + printf("arg9: run kernel # of times (>1)\n"); + printf("arg10 to 24: N, K, C, Y, X, Hi, Wi, Sy, Sx, Dy, Dx, LeftPy, LeftPx, RightPy, " + "RightPx\n"); + exit(1); + } + + const int data_type = static_cast(std::stoi(argv[2])); + const int in_layout = static_cast(std::stoi(argv[3])); + const int wei_layout = static_cast(std::stoi(argv[4])); + const int out_layout = static_cast(std::stoi(argv[5])); + const bool do_verification = std::stoi(argv[6]); + const int init_method = std::stoi(argv[7]); + const bool do_log = std::stoi(argv[8]); + const int nrepeat = std::stoi(argv[9]); + + const ck::index_t N = std::stoi(argv[10]); + const ck::index_t K = std::stoi(argv[11]); + const ck::index_t C = std::stoi(argv[12]); + const ck::index_t Y = std::stoi(argv[13]); + const ck::index_t X = std::stoi(argv[14]); + const ck::index_t Hi = std::stoi(argv[15]); + const ck::index_t Wi = std::stoi(argv[16]); + + const ck::index_t conv_stride_h = std::stoi(argv[17]); + const ck::index_t conv_stride_w = std::stoi(argv[18]); + const ck::index_t conv_dilation_h = std::stoi(argv[19]); + const ck::index_t conv_dilation_w = std::stoi(argv[20]); + const ck::index_t in_left_pad_h = std::stoi(argv[21]); + const ck::index_t in_left_pad_w = std::stoi(argv[22]); + const ck::index_t in_right_pad_h = std::stoi(argv[23]); + const ck::index_t in_right_pad_w = std::stoi(argv[24]); + + const ck::index_t YEff = (Y - 1) * conv_dilation_h + 1; + const ck::index_t XEff = (X - 1) * conv_dilation_w + 1; + + const ck::index_t Ho = (Hi + in_left_pad_h + in_right_pad_h - YEff) / conv_stride_h + 1; + const ck::index_t Wo = (Wi + in_left_pad_w + in_right_pad_w - XEff) / conv_stride_w + 1; + + if(data_type == ConvDataType::F16_F16_F16 && in_layout == ConvInputLayout::NHWC && + wei_layout == ConvWeightLayout::KYXC && out_layout == ConvOutputLayout::NHWK) + { + ck::profiler::profile_conv_fwd_bias_relu_add_impl<2, + ck::half_t, + ck::half_t, + ck::half_t, + ck::tensor_layout::convolution::NHWC, + ck::tensor_layout::convolution::KYXC, + ck::tensor_layout::convolution::NHWK>( + do_verification, + init_method, + do_log, + nrepeat, + N, + K, + C, + std::vector{Hi, Wi}, + std::vector{Y, X}, + std::vector{Ho, Wo}, + std::vector{conv_stride_h, conv_stride_w}, + std::vector{conv_dilation_h, conv_dilation_w}, + std::vector{in_left_pad_h, in_left_pad_w}, + std::vector{in_right_pad_h, in_right_pad_w}); + } + else + { + throw std::runtime_error("wrong! data_type & layout for this operator is not implemented"); + } + + return 1; +} diff --git a/profiler/profile_conv_fwd_bias_relu_atomic_add.cpp b/profiler/profile_conv_fwd_bias_relu_atomic_add.cpp new file mode 100644 index 0000000000..3c179d36b2 --- /dev/null +++ b/profiler/profile_conv_fwd_bias_relu_atomic_add.cpp @@ -0,0 +1,116 @@ +#include +#include +#include +#include +#include +#include +#include "profile_conv_fwd_bias_relu_atomic_add_impl.hpp" + +enum ConvDataType +{ + F32_F32_F32, // 0 + F16_F16_F16, // 1 +}; + +enum ConvInputLayout +{ + NCHW, // 0 + NHWC, // 1 +}; + +enum ConvWeightLayout +{ + KCYX, // 0 + KYXC, // 1 +}; + +enum ConvOutputLayout +{ + NKHW, // 0 + NHWK, // 1 +}; + +int profile_conv_fwd_bias_relu_atomic_add(int argc, char* argv[]) +{ + if(argc != 25) + { + printf("arg1: tensor operation (conv_fwd_bias_relu_atomic_add: " + "ForwardConvolution+Bias+ReLu+AtomicAdd)\n"); + printf("arg2: data type (0: fp32; 1: fp16)\n"); + printf("arg3: input tensor layout (0: NCHW; 1: NHWC)\n"); + printf("arg4: weight tensor layout (0: KCYX; 1: KYXC)\n"); + printf("arg5: output tensor layout (0: NKHW; 1: NHWK)\n"); + printf("arg6: verification (0: no; 1: yes)\n"); + printf("arg7: initialization (0: no init; 1: integer value; 2: decimal value)\n"); + printf("arg8: print tensor value (0: no; 1: yes)\n"); + printf("arg9: run kernel # of times (>1)\n"); + printf("arg10 to 24: N, K, C, Y, X, Hi, Wi, Sy, Sx, Dy, Dx, LeftPy, LeftPx, RightPy, " + "RightPx\n"); + exit(1); + } + + const int data_type = static_cast(std::stoi(argv[2])); + const int in_layout = static_cast(std::stoi(argv[3])); + const int wei_layout = static_cast(std::stoi(argv[4])); + const int out_layout = static_cast(std::stoi(argv[5])); + const bool do_verification = std::stoi(argv[6]); + const int init_method = std::stoi(argv[7]); + const bool do_log = std::stoi(argv[8]); + const int nrepeat = std::stoi(argv[9]); + + const ck::index_t N = std::stoi(argv[10]); + const ck::index_t K = std::stoi(argv[11]); + const ck::index_t C = std::stoi(argv[12]); + const ck::index_t Y = std::stoi(argv[13]); + const ck::index_t X = std::stoi(argv[14]); + const ck::index_t Hi = std::stoi(argv[15]); + const ck::index_t Wi = std::stoi(argv[16]); + + const ck::index_t conv_stride_h = std::stoi(argv[17]); + const ck::index_t conv_stride_w = std::stoi(argv[18]); + const ck::index_t conv_dilation_h = std::stoi(argv[19]); + const ck::index_t conv_dilation_w = std::stoi(argv[20]); + const ck::index_t in_left_pad_h = std::stoi(argv[21]); + const ck::index_t in_left_pad_w = std::stoi(argv[22]); + const ck::index_t in_right_pad_h = std::stoi(argv[23]); + const ck::index_t in_right_pad_w = std::stoi(argv[24]); + + const ck::index_t YEff = (Y - 1) * conv_dilation_h + 1; + const ck::index_t XEff = (X - 1) * conv_dilation_w + 1; + + const ck::index_t Ho = (Hi + in_left_pad_h + in_right_pad_h - YEff) / conv_stride_h + 1; + const ck::index_t Wo = (Wi + in_left_pad_w + in_right_pad_w - XEff) / conv_stride_w + 1; + + if(data_type == ConvDataType::F16_F16_F16 && in_layout == ConvInputLayout::NHWC && + wei_layout == ConvWeightLayout::KYXC && out_layout == ConvOutputLayout::NHWK) + { + ck::profiler::profile_conv_fwd_bias_relu_atomic_add_impl< + 2, + ck::half_t, + ck::half_t, + ck::half_t, + ck::tensor_layout::convolution::NHWC, + ck::tensor_layout::convolution::KYXC, + ck::tensor_layout::convolution::NHWK>( + do_verification, + init_method, + do_log, + nrepeat, + N, + K, + C, + std::vector{Hi, Wi}, + std::vector{Y, X}, + std::vector{Ho, Wo}, + std::vector{conv_stride_h, conv_stride_w}, + std::vector{conv_dilation_h, conv_dilation_w}, + std::vector{in_left_pad_h, in_left_pad_w}, + std::vector{in_right_pad_h, in_right_pad_w}); + } + else + { + throw std::runtime_error("wrong! data_type & layout for this operator is not implemented"); + } + + return 1; +} diff --git a/profiler/profile_gemm.cpp b/profiler/profile_gemm.cpp new file mode 100644 index 0000000000..c34c3376f4 --- /dev/null +++ b/profiler/profile_gemm.cpp @@ -0,0 +1,227 @@ +#include +#include +#include +#include +#include +#include +#include "config.hpp" +#include "print.hpp" +#include "device.hpp" +#include "host_tensor.hpp" +#include "host_tensor_generator.hpp" +#include "host_gemm.hpp" +#include "device_tensor.hpp" +#include "device_base.hpp" +#include "device_gemm_xdl.hpp" +#include "profile_gemm_impl.hpp" + +enum GemmMatrixLayout +{ + MK_KN_MN, // 0 + MK_NK_MN, // 1 + KM_KN_MN, // 2 + KM_NK_MN, // 3 + MK_KN_NM, // 4 + MK_NK_NM, // 5 + KM_KN_NM, // 6 + KM_NK_NM, // 7 +}; + +enum GemmDataType +{ + F32_F32_F32, // 0 + F16_F16_F16, // 1 +}; + +int profile_gemm(int argc, char* argv[]) +{ + if(argc != 14) + { + printf("arg1: tensor operation (gemm: GEMM)\n"); + printf("arg2: data type (0: fp32; 1: fp16)\n"); + printf("arg3: matrix layout (0: A[m, k] * B[k, n] = C[m, n];\n"); + printf(" 1: A[m, k] * B[n, k] = C[m, n];\n"); + printf(" 2: A[k, n] * B[k, n] = C[m, n];\n"); + printf(" 3: A[k, n] * B[n, k] = C[m, n])\n"); + printf("arg4: verification (0: no; 1: yes)\n"); + printf("arg5: initialization (0: no init; 1: integer value; 2: decimal value)\n"); + printf("arg8: print tensor value (0: no; 1: yes)\n"); + printf("arg7: run kernel # of times (>1)\n"); + printf("arg8 to 13: M, N, K, StrideA, StrideB, StrideC\n"); + exit(1); + } + + const int data_type = static_cast(std::stoi(argv[2])); + const int layout = static_cast(std::stoi(argv[3])); + const bool do_verification = std::stoi(argv[4]); + const int init_method = std::stoi(argv[5]); + const bool do_log = std::stoi(argv[6]); + const int nrepeat = std::stoi(argv[7]); + + const int M = std::stoi(argv[8]); + const int N = std::stoi(argv[9]); + const int K = std::stoi(argv[10]); + + const int StrideA = std::stoi(argv[11]); + const int StrideB = std::stoi(argv[12]); + const int StrideC = std::stoi(argv[13]); + + if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::MK_KN_MN) + { + ck::profiler::profile_gemm_impl( + do_verification, + init_method, + do_log, + nrepeat, + M, + N, + K, + (StrideA < 0) ? K : StrideA, + (StrideB < 0) ? N : StrideB, + (StrideC < 0) ? N : StrideC); + } + else if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::MK_NK_MN) + { + ck::profiler::profile_gemm_impl( + do_verification, + init_method, + do_log, + nrepeat, + M, + N, + K, + (StrideA < 0) ? K : StrideA, + (StrideB < 0) ? K : StrideB, + (StrideC < 0) ? N : StrideC); + } + else if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::KM_KN_MN) + { + ck::profiler::profile_gemm_impl( + do_verification, + init_method, + do_log, + nrepeat, + M, + N, + K, + (StrideA < 0) ? M : StrideA, + (StrideB < 0) ? N : StrideB, + (StrideC < 0) ? N : StrideC); + } + else if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::KM_NK_MN) + { + ck::profiler::profile_gemm_impl( + do_verification, + init_method, + do_log, + nrepeat, + M, + N, + K, + (StrideA < 0) ? M : StrideA, + (StrideB < 0) ? K : StrideB, + (StrideC < 0) ? N : StrideC); + } + else if(data_type == GemmDataType::F32_F32_F32 && layout == GemmMatrixLayout::MK_KN_MN) + { + ck::profiler::profile_gemm_impl( + do_verification, + init_method, + do_log, + nrepeat, + M, + N, + K, + (StrideA < 0) ? K : StrideA, + (StrideB < 0) ? N : StrideB, + (StrideC < 0) ? N : StrideC); + } + else if(data_type == GemmDataType::F32_F32_F32 && layout == GemmMatrixLayout::MK_NK_MN) + { + ck::profiler::profile_gemm_impl( + do_verification, + init_method, + do_log, + nrepeat, + M, + N, + K, + (StrideA < 0) ? K : StrideA, + (StrideB < 0) ? K : StrideB, + (StrideC < 0) ? N : StrideC); + } + else if(data_type == GemmDataType::F32_F32_F32 && layout == GemmMatrixLayout::KM_KN_MN) + { + ck::profiler::profile_gemm_impl( + do_verification, + init_method, + do_log, + nrepeat, + M, + N, + K, + (StrideA < 0) ? M : StrideA, + (StrideB < 0) ? N : StrideB, + (StrideC < 0) ? N : StrideC); + } + else if(data_type == GemmDataType::F32_F32_F32 && layout == GemmMatrixLayout::KM_NK_MN) + { + ck::profiler::profile_gemm_impl( + do_verification, + init_method, + do_log, + nrepeat, + M, + N, + K, + (StrideA < 0) ? M : StrideA, + (StrideB < 0) ? K : StrideB, + (StrideC < 0) ? N : StrideC); + } + else + { + throw std::runtime_error("wrong! this GEMM data_type & layout is not implemented"); + } + + return 1; +} diff --git a/profiler/profiler.cpp b/profiler/profiler.cpp index fa69e9f1e0..a8d3322872 100644 --- a/profiler/profiler.cpp +++ b/profiler/profiler.cpp @@ -5,22 +5,42 @@ #include #include -int gemm_profiler(int, char*[]); -int conv_profiler(int, char*[]); +int profile_gemm(int, char*[]); +int profile_conv_fwd(int, char*[]); +int profile_conv_fwd_bias_relu(int, char*[]); +int profile_conv_fwd_bias_relu_add(int, char*[]); +int profile_conv_fwd_bias_relu_atomic_add(int, char*[]); int main(int argc, char* argv[]) { if(strcmp(argv[1], "gemm") == 0) { - return gemm_profiler(argc, argv); + return profile_gemm(argc, argv); } - else if(strcmp(argv[1], "conv") == 0) + else if(strcmp(argv[1], "conv_fwd") == 0) { - return conv_profiler(argc, argv); + return profile_conv_fwd(argc, argv); + } + else if(strcmp(argv[1], "conv_fwd_bias_relu") == 0) + { + return profile_conv_fwd_bias_relu(argc, argv); + } + else if(strcmp(argv[1], "conv_fwd_bias_relu_add") == 0) + { + return profile_conv_fwd_bias_relu_add(argc, argv); + } + else if(strcmp(argv[1], "conv_fwd_bias_relu_atomic_add") == 0) + { + return profile_conv_fwd_bias_relu_atomic_add(argc, argv); } else { - printf("arg1: tensor operation (gemm=GEMM, conv=Convolution)\n"); + printf("arg1: tensor operation (gemm: GEMM;\n" + " conv_fwd: ForwardConvolution;\n" + " conv_fwd_bias_relu: ForwardConvolution+Bias+ReLU)\n" + " conv_fwd_bias_relu_add: ForwardConvolution+Bias+ReLU+Add)\n" + " conv_fwd_bias_relu_atomic_add: " + "ForwardConvolution+Bias+ReLU+AtomicAdd)\n"); return 0; } }