From 8133713e969941252f0ac1218317eb5c41ee8fa5 Mon Sep 17 00:00:00 2001 From: Chao Liu Date: Fri, 5 Jul 2019 15:35:21 -0500 Subject: [PATCH] adding implicit gemm v4r2 --- ...tion_implicit_gemm_v4r1_nchw_kcyx_nkhw.hpp | 354 +++++++++++++++ ..._v4r1_nchw_kcyx_nkhw_lds_double_buffer.hpp | 415 +++++++++++++++++ ..._v4r2_nchw_kcyx_nkhw_lds_double_buffer.hpp | 428 ++++++++++++++++++ ...tion_implicit_gemm_v4r1_nchw_kcyx_nkhw.hpp | 232 ++++++++++ ...tion_implicit_gemm_v4r2_nchw_kcyx_nkhw.hpp | 159 +++++++ 5 files changed, 1588 insertions(+) create mode 100644 composable_kernel/include/kernel_algorithm/gridwise_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw.hpp create mode 100644 composable_kernel/include/kernel_algorithm/gridwise_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer.hpp create mode 100644 composable_kernel/include/kernel_algorithm/gridwise_convolution_implicit_gemm_v4r2_nchw_kcyx_nkhw_lds_double_buffer.hpp create mode 100644 driver/include/device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw.hpp create mode 100644 driver/include/device_convolution_implicit_gemm_v4r2_nchw_kcyx_nkhw.hpp diff --git a/composable_kernel/include/kernel_algorithm/gridwise_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw.hpp b/composable_kernel/include/kernel_algorithm/gridwise_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw.hpp new file mode 100644 index 0000000000..1c5d6ef5e4 --- /dev/null +++ b/composable_kernel/include/kernel_algorithm/gridwise_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw.hpp @@ -0,0 +1,354 @@ +#ifndef CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V4R1_NCHW_KCYX_NKHW +#define CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V4R1_NCHW_KCYX_NKHW + +#include "common_header.hpp" +#include "ConstantTensorDescriptor.hpp" +#include "ConstantMergedTensorDescriptor.hpp" +#include "ConstantMatrixDescriptor.hpp" +#include "blockwise_generic_tensor_slice_copy.hpp" +#include "blockwise_gemm.hpp" +#include "threadwise_generic_tensor_slice_copy.hpp" + +namespace ck { + +// define B = merge(N0, Ho, Wo) +template +struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw +{ + __device__ void Run(const Float* const __restrict__ p_in_global, + const Float* const __restrict__ p_wei_global, + Float* const __restrict__ p_out_global) const + { + // this is a mess + // TODO: find more elegent way of specifying (or calculating) performance parameters + static_assert(N2 == GemmNPerThreadSubC, "wrong!"); + static_assert((N1 * N2 * BPerBlock) % + (GemmNPerThreadSubC * GemmNLevel0Cluster * GemmNLevel1Cluster) == + 0, + "wrong!"); + + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + constexpr auto I2 = Number<2>{}; + constexpr auto I3 = Number<3>{}; + constexpr auto I4 = Number<4>{}; + constexpr auto I5 = Number<5>{}; + constexpr auto I6 = Number<6>{}; + constexpr auto I7 = Number<7>{}; + + constexpr auto True = integral_constant{}; + + constexpr auto in_n_c_h_w_global_desc = InGlobalDesc{}; + constexpr auto wei_k_c_y_x_global_desc = WeiGlobalDesc{}; + constexpr auto out_n_k_h_w_global_desc = OutGlobalDesc{}; + + constexpr index_t N = in_n_c_h_w_global_desc.GetLength(I0); + constexpr index_t C = in_n_c_h_w_global_desc.GetLength(I1); + constexpr index_t Hi = in_n_c_h_w_global_desc.GetLength(I2); + constexpr index_t Wi = in_n_c_h_w_global_desc.GetLength(I3); + + constexpr index_t K = out_n_k_h_w_global_desc.GetLength(I1); + constexpr index_t Ho = out_n_k_h_w_global_desc.GetLength(I2); + constexpr index_t Wo = out_n_k_h_w_global_desc.GetLength(I3); + + constexpr index_t Y = wei_k_c_y_x_global_desc.GetLength(I2); + constexpr index_t X = wei_k_c_y_x_global_desc.GetLength(I3); + + static_assert(N % (N1 * N2) == 0, "wrong! cannot divice N evenly among thread"); + + constexpr index_t N0 = N / (N1 * N2); + + constexpr index_t B = N0 * Ho * Wo; + + constexpr index_t E = C * Y * X; + + // divide block work by [K, B] + static_assert(K % KPerBlock == 0 && B % BPerBlock == 0 && E % EPerBlock == 0, + "wrong! cannot divide work evenly among block"); + + constexpr index_t KBlockWork = K / KPerBlock; + constexpr index_t BBlockWork = B / BPerBlock; + + constexpr auto block_work_desc = + make_ConstantTensorDescriptor_packed(Sequence{}); + + const auto block_work_multi_id = + block_work_desc.GetMultiIndexFrom1dIndex(get_block_1d_id()); + + const index_t k_block_data_on_global = block_work_multi_id[0] * KPerBlock; + const index_t b_block_data_on_global = block_work_multi_id[1] * BPerBlock; + + // input tensor + // tensor descriptor in device memory [N0, N1, N2, Ho, Wo] + constexpr auto in_n0_n1_n2_h_w_global_desc = in_n_c_h_w_global_desc.Slice(I2, Number{}) + .Slice(I3, Number{}) + .Fold(I0, Number{}, Number{}) + .Extract(Sequence<0, 1, 2, 4, 5>{}); + + // batch descritpor for device memory + constexpr auto in_c_y_x_global_desc = in_n_c_h_w_global_desc.Slice(I2, Number{}) + .Slice(I3, Number{}) + .Extract(Sequence<1, 2, 3>{}); + + // merged tensor descriptor in device memory [E, N1, B, N2], src of blockwise copy + constexpr auto in_e_n1_b_n2_global_merged_desc = make_ConstantMergedTensorDescriptor( + in_c_y_x_global_desc.Embed(in_n0_n1_n2_h_w_global_desc), + Sequence<0, 1, 2>{}, + Sequence<4>{}, + Sequence<3, 6, 7>{}, + Sequence<5>{}); + +#if 0 + if(get_block_1d_id() == 0 && get_thread_local_1d_id() == 0) + { + print_ConstantTensorDescriptor(in_n0_n1_n2_h_w_global_desc, + "in_n0_n1_n2_h_w_global_desc: "); + print_ConstantTensorDescriptor(in_c_y_x_global_desc, "in_c_y_x_global_desc: "); + print_ConstantMergedTensorDescriptor(in_e_n1_b_n2_global_merged_desc, + "in_e_n1_b_n2_global_merged_desc: "); + } +#endif + + // memory layout descriptor in LDS [E, N1, B, N2], dst of blockwise copy + // be careful of LDS alignment + constexpr auto in_e_n1_b_n2_block_desc = make_ConstantTensorDescriptor_aligned( + Sequence{}, Number{}); + + // this check is ad-hoc + // TODO: need to properly implement tensor descriptor with multiple alignment + // requirements + static_assert(in_e_n1_b_n2_block_desc.GetStride(I1) % GemmDataPerReadB == 0, + "GemmDataPerReadB alignment requirement is not satisfied"); + + // input blockwise copy + // slice a merged tensor, reorder and copy to a normal tensor + // this copy operator already has blockwise offset built-in + auto blockwise_in_copy = + BlockwiseGenericTensorSliceCopy_v1( + {0, 0, b_block_data_on_global, 0}, {0, 0, 0, 0}); + + // weight tensor + // tensor descriptor in device memory, src of blockwise copy + constexpr auto wei_e_k_global_desc = + wei_k_c_y_x_global_desc.Unfold(I1, I3).ReorderGivenNew2Old(Sequence<1, 0>{}); + + // tensor descriptor in LDS, dst of blockwise copy + // be careful of LDS alignment + constexpr auto wei_e_k_block_desc = make_ConstantTensorDescriptor_aligned( + Sequence{}, + Number{}); + + // operator for blockwise copy of weight into LDS + // slice a tensor, and copy it into another tensor + // this copy operator already have blockwise offset built-in + auto blockwise_wei_copy = + BlockwiseGenericTensorSliceCopy_v1( + {0, k_block_data_on_global}, {0, 0}); + + // GEMM definition + // c_mtx += transpose(a_mtx) * b_mtx + // a_mtx[EPerBlock, KPerBlock] is in LDS + // b_mtx[EPerBlocl, N1 * BPerBlock * N2] is in LDS + // c_mtx[KPerBlock, N1 * BPerBlock * N2] is distributed among threads, and saved in + // register + constexpr auto a_e_k_block_mtx_desc = make_ConstantMatrixDescriptor( + Number{}, Number{}, Number{}); + + constexpr auto b_e_n1bn2_block_mtx_desc = + make_ConstantMatrixDescriptor(Number{}, + Number{}, + Number{}); + + // sanity check + static_assert(KPerBlock % (GemmMPerThreadSubC * GemmMLevel0Cluster * GemmMLevel1Cluster) == + 0, + "wrong!"); + + constexpr index_t GemmMRepeat = + KPerBlock / (GemmMPerThreadSubC * GemmMLevel0Cluster * GemmMLevel1Cluster); + + // c_thread_mtx definition: this is a mess + // TODO:: more elegent way of defining c_thread_mtx + constexpr auto c_k0k2_n1n2_thread_mtx_desc = make_ConstantMatrixDescriptor( + Number{}, Number{}); + + const auto blockwise_gemm = BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2< + BlockSize, + decltype(a_e_k_block_mtx_desc), + decltype(b_e_n1bn2_block_mtx_desc), + decltype(c_k0k2_n1n2_thread_mtx_desc), + GemmMPerThreadSubC, + GemmNPerThreadSubC, + GemmMLevel0Cluster, + GemmNLevel0Cluster, + GemmMLevel1Cluster, + GemmNLevel1Cluster, + GemmKPerThreadLoop, + GemmDataPerReadA, + GemmDataPerReadB>{}; + + // choose GEMM implementation here + const auto run_blockwise_gemm = [&](auto... Xs) { +#if 1 + return blockwise_gemm.Run(Xs...); +#else + return blockwise_gemm.Run_amd_asm(Xs...); +#endif + }; + + // LDS allocation for input and weight: be careful of alignment + constexpr index_t max_align = math::lcm(InBlockCopyDstDataPerWrite_N2, + WeiBlockCopyDstDataPerWrite_K, + GemmDataPerReadA, + GemmDataPerReadB); + + constexpr index_t in_block_space = + in_e_n1_b_n2_block_desc.GetElementSpace(Number{}); + + constexpr index_t wei_block_space = wei_e_k_block_desc.GetElementSpace(Number{}); + + __shared__ Float p_in_block[in_block_space]; + __shared__ Float p_wei_block[wei_block_space]; + + // register allocation for output + Float p_out_thread[c_k0k2_n1n2_thread_mtx_desc.GetElementSpace()]; + + // zero out threadwise output + threadwise_matrix_set_zero(c_k0k2_n1n2_thread_mtx_desc, p_out_thread); + + // do work + for(index_t e = 0; e < E; e += EPerBlock) + { + // marching slicing window + blockwise_in_copy.Run(p_in_global, p_in_block); + blockwise_wei_copy.Run(p_wei_global, p_wei_block); + + __syncthreads(); + + run_blockwise_gemm(p_wei_block, p_in_block, p_out_thread); + + __syncthreads(); + + blockwise_in_copy.MoveSlicingWindowOnSourceTensor(I0, Number{}, True); + blockwise_wei_copy.MoveSlicingWindowOnSourceTensor(I0, Number{}, True); + } + + // copy output: register to global memory + { + constexpr index_t K2 = GemmMPerThreadSubC; + constexpr index_t K1 = GemmMLevel0Cluster * GemmMLevel1Cluster; + constexpr index_t K0 = K / (K1 * K2); + + // define tensor descriptor for threadwise copy + // output memory layout descriptor in register + constexpr auto out_k0_k1_k2_n1_n0_h_w_n2_thread_mem_desc = + make_ConstantTensorDescriptor_packed( + Sequence{}); + + // output tensor descriptor in register, src of threadwise copy + constexpr auto out_n0_n1_n2_k0_k1_k2_h_w_thread_desc = + out_k0_k1_k2_n1_n0_h_w_n2_thread_mem_desc.ReorderGivenNew2Old( + Sequence<4, 3, 7, 0, 1, 2, 5, 6>{}); + + // output memory layout descriptor in device memory, dst of threadwise copy + constexpr auto out_n0_n1_n2_k0_k1_k2_h_w_global_mem_desc = + out_n_k_h_w_global_desc.Fold(I1, Number{}, Number{}) + .Fold(I0, Number{}, Number{}); + + // calculate origin of thread output tensor on global memory + // blockwise GEMM c matrix starting index + const auto c_thread_mtx_on_block = + blockwise_gemm.GetBeginOfThreadMatrixC(get_thread_local_1d_id()); + + const index_t k_thread_data_on_global = + k_block_data_on_global + c_thread_mtx_on_block.row; + + const index_t b_thread_data_on_global = + b_block_data_on_global + c_thread_mtx_on_block.col / N2; + + // output merged global tensor descriptor, for calculating origin of thread tensor + // in global memory + constexpr auto out_k_n1_b_n2_global_merged_desc = make_ConstantMergedTensorDescriptor( + out_n0_n1_n2_k0_k1_k2_h_w_global_mem_desc.Unfold(I3, I5), + Sequence<3>{}, + Sequence<1>{}, + Sequence<0, 4, 5>{}, + Sequence<2>{}); + + // origin of dst in device memory + Float* p_out_thread_on_global = + p_out_global + + out_k_n1_b_n2_global_merged_desc.GetOffsetFromMultiIndex( + k_thread_data_on_global, 0, b_thread_data_on_global, 0); + + threadwise_generic_tensor_slice_copy_v1( + out_n0_n1_n2_k0_k1_k2_h_w_thread_desc, + p_out_thread, + {0, 0, 0, 0, 0, 0, 0, 0}, + out_n0_n1_n2_k0_k1_k2_h_w_global_mem_desc, + p_out_thread_on_global, + {0, 0, 0, 0, 0, 0, 0, 0}, + out_n0_n1_n2_k0_k1_k2_h_w_thread_desc.GetLengths(), + arithmetic_sequence_gen<0, 8, 1>::type{}, + Number<1>{}); + } + } +}; + +} // namespace ck +#endif diff --git a/composable_kernel/include/kernel_algorithm/gridwise_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer.hpp b/composable_kernel/include/kernel_algorithm/gridwise_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer.hpp new file mode 100644 index 0000000000..07ac2d8a88 --- /dev/null +++ b/composable_kernel/include/kernel_algorithm/gridwise_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer.hpp @@ -0,0 +1,415 @@ +#ifndef CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V4R1_NCHW_KCYX_NKHW_LDS_DOUBLE_BUFFER +#define CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V4R1_NCHW_KCYX_NKHW_LDS_DOUBLE_BUFFER + +#include "common_header.hpp" +#include "ConstantTensorDescriptor.hpp" +#include "ConstantMergedTensorDescriptor.hpp" +#include "ConstantMatrixDescriptor.hpp" +#include "blockwise_generic_tensor_slice_copy.hpp" +#include "blockwise_gemm.hpp" +#include "threadwise_generic_tensor_slice_copy.hpp" + +namespace ck { + +// define B = merge(N0, Ho, Wo) +template +struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer +{ + __device__ void Run(const Float* const __restrict__ p_in_global, + const Float* const __restrict__ p_wei_global, + Float* const __restrict__ p_out_global) const + { + // this is a mess + // TODO: find more elegent way of specifying (or calculating) performance parameters + static_assert(N2 == GemmNPerThreadSubC, "wrong!"); + static_assert((N1 * N2 * BPerBlock) % + (GemmNPerThreadSubC * GemmNLevel0Cluster * GemmNLevel1Cluster) == + 0, + "wrong!"); + + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + constexpr auto I2 = Number<2>{}; + constexpr auto I3 = Number<3>{}; + constexpr auto I5 = Number<5>{}; + + constexpr auto True = integral_constant{}; + + constexpr auto in_n_c_h_w_global_desc = InGlobalDesc{}; + constexpr auto wei_k_c_y_x_global_desc = WeiGlobalDesc{}; + constexpr auto out_n_k_h_w_global_desc = OutGlobalDesc{}; + + constexpr index_t N = in_n_c_h_w_global_desc.GetLength(I0); + constexpr index_t C = in_n_c_h_w_global_desc.GetLength(I1); + + constexpr index_t K = out_n_k_h_w_global_desc.GetLength(I1); + constexpr index_t Ho = out_n_k_h_w_global_desc.GetLength(I2); + constexpr index_t Wo = out_n_k_h_w_global_desc.GetLength(I3); + + constexpr index_t Y = wei_k_c_y_x_global_desc.GetLength(I2); + constexpr index_t X = wei_k_c_y_x_global_desc.GetLength(I3); + + constexpr index_t ConvStrideH = ConvStrides{}[0]; + constexpr index_t ConvStrideW = ConvStrides{}[1]; + + constexpr index_t ConvDilationH = ConvDilations{}[0]; + constexpr index_t ConvDilationW = ConvDilations{}[1]; + + static_assert(N % (N1 * N2) == 0, "wrong! cannot divice N evenly among thread"); + + constexpr index_t N0 = N / (N1 * N2); + + constexpr index_t B = N0 * Ho * Wo; + + constexpr index_t E = C * Y * X; + + // sanity-check for vectorized memory load + static_assert(ConvStrideW == 1 || InBlockCopySrcDataPerRead_B == 1, + "wrong! global vector load of input tensor is wrong"); + + static_assert((X == 1 || ConvDilationW % InBlockCopySrcDataPerRead_B == 0), + "wrong! aligment requirement for vectorized global load of input tensor will " + "be violated"); + + // divide block work by [K, B] + static_assert(K % KPerBlock == 0 && B % BPerBlock == 0 && E % (2 * EPerBlock) == 0, + "wrong! cannot divide work evenly among block"); + + constexpr index_t KBlockWork = K / KPerBlock; + constexpr index_t BBlockWork = B / BPerBlock; + + constexpr auto block_work_desc = + make_ConstantTensorDescriptor_packed(Sequence{}); + + const auto block_work_multi_id = + block_work_desc.GetMultiIndexFrom1dIndex(get_block_1d_id()); + + const index_t k_block_data_on_global = block_work_multi_id[0] * KPerBlock; + const index_t b_block_data_on_global = block_work_multi_id[1] * BPerBlock; + + // input tensor + // tensor descriptor in device memory [N0, N1, N2, Ho, Wo] + constexpr auto in_n0_n1_n2_h_w_global_desc = + in_n_c_h_w_global_desc.StridedSlice(I2, Number{}, Number{}) + .StridedSlice(I3, Number{}, Number{}) + .Fold(I0, Number{}, Number{}) + .Extract(Sequence<0, 1, 2, 4, 5>{}); + + // batch descritpor for device memory + constexpr auto in_c_y_x_global_desc = + in_n_c_h_w_global_desc.StridedSlice(I2, Number{}, Number{}) + .StridedSlice(I3, Number{}, Number{}) + .Extract(Sequence<1, 2, 3>{}); + + // merged tensor descriptor in device memory [E, N1, B, N2], src of blockwise copy + constexpr auto in_e_n1_b_n2_global_merged_desc = make_ConstantMergedTensorDescriptor( + in_c_y_x_global_desc.Embed(in_n0_n1_n2_h_w_global_desc), + Sequence<0, 1, 2>{}, + Sequence<4>{}, + Sequence<3, 6, 7>{}, + Sequence<5>{}); + + // memory layout descriptor in LDS [E, N1, B, N2], dst of blockwise copy + // be careful of LDS alignment + constexpr auto in_e_n1_b_n2_block_desc = make_ConstantTensorDescriptor_aligned( + Sequence{}, Number{}); + + // this check is ad-hoc + // TODO: need to properly implement tensor descriptor with multiple alignment + // requirements + static_assert(in_e_n1_b_n2_block_desc.GetStride(I1) % GemmDataPerReadB == 0, + "GemmDataPerReadB alignment requirement is not satisfied"); + + // input blockwise copy + // slice a merged tensor, reorder and copy to a normal tensor + // this copy operator already has blockwise offset built-in + auto blockwise_in_copy = + BlockwiseGenericTensorSliceCopy_v1( + {0, 0, b_block_data_on_global, 0}, {0, 0, 0, 0}); + + // weight tensor + // tensor descriptor in device memory, src of blockwise copy + constexpr auto wei_e_k_global_desc = + wei_k_c_y_x_global_desc.Unfold(I1, I3).ReorderGivenNew2Old(Sequence<1, 0>{}); + + // tensor descriptor in LDS, dst of blockwise copy + // be careful of LDS alignment + constexpr auto wei_e_k_block_desc = make_ConstantTensorDescriptor_aligned( + Sequence{}, + Number{}); + + // operator for blockwise copy of weight into LDS + // slice a tensor, and copy it into another tensor + // this copy operator already have blockwise offset built-in + auto blockwise_wei_copy = + BlockwiseGenericTensorSliceCopy_v1( + {0, k_block_data_on_global}, {0, 0}); + + // GEMM definition + // c_mtx += transpose(a_mtx) * b_mtx + // a_mtx[EPerBlock, KPerBlock] is in LDS + // b_mtx[EPerBlocl, N1 * BPerBlock * N2] is in LDS + // c_mtx[KPerBlock, N1 * BPerBlock * N2] is distributed among threads, and saved in + // register + constexpr auto a_e_k_block_mtx_desc = make_ConstantMatrixDescriptor( + Number{}, Number{}, Number{}); + + constexpr auto b_e_n1bn2_block_mtx_desc = + make_ConstantMatrixDescriptor(Number{}, + Number{}, + Number{}); + + // sanity check + static_assert(KPerBlock % (GemmMPerThreadSubC * GemmMLevel0Cluster * GemmMLevel1Cluster) == + 0, + "wrong!"); + + constexpr index_t GemmMRepeat = + KPerBlock / (GemmMPerThreadSubC * GemmMLevel0Cluster * GemmMLevel1Cluster); + + // c_thread_mtx definition: this is a mess + // TODO:: more elegent way of defining c_thread_mtx + constexpr auto c_k0k2_n1n2_thread_mtx_desc = make_ConstantMatrixDescriptor( + Number{}, Number{}); + + const auto blockwise_gemm = BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2< + BlockSize, + decltype(a_e_k_block_mtx_desc), + decltype(b_e_n1bn2_block_mtx_desc), + decltype(c_k0k2_n1n2_thread_mtx_desc), + GemmMPerThreadSubC, + GemmNPerThreadSubC, + GemmMLevel0Cluster, + GemmNLevel0Cluster, + GemmMLevel1Cluster, + GemmNLevel1Cluster, + GemmKPerThreadLoop, + GemmDataPerReadA, + GemmDataPerReadB>{}; + + // LDS allocation for input and weight: be careful of alignment + constexpr index_t max_align = math::lcm(InBlockCopyDstDataPerWrite_N2, + WeiBlockCopyDstDataPerWrite_K, + GemmDataPerReadA, + GemmDataPerReadB); + + constexpr index_t in_block_space = + math::integer_least_multiple(in_e_n1_b_n2_block_desc.GetElementSpace(), max_align); + + constexpr index_t wei_block_space = + math::integer_least_multiple(wei_e_k_block_desc.GetElementSpace(), max_align); + + __shared__ Float p_in_block_double[2 * in_block_space]; + __shared__ Float p_wei_block_double[2 * wei_block_space]; + + // register allocation for output + Float p_out_thread[c_k0k2_n1n2_thread_mtx_desc.GetElementSpace()]; + + // zero out threadwise output + threadwise_matrix_set_zero(c_k0k2_n1n2_thread_mtx_desc, p_out_thread); + + const Float* p_wei_block_on_global = p_wei_global; + + // LDS double buffer: preload data into LDS + { + blockwise_in_copy.Run(p_in_global, p_in_block_double); + blockwise_wei_copy.Run(p_wei_global, p_wei_block_double); + } + + // LDS double buffer: main body + for(index_t e_block_data_begin = 0; e_block_data_begin + 2 * EPerBlock < E; + e_block_data_begin += 2 * EPerBlock) + { +#pragma unroll + for(index_t iloop = 0; iloop < 2; ++iloop) + { + const bool even_loop = (iloop % 2 == 0); + + Float* p_in_block_now = + even_loop ? p_in_block_double : p_in_block_double + in_block_space; + Float* p_wei_block_now = + even_loop ? p_wei_block_double : p_wei_block_double + wei_block_space; + + Float* p_in_block_next = + even_loop ? p_in_block_double + in_block_space : p_in_block_double; + Float* p_wei_block_next = + even_loop ? p_wei_block_double + wei_block_space : p_wei_block_double; + + Float p_in_register_clipboard[blockwise_in_copy.GetRegisterClipboardSize()]; + Float p_wei_register_clipboard[blockwise_wei_copy.GetRegisterClipboardSize()]; + + blockwise_in_copy.MoveSlicingWindowOnSourceTensor(I0, Number{}, True); + p_wei_block_on_global += EPerBlock * wei_e_k_global_desc.GetStride(I0); + + __syncthreads(); + + // LDS doubel buffer: load next data from device mem + blockwise_in_copy.RunLoadRegisterClipboard(p_in_global, p_in_register_clipboard); + blockwise_wei_copy.RunLoadRegisterClipboard(p_wei_block_on_global, + p_wei_register_clipboard); + + // LDS double buffer: GEMM on current data + blockwise_gemm.Run(p_wei_block_now, p_in_block_now, p_out_thread); + + // LDS double buffer: store next data to LDS + blockwise_in_copy.RunStoreRegisterClipboard(p_in_register_clipboard, + p_in_block_next); + blockwise_wei_copy.RunStoreRegisterClipboard(p_wei_register_clipboard, + p_wei_block_next); + } + } + + // LDS double buffer: tail + { + Float p_in_register_clipboard[blockwise_in_copy.GetRegisterClipboardSize()]; + Float p_wei_register_clipboard[blockwise_wei_copy.GetRegisterClipboardSize()]; + + // even iteration + blockwise_in_copy.MoveSlicingWindowOnSourceTensor(I0, Number{}, True); + p_wei_block_on_global += EPerBlock * wei_e_k_global_desc.GetStride(I0); + + __syncthreads(); + + // LDS doubel buffer: load next data from device mem + blockwise_in_copy.RunLoadRegisterClipboard(p_in_global, p_in_register_clipboard); + blockwise_wei_copy.RunLoadRegisterClipboard(p_wei_block_on_global, + p_wei_register_clipboard); + + // LDS double buffer: GEMM on current data + blockwise_gemm.Run(p_wei_block_double, p_in_block_double, p_out_thread); + + // LDS double buffer: store next data to LDS + blockwise_in_copy.RunStoreRegisterClipboard(p_in_register_clipboard, + p_in_block_double + in_block_space); + blockwise_wei_copy.RunStoreRegisterClipboard(p_wei_register_clipboard, + p_wei_block_double + wei_block_space); + + // odd iteration + __syncthreads(); + + // LDS double buffer: GEMM on current data + blockwise_gemm.Run(p_wei_block_double + wei_block_space, + p_in_block_double + in_block_space, + p_out_thread); + } + + // copy output: register to global memory + { + constexpr index_t K2 = GemmMPerThreadSubC; + constexpr index_t K1 = GemmMLevel0Cluster * GemmMLevel1Cluster; + + // define tensor descriptor for threadwise copy + // output memory layout descriptor in register + constexpr auto out_k0_k1_k2_n1_n0_h_w_n2_thread_mem_desc = + make_ConstantTensorDescriptor_packed( + Sequence{}); + + // output tensor descriptor in register, src of threadwise copy + constexpr auto out_n0_n1_n2_k0_k1_k2_h_w_thread_desc = + out_k0_k1_k2_n1_n0_h_w_n2_thread_mem_desc.ReorderGivenNew2Old( + Sequence<4, 3, 7, 0, 1, 2, 5, 6>{}); + + // output memory layout descriptor in device memory, dst of threadwise copy + constexpr auto out_n0_n1_n2_k0_k1_k2_h_w_global_mem_desc = + out_n_k_h_w_global_desc.Fold(I1, Number{}, Number{}) + .Fold(I0, Number{}, Number{}); + + // calculate origin of thread output tensor on global memory + // blockwise GEMM c matrix starting index + const auto c_thread_mtx_on_block = + blockwise_gemm.GetBeginOfThreadMatrixC(get_thread_local_1d_id()); + + const index_t k_thread_data_on_global = + k_block_data_on_global + c_thread_mtx_on_block.row; + + const index_t b_thread_data_on_global = + b_block_data_on_global + c_thread_mtx_on_block.col / N2; + + // output merged global tensor descriptor, for calculating origin of thread tensor + // in global memory + constexpr auto out_k_n1_b_n2_global_merged_desc = make_ConstantMergedTensorDescriptor( + out_n0_n1_n2_k0_k1_k2_h_w_global_mem_desc.Unfold(I3, I5), + Sequence<3>{}, + Sequence<1>{}, + Sequence<0, 4, 5>{}, + Sequence<2>{}); + + // origin of dst in device memory + Float* p_out_thread_on_global = + p_out_global + + out_k_n1_b_n2_global_merged_desc.GetOffsetFromMultiIndex( + k_thread_data_on_global, 0, b_thread_data_on_global, 0); + + threadwise_generic_tensor_slice_copy_v1( + out_n0_n1_n2_k0_k1_k2_h_w_thread_desc, + p_out_thread, + {0, 0, 0, 0, 0, 0, 0, 0}, + out_n0_n1_n2_k0_k1_k2_h_w_global_mem_desc, + p_out_thread_on_global, + {0, 0, 0, 0, 0, 0, 0, 0}, + out_n0_n1_n2_k0_k1_k2_h_w_thread_desc.GetLengths(), + arithmetic_sequence_gen<0, 8, 1>::type{}, + Number<1>{}); + } + } +}; + +} // namespace ck +#endif diff --git a/composable_kernel/include/kernel_algorithm/gridwise_convolution_implicit_gemm_v4r2_nchw_kcyx_nkhw_lds_double_buffer.hpp b/composable_kernel/include/kernel_algorithm/gridwise_convolution_implicit_gemm_v4r2_nchw_kcyx_nkhw_lds_double_buffer.hpp new file mode 100644 index 0000000000..13a523b521 --- /dev/null +++ b/composable_kernel/include/kernel_algorithm/gridwise_convolution_implicit_gemm_v4r2_nchw_kcyx_nkhw_lds_double_buffer.hpp @@ -0,0 +1,428 @@ +#ifndef CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V4R2_NCHW_KCYX_NKHW_LDS_DOUBLE_BUFFER +#define CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V4R2_NCHW_KCYX_NKHW_LDS_DOUBLE_BUFFER + +#include "common_header.hpp" +#include "ConstantTensorDescriptor.hpp" +#include "ConstantMergedTensorDescriptor.hpp" +#include "ConstantMatrixDescriptor.hpp" +#include "blockwise_generic_tensor_slice_copy.hpp" +#include "blockwise_gemm.hpp" +#include "threadwise_generic_tensor_slice_copy.hpp" + +namespace ck { + +// define B = merge(N0, Ho, Wo) +template +struct GridwiseConvolutionImplicitGemm_v4r2_nchw_kcyx_nkhw_lds_double_buffer +{ + __device__ void Run(const Float* const __restrict__ p_in_global, + const Float* const __restrict__ p_wei_global, + Float* const __restrict__ p_out_global) const + { + // this is a mess + // TODO: find more elegent way of specifying (or calculating) performance parameters + static_assert(N2 == GemmNPerThreadSubC, "wrong!"); + static_assert((N1 * N2 * BPerBlock) % + (GemmNPerThreadSubC * GemmNLevel0Cluster * GemmNLevel1Cluster) == + 0, + "wrong!"); + + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + constexpr auto I2 = Number<2>{}; + constexpr auto I3 = Number<3>{}; + constexpr auto I5 = Number<5>{}; + + constexpr auto True = integral_constant{}; + + constexpr auto in_n_c_h_w_global_desc = InGlobalDesc{}; + constexpr auto wei_k_c_y_x_global_desc = WeiGlobalDesc{}; + constexpr auto out_n_k_h_w_global_desc = OutGlobalDesc{}; + + constexpr index_t N = in_n_c_h_w_global_desc.GetLengths()[0]; + constexpr index_t C = in_n_c_h_w_global_desc.GetLengths()[1]; + + constexpr index_t K = out_n_k_h_w_global_desc.GetLengths()[1]; + constexpr index_t Ho = out_n_k_h_w_global_desc.GetLengths()[2]; + constexpr index_t Wo = out_n_k_h_w_global_desc.GetLengths()[3]; + + constexpr index_t Y = wei_k_c_y_x_global_desc.GetLengths()[2]; + constexpr index_t X = wei_k_c_y_x_global_desc.GetLengths()[3]; + + constexpr index_t ConvStrideH = ConvStrides{}[0]; + constexpr index_t ConvStrideW = ConvStrides{}[1]; + + constexpr index_t ConvDilationH = ConvDilations{}[0]; + constexpr index_t ConvDilationW = ConvDilations{}[1]; + + constexpr index_t E = C * Y * X; + + constexpr index_t N1 = N / (N0 * N2); + constexpr index_t Ho1 = Ho / (Ho0 * Ho2); + constexpr index_t Wo1 = Wo / (Wo0 * Wo2); + + constexpr index_t B1 = N1 * Ho1 * Wo1; + + static_assert((X == 1 || ConvDilationW % InBlockCopySrcDataPerRead_B == 0), + "wrong! aligment requirement for vectorized global load of input tensor will " + "be violated"); + + // divide block work by [K, B] + static_assert(K % KPerBlock == 0 && B % BPerBlock == 0 && E % (2 * EPerBlock) == 0, + "wrong! cannot divide work evenly among block"); + + constexpr index_t KBlockWork = K / KPerBlock; + constexpr index_t B1BlockWork = B1 / B1PerBlock; + + constexpr auto block_work_desc = + make_ConstantTensorDescriptor_packed(Sequence{}); + + const auto block_work_multi_id = + block_work_desc.GetMultiIndexFrom1dIndex(get_block_1d_id()); + + const index_t k_block_data_on_global = block_work_multi_id[0] * KPerBlock; + const index_t b1_block_data_on_global = block_work_multi_id[1] * B1PerBlock; + + // input tensor + // tensor descriptor in device memory [N0, N1, N2, Ho0, Ho1, Ho2, Wo0, Wo1, Wo2] + constexpr auto in_n0_n1_n2_ho0_ho1_ho2_wo0_wo1_wo2_global_desc = + in_n_c_h_w_global_desc.Extract(I0, I2, I3) + .StridedSlice(I1, Number{}, Number{}) + .StridedSlice(I2, Number{}, Number{}) + .Fold(I2, Number{}, Number{}) + .Fold(I1, Number{}, Number{}) + .Fold(I0, Number{}, Number{}); + + constexpr auto in_n0_ho0_wo0_n1_ho1_wo1_n2_ho2_wo2_global_desc = + in_n0_n1_n2_ho0_ho1_ho2_wo0_wo1_wo2_global_desc.ReorderGivenNew2Old( + Sequence<0, 3, 6, 1, 4, 7, 2, 5, 8>{}); + + // batch descritpor for device memory + constexpr auto in_c_y_x_global_desc = + in_n_c_h_w_global_desc.StridedSlice(I2, Number{}, Number{}) + .StridedSlice(I3, Number{}, Number{}) + .Extract(Sequence<1, 2, 3>{}); + + // merged tensor descriptor in device memory [E, N1, B, N2], src of blockwise copy + constexpr auto in_e_n0_ho0_wo0_b1_n2_ho2_wo2_global_merged_desc = + make_ConstantMergedTensorDescriptor( + in_c_y_x_global_desc.Embed(in_n0_ho0_wo0_n1_ho1_wo1_n2_ho2_wo2_global_desc), + Sequence<0, 1, 2>{}, + Sequence<3>{}, + Sequence<4>{}, + Sequence<5>{}, + Sequence<6, 7, 8>{}, + Sequence<9>{}, + Sequence<10>{}, + Sequence<11>{}); + + // memory layout descriptor in LDS [E, N1, B, N2], dst of blockwise copy + // be careful of LDS alignment + constexpr auto in_e_n0_ho0_wo0_b1_n2_ho2_wo2_block_desc = + in_e_n0_ho0_wo0_b1_n2_ho2_wo2_global_merged_desc.Pack(); + + // input blockwise copy + // slice a merged tensor, reorder and copy to a normal tensor + // this copy operator already has blockwise offset built-in + auto blockwise_in_copy = BlockwiseGenericTensorSliceCopy_v1< + BlockSize, + Float, + decltype(in_e_n0_ho0_wo0_b1_n2_ho2_wo2_global_merged_desc), + decltype(in_e_n0_ho0_wo0_b1_n2_ho2_wo2_block_desc), + decltype(in_e_n0_ho0_wo0_b1_n2_ho2_wo2_block_desc.GetLengths()), + InBlockCopySubLengths_E_N0_Ho0_Wo0_B1_N2_Ho2_Wo2, + InBlockCopyClusterLengths_E_N0_Ho0_Wo0_B1_N2_Ho2_Wo2, + InBlockCopyThreadClusterArrangeOrder, + InBlockCopySrcAccessOrder, + InBlockCopyDstAccessOrder, + InBlockCopyDataPerAccess_Wo2, + InBlockCopyDataPerAccess_Wo2>({0, 0, 0, 0, b1_block_data_on_global, 0, 0, 0}, + {0, 0, 0, 0, 0, 0, 0, 0}); + + // weight tensor + // tensor descriptor in device memory, src of blockwise copy + constexpr auto wei_e_k_global_desc = + wei_k_c_y_x_global_desc.Unfold(I1, I3).ReorderGivenNew2Old(Sequence<1, 0>{}); + + // tensor descriptor in LDS, dst of blockwise copy + // be careful of LDS alignment + constexpr auto wei_e_k_block_desc = make_ConstantTensorDescriptor_aligned( + Sequence{}, + Number{}); + + // operator for blockwise copy of weight into LDS + // slice a tensor, and copy it into another tensor + // this copy operator already have blockwise offset built-in + auto blockwise_wei_copy = + BlockwiseGenericTensorSliceCopy_v1( + {0, k_block_data_on_global}, {0, 0}); + + // GEMM definition + // c_mtx += transpose(a_mtx) * b_mtx + // a_mtx[EPerBlock, KPerBlock] is in LDS + // b_mtx[EPerBlocl, N1 * BPerBlock * N2] is in LDS + // c_mtx[KPerBlock, N1 * BPerBlock * N2] is distributed among threads, and saved in + // register + constexpr auto a_e_k_block_mtx_desc = + make_ConstantMatrixDescriptor_from_ConstantTensorDescriptor(wei_e_k_block_desc); + + // this check is ad-hoc + // TODO: need to properly implement tensor descriptor with multiple alignment + // requirements + static_assert(in_e_n0_ho0_wo0_b1_n2_ho2_wo2_block_desc.GetStrides()[3] % GemmDataPerReadB == + 0, + "GemmDataPerReadB alignment requirement is not satisfied"); + + constexpr auto b_e_n0ho0wo0b1n2ho2wo2_block_mtx_desc = + make_ConstantMatrixDescriptor_from_ConstantTensorDescriptor( + in_e_n0_ho0_wo0_b1_n2_ho2_wo2_block_desc.Unfold(I1, I7)); + + // sanity check + static_assert(KPerBlock % (GemmMPerThreadSubC * GemmMLevel0Cluster * GemmMLevel1Cluster) == + 0, + "wrong!"); + + constexpr index_t GemmMRepeat = + KPerBlock / (GemmMPerThreadSubC * GemmMLevel0Cluster * GemmMLevel1Cluster); + + // c_thread_mtx definition: this is a mess + // TODO:: more elegent way of defining c_thread_mtx + constexpr auto c_k0k2_n0ho0wo0n2ho2wo2_thread_mtx_desc = make_ConstantMatrixDescriptor( + Number{}, Number{}); + + const auto blockwise_gemm = BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2< + BlockSize, + decltype(a_e_k_block_mtx_desc), + decltype(b_e_n0ho0wo0b1n2ho2wo2_block_mtx_desc), + decltype(c_k0k2_n0ho0wo0n2ho2wo2_thread_mtx_desc), + GemmMPerThreadSubC, + GemmNPerThreadSubC, + GemmMLevel0Cluster, + GemmNLevel0Cluster, + GemmMLevel1Cluster, + GemmNLevel1Cluster, + GemmKPerThreadLoop, + GemmDataPerReadA, + GemmDataPerReadB>{}; + + // LDS allocation for input and weight: be careful of alignment + constexpr index_t max_align = math::lcm(InBlockCopyDstDataPerWrite_N2, + WeiBlockCopyDstDataPerWrite_K, + GemmDataPerReadA, + GemmDataPerReadB); + + constexpr index_t in_block_space = math::integer_least_multiple( + in_e_n0_ho0_wo0_b1_n2_ho2_wo2_block_desc.GetElementSpace(), max_align); + + constexpr index_t wei_block_space = + math::integer_least_multiple(wei_e_k_block_desc.GetElementSpace(), max_align); + + __shared__ Float p_in_block_double[2 * in_block_space]; + __shared__ Float p_wei_block_double[2 * wei_block_space]; + + // register allocation for output + Float p_out_thread[c_k0k2_n0ho0wo0n2ho2wo2_thread_mtx_desc.GetElementSpace()]; + + // zero out threadwise output + threadwise_matrix_set_zero(c_k0k2_n0ho0wo0n2ho2wo2_thread_mtx_desc, p_out_thread); + + const Float* p_wei_block_on_global = p_wei_global; + + // LDS double buffer: preload data into LDS + { + blockwise_in_copy.Run(p_in_global, p_in_block_double); + blockwise_wei_copy.Run(p_wei_global, p_wei_block_double); + } + + // LDS double buffer: main body + for(index_t e_block_data_begin = 0; e_block_data_begin + 2 * EPerBlock < E; + e_block_data_begin += 2 * EPerBlock) + { +#pragma unroll + for(index_t iloop = 0; iloop < 2; ++iloop) + { + const bool even_loop = (iloop % 2 == 0); + + Float* p_in_block_now = + even_loop ? p_in_block_double : p_in_block_double + in_block_space; + Float* p_wei_block_now = + even_loop ? p_wei_block_double : p_wei_block_double + wei_block_space; + + Float* p_in_block_next = + even_loop ? p_in_block_double + in_block_space : p_in_block_double; + Float* p_wei_block_next = + even_loop ? p_wei_block_double + wei_block_space : p_wei_block_double; + + Float p_in_register_clipboard[blockwise_in_copy.GetRegisterClipboardSize()]; + Float p_wei_register_clipboard[blockwise_wei_copy.GetRegisterClipboardSize()]; + + blockwise_in_copy.MoveSlicingWindowOnSourceTensor(I0, Number{}, True); + p_wei_block_on_global += EPerBlock * wei_e_k_global_desc.GetStride(I0); + + __syncthreads(); + + // LDS doubel buffer: load next data from device mem + blockwise_in_copy.RunLoadRegisterClipboard(p_in_global, p_in_register_clipboard); + blockwise_wei_copy.RunLoadRegisterClipboard(p_wei_block_on_global, + p_wei_register_clipboard); + + // LDS double buffer: GEMM on current data + blockwise_gemm.Run(p_wei_block_now, p_in_block_now, p_out_thread); + + // LDS double buffer: store next data to LDS + blockwise_in_copy.RunStoreRegisterClipboard(p_in_register_clipboard, + p_in_block_next); + blockwise_wei_copy.RunStoreRegisterClipboard(p_wei_register_clipboard, + p_wei_block_next); + } + } + + // LDS double buffer: tail + { + Float p_in_register_clipboard[blockwise_in_copy.GetRegisterClipboardSize()]; + Float p_wei_register_clipboard[blockwise_wei_copy.GetRegisterClipboardSize()]; + + // even iteration + blockwise_in_copy.MoveSlicingWindowOnSourceTensor(I0, Number{}, True); + p_wei_block_on_global += EPerBlock * wei_e_k_global_desc.GetStride(I0); + + __syncthreads(); + + // LDS doubel buffer: load next data from device mem + blockwise_in_copy.RunLoadRegisterClipboard(p_in_global, p_in_register_clipboard); + blockwise_wei_copy.RunLoadRegisterClipboard(p_wei_block_on_global, + p_wei_register_clipboard); + + // LDS double buffer: GEMM on current data + blockwise_gemm.Run(p_wei_block_double, p_in_block_double, p_out_thread); + + // LDS double buffer: store next data to LDS + blockwise_in_copy.RunStoreRegisterClipboard(p_in_register_clipboard, + p_in_block_double + in_block_space); + blockwise_wei_copy.RunStoreRegisterClipboard(p_wei_register_clipboard, + p_wei_block_double + wei_block_space); + + // odd iteration + __syncthreads(); + + // LDS double buffer: GEMM on current data + blockwise_gemm.Run(p_wei_block_double + wei_block_space, + p_in_block_double + in_block_space, + p_out_thread); + } + + // copy output: register to global memory + { + constexpr index_t K2 = GemmMPerThreadSubC; + constexpr index_t K1 = GemmMLevel0Cluster * GemmMLevel1Cluster; + + // define tensor descriptor for threadwise copy + // output memory layout descriptor in register + constexpr auto out_k0_k1_k2_n0_ho0_wo0_n1_ho1_wo1_n2_ho2_wo2_thread_mem_desc = + make_ConstantTensorDescriptor_packed( + Sequence{}); + + // output tensor descriptor in register, src of threadwise copy + constexpr auto out_n0_n1_n2_k0_k1_k2_ho0_ho1_ho2_wo0_wo1_wo2_thread_desc = + out_k0_k1_k2_n0_ho0_wo0_n1_ho1_wo1_n2_ho2_wo2_thread_mem_desc.ReorderGivenNew2Old( + Sequence<3, 6, 9, 0, 1, 2, 4, 7, 10, 5, 8, 11>{}); + + // output memory layout descriptor in device memory, dst of threadwise copy + constexpr auto out_n0_n1_n2_k0_k1_k2_ho0_ho1_ho2_wo0_wo1_wo2_global_desc = + out_n_k_h_w_global_desc.Fold(I3, Sequence{}) + .Fold(I2, Sequence{}) + .Fold(I1, Sequence{}) + .Fold(I0, Sequence{}); + + // calculate origin of thread output tensor on global memory + // blockwise GEMM c matrix starting index + const auto c_thread_mtx_on_block = + blockwise_gemm.GetBeginOfThreadMatrixC(get_thread_local_1d_id()); + + const index_t k_thread_data_on_global = + k_block_data_on_global + c_thread_mtx_on_block.row; + + const index_t b_thread_data_on_global = + b_block_data_on_global + c_thread_mtx_on_block.col / N2; + + // output merged global tensor descriptor, for calculating origin of thread tensor + // in global memory + constexpr auto out_k_n1_b_n2_global_merged_desc = make_ConstantMergedTensorDescriptor( + out_n0_n1_n2_k0_k1_k2_h_w_global_mem_desc.Unfold(I3, I5), + Sequence<3>{}, + Sequence<1>{}, + Sequence<0, 4, 5>{}, + Sequence<2>{}); + + // origin of dst in device memory + Float* p_out_thread_on_global = + p_out_global + + out_k_n1_b_n2_global_merged_desc.GetOffsetFromMultiIndex( + k_thread_data_on_global, 0, b_thread_data_on_global, 0); + + threadwise_generic_tensor_slice_copy_v1( + out_n0_n1_n2_k0_k1_k2_h_w_thread_desc, + p_out_thread, + {0, 0, 0, 0, 0, 0, 0, 0}, + out_n0_n1_n2_k0_k1_k2_h_w_global_mem_desc, + p_out_thread_on_global, + {0, 0, 0, 0, 0, 0, 0, 0}, + out_n0_n1_n2_k0_k1_k2_h_w_thread_desc.GetLengths(), + arithmetic_sequence_gen<0, 8, 1>::type{}, + Number<1>{}); + } + } +}; + +} // namespace ck +#endif diff --git a/driver/include/device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw.hpp b/driver/include/device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw.hpp new file mode 100644 index 0000000000..b3e9ef44d5 --- /dev/null +++ b/driver/include/device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw.hpp @@ -0,0 +1,232 @@ +#pragma once +#include +#include "device.hpp" +#include "tensor.hpp" +#include "gridwise_convolution_kernel_wrapper.hpp" +#include "gridwise_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw.hpp" +#include "gridwise_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer.hpp" + +using namespace ck; + +template +void device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc, + const Tensor& in_nchw, + WeiDesc, + const Tensor& wei_kcyx, + OutDesc, + Tensor& out_nkhw, + ConvStrides, + ConvDilations, + index_t nrepeat) +{ + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + constexpr auto I2 = Number<2>{}; + constexpr auto I3 = Number<3>{}; + + constexpr auto in_nchw_desc = InDesc{}; + constexpr auto wei_kcyx_desc = WeiDesc{}; + constexpr auto out_nkhw_desc = OutDesc{}; + + constexpr index_t Hi = in_nchw_desc.GetLength(I2); + constexpr index_t Wi = in_nchw_desc.GetLength(I3); + + constexpr index_t N = out_nkhw_desc.GetLength(I0); + constexpr index_t Ho = out_nkhw_desc.GetLength(I2); + constexpr index_t Wo = out_nkhw_desc.GetLength(I3); + + constexpr index_t K = wei_kcyx_desc.GetLength(I0); + constexpr index_t C = wei_kcyx_desc.GetLength(I1); + constexpr index_t Y = wei_kcyx_desc.GetLength(I2); + constexpr index_t X = wei_kcyx_desc.GetLength(I3); + + std::size_t data_sz = sizeof(T); + DeviceMem in_nchw_device_buf(data_sz * in_nchw.mDesc.GetElementSpace()); + DeviceMem wei_kcyx_device_buf(data_sz * wei_kcyx.mDesc.GetElementSpace()); + DeviceMem out_nkhw_device_buf(data_sz * out_nkhw.mDesc.GetElementSpace()); + + in_nchw_device_buf.ToDevice(in_nchw.mData.data()); + wei_kcyx_device_buf.ToDevice(wei_kcyx.mData.data()); + out_nkhw_device_buf.ToDevice(out_nkhw.mData.data()); + + constexpr index_t N1 = 2; + constexpr index_t N2 = 4; + + constexpr index_t B = (N * Ho * Wo) / (N1 * N2); + +#if 1 + constexpr index_t BlockSize = 256; + + constexpr index_t BPerBlock = 16; + constexpr index_t KPerBlock = 128; + constexpr index_t EPerBlock = 8; + + constexpr index_t GemmMPerThreadSubC = 4; + constexpr index_t GemmNPerThreadSubC = 4; + constexpr index_t GemmMLevel0Cluster = 4; + constexpr index_t GemmNLevel0Cluster = 4; + constexpr index_t GemmMLevel1Cluster = 4; + constexpr index_t GemmNLevel1Cluster = 4; + constexpr index_t GemmKPerThreadLoop = 1; + constexpr index_t GemmDataPerReadA = 4; + constexpr index_t GemmDataPerReadB = 4; + + using InBlockCopySubLengths_E_N1_B_N2 = Sequence<1, 1, 1, 4>; + using InBlockCopyClusterLengths_E_N1_B_N2 = Sequence<8, 2, 16, 1>; + using InBlockCopyThreadClusterArrangeOrder = Sequence<0, 1, 3, 2>; // [E, N1, N2, B] + using InBlockCopySrcAccessOrder = Sequence<0, 1, 3, 2>; // [E, N1, N2, B] + using InBlockCopyDstAccessOrder = Sequence<0, 1, 2, 3>; // [E, N1, B, N2] + + constexpr index_t InBlockCopySrcDataPerRead_B = 1; + constexpr index_t InBlockCopyDstDataPerWrite_N2 = 4; + + using WeiBlockCopySubLengths_E_K = Sequence<4, 1>; + using WeiBlockCopyClusterLengths_E_K = Sequence<2, 128>; + using WeiBlockCopyThreadClusterArrangeOrder = Sequence<1, 0>; // [K, E] + using WeiBlockCopySrcAccessOrder = Sequence<1, 0>; // [K, E] + using WeiBlockCopyDstAccessOrder = Sequence<0, 1>; // [E, K] + + constexpr index_t WeiBlockCopySrcDataPerRead_E = 4; + constexpr index_t WeiBlockCopyDstDataPerWrite_K = 1; +#elif 0 + constexpr index_t BlockSize = 256; + + constexpr index_t BPerBlock = 16; + constexpr index_t KPerBlock = 128; + constexpr index_t EPerBlock = 8; + + constexpr index_t GemmMPerThreadSubC = 4; + constexpr index_t GemmNPerThreadSubC = 4; + constexpr index_t GemmMLevel0Cluster = 4; + constexpr index_t GemmNLevel0Cluster = 4; + constexpr index_t GemmMLevel1Cluster = 4; + constexpr index_t GemmNLevel1Cluster = 4; + constexpr index_t GemmKPerThreadLoop = 1; + constexpr index_t GemmDataPerReadA = 4; + constexpr index_t GemmDataPerReadB = 4; + + using InBlockCopySubLengths_E_N1_B_N2 = Sequence<1, 1, 4, 1>; + using InBlockCopyClusterLengths_E_N1_B_N2 = Sequence<8, 2, 4, 4>; + using InBlockCopyThreadClusterArrangeOrder = Sequence<0, 1, 3, 2>; // [E, N1, N2, B] + using InBlockCopySrcAccessOrder = Sequence<0, 1, 3, 2>; // [E, N1, N2, B] + using InBlockCopyDstAccessOrder = Sequence<0, 1, 2, 3>; // [E, N1, B, N2] + + constexpr index_t InBlockCopySrcDataPerRead_B = 4; + constexpr index_t InBlockCopyDstDataPerWrite_N2 = 1; + + using WeiBlockCopySubLengths_E_K = Sequence<4, 1>; + using WeiBlockCopyClusterLengths_E_K = Sequence<2, 128>; + using WeiBlockCopyThreadClusterArrangeOrder = Sequence<1, 0>; // [K, E] + using WeiBlockCopySrcAccessOrder = Sequence<1, 0>; // [K, E] + using WeiBlockCopyDstAccessOrder = Sequence<0, 1>; // [E, K] + + constexpr index_t WeiBlockCopySrcDataPerRead_E = 4; + constexpr index_t WeiBlockCopyDstDataPerWrite_K = 1; +#elif 1 + constexpr index_t BlockSize = 256; + + constexpr index_t BPerBlock = 16; + constexpr index_t KPerBlock = 128; + constexpr index_t EPerBlock = 8; + + constexpr index_t GemmMPerThreadSubC = 4; + constexpr index_t GemmNPerThreadSubC = 4; + constexpr index_t GemmMLevel0Cluster = 4; + constexpr index_t GemmNLevel0Cluster = 4; + constexpr index_t GemmMLevel1Cluster = 4; + constexpr index_t GemmNLevel1Cluster = 4; + constexpr index_t GemmKPerThreadLoop = 1; + constexpr index_t GemmDataPerReadA = 4; + constexpr index_t GemmDataPerReadB = 4; + + using InBlockCopySubLengths_E_N1_B_N2 = Sequence<1, 1, 2, 2>; + using InBlockCopyClusterLengths_E_N1_B_N2 = Sequence<8, 2, 8, 2>; + using InBlockCopyThreadClusterArrangeOrder = Sequence<0, 1, 3, 2>; // [E, N1, N2, B] + using InBlockCopySrcAccessOrder = Sequence<0, 1, 3, 2>; // [E, N1, N2, B] + using InBlockCopyDstAccessOrder = Sequence<0, 1, 2, 3>; // [E, N1, B, N2] + + constexpr index_t InBlockCopySrcDataPerRead_B = 2; + constexpr index_t InBlockCopyDstDataPerWrite_N2 = 2; + + using WeiBlockCopySubLengths_E_K = Sequence<4, 1>; + using WeiBlockCopyClusterLengths_E_K = Sequence<2, 128>; + using WeiBlockCopyThreadClusterArrangeOrder = Sequence<1, 0>; // [K, E] + using WeiBlockCopySrcAccessOrder = Sequence<1, 0>; // [K, E] + using WeiBlockCopyDstAccessOrder = Sequence<0, 1>; // [E, K] + + constexpr index_t WeiBlockCopySrcDataPerRead_E = 4; + constexpr index_t WeiBlockCopyDstDataPerWrite_K = 1; +#endif + + constexpr index_t GridSize = + ((B + BPerBlock - 1) / BPerBlock) * ((K + KPerBlock - 1) / KPerBlock); + + printf("%s: BlockSize %u, GridSize %u \n", __func__, BlockSize, GridSize); + + for(index_t i = 0; i < nrepeat; ++i) + { + constexpr auto gridwise_conv = +#if 0 + GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw +#else + GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer +#endif + {}; + + float time = launch_kernel(run_gridwise_convolution_kernel, + dim3(GridSize), + dim3(BlockSize), + 0, + static_cast(in_nchw_device_buf.GetDeviceBuffer()), + static_cast(wei_kcyx_device_buf.GetDeviceBuffer()), + static_cast(out_nkhw_device_buf.GetDeviceBuffer())); + + printf("Elapsed time : %f ms, %f TFlop/s\n", + time, + (float)calculate_convolution_flops(InDesc{}, WeiDesc{}, OutDesc{}) / + (std::size_t(1000) * 1000 * 1000) / time); + usleep(std::min(time * 1000, float(10000))); + } + + out_nkhw_device_buf.FromDevice(out_nkhw.mData.data()); +} diff --git a/driver/include/device_convolution_implicit_gemm_v4r2_nchw_kcyx_nkhw.hpp b/driver/include/device_convolution_implicit_gemm_v4r2_nchw_kcyx_nkhw.hpp new file mode 100644 index 0000000000..5f6c2ec3f2 --- /dev/null +++ b/driver/include/device_convolution_implicit_gemm_v4r2_nchw_kcyx_nkhw.hpp @@ -0,0 +1,159 @@ +#pragma once +#include +#include "device.hpp" +#include "tensor.hpp" +#include "gridwise_convolution_kernel_wrapper.hpp" +#include "gridwise_convolution_implicit_gemm_v4r2_nchw_kcyx_nkhw_lds_double_buffer.hpp" + +using namespace ck; + +template +void device_convolution_implicit_gemm_v4r2_nchw_kcyx_nkhw(InDesc, + const Tensor& in_nchw, + WeiDesc, + const Tensor& wei_kcyx, + OutDesc, + Tensor& out_nkhw, + ConvStrides, + ConvDilations, + index_t nrepeat) +{ + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + constexpr auto I2 = Number<2>{}; + constexpr auto I3 = Number<3>{}; + + constexpr auto in_nchw_desc = InDesc{}; + constexpr auto wei_kcyx_desc = WeiDesc{}; + constexpr auto out_nkhw_desc = OutDesc{}; + + constexpr index_t Hi = in_nchw_desc.GetLength(I2); + constexpr index_t Wi = in_nchw_desc.GetLength(I3); + + constexpr index_t N = out_nkhw_desc.GetLength(I0); + constexpr index_t Ho = out_nkhw_desc.GetLength(I2); + constexpr index_t Wo = out_nkhw_desc.GetLength(I3); + + constexpr index_t K = wei_kcyx_desc.GetLength(I0); + constexpr index_t C = wei_kcyx_desc.GetLength(I1); + constexpr index_t Y = wei_kcyx_desc.GetLength(I2); + constexpr index_t X = wei_kcyx_desc.GetLength(I3); + + std::size_t data_sz = sizeof(T); + DeviceMem in_nchw_device_buf(data_sz * in_nchw.mDesc.GetElementSpace()); + DeviceMem wei_kcyx_device_buf(data_sz * wei_kcyx.mDesc.GetElementSpace()); + DeviceMem out_nkhw_device_buf(data_sz * out_nkhw.mDesc.GetElementSpace()); + + in_nchw_device_buf.ToDevice(in_nchw.mData.data()); + wei_kcyx_device_buf.ToDevice(wei_kcyx.mData.data()); + out_nkhw_device_buf.ToDevice(out_nkhw.mData.data()); + + constexpr index_t N1 = 2; + constexpr index_t N2 = 4; + + constexpr index_t B = (N * Ho * Wo) / (N1 * N2); + +#if 1 + constexpr index_t BlockSize = 256; + + constexpr index_t BPerBlock = 16; + constexpr index_t KPerBlock = 128; + constexpr index_t EPerBlock = 8; + + constexpr index_t GemmMPerThreadSubC = 4; + constexpr index_t GemmNPerThreadSubC = 4; + constexpr index_t GemmMLevel0Cluster = 4; + constexpr index_t GemmNLevel0Cluster = 4; + constexpr index_t GemmMLevel1Cluster = 4; + constexpr index_t GemmNLevel1Cluster = 4; + constexpr index_t GemmKPerThreadLoop = 1; + constexpr index_t GemmDataPerReadA = 4; + constexpr index_t GemmDataPerReadB = 4; + + using InBlockCopySubLengths_E_N1_B_N2 = Sequence<1, 1, 1, 4>; + using InBlockCopyClusterLengths_E_N1_B_N2 = Sequence<8, 2, 16, 1>; + using InBlockCopyThreadClusterArrangeOrder = Sequence<0, 1, 3, 2>; // [E, N1, N2, B] + using InBlockCopySrcAccessOrder = Sequence<0, 1, 3, 2>; // [E, N1, N2, B] + using InBlockCopyDstAccessOrder = Sequence<0, 1, 2, 3>; // [E, N1, B, N2] + + constexpr index_t InBlockCopySrcDataPerRead_B = 1; + constexpr index_t InBlockCopyDstDataPerWrite_N2 = 4; + + using WeiBlockCopySubLengths_E_K = Sequence<4, 1>; + using WeiBlockCopyClusterLengths_E_K = Sequence<2, 128>; + using WeiBlockCopyThreadClusterArrangeOrder = Sequence<1, 0>; // [K, E] + using WeiBlockCopySrcAccessOrder = Sequence<1, 0>; // [K, E] + using WeiBlockCopyDstAccessOrder = Sequence<0, 1>; // [E, K] + + constexpr index_t WeiBlockCopySrcDataPerRead_E = 4; + constexpr index_t WeiBlockCopyDstDataPerWrite_K = 1; +#endif + + constexpr index_t GridSize = + ((B + BPerBlock - 1) / BPerBlock) * ((K + KPerBlock - 1) / KPerBlock); + + printf("%s: BlockSize %u, GridSize %u \n", __func__, BlockSize, GridSize); + + for(index_t i = 0; i < nrepeat; ++i) + { + constexpr auto gridwise_conv = + GridwiseConvolutionImplicitGemm_v4r2_nchw_kcyx_nkhw_lds_double_buffer< + GridSize, + BlockSize, + T, + decltype(in_nchw_desc), + decltype(wei_kcyx_desc), + decltype(out_nkhw_desc), + ConvStrides, + ConvDilations, + BPerBlock, + KPerBlock, + EPerBlock, + N1, + N2, + GemmMPerThreadSubC, + GemmNPerThreadSubC, + GemmMLevel0Cluster, + GemmNLevel0Cluster, + GemmMLevel1Cluster, + GemmNLevel1Cluster, + GemmKPerThreadLoop, + GemmDataPerReadA, + GemmDataPerReadB, + InBlockCopySubLengths_E_N1_B_N2, + InBlockCopyClusterLengths_E_N1_B_N2, + InBlockCopyThreadClusterArrangeOrder, + InBlockCopySrcAccessOrder, + InBlockCopyDstAccessOrder, + InBlockCopySrcDataPerRead_B, + InBlockCopyDstDataPerWrite_N2, + WeiBlockCopySubLengths_E_K, + WeiBlockCopyClusterLengths_E_K, + WeiBlockCopyThreadClusterArrangeOrder, + WeiBlockCopySrcAccessOrder, + WeiBlockCopyDstAccessOrder, + WeiBlockCopySrcDataPerRead_E, + WeiBlockCopyDstDataPerWrite_K>{}; + + float time = launch_kernel(run_gridwise_convolution_kernel, + dim3(GridSize), + dim3(BlockSize), + 0, + static_cast(in_nchw_device_buf.GetDeviceBuffer()), + static_cast(wei_kcyx_device_buf.GetDeviceBuffer()), + static_cast(out_nkhw_device_buf.GetDeviceBuffer())); + + printf("Elapsed time : %f ms, %f TFlop/s\n", + time, + (float)calculate_convolution_flops(InDesc{}, WeiDesc{}, OutDesc{}) / + (std::size_t(1000) * 1000 * 1000) / time); + usleep(std::min(time * 1000, float(10000))); + } + + out_nkhw_device_buf.FromDevice(out_nkhw.mData.data()); +}