diff --git a/composable_kernel/include/driver/driver_dynamic_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp b/composable_kernel/include/driver/driver_dynamic_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp index 06a755d13b..5c7f065042 100644 --- a/composable_kernel/include/driver/driver_dynamic_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp +++ b/composable_kernel/include/driver/driver_dynamic_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp @@ -4,1532 +4,441 @@ #include "common_header.hpp" #include "dynamic_tensor_descriptor.hpp" #include "dynamic_tensor_descriptor_helper.hpp" -#include "gridwise_dynamic_gemm.hpp" -#include "gridwise_operation_wrapper.hpp" +#include "driver_dynamic_gemm_v1.hpp" namespace ck { // GemmM = K // GemmN = N * Ho * Wo // GemmK = C * Y * X -template -struct DriverDynamicConvolutionForwardImplicitGemm_v4r4_nchw_kcyx_nkhw_pad + index_t GemmM1, + index_t GemmN1, + typename... Wei, + typename... In, + typename... Out, + typename ConvStrides, + typename ConvDilations, + typename InLeftPads, + typename InRightPads> +__host__ __device__ constexpr auto transform_forward_convolution_into_gemm_v4r4_nchw_kcyx_nkhw_pad( + const DynamicTensorDescriptor& wei_k_c_y_x_global_desc, + const DynamicTensorDescriptor& in_n_c_hi_wi_global_desc, + const DynamicTensorDescriptor& out_n_k_ho_wo_global_desc, + const ConvStrides& conv_strides, + const ConvDilations& conv_dilations, + const InLeftPads& in_left_pads, + const InRightPads& in_right_pads) { - template - __host__ void Run(const DynamicTensorDescriptor& wei_k_c_y_x_global_desc, - const DynamicTensorDescriptor& in_n_c_hi_wi_global_desc, - const DynamicTensorDescriptor& out_n_k_ho_wo_global_desc, - const ConvStrides& conv_strides, - const ConvDilations& conv_dilations, - const InLeftPads& in_left_pads, - const InRightPads& in_right_pads, - const FloatAB* __restrict__ p_wei_global, - const FloatAB* __restrict__ p_in_global, - FloatC* __restrict__ p_out_global) const - { - constexpr auto I0 = Number<0>{}; - constexpr auto I1 = Number<1>{}; - constexpr auto I2 = Number<2>{}; - constexpr auto I3 = Number<3>{}; + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + constexpr auto I2 = Number<2>{}; + constexpr auto I3 = Number<3>{}; - const auto N = in_n_c_hi_wi_global_desc.GetLength(I0); - const auto C = in_n_c_hi_wi_global_desc.GetLength(I1); - const auto K = out_n_k_ho_wo_global_desc.GetLength(I1); + const auto N = in_n_c_hi_wi_global_desc.GetLength(I0); + const auto C = in_n_c_hi_wi_global_desc.GetLength(I1); + const auto K = out_n_k_ho_wo_global_desc.GetLength(I1); - const auto Hi = in_n_c_hi_wi_global_desc.GetLength(I2); - const auto Wi = in_n_c_hi_wi_global_desc.GetLength(I3); + const auto Hi = in_n_c_hi_wi_global_desc.GetLength(I2); + const auto Wi = in_n_c_hi_wi_global_desc.GetLength(I3); - const auto Ho = out_n_k_ho_wo_global_desc.GetLength(I2); - const auto Wo = out_n_k_ho_wo_global_desc.GetLength(I3); + const auto Ho = out_n_k_ho_wo_global_desc.GetLength(I2); + const auto Wo = out_n_k_ho_wo_global_desc.GetLength(I3); - const auto Y = wei_k_c_y_x_global_desc.GetLength(I2); - const auto X = wei_k_c_y_x_global_desc.GetLength(I3); + const auto Y = wei_k_c_y_x_global_desc.GetLength(I2); + const auto X = wei_k_c_y_x_global_desc.GetLength(I3); - const auto ConvStrideH = conv_strides[I0]; - const auto ConvStrideW = conv_strides[I1]; + const auto ConvStrideH = conv_strides[I0]; + const auto ConvStrideW = conv_strides[I1]; - const auto ConvDilationH = conv_dilations[I0]; - const auto ConvDilationW = conv_dilations[I1]; + const auto ConvDilationH = conv_dilations[I0]; + const auto ConvDilationW = conv_dilations[I1]; - const auto InLeftPadH = in_left_pads[I0]; - const auto InLeftPadW = in_left_pads[I1]; + const auto InLeftPadH = in_left_pads[I0]; + const auto InLeftPadW = in_left_pads[I1]; - const auto InRightPadH = in_right_pads[I0]; - const auto InRightPadW = in_right_pads[I1]; + const auto InRightPadH = in_right_pads[I0]; + const auto InRightPadW = in_right_pads[I1]; - // weight tensor - const auto wei_gemmk_gemmm_global_desc = transform_dynamic_tensor_descriptor( - make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(K, C * Y * X)), - make_tuple(make_pass_through_transform(K), make_pass_through_transform(C * Y * X)), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<1>{}, Sequence<0>{})); + // weight tensor + const auto wei_gemmk_gemmm_global_desc = transform_dynamic_tensor_descriptor( + make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(K, C * Y * X)), + make_tuple(make_pass_through_transform(K), make_pass_through_transform(C * Y * X)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<1>{}, Sequence<0>{})); - // input tensor - const auto in_n_c_hip_wip_global_desc = transform_dynamic_tensor_descriptor( - in_n_c_hi_wi_global_desc, - make_tuple(make_pass_through_transform(N), - make_pass_through_transform(C), - make_pad_transform(Hi, InLeftPadH, InRightPadH), - make_pad_transform(Wi, InLeftPadW, InRightPadW)), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); + // input tensor + const auto in_n_c_hip_wip_global_desc = transform_dynamic_tensor_descriptor( + in_n_c_hi_wi_global_desc, + make_tuple(make_pass_through_transform(N), + make_pass_through_transform(C), + make_pad_transform(Hi, InLeftPadH, InRightPadH), + make_pad_transform(Wi, InLeftPadW, InRightPadW)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); - const auto in_n_c_y_ho_x_wo_global_desc = transform_dynamic_tensor_descriptor( - in_n_c_hip_wip_global_desc, - make_tuple( - make_pass_through_transform(N), - make_pass_through_transform(C), - make_embed_transform(make_tuple(Y, Ho), make_tuple(ConvDilationH, ConvStrideH)), - make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW))), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4, 5>{})); + const auto in_n_c_y_ho_x_wo_global_desc = transform_dynamic_tensor_descriptor( + in_n_c_hip_wip_global_desc, + make_tuple(make_pass_through_transform(N), + make_pass_through_transform(C), + make_embed_transform(make_tuple(Y, Ho), make_tuple(ConvDilationH, ConvStrideH)), + make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW))), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4, 5>{})); - const auto in_gemmk_gemmn_global_desc = transform_dynamic_tensor_descriptor( - in_n_c_y_ho_x_wo_global_desc, - make_tuple(make_merge_transform(make_tuple(C, Y, X)), - make_merge_transform(make_tuple(N, Ho, Wo))), - make_tuple(Sequence<1, 2, 4>{}, Sequence<0, 3, 5>{}), - make_tuple(Sequence<0>{}, Sequence<1>{})); + const auto in_gemmk_gemmn_global_desc = + transform_dynamic_tensor_descriptor(in_n_c_y_ho_x_wo_global_desc, + make_tuple(make_merge_transform(make_tuple(C, Y, X)), + make_merge_transform(make_tuple(N, Ho, Wo))), + make_tuple(Sequence<1, 2, 4>{}, Sequence<0, 3, 5>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); - // output tensor - const auto out_gemmm_gemmn_global_desc = transform_dynamic_tensor_descriptor( - make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(N, K, Ho * Wo)), - make_tuple(make_pass_through_transform(K), - make_merge_transform(make_tuple(N, Ho * Wo))), - make_tuple(Sequence<1>{}, Sequence<0, 2>{}), - make_tuple(Sequence<0>{}, Sequence<1>{})); + // output tensor + const auto out_gemmm_gemmn_global_desc = transform_dynamic_tensor_descriptor( + make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(N, K, Ho * Wo)), + make_tuple(make_pass_through_transform(K), make_merge_transform(make_tuple(N, Ho * Wo))), + make_tuple(Sequence<1>{}, Sequence<0, 2>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); - const auto GemmM = out_gemmm_gemmn_global_desc.GetLength(I0); - const auto GemmN = out_gemmm_gemmn_global_desc.GetLength(I1); - const auto GemmK = wei_gemmk_gemmm_global_desc.GetLength(I0); + const auto GemmM = out_gemmm_gemmn_global_desc.GetLength(I0); + const auto GemmN = out_gemmm_gemmn_global_desc.GetLength(I1); + const auto GemmK = wei_gemmk_gemmm_global_desc.GetLength(I0); - if(!(GemmM % GemmMPerBlock == 0 && GemmN % GemmNPerBlock == 0 && - GemmK % GemmKPerBlock == 0)) - { - throw std::runtime_error("wrong! GEMM size no divisible"); - } + assert(GemmM % GemmMPerBlock == 0 && GemmN % GemmNPerBlock == 0 && GemmK % GemmKPerBlock == 0); - constexpr auto GemmM1 = Number{}; - constexpr auto GemmN1 = Number{}; + const auto GemmM0 = GemmM / Number{}; + const auto GemmN0 = GemmN / Number{}; - const auto GemmM0 = GemmM / GemmM1; - const auto GemmN0 = GemmN / GemmN1; + const auto out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc = transform_dynamic_tensor_descriptor( + out_gemmm_gemmn_global_desc, + make_tuple(make_unmerge_transform(make_tuple(GemmM0, GemmM1)), + make_unmerge_transform(make_tuple(GemmN0, GemmN1))), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{})); - const auto out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc = - transform_dynamic_tensor_descriptor( - out_gemmm_gemmn_global_desc, - make_tuple(make_unmerge_transform(make_tuple(GemmM0, GemmM1)), - make_unmerge_transform(make_tuple(GemmN0, GemmN1))), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{})); + // out_gemm_block_cluster_desc + const auto out_gemm_block_cluster_desc = make_cluster_descriptor_v2( + make_tuple(GemmM / Number{}, GemmN / Number{})); - // hack to control index calculation when iterating over a_k_m_global tensor - constexpr auto a_k_m_global_iterator_hacks = - make_tuple(make_tuple(Sequence<0, 0, 0>{}, Sequence<0, 0, 0>{}), - make_tuple(Sequence<0, 0, 0>{}, Sequence<0, 0, 0>{})); + // hack to control index calculation when iterating over wei_gemmk_gemmm_global tensor + constexpr auto wei_gemmk_gemmm_global_iterator_hacks = + make_tuple(make_tuple(Sequence<0, 0, 0>{}, Sequence<0, 0, 0>{}), + make_tuple(Sequence<0, 0, 0>{}, Sequence<0, 0, 0>{})); - constexpr auto a_k_m_global_move_slice_window_iterator_hack = Sequence<0, 0, 0>{}; + constexpr auto wei_gemmk_gemmm_global_move_slice_window_iterator_hacks = Sequence<0, 0, 0>{}; - // hack to control index calculation when iterating over b_k_n_global tensor - constexpr auto b_k_n_global_iterator_hacks = - make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0>{}, - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1>{}), - make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0>{}, - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2>{})); + // hack to control index calculation when iterating over in_gemmk_gemmn_global tensor + constexpr auto in_gemmk_gemmn_global_iterator_hacks = + make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0>{}, + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1>{}), + make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0>{}, + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2>{})); - constexpr auto b_k_n_global_move_slice_window_iterator_hack = - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2>{}; + constexpr auto in_gemmk_gemmn_global_move_slice_window_iterator_hacks = + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2>{}; - // hack to control index calculation when iterating over c_m0_m1_n0_n1_global tensor - // hack for NKHW format - constexpr auto c_m0_m1_n0_n1_global_tensor_iterator_hacks = - make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 1, 0, 0>{}, - Sequence<0, 0, 1, 0, 0>{}), - make_tuple(Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 2, 0, 0>{}, - Sequence<0, 0, 2, 0, 0>{})); + // hack to control index calculation when iterating over out_gemmm0_gemmm1_gemmn0_gemmn1_global + // tensor hack for NKHW format + constexpr auto out_gemmm0_gemmm1_gemmn0_gemmn1_global_iterator_hacks = + make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 1, 0, 0>{}, + Sequence<0, 0, 1, 0, 0>{}), + make_tuple(Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 2, 0, 0>{}, + Sequence<0, 0, 2, 0, 0>{})); - // GEMM - using gridwise_gemm = GridwiseDynamicGemm_km_kn_m0m1n0n1_v1< - BlockSize, - FloatAB, - FloatAcc, - FloatC, - InMemoryDataOperation::Set, - decltype(wei_gemmk_gemmm_global_desc), - decltype(in_gemmk_gemmn_global_desc), - decltype(out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc), - GemmMPerBlock, - GemmNPerBlock, - GemmKPerBlock, - GemmMPerThread, - GemmNPerThread, - GemmKPerThread, - GemmMLevel0Cluster, - GemmNLevel0Cluster, - GemmMLevel1Cluster, - GemmNLevel1Cluster, - GemmABlockTransferThreadSliceLengths_GemmK_GemmM, - GemmABlockTransferThreadClusterLengths_GemmK_GemmM, - Sequence<1, 0>, - Sequence<1, 0>, - 0, - GemmABlockTransferSrcScalarPerVector_GemmK, - GemmABlockTransferDstScalarPerVector_GemmM, - false, // don't move back src coordinate after threadwise copy - GemmBBlockTransferThreadSliceLengths_GemmK_GemmN, - GemmBBlockTransferThreadClusterLengths_GemmK_GemmN, - Sequence<0, 1>, - Sequence<0, 1>, - 1, - GemmBBlockTransferSrcScalarPerVector_GemmN, - GemmBBlockTransferDstScalarPerVector_GemmN, - false, // don't move back src coordinate after threadwise copy, which will be fused with - // MoveSrcSliceWindow() to save addr computation - Sequence<2, 3, 0, 1>, - 3, - GemmCThreadTransferDstScalarPerVector_GemmN1, - decltype(a_k_m_global_iterator_hacks), - decltype(b_k_n_global_iterator_hacks), - decltype(c_m0_m1_n0_n1_global_tensor_iterator_hacks), - decltype(a_k_m_global_move_slice_window_iterator_hack), - decltype(b_k_n_global_move_slice_window_iterator_hack)>; - - const auto GridSize = (GemmM / GemmMPerBlock) * (GemmN / GemmNPerBlock); - - const bool has_main_k_block_loop = (GemmK + GemmKPerBlock) / (2 * GemmKPerBlock) > 1; - - const bool has_double_tail_k_block_loop = (GemmK / GemmKPerBlock) % 2 == 0; - -#if CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VALUE - index_t nrepeat = 100; - - for(index_t i = 0; i < 5; ++i) - { - std::cout << "Start running " << nrepeat << " times..." << std::endl; - - KernelTimer timer; - timer.Start(); - - for(index_t j = 0; j < nrepeat; ++j) - { - if(has_main_k_block_loop && has_double_tail_k_block_loop) - { - const auto kernel = - run_gridwise_operation, - integral_constant>; - - launch_kernel(kernel, - dim3(GridSize), - dim3(BlockSize), - 0, - 0, - wei_gemmk_gemmm_global_desc, - p_wei_global, - in_gemmk_gemmn_global_desc, - p_in_global, - out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc, - p_out_global, - integral_constant{}, - integral_constant{}); - } - else if(has_main_k_block_loop && !has_double_tail_k_block_loop) - { - const auto kernel = - run_gridwise_operation, - integral_constant>; - - launch_kernel(kernel, - dim3(GridSize), - dim3(BlockSize), - 0, - 0, - wei_gemmk_gemmm_global_desc, - p_wei_global, - in_gemmk_gemmn_global_desc, - p_in_global, - out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc, - p_out_global, - integral_constant{}, - integral_constant{}); - } - else if(!has_main_k_block_loop && has_double_tail_k_block_loop) - { - const auto kernel = - run_gridwise_operation, - integral_constant>; - - launch_kernel(kernel, - dim3(GridSize), - dim3(BlockSize), - 0, - 0, - wei_gemmk_gemmm_global_desc, - p_wei_global, - in_gemmk_gemmn_global_desc, - p_in_global, - out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc, - p_out_global, - integral_constant{}, - integral_constant{}); - } - else - { - const auto kernel = - run_gridwise_operation, - integral_constant>; - - launch_kernel(kernel, - dim3(GridSize), - dim3(BlockSize), - 0, - 0, - wei_gemmk_gemmm_global_desc, - p_wei_global, - in_gemmk_gemmn_global_desc, - p_in_global, - out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc, - p_out_global, - integral_constant{}, - integral_constant{}); - } - } - - timer.End(); - - float ave_time = timer.GetElapsedTime() / nrepeat; - - float perf = (float)calculate_convolution_flops(in_n_c_hi_wi_global_desc, - wei_k_c_y_x_global_desc, - out_n_k_ho_wo_global_desc) / - (std::size_t(1000) * 1000 * 1000) / ave_time; - - std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s" - << std::endl; - } -#elif CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VOID_POINTER - using ADesc = decltype(wei_gemmk_gemmm_global_desc); - using BDesc = decltype(in_gemmk_gemmn_global_desc); - using CDesc = decltype(out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc); - - DeviceMem wei_gemmk_gemmm_global_desc_device_buf(sizeof(ADesc)); - DeviceMem in_gemmk_gemmn_global_desc_device_buf(sizeof(BDesc)); - DeviceMem out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc_desc_device_buf(sizeof(CDesc)); - - wei_gemmk_gemmm_global_desc_device_buf.ToDevice(&wei_gemmk_gemmm_global_desc); - in_gemmk_gemmn_global_desc_device_buf.ToDevice(&in_gemmk_gemmn_global_desc); - out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc_desc_device_buf.ToDevice( - &out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc); - - index_t nrepeat = 100; - - for(index_t i = 0; i < 5; ++i) - { - std::cout << "Start running " << nrepeat << " times..." << std::endl; - - KernelTimer timer; - timer.Start(); - - for(index_t j = 0; j < nrepeat; ++j) - { - if(has_main_k_block_loop && has_double_tail_k_block_loop) - { - const auto kernel = run_gridwise_dynamic_gemm_v1; - - launch_kernel( - kernel, - dim3(GridSize), - dim3(BlockSize), - 0, - 0, - (void __CONSTANT__*) - wei_gemmk_gemmm_global_desc_device_buf.GetDeviceBuffer(), - p_wei_global, - (void __CONSTANT__*)in_gemmk_gemmn_global_desc_device_buf.GetDeviceBuffer(), - p_in_global, - (void __CONSTANT__*) - out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc_desc_device_buf - .GetDeviceBuffer(), - p_out_global); - } - else if(has_main_k_block_loop && !has_double_tail_k_block_loop) - { - const auto kernel = run_gridwise_dynamic_gemm_v1; - - launch_kernel( - kernel, - dim3(GridSize), - dim3(BlockSize), - 0, - 0, - (void __CONSTANT__*) - wei_gemmk_gemmm_global_desc_device_buf.GetDeviceBuffer(), - p_wei_global, - (void __CONSTANT__*)in_gemmk_gemmn_global_desc_device_buf.GetDeviceBuffer(), - p_in_global, - (void __CONSTANT__*) - out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc_desc_device_buf - .GetDeviceBuffer(), - p_out_global); - } - else if(!has_main_k_block_loop && has_double_tail_k_block_loop) - { - const auto kernel = run_gridwise_dynamic_gemm_v1; - - launch_kernel( - kernel, - dim3(GridSize), - dim3(BlockSize), - 0, - 0, - (void __CONSTANT__*) - wei_gemmk_gemmm_global_desc_device_buf.GetDeviceBuffer(), - p_wei_global, - (void __CONSTANT__*)in_gemmk_gemmn_global_desc_device_buf.GetDeviceBuffer(), - p_in_global, - (void __CONSTANT__*) - out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc_desc_device_buf - .GetDeviceBuffer(), - p_out_global); - } - else - { - const auto kernel = run_gridwise_dynamic_gemm_v1; - - launch_kernel( - kernel, - dim3(GridSize), - dim3(BlockSize), - 0, - 0, - (void __CONSTANT__*) - wei_gemmk_gemmm_global_desc_device_buf.GetDeviceBuffer(), - p_wei_global, - (void __CONSTANT__*)in_gemmk_gemmn_global_desc_device_buf.GetDeviceBuffer(), - p_in_global, - (void __CONSTANT__*) - out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc_desc_device_buf - .GetDeviceBuffer(), - p_out_global); - } - } - - timer.End(); - - float ave_time = timer.GetElapsedTime() / nrepeat; - - float perf = (float)calculate_convolution_flops(in_n_c_hi_wi_global_desc, - wei_k_c_y_x_global_desc, - out_n_k_ho_wo_global_desc) / - (std::size_t(1000) * 1000 * 1000) / ave_time; - - std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s" - << std::endl; - } -#endif - } -}; + return make_tuple(wei_gemmk_gemmm_global_desc, + in_gemmk_gemmn_global_desc, + out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc, + out_gemm_block_cluster_desc, + wei_gemmk_gemmm_global_iterator_hacks, + in_gemmk_gemmn_global_iterator_hacks, + out_gemmm0_gemmm1_gemmn0_gemmn1_global_iterator_hacks, + wei_gemmk_gemmm_global_move_slice_window_iterator_hacks, + in_gemmk_gemmn_global_move_slice_window_iterator_hacks); +} // GemmM = K // GemmN = N * Ho * Wo // GemmK = C * Y * X -template -struct DriverDynamicConvolutionForwardImplicitGemm_v4r4_nchw_kcyx_nkhw_no_pad + index_t GemmM1, + index_t GemmN1, + typename... Wei, + typename... In, + typename... Out, + typename ConvStrides, + typename ConvDilations, + typename InLeftPads, + typename InRightPads> +__host__ __device__ constexpr auto +transform_forward_convolution_into_gemm_v4r4_nchw_kcyx_nkhw_no_pad( + const DynamicTensorDescriptor& wei_k_c_y_x_global_desc, + const DynamicTensorDescriptor& in_n_c_hi_wi_global_desc, + const DynamicTensorDescriptor& out_n_k_ho_wo_global_desc, + const ConvStrides& conv_strides, + const ConvDilations& conv_dilations, + const InLeftPads& in_left_pads, + const InRightPads& in_right_pads) { - template - __host__ void Run(const DynamicTensorDescriptor& wei_k_c_y_x_global_desc, - const DynamicTensorDescriptor& in_n_c_hi_wi_global_desc, - const DynamicTensorDescriptor& out_n_k_ho_wo_global_desc, - const ConvStrides& conv_strides, - const ConvDilations& conv_dilations, - const InLeftPads& in_left_pads, - const InRightPads& in_right_pads, - const FloatAB* __restrict__ p_wei_global, - const FloatAB* __restrict__ p_in_global, - FloatC* __restrict__ p_out_global) const - { - constexpr auto I0 = Number<0>{}; - constexpr auto I1 = Number<1>{}; - constexpr auto I2 = Number<2>{}; - constexpr auto I3 = Number<3>{}; + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + constexpr auto I2 = Number<2>{}; + constexpr auto I3 = Number<3>{}; - const auto N = in_n_c_hi_wi_global_desc.GetLength(I0); - const auto C = in_n_c_hi_wi_global_desc.GetLength(I1); - const auto K = out_n_k_ho_wo_global_desc.GetLength(I1); + const auto N = in_n_c_hi_wi_global_desc.GetLength(I0); + const auto C = in_n_c_hi_wi_global_desc.GetLength(I1); + const auto K = out_n_k_ho_wo_global_desc.GetLength(I1); - const auto Hi = in_n_c_hi_wi_global_desc.GetLength(I2); - const auto Wi = in_n_c_hi_wi_global_desc.GetLength(I3); + const auto Hi = in_n_c_hi_wi_global_desc.GetLength(I2); + const auto Wi = in_n_c_hi_wi_global_desc.GetLength(I3); - const auto Ho = out_n_k_ho_wo_global_desc.GetLength(I2); - const auto Wo = out_n_k_ho_wo_global_desc.GetLength(I3); + const auto Ho = out_n_k_ho_wo_global_desc.GetLength(I2); + const auto Wo = out_n_k_ho_wo_global_desc.GetLength(I3); - const auto Y = wei_k_c_y_x_global_desc.GetLength(I2); - const auto X = wei_k_c_y_x_global_desc.GetLength(I3); + const auto Y = wei_k_c_y_x_global_desc.GetLength(I2); + const auto X = wei_k_c_y_x_global_desc.GetLength(I3); - const auto ConvStrideH = conv_strides[I0]; - const auto ConvStrideW = conv_strides[I1]; + const auto ConvStrideH = conv_strides[I0]; + const auto ConvStrideW = conv_strides[I1]; - const auto ConvDilationH = conv_dilations[I0]; - const auto ConvDilationW = conv_dilations[I1]; + const auto ConvDilationH = conv_dilations[I0]; + const auto ConvDilationW = conv_dilations[I1]; - const auto InLeftPadH = in_left_pads[I0]; - const auto InLeftPadW = in_left_pads[I1]; + const auto InLeftPadH = in_left_pads[I0]; + const auto InLeftPadW = in_left_pads[I1]; - const auto InRightPadH = in_right_pads[I0]; - const auto InRightPadW = in_right_pads[I1]; + const auto InRightPadH = in_right_pads[I0]; + const auto InRightPadW = in_right_pads[I1]; - if(!(InLeftPadH == 0 && InLeftPadW == 0 && InRightPadH == 0 && InRightPadW == 0)) - { - throw std::runtime_error("wrong! no padding"); - } + assert(InLeftPadH == 0 && InLeftPadW == 0 && InRightPadH == 0 && InRightPadW == 0); - // weight tensor - const auto wei_gemmk_gemmm_global_desc = transform_dynamic_tensor_descriptor( - make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(K, C * Y * X)), - make_tuple(make_pass_through_transform(K), make_pass_through_transform(C * Y * X)), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<1>{}, Sequence<0>{})); + // weight tensor + const auto wei_gemmk_gemmm_global_desc = transform_dynamic_tensor_descriptor( + make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(K, C * Y * X)), + make_tuple(make_pass_through_transform(K), make_pass_through_transform(C * Y * X)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<1>{}, Sequence<0>{})); - // input tensor - const auto in_n_c_y_ho_x_wo_global_desc = transform_dynamic_tensor_descriptor( - in_n_c_hi_wi_global_desc, - make_tuple( - make_pass_through_transform(N), - make_pass_through_transform(C), - make_embed_transform(make_tuple(Y, Ho), make_tuple(ConvDilationH, ConvStrideH)), - make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW))), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4, 5>{})); + // input tensor + const auto in_n_c_y_ho_x_wo_global_desc = transform_dynamic_tensor_descriptor( + in_n_c_hi_wi_global_desc, + make_tuple(make_pass_through_transform(N), + make_pass_through_transform(C), + make_embed_transform(make_tuple(Y, Ho), make_tuple(ConvDilationH, ConvStrideH)), + make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW))), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4, 5>{})); - const auto in_gemmk_gemmn_global_desc = transform_dynamic_tensor_descriptor( - in_n_c_y_ho_x_wo_global_desc, - make_tuple(make_merge_transform(make_tuple(C, Y, X)), - make_merge_transform(make_tuple(N, Ho, Wo))), - make_tuple(Sequence<1, 2, 4>{}, Sequence<0, 3, 5>{}), - make_tuple(Sequence<0>{}, Sequence<1>{})); + const auto in_gemmk_gemmn_global_desc = + transform_dynamic_tensor_descriptor(in_n_c_y_ho_x_wo_global_desc, + make_tuple(make_merge_transform(make_tuple(C, Y, X)), + make_merge_transform(make_tuple(N, Ho, Wo))), + make_tuple(Sequence<1, 2, 4>{}, Sequence<0, 3, 5>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); - // output tensor - const auto out_gemmm_gemmn_global_desc = transform_dynamic_tensor_descriptor( - make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(N, K, Ho * Wo)), - make_tuple(make_pass_through_transform(K), - make_merge_transform(make_tuple(N, Ho * Wo))), - make_tuple(Sequence<1>{}, Sequence<0, 2>{}), - make_tuple(Sequence<0>{}, Sequence<1>{})); + // output tensor + const auto out_gemmm_gemmn_global_desc = transform_dynamic_tensor_descriptor( + make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(N, K, Ho * Wo)), + make_tuple(make_pass_through_transform(K), make_merge_transform(make_tuple(N, Ho * Wo))), + make_tuple(Sequence<1>{}, Sequence<0, 2>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); - const auto GemmM = out_gemmm_gemmn_global_desc.GetLength(I0); - const auto GemmN = out_gemmm_gemmn_global_desc.GetLength(I1); - const auto GemmK = wei_gemmk_gemmm_global_desc.GetLength(I0); + const auto GemmM = out_gemmm_gemmn_global_desc.GetLength(I0); + const auto GemmN = out_gemmm_gemmn_global_desc.GetLength(I1); + const auto GemmK = wei_gemmk_gemmm_global_desc.GetLength(I0); - if(!(GemmM % GemmMPerBlock == 0 && GemmN % GemmNPerBlock == 0 && - GemmK % GemmKPerBlock == 0)) - { - throw std::runtime_error("wrong! GEMM size no divisible"); - } + assert(GemmM % GemmMPerBlock == 0 && GemmN % GemmNPerBlock == 0 && GemmK % GemmKPerBlock == 0); - constexpr auto GemmM1 = Number{}; - constexpr auto GemmN1 = Number{}; + const auto GemmM0 = GemmM / Number{}; + const auto GemmN0 = GemmN / Number{}; - const auto GemmM0 = GemmM / GemmM1; - const auto GemmN0 = GemmN / GemmN1; + const auto out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc = transform_dynamic_tensor_descriptor( + out_gemmm_gemmn_global_desc, + make_tuple(make_unmerge_transform(make_tuple(GemmM0, GemmM1)), + make_unmerge_transform(make_tuple(GemmN0, GemmN1))), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{})); - const auto out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc = - transform_dynamic_tensor_descriptor( - out_gemmm_gemmn_global_desc, - make_tuple(make_unmerge_transform(make_tuple(GemmM0, GemmM1)), - make_unmerge_transform(make_tuple(GemmN0, GemmN1))), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{})); + // out_gemm_block_cluster_desc + const auto out_gemm_block_cluster_desc = make_cluster_descriptor_v2( + make_tuple(GemmM / Number{}, GemmN / Number{})); - // hack to control index calculation when iterating over a_k_m_global tensor - constexpr auto a_k_m_global_iterator_hacks = - make_tuple(make_tuple(Sequence<0, 0, 0>{}, Sequence<0, 0, 0>{}), - make_tuple(Sequence<0, 0, 0>{}, Sequence<0, 0, 0>{})); + // hack to control index calculation when iterating over a_k_m_global tensor + constexpr auto wei_gemmk_gemmm_global_iterator_hacks = + make_tuple(make_tuple(Sequence<0, 0, 0>{}, Sequence<0, 0, 0>{}), + make_tuple(Sequence<0, 0, 0>{}, Sequence<0, 0, 0>{})); - constexpr auto a_k_m_global_move_slice_window_iterator_hack = Sequence<0, 0, 0>{}; + constexpr auto wei_gemmk_gemmm_global_move_slice_window_iterator_hacks = Sequence<0, 0, 0>{}; - // hack to control index calculation when iterating over b_k_n_global tensor - constexpr auto b_k_n_global_iterator_hacks = make_tuple( - make_tuple(Sequence<0, 0, 0, 0, 0, 1, 0>{}, Sequence<0, 0, 0, 0, 0, 0, 1>{}), - make_tuple(Sequence<0, 0, 0, 0, 0, 2, 0>{}, Sequence<0, 0, 0, 0, 0, 0, 2>{})); + // hack to control index calculation when iterating over b_k_n_global tensor + constexpr auto in_gemmk_gemmn_global_iterator_hacks = + make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 1, 0>{}, Sequence<0, 0, 0, 0, 0, 0, 1>{}), + make_tuple(Sequence<0, 0, 0, 0, 0, 2, 0>{}, Sequence<0, 0, 0, 0, 0, 0, 2>{})); - constexpr auto b_k_n_global_move_slice_window_iterator_hack = - Sequence<0, 0, 0, 0, 0, 1, 2>{}; + constexpr auto in_gemmk_gemmn_global_move_slice_window_iterator_hacks = + Sequence<0, 0, 0, 0, 0, 1, 2>{}; - // hack to control index calculation when iterating over c_m0_m1_n0_n1_global tensor - // hack for NKHW format - constexpr auto c_m0_m1_n0_n1_global_tensor_iterator_hacks = - make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 1, 0, 0>{}, - Sequence<0, 0, 1, 0, 0>{}), - make_tuple(Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 2, 0, 0>{}, - Sequence<0, 0, 2, 0, 0>{})); + // hack to control index calculation when iterating over c_m0_m1_n0_n1_global tensor + // hack for NKHW format + constexpr auto out_gemmm0_gemmm1_gemmn0_gemmn1_global_iterator_hacks = + make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 1, 0, 0>{}, + Sequence<0, 0, 1, 0, 0>{}), + make_tuple(Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 2, 0, 0>{}, + Sequence<0, 0, 2, 0, 0>{})); - // GEMM - using gridwise_gemm = GridwiseDynamicGemm_km_kn_m0m1n0n1_v1< - BlockSize, - FloatAB, - FloatAcc, - FloatC, - InMemoryDataOperation::Set, - decltype(wei_gemmk_gemmm_global_desc), - decltype(in_gemmk_gemmn_global_desc), - decltype(out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc), - GemmMPerBlock, - GemmNPerBlock, - GemmKPerBlock, - GemmMPerThread, - GemmNPerThread, - GemmKPerThread, - GemmMLevel0Cluster, - GemmNLevel0Cluster, - GemmMLevel1Cluster, - GemmNLevel1Cluster, - GemmABlockTransferThreadSliceLengths_GemmK_GemmM, - GemmABlockTransferThreadClusterLengths_GemmK_GemmM, - Sequence<1, 0>, - Sequence<1, 0>, - 0, - GemmABlockTransferSrcScalarPerVector_GemmK, - GemmABlockTransferDstScalarPerVector_GemmM, - false, // don't move back src coordinate after threadwise copy - GemmBBlockTransferThreadSliceLengths_GemmK_GemmN, - GemmBBlockTransferThreadClusterLengths_GemmK_GemmN, - Sequence<0, 1>, - Sequence<0, 1>, - 1, - GemmBBlockTransferSrcScalarPerVector_GemmN, - GemmBBlockTransferDstScalarPerVector_GemmN, - false, // don't move back src coordinate after threadwise copy, which will be fused with - // MoveSrcSliceWindow() to save addr computation - Sequence<2, 3, 0, 1>, - 3, - GemmCThreadTransferDstScalarPerVector_GemmN1, - decltype(a_k_m_global_iterator_hacks), - decltype(b_k_n_global_iterator_hacks), - decltype(c_m0_m1_n0_n1_global_tensor_iterator_hacks), - decltype(a_k_m_global_move_slice_window_iterator_hack), - decltype(b_k_n_global_move_slice_window_iterator_hack)>; + return make_tuple(wei_gemmk_gemmm_global_desc, + in_gemmk_gemmn_global_desc, + out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc, + out_gemm_block_cluster_desc, + wei_gemmk_gemmm_global_iterator_hacks, + in_gemmk_gemmn_global_iterator_hacks, + out_gemmm0_gemmm1_gemmn0_gemmn1_global_iterator_hacks, + wei_gemmk_gemmm_global_move_slice_window_iterator_hacks, + in_gemmk_gemmn_global_move_slice_window_iterator_hacks); +} - const auto GridSize = (GemmM / GemmMPerBlock) * (GemmN / GemmNPerBlock); - - const bool has_main_k_block_loop = (GemmK + GemmKPerBlock) / (2 * GemmKPerBlock) > 1; - - const bool has_double_tail_k_block_loop = (GemmK / GemmKPerBlock) % 2 == 0; - -#if CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VALUE - index_t nrepeat = 100; - - for(index_t i = 0; i < 5; ++i) - { - std::cout << "Start running " << nrepeat << " times..." << std::endl; - - KernelTimer timer; - timer.Start(); - - for(index_t j = 0; j < nrepeat; ++j) - { - if(has_main_k_block_loop && has_double_tail_k_block_loop) - { - const auto kernel = - run_gridwise_operation, - integral_constant>; - - launch_kernel(kernel, - dim3(GridSize), - dim3(BlockSize), - 0, - 0, - wei_gemmk_gemmm_global_desc, - p_wei_global, - in_gemmk_gemmn_global_desc, - p_in_global, - out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc, - p_out_global, - integral_constant{}, - integral_constant{}); - } - else if(has_main_k_block_loop && !has_double_tail_k_block_loop) - { - const auto kernel = - run_gridwise_operation, - integral_constant>; - - launch_kernel(kernel, - dim3(GridSize), - dim3(BlockSize), - 0, - 0, - wei_gemmk_gemmm_global_desc, - p_wei_global, - in_gemmk_gemmn_global_desc, - p_in_global, - out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc, - p_out_global, - integral_constant{}, - integral_constant{}); - } - else if(!has_main_k_block_loop && has_double_tail_k_block_loop) - { - const auto kernel = - run_gridwise_operation, - integral_constant>; - - launch_kernel(kernel, - dim3(GridSize), - dim3(BlockSize), - 0, - 0, - wei_gemmk_gemmm_global_desc, - p_wei_global, - in_gemmk_gemmn_global_desc, - p_in_global, - out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc, - p_out_global, - integral_constant{}, - integral_constant{}); - } - else - { - const auto kernel = - run_gridwise_operation, - integral_constant>; - - launch_kernel(kernel, - dim3(GridSize), - dim3(BlockSize), - 0, - 0, - wei_gemmk_gemmm_global_desc, - p_wei_global, - in_gemmk_gemmn_global_desc, - p_in_global, - out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc, - p_out_global, - integral_constant{}, - integral_constant{}); - } - } - - timer.End(); - - float ave_time = timer.GetElapsedTime() / nrepeat; - - float perf = (float)calculate_convolution_flops(in_n_c_hi_wi_global_desc, - wei_k_c_y_x_global_desc, - out_n_k_ho_wo_global_desc) / - (std::size_t(1000) * 1000 * 1000) / ave_time; - - std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s" - << std::endl; - } -#elif CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VOID_POINTER - using ADesc = decltype(wei_gemmk_gemmm_global_desc); - using BDesc = decltype(in_gemmk_gemmn_global_desc); - using CDesc = decltype(out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc); - - DeviceMem wei_gemmk_gemmm_global_desc_device_buf(sizeof(ADesc)); - DeviceMem in_gemmk_gemmn_global_desc_device_buf(sizeof(BDesc)); - DeviceMem out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc_desc_device_buf(sizeof(CDesc)); - - wei_gemmk_gemmm_global_desc_device_buf.ToDevice(&wei_gemmk_gemmm_global_desc); - in_gemmk_gemmn_global_desc_device_buf.ToDevice(&in_gemmk_gemmn_global_desc); - out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc_desc_device_buf.ToDevice( - &out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc); - - index_t nrepeat = 100; - - for(index_t i = 0; i < 5; ++i) - { - std::cout << "Start running " << nrepeat << " times..." << std::endl; - - KernelTimer timer; - timer.Start(); - - for(index_t j = 0; j < nrepeat; ++j) - { - if(has_main_k_block_loop && has_double_tail_k_block_loop) - { - const auto kernel = run_gridwise_dynamic_gemm_v1; - - launch_kernel( - kernel, - dim3(GridSize), - dim3(BlockSize), - 0, - 0, - (void __CONSTANT__*) - wei_gemmk_gemmm_global_desc_device_buf.GetDeviceBuffer(), - p_wei_global, - (void __CONSTANT__*)in_gemmk_gemmn_global_desc_device_buf.GetDeviceBuffer(), - p_in_global, - (void __CONSTANT__*) - out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc_desc_device_buf - .GetDeviceBuffer(), - p_out_global); - } - else if(has_main_k_block_loop && !has_double_tail_k_block_loop) - { - const auto kernel = run_gridwise_dynamic_gemm_v1; - - launch_kernel( - kernel, - dim3(GridSize), - dim3(BlockSize), - 0, - 0, - (void __CONSTANT__*) - wei_gemmk_gemmm_global_desc_device_buf.GetDeviceBuffer(), - p_wei_global, - (void __CONSTANT__*)in_gemmk_gemmn_global_desc_device_buf.GetDeviceBuffer(), - p_in_global, - (void __CONSTANT__*) - out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc_desc_device_buf - .GetDeviceBuffer(), - p_out_global); - } - else if(!has_main_k_block_loop && has_double_tail_k_block_loop) - { - const auto kernel = run_gridwise_dynamic_gemm_v1; - - launch_kernel( - kernel, - dim3(GridSize), - dim3(BlockSize), - 0, - 0, - (void __CONSTANT__*) - wei_gemmk_gemmm_global_desc_device_buf.GetDeviceBuffer(), - p_wei_global, - (void __CONSTANT__*)in_gemmk_gemmn_global_desc_device_buf.GetDeviceBuffer(), - p_in_global, - (void __CONSTANT__*) - out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc_desc_device_buf - .GetDeviceBuffer(), - p_out_global); - } - else - { - const auto kernel = run_gridwise_dynamic_gemm_v1; - - launch_kernel( - kernel, - dim3(GridSize), - dim3(BlockSize), - 0, - 0, - (void __CONSTANT__*) - wei_gemmk_gemmm_global_desc_device_buf.GetDeviceBuffer(), - p_wei_global, - (void __CONSTANT__*)in_gemmk_gemmn_global_desc_device_buf.GetDeviceBuffer(), - p_in_global, - (void __CONSTANT__*) - out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc_desc_device_buf - .GetDeviceBuffer(), - p_out_global); - } - } - - timer.End(); - - float ave_time = timer.GetElapsedTime() / nrepeat; - - float perf = (float)calculate_convolution_flops(in_n_c_hi_wi_global_desc, - wei_k_c_y_x_global_desc, - out_n_k_ho_wo_global_desc) / - (std::size_t(1000) * 1000 * 1000) / ave_time; - - std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s" - << std::endl; - } -#endif - } -}; - -template -struct DriverDynamicConvolutionForwardImplicitGemm_v4r4_nchw_kcyx_nkhw_1x1 + index_t GemmM1, + index_t GemmN1, + typename... Wei, + typename... In, + typename... Out, + typename ConvStrides, + typename ConvDilations, + typename InLeftPads, + typename InRightPads> +__host__ __device__ constexpr auto transform_forward_convolution_into_gemm_v4r4_nchw_kcyx_nkhw_1x1( + const DynamicTensorDescriptor& wei_k_c_y_x_global_desc, + const DynamicTensorDescriptor& in_n_c_hi_wi_global_desc, + const DynamicTensorDescriptor& out_n_k_ho_wo_global_desc, + const ConvStrides& conv_strides, + const ConvDilations& conv_dilations, + const InLeftPads& in_left_pads, + const InRightPads& in_right_pads) { - template - __host__ void Run(const DynamicTensorDescriptor& wei_k_c_y_x_global_desc, - const DynamicTensorDescriptor& in_n_c_hi_wi_global_desc, - const DynamicTensorDescriptor& out_n_k_ho_wo_global_desc, - const ConvStrides& conv_strides, - const ConvDilations& conv_dilations, - const InLeftPads& in_left_pads, - const InRightPads& in_right_pads, - const FloatAB* __restrict__ p_wei_global, - const FloatAB* __restrict__ p_in_global, - FloatC* __restrict__ p_out_global) const - { - constexpr auto I0 = Number<0>{}; - constexpr auto I1 = Number<1>{}; - constexpr auto I2 = Number<2>{}; - constexpr auto I3 = Number<3>{}; + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + constexpr auto I2 = Number<2>{}; + constexpr auto I3 = Number<3>{}; - const auto N = in_n_c_hi_wi_global_desc.GetLength(I0); - const auto C = in_n_c_hi_wi_global_desc.GetLength(I1); - const auto K = out_n_k_ho_wo_global_desc.GetLength(I1); + const auto N = in_n_c_hi_wi_global_desc.GetLength(I0); + const auto C = in_n_c_hi_wi_global_desc.GetLength(I1); + const auto K = out_n_k_ho_wo_global_desc.GetLength(I1); - const auto Hi = in_n_c_hi_wi_global_desc.GetLength(I2); - const auto Wi = in_n_c_hi_wi_global_desc.GetLength(I3); + const auto Hi = in_n_c_hi_wi_global_desc.GetLength(I2); + const auto Wi = in_n_c_hi_wi_global_desc.GetLength(I3); - const auto Ho = out_n_k_ho_wo_global_desc.GetLength(I2); - const auto Wo = out_n_k_ho_wo_global_desc.GetLength(I3); + const auto Ho = out_n_k_ho_wo_global_desc.GetLength(I2); + const auto Wo = out_n_k_ho_wo_global_desc.GetLength(I3); - const auto Y = wei_k_c_y_x_global_desc.GetLength(I2); - const auto X = wei_k_c_y_x_global_desc.GetLength(I3); + const auto Y = wei_k_c_y_x_global_desc.GetLength(I2); + const auto X = wei_k_c_y_x_global_desc.GetLength(I3); - const auto ConvStrideH = conv_strides[I0]; - const auto ConvStrideW = conv_strides[I1]; + const auto ConvStrideH = conv_strides[I0]; + const auto ConvStrideW = conv_strides[I1]; - const auto ConvDilationH = conv_dilations[I0]; - const auto ConvDilationW = conv_dilations[I1]; + const auto ConvDilationH = conv_dilations[I0]; + const auto ConvDilationW = conv_dilations[I1]; - const auto InLeftPadH = in_left_pads[I0]; - const auto InLeftPadW = in_left_pads[I1]; + const auto InLeftPadH = in_left_pads[I0]; + const auto InLeftPadW = in_left_pads[I1]; - const auto InRightPadH = in_right_pads[I0]; - const auto InRightPadW = in_right_pads[I1]; + const auto InRightPadH = in_right_pads[I0]; + const auto InRightPadW = in_right_pads[I1]; - if(!(Y == 1 && X == 1 && ConvStrideH == 1 && ConvStrideW == 1 && ConvDilationH == 1 && - ConvDilationW == 1 && InLeftPadH == 0 && InLeftPadW == 0 && InRightPadH == 0 && - InRightPadW == 0)) - { - throw std::runtime_error("wrong! 1x1, stride 1, no padding"); - } + assert(Y == 1 && X == 1 && ConvStrideH == 1 && ConvStrideW == 1 && ConvDilationH == 1 && + ConvDilationW == 1 && InLeftPadH == 0 && InLeftPadW == 0 && InRightPadH == 0 && + InRightPadW == 0); - // weight tensor - const auto wei_gemmk_gemmm_global_desc = transform_dynamic_tensor_descriptor( - make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(K, C)), - make_tuple(make_pass_through_transform(K), make_pass_through_transform(C)), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<1>{}, Sequence<0>{})); + // weight tensor + const auto wei_gemmk_gemmm_global_desc = transform_dynamic_tensor_descriptor( + make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(K, C)), + make_tuple(make_pass_through_transform(K), make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<1>{}, Sequence<0>{})); - // input tensor - const auto in_gemmk_gemmn_global_desc = transform_dynamic_tensor_descriptor( - in_n_c_hi_wi_global_desc, - make_tuple(make_pass_through_transform(C), make_merge_transform(make_tuple(N, Ho, Wo))), - make_tuple(Sequence<1>{}, Sequence<0, 2, 3>{}), - make_tuple(Sequence<0>{}, Sequence<1>{})); + // input tensor + const auto in_gemmk_gemmn_global_desc = transform_dynamic_tensor_descriptor( + in_n_c_hi_wi_global_desc, + make_tuple(make_pass_through_transform(C), make_merge_transform(make_tuple(N, Ho, Wo))), + make_tuple(Sequence<1>{}, Sequence<0, 2, 3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); - // output tensor - const auto out_gemmm_gemmn_global_desc = transform_dynamic_tensor_descriptor( - make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(N, K, Ho * Wo)), - make_tuple(make_pass_through_transform(K), - make_merge_transform(make_tuple(N, Ho * Wo))), - make_tuple(Sequence<1>{}, Sequence<0, 2>{}), - make_tuple(Sequence<0>{}, Sequence<1>{})); + // output tensor + const auto out_gemmm_gemmn_global_desc = transform_dynamic_tensor_descriptor( + make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(N, K, Ho * Wo)), + make_tuple(make_pass_through_transform(K), make_merge_transform(make_tuple(N, Ho * Wo))), + make_tuple(Sequence<1>{}, Sequence<0, 2>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); - const auto GemmM = out_gemmm_gemmn_global_desc.GetLength(I0); - const auto GemmN = out_gemmm_gemmn_global_desc.GetLength(I1); - const auto GemmK = wei_gemmk_gemmm_global_desc.GetLength(I0); + const auto GemmM = out_gemmm_gemmn_global_desc.GetLength(I0); + const auto GemmN = out_gemmm_gemmn_global_desc.GetLength(I1); + const auto GemmK = wei_gemmk_gemmm_global_desc.GetLength(I0); - if(!(GemmM % GemmMPerBlock == 0 && GemmN % GemmNPerBlock == 0 && - GemmK % GemmKPerBlock == 0)) - { - throw std::runtime_error("wrong! GEMM size no divisible"); - } + assert(GemmM % GemmMPerBlock == 0 && GemmN % GemmNPerBlock == 0 && GemmK % GemmKPerBlock == 0); - constexpr auto GemmM1 = Number{}; - constexpr auto GemmN1 = Number{}; + const auto GemmM0 = GemmM / Number{}; + const auto GemmN0 = GemmN / Number{}; - const auto GemmM0 = GemmM / GemmM1; - const auto GemmN0 = GemmN / GemmN1; + const auto out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc = transform_dynamic_tensor_descriptor( + out_gemmm_gemmn_global_desc, + make_tuple(make_unmerge_transform(make_tuple(GemmM0, GemmM1)), + make_unmerge_transform(make_tuple(GemmN0, GemmN1))), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{})); - const auto out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc = - transform_dynamic_tensor_descriptor( - out_gemmm_gemmn_global_desc, - make_tuple(make_unmerge_transform(make_tuple(GemmM0, GemmM1)), - make_unmerge_transform(make_tuple(GemmN0, GemmN1))), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{})); + // out_gemm_block_cluster_desc + const auto out_gemm_block_cluster_desc = make_cluster_descriptor_v2( + make_tuple(GemmM / Number{}, GemmN / Number{})); - // hack to control index calculation when iterating over a_k_m_global tensor - constexpr auto a_k_m_global_iterator_hacks = - make_tuple(make_tuple(Sequence<0, 0, 0>{}, Sequence<0, 0, 0>{}), - make_tuple(Sequence<0, 0, 0>{}, Sequence<0, 0, 0>{})); + // hack to control index calculation when iterating over a_k_m_global tensor + constexpr auto wei_gemmk_gemmm_global_iterator_hacks = + make_tuple(make_tuple(Sequence<0, 0, 0>{}, Sequence<0, 0, 0>{}), + make_tuple(Sequence<0, 0, 0>{}, Sequence<0, 0, 0>{})); - constexpr auto a_k_m_global_move_slice_window_iterator_hack = Sequence<0, 0, 0>{}; + constexpr auto wei_gemmk_gemmm_global_move_slice_window_iterator_hacks = Sequence<0, 0, 0>{}; - // hack to control index calculation when iterating over b_k_n_global tensor - constexpr auto b_k_n_global_iterator_hacks = - make_tuple(make_tuple(Sequence<0, 1, 0>{}, Sequence<0, 0, 1>{}), - make_tuple(Sequence<0, 2, 0>{}, Sequence<0, 0, 2>{})); + // hack to control index calculation when iterating over b_k_n_global tensor + constexpr auto in_gemmk_gemmn_global_iterator_hacks = + make_tuple(make_tuple(Sequence<0, 1, 0>{}, Sequence<0, 0, 1>{}), + make_tuple(Sequence<0, 2, 0>{}, Sequence<0, 0, 2>{})); - constexpr auto b_k_n_global_move_slice_window_iterator_hack = Sequence<0, 1, 2>{}; + constexpr auto in_gemmk_gemmn_global_move_slice_window_iterator_hacks = Sequence<0, 1, 2>{}; - // hack to control index calculation when iterating over c_m0_m1_n0_n1_global tensor - constexpr auto c_m0_m1_n0_n1_global_tensor_iterator_hacks = - make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 1, 0, 0>{}, - Sequence<0, 0, 1, 0, 0>{}), - make_tuple(Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 2, 0, 0>{}, - Sequence<0, 0, 2, 0, 0>{})); + // hack to control index calculation when iterating over c_m0_m1_n0_n1_global tensor + constexpr auto out_gemmm0_gemmm1_gemmn0_gemmn1_global_iterator_hacks = + make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 1, 0, 0>{}, + Sequence<0, 0, 1, 0, 0>{}), + make_tuple(Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 2, 0, 0>{}, + Sequence<0, 0, 2, 0, 0>{})); - // GEMM - using gridwise_gemm = GridwiseDynamicGemm_km_kn_m0m1n0n1_v1< - BlockSize, - FloatAB, - FloatAcc, - FloatC, - InMemoryDataOperation::Set, - decltype(wei_gemmk_gemmm_global_desc), - decltype(in_gemmk_gemmn_global_desc), - decltype(out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc), - GemmMPerBlock, - GemmNPerBlock, - GemmKPerBlock, - GemmMPerThread, - GemmNPerThread, - GemmKPerThread, - GemmMLevel0Cluster, - GemmNLevel0Cluster, - GemmMLevel1Cluster, - GemmNLevel1Cluster, - GemmABlockTransferThreadSliceLengths_GemmK_GemmM, - GemmABlockTransferThreadClusterLengths_GemmK_GemmM, - Sequence<1, 0>, - Sequence<1, 0>, - 0, - GemmABlockTransferSrcScalarPerVector_GemmK, - GemmABlockTransferDstScalarPerVector_GemmM, - false, // don't move back src coordinate after threadwise copy - GemmBBlockTransferThreadSliceLengths_GemmK_GemmN, - GemmBBlockTransferThreadClusterLengths_GemmK_GemmN, - Sequence<0, 1>, - Sequence<0, 1>, - 1, - GemmBBlockTransferSrcScalarPerVector_GemmN, - GemmBBlockTransferDstScalarPerVector_GemmN, - false, // don't move back src coordinate after threadwise copy, which will be fused with - // MoveSrcSliceWindow() to save addr computation - Sequence<2, 3, 0, 1>, - 3, - GemmCThreadTransferDstScalarPerVector_GemmN1, - decltype(a_k_m_global_iterator_hacks), - decltype(b_k_n_global_iterator_hacks), - decltype(c_m0_m1_n0_n1_global_tensor_iterator_hacks), - decltype(a_k_m_global_move_slice_window_iterator_hack), - decltype(b_k_n_global_move_slice_window_iterator_hack)>; - - const auto GridSize = (GemmM / GemmMPerBlock) * (GemmN / GemmNPerBlock); - - const bool has_main_k_block_loop = (GemmK + GemmKPerBlock) / (2 * GemmKPerBlock) > 1; - - const bool has_double_tail_k_block_loop = (GemmK / GemmKPerBlock) % 2 == 0; - -#if CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VALUE - index_t nrepeat = 100; - - for(index_t i = 0; i < 5; ++i) - { - std::cout << "Start running " << nrepeat << " times..." << std::endl; - - KernelTimer timer; - timer.Start(); - - for(index_t j = 0; j < nrepeat; ++j) - { - if(has_main_k_block_loop && has_double_tail_k_block_loop) - { - const auto kernel = - run_gridwise_operation, - integral_constant>; - - launch_kernel(kernel, - dim3(GridSize), - dim3(BlockSize), - 0, - 0, - wei_gemmk_gemmm_global_desc, - p_wei_global, - in_gemmk_gemmn_global_desc, - p_in_global, - out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc, - p_out_global, - integral_constant{}, - integral_constant{}); - } - else if(has_main_k_block_loop && !has_double_tail_k_block_loop) - { - const auto kernel = - run_gridwise_operation, - integral_constant>; - - launch_kernel(kernel, - dim3(GridSize), - dim3(BlockSize), - 0, - 0, - wei_gemmk_gemmm_global_desc, - p_wei_global, - in_gemmk_gemmn_global_desc, - p_in_global, - out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc, - p_out_global, - integral_constant{}, - integral_constant{}); - } - else if(!has_main_k_block_loop && has_double_tail_k_block_loop) - { - const auto kernel = - run_gridwise_operation, - integral_constant>; - - launch_kernel(kernel, - dim3(GridSize), - dim3(BlockSize), - 0, - 0, - wei_gemmk_gemmm_global_desc, - p_wei_global, - in_gemmk_gemmn_global_desc, - p_in_global, - out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc, - p_out_global, - integral_constant{}, - integral_constant{}); - } - else - { - const auto kernel = - run_gridwise_operation, - integral_constant>; - - launch_kernel(kernel, - dim3(GridSize), - dim3(BlockSize), - 0, - 0, - wei_gemmk_gemmm_global_desc, - p_wei_global, - in_gemmk_gemmn_global_desc, - p_in_global, - out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc, - p_out_global, - integral_constant{}, - integral_constant{}); - } - } - - timer.End(); - - float ave_time = timer.GetElapsedTime() / nrepeat; - - float perf = (float)calculate_convolution_flops(in_n_c_hi_wi_global_desc, - wei_k_c_y_x_global_desc, - out_n_k_ho_wo_global_desc) / - (std::size_t(1000) * 1000 * 1000) / ave_time; - - std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s" - << std::endl; - } -#elif CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VOID_POINTER - using ADesc = decltype(wei_gemmk_gemmm_global_desc); - using BDesc = decltype(in_gemmk_gemmn_global_desc); - using CDesc = decltype(out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc); - - DeviceMem wei_gemmk_gemmm_global_desc_device_buf(sizeof(ADesc)); - DeviceMem in_gemmk_gemmn_global_desc_device_buf(sizeof(BDesc)); - DeviceMem out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc_desc_device_buf(sizeof(CDesc)); - - wei_gemmk_gemmm_global_desc_device_buf.ToDevice(&wei_gemmk_gemmm_global_desc); - in_gemmk_gemmn_global_desc_device_buf.ToDevice(&in_gemmk_gemmn_global_desc); - out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc_desc_device_buf.ToDevice( - &out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc); - - index_t nrepeat = 100; - - for(index_t i = 0; i < 5; ++i) - { - std::cout << "Start running " << nrepeat << " times..." << std::endl; - - KernelTimer timer; - timer.Start(); - - for(index_t j = 0; j < nrepeat; ++j) - { - if(has_main_k_block_loop && has_double_tail_k_block_loop) - { - const auto kernel = run_gridwise_dynamic_gemm_v1; - - launch_kernel( - kernel, - dim3(GridSize), - dim3(BlockSize), - 0, - 0, - (void __CONSTANT__*) - wei_gemmk_gemmm_global_desc_device_buf.GetDeviceBuffer(), - p_wei_global, - (void __CONSTANT__*)in_gemmk_gemmn_global_desc_device_buf.GetDeviceBuffer(), - p_in_global, - (void __CONSTANT__*) - out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc_desc_device_buf - .GetDeviceBuffer(), - p_out_global); - } - else if(has_main_k_block_loop && !has_double_tail_k_block_loop) - { - const auto kernel = run_gridwise_dynamic_gemm_v1; - - launch_kernel( - kernel, - dim3(GridSize), - dim3(BlockSize), - 0, - 0, - (void __CONSTANT__*) - wei_gemmk_gemmm_global_desc_device_buf.GetDeviceBuffer(), - p_wei_global, - (void __CONSTANT__*)in_gemmk_gemmn_global_desc_device_buf.GetDeviceBuffer(), - p_in_global, - (void __CONSTANT__*) - out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc_desc_device_buf - .GetDeviceBuffer(), - p_out_global); - } - else if(!has_main_k_block_loop && has_double_tail_k_block_loop) - { - const auto kernel = run_gridwise_dynamic_gemm_v1; - - launch_kernel( - kernel, - dim3(GridSize), - dim3(BlockSize), - 0, - 0, - (void __CONSTANT__*) - wei_gemmk_gemmm_global_desc_device_buf.GetDeviceBuffer(), - p_wei_global, - (void __CONSTANT__*)in_gemmk_gemmn_global_desc_device_buf.GetDeviceBuffer(), - p_in_global, - (void __CONSTANT__*) - out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc_desc_device_buf - .GetDeviceBuffer(), - p_out_global); - } - else - { - const auto kernel = run_gridwise_dynamic_gemm_v1; - - launch_kernel( - kernel, - dim3(GridSize), - dim3(BlockSize), - 0, - 0, - (void __CONSTANT__*) - wei_gemmk_gemmm_global_desc_device_buf.GetDeviceBuffer(), - p_wei_global, - (void __CONSTANT__*)in_gemmk_gemmn_global_desc_device_buf.GetDeviceBuffer(), - p_in_global, - (void __CONSTANT__*) - out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc_desc_device_buf - .GetDeviceBuffer(), - p_out_global); - } - } - - timer.End(); - - float ave_time = timer.GetElapsedTime() / nrepeat; - - float perf = (float)calculate_convolution_flops(in_n_c_hi_wi_global_desc, - wei_k_c_y_x_global_desc, - out_n_k_ho_wo_global_desc) / - (std::size_t(1000) * 1000 * 1000) / ave_time; - - std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s" - << std::endl; - } -#endif - } -}; + return make_tuple(wei_gemmk_gemmm_global_desc, + in_gemmk_gemmn_global_desc, + out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc, + out_gemm_block_cluster_desc, + wei_gemmk_gemmm_global_iterator_hacks, + in_gemmk_gemmn_global_iterator_hacks, + out_gemmm0_gemmm1_gemmn0_gemmn1_global_iterator_hacks, + wei_gemmk_gemmm_global_move_slice_window_iterator_hacks, + in_gemmk_gemmn_global_move_slice_window_iterator_hacks); +} } // namespace ck #endif diff --git a/composable_kernel/include/driver/driver_dynamic_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk.hpp b/composable_kernel/include/driver/driver_dynamic_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk.hpp index 922a036013..98b0e87119 100644 --- a/composable_kernel/include/driver/driver_dynamic_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk.hpp +++ b/composable_kernel/include/driver/driver_dynamic_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk.hpp @@ -4,1015 +4,297 @@ #include "common_header.hpp" #include "dynamic_tensor_descriptor.hpp" #include "dynamic_tensor_descriptor_helper.hpp" -#include "gridwise_dynamic_gemm.hpp" -#include "gridwise_operation_wrapper.hpp" +#include "driver_dynamic_gemm_v1.hpp" namespace ck { // GemmM = K // GemmN = N * Ho * Wo -// GemmK = Y * X * C -template -struct DriverDynamicConvolutionForwardImplicitGemm_v4r4_nhwc_kyxc_nhwk_pad + index_t GemmM1, + index_t GemmN1, + typename... Wei, + typename... In, + typename... Out, + typename ConvStrides, + typename ConvDilations, + typename InLeftPads, + typename InRightPads> +__host__ __device__ constexpr auto transform_forward_convolution_into_gemm_v4r4_nhwc_kyxc_nhwk_pad( + const DynamicTensorDescriptor& wei_k_y_x_c_global_desc, + const DynamicTensorDescriptor& in_n_hi_wi_c_global_desc, + const DynamicTensorDescriptor& out_n_ho_wo_k_global_desc, + const ConvStrides& conv_strides, + const ConvDilations& conv_dilations, + const InLeftPads& in_left_pads, + const InRightPads& in_right_pads) { - template - __host__ void Run(const DynamicTensorDescriptor& wei_k_y_x_c_global_desc, - const DynamicTensorDescriptor& in_n_hi_wi_c_global_desc, - const DynamicTensorDescriptor& out_n_ho_wo_k_global_desc, - const ConvStrides& conv_strides, - const ConvDilations& conv_dilations, - const InLeftPads& in_left_pads, - const InRightPads& in_right_pads, - const FloatAB* __restrict__ p_wei_global, - const FloatAB* __restrict__ p_in_global, - FloatC* __restrict__ p_out_global) const - { - constexpr auto I0 = Number<0>{}; - constexpr auto I1 = Number<1>{}; - constexpr auto I2 = Number<2>{}; - constexpr auto I3 = Number<3>{}; + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + constexpr auto I2 = Number<2>{}; + constexpr auto I3 = Number<3>{}; - const auto N = in_n_hi_wi_c_global_desc.GetLength(I0); - const auto C = in_n_hi_wi_c_global_desc.GetLength(I3); - const auto K = out_n_ho_wo_k_global_desc.GetLength(I3); + const auto N = in_n_hi_wi_c_global_desc.GetLength(I0); + const auto C = in_n_hi_wi_c_global_desc.GetLength(I3); + const auto K = out_n_ho_wo_k_global_desc.GetLength(I3); - const auto Hi = in_n_hi_wi_c_global_desc.GetLength(I1); - const auto Wi = in_n_hi_wi_c_global_desc.GetLength(I2); + const auto Hi = in_n_hi_wi_c_global_desc.GetLength(I1); + const auto Wi = in_n_hi_wi_c_global_desc.GetLength(I2); - const auto Ho = out_n_ho_wo_k_global_desc.GetLength(I1); - const auto Wo = out_n_ho_wo_k_global_desc.GetLength(I2); + const auto Ho = out_n_ho_wo_k_global_desc.GetLength(I1); + const auto Wo = out_n_ho_wo_k_global_desc.GetLength(I2); - const auto Y = wei_k_y_x_c_global_desc.GetLength(I1); - const auto X = wei_k_y_x_c_global_desc.GetLength(I2); + const auto Y = wei_k_y_x_c_global_desc.GetLength(I1); + const auto X = wei_k_y_x_c_global_desc.GetLength(I2); - const auto ConvStrideH = conv_strides[I0]; - const auto ConvStrideW = conv_strides[I1]; + const auto ConvStrideH = conv_strides[I0]; + const auto ConvStrideW = conv_strides[I1]; - const auto ConvDilationH = conv_dilations[I0]; - const auto ConvDilationW = conv_dilations[I1]; + const auto ConvDilationH = conv_dilations[I0]; + const auto ConvDilationW = conv_dilations[I1]; - const auto InLeftPadH = in_left_pads[I0]; - const auto InLeftPadW = in_left_pads[I1]; + const auto InLeftPadH = in_left_pads[I0]; + const auto InLeftPadW = in_left_pads[I1]; - const auto InRightPadH = in_right_pads[I0]; - const auto InRightPadW = in_right_pads[I1]; + const auto InRightPadH = in_right_pads[I0]; + const auto InRightPadW = in_right_pads[I1]; - // weight tensor - const auto wei_gemmk_gemmm_global_desc = transform_dynamic_tensor_descriptor( - make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(K, Y * X * C)), - make_tuple(make_pass_through_transform(K), make_pass_through_transform(Y * X * C)), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<1>{}, Sequence<0>{})); + // weight tensor + const auto wei_gemmk_gemmm_global_desc = transform_dynamic_tensor_descriptor( + make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(K, Y * X * C)), + make_tuple(make_pass_through_transform(K), make_pass_through_transform(Y * X * C)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<1>{}, Sequence<0>{})); - // input tensor - const auto in_n_hip_wip_c_global_desc = transform_dynamic_tensor_descriptor( - in_n_hi_wi_c_global_desc, - make_tuple(make_pass_through_transform(N), - make_pad_transform(Hi, InLeftPadH, InRightPadH), - make_pad_transform(Wi, InLeftPadW, InRightPadW), - make_pass_through_transform(C)), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); + // input tensor + const auto in_n_hip_wip_c_global_desc = transform_dynamic_tensor_descriptor( + in_n_hi_wi_c_global_desc, + make_tuple(make_pass_through_transform(N), + make_pad_transform(Hi, InLeftPadH, InRightPadH), + make_pad_transform(Wi, InLeftPadW, InRightPadW), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); - const auto in_n_y_ho_x_wo_c_global_desc = transform_dynamic_tensor_descriptor( - in_n_hip_wip_c_global_desc, - make_tuple( - make_pass_through_transform(N), - make_embed_transform(make_tuple(Y, Ho), make_tuple(ConvDilationH, ConvStrideH)), - make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW)), - make_pass_through_transform(C)), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), - make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{})); + const auto in_n_y_ho_x_wo_c_global_desc = transform_dynamic_tensor_descriptor( + in_n_hip_wip_c_global_desc, + make_tuple(make_pass_through_transform(N), + make_embed_transform(make_tuple(Y, Ho), make_tuple(ConvDilationH, ConvStrideH)), + make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW)), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{})); - const auto in_gemmk_gemmn_global_desc = transform_dynamic_tensor_descriptor( - in_n_y_ho_x_wo_c_global_desc, - make_tuple(make_merge_transform(make_tuple(Y, X, C)), - make_merge_transform(make_tuple(N, Ho, Wo))), - make_tuple(Sequence<1, 3, 5>{}, Sequence<0, 2, 4>{}), - make_tuple(Sequence<0>{}, Sequence<1>{})); + const auto in_gemmk_gemmn_global_desc = + transform_dynamic_tensor_descriptor(in_n_y_ho_x_wo_c_global_desc, + make_tuple(make_merge_transform(make_tuple(Y, X, C)), + make_merge_transform(make_tuple(N, Ho, Wo))), + make_tuple(Sequence<1, 3, 5>{}, Sequence<0, 2, 4>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); - // output tensor - const auto out_gemmm_gemmn_global_desc = transform_dynamic_tensor_descriptor( - make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(N * Ho * Wo, K)), - make_tuple(make_pass_through_transform(N * Ho * Wo), make_pass_through_transform(K)), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<1>{}, Sequence<0>{})); + // output tensor + const auto out_gemmm_gemmn_global_desc = transform_dynamic_tensor_descriptor( + make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(N * Ho * Wo, K)), + make_tuple(make_pass_through_transform(N * Ho * Wo), make_pass_through_transform(K)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<1>{}, Sequence<0>{})); - const auto GemmM = out_gemmm_gemmn_global_desc.GetLength(I0); - const auto GemmN = out_gemmm_gemmn_global_desc.GetLength(I1); - const auto GemmK = wei_gemmk_gemmm_global_desc.GetLength(I0); + const auto GemmM = out_gemmm_gemmn_global_desc.GetLength(I0); + const auto GemmN = out_gemmm_gemmn_global_desc.GetLength(I1); + const auto GemmK = wei_gemmk_gemmm_global_desc.GetLength(I0); - if(!(GemmM % GemmMPerBlock == 0 && GemmN % GemmNPerBlock == 0 && - GemmK % GemmKPerBlock == 0)) - { - throw std::runtime_error("wrong! GEMM size no divisible"); - } + assert(GemmM % GemmMPerBlock == 0 && GemmN % GemmNPerBlock == 0 && GemmK % GemmKPerBlock == 0); - constexpr auto GemmM1 = Number{}; - constexpr auto GemmN1 = Number{}; + const auto GemmM0 = GemmM / Number{}; + const auto GemmN0 = GemmN / Number{}; - const auto GemmM0 = GemmM / GemmM1; - const auto GemmN0 = GemmN / GemmN1; + const auto out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc = transform_dynamic_tensor_descriptor( + out_gemmm_gemmn_global_desc, + make_tuple(make_unmerge_transform(make_tuple(GemmM0, GemmM1)), + make_unmerge_transform(make_tuple(GemmN0, GemmN1))), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{})); - const auto out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc = - transform_dynamic_tensor_descriptor( - out_gemmm_gemmn_global_desc, - make_tuple(make_unmerge_transform(make_tuple(GemmM0, GemmM1)), - make_unmerge_transform(make_tuple(GemmN0, GemmN1))), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{})); + // out_gemm_block_cluster_desc + const auto out_gemm_block_cluster_desc = make_cluster_descriptor_v2( + make_tuple(GemmM / Number{}, GemmN / Number{})); - // hack to control index calculation when iterating over a_k_m_global tensor - constexpr auto a_k_m_global_iterator_hacks = - make_tuple(make_tuple(Sequence<0, 0, 0>{}, Sequence<0, 0, 0>{}), - make_tuple(Sequence<0, 0, 0>{}, Sequence<0, 0, 0>{})); + // hack to control index calculation when iterating over a_k_m_global tensor + constexpr auto wei_gemmk_gemmm_global_iterator_hacks = + make_tuple(make_tuple(Sequence<0, 0, 0>{}, Sequence<0, 0, 0>{}), + make_tuple(Sequence<0, 0, 0>{}, Sequence<0, 0, 0>{})); - constexpr auto a_k_m_global_move_slice_window_iterator_hack = Sequence<0, 0, 0>{}; + constexpr auto wei_gemmk_gemmm_global_move_slice_window_iterator_hacks = Sequence<0, 0, 0>{}; - // hack to control index calculation when iterating over b_k_n_global tensor - constexpr auto b_k_n_global_iterator_hacks = - make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0>{}, - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1>{}), - make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0>{}, - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2>{})); + // hack to control index calculation when iterating over b_k_n_global tensor + constexpr auto in_gemmk_gemmn_global_iterator_hacks = + make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0>{}, + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1>{}), + make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0>{}, + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2>{})); - constexpr auto b_k_n_global_move_slice_window_iterator_hack = - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2>{}; + constexpr auto in_gemmk_gemmn_global_move_slice_window_iterator_hacks = + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2>{}; - // hack to control index calculation when iterating over c_m0_m1_n0_n1_global tensor - // hack for NKHW format - constexpr auto c_m0_m1_n0_n1_global_tensor_iterator_hacks = - make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 1, 0, 0>{}, - Sequence<0, 0, 1, 0, 0>{}), - make_tuple(Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 2, 0, 0>{}, - Sequence<0, 0, 2, 0, 0>{})); + // hack to control index calculation when iterating over c_m0_m1_n0_n1_global tensor + // hack for NKHW format + constexpr auto out_gemmm0_gemmm1_gemmn0_gemmn1_global_iterator_hacks = + make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 1, 0, 0>{}, + Sequence<0, 0, 1, 0, 0>{}), + make_tuple(Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 2, 0, 0>{}, + Sequence<0, 0, 2, 0, 0>{})); - // GEMM - using gridwise_gemm = GridwiseDynamicGemm_km_kn_m0m1n0n1_v1< - BlockSize, - FloatAB, - FloatAcc, - FloatC, - InMemoryDataOperation::Set, - decltype(wei_gemmk_gemmm_global_desc), - decltype(in_gemmk_gemmn_global_desc), - decltype(out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc), - GemmMPerBlock, - GemmNPerBlock, - GemmKPerBlock, - GemmMPerThread, - GemmNPerThread, - GemmKPerThread, - GemmMLevel0Cluster, - GemmNLevel0Cluster, - GemmMLevel1Cluster, - GemmNLevel1Cluster, - GemmABlockTransferThreadSliceLengths_GemmK_GemmM, - GemmABlockTransferThreadClusterLengths_GemmK_GemmM, - Sequence<1, 0>, - Sequence<1, 0>, - 0, - GemmABlockTransferSrcScalarPerVector_GemmK, - GemmABlockTransferDstScalarPerVector_GemmM, - false, // don't move back src coordinate after threadwise copy - GemmBBlockTransferThreadSliceLengths_GemmK_GemmN, - GemmBBlockTransferThreadClusterLengths_GemmK_GemmN, - Sequence<1, 0>, - Sequence<1, 0>, - 0, - GemmBBlockTransferSrcScalarPerVector_GemmK, - GemmBBlockTransferDstScalarPerVector_GemmN, - false, // don't move back src coordinate after threadwise copy, which will be fused with - // MoveSrcSliceWindow() to save addr computation - Sequence<2, 3, 0, 1>, - 1, - GemmCThreadTransferDstScalarPerVector_GemmM1, - decltype(a_k_m_global_iterator_hacks), - decltype(b_k_n_global_iterator_hacks), - decltype(c_m0_m1_n0_n1_global_tensor_iterator_hacks), - decltype(a_k_m_global_move_slice_window_iterator_hack), - decltype(b_k_n_global_move_slice_window_iterator_hack)>; + return make_tuple(wei_gemmk_gemmm_global_desc, + in_gemmk_gemmn_global_desc, + out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc, + out_gemm_block_cluster_desc, + wei_gemmk_gemmm_global_iterator_hacks, + in_gemmk_gemmn_global_iterator_hacks, + out_gemmm0_gemmm1_gemmn0_gemmn1_global_iterator_hacks, + wei_gemmk_gemmm_global_move_slice_window_iterator_hacks, + in_gemmk_gemmn_global_move_slice_window_iterator_hacks); +} - const auto GridSize = (GemmM / GemmMPerBlock) * (GemmN / GemmNPerBlock); - - const bool has_main_k_block_loop = (GemmK + GemmKPerBlock) / (2 * GemmKPerBlock) > 1; - - const bool has_double_tail_k_block_loop = (GemmK / GemmKPerBlock) % 2 == 0; - - printf("%s: BlockSize %d, GridSize %d \n", __func__, BlockSize, GridSize); - -#if CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VALUE - index_t nrepeat = 100; - - for(index_t i = 0; i < 5; ++i) - { - std::cout << "Start running " << nrepeat << " times..." << std::endl; - - KernelTimer timer; - timer.Start(); - - for(index_t j = 0; j < nrepeat; ++j) - { - if(has_main_k_block_loop && has_double_tail_k_block_loop) - { - const auto kernel = - run_gridwise_operation, - integral_constant>; - - launch_kernel(kernel, - dim3(GridSize), - dim3(BlockSize), - 0, - 0, - wei_gemmk_gemmm_global_desc, - p_wei_global, - in_gemmk_gemmn_global_desc, - p_in_global, - out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc, - p_out_global, - integral_constant{}, - integral_constant{}); - } - else if(has_main_k_block_loop && !has_double_tail_k_block_loop) - { - const auto kernel = - run_gridwise_operation, - integral_constant>; - - launch_kernel(kernel, - dim3(GridSize), - dim3(BlockSize), - 0, - 0, - wei_gemmk_gemmm_global_desc, - p_wei_global, - in_gemmk_gemmn_global_desc, - p_in_global, - out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc, - p_out_global, - integral_constant{}, - integral_constant{}); - } - else if(!has_main_k_block_loop && has_double_tail_k_block_loop) - { - const auto kernel = - run_gridwise_operation, - integral_constant>; - - launch_kernel(kernel, - dim3(GridSize), - dim3(BlockSize), - 0, - 0, - wei_gemmk_gemmm_global_desc, - p_wei_global, - in_gemmk_gemmn_global_desc, - p_in_global, - out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc, - p_out_global, - integral_constant{}, - integral_constant{}); - } - else - { - const auto kernel = - run_gridwise_operation, - integral_constant>; - - launch_kernel(kernel, - dim3(GridSize), - dim3(BlockSize), - 0, - 0, - wei_gemmk_gemmm_global_desc, - p_wei_global, - in_gemmk_gemmn_global_desc, - p_in_global, - out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc, - p_out_global, - integral_constant{}, - integral_constant{}); - } - } - - timer.End(); - - float ave_time = timer.GetElapsedTime() / nrepeat; - - float perf = (float)(std::size_t(2) * N * K * Ho * Wo * C * Y * X) / - (std::size_t(1000) * 1000 * 1000) / ave_time; - - std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s" - << std::endl; - } -#elif CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VOID_POINTER - using ADesc = decltype(wei_gemmk_gemmm_global_desc); - using BDesc = decltype(in_gemmk_gemmn_global_desc); - using CDesc = decltype(out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc); - - DeviceMem wei_gemmk_gemmm_global_desc_device_buf(sizeof(ADesc)); - DeviceMem in_gemmk_gemmn_global_desc_device_buf(sizeof(BDesc)); - DeviceMem out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc_desc_device_buf(sizeof(CDesc)); - - wei_gemmk_gemmm_global_desc_device_buf.ToDevice(&wei_gemmk_gemmm_global_desc); - in_gemmk_gemmn_global_desc_device_buf.ToDevice(&in_gemmk_gemmn_global_desc); - out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc_desc_device_buf.ToDevice( - &out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc); - - index_t nrepeat = 100; - - for(index_t i = 0; i < 5; ++i) - { - std::cout << "Start running " << nrepeat << " times..." << std::endl; - - KernelTimer timer; - timer.Start(); - - for(index_t j = 0; j < nrepeat; ++j) - { - if(has_main_k_block_loop && has_double_tail_k_block_loop) - { - const auto kernel = run_gridwise_dynamic_gemm_v1; - - launch_kernel( - kernel, - dim3(GridSize), - dim3(BlockSize), - 0, - 0, - (void __CONSTANT__*) - wei_gemmk_gemmm_global_desc_device_buf.GetDeviceBuffer(), - p_wei_global, - (void __CONSTANT__*)in_gemmk_gemmn_global_desc_device_buf.GetDeviceBuffer(), - p_in_global, - (void __CONSTANT__*) - out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc_desc_device_buf - .GetDeviceBuffer(), - p_out_global); - } - else if(has_main_k_block_loop && !has_double_tail_k_block_loop) - { - const auto kernel = run_gridwise_dynamic_gemm_v1; - - launch_kernel( - kernel, - dim3(GridSize), - dim3(BlockSize), - 0, - 0, - (void __CONSTANT__*) - wei_gemmk_gemmm_global_desc_device_buf.GetDeviceBuffer(), - p_wei_global, - (void __CONSTANT__*)in_gemmk_gemmn_global_desc_device_buf.GetDeviceBuffer(), - p_in_global, - (void __CONSTANT__*) - out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc_desc_device_buf - .GetDeviceBuffer(), - p_out_global); - } - else if(!has_main_k_block_loop && has_double_tail_k_block_loop) - { - const auto kernel = run_gridwise_dynamic_gemm_v1; - - launch_kernel( - kernel, - dim3(GridSize), - dim3(BlockSize), - 0, - 0, - (void __CONSTANT__*) - wei_gemmk_gemmm_global_desc_device_buf.GetDeviceBuffer(), - p_wei_global, - (void __CONSTANT__*)in_gemmk_gemmn_global_desc_device_buf.GetDeviceBuffer(), - p_in_global, - (void __CONSTANT__*) - out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc_desc_device_buf - .GetDeviceBuffer(), - p_out_global); - } - else - { - const auto kernel = run_gridwise_dynamic_gemm_v1; - - launch_kernel( - kernel, - dim3(GridSize), - dim3(BlockSize), - 0, - 0, - (void __CONSTANT__*) - wei_gemmk_gemmm_global_desc_device_buf.GetDeviceBuffer(), - p_wei_global, - (void __CONSTANT__*)in_gemmk_gemmn_global_desc_device_buf.GetDeviceBuffer(), - p_in_global, - (void __CONSTANT__*) - out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc_desc_device_buf - .GetDeviceBuffer(), - p_out_global); - } - } - - timer.End(); - - float ave_time = timer.GetElapsedTime() / nrepeat; - - float perf = (float)(std::size_t(2) * N * K * Ho * Wo * C * Y * X) / - (std::size_t(1000) * 1000 * 1000) / ave_time; - - std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s" - << std::endl; - } -#endif - } -}; - -template -struct DriverDynamicConvolutionForwardImplicitGemm_v4r4_nhwc_kyxc_nhwk_1x1 + index_t GemmM1, + index_t GemmN1, + typename... Wei, + typename... In, + typename... Out, + typename ConvStrides, + typename ConvDilations, + typename InLeftPads, + typename InRightPads> +__host__ __device__ constexpr auto transform_forward_convolution_into_gemm_v4r4_nhwc_kyxc_nhwk_1x1( + const DynamicTensorDescriptor& wei_k_y_x_c_global_desc, + const DynamicTensorDescriptor& in_n_hi_wi_c_global_desc, + const DynamicTensorDescriptor& out_n_ho_wo_k_global_desc, + const ConvStrides& conv_strides, + const ConvDilations& conv_dilations, + const InLeftPads& in_left_pads, + const InRightPads& in_right_pads) { - template - __host__ void Run(const DynamicTensorDescriptor& wei_k_y_x_c_global_desc, - const DynamicTensorDescriptor& in_n_hi_wi_c_global_desc, - const DynamicTensorDescriptor& out_n_ho_wo_k_global_desc, - const ConvStrides& conv_strides, - const ConvDilations& conv_dilations, - const InLeftPads& in_left_pads, - const InRightPads& in_right_pads, - const FloatAB* __restrict__ p_wei_global, - const FloatAB* __restrict__ p_in_global, - FloatC* __restrict__ p_out_global) const - { - constexpr auto I0 = Number<0>{}; - constexpr auto I1 = Number<1>{}; - constexpr auto I2 = Number<2>{}; - constexpr auto I3 = Number<3>{}; + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + constexpr auto I2 = Number<2>{}; + constexpr auto I3 = Number<3>{}; - const auto N = in_n_hi_wi_c_global_desc.GetLength(I0); - const auto C = in_n_hi_wi_c_global_desc.GetLength(I3); - const auto K = out_n_ho_wo_k_global_desc.GetLength(I3); + const auto N = in_n_hi_wi_c_global_desc.GetLength(I0); + const auto C = in_n_hi_wi_c_global_desc.GetLength(I3); + const auto K = out_n_ho_wo_k_global_desc.GetLength(I3); - const auto Hi = in_n_hi_wi_c_global_desc.GetLength(I1); - const auto Wi = in_n_hi_wi_c_global_desc.GetLength(I2); + const auto Hi = in_n_hi_wi_c_global_desc.GetLength(I1); + const auto Wi = in_n_hi_wi_c_global_desc.GetLength(I2); - const auto Ho = out_n_ho_wo_k_global_desc.GetLength(I1); - const auto Wo = out_n_ho_wo_k_global_desc.GetLength(I2); + const auto Ho = out_n_ho_wo_k_global_desc.GetLength(I1); + const auto Wo = out_n_ho_wo_k_global_desc.GetLength(I2); - const auto Y = wei_k_y_x_c_global_desc.GetLength(I1); - const auto X = wei_k_y_x_c_global_desc.GetLength(I2); + const auto Y = wei_k_y_x_c_global_desc.GetLength(I1); + const auto X = wei_k_y_x_c_global_desc.GetLength(I2); - const auto ConvStrideH = conv_strides[I0]; - const auto ConvStrideW = conv_strides[I1]; + const auto ConvStrideH = conv_strides[I0]; + const auto ConvStrideW = conv_strides[I1]; - const auto ConvDilationH = conv_dilations[I0]; - const auto ConvDilationW = conv_dilations[I1]; + const auto ConvDilationH = conv_dilations[I0]; + const auto ConvDilationW = conv_dilations[I1]; - const auto InLeftPadH = in_left_pads[I0]; - const auto InLeftPadW = in_left_pads[I1]; + const auto InLeftPadH = in_left_pads[I0]; + const auto InLeftPadW = in_left_pads[I1]; - const auto InRightPadH = in_right_pads[I0]; - const auto InRightPadW = in_right_pads[I1]; + const auto InRightPadH = in_right_pads[I0]; + const auto InRightPadW = in_right_pads[I1]; - if(!(Y == 1 && X == 1 && ConvStrideH == 1 && ConvStrideW == 1 && ConvDilationH == 1 && - ConvDilationW == 1 && InLeftPadH == 0 && InLeftPadW == 0 && InRightPadH == 0 && - InRightPadW == 0)) - { - throw std::runtime_error("wrong! 1x1, stride 1, no padding"); - } + assert(Y == 1 && X == 1 && ConvStrideH == 1 && ConvStrideW == 1 && ConvDilationH == 1 && + ConvDilationW == 1 && InLeftPadH == 0 && InLeftPadW == 0 && InRightPadH == 0 && + InRightPadW == 0); - // weight tensor - const auto wei_gemmk_gemmm_global_desc = transform_dynamic_tensor_descriptor( - make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(K, C)), - make_tuple(make_pass_through_transform(K), make_pass_through_transform(C)), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<1>{}, Sequence<0>{})); + // weight tensor + const auto wei_gemmk_gemmm_global_desc = transform_dynamic_tensor_descriptor( + make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(K, C)), + make_tuple(make_pass_through_transform(K), make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<1>{}, Sequence<0>{})); - // input tensor - const auto in_gemmk_gemmn_global_desc = transform_dynamic_tensor_descriptor( - make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(N * Ho * Wo, C)), - make_tuple(make_pass_through_transform(N * Ho * Wo), make_pass_through_transform(C)), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<1>{}, Sequence<0>{})); + // input tensor + const auto in_gemmk_gemmn_global_desc = transform_dynamic_tensor_descriptor( + make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(N * Ho * Wo, C)), + make_tuple(make_pass_through_transform(N * Ho * Wo), make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<1>{}, Sequence<0>{})); - // output tensor - const auto out_gemmm_gemmn_global_desc = transform_dynamic_tensor_descriptor( - make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(N * Ho * Wo, K)), - make_tuple(make_pass_through_transform(N * Ho * Wo), make_pass_through_transform(K)), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<1>{}, Sequence<0>{})); + // output tensor + const auto out_gemmm_gemmn_global_desc = transform_dynamic_tensor_descriptor( + make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(N * Ho * Wo, K)), + make_tuple(make_pass_through_transform(N * Ho * Wo), make_pass_through_transform(K)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<1>{}, Sequence<0>{})); - const auto GemmM = out_gemmm_gemmn_global_desc.GetLength(I0); - const auto GemmN = out_gemmm_gemmn_global_desc.GetLength(I1); - const auto GemmK = wei_gemmk_gemmm_global_desc.GetLength(I0); + const auto GemmM = out_gemmm_gemmn_global_desc.GetLength(I0); + const auto GemmN = out_gemmm_gemmn_global_desc.GetLength(I1); + const auto GemmK = wei_gemmk_gemmm_global_desc.GetLength(I0); - if(!(GemmM % GemmMPerBlock == 0 && GemmN % GemmNPerBlock == 0 && - GemmK % GemmKPerBlock == 0)) - { - throw std::runtime_error("wrong! GEMM size no divisible"); - } + assert(GemmM % GemmMPerBlock == 0 && GemmN % GemmNPerBlock == 0 && GemmK % GemmKPerBlock == 0); - constexpr auto GemmM1 = Number{}; - constexpr auto GemmN1 = Number{}; + const auto GemmM0 = GemmM / Number{}; + const auto GemmN0 = GemmN / Number{}; - const auto GemmM0 = GemmM / GemmM1; - const auto GemmN0 = GemmN / GemmN1; + const auto out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc = transform_dynamic_tensor_descriptor( + out_gemmm_gemmn_global_desc, + make_tuple(make_unmerge_transform(make_tuple(GemmM0, GemmM1)), + make_unmerge_transform(make_tuple(GemmN0, GemmN1))), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{})); - const auto out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc = - transform_dynamic_tensor_descriptor( - out_gemmm_gemmn_global_desc, - make_tuple(make_unmerge_transform(make_tuple(GemmM0, GemmM1)), - make_unmerge_transform(make_tuple(GemmN0, GemmN1))), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{})); + // out_gemm_block_cluster_desc + const auto out_gemm_block_cluster_desc = make_cluster_descriptor_v2( + make_tuple(GemmM / Number{}, GemmN / Number{})); - // hack to control index calculation when iterating over a_k_m_global tensor - constexpr auto a_k_m_global_iterator_hacks = - make_tuple(make_tuple(Sequence<0, 0, 0>{}, Sequence<0, 0, 0>{}), - make_tuple(Sequence<0, 0, 0>{}, Sequence<0, 0, 0>{})); + // hack to control index calculation when iterating over wei_gemmk_gemmm_global_iterator_hacks + // tensor + constexpr auto wei_gemmk_gemmm_global_iterator_hacks = + make_tuple(make_tuple(Sequence<0, 0, 0>{}, Sequence<0, 0, 0>{}), + make_tuple(Sequence<0, 0, 0>{}, Sequence<0, 0, 0>{})); - constexpr auto a_k_m_global_move_slice_window_iterator_hack = Sequence<0, 0, 0>{}; + constexpr auto wei_gemmk_gemmm_global_move_slice_window_iterator_hacks = Sequence<0, 0, 0>{}; - // hack to control index calculation when iterating over b_k_n_global tensor - constexpr auto b_k_n_global_iterator_hacks = - make_tuple(make_tuple(Sequence<0, 0, 0>{}, Sequence<0, 0, 0>{}), - make_tuple(Sequence<0, 0, 0>{}, Sequence<0, 0, 0>{})); + // hack to control index calculation when iterating over b_k_n_global tensor + constexpr auto in_gemmk_gemmn_global_iterator_hacks = + make_tuple(make_tuple(Sequence<0, 0, 0>{}, Sequence<0, 0, 0>{}), + make_tuple(Sequence<0, 0, 0>{}, Sequence<0, 0, 0>{})); - constexpr auto b_k_n_global_move_slice_window_iterator_hack = Sequence<0, 0, 0>{}; + constexpr auto in_gemmk_gemmn_global_move_slice_window_iterator_hacks = Sequence<0, 0, 0>{}; - // hack to control index calculation when iterating over c_m0_m1_n0_n1_global tensor - constexpr auto c_m0_m1_n0_n1_global_tensor_iterator_hacks = - make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0>{}), - make_tuple(Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0>{})); + // hack to control index calculation when iterating over c_m0_m1_n0_n1_global tensor + constexpr auto out_gemmm0_gemmm1_gemmn0_gemmn1_global_iterator_hacks = + make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0>{}), + make_tuple(Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0>{})); - // GEMM - using gridwise_gemm = GridwiseDynamicGemm_km_kn_m0m1n0n1_v1< - BlockSize, - FloatAB, - FloatAcc, - FloatC, - InMemoryDataOperation::Set, - decltype(wei_gemmk_gemmm_global_desc), - decltype(in_gemmk_gemmn_global_desc), - decltype(out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc), - GemmMPerBlock, - GemmNPerBlock, - GemmKPerBlock, - GemmMPerThread, - GemmNPerThread, - GemmKPerThread, - GemmMLevel0Cluster, - GemmNLevel0Cluster, - GemmMLevel1Cluster, - GemmNLevel1Cluster, - GemmABlockTransferThreadSliceLengths_GemmK_GemmM, - GemmABlockTransferThreadClusterLengths_GemmK_GemmM, - Sequence<1, 0>, - Sequence<1, 0>, - 0, - GemmABlockTransferSrcScalarPerVector_GemmK, - GemmABlockTransferDstScalarPerVector_GemmM, - false, // don't move back src coordinate after threadwise copy - GemmBBlockTransferThreadSliceLengths_GemmK_GemmN, - GemmBBlockTransferThreadClusterLengths_GemmK_GemmN, - Sequence<1, 0>, - Sequence<1, 0>, - 0, - GemmBBlockTransferSrcScalarPerVector_GemmK, - GemmBBlockTransferDstScalarPerVector_GemmN, - false, // don't move back src coordinate after threadwise copy, which will be fused with - // MoveSrcSliceWindow() to save addr computation - Sequence<2, 3, 0, 1>, - 1, - GemmCThreadTransferDstScalarPerVector_GemmM1, - decltype(a_k_m_global_iterator_hacks), - decltype(b_k_n_global_iterator_hacks), - decltype(c_m0_m1_n0_n1_global_tensor_iterator_hacks), - decltype(a_k_m_global_move_slice_window_iterator_hack), - decltype(b_k_n_global_move_slice_window_iterator_hack)>; - - const auto GridSize = (GemmM / GemmMPerBlock) * (GemmN / GemmNPerBlock); - - const bool has_main_k_block_loop = (GemmK + GemmKPerBlock) / (2 * GemmKPerBlock) > 1; - - const bool has_double_tail_k_block_loop = (GemmK / GemmKPerBlock) % 2 == 0; - - printf("%s: BlockSize %d, GridSize %d \n", __func__, BlockSize, GridSize); - -#if CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VALUE - index_t nrepeat = 100; - - for(index_t i = 0; i < 5; ++i) - { - std::cout << "Start running " << nrepeat << " times..." << std::endl; - - KernelTimer timer; - timer.Start(); - - for(index_t j = 0; j < nrepeat; ++j) - { - if(has_main_k_block_loop && has_double_tail_k_block_loop) - { - const auto kernel = - run_gridwise_operation, - integral_constant>; - - launch_kernel(kernel, - dim3(GridSize), - dim3(BlockSize), - 0, - 0, - wei_gemmk_gemmm_global_desc, - p_wei_global, - in_gemmk_gemmn_global_desc, - p_in_global, - out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc, - p_out_global, - integral_constant{}, - integral_constant{}); - } - else if(has_main_k_block_loop && !has_double_tail_k_block_loop) - { - const auto kernel = - run_gridwise_operation, - integral_constant>; - - launch_kernel(kernel, - dim3(GridSize), - dim3(BlockSize), - 0, - 0, - wei_gemmk_gemmm_global_desc, - p_wei_global, - in_gemmk_gemmn_global_desc, - p_in_global, - out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc, - p_out_global, - integral_constant{}, - integral_constant{}); - } - else if(!has_main_k_block_loop && has_double_tail_k_block_loop) - { - const auto kernel = - run_gridwise_operation, - integral_constant>; - - launch_kernel(kernel, - dim3(GridSize), - dim3(BlockSize), - 0, - 0, - wei_gemmk_gemmm_global_desc, - p_wei_global, - in_gemmk_gemmn_global_desc, - p_in_global, - out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc, - p_out_global, - integral_constant{}, - integral_constant{}); - } - else - { - const auto kernel = - run_gridwise_operation, - integral_constant>; - - launch_kernel(kernel, - dim3(GridSize), - dim3(BlockSize), - 0, - 0, - wei_gemmk_gemmm_global_desc, - p_wei_global, - in_gemmk_gemmn_global_desc, - p_in_global, - out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc, - p_out_global, - integral_constant{}, - integral_constant{}); - } - } - - timer.End(); - - float ave_time = timer.GetElapsedTime() / nrepeat; - - float perf = (float)(std::size_t(2) * N * K * Ho * Wo * C * Y * X) / - (std::size_t(1000) * 1000 * 1000) / ave_time; - - std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s" - << std::endl; - } -#elif CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VOID_POINTER - using ADesc = decltype(wei_gemmk_gemmm_global_desc); - using BDesc = decltype(in_gemmk_gemmn_global_desc); - using CDesc = decltype(out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc); - - DeviceMem wei_gemmk_gemmm_global_desc_device_buf(sizeof(ADesc)); - DeviceMem in_gemmk_gemmn_global_desc_device_buf(sizeof(BDesc)); - DeviceMem out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc_desc_device_buf(sizeof(CDesc)); - - wei_gemmk_gemmm_global_desc_device_buf.ToDevice(&wei_gemmk_gemmm_global_desc); - in_gemmk_gemmn_global_desc_device_buf.ToDevice(&in_gemmk_gemmn_global_desc); - out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc_desc_device_buf.ToDevice( - &out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc); - - index_t nrepeat = 100; - - for(index_t i = 0; i < 5; ++i) - { - std::cout << "Start running " << nrepeat << " times..." << std::endl; - - KernelTimer timer; - timer.Start(); - - for(index_t j = 0; j < nrepeat; ++j) - { - if(has_main_k_block_loop && has_double_tail_k_block_loop) - { - const auto kernel = run_gridwise_dynamic_gemm_v1; - - launch_kernel( - kernel, - dim3(GridSize), - dim3(BlockSize), - 0, - 0, - (void __CONSTANT__*) - wei_gemmk_gemmm_global_desc_device_buf.GetDeviceBuffer(), - p_wei_global, - (void __CONSTANT__*)in_gemmk_gemmn_global_desc_device_buf.GetDeviceBuffer(), - p_in_global, - (void __CONSTANT__*) - out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc_desc_device_buf - .GetDeviceBuffer(), - p_out_global); - } - else if(has_main_k_block_loop && !has_double_tail_k_block_loop) - { - const auto kernel = run_gridwise_dynamic_gemm_v1; - - launch_kernel( - kernel, - dim3(GridSize), - dim3(BlockSize), - 0, - 0, - (void __CONSTANT__*) - wei_gemmk_gemmm_global_desc_device_buf.GetDeviceBuffer(), - p_wei_global, - (void __CONSTANT__*)in_gemmk_gemmn_global_desc_device_buf.GetDeviceBuffer(), - p_in_global, - (void __CONSTANT__*) - out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc_desc_device_buf - .GetDeviceBuffer(), - p_out_global); - } - else if(!has_main_k_block_loop && has_double_tail_k_block_loop) - { - const auto kernel = run_gridwise_dynamic_gemm_v1; - - launch_kernel( - kernel, - dim3(GridSize), - dim3(BlockSize), - 0, - 0, - (void __CONSTANT__*) - wei_gemmk_gemmm_global_desc_device_buf.GetDeviceBuffer(), - p_wei_global, - (void __CONSTANT__*)in_gemmk_gemmn_global_desc_device_buf.GetDeviceBuffer(), - p_in_global, - (void __CONSTANT__*) - out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc_desc_device_buf - .GetDeviceBuffer(), - p_out_global); - } - else - { - const auto kernel = run_gridwise_dynamic_gemm_v1; - - launch_kernel( - kernel, - dim3(GridSize), - dim3(BlockSize), - 0, - 0, - (void __CONSTANT__*) - wei_gemmk_gemmm_global_desc_device_buf.GetDeviceBuffer(), - p_wei_global, - (void __CONSTANT__*)in_gemmk_gemmn_global_desc_device_buf.GetDeviceBuffer(), - p_in_global, - (void __CONSTANT__*) - out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc_desc_device_buf - .GetDeviceBuffer(), - p_out_global); - } - } - timer.End(); - - float ave_time = timer.GetElapsedTime() / nrepeat; - - float perf = (float)(std::size_t(2) * N * K * Ho * Wo * C * Y * X) / - (std::size_t(1000) * 1000 * 1000) / ave_time; - - std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s" - << std::endl; - } -#endif - } -}; + return make_tuple(wei_gemmk_gemmm_global_desc, + in_gemmk_gemmn_global_desc, + out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc, + out_gemm_block_cluster_desc, + wei_gemmk_gemmm_global_iterator_hacks, + in_gemmk_gemmn_global_iterator_hacks, + out_gemmm0_gemmm1_gemmn0_gemmn1_global_iterator_hacks, + wei_gemmk_gemmm_global_move_slice_window_iterator_hacks, + in_gemmk_gemmn_global_move_slice_window_iterator_hacks); +} } // namespace ck #endif diff --git a/composable_kernel/include/driver/driver_dynamic_gemm_v1.hpp b/composable_kernel/include/driver/driver_dynamic_gemm_v1.hpp new file mode 100644 index 0000000000..0106377a4f --- /dev/null +++ b/composable_kernel/include/driver/driver_dynamic_gemm_v1.hpp @@ -0,0 +1,396 @@ +#ifndef CK_DRIVER_DYNAMIC_GEMM_V1 +#define CK_DRIVER_DYNAMIC_GEMM_V1 + +#include "common_header.hpp" +#include "dynamic_tensor_descriptor.hpp" +#include "dynamic_tensor_descriptor_helper.hpp" +#include "gridwise_dynamic_gemm.hpp" +#include "gridwise_operation_wrapper.hpp" + +namespace ck { + +template +__host__ float launch_kernel_dynamic_gemm_v1(const FloatAB* p_a_global, + const FloatAB* p_b_global, + FloatC* p_c_global, + const AGlobalDesc& a_k_m_global_desc, + const BGlobalDesc& b_k_n_global_desc, + const CGlobalDesc& c_m0_m1_n0_n1_global_desc, + const CBlockClusterDesc& c_block_cluster_desc, + AGlobalIteratorHacks, + BGlobalIteratorHacks, + CGlobalIteratorHacks, + AGlobalMoveSliceWindowIteratorHacks, + BGlobalMoveSliceWindowIteratorHacks, + index_t nrepeat) + +{ + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + constexpr auto I2 = Number<2>{}; + constexpr auto I3 = Number<3>{}; + + const auto M = a_k_m_global_desc.GetLength(I1); + const auto N = b_k_n_global_desc.GetLength(I1); + const auto K = a_k_m_global_desc.GetLength(I0); + + if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && K % KPerBlock == 0)) + { + throw std::runtime_error("wrong! GEMM size no divisible"); + } + + constexpr auto M1 = Number{}; + constexpr auto N1 = Number{}; + + if(!(MPerBlock % M1 == 0 && NPerBlock % N1 == 0)) + { + throw std::runtime_error("wrong! GEMM size no divisible"); + } + + // GEMM + using gridwise_gemm = + GridwiseDynamicGemm_km_kn_m0m1n0n1_v1; + + const auto GridSize = (M / MPerBlock) * (N / NPerBlock); + + const bool has_main_k_block_loop = (K + KPerBlock) / (2 * KPerBlock) > 1; + + const bool has_double_tail_k_block_loop = (K / KPerBlock) % 2 == 0; + +#if CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VALUE + float ave_time = 0; + + if(has_main_k_block_loop && has_double_tail_k_block_loop) + { + const auto kernel = run_gridwise_operation, + const FloatAB*, + remove_reference_t, + const FloatAB*, + remove_reference_t, + FloatC*, + remove_reference_t, + integral_constant, + integral_constant>; + + ave_time = launch_and_time_kernel(kernel, + nrepeat, + dim3(GridSize), + dim3(BlockSize), + 0, + 0, + a_k_m_global_desc, + p_a_global, + b_k_n_global_desc, + p_b_global, + c_m0_m1_n0_n1_global_desc, + p_c_global, + c_block_cluster_desc, + integral_constant{}, + integral_constant{}); + } + else if(has_main_k_block_loop && !has_double_tail_k_block_loop) + { + const auto kernel = run_gridwise_operation, + const FloatAB*, + remove_reference_t, + const FloatAB*, + remove_reference_t, + FloatC*, + remove_reference_t, + integral_constant, + integral_constant>; + + ave_time = launch_and_time_kernel(kernel, + nrepeat, + dim3(GridSize), + dim3(BlockSize), + 0, + 0, + a_k_m_global_desc, + p_a_global, + b_k_n_global_desc, + p_b_global, + c_m0_m1_n0_n1_global_desc, + p_c_global, + c_block_cluster_desc, + integral_constant{}, + integral_constant{}); + } + else if(!has_main_k_block_loop && has_double_tail_k_block_loop) + { + const auto kernel = run_gridwise_operation, + const FloatAB*, + remove_reference_t, + const FloatAB*, + remove_reference_t, + FloatC*, + remove_reference_t, + integral_constant, + integral_constant>; + + ave_time = launch_and_time_kernel(kernel, + nrepeat, + dim3(GridSize), + dim3(BlockSize), + 0, + 0, + a_k_m_global_desc, + p_a_global, + b_k_n_global_desc, + p_b_global, + c_m0_m1_n0_n1_global_desc, + p_c_global, + c_block_cluster_desc, + integral_constant{}, + integral_constant{}); + } + else + { + const auto kernel = run_gridwise_operation, + const FloatAB*, + remove_reference_t, + const FloatAB*, + remove_reference_t, + FloatC*, + remove_reference_t, + integral_constant, + integral_constant>; + + ave_time = launch_and_time_kernel(kernel, + nrepeat, + dim3(GridSize), + dim3(BlockSize), + 0, + 0, + a_k_m_global_desc, + p_a_global, + b_k_n_global_desc, + p_b_global, + c_m0_m1_n0_n1_global_desc, + p_c_global, + c_block_cluster_desc, + integral_constant{}, + integral_constant{}); + } + + return ave_time; +#elif CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VOID_POINTER + DeviceMem a_k_m_global_desc_device_buf(sizeof(AGlobalDesc)); + DeviceMem b_k_n_global_desc_device_buf(sizeof(BGlobalDesc)); + DeviceMem c_m0_m1_n0_n1_global_desc_device_buf(sizeof(CGlobalDesc)); + DeviceMem c_block_cluster_desc_device_buf(sizeof(c_block_cluster_desc)); + + a_k_m_global_desc_device_buf.ToDevice(&a_k_m_global_desc); + b_k_n_global_desc_device_buf.ToDevice(&b_k_n_global_desc); + c_m0_m1_n0_n1_global_desc_device_buf.ToDevice(&c_m0_m1_n0_n1_global_desc); + c_block_cluster_desc_device_buf.ToDevice(&c_block_cluster_desc); + + float ave_time = 0; + + if(has_main_k_block_loop && has_double_tail_k_block_loop) + { + const auto kernel = run_gridwise_operation, + const FloatAB*, + remove_reference_t, + const FloatAB*, + remove_reference_t, + FloatC*, + remove_reference_t, + true, + true>; + + ave_time = launch_and_time_kernel( + kernel, + nrepeat, + dim3(GridSize), + dim3(BlockSize), + 0, + 0, + (void __CONSTANT__*)a_k_m_global_desc_device_buf.GetDeviceBuffer(), + p_a_global, + (void __CONSTANT__*)b_k_n_global_desc_device_buf.GetDeviceBuffer(), + p_b_global, + (void __CONSTANT__*)c_m0_m1_n0_n1_global_desc_device_buf.GetDeviceBuffer(), + p_c_global, + (void __CONSTANT__*)c_block_cluster_desc_device_buf.GetDeviceBuffer()); + } + else if(has_main_k_block_loop && !has_double_tail_k_block_loop) + { + const auto kernel = kernel_dynamic_gemm_v1, + FloatAB, + remove_reference_t, + FloatAB, + remove_reference_t, + FloatC, + remove_reference_t, + true, + false>; + + ave_time = launch_and_time_kernel( + kernel, + nrepeat, + dim3(GridSize), + dim3(BlockSize), + 0, + 0, + (void __CONSTANT__*)a_k_m_global_desc_device_buf.GetDeviceBuffer(), + p_a_global, + (void __CONSTANT__*)b_k_n_global_desc_device_buf.GetDeviceBuffer(), + p_b_global, + (void __CONSTANT__*)c_m0_m1_n0_n1_global_desc_device_buf.GetDeviceBuffer(), + p_c_global, + (void __CONSTANT__*)c_block_cluster_desc_device_buf.GetDeviceBuffer()); + } + else if(!has_main_k_block_loop && has_double_tail_k_block_loop) + { + const auto kernel = kernel_dynamic_gemm_v1, + FloatAB, + remove_reference_t, + FloatAB, + remove_reference_t, + FloatC, + remove_reference_t, + false, + true>; + + ave_time = launch_and_time_kernel( + kernel, + nrepeat, + dim3(GridSize), + dim3(BlockSize), + 0, + 0, + (void __CONSTANT__*)a_k_m_global_desc_device_buf.GetDeviceBuffer(), + p_a_global, + (void __CONSTANT__*)b_k_n_global_desc_device_buf.GetDeviceBuffer(), + p_b_global, + (void __CONSTANT__*)c_m0_m1_n0_n1_global_desc_device_buf.GetDeviceBuffer(), + p_c_global, + (void __CONSTANT__*)c_block_cluster_desc_device_buf.GetDeviceBuffer()); + } + else + { + const auto kernel = kernel_dynamic_gemm_v1, + FloatAB, + remove_reference_t, + FloatAB, + remove_reference_t, + FloatC, + remove_reference_t, + false, + false>; + + ave_time = launch_and_time_kernel( + kernel, + nrepeat, + dim3(GridSize), + dim3(BlockSize), + 0, + 0, + (void __CONSTANT__*)a_k_m_global_desc_device_buf.GetDeviceBuffer(), + p_a_global, + (void __CONSTANT__*)b_k_n_global_desc_device_buf.GetDeviceBuffer(), + p_b_global, + (void __CONSTANT__*)c_m0_m1_n0_n1_global_desc_device_buf.GetDeviceBuffer(), + p_c_global, + (void __CONSTANT__*)c_block_cluster_desc_device_buf.GetDeviceBuffer()); + } + + return ave_time; +#endif +} + +} // namespace ck +#endif diff --git a/composable_kernel/include/tensor_description/cluster_descriptor.hpp b/composable_kernel/include/tensor_description/cluster_descriptor.hpp index 96dbe07073..7793dc242a 100644 --- a/composable_kernel/include/tensor_description/cluster_descriptor.hpp +++ b/composable_kernel/include/tensor_description/cluster_descriptor.hpp @@ -5,6 +5,7 @@ // TODO remove dependency on deprecated tensor descriptor #include "tensor_descriptor.hpp" +#include "tensor_adaptor.hpp" namespace ck { @@ -44,5 +45,30 @@ __host__ __device__ constexpr auto make_cluster_descriptor( return ClusterDescriptor{}; } +#if 1 +template ::type> +__host__ __device__ constexpr auto make_cluster_descriptor_v2( + const Lengths& lengths, + ArrangeOrder order = typename arithmetic_sequence_gen<0, Lengths::Size(), 1>::type{}) +{ + constexpr index_t ndim_low = Lengths::Size(); + + const auto reordered_lengths = container_reorder_given_new2old(lengths, order); + + const auto low_lengths = generate_tuple( + [&](auto idim_low) { return reordered_lengths[idim_low]; }, Number{}); + + const auto transform = make_merge_transform(low_lengths); + + constexpr auto low_dim_old_top_ids = ArrangeOrder{}; + + constexpr auto up_dim_new_top_ids = Sequence<0>{}; + + return make_single_stage_tensor_adaptor( + make_tuple(transform), make_tuple(low_dim_old_top_ids), make_tuple(up_dim_new_top_ids)); +} +#endif + } // namespace ck #endif diff --git a/composable_kernel/include/tensor_description/dynamic_multi_index_transform.hpp b/composable_kernel/include/tensor_description/dynamic_multi_index_transform.hpp index 0f1f0d5c29..23748dad59 100644 --- a/composable_kernel/include/tensor_description/dynamic_multi_index_transform.hpp +++ b/composable_kernel/include/tensor_description/dynamic_multi_index_transform.hpp @@ -1282,7 +1282,7 @@ struct DynamicFreeze __host__ __device__ constexpr void CalculateLowerIndex(LowIdx& idx_low, const UpIdx& idx_up) const { - static_assert(LowIdx::Size() == 1 && UpIdx::Size() == 1, + static_assert(LowIdx::Size() == 1 && UpIdx::Size() == 0, "wrong! inconsistent # of dimension"); idx_low = low_idx_; @@ -1299,7 +1299,7 @@ struct DynamicFreeze const UpIdx& idx_up_new, Number) { - idx_diff_low(Number<0>{}) = index_t{Number<0>{}}; + idx_diff_low(Number<0>{}) = 0; } __host__ __device__ static constexpr bool IsLinearTransform() { return true; } @@ -1328,5 +1328,90 @@ struct DynamicFreeze } }; +template +struct DynamicVectorize +{ + using LowerIndex = MultiIndex<1>; + using UpperIndex = MultiIndex<1>; + + using UpLengths = decltype(make_tuple(UpLength{})); + + UpLengths up_lengths_; + VectorSize vector_size_; + + __host__ __device__ constexpr DynamicVectorize() = default; + + __host__ __device__ constexpr DynamicVectorize(const VectorSize& vector_size, + const UpLength& up_length) + : vector_size_{vector_size}, up_lengths_{make_tuple(up_length)} + { + } + + __host__ __device__ static constexpr index_t GetNumOfLowerDimension() { return 1; } + + __host__ __device__ static constexpr index_t GetNumOfUpperDimension() { return 1; } + + __host__ __device__ constexpr const auto& GetUpperLengths() const { return up_lengths_; } + + template + __host__ __device__ void CalculateLowerIndex(LowIdx& idx_low, const UpIdx& idx_up) const + { + static_assert(LowIdx::Size() == 1 && UpIdx::Size() == 1, + "wrong! inconsistent # of dimension"); + + idx_low(Number<0>{}) = vector_size_ * idx_up[Number<0>{}]; + } + + template + __host__ __device__ void UpdateLowerIndex(LowIdxDiff& idx_diff_low, + const UpIdxDiff& idx_diff_up, + LowIdx& idx_low, + const UpIdx& idx_up_new, + Number) const + { + static_assert(LowIdxDiff::Size() == 1 && UpIdxDiff::Size() == 1 && LowIdx::Size() == 1 && + UpIdx::Size() == 1, + "wrong! inconsistent # of dimension"); + + constexpr auto I0 = Number<0>{}; + + idx_diff_low(I0) = vector_size_ * idx_diff_up[I0]; + + idx_low += idx_diff_low; + } + + __host__ __device__ static constexpr bool IsLinearTransform() { return true; } + + __host__ __device__ static constexpr bool IsValidUpperIndexAlwaysMappedToValidLowerIndex() + { + return true; + } + + template + __host__ __device__ static constexpr bool + IsValidUpperIndexMappedToValidLowerIndex(const UpIdx& /* idx_up */) + { + return true; + } + + __host__ __device__ static constexpr bool IsKnownAtCompileTime() + { + return is_known_at_compile_time::value; + } + + __host__ __device__ void Print() const + { + printf("{"); + printf("DynamicVectorize, "); + printf("up_lengths_"); + print_multi_index(up_lengths_); + printf("}"); + } +}; + } // namespace ck #endif diff --git a/composable_kernel/include/tensor_description/dynamic_multi_index_transform_helper.hpp b/composable_kernel/include/tensor_description/dynamic_multi_index_transform_helper.hpp index 591cedb76b..342be83d17 100644 --- a/composable_kernel/include/tensor_description/dynamic_multi_index_transform_helper.hpp +++ b/composable_kernel/include/tensor_description/dynamic_multi_index_transform_helper.hpp @@ -74,5 +74,12 @@ __host__ __device__ constexpr auto make_freeze_transform(const LowerIndex& low_i return DynamicFreeze{low_idx}; } +template +__host__ __device__ constexpr auto make_vectorize_transform(const VectorSize& vector_size, + const UpLength& up_length) +{ + return DynamicVectorize{vector_size, up_length}; +} + } // namespace ck #endif diff --git a/composable_kernel/include/tensor_description/dynamic_tensor_descriptor.hpp b/composable_kernel/include/tensor_description/dynamic_tensor_descriptor.hpp index e2121f1f3e..03c2fccb2e 100644 --- a/composable_kernel/include/tensor_description/dynamic_tensor_descriptor.hpp +++ b/composable_kernel/include/tensor_description/dynamic_tensor_descriptor.hpp @@ -12,25 +12,6 @@ struct DynamicTensorCoordinate; template struct DynamicTensorCoordinateIterator; -template -__host__ __device__ constexpr index_t GetNumOfHiddenDimension(LowerDimensionIdss, - UpperDimensionIdss) -{ - constexpr auto all_low_dim_ids = - unpack([](auto&&... xs) constexpr { return merge_sequences(xs...); }, LowerDimensionIdss{}); - - constexpr auto all_up_dim_ids = - unpack([](auto&&... xs) constexpr { return merge_sequences(xs...); }, UpperDimensionIdss{}); - - constexpr auto all_dim_ids = merge_sequences(all_low_dim_ids, all_up_dim_ids); - - using unique_sort_all_dim_ids = typename sequence_unique_sort, - math::equal>::type; - - return unique_sort_all_dim_ids::Size(); -} - // Transforms: Tuple // LowerDimensionIdss : Tuple, ...> // UpperDimensionIdss : Tuple, ...> @@ -374,13 +355,13 @@ transform_dynamic_tensor_descriptor(const OldTensorDescriptor& old_tensor_desc, unordered_new_visible_dim_hidden_ids.ReorderGivenOld2New(new_visible_dim_unordered2ordered); // put everything together - const auto all_transforms = container_cat(old_tensor_desc.GetTransforms(), new_transforms); + const auto all_transforms = container_concat(old_tensor_desc.GetTransforms(), new_transforms); constexpr auto all_low_dim_hidden_idss = - container_cat(OldTensorDescriptor::GetLowerDimensionIdss(), low_dim_hidden_idss); + container_concat(OldTensorDescriptor::GetLowerDimensionIdss(), low_dim_hidden_idss); constexpr auto all_up_dim_hidden_idss = - container_cat(OldTensorDescriptor::GetUpperDimensionIdss(), up_dim_hidden_idss); + container_concat(OldTensorDescriptor::GetUpperDimensionIdss(), up_dim_hidden_idss); const auto element_space_size = old_tensor_desc.GetElementSpaceSize(); diff --git a/composable_kernel/include/tensor_description/tensor_adaptor.hpp b/composable_kernel/include/tensor_description/tensor_adaptor.hpp new file mode 100644 index 0000000000..3623c92f21 --- /dev/null +++ b/composable_kernel/include/tensor_description/tensor_adaptor.hpp @@ -0,0 +1,456 @@ +#ifndef CK_TENSOR_ADAPTOR_HPP +#define CK_TENSOR_ADAPTOR_HPP + +#include "common_header.hpp" +#include "dynamic_tensor_descriptor.hpp" +#include "dynamic_tensor_descriptor_helper.hpp" + +namespace ck { + +// Transforms: Tuple +// LowerDimensionHiddenIdss : Tuple, ...> +// UpperDimensionHiddenIdss : Tuple, ...> +// BottomDimensionHiddenIds : Sequence<...> +// TopDimensionHiddenIds : Sequence<...> +template +struct TensorAdaptor +{ + __host__ __device__ static constexpr index_t GetNumOfTransform() { return Transforms::Size(); } + + __host__ __device__ constexpr const auto& GetTransforms() const { return transforms_; } + + __host__ __device__ static constexpr auto GetLowerDimensionHiddenIdss() + { + return LowerDimensionHiddenIdss{}; + } + + __host__ __device__ static constexpr auto GetUpperDimensionHiddenIdss() + { + return UpperDimensionHiddenIdss{}; + } + + __host__ __device__ static constexpr auto GetTopDimensionHiddenIds() + { + return TopDimensionHiddenIds{}; + } + + __host__ __device__ static constexpr auto GetBottomDimensionHiddenIds() + { + return BottomDimensionHiddenIds{}; + } + + __host__ __device__ static constexpr auto InitializeElementSize(const Transforms& transforms) + { + const auto lengths = generate_tuple( + [&](auto idim_top) { + constexpr auto tmp = GetTransformAndItsUpperDimension(idim_top); + + constexpr index_t itran = tmp[Number<0>{}]; + constexpr index_t idim_up = tmp[Number<1>{}]; + constexpr bool found = tmp[Number<2>{}]; + + static_assert(found == true, + "wrong! not found matching transformation and upper-dimension"); + + const auto length = + transforms[Number{}].GetUpperLengths()[Number{}]; + + return length; + }, + Number{}); + + // TODO: make container_reduce support tuple of Number and index_t + return container_reduce(lengths, math::multiplies_v2{}, Number<1>{}); + } + + template + __host__ __device__ static constexpr auto GetTransformAndItsUpperDimension(Number) + { + constexpr auto idim_top = Number{}; + + constexpr index_t idim_hidden = TopDimensionHiddenIds::At(idim_top); + + index_t itran_found = 0; + index_t idim_up_found = 0; + bool found = false; + + static_for<0, ntransform_, 1>{}([&](auto itran) { + constexpr auto up_dim_ids = UpperDimensionHiddenIdss{}[itran]; + + static_for<0, up_dim_ids.Size(), 1>{}([&](auto idim_up) { + if constexpr(up_dim_ids[idim_up] == idim_hidden) + { + itran_found = itran; + idim_up_found = idim_up; + found = true; + } + }); + }); + + return make_tuple(itran_found, idim_up_found, found); + } + + __host__ __device__ static constexpr index_t GetNumOfBottomDimension() + { + return BottomDimensionHiddenIds::Size(); + } + + __host__ __device__ static constexpr index_t GetNumOfTopDimension() + { + return TopDimensionHiddenIds::Size(); + } + + __host__ __device__ static constexpr index_t GetNumOfHiddenDimension() + { + constexpr auto all_low_dim_ids = + unpack([](auto&&... xs) constexpr { return merge_sequences(xs...); }, + LowerDimensionHiddenIdss{}); + + constexpr auto all_up_dim_ids = + unpack([](auto&&... xs) constexpr { return merge_sequences(xs...); }, + UpperDimensionHiddenIdss{}); + + constexpr auto all_dim_ids = merge_sequences(all_low_dim_ids, all_up_dim_ids); + + using unique_sort_all_dim_ids = typename sequence_unique_sort, + math::equal>::type; + + return unique_sort_all_dim_ids::Size(); + } + + constexpr static index_t ntransform_ = GetNumOfTransform(); + constexpr static index_t ndim_hidden_ = GetNumOfHiddenDimension(); + constexpr static index_t ndim_bottom_ = GetNumOfBottomDimension(); + constexpr static index_t ndim_top_ = GetNumOfTopDimension(); + + using HiddenIndex = MultiIndex; + using BottomIndex = MultiIndex; + using TopIndex = MultiIndex; + + // may be index_t or Number<> + using ElementSize = remove_cv_t; + + public: + __host__ __device__ constexpr TensorAdaptor() = default; + + __host__ __device__ constexpr TensorAdaptor(const Transforms& transforms) + : transforms_{transforms}, element_size_{InitializeElementSize(transforms)} + { + static_assert(Transforms::Size() == ntransform_ && + LowerDimensionHiddenIdss::Size() == ntransform_ && + UpperDimensionHiddenIdss::Size() == ntransform_, + "wrong! inconsistent # of transformations"); + + // TODO check dependency of dimensions is valid + } + + __host__ __device__ constexpr auto GetElementSize() const { return element_size_; } + + template + __host__ __device__ constexpr auto CalculateBottomIndex(const TopIdx& idx_top) const + { + static_assert(TopIdx::Size() == TopDimensionHiddenIds::Size(), + "wrong! # of dimension inconsistent"); + + constexpr index_t ntransform = GetNumOfTransform(); + constexpr index_t ndim_hidden = GetNumOfHiddenDimension(); + + MultiIndex idx_hidden; + + // initialize uppest index + set_container_subset(idx_hidden, GetTopDimensionHiddenIds(), idx_top); + + // calculate hidden index + static_for{}([&](auto itran_p1) { + auto itran = itran_p1 - Number<1>{}; + const auto& tran = GetTransforms().At(itran); + constexpr auto dims_low = GetLowerDimensionHiddenIdss().At(itran); + constexpr auto dims_up = GetUpperDimensionHiddenIdss().At(itran); + + const auto idx_up = get_container_subset(idx_hidden, dims_up); + + MultiIndex idx_low; + + tran.CalculateLowerIndex(idx_low, idx_up); + + set_container_subset(idx_hidden, dims_low, idx_low); + }); + + return get_container_subset(idx_hidden, BottomDimensionHiddenIds{}); + } + + __host__ __device__ void Print() const + { + printf("{"); + printf("TensorAdaptor, "); + static_for<0, ntransform_, 1>{}([&](auto i) { + printf("transforms: "); + transforms_[i].Print(); + printf("LowerDimensionHiddenIds:"); + LowerDimensionHiddenIdss{}.At(i).Print(); + printf("UpperDimensionHiddenIds:"); + UpperDimensionHiddenIdss{}.At(i).Print(); + }); + + printf("BottomDimensionHiddenIds:"); + BottomDimensionHiddenIds::Print(); + printf("TopDimensionHiddenIds:"); + TopDimensionHiddenIds::Print(); + + printf("}"); + } + + private: + Transforms transforms_; + ElementSize element_size_; +}; + +template +__host__ __device__ constexpr auto chain_tensor_adaptors(const TensorAdaptor0& adaptor0, + const TensorAdaptor1& adaptor1) +{ + static_assert(TensorAdaptor0::GetNumOfTopDimension() == + TensorAdaptor1::GetNumOfBottomDimension(), + "wrong!"); + + // all_transforms = transform0 + transform1 + const auto all_transforms = + container_concat(adaptor0.GetTransforms(), adaptor1.GetTransforms()); + + // shift + constexpr index_t adaptor0_max_hidden_id = [&]() { + index_t adaptor0_max_hidden_id = NumericLimits::Min(); + + static_for<0, TensorAdaptor0::GetNumOfTransform(), 1>{}([&](auto itran) { + constexpr index_t ndim_low = + TensorAdaptor0{}.GetTransforms()[itran].GetNumOfLowerDimension(); + + static_for<0, ndim_low, 1>{}([&](auto idim_low) { + adaptor0_max_hidden_id = + math::max(adaptor0_max_hidden_id, + TensorAdaptor0::GetLowerDimensionHiddenIdss()[itran][idim_low].value); + }); + + constexpr index_t ndim_up = + TensorAdaptor0{}.GetTransforms()[itran].GetNumOfUpperDimension(); + + static_for<0, ndim_up, 1>{}([&](auto idim_up) { + adaptor0_max_hidden_id = + math::max(adaptor0_max_hidden_id, + TensorAdaptor0::GetUpperDimensionHiddenIdss()[itran][idim_up].value); + }); + }); + + return adaptor0_max_hidden_id; + }(); + + constexpr index_t adaptor1_min_hidden_id = [&]() { + index_t adaptor1_min_hidden_id = NumericLimits::Max(); + + static_for<0, TensorAdaptor1::GetNumOfTransform(), 1>{}([&](auto itran) { + constexpr index_t ndim_low = + TensorAdaptor1{}.GetTransforms()[itran].GetNumOfLowerDimension(); + + // get the min of all lower dimenions, but not bottom dimension (because their id will + // be matched with top id from adaptor0) + static_for<0, ndim_low, 1>{}([&](auto idim_low) { + constexpr index_t low_dim_hidden_id = + TensorAdaptor1::GetLowerDimensionHiddenIdss()[itran][idim_low].value; + + bool is_bottom_dim = false; + static_for<0, TensorAdaptor1::GetNumOfBottomDimension(), 1>{}([&](auto i) { + if constexpr(low_dim_hidden_id == + TensorAdaptor1::GetBottomDimensionHiddenIds()[i]) + { + is_bottom_dim = true; + } + }); + + if(!is_bottom_dim) + { + adaptor1_min_hidden_id = math::min(adaptor1_min_hidden_id, low_dim_hidden_id); + } + }); + + constexpr index_t ndim_up = + TensorAdaptor1{}.GetTransforms()[itran].GetNumOfUpperDimension(); + + // get the min of all upper dimensions + static_for<0, ndim_up, 1>{}([&](auto idim_up) { + adaptor1_min_hidden_id = + math::min(adaptor1_min_hidden_id, + TensorAdaptor1::GetUpperDimensionHiddenIdss()[itran][idim_up].value); + }); + }); + + return adaptor1_min_hidden_id; + }(); + + constexpr index_t adaptor1_hidden_id_shift = + adaptor0_max_hidden_id + 1 - adaptor1_min_hidden_id; + + constexpr index_t ndim_bottom_1 = TensorAdaptor1::GetNumOfBottomDimension(); + + // all_low_dim_hidden_idss = + // low_dim_hidden_idss_0 + match_hidden_id_for_1(shift_hidden_id_for_1(low_dim_hiden_idss_1)) + constexpr auto low_dim_hidden_idss_1 = generate_tuple( + // generate sequence of ids for a transform + [&](auto itran) { + constexpr auto ndim_low_1 = TensorAdaptor1::GetLowerDimensionHiddenIdss()[itran].Size(); + + constexpr auto low_dim_hidden_ids_1 = + TensorAdaptor1::GetLowerDimensionHiddenIdss()[itran]; + + // sequence in, sequence out + constexpr auto low_dim_hidden_ids_1_mod = [&]() constexpr + { + auto low_dim_hidden_ids_1_mod = to_multi_index(low_dim_hidden_ids_1); + + // shift hidden id so every dim id is unique + static_for<0, ndim_low_1, 1>{}([&](auto idim_low_1) { + low_dim_hidden_ids_1_mod(idim_low_1) += adaptor1_hidden_id_shift; + }); + + // match hidden id + static_for<0, ndim_low_1, 1>{}([&](auto idim_low_1) { + static_for<0, ndim_bottom_1, 1>{}([&](auto idim_bottom_1) { + // if this low dim is bottom dim, then do id matching + if constexpr(low_dim_hidden_ids_1[idim_low_1] == + TensorAdaptor1::GetBottomDimensionHiddenIds()[idim_bottom_1]) + { + low_dim_hidden_ids_1_mod(idim_low_1) = + TensorAdaptor0::GetTopDimensionHiddenIds()[idim_bottom_1]; + } + }); + }); + + return low_dim_hidden_ids_1_mod; + } + (); + + return generate_sequence_v2( + [&](auto i) constexpr { return Number{}; }, + Number{}); + }, + Number{}); + + constexpr auto all_low_dim_hidden_idss = + container_concat(TensorAdaptor0::GetLowerDimensionHiddenIdss(), low_dim_hidden_idss_1); + + // all_up_dim_hidden_idss = + // up_dim_hidden_idss_0 + shift_hidden_id_for_1(up_dim_hiden_idss_1) + constexpr auto up_dim_hidden_idss_1 = generate_tuple( + // generate sequence of ids for a transform + [&](auto itran) { + constexpr auto ndim_up_1 = TensorAdaptor1::GetUpperDimensionHiddenIdss()[itran].Size(); + + constexpr auto up_dim_hidden_ids_1 = + TensorAdaptor1::GetUpperDimensionHiddenIdss()[itran]; + + // sequence in, constexpr tuple out + constexpr auto up_dim_hidden_ids_1_mod = [&]() constexpr + { + auto up_dim_hidden_ids_1_mod = to_multi_index(up_dim_hidden_ids_1); + + // shift hidden id + static_for<0, ndim_up_1, 1>{}([&](auto idim_up_1) { + up_dim_hidden_ids_1_mod(idim_up_1) += adaptor1_hidden_id_shift; + }); + + return up_dim_hidden_ids_1_mod; + } + (); + + // constexpr tuple to sequence + return generate_sequence_v2( + [&](auto i) constexpr { return Number{}; }, + Number{}); + }, + Number{}); + + constexpr auto all_up_dim_hidden_idss = + container_concat(TensorAdaptor0::GetUpperDimensionHiddenIdss(), up_dim_hidden_idss_1); + + // bottom_dim_hidden_ids = bottom_dim_hidden_ids_0 + constexpr auto bottom_dim_hidden_ids = TensorAdaptor0::GetBottomDimensionHiddenIds(); + + // top_dim_hidden_ids = shift_hidden_id(top_dim_hidden_ids_1) + constexpr auto top_dim_hidden_ids = + TensorAdaptor1::GetTopDimensionHiddenIds() + Number{}; + + // put everything together + return TensorAdaptor, + remove_cv_t, + remove_cv_t, + remove_cv_t, + remove_cv_t>{all_transforms}; +} + +// Transforms: Tuple +// LowerDimensionOldTopIdss: Tuple, ...> +// UpperDimensionNewTopIdss: Tuple, ...> +template +__host__ __device__ constexpr auto make_single_stage_tensor_adaptor(const Transforms& transforms, + LowerDimensionOldTopIdss, + UpperDimensionNewTopIdss) +{ + constexpr index_t ntransform = Transforms::Size(); + + static_assert(LowerDimensionOldTopIdss::Size() == ntransform && + UpperDimensionNewTopIdss::Size() == ntransform, + "wrong!"); + + // sanity check on LowerDimensionOldTopIdss and UpperDimensionNewTopIdss + constexpr auto all_low_dim_old_top_ids = + unpack([](auto&&... xs) constexpr { return merge_sequences(xs...); }, + LowerDimensionOldTopIdss{}); + + constexpr auto all_up_dim_new_top_ids = + unpack([](auto&&... xs) constexpr { return merge_sequences(xs...); }, + UpperDimensionNewTopIdss{}); + + static_assert(is_valid_sequence_map::value && + is_valid_sequence_map::value, + "wrong!"); + + constexpr index_t ndim_old_top = all_low_dim_old_top_ids.Size(); + constexpr index_t ndim_new_top = all_up_dim_new_top_ids.Size(); + + // low_dim_hidden_idss + constexpr auto low_dim_hidden_idss = LowerDimensionOldTopIdss{}; + + // up_dim_hidden_idss: shift UpperDimensionNewTopIdss by ndim_bottom + constexpr auto up_dim_hidden_idss = generate_tuple( + [](auto itran) { return UpperDimensionNewTopIdss{}[itran] + Number{}; }, + Number{}); + + // bottom_dim_hidden_ids + constexpr auto bottom_dim_hidden_ids = + typename arithmetic_sequence_gen<0, ndim_old_top, 1>::type{}; + + // top_dim_hidden_ids + constexpr auto top_dim_hidden_ids = + typename arithmetic_sequence_gen<0, ndim_new_top, 1>::type{} + Number{}; + + return TensorAdaptor, + remove_cv_t, + remove_cv_t, + remove_cv_t, + remove_cv_t>{transforms}; +} + +template = 2, bool>::type = false> +__host__ __device__ constexpr auto chain_tensor_adaptors(const X& x, const Xs&... xs) +{ + return chain_tensor_adaptors(x, chain_tensor_adaptors(xs...)); +} + +} // namespace ck +#endif diff --git a/composable_kernel/include/tensor_operation/blockwise_dynamic_tensor_slice_transfer.hpp b/composable_kernel/include/tensor_operation/blockwise_dynamic_tensor_slice_transfer.hpp index 5aac3f9d19..62961c0328 100644 --- a/composable_kernel/include/tensor_operation/blockwise_dynamic_tensor_slice_transfer.hpp +++ b/composable_kernel/include/tensor_operation/blockwise_dynamic_tensor_slice_transfer.hpp @@ -67,26 +67,18 @@ struct BlockwiseDynamicTensorSliceTransfer_v4 if(BlockSize == thread_cluster_desc_.GetElementSize() or get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize()) { - const auto thread_cluster_id = - thread_cluster_desc_.CalculateClusterIndex(get_thread_local_1d_id()); + const auto thread_cluster_idx = thread_cluster_desc_.CalculateBottomIndex( + make_multi_index(get_thread_local_1d_id())); - const auto thread_data_id_begin = thread_cluster_id * ThreadSliceLengths{}; + const auto thread_data_idx_begin = thread_cluster_idx * ThreadSliceLengths{}; threadwise_transfer_.SetSrcSliceOrigin(src_desc, - src_block_slice_origin + thread_data_id_begin); + src_block_slice_origin + thread_data_idx_begin); threadwise_transfer_.SetDstSliceOrigin(dst_desc, - dst_block_slice_origin + thread_data_id_begin); + dst_block_slice_origin + thread_data_idx_begin); } } - __device__ static constexpr auto CalculateThreadDataBegin() - { - const auto thread_cluster_id = - thread_cluster_desc_.CalculateClusterIndex(get_thread_local_1d_id()); - - return thread_cluster_id * ThreadSliceLengths{}; - } - template __device__ void RunRead(const SrcDesc& src_desc, const SrcData* p_src, @@ -141,8 +133,9 @@ struct BlockwiseDynamicTensorSliceTransfer_v4 } } + private: static constexpr auto thread_cluster_desc_ = - make_cluster_descriptor(ThreadClusterLengths{}, ThreadClusterArrangeOrder{}); + make_cluster_descriptor_v2(ThreadClusterLengths{}, ThreadClusterArrangeOrder{}); using ThreadwiseTransfer = ThreadwiseDynamicTensorSliceTransfer_v3::type = false> -struct BlockwiseGemm_km_kn_m0m1n0n1_v1r1 +struct BlockwiseGemm_km0m1_kn0n1_m0m1n0n1_v1 { - struct MatrixIndex - { - index_t row; - index_t col; - }; + using AIndex = MultiIndex<3>; + using BIndex = MultiIndex<3>; + using CIndex = MultiIndex<4>; - private: - static constexpr auto a_thread_mtx_desc_ = make_dynamic_naive_tensor_descriptor_packed_v2( - make_tuple(Number{}, ThreadMatrixC{}.GetLength(Number<0>{}))); - - static constexpr auto b_thread_mtx_desc_ = make_dynamic_naive_tensor_descriptor_packed_v2( - make_tuple(Number{}, ThreadMatrixC{}.GetLength(Number<1>{}))); - - using AThreadCopy = - ThreadwiseDynamicTensorSliceTransfer_v4, - Sequence<0, 1>, - 1, - ThreadGemmADataPerRead_M, - AddressSpace::Generic, - AddressSpace::Vgpr, - 1>; - - using BThreadCopy = - ThreadwiseDynamicTensorSliceTransfer_v4, - Sequence<0, 1>, - 1, - ThreadGemmBDataPerRead_N, - AddressSpace::Generic, - AddressSpace::Vgpr, - 1>; - - MatrixIndex c_thread_begin_mtx_idx_; - - AThreadCopy a_thread_copy_; - BThreadCopy b_thread_copy_; + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; public: - __device__ BlockwiseGemm_km_kn_m0m1n0n1_v1r1() - : c_thread_begin_mtx_idx_{GetBeginOfThreadMatrixC(get_thread_local_1d_id())}, - a_thread_copy_{make_tuple(0, c_thread_begin_mtx_idx_.row)}, - b_thread_copy_{make_tuple(0, c_thread_begin_mtx_idx_.col)} + __device__ BlockwiseGemm_km0m1_kn0n1_m0m1n0n1_v1() + : c_thread_origin_data_idx_{CalculateCThreadOriginDataIndex(get_thread_local_1d_id())}, + a_thread_copy_{ + make_tuple(0, c_thread_origin_data_idx_[I0], c_thread_origin_data_idx_[I1])}, + b_thread_copy_{ + make_tuple(0, c_thread_origin_data_idx_[I2], c_thread_origin_data_idx_[I3])} { - static_assert(BlockMatrixA::IsKnownAtCompileTime() && - BlockMatrixB::IsKnownAtCompileTime() && - ThreadMatrixC::IsKnownAtCompileTime(), + static_assert(ABlockDesc::IsKnownAtCompileTime() && BBlockDesc::IsKnownAtCompileTime() && + CThreadDesc::IsKnownAtCompileTime(), "wrong! Desc should be known at compile-time"); - constexpr auto I0 = Number<0>{}; - constexpr auto I1 = Number<1>{}; + static_assert(BlockSize == MLevel0ThreadCluster * MLevel1ThreadCluster * + NLevel0ThreadCluster * NLevel1ThreadCluster, + "wrong! blocksize and cluster size not consistent"); - constexpr index_t ThreadPerLevel1Cluster = MLevel0ThreadCluster * NLevel0ThreadCluster * - MLevel1ThreadCluster * NLevel1ThreadCluster; - - static_assert(BlockSize == ThreadPerLevel1Cluster, "wrong! wrong blocksize\n"); - - static_assert(BlockMatrixA{}.GetLength(I0) == BlockMatrixB{}.GetLength(I0), + static_assert(ABlockDesc{}.GetLength(I0) == BBlockDesc{}.GetLength(I0), "wrong! K dimension not consistent"); - - constexpr index_t M = BlockMatrixA{}.GetLength(I1); // A is transposed - constexpr index_t N = BlockMatrixB{}.GetLength(I1); - - static_assert(M % (MPerThreadSubC * MLevel0ThreadCluster * MLevel1ThreadCluster) == 0 && - N % (NPerThreadSubC * NLevel0ThreadCluster * NLevel1ThreadCluster) == 0, - "wrong! Cannot evenly divide work among"); - - static_assert(ThreadMatrixC{}.GetLength(I0) == GetThreadMatrixCLengths()[I0] && - ThreadMatrixC{}.GetLength(I1) == GetThreadMatrixCLengths()[I1], - "wrong! ThreadMatrixC lengths is wrong"); } - __device__ static constexpr auto GetThreadMatrixCLengths() + __device__ static CIndex CalculateCThreadOriginDataIndex(index_t thread_id) { - constexpr auto I1 = Number<1>{}; + constexpr index_t M0 = ABlockDesc{}.GetLength(I1); + constexpr index_t N0 = BBlockDesc{}.GetLength(I1); + constexpr index_t M1 = ABlockDesc{}.GetLength(I2); + constexpr index_t N1 = BBlockDesc{}.GetLength(I2); - constexpr index_t M = BlockMatrixA{}.GetLength(I1); // A is transposed - constexpr index_t N = BlockMatrixB{}.GetLength(I1); + // 4-d data space into 4-d thread space + constexpr auto adaptor0 = make_single_stage_tensor_adaptor( + make_tuple(make_vectorize_transform(M0, 1), + make_vectorize_transform(M1PerThread, M1 / M1PerThread), + make_vectorize_transform(N0, 1), + make_vectorize_transform(N1PerThread, N1 / N1PerThread)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); - constexpr index_t MRepeat = - M / (MPerThreadSubC * MLevel0ThreadCluster * MLevel1ThreadCluster); - constexpr index_t NRepeat = - N / (NPerThreadSubC * NLevel0ThreadCluster * NLevel1ThreadCluster); + // thread position 4-d thread space + constexpr auto adaptor1 = make_single_stage_tensor_adaptor( + make_tuple( + make_freeze_transform(make_multi_index(0)), + make_unmerge_transform(make_tuple(MLevel1ThreadCluster, MLevel0ThreadCluster)), + make_freeze_transform(make_multi_index(0)), + make_unmerge_transform(make_tuple(NLevel1ThreadCluster, NLevel0ThreadCluster))), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<>{}, Sequence<0, 1>{}, Sequence<>{}, Sequence<2, 3>{})); - return Sequence{}; - } + // 4-d thread space to 1-d thread space + constexpr auto adaptor2 = make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(MLevel1ThreadCluster, + NLevel1ThreadCluster, + MLevel0ThreadCluster, + NLevel0ThreadCluster))), + make_tuple(Sequence<0, 2, 1, 3>{}), + make_tuple(Sequence<0>{})); - __device__ static MatrixIndex GetBeginOfThreadMatrixC(index_t thread_id) - { - constexpr index_t ThreadPerLevel0Cluster = MLevel0ThreadCluster * NLevel0ThreadCluster; + constexpr auto cluster_desc = chain_tensor_adaptors(adaptor0, adaptor1, adaptor2); - index_t level1_id = thread_id / ThreadPerLevel0Cluster; - index_t level1_m_id = level1_id / NLevel1ThreadCluster; - index_t level1_n_id = level1_id % NLevel1ThreadCluster; - - index_t level0_id = thread_id % ThreadPerLevel0Cluster; - index_t level0_m_id = level0_id / NLevel0ThreadCluster; - index_t level0_n_id = level0_id % NLevel0ThreadCluster; - - constexpr index_t MPerLevel0Cluster = MPerThreadSubC * MLevel0ThreadCluster; - constexpr index_t NPerLevel0Cluster = NPerThreadSubC * NLevel0ThreadCluster; - - return MatrixIndex{level1_m_id * MPerLevel0Cluster + level0_m_id * MPerThreadSubC, - level1_n_id * NPerLevel0Cluster + level0_n_id * NPerThreadSubC}; - } - - template - __device__ void Run_pipelined_2x2(const ABlockBuffer& a_block_buf, - const BBlockBuffer& b_block_buf, - CThreadBuffer& c_thread_buf) const - { - static_assert(is_same>, - remove_cv_t>>::value && - is_same>, - remove_cv_t>>::value && - is_same>, - remove_cv_t>>::value && - "wrong! inconsistent type"); - - constexpr auto I0 = Number<0>{}; - constexpr auto I1 = Number<1>{}; - - constexpr auto a_block_mtx = BlockMatrixA{}; - constexpr auto b_block_mtx = BlockMatrixB{}; - constexpr auto c_thread_mtx_desc = ThreadMatrixC{}; - - constexpr auto K = a_block_mtx.GetLength(I0); - - constexpr auto MPerThread = c_thread_mtx_desc.GetLength(I0); - constexpr auto NPerThread = c_thread_mtx_desc.GetLength(I1); - - constexpr index_t MPerLevel1Cluster = - MPerThreadSubC * MLevel0ThreadCluster * MLevel1ThreadCluster; - - constexpr index_t NPerLevel1Cluster = - NPerThreadSubC * NLevel0ThreadCluster * NLevel1ThreadCluster; - - constexpr index_t MRepeat = MPerThread / MPerThreadSubC; - constexpr index_t NRepeat = NPerThread / NPerThreadSubC; - - static_assert(MRepeat == 2 && NRepeat == 2, "wrong! only support 2x2 pipeline"); - - // thread A-sub, B-sub - constexpr auto a_thread_sub_mtx = make_dynamic_naive_tensor_descriptor_v2( - make_tuple(Number{}, Number{}), - make_tuple(Number{}, Number<1>{})); - - constexpr auto b_thread_sub_mtx = make_dynamic_naive_tensor_descriptor_v2( - make_tuple(Number{}, Number{}), - make_tuple(Number{}, Number<1>{})); - - constexpr auto c_thread_sub_mtx = make_dynamic_naive_tensor_descriptor_v2( - make_tuple(Number{}, Number{}), - make_tuple(Number{}, Number<1>{})); - - auto a_thread_buf = make_static_buffer(a_thread_mtx_desc_.GetElementSpaceSize()); - auto b_thread_buf = make_static_buffer(b_thread_mtx_desc_.GetElementSpaceSize()); - - constexpr auto threadwise_gemm = ThreadwiseGemm_km_kn_mn_v1r1{}; - - // read A_sub_0 - a_thread_copy_.Run(BlockMatrixA{}, - make_tuple(I0, I0), - a_block_buf, - a_thread_mtx_desc_, - make_tuple(I0, I0), - a_thread_buf); - - // read B_sub_0 - b_thread_copy_.Run(BlockMatrixB{}, - make_tuple(I0, I0), - b_block_buf, - b_thread_mtx_desc_, - make_tuple(I0, I0), - b_thread_buf); - - // read B_sub_1 - b_thread_copy_.Run(BlockMatrixB{}, - make_tuple(I0, Number{}), - b_block_buf, - b_thread_mtx_desc_, - make_tuple(I0, Number{}), - b_thread_buf); - - // read A_sub_1 - a_thread_copy_.Run(BlockMatrixA{}, - make_tuple(I0, Number{}), - a_block_buf, - a_thread_mtx_desc_, - make_tuple(I0, Number{}), - a_thread_buf); - - // C_sub_00 += transpose(A_sub_0) * B_sub_0 - threadwise_gemm.Run(a_thread_buf, - make_tuple(I0, I0), - b_thread_buf, - make_tuple(I0, I0), - c_thread_buf, - make_tuple(I0, I0)); - - // C_sub_01 += transpose(A_sub_0) * B_sub_1 - threadwise_gemm.Run(a_thread_buf, - make_tuple(I0, I0), - b_thread_buf, - make_tuple(I0, Number{}), - c_thread_buf, - make_tuple(I0, Number{})); - - // loop over rest of k - static_for{}([&](auto k) { - // read A_sub_0 - a_thread_copy_.Run(BlockMatrixA{}, - make_tuple(k, I0), - a_block_buf, - a_thread_mtx_desc_, - make_tuple(I0, I0), - a_thread_buf); - - // C_sub_10 += transpose(A_sub_1) * B_sub_0 - threadwise_gemm.Run(a_thread_buf, - make_tuple(I0, Number{}), - b_thread_buf, - make_tuple(I0, I0), - c_thread_buf, - make_tuple(Number{}, I0)); - - // read B_sub_0 - b_thread_copy_.Run(BlockMatrixB{}, - make_tuple(k, I0), - b_block_buf, - b_thread_mtx_desc_, - make_tuple(I0, I0), - b_thread_buf); - - // C_sub_11 += transpose(A_sub_1) * B_sub_1 - threadwise_gemm.Run(a_thread_buf, - make_tuple(I0, Number{}), - b_thread_buf, - make_tuple(I0, Number{}), - c_thread_buf, - make_tuple(Number{}, Number{})); - - // read B_sub_1 - b_thread_copy_.Run(BlockMatrixB{}, - make_tuple(k, Number{}), - b_block_buf, - b_thread_mtx_desc_, - make_tuple(I0, Number{}), - b_thread_buf); - - // read A_sub_1 - a_thread_copy_.Run(BlockMatrixA{}, - make_tuple(k, Number{}), - a_block_buf, - a_thread_mtx_desc_, - make_tuple(I0, Number{}), - a_thread_buf); - - // C_sub_00 += transpose(A_sub_0) * B_sub_0 - threadwise_gemm.Run(a_thread_buf, - make_tuple(I0, I0), - b_thread_buf, - make_tuple(I0, I0), - c_thread_buf, - make_tuple(I0, I0)); - - // C_sub_01 += transpose(A_sub_0) * B_sub_1 - threadwise_gemm.Run(a_thread_buf, - make_tuple(I0, I0), - b_thread_buf, - make_tuple(I0, Number{}), - c_thread_buf, - make_tuple(I0, Number{})); - }); - - // C_sub_10 += transpose(A_sub_1) * B_sub_0 - threadwise_gemm.Run(a_thread_buf, - make_tuple(I0, Number{}), - b_thread_buf, - make_tuple(I0, I0), - c_thread_buf, - make_tuple(Number{}, I0)); - - // C_sub_11 += transpose(A_sub_1) * B_sub_1 - threadwise_gemm.Run(a_thread_buf, - make_tuple(I0, Number{}), - b_thread_buf, - make_tuple(I0, Number{}), - c_thread_buf, - make_tuple(Number{}, Number{})); + return cluster_desc.CalculateBottomIndex(make_multi_index(get_thread_local_1d_id())); } template @@ -349,28 +115,394 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v1r1 const BBlockBuffer& b_block_buf, CThreadBuffer& c_thread_buf) const { -#if CK_EXPERIMENTAL_BLOCKWISE_GEMM_USE_PIPELINE - constexpr auto I0 = Number<0>{}; - constexpr auto I1 = Number<1>{}; + auto a_thread_buf = make_static_buffer(a_thread_desc_.GetElementSpaceSize()); + auto b_thread_buf = make_static_buffer(b_thread_desc_.GetElementSpaceSize()); - constexpr index_t MPerThread = ThreadMatrixC{}.GetLength(I0); - constexpr index_t NPerThread = ThreadMatrixC{}.GetLength(I1); + constexpr auto threadwise_gemm = + ThreadwiseGemm_km0m1_kn0n1_m0m1n0n1, + Sequence, + Sequence>{}; - constexpr index_t MRepeat = MPerThread / MPerThreadSubC; - constexpr index_t NRepeat = NPerThread / NPerThreadSubC; + constexpr index_t K = ABlockDesc{}.GetLength(I0); - if constexpr(MRepeat == 2 && NRepeat == 2) - { - Run_pipelined_2x2(a_block_buf, b_block_buf, c_thread_buf); - } - else - { - Run_naive(a_block_buf, b_block_buf, c_thread_buf); - } -#else - Run_naive(a_block_buf, b_block_buf, c_thread_buf); -#endif + static_for<0, K, KPerThread>{}([&](auto k) { + a_thread_copy_.Run(ABlockDesc{}, + make_tuple(k, I0, I0), + a_block_buf, + a_thread_desc_, + make_tuple(I0, I0, I0), + a_thread_buf); + + b_thread_copy_.Run(BBlockDesc{}, + make_tuple(k, I0, I0), + b_block_buf, + b_thread_desc_, + make_tuple(I0, I0, I0), + b_thread_buf); + + threadwise_gemm.Run(a_thread_buf, + make_tuple(I0, I0, I0), + b_thread_buf, + make_tuple(I0, I0, I0), + c_thread_buf, + make_tuple(I0, I0, I0, I0)); + }); } + + private: + static constexpr index_t M0_ = ABlockDesc{}.GetLength(I1); + static constexpr index_t N0_ = BBlockDesc{}.GetLength(I1); + + // A[K, M0, M1] + static constexpr auto a_thread_desc_ = make_dynamic_naive_tensor_descriptor_packed_v2( + make_tuple(Number{}, Number{}, Number{})); + + // B[K, N0, N1] + static constexpr auto b_thread_desc_ = make_dynamic_naive_tensor_descriptor_packed_v2( + make_tuple(Number{}, Number{}, Number{})); + + using AThreadCopy = + ThreadwiseDynamicTensorSliceTransfer_v4, + Sequence<0, 1, 2>, + 2, + AThreadCopyScalarPerVector_M1, + AddressSpace::Generic, + AddressSpace::Vgpr, + 1>; + + using BThreadCopy = + ThreadwiseDynamicTensorSliceTransfer_v4, + Sequence<0, 1, 2>, + 2, + BThreadCopyScalarPerVector_N1, + AddressSpace::Generic, + AddressSpace::Vgpr, + 1>; + + CIndex c_thread_origin_data_idx_; + + AThreadCopy a_thread_copy_; + BThreadCopy b_thread_copy_; }; + +// C[M0, M1, N0, N1] += transpose(A[K, M0, M1]) * B[K, N0, N1] +// A and B are visable to the whole block, C is distributed among each thread +// Assume: +// 1. A: +// 1. ABlockDesc is known at compile-time +// 2. ABlockBuffer is DynamicBuffer +// 2. B: +// 1. ABlockDesc is known at compile-time +// 2. BBlockBuffer is DynamicBuffer +// 3. C: +// 1. CThreadDesc is known at compile-time +// 2. CThreadBuffer is StaticBuffer +template ::type = false> +struct BlockwiseGemm_km0m1_kn0n1_m0m1n0n1_v2_pipeline_2x2 +{ + using AIndex = MultiIndex<3>; + using BIndex = MultiIndex<3>; + using CIndex = MultiIndex<4>; + + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + + public: + __device__ BlockwiseGemm_km0m1_kn0n1_m0m1n0n1_v2_pipeline_2x2() + : c_thread_origin_data_idx_{CalculateCThreadOriginDataIndex(get_thread_local_1d_id())}, + a_thread_copy_{ + make_tuple(0, c_thread_origin_data_idx_[I0], c_thread_origin_data_idx_[I1])}, + b_thread_copy_{ + make_tuple(0, c_thread_origin_data_idx_[I2], c_thread_origin_data_idx_[I3])} + { + static_assert(ABlockDesc::IsKnownAtCompileTime() && BBlockDesc::IsKnownAtCompileTime() && + CThreadDesc::IsKnownAtCompileTime(), + "wrong! Desc should be known at compile-time"); + + static_assert(BlockSize == MLevel0ThreadCluster * MLevel1ThreadCluster * + NLevel0ThreadCluster * NLevel1ThreadCluster, + "wrong! blocksize and cluster size not consistent"); + + static_assert(ABlockDesc{}.GetLength(I0) == BBlockDesc{}.GetLength(I0), + "wrong! K dimension not consistent"); + + // TODO: remove this restriction + static_assert(ABlockDesc{}.GetLength(I1) == 2 && BBlockDesc{}.GetLength(I1) == 2 && + CThreadDesc{}.GetLength(I0) == 2 && CThreadDesc{}.GetLength(I2) == 2, + "wrong"); + } + + __device__ static CIndex CalculateCThreadOriginDataIndex(index_t thread_id) + { + constexpr index_t M0 = ABlockDesc{}.GetLength(I1); + constexpr index_t N0 = BBlockDesc{}.GetLength(I1); + constexpr index_t M1 = ABlockDesc{}.GetLength(I2); + constexpr index_t N1 = BBlockDesc{}.GetLength(I2); + + // 4-d data space into 4-d thread space + constexpr auto adaptor0 = make_single_stage_tensor_adaptor( + make_tuple(make_vectorize_transform(M0, 1), + make_vectorize_transform(M1PerThread, M1 / M1PerThread), + make_vectorize_transform(N0, 1), + make_vectorize_transform(N1PerThread, N1 / N1PerThread)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); + + // thread position 4-d thread space + constexpr auto adaptor1 = make_single_stage_tensor_adaptor( + make_tuple( + make_freeze_transform(make_multi_index(0)), + make_unmerge_transform(make_tuple(MLevel1ThreadCluster, MLevel0ThreadCluster)), + make_freeze_transform(make_multi_index(0)), + make_unmerge_transform(make_tuple(NLevel1ThreadCluster, NLevel0ThreadCluster))), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<>{}, Sequence<0, 1>{}, Sequence<>{}, Sequence<2, 3>{})); + + // 4-d thread space to 1-d thread space + constexpr auto adaptor2 = make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(MLevel1ThreadCluster, + NLevel1ThreadCluster, + MLevel0ThreadCluster, + NLevel0ThreadCluster))), + make_tuple(Sequence<0, 2, 1, 3>{}), + make_tuple(Sequence<0>{})); + + constexpr auto cluster_desc = chain_tensor_adaptors(adaptor0, adaptor1, adaptor2); + + return cluster_desc.CalculateBottomIndex(make_multi_index(get_thread_local_1d_id())); + } + + template + __device__ void Run(const ABlockBuffer& a_block_buf, + const BBlockBuffer& b_block_buf, + CThreadBuffer& c_thread_buf) const + { + auto a_thread_buf = make_static_buffer(a_thread_desc_.GetElementSpaceSize()); + auto b_thread_buf = make_static_buffer(b_thread_desc_.GetElementSpaceSize()); + + constexpr auto threadwise_gemm = + ThreadwiseGemm_km0m1_kn0n1_m0m1n0n1, + Sequence<1, M1PerThread>, + Sequence<1, N1PerThread>>{}; + + constexpr index_t K = ABlockDesc{}.GetLength(I0); + + // read A_sub_0 + a_thread_copy_.Run(ABlockDesc{}, + make_tuple(I0, I0, I0), + a_block_buf, + a_thread_desc_, + make_tuple(I0, I0, I0), + a_thread_buf); + + // read B_sub_0 + b_thread_copy_.Run(BBlockDesc{}, + make_tuple(I0, I0, I0), + b_block_buf, + b_thread_desc_, + make_tuple(I0, I0, I0), + b_thread_buf); + + // read B_sub_1 + b_thread_copy_.Run(BBlockDesc{}, + make_tuple(I0, I1, I0), + b_block_buf, + b_thread_desc_, + make_tuple(I0, I1, I0), + b_thread_buf); + + // read A_sub_1 + a_thread_copy_.Run(ABlockDesc{}, + make_tuple(I0, I1, I0), + a_block_buf, + a_thread_desc_, + make_tuple(I0, I1, I0), + a_thread_buf); + + // C_sub_00 += transpose(A_sub_0) * B_sub_0 + threadwise_gemm.Run(a_thread_buf, + make_tuple(I0, I0, I0), + b_thread_buf, + make_tuple(I0, I0, I0), + c_thread_buf, + make_tuple(I0, I0, I0, I0)); + + // C_sub_01 += transpose(A_sub_0) * B_sub_1 + threadwise_gemm.Run(a_thread_buf, + make_tuple(I0, I0, I0), + b_thread_buf, + make_tuple(I0, I1, I0), + c_thread_buf, + make_tuple(I0, I0, I1, I0)); + + // loop over rest of k + static_for{}([&](auto k) { + // read A_sub_0 + a_thread_copy_.Run(ABlockDesc{}, + make_tuple(k, I0, I0), + a_block_buf, + a_thread_desc_, + make_tuple(I0, I0, I0), + a_thread_buf); + + // C_sub_10 += transpose(A_sub_1) * B_sub_0 + threadwise_gemm.Run(a_thread_buf, + make_tuple(I0, I1, I0), + b_thread_buf, + make_tuple(I0, I0, I0), + c_thread_buf, + make_tuple(I1, I0, I0, I0)); + + // read B_sub_0 + b_thread_copy_.Run(BBlockDesc{}, + make_tuple(k, I0, I0), + b_block_buf, + b_thread_desc_, + make_tuple(I0, I0, I0), + b_thread_buf); + + // C_sub_11 += transpose(A_sub_1) * B_sub_1 + threadwise_gemm.Run(a_thread_buf, + make_tuple(I0, I1, I0), + b_thread_buf, + make_tuple(I0, I1, I0), + c_thread_buf, + make_tuple(I1, I0, I1, I0)); + + // read B_sub_1 + b_thread_copy_.Run(BBlockDesc{}, + make_tuple(k, I1, I0), + b_block_buf, + b_thread_desc_, + make_tuple(I0, I1, I0), + b_thread_buf); + + // read A_sub_1 + a_thread_copy_.Run(ABlockDesc{}, + make_tuple(k, I1, I0), + a_block_buf, + a_thread_desc_, + make_tuple(I0, I1, I0), + a_thread_buf); + + // C_sub_00 += transpose(A_sub_0) * B_sub_0 + threadwise_gemm.Run(a_thread_buf, + make_tuple(I0, I0, I0), + b_thread_buf, + make_tuple(I0, I0, I0), + c_thread_buf, + make_tuple(I0, I0, I0, I0)); + + // C_sub_01 += transpose(A_sub_0) * B_sub_1 + threadwise_gemm.Run(a_thread_buf, + make_tuple(I0, I0, I0), + b_thread_buf, + make_tuple(I0, I1, I0), + c_thread_buf, + make_tuple(I0, I0, I1, I0)); + }); + + // C_sub_10 += transpose(A_sub_1) * B_sub_0 + threadwise_gemm.Run(a_thread_buf, + make_tuple(I0, I1, I0), + b_thread_buf, + make_tuple(I0, I0, I0), + c_thread_buf, + make_tuple(I1, I0, I0, I0)); + + // C_sub_11 += transpose(A_sub_1) * B_sub_1 + threadwise_gemm.Run(a_thread_buf, + make_tuple(I0, I1, I0), + b_thread_buf, + make_tuple(I0, I1, I0), + c_thread_buf, + make_tuple(I1, I0, I1, I0)); + } + + private: + static constexpr index_t M0_ = ABlockDesc{}.GetLength(I1); + static constexpr index_t N0_ = BBlockDesc{}.GetLength(I1); + + // A[K, M0, M1] + static constexpr auto a_thread_desc_ = make_dynamic_naive_tensor_descriptor_packed_v2( + make_tuple(Number{}, Number{}, Number{})); + + // B[K, N0, N1] + static constexpr auto b_thread_desc_ = make_dynamic_naive_tensor_descriptor_packed_v2( + make_tuple(Number{}, Number{}, Number{})); + + using AThreadCopy = + ThreadwiseDynamicTensorSliceTransfer_v4, + Sequence<0, 1, 2>, + 2, + AThreadCopyScalarPerVector_M1, + AddressSpace::Generic, + AddressSpace::Vgpr, + 1>; + + using BThreadCopy = + ThreadwiseDynamicTensorSliceTransfer_v4, + Sequence<0, 1, 2>, + 2, + BThreadCopyScalarPerVector_N1, + AddressSpace::Generic, + AddressSpace::Vgpr, + 1>; + + CIndex c_thread_origin_data_idx_; + + AThreadCopy a_thread_copy_; + BThreadCopy b_thread_copy_; +}; + } // namespace ck #endif diff --git a/composable_kernel/include/tensor_operation/gridwise_dynamic_gemm.hpp b/composable_kernel/include/tensor_operation/gridwise_dynamic_gemm.hpp index 0f94f67bbc..f2fd245dfa 100644 --- a/composable_kernel/include/tensor_operation/gridwise_dynamic_gemm.hpp +++ b/composable_kernel/include/tensor_operation/gridwise_dynamic_gemm.hpp @@ -12,7 +12,36 @@ namespace ck { -#if CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VOID_POINTER +#if CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VALUE +template +__global__ void kernel_dynamic_gemm_v1(const AGlobalDesc a_k_m_global_desc, + const FloatA* __restrict__ p_a_global, + const BGlobalDesc b_k_n_global_desc, + const FloatB* __restrict__ p_b_global, + const CGlobalDesc c_m0_m1_n0_n1_global_desc, + FloatC* __restrict__ p_c_global, + const CBlockClusterDesc c_block_cluster_desc) +{ + GridwiseGemm{}.Run(a_k_m_global_desc, + p_a_global, + b_k_n_global_desc, + p_b_global, + c_m0_m1_n0_n1_global_desc, + p_c_global, + c_block_cluster_desc, + integral_constant{}, + integral_constant{}); +} +#elif CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VOID_POINTER // pass tensor descriptor by __CONSTANT__ void pointer // __CONSTANT__ is needed to inform compiler void pointers in the kernel signature are pointing to // non-modifiable parameter address space, so compiler can enable corresponding optimization @@ -23,16 +52,18 @@ template -__global__ void run_gridwise_dynamic_gemm_v1(const void __CONSTANT__* p_a_k_m_global_desc, - const FloatA* __restrict__ p_a_global, - const void __CONSTANT__* p_b_k_n_global_desc, - const FloatB* __restrict__ p_b_global, - const void __CONSTANT__* p_c_m0_m1_n0_n1_global_desc, - FloatC* __restrict__ p_c_global) +__global__ void kernel_dynamic_gemm_v1(const void __CONSTANT__* p_a_k_m_global_desc, + const FloatA* __restrict__ p_a_global, + const void __CONSTANT__* p_b_k_n_global_desc, + const FloatB* __restrict__ p_b_global, + const void __CONSTANT__* p_c_m0_m1_n0_n1_global_desc, + FloatC* __restrict__ p_c_global, + const void __CONSTANT__* p_c_block_cluster_desc) { - // first cast void __CONSTANT__* to void* + // first cast void __CONSTANT__ void* to void* // second cast void* to Desc* // the copy constructor of tensor descriptor doesn't take address_space(4) const auto a_k_m_global_desc = @@ -42,12 +73,16 @@ __global__ void run_gridwise_dynamic_gemm_v1(const void __CONSTANT__* p_a_k_m_gl const auto c_m0_m1_n0_n1_global_desc = *reinterpret_cast((const void*)p_c_m0_m1_n0_n1_global_desc); + const auto c_block_cluster_desc = + *reinterpret_cast((const void*)p_c_block_cluster_desc); + GridwiseGemm{}.Run(a_k_m_global_desc, p_a_global, b_k_n_global_desc, p_b_global, c_m0_m1_n0_n1_global_desc, p_c_global, + c_block_cluster_desc, integral_constant{}, integral_constant{}); } @@ -61,6 +96,7 @@ template , integral_constant) const { constexpr auto I0 = Number<0>{}; constexpr auto I1 = Number<1>{}; + constexpr auto I2 = Number<2>{}; + constexpr auto I3 = Number<3>{}; const auto K = a_k_m_global_desc.GetLength(I0); const auto M = a_k_m_global_desc.GetLength(I1); const auto N = b_k_n_global_desc.GetLength(I1); // divide block work by [M, N] -#if 0 - const auto m_block_work_num = M / Number{}; - const auto n_block_work_num = N / Number{}; + const auto block_work_idx = + c_block_cluster_desc.CalculateBottomIndex(make_multi_index(get_block_1d_id())); - const index_t m_block_work_id = get_block_1d_id() / n_block_work_num; - const index_t n_block_work_id = get_block_1d_id() - m_block_work_id * n_block_work_num; + // HACK: this force m/n_block_data_idx_on_global into SGPR + const index_t m_block_data_idx_on_global = + __builtin_amdgcn_readfirstlane(block_work_idx[I0] * MPerBlock); -#else - // Hack: this force result into SGPR - const index_t m_block_work_num = __builtin_amdgcn_readfirstlane(M / MPerBlock); - const index_t n_block_work_num = __builtin_amdgcn_readfirstlane(N / NPerBlock); - - const index_t m_block_work_id = - __builtin_amdgcn_readfirstlane(get_block_1d_id() / n_block_work_num); - const index_t n_block_work_id = get_block_1d_id() - m_block_work_id * n_block_work_num; -#endif - - const index_t m_block_data_on_global = m_block_work_id * MPerBlock; - const index_t n_block_data_on_global = n_block_work_id * NPerBlock; + const index_t n_block_data_idx_on_global = + __builtin_amdgcn_readfirstlane(block_work_idx[I1] * NPerBlock); // lds max alignment constexpr auto max_lds_align = math::lcm(Number{}, @@ -204,7 +233,7 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1 AThreadTransferSrcResetCoordinateAfterRun, true>( a_k_m_global_desc, - make_multi_index(0, m_block_data_on_global), + make_multi_index(0, m_block_data_idx_on_global), a_k_m_block_desc, make_multi_index(0, 0)); @@ -233,7 +262,7 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1 BThreadTransferSrcResetCoordinateAfterRun, true>( b_k_n_global_desc, - make_multi_index(0, n_block_data_on_global), + make_multi_index(0, n_block_data_idx_on_global), b_k_n_block_desc, make_multi_index(0, 0)); @@ -251,28 +280,45 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1 constexpr index_t MRepeat = MPerBlock / (MPerThread * MLevel0Cluster * MLevel1Cluster); constexpr index_t NRepeat = NPerBlock / (NPerThread * NLevel0Cluster * NLevel1Cluster); - // c_thread_mtx definition: this is a mess - // TODO:: more elegent way of defining c_thread_mtx - constexpr auto c_m0m1_n0n1_thread_desc = make_dynamic_naive_tensor_descriptor_packed_v2( - make_tuple(Number{}, Number{})); + constexpr auto a_k_m0_m1_block_desc = transform_dynamic_tensor_descriptor( + a_k_m_block_desc, + make_tuple( + make_pass_through_transform(Number{}), + make_unmerge_transform(make_tuple( + Number{}, Number{}))), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1, 2>{})); + + constexpr auto b_k_n0_n1_block_desc = transform_dynamic_tensor_descriptor( + b_k_n_block_desc, + make_tuple( + make_pass_through_transform(Number{}), + make_unmerge_transform(make_tuple( + Number{}, Number{}))), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1, 2>{})); + + constexpr auto c_m0_m1_n0_n1_thread_desc = + make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple( + Number{}, Number{}, Number{}, Number{})); const auto blockwise_gemm = - BlockwiseGemm_km_kn_m0m1n0n1_v1r1{}; + BlockwiseGemm_km0m1_kn0n1_m0m1n0n1_v2_pipeline_2x2{}; // LDS allocation for A and B: be careful of alignment constexpr auto a_block_space_size = @@ -286,12 +332,12 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1 // register allocation for output auto c_thread_buf = - make_static_buffer(c_m0m1_n0n1_thread_desc.GetElementSpaceSize()); + make_static_buffer(c_m0_m1_n0_n1_thread_desc.GetElementSpaceSize()); ThreadwiseDynamicTensorSliceSet_v1>{} - .Run(c_m0m1_n0n1_thread_desc, make_tuple(I0, I0), c_thread_buf, FloatAcc{0}); + decltype(c_m0_m1_n0_n1_thread_desc), + Sequence>{} + .Run(c_m0_m1_n0_n1_thread_desc, make_tuple(I0, I0, I0, I0), c_thread_buf, FloatAcc{0}); constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock, 0); constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock, 0); @@ -427,30 +473,11 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1 constexpr auto M1 = Number{}; constexpr auto N1 = Number{}; - // define input tensor descriptor for threadwise copy - // thread input tensor, src of threadwise copy - constexpr auto c_m0_m1_n0_n1_thread_desc = - make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(Number{}, - Number{}, - Number{}, - Number{})); - - // calculate origin of thread input tensor on global memory - // blockwise GEMM c matrix starting index - const auto c_thread_mtx_on_block = - blockwise_gemm.GetBeginOfThreadMatrixC(get_thread_local_1d_id()); - - const index_t m_thread_data_on_global = - m_block_data_on_global + c_thread_mtx_on_block.row; - - const index_t n_thread_data_on_global = - n_block_data_on_global + c_thread_mtx_on_block.col; - // hack to control index calculation when iterating over c_m0_m1_n0_n1_global tensor constexpr auto c_m0_m1_n0_n1_global_tensor_iterator_hacks = CGlobalIteratorHacks{}; - constexpr auto tmp = make_unmerge_transform(make_tuple( - Number{}, Number{}, Number{}, Number{})); + const auto c_thread_data_idx_on_block = + blockwise_gemm.CalculateCThreadOriginDataIndex(get_thread_local_1d_id()); ThreadwiseDynamicTensorSliceTransfer_v1r3< FloatAcc, @@ -465,11 +492,12 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1 AddressSpace::Global, CGlobalMemoryDataOperation, 1, - true>(c_m0_m1_n0_n1_global_desc, - make_multi_index(m_thread_data_on_global / M1, - m_thread_data_on_global % M1, - n_thread_data_on_global / N1, - n_thread_data_on_global % N1)) + true>{ + c_m0_m1_n0_n1_global_desc, + make_multi_index(m_block_data_idx_on_global / M1 + c_thread_data_idx_on_block[I0], + c_thread_data_idx_on_block[I1], + n_block_data_idx_on_global / N1 + c_thread_data_idx_on_block[I2], + c_thread_data_idx_on_block[I3])} .Run(c_m0_m1_n0_n1_thread_desc, make_tuple(I0, I0, I0, I0), c_thread_buf, @@ -486,6 +514,7 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1 const FloatAB* __restrict__ p_b_global, const CGlobalDesc& c_m0_m1_n0_n1_global_desc, FloatC* __restrict__ p_c_global, + const CBlockClusterDesc& c_block_cluster_desc, integral_constant, integral_constant) const { @@ -499,6 +528,7 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1 p_b_global, c_m0_m1_n0_n1_global_desc, p_c_global, + c_block_cluster_desc, p_shared_block, integral_constant{}, integral_constant{}); diff --git a/composable_kernel/include/tensor_operation/threadwise_dynamic_tensor_slice_transfer.hpp b/composable_kernel/include/tensor_operation/threadwise_dynamic_tensor_slice_transfer.hpp index 34b6cfec79..0aa2da0240 100644 --- a/composable_kernel/include/tensor_operation/threadwise_dynamic_tensor_slice_transfer.hpp +++ b/composable_kernel/include/tensor_operation/threadwise_dynamic_tensor_slice_transfer.hpp @@ -1376,6 +1376,8 @@ struct ThreadwiseDynamicTensorSliceTransfer_v4 { static_assert(SrcDesc::IsKnownAtCompileTime() && DstDesc::IsKnownAtCompileTime(), "wrong! SrcDesc and DstDesc need to known at compile-time"); + + static_assert(SliceLengths::At(Number{}) % SrcScalarPerVector == 0, "wrong!"); } template ::type = false> +struct ThreadwiseGemm_km0m1_kn0n1_m0m1n0n1 +{ + __device__ constexpr ThreadwiseGemm_km0m1_kn0n1_m0m1n0n1() + { + static_assert(ADesc::IsKnownAtCompileTime() && BDesc::IsKnownAtCompileTime() && + CDesc::IsKnownAtCompileTime(), + "wrong! Desc should be known at compile-time"); + + // TODO: sanity-check: compare ADesc, BDesc, CDesc Size with KLenghts, MLengths and NLengths + + // TODO remove this restriction + static_assert(KLengths::Size() == 1 && MLengths::Size() == 2 && NLengths::Size() == 2, + "wrong!"); + } + + template + __device__ static void Run(const ABuffer& a_buf, + AOriginIdx, + const BBuffer& b_buf, + BOriginIdx, + CBuffer& c_buf, + COriginIdx) + { + static_assert( + is_known_at_compile_time>>::value && + is_known_at_compile_time>>::value && + is_known_at_compile_time>>::value, + "wrong! AOriginIdx, BOriginIdx, COringinIdx should be known at compile-time"); + + static_assert(is_same>, + remove_cv_t>>::value && + is_same>, + remove_cv_t>>::value && + is_same>, + remove_cv_t>>::value && + "wrong! inconsistent type"); + + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + constexpr auto I2 = Number<2>{}; + constexpr auto I3 = Number<3>{}; + + constexpr auto K = KLengths{}[I0]; + constexpr auto M0 = MLengths{}[I0]; + constexpr auto M1 = MLengths{}[I1]; + constexpr auto N0 = NLengths{}[I0]; + constexpr auto N1 = NLengths{}[I1]; + + constexpr auto a_origin_idx = to_multi_index(AOriginIdx{}); + constexpr auto b_origin_idx = to_multi_index(BOriginIdx{}); + constexpr auto c_origin_idx = to_multi_index(COriginIdx{}); + + static_for<0, K, 1>{}([&](auto k) { + static_for<0, M0, 1>{}([&](auto m0) { + static_for<0, M1, 1>{}([&](auto m1) { + static_for<0, N0, 1>{}([&](auto n0) { + static_for<0, N1, 1>{}([&](auto n1) { + + constexpr index_t a_offset = + ADesc{}.CalculateOffset(a_origin_idx + make_multi_index(k, m0, m1)); + constexpr index_t b_offset = + BDesc{}.CalculateOffset(b_origin_idx + make_multi_index(k, n0, n1)); + constexpr index_t c_offset = CDesc{}.CalculateOffset( + c_origin_idx + make_multi_index(m0, m1, n0, n1)); + + amd_assembly_inner_product(a_buf[Number{}], + b_buf[Number{}], + c_buf(Number{})); + }); + }); + }); + }); + }); + } +}; + } // namespace ck #endif diff --git a/composable_kernel/include/utility/common_header.hpp b/composable_kernel/include/utility/common_header.hpp index 6afe465800..dd89918275 100644 --- a/composable_kernel/include/utility/common_header.hpp +++ b/composable_kernel/include/utility/common_header.hpp @@ -6,6 +6,7 @@ #include "container_helper.hpp" #include "statically_indexed_array.hpp" #include "container_element_picker.hpp" +#include "data_type.hpp" #include "float_type.hpp" #include "buffer.hpp" #include "functional.hpp" diff --git a/composable_kernel/include/utility/container_element_picker.hpp b/composable_kernel/include/utility/container_element_picker.hpp index f71086f6cb..54915125ac 100644 --- a/composable_kernel/include/utility/container_element_picker.hpp +++ b/composable_kernel/include/utility/container_element_picker.hpp @@ -20,7 +20,8 @@ struct ContainerElementPicker __host__ __device__ constexpr ContainerElementPicker(Arr& array) : mArray{array} { - constexpr index_t imax = reduce_on_sequence(Picks{}, math::maxer{}, Number<0>{}); + constexpr index_t imax = + reduce_on_sequence(Picks{}, math::maximize{}, Number<0>{}); static_assert(imax < Arr::Size(), "wrong! exceeding # array element"); } @@ -85,7 +86,8 @@ struct ConstantContainerElementPicker __host__ __device__ constexpr ConstantContainerElementPicker(const Arr& array) : mArray{array} { - constexpr index_t imax = reduce_on_sequence(Picks{}, math::maxer{}, Number<0>{}); + constexpr index_t imax = + reduce_on_sequence(Picks{}, math::maximize{}, Number<0>{}); static_assert(imax < Arr::Size(), "wrong! exceeding # array element"); } diff --git a/composable_kernel/include/utility/container_helper.hpp b/composable_kernel/include/utility/container_helper.hpp index f47a29d058..74cd600cae 100644 --- a/composable_kernel/include/utility/container_helper.hpp +++ b/composable_kernel/include/utility/container_helper.hpp @@ -26,13 +26,13 @@ __host__ __device__ constexpr auto container_push_back(const Array template __host__ __device__ constexpr auto container_push_front(const Tuple& a, const T& x) { - return container_cat(make_tuple(x), a); + return container_concat(make_tuple(x), a); } template __host__ __device__ constexpr auto container_push_back(const Tuple& a, const T& x) { - return container_cat(a, make_tuple(x)); + return container_concat(a, make_tuple(x)); } template @@ -158,6 +158,7 @@ __host__ __device__ constexpr auto container_reduce_impl( } // rocm-4.1 compiler would crash for recursive lambda +// container reduce with initial value template & x, Reduce f, TData init) } template -__host__ __device__ constexpr auto container_cat(const X& x, const Ys&... ys) +__host__ __device__ constexpr auto container_concat(const X& x, const Ys&... ys) { - return container_cat(x, container_cat(ys...)); + return container_concat(x, container_concat(ys...)); } template -__host__ __device__ constexpr auto container_cat(const Array& ax, const Array& ay) +__host__ __device__ constexpr auto container_concat(const Array& ax, const Array& ay) { return unpack2( [&](auto&&... zs) { return make_array(std::forward(zs)...); }, ax, ay); } template -__host__ __device__ constexpr auto container_cat(const Tuple& tx, const Tuple& ty) +__host__ __device__ constexpr auto container_concat(const Tuple& tx, const Tuple& ty) { return unpack2( [&](auto&&... zs) { return make_tuple(std::forward(zs)...); }, tx, ty); } template -__host__ __device__ constexpr auto container_cat(const Container& x) +__host__ __device__ constexpr auto container_concat(const Container& x) { return x; } diff --git a/composable_kernel/include/utility/data_type.hpp b/composable_kernel/include/utility/data_type.hpp new file mode 100644 index 0000000000..66d2a88be4 --- /dev/null +++ b/composable_kernel/include/utility/data_type.hpp @@ -0,0 +1,24 @@ +#ifndef CK_DATA_TYPE_HPP +#define CK_DATA_TYPE_HPP + +namespace ck { + +template +struct NumericLimits; + +template <> +struct NumericLimits +{ + __host__ __device__ static constexpr int32_t Min() + { + return std::numeric_limits::min(); + } + + __host__ __device__ static constexpr int32_t Max() + { + return std::numeric_limits::max(); + } +}; + +} // namespace ck +#endif diff --git a/composable_kernel/include/utility/math.hpp b/composable_kernel/include/utility/math.hpp index 5738030732..639d4157e6 100644 --- a/composable_kernel/include/utility/math.hpp +++ b/composable_kernel/include/utility/math.hpp @@ -43,11 +43,17 @@ struct multiplies_v2 }; template -struct maxer +struct maximize { __host__ __device__ constexpr T operator()(T a, T b) const { return a >= b ? a : b; } }; +template +struct minimize +{ + __host__ __device__ constexpr T operator()(T a, T b) const { return a <= b ? a : b; } +}; + template struct integer_divide_ceiler { diff --git a/driver/include/device.hpp b/driver/include/device.hpp index 869d52d794..b68e07c85f 100644 --- a/driver/include/device.hpp +++ b/driver/include/device.hpp @@ -46,6 +46,7 @@ void launch_kernel(F kernel, template float launch_and_time_kernel(F kernel, + int nrepeat, dim3 grid_dim, dim3 block_dim, std::size_t lds_byte, @@ -54,15 +55,32 @@ float launch_and_time_kernel(F kernel, { KernelTimer timer; + printf("%s: block_dim {%d, %d, %d}, grid_dim {%d, %d, %d} \n", + __func__, + grid_dim.x, + grid_dim.y, + grid_dim.z, + block_dim.x, + block_dim.y, + block_dim.z); + + printf("Warm up\n"); + + // warm up + hipLaunchKernelGGL(kernel, grid_dim, block_dim, lds_byte, stream_id, args...); + + printf("Start running %d times...\n", nrepeat); + timer.Start(); - hipLaunchKernelGGL(kernel, grid_dim, block_dim, lds_byte, stream_id, args...); + for(int i = 0; i < nrepeat; ++i) + { + hipLaunchKernelGGL(kernel, grid_dim, block_dim, lds_byte, stream_id, args...); + } timer.End(); - hipGetLastError(); - - return timer.GetElapsedTime(); + return timer.GetElapsedTime() / nrepeat; } #elif CK_DEVICE_BACKEND_NVIDIA diff --git a/driver/include/device_dynamic_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp b/driver/include/device_dynamic_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp index 1aa187dfcf..520003d038 100644 --- a/driver/include/device_dynamic_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp +++ b/driver/include/device_dynamic_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp @@ -29,8 +29,17 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw( { using namespace ck; - std::cout << "device_dynamic_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw" - << std::endl; + std::cout << __func__ << std::endl; + + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + constexpr auto I2 = Number<2>{}; + constexpr auto I3 = Number<3>{}; + constexpr auto I4 = Number<4>{}; + constexpr auto I5 = Number<5>{}; + constexpr auto I6 = Number<6>{}; + constexpr auto I7 = Number<7>{}; + constexpr auto I8 = Number<8>{}; DeviceMem in_n_c_hi_wi_device_buf(sizeof(TInWei) * in_n_c_hi_wi.mDesc.GetElementSpace()); DeviceMem wei_k_c_y_x_device_buf(sizeof(TInWei) * wei_k_c_y_x.mDesc.GetElementSpace()); @@ -459,50 +468,94 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw( constexpr index_t GemmCThreadTransferDstScalarPerVector_GemmN1 = 4; #endif - constexpr auto conv_driver = + constexpr index_t GemmM1 = GemmMPerThread * GemmMLevel0Cluster * GemmMLevel1Cluster; + constexpr index_t GemmN1 = GemmNPerThread * GemmNLevel0Cluster * GemmNLevel1Cluster; + + const auto descs = #if 1 - DriverDynamicConvolutionForwardImplicitGemm_v4r4_nchw_kcyx_nkhw_pad + transform_forward_convolution_into_gemm_v4r4_nchw_kcyx_nkhw_pad #elif 0 - DriverDynamicConvolutionForwardImplicitGemm_v4r4_nchw_kcyx_nkhw_no_pad -#elif 1 - DriverDynamicConvolutionForwardImplicitGemm_v4r4_nchw_kcyx_nkhw_1x1 + transform_forward_convolution_into_gemm_v4r4_nchw_kcyx_nkhw_no_pad +#else + transform_forward_convolution_into_gemm_v4r4_nchw_kcyx_nkhw_1x1 #endif - ::type, - TAcc, - TOut, - GemmMPerBlock, - GemmNPerBlock, - GemmKPerBlock, - GemmMPerThread, - GemmNPerThread, - GemmKPerThread, - GemmMLevel0Cluster, - GemmNLevel0Cluster, - GemmMLevel1Cluster, - GemmNLevel1Cluster, - GemmABlockTransferThreadSliceLengths_GemmK_GemmM, - GemmABlockTransferThreadClusterLengths_GemmK_GemmM, - GemmABlockTransferSrcScalarPerVector_GemmK, - GemmABlockTransferDstScalarPerVector_GemmM, - GemmBBlockTransferThreadSliceLengths_GemmK_GemmN, - GemmBBlockTransferThreadClusterLengths_GemmK_GemmN, - GemmBBlockTransferSrcScalarPerVector_GemmN, - GemmBBlockTransferDstScalarPerVector_GemmN, - GemmCThreadTransferDstScalarPerVector_GemmN1>{}; + (wei_k_c_y_x_desc, + in_n_c_hi_wi_desc, + out_n_k_ho_wo_desc, + conv_strides, + conv_dilations, + in_left_pads, + in_right_pads); - conv_driver.Run(wei_k_c_y_x_desc, - in_n_c_hi_wi_desc, - out_n_k_ho_wo_desc, - conv_strides, - conv_dilations, - in_left_pads, - in_right_pads, - static_cast::type*>( - wei_k_c_y_x_device_buf.GetDeviceBuffer()), - static_cast::type*>( - in_n_c_hi_wi_device_buf.GetDeviceBuffer()), - static_cast(out_n_k_ho_wo_device_buf.GetDeviceBuffer())); + for(index_t i = 0; i < 5; ++i) + { + float ave_time = launch_kernel_dynamic_gemm_v1< + BlockSize, + typename vector_type::type, + TAcc, + TOut, + InMemoryDataOperation::Set, + decltype(descs[I0]), + decltype(descs[I1]), + decltype(descs[I2]), + decltype(descs[I3]), + GemmMPerBlock, + GemmNPerBlock, + GemmKPerBlock, + GemmMPerThread, + GemmNPerThread, + GemmKPerThread, + GemmMLevel0Cluster, + GemmNLevel0Cluster, + GemmMLevel1Cluster, + GemmNLevel1Cluster, + GemmABlockTransferThreadSliceLengths_GemmK_GemmM, + GemmABlockTransferThreadClusterLengths_GemmK_GemmM, + Sequence<1, 0>, + Sequence<1, 0>, + 0, + GemmABlockTransferSrcScalarPerVector_GemmK, + GemmABlockTransferDstScalarPerVector_GemmM, + false, // don't move back src coordinate after threadwise copy + GemmBBlockTransferThreadSliceLengths_GemmK_GemmN, + GemmBBlockTransferThreadClusterLengths_GemmK_GemmN, + Sequence<0, 1>, + Sequence<0, 1>, + 1, + GemmBBlockTransferSrcScalarPerVector_GemmN, + GemmBBlockTransferDstScalarPerVector_GemmN, + false, // don't move back src coordinate after threadwise copy, which will be fused with + // MoveSrcSliceWindow() to save addr computation + Sequence<2, 3, 0, 1>, + 3, + GemmCThreadTransferDstScalarPerVector_GemmN1, + decltype(descs[I4]), + decltype(descs[I5]), + decltype(descs[I6]), + decltype(descs[I7]), + decltype(descs[I8])>(static_cast::type*>( + wei_k_c_y_x_device_buf.GetDeviceBuffer()), + static_cast::type*>( + in_n_c_hi_wi_device_buf.GetDeviceBuffer()), + static_cast(out_n_k_ho_wo_device_buf.GetDeviceBuffer()), + descs[I0], + descs[I1], + descs[I2], + descs[I3], + descs[I4], + descs[I5], + descs[I6], + descs[I7], + descs[I8], + nrepeat); + float perf = (float)calculate_convolution_flops( + in_n_c_hi_wi_desc, wei_k_c_y_x_desc, out_n_k_ho_wo_desc) / + (std::size_t(1000) * 1000 * 1000) / ave_time; + + std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s" << std::endl; + } + + // copy result back to host out_n_k_ho_wo_device_buf.FromDevice(out_n_k_ho_wo.mData.data()); } diff --git a/driver/include/device_dynamic_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk.hpp b/driver/include/device_dynamic_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk.hpp index 65c4a60dbb..50f720b2f1 100644 --- a/driver/include/device_dynamic_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk.hpp +++ b/driver/include/device_dynamic_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk.hpp @@ -29,13 +29,17 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk( { using namespace ck; - std::cout << "device_dynamic_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk" - << std::endl; + std::cout << __func__ << std::endl; constexpr auto I0 = Number<0>{}; constexpr auto I1 = Number<1>{}; constexpr auto I2 = Number<2>{}; constexpr auto I3 = Number<3>{}; + constexpr auto I4 = Number<4>{}; + constexpr auto I5 = Number<5>{}; + constexpr auto I6 = Number<6>{}; + constexpr auto I7 = Number<7>{}; + constexpr auto I8 = Number<8>{}; constexpr auto N = OutDesc::GetLengths()[I0]; constexpr auto K = OutDesc::GetLengths()[I1]; @@ -53,7 +57,7 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk( constexpr auto C0 = C / Number{}; constexpr auto C1 = Number{}; -#if 0 +#if 1 // run-time variables constexpr auto in_n_hi_wi_c0_desc = make_dynamic_naive_tensor_descriptor_packed_v2(make_multi_index(N, Hi, Wi, C0)); @@ -112,7 +116,7 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk( wei_k_y_x_c_device_buf.ToDevice(wei_k_y_x_c.mData.data()); out_n_ho_wo_k_device_buf.ToDevice(out_n_ho_wo_k.mData.data()); -#if 1 +#if 0 // cdata = 16, BlockSize = 64, 16x64x4 constexpr index_t BlockSize = 64; @@ -372,51 +376,92 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk( constexpr index_t GemmCThreadTransferDstScalarPerVector_GemmM1 = 4; #endif - constexpr auto conv_driver = + constexpr index_t GemmM1 = GemmMPerThread * GemmMLevel0Cluster * GemmMLevel1Cluster; + constexpr index_t GemmN1 = GemmNPerThread * GemmNLevel0Cluster * GemmNLevel1Cluster; + + const auto descs = #if 1 - DriverDynamicConvolutionForwardImplicitGemm_v4r4_nhwc_kyxc_nhwk_pad -#elif 0 - DriverDynamicConvolutionForwardImplicitGemm_v4r4_nhwc_kyxc_nhwk_no_pad -#elif 1 - DriverDynamicConvolutionForwardImplicitGemm_v4r4_nhwc_kyxc_nhwk_1x1 + transform_forward_convolution_into_gemm_v4r4_nhwc_kyxc_nhwk_pad +#else + transform_forward_convolution_into_gemm_v4r4_nhwc_kyxc_nhwk_1x1 #endif - ::type, - TAcc, - TOut, - GemmMPerBlock, - GemmNPerBlock, - GemmKPerBlock, - GemmMPerThread, - GemmNPerThread, - GemmKPerThread, - GemmMLevel0Cluster, - GemmNLevel0Cluster, - GemmMLevel1Cluster, - GemmNLevel1Cluster, - GemmABlockTransferThreadSliceLengths_GemmK_GemmM, - GemmABlockTransferThreadClusterLengths_GemmK_GemmM, - GemmABlockTransferSrcScalarPerVector_GemmK, - GemmABlockTransferDstScalarPerVector_GemmM, - GemmBBlockTransferThreadSliceLengths_GemmK_GemmN, - GemmBBlockTransferThreadClusterLengths_GemmK_GemmN, - GemmBBlockTransferSrcScalarPerVector_GemmK, - GemmBBlockTransferDstScalarPerVector_GemmN, - GemmCThreadTransferDstScalarPerVector_GemmM1>{}; + (wei_k_y_x_c0_desc, + in_n_hi_wi_c0_desc, + out_n_ho_wo_k_desc, + conv_strides, + conv_dilations, + in_left_pads, + in_right_pads); - conv_driver.Run(wei_k_y_x_c0_desc, - in_n_hi_wi_c0_desc, - out_n_ho_wo_k_desc, - conv_strides, - conv_dilations, - in_left_pads, - in_right_pads, - static_cast::type*>( - wei_k_y_x_c_device_buf.GetDeviceBuffer()), - static_cast::type*>( - in_n_hi_wi_c_device_buf.GetDeviceBuffer()), - static_cast(out_n_ho_wo_k_device_buf.GetDeviceBuffer())); + for(index_t i = 0; i < 5; ++i) + { + float ave_time = launch_kernel_dynamic_gemm_v1< + BlockSize, + typename vector_type::type, + TAcc, + TOut, + InMemoryDataOperation::Set, + decltype(descs[I0]), + decltype(descs[I1]), + decltype(descs[I2]), + decltype(descs[I3]), + GemmMPerBlock, + GemmNPerBlock, + GemmKPerBlock, + GemmMPerThread, + GemmNPerThread, + GemmKPerThread, + GemmMLevel0Cluster, + GemmNLevel0Cluster, + GemmMLevel1Cluster, + GemmNLevel1Cluster, + GemmABlockTransferThreadSliceLengths_GemmK_GemmM, + GemmABlockTransferThreadClusterLengths_GemmK_GemmM, + Sequence<1, 0>, + Sequence<1, 0>, + 0, + GemmABlockTransferSrcScalarPerVector_GemmK, + GemmABlockTransferDstScalarPerVector_GemmM, + false, // don't move back src coordinate after threadwise copy + GemmBBlockTransferThreadSliceLengths_GemmK_GemmN, + GemmBBlockTransferThreadClusterLengths_GemmK_GemmN, + Sequence<1, 0>, + Sequence<1, 0>, + 0, + GemmBBlockTransferSrcScalarPerVector_GemmK, + GemmBBlockTransferDstScalarPerVector_GemmN, + false, // don't move back src coordinate after threadwise copy, which will be fused with + // MoveSrcSliceWindow() to save addr computation + Sequence<2, 3, 0, 1>, + 1, + GemmCThreadTransferDstScalarPerVector_GemmM1, + decltype(descs[I4]), + decltype(descs[I5]), + decltype(descs[I6]), + decltype(descs[I7]), + decltype(descs[I8])>(static_cast::type*>( + wei_k_y_x_c_device_buf.GetDeviceBuffer()), + static_cast::type*>( + in_n_hi_wi_c_device_buf.GetDeviceBuffer()), + static_cast(out_n_ho_wo_k_device_buf.GetDeviceBuffer()), + descs[I0], + descs[I1], + descs[I2], + descs[I3], + descs[I4], + descs[I5], + descs[I6], + descs[I7], + descs[I8], + nrepeat); + float perf = (float)(std::size_t(2) * N * K * Ho * Wo * C * Y * X) / + (std::size_t(1000) * 1000 * 1000) / ave_time; + + std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s" << std::endl; + } + + // copy result back to host out_n_ho_wo_k_device_buf.FromDevice(out_n_ho_wo_k.mData.data()); auto f_nhwk2nkhw = [&](auto n, auto k, auto ho, auto wo) { diff --git a/driver/src/conv_driver.cpp b/driver/src/conv_driver.cpp index ab9de5b661..539eca228a 100644 --- a/driver/src/conv_driver.cpp +++ b/driver/src/conv_driver.cpp @@ -48,8 +48,8 @@ int main(int argc, char* argv[]) using ConvStrides = Sequence<1, 1>; using ConvDilations = Sequence<1, 1>; - using LeftPads = Sequence<0, 0>; - using RightPads = Sequence<0, 0>; + using LeftPads = Sequence<0, 0>; + using RightPads = Sequence<0, 0>; #elif 0 constexpr index_t N = 1; constexpr index_t C = 16; @@ -62,9 +62,9 @@ int main(int argc, char* argv[]) using ConvStrides = Sequence<1, 1>; using ConvDilations = Sequence<1, 1>; - using LeftPads = Sequence<0, 0>; - using RightPads = Sequence<0, 0>; -#elif 1 + using LeftPads = Sequence<0, 0>; + using RightPads = Sequence<0, 0>; +#elif 0 constexpr index_t N = 1; constexpr index_t C = 16; constexpr index_t HI = 1080; @@ -92,7 +92,7 @@ int main(int argc, char* argv[]) using LeftPads = Sequence<1, 1>; using RightPads = Sequence<1, 1>; -#elif 0 +#elif 1 constexpr index_t N = 1; constexpr index_t C = 16; constexpr index_t HI = 540; @@ -210,7 +210,7 @@ int main(int argc, char* argv[]) using LeftPads = Sequence<1, 1>; using RightPads = Sequence<1, 1>; -#elif 0 +#elif 1 // 3x3, 71x71 constexpr index_t N = 128; constexpr index_t C = 192; @@ -225,7 +225,7 @@ int main(int argc, char* argv[]) using LeftPads = Sequence<1, 1>; using RightPads = Sequence<1, 1>; -#elif 1 +#elif 0 // 7x1, 17x17 constexpr index_t N = 128; constexpr index_t C = 128; diff --git a/script/docker-rocm3.7.sh b/script/docker-rocm3.7.sh new file mode 100644 index 0000000000..e9aab49447 --- /dev/null +++ b/script/docker-rocm3.7.sh @@ -0,0 +1,14 @@ +WORKSPACE=$1 +echo "workspace: " $WORKSPACE + +docker run \ +-it \ +--rm \ +--privileged \ +--group-add sudo \ +-w /root/workspace \ +-v $WORKSPACE:/root/workspace \ +asroy/tensorflow:rocm3.7-tf2.3-dev-omp \ +/bin/bash + +#--network host \