From 627d8ef35a6da8ad268b5197e3045ccdfb4ac684 Mon Sep 17 00:00:00 2001 From: ltqin Date: Tue, 31 Aug 2021 11:49:17 +0800 Subject: [PATCH] Backward weight v4r4r2 with xdlops (#18) * start * modify transformat * modify device convolutiion * modify host * added host conv bwd and wrw * remove bwd, seperate wrw * clean * hacall k to zero * out log * fixed * fixed * change to (out in wei) * input hack * hack to out * format * fix by comments * change wei hacks(wei transform has not merge) * fix program once issue * fix review comment * fix vector load issue * tweak Co-authored-by: ltqin Co-authored-by: Jing Zhang Co-authored-by: Chao Liu --- ...lution_into_gemm_v4r4r2_nchw_kcyx_nkhw.hpp | 129 ++++++++ host/driver_offline/CMakeLists.txt | 3 + ...plicit_gemm_v4r1_xdlops_nhwc_kyxc_nhwk.hpp | 24 +- ...icit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw.hpp | 228 +++++++++++++ ...icit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw.hpp | 85 +++-- ...icit_gemm_v4r4r3_xdlops_nhwc_kyxc_nhwk.hpp | 302 ------------------ ...icit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk.hpp | 32 +- .../src/conv_bwd_driver_offline.cpp | 48 +-- .../src/conv_fwd_driver_offline.cpp | 80 ++--- .../src/conv_wrw_driver_offline.cpp | 281 ++++++++++++++++ ..._conv_wrw.hpp => host_conv_bwd_weight.hpp} | 0 11 files changed, 777 insertions(+), 435 deletions(-) create mode 100644 composable_kernel/include/problem_transform/transform_backward_weight_convolution_into_gemm_v4r4r2_nchw_kcyx_nkhw.hpp create mode 100644 host/driver_offline/include/device_convolution_backward_weight_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw.hpp delete mode 100644 host/driver_offline/include/device_convolution_forward_implicit_gemm_v4r4r3_xdlops_nhwc_kyxc_nhwk.hpp create mode 100644 host/driver_offline/src/conv_wrw_driver_offline.cpp rename host/host_tensor/include/{host_conv_wrw.hpp => host_conv_bwd_weight.hpp} (100%) diff --git a/composable_kernel/include/problem_transform/transform_backward_weight_convolution_into_gemm_v4r4r2_nchw_kcyx_nkhw.hpp b/composable_kernel/include/problem_transform/transform_backward_weight_convolution_into_gemm_v4r4r2_nchw_kcyx_nkhw.hpp new file mode 100644 index 0000000000..949f044b7d --- /dev/null +++ b/composable_kernel/include/problem_transform/transform_backward_weight_convolution_into_gemm_v4r4r2_nchw_kcyx_nkhw.hpp @@ -0,0 +1,129 @@ +#ifndef CK_TRANSFORM_BACKWARD_WEIGHT_CONVOLUTION_INTO_GEMM_V4R4R2_NCHW_KCYX_NKHW_HPP +#define CK_TRANSFORM_BACKWARD_WEIGHT_CONVOLUTION_INTO_GEMM_V4R4R2_NCHW_KCYX_NKHW_HPP + +#include "common_header.hpp" +#include "tensor_descriptor.hpp" +#include "tensor_descriptor_helper.hpp" + +namespace ck { + +// GemmM = K +// GemmK = N * Ho * Wo +// GemmN = C * Y * X +template +__host__ __device__ constexpr auto +transform_backward_weight_convolution_into_gemm_v4r4r2_nchw_kcyx_nkhw_pad( + const TensorDescriptor& wei_k_c_y_x_grid_desc, + const TensorDescriptor& in_n_c_hi_wi_grid_desc, + const TensorDescriptor& out_n_k_ho_wo_grid_desc, + const ConvStrides& conv_strides, + const ConvDilations& conv_dilations, + const InLeftPads& in_left_pads, + const InRightPads& in_right_pads, + Number) +{ + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + constexpr auto I2 = Number<2>{}; + constexpr auto I3 = Number<3>{}; + + constexpr auto GemmK1 = Number{}; + + const auto N = in_n_c_hi_wi_grid_desc.GetLength(I0); + const auto C = in_n_c_hi_wi_grid_desc.GetLength(I1); + const auto K = out_n_k_ho_wo_grid_desc.GetLength(I1); + + const auto Hi = in_n_c_hi_wi_grid_desc.GetLength(I2); + const auto Wi = in_n_c_hi_wi_grid_desc.GetLength(I3); + + const auto Ho = out_n_k_ho_wo_grid_desc.GetLength(I2); + const auto Wo = out_n_k_ho_wo_grid_desc.GetLength(I3); + + const auto Y = wei_k_c_y_x_grid_desc.GetLength(I2); + const auto X = wei_k_c_y_x_grid_desc.GetLength(I3); + + const auto ConvStrideH = conv_strides[I0]; + const auto ConvStrideW = conv_strides[I1]; + + const auto ConvDilationH = conv_dilations[I0]; + const auto ConvDilationW = conv_dilations[I1]; + + const auto InLeftPadH = in_left_pads[I0]; + const auto InLeftPadW = in_left_pads[I1]; + + const auto InRightPadH = in_right_pads[I0]; + const auto InRightPadW = in_right_pads[I1]; + + const auto GemmM = K; + const auto GemmN = C * Y * X; + const auto GemmK = N * Ho * Wo; + const auto GemmK0 = GemmK / GemmK1; + + // weight tensor + const auto wei_gemmm_gemmn_grid_desc = transform_tensor_descriptor( + make_naive_tensor_descriptor_packed(make_tuple(K, C * Y * X)), + make_tuple(make_pass_through_transform(K), make_pass_through_transform(C * Y * X)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + // input tensor + const auto in_n_c_hip_wip_grid_desc = transform_tensor_descriptor( + in_n_c_hi_wi_grid_desc, + make_tuple(make_pass_through_transform(N), + make_pass_through_transform(C), + make_pad_transform(Hi, InLeftPadH, InRightPadH), + make_pad_transform(Wi, InLeftPadW, InRightPadW)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); + + const auto in_n_c_y_ho_x_wo_grid_desc = transform_tensor_descriptor( + in_n_c_hip_wip_grid_desc, + make_tuple(make_pass_through_transform(N), + make_pass_through_transform(C), + make_embed_transform(make_tuple(Y, Ho), make_tuple(ConvDilationH, ConvStrideH)), + make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW))), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4, 5>{})); + + const auto in_gemmk_gemmn_grid_desc = + transform_tensor_descriptor(in_n_c_y_ho_x_wo_grid_desc, + make_tuple(make_merge_transform(make_tuple(C, Y, X)), + make_merge_transform(make_tuple(N, Ho, Wo))), + make_tuple(Sequence<1, 2, 4>{}, Sequence<0, 3, 5>{}), + make_tuple(Sequence<1>{}, Sequence<0>{})); + + const auto in_gemmk0_gemmn_gemmk1_grid_desc = + transform_tensor_descriptor(in_gemmk_gemmn_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1)), + make_pass_through_transform(GemmN)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + // output tensor + const auto out_gemmk_gemmm_grid_desc = transform_tensor_descriptor( + make_naive_tensor_descriptor_packed(make_tuple(N, K, Ho * Wo)), + make_tuple(make_pass_through_transform(K), make_merge_transform(make_tuple(N, Ho * Wo))), + make_tuple(Sequence<1>{}, Sequence<0, 2>{}), + make_tuple(Sequence<1>{}, Sequence<0>{})); + + const auto out_gemmk0_gemmm_gemmk1_grid_desc = + transform_tensor_descriptor(out_gemmk_gemmm_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1)), + make_pass_through_transform(GemmM)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return make_tuple(out_gemmk0_gemmm_gemmk1_grid_desc, + in_gemmk0_gemmn_gemmk1_grid_desc, + wei_gemmm_gemmn_grid_desc); +} + +} // namespace ck +#endif diff --git a/host/driver_offline/CMakeLists.txt b/host/driver_offline/CMakeLists.txt index fec11e99af..8dec70d03f 100644 --- a/host/driver_offline/CMakeLists.txt +++ b/host/driver_offline/CMakeLists.txt @@ -13,9 +13,12 @@ include_directories(BEFORE set(CONV_FWD_DRIVER_OFFLINE_SOURCE src/conv_fwd_driver_offline.cpp) set(CONV_BWD_DRIVER_OFFLINE_SOURCE src/conv_bwd_driver_offline.cpp) +set(CONV_WRW_DRIVER_OFFLINE_SOURCE src/conv_wrw_driver_offline.cpp) add_executable(conv_fwd_driver_offline ${CONV_FWD_DRIVER_OFFLINE_SOURCE}) add_executable(conv_bwd_driver_offline ${CONV_BWD_DRIVER_OFFLINE_SOURCE}) +add_executable(conv_wrw_driver_offline ${CONV_WRW_DRIVER_OFFLINE_SOURCE}) target_link_libraries(conv_fwd_driver_offline PRIVATE host_tensor) target_link_libraries(conv_bwd_driver_offline PRIVATE host_tensor) +target_link_libraries(conv_wrw_driver_offline PRIVATE host_tensor) diff --git a/host/driver_offline/include/device_convolution_backward_data_implicit_gemm_v4r1_xdlops_nhwc_kyxc_nhwk.hpp b/host/driver_offline/include/device_convolution_backward_data_implicit_gemm_v4r1_xdlops_nhwc_kyxc_nhwk.hpp index 7196f3c179..8f49473563 100644 --- a/host/driver_offline/include/device_convolution_backward_data_implicit_gemm_v4r1_xdlops_nhwc_kyxc_nhwk.hpp +++ b/host/driver_offline/include/device_convolution_backward_data_implicit_gemm_v4r1_xdlops_nhwc_kyxc_nhwk.hpp @@ -208,20 +208,20 @@ void device_convolution_backward_data_implicit_gemm_v4r1_xdlops_nhwc_kyxc_nhwk( // HACK: hacks that control index calculation when iterating over A, B, C matrix constexpr auto wei_gemmk0_gemmm_gemmk1_grid_step_hacks = - make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{}, // 0+: gemmk0 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1+: gemmm - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}), // 2+: gemmk1 - make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{}, // 0-: Gemmk0 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1-: Gemmm - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{})); // 2-: Gemmk1 + make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{}, // 0+: GemmK0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1+: GemmM + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}), // 2+: GemmK1 + make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{}, // 0-: GemmK0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1-: GemmM + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{})); // 2-: GemmK1 constexpr auto out_gemmk0_gemmn_gemmk1_grid_step_hacks = make_tuple( - make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{}, // 0+: gemmk0 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0>{}, // 1+: gemmn - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}), // 2+: gemmk1 - make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{}, // 0-: gemmk0 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0>{}, // 1-: gemmn - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{})); // 2-: gemmk1 + make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{}, // 0+: GemmK0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0>{}, // 1+: GemmN + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}), // 2+: GemmK1 + make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{}, // 0-: GemmK0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0>{}, // 1-: GemmN + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{})); // 2-: GemmK1 // clang-format off constexpr auto in_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks = make_tuple( diff --git a/host/driver_offline/include/device_convolution_backward_weight_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw.hpp b/host/driver_offline/include/device_convolution_backward_weight_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw.hpp new file mode 100644 index 0000000000..e97bc9c1c7 --- /dev/null +++ b/host/driver_offline/include/device_convolution_backward_weight_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw.hpp @@ -0,0 +1,228 @@ +#include +#include "device.hpp" +#include "host_tensor.hpp" +#include "transform_backward_weight_convolution_into_gemm_v4r4r2_nchw_kcyx_nkhw.hpp" +#include "driver_gemm_xdlops_v2r3.hpp" + +template +void device_convolution_backward_weight_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw( + const InLengths& in_n_c_hi_wi_lengths, + const WeiLengths& wei_k_c_y_x_lengths, + const OutLengths& out_n_k_ho_wo_lengths, + const ConvStrides& conv_strides, + const ConvDilations& conv_dilations, + const InLeftPads& in_left_pads, + const InRightPads& in_right_pads, + const Tensor& in_n_c_hi_wi, + Tensor& wei_k_c_y_x, + const Tensor& out_n_k_ho_wo, + ck::index_t nrepeat) +{ + using namespace ck; + + std::cout << __func__ << std::endl; + + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + constexpr auto I2 = Number<2>{}; + + DeviceMem in_n_c_hi_wi_device_buf(sizeof(TInWei) * in_n_c_hi_wi.mDesc.GetElementSpace()); + DeviceMem wei_k_c_y_x_device_buf(sizeof(TInWei) * wei_k_c_y_x.mDesc.GetElementSpace()); + DeviceMem out_n_k_ho_wo_device_buf(sizeof(TOut) * out_n_k_ho_wo.mDesc.GetElementSpace()); + + in_n_c_hi_wi_device_buf.ToDevice(in_n_c_hi_wi.mData.data()); + wei_k_c_y_x_device_buf.ToDevice(wei_k_c_y_x.mData.data()); + out_n_k_ho_wo_device_buf.ToDevice(out_n_k_ho_wo.mData.data()); + + const auto in_n_c_hi_wi_desc = make_naive_tensor_descriptor_packed(in_n_c_hi_wi_lengths); + const auto wei_k_c_y_x_desc = make_naive_tensor_descriptor_packed(wei_k_c_y_x_lengths); + const auto out_n_k_ho_wo_desc = make_naive_tensor_descriptor_packed(out_n_k_ho_wo_lengths); + +#if 1 + // [M, N, K0, K1] = [128, 128, 4, 8] for fp16 + constexpr index_t BlockSize = 256; + + constexpr index_t GemmMPerBlock = 128; + constexpr index_t GemmNPerBlock = 128; + constexpr index_t GemmKPerBlock = 4; + + constexpr index_t GemmMPerWave = 32; + constexpr index_t GemmNPerWave = 32; + constexpr index_t GemmK1 = 8; + + constexpr index_t MRepeat = 2; + constexpr index_t NRepeat = 2; + + using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 2, 8>; + using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>; + // using vector load 4, so config's wo*ho must be a multiple of 4 + constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK1 = 4; + constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 4; + + using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 2, 8>; + using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>; + + constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 1; + constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 8; + + constexpr index_t GemmCThreadTransferDstScalarPerVector = 1; +#elif 1 + // [M, N, K0, K1] = [128, 128, 4, 8] for fp16 + constexpr index_t BlockSize = 256; + + constexpr index_t GemmMPerBlock = 256; + constexpr index_t GemmNPerBlock = 128; + constexpr index_t GemmKPerBlock = 4; + + constexpr index_t GemmMPerWave = 32; + constexpr index_t GemmNPerWave = 32; + constexpr index_t GemmK1 = 8; + + constexpr index_t MRepeat = 4; + constexpr index_t NRepeat = 2; + + using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 8>; + using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>; + // using vector load 4, so config's wo*ho must be a multiple of 4 + constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK1 = 4; + constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 4; + + using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 2, 8>; + using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>; + + constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 1; + constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 8; + + constexpr index_t GemmCThreadTransferDstScalarPerVector = 1; +#endif + + const auto descs = transform_backward_weight_convolution_into_gemm_v4r4r2_nchw_kcyx_nkhw_pad( + wei_k_c_y_x_desc, + in_n_c_hi_wi_desc, + out_n_k_ho_wo_desc, + conv_strides, + conv_dilations, + in_left_pads, + in_right_pads, + Number{}); + + const auto out_gemmk0_gemmm_gemmk1_grid_desc = descs[I0]; + const auto in_gemmk0_gemmn_gemmk1_grid_desc = descs[I1]; + const auto wei_gemmm_gemmn_grid_desc = descs[I2]; + + // HACK: hacks that control index calculation when iterating over A, B, C matrix + constexpr auto out_gemmk0_gemmm_gemmk1_grid_step_hacks = + make_tuple(make_tuple(Sequence<0, 0, 1, 0, 0>{}, // 0+: GemmK0 + Sequence<0, 0, 0, 0, 0>{}, // 1+: GemmM + Sequence<0, 0, 1, 0, 0>{}), // 2+: GemmK1 + make_tuple(Sequence<0, 0, 2, 0, 0>{}, // 0-: GemmK0 + Sequence<0, 0, 0, 0, 0>{}, // 1-: GemmM + Sequence<0, 0, 2, 0, 0>{})); // 2-: GemmK1 + + constexpr auto in_gemmk0_gemmn_gemmk1_grid_step_hacks = + make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{}, // 0+: GemmK0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}, // 1+: GemmN + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{}), // 2+: GemmK1 + make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{}, // 0-: GemmK0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{}, // 1-: GemmN + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{})); // 2-: GemmK1 + + constexpr auto wei_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks = + make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0+: M0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1+: N0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2+: M1 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3+: N1 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4+: M2 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5+: M3 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6+: M4 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}), // 7+: N2 + make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0-: M0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1-: N0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2-: M1 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3-: N1 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4-: M2 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5-: M3 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6-: M4 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{})); // 7-: N2 + + constexpr auto out_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks = + Sequence<0, 0, 1, 0, 0>{}; + + constexpr auto in_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks = + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 1, 0, 0>{}; + + for(index_t i = 0; i < 5; ++i) + { + float ave_time = driver_gemm_xdlops_v2r3< + BlockSize, + TInWei, + TAcc, + TOut, + InMemoryDataOperationEnum_t::Set, + decltype(out_gemmk0_gemmm_gemmk1_grid_desc), + decltype(in_gemmk0_gemmn_gemmk1_grid_desc), + decltype(wei_gemmm_gemmn_grid_desc), + GemmMPerBlock, + GemmNPerBlock, + GemmKPerBlock, + GemmMPerWave, + GemmNPerWave, + GemmK1, + MRepeat, + NRepeat, + GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1, + GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1, + Sequence<1, 0, 2>, + Sequence<1, 0, 2>, + 2, + GemmABlockTransferSrcScalarPerVector_GemmK1, + GemmABlockTransferDstScalarPerVector_GemmK1, + false, // don't move back src coordinate after threadwise copy + GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1, + GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1, + Sequence<1, 0, 2>, + Sequence<1, 0, 2>, + 2, + GemmBBlockTransferSrcScalarPerVector_GemmN, + GemmBBlockTransferDstScalarPerVector_GemmK1, + false, // don't move back src coordinate after threadwise copy + Sequence<3, 0, 1, 2, 7, 5, 4, 6>, + 7, + GemmCThreadTransferDstScalarPerVector, + decltype(out_gemmk0_gemmm_gemmk1_grid_step_hacks), + decltype(in_gemmk0_gemmn_gemmk1_grid_step_hacks), + decltype(wei_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks), + decltype(out_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks), + decltype(in_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks), + false>(static_cast(out_n_k_ho_wo_device_buf.GetDeviceBuffer()), + static_cast(in_n_c_hi_wi_device_buf.GetDeviceBuffer()), + static_cast(wei_k_c_y_x_device_buf.GetDeviceBuffer()), + out_gemmk0_gemmm_gemmk1_grid_desc, + in_gemmk0_gemmn_gemmk1_grid_desc, + wei_gemmm_gemmn_grid_desc, + out_gemmk0_gemmm_gemmk1_grid_step_hacks, + in_gemmk0_gemmn_gemmk1_grid_step_hacks, + wei_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks, + out_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks, + in_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks, + nrepeat); + + float perf = static_cast(calculate_convolution_flops( + in_n_c_hi_wi_desc, wei_k_c_y_x_desc, out_n_k_ho_wo_desc)) / + (std::size_t(1000) * 1000 * 1000) / ave_time; + + std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s" << std::endl; + } + + // copy result back to host + wei_k_c_y_x_device_buf.FromDevice(wei_k_c_y_x.mData.data()); +} diff --git a/host/driver_offline/include/device_convolution_forward_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw.hpp b/host/driver_offline/include/device_convolution_forward_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw.hpp index dc4f5eafb6..d65ecadb4d 100644 --- a/host/driver_offline/include/device_convolution_forward_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw.hpp +++ b/host/driver_offline/include/device_convolution_forward_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw.hpp @@ -47,7 +47,7 @@ void device_convolution_forward_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw( const auto wei_k_c_y_x_desc = make_naive_tensor_descriptor_packed(wei_k_c_y_x_lengths); const auto out_n_k_ho_wo_desc = make_naive_tensor_descriptor_packed(out_n_k_ho_wo_lengths); -#if 1 +#if 0 // [M, N, K0, K1] = [128, 128, 4, 8] for fp16 constexpr index_t BlockSize = 256; @@ -74,6 +74,34 @@ void device_convolution_forward_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw( constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 1; constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 8; + constexpr index_t GemmCThreadTransferDstScalarPerVector = 1; +#elif 1 + // [M, N, K0, K1] = [256, 128, 4, 8] for fp16 + constexpr index_t BlockSize = 256; + + constexpr index_t GemmMPerBlock = 256; + constexpr index_t GemmNPerBlock = 128; + constexpr index_t GemmKPerBlock = 4; + + constexpr index_t GemmMPerWave = 32; + constexpr index_t GemmNPerWave = 32; + constexpr index_t GemmK1 = 8; + + constexpr index_t MRepeat = 4; + constexpr index_t NRepeat = 2; + + using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 8>; + using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>; + + constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK1 = 8; + constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 8; + + using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 2, 8>; + using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>; + + constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 1; + constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 8; + constexpr index_t GemmCThreadTransferDstScalarPerVector = 1; #endif @@ -92,36 +120,39 @@ void device_convolution_forward_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw( const auto out_gemmm_gemmn_grid_desc = descs[I2]; // HACK: hacks that control index calculation when iterating over A, B, C matrix - constexpr auto wei_gemmk0_gemmm_gemmk1_grid_step_hacks = make_tuple( - make_tuple(Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}), - make_tuple( - Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{})); + constexpr auto wei_gemmk0_gemmm_gemmk1_grid_step_hacks = + make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{}, // 0+: GemmK0 + Sequence<0, 0, 0, 0, 0>{}, // 1+: GemmM + Sequence<0, 0, 0, 0, 0>{}), // 2+: GemmK1 + make_tuple(Sequence<0, 0, 0, 0, 0>{}, // 0-: GemmK0 + Sequence<0, 0, 0, 0, 0>{}, // 1-: GemmM + Sequence<0, 0, 0, 0, 0>{})); // 2-: GemmK1 constexpr auto in_gemmk0_gemmn_gemmk1_grid_step_hacks = - make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{}, - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}), - make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{}, - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{})); + make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}, // 0+: GemmK0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{}, // 1+: GemmN + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}), // 2+: GemmK1 + make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{}, // 0-: GemmK0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{}, // 1-: GemmN + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{})); // 2-: GemmK1 constexpr auto out_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks = - make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, - Sequence<0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, - Sequence<0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, - Sequence<0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0>{}), - make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, - Sequence<0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, - Sequence<0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, - Sequence<0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0>{})); + make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0+: M0 + Sequence<0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1+: N0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2+: M1 + Sequence<0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3+: N1 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4+: M2 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5+: M3 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6+: M4 + Sequence<0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0>{}), // 7+: N2 + make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0-: M0 + Sequence<0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1-: N0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2-: M1 + Sequence<0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3-: N1 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4-: M2 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5-: M3 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6-: M4 + Sequence<0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0>{})); // 7-: N2 constexpr auto wei_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks = Sequence<0, 0, 0, 0, 0>{}; diff --git a/host/driver_offline/include/device_convolution_forward_implicit_gemm_v4r4r3_xdlops_nhwc_kyxc_nhwk.hpp b/host/driver_offline/include/device_convolution_forward_implicit_gemm_v4r4r3_xdlops_nhwc_kyxc_nhwk.hpp deleted file mode 100644 index 692751bfb3..0000000000 --- a/host/driver_offline/include/device_convolution_forward_implicit_gemm_v4r4r3_xdlops_nhwc_kyxc_nhwk.hpp +++ /dev/null @@ -1,302 +0,0 @@ -#include -#include "device.hpp" -#include "host_tensor.hpp" -#include "transform_forward_convolution_into_gemm_v4r4r2_nhwc_kyxc_nhwk.hpp" -#include "driver_gemm_xdlops_v2r3.hpp" - -template -void device_convolution_forward_implicit_gemm_v4r4r3_xdlops_nhwc_kyxc_nhwk( - const InLengths& in_n_hi_wi_c_lengths, - const WeiLengths& wei_k_y_x_c_lengths, - const OutLengths& out_n_ho_wo_k_lengths, - const ConvStrides& conv_strides, - const ConvDilations& conv_dilations, - const InLeftPads& in_left_pads, - const InRightPads& in_right_pads, - const Tensor& in_n_hi_wi_c, - const Tensor& wei_k_y_x_c, - Tensor& out_n_ho_wo_k, - ck::index_t nrepeat) -{ - using namespace ck; - - std::cout << __func__ << std::endl; - - constexpr auto I0 = Number<0>{}; - constexpr auto I1 = Number<1>{}; - constexpr auto I2 = Number<2>{}; - constexpr auto I3 = Number<3>{}; - constexpr auto I4 = Number<4>{}; - constexpr auto I5 = Number<5>{}; - constexpr auto I6 = Number<6>{}; - constexpr auto I7 = Number<7>{}; - constexpr auto I8 = Number<8>{}; - - DeviceMem in_n_hi_wi_c_device_buf(sizeof(TInWei) * in_n_hi_wi_c.mDesc.GetElementSpace()); - DeviceMem wei_k_y_x_c_device_buf(sizeof(TInWei) * wei_k_y_x_c.mDesc.GetElementSpace()); - DeviceMem out_n_ho_wo_k_device_buf(sizeof(TOut) * out_n_ho_wo_k.mDesc.GetElementSpace()); - - in_n_hi_wi_c_device_buf.ToDevice(in_n_hi_wi_c.mData.data()); - wei_k_y_x_c_device_buf.ToDevice(wei_k_y_x_c.mData.data()); - out_n_ho_wo_k_device_buf.ToDevice(out_n_ho_wo_k.mData.data()); - - const auto in_n_hi_wi_c_desc = make_naive_tensor_descriptor_packed(in_n_hi_wi_c_lengths); - const auto wei_k_y_x_c_desc = make_naive_tensor_descriptor_packed(wei_k_y_x_c_lengths); - const auto out_n_ho_wo_k_desc = make_naive_tensor_descriptor_packed(out_n_ho_wo_k_lengths); - -#if 1 - // [M, N, K0, K1] = [256, 128, 4, 4] for fp32 - constexpr index_t BlockSize = 256; - - constexpr index_t GemmMPerBlock = 256; - constexpr index_t GemmNPerBlock = 128; - constexpr index_t GemmKPerBlock = 4; - - constexpr index_t GemmMPerWave = 32; - constexpr index_t GemmNPerWave = 32; - constexpr index_t GemmK1 = 4; - - constexpr index_t MRepeat = 4; - constexpr index_t NRepeat = 2; - - using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 4>; - using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>; - - constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK1 = 4; - constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 4; - - using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 2, 4>; - using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>; - - constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmK1 = 4; - constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 4; - - constexpr index_t GemmCThreadTransferDstScalarPerVector = 4; -#elif 1 - // [M, N, K0, K1] = [128, 128, 4, 4] for fp32 - constexpr index_t BlockSize = 256; - - constexpr index_t GemmMPerBlock = 128; - constexpr index_t GemmNPerBlock = 128; - constexpr index_t GemmKPerBlock = 4; - - constexpr index_t GemmMPerWave = 32; - constexpr index_t GemmNPerWave = 32; - constexpr index_t GemmK1 = 4; - - constexpr index_t MRepeat = 2; - constexpr index_t NRepeat = 2; - - using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 2, 4>; - using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>; - - constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK1 = 4; - constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 4; - - using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 2, 4>; - using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>; - - constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmK1 = 4; - constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 4; - - constexpr index_t GemmCThreadTransferDstScalarPerVector = 4; -#elif 0 - // [M, N, K0, K1] = [256, 256, 4, 8] for fp16 - constexpr index_t BlockSize = 256; - - constexpr index_t GemmMPerBlock = 256; - constexpr index_t GemmNPerBlock = 256; - constexpr index_t GemmKPerBlock = 4; - - constexpr index_t GemmMPerWave = 32; - constexpr index_t GemmNPerWave = 32; - constexpr index_t GemmK1 = 8; - - constexpr index_t MRepeat = 4; - constexpr index_t NRepeat = 4; - - using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 8>; - using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>; - - constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK1 = 8; - constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 8; - - using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 4, 8>; - using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>; - - constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmK1 = 8; - constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 8; - - constexpr index_t GemmCThreadTransferDstScalarPerVector = 4; -#elif 1 - // [M, N, K0, K1] = [256, 128, 4, 8] for fp16 - constexpr index_t BlockSize = 256; - - constexpr index_t GemmMPerBlock = 256; - constexpr index_t GemmNPerBlock = 128; - constexpr index_t GemmKPerBlock = 4; - - constexpr index_t GemmMPerWave = 32; - constexpr index_t GemmNPerWave = 32; - constexpr index_t GemmK1 = 8; - - constexpr index_t MRepeat = 4; - constexpr index_t NRepeat = 2; - - using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 8>; - using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>; - - constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK1 = 8; - constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 8; - - using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 2, 8>; - using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>; - - constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmK1 = 8; - constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 8; - - constexpr index_t GemmCThreadTransferDstScalarPerVector = 4; -#endif - - const auto descs = - transform_forward_convolution_into_gemm_v4r4r2_nhwc_kyxc_nhwk_pad(wei_k_y_x_c_desc, - in_n_hi_wi_c_desc, - out_n_ho_wo_k_desc, - conv_strides, - conv_dilations, - in_left_pads, - in_right_pads, - Number{}); - - const auto wei_gemmk0_gemmm_gemmk1_grid_desc = descs[I0]; - const auto in_gemmk0_gemmn_gemmk1_grid_desc = descs[I1]; - const auto out_gemmm_gemmn_grid_desc = descs[I2]; - - // HACK: hacks that control index calculation when iterating over A, B, C matrix - constexpr auto wei_gemmk0_gemmm_gemmk1_grid_step_hacks = make_tuple( - make_tuple(Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}), - make_tuple( - Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{})); - - constexpr auto in_gemmk0_gemmn_gemmk1_grid_step_hacks = - make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{}, - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}), - make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{}, - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{})); - - constexpr auto out_m0_m1_m2_n_grid_step_hacks = - make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 1, 0, 0>{}, - Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 1, 0, 0>{}, - Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 1, 0, 0>{}), - make_tuple(Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 2, 0, 0>{}, - Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 2, 0, 0>{}, - Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 2, 0, 0>{})); - - constexpr auto wei_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks = - Sequence<0, 0, 0, 0, 0>{}; - - constexpr auto in_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks = - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 0, 0>{}; - - for(index_t i = 0; i < 5; ++i) - { - float ave_time = driver_gemm_xdlops_v2r3< - BlockSize, - TInWei, - TAcc, - TOut, - InMemoryDataOperationEnum_t::Set, - decltype(wei_gemmk0_gemmm_gemmk1_grid_desc), - decltype(in_gemmk0_gemmn_gemmk1_grid_desc), - decltype(out_gemmm_gemmn_grid_desc), - GemmMPerBlock, - GemmNPerBlock, - GemmKPerBlock, - GemmMPerWave, - GemmNPerWave, - MRepeat, - NRepeat, - GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1, - GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1, - Sequence<1, 0, 2>, - Sequence<1, 0, 2>, - 2, - GemmABlockTransferSrcScalarPerVector_GemmK1, - GemmABlockTransferDstScalarPerVector_GemmK1, - false, // don't move back src coordinate after threadwise copy - GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1, - GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1, - Sequence<1, 0, 2>, - Sequence<1, 0, 2>, - 2, - GemmBBlockTransferSrcScalarPerVector_GemmK1, - GemmBBlockTransferDstScalarPerVector_GemmK1, - false, // don't move back src coordinate after threadwise copy - Sequence<2, 3, 0, 1, 7, 5, 4, 6>, - 6, - GemmCThreadTransferDstScalarPerVector, - decltype(wei_gemmk0_gemmm_gemmk1_grid_step_hacks), - decltype(in_gemmk0_gemmn_gemmk1_grid_step_hacks), - decltype(out_m0_m1_m2_n_grid_step_hacks), - decltype(wei_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks), - decltype(in_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks), - false // CAccessOrderMRepeatNRepeat - >(static_cast(wei_k_y_x_c_device_buf.GetDeviceBuffer()), - static_cast(in_n_hi_wi_c_device_buf.GetDeviceBuffer()), - static_cast(out_n_ho_wo_k_device_buf.GetDeviceBuffer()), - wei_gemmk0_gemmm_gemmk1_grid_desc, - in_gemmk0_gemmn_gemmk1_grid_desc, - out_gemmm_gemmn_grid_desc, - wei_gemmk0_gemmm_gemmk1_grid_step_hacks, - in_gemmk0_gemmn_gemmk1_grid_step_hacks, - out_m0_m1_m2_n_grid_step_hacks, - wei_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks, - in_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks, - nrepeat); - - { - const auto N = out_n_ho_wo_k_lengths[I0]; - const auto K = out_n_ho_wo_k_lengths[I3]; - const auto C = wei_k_y_x_c_lengths[I3]; - - const auto Hi = in_n_hi_wi_c_lengths[I1]; - const auto Wi = in_n_hi_wi_c_lengths[I2]; - - const auto Ho = out_n_ho_wo_k_lengths[I1]; - const auto Wo = out_n_ho_wo_k_lengths[I2]; - - const auto Y = wei_k_y_x_c_lengths[I1]; - const auto X = wei_k_y_x_c_lengths[I2]; - - float perf = (float)(std::size_t(2) * N * K * Ho * Wo * C * Y * X) / - (std::size_t(1000) * 1000 * 1000) / ave_time; - - std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s" - << std::endl; - } - } - - // copy result back to host - out_n_ho_wo_k_device_buf.FromDevice(out_n_ho_wo_k.mData.data()); -} diff --git a/host/driver_offline/include/device_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk.hpp b/host/driver_offline/include/device_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk.hpp index 5ff8dfb665..52432664de 100644 --- a/host/driver_offline/include/device_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk.hpp +++ b/host/driver_offline/include/device_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk.hpp @@ -250,22 +250,22 @@ void device_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk( Sequence<0, 0, 0, 0, 0>{})); // 2-: GemmK1 constexpr auto out_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks = - make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0+: MRepeat - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1+: NRepeat - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2+: MWaves - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3+: NWaves - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4+: M0 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5+: M1 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6+: M2 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}), // 7+: N1 - make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0-: MRepeat - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1-: NRepeat - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2-: MWaves - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3-: NWaves - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4-: M0 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5-: M1 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6-: M2 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{})); // 7-: N1 + make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0+: M0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1+: N0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2+: M1 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3+: N1 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4+: M2 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5+: M3 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6+: M4 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}), // 7+: N2 + make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0-: M0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1-: N0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2-: M1 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3-: N1 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4-: M2 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5-: M3 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6-: M4 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{})); // 7-: N2 constexpr auto in_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks = Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 0, 0>{}; diff --git a/host/driver_offline/src/conv_bwd_driver_offline.cpp b/host/driver_offline/src/conv_bwd_driver_offline.cpp index 67cea94813..4e93ada859 100644 --- a/host/driver_offline/src/conv_bwd_driver_offline.cpp +++ b/host/driver_offline/src/conv_bwd_driver_offline.cpp @@ -41,7 +41,7 @@ int main(int argc, char* argv[]) // dynamic mode if(argc != 22) { - printf("arg1 to 5: layout, algo, do_verification, init_method, do_log, nrepeat\n"); + printf("arg1 to 6: layout, algo, do_verification, init_method, do_log, nrepeat\n"); printf("rest: N, K, C, Y, X, Hi, Wi, Sy, Sx, Dy, Dx, LeftPy, LeftPx, RightPy, RightPx\n"); exit(1); } @@ -79,7 +79,7 @@ int main(int argc, char* argv[]) // static mode if(argc < 7) { - printf("arg1 to 5: layout, algo, do_verification, init_method, do_log, nrepeat\n"); + printf("arg1 to 6: layout, algo, do_verification, init_method, do_log, nrepeat\n"); exit(1); } @@ -90,28 +90,28 @@ int main(int argc, char* argv[]) const bool do_log = std::stoi(argv[5]); const int nrepeat = std::stoi(argv[6]); - constexpr index_t N = 128; - constexpr index_t C = 192; - constexpr index_t Hi = 71; - constexpr index_t Wi = 71; - constexpr index_t K = 256; - constexpr index_t Y = 3; - constexpr index_t X = 3; + constexpr auto N = Number<128>{}; + constexpr auto C = Number<192>{}; + constexpr auto Hi = Number<71>{}; + constexpr auto Wi = Number<71>{}; + constexpr auto K = Number<256>{}; + constexpr auto Y = Number<3>{}; + constexpr auto X = Number<3>{}; - const index_t conv_stride_h = 2; - const index_t conv_stride_w = 2; - const index_t conv_dilation_h = 1; - const index_t conv_dilation_w = 1; - const index_t in_left_pad_h = 1; - const index_t in_left_pad_w = 1; - const index_t in_right_pad_h = 1; - const index_t in_right_pad_w = 1; + constexpr auto conv_stride_h = I2; + constexpr auto conv_stride_w = I2; + constexpr auto conv_dilation_h = I1; + constexpr auto conv_dilation_w = I1; + constexpr auto in_left_pad_h = I1; + constexpr auto in_left_pad_w = I1; + constexpr auto in_right_pad_h = I1; + constexpr auto in_right_pad_w = I1; - const index_t YEff = (Y - 1) * conv_dilation_h + 1; - const index_t XEff = (X - 1) * conv_dilation_w + 1; + constexpr auto YEff = (Y - I1) * conv_dilation_h + I1; + constexpr auto XEff = (X - I1) * conv_dilation_w + I1; - const index_t Ho = (Hi + in_left_pad_h + in_right_pad_h - YEff) / conv_stride_h + 1; - const index_t Wo = (Wi + in_left_pad_w + in_right_pad_w - XEff) / conv_stride_w + 1; + constexpr auto Ho = (Hi + in_left_pad_h + in_right_pad_h - YEff) / conv_stride_h + I1; + constexpr auto Wo = (Wi + in_left_pad_w + in_right_pad_w - XEff) / conv_stride_w + I1; #endif #if 0 @@ -119,9 +119,9 @@ int main(int argc, char* argv[]) using acc_data_t = float; using out_data_t = float; #elif 1 - using in_data_t = half_t; - using acc_data_t = float; - using out_data_t = half_t; + using in_data_t = half_t; + using acc_data_t = float; + using out_data_t = half_t; #endif std::vector in_lengths_host(4), wei_lengths_host(4), out_lengths_host(4); diff --git a/host/driver_offline/src/conv_fwd_driver_offline.cpp b/host/driver_offline/src/conv_fwd_driver_offline.cpp index 21acb35732..34d7247f3c 100644 --- a/host/driver_offline/src/conv_fwd_driver_offline.cpp +++ b/host/driver_offline/src/conv_fwd_driver_offline.cpp @@ -19,7 +19,7 @@ #include "device_convolution_forward_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw.hpp" #include "device_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk.hpp" -#define USE_MODE 1 +#define USE_DYNAMIC_MODE 1 #define USE_CONV_FWD_V4R4_NCHW 0 #define USE_CONV_FWD_V4R4R2_NHWC 0 #define USE_CONV_FWD_V6R1_NCHW 0 @@ -49,11 +49,11 @@ int main(int argc, char* argv[]) constexpr auto I5 = Number<5>{}; constexpr auto I6 = Number<6>{}; -#if USE_MODE +#if USE_DYNAMIC_MODE // dynamic mode if(argc != 22) { - printf("arg1 to 5: layout, algo, do_verification, init_method, do_log, nrepeat\n"); + printf("arg1 to 6: layout, algo, do_verification, init_method, do_log, nrepeat\n"); printf("rest: N, K, C, Y, X, Hi, Wi, Sy, Sx, Dy, Dx, LeftPy, LeftPx, RightPy, RightPx\n"); exit(1); } @@ -91,7 +91,7 @@ int main(int argc, char* argv[]) // static mode if(argc < 7) { - printf("arg1 to 5: layout, algo, do_verification, init_method, do_log, nrepeat\n"); + printf("arg1 to 6: layout, algo, do_verification, init_method, do_log, nrepeat\n"); exit(1); } @@ -102,28 +102,28 @@ int main(int argc, char* argv[]) const bool do_log = std::stoi(argv[5]); const int nrepeat = std::stoi(argv[6]); - constexpr index_t N = 128; - constexpr index_t C = 192; - constexpr index_t Hi = 71; - constexpr index_t Wi = 71; - constexpr index_t K = 256; - constexpr index_t Y = 3; - constexpr index_t X = 3; + constexpr auto N = Number<128>{}; + constexpr auto C = Number<192>{}; + constexpr auto Hi = Number<71>{}; + constexpr auto Wi = Number<71>{}; + constexpr auto K = Number<256>{}; + constexpr auto Y = Number<3>{}; + constexpr auto X = Number<3>{}; - const index_t conv_stride_h = 2; - const index_t conv_stride_w = 2; - const index_t conv_dilation_h = 1; - const index_t conv_dilation_w = 1; - const index_t in_left_pad_h = 1; - const index_t in_left_pad_w = 1; - const index_t in_right_pad_h = 1; - const index_t in_right_pad_w = 1; + constexpr auto conv_stride_h = I2; + constexpr auto conv_stride_w = I2; + constexpr auto conv_dilation_h = I1; + constexpr auto conv_dilation_w = I1; + constexpr auto in_left_pad_h = I1; + constexpr auto in_left_pad_w = I1; + constexpr auto in_right_pad_h = I1; + constexpr auto in_right_pad_w = I1; - const index_t YEff = (Y - 1) * conv_dilation_h + 1; - const index_t XEff = (X - 1) * conv_dilation_w + 1; + constexpr auto YEff = (Y - I1) * conv_dilation_h + I1; + constexpr auto XEff = (X - I1) * conv_dilation_w + I1; - const index_t Ho = (Hi + in_left_pad_h + in_right_pad_h - YEff) / conv_stride_h + 1; - const index_t Wo = (Wi + in_left_pad_w + in_right_pad_w - XEff) / conv_stride_w + 1; + constexpr auto Ho = (Hi + in_left_pad_h + in_right_pad_h - YEff) / conv_stride_h + I1; + constexpr auto Wo = (Wi + in_left_pad_w + in_right_pad_w - XEff) / conv_stride_w + I1; #endif #if 0 @@ -131,9 +131,9 @@ int main(int argc, char* argv[]) using acc_data_t = float; using out_data_t = float; #elif 1 - using in_data_t = half_t; - using acc_data_t = float; - using out_data_t = half_t; + using in_data_t = half_t; + using acc_data_t = float; + using out_data_t = half_t; #elif 1 using in_data_t = int8_t; using acc_data_t = int32_t; @@ -228,7 +228,6 @@ int main(int argc, char* argv[]) } auto f_make_for_device_nchw = [&]() { -#if USE_MODE const auto in_lengths_dev = make_tuple(N, C, Hi, Wi); const auto wei_lengths_dev = make_tuple(K, C, Y, X); const auto out_lengths_dev = make_tuple(N, K, Ho, Wo); @@ -236,19 +235,6 @@ int main(int argc, char* argv[]) const auto conv_dilations_dev = make_tuple(conv_dilation_h, conv_dilation_w); const auto in_left_pads_dev = make_tuple(in_left_pad_h, in_left_pad_w); const auto in_right_pads_dev = make_tuple(in_right_pad_h, in_right_pad_w); -#else - const auto in_lengths_dev = - make_tuple(Number{}, Number{}, Number{}, Number{}); - const auto wei_lengths_dev = make_tuple(Number{}, Number{}, Number{}, Number{}); - const auto out_lengths_dev = - make_tuple(Number{}, Number{}, Number{}, Number{}); - const auto conv_strides_dev = make_tuple(Number{}, Number{}); - const auto conv_dilations_dev = - make_tuple(Number{}, Number{}); - const auto in_left_pads_dev = make_tuple(Number{}, Number{}); - const auto in_right_pads_dev = - make_tuple(Number{}, Number{}); -#endif return make_tuple(in_lengths_dev, wei_lengths_dev, @@ -260,7 +246,6 @@ int main(int argc, char* argv[]) }; auto f_make_for_device_nhwc = [&]() { -#if USE_MODE const auto in_lengths_dev = make_tuple(N, Hi, Wi, C); const auto wei_lengths_dev = make_tuple(K, Y, X, C); const auto out_lengths_dev = make_tuple(N, Ho, Wo, K); @@ -268,19 +253,6 @@ int main(int argc, char* argv[]) const auto conv_dilations_dev = make_tuple(conv_dilation_h, conv_dilation_w); const auto in_left_pads_dev = make_tuple(in_left_pad_h, in_left_pad_w); const auto in_right_pads_dev = make_tuple(in_right_pad_h, in_right_pad_w); -#else - const auto in_lengths_dev = - make_tuple(Number{}, Number{}, Number{}, Number{}); - const auto wei_lengths_dev = make_tuple(Number{}, Number{}, Number{}, Number{}); - const auto out_lengths_dev = - make_tuple(Number{}, Number{}, Number{}, Number{}); - const auto conv_strides_dev = make_tuple(Number{}, Number{}); - const auto conv_dilations_dev = - make_tuple(Number{}, Number{}); - const auto in_left_pads_dev = make_tuple(Number{}, Number{}); - const auto in_right_pads_dev = - make_tuple(Number{}, Number{}); -#endif return make_tuple(in_lengths_dev, wei_lengths_dev, diff --git a/host/driver_offline/src/conv_wrw_driver_offline.cpp b/host/driver_offline/src/conv_wrw_driver_offline.cpp new file mode 100644 index 0000000000..13c73abf30 --- /dev/null +++ b/host/driver_offline/src/conv_wrw_driver_offline.cpp @@ -0,0 +1,281 @@ +#include +#include +#include +#include +#include +#include +#include "config.hpp" +#include "print.hpp" +#include "device.hpp" +#include "host_tensor.hpp" +#include "host_tensor_generator.hpp" +#include "conv_common.hpp" +#include "host_conv_bwd_weight.hpp" +#include "device_tensor.hpp" +#include "device_convolution_backward_weight_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw.hpp" + +#define USE_DYNAMIC_MODE 1 +#define USE_CONV_WRW_V4R4R2_XDL_NCHW 1 + +enum ConvBackwardWeightAlgo +{ + V4R4R2XDLNCHW, +}; + +int main(int argc, char* argv[]) +{ + using namespace ck; + + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + constexpr auto I2 = Number<2>{}; + constexpr auto I3 = Number<3>{}; + constexpr auto I4 = Number<4>{}; + constexpr auto I5 = Number<5>{}; + constexpr auto I6 = Number<6>{}; + +#if USE_DYNAMIC_MODE + // dynamic mode + if(argc != 22) + { + printf("arg1 to 6: layout, algo, do_verification, init_method, do_log, nrepeat\n"); + printf("rest: N, K, C, Y, X, Hi, Wi, Sy, Sx, Dy, Dx, LeftPy, LeftPx, RightPy, RightPx\n"); + exit(1); + } + + const ConvTensorLayout layout = static_cast(std::stoi(argv[1])); + const ConvBackwardWeightAlgo algo = static_cast(std::stoi(argv[2])); + const bool do_verification = std::stoi(argv[3]); + const int init_method = std::stoi(argv[4]); + const bool do_log = std::stoi(argv[5]); + const int nrepeat = std::stoi(argv[6]); + + const index_t N = std::stoi(argv[7]); + const index_t K = std::stoi(argv[8]); + const index_t C = std::stoi(argv[9]); + const index_t Y = std::stoi(argv[10]); + const index_t X = std::stoi(argv[11]); + const index_t Hi = std::stoi(argv[12]); + const index_t Wi = std::stoi(argv[13]); + + const index_t conv_stride_h = std::stoi(argv[14]); + const index_t conv_stride_w = std::stoi(argv[15]); + const index_t conv_dilation_h = std::stoi(argv[16]); + const index_t conv_dilation_w = std::stoi(argv[17]); + const index_t in_left_pad_h = std::stoi(argv[18]); + const index_t in_left_pad_w = std::stoi(argv[19]); + const index_t in_right_pad_h = std::stoi(argv[20]); + const index_t in_right_pad_w = std::stoi(argv[21]); + + const index_t YEff = (Y - 1) * conv_dilation_h + 1; + const index_t XEff = (X - 1) * conv_dilation_w + 1; + + const index_t Ho = (Hi + in_left_pad_h + in_right_pad_h - YEff) / conv_stride_h + 1; + const index_t Wo = (Wi + in_left_pad_w + in_right_pad_w - XEff) / conv_stride_w + 1; +#else + // static mode + if(argc < 7) + { + printf("arg1 to 6: layout, algo, do_verification, init_method, do_log, nrepeat\n"); + exit(1); + } + + const ConvTensorLayout layout = static_cast(std::stoi(argv[1])); + const ConvBackwardWeightAlgo algo = static_cast(std::stoi(argv[2])); + const bool do_verification = std::stoi(argv[3]); + const int init_method = std::stoi(argv[4]); + const bool do_log = std::stoi(argv[5]); + const int nrepeat = std::stoi(argv[6]); + + constexpr auto N = Number<128>{}; + constexpr auto C = Number<128>{}; + constexpr auto Hi = Number<14>{}; + constexpr auto Wi = Number<14>{}; + constexpr auto K = Number<256>{}; + constexpr auto Y = Number<3>{}; + constexpr auto X = Number<3>{}; + + constexpr auto conv_stride_h = I1; + constexpr auto conv_stride_w = I1; + constexpr auto conv_dilation_h = I1; + constexpr auto conv_dilation_w = I1; + constexpr auto in_left_pad_h = I1; + constexpr auto in_left_pad_w = I1; + constexpr auto in_right_pad_h = I1; + constexpr auto in_right_pad_w = I1; + + constexpr auto YEff = (Y - I1) * conv_dilation_h + I1; + constexpr auto XEff = (X - I1) * conv_dilation_w + I1; + + constexpr auto Ho = (Hi + in_left_pad_h + in_right_pad_h - YEff) / conv_stride_h + I1; + constexpr auto Wo = (Wi + in_left_pad_w + in_right_pad_w - XEff) / conv_stride_w + I1; +#endif + +#if 1 + using in_data_t = float; + using acc_data_t = float; + using out_data_t = float; +#elif 1 + using in_data_t = half_t; + using acc_data_t = float; + using out_data_t = half_t; +#elif 1 + using in_data_t = int8_t; + using acc_data_t = int32_t; + using out_data_t = int8_t; +#endif + + std::vector in_lengths_host(4), wei_lengths_host(4), out_lengths_host(4); + + if(layout == ConvTensorLayout::NCHW) + { + in_lengths_host[0] = static_cast(N); + in_lengths_host[1] = static_cast(C); + in_lengths_host[2] = static_cast(Hi); + in_lengths_host[3] = static_cast(Wi); + wei_lengths_host[0] = static_cast(K); + wei_lengths_host[1] = static_cast(C); + wei_lengths_host[2] = static_cast(Y); + wei_lengths_host[3] = static_cast(X); + out_lengths_host[0] = static_cast(N); + out_lengths_host[1] = static_cast(K); + out_lengths_host[2] = static_cast(Ho); + out_lengths_host[3] = static_cast(Wo); + } + else if(layout == ConvTensorLayout::NHWC) + { + in_lengths_host[0] = static_cast(N); + in_lengths_host[1] = static_cast(Hi); + in_lengths_host[2] = static_cast(Wi); + in_lengths_host[3] = static_cast(C); + wei_lengths_host[0] = static_cast(K); + wei_lengths_host[1] = static_cast(Y); + wei_lengths_host[2] = static_cast(X); + wei_lengths_host[3] = static_cast(C); + out_lengths_host[0] = static_cast(N); + out_lengths_host[1] = static_cast(Ho); + out_lengths_host[2] = static_cast(Wo); + out_lengths_host[3] = static_cast(K); + } + else + { + std::runtime_error("wrong! not implemented"); + } + + Tensor in(in_lengths_host); + Tensor wei_device(wei_lengths_host); + Tensor wei_host(wei_lengths_host); + Tensor out(out_lengths_host); + + std::cout << "layout: " << layout << std::endl; + ostream_HostTensorDescriptor(in.mDesc, std::cout << "in: "); + ostream_HostTensorDescriptor(wei_host.mDesc, std::cout << "wei: "); + ostream_HostTensorDescriptor(out.mDesc, std::cout << "out: "); + print_array("InLeftPads", make_tuple(in_left_pad_h, in_left_pad_w)); + print_array("InRightPads", make_tuple(in_right_pad_h, in_right_pad_w)); + print_array("ConvStrides", make_tuple(conv_stride_h, conv_stride_w)); + print_array("ConvDilations", make_tuple(conv_dilation_h, conv_dilation_w)); + + std::size_t num_thread = std::thread::hardware_concurrency(); + + switch(init_method) + { + case 0: + // no initialization + break; + case 1: + in.GenerateTensorValue(GeneratorTensor_1{}, num_thread); + out.GenerateTensorValue(GeneratorTensor_1{}, num_thread); + break; + case 2: + in.GenerateTensorValue(GeneratorTensor_1{}, num_thread); + out.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); + break; + case 3: + in.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); + out.GenerateTensorValue(GeneratorTensor_1{}, num_thread); + break; + case 4: + in.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); + out.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); + break; + case 5: + in.GenerateTensorValue(GeneratorTensor_3{-0.1, 0.1}, num_thread); + out.GenerateTensorValue(GeneratorTensor_3{-0.1, 0.1}, num_thread); + break; + default: + in.GenerateTensorValue(GeneratorTensor_2{1, 5}, num_thread); + + auto gen_out = [](auto... is) { + return GeneratorTensor_2{1, 5}(is...) * GeneratorTensor_Checkboard{}(is...); + }; + out.GenerateTensorValue(gen_out, num_thread); + } + + auto f_make_for_device_nchw = [&]() { + const auto in_lengths_dev = make_tuple(N, C, Hi, Wi); + const auto wei_lengths_dev = make_tuple(K, C, Y, X); + const auto out_lengths_dev = make_tuple(N, K, Ho, Wo); + const auto conv_strides_dev = make_tuple(conv_stride_h, conv_stride_w); + const auto conv_dilations_dev = make_tuple(conv_dilation_h, conv_dilation_w); + const auto in_left_pads_dev = make_tuple(in_left_pad_h, in_left_pad_w); + const auto in_right_pads_dev = make_tuple(in_right_pad_h, in_right_pad_w); + + return make_tuple(in_lengths_dev, + wei_lengths_dev, + out_lengths_dev, + conv_strides_dev, + conv_dilations_dev, + in_left_pads_dev, + in_right_pads_dev); + }; + +#if USE_CONV_WRW_V4R4R2_XDL_NCHW + if(algo == ConvBackwardWeightAlgo::V4R4R2XDLNCHW) + { + if(layout != ConvTensorLayout::NCHW) + { + throw std::runtime_error("wrong! layout"); + } + + const auto tmp = f_make_for_device_nchw(); + + device_convolution_backward_weight_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw( + tmp[I0], + tmp[I1], + tmp[I2], + tmp[I3], + tmp[I4], + tmp[I5], + tmp[I6], + in, + wei_device, + out, + nrepeat); + } +#endif + + if(do_verification) + { + host_direct_convolution_backward_weights(out, + in, + wei_host, + make_tuple(conv_stride_h, conv_stride_w), + make_tuple(conv_dilation_h, conv_dilation_w), + make_tuple(in_left_pad_h, in_left_pad_w), + make_tuple(in_right_pad_h, in_right_pad_w), + layout); + + check_error(wei_host, wei_device); + + if(do_log) + { + LogRangeAsType(std::cout << "out: ", out.mData, ",") << std::endl; + LogRangeAsType(std::cout << "in : ", in.mData, ",") << std::endl; + LogRangeAsType(std::cout << "wei_device: ", wei_device.mData, ",") << std::endl; + LogRangeAsType(std::cout << "wei_host : ", wei_host.mData, ",") << std::endl; + } + } +} diff --git a/host/host_tensor/include/host_conv_wrw.hpp b/host/host_tensor/include/host_conv_bwd_weight.hpp similarity index 100% rename from host/host_tensor/include/host_conv_wrw.hpp rename to host/host_tensor/include/host_conv_bwd_weight.hpp