diff --git a/composable_kernel/include/kernel_algorithm/gridwise_convolution_implicit_gemm_v1r3_chwn_cyxk_khwn.hpp b/composable_kernel/include/kernel_algorithm/gridwise_convolution_implicit_gemm_v1r3_chwn_cyxk_khwn.hpp index 1e20a7534d..dd3cd21c6c 100644 --- a/composable_kernel/include/kernel_algorithm/gridwise_convolution_implicit_gemm_v1r3_chwn_cyxk_khwn.hpp +++ b/composable_kernel/include/kernel_algorithm/gridwise_convolution_implicit_gemm_v1r3_chwn_cyxk_khwn.hpp @@ -1,5 +1,5 @@ -#ifndef CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V1R3_CHWN_CYXK_KHWN -#define CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V1R3_CHWN_CYXK_KHWN +#ifndef CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V1R3_CHWN_CYXK_KHWN_HPP +#define CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V1R3_CHWN_CYXK_KHWN_HPP #include "common_header.hpp" #include "ConstantTensorDescriptor.hpp" @@ -79,21 +79,21 @@ struct GridwiseConvolutionImplicitGemm_v1r3_chwn_cyxk_khwn Ho % HoPerBlock == 0 && Wo % WoPerBlock == 0, "wrong! cannot evenly divide work for workgroup "); - constexpr index_t NBlockWork = math::integer_divide_ceil(N, NPerBlock); constexpr index_t KBlockWork = math::integer_divide_ceil(K, KPerBlock); constexpr index_t HBlockWork = math::integer_divide_ceil(Ho, HoPerBlock); constexpr index_t WBlockWork = math::integer_divide_ceil(Wo, WoPerBlock); + constexpr index_t NBlockWork = math::integer_divide_ceil(N, NPerBlock); constexpr auto block_work_desc = make_ConstantTensorDescriptor_packed( - Sequence{}); + Sequence{}); const auto block_work_multi_id = block_work_desc.GetMultiIndexFrom1dIndex(get_block_1d_id()); - const index_t n_block_data_begin = block_work_multi_id[0] * NPerBlock; - const index_t k_block_data_begin = block_work_multi_id[1] * KPerBlock; - const index_t ho_block_data_begin = block_work_multi_id[2] * HoPerBlock; - const index_t wo_block_data_begin = block_work_multi_id[3] * WoPerBlock; + const index_t k_block_data_begin = block_work_multi_id[0] * KPerBlock; + const index_t ho_block_data_begin = block_work_multi_id[1] * HoPerBlock; + const index_t wo_block_data_begin = block_work_multi_id[2] * WoPerBlock; + const index_t n_block_data_begin = block_work_multi_id[3] * NPerBlock; const index_t hi_block_data_begin = ho_block_data_begin; const index_t wi_block_data_begin = wo_block_data_begin; diff --git a/composable_kernel/include/kernel_algorithm/gridwise_convolution_implicit_gemm_v1r3_chwn_cyxk_khwn_lds_double_buffer.hpp b/composable_kernel/include/kernel_algorithm/gridwise_convolution_implicit_gemm_v1r3_chwn_cyxk_khwn_lds_double_buffer.hpp index 4f297fac3d..9c816bf21d 100644 --- a/composable_kernel/include/kernel_algorithm/gridwise_convolution_implicit_gemm_v1r3_chwn_cyxk_khwn_lds_double_buffer.hpp +++ b/composable_kernel/include/kernel_algorithm/gridwise_convolution_implicit_gemm_v1r3_chwn_cyxk_khwn_lds_double_buffer.hpp @@ -1,5 +1,5 @@ -#ifndef CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V1R3_CHWN_CYXK_KHWN_LDS_DOUBLE_BUFFER -#define CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V1R3_CHWN_CYXK_KHWN_LDS_DOUBLE_BUFFER +#ifndef CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V1R3_CHWN_CYXK_KHWN_LDS_DOUBLE_BUFFER_HPP +#define CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V1R3_CHWN_CYXK_KHWN_LDS_DOUBLE_BUFFER_HPP #include "common_header.hpp" #include "ConstantTensorDescriptor.hpp" @@ -74,14 +74,8 @@ struct GridwiseConvolutionImplicitGemm_v1r3_chwn_cyxk_khwn_lds_double_buffer constexpr index_t Y = wei_c_y_x_k_global_desc.GetLength(I1); constexpr index_t X = wei_c_y_x_k_global_desc.GetLength(I2); - constexpr index_t HiPerBlock = HoPerBlock + Y - 1; - constexpr index_t WiPerBlock = WoPerBlock + X - 1; - - // assert for LDS double buffer - static_assert(C % (2 * CPerBlock) == 0, "C cannot be evenly divided"); - // divide block work: [K, Ho, Wo, N] - static_assert(N % NPerBlock == 0 && K % KPerBlock == 0 && C % CPerBlock == 0 && + static_assert(N % NPerBlock == 0 && K % KPerBlock == 0 && C % (2 * CPerBlock) == 0 && Ho % HoPerBlock == 0 && Wo % WoPerBlock == 0, "wrong! cannot evenly divide work for workgroup "); diff --git a/composable_kernel/include/kernel_algorithm/gridwise_convolution_implicit_gemm_v1r3_chwn_cyxk_khwn_padded.hpp b/composable_kernel/include/kernel_algorithm/gridwise_convolution_implicit_gemm_v1r3_chwn_cyxk_khwn_padded.hpp new file mode 100644 index 0000000000..3985bbf3a7 --- /dev/null +++ b/composable_kernel/include/kernel_algorithm/gridwise_convolution_implicit_gemm_v1r3_chwn_cyxk_khwn_padded.hpp @@ -0,0 +1,420 @@ +#ifndef CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V1R3_CHWN_CYXK_KHWN_PADDED_HPP +#define CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V1R3_CHWN_CYXK_KHWN_PADDED_HPP + +#include "common_header.hpp" +#include "ConstantTensorDescriptor.hpp" +#include "ConstantMatrixDescriptor.hpp" +#include "blockwise_generic_tensor_slice_copy.hpp" +#include "threadwise_generic_tensor_slice_copy.hpp" +#include "blockwise_batched_gemm.hpp" + +namespace ck { + +template +struct GridwiseConvolutionImplicitGemm_v1r3_chwn_cyxk_khwn_padded +{ + __device__ void Run(const Float* const __restrict__ p_in_global, + const Float* const __restrict__ p_wei_global, + Float* const __restrict__ p_out_global) const + { + // be careful of this assertion + static_assert( + NPerBlock % NPerThread == 0 && + ((GemmNPerThreadSubC <= NPerBlock && NPerBlock % GemmNPerThreadSubC == 0) || + (GemmNPerThreadSubC >= NPerBlock && NPerThread == NPerBlock && + GemmNPerThreadSubC % NPerThread == 0)), + "wrong!"); + + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + constexpr auto I2 = Number<2>{}; + constexpr auto I3 = Number<3>{}; + + constexpr auto in_c_h_w_n_global_desc = InGlobalDesc{}; + constexpr auto wei_c_y_x_k_global_desc = WeiGlobalDesc{}; + constexpr auto out_k_h_w_n_global_desc = OutGlobalDesc{}; + + constexpr index_t C = in_c_h_w_n_global_desc.GetLength(I0); + + constexpr index_t K = out_k_h_w_n_global_desc.GetLength(I0); + constexpr index_t Ho = out_k_h_w_n_global_desc.GetLength(I1); + constexpr index_t Wo = out_k_h_w_n_global_desc.GetLength(I2); + constexpr index_t N = out_k_h_w_n_global_desc.GetLength(I3); + + constexpr index_t Y = wei_c_y_x_k_global_desc.GetLength(I1); + constexpr index_t X = wei_c_y_x_k_global_desc.GetLength(I2); + + // divide block work: [K, Ho, Wo, N] + static_assert(N % NPerBlock == 0 && K % KPerBlock == 0 && C % CPerBlock == 0 && + Ho % HoPerBlock == 0 && Wo % WoPerBlock == 0, + "wrong! cannot evenly divide work for workgroup "); + + constexpr index_t KBlockWork = math::integer_divide_ceil(K, KPerBlock); + constexpr index_t HBlockWork = math::integer_divide_ceil(Ho, HoPerBlock); + constexpr index_t WBlockWork = math::integer_divide_ceil(Wo, WoPerBlock); + constexpr index_t NBlockWork = math::integer_divide_ceil(N, NPerBlock); + + constexpr auto block_work_desc = make_ConstantTensorDescriptor_packed( + Sequence{}); + + const auto block_work_multi_id = + block_work_desc.GetMultiIndexFrom1dIndex(get_block_1d_id()); + + const index_t k_block_data_begin = block_work_multi_id[0] * KPerBlock; + const index_t ho_block_data_begin = block_work_multi_id[1] * HoPerBlock; + const index_t wo_block_data_begin = block_work_multi_id[2] * WoPerBlock; + const index_t n_block_data_begin = block_work_multi_id[3] * NPerBlock; + + const index_t hi_block_data_begin = ho_block_data_begin; + const index_t wi_block_data_begin = wo_block_data_begin; + + // global tensor view + constexpr auto wei_c_k_global_desc = wei_c_y_x_k_global_desc.Extract(I0, I3); + + // LDS tensor view + // be careful of alignment + constexpr index_t max_align = math::lcm(InBlockCopyDataPerAccess_N, + WeiBlockCopyDataPerAccess_K, + GemmDataPerReadA, + GemmDataPerReadB); + + constexpr auto in_c_h_w_n_block_desc = make_ConstantTensorDescriptor_aligned( + Sequence{}, Number{}); + + // this check is ad-hoc + // TODO: need to properly implement tensor descriptor with alignment + static_assert(in_c_h_w_n_block_desc.GetStride(I1) % GemmDataPerReadB == 0, + "GemmDataPerReadB alignment requirement is not meet"); + + constexpr auto wei_c_k_block_desc = make_ConstantTensorDescriptor_aligned( + Sequence{}, Number{}); + + // tensor view of threadwise output in register + constexpr auto out_k_h_w_n_thread_desc = make_ConstantTensorDescriptor_packed( + Sequence{}); + + // blockwise copy + // input: format is [C, Hi, Wi, N] + auto blockwise_in_copy = + BlockwiseGenericTensorSliceCopy_v1, + Sequence<0, 1, 2, 3>, + Sequence<0, 1, 2, 3>, + 3, + 3, + InBlockCopyDataPerAccess_N, + InBlockCopyDataPerAccess_N>({0, 0, 0, 0}, + {0, 0, 0, 0}); + + // blockwise wei copy + // format is [CPerBlock, X * KPerBlock] + const auto blockwise_wei_copy = + BlockwiseGenericTensorSliceCopy_v1, + Sequence<0, 1>, + Sequence<0, 1>, + 1, + 1, + WeiBlockCopyDataPerAccess_K, + WeiBlockCopyDataPerAccess_K>({0, 0}, {0, 0}); + + // a series of blockwise batched GEMM + // C_matrix += transpose(A_matrix) * B_matrix + // A_matrix and B_matrix saved in LDS, C_matrix saved in register + // A_matrix[C,K] is a sub-matrix of wei_block[C,K] + // B_matrix[C,Wo*N] is a sub-matrix of in_block[C,Hi,Wi,N] + // C_matrix[K,Wo*N] is a sub-matrix of out_block[K,Ho,Wo,N] + constexpr auto a_c_k_block_mtx_desc = make_ConstantMatrixDescriptor( + Number{}, Number{}, Number{}); + + constexpr auto b_c_wn_block_mtx_desc = + make_ConstantMatrixDescriptor(Number{}, + Number{}, + Number{}); + + constexpr auto c_k_wn_thread_mtx_desc = + make_ConstantMatrixDescriptor(Number{}, + Number{}, + Number{}); + + const auto blockwise_batch_gemm = + BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2< + BlockSize, + decltype(a_c_k_block_mtx_desc), + decltype(b_c_wn_block_mtx_desc), + decltype(c_k_wn_thread_mtx_desc), + 0, + in_c_h_w_n_block_desc.GetStride(I1), + out_k_h_w_n_thread_desc.GetStride(I1), + HoPerBlock, + GemmMPerThreadSubC, + GemmNPerThreadSubC, + GemmMLevel0Cluster, + GemmNLevel0Cluster, + GemmMLevel1Cluster, + GemmNLevel1Cluster, + GemmKPerThreadLoop, + HoPerThread, + GemmDataPerReadA, + GemmDataPerReadB>{}; + + // LDS: be careful of alignment + constexpr index_t in_block_space = in_c_h_w_n_block_desc.GetElementSpace(); + constexpr index_t wei_block_space = wei_c_k_block_desc.GetElementSpace(); + + __shared__ Float p_in_block[in_block_space]; + __shared__ Float p_wei_block[wei_block_space]; + + // register + // C++ lambda doesn't capture array, use pointer instead + Float p_out_thread_data[out_k_h_w_n_thread_desc.GetElementSpace()]; + Float* const p_out_thread = p_out_thread_data; + +#if 0 + if(get_thread_local_1d_id() == 0 && get_block_1d_id() == 0) + { + print_ConstantTensorDescriptor(in_c_h_w_n_global_desc, "in_c_h_w_n_global_desc"); + print_ConstantTensorDescriptor(wei_c_y_x_k_global_desc, "wei_c_y_x_k_global_desc"); + + print_ConstantTensorDescriptor(in_c_h_w_n_block_desc, "in_c_h_w_n_block_desc"); + print_ConstantTensorDescriptor(wei_c_x_k_block_desc, "wei_c_x_k_block_desc"); + + printf("in_block_space %u, wei_block_space %u\n", in_block_space, wei_block_space); + } +#endif + + // set threadwise output tensor to 0 + threadwise_matrix_set_zero(c_k_wn_thread_mtx_desc, p_out_thread); + + for(index_t y = 0; y < Y; ++y) + { + for(index_t x = 0; x < X; ++x) + { + const Float* p_in_global_block_offset = + p_in_global + + in_c_h_w_n_global_desc.GetOffsetFromMultiIndex( + 0, hi_block_data_begin + y, wi_block_data_begin + x, n_block_data_begin); + + const Float* p_wei_global_block_offset = + p_wei_global + + wei_c_y_x_k_global_desc.GetOffsetFromMultiIndex(0, y, x, k_block_data_begin); + + for(index_t c_block_data_begin = 0; c_block_data_begin < C; + c_block_data_begin += CPerBlock, + p_in_global_block_offset += + CPerBlock * in_c_h_w_n_global_desc.GetStride(I0), + p_wei_global_block_offset += + CPerBlock * wei_c_y_x_k_global_desc.GetStride(I0)) + { + blockwise_in_copy.Run(p_in_global_block_offset, p_in_block); + + blockwise_wei_copy.Run(p_wei_global_block_offset, p_wei_block); + + __syncthreads(); + + blockwise_batch_gemm.Run(p_wei_block, p_in_block, p_out_thread); + + __syncthreads(); + } + } + } + + // output: register to global mem + const auto c_thread_mtx_begin = + blockwise_batch_gemm.GetBeginOfThreadMatrixC(get_thread_local_1d_id()); + + const index_t k_thread_data_begin = c_thread_mtx_begin.row; + const index_t ho_thread_data_begin = c_thread_mtx_begin.batch; + const index_t wo_thread_data_begin = c_thread_mtx_begin.col / NPerBlock; + const index_t n_thread_data_begin = c_thread_mtx_begin.col % NPerBlock; + + static_if{}([&](auto fwd) { + // fwd do nothing but perfect forwarding. + // Using this trick to make this lambda a generic lambda, so it won't be compiled until + // being instantiated here + static_assert( + (fwd(GemmNPerThreadSubC) <= NPerBlock && NPerBlock % GemmNPerThreadSubC == 0), + "wrong!"); + + // output is a 10d tensor + constexpr index_t N2 = GemmNPerThreadSubC; + constexpr index_t N1 = NPerBlock / N2; + + constexpr index_t W2 = + (GemmNLevel0Cluster * GemmNLevel1Cluster) / fwd(NPerBlock / GemmNPerThreadSubC); + constexpr index_t W1 = WoPerBlock / W2; + + constexpr index_t K2 = GemmMPerThreadSubC; + constexpr index_t K1 = KPerBlock / KPerThread; + + constexpr auto out_10d_global_desc = fwd(out_k_h_w_n_global_desc) + .Fold(I3, Number{}, Number{}) + .Fold(I2, Number{}, Number{}) + .Fold(I0, Number{}, Number{}); + + constexpr auto out_10d_thread_desc = fwd(out_k_h_w_n_thread_desc) + .Fold(I3, Number<1>{}, Number{}) + .Fold(I2, Number{}, Number<1>{}) + .Fold(I0, Number<1>{}, Number{}); + +#if 0 + if(get_thread_local_1d_id() == 0 && get_block_1d_id() == 0) + { + print_ConstantTensorDescriptor(out_k_h_w_n_thread_desc, + "a: out_k_h_w_n_thread_desc"); + print_ConstantTensorDescriptor(out_10d_thread_desc, "a: out_10d_thread_desc"); + + print_ConstantTensorDescriptor(out_k_h_w_n_global_desc, + "a: out_k_h_w_n_global_desc"); + print_ConstantTensorDescriptor(out_10d_global_desc, "a: out_10d_global_desc"); + } +#endif + + Float* p_out_thread_on_global = p_out_global + + out_k_h_w_n_global_desc.GetOffsetFromMultiIndex( + k_block_data_begin + k_thread_data_begin, + ho_block_data_begin + ho_thread_data_begin, + wo_block_data_begin + wo_thread_data_begin, + n_block_data_begin + n_thread_data_begin); + +#if 1 + ThreadwiseGenericTensorSliceCopy_v1r2::type, + 9, + OutThreadCopyDataPerAccess_N, + OutThreadCopyDataPerAccess_N>( + make_zero_array(), make_zero_array()) + .Run(p_out_thread, p_out_thread_on_global); +#elif 0 + ThreadwiseGenericTensorSliceCopy_v1r1::type, + arithmetic_sequence_gen<0, 10, 1>::type, + 9, + 9, + OutThreadCopyDataPerAccess_N, + OutThreadCopyDataPerAccess_N>( + make_zero_array(), make_zero_array()) + .Run(p_out_thread, p_out_thread_on_global); +#endif + }).Else([&](auto fwd) { + static_assert(fwd(GemmNPerThreadSubC) >= NPerBlock && NPerThread == NPerBlock && + GemmNPerThreadSubC % NPerThread == 0, + "wrong!"); + + // output is a 10d tensor + constexpr index_t N1 = NPerBlock; + + constexpr index_t W3 = GemmNPerThreadSubC / NPerBlock; + constexpr index_t W2 = GemmNLevel0Cluster * GemmNLevel1Cluster; + constexpr index_t W1 = WoPerBlock / fwd(W2 * W3); + + constexpr index_t K2 = GemmMPerThreadSubC; + constexpr index_t K1 = KPerBlock / KPerThread; + + constexpr auto out_10d_global_desc = + fwd(out_k_h_w_n_global_desc) + .Fold(I3, Number{}) + .Fold(I2, Number{}, Number{}, Number{}) + .Fold(I0, Number{}, Number{}); + + constexpr auto out_10d_thread_desc = + fwd(out_k_h_w_n_thread_desc) + .Fold(I3, Number{}) + .Fold(I2, Number{}, Number<1>{}, Number{}) + .Fold(I0, Number<1>{}, Number{}); + +#if 0 + if(get_thread_local_1d_id() == 0 && get_block_1d_id() == 0) + { + print_ConstantTensorDescriptor(out_k_h_w_n_thread_desc, + "b: out_k_h_w_n_thread_desc"); + print_ConstantTensorDescriptor(out_10d_thread_desc, "b: out_10d_thread_desc"); + + print_ConstantTensorDescriptor(out_k_h_w_n_global_desc, + "b: out_k_h_w_n_global_desc"); + print_ConstantTensorDescriptor(out_10d_global_desc, "b: out_10d_global_desc"); + } +#endif + + Float* p_out_thread_on_global = p_out_global + + out_k_h_w_n_global_desc.GetOffsetFromMultiIndex( + k_block_data_begin + k_thread_data_begin, + ho_block_data_begin + ho_thread_data_begin, + wo_block_data_begin + wo_thread_data_begin, + n_block_data_begin + n_thread_data_begin); + +#if 1 + ThreadwiseGenericTensorSliceCopy_v1r2::type, + 9, + OutThreadCopyDataPerAccess_N, + OutThreadCopyDataPerAccess_N>( + make_zero_array(), make_zero_array()) + .Run(p_out_thread, p_out_thread_on_global); +#elif 0 + ThreadwiseGenericTensorSliceCopy_v1r1::type, + arithmetic_sequence_gen<0, 10, 1>::type, + 9, + 9, + OutThreadCopyDataPerAccess_N, + OutThreadCopyDataPerAccess_N>( + make_zero_array(), make_zero_array()) + .Run(p_out_thread, p_out_thread_on_global); +#endif + }); + } +}; + +} // namespace ck +#endif diff --git a/composable_kernel/include/tensor_description/tensor_coordinate.hpp b/composable_kernel/include/tensor_description/tensor_coordinate.hpp index 8fa701ccee..77ed7c052b 100644 --- a/composable_kernel/include/tensor_description/tensor_coordinate.hpp +++ b/composable_kernel/include/tensor_description/tensor_coordinate.hpp @@ -301,14 +301,14 @@ struct TensorCoordinate private: template __host__ __device__ static constexpr auto - MakeDummyTensorCoordinate(ConstantTensorDescriptor) + MakeDummyTensorCoordinate(ConstantTensorDescriptor) { return NormalTensorCoordinate>(); } template __host__ __device__ static constexpr auto - MakeDummyTensorCoordinate(ConstantMergedTensorDescriptor) + MakeDummyTensorCoordinate(ConstantMergedTensorDescriptor) { return MergedTensorCoordinate>(); } diff --git a/driver/include/device_convolution_implicit_gemm_v1_chwn_cyxk_khwn.hpp b/driver/include/device_convolution_implicit_gemm_v1_chwn_cyxk_khwn.hpp index 8fb5c79e07..67d36fb79d 100644 --- a/driver/include/device_convolution_implicit_gemm_v1_chwn_cyxk_khwn.hpp +++ b/driver/include/device_convolution_implicit_gemm_v1_chwn_cyxk_khwn.hpp @@ -472,55 +472,54 @@ void device_convolution_implicit_gemm_v1_chwn_cyxk_khwn(InDesc, #endif constexpr index_t GridSize = - ((N + NPerBlock - 1) / NPerBlock) * ((K + KPerBlock - 1) / KPerBlock) * - ((Ho + HoPerBlock - 1) / HoPerBlock) * ((Wo + WoPerBlock - 1) / WoPerBlock); + (N / NPerBlock) * (K / KPerBlock) * (Ho / HoPerBlock) * (Wo / WoPerBlock); printf("%s: BlockSize %u, GridSize %u \n", __func__, BlockSize, GridSize); + constexpr auto gridwise_conv = +#if 0 + GridwiseConvolutionImplicitGemm_v1r1_chwn_cyxk_khwn +#elif 0 + GridwiseConvolutionImplicitGemm_v1r2_chwn_cyxk_khwn +#elif 0 + GridwiseConvolutionImplicitGemm_v1r3_chwn_cyxk_khwn +#elif 1 + GridwiseConvolutionImplicitGemm_v1r3_chwn_cyxk_khwn_lds_double_buffer +#endif + {}; + for(index_t i = 0; i < nrepeat; ++i) { - constexpr auto gridwise_conv = -#if 0 - GridwiseConvolutionImplicitGemm_v1r1_chwn_cyxk_khwn -#elif 0 - GridwiseConvolutionImplicitGemm_v1r2_chwn_cyxk_khwn -#elif 0 - GridwiseConvolutionImplicitGemm_v1r3_chwn_cyxk_khwn -#elif 1 - GridwiseConvolutionImplicitGemm_v1r3_chwn_cyxk_khwn_lds_double_buffer -#endif - {}; - float time = launch_kernel(run_gridwise_convolution_kernel, dim3(GridSize), dim3(BlockSize), diff --git a/driver/include/device_convolution_implicit_gemm_v1_chwn_cyxk_khwn_padded.hpp b/driver/include/device_convolution_implicit_gemm_v1_chwn_cyxk_khwn_padded.hpp new file mode 100644 index 0000000000..f73be2ffa5 --- /dev/null +++ b/driver/include/device_convolution_implicit_gemm_v1_chwn_cyxk_khwn_padded.hpp @@ -0,0 +1,184 @@ +#pragma once +#include +#include "device.hpp" +#include "tensor.hpp" +#include "gridwise_convolution_implicit_gemm_v1r3_chwn_cyxk_khwn_padded.hpp" + +using namespace ck; + +template +void device_convolution_implicit_gemm_v1_chwn_cyxk_khwn_padded(InDesc, + const Tensor& in_nchw, + WeiDesc, + const Tensor& wei_kcyx, + OutDesc, + Tensor& out_nkhw, + LowerPads, + UpperPads, + index_t nrepeat) +{ + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + constexpr auto I2 = Number<2>{}; + constexpr auto I3 = Number<3>{}; + + constexpr auto in_nchw_desc = InDesc{}; + constexpr auto wei_kcyx_desc = WeiDesc{}; + constexpr auto out_nkhw_desc = OutDesc{}; + + constexpr index_t Hi = in_nchw_desc.GetLength(I2); + constexpr index_t Wi = in_nchw_desc.GetLength(I3); + + constexpr index_t N = out_nkhw_desc.GetLength(I0); + constexpr index_t Ho = out_nkhw_desc.GetLength(I2); + constexpr index_t Wo = out_nkhw_desc.GetLength(I3); + + constexpr index_t K = wei_kcyx_desc.GetLength(I0); + constexpr index_t C = wei_kcyx_desc.GetLength(I1); + constexpr index_t Y = wei_kcyx_desc.GetLength(I2); + constexpr index_t X = wei_kcyx_desc.GetLength(I3); + + // reorder weight + auto wei_cyxk_desc = make_ConstantTensorDescriptor_packed(Sequence{}); + ostream_ConstantTensorDescriptor(wei_cyxk_desc, std::cout << "wei_cyxk_desc: "); + + Tensor wei_cyxk(make_TensorDescriptor(wei_cyxk_desc)); + + auto f_reorder_kcyx2cyxk = [&](auto k, auto c, auto y, auto x) { + wei_cyxk(c, y, x, k) = wei_kcyx(k, c, y, x); + }; + + make_ParallelTensorFunctor(f_reorder_kcyx2cyxk, K, C, Y, X)( + std::thread::hardware_concurrency()); + + // reorder input + auto in_chwn_desc = make_ConstantTensorDescriptor_packed(Sequence{}); + ostream_ConstantTensorDescriptor(in_chwn_desc, std::cout << "in_chwn_desc: "); + + Tensor in_chwn(make_TensorDescriptor(in_chwn_desc)); + + auto f_reorder_nchw2chwn = [&](auto n, auto c, auto hi, auto wi) { + in_chwn(c, hi, wi, n) = in_nchw(n, c, hi, wi); + }; + + make_ParallelTensorFunctor(f_reorder_nchw2chwn, N, C, Hi, Wi)( + std::thread::hardware_concurrency()); + + // output + auto out_khwn_desc = make_ConstantTensorDescriptor_packed(Sequence{}); + ostream_ConstantTensorDescriptor(out_khwn_desc, std::cout << "out_khwn_desc: "); + + Tensor out_khwn(make_TensorDescriptor(out_khwn_desc)); + + std::size_t data_sz = sizeof(T); + DeviceMem in_chwn_device_buf(data_sz * in_chwn.mDesc.GetElementSpace()); + DeviceMem wei_cyxk_device_buf(data_sz * wei_cyxk.mDesc.GetElementSpace()); + DeviceMem out_khwn_device_buf(data_sz * out_khwn.mDesc.GetElementSpace()); + + in_chwn_device_buf.ToDevice(in_chwn.mData.data()); + wei_cyxk_device_buf.ToDevice(wei_cyxk.mData.data()); + out_khwn_device_buf.ToDevice(out_khwn.mData.data()); + +#if 1 + // v1r3, 3x3, 32x32, 1x1 pad + constexpr index_t BlockSize = 128; + + constexpr index_t NPerBlock = 16; + constexpr index_t KPerBlock = 128; + constexpr index_t CPerBlock = 8; + constexpr index_t HoPerBlock = 2; + constexpr index_t WoPerBlock = 2; + + constexpr index_t NPerThread = 4; + constexpr index_t KPerThread = 8; + constexpr index_t HoPerThread = 1; + constexpr index_t WoPerThread = 2; + + constexpr index_t GemmMPerThreadSubC = 4; + constexpr index_t GemmNPerThreadSubC = 4; + constexpr index_t GemmMLevel0Cluster = 4; + constexpr index_t GemmNLevel0Cluster = 2; + constexpr index_t GemmMLevel1Cluster = 4; + constexpr index_t GemmNLevel1Cluster = 2; + constexpr index_t GemmKPerThreadLoop = 1; + constexpr index_t GemmDataPerReadA = 4; + constexpr index_t GemmDataPerReadB = 4; + + using InBlockCopySubLengths_CHWN = Sequence<1, 1, 1, 4>; + using InBlockCopyClusterLengths_CHWN = Sequence<8, 2, 2, 4>; + constexpr index_t InBlockCopyDataPerAccess_N = 4; + + using WeiBlockCopySubLengths_CK = Sequence<2, 4>; + using WeiBlockCopyClusterLengths_CK = Sequence<4, 32>; + constexpr index_t WeiBlockCopyDataPerAccess_K = 4; + + constexpr index_t OutThreadCopyDataPerAccess_N = 2; +#endif + + constexpr index_t GridSize = + (N / NPerBlock) * (K / KPerBlock) * (Ho / HoPerBlock) * (Wo / WoPerBlock); + + printf("%s: BlockSize %u, GridSize %u \n", __func__, BlockSize, GridSize); + + constexpr auto gridwise_conv = + GridwiseConvolutionImplicitGemm_v1r3_chwn_cyxk_khwn_padded{}; + + for(index_t i = 0; i < nrepeat; ++i) + { + float time = launch_kernel(run_gridwise_convolution_kernel, + dim3(GridSize), + dim3(BlockSize), + 0, + static_cast(in_chwn_device_buf.GetDeviceBuffer()), + static_cast(wei_cyxk_device_buf.GetDeviceBuffer()), + static_cast(out_khwn_device_buf.GetDeviceBuffer())); + + printf("Elapsed time : %f ms, %f TFlop/s\n", + time, + (float)calculate_convolution_flops(InDesc{}, WeiDesc{}, OutDesc{}) / + (std::size_t(1000) * 1000 * 1000) / time); + usleep(std::min(time * 1000, float(10000))); + } + + out_khwn_device_buf.FromDevice(out_khwn.mData.data()); + + // reorder output + auto f_reorder_khwn2nkhw = [&](auto k, auto ho, auto wo, auto n) { + out_nkhw(n, k, ho, wo) = out_khwn(k, ho, wo, n); + }; + + make_ParallelTensorFunctor(f_reorder_khwn2nkhw, K, Ho, Wo, N)( + std::thread::hardware_concurrency()); +} diff --git a/driver/src/driver.cpp b/driver/src/driver.cpp index 7ea05e243e..ea961d3564 100644 --- a/driver/src/driver.cpp +++ b/driver/src/driver.cpp @@ -10,6 +10,7 @@ #include "host_conv.hpp" #include "device_convolution_direct_v2_nchw_kcyx_nkhw.hpp" #include "device_convolution_implicit_gemm_v1_chwn_cyxk_khwn.hpp" +#include "device_convolution_implicit_gemm_v1_chwn_cyxk_khwn_padded.hpp" //#include "device_convolution_implicit_gemm_v1_nchw_cyxk_nkhw.hpp" //#include "device_convolution_implicit_gemm_v2_chwn_cyxk_khwn.hpp" //#include "device_convolution_implicit_gemm_v3_nchw_cyxk_nkhw.hpp" @@ -71,7 +72,7 @@ int main(int argc, char* argv[]) { using namespace ck; -#if 1 +#if 0 constexpr index_t N = 64; constexpr index_t C = 1536; constexpr index_t HI = 8; @@ -367,9 +368,19 @@ int main(int argc, char* argv[]) #if 0 device_convolution_direct_v2_nchw_kcyx_nkhw (in_nchw_desc, in_nchw, wei_kcyx_desc, wei_kcyx, out_nkhw_desc, out_nkhw_device, nrepeat); -#elif 1 +#elif 0 device_convolution_implicit_gemm_v1_chwn_cyxk_khwn( in_nchw_desc, in_nchw, wei_kcyx_desc, wei_kcyx, out_nkhw_desc, out_nkhw_device, nrepeat); +#elif 1 + device_convolution_implicit_gemm_v1_chwn_cyxk_khwn_padded(in_nchw_desc, + in_nchw, + wei_kcyx_desc, + wei_kcyx, + out_nkhw_desc, + out_nkhw_device, + lower_pads, + upper_pads, + nrepeat); #elif 0 device_convolution_implicit_gemm_v1_nchw_cyxk_nkhw( in_nchw_desc, in_nchw, wei_kcyx_desc, wei_kcyx, out_nkhw_desc, out_nkhw_device, nrepeat); @@ -419,16 +430,6 @@ int main(int argc, char* argv[]) ConvStrides{}, ConvDilations{}, nrepeat); -#elif 0 - device_implicit_gemm_convolution_1_chwn_cyxk_khwn_padded(in_nchw_desc, - in_nchw, - wei_kcyx_desc, - wei_kcyx, - out_nkhw_desc, - out_nkhw_device, - lower_pads, - upper_pads, - nrepeat); #endif if(do_verification)