diff --git a/composable_kernel/include/kernel_algorithm/gridwise_convolution_implicit_gemm_v4_nchw_kcyx_nkhw.hpp b/composable_kernel/include/kernel_algorithm/gridwise_convolution_implicit_gemm_v4_nchw_kcyx_nkhw.hpp deleted file mode 100644 index 915193dc40..0000000000 --- a/composable_kernel/include/kernel_algorithm/gridwise_convolution_implicit_gemm_v4_nchw_kcyx_nkhw.hpp +++ /dev/null @@ -1,354 +0,0 @@ -#ifndef CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V4_NCHW_KCYX_NKHW -#define CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V4_NCHW_KCYX_NKHW - -#include "common_header.hpp" -#include "ConstantTensorDescriptor.hpp" -#include "ConstantMergedTensorDescriptor.hpp" -#include "ConstantMatrixDescriptor.hpp" -#include "blockwise_generic_tensor_slice_copy.hpp" -#include "blockwise_gemm.hpp" -#include "threadwise_generic_tensor_slice_copy.hpp" - -namespace ck { - -// define B = merge(N0, Ho, Wo) -template -struct GridwiseConvolutionImplicitGemm_v4_nchw_kcyx_nkhw -{ - __device__ void Run(const Float* const __restrict__ p_in_global, - const Float* const __restrict__ p_wei_global, - Float* const __restrict__ p_out_global) const - { - // this is a mess - // TODO: find more elegent way of specifying (or calculating) performance parameters - static_assert(N2 == GemmNPerThreadSubC, "wrong!"); - static_assert((N1 * N2 * BPerBlock) % - (GemmNPerThreadSubC * GemmNLevel0Cluster * GemmNLevel1Cluster) == - 0, - "wrong!"); - - constexpr auto I0 = Number<0>{}; - constexpr auto I1 = Number<1>{}; - constexpr auto I2 = Number<2>{}; - constexpr auto I3 = Number<3>{}; - constexpr auto I4 = Number<4>{}; - constexpr auto I5 = Number<5>{}; - constexpr auto I6 = Number<6>{}; - constexpr auto I7 = Number<7>{}; - - constexpr auto True = integral_constant{}; - - constexpr auto in_n_c_h_w_global_desc = InGlobalDesc{}; - constexpr auto wei_k_c_y_x_global_desc = WeiGlobalDesc{}; - constexpr auto out_n_k_h_w_global_desc = OutGlobalDesc{}; - - constexpr index_t N = in_n_c_h_w_global_desc.GetLength(I0); - constexpr index_t C = in_n_c_h_w_global_desc.GetLength(I1); - constexpr index_t Hi = in_n_c_h_w_global_desc.GetLength(I2); - constexpr index_t Wi = in_n_c_h_w_global_desc.GetLength(I3); - - constexpr index_t K = out_n_k_h_w_global_desc.GetLength(I1); - constexpr index_t Ho = out_n_k_h_w_global_desc.GetLength(I2); - constexpr index_t Wo = out_n_k_h_w_global_desc.GetLength(I3); - - constexpr index_t Y = wei_k_c_y_x_global_desc.GetLength(I2); - constexpr index_t X = wei_k_c_y_x_global_desc.GetLength(I3); - - static_assert(N % (N1 * N2) == 0, "wrong! cannot divice N evenly among thread"); - - constexpr index_t N0 = N / (N1 * N2); - - constexpr index_t B = N0 * Ho * Wo; - - constexpr index_t E = C * Y * X; - - // divide block work by [K, B] - static_assert(K % KPerBlock == 0 && B % BPerBlock == 0 && E % EPerBlock == 0, - "wrong! cannot divide work evenly among block"); - - constexpr index_t KBlockWork = K / KPerBlock; - constexpr index_t BBlockWork = B / BPerBlock; - - constexpr auto block_work_desc = - make_ConstantTensorDescriptor_packed(Sequence{}); - - const auto block_work_multi_id = - block_work_desc.GetMultiIndexFrom1dIndex(get_block_1d_id()); - - const index_t k_block_data_on_global = block_work_multi_id[0] * KPerBlock; - const index_t b_block_data_on_global = block_work_multi_id[1] * BPerBlock; - - // input tensor - // tensor descriptor in device memory [N0, N1, N2, Ho, Wo] - constexpr auto in_n0_n1_n2_h_w_global_desc = in_n_c_h_w_global_desc.Slice(I2, Number{}) - .Slice(I3, Number{}) - .Fold(I0, Number{}, Number{}) - .Extract(Sequence<0, 1, 2, 4, 5>{}); - - // batch descritpor for device memory - constexpr auto in_c_y_x_global_desc = in_n_c_h_w_global_desc.Slice(I2, Number{}) - .Slice(I3, Number{}) - .Extract(Sequence<1, 2, 3>{}); - - // merged tensor descriptor in device memory [E, N1, B, N2], src of blockwise copy - constexpr auto in_e_n1_b_n2_global_merged_desc = make_ConstantMergedTensorDescriptor( - in_c_y_x_global_desc.Embed(in_n0_n1_n2_h_w_global_desc), - Sequence<0, 1, 2>{}, - Sequence<4>{}, - Sequence<3, 6, 7>{}, - Sequence<5>{}); - -#if 0 - if(get_block_1d_id() == 0 && get_thread_local_1d_id() == 0) - { - print_ConstantTensorDescriptor(in_n0_n1_n2_h_w_global_desc, - "in_n0_n1_n2_h_w_global_desc: "); - print_ConstantTensorDescriptor(in_c_y_x_global_desc, "in_c_y_x_global_desc: "); - print_ConstantMergedTensorDescriptor(in_e_n1_b_n2_global_merged_desc, - "in_e_n1_b_n2_global_merged_desc: "); - } -#endif - - // memory layout descriptor in LDS [E, N1, B, N2], dst of blockwise copy - // be careful of LDS alignment - constexpr auto in_e_n1_b_n2_block_desc = make_ConstantTensorDescriptor_aligned( - Sequence{}, Number{}); - - // this check is ad-hoc - // TODO: need to properly implement tensor descriptor with multiple alignment - // requirements - static_assert(in_e_n1_b_n2_block_desc.GetStride(I1) % GemmDataPerReadB == 0, - "GemmDataPerReadB alignment requirement is not satisfied"); - - // input blockwise copy - // slice a merged tensor, reorder and copy to a normal tensor - // this copy operator already has blockwise offset built-in - auto blockwise_in_copy = - BlockwiseGenericTensorSliceCopy_v1( - {0, 0, b_block_data_on_global, 0}, {0, 0, 0, 0}); - - // weight tensor - // tensor descriptor in device memory, src of blockwise copy - constexpr auto wei_e_k_global_desc = - wei_k_c_y_x_global_desc.Unfold(I1, I3).ReorderGivenNew2Old(Sequence<1, 0>{}); - - // tensor descriptor in LDS, dst of blockwise copy - // be careful of LDS alignment - constexpr auto wei_e_k_block_desc = make_ConstantTensorDescriptor_aligned( - Sequence{}, - Number{}); - - // operator for blockwise copy of weight into LDS - // slice a tensor, and copy it into another tensor - // this copy operator already have blockwise offset built-in - auto blockwise_wei_copy = - BlockwiseGenericTensorSliceCopy_v1( - {0, k_block_data_on_global}, {0, 0}); - - // GEMM definition - // c_mtx += transpose(a_mtx) * b_mtx - // a_mtx[EPerBlock, KPerBlock] is in LDS - // b_mtx[EPerBlocl, N1 * BPerBlock * N2] is in LDS - // c_mtx[KPerBlock, N1 * BPerBlock * N2] is distributed among threads, and saved in - // register - constexpr auto a_e_k_block_mtx_desc = make_ConstantMatrixDescriptor( - Number{}, Number{}, Number{}); - - constexpr auto b_e_n1bn2_block_mtx_desc = - make_ConstantMatrixDescriptor(Number{}, - Number{}, - Number{}); - - // sanity check - static_assert(KPerBlock % (GemmMPerThreadSubC * GemmMLevel0Cluster * GemmMLevel1Cluster) == - 0, - "wrong!"); - - constexpr index_t GemmMRepeat = - KPerBlock / (GemmMPerThreadSubC * GemmMLevel0Cluster * GemmMLevel1Cluster); - - // c_thread_mtx definition: this is a mess - // TODO:: more elegent way of defining c_thread_mtx - constexpr auto c_k0k2_n1n2_thread_mtx_desc = make_ConstantMatrixDescriptor( - Number{}, Number{}); - - const auto blockwise_gemm = BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2< - BlockSize, - decltype(a_e_k_block_mtx_desc), - decltype(b_e_n1bn2_block_mtx_desc), - decltype(c_k0k2_n1n2_thread_mtx_desc), - GemmMPerThreadSubC, - GemmNPerThreadSubC, - GemmMLevel0Cluster, - GemmNLevel0Cluster, - GemmMLevel1Cluster, - GemmNLevel1Cluster, - GemmKPerThreadLoop, - GemmDataPerReadA, - GemmDataPerReadB>{}; - - // choose GEMM implementation here - const auto run_blockwise_gemm = [&](auto... Xs) { -#if 1 - return blockwise_gemm.Run(Xs...); -#else - return blockwise_gemm.Run_amd_asm(Xs...); -#endif - }; - - // LDS allocation for input and weight: be careful of alignment - constexpr index_t max_align = math::lcm(InBlockCopyDstDataPerWrite_N2, - WeiBlockCopyDstDataPerWrite_K, - GemmDataPerReadA, - GemmDataPerReadB); - - constexpr index_t in_block_space = - in_e_n1_b_n2_block_desc.GetElementSpace(Number{}); - - constexpr index_t wei_block_space = wei_e_k_block_desc.GetElementSpace(Number{}); - - __shared__ Float p_in_block[in_block_space]; - __shared__ Float p_wei_block[wei_block_space]; - - // register allocation for output - Float p_out_thread[c_k0k2_n1n2_thread_mtx_desc.GetElementSpace()]; - - // zero out threadwise output - threadwise_matrix_set_zero(c_k0k2_n1n2_thread_mtx_desc, p_out_thread); - - // do work - for(index_t e = 0; e < E; e += EPerBlock) - { - // marching slicing window - blockwise_in_copy.Run(p_in_global, p_in_block); - blockwise_wei_copy.Run(p_wei_global, p_wei_block); - - __syncthreads(); - - run_blockwise_gemm(p_wei_block, p_in_block, p_out_thread); - - __syncthreads(); - - blockwise_in_copy.MoveSlicingWindowOnSourceTensor(I0, Number{}, True); - blockwise_wei_copy.MoveSlicingWindowOnSourceTensor(I0, Number{}, True); - } - - // copy output: register to global memory - { - constexpr index_t K2 = GemmMPerThreadSubC; - constexpr index_t K1 = GemmMLevel0Cluster * GemmMLevel1Cluster; - constexpr index_t K0 = K / (K1 * K2); - - // define tensor descriptor for threadwise copy - // output memory layout descriptor in register - constexpr auto out_k0_k1_k2_n1_n0_h_w_n2_thread_mem_desc = - make_ConstantTensorDescriptor_packed( - Sequence{}); - - // output tensor descriptor in register, src of threadwise copy - constexpr auto out_n0_n1_n2_k0_k1_k2_h_w_thread_desc = - out_k0_k1_k2_n1_n0_h_w_n2_thread_mem_desc.ReorderGivenNew2Old( - Sequence<4, 3, 7, 0, 1, 2, 5, 6>{}); - - // output memory layout descriptor in device memory, dst of threadwise copy - constexpr auto out_n0_n1_n2_k0_k1_k2_h_w_global_mem_desc = - out_n_k_h_w_global_desc.Fold(I1, Number{}, Number{}) - .Fold(I0, Number{}, Number{}); - - // calculate origin of thread output tensor on global memory - // blockwise GEMM c matrix starting index - const auto c_thread_mtx_on_block = - blockwise_gemm.GetBeginOfThreadMatrixC(get_thread_local_1d_id()); - - const index_t k_thread_data_on_global = - k_block_data_on_global + c_thread_mtx_on_block.row; - - const index_t b_thread_data_on_global = - b_block_data_on_global + c_thread_mtx_on_block.col / N2; - - // output merged global tensor descriptor, for calculating origin of thread tensor - // in global memory - constexpr auto out_k_n1_b_n2_global_merged_desc = make_ConstantMergedTensorDescriptor( - out_n0_n1_n2_k0_k1_k2_h_w_global_mem_desc.Unfold(I3, I5), - Sequence<3>{}, - Sequence<1>{}, - Sequence<0, 4, 5>{}, - Sequence<2>{}); - - // origin of dst in device memory - Float* p_out_thread_on_global = - p_out_global + - out_k_n1_b_n2_global_merged_desc.GetOffsetFromMultiIndex( - k_thread_data_on_global, 0, b_thread_data_on_global, 0); - - threadwise_generic_tensor_slice_copy_v1( - out_n0_n1_n2_k0_k1_k2_h_w_thread_desc, - p_out_thread, - {0, 0, 0, 0, 0, 0, 0, 0}, - out_n0_n1_n2_k0_k1_k2_h_w_global_mem_desc, - p_out_thread_on_global, - {0, 0, 0, 0, 0, 0, 0, 0}, - out_n0_n1_n2_k0_k1_k2_h_w_thread_desc.GetLengths(), - arithmetic_sequence_gen<0, 8, 1>::type{}, - Number<1>{}); - } - } -}; - -} // namespace ck -#endif diff --git a/composable_kernel/include/kernel_algorithm/gridwise_convolution_implicit_gemm_v4_nchw_kcyx_nkhw_lds_double_buffer.hpp b/composable_kernel/include/kernel_algorithm/gridwise_convolution_implicit_gemm_v4_nchw_kcyx_nkhw_lds_double_buffer.hpp deleted file mode 100644 index 65c397564f..0000000000 --- a/composable_kernel/include/kernel_algorithm/gridwise_convolution_implicit_gemm_v4_nchw_kcyx_nkhw_lds_double_buffer.hpp +++ /dev/null @@ -1,415 +0,0 @@ -#ifndef CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V4_NCHW_KCYX_NKHW_LDS_DOUBLE_BUFFER -#define CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V4_NCHW_KCYX_NKHW_LDS_DOUBLE_BUFFER - -#include "common_header.hpp" -#include "ConstantTensorDescriptor.hpp" -#include "ConstantMergedTensorDescriptor.hpp" -#include "ConstantMatrixDescriptor.hpp" -#include "blockwise_generic_tensor_slice_copy.hpp" -#include "blockwise_gemm.hpp" -#include "threadwise_generic_tensor_slice_copy.hpp" - -namespace ck { - -// define B = merge(N0, Ho, Wo) -template -struct GridwiseConvolutionImplicitGemm_v4_nchw_kcyx_nkhw_lds_double_buffer -{ - __device__ void Run(const Float* const __restrict__ p_in_global, - const Float* const __restrict__ p_wei_global, - Float* const __restrict__ p_out_global) const - { - // this is a mess - // TODO: find more elegent way of specifying (or calculating) performance parameters - static_assert(N2 == GemmNPerThreadSubC, "wrong!"); - static_assert((N1 * N2 * BPerBlock) % - (GemmNPerThreadSubC * GemmNLevel0Cluster * GemmNLevel1Cluster) == - 0, - "wrong!"); - - constexpr auto I0 = Number<0>{}; - constexpr auto I1 = Number<1>{}; - constexpr auto I2 = Number<2>{}; - constexpr auto I3 = Number<3>{}; - constexpr auto I5 = Number<5>{}; - - constexpr auto True = integral_constant{}; - - constexpr auto in_n_c_h_w_global_desc = InGlobalDesc{}; - constexpr auto wei_k_c_y_x_global_desc = WeiGlobalDesc{}; - constexpr auto out_n_k_h_w_global_desc = OutGlobalDesc{}; - - constexpr index_t N = in_n_c_h_w_global_desc.GetLength(I0); - constexpr index_t C = in_n_c_h_w_global_desc.GetLength(I1); - - constexpr index_t K = out_n_k_h_w_global_desc.GetLength(I1); - constexpr index_t Ho = out_n_k_h_w_global_desc.GetLength(I2); - constexpr index_t Wo = out_n_k_h_w_global_desc.GetLength(I3); - - constexpr index_t Y = wei_k_c_y_x_global_desc.GetLength(I2); - constexpr index_t X = wei_k_c_y_x_global_desc.GetLength(I3); - - constexpr index_t ConvStrideH = ConvStrides{}[0]; - constexpr index_t ConvStrideW = ConvStrides{}[1]; - - constexpr index_t ConvDilationH = ConvDilations{}[0]; - constexpr index_t ConvDilationW = ConvDilations{}[1]; - - static_assert(N % (N1 * N2) == 0, "wrong! cannot divice N evenly among thread"); - - constexpr index_t N0 = N / (N1 * N2); - - constexpr index_t B = N0 * Ho * Wo; - - constexpr index_t E = C * Y * X; - - // sanity-check for vectorized memory load - static_assert(ConvStrideW == 1 || InBlockCopySrcDataPerRead_B == 1, - "wrong! global vector load of input tensor is wrong"); - - static_assert((X == 1 || ConvDilationW % InBlockCopySrcDataPerRead_B == 0), - "wrong! aligment requirement for vectorized global load of input tensor will " - "be violated"); - - // divide block work by [K, B] - static_assert(K % KPerBlock == 0 && B % BPerBlock == 0 && E % (2 * EPerBlock) == 0, - "wrong! cannot divide work evenly among block"); - - constexpr index_t KBlockWork = K / KPerBlock; - constexpr index_t BBlockWork = B / BPerBlock; - - constexpr auto block_work_desc = - make_ConstantTensorDescriptor_packed(Sequence{}); - - const auto block_work_multi_id = - block_work_desc.GetMultiIndexFrom1dIndex(get_block_1d_id()); - - const index_t k_block_data_on_global = block_work_multi_id[0] * KPerBlock; - const index_t b_block_data_on_global = block_work_multi_id[1] * BPerBlock; - - // input tensor - // tensor descriptor in device memory [N0, N1, N2, Ho, Wo] - constexpr auto in_n0_n1_n2_h_w_global_desc = - in_n_c_h_w_global_desc.StridedSlice(I2, Number{}, Number{}) - .StridedSlice(I3, Number{}, Number{}) - .Fold(I0, Number{}, Number{}) - .Extract(Sequence<0, 1, 2, 4, 5>{}); - - // batch descritpor for device memory - constexpr auto in_c_y_x_global_desc = - in_n_c_h_w_global_desc.StridedSlice(I2, Number{}, Number{}) - .StridedSlice(I3, Number{}, Number{}) - .Extract(Sequence<1, 2, 3>{}); - - // merged tensor descriptor in device memory [E, N1, B, N2], src of blockwise copy - constexpr auto in_e_n1_b_n2_global_merged_desc = make_ConstantMergedTensorDescriptor( - in_c_y_x_global_desc.Embed(in_n0_n1_n2_h_w_global_desc), - Sequence<0, 1, 2>{}, - Sequence<4>{}, - Sequence<3, 6, 7>{}, - Sequence<5>{}); - - // memory layout descriptor in LDS [E, N1, B, N2], dst of blockwise copy - // be careful of LDS alignment - constexpr auto in_e_n1_b_n2_block_desc = make_ConstantTensorDescriptor_aligned( - Sequence{}, Number{}); - - // this check is ad-hoc - // TODO: need to properly implement tensor descriptor with multiple alignment - // requirements - static_assert(in_e_n1_b_n2_block_desc.GetStride(I1) % GemmDataPerReadB == 0, - "GemmDataPerReadB alignment requirement is not satisfied"); - - // input blockwise copy - // slice a merged tensor, reorder and copy to a normal tensor - // this copy operator already has blockwise offset built-in - auto blockwise_in_copy = - BlockwiseGenericTensorSliceCopy_v1( - {0, 0, b_block_data_on_global, 0}, {0, 0, 0, 0}); - - // weight tensor - // tensor descriptor in device memory, src of blockwise copy - constexpr auto wei_e_k_global_desc = - wei_k_c_y_x_global_desc.Unfold(I1, I3).ReorderGivenNew2Old(Sequence<1, 0>{}); - - // tensor descriptor in LDS, dst of blockwise copy - // be careful of LDS alignment - constexpr auto wei_e_k_block_desc = make_ConstantTensorDescriptor_aligned( - Sequence{}, - Number{}); - - // operator for blockwise copy of weight into LDS - // slice a tensor, and copy it into another tensor - // this copy operator already have blockwise offset built-in - auto blockwise_wei_copy = - BlockwiseGenericTensorSliceCopy_v1( - {0, k_block_data_on_global}, {0, 0}); - - // GEMM definition - // c_mtx += transpose(a_mtx) * b_mtx - // a_mtx[EPerBlock, KPerBlock] is in LDS - // b_mtx[EPerBlocl, N1 * BPerBlock * N2] is in LDS - // c_mtx[KPerBlock, N1 * BPerBlock * N2] is distributed among threads, and saved in - // register - constexpr auto a_e_k_block_mtx_desc = make_ConstantMatrixDescriptor( - Number{}, Number{}, Number{}); - - constexpr auto b_e_n1bn2_block_mtx_desc = - make_ConstantMatrixDescriptor(Number{}, - Number{}, - Number{}); - - // sanity check - static_assert(KPerBlock % (GemmMPerThreadSubC * GemmMLevel0Cluster * GemmMLevel1Cluster) == - 0, - "wrong!"); - - constexpr index_t GemmMRepeat = - KPerBlock / (GemmMPerThreadSubC * GemmMLevel0Cluster * GemmMLevel1Cluster); - - // c_thread_mtx definition: this is a mess - // TODO:: more elegent way of defining c_thread_mtx - constexpr auto c_k0k2_n1n2_thread_mtx_desc = make_ConstantMatrixDescriptor( - Number{}, Number{}); - - const auto blockwise_gemm = BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2< - BlockSize, - decltype(a_e_k_block_mtx_desc), - decltype(b_e_n1bn2_block_mtx_desc), - decltype(c_k0k2_n1n2_thread_mtx_desc), - GemmMPerThreadSubC, - GemmNPerThreadSubC, - GemmMLevel0Cluster, - GemmNLevel0Cluster, - GemmMLevel1Cluster, - GemmNLevel1Cluster, - GemmKPerThreadLoop, - GemmDataPerReadA, - GemmDataPerReadB>{}; - - // LDS allocation for input and weight: be careful of alignment - constexpr index_t max_align = math::lcm(InBlockCopyDstDataPerWrite_N2, - WeiBlockCopyDstDataPerWrite_K, - GemmDataPerReadA, - GemmDataPerReadB); - - constexpr index_t in_block_space = - math::integer_least_multiple(in_e_n1_b_n2_block_desc.GetElementSpace(), max_align); - - constexpr index_t wei_block_space = - math::integer_least_multiple(wei_e_k_block_desc.GetElementSpace(), max_align); - - __shared__ Float p_in_block_double[2 * in_block_space]; - __shared__ Float p_wei_block_double[2 * wei_block_space]; - - // register allocation for output - Float p_out_thread[c_k0k2_n1n2_thread_mtx_desc.GetElementSpace()]; - - // zero out threadwise output - threadwise_matrix_set_zero(c_k0k2_n1n2_thread_mtx_desc, p_out_thread); - - const Float* p_wei_block_on_global = p_wei_global; - - // LDS double buffer: preload data into LDS - { - blockwise_in_copy.Run(p_in_global, p_in_block_double); - blockwise_wei_copy.Run(p_wei_global, p_wei_block_double); - } - - // LDS double buffer: main body - for(index_t e_block_data_begin = 0; e_block_data_begin + 2 * EPerBlock < E; - e_block_data_begin += 2 * EPerBlock) - { -#pragma unroll - for(index_t iloop = 0; iloop < 2; ++iloop) - { - const bool even_loop = (iloop % 2 == 0); - - Float* p_in_block_now = - even_loop ? p_in_block_double : p_in_block_double + in_block_space; - Float* p_wei_block_now = - even_loop ? p_wei_block_double : p_wei_block_double + wei_block_space; - - Float* p_in_block_next = - even_loop ? p_in_block_double + in_block_space : p_in_block_double; - Float* p_wei_block_next = - even_loop ? p_wei_block_double + wei_block_space : p_wei_block_double; - - Float p_in_register_clipboard[blockwise_in_copy.GetRegisterClipboardSize()]; - Float p_wei_register_clipboard[blockwise_wei_copy.GetRegisterClipboardSize()]; - - blockwise_in_copy.MoveSlicingWindowOnSourceTensor(I0, Number{}, True); - p_wei_block_on_global += EPerBlock * wei_e_k_global_desc.GetStride(I0); - - __syncthreads(); - - // LDS doubel buffer: load next data from device mem - blockwise_in_copy.RunLoadRegisterClipboard(p_in_global, p_in_register_clipboard); - blockwise_wei_copy.RunLoadRegisterClipboard(p_wei_block_on_global, - p_wei_register_clipboard); - - // LDS double buffer: GEMM on current data - blockwise_gemm.Run(p_wei_block_now, p_in_block_now, p_out_thread); - - // LDS double buffer: store next data to LDS - blockwise_in_copy.RunStoreRegisterClipboard(p_in_register_clipboard, - p_in_block_next); - blockwise_wei_copy.RunStoreRegisterClipboard(p_wei_register_clipboard, - p_wei_block_next); - } - } - - // LDS double buffer: tail - { - Float p_in_register_clipboard[blockwise_in_copy.GetRegisterClipboardSize()]; - Float p_wei_register_clipboard[blockwise_wei_copy.GetRegisterClipboardSize()]; - - // even iteration - blockwise_in_copy.MoveSlicingWindowOnSourceTensor(I0, Number{}, True); - p_wei_block_on_global += EPerBlock * wei_e_k_global_desc.GetStride(I0); - - __syncthreads(); - - // LDS doubel buffer: load next data from device mem - blockwise_in_copy.RunLoadRegisterClipboard(p_in_global, p_in_register_clipboard); - blockwise_wei_copy.RunLoadRegisterClipboard(p_wei_block_on_global, - p_wei_register_clipboard); - - // LDS double buffer: GEMM on current data - blockwise_gemm.Run(p_wei_block_double, p_in_block_double, p_out_thread); - - // LDS double buffer: store next data to LDS - blockwise_in_copy.RunStoreRegisterClipboard(p_in_register_clipboard, - p_in_block_double + in_block_space); - blockwise_wei_copy.RunStoreRegisterClipboard(p_wei_register_clipboard, - p_wei_block_double + wei_block_space); - - // odd iteration - __syncthreads(); - - // LDS double buffer: GEMM on current data - blockwise_gemm.Run(p_wei_block_double + wei_block_space, - p_in_block_double + in_block_space, - p_out_thread); - } - - // copy output: register to global memory - { - constexpr index_t K2 = GemmMPerThreadSubC; - constexpr index_t K1 = GemmMLevel0Cluster * GemmMLevel1Cluster; - - // define tensor descriptor for threadwise copy - // output memory layout descriptor in register - constexpr auto out_k0_k1_k2_n1_n0_h_w_n2_thread_mem_desc = - make_ConstantTensorDescriptor_packed( - Sequence{}); - - // output tensor descriptor in register, src of threadwise copy - constexpr auto out_n0_n1_n2_k0_k1_k2_h_w_thread_desc = - out_k0_k1_k2_n1_n0_h_w_n2_thread_mem_desc.ReorderGivenNew2Old( - Sequence<4, 3, 7, 0, 1, 2, 5, 6>{}); - - // output memory layout descriptor in device memory, dst of threadwise copy - constexpr auto out_n0_n1_n2_k0_k1_k2_h_w_global_mem_desc = - out_n_k_h_w_global_desc.Fold(I1, Number{}, Number{}) - .Fold(I0, Number{}, Number{}); - - // calculate origin of thread output tensor on global memory - // blockwise GEMM c matrix starting index - const auto c_thread_mtx_on_block = - blockwise_gemm.GetBeginOfThreadMatrixC(get_thread_local_1d_id()); - - const index_t k_thread_data_on_global = - k_block_data_on_global + c_thread_mtx_on_block.row; - - const index_t b_thread_data_on_global = - b_block_data_on_global + c_thread_mtx_on_block.col / N2; - - // output merged global tensor descriptor, for calculating origin of thread tensor - // in global memory - constexpr auto out_k_n1_b_n2_global_merged_desc = make_ConstantMergedTensorDescriptor( - out_n0_n1_n2_k0_k1_k2_h_w_global_mem_desc.Unfold(I3, I5), - Sequence<3>{}, - Sequence<1>{}, - Sequence<0, 4, 5>{}, - Sequence<2>{}); - - // origin of dst in device memory - Float* p_out_thread_on_global = - p_out_global + - out_k_n1_b_n2_global_merged_desc.GetOffsetFromMultiIndex( - k_thread_data_on_global, 0, b_thread_data_on_global, 0); - - threadwise_generic_tensor_slice_copy_v1( - out_n0_n1_n2_k0_k1_k2_h_w_thread_desc, - p_out_thread, - {0, 0, 0, 0, 0, 0, 0, 0}, - out_n0_n1_n2_k0_k1_k2_h_w_global_mem_desc, - p_out_thread_on_global, - {0, 0, 0, 0, 0, 0, 0, 0}, - out_n0_n1_n2_k0_k1_k2_h_w_thread_desc.GetLengths(), - arithmetic_sequence_gen<0, 8, 1>::type{}, - Number<1>{}); - } - } -}; - -} // namespace ck -#endif diff --git a/composable_kernel/include/tensor_description/ConstantMatrixDescriptor.hpp b/composable_kernel/include/tensor_description/ConstantMatrixDescriptor.hpp index af7bc1d354..26bcc4b77f 100644 --- a/composable_kernel/include/tensor_description/ConstantMatrixDescriptor.hpp +++ b/composable_kernel/include/tensor_description/ConstantMatrixDescriptor.hpp @@ -2,6 +2,7 @@ #define CK_CONSTANT_MATRIX_DESCRIPTOR_HPP #include "common_header.hpp" +#include "ConstantTensorDescriptor.hpp" namespace ck { @@ -39,7 +40,7 @@ struct ConstantMatrixDescriptor }; template -__host__ __device__ constexpr auto make_ConstantMatrixDescriptor(Number, Number) +__host__ __device__ constexpr auto make_ConstantMatrixDescriptor_packed(Number, Number) { return ConstantMatrixDescriptor{}; } @@ -51,6 +52,13 @@ __host__ __device__ constexpr auto return ConstantMatrixDescriptor{}; } +template +__host__ __device__ constexpr auto + make_ConstantMatrixDescriptor_from_ConstantTensorDescriptor(ConstantTensorDescriptor, Sequence> +{ + return ConstantMatrixDescriptor{}; +} + template __host__ __device__ void print_ConstantMatrixDescriptor(TDesc, const char* s) { diff --git a/composable_kernel/include/tensor_description/ConstantMergedTensorDescriptor.hpp b/composable_kernel/include/tensor_description/ConstantMergedTensorDescriptor.hpp index 700f80845e..59810953aa 100644 --- a/composable_kernel/include/tensor_description/ConstantMergedTensorDescriptor.hpp +++ b/composable_kernel/include/tensor_description/ConstantMergedTensorDescriptor.hpp @@ -174,6 +174,12 @@ struct ConstantMergedTensorDescriptor return packed_desc.GetMultiIndexFrom1dIndex(id); } + + __host__ __device__ static constexpr auto Pack() + { + using Strides = decltype(calculate_tensor_strides_packed(GetLengths())); + return ConstantTensorDescriptor{}; + } }; template diff --git a/composable_kernel/include/tensor_description/ConstantTensorDescriptor.hpp b/composable_kernel/include/tensor_description/ConstantTensorDescriptor.hpp index 5fad7a46a1..c2828d7ac8 100644 --- a/composable_kernel/include/tensor_description/ConstantTensorDescriptor.hpp +++ b/composable_kernel/include/tensor_description/ConstantTensorDescriptor.hpp @@ -371,6 +371,12 @@ struct ConstantTensorDescriptor return ConstantTensorDescriptor{}; } + template + __host__ __device__ static constexpr auto Fold(Number, Sequence) + { + return Fold(Number{}, Number{}...); + } + // this function unfold dimension [FirstUnfoldDim, ..., LastUnfoldDim] into 1 dimension template __host__ __device__ static constexpr auto Unfold(Number, Number) @@ -409,21 +415,18 @@ struct ConstantTensorDescriptor return ConstantTensorDescriptor{}; } + __host__ __device__ static constexpr auto Pack() + { + using Strides = decltype(calculate_tensor_strides_packed(Lengths{})); + return ConstantTensorDescriptor{}; + } + template __host__ __device__ static constexpr auto ReorderGivenNew2Old(MapNew2Old) { return ConstantTensorDescriptor{}; } - -#if 0 // require sequence_sort, which is not implemented yet - template - __host__ __device__ static constexpr auto ReorderGivenOld2New(MapOld2New) - { - return ConstantTensorDescriptor{} - } -#endif }; template diff --git a/driver/include/device_convolution_implicit_gemm_v4_nchw_kcyx_nkhw.hpp b/driver/include/device_convolution_implicit_gemm_v4_nchw_kcyx_nkhw.hpp deleted file mode 100644 index af5711a2dc..0000000000 --- a/driver/include/device_convolution_implicit_gemm_v4_nchw_kcyx_nkhw.hpp +++ /dev/null @@ -1,232 +0,0 @@ -#pragma once -#include -#include "device.hpp" -#include "tensor.hpp" -#include "gridwise_convolution_kernel_wrapper.hpp" -#include "gridwise_convolution_implicit_gemm_v4_nchw_kcyx_nkhw.hpp" -#include "gridwise_convolution_implicit_gemm_v4_nchw_kcyx_nkhw_lds_double_buffer.hpp" - -using namespace ck; - -template -void device_convolution_implicit_gemm_v4_nchw_kcyx_nkhw(InDesc, - const Tensor& in_nchw, - WeiDesc, - const Tensor& wei_kcyx, - OutDesc, - Tensor& out_nkhw, - ConvStrides, - ConvDilations, - index_t nrepeat) -{ - constexpr auto I0 = Number<0>{}; - constexpr auto I1 = Number<1>{}; - constexpr auto I2 = Number<2>{}; - constexpr auto I3 = Number<3>{}; - - constexpr auto in_nchw_desc = InDesc{}; - constexpr auto wei_kcyx_desc = WeiDesc{}; - constexpr auto out_nkhw_desc = OutDesc{}; - - constexpr index_t Hi = in_nchw_desc.GetLength(I2); - constexpr index_t Wi = in_nchw_desc.GetLength(I3); - - constexpr index_t N = out_nkhw_desc.GetLength(I0); - constexpr index_t Ho = out_nkhw_desc.GetLength(I2); - constexpr index_t Wo = out_nkhw_desc.GetLength(I3); - - constexpr index_t K = wei_kcyx_desc.GetLength(I0); - constexpr index_t C = wei_kcyx_desc.GetLength(I1); - constexpr index_t Y = wei_kcyx_desc.GetLength(I2); - constexpr index_t X = wei_kcyx_desc.GetLength(I3); - - std::size_t data_sz = sizeof(T); - DeviceMem in_nchw_device_buf(data_sz * in_nchw.mDesc.GetElementSpace()); - DeviceMem wei_kcyx_device_buf(data_sz * wei_kcyx.mDesc.GetElementSpace()); - DeviceMem out_nkhw_device_buf(data_sz * out_nkhw.mDesc.GetElementSpace()); - - in_nchw_device_buf.ToDevice(in_nchw.mData.data()); - wei_kcyx_device_buf.ToDevice(wei_kcyx.mData.data()); - out_nkhw_device_buf.ToDevice(out_nkhw.mData.data()); - - constexpr index_t N1 = 2; - constexpr index_t N2 = 4; - - constexpr index_t B = (N * Ho * Wo) / (N1 * N2); - -#if 1 - constexpr index_t BlockSize = 256; - - constexpr index_t BPerBlock = 16; - constexpr index_t KPerBlock = 128; - constexpr index_t EPerBlock = 8; - - constexpr index_t GemmMPerThreadSubC = 4; - constexpr index_t GemmNPerThreadSubC = 4; - constexpr index_t GemmMLevel0Cluster = 4; - constexpr index_t GemmNLevel0Cluster = 4; - constexpr index_t GemmMLevel1Cluster = 4; - constexpr index_t GemmNLevel1Cluster = 4; - constexpr index_t GemmKPerThreadLoop = 1; - constexpr index_t GemmDataPerReadA = 4; - constexpr index_t GemmDataPerReadB = 4; - - using InBlockCopySubLengths_E_N1_B_N2 = Sequence<1, 1, 1, 4>; - using InBlockCopyClusterLengths_E_N1_B_N2 = Sequence<8, 2, 16, 1>; - using InBlockCopyThreadClusterArrangeOrder = Sequence<0, 1, 3, 2>; // [E, N1, N2, B] - using InBlockCopySrcAccessOrder = Sequence<0, 1, 3, 2>; // [E, N1, N2, B] - using InBlockCopyDstAccessOrder = Sequence<0, 1, 2, 3>; // [E, N1, B, N2] - - constexpr index_t InBlockCopySrcDataPerRead_B = 1; - constexpr index_t InBlockCopyDstDataPerWrite_N2 = 4; - - using WeiBlockCopySubLengths_E_K = Sequence<4, 1>; - using WeiBlockCopyClusterLengths_E_K = Sequence<2, 128>; - using WeiBlockCopyThreadClusterArrangeOrder = Sequence<1, 0>; // [K, E] - using WeiBlockCopySrcAccessOrder = Sequence<1, 0>; // [K, E] - using WeiBlockCopyDstAccessOrder = Sequence<0, 1>; // [E, K] - - constexpr index_t WeiBlockCopySrcDataPerRead_E = 4; - constexpr index_t WeiBlockCopyDstDataPerWrite_K = 1; -#elif 0 - constexpr index_t BlockSize = 256; - - constexpr index_t BPerBlock = 16; - constexpr index_t KPerBlock = 128; - constexpr index_t EPerBlock = 8; - - constexpr index_t GemmMPerThreadSubC = 4; - constexpr index_t GemmNPerThreadSubC = 4; - constexpr index_t GemmMLevel0Cluster = 4; - constexpr index_t GemmNLevel0Cluster = 4; - constexpr index_t GemmMLevel1Cluster = 4; - constexpr index_t GemmNLevel1Cluster = 4; - constexpr index_t GemmKPerThreadLoop = 1; - constexpr index_t GemmDataPerReadA = 4; - constexpr index_t GemmDataPerReadB = 4; - - using InBlockCopySubLengths_E_N1_B_N2 = Sequence<1, 1, 4, 1>; - using InBlockCopyClusterLengths_E_N1_B_N2 = Sequence<8, 2, 4, 4>; - using InBlockCopyThreadClusterArrangeOrder = Sequence<0, 1, 3, 2>; // [E, N1, N2, B] - using InBlockCopySrcAccessOrder = Sequence<0, 1, 3, 2>; // [E, N1, N2, B] - using InBlockCopyDstAccessOrder = Sequence<0, 1, 2, 3>; // [E, N1, B, N2] - - constexpr index_t InBlockCopySrcDataPerRead_B = 4; - constexpr index_t InBlockCopyDstDataPerWrite_N2 = 1; - - using WeiBlockCopySubLengths_E_K = Sequence<4, 1>; - using WeiBlockCopyClusterLengths_E_K = Sequence<2, 128>; - using WeiBlockCopyThreadClusterArrangeOrder = Sequence<1, 0>; // [K, E] - using WeiBlockCopySrcAccessOrder = Sequence<1, 0>; // [K, E] - using WeiBlockCopyDstAccessOrder = Sequence<0, 1>; // [E, K] - - constexpr index_t WeiBlockCopySrcDataPerRead_E = 4; - constexpr index_t WeiBlockCopyDstDataPerWrite_K = 1; -#elif 1 - constexpr index_t BlockSize = 256; - - constexpr index_t BPerBlock = 16; - constexpr index_t KPerBlock = 128; - constexpr index_t EPerBlock = 8; - - constexpr index_t GemmMPerThreadSubC = 4; - constexpr index_t GemmNPerThreadSubC = 4; - constexpr index_t GemmMLevel0Cluster = 4; - constexpr index_t GemmNLevel0Cluster = 4; - constexpr index_t GemmMLevel1Cluster = 4; - constexpr index_t GemmNLevel1Cluster = 4; - constexpr index_t GemmKPerThreadLoop = 1; - constexpr index_t GemmDataPerReadA = 4; - constexpr index_t GemmDataPerReadB = 4; - - using InBlockCopySubLengths_E_N1_B_N2 = Sequence<1, 1, 2, 2>; - using InBlockCopyClusterLengths_E_N1_B_N2 = Sequence<8, 2, 8, 2>; - using InBlockCopyThreadClusterArrangeOrder = Sequence<0, 1, 3, 2>; // [E, N1, N2, B] - using InBlockCopySrcAccessOrder = Sequence<0, 1, 3, 2>; // [E, N1, N2, B] - using InBlockCopyDstAccessOrder = Sequence<0, 1, 2, 3>; // [E, N1, B, N2] - - constexpr index_t InBlockCopySrcDataPerRead_B = 2; - constexpr index_t InBlockCopyDstDataPerWrite_N2 = 2; - - using WeiBlockCopySubLengths_E_K = Sequence<4, 1>; - using WeiBlockCopyClusterLengths_E_K = Sequence<2, 128>; - using WeiBlockCopyThreadClusterArrangeOrder = Sequence<1, 0>; // [K, E] - using WeiBlockCopySrcAccessOrder = Sequence<1, 0>; // [K, E] - using WeiBlockCopyDstAccessOrder = Sequence<0, 1>; // [E, K] - - constexpr index_t WeiBlockCopySrcDataPerRead_E = 4; - constexpr index_t WeiBlockCopyDstDataPerWrite_K = 1; -#endif - - constexpr index_t GridSize = - ((B + BPerBlock - 1) / BPerBlock) * ((K + KPerBlock - 1) / KPerBlock); - - printf("%s: BlockSize %u, GridSize %u \n", __func__, BlockSize, GridSize); - - for(index_t i = 0; i < nrepeat; ++i) - { - constexpr auto gridwise_conv = -#if 0 - GridwiseConvolutionImplicitGemm_v4_nchw_kcyx_nkhw -#else - GridwiseConvolutionImplicitGemm_v4_nchw_kcyx_nkhw_lds_double_buffer -#endif - {}; - - float time = launch_kernel(run_gridwise_convolution_kernel, - dim3(GridSize), - dim3(BlockSize), - 0, - static_cast(in_nchw_device_buf.GetDeviceBuffer()), - static_cast(wei_kcyx_device_buf.GetDeviceBuffer()), - static_cast(out_nkhw_device_buf.GetDeviceBuffer())); - - printf("Elapsed time : %f ms, %f TFlop/s\n", - time, - (float)calculate_convolution_flops(InDesc{}, WeiDesc{}, OutDesc{}) / - (std::size_t(1000) * 1000 * 1000) / time); - usleep(std::min(time * 1000, float(10000))); - } - - out_nkhw_device_buf.FromDevice(out_nkhw.mData.data()); -} diff --git a/driver/src/driver.cpp b/driver/src/driver.cpp index e8977fe6f8..b13dcc2055 100644 --- a/driver/src/driver.cpp +++ b/driver/src/driver.cpp @@ -12,7 +12,8 @@ #include "device_convolution_implicit_gemm_v1_nchw_cyxk_nkhw.hpp" #include "device_convolution_implicit_gemm_v2_chwn_cyxk_khwn.hpp" #include "device_convolution_implicit_gemm_v3_nchw_cyxk_nkhw.hpp" -#include "device_convolution_implicit_gemm_v4_nchw_kcyx_nkhw.hpp" +#include "device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw.hpp" +#include "device_convolution_implicit_gemm_v4r2_nchw_kcyx_nkhw.hpp" using namespace ck; @@ -870,8 +871,10 @@ int main(int argc, char* argv[]) device_convolution_implicit_gemm_v2_chwn_cyxk_khwn #elif 0 device_convolution_implicit_gemm_v3_nchw_cyxk_nkhw +#elif 0 + device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw #elif 1 - device_convolution_implicit_gemm_v4_nchw_kcyx_nkhw + device_convolution_implicit_gemm_v4r2_nchw_kcyx_nkhw #endif (in_nchw_desc, in_nchw,