From fbdf4332c79a18454a553105ae5373911b2ba4ce Mon Sep 17 00:00:00 2001 From: zjing14 Date: Fri, 16 Jul 2021 23:27:08 -0500 Subject: [PATCH] Add xdlops v4r4r4 into online compilation (#48) * init for v4r4 xdlops olc * refactor wrap * init impl of v4r4 nchw xdlops olc * tuning * test perf * fixed v4r4 nhwc * tuned v4r4 nhwc * use gridwise_gemm_xdlops_v2r3 * swap a/b * add pointer support into offline v2r3 * debugging v4r4r4 transform for olc * change timer of olc * refactor v4r4 xdlops nchw olc * remove transform fun in v4r4 xdlops nhwc olc Co-authored-by: Chao Liu --- ...plicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw.hpp | 365 ----------- .../driver/driver_dynamic_gemm_xdlops_v1.hpp | 384 ------------ .../driver/driver_dynamic_gemm_xdlops_v2.hpp | 202 ------ .../driver_dynamic_gemm_xdlops_v2r2.hpp | 167 ----- .../driver_dynamic_gemm_xdlops_v2r3.hpp | 29 + ...lution_into_gemm_v4r4r2_nhwc_kyxc_nhwk.hpp | 129 ---- .../blockwise_gemm_xdlops.hpp | 44 +- .../gridwise_dynamic_gemm_xdlops.hpp | 585 ------------------ .../gridwise_dynamic_gemm_xdlops_v2.hpp | 498 --------------- .../gridwise_dynamic_gemm_xdlops_v2r2.hpp | 509 --------------- .../gridwise_dynamic_gemm_xdlops_v2r3.hpp | 56 +- .../include/utility/common_header.hpp | 1 - ...plicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw.cpp | 359 +++++++++++ ...plicit_gemm_v4r4_xdlops_nhwc_kyxc_nhwk.cpp | 359 +++++++++++ driver/CMakeLists.txt | 2 + driver/conv_bwd_data_driver_v2.cpp | 2 +- driver/conv_driver_v2.cpp | 70 +-- driver/conv_driver_v2_olc.cpp | 108 +++- driver/include/conv_tunables.hpp | 140 +++++ ...plicit_gemm_v4r1_xdlops_nhwc_kyxc_nhwk.hpp | 1 + ...icit_gemm_v4r1r2_xdlops_nhwc_kyxc_nhwk.hpp | 1 + ...plicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw.hpp | 283 --------- ...icit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw.hpp | 137 +--- ...icit_gemm_v4r4r2_xdlops_nhwc_kyxc_nhwk.hpp | 240 ------- ...icit_gemm_v4r4r3_xdlops_nhwc_kyxc_nhwk.hpp | 305 --------- ...icit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk.hpp | 29 +- ...plicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw.hpp | 376 +++++++++++ ...plicit_gemm_v4r4_xdlops_nhwc_kyxc_nhwk.hpp | 379 ++++++++++++ 28 files changed, 1851 insertions(+), 3909 deletions(-) delete mode 100644 composable_kernel/include/driver/driver_dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw.hpp delete mode 100644 composable_kernel/include/driver/driver_dynamic_gemm_xdlops_v1.hpp delete mode 100644 composable_kernel/include/driver/driver_dynamic_gemm_xdlops_v2.hpp delete mode 100644 composable_kernel/include/driver/driver_dynamic_gemm_xdlops_v2r2.hpp delete mode 100644 composable_kernel/include/kernel_algorithm/transform_forward_convolution_into_gemm_v4r4r2_nhwc_kyxc_nhwk.hpp delete mode 100644 composable_kernel/include/tensor_operation/gridwise_dynamic_gemm_xdlops.hpp delete mode 100644 composable_kernel/include/tensor_operation/gridwise_dynamic_gemm_xdlops_v2.hpp delete mode 100644 composable_kernel/include/tensor_operation/gridwise_dynamic_gemm_xdlops_v2r2.hpp create mode 100644 composable_kernel/src/kernel_wrapper/dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw.cpp create mode 100644 composable_kernel/src/kernel_wrapper/dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nhwc_kyxc_nhwk.cpp delete mode 100644 driver/include/device_dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw.hpp delete mode 100644 driver/include/device_dynamic_convolution_forward_implicit_gemm_v4r4r2_xdlops_nhwc_kyxc_nhwk.hpp delete mode 100644 driver/include/device_dynamic_convolution_forward_implicit_gemm_v4r4r3_xdlops_nhwc_kyxc_nhwk.hpp create mode 100644 driver/include/olc_device_dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw.hpp create mode 100644 driver/include/olc_device_dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nhwc_kyxc_nhwk.hpp diff --git a/composable_kernel/include/driver/driver_dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw.hpp b/composable_kernel/include/driver/driver_dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw.hpp deleted file mode 100644 index c2a67062c8..0000000000 --- a/composable_kernel/include/driver/driver_dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw.hpp +++ /dev/null @@ -1,365 +0,0 @@ -#ifndef CK_DRIVER_DYNAMIC_CONVOLUTION_FORWARD_IMPLICIT_GEMM_V4R4_XDLOPS_NCHW_KCYX_NKHW_HPP -#define CK_DRIVER_DYNAMIC_CONVOLUTION_FORWARD_IMPLICIT_GEMM_V4R4_XDLOPS_NCHW_KCYX_NKHW_HPP - -#include "common_header.hpp" -#include "dynamic_tensor_descriptor.hpp" -#include "dynamic_tensor_descriptor_helper.hpp" -#include "driver_dynamic_gemm_xdlops_v1.hpp" -#include "driver_dynamic_gemm_xdlops_v2.hpp" - -namespace ck { - -// GemmM = K -// GemmN = N * Ho * Wo -// GemmK = C * Y * X -template -__host__ __device__ constexpr auto -transform_forward_convolution_into_gemm_v4r4_xdlops_nchw_kcyx_nkhw_pad( - const DynamicTensorDescriptor& wei_k_c_y_x_global_desc, - const DynamicTensorDescriptor& in_n_c_hi_wi_global_desc, - const DynamicTensorDescriptor& out_n_k_ho_wo_global_desc, - const ConvStrides& conv_strides, - const ConvDilations& conv_dilations, - const InLeftPads& in_left_pads, - const InRightPads& in_right_pads) -{ - constexpr auto I0 = Number<0>{}; - constexpr auto I1 = Number<1>{}; - constexpr auto I2 = Number<2>{}; - constexpr auto I3 = Number<3>{}; - - const auto N = in_n_c_hi_wi_global_desc.GetLength(I0); - const auto C = in_n_c_hi_wi_global_desc.GetLength(I1); - const auto K = out_n_k_ho_wo_global_desc.GetLength(I1); - - const auto Hi = in_n_c_hi_wi_global_desc.GetLength(I2); - const auto Wi = in_n_c_hi_wi_global_desc.GetLength(I3); - - const auto Ho = out_n_k_ho_wo_global_desc.GetLength(I2); - const auto Wo = out_n_k_ho_wo_global_desc.GetLength(I3); - - const auto Y = wei_k_c_y_x_global_desc.GetLength(I2); - const auto X = wei_k_c_y_x_global_desc.GetLength(I3); - - const auto ConvStrideH = conv_strides[I0]; - const auto ConvStrideW = conv_strides[I1]; - - const auto ConvDilationH = conv_dilations[I0]; - const auto ConvDilationW = conv_dilations[I1]; - - const auto InLeftPadH = in_left_pads[I0]; - const auto InLeftPadW = in_left_pads[I1]; - - const auto InRightPadH = in_right_pads[I0]; - const auto InRightPadW = in_right_pads[I1]; - - const auto GemmM = K; - const auto GemmN = N * Ho * Wo; - const auto GemmK = C * Y * X; - const auto GemmK0 = GemmK / GemmKPack; - - // weight tensor - const auto wei_gemmk_gemmm_global_desc = transform_dynamic_tensor_descriptor( - make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(K, C * Y * X)), - make_tuple(make_pass_through_transform(K), make_pass_through_transform(C * Y * X)), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<1>{}, Sequence<0>{})); - - const auto wei_gemmk0_gemmm_gemmk1_global_desc = transform_dynamic_tensor_descriptor( - wei_gemmk_gemmm_global_desc, - make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmKPack)), - make_pass_through_transform(GemmM)), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<0, 2>{}, Sequence<1>{})); - - // input tensor - const auto in_n_c_hip_wip_global_desc = transform_dynamic_tensor_descriptor( - in_n_c_hi_wi_global_desc, - make_tuple(make_pass_through_transform(N), - make_pass_through_transform(C), - make_pad_transform(Hi, InLeftPadH, InRightPadH), - make_pad_transform(Wi, InLeftPadW, InRightPadW)), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); - - const auto in_n_c_y_ho_x_wo_global_desc = transform_dynamic_tensor_descriptor( - in_n_c_hip_wip_global_desc, - make_tuple(make_pass_through_transform(N), - make_pass_through_transform(C), - make_embed_transform(make_tuple(Y, Ho), make_tuple(ConvDilationH, ConvStrideH)), - make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW))), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4, 5>{})); - - const auto in_gemmk_gemmn_global_desc = - transform_dynamic_tensor_descriptor(in_n_c_y_ho_x_wo_global_desc, - make_tuple(make_merge_transform(make_tuple(C, Y, X)), - make_merge_transform(make_tuple(N, Ho, Wo))), - make_tuple(Sequence<1, 2, 4>{}, Sequence<0, 3, 5>{}), - make_tuple(Sequence<0>{}, Sequence<1>{})); - - const auto in_gemmk0_gemmn_gemmk1_global_desc = transform_dynamic_tensor_descriptor( - in_gemmk_gemmn_global_desc, - make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmKPack)), - make_pass_through_transform(GemmN)), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<0, 2>{}, Sequence<1>{})); - - // output tensor - const auto out_gemmm_gemmn_global_desc = transform_dynamic_tensor_descriptor( - make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(N, K, Ho * Wo)), - make_tuple(make_pass_through_transform(K), make_merge_transform(make_tuple(N, Ho * Wo))), - make_tuple(Sequence<1>{}, Sequence<0, 2>{}), - make_tuple(Sequence<0>{}, Sequence<1>{})); - - assert(GemmM == out_gemmm_gemmn_global_desc.GetLength(I0)); - assert(GemmN == out_gemmm_gemmn_global_desc.GetLength(I1)); - assert(GemmK0 == in_gemmk0_gemmn_gemmk1_global_desc.GetLength(I0)); - assert(GemmK0 == wei_gemmk0_gemmm_gemmk1_global_desc.GetLength(I0)); - - assert(GemmM % GemmMPerBlock == 0 && GemmN % GemmNPerBlock == 0 && GemmK0 % GemmKPerBlock == 0); - - constexpr auto xdlops_gemm = XdlopsGemm{}; - - constexpr auto CLayout = xdlops_gemm.GetCLayout(); - - constexpr index_t M0 = CLayout.M1(); - constexpr index_t M1 = CLayout.N1(); - constexpr index_t M2 = CLayout.M0(); - - const auto out_m0_m1_m2_n_global_desc = transform_dynamic_tensor_descriptor( - out_gemmm_gemmn_global_desc, - make_tuple(make_unmerge_transform(make_tuple(GemmM / (M1 * M2), M1, M2)), - make_pass_through_transform(GemmN)), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<0, 1, 2>{}, Sequence<3>{})); - - // out_gemm_block_cluster_desc - const auto out_gemm_block_cluster_desc = make_cluster_descriptor_v2( - make_tuple(GemmM / Number{}, GemmN / Number{})); - - // hack to control index calculation when iterating over wei_gemmk0_gemmm_gemmk1_global tensor - constexpr auto wei_gemmk0_gemmm_gemmk1_global_iterator_hacks = make_tuple( - make_tuple(Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}), - make_tuple( - Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{})); - - constexpr auto wei_gemmk0_gemmm_gemmk1_global_move_slice_window_iterator_hacks = - Sequence<0, 0, 0, 0, 0>{}; - - // hack to control index calculation when iterating over in_gemmk0_gemmn_gemmk1_global tensor - constexpr auto in_gemmk0_gemmn_gemmk1_global_iterator_hacks = - make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{}, - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}), - make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{}, - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{})); - - constexpr auto in_gemmk0_gemmn_gemmk1_global_move_slice_window_iterator_hacks = - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 0, 0>{}; - - // hack to control index calculation when iterating over out_gemmm0_gemmm1_gemmn0_gemmn1_global - // tensor hack for NKHW format - constexpr auto out_m0_m1_m2_n_global_iterator_hacks = - make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 1, 0, 0>{}), - make_tuple(Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 2, 0, 0>{})); - - return make_tuple(wei_gemmk0_gemmm_gemmk1_global_desc, - in_gemmk0_gemmn_gemmk1_global_desc, - out_m0_m1_m2_n_global_desc, - out_gemm_block_cluster_desc, - wei_gemmk0_gemmm_gemmk1_global_iterator_hacks, - in_gemmk0_gemmn_gemmk1_global_iterator_hacks, - out_m0_m1_m2_n_global_iterator_hacks, - wei_gemmk0_gemmm_gemmk1_global_move_slice_window_iterator_hacks, - in_gemmk0_gemmn_gemmk1_global_move_slice_window_iterator_hacks); -} - -// GemmM = K -// GemmN = N * Ho * Wo -// GemmK = C * Y * X -template -__host__ __device__ constexpr auto -transform_forward_convolution_into_gemm_v4r4_xdlops_nchw_kcyx_nkhw_1x1( - const DynamicTensorDescriptor& wei_k_c_y_x_global_desc, - const DynamicTensorDescriptor& in_n_c_hi_wi_global_desc, - const DynamicTensorDescriptor& out_n_k_ho_wo_global_desc, - const ConvStrides& conv_strides, - const ConvDilations& conv_dilations, - const InLeftPads& in_left_pads, - const InRightPads& in_right_pads) -{ - constexpr auto I0 = Number<0>{}; - constexpr auto I1 = Number<1>{}; - constexpr auto I2 = Number<2>{}; - constexpr auto I3 = Number<3>{}; - - const auto N = in_n_c_hi_wi_global_desc.GetLength(I0); - const auto C = in_n_c_hi_wi_global_desc.GetLength(I1); - const auto K = out_n_k_ho_wo_global_desc.GetLength(I1); - - const auto Hi = in_n_c_hi_wi_global_desc.GetLength(I2); - const auto Wi = in_n_c_hi_wi_global_desc.GetLength(I3); - - const auto Ho = out_n_k_ho_wo_global_desc.GetLength(I2); - const auto Wo = out_n_k_ho_wo_global_desc.GetLength(I3); - - const auto Y = wei_k_c_y_x_global_desc.GetLength(I2); - const auto X = wei_k_c_y_x_global_desc.GetLength(I3); - - const auto ConvStrideH = conv_strides[I0]; - const auto ConvStrideW = conv_strides[I1]; - - const auto ConvDilationH = conv_dilations[I0]; - const auto ConvDilationW = conv_dilations[I1]; - - const auto InLeftPadH = in_left_pads[I0]; - const auto InLeftPadW = in_left_pads[I1]; - - const auto InRightPadH = in_right_pads[I0]; - const auto InRightPadW = in_right_pads[I1]; - - const auto GemmM = K; - const auto GemmN = N * Ho * Wo; - const auto GemmK = C * Y * X; - const auto GemmK0 = GemmK / GemmKPack; - - assert(Y == 1 && X == 1 && ConvStrideH == 1 && ConvStrideW == 1 && ConvDilationH == 1 && - ConvDilationW == 1 && InLeftPadH == 0 && InLeftPadW == 0 && InRightPadH == 0 && - InRightPadW == 0); - - // weight tensor - const auto wei_gemmk_gemmm_global_desc = transform_dynamic_tensor_descriptor( - make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(K, C)), - make_tuple(make_pass_through_transform(K), make_pass_through_transform(C)), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<1>{}, Sequence<0>{})); - - const auto wei_gemmk0_gemmm_gemmk1_global_desc = transform_dynamic_tensor_descriptor( - wei_gemmk_gemmm_global_desc, - make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmKPack)), - make_pass_through_transform(GemmM)), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<0, 2>{}, Sequence<1>{})); - - // input tensor - const auto in_gemmk_gemmn_global_desc = transform_dynamic_tensor_descriptor( - in_n_c_hi_wi_global_desc, - make_tuple(make_pass_through_transform(C), make_merge_transform(make_tuple(N, Ho, Wo))), - make_tuple(Sequence<1>{}, Sequence<0, 2, 3>{}), - make_tuple(Sequence<0>{}, Sequence<1>{})); - - const auto in_gemmk0_gemmn_gemmk1_global_desc = transform_dynamic_tensor_descriptor( - in_gemmk_gemmn_global_desc, - make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmKPack)), - make_pass_through_transform(GemmN)), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<0, 2>{}, Sequence<1>{})); - - // output tensor - const auto out_gemmm_gemmn_global_desc = transform_dynamic_tensor_descriptor( - make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(N, K, Ho * Wo)), - make_tuple(make_pass_through_transform(K), make_merge_transform(make_tuple(N, Ho * Wo))), - make_tuple(Sequence<1>{}, Sequence<0, 2>{}), - make_tuple(Sequence<0>{}, Sequence<1>{})); - - assert(GemmM == out_gemmm_gemmn_global_desc.GetLength(I0)); - assert(GemmN == out_gemmm_gemmn_global_desc.GetLength(I1)); - assert(GemmK0 == in_gemmk0_gemmn_gemmk1_global_desc.GetLength(I0)); - assert(GemmK0 == wei_gemmk0_gemmm_gemmk1_global_desc.GetLength(I0)); - - assert(GemmM % GemmMPerBlock == 0 && GemmN % GemmNPerBlock == 0 && GemmK0 % GemmKPerBlock == 0); - - constexpr auto xdlops_gemm = XdlopsGemm{}; - - constexpr auto CLayout = xdlops_gemm.GetCLayout(); - - constexpr index_t M0 = CLayout.M1(); - constexpr index_t M1 = CLayout.N1(); - constexpr index_t M2 = CLayout.M0(); - - const auto out_m0_m1_m2_n_global_desc = transform_dynamic_tensor_descriptor( - out_gemmm_gemmn_global_desc, - make_tuple(make_unmerge_transform(make_tuple(GemmM / (M1 * M2), M1, M2)), - make_pass_through_transform(GemmN)), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<0, 1, 2>{}, Sequence<3>{})); - - // out_gemm_block_cluster_desc - const auto out_gemm_block_cluster_desc = make_cluster_descriptor_v2( - make_tuple(GemmM / Number{}, GemmN / Number{})); - - // hack to control index calculation when iterating over wei_gemmk0_gemmm_gemmk1_global tensor - constexpr auto wei_gemmk0_gemmm_gemmk1_global_iterator_hacks = make_tuple( - make_tuple(Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}), - make_tuple( - Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{})); - - constexpr auto wei_gemmk0_gemmm_gemmk1_global_move_slice_window_iterator_hacks = - Sequence<0, 0, 0, 0, 0>{}; - - // hack to control index calculation when iterating over in_gemmk0_gemmn_gemmk1_global tensor - constexpr auto in_gemmk0_gemmn_gemmk1_global_iterator_hacks = make_tuple( - make_tuple(Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 1, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}), - make_tuple( - Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 2, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{})); - - constexpr auto in_gemmk0_gemmn_gemmk1_global_move_slice_window_iterator_hacks = - Sequence<0, 1, 2, 0, 0>{}; - - // hack to control index calculation when iterating over out_gemmm0_gemmm1_gemmn0_gemmn1_global - // tensor hack for NKHW format - constexpr auto out_m0_m1_m2_n_global_iterator_hacks = - make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 1, 0, 0>{}), - make_tuple(Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 2, 0, 0>{})); - - return make_tuple(wei_gemmk0_gemmm_gemmk1_global_desc, - in_gemmk0_gemmn_gemmk1_global_desc, - out_m0_m1_m2_n_global_desc, - out_gemm_block_cluster_desc, - wei_gemmk0_gemmm_gemmk1_global_iterator_hacks, - in_gemmk0_gemmn_gemmk1_global_iterator_hacks, - out_m0_m1_m2_n_global_iterator_hacks, - wei_gemmk0_gemmm_gemmk1_global_move_slice_window_iterator_hacks, - in_gemmk0_gemmn_gemmk1_global_move_slice_window_iterator_hacks); -} - -} // namespace ck -#endif diff --git a/composable_kernel/include/driver/driver_dynamic_gemm_xdlops_v1.hpp b/composable_kernel/include/driver/driver_dynamic_gemm_xdlops_v1.hpp deleted file mode 100644 index 5cde70cba9..0000000000 --- a/composable_kernel/include/driver/driver_dynamic_gemm_xdlops_v1.hpp +++ /dev/null @@ -1,384 +0,0 @@ -#ifndef CK_DRIVER_DYNAMIC_GEMM_XDLOPS_V1 -#define CK_DRIVER_DYNAMIC_GEMM_XDLOPS_V1 - -#include "common_header.hpp" -#include "dynamic_tensor_descriptor.hpp" -#include "dynamic_tensor_descriptor_helper.hpp" -#include "gridwise_dynamic_gemm_xdlops.hpp" -#include "gridwise_operation_wrapper.hpp" - -namespace ck { - -template -__host__ float launch_kernel_dynamic_gemm_xdlops_v1(const FloatAB* p_a_global, - const FloatAB* p_b_global, - FloatC* p_c_global, - const AGlobalDesc& a_k_m_global_desc, - const BGlobalDesc& b_k_n_global_desc, - const CGlobalDesc& c_m0_m1_n0_n1_global_desc, - const CBlockClusterDesc& c_block_cluster_desc, - AGlobalIteratorHacks, - BGlobalIteratorHacks, - CGlobalIteratorHacks, - AGlobalMoveSliceWindowIteratorHacks, - BGlobalMoveSliceWindowIteratorHacks, - index_t nrepeat) - -{ - constexpr auto I0 = Number<0>{}; - constexpr auto I1 = Number<1>{}; - constexpr auto I2 = Number<2>{}; - constexpr auto I3 = Number<3>{}; - - const auto M = a_k_m_global_desc.GetLength(I1); - const auto N = b_k_n_global_desc.GetLength(I1); - const auto K = a_k_m_global_desc.GetLength(I0); - - if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && K % KPerBlock == 0)) - { - throw std::runtime_error("wrong! GEMM size no divisible"); - } - - if(!(MPerBlock % MPerWave == 0 && NPerBlock % NPerWave == 0)) - { - throw std::runtime_error("wrong! GEMM size no divisible"); - } - - // GEMM - using gridwise_gemm = - GridwiseDynamicGemm_km_kn_m0m1n0n1_xdlops_v1; - - const auto GridSize = (M / MPerBlock) * (N / NPerBlock); - - const bool has_main_k_block_loop = (K + KPerBlock) / (2 * KPerBlock) > 1; - - const bool has_double_tail_k_block_loop = (K / KPerBlock) % 2 == 0; - - std::cerr << "has_main_k_block_loop = " << has_main_k_block_loop - << " has_double_tail_k_block_loop = " << has_double_tail_k_block_loop << std::endl; - -#if CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VALUE - float ave_time = 0; - - if(has_main_k_block_loop && has_double_tail_k_block_loop) - { - const auto kernel = kernel_dynamic_gemm_xdlops_v1, - remove_reference_t, - remove_reference_t, - remove_reference_t, - true, - true>; - - ave_time = launch_and_time_kernel(kernel, - nrepeat, - dim3(GridSize), - dim3(BlockSize), - 0, - 0, - p_a_global, - p_b_global, - p_c_global, - a_k_m_global_desc, - b_k_n_global_desc, - c_m0_m1_n0_n1_global_desc, - c_block_cluster_desc); - } - else if(has_main_k_block_loop && !has_double_tail_k_block_loop) - { - const auto kernel = kernel_dynamic_gemm_xdlops_v1, - remove_reference_t, - remove_reference_t, - remove_reference_t, - true, - false>; - - ave_time = launch_and_time_kernel(kernel, - nrepeat, - dim3(GridSize), - dim3(BlockSize), - 0, - 0, - p_a_global, - p_b_global, - p_c_global, - a_k_m_global_desc, - b_k_n_global_desc, - c_m0_m1_n0_n1_global_desc, - c_block_cluster_desc); - } - else if(!has_main_k_block_loop && has_double_tail_k_block_loop) - { - const auto kernel = kernel_dynamic_gemm_xdlops_v1, - remove_reference_t, - remove_reference_t, - remove_reference_t, - false, - true>; - - ave_time = launch_and_time_kernel(kernel, - nrepeat, - dim3(GridSize), - dim3(BlockSize), - 0, - 0, - p_a_global, - p_b_global, - p_c_global, - a_k_m_global_desc, - b_k_n_global_desc, - c_m0_m1_n0_n1_global_desc, - c_block_cluster_desc); - } - else - { - const auto kernel = kernel_dynamic_gemm_xdlops_v1, - remove_reference_t, - remove_reference_t, - remove_reference_t, - false, - false>; - - ave_time = launch_and_time_kernel(kernel, - nrepeat, - dim3(GridSize), - dim3(BlockSize), - 0, - 0, - p_a_global, - p_b_global, - p_c_global, - a_k_m_global_desc, - b_k_n_global_desc, - c_m0_m1_n0_n1_global_desc, - c_block_cluster_desc); - } - - return ave_time; -#elif CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VOID_POINTER - DeviceMem a_k_m_global_desc_device_buf(sizeof(AGlobalDesc)); - DeviceMem b_k_n_global_desc_device_buf(sizeof(BGlobalDesc)); - DeviceMem c_m0_m1_n0_n1_global_desc_device_buf(sizeof(CGlobalDesc)); - DeviceMem c_block_cluster_desc_device_buf(sizeof(c_block_cluster_desc)); - - a_k_m_global_desc_device_buf.ToDevice(&a_k_m_global_desc); - b_k_n_global_desc_device_buf.ToDevice(&b_k_n_global_desc); - c_m0_m1_n0_n1_global_desc_device_buf.ToDevice(&c_m0_m1_n0_n1_global_desc); - c_block_cluster_desc_device_buf.ToDevice(&c_block_cluster_desc); - - float ave_time = 0; - - if(has_main_k_block_loop && has_double_tail_k_block_loop) - { - const auto kernel = kernel_dynamic_gemm_xdlops_v1, - remove_reference_t, - remove_reference_t, - remove_reference_t, - true, - true>; - - ave_time = launch_and_time_kernel( - kernel, - nrepeat, - dim3(GridSize), - dim3(BlockSize), - 0, - 0, - p_a_global, - p_b_global, - p_c_global, - (void __CONSTANT__*)a_k_m_global_desc_device_buf.GetDeviceBuffer(), - (void __CONSTANT__*)b_k_n_global_desc_device_buf.GetDeviceBuffer(), - (void __CONSTANT__*)c_m0_m1_n0_n1_global_desc_device_buf.GetDeviceBuffer(), - (void __CONSTANT__*)c_block_cluster_desc_device_buf.GetDeviceBuffer()); - } - else if(has_main_k_block_loop && !has_double_tail_k_block_loop) - { - const auto kernel = kernel_dynamic_gemm_xdlops_v1, - remove_reference_t, - remove_reference_t, - remove_reference_t, - true, - false>; - - ave_time = launch_and_time_kernel( - kernel, - nrepeat, - dim3(GridSize), - dim3(BlockSize), - 0, - 0, - p_a_global, - p_b_global, - p_c_global, - (void __CONSTANT__*)a_k_m_global_desc_device_buf.GetDeviceBuffer(), - (void __CONSTANT__*)b_k_n_global_desc_device_buf.GetDeviceBuffer(), - (void __CONSTANT__*)c_m0_m1_n0_n1_global_desc_device_buf.GetDeviceBuffer(), - (void __CONSTANT__*)c_block_cluster_desc_device_buf.GetDeviceBuffer()); - } - else if(!has_main_k_block_loop && has_double_tail_k_block_loop) - { - const auto kernel = kernel_dynamic_gemm_xdlops_v1, - remove_reference_t, - remove_reference_t, - remove_reference_t, - false, - true>; - - ave_time = launch_and_time_kernel( - kernel, - nrepeat, - dim3(GridSize), - dim3(BlockSize), - 0, - 0, - p_a_global, - p_b_global, - p_c_global, - (void __CONSTANT__*)a_k_m_global_desc_device_buf.GetDeviceBuffer(), - (void __CONSTANT__*)b_k_n_global_desc_device_buf.GetDeviceBuffer(), - (void __CONSTANT__*)c_m0_m1_n0_n1_global_desc_device_buf.GetDeviceBuffer(), - (void __CONSTANT__*)c_block_cluster_desc_device_buf.GetDeviceBuffer()); - } - else - { - const auto kernel = kernel_dynamic_gemm_xdlops_v1, - remove_reference_t, - remove_reference_t, - remove_reference_t, - false, - false>; - - ave_time = launch_and_time_kernel( - kernel, - nrepeat, - dim3(GridSize), - dim3(BlockSize), - 0, - 0, - p_a_global, - p_b_global, - p_c_global, - (void __CONSTANT__*)a_k_m_global_desc_device_buf.GetDeviceBuffer(), - (void __CONSTANT__*)b_k_n_global_desc_device_buf.GetDeviceBuffer(), - (void __CONSTANT__*)c_m0_m1_n0_n1_global_desc_device_buf.GetDeviceBuffer(), - (void __CONSTANT__*)c_block_cluster_desc_device_buf.GetDeviceBuffer()); - } - - return ave_time; -#endif -} - -} // namespace ck -#endif diff --git a/composable_kernel/include/driver/driver_dynamic_gemm_xdlops_v2.hpp b/composable_kernel/include/driver/driver_dynamic_gemm_xdlops_v2.hpp deleted file mode 100644 index e7462b919c..0000000000 --- a/composable_kernel/include/driver/driver_dynamic_gemm_xdlops_v2.hpp +++ /dev/null @@ -1,202 +0,0 @@ -#ifndef CK_DRIVER_DYNAMIC_GEMM_XDLOPS_V2 -#define CK_DRIVER_DYNAMIC_GEMM_XDLOPS_V2 - -#include "common_header.hpp" -#include "dynamic_tensor_descriptor.hpp" -#include "dynamic_tensor_descriptor_helper.hpp" -#include "gridwise_dynamic_gemm_xdlops_v2.hpp" -#include "gridwise_operation_wrapper.hpp" - -namespace ck { - -template -__host__ float launch_kernel_dynamic_gemm_xdlops_v2(const FloatAB* p_a_global, - const FloatAB* p_b_global, - FloatC* p_c_global, - const AGlobalDesc& a_k_m_global_desc, - const BGlobalDesc& b_k_n_global_desc, - const CGlobalDesc& c_m0_m1_n0_n1_global_desc, - const CBlockClusterDesc& c_block_cluster_desc, - AGlobalIteratorHacks, - BGlobalIteratorHacks, - CGlobalIteratorHacks, - AGlobalMoveSliceWindowIteratorHacks, - BGlobalMoveSliceWindowIteratorHacks, - index_t nrepeat) - -{ - constexpr auto I0 = Number<0>{}; - constexpr auto I1 = Number<1>{}; - constexpr auto I2 = Number<2>{}; - constexpr auto I3 = Number<3>{}; - - const auto M = a_k_m_global_desc.GetLength(I1); - const auto N = b_k_n_global_desc.GetLength(I1); - const auto K = a_k_m_global_desc.GetLength(I0); - - if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && K % KPerBlock == 0)) - { - throw std::runtime_error("wrong! GEMM size no divisible"); - } - - if(!(MPerBlock % MPerWave == 0 && NPerBlock % NPerWave == 0)) - { - throw std::runtime_error("wrong! GEMM size no divisible"); - } - - // GEMM - using gridwise_gemm = - GridwiseDynamicGemm_km_kn_m0m1n0n1_xdlops_v2; - - const auto GridSize = (M / MPerBlock) * (N / NPerBlock); - -#if CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VALUE - float ave_time = 0; - - const auto kernel = kernel_dynamic_gemm_xdlops_v2, - remove_reference_t, - remove_reference_t, - remove_reference_t>; - - ave_time = launch_and_time_kernel(kernel, - nrepeat, - dim3(GridSize), - dim3(BlockSize), - 0, - 0, - p_a_global, - p_b_global, - p_c_global, - a_k_m_global_desc, - b_k_n_global_desc, - c_m0_m1_n0_n1_global_desc, - c_block_cluster_desc); - - return ave_time; -#elif CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VOID_POINTER - DeviceMem a_k_m_global_desc_device_buf(sizeof(AGlobalDesc)); - DeviceMem b_k_n_global_desc_device_buf(sizeof(BGlobalDesc)); - DeviceMem c_m0_m1_n0_n1_global_desc_device_buf(sizeof(CGlobalDesc)); - DeviceMem c_block_cluster_desc_device_buf(sizeof(c_block_cluster_desc)); - - a_k_m_global_desc_device_buf.ToDevice(&a_k_m_global_desc); - b_k_n_global_desc_device_buf.ToDevice(&b_k_n_global_desc); - c_m0_m1_n0_n1_global_desc_device_buf.ToDevice(&c_m0_m1_n0_n1_global_desc); - c_block_cluster_desc_device_buf.ToDevice(&c_block_cluster_desc); - - float ave_time = 0; - - const auto kernel = kernel_dynamic_gemm_xdlops_v1, - remove_reference_t, - remove_reference_t, - remove_reference_t>; - - ave_time = launch_and_time_kernel( - kernel, - nrepeat, - dim3(GridSize), - dim3(BlockSize), - 0, - 0, - p_a_global, - p_b_global, - p_c_global, - (void __CONSTANT__*)a_k_m_global_desc_device_buf.GetDeviceBuffer(), - (void __CONSTANT__*)b_k_n_global_desc_device_buf.GetDeviceBuffer(), - (void __CONSTANT__*)c_m0_m1_n0_n1_global_desc_device_buf.GetDeviceBuffer(), - (void __CONSTANT__*)c_block_cluster_desc_device_buf.GetDeviceBuffer()); - - return ave_time; -#endif -} - -} // namespace ck -#endif diff --git a/composable_kernel/include/driver/driver_dynamic_gemm_xdlops_v2r2.hpp b/composable_kernel/include/driver/driver_dynamic_gemm_xdlops_v2r2.hpp deleted file mode 100644 index 6bf53e06e2..0000000000 --- a/composable_kernel/include/driver/driver_dynamic_gemm_xdlops_v2r2.hpp +++ /dev/null @@ -1,167 +0,0 @@ -#ifndef CK_DRIVER_DYNAMIC_GEMM_XDLOPS_V2R2 -#define CK_DRIVER_DYNAMIC_GEMM_XDLOPS_V2R2 - -#include "common_header.hpp" -#include "dynamic_tensor_descriptor.hpp" -#include "dynamic_tensor_descriptor_helper.hpp" -#include "gridwise_dynamic_gemm_xdlops_v2r2.hpp" - -namespace ck { - -template -__host__ float driver_dynamic_gemm_xdlops_v2r2(const FloatAB* p_a_grid, - const FloatAB* p_b_grid, - FloatC* p_c_grid, - const AK0MK1GridDesc& a_k0_m_k1_grid_desc, - const BK0NK1GridDesc& b_k0_n_k1_grid_desc, - const CMNGridDesc& c_m_n_grid_desc, - AGridIteratorHacks, - BGridIteratorHacks, - CGridIteratorHacks, - AGridMoveSliceWindowIteratorHacks, - BGridMoveSliceWindowIteratorHacks, - index_t nrepeat) - -{ - constexpr auto I0 = Number<0>{}; - constexpr auto I1 = Number<1>{}; - constexpr auto I2 = Number<2>{}; - constexpr auto I3 = Number<3>{}; - constexpr auto I4 = Number<4>{}; - constexpr auto I5 = Number<5>{}; - - using GridwiseGemm = - GridwiseDynamicGemm_k0mk1_k0nk1_mn_xdlops_v2r2; - - { - std::cout << "a_k0_m_k1_grid_desc{" << a_k0_m_k1_grid_desc.GetLength(I0) << ", " - << a_k0_m_k1_grid_desc.GetLength(I1) << ", " << a_k0_m_k1_grid_desc.GetLength(I2) - << "}" << std::endl; - - std::cout << "b_k0_n_k1_grid_desc{" << b_k0_n_k1_grid_desc.GetLength(I0) << ", " - << b_k0_n_k1_grid_desc.GetLength(I1) << ", " << b_k0_n_k1_grid_desc.GetLength(I2) - << "}" << std::endl; - - std::cout << "c_m_n_grid_desc{ " << c_m_n_grid_desc.GetLength(I0) << ", " - << c_m_n_grid_desc.GetLength(I1) << "}" << std::endl; - } - - if(!GridwiseGemm::CheckValidity(a_k0_m_k1_grid_desc, b_k0_n_k1_grid_desc, c_m_n_grid_desc)) - { - throw std::runtime_error( - "wrong! GridwiseDynamicGemm_km_kn_m0m1n0n1_xdlops_v2r2 has invalid setting"); - } - - const auto c_m0_m1_m2_n_grid_desc = GridwiseGemm::MakeCM0M1M2NGridDescriptor(c_m_n_grid_desc); - - using CM0M1M2NGridDesc = decltype(c_m0_m1_m2_n_grid_desc); - - const auto c_block_cluster_adaptor = GridwiseGemm::MakeCBlockClusterAdaptor(c_m_n_grid_desc); - - using CBlockClusterAdaptor = decltype(c_block_cluster_adaptor); - - const index_t grid_size = GridwiseGemm::CalculateGridSize(c_m_n_grid_desc); - - const auto kernel = kernel_dynamic_gemm_xdlops_v2r2, - remove_reference_t, - remove_reference_t, - remove_reference_t>; - - float ave_time = launch_and_time_kernel(kernel, - nrepeat, - dim3(grid_size), - dim3(BlockSize), - 0, - 0, - p_a_grid, - p_b_grid, - p_c_grid, - a_k0_m_k1_grid_desc, - b_k0_n_k1_grid_desc, - c_m0_m1_m2_n_grid_desc, - c_block_cluster_adaptor); - - return ave_time; -} - -} // namespace ck -#endif diff --git a/composable_kernel/include/driver/driver_dynamic_gemm_xdlops_v2r3.hpp b/composable_kernel/include/driver/driver_dynamic_gemm_xdlops_v2r3.hpp index ca65c2d073..f07a51d21d 100644 --- a/composable_kernel/include/driver/driver_dynamic_gemm_xdlops_v2r3.hpp +++ b/composable_kernel/include/driver/driver_dynamic_gemm_xdlops_v2r3.hpp @@ -21,6 +21,7 @@ template , remove_reference_t>; +#if CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VALUE float ave_time = launch_and_time_kernel(kernel, nrepeat, dim3(grid_size), @@ -162,6 +165,32 @@ __host__ float driver_dynamic_gemm_xdlops_v2r3(const FloatAB* p_a_grid, c_m0_m1_m2_n_grid_desc, c_block_cluster_adaptor); +#elif CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VOID_POINTER + DeviceMem a_k0_m_k1_grid_desc_dev_buf(sizeof(AK0MK1GridDesc)); + DeviceMem b_k0_n_k1_grid_desc_dev_buf(sizeof(BK0NK1GridDesc)); + DeviceMem c_m0_m1_m2_n_grid_desc_dev_buf(sizeof(CM0M1M2NGridDesc)); + DeviceMem c_block_cluster_adaptor_dev_buf(sizeof(CBlockClusterAdaptor)); + + a_k0_m_k1_grid_desc_dev_buf.ToDevice(&a_k0_m_k1_grid_desc); + b_k0_n_k1_grid_desc_dev_buf.ToDevice(&b_k0_n_k1_grid_desc); + c_m0_m1_m2_n_grid_desc_dev_buf.ToDevice(&c_m0_m1_m2_n_grid_desc); + c_block_cluster_adaptor_dev_buf.ToDevice(&c_block_cluster_adaptor); + + float ave_time = launch_and_time_kernel( + kernel, + nrepeat, + dim3(grid_size), + dim3(BlockSize), + 0, + 0, + p_a_grid, + p_b_grid, + p_c_grid, + (void __CONSTANT__*)a_k0_m_k1_grid_desc_dev_buf.GetDeviceBuffer(), + (void __CONSTANT__*)b_k0_n_k1_grid_desc_dev_buf.GetDeviceBuffer(), + (void __CONSTANT__*)c_m0_m1_m2_n_grid_desc_dev_buf.GetDeviceBuffer(), + (void __CONSTANT__*)c_block_cluster_adaptor_dev_buf.GetDeviceBuffer()); +#endif return ave_time; } diff --git a/composable_kernel/include/kernel_algorithm/transform_forward_convolution_into_gemm_v4r4r2_nhwc_kyxc_nhwk.hpp b/composable_kernel/include/kernel_algorithm/transform_forward_convolution_into_gemm_v4r4r2_nhwc_kyxc_nhwk.hpp deleted file mode 100644 index 5814e66766..0000000000 --- a/composable_kernel/include/kernel_algorithm/transform_forward_convolution_into_gemm_v4r4r2_nhwc_kyxc_nhwk.hpp +++ /dev/null @@ -1,129 +0,0 @@ -#ifndef CK_TRANSFORM_FORWARD_CONVOLUTION_INTO_GEMM_V4R4R2_NHWC_KYXC_NHWK_HPP -#define CK_TRANSFORM_FORWARD_CONVOLUTION_INTO_GEMM_V4R4R2_NHWC_KYXC_NHWK_HPP - -#include "common_header.hpp" -#include "dynamic_tensor_descriptor.hpp" -#include "dynamic_tensor_descriptor_helper.hpp" - -namespace ck { - -// GemmM = K -// GemmN = N * Ho * Wo -// GemmK = C * Y * X -template -__host__ __device__ constexpr auto -transform_forward_convolution_into_gemm_v4r4r2_nhwc_kyxc_nhwk_pad( - const DynamicTensorDescriptor& wei_k_y_x_c_grid_desc, - const DynamicTensorDescriptor& in_n_hi_wi_c_grid_desc, - const DynamicTensorDescriptor& out_n_ho_wo_k_grid_desc, - const ConvStrides& conv_strides, - const ConvDilations& conv_dilations, - const InLeftPads& in_left_pads, - const InRightPads& in_right_pads, - Number) -{ - constexpr auto I0 = Number<0>{}; - constexpr auto I1 = Number<1>{}; - constexpr auto I2 = Number<2>{}; - constexpr auto I3 = Number<3>{}; - - constexpr auto GemmK1 = Number{}; - - const auto N = in_n_hi_wi_c_grid_desc.GetLength(I0); - const auto C = in_n_hi_wi_c_grid_desc.GetLength(I3); - const auto K = out_n_ho_wo_k_grid_desc.GetLength(I3); - - const auto Hi = in_n_hi_wi_c_grid_desc.GetLength(I1); - const auto Wi = in_n_hi_wi_c_grid_desc.GetLength(I2); - - const auto Ho = out_n_ho_wo_k_grid_desc.GetLength(I1); - const auto Wo = out_n_ho_wo_k_grid_desc.GetLength(I2); - - const auto Y = wei_k_y_x_c_grid_desc.GetLength(I1); - const auto X = wei_k_y_x_c_grid_desc.GetLength(I2); - - const auto ConvStrideH = conv_strides[I0]; - const auto ConvStrideW = conv_strides[I1]; - - const auto ConvDilationH = conv_dilations[I0]; - const auto ConvDilationW = conv_dilations[I1]; - - const auto InLeftPadH = in_left_pads[I0]; - const auto InLeftPadW = in_left_pads[I1]; - - const auto InRightPadH = in_right_pads[I0]; - const auto InRightPadW = in_right_pads[I1]; - - const auto GemmM = K; - const auto GemmN = N * Ho * Wo; - const auto GemmK = C * Y * X; - const auto GemmK0 = GemmK / GemmK1; - - // weight tensor - const auto wei_gemmk_gemmm_grid_desc = transform_dynamic_tensor_descriptor( - make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(K, Y * X * C)), - make_tuple(make_pass_through_transform(K), make_pass_through_transform(Y * X * C)), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<1>{}, Sequence<0>{})); - - const auto wei_gemmk0_gemmm_gemmk1_grid_desc = transform_dynamic_tensor_descriptor( - wei_gemmk_gemmm_grid_desc, - make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1)), - make_pass_through_transform(GemmM)), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<0, 2>{}, Sequence<1>{})); - - // input tensor - const auto in_n_hip_wip_c_grid_desc = transform_dynamic_tensor_descriptor( - in_n_hi_wi_c_grid_desc, - make_tuple(make_pass_through_transform(N), - make_pad_transform(Hi, InLeftPadH, InRightPadH), - make_pad_transform(Wi, InLeftPadW, InRightPadW), - make_pass_through_transform(C)), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); - - const auto in_n_y_ho_x_wo_c_grid_desc = transform_dynamic_tensor_descriptor( - in_n_hip_wip_c_grid_desc, - make_tuple(make_pass_through_transform(N), - make_embed_transform(make_tuple(Y, Ho), make_tuple(ConvDilationH, ConvStrideH)), - make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW)), - make_pass_through_transform(C)), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), - make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{})); - - const auto in_gemmk_gemmn_grid_desc = - transform_dynamic_tensor_descriptor(in_n_y_ho_x_wo_c_grid_desc, - make_tuple(make_merge_transform(make_tuple(Y, X, C)), - make_merge_transform(make_tuple(N, Ho, Wo))), - make_tuple(Sequence<1, 3, 5>{}, Sequence<0, 2, 4>{}), - make_tuple(Sequence<0>{}, Sequence<1>{})); - - const auto in_gemmk0_gemmn_gemmk1_grid_desc = transform_dynamic_tensor_descriptor( - in_gemmk_gemmn_grid_desc, - make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1)), - make_pass_through_transform(GemmN)), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<0, 2>{}, Sequence<1>{})); - - // output tensor - const auto out_gemmm_gemmn_grid_desc = transform_dynamic_tensor_descriptor( - make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(N * Ho * Wo, K)), - make_tuple(make_pass_through_transform(N * Ho * Wo), make_pass_through_transform(K)), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<1>{}, Sequence<0>{})); - - return make_tuple(wei_gemmk0_gemmm_gemmk1_grid_desc, - in_gemmk0_gemmn_gemmk1_grid_desc, - out_gemmm_gemmn_grid_desc); -} - -} // namespace ck -#endif diff --git a/composable_kernel/include/tensor_operation/blockwise_gemm_xdlops.hpp b/composable_kernel/include/tensor_operation/blockwise_gemm_xdlops.hpp index 4b8133870e..f21983d5b5 100644 --- a/composable_kernel/include/tensor_operation/blockwise_gemm_xdlops.hpp +++ b/composable_kernel/include/tensor_operation/blockwise_gemm_xdlops.hpp @@ -13,7 +13,7 @@ template + index_t K1> struct BlockwiseGemmXdlops_km_kn_m0m1m2n_v1 { @@ -32,7 +32,7 @@ struct BlockwiseGemmXdlops_km_kn_m0m1m2n_v1 static constexpr index_t N0 = BBlockDesc{}.GetLength(I1); static constexpr index_t N1 = BBlockDesc{}.GetLength(I2); - static constexpr auto xdlops_gemm = XdlopsGemm{}; + static constexpr auto xdlops_gemm = XdlopsGemm{}; static constexpr index_t MWaves = M1 / MPerWave; static constexpr index_t NWaves = N1 / NPerWave; @@ -119,18 +119,18 @@ struct BlockwiseGemmXdlops_km_kn_m0m1m2n_v1 "wrong! K dimension not consistent"); static_assert(ABlockDesc{}.GetLength(I3) == BBlockDesc{}.GetLength(I3), - "wrong! KPack dimension not consistent"); + "wrong! K1 dimension not consistent"); static_assert(BlockSize == MWaves * NWaves * WaveSize, "BlockSize != MWaves * NWaves * WaveSize\n"); - static_assert(KPack == BBlockDesc{}.GetLength(I3), "KPack is wrong!"); + static_assert(K1 == BBlockDesc{}.GetLength(I3), "K1 is wrong!"); constexpr index_t KPerBlock = ABlockDesc{}.GetLength(I0); static_assert(KPerBlock % xdlops_gemm.KPerXdlops == 0, "KPerBlock is wrong!"); - static_assert(KPack % xdlops_gemm.mfma_type.k_base == 0, "KPack is wrong!"); + static_assert(K1 % xdlops_gemm.mfma_type.k_base == 0, "K1 is wrong!"); } template @@ -194,11 +194,11 @@ struct BlockwiseGemmXdlops_km_kn_m0m1m2n_v1 private: // A[K, M] static constexpr auto a_thread_desc_ = make_dynamic_naive_tensor_descriptor_packed_v2( - make_tuple(I1, Number{}, I1, Number{})); + make_tuple(I1, Number{}, I1, Number{})); // B[K, N] static constexpr auto b_thread_desc_ = make_dynamic_naive_tensor_descriptor_packed_v2( - make_tuple(I1, Number{}, I1, Number{})); + make_tuple(I1, Number{}, I1, Number{})); static constexpr auto c_thread_desc_ = make_dynamic_naive_tensor_descriptor_packed_v2( make_tuple(Number{}, Number{})); @@ -207,20 +207,20 @@ struct BlockwiseGemmXdlops_km_kn_m0m1m2n_v1 FloatAB, ABlockDesc, decltype(a_thread_desc_), - Sequence<1, MRepeat, 1, KPack>, + Sequence<1, MRepeat, 1, K1>, Sequence<0, 1, 2, 3>, 3, - KPack, + K1, 1>; using BThreadCopy = ThreadwiseDynamicTensorSliceTransfer_v4, + Sequence<1, NRepeat, 1, K1>, Sequence<0, 1, 2, 3>, 3, - KPack, + K1, 1>; AThreadCopy a_thread_copy_; @@ -233,7 +233,7 @@ template + index_t K1> struct BlockwiseGemmXdlops_km_kn_m0m1m2n_v1_2x2pipeline { @@ -244,7 +244,7 @@ struct BlockwiseGemmXdlops_km_kn_m0m1m2n_v1_2x2pipeline static constexpr auto I2 = Number<2>{}; static constexpr auto I3 = Number<3>{}; - static constexpr auto xdlops_gemm = XdlopsGemm{}; + static constexpr auto xdlops_gemm = XdlopsGemm{}; static constexpr index_t WaveSize = 64; @@ -339,18 +339,18 @@ struct BlockwiseGemmXdlops_km_kn_m0m1m2n_v1_2x2pipeline "wrong! K dimension not consistent"); static_assert(ABlockDesc{}.GetLength(I3) == BBlockDesc{}.GetLength(I3), - "wrong! KPack dimension not consistent"); + "wrong! K1 dimension not consistent"); static_assert(BlockSize == MWaves * NWaves * WaveSize, "BlockSize != MWaves * NWaves * WaveSize\n"); - static_assert(KPack == BBlockDesc{}.GetLength(I3), "KPack is wrong!"); + static_assert(K1 == BBlockDesc{}.GetLength(I3), "K1 is wrong!"); constexpr index_t KPerBlock = ABlockDesc{}.GetLength(I0); static_assert(KPerBlock % xdlops_gemm.KPerXdlops == 0, "KPerBlock is wrong!"); - static_assert(KPack % xdlops_gemm.mfma_type.k_base == 0, "KPack is wrong!"); + static_assert(K1 % xdlops_gemm.mfma_type.k_base == 0, "K1 is wrong!"); } template @@ -491,11 +491,11 @@ struct BlockwiseGemmXdlops_km_kn_m0m1m2n_v1_2x2pipeline private: // A[K, M] static constexpr auto a_thread_desc_ = make_dynamic_naive_tensor_descriptor_packed_v2( - make_tuple(I1, Number{}, I1, Number{})); + make_tuple(I1, Number{}, I1, Number{})); // B[K, N] static constexpr auto b_thread_desc_ = make_dynamic_naive_tensor_descriptor_packed_v2( - make_tuple(I1, Number{}, I1, Number{})); + make_tuple(I1, Number{}, I1, Number{})); static constexpr auto c_thread_desc_ = make_dynamic_naive_tensor_descriptor_packed_v2( make_tuple(Number{}, Number{})); @@ -504,20 +504,20 @@ struct BlockwiseGemmXdlops_km_kn_m0m1m2n_v1_2x2pipeline FloatAB, ABlockDesc, decltype(a_thread_desc_), - Sequence<1, 1, 1, KPack>, + Sequence<1, 1, 1, K1>, Sequence<0, 1, 2, 3>, 3, - 1, // KPack, + 1, // K1, 1>; using BThreadCopy = ThreadwiseDynamicTensorSliceTransfer_v4, + Sequence<1, 1, 1, K1>, Sequence<0, 1, 2, 3>, 3, - 1, // KPack, + 1, // K1, 1>; AThreadCopy a_thread_copy_; diff --git a/composable_kernel/include/tensor_operation/gridwise_dynamic_gemm_xdlops.hpp b/composable_kernel/include/tensor_operation/gridwise_dynamic_gemm_xdlops.hpp deleted file mode 100644 index 3fe9eb3b36..0000000000 --- a/composable_kernel/include/tensor_operation/gridwise_dynamic_gemm_xdlops.hpp +++ /dev/null @@ -1,585 +0,0 @@ -#ifndef CK_GRIDWISE_DYNAMIC_GEMM_XDLOPS_HPP -#define CK_GRIDWISE_DYNAMIC_GEMM_XDLOPS_HPP - -#include "common_header.hpp" -#include "dynamic_multi_index_transform_helper.hpp" -#include "dynamic_tensor_descriptor.hpp" -#include "dynamic_tensor_descriptor_helper.hpp" -#include "blockwise_gemm_xdlops.hpp" -#include "blockwise_dynamic_tensor_slice_transfer.hpp" -#include "threadwise_dynamic_tensor_slice_transfer.hpp" -#include "threadwise_dynamic_tensor_slice_set.hpp" - -namespace ck { - -#if CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VALUE -template -__global__ void -#if CK_USE_LAUNCH_BOUNDS - __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) -#endif - kernel_dynamic_gemm_xdlops_v1(const FloatA* __restrict__ p_a_global, - const FloatB* __restrict__ p_b_global, - FloatC* __restrict__ p_c_global, - const AGlobalDesc a_k0_m_k1_global_desc, - const BGlobalDesc b_k0_n_k1_global_desc, - const CGlobalDesc c_m0_m1_m2_n_global_desc, - const CBlockClusterDesc c_block_cluster_desc) -{ - GridwiseGemm::Run(p_a_global, - p_b_global, - p_c_global, - a_k0_m_k1_global_desc, - b_k0_n_k1_global_desc, - c_m0_m1_m2_n_global_desc, - c_block_cluster_desc, - integral_constant{}, - integral_constant{}); -} -#elif CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VOID_POINTER -// pass tensor descriptor by __CONSTANT__ void pointer -// __CONSTANT__ is needed to inform compiler void pointers in the kernel signature are pointing to -// non-modifiable parameter address space, so compiler can enable corresponding optimization -template -__global__ void -#if CK_USE_LAUNCH_BOUNDS - __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) -#endif - kernel_dynamic_gemm_xdlops_v1(const FloatA* __restrict__ p_a_global, - const FloatB* __restrict__ p_b_global, - FloatC* __restrict__ p_c_global, - const void __CONSTANT__* p_a_k0_m_k1_global_desc, - const void __CONSTANT__* p_b_k0_n_k1_global_desc, - const void __CONSTANT__* p_c_m0_m1_m2_n_global_desc, - const void __CONSTANT__* p_c_block_cluster_desc) -{ - // first cast void __CONSTANT__ void* to void* - // second cast void* to Desc* - // the copy constructor of tensor descriptor doesn't take address_space(4) - const auto a_k0_m_k1_global_desc = - *reinterpret_cast((const void*)p_a_k0_m_k1_global_desc); - const auto b_k0_n_k1_global_desc = - *reinterpret_cast((const void*)p_b_k0_n_k1_global_desc); - const auto c_m0_m1_m2_n_global_desc = - *reinterpret_cast((const void*)p_c_m0_m1_m2_n_global_desc); - - const auto c_block_cluster_desc = - *reinterpret_cast((const void*)p_c_block_cluster_desc); - - GridwiseGemm::Run(p_a_global, - p_b_global, - p_c_global, - a_k0_m_k1_global_desc, - b_k0_n_k1_global_desc, - c_m0_m1_m2_n_global_desc, - c_block_cluster_desc, - integral_constant{}, - integral_constant{}); -} -#endif - -template -struct GridwiseDynamicGemm_km_kn_m0m1n0n1_xdlops_v1 -{ - __host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte() - { - constexpr auto max_lds_align = Number{}; - - // A matrix in LDS memory, dst of blockwise copy - // be careful of LDS alignment - constexpr auto a_k0_m_k1_block_desc = make_dynamic_naive_tensor_descriptor_aligned_v2( - make_tuple(Number{}, Number{}, Number{}), max_lds_align); - - // B matrix in LDS memory, dst of blockwise copy - // be careful of LDS alignment - constexpr auto b_k0_n_k1_block_desc = make_dynamic_naive_tensor_descriptor_aligned_v2( - make_tuple(Number{}, Number{}, Number{}), max_lds_align); - - // LDS allocation for A and B: be careful of alignment - constexpr auto a_block_space_size = - math::integer_least_multiple(a_k0_m_k1_block_desc.GetElementSpaceSize(), max_lds_align); - - constexpr auto b_block_space_size = - math::integer_least_multiple(b_k0_n_k1_block_desc.GetElementSpaceSize(), max_lds_align); - - return 2 * (a_block_space_size + b_block_space_size) * sizeof(FloatAB); - } - - template - __device__ static void Run(const FloatAB* __restrict__ p_a_global, - const FloatAB* __restrict__ p_b_global, - FloatC* __restrict__ p_c_global, - const AGlobalDesc& a_k0_m_k1_global_desc, - const BGlobalDesc& b_k0_n_k1_global_desc, - const CGlobalDesc& c_m0_m1_m2_n_global_desc, - const CBlockClusterDesc& c_block_cluster_desc, - FloatAB* __restrict__ p_shared_block, - integral_constant, - integral_constant) - { - constexpr auto I0 = Number<0>{}; - constexpr auto I1 = Number<1>{}; - constexpr auto I2 = Number<2>{}; - constexpr auto I3 = Number<3>{}; - - const auto a_global_buf = make_dynamic_buffer( - p_a_global, a_k0_m_k1_global_desc.GetElementSpaceSize()); - const auto b_global_buf = make_dynamic_buffer( - p_b_global, b_k0_n_k1_global_desc.GetElementSpaceSize()); - auto c_global_buf = make_dynamic_buffer( - p_c_global, c_m0_m1_m2_n_global_desc.GetElementSpaceSize()); - - const auto K0 = a_k0_m_k1_global_desc.GetLength(I0); - const auto M = a_k0_m_k1_global_desc.GetLength(I1); - const auto N = b_k0_n_k1_global_desc.GetLength(I1); - const auto K1 = b_k0_n_k1_global_desc.GetLength(I2); - - // divide block work by [M, N] - const auto block_work_idx = - c_block_cluster_desc.CalculateBottomIndex(make_multi_index(get_block_1d_id())); - - // HACK: this force m/n_block_data_idx_on_global into SGPR - const index_t m_block_data_idx_on_global = - __builtin_amdgcn_readfirstlane(block_work_idx[I0] * MPerBlock); - - const index_t n_block_data_idx_on_global = - __builtin_amdgcn_readfirstlane(block_work_idx[I1] * NPerBlock); - - // lds max alignment - constexpr auto max_lds_align = Number{}; - - // A matrix in LDS memory, dst of blockwise copy - // be careful of LDS alignment - constexpr auto a_k0_m_k1_block_desc = make_dynamic_naive_tensor_descriptor_aligned_v2( - make_tuple(Number{}, Number{}, Number{}), max_lds_align); - - // B matrix in LDS memory, dst of blockwise copy - // be careful of LDS alignment - constexpr auto b_k0_n_k1_block_desc = make_dynamic_naive_tensor_descriptor_aligned_v2( - make_tuple(Number{}, Number{}, Number{}), max_lds_align); - - // A matrix blockwise copy - auto a_blockwise_copy = - BlockwiseDynamicTensorSliceTransfer_v4, - ABlockTransferThreadSliceLengths_K_M_KPack, - ABlockTransferThreadClusterLengths_K_M_KPack, - ABlockTransferThreadClusterArrangeOrder, - FloatAB, - FloatAB, - decltype(a_k0_m_k1_global_desc), - decltype(a_k0_m_k1_block_desc), - ABlockTransferSrcAccessOrder, - Sequence<1, 0, 2>, - ABlockTransferSrcVectorDim, - 2, - ABlockTransferSrcScalarPerVector, - ABlockTransferDstScalarPerVector_KPack, - 1, - 1, - AThreadTransferSrcResetCoordinateAfterRun, - true>( - a_k0_m_k1_global_desc, - make_multi_index(0, m_block_data_idx_on_global, 0), - a_k0_m_k1_block_desc, - make_multi_index(0, 0, 0)); - - // B matrix blockwise copy - auto b_blockwise_copy = - BlockwiseDynamicTensorSliceTransfer_v4, - BBlockTransferThreadSliceLengths_K_N_KPack, - BBlockTransferThreadClusterLengths_K_N_KPack, - BBlockTransferThreadClusterArrangeOrder, - FloatAB, - FloatAB, - decltype(b_k0_n_k1_global_desc), - decltype(b_k0_n_k1_block_desc), - BBlockTransferSrcAccessOrder, - Sequence<1, 0, 2>, - BBlockTransferSrcVectorDim, - 2, - BBlockTransferSrcScalarPerVector, - BBlockTransferDstScalarPerVector_KPack, - 1, - 1, - BThreadTransferSrcResetCoordinateAfterRun, - true>( - b_k0_n_k1_global_desc, - make_multi_index(0, n_block_data_idx_on_global, 0), - b_k0_n_k1_block_desc, - make_multi_index(0, 0, 0)); - - // GEMM definition - // c_mtx += transpose(a_mtx) * b_mtx - // a_mtx[KPerBlock, MPerBlock] is in LDS - // b_mtx[KPerBlock, NPerBlock] is in LDS - // c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in - // register - // sanity check - - static_assert(MPerBlock % (MPerWave * MRepeat) == 0 && - NPerBlock % (NPerWave * NRepeat) == 0, - "wrong!"); - - constexpr auto a_k0_m0_m1_k1_block_desc = transform_dynamic_tensor_descriptor( - a_k0_m_k1_block_desc, - make_tuple(make_pass_through_transform(Number{}), - make_unmerge_transform( - make_tuple(Number{}, Number{})), - make_pass_through_transform(Number{})), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), - make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{})); - - constexpr auto b_k0_n0_n1_k1_block_desc = transform_dynamic_tensor_descriptor( - b_k0_n_k1_block_desc, - make_tuple(make_pass_through_transform(Number{}), - make_unmerge_transform( - make_tuple(Number{}, Number{})), - make_pass_through_transform(Number{})), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), - make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{})); - - const auto blockwise_gemm = - BlockwiseGemmXdlops_km_kn_m0m1m2n_v1{}; - - constexpr auto CLayout = blockwise_gemm.GetCLayout(); - - constexpr index_t BlkSize = CLayout.GetBlkSize(); - constexpr index_t NumBlks = CLayout.GetNumBlks(); - constexpr index_t NumXdlops = CLayout.GetNumXdlops(); - - constexpr auto c_mr_nr_nx_desc = make_dynamic_naive_tensor_descriptor_packed_v2( - make_tuple(Number{}, Number{}, Number{})); - - constexpr auto c_blk_nb_bs_desc = make_dynamic_naive_tensor_descriptor_packed_v2( - make_tuple(Number{}, Number{})); - - StaticBuffer, - c_mr_nr_nx_desc.GetElementSpaceSize()> - c_thread_buf; - - // LDS allocation for A and B: be careful of alignment - constexpr auto a_block_space_size = - math::integer_least_multiple(a_k0_m_k1_block_desc.GetElementSpaceSize(), max_lds_align); - - constexpr auto b_block_space_size = - math::integer_least_multiple(b_k0_n_k1_block_desc.GetElementSpaceSize(), max_lds_align); - - FloatAB* p_a_block_double = p_shared_block; - FloatAB* p_b_block_double = p_shared_block + 2 * a_block_space_size; - - // register allocation for output - // auto c_thread_buf = make_static_buffer( - // c_m0_m1_n0_n1_thread_desc.GetElementSpaceSize()); - - // ThreadwiseDynamicTensorSliceSet_v1>{} - //.Run(c_m0_m1_n0_n1_thread_desc, make_tuple(I0, I0, I0, I0), c_thread_buf, FloatAcc{0}); - - constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock, 0, 0); - constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock, 0, 0); - - // hack to control index calculation when iterating over A and B matrix for threadwise copy - constexpr auto a_k0_m_k1_global_iterator_hacks = AGlobalIteratorHacks{}; - constexpr auto b_k0_n_k1_global_iterator_hacks = BGlobalIteratorHacks{}; - - // hack to control index calculation when move slice window for A and B matrix for - // threadwise copy - constexpr auto a_k0_m_k1_global_move_slice_window_iterator_hack = - AGlobalMoveSliceWindowIteratorHacks{}; - constexpr auto b_k0_n_k1_global_move_slice_window_iterator_hack = - BGlobalMoveSliceWindowIteratorHacks{}; - - auto a_block_even_buf = make_dynamic_buffer( - p_a_block_double, a_k0_m_k1_block_desc.GetElementSpaceSize()); - auto b_block_even_buf = make_dynamic_buffer( - p_b_block_double, b_k0_n_k1_block_desc.GetElementSpaceSize()); - - auto a_block_odd_buf = make_dynamic_buffer( - p_a_block_double + a_block_space_size, a_k0_m_k1_block_desc.GetElementSpaceSize()); - auto b_block_odd_buf = make_dynamic_buffer( - p_b_block_double + b_block_space_size, b_k0_n_k1_block_desc.GetElementSpaceSize()); - - // LDS double buffer: preload data into LDS - { - a_blockwise_copy.RunRead( - a_k0_m_k1_global_desc, a_global_buf, a_k0_m_k1_global_iterator_hacks); - b_blockwise_copy.RunRead( - b_k0_n_k1_global_desc, b_global_buf, b_k0_n_k1_global_iterator_hacks); - - a_blockwise_copy.RunWrite(a_k0_m_k1_block_desc, a_block_even_buf); - b_blockwise_copy.RunWrite(b_k0_n_k1_block_desc, b_block_even_buf); - } - - if constexpr(HasMainKBlockLoop) - { - index_t k_block_data_begin = 0; - - // LDS double buffer: main body - // use Do-While loop instead of For loop to simplify control flow - do - { - // even iteration - a_blockwise_copy.MoveSrcSliceWindow( - a_k0_m_k1_global_desc, - a_block_slice_copy_step, - a_k0_m_k1_global_move_slice_window_iterator_hack); - b_blockwise_copy.MoveSrcSliceWindow( - b_k0_n_k1_global_desc, - b_block_slice_copy_step, - b_k0_n_k1_global_move_slice_window_iterator_hack); - - __syncthreads(); - - // LDS doubel buffer: load next data from device mem - a_blockwise_copy.RunRead( - a_k0_m_k1_global_desc, a_global_buf, a_k0_m_k1_global_iterator_hacks); - b_blockwise_copy.RunRead( - b_k0_n_k1_global_desc, b_global_buf, b_k0_n_k1_global_iterator_hacks); - - asm volatile("s_nop 0"); - - // LDS double buffer: GEMM on current data - blockwise_gemm.Run(a_block_even_buf, b_block_even_buf, c_thread_buf); - - // LDS double buffer: store next data to LDS - a_blockwise_copy.RunWrite(a_k0_m_k1_block_desc, a_block_odd_buf); - b_blockwise_copy.RunWrite(b_k0_n_k1_block_desc, b_block_odd_buf); - - // odd iteration - a_blockwise_copy.MoveSrcSliceWindow( - a_k0_m_k1_global_desc, - a_block_slice_copy_step, - a_k0_m_k1_global_move_slice_window_iterator_hack); - b_blockwise_copy.MoveSrcSliceWindow( - b_k0_n_k1_global_desc, - b_block_slice_copy_step, - b_k0_n_k1_global_move_slice_window_iterator_hack); - - __syncthreads(); - - // LDS doubel buffer: load next data from device mem - a_blockwise_copy.RunRead( - a_k0_m_k1_global_desc, a_global_buf, a_k0_m_k1_global_iterator_hacks); - b_blockwise_copy.RunRead( - b_k0_n_k1_global_desc, b_global_buf, b_k0_n_k1_global_iterator_hacks); - - asm volatile("s_nop 0"); - - // LDS double buffer: GEMM on current data - blockwise_gemm.Run(a_block_odd_buf, b_block_odd_buf, c_thread_buf); - - // LDS double buffer: store next data to LDS - a_blockwise_copy.RunWrite(a_k0_m_k1_block_desc, a_block_even_buf); - b_blockwise_copy.RunWrite(b_k0_n_k1_block_desc, b_block_even_buf); - - k_block_data_begin += 2 * KPerBlock; - } while(k_block_data_begin < K0 - 2 * KPerBlock); - } - - // LDS double buffer: tail - if constexpr(HasDoubleTailKBlockLoop) // if has 2 iteration left - { - a_blockwise_copy.MoveSrcSliceWindow(a_k0_m_k1_global_desc, - a_block_slice_copy_step, - a_k0_m_k1_global_move_slice_window_iterator_hack); - b_blockwise_copy.MoveSrcSliceWindow(b_k0_n_k1_global_desc, - b_block_slice_copy_step, - b_k0_n_k1_global_move_slice_window_iterator_hack); - - __syncthreads(); - - // LDS double buffer: load last data from device mem - a_blockwise_copy.RunRead( - a_k0_m_k1_global_desc, a_global_buf, a_k0_m_k1_global_iterator_hacks); - b_blockwise_copy.RunRead( - b_k0_n_k1_global_desc, b_global_buf, b_k0_n_k1_global_iterator_hacks); - - // LDS double buffer: GEMM on 2nd-last data - blockwise_gemm.Run(a_block_even_buf, b_block_even_buf, c_thread_buf); - - // LDS double buffer: store last data to LDS - a_blockwise_copy.RunWrite(a_k0_m_k1_block_desc, a_block_odd_buf); - b_blockwise_copy.RunWrite(b_k0_n_k1_block_desc, b_block_odd_buf); - - __syncthreads(); - - // LDS double buffer: GEMM on last data - blockwise_gemm.Run(a_block_odd_buf, b_block_odd_buf, c_thread_buf); - } - else // if has 1 iteration left - { - __syncthreads(); - - // LDS double buffer: GEMM on last data - blockwise_gemm.Run(a_block_even_buf, b_block_even_buf, c_thread_buf); - } - - // output: register to global memory - { - - constexpr index_t M0 = CLayout.M1(); - constexpr index_t M1 = CLayout.N1(); - constexpr index_t M2 = CLayout.M0(); - - constexpr auto c_m0_m1_m2_n_thread_desc = - make_dynamic_naive_tensor_descriptor_packed_v2( - make_tuple(Number{}, Number<1>{}, Number{}, Number<1>{})); - - StaticBuffer c_blk_buf_; - - static_for<0, MRepeat, 1>{}([&](auto mr_i) { - static_for<0, NRepeat, 1>{}([&](auto nr_i) { - static_for<0, NumXdlops, 1>{}([&](auto xdlops_i) { - static_for<0, NumBlks, 1>{}([&](auto blk_i) { - auto c_blk = c_thread_buf[Number{}]; - - static_for<0, BlkSize, 1>{}([&](auto j) { - c_blk_buf_(j) = c_blk.template AsType()[Number< - c_blk_nb_bs_desc.CalculateOffset(make_tuple(blk_i, j))>{}]; - }); - - // calculate origin of thread output tensor on global memory - // blockwise GEMM c matrix starting index - const auto c_thread_mtx_on_block = - blockwise_gemm.CalculateCThreadOriginDataIndex( - mr_i, nr_i, xdlops_i, blk_i); - - const index_t m_thread_data_on_global = - m_block_data_idx_on_global + c_thread_mtx_on_block[I0]; - - const index_t n_thread_data_on_global = - n_block_data_idx_on_global + c_thread_mtx_on_block[I1]; - - constexpr auto c_m0_m1_m2_n_global_tensor_iterator_hacks = - CGlobalIteratorHacks{}; - - ThreadwiseDynamicTensorSliceTransfer_v1r3< - FloatC, - FloatC, - decltype(c_m0_m1_m2_n_thread_desc), - decltype(c_m0_m1_m2_n_global_desc), - Sequence, - CThreadTransferSrcDstAccessOrder, - CThreadTransferSrcDstVectorDim, - CThreadTransferDstScalarPerVector, - CGlobalMemoryDataOperation, - 1, - true>{c_m0_m1_m2_n_global_desc, - make_multi_index(m_thread_data_on_global / (M2 * M1), - m_thread_data_on_global % (M2 * M1) / M2, - m_thread_data_on_global % M2, - n_thread_data_on_global)} - .Run(c_m0_m1_m2_n_thread_desc, - make_tuple(I0, I0, I0, I0), - c_blk_buf_, - c_m0_m1_m2_n_global_desc, - c_global_buf, - c_m0_m1_m2_n_global_tensor_iterator_hacks); - }); - }); - }); - }); - } - } - - template - __device__ static void Run(const FloatAB* __restrict__ p_a_global, - const FloatAB* __restrict__ p_b_global, - FloatC* __restrict__ p_c_global, - const AGlobalDesc& a_k0_m_k1_global_desc, - const BGlobalDesc& b_k0_n_k1_global_desc, - const CGlobalDesc& c_m0_m1_m2_n_global_desc, - const CBlockClusterDesc& c_block_cluster_desc, - integral_constant, - integral_constant) - { - constexpr index_t shared_block_size = GetSharedMemoryNumberOfByte() / sizeof(FloatAB); - - __shared__ FloatAB p_shared_block[shared_block_size]; - - Run(p_a_global, - p_b_global, - p_c_global, - a_k0_m_k1_global_desc, - b_k0_n_k1_global_desc, - c_m0_m1_m2_n_global_desc, - c_block_cluster_desc, - p_shared_block, - integral_constant{}, - integral_constant{}); - } -}; - -} // namespace ck -#endif diff --git a/composable_kernel/include/tensor_operation/gridwise_dynamic_gemm_xdlops_v2.hpp b/composable_kernel/include/tensor_operation/gridwise_dynamic_gemm_xdlops_v2.hpp deleted file mode 100644 index 4e1549355d..0000000000 --- a/composable_kernel/include/tensor_operation/gridwise_dynamic_gemm_xdlops_v2.hpp +++ /dev/null @@ -1,498 +0,0 @@ -#ifndef CK_GRIDWISE_DYNAMIC_GEMM_XDLOPS_V2_HPP -#define CK_GRIDWISE_DYNAMIC_GEMM_XDLOPS_V2_HPP - -#include "common_header.hpp" -#include "dynamic_multi_index_transform_helper.hpp" -#include "dynamic_tensor_descriptor.hpp" -#include "dynamic_tensor_descriptor_helper.hpp" -#include "blockwise_gemm_xdlops.hpp" -#include "blockwise_dynamic_tensor_slice_transfer.hpp" -#include "threadwise_dynamic_tensor_slice_transfer.hpp" -#include "threadwise_dynamic_tensor_slice_set.hpp" - -namespace ck { - -#if CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VALUE -template -__global__ void -#if CK_USE_LAUNCH_BOUNDS - __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) -#endif - kernel_dynamic_gemm_xdlops_v2(const FloatA* __restrict__ p_a_global, - const FloatB* __restrict__ p_b_global, - FloatC* __restrict__ p_c_global, - const AGlobalDesc a_k0_m_k1_global_desc, - const BGlobalDesc b_k0_n_k1_global_desc, - const CGlobalDesc c_m0_m1_m2_n_global_desc, - const CBlockClusterDesc c_block_cluster_desc) -{ - GridwiseGemm::Run(p_a_global, - p_b_global, - p_c_global, - a_k0_m_k1_global_desc, - b_k0_n_k1_global_desc, - c_m0_m1_m2_n_global_desc, - c_block_cluster_desc); -} -#elif CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VOID_POINTER -// pass tensor descriptor by __CONSTANT__ void pointer -// __CONSTANT__ is needed to inform compiler void pointers in the kernel signature are pointing to -// non-modifiable parameter address space, so compiler can enable corresponding optimization -template -__global__ void -#if CK_USE_LAUNCH_BOUNDS - __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) -#endif - kernel_dynamic_gemm_xdlops_v2(const FloatA* __restrict__ p_a_global, - const FloatB* __restrict__ p_b_global, - FloatC* __restrict__ p_c_global, - const void __CONSTANT__* p_a_k0_m_k1_global_desc, - const void __CONSTANT__* p_b_k0_n_k1_global_desc, - const void __CONSTANT__* p_c_m0_m1_m2_n_global_desc, - const void __CONSTANT__* p_c_block_cluster_desc) -{ - // first cast void __CONSTANT__ void* to void* - // second cast void* to Desc* - // the copy constructor of tensor descriptor doesn't take address_space(4) - const auto a_k0_m_k1_global_desc = - *reinterpret_cast((const void*)p_a_k0_m_k1_global_desc); - const auto b_k0_n_k1_global_desc = - *reinterpret_cast((const void*)p_b_k0_n_k1_global_desc); - const auto c_m0_m1_m2_n_global_desc = - *reinterpret_cast((const void*)p_c_m0_m1_m2_n_global_desc); - - const auto c_block_cluster_desc = - *reinterpret_cast((const void*)p_c_block_cluster_desc); - - GridwiseGemm::Run(p_a_global, - p_b_global, - p_c_global, - a_k0_m_k1_global_desc, - b_k0_n_k1_global_desc, - c_m0_m1_m2_n_global_desc, - c_block_cluster_desc, - integral_constant{}, - integral_constant{}); -} -#endif - -template -struct GridwiseDynamicGemm_km_kn_m0m1n0n1_xdlops_v2 -{ - __host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte() - { - constexpr auto max_lds_align = Number{}; - - // A matrix in LDS memory, dst of blockwise copy - // be careful of LDS alignment - constexpr auto a_k0_m_k1_block_desc = make_dynamic_naive_tensor_descriptor_aligned_v2( - make_tuple(Number{}, Number{}, Number{}), max_lds_align); - - // B matrix in LDS memory, dst of blockwise copy - // be careful of LDS alignment - constexpr auto b_k0_n_k1_block_desc = make_dynamic_naive_tensor_descriptor_aligned_v2( - make_tuple(Number{}, Number{}, Number{}), max_lds_align); - - // LDS allocation for A and B: be careful of alignment - constexpr auto a_block_space_size = - math::integer_least_multiple(a_k0_m_k1_block_desc.GetElementSpaceSize(), max_lds_align); - - constexpr auto b_block_space_size = - math::integer_least_multiple(b_k0_n_k1_block_desc.GetElementSpaceSize(), max_lds_align); - - return (a_block_space_size + b_block_space_size) * sizeof(FloatAB); - } - - __device__ static void Run(const FloatAB* __restrict__ p_a_global, - const FloatAB* __restrict__ p_b_global, - FloatC* __restrict__ p_c_global, - const AGlobalDesc& a_k0_m_k1_global_desc, - const BGlobalDesc& b_k0_n_k1_global_desc, - const CGlobalDesc& c_m0_m1_m2_n_global_desc, - const CBlockClusterDesc& c_block_cluster_desc, - FloatAB* __restrict__ p_shared_block) - { - constexpr auto I0 = Number<0>{}; - constexpr auto I1 = Number<1>{}; - constexpr auto I2 = Number<2>{}; - constexpr auto I3 = Number<3>{}; - - const auto a_global_buf = make_dynamic_buffer( - p_a_global, a_k0_m_k1_global_desc.GetElementSpaceSize()); - const auto b_global_buf = make_dynamic_buffer( - p_b_global, b_k0_n_k1_global_desc.GetElementSpaceSize()); - auto c_global_buf = make_dynamic_buffer( - p_c_global, c_m0_m1_m2_n_global_desc.GetElementSpaceSize()); - - const auto K0 = a_k0_m_k1_global_desc.GetLength(I0); - const auto M = a_k0_m_k1_global_desc.GetLength(I1); - const auto N = b_k0_n_k1_global_desc.GetLength(I1); - const auto K1 = b_k0_n_k1_global_desc.GetLength(I2); - - // divide block work by [M, N] - const auto block_work_idx = - c_block_cluster_desc.CalculateBottomIndex(make_multi_index(get_block_1d_id())); - - // HACK: this force m/n_block_data_idx_on_global into SGPR - const index_t m_block_data_idx_on_global = - __builtin_amdgcn_readfirstlane(block_work_idx[I0] * MPerBlock); - - const index_t n_block_data_idx_on_global = - __builtin_amdgcn_readfirstlane(block_work_idx[I1] * NPerBlock); - - // lds max alignment - constexpr auto max_lds_align = Number{}; - - // A matrix in LDS memory, dst of blockwise copy - // be careful of LDS alignment - constexpr auto a_k0_m_k1_block_desc = make_dynamic_naive_tensor_descriptor_aligned_v2( - make_tuple(Number{}, Number{}, Number{}), max_lds_align); - - // B matrix in LDS memory, dst of blockwise copy - // be careful of LDS alignment - constexpr auto b_k0_n_k1_block_desc = make_dynamic_naive_tensor_descriptor_aligned_v2( - make_tuple(Number{}, Number{}, Number{}), max_lds_align); - - // A matrix blockwise copy - auto a_blockwise_copy = - BlockwiseDynamicTensorSliceTransfer_v4, - ABlockTransferThreadSliceLengths_K_M_KPack, - ABlockTransferThreadClusterLengths_K_M_KPack, - ABlockTransferThreadClusterArrangeOrder, - FloatAB, - FloatAB, - decltype(a_k0_m_k1_global_desc), - decltype(a_k0_m_k1_block_desc), - ABlockTransferSrcAccessOrder, - Sequence<1, 0, 2>, - ABlockTransferSrcVectorDim, - 2, - ABlockTransferSrcScalarPerVector, - ABlockTransferDstScalarPerVector_KPack, - 1, - 1, - AThreadTransferSrcResetCoordinateAfterRun, - true>( - a_k0_m_k1_global_desc, - make_multi_index(0, m_block_data_idx_on_global, 0), - a_k0_m_k1_block_desc, - make_multi_index(0, 0, 0)); - - // B matrix blockwise copy - auto b_blockwise_copy = - BlockwiseDynamicTensorSliceTransfer_v4, - BBlockTransferThreadSliceLengths_K_N_KPack, - BBlockTransferThreadClusterLengths_K_N_KPack, - BBlockTransferThreadClusterArrangeOrder, - FloatAB, - FloatAB, - decltype(b_k0_n_k1_global_desc), - decltype(b_k0_n_k1_block_desc), - BBlockTransferSrcAccessOrder, - Sequence<1, 0, 2>, - BBlockTransferSrcVectorDim, - 2, - BBlockTransferSrcScalarPerVector, - BBlockTransferDstScalarPerVector_KPack, - 1, - 1, - BThreadTransferSrcResetCoordinateAfterRun, - true>( - b_k0_n_k1_global_desc, - make_multi_index(0, n_block_data_idx_on_global, 0), - b_k0_n_k1_block_desc, - make_multi_index(0, 0, 0)); - - // GEMM definition - // c_mtx += transpose(a_mtx) * b_mtx - // a_mtx[KPerBlock, MPerBlock] is in LDS - // b_mtx[KPerBlock, NPerBlock] is in LDS - // c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in - // register - // sanity check - - static_assert(MPerBlock % (MPerWave * MRepeat) == 0 && - NPerBlock % (NPerWave * NRepeat) == 0, - "wrong!"); - - constexpr auto a_k0_m0_m1_k1_block_desc = transform_dynamic_tensor_descriptor( - a_k0_m_k1_block_desc, - make_tuple(make_pass_through_transform(Number{}), - make_unmerge_transform( - make_tuple(Number{}, Number{})), - make_pass_through_transform(Number{})), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), - make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{})); - - constexpr auto b_k0_n0_n1_k1_block_desc = transform_dynamic_tensor_descriptor( - b_k0_n_k1_block_desc, - make_tuple(make_pass_through_transform(Number{}), - make_unmerge_transform( - make_tuple(Number{}, Number{})), - make_pass_through_transform(Number{})), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), - make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{})); - - const auto blockwise_gemm = - BlockwiseGemmXdlops_km_kn_m0m1m2n_v1{}; - - constexpr auto CLayout = blockwise_gemm.GetCLayout(); - - constexpr index_t BlkSize = CLayout.GetBlkSize(); - constexpr index_t NumBlks = CLayout.GetNumBlks(); - constexpr index_t NumXdlops = CLayout.GetNumXdlops(); - - constexpr auto c_mr_nr_nx_desc = make_dynamic_naive_tensor_descriptor_packed_v2( - make_tuple(Number{}, Number{}, Number{})); - - constexpr auto c_blk_nb_bs_desc = make_dynamic_naive_tensor_descriptor_packed_v2( - make_tuple(Number{}, Number{})); - - StaticBuffer, - c_mr_nr_nx_desc.GetElementSpaceSize()> - c_thread_buf; - - // LDS allocation for A and B: be careful of alignment - constexpr auto a_block_space_size = - math::integer_least_multiple(a_k0_m_k1_block_desc.GetElementSpaceSize(), max_lds_align); - - constexpr auto b_block_space_size = - math::integer_least_multiple(b_k0_n_k1_block_desc.GetElementSpaceSize(), max_lds_align); - - FloatAB* p_a_block = p_shared_block; - FloatAB* p_b_block = p_shared_block + a_block_space_size; - - // register allocation for output - // auto c_thread_buf = make_static_buffer( - // c_m0_m1_n0_n1_thread_desc.GetElementSpaceSize()); - - // ThreadwiseDynamicTensorSliceSet_v1>{} - //.Run(c_m0_m1_n0_n1_thread_desc, make_tuple(I0, I0, I0, I0), c_thread_buf, FloatAcc{0}); - - constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock, 0, 0); - constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock, 0, 0); - - // hack to control index calculation when iterating over A and B matrix for threadwise copy - constexpr auto a_k0_m_k1_global_iterator_hacks = AGlobalIteratorHacks{}; - constexpr auto b_k0_n_k1_global_iterator_hacks = BGlobalIteratorHacks{}; - - // hack to control index calculation when move slice window for A and B matrix for - // threadwise copy - constexpr auto a_k0_m_k1_global_move_slice_window_iterator_hack = - AGlobalMoveSliceWindowIteratorHacks{}; - constexpr auto b_k0_n_k1_global_move_slice_window_iterator_hack = - BGlobalMoveSliceWindowIteratorHacks{}; - - auto a_block_buf = make_dynamic_buffer( - p_a_block, a_k0_m_k1_block_desc.GetElementSpaceSize()); - auto b_block_buf = make_dynamic_buffer( - p_b_block, b_k0_n_k1_block_desc.GetElementSpaceSize()); - - // preload data into LDS - { - a_blockwise_copy.RunRead( - a_k0_m_k1_global_desc, a_global_buf, a_k0_m_k1_global_iterator_hacks); - b_blockwise_copy.RunRead( - b_k0_n_k1_global_desc, b_global_buf, b_k0_n_k1_global_iterator_hacks); - - a_blockwise_copy.RunWrite(a_k0_m_k1_block_desc, a_block_buf); - b_blockwise_copy.RunWrite(b_k0_n_k1_block_desc, b_block_buf); - } - - // main body - index_t k_block_data_begin = 0; - - do - { - a_blockwise_copy.MoveSrcSliceWindow(a_k0_m_k1_global_desc, - a_block_slice_copy_step, - a_k0_m_k1_global_move_slice_window_iterator_hack); - b_blockwise_copy.MoveSrcSliceWindow(b_k0_n_k1_global_desc, - b_block_slice_copy_step, - b_k0_n_k1_global_move_slice_window_iterator_hack); - - a_blockwise_copy.RunRead( - a_k0_m_k1_global_desc, a_global_buf, a_k0_m_k1_global_iterator_hacks); - b_blockwise_copy.RunRead( - b_k0_n_k1_global_desc, b_global_buf, b_k0_n_k1_global_iterator_hacks); - - block_sync_lds(); - - blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf); - - block_sync_lds(); - - a_blockwise_copy.RunWrite(a_k0_m_k1_block_desc, a_block_buf); - b_blockwise_copy.RunWrite(b_k0_n_k1_block_desc, b_block_buf); - - k_block_data_begin += KPerBlock; - } while(k_block_data_begin < (K0 - KPerBlock)); - - // tail - { - block_sync_lds(); - - blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf); - } - - // output: register to global memory - { - - constexpr index_t M0 = CLayout.M1(); - constexpr index_t M1 = CLayout.N1(); - constexpr index_t M2 = CLayout.M0(); - - constexpr auto c_m0_m1_m2_n_thread_desc = - make_dynamic_naive_tensor_descriptor_packed_v2( - make_tuple(Number{}, Number<1>{}, Number{}, Number<1>{})); - - StaticBuffer c_blk_buf_; - - static_for<0, MRepeat, 1>{}([&](auto mr_i) { - static_for<0, NRepeat, 1>{}([&](auto nr_i) { - static_for<0, NumXdlops, 1>{}([&](auto xdlops_i) { - static_for<0, NumBlks, 1>{}([&](auto blk_i) { - auto c_blk = c_thread_buf[Number{}]; - - static_for<0, BlkSize, 1>{}([&](auto j) { - c_blk_buf_(j) = c_blk.template AsType()[Number< - c_blk_nb_bs_desc.CalculateOffset(make_tuple(blk_i, j))>{}]; - }); - - // calculate origin of thread output tensor on global memory - // blockwise GEMM c matrix starting index - const auto c_thread_mtx_on_block = - blockwise_gemm.CalculateCThreadOriginDataIndex( - mr_i, nr_i, xdlops_i, blk_i); - - const index_t m_thread_data_on_global = - m_block_data_idx_on_global + c_thread_mtx_on_block[I0]; - - const index_t n_thread_data_on_global = - n_block_data_idx_on_global + c_thread_mtx_on_block[I1]; - - constexpr auto c_m0_m1_m2_n_global_tensor_iterator_hacks = - CGlobalIteratorHacks{}; - - ThreadwiseDynamicTensorSliceTransfer_v1r3< - FloatC, - FloatC, - decltype(c_m0_m1_m2_n_thread_desc), - decltype(c_m0_m1_m2_n_global_desc), - Sequence, - CThreadTransferSrcDstAccessOrder, - CThreadTransferSrcDstVectorDim, - CThreadTransferDstScalarPerVector, - CGlobalMemoryDataOperation, - 1, - true>{c_m0_m1_m2_n_global_desc, - make_multi_index(m_thread_data_on_global / (M2 * M1), - m_thread_data_on_global % (M2 * M1) / M2, - m_thread_data_on_global % M2, - n_thread_data_on_global)} - .Run(c_m0_m1_m2_n_thread_desc, - make_tuple(I0, I0, I0, I0), - c_blk_buf_, - c_m0_m1_m2_n_global_desc, - c_global_buf, - c_m0_m1_m2_n_global_tensor_iterator_hacks); - }); - }); - }); - }); - } - } - - __device__ static void Run(const FloatAB* __restrict__ p_a_global, - const FloatAB* __restrict__ p_b_global, - FloatC* __restrict__ p_c_global, - const AGlobalDesc& a_k0_m_k1_global_desc, - const BGlobalDesc& b_k0_n_k1_global_desc, - const CGlobalDesc& c_m0_m1_m2_n_global_desc, - const CBlockClusterDesc& c_block_cluster_desc) - { - constexpr index_t shared_block_size = GetSharedMemoryNumberOfByte() / sizeof(FloatAB); - - __shared__ FloatAB p_shared_block[shared_block_size]; - - Run(p_a_global, - p_b_global, - p_c_global, - a_k0_m_k1_global_desc, - b_k0_n_k1_global_desc, - c_m0_m1_m2_n_global_desc, - c_block_cluster_desc, - p_shared_block); - } -}; - -} // namespace ck -#endif diff --git a/composable_kernel/include/tensor_operation/gridwise_dynamic_gemm_xdlops_v2r2.hpp b/composable_kernel/include/tensor_operation/gridwise_dynamic_gemm_xdlops_v2r2.hpp deleted file mode 100644 index 4e7e59d10c..0000000000 --- a/composable_kernel/include/tensor_operation/gridwise_dynamic_gemm_xdlops_v2r2.hpp +++ /dev/null @@ -1,509 +0,0 @@ -#ifndef CK_GRIDWISE_DYNAMIC_GEMM_XDLOPS_V2R2_HPP -#define CK_GRIDWISE_DYNAMIC_GEMM_XDLOPS_V2R2_HPP - -#include "common_header.hpp" -#include "dynamic_multi_index_transform_helper.hpp" -#include "dynamic_tensor_descriptor.hpp" -#include "dynamic_tensor_descriptor_helper.hpp" -#include "blockwise_gemm_xdlops.hpp" -#include "blockwise_dynamic_tensor_slice_transfer.hpp" -#include "threadwise_dynamic_tensor_slice_transfer.hpp" -#include "threadwise_dynamic_tensor_slice_set.hpp" - -namespace ck { - -template -__global__ void -#if CK_USE_LAUNCH_BOUNDS - __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) -#endif - kernel_dynamic_gemm_xdlops_v2r2(const FloatAB* __restrict__ p_a_grid, - const FloatAB* __restrict__ p_b_grid, - FloatC* __restrict__ p_c_grid, - const AK0MK1GridDesc a_k0_m_k1_grid_desc, - const BK0NK1GridDesc b_k0_n_k1_grid_desc, - const CM0M1M2NGridDesc c_m0_m1_m2_n_grid_desc, - const CBlockClusterAdaptor c_block_cluster_adaptor) -{ - constexpr index_t shared_block_size = - GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatAB); - - __shared__ FloatAB p_shared_block[shared_block_size]; - - GridwiseGemm::Run(p_a_grid, - p_b_grid, - p_c_grid, - p_shared_block, - a_k0_m_k1_grid_desc, - b_k0_n_k1_grid_desc, - c_m0_m1_m2_n_grid_desc, - c_block_cluster_adaptor); -} - -template -struct GridwiseDynamicGemm_k0mk1_k0nk1_mn_xdlops_v2r2 -{ - static constexpr auto I0 = Number<0>{}; - static constexpr auto I1 = Number<1>{}; - static constexpr auto I2 = Number<2>{}; - static constexpr auto I3 = Number<3>{}; - - // K1 should be Number<...> - static constexpr auto K1 = AK0MK1GridDesc{}.GetLength(I2); - - __host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte() - { - constexpr auto max_lds_align = K1; - - // A matrix in LDS memory, dst of blockwise copy - // be careful of LDS alignment - constexpr auto a_k0_m_k1_block_desc = make_dynamic_naive_tensor_descriptor_aligned_v2( - make_tuple(Number{}, Number{}, K1), max_lds_align); - - // B matrix in LDS memory, dst of blockwise copy - // be careful of LDS alignment - constexpr auto b_k0_n_k1_block_desc = make_dynamic_naive_tensor_descriptor_aligned_v2( - make_tuple(Number{}, Number{}, K1), max_lds_align); - - // LDS allocation for A and B: be careful of alignment - constexpr auto a_block_space_size = - math::integer_least_multiple(a_k0_m_k1_block_desc.GetElementSpaceSize(), max_lds_align); - - constexpr auto b_block_space_size = - math::integer_least_multiple(b_k0_n_k1_block_desc.GetElementSpaceSize(), max_lds_align); - - return (a_block_space_size + b_block_space_size) * sizeof(FloatAB); - } - - __host__ __device__ static constexpr bool - CheckValidity(const AK0MK1GridDesc& a_k0_m_k1_grid_desc, - const BK0NK1GridDesc& b_k0_n_k1_grid_desc, - const CMNGridDesc& c_m_n_grid_desc) - { - // TODO: turn on this - static_assert(is_known_at_compile_time>::value, - "wrong! K1 need to be known at compile-time"); - - const auto M = a_k0_m_k1_grid_desc.GetLength(I1); - const auto N = b_k0_n_k1_grid_desc.GetLength(I1); - const auto K0 = a_k0_m_k1_grid_desc.GetLength(I0); - - // TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc) - - return (M == c_m_n_grid_desc.GetLength(I0) && N == c_m_n_grid_desc.GetLength(I1) && - K0 == b_k0_n_k1_grid_desc.GetLength(I0) && - K1 == a_k0_m_k1_grid_desc.GetLength(I2) && - K1 == b_k0_n_k1_grid_desc.GetLength(I2)) && - (M % MPerBlock == 0 && N % NPerBlock == 0 && K0 % KPerBlock == 0) && - (MPerBlock % MPerWave == 0 && NPerBlock % NPerWave == 0); - } - - __host__ __device__ static constexpr index_t - CalculateGridSize(const CMNGridDesc& c_m_n_grid_desc) - { - const auto M = c_m_n_grid_desc.GetLength(I0); - const auto N = c_m_n_grid_desc.GetLength(I1); - - const index_t grid_size = (M / MPerBlock) * (N / NPerBlock); - - return grid_size; - } - - __host__ __device__ static constexpr auto - MakeCM0M1M2NGridDescriptor(const CMNGridDesc& c_m_n_grid_desc) - { - const auto M = c_m_n_grid_desc.GetLength(I0); - const auto N = c_m_n_grid_desc.GetLength(I1); - - constexpr auto xdlops_gemm = XdlopsGemm{}; - - constexpr auto CLayout = xdlops_gemm.GetCLayout(); - - constexpr auto M0 = Number{}; - constexpr auto M1 = Number{}; - constexpr auto M2 = Number{}; - - const auto c_m0_m1_m2_n_grid_desc = transform_dynamic_tensor_descriptor( - c_m_n_grid_desc, - make_tuple(make_unmerge_transform(make_tuple(M / (M1 * M2), M1, M2)), - make_pass_through_transform(N)), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<0, 1, 2>{}, Sequence<3>{})); - - return c_m0_m1_m2_n_grid_desc; - } - - __host__ __device__ static constexpr auto - MakeCBlockClusterAdaptor(const CMNGridDesc& c_m_n_grid_desc) - { - const auto M = c_m_n_grid_desc.GetLength(I0); - const auto N = c_m_n_grid_desc.GetLength(I1); - - constexpr auto M1 = Number{}; - constexpr auto N1 = Number{}; - - const auto M0 = M / M1; - const auto N0 = N / N1; - - const auto c_blockid_to_m0_n0_block_cluster_adaptor = - make_single_stage_tensor_adaptor(make_tuple(make_merge_transform(make_tuple(M0, N0))), - make_tuple(Sequence<0, 1>{}), - make_tuple(Sequence<0>{})); - - return c_blockid_to_m0_n0_block_cluster_adaptor; - } - - using CM0M1M2NGridDesc = decltype(MakeCM0M1M2NGridDescriptor(CMNGridDesc{})); - using CBlockClusterAdaptor = decltype(MakeCBlockClusterAdaptor(CMNGridDesc{})); - - __device__ static void Run(const FloatAB* __restrict__ p_a_grid, - const FloatAB* __restrict__ p_b_grid, - FloatC* __restrict__ p_c_grid, - FloatAB* __restrict__ p_shared_block, - const AK0MK1GridDesc& a_k0_m_k1_grid_desc, - const BK0NK1GridDesc& b_k0_n_k1_grid_desc, - const CM0M1M2NGridDesc& c_m0_m1_m2_n_grid_desc, - const CBlockClusterAdaptor& c_block_cluster_adaptor) - { - constexpr auto I0 = Number<0>{}; - constexpr auto I1 = Number<1>{}; - constexpr auto I2 = Number<2>{}; - constexpr auto I3 = Number<3>{}; - - const auto a_grid_buf = make_dynamic_buffer( - p_a_grid, a_k0_m_k1_grid_desc.GetElementSpaceSize()); - const auto b_grid_buf = make_dynamic_buffer( - p_b_grid, b_k0_n_k1_grid_desc.GetElementSpaceSize()); - auto c_grid_buf = make_dynamic_buffer( - p_c_grid, c_m0_m1_m2_n_grid_desc.GetElementSpaceSize()); - - const auto K0 = a_k0_m_k1_grid_desc.GetLength(I0); - const auto M = a_k0_m_k1_grid_desc.GetLength(I1); - const auto N = b_k0_n_k1_grid_desc.GetLength(I1); - - // divide block work by [M, N] - const auto block_work_idx = - c_block_cluster_adaptor.CalculateBottomIndex(make_multi_index(get_block_1d_id())); - - // HACK: this force m/n_block_data_idx_on_grid into SGPR - const index_t m_block_data_idx_on_grid = - __builtin_amdgcn_readfirstlane(block_work_idx[I0] * MPerBlock); - - const index_t n_block_data_idx_on_grid = - __builtin_amdgcn_readfirstlane(block_work_idx[I1] * NPerBlock); - - // lds max alignment - constexpr auto max_lds_align = K1; - - // A matrix in LDS memory, dst of blockwise copy - // be careful of LDS alignment - constexpr auto a_k0_m_k1_block_desc = make_dynamic_naive_tensor_descriptor_aligned_v2( - make_tuple(Number{}, Number{}, K1), max_lds_align); - - // B matrix in LDS memory, dst of blockwise copy - // be careful of LDS alignment - constexpr auto b_k0_n_k1_block_desc = make_dynamic_naive_tensor_descriptor_aligned_v2( - make_tuple(Number{}, Number{}, K1), max_lds_align); - - // A matrix blockwise copy - auto a_blockwise_copy = - BlockwiseDynamicTensorSliceTransfer_v4, - ABlockTransferThreadSliceLengths_K0_M_K1, - ABlockTransferThreadClusterLengths_K0_M_K1, - ABlockTransferThreadClusterArrangeOrder, - FloatAB, - FloatAB, - decltype(a_k0_m_k1_grid_desc), - decltype(a_k0_m_k1_block_desc), - ABlockTransferSrcAccessOrder, - Sequence<1, 0, 2>, - ABlockTransferSrcVectorDim, - 2, - ABlockTransferSrcScalarPerVector, - ABlockTransferDstScalarPerVector_K1, - 1, - 1, - AThreadTransferSrcResetCoordinateAfterRun, - true>( - a_k0_m_k1_grid_desc, - make_multi_index(0, m_block_data_idx_on_grid, 0), - a_k0_m_k1_block_desc, - make_multi_index(0, 0, 0)); - - // B matrix blockwise copy - auto b_blockwise_copy = - BlockwiseDynamicTensorSliceTransfer_v4, - BBlockTransferThreadSliceLengths_K0_N_K1, - BBlockTransferThreadClusterLengths_K0_N_K1, - BBlockTransferThreadClusterArrangeOrder, - FloatAB, - FloatAB, - decltype(b_k0_n_k1_grid_desc), - decltype(b_k0_n_k1_block_desc), - BBlockTransferSrcAccessOrder, - Sequence<1, 0, 2>, - BBlockTransferSrcVectorDim, - 2, - BBlockTransferSrcScalarPerVector, - BBlockTransferDstScalarPerVector_K1, - 1, - 1, - BThreadTransferSrcResetCoordinateAfterRun, - true>( - b_k0_n_k1_grid_desc, - make_multi_index(0, n_block_data_idx_on_grid, 0), - b_k0_n_k1_block_desc, - make_multi_index(0, 0, 0)); - - // GEMM definition - // c_mtx += transpose(a_mtx) * b_mtx - // a_mtx[KPerBlock, MPerBlock] is in LDS - // b_mtx[KPerBlock, NPerBlock] is in LDS - // c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in - // register - // sanity check - - static_assert(MPerBlock % (MPerWave * MRepeat) == 0 && - NPerBlock % (NPerWave * NRepeat) == 0, - "wrong!"); - - constexpr auto a_k0_m0_m1_k1_block_desc = transform_dynamic_tensor_descriptor( - a_k0_m_k1_block_desc, - make_tuple(make_pass_through_transform(Number{}), - make_unmerge_transform( - make_tuple(Number{}, Number{})), - make_pass_through_transform(K1)), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), - make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{})); - - constexpr auto b_k0_n0_n1_k1_block_desc = transform_dynamic_tensor_descriptor( - b_k0_n_k1_block_desc, - make_tuple(make_pass_through_transform(Number{}), - make_unmerge_transform( - make_tuple(Number{}, Number{})), - make_pass_through_transform(K1)), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), - make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{})); - - const auto blockwise_gemm = - BlockwiseGemmXdlops_km_kn_m0m1m2n_v1{}; - - constexpr auto CLayout = blockwise_gemm.GetCLayout(); - - constexpr index_t BlkSize = CLayout.GetBlkSize(); - constexpr index_t NumBlks = CLayout.GetNumBlks(); - constexpr index_t NumXdlops = CLayout.GetNumXdlops(); - - constexpr auto c_mr_nr_nx_desc = make_dynamic_naive_tensor_descriptor_packed_v2( - make_tuple(Number{}, Number{}, Number{})); - - constexpr auto c_blk_nb_bs_desc = make_dynamic_naive_tensor_descriptor_packed_v2( - make_tuple(Number{}, Number{})); - - StaticBuffer, - c_mr_nr_nx_desc.GetElementSpaceSize()> - c_thread_buf; - - // LDS allocation for A and B: be careful of alignment - constexpr auto a_block_space_size = - math::integer_least_multiple(a_k0_m_k1_block_desc.GetElementSpaceSize(), max_lds_align); - - constexpr auto b_block_space_size = - math::integer_least_multiple(b_k0_n_k1_block_desc.GetElementSpaceSize(), max_lds_align); - - FloatAB* p_a_block = p_shared_block; - FloatAB* p_b_block = p_shared_block + a_block_space_size; - - constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock, 0, 0); - constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock, 0, 0); - - // hack to control index calculation when iterating over A and B matrix for threadwise copy - constexpr auto a_k0_m_k1_grid_iterator_hacks = AGridIteratorHacks{}; - constexpr auto b_k0_n_k1_grid_iterator_hacks = BGridIteratorHacks{}; - - // hack to control index calculation when move slice window for A and B matrix for - // threadwise copy - constexpr auto a_k0_m_k1_grid_move_slice_window_iterator_hack = - AGridMoveSliceWindowIteratorHacks{}; - constexpr auto b_k0_n_k1_grid_move_slice_window_iterator_hack = - BGridMoveSliceWindowIteratorHacks{}; - - auto a_block_buf = make_dynamic_buffer( - p_a_block, a_k0_m_k1_block_desc.GetElementSpaceSize()); - auto b_block_buf = make_dynamic_buffer( - p_b_block, b_k0_n_k1_block_desc.GetElementSpaceSize()); - - // preload data into LDS - { - a_blockwise_copy.RunRead( - a_k0_m_k1_grid_desc, a_grid_buf, a_k0_m_k1_grid_iterator_hacks); - b_blockwise_copy.RunRead( - b_k0_n_k1_grid_desc, b_grid_buf, b_k0_n_k1_grid_iterator_hacks); - - a_blockwise_copy.RunWrite(a_k0_m_k1_block_desc, a_block_buf); - b_blockwise_copy.RunWrite(b_k0_n_k1_block_desc, b_block_buf); - } - - // main body - index_t k_block_data_begin = 0; - - do - { - a_blockwise_copy.MoveSrcSliceWindow(a_k0_m_k1_grid_desc, - a_block_slice_copy_step, - a_k0_m_k1_grid_move_slice_window_iterator_hack); - b_blockwise_copy.MoveSrcSliceWindow(b_k0_n_k1_grid_desc, - b_block_slice_copy_step, - b_k0_n_k1_grid_move_slice_window_iterator_hack); - - a_blockwise_copy.RunRead( - a_k0_m_k1_grid_desc, a_grid_buf, a_k0_m_k1_grid_iterator_hacks); - - block_sync_lds(); - - b_blockwise_copy.RunRead( - b_k0_n_k1_grid_desc, b_grid_buf, b_k0_n_k1_grid_iterator_hacks); - - blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf); - - block_sync_lds(); - - a_blockwise_copy.RunWrite(a_k0_m_k1_block_desc, a_block_buf); - b_blockwise_copy.RunWrite(b_k0_n_k1_block_desc, b_block_buf); - - k_block_data_begin += KPerBlock; - } while(k_block_data_begin < (K0 - KPerBlock)); - - // tail - { - block_sync_lds(); - - blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf); - } - - // output: register to global memory - { - constexpr index_t M0 = CLayout.M1(); - constexpr index_t M1 = CLayout.N1(); - constexpr index_t M2 = CLayout.M0(); - - constexpr auto c_m0_m1_m2_n_thread_desc = - make_dynamic_naive_tensor_descriptor_packed_v2( - make_tuple(Number{}, Number<1>{}, Number{}, Number<1>{})); - - StaticBuffer c_blk_buf_; - - static_for<0, MRepeat, 1>{}([&](auto mr_i) { - static_for<0, NRepeat, 1>{}([&](auto nr_i) { - static_for<0, NumXdlops, 1>{}([&](auto xdlops_i) { - static_for<0, NumBlks, 1>{}([&](auto blk_i) { - auto c_blk = c_thread_buf[Number{}]; - - static_for<0, BlkSize, 1>{}([&](auto j) { - c_blk_buf_(j) = c_blk.template AsType()[Number< - c_blk_nb_bs_desc.CalculateOffset(make_tuple(blk_i, j))>{}]; - }); - - // calculate origin of thread output tensor on global memory - // blockwise GEMM c matrix starting index - const auto c_thread_mtx_on_block = - blockwise_gemm.CalculateCThreadOriginDataIndex( - mr_i, nr_i, xdlops_i, blk_i); - - const index_t m_thread_data_on_grid = - m_block_data_idx_on_grid + c_thread_mtx_on_block[I0]; - - const index_t n_thread_data_on_grid = - n_block_data_idx_on_grid + c_thread_mtx_on_block[I1]; - - constexpr auto c_m0_m1_m2_n_grid_tensor_iterator_hacks = - CGridIteratorHacks{}; - - ThreadwiseDynamicTensorSliceTransfer_v1r3< - FloatC, - FloatC, - decltype(c_m0_m1_m2_n_thread_desc), - decltype(c_m0_m1_m2_n_grid_desc), - Sequence, - CThreadTransferSrcDstAccessOrder, - CThreadTransferSrcDstVectorDim, - CThreadTransferDstScalarPerVector, - CGlobalMemoryDataOperation, - 1, - true>{c_m0_m1_m2_n_grid_desc, - make_multi_index(m_thread_data_on_grid / (M2 * M1), - m_thread_data_on_grid % (M2 * M1) / M2, - m_thread_data_on_grid % M2, - n_thread_data_on_grid)} - .Run(c_m0_m1_m2_n_thread_desc, - make_tuple(I0, I0, I0, I0), - c_blk_buf_, - c_m0_m1_m2_n_grid_desc, - c_grid_buf, - c_m0_m1_m2_n_grid_tensor_iterator_hacks); - }); - }); - }); - }); - } - } -}; - -} // namespace ck -#endif diff --git a/composable_kernel/include/tensor_operation/gridwise_dynamic_gemm_xdlops_v2r3.hpp b/composable_kernel/include/tensor_operation/gridwise_dynamic_gemm_xdlops_v2r3.hpp index d15ec86800..3b1dc9cea1 100644 --- a/composable_kernel/include/tensor_operation/gridwise_dynamic_gemm_xdlops_v2r3.hpp +++ b/composable_kernel/include/tensor_operation/gridwise_dynamic_gemm_xdlops_v2r3.hpp @@ -12,6 +12,7 @@ namespace ck { +#if CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VALUE template +__global__ void +#if CK_USE_LAUNCH_BOUNDS + __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +#endif + kernel_dynamic_gemm_xdlops_v2r3(const FloatAB* __restrict__ p_a_grid, + const FloatAB* __restrict__ p_b_grid, + FloatC* __restrict__ p_c_grid, + const void __CONSTANT__* p_a_k0_m_k1_grid_desc, + const void __CONSTANT__* p_b_k0_n_k1_grid_desc, + const void __CONSTANT__* p_c_m0_m1_m2_n_grid_desc, + const void __CONSTANT__* p_c_block_cluster_adaptor) +{ + constexpr index_t shared_block_size = + GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatAB); + + const auto a_k0_m_k1_grid_desc = + *reinterpret_cast((const void*)p_a_k0_m_k1_grid_desc); + const auto b_k0_n_k1_grid_desc = + *reinterpret_cast((const void*)p_b_k0_n_k1_grid_desc); + const auto c_m0_m1_m2_n_grid_desc = + *reinterpret_cast((const void*)p_c_m0_m1_m2_n_grid_desc); + const auto c_block_cluster_adaptor = + *reinterpret_cast((const void*)p_c_block_cluster_adaptor); + + __shared__ FloatAB p_shared_block[shared_block_size]; + + GridwiseGemm::Run(p_a_grid, + p_b_grid, + p_c_grid, + p_shared_block, + a_k0_m_k1_grid_desc, + b_k0_n_k1_grid_desc, + c_m0_m1_m2_n_grid_desc, + c_block_cluster_adaptor); +} +#endif template {}; // K1 should be Number<...> - static constexpr auto K1 = AK0MK1GridDesc{}.GetLength(I2); + static constexpr auto K1 = Number{}; __host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte() { @@ -160,7 +206,7 @@ struct GridwiseDynamicGemm_k0mk1_k0nk1_mn_xdlops_v2r3 const auto M = c_m_n_grid_desc.GetLength(I0); const auto N = c_m_n_grid_desc.GetLength(I1); - constexpr auto xdlops_gemm = XdlopsGemm{}; + constexpr auto xdlops_gemm = XdlopsGemm{}; constexpr auto CLayout = xdlops_gemm.GetCLayout(); @@ -267,7 +313,7 @@ struct GridwiseDynamicGemm_k0mk1_k0nk1_mn_xdlops_v2r3 auto a_blockwise_copy = BlockwiseDynamicTensorSliceTransfer_v4, + Sequence, ABlockTransferThreadSliceLengths_K0_M_K1, ABlockTransferThreadClusterLengths_K0_M_K1, ABlockTransferThreadClusterArrangeOrder, @@ -294,7 +340,7 @@ struct GridwiseDynamicGemm_k0mk1_k0nk1_mn_xdlops_v2r3 auto b_blockwise_copy = BlockwiseDynamicTensorSliceTransfer_v4, + Sequence, BBlockTransferThreadSliceLengths_K0_N_K1, BBlockTransferThreadClusterLengths_K0_N_K1, BBlockTransferThreadClusterArrangeOrder, @@ -354,7 +400,7 @@ struct GridwiseDynamicGemm_k0mk1_k0nk1_mn_xdlops_v2r3 decltype(b_k0_n0_n1_k1_block_desc), MPerWave, NPerWave, - K1.value>{}; + K1>{}; constexpr auto CLayout = blockwise_gemm.GetCLayout(); diff --git a/composable_kernel/include/utility/common_header.hpp b/composable_kernel/include/utility/common_header.hpp index 32e2abd99f..ad38d0461c 100644 --- a/composable_kernel/include/utility/common_header.hpp +++ b/composable_kernel/include/utility/common_header.hpp @@ -39,7 +39,6 @@ #if CK_USE_AMD_XDLOPS #include "amd_xdlops.hpp" -#include "amd_xdlops_inline_asm.hpp" #endif #endif diff --git a/composable_kernel/src/kernel_wrapper/dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw.cpp b/composable_kernel/src/kernel_wrapper/dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw.cpp new file mode 100644 index 0000000000..b9a835336b --- /dev/null +++ b/composable_kernel/src/kernel_wrapper/dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw.cpp @@ -0,0 +1,359 @@ +#include "common_header.hpp" +#include "type_helper.hpp" +#include "dynamic_tensor_descriptor.hpp" +#include "dynamic_tensor_descriptor_helper.hpp" +#include "gridwise_dynamic_gemm_xdlops_v2r3.hpp" +#include "transform_forward_convolution_into_gemm_v4r4r2_nchw_kcyx_nkhw.hpp" + +using namespace ck; + +using FloatAB = typename get_type_from_type_id(CK_PARAM_IN_WEI_DATATYPE)>::type; +using FloatC = typename get_type_from_type_id(CK_PARAM_OUT_DATATYPE)>::type; +using FloatAcc = typename get_type_from_type_id(CK_PARAM_CONV_COMPTYPE)>::type; + +constexpr index_t BlockSize = CK_PARAM_BlockSize; + +constexpr index_t MPerBlock = CK_PARAM_MPerBlock; +constexpr index_t NPerBlock = CK_PARAM_NPerBlock; +constexpr index_t KPerBlock = CK_PARAM_KPerBlock; + +constexpr index_t MPerWave = CK_PARAM_MPerWave; +constexpr index_t NPerWave = CK_PARAM_NPerWave; +constexpr index_t MRepeat = CK_PARAM_MRepeat; +constexpr index_t NRepeat = CK_PARAM_NRepeat; +constexpr index_t K1 = CK_PARAM_K1; + +using ABlockTransferThreadSliceLengths_K0_M_K1 = + Sequence; +using ABlockTransferThreadClusterLengths_K0_M_K1 = + Sequence; +using ABlockTransferThreadClusterArrangeOrder = + Sequence; +using ABlockTransferSrcAccessOrder = Sequence; + +constexpr index_t ABlockTransferSrcVectorDim = CK_PARAM_ABlockTransferSrcVectorDim; +constexpr index_t ABlockTransferSrcScalarPerVector = CK_PARAM_ABlockTransferSrcScalarPerVector; +constexpr index_t ABlockTransferDstScalarPerVector_K1 = + CK_PARAM_ABlockTransferDstScalarPerVector_K1; +constexpr bool AThreadTransferSrcResetCoordinateAfterRun = + static_cast(CK_PARAM_AThreadTransferSrcResetCoordinateAfterRun); + +using BBlockTransferThreadSliceLengths_K0_N_K1 = + Sequence; +using BBlockTransferThreadClusterLengths_K0_N_K1 = + Sequence; +using BBlockTransferThreadClusterArrangeOrder = + Sequence; +using BBlockTransferSrcAccessOrder = Sequence; + +constexpr index_t BBlockTransferSrcVectorDim = CK_PARAM_BBlockTransferSrcVectorDim; +constexpr index_t BBlockTransferSrcScalarPerVector = CK_PARAM_BBlockTransferSrcScalarPerVector; +constexpr index_t BBlockTransferDstScalarPerVector_K1 = + CK_PARAM_BBlockTransferDstScalarPerVector_K1; +constexpr bool BThreadTransferSrcResetCoordinateAfterRun = + static_cast(CK_PARAM_BThreadTransferSrcResetCoordinateAfterRun); + +using CThreadTransferSrcDstAccessOrder = Sequence; +constexpr index_t CThreadTransferSrcDstVectorDim = CK_PARAM_CThreadTransferSrcDstVectorDim; +constexpr index_t CThreadTransferDstScalarPerVector = CK_PARAM_CThreadTransferDstScalarPerVector; + +extern "C" __global__ void +dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw_prepare( + int n, + int c, + int hi, + int wi, + int k, + int y, + int x, + int convStrideH, + int convStrideW, + int convDilationY, + int convDilationX, + int leftPadH, + int leftPadW, + int rightPadH, + int rightPadW, + void* p_a_k0_m_k1_grid_desc, + void* p_b_k0_n_k1_grid_desc, + void* p_c_m0_m1_m2_n_grid_desc, + void* p_c_blockid_to_m0_n0_block_cluster_adaptor) +{ + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + constexpr auto I2 = Number<2>{}; + + const index_t ho = (hi + leftPadH + rightPadH - convDilationY * (y - 1) - 1) / convStrideH + 1; + const index_t wo = (wi + leftPadW + rightPadW - convDilationX * (x - 1) - 1) / convStrideW + 1; + + const auto in_n_c_hi_wi_desc = + make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(n, c, hi, wi)); + const auto wei_k_c_y_x_desc = + make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(k, c, y, x)); + const auto out_n_k_ho_wo_desc = + make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(n, k, ho, wo)); + + const auto descs = transform_forward_convolution_into_gemm_v4r4r2_nchw_kcyx_nkhw_pad( + wei_k_c_y_x_desc, + in_n_c_hi_wi_desc, + out_n_k_ho_wo_desc, + make_tuple(convStrideH, convStrideW), + make_tuple(convDilationY, convDilationX), + make_tuple(leftPadH, leftPadW), + make_tuple(rightPadH, rightPadW), + Number{}); + + const auto a_k0_m_k1_grid_desc = descs[I0]; + const auto b_k0_n_k1_grid_desc = descs[I1]; + const auto c_m_n_grid_desc = descs[I2]; + + using AK0MK1GridDesc = decltype(a_k0_m_k1_grid_desc); + using BK0NK1GridDesc = decltype(b_k0_n_k1_grid_desc); + using CMNGridDesc = decltype(c_m_n_grid_desc); + + using AGridIteratorHacks = decltype(make_tuple( + make_tuple(Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}), + make_tuple( + Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}))); + + using BGridIteratorHacks = + decltype(make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{}, + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}), + make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{}, + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{}))); + + using CGridIteratorHacks = decltype(make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 1, 0, 0>{}, + Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 1, 0, 0>{}, + Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 1, 0, 0>{}), + make_tuple(Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 2, 0, 0>{}, + Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 2, 0, 0>{}, + Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 2, 0, 0>{}))); + + using AGridMoveSliceWindowIteratorHacks = Sequence<0, 0, 0, 0, 0>; + using BGridMoveSliceWindowIteratorHacks = Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 0, 0>; + + using GridwiseGemm = + GridwiseDynamicGemm_k0mk1_k0nk1_mn_xdlops_v2r3; + + auto c_m0_m1_m2_n_grid_desc = GridwiseGemm::MakeCM0M1M2NGridDescriptor(c_m_n_grid_desc); + + auto c_blockid_to_m0_n0_block_cluster_adaptor = + GridwiseGemm::MakeCBlockClusterAdaptor(c_m_n_grid_desc); + + if(hipThreadIdx_x == 0) + { + *static_cast*>(p_a_k0_m_k1_grid_desc) = + a_k0_m_k1_grid_desc; + *static_cast*>(p_b_k0_n_k1_grid_desc) = + b_k0_n_k1_grid_desc; + *static_cast(p_c_m0_m1_m2_n_grid_desc) = + c_m0_m1_m2_n_grid_desc; + *static_cast( + p_c_blockid_to_m0_n0_block_cluster_adaptor) = c_blockid_to_m0_n0_block_cluster_adaptor; + } +}; + +extern "C" __global__ void +#if CK_USE_LAUNCH_BOUNDS + __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +#endif + dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw( + const FloatAB* __restrict__ p_a_grid, + const FloatAB* __restrict__ p_b_grid, + FloatC* __restrict__ p_c_grid, + const void __CONSTANT__* p_a_k0_m_k1_grid_desc, + const void __CONSTANT__* p_b_k0_n_k1_grid_desc, + const void __CONSTANT__* p_c_m0_m1_m2_n_grid_desc, + const void __CONSTANT__* p_c_blockid_to_m0_n0_block_cluster_adaptor) +{ + + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + constexpr auto I2 = Number<2>{}; + + constexpr auto in_n_c_hi_wi_desc = + make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(256, 256, 28, 28)); + constexpr auto wei_k_c_y_x_desc = + make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(256, 256, 3, 3)); + constexpr auto out_n_k_ho_wo_desc = + make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(256, 256, 28, 28)); + + constexpr auto descs = + transform_forward_convolution_into_gemm_v4r4r2_nchw_kcyx_nkhw_pad(wei_k_c_y_x_desc, + in_n_c_hi_wi_desc, + out_n_k_ho_wo_desc, + make_tuple(1, 1), + make_tuple(1, 1), + make_tuple(1, 1), + make_tuple(1, 1), + Number{}); + + constexpr auto a_k0_m_k1_grid_desc_tmp = descs[I0]; + constexpr auto b_k0_n_k1_grid_desc_tmp = descs[I1]; + constexpr auto c_m_n_grid_desc = descs[I2]; + + using AGridIteratorHacks = decltype(make_tuple( + make_tuple(Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}), + make_tuple( + Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}))); + + using BGridIteratorHacks = + decltype(make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{}, + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}), + make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{}, + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{}))); + + using CGridIteratorHacks = decltype(make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 1, 0, 0>{}, + Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 1, 0, 0>{}, + Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 1, 0, 0>{}), + make_tuple(Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 2, 0, 0>{}, + Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 2, 0, 0>{}, + Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 2, 0, 0>{}))); + + using AGridMoveSliceWindowIteratorHacks = Sequence<0, 0, 0, 0, 0>; + using BGridMoveSliceWindowIteratorHacks = Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 0, 0>; + + using AK0MK1GridDesc = decltype(a_k0_m_k1_grid_desc_tmp); + using BK0NK1GridDesc = decltype(b_k0_n_k1_grid_desc_tmp); + using CMNGridDesc = decltype(c_m_n_grid_desc); + + using GridwiseGemm = + GridwiseDynamicGemm_k0mk1_k0nk1_mn_xdlops_v2r3; + + constexpr auto c_m0_m1_m2_n_grid_desc_tmp = + GridwiseGemm::MakeCM0M1M2NGridDescriptor(c_m_n_grid_desc); + constexpr auto c_blockid_to_m0_n0_block_cluster_adaptor_tmp = + GridwiseGemm::MakeCBlockClusterAdaptor(c_m_n_grid_desc); + + using CM0M1M2NGridDesc = decltype(c_m0_m1_m2_n_grid_desc_tmp); + using CBlockIdToM0N0BlockClusterAdaptor = + decltype(c_blockid_to_m0_n0_block_cluster_adaptor_tmp); + + const auto a_k0_m_k1_grid_desc = + *reinterpret_cast((const void*)p_a_k0_m_k1_grid_desc); + const auto b_k0_n_k1_grid_desc = + *reinterpret_cast((const void*)p_b_k0_n_k1_grid_desc); + const auto c_m0_m1_m2_n_grid_desc = + *reinterpret_cast((const void*)p_c_m0_m1_m2_n_grid_desc); + const auto c_blockid_to_m0_n0_block_cluster_adaptor = + *reinterpret_cast( + (const void*)p_c_blockid_to_m0_n0_block_cluster_adaptor); + + constexpr index_t shared_block_size = + GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatAB); + + __shared__ FloatAB p_shared_block[shared_block_size]; + + GridwiseGemm::Run(p_a_grid, + p_b_grid, + p_c_grid, + p_shared_block, + a_k0_m_k1_grid_desc, + b_k0_n_k1_grid_desc, + c_m0_m1_m2_n_grid_desc, + c_blockid_to_m0_n0_block_cluster_adaptor); +}; diff --git a/composable_kernel/src/kernel_wrapper/dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nhwc_kyxc_nhwk.cpp b/composable_kernel/src/kernel_wrapper/dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nhwc_kyxc_nhwk.cpp new file mode 100644 index 0000000000..9e8de0ac8e --- /dev/null +++ b/composable_kernel/src/kernel_wrapper/dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nhwc_kyxc_nhwk.cpp @@ -0,0 +1,359 @@ +#include "common_header.hpp" +#include "type_helper.hpp" +#include "dynamic_tensor_descriptor.hpp" +#include "dynamic_tensor_descriptor_helper.hpp" +#include "gridwise_dynamic_gemm_xdlops_v2r3.hpp" +#include "transform_forward_convolution_into_gemm_v4r4r4_nhwc_kyxc_nhwk.hpp" + +using namespace ck; + +using FloatAB = typename get_type_from_type_id(CK_PARAM_IN_WEI_DATATYPE)>::type; +using FloatC = typename get_type_from_type_id(CK_PARAM_OUT_DATATYPE)>::type; +using FloatAcc = typename get_type_from_type_id(CK_PARAM_CONV_COMPTYPE)>::type; + +constexpr index_t BlockSize = CK_PARAM_BlockSize; + +constexpr index_t MPerBlock = CK_PARAM_MPerBlock; +constexpr index_t NPerBlock = CK_PARAM_NPerBlock; +constexpr index_t KPerBlock = CK_PARAM_KPerBlock; + +constexpr index_t MPerWave = CK_PARAM_MPerWave; +constexpr index_t NPerWave = CK_PARAM_NPerWave; +constexpr index_t MRepeat = CK_PARAM_MRepeat; +constexpr index_t NRepeat = CK_PARAM_NRepeat; +constexpr index_t K1 = CK_PARAM_K1; + +using ABlockTransferThreadSliceLengths_K0_M_K1 = + Sequence; +using ABlockTransferThreadClusterLengths_K0_M_K1 = + Sequence; +using ABlockTransferThreadClusterArrangeOrder = + Sequence; +using ABlockTransferSrcAccessOrder = Sequence; + +constexpr index_t ABlockTransferSrcVectorDim = CK_PARAM_ABlockTransferSrcVectorDim; +constexpr index_t ABlockTransferSrcScalarPerVector = CK_PARAM_ABlockTransferSrcScalarPerVector; +constexpr index_t ABlockTransferDstScalarPerVector_K1 = + CK_PARAM_ABlockTransferDstScalarPerVector_K1; +constexpr bool AThreadTransferSrcResetCoordinateAfterRun = + static_cast(CK_PARAM_AThreadTransferSrcResetCoordinateAfterRun); + +using BBlockTransferThreadSliceLengths_K0_N_K1 = + Sequence; +using BBlockTransferThreadClusterLengths_K0_N_K1 = + Sequence; +using BBlockTransferThreadClusterArrangeOrder = + Sequence; +using BBlockTransferSrcAccessOrder = Sequence; + +constexpr index_t BBlockTransferSrcVectorDim = CK_PARAM_BBlockTransferSrcVectorDim; +constexpr index_t BBlockTransferSrcScalarPerVector = CK_PARAM_BBlockTransferSrcScalarPerVector; +constexpr index_t BBlockTransferDstScalarPerVector_K1 = + CK_PARAM_BBlockTransferDstScalarPerVector_K1; +constexpr bool BThreadTransferSrcResetCoordinateAfterRun = + static_cast(CK_PARAM_BThreadTransferSrcResetCoordinateAfterRun); + +using CThreadTransferSrcDstAccessOrder = Sequence; +constexpr index_t CThreadTransferSrcDstVectorDim = CK_PARAM_CThreadTransferSrcDstVectorDim; +constexpr index_t CThreadTransferDstScalarPerVector = CK_PARAM_CThreadTransferDstScalarPerVector; + +extern "C" __global__ void +dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nhwc_kyxc_nhwk_prepare( + int n, + int hi, + int wi, + int c, + int k, + int y, + int x, + int convStrideH, + int convStrideW, + int convDilationY, + int convDilationX, + int leftPadH, + int leftPadW, + int rightPadH, + int rightPadW, + void* p_a_k0_m_k1_grid_desc, + void* p_b_k0_n_k1_grid_desc, + void* p_c_m0_m1_m2_n_grid_desc, + void* p_c_blockid_to_m0_n0_block_cluster_adaptor) +{ + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + constexpr auto I2 = Number<2>{}; + + const index_t ho = (hi + leftPadH + rightPadH - convDilationY * (y - 1) - 1) / convStrideH + 1; + const index_t wo = (wi + leftPadW + rightPadW - convDilationX * (x - 1) - 1) / convStrideW + 1; + + const auto in_n_hi_wi_c_desc = + make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(n, hi, wi, c)); + const auto wei_k_y_x_c_desc = + make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(k, y, x, c)); + const auto out_n_ho_wo_k_desc = + make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(n, ho, wo, k)); + + const auto descs = transform_forward_convolution_into_gemm_v4r4r4_nhwc_kyxc_nhwk_pad( + in_n_hi_wi_c_desc, + wei_k_y_x_c_desc, + out_n_ho_wo_k_desc, + make_tuple(convStrideH, convStrideW), + make_tuple(convDilationY, convDilationX), + make_tuple(leftPadH, leftPadW), + make_tuple(rightPadH, rightPadW), + Number{}); + + const auto a_k0_m_k1_grid_desc = descs[I0]; + const auto b_k0_n_k1_grid_desc = descs[I1]; + const auto c_m_n_grid_desc = descs[I2]; + + using AK0MK1GridDesc = decltype(a_k0_m_k1_grid_desc); + using BK0NK1GridDesc = decltype(b_k0_n_k1_grid_desc); + using CMNGridDesc = decltype(c_m_n_grid_desc); + + using BGridIteratorHacks = decltype(make_tuple( + make_tuple(Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}), + make_tuple( + Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}))); + + using AGridIteratorHacks = + decltype(make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{}, + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}), + make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{}, + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{}))); + + using CGridIteratorHacks = decltype(make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 1, 0, 0>{}, + Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 1, 0, 0>{}, + Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 1, 0, 0>{}), + make_tuple(Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 2, 0, 0>{}, + Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 2, 0, 0>{}, + Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 2, 0, 0>{}))); + + using AGridMoveSliceWindowIteratorHacks = Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 0, 0>; + using BGridMoveSliceWindowIteratorHacks = Sequence<0, 0, 0, 0, 0>; + + using GridwiseGemm = + GridwiseDynamicGemm_k0mk1_k0nk1_mn_xdlops_v2r3; + + auto c_m0_m1_m2_n_grid_desc = GridwiseGemm::MakeCM0M1M2NGridDescriptor(c_m_n_grid_desc); + + auto c_blockid_to_m0_n0_block_cluster_adaptor = + GridwiseGemm::MakeCBlockClusterAdaptor(c_m_n_grid_desc); + + if(hipThreadIdx_x == 0) + { + *static_cast*>(p_a_k0_m_k1_grid_desc) = + a_k0_m_k1_grid_desc; + *static_cast*>(p_b_k0_n_k1_grid_desc) = + b_k0_n_k1_grid_desc; + *static_cast(p_c_m0_m1_m2_n_grid_desc) = + c_m0_m1_m2_n_grid_desc; + *static_cast( + p_c_blockid_to_m0_n0_block_cluster_adaptor) = c_blockid_to_m0_n0_block_cluster_adaptor; + } +}; + +extern "C" __global__ void +#if CK_USE_LAUNCH_BOUNDS + __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +#endif + dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nhwc_kyxc_nhwk( + const FloatAB* __restrict__ p_a_grid, + const FloatAB* __restrict__ p_b_grid, + FloatC* __restrict__ p_c_grid, + const void __CONSTANT__* p_a_k0_m_k1_grid_desc, + const void __CONSTANT__* p_b_k0_n_k1_grid_desc, + const void __CONSTANT__* p_c_m0_m1_m2_n_grid_desc, + const void __CONSTANT__* p_c_blockid_to_m0_n0_block_cluster_adaptor) +{ + + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + constexpr auto I2 = Number<2>{}; + constexpr auto I3 = Number<3>{}; + + constexpr auto in_n_hi_wi_c_desc = + make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(256, 28, 28, 256)); + constexpr auto wei_k_y_x_c_desc = + make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(256, 3, 3, 256)); + constexpr auto out_n_ho_wo_k_desc = + make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(256, 28, 28, 256)); + + constexpr auto descs = + transform_forward_convolution_into_gemm_v4r4r4_nhwc_kyxc_nhwk_pad(in_n_hi_wi_c_desc, + wei_k_y_x_c_desc, + out_n_ho_wo_k_desc, + make_tuple(1, 1), + make_tuple(1, 1), + make_tuple(1, 1), + make_tuple(1, 1), + Number{}); + + constexpr auto a_k0_m_k1_grid_desc_tmp = descs[I0]; + constexpr auto b_k0_n_k1_grid_desc_tmp = descs[I1]; + constexpr auto c_m_n_grid_desc = descs[I2]; + + using AK0MK1GridDesc = decltype(a_k0_m_k1_grid_desc_tmp); + using BK0NK1GridDesc = decltype(b_k0_n_k1_grid_desc_tmp); + using CMNGridDesc = decltype(c_m_n_grid_desc); + + using BGridIteratorHacks = decltype(make_tuple( + make_tuple(Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}), + make_tuple( + Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}))); + + using AGridIteratorHacks = + decltype(make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{}, + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}), + make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{}, + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{}))); + + using CGridIteratorHacks = decltype(make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 1, 0, 0>{}, + Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 1, 0, 0>{}, + Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 1, 0, 0>{}), + make_tuple(Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 2, 0, 0>{}, + Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 2, 0, 0>{}, + Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 2, 0, 0>{}))); + + using AGridMoveSliceWindowIteratorHacks = Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 0, 0>; + using BGridMoveSliceWindowIteratorHacks = Sequence<0, 0, 0, 0, 0>; + + using GridwiseGemm = + GridwiseDynamicGemm_k0mk1_k0nk1_mn_xdlops_v2r3; + constexpr auto c_m0_m1_m2_n_grid_desc_tmp = + GridwiseGemm::MakeCM0M1M2NGridDescriptor(c_m_n_grid_desc); + constexpr auto c_blockid_to_m0_n0_block_cluster_adaptor_tmp = + GridwiseGemm::MakeCBlockClusterAdaptor(c_m_n_grid_desc); + + using CM0M1M2NGridDesc = decltype(c_m0_m1_m2_n_grid_desc_tmp); + using CBlockIdToM0N0BlockClusterAdaptor = + decltype(c_blockid_to_m0_n0_block_cluster_adaptor_tmp); + + const auto a_k0_m_k1_grid_desc = + *reinterpret_cast((const void*)p_a_k0_m_k1_grid_desc); + const auto b_k0_n_k1_grid_desc = + *reinterpret_cast((const void*)p_b_k0_n_k1_grid_desc); + const auto c_m0_m1_m2_n_grid_desc = + *reinterpret_cast((const void*)p_c_m0_m1_m2_n_grid_desc); + const auto c_blockid_to_m0_n0_block_cluster_adaptor = + *reinterpret_cast( + (const void*)p_c_blockid_to_m0_n0_block_cluster_adaptor); + + constexpr index_t shared_block_size = + GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatAB); + + __shared__ FloatAB p_shared_block[shared_block_size]; + + GridwiseGemm::Run(p_a_grid, + p_b_grid, + p_c_grid, + p_shared_block, + a_k0_m_k1_grid_desc, + b_k0_n_k1_grid_desc, + c_m0_m1_m2_n_grid_desc, + c_blockid_to_m0_n0_block_cluster_adaptor); +}; diff --git a/driver/CMakeLists.txt b/driver/CMakeLists.txt index 9800559fe9..ecc4d7091d 100644 --- a/driver/CMakeLists.txt +++ b/driver/CMakeLists.txt @@ -92,6 +92,8 @@ set(MCONV_KERNEL_INCLUDES set(MCONV_KERNELS ../composable_kernel/src/kernel_wrapper/dynamic_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw.cpp ../composable_kernel/src/kernel_wrapper/dynamic_convolution_forward_implicit_gemm_v4r5_nchw_kcyx_nkhw.cpp + ../composable_kernel/src/kernel_wrapper/dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw.cpp + ../composable_kernel/src/kernel_wrapper/dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nhwc_kyxc_nhwk.cpp ) add_kernels("olCompiling/" "${MCONV_KERNELS}") diff --git a/driver/conv_bwd_data_driver_v2.cpp b/driver/conv_bwd_data_driver_v2.cpp index 4ae9c5c749..61c3fc385d 100644 --- a/driver/conv_bwd_data_driver_v2.cpp +++ b/driver/conv_bwd_data_driver_v2.cpp @@ -16,7 +16,7 @@ #include "device_dynamic_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_kyxc_nhwk.hpp" #define USE_DYNAMIC_MODE 1 -#define USE_CONV_BWD_V4R1_XDL_NHWC 0 +#define USE_CONV_BWD_V4R1_XDL_NHWC 1 #define USE_CONV_BWD_V4R1R2_XDL_NHWC 1 enum ConvBackwardDataAlgo diff --git a/driver/conv_driver_v2.cpp b/driver/conv_driver_v2.cpp index 3b9fde9257..3574431556 100644 --- a/driver/conv_driver_v2.cpp +++ b/driver/conv_driver_v2.cpp @@ -19,8 +19,6 @@ #include "device_dynamic_convolution_forward_implicit_gemm_v4r5r2_nchw_kcyx_nkhw.hpp" #include "device_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw.hpp" #include "device_dynamic_convolution_forward_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw.hpp" -#include "device_dynamic_convolution_forward_implicit_gemm_v4r4r2_xdlops_nhwc_kyxc_nhwk.hpp" -#include "device_dynamic_convolution_forward_implicit_gemm_v4r4r3_xdlops_nhwc_kyxc_nhwk.hpp" #include "device_dynamic_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk.hpp" #define USE_DYNAMIC_MODE 1 @@ -30,9 +28,7 @@ #define USE_CONV_FWD_V4R5_NCHW 0 #define USE_CONV_FWD_V4R5R2_NCHW 0 #define USE_CONV_FWD_V5R1_NCHW 0 -#define USE_CONV_FWD_V4R4_XDL_NCHW 1 -#define USE_CONV_FWD_V4R4R2_XDL_NHWC 0 -#define USE_CONV_FWD_V4R4R3_XDL_NHWC 0 +#define USE_CONV_FWD_V4R4R2_XDL_NCHW 1 #define USE_CONV_FWD_V4R4R4_XDL_NHWC 1 enum ConvForwardAlgo @@ -43,10 +39,8 @@ enum ConvForwardAlgo V4R5NCHW, // 3 V4R5R2NCHW, // 4 V5R1NCHW, // 5 - V4R4XDLNCHW, // 6 - V4R4R2XDLNHWC, // 7 - V4R4R3XDLNHWC, // 8 - V4R4R4XDLNHWC // 9 + V4R4R2XDLNCHW, // 6 + V4R4R4XDLNHWC // 7 }; int main(int argc, char* argv[]) @@ -462,8 +456,8 @@ int main(int argc, char* argv[]) } #endif -#if USE_CONV_FWD_V4R4_XDL_NCHW - if(algo == ConvForwardAlgo::V4R4XDLNCHW) +#if USE_CONV_FWD_V4R4R2_XDL_NCHW + if(algo == ConvForwardAlgo::V4R4R2XDLNCHW) { if(layout != ConvTensorLayout::NCHW) { @@ -489,60 +483,6 @@ int main(int argc, char* argv[]) } #endif -#if USE_CONV_FWD_V4R4R2_XDL_NHWC - if(algo == ConvForwardAlgo::V4R4R2XDLNHWC) - { - if(layout != ConvTensorLayout::NHWC) - { - throw std::runtime_error("wrong! layout"); - } - - const auto tmp = f_make_for_device_nhwc(); - - device_dynamic_convolution_forward_implicit_gemm_v4r4r2_xdlops_nhwc_kyxc_nhwk( - tmp[I0], - tmp[I1], - tmp[I2], - tmp[I3], - tmp[I4], - tmp[I5], - tmp[I6], - in, - wei, - out_device, - nrepeat); - } -#endif - -#if USE_CONV_FWD_V4R4R3_XDL_NHWC - if(algo == ConvForwardAlgo::V4R4R3XDLNHWC) - { - if(layout != ConvTensorLayout::NHWC) - { - throw std::runtime_error("wrong! layout"); - } - - const auto tmp = f_make_for_device_nhwc(); - - device_dynamic_convolution_forward_implicit_gemm_v4r4r3_xdlops_nhwc_kyxc_nhwk( - tmp[I0], - tmp[I1], - tmp[I2], - tmp[I3], - tmp[I4], - tmp[I5], - tmp[I6], - in, - wei, - out_device, - nrepeat); - } -#endif - #if USE_CONV_FWD_V4R4R4_XDL_NHWC if(algo == ConvForwardAlgo::V4R4R4XDLNHWC) { diff --git a/driver/conv_driver_v2_olc.cpp b/driver/conv_driver_v2_olc.cpp index d117dfdd31..14e3e95205 100644 --- a/driver/conv_driver_v2_olc.cpp +++ b/driver/conv_driver_v2_olc.cpp @@ -12,11 +12,17 @@ #include "conv_common.hpp" #include "host_conv.hpp" #include "device_tensor.hpp" + #include "olc_device_dynamic_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp" #include "olc_device_dynamic_convolution_forward_implicit_gemm_v4r5_nchw_kcyx_nkhw.hpp" +#include "olc_device_dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw.hpp" +#include "olc_device_dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nhwc_kyxc_nhwk.hpp" + #define USE_CONV_FWD_V4R4_NCHW 1 #define USE_CONV_FWD_V4R5_NCHW 1 +#define USE_CONV_FWD_V4R4_XDLOPS_NCHW 1 +#define USE_CONV_FWD_V4R4_XDLOPS_NHWC 1 #include "conv_tunables.hpp" #include "handle.hpp" @@ -24,10 +30,10 @@ enum ConvForwardAlgo { - V4R4NCHW, - V4R4NHWC, - V4R5NCHW, - V5R1NCHW + V4R4NCHW, // 0 + V4R5NCHW, // 1 + V4R4XDLNCHW, // 2 + V4R4XDLNHWC // 3 }; int main(int argc, char* argv[]) @@ -105,14 +111,16 @@ int main(int argc, char* argv[]) { case ConvTensorLayout::NCHW: // NCHW - in_lengths_host[0] = static_cast(N); - in_lengths_host[1] = static_cast(C); - in_lengths_host[2] = static_cast(Hi); - in_lengths_host[3] = static_cast(Wi); + in_lengths_host[0] = static_cast(N); + in_lengths_host[1] = static_cast(C); + in_lengths_host[2] = static_cast(Hi); + in_lengths_host[3] = static_cast(Wi); + wei_lengths_host[0] = static_cast(K); wei_lengths_host[1] = static_cast(C); wei_lengths_host[2] = static_cast(Y); wei_lengths_host[3] = static_cast(X); + out_lengths_host[0] = static_cast(N); out_lengths_host[1] = static_cast(K); out_lengths_host[2] = static_cast(Ho); @@ -120,14 +128,16 @@ int main(int argc, char* argv[]) break; case ConvTensorLayout::NHWC: // NHWC - in_lengths_host[0] = static_cast(N); - in_lengths_host[1] = static_cast(Hi); - in_lengths_host[2] = static_cast(Wi); - in_lengths_host[3] = static_cast(C); + in_lengths_host[0] = static_cast(N); + in_lengths_host[1] = static_cast(Hi); + in_lengths_host[2] = static_cast(Wi); + in_lengths_host[3] = static_cast(C); + wei_lengths_host[0] = static_cast(K); wei_lengths_host[1] = static_cast(Y); wei_lengths_host[2] = static_cast(X); wei_lengths_host[3] = static_cast(C); + out_lengths_host[0] = static_cast(N); out_lengths_host[1] = static_cast(Ho); out_lengths_host[2] = static_cast(Wo); @@ -271,10 +281,80 @@ int main(int argc, char* argv[]) } #endif +#if USE_CONV_FWD_V4R4_XDLOPS_NCHW + if(algo == ConvForwardAlgo::V4R4XDLNCHW) + { + if(layout != ConvTensorLayout::NCHW) + { + throw std::runtime_error("wrong! layout"); + } + + const auto tmp = f_make_for_device_nchw(); + + tunable_dyn_conv_fwd_v4r4_xdlops_nchw_kcyx_nkhw* tunable = + &default_tunable_dyn_conv_fwd_v4r4_xdlops_nchw_kcyx_nkhw; + + device_dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw_olc( + handle, + tmp[I0], + tmp[I1], + tmp[I2], + conv_strides, + conv_dilations, + in_left_pads, + in_right_pads, + in, + wei, + out_device, + tunable, + nrepeat); + } +#endif + +#if USE_CONV_FWD_V4R4_XDLOPS_NHWC + if(algo == ConvForwardAlgo::V4R4XDLNHWC) + { + if(layout != ConvTensorLayout::NHWC) + { + throw std::runtime_error("wrong! layout"); + } + + const auto tmp = f_make_for_device_nhwc(); + + tunable_dyn_conv_fwd_v4r4_xdlops_nhwc_kyxc_nhwk* tunable = + &default_tunable_dyn_conv_fwd_v4r4_xdlops_nhwc_kyxc_nhwk; + + device_dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nhwc_kyxc_nhwk_olc( + handle, + tmp[I0], + tmp[I1], + tmp[I2], + conv_strides, + conv_dilations, + in_left_pads, + in_right_pads, + in, + wei, + out_device, + tunable, + nrepeat); + } +#endif + if(do_verification) { - host_direct_convolution( - in, wei, out_host, conv_strides, conv_dilations, in_left_pads, in_right_pads); + host_direct_convolution(in, + wei, + out_host, + make_tuple(conv_stride_h, conv_stride_w), + make_tuple(conv_dilation_h, conv_dilation_w), + make_tuple(in_left_pad_h, in_left_pad_w), + make_tuple(in_right_pad_h, in_right_pad_w), + layout); check_error(out_host, out_device); diff --git a/driver/include/conv_tunables.hpp b/driver/include/conv_tunables.hpp index 33f791d289..0275a95f9a 100644 --- a/driver/include/conv_tunables.hpp +++ b/driver/include/conv_tunables.hpp @@ -50,6 +50,146 @@ static tunable_dyn_conv_fwd_v4r4_nchw_kcyx_nkhw default_tunable_dyn_conv_fwd_v4r {0, 1, 2}, {0, 1, 2}, 2, 1, 1, false, {3, 4, 5, 0, 1, 2}, 5, 1}; +struct tunable_dyn_conv_fwd_v4r4_xdlops_nchw_kcyx_nkhw +{ + ck::index_t BlockSize; // usually not tunable + + ck::index_t MPerBlock; + ck::index_t NPerBlock; + ck::index_t KPerBlock; + + ck::index_t MPerWave; + ck::index_t NPerWave; + ck::index_t K1; + + ck::index_t MRepeat; + ck::index_t NRepeat; + + std::array ABlockTransferThreadSliceLengths_K0_M_K1; + std::array ABlockTransferThreadClusterLengths_K0_M_K1; + std::array ABlockTransferThreadClusterArrangeOrder; + std::array ABlockTransferSrcAccessOrder; + ck::index_t ABlockTransferSrcVectorDim; + ck::index_t ABlockTransferSrcScalarPerVector; + ck::index_t ABlockTransferDstScalarPerVector_K1; + bool AThreadTransferSrcResetCoordinateAfterRun; + + std::array BBlockTransferThreadSliceLengths_K0_N_K1; + std::array BBlockTransferThreadClusterLengths_K0_N_K1; + std::array BBlockTransferThreadClusterArrangeOrder; + std::array BBlockTransferSrcAccessOrder; + ck::index_t BBlockTransferSrcVectorDim; + ck::index_t BBlockTransferSrcScalarPerVector; + ck::index_t BBlockTransferDstScalarPerVector_K1; + bool BThreadTransferSrcResetCoordinateAfterRun; + + std::array CThreadTransferSrcDstAccessOrder; + ck::index_t CThreadTransferSrcDstVectorDim; + ck::index_t CThreadTransferDstScalarPerVector; +}; + +static tunable_dyn_conv_fwd_v4r4_xdlops_nchw_kcyx_nkhw + default_tunable_dyn_conv_fwd_v4r4_xdlops_nchw_kcyx_nkhw = { + 256, // BlockSize + 128, // MPerBlock, + 128, // NPerBlock, + 4, // KPerBlock, + 32, // MPerWave, + 32, // NPerWave, + 4, // K1, + 2, // MRepeat, + 2, // NRepeat, + {1, 2, 4}, // ABlockTransferThreadSliceLengths_K0_M_K1, + {4, 64, 1}, // ABlockTransferThreadClusterLengths_K0_M_K1, + {1, 0, 2}, // ABlockTransferThreadClusterArrangeOrder, + {1, 0, 2}, // ABlockTransferSrcAccessOrder, + 2, // ABlockTransferSrcVectorDim + 1, // ABlockTransferSrcScalarPerVector, + 4, // ABlockTransferDstScalarPerVector_K1, + false, // AThreadTransferSrcResetCoordinateAfterRun, + {1, 2, 4}, // BBlockTransferThreadSliceLengths_K0_N_K1, + {4, 64, 1}, // BBlockTransferThreadClusterLengths_K0_N_K1, + {0, 2, 1}, // BBlockTransferThreadClusterArrangeOrder, + {1, 0, 2}, // BBlockTransferSrcAccessOrder, + 1, // BBlockTransferSrcVectorDim + 1, // BBlockTransferSrcScalarPerVector + 4, // BBlockTransferDstScalarPerVector_K1 + false, // BThreadTransferSrcResetCoordinateAfterRun + {3, 0, 1, 2, 7, 5, 4, 6}, // CThreadTransferSrcDstAccessOrder + 7, // CThreadTransferSrcDstVectorDim, + 1 // CThreadTransferDstScalarPerVector +}; + +struct tunable_dyn_conv_fwd_v4r4_xdlops_nhwc_kyxc_nhwk +{ + ck::index_t BlockSize; // usually not tunable + + ck::index_t MPerBlock; + ck::index_t NPerBlock; + ck::index_t KPerBlock; + + ck::index_t MPerWave; + ck::index_t NPerWave; + ck::index_t K1; + + ck::index_t MRepeat; + ck::index_t NRepeat; + + std::array ABlockTransferThreadSliceLengths_K0_M_K1; + std::array ABlockTransferThreadClusterLengths_K0_M_K1; + std::array ABlockTransferThreadClusterArrangeOrder; + std::array ABlockTransferSrcAccessOrder; + ck::index_t ABlockTransferSrcVectorDim; + ck::index_t ABlockTransferSrcScalarPerVector; + ck::index_t ABlockTransferDstScalarPerVector_K1; + bool AThreadTransferSrcResetCoordinateAfterRun; + + std::array BBlockTransferThreadSliceLengths_K0_N_K1; + std::array BBlockTransferThreadClusterLengths_K0_N_K1; + std::array BBlockTransferThreadClusterArrangeOrder; + std::array BBlockTransferSrcAccessOrder; + ck::index_t BBlockTransferSrcVectorDim; + ck::index_t BBlockTransferSrcScalarPerVector; + ck::index_t BBlockTransferDstScalarPerVector_K1; + bool BThreadTransferSrcResetCoordinateAfterRun; + + std::array CThreadTransferSrcDstAccessOrder; + ck::index_t CThreadTransferSrcDstVectorDim; + ck::index_t CThreadTransferDstScalarPerVector; +}; + +static tunable_dyn_conv_fwd_v4r4_xdlops_nhwc_kyxc_nhwk + default_tunable_dyn_conv_fwd_v4r4_xdlops_nhwc_kyxc_nhwk = { + 256, // BlockSize + 128, // MPerBlock, + 128, // NPerBlock, + 4, // KPerBlock, + 32, // MPerWave, + 32, // NPerWave, + 4, // K1, + 2, // MRepeat, + 2, // NRepeat, + {1, 2, 4}, // ABlockTransferThreadSliceLengths_K0_M_K1, + {4, 64, 1}, // ABlockTransferThreadClusterLengths_K0_M_K1, + {1, 0, 2}, // ABlockTransferThreadClusterArrangeOrder, + {1, 0, 2}, // ABlockTransferSrcAccessOrder, + 2, // ABlockTransferSrcVectorDim + 4, // ABlockTransferSrcScalarPerVector, + 4, // ABlockTransferDstScalarPerVector_K1, + false, // AThreadTransferSrcResetCoordinateAfterRun, + {1, 2, 4}, // BBlockTransferThreadSliceLengths_K0_N_K1, + {4, 64, 1}, // BBlockTransferThreadClusterLengths_K0_N_K1, + {1, 0, 2}, // BBlockTransferThreadClusterArrangeOrder, + {1, 0, 2}, // BBlockTransferSrcAccessOrder, + 2, // BBlockTransferSrcVectorDim + 4, // BBlockTransferSrcScalarPerVector + 4, // BBlockTransferDstScalarPerVector_K1 + false, // BThreadTransferSrcResetCoordinateAfterRun + {2, 3, 0, 1, 7, 5, 4, 6}, // CThreadTransferSrcDstAccessOrder + 7, // CThreadTransferSrcDstVectorDim, + 1 // CThreadTransferDstScalarPerVector +}; + struct tunable_dyn_conv_fwd_v4r5_nchw_kcyx_nkhw { ck::index_t BlockSize; diff --git a/driver/include/device_dynamic_convolution_backward_data_implicit_gemm_v4r1_xdlops_nhwc_kyxc_nhwk.hpp b/driver/include/device_dynamic_convolution_backward_data_implicit_gemm_v4r1_xdlops_nhwc_kyxc_nhwk.hpp index de48a0ea82..0ea190611b 100644 --- a/driver/include/device_dynamic_convolution_backward_data_implicit_gemm_v4r1_xdlops_nhwc_kyxc_nhwk.hpp +++ b/driver/include/device_dynamic_convolution_backward_data_implicit_gemm_v4r1_xdlops_nhwc_kyxc_nhwk.hpp @@ -273,6 +273,7 @@ void device_dynamic_convolution_backward_data_implicit_gemm_v4r1_xdlops_nhwc_kyx GemmKPerBlock, GemmMPerWave, GemmNPerWave, + GemmK1, MRepeat, NRepeat, GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1, diff --git a/driver/include/device_dynamic_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_kyxc_nhwk.hpp b/driver/include/device_dynamic_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_kyxc_nhwk.hpp index 0711a5f262..315f201458 100644 --- a/driver/include/device_dynamic_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_kyxc_nhwk.hpp +++ b/driver/include/device_dynamic_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_kyxc_nhwk.hpp @@ -245,6 +245,7 @@ void device_dynamic_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_k GemmKPerBlock, GemmMPerWave, GemmNPerWave, + GemmK1, MRepeat, NRepeat, GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1, diff --git a/driver/include/device_dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw.hpp b/driver/include/device_dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw.hpp deleted file mode 100644 index 5890b12e00..0000000000 --- a/driver/include/device_dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw.hpp +++ /dev/null @@ -1,283 +0,0 @@ -#include -#include "device.hpp" -#include "host_tensor.hpp" -#include "driver_dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw.hpp" - -template -void device_dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw( - const InLengths& in_n_c_hi_wi_lengths, - const WeiLengths& wei_k_c_y_x_lengths, - const OutLengths& out_n_k_ho_wo_lengths, - const ConvStrides& conv_strides, - const ConvDilations& conv_dilations, - const InLeftPads& in_left_pads, - const InRightPads& in_right_pads, - const Tensor& in_n_c_hi_wi, - const Tensor& wei_k_c_y_x, - Tensor& out_n_k_ho_wo, - ck::index_t nrepeat) -{ - using namespace ck; - - std::cout << __func__ << std::endl; - - constexpr auto I0 = Number<0>{}; - constexpr auto I1 = Number<1>{}; - constexpr auto I2 = Number<2>{}; - constexpr auto I3 = Number<3>{}; - constexpr auto I4 = Number<4>{}; - constexpr auto I5 = Number<5>{}; - constexpr auto I6 = Number<6>{}; - constexpr auto I7 = Number<7>{}; - constexpr auto I8 = Number<8>{}; - - DeviceMem in_n_c_hi_wi_device_buf(sizeof(TInWei) * in_n_c_hi_wi.mDesc.GetElementSpace()); - DeviceMem wei_k_c_y_x_device_buf(sizeof(TInWei) * wei_k_c_y_x.mDesc.GetElementSpace()); - DeviceMem out_n_k_ho_wo_device_buf(sizeof(TOut) * out_n_k_ho_wo.mDesc.GetElementSpace()); - - in_n_c_hi_wi_device_buf.ToDevice(in_n_c_hi_wi.mData.data()); - wei_k_c_y_x_device_buf.ToDevice(wei_k_c_y_x.mData.data()); - out_n_k_ho_wo_device_buf.ToDevice(out_n_k_ho_wo.mData.data()); - - const auto in_n_c_hi_wi_desc = - make_dynamic_naive_tensor_descriptor_packed_v2(in_n_c_hi_wi_lengths); - const auto wei_k_c_y_x_desc = - make_dynamic_naive_tensor_descriptor_packed_v2(wei_k_c_y_x_lengths); - const auto out_n_k_ho_wo_desc = - make_dynamic_naive_tensor_descriptor_packed_v2(out_n_k_ho_wo_lengths); - -#if 0 - constexpr index_t BlockSize = 256; - - constexpr index_t GemmMPerBlock = 128; - constexpr index_t GemmNPerBlock = 128; - constexpr index_t GemmKPerBlock = 4; - - constexpr index_t GemmMPerWave = 64; - constexpr index_t GemmNPerWave = 64; - constexpr index_t GemmKPack = 8; - - constexpr index_t MRepeat = 1; - constexpr index_t NRepeat = 1; - - using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 2, 8>; - using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>; - - constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK = 8; - constexpr index_t GemmABlockTransferDstScalarPerVector_KPack = 8; - - using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 4, 4>; - using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 32, 2>; - - constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 4; - constexpr index_t GemmBBlockTransferDstScalarPerVector_KPack = 4; - - constexpr index_t GemmCThreadTransferDstScalarPerVector_GemmN1 = 1; -#elif 0 - constexpr index_t BlockSize = 256; - - constexpr index_t GemmMPerBlock = 256; - constexpr index_t GemmNPerBlock = 128; - constexpr index_t GemmKPerBlock = 4; - - constexpr index_t GemmMPerWave = 64; - constexpr index_t GemmNPerWave = 64; - constexpr index_t GemmKPack = 8; - - constexpr index_t MRepeat = 2; - constexpr index_t NRepeat = 1; - - using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 8>; - using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>; - - constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK = 8; - constexpr index_t GemmABlockTransferDstScalarPerVector_KPack = 8; - - using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 4, 4>; - using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 32, 2>; - - constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 4; - constexpr index_t GemmBBlockTransferDstScalarPerVector_KPack = 4; - - constexpr index_t GemmCThreadTransferDstScalarPerVector_GemmN1 = 1; -#elif 0 - constexpr index_t BlockSize = 256; - - constexpr index_t GemmMPerBlock = 256; - constexpr index_t GemmNPerBlock = 128; - constexpr index_t GemmKPerBlock = 4; - - constexpr index_t GemmMPerWave = 64; - constexpr index_t GemmNPerWave = 64; - constexpr index_t GemmKPack = 8; - - constexpr index_t MRepeat = 2; - constexpr index_t NRepeat = 1; - - using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 8>; - using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>; - - constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK = 8; - constexpr index_t GemmABlockTransferDstScalarPerVector_KPack = 8; - - using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 4, 4>; - using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 32, 2>; - - constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 1; - constexpr index_t GemmBBlockTransferDstScalarPerVector_KPack = 4; - - constexpr index_t GemmCThreadTransferDstScalarPerVector_GemmN1 = 1; -#elif 1 - // [M, N, K0, K1] = [256, 128, 4, 4] - constexpr index_t BlockSize = 256; - - constexpr index_t GemmMPerBlock = 256; - constexpr index_t GemmNPerBlock = 128; - constexpr index_t GemmKPerBlock = 4; - - constexpr index_t GemmMPerWave = 64; - constexpr index_t GemmNPerWave = 64; - constexpr index_t GemmKPack = 4; - - constexpr index_t MRepeat = 2; - constexpr index_t NRepeat = 1; - - using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 4>; - using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>; - - constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK = 4; - constexpr index_t GemmABlockTransferDstScalarPerVector_KPack = 4; - - using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 2, 4>; - using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>; - - constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 1; - constexpr index_t GemmBBlockTransferDstScalarPerVector_KPack = 4; - - constexpr index_t GemmCThreadTransferDstScalarPerVector_GemmN1 = 1; -#elif 1 - // [M, N, K0, K1] = [128, 128, 4, 4] - constexpr index_t BlockSize = 256; - - constexpr index_t GemmMPerBlock = 128; - constexpr index_t GemmNPerBlock = 128; - constexpr index_t GemmKPerBlock = 4; - - constexpr index_t GemmMPerWave = 64; - constexpr index_t GemmNPerWave = 64; - constexpr index_t GemmKPack = 4; - - constexpr index_t MRepeat = 1; - constexpr index_t NRepeat = 1; - - using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 2, 4>; - using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>; - - constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK = 4; - constexpr index_t GemmABlockTransferDstScalarPerVector_KPack = 4; - - using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 2, 4>; - using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>; - - constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 1; - constexpr index_t GemmBBlockTransferDstScalarPerVector_KPack = 4; - - constexpr index_t GemmCThreadTransferDstScalarPerVector_GemmN1 = 1; -#endif - - const auto descs = -#if 1 - transform_forward_convolution_into_gemm_v4r4_xdlops_nchw_kcyx_nkhw_pad -#else - transform_forward_convolution_into_gemm_v4r4_xdlops_nchw_kcyx_nkhw_1x1 -#endif - ( - wei_k_c_y_x_desc, - in_n_c_hi_wi_desc, - out_n_k_ho_wo_desc, - conv_strides, - conv_dilations, - in_left_pads, - in_right_pads); - - for(index_t i = 0; i < 5; ++i) - { -#if 0 - float ave_time = launch_kernel_dynamic_gemm_xdlops_v1 -#else - float ave_time = launch_kernel_dynamic_gemm_xdlops_v2 -#endif - , - Sequence<1, 0, 2>, - 2, - GemmABlockTransferSrcScalarPerVector_GemmK, - GemmABlockTransferDstScalarPerVector_KPack, - false, // don't move back src coordinate after threadwise copy - GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1, - GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1, - Sequence<0, 2, 1>, - Sequence<1, 0, 2>, - 1, - GemmBBlockTransferSrcScalarPerVector_GemmN, - GemmBBlockTransferDstScalarPerVector_KPack, - false, // don't move back src coordinate after threadwise copy, which will be fused - // with MoveSrcSliceWindow() to save addr computation - Sequence<2, 3, 0, 1>, - 3, - GemmCThreadTransferDstScalarPerVector_GemmN1, - decltype(descs[I4]), - decltype(descs[I5]), - decltype(descs[I6]), - decltype(descs[I7]), - decltype(descs[I8])>(static_cast(wei_k_c_y_x_device_buf.GetDeviceBuffer()), - static_cast(in_n_c_hi_wi_device_buf.GetDeviceBuffer()), - static_cast(out_n_k_ho_wo_device_buf.GetDeviceBuffer()), - descs[I0], - descs[I1], - descs[I2], - descs[I3], - descs[I4], - descs[I5], - descs[I6], - descs[I7], - descs[I8], - nrepeat); - - float perf = (float)calculate_convolution_flops( - in_n_c_hi_wi_desc, wei_k_c_y_x_desc, out_n_k_ho_wo_desc) / - (std::size_t(1000) * 1000 * 1000) / ave_time; - - std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s" << std::endl; - } - - // copy result back to host - out_n_k_ho_wo_device_buf.FromDevice(out_n_k_ho_wo.mData.data()); -} diff --git a/driver/include/device_dynamic_convolution_forward_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw.hpp b/driver/include/device_dynamic_convolution_forward_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw.hpp index 4f531b55ae..035546d31a 100644 --- a/driver/include/device_dynamic_convolution_forward_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw.hpp +++ b/driver/include/device_dynamic_convolution_forward_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw.hpp @@ -2,7 +2,7 @@ #include "device.hpp" #include "host_tensor.hpp" #include "transform_forward_convolution_into_gemm_v4r4r2_nchw_kcyx_nkhw.hpp" -#include "driver_dynamic_gemm_xdlops_v2r2.hpp" +#include "driver_dynamic_gemm_xdlops_v2r3.hpp" template ; - using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>; - - constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK1 = 4; - constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 4; - - using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 2, 4>; - using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>; - - constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 1; - constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 4; - - constexpr index_t GemmCThreadTransferDstScalarPerVector = 1; -#elif 0 - // [M, N, K0, K1] = [128, 128, 4, 4] for fp32 - constexpr index_t BlockSize = 256; - - constexpr index_t GemmMPerBlock = 128; - constexpr index_t GemmNPerBlock = 128; - constexpr index_t GemmKPerBlock = 4; - - constexpr index_t GemmMPerWave = 64; - constexpr index_t GemmNPerWave = 64; - constexpr index_t GemmK1 = 4; - - constexpr index_t MRepeat = 1; - constexpr index_t NRepeat = 1; - - using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 2, 4>; - using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>; - - constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK1 = 4; - constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 4; - - using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 2, 4>; - using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>; - - constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 1; - constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 4; - - constexpr index_t GemmCThreadTransferDstScalarPerVector = 1; -#elif 1 +#if 1 // [M, N, K0, K1] = [256, 128, 4, 8] for fp16 constexpr index_t BlockSize = 256; @@ -120,12 +64,12 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nk constexpr index_t GemmNPerBlock = 128; constexpr index_t GemmKPerBlock = 4; - constexpr index_t GemmMPerWave = 64; - constexpr index_t GemmNPerWave = 64; + constexpr index_t GemmMPerWave = 32; + constexpr index_t GemmNPerWave = 32; constexpr index_t GemmK1 = 8; - constexpr index_t MRepeat = 2; - constexpr index_t NRepeat = 1; + constexpr index_t MRepeat = 4; + constexpr index_t NRepeat = 2; using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 8>; using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>; @@ -139,34 +83,6 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nk constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 1; constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 8; - constexpr index_t GemmCThreadTransferDstScalarPerVector = 1; -#elif 1 - // [M, N, K0, K1] = [128, 128, 4, 8] for fp16 - constexpr index_t BlockSize = 256; - - constexpr index_t GemmMPerBlock = 128; - constexpr index_t GemmNPerBlock = 128; - constexpr index_t GemmKPerBlock = 4; - - constexpr index_t GemmMPerWave = 64; - constexpr index_t GemmNPerWave = 64; - constexpr index_t GemmK1 = 8; - - constexpr index_t MRepeat = 1; - constexpr index_t NRepeat = 1; - - using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 2, 8>; - using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>; - - constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK1 = 8; - constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 8; - - using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 4, 4>; - using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 32, 2>; - - constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 1; - constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 4; - constexpr index_t GemmCThreadTransferDstScalarPerVector = 1; #endif @@ -200,10 +116,18 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nk constexpr auto out_m0_m1_m2_n_grid_iterator_hacks = make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 1, 0, 0>{}, + Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 1, 0, 0>{}, + Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 1, 0, 0>{}), make_tuple(Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 2, 0, 0>{}, + Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 2, 0, 0>{}, + Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 2, 0, 0>{})); @@ -216,7 +140,7 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nk for(index_t i = 0; i < 5; ++i) { - float ave_time = driver_dynamic_gemm_xdlops_v2r2< + float ave_time = driver_dynamic_gemm_xdlops_v2r3< BlockSize, TInWei, TAcc, @@ -230,6 +154,7 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nk GemmKPerBlock, GemmMPerWave, GemmNPerWave, + GemmK1, MRepeat, NRepeat, GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1, @@ -248,26 +173,26 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nk GemmBBlockTransferSrcScalarPerVector_GemmN, GemmBBlockTransferDstScalarPerVector_GemmK1, false, // don't move back src coordinate after threadwise copy - Sequence<3, 0, 1, 2>, - 3, + Sequence<3, 0, 1, 2, 7, 5, 4, 6>, + 7, GemmCThreadTransferDstScalarPerVector, decltype(wei_gemmk0_gemmm_gemmk1_grid_iterator_hacks), decltype(in_gemmk0_gemmn_gemmk1_grid_iterator_hacks), decltype(out_m0_m1_m2_n_grid_iterator_hacks), decltype(wei_gemmk0_gemmm_gemmk1_grid_move_slice_window_iterator_hacks), - decltype(in_gemmk0_gemmn_gemmk1_grid_move_slice_window_iterator_hacks)>( - static_cast(wei_k_c_y_x_device_buf.GetDeviceBuffer()), - static_cast(in_n_c_hi_wi_device_buf.GetDeviceBuffer()), - static_cast(out_n_k_ho_wo_device_buf.GetDeviceBuffer()), - wei_gemmk0_gemmm_gemmk1_grid_desc, - in_gemmk0_gemmn_gemmk1_grid_desc, - out_gemmm_gemmn_grid_desc, - wei_gemmk0_gemmm_gemmk1_grid_iterator_hacks, - in_gemmk0_gemmn_gemmk1_grid_iterator_hacks, - out_m0_m1_m2_n_grid_iterator_hacks, - wei_gemmk0_gemmm_gemmk1_grid_move_slice_window_iterator_hacks, - in_gemmk0_gemmn_gemmk1_grid_move_slice_window_iterator_hacks, - nrepeat); + decltype(in_gemmk0_gemmn_gemmk1_grid_move_slice_window_iterator_hacks), + false>(static_cast(wei_k_c_y_x_device_buf.GetDeviceBuffer()), + static_cast(in_n_c_hi_wi_device_buf.GetDeviceBuffer()), + static_cast(out_n_k_ho_wo_device_buf.GetDeviceBuffer()), + wei_gemmk0_gemmm_gemmk1_grid_desc, + in_gemmk0_gemmn_gemmk1_grid_desc, + out_gemmm_gemmn_grid_desc, + wei_gemmk0_gemmm_gemmk1_grid_iterator_hacks, + in_gemmk0_gemmn_gemmk1_grid_iterator_hacks, + out_m0_m1_m2_n_grid_iterator_hacks, + wei_gemmk0_gemmm_gemmk1_grid_move_slice_window_iterator_hacks, + in_gemmk0_gemmn_gemmk1_grid_move_slice_window_iterator_hacks, + nrepeat); float perf = (float)calculate_convolution_flops( in_n_c_hi_wi_desc, wei_k_c_y_x_desc, out_n_k_ho_wo_desc) / diff --git a/driver/include/device_dynamic_convolution_forward_implicit_gemm_v4r4r2_xdlops_nhwc_kyxc_nhwk.hpp b/driver/include/device_dynamic_convolution_forward_implicit_gemm_v4r4r2_xdlops_nhwc_kyxc_nhwk.hpp deleted file mode 100644 index bb37ac309f..0000000000 --- a/driver/include/device_dynamic_convolution_forward_implicit_gemm_v4r4r2_xdlops_nhwc_kyxc_nhwk.hpp +++ /dev/null @@ -1,240 +0,0 @@ -#include -#include "device.hpp" -#include "host_tensor.hpp" -#include "transform_forward_convolution_into_gemm_v4r4r2_nhwc_kyxc_nhwk.hpp" -#include "driver_dynamic_gemm_xdlops_v2r2.hpp" - -template -void device_dynamic_convolution_forward_implicit_gemm_v4r4r2_xdlops_nhwc_kyxc_nhwk( - const InLengths& in_n_hi_wi_c_lengths, - const WeiLengths& wei_k_y_x_c_lengths, - const OutLengths& out_n_ho_wo_k_lengths, - const ConvStrides& conv_strides, - const ConvDilations& conv_dilations, - const InLeftPads& in_left_pads, - const InRightPads& in_right_pads, - const Tensor& in_n_hi_wi_c, - const Tensor& wei_k_y_x_c, - Tensor& out_n_ho_wo_k, - ck::index_t nrepeat) -{ - using namespace ck; - - std::cout << __func__ << std::endl; - - constexpr auto I0 = Number<0>{}; - constexpr auto I1 = Number<1>{}; - constexpr auto I2 = Number<2>{}; - constexpr auto I3 = Number<3>{}; - constexpr auto I4 = Number<4>{}; - constexpr auto I5 = Number<5>{}; - constexpr auto I6 = Number<6>{}; - constexpr auto I7 = Number<7>{}; - constexpr auto I8 = Number<8>{}; - - DeviceMem in_n_hi_wi_c_device_buf(sizeof(TInWei) * in_n_hi_wi_c.mDesc.GetElementSpace()); - DeviceMem wei_k_y_x_c_device_buf(sizeof(TInWei) * wei_k_y_x_c.mDesc.GetElementSpace()); - DeviceMem out_n_ho_wo_k_device_buf(sizeof(TOut) * out_n_ho_wo_k.mDesc.GetElementSpace()); - - in_n_hi_wi_c_device_buf.ToDevice(in_n_hi_wi_c.mData.data()); - wei_k_y_x_c_device_buf.ToDevice(wei_k_y_x_c.mData.data()); - out_n_ho_wo_k_device_buf.ToDevice(out_n_ho_wo_k.mData.data()); - - const auto in_n_hi_wi_c_desc = - make_dynamic_naive_tensor_descriptor_packed_v2(in_n_hi_wi_c_lengths); - const auto wei_k_y_x_c_desc = - make_dynamic_naive_tensor_descriptor_packed_v2(wei_k_y_x_c_lengths); - const auto out_n_ho_wo_k_desc = - make_dynamic_naive_tensor_descriptor_packed_v2(out_n_ho_wo_k_lengths); - -#if 1 - // [M, N, K0, K1] = [256, 128, 4, 4] for fp32 - constexpr index_t BlockSize = 256; - - constexpr index_t GemmMPerBlock = 256; - constexpr index_t GemmNPerBlock = 128; - constexpr index_t GemmKPerBlock = 4; - - constexpr index_t GemmMPerWave = 64; - constexpr index_t GemmNPerWave = 64; - constexpr index_t GemmK1 = 4; - - constexpr index_t MRepeat = 2; - constexpr index_t NRepeat = 1; - - using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 4>; - using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>; - - constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK1 = 4; - constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 4; - - using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 2, 4>; - using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>; - - constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmK1 = 4; - constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 4; - - constexpr index_t GemmCThreadTransferDstScalarPerVector = 4; -#elif 1 - // [M, N, K0, K1] = [256, 128, 4, 8] for fp16 - constexpr index_t BlockSize = 256; - - constexpr index_t GemmMPerBlock = 256; - constexpr index_t GemmNPerBlock = 128; - constexpr index_t GemmKPerBlock = 4; - - constexpr index_t GemmMPerWave = 64; - constexpr index_t GemmNPerWave = 64; - constexpr index_t GemmK1 = 8; - - constexpr index_t MRepeat = 2; - constexpr index_t NRepeat = 1; - - using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 8>; - using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>; - - constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK1 = 8; - constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 8; - - using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 2, 8>; - using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>; - - constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmK1 = 8; - constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 8; - - constexpr index_t GemmCThreadTransferDstScalarPerVector = 4; -#endif - - const auto descs = - transform_forward_convolution_into_gemm_v4r4r2_nhwc_kyxc_nhwk_pad(wei_k_y_x_c_desc, - in_n_hi_wi_c_desc, - out_n_ho_wo_k_desc, - conv_strides, - conv_dilations, - in_left_pads, - in_right_pads, - Number{}); - - const auto wei_gemmk0_gemmm_gemmk1_grid_desc = descs[I0]; - const auto in_gemmk0_gemmn_gemmk1_grid_desc = descs[I1]; - const auto out_gemmm_gemmn_grid_desc = descs[I2]; - - // HACK: hacks that control index calculation when iterating over A, B, C matrix - constexpr auto wei_gemmk0_gemmm_gemmk1_grid_iterator_hacks = make_tuple( - make_tuple(Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}), - make_tuple( - Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{})); - - constexpr auto in_gemmk0_gemmn_gemmk1_grid_iterator_hacks = - make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{}, - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}), - make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{}, - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{})); - - constexpr auto out_m0_m1_m2_n_grid_iterator_hacks = - make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 1, 0, 0>{}), - make_tuple(Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 2, 0, 0>{})); - - constexpr auto wei_gemmk0_gemmm_gemmk1_grid_move_slice_window_iterator_hacks = - Sequence<0, 0, 0, 0, 0>{}; - - constexpr auto in_gemmk0_gemmn_gemmk1_grid_move_slice_window_iterator_hacks = - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 0, 0>{}; - - for(index_t i = 0; i < 5; ++i) - { - float ave_time = driver_dynamic_gemm_xdlops_v2r2< - BlockSize, - TInWei, - TAcc, - TOut, - InMemoryDataOperation::Set, - decltype(wei_gemmk0_gemmm_gemmk1_grid_desc), - decltype(in_gemmk0_gemmn_gemmk1_grid_desc), - decltype(out_gemmm_gemmn_grid_desc), - GemmMPerBlock, - GemmNPerBlock, - GemmKPerBlock, - GemmMPerWave, - GemmNPerWave, - MRepeat, - NRepeat, - GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1, - GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1, - Sequence<1, 0, 2>, - Sequence<1, 0, 2>, - 2, - GemmABlockTransferSrcScalarPerVector_GemmK1, - GemmABlockTransferDstScalarPerVector_GemmK1, - false, // don't move back src coordinate after threadwise copy - GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1, - GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1, - Sequence<1, 0, 2>, - Sequence<1, 0, 2>, - 2, - GemmBBlockTransferSrcScalarPerVector_GemmK1, - GemmBBlockTransferDstScalarPerVector_GemmK1, - false, // don't move back src coordinate after threadwise copy - Sequence<2, 3, 0, 1>, - 2, - GemmCThreadTransferDstScalarPerVector, - decltype(wei_gemmk0_gemmm_gemmk1_grid_iterator_hacks), - decltype(in_gemmk0_gemmn_gemmk1_grid_iterator_hacks), - decltype(out_m0_m1_m2_n_grid_iterator_hacks), - decltype(wei_gemmk0_gemmm_gemmk1_grid_move_slice_window_iterator_hacks), - decltype(in_gemmk0_gemmn_gemmk1_grid_move_slice_window_iterator_hacks)>( - static_cast(wei_k_y_x_c_device_buf.GetDeviceBuffer()), - static_cast(in_n_hi_wi_c_device_buf.GetDeviceBuffer()), - static_cast(out_n_ho_wo_k_device_buf.GetDeviceBuffer()), - wei_gemmk0_gemmm_gemmk1_grid_desc, - in_gemmk0_gemmn_gemmk1_grid_desc, - out_gemmm_gemmn_grid_desc, - wei_gemmk0_gemmm_gemmk1_grid_iterator_hacks, - in_gemmk0_gemmn_gemmk1_grid_iterator_hacks, - out_m0_m1_m2_n_grid_iterator_hacks, - wei_gemmk0_gemmm_gemmk1_grid_move_slice_window_iterator_hacks, - in_gemmk0_gemmn_gemmk1_grid_move_slice_window_iterator_hacks, - nrepeat); - - { - const auto N = out_n_ho_wo_k_lengths[I0]; - const auto K = out_n_ho_wo_k_lengths[I3]; - const auto C = wei_k_y_x_c_lengths[I3]; - - const auto Hi = in_n_hi_wi_c_lengths[I1]; - const auto Wi = in_n_hi_wi_c_lengths[I2]; - - const auto Ho = out_n_ho_wo_k_lengths[I1]; - const auto Wo = out_n_ho_wo_k_lengths[I2]; - - const auto Y = wei_k_y_x_c_lengths[I1]; - const auto X = wei_k_y_x_c_lengths[I2]; - - float perf = (float)(std::size_t(2) * N * K * Ho * Wo * C * Y * X) / - (std::size_t(1000) * 1000 * 1000) / ave_time; - - std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s" - << std::endl; - } - } - - // copy result back to host - out_n_ho_wo_k_device_buf.FromDevice(out_n_ho_wo_k.mData.data()); -} diff --git a/driver/include/device_dynamic_convolution_forward_implicit_gemm_v4r4r3_xdlops_nhwc_kyxc_nhwk.hpp b/driver/include/device_dynamic_convolution_forward_implicit_gemm_v4r4r3_xdlops_nhwc_kyxc_nhwk.hpp deleted file mode 100644 index c1e63664e5..0000000000 --- a/driver/include/device_dynamic_convolution_forward_implicit_gemm_v4r4r3_xdlops_nhwc_kyxc_nhwk.hpp +++ /dev/null @@ -1,305 +0,0 @@ -#include -#include "device.hpp" -#include "host_tensor.hpp" -#include "transform_forward_convolution_into_gemm_v4r4r2_nhwc_kyxc_nhwk.hpp" -#include "driver_dynamic_gemm_xdlops_v2r3.hpp" - -template -void device_dynamic_convolution_forward_implicit_gemm_v4r4r3_xdlops_nhwc_kyxc_nhwk( - const InLengths& in_n_hi_wi_c_lengths, - const WeiLengths& wei_k_y_x_c_lengths, - const OutLengths& out_n_ho_wo_k_lengths, - const ConvStrides& conv_strides, - const ConvDilations& conv_dilations, - const InLeftPads& in_left_pads, - const InRightPads& in_right_pads, - const Tensor& in_n_hi_wi_c, - const Tensor& wei_k_y_x_c, - Tensor& out_n_ho_wo_k, - ck::index_t nrepeat) -{ - using namespace ck; - - std::cout << __func__ << std::endl; - - constexpr auto I0 = Number<0>{}; - constexpr auto I1 = Number<1>{}; - constexpr auto I2 = Number<2>{}; - constexpr auto I3 = Number<3>{}; - constexpr auto I4 = Number<4>{}; - constexpr auto I5 = Number<5>{}; - constexpr auto I6 = Number<6>{}; - constexpr auto I7 = Number<7>{}; - constexpr auto I8 = Number<8>{}; - - DeviceMem in_n_hi_wi_c_device_buf(sizeof(TInWei) * in_n_hi_wi_c.mDesc.GetElementSpace()); - DeviceMem wei_k_y_x_c_device_buf(sizeof(TInWei) * wei_k_y_x_c.mDesc.GetElementSpace()); - DeviceMem out_n_ho_wo_k_device_buf(sizeof(TOut) * out_n_ho_wo_k.mDesc.GetElementSpace()); - - in_n_hi_wi_c_device_buf.ToDevice(in_n_hi_wi_c.mData.data()); - wei_k_y_x_c_device_buf.ToDevice(wei_k_y_x_c.mData.data()); - out_n_ho_wo_k_device_buf.ToDevice(out_n_ho_wo_k.mData.data()); - - const auto in_n_hi_wi_c_desc = - make_dynamic_naive_tensor_descriptor_packed_v2(in_n_hi_wi_c_lengths); - const auto wei_k_y_x_c_desc = - make_dynamic_naive_tensor_descriptor_packed_v2(wei_k_y_x_c_lengths); - const auto out_n_ho_wo_k_desc = - make_dynamic_naive_tensor_descriptor_packed_v2(out_n_ho_wo_k_lengths); - -#if 1 - // [M, N, K0, K1] = [256, 128, 4, 4] for fp32 - constexpr index_t BlockSize = 256; - - constexpr index_t GemmMPerBlock = 256; - constexpr index_t GemmNPerBlock = 128; - constexpr index_t GemmKPerBlock = 4; - - constexpr index_t GemmMPerWave = 32; - constexpr index_t GemmNPerWave = 32; - constexpr index_t GemmK1 = 4; - - constexpr index_t MRepeat = 4; - constexpr index_t NRepeat = 2; - - using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 4>; - using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>; - - constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK1 = 4; - constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 4; - - using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 2, 4>; - using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>; - - constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmK1 = 4; - constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 4; - - constexpr index_t GemmCThreadTransferDstScalarPerVector = 4; -#elif 1 - // [M, N, K0, K1] = [128, 128, 4, 4] for fp32 - constexpr index_t BlockSize = 256; - - constexpr index_t GemmMPerBlock = 128; - constexpr index_t GemmNPerBlock = 128; - constexpr index_t GemmKPerBlock = 4; - - constexpr index_t GemmMPerWave = 32; - constexpr index_t GemmNPerWave = 32; - constexpr index_t GemmK1 = 4; - - constexpr index_t MRepeat = 2; - constexpr index_t NRepeat = 2; - - using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 2, 4>; - using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>; - - constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK1 = 4; - constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 4; - - using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 2, 4>; - using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>; - - constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmK1 = 4; - constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 4; - - constexpr index_t GemmCThreadTransferDstScalarPerVector = 4; -#elif 0 - // [M, N, K0, K1] = [256, 256, 4, 8] for fp16 - constexpr index_t BlockSize = 256; - - constexpr index_t GemmMPerBlock = 256; - constexpr index_t GemmNPerBlock = 256; - constexpr index_t GemmKPerBlock = 4; - - constexpr index_t GemmMPerWave = 32; - constexpr index_t GemmNPerWave = 32; - constexpr index_t GemmK1 = 8; - - constexpr index_t MRepeat = 4; - constexpr index_t NRepeat = 4; - - using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 8>; - using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>; - - constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK1 = 8; - constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 8; - - using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 4, 8>; - using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>; - - constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmK1 = 8; - constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 8; - - constexpr index_t GemmCThreadTransferDstScalarPerVector = 4; -#elif 1 - // [M, N, K0, K1] = [256, 128, 4, 8] for fp16 - constexpr index_t BlockSize = 256; - - constexpr index_t GemmMPerBlock = 256; - constexpr index_t GemmNPerBlock = 128; - constexpr index_t GemmKPerBlock = 4; - - constexpr index_t GemmMPerWave = 32; - constexpr index_t GemmNPerWave = 32; - constexpr index_t GemmK1 = 8; - - constexpr index_t MRepeat = 4; - constexpr index_t NRepeat = 2; - - using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 8>; - using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>; - - constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK1 = 8; - constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 8; - - using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 2, 8>; - using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>; - - constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmK1 = 8; - constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 8; - - constexpr index_t GemmCThreadTransferDstScalarPerVector = 4; -#endif - - const auto descs = - transform_forward_convolution_into_gemm_v4r4r2_nhwc_kyxc_nhwk_pad(wei_k_y_x_c_desc, - in_n_hi_wi_c_desc, - out_n_ho_wo_k_desc, - conv_strides, - conv_dilations, - in_left_pads, - in_right_pads, - Number{}); - - const auto wei_gemmk0_gemmm_gemmk1_grid_desc = descs[I0]; - const auto in_gemmk0_gemmn_gemmk1_grid_desc = descs[I1]; - const auto out_gemmm_gemmn_grid_desc = descs[I2]; - - // HACK: hacks that control index calculation when iterating over A, B, C matrix - constexpr auto wei_gemmk0_gemmm_gemmk1_grid_iterator_hacks = make_tuple( - make_tuple(Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}), - make_tuple( - Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{})); - - constexpr auto in_gemmk0_gemmn_gemmk1_grid_iterator_hacks = - make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{}, - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}), - make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{}, - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{})); - - constexpr auto out_m0_m1_m2_n_grid_iterator_hacks = - make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 1, 0, 0>{}, - Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 1, 0, 0>{}, - Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 1, 0, 0>{}), - make_tuple(Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 2, 0, 0>{}, - Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 2, 0, 0>{}, - Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 2, 0, 0>{})); - - constexpr auto wei_gemmk0_gemmm_gemmk1_grid_move_slice_window_iterator_hacks = - Sequence<0, 0, 0, 0, 0>{}; - - constexpr auto in_gemmk0_gemmn_gemmk1_grid_move_slice_window_iterator_hacks = - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 0, 0>{}; - - for(index_t i = 0; i < 5; ++i) - { - float ave_time = driver_dynamic_gemm_xdlops_v2r3< - BlockSize, - TInWei, - TAcc, - TOut, - InMemoryDataOperation::Set, - decltype(wei_gemmk0_gemmm_gemmk1_grid_desc), - decltype(in_gemmk0_gemmn_gemmk1_grid_desc), - decltype(out_gemmm_gemmn_grid_desc), - GemmMPerBlock, - GemmNPerBlock, - GemmKPerBlock, - GemmMPerWave, - GemmNPerWave, - MRepeat, - NRepeat, - GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1, - GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1, - Sequence<1, 0, 2>, - Sequence<1, 0, 2>, - 2, - GemmABlockTransferSrcScalarPerVector_GemmK1, - GemmABlockTransferDstScalarPerVector_GemmK1, - false, // don't move back src coordinate after threadwise copy - GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1, - GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1, - Sequence<1, 0, 2>, - Sequence<1, 0, 2>, - 2, - GemmBBlockTransferSrcScalarPerVector_GemmK1, - GemmBBlockTransferDstScalarPerVector_GemmK1, - false, // don't move back src coordinate after threadwise copy - Sequence<2, 3, 0, 1, 7, 5, 4, 6>, - 6, - GemmCThreadTransferDstScalarPerVector, - decltype(wei_gemmk0_gemmm_gemmk1_grid_iterator_hacks), - decltype(in_gemmk0_gemmn_gemmk1_grid_iterator_hacks), - decltype(out_m0_m1_m2_n_grid_iterator_hacks), - decltype(wei_gemmk0_gemmm_gemmk1_grid_move_slice_window_iterator_hacks), - decltype(in_gemmk0_gemmn_gemmk1_grid_move_slice_window_iterator_hacks), - false // CAccessOrderMRepeatNRepeat - >(static_cast(wei_k_y_x_c_device_buf.GetDeviceBuffer()), - static_cast(in_n_hi_wi_c_device_buf.GetDeviceBuffer()), - static_cast(out_n_ho_wo_k_device_buf.GetDeviceBuffer()), - wei_gemmk0_gemmm_gemmk1_grid_desc, - in_gemmk0_gemmn_gemmk1_grid_desc, - out_gemmm_gemmn_grid_desc, - wei_gemmk0_gemmm_gemmk1_grid_iterator_hacks, - in_gemmk0_gemmn_gemmk1_grid_iterator_hacks, - out_m0_m1_m2_n_grid_iterator_hacks, - wei_gemmk0_gemmm_gemmk1_grid_move_slice_window_iterator_hacks, - in_gemmk0_gemmn_gemmk1_grid_move_slice_window_iterator_hacks, - nrepeat); - - { - const auto N = out_n_ho_wo_k_lengths[I0]; - const auto K = out_n_ho_wo_k_lengths[I3]; - const auto C = wei_k_y_x_c_lengths[I3]; - - const auto Hi = in_n_hi_wi_c_lengths[I1]; - const auto Wi = in_n_hi_wi_c_lengths[I2]; - - const auto Ho = out_n_ho_wo_k_lengths[I1]; - const auto Wo = out_n_ho_wo_k_lengths[I2]; - - const auto Y = wei_k_y_x_c_lengths[I1]; - const auto X = wei_k_y_x_c_lengths[I2]; - - float perf = (float)(std::size_t(2) * N * K * Ho * Wo * C * Y * X) / - (std::size_t(1000) * 1000 * 1000) / ave_time; - - std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s" - << std::endl; - } - } - - // copy result back to host - out_n_ho_wo_k_device_buf.FromDevice(out_n_ho_wo_k.mData.data()); -} diff --git a/driver/include/device_dynamic_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk.hpp b/driver/include/device_dynamic_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk.hpp index f1a0bed7c0..0455f77718 100644 --- a/driver/include/device_dynamic_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk.hpp +++ b/driver/include/device_dynamic_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk.hpp @@ -223,34 +223,6 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nh constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmK1 = 8; constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 8; - constexpr index_t GemmCThreadTransferDstScalarPerVector = 1; -#elif 1 - // [M, N, K0, K1] = [256, 128, 4, 8] for fp16 - constexpr index_t BlockSize = 256; - - constexpr index_t GemmMPerBlock = 256; - constexpr index_t GemmNPerBlock = 128; - constexpr index_t GemmKPerBlock = 4; - - constexpr index_t GemmMPerWave = 64; - constexpr index_t GemmNPerWave = 64; - constexpr index_t GemmK1 = 8; - - constexpr index_t MRepeat = 2; - constexpr index_t NRepeat = 1; - - using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 8>; - using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>; - - constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK1 = 8; - constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 8; - - using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 2, 8>; - using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>; - - constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmK1 = 8; - constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 8; - constexpr index_t GemmCThreadTransferDstScalarPerVector = 1; #endif @@ -325,6 +297,7 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nh GemmKPerBlock, GemmMPerWave, GemmNPerWave, + GemmK1, MRepeat, NRepeat, GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1, diff --git a/driver/include/olc_device_dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw.hpp b/driver/include/olc_device_dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw.hpp new file mode 100644 index 0000000000..2b653dbae1 --- /dev/null +++ b/driver/include/olc_device_dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw.hpp @@ -0,0 +1,376 @@ +#include "device.hpp" +#include "host_tensor.hpp" +#include "dynamic_tensor_descriptor.hpp" +#include "dynamic_tensor_descriptor_helper.hpp" + +#include "olc_driver_common.hpp" +#include "conv_tunables.hpp" + +#include "handle.hpp" + +namespace detail_dyn_conv_fwd_v4r4_xdlops_nchw_kcyx_nkhw { + +template +static std::string get_network_config_string_from_types() +{ + std::string out; + + out += static_cast(Driver::get_typeid_from_type()) + + static_cast(Driver::get_typeid_from_type()) + + static_cast(Driver::get_typeid_from_type()); + + return (out); +}; + +static std::string +get_network_config_string_from_tunable(const tunable_dyn_conv_fwd_v4r4_xdlops_nchw_kcyx_nkhw* pt) +{ + std::string out("TUN_"); + + out += std::to_string(pt->BlockSize) + "_"; + + out += std::to_string(pt->MPerBlock) + "x" + std::to_string(pt->NPerBlock) + "x" + + std::to_string(pt->KPerBlock) + "_"; + out += std::to_string(pt->MPerWave) + "x" + std::to_string(pt->NPerWave) + "x" + + std::to_string(pt->MRepeat) + "x" + std::to_string(pt->NRepeat) + "x" + + std::to_string(pt->K1) + "_"; + + out += std::to_string(pt->ABlockTransferThreadSliceLengths_K0_M_K1[0]) + "x" + + std::to_string(pt->ABlockTransferThreadSliceLengths_K0_M_K1[1]) + "x" + + std::to_string(pt->ABlockTransferThreadSliceLengths_K0_M_K1[2]) + "_"; + + out += std::to_string(pt->ABlockTransferThreadClusterLengths_K0_M_K1[0]) + "x" + + std::to_string(pt->ABlockTransferThreadClusterLengths_K0_M_K1[1]) + "x" + + std::to_string(pt->ABlockTransferThreadClusterLengths_K0_M_K1[2]) + "_"; + + out += std::to_string(pt->ABlockTransferThreadClusterArrangeOrder[0]) + "x" + + std::to_string(pt->ABlockTransferThreadClusterArrangeOrder[1]) + "x" + + std::to_string(pt->ABlockTransferThreadClusterArrangeOrder[2]) + "_"; + + out += std::to_string(pt->ABlockTransferSrcAccessOrder[0]) + "x" + + std::to_string(pt->ABlockTransferSrcAccessOrder[1]) + "x" + + std::to_string(pt->ABlockTransferSrcAccessOrder[2]) + "_"; + + out += std::to_string(pt->ABlockTransferSrcVectorDim) + "_"; + out += std::to_string(pt->ABlockTransferSrcScalarPerVector) + "_"; + out += std::to_string(pt->ABlockTransferDstScalarPerVector_K1) + "_"; + out += std::to_string(pt->AThreadTransferSrcResetCoordinateAfterRun) + "_"; + + out += std::to_string(pt->BBlockTransferThreadSliceLengths_K0_N_K1[0]) + "x" + + std::to_string(pt->BBlockTransferThreadSliceLengths_K0_N_K1[1]) + "x" + + std::to_string(pt->BBlockTransferThreadSliceLengths_K0_N_K1[2]) + "_"; + + out += std::to_string(pt->BBlockTransferThreadClusterLengths_K0_N_K1[0]) + "x" + + std::to_string(pt->BBlockTransferThreadClusterLengths_K0_N_K1[1]) + "x" + + std::to_string(pt->BBlockTransferThreadClusterLengths_K0_N_K1[2]) + "_"; + + out += std::to_string(pt->BBlockTransferThreadClusterArrangeOrder[0]) + "x" + + std::to_string(pt->BBlockTransferThreadClusterArrangeOrder[1]) + "x" + + std::to_string(pt->BBlockTransferThreadClusterArrangeOrder[2]) + "_"; + + out += std::to_string(pt->BBlockTransferSrcAccessOrder[0]) + "x" + + std::to_string(pt->BBlockTransferSrcAccessOrder[1]) + "x" + + std::to_string(pt->BBlockTransferSrcAccessOrder[2]) + "_"; + + out += std::to_string(pt->BBlockTransferSrcVectorDim) + "_"; + out += std::to_string(pt->BBlockTransferSrcScalarPerVector) + "_"; + out += std::to_string(pt->BBlockTransferDstScalarPerVector_K1) + "_"; + out += std::to_string(pt->BThreadTransferSrcResetCoordinateAfterRun) + "_"; + + out += std::to_string(pt->CThreadTransferSrcDstAccessOrder[0]) + "x" + + std::to_string(pt->CThreadTransferSrcDstAccessOrder[1]) + "x" + + std::to_string(pt->CThreadTransferSrcDstAccessOrder[2]) + "x" + + std::to_string(pt->CThreadTransferSrcDstAccessOrder[3]) + "x" + + std::to_string(pt->CThreadTransferSrcDstAccessOrder[4]) + "x" + + std::to_string(pt->CThreadTransferSrcDstAccessOrder[5]) + "x" + + std::to_string(pt->CThreadTransferSrcDstAccessOrder[6]) + "x" + + std::to_string(pt->CThreadTransferSrcDstAccessOrder[7]) + "_"; + + out += std::to_string(pt->CThreadTransferSrcDstVectorDim) + "_"; + out += std::to_string(pt->CThreadTransferDstScalarPerVector); + + return (out); +}; + +template +static std::string get_definition_string_from_types() +{ + std::string out; + + out += " -DCK_PARAM_IN_WEI_DATATYPE=" + std::to_string(Driver::get_typeid_from_type()) + + " -DCK_PARAM_CONV_COMPTYPE=" + std::to_string(Driver::get_typeid_from_type()) + + " -DCK_PARAM_OUT_DATATYPE=" + std::to_string(Driver::get_typeid_from_type()); + + return (out); +}; + +static std::string +get_definition_string_from_tunable(const tunable_dyn_conv_fwd_v4r4_xdlops_nchw_kcyx_nkhw* pt) +{ + std::string out; + + out += " -DCK_PARAM_BlockSize=" + std::to_string(pt->BlockSize); + + out += " -DCK_PARAM_MPerBlock=" + std::to_string(pt->MPerBlock) + + " -DCK_PARAM_NPerBlock=" + std::to_string(pt->NPerBlock) + + " -DCK_PARAM_KPerBlock=" + std::to_string(pt->KPerBlock); + out += " -DCK_PARAM_MPerWave=" + std::to_string(pt->MPerWave) + + " -DCK_PARAM_NPerWave=" + std::to_string(pt->NPerWave) + + " -DCK_PARAM_K1=" + std::to_string(pt->K1) + + " -DCK_PARAM_MRepeat=" + std::to_string(pt->MRepeat) + + " -DCK_PARAM_NRepeat=" + std::to_string(pt->NRepeat); + + out += " -DCK_PARAM_ABlockTransferThreadSliceLengths_K0_M_K1=" + + std::to_string(pt->ABlockTransferThreadSliceLengths_K0_M_K1[0]) + "," + + std::to_string(pt->ABlockTransferThreadSliceLengths_K0_M_K1[1]) + "," + + std::to_string(pt->ABlockTransferThreadSliceLengths_K0_M_K1[2]); + + out += " -DCK_PARAM_ABlockTransferThreadClusterLengths_K0_M_K1=" + + std::to_string(pt->ABlockTransferThreadClusterLengths_K0_M_K1[0]) + "," + + std::to_string(pt->ABlockTransferThreadClusterLengths_K0_M_K1[1]) + "," + + std::to_string(pt->ABlockTransferThreadClusterLengths_K0_M_K1[2]); + + out += " -DCK_PARAM_ABlockTransferThreadClusterArrangeOrder=" + + std::to_string(pt->ABlockTransferThreadClusterArrangeOrder[0]) + "," + + std::to_string(pt->ABlockTransferThreadClusterArrangeOrder[1]) + "," + + std::to_string(pt->ABlockTransferThreadClusterArrangeOrder[2]); + + out += " -DCK_PARAM_ABlockTransferSrcAccessOrder=" + + std::to_string(pt->ABlockTransferSrcAccessOrder[0]) + "," + + std::to_string(pt->ABlockTransferSrcAccessOrder[1]) + "," + + std::to_string(pt->ABlockTransferSrcAccessOrder[2]); + + out += + " -DCK_PARAM_ABlockTransferSrcVectorDim=" + std::to_string(pt->ABlockTransferSrcVectorDim); + out += " -DCK_PARAM_ABlockTransferSrcScalarPerVector=" + + std::to_string(pt->ABlockTransferSrcScalarPerVector); + out += " -DCK_PARAM_ABlockTransferDstScalarPerVector_K1=" + + std::to_string(pt->ABlockTransferDstScalarPerVector_K1); + out += " -DCK_PARAM_AThreadTransferSrcResetCoordinateAfterRun=" + + std::to_string(pt->AThreadTransferSrcResetCoordinateAfterRun); + + out += " -DCK_PARAM_BBlockTransferThreadSliceLengths_K0_N_K1=" + + std::to_string(pt->BBlockTransferThreadSliceLengths_K0_N_K1[0]) + "," + + std::to_string(pt->BBlockTransferThreadSliceLengths_K0_N_K1[1]) + "," + + std::to_string(pt->BBlockTransferThreadSliceLengths_K0_N_K1[2]); + + out += " -DCK_PARAM_BBlockTransferThreadClusterLengths_K0_N_K1=" + + std::to_string(pt->BBlockTransferThreadClusterLengths_K0_N_K1[0]) + "," + + std::to_string(pt->BBlockTransferThreadClusterLengths_K0_N_K1[1]) + "," + + std::to_string(pt->BBlockTransferThreadClusterLengths_K0_N_K1[2]); + + out += " -DCK_PARAM_BBlockTransferThreadClusterArrangeOrder=" + + std::to_string(pt->BBlockTransferThreadClusterArrangeOrder[0]) + "," + + std::to_string(pt->BBlockTransferThreadClusterArrangeOrder[1]) + "," + + std::to_string(pt->BBlockTransferThreadClusterArrangeOrder[2]); + + out += " -DCK_PARAM_BBlockTransferSrcAccessOrder=" + + std::to_string(pt->BBlockTransferSrcAccessOrder[0]) + "," + + std::to_string(pt->BBlockTransferSrcAccessOrder[1]) + "," + + std::to_string(pt->BBlockTransferSrcAccessOrder[2]); + + out += + " -DCK_PARAM_BBlockTransferSrcVectorDim=" + std::to_string(pt->BBlockTransferSrcVectorDim); + out += " -DCK_PARAM_BBlockTransferSrcScalarPerVector=" + + std::to_string(pt->BBlockTransferSrcScalarPerVector); + out += " -DCK_PARAM_BBlockTransferDstScalarPerVector_K1=" + + std::to_string(pt->BBlockTransferDstScalarPerVector_K1); + out += " -DCK_PARAM_BThreadTransferSrcResetCoordinateAfterRun=" + + std::to_string(pt->BThreadTransferSrcResetCoordinateAfterRun); + + out += " -DCK_PARAM_CThreadTransferSrcDstAccessOrder=" + + std::to_string(pt->CThreadTransferSrcDstAccessOrder[0]) + "," + + std::to_string(pt->CThreadTransferSrcDstAccessOrder[1]) + "," + + std::to_string(pt->CThreadTransferSrcDstAccessOrder[2]) + "," + + std::to_string(pt->CThreadTransferSrcDstAccessOrder[3]) + "," + + std::to_string(pt->CThreadTransferSrcDstAccessOrder[4]) + "," + + std::to_string(pt->CThreadTransferSrcDstAccessOrder[5]) + "," + + std::to_string(pt->CThreadTransferSrcDstAccessOrder[6]) + "," + + std::to_string(pt->CThreadTransferSrcDstAccessOrder[7]); + + out += " -DCK_PARAM_CThreadTransferSrcDstVectorDim=" + + std::to_string(pt->CThreadTransferSrcDstVectorDim); + out += " -DCK_PARAM_CThreadTransferDstScalarPerVector=" + + std::to_string(pt->CThreadTransferDstScalarPerVector); + + return (out); +}; + +} // namespace detail_dyn_conv_fwd_v4r4_xdlops_nchw_kcyx_nkhw + +template +void device_dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw_olc( + olCompile::Handle* handle, + const InLengths& in_n_c_hi_wi_lengths, + const WeiLengths& wei_k_c_y_x_lengths, + const OutLengths& out_n_k_ho_wo_lengths, + const ConvStrides& conv_strides, + const ConvDilations& conv_dilations, + const InLeftPads& in_left_pads, + const InRightPads& in_right_pads, + const Tensor& in_n_c_hi_wi, + const Tensor& wei_k_c_y_x, + Tensor& out_n_k_ho_wo, + const tunable_dyn_conv_fwd_v4r4_xdlops_nchw_kcyx_nkhw* tunable, + ck::index_t nrepeat) +{ + using namespace ck; + using namespace detail_dyn_conv_fwd_v4r4_xdlops_nchw_kcyx_nkhw; + using size_t = std::size_t; + + ///////////////////////////////////////////////////////////////////////////////////////////////////////////// + + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + constexpr auto I2 = Number<2>{}; + constexpr auto I3 = Number<3>{}; + + const auto in_n_c_hi_wi_desc = + make_dynamic_naive_tensor_descriptor_packed_v2(in_n_c_hi_wi_lengths); + const auto wei_k_c_y_x_desc = + make_dynamic_naive_tensor_descriptor_packed_v2(wei_k_c_y_x_lengths); + const auto out_n_k_ho_wo_desc = + make_dynamic_naive_tensor_descriptor_packed_v2(out_n_k_ho_wo_lengths); + + const auto n = in_n_c_hi_wi_desc.GetLength(I0); + const auto c = in_n_c_hi_wi_desc.GetLength(I1); + const auto hi = in_n_c_hi_wi_desc.GetLength(I2); + const auto wi = in_n_c_hi_wi_desc.GetLength(I3); + const auto k = wei_k_c_y_x_desc.GetLength(I0); + const auto y = wei_k_c_y_x_desc.GetLength(I2); + const auto x = wei_k_c_y_x_desc.GetLength(I3); + const auto ho = out_n_k_ho_wo_desc.GetLength(I2); + const auto wo = out_n_k_ho_wo_desc.GetLength(I3); + + const auto M = k; + const auto N = n * ho * wo; + const auto K = c * y * x; + const auto K0 = K / tunable->K1; + + const index_t grid_size = (M / tunable->MPerBlock) * (N / tunable->NPerBlock); + ///////////////////////////////////////////////////////////////////////////////////////////////////////////// + + // these buffers are usually provided by the user application + DeviceMem in_n_c_hi_wi_dev_buf(sizeof(TInWei) * in_n_c_hi_wi.mDesc.GetElementSpace()); + DeviceMem wei_k_c_y_x_dev_buf(sizeof(TInWei) * wei_k_c_y_x.mDesc.GetElementSpace()); + DeviceMem out_n_k_ho_wo_dev_buf(sizeof(TOut) * out_n_k_ho_wo.mDesc.GetElementSpace()); + + in_n_c_hi_wi_dev_buf.ToDevice(in_n_c_hi_wi.mData.data()); + wei_k_c_y_x_dev_buf.ToDevice(wei_k_c_y_x.mData.data()); + out_n_k_ho_wo_dev_buf.ToDevice(out_n_k_ho_wo.mData.data()); + + // these are workspace buffers that should be expressed to the user by the corresponding + // workspace API + DeviceMem workspace_buf(4096); + + void* a_k_m0_m1_grid_desc_dev_buf = workspace_buf.GetDeviceBuffer(); + void* b_k_n0_n1_grid_desc_dev_buf = + static_cast(static_cast(workspace_buf.GetDeviceBuffer()) + 1024); + void* c_m0_m10_m11_n0_n10_n11_grid_desc_dev_buf = + static_cast(static_cast(workspace_buf.GetDeviceBuffer()) + 2048); + void* c_blockid_to_m0_n0_block_cluster_adaptor_dev_buf = + static_cast(static_cast(workspace_buf.GetDeviceBuffer()) + 3072); + + const std::vector vld = {static_cast(tunable->BlockSize), 1, 1}; + const std::vector vgd1 = {static_cast(tunable->BlockSize), 1, 1}; + const std::vector vgd2 = {static_cast(grid_size * tunable->BlockSize), 1, 1}; + + std::string program_name = + "dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw.cpp"; + std::string algo_name = "implicit_gemm_conv_fwd_v4r4_xdlops_nchw"; + + std::string param = " -std=c++17 "; + std::string network_config; + + param += get_definition_string_from_types() + " " + " -DCK_USE_AMD_XDLOPS" + + get_definition_string_from_tunable(tunable); + + network_config = get_network_config_string_from_types() + "_" + + get_network_config_string_from_tunable(tunable); + + std::vector kernel1_times; + std::vector kernel2_times; + + for(index_t i = 0; i < nrepeat; ++i) + { + KernelTimer timer1, timer2; + std::string kernel_name; + + kernel_name = + "dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw_prepare"; + auto network_config_1 = network_config + "_1"; + + timer1.Start(); + handle->AddKernel(algo_name, network_config_1, program_name, kernel_name, vld, vgd1, param)( + static_cast(in_n_c_hi_wi_lengths[I0]), + static_cast(in_n_c_hi_wi_lengths[I1]), + static_cast(in_n_c_hi_wi_lengths[I2]), + static_cast(in_n_c_hi_wi_lengths[I3]), + static_cast(wei_k_c_y_x_lengths[I0]), + static_cast(wei_k_c_y_x_lengths[I2]), + static_cast(wei_k_c_y_x_lengths[I3]), + conv_strides[I0], + conv_strides[I1], + conv_dilations[I0], + conv_dilations[I1], + in_left_pads[I0], + in_left_pads[I1], + in_right_pads[I0], + in_right_pads[I1], + a_k_m0_m1_grid_desc_dev_buf, + b_k_n0_n1_grid_desc_dev_buf, + c_m0_m10_m11_n0_n10_n11_grid_desc_dev_buf, + c_blockid_to_m0_n0_block_cluster_adaptor_dev_buf); + timer1.End(); + + kernel_name = "dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw"; + auto network_config_2 = network_config + "_2"; + + timer2.Start(); + handle->AddKernel(algo_name, network_config_2, program_name, kernel_name, vld, vgd2, param)( + reinterpret_cast(wei_k_c_y_x_dev_buf.GetDeviceBuffer()), + reinterpret_cast(in_n_c_hi_wi_dev_buf.GetDeviceBuffer()), + reinterpret_cast(out_n_k_ho_wo_dev_buf.GetDeviceBuffer()), + (const void*)(a_k_m0_m1_grid_desc_dev_buf), + (const void*)(b_k_n0_n1_grid_desc_dev_buf), + (const void*)(c_m0_m10_m11_n0_n10_n11_grid_desc_dev_buf), + (const void*)(c_blockid_to_m0_n0_block_cluster_adaptor_dev_buf)); + timer2.End(); + + kernel1_times.push_back(timer1.GetElapsedTime()); + kernel2_times.push_back(timer2.GetElapsedTime()); + } + + { + auto ave_time1 = Driver::get_effective_average(kernel1_times); + auto ave_time2 = Driver::get_effective_average(kernel2_times); + + const auto N = in_n_c_hi_wi_lengths[I0]; + const auto C = in_n_c_hi_wi_lengths[I1]; + + const auto K = out_n_k_ho_wo_lengths[I1]; + const auto Ho = out_n_k_ho_wo_lengths[I2]; + const auto Wo = out_n_k_ho_wo_lengths[I3]; + + const auto Y = wei_k_c_y_x_lengths[I2]; + const auto X = wei_k_c_y_x_lengths[I3]; + + float perf = (float)(std::size_t(2) * N * K * Ho * Wo * C * Y * X) / + (std::size_t(1000) * 1000 * 1000) / (ave_time1 + ave_time2); + + std::cout << "Average time : " << ave_time1 + ave_time2 << " ms(" << ave_time1 << ", " + << ave_time2 << "), " << perf << " TFlop/s" << std::endl; + }; + + // copy result back to host + out_n_k_ho_wo_dev_buf.FromDevice(out_n_k_ho_wo.mData.data()); +} diff --git a/driver/include/olc_device_dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nhwc_kyxc_nhwk.hpp b/driver/include/olc_device_dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nhwc_kyxc_nhwk.hpp new file mode 100644 index 0000000000..073e90bd63 --- /dev/null +++ b/driver/include/olc_device_dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nhwc_kyxc_nhwk.hpp @@ -0,0 +1,379 @@ +#include "device.hpp" +#include "host_tensor.hpp" +#include "dynamic_tensor_descriptor.hpp" +#include "dynamic_tensor_descriptor_helper.hpp" +#include "transform_forward_convolution_into_gemm_v4r4r4_nhwc_kyxc_nhwk.hpp" + +#include "olc_driver_common.hpp" +#include "conv_tunables.hpp" + +#include "handle.hpp" + +namespace detail_dyn_conv_fwd_v4r4_xdlops_nhwc_kyxc_nhwk { + +template +static std::string get_network_config_string_from_types() +{ + std::string out; + + out += static_cast(Driver::get_typeid_from_type()) + + static_cast(Driver::get_typeid_from_type()) + + static_cast(Driver::get_typeid_from_type()); + + return (out); +}; + +static std::string +get_network_config_string_from_tunable(const tunable_dyn_conv_fwd_v4r4_xdlops_nhwc_kyxc_nhwk* pt) +{ + std::string out("TUN_"); + + out += std::to_string(pt->BlockSize) + "_"; + + out += std::to_string(pt->MPerBlock) + "x" + std::to_string(pt->NPerBlock) + "x" + + std::to_string(pt->KPerBlock) + "_"; + out += std::to_string(pt->MPerWave) + "x" + std::to_string(pt->NPerWave) + "x" + + std::to_string(pt->MRepeat) + "x" + std::to_string(pt->NRepeat) + "x" + + std::to_string(pt->K1) + "_"; + + out += std::to_string(pt->ABlockTransferThreadSliceLengths_K0_M_K1[0]) + "x" + + std::to_string(pt->ABlockTransferThreadSliceLengths_K0_M_K1[1]) + "x" + + std::to_string(pt->ABlockTransferThreadSliceLengths_K0_M_K1[2]) + "_"; + + out += std::to_string(pt->ABlockTransferThreadClusterLengths_K0_M_K1[0]) + "x" + + std::to_string(pt->ABlockTransferThreadClusterLengths_K0_M_K1[1]) + "x" + + std::to_string(pt->ABlockTransferThreadClusterLengths_K0_M_K1[2]) + "_"; + + out += std::to_string(pt->ABlockTransferThreadClusterArrangeOrder[0]) + "x" + + std::to_string(pt->ABlockTransferThreadClusterArrangeOrder[1]) + "x" + + std::to_string(pt->ABlockTransferThreadClusterArrangeOrder[2]) + "_"; + + out += std::to_string(pt->ABlockTransferSrcAccessOrder[0]) + "x" + + std::to_string(pt->ABlockTransferSrcAccessOrder[1]) + "x" + + std::to_string(pt->ABlockTransferSrcAccessOrder[2]) + "_"; + + out += std::to_string(pt->ABlockTransferSrcVectorDim) + "_"; + out += std::to_string(pt->ABlockTransferSrcScalarPerVector) + "_"; + out += std::to_string(pt->ABlockTransferDstScalarPerVector_K1) + "_"; + out += std::to_string(pt->AThreadTransferSrcResetCoordinateAfterRun) + "_"; + + out += std::to_string(pt->BBlockTransferThreadSliceLengths_K0_N_K1[0]) + "x" + + std::to_string(pt->BBlockTransferThreadSliceLengths_K0_N_K1[1]) + "x" + + std::to_string(pt->BBlockTransferThreadSliceLengths_K0_N_K1[2]) + "_"; + + out += std::to_string(pt->BBlockTransferThreadClusterLengths_K0_N_K1[0]) + "x" + + std::to_string(pt->BBlockTransferThreadClusterLengths_K0_N_K1[1]) + "x" + + std::to_string(pt->BBlockTransferThreadClusterLengths_K0_N_K1[2]) + "_"; + + out += std::to_string(pt->BBlockTransferThreadClusterArrangeOrder[0]) + "x" + + std::to_string(pt->BBlockTransferThreadClusterArrangeOrder[1]) + "x" + + std::to_string(pt->BBlockTransferThreadClusterArrangeOrder[2]) + "_"; + + out += std::to_string(pt->BBlockTransferSrcAccessOrder[0]) + "x" + + std::to_string(pt->BBlockTransferSrcAccessOrder[1]) + "x" + + std::to_string(pt->BBlockTransferSrcAccessOrder[2]) + "_"; + + out += std::to_string(pt->BBlockTransferSrcVectorDim) + "_"; + out += std::to_string(pt->BBlockTransferSrcScalarPerVector) + "_"; + out += std::to_string(pt->BBlockTransferDstScalarPerVector_K1) + "_"; + out += std::to_string(pt->BThreadTransferSrcResetCoordinateAfterRun) + "_"; + + out += std::to_string(pt->CThreadTransferSrcDstAccessOrder[0]) + "x" + + std::to_string(pt->CThreadTransferSrcDstAccessOrder[1]) + "x" + + std::to_string(pt->CThreadTransferSrcDstAccessOrder[2]) + "x" + + std::to_string(pt->CThreadTransferSrcDstAccessOrder[3]) + "x" + + std::to_string(pt->CThreadTransferSrcDstAccessOrder[4]) + "x" + + std::to_string(pt->CThreadTransferSrcDstAccessOrder[5]) + "x" + + std::to_string(pt->CThreadTransferSrcDstAccessOrder[6]) + "x" + + std::to_string(pt->CThreadTransferSrcDstAccessOrder[7]) + "_"; + + out += std::to_string(pt->CThreadTransferSrcDstVectorDim) + "_"; + out += std::to_string(pt->CThreadTransferDstScalarPerVector); + + return (out); +}; + +template +static std::string get_definition_string_from_types() +{ + std::string out; + + out += " -DCK_PARAM_IN_WEI_DATATYPE=" + std::to_string(Driver::get_typeid_from_type()) + + " -DCK_PARAM_CONV_COMPTYPE=" + std::to_string(Driver::get_typeid_from_type()) + + " -DCK_PARAM_OUT_DATATYPE=" + std::to_string(Driver::get_typeid_from_type()); + + return (out); +}; + +static std::string +get_definition_string_from_tunable(const tunable_dyn_conv_fwd_v4r4_xdlops_nhwc_kyxc_nhwk* pt) +{ + std::string out; + + out += " -DCK_PARAM_BlockSize=" + std::to_string(pt->BlockSize); + + out += " -DCK_PARAM_MPerBlock=" + std::to_string(pt->MPerBlock) + + " -DCK_PARAM_NPerBlock=" + std::to_string(pt->NPerBlock) + + " -DCK_PARAM_KPerBlock=" + std::to_string(pt->KPerBlock); + out += " -DCK_PARAM_MPerWave=" + std::to_string(pt->MPerWave) + + " -DCK_PARAM_NPerWave=" + std::to_string(pt->NPerWave) + + " -DCK_PARAM_K1=" + std::to_string(pt->K1) + + " -DCK_PARAM_MRepeat=" + std::to_string(pt->MRepeat) + + " -DCK_PARAM_NRepeat=" + std::to_string(pt->NRepeat); + + out += " -DCK_PARAM_ABlockTransferThreadSliceLengths_K0_M_K1=" + + std::to_string(pt->ABlockTransferThreadSliceLengths_K0_M_K1[0]) + "," + + std::to_string(pt->ABlockTransferThreadSliceLengths_K0_M_K1[1]) + "," + + std::to_string(pt->ABlockTransferThreadSliceLengths_K0_M_K1[2]); + + out += " -DCK_PARAM_ABlockTransferThreadClusterLengths_K0_M_K1=" + + std::to_string(pt->ABlockTransferThreadClusterLengths_K0_M_K1[0]) + "," + + std::to_string(pt->ABlockTransferThreadClusterLengths_K0_M_K1[1]) + "," + + std::to_string(pt->ABlockTransferThreadClusterLengths_K0_M_K1[2]); + + out += " -DCK_PARAM_ABlockTransferThreadClusterArrangeOrder=" + + std::to_string(pt->ABlockTransferThreadClusterArrangeOrder[0]) + "," + + std::to_string(pt->ABlockTransferThreadClusterArrangeOrder[1]) + "," + + std::to_string(pt->ABlockTransferThreadClusterArrangeOrder[2]); + + out += " -DCK_PARAM_ABlockTransferSrcAccessOrder=" + + std::to_string(pt->ABlockTransferSrcAccessOrder[0]) + "," + + std::to_string(pt->ABlockTransferSrcAccessOrder[1]) + "," + + std::to_string(pt->ABlockTransferSrcAccessOrder[2]); + + out += + " -DCK_PARAM_ABlockTransferSrcVectorDim=" + std::to_string(pt->ABlockTransferSrcVectorDim); + out += " -DCK_PARAM_ABlockTransferSrcScalarPerVector=" + + std::to_string(pt->ABlockTransferSrcScalarPerVector); + out += " -DCK_PARAM_ABlockTransferDstScalarPerVector_K1=" + + std::to_string(pt->ABlockTransferDstScalarPerVector_K1); + out += " -DCK_PARAM_AThreadTransferSrcResetCoordinateAfterRun=" + + std::to_string(pt->AThreadTransferSrcResetCoordinateAfterRun); + + out += " -DCK_PARAM_BBlockTransferThreadSliceLengths_K0_N_K1=" + + std::to_string(pt->BBlockTransferThreadSliceLengths_K0_N_K1[0]) + "," + + std::to_string(pt->BBlockTransferThreadSliceLengths_K0_N_K1[1]) + "," + + std::to_string(pt->BBlockTransferThreadSliceLengths_K0_N_K1[2]); + + out += " -DCK_PARAM_BBlockTransferThreadClusterLengths_K0_N_K1=" + + std::to_string(pt->BBlockTransferThreadClusterLengths_K0_N_K1[0]) + "," + + std::to_string(pt->BBlockTransferThreadClusterLengths_K0_N_K1[1]) + "," + + std::to_string(pt->BBlockTransferThreadClusterLengths_K0_N_K1[2]); + + out += " -DCK_PARAM_BBlockTransferThreadClusterArrangeOrder=" + + std::to_string(pt->BBlockTransferThreadClusterArrangeOrder[0]) + "," + + std::to_string(pt->BBlockTransferThreadClusterArrangeOrder[1]) + "," + + std::to_string(pt->BBlockTransferThreadClusterArrangeOrder[2]); + + out += " -DCK_PARAM_BBlockTransferSrcAccessOrder=" + + std::to_string(pt->BBlockTransferSrcAccessOrder[0]) + "," + + std::to_string(pt->BBlockTransferSrcAccessOrder[1]) + "," + + std::to_string(pt->BBlockTransferSrcAccessOrder[2]); + + out += + " -DCK_PARAM_BBlockTransferSrcVectorDim=" + std::to_string(pt->BBlockTransferSrcVectorDim); + out += " -DCK_PARAM_BBlockTransferSrcScalarPerVector=" + + std::to_string(pt->BBlockTransferSrcScalarPerVector); + out += " -DCK_PARAM_BBlockTransferDstScalarPerVector_K1=" + + std::to_string(pt->BBlockTransferDstScalarPerVector_K1); + out += " -DCK_PARAM_BThreadTransferSrcResetCoordinateAfterRun=" + + std::to_string(pt->BThreadTransferSrcResetCoordinateAfterRun); + + out += " -DCK_PARAM_CThreadTransferSrcDstAccessOrder=" + + std::to_string(pt->CThreadTransferSrcDstAccessOrder[0]) + "," + + std::to_string(pt->CThreadTransferSrcDstAccessOrder[1]) + "," + + std::to_string(pt->CThreadTransferSrcDstAccessOrder[2]) + "," + + std::to_string(pt->CThreadTransferSrcDstAccessOrder[3]) + "," + + std::to_string(pt->CThreadTransferSrcDstAccessOrder[4]) + "," + + std::to_string(pt->CThreadTransferSrcDstAccessOrder[5]) + "," + + std::to_string(pt->CThreadTransferSrcDstAccessOrder[6]) + "," + + std::to_string(pt->CThreadTransferSrcDstAccessOrder[7]); + + out += " -DCK_PARAM_CThreadTransferSrcDstVectorDim=" + + std::to_string(pt->CThreadTransferSrcDstVectorDim); + out += " -DCK_PARAM_CThreadTransferDstScalarPerVector=" + + std::to_string(pt->CThreadTransferDstScalarPerVector); + + return (out); +}; + +} // namespace detail_dyn_conv_fwd_v4r4_xdlops_nhwc_kyxc_nhwk + +template +void device_dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nhwc_kyxc_nhwk_olc( + olCompile::Handle* handle, + const InLengths& in_n_hi_wi_c_lengths, + const WeiLengths& wei_k_y_x_c_lengths, + const OutLengths& out_n_ho_wo_k_lengths, + const ConvStrides& conv_strides, + const ConvDilations& conv_dilations, + const InLeftPads& in_left_pads, + const InRightPads& in_right_pads, + const Tensor& in_n_hi_wi_c, + const Tensor& wei_k_y_x_c, + Tensor& out_n_ho_wo_k, + const tunable_dyn_conv_fwd_v4r4_xdlops_nhwc_kyxc_nhwk* tunable, + ck::index_t nrepeat) +{ + using namespace ck; + using namespace detail_dyn_conv_fwd_v4r4_xdlops_nhwc_kyxc_nhwk; + using size_t = std::size_t; + + ///////////////////////////////////////////////////////////////////////////////////////////////////////////// + // The follow codes are only used for computing the grid_size, hasMainKBlockLoop, + // hasDoubleTailKBlockLoop + + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + constexpr auto I2 = Number<2>{}; + constexpr auto I3 = Number<3>{}; + + const auto in_n_hi_wi_c_desc = + make_dynamic_naive_tensor_descriptor_packed_v2(in_n_hi_wi_c_lengths); + const auto wei_k_y_x_c_desc = + make_dynamic_naive_tensor_descriptor_packed_v2(wei_k_y_x_c_lengths); + const auto out_n_ho_wo_k_desc = + make_dynamic_naive_tensor_descriptor_packed_v2(out_n_ho_wo_k_lengths); + + const auto n = in_n_hi_wi_c_desc.GetLength(I0); + const auto hi = in_n_hi_wi_c_desc.GetLength(I1); + const auto wi = in_n_hi_wi_c_desc.GetLength(I2); + const auto c = in_n_hi_wi_c_desc.GetLength(I3); + + const auto k = wei_k_y_x_c_desc.GetLength(I0); + const auto y = wei_k_y_x_c_desc.GetLength(I1); + const auto x = wei_k_y_x_c_desc.GetLength(I2); + + const auto ho = out_n_ho_wo_k_desc.GetLength(I1); + const auto wo = out_n_ho_wo_k_desc.GetLength(I2); + + const auto M = k; + const auto N = n * ho * wo; + const auto K = c * y * x; + const auto K0 = K / tunable->K1; + + const index_t grid_size = (M / tunable->MPerBlock) * (N / tunable->NPerBlock); + + // these buffers are usually provided by the user application + DeviceMem in_n_hi_wi_c_dev_buf(sizeof(TInWei) * in_n_hi_wi_c.mDesc.GetElementSpace()); + DeviceMem wei_k_y_x_c_dev_buf(sizeof(TInWei) * wei_k_y_x_c.mDesc.GetElementSpace()); + DeviceMem out_n_ho_wo_k_dev_buf(sizeof(TOut) * out_n_ho_wo_k.mDesc.GetElementSpace()); + + in_n_hi_wi_c_dev_buf.ToDevice(in_n_hi_wi_c.mData.data()); + wei_k_y_x_c_dev_buf.ToDevice(wei_k_y_x_c.mData.data()); + out_n_ho_wo_k_dev_buf.ToDevice(out_n_ho_wo_k.mData.data()); + + // these are workspace buffers that should be expressed to the user by the corresponding + // workspace API + DeviceMem workspace_buf(4096); + + void* a_k0_m_k1_grid_desc_dev_buf = workspace_buf.GetDeviceBuffer(); + void* b_k0_n_k1_grid_desc_dev_buf = + static_cast(static_cast(workspace_buf.GetDeviceBuffer()) + 1024); + void* c_m0_m1_m2_n_grid_desc_dev_buf = + static_cast(static_cast(workspace_buf.GetDeviceBuffer()) + 2048); + void* c_blockid_to_m0_n0_block_cluster_adaptor_dev_buf = + static_cast(static_cast(workspace_buf.GetDeviceBuffer()) + 3072); + + const std::vector vld = {static_cast(tunable->BlockSize), 1, 1}; + const std::vector vgd1 = {static_cast(tunable->BlockSize), 1, 1}; + const std::vector vgd2 = {static_cast(grid_size * tunable->BlockSize), 1, 1}; + + std::string program_name = + "dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nhwc_kyxc_nhwk.cpp"; + std::string algo_name = "implicit_gemm_conv_fwd_v4r4_xdlops_nhwc"; + + std::string param = " -std=c++17 "; + std::string network_config; + + param += get_definition_string_from_types() + " -DCK_USE_AMD_XDLOPS "; + param += get_definition_string_from_tunable(tunable); + + network_config = get_network_config_string_from_types() + "_" + + get_network_config_string_from_tunable(tunable); + + std::vector kernel1_times; + std::vector kernel2_times; + + KernelTimer timer1, timer2; + std::string kernel_name; + + kernel_name = "dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nhwc_kyxc_nhwk_prepare"; + auto network_config_1 = network_config + "_1"; + + timer1.Start(); + for(index_t i = 0; i < nrepeat; ++i) + { + handle->AddKernel(algo_name, network_config_1, program_name, kernel_name, vld, vgd1, param)( + static_cast(in_n_hi_wi_c_lengths[I0]), + static_cast(in_n_hi_wi_c_lengths[I1]), + static_cast(in_n_hi_wi_c_lengths[I2]), + static_cast(in_n_hi_wi_c_lengths[I3]), + static_cast(wei_k_y_x_c_lengths[I0]), + static_cast(wei_k_y_x_c_lengths[I1]), + static_cast(wei_k_y_x_c_lengths[I2]), + conv_strides[I0], + conv_strides[I1], + conv_dilations[I0], + conv_dilations[I1], + in_left_pads[I0], + in_left_pads[I1], + in_right_pads[I0], + in_right_pads[I1], + a_k0_m_k1_grid_desc_dev_buf, + b_k0_n_k1_grid_desc_dev_buf, + c_m0_m1_m2_n_grid_desc_dev_buf, + c_blockid_to_m0_n0_block_cluster_adaptor_dev_buf); + } + timer1.End(); + + kernel_name = "dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nhwc_kyxc_nhwk"; + auto network_config_2 = network_config + "_2"; + + timer2.Start(); + for(index_t i = 0; i < nrepeat; ++i) + { + handle->AddKernel(algo_name, network_config_2, program_name, kernel_name, vld, vgd2, param)( + reinterpret_cast(in_n_hi_wi_c_dev_buf.GetDeviceBuffer()), + reinterpret_cast(wei_k_y_x_c_dev_buf.GetDeviceBuffer()), + reinterpret_cast(out_n_ho_wo_k_dev_buf.GetDeviceBuffer()), + (const void*)(a_k0_m_k1_grid_desc_dev_buf), + (const void*)(b_k0_n_k1_grid_desc_dev_buf), + (const void*)(c_m0_m1_m2_n_grid_desc_dev_buf), + (const void*)(c_blockid_to_m0_n0_block_cluster_adaptor_dev_buf)); + } + timer2.End(); + + { + auto ave_time1 = timer1.GetElapsedTime() / nrepeat; + auto ave_time2 = timer2.GetElapsedTime() / nrepeat; + + const auto N = in_n_hi_wi_c_lengths[I0]; + const auto C = in_n_hi_wi_c_lengths[I3]; + + const auto Ho = out_n_ho_wo_k_lengths[I1]; + const auto Wo = out_n_ho_wo_k_lengths[I2]; + const auto K = out_n_ho_wo_k_lengths[I3]; + + const auto Y = wei_k_y_x_c_lengths[I1]; + const auto X = wei_k_y_x_c_lengths[I2]; + + float perf = (float)(std::size_t(2) * N * K * Ho * Wo * C * Y * X) / + (std::size_t(1000) * 1000 * 1000) / ave_time2; + + std::cout << "Average time : " << ave_time1 + ave_time2 << " ms(" << ave_time1 << ", " + << ave_time2 << "), " << perf << " TFlop/s" << std::endl; + }; + + // copy result back to host + out_n_ho_wo_k_dev_buf.FromDevice(out_n_ho_wo_k.mData.data()); +}