diff --git a/composable_kernel/include/driver/driver_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw.hpp b/composable_kernel/include/driver/driver_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw.hpp index 54a9370999..5ad7c0ca93 100644 --- a/composable_kernel/include/driver/driver_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw.hpp +++ b/composable_kernel/include/driver/driver_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw.hpp @@ -38,7 +38,7 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_pad typename InRightPads> __host__ void Run(const DynamicTensorDescriptor& wei_k_c_y_x_global_desc, const DynamicTensorDescriptor& in_n_c_hi_wi_global_desc, - const DynamicTensorDescriptor& out_n_k_ho_wo_global_desc, + const DynamicTensorDescriptor& out_n_k0_ho_wo_k1_global_desc, const ConvStrides& conv_strides, const ConvDilations& conv_dilations, const InLeftPads& in_left_pads, @@ -51,17 +51,21 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_pad constexpr auto I1 = Number<1>{}; constexpr auto I2 = Number<2>{}; constexpr auto I3 = Number<3>{}; + constexpr auto I4 = Number<4>{}; - const auto N = in_n_c_hi_wi_global_desc.GetLength(I0); - const auto C = in_n_c_hi_wi_global_desc.GetLength(I1); - const auto K = out_n_k_ho_wo_global_desc.GetLength(I1); + const auto N = in_n_c_hi_wi_global_desc.GetLength(I0); + const auto C = in_n_c_hi_wi_global_desc.GetLength(I1); + const auto K0 = out_n_k0_ho_wo_k1_global_desc.GetLength(I1); const auto Hi = in_n_c_hi_wi_global_desc.GetLength(I2); const auto Wi = in_n_c_hi_wi_global_desc.GetLength(I3); - const auto Ho = out_n_k_ho_wo_global_desc.GetLength(I2); - const auto Wo = out_n_k_ho_wo_global_desc.GetLength(I3); + const auto Ho = out_n_k0_ho_wo_k1_global_desc.GetLength(I2); + const auto Wo = out_n_k0_ho_wo_k1_global_desc.GetLength(I3); + const auto K1 = out_n_k0_ho_wo_k1_global_desc.GetLength(I4); + + const auto K = wei_k_c_y_x_global_desc.GetLength(I0); const auto Y = wei_k_c_y_x_global_desc.GetLength(I2); const auto X = wei_k_c_y_x_global_desc.GetLength(I3); @@ -78,7 +82,7 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_pad const auto InRightPadW = in_right_pads[I1]; // weight tensor - const auto wei_gemmk_gemmm_global_desc = transform_dynamic_tensor_descriptor( + const auto wei_e_k_global_desc = transform_dynamic_tensor_descriptor( make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(K, C * Y * X)), make_tuple(make_pass_through_transform(K), make_pass_through_transform(C * Y * X)), make_tuple(Sequence<0>{}, Sequence<1>{}), @@ -104,7 +108,7 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_pad make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4, 5>{})); - const auto in_gemmk_n_ho_wo_global_desc = transform_dynamic_tensor_descriptor( + const auto in_e_n_ho_wo_global_desc = transform_dynamic_tensor_descriptor( in_n_c_y_ho_x_wo_global_desc, make_tuple(make_merge_transform(make_tuple(C, Y, X)), make_pass_through_transform(N), @@ -114,31 +118,31 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_pad make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); // output tensor - const auto out_gemmm_n_ho_wo_global_desc = transform_dynamic_tensor_descriptor( - make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(N, K, Ho, Wo)), - make_tuple(make_pass_through_transform(K), + const auto out_k_n_ho_wo_global_desc = transform_dynamic_tensor_descriptor( + make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(N, K0, Ho, Wo, K1)), + make_tuple(make_merge_transform(make_tuple(K0, K1)), make_pass_through_transform(N), make_pass_through_transform(Ho), make_pass_through_transform(Wo)), - make_tuple(Sequence<1>{}, Sequence<0>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<1, 4>{}, Sequence<0>{}, Sequence<2>{}, Sequence<3>{}), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); const auto E = C * Y * X; - if(!(K % KPerBlock == 0 && Ho % HoPerBlock == 0 && Wo % WoPerBlock == 0 && - E % EPerBlock == 0)) + if(!((K % KPerBlock) == 0 && (Ho % HoPerBlock) == 0 && (Wo % WoPerBlock) == 0 && + (E % EPerBlock) == 0)) { throw std::runtime_error("wrong! GEMM size no divisible"); } // hack to control index calculation when iterating over a_k_m_global tensor - constexpr auto a_k_m_global_iterator_hacks = + constexpr auto a_e_k_global_iterator_hacks = make_tuple(make_tuple(Sequence<0, 0, 0>{}, Sequence<0, 0, 0>{}), make_tuple(Sequence<0, 0, 0>{}, Sequence<0, 0, 0>{})); - constexpr auto a_k_m_global_move_slice_window_iterator_hack = Sequence<0, 0, 0>{}; + constexpr auto a_e_k_global_move_slice_window_iterator_hack = Sequence<0, 0, 0>{}; - constexpr auto b_k_n_global_iterator_hacks = + constexpr auto b_e_n_ho_wo_global_iterator_hacks = make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, @@ -148,17 +152,17 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_pad Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{})); - constexpr auto b_k_n_global_move_slice_window_iterator_hack = + constexpr auto b_e_n_ho_wo_global_move_slice_window_iterator_hack = Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}; // hack to control index calculation when iterating over c_m0_m1_n0_n1_global tensor // hack for NKHW format - constexpr auto c_k_n_h_w_global_tensor_iterator_hacks = - make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{}, + constexpr auto c_k_n_ho_wo_global_tensor_iterator_hacks = + make_tuple(make_tuple(Sequence<0, 1, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}), - make_tuple(Sequence<0, 0, 0, 0, 0>{}, + make_tuple(Sequence<0, 2, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{})); @@ -171,9 +175,9 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_pad FloatAcc, FloatC, InMemoryDataOperation::Set, - decltype(wei_gemmk_gemmm_global_desc), - decltype(in_gemmk_n_ho_wo_global_desc), - decltype(out_gemmm_n_ho_wo_global_desc), + decltype(wei_e_k_global_desc), + decltype(in_e_n_ho_wo_global_desc), + decltype(out_k_n_ho_wo_global_desc), KPerBlock, HoPerBlock, WoPerBlock, @@ -196,13 +200,13 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_pad false, // don't move back src coordinate after threadwise copy, which will be fused with // MoveSrcSliceWindow() to save addr computation Sequence<0, 2, 3, 1>, - 3, + 0, CThreadTransferDstScalarPerVector_W, - decltype(a_k_m_global_iterator_hacks), - decltype(b_k_n_global_iterator_hacks), - decltype(c_k_n_h_w_global_tensor_iterator_hacks), - decltype(a_k_m_global_move_slice_window_iterator_hack), - decltype(b_k_n_global_move_slice_window_iterator_hack)>; + decltype(a_e_k_global_iterator_hacks), + decltype(b_e_n_ho_wo_global_iterator_hacks), + decltype(c_k_n_ho_wo_global_tensor_iterator_hacks), + decltype(a_e_k_global_move_slice_window_iterator_hack), + decltype(b_e_n_ho_wo_global_move_slice_window_iterator_hack)>; const auto GridSize = (K / KPerBlock) * (Ho / HoPerBlock) * (Wo / WoPerBlock) * N; @@ -226,108 +230,104 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_pad { if(has_main_k_block_loop && has_double_tail_k_block_loop) { - const auto kernel = - run_gridwise_operation, - integral_constant>; + const auto kernel = run_gridwise_operation, + integral_constant>; launch_kernel(kernel, dim3(GridSize), dim3(BlockSize), 0, 0, - wei_gemmk_gemmm_global_desc, + wei_e_k_global_desc, p_wei_global, - in_gemmk_n_ho_wo_global_desc, + in_e_n_ho_wo_global_desc, p_in_global, - out_gemmm_n_ho_wo_global_desc, + out_k_n_ho_wo_global_desc, p_out_global, integral_constant{}, integral_constant{}); } else if(has_main_k_block_loop && !has_double_tail_k_block_loop) { - const auto kernel = - run_gridwise_operation, - integral_constant>; + const auto kernel = run_gridwise_operation, + integral_constant>; launch_kernel(kernel, dim3(GridSize), dim3(BlockSize), 0, 0, - wei_gemmk_gemmm_global_desc, + wei_e_k_global_desc, p_wei_global, - in_gemmk_n_ho_wo_global_desc, + in_e_n_ho_wo_global_desc, p_in_global, - out_gemmm_n_ho_wo_global_desc, + out_k_n_ho_wo_global_desc, p_out_global, integral_constant{}, integral_constant{}); } else if(!has_main_k_block_loop && has_double_tail_k_block_loop) { - const auto kernel = - run_gridwise_operation, - integral_constant>; + const auto kernel = run_gridwise_operation, + integral_constant>; launch_kernel(kernel, dim3(GridSize), dim3(BlockSize), 0, 0, - wei_gemmk_gemmm_global_desc, + wei_e_k_global_desc, p_wei_global, - in_gemmk_n_ho_wo_global_desc, + in_e_n_ho_wo_global_desc, p_in_global, - out_gemmm_n_ho_wo_global_desc, + out_k_n_ho_wo_global_desc, p_out_global, integral_constant{}, integral_constant{}); } else { - const auto kernel = - run_gridwise_operation, - integral_constant>; + const auto kernel = run_gridwise_operation, + integral_constant>; launch_kernel(kernel, dim3(GridSize), dim3(BlockSize), 0, 0, - wei_gemmk_gemmm_global_desc, + wei_e_k_global_desc, p_wei_global, - in_gemmk_n_ho_wo_global_desc, + in_e_n_ho_wo_global_desc, p_in_global, - out_gemmm_n_ho_wo_global_desc, + out_k_n_ho_wo_global_desc, p_out_global, integral_constant{}, integral_constant{}); @@ -340,7 +340,7 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_pad float perf = (float)calculate_convolution_flops(in_n_c_hi_wi_global_desc, wei_k_c_y_x_global_desc, - out_n_k_ho_wo_global_desc) / + out_n_k0_ho_wo_k1_global_desc) / (std::size_t(1000) * 1000 * 1000) / ave_time; std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s" diff --git a/composable_kernel/include/driver/driver_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw_outpad.hpp b/composable_kernel/include/driver/driver_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw_outpad.hpp new file mode 100644 index 0000000000..f7c24ead4d --- /dev/null +++ b/composable_kernel/include/driver/driver_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw_outpad.hpp @@ -0,0 +1,368 @@ +#ifndef CK_DRIVER_DYNAMIC_CONVOLUTION_FORWARD_IMPLICIT_GEMM_V5R1_NCHW_KCYX_NKHW_OUTPAD_HPP +#define CK_DRIVER_DYNAMIC_CONVOLUTION_FORWARD_IMPLICIT_GEMM_V5R1_NCHW_KCYX_NKHW_OUTPAD_HPP + +#include "common_header.hpp" +#include "dynamic_tensor_descriptor.hpp" +#include "dynamic_tensor_descriptor_helper.hpp" +#include "gridwise_dynamic_gemm_v2.hpp" +#include "gridwise_operation_wrapper.hpp" + +namespace ck { + +template +struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_outpad +{ + template + __host__ void Run(const DynamicTensorDescriptor& wei_k_c_y_x_global_desc, + const DynamicTensorDescriptor& in_n_c_hi_wi_global_desc, + const DynamicTensorDescriptor& out_n_k0_ho_wo_k1_global_desc, + const ConvStrides& conv_strides, + const ConvDilations& conv_dilations, + const InLeftPads& in_left_pads, + const InRightPads& in_right_pads, + const FloatAB* __restrict__ p_wei_global, + const FloatAB* __restrict__ p_in_global, + FloatC* __restrict__ p_out_global) const + { + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + constexpr auto I2 = Number<2>{}; + constexpr auto I3 = Number<3>{}; + constexpr auto I4 = Number<4>{}; + + const auto N = in_n_c_hi_wi_global_desc.GetLength(I0); + const auto C = in_n_c_hi_wi_global_desc.GetLength(I1); + const auto K0 = out_n_k0_ho_wo_k1_global_desc.GetLength(I1); + + const auto Hi = in_n_c_hi_wi_global_desc.GetLength(I2); + const auto Wi = in_n_c_hi_wi_global_desc.GetLength(I3); + + const auto Ho = out_n_k0_ho_wo_k1_global_desc.GetLength(I2); + const auto Wo = out_n_k0_ho_wo_k1_global_desc.GetLength(I3); + + const auto K1 = out_n_k0_ho_wo_k1_global_desc.GetLength(I4); + + const auto K = wei_k_c_y_x_global_desc.GetLength(I0); + const auto Y = wei_k_c_y_x_global_desc.GetLength(I2); + const auto X = wei_k_c_y_x_global_desc.GetLength(I3); + + const auto ConvStrideH = conv_strides[I0]; + const auto ConvStrideW = conv_strides[I1]; + + const auto ConvDilationH = conv_dilations[I0]; + const auto ConvDilationW = conv_dilations[I1]; + + const auto Hop = (Ho + HoPerBlock - 1) / HoPerBlock * HoPerBlock; + const auto Wop = (Wo + WoPerBlock - 1) / WoPerBlock * WoPerBlock; + + const auto OutRightPadH = Hop - Ho; + const auto OutRightPadW = Wop - Wo; + + const auto InLeftPadH = in_left_pads[I0]; + const auto InLeftPadW = in_left_pads[I1]; + + const auto InRightPadH = in_right_pads[I0] + OutRightPadH * ConvStrideH; + const auto InRightPadW = in_right_pads[I1] + OutRightPadW * ConvStrideW; + + std::cerr << "OutRightPadH = " << OutRightPadH << " OutRightPadW = " << OutRightPadW + << std::endl; + std::cerr << "InRightPadH = " << InRightPadH << " InRightPadW = " << InRightPadW + << std::endl; + + // weight tensor + const auto wei_e_k_global_desc = transform_dynamic_tensor_descriptor( + make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(K, C * Y * X)), + make_tuple(make_pass_through_transform(K), make_pass_through_transform(C * Y * X)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<1>{}, Sequence<0>{})); + + // input tensor + const auto in_n_c_hip_wip_global_desc = transform_dynamic_tensor_descriptor( + in_n_c_hi_wi_global_desc, + make_tuple(make_pass_through_transform(N), + make_pass_through_transform(C), + make_pad_transform(Hi, InLeftPadH, InRightPadH), + make_pad_transform(Wi, InLeftPadW, InRightPadW)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); + + const auto in_n_c_y_ho_x_wo_global_desc = transform_dynamic_tensor_descriptor( + in_n_c_hip_wip_global_desc, + make_tuple( + make_pass_through_transform(N), + make_pass_through_transform(C), + make_embed_transform(make_tuple(Y, Hop), make_tuple(ConvDilationH, ConvStrideH)), + make_embed_transform(make_tuple(X, Wop), make_tuple(ConvDilationW, ConvStrideW))), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4, 5>{})); + + const auto in_e_n_ho_wo_global_desc = transform_dynamic_tensor_descriptor( + in_n_c_y_ho_x_wo_global_desc, + make_tuple(make_merge_transform(make_tuple(C, Y, X)), + make_pass_through_transform(N), + make_pass_through_transform(Hop), + make_pass_through_transform(Wop)), + make_tuple(Sequence<1, 2, 4>{}, Sequence<0>{}, Sequence<3>{}, Sequence<5>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); + + // output tensor + const auto out_k_n_hop_wop_global_desc = transform_dynamic_tensor_descriptor( + make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(N, K0, Ho, Wo, K1)), + make_tuple(make_merge_transform(make_tuple(K0, K1)), + make_pass_through_transform(N), + make_pad_transform(Ho, 0, OutRightPadH), + make_pad_transform(Wo, 0, OutRightPadW)), + make_tuple(Sequence<1, 4>{}, Sequence<0>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); + + const auto E = C * Y * X; + + std::cerr << "Hop = " << Hop << " Wop = " << Wop << std::endl; + + if(!((K % KPerBlock) == 0 && (Hop % HoPerBlock) == 0 && (Wop % WoPerBlock) == 0 && + (E % EPerBlock) == 0)) + { + throw std::runtime_error("wrong! GEMM size no divisible"); + } + + // hack to control index calculation when iterating over a_k_m_global tensor + constexpr auto a_e_k_global_iterator_hacks = + make_tuple(make_tuple(Sequence<0, 0, 0>{}, Sequence<0, 0, 0>{}), + make_tuple(Sequence<0, 0, 0>{}, Sequence<0, 0, 0>{})); + + constexpr auto a_e_k_global_move_slice_window_iterator_hack = Sequence<0, 0, 0>{}; + + constexpr auto b_e_n_ho_wo_global_iterator_hacks = + make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}), + make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{})); + + constexpr auto b_e_n_ho_wo_global_move_slice_window_iterator_hack = + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}; + + // hack to control index calculation when iterating over c_m0_m1_n0_n1_global tensor + // hack for NKHW format + constexpr auto c_k_n_ho_wo_global_tensor_iterator_hacks = + make_tuple(make_tuple(Sequence<0, 1, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0>{}), + make_tuple(Sequence<0, 2, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0>{})); + + // GEMM + using gridwise_gemm = GridwiseDynamicGemm_km_kn_mn_v3< + BlockSize, + FloatAB, + FloatAcc, + FloatC, + InMemoryDataOperation::Set, + decltype(wei_e_k_global_desc), + decltype(in_e_n_ho_wo_global_desc), + decltype(out_k_n_hop_wop_global_desc), + KPerBlock, + HoPerBlock, + WoPerBlock, + EPerBlock, + KPerThread, + HoPerThread, + WoPerThread, + EPerThread, + ABlockTransferThreadSliceLengths_E_K, + ABlockTransferThreadClusterLengths_E_K, + Sequence<1, 0>, + Sequence<1, 0>, + 0, + ABlockTransferSrcScalarPerVector_E, + ABlockTransferDstScalarPerVector_K, + false, // don't move back src coordinate after threadwise copy + Sequence<0, 2, 3, 1>, + 3, + BThreadTransferSrcScalarPerVector_W, + false, // don't move back src coordinate after threadwise copy, which will be fused with + // MoveSrcSliceWindow() to save addr computation + Sequence<0, 2, 3, 1>, + 0, + CThreadTransferDstScalarPerVector_W, + decltype(a_e_k_global_iterator_hacks), + decltype(b_e_n_ho_wo_global_iterator_hacks), + decltype(c_k_n_ho_wo_global_tensor_iterator_hacks), + decltype(a_e_k_global_move_slice_window_iterator_hack), + decltype(b_e_n_ho_wo_global_move_slice_window_iterator_hack)>; + + const auto GridSize = (K / KPerBlock) * (Hop / HoPerBlock) * (Wop / WoPerBlock) * N; + + const bool has_main_k_block_loop = (E + EPerBlock) / (2 * EPerBlock) > 1; + + const bool has_double_tail_k_block_loop = (E / EPerBlock) % 2 == 0; + + index_t nrepeat = 100; + + for(index_t i = 0; i < 5; ++i) + { + std::cout << "Start running " << nrepeat << " times..." << std::endl; + + KernelTimer timer; + timer.Start(); + std::cout << "has_main_k_block_loop: " << has_main_k_block_loop + << " has_double_tail_k_block_loop: " << has_double_tail_k_block_loop + << std::endl; + + for(index_t j = 0; j < nrepeat; ++j) + { + if(has_main_k_block_loop && has_double_tail_k_block_loop) + { + const auto kernel = + run_gridwise_operation, + integral_constant>; + + launch_kernel(kernel, + dim3(GridSize), + dim3(BlockSize), + 0, + 0, + wei_e_k_global_desc, + p_wei_global, + in_e_n_ho_wo_global_desc, + p_in_global, + out_k_n_hop_wop_global_desc, + p_out_global, + integral_constant{}, + integral_constant{}); + } + else if(has_main_k_block_loop && !has_double_tail_k_block_loop) + { + const auto kernel = + run_gridwise_operation, + integral_constant>; + + launch_kernel(kernel, + dim3(GridSize), + dim3(BlockSize), + 0, + 0, + wei_e_k_global_desc, + p_wei_global, + in_e_n_ho_wo_global_desc, + p_in_global, + out_k_n_hop_wop_global_desc, + p_out_global, + integral_constant{}, + integral_constant{}); + } + else if(!has_main_k_block_loop && has_double_tail_k_block_loop) + { + const auto kernel = + run_gridwise_operation, + integral_constant>; + + launch_kernel(kernel, + dim3(GridSize), + dim3(BlockSize), + 0, + 0, + wei_e_k_global_desc, + p_wei_global, + in_e_n_ho_wo_global_desc, + p_in_global, + out_k_n_hop_wop_global_desc, + p_out_global, + integral_constant{}, + integral_constant{}); + } + else + { + const auto kernel = + run_gridwise_operation, + integral_constant>; + + launch_kernel(kernel, + dim3(GridSize), + dim3(BlockSize), + 0, + 0, + wei_e_k_global_desc, + p_wei_global, + in_e_n_ho_wo_global_desc, + p_in_global, + out_k_n_hop_wop_global_desc, + p_out_global, + integral_constant{}, + integral_constant{}); + } + } + + timer.End(); + + float ave_time = timer.GetElapsedTime() / nrepeat; + + float perf = (float)calculate_convolution_flops(in_n_c_hi_wi_global_desc, + wei_k_c_y_x_global_desc, + out_n_k0_ho_wo_k1_global_desc) / + (std::size_t(1000) * 1000 * 1000) / ave_time; + + std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s" + << std::endl; + } + } +}; +} // namespace ck +#endif diff --git a/composable_kernel/include/tensor_operation/blockwise_gemm_v3.hpp b/composable_kernel/include/tensor_operation/blockwise_gemm_v3.hpp index 76f50bc811..33d770092c 100644 --- a/composable_kernel/include/tensor_operation/blockwise_gemm_v3.hpp +++ b/composable_kernel/include/tensor_operation/blockwise_gemm_v3.hpp @@ -134,9 +134,12 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v3 constexpr auto KPerThreadSubC = 4; + constexpr auto HoPerThreadSubC = 2; + constexpr auto WoPerThreadSubC = 2; + static_assert(KPerThread % KPerThreadSubC == 0, ""); - static_assert(HPerThread % 2 == 0, ""); - static_assert(WPerThread % 2 == 0, ""); + static_assert(HPerThread % HoPerThreadSubC == 0, ""); + static_assert(WPerThread % WoPerThreadSubC == 0, ""); // thread A, B for GEMM constexpr auto a_thread_mtx = make_dynamic_naive_tensor_descriptor_packed_v2( @@ -158,7 +161,9 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v3 constexpr auto threadwise_gemm = ThreadwiseGemm_km_kn_mn_v3{}; + decltype(c_thread_mtx), + HoPerThreadSubC, + WoPerThreadSubC>{}; // loop over k #pragma unroll for(index_t e_begin = 0; e_begin < EPerBlock; e_begin += EPerThreadLoop) @@ -171,10 +176,11 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v3 mMyThreadOffsetA, p_a_thread); - for(index_t h_begin = 0; h_begin < HPerThread; h_begin += 2) +#pragma unroll + for(index_t h_begin = 0; h_begin < HPerThread; h_begin += HoPerThreadSubC) { - - for(index_t w_begin = 0; w_begin < WPerThread; w_begin += 2) +#pragma unroll + for(index_t w_begin = 0; w_begin < WPerThread; w_begin += WoPerThreadSubC) { threadwise_gemm.Run(p_a_thread, p_b_thread + b_thread_mtx.CalculateOffset(make_tuple( diff --git a/composable_kernel/include/tensor_operation/threadwise_gemm_v3.hpp b/composable_kernel/include/tensor_operation/threadwise_gemm_v3.hpp index 96d0afa892..54a4932f4d 100644 --- a/composable_kernel/include/tensor_operation/threadwise_gemm_v3.hpp +++ b/composable_kernel/include/tensor_operation/threadwise_gemm_v3.hpp @@ -37,6 +37,8 @@ __device__ void threadwise_matrix_set_zero_v3(Desc, Float* __restrict__ p_thread template ::type = false> @@ -54,11 +56,6 @@ struct ThreadwiseGemm_km_kn_mn_v3 constexpr auto I2 = Number<2>{}; constexpr auto I3 = Number<3>{}; - // constexpr auto H = BDesc{}.GetLength(I2); - // constexpr auto W = BDesc{}.GetLength(I3); - constexpr auto H = 2; - constexpr auto W = 2; - constexpr auto E = ADesc{}.GetLength(I0); constexpr auto K = ADesc{}.GetLength(I1); diff --git a/composable_kernel/include/utility/amd_buffer_addressing_v2.hpp b/composable_kernel/include/utility/amd_buffer_addressing_v2.hpp index 5c4e153869..ff8d57993f 100644 --- a/composable_kernel/include/utility/amd_buffer_addressing_v2.hpp +++ b/composable_kernel/include/utility/amd_buffer_addressing_v2.hpp @@ -59,7 +59,26 @@ __llvm_amdgcn_raw_buffer_load_i32x4(int32x4_t srsrc, index_t voffset, index_t soffset, index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.v4i32"); +// half +__device__ half_t +__llvm_amdgcn_raw_buffer_load_fp16(int32x4_t srsrc, + index_t voffset, + index_t soffset, + index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.f16"); +__device__ half2_t +__llvm_amdgcn_raw_buffer_load_fp16x2(int32x4_t srsrc, + index_t voffset, + index_t soffset, + index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.v2f16"); + +__device__ half4_t +__llvm_amdgcn_raw_buffer_load_fp16x4(int32x4_t srsrc, + index_t voffset, + index_t soffset, + index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.v4f16"); + +// float __device__ float __llvm_amdgcn_raw_buffer_load_fp32(int32x4_t srsrc, index_t voffset, @@ -114,6 +133,28 @@ __llvm_amdgcn_raw_buffer_store_i32x4(int32x4_t vdata, index_t soffset, index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.v4i32"); +// half +__device__ void +__llvm_amdgcn_raw_buffer_store_fp16(half_t vdata, + int32x4_t rsrc, + index_t voffset, + index_t soffset, + index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.f16"); + +__device__ void +__llvm_amdgcn_raw_buffer_store_fp16x2(half2_t vdata, + int32x4_t rsrc, + index_t voffset, + index_t soffset, + index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.v2f16"); + +__device__ void +__llvm_amdgcn_raw_buffer_store_fp16x4(half4_t vdata, + int32x4_t rsrc, + index_t voffset, + index_t soffset, + index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.v4f16"); +// float __device__ void __llvm_amdgcn_raw_buffer_store_fp32(float vdata, int32x4_t rsrc, @@ -142,7 +183,13 @@ amd_buffer_load_impl_v2(int32x4_t src_wave_buffer_resource, index_t src_wave_addr_offset) { static_assert((is_same::value && (N == 1 || N == 2 || N == 4 || N == 8)) || - (is_same::value && (N == 1 || N == 2 || N == 4 || N == 8)), + (is_same::value && (N == 1 || N == 2 || N == 4)) || + (is_same::value && (N == 1)) || + (is_same::value && (N == 1)) || + (is_same::value && (N == 1)) || + (is_same::value && (N == 1 || N == 2 || N == 4 || N == 8)) || + (is_same::value && (N == 1)) || + (is_same::value && (N == 1)), "wrong! not implemented"); if constexpr(is_same::value) @@ -169,8 +216,63 @@ amd_buffer_load_impl_v2(int32x4_t src_wave_buffer_resource, tmp.Vectors(Number<4>{})(Number<0>{}) = __llvm_amdgcn_raw_buffer_load_fp32x4( src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0); - tmp.Vectors(Number<4>{})(Number<1>{}) = __llvm_amdgcn_raw_buffer_load_fp32x4( - src_wave_buffer_resource, src_thread_addr_offset, 4 * sizeof(float), 0); + tmp.Vectors(Number<4>{})(Number<1>{}) = + __llvm_amdgcn_raw_buffer_load_fp32x4(src_wave_buffer_resource, + src_thread_addr_offset, + src_wave_addr_offset + 4 * sizeof(float), + 0); + + return tmp.Vector(); + } + } + else if constexpr(is_same::value) + { + if constexpr(N == 1) + { + return __llvm_amdgcn_raw_buffer_load_fp16( + src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0); + } + else if constexpr(N == 2) + { + return __llvm_amdgcn_raw_buffer_load_fp16x2( + src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0); + } + else if constexpr(N == 4) + { + return __llvm_amdgcn_raw_buffer_load_fp16x4( + src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0); + } + } + else if constexpr(is_same::value) + { + if constexpr(N == 1) + { + return __llvm_amdgcn_raw_buffer_load_fp16x2( + src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0); + } + } + else if constexpr(is_same::value) + { + if constexpr(N == 1) + { + return __llvm_amdgcn_raw_buffer_load_fp16x4( + src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0); + } + } + else if constexpr(is_same::value) + { + if constexpr(N == 1) + { + vector_type tmp; + + tmp.Vectors(Number<4>{})(Number<0>{}) = __llvm_amdgcn_raw_buffer_load_fp16x4( + src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0); + + tmp.Vectors(Number<4>{})(Number<1>{}) = + __llvm_amdgcn_raw_buffer_load_fp16x4(src_wave_buffer_resource, + src_thread_addr_offset, + src_wave_addr_offset + 4 * sizeof(half_t), + 0); return tmp.Vector(); } @@ -199,12 +301,31 @@ amd_buffer_load_impl_v2(int32x4_t src_wave_buffer_resource, tmp.Vectors(Number<4>{})(Number<0>{}) = __llvm_amdgcn_raw_buffer_load_i32x4( src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0); - tmp.Vectors(Number<4>{})(Number<1>{}) = __llvm_amdgcn_raw_buffer_load_i32x4( - src_wave_buffer_resource, src_thread_addr_offset, 4 * sizeof(int32_t), 0); + tmp.Vectors(Number<4>{})(Number<1>{}) = + __llvm_amdgcn_raw_buffer_load_i32x4(src_wave_buffer_resource, + src_thread_addr_offset, + src_wave_addr_offset + 4 * sizeof(int32_t), + 0); return tmp.Vector(); } } + else if constexpr(is_same::value) + { + if constexpr(N == 1) + { + return __llvm_amdgcn_raw_buffer_load_i32x2( + src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0); + } + } + else if constexpr(is_same::value) + { + if constexpr(N == 1) + { + return __llvm_amdgcn_raw_buffer_load_i32x4( + src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0); + } + } } template @@ -213,10 +334,12 @@ __device__ void amd_buffer_store_impl_v2(const typename vector_type::type index_t dst_thread_addr_offset, index_t dst_wave_addr_offset) { - static_assert((is_same::value && (N == 1 || N == 2 || N == 4)) || - (is_same::value && (N == 1 || N == 2 || N == 4)) || - (is_same::value && (N == 1 || N == 2 || N == 4)), - "wrong! not implemented"); + static_assert( + (is_same::value && (N == 1 || N == 2 || N == 4)) || + (is_same::value && (N == 1 || N == 2 || N == 4)) || + (is_same::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) || + (is_same::value && (N == 1 || N == 2 || N == 4 || N == 8)), + "wrong! not implemented"); if constexpr(is_same::value) { @@ -298,6 +421,65 @@ __device__ void amd_buffer_store_impl_v2(const typename vector_type::type dst_wave_addr_offset, 0); } + else if constexpr(N == 8) + { + __llvm_amdgcn_raw_buffer_store_i32x2(src_thread_data, + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset, + 0); + } + else if constexpr(N == 16) + { + __llvm_amdgcn_raw_buffer_store_i32x4(src_thread_data, + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset, + 0); + } + } + else if constexpr(is_same::value) + { + if constexpr(N == 1) + { + __llvm_amdgcn_raw_buffer_store_fp16(src_thread_data, + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset, + 0); + } + else if constexpr(N == 2) + { + __llvm_amdgcn_raw_buffer_store_fp16x2(src_thread_data, + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset, + 0); + } + else if constexpr(N == 4) + { + __llvm_amdgcn_raw_buffer_store_fp16x4(src_thread_data, + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset, + 0); + } + else if constexpr(N == 8) + { + vector_type tmp{src_thread_data}; + + __llvm_amdgcn_raw_buffer_store_fp16x4(tmp.Vectors(Number<4>{})[Number<0>{}], + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset, + 0); + + __llvm_amdgcn_raw_buffer_store_fp16x4(tmp.Vectors(Number<4>{})[Number<1>{}], + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset + 4 * sizeof(half_t), + 0); + } } } diff --git a/composable_kernel/include/utility/amd_inline_asm.hpp b/composable_kernel/include/utility/amd_inline_asm.hpp index 6260fdc5bb..fa0f76e630 100644 --- a/composable_kernel/include/utility/amd_inline_asm.hpp +++ b/composable_kernel/include/utility/amd_inline_asm.hpp @@ -166,6 +166,53 @@ __device__ void amd_assembly_outer_product_1x4(half4_t a, "3"(c3)); } +__device__ void amd_assembly_outer_product_1x4(half8_t a, + half8_t b0, + half8_t b1, + half8_t b2, + half8_t b3, + float& c0, + float& c1, + float& c2, + float& c3) +{ + + const half4_t* p_a_half4 = reinterpret_cast(&a); + const half4_t* p_b0_half4 = reinterpret_cast(&b0); + const half4_t* p_b1_half4 = reinterpret_cast(&b1); + const half4_t* p_b2_half4 = reinterpret_cast(&b2); + const half4_t* p_b3_half4 = reinterpret_cast(&b3); + + amd_assembly_outer_product_1x4( + p_a_half4[0], p_b0_half4[0], p_b1_half4[0], p_b2_half4[0], p_b3_half4[0], c0, c1, c2, c3); + + amd_assembly_outer_product_1x4( + p_a_half4[1], p_b0_half4[1], p_b1_half4[1], p_b2_half4[1], p_b3_half4[1], c0, c1, c2, c3); +} + +__device__ void amd_assembly_outer_product_1x4(half16_t a, + half16_t b0, + half16_t b1, + half16_t b2, + half16_t b3, + float& c0, + float& c1, + float& c2, + float& c3) +{ + const half8_t* p_a_half8 = reinterpret_cast(&a); + const half8_t* p_b0_half8 = reinterpret_cast(&b0); + const half8_t* p_b1_half8 = reinterpret_cast(&b1); + const half8_t* p_b2_half8 = reinterpret_cast(&b2); + const half8_t* p_b3_half8 = reinterpret_cast(&b3); + + amd_assembly_outer_product_1x4( + p_a_half8[0], p_b0_half8[0], p_b1_half8[0], p_b2_half8[0], p_b3_half8[0], c0, c1, c2, c3); + + amd_assembly_outer_product_1x4( + p_a_half8[1], p_b0_half8[1], p_b1_half8[1], p_b2_half8[1], p_b3_half8[1], c0, c1, c2, c3); +} + // c0 += inner_product(a, b0) // c1 += inner_product(a, b1) __device__ void @@ -215,5 +262,82 @@ __device__ void amd_assembly_outer_product_1x4(int8x4_t a, #endif } +__device__ void amd_assembly_outer_product_1x4(int8x8_t a, + int8x8_t b0, + int8x8_t b1, + int8x8_t b2, + int8x8_t b3, + int32_t& c0, + int32_t& c1, + int32_t& c2, + int32_t& c3) +{ + + const int8x4_t* p_a_int8x4_t = reinterpret_cast(&a); + const int8x4_t* p_b0_int8x4_t = reinterpret_cast(&b0); + const int8x4_t* p_b1_int8x4_t = reinterpret_cast(&b1); + const int8x4_t* p_b2_int8x4_t = reinterpret_cast(&b2); + const int8x4_t* p_b3_int8x4_t = reinterpret_cast(&b3); + + amd_assembly_outer_product_1x4(p_a_int8x4_t[0], + p_b0_int8x4_t[0], + p_b1_int8x4_t[0], + p_b2_int8x4_t[0], + p_b3_int8x4_t[0], + c0, + c1, + c2, + c3); + + amd_assembly_outer_product_1x4(p_a_int8x4_t[1], + p_b0_int8x4_t[1], + p_b1_int8x4_t[1], + p_b2_int8x4_t[1], + p_b3_int8x4_t[1], + c0, + c1, + c2, + c3); +} + +__device__ void amd_assembly_outer_product_1x4(int8x16_t a, + int8x16_t b0, + int8x16_t b1, + int8x16_t b2, + int8x16_t b3, + int32_t& c0, + int32_t& c1, + int32_t& c2, + int32_t& c3) + +{ + + const int8x8_t* p_a_int8x8_t = reinterpret_cast(&a); + const int8x8_t* p_b0_int8x8_t = reinterpret_cast(&b0); + const int8x8_t* p_b1_int8x8_t = reinterpret_cast(&b1); + const int8x8_t* p_b2_int8x8_t = reinterpret_cast(&b2); + const int8x8_t* p_b3_int8x8_t = reinterpret_cast(&b3); + + amd_assembly_outer_product_1x4(p_a_int8x8_t[0], + p_b0_int8x8_t[0], + p_b1_int8x8_t[0], + p_b2_int8x8_t[0], + p_b3_int8x8_t[0], + c0, + c1, + c2, + c3); + + amd_assembly_outer_product_1x4(p_a_int8x8_t[1], + p_b0_int8x8_t[1], + p_b1_int8x8_t[1], + p_b2_int8x8_t[1], + p_b3_int8x8_t[1], + c0, + c1, + c2, + c3); +} + } // namespace ck #endif diff --git a/composable_kernel/include/utility/float_type.amd.hpp.in b/composable_kernel/include/utility/float_type.amd.hpp.in index 6c0aedb6d4..7ce0d18d61 100644 --- a/composable_kernel/include/utility/float_type.amd.hpp.in +++ b/composable_kernel/include/utility/float_type.amd.hpp.in @@ -168,6 +168,84 @@ struct vector_type __host__ __device__ constexpr auto& Vectors(Number<8>) { return data_.d8x1_; } }; +template +struct vector_type +{ + using d1_t = T; + typedef T d2_t __attribute__((ext_vector_type(2))); + typedef T d4_t __attribute__((ext_vector_type(4))); + typedef T d8_t __attribute__((ext_vector_type(8))); + typedef T d16_t __attribute__((ext_vector_type(16))); + + using type = d16_t; + + union + { + d16_t d16_; + StaticallyIndexedArray d1x16_; + StaticallyIndexedArray d2x8_; + StaticallyIndexedArray d4x4_; + StaticallyIndexedArray d8x2_; + StaticallyIndexedArray d16x1_; + } data_; + + __host__ __device__ constexpr vector_type() : data_{type{0}} {} + + __host__ __device__ constexpr vector_type(type v) : data_{v} {} + + __host__ __device__ static constexpr index_t Size() { return 16; } + + __host__ __device__ constexpr const auto& Vector() const { return data_.d16_; } + + __host__ __device__ constexpr auto& Vector() { return data_.d16_; } + + __host__ __device__ constexpr const auto& Scalars() const { return data_.d1x16_; } + + __host__ __device__ constexpr auto& Scalars() { return data_.d1x16_; } + + __host__ __device__ constexpr const auto& Vectors(Number<1>) const { return data_.d1x16_; } + + __host__ __device__ constexpr const auto& Vectors(Number<2>) const { return data_.d2x8_; } + + __host__ __device__ constexpr const auto& Vectors(Number<4>) const { return data_.d4x4_; } + + __host__ __device__ constexpr const auto& Vectors(Number<8>) const { return data_.d8x2_; } + + __host__ __device__ constexpr const auto& Vectors(Number<16>) const { return data_.d16x1_; } + + __host__ __device__ constexpr auto& Vectors(Number<1>) { return data_.d1x16_; } + + __host__ __device__ constexpr auto& Vectors(Number<2>) { return data_.d2x8_; } + + __host__ __device__ constexpr auto& Vectors(Number<4>) { return data_.d4x4_; } + + __host__ __device__ constexpr auto& Vectors(Number<8>) { return data_.d8x2_; } + + __host__ __device__ constexpr auto& Vectors(Number<16>) { return data_.d16x1_; } +}; + +// fp32 +using float2_t = typename vector_type::type; +using float4_t = typename vector_type::type; +using float8_t = typename vector_type::type; + +// fp16 +using half_t = _Float16; +using half2_t = typename vector_type::type; +using half4_t = typename vector_type::type; +using half8_t = typename vector_type::type; +using half16_t = typename vector_type::type; + +// bfp16 +using ushort2_t = typename vector_type::type; +using ushort4_t = typename vector_type::type; +using ushort8_t = typename vector_type::type; + +// i32 +using int32x2_t = typename vector_type::type; +using int32x4_t = typename vector_type::type; +using int32x8_t = typename vector_type::type; + template <> struct vector_type { @@ -250,31 +328,118 @@ struct vector_type __host__ __device__ constexpr auto& Vectors(Number<4>) { return data_.d4x1_; } }; -// fp32 -using float2_t = typename vector_type::type; -using float4_t = typename vector_type::type; -using float8_t = typename vector_type::type; +template <> +struct vector_type +{ + using d1_t = int8_t; + typedef int16_t d2_t; + typedef int32_t d4_t; + typedef int32x2_t d8_t; -// fp16 -using half_t = _Float16; -using half2_t = typename vector_type::type; -using half4_t = typename vector_type::type; -using half8_t = typename vector_type::type; + using type = d8_t; -// bfp16 -using ushort2_t = typename vector_type::type; -using ushort4_t = typename vector_type::type; -using ushort8_t = typename vector_type::type; + union + { + d8_t d8_; + StaticallyIndexedArray d1x8_; + StaticallyIndexedArray d2x4_; + StaticallyIndexedArray d4x2_; + StaticallyIndexedArray d8x1_; + } data_; -// i32 -using int32x2_t = typename vector_type::type; -using int32x4_t = typename vector_type::type; -using int32x8_t = typename vector_type::type; + __host__ __device__ constexpr vector_type() : data_{type{0}} {} + + __host__ __device__ constexpr vector_type(type v) : data_{v} {} + + __host__ __device__ static constexpr index_t Size() { return 8; } + + __host__ __device__ constexpr const auto& Vector() const { return data_.d8_; } + + __host__ __device__ constexpr auto& Vector() { return data_.d8_; } + + __host__ __device__ constexpr const auto& Scalars() const { return data_.d1x8_; } + + __host__ __device__ constexpr auto& Scalars() { return data_.d1x8_; } + + __host__ __device__ constexpr const auto& Vectors(Number<1>) const { return data_.d1x8_; } + + __host__ __device__ constexpr const auto& Vectors(Number<2>) const { return data_.d2x4_; } + + __host__ __device__ constexpr const auto& Vectors(Number<4>) const { return data_.d4x2_; } + + __host__ __device__ constexpr const auto& Vectors(Number<8>) const { return data_.d8x1_; } + + __host__ __device__ constexpr auto& Vectors(Number<1>) { return data_.d1x8_; } + + __host__ __device__ constexpr auto& Vectors(Number<2>) { return data_.d2x4_; } + + __host__ __device__ constexpr auto& Vectors(Number<4>) { return data_.d4x2_; } + + __host__ __device__ constexpr auto& Vectors(Number<8>) { return data_.d8x1_; } +}; + +template <> +struct vector_type +{ + using d1_t = int8_t; + typedef int16_t d2_t; + typedef int32_t d4_t; + typedef int32x2_t d8_t; + typedef int32x4_t d16_t; + + using type = d16_t; + + union + { + d16_t d16_; + StaticallyIndexedArray d1x16_; + StaticallyIndexedArray d2x8_; + StaticallyIndexedArray d4x4_; + StaticallyIndexedArray d8x2_; + StaticallyIndexedArray d16x1_; + } data_; + + __host__ __device__ constexpr vector_type() : data_{type{0}} {} + + __host__ __device__ constexpr vector_type(type v) : data_{v} {} + + __host__ __device__ static constexpr index_t Size() { return 16; } + + __host__ __device__ constexpr const auto& Vector() const { return data_.d16_; } + + __host__ __device__ constexpr auto& Vector() { return data_.d16_; } + + __host__ __device__ constexpr const auto& Scalars() const { return data_.d1x16_; } + + __host__ __device__ constexpr auto& Scalars() { return data_.d1x16_; } + + __host__ __device__ constexpr const auto& Vectors(Number<1>) const { return data_.d1x16_; } + + __host__ __device__ constexpr const auto& Vectors(Number<2>) const { return data_.d2x8_; } + + __host__ __device__ constexpr const auto& Vectors(Number<4>) const { return data_.d4x4_; } + + __host__ __device__ constexpr const auto& Vectors(Number<8>) const { return data_.d8x2_; } + + __host__ __device__ constexpr const auto& Vectors(Number<16>) const { return data_.d16x1_; } + + __host__ __device__ constexpr auto& Vectors(Number<1>) { return data_.d1x16_; } + + __host__ __device__ constexpr auto& Vectors(Number<2>) { return data_.d2x8_; } + + __host__ __device__ constexpr auto& Vectors(Number<4>) { return data_.d4x4_; } + + __host__ __device__ constexpr auto& Vectors(Number<8>) { return data_.d8x2_; } + + __host__ __device__ constexpr auto& Vectors(Number<16>) { return data_.d16x1_; } +}; // i8 // hack for int8x4_t, because compiler does not have native support for int8x4_t // int8x4_t is defined as int32_t -using int8x4_t = typename vector_type::type; +using int8x4_t = typename vector_type::type; +using int8x8_t = typename vector_type::type; +using int8x16_t = typename vector_type::type; // data type conversion template @@ -339,6 +504,34 @@ struct inner_product_with_conversion return acc; } + + __device__ T operator()(int8x8_t a, int8x8_t b) const + { + const vector_type a_vector{a}; + const vector_type b_vector{b}; + + T acc = 0; + + static_for<0, 8, 1>{}([&](auto i) { + acc += convert(a_vector.Scalars()[i]) * convert(b_vector.Scalars()[i]); + }); + + return acc; + } + + __device__ T operator()(int8x16_t a, int8x16_t b) const + { + const vector_type a_vector{a}; + const vector_type b_vector{b}; + + T acc = 0; + + static_for<0, 16, 1>{}([&](auto i) { + acc += convert(a_vector.Scalars()[i]) * convert(b_vector.Scalars()[i]); + }); + + return acc; + } }; } // namespace ck diff --git a/driver/include/device_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw.hpp b/driver/include/device_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw.hpp index 3bce677665..8d3c0d10b1 100644 --- a/driver/include/device_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw.hpp +++ b/driver/include/device_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw.hpp @@ -2,6 +2,7 @@ #include "device.hpp" #include "host_tensor.hpp" #include "driver_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw.hpp" +#include "driver_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw_outpad.hpp" template {}; constexpr auto C1 = Number{}; + constexpr auto K0 = K / Number{}; + constexpr auto K1 = Number{}; + #if 0 // run-time variables const auto in_n_c_hi_wi_desc = @@ -76,19 +80,21 @@ void device_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw( make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(N, C0, Hi, Wi)); const auto wei_k_c0_y_x_desc = make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(K, C0, Y, X)); - const auto out_n_k_ho_wo_desc = - make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(N, K, Ho, Wo)); + const auto out_n_k0_ho_wo_k1_desc = + make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(N, K0, Ho, Wo, K1)); - const auto conv_strides = sequence_to_tuple_of_number(ConvStrides{}); - const auto conv_dilations = sequence_to_tuple_of_number(ConvDilations{}); - const auto in_left_pads = sequence_to_tuple_of_number(InLeftPads{}); - const auto in_right_pads = sequence_to_tuple_of_number(InRightPads{}); + const auto conv_strides = sequence_to_tuple_of_number(ConvStrides{}); + const auto conv_dilations = sequence_to_tuple_of_number(ConvDilations{}); + const auto in_left_pads = sequence_to_tuple_of_number(InLeftPads{}); + const auto in_right_pads = sequence_to_tuple_of_number(InRightPads{}); #endif Tensor in_n_c0_hi_wi_c1(make_HostTensorDescriptor( make_native_tensor_descriptor_packed(Sequence{}))); Tensor wei_k_c0_y_x_c1(make_HostTensorDescriptor( make_native_tensor_descriptor_packed(Sequence{}))); + Tensor out_n_k0_ho_wo_k1(make_HostTensorDescriptor( + make_native_tensor_descriptor_packed(Sequence{}))); auto f_nchw2nc0hwc1 = [&](auto n, auto hi, auto wi, auto c) { in_n_c0_hi_wi_c1(n, c / InWeiVectorSize, hi, wi, c % InWeiVectorSize) = @@ -106,13 +112,38 @@ void device_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw( in_n_c_hi_wi_device_buf.ToDevice(in_n_c0_hi_wi_c1.mData.data()); wei_k_c_y_x_device_buf.ToDevice(wei_k_c0_y_x_c1.mData.data()); +#if 1 // cdata = 64, BlockSize = 64, 16x8x32x4 constexpr index_t BlockSize = 64; + constexpr index_t KPerBlock = K; + constexpr index_t HoPerBlock = 8; + constexpr index_t WoPerBlock = 32; + constexpr index_t EPerBlock = C0; + + constexpr index_t KPerThread = KPerBlock; + constexpr index_t HoPerThread = 2; + constexpr index_t WoPerThread = 2; + constexpr index_t EPerThread = EPerBlock; + + using ABlockTransferThreadSliceLengths_E_K = Sequence<3, 1>; + using ABlockTransferThreadClusterLengths_E_K = Sequence<3 * EPerBlock, KPerBlock>; + + constexpr index_t ABlockTransferSrcScalarPerVector_E = 1; + constexpr index_t ABlockTransferDstScalarPerVector_K = 1; + + constexpr index_t BThreadTransferSrcScalarPerVector_W = 1; + + constexpr index_t CThreadTransferDstScalarPerVector_W = K1; + + static_assert(KPerThread % CThreadTransferDstScalarPerVector_W == 0, ""); +#else + constexpr index_t BlockSize = 64; + constexpr index_t KPerBlock = 16; constexpr index_t HoPerBlock = 8; constexpr index_t WoPerBlock = 32; - constexpr index_t EPerBlock = 4; + constexpr index_t EPerBlock = 1; constexpr index_t KPerThread = 16; constexpr index_t HoPerThread = 2; @@ -127,32 +158,28 @@ void device_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw( constexpr index_t BThreadTransferSrcScalarPerVector_W = 1; - constexpr index_t CThreadTransferDstScalarPerVector_W = 1; + constexpr index_t CThreadTransferDstScalarPerVector_W = K1; + + static_assert(KPerThread % CThreadTransferDstScalarPerVector_W == 0, ""); +#endif constexpr auto conv_driver = +#if 0 DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_pad< - BlockSize, - typename vector_type::type, - TAcc, - TOut, - KPerBlock, - HoPerBlock, - WoPerBlock, - EPerBlock, - KPerThread, - HoPerThread, - WoPerThread, - EPerThread, - ABlockTransferThreadSliceLengths_E_K, - ABlockTransferThreadClusterLengths_E_K, - ABlockTransferSrcScalarPerVector_E, - ABlockTransferDstScalarPerVector_K, - BThreadTransferSrcScalarPerVector_W, - CThreadTransferDstScalarPerVector_W>{}; +#else + DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_outpad< +#endif + BlockSize, + typename vector_type::type, TAcc, TOut, KPerBlock, + HoPerBlock, WoPerBlock, EPerBlock, KPerThread, HoPerThread, WoPerThread, + EPerThread, ABlockTransferThreadSliceLengths_E_K, + ABlockTransferThreadClusterLengths_E_K, ABlockTransferSrcScalarPerVector_E, + ABlockTransferDstScalarPerVector_K, BThreadTransferSrcScalarPerVector_W, + CThreadTransferDstScalarPerVector_W > {}; conv_driver.Run(wei_k_c0_y_x_desc, in_n_c0_hi_wi_desc, - out_n_k_ho_wo_desc, + out_n_k0_ho_wo_k1_desc, conv_strides, conv_dilations, in_left_pads, @@ -163,5 +190,12 @@ void device_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw( in_n_c_hi_wi_device_buf.GetDeviceBuffer()), static_cast(out_n_k_ho_wo_device_buf.GetDeviceBuffer())); - out_n_k_ho_wo_device_buf.FromDevice(out_n_k_ho_wo.mData.data()); + out_n_k_ho_wo_device_buf.FromDevice(out_n_k0_ho_wo_k1.mData.data()); + + auto f_nk0hwk1_to_nkhw = [&](auto n, auto k, auto ho, auto wo) { + out_n_k_ho_wo(n, k, ho, wo) = + out_n_k0_ho_wo_k1(n, k / InWeiVectorSize, ho, wo, k % InWeiVectorSize); + }; + + make_ParallelTensorFunctor(f_nk0hwk1_to_nkhw, N, K, Ho, Wo)(); } diff --git a/driver/src/conv_driver.cpp b/driver/src/conv_driver.cpp index 8d5bc24c8c..1e9487287d 100644 --- a/driver/src/conv_driver.cpp +++ b/driver/src/conv_driver.cpp @@ -48,8 +48,8 @@ int main(int argc, char* argv[]) using ConvStrides = Sequence<1, 1>; using ConvDilations = Sequence<1, 1>; - using LeftPads = Sequence<0, 0>; - using RightPads = Sequence<0, 0>; + using LeftPads = Sequence<0, 0>; + using RightPads = Sequence<0, 0>; #elif 0 constexpr index_t N = 1; constexpr index_t C = 16; @@ -62,8 +62,8 @@ int main(int argc, char* argv[]) using ConvStrides = Sequence<1, 1>; using ConvDilations = Sequence<1, 1>; - using LeftPads = Sequence<0, 0>; - using RightPads = Sequence<0, 0>; + using LeftPads = Sequence<0, 0>; + using RightPads = Sequence<0, 0>; #elif 1 constexpr index_t N = 1; constexpr index_t C = 16; @@ -642,7 +642,7 @@ int main(int argc, char* argv[]) using out_data_t = int8_t; #elif 1 using in_data_t = int8_t; - constexpr index_t in_vector_size = 4; + constexpr index_t in_vector_size = 16; using acc_data_t = int32_t; using out_data_t = int8_t; #endif @@ -741,7 +741,7 @@ int main(int argc, char* argv[]) LeftPads{}, RightPads{}, nrepeat); -#elif 1 +#elif 0 device_dynamic_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk